gstaichi 0.1.18.dev1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +113 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +122 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +122 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +122 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +122 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +122 -0
- gstaichi-0.1.18.dev1.data/data/bin/SPIRV-Tools-shared.dll +0 -0
- gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-diff.lib +0 -0
- gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-link.lib +0 -0
- gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-lint.lib +0 -0
- gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-opt.lib +0 -0
- gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-reduce.lib +0 -0
- gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-shared.lib +0 -0
- gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools.lib +0 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-0.1.18.dev1.data/data/lib/glfw3.lib +0 -0
- gstaichi-0.1.18.dev1.dist-info/METADATA +108 -0
- gstaichi-0.1.18.dev1.dist-info/RECORD +198 -0
- gstaichi-0.1.18.dev1.dist-info/WHEEL +5 -0
- gstaichi-0.1.18.dev1.dist-info/entry_points.txt +2 -0
- gstaichi-0.1.18.dev1.dist-info/licenses/LICENSE +201 -0
- gstaichi-0.1.18.dev1.dist-info/top_level.txt +1 -0
- taichi/CHANGELOG.md +15 -0
- taichi/__init__.py +44 -0
- taichi/__main__.py +5 -0
- taichi/_funcs.py +706 -0
- taichi/_kernels.py +420 -0
- taichi/_lib/__init__.py +3 -0
- taichi/_lib/c_api/bin/taichi_c_api.dll +0 -0
- taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +1401 -0
- taichi/_lib/c_api/include/taichi/taichi.h +29 -0
- taichi/_lib/c_api/include/taichi/taichi_core.h +1111 -0
- taichi/_lib/c_api/include/taichi/taichi_cpu.h +29 -0
- taichi/_lib/c_api/include/taichi/taichi_cuda.h +36 -0
- taichi/_lib/c_api/include/taichi/taichi_platform.h +55 -0
- taichi/_lib/c_api/include/taichi/taichi_unity.h +64 -0
- taichi/_lib/c_api/include/taichi/taichi_vulkan.h +151 -0
- taichi/_lib/c_api/lib/taichi_c_api.lib +0 -0
- taichi/_lib/c_api/runtime/runtime_cuda.bc +0 -0
- taichi/_lib/c_api/runtime/runtime_x64.bc +0 -0
- taichi/_lib/c_api/runtime/slim_libdevice.10.bc +0 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +29 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +65 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +121 -0
- taichi/_lib/core/__init__.py +0 -0
- taichi/_lib/core/py.typed +0 -0
- taichi/_lib/core/taichi_python.cp310-win_amd64.pyd +0 -0
- taichi/_lib/core/taichi_python.pyi +3077 -0
- taichi/_lib/runtime/runtime_cuda.bc +0 -0
- taichi/_lib/runtime/runtime_x64.bc +0 -0
- taichi/_lib/runtime/slim_libdevice.10.bc +0 -0
- taichi/_lib/utils.py +249 -0
- taichi/_logging.py +131 -0
- taichi/_main.py +552 -0
- taichi/_snode/__init__.py +5 -0
- taichi/_snode/fields_builder.py +189 -0
- taichi/_snode/snode_tree.py +34 -0
- taichi/_ti_module/__init__.py +3 -0
- taichi/_ti_module/cppgen.py +309 -0
- taichi/_ti_module/module.py +145 -0
- taichi/_version.py +1 -0
- taichi/_version_check.py +100 -0
- taichi/ad/__init__.py +3 -0
- taichi/ad/_ad.py +530 -0
- taichi/algorithms/__init__.py +3 -0
- taichi/algorithms/_algorithms.py +117 -0
- taichi/aot/__init__.py +12 -0
- taichi/aot/_export.py +28 -0
- taichi/aot/conventions/__init__.py +3 -0
- taichi/aot/conventions/gfxruntime140/__init__.py +38 -0
- taichi/aot/conventions/gfxruntime140/dr.py +244 -0
- taichi/aot/conventions/gfxruntime140/sr.py +613 -0
- taichi/aot/module.py +253 -0
- taichi/aot/utils.py +151 -0
- taichi/assets/.git +1 -0
- taichi/assets/Go-Regular.ttf +0 -0
- taichi/assets/static/imgs/ti_gallery.png +0 -0
- taichi/examples/minimal.py +28 -0
- taichi/experimental.py +16 -0
- taichi/graph/__init__.py +3 -0
- taichi/graph/_graph.py +292 -0
- taichi/lang/__init__.py +50 -0
- taichi/lang/_ndarray.py +348 -0
- taichi/lang/_ndrange.py +152 -0
- taichi/lang/_texture.py +172 -0
- taichi/lang/_wrap_inspect.py +189 -0
- taichi/lang/any_array.py +99 -0
- taichi/lang/argpack.py +411 -0
- taichi/lang/ast/__init__.py +5 -0
- taichi/lang/ast/ast_transformer.py +1806 -0
- taichi/lang/ast/ast_transformer_utils.py +328 -0
- taichi/lang/ast/checkers.py +106 -0
- taichi/lang/ast/symbol_resolver.py +57 -0
- taichi/lang/ast/transform.py +9 -0
- taichi/lang/common_ops.py +310 -0
- taichi/lang/exception.py +80 -0
- taichi/lang/expr.py +180 -0
- taichi/lang/field.py +464 -0
- taichi/lang/impl.py +1246 -0
- taichi/lang/kernel_arguments.py +157 -0
- taichi/lang/kernel_impl.py +1415 -0
- taichi/lang/matrix.py +1877 -0
- taichi/lang/matrix_ops.py +341 -0
- taichi/lang/matrix_ops_utils.py +190 -0
- taichi/lang/mesh.py +687 -0
- taichi/lang/misc.py +807 -0
- taichi/lang/ops.py +1489 -0
- taichi/lang/runtime_ops.py +13 -0
- taichi/lang/shell.py +35 -0
- taichi/lang/simt/__init__.py +5 -0
- taichi/lang/simt/block.py +94 -0
- taichi/lang/simt/grid.py +7 -0
- taichi/lang/simt/subgroup.py +191 -0
- taichi/lang/simt/warp.py +96 -0
- taichi/lang/snode.py +487 -0
- taichi/lang/source_builder.py +150 -0
- taichi/lang/struct.py +855 -0
- taichi/lang/util.py +381 -0
- taichi/linalg/__init__.py +8 -0
- taichi/linalg/matrixfree_cg.py +310 -0
- taichi/linalg/sparse_cg.py +59 -0
- taichi/linalg/sparse_matrix.py +303 -0
- taichi/linalg/sparse_solver.py +123 -0
- taichi/math/__init__.py +11 -0
- taichi/math/_complex.py +204 -0
- taichi/math/mathimpl.py +886 -0
- taichi/profiler/__init__.py +6 -0
- taichi/profiler/kernel_metrics.py +260 -0
- taichi/profiler/kernel_profiler.py +592 -0
- taichi/profiler/memory_profiler.py +15 -0
- taichi/profiler/scoped_profiler.py +36 -0
- taichi/shaders/Circles_vk.frag +29 -0
- taichi/shaders/Circles_vk.vert +45 -0
- taichi/shaders/Circles_vk_frag.spv +0 -0
- taichi/shaders/Circles_vk_vert.spv +0 -0
- taichi/shaders/Lines_vk.frag +9 -0
- taichi/shaders/Lines_vk.vert +11 -0
- taichi/shaders/Lines_vk_frag.spv +0 -0
- taichi/shaders/Lines_vk_vert.spv +0 -0
- taichi/shaders/Mesh_vk.frag +71 -0
- taichi/shaders/Mesh_vk.vert +68 -0
- taichi/shaders/Mesh_vk_frag.spv +0 -0
- taichi/shaders/Mesh_vk_vert.spv +0 -0
- taichi/shaders/Particles_vk.frag +95 -0
- taichi/shaders/Particles_vk.vert +73 -0
- taichi/shaders/Particles_vk_frag.spv +0 -0
- taichi/shaders/Particles_vk_vert.spv +0 -0
- taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
- taichi/shaders/SceneLines_vk.frag +9 -0
- taichi/shaders/SceneLines_vk.vert +12 -0
- taichi/shaders/SceneLines_vk_frag.spv +0 -0
- taichi/shaders/SceneLines_vk_vert.spv +0 -0
- taichi/shaders/SetImage_vk.frag +21 -0
- taichi/shaders/SetImage_vk.vert +15 -0
- taichi/shaders/SetImage_vk_frag.spv +0 -0
- taichi/shaders/SetImage_vk_vert.spv +0 -0
- taichi/shaders/Triangles_vk.frag +16 -0
- taichi/shaders/Triangles_vk.vert +29 -0
- taichi/shaders/Triangles_vk_frag.spv +0 -0
- taichi/shaders/Triangles_vk_vert.spv +0 -0
- taichi/shaders/lines2quad_vk_comp.spv +0 -0
- taichi/sparse/__init__.py +3 -0
- taichi/sparse/_sparse_grid.py +77 -0
- taichi/tools/__init__.py +12 -0
- taichi/tools/diagnose.py +124 -0
- taichi/tools/np2ply.py +364 -0
- taichi/tools/vtk.py +38 -0
- taichi/types/__init__.py +19 -0
- taichi/types/annotations.py +47 -0
- taichi/types/compound_types.py +90 -0
- taichi/types/enums.py +49 -0
- taichi/types/ndarray_type.py +147 -0
- taichi/types/primitive_types.py +203 -0
- taichi/types/quant.py +88 -0
- taichi/types/texture_type.py +85 -0
- taichi/types/utils.py +13 -0
@@ -0,0 +1,1806 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import ast
|
4
|
+
import collections.abc
|
5
|
+
import dataclasses
|
6
|
+
import inspect
|
7
|
+
import itertools
|
8
|
+
import math
|
9
|
+
import operator
|
10
|
+
import re
|
11
|
+
import warnings
|
12
|
+
from ast import unparse
|
13
|
+
from collections import ChainMap
|
14
|
+
from typing import Any, Iterable, Type
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
|
18
|
+
from taichi._lib import core as _ti_core
|
19
|
+
from taichi.lang import _ndarray, any_array, expr, impl, kernel_arguments, matrix, mesh
|
20
|
+
from taichi.lang import ops as ti_ops
|
21
|
+
from taichi.lang._ndrange import _Ndrange, ndrange
|
22
|
+
from taichi.lang.argpack import ArgPackType
|
23
|
+
from taichi.lang.ast.ast_transformer_utils import (
|
24
|
+
ASTTransformerContext,
|
25
|
+
Builder,
|
26
|
+
LoopStatus,
|
27
|
+
ReturnStatus,
|
28
|
+
)
|
29
|
+
from taichi.lang.ast.symbol_resolver import ASTResolver
|
30
|
+
from taichi.lang.exception import (
|
31
|
+
TaichiIndexError,
|
32
|
+
TaichiRuntimeTypeError,
|
33
|
+
TaichiSyntaxError,
|
34
|
+
TaichiTypeError,
|
35
|
+
handle_exception_from_cpp,
|
36
|
+
)
|
37
|
+
from taichi.lang.expr import Expr, make_expr_group
|
38
|
+
from taichi.lang.field import Field
|
39
|
+
from taichi.lang.matrix import Matrix, MatrixType, Vector
|
40
|
+
from taichi.lang.snode import append, deactivate, length
|
41
|
+
from taichi.lang.struct import Struct, StructType
|
42
|
+
from taichi.lang.util import is_taichi_class, to_taichi_type
|
43
|
+
from taichi.types import annotations, ndarray_type, primitive_types, texture_type
|
44
|
+
from taichi.types.utils import is_integral
|
45
|
+
|
46
|
+
|
47
|
+
def reshape_list(flat_list: list[Any], target_shape: Iterable[int]) -> list[Any]:
|
48
|
+
if len(target_shape) < 2:
|
49
|
+
return flat_list
|
50
|
+
|
51
|
+
curr_list = []
|
52
|
+
dim = target_shape[-1]
|
53
|
+
for i, elem in enumerate(flat_list):
|
54
|
+
if i % dim == 0:
|
55
|
+
curr_list.append([])
|
56
|
+
curr_list[-1].append(elem)
|
57
|
+
|
58
|
+
return reshape_list(curr_list, target_shape[:-1])
|
59
|
+
|
60
|
+
|
61
|
+
def boundary_type_cast_warning(expression: Expr) -> None:
|
62
|
+
expr_dtype = expression.ptr.get_rvalue_type()
|
63
|
+
if not is_integral(expr_dtype) or expr_dtype in [
|
64
|
+
primitive_types.i64,
|
65
|
+
primitive_types.u64,
|
66
|
+
primitive_types.u32,
|
67
|
+
]:
|
68
|
+
warnings.warn(
|
69
|
+
f"Casting range_for boundary values from {expr_dtype} to i32, which may cause numerical issues",
|
70
|
+
Warning,
|
71
|
+
)
|
72
|
+
|
73
|
+
|
74
|
+
class ASTTransformer(Builder):
|
75
|
+
@staticmethod
|
76
|
+
def build_Name(ctx: ASTTransformerContext, node: ast.Name):
|
77
|
+
node.ptr = ctx.get_var_by_name(node.id)
|
78
|
+
if isinstance(node, (ast.stmt, ast.expr)) and isinstance(node.ptr, Expr):
|
79
|
+
node.ptr.dbg_info = _ti_core.DebugInfo(ctx.get_pos_info(node))
|
80
|
+
node.ptr.ptr.set_dbg_info(node.ptr.dbg_info)
|
81
|
+
return node.ptr
|
82
|
+
|
83
|
+
@staticmethod
|
84
|
+
def build_AnnAssign(ctx: ASTTransformerContext, node: ast.AnnAssign):
|
85
|
+
build_stmt(ctx, node.value)
|
86
|
+
build_stmt(ctx, node.annotation)
|
87
|
+
|
88
|
+
is_static_assign = isinstance(node.value, ast.Call) and node.value.func.ptr is impl.static
|
89
|
+
|
90
|
+
node.ptr = ASTTransformer.build_assign_annotated(
|
91
|
+
ctx, node.target, node.value.ptr, is_static_assign, node.annotation.ptr
|
92
|
+
)
|
93
|
+
return node.ptr
|
94
|
+
|
95
|
+
@staticmethod
|
96
|
+
def build_assign_annotated(
|
97
|
+
ctx: ASTTransformerContext, target: ast.Name, value, is_static_assign: bool, annotation: Type
|
98
|
+
):
|
99
|
+
"""Build an annotated assignment like this: target: annotation = value.
|
100
|
+
|
101
|
+
Args:
|
102
|
+
ctx (ast_builder_utils.BuilderContext): The builder context.
|
103
|
+
target (ast.Name): A variable name. `target.id` holds the name as
|
104
|
+
a string.
|
105
|
+
annotation: A type we hope to assign to the target
|
106
|
+
value: A node representing the value.
|
107
|
+
is_static_assign: A boolean value indicating whether this is a static assignment
|
108
|
+
"""
|
109
|
+
is_local = isinstance(target, ast.Name)
|
110
|
+
if is_local and target.id in ctx.kernel_args:
|
111
|
+
raise TaichiSyntaxError(
|
112
|
+
f'Kernel argument "{target.id}" is immutable in the kernel. '
|
113
|
+
f"If you want to change its value, please create a new variable."
|
114
|
+
)
|
115
|
+
anno = impl.expr_init(annotation)
|
116
|
+
if is_static_assign:
|
117
|
+
raise TaichiSyntaxError("Static assign cannot be used on annotated assignment")
|
118
|
+
if is_local and not ctx.is_var_declared(target.id):
|
119
|
+
var = ti_ops.cast(value, anno)
|
120
|
+
var = impl.expr_init(var)
|
121
|
+
ctx.create_variable(target.id, var)
|
122
|
+
else:
|
123
|
+
var = build_stmt(ctx, target)
|
124
|
+
if var.ptr.get_rvalue_type() != anno:
|
125
|
+
raise TaichiSyntaxError("Static assign cannot have type overloading")
|
126
|
+
var._assign(value)
|
127
|
+
return var
|
128
|
+
|
129
|
+
@staticmethod
|
130
|
+
def build_Assign(ctx: ASTTransformerContext, node: ast.Assign) -> None:
|
131
|
+
build_stmt(ctx, node.value)
|
132
|
+
is_static_assign = isinstance(node.value, ast.Call) and node.value.func.ptr is impl.static
|
133
|
+
|
134
|
+
# Keep all generated assign statements and compose single one at last.
|
135
|
+
# The variable is introduced to support chained assignments.
|
136
|
+
# Ref https://github.com/taichi-dev/taichi/issues/2659.
|
137
|
+
values = node.value.ptr if is_static_assign else impl.expr_init(node.value.ptr)
|
138
|
+
|
139
|
+
for node_target in node.targets:
|
140
|
+
ASTTransformer.build_assign_unpack(ctx, node_target, values, is_static_assign)
|
141
|
+
return None
|
142
|
+
|
143
|
+
@staticmethod
|
144
|
+
def build_assign_unpack(ctx: ASTTransformerContext, node_target: list | ast.Tuple, values, is_static_assign: bool):
|
145
|
+
"""Build the unpack assignments like this: (target1, target2) = (value1, value2).
|
146
|
+
The function should be called only if the node target is a tuple.
|
147
|
+
|
148
|
+
Args:
|
149
|
+
ctx (ast_builder_utils.BuilderContext): The builder context.
|
150
|
+
node_target (ast.Tuple): A list or tuple object. `node_target.elts` holds a
|
151
|
+
list of nodes representing the elements.
|
152
|
+
values: A node/list representing the values.
|
153
|
+
is_static_assign: A boolean value indicating whether this is a static assignment
|
154
|
+
"""
|
155
|
+
if not isinstance(node_target, ast.Tuple):
|
156
|
+
return ASTTransformer.build_assign_basic(ctx, node_target, values, is_static_assign)
|
157
|
+
targets = node_target.elts
|
158
|
+
|
159
|
+
if isinstance(values, matrix.Matrix):
|
160
|
+
if not values.m == 1:
|
161
|
+
raise ValueError("Matrices with more than one columns cannot be unpacked")
|
162
|
+
values = values.entries
|
163
|
+
|
164
|
+
# Unpack: a, b, c = ti.Vector([1., 2., 3.])
|
165
|
+
if isinstance(values, impl.Expr) and values.ptr.is_tensor():
|
166
|
+
if len(values.get_shape()) > 1:
|
167
|
+
raise ValueError("Matrices with more than one columns cannot be unpacked")
|
168
|
+
|
169
|
+
values = ctx.ast_builder.expand_exprs([values.ptr])
|
170
|
+
if len(values) == 1:
|
171
|
+
values = values[0]
|
172
|
+
|
173
|
+
if isinstance(values, impl.Expr) and values.ptr.is_struct():
|
174
|
+
values = ctx.ast_builder.expand_exprs([values.ptr])
|
175
|
+
if len(values) == 1:
|
176
|
+
values = values[0]
|
177
|
+
|
178
|
+
if not isinstance(values, collections.abc.Sequence):
|
179
|
+
raise TaichiSyntaxError(f"Cannot unpack type: {type(values)}")
|
180
|
+
|
181
|
+
if len(values) != len(targets):
|
182
|
+
raise TaichiSyntaxError("The number of targets is not equal to value length")
|
183
|
+
|
184
|
+
for i, target in enumerate(targets):
|
185
|
+
ASTTransformer.build_assign_basic(ctx, target, values[i], is_static_assign)
|
186
|
+
|
187
|
+
return None
|
188
|
+
|
189
|
+
@staticmethod
|
190
|
+
def build_assign_basic(ctx: ASTTransformerContext, target: ast.Name, value, is_static_assign: bool):
|
191
|
+
"""Build basic assignment like this: target = value.
|
192
|
+
|
193
|
+
Args:
|
194
|
+
ctx (ast_builder_utils.BuilderContext): The builder context.
|
195
|
+
target (ast.Name): A variable name. `target.id` holds the name as
|
196
|
+
a string.
|
197
|
+
value: A node representing the value.
|
198
|
+
is_static_assign: A boolean value indicating whether this is a static assignment
|
199
|
+
"""
|
200
|
+
is_local = isinstance(target, ast.Name)
|
201
|
+
if is_local and target.id in ctx.kernel_args:
|
202
|
+
raise TaichiSyntaxError(
|
203
|
+
f'Kernel argument "{target.id}" is immutable in the kernel. '
|
204
|
+
f"If you want to change its value, please create a new variable."
|
205
|
+
)
|
206
|
+
if is_static_assign:
|
207
|
+
if not is_local:
|
208
|
+
raise TaichiSyntaxError("Static assign cannot be used on elements in arrays")
|
209
|
+
ctx.create_variable(target.id, value)
|
210
|
+
var = value
|
211
|
+
elif is_local and not ctx.is_var_declared(target.id):
|
212
|
+
var = impl.expr_init(value)
|
213
|
+
ctx.create_variable(target.id, var)
|
214
|
+
else:
|
215
|
+
var = build_stmt(ctx, target)
|
216
|
+
try:
|
217
|
+
var._assign(value)
|
218
|
+
except AttributeError:
|
219
|
+
raise TaichiSyntaxError(
|
220
|
+
f"Variable '{unparse(target).strip()}' cannot be assigned. Maybe it is not a Taichi object?"
|
221
|
+
)
|
222
|
+
return var
|
223
|
+
|
224
|
+
@staticmethod
|
225
|
+
def build_NamedExpr(ctx: ASTTransformerContext, node: ast.NamedExpr):
|
226
|
+
build_stmt(ctx, node.value)
|
227
|
+
is_static_assign = isinstance(node.value, ast.Call) and node.value.func.ptr is impl.static
|
228
|
+
node.ptr = ASTTransformer.build_assign_basic(ctx, node.target, node.value.ptr, is_static_assign)
|
229
|
+
return node.ptr
|
230
|
+
|
231
|
+
@staticmethod
|
232
|
+
def is_tuple(node):
|
233
|
+
if isinstance(node, ast.Tuple):
|
234
|
+
return True
|
235
|
+
if isinstance(node, ast.Index) and isinstance(node.value.ptr, tuple):
|
236
|
+
return True
|
237
|
+
if isinstance(node.ptr, tuple):
|
238
|
+
return True
|
239
|
+
return False
|
240
|
+
|
241
|
+
@staticmethod
|
242
|
+
def build_Subscript(ctx: ASTTransformerContext, node: ast.Subscript):
|
243
|
+
build_stmt(ctx, node.value)
|
244
|
+
build_stmt(ctx, node.slice)
|
245
|
+
if not ASTTransformer.is_tuple(node.slice):
|
246
|
+
node.slice.ptr = [node.slice.ptr]
|
247
|
+
node.ptr = impl.subscript(ctx.ast_builder, node.value.ptr, *node.slice.ptr)
|
248
|
+
return node.ptr
|
249
|
+
|
250
|
+
@staticmethod
|
251
|
+
def build_Slice(ctx: ASTTransformerContext, node: ast.Slice):
|
252
|
+
if node.lower is not None:
|
253
|
+
build_stmt(ctx, node.lower)
|
254
|
+
if node.upper is not None:
|
255
|
+
build_stmt(ctx, node.upper)
|
256
|
+
if node.step is not None:
|
257
|
+
build_stmt(ctx, node.step)
|
258
|
+
|
259
|
+
node.ptr = slice(
|
260
|
+
node.lower.ptr if node.lower else None,
|
261
|
+
node.upper.ptr if node.upper else None,
|
262
|
+
node.step.ptr if node.step else None,
|
263
|
+
)
|
264
|
+
return node.ptr
|
265
|
+
|
266
|
+
@staticmethod
|
267
|
+
def build_ExtSlice(ctx: ASTTransformerContext, node: ast.ExtSlice):
|
268
|
+
build_stmts(ctx, node.dims)
|
269
|
+
node.ptr = tuple(dim.ptr for dim in node.dims)
|
270
|
+
return node.ptr
|
271
|
+
|
272
|
+
@staticmethod
|
273
|
+
def build_Tuple(ctx: ASTTransformerContext, node: ast.Tuple):
|
274
|
+
build_stmts(ctx, node.elts)
|
275
|
+
node.ptr = tuple(elt.ptr for elt in node.elts)
|
276
|
+
return node.ptr
|
277
|
+
|
278
|
+
@staticmethod
|
279
|
+
def build_List(ctx: ASTTransformerContext, node: ast.List):
|
280
|
+
build_stmts(ctx, node.elts)
|
281
|
+
node.ptr = [elt.ptr for elt in node.elts]
|
282
|
+
return node.ptr
|
283
|
+
|
284
|
+
@staticmethod
|
285
|
+
def build_Dict(ctx: ASTTransformerContext, node: ast.Dict):
|
286
|
+
dic = {}
|
287
|
+
for key, value in zip(node.keys, node.values):
|
288
|
+
if key is None:
|
289
|
+
dic.update(build_stmt(ctx, value))
|
290
|
+
else:
|
291
|
+
dic[build_stmt(ctx, key)] = build_stmt(ctx, value)
|
292
|
+
node.ptr = dic
|
293
|
+
return node.ptr
|
294
|
+
|
295
|
+
@staticmethod
|
296
|
+
def process_listcomp(ctx: ASTTransformerContext, node, result) -> None:
|
297
|
+
result.append(build_stmt(ctx, node.elt))
|
298
|
+
|
299
|
+
@staticmethod
|
300
|
+
def process_dictcomp(ctx: ASTTransformerContext, node, result) -> None:
|
301
|
+
key = build_stmt(ctx, node.key)
|
302
|
+
value = build_stmt(ctx, node.value)
|
303
|
+
result[key] = value
|
304
|
+
|
305
|
+
@staticmethod
|
306
|
+
def process_generators(ctx: ASTTransformerContext, node: ast.GeneratorExp, now_comp, func, result):
|
307
|
+
if now_comp >= len(node.generators):
|
308
|
+
return func(ctx, node, result)
|
309
|
+
with ctx.static_scope_guard():
|
310
|
+
_iter = build_stmt(ctx, node.generators[now_comp].iter)
|
311
|
+
|
312
|
+
if isinstance(_iter, impl.Expr) and _iter.ptr.is_tensor():
|
313
|
+
shape = _iter.ptr.get_shape()
|
314
|
+
flattened = [Expr(x) for x in ctx.ast_builder.expand_exprs([_iter.ptr])]
|
315
|
+
_iter = reshape_list(flattened, shape)
|
316
|
+
|
317
|
+
for value in _iter:
|
318
|
+
with ctx.variable_scope_guard():
|
319
|
+
ASTTransformer.build_assign_unpack(ctx, node.generators[now_comp].target, value, True)
|
320
|
+
with ctx.static_scope_guard():
|
321
|
+
build_stmts(ctx, node.generators[now_comp].ifs)
|
322
|
+
ASTTransformer.process_ifs(ctx, node, now_comp, 0, func, result)
|
323
|
+
return None
|
324
|
+
|
325
|
+
@staticmethod
|
326
|
+
def process_ifs(ctx: ASTTransformerContext, node: ast.If, now_comp, now_if, func, result):
|
327
|
+
if now_if >= len(node.generators[now_comp].ifs):
|
328
|
+
return ASTTransformer.process_generators(ctx, node, now_comp + 1, func, result)
|
329
|
+
cond = node.generators[now_comp].ifs[now_if].ptr
|
330
|
+
if cond:
|
331
|
+
ASTTransformer.process_ifs(ctx, node, now_comp, now_if + 1, func, result)
|
332
|
+
|
333
|
+
return None
|
334
|
+
|
335
|
+
@staticmethod
|
336
|
+
def build_ListComp(ctx: ASTTransformerContext, node: ast.ListComp):
|
337
|
+
result = []
|
338
|
+
ASTTransformer.process_generators(ctx, node, 0, ASTTransformer.process_listcomp, result)
|
339
|
+
node.ptr = result
|
340
|
+
return node.ptr
|
341
|
+
|
342
|
+
@staticmethod
|
343
|
+
def build_DictComp(ctx: ASTTransformerContext, node: ast.DictComp):
|
344
|
+
result = {}
|
345
|
+
ASTTransformer.process_generators(ctx, node, 0, ASTTransformer.process_dictcomp, result)
|
346
|
+
node.ptr = result
|
347
|
+
return node.ptr
|
348
|
+
|
349
|
+
@staticmethod
|
350
|
+
def build_Index(ctx: ASTTransformerContext, node: ast.Index):
|
351
|
+
node.ptr = build_stmt(ctx, node.value)
|
352
|
+
return node.ptr
|
353
|
+
|
354
|
+
@staticmethod
|
355
|
+
def build_Constant(ctx: ASTTransformerContext, node: ast.Constant):
|
356
|
+
node.ptr = node.value
|
357
|
+
return node.ptr
|
358
|
+
|
359
|
+
@staticmethod
|
360
|
+
def build_Num(ctx: ASTTransformerContext, node: ast.Num):
|
361
|
+
node.ptr = node.n
|
362
|
+
return node.ptr
|
363
|
+
|
364
|
+
@staticmethod
|
365
|
+
def build_Str(ctx: ASTTransformerContext, node: ast.Str):
|
366
|
+
node.ptr = node.s
|
367
|
+
return node.ptr
|
368
|
+
|
369
|
+
@staticmethod
|
370
|
+
def build_Bytes(ctx: ASTTransformerContext, node: ast.Bytes):
|
371
|
+
node.ptr = node.s
|
372
|
+
return node.ptr
|
373
|
+
|
374
|
+
@staticmethod
|
375
|
+
def build_NameConstant(ctx: ASTTransformerContext, node: ast.NameConstant):
|
376
|
+
node.ptr = node.value
|
377
|
+
return node.ptr
|
378
|
+
|
379
|
+
@staticmethod
|
380
|
+
def build_keyword(ctx: ASTTransformerContext, node: ast.keyword):
|
381
|
+
build_stmt(ctx, node.value)
|
382
|
+
if node.arg is None:
|
383
|
+
node.ptr = node.value.ptr
|
384
|
+
else:
|
385
|
+
node.ptr = {node.arg: node.value.ptr}
|
386
|
+
return node.ptr
|
387
|
+
|
388
|
+
@staticmethod
|
389
|
+
def build_Starred(ctx: ASTTransformerContext, node: ast.Starred):
|
390
|
+
node.ptr = build_stmt(ctx, node.value)
|
391
|
+
return node.ptr
|
392
|
+
|
393
|
+
@staticmethod
|
394
|
+
def build_FormattedValue(ctx: ASTTransformerContext, node: ast.FormattedValue):
|
395
|
+
node.ptr = build_stmt(ctx, node.value)
|
396
|
+
if node.format_spec is None or len(node.format_spec.values) == 0:
|
397
|
+
return node.ptr
|
398
|
+
values = node.format_spec.values
|
399
|
+
assert len(values) == 1
|
400
|
+
format_str = values[0].s
|
401
|
+
assert format_str is not None
|
402
|
+
# distinguished from normal list
|
403
|
+
return ["__ti_fmt_value__", node.ptr, format_str]
|
404
|
+
|
405
|
+
@staticmethod
|
406
|
+
def build_JoinedStr(ctx: ASTTransformerContext, node: ast.JoinedStr):
|
407
|
+
str_spec = ""
|
408
|
+
args = []
|
409
|
+
for sub_node in node.values:
|
410
|
+
if isinstance(sub_node, ast.FormattedValue):
|
411
|
+
str_spec += "{}"
|
412
|
+
args.append(build_stmt(ctx, sub_node))
|
413
|
+
elif isinstance(sub_node, ast.Constant):
|
414
|
+
str_spec += sub_node.value
|
415
|
+
elif isinstance(sub_node, ast.Str):
|
416
|
+
str_spec += sub_node.s
|
417
|
+
else:
|
418
|
+
raise TaichiSyntaxError("Invalid value for fstring.")
|
419
|
+
|
420
|
+
args.insert(0, str_spec)
|
421
|
+
node.ptr = impl.ti_format(*args)
|
422
|
+
return node.ptr
|
423
|
+
|
424
|
+
@staticmethod
|
425
|
+
def build_call_if_is_builtin(ctx: ASTTransformerContext, node, args, keywords):
|
426
|
+
from taichi.lang import matrix_ops # pylint: disable=C0415
|
427
|
+
|
428
|
+
func = node.func.ptr
|
429
|
+
replace_func = {
|
430
|
+
id(print): impl.ti_print,
|
431
|
+
id(min): ti_ops.min,
|
432
|
+
id(max): ti_ops.max,
|
433
|
+
id(int): impl.ti_int,
|
434
|
+
id(bool): impl.ti_bool,
|
435
|
+
id(float): impl.ti_float,
|
436
|
+
id(any): matrix_ops.any,
|
437
|
+
id(all): matrix_ops.all,
|
438
|
+
id(abs): abs,
|
439
|
+
id(pow): pow,
|
440
|
+
id(operator.matmul): matrix_ops.matmul,
|
441
|
+
}
|
442
|
+
|
443
|
+
# Builtin 'len' function on Matrix Expr
|
444
|
+
if id(func) == id(len) and len(args) == 1:
|
445
|
+
if isinstance(args[0], Expr) and args[0].ptr.is_tensor():
|
446
|
+
node.ptr = args[0].get_shape()[0]
|
447
|
+
return True
|
448
|
+
|
449
|
+
if id(func) in replace_func:
|
450
|
+
node.ptr = replace_func[id(func)](*args, **keywords)
|
451
|
+
return True
|
452
|
+
return False
|
453
|
+
|
454
|
+
@staticmethod
|
455
|
+
def build_call_if_is_type(ctx: ASTTransformerContext, node, args, keywords):
|
456
|
+
func = node.func.ptr
|
457
|
+
if id(func) in primitive_types.type_ids:
|
458
|
+
if len(args) != 1 or keywords:
|
459
|
+
raise TaichiSyntaxError("A primitive type can only decorate a single expression.")
|
460
|
+
if is_taichi_class(args[0]):
|
461
|
+
raise TaichiSyntaxError("A primitive type cannot decorate an expression with a compound type.")
|
462
|
+
|
463
|
+
if isinstance(args[0], expr.Expr):
|
464
|
+
if args[0].ptr.is_tensor():
|
465
|
+
raise TaichiSyntaxError("A primitive type cannot decorate an expression with a compound type.")
|
466
|
+
node.ptr = ti_ops.cast(args[0], func)
|
467
|
+
else:
|
468
|
+
node.ptr = expr.Expr(args[0], dtype=func)
|
469
|
+
return True
|
470
|
+
return False
|
471
|
+
|
472
|
+
@staticmethod
|
473
|
+
def is_external_func(ctx: ASTTransformerContext, func) -> bool:
|
474
|
+
if ctx.is_in_static_scope(): # allow external function in static scope
|
475
|
+
return False
|
476
|
+
if hasattr(func, "_is_taichi_function") or hasattr(func, "_is_wrapped_kernel"): # taichi func/kernel
|
477
|
+
return False
|
478
|
+
if hasattr(func, "__module__") and func.__module__ and func.__module__.startswith("taichi."):
|
479
|
+
return False
|
480
|
+
return True
|
481
|
+
|
482
|
+
@staticmethod
|
483
|
+
def warn_if_is_external_func(ctx: ASTTransformerContext, node):
|
484
|
+
func = node.func.ptr
|
485
|
+
if not ASTTransformer.is_external_func(ctx, func):
|
486
|
+
return
|
487
|
+
name = unparse(node.func).strip()
|
488
|
+
warnings.warn_explicit(
|
489
|
+
f"\x1b[38;5;226m" # Yellow
|
490
|
+
f'Calling non-taichi function "{name}". '
|
491
|
+
f"Scope inside the function is not processed by the Taichi AST transformer. "
|
492
|
+
f"The function may not work as expected. Proceed with caution! "
|
493
|
+
f"Maybe you can consider turning it into a @ti.func?"
|
494
|
+
f"\x1b[0m", # Reset
|
495
|
+
SyntaxWarning,
|
496
|
+
ctx.file,
|
497
|
+
node.lineno + ctx.lineno_offset,
|
498
|
+
module="taichi",
|
499
|
+
)
|
500
|
+
|
501
|
+
@staticmethod
|
502
|
+
# Parses a formatted string and extracts format specifiers from it, along with positional and keyword arguments.
|
503
|
+
# This function produces a canonicalized formatted string that includes solely empty replacement fields, e.g. 'qwerty {} {} {} {} {}'.
|
504
|
+
# Note that the arguments can be used multiple times in the string.
|
505
|
+
# e.g.:
|
506
|
+
# origin input: 'qwerty {1} {} {1:.3f} {k:.4f} {k:}'.format(1.0, 2.0, k=k)
|
507
|
+
# raw_string: 'qwerty {1} {} {1:.3f} {k:.4f} {k:}'
|
508
|
+
# raw_args: [1.0, 2.0]
|
509
|
+
# raw_keywords: {'k': <ti.Expr>}
|
510
|
+
# return value: ['qwerty {} {} {} {} {}', 2.0, 1.0, ['__ti_fmt_value__', 2.0, '.3f'], ['__ti_fmt_value__', <ti.Expr>, '.4f'], <ti.Expr>]
|
511
|
+
def canonicalize_formatted_string(raw_string: str, *raw_args: list, **raw_keywords: dict):
|
512
|
+
raw_brackets = re.findall(r"{(.*?)}", raw_string)
|
513
|
+
brackets = []
|
514
|
+
unnamed = 0
|
515
|
+
for bracket in raw_brackets:
|
516
|
+
item, spec = bracket.split(":") if ":" in bracket else (bracket, None)
|
517
|
+
if item.isdigit():
|
518
|
+
item = int(item)
|
519
|
+
# handle unnamed positional args
|
520
|
+
if item == "":
|
521
|
+
item = unnamed
|
522
|
+
unnamed += 1
|
523
|
+
# handle empty spec
|
524
|
+
if spec == "":
|
525
|
+
spec = None
|
526
|
+
brackets.append((item, spec))
|
527
|
+
|
528
|
+
# check for errors in the arguments
|
529
|
+
max_args_index = max([t[0] for t in brackets if isinstance(t[0], int)], default=-1)
|
530
|
+
if max_args_index + 1 != len(raw_args):
|
531
|
+
raise TaichiSyntaxError(
|
532
|
+
f"Expected {max_args_index + 1} positional argument(s), but received {len(raw_args)} instead."
|
533
|
+
)
|
534
|
+
brackets_keywords = [t[0] for t in brackets if isinstance(t[0], str)]
|
535
|
+
for item in brackets_keywords:
|
536
|
+
if item not in raw_keywords:
|
537
|
+
raise TaichiSyntaxError(f"Keyword '{item}' not found.")
|
538
|
+
for item in raw_keywords:
|
539
|
+
if item not in brackets_keywords:
|
540
|
+
raise TaichiSyntaxError(f"Keyword '{item}' not used.")
|
541
|
+
|
542
|
+
# reorganize the arguments based on their positions, keywords, and format specifiers
|
543
|
+
args = []
|
544
|
+
for item, spec in brackets:
|
545
|
+
new_arg = raw_args[item] if isinstance(item, int) else raw_keywords[item]
|
546
|
+
if spec is not None:
|
547
|
+
args.append(["__ti_fmt_value__", new_arg, spec])
|
548
|
+
else:
|
549
|
+
args.append(new_arg)
|
550
|
+
# put the formatted string as the first argument to make ti.format() happy
|
551
|
+
args.insert(0, re.sub(r"{.*?}", "{}", raw_string))
|
552
|
+
return args
|
553
|
+
|
554
|
+
@staticmethod
|
555
|
+
def expand_node_args_dataclasses(args: tuple[ast.AST, ...]) -> tuple[ast.AST, ...]:
|
556
|
+
args_new = []
|
557
|
+
for arg in args:
|
558
|
+
val = arg.ptr
|
559
|
+
if dataclasses.is_dataclass(val):
|
560
|
+
dataclass_type = val
|
561
|
+
for field in dataclasses.fields(dataclass_type):
|
562
|
+
child_name = f"__ti_{arg.id}_{field.name}"
|
563
|
+
load_ctx = ast.Load()
|
564
|
+
arg_node = ast.Name(
|
565
|
+
id=child_name,
|
566
|
+
ctx=load_ctx,
|
567
|
+
lineno=arg.lineno,
|
568
|
+
end_lineno=arg.end_lineno,
|
569
|
+
col_offset=arg.col_offset,
|
570
|
+
end_col_offset=arg.end_col_offset,
|
571
|
+
)
|
572
|
+
args_new.append(arg_node)
|
573
|
+
else:
|
574
|
+
args_new.append(arg)
|
575
|
+
return tuple(args_new)
|
576
|
+
|
577
|
+
@staticmethod
|
578
|
+
def build_Call(ctx: ASTTransformerContext, node: ast.Call):
|
579
|
+
if ASTTransformer.get_decorator(ctx, node) in ["static", "static_assert"]:
|
580
|
+
with ctx.static_scope_guard():
|
581
|
+
build_stmt(ctx, node.func)
|
582
|
+
build_stmts(ctx, node.args)
|
583
|
+
build_stmts(ctx, node.keywords)
|
584
|
+
else:
|
585
|
+
build_stmt(ctx, node.func)
|
586
|
+
# creates variable for the dataclass itself (as well as other variables,
|
587
|
+
# not related to dataclasses). Necessary for calling further child functions
|
588
|
+
build_stmts(ctx, node.args)
|
589
|
+
node.args = ASTTransformer.expand_node_args_dataclasses(node.args)
|
590
|
+
# create variables for the now-expanded dataclass members
|
591
|
+
build_stmts(ctx, node.args)
|
592
|
+
build_stmts(ctx, node.keywords)
|
593
|
+
|
594
|
+
args = []
|
595
|
+
for arg in node.args:
|
596
|
+
if isinstance(arg, ast.Starred):
|
597
|
+
arg_list = arg.ptr
|
598
|
+
if isinstance(arg_list, Expr) and arg_list.is_tensor():
|
599
|
+
# Expand Expr with Matrix-type return into list of Exprs
|
600
|
+
arg_list = [Expr(x) for x in ctx.ast_builder.expand_exprs([arg_list.ptr])]
|
601
|
+
|
602
|
+
for i in arg_list:
|
603
|
+
args.append(i)
|
604
|
+
else:
|
605
|
+
args.append(arg.ptr)
|
606
|
+
keywords = dict(ChainMap(*[keyword.ptr for keyword in node.keywords]))
|
607
|
+
func = node.func.ptr
|
608
|
+
|
609
|
+
if id(func) in [id(print), id(impl.ti_print)]:
|
610
|
+
ctx.func.has_print = True
|
611
|
+
|
612
|
+
if isinstance(node.func, ast.Attribute) and isinstance(node.func.value.ptr, str) and node.func.attr == "format":
|
613
|
+
raw_string = node.func.value.ptr
|
614
|
+
args = ASTTransformer.canonicalize_formatted_string(raw_string, *args, **keywords)
|
615
|
+
node.ptr = impl.ti_format(*args)
|
616
|
+
return node.ptr
|
617
|
+
|
618
|
+
if id(func) == id(Matrix) or id(func) == id(Vector):
|
619
|
+
node.ptr = matrix.make_matrix(*args, **keywords)
|
620
|
+
return node.ptr
|
621
|
+
|
622
|
+
if ASTTransformer.build_call_if_is_builtin(ctx, node, args, keywords):
|
623
|
+
return node.ptr
|
624
|
+
|
625
|
+
if ASTTransformer.build_call_if_is_type(ctx, node, args, keywords):
|
626
|
+
return node.ptr
|
627
|
+
|
628
|
+
if hasattr(node.func, "caller"):
|
629
|
+
node.ptr = func(node.func.caller, *args, **keywords)
|
630
|
+
return node.ptr
|
631
|
+
ASTTransformer.warn_if_is_external_func(ctx, node)
|
632
|
+
try:
|
633
|
+
node.ptr = func(*args, **keywords)
|
634
|
+
except TypeError as e:
|
635
|
+
module = inspect.getmodule(func)
|
636
|
+
error_msg = re.sub(r"\bExpr\b", "Taichi Expression", str(e))
|
637
|
+
msg = f"TypeError when calling `{func.__name__}`: {error_msg}."
|
638
|
+
if ASTTransformer.is_external_func(ctx, node.func.ptr):
|
639
|
+
args_has_expr = any([isinstance(arg, Expr) for arg in args])
|
640
|
+
if args_has_expr and (module == math or module == np):
|
641
|
+
exec_str = f"from taichi import {func.__name__}"
|
642
|
+
try:
|
643
|
+
exec(exec_str, {})
|
644
|
+
except:
|
645
|
+
pass
|
646
|
+
else:
|
647
|
+
msg += f"\nDid you mean to use `ti.{func.__name__}` instead of `{module.__name__}.{func.__name__}`?"
|
648
|
+
raise TaichiTypeError(msg)
|
649
|
+
|
650
|
+
if getattr(func, "_is_taichi_function", False):
|
651
|
+
ctx.func.has_print |= func.func.has_print
|
652
|
+
|
653
|
+
return node.ptr
|
654
|
+
|
655
|
+
@staticmethod
|
656
|
+
def build_FunctionDef(ctx: ASTTransformerContext, node: ast.FunctionDef):
|
657
|
+
if ctx.visited_funcdef:
|
658
|
+
raise TaichiSyntaxError(
|
659
|
+
f"Function definition is not allowed in 'ti.{'kernel' if ctx.is_kernel else 'func'}'."
|
660
|
+
)
|
661
|
+
ctx.visited_funcdef = True
|
662
|
+
|
663
|
+
args = node.args
|
664
|
+
assert args.vararg is None
|
665
|
+
assert args.kwonlyargs == []
|
666
|
+
assert args.kw_defaults == []
|
667
|
+
assert args.kwarg is None
|
668
|
+
|
669
|
+
def decl_and_create_variable(
|
670
|
+
annotation, name, arg_features, invoke_later_dict, prefix_name, arg_depth
|
671
|
+
) -> tuple[bool, Any]:
|
672
|
+
full_name = prefix_name + "_" + name
|
673
|
+
if not isinstance(annotation, primitive_types.RefType):
|
674
|
+
ctx.kernel_args.append(name)
|
675
|
+
if isinstance(annotation, ArgPackType):
|
676
|
+
kernel_arguments.push_argpack_arg(name)
|
677
|
+
d = {}
|
678
|
+
items_to_put_in_dict = []
|
679
|
+
for j, (_name, anno) in enumerate(annotation.members.items()):
|
680
|
+
result, obj = decl_and_create_variable(
|
681
|
+
anno, _name, arg_features[j], invoke_later_dict, full_name, arg_depth + 1
|
682
|
+
)
|
683
|
+
if not result:
|
684
|
+
d[_name] = None
|
685
|
+
items_to_put_in_dict.append((full_name + "_" + _name, _name, obj))
|
686
|
+
else:
|
687
|
+
d[_name] = obj
|
688
|
+
argpack = kernel_arguments.decl_argpack_arg(annotation, d)
|
689
|
+
for item in items_to_put_in_dict:
|
690
|
+
invoke_later_dict[item[0]] = argpack, item[1], *item[2]
|
691
|
+
return True, argpack
|
692
|
+
if annotation == annotations.template or isinstance(annotation, annotations.template):
|
693
|
+
return True, ctx.global_vars[name]
|
694
|
+
if isinstance(annotation, annotations.sparse_matrix_builder):
|
695
|
+
return False, (
|
696
|
+
kernel_arguments.decl_sparse_matrix,
|
697
|
+
(
|
698
|
+
to_taichi_type(arg_features),
|
699
|
+
full_name,
|
700
|
+
),
|
701
|
+
)
|
702
|
+
if isinstance(annotation, ndarray_type.NdarrayType):
|
703
|
+
return False, (
|
704
|
+
kernel_arguments.decl_ndarray_arg,
|
705
|
+
(
|
706
|
+
to_taichi_type(arg_features[0]),
|
707
|
+
arg_features[1],
|
708
|
+
full_name,
|
709
|
+
arg_features[2],
|
710
|
+
arg_features[3],
|
711
|
+
),
|
712
|
+
)
|
713
|
+
if isinstance(annotation, texture_type.TextureType):
|
714
|
+
return False, (kernel_arguments.decl_texture_arg, (arg_features[0], full_name))
|
715
|
+
if isinstance(annotation, texture_type.RWTextureType):
|
716
|
+
return False, (
|
717
|
+
kernel_arguments.decl_rw_texture_arg,
|
718
|
+
(arg_features[0], arg_features[1], arg_features[2], full_name),
|
719
|
+
)
|
720
|
+
if isinstance(annotation, MatrixType):
|
721
|
+
return True, kernel_arguments.decl_matrix_arg(annotation, name, arg_depth)
|
722
|
+
if isinstance(annotation, StructType):
|
723
|
+
return True, kernel_arguments.decl_struct_arg(annotation, name, arg_depth)
|
724
|
+
return True, kernel_arguments.decl_scalar_arg(annotation, name, arg_depth)
|
725
|
+
|
726
|
+
def transform_as_kernel() -> None:
|
727
|
+
if node.returns is not None:
|
728
|
+
if not isinstance(node.returns, ast.Constant):
|
729
|
+
for return_type in ctx.func.return_type:
|
730
|
+
kernel_arguments.decl_ret(return_type)
|
731
|
+
impl.get_runtime().compiling_callable.finalize_rets()
|
732
|
+
|
733
|
+
invoke_later_dict: dict[str, tuple[Any, str, Any]] = dict()
|
734
|
+
create_variable_later = dict()
|
735
|
+
for i, arg in enumerate(args.args):
|
736
|
+
argument = ctx.func.arguments[i]
|
737
|
+
if isinstance(argument.annotation, ArgPackType):
|
738
|
+
kernel_arguments.push_argpack_arg(argument.name)
|
739
|
+
d = {}
|
740
|
+
items_to_put_in_dict: list[tuple[str, str, Any]] = []
|
741
|
+
for j, (name, anno) in enumerate(argument.annotation.members.items()):
|
742
|
+
result, obj = decl_and_create_variable(
|
743
|
+
anno, name, ctx.arg_features[i][j], invoke_later_dict, "__argpack_" + name, 1
|
744
|
+
)
|
745
|
+
if not result:
|
746
|
+
d[name] = None
|
747
|
+
items_to_put_in_dict.append(("__argpack_" + name, name, obj))
|
748
|
+
else:
|
749
|
+
d[name] = obj
|
750
|
+
argpack = kernel_arguments.decl_argpack_arg(ctx.func.arguments[i].annotation, d)
|
751
|
+
for item in items_to_put_in_dict:
|
752
|
+
invoke_later_dict[item[0]] = argpack, item[1], *item[2]
|
753
|
+
create_variable_later[arg.arg] = argpack
|
754
|
+
elif dataclasses.is_dataclass(argument.annotation):
|
755
|
+
arg_features = ctx.arg_features[i]
|
756
|
+
ctx.create_variable(argument.name, argument.annotation)
|
757
|
+
for field_idx, field in enumerate(dataclasses.fields(argument.annotation)):
|
758
|
+
flat_name = f"__ti_{argument.name}_{field.name}"
|
759
|
+
result, obj = decl_and_create_variable(
|
760
|
+
field.type,
|
761
|
+
flat_name,
|
762
|
+
arg_features[field_idx],
|
763
|
+
invoke_later_dict,
|
764
|
+
"",
|
765
|
+
0,
|
766
|
+
)
|
767
|
+
if result:
|
768
|
+
ctx.create_variable(flat_name, obj)
|
769
|
+
else:
|
770
|
+
decl_type_func, type_args = obj
|
771
|
+
obj = decl_type_func(*type_args)
|
772
|
+
ctx.create_variable(flat_name, obj)
|
773
|
+
else:
|
774
|
+
result, obj = decl_and_create_variable(
|
775
|
+
argument.annotation,
|
776
|
+
argument.name,
|
777
|
+
ctx.arg_features[i] if ctx.arg_features is not None else None,
|
778
|
+
invoke_later_dict,
|
779
|
+
"",
|
780
|
+
0,
|
781
|
+
)
|
782
|
+
if result:
|
783
|
+
ctx.create_variable(arg.arg, obj)
|
784
|
+
else:
|
785
|
+
decl_type_func, type_args = obj
|
786
|
+
obj = decl_type_func(*type_args)
|
787
|
+
ctx.create_variable(arg.arg, obj)
|
788
|
+
for k, v in invoke_later_dict.items():
|
789
|
+
argpack, name, func, params = v
|
790
|
+
argpack[name] = func(*params)
|
791
|
+
for k, v in create_variable_later.items():
|
792
|
+
ctx.create_variable(k, v)
|
793
|
+
|
794
|
+
impl.get_runtime().compiling_callable.finalize_params()
|
795
|
+
# remove original args
|
796
|
+
node.args.args = []
|
797
|
+
|
798
|
+
if ctx.is_kernel: # ti.kernel
|
799
|
+
transform_as_kernel()
|
800
|
+
|
801
|
+
else: # ti.func
|
802
|
+
if ctx.is_real_function:
|
803
|
+
transform_as_kernel()
|
804
|
+
else:
|
805
|
+
for data_i, data in enumerate(ctx.argument_data):
|
806
|
+
argument = ctx.func.arguments[data_i]
|
807
|
+
if isinstance(argument.annotation, annotations.template):
|
808
|
+
ctx.create_variable(argument.name, data)
|
809
|
+
continue
|
810
|
+
|
811
|
+
elif dataclasses.is_dataclass(argument.annotation):
|
812
|
+
dataclass_type = argument.annotation
|
813
|
+
for field in dataclasses.fields(dataclass_type):
|
814
|
+
data_child = getattr(data, field.name)
|
815
|
+
if not isinstance(
|
816
|
+
data_child,
|
817
|
+
(
|
818
|
+
_ndarray.ScalarNdarray,
|
819
|
+
matrix.VectorNdarray,
|
820
|
+
matrix.MatrixNdarray,
|
821
|
+
any_array.AnyArray,
|
822
|
+
),
|
823
|
+
):
|
824
|
+
raise TaichiSyntaxError(
|
825
|
+
f"Argument {argument.name} of type {dataclass_type} {field.type} is not recognized."
|
826
|
+
)
|
827
|
+
field.type.check_matched(data_child.get_type(), field.name)
|
828
|
+
var_name = f"__ti_{argument.name}_{field.name}"
|
829
|
+
ctx.create_variable(var_name, data_child)
|
830
|
+
continue
|
831
|
+
|
832
|
+
# Ndarray arguments are passed by reference.
|
833
|
+
if isinstance(argument.annotation, (ndarray_type.NdarrayType)):
|
834
|
+
if not isinstance(
|
835
|
+
data,
|
836
|
+
(
|
837
|
+
_ndarray.ScalarNdarray,
|
838
|
+
matrix.VectorNdarray,
|
839
|
+
matrix.MatrixNdarray,
|
840
|
+
any_array.AnyArray,
|
841
|
+
),
|
842
|
+
):
|
843
|
+
raise TaichiSyntaxError(
|
844
|
+
f"Argument {arg.arg} of type {argument.annotation} is not recognized."
|
845
|
+
)
|
846
|
+
argument.annotation.check_matched(data.get_type(), argument.name)
|
847
|
+
ctx.create_variable(argument.name, data)
|
848
|
+
continue
|
849
|
+
|
850
|
+
# Matrix arguments are passed by value.
|
851
|
+
if isinstance(argument.annotation, (MatrixType)):
|
852
|
+
var_name = argument.name
|
853
|
+
# "data" is expected to be an Expr here,
|
854
|
+
# so we simply call "impl.expr_init_func(data)" to perform:
|
855
|
+
#
|
856
|
+
# TensorType* t = alloca()
|
857
|
+
# assign(t, data)
|
858
|
+
#
|
859
|
+
# We created local variable "t" - a copy of the passed-in argument "data"
|
860
|
+
if not isinstance(data, expr.Expr) or not data.ptr.is_tensor():
|
861
|
+
raise TaichiSyntaxError(
|
862
|
+
f"Argument {var_name} of type {argument.annotation} is expected to be a Matrix, but got {type(data)}."
|
863
|
+
)
|
864
|
+
|
865
|
+
element_shape = data.ptr.get_rvalue_type().shape()
|
866
|
+
if len(element_shape) != argument.annotation.ndim:
|
867
|
+
raise TaichiSyntaxError(
|
868
|
+
f"Argument {var_name} of type {argument.annotation} is expected to be a Matrix with ndim {argument.annotation.ndim}, but got {len(element_shape)}."
|
869
|
+
)
|
870
|
+
|
871
|
+
assert argument.annotation.ndim > 0
|
872
|
+
if element_shape[0] != argument.annotation.n:
|
873
|
+
raise TaichiSyntaxError(
|
874
|
+
f"Argument {var_name} of type {argument.annotation} is expected to be a Matrix with n {argument.annotation.n}, but got {element_shape[0]}."
|
875
|
+
)
|
876
|
+
|
877
|
+
if argument.annotation.ndim == 2 and element_shape[1] != argument.annotation.m:
|
878
|
+
raise TaichiSyntaxError(
|
879
|
+
f"Argument {var_name} of type {argument.annotation} is expected to be a Matrix with m {argument.annotation.m}, but got {element_shape[0]}."
|
880
|
+
)
|
881
|
+
|
882
|
+
ctx.create_variable(var_name, impl.expr_init_func(data))
|
883
|
+
continue
|
884
|
+
|
885
|
+
if id(argument.annotation) in primitive_types.type_ids:
|
886
|
+
var_name = argument.name
|
887
|
+
ctx.create_variable(var_name, impl.expr_init_func(ti_ops.cast(data, argument.annotation)))
|
888
|
+
continue
|
889
|
+
# Create a copy for non-template arguments,
|
890
|
+
# so that they are passed by value.
|
891
|
+
var_name = argument.name
|
892
|
+
ctx.create_variable(var_name, impl.expr_init_func(data))
|
893
|
+
for v in ctx.func.orig_arguments:
|
894
|
+
if dataclasses.is_dataclass(v.annotation):
|
895
|
+
ctx.create_variable(v.name, v.annotation)
|
896
|
+
|
897
|
+
with ctx.variable_scope_guard():
|
898
|
+
build_stmts(ctx, node.body)
|
899
|
+
|
900
|
+
return None
|
901
|
+
|
902
|
+
@staticmethod
|
903
|
+
def build_Return(ctx: ASTTransformerContext, node: ast.Return) -> None:
|
904
|
+
if not ctx.is_real_function:
|
905
|
+
if ctx.is_in_non_static_control_flow():
|
906
|
+
raise TaichiSyntaxError("Return inside non-static if/for is not supported")
|
907
|
+
if node.value is not None:
|
908
|
+
build_stmt(ctx, node.value)
|
909
|
+
if node.value is None or node.value.ptr is None:
|
910
|
+
if not ctx.is_real_function:
|
911
|
+
ctx.returned = ReturnStatus.ReturnedVoid
|
912
|
+
return None
|
913
|
+
if ctx.is_kernel or ctx.is_real_function:
|
914
|
+
# TODO: check if it's at the end of a kernel, throw TaichiSyntaxError if not
|
915
|
+
if ctx.func.return_type is None:
|
916
|
+
raise TaichiSyntaxError(
|
917
|
+
f'A {"kernel" if ctx.is_kernel else "function"} '
|
918
|
+
"with a return value must be annotated "
|
919
|
+
"with a return type, e.g. def func() -> ti.f32"
|
920
|
+
)
|
921
|
+
return_exprs = []
|
922
|
+
if len(ctx.func.return_type) == 1:
|
923
|
+
node.value.ptr = [node.value.ptr]
|
924
|
+
assert len(ctx.func.return_type) == len(node.value.ptr)
|
925
|
+
for return_type, ptr in zip(ctx.func.return_type, node.value.ptr):
|
926
|
+
if id(return_type) in primitive_types.type_ids:
|
927
|
+
if isinstance(ptr, Expr):
|
928
|
+
if ptr.is_tensor() or ptr.is_struct() or ptr.element_type() not in primitive_types.all_types:
|
929
|
+
raise TaichiRuntimeTypeError.get_ret(str(return_type), ptr)
|
930
|
+
elif not isinstance(ptr, (float, int, np.floating, np.integer)):
|
931
|
+
raise TaichiRuntimeTypeError.get_ret(str(return_type), ptr)
|
932
|
+
return_exprs += [ti_ops.cast(expr.Expr(ptr), return_type).ptr]
|
933
|
+
elif isinstance(return_type, MatrixType):
|
934
|
+
values = ptr
|
935
|
+
if isinstance(values, Matrix):
|
936
|
+
if values.ndim != ctx.func.return_type.ndim:
|
937
|
+
raise TaichiRuntimeTypeError(
|
938
|
+
f"Return matrix ndim mismatch, expecting={return_type.ndim}, got={values.ndim}."
|
939
|
+
)
|
940
|
+
elif return_type.get_shape() != values.get_shape():
|
941
|
+
raise TaichiRuntimeTypeError(
|
942
|
+
f"Return matrix shape mismatch, expecting={return_type.get_shape()}, got={values.get_shape()}."
|
943
|
+
)
|
944
|
+
values = (
|
945
|
+
itertools.chain.from_iterable(values.to_list())
|
946
|
+
if values.ndim == 1
|
947
|
+
else iter(values.to_list())
|
948
|
+
)
|
949
|
+
elif isinstance(values, Expr):
|
950
|
+
if not values.is_tensor():
|
951
|
+
raise TaichiRuntimeTypeError.get_ret(return_type.to_string(), ptr)
|
952
|
+
elif (
|
953
|
+
return_type.dtype in primitive_types.real_types
|
954
|
+
and not values.element_type() in primitive_types.all_types
|
955
|
+
):
|
956
|
+
raise TaichiRuntimeTypeError.get_ret(return_type.dtype.to_string(), values.element_type())
|
957
|
+
elif (
|
958
|
+
return_type.dtype in primitive_types.integer_types
|
959
|
+
and not values.element_type() in primitive_types.integer_types
|
960
|
+
):
|
961
|
+
raise TaichiRuntimeTypeError.get_ret(return_type.dtype.to_string(), values.element_type())
|
962
|
+
elif len(values.get_shape()) != return_type.ndim:
|
963
|
+
raise TaichiRuntimeTypeError(
|
964
|
+
f"Return matrix ndim mismatch, expecting={return_type.ndim}, got={len(values.get_shape())}."
|
965
|
+
)
|
966
|
+
elif return_type.get_shape() != values.get_shape():
|
967
|
+
raise TaichiRuntimeTypeError(
|
968
|
+
f"Return matrix shape mismatch, expecting={return_type.get_shape()}, got={values.get_shape()}."
|
969
|
+
)
|
970
|
+
values = [values]
|
971
|
+
else:
|
972
|
+
np_array = np.array(values)
|
973
|
+
dt, shape, ndim = np_array.dtype, np_array.shape, np_array.ndim
|
974
|
+
if return_type.dtype in primitive_types.real_types and dt not in (
|
975
|
+
float,
|
976
|
+
int,
|
977
|
+
np.floating,
|
978
|
+
np.integer,
|
979
|
+
):
|
980
|
+
raise TaichiRuntimeTypeError.get_ret(return_type.dtype.to_string(), dt)
|
981
|
+
elif return_type.dtype in primitive_types.integer_types and dt not in (int, np.integer):
|
982
|
+
raise TaichiRuntimeTypeError.get_ret(return_type.dtype.to_string(), dt)
|
983
|
+
elif ndim != return_type.ndim:
|
984
|
+
raise TaichiRuntimeTypeError(
|
985
|
+
f"Return matrix ndim mismatch, expecting={return_type.ndim}, got={ndim}."
|
986
|
+
)
|
987
|
+
elif return_type.get_shape() != shape:
|
988
|
+
raise TaichiRuntimeTypeError(
|
989
|
+
f"Return matrix shape mismatch, expecting={return_type.get_shape()}, got={shape}."
|
990
|
+
)
|
991
|
+
values = [values]
|
992
|
+
return_exprs += [ti_ops.cast(exp, return_type.dtype) for exp in values]
|
993
|
+
elif isinstance(return_type, StructType):
|
994
|
+
if not isinstance(ptr, Struct) or not isinstance(ptr, return_type):
|
995
|
+
raise TaichiRuntimeTypeError.get_ret(str(return_type), ptr)
|
996
|
+
values = ptr
|
997
|
+
assert isinstance(values, Struct)
|
998
|
+
return_exprs += expr._get_flattened_ptrs(values)
|
999
|
+
else:
|
1000
|
+
raise TaichiSyntaxError("The return type is not supported now!")
|
1001
|
+
ctx.ast_builder.create_kernel_exprgroup_return(
|
1002
|
+
expr.make_expr_group(return_exprs), _ti_core.DebugInfo(ctx.get_pos_info(node))
|
1003
|
+
)
|
1004
|
+
else:
|
1005
|
+
ctx.return_data = node.value.ptr
|
1006
|
+
if ctx.func.return_type is not None:
|
1007
|
+
if len(ctx.func.return_type) == 1:
|
1008
|
+
ctx.return_data = [ctx.return_data]
|
1009
|
+
for i, return_type in enumerate(ctx.func.return_type):
|
1010
|
+
if id(return_type) in primitive_types.type_ids:
|
1011
|
+
ctx.return_data[i] = ti_ops.cast(ctx.return_data[i], return_type)
|
1012
|
+
if len(ctx.func.return_type) == 1:
|
1013
|
+
ctx.return_data = ctx.return_data[0]
|
1014
|
+
if not ctx.is_real_function:
|
1015
|
+
ctx.returned = ReturnStatus.ReturnedValue
|
1016
|
+
return None
|
1017
|
+
|
1018
|
+
@staticmethod
|
1019
|
+
def build_Module(ctx: ASTTransformerContext, node: ast.Module) -> None:
|
1020
|
+
with ctx.variable_scope_guard():
|
1021
|
+
# Do NOT use |build_stmts| which inserts 'del' statements to the
|
1022
|
+
# end and deletes parameters passed into the module
|
1023
|
+
for stmt in node.body:
|
1024
|
+
build_stmt(ctx, stmt)
|
1025
|
+
return None
|
1026
|
+
|
1027
|
+
@staticmethod
|
1028
|
+
def build_attribute_if_is_dynamic_snode_method(ctx: ASTTransformerContext, node) -> bool:
|
1029
|
+
is_subscript = isinstance(node.value, ast.Subscript)
|
1030
|
+
names = ("append", "deactivate", "length")
|
1031
|
+
if node.attr not in names:
|
1032
|
+
return False
|
1033
|
+
if is_subscript:
|
1034
|
+
x = node.value.value.ptr
|
1035
|
+
indices = node.value.slice.ptr
|
1036
|
+
else:
|
1037
|
+
x = node.value.ptr
|
1038
|
+
indices = []
|
1039
|
+
if not isinstance(x, Field):
|
1040
|
+
return False
|
1041
|
+
if not x.parent().ptr.type == _ti_core.SNodeType.dynamic:
|
1042
|
+
return False
|
1043
|
+
field_dim = x.snode.ptr.num_active_indices()
|
1044
|
+
indices_expr_group = make_expr_group(*indices)
|
1045
|
+
index_dim = indices_expr_group.size()
|
1046
|
+
if field_dim != index_dim + 1:
|
1047
|
+
return False
|
1048
|
+
if node.attr == "append":
|
1049
|
+
node.ptr = lambda val: append(x.parent(), indices, val)
|
1050
|
+
elif node.attr == "deactivate":
|
1051
|
+
node.ptr = lambda: deactivate(x.parent(), indices)
|
1052
|
+
else:
|
1053
|
+
node.ptr = lambda: length(x.parent(), indices)
|
1054
|
+
return True
|
1055
|
+
|
1056
|
+
@staticmethod
|
1057
|
+
def build_Attribute(ctx: ASTTransformerContext, node: ast.Attribute):
|
1058
|
+
# There are two valid cases for the methods of Dynamic SNode:
|
1059
|
+
#
|
1060
|
+
# 1. x[i, j].append (where the dimension of the field (3 in this case) is equal to one plus the number of the
|
1061
|
+
# indices (2 in this case) )
|
1062
|
+
#
|
1063
|
+
# 2. x.append (where the dimension of the field is one, equal to x[()].append)
|
1064
|
+
#
|
1065
|
+
# For the first case, the AST (simplified) is like node = Attribute(value=Subscript(value=x, slice=[i, j]),
|
1066
|
+
# attr="append"), when we build_stmt(node.value)(build the expression of the Subscript i.e. x[i, j]),
|
1067
|
+
# it should build the expression of node.value.value (i.e. x) and node.value.slice (i.e. [i, j]), and raise a
|
1068
|
+
# TaichiIndexError because the dimension of the field is not equal to the number of the indices. Therefore,
|
1069
|
+
# when we meet the error, we can detect whether it is a method of Dynamic SNode and build the expression if
|
1070
|
+
# it is by calling build_attribute_if_is_dynamic_snode_method. If we find that it is not a method of Dynamic
|
1071
|
+
# SNode, we raise the error again.
|
1072
|
+
#
|
1073
|
+
# For the second case, the AST (simplified) is like node = Attribute(value=x, attr="append"), and it does not
|
1074
|
+
# raise error when we build_stmt(node.value). Therefore, when we do not meet the error, we can also detect
|
1075
|
+
# whether it is a method of Dynamic SNode and build the expression if it is by calling
|
1076
|
+
# build_attribute_if_is_dynamic_snode_method. If we find that it is not a method of Dynamic SNode,
|
1077
|
+
# we continue to process it as a normal attribute node.
|
1078
|
+
try:
|
1079
|
+
build_stmt(ctx, node.value)
|
1080
|
+
except Exception as e:
|
1081
|
+
e = handle_exception_from_cpp(e)
|
1082
|
+
if isinstance(e, TaichiIndexError):
|
1083
|
+
node.value.ptr = None
|
1084
|
+
if ASTTransformer.build_attribute_if_is_dynamic_snode_method(ctx, node):
|
1085
|
+
return node.ptr
|
1086
|
+
raise e
|
1087
|
+
|
1088
|
+
if ASTTransformer.build_attribute_if_is_dynamic_snode_method(ctx, node):
|
1089
|
+
return node.ptr
|
1090
|
+
|
1091
|
+
if isinstance(node.value.ptr, Expr) and not hasattr(node.value.ptr, node.attr):
|
1092
|
+
if node.attr in Matrix._swizzle_to_keygroup:
|
1093
|
+
keygroup = Matrix._swizzle_to_keygroup[node.attr]
|
1094
|
+
Matrix._keygroup_to_checker[keygroup](node.value.ptr, node.attr)
|
1095
|
+
attr_len = len(node.attr)
|
1096
|
+
if attr_len == 1:
|
1097
|
+
node.ptr = Expr(
|
1098
|
+
impl.get_runtime()
|
1099
|
+
.compiling_callable.ast_builder()
|
1100
|
+
.expr_subscript(
|
1101
|
+
node.value.ptr.ptr,
|
1102
|
+
make_expr_group(keygroup.index(node.attr)),
|
1103
|
+
_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
|
1104
|
+
)
|
1105
|
+
)
|
1106
|
+
else:
|
1107
|
+
node.ptr = Expr(
|
1108
|
+
_ti_core.subscript_with_multiple_indices(
|
1109
|
+
node.value.ptr.ptr,
|
1110
|
+
[make_expr_group(keygroup.index(ch)) for ch in node.attr],
|
1111
|
+
(attr_len,),
|
1112
|
+
_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
|
1113
|
+
)
|
1114
|
+
)
|
1115
|
+
else:
|
1116
|
+
from taichi.lang import ( # pylint: disable=C0415
|
1117
|
+
matrix_ops as tensor_ops,
|
1118
|
+
)
|
1119
|
+
|
1120
|
+
node.ptr = getattr(tensor_ops, node.attr)
|
1121
|
+
setattr(node, "caller", node.value.ptr)
|
1122
|
+
else:
|
1123
|
+
node.ptr = getattr(node.value.ptr, node.attr)
|
1124
|
+
return node.ptr
|
1125
|
+
|
1126
|
+
@staticmethod
|
1127
|
+
def build_BinOp(ctx: ASTTransformerContext, node: ast.BinOp):
|
1128
|
+
build_stmt(ctx, node.left)
|
1129
|
+
build_stmt(ctx, node.right)
|
1130
|
+
# pylint: disable-msg=C0415
|
1131
|
+
from taichi.lang.matrix_ops import matmul
|
1132
|
+
|
1133
|
+
op = {
|
1134
|
+
ast.Add: lambda l, r: l + r,
|
1135
|
+
ast.Sub: lambda l, r: l - r,
|
1136
|
+
ast.Mult: lambda l, r: l * r,
|
1137
|
+
ast.Div: lambda l, r: l / r,
|
1138
|
+
ast.FloorDiv: lambda l, r: l // r,
|
1139
|
+
ast.Mod: lambda l, r: l % r,
|
1140
|
+
ast.Pow: lambda l, r: l**r,
|
1141
|
+
ast.LShift: lambda l, r: l << r,
|
1142
|
+
ast.RShift: lambda l, r: l >> r,
|
1143
|
+
ast.BitOr: lambda l, r: l | r,
|
1144
|
+
ast.BitXor: lambda l, r: l ^ r,
|
1145
|
+
ast.BitAnd: lambda l, r: l & r,
|
1146
|
+
ast.MatMult: matmul,
|
1147
|
+
}.get(type(node.op))
|
1148
|
+
try:
|
1149
|
+
node.ptr = op(node.left.ptr, node.right.ptr)
|
1150
|
+
except TypeError as e:
|
1151
|
+
raise TaichiTypeError(str(e)) from None
|
1152
|
+
return node.ptr
|
1153
|
+
|
1154
|
+
@staticmethod
|
1155
|
+
def build_AugAssign(ctx: ASTTransformerContext, node: ast.AugAssign):
|
1156
|
+
build_stmt(ctx, node.target)
|
1157
|
+
build_stmt(ctx, node.value)
|
1158
|
+
if isinstance(node.target, ast.Name) and node.target.id in ctx.kernel_args:
|
1159
|
+
raise TaichiSyntaxError(
|
1160
|
+
f'Kernel argument "{node.target.id}" is immutable in the kernel. '
|
1161
|
+
f"If you want to change its value, please create a new variable."
|
1162
|
+
)
|
1163
|
+
node.ptr = node.target.ptr._augassign(node.value.ptr, type(node.op).__name__)
|
1164
|
+
return node.ptr
|
1165
|
+
|
1166
|
+
@staticmethod
|
1167
|
+
def build_UnaryOp(ctx: ASTTransformerContext, node: ast.UnaryOp):
|
1168
|
+
build_stmt(ctx, node.operand)
|
1169
|
+
op = {
|
1170
|
+
ast.UAdd: lambda l: l,
|
1171
|
+
ast.USub: lambda l: -l,
|
1172
|
+
ast.Not: ti_ops.logical_not,
|
1173
|
+
ast.Invert: lambda l: ~l,
|
1174
|
+
}.get(type(node.op))
|
1175
|
+
node.ptr = op(node.operand.ptr)
|
1176
|
+
return node.ptr
|
1177
|
+
|
1178
|
+
@staticmethod
|
1179
|
+
def build_bool_op(op):
|
1180
|
+
def inner(operands):
|
1181
|
+
if len(operands) == 1:
|
1182
|
+
return operands[0].ptr
|
1183
|
+
return op(operands[0].ptr, inner(operands[1:]))
|
1184
|
+
|
1185
|
+
return inner
|
1186
|
+
|
1187
|
+
@staticmethod
|
1188
|
+
def build_static_and(operands):
|
1189
|
+
for operand in operands:
|
1190
|
+
if not operand.ptr:
|
1191
|
+
return operand.ptr
|
1192
|
+
return operands[-1].ptr
|
1193
|
+
|
1194
|
+
@staticmethod
|
1195
|
+
def build_static_or(operands):
|
1196
|
+
for operand in operands:
|
1197
|
+
if operand.ptr:
|
1198
|
+
return operand.ptr
|
1199
|
+
return operands[-1].ptr
|
1200
|
+
|
1201
|
+
@staticmethod
|
1202
|
+
def build_BoolOp(ctx: ASTTransformerContext, node: ast.BoolOp):
|
1203
|
+
build_stmts(ctx, node.values)
|
1204
|
+
if ctx.is_in_static_scope():
|
1205
|
+
ops = {
|
1206
|
+
ast.And: ASTTransformer.build_static_and,
|
1207
|
+
ast.Or: ASTTransformer.build_static_or,
|
1208
|
+
}
|
1209
|
+
elif impl.get_runtime().short_circuit_operators:
|
1210
|
+
ops = {
|
1211
|
+
ast.And: ASTTransformer.build_bool_op(ti_ops.logical_and),
|
1212
|
+
ast.Or: ASTTransformer.build_bool_op(ti_ops.logical_or),
|
1213
|
+
}
|
1214
|
+
else:
|
1215
|
+
ops = {
|
1216
|
+
ast.And: ASTTransformer.build_bool_op(ti_ops.bit_and),
|
1217
|
+
ast.Or: ASTTransformer.build_bool_op(ti_ops.bit_or),
|
1218
|
+
}
|
1219
|
+
op = ops.get(type(node.op))
|
1220
|
+
node.ptr = op(node.values)
|
1221
|
+
return node.ptr
|
1222
|
+
|
1223
|
+
@staticmethod
|
1224
|
+
def build_Compare(ctx: ASTTransformerContext, node: ast.Compare):
|
1225
|
+
build_stmt(ctx, node.left)
|
1226
|
+
build_stmts(ctx, node.comparators)
|
1227
|
+
ops = {
|
1228
|
+
ast.Eq: lambda l, r: l == r,
|
1229
|
+
ast.NotEq: lambda l, r: l != r,
|
1230
|
+
ast.Lt: lambda l, r: l < r,
|
1231
|
+
ast.LtE: lambda l, r: l <= r,
|
1232
|
+
ast.Gt: lambda l, r: l > r,
|
1233
|
+
ast.GtE: lambda l, r: l >= r,
|
1234
|
+
}
|
1235
|
+
ops_static = {
|
1236
|
+
ast.In: lambda l, r: l in r,
|
1237
|
+
ast.NotIn: lambda l, r: l not in r,
|
1238
|
+
}
|
1239
|
+
if ctx.is_in_static_scope():
|
1240
|
+
ops = {**ops, **ops_static}
|
1241
|
+
operands = [node.left.ptr] + [comparator.ptr for comparator in node.comparators]
|
1242
|
+
val = True
|
1243
|
+
for i, node_op in enumerate(node.ops):
|
1244
|
+
if isinstance(node_op, (ast.Is, ast.IsNot)):
|
1245
|
+
name = "is" if isinstance(node_op, ast.Is) else "is not"
|
1246
|
+
raise TaichiSyntaxError(f'Operator "{name}" in Taichi scope is not supported.')
|
1247
|
+
l = operands[i]
|
1248
|
+
r = operands[i + 1]
|
1249
|
+
op = ops.get(type(node_op))
|
1250
|
+
|
1251
|
+
if op is None:
|
1252
|
+
if type(node_op) in ops_static:
|
1253
|
+
raise TaichiSyntaxError(f'"{type(node_op).__name__}" is only supported inside `ti.static`.')
|
1254
|
+
else:
|
1255
|
+
raise TaichiSyntaxError(f'"{type(node_op).__name__}" is not supported in Taichi kernels.')
|
1256
|
+
val = ti_ops.logical_and(val, op(l, r))
|
1257
|
+
if not isinstance(val, (bool, np.bool_)):
|
1258
|
+
val = ti_ops.cast(val, primitive_types.u1)
|
1259
|
+
node.ptr = val
|
1260
|
+
return node.ptr
|
1261
|
+
|
1262
|
+
@staticmethod
|
1263
|
+
def get_decorator(ctx: ASTTransformerContext, node) -> str:
|
1264
|
+
if not isinstance(node, ast.Call):
|
1265
|
+
return ""
|
1266
|
+
for wanted, name in [
|
1267
|
+
(impl.static, "static"),
|
1268
|
+
(impl.static_assert, "static_assert"),
|
1269
|
+
(impl.grouped, "grouped"),
|
1270
|
+
(ndrange, "ndrange"),
|
1271
|
+
]:
|
1272
|
+
if ASTResolver.resolve_to(node.func, wanted, ctx.global_vars):
|
1273
|
+
return name
|
1274
|
+
return ""
|
1275
|
+
|
1276
|
+
@staticmethod
|
1277
|
+
def get_for_loop_targets(node: ast.Name | ast.Tuple | Any) -> list:
|
1278
|
+
"""
|
1279
|
+
Returns the list of indices of the for loop |node|.
|
1280
|
+
See also: https://docs.python.org/3/library/ast.html#ast.For
|
1281
|
+
"""
|
1282
|
+
if isinstance(node.target, ast.Name):
|
1283
|
+
return [node.target.id]
|
1284
|
+
assert isinstance(node.target, ast.Tuple)
|
1285
|
+
return [name.id for name in node.target.elts]
|
1286
|
+
|
1287
|
+
@staticmethod
|
1288
|
+
def build_static_for(ctx: ASTTransformerContext, node: ast.For, is_grouped: bool) -> None:
|
1289
|
+
ti_unroll_limit = impl.get_runtime().unrolling_limit
|
1290
|
+
if is_grouped:
|
1291
|
+
assert len(node.iter.args[0].args) == 1
|
1292
|
+
ndrange_arg = build_stmt(ctx, node.iter.args[0].args[0])
|
1293
|
+
if not isinstance(ndrange_arg, _Ndrange):
|
1294
|
+
raise TaichiSyntaxError("Only 'ti.ndrange' is allowed in 'ti.static(ti.grouped(...))'.")
|
1295
|
+
targets = ASTTransformer.get_for_loop_targets(node)
|
1296
|
+
if len(targets) != 1:
|
1297
|
+
raise TaichiSyntaxError(f"Group for should have 1 loop target, found {len(targets)}")
|
1298
|
+
target = targets[0]
|
1299
|
+
iter_time = 0
|
1300
|
+
alert_already = False
|
1301
|
+
|
1302
|
+
for value in impl.grouped(ndrange_arg):
|
1303
|
+
iter_time += 1
|
1304
|
+
if not alert_already and ti_unroll_limit and iter_time > ti_unroll_limit:
|
1305
|
+
alert_already = True
|
1306
|
+
warnings.warn_explicit(
|
1307
|
+
f"""You are unrolling more than
|
1308
|
+
{ti_unroll_limit} iterations, so the compile time may be extremely long.
|
1309
|
+
You can use a non-static for loop if you want to decrease the compile time.
|
1310
|
+
You can disable this warning by setting ti.init(unrolling_limit=0).""",
|
1311
|
+
SyntaxWarning,
|
1312
|
+
ctx.file,
|
1313
|
+
node.lineno + ctx.lineno_offset,
|
1314
|
+
module="taichi",
|
1315
|
+
)
|
1316
|
+
|
1317
|
+
with ctx.variable_scope_guard():
|
1318
|
+
ctx.create_variable(target, value)
|
1319
|
+
build_stmts(ctx, node.body)
|
1320
|
+
status = ctx.loop_status()
|
1321
|
+
if status == LoopStatus.Break:
|
1322
|
+
break
|
1323
|
+
elif status == LoopStatus.Continue:
|
1324
|
+
ctx.set_loop_status(LoopStatus.Normal)
|
1325
|
+
else:
|
1326
|
+
build_stmt(ctx, node.iter)
|
1327
|
+
targets = ASTTransformer.get_for_loop_targets(node)
|
1328
|
+
|
1329
|
+
iter_time = 0
|
1330
|
+
alert_already = False
|
1331
|
+
for target_values in node.iter.ptr:
|
1332
|
+
if not isinstance(target_values, collections.abc.Sequence) or len(targets) == 1:
|
1333
|
+
target_values = [target_values]
|
1334
|
+
|
1335
|
+
iter_time += 1
|
1336
|
+
if not alert_already and ti_unroll_limit and iter_time > ti_unroll_limit:
|
1337
|
+
alert_already = True
|
1338
|
+
warnings.warn_explicit(
|
1339
|
+
f"""You are unrolling more than
|
1340
|
+
{ti_unroll_limit} iterations, so the compile time may be extremely long.
|
1341
|
+
You can use a non-static for loop if you want to decrease the compile time.
|
1342
|
+
You can disable this warning by setting ti.init(unrolling_limit=0).""",
|
1343
|
+
SyntaxWarning,
|
1344
|
+
ctx.file,
|
1345
|
+
node.lineno + ctx.lineno_offset,
|
1346
|
+
module="taichi",
|
1347
|
+
)
|
1348
|
+
|
1349
|
+
with ctx.variable_scope_guard():
|
1350
|
+
for target, target_value in zip(targets, target_values):
|
1351
|
+
ctx.create_variable(target, target_value)
|
1352
|
+
build_stmts(ctx, node.body)
|
1353
|
+
status = ctx.loop_status()
|
1354
|
+
if status == LoopStatus.Break:
|
1355
|
+
break
|
1356
|
+
elif status == LoopStatus.Continue:
|
1357
|
+
ctx.set_loop_status(LoopStatus.Normal)
|
1358
|
+
return None
|
1359
|
+
|
1360
|
+
@staticmethod
|
1361
|
+
def build_range_for(ctx: ASTTransformerContext, node: ast.For) -> None:
|
1362
|
+
with ctx.variable_scope_guard():
|
1363
|
+
loop_name = node.target.id
|
1364
|
+
ctx.check_loop_var(loop_name)
|
1365
|
+
loop_var = expr.Expr(ctx.ast_builder.make_id_expr(""))
|
1366
|
+
ctx.create_variable(loop_name, loop_var)
|
1367
|
+
if len(node.iter.args) not in [1, 2]:
|
1368
|
+
raise TaichiSyntaxError(f"Range should have 1 or 2 arguments, found {len(node.iter.args)}")
|
1369
|
+
if len(node.iter.args) == 2:
|
1370
|
+
begin_expr = expr.Expr(build_stmt(ctx, node.iter.args[0]))
|
1371
|
+
end_expr = expr.Expr(build_stmt(ctx, node.iter.args[1]))
|
1372
|
+
|
1373
|
+
# Warning for implicit dtype conversion
|
1374
|
+
boundary_type_cast_warning(begin_expr)
|
1375
|
+
boundary_type_cast_warning(end_expr)
|
1376
|
+
|
1377
|
+
begin = ti_ops.cast(begin_expr, primitive_types.i32)
|
1378
|
+
end = ti_ops.cast(end_expr, primitive_types.i32)
|
1379
|
+
|
1380
|
+
else:
|
1381
|
+
end_expr = expr.Expr(build_stmt(ctx, node.iter.args[0]))
|
1382
|
+
|
1383
|
+
# Warning for implicit dtype conversion
|
1384
|
+
boundary_type_cast_warning(end_expr)
|
1385
|
+
|
1386
|
+
begin = ti_ops.cast(expr.Expr(0), primitive_types.i32)
|
1387
|
+
end = ti_ops.cast(end_expr, primitive_types.i32)
|
1388
|
+
|
1389
|
+
for_di = _ti_core.DebugInfo(ctx.get_pos_info(node))
|
1390
|
+
ctx.ast_builder.begin_frontend_range_for(loop_var.ptr, begin.ptr, end.ptr, for_di)
|
1391
|
+
build_stmts(ctx, node.body)
|
1392
|
+
ctx.ast_builder.end_frontend_range_for()
|
1393
|
+
return None
|
1394
|
+
|
1395
|
+
@staticmethod
|
1396
|
+
def build_ndrange_for(ctx: ASTTransformerContext, node: ast.For) -> None:
|
1397
|
+
with ctx.variable_scope_guard():
|
1398
|
+
ndrange_var = impl.expr_init(build_stmt(ctx, node.iter))
|
1399
|
+
ndrange_begin = ti_ops.cast(expr.Expr(0), primitive_types.i32)
|
1400
|
+
ndrange_end = ti_ops.cast(
|
1401
|
+
expr.Expr(impl.subscript(ctx.ast_builder, ndrange_var.acc_dimensions, 0)),
|
1402
|
+
primitive_types.i32,
|
1403
|
+
)
|
1404
|
+
ndrange_loop_var = expr.Expr(ctx.ast_builder.make_id_expr(""))
|
1405
|
+
for_di = _ti_core.DebugInfo(ctx.get_pos_info(node))
|
1406
|
+
ctx.ast_builder.begin_frontend_range_for(ndrange_loop_var.ptr, ndrange_begin.ptr, ndrange_end.ptr, for_di)
|
1407
|
+
I = impl.expr_init(ndrange_loop_var)
|
1408
|
+
targets = ASTTransformer.get_for_loop_targets(node)
|
1409
|
+
if len(targets) != len(ndrange_var.dimensions):
|
1410
|
+
raise TaichiSyntaxError(
|
1411
|
+
"Ndrange for loop with number of the loop variables not equal to "
|
1412
|
+
"the dimension of the ndrange is not supported. "
|
1413
|
+
"Please check if the number of arguments of ti.ndrange() is equal to "
|
1414
|
+
"the number of the loop variables."
|
1415
|
+
)
|
1416
|
+
for i, target in enumerate(targets):
|
1417
|
+
if i + 1 < len(targets):
|
1418
|
+
target_tmp = impl.expr_init(I // ndrange_var.acc_dimensions[i + 1])
|
1419
|
+
else:
|
1420
|
+
target_tmp = impl.expr_init(I)
|
1421
|
+
ctx.create_variable(
|
1422
|
+
target,
|
1423
|
+
impl.expr_init(
|
1424
|
+
target_tmp
|
1425
|
+
+ impl.subscript(
|
1426
|
+
ctx.ast_builder,
|
1427
|
+
impl.subscript(ctx.ast_builder, ndrange_var.bounds, i),
|
1428
|
+
0,
|
1429
|
+
)
|
1430
|
+
),
|
1431
|
+
)
|
1432
|
+
if i + 1 < len(targets):
|
1433
|
+
I._assign(I - target_tmp * ndrange_var.acc_dimensions[i + 1])
|
1434
|
+
build_stmts(ctx, node.body)
|
1435
|
+
ctx.ast_builder.end_frontend_range_for()
|
1436
|
+
return None
|
1437
|
+
|
1438
|
+
@staticmethod
|
1439
|
+
def build_grouped_ndrange_for(ctx: ASTTransformerContext, node: ast.For) -> None:
|
1440
|
+
with ctx.variable_scope_guard():
|
1441
|
+
ndrange_var = impl.expr_init(build_stmt(ctx, node.iter.args[0]))
|
1442
|
+
ndrange_begin = ti_ops.cast(expr.Expr(0), primitive_types.i32)
|
1443
|
+
ndrange_end = ti_ops.cast(
|
1444
|
+
expr.Expr(impl.subscript(ctx.ast_builder, ndrange_var.acc_dimensions, 0)),
|
1445
|
+
primitive_types.i32,
|
1446
|
+
)
|
1447
|
+
ndrange_loop_var = expr.Expr(ctx.ast_builder.make_id_expr(""))
|
1448
|
+
for_di = _ti_core.DebugInfo(ctx.get_pos_info(node))
|
1449
|
+
ctx.ast_builder.begin_frontend_range_for(ndrange_loop_var.ptr, ndrange_begin.ptr, ndrange_end.ptr, for_di)
|
1450
|
+
|
1451
|
+
targets = ASTTransformer.get_for_loop_targets(node)
|
1452
|
+
if len(targets) != 1:
|
1453
|
+
raise TaichiSyntaxError(f"Group for should have 1 loop target, found {len(targets)}")
|
1454
|
+
target = targets[0]
|
1455
|
+
mat = matrix.make_matrix([0] * len(ndrange_var.dimensions), dt=primitive_types.i32)
|
1456
|
+
target_var = impl.expr_init(mat)
|
1457
|
+
|
1458
|
+
ctx.create_variable(target, target_var)
|
1459
|
+
I = impl.expr_init(ndrange_loop_var)
|
1460
|
+
for i in range(len(ndrange_var.dimensions)):
|
1461
|
+
if i + 1 < len(ndrange_var.dimensions):
|
1462
|
+
target_tmp = I // ndrange_var.acc_dimensions[i + 1]
|
1463
|
+
else:
|
1464
|
+
target_tmp = I
|
1465
|
+
impl.subscript(ctx.ast_builder, target_var, i)._assign(target_tmp + ndrange_var.bounds[i][0])
|
1466
|
+
if i + 1 < len(ndrange_var.dimensions):
|
1467
|
+
I._assign(I - target_tmp * ndrange_var.acc_dimensions[i + 1])
|
1468
|
+
build_stmts(ctx, node.body)
|
1469
|
+
ctx.ast_builder.end_frontend_range_for()
|
1470
|
+
return None
|
1471
|
+
|
1472
|
+
@staticmethod
|
1473
|
+
def build_struct_for(ctx: ASTTransformerContext, node: ast.For, is_grouped: bool) -> None:
|
1474
|
+
# for i, j in x
|
1475
|
+
# for I in ti.grouped(x)
|
1476
|
+
targets = ASTTransformer.get_for_loop_targets(node)
|
1477
|
+
|
1478
|
+
for target in targets:
|
1479
|
+
ctx.check_loop_var(target)
|
1480
|
+
|
1481
|
+
with ctx.variable_scope_guard():
|
1482
|
+
if is_grouped:
|
1483
|
+
if len(targets) != 1:
|
1484
|
+
raise TaichiSyntaxError(f"Group for should have 1 loop target, found {len(targets)}")
|
1485
|
+
target = targets[0]
|
1486
|
+
loop_var = build_stmt(ctx, node.iter)
|
1487
|
+
loop_indices = expr.make_var_list(size=len(loop_var.shape), ast_builder=ctx.ast_builder)
|
1488
|
+
expr_group = expr.make_expr_group(loop_indices)
|
1489
|
+
impl.begin_frontend_struct_for(ctx.ast_builder, expr_group, loop_var)
|
1490
|
+
ctx.create_variable(target, matrix.make_matrix(loop_indices, dt=primitive_types.i32))
|
1491
|
+
build_stmts(ctx, node.body)
|
1492
|
+
ctx.ast_builder.end_frontend_struct_for()
|
1493
|
+
else:
|
1494
|
+
_vars = []
|
1495
|
+
for name in targets:
|
1496
|
+
var = expr.Expr(ctx.ast_builder.make_id_expr(""))
|
1497
|
+
_vars.append(var)
|
1498
|
+
ctx.create_variable(name, var)
|
1499
|
+
loop_var = node.iter.ptr
|
1500
|
+
expr_group = expr.make_expr_group(*_vars)
|
1501
|
+
impl.begin_frontend_struct_for(ctx.ast_builder, expr_group, loop_var)
|
1502
|
+
build_stmts(ctx, node.body)
|
1503
|
+
ctx.ast_builder.end_frontend_struct_for()
|
1504
|
+
return None
|
1505
|
+
|
1506
|
+
@staticmethod
|
1507
|
+
def build_mesh_for(ctx: ASTTransformerContext, node: ast.For) -> None:
|
1508
|
+
targets = ASTTransformer.get_for_loop_targets(node)
|
1509
|
+
if len(targets) != 1:
|
1510
|
+
raise TaichiSyntaxError("Mesh for should have 1 loop target, found {len(targets)}")
|
1511
|
+
target = targets[0]
|
1512
|
+
|
1513
|
+
with ctx.variable_scope_guard():
|
1514
|
+
var = expr.Expr(ctx.ast_builder.make_id_expr(""))
|
1515
|
+
ctx.mesh = node.iter.ptr.mesh
|
1516
|
+
assert isinstance(ctx.mesh, impl.MeshInstance)
|
1517
|
+
mesh_idx = mesh.MeshElementFieldProxy(ctx.mesh, node.iter.ptr._type, var.ptr)
|
1518
|
+
ctx.create_variable(target, mesh_idx)
|
1519
|
+
ctx.ast_builder.begin_frontend_mesh_for(
|
1520
|
+
mesh_idx.ptr,
|
1521
|
+
ctx.mesh.mesh_ptr,
|
1522
|
+
node.iter.ptr._type,
|
1523
|
+
_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
|
1524
|
+
)
|
1525
|
+
build_stmts(ctx, node.body)
|
1526
|
+
ctx.mesh = None
|
1527
|
+
ctx.ast_builder.end_frontend_mesh_for()
|
1528
|
+
return None
|
1529
|
+
|
1530
|
+
@staticmethod
|
1531
|
+
def build_nested_mesh_for(ctx: ASTTransformerContext, node: ast.For) -> None:
|
1532
|
+
targets = ASTTransformer.get_for_loop_targets(node)
|
1533
|
+
if len(targets) != 1:
|
1534
|
+
raise TaichiSyntaxError("Nested-mesh for should have 1 loop target, found {len(targets)}")
|
1535
|
+
target = targets[0]
|
1536
|
+
|
1537
|
+
with ctx.variable_scope_guard():
|
1538
|
+
ctx.mesh = node.iter.ptr.mesh
|
1539
|
+
assert isinstance(ctx.mesh, impl.MeshInstance)
|
1540
|
+
loop_name = node.target.id + "_index__"
|
1541
|
+
loop_var = expr.Expr(ctx.ast_builder.make_id_expr(""))
|
1542
|
+
ctx.create_variable(loop_name, loop_var)
|
1543
|
+
begin = expr.Expr(0)
|
1544
|
+
end = ti_ops.cast(node.iter.ptr.size, primitive_types.i32)
|
1545
|
+
for_di = _ti_core.DebugInfo(ctx.get_pos_info(node))
|
1546
|
+
ctx.ast_builder.begin_frontend_range_for(loop_var.ptr, begin.ptr, end.ptr, for_di)
|
1547
|
+
entry_expr = _ti_core.get_relation_access(
|
1548
|
+
ctx.mesh.mesh_ptr,
|
1549
|
+
node.iter.ptr.from_index.ptr,
|
1550
|
+
node.iter.ptr.to_element_type,
|
1551
|
+
loop_var.ptr,
|
1552
|
+
)
|
1553
|
+
entry_expr.type_check(impl.get_runtime().prog.config())
|
1554
|
+
mesh_idx = mesh.MeshElementFieldProxy(ctx.mesh, node.iter.ptr.to_element_type, entry_expr)
|
1555
|
+
ctx.create_variable(target, mesh_idx)
|
1556
|
+
build_stmts(ctx, node.body)
|
1557
|
+
ctx.ast_builder.end_frontend_range_for()
|
1558
|
+
|
1559
|
+
return None
|
1560
|
+
|
1561
|
+
@staticmethod
|
1562
|
+
def build_For(ctx: ASTTransformerContext, node: ast.For) -> None:
|
1563
|
+
if node.orelse:
|
1564
|
+
raise TaichiSyntaxError("'else' clause for 'for' not supported in Taichi kernels")
|
1565
|
+
decorator = ASTTransformer.get_decorator(ctx, node.iter)
|
1566
|
+
double_decorator = ""
|
1567
|
+
if decorator != "" and len(node.iter.args) == 1:
|
1568
|
+
double_decorator = ASTTransformer.get_decorator(ctx, node.iter.args[0])
|
1569
|
+
|
1570
|
+
if decorator == "static":
|
1571
|
+
if double_decorator == "static":
|
1572
|
+
raise TaichiSyntaxError("'ti.static' cannot be nested")
|
1573
|
+
with ctx.loop_scope_guard(is_static=True):
|
1574
|
+
return ASTTransformer.build_static_for(ctx, node, double_decorator == "grouped")
|
1575
|
+
with ctx.loop_scope_guard():
|
1576
|
+
if decorator == "ndrange":
|
1577
|
+
if double_decorator != "":
|
1578
|
+
raise TaichiSyntaxError("No decorator is allowed inside 'ti.ndrange")
|
1579
|
+
return ASTTransformer.build_ndrange_for(ctx, node)
|
1580
|
+
if decorator == "grouped":
|
1581
|
+
if double_decorator == "static":
|
1582
|
+
raise TaichiSyntaxError("'ti.static' is not allowed inside 'ti.grouped'")
|
1583
|
+
elif double_decorator == "ndrange":
|
1584
|
+
return ASTTransformer.build_grouped_ndrange_for(ctx, node)
|
1585
|
+
elif double_decorator == "grouped":
|
1586
|
+
raise TaichiSyntaxError("'ti.grouped' cannot be nested")
|
1587
|
+
else:
|
1588
|
+
return ASTTransformer.build_struct_for(ctx, node, is_grouped=True)
|
1589
|
+
elif (
|
1590
|
+
isinstance(node.iter, ast.Call)
|
1591
|
+
and isinstance(node.iter.func, ast.Name)
|
1592
|
+
and node.iter.func.id == "range"
|
1593
|
+
):
|
1594
|
+
return ASTTransformer.build_range_for(ctx, node)
|
1595
|
+
else:
|
1596
|
+
build_stmt(ctx, node.iter)
|
1597
|
+
if isinstance(node.iter.ptr, mesh.MeshElementField):
|
1598
|
+
if not _ti_core.is_extension_supported(impl.default_cfg().arch, _ti_core.Extension.mesh):
|
1599
|
+
raise Exception(
|
1600
|
+
"Backend " + str(impl.default_cfg().arch) + " doesn't support MeshTaichi extension"
|
1601
|
+
)
|
1602
|
+
return ASTTransformer.build_mesh_for(ctx, node)
|
1603
|
+
if isinstance(node.iter.ptr, mesh.MeshRelationAccessProxy):
|
1604
|
+
return ASTTransformer.build_nested_mesh_for(ctx, node)
|
1605
|
+
# Struct for
|
1606
|
+
return ASTTransformer.build_struct_for(ctx, node, is_grouped=False)
|
1607
|
+
|
1608
|
+
@staticmethod
|
1609
|
+
def build_While(ctx: ASTTransformerContext, node: ast.While) -> None:
|
1610
|
+
if node.orelse:
|
1611
|
+
raise TaichiSyntaxError("'else' clause for 'while' not supported in Taichi kernels")
|
1612
|
+
|
1613
|
+
with ctx.loop_scope_guard():
|
1614
|
+
stmt_dbg_info = _ti_core.DebugInfo(ctx.get_pos_info(node))
|
1615
|
+
ctx.ast_builder.begin_frontend_while(expr.Expr(1, dtype=primitive_types.i32).ptr, stmt_dbg_info)
|
1616
|
+
while_cond = build_stmt(ctx, node.test)
|
1617
|
+
impl.begin_frontend_if(ctx.ast_builder, while_cond, stmt_dbg_info)
|
1618
|
+
ctx.ast_builder.begin_frontend_if_true()
|
1619
|
+
ctx.ast_builder.pop_scope()
|
1620
|
+
ctx.ast_builder.begin_frontend_if_false()
|
1621
|
+
ctx.ast_builder.insert_break_stmt(stmt_dbg_info)
|
1622
|
+
ctx.ast_builder.pop_scope()
|
1623
|
+
build_stmts(ctx, node.body)
|
1624
|
+
ctx.ast_builder.pop_scope()
|
1625
|
+
return None
|
1626
|
+
|
1627
|
+
@staticmethod
|
1628
|
+
def build_If(ctx: ASTTransformerContext, node: ast.If) -> ast.If | None:
|
1629
|
+
build_stmt(ctx, node.test)
|
1630
|
+
is_static_if = ASTTransformer.get_decorator(ctx, node.test) == "static"
|
1631
|
+
|
1632
|
+
if is_static_if:
|
1633
|
+
if node.test.ptr:
|
1634
|
+
build_stmts(ctx, node.body)
|
1635
|
+
else:
|
1636
|
+
build_stmts(ctx, node.orelse)
|
1637
|
+
return node
|
1638
|
+
|
1639
|
+
with ctx.non_static_if_guard(node):
|
1640
|
+
stmt_dbg_info = _ti_core.DebugInfo(ctx.get_pos_info(node))
|
1641
|
+
impl.begin_frontend_if(ctx.ast_builder, node.test.ptr, stmt_dbg_info)
|
1642
|
+
ctx.ast_builder.begin_frontend_if_true()
|
1643
|
+
build_stmts(ctx, node.body)
|
1644
|
+
ctx.ast_builder.pop_scope()
|
1645
|
+
ctx.ast_builder.begin_frontend_if_false()
|
1646
|
+
build_stmts(ctx, node.orelse)
|
1647
|
+
ctx.ast_builder.pop_scope()
|
1648
|
+
return None
|
1649
|
+
|
1650
|
+
@staticmethod
|
1651
|
+
def build_Expr(ctx: ASTTransformerContext, node: ast.Expr) -> None:
|
1652
|
+
build_stmt(ctx, node.value)
|
1653
|
+
return None
|
1654
|
+
|
1655
|
+
@staticmethod
|
1656
|
+
def build_IfExp(ctx: ASTTransformerContext, node: ast.IfExp):
|
1657
|
+
build_stmt(ctx, node.test)
|
1658
|
+
build_stmt(ctx, node.body)
|
1659
|
+
build_stmt(ctx, node.orelse)
|
1660
|
+
|
1661
|
+
has_tensor_type = False
|
1662
|
+
if isinstance(node.test.ptr, expr.Expr) and node.test.ptr.is_tensor():
|
1663
|
+
has_tensor_type = True
|
1664
|
+
if isinstance(node.body.ptr, expr.Expr) and node.body.ptr.is_tensor():
|
1665
|
+
has_tensor_type = True
|
1666
|
+
if isinstance(node.orelse.ptr, expr.Expr) and node.orelse.ptr.is_tensor():
|
1667
|
+
has_tensor_type = True
|
1668
|
+
|
1669
|
+
if has_tensor_type:
|
1670
|
+
if isinstance(node.test.ptr, expr.Expr) and node.test.ptr.is_tensor():
|
1671
|
+
raise TaichiSyntaxError(
|
1672
|
+
"Using conditional expression for element-wise select operation on "
|
1673
|
+
"Taichi vectors/matrices is deprecated and removed starting from Taichi v1.5.0 "
|
1674
|
+
'Please use "ti.select" instead.'
|
1675
|
+
)
|
1676
|
+
node.ptr = ti_ops.select(node.test.ptr, node.body.ptr, node.orelse.ptr)
|
1677
|
+
return node.ptr
|
1678
|
+
|
1679
|
+
is_static_if = ASTTransformer.get_decorator(ctx, node.test) == "static"
|
1680
|
+
|
1681
|
+
if is_static_if:
|
1682
|
+
if node.test.ptr:
|
1683
|
+
node.ptr = build_stmt(ctx, node.body)
|
1684
|
+
else:
|
1685
|
+
node.ptr = build_stmt(ctx, node.orelse)
|
1686
|
+
return node.ptr
|
1687
|
+
|
1688
|
+
node.ptr = ti_ops.ifte(node.test.ptr, node.body.ptr, node.orelse.ptr)
|
1689
|
+
return node.ptr
|
1690
|
+
|
1691
|
+
@staticmethod
|
1692
|
+
def _is_string_mod_args(msg) -> bool:
|
1693
|
+
# 1. str % (a, b, c, ...)
|
1694
|
+
# 2. str % single_item
|
1695
|
+
# Note that |msg.right| may not be a tuple.
|
1696
|
+
if not isinstance(msg, ast.BinOp):
|
1697
|
+
return False
|
1698
|
+
if not isinstance(msg.op, ast.Mod):
|
1699
|
+
return False
|
1700
|
+
if isinstance(msg.left, ast.Str):
|
1701
|
+
return True
|
1702
|
+
if isinstance(msg.left, ast.Constant) and isinstance(msg.left.value, str):
|
1703
|
+
return True
|
1704
|
+
return False
|
1705
|
+
|
1706
|
+
@staticmethod
|
1707
|
+
def _handle_string_mod_args(ctx: ASTTransformerContext, node):
|
1708
|
+
msg = build_stmt(ctx, node.left)
|
1709
|
+
args = build_stmt(ctx, node.right)
|
1710
|
+
if not isinstance(args, collections.abc.Sequence):
|
1711
|
+
args = (args,)
|
1712
|
+
args = [expr.Expr(x).ptr for x in args]
|
1713
|
+
return msg, args
|
1714
|
+
|
1715
|
+
@staticmethod
|
1716
|
+
def ti_format_list_to_assert_msg(raw) -> tuple[str, list]:
|
1717
|
+
# TODO: ignore formats here for now
|
1718
|
+
entries, _ = impl.ti_format_list_to_content_entries([raw])
|
1719
|
+
msg = ""
|
1720
|
+
args = []
|
1721
|
+
for entry in entries:
|
1722
|
+
if isinstance(entry, str):
|
1723
|
+
msg += entry
|
1724
|
+
elif isinstance(entry, _ti_core.Expr):
|
1725
|
+
ty = entry.get_rvalue_type()
|
1726
|
+
if ty in primitive_types.real_types:
|
1727
|
+
msg += "%f"
|
1728
|
+
elif ty in primitive_types.integer_types:
|
1729
|
+
msg += "%d"
|
1730
|
+
else:
|
1731
|
+
raise TaichiSyntaxError(f"Unsupported data type: {type(ty)}")
|
1732
|
+
args.append(entry)
|
1733
|
+
else:
|
1734
|
+
raise TaichiSyntaxError(f"Unsupported type: {type(entry)}")
|
1735
|
+
return msg, args
|
1736
|
+
|
1737
|
+
@staticmethod
|
1738
|
+
def build_Assert(ctx: ASTTransformerContext, node: ast.Assert) -> None:
|
1739
|
+
extra_args = []
|
1740
|
+
if node.msg is not None:
|
1741
|
+
if ASTTransformer._is_string_mod_args(node.msg):
|
1742
|
+
msg, extra_args = ASTTransformer._handle_string_mod_args(ctx, node.msg)
|
1743
|
+
else:
|
1744
|
+
msg = build_stmt(ctx, node.msg)
|
1745
|
+
if isinstance(node.msg, ast.Constant):
|
1746
|
+
msg = str(msg)
|
1747
|
+
elif isinstance(node.msg, ast.Str):
|
1748
|
+
pass
|
1749
|
+
elif isinstance(msg, collections.abc.Sequence) and len(msg) > 0 and msg[0] == "__ti_format__":
|
1750
|
+
msg, extra_args = ASTTransformer.ti_format_list_to_assert_msg(msg)
|
1751
|
+
else:
|
1752
|
+
raise TaichiSyntaxError(f"assert info must be constant or formatted string, not {type(msg)}")
|
1753
|
+
else:
|
1754
|
+
msg = unparse(node.test)
|
1755
|
+
test = build_stmt(ctx, node.test)
|
1756
|
+
impl.ti_assert(test, msg.strip(), extra_args, _ti_core.DebugInfo(ctx.get_pos_info(node)))
|
1757
|
+
return None
|
1758
|
+
|
1759
|
+
@staticmethod
|
1760
|
+
def build_Break(ctx: ASTTransformerContext, node: ast.Break) -> None:
|
1761
|
+
if ctx.is_in_static_for():
|
1762
|
+
nearest_non_static_if = ctx.current_loop_scope().nearest_non_static_if
|
1763
|
+
if nearest_non_static_if:
|
1764
|
+
msg = ctx.get_pos_info(nearest_non_static_if.test)
|
1765
|
+
msg += (
|
1766
|
+
"You are trying to `break` a static `for` loop, "
|
1767
|
+
"but the `break` statement is inside a non-static `if`. "
|
1768
|
+
)
|
1769
|
+
raise TaichiSyntaxError(msg)
|
1770
|
+
ctx.set_loop_status(LoopStatus.Break)
|
1771
|
+
else:
|
1772
|
+
ctx.ast_builder.insert_break_stmt(_ti_core.DebugInfo(ctx.get_pos_info(node)))
|
1773
|
+
return None
|
1774
|
+
|
1775
|
+
@staticmethod
|
1776
|
+
def build_Continue(ctx: ASTTransformerContext, node: ast.Continue) -> None:
|
1777
|
+
if ctx.is_in_static_for():
|
1778
|
+
nearest_non_static_if = ctx.current_loop_scope().nearest_non_static_if
|
1779
|
+
if nearest_non_static_if:
|
1780
|
+
msg = ctx.get_pos_info(nearest_non_static_if.test)
|
1781
|
+
msg += (
|
1782
|
+
"You are trying to `continue` a static `for` loop, "
|
1783
|
+
"but the `continue` statement is inside a non-static `if`. "
|
1784
|
+
)
|
1785
|
+
raise TaichiSyntaxError(msg)
|
1786
|
+
ctx.set_loop_status(LoopStatus.Continue)
|
1787
|
+
else:
|
1788
|
+
ctx.ast_builder.insert_continue_stmt(_ti_core.DebugInfo(ctx.get_pos_info(node)))
|
1789
|
+
return None
|
1790
|
+
|
1791
|
+
@staticmethod
|
1792
|
+
def build_Pass(ctx: ASTTransformerContext, node: ast.Pass) -> None:
|
1793
|
+
return None
|
1794
|
+
|
1795
|
+
|
1796
|
+
build_stmt = ASTTransformer()
|
1797
|
+
|
1798
|
+
|
1799
|
+
def build_stmts(ctx: ASTTransformerContext, stmts: list):
|
1800
|
+
with ctx.variable_scope_guard():
|
1801
|
+
for stmt in stmts:
|
1802
|
+
if ctx.returned != ReturnStatus.NoReturn or ctx.loop_status() != LoopStatus.Normal:
|
1803
|
+
break
|
1804
|
+
else:
|
1805
|
+
build_stmt(ctx, stmt)
|
1806
|
+
return stmts
|