gstaichi 0.1.23.dev0__cp310-cp310-macosx_15_0_arm64.whl → 0.1.25.dev0__cp310-cp310-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {taichi → gstaichi}/__init__.py +9 -13
- {taichi → gstaichi}/_funcs.py +8 -8
- {taichi → gstaichi}/_kernels.py +19 -19
- gstaichi/_lib/__init__.py +3 -0
- taichi/_lib/core/taichi_python.cpython-310-darwin.so → gstaichi/_lib/core/gstaichi_python.cpython-310-darwin.so +0 -0
- taichi/_lib/core/taichi_python.pyi → gstaichi/_lib/core/gstaichi_python.pyi +382 -520
- {taichi → gstaichi}/_lib/runtime/runtime_arm64.bc +0 -0
- {taichi → gstaichi}/_lib/utils.py +15 -15
- {taichi → gstaichi}/_logging.py +1 -1
- {taichi → gstaichi}/_main.py +24 -31
- gstaichi/_snode/__init__.py +5 -0
- {taichi → gstaichi}/_snode/fields_builder.py +27 -29
- {taichi → gstaichi}/_snode/snode_tree.py +5 -5
- gstaichi/_test_tools/__init__.py +0 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_version.py +1 -0
- {taichi → gstaichi}/_version_check.py +8 -5
- gstaichi/ad/__init__.py +3 -0
- {taichi → gstaichi}/ad/_ad.py +26 -26
- {taichi → gstaichi}/algorithms/_algorithms.py +7 -7
- {taichi → gstaichi}/examples/minimal.py +1 -1
- {taichi → gstaichi}/experimental.py +1 -1
- gstaichi/lang/__init__.py +50 -0
- {taichi → gstaichi}/lang/_ndarray.py +30 -26
- {taichi → gstaichi}/lang/_ndrange.py +8 -8
- gstaichi/lang/_template_mapper.py +199 -0
- {taichi → gstaichi}/lang/_texture.py +19 -19
- {taichi → gstaichi}/lang/_wrap_inspect.py +7 -7
- {taichi → gstaichi}/lang/any_array.py +13 -13
- {taichi → gstaichi}/lang/argpack.py +29 -29
- gstaichi/lang/ast/__init__.py +5 -0
- {taichi → gstaichi}/lang/ast/ast_transformer.py +94 -582
- {taichi → gstaichi}/lang/ast/ast_transformer_utils.py +54 -41
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +267 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +320 -0
- {taichi → gstaichi}/lang/ast/checkers.py +5 -5
- gstaichi/lang/ast/transform.py +9 -0
- {taichi → gstaichi}/lang/common_ops.py +12 -12
- gstaichi/lang/exception.py +80 -0
- {taichi → gstaichi}/lang/expr.py +22 -22
- {taichi → gstaichi}/lang/field.py +29 -27
- {taichi → gstaichi}/lang/impl.py +116 -121
- {taichi → gstaichi}/lang/kernel_arguments.py +16 -16
- {taichi → gstaichi}/lang/kernel_impl.py +330 -363
- {taichi → gstaichi}/lang/matrix.py +119 -115
- {taichi → gstaichi}/lang/matrix_ops.py +6 -6
- {taichi → gstaichi}/lang/matrix_ops_utils.py +4 -4
- {taichi → gstaichi}/lang/mesh.py +22 -22
- {taichi → gstaichi}/lang/misc.py +39 -68
- {taichi → gstaichi}/lang/ops.py +146 -141
- {taichi → gstaichi}/lang/runtime_ops.py +2 -2
- {taichi → gstaichi}/lang/shell.py +3 -3
- {taichi → gstaichi}/lang/simt/__init__.py +1 -1
- {taichi → gstaichi}/lang/simt/block.py +7 -7
- {taichi → gstaichi}/lang/simt/grid.py +1 -1
- {taichi → gstaichi}/lang/simt/subgroup.py +1 -1
- {taichi → gstaichi}/lang/simt/warp.py +1 -1
- {taichi → gstaichi}/lang/snode.py +46 -44
- {taichi → gstaichi}/lang/source_builder.py +13 -13
- {taichi → gstaichi}/lang/struct.py +33 -33
- {taichi → gstaichi}/lang/util.py +24 -24
- gstaichi/linalg/__init__.py +8 -0
- {taichi → gstaichi}/linalg/matrixfree_cg.py +14 -14
- {taichi → gstaichi}/linalg/sparse_cg.py +10 -10
- {taichi → gstaichi}/linalg/sparse_matrix.py +23 -23
- {taichi → gstaichi}/linalg/sparse_solver.py +21 -21
- {taichi → gstaichi}/math/__init__.py +1 -1
- {taichi → gstaichi}/math/_complex.py +21 -20
- {taichi → gstaichi}/math/mathimpl.py +56 -56
- gstaichi/profiler/__init__.py +6 -0
- {taichi → gstaichi}/profiler/kernel_metrics.py +11 -11
- {taichi → gstaichi}/profiler/kernel_profiler.py +30 -36
- {taichi → gstaichi}/profiler/memory_profiler.py +1 -1
- {taichi → gstaichi}/profiler/scoped_profiler.py +2 -2
- {taichi → gstaichi}/sparse/_sparse_grid.py +7 -7
- {taichi → gstaichi}/tools/__init__.py +4 -4
- {taichi → gstaichi}/tools/diagnose.py +10 -17
- gstaichi/types/__init__.py +19 -0
- {taichi → gstaichi}/types/annotations.py +1 -1
- {taichi → gstaichi}/types/compound_types.py +8 -8
- {taichi → gstaichi}/types/enums.py +1 -1
- {taichi → gstaichi}/types/ndarray_type.py +7 -7
- {taichi → gstaichi}/types/primitive_types.py +17 -14
- {taichi → gstaichi}/types/quant.py +9 -9
- {taichi → gstaichi}/types/texture_type.py +5 -5
- {taichi → gstaichi}/types/utils.py +1 -1
- {gstaichi-0.1.23.dev0.dist-info → gstaichi-0.1.25.dev0.dist-info}/METADATA +13 -16
- gstaichi-0.1.25.dev0.dist-info/RECORD +168 -0
- gstaichi-0.1.25.dev0.dist-info/entry_points.txt +2 -0
- gstaichi-0.1.25.dev0.dist-info/top_level.txt +1 -0
- gstaichi-0.1.23.dev0.dist-info/RECORD +0 -219
- gstaichi-0.1.23.dev0.dist-info/entry_points.txt +0 -2
- gstaichi-0.1.23.dev0.dist-info/top_level.txt +0 -1
- taichi/_lib/__init__.py +0 -3
- taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +0 -1401
- taichi/_lib/c_api/include/taichi/taichi.h +0 -29
- taichi/_lib/c_api/include/taichi/taichi_core.h +0 -1111
- taichi/_lib/c_api/include/taichi/taichi_cpu.h +0 -29
- taichi/_lib/c_api/include/taichi/taichi_metal.h +0 -72
- taichi/_lib/c_api/include/taichi/taichi_platform.h +0 -55
- taichi/_lib/c_api/include/taichi/taichi_unity.h +0 -64
- taichi/_lib/c_api/include/taichi/taichi_vulkan.h +0 -151
- taichi/_lib/c_api/lib/libtaichi_c_api.dylib +0 -0
- taichi/_lib/c_api/runtime/runtime_arm64.bc +0 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +0 -29
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +0 -65
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +0 -121
- taichi/_lib/runtime/libMoltenVK.dylib +0 -0
- taichi/_snode/__init__.py +0 -5
- taichi/_ti_module/__init__.py +0 -3
- taichi/_ti_module/cppgen.py +0 -309
- taichi/_ti_module/module.py +0 -145
- taichi/_version.py +0 -1
- taichi/ad/__init__.py +0 -3
- taichi/aot/__init__.py +0 -12
- taichi/aot/_export.py +0 -28
- taichi/aot/conventions/__init__.py +0 -3
- taichi/aot/conventions/gfxruntime140/__init__.py +0 -38
- taichi/aot/conventions/gfxruntime140/dr.py +0 -244
- taichi/aot/conventions/gfxruntime140/sr.py +0 -613
- taichi/aot/module.py +0 -253
- taichi/aot/utils.py +0 -151
- taichi/graph/__init__.py +0 -3
- taichi/graph/_graph.py +0 -292
- taichi/lang/__init__.py +0 -50
- taichi/lang/ast/__init__.py +0 -5
- taichi/lang/ast/transform.py +0 -9
- taichi/lang/exception.py +0 -80
- taichi/linalg/__init__.py +0 -8
- taichi/profiler/__init__.py +0 -6
- taichi/shaders/Circles_vk.frag +0 -29
- taichi/shaders/Circles_vk.vert +0 -45
- taichi/shaders/Circles_vk_frag.spv +0 -0
- taichi/shaders/Circles_vk_vert.spv +0 -0
- taichi/shaders/Lines_vk.frag +0 -9
- taichi/shaders/Lines_vk.vert +0 -11
- taichi/shaders/Lines_vk_frag.spv +0 -0
- taichi/shaders/Lines_vk_vert.spv +0 -0
- taichi/shaders/Mesh_vk.frag +0 -71
- taichi/shaders/Mesh_vk.vert +0 -68
- taichi/shaders/Mesh_vk_frag.spv +0 -0
- taichi/shaders/Mesh_vk_vert.spv +0 -0
- taichi/shaders/Particles_vk.frag +0 -95
- taichi/shaders/Particles_vk.vert +0 -73
- taichi/shaders/Particles_vk_frag.spv +0 -0
- taichi/shaders/Particles_vk_vert.spv +0 -0
- taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
- taichi/shaders/SceneLines_vk.frag +0 -9
- taichi/shaders/SceneLines_vk.vert +0 -12
- taichi/shaders/SceneLines_vk_frag.spv +0 -0
- taichi/shaders/SceneLines_vk_vert.spv +0 -0
- taichi/shaders/SetImage_vk.frag +0 -21
- taichi/shaders/SetImage_vk.vert +0 -15
- taichi/shaders/SetImage_vk_frag.spv +0 -0
- taichi/shaders/SetImage_vk_vert.spv +0 -0
- taichi/shaders/Triangles_vk.frag +0 -16
- taichi/shaders/Triangles_vk.vert +0 -29
- taichi/shaders/Triangles_vk_frag.spv +0 -0
- taichi/shaders/Triangles_vk_vert.spv +0 -0
- taichi/shaders/lines2quad_vk_comp.spv +0 -0
- taichi/types/__init__.py +0 -19
- {taichi → gstaichi}/__main__.py +0 -0
- {taichi → gstaichi}/_lib/core/__init__.py +0 -0
- {taichi → gstaichi}/_lib/core/py.typed +0 -0
- {taichi/_lib/c_api → gstaichi/_lib}/runtime/libMoltenVK.dylib +0 -0
- {taichi → gstaichi}/algorithms/__init__.py +0 -0
- {taichi → gstaichi}/assets/.git +0 -0
- {taichi → gstaichi}/assets/Go-Regular.ttf +0 -0
- {taichi → gstaichi}/assets/static/imgs/ti_gallery.png +0 -0
- {taichi → gstaichi}/lang/ast/symbol_resolver.py +0 -0
- {taichi → gstaichi}/sparse/__init__.py +0 -0
- {taichi → gstaichi}/tools/np2ply.py +0 -0
- {taichi → gstaichi}/tools/vtk.py +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/GLFW/glfw3.h +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/GLFW/glfw3native.h +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/instrument.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/libspirv.h +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/libspirv.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/linker.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/optimizer.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/GLSL.std.450.h +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv.h +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_cfg.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_common.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_cpp.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_cross.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_cross_c.h +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_cross_containers.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_cross_error_handling.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_cross_util.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_glsl.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_hlsl.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_msl.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_parser.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_reflect.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/glfw3/glfw3Config.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/glfw3/glfw3Targets-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/glfw3/glfw3Targets.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.dist-info → gstaichi-0.1.25.dev0.dist-info}/WHEEL +0 -0
- {gstaichi-0.1.23.dev0.dist-info → gstaichi-0.1.25.dev0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,320 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import ast
|
4
|
+
import dataclasses
|
5
|
+
from typing import Any, Callable
|
6
|
+
|
7
|
+
from gstaichi.lang import (
|
8
|
+
_ndarray,
|
9
|
+
any_array,
|
10
|
+
expr,
|
11
|
+
impl,
|
12
|
+
kernel_arguments,
|
13
|
+
matrix,
|
14
|
+
)
|
15
|
+
from gstaichi.lang import ops as ti_ops
|
16
|
+
from gstaichi.lang.argpack import ArgPackType
|
17
|
+
from gstaichi.lang.ast.ast_transformer_utils import (
|
18
|
+
ASTTransformerContext,
|
19
|
+
)
|
20
|
+
from gstaichi.lang.exception import (
|
21
|
+
GsTaichiSyntaxError,
|
22
|
+
)
|
23
|
+
from gstaichi.lang.matrix import MatrixType
|
24
|
+
from gstaichi.lang.struct import StructType
|
25
|
+
from gstaichi.lang.util import to_gstaichi_type
|
26
|
+
from gstaichi.types import annotations, ndarray_type, primitive_types, texture_type
|
27
|
+
|
28
|
+
|
29
|
+
class FunctionDefTransformer:
|
30
|
+
@staticmethod
|
31
|
+
def _decl_and_create_variable(
|
32
|
+
ctx: ASTTransformerContext, annotation, name, arg_features, invoke_later_dict, prefix_name, arg_depth
|
33
|
+
) -> tuple[bool, Any]:
|
34
|
+
full_name = prefix_name + "_" + name
|
35
|
+
if not isinstance(annotation, primitive_types.RefType):
|
36
|
+
ctx.kernel_args.append(name)
|
37
|
+
if isinstance(annotation, ArgPackType):
|
38
|
+
kernel_arguments.push_argpack_arg(name)
|
39
|
+
d = {}
|
40
|
+
items_to_put_in_dict = []
|
41
|
+
for j, (_name, anno) in enumerate(annotation.members.items()):
|
42
|
+
result, obj = FunctionDefTransformer._decl_and_create_variable(
|
43
|
+
ctx, anno, _name, arg_features[j], invoke_later_dict, full_name, arg_depth + 1
|
44
|
+
)
|
45
|
+
if not result:
|
46
|
+
d[_name] = None
|
47
|
+
items_to_put_in_dict.append((full_name + "_" + _name, _name, obj))
|
48
|
+
else:
|
49
|
+
d[_name] = obj
|
50
|
+
argpack = kernel_arguments.decl_argpack_arg(annotation, d)
|
51
|
+
for item in items_to_put_in_dict:
|
52
|
+
invoke_later_dict[item[0]] = argpack, item[1], *item[2]
|
53
|
+
return True, argpack
|
54
|
+
if annotation == annotations.template or isinstance(annotation, annotations.template):
|
55
|
+
return True, ctx.global_vars[name]
|
56
|
+
if isinstance(annotation, annotations.sparse_matrix_builder):
|
57
|
+
return False, (
|
58
|
+
kernel_arguments.decl_sparse_matrix,
|
59
|
+
(
|
60
|
+
to_gstaichi_type(arg_features),
|
61
|
+
full_name,
|
62
|
+
),
|
63
|
+
)
|
64
|
+
if isinstance(annotation, ndarray_type.NdarrayType):
|
65
|
+
return False, (
|
66
|
+
kernel_arguments.decl_ndarray_arg,
|
67
|
+
(
|
68
|
+
to_gstaichi_type(arg_features[0]),
|
69
|
+
arg_features[1],
|
70
|
+
full_name,
|
71
|
+
arg_features[2],
|
72
|
+
arg_features[3],
|
73
|
+
),
|
74
|
+
)
|
75
|
+
if isinstance(annotation, texture_type.TextureType):
|
76
|
+
return False, (kernel_arguments.decl_texture_arg, (arg_features[0], full_name))
|
77
|
+
if isinstance(annotation, texture_type.RWTextureType):
|
78
|
+
return False, (
|
79
|
+
kernel_arguments.decl_rw_texture_arg,
|
80
|
+
(arg_features[0], arg_features[1], arg_features[2], full_name),
|
81
|
+
)
|
82
|
+
if isinstance(annotation, MatrixType):
|
83
|
+
return True, kernel_arguments.decl_matrix_arg(annotation, name, arg_depth)
|
84
|
+
if isinstance(annotation, StructType):
|
85
|
+
return True, kernel_arguments.decl_struct_arg(annotation, name, arg_depth)
|
86
|
+
return True, kernel_arguments.decl_scalar_arg(annotation, name, arg_depth)
|
87
|
+
|
88
|
+
@staticmethod
|
89
|
+
def _transform_kernel_arg(
|
90
|
+
ctx: ASTTransformerContext,
|
91
|
+
invoke_later_dict: dict[str, tuple[Any, str, Callable, list[Any]]],
|
92
|
+
create_variable_later: dict[str, Any],
|
93
|
+
argument_name: str,
|
94
|
+
argument_type: Any,
|
95
|
+
this_arg_features: tuple[Any, ...],
|
96
|
+
) -> None:
|
97
|
+
if isinstance(argument_type, ArgPackType):
|
98
|
+
kernel_arguments.push_argpack_arg(argument_name)
|
99
|
+
d = {}
|
100
|
+
items_to_put_in_dict: list[tuple[str, str, Any]] = []
|
101
|
+
for j, (name, anno) in enumerate(argument_type.members.items()):
|
102
|
+
result, obj = FunctionDefTransformer._decl_and_create_variable(
|
103
|
+
ctx, anno, name, this_arg_features[j], invoke_later_dict, "__argpack_" + name, 1
|
104
|
+
)
|
105
|
+
if not result:
|
106
|
+
d[name] = None
|
107
|
+
items_to_put_in_dict.append(("__argpack_" + name, name, obj))
|
108
|
+
else:
|
109
|
+
d[name] = obj
|
110
|
+
argpack = kernel_arguments.decl_argpack_arg(argument_type, d)
|
111
|
+
for item in items_to_put_in_dict:
|
112
|
+
invoke_later_dict[item[0]] = argpack, item[1], *item[2]
|
113
|
+
create_variable_later[argument_name] = argpack
|
114
|
+
elif dataclasses.is_dataclass(argument_type):
|
115
|
+
arg_features = this_arg_features
|
116
|
+
ctx.create_variable(argument_name, argument_type)
|
117
|
+
for field_idx, field in enumerate(dataclasses.fields(argument_type)):
|
118
|
+
flat_name = f"__ti_{argument_name}_{field.name}"
|
119
|
+
result, obj = FunctionDefTransformer._decl_and_create_variable(
|
120
|
+
ctx,
|
121
|
+
field.type,
|
122
|
+
flat_name,
|
123
|
+
arg_features[field_idx],
|
124
|
+
invoke_later_dict,
|
125
|
+
"",
|
126
|
+
0,
|
127
|
+
)
|
128
|
+
if result:
|
129
|
+
ctx.create_variable(flat_name, obj)
|
130
|
+
else:
|
131
|
+
decl_type_func, type_args = obj
|
132
|
+
obj = decl_type_func(*type_args)
|
133
|
+
ctx.create_variable(flat_name, obj)
|
134
|
+
else:
|
135
|
+
result, obj = FunctionDefTransformer._decl_and_create_variable(
|
136
|
+
ctx,
|
137
|
+
argument_type,
|
138
|
+
argument_name,
|
139
|
+
this_arg_features if ctx.arg_features is not None else None,
|
140
|
+
invoke_later_dict,
|
141
|
+
"",
|
142
|
+
0,
|
143
|
+
)
|
144
|
+
if result:
|
145
|
+
ctx.create_variable(argument_name, obj)
|
146
|
+
else:
|
147
|
+
decl_type_func, type_args = obj
|
148
|
+
obj = decl_type_func(*type_args)
|
149
|
+
ctx.create_variable(argument_name, obj)
|
150
|
+
|
151
|
+
@staticmethod
|
152
|
+
def _transform_as_kernel(ctx: ASTTransformerContext, node: ast.FunctionDef, args: ast.arguments) -> None:
|
153
|
+
if node.returns is not None:
|
154
|
+
if not isinstance(node.returns, ast.Constant):
|
155
|
+
for return_type in ctx.func.return_type:
|
156
|
+
kernel_arguments.decl_ret(return_type)
|
157
|
+
impl.get_runtime().compiling_callable.finalize_rets()
|
158
|
+
|
159
|
+
invoke_later_dict: dict[str, tuple[Any, str, Any]] = dict()
|
160
|
+
create_variable_later = dict()
|
161
|
+
for i, arg in enumerate(args.args):
|
162
|
+
argument = ctx.func.arguments[i]
|
163
|
+
FunctionDefTransformer._transform_kernel_arg(
|
164
|
+
ctx,
|
165
|
+
invoke_later_dict,
|
166
|
+
create_variable_later,
|
167
|
+
argument.name,
|
168
|
+
argument.annotation,
|
169
|
+
ctx.arg_features[i] if ctx.arg_features is not None else (),
|
170
|
+
)
|
171
|
+
|
172
|
+
for k, v in invoke_later_dict.items():
|
173
|
+
argpack, name, func, params = v
|
174
|
+
argpack[name] = func(*params)
|
175
|
+
for k, v in create_variable_later.items():
|
176
|
+
ctx.create_variable(k, v)
|
177
|
+
|
178
|
+
impl.get_runtime().compiling_callable.finalize_params()
|
179
|
+
# remove original args
|
180
|
+
node.args.args = []
|
181
|
+
|
182
|
+
@staticmethod
|
183
|
+
def _transform_func_arg(
|
184
|
+
ctx: ASTTransformerContext,
|
185
|
+
argument_name: str,
|
186
|
+
argument_type: Any,
|
187
|
+
data: Any,
|
188
|
+
) -> None:
|
189
|
+
if isinstance(argument_type, annotations.template):
|
190
|
+
ctx.create_variable(argument_name, data)
|
191
|
+
return None
|
192
|
+
|
193
|
+
if dataclasses.is_dataclass(argument_type):
|
194
|
+
dataclass_type = argument_type
|
195
|
+
for field in dataclasses.fields(dataclass_type):
|
196
|
+
data_child = getattr(data, field.name)
|
197
|
+
if not isinstance(
|
198
|
+
data_child,
|
199
|
+
(
|
200
|
+
_ndarray.ScalarNdarray,
|
201
|
+
matrix.VectorNdarray,
|
202
|
+
matrix.MatrixNdarray,
|
203
|
+
any_array.AnyArray,
|
204
|
+
),
|
205
|
+
):
|
206
|
+
raise GsTaichiSyntaxError(
|
207
|
+
f"Argument {argument_name} of type {dataclass_type} {field.type} is not recognized."
|
208
|
+
)
|
209
|
+
field.type.check_matched(data_child.get_type(), field.name)
|
210
|
+
var_name = f"__ti_{argument_name}_{field.name}"
|
211
|
+
ctx.create_variable(var_name, data_child)
|
212
|
+
return None
|
213
|
+
|
214
|
+
# Ndarray arguments are passed by reference.
|
215
|
+
if isinstance(argument_type, (ndarray_type.NdarrayType)):
|
216
|
+
if not isinstance(
|
217
|
+
data,
|
218
|
+
(
|
219
|
+
_ndarray.ScalarNdarray,
|
220
|
+
matrix.VectorNdarray,
|
221
|
+
matrix.MatrixNdarray,
|
222
|
+
any_array.AnyArray,
|
223
|
+
),
|
224
|
+
):
|
225
|
+
raise GsTaichiSyntaxError(f"Argument {arg.arg} of type {argument_type} is not recognized.")
|
226
|
+
argument_type.check_matched(data.get_type(), argument_name)
|
227
|
+
ctx.create_variable(argument_name, data)
|
228
|
+
return None
|
229
|
+
|
230
|
+
# Matrix arguments are passed by value.
|
231
|
+
if isinstance(argument_type, (MatrixType)):
|
232
|
+
var_name = argument_name
|
233
|
+
# "data" is expected to be an Expr here,
|
234
|
+
# so we simply call "impl.expr_init_func(data)" to perform:
|
235
|
+
#
|
236
|
+
# TensorType* t = alloca()
|
237
|
+
# assign(t, data)
|
238
|
+
#
|
239
|
+
# We created local variable "t" - a copy of the passed-in argument "data"
|
240
|
+
if not isinstance(data, expr.Expr) or not data.ptr.is_tensor():
|
241
|
+
raise GsTaichiSyntaxError(
|
242
|
+
f"Argument {var_name} of type {argument_type} is expected to be a Matrix, but got {type(data)}."
|
243
|
+
)
|
244
|
+
|
245
|
+
element_shape = data.ptr.get_rvalue_type().shape()
|
246
|
+
if len(element_shape) != argument_type.ndim:
|
247
|
+
raise GsTaichiSyntaxError(
|
248
|
+
f"Argument {var_name} of type {argument_type} is expected to be a Matrix with ndim {argument_type.ndim}, but got {len(element_shape)}."
|
249
|
+
)
|
250
|
+
|
251
|
+
assert argument_type.ndim > 0
|
252
|
+
if element_shape[0] != argument_type.n:
|
253
|
+
raise GsTaichiSyntaxError(
|
254
|
+
f"Argument {var_name} of type {argument_type} is expected to be a Matrix with n {argument_type.n}, but got {element_shape[0]}."
|
255
|
+
)
|
256
|
+
|
257
|
+
if argument_type.ndim == 2 and element_shape[1] != argument_type.m:
|
258
|
+
raise GsTaichiSyntaxError(
|
259
|
+
f"Argument {var_name} of type {argument_type} is expected to be a Matrix with m {argument_type.m}, but got {element_shape[0]}."
|
260
|
+
)
|
261
|
+
|
262
|
+
ctx.create_variable(var_name, impl.expr_init_func(data))
|
263
|
+
return None
|
264
|
+
|
265
|
+
if id(argument_type) in primitive_types.type_ids:
|
266
|
+
var_name = argument_name
|
267
|
+
ctx.create_variable(var_name, impl.expr_init_func(ti_ops.cast(data, argument_type)))
|
268
|
+
return None
|
269
|
+
# Create a copy for non-template arguments,
|
270
|
+
# so that they are passed by value.
|
271
|
+
var_name = argument_name
|
272
|
+
ctx.create_variable(var_name, impl.expr_init_func(data))
|
273
|
+
return None
|
274
|
+
|
275
|
+
@staticmethod
|
276
|
+
def _transform_as_func(ctx: ASTTransformerContext, node: ast.FunctionDef, args: ast.arguments) -> None:
|
277
|
+
for data_i, data in enumerate(ctx.argument_data):
|
278
|
+
argument = ctx.func.arguments[data_i]
|
279
|
+
FunctionDefTransformer._transform_func_arg(
|
280
|
+
ctx,
|
281
|
+
argument.name,
|
282
|
+
argument.annotation,
|
283
|
+
data,
|
284
|
+
)
|
285
|
+
|
286
|
+
for v in ctx.func.orig_arguments:
|
287
|
+
if dataclasses.is_dataclass(v.annotation):
|
288
|
+
ctx.create_variable(v.name, v.annotation)
|
289
|
+
|
290
|
+
@staticmethod
|
291
|
+
def build_FunctionDef(
|
292
|
+
ctx: ASTTransformerContext,
|
293
|
+
node: ast.FunctionDef,
|
294
|
+
build_stmts: Callable[[ASTTransformerContext, list[ast.stmt]], None],
|
295
|
+
) -> None:
|
296
|
+
if ctx.visited_funcdef:
|
297
|
+
raise GsTaichiSyntaxError(
|
298
|
+
f"Function definition is not allowed in 'ti.{'kernel' if ctx.is_kernel else 'func'}'."
|
299
|
+
)
|
300
|
+
ctx.visited_funcdef = True
|
301
|
+
|
302
|
+
args = node.args
|
303
|
+
assert args.vararg is None
|
304
|
+
assert args.kwonlyargs == []
|
305
|
+
assert args.kw_defaults == []
|
306
|
+
assert args.kwarg is None
|
307
|
+
|
308
|
+
if ctx.is_kernel: # ti.kernel
|
309
|
+
FunctionDefTransformer._transform_as_kernel(ctx, node, args)
|
310
|
+
|
311
|
+
else: # ti.func
|
312
|
+
if ctx.is_real_function:
|
313
|
+
FunctionDefTransformer._transform_as_kernel(ctx, node, args)
|
314
|
+
else:
|
315
|
+
FunctionDefTransformer._transform_as_func(ctx, node, args)
|
316
|
+
|
317
|
+
with ctx.variable_scope_guard():
|
318
|
+
build_stmts(ctx, node.body)
|
319
|
+
|
320
|
+
return None
|
@@ -2,8 +2,8 @@
|
|
2
2
|
|
3
3
|
import ast
|
4
4
|
|
5
|
-
from
|
6
|
-
from
|
5
|
+
from gstaichi.lang._wrap_inspect import getsourcefile, getsourcelines
|
6
|
+
from gstaichi.lang.exception import GsTaichiSyntaxError
|
7
7
|
|
8
8
|
|
9
9
|
class KernelSimplicityASTChecker(ast.NodeVisitor):
|
@@ -62,7 +62,7 @@ class KernelSimplicityASTChecker(ast.NodeVisitor):
|
|
62
62
|
if not isinstance(node, ast.stmt):
|
63
63
|
return False
|
64
64
|
# TODO(#536): Frontend pass should help make sure |func| is a valid AST for
|
65
|
-
#
|
65
|
+
# GsTaichi.
|
66
66
|
ignored = [ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef]
|
67
67
|
return not any(map(lambda t: isinstance(node, t), ignored))
|
68
68
|
|
@@ -72,7 +72,7 @@ class KernelSimplicityASTChecker(ast.NodeVisitor):
|
|
72
72
|
return
|
73
73
|
|
74
74
|
if not (self.top_level or self.current_scope.allows_more_stmt):
|
75
|
-
raise
|
75
|
+
raise GsTaichiSyntaxError(f"No more statements allowed, at {self.get_error_location(node)}")
|
76
76
|
old_top_level = self.top_level
|
77
77
|
if old_top_level:
|
78
78
|
self._scope_guards.append(self.new_scope())
|
@@ -96,7 +96,7 @@ class KernelSimplicityASTChecker(ast.NodeVisitor):
|
|
96
96
|
# and node.iter.func.attr == 'static')
|
97
97
|
# if not (self.top_level or self.current_scope.allows_for_loop
|
98
98
|
# or is_static):
|
99
|
-
# raise
|
99
|
+
# raise GsTaichiSyntaxError(
|
100
100
|
# f'No more for loops allowed, at {self.get_error_location(node)}'
|
101
101
|
# )
|
102
102
|
# with self.new_scope():
|
@@ -0,0 +1,9 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
from gstaichi.lang.ast.ast_transformer import ASTTransformer
|
4
|
+
from gstaichi.lang.ast.ast_transformer_utils import ASTTransformerContext
|
5
|
+
|
6
|
+
|
7
|
+
def transform_tree(tree, ctx: ASTTransformerContext):
|
8
|
+
ASTTransformer()(ctx, tree)
|
9
|
+
return ctx.return_data
|
@@ -2,13 +2,13 @@
|
|
2
2
|
|
3
3
|
from typing import TYPE_CHECKING
|
4
4
|
|
5
|
-
from
|
6
|
-
from
|
7
|
-
from
|
5
|
+
from gstaichi.lang import ops
|
6
|
+
from gstaichi.lang.util import in_python_scope
|
7
|
+
from gstaichi.types import primitive_types
|
8
8
|
|
9
9
|
|
10
|
-
class
|
11
|
-
"""The base class of
|
10
|
+
class GsTaichiOperations:
|
11
|
+
"""The base class of gstaichi operations of expressions. Subclasses: :class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`"""
|
12
12
|
|
13
13
|
if TYPE_CHECKING:
|
14
14
|
# Make pylint happy
|
@@ -124,7 +124,7 @@ class TaichiOperations:
|
|
124
124
|
other (Any): Given operand.
|
125
125
|
|
126
126
|
Returns:
|
127
|
-
:class:`~
|
127
|
+
:class:`~gstaichi.lang.expr.Expr`: The computing expression of atomic add."""
|
128
128
|
return ops.atomic_add(self, other)
|
129
129
|
|
130
130
|
def _atomic_mul(self, other):
|
@@ -134,7 +134,7 @@ class TaichiOperations:
|
|
134
134
|
other (Any): Given operand.
|
135
135
|
|
136
136
|
Returns:
|
137
|
-
:class:`~
|
137
|
+
:class:`~gstaichi.lang.expr.Expr`: The computing expression of atomic mul."""
|
138
138
|
return ops.atomic_mul(self, other)
|
139
139
|
|
140
140
|
def _atomic_sub(self, other):
|
@@ -144,7 +144,7 @@ class TaichiOperations:
|
|
144
144
|
other (Any): Given operand.
|
145
145
|
|
146
146
|
Returns:
|
147
|
-
:class:`~
|
147
|
+
:class:`~gstaichi.lang.expr.Expr`: The computing expression of atomic sub."""
|
148
148
|
return ops.atomic_sub(self, other)
|
149
149
|
|
150
150
|
def _atomic_and(self, other):
|
@@ -154,7 +154,7 @@ class TaichiOperations:
|
|
154
154
|
other (Any): Given operand.
|
155
155
|
|
156
156
|
Returns:
|
157
|
-
:class:`~
|
157
|
+
:class:`~gstaichi.lang.expr.Expr`: The computing expression of atomic and."""
|
158
158
|
return ops.atomic_and(self, other)
|
159
159
|
|
160
160
|
def _atomic_xor(self, other):
|
@@ -164,7 +164,7 @@ class TaichiOperations:
|
|
164
164
|
other (Any): Given operand.
|
165
165
|
|
166
166
|
Returns:
|
167
|
-
:class:`~
|
167
|
+
:class:`~gstaichi.lang.expr.Expr`: The computing expression of atomic xor."""
|
168
168
|
return ops.atomic_xor(self, other)
|
169
169
|
|
170
170
|
def _atomic_or(self, other):
|
@@ -174,7 +174,7 @@ class TaichiOperations:
|
|
174
174
|
other (Any): Given operand.
|
175
175
|
|
176
176
|
Returns:
|
177
|
-
:class:`~
|
177
|
+
:class:`~gstaichi.lang.expr.Expr`: The computing expression of atomic or."""
|
178
178
|
return ops.atomic_or(self, other)
|
179
179
|
|
180
180
|
# In-place operators in python scope returns NotImplemented to fall back to normal operators
|
@@ -264,7 +264,7 @@ class TaichiOperations:
|
|
264
264
|
other (Any): Given operand.
|
265
265
|
|
266
266
|
Returns:
|
267
|
-
:class:`~
|
267
|
+
:class:`~gstaichi.lang.expr.Expr`: The expression after assigning."""
|
268
268
|
return ops.assign(self, other)
|
269
269
|
|
270
270
|
def _augassign(self, x, op):
|
@@ -0,0 +1,80 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
from gstaichi._lib import core
|
4
|
+
|
5
|
+
|
6
|
+
class GsTaichiCompilationError(Exception):
|
7
|
+
"""Base class for all compilation exceptions."""
|
8
|
+
|
9
|
+
pass
|
10
|
+
|
11
|
+
|
12
|
+
class GsTaichiSyntaxError(GsTaichiCompilationError, SyntaxError):
|
13
|
+
"""Thrown when a syntax error is found during compilation."""
|
14
|
+
|
15
|
+
pass
|
16
|
+
|
17
|
+
|
18
|
+
class GsTaichiNameError(GsTaichiCompilationError, NameError):
|
19
|
+
"""Thrown when an undefine name is found during compilation."""
|
20
|
+
|
21
|
+
pass
|
22
|
+
|
23
|
+
|
24
|
+
class GsTaichiIndexError(GsTaichiCompilationError, IndexError):
|
25
|
+
"""Thrown when an index error is found during compilation."""
|
26
|
+
|
27
|
+
pass
|
28
|
+
|
29
|
+
|
30
|
+
class GsTaichiTypeError(GsTaichiCompilationError, TypeError):
|
31
|
+
"""Thrown when a type mismatch is found during compilation."""
|
32
|
+
|
33
|
+
pass
|
34
|
+
|
35
|
+
|
36
|
+
class GsTaichiRuntimeError(RuntimeError):
|
37
|
+
"""Thrown when the compiled program cannot be executed due to unspecified reasons."""
|
38
|
+
|
39
|
+
pass
|
40
|
+
|
41
|
+
|
42
|
+
class GsTaichiAssertionError(GsTaichiRuntimeError, AssertionError):
|
43
|
+
"""Thrown when assertion fails at runtime."""
|
44
|
+
|
45
|
+
pass
|
46
|
+
|
47
|
+
|
48
|
+
class GsTaichiRuntimeTypeError(GsTaichiRuntimeError, TypeError):
|
49
|
+
@staticmethod
|
50
|
+
def get(pos, needed, provided):
|
51
|
+
return GsTaichiRuntimeTypeError(
|
52
|
+
f"Argument {pos} (type={provided}) cannot be converted into required type {needed}"
|
53
|
+
)
|
54
|
+
|
55
|
+
@staticmethod
|
56
|
+
def get_ret(needed, provided):
|
57
|
+
return GsTaichiRuntimeTypeError(f"Return (type={provided}) cannot be converted into required type {needed}")
|
58
|
+
|
59
|
+
|
60
|
+
def handle_exception_from_cpp(exc):
|
61
|
+
if isinstance(exc, core.GsTaichiTypeError):
|
62
|
+
return GsTaichiTypeError(str(exc))
|
63
|
+
if isinstance(exc, core.GsTaichiSyntaxError):
|
64
|
+
return GsTaichiSyntaxError(str(exc))
|
65
|
+
if isinstance(exc, core.GsTaichiIndexError):
|
66
|
+
return GsTaichiIndexError(str(exc))
|
67
|
+
if isinstance(exc, core.GsTaichiAssertionError):
|
68
|
+
return GsTaichiAssertionError(str(exc))
|
69
|
+
return exc
|
70
|
+
|
71
|
+
|
72
|
+
__all__ = [
|
73
|
+
"GsTaichiSyntaxError",
|
74
|
+
"GsTaichiTypeError",
|
75
|
+
"GsTaichiCompilationError",
|
76
|
+
"GsTaichiNameError",
|
77
|
+
"GsTaichiRuntimeError",
|
78
|
+
"GsTaichiRuntimeTypeError",
|
79
|
+
"GsTaichiAssertionError",
|
80
|
+
]
|
{taichi → gstaichi}/lang/expr.py
RENAMED
@@ -2,21 +2,21 @@ from typing import TYPE_CHECKING
|
|
2
2
|
|
3
3
|
import numpy as np
|
4
4
|
|
5
|
-
from
|
6
|
-
from
|
7
|
-
from
|
8
|
-
from
|
9
|
-
from
|
10
|
-
from
|
11
|
-
from
|
12
|
-
from
|
5
|
+
from gstaichi._lib import core as _ti_core
|
6
|
+
from gstaichi.lang import impl
|
7
|
+
from gstaichi.lang.common_ops import GsTaichiOperations
|
8
|
+
from gstaichi.lang.exception import GsTaichiCompilationError, GsTaichiTypeError
|
9
|
+
from gstaichi.lang.matrix import make_matrix
|
10
|
+
from gstaichi.lang.util import is_gstaichi_class, is_matrix_class, to_numpy_type
|
11
|
+
from gstaichi.types import primitive_types
|
12
|
+
from gstaichi.types.primitive_types import integer_types, real_types
|
13
13
|
|
14
14
|
if TYPE_CHECKING:
|
15
|
-
from
|
15
|
+
from gstaichi.lang.ast.ast_transformer_utils import ASTBuilder
|
16
16
|
|
17
17
|
|
18
18
|
# Scalar, basic data type
|
19
|
-
class Expr(
|
19
|
+
class Expr(GsTaichiOperations):
|
20
20
|
"""A Python-side Expr wrapper, whose member variable `ptr` is an instance of C++ Expr class. A C++ Expr object contains member variable `expr` which holds an instance of C++ Expression class."""
|
21
21
|
|
22
22
|
def __init__(self, *args, dbg_info=None, dtype=None):
|
@@ -24,7 +24,7 @@ class Expr(TaichiOperations):
|
|
24
24
|
self.ptr_type_checked = False
|
25
25
|
self.declaration_tb: str = ""
|
26
26
|
if len(args) == 1:
|
27
|
-
if isinstance(args[0], _ti_core.
|
27
|
+
if isinstance(args[0], _ti_core.ExprCxx):
|
28
28
|
self.ptr = args[0]
|
29
29
|
elif isinstance(args[0], Expr):
|
30
30
|
self.ptr = args[0].ptr
|
@@ -39,7 +39,7 @@ class Expr(TaichiOperations):
|
|
39
39
|
arg = args[0]
|
40
40
|
if isinstance(arg, np.ndarray):
|
41
41
|
if arg.shape:
|
42
|
-
raise
|
42
|
+
raise GsTaichiTypeError(
|
43
43
|
"Only 0-dimensional numpy array can be used to initialize a scalar expression"
|
44
44
|
)
|
45
45
|
arg = arg.dtype.type(arg)
|
@@ -63,7 +63,7 @@ class Expr(TaichiOperations):
|
|
63
63
|
|
64
64
|
def get_shape(self):
|
65
65
|
if not self.is_tensor():
|
66
|
-
raise
|
66
|
+
raise GsTaichiCompilationError(f"Getting shape of non-tensor type: {self.ptr.get_rvalue_type()}")
|
67
67
|
shape = self.ptr.get_shape()
|
68
68
|
assert shape is not None
|
69
69
|
return tuple(shape)
|
@@ -72,14 +72,14 @@ class Expr(TaichiOperations):
|
|
72
72
|
def n(self):
|
73
73
|
shape = self.get_shape()
|
74
74
|
if len(shape) < 1:
|
75
|
-
raise
|
75
|
+
raise GsTaichiCompilationError(f"Getting n of tensor type < 1D: {self.ptr.get_rvalue_type()}")
|
76
76
|
return shape[0]
|
77
77
|
|
78
78
|
@property
|
79
79
|
def m(self):
|
80
80
|
shape = self.get_shape()
|
81
81
|
if len(shape) < 2:
|
82
|
-
raise
|
82
|
+
raise GsTaichiCompilationError(f"Getting m of tensor type < 2D: {self.ptr.get_rvalue_type()}")
|
83
83
|
return shape[1]
|
84
84
|
|
85
85
|
def __hash__(self):
|
@@ -116,7 +116,7 @@ def make_constant_expr(val, dtype):
|
|
116
116
|
if isinstance(val, (float, np.floating)):
|
117
117
|
constant_dtype = impl.get_runtime().default_fp if dtype is None else dtype
|
118
118
|
if constant_dtype not in real_types:
|
119
|
-
raise
|
119
|
+
raise GsTaichiTypeError(
|
120
120
|
"Floating-point literals must be annotated with a floating-point type. For type casting, use `ti.cast`."
|
121
121
|
)
|
122
122
|
return Expr(_ti_core.make_const_expr_fp(constant_dtype, val))
|
@@ -124,19 +124,19 @@ def make_constant_expr(val, dtype):
|
|
124
124
|
if isinstance(val, (int, np.integer)):
|
125
125
|
constant_dtype = impl.get_runtime().default_ip if dtype is None else dtype
|
126
126
|
if constant_dtype not in integer_types:
|
127
|
-
raise
|
127
|
+
raise GsTaichiTypeError(
|
128
128
|
"Integer literals must be annotated with a integer type. For type casting, use `ti.cast`."
|
129
129
|
)
|
130
130
|
if _check_in_range(to_numpy_type(constant_dtype), val):
|
131
131
|
return Expr(_ti_core.make_const_expr_int(constant_dtype, _clamp_unsigned_to_range(np.int64, val)))
|
132
132
|
if dtype is None:
|
133
|
-
raise
|
133
|
+
raise GsTaichiTypeError(
|
134
134
|
f"Integer literal {val} exceeded the range of default_ip: {impl.get_runtime().default_ip}, please specify the dtype via e.g. `ti.u64({val})` or set a different `default_ip` in `ti.init()`"
|
135
135
|
)
|
136
136
|
else:
|
137
|
-
raise
|
137
|
+
raise GsTaichiTypeError(f"Integer literal {val} exceeded the range of specified dtype: {dtype}")
|
138
138
|
|
139
|
-
raise
|
139
|
+
raise GsTaichiTypeError(f"Invalid constant scalar data type: {type(val)}")
|
140
140
|
|
141
141
|
|
142
142
|
def make_var_list(size: int, ast_builder: "ASTBuilder | None" = None):
|
@@ -151,7 +151,7 @@ def make_var_list(size: int, ast_builder: "ASTBuilder | None" = None):
|
|
151
151
|
|
152
152
|
|
153
153
|
def make_expr_group(*exprs):
|
154
|
-
from
|
154
|
+
from gstaichi.lang.matrix import Matrix # pylint: disable=C0415
|
155
155
|
|
156
156
|
if len(exprs) == 1:
|
157
157
|
if isinstance(exprs[0], (list, tuple)):
|
@@ -169,7 +169,7 @@ def make_expr_group(*exprs):
|
|
169
169
|
|
170
170
|
|
171
171
|
def _get_flattened_ptrs(val):
|
172
|
-
if
|
172
|
+
if is_gstaichi_class(val):
|
173
173
|
ptrs = []
|
174
174
|
for item in val._members:
|
175
175
|
ptrs.extend(_get_flattened_ptrs(item))
|