gstaichi 0.1.23.dev0__cp310-cp310-win_amd64.whl → 1.0.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +6 -0
- gstaichi/__init__.py +40 -0
- {taichi → gstaichi}/_funcs.py +8 -8
- {taichi → gstaichi}/_kernels.py +19 -19
- gstaichi/_lib/__init__.py +3 -0
- taichi/_lib/core/taichi_python.cp310-win_amd64.pyd → gstaichi/_lib/core/gstaichi_python.cp310-win_amd64.pyd +0 -0
- taichi/_lib/core/taichi_python.pyi → gstaichi/_lib/core/gstaichi_python.pyi +382 -522
- {taichi → gstaichi}/_lib/runtime/runtime_cuda.bc +0 -0
- {taichi → gstaichi}/_lib/runtime/runtime_x64.bc +0 -0
- {taichi → gstaichi}/_lib/utils.py +15 -15
- {taichi → gstaichi}/_logging.py +1 -1
- gstaichi/_snode/__init__.py +5 -0
- {taichi → gstaichi}/_snode/fields_builder.py +27 -29
- {taichi → gstaichi}/_snode/snode_tree.py +5 -5
- gstaichi/_test_tools/__init__.py +0 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_version.py +1 -0
- {taichi → gstaichi}/_version_check.py +8 -5
- gstaichi/ad/__init__.py +3 -0
- {taichi → gstaichi}/ad/_ad.py +26 -26
- {taichi → gstaichi}/algorithms/_algorithms.py +7 -7
- {taichi → gstaichi}/examples/minimal.py +1 -1
- {taichi → gstaichi}/experimental.py +1 -1
- gstaichi/lang/__init__.py +50 -0
- {taichi → gstaichi}/lang/_ndarray.py +30 -26
- {taichi → gstaichi}/lang/_ndrange.py +8 -8
- gstaichi/lang/_template_mapper.py +199 -0
- {taichi → gstaichi}/lang/_texture.py +19 -19
- {taichi → gstaichi}/lang/_wrap_inspect.py +7 -7
- {taichi → gstaichi}/lang/any_array.py +13 -13
- {taichi → gstaichi}/lang/argpack.py +29 -29
- gstaichi/lang/ast/__init__.py +5 -0
- {taichi → gstaichi}/lang/ast/ast_transformer.py +94 -582
- {taichi → gstaichi}/lang/ast/ast_transformer_utils.py +54 -41
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +267 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +320 -0
- {taichi → gstaichi}/lang/ast/checkers.py +5 -5
- gstaichi/lang/ast/transform.py +9 -0
- {taichi → gstaichi}/lang/common_ops.py +12 -12
- gstaichi/lang/exception.py +80 -0
- {taichi → gstaichi}/lang/expr.py +22 -22
- {taichi → gstaichi}/lang/field.py +29 -27
- {taichi → gstaichi}/lang/impl.py +116 -121
- {taichi → gstaichi}/lang/kernel_arguments.py +16 -16
- {taichi → gstaichi}/lang/kernel_impl.py +330 -363
- {taichi → gstaichi}/lang/matrix.py +119 -115
- {taichi → gstaichi}/lang/matrix_ops.py +6 -6
- {taichi → gstaichi}/lang/matrix_ops_utils.py +4 -4
- {taichi → gstaichi}/lang/mesh.py +22 -22
- {taichi → gstaichi}/lang/misc.py +39 -68
- {taichi → gstaichi}/lang/ops.py +146 -141
- {taichi → gstaichi}/lang/runtime_ops.py +2 -2
- {taichi → gstaichi}/lang/shell.py +3 -3
- {taichi → gstaichi}/lang/simt/__init__.py +1 -1
- {taichi → gstaichi}/lang/simt/block.py +7 -7
- {taichi → gstaichi}/lang/simt/grid.py +1 -1
- {taichi → gstaichi}/lang/simt/subgroup.py +1 -1
- {taichi → gstaichi}/lang/simt/warp.py +1 -1
- {taichi → gstaichi}/lang/snode.py +46 -44
- {taichi → gstaichi}/lang/source_builder.py +13 -13
- {taichi → gstaichi}/lang/struct.py +33 -33
- {taichi → gstaichi}/lang/util.py +24 -24
- gstaichi/linalg/__init__.py +8 -0
- {taichi → gstaichi}/linalg/matrixfree_cg.py +14 -14
- {taichi → gstaichi}/linalg/sparse_cg.py +10 -10
- {taichi → gstaichi}/linalg/sparse_matrix.py +23 -23
- {taichi → gstaichi}/linalg/sparse_solver.py +21 -21
- {taichi → gstaichi}/math/__init__.py +1 -1
- {taichi → gstaichi}/math/_complex.py +21 -20
- {taichi → gstaichi}/math/mathimpl.py +56 -56
- gstaichi/profiler/__init__.py +6 -0
- {taichi → gstaichi}/profiler/kernel_metrics.py +11 -11
- {taichi → gstaichi}/profiler/kernel_profiler.py +30 -36
- {taichi → gstaichi}/profiler/memory_profiler.py +1 -1
- {taichi → gstaichi}/profiler/scoped_profiler.py +2 -2
- {taichi → gstaichi}/sparse/_sparse_grid.py +7 -7
- {taichi → gstaichi}/tools/__init__.py +4 -4
- {taichi → gstaichi}/tools/diagnose.py +10 -17
- gstaichi/types/__init__.py +19 -0
- {taichi → gstaichi}/types/annotations.py +1 -1
- {taichi → gstaichi}/types/compound_types.py +8 -8
- {taichi → gstaichi}/types/enums.py +1 -1
- {taichi → gstaichi}/types/ndarray_type.py +7 -7
- {taichi → gstaichi}/types/primitive_types.py +17 -14
- {taichi → gstaichi}/types/quant.py +9 -9
- {taichi → gstaichi}/types/texture_type.py +5 -5
- {taichi → gstaichi}/types/utils.py +1 -1
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/bin/SPIRV-Tools-shared.dll +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools-diff.lib +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools-link.lib +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools-lint.lib +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools-opt.lib +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools-reduce.lib +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools-shared.lib +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/SPIRV-Tools.lib +0 -0
- {gstaichi-0.1.23.dev0.dist-info → gstaichi-1.0.1.dist-info}/METADATA +13 -16
- gstaichi-1.0.1.dist-info/RECORD +135 -0
- gstaichi-1.0.1.dist-info/top_level.txt +1 -0
- gstaichi-0.1.23.dev0.data/data/include/GLFW/glfw3.h +0 -6389
- gstaichi-0.1.23.dev0.data/data/include/GLFW/glfw3native.h +0 -594
- gstaichi-0.1.23.dev0.data/data/lib/cmake/glfw3/glfw3Config.cmake +0 -3
- gstaichi-0.1.23.dev0.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +0 -65
- gstaichi-0.1.23.dev0.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +0 -19
- gstaichi-0.1.23.dev0.data/data/lib/cmake/glfw3/glfw3Targets.cmake +0 -107
- gstaichi-0.1.23.dev0.data/data/lib/glfw3.lib +0 -0
- gstaichi-0.1.23.dev0.dist-info/RECORD +0 -198
- gstaichi-0.1.23.dev0.dist-info/entry_points.txt +0 -2
- gstaichi-0.1.23.dev0.dist-info/top_level.txt +0 -1
- taichi/CHANGELOG.md +0 -20
- taichi/__init__.py +0 -44
- taichi/__main__.py +0 -5
- taichi/_lib/__init__.py +0 -3
- taichi/_lib/c_api/bin/taichi_c_api.dll +0 -0
- taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +0 -1401
- taichi/_lib/c_api/include/taichi/taichi.h +0 -29
- taichi/_lib/c_api/include/taichi/taichi_core.h +0 -1111
- taichi/_lib/c_api/include/taichi/taichi_cpu.h +0 -29
- taichi/_lib/c_api/include/taichi/taichi_cuda.h +0 -36
- taichi/_lib/c_api/include/taichi/taichi_platform.h +0 -55
- taichi/_lib/c_api/include/taichi/taichi_unity.h +0 -64
- taichi/_lib/c_api/include/taichi/taichi_vulkan.h +0 -151
- taichi/_lib/c_api/lib/taichi_c_api.lib +0 -0
- taichi/_lib/c_api/runtime/runtime_cuda.bc +0 -0
- taichi/_lib/c_api/runtime/runtime_x64.bc +0 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +0 -29
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +0 -65
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +0 -121
- taichi/_lib/runtime/slim_libdevice.10.bc +0 -0
- taichi/_main.py +0 -552
- taichi/_snode/__init__.py +0 -5
- taichi/_ti_module/__init__.py +0 -3
- taichi/_ti_module/cppgen.py +0 -309
- taichi/_ti_module/module.py +0 -145
- taichi/_version.py +0 -1
- taichi/ad/__init__.py +0 -3
- taichi/aot/__init__.py +0 -12
- taichi/aot/_export.py +0 -28
- taichi/aot/conventions/__init__.py +0 -3
- taichi/aot/conventions/gfxruntime140/__init__.py +0 -38
- taichi/aot/conventions/gfxruntime140/dr.py +0 -244
- taichi/aot/conventions/gfxruntime140/sr.py +0 -613
- taichi/aot/module.py +0 -253
- taichi/aot/utils.py +0 -151
- taichi/graph/__init__.py +0 -3
- taichi/graph/_graph.py +0 -292
- taichi/lang/__init__.py +0 -50
- taichi/lang/ast/__init__.py +0 -5
- taichi/lang/ast/transform.py +0 -9
- taichi/lang/exception.py +0 -80
- taichi/linalg/__init__.py +0 -8
- taichi/profiler/__init__.py +0 -6
- taichi/shaders/Circles_vk.frag +0 -29
- taichi/shaders/Circles_vk.vert +0 -45
- taichi/shaders/Circles_vk_frag.spv +0 -0
- taichi/shaders/Circles_vk_vert.spv +0 -0
- taichi/shaders/Lines_vk.frag +0 -9
- taichi/shaders/Lines_vk.vert +0 -11
- taichi/shaders/Lines_vk_frag.spv +0 -0
- taichi/shaders/Lines_vk_vert.spv +0 -0
- taichi/shaders/Mesh_vk.frag +0 -71
- taichi/shaders/Mesh_vk.vert +0 -68
- taichi/shaders/Mesh_vk_frag.spv +0 -0
- taichi/shaders/Mesh_vk_vert.spv +0 -0
- taichi/shaders/Particles_vk.frag +0 -95
- taichi/shaders/Particles_vk.vert +0 -73
- taichi/shaders/Particles_vk_frag.spv +0 -0
- taichi/shaders/Particles_vk_vert.spv +0 -0
- taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
- taichi/shaders/SceneLines_vk.frag +0 -9
- taichi/shaders/SceneLines_vk.vert +0 -12
- taichi/shaders/SceneLines_vk_frag.spv +0 -0
- taichi/shaders/SceneLines_vk_vert.spv +0 -0
- taichi/shaders/SetImage_vk.frag +0 -21
- taichi/shaders/SetImage_vk.vert +0 -15
- taichi/shaders/SetImage_vk_frag.spv +0 -0
- taichi/shaders/SetImage_vk_vert.spv +0 -0
- taichi/shaders/Triangles_vk.frag +0 -16
- taichi/shaders/Triangles_vk.vert +0 -29
- taichi/shaders/Triangles_vk_frag.spv +0 -0
- taichi/shaders/Triangles_vk_vert.spv +0 -0
- taichi/shaders/lines2quad_vk_comp.spv +0 -0
- taichi/types/__init__.py +0 -19
- {taichi → gstaichi}/_lib/core/__init__.py +0 -0
- {taichi → gstaichi}/_lib/core/py.typed +0 -0
- {taichi/_lib/c_api → gstaichi/_lib}/runtime/slim_libdevice.10.bc +0 -0
- {taichi → gstaichi}/algorithms/__init__.py +0 -0
- {taichi → gstaichi}/assets/.git +0 -0
- {taichi → gstaichi}/assets/Go-Regular.ttf +0 -0
- {taichi → gstaichi}/assets/static/imgs/ti_gallery.png +0 -0
- {taichi → gstaichi}/lang/ast/symbol_resolver.py +0 -0
- {taichi → gstaichi}/sparse/__init__.py +0 -0
- {taichi → gstaichi}/tools/np2ply.py +0 -0
- {taichi → gstaichi}/tools/vtk.py +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/instrument.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/libspirv.h +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/libspirv.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/linker.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/optimizer.hpp +0 -0
- {gstaichi-0.1.23.dev0.dist-info → gstaichi-1.0.1.dist-info}/WHEEL +0 -0
- {gstaichi-0.1.23.dev0.dist-info → gstaichi-1.0.1.dist-info}/licenses/LICENSE +0 -0
@@ -2,46 +2,42 @@
|
|
2
2
|
|
3
3
|
import ast
|
4
4
|
import collections.abc
|
5
|
-
import dataclasses
|
6
|
-
import inspect
|
7
5
|
import itertools
|
8
|
-
import math
|
9
|
-
import operator
|
10
|
-
import re
|
11
6
|
import warnings
|
12
7
|
from ast import unparse
|
13
|
-
from collections import ChainMap
|
14
8
|
from typing import Any, Iterable, Type
|
15
9
|
|
16
10
|
import numpy as np
|
17
11
|
|
18
|
-
from
|
19
|
-
from
|
20
|
-
from
|
21
|
-
from
|
22
|
-
from
|
23
|
-
from taichi.lang.ast.ast_transformer_utils import (
|
12
|
+
from gstaichi._lib import core as _ti_core
|
13
|
+
from gstaichi.lang import expr, impl, matrix, mesh
|
14
|
+
from gstaichi.lang import ops as ti_ops
|
15
|
+
from gstaichi.lang._ndrange import _Ndrange
|
16
|
+
from gstaichi.lang.ast.ast_transformer_utils import (
|
24
17
|
ASTTransformerContext,
|
25
18
|
Builder,
|
26
19
|
LoopStatus,
|
27
20
|
ReturnStatus,
|
21
|
+
get_decorator,
|
28
22
|
)
|
29
|
-
from
|
30
|
-
from
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
23
|
+
from gstaichi.lang.ast.ast_transformers.call_transformer import CallTransformer
|
24
|
+
from gstaichi.lang.ast.ast_transformers.function_def_transformer import (
|
25
|
+
FunctionDefTransformer,
|
26
|
+
)
|
27
|
+
from gstaichi.lang.exception import (
|
28
|
+
GsTaichiIndexError,
|
29
|
+
GsTaichiRuntimeTypeError,
|
30
|
+
GsTaichiSyntaxError,
|
31
|
+
GsTaichiTypeError,
|
35
32
|
handle_exception_from_cpp,
|
36
33
|
)
|
37
|
-
from
|
38
|
-
from
|
39
|
-
from
|
40
|
-
from
|
41
|
-
from
|
42
|
-
from
|
43
|
-
from
|
44
|
-
from taichi.types.utils import is_integral
|
34
|
+
from gstaichi.lang.expr import Expr, make_expr_group
|
35
|
+
from gstaichi.lang.field import Field
|
36
|
+
from gstaichi.lang.matrix import Matrix, MatrixType
|
37
|
+
from gstaichi.lang.snode import append, deactivate, length
|
38
|
+
from gstaichi.lang.struct import Struct, StructType
|
39
|
+
from gstaichi.types import primitive_types
|
40
|
+
from gstaichi.types.utils import is_integral
|
45
41
|
|
46
42
|
|
47
43
|
def reshape_list(flat_list: list[Any], target_shape: Iterable[int]) -> list[Any]:
|
@@ -108,13 +104,13 @@ class ASTTransformer(Builder):
|
|
108
104
|
"""
|
109
105
|
is_local = isinstance(target, ast.Name)
|
110
106
|
if is_local and target.id in ctx.kernel_args:
|
111
|
-
raise
|
107
|
+
raise GsTaichiSyntaxError(
|
112
108
|
f'Kernel argument "{target.id}" is immutable in the kernel. '
|
113
109
|
f"If you want to change its value, please create a new variable."
|
114
110
|
)
|
115
111
|
anno = impl.expr_init(annotation)
|
116
112
|
if is_static_assign:
|
117
|
-
raise
|
113
|
+
raise GsTaichiSyntaxError("Static assign cannot be used on annotated assignment")
|
118
114
|
if is_local and not ctx.is_var_declared(target.id):
|
119
115
|
var = ti_ops.cast(value, anno)
|
120
116
|
var = impl.expr_init(var)
|
@@ -122,7 +118,7 @@ class ASTTransformer(Builder):
|
|
122
118
|
else:
|
123
119
|
var = build_stmt(ctx, target)
|
124
120
|
if var.ptr.get_rvalue_type() != anno:
|
125
|
-
raise
|
121
|
+
raise GsTaichiSyntaxError("Static assign cannot have type overloading")
|
126
122
|
var._assign(value)
|
127
123
|
return var
|
128
124
|
|
@@ -133,7 +129,7 @@ class ASTTransformer(Builder):
|
|
133
129
|
|
134
130
|
# Keep all generated assign statements and compose single one at last.
|
135
131
|
# The variable is introduced to support chained assignments.
|
136
|
-
# Ref https://github.com/taichi-dev/
|
132
|
+
# Ref https://github.com/taichi-dev/gstaichi/issues/2659.
|
137
133
|
values = node.value.ptr if is_static_assign else impl.expr_init(node.value.ptr)
|
138
134
|
|
139
135
|
for node_target in node.targets:
|
@@ -176,10 +172,10 @@ class ASTTransformer(Builder):
|
|
176
172
|
values = values[0]
|
177
173
|
|
178
174
|
if not isinstance(values, collections.abc.Sequence):
|
179
|
-
raise
|
175
|
+
raise GsTaichiSyntaxError(f"Cannot unpack type: {type(values)}")
|
180
176
|
|
181
177
|
if len(values) != len(targets):
|
182
|
-
raise
|
178
|
+
raise GsTaichiSyntaxError("The number of targets is not equal to value length")
|
183
179
|
|
184
180
|
for i, target in enumerate(targets):
|
185
181
|
ASTTransformer.build_assign_basic(ctx, target, values[i], is_static_assign)
|
@@ -199,13 +195,13 @@ class ASTTransformer(Builder):
|
|
199
195
|
"""
|
200
196
|
is_local = isinstance(target, ast.Name)
|
201
197
|
if is_local and target.id in ctx.kernel_args:
|
202
|
-
raise
|
198
|
+
raise GsTaichiSyntaxError(
|
203
199
|
f'Kernel argument "{target.id}" is immutable in the kernel. '
|
204
200
|
f"If you want to change its value, please create a new variable."
|
205
201
|
)
|
206
202
|
if is_static_assign:
|
207
203
|
if not is_local:
|
208
|
-
raise
|
204
|
+
raise GsTaichiSyntaxError("Static assign cannot be used on elements in arrays")
|
209
205
|
ctx.create_variable(target.id, value)
|
210
206
|
var = value
|
211
207
|
elif is_local and not ctx.is_var_declared(target.id):
|
@@ -216,8 +212,8 @@ class ASTTransformer(Builder):
|
|
216
212
|
try:
|
217
213
|
var._assign(value)
|
218
214
|
except AttributeError:
|
219
|
-
raise
|
220
|
-
f"Variable '{unparse(target).strip()}' cannot be assigned. Maybe it is not a
|
215
|
+
raise GsTaichiSyntaxError(
|
216
|
+
f"Variable '{unparse(target).strip()}' cannot be assigned. Maybe it is not a GsTaichi object?"
|
221
217
|
)
|
222
218
|
return var
|
223
219
|
|
@@ -415,495 +411,25 @@ class ASTTransformer(Builder):
|
|
415
411
|
elif isinstance(sub_node, ast.Str):
|
416
412
|
str_spec += sub_node.s
|
417
413
|
else:
|
418
|
-
raise
|
414
|
+
raise GsTaichiSyntaxError("Invalid value for fstring.")
|
419
415
|
|
420
416
|
args.insert(0, str_spec)
|
421
417
|
node.ptr = impl.ti_format(*args)
|
422
418
|
return node.ptr
|
423
419
|
|
424
420
|
@staticmethod
|
425
|
-
def
|
426
|
-
|
427
|
-
|
428
|
-
func = node.func.ptr
|
429
|
-
replace_func = {
|
430
|
-
id(print): impl.ti_print,
|
431
|
-
id(min): ti_ops.min,
|
432
|
-
id(max): ti_ops.max,
|
433
|
-
id(int): impl.ti_int,
|
434
|
-
id(bool): impl.ti_bool,
|
435
|
-
id(float): impl.ti_float,
|
436
|
-
id(any): matrix_ops.any,
|
437
|
-
id(all): matrix_ops.all,
|
438
|
-
id(abs): abs,
|
439
|
-
id(pow): pow,
|
440
|
-
id(operator.matmul): matrix_ops.matmul,
|
441
|
-
}
|
442
|
-
|
443
|
-
# Builtin 'len' function on Matrix Expr
|
444
|
-
if id(func) == id(len) and len(args) == 1:
|
445
|
-
if isinstance(args[0], Expr) and args[0].ptr.is_tensor():
|
446
|
-
node.ptr = args[0].get_shape()[0]
|
447
|
-
return True
|
448
|
-
|
449
|
-
if id(func) in replace_func:
|
450
|
-
node.ptr = replace_func[id(func)](*args, **keywords)
|
451
|
-
return True
|
452
|
-
return False
|
453
|
-
|
454
|
-
@staticmethod
|
455
|
-
def build_call_if_is_type(ctx: ASTTransformerContext, node, args, keywords):
|
456
|
-
func = node.func.ptr
|
457
|
-
if id(func) in primitive_types.type_ids:
|
458
|
-
if len(args) != 1 or keywords:
|
459
|
-
raise TaichiSyntaxError("A primitive type can only decorate a single expression.")
|
460
|
-
if is_taichi_class(args[0]):
|
461
|
-
raise TaichiSyntaxError("A primitive type cannot decorate an expression with a compound type.")
|
462
|
-
|
463
|
-
if isinstance(args[0], expr.Expr):
|
464
|
-
if args[0].ptr.is_tensor():
|
465
|
-
raise TaichiSyntaxError("A primitive type cannot decorate an expression with a compound type.")
|
466
|
-
node.ptr = ti_ops.cast(args[0], func)
|
467
|
-
else:
|
468
|
-
node.ptr = expr.Expr(args[0], dtype=func)
|
469
|
-
return True
|
470
|
-
return False
|
471
|
-
|
472
|
-
@staticmethod
|
473
|
-
def is_external_func(ctx: ASTTransformerContext, func) -> bool:
|
474
|
-
if ctx.is_in_static_scope(): # allow external function in static scope
|
475
|
-
return False
|
476
|
-
if hasattr(func, "_is_taichi_function") or hasattr(func, "_is_wrapped_kernel"): # taichi func/kernel
|
477
|
-
return False
|
478
|
-
if hasattr(func, "__module__") and func.__module__ and func.__module__.startswith("taichi."):
|
479
|
-
return False
|
480
|
-
return True
|
481
|
-
|
482
|
-
@staticmethod
|
483
|
-
def warn_if_is_external_func(ctx: ASTTransformerContext, node):
|
484
|
-
func = node.func.ptr
|
485
|
-
if not ASTTransformer.is_external_func(ctx, func):
|
486
|
-
return
|
487
|
-
name = unparse(node.func).strip()
|
488
|
-
warnings.warn_explicit(
|
489
|
-
f"\x1b[38;5;226m" # Yellow
|
490
|
-
f'Calling non-taichi function "{name}". '
|
491
|
-
f"Scope inside the function is not processed by the Taichi AST transformer. "
|
492
|
-
f"The function may not work as expected. Proceed with caution! "
|
493
|
-
f"Maybe you can consider turning it into a @ti.func?"
|
494
|
-
f"\x1b[0m", # Reset
|
495
|
-
SyntaxWarning,
|
496
|
-
ctx.file,
|
497
|
-
node.lineno + ctx.lineno_offset,
|
498
|
-
module="taichi",
|
499
|
-
)
|
421
|
+
def build_Call(ctx: ASTTransformerContext, node: ast.Call) -> Any | None:
|
422
|
+
return CallTransformer.build_Call(ctx, node, build_stmt, build_stmts)
|
500
423
|
|
501
424
|
@staticmethod
|
502
|
-
|
503
|
-
|
504
|
-
# Note that the arguments can be used multiple times in the string.
|
505
|
-
# e.g.:
|
506
|
-
# origin input: 'qwerty {1} {} {1:.3f} {k:.4f} {k:}'.format(1.0, 2.0, k=k)
|
507
|
-
# raw_string: 'qwerty {1} {} {1:.3f} {k:.4f} {k:}'
|
508
|
-
# raw_args: [1.0, 2.0]
|
509
|
-
# raw_keywords: {'k': <ti.Expr>}
|
510
|
-
# return value: ['qwerty {} {} {} {} {}', 2.0, 1.0, ['__ti_fmt_value__', 2.0, '.3f'], ['__ti_fmt_value__', <ti.Expr>, '.4f'], <ti.Expr>]
|
511
|
-
def canonicalize_formatted_string(raw_string: str, *raw_args: list, **raw_keywords: dict):
|
512
|
-
raw_brackets = re.findall(r"{(.*?)}", raw_string)
|
513
|
-
brackets = []
|
514
|
-
unnamed = 0
|
515
|
-
for bracket in raw_brackets:
|
516
|
-
item, spec = bracket.split(":") if ":" in bracket else (bracket, None)
|
517
|
-
if item.isdigit():
|
518
|
-
item = int(item)
|
519
|
-
# handle unnamed positional args
|
520
|
-
if item == "":
|
521
|
-
item = unnamed
|
522
|
-
unnamed += 1
|
523
|
-
# handle empty spec
|
524
|
-
if spec == "":
|
525
|
-
spec = None
|
526
|
-
brackets.append((item, spec))
|
527
|
-
|
528
|
-
# check for errors in the arguments
|
529
|
-
max_args_index = max([t[0] for t in brackets if isinstance(t[0], int)], default=-1)
|
530
|
-
if max_args_index + 1 != len(raw_args):
|
531
|
-
raise TaichiSyntaxError(
|
532
|
-
f"Expected {max_args_index + 1} positional argument(s), but received {len(raw_args)} instead."
|
533
|
-
)
|
534
|
-
brackets_keywords = [t[0] for t in brackets if isinstance(t[0], str)]
|
535
|
-
for item in brackets_keywords:
|
536
|
-
if item not in raw_keywords:
|
537
|
-
raise TaichiSyntaxError(f"Keyword '{item}' not found.")
|
538
|
-
for item in raw_keywords:
|
539
|
-
if item not in brackets_keywords:
|
540
|
-
raise TaichiSyntaxError(f"Keyword '{item}' not used.")
|
541
|
-
|
542
|
-
# reorganize the arguments based on their positions, keywords, and format specifiers
|
543
|
-
args = []
|
544
|
-
for item, spec in brackets:
|
545
|
-
new_arg = raw_args[item] if isinstance(item, int) else raw_keywords[item]
|
546
|
-
if spec is not None:
|
547
|
-
args.append(["__ti_fmt_value__", new_arg, spec])
|
548
|
-
else:
|
549
|
-
args.append(new_arg)
|
550
|
-
# put the formatted string as the first argument to make ti.format() happy
|
551
|
-
args.insert(0, re.sub(r"{.*?}", "{}", raw_string))
|
552
|
-
return args
|
553
|
-
|
554
|
-
@staticmethod
|
555
|
-
def expand_node_args_dataclasses(args: tuple[ast.AST, ...]) -> tuple[ast.AST, ...]:
|
556
|
-
args_new = []
|
557
|
-
for arg in args:
|
558
|
-
val = arg.ptr
|
559
|
-
if dataclasses.is_dataclass(val):
|
560
|
-
dataclass_type = val
|
561
|
-
for field in dataclasses.fields(dataclass_type):
|
562
|
-
child_name = f"__ti_{arg.id}_{field.name}"
|
563
|
-
load_ctx = ast.Load()
|
564
|
-
arg_node = ast.Name(
|
565
|
-
id=child_name,
|
566
|
-
ctx=load_ctx,
|
567
|
-
lineno=arg.lineno,
|
568
|
-
end_lineno=arg.end_lineno,
|
569
|
-
col_offset=arg.col_offset,
|
570
|
-
end_col_offset=arg.end_col_offset,
|
571
|
-
)
|
572
|
-
args_new.append(arg_node)
|
573
|
-
else:
|
574
|
-
args_new.append(arg)
|
575
|
-
return tuple(args_new)
|
576
|
-
|
577
|
-
@staticmethod
|
578
|
-
def build_Call(ctx: ASTTransformerContext, node: ast.Call):
|
579
|
-
if ASTTransformer.get_decorator(ctx, node) in ["static", "static_assert"]:
|
580
|
-
with ctx.static_scope_guard():
|
581
|
-
build_stmt(ctx, node.func)
|
582
|
-
build_stmts(ctx, node.args)
|
583
|
-
build_stmts(ctx, node.keywords)
|
584
|
-
else:
|
585
|
-
build_stmt(ctx, node.func)
|
586
|
-
# creates variable for the dataclass itself (as well as other variables,
|
587
|
-
# not related to dataclasses). Necessary for calling further child functions
|
588
|
-
build_stmts(ctx, node.args)
|
589
|
-
node.args = ASTTransformer.expand_node_args_dataclasses(node.args)
|
590
|
-
# create variables for the now-expanded dataclass members
|
591
|
-
build_stmts(ctx, node.args)
|
592
|
-
build_stmts(ctx, node.keywords)
|
593
|
-
|
594
|
-
args = []
|
595
|
-
for arg in node.args:
|
596
|
-
if isinstance(arg, ast.Starred):
|
597
|
-
arg_list = arg.ptr
|
598
|
-
if isinstance(arg_list, Expr) and arg_list.is_tensor():
|
599
|
-
# Expand Expr with Matrix-type return into list of Exprs
|
600
|
-
arg_list = [Expr(x) for x in ctx.ast_builder.expand_exprs([arg_list.ptr])]
|
601
|
-
|
602
|
-
for i in arg_list:
|
603
|
-
args.append(i)
|
604
|
-
else:
|
605
|
-
args.append(arg.ptr)
|
606
|
-
keywords = dict(ChainMap(*[keyword.ptr for keyword in node.keywords]))
|
607
|
-
func = node.func.ptr
|
608
|
-
|
609
|
-
if id(func) in [id(print), id(impl.ti_print)]:
|
610
|
-
ctx.func.has_print = True
|
611
|
-
|
612
|
-
if isinstance(node.func, ast.Attribute) and isinstance(node.func.value.ptr, str) and node.func.attr == "format":
|
613
|
-
raw_string = node.func.value.ptr
|
614
|
-
args = ASTTransformer.canonicalize_formatted_string(raw_string, *args, **keywords)
|
615
|
-
node.ptr = impl.ti_format(*args)
|
616
|
-
return node.ptr
|
617
|
-
|
618
|
-
if id(func) == id(Matrix) or id(func) == id(Vector):
|
619
|
-
node.ptr = matrix.make_matrix(*args, **keywords)
|
620
|
-
return node.ptr
|
621
|
-
|
622
|
-
if ASTTransformer.build_call_if_is_builtin(ctx, node, args, keywords):
|
623
|
-
return node.ptr
|
624
|
-
|
625
|
-
if ASTTransformer.build_call_if_is_type(ctx, node, args, keywords):
|
626
|
-
return node.ptr
|
627
|
-
|
628
|
-
if hasattr(node.func, "caller"):
|
629
|
-
node.ptr = func(node.func.caller, *args, **keywords)
|
630
|
-
return node.ptr
|
631
|
-
ASTTransformer.warn_if_is_external_func(ctx, node)
|
632
|
-
try:
|
633
|
-
node.ptr = func(*args, **keywords)
|
634
|
-
except TypeError as e:
|
635
|
-
module = inspect.getmodule(func)
|
636
|
-
error_msg = re.sub(r"\bExpr\b", "Taichi Expression", str(e))
|
637
|
-
msg = f"TypeError when calling `{func.__name__}`: {error_msg}."
|
638
|
-
if ASTTransformer.is_external_func(ctx, node.func.ptr):
|
639
|
-
args_has_expr = any([isinstance(arg, Expr) for arg in args])
|
640
|
-
if args_has_expr and (module == math or module == np):
|
641
|
-
exec_str = f"from taichi import {func.__name__}"
|
642
|
-
try:
|
643
|
-
exec(exec_str, {})
|
644
|
-
except:
|
645
|
-
pass
|
646
|
-
else:
|
647
|
-
msg += f"\nDid you mean to use `ti.{func.__name__}` instead of `{module.__name__}.{func.__name__}`?"
|
648
|
-
raise TaichiTypeError(msg)
|
649
|
-
|
650
|
-
if getattr(func, "_is_taichi_function", False):
|
651
|
-
ctx.func.has_print |= func.func.has_print
|
652
|
-
|
653
|
-
return node.ptr
|
654
|
-
|
655
|
-
@staticmethod
|
656
|
-
def build_FunctionDef(ctx: ASTTransformerContext, node: ast.FunctionDef):
|
657
|
-
if ctx.visited_funcdef:
|
658
|
-
raise TaichiSyntaxError(
|
659
|
-
f"Function definition is not allowed in 'ti.{'kernel' if ctx.is_kernel else 'func'}'."
|
660
|
-
)
|
661
|
-
ctx.visited_funcdef = True
|
662
|
-
|
663
|
-
args = node.args
|
664
|
-
assert args.vararg is None
|
665
|
-
assert args.kwonlyargs == []
|
666
|
-
assert args.kw_defaults == []
|
667
|
-
assert args.kwarg is None
|
668
|
-
|
669
|
-
def decl_and_create_variable(
|
670
|
-
annotation, name, arg_features, invoke_later_dict, prefix_name, arg_depth
|
671
|
-
) -> tuple[bool, Any]:
|
672
|
-
full_name = prefix_name + "_" + name
|
673
|
-
if not isinstance(annotation, primitive_types.RefType):
|
674
|
-
ctx.kernel_args.append(name)
|
675
|
-
if isinstance(annotation, ArgPackType):
|
676
|
-
kernel_arguments.push_argpack_arg(name)
|
677
|
-
d = {}
|
678
|
-
items_to_put_in_dict = []
|
679
|
-
for j, (_name, anno) in enumerate(annotation.members.items()):
|
680
|
-
result, obj = decl_and_create_variable(
|
681
|
-
anno, _name, arg_features[j], invoke_later_dict, full_name, arg_depth + 1
|
682
|
-
)
|
683
|
-
if not result:
|
684
|
-
d[_name] = None
|
685
|
-
items_to_put_in_dict.append((full_name + "_" + _name, _name, obj))
|
686
|
-
else:
|
687
|
-
d[_name] = obj
|
688
|
-
argpack = kernel_arguments.decl_argpack_arg(annotation, d)
|
689
|
-
for item in items_to_put_in_dict:
|
690
|
-
invoke_later_dict[item[0]] = argpack, item[1], *item[2]
|
691
|
-
return True, argpack
|
692
|
-
if annotation == annotations.template or isinstance(annotation, annotations.template):
|
693
|
-
return True, ctx.global_vars[name]
|
694
|
-
if isinstance(annotation, annotations.sparse_matrix_builder):
|
695
|
-
return False, (
|
696
|
-
kernel_arguments.decl_sparse_matrix,
|
697
|
-
(
|
698
|
-
to_taichi_type(arg_features),
|
699
|
-
full_name,
|
700
|
-
),
|
701
|
-
)
|
702
|
-
if isinstance(annotation, ndarray_type.NdarrayType):
|
703
|
-
return False, (
|
704
|
-
kernel_arguments.decl_ndarray_arg,
|
705
|
-
(
|
706
|
-
to_taichi_type(arg_features[0]),
|
707
|
-
arg_features[1],
|
708
|
-
full_name,
|
709
|
-
arg_features[2],
|
710
|
-
arg_features[3],
|
711
|
-
),
|
712
|
-
)
|
713
|
-
if isinstance(annotation, texture_type.TextureType):
|
714
|
-
return False, (kernel_arguments.decl_texture_arg, (arg_features[0], full_name))
|
715
|
-
if isinstance(annotation, texture_type.RWTextureType):
|
716
|
-
return False, (
|
717
|
-
kernel_arguments.decl_rw_texture_arg,
|
718
|
-
(arg_features[0], arg_features[1], arg_features[2], full_name),
|
719
|
-
)
|
720
|
-
if isinstance(annotation, MatrixType):
|
721
|
-
return True, kernel_arguments.decl_matrix_arg(annotation, name, arg_depth)
|
722
|
-
if isinstance(annotation, StructType):
|
723
|
-
return True, kernel_arguments.decl_struct_arg(annotation, name, arg_depth)
|
724
|
-
return True, kernel_arguments.decl_scalar_arg(annotation, name, arg_depth)
|
725
|
-
|
726
|
-
def transform_as_kernel() -> None:
|
727
|
-
if node.returns is not None:
|
728
|
-
if not isinstance(node.returns, ast.Constant):
|
729
|
-
for return_type in ctx.func.return_type:
|
730
|
-
kernel_arguments.decl_ret(return_type)
|
731
|
-
impl.get_runtime().compiling_callable.finalize_rets()
|
732
|
-
|
733
|
-
invoke_later_dict: dict[str, tuple[Any, str, Any]] = dict()
|
734
|
-
create_variable_later = dict()
|
735
|
-
for i, arg in enumerate(args.args):
|
736
|
-
argument = ctx.func.arguments[i]
|
737
|
-
if isinstance(argument.annotation, ArgPackType):
|
738
|
-
kernel_arguments.push_argpack_arg(argument.name)
|
739
|
-
d = {}
|
740
|
-
items_to_put_in_dict: list[tuple[str, str, Any]] = []
|
741
|
-
for j, (name, anno) in enumerate(argument.annotation.members.items()):
|
742
|
-
result, obj = decl_and_create_variable(
|
743
|
-
anno, name, ctx.arg_features[i][j], invoke_later_dict, "__argpack_" + name, 1
|
744
|
-
)
|
745
|
-
if not result:
|
746
|
-
d[name] = None
|
747
|
-
items_to_put_in_dict.append(("__argpack_" + name, name, obj))
|
748
|
-
else:
|
749
|
-
d[name] = obj
|
750
|
-
argpack = kernel_arguments.decl_argpack_arg(ctx.func.arguments[i].annotation, d)
|
751
|
-
for item in items_to_put_in_dict:
|
752
|
-
invoke_later_dict[item[0]] = argpack, item[1], *item[2]
|
753
|
-
create_variable_later[arg.arg] = argpack
|
754
|
-
elif dataclasses.is_dataclass(argument.annotation):
|
755
|
-
arg_features = ctx.arg_features[i]
|
756
|
-
ctx.create_variable(argument.name, argument.annotation)
|
757
|
-
for field_idx, field in enumerate(dataclasses.fields(argument.annotation)):
|
758
|
-
flat_name = f"__ti_{argument.name}_{field.name}"
|
759
|
-
result, obj = decl_and_create_variable(
|
760
|
-
field.type,
|
761
|
-
flat_name,
|
762
|
-
arg_features[field_idx],
|
763
|
-
invoke_later_dict,
|
764
|
-
"",
|
765
|
-
0,
|
766
|
-
)
|
767
|
-
if result:
|
768
|
-
ctx.create_variable(flat_name, obj)
|
769
|
-
else:
|
770
|
-
decl_type_func, type_args = obj
|
771
|
-
obj = decl_type_func(*type_args)
|
772
|
-
ctx.create_variable(flat_name, obj)
|
773
|
-
else:
|
774
|
-
result, obj = decl_and_create_variable(
|
775
|
-
argument.annotation,
|
776
|
-
argument.name,
|
777
|
-
ctx.arg_features[i] if ctx.arg_features is not None else None,
|
778
|
-
invoke_later_dict,
|
779
|
-
"",
|
780
|
-
0,
|
781
|
-
)
|
782
|
-
if result:
|
783
|
-
ctx.create_variable(arg.arg, obj)
|
784
|
-
else:
|
785
|
-
decl_type_func, type_args = obj
|
786
|
-
obj = decl_type_func(*type_args)
|
787
|
-
ctx.create_variable(arg.arg, obj)
|
788
|
-
for k, v in invoke_later_dict.items():
|
789
|
-
argpack, name, func, params = v
|
790
|
-
argpack[name] = func(*params)
|
791
|
-
for k, v in create_variable_later.items():
|
792
|
-
ctx.create_variable(k, v)
|
793
|
-
|
794
|
-
impl.get_runtime().compiling_callable.finalize_params()
|
795
|
-
# remove original args
|
796
|
-
node.args.args = []
|
797
|
-
|
798
|
-
if ctx.is_kernel: # ti.kernel
|
799
|
-
transform_as_kernel()
|
800
|
-
|
801
|
-
else: # ti.func
|
802
|
-
if ctx.is_real_function:
|
803
|
-
transform_as_kernel()
|
804
|
-
else:
|
805
|
-
for data_i, data in enumerate(ctx.argument_data):
|
806
|
-
argument = ctx.func.arguments[data_i]
|
807
|
-
if isinstance(argument.annotation, annotations.template):
|
808
|
-
ctx.create_variable(argument.name, data)
|
809
|
-
continue
|
810
|
-
|
811
|
-
elif dataclasses.is_dataclass(argument.annotation):
|
812
|
-
dataclass_type = argument.annotation
|
813
|
-
for field in dataclasses.fields(dataclass_type):
|
814
|
-
data_child = getattr(data, field.name)
|
815
|
-
if not isinstance(
|
816
|
-
data_child,
|
817
|
-
(
|
818
|
-
_ndarray.ScalarNdarray,
|
819
|
-
matrix.VectorNdarray,
|
820
|
-
matrix.MatrixNdarray,
|
821
|
-
any_array.AnyArray,
|
822
|
-
),
|
823
|
-
):
|
824
|
-
raise TaichiSyntaxError(
|
825
|
-
f"Argument {argument.name} of type {dataclass_type} {field.type} is not recognized."
|
826
|
-
)
|
827
|
-
field.type.check_matched(data_child.get_type(), field.name)
|
828
|
-
var_name = f"__ti_{argument.name}_{field.name}"
|
829
|
-
ctx.create_variable(var_name, data_child)
|
830
|
-
continue
|
831
|
-
|
832
|
-
# Ndarray arguments are passed by reference.
|
833
|
-
if isinstance(argument.annotation, (ndarray_type.NdarrayType)):
|
834
|
-
if not isinstance(
|
835
|
-
data,
|
836
|
-
(
|
837
|
-
_ndarray.ScalarNdarray,
|
838
|
-
matrix.VectorNdarray,
|
839
|
-
matrix.MatrixNdarray,
|
840
|
-
any_array.AnyArray,
|
841
|
-
),
|
842
|
-
):
|
843
|
-
raise TaichiSyntaxError(
|
844
|
-
f"Argument {arg.arg} of type {argument.annotation} is not recognized."
|
845
|
-
)
|
846
|
-
argument.annotation.check_matched(data.get_type(), argument.name)
|
847
|
-
ctx.create_variable(argument.name, data)
|
848
|
-
continue
|
849
|
-
|
850
|
-
# Matrix arguments are passed by value.
|
851
|
-
if isinstance(argument.annotation, (MatrixType)):
|
852
|
-
var_name = argument.name
|
853
|
-
# "data" is expected to be an Expr here,
|
854
|
-
# so we simply call "impl.expr_init_func(data)" to perform:
|
855
|
-
#
|
856
|
-
# TensorType* t = alloca()
|
857
|
-
# assign(t, data)
|
858
|
-
#
|
859
|
-
# We created local variable "t" - a copy of the passed-in argument "data"
|
860
|
-
if not isinstance(data, expr.Expr) or not data.ptr.is_tensor():
|
861
|
-
raise TaichiSyntaxError(
|
862
|
-
f"Argument {var_name} of type {argument.annotation} is expected to be a Matrix, but got {type(data)}."
|
863
|
-
)
|
864
|
-
|
865
|
-
element_shape = data.ptr.get_rvalue_type().shape()
|
866
|
-
if len(element_shape) != argument.annotation.ndim:
|
867
|
-
raise TaichiSyntaxError(
|
868
|
-
f"Argument {var_name} of type {argument.annotation} is expected to be a Matrix with ndim {argument.annotation.ndim}, but got {len(element_shape)}."
|
869
|
-
)
|
870
|
-
|
871
|
-
assert argument.annotation.ndim > 0
|
872
|
-
if element_shape[0] != argument.annotation.n:
|
873
|
-
raise TaichiSyntaxError(
|
874
|
-
f"Argument {var_name} of type {argument.annotation} is expected to be a Matrix with n {argument.annotation.n}, but got {element_shape[0]}."
|
875
|
-
)
|
876
|
-
|
877
|
-
if argument.annotation.ndim == 2 and element_shape[1] != argument.annotation.m:
|
878
|
-
raise TaichiSyntaxError(
|
879
|
-
f"Argument {var_name} of type {argument.annotation} is expected to be a Matrix with m {argument.annotation.m}, but got {element_shape[0]}."
|
880
|
-
)
|
881
|
-
|
882
|
-
ctx.create_variable(var_name, impl.expr_init_func(data))
|
883
|
-
continue
|
884
|
-
|
885
|
-
if id(argument.annotation) in primitive_types.type_ids:
|
886
|
-
var_name = argument.name
|
887
|
-
ctx.create_variable(var_name, impl.expr_init_func(ti_ops.cast(data, argument.annotation)))
|
888
|
-
continue
|
889
|
-
# Create a copy for non-template arguments,
|
890
|
-
# so that they are passed by value.
|
891
|
-
var_name = argument.name
|
892
|
-
ctx.create_variable(var_name, impl.expr_init_func(data))
|
893
|
-
for v in ctx.func.orig_arguments:
|
894
|
-
if dataclasses.is_dataclass(v.annotation):
|
895
|
-
ctx.create_variable(v.name, v.annotation)
|
896
|
-
|
897
|
-
with ctx.variable_scope_guard():
|
898
|
-
build_stmts(ctx, node.body)
|
899
|
-
|
900
|
-
return None
|
425
|
+
def build_FunctionDef(ctx: ASTTransformerContext, node: ast.FunctionDef) -> None:
|
426
|
+
FunctionDefTransformer.build_FunctionDef(ctx, node, build_stmts)
|
901
427
|
|
902
428
|
@staticmethod
|
903
429
|
def build_Return(ctx: ASTTransformerContext, node: ast.Return) -> None:
|
904
430
|
if not ctx.is_real_function:
|
905
431
|
if ctx.is_in_non_static_control_flow():
|
906
|
-
raise
|
432
|
+
raise GsTaichiSyntaxError("Return inside non-static if/for is not supported")
|
907
433
|
if node.value is not None:
|
908
434
|
build_stmt(ctx, node.value)
|
909
435
|
if node.value is None or node.value.ptr is None:
|
@@ -911,9 +437,9 @@ class ASTTransformer(Builder):
|
|
911
437
|
ctx.returned = ReturnStatus.ReturnedVoid
|
912
438
|
return None
|
913
439
|
if ctx.is_kernel or ctx.is_real_function:
|
914
|
-
# TODO: check if it's at the end of a kernel, throw
|
440
|
+
# TODO: check if it's at the end of a kernel, throw GsTaichiSyntaxError if not
|
915
441
|
if ctx.func.return_type is None:
|
916
|
-
raise
|
442
|
+
raise GsTaichiSyntaxError(
|
917
443
|
f'A {"kernel" if ctx.is_kernel else "function"} '
|
918
444
|
"with a return value must be annotated "
|
919
445
|
"with a return type, e.g. def func() -> ti.f32"
|
@@ -926,19 +452,19 @@ class ASTTransformer(Builder):
|
|
926
452
|
if id(return_type) in primitive_types.type_ids:
|
927
453
|
if isinstance(ptr, Expr):
|
928
454
|
if ptr.is_tensor() or ptr.is_struct() or ptr.element_type() not in primitive_types.all_types:
|
929
|
-
raise
|
455
|
+
raise GsTaichiRuntimeTypeError.get_ret(str(return_type), ptr)
|
930
456
|
elif not isinstance(ptr, (float, int, np.floating, np.integer)):
|
931
|
-
raise
|
457
|
+
raise GsTaichiRuntimeTypeError.get_ret(str(return_type), ptr)
|
932
458
|
return_exprs += [ti_ops.cast(expr.Expr(ptr), return_type).ptr]
|
933
459
|
elif isinstance(return_type, MatrixType):
|
934
460
|
values = ptr
|
935
461
|
if isinstance(values, Matrix):
|
936
462
|
if values.ndim != ctx.func.return_type.ndim:
|
937
|
-
raise
|
463
|
+
raise GsTaichiRuntimeTypeError(
|
938
464
|
f"Return matrix ndim mismatch, expecting={return_type.ndim}, got={values.ndim}."
|
939
465
|
)
|
940
466
|
elif return_type.get_shape() != values.get_shape():
|
941
|
-
raise
|
467
|
+
raise GsTaichiRuntimeTypeError(
|
942
468
|
f"Return matrix shape mismatch, expecting={return_type.get_shape()}, got={values.get_shape()}."
|
943
469
|
)
|
944
470
|
values = (
|
@@ -948,23 +474,23 @@ class ASTTransformer(Builder):
|
|
948
474
|
)
|
949
475
|
elif isinstance(values, Expr):
|
950
476
|
if not values.is_tensor():
|
951
|
-
raise
|
477
|
+
raise GsTaichiRuntimeTypeError.get_ret(return_type.to_string(), ptr)
|
952
478
|
elif (
|
953
479
|
return_type.dtype in primitive_types.real_types
|
954
480
|
and not values.element_type() in primitive_types.all_types
|
955
481
|
):
|
956
|
-
raise
|
482
|
+
raise GsTaichiRuntimeTypeError.get_ret(return_type.dtype.to_string(), values.element_type())
|
957
483
|
elif (
|
958
484
|
return_type.dtype in primitive_types.integer_types
|
959
485
|
and not values.element_type() in primitive_types.integer_types
|
960
486
|
):
|
961
|
-
raise
|
487
|
+
raise GsTaichiRuntimeTypeError.get_ret(return_type.dtype.to_string(), values.element_type())
|
962
488
|
elif len(values.get_shape()) != return_type.ndim:
|
963
|
-
raise
|
489
|
+
raise GsTaichiRuntimeTypeError(
|
964
490
|
f"Return matrix ndim mismatch, expecting={return_type.ndim}, got={len(values.get_shape())}."
|
965
491
|
)
|
966
492
|
elif return_type.get_shape() != values.get_shape():
|
967
|
-
raise
|
493
|
+
raise GsTaichiRuntimeTypeError(
|
968
494
|
f"Return matrix shape mismatch, expecting={return_type.get_shape()}, got={values.get_shape()}."
|
969
495
|
)
|
970
496
|
values = [values]
|
@@ -977,27 +503,27 @@ class ASTTransformer(Builder):
|
|
977
503
|
np.floating,
|
978
504
|
np.integer,
|
979
505
|
):
|
980
|
-
raise
|
506
|
+
raise GsTaichiRuntimeTypeError.get_ret(return_type.dtype.to_string(), dt)
|
981
507
|
elif return_type.dtype in primitive_types.integer_types and dt not in (int, np.integer):
|
982
|
-
raise
|
508
|
+
raise GsTaichiRuntimeTypeError.get_ret(return_type.dtype.to_string(), dt)
|
983
509
|
elif ndim != return_type.ndim:
|
984
|
-
raise
|
510
|
+
raise GsTaichiRuntimeTypeError(
|
985
511
|
f"Return matrix ndim mismatch, expecting={return_type.ndim}, got={ndim}."
|
986
512
|
)
|
987
513
|
elif return_type.get_shape() != shape:
|
988
|
-
raise
|
514
|
+
raise GsTaichiRuntimeTypeError(
|
989
515
|
f"Return matrix shape mismatch, expecting={return_type.get_shape()}, got={shape}."
|
990
516
|
)
|
991
517
|
values = [values]
|
992
518
|
return_exprs += [ti_ops.cast(exp, return_type.dtype) for exp in values]
|
993
519
|
elif isinstance(return_type, StructType):
|
994
520
|
if not isinstance(ptr, Struct) or not isinstance(ptr, return_type):
|
995
|
-
raise
|
521
|
+
raise GsTaichiRuntimeTypeError.get_ret(str(return_type), ptr)
|
996
522
|
values = ptr
|
997
523
|
assert isinstance(values, Struct)
|
998
524
|
return_exprs += expr._get_flattened_ptrs(values)
|
999
525
|
else:
|
1000
|
-
raise
|
526
|
+
raise GsTaichiSyntaxError("The return type is not supported now!")
|
1001
527
|
ctx.ast_builder.create_kernel_exprgroup_return(
|
1002
528
|
expr.make_expr_group(return_exprs), _ti_core.DebugInfo(ctx.get_pos_info(node))
|
1003
529
|
)
|
@@ -1065,7 +591,7 @@ class ASTTransformer(Builder):
|
|
1065
591
|
# For the first case, the AST (simplified) is like node = Attribute(value=Subscript(value=x, slice=[i, j]),
|
1066
592
|
# attr="append"), when we build_stmt(node.value)(build the expression of the Subscript i.e. x[i, j]),
|
1067
593
|
# it should build the expression of node.value.value (i.e. x) and node.value.slice (i.e. [i, j]), and raise a
|
1068
|
-
#
|
594
|
+
# GsTaichiIndexError because the dimension of the field is not equal to the number of the indices. Therefore,
|
1069
595
|
# when we meet the error, we can detect whether it is a method of Dynamic SNode and build the expression if
|
1070
596
|
# it is by calling build_attribute_if_is_dynamic_snode_method. If we find that it is not a method of Dynamic
|
1071
597
|
# SNode, we raise the error again.
|
@@ -1079,7 +605,7 @@ class ASTTransformer(Builder):
|
|
1079
605
|
build_stmt(ctx, node.value)
|
1080
606
|
except Exception as e:
|
1081
607
|
e = handle_exception_from_cpp(e)
|
1082
|
-
if isinstance(e,
|
608
|
+
if isinstance(e, GsTaichiIndexError):
|
1083
609
|
node.value.ptr = None
|
1084
610
|
if ASTTransformer.build_attribute_if_is_dynamic_snode_method(ctx, node):
|
1085
611
|
return node.ptr
|
@@ -1113,7 +639,7 @@ class ASTTransformer(Builder):
|
|
1113
639
|
)
|
1114
640
|
)
|
1115
641
|
else:
|
1116
|
-
from
|
642
|
+
from gstaichi.lang import ( # pylint: disable=C0415
|
1117
643
|
matrix_ops as tensor_ops,
|
1118
644
|
)
|
1119
645
|
|
@@ -1128,7 +654,7 @@ class ASTTransformer(Builder):
|
|
1128
654
|
build_stmt(ctx, node.left)
|
1129
655
|
build_stmt(ctx, node.right)
|
1130
656
|
# pylint: disable-msg=C0415
|
1131
|
-
from
|
657
|
+
from gstaichi.lang.matrix_ops import matmul
|
1132
658
|
|
1133
659
|
op = {
|
1134
660
|
ast.Add: lambda l, r: l + r,
|
@@ -1148,7 +674,7 @@ class ASTTransformer(Builder):
|
|
1148
674
|
try:
|
1149
675
|
node.ptr = op(node.left.ptr, node.right.ptr)
|
1150
676
|
except TypeError as e:
|
1151
|
-
raise
|
677
|
+
raise GsTaichiTypeError(str(e)) from None
|
1152
678
|
return node.ptr
|
1153
679
|
|
1154
680
|
@staticmethod
|
@@ -1156,7 +682,7 @@ class ASTTransformer(Builder):
|
|
1156
682
|
build_stmt(ctx, node.target)
|
1157
683
|
build_stmt(ctx, node.value)
|
1158
684
|
if isinstance(node.target, ast.Name) and node.target.id in ctx.kernel_args:
|
1159
|
-
raise
|
685
|
+
raise GsTaichiSyntaxError(
|
1160
686
|
f'Kernel argument "{node.target.id}" is immutable in the kernel. '
|
1161
687
|
f"If you want to change its value, please create a new variable."
|
1162
688
|
)
|
@@ -1243,36 +769,22 @@ class ASTTransformer(Builder):
|
|
1243
769
|
for i, node_op in enumerate(node.ops):
|
1244
770
|
if isinstance(node_op, (ast.Is, ast.IsNot)):
|
1245
771
|
name = "is" if isinstance(node_op, ast.Is) else "is not"
|
1246
|
-
raise
|
772
|
+
raise GsTaichiSyntaxError(f'Operator "{name}" in GsTaichi scope is not supported.')
|
1247
773
|
l = operands[i]
|
1248
774
|
r = operands[i + 1]
|
1249
775
|
op = ops.get(type(node_op))
|
1250
776
|
|
1251
777
|
if op is None:
|
1252
778
|
if type(node_op) in ops_static:
|
1253
|
-
raise
|
779
|
+
raise GsTaichiSyntaxError(f'"{type(node_op).__name__}" is only supported inside `ti.static`.')
|
1254
780
|
else:
|
1255
|
-
raise
|
781
|
+
raise GsTaichiSyntaxError(f'"{type(node_op).__name__}" is not supported in GsTaichi kernels.')
|
1256
782
|
val = ti_ops.logical_and(val, op(l, r))
|
1257
783
|
if not isinstance(val, (bool, np.bool_)):
|
1258
784
|
val = ti_ops.cast(val, primitive_types.u1)
|
1259
785
|
node.ptr = val
|
1260
786
|
return node.ptr
|
1261
787
|
|
1262
|
-
@staticmethod
|
1263
|
-
def get_decorator(ctx: ASTTransformerContext, node) -> str:
|
1264
|
-
if not isinstance(node, ast.Call):
|
1265
|
-
return ""
|
1266
|
-
for wanted, name in [
|
1267
|
-
(impl.static, "static"),
|
1268
|
-
(impl.static_assert, "static_assert"),
|
1269
|
-
(impl.grouped, "grouped"),
|
1270
|
-
(ndrange, "ndrange"),
|
1271
|
-
]:
|
1272
|
-
if ASTResolver.resolve_to(node.func, wanted, ctx.global_vars):
|
1273
|
-
return name
|
1274
|
-
return ""
|
1275
|
-
|
1276
788
|
@staticmethod
|
1277
789
|
def get_for_loop_targets(node: ast.Name | ast.Tuple | Any) -> list:
|
1278
790
|
"""
|
@@ -1291,10 +803,10 @@ class ASTTransformer(Builder):
|
|
1291
803
|
assert len(node.iter.args[0].args) == 1
|
1292
804
|
ndrange_arg = build_stmt(ctx, node.iter.args[0].args[0])
|
1293
805
|
if not isinstance(ndrange_arg, _Ndrange):
|
1294
|
-
raise
|
806
|
+
raise GsTaichiSyntaxError("Only 'ti.ndrange' is allowed in 'ti.static(ti.grouped(...))'.")
|
1295
807
|
targets = ASTTransformer.get_for_loop_targets(node)
|
1296
808
|
if len(targets) != 1:
|
1297
|
-
raise
|
809
|
+
raise GsTaichiSyntaxError(f"Group for should have 1 loop target, found {len(targets)}")
|
1298
810
|
target = targets[0]
|
1299
811
|
iter_time = 0
|
1300
812
|
alert_already = False
|
@@ -1311,7 +823,7 @@ class ASTTransformer(Builder):
|
|
1311
823
|
SyntaxWarning,
|
1312
824
|
ctx.file,
|
1313
825
|
node.lineno + ctx.lineno_offset,
|
1314
|
-
module="
|
826
|
+
module="gstaichi",
|
1315
827
|
)
|
1316
828
|
|
1317
829
|
with ctx.variable_scope_guard():
|
@@ -1343,7 +855,7 @@ class ASTTransformer(Builder):
|
|
1343
855
|
SyntaxWarning,
|
1344
856
|
ctx.file,
|
1345
857
|
node.lineno + ctx.lineno_offset,
|
1346
|
-
module="
|
858
|
+
module="gstaichi",
|
1347
859
|
)
|
1348
860
|
|
1349
861
|
with ctx.variable_scope_guard():
|
@@ -1365,7 +877,7 @@ class ASTTransformer(Builder):
|
|
1365
877
|
loop_var = expr.Expr(ctx.ast_builder.make_id_expr(""))
|
1366
878
|
ctx.create_variable(loop_name, loop_var)
|
1367
879
|
if len(node.iter.args) not in [1, 2]:
|
1368
|
-
raise
|
880
|
+
raise GsTaichiSyntaxError(f"Range should have 1 or 2 arguments, found {len(node.iter.args)}")
|
1369
881
|
if len(node.iter.args) == 2:
|
1370
882
|
begin_expr = expr.Expr(build_stmt(ctx, node.iter.args[0]))
|
1371
883
|
end_expr = expr.Expr(build_stmt(ctx, node.iter.args[1]))
|
@@ -1407,7 +919,7 @@ class ASTTransformer(Builder):
|
|
1407
919
|
I = impl.expr_init(ndrange_loop_var)
|
1408
920
|
targets = ASTTransformer.get_for_loop_targets(node)
|
1409
921
|
if len(targets) != len(ndrange_var.dimensions):
|
1410
|
-
raise
|
922
|
+
raise GsTaichiSyntaxError(
|
1411
923
|
"Ndrange for loop with number of the loop variables not equal to "
|
1412
924
|
"the dimension of the ndrange is not supported. "
|
1413
925
|
"Please check if the number of arguments of ti.ndrange() is equal to "
|
@@ -1450,7 +962,7 @@ class ASTTransformer(Builder):
|
|
1450
962
|
|
1451
963
|
targets = ASTTransformer.get_for_loop_targets(node)
|
1452
964
|
if len(targets) != 1:
|
1453
|
-
raise
|
965
|
+
raise GsTaichiSyntaxError(f"Group for should have 1 loop target, found {len(targets)}")
|
1454
966
|
target = targets[0]
|
1455
967
|
mat = matrix.make_matrix([0] * len(ndrange_var.dimensions), dt=primitive_types.i32)
|
1456
968
|
target_var = impl.expr_init(mat)
|
@@ -1481,7 +993,7 @@ class ASTTransformer(Builder):
|
|
1481
993
|
with ctx.variable_scope_guard():
|
1482
994
|
if is_grouped:
|
1483
995
|
if len(targets) != 1:
|
1484
|
-
raise
|
996
|
+
raise GsTaichiSyntaxError(f"Group for should have 1 loop target, found {len(targets)}")
|
1485
997
|
target = targets[0]
|
1486
998
|
loop_var = build_stmt(ctx, node.iter)
|
1487
999
|
loop_indices = expr.make_var_list(size=len(loop_var.shape), ast_builder=ctx.ast_builder)
|
@@ -1507,7 +1019,7 @@ class ASTTransformer(Builder):
|
|
1507
1019
|
def build_mesh_for(ctx: ASTTransformerContext, node: ast.For) -> None:
|
1508
1020
|
targets = ASTTransformer.get_for_loop_targets(node)
|
1509
1021
|
if len(targets) != 1:
|
1510
|
-
raise
|
1022
|
+
raise GsTaichiSyntaxError("Mesh for should have 1 loop target, found {len(targets)}")
|
1511
1023
|
target = targets[0]
|
1512
1024
|
|
1513
1025
|
with ctx.variable_scope_guard():
|
@@ -1531,7 +1043,7 @@ class ASTTransformer(Builder):
|
|
1531
1043
|
def build_nested_mesh_for(ctx: ASTTransformerContext, node: ast.For) -> None:
|
1532
1044
|
targets = ASTTransformer.get_for_loop_targets(node)
|
1533
1045
|
if len(targets) != 1:
|
1534
|
-
raise
|
1046
|
+
raise GsTaichiSyntaxError("Nested-mesh for should have 1 loop target, found {len(targets)}")
|
1535
1047
|
target = targets[0]
|
1536
1048
|
|
1537
1049
|
with ctx.variable_scope_guard():
|
@@ -1561,29 +1073,29 @@ class ASTTransformer(Builder):
|
|
1561
1073
|
@staticmethod
|
1562
1074
|
def build_For(ctx: ASTTransformerContext, node: ast.For) -> None:
|
1563
1075
|
if node.orelse:
|
1564
|
-
raise
|
1565
|
-
decorator =
|
1076
|
+
raise GsTaichiSyntaxError("'else' clause for 'for' not supported in GsTaichi kernels")
|
1077
|
+
decorator = get_decorator(ctx, node.iter)
|
1566
1078
|
double_decorator = ""
|
1567
1079
|
if decorator != "" and len(node.iter.args) == 1:
|
1568
|
-
double_decorator =
|
1080
|
+
double_decorator = get_decorator(ctx, node.iter.args[0])
|
1569
1081
|
|
1570
1082
|
if decorator == "static":
|
1571
1083
|
if double_decorator == "static":
|
1572
|
-
raise
|
1084
|
+
raise GsTaichiSyntaxError("'ti.static' cannot be nested")
|
1573
1085
|
with ctx.loop_scope_guard(is_static=True):
|
1574
1086
|
return ASTTransformer.build_static_for(ctx, node, double_decorator == "grouped")
|
1575
1087
|
with ctx.loop_scope_guard():
|
1576
1088
|
if decorator == "ndrange":
|
1577
1089
|
if double_decorator != "":
|
1578
|
-
raise
|
1090
|
+
raise GsTaichiSyntaxError("No decorator is allowed inside 'ti.ndrange")
|
1579
1091
|
return ASTTransformer.build_ndrange_for(ctx, node)
|
1580
1092
|
if decorator == "grouped":
|
1581
1093
|
if double_decorator == "static":
|
1582
|
-
raise
|
1094
|
+
raise GsTaichiSyntaxError("'ti.static' is not allowed inside 'ti.grouped'")
|
1583
1095
|
elif double_decorator == "ndrange":
|
1584
1096
|
return ASTTransformer.build_grouped_ndrange_for(ctx, node)
|
1585
1097
|
elif double_decorator == "grouped":
|
1586
|
-
raise
|
1098
|
+
raise GsTaichiSyntaxError("'ti.grouped' cannot be nested")
|
1587
1099
|
else:
|
1588
1100
|
return ASTTransformer.build_struct_for(ctx, node, is_grouped=True)
|
1589
1101
|
elif (
|
@@ -1597,7 +1109,7 @@ class ASTTransformer(Builder):
|
|
1597
1109
|
if isinstance(node.iter.ptr, mesh.MeshElementField):
|
1598
1110
|
if not _ti_core.is_extension_supported(impl.default_cfg().arch, _ti_core.Extension.mesh):
|
1599
1111
|
raise Exception(
|
1600
|
-
"Backend " + str(impl.default_cfg().arch) + " doesn't support
|
1112
|
+
"Backend " + str(impl.default_cfg().arch) + " doesn't support MeshGsTaichi extension"
|
1601
1113
|
)
|
1602
1114
|
return ASTTransformer.build_mesh_for(ctx, node)
|
1603
1115
|
if isinstance(node.iter.ptr, mesh.MeshRelationAccessProxy):
|
@@ -1608,7 +1120,7 @@ class ASTTransformer(Builder):
|
|
1608
1120
|
@staticmethod
|
1609
1121
|
def build_While(ctx: ASTTransformerContext, node: ast.While) -> None:
|
1610
1122
|
if node.orelse:
|
1611
|
-
raise
|
1123
|
+
raise GsTaichiSyntaxError("'else' clause for 'while' not supported in GsTaichi kernels")
|
1612
1124
|
|
1613
1125
|
with ctx.loop_scope_guard():
|
1614
1126
|
stmt_dbg_info = _ti_core.DebugInfo(ctx.get_pos_info(node))
|
@@ -1627,7 +1139,7 @@ class ASTTransformer(Builder):
|
|
1627
1139
|
@staticmethod
|
1628
1140
|
def build_If(ctx: ASTTransformerContext, node: ast.If) -> ast.If | None:
|
1629
1141
|
build_stmt(ctx, node.test)
|
1630
|
-
is_static_if =
|
1142
|
+
is_static_if = get_decorator(ctx, node.test) == "static"
|
1631
1143
|
|
1632
1144
|
if is_static_if:
|
1633
1145
|
if node.test.ptr:
|
@@ -1668,15 +1180,15 @@ class ASTTransformer(Builder):
|
|
1668
1180
|
|
1669
1181
|
if has_tensor_type:
|
1670
1182
|
if isinstance(node.test.ptr, expr.Expr) and node.test.ptr.is_tensor():
|
1671
|
-
raise
|
1183
|
+
raise GsTaichiSyntaxError(
|
1672
1184
|
"Using conditional expression for element-wise select operation on "
|
1673
|
-
"
|
1185
|
+
"GsTaichi vectors/matrices is deprecated and removed starting from GsTaichi v1.5.0 "
|
1674
1186
|
'Please use "ti.select" instead.'
|
1675
1187
|
)
|
1676
1188
|
node.ptr = ti_ops.select(node.test.ptr, node.body.ptr, node.orelse.ptr)
|
1677
1189
|
return node.ptr
|
1678
1190
|
|
1679
|
-
is_static_if =
|
1191
|
+
is_static_if = get_decorator(ctx, node.test) == "static"
|
1680
1192
|
|
1681
1193
|
if is_static_if:
|
1682
1194
|
if node.test.ptr:
|
@@ -1721,17 +1233,17 @@ class ASTTransformer(Builder):
|
|
1721
1233
|
for entry in entries:
|
1722
1234
|
if isinstance(entry, str):
|
1723
1235
|
msg += entry
|
1724
|
-
elif isinstance(entry, _ti_core.
|
1236
|
+
elif isinstance(entry, _ti_core.ExprCxx):
|
1725
1237
|
ty = entry.get_rvalue_type()
|
1726
1238
|
if ty in primitive_types.real_types:
|
1727
1239
|
msg += "%f"
|
1728
1240
|
elif ty in primitive_types.integer_types:
|
1729
1241
|
msg += "%d"
|
1730
1242
|
else:
|
1731
|
-
raise
|
1243
|
+
raise GsTaichiSyntaxError(f"Unsupported data type: {type(ty)}")
|
1732
1244
|
args.append(entry)
|
1733
1245
|
else:
|
1734
|
-
raise
|
1246
|
+
raise GsTaichiSyntaxError(f"Unsupported type: {type(entry)}")
|
1735
1247
|
return msg, args
|
1736
1248
|
|
1737
1249
|
@staticmethod
|
@@ -1749,7 +1261,7 @@ class ASTTransformer(Builder):
|
|
1749
1261
|
elif isinstance(msg, collections.abc.Sequence) and len(msg) > 0 and msg[0] == "__ti_format__":
|
1750
1262
|
msg, extra_args = ASTTransformer.ti_format_list_to_assert_msg(msg)
|
1751
1263
|
else:
|
1752
|
-
raise
|
1264
|
+
raise GsTaichiSyntaxError(f"assert info must be constant or formatted string, not {type(msg)}")
|
1753
1265
|
else:
|
1754
1266
|
msg = unparse(node.test)
|
1755
1267
|
test = build_stmt(ctx, node.test)
|
@@ -1766,7 +1278,7 @@ class ASTTransformer(Builder):
|
|
1766
1278
|
"You are trying to `break` a static `for` loop, "
|
1767
1279
|
"but the `break` statement is inside a non-static `if`. "
|
1768
1280
|
)
|
1769
|
-
raise
|
1281
|
+
raise GsTaichiSyntaxError(msg)
|
1770
1282
|
ctx.set_loop_status(LoopStatus.Break)
|
1771
1283
|
else:
|
1772
1284
|
ctx.ast_builder.insert_break_stmt(_ti_core.DebugInfo(ctx.get_pos_info(node)))
|
@@ -1782,7 +1294,7 @@ class ASTTransformer(Builder):
|
|
1782
1294
|
"You are trying to `continue` a static `for` loop, "
|
1783
1295
|
"but the `continue` statement is inside a non-static `if`. "
|
1784
1296
|
)
|
1785
|
-
raise
|
1297
|
+
raise GsTaichiSyntaxError(msg)
|
1786
1298
|
ctx.set_loop_status(LoopStatus.Continue)
|
1787
1299
|
else:
|
1788
1300
|
ctx.ast_builder.insert_continue_stmt(_ti_core.DebugInfo(ctx.get_pos_info(node)))
|
@@ -1796,7 +1308,7 @@ class ASTTransformer(Builder):
|
|
1796
1308
|
build_stmt = ASTTransformer()
|
1797
1309
|
|
1798
1310
|
|
1799
|
-
def build_stmts(ctx: ASTTransformerContext, stmts: list):
|
1311
|
+
def build_stmts(ctx: ASTTransformerContext, stmts: list[ast.stmt]):
|
1800
1312
|
with ctx.variable_scope_guard():
|
1801
1313
|
for stmt in stmts:
|
1802
1314
|
if ctx.returned != ReturnStatus.NoReturn or ctx.loop_status() != LoopStatus.Normal:
|