gstaichi 0.1.18.dev1__cp310-cp310-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv.h +2568 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv.hpp +2579 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-0.1.18.dev1.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.dist-info/METADATA +108 -0
- gstaichi-0.1.18.dev1.dist-info/RECORD +219 -0
- gstaichi-0.1.18.dev1.dist-info/WHEEL +5 -0
- gstaichi-0.1.18.dev1.dist-info/entry_points.txt +2 -0
- gstaichi-0.1.18.dev1.dist-info/licenses/LICENSE +201 -0
- gstaichi-0.1.18.dev1.dist-info/top_level.txt +1 -0
- taichi/__init__.py +44 -0
- taichi/__main__.py +5 -0
- taichi/_funcs.py +706 -0
- taichi/_kernels.py +420 -0
- taichi/_lib/__init__.py +3 -0
- taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +1401 -0
- taichi/_lib/c_api/include/taichi/taichi.h +29 -0
- taichi/_lib/c_api/include/taichi/taichi_core.h +1111 -0
- taichi/_lib/c_api/include/taichi/taichi_cpu.h +29 -0
- taichi/_lib/c_api/include/taichi/taichi_metal.h +72 -0
- taichi/_lib/c_api/include/taichi/taichi_platform.h +55 -0
- taichi/_lib/c_api/include/taichi/taichi_unity.h +64 -0
- taichi/_lib/c_api/include/taichi/taichi_vulkan.h +151 -0
- taichi/_lib/c_api/lib/libtaichi_c_api.dylib +0 -0
- taichi/_lib/c_api/runtime/libMoltenVK.dylib +0 -0
- taichi/_lib/c_api/runtime/runtime_arm64.bc +0 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +29 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +65 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +121 -0
- taichi/_lib/core/__init__.py +0 -0
- taichi/_lib/core/py.typed +0 -0
- taichi/_lib/core/taichi_python.cpython-310-darwin.so +0 -0
- taichi/_lib/core/taichi_python.pyi +3077 -0
- taichi/_lib/runtime/libMoltenVK.dylib +0 -0
- taichi/_lib/runtime/runtime_arm64.bc +0 -0
- taichi/_lib/utils.py +249 -0
- taichi/_logging.py +131 -0
- taichi/_main.py +552 -0
- taichi/_snode/__init__.py +5 -0
- taichi/_snode/fields_builder.py +189 -0
- taichi/_snode/snode_tree.py +34 -0
- taichi/_ti_module/__init__.py +3 -0
- taichi/_ti_module/cppgen.py +309 -0
- taichi/_ti_module/module.py +145 -0
- taichi/_version.py +1 -0
- taichi/_version_check.py +100 -0
- taichi/ad/__init__.py +3 -0
- taichi/ad/_ad.py +530 -0
- taichi/algorithms/__init__.py +3 -0
- taichi/algorithms/_algorithms.py +117 -0
- taichi/aot/__init__.py +12 -0
- taichi/aot/_export.py +28 -0
- taichi/aot/conventions/__init__.py +3 -0
- taichi/aot/conventions/gfxruntime140/__init__.py +38 -0
- taichi/aot/conventions/gfxruntime140/dr.py +244 -0
- taichi/aot/conventions/gfxruntime140/sr.py +613 -0
- taichi/aot/module.py +253 -0
- taichi/aot/utils.py +151 -0
- taichi/assets/.git +1 -0
- taichi/assets/Go-Regular.ttf +0 -0
- taichi/assets/static/imgs/ti_gallery.png +0 -0
- taichi/examples/minimal.py +28 -0
- taichi/experimental.py +16 -0
- taichi/graph/__init__.py +3 -0
- taichi/graph/_graph.py +292 -0
- taichi/lang/__init__.py +50 -0
- taichi/lang/_ndarray.py +348 -0
- taichi/lang/_ndrange.py +152 -0
- taichi/lang/_texture.py +172 -0
- taichi/lang/_wrap_inspect.py +189 -0
- taichi/lang/any_array.py +99 -0
- taichi/lang/argpack.py +411 -0
- taichi/lang/ast/__init__.py +5 -0
- taichi/lang/ast/ast_transformer.py +1806 -0
- taichi/lang/ast/ast_transformer_utils.py +328 -0
- taichi/lang/ast/checkers.py +106 -0
- taichi/lang/ast/symbol_resolver.py +57 -0
- taichi/lang/ast/transform.py +9 -0
- taichi/lang/common_ops.py +310 -0
- taichi/lang/exception.py +80 -0
- taichi/lang/expr.py +180 -0
- taichi/lang/field.py +464 -0
- taichi/lang/impl.py +1246 -0
- taichi/lang/kernel_arguments.py +157 -0
- taichi/lang/kernel_impl.py +1415 -0
- taichi/lang/matrix.py +1877 -0
- taichi/lang/matrix_ops.py +341 -0
- taichi/lang/matrix_ops_utils.py +190 -0
- taichi/lang/mesh.py +687 -0
- taichi/lang/misc.py +807 -0
- taichi/lang/ops.py +1489 -0
- taichi/lang/runtime_ops.py +13 -0
- taichi/lang/shell.py +35 -0
- taichi/lang/simt/__init__.py +5 -0
- taichi/lang/simt/block.py +94 -0
- taichi/lang/simt/grid.py +7 -0
- taichi/lang/simt/subgroup.py +191 -0
- taichi/lang/simt/warp.py +96 -0
- taichi/lang/snode.py +487 -0
- taichi/lang/source_builder.py +150 -0
- taichi/lang/struct.py +855 -0
- taichi/lang/util.py +381 -0
- taichi/linalg/__init__.py +8 -0
- taichi/linalg/matrixfree_cg.py +310 -0
- taichi/linalg/sparse_cg.py +59 -0
- taichi/linalg/sparse_matrix.py +303 -0
- taichi/linalg/sparse_solver.py +123 -0
- taichi/math/__init__.py +11 -0
- taichi/math/_complex.py +204 -0
- taichi/math/mathimpl.py +886 -0
- taichi/profiler/__init__.py +6 -0
- taichi/profiler/kernel_metrics.py +260 -0
- taichi/profiler/kernel_profiler.py +592 -0
- taichi/profiler/memory_profiler.py +15 -0
- taichi/profiler/scoped_profiler.py +36 -0
- taichi/shaders/Circles_vk.frag +29 -0
- taichi/shaders/Circles_vk.vert +45 -0
- taichi/shaders/Circles_vk_frag.spv +0 -0
- taichi/shaders/Circles_vk_vert.spv +0 -0
- taichi/shaders/Lines_vk.frag +9 -0
- taichi/shaders/Lines_vk.vert +11 -0
- taichi/shaders/Lines_vk_frag.spv +0 -0
- taichi/shaders/Lines_vk_vert.spv +0 -0
- taichi/shaders/Mesh_vk.frag +71 -0
- taichi/shaders/Mesh_vk.vert +68 -0
- taichi/shaders/Mesh_vk_frag.spv +0 -0
- taichi/shaders/Mesh_vk_vert.spv +0 -0
- taichi/shaders/Particles_vk.frag +95 -0
- taichi/shaders/Particles_vk.vert +73 -0
- taichi/shaders/Particles_vk_frag.spv +0 -0
- taichi/shaders/Particles_vk_vert.spv +0 -0
- taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
- taichi/shaders/SceneLines_vk.frag +9 -0
- taichi/shaders/SceneLines_vk.vert +12 -0
- taichi/shaders/SceneLines_vk_frag.spv +0 -0
- taichi/shaders/SceneLines_vk_vert.spv +0 -0
- taichi/shaders/SetImage_vk.frag +21 -0
- taichi/shaders/SetImage_vk.vert +15 -0
- taichi/shaders/SetImage_vk_frag.spv +0 -0
- taichi/shaders/SetImage_vk_vert.spv +0 -0
- taichi/shaders/Triangles_vk.frag +16 -0
- taichi/shaders/Triangles_vk.vert +29 -0
- taichi/shaders/Triangles_vk_frag.spv +0 -0
- taichi/shaders/Triangles_vk_vert.spv +0 -0
- taichi/shaders/lines2quad_vk_comp.spv +0 -0
- taichi/sparse/__init__.py +3 -0
- taichi/sparse/_sparse_grid.py +77 -0
- taichi/tools/__init__.py +12 -0
- taichi/tools/diagnose.py +124 -0
- taichi/tools/np2ply.py +364 -0
- taichi/tools/vtk.py +38 -0
- taichi/types/__init__.py +19 -0
- taichi/types/annotations.py +47 -0
- taichi/types/compound_types.py +90 -0
- taichi/types/enums.py +49 -0
- taichi/types/ndarray_type.py +147 -0
- taichi/types/primitive_types.py +203 -0
- taichi/types/quant.py +88 -0
- taichi/types/texture_type.py +85 -0
- taichi/types/utils.py +13 -0
@@ -0,0 +1,1415 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import ast
|
4
|
+
import dataclasses
|
5
|
+
import functools
|
6
|
+
import inspect
|
7
|
+
import json
|
8
|
+
import operator
|
9
|
+
import os
|
10
|
+
import pathlib
|
11
|
+
import re
|
12
|
+
import sys
|
13
|
+
import textwrap
|
14
|
+
import time
|
15
|
+
import types
|
16
|
+
import typing
|
17
|
+
import warnings
|
18
|
+
import weakref
|
19
|
+
from typing import Any, Callable, Type, Union
|
20
|
+
|
21
|
+
import numpy as np
|
22
|
+
|
23
|
+
import taichi.lang
|
24
|
+
import taichi.lang._ndarray
|
25
|
+
import taichi.lang._texture
|
26
|
+
import taichi.lang.expr
|
27
|
+
import taichi.lang.snode
|
28
|
+
import taichi.types.annotations
|
29
|
+
from taichi import _logging
|
30
|
+
from taichi._lib import core as _ti_core
|
31
|
+
from taichi._lib.core.taichi_python import ASTBuilder
|
32
|
+
from taichi.lang import impl, ops, runtime_ops
|
33
|
+
from taichi.lang._wrap_inspect import getsourcefile, getsourcelines
|
34
|
+
from taichi.lang.any_array import AnyArray
|
35
|
+
from taichi.lang.argpack import ArgPack, ArgPackType
|
36
|
+
from taichi.lang.ast import (
|
37
|
+
ASTTransformerContext,
|
38
|
+
KernelSimplicityASTChecker,
|
39
|
+
transform_tree,
|
40
|
+
)
|
41
|
+
from taichi.lang.ast.ast_transformer_utils import ReturnStatus
|
42
|
+
from taichi.lang.exception import (
|
43
|
+
TaichiCompilationError,
|
44
|
+
TaichiRuntimeError,
|
45
|
+
TaichiRuntimeTypeError,
|
46
|
+
TaichiSyntaxError,
|
47
|
+
TaichiTypeError,
|
48
|
+
handle_exception_from_cpp,
|
49
|
+
)
|
50
|
+
from taichi.lang.expr import Expr
|
51
|
+
from taichi.lang.kernel_arguments import KernelArgument
|
52
|
+
from taichi.lang.matrix import MatrixType
|
53
|
+
from taichi.lang.shell import _shell_pop_print
|
54
|
+
from taichi.lang.struct import StructType
|
55
|
+
from taichi.lang.util import cook_dtype, has_paddle, has_pytorch, to_taichi_type
|
56
|
+
from taichi.types import (
|
57
|
+
ndarray_type,
|
58
|
+
primitive_types,
|
59
|
+
sparse_matrix_builder,
|
60
|
+
template,
|
61
|
+
texture_type,
|
62
|
+
)
|
63
|
+
from taichi.types.compound_types import CompoundType
|
64
|
+
from taichi.types.enums import AutodiffMode, Layout
|
65
|
+
from taichi.types.utils import is_signed
|
66
|
+
|
67
|
+
|
68
|
+
def func(fn: Callable, is_real_function: bool = False):
|
69
|
+
"""Marks a function as callable in Taichi-scope.
|
70
|
+
|
71
|
+
This decorator transforms a Python function into a Taichi one. Taichi
|
72
|
+
will JIT compile it into native instructions.
|
73
|
+
|
74
|
+
Args:
|
75
|
+
fn (Callable): The Python function to be decorated
|
76
|
+
is_real_function (bool): Whether the function is a real function
|
77
|
+
|
78
|
+
Returns:
|
79
|
+
Callable: The decorated function
|
80
|
+
|
81
|
+
Example::
|
82
|
+
|
83
|
+
>>> @ti.func
|
84
|
+
>>> def foo(x):
|
85
|
+
>>> return x + 2
|
86
|
+
>>>
|
87
|
+
>>> @ti.kernel
|
88
|
+
>>> def run():
|
89
|
+
>>> print(foo(40)) # 42
|
90
|
+
"""
|
91
|
+
is_classfunc = _inside_class(level_of_class_stackframe=3 + is_real_function)
|
92
|
+
|
93
|
+
fun = Func(fn, _classfunc=is_classfunc, is_real_function=is_real_function)
|
94
|
+
|
95
|
+
@functools.wraps(fn)
|
96
|
+
def decorated(*args, **kwargs):
|
97
|
+
return fun.__call__(*args, **kwargs)
|
98
|
+
|
99
|
+
decorated._is_taichi_function = True
|
100
|
+
decorated._is_real_function = is_real_function
|
101
|
+
decorated.func = fun
|
102
|
+
return decorated
|
103
|
+
|
104
|
+
|
105
|
+
def real_func(fn: Callable):
|
106
|
+
return func(fn, is_real_function=True)
|
107
|
+
|
108
|
+
|
109
|
+
def pyfunc(fn: Callable):
|
110
|
+
"""Marks a function as callable in both Taichi and Python scopes.
|
111
|
+
|
112
|
+
When called inside the Taichi scope, Taichi will JIT compile it into
|
113
|
+
native instructions. Otherwise it will be invoked directly as a
|
114
|
+
Python function.
|
115
|
+
|
116
|
+
See also :func:`~taichi.lang.kernel_impl.func`.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
fn (Callable): The Python function to be decorated
|
120
|
+
|
121
|
+
Returns:
|
122
|
+
Callable: The decorated function
|
123
|
+
"""
|
124
|
+
is_classfunc = _inside_class(level_of_class_stackframe=3)
|
125
|
+
fun = Func(fn, _classfunc=is_classfunc, _pyfunc=True)
|
126
|
+
|
127
|
+
@functools.wraps(fn)
|
128
|
+
def decorated(*args, **kwargs):
|
129
|
+
return fun.__call__(*args, **kwargs)
|
130
|
+
|
131
|
+
decorated._is_taichi_function = True
|
132
|
+
decorated._is_real_function = False
|
133
|
+
decorated.func = fun
|
134
|
+
return decorated
|
135
|
+
|
136
|
+
|
137
|
+
def _get_tree_and_ctx(
|
138
|
+
self: "Func | Kernel",
|
139
|
+
excluded_parameters=(),
|
140
|
+
is_kernel: bool = True,
|
141
|
+
arg_features=None,
|
142
|
+
args=None,
|
143
|
+
ast_builder: ASTBuilder | None = None,
|
144
|
+
is_real_function: bool = False,
|
145
|
+
):
|
146
|
+
file = getsourcefile(self.func)
|
147
|
+
src, start_lineno = getsourcelines(self.func)
|
148
|
+
src = [textwrap.fill(line, tabsize=4, width=9999) for line in src]
|
149
|
+
tree = ast.parse(textwrap.dedent("\n".join(src)))
|
150
|
+
|
151
|
+
func_body = tree.body[0]
|
152
|
+
func_body.decorator_list = []
|
153
|
+
|
154
|
+
global_vars = _get_global_vars(self.func)
|
155
|
+
|
156
|
+
if is_kernel or is_real_function:
|
157
|
+
# inject template parameters into globals
|
158
|
+
for i in self.template_slot_locations:
|
159
|
+
template_var_name = self.arguments[i].name
|
160
|
+
global_vars[template_var_name] = args[i]
|
161
|
+
parameters = inspect.signature(self.func).parameters
|
162
|
+
for arg_i, (param_name, param) in enumerate(parameters.items()):
|
163
|
+
if dataclasses.is_dataclass(param.annotation):
|
164
|
+
for member_field in dataclasses.fields(param.annotation):
|
165
|
+
child_value = getattr(args[arg_i], member_field.name)
|
166
|
+
flat_name = f"__ti_{param_name}_{member_field.name}"
|
167
|
+
global_vars[flat_name] = child_value
|
168
|
+
|
169
|
+
return tree, ASTTransformerContext(
|
170
|
+
excluded_parameters=excluded_parameters,
|
171
|
+
is_kernel=is_kernel,
|
172
|
+
func=self,
|
173
|
+
arg_features=arg_features,
|
174
|
+
global_vars=global_vars,
|
175
|
+
argument_data=args,
|
176
|
+
src=src,
|
177
|
+
start_lineno=start_lineno,
|
178
|
+
file=file,
|
179
|
+
ast_builder=ast_builder,
|
180
|
+
is_real_function=is_real_function,
|
181
|
+
)
|
182
|
+
|
183
|
+
|
184
|
+
def expand_func_arguments(arguments: list[KernelArgument]) -> list[KernelArgument]:
|
185
|
+
new_arguments = []
|
186
|
+
for argument in arguments:
|
187
|
+
if dataclasses.is_dataclass(argument.annotation):
|
188
|
+
for field in dataclasses.fields(argument.annotation):
|
189
|
+
new_argument = KernelArgument(
|
190
|
+
_annotation=field.type,
|
191
|
+
_name=f"__ti_{argument.name}_{field.name}",
|
192
|
+
)
|
193
|
+
new_arguments.append(new_argument)
|
194
|
+
else:
|
195
|
+
new_arguments.append(argument)
|
196
|
+
return new_arguments
|
197
|
+
|
198
|
+
|
199
|
+
def _process_args(self: "Func | Kernel", is_func: bool, args, kwargs):
|
200
|
+
if is_func:
|
201
|
+
self.arguments = expand_func_arguments(self.arguments)
|
202
|
+
fused_args = [argument.default for argument in self.arguments]
|
203
|
+
len_args = len(args)
|
204
|
+
|
205
|
+
if len_args > len(fused_args):
|
206
|
+
arg_str = ", ".join([str(arg) for arg in args])
|
207
|
+
expected_str = ", ".join([f"{arg.name} : {arg.annotation}" for arg in self.arguments])
|
208
|
+
msg = f"Too many arguments. Expected ({expected_str}), got ({arg_str})."
|
209
|
+
raise TaichiSyntaxError(msg)
|
210
|
+
|
211
|
+
for i, arg in enumerate(args):
|
212
|
+
fused_args[i] = arg
|
213
|
+
|
214
|
+
for key, value in kwargs.items():
|
215
|
+
found = False
|
216
|
+
for i, arg in enumerate(self.arguments):
|
217
|
+
if key == arg.name:
|
218
|
+
if i < len_args:
|
219
|
+
raise TaichiSyntaxError(f"Multiple values for argument '{key}'.")
|
220
|
+
fused_args[i] = value
|
221
|
+
found = True
|
222
|
+
break
|
223
|
+
if not found:
|
224
|
+
raise TaichiSyntaxError(f"Unexpected argument '{key}'.")
|
225
|
+
|
226
|
+
for i, arg in enumerate(fused_args):
|
227
|
+
if arg is inspect.Parameter.empty:
|
228
|
+
if self.arguments[i].annotation is inspect._empty:
|
229
|
+
raise TaichiSyntaxError(f"Parameter `{self.arguments[i].name}` missing.")
|
230
|
+
else:
|
231
|
+
raise TaichiSyntaxError(
|
232
|
+
f"Parameter `{self.arguments[i].name} : {self.arguments[i].annotation}` missing."
|
233
|
+
)
|
234
|
+
|
235
|
+
return tuple(fused_args)
|
236
|
+
|
237
|
+
|
238
|
+
def unpack_ndarray_struct(tree: ast.Module, struct_locals: set[str]) -> ast.Module:
|
239
|
+
class AttributeToNameTransformer(ast.NodeTransformer):
|
240
|
+
def visit_Attribute(self, node: ast.AST):
|
241
|
+
if isinstance(node.value, ast.Attribute):
|
242
|
+
return node
|
243
|
+
if not isinstance(node.value, ast.Name):
|
244
|
+
return node
|
245
|
+
base_id = node.value.id
|
246
|
+
attr_name = node.attr
|
247
|
+
new_id = f"__ti_{base_id}_{attr_name}"
|
248
|
+
if new_id not in struct_locals:
|
249
|
+
return node
|
250
|
+
return ast.copy_location(ast.Name(id=new_id, ctx=node.ctx), node)
|
251
|
+
|
252
|
+
transformer = AttributeToNameTransformer()
|
253
|
+
new_tree = transformer.visit(tree)
|
254
|
+
ast.fix_missing_locations(new_tree)
|
255
|
+
return new_tree
|
256
|
+
|
257
|
+
|
258
|
+
def extract_struct_locals_from_context(ctx: ASTTransformerContext):
|
259
|
+
"""
|
260
|
+
- Uses ctx.func.func to get the function signature.
|
261
|
+
- Searches this for any dataclasses:
|
262
|
+
- If it finds any dataclasses, then converts them into expanded names.
|
263
|
+
- E.g. my_struct: MyStruct, and MyStruct contains a, b, c would become:
|
264
|
+
{"__ti_my_struct_a", "__ti_my_struct_b, "__ti_my_struct_c"}
|
265
|
+
"""
|
266
|
+
assert ctx.func is not None
|
267
|
+
sig = inspect.signature(ctx.func.func)
|
268
|
+
parameters = sig.parameters
|
269
|
+
struct_locals = set()
|
270
|
+
for param_name, parameter in parameters.items():
|
271
|
+
if dataclasses.is_dataclass(parameter.annotation):
|
272
|
+
for field in dataclasses.fields(parameter.annotation):
|
273
|
+
child_name = f"__ti_{param_name}_{field.name}"
|
274
|
+
struct_locals.add(child_name)
|
275
|
+
return struct_locals
|
276
|
+
|
277
|
+
|
278
|
+
class Func:
|
279
|
+
function_counter = 0
|
280
|
+
|
281
|
+
def __init__(self, _func: Callable, _classfunc=False, _pyfunc=False, is_real_function=False):
|
282
|
+
self.func = _func
|
283
|
+
self.func_id = Func.function_counter
|
284
|
+
Func.function_counter += 1
|
285
|
+
self.compiled = {}
|
286
|
+
self.classfunc = _classfunc
|
287
|
+
self.pyfunc = _pyfunc
|
288
|
+
self.is_real_function = is_real_function
|
289
|
+
self.arguments: list[KernelArgument] = []
|
290
|
+
self.orig_arguments: list[KernelArgument] = []
|
291
|
+
self.return_type: tuple[Type, ...] | None = None
|
292
|
+
self.extract_arguments()
|
293
|
+
self.template_slot_locations: list[int] = []
|
294
|
+
for i, arg in enumerate(self.arguments):
|
295
|
+
if arg.annotation == template or isinstance(arg.annotation, template):
|
296
|
+
self.template_slot_locations.append(i)
|
297
|
+
self.mapper = TaichiCallableTemplateMapper(self.arguments, self.template_slot_locations)
|
298
|
+
self.taichi_functions = {} # The |Function| class in C++
|
299
|
+
self.has_print = False
|
300
|
+
|
301
|
+
def __call__(self, *args, **kwargs):
|
302
|
+
args = _process_args(self, is_func=True, args=args, kwargs=kwargs)
|
303
|
+
|
304
|
+
if not impl.inside_kernel():
|
305
|
+
if not self.pyfunc:
|
306
|
+
raise TaichiSyntaxError("Taichi functions cannot be called from Python-scope.")
|
307
|
+
return self.func(*args)
|
308
|
+
|
309
|
+
current_kernel = impl.get_runtime().current_kernel
|
310
|
+
if self.is_real_function:
|
311
|
+
if current_kernel.autodiff_mode != AutodiffMode.NONE:
|
312
|
+
raise TaichiSyntaxError("Real function in gradient kernels unsupported.")
|
313
|
+
instance_id, arg_features = self.mapper.lookup(args)
|
314
|
+
key = _ti_core.FunctionKey(self.func.__name__, self.func_id, instance_id)
|
315
|
+
if key.instance_id not in self.compiled:
|
316
|
+
self.do_compile(key=key, args=args, arg_features=arg_features)
|
317
|
+
return self.func_call_rvalue(key=key, args=args)
|
318
|
+
tree, ctx = _get_tree_and_ctx(
|
319
|
+
self,
|
320
|
+
is_kernel=False,
|
321
|
+
args=args,
|
322
|
+
ast_builder=current_kernel.ast_builder(),
|
323
|
+
is_real_function=self.is_real_function,
|
324
|
+
)
|
325
|
+
|
326
|
+
struct_locals = extract_struct_locals_from_context(ctx)
|
327
|
+
tree = unpack_ndarray_struct(tree, struct_locals=struct_locals)
|
328
|
+
ret = transform_tree(tree, ctx)
|
329
|
+
if not self.is_real_function:
|
330
|
+
if self.return_type and ctx.returned != ReturnStatus.ReturnedValue:
|
331
|
+
raise TaichiSyntaxError("Function has a return type but does not have a return statement")
|
332
|
+
return ret
|
333
|
+
|
334
|
+
def func_call_rvalue(self, key, args):
|
335
|
+
# Skip the template args, e.g., |self|
|
336
|
+
assert self.is_real_function
|
337
|
+
non_template_args = []
|
338
|
+
dbg_info = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
|
339
|
+
for i, kernel_arg in enumerate(self.arguments):
|
340
|
+
anno = kernel_arg.annotation
|
341
|
+
if not isinstance(anno, template):
|
342
|
+
if id(anno) in primitive_types.type_ids:
|
343
|
+
non_template_args.append(ops.cast(args[i], anno))
|
344
|
+
elif isinstance(anno, primitive_types.RefType):
|
345
|
+
non_template_args.append(_ti_core.make_reference(args[i].ptr, dbg_info))
|
346
|
+
elif isinstance(anno, ndarray_type.NdarrayType):
|
347
|
+
if not isinstance(args[i], AnyArray):
|
348
|
+
raise TaichiTypeError(
|
349
|
+
f"Expected ndarray in the kernel argument for argument {kernel_arg.name}, got {args[i]}"
|
350
|
+
)
|
351
|
+
non_template_args += _ti_core.get_external_tensor_real_func_args(args[i].ptr, dbg_info)
|
352
|
+
else:
|
353
|
+
non_template_args.append(args[i])
|
354
|
+
non_template_args = impl.make_expr_group(non_template_args)
|
355
|
+
compiling_callable = impl.get_runtime().compiling_callable
|
356
|
+
assert compiling_callable is not None
|
357
|
+
func_call = compiling_callable.ast_builder().insert_func_call(
|
358
|
+
self.taichi_functions[key.instance_id], non_template_args, dbg_info
|
359
|
+
)
|
360
|
+
if self.return_type is None:
|
361
|
+
return None
|
362
|
+
func_call = Expr(func_call)
|
363
|
+
ret = []
|
364
|
+
|
365
|
+
for i, return_type in enumerate(self.return_type):
|
366
|
+
if id(return_type) in primitive_types.type_ids:
|
367
|
+
ret.append(
|
368
|
+
Expr(
|
369
|
+
_ti_core.make_get_element_expr(
|
370
|
+
func_call.ptr, (i,), _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
|
371
|
+
)
|
372
|
+
)
|
373
|
+
)
|
374
|
+
elif isinstance(return_type, (StructType, MatrixType)):
|
375
|
+
ret.append(return_type.from_taichi_object(func_call, (i,)))
|
376
|
+
else:
|
377
|
+
raise TaichiTypeError(f"Unsupported return type for return value {i}: {return_type}")
|
378
|
+
if len(ret) == 1:
|
379
|
+
return ret[0]
|
380
|
+
return tuple(ret)
|
381
|
+
|
382
|
+
def do_compile(self, key, args, arg_features):
|
383
|
+
tree, ctx = _get_tree_and_ctx(
|
384
|
+
self, is_kernel=False, args=args, arg_features=arg_features, is_real_function=self.is_real_function
|
385
|
+
)
|
386
|
+
fn = impl.get_runtime().prog.create_function(key)
|
387
|
+
|
388
|
+
def func_body():
|
389
|
+
old_callable = impl.get_runtime().compiling_callable
|
390
|
+
impl.get_runtime().compiling_callable = fn
|
391
|
+
ctx.ast_builder = fn.ast_builder()
|
392
|
+
transform_tree(tree, ctx)
|
393
|
+
impl.get_runtime().compiling_callable = old_callable
|
394
|
+
|
395
|
+
self.taichi_functions[key.instance_id] = fn
|
396
|
+
self.compiled[key.instance_id] = func_body
|
397
|
+
self.taichi_functions[key.instance_id].set_function_body(func_body)
|
398
|
+
|
399
|
+
def extract_arguments(self) -> None:
|
400
|
+
sig = inspect.signature(self.func)
|
401
|
+
if sig.return_annotation not in (inspect.Signature.empty, None):
|
402
|
+
self.return_type = sig.return_annotation
|
403
|
+
if (
|
404
|
+
isinstance(self.return_type, (types.GenericAlias, typing._GenericAlias))
|
405
|
+
and self.return_type.__origin__ is tuple
|
406
|
+
):
|
407
|
+
self.return_type = self.return_type.__args__
|
408
|
+
if not isinstance(self.return_type, (list, tuple)):
|
409
|
+
self.return_type = (self.return_type,)
|
410
|
+
for i, return_type in enumerate(self.return_type):
|
411
|
+
if return_type is Ellipsis:
|
412
|
+
raise TaichiSyntaxError("Ellipsis is not supported in return type annotations")
|
413
|
+
params = sig.parameters
|
414
|
+
arg_names = params.keys()
|
415
|
+
for i, arg_name in enumerate(arg_names):
|
416
|
+
param = params[arg_name]
|
417
|
+
if param.kind == inspect.Parameter.VAR_KEYWORD:
|
418
|
+
raise TaichiSyntaxError("Taichi functions do not support variable keyword parameters (i.e., **kwargs)")
|
419
|
+
if param.kind == inspect.Parameter.VAR_POSITIONAL:
|
420
|
+
raise TaichiSyntaxError("Taichi functions do not support variable positional parameters (i.e., *args)")
|
421
|
+
if param.kind == inspect.Parameter.KEYWORD_ONLY:
|
422
|
+
raise TaichiSyntaxError("Taichi functions do not support keyword parameters")
|
423
|
+
if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
|
424
|
+
raise TaichiSyntaxError('Taichi functions only support "positional or keyword" parameters')
|
425
|
+
annotation = param.annotation
|
426
|
+
if annotation is inspect.Parameter.empty:
|
427
|
+
if i == 0 and self.classfunc:
|
428
|
+
annotation = template()
|
429
|
+
# TODO: pyfunc also need type annotation check when real function is enabled,
|
430
|
+
# but that has to happen at runtime when we know which scope it's called from.
|
431
|
+
elif not self.pyfunc and self.is_real_function:
|
432
|
+
raise TaichiSyntaxError(
|
433
|
+
f"Taichi function `{self.func.__name__}` parameter `{arg_name}` must be type annotated"
|
434
|
+
)
|
435
|
+
else:
|
436
|
+
if isinstance(annotation, ndarray_type.NdarrayType):
|
437
|
+
pass
|
438
|
+
elif isinstance(annotation, MatrixType):
|
439
|
+
pass
|
440
|
+
elif isinstance(annotation, StructType):
|
441
|
+
pass
|
442
|
+
elif id(annotation) in primitive_types.type_ids:
|
443
|
+
pass
|
444
|
+
elif type(annotation) == taichi.types.annotations.Template:
|
445
|
+
pass
|
446
|
+
elif isinstance(annotation, template) or annotation == taichi.types.annotations.Template:
|
447
|
+
pass
|
448
|
+
elif isinstance(annotation, primitive_types.RefType):
|
449
|
+
pass
|
450
|
+
elif isinstance(annotation, type) and dataclasses.is_dataclass(annotation):
|
451
|
+
pass
|
452
|
+
else:
|
453
|
+
raise TaichiSyntaxError(f"Invalid type annotation (argument {i}) of Taichi function: {annotation}")
|
454
|
+
self.arguments.append(KernelArgument(annotation, param.name, param.default))
|
455
|
+
self.orig_arguments.append(KernelArgument(annotation, param.name, param.default))
|
456
|
+
|
457
|
+
|
458
|
+
AnnotationType = Union[
|
459
|
+
template,
|
460
|
+
ArgPackType,
|
461
|
+
"texture_type.TextureType",
|
462
|
+
"texture_type.RWTextureType",
|
463
|
+
ndarray_type.NdarrayType,
|
464
|
+
sparse_matrix_builder,
|
465
|
+
Any,
|
466
|
+
]
|
467
|
+
|
468
|
+
|
469
|
+
class TaichiCallableTemplateMapper:
|
470
|
+
"""
|
471
|
+
This should probably be renamed to sometihng like FeatureMapper, or
|
472
|
+
FeatureExtractor, since:
|
473
|
+
- it's not specific to templates
|
474
|
+
- it extracts what are later called 'features', for example for ndarray this includes:
|
475
|
+
- element type
|
476
|
+
- number dimensions
|
477
|
+
- needs grad (or not)
|
478
|
+
- these are returned as a heterogeneous tuple, whose contents depends on the type
|
479
|
+
"""
|
480
|
+
|
481
|
+
def __init__(self, arguments: list[KernelArgument], template_slot_locations: list[int]) -> None:
|
482
|
+
self.arguments = arguments
|
483
|
+
self.num_args = len(arguments)
|
484
|
+
self.template_slot_locations = template_slot_locations
|
485
|
+
self.mapping = {}
|
486
|
+
|
487
|
+
@staticmethod
|
488
|
+
def extract_arg(arg, annotation: AnnotationType, arg_name: str):
|
489
|
+
if annotation == template or isinstance(annotation, template):
|
490
|
+
if isinstance(arg, taichi.lang.snode.SNode):
|
491
|
+
return arg.ptr
|
492
|
+
if isinstance(arg, taichi.lang.expr.Expr):
|
493
|
+
return arg.ptr.get_underlying_ptr_address()
|
494
|
+
if isinstance(arg, _ti_core.Expr):
|
495
|
+
return arg.get_underlying_ptr_address()
|
496
|
+
if isinstance(arg, tuple):
|
497
|
+
return tuple(TaichiCallableTemplateMapper.extract_arg(item, annotation, arg_name) for item in arg)
|
498
|
+
if isinstance(arg, taichi.lang._ndarray.Ndarray):
|
499
|
+
raise TaichiRuntimeTypeError(
|
500
|
+
"Ndarray shouldn't be passed in via `ti.template()`, please annotate your kernel using `ti.types.ndarray(...)` instead"
|
501
|
+
)
|
502
|
+
|
503
|
+
if isinstance(arg, (list, tuple, dict, set)) or hasattr(arg, "_data_oriented"):
|
504
|
+
# [Composite arguments] Return weak reference to the object
|
505
|
+
# Taichi kernel will cache the extracted arguments, thus we can't simply return the original argument.
|
506
|
+
# Instead, a weak reference to the original value is returned to avoid memory leak.
|
507
|
+
|
508
|
+
# TODO(zhanlue): replacing "tuple(args)" with "hash of argument values"
|
509
|
+
# This can resolve the following issues:
|
510
|
+
# 1. Invalid weak-ref will leave a dead(dangling) entry in both caches: "self.mapping" and "self.compiled_functions"
|
511
|
+
# 2. Different argument instances with same type and same value, will get templatized into seperate kernels.
|
512
|
+
return weakref.ref(arg)
|
513
|
+
|
514
|
+
# [Primitive arguments] Return the value
|
515
|
+
return arg
|
516
|
+
if isinstance(annotation, ArgPackType):
|
517
|
+
if not isinstance(arg, ArgPack):
|
518
|
+
raise TaichiRuntimeTypeError(f"Argument {arg_name} must be a argument pack, got {type(arg)}")
|
519
|
+
return tuple(
|
520
|
+
TaichiCallableTemplateMapper.extract_arg(arg[name], dtype, arg_name)
|
521
|
+
for index, (name, dtype) in enumerate(annotation.members.items())
|
522
|
+
)
|
523
|
+
if dataclasses.is_dataclass(annotation):
|
524
|
+
_res_l = []
|
525
|
+
for field in dataclasses.fields(annotation):
|
526
|
+
field_value = getattr(arg, field.name)
|
527
|
+
arg_name = f"__ti_{arg_name}_{field.name}"
|
528
|
+
field_extracted = TaichiCallableTemplateMapper.extract_arg(field_value, field.type, arg_name)
|
529
|
+
_res_l.append(field_extracted)
|
530
|
+
return tuple(_res_l)
|
531
|
+
if isinstance(annotation, texture_type.TextureType):
|
532
|
+
if not isinstance(arg, taichi.lang._texture.Texture):
|
533
|
+
raise TaichiRuntimeTypeError(f"Argument {arg_name} must be a texture, got {type(arg)}")
|
534
|
+
if arg.num_dims != annotation.num_dimensions:
|
535
|
+
raise TaichiRuntimeTypeError(
|
536
|
+
f"TextureType dimension mismatch for argument {arg_name}: expected {annotation.num_dimensions}, got {arg.num_dims}"
|
537
|
+
)
|
538
|
+
return (arg.num_dims,)
|
539
|
+
if isinstance(annotation, texture_type.RWTextureType):
|
540
|
+
if not isinstance(arg, taichi.lang._texture.Texture):
|
541
|
+
raise TaichiRuntimeTypeError(f"Argument {arg_name} must be a texture, got {type(arg)}")
|
542
|
+
if arg.num_dims != annotation.num_dimensions:
|
543
|
+
raise TaichiRuntimeTypeError(
|
544
|
+
f"RWTextureType dimension mismatch for argument {arg_name}: expected {annotation.num_dimensions}, got {arg.num_dims}"
|
545
|
+
)
|
546
|
+
if arg.fmt != annotation.fmt:
|
547
|
+
raise TaichiRuntimeTypeError(
|
548
|
+
f"RWTextureType format mismatch for argument {arg_name}: expected {annotation.fmt}, got {arg.fmt}"
|
549
|
+
)
|
550
|
+
# (penguinliong) '0' is the assumed LOD level. We currently don't
|
551
|
+
# support mip-mapping.
|
552
|
+
return arg.num_dims, arg.fmt, 0
|
553
|
+
if isinstance(annotation, ndarray_type.NdarrayType):
|
554
|
+
if isinstance(arg, taichi.lang._ndarray.Ndarray):
|
555
|
+
annotation.check_matched(arg.get_type(), arg_name)
|
556
|
+
needs_grad = (arg.grad is not None) if annotation.needs_grad is None else annotation.needs_grad
|
557
|
+
assert arg.shape is not None
|
558
|
+
return arg.element_type, len(arg.shape), needs_grad, annotation.boundary
|
559
|
+
if isinstance(arg, AnyArray):
|
560
|
+
ty = arg.get_type()
|
561
|
+
annotation.check_matched(arg.get_type(), arg_name)
|
562
|
+
return ty.element_type, len(arg.shape), ty.needs_grad, annotation.boundary
|
563
|
+
# external arrays
|
564
|
+
shape = getattr(arg, "shape", None)
|
565
|
+
if shape is None:
|
566
|
+
raise TaichiRuntimeTypeError(f"Invalid type for argument {arg_name}, got {arg}")
|
567
|
+
shape = tuple(shape)
|
568
|
+
element_shape: tuple[int, ...] = ()
|
569
|
+
dtype = to_taichi_type(arg.dtype)
|
570
|
+
if isinstance(annotation.dtype, MatrixType):
|
571
|
+
if annotation.ndim is not None:
|
572
|
+
if len(shape) != annotation.dtype.ndim + annotation.ndim:
|
573
|
+
raise ValueError(
|
574
|
+
f"Invalid value for argument {arg_name} - required array has ndim={annotation.ndim} element_dim={annotation.dtype.ndim}, "
|
575
|
+
f"array with {len(shape)} dimensions is provided"
|
576
|
+
)
|
577
|
+
else:
|
578
|
+
if len(shape) < annotation.dtype.ndim:
|
579
|
+
raise ValueError(
|
580
|
+
f"Invalid value for argument {arg_name} - required element_dim={annotation.dtype.ndim}, "
|
581
|
+
f"array with {len(shape)} dimensions is provided"
|
582
|
+
)
|
583
|
+
element_shape = shape[-annotation.dtype.ndim :]
|
584
|
+
anno_element_shape = annotation.dtype.get_shape()
|
585
|
+
if None not in anno_element_shape and element_shape != anno_element_shape:
|
586
|
+
raise ValueError(
|
587
|
+
f"Invalid value for argument {arg_name} - required element_shape={anno_element_shape}, "
|
588
|
+
f"array with element shape of {element_shape} is provided"
|
589
|
+
)
|
590
|
+
elif annotation.dtype is not None:
|
591
|
+
# User specified scalar dtype
|
592
|
+
if annotation.dtype != dtype:
|
593
|
+
raise ValueError(
|
594
|
+
f"Invalid value for argument {arg_name} - required array has dtype={annotation.dtype.to_string()}, "
|
595
|
+
f"array with dtype={dtype.to_string()} is provided"
|
596
|
+
)
|
597
|
+
|
598
|
+
if annotation.ndim is not None and len(shape) != annotation.ndim:
|
599
|
+
raise ValueError(
|
600
|
+
f"Invalid value for argument {arg_name} - required array has ndim={annotation.ndim}, "
|
601
|
+
f"array with {len(shape)} dimensions is provided"
|
602
|
+
)
|
603
|
+
needs_grad = (
|
604
|
+
getattr(arg, "requires_grad", False) if annotation.needs_grad is None else annotation.needs_grad
|
605
|
+
)
|
606
|
+
element_type = (
|
607
|
+
_ti_core.get_type_factory_instance().get_tensor_type(element_shape, dtype)
|
608
|
+
if len(element_shape) != 0
|
609
|
+
else arg.dtype
|
610
|
+
)
|
611
|
+
return element_type, len(shape) - len(element_shape), needs_grad, annotation.boundary
|
612
|
+
if isinstance(annotation, sparse_matrix_builder):
|
613
|
+
return arg.dtype
|
614
|
+
# Use '#' as a placeholder because other kinds of arguments are not involved in template instantiation
|
615
|
+
return "#"
|
616
|
+
|
617
|
+
def extract(self, args):
|
618
|
+
extracted = []
|
619
|
+
for arg, kernel_arg in zip(args, self.arguments):
|
620
|
+
extracted.append(self.extract_arg(arg, kernel_arg.annotation, kernel_arg.name))
|
621
|
+
return tuple(extracted)
|
622
|
+
|
623
|
+
def lookup(self, args):
|
624
|
+
if len(args) != self.num_args:
|
625
|
+
raise TypeError(f"{self.num_args} argument(s) needed but {len(args)} provided.")
|
626
|
+
|
627
|
+
key = self.extract(args)
|
628
|
+
if key not in self.mapping:
|
629
|
+
count = len(self.mapping)
|
630
|
+
self.mapping[key] = count
|
631
|
+
return self.mapping[key], key
|
632
|
+
|
633
|
+
|
634
|
+
def _get_global_vars(_func):
|
635
|
+
# Discussions: https://github.com/taichi-dev/taichi/issues/282
|
636
|
+
global_vars = _func.__globals__.copy()
|
637
|
+
|
638
|
+
freevar_names = _func.__code__.co_freevars
|
639
|
+
closure = _func.__closure__
|
640
|
+
if closure:
|
641
|
+
freevar_values = list(map(lambda x: x.cell_contents, closure))
|
642
|
+
for name, value in zip(freevar_names, freevar_values):
|
643
|
+
global_vars[name] = value
|
644
|
+
|
645
|
+
return global_vars
|
646
|
+
|
647
|
+
|
648
|
+
class Kernel:
|
649
|
+
counter = 0
|
650
|
+
|
651
|
+
def __init__(self, _func: Callable, autodiff_mode, _classkernel=False):
|
652
|
+
self.func = _func
|
653
|
+
self.kernel_counter = Kernel.counter
|
654
|
+
Kernel.counter += 1
|
655
|
+
assert autodiff_mode in (
|
656
|
+
AutodiffMode.NONE,
|
657
|
+
AutodiffMode.VALIDATION,
|
658
|
+
AutodiffMode.FORWARD,
|
659
|
+
AutodiffMode.REVERSE,
|
660
|
+
)
|
661
|
+
self.autodiff_mode = autodiff_mode
|
662
|
+
self.grad: Kernel | None = None
|
663
|
+
self.arguments: list[KernelArgument] = []
|
664
|
+
self.return_type = None
|
665
|
+
self.classkernel = _classkernel
|
666
|
+
self.extract_arguments()
|
667
|
+
self.template_slot_locations = []
|
668
|
+
for i, arg in enumerate(self.arguments):
|
669
|
+
if arg.annotation == template or isinstance(arg.annotation, template):
|
670
|
+
self.template_slot_locations.append(i)
|
671
|
+
self.mapper = TaichiCallableTemplateMapper(self.arguments, self.template_slot_locations)
|
672
|
+
impl.get_runtime().kernels.append(self)
|
673
|
+
self.reset()
|
674
|
+
self.kernel_cpp = None
|
675
|
+
self.compiled_kernels = {}
|
676
|
+
self.has_print = False
|
677
|
+
|
678
|
+
def ast_builder(self) -> ASTBuilder:
|
679
|
+
assert self.kernel_cpp is not None
|
680
|
+
return self.kernel_cpp.ast_builder()
|
681
|
+
|
682
|
+
def reset(self):
|
683
|
+
self.runtime = impl.get_runtime()
|
684
|
+
self.compiled_kernels = {}
|
685
|
+
|
686
|
+
def extract_arguments(self):
|
687
|
+
sig = inspect.signature(self.func)
|
688
|
+
if sig.return_annotation not in (inspect._empty, None):
|
689
|
+
self.return_type = sig.return_annotation
|
690
|
+
if (
|
691
|
+
isinstance(self.return_type, (types.GenericAlias, typing._GenericAlias))
|
692
|
+
and self.return_type.__origin__ is tuple
|
693
|
+
):
|
694
|
+
self.return_type = self.return_type.__args__
|
695
|
+
if not isinstance(self.return_type, (list, tuple)):
|
696
|
+
self.return_type = (self.return_type,)
|
697
|
+
for return_type in self.return_type:
|
698
|
+
if return_type is Ellipsis:
|
699
|
+
raise TaichiSyntaxError("Ellipsis is not supported in return type annotations")
|
700
|
+
params = sig.parameters
|
701
|
+
arg_names = params.keys()
|
702
|
+
for i, arg_name in enumerate(arg_names):
|
703
|
+
param = params[arg_name]
|
704
|
+
if param.kind == inspect.Parameter.VAR_KEYWORD:
|
705
|
+
raise TaichiSyntaxError("Taichi kernels do not support variable keyword parameters (i.e., **kwargs)")
|
706
|
+
if param.kind == inspect.Parameter.VAR_POSITIONAL:
|
707
|
+
raise TaichiSyntaxError("Taichi kernels do not support variable positional parameters (i.e., *args)")
|
708
|
+
if param.default is not inspect.Parameter.empty:
|
709
|
+
raise TaichiSyntaxError("Taichi kernels do not support default values for arguments")
|
710
|
+
if param.kind == inspect.Parameter.KEYWORD_ONLY:
|
711
|
+
raise TaichiSyntaxError("Taichi kernels do not support keyword parameters")
|
712
|
+
if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
|
713
|
+
raise TaichiSyntaxError('Taichi kernels only support "positional or keyword" parameters')
|
714
|
+
annotation = param.annotation
|
715
|
+
if param.annotation is inspect.Parameter.empty:
|
716
|
+
if i == 0 and self.classkernel: # The |self| parameter
|
717
|
+
annotation = template()
|
718
|
+
else:
|
719
|
+
raise TaichiSyntaxError("Taichi kernels parameters must be type annotated")
|
720
|
+
else:
|
721
|
+
if isinstance(
|
722
|
+
annotation,
|
723
|
+
(
|
724
|
+
template,
|
725
|
+
ndarray_type.NdarrayType,
|
726
|
+
texture_type.TextureType,
|
727
|
+
texture_type.RWTextureType,
|
728
|
+
),
|
729
|
+
):
|
730
|
+
pass
|
731
|
+
elif id(annotation) in primitive_types.type_ids:
|
732
|
+
pass
|
733
|
+
elif isinstance(annotation, sparse_matrix_builder):
|
734
|
+
pass
|
735
|
+
elif isinstance(annotation, MatrixType):
|
736
|
+
pass
|
737
|
+
elif isinstance(annotation, StructType):
|
738
|
+
pass
|
739
|
+
elif isinstance(annotation, ArgPackType):
|
740
|
+
pass
|
741
|
+
elif annotation == template:
|
742
|
+
pass
|
743
|
+
elif isinstance(annotation, type) and dataclasses.is_dataclass(annotation):
|
744
|
+
pass
|
745
|
+
else:
|
746
|
+
raise TaichiSyntaxError(f"Invalid type annotation (argument {i}) of Taichi kernel: {annotation}")
|
747
|
+
self.arguments.append(KernelArgument(annotation, param.name, param.default))
|
748
|
+
|
749
|
+
def materialize(self, key, args: list[Any], arg_features):
|
750
|
+
if key is None:
|
751
|
+
key = (self.func, 0, self.autodiff_mode)
|
752
|
+
self.runtime.materialize()
|
753
|
+
|
754
|
+
if key in self.compiled_kernels:
|
755
|
+
return
|
756
|
+
|
757
|
+
kernel_name = f"{self.func.__name__}_c{self.kernel_counter}_{key[1]}"
|
758
|
+
_logging.trace(f"Compiling kernel {kernel_name} in {self.autodiff_mode}...")
|
759
|
+
|
760
|
+
tree, ctx = _get_tree_and_ctx(
|
761
|
+
self,
|
762
|
+
args=args,
|
763
|
+
excluded_parameters=self.template_slot_locations,
|
764
|
+
arg_features=arg_features,
|
765
|
+
)
|
766
|
+
|
767
|
+
if self.autodiff_mode != AutodiffMode.NONE:
|
768
|
+
KernelSimplicityASTChecker(self.func).visit(tree)
|
769
|
+
|
770
|
+
# Do not change the name of 'taichi_ast_generator'
|
771
|
+
# The warning system needs this identifier to remove unnecessary messages
|
772
|
+
def taichi_ast_generator(kernel_cxx: Kernel): # not sure if this type is correct, seems doubtful
|
773
|
+
nonlocal tree
|
774
|
+
if self.runtime.inside_kernel:
|
775
|
+
raise TaichiSyntaxError(
|
776
|
+
"Kernels cannot call other kernels. I.e., nested kernels are not allowed. "
|
777
|
+
"Please check if you have direct/indirect invocation of kernels within kernels. "
|
778
|
+
"Note that some methods provided by the Taichi standard library may invoke kernels, "
|
779
|
+
"and please move their invocations to Python-scope."
|
780
|
+
)
|
781
|
+
self.kernel_cpp = kernel_cxx
|
782
|
+
self.runtime.inside_kernel = True
|
783
|
+
self.runtime._current_kernel = self
|
784
|
+
assert self.runtime.compiling_callable is None
|
785
|
+
self.runtime.compiling_callable = kernel_cxx
|
786
|
+
try:
|
787
|
+
ctx.ast_builder = kernel_cxx.ast_builder()
|
788
|
+
|
789
|
+
def ast_to_dict(node):
|
790
|
+
if isinstance(node, ast.AST):
|
791
|
+
fields = {k: ast_to_dict(v) for k, v in ast.iter_fields(node)}
|
792
|
+
return {
|
793
|
+
"type": node.__class__.__name__,
|
794
|
+
"fields": fields,
|
795
|
+
"lineno": getattr(node, "lineno", None),
|
796
|
+
"col_offset": getattr(node, "col_offset", None),
|
797
|
+
}
|
798
|
+
if isinstance(node, list):
|
799
|
+
return [ast_to_dict(x) for x in node]
|
800
|
+
return node # Basic types (str, int, None, etc.)
|
801
|
+
|
802
|
+
if os.environ.get("TI_DUMP_AST", "") == "1":
|
803
|
+
target_dir = pathlib.Path("/tmp/ast")
|
804
|
+
target_dir.mkdir(parents=True, exist_ok=True)
|
805
|
+
|
806
|
+
start = time.time()
|
807
|
+
ast_str = ast.dump(tree, indent=2)
|
808
|
+
output_file = target_dir / f"{kernel_name}_ast.txt"
|
809
|
+
output_file.write_text(ast_str)
|
810
|
+
elapsed_txt = time.time() - start
|
811
|
+
|
812
|
+
start = time.time()
|
813
|
+
json_str = json.dumps(ast_to_dict(tree), indent=2)
|
814
|
+
output_file = target_dir / f"{kernel_name}_ast.json"
|
815
|
+
output_file.write_text(json_str)
|
816
|
+
elapsed_json = time.time() - start
|
817
|
+
|
818
|
+
output_file = target_dir / f"{kernel_name}_gen_time.json"
|
819
|
+
output_file.write_text(
|
820
|
+
json.dumps({"elapsed_txt": elapsed_txt, "elapsed_json": elapsed_json}, indent=2)
|
821
|
+
)
|
822
|
+
struct_locals = extract_struct_locals_from_context(ctx)
|
823
|
+
tree = unpack_ndarray_struct(tree, struct_locals=struct_locals)
|
824
|
+
transform_tree(tree, ctx)
|
825
|
+
if not ctx.is_real_function:
|
826
|
+
if self.return_type and ctx.returned != ReturnStatus.ReturnedValue:
|
827
|
+
raise TaichiSyntaxError("Kernel has a return type but does not have a return statement")
|
828
|
+
finally:
|
829
|
+
self.runtime.inside_kernel = False
|
830
|
+
self.runtime._current_kernel = None
|
831
|
+
self.runtime.compiling_callable = None
|
832
|
+
|
833
|
+
taichi_kernel = impl.get_runtime().prog.create_kernel(taichi_ast_generator, kernel_name, self.autodiff_mode)
|
834
|
+
assert key not in self.compiled_kernels
|
835
|
+
self.compiled_kernels[key] = taichi_kernel
|
836
|
+
|
837
|
+
def launch_kernel(self, t_kernel, *args):
|
838
|
+
assert len(args) == len(self.arguments), f"{len(self.arguments)} arguments needed but {len(args)} provided"
|
839
|
+
|
840
|
+
tmps = []
|
841
|
+
callbacks = []
|
842
|
+
|
843
|
+
actual_argument_slot = 0
|
844
|
+
launch_ctx = t_kernel.make_launch_context()
|
845
|
+
max_arg_num = 64
|
846
|
+
exceed_max_arg_num = False
|
847
|
+
|
848
|
+
def set_arg_ndarray(indices, v):
|
849
|
+
v_primal = v.arr
|
850
|
+
v_grad = v.grad.arr if v.grad else None
|
851
|
+
if v_grad is None:
|
852
|
+
launch_ctx.set_arg_ndarray(indices, v_primal)
|
853
|
+
else:
|
854
|
+
launch_ctx.set_arg_ndarray_with_grad(indices, v_primal, v_grad)
|
855
|
+
|
856
|
+
def set_arg_texture(indices, v):
|
857
|
+
launch_ctx.set_arg_texture(indices, v.tex)
|
858
|
+
|
859
|
+
def set_arg_rw_texture(indices, v):
|
860
|
+
launch_ctx.set_arg_rw_texture(indices, v.tex)
|
861
|
+
|
862
|
+
def set_arg_ext_array(indices, v, needed):
|
863
|
+
# Element shapes are already specialized in Taichi codegen.
|
864
|
+
# The shape information for element dims are no longer needed.
|
865
|
+
# Therefore we strip the element shapes from the shape vector,
|
866
|
+
# so that it only holds "real" array shapes.
|
867
|
+
is_soa = needed.layout == Layout.SOA
|
868
|
+
array_shape = v.shape
|
869
|
+
if functools.reduce(operator.mul, array_shape, 1) > np.iinfo(np.int32).max:
|
870
|
+
warnings.warn("Ndarray index might be out of int32 boundary but int64 indexing is not supported yet.")
|
871
|
+
if needed.dtype is None or id(needed.dtype) in primitive_types.type_ids:
|
872
|
+
element_dim = 0
|
873
|
+
else:
|
874
|
+
element_dim = needed.dtype.ndim
|
875
|
+
array_shape = v.shape[element_dim:] if is_soa else v.shape[:-element_dim]
|
876
|
+
if isinstance(v, np.ndarray):
|
877
|
+
# numpy
|
878
|
+
if v.flags.c_contiguous:
|
879
|
+
launch_ctx.set_arg_external_array_with_shape(indices, int(v.ctypes.data), v.nbytes, array_shape, 0)
|
880
|
+
elif v.flags.f_contiguous:
|
881
|
+
# TODO: A better way that avoids copying is saving strides info.
|
882
|
+
tmp = np.ascontiguousarray(v)
|
883
|
+
# Purpose: DO NOT GC |tmp|!
|
884
|
+
tmps.append(tmp)
|
885
|
+
|
886
|
+
def callback(original, updated):
|
887
|
+
np.copyto(original, np.asfortranarray(updated))
|
888
|
+
|
889
|
+
callbacks.append(functools.partial(callback, v, tmp))
|
890
|
+
launch_ctx.set_arg_external_array_with_shape(
|
891
|
+
indices, int(tmp.ctypes.data), tmp.nbytes, array_shape, 0
|
892
|
+
)
|
893
|
+
else:
|
894
|
+
raise ValueError(
|
895
|
+
"Non contiguous numpy arrays are not supported, please call np.ascontiguousarray(arr) "
|
896
|
+
"before passing it into taichi kernel."
|
897
|
+
)
|
898
|
+
elif has_pytorch():
|
899
|
+
import torch # pylint: disable=C0415
|
900
|
+
|
901
|
+
if isinstance(v, torch.Tensor):
|
902
|
+
if not v.is_contiguous():
|
903
|
+
raise ValueError(
|
904
|
+
"Non contiguous tensors are not supported, please call tensor.contiguous() before "
|
905
|
+
"passing it into taichi kernel."
|
906
|
+
)
|
907
|
+
taichi_arch = self.runtime.prog.config().arch
|
908
|
+
|
909
|
+
def get_call_back(u, v):
|
910
|
+
def call_back():
|
911
|
+
u.copy_(v)
|
912
|
+
|
913
|
+
return call_back
|
914
|
+
|
915
|
+
# FIXME: only allocate when launching grad kernel
|
916
|
+
if v.requires_grad and v.grad is None:
|
917
|
+
v.grad = torch.zeros_like(v)
|
918
|
+
|
919
|
+
if v.requires_grad:
|
920
|
+
if not isinstance(v.grad, torch.Tensor):
|
921
|
+
raise ValueError(
|
922
|
+
f"Expecting torch.Tensor for gradient tensor, but getting {v.grad.__class__.__name__} instead"
|
923
|
+
)
|
924
|
+
if not v.grad.is_contiguous():
|
925
|
+
raise ValueError(
|
926
|
+
"Non contiguous gradient tensors are not supported, please call tensor.grad.contiguous() before passing it into taichi kernel."
|
927
|
+
)
|
928
|
+
|
929
|
+
tmp = v
|
930
|
+
if (str(v.device) != "cpu") and not (
|
931
|
+
str(v.device).startswith("cuda") and taichi_arch == _ti_core.Arch.cuda
|
932
|
+
):
|
933
|
+
# Getting a torch CUDA tensor on Taichi non-cuda arch:
|
934
|
+
# We just replace it with a CPU tensor and by the end of kernel execution we'll use the
|
935
|
+
# callback to copy the values back to the original CUDA tensor.
|
936
|
+
host_v = v.to(device="cpu", copy=True)
|
937
|
+
tmp = host_v
|
938
|
+
callbacks.append(get_call_back(v, host_v))
|
939
|
+
|
940
|
+
launch_ctx.set_arg_external_array_with_shape(
|
941
|
+
indices,
|
942
|
+
int(tmp.data_ptr()),
|
943
|
+
tmp.element_size() * tmp.nelement(),
|
944
|
+
array_shape,
|
945
|
+
int(v.grad.data_ptr()) if v.grad is not None else 0,
|
946
|
+
)
|
947
|
+
else:
|
948
|
+
raise TaichiRuntimeTypeError(f"Argument {needed} cannot be converted into required type {type(v)}")
|
949
|
+
elif has_paddle():
|
950
|
+
import paddle # pylint: disable=C0415 # type: ignore
|
951
|
+
|
952
|
+
if isinstance(v, paddle.Tensor):
|
953
|
+
# For now, paddle.fluid.core.Tensor._ptr() is only available on develop branch
|
954
|
+
def get_call_back(u, v):
|
955
|
+
def call_back():
|
956
|
+
u.copy_(v, False)
|
957
|
+
|
958
|
+
return call_back
|
959
|
+
|
960
|
+
tmp = v.value().get_tensor()
|
961
|
+
taichi_arch = self.runtime.prog.config().arch
|
962
|
+
if v.place.is_gpu_place():
|
963
|
+
if taichi_arch != _ti_core.Arch.cuda:
|
964
|
+
# Paddle cuda tensor on Taichi non-cuda arch
|
965
|
+
host_v = v.cpu()
|
966
|
+
tmp = host_v.value().get_tensor()
|
967
|
+
callbacks.append(get_call_back(v, host_v))
|
968
|
+
elif v.place.is_cpu_place():
|
969
|
+
if taichi_arch == _ti_core.Arch.cuda:
|
970
|
+
# Paddle cpu tensor on Taichi cuda arch
|
971
|
+
gpu_v = v.cuda()
|
972
|
+
tmp = gpu_v.value().get_tensor()
|
973
|
+
callbacks.append(get_call_back(v, gpu_v))
|
974
|
+
else:
|
975
|
+
# Paddle do support many other backends like XPU, NPU, MLU, IPU
|
976
|
+
raise TaichiRuntimeTypeError(f"Taichi do not support backend {v.place} that Paddle support")
|
977
|
+
launch_ctx.set_arg_external_array_with_shape(
|
978
|
+
indices, int(tmp._ptr()), v.element_size() * v.size, array_shape, 0
|
979
|
+
)
|
980
|
+
else:
|
981
|
+
raise TaichiRuntimeTypeError(f"Argument {needed} cannot be converted into required type {v}")
|
982
|
+
else:
|
983
|
+
raise TaichiRuntimeTypeError(f"Argument {needed} cannot be converted into required type {v}")
|
984
|
+
|
985
|
+
def set_arg_matrix(indices, v, needed):
|
986
|
+
def cast_float(x):
|
987
|
+
if not isinstance(x, (int, float, np.integer, np.floating)):
|
988
|
+
raise TaichiRuntimeTypeError(
|
989
|
+
f"Argument {needed.dtype} cannot be converted into required type {type(x)}"
|
990
|
+
)
|
991
|
+
return float(x)
|
992
|
+
|
993
|
+
def cast_int(x):
|
994
|
+
if not isinstance(x, (int, np.integer)):
|
995
|
+
raise TaichiRuntimeTypeError(
|
996
|
+
f"Argument {needed.dtype} cannot be converted into required type {type(x)}"
|
997
|
+
)
|
998
|
+
return int(x)
|
999
|
+
|
1000
|
+
cast_func = None
|
1001
|
+
if needed.dtype in primitive_types.real_types:
|
1002
|
+
cast_func = cast_float
|
1003
|
+
elif needed.dtype in primitive_types.integer_types:
|
1004
|
+
cast_func = cast_int
|
1005
|
+
else:
|
1006
|
+
raise ValueError(f"Matrix dtype {needed.dtype} is not integer type or real type.")
|
1007
|
+
|
1008
|
+
if needed.ndim == 2:
|
1009
|
+
v = [cast_func(v[i, j]) for i in range(needed.n) for j in range(needed.m)]
|
1010
|
+
else:
|
1011
|
+
v = [cast_func(v[i]) for i in range(needed.n)]
|
1012
|
+
v = needed(*v)
|
1013
|
+
needed.set_kernel_struct_args(v, launch_ctx, indices)
|
1014
|
+
|
1015
|
+
def set_arg_sparse_matrix_builder(indices, v):
|
1016
|
+
# Pass only the base pointer of the ti.types.sparse_matrix_builder() argument
|
1017
|
+
launch_ctx.set_arg_uint(indices, v._get_ndarray_addr())
|
1018
|
+
|
1019
|
+
set_later_list = []
|
1020
|
+
|
1021
|
+
def recursive_set_args(needed_arg_type, provided_arg_type, v, indices):
|
1022
|
+
"""
|
1023
|
+
Returns the number of kernel args set
|
1024
|
+
e.g. templates don't set kernel args, so returns 0
|
1025
|
+
a single ndarray is 1 kernel arg, so returns 1
|
1026
|
+
a struct of 3 ndarrays would set 3 kernel args, so return 3
|
1027
|
+
"""
|
1028
|
+
in_argpack = len(indices) > 1
|
1029
|
+
nonlocal actual_argument_slot, exceed_max_arg_num, set_later_list
|
1030
|
+
if actual_argument_slot >= max_arg_num:
|
1031
|
+
exceed_max_arg_num = True
|
1032
|
+
return 0
|
1033
|
+
actual_argument_slot += 1
|
1034
|
+
if isinstance(needed_arg_type, ArgPackType):
|
1035
|
+
if not isinstance(v, ArgPack):
|
1036
|
+
raise TaichiRuntimeTypeError.get(indices, str(needed_arg_type), str(provided_arg_type))
|
1037
|
+
idx_new = 0
|
1038
|
+
for j, (name, anno) in enumerate(needed_arg_type.members.items()):
|
1039
|
+
idx_new += recursive_set_args(anno, type(v[name]), v[name], indices + (idx_new,))
|
1040
|
+
launch_ctx.set_arg_argpack(indices, v._ArgPack__argpack) # type: ignore
|
1041
|
+
return 1
|
1042
|
+
# Note: do not use sth like "needed == f32". That would be slow.
|
1043
|
+
if id(needed_arg_type) in primitive_types.real_type_ids:
|
1044
|
+
if not isinstance(v, (float, int, np.floating, np.integer)):
|
1045
|
+
raise TaichiRuntimeTypeError.get(indices, needed_arg_type.to_string(), provided_arg_type)
|
1046
|
+
if in_argpack:
|
1047
|
+
return 1
|
1048
|
+
launch_ctx.set_arg_float(indices, float(v))
|
1049
|
+
return 1
|
1050
|
+
if id(needed_arg_type) in primitive_types.integer_type_ids:
|
1051
|
+
if not isinstance(v, (int, np.integer)):
|
1052
|
+
raise TaichiRuntimeTypeError.get(indices, needed_arg_type.to_string(), provided_arg_type)
|
1053
|
+
if in_argpack:
|
1054
|
+
return 1
|
1055
|
+
if is_signed(cook_dtype(needed_arg_type)):
|
1056
|
+
launch_ctx.set_arg_int(indices, int(v))
|
1057
|
+
else:
|
1058
|
+
launch_ctx.set_arg_uint(indices, int(v))
|
1059
|
+
return 1
|
1060
|
+
if isinstance(needed_arg_type, sparse_matrix_builder):
|
1061
|
+
if in_argpack:
|
1062
|
+
set_later_list.append((set_arg_sparse_matrix_builder, (v,)))
|
1063
|
+
return 0
|
1064
|
+
set_arg_sparse_matrix_builder(indices, v)
|
1065
|
+
return 1
|
1066
|
+
if dataclasses.is_dataclass(needed_arg_type):
|
1067
|
+
assert provided_arg_type == needed_arg_type
|
1068
|
+
idx = 0
|
1069
|
+
for j, field in enumerate(dataclasses.fields(needed_arg_type)):
|
1070
|
+
assert not isinstance(field.type, str)
|
1071
|
+
field_value = getattr(v, field.name)
|
1072
|
+
idx += recursive_set_args(field.type, field.type, field_value, (indices[0] + idx,))
|
1073
|
+
return idx
|
1074
|
+
if isinstance(needed_arg_type, ndarray_type.NdarrayType) and isinstance(v, taichi.lang._ndarray.Ndarray):
|
1075
|
+
if in_argpack:
|
1076
|
+
set_later_list.append((set_arg_ndarray, (v,)))
|
1077
|
+
return 0
|
1078
|
+
set_arg_ndarray(indices, v)
|
1079
|
+
return 1
|
1080
|
+
if isinstance(needed_arg_type, texture_type.TextureType) and isinstance(v, taichi.lang._texture.Texture):
|
1081
|
+
if in_argpack:
|
1082
|
+
set_later_list.append((set_arg_texture, (v,)))
|
1083
|
+
return 0
|
1084
|
+
set_arg_texture(indices, v)
|
1085
|
+
return 1
|
1086
|
+
if isinstance(needed_arg_type, texture_type.RWTextureType) and isinstance(v, taichi.lang._texture.Texture):
|
1087
|
+
if in_argpack:
|
1088
|
+
set_later_list.append((set_arg_rw_texture, (v,)))
|
1089
|
+
return 0
|
1090
|
+
set_arg_rw_texture(indices, v)
|
1091
|
+
return 1
|
1092
|
+
if isinstance(needed_arg_type, ndarray_type.NdarrayType):
|
1093
|
+
if in_argpack:
|
1094
|
+
set_later_list.append((set_arg_ext_array, (v, needed_arg_type)))
|
1095
|
+
return 0
|
1096
|
+
set_arg_ext_array(indices, v, needed_arg_type)
|
1097
|
+
return 1
|
1098
|
+
if isinstance(needed_arg_type, MatrixType):
|
1099
|
+
if in_argpack:
|
1100
|
+
return 1
|
1101
|
+
set_arg_matrix(indices, v, needed_arg_type)
|
1102
|
+
return 1
|
1103
|
+
if isinstance(needed_arg_type, StructType):
|
1104
|
+
if in_argpack:
|
1105
|
+
return 1
|
1106
|
+
if not isinstance(v, needed_arg_type):
|
1107
|
+
raise TaichiRuntimeTypeError(
|
1108
|
+
f"Argument {provided_arg_type} cannot be converted into required type {needed_arg_type}"
|
1109
|
+
)
|
1110
|
+
needed_arg_type.set_kernel_struct_args(v, launch_ctx, indices)
|
1111
|
+
return 1
|
1112
|
+
if needed_arg_type == template or isinstance(needed_arg_type, template):
|
1113
|
+
return 0
|
1114
|
+
raise ValueError(f"Argument type mismatch. Expecting {needed_arg_type}, got {type(v)}.")
|
1115
|
+
|
1116
|
+
template_num = 0
|
1117
|
+
i_out = 0
|
1118
|
+
for i_in, val in enumerate(args):
|
1119
|
+
needed_ = self.arguments[i_in].annotation
|
1120
|
+
if needed_ == template or isinstance(needed_, template):
|
1121
|
+
template_num += 1
|
1122
|
+
i_out += 1
|
1123
|
+
continue
|
1124
|
+
i_out += recursive_set_args(needed_, type(val), val, (i_out - template_num,))
|
1125
|
+
|
1126
|
+
for i, (set_arg_func, params) in enumerate(set_later_list):
|
1127
|
+
set_arg_func((len(args) - template_num + i,), *params)
|
1128
|
+
|
1129
|
+
if exceed_max_arg_num:
|
1130
|
+
raise TaichiRuntimeError(
|
1131
|
+
f"The number of elements in kernel arguments is too big! Do not exceed {max_arg_num} on {_ti_core.arch_name(impl.current_cfg().arch)} backend."
|
1132
|
+
)
|
1133
|
+
|
1134
|
+
try:
|
1135
|
+
prog = impl.get_runtime().prog
|
1136
|
+
# Compile kernel (& Online Cache & Offline Cache)
|
1137
|
+
compiled_kernel_data = prog.compile_kernel(prog.config(), prog.get_device_caps(), t_kernel)
|
1138
|
+
# Launch kernel
|
1139
|
+
prog.launch_kernel(compiled_kernel_data, launch_ctx)
|
1140
|
+
except Exception as e:
|
1141
|
+
e = handle_exception_from_cpp(e)
|
1142
|
+
if impl.get_runtime().print_full_traceback:
|
1143
|
+
raise e
|
1144
|
+
raise e from None
|
1145
|
+
|
1146
|
+
ret = None
|
1147
|
+
ret_dt = self.return_type
|
1148
|
+
has_ret = ret_dt is not None
|
1149
|
+
|
1150
|
+
if has_ret or self.has_print:
|
1151
|
+
runtime_ops.sync()
|
1152
|
+
|
1153
|
+
if has_ret:
|
1154
|
+
ret = []
|
1155
|
+
for i, ret_type in enumerate(ret_dt):
|
1156
|
+
ret.append(self.construct_kernel_ret(launch_ctx, ret_type, (i,)))
|
1157
|
+
if len(ret_dt) == 1:
|
1158
|
+
ret = ret[0]
|
1159
|
+
if callbacks:
|
1160
|
+
for c in callbacks:
|
1161
|
+
c()
|
1162
|
+
|
1163
|
+
return ret
|
1164
|
+
|
1165
|
+
def construct_kernel_ret(self, launch_ctx, ret_type, index=()):
|
1166
|
+
if isinstance(ret_type, CompoundType):
|
1167
|
+
return ret_type.from_kernel_struct_ret(launch_ctx, index)
|
1168
|
+
if ret_type in primitive_types.integer_types:
|
1169
|
+
if is_signed(cook_dtype(ret_type)):
|
1170
|
+
return launch_ctx.get_struct_ret_int(index)
|
1171
|
+
return launch_ctx.get_struct_ret_uint(index)
|
1172
|
+
if ret_type in primitive_types.real_types:
|
1173
|
+
return launch_ctx.get_struct_ret_float(index)
|
1174
|
+
raise TaichiRuntimeTypeError(f"Invalid return type on index={index}")
|
1175
|
+
|
1176
|
+
def ensure_compiled(self, *args):
|
1177
|
+
instance_id, arg_features = self.mapper.lookup(args)
|
1178
|
+
key = (self.func, instance_id, self.autodiff_mode)
|
1179
|
+
self.materialize(key=key, args=args, arg_features=arg_features)
|
1180
|
+
return key
|
1181
|
+
|
1182
|
+
# For small kernels (< 3us), the performance can be pretty sensitive to overhead in __call__
|
1183
|
+
# Thus this part needs to be fast. (i.e. < 3us on a 4 GHz x64 CPU)
|
1184
|
+
@_shell_pop_print
|
1185
|
+
def __call__(self, *args, **kwargs):
|
1186
|
+
args = _process_args(self, is_func=False, args=args, kwargs=kwargs)
|
1187
|
+
|
1188
|
+
# Transform the primal kernel to forward mode grad kernel
|
1189
|
+
# then recover to primal when exiting the forward mode manager
|
1190
|
+
if self.runtime.fwd_mode_manager and not self.runtime.grad_replaced:
|
1191
|
+
# TODO: if we would like to compute 2nd-order derivatives by forward-on-reverse in a nested context manager fashion,
|
1192
|
+
# i.e., a `Tape` nested in the `FwdMode`, we can transform the kernels with `mode_original == AutodiffMode.REVERSE` only,
|
1193
|
+
# to avoid duplicate computation for 1st-order derivatives
|
1194
|
+
self.runtime.fwd_mode_manager.insert(self)
|
1195
|
+
|
1196
|
+
# Both the class kernels and the plain-function kernels are unified now.
|
1197
|
+
# In both cases, |self.grad| is another Kernel instance that computes the
|
1198
|
+
# gradient. For class kernels, args[0] is always the kernel owner.
|
1199
|
+
|
1200
|
+
# No need to capture grad kernels because they are already bound with their primal kernels
|
1201
|
+
if (
|
1202
|
+
self.autodiff_mode in (AutodiffMode.NONE, AutodiffMode.VALIDATION)
|
1203
|
+
and self.runtime.target_tape
|
1204
|
+
and not self.runtime.grad_replaced
|
1205
|
+
):
|
1206
|
+
self.runtime.target_tape.insert(self, args)
|
1207
|
+
|
1208
|
+
if self.autodiff_mode != AutodiffMode.NONE and impl.current_cfg().opt_level == 0:
|
1209
|
+
_logging.warn("""opt_level = 1 is enforced to enable gradient computation.""")
|
1210
|
+
impl.current_cfg().opt_level = 1
|
1211
|
+
key = self.ensure_compiled(*args)
|
1212
|
+
kernel_cpp = self.compiled_kernels[key]
|
1213
|
+
return self.launch_kernel(kernel_cpp, *args)
|
1214
|
+
|
1215
|
+
|
1216
|
+
# For a Taichi class definition like below:
|
1217
|
+
#
|
1218
|
+
# @ti.data_oriented
|
1219
|
+
# class X:
|
1220
|
+
# @ti.kernel
|
1221
|
+
# def foo(self):
|
1222
|
+
# ...
|
1223
|
+
#
|
1224
|
+
# When ti.kernel runs, the stackframe's |code_context| of Python 3.8(+) is
|
1225
|
+
# different from that of Python 3.7 and below. In 3.8+, it is 'class X:',
|
1226
|
+
# whereas in <=3.7, it is '@ti.data_oriented'. More interestingly, if the class
|
1227
|
+
# inherits, i.e. class X(object):, then in both versions, |code_context| is
|
1228
|
+
# 'class X(object):'...
|
1229
|
+
_KERNEL_CLASS_STACKFRAME_STMT_RES = [
|
1230
|
+
re.compile(r"@(\w+\.)?data_oriented"),
|
1231
|
+
re.compile(r"class "),
|
1232
|
+
]
|
1233
|
+
|
1234
|
+
|
1235
|
+
def _inside_class(level_of_class_stackframe):
|
1236
|
+
try:
|
1237
|
+
maybe_class_frame = sys._getframe(level_of_class_stackframe)
|
1238
|
+
statement_list = inspect.getframeinfo(maybe_class_frame)[3]
|
1239
|
+
if statement_list is None:
|
1240
|
+
return False
|
1241
|
+
first_statment = statement_list[0].strip()
|
1242
|
+
for pat in _KERNEL_CLASS_STACKFRAME_STMT_RES:
|
1243
|
+
if pat.match(first_statment):
|
1244
|
+
return True
|
1245
|
+
except:
|
1246
|
+
pass
|
1247
|
+
return False
|
1248
|
+
|
1249
|
+
|
1250
|
+
def _kernel_impl(_func: Callable, level_of_class_stackframe: int, verbose: bool = False):
|
1251
|
+
# Can decorators determine if a function is being defined inside a class?
|
1252
|
+
# https://stackoverflow.com/a/8793684/12003165
|
1253
|
+
is_classkernel = _inside_class(level_of_class_stackframe + 1)
|
1254
|
+
|
1255
|
+
if verbose:
|
1256
|
+
print(f"kernel={_func.__name__} is_classkernel={is_classkernel}")
|
1257
|
+
primal = Kernel(_func, autodiff_mode=AutodiffMode.NONE, _classkernel=is_classkernel)
|
1258
|
+
adjoint = Kernel(_func, autodiff_mode=AutodiffMode.REVERSE, _classkernel=is_classkernel)
|
1259
|
+
# Having |primal| contains |grad| makes the tape work.
|
1260
|
+
primal.grad = adjoint
|
1261
|
+
|
1262
|
+
if is_classkernel:
|
1263
|
+
# For class kernels, their primal/adjoint callables are constructed
|
1264
|
+
# when the kernel is accessed via the instance inside
|
1265
|
+
# _BoundedDifferentiableMethod.
|
1266
|
+
# This is because we need to bind the kernel or |grad| to the instance
|
1267
|
+
# owning the kernel, which is not known until the kernel is accessed.
|
1268
|
+
#
|
1269
|
+
# See also: _BoundedDifferentiableMethod, data_oriented.
|
1270
|
+
@functools.wraps(_func)
|
1271
|
+
def wrapped(*args, **kwargs):
|
1272
|
+
# If we reach here (we should never), it means the class is not decorated
|
1273
|
+
# with @ti.data_oriented, otherwise getattr would have intercepted the call.
|
1274
|
+
clsobj = type(args[0])
|
1275
|
+
assert not hasattr(clsobj, "_data_oriented")
|
1276
|
+
raise TaichiSyntaxError(f"Please decorate class {clsobj.__name__} with @ti.data_oriented")
|
1277
|
+
|
1278
|
+
else:
|
1279
|
+
|
1280
|
+
@functools.wraps(_func)
|
1281
|
+
def wrapped(*args, **kwargs):
|
1282
|
+
try:
|
1283
|
+
return primal(*args, **kwargs)
|
1284
|
+
except (TaichiCompilationError, TaichiRuntimeError) as e:
|
1285
|
+
if impl.get_runtime().print_full_traceback:
|
1286
|
+
raise e
|
1287
|
+
raise type(e)("\n" + str(e)) from None
|
1288
|
+
|
1289
|
+
wrapped.grad = adjoint
|
1290
|
+
|
1291
|
+
wrapped._is_wrapped_kernel = True
|
1292
|
+
wrapped._is_classkernel = is_classkernel
|
1293
|
+
wrapped._primal = primal
|
1294
|
+
wrapped._adjoint = adjoint
|
1295
|
+
return wrapped
|
1296
|
+
|
1297
|
+
|
1298
|
+
def kernel(fn: Callable):
|
1299
|
+
"""Marks a function as a Taichi kernel.
|
1300
|
+
|
1301
|
+
A Taichi kernel is a function written in Python, and gets JIT compiled by
|
1302
|
+
Taichi into native CPU/GPU instructions (e.g. a series of CUDA kernels).
|
1303
|
+
The top-level ``for`` loops are automatically parallelized, and distributed
|
1304
|
+
to either a CPU thread pool or massively parallel GPUs.
|
1305
|
+
|
1306
|
+
Kernel's gradient kernel would be generated automatically by the AutoDiff system.
|
1307
|
+
|
1308
|
+
See also https://docs.taichi-lang.org/docs/syntax#kernel.
|
1309
|
+
|
1310
|
+
Args:
|
1311
|
+
fn (Callable): the Python function to be decorated
|
1312
|
+
|
1313
|
+
Returns:
|
1314
|
+
Callable: The decorated function
|
1315
|
+
|
1316
|
+
Example::
|
1317
|
+
|
1318
|
+
>>> x = ti.field(ti.i32, shape=(4, 8))
|
1319
|
+
>>>
|
1320
|
+
>>> @ti.kernel
|
1321
|
+
>>> def run():
|
1322
|
+
>>> # Assigns all the elements of `x` in parallel.
|
1323
|
+
>>> for i in x:
|
1324
|
+
>>> x[i] = i
|
1325
|
+
"""
|
1326
|
+
return _kernel_impl(fn, level_of_class_stackframe=3)
|
1327
|
+
|
1328
|
+
|
1329
|
+
class _BoundedDifferentiableMethod:
|
1330
|
+
def __init__(self, kernel_owner, wrapped_kernel_func):
|
1331
|
+
clsobj = type(kernel_owner)
|
1332
|
+
if not getattr(clsobj, "_data_oriented", False):
|
1333
|
+
raise TaichiSyntaxError(f"Please decorate class {clsobj.__name__} with @ti.data_oriented")
|
1334
|
+
self._kernel_owner = kernel_owner
|
1335
|
+
self._primal = wrapped_kernel_func._primal
|
1336
|
+
self._adjoint = wrapped_kernel_func._adjoint
|
1337
|
+
self._is_staticmethod = wrapped_kernel_func._is_staticmethod
|
1338
|
+
self.__name__: str | None = None
|
1339
|
+
|
1340
|
+
def __call__(self, *args, **kwargs):
|
1341
|
+
try:
|
1342
|
+
if self._is_staticmethod:
|
1343
|
+
return self._primal(*args, **kwargs)
|
1344
|
+
return self._primal(self._kernel_owner, *args, **kwargs)
|
1345
|
+
except (TaichiCompilationError, TaichiRuntimeError) as e:
|
1346
|
+
if impl.get_runtime().print_full_traceback:
|
1347
|
+
raise e
|
1348
|
+
raise type(e)("\n" + str(e)) from None
|
1349
|
+
|
1350
|
+
def grad(self, *args, **kwargs):
|
1351
|
+
return self._adjoint(self._kernel_owner, *args, **kwargs)
|
1352
|
+
|
1353
|
+
|
1354
|
+
def data_oriented(cls):
|
1355
|
+
"""Marks a class as Taichi compatible.
|
1356
|
+
|
1357
|
+
To allow for modularized code, Taichi provides this decorator so that
|
1358
|
+
Taichi kernels can be defined inside a class.
|
1359
|
+
|
1360
|
+
See also https://docs.taichi-lang.org/docs/odop
|
1361
|
+
|
1362
|
+
Example::
|
1363
|
+
|
1364
|
+
>>> @ti.data_oriented
|
1365
|
+
>>> class TiArray:
|
1366
|
+
>>> def __init__(self, n):
|
1367
|
+
>>> self.x = ti.field(ti.f32, shape=n)
|
1368
|
+
>>>
|
1369
|
+
>>> @ti.kernel
|
1370
|
+
>>> def inc(self):
|
1371
|
+
>>> for i in self.x:
|
1372
|
+
>>> self.x[i] += 1.0
|
1373
|
+
>>>
|
1374
|
+
>>> a = TiArray(32)
|
1375
|
+
>>> a.inc()
|
1376
|
+
|
1377
|
+
Args:
|
1378
|
+
cls (Class): the class to be decorated
|
1379
|
+
|
1380
|
+
Returns:
|
1381
|
+
The decorated class.
|
1382
|
+
"""
|
1383
|
+
|
1384
|
+
def _getattr(self, item):
|
1385
|
+
method = cls.__dict__.get(item, None)
|
1386
|
+
is_property = method.__class__ == property
|
1387
|
+
is_staticmethod = method.__class__ == staticmethod
|
1388
|
+
if is_property:
|
1389
|
+
x = method.fget
|
1390
|
+
else:
|
1391
|
+
x = super(cls, self).__getattribute__(item)
|
1392
|
+
if hasattr(x, "_is_wrapped_kernel"):
|
1393
|
+
if inspect.ismethod(x):
|
1394
|
+
wrapped = x.__func__
|
1395
|
+
else:
|
1396
|
+
wrapped = x
|
1397
|
+
wrapped._is_staticmethod = is_staticmethod
|
1398
|
+
assert inspect.isfunction(wrapped)
|
1399
|
+
if wrapped._is_classkernel:
|
1400
|
+
ret = _BoundedDifferentiableMethod(self, wrapped)
|
1401
|
+
ret.__name__ = wrapped.__name__
|
1402
|
+
if is_property:
|
1403
|
+
return ret()
|
1404
|
+
return ret
|
1405
|
+
if is_property:
|
1406
|
+
return x(self)
|
1407
|
+
return x
|
1408
|
+
|
1409
|
+
cls.__getattribute__ = _getattr
|
1410
|
+
cls._data_oriented = True
|
1411
|
+
|
1412
|
+
return cls
|
1413
|
+
|
1414
|
+
|
1415
|
+
__all__ = ["data_oriented", "func", "kernel", "pyfunc", "real_func"]
|