gstaichi 0.1.23.dev0__cp310-cp310-macosx_15_0_arm64.whl → 1.0.1__cp310-cp310-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +6 -0
- gstaichi/__init__.py +40 -0
- {taichi → gstaichi}/_funcs.py +8 -8
- {taichi → gstaichi}/_kernels.py +19 -19
- gstaichi/_lib/__init__.py +3 -0
- taichi/_lib/core/taichi_python.cpython-310-darwin.so → gstaichi/_lib/core/gstaichi_python.cpython-310-darwin.so +0 -0
- taichi/_lib/core/taichi_python.pyi → gstaichi/_lib/core/gstaichi_python.pyi +382 -520
- {taichi → gstaichi}/_lib/runtime/runtime_arm64.bc +0 -0
- {taichi → gstaichi}/_lib/utils.py +15 -15
- {taichi → gstaichi}/_logging.py +1 -1
- gstaichi/_snode/__init__.py +5 -0
- {taichi → gstaichi}/_snode/fields_builder.py +27 -29
- {taichi → gstaichi}/_snode/snode_tree.py +5 -5
- gstaichi/_test_tools/__init__.py +0 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_version.py +1 -0
- {taichi → gstaichi}/_version_check.py +8 -5
- gstaichi/ad/__init__.py +3 -0
- {taichi → gstaichi}/ad/_ad.py +26 -26
- {taichi → gstaichi}/algorithms/_algorithms.py +7 -7
- {taichi → gstaichi}/examples/minimal.py +1 -1
- {taichi → gstaichi}/experimental.py +1 -1
- gstaichi/lang/__init__.py +50 -0
- {taichi → gstaichi}/lang/_ndarray.py +30 -26
- {taichi → gstaichi}/lang/_ndrange.py +8 -8
- gstaichi/lang/_template_mapper.py +199 -0
- {taichi → gstaichi}/lang/_texture.py +19 -19
- {taichi → gstaichi}/lang/_wrap_inspect.py +7 -7
- {taichi → gstaichi}/lang/any_array.py +13 -13
- {taichi → gstaichi}/lang/argpack.py +29 -29
- gstaichi/lang/ast/__init__.py +5 -0
- {taichi → gstaichi}/lang/ast/ast_transformer.py +94 -582
- {taichi → gstaichi}/lang/ast/ast_transformer_utils.py +54 -41
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +267 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +320 -0
- {taichi → gstaichi}/lang/ast/checkers.py +5 -5
- gstaichi/lang/ast/transform.py +9 -0
- {taichi → gstaichi}/lang/common_ops.py +12 -12
- gstaichi/lang/exception.py +80 -0
- {taichi → gstaichi}/lang/expr.py +22 -22
- {taichi → gstaichi}/lang/field.py +29 -27
- {taichi → gstaichi}/lang/impl.py +116 -121
- {taichi → gstaichi}/lang/kernel_arguments.py +16 -16
- {taichi → gstaichi}/lang/kernel_impl.py +330 -363
- {taichi → gstaichi}/lang/matrix.py +119 -115
- {taichi → gstaichi}/lang/matrix_ops.py +6 -6
- {taichi → gstaichi}/lang/matrix_ops_utils.py +4 -4
- {taichi → gstaichi}/lang/mesh.py +22 -22
- {taichi → gstaichi}/lang/misc.py +39 -68
- {taichi → gstaichi}/lang/ops.py +146 -141
- {taichi → gstaichi}/lang/runtime_ops.py +2 -2
- {taichi → gstaichi}/lang/shell.py +3 -3
- {taichi → gstaichi}/lang/simt/__init__.py +1 -1
- {taichi → gstaichi}/lang/simt/block.py +7 -7
- {taichi → gstaichi}/lang/simt/grid.py +1 -1
- {taichi → gstaichi}/lang/simt/subgroup.py +1 -1
- {taichi → gstaichi}/lang/simt/warp.py +1 -1
- {taichi → gstaichi}/lang/snode.py +46 -44
- {taichi → gstaichi}/lang/source_builder.py +13 -13
- {taichi → gstaichi}/lang/struct.py +33 -33
- {taichi → gstaichi}/lang/util.py +24 -24
- gstaichi/linalg/__init__.py +8 -0
- {taichi → gstaichi}/linalg/matrixfree_cg.py +14 -14
- {taichi → gstaichi}/linalg/sparse_cg.py +10 -10
- {taichi → gstaichi}/linalg/sparse_matrix.py +23 -23
- {taichi → gstaichi}/linalg/sparse_solver.py +21 -21
- {taichi → gstaichi}/math/__init__.py +1 -1
- {taichi → gstaichi}/math/_complex.py +21 -20
- {taichi → gstaichi}/math/mathimpl.py +56 -56
- gstaichi/profiler/__init__.py +6 -0
- {taichi → gstaichi}/profiler/kernel_metrics.py +11 -11
- {taichi → gstaichi}/profiler/kernel_profiler.py +30 -36
- {taichi → gstaichi}/profiler/memory_profiler.py +1 -1
- {taichi → gstaichi}/profiler/scoped_profiler.py +2 -2
- {taichi → gstaichi}/sparse/_sparse_grid.py +7 -7
- {taichi → gstaichi}/tools/__init__.py +4 -4
- {taichi → gstaichi}/tools/diagnose.py +10 -17
- gstaichi/types/__init__.py +19 -0
- {taichi → gstaichi}/types/annotations.py +1 -1
- {taichi → gstaichi}/types/compound_types.py +8 -8
- {taichi → gstaichi}/types/enums.py +1 -1
- {taichi → gstaichi}/types/ndarray_type.py +7 -7
- {taichi → gstaichi}/types/primitive_types.py +17 -14
- {taichi → gstaichi}/types/quant.py +9 -9
- {taichi → gstaichi}/types/texture_type.py +5 -5
- {taichi → gstaichi}/types/utils.py +1 -1
- {gstaichi-0.1.23.dev0.dist-info → gstaichi-1.0.1.dist-info}/METADATA +13 -16
- gstaichi-1.0.1.dist-info/RECORD +166 -0
- gstaichi-1.0.1.dist-info/top_level.txt +1 -0
- gstaichi-0.1.23.dev0.dist-info/RECORD +0 -219
- gstaichi-0.1.23.dev0.dist-info/entry_points.txt +0 -2
- gstaichi-0.1.23.dev0.dist-info/top_level.txt +0 -1
- taichi/__init__.py +0 -44
- taichi/__main__.py +0 -5
- taichi/_lib/__init__.py +0 -3
- taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +0 -1401
- taichi/_lib/c_api/include/taichi/taichi.h +0 -29
- taichi/_lib/c_api/include/taichi/taichi_core.h +0 -1111
- taichi/_lib/c_api/include/taichi/taichi_cpu.h +0 -29
- taichi/_lib/c_api/include/taichi/taichi_metal.h +0 -72
- taichi/_lib/c_api/include/taichi/taichi_platform.h +0 -55
- taichi/_lib/c_api/include/taichi/taichi_unity.h +0 -64
- taichi/_lib/c_api/include/taichi/taichi_vulkan.h +0 -151
- taichi/_lib/c_api/lib/libtaichi_c_api.dylib +0 -0
- taichi/_lib/c_api/runtime/runtime_arm64.bc +0 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +0 -29
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +0 -65
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +0 -121
- taichi/_lib/runtime/libMoltenVK.dylib +0 -0
- taichi/_main.py +0 -552
- taichi/_snode/__init__.py +0 -5
- taichi/_ti_module/__init__.py +0 -3
- taichi/_ti_module/cppgen.py +0 -309
- taichi/_ti_module/module.py +0 -145
- taichi/_version.py +0 -1
- taichi/ad/__init__.py +0 -3
- taichi/aot/__init__.py +0 -12
- taichi/aot/_export.py +0 -28
- taichi/aot/conventions/__init__.py +0 -3
- taichi/aot/conventions/gfxruntime140/__init__.py +0 -38
- taichi/aot/conventions/gfxruntime140/dr.py +0 -244
- taichi/aot/conventions/gfxruntime140/sr.py +0 -613
- taichi/aot/module.py +0 -253
- taichi/aot/utils.py +0 -151
- taichi/graph/__init__.py +0 -3
- taichi/graph/_graph.py +0 -292
- taichi/lang/__init__.py +0 -50
- taichi/lang/ast/__init__.py +0 -5
- taichi/lang/ast/transform.py +0 -9
- taichi/lang/exception.py +0 -80
- taichi/linalg/__init__.py +0 -8
- taichi/profiler/__init__.py +0 -6
- taichi/shaders/Circles_vk.frag +0 -29
- taichi/shaders/Circles_vk.vert +0 -45
- taichi/shaders/Circles_vk_frag.spv +0 -0
- taichi/shaders/Circles_vk_vert.spv +0 -0
- taichi/shaders/Lines_vk.frag +0 -9
- taichi/shaders/Lines_vk.vert +0 -11
- taichi/shaders/Lines_vk_frag.spv +0 -0
- taichi/shaders/Lines_vk_vert.spv +0 -0
- taichi/shaders/Mesh_vk.frag +0 -71
- taichi/shaders/Mesh_vk.vert +0 -68
- taichi/shaders/Mesh_vk_frag.spv +0 -0
- taichi/shaders/Mesh_vk_vert.spv +0 -0
- taichi/shaders/Particles_vk.frag +0 -95
- taichi/shaders/Particles_vk.vert +0 -73
- taichi/shaders/Particles_vk_frag.spv +0 -0
- taichi/shaders/Particles_vk_vert.spv +0 -0
- taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
- taichi/shaders/SceneLines_vk.frag +0 -9
- taichi/shaders/SceneLines_vk.vert +0 -12
- taichi/shaders/SceneLines_vk_frag.spv +0 -0
- taichi/shaders/SceneLines_vk_vert.spv +0 -0
- taichi/shaders/SetImage_vk.frag +0 -21
- taichi/shaders/SetImage_vk.vert +0 -15
- taichi/shaders/SetImage_vk_frag.spv +0 -0
- taichi/shaders/SetImage_vk_vert.spv +0 -0
- taichi/shaders/Triangles_vk.frag +0 -16
- taichi/shaders/Triangles_vk.vert +0 -29
- taichi/shaders/Triangles_vk_frag.spv +0 -0
- taichi/shaders/Triangles_vk_vert.spv +0 -0
- taichi/shaders/lines2quad_vk_comp.spv +0 -0
- taichi/types/__init__.py +0 -19
- {taichi → gstaichi}/_lib/core/__init__.py +0 -0
- {taichi → gstaichi}/_lib/core/py.typed +0 -0
- {taichi/_lib/c_api → gstaichi/_lib}/runtime/libMoltenVK.dylib +0 -0
- {taichi → gstaichi}/algorithms/__init__.py +0 -0
- {taichi → gstaichi}/assets/.git +0 -0
- {taichi → gstaichi}/assets/Go-Regular.ttf +0 -0
- {taichi → gstaichi}/assets/static/imgs/ti_gallery.png +0 -0
- {taichi → gstaichi}/lang/ast/symbol_resolver.py +0 -0
- {taichi → gstaichi}/sparse/__init__.py +0 -0
- {taichi → gstaichi}/tools/np2ply.py +0 -0
- {taichi → gstaichi}/tools/vtk.py +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/GLFW/glfw3.h +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/GLFW/glfw3native.h +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/instrument.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/libspirv.h +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/libspirv.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/linker.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/optimizer.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/GLSL.std.450.h +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv.h +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_cfg.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_common.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_cpp.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_cross.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_cross_c.h +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_cross_containers.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_cross_error_handling.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_cross_util.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_glsl.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_hlsl.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_msl.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_parser.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_reflect.hpp +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/glfw3/glfw3Config.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/glfw3/glfw3Targets-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/glfw3/glfw3Targets.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +0 -0
- {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +0 -0
- {gstaichi-0.1.23.dev0.dist-info → gstaichi-1.0.1.dist-info}/WHEEL +0 -0
- {gstaichi-0.1.23.dev0.dist-info → gstaichi-1.0.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,3 @@
|
|
1
|
-
# type: ignore
|
2
|
-
|
3
1
|
import ast
|
4
2
|
import dataclasses
|
5
3
|
import functools
|
@@ -15,60 +13,181 @@ import time
|
|
15
13
|
import types
|
16
14
|
import typing
|
17
15
|
import warnings
|
18
|
-
import
|
19
|
-
from typing import Any, Callable, Type, Union
|
16
|
+
from typing import Any, Callable, Type
|
20
17
|
|
21
18
|
import numpy as np
|
22
19
|
|
23
|
-
import
|
24
|
-
import
|
25
|
-
import
|
26
|
-
import
|
27
|
-
import
|
28
|
-
import
|
29
|
-
from
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
from
|
36
|
-
from
|
20
|
+
import gstaichi.lang
|
21
|
+
import gstaichi.lang._ndarray
|
22
|
+
import gstaichi.lang._texture
|
23
|
+
import gstaichi.types.annotations
|
24
|
+
from gstaichi import _logging
|
25
|
+
from gstaichi._lib import core as _ti_core
|
26
|
+
from gstaichi._lib.core.gstaichi_python import (
|
27
|
+
ASTBuilder,
|
28
|
+
FunctionKey,
|
29
|
+
KernelCxx,
|
30
|
+
KernelLaunchContext,
|
31
|
+
)
|
32
|
+
from gstaichi.lang import impl, ops, runtime_ops
|
33
|
+
from gstaichi.lang._template_mapper import GsTaichiCallableTemplateMapper
|
34
|
+
from gstaichi.lang._wrap_inspect import getsourcefile, getsourcelines
|
35
|
+
from gstaichi.lang.any_array import AnyArray
|
36
|
+
from gstaichi.lang.argpack import ArgPack, ArgPackType
|
37
|
+
from gstaichi.lang.ast import (
|
37
38
|
ASTTransformerContext,
|
38
39
|
KernelSimplicityASTChecker,
|
39
40
|
transform_tree,
|
40
41
|
)
|
41
|
-
from
|
42
|
-
from
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
42
|
+
from gstaichi.lang.ast.ast_transformer_utils import ReturnStatus
|
43
|
+
from gstaichi.lang.exception import (
|
44
|
+
GsTaichiCompilationError,
|
45
|
+
GsTaichiRuntimeError,
|
46
|
+
GsTaichiRuntimeTypeError,
|
47
|
+
GsTaichiSyntaxError,
|
48
|
+
GsTaichiTypeError,
|
48
49
|
handle_exception_from_cpp,
|
49
50
|
)
|
50
|
-
from
|
51
|
-
from
|
52
|
-
from
|
53
|
-
from
|
54
|
-
from
|
55
|
-
from
|
56
|
-
from
|
51
|
+
from gstaichi.lang.expr import Expr
|
52
|
+
from gstaichi.lang.kernel_arguments import KernelArgument
|
53
|
+
from gstaichi.lang.matrix import MatrixType
|
54
|
+
from gstaichi.lang.shell import _shell_pop_print
|
55
|
+
from gstaichi.lang.struct import StructType
|
56
|
+
from gstaichi.lang.util import cook_dtype, has_paddle, has_pytorch
|
57
|
+
from gstaichi.types import (
|
57
58
|
ndarray_type,
|
58
59
|
primitive_types,
|
59
60
|
sparse_matrix_builder,
|
60
61
|
template,
|
61
62
|
texture_type,
|
62
63
|
)
|
63
|
-
from
|
64
|
-
from
|
65
|
-
from
|
64
|
+
from gstaichi.types.compound_types import CompoundType
|
65
|
+
from gstaichi.types.enums import AutodiffMode, Layout
|
66
|
+
from gstaichi.types.utils import is_signed
|
67
|
+
|
68
|
+
CompiledKernelKeyType = tuple[Callable, int, AutodiffMode]
|
66
69
|
|
67
70
|
|
68
|
-
|
69
|
-
"""
|
71
|
+
class GsTaichiCallable:
|
72
|
+
"""
|
73
|
+
BoundGsTaichiCallable is used to enable wrapping a bindable function with a class.
|
74
|
+
|
75
|
+
Design requirements for GsTaichiCallable:
|
76
|
+
- wrap/contain a reference to a class Func instance, and allow (the GsTaichiCallable) being passed around
|
77
|
+
like normal function pointer
|
78
|
+
- expose attributes of the wrapped class Func, such as `_if_real_function`, `_primal`, etc
|
79
|
+
- allow for (now limited) strong typing, and enable type checkers, such as pyright/mypy
|
80
|
+
- currently GsTaichiCallable is a shared type used for all functions marked with @ti.func, @ti.kernel,
|
81
|
+
python functions (?)
|
82
|
+
- note: current type-checking implementation does not distinguish between different type flavors of
|
83
|
+
GsTaichiCallable, with different values of `_if_real_function`, `_primal`, etc
|
84
|
+
- handle not only class-less functions, but also class-instance methods (where determining the `self`
|
85
|
+
reference is a challenge)
|
86
|
+
|
87
|
+
Let's take the following example:
|
88
|
+
|
89
|
+
def test_ptr_class_func():
|
90
|
+
@ti.data_oriented
|
91
|
+
class MyClass:
|
92
|
+
def __init__(self):
|
93
|
+
self.a = ti.field(dtype=ti.f32, shape=(3))
|
94
|
+
|
95
|
+
def add2numbers_py(self, x, y):
|
96
|
+
return x + y
|
97
|
+
|
98
|
+
@ti.func
|
99
|
+
def add2numbers_func(self, x, y):
|
100
|
+
return x + y
|
101
|
+
|
102
|
+
@ti.kernel
|
103
|
+
def func(self):
|
104
|
+
a, add_py, add_func = ti.static(self.a, self.add2numbers_py, self.add2numbers_func)
|
105
|
+
a[0] = add_py(2, 3)
|
106
|
+
a[1] = add_func(3, 7)
|
107
|
+
|
108
|
+
(taken from test_ptr_assign.py).
|
109
|
+
|
110
|
+
When the @ti.func decorator is parsed, the function `add2numbers_func` exists, but there is not yet any `self`
|
111
|
+
- it is not possible for the method to be bound, to a `self` instance
|
112
|
+
- however, the @ti.func annotation, runs the kernel_imp.py::func function --- it is at this point
|
113
|
+
that GsTaichi's original code creates a class Func instance (that wraps the add2numbers_func)
|
114
|
+
and immediately we create a GsTaichiCallable instance that wraps the Func instance.
|
115
|
+
- effectively, we have two layers of wrapping GsTaichiCallable->Func->function pointer
|
116
|
+
(actual function definition)
|
117
|
+
- later on, when we call self.add2numbers_py, here:
|
118
|
+
|
119
|
+
a, add_py, add_func = ti.static(self.a, self.add2numbers_py, self.add2numbers_func)
|
120
|
+
|
121
|
+
... we want to call the bound method, `self.add2numbers_py`.
|
122
|
+
- an actual python function reference, created by doing somevar = MyClass.add2numbers, can automatically
|
123
|
+
binds to self, when called from self in this way (however, add2numbers_py is actually a class
|
124
|
+
Func instance, wrapping python function reference -- now also all wrapped by a GsTaichiCallable
|
125
|
+
instance -- returned by the kernel_impl.py::func function, run by @ti.func)
|
126
|
+
- however, in order to be able to add strongly typed attributes to the wrapped python function, we need
|
127
|
+
to wrap the wrapped python function in a class
|
128
|
+
- the wrapped python function, wrapped in a GsTaichiCallable class (which is callable, and will
|
129
|
+
execute the underlying double-wrapped python function), will NOT automatically bind
|
130
|
+
- when we invoke GsTaichiCallable, the wrapped function is invoked. The wrapped function is unbound, and
|
131
|
+
so `self` is not automatically passed in, as an argument, and things break
|
132
|
+
|
133
|
+
To address this we need to use the `__get__` method, in our function wrapper, ie GsTaichiCallable,
|
134
|
+
and have the `__get__` method return the `BoundGsTaichiCallable` object. The `__get__` method handles
|
135
|
+
running the binding for us, and effectively binds `BoundFunc` object to `self` object, by passing
|
136
|
+
in the instance, as an argument into `BoundGsTaichiCallable.__init__`.
|
137
|
+
|
138
|
+
`BoundFunc` can then be used as a normal bound func - even though it's just an object instance -
|
139
|
+
using its `__call__` method. Effectively, at the time of actually invoking the underlying python
|
140
|
+
function, we have 3 layers of wrapper instances:
|
141
|
+
BoundGsTaichiCallabe -> GsTaichiCallable -> Func -> python function reference/definition
|
142
|
+
"""
|
143
|
+
|
144
|
+
def __init__(self, fn: Callable, wrapper: Callable) -> None:
|
145
|
+
self.fn: Callable = fn
|
146
|
+
self.wrapper: Callable = wrapper
|
147
|
+
self._is_real_function: bool = False
|
148
|
+
self._is_gstaichi_function: bool = False
|
149
|
+
self._is_wrapped_kernel: bool = False
|
150
|
+
self._is_classkernel: bool = False
|
151
|
+
self._primal: Kernel | None = None
|
152
|
+
self._adjoint: Kernel | None = None
|
153
|
+
self.grad: Kernel | None = None
|
154
|
+
self._is_staticmethod: bool = False
|
155
|
+
functools.update_wrapper(self, fn)
|
156
|
+
|
157
|
+
def __call__(self, *args, **kwargs):
|
158
|
+
return self.wrapper.__call__(*args, **kwargs)
|
159
|
+
|
160
|
+
def __get__(self, instance, owner):
|
161
|
+
if instance is None:
|
162
|
+
return self
|
163
|
+
return BoundGsTaichiCallable(instance, self)
|
164
|
+
|
165
|
+
|
166
|
+
class BoundGsTaichiCallable:
|
167
|
+
def __init__(self, instance: Any, gstaichi_callable: "GsTaichiCallable"):
|
168
|
+
self.wrapper = gstaichi_callable.wrapper
|
169
|
+
self.instance = instance
|
170
|
+
self.gstaichi_callable = gstaichi_callable
|
171
|
+
|
172
|
+
def __call__(self, *args, **kwargs):
|
173
|
+
return self.wrapper(self.instance, *args, **kwargs)
|
174
|
+
|
175
|
+
def __getattr__(self, k: str) -> Any:
|
176
|
+
res = getattr(self.gstaichi_callable, k)
|
177
|
+
return res
|
178
|
+
|
179
|
+
def __setattr__(self, k: str, v: Any) -> None:
|
180
|
+
# Note: these have to match the name of any attributes on this class.
|
181
|
+
if k in ("wrapper", "instance", "gstaichi_callable"):
|
182
|
+
object.__setattr__(self, k, v)
|
183
|
+
else:
|
184
|
+
setattr(self.gstaichi_callable, k, v)
|
70
185
|
|
71
|
-
|
186
|
+
|
187
|
+
def func(fn: Callable, is_real_function: bool = False) -> GsTaichiCallable:
|
188
|
+
"""Marks a function as callable in GsTaichi-scope.
|
189
|
+
|
190
|
+
This decorator transforms a Python function into a GsTaichi one. GsTaichi
|
72
191
|
will JIT compile it into native instructions.
|
73
192
|
|
74
193
|
Args:
|
@@ -91,29 +210,24 @@ def func(fn: Callable, is_real_function: bool = False):
|
|
91
210
|
is_classfunc = _inside_class(level_of_class_stackframe=3 + is_real_function)
|
92
211
|
|
93
212
|
fun = Func(fn, _classfunc=is_classfunc, is_real_function=is_real_function)
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
decorated._is_taichi_function = True
|
100
|
-
decorated._is_real_function = is_real_function
|
101
|
-
decorated.func = fun
|
102
|
-
return decorated
|
213
|
+
gstaichi_callable = GsTaichiCallable(fn, fun)
|
214
|
+
gstaichi_callable._is_gstaichi_function = True
|
215
|
+
gstaichi_callable._is_real_function = is_real_function
|
216
|
+
return gstaichi_callable
|
103
217
|
|
104
218
|
|
105
|
-
def real_func(fn: Callable):
|
219
|
+
def real_func(fn: Callable) -> GsTaichiCallable:
|
106
220
|
return func(fn, is_real_function=True)
|
107
221
|
|
108
222
|
|
109
|
-
def pyfunc(fn: Callable):
|
110
|
-
"""Marks a function as callable in both
|
223
|
+
def pyfunc(fn: Callable) -> GsTaichiCallable:
|
224
|
+
"""Marks a function as callable in both GsTaichi and Python scopes.
|
111
225
|
|
112
|
-
When called inside the
|
226
|
+
When called inside the GsTaichi scope, GsTaichi will JIT compile it into
|
113
227
|
native instructions. Otherwise it will be invoked directly as a
|
114
228
|
Python function.
|
115
229
|
|
116
|
-
See also :func:`~
|
230
|
+
See also :func:`~gstaichi.lang.kernel_impl.func`.
|
117
231
|
|
118
232
|
Args:
|
119
233
|
fn (Callable): The Python function to be decorated
|
@@ -123,33 +237,28 @@ def pyfunc(fn: Callable):
|
|
123
237
|
"""
|
124
238
|
is_classfunc = _inside_class(level_of_class_stackframe=3)
|
125
239
|
fun = Func(fn, _classfunc=is_classfunc, _pyfunc=True)
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
decorated._is_taichi_function = True
|
132
|
-
decorated._is_real_function = False
|
133
|
-
decorated.func = fun
|
134
|
-
return decorated
|
240
|
+
gstaichi_callable = GsTaichiCallable(fn, fun)
|
241
|
+
gstaichi_callable._is_gstaichi_function = True
|
242
|
+
gstaichi_callable._is_real_function = False
|
243
|
+
return gstaichi_callable
|
135
244
|
|
136
245
|
|
137
246
|
def _get_tree_and_ctx(
|
138
247
|
self: "Func | Kernel",
|
248
|
+
args: tuple[Any, ...],
|
139
249
|
excluded_parameters=(),
|
140
250
|
is_kernel: bool = True,
|
141
251
|
arg_features=None,
|
142
|
-
args=None,
|
143
252
|
ast_builder: ASTBuilder | None = None,
|
144
253
|
is_real_function: bool = False,
|
145
|
-
):
|
254
|
+
) -> tuple[ast.Module, ASTTransformerContext]:
|
146
255
|
file = getsourcefile(self.func)
|
147
256
|
src, start_lineno = getsourcelines(self.func)
|
148
257
|
src = [textwrap.fill(line, tabsize=4, width=9999) for line in src]
|
149
258
|
tree = ast.parse(textwrap.dedent("\n".join(src)))
|
150
259
|
|
151
260
|
func_body = tree.body[0]
|
152
|
-
func_body.decorator_list = []
|
261
|
+
func_body.decorator_list = [] # type: ignore , kick that can down the road...
|
153
262
|
|
154
263
|
global_vars = _get_global_vars(self.func)
|
155
264
|
|
@@ -196,17 +305,18 @@ def expand_func_arguments(arguments: list[KernelArgument]) -> list[KernelArgumen
|
|
196
305
|
return new_arguments
|
197
306
|
|
198
307
|
|
199
|
-
def _process_args(self: "Func | Kernel", is_func: bool, args, kwargs):
|
308
|
+
def _process_args(self: "Func | Kernel", is_func: bool, args: tuple[Any, ...], kwargs) -> tuple[Any, ...]:
|
200
309
|
if is_func:
|
201
310
|
self.arguments = expand_func_arguments(self.arguments)
|
202
311
|
fused_args = [argument.default for argument in self.arguments]
|
312
|
+
ret: list[Any] = [argument.default for argument in self.arguments]
|
203
313
|
len_args = len(args)
|
204
314
|
|
205
315
|
if len_args > len(fused_args):
|
206
316
|
arg_str = ", ".join([str(arg) for arg in args])
|
207
317
|
expected_str = ", ".join([f"{arg.name} : {arg.annotation}" for arg in self.arguments])
|
208
318
|
msg = f"Too many arguments. Expected ({expected_str}), got ({arg_str})."
|
209
|
-
raise
|
319
|
+
raise GsTaichiSyntaxError(msg)
|
210
320
|
|
211
321
|
for i, arg in enumerate(args):
|
212
322
|
fused_args[i] = arg
|
@@ -216,19 +326,19 @@ def _process_args(self: "Func | Kernel", is_func: bool, args, kwargs):
|
|
216
326
|
for i, arg in enumerate(self.arguments):
|
217
327
|
if key == arg.name:
|
218
328
|
if i < len_args:
|
219
|
-
raise
|
329
|
+
raise GsTaichiSyntaxError(f"Multiple values for argument '{key}'.")
|
220
330
|
fused_args[i] = value
|
221
331
|
found = True
|
222
332
|
break
|
223
333
|
if not found:
|
224
|
-
raise
|
334
|
+
raise GsTaichiSyntaxError(f"Unexpected argument '{key}'.")
|
225
335
|
|
226
336
|
for i, arg in enumerate(fused_args):
|
227
337
|
if arg is inspect.Parameter.empty:
|
228
338
|
if self.arguments[i].annotation is inspect._empty:
|
229
|
-
raise
|
339
|
+
raise GsTaichiSyntaxError(f"Parameter `{self.arguments[i].name}` missing.")
|
230
340
|
else:
|
231
|
-
raise
|
341
|
+
raise GsTaichiSyntaxError(
|
232
342
|
f"Parameter `{self.arguments[i].name} : {self.arguments[i].annotation}` missing."
|
233
343
|
)
|
234
344
|
|
@@ -237,7 +347,7 @@ def _process_args(self: "Func | Kernel", is_func: bool, args, kwargs):
|
|
237
347
|
|
238
348
|
def unpack_ndarray_struct(tree: ast.Module, struct_locals: set[str]) -> ast.Module:
|
239
349
|
class AttributeToNameTransformer(ast.NodeTransformer):
|
240
|
-
def visit_Attribute(self, node: ast.
|
350
|
+
def visit_Attribute(self, node: ast.Attribute):
|
241
351
|
if isinstance(node.value, ast.Attribute):
|
242
352
|
return node
|
243
353
|
if not isinstance(node.value, ast.Name):
|
@@ -278,7 +388,7 @@ def extract_struct_locals_from_context(ctx: ASTTransformerContext):
|
|
278
388
|
class Func:
|
279
389
|
function_counter = 0
|
280
390
|
|
281
|
-
def __init__(self, _func: Callable, _classfunc=False, _pyfunc=False, is_real_function=False):
|
391
|
+
def __init__(self, _func: Callable, _classfunc=False, _pyfunc=False, is_real_function=False) -> None:
|
282
392
|
self.func = _func
|
283
393
|
self.func_id = Func.function_counter
|
284
394
|
Func.function_counter += 1
|
@@ -294,22 +404,22 @@ class Func:
|
|
294
404
|
for i, arg in enumerate(self.arguments):
|
295
405
|
if arg.annotation == template or isinstance(arg.annotation, template):
|
296
406
|
self.template_slot_locations.append(i)
|
297
|
-
self.mapper =
|
298
|
-
self.
|
407
|
+
self.mapper = GsTaichiCallableTemplateMapper(self.arguments, self.template_slot_locations)
|
408
|
+
self.gstaichi_functions = {} # The |Function| class in C++
|
299
409
|
self.has_print = False
|
300
410
|
|
301
|
-
def __call__(self, *args, **kwargs):
|
411
|
+
def __call__(self, *args, **kwargs) -> Any:
|
302
412
|
args = _process_args(self, is_func=True, args=args, kwargs=kwargs)
|
303
413
|
|
304
414
|
if not impl.inside_kernel():
|
305
415
|
if not self.pyfunc:
|
306
|
-
raise
|
416
|
+
raise GsTaichiSyntaxError("GsTaichi functions cannot be called from Python-scope.")
|
307
417
|
return self.func(*args)
|
308
418
|
|
309
419
|
current_kernel = impl.get_runtime().current_kernel
|
310
420
|
if self.is_real_function:
|
311
421
|
if current_kernel.autodiff_mode != AutodiffMode.NONE:
|
312
|
-
raise
|
422
|
+
raise GsTaichiSyntaxError("Real function in gradient kernels unsupported.")
|
313
423
|
instance_id, arg_features = self.mapper.lookup(args)
|
314
424
|
key = _ti_core.FunctionKey(self.func.__name__, self.func_id, instance_id)
|
315
425
|
if key.instance_id not in self.compiled:
|
@@ -328,10 +438,10 @@ class Func:
|
|
328
438
|
ret = transform_tree(tree, ctx)
|
329
439
|
if not self.is_real_function:
|
330
440
|
if self.return_type and ctx.returned != ReturnStatus.ReturnedValue:
|
331
|
-
raise
|
441
|
+
raise GsTaichiSyntaxError("Function has a return type but does not have a return statement")
|
332
442
|
return ret
|
333
443
|
|
334
|
-
def func_call_rvalue(self, key, args):
|
444
|
+
def func_call_rvalue(self, key: FunctionKey, args: tuple[Any, ...]) -> Any:
|
335
445
|
# Skip the template args, e.g., |self|
|
336
446
|
assert self.is_real_function
|
337
447
|
non_template_args = []
|
@@ -345,7 +455,7 @@ class Func:
|
|
345
455
|
non_template_args.append(_ti_core.make_reference(args[i].ptr, dbg_info))
|
346
456
|
elif isinstance(anno, ndarray_type.NdarrayType):
|
347
457
|
if not isinstance(args[i], AnyArray):
|
348
|
-
raise
|
458
|
+
raise GsTaichiTypeError(
|
349
459
|
f"Expected ndarray in the kernel argument for argument {kernel_arg.name}, got {args[i]}"
|
350
460
|
)
|
351
461
|
non_template_args += _ti_core.get_external_tensor_real_func_args(args[i].ptr, dbg_info)
|
@@ -355,7 +465,7 @@ class Func:
|
|
355
465
|
compiling_callable = impl.get_runtime().compiling_callable
|
356
466
|
assert compiling_callable is not None
|
357
467
|
func_call = compiling_callable.ast_builder().insert_func_call(
|
358
|
-
self.
|
468
|
+
self.gstaichi_functions[key.instance_id], non_template_args, dbg_info
|
359
469
|
)
|
360
470
|
if self.return_type is None:
|
361
471
|
return None
|
@@ -372,14 +482,14 @@ class Func:
|
|
372
482
|
)
|
373
483
|
)
|
374
484
|
elif isinstance(return_type, (StructType, MatrixType)):
|
375
|
-
ret.append(return_type.
|
485
|
+
ret.append(return_type.from_gstaichi_object(func_call, (i,)))
|
376
486
|
else:
|
377
|
-
raise
|
487
|
+
raise GsTaichiTypeError(f"Unsupported return type for return value {i}: {return_type}")
|
378
488
|
if len(ret) == 1:
|
379
489
|
return ret[0]
|
380
490
|
return tuple(ret)
|
381
491
|
|
382
|
-
def do_compile(self, key, args, arg_features):
|
492
|
+
def do_compile(self, key: FunctionKey, args: tuple[Any, ...], arg_features: tuple[Any, ...]) -> None:
|
383
493
|
tree, ctx = _get_tree_and_ctx(
|
384
494
|
self, is_kernel=False, args=args, arg_features=arg_features, is_real_function=self.is_real_function
|
385
495
|
)
|
@@ -392,36 +502,42 @@ class Func:
|
|
392
502
|
transform_tree(tree, ctx)
|
393
503
|
impl.get_runtime().compiling_callable = old_callable
|
394
504
|
|
395
|
-
self.
|
505
|
+
self.gstaichi_functions[key.instance_id] = fn
|
396
506
|
self.compiled[key.instance_id] = func_body
|
397
|
-
self.
|
507
|
+
self.gstaichi_functions[key.instance_id].set_function_body(func_body)
|
398
508
|
|
399
509
|
def extract_arguments(self) -> None:
|
400
510
|
sig = inspect.signature(self.func)
|
401
511
|
if sig.return_annotation not in (inspect.Signature.empty, None):
|
402
512
|
self.return_type = sig.return_annotation
|
403
513
|
if (
|
404
|
-
isinstance(self.return_type, (types.GenericAlias, typing._GenericAlias))
|
405
|
-
and self.return_type.__origin__ is tuple
|
514
|
+
isinstance(self.return_type, (types.GenericAlias, typing._GenericAlias)) # type: ignore
|
515
|
+
and self.return_type.__origin__ is tuple # type: ignore
|
406
516
|
):
|
407
|
-
self.return_type = self.return_type.__args__
|
517
|
+
self.return_type = self.return_type.__args__ # type: ignore
|
518
|
+
if self.return_type is None:
|
519
|
+
return
|
408
520
|
if not isinstance(self.return_type, (list, tuple)):
|
409
521
|
self.return_type = (self.return_type,)
|
410
522
|
for i, return_type in enumerate(self.return_type):
|
411
523
|
if return_type is Ellipsis:
|
412
|
-
raise
|
524
|
+
raise GsTaichiSyntaxError("Ellipsis is not supported in return type annotations")
|
413
525
|
params = sig.parameters
|
414
526
|
arg_names = params.keys()
|
415
527
|
for i, arg_name in enumerate(arg_names):
|
416
528
|
param = params[arg_name]
|
417
529
|
if param.kind == inspect.Parameter.VAR_KEYWORD:
|
418
|
-
raise
|
530
|
+
raise GsTaichiSyntaxError(
|
531
|
+
"GsTaichi functions do not support variable keyword parameters (i.e., **kwargs)"
|
532
|
+
)
|
419
533
|
if param.kind == inspect.Parameter.VAR_POSITIONAL:
|
420
|
-
raise
|
534
|
+
raise GsTaichiSyntaxError(
|
535
|
+
"GsTaichi functions do not support variable positional parameters (i.e., *args)"
|
536
|
+
)
|
421
537
|
if param.kind == inspect.Parameter.KEYWORD_ONLY:
|
422
|
-
raise
|
538
|
+
raise GsTaichiSyntaxError("GsTaichi functions do not support keyword parameters")
|
423
539
|
if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
|
424
|
-
raise
|
540
|
+
raise GsTaichiSyntaxError('GsTaichi functions only support "positional or keyword" parameters')
|
425
541
|
annotation = param.annotation
|
426
542
|
if annotation is inspect.Parameter.empty:
|
427
543
|
if i == 0 and self.classfunc:
|
@@ -429,8 +545,8 @@ class Func:
|
|
429
545
|
# TODO: pyfunc also need type annotation check when real function is enabled,
|
430
546
|
# but that has to happen at runtime when we know which scope it's called from.
|
431
547
|
elif not self.pyfunc and self.is_real_function:
|
432
|
-
raise
|
433
|
-
f"
|
548
|
+
raise GsTaichiSyntaxError(
|
549
|
+
f"GsTaichi function `{self.func.__name__}` parameter `{arg_name}` must be type annotated"
|
434
550
|
)
|
435
551
|
else:
|
436
552
|
if isinstance(annotation, ndarray_type.NdarrayType):
|
@@ -441,198 +557,24 @@ class Func:
|
|
441
557
|
pass
|
442
558
|
elif id(annotation) in primitive_types.type_ids:
|
443
559
|
pass
|
444
|
-
elif type(annotation) ==
|
560
|
+
elif type(annotation) == gstaichi.types.annotations.Template:
|
445
561
|
pass
|
446
|
-
elif isinstance(annotation, template) or annotation ==
|
562
|
+
elif isinstance(annotation, template) or annotation == gstaichi.types.annotations.Template:
|
447
563
|
pass
|
448
564
|
elif isinstance(annotation, primitive_types.RefType):
|
449
565
|
pass
|
450
566
|
elif isinstance(annotation, type) and dataclasses.is_dataclass(annotation):
|
451
567
|
pass
|
452
568
|
else:
|
453
|
-
raise
|
569
|
+
raise GsTaichiSyntaxError(
|
570
|
+
f"Invalid type annotation (argument {i}) of GsTaichi function: {annotation}"
|
571
|
+
)
|
454
572
|
self.arguments.append(KernelArgument(annotation, param.name, param.default))
|
455
573
|
self.orig_arguments.append(KernelArgument(annotation, param.name, param.default))
|
456
574
|
|
457
575
|
|
458
|
-
|
459
|
-
|
460
|
-
ArgPackType,
|
461
|
-
"texture_type.TextureType",
|
462
|
-
"texture_type.RWTextureType",
|
463
|
-
ndarray_type.NdarrayType,
|
464
|
-
sparse_matrix_builder,
|
465
|
-
Any,
|
466
|
-
]
|
467
|
-
|
468
|
-
|
469
|
-
class TaichiCallableTemplateMapper:
|
470
|
-
"""
|
471
|
-
This should probably be renamed to sometihng like FeatureMapper, or
|
472
|
-
FeatureExtractor, since:
|
473
|
-
- it's not specific to templates
|
474
|
-
- it extracts what are later called 'features', for example for ndarray this includes:
|
475
|
-
- element type
|
476
|
-
- number dimensions
|
477
|
-
- needs grad (or not)
|
478
|
-
- these are returned as a heterogeneous tuple, whose contents depends on the type
|
479
|
-
"""
|
480
|
-
|
481
|
-
def __init__(self, arguments: list[KernelArgument], template_slot_locations: list[int]) -> None:
|
482
|
-
self.arguments = arguments
|
483
|
-
self.num_args = len(arguments)
|
484
|
-
self.template_slot_locations = template_slot_locations
|
485
|
-
self.mapping = {}
|
486
|
-
|
487
|
-
@staticmethod
|
488
|
-
def extract_arg(arg, annotation: AnnotationType, arg_name: str):
|
489
|
-
if annotation == template or isinstance(annotation, template):
|
490
|
-
if isinstance(arg, taichi.lang.snode.SNode):
|
491
|
-
return arg.ptr
|
492
|
-
if isinstance(arg, taichi.lang.expr.Expr):
|
493
|
-
return arg.ptr.get_underlying_ptr_address()
|
494
|
-
if isinstance(arg, _ti_core.Expr):
|
495
|
-
return arg.get_underlying_ptr_address()
|
496
|
-
if isinstance(arg, tuple):
|
497
|
-
return tuple(TaichiCallableTemplateMapper.extract_arg(item, annotation, arg_name) for item in arg)
|
498
|
-
if isinstance(arg, taichi.lang._ndarray.Ndarray):
|
499
|
-
raise TaichiRuntimeTypeError(
|
500
|
-
"Ndarray shouldn't be passed in via `ti.template()`, please annotate your kernel using `ti.types.ndarray(...)` instead"
|
501
|
-
)
|
502
|
-
|
503
|
-
if isinstance(arg, (list, tuple, dict, set)) or hasattr(arg, "_data_oriented"):
|
504
|
-
# [Composite arguments] Return weak reference to the object
|
505
|
-
# Taichi kernel will cache the extracted arguments, thus we can't simply return the original argument.
|
506
|
-
# Instead, a weak reference to the original value is returned to avoid memory leak.
|
507
|
-
|
508
|
-
# TODO(zhanlue): replacing "tuple(args)" with "hash of argument values"
|
509
|
-
# This can resolve the following issues:
|
510
|
-
# 1. Invalid weak-ref will leave a dead(dangling) entry in both caches: "self.mapping" and "self.compiled_functions"
|
511
|
-
# 2. Different argument instances with same type and same value, will get templatized into seperate kernels.
|
512
|
-
return weakref.ref(arg)
|
513
|
-
|
514
|
-
# [Primitive arguments] Return the value
|
515
|
-
return arg
|
516
|
-
if isinstance(annotation, ArgPackType):
|
517
|
-
if not isinstance(arg, ArgPack):
|
518
|
-
raise TaichiRuntimeTypeError(f"Argument {arg_name} must be a argument pack, got {type(arg)}")
|
519
|
-
return tuple(
|
520
|
-
TaichiCallableTemplateMapper.extract_arg(arg[name], dtype, arg_name)
|
521
|
-
for index, (name, dtype) in enumerate(annotation.members.items())
|
522
|
-
)
|
523
|
-
if dataclasses.is_dataclass(annotation):
|
524
|
-
_res_l = []
|
525
|
-
for field in dataclasses.fields(annotation):
|
526
|
-
field_value = getattr(arg, field.name)
|
527
|
-
arg_name = f"__ti_{arg_name}_{field.name}"
|
528
|
-
field_extracted = TaichiCallableTemplateMapper.extract_arg(field_value, field.type, arg_name)
|
529
|
-
_res_l.append(field_extracted)
|
530
|
-
return tuple(_res_l)
|
531
|
-
if isinstance(annotation, texture_type.TextureType):
|
532
|
-
if not isinstance(arg, taichi.lang._texture.Texture):
|
533
|
-
raise TaichiRuntimeTypeError(f"Argument {arg_name} must be a texture, got {type(arg)}")
|
534
|
-
if arg.num_dims != annotation.num_dimensions:
|
535
|
-
raise TaichiRuntimeTypeError(
|
536
|
-
f"TextureType dimension mismatch for argument {arg_name}: expected {annotation.num_dimensions}, got {arg.num_dims}"
|
537
|
-
)
|
538
|
-
return (arg.num_dims,)
|
539
|
-
if isinstance(annotation, texture_type.RWTextureType):
|
540
|
-
if not isinstance(arg, taichi.lang._texture.Texture):
|
541
|
-
raise TaichiRuntimeTypeError(f"Argument {arg_name} must be a texture, got {type(arg)}")
|
542
|
-
if arg.num_dims != annotation.num_dimensions:
|
543
|
-
raise TaichiRuntimeTypeError(
|
544
|
-
f"RWTextureType dimension mismatch for argument {arg_name}: expected {annotation.num_dimensions}, got {arg.num_dims}"
|
545
|
-
)
|
546
|
-
if arg.fmt != annotation.fmt:
|
547
|
-
raise TaichiRuntimeTypeError(
|
548
|
-
f"RWTextureType format mismatch for argument {arg_name}: expected {annotation.fmt}, got {arg.fmt}"
|
549
|
-
)
|
550
|
-
# (penguinliong) '0' is the assumed LOD level. We currently don't
|
551
|
-
# support mip-mapping.
|
552
|
-
return arg.num_dims, arg.fmt, 0
|
553
|
-
if isinstance(annotation, ndarray_type.NdarrayType):
|
554
|
-
if isinstance(arg, taichi.lang._ndarray.Ndarray):
|
555
|
-
annotation.check_matched(arg.get_type(), arg_name)
|
556
|
-
needs_grad = (arg.grad is not None) if annotation.needs_grad is None else annotation.needs_grad
|
557
|
-
assert arg.shape is not None
|
558
|
-
return arg.element_type, len(arg.shape), needs_grad, annotation.boundary
|
559
|
-
if isinstance(arg, AnyArray):
|
560
|
-
ty = arg.get_type()
|
561
|
-
annotation.check_matched(arg.get_type(), arg_name)
|
562
|
-
return ty.element_type, len(arg.shape), ty.needs_grad, annotation.boundary
|
563
|
-
# external arrays
|
564
|
-
shape = getattr(arg, "shape", None)
|
565
|
-
if shape is None:
|
566
|
-
raise TaichiRuntimeTypeError(f"Invalid type for argument {arg_name}, got {arg}")
|
567
|
-
shape = tuple(shape)
|
568
|
-
element_shape: tuple[int, ...] = ()
|
569
|
-
dtype = to_taichi_type(arg.dtype)
|
570
|
-
if isinstance(annotation.dtype, MatrixType):
|
571
|
-
if annotation.ndim is not None:
|
572
|
-
if len(shape) != annotation.dtype.ndim + annotation.ndim:
|
573
|
-
raise ValueError(
|
574
|
-
f"Invalid value for argument {arg_name} - required array has ndim={annotation.ndim} element_dim={annotation.dtype.ndim}, "
|
575
|
-
f"array with {len(shape)} dimensions is provided"
|
576
|
-
)
|
577
|
-
else:
|
578
|
-
if len(shape) < annotation.dtype.ndim:
|
579
|
-
raise ValueError(
|
580
|
-
f"Invalid value for argument {arg_name} - required element_dim={annotation.dtype.ndim}, "
|
581
|
-
f"array with {len(shape)} dimensions is provided"
|
582
|
-
)
|
583
|
-
element_shape = shape[-annotation.dtype.ndim :]
|
584
|
-
anno_element_shape = annotation.dtype.get_shape()
|
585
|
-
if None not in anno_element_shape and element_shape != anno_element_shape:
|
586
|
-
raise ValueError(
|
587
|
-
f"Invalid value for argument {arg_name} - required element_shape={anno_element_shape}, "
|
588
|
-
f"array with element shape of {element_shape} is provided"
|
589
|
-
)
|
590
|
-
elif annotation.dtype is not None:
|
591
|
-
# User specified scalar dtype
|
592
|
-
if annotation.dtype != dtype:
|
593
|
-
raise ValueError(
|
594
|
-
f"Invalid value for argument {arg_name} - required array has dtype={annotation.dtype.to_string()}, "
|
595
|
-
f"array with dtype={dtype.to_string()} is provided"
|
596
|
-
)
|
597
|
-
|
598
|
-
if annotation.ndim is not None and len(shape) != annotation.ndim:
|
599
|
-
raise ValueError(
|
600
|
-
f"Invalid value for argument {arg_name} - required array has ndim={annotation.ndim}, "
|
601
|
-
f"array with {len(shape)} dimensions is provided"
|
602
|
-
)
|
603
|
-
needs_grad = (
|
604
|
-
getattr(arg, "requires_grad", False) if annotation.needs_grad is None else annotation.needs_grad
|
605
|
-
)
|
606
|
-
element_type = (
|
607
|
-
_ti_core.get_type_factory_instance().get_tensor_type(element_shape, dtype)
|
608
|
-
if len(element_shape) != 0
|
609
|
-
else arg.dtype
|
610
|
-
)
|
611
|
-
return element_type, len(shape) - len(element_shape), needs_grad, annotation.boundary
|
612
|
-
if isinstance(annotation, sparse_matrix_builder):
|
613
|
-
return arg.dtype
|
614
|
-
# Use '#' as a placeholder because other kinds of arguments are not involved in template instantiation
|
615
|
-
return "#"
|
616
|
-
|
617
|
-
def extract(self, args):
|
618
|
-
extracted = []
|
619
|
-
for arg, kernel_arg in zip(args, self.arguments):
|
620
|
-
extracted.append(self.extract_arg(arg, kernel_arg.annotation, kernel_arg.name))
|
621
|
-
return tuple(extracted)
|
622
|
-
|
623
|
-
def lookup(self, args):
|
624
|
-
if len(args) != self.num_args:
|
625
|
-
raise TypeError(f"{self.num_args} argument(s) needed but {len(args)} provided.")
|
626
|
-
|
627
|
-
key = self.extract(args)
|
628
|
-
if key not in self.mapping:
|
629
|
-
count = len(self.mapping)
|
630
|
-
self.mapping[key] = count
|
631
|
-
return self.mapping[key], key
|
632
|
-
|
633
|
-
|
634
|
-
def _get_global_vars(_func):
|
635
|
-
# Discussions: https://github.com/taichi-dev/taichi/issues/282
|
576
|
+
def _get_global_vars(_func: Callable) -> dict[str, Any]:
|
577
|
+
# Discussions: https://github.com/taichi-dev/gstaichi/issues/282
|
636
578
|
global_vars = _func.__globals__.copy()
|
637
579
|
|
638
580
|
freevar_names = _func.__code__.co_freevars
|
@@ -648,7 +590,7 @@ def _get_global_vars(_func):
|
|
648
590
|
class Kernel:
|
649
591
|
counter = 0
|
650
592
|
|
651
|
-
def __init__(self, _func: Callable, autodiff_mode, _classkernel=False):
|
593
|
+
def __init__(self, _func: Callable, autodiff_mode: AutodiffMode, _classkernel=False) -> None:
|
652
594
|
self.func = _func
|
653
595
|
self.kernel_counter = Kernel.counter
|
654
596
|
Kernel.counter += 1
|
@@ -668,27 +610,27 @@ class Kernel:
|
|
668
610
|
for i, arg in enumerate(self.arguments):
|
669
611
|
if arg.annotation == template or isinstance(arg.annotation, template):
|
670
612
|
self.template_slot_locations.append(i)
|
671
|
-
self.mapper =
|
613
|
+
self.mapper = GsTaichiCallableTemplateMapper(self.arguments, self.template_slot_locations)
|
672
614
|
impl.get_runtime().kernels.append(self)
|
673
615
|
self.reset()
|
674
616
|
self.kernel_cpp = None
|
675
|
-
self.compiled_kernels = {}
|
617
|
+
self.compiled_kernels: dict[CompiledKernelKeyType, KernelCxx] = {}
|
676
618
|
self.has_print = False
|
677
619
|
|
678
620
|
def ast_builder(self) -> ASTBuilder:
|
679
621
|
assert self.kernel_cpp is not None
|
680
622
|
return self.kernel_cpp.ast_builder()
|
681
623
|
|
682
|
-
def reset(self):
|
624
|
+
def reset(self) -> None:
|
683
625
|
self.runtime = impl.get_runtime()
|
684
626
|
self.compiled_kernels = {}
|
685
627
|
|
686
|
-
def extract_arguments(self):
|
628
|
+
def extract_arguments(self) -> None:
|
687
629
|
sig = inspect.signature(self.func)
|
688
630
|
if sig.return_annotation not in (inspect._empty, None):
|
689
631
|
self.return_type = sig.return_annotation
|
690
632
|
if (
|
691
|
-
isinstance(self.return_type, (types.GenericAlias, typing._GenericAlias))
|
633
|
+
isinstance(self.return_type, (types.GenericAlias, typing._GenericAlias)) # type: ignore
|
692
634
|
and self.return_type.__origin__ is tuple
|
693
635
|
):
|
694
636
|
self.return_type = self.return_type.__args__
|
@@ -696,27 +638,31 @@ class Kernel:
|
|
696
638
|
self.return_type = (self.return_type,)
|
697
639
|
for return_type in self.return_type:
|
698
640
|
if return_type is Ellipsis:
|
699
|
-
raise
|
641
|
+
raise GsTaichiSyntaxError("Ellipsis is not supported in return type annotations")
|
700
642
|
params = sig.parameters
|
701
643
|
arg_names = params.keys()
|
702
644
|
for i, arg_name in enumerate(arg_names):
|
703
645
|
param = params[arg_name]
|
704
646
|
if param.kind == inspect.Parameter.VAR_KEYWORD:
|
705
|
-
raise
|
647
|
+
raise GsTaichiSyntaxError(
|
648
|
+
"GsTaichi kernels do not support variable keyword parameters (i.e., **kwargs)"
|
649
|
+
)
|
706
650
|
if param.kind == inspect.Parameter.VAR_POSITIONAL:
|
707
|
-
raise
|
651
|
+
raise GsTaichiSyntaxError(
|
652
|
+
"GsTaichi kernels do not support variable positional parameters (i.e., *args)"
|
653
|
+
)
|
708
654
|
if param.default is not inspect.Parameter.empty:
|
709
|
-
raise
|
655
|
+
raise GsTaichiSyntaxError("GsTaichi kernels do not support default values for arguments")
|
710
656
|
if param.kind == inspect.Parameter.KEYWORD_ONLY:
|
711
|
-
raise
|
657
|
+
raise GsTaichiSyntaxError("GsTaichi kernels do not support keyword parameters")
|
712
658
|
if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
|
713
|
-
raise
|
659
|
+
raise GsTaichiSyntaxError('GsTaichi kernels only support "positional or keyword" parameters')
|
714
660
|
annotation = param.annotation
|
715
661
|
if param.annotation is inspect.Parameter.empty:
|
716
662
|
if i == 0 and self.classkernel: # The |self| parameter
|
717
663
|
annotation = template()
|
718
664
|
else:
|
719
|
-
raise
|
665
|
+
raise GsTaichiSyntaxError("GsTaichi kernels parameters must be type annotated")
|
720
666
|
else:
|
721
667
|
if isinstance(
|
722
668
|
annotation,
|
@@ -743,10 +689,12 @@ class Kernel:
|
|
743
689
|
elif isinstance(annotation, type) and dataclasses.is_dataclass(annotation):
|
744
690
|
pass
|
745
691
|
else:
|
746
|
-
raise
|
692
|
+
raise GsTaichiSyntaxError(
|
693
|
+
f"Invalid type annotation (argument {i}) of GsTaichi kernel: {annotation}"
|
694
|
+
)
|
747
695
|
self.arguments.append(KernelArgument(annotation, param.name, param.default))
|
748
696
|
|
749
|
-
def materialize(self, key, args:
|
697
|
+
def materialize(self, key: CompiledKernelKeyType | None, args: tuple[Any, ...], arg_features):
|
750
698
|
if key is None:
|
751
699
|
key = (self.func, 0, self.autodiff_mode)
|
752
700
|
self.runtime.materialize()
|
@@ -767,15 +715,15 @@ class Kernel:
|
|
767
715
|
if self.autodiff_mode != AutodiffMode.NONE:
|
768
716
|
KernelSimplicityASTChecker(self.func).visit(tree)
|
769
717
|
|
770
|
-
# Do not change the name of '
|
718
|
+
# Do not change the name of 'gstaichi_ast_generator'
|
771
719
|
# The warning system needs this identifier to remove unnecessary messages
|
772
|
-
def
|
720
|
+
def gstaichi_ast_generator(kernel_cxx: Kernel): # not sure if this type is correct, seems doubtful
|
773
721
|
nonlocal tree
|
774
722
|
if self.runtime.inside_kernel:
|
775
|
-
raise
|
723
|
+
raise GsTaichiSyntaxError(
|
776
724
|
"Kernels cannot call other kernels. I.e., nested kernels are not allowed. "
|
777
725
|
"Please check if you have direct/indirect invocation of kernels within kernels. "
|
778
|
-
"Note that some methods provided by the
|
726
|
+
"Note that some methods provided by the GsTaichi standard library may invoke kernels, "
|
779
727
|
"and please move their invocations to Python-scope."
|
780
728
|
)
|
781
729
|
self.kernel_cpp = kernel_cxx
|
@@ -786,7 +734,7 @@ class Kernel:
|
|
786
734
|
try:
|
787
735
|
ctx.ast_builder = kernel_cxx.ast_builder()
|
788
736
|
|
789
|
-
def ast_to_dict(node):
|
737
|
+
def ast_to_dict(node: ast.AST | list | primitive_types._python_primitive_types):
|
790
738
|
if isinstance(node, ast.AST):
|
791
739
|
fields = {k: ast_to_dict(v) for k, v in ast.iter_fields(node)}
|
792
740
|
return {
|
@@ -824,17 +772,17 @@ class Kernel:
|
|
824
772
|
transform_tree(tree, ctx)
|
825
773
|
if not ctx.is_real_function:
|
826
774
|
if self.return_type and ctx.returned != ReturnStatus.ReturnedValue:
|
827
|
-
raise
|
775
|
+
raise GsTaichiSyntaxError("Kernel has a return type but does not have a return statement")
|
828
776
|
finally:
|
829
777
|
self.runtime.inside_kernel = False
|
830
778
|
self.runtime._current_kernel = None
|
831
779
|
self.runtime.compiling_callable = None
|
832
780
|
|
833
|
-
|
781
|
+
gstaichi_kernel = impl.get_runtime().prog.create_kernel(gstaichi_ast_generator, kernel_name, self.autodiff_mode)
|
834
782
|
assert key not in self.compiled_kernels
|
835
|
-
self.compiled_kernels[key] =
|
783
|
+
self.compiled_kernels[key] = gstaichi_kernel
|
836
784
|
|
837
|
-
def launch_kernel(self, t_kernel, *args):
|
785
|
+
def launch_kernel(self, t_kernel: KernelCxx, *args) -> Any:
|
838
786
|
assert len(args) == len(self.arguments), f"{len(self.arguments)} arguments needed but {len(args)} provided"
|
839
787
|
|
840
788
|
tmps = []
|
@@ -842,25 +790,28 @@ class Kernel:
|
|
842
790
|
|
843
791
|
actual_argument_slot = 0
|
844
792
|
launch_ctx = t_kernel.make_launch_context()
|
845
|
-
max_arg_num =
|
793
|
+
max_arg_num = 512
|
846
794
|
exceed_max_arg_num = False
|
847
795
|
|
848
|
-
def set_arg_ndarray(indices, v):
|
796
|
+
def set_arg_ndarray(indices: tuple[int, ...], v: gstaichi.lang._ndarray.Ndarray) -> None:
|
849
797
|
v_primal = v.arr
|
850
798
|
v_grad = v.grad.arr if v.grad else None
|
851
799
|
if v_grad is None:
|
852
|
-
launch_ctx.set_arg_ndarray(indices, v_primal)
|
800
|
+
launch_ctx.set_arg_ndarray(indices, v_primal) # type: ignore , solvable probably, just not today
|
853
801
|
else:
|
854
|
-
launch_ctx.set_arg_ndarray_with_grad(indices, v_primal, v_grad)
|
802
|
+
launch_ctx.set_arg_ndarray_with_grad(indices, v_primal, v_grad) # type: ignore
|
855
803
|
|
856
|
-
def set_arg_texture(indices, v):
|
804
|
+
def set_arg_texture(indices: tuple[int, ...], v: gstaichi.lang._texture.Texture) -> None:
|
857
805
|
launch_ctx.set_arg_texture(indices, v.tex)
|
858
806
|
|
859
|
-
def set_arg_rw_texture(indices, v):
|
807
|
+
def set_arg_rw_texture(indices: tuple[int, ...], v: gstaichi.lang._texture.Texture) -> None:
|
860
808
|
launch_ctx.set_arg_rw_texture(indices, v.tex)
|
861
809
|
|
862
|
-
def set_arg_ext_array(indices, v, needed):
|
863
|
-
#
|
810
|
+
def set_arg_ext_array(indices: tuple[int, ...], v: Any, needed: ndarray_type.NdarrayType) -> None:
|
811
|
+
# v is things like torch Tensor and numpy array
|
812
|
+
# Not adding type for this, since adds additional dependencies
|
813
|
+
#
|
814
|
+
# Element shapes are already specialized in GsTaichi codegen.
|
864
815
|
# The shape information for element dims are no longer needed.
|
865
816
|
# Therefore we strip the element shapes from the shape vector,
|
866
817
|
# so that it only holds "real" array shapes.
|
@@ -893,7 +844,7 @@ class Kernel:
|
|
893
844
|
else:
|
894
845
|
raise ValueError(
|
895
846
|
"Non contiguous numpy arrays are not supported, please call np.ascontiguousarray(arr) "
|
896
|
-
"before passing it into
|
847
|
+
"before passing it into gstaichi kernel."
|
897
848
|
)
|
898
849
|
elif has_pytorch():
|
899
850
|
import torch # pylint: disable=C0415
|
@@ -902,9 +853,9 @@ class Kernel:
|
|
902
853
|
if not v.is_contiguous():
|
903
854
|
raise ValueError(
|
904
855
|
"Non contiguous tensors are not supported, please call tensor.contiguous() before "
|
905
|
-
"passing it into
|
856
|
+
"passing it into gstaichi kernel."
|
906
857
|
)
|
907
|
-
|
858
|
+
gstaichi_arch = self.runtime.prog.config().arch
|
908
859
|
|
909
860
|
def get_call_back(u, v):
|
910
861
|
def call_back():
|
@@ -923,14 +874,14 @@ class Kernel:
|
|
923
874
|
)
|
924
875
|
if not v.grad.is_contiguous():
|
925
876
|
raise ValueError(
|
926
|
-
"Non contiguous gradient tensors are not supported, please call tensor.grad.contiguous() before passing it into
|
877
|
+
"Non contiguous gradient tensors are not supported, please call tensor.grad.contiguous() before passing it into gstaichi kernel."
|
927
878
|
)
|
928
879
|
|
929
880
|
tmp = v
|
930
881
|
if (str(v.device) != "cpu") and not (
|
931
|
-
str(v.device).startswith("cuda") and
|
882
|
+
str(v.device).startswith("cuda") and gstaichi_arch == _ti_core.Arch.cuda
|
932
883
|
):
|
933
|
-
# Getting a torch CUDA tensor on
|
884
|
+
# Getting a torch CUDA tensor on GsTaichi non-cuda arch:
|
934
885
|
# We just replace it with a CPU tensor and by the end of kernel execution we'll use the
|
935
886
|
# callback to copy the values back to the original CUDA tensor.
|
936
887
|
host_v = v.to(device="cpu", copy=True)
|
@@ -945,8 +896,12 @@ class Kernel:
|
|
945
896
|
int(v.grad.data_ptr()) if v.grad is not None else 0,
|
946
897
|
)
|
947
898
|
else:
|
948
|
-
raise
|
899
|
+
raise GsTaichiRuntimeTypeError(
|
900
|
+
f"Argument {needed} cannot be converted into required type {type(v)}"
|
901
|
+
)
|
949
902
|
elif has_paddle():
|
903
|
+
# Do we want to continue to support paddle? :thinking_face:
|
904
|
+
# #maybeprunable
|
950
905
|
import paddle # pylint: disable=C0415 # type: ignore
|
951
906
|
|
952
907
|
if isinstance(v, paddle.Tensor):
|
@@ -958,41 +913,41 @@ class Kernel:
|
|
958
913
|
return call_back
|
959
914
|
|
960
915
|
tmp = v.value().get_tensor()
|
961
|
-
|
916
|
+
gstaichi_arch = self.runtime.prog.config().arch
|
962
917
|
if v.place.is_gpu_place():
|
963
|
-
if
|
964
|
-
# Paddle cuda tensor on
|
918
|
+
if gstaichi_arch != _ti_core.Arch.cuda:
|
919
|
+
# Paddle cuda tensor on GsTaichi non-cuda arch
|
965
920
|
host_v = v.cpu()
|
966
921
|
tmp = host_v.value().get_tensor()
|
967
922
|
callbacks.append(get_call_back(v, host_v))
|
968
923
|
elif v.place.is_cpu_place():
|
969
|
-
if
|
970
|
-
# Paddle cpu tensor on
|
924
|
+
if gstaichi_arch == _ti_core.Arch.cuda:
|
925
|
+
# Paddle cpu tensor on GsTaichi cuda arch
|
971
926
|
gpu_v = v.cuda()
|
972
927
|
tmp = gpu_v.value().get_tensor()
|
973
928
|
callbacks.append(get_call_back(v, gpu_v))
|
974
929
|
else:
|
975
930
|
# Paddle do support many other backends like XPU, NPU, MLU, IPU
|
976
|
-
raise
|
931
|
+
raise GsTaichiRuntimeTypeError(f"GsTaichi do not support backend {v.place} that Paddle support")
|
977
932
|
launch_ctx.set_arg_external_array_with_shape(
|
978
933
|
indices, int(tmp._ptr()), v.element_size() * v.size, array_shape, 0
|
979
934
|
)
|
980
935
|
else:
|
981
|
-
raise
|
936
|
+
raise GsTaichiRuntimeTypeError(f"Argument {needed} cannot be converted into required type {v}")
|
982
937
|
else:
|
983
|
-
raise
|
938
|
+
raise GsTaichiRuntimeTypeError(f"Argument {needed} cannot be converted into required type {v}")
|
984
939
|
|
985
|
-
def set_arg_matrix(indices, v, needed):
|
986
|
-
def cast_float(x):
|
940
|
+
def set_arg_matrix(indices: tuple[int, ...], v, needed) -> None:
|
941
|
+
def cast_float(x: float | np.floating | np.integer | int) -> float:
|
987
942
|
if not isinstance(x, (int, float, np.integer, np.floating)):
|
988
|
-
raise
|
943
|
+
raise GsTaichiRuntimeTypeError(
|
989
944
|
f"Argument {needed.dtype} cannot be converted into required type {type(x)}"
|
990
945
|
)
|
991
946
|
return float(x)
|
992
947
|
|
993
|
-
def cast_int(x):
|
948
|
+
def cast_int(x: int | np.integer) -> int:
|
994
949
|
if not isinstance(x, (int, np.integer)):
|
995
|
-
raise
|
950
|
+
raise GsTaichiRuntimeTypeError(
|
996
951
|
f"Argument {needed.dtype} cannot be converted into required type {type(x)}"
|
997
952
|
)
|
998
953
|
return int(x)
|
@@ -1012,13 +967,13 @@ class Kernel:
|
|
1012
967
|
v = needed(*v)
|
1013
968
|
needed.set_kernel_struct_args(v, launch_ctx, indices)
|
1014
969
|
|
1015
|
-
def set_arg_sparse_matrix_builder(indices, v):
|
970
|
+
def set_arg_sparse_matrix_builder(indices: tuple[int, ...], v) -> None:
|
1016
971
|
# Pass only the base pointer of the ti.types.sparse_matrix_builder() argument
|
1017
972
|
launch_ctx.set_arg_uint(indices, v._get_ndarray_addr())
|
1018
973
|
|
1019
974
|
set_later_list = []
|
1020
975
|
|
1021
|
-
def recursive_set_args(needed_arg_type, provided_arg_type, v, indices):
|
976
|
+
def recursive_set_args(needed_arg_type: Type, provided_arg_type: Type, v: Any, indices: tuple[int, ...]) -> int:
|
1022
977
|
"""
|
1023
978
|
Returns the number of kernel args set
|
1024
979
|
e.g. templates don't set kernel args, so returns 0
|
@@ -1033,7 +988,7 @@ class Kernel:
|
|
1033
988
|
actual_argument_slot += 1
|
1034
989
|
if isinstance(needed_arg_type, ArgPackType):
|
1035
990
|
if not isinstance(v, ArgPack):
|
1036
|
-
raise
|
991
|
+
raise GsTaichiRuntimeTypeError.get(indices, str(needed_arg_type), str(provided_arg_type))
|
1037
992
|
idx_new = 0
|
1038
993
|
for j, (name, anno) in enumerate(needed_arg_type.members.items()):
|
1039
994
|
idx_new += recursive_set_args(anno, type(v[name]), v[name], indices + (idx_new,))
|
@@ -1042,14 +997,14 @@ class Kernel:
|
|
1042
997
|
# Note: do not use sth like "needed == f32". That would be slow.
|
1043
998
|
if id(needed_arg_type) in primitive_types.real_type_ids:
|
1044
999
|
if not isinstance(v, (float, int, np.floating, np.integer)):
|
1045
|
-
raise
|
1000
|
+
raise GsTaichiRuntimeTypeError.get(indices, needed_arg_type.to_string(), provided_arg_type)
|
1046
1001
|
if in_argpack:
|
1047
1002
|
return 1
|
1048
1003
|
launch_ctx.set_arg_float(indices, float(v))
|
1049
1004
|
return 1
|
1050
1005
|
if id(needed_arg_type) in primitive_types.integer_type_ids:
|
1051
1006
|
if not isinstance(v, (int, np.integer)):
|
1052
|
-
raise
|
1007
|
+
raise GsTaichiRuntimeTypeError.get(indices, needed_arg_type.to_string(), provided_arg_type)
|
1053
1008
|
if in_argpack:
|
1054
1009
|
return 1
|
1055
1010
|
if is_signed(cook_dtype(needed_arg_type)):
|
@@ -1071,19 +1026,21 @@ class Kernel:
|
|
1071
1026
|
field_value = getattr(v, field.name)
|
1072
1027
|
idx += recursive_set_args(field.type, field.type, field_value, (indices[0] + idx,))
|
1073
1028
|
return idx
|
1074
|
-
if isinstance(needed_arg_type, ndarray_type.NdarrayType) and isinstance(v,
|
1029
|
+
if isinstance(needed_arg_type, ndarray_type.NdarrayType) and isinstance(v, gstaichi.lang._ndarray.Ndarray):
|
1075
1030
|
if in_argpack:
|
1076
1031
|
set_later_list.append((set_arg_ndarray, (v,)))
|
1077
1032
|
return 0
|
1078
1033
|
set_arg_ndarray(indices, v)
|
1079
1034
|
return 1
|
1080
|
-
if isinstance(needed_arg_type, texture_type.TextureType) and isinstance(v,
|
1035
|
+
if isinstance(needed_arg_type, texture_type.TextureType) and isinstance(v, gstaichi.lang._texture.Texture):
|
1081
1036
|
if in_argpack:
|
1082
1037
|
set_later_list.append((set_arg_texture, (v,)))
|
1083
1038
|
return 0
|
1084
1039
|
set_arg_texture(indices, v)
|
1085
1040
|
return 1
|
1086
|
-
if isinstance(needed_arg_type, texture_type.RWTextureType) and isinstance(
|
1041
|
+
if isinstance(needed_arg_type, texture_type.RWTextureType) and isinstance(
|
1042
|
+
v, gstaichi.lang._texture.Texture
|
1043
|
+
):
|
1087
1044
|
if in_argpack:
|
1088
1045
|
set_later_list.append((set_arg_rw_texture, (v,)))
|
1089
1046
|
return 0
|
@@ -1103,8 +1060,12 @@ class Kernel:
|
|
1103
1060
|
if isinstance(needed_arg_type, StructType):
|
1104
1061
|
if in_argpack:
|
1105
1062
|
return 1
|
1106
|
-
|
1107
|
-
|
1063
|
+
# Unclear how to make the following pass typing checks
|
1064
|
+
# StructType implements __instancecheck__, which should be a classmethod, but
|
1065
|
+
# is currently an instance method
|
1066
|
+
# TODO: look into this more deeply at some point
|
1067
|
+
if not isinstance(v, needed_arg_type): # type: ignore
|
1068
|
+
raise GsTaichiRuntimeTypeError(
|
1108
1069
|
f"Argument {provided_arg_type} cannot be converted into required type {needed_arg_type}"
|
1109
1070
|
)
|
1110
1071
|
needed_arg_type.set_kernel_struct_args(v, launch_ctx, indices)
|
@@ -1127,7 +1088,7 @@ class Kernel:
|
|
1127
1088
|
set_arg_func((len(args) - template_num + i,), *params)
|
1128
1089
|
|
1129
1090
|
if exceed_max_arg_num:
|
1130
|
-
raise
|
1091
|
+
raise GsTaichiRuntimeError(
|
1131
1092
|
f"The number of elements in kernel arguments is too big! Do not exceed {max_arg_num} on {_ti_core.arch_name(impl.current_cfg().arch)} backend."
|
1132
1093
|
)
|
1133
1094
|
|
@@ -1162,7 +1123,7 @@ class Kernel:
|
|
1162
1123
|
|
1163
1124
|
return ret
|
1164
1125
|
|
1165
|
-
def construct_kernel_ret(self, launch_ctx, ret_type, index=()):
|
1126
|
+
def construct_kernel_ret(self, launch_ctx: KernelLaunchContext, ret_type: Any, index: tuple[int, ...] = ()):
|
1166
1127
|
if isinstance(ret_type, CompoundType):
|
1167
1128
|
return ret_type.from_kernel_struct_ret(launch_ctx, index)
|
1168
1129
|
if ret_type in primitive_types.integer_types:
|
@@ -1171,9 +1132,9 @@ class Kernel:
|
|
1171
1132
|
return launch_ctx.get_struct_ret_uint(index)
|
1172
1133
|
if ret_type in primitive_types.real_types:
|
1173
1134
|
return launch_ctx.get_struct_ret_float(index)
|
1174
|
-
raise
|
1135
|
+
raise GsTaichiRuntimeTypeError(f"Invalid return type on index={index}")
|
1175
1136
|
|
1176
|
-
def ensure_compiled(self, *args):
|
1137
|
+
def ensure_compiled(self, *args: tuple[Any, ...]) -> tuple[Callable, int, AutodiffMode]:
|
1177
1138
|
instance_id, arg_features = self.mapper.lookup(args)
|
1178
1139
|
key = (self.func, instance_id, self.autodiff_mode)
|
1179
1140
|
self.materialize(key=key, args=args, arg_features=arg_features)
|
@@ -1182,7 +1143,7 @@ class Kernel:
|
|
1182
1143
|
# For small kernels (< 3us), the performance can be pretty sensitive to overhead in __call__
|
1183
1144
|
# Thus this part needs to be fast. (i.e. < 3us on a 4 GHz x64 CPU)
|
1184
1145
|
@_shell_pop_print
|
1185
|
-
def __call__(self, *args, **kwargs):
|
1146
|
+
def __call__(self, *args, **kwargs) -> Any:
|
1186
1147
|
args = _process_args(self, is_func=False, args=args, kwargs=kwargs)
|
1187
1148
|
|
1188
1149
|
# Transform the primal kernel to forward mode grad kernel
|
@@ -1213,7 +1174,7 @@ class Kernel:
|
|
1213
1174
|
return self.launch_kernel(kernel_cpp, *args)
|
1214
1175
|
|
1215
1176
|
|
1216
|
-
# For a
|
1177
|
+
# For a GsTaichi class definition like below:
|
1217
1178
|
#
|
1218
1179
|
# @ti.data_oriented
|
1219
1180
|
# class X:
|
@@ -1232,7 +1193,7 @@ _KERNEL_CLASS_STACKFRAME_STMT_RES = [
|
|
1232
1193
|
]
|
1233
1194
|
|
1234
1195
|
|
1235
|
-
def _inside_class(level_of_class_stackframe):
|
1196
|
+
def _inside_class(level_of_class_stackframe: int) -> bool:
|
1236
1197
|
try:
|
1237
1198
|
maybe_class_frame = sys._getframe(level_of_class_stackframe)
|
1238
1199
|
statement_list = inspect.getframeinfo(maybe_class_frame)[3]
|
@@ -1247,7 +1208,7 @@ def _inside_class(level_of_class_stackframe):
|
|
1247
1208
|
return False
|
1248
1209
|
|
1249
1210
|
|
1250
|
-
def _kernel_impl(_func: Callable, level_of_class_stackframe: int, verbose: bool = False):
|
1211
|
+
def _kernel_impl(_func: Callable, level_of_class_stackframe: int, verbose: bool = False) -> GsTaichiCallable:
|
1251
1212
|
# Can decorators determine if a function is being defined inside a class?
|
1252
1213
|
# https://stackoverflow.com/a/8793684/12003165
|
1253
1214
|
is_classkernel = _inside_class(level_of_class_stackframe + 1)
|
@@ -1259,6 +1220,7 @@ def _kernel_impl(_func: Callable, level_of_class_stackframe: int, verbose: bool
|
|
1259
1220
|
# Having |primal| contains |grad| makes the tape work.
|
1260
1221
|
primal.grad = adjoint
|
1261
1222
|
|
1223
|
+
wrapped: GsTaichiCallable
|
1262
1224
|
if is_classkernel:
|
1263
1225
|
# For class kernels, their primal/adjoint callables are constructed
|
1264
1226
|
# when the kernel is accessed via the instance inside
|
@@ -1268,24 +1230,26 @@ def _kernel_impl(_func: Callable, level_of_class_stackframe: int, verbose: bool
|
|
1268
1230
|
#
|
1269
1231
|
# See also: _BoundedDifferentiableMethod, data_oriented.
|
1270
1232
|
@functools.wraps(_func)
|
1271
|
-
def
|
1233
|
+
def wrapped_classkernel(*args, **kwargs):
|
1272
1234
|
# If we reach here (we should never), it means the class is not decorated
|
1273
1235
|
# with @ti.data_oriented, otherwise getattr would have intercepted the call.
|
1274
1236
|
clsobj = type(args[0])
|
1275
1237
|
assert not hasattr(clsobj, "_data_oriented")
|
1276
|
-
raise
|
1238
|
+
raise GsTaichiSyntaxError(f"Please decorate class {clsobj.__name__} with @ti.data_oriented")
|
1277
1239
|
|
1240
|
+
wrapped = GsTaichiCallable(_func, wrapped_classkernel)
|
1278
1241
|
else:
|
1279
1242
|
|
1280
1243
|
@functools.wraps(_func)
|
1281
|
-
def
|
1244
|
+
def wrapped_func(*args, **kwargs):
|
1282
1245
|
try:
|
1283
1246
|
return primal(*args, **kwargs)
|
1284
|
-
except (
|
1247
|
+
except (GsTaichiCompilationError, GsTaichiRuntimeError) as e:
|
1285
1248
|
if impl.get_runtime().print_full_traceback:
|
1286
1249
|
raise e
|
1287
1250
|
raise type(e)("\n" + str(e)) from None
|
1288
1251
|
|
1252
|
+
wrapped = GsTaichiCallable(_func, wrapped_func)
|
1289
1253
|
wrapped.grad = adjoint
|
1290
1254
|
|
1291
1255
|
wrapped._is_wrapped_kernel = True
|
@@ -1296,10 +1260,10 @@ def _kernel_impl(_func: Callable, level_of_class_stackframe: int, verbose: bool
|
|
1296
1260
|
|
1297
1261
|
|
1298
1262
|
def kernel(fn: Callable):
|
1299
|
-
"""Marks a function as a
|
1263
|
+
"""Marks a function as a GsTaichi kernel.
|
1300
1264
|
|
1301
|
-
A
|
1302
|
-
|
1265
|
+
A GsTaichi kernel is a function written in Python, and gets JIT compiled by
|
1266
|
+
GsTaichi into native CPU/GPU instructions (e.g. a series of CUDA kernels).
|
1303
1267
|
The top-level ``for`` loops are automatically parallelized, and distributed
|
1304
1268
|
to either a CPU thread pool or massively parallel GPUs.
|
1305
1269
|
|
@@ -1327,10 +1291,10 @@ def kernel(fn: Callable):
|
|
1327
1291
|
|
1328
1292
|
|
1329
1293
|
class _BoundedDifferentiableMethod:
|
1330
|
-
def __init__(self, kernel_owner, wrapped_kernel_func):
|
1294
|
+
def __init__(self, kernel_owner: Any, wrapped_kernel_func: GsTaichiCallable | BoundGsTaichiCallable):
|
1331
1295
|
clsobj = type(kernel_owner)
|
1332
1296
|
if not getattr(clsobj, "_data_oriented", False):
|
1333
|
-
raise
|
1297
|
+
raise GsTaichiSyntaxError(f"Please decorate class {clsobj.__name__} with @ti.data_oriented")
|
1334
1298
|
self._kernel_owner = kernel_owner
|
1335
1299
|
self._primal = wrapped_kernel_func._primal
|
1336
1300
|
self._adjoint = wrapped_kernel_func._adjoint
|
@@ -1339,23 +1303,26 @@ class _BoundedDifferentiableMethod:
|
|
1339
1303
|
|
1340
1304
|
def __call__(self, *args, **kwargs):
|
1341
1305
|
try:
|
1306
|
+
assert self._primal is not None
|
1342
1307
|
if self._is_staticmethod:
|
1343
1308
|
return self._primal(*args, **kwargs)
|
1344
1309
|
return self._primal(self._kernel_owner, *args, **kwargs)
|
1345
|
-
|
1310
|
+
|
1311
|
+
except (GsTaichiCompilationError, GsTaichiRuntimeError) as e:
|
1346
1312
|
if impl.get_runtime().print_full_traceback:
|
1347
1313
|
raise e
|
1348
1314
|
raise type(e)("\n" + str(e)) from None
|
1349
1315
|
|
1350
|
-
def grad(self, *args, **kwargs):
|
1316
|
+
def grad(self, *args, **kwargs) -> Kernel:
|
1317
|
+
assert self._adjoint is not None
|
1351
1318
|
return self._adjoint(self._kernel_owner, *args, **kwargs)
|
1352
1319
|
|
1353
1320
|
|
1354
1321
|
def data_oriented(cls):
|
1355
|
-
"""Marks a class as
|
1322
|
+
"""Marks a class as GsTaichi compatible.
|
1356
1323
|
|
1357
|
-
To allow for modularized code,
|
1358
|
-
|
1324
|
+
To allow for modularized code, GsTaichi provides this decorator so that
|
1325
|
+
GsTaichi kernels can be defined inside a class.
|
1359
1326
|
|
1360
1327
|
See also https://docs.taichi-lang.org/docs/odop
|
1361
1328
|
|
@@ -1394,11 +1361,11 @@ def data_oriented(cls):
|
|
1394
1361
|
wrapped = x.__func__
|
1395
1362
|
else:
|
1396
1363
|
wrapped = x
|
1364
|
+
assert isinstance(wrapped, (BoundGsTaichiCallable, GsTaichiCallable))
|
1397
1365
|
wrapped._is_staticmethod = is_staticmethod
|
1398
|
-
assert inspect.isfunction(wrapped)
|
1399
1366
|
if wrapped._is_classkernel:
|
1400
1367
|
ret = _BoundedDifferentiableMethod(self, wrapped)
|
1401
|
-
ret.__name__ = wrapped.__name__
|
1368
|
+
ret.__name__ = wrapped.__name__ # type: ignore
|
1402
1369
|
if is_property:
|
1403
1370
|
return ret()
|
1404
1371
|
return ret
|