gstaichi 0.1.20.dev0__cp310-cp310-macosx_15_0_arm64.whl → 0.1.25.dev0__cp310-cp310-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {taichi → gstaichi}/__init__.py +9 -13
- {taichi → gstaichi}/_funcs.py +8 -8
- {taichi → gstaichi}/_kernels.py +19 -19
- gstaichi/_lib/__init__.py +3 -0
- taichi/_lib/core/taichi_python.cpython-310-darwin.so → gstaichi/_lib/core/gstaichi_python.cpython-310-darwin.so +0 -0
- taichi/_lib/core/taichi_python.pyi → gstaichi/_lib/core/gstaichi_python.pyi +382 -520
- {taichi → gstaichi}/_lib/runtime/runtime_arm64.bc +0 -0
- {taichi → gstaichi}/_lib/utils.py +15 -15
- {taichi → gstaichi}/_logging.py +1 -1
- {taichi → gstaichi}/_main.py +24 -31
- gstaichi/_snode/__init__.py +5 -0
- {taichi → gstaichi}/_snode/fields_builder.py +27 -29
- {taichi → gstaichi}/_snode/snode_tree.py +5 -5
- gstaichi/_test_tools/__init__.py +0 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_version.py +1 -0
- {taichi → gstaichi}/_version_check.py +8 -5
- gstaichi/ad/__init__.py +3 -0
- {taichi → gstaichi}/ad/_ad.py +26 -26
- {taichi → gstaichi}/algorithms/_algorithms.py +7 -7
- {taichi → gstaichi}/examples/minimal.py +1 -1
- {taichi → gstaichi}/experimental.py +1 -1
- gstaichi/lang/__init__.py +50 -0
- {taichi → gstaichi}/lang/_ndarray.py +30 -26
- {taichi → gstaichi}/lang/_ndrange.py +8 -8
- gstaichi/lang/_template_mapper.py +199 -0
- {taichi → gstaichi}/lang/_texture.py +19 -19
- {taichi → gstaichi}/lang/_wrap_inspect.py +7 -7
- {taichi → gstaichi}/lang/any_array.py +13 -13
- {taichi → gstaichi}/lang/argpack.py +29 -29
- gstaichi/lang/ast/__init__.py +5 -0
- {taichi → gstaichi}/lang/ast/ast_transformer.py +94 -582
- {taichi → gstaichi}/lang/ast/ast_transformer_utils.py +54 -41
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +267 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +320 -0
- {taichi → gstaichi}/lang/ast/checkers.py +5 -5
- gstaichi/lang/ast/transform.py +9 -0
- {taichi → gstaichi}/lang/common_ops.py +12 -12
- gstaichi/lang/exception.py +80 -0
- {taichi → gstaichi}/lang/expr.py +22 -22
- {taichi → gstaichi}/lang/field.py +29 -27
- {taichi → gstaichi}/lang/impl.py +116 -121
- {taichi → gstaichi}/lang/kernel_arguments.py +16 -16
- {taichi → gstaichi}/lang/kernel_impl.py +330 -363
- {taichi → gstaichi}/lang/matrix.py +119 -115
- {taichi → gstaichi}/lang/matrix_ops.py +6 -6
- {taichi → gstaichi}/lang/matrix_ops_utils.py +4 -4
- {taichi → gstaichi}/lang/mesh.py +22 -22
- {taichi → gstaichi}/lang/misc.py +39 -68
- {taichi → gstaichi}/lang/ops.py +146 -141
- {taichi → gstaichi}/lang/runtime_ops.py +2 -2
- {taichi → gstaichi}/lang/shell.py +3 -3
- {taichi → gstaichi}/lang/simt/__init__.py +1 -1
- {taichi → gstaichi}/lang/simt/block.py +7 -7
- {taichi → gstaichi}/lang/simt/grid.py +1 -1
- {taichi → gstaichi}/lang/simt/subgroup.py +1 -1
- {taichi → gstaichi}/lang/simt/warp.py +1 -1
- {taichi → gstaichi}/lang/snode.py +46 -44
- {taichi → gstaichi}/lang/source_builder.py +13 -13
- {taichi → gstaichi}/lang/struct.py +33 -33
- {taichi → gstaichi}/lang/util.py +24 -24
- gstaichi/linalg/__init__.py +8 -0
- {taichi → gstaichi}/linalg/matrixfree_cg.py +14 -14
- {taichi → gstaichi}/linalg/sparse_cg.py +10 -10
- {taichi → gstaichi}/linalg/sparse_matrix.py +23 -23
- {taichi → gstaichi}/linalg/sparse_solver.py +21 -21
- {taichi → gstaichi}/math/__init__.py +1 -1
- {taichi → gstaichi}/math/_complex.py +21 -20
- {taichi → gstaichi}/math/mathimpl.py +56 -56
- gstaichi/profiler/__init__.py +6 -0
- {taichi → gstaichi}/profiler/kernel_metrics.py +11 -11
- {taichi → gstaichi}/profiler/kernel_profiler.py +30 -36
- {taichi → gstaichi}/profiler/memory_profiler.py +1 -1
- {taichi → gstaichi}/profiler/scoped_profiler.py +2 -2
- {taichi → gstaichi}/sparse/_sparse_grid.py +7 -7
- {taichi → gstaichi}/tools/__init__.py +4 -4
- {taichi → gstaichi}/tools/diagnose.py +10 -17
- gstaichi/types/__init__.py +19 -0
- {taichi → gstaichi}/types/annotations.py +1 -1
- {taichi → gstaichi}/types/compound_types.py +8 -8
- {taichi → gstaichi}/types/enums.py +1 -1
- {taichi → gstaichi}/types/ndarray_type.py +7 -7
- {taichi → gstaichi}/types/primitive_types.py +17 -14
- {taichi → gstaichi}/types/quant.py +9 -9
- {taichi → gstaichi}/types/texture_type.py +5 -5
- {taichi → gstaichi}/types/utils.py +1 -1
- {gstaichi-0.1.20.dev0.dist-info → gstaichi-0.1.25.dev0.dist-info}/METADATA +13 -16
- gstaichi-0.1.25.dev0.dist-info/RECORD +168 -0
- gstaichi-0.1.25.dev0.dist-info/entry_points.txt +2 -0
- gstaichi-0.1.25.dev0.dist-info/top_level.txt +1 -0
- gstaichi-0.1.20.dev0.dist-info/RECORD +0 -219
- gstaichi-0.1.20.dev0.dist-info/entry_points.txt +0 -2
- gstaichi-0.1.20.dev0.dist-info/top_level.txt +0 -1
- taichi/_lib/__init__.py +0 -3
- taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +0 -1401
- taichi/_lib/c_api/include/taichi/taichi.h +0 -29
- taichi/_lib/c_api/include/taichi/taichi_core.h +0 -1111
- taichi/_lib/c_api/include/taichi/taichi_cpu.h +0 -29
- taichi/_lib/c_api/include/taichi/taichi_metal.h +0 -72
- taichi/_lib/c_api/include/taichi/taichi_platform.h +0 -55
- taichi/_lib/c_api/include/taichi/taichi_unity.h +0 -64
- taichi/_lib/c_api/include/taichi/taichi_vulkan.h +0 -151
- taichi/_lib/c_api/lib/libtaichi_c_api.dylib +0 -0
- taichi/_lib/c_api/runtime/runtime_arm64.bc +0 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +0 -29
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +0 -65
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +0 -121
- taichi/_lib/runtime/libMoltenVK.dylib +0 -0
- taichi/_snode/__init__.py +0 -5
- taichi/_ti_module/__init__.py +0 -3
- taichi/_ti_module/cppgen.py +0 -309
- taichi/_ti_module/module.py +0 -145
- taichi/_version.py +0 -1
- taichi/ad/__init__.py +0 -3
- taichi/aot/__init__.py +0 -12
- taichi/aot/_export.py +0 -28
- taichi/aot/conventions/__init__.py +0 -3
- taichi/aot/conventions/gfxruntime140/__init__.py +0 -38
- taichi/aot/conventions/gfxruntime140/dr.py +0 -244
- taichi/aot/conventions/gfxruntime140/sr.py +0 -613
- taichi/aot/module.py +0 -253
- taichi/aot/utils.py +0 -151
- taichi/graph/__init__.py +0 -3
- taichi/graph/_graph.py +0 -292
- taichi/lang/__init__.py +0 -50
- taichi/lang/ast/__init__.py +0 -5
- taichi/lang/ast/transform.py +0 -9
- taichi/lang/exception.py +0 -80
- taichi/linalg/__init__.py +0 -8
- taichi/profiler/__init__.py +0 -6
- taichi/shaders/Circles_vk.frag +0 -29
- taichi/shaders/Circles_vk.vert +0 -45
- taichi/shaders/Circles_vk_frag.spv +0 -0
- taichi/shaders/Circles_vk_vert.spv +0 -0
- taichi/shaders/Lines_vk.frag +0 -9
- taichi/shaders/Lines_vk.vert +0 -11
- taichi/shaders/Lines_vk_frag.spv +0 -0
- taichi/shaders/Lines_vk_vert.spv +0 -0
- taichi/shaders/Mesh_vk.frag +0 -71
- taichi/shaders/Mesh_vk.vert +0 -68
- taichi/shaders/Mesh_vk_frag.spv +0 -0
- taichi/shaders/Mesh_vk_vert.spv +0 -0
- taichi/shaders/Particles_vk.frag +0 -95
- taichi/shaders/Particles_vk.vert +0 -73
- taichi/shaders/Particles_vk_frag.spv +0 -0
- taichi/shaders/Particles_vk_vert.spv +0 -0
- taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
- taichi/shaders/SceneLines_vk.frag +0 -9
- taichi/shaders/SceneLines_vk.vert +0 -12
- taichi/shaders/SceneLines_vk_frag.spv +0 -0
- taichi/shaders/SceneLines_vk_vert.spv +0 -0
- taichi/shaders/SetImage_vk.frag +0 -21
- taichi/shaders/SetImage_vk.vert +0 -15
- taichi/shaders/SetImage_vk_frag.spv +0 -0
- taichi/shaders/SetImage_vk_vert.spv +0 -0
- taichi/shaders/Triangles_vk.frag +0 -16
- taichi/shaders/Triangles_vk.vert +0 -29
- taichi/shaders/Triangles_vk_frag.spv +0 -0
- taichi/shaders/Triangles_vk_vert.spv +0 -0
- taichi/shaders/lines2quad_vk_comp.spv +0 -0
- taichi/types/__init__.py +0 -19
- {taichi → gstaichi}/__main__.py +0 -0
- {taichi → gstaichi}/_lib/core/__init__.py +0 -0
- {taichi → gstaichi}/_lib/core/py.typed +0 -0
- {taichi/_lib/c_api → gstaichi/_lib}/runtime/libMoltenVK.dylib +0 -0
- {taichi → gstaichi}/algorithms/__init__.py +0 -0
- {taichi → gstaichi}/assets/.git +0 -0
- {taichi → gstaichi}/assets/Go-Regular.ttf +0 -0
- {taichi → gstaichi}/assets/static/imgs/ti_gallery.png +0 -0
- {taichi → gstaichi}/lang/ast/symbol_resolver.py +0 -0
- {taichi → gstaichi}/sparse/__init__.py +0 -0
- {taichi → gstaichi}/tools/np2ply.py +0 -0
- {taichi → gstaichi}/tools/vtk.py +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/GLFW/glfw3.h +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/GLFW/glfw3native.h +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/instrument.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/libspirv.h +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/libspirv.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/linker.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/optimizer.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/GLSL.std.450.h +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv.h +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_cfg.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_common.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_cpp.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_cross.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_cross_c.h +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_cross_containers.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_cross_error_handling.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_cross_util.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_glsl.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_hlsl.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_msl.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_parser.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_reflect.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/glfw3/glfw3Config.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/glfw3/glfw3Targets-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/glfw3/glfw3Targets.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.dist-info → gstaichi-0.1.25.dev0.dist-info}/WHEEL +0 -0
- {gstaichi-0.1.20.dev0.dist-info → gstaichi-0.1.25.dev0.dist-info}/licenses/LICENSE +0 -0
@@ -1,11 +1,11 @@
|
|
1
1
|
# type: ignore
|
2
2
|
|
3
|
-
from
|
3
|
+
from gstaichi.lang import impl
|
4
4
|
|
5
5
|
|
6
6
|
def sync():
|
7
7
|
"""Blocks the calling thread until all the previously
|
8
|
-
launched
|
8
|
+
launched GsTaichi kernels have completed.
|
9
9
|
"""
|
10
10
|
impl.get_runtime().sync()
|
11
11
|
|
@@ -4,8 +4,8 @@ import functools
|
|
4
4
|
import os
|
5
5
|
import sys
|
6
6
|
|
7
|
-
from
|
8
|
-
from
|
7
|
+
from gstaichi._lib import core as _ti_core
|
8
|
+
from gstaichi._logging import info
|
9
9
|
|
10
10
|
pybuf_enabled = False
|
11
11
|
_env_enable_pybuf = os.environ.get("TI_ENABLE_PYBUF", "1")
|
@@ -28,7 +28,7 @@ def _shell_pop_print(old_call):
|
|
28
28
|
def new_call(*args, **kwargs):
|
29
29
|
ret = old_call(*args, **kwargs)
|
30
30
|
# print's in kernel won't take effect until ti.sync(), discussion:
|
31
|
-
# https://github.com/taichi-dev/
|
31
|
+
# https://github.com/taichi-dev/gstaichi/pull/1303#discussion_r444897102
|
32
32
|
print(_ti_core.pop_python_print_buffer(), end="")
|
33
33
|
return ret
|
34
34
|
|
@@ -1,13 +1,13 @@
|
|
1
1
|
# type: ignore
|
2
2
|
|
3
|
-
from
|
4
|
-
from
|
5
|
-
from
|
6
|
-
from
|
3
|
+
from gstaichi._lib import core as _ti_core
|
4
|
+
from gstaichi.lang import impl
|
5
|
+
from gstaichi.lang.expr import make_expr_group
|
6
|
+
from gstaichi.lang.util import gstaichi_scope
|
7
7
|
|
8
8
|
|
9
9
|
def arch_uses_spv(arch):
|
10
|
-
return arch == _ti_core.vulkan or arch == _ti_core.metal
|
10
|
+
return arch == _ti_core.vulkan or arch == _ti_core.metal
|
11
11
|
|
12
12
|
|
13
13
|
def sync():
|
@@ -66,7 +66,7 @@ def global_thread_idx():
|
|
66
66
|
|
67
67
|
|
68
68
|
class SharedArray:
|
69
|
-
|
69
|
+
_is_gstaichi_class = True
|
70
70
|
|
71
71
|
def __init__(self, shape, dtype):
|
72
72
|
if isinstance(shape, int):
|
@@ -82,7 +82,7 @@ class SharedArray:
|
|
82
82
|
self.dtype = dtype
|
83
83
|
self.shared_array_proxy = impl.expr_init_shared_array(self.shape, dtype)
|
84
84
|
|
85
|
-
@
|
85
|
+
@gstaichi_scope
|
86
86
|
def subscript(self, *indices):
|
87
87
|
ast_builder = impl.get_runtime().compiling_callable.ast_builder()
|
88
88
|
return impl.Expr(
|
@@ -2,23 +2,25 @@
|
|
2
2
|
|
3
3
|
import numbers
|
4
4
|
|
5
|
-
from
|
6
|
-
from
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
from
|
11
|
-
from
|
5
|
+
from gstaichi._lib import core as _ti_core
|
6
|
+
from gstaichi._lib.core.gstaichi_python import (
|
7
|
+
Axis,
|
8
|
+
SNodeCxx,
|
9
|
+
)
|
10
|
+
from gstaichi.lang import expr, impl, matrix
|
11
|
+
from gstaichi.lang.exception import GsTaichiRuntimeError
|
12
|
+
from gstaichi.lang.field import BitpackedFields, Field
|
13
|
+
from gstaichi.lang.util import get_traceback
|
12
14
|
|
13
15
|
|
14
16
|
class SNode:
|
15
17
|
"""A Python-side SNode wrapper.
|
16
18
|
|
17
|
-
For more information on
|
19
|
+
For more information on GsTaichi's SNode system, please check out
|
18
20
|
these references:
|
19
21
|
|
20
22
|
* https://docs.taichi-lang.org/docs/sparse
|
21
|
-
* https://yuanming.
|
23
|
+
* https://yuanming.gstaichi.graphics/publication/2019-gstaichi/gstaichi-lang.pdf
|
22
24
|
|
23
25
|
Arg:
|
24
26
|
ptr (pointer): The C++ side SNode pointer.
|
@@ -35,7 +37,7 @@ class SNode:
|
|
35
37
|
dimensions (Union[List[int], int]): Shape of each axis.
|
36
38
|
|
37
39
|
Returns:
|
38
|
-
The added :class:`~
|
40
|
+
The added :class:`~gstaichi.lang.SNode` instance.
|
39
41
|
"""
|
40
42
|
if isinstance(dimensions, numbers.Number):
|
41
43
|
dimensions = [dimensions] * len(axes)
|
@@ -49,10 +51,10 @@ class SNode:
|
|
49
51
|
dimensions (Union[List[int], int]): Shape of each axis.
|
50
52
|
|
51
53
|
Returns:
|
52
|
-
The added :class:`~
|
54
|
+
The added :class:`~gstaichi.lang.SNode` instance.
|
53
55
|
"""
|
54
56
|
if not _ti_core.is_extension_supported(impl.current_cfg().arch, _ti_core.Extension.sparse):
|
55
|
-
raise
|
57
|
+
raise GsTaichiRuntimeError("Pointer SNode is not supported on this backend.")
|
56
58
|
if isinstance(dimensions, numbers.Number):
|
57
59
|
dimensions = [dimensions] * len(axes)
|
58
60
|
return SNode(self.ptr.pointer(axes, dimensions, _ti_core.DebugInfo(get_traceback())))
|
@@ -75,10 +77,10 @@ class SNode:
|
|
75
77
|
chunk_size (int): Chunk size.
|
76
78
|
|
77
79
|
Returns:
|
78
|
-
The added :class:`~
|
80
|
+
The added :class:`~gstaichi.lang.SNode` instance.
|
79
81
|
"""
|
80
82
|
if not _ti_core.is_extension_supported(impl.current_cfg().arch, _ti_core.Extension.sparse):
|
81
|
-
raise
|
83
|
+
raise GsTaichiRuntimeError("Dynamic SNode is not supported on this backend.")
|
82
84
|
assert len(axis) == 1
|
83
85
|
if chunk_size is None:
|
84
86
|
chunk_size = dimension
|
@@ -92,10 +94,10 @@ class SNode:
|
|
92
94
|
dimensions (Union[List[int], int]): Shape of each axis.
|
93
95
|
|
94
96
|
Returns:
|
95
|
-
The added :class:`~
|
97
|
+
The added :class:`~gstaichi.lang.SNode` instance.
|
96
98
|
"""
|
97
99
|
if not _ti_core.is_extension_supported(impl.current_cfg().arch, _ti_core.Extension.sparse):
|
98
|
-
raise
|
100
|
+
raise GsTaichiRuntimeError("Bitmasked SNode is not supported on this backend.")
|
99
101
|
if isinstance(dimensions, numbers.Number):
|
100
102
|
dimensions = [dimensions] * len(axes)
|
101
103
|
return SNode(self.ptr.bitmasked(axes, dimensions, _ti_core.DebugInfo(get_traceback())))
|
@@ -109,17 +111,17 @@ class SNode:
|
|
109
111
|
max_num_bits (int): Maximum number of bits it can hold.
|
110
112
|
|
111
113
|
Returns:
|
112
|
-
The added :class:`~
|
114
|
+
The added :class:`~gstaichi.lang.SNode` instance.
|
113
115
|
"""
|
114
116
|
if isinstance(dimensions, numbers.Number):
|
115
117
|
dimensions = [dimensions] * len(axes)
|
116
118
|
return SNode(self.ptr.quant_array(axes, dimensions, max_num_bits, _ti_core.DebugInfo(get_traceback())))
|
117
119
|
|
118
120
|
def place(self, *args, offset: numbers.Number | tuple[numbers.Number] | None = None) -> "SNode":
|
119
|
-
"""Places a list of
|
121
|
+
"""Places a list of GsTaichi fields under the `self` container.
|
120
122
|
|
121
123
|
Args:
|
122
|
-
*args (List[ti.field]): A list of
|
124
|
+
*args (List[ti.field]): A list of GsTaichi fields to place.
|
123
125
|
offset (Union[Number, tuple[Number]]): Offset of the field domain.
|
124
126
|
|
125
127
|
Returns:
|
@@ -150,11 +152,11 @@ class SNode:
|
|
150
152
|
"""Automatically place the adjoint fields following the layout of their primal fields.
|
151
153
|
|
152
154
|
Users don't need to specify ``needs_grad`` when they define scalar/vector/matrix fields (primal fields) using autodiff.
|
153
|
-
When all the primal fields are defined, using ``
|
155
|
+
When all the primal fields are defined, using ``gstaichi.root.lazy_grad()`` could automatically generate
|
154
156
|
their corresponding adjoint fields (gradient field).
|
155
157
|
|
156
158
|
To know more details about primal, adjoint fields and ``lazy_grad()``,
|
157
|
-
please see Page 4 and Page 13-14 of
|
159
|
+
please see Page 4 and Page 13-14 of DiffGsTaichi Paper: https://arxiv.org/pdf/1910.00935.pdf
|
158
160
|
"""
|
159
161
|
self.ptr.lazy_grad()
|
160
162
|
|
@@ -236,10 +238,10 @@ class SNode:
|
|
236
238
|
return ret
|
237
239
|
|
238
240
|
def _loop_range(self):
|
239
|
-
"""Gets the
|
241
|
+
"""Gets the gstaichi_python.SNode to serve as loop range.
|
240
242
|
|
241
243
|
Returns:
|
242
|
-
|
244
|
+
gstaichi_python.SNode: See above.
|
243
245
|
"""
|
244
246
|
return self.ptr
|
245
247
|
|
@@ -294,14 +296,14 @@ class SNode:
|
|
294
296
|
c.deactivate_all()
|
295
297
|
SNodeType = _ti_core.SNodeType
|
296
298
|
if self.ptr.type == SNodeType.pointer or self.ptr.type == SNodeType.bitmasked:
|
297
|
-
from
|
299
|
+
from gstaichi._kernels import snode_deactivate # pylint: disable=C0415
|
298
300
|
|
299
301
|
snode_deactivate(self)
|
300
302
|
if self.ptr.type == SNodeType.dynamic:
|
301
303
|
# Note that dynamic nodes are different from other sparse nodes:
|
302
304
|
# instead of deactivating each element, we only need to deactivate
|
303
305
|
# its parent, whose linked list of chunks of elements will be deleted.
|
304
|
-
from
|
306
|
+
from gstaichi._kernels import ( # pylint: disable=C0415
|
305
307
|
snode_deactivate_dynamic,
|
306
308
|
)
|
307
309
|
|
@@ -340,11 +342,11 @@ def rescale_index(a, b, I):
|
|
340
342
|
|
341
343
|
Args:
|
342
344
|
|
343
|
-
a, b (Union[:class:`~
|
344
|
-
I (Union[list, :class:`~
|
345
|
+
a, b (Union[:class:`~gstaichi.Field`, :class:`~gstaichi.MatrixField`): Input gstaichi fields or snodes.
|
346
|
+
I (Union[list, :class:`~gstaichi.Vector`]): grouped loop index.
|
345
347
|
|
346
348
|
Returns:
|
347
|
-
Ib (:class:`~
|
349
|
+
Ib (:class:`~gstaichi.Vector`): rescaled grouped loop index
|
348
350
|
"""
|
349
351
|
|
350
352
|
assert isinstance(a, (Field, SNode)), "The first argument must be a field or an SNode"
|
@@ -357,7 +359,7 @@ def rescale_index(a, b, I):
|
|
357
359
|
), "The third argument must be an index (list, ti.Vector, or Expr with TensorType)"
|
358
360
|
n = I.n
|
359
361
|
|
360
|
-
from
|
362
|
+
from gstaichi.lang.kernel_impl import pyfunc # pylint: disable=C0415
|
361
363
|
|
362
364
|
@pyfunc
|
363
365
|
def _rescale_index():
|
@@ -376,9 +378,9 @@ def append(node, indices, val):
|
|
376
378
|
"""Append a value `val` to a SNode `node` at index `indices`.
|
377
379
|
|
378
380
|
Args:
|
379
|
-
node (:class:`~
|
380
|
-
indices (Union[int, :class:`~
|
381
|
-
val (Union[:mod:`~
|
381
|
+
node (:class:`~gstaichi.SNode`): Input SNode.
|
382
|
+
indices (Union[int, :class:`~gstaichi.Vector`]): the indices to visit.
|
383
|
+
val (Union[:mod:`~gstaichi.types.primitive_types`, :mod:`~gstaichi.types.compound_types`]): the data to be appended.
|
382
384
|
"""
|
383
385
|
ptrs = expr._get_flattened_ptrs(val)
|
384
386
|
append_expr = expr.Expr(
|
@@ -396,8 +398,8 @@ def is_active(node, indices):
|
|
396
398
|
`indices` is active or not.
|
397
399
|
|
398
400
|
Args:
|
399
|
-
node (:class:`~
|
400
|
-
indices (Union[int, list, :class:`~
|
401
|
+
node (:class:`~gstaichi.SNode`): Must be a pointer, hash or bitmasked node.
|
402
|
+
indices (Union[int, list, :class:`~gstaichi.Vector`]): the indices to visit.
|
401
403
|
|
402
404
|
Returns:
|
403
405
|
bool: the cell `node[indices]` is active or not.
|
@@ -414,8 +416,8 @@ def activate(node, indices):
|
|
414
416
|
"""Explicitly activate a cell of `node` at location `indices`.
|
415
417
|
|
416
418
|
Args:
|
417
|
-
node (:class:`~
|
418
|
-
indices (Union[int, :class:`~
|
419
|
+
node (:class:`~gstaichi.SNode`): Must be a pointer, hash or bitmasked node.
|
420
|
+
indices (Union[int, :class:`~gstaichi.Vector`]): the indices to activate.
|
419
421
|
"""
|
420
422
|
impl.get_runtime().compiling_callable.ast_builder().insert_activate(
|
421
423
|
node._snode.ptr, expr.make_expr_group(indices), _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
|
@@ -425,12 +427,12 @@ def activate(node, indices):
|
|
425
427
|
def deactivate(node, indices):
|
426
428
|
"""Explicitly deactivate a cell of `node` at location `indices`.
|
427
429
|
|
428
|
-
After deactivation, the
|
430
|
+
After deactivation, the GsTaichi runtime automatically recycles and zero-fills
|
429
431
|
the memory of the deactivated cell.
|
430
432
|
|
431
433
|
Args:
|
432
|
-
node (:class:`~
|
433
|
-
indices (Union[int, :class:`~
|
434
|
+
node (:class:`~gstaichi.SNode`): Must be a pointer, hash or bitmasked node.
|
435
|
+
indices (Union[int, :class:`~gstaichi.Vector`]): the indices to deactivate.
|
434
436
|
"""
|
435
437
|
impl.get_runtime().compiling_callable.ast_builder().insert_deactivate(
|
436
438
|
node._snode.ptr, expr.make_expr_group(indices), _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
|
@@ -441,8 +443,8 @@ def length(node, indices):
|
|
441
443
|
"""Return the length of the dynamic SNode `node` at index `indices`.
|
442
444
|
|
443
445
|
Args:
|
444
|
-
node (:class:`~
|
445
|
-
indices (Union[int, :class:`~
|
446
|
+
node (:class:`~gstaichi.SNode`): a dynamic SNode.
|
447
|
+
indices (Union[int, :class:`~gstaichi.Vector`]): the indices to query.
|
446
448
|
|
447
449
|
Returns:
|
448
450
|
int: the length of cell `node[indices]`.
|
@@ -458,11 +460,11 @@ def length(node, indices):
|
|
458
460
|
def get_addr(f, indices):
|
459
461
|
"""Query the memory address (on CUDA/x64) of field `f` at index `indices`.
|
460
462
|
|
461
|
-
Currently, this function can only be called inside a
|
463
|
+
Currently, this function can only be called inside a gstaichi kernel.
|
462
464
|
|
463
465
|
Args:
|
464
|
-
f (Union[:class:`~
|
465
|
-
indices (Union[int, :class:`~
|
466
|
+
f (Union[:class:`~gstaichi.Field`, :class:`~gstaichi.MatrixField`]): Input gstaichi field for memory address query.
|
467
|
+
indices (Union[int, :class:`~gstaichi.Vector`]): The specified field indices of the query.
|
466
468
|
|
467
469
|
Returns:
|
468
470
|
ti.u64: The memory address of `f[indices]`.
|
@@ -7,11 +7,11 @@ import shutil
|
|
7
7
|
import subprocess
|
8
8
|
import tempfile
|
9
9
|
|
10
|
-
from
|
11
|
-
from
|
12
|
-
from
|
13
|
-
from
|
14
|
-
from
|
10
|
+
from gstaichi._lib import core as _ti_core
|
11
|
+
from gstaichi.lang import impl
|
12
|
+
from gstaichi.lang.exception import GsTaichiSyntaxError
|
13
|
+
from gstaichi.lang.expr import make_expr_group
|
14
|
+
from gstaichi.lang.util import get_clangpp
|
15
15
|
|
16
16
|
|
17
17
|
class SourceBuilder:
|
@@ -36,7 +36,7 @@ class SourceBuilder:
|
|
36
36
|
|
37
37
|
if filename.endswith((".cpp", ".c", ".cc")):
|
38
38
|
if impl.current_cfg().arch not in [_ti_core.Arch.x64, _ti_core.Arch.cuda]:
|
39
|
-
raise
|
39
|
+
raise GsTaichiSyntaxError("Unsupported arch for external function call")
|
40
40
|
if compile_fn is None:
|
41
41
|
|
42
42
|
def compile_fn_impl(filename):
|
@@ -62,7 +62,7 @@ class SourceBuilder:
|
|
62
62
|
self.mode = "bc"
|
63
63
|
elif filename.endswith(".cu"):
|
64
64
|
if impl.current_cfg().arch not in [_ti_core.Arch.cuda]:
|
65
|
-
raise
|
65
|
+
raise GsTaichiSyntaxError("Unsupported arch for external function call")
|
66
66
|
if compile_fn is None:
|
67
67
|
shutil.copy(filename, os.path.join(self.td, "source.cu"))
|
68
68
|
|
@@ -83,12 +83,12 @@ class SourceBuilder:
|
|
83
83
|
self.mode = "bc"
|
84
84
|
elif filename.endswith((".so", ".dylib", ".dll")):
|
85
85
|
if impl.current_cfg().arch not in [_ti_core.Arch.x64]:
|
86
|
-
raise
|
86
|
+
raise GsTaichiSyntaxError("Unsupported arch for external function call")
|
87
87
|
self.so = ctypes.CDLL(filename)
|
88
88
|
self.mode = "so"
|
89
89
|
elif filename.endswith(".ll"):
|
90
90
|
if impl.current_cfg().arch not in [_ti_core.Arch.x64, _ti_core.Arch.cuda]:
|
91
|
-
raise
|
91
|
+
raise GsTaichiSyntaxError("Unsupported arch for external function call")
|
92
92
|
subprocess.call(
|
93
93
|
"llvm-as " + filename + " -o " + os.path.join(self.td, "source.bc"),
|
94
94
|
shell=True,
|
@@ -97,17 +97,17 @@ class SourceBuilder:
|
|
97
97
|
self.mode = "bc"
|
98
98
|
elif filename.endswith(".bc"):
|
99
99
|
if impl.current_cfg().arch not in [_ti_core.Arch.x64, _ti_core.Arch.cuda]:
|
100
|
-
raise
|
100
|
+
raise GsTaichiSyntaxError("Unsupported arch for external function call")
|
101
101
|
self.bc = filename
|
102
102
|
self.mode = "bc"
|
103
103
|
else:
|
104
|
-
raise
|
104
|
+
raise GsTaichiSyntaxError("Unsupported file type for external function call.")
|
105
105
|
return self
|
106
106
|
|
107
107
|
@classmethod
|
108
108
|
def from_source(cls, source_code, compile_fn=None):
|
109
109
|
if impl.current_cfg().arch not in [_ti_core.Arch.x64, _ti_core.Arch.cuda]:
|
110
|
-
raise
|
110
|
+
raise GsTaichiSyntaxError("Unsupported arch for external function call")
|
111
111
|
_temp_dir = tempfile.mkdtemp()
|
112
112
|
_temp_source = os.path.join(_temp_dir, "_temp_source.cpp")
|
113
113
|
with open(_temp_source, "w") as f:
|
@@ -144,7 +144,7 @@ class SourceBuilder:
|
|
144
144
|
if self.mode == "so":
|
145
145
|
return external_func_call_wrapper
|
146
146
|
|
147
|
-
raise
|
147
|
+
raise GsTaichiSyntaxError("Error occurs when calling external function.")
|
148
148
|
|
149
149
|
|
150
150
|
__all__ = []
|
@@ -5,21 +5,21 @@ from types import MethodType
|
|
5
5
|
|
6
6
|
import numpy as np
|
7
7
|
|
8
|
-
from
|
9
|
-
from
|
10
|
-
from
|
11
|
-
|
12
|
-
|
13
|
-
|
8
|
+
from gstaichi._lib import core as _ti_core
|
9
|
+
from gstaichi.lang import expr, impl, ops
|
10
|
+
from gstaichi.lang.exception import (
|
11
|
+
GsTaichiRuntimeTypeError,
|
12
|
+
GsTaichiSyntaxError,
|
13
|
+
GsTaichiTypeError,
|
14
14
|
)
|
15
|
-
from
|
16
|
-
from
|
17
|
-
from
|
18
|
-
from
|
19
|
-
from
|
20
|
-
from
|
21
|
-
from
|
22
|
-
from
|
15
|
+
from gstaichi.lang.expr import Expr
|
16
|
+
from gstaichi.lang.field import Field, ScalarField, SNodeHostAccess
|
17
|
+
from gstaichi.lang.matrix import Matrix, MatrixType
|
18
|
+
from gstaichi.lang.util import cook_dtype, gstaichi_scope, in_python_scope, python_scope
|
19
|
+
from gstaichi.types import primitive_types
|
20
|
+
from gstaichi.types.compound_types import CompoundType
|
21
|
+
from gstaichi.types.enums import Layout
|
22
|
+
from gstaichi.types.utils import is_signed
|
23
23
|
|
24
24
|
|
25
25
|
class Struct:
|
@@ -50,7 +50,7 @@ _
|
|
50
50
|
dict_items([('v', [0. 0. 0.]), ('t', 1.0), ('A', {'v': [[0.], [0.], [0.]], 't': 1.0})])
|
51
51
|
"""
|
52
52
|
|
53
|
-
|
53
|
+
_is_gstaichi_class = True
|
54
54
|
_instance_count = 0
|
55
55
|
|
56
56
|
def __init__(self, *args, **kwargs):
|
@@ -60,7 +60,7 @@ _
|
|
60
60
|
elif len(args) == 0:
|
61
61
|
self.__entries = kwargs
|
62
62
|
else:
|
63
|
-
raise
|
63
|
+
raise GsTaichiSyntaxError(
|
64
64
|
"Custom structs need to be initialized using either dictionary or keyword arguments"
|
65
65
|
)
|
66
66
|
self.__methods = self.__entries.pop("__struct_methods", {})
|
@@ -171,14 +171,14 @@ _
|
|
171
171
|
|
172
172
|
return setter
|
173
173
|
|
174
|
-
@
|
174
|
+
@gstaichi_scope
|
175
175
|
def _assign(self, other):
|
176
176
|
if not isinstance(other, (dict, Struct)):
|
177
|
-
raise
|
177
|
+
raise GsTaichiTypeError("Only dict or Struct can be assigned to a Struct")
|
178
178
|
if isinstance(other, dict):
|
179
179
|
other = Struct(other)
|
180
180
|
if self.__entries.keys() != other.__entries.keys():
|
181
|
-
raise
|
181
|
+
raise GsTaichiTypeError(f"Member mismatch between structs {self.keys}, {other.keys}")
|
182
182
|
for k, v in self.items:
|
183
183
|
v._assign(other.__entries[k])
|
184
184
|
self.__dtype = other.__dtype
|
@@ -242,7 +242,7 @@ _
|
|
242
242
|
needs_dual=False,
|
243
243
|
layout=Layout.AOS,
|
244
244
|
):
|
245
|
-
"""Creates a :class:`~
|
245
|
+
"""Creates a :class:`~gstaichi.StructField` with each element
|
246
246
|
has this struct as its type.
|
247
247
|
|
248
248
|
Args:
|
@@ -280,7 +280,7 @@ _
|
|
280
280
|
"""
|
281
281
|
|
282
282
|
if shape is None and offset is not None:
|
283
|
-
raise
|
283
|
+
raise GsTaichiSyntaxError("shape cannot be None when offset is being set")
|
284
284
|
|
285
285
|
field_dict = {}
|
286
286
|
|
@@ -321,7 +321,7 @@ _
|
|
321
321
|
offset = (offset,)
|
322
322
|
|
323
323
|
if offset is not None and len(shape) != len(offset):
|
324
|
-
raise
|
324
|
+
raise GsTaichiSyntaxError(
|
325
325
|
f"The dimensionality of shape and offset must be the same ({len(shape)} != {len(offset)})"
|
326
326
|
)
|
327
327
|
dim = len(shape)
|
@@ -365,7 +365,7 @@ class _IntermediateStruct(Struct):
|
|
365
365
|
|
366
366
|
|
367
367
|
class StructField(Field):
|
368
|
-
"""
|
368
|
+
"""GsTaichi struct field with SNode implementation.
|
369
369
|
|
370
370
|
Instead of directly constraining Expr entries, the StructField object
|
371
371
|
directly hosts members as `Field` instances to support nested structs.
|
@@ -463,7 +463,7 @@ class StructField(Field):
|
|
463
463
|
"""Gets SNode of representative field member for loop range info.
|
464
464
|
|
465
465
|
Returns:
|
466
|
-
|
466
|
+
gstaichi_python.SNode: SNode of representative (first) field member.
|
467
467
|
"""
|
468
468
|
return self._members[0]._loop_range()
|
469
469
|
|
@@ -680,13 +680,13 @@ class StructType(CompoundType):
|
|
680
680
|
return False
|
681
681
|
return True
|
682
682
|
|
683
|
-
def
|
683
|
+
def from_gstaichi_object(self, func_ret, ret_index=()):
|
684
684
|
d = {}
|
685
685
|
items = self.members.items()
|
686
686
|
for index, pair in enumerate(items):
|
687
687
|
name, dtype = pair
|
688
688
|
if isinstance(dtype, CompoundType):
|
689
|
-
d[name] = dtype.
|
689
|
+
d[name] = dtype.from_gstaichi_object(func_ret, ret_index + (index,))
|
690
690
|
else:
|
691
691
|
d[name] = expr.Expr(
|
692
692
|
_ti_core.make_get_element_expr(
|
@@ -717,7 +717,7 @@ class StructType(CompoundType):
|
|
717
717
|
elif dtype in primitive_types.real_types:
|
718
718
|
d[name] = launch_ctx.get_struct_ret_float(ret_index + (index,))
|
719
719
|
else:
|
720
|
-
raise
|
720
|
+
raise GsTaichiRuntimeTypeError(f"Invalid return type on index={ret_index + (index, )}")
|
721
721
|
d["__struct_methods"] = self.methods
|
722
722
|
|
723
723
|
struct = Struct(d)
|
@@ -740,7 +740,7 @@ class StructType(CompoundType):
|
|
740
740
|
elif dtype in primitive_types.real_types:
|
741
741
|
launch_ctx.set_struct_arg_float(ret_index + (index,), struct[name])
|
742
742
|
else:
|
743
|
-
raise
|
743
|
+
raise GsTaichiRuntimeTypeError(f"Invalid argument type on index={ret_index + (index, )}")
|
744
744
|
|
745
745
|
def set_argpack_struct_args(self, struct, argpack, ret_index=()):
|
746
746
|
# TODO: move this to class Struct after we add dtype to Struct
|
@@ -758,12 +758,12 @@ class StructType(CompoundType):
|
|
758
758
|
elif dtype in primitive_types.real_types:
|
759
759
|
argpack.set_arg_float(ret_index + (index,), struct[name])
|
760
760
|
else:
|
761
|
-
raise
|
761
|
+
raise GsTaichiRuntimeTypeError(f"Invalid argument type on index={ret_index + (index, )}")
|
762
762
|
|
763
763
|
def cast(self, struct):
|
764
764
|
# sanity check members
|
765
765
|
if self.members.keys() != struct._Struct__entries.keys():
|
766
|
-
raise
|
766
|
+
raise GsTaichiSyntaxError("Incompatible arguments for custom struct members!")
|
767
767
|
entries = {}
|
768
768
|
for k, dtype in self.members.items():
|
769
769
|
if isinstance(dtype, MatrixType):
|
@@ -806,7 +806,7 @@ class StructType(CompoundType):
|
|
806
806
|
|
807
807
|
|
808
808
|
def dataclass(cls):
|
809
|
-
"""Converts a class with field annotations and methods into a
|
809
|
+
"""Converts a class with field annotations and methods into a gstaichi struct type.
|
810
810
|
|
811
811
|
This will return a normal custom struct type, with the functions added to it.
|
812
812
|
Struct fields can be generated in the normal way from the struct type.
|
@@ -834,7 +834,7 @@ def dataclass(cls):
|
|
834
834
|
cls (Class): the class with annotations and methods to convert to a struct
|
835
835
|
|
836
836
|
Returns:
|
837
|
-
A
|
837
|
+
A gstaichi struct with the annotations as fields
|
838
838
|
and methods from the class attached.
|
839
839
|
"""
|
840
840
|
# save the annotation fields for the struct
|
@@ -842,7 +842,7 @@ def dataclass(cls):
|
|
842
842
|
# raise error if there are default values
|
843
843
|
for k in fields.keys():
|
844
844
|
if hasattr(cls, k):
|
845
|
-
raise
|
845
|
+
raise GsTaichiSyntaxError("Default value in @dataclass is not supported.")
|
846
846
|
# get the class methods to be attached to the struct types
|
847
847
|
fields["__struct_methods"] = {
|
848
848
|
attribute: getattr(cls, attribute)
|