gstaichi 0.1.20.dev0__cp310-cp310-macosx_15_0_arm64.whl → 0.1.25.dev0__cp310-cp310-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {taichi → gstaichi}/__init__.py +9 -13
- {taichi → gstaichi}/_funcs.py +8 -8
- {taichi → gstaichi}/_kernels.py +19 -19
- gstaichi/_lib/__init__.py +3 -0
- taichi/_lib/core/taichi_python.cpython-310-darwin.so → gstaichi/_lib/core/gstaichi_python.cpython-310-darwin.so +0 -0
- taichi/_lib/core/taichi_python.pyi → gstaichi/_lib/core/gstaichi_python.pyi +382 -520
- {taichi → gstaichi}/_lib/runtime/runtime_arm64.bc +0 -0
- {taichi → gstaichi}/_lib/utils.py +15 -15
- {taichi → gstaichi}/_logging.py +1 -1
- {taichi → gstaichi}/_main.py +24 -31
- gstaichi/_snode/__init__.py +5 -0
- {taichi → gstaichi}/_snode/fields_builder.py +27 -29
- {taichi → gstaichi}/_snode/snode_tree.py +5 -5
- gstaichi/_test_tools/__init__.py +0 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_version.py +1 -0
- {taichi → gstaichi}/_version_check.py +8 -5
- gstaichi/ad/__init__.py +3 -0
- {taichi → gstaichi}/ad/_ad.py +26 -26
- {taichi → gstaichi}/algorithms/_algorithms.py +7 -7
- {taichi → gstaichi}/examples/minimal.py +1 -1
- {taichi → gstaichi}/experimental.py +1 -1
- gstaichi/lang/__init__.py +50 -0
- {taichi → gstaichi}/lang/_ndarray.py +30 -26
- {taichi → gstaichi}/lang/_ndrange.py +8 -8
- gstaichi/lang/_template_mapper.py +199 -0
- {taichi → gstaichi}/lang/_texture.py +19 -19
- {taichi → gstaichi}/lang/_wrap_inspect.py +7 -7
- {taichi → gstaichi}/lang/any_array.py +13 -13
- {taichi → gstaichi}/lang/argpack.py +29 -29
- gstaichi/lang/ast/__init__.py +5 -0
- {taichi → gstaichi}/lang/ast/ast_transformer.py +94 -582
- {taichi → gstaichi}/lang/ast/ast_transformer_utils.py +54 -41
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +267 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +320 -0
- {taichi → gstaichi}/lang/ast/checkers.py +5 -5
- gstaichi/lang/ast/transform.py +9 -0
- {taichi → gstaichi}/lang/common_ops.py +12 -12
- gstaichi/lang/exception.py +80 -0
- {taichi → gstaichi}/lang/expr.py +22 -22
- {taichi → gstaichi}/lang/field.py +29 -27
- {taichi → gstaichi}/lang/impl.py +116 -121
- {taichi → gstaichi}/lang/kernel_arguments.py +16 -16
- {taichi → gstaichi}/lang/kernel_impl.py +330 -363
- {taichi → gstaichi}/lang/matrix.py +119 -115
- {taichi → gstaichi}/lang/matrix_ops.py +6 -6
- {taichi → gstaichi}/lang/matrix_ops_utils.py +4 -4
- {taichi → gstaichi}/lang/mesh.py +22 -22
- {taichi → gstaichi}/lang/misc.py +39 -68
- {taichi → gstaichi}/lang/ops.py +146 -141
- {taichi → gstaichi}/lang/runtime_ops.py +2 -2
- {taichi → gstaichi}/lang/shell.py +3 -3
- {taichi → gstaichi}/lang/simt/__init__.py +1 -1
- {taichi → gstaichi}/lang/simt/block.py +7 -7
- {taichi → gstaichi}/lang/simt/grid.py +1 -1
- {taichi → gstaichi}/lang/simt/subgroup.py +1 -1
- {taichi → gstaichi}/lang/simt/warp.py +1 -1
- {taichi → gstaichi}/lang/snode.py +46 -44
- {taichi → gstaichi}/lang/source_builder.py +13 -13
- {taichi → gstaichi}/lang/struct.py +33 -33
- {taichi → gstaichi}/lang/util.py +24 -24
- gstaichi/linalg/__init__.py +8 -0
- {taichi → gstaichi}/linalg/matrixfree_cg.py +14 -14
- {taichi → gstaichi}/linalg/sparse_cg.py +10 -10
- {taichi → gstaichi}/linalg/sparse_matrix.py +23 -23
- {taichi → gstaichi}/linalg/sparse_solver.py +21 -21
- {taichi → gstaichi}/math/__init__.py +1 -1
- {taichi → gstaichi}/math/_complex.py +21 -20
- {taichi → gstaichi}/math/mathimpl.py +56 -56
- gstaichi/profiler/__init__.py +6 -0
- {taichi → gstaichi}/profiler/kernel_metrics.py +11 -11
- {taichi → gstaichi}/profiler/kernel_profiler.py +30 -36
- {taichi → gstaichi}/profiler/memory_profiler.py +1 -1
- {taichi → gstaichi}/profiler/scoped_profiler.py +2 -2
- {taichi → gstaichi}/sparse/_sparse_grid.py +7 -7
- {taichi → gstaichi}/tools/__init__.py +4 -4
- {taichi → gstaichi}/tools/diagnose.py +10 -17
- gstaichi/types/__init__.py +19 -0
- {taichi → gstaichi}/types/annotations.py +1 -1
- {taichi → gstaichi}/types/compound_types.py +8 -8
- {taichi → gstaichi}/types/enums.py +1 -1
- {taichi → gstaichi}/types/ndarray_type.py +7 -7
- {taichi → gstaichi}/types/primitive_types.py +17 -14
- {taichi → gstaichi}/types/quant.py +9 -9
- {taichi → gstaichi}/types/texture_type.py +5 -5
- {taichi → gstaichi}/types/utils.py +1 -1
- {gstaichi-0.1.20.dev0.dist-info → gstaichi-0.1.25.dev0.dist-info}/METADATA +13 -16
- gstaichi-0.1.25.dev0.dist-info/RECORD +168 -0
- gstaichi-0.1.25.dev0.dist-info/entry_points.txt +2 -0
- gstaichi-0.1.25.dev0.dist-info/top_level.txt +1 -0
- gstaichi-0.1.20.dev0.dist-info/RECORD +0 -219
- gstaichi-0.1.20.dev0.dist-info/entry_points.txt +0 -2
- gstaichi-0.1.20.dev0.dist-info/top_level.txt +0 -1
- taichi/_lib/__init__.py +0 -3
- taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +0 -1401
- taichi/_lib/c_api/include/taichi/taichi.h +0 -29
- taichi/_lib/c_api/include/taichi/taichi_core.h +0 -1111
- taichi/_lib/c_api/include/taichi/taichi_cpu.h +0 -29
- taichi/_lib/c_api/include/taichi/taichi_metal.h +0 -72
- taichi/_lib/c_api/include/taichi/taichi_platform.h +0 -55
- taichi/_lib/c_api/include/taichi/taichi_unity.h +0 -64
- taichi/_lib/c_api/include/taichi/taichi_vulkan.h +0 -151
- taichi/_lib/c_api/lib/libtaichi_c_api.dylib +0 -0
- taichi/_lib/c_api/runtime/runtime_arm64.bc +0 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +0 -29
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +0 -65
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +0 -121
- taichi/_lib/runtime/libMoltenVK.dylib +0 -0
- taichi/_snode/__init__.py +0 -5
- taichi/_ti_module/__init__.py +0 -3
- taichi/_ti_module/cppgen.py +0 -309
- taichi/_ti_module/module.py +0 -145
- taichi/_version.py +0 -1
- taichi/ad/__init__.py +0 -3
- taichi/aot/__init__.py +0 -12
- taichi/aot/_export.py +0 -28
- taichi/aot/conventions/__init__.py +0 -3
- taichi/aot/conventions/gfxruntime140/__init__.py +0 -38
- taichi/aot/conventions/gfxruntime140/dr.py +0 -244
- taichi/aot/conventions/gfxruntime140/sr.py +0 -613
- taichi/aot/module.py +0 -253
- taichi/aot/utils.py +0 -151
- taichi/graph/__init__.py +0 -3
- taichi/graph/_graph.py +0 -292
- taichi/lang/__init__.py +0 -50
- taichi/lang/ast/__init__.py +0 -5
- taichi/lang/ast/transform.py +0 -9
- taichi/lang/exception.py +0 -80
- taichi/linalg/__init__.py +0 -8
- taichi/profiler/__init__.py +0 -6
- taichi/shaders/Circles_vk.frag +0 -29
- taichi/shaders/Circles_vk.vert +0 -45
- taichi/shaders/Circles_vk_frag.spv +0 -0
- taichi/shaders/Circles_vk_vert.spv +0 -0
- taichi/shaders/Lines_vk.frag +0 -9
- taichi/shaders/Lines_vk.vert +0 -11
- taichi/shaders/Lines_vk_frag.spv +0 -0
- taichi/shaders/Lines_vk_vert.spv +0 -0
- taichi/shaders/Mesh_vk.frag +0 -71
- taichi/shaders/Mesh_vk.vert +0 -68
- taichi/shaders/Mesh_vk_frag.spv +0 -0
- taichi/shaders/Mesh_vk_vert.spv +0 -0
- taichi/shaders/Particles_vk.frag +0 -95
- taichi/shaders/Particles_vk.vert +0 -73
- taichi/shaders/Particles_vk_frag.spv +0 -0
- taichi/shaders/Particles_vk_vert.spv +0 -0
- taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
- taichi/shaders/SceneLines_vk.frag +0 -9
- taichi/shaders/SceneLines_vk.vert +0 -12
- taichi/shaders/SceneLines_vk_frag.spv +0 -0
- taichi/shaders/SceneLines_vk_vert.spv +0 -0
- taichi/shaders/SetImage_vk.frag +0 -21
- taichi/shaders/SetImage_vk.vert +0 -15
- taichi/shaders/SetImage_vk_frag.spv +0 -0
- taichi/shaders/SetImage_vk_vert.spv +0 -0
- taichi/shaders/Triangles_vk.frag +0 -16
- taichi/shaders/Triangles_vk.vert +0 -29
- taichi/shaders/Triangles_vk_frag.spv +0 -0
- taichi/shaders/Triangles_vk_vert.spv +0 -0
- taichi/shaders/lines2quad_vk_comp.spv +0 -0
- taichi/types/__init__.py +0 -19
- {taichi → gstaichi}/__main__.py +0 -0
- {taichi → gstaichi}/_lib/core/__init__.py +0 -0
- {taichi → gstaichi}/_lib/core/py.typed +0 -0
- {taichi/_lib/c_api → gstaichi/_lib}/runtime/libMoltenVK.dylib +0 -0
- {taichi → gstaichi}/algorithms/__init__.py +0 -0
- {taichi → gstaichi}/assets/.git +0 -0
- {taichi → gstaichi}/assets/Go-Regular.ttf +0 -0
- {taichi → gstaichi}/assets/static/imgs/ti_gallery.png +0 -0
- {taichi → gstaichi}/lang/ast/symbol_resolver.py +0 -0
- {taichi → gstaichi}/sparse/__init__.py +0 -0
- {taichi → gstaichi}/tools/np2ply.py +0 -0
- {taichi → gstaichi}/tools/vtk.py +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/GLFW/glfw3.h +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/GLFW/glfw3native.h +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/instrument.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/libspirv.h +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/libspirv.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/linker.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/optimizer.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/GLSL.std.450.h +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv.h +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_cfg.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_common.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_cpp.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_cross.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_cross_c.h +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_cross_containers.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_cross_error_handling.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_cross_util.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_glsl.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_hlsl.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_msl.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_parser.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv_cross/spirv_reflect.hpp +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/glfw3/glfw3Config.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/glfw3/glfw3Targets-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/cmake/glfw3/glfw3Targets.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +0 -0
- {gstaichi-0.1.20.dev0.data → gstaichi-0.1.25.dev0.data}/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +0 -0
- {gstaichi-0.1.20.dev0.dist-info → gstaichi-0.1.25.dev0.dist-info}/WHEEL +0 -0
- {gstaichi-0.1.20.dev0.dist-info → gstaichi-0.1.25.dev0.dist-info}/licenses/LICENSE +0 -0
{taichi → gstaichi}/lang/impl.py
RENAMED
@@ -1,33 +1,31 @@
|
|
1
|
-
# type: ignore
|
2
|
-
|
3
1
|
import numbers
|
4
2
|
from types import FunctionType, MethodType
|
5
3
|
from typing import Any, Iterable, Sequence
|
6
4
|
|
7
5
|
import numpy as np
|
8
6
|
|
9
|
-
from
|
10
|
-
from
|
11
|
-
|
7
|
+
from gstaichi._lib import core as _ti_core
|
8
|
+
from gstaichi._lib.core.gstaichi_python import (
|
9
|
+
DataTypeCxx,
|
12
10
|
Function,
|
13
11
|
Program,
|
14
12
|
)
|
15
|
-
from
|
16
|
-
from
|
17
|
-
from
|
18
|
-
from
|
19
|
-
from
|
20
|
-
from
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
13
|
+
from gstaichi._snode.fields_builder import FieldsBuilder
|
14
|
+
from gstaichi.lang._ndarray import ScalarNdarray
|
15
|
+
from gstaichi.lang._ndrange import GroupedNDRange, _Ndrange
|
16
|
+
from gstaichi.lang._texture import RWTextureAccessor
|
17
|
+
from gstaichi.lang.any_array import AnyArray
|
18
|
+
from gstaichi.lang.exception import (
|
19
|
+
GsTaichiCompilationError,
|
20
|
+
GsTaichiRuntimeError,
|
21
|
+
GsTaichiSyntaxError,
|
22
|
+
GsTaichiTypeError,
|
25
23
|
)
|
26
|
-
from
|
27
|
-
from
|
28
|
-
from
|
29
|
-
from
|
30
|
-
from
|
24
|
+
from gstaichi.lang.expr import Expr, make_expr_group
|
25
|
+
from gstaichi.lang.field import Field, ScalarField
|
26
|
+
from gstaichi.lang.kernel_arguments import SparseMatrixProxy
|
27
|
+
from gstaichi.lang.kernel_impl import BoundGsTaichiCallable, GsTaichiCallable, Kernel
|
28
|
+
from gstaichi.lang.matrix import (
|
31
29
|
Matrix,
|
32
30
|
MatrixField,
|
33
31
|
MatrixNdarray,
|
@@ -36,7 +34,7 @@ from taichi.lang.matrix import (
|
|
36
34
|
VectorNdarray,
|
37
35
|
make_matrix,
|
38
36
|
)
|
39
|
-
from
|
37
|
+
from gstaichi.lang.mesh import (
|
40
38
|
ConvType,
|
41
39
|
MeshElementFieldProxy,
|
42
40
|
MeshInstance,
|
@@ -45,19 +43,19 @@ from taichi.lang.mesh import (
|
|
45
43
|
MeshReorderedScalarFieldProxy,
|
46
44
|
element_type_name,
|
47
45
|
)
|
48
|
-
from
|
49
|
-
from
|
50
|
-
from
|
51
|
-
from
|
46
|
+
from gstaichi.lang.simt.block import SharedArray
|
47
|
+
from gstaichi.lang.snode import SNode
|
48
|
+
from gstaichi.lang.struct import Struct, StructField, _IntermediateStruct
|
49
|
+
from gstaichi.lang.util import (
|
52
50
|
cook_dtype,
|
53
51
|
get_traceback,
|
54
|
-
|
52
|
+
gstaichi_scope,
|
53
|
+
is_gstaichi_class,
|
55
54
|
python_scope,
|
56
|
-
taichi_scope,
|
57
55
|
warning,
|
58
56
|
)
|
59
|
-
from
|
60
|
-
from
|
57
|
+
from gstaichi.types.enums import SNodeGradType
|
58
|
+
from gstaichi.types.primitive_types import (
|
61
59
|
all_types,
|
62
60
|
f16,
|
63
61
|
f32,
|
@@ -70,7 +68,7 @@ from taichi.types.primitive_types import (
|
|
70
68
|
)
|
71
69
|
|
72
70
|
|
73
|
-
@
|
71
|
+
@gstaichi_scope
|
74
72
|
def expr_init_shared_array(shape, element_type):
|
75
73
|
compiling_callable = get_runtime().compiling_callable
|
76
74
|
assert compiling_callable is not None
|
@@ -79,7 +77,7 @@ def expr_init_shared_array(shape, element_type):
|
|
79
77
|
)
|
80
78
|
|
81
79
|
|
82
|
-
@
|
80
|
+
@gstaichi_scope
|
83
81
|
def expr_init(rhs):
|
84
82
|
compiling_callable = get_runtime().compiling_callable
|
85
83
|
assert compiling_callable is not None
|
@@ -88,7 +86,7 @@ def expr_init(rhs):
|
|
88
86
|
compiling_callable.ast_builder().expr_alloca(_ti_core.DebugInfo(get_runtime().get_current_src_info()))
|
89
87
|
)
|
90
88
|
if isinstance(rhs, Matrix) and (hasattr(rhs, "_DIM")):
|
91
|
-
return Matrix(*rhs.to_list(), ndim=rhs.ndim)
|
89
|
+
return Matrix(*rhs.to_list(), ndim=rhs.ndim) # type: ignore
|
92
90
|
if isinstance(rhs, Matrix):
|
93
91
|
return make_matrix(rhs.to_list())
|
94
92
|
if isinstance(rhs, SharedArray):
|
@@ -101,7 +99,7 @@ def expr_init(rhs):
|
|
101
99
|
return tuple(expr_init(e) for e in rhs)
|
102
100
|
if isinstance(rhs, dict):
|
103
101
|
return dict((key, expr_init(val)) for key, val in rhs.items())
|
104
|
-
if isinstance(rhs, _ti_core.
|
102
|
+
if isinstance(rhs, _ti_core.DataTypeCxx):
|
105
103
|
return rhs
|
106
104
|
if isinstance(rhs, _ti_core.Arch):
|
107
105
|
return rhs
|
@@ -120,7 +118,7 @@ def expr_init(rhs):
|
|
120
118
|
)
|
121
119
|
|
122
120
|
|
123
|
-
@
|
121
|
+
@gstaichi_scope
|
124
122
|
def expr_init_func(rhs): # temporary solution to allow passing in fields as arguments
|
125
123
|
if isinstance(rhs, Field):
|
126
124
|
return rhs
|
@@ -130,7 +128,7 @@ def expr_init_func(rhs): # temporary solution to allow passing in fields as arg
|
|
130
128
|
def begin_frontend_struct_for(ast_builder, group, loop_range):
|
131
129
|
if not isinstance(loop_range, (AnyArray, Field, SNode, RWTextureAccessor, _Root)):
|
132
130
|
raise TypeError(
|
133
|
-
f"Cannot loop over the object {type(loop_range)} in
|
131
|
+
f"Cannot loop over the object {type(loop_range)} in GsTaichi scope. Only GsTaichi fields (via template) or dense arrays (via types.ndarray) are supported."
|
134
132
|
)
|
135
133
|
if group.size() != len(loop_range.shape):
|
136
134
|
raise IndexError(
|
@@ -147,7 +145,7 @@ def begin_frontend_struct_for(ast_builder, group, loop_range):
|
|
147
145
|
|
148
146
|
def begin_frontend_if(ast_builder, cond, stmt_dbg_info):
|
149
147
|
assert ast_builder is not None
|
150
|
-
if
|
148
|
+
if is_gstaichi_class(cond):
|
151
149
|
raise ValueError(
|
152
150
|
"The truth value of vectors/matrices is ambiguous.\n"
|
153
151
|
"Consider using `any` or `all` when comparing vectors/matrices:\n"
|
@@ -158,15 +156,15 @@ def begin_frontend_if(ast_builder, cond, stmt_dbg_info):
|
|
158
156
|
ast_builder.begin_frontend_if(Expr(cond).ptr, stmt_dbg_info)
|
159
157
|
|
160
158
|
|
161
|
-
@
|
159
|
+
@gstaichi_scope
|
162
160
|
def _calc_slice(index, default_stop):
|
163
161
|
start, stop, step = index.start or 0, index.stop or default_stop, index.step or 1
|
164
162
|
|
165
163
|
def check_validity(x):
|
166
164
|
# TODO(mzmzm): support variable in slice
|
167
165
|
if isinstance(x, Expr):
|
168
|
-
raise
|
169
|
-
"
|
166
|
+
raise GsTaichiCompilationError(
|
167
|
+
"GsTaichi does not support variables in slice now, please use constant instead of it."
|
170
168
|
)
|
171
169
|
|
172
170
|
check_validity(start), check_validity(stop), check_validity(step)
|
@@ -190,16 +188,16 @@ def validate_subscript_index(value, index):
|
|
190
188
|
validate_subscript_index(value, index.stop)
|
191
189
|
|
192
190
|
if isinstance(index, int) and index < 0:
|
193
|
-
raise
|
191
|
+
raise GsTaichiSyntaxError("Negative indices are not supported in GsTaichi kernels.")
|
194
192
|
|
195
193
|
|
196
|
-
@
|
194
|
+
@gstaichi_scope
|
197
195
|
def subscript(ast_builder, value, *_indices, skip_reordered=False):
|
198
196
|
dbg_info = _ti_core.DebugInfo(get_runtime().get_current_src_info())
|
199
197
|
compiling_callable = get_runtime().compiling_callable
|
200
198
|
assert compiling_callable is not None
|
201
199
|
ast_builder = compiling_callable.ast_builder()
|
202
|
-
# Directly evaluate in Python for non-
|
200
|
+
# Directly evaluate in Python for non-GsTaichi types
|
203
201
|
if not isinstance(
|
204
202
|
value,
|
205
203
|
(
|
@@ -237,14 +235,14 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False):
|
|
237
235
|
indices_expr_group = None
|
238
236
|
if has_slice:
|
239
237
|
if not (isinstance(value, Expr) and value.is_tensor()):
|
240
|
-
raise
|
238
|
+
raise GsTaichiSyntaxError(f"The type {type(value)} do not support index of slice type")
|
241
239
|
else:
|
242
240
|
indices_expr_group = make_expr_group(*indices)
|
243
241
|
|
244
242
|
if isinstance(value, SharedArray):
|
245
243
|
return value.subscript(*indices)
|
246
244
|
if isinstance(value, MeshElementFieldProxy):
|
247
|
-
return value.subscript(*indices)
|
245
|
+
return value.subscript(*indices) # type: ignore
|
248
246
|
if isinstance(value, MeshRelationAccessProxy):
|
249
247
|
return value.subscript(*indices)
|
250
248
|
if isinstance(value, (MeshReorderedScalarFieldProxy, MeshReorderedMatrixFieldProxy)) and not skip_reordered:
|
@@ -333,7 +331,7 @@ class SrcInfoGuard:
|
|
333
331
|
self.info_stack.pop()
|
334
332
|
|
335
333
|
|
336
|
-
class
|
334
|
+
class PyGsTaichi:
|
337
335
|
def __init__(self, kernels=None):
|
338
336
|
self.materialized = False
|
339
337
|
self._prog: Program | None = None
|
@@ -359,13 +357,13 @@ class PyTaichi:
|
|
359
357
|
@property
|
360
358
|
def prog(self) -> Program:
|
361
359
|
if self._prog is None:
|
362
|
-
raise
|
360
|
+
raise GsTaichiRuntimeError("_prog attribute not initialized. Maybe you forgot to call `ti.init()` first?")
|
363
361
|
return self._prog
|
364
362
|
|
365
363
|
@property
|
366
364
|
def current_kernel(self) -> Kernel:
|
367
365
|
if self._current_kernel is None:
|
368
|
-
raise
|
366
|
+
raise GsTaichiRuntimeError(
|
369
367
|
"_pr_current_kernelog attribute not initialized. Maybe you forgot to call `ti.init()` first?"
|
370
368
|
)
|
371
369
|
return self._current_kernel
|
@@ -385,7 +383,7 @@ class PyTaichi:
|
|
385
383
|
if builder == _root_fb:
|
386
384
|
continue
|
387
385
|
|
388
|
-
raise
|
386
|
+
raise GsTaichiRuntimeError(
|
389
387
|
f"Field builder {builder} is not finalized. " f"Please call finalize() on it. Traceback:\n{tb}"
|
390
388
|
)
|
391
389
|
|
@@ -426,7 +424,7 @@ class PyTaichi:
|
|
426
424
|
# if the root itself is empty), so that there is a valid struct
|
427
425
|
# llvm::Module, if no field has been declared before the first kernel
|
428
426
|
# invocation. Example case:
|
429
|
-
# https://github.com/taichi-dev/
|
427
|
+
# https://github.com/taichi-dev/gstaichi/blob/27bb1dc3227d9273a79fcb318fdb06fd053068f5/tests/python/test_ad_basics.py#L260-L266
|
430
428
|
return
|
431
429
|
|
432
430
|
if get_runtime().prog.config().debug:
|
@@ -437,13 +435,6 @@ class PyTaichi:
|
|
437
435
|
global _root_fb
|
438
436
|
_root_fb = FieldsBuilder()
|
439
437
|
|
440
|
-
@staticmethod
|
441
|
-
def _finalize_root_fb_for_aot():
|
442
|
-
if _root_fb.finalized:
|
443
|
-
raise RuntimeError("AOT: can only finalize the root FieldsBuilder once")
|
444
|
-
assert isinstance(_root_fb, FieldsBuilder)
|
445
|
-
_root_fb._finalize_for_aot()
|
446
|
-
|
447
438
|
@staticmethod
|
448
439
|
def _get_tb(_var):
|
449
440
|
return getattr(_var, "declaration_tb", str(_var.ptr))
|
@@ -531,33 +522,33 @@ class PyTaichi:
|
|
531
522
|
self._prog.synchronize()
|
532
523
|
|
533
524
|
|
534
|
-
|
525
|
+
pygstaichi = PyGsTaichi()
|
535
526
|
|
536
527
|
|
537
|
-
def get_runtime() ->
|
538
|
-
return
|
528
|
+
def get_runtime() -> PyGsTaichi:
|
529
|
+
return pygstaichi
|
539
530
|
|
540
531
|
|
541
532
|
def reset():
|
542
|
-
global
|
543
|
-
old_kernels =
|
544
|
-
|
545
|
-
|
533
|
+
global pygstaichi
|
534
|
+
old_kernels = pygstaichi.kernels
|
535
|
+
pygstaichi.clear()
|
536
|
+
pygstaichi = PyGsTaichi(old_kernels)
|
546
537
|
for k in old_kernels:
|
547
538
|
k.reset()
|
548
539
|
_ti_core.reset_default_compile_config()
|
549
540
|
|
550
541
|
|
551
|
-
@
|
542
|
+
@gstaichi_scope
|
552
543
|
def static_print(*args, __p=print, **kwargs):
|
553
|
-
"""The print function in
|
544
|
+
"""The print function in GsTaichi scope.
|
554
545
|
|
555
546
|
This function is called at compile time and has no runtime overhead.
|
556
547
|
"""
|
557
548
|
__p(*args, **kwargs)
|
558
549
|
|
559
550
|
|
560
|
-
# we don't add @
|
551
|
+
# we don't add @gstaichi_scope decorator for @ti.pyfunc to work
|
561
552
|
def static_assert(cond, msg=None):
|
562
553
|
"""Throw AssertionError when `cond` is False.
|
563
554
|
|
@@ -577,7 +568,7 @@ def static_assert(cond, msg=None):
|
|
577
568
|
AssertionError: the year must be a lunar year
|
578
569
|
"""
|
579
570
|
if isinstance(cond, Expr):
|
580
|
-
raise
|
571
|
+
raise GsTaichiTypeError("Static assert with non-static condition")
|
581
572
|
if msg is not None:
|
582
573
|
assert cond, msg
|
583
574
|
else:
|
@@ -585,7 +576,7 @@ def static_assert(cond, msg=None):
|
|
585
576
|
|
586
577
|
|
587
578
|
def inside_kernel():
|
588
|
-
return
|
579
|
+
return pygstaichi.inside_kernel
|
589
580
|
|
590
581
|
|
591
582
|
def index_nd(dim):
|
@@ -597,21 +588,21 @@ class _UninitializedRootFieldsBuilder:
|
|
597
588
|
if item == "__qualname__":
|
598
589
|
# For sphinx docstring extraction.
|
599
590
|
return "_UninitializedRootFieldsBuilder"
|
600
|
-
raise
|
591
|
+
raise GsTaichiRuntimeError("Please call init() first")
|
601
592
|
|
602
593
|
|
603
594
|
# `root` initialization must be delayed until after the program is
|
604
|
-
# created. Unfortunately, `root` exists in both
|
605
|
-
# the top-level
|
595
|
+
# created. Unfortunately, `root` exists in both gstaichi.lang.impl module and
|
596
|
+
# the top-level gstaichi module at this point; so if `root` itself is written, we
|
606
597
|
# would have to make sure that `root` in all the modules get updated to the same
|
607
598
|
# instance. This is an error-prone process.
|
608
599
|
#
|
609
600
|
# To avoid this situation, we create `root` once during the import time, and
|
610
601
|
# never write to it. The core part, `_root_fb`, is the one whose initialization
|
611
|
-
# gets delayed. `_root_fb` will only exist in the
|
602
|
+
# gets delayed. `_root_fb` will only exist in the gstaichi.lang.impl module, so
|
612
603
|
# writing to it is would result in less for maintenance cost.
|
613
604
|
#
|
614
|
-
# `_root_fb` will be overridden inside :func:`
|
605
|
+
# `_root_fb` will be overridden inside :func:`gstaichi.lang.init`.
|
615
606
|
_root_fb = _UninitializedRootFieldsBuilder()
|
616
607
|
|
617
608
|
|
@@ -626,19 +617,19 @@ class _Root:
|
|
626
617
|
|
627
618
|
@staticmethod
|
628
619
|
def parent(n=1):
|
629
|
-
"""Same as :func:`
|
620
|
+
"""Same as :func:`gstaichi.SNode.parent`"""
|
630
621
|
assert isinstance(_root_fb, FieldsBuilder)
|
631
622
|
return _root_fb.root.parent(n)
|
632
623
|
|
633
624
|
@staticmethod
|
634
625
|
def _loop_range():
|
635
|
-
"""Same as :func:`
|
626
|
+
"""Same as :func:`gstaichi.SNode.loop_range`"""
|
636
627
|
assert isinstance(_root_fb, FieldsBuilder)
|
637
628
|
return _root_fb.root._loop_range()
|
638
629
|
|
639
630
|
@staticmethod
|
640
631
|
def _get_children():
|
641
|
-
"""Same as :func:`
|
632
|
+
"""Same as :func:`gstaichi.SNode.get_children`"""
|
642
633
|
assert isinstance(_root_fb, FieldsBuilder)
|
643
634
|
return _root_fb.root._get_children()
|
644
635
|
|
@@ -650,7 +641,7 @@ class _Root:
|
|
650
641
|
|
651
642
|
@property
|
652
643
|
def shape(self):
|
653
|
-
"""Same as :func:`
|
644
|
+
"""Same as :func:`gstaichi.SNode.shape`"""
|
654
645
|
assert isinstance(_root_fb, FieldsBuilder)
|
655
646
|
return _root_fb.root.shape
|
656
647
|
|
@@ -667,7 +658,7 @@ class _Root:
|
|
667
658
|
|
668
659
|
|
669
660
|
root = _Root()
|
670
|
-
"""Root of the declared
|
661
|
+
"""Root of the declared GsTaichi :func:`~gstaichi.lang.impl.field`s.
|
671
662
|
|
672
663
|
See also https://docs.taichi-lang.org/docs/layout
|
673
664
|
|
@@ -702,7 +693,7 @@ def create_field_member(dtype, name, needs_grad, needs_dual):
|
|
702
693
|
x.ptr = _ti_core.expr_field(x.ptr, dtype)
|
703
694
|
x.ptr.set_name(name)
|
704
695
|
x.ptr.set_grad_type(SNodeGradType.PRIMAL)
|
705
|
-
|
696
|
+
pygstaichi.global_vars.append(x)
|
706
697
|
|
707
698
|
x_grad = None
|
708
699
|
x_dual = None
|
@@ -717,13 +708,13 @@ def create_field_member(dtype, name, needs_grad, needs_dual):
|
|
717
708
|
x_grad.ptr.set_grad_type(SNodeGradType.ADJOINT)
|
718
709
|
x.ptr.set_adjoint(x_grad.ptr)
|
719
710
|
if needs_grad:
|
720
|
-
|
711
|
+
pygstaichi.grad_vars.append(x_grad)
|
721
712
|
|
722
713
|
if prog.config().debug:
|
723
714
|
# adjoint checkbit
|
724
715
|
x_grad_checkbit = Expr(prog.make_id_expr(""))
|
725
716
|
dtype = u8
|
726
|
-
if prog.config().arch
|
717
|
+
if prog.config().arch == _ti_core.vulkan:
|
727
718
|
dtype = i32
|
728
719
|
x_grad_checkbit.ptr = _ti_core.expr_field(x_grad_checkbit.ptr, cook_dtype(dtype))
|
729
720
|
x_grad_checkbit.ptr.set_name(name + ".grad_checkbit")
|
@@ -737,9 +728,9 @@ def create_field_member(dtype, name, needs_grad, needs_dual):
|
|
737
728
|
x_dual.ptr.set_grad_type(SNodeGradType.DUAL)
|
738
729
|
x.ptr.set_dual(x_dual.ptr)
|
739
730
|
if needs_dual:
|
740
|
-
|
731
|
+
pygstaichi.dual_vars.append(x_dual)
|
741
732
|
elif needs_grad or needs_dual:
|
742
|
-
raise
|
733
|
+
raise GsTaichiRuntimeError(f"{dtype} is not supported for field with `needs_grad=True` or `needs_dual=True`.")
|
743
734
|
|
744
735
|
return x, x_grad, x_dual
|
745
736
|
|
@@ -765,9 +756,9 @@ def _field(
|
|
765
756
|
|
766
757
|
if shape is None:
|
767
758
|
if offset is not None:
|
768
|
-
raise
|
759
|
+
raise GsTaichiSyntaxError("shape cannot be None when offset is set")
|
769
760
|
if order is not None:
|
770
|
-
raise
|
761
|
+
raise GsTaichiSyntaxError("shape cannot be None when order is set")
|
771
762
|
else:
|
772
763
|
if isinstance(shape, numbers.Number):
|
773
764
|
shape = (shape,)
|
@@ -775,20 +766,22 @@ def _field(
|
|
775
766
|
offset = (offset,)
|
776
767
|
dim = len(shape)
|
777
768
|
if offset is not None and dim != len(offset):
|
778
|
-
raise
|
769
|
+
raise GsTaichiSyntaxError(
|
770
|
+
f"The dimensionality of shape and offset must be the same ({dim} != {len(offset)})"
|
771
|
+
)
|
779
772
|
axis_seq = []
|
780
773
|
shape_seq = []
|
781
774
|
if order is not None:
|
782
775
|
if dim != len(order):
|
783
|
-
raise
|
776
|
+
raise GsTaichiSyntaxError(
|
784
777
|
f"The dimensionality of shape and order must be the same ({dim} != {len(order)})"
|
785
778
|
)
|
786
779
|
if dim != len(set(order)):
|
787
|
-
raise
|
780
|
+
raise GsTaichiSyntaxError("The axes in order must be different")
|
788
781
|
for ch in order:
|
789
782
|
axis = ord(ch) - ord("i")
|
790
783
|
if axis < 0 or axis >= dim:
|
791
|
-
raise
|
784
|
+
raise GsTaichiSyntaxError(f"Invalid axis {ch}")
|
792
785
|
axis_seq.append(axis)
|
793
786
|
shape_seq.append(shape[axis])
|
794
787
|
else:
|
@@ -805,12 +798,12 @@ def _field(
|
|
805
798
|
|
806
799
|
@python_scope
|
807
800
|
def field(dtype, *args, **kwargs):
|
808
|
-
"""Defines a
|
801
|
+
"""Defines a GsTaichi field.
|
809
802
|
|
810
|
-
A
|
811
|
-
the complexity of how its underlying :class:`~
|
812
|
-
actually defined. The data in a
|
813
|
-
a
|
803
|
+
A GsTaichi field can be viewed as an abstract N-dimensional array, hiding away
|
804
|
+
the complexity of how its underlying :class:`~gstaichi.lang.snode.SNode` are
|
805
|
+
actually defined. The data in a GsTaichi field can be directly accessed by
|
806
|
+
a GsTaichi :func:`~gstaichi.lang.kernel_impl.kernel`.
|
814
807
|
|
815
808
|
See also https://docs.taichi-lang.org/docs/field
|
816
809
|
|
@@ -827,7 +820,7 @@ def field(dtype, *args, **kwargs):
|
|
827
820
|
|
828
821
|
Example::
|
829
822
|
|
830
|
-
The code below shows how a
|
823
|
+
The code below shows how a GsTaichi field can be declared and defined::
|
831
824
|
|
832
825
|
>>> x1 = ti.field(ti.f32, shape=(16, 8))
|
833
826
|
>>> # Equivalently
|
@@ -851,14 +844,14 @@ def field(dtype, *args, **kwargs):
|
|
851
844
|
|
852
845
|
@python_scope
|
853
846
|
def ndarray(dtype, shape, needs_grad=False):
|
854
|
-
"""Defines a
|
847
|
+
"""Defines a GsTaichi ndarray with scalar elements.
|
855
848
|
|
856
849
|
Args:
|
857
850
|
dtype (Union[DataType, MatrixType]): Data type of each element. This can be either a scalar type like ti.f32 or a compound type like ti.types.vector(3, ti.i32).
|
858
851
|
shape (Union[int, tuple[int]]): Shape of the ndarray.
|
859
852
|
|
860
853
|
Example:
|
861
|
-
The code below shows how a
|
854
|
+
The code below shows how a GsTaichi ndarray with scalar elements can be declared and defined::
|
862
855
|
|
863
856
|
>>> x = ti.ndarray(ti.f32, shape=(16, 8)) # ndarray of shape (16, 8), each element is ti.f32 scalar.
|
864
857
|
>>> vec3 = ti.types.vector(3, ti.i32)
|
@@ -870,7 +863,7 @@ def ndarray(dtype, shape, needs_grad=False):
|
|
870
863
|
if isinstance(shape, numbers.Number):
|
871
864
|
shape = (shape,)
|
872
865
|
if not all((isinstance(x, int) or isinstance(x, np.integer)) and x > 0 and x <= 2**31 - 1 for x in shape):
|
873
|
-
raise
|
866
|
+
raise GsTaichiRuntimeError(f"{shape} is not a valid shape for ndarray")
|
874
867
|
if dtype in all_types:
|
875
868
|
dt = cook_dtype(dtype)
|
876
869
|
x = ScalarNdarray(dt, shape)
|
@@ -881,17 +874,19 @@ def ndarray(dtype, shape, needs_grad=False):
|
|
881
874
|
x = MatrixNdarray(dtype.n, dtype.m, dtype.dtype, shape)
|
882
875
|
dt = dtype.dtype
|
883
876
|
else:
|
884
|
-
raise
|
877
|
+
raise GsTaichiRuntimeError(f"{dtype} is not supported as ndarray element type")
|
885
878
|
if needs_grad:
|
886
|
-
assert isinstance(dt,
|
879
|
+
assert isinstance(dt, DataTypeCxx)
|
887
880
|
if not _ti_core.is_real(dt):
|
888
|
-
raise
|
881
|
+
raise GsTaichiRuntimeError(
|
882
|
+
f"{dt} is not supported for ndarray with `needs_grad=True` or `needs_dual=True`."
|
883
|
+
)
|
889
884
|
x_grad = ndarray(dtype, shape, needs_grad=False)
|
890
885
|
x._set_grad(x_grad)
|
891
886
|
return x
|
892
887
|
|
893
888
|
|
894
|
-
@
|
889
|
+
@gstaichi_scope
|
895
890
|
def ti_format_list_to_content_entries(raw):
|
896
891
|
# return a pair of [content, format]
|
897
892
|
def entry2content(_var):
|
@@ -919,7 +914,7 @@ def ti_format_list_to_content_entries(raw):
|
|
919
914
|
yield _var[1:]
|
920
915
|
continue
|
921
916
|
elif hasattr(_var, "__ti_repr__"):
|
922
|
-
res = _var.__ti_repr__()
|
917
|
+
res = _var.__ti_repr__() # type: ignore
|
923
918
|
elif isinstance(_var, (list, tuple)):
|
924
919
|
# If the first element is '__ti_format__', this list is the result of ti_format.
|
925
920
|
if len(_var) > 0 and isinstance(_var[0], str) and _var[0] == "__ti_format__":
|
@@ -956,7 +951,7 @@ def ti_format_list_to_content_entries(raw):
|
|
956
951
|
return extract_formats(entries)
|
957
952
|
|
958
953
|
|
959
|
-
@
|
954
|
+
@gstaichi_scope
|
960
955
|
def ti_print(*_vars, sep=" ", end="\n"):
|
961
956
|
def add_separators(_vars):
|
962
957
|
for i, _var in enumerate(_vars):
|
@@ -974,7 +969,7 @@ def ti_print(*_vars, sep=" ", end="\n"):
|
|
974
969
|
)
|
975
970
|
|
976
971
|
|
977
|
-
@
|
972
|
+
@gstaichi_scope
|
978
973
|
def ti_format(*args):
|
979
974
|
content = args[0]
|
980
975
|
mixed = args[1:]
|
@@ -997,7 +992,7 @@ def ti_format(*args):
|
|
997
992
|
return res
|
998
993
|
|
999
994
|
|
1000
|
-
@
|
995
|
+
@gstaichi_scope
|
1001
996
|
def ti_assert(cond, msg, extra_args, dbg_info):
|
1002
997
|
# Mostly a wrapper to help us convert from Expr (defined in Python) to
|
1003
998
|
# _ti_core.Expr (defined in C++)
|
@@ -1006,35 +1001,35 @@ def ti_assert(cond, msg, extra_args, dbg_info):
|
|
1006
1001
|
compiling_callable.ast_builder().create_assert_stmt(Expr(cond).ptr, msg, extra_args, dbg_info)
|
1007
1002
|
|
1008
1003
|
|
1009
|
-
@
|
1004
|
+
@gstaichi_scope
|
1010
1005
|
def ti_int(_var):
|
1011
1006
|
if hasattr(_var, "__ti_int__"):
|
1012
1007
|
return _var.__ti_int__()
|
1013
1008
|
return int(_var)
|
1014
1009
|
|
1015
1010
|
|
1016
|
-
@
|
1011
|
+
@gstaichi_scope
|
1017
1012
|
def ti_bool(_var):
|
1018
1013
|
if hasattr(_var, "__ti_bool__"):
|
1019
1014
|
return _var.__ti_bool__()
|
1020
1015
|
return bool(_var)
|
1021
1016
|
|
1022
1017
|
|
1023
|
-
@
|
1018
|
+
@gstaichi_scope
|
1024
1019
|
def ti_float(_var):
|
1025
1020
|
if hasattr(_var, "__ti_float__"):
|
1026
1021
|
return _var.__ti_float__()
|
1027
1022
|
return float(_var)
|
1028
1023
|
|
1029
1024
|
|
1030
|
-
@
|
1025
|
+
@gstaichi_scope
|
1031
1026
|
def zero(x):
|
1032
1027
|
# TODO: get dtype from Expr and Matrix:
|
1033
1028
|
"""Returns an array of zeros with the same shape and type as the input. It's also a scalar
|
1034
1029
|
if the input is a scalar.
|
1035
1030
|
|
1036
1031
|
Args:
|
1037
|
-
x (Union[:mod:`~
|
1032
|
+
x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): The input.
|
1038
1033
|
|
1039
1034
|
Returns:
|
1040
1035
|
A new copy of the input but filled with zeros.
|
@@ -1051,13 +1046,13 @@ def zero(x):
|
|
1051
1046
|
return x * 0
|
1052
1047
|
|
1053
1048
|
|
1054
|
-
@
|
1049
|
+
@gstaichi_scope
|
1055
1050
|
def one(x):
|
1056
1051
|
"""Returns an array of ones with the same shape and type as the input. It's also a scalar
|
1057
1052
|
if the input is a scalar.
|
1058
1053
|
|
1059
1054
|
Args:
|
1060
|
-
x (Union[:mod:`~
|
1055
|
+
x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): The input.
|
1061
1056
|
|
1062
1057
|
Returns:
|
1063
1058
|
A new copy of the input but filled with ones.
|
@@ -1080,7 +1075,7 @@ def axes(*x: int):
|
|
1080
1075
|
Args:
|
1081
1076
|
*x: A list of axes to be activated
|
1082
1077
|
|
1083
|
-
Note that
|
1078
|
+
Note that GsTaichi has already provided a set of commonly used axes. For example,
|
1084
1079
|
`ti.ij` is just `axes(0, 1)` under the hood.
|
1085
1080
|
"""
|
1086
1081
|
return [_ti_core.Axis(i) for i in x]
|
@@ -1090,9 +1085,9 @@ Axis = _ti_core.Axis
|
|
1090
1085
|
|
1091
1086
|
|
1092
1087
|
def static(x, *xs) -> Any:
|
1093
|
-
"""Evaluates a
|
1088
|
+
"""Evaluates a GsTaichi-scope expression at compile time.
|
1094
1089
|
|
1095
|
-
`static()` is what enables the so-called metaprogramming in
|
1090
|
+
`static()` is what enables the so-called metaprogramming in GsTaichi. It is
|
1096
1091
|
in many ways similar to ``constexpr`` in C++.
|
1097
1092
|
|
1098
1093
|
See also https://docs.taichi-lang.org/docs/meta.
|
@@ -1162,12 +1157,12 @@ def static(x, *xs) -> Any:
|
|
1162
1157
|
return x
|
1163
1158
|
if isinstance(x, Field):
|
1164
1159
|
return x
|
1165
|
-
if isinstance(x, (FunctionType, MethodType)):
|
1160
|
+
if isinstance(x, (FunctionType, MethodType, BoundGsTaichiCallable, GsTaichiCallable)):
|
1166
1161
|
return x
|
1167
1162
|
raise ValueError(f"Input to ti.static must be compile-time constants or global pointers, instead of {type(x)}")
|
1168
1163
|
|
1169
1164
|
|
1170
|
-
@
|
1165
|
+
@gstaichi_scope
|
1171
1166
|
def grouped(x):
|
1172
1167
|
"""Groups the indices in the iterator returned by `ndrange()` into a 1-D vector.
|
1173
1168
|
|
@@ -1175,7 +1170,7 @@ def grouped(x):
|
|
1175
1170
|
in one `for` loop and a single index.
|
1176
1171
|
|
1177
1172
|
Args:
|
1178
|
-
x (:func:`~
|
1173
|
+
x (:func:`~gstaichi.ndrange`): an iterator object returned by `ti.ndrange`.
|
1179
1174
|
|
1180
1175
|
Example::
|
1181
1176
|
>>> # without ti.grouped
|
@@ -1197,7 +1192,7 @@ def stop_grad(x):
|
|
1197
1192
|
"""Stops computing gradients during back propagation.
|
1198
1193
|
|
1199
1194
|
Args:
|
1200
|
-
x (:class:`~
|
1195
|
+
x (:class:`~gstaichi.Field`): A field.
|
1201
1196
|
"""
|
1202
1197
|
compiling_callable = get_runtime().compiling_callable
|
1203
1198
|
assert compiling_callable is not None
|
@@ -1220,7 +1215,7 @@ def get_cuda_compute_capability():
|
|
1220
1215
|
return _ti_core.query_int64("cuda_compute_capability")
|
1221
1216
|
|
1222
1217
|
|
1223
|
-
@
|
1218
|
+
@gstaichi_scope
|
1224
1219
|
def mesh_relation_access(mesh, from_index, to_element_type):
|
1225
1220
|
# to support ti.mesh_local and access mesh attribute as field
|
1226
1221
|
if isinstance(from_index, MeshInstance):
|