gstaichi 0.1.23.dev0__cp310-cp310-win_amd64.whl → 0.1.25.dev0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +9 -0
- {taichi → gstaichi}/__init__.py +9 -13
- {taichi → gstaichi}/_funcs.py +8 -8
- {taichi → gstaichi}/_kernels.py +19 -19
- gstaichi/_lib/__init__.py +3 -0
- taichi/_lib/core/taichi_python.cp310-win_amd64.pyd → gstaichi/_lib/core/gstaichi_python.cp310-win_amd64.pyd +0 -0
- taichi/_lib/core/taichi_python.pyi → gstaichi/_lib/core/gstaichi_python.pyi +382 -522
- {taichi → gstaichi}/_lib/runtime/runtime_cuda.bc +0 -0
- {taichi → gstaichi}/_lib/runtime/runtime_x64.bc +0 -0
- {taichi → gstaichi}/_lib/utils.py +15 -15
- {taichi → gstaichi}/_logging.py +1 -1
- {taichi → gstaichi}/_main.py +24 -31
- gstaichi/_snode/__init__.py +5 -0
- {taichi → gstaichi}/_snode/fields_builder.py +27 -29
- {taichi → gstaichi}/_snode/snode_tree.py +5 -5
- gstaichi/_test_tools/__init__.py +0 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_version.py +1 -0
- {taichi → gstaichi}/_version_check.py +8 -5
- gstaichi/ad/__init__.py +3 -0
- {taichi → gstaichi}/ad/_ad.py +26 -26
- {taichi → gstaichi}/algorithms/_algorithms.py +7 -7
- {taichi → gstaichi}/examples/minimal.py +1 -1
- {taichi → gstaichi}/experimental.py +1 -1
- gstaichi/lang/__init__.py +50 -0
- {taichi → gstaichi}/lang/_ndarray.py +30 -26
- {taichi → gstaichi}/lang/_ndrange.py +8 -8
- gstaichi/lang/_template_mapper.py +199 -0
- {taichi → gstaichi}/lang/_texture.py +19 -19
- {taichi → gstaichi}/lang/_wrap_inspect.py +7 -7
- {taichi → gstaichi}/lang/any_array.py +13 -13
- {taichi → gstaichi}/lang/argpack.py +29 -29
- gstaichi/lang/ast/__init__.py +5 -0
- {taichi → gstaichi}/lang/ast/ast_transformer.py +94 -582
- {taichi → gstaichi}/lang/ast/ast_transformer_utils.py +54 -41
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +267 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +320 -0
- {taichi → gstaichi}/lang/ast/checkers.py +5 -5
- gstaichi/lang/ast/transform.py +9 -0
- {taichi → gstaichi}/lang/common_ops.py +12 -12
- gstaichi/lang/exception.py +80 -0
- {taichi → gstaichi}/lang/expr.py +22 -22
- {taichi → gstaichi}/lang/field.py +29 -27
- {taichi → gstaichi}/lang/impl.py +116 -121
- {taichi → gstaichi}/lang/kernel_arguments.py +16 -16
- {taichi → gstaichi}/lang/kernel_impl.py +330 -363
- {taichi → gstaichi}/lang/matrix.py +119 -115
- {taichi → gstaichi}/lang/matrix_ops.py +6 -6
- {taichi → gstaichi}/lang/matrix_ops_utils.py +4 -4
- {taichi → gstaichi}/lang/mesh.py +22 -22
- {taichi → gstaichi}/lang/misc.py +39 -68
- {taichi → gstaichi}/lang/ops.py +146 -141
- {taichi → gstaichi}/lang/runtime_ops.py +2 -2
- {taichi → gstaichi}/lang/shell.py +3 -3
- {taichi → gstaichi}/lang/simt/__init__.py +1 -1
- {taichi → gstaichi}/lang/simt/block.py +7 -7
- {taichi → gstaichi}/lang/simt/grid.py +1 -1
- {taichi → gstaichi}/lang/simt/subgroup.py +1 -1
- {taichi → gstaichi}/lang/simt/warp.py +1 -1
- {taichi → gstaichi}/lang/snode.py +46 -44
- {taichi → gstaichi}/lang/source_builder.py +13 -13
- {taichi → gstaichi}/lang/struct.py +33 -33
- {taichi → gstaichi}/lang/util.py +24 -24
- gstaichi/linalg/__init__.py +8 -0
- {taichi → gstaichi}/linalg/matrixfree_cg.py +14 -14
- {taichi → gstaichi}/linalg/sparse_cg.py +10 -10
- {taichi → gstaichi}/linalg/sparse_matrix.py +23 -23
- {taichi → gstaichi}/linalg/sparse_solver.py +21 -21
- {taichi → gstaichi}/math/__init__.py +1 -1
- {taichi → gstaichi}/math/_complex.py +21 -20
- {taichi → gstaichi}/math/mathimpl.py +56 -56
- gstaichi/profiler/__init__.py +6 -0
- {taichi → gstaichi}/profiler/kernel_metrics.py +11 -11
- {taichi → gstaichi}/profiler/kernel_profiler.py +30 -36
- {taichi → gstaichi}/profiler/memory_profiler.py +1 -1
- {taichi → gstaichi}/profiler/scoped_profiler.py +2 -2
- {taichi → gstaichi}/sparse/_sparse_grid.py +7 -7
- {taichi → gstaichi}/tools/__init__.py +4 -4
- {taichi → gstaichi}/tools/diagnose.py +10 -17
- gstaichi/types/__init__.py +19 -0
- {taichi → gstaichi}/types/annotations.py +1 -1
- {taichi → gstaichi}/types/compound_types.py +8 -8
- {taichi → gstaichi}/types/enums.py +1 -1
- {taichi → gstaichi}/types/ndarray_type.py +7 -7
- {taichi → gstaichi}/types/primitive_types.py +17 -14
- {taichi → gstaichi}/types/quant.py +9 -9
- {taichi → gstaichi}/types/texture_type.py +5 -5
- {taichi → gstaichi}/types/utils.py +1 -1
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/bin/SPIRV-Tools-shared.dll +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools-diff.lib +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools-link.lib +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools-lint.lib +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools-opt.lib +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools-reduce.lib +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools-shared.lib +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools.lib +0 -0
- {gstaichi-0.1.23.dev0.dist-info → gstaichi-0.1.25.dev0.dist-info}/METADATA +13 -16
- gstaichi-0.1.25.dev0.dist-info/RECORD +138 -0
- gstaichi-0.1.25.dev0.dist-info/entry_points.txt +2 -0
- gstaichi-0.1.25.dev0.dist-info/top_level.txt +1 -0
- gstaichi-0.1.23.dev0.data/data/include/GLFW/glfw3.h +0 -6389
- gstaichi-0.1.23.dev0.data/data/include/GLFW/glfw3native.h +0 -594
- gstaichi-0.1.23.dev0.data/data/lib/cmake/glfw3/glfw3Config.cmake +0 -3
- gstaichi-0.1.23.dev0.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +0 -65
- gstaichi-0.1.23.dev0.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +0 -19
- gstaichi-0.1.23.dev0.data/data/lib/cmake/glfw3/glfw3Targets.cmake +0 -107
- gstaichi-0.1.23.dev0.data/data/lib/glfw3.lib +0 -0
- gstaichi-0.1.23.dev0.dist-info/RECORD +0 -198
- gstaichi-0.1.23.dev0.dist-info/entry_points.txt +0 -2
- gstaichi-0.1.23.dev0.dist-info/top_level.txt +0 -1
- taichi/CHANGELOG.md +0 -20
- taichi/_lib/__init__.py +0 -3
- taichi/_lib/c_api/bin/taichi_c_api.dll +0 -0
- taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +0 -1401
- taichi/_lib/c_api/include/taichi/taichi.h +0 -29
- taichi/_lib/c_api/include/taichi/taichi_core.h +0 -1111
- taichi/_lib/c_api/include/taichi/taichi_cpu.h +0 -29
- taichi/_lib/c_api/include/taichi/taichi_cuda.h +0 -36
- taichi/_lib/c_api/include/taichi/taichi_platform.h +0 -55
- taichi/_lib/c_api/include/taichi/taichi_unity.h +0 -64
- taichi/_lib/c_api/include/taichi/taichi_vulkan.h +0 -151
- taichi/_lib/c_api/lib/taichi_c_api.lib +0 -0
- taichi/_lib/c_api/runtime/runtime_cuda.bc +0 -0
- taichi/_lib/c_api/runtime/runtime_x64.bc +0 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +0 -29
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +0 -65
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +0 -121
- taichi/_lib/runtime/slim_libdevice.10.bc +0 -0
- taichi/_snode/__init__.py +0 -5
- taichi/_ti_module/__init__.py +0 -3
- taichi/_ti_module/cppgen.py +0 -309
- taichi/_ti_module/module.py +0 -145
- taichi/_version.py +0 -1
- taichi/ad/__init__.py +0 -3
- taichi/aot/__init__.py +0 -12
- taichi/aot/_export.py +0 -28
- taichi/aot/conventions/__init__.py +0 -3
- taichi/aot/conventions/gfxruntime140/__init__.py +0 -38
- taichi/aot/conventions/gfxruntime140/dr.py +0 -244
- taichi/aot/conventions/gfxruntime140/sr.py +0 -613
- taichi/aot/module.py +0 -253
- taichi/aot/utils.py +0 -151
- taichi/graph/__init__.py +0 -3
- taichi/graph/_graph.py +0 -292
- taichi/lang/__init__.py +0 -50
- taichi/lang/ast/__init__.py +0 -5
- taichi/lang/ast/transform.py +0 -9
- taichi/lang/exception.py +0 -80
- taichi/linalg/__init__.py +0 -8
- taichi/profiler/__init__.py +0 -6
- taichi/shaders/Circles_vk.frag +0 -29
- taichi/shaders/Circles_vk.vert +0 -45
- taichi/shaders/Circles_vk_frag.spv +0 -0
- taichi/shaders/Circles_vk_vert.spv +0 -0
- taichi/shaders/Lines_vk.frag +0 -9
- taichi/shaders/Lines_vk.vert +0 -11
- taichi/shaders/Lines_vk_frag.spv +0 -0
- taichi/shaders/Lines_vk_vert.spv +0 -0
- taichi/shaders/Mesh_vk.frag +0 -71
- taichi/shaders/Mesh_vk.vert +0 -68
- taichi/shaders/Mesh_vk_frag.spv +0 -0
- taichi/shaders/Mesh_vk_vert.spv +0 -0
- taichi/shaders/Particles_vk.frag +0 -95
- taichi/shaders/Particles_vk.vert +0 -73
- taichi/shaders/Particles_vk_frag.spv +0 -0
- taichi/shaders/Particles_vk_vert.spv +0 -0
- taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
- taichi/shaders/SceneLines_vk.frag +0 -9
- taichi/shaders/SceneLines_vk.vert +0 -12
- taichi/shaders/SceneLines_vk_frag.spv +0 -0
- taichi/shaders/SceneLines_vk_vert.spv +0 -0
- taichi/shaders/SetImage_vk.frag +0 -21
- taichi/shaders/SetImage_vk.vert +0 -15
- taichi/shaders/SetImage_vk_frag.spv +0 -0
- taichi/shaders/SetImage_vk_vert.spv +0 -0
- taichi/shaders/Triangles_vk.frag +0 -16
- taichi/shaders/Triangles_vk.vert +0 -29
- taichi/shaders/Triangles_vk_frag.spv +0 -0
- taichi/shaders/Triangles_vk_vert.spv +0 -0
- taichi/shaders/lines2quad_vk_comp.spv +0 -0
- taichi/types/__init__.py +0 -19
- {taichi → gstaichi}/__main__.py +0 -0
- {taichi → gstaichi}/_lib/core/__init__.py +0 -0
- {taichi → gstaichi}/_lib/core/py.typed +0 -0
- {taichi/_lib/c_api → gstaichi/_lib}/runtime/slim_libdevice.10.bc +0 -0
- {taichi → gstaichi}/algorithms/__init__.py +0 -0
- {taichi → gstaichi}/assets/.git +0 -0
- {taichi → gstaichi}/assets/Go-Regular.ttf +0 -0
- {taichi → gstaichi}/assets/static/imgs/ti_gallery.png +0 -0
- {taichi → gstaichi}/lang/ast/symbol_resolver.py +0 -0
- {taichi → gstaichi}/sparse/__init__.py +0 -0
- {taichi → gstaichi}/tools/np2ply.py +0 -0
- {taichi → gstaichi}/tools/vtk.py +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/instrument.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/libspirv.h +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/libspirv.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/linker.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/optimizer.hpp +0 -0
- {gstaichi-0.1.23.dev0.dist-info → gstaichi-0.1.25.dev0.dist-info}/WHEEL +0 -0
- {gstaichi-0.1.23.dev0.dist-info → gstaichi-0.1.25.dev0.dist-info}/licenses/LICENSE +0 -0
@@ -4,33 +4,34 @@ import ast
|
|
4
4
|
import builtins
|
5
5
|
import traceback
|
6
6
|
from enum import Enum
|
7
|
-
from sys import version_info
|
8
7
|
from textwrap import TextWrapper
|
9
8
|
from typing import TYPE_CHECKING, Any, List
|
10
9
|
|
11
|
-
from
|
12
|
-
from
|
13
|
-
from
|
14
|
-
|
15
|
-
|
16
|
-
|
10
|
+
from gstaichi._lib.core.gstaichi_python import ASTBuilder
|
11
|
+
from gstaichi.lang import impl
|
12
|
+
from gstaichi.lang._ndrange import ndrange
|
13
|
+
from gstaichi.lang.ast.symbol_resolver import ASTResolver
|
14
|
+
from gstaichi.lang.exception import (
|
15
|
+
GsTaichiCompilationError,
|
16
|
+
GsTaichiNameError,
|
17
|
+
GsTaichiSyntaxError,
|
17
18
|
handle_exception_from_cpp,
|
18
19
|
)
|
19
20
|
|
20
21
|
if TYPE_CHECKING:
|
21
|
-
from
|
22
|
+
from gstaichi.lang.kernel_impl import (
|
22
23
|
Func,
|
23
24
|
Kernel,
|
24
25
|
)
|
25
26
|
|
26
27
|
|
27
28
|
class Builder:
|
28
|
-
def __call__(self, ctx, node):
|
29
|
+
def __call__(self, ctx: "ASTTransformerContext", node: ast.AST):
|
29
30
|
method = getattr(self, "build_" + node.__class__.__name__, None)
|
30
31
|
try:
|
31
32
|
if method is None:
|
32
33
|
error_msg = f'Unsupported node "{node.__class__.__name__}"'
|
33
|
-
raise
|
34
|
+
raise GsTaichiSyntaxError(error_msg)
|
34
35
|
info = ctx.get_pos_info(node) if isinstance(node, (ast.stmt, ast.expr)) else ""
|
35
36
|
with impl.get_runtime().src_info_guard(info):
|
36
37
|
return method(ctx, node)
|
@@ -41,15 +42,15 @@ class Builder:
|
|
41
42
|
raise e.with_traceback(None)
|
42
43
|
ctx.raised = True
|
43
44
|
e = handle_exception_from_cpp(e)
|
44
|
-
if not isinstance(e,
|
45
|
+
if not isinstance(e, GsTaichiCompilationError):
|
45
46
|
msg = ctx.get_pos_info(node) + traceback.format_exc()
|
46
|
-
raise
|
47
|
+
raise GsTaichiCompilationError(msg) from None
|
47
48
|
msg = ctx.get_pos_info(node) + str(e)
|
48
49
|
raise type(e)(msg) from None
|
49
50
|
|
50
51
|
|
51
52
|
class VariableScopeGuard:
|
52
|
-
def __init__(self, scopes):
|
53
|
+
def __init__(self, scopes: list[dict[str, Any]]):
|
53
54
|
self.scopes = scopes
|
54
55
|
|
55
56
|
def __enter__(self):
|
@@ -65,7 +66,7 @@ class StaticScopeStatus:
|
|
65
66
|
|
66
67
|
|
67
68
|
class StaticScopeGuard:
|
68
|
-
def __init__(self, status):
|
69
|
+
def __init__(self, status: StaticScopeStatus):
|
69
70
|
self.status = status
|
70
71
|
|
71
72
|
def __enter__(self):
|
@@ -107,7 +108,7 @@ class LoopScopeAttribute:
|
|
107
108
|
|
108
109
|
|
109
110
|
class LoopScopeGuard:
|
110
|
-
def __init__(self, scopes, non_static_guard=None):
|
111
|
+
def __init__(self, scopes: list[dict[str, Any]], non_static_guard=None):
|
111
112
|
self.scopes = scopes
|
112
113
|
self.non_static_guard = non_static_guard
|
113
114
|
|
@@ -167,7 +168,7 @@ class ASTTransformerContext:
|
|
167
168
|
is_real_function: bool = False,
|
168
169
|
):
|
169
170
|
self.func = func
|
170
|
-
self.local_scopes = []
|
171
|
+
self.local_scopes: list[dict[str, Any]] = []
|
171
172
|
self.loop_scopes: List[LoopScopeAttribute] = []
|
172
173
|
self.excluded_parameters = excluded_parameters
|
173
174
|
self.is_kernel = is_kernel
|
@@ -192,7 +193,7 @@ class ASTTransformerContext:
|
|
192
193
|
self.ast_builder = ast_builder
|
193
194
|
self.visited_funcdef = False
|
194
195
|
self.is_real_function = is_real_function
|
195
|
-
self.kernel_args = []
|
196
|
+
self.kernel_args: list = []
|
196
197
|
|
197
198
|
# e.g.: FunctionDef, Module, Global
|
198
199
|
def variable_scope_guard(self):
|
@@ -211,61 +212,61 @@ class ASTTransformerContext:
|
|
211
212
|
self.non_static_control_flow_status,
|
212
213
|
)
|
213
214
|
|
214
|
-
def non_static_control_flow_guard(self):
|
215
|
+
def non_static_control_flow_guard(self) -> NonStaticControlFlowGuard:
|
215
216
|
return NonStaticControlFlowGuard(self.non_static_control_flow_status)
|
216
217
|
|
217
|
-
def static_scope_guard(self):
|
218
|
+
def static_scope_guard(self) -> StaticScopeGuard:
|
218
219
|
return StaticScopeGuard(self.static_scope_status)
|
219
220
|
|
220
|
-
def current_scope(self):
|
221
|
+
def current_scope(self) -> dict[str, Any]:
|
221
222
|
return self.local_scopes[-1]
|
222
223
|
|
223
|
-
def current_loop_scope(self):
|
224
|
+
def current_loop_scope(self) -> dict[str, Any]:
|
224
225
|
return self.loop_scopes[-1]
|
225
226
|
|
226
|
-
def loop_status(self):
|
227
|
+
def loop_status(self) -> LoopStatus:
|
227
228
|
if self.loop_scopes:
|
228
229
|
return self.loop_scopes[-1].status
|
229
230
|
return LoopStatus.Normal
|
230
231
|
|
231
|
-
def set_loop_status(self, status):
|
232
|
+
def set_loop_status(self, status: LoopStatus) -> None:
|
232
233
|
self.loop_scopes[-1].status = status
|
233
234
|
|
234
|
-
def is_in_static_for(self):
|
235
|
+
def is_in_static_for(self) -> bool:
|
235
236
|
if self.loop_scopes:
|
236
237
|
return self.loop_scopes[-1].is_static
|
237
238
|
return False
|
238
239
|
|
239
|
-
def is_in_non_static_control_flow(self):
|
240
|
+
def is_in_non_static_control_flow(self) -> bool:
|
240
241
|
return self.non_static_control_flow_status.is_in_non_static_control_flow
|
241
242
|
|
242
|
-
def is_in_static_scope(self):
|
243
|
+
def is_in_static_scope(self) -> bool:
|
243
244
|
return self.static_scope_status.is_in_static_scope
|
244
245
|
|
245
|
-
def is_var_declared(self, name):
|
246
|
+
def is_var_declared(self, name: str) -> bool:
|
246
247
|
for s in self.local_scopes:
|
247
248
|
if name in s:
|
248
249
|
return True
|
249
250
|
return False
|
250
251
|
|
251
|
-
def create_variable(self, name, var):
|
252
|
+
def create_variable(self, name: str, var: Any) -> None:
|
252
253
|
if name in self.current_scope():
|
253
|
-
raise
|
254
|
+
raise GsTaichiSyntaxError("Recreating variables is not allowed")
|
254
255
|
self.current_scope()[name] = var
|
255
256
|
|
256
|
-
def check_loop_var(self, loop_var):
|
257
|
+
def check_loop_var(self, loop_var: str) -> None:
|
257
258
|
if self.is_var_declared(loop_var):
|
258
|
-
raise
|
259
|
+
raise GsTaichiSyntaxError(
|
259
260
|
f"Variable '{loop_var}' is already declared in the outer scope and cannot be used as loop variable"
|
260
261
|
)
|
261
262
|
|
262
|
-
def get_var_by_name(self, name: str):
|
263
|
+
def get_var_by_name(self, name: str) -> Any:
|
263
264
|
for s in reversed(self.local_scopes):
|
264
265
|
if name in s:
|
265
266
|
return s[name]
|
266
267
|
if name in self.global_vars:
|
267
268
|
var = self.global_vars[name]
|
268
|
-
from
|
269
|
+
from gstaichi.lang.matrix import ( # pylint: disable-msg=C0415
|
269
270
|
Matrix,
|
270
271
|
make_matrix,
|
271
272
|
)
|
@@ -276,19 +277,16 @@ class ASTTransformerContext:
|
|
276
277
|
try:
|
277
278
|
return getattr(builtins, name)
|
278
279
|
except AttributeError:
|
279
|
-
raise
|
280
|
+
raise GsTaichiNameError(f'Name "{name}" is not defined')
|
280
281
|
|
281
|
-
def get_pos_info(self, node) -> str:
|
282
|
+
def get_pos_info(self, node: ast.AST) -> str:
|
282
283
|
msg = f'File "{self.file}", line {node.lineno + self.lineno_offset}, in {self.func.func.__name__}:\n'
|
283
|
-
if version_info < (3, 8):
|
284
|
-
msg += self.src[node.lineno - 1] + "\n"
|
285
|
-
return msg
|
286
284
|
col_offset = self.indent + node.col_offset
|
287
285
|
end_col_offset = self.indent + node.end_col_offset
|
288
286
|
|
289
287
|
wrapper = TextWrapper(width=80)
|
290
288
|
|
291
|
-
def gen_line(code, hint):
|
289
|
+
def gen_line(code: str, hint: str) -> str:
|
292
290
|
hint += " " * (len(code) - len(hint))
|
293
291
|
code = wrapper.wrap(code)
|
294
292
|
hint = wrapper.wrap(hint)
|
@@ -297,8 +295,9 @@ class ASTTransformerContext:
|
|
297
295
|
return "".join([c + "\n" + h + "\n" for c, h in zip(code, hint)])
|
298
296
|
|
299
297
|
if node.lineno == node.end_lineno:
|
300
|
-
|
301
|
-
|
298
|
+
if node.lineno - 1 < len(self.src):
|
299
|
+
hint = " " * col_offset + "^" * (end_col_offset - col_offset)
|
300
|
+
msg += gen_line(self.src[node.lineno - 1], hint)
|
302
301
|
else:
|
303
302
|
node_type = node.__class__.__name__
|
304
303
|
|
@@ -326,3 +325,17 @@ class ASTTransformerContext:
|
|
326
325
|
hint = ""
|
327
326
|
msg += gen_line(self.src[i], hint)
|
328
327
|
return msg
|
328
|
+
|
329
|
+
|
330
|
+
def get_decorator(ctx: ASTTransformerContext, node) -> str:
|
331
|
+
if not isinstance(node, ast.Call):
|
332
|
+
return ""
|
333
|
+
for wanted, name in [
|
334
|
+
(impl.static, "static"),
|
335
|
+
(impl.static_assert, "static_assert"),
|
336
|
+
(impl.grouped, "grouped"),
|
337
|
+
(ndrange, "ndrange"),
|
338
|
+
]:
|
339
|
+
if ASTResolver.resolve_to(node.func, wanted, ctx.global_vars):
|
340
|
+
return name
|
341
|
+
return ""
|
File without changes
|
@@ -0,0 +1,267 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import ast
|
4
|
+
import dataclasses
|
5
|
+
import inspect
|
6
|
+
import math
|
7
|
+
import operator
|
8
|
+
import re
|
9
|
+
import warnings
|
10
|
+
from ast import unparse
|
11
|
+
from collections import ChainMap
|
12
|
+
|
13
|
+
import numpy as np
|
14
|
+
|
15
|
+
from gstaichi.lang import (
|
16
|
+
expr,
|
17
|
+
impl,
|
18
|
+
matrix,
|
19
|
+
)
|
20
|
+
from gstaichi.lang import ops as ti_ops
|
21
|
+
from gstaichi.lang.ast.ast_transformer_utils import (
|
22
|
+
ASTTransformerContext,
|
23
|
+
get_decorator,
|
24
|
+
)
|
25
|
+
from gstaichi.lang.exception import (
|
26
|
+
GsTaichiSyntaxError,
|
27
|
+
GsTaichiTypeError,
|
28
|
+
)
|
29
|
+
from gstaichi.lang.expr import Expr
|
30
|
+
from gstaichi.lang.matrix import Matrix, Vector
|
31
|
+
from gstaichi.lang.util import is_gstaichi_class
|
32
|
+
from gstaichi.types import primitive_types
|
33
|
+
|
34
|
+
|
35
|
+
class CallTransformer:
|
36
|
+
@staticmethod
|
37
|
+
def build_call_if_is_builtin(ctx: ASTTransformerContext, node, args, keywords):
|
38
|
+
from gstaichi.lang import matrix_ops # pylint: disable=C0415
|
39
|
+
|
40
|
+
func = node.func.ptr
|
41
|
+
replace_func = {
|
42
|
+
id(print): impl.ti_print,
|
43
|
+
id(min): ti_ops.min,
|
44
|
+
id(max): ti_ops.max,
|
45
|
+
id(int): impl.ti_int,
|
46
|
+
id(bool): impl.ti_bool,
|
47
|
+
id(float): impl.ti_float,
|
48
|
+
id(any): matrix_ops.any,
|
49
|
+
id(all): matrix_ops.all,
|
50
|
+
id(abs): abs,
|
51
|
+
id(pow): pow,
|
52
|
+
id(operator.matmul): matrix_ops.matmul,
|
53
|
+
}
|
54
|
+
|
55
|
+
# Builtin 'len' function on Matrix Expr
|
56
|
+
if id(func) == id(len) and len(args) == 1:
|
57
|
+
if isinstance(args[0], Expr) and args[0].ptr.is_tensor():
|
58
|
+
node.ptr = args[0].get_shape()[0]
|
59
|
+
return True
|
60
|
+
|
61
|
+
if id(func) in replace_func:
|
62
|
+
node.ptr = replace_func[id(func)](*args, **keywords)
|
63
|
+
return True
|
64
|
+
return False
|
65
|
+
|
66
|
+
@staticmethod
|
67
|
+
def build_call_if_is_type(ctx: ASTTransformerContext, node, args, keywords):
|
68
|
+
func = node.func.ptr
|
69
|
+
if id(func) in primitive_types.type_ids:
|
70
|
+
if len(args) != 1 or keywords:
|
71
|
+
raise GsTaichiSyntaxError("A primitive type can only decorate a single expression.")
|
72
|
+
if is_gstaichi_class(args[0]):
|
73
|
+
raise GsTaichiSyntaxError("A primitive type cannot decorate an expression with a compound type.")
|
74
|
+
|
75
|
+
if isinstance(args[0], expr.Expr):
|
76
|
+
if args[0].ptr.is_tensor():
|
77
|
+
raise GsTaichiSyntaxError("A primitive type cannot decorate an expression with a compound type.")
|
78
|
+
node.ptr = ti_ops.cast(args[0], func)
|
79
|
+
else:
|
80
|
+
node.ptr = expr.Expr(args[0], dtype=func)
|
81
|
+
return True
|
82
|
+
return False
|
83
|
+
|
84
|
+
@staticmethod
|
85
|
+
def is_external_func(ctx: ASTTransformerContext, func) -> bool:
|
86
|
+
if ctx.is_in_static_scope(): # allow external function in static scope
|
87
|
+
return False
|
88
|
+
if hasattr(func, "_is_gstaichi_function") or hasattr(func, "_is_wrapped_kernel"): # gstaichi func/kernel
|
89
|
+
return False
|
90
|
+
if hasattr(func, "__module__") and func.__module__ and func.__module__.startswith("gstaichi."):
|
91
|
+
return False
|
92
|
+
return True
|
93
|
+
|
94
|
+
@staticmethod
|
95
|
+
def warn_if_is_external_func(ctx: ASTTransformerContext, node):
|
96
|
+
func = node.func.ptr
|
97
|
+
if not CallTransformer.is_external_func(ctx, func):
|
98
|
+
return
|
99
|
+
name = unparse(node.func).strip()
|
100
|
+
warnings.warn_explicit(
|
101
|
+
f"\x1b[38;5;226m" # Yellow
|
102
|
+
f'Calling non-gstaichi function "{name}". '
|
103
|
+
f"Scope inside the function is not processed by the GsTaichi AST transformer. "
|
104
|
+
f"The function may not work as expected. Proceed with caution! "
|
105
|
+
f"Maybe you can consider turning it into a @ti.func?"
|
106
|
+
f"\x1b[0m", # Reset
|
107
|
+
SyntaxWarning,
|
108
|
+
ctx.file,
|
109
|
+
node.lineno + ctx.lineno_offset,
|
110
|
+
module="gstaichi",
|
111
|
+
)
|
112
|
+
|
113
|
+
@staticmethod
|
114
|
+
# Parses a formatted string and extracts format specifiers from it, along with positional and keyword arguments.
|
115
|
+
# This function produces a canonicalized formatted string that includes solely empty replacement fields, e.g. 'qwerty {} {} {} {} {}'.
|
116
|
+
# Note that the arguments can be used multiple times in the string.
|
117
|
+
# e.g.:
|
118
|
+
# origin input: 'qwerty {1} {} {1:.3f} {k:.4f} {k:}'.format(1.0, 2.0, k=k)
|
119
|
+
# raw_string: 'qwerty {1} {} {1:.3f} {k:.4f} {k:}'
|
120
|
+
# raw_args: [1.0, 2.0]
|
121
|
+
# raw_keywords: {'k': <ti.Expr>}
|
122
|
+
# return value: ['qwerty {} {} {} {} {}', 2.0, 1.0, ['__ti_fmt_value__', 2.0, '.3f'], ['__ti_fmt_value__', <ti.Expr>, '.4f'], <ti.Expr>]
|
123
|
+
def canonicalize_formatted_string(raw_string: str, *raw_args: list, **raw_keywords: dict):
|
124
|
+
raw_brackets = re.findall(r"{(.*?)}", raw_string)
|
125
|
+
brackets = []
|
126
|
+
unnamed = 0
|
127
|
+
for bracket in raw_brackets:
|
128
|
+
item, spec = bracket.split(":") if ":" in bracket else (bracket, None)
|
129
|
+
if item.isdigit():
|
130
|
+
item = int(item)
|
131
|
+
# handle unnamed positional args
|
132
|
+
if item == "":
|
133
|
+
item = unnamed
|
134
|
+
unnamed += 1
|
135
|
+
# handle empty spec
|
136
|
+
if spec == "":
|
137
|
+
spec = None
|
138
|
+
brackets.append((item, spec))
|
139
|
+
|
140
|
+
# check for errors in the arguments
|
141
|
+
max_args_index = max([t[0] for t in brackets if isinstance(t[0], int)], default=-1)
|
142
|
+
if max_args_index + 1 != len(raw_args):
|
143
|
+
raise GsTaichiSyntaxError(
|
144
|
+
f"Expected {max_args_index + 1} positional argument(s), but received {len(raw_args)} instead."
|
145
|
+
)
|
146
|
+
brackets_keywords = [t[0] for t in brackets if isinstance(t[0], str)]
|
147
|
+
for item in brackets_keywords:
|
148
|
+
if item not in raw_keywords:
|
149
|
+
raise GsTaichiSyntaxError(f"Keyword '{item}' not found.")
|
150
|
+
for item in raw_keywords:
|
151
|
+
if item not in brackets_keywords:
|
152
|
+
raise GsTaichiSyntaxError(f"Keyword '{item}' not used.")
|
153
|
+
|
154
|
+
# reorganize the arguments based on their positions, keywords, and format specifiers
|
155
|
+
args = []
|
156
|
+
for item, spec in brackets:
|
157
|
+
new_arg = raw_args[item] if isinstance(item, int) else raw_keywords[item]
|
158
|
+
if spec is not None:
|
159
|
+
args.append(["__ti_fmt_value__", new_arg, spec])
|
160
|
+
else:
|
161
|
+
args.append(new_arg)
|
162
|
+
# put the formatted string as the first argument to make ti.format() happy
|
163
|
+
args.insert(0, re.sub(r"{.*?}", "{}", raw_string))
|
164
|
+
return args
|
165
|
+
|
166
|
+
@staticmethod
|
167
|
+
def expand_node_args_dataclasses(args: tuple[ast.AST, ...]) -> tuple[ast.AST, ...]:
|
168
|
+
args_new = []
|
169
|
+
for arg in args:
|
170
|
+
val = arg.ptr
|
171
|
+
if dataclasses.is_dataclass(val):
|
172
|
+
dataclass_type = val
|
173
|
+
for field in dataclasses.fields(dataclass_type):
|
174
|
+
child_name = f"__ti_{arg.id}_{field.name}"
|
175
|
+
load_ctx = ast.Load()
|
176
|
+
arg_node = ast.Name(
|
177
|
+
id=child_name,
|
178
|
+
ctx=load_ctx,
|
179
|
+
lineno=arg.lineno,
|
180
|
+
end_lineno=arg.end_lineno,
|
181
|
+
col_offset=arg.col_offset,
|
182
|
+
end_col_offset=arg.end_col_offset,
|
183
|
+
)
|
184
|
+
args_new.append(arg_node)
|
185
|
+
else:
|
186
|
+
args_new.append(arg)
|
187
|
+
return tuple(args_new)
|
188
|
+
|
189
|
+
@staticmethod
|
190
|
+
def build_Call(ctx: ASTTransformerContext, node: ast.Call, build_stmt, build_stmts):
|
191
|
+
if get_decorator(ctx, node) in ["static", "static_assert"]:
|
192
|
+
with ctx.static_scope_guard():
|
193
|
+
build_stmt(ctx, node.func)
|
194
|
+
build_stmts(ctx, node.args)
|
195
|
+
build_stmts(ctx, node.keywords)
|
196
|
+
else:
|
197
|
+
build_stmt(ctx, node.func)
|
198
|
+
# creates variable for the dataclass itself (as well as other variables,
|
199
|
+
# not related to dataclasses). Necessary for calling further child functions
|
200
|
+
build_stmts(ctx, node.args)
|
201
|
+
node.args = CallTransformer.expand_node_args_dataclasses(node.args)
|
202
|
+
# create variables for the now-expanded dataclass members
|
203
|
+
build_stmts(ctx, node.args)
|
204
|
+
build_stmts(ctx, node.keywords)
|
205
|
+
|
206
|
+
args = []
|
207
|
+
for arg in node.args:
|
208
|
+
if isinstance(arg, ast.Starred):
|
209
|
+
arg_list = arg.ptr
|
210
|
+
if isinstance(arg_list, Expr) and arg_list.is_tensor():
|
211
|
+
# Expand Expr with Matrix-type return into list of Exprs
|
212
|
+
arg_list = [Expr(x) for x in ctx.ast_builder.expand_exprs([arg_list.ptr])]
|
213
|
+
|
214
|
+
for i in arg_list:
|
215
|
+
args.append(i)
|
216
|
+
else:
|
217
|
+
args.append(arg.ptr)
|
218
|
+
keywords = dict(ChainMap(*[keyword.ptr for keyword in node.keywords]))
|
219
|
+
func = node.func.ptr
|
220
|
+
|
221
|
+
if id(func) in [id(print), id(impl.ti_print)]:
|
222
|
+
ctx.func.has_print = True
|
223
|
+
|
224
|
+
if isinstance(node.func, ast.Attribute) and isinstance(node.func.value.ptr, str) and node.func.attr == "format":
|
225
|
+
raw_string = node.func.value.ptr
|
226
|
+
args = CallTransformer.canonicalize_formatted_string(raw_string, *args, **keywords)
|
227
|
+
node.ptr = impl.ti_format(*args)
|
228
|
+
return node.ptr
|
229
|
+
|
230
|
+
if id(func) == id(Matrix) or id(func) == id(Vector):
|
231
|
+
node.ptr = matrix.make_matrix(*args, **keywords)
|
232
|
+
return node.ptr
|
233
|
+
|
234
|
+
if CallTransformer.build_call_if_is_builtin(ctx, node, args, keywords):
|
235
|
+
return node.ptr
|
236
|
+
|
237
|
+
if CallTransformer.build_call_if_is_type(ctx, node, args, keywords):
|
238
|
+
return node.ptr
|
239
|
+
|
240
|
+
if hasattr(node.func, "caller"):
|
241
|
+
node.ptr = func(node.func.caller, *args, **keywords)
|
242
|
+
return node.ptr
|
243
|
+
|
244
|
+
CallTransformer.warn_if_is_external_func(ctx, node)
|
245
|
+
try:
|
246
|
+
node.ptr = func(*args, **keywords)
|
247
|
+
except TypeError as e:
|
248
|
+
module = inspect.getmodule(func)
|
249
|
+
error_msg = re.sub(r"\bExpr\b", "GsTaichi Expression", str(e))
|
250
|
+
func_name = getattr(func, "__name__", func.__class__.__name__)
|
251
|
+
msg = f"TypeError when calling `{func_name}`: {error_msg}."
|
252
|
+
if CallTransformer.is_external_func(ctx, node.func.ptr):
|
253
|
+
args_has_expr = any([isinstance(arg, Expr) for arg in args])
|
254
|
+
if args_has_expr and (module == math or module == np):
|
255
|
+
exec_str = f"from gstaichi import {func.__name__}"
|
256
|
+
try:
|
257
|
+
exec(exec_str, {})
|
258
|
+
except:
|
259
|
+
pass
|
260
|
+
else:
|
261
|
+
msg += f"\nDid you mean to use `ti.{func.__name__}` instead of `{module.__name__}.{func.__name__}`?"
|
262
|
+
raise GsTaichiTypeError(msg)
|
263
|
+
|
264
|
+
if getattr(func, "_is_gstaichi_function", False):
|
265
|
+
ctx.func.has_print |= func.wrapper.has_print
|
266
|
+
|
267
|
+
return node.ptr
|