gstaichi 0.1.18.dev1__cp310-cp310-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv.h +2568 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv.hpp +2579 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-0.1.18.dev1.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.dist-info/METADATA +108 -0
- gstaichi-0.1.18.dev1.dist-info/RECORD +219 -0
- gstaichi-0.1.18.dev1.dist-info/WHEEL +5 -0
- gstaichi-0.1.18.dev1.dist-info/entry_points.txt +2 -0
- gstaichi-0.1.18.dev1.dist-info/licenses/LICENSE +201 -0
- gstaichi-0.1.18.dev1.dist-info/top_level.txt +1 -0
- taichi/__init__.py +44 -0
- taichi/__main__.py +5 -0
- taichi/_funcs.py +706 -0
- taichi/_kernels.py +420 -0
- taichi/_lib/__init__.py +3 -0
- taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +1401 -0
- taichi/_lib/c_api/include/taichi/taichi.h +29 -0
- taichi/_lib/c_api/include/taichi/taichi_core.h +1111 -0
- taichi/_lib/c_api/include/taichi/taichi_cpu.h +29 -0
- taichi/_lib/c_api/include/taichi/taichi_metal.h +72 -0
- taichi/_lib/c_api/include/taichi/taichi_platform.h +55 -0
- taichi/_lib/c_api/include/taichi/taichi_unity.h +64 -0
- taichi/_lib/c_api/include/taichi/taichi_vulkan.h +151 -0
- taichi/_lib/c_api/lib/libtaichi_c_api.dylib +0 -0
- taichi/_lib/c_api/runtime/libMoltenVK.dylib +0 -0
- taichi/_lib/c_api/runtime/runtime_arm64.bc +0 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +29 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +65 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +121 -0
- taichi/_lib/core/__init__.py +0 -0
- taichi/_lib/core/py.typed +0 -0
- taichi/_lib/core/taichi_python.cpython-310-darwin.so +0 -0
- taichi/_lib/core/taichi_python.pyi +3077 -0
- taichi/_lib/runtime/libMoltenVK.dylib +0 -0
- taichi/_lib/runtime/runtime_arm64.bc +0 -0
- taichi/_lib/utils.py +249 -0
- taichi/_logging.py +131 -0
- taichi/_main.py +552 -0
- taichi/_snode/__init__.py +5 -0
- taichi/_snode/fields_builder.py +189 -0
- taichi/_snode/snode_tree.py +34 -0
- taichi/_ti_module/__init__.py +3 -0
- taichi/_ti_module/cppgen.py +309 -0
- taichi/_ti_module/module.py +145 -0
- taichi/_version.py +1 -0
- taichi/_version_check.py +100 -0
- taichi/ad/__init__.py +3 -0
- taichi/ad/_ad.py +530 -0
- taichi/algorithms/__init__.py +3 -0
- taichi/algorithms/_algorithms.py +117 -0
- taichi/aot/__init__.py +12 -0
- taichi/aot/_export.py +28 -0
- taichi/aot/conventions/__init__.py +3 -0
- taichi/aot/conventions/gfxruntime140/__init__.py +38 -0
- taichi/aot/conventions/gfxruntime140/dr.py +244 -0
- taichi/aot/conventions/gfxruntime140/sr.py +613 -0
- taichi/aot/module.py +253 -0
- taichi/aot/utils.py +151 -0
- taichi/assets/.git +1 -0
- taichi/assets/Go-Regular.ttf +0 -0
- taichi/assets/static/imgs/ti_gallery.png +0 -0
- taichi/examples/minimal.py +28 -0
- taichi/experimental.py +16 -0
- taichi/graph/__init__.py +3 -0
- taichi/graph/_graph.py +292 -0
- taichi/lang/__init__.py +50 -0
- taichi/lang/_ndarray.py +348 -0
- taichi/lang/_ndrange.py +152 -0
- taichi/lang/_texture.py +172 -0
- taichi/lang/_wrap_inspect.py +189 -0
- taichi/lang/any_array.py +99 -0
- taichi/lang/argpack.py +411 -0
- taichi/lang/ast/__init__.py +5 -0
- taichi/lang/ast/ast_transformer.py +1806 -0
- taichi/lang/ast/ast_transformer_utils.py +328 -0
- taichi/lang/ast/checkers.py +106 -0
- taichi/lang/ast/symbol_resolver.py +57 -0
- taichi/lang/ast/transform.py +9 -0
- taichi/lang/common_ops.py +310 -0
- taichi/lang/exception.py +80 -0
- taichi/lang/expr.py +180 -0
- taichi/lang/field.py +464 -0
- taichi/lang/impl.py +1246 -0
- taichi/lang/kernel_arguments.py +157 -0
- taichi/lang/kernel_impl.py +1415 -0
- taichi/lang/matrix.py +1877 -0
- taichi/lang/matrix_ops.py +341 -0
- taichi/lang/matrix_ops_utils.py +190 -0
- taichi/lang/mesh.py +687 -0
- taichi/lang/misc.py +807 -0
- taichi/lang/ops.py +1489 -0
- taichi/lang/runtime_ops.py +13 -0
- taichi/lang/shell.py +35 -0
- taichi/lang/simt/__init__.py +5 -0
- taichi/lang/simt/block.py +94 -0
- taichi/lang/simt/grid.py +7 -0
- taichi/lang/simt/subgroup.py +191 -0
- taichi/lang/simt/warp.py +96 -0
- taichi/lang/snode.py +487 -0
- taichi/lang/source_builder.py +150 -0
- taichi/lang/struct.py +855 -0
- taichi/lang/util.py +381 -0
- taichi/linalg/__init__.py +8 -0
- taichi/linalg/matrixfree_cg.py +310 -0
- taichi/linalg/sparse_cg.py +59 -0
- taichi/linalg/sparse_matrix.py +303 -0
- taichi/linalg/sparse_solver.py +123 -0
- taichi/math/__init__.py +11 -0
- taichi/math/_complex.py +204 -0
- taichi/math/mathimpl.py +886 -0
- taichi/profiler/__init__.py +6 -0
- taichi/profiler/kernel_metrics.py +260 -0
- taichi/profiler/kernel_profiler.py +592 -0
- taichi/profiler/memory_profiler.py +15 -0
- taichi/profiler/scoped_profiler.py +36 -0
- taichi/shaders/Circles_vk.frag +29 -0
- taichi/shaders/Circles_vk.vert +45 -0
- taichi/shaders/Circles_vk_frag.spv +0 -0
- taichi/shaders/Circles_vk_vert.spv +0 -0
- taichi/shaders/Lines_vk.frag +9 -0
- taichi/shaders/Lines_vk.vert +11 -0
- taichi/shaders/Lines_vk_frag.spv +0 -0
- taichi/shaders/Lines_vk_vert.spv +0 -0
- taichi/shaders/Mesh_vk.frag +71 -0
- taichi/shaders/Mesh_vk.vert +68 -0
- taichi/shaders/Mesh_vk_frag.spv +0 -0
- taichi/shaders/Mesh_vk_vert.spv +0 -0
- taichi/shaders/Particles_vk.frag +95 -0
- taichi/shaders/Particles_vk.vert +73 -0
- taichi/shaders/Particles_vk_frag.spv +0 -0
- taichi/shaders/Particles_vk_vert.spv +0 -0
- taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
- taichi/shaders/SceneLines_vk.frag +9 -0
- taichi/shaders/SceneLines_vk.vert +12 -0
- taichi/shaders/SceneLines_vk_frag.spv +0 -0
- taichi/shaders/SceneLines_vk_vert.spv +0 -0
- taichi/shaders/SetImage_vk.frag +21 -0
- taichi/shaders/SetImage_vk.vert +15 -0
- taichi/shaders/SetImage_vk_frag.spv +0 -0
- taichi/shaders/SetImage_vk_vert.spv +0 -0
- taichi/shaders/Triangles_vk.frag +16 -0
- taichi/shaders/Triangles_vk.vert +29 -0
- taichi/shaders/Triangles_vk_frag.spv +0 -0
- taichi/shaders/Triangles_vk_vert.spv +0 -0
- taichi/shaders/lines2quad_vk_comp.spv +0 -0
- taichi/sparse/__init__.py +3 -0
- taichi/sparse/_sparse_grid.py +77 -0
- taichi/tools/__init__.py +12 -0
- taichi/tools/diagnose.py +124 -0
- taichi/tools/np2ply.py +364 -0
- taichi/tools/vtk.py +38 -0
- taichi/types/__init__.py +19 -0
- taichi/types/annotations.py +47 -0
- taichi/types/compound_types.py +90 -0
- taichi/types/enums.py +49 -0
- taichi/types/ndarray_type.py +147 -0
- taichi/types/primitive_types.py +203 -0
- taichi/types/quant.py +88 -0
- taichi/types/texture_type.py +85 -0
- taichi/types/utils.py +13 -0
taichi/lang/impl.py
ADDED
@@ -0,0 +1,1246 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import numbers
|
4
|
+
from types import FunctionType, MethodType
|
5
|
+
from typing import Any, Iterable, Sequence
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
|
9
|
+
from taichi._lib import core as _ti_core
|
10
|
+
from taichi._lib.core.taichi_python import (
|
11
|
+
DataType,
|
12
|
+
Function,
|
13
|
+
Program,
|
14
|
+
)
|
15
|
+
from taichi._snode.fields_builder import FieldsBuilder
|
16
|
+
from taichi.lang._ndarray import ScalarNdarray
|
17
|
+
from taichi.lang._ndrange import GroupedNDRange, _Ndrange
|
18
|
+
from taichi.lang._texture import RWTextureAccessor
|
19
|
+
from taichi.lang.any_array import AnyArray
|
20
|
+
from taichi.lang.exception import (
|
21
|
+
TaichiCompilationError,
|
22
|
+
TaichiRuntimeError,
|
23
|
+
TaichiSyntaxError,
|
24
|
+
TaichiTypeError,
|
25
|
+
)
|
26
|
+
from taichi.lang.expr import Expr, make_expr_group
|
27
|
+
from taichi.lang.field import Field, ScalarField
|
28
|
+
from taichi.lang.kernel_arguments import SparseMatrixProxy
|
29
|
+
from taichi.lang.kernel_impl import Kernel
|
30
|
+
from taichi.lang.matrix import (
|
31
|
+
Matrix,
|
32
|
+
MatrixField,
|
33
|
+
MatrixNdarray,
|
34
|
+
MatrixType,
|
35
|
+
Vector,
|
36
|
+
VectorNdarray,
|
37
|
+
make_matrix,
|
38
|
+
)
|
39
|
+
from taichi.lang.mesh import (
|
40
|
+
ConvType,
|
41
|
+
MeshElementFieldProxy,
|
42
|
+
MeshInstance,
|
43
|
+
MeshRelationAccessProxy,
|
44
|
+
MeshReorderedMatrixFieldProxy,
|
45
|
+
MeshReorderedScalarFieldProxy,
|
46
|
+
element_type_name,
|
47
|
+
)
|
48
|
+
from taichi.lang.simt.block import SharedArray
|
49
|
+
from taichi.lang.snode import SNode
|
50
|
+
from taichi.lang.struct import Struct, StructField, _IntermediateStruct
|
51
|
+
from taichi.lang.util import (
|
52
|
+
cook_dtype,
|
53
|
+
get_traceback,
|
54
|
+
is_taichi_class,
|
55
|
+
python_scope,
|
56
|
+
taichi_scope,
|
57
|
+
warning,
|
58
|
+
)
|
59
|
+
from taichi.types.enums import SNodeGradType
|
60
|
+
from taichi.types.primitive_types import (
|
61
|
+
all_types,
|
62
|
+
f16,
|
63
|
+
f32,
|
64
|
+
f64,
|
65
|
+
i32,
|
66
|
+
i64,
|
67
|
+
u8,
|
68
|
+
u32,
|
69
|
+
u64,
|
70
|
+
)
|
71
|
+
|
72
|
+
|
73
|
+
@taichi_scope
|
74
|
+
def expr_init_shared_array(shape, element_type):
|
75
|
+
compiling_callable = get_runtime().compiling_callable
|
76
|
+
assert compiling_callable is not None
|
77
|
+
return compiling_callable.ast_builder().expr_alloca_shared_array(
|
78
|
+
shape, element_type, _ti_core.DebugInfo(get_runtime().get_current_src_info())
|
79
|
+
)
|
80
|
+
|
81
|
+
|
82
|
+
@taichi_scope
|
83
|
+
def expr_init(rhs):
|
84
|
+
compiling_callable = get_runtime().compiling_callable
|
85
|
+
assert compiling_callable is not None
|
86
|
+
if rhs is None:
|
87
|
+
return Expr(
|
88
|
+
compiling_callable.ast_builder().expr_alloca(_ti_core.DebugInfo(get_runtime().get_current_src_info()))
|
89
|
+
)
|
90
|
+
if isinstance(rhs, Matrix) and (hasattr(rhs, "_DIM")):
|
91
|
+
return Matrix(*rhs.to_list(), ndim=rhs.ndim)
|
92
|
+
if isinstance(rhs, Matrix):
|
93
|
+
return make_matrix(rhs.to_list())
|
94
|
+
if isinstance(rhs, SharedArray):
|
95
|
+
return rhs
|
96
|
+
if isinstance(rhs, Struct):
|
97
|
+
return Struct(rhs.to_dict(include_methods=True, include_ndim=True))
|
98
|
+
if isinstance(rhs, list):
|
99
|
+
return [expr_init(e) for e in rhs]
|
100
|
+
if isinstance(rhs, tuple):
|
101
|
+
return tuple(expr_init(e) for e in rhs)
|
102
|
+
if isinstance(rhs, dict):
|
103
|
+
return dict((key, expr_init(val)) for key, val in rhs.items())
|
104
|
+
if isinstance(rhs, _ti_core.DataType):
|
105
|
+
return rhs
|
106
|
+
if isinstance(rhs, _ti_core.Arch):
|
107
|
+
return rhs
|
108
|
+
if isinstance(rhs, _Ndrange):
|
109
|
+
return rhs
|
110
|
+
if isinstance(rhs, MeshElementFieldProxy):
|
111
|
+
return rhs
|
112
|
+
if isinstance(rhs, MeshRelationAccessProxy):
|
113
|
+
return rhs
|
114
|
+
if hasattr(rhs, "_data_oriented"):
|
115
|
+
return rhs
|
116
|
+
return Expr(
|
117
|
+
compiling_callable.ast_builder().expr_var(
|
118
|
+
Expr(rhs).ptr, _ti_core.DebugInfo(get_runtime().get_current_src_info())
|
119
|
+
)
|
120
|
+
)
|
121
|
+
|
122
|
+
|
123
|
+
@taichi_scope
|
124
|
+
def expr_init_func(rhs): # temporary solution to allow passing in fields as arguments
|
125
|
+
if isinstance(rhs, Field):
|
126
|
+
return rhs
|
127
|
+
return expr_init(rhs)
|
128
|
+
|
129
|
+
|
130
|
+
def begin_frontend_struct_for(ast_builder, group, loop_range):
|
131
|
+
if not isinstance(loop_range, (AnyArray, Field, SNode, RWTextureAccessor, _Root)):
|
132
|
+
raise TypeError(
|
133
|
+
f"Cannot loop over the object {type(loop_range)} in Taichi scope. Only Taichi fields (via template) or dense arrays (via types.ndarray) are supported."
|
134
|
+
)
|
135
|
+
if group.size() != len(loop_range.shape):
|
136
|
+
raise IndexError(
|
137
|
+
"Number of struct-for indices does not match loop variable dimensionality "
|
138
|
+
f"({group.size()} != {len(loop_range.shape)}). Maybe you wanted to "
|
139
|
+
'use "for I in ti.grouped(x)" to group all indices into a single vector I?'
|
140
|
+
)
|
141
|
+
dbg_info = _ti_core.DebugInfo(get_runtime().get_current_src_info())
|
142
|
+
if isinstance(loop_range, (AnyArray, RWTextureAccessor)):
|
143
|
+
ast_builder.begin_frontend_struct_for_on_external_tensor(group, loop_range._loop_range(), dbg_info)
|
144
|
+
else:
|
145
|
+
ast_builder.begin_frontend_struct_for_on_snode(group, loop_range._loop_range(), dbg_info)
|
146
|
+
|
147
|
+
|
148
|
+
def begin_frontend_if(ast_builder, cond, stmt_dbg_info):
|
149
|
+
assert ast_builder is not None
|
150
|
+
if is_taichi_class(cond):
|
151
|
+
raise ValueError(
|
152
|
+
"The truth value of vectors/matrices is ambiguous.\n"
|
153
|
+
"Consider using `any` or `all` when comparing vectors/matrices:\n"
|
154
|
+
" if all(x == y):\n"
|
155
|
+
"or\n"
|
156
|
+
" if any(x != y):\n"
|
157
|
+
)
|
158
|
+
ast_builder.begin_frontend_if(Expr(cond).ptr, stmt_dbg_info)
|
159
|
+
|
160
|
+
|
161
|
+
@taichi_scope
|
162
|
+
def _calc_slice(index, default_stop):
|
163
|
+
start, stop, step = index.start or 0, index.stop or default_stop, index.step or 1
|
164
|
+
|
165
|
+
def check_validity(x):
|
166
|
+
# TODO(mzmzm): support variable in slice
|
167
|
+
if isinstance(x, Expr):
|
168
|
+
raise TaichiCompilationError(
|
169
|
+
"Taichi does not support variables in slice now, please use constant instead of it."
|
170
|
+
)
|
171
|
+
|
172
|
+
check_validity(start), check_validity(stop), check_validity(step)
|
173
|
+
return [_ for _ in range(start, stop, step)]
|
174
|
+
|
175
|
+
|
176
|
+
def validate_subscript_index(value, index):
|
177
|
+
if isinstance(value, Field):
|
178
|
+
# field supports negative indices
|
179
|
+
return
|
180
|
+
|
181
|
+
if isinstance(index, Expr):
|
182
|
+
return
|
183
|
+
|
184
|
+
if isinstance(index, Iterable):
|
185
|
+
for ind in index:
|
186
|
+
validate_subscript_index(value, ind)
|
187
|
+
|
188
|
+
if isinstance(index, slice):
|
189
|
+
validate_subscript_index(value, index.start)
|
190
|
+
validate_subscript_index(value, index.stop)
|
191
|
+
|
192
|
+
if isinstance(index, int) and index < 0:
|
193
|
+
raise TaichiSyntaxError("Negative indices are not supported in Taichi kernels.")
|
194
|
+
|
195
|
+
|
196
|
+
@taichi_scope
|
197
|
+
def subscript(ast_builder, value, *_indices, skip_reordered=False):
|
198
|
+
dbg_info = _ti_core.DebugInfo(get_runtime().get_current_src_info())
|
199
|
+
compiling_callable = get_runtime().compiling_callable
|
200
|
+
assert compiling_callable is not None
|
201
|
+
ast_builder = compiling_callable.ast_builder()
|
202
|
+
# Directly evaluate in Python for non-Taichi types
|
203
|
+
if not isinstance(
|
204
|
+
value,
|
205
|
+
(
|
206
|
+
Expr,
|
207
|
+
Field,
|
208
|
+
AnyArray,
|
209
|
+
SparseMatrixProxy,
|
210
|
+
MeshElementFieldProxy,
|
211
|
+
MeshRelationAccessProxy,
|
212
|
+
SharedArray,
|
213
|
+
),
|
214
|
+
):
|
215
|
+
if len(_indices) == 1:
|
216
|
+
_indices = _indices[0]
|
217
|
+
return value.__getitem__(_indices)
|
218
|
+
|
219
|
+
has_slice = False
|
220
|
+
|
221
|
+
flattened_indices = []
|
222
|
+
for _index in _indices:
|
223
|
+
if isinstance(_index, Matrix):
|
224
|
+
ind = _index.to_list()
|
225
|
+
elif isinstance(_index, slice):
|
226
|
+
ind = [_index]
|
227
|
+
has_slice = True
|
228
|
+
else:
|
229
|
+
ind = [_index]
|
230
|
+
flattened_indices += ind
|
231
|
+
indices = tuple(flattened_indices)
|
232
|
+
validate_subscript_index(value, indices)
|
233
|
+
|
234
|
+
if len(indices) == 1 and indices[0] is None:
|
235
|
+
indices = ()
|
236
|
+
|
237
|
+
indices_expr_group = None
|
238
|
+
if has_slice:
|
239
|
+
if not (isinstance(value, Expr) and value.is_tensor()):
|
240
|
+
raise TaichiSyntaxError(f"The type {type(value)} do not support index of slice type")
|
241
|
+
else:
|
242
|
+
indices_expr_group = make_expr_group(*indices)
|
243
|
+
|
244
|
+
if isinstance(value, SharedArray):
|
245
|
+
return value.subscript(*indices)
|
246
|
+
if isinstance(value, MeshElementFieldProxy):
|
247
|
+
return value.subscript(*indices)
|
248
|
+
if isinstance(value, MeshRelationAccessProxy):
|
249
|
+
return value.subscript(*indices)
|
250
|
+
if isinstance(value, (MeshReorderedScalarFieldProxy, MeshReorderedMatrixFieldProxy)) and not skip_reordered:
|
251
|
+
assert len(indices) > 0
|
252
|
+
reordered_index = tuple(
|
253
|
+
[
|
254
|
+
Expr(
|
255
|
+
ast_builder.mesh_index_conversion(
|
256
|
+
value.mesh_ptr, value.element_type, Expr(indices[0]).ptr, ConvType.g2r, dbg_info
|
257
|
+
)
|
258
|
+
)
|
259
|
+
]
|
260
|
+
)
|
261
|
+
return subscript(ast_builder, value, *reordered_index, skip_reordered=True)
|
262
|
+
if isinstance(value, SparseMatrixProxy):
|
263
|
+
return value.subscript(*indices)
|
264
|
+
if isinstance(value, Field):
|
265
|
+
_var = value._get_field_members()[0].ptr
|
266
|
+
snode = _var.snode()
|
267
|
+
if snode is None:
|
268
|
+
if _var.is_primal():
|
269
|
+
raise RuntimeError(f"{_var.get_expr_name()} has not been placed.")
|
270
|
+
else:
|
271
|
+
raise RuntimeError(
|
272
|
+
f"Gradient {_var.get_expr_name()} has not been placed, check whether `needs_grad=True`"
|
273
|
+
)
|
274
|
+
|
275
|
+
assert indices_expr_group is not None
|
276
|
+
if isinstance(value, MatrixField):
|
277
|
+
return Expr(ast_builder.expr_subscript(value.ptr, indices_expr_group, dbg_info))
|
278
|
+
if isinstance(value, StructField):
|
279
|
+
entries = {k: subscript(ast_builder, v, *indices) for k, v in value._items}
|
280
|
+
entries["__struct_methods"] = value.struct_methods
|
281
|
+
return _IntermediateStruct(entries)
|
282
|
+
return Expr(ast_builder.expr_subscript(_var, indices_expr_group, dbg_info))
|
283
|
+
if isinstance(value, AnyArray):
|
284
|
+
assert indices_expr_group is not None
|
285
|
+
return Expr(ast_builder.expr_subscript(value.ptr, indices_expr_group, dbg_info))
|
286
|
+
assert isinstance(value, Expr)
|
287
|
+
# Index into TensorType
|
288
|
+
# value: IndexExpression with ret_type = TensorType
|
289
|
+
assert value.is_tensor()
|
290
|
+
|
291
|
+
if has_slice:
|
292
|
+
shape = value.get_shape()
|
293
|
+
dim = len(shape)
|
294
|
+
assert dim == len(indices)
|
295
|
+
indices = [
|
296
|
+
_calc_slice(index, shape[i]) if isinstance(index, slice) else index for i, index in enumerate(indices)
|
297
|
+
]
|
298
|
+
if dim == 1:
|
299
|
+
assert isinstance(indices[0], list)
|
300
|
+
multiple_indices = [make_expr_group(i) for i in indices[0]]
|
301
|
+
return_shape = (len(indices[0]),)
|
302
|
+
else:
|
303
|
+
assert dim == 2
|
304
|
+
if isinstance(indices[0], list) and isinstance(indices[1], list):
|
305
|
+
multiple_indices = [make_expr_group(i, j) for i in indices[0] for j in indices[1]]
|
306
|
+
return_shape = (len(indices[0]), len(indices[1]))
|
307
|
+
elif isinstance(indices[0], list): # indices[1] is not list
|
308
|
+
multiple_indices = [make_expr_group(i, indices[1]) for i in indices[0]]
|
309
|
+
return_shape = (len(indices[0]),)
|
310
|
+
else: # indices[0] is not list while indices[1] is list
|
311
|
+
multiple_indices = [make_expr_group(indices[0], j) for j in indices[1]]
|
312
|
+
return_shape = (len(indices[1]),)
|
313
|
+
return Expr(
|
314
|
+
_ti_core.subscript_with_multiple_indices(
|
315
|
+
value.ptr,
|
316
|
+
multiple_indices,
|
317
|
+
return_shape,
|
318
|
+
dbg_info,
|
319
|
+
)
|
320
|
+
)
|
321
|
+
return Expr(ast_builder.expr_subscript(value.ptr, indices_expr_group, dbg_info))
|
322
|
+
|
323
|
+
|
324
|
+
class SrcInfoGuard:
|
325
|
+
def __init__(self, info_stack, info):
|
326
|
+
self.info_stack = info_stack
|
327
|
+
self.info = info
|
328
|
+
|
329
|
+
def __enter__(self):
|
330
|
+
self.info_stack.append(self.info)
|
331
|
+
|
332
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
333
|
+
self.info_stack.pop()
|
334
|
+
|
335
|
+
|
336
|
+
class PyTaichi:
|
337
|
+
def __init__(self, kernels=None):
|
338
|
+
self.materialized = False
|
339
|
+
self._prog: Program | None = None
|
340
|
+
self.src_info_stack = []
|
341
|
+
self.inside_kernel: bool = False
|
342
|
+
self.compiling_callable: Kernel | Function | None = None # pointer to instance of lang::Kernel/Function
|
343
|
+
self._current_kernel: Kernel | None = None
|
344
|
+
self.global_vars = []
|
345
|
+
self.grad_vars = []
|
346
|
+
self.dual_vars = []
|
347
|
+
self.matrix_fields = []
|
348
|
+
self.default_fp = f32
|
349
|
+
self.default_ip = i32
|
350
|
+
self.default_up = u32
|
351
|
+
self.print_full_traceback: bool = False
|
352
|
+
self.target_tape = None
|
353
|
+
self.fwd_mode_manager = None
|
354
|
+
self.grad_replaced = False
|
355
|
+
self.kernels = kernels or []
|
356
|
+
self._signal_handler_registry = None
|
357
|
+
self.unfinalized_fields_builder = {}
|
358
|
+
|
359
|
+
@property
|
360
|
+
def prog(self) -> Program:
|
361
|
+
if self._prog is None:
|
362
|
+
raise TaichiRuntimeError("_prog attribute not initialized. Maybe you forgot to call `ti.init()` first?")
|
363
|
+
return self._prog
|
364
|
+
|
365
|
+
@property
|
366
|
+
def current_kernel(self) -> Kernel:
|
367
|
+
if self._current_kernel is None:
|
368
|
+
raise TaichiRuntimeError(
|
369
|
+
"_pr_current_kernelog attribute not initialized. Maybe you forgot to call `ti.init()` first?"
|
370
|
+
)
|
371
|
+
return self._current_kernel
|
372
|
+
|
373
|
+
def initialize_fields_builder(self, builder):
|
374
|
+
self.unfinalized_fields_builder[builder] = get_traceback(2)
|
375
|
+
|
376
|
+
def clear_compiled_functions(self):
|
377
|
+
for k in self.kernels:
|
378
|
+
k.compiled_kernels.clear()
|
379
|
+
|
380
|
+
def finalize_fields_builder(self, builder):
|
381
|
+
self.unfinalized_fields_builder.pop(builder)
|
382
|
+
|
383
|
+
def validate_fields_builder(self):
|
384
|
+
for builder, tb in self.unfinalized_fields_builder.items():
|
385
|
+
if builder == _root_fb:
|
386
|
+
continue
|
387
|
+
|
388
|
+
raise TaichiRuntimeError(
|
389
|
+
f"Field builder {builder} is not finalized. " f"Please call finalize() on it. Traceback:\n{tb}"
|
390
|
+
)
|
391
|
+
|
392
|
+
def get_num_compiled_functions(self):
|
393
|
+
count = 0
|
394
|
+
for k in self.kernels:
|
395
|
+
count += len(k.compiled_kernels)
|
396
|
+
return count
|
397
|
+
|
398
|
+
def src_info_guard(self, info):
|
399
|
+
return SrcInfoGuard(self.src_info_stack, info)
|
400
|
+
|
401
|
+
def get_current_src_info(self):
|
402
|
+
return self.src_info_stack[-1]
|
403
|
+
|
404
|
+
def set_default_fp(self, fp):
|
405
|
+
assert fp in [f16, f32, f64]
|
406
|
+
self.default_fp = fp
|
407
|
+
default_cfg().default_fp = self.default_fp
|
408
|
+
|
409
|
+
def set_default_ip(self, ip):
|
410
|
+
assert ip in [i32, i64]
|
411
|
+
self.default_ip = ip
|
412
|
+
self.default_up = u32 if ip == i32 else u64
|
413
|
+
default_cfg().default_ip = self.default_ip
|
414
|
+
default_cfg().default_up = self.default_up
|
415
|
+
|
416
|
+
def create_program(self):
|
417
|
+
if self._prog is None:
|
418
|
+
self._prog = _ti_core.Program()
|
419
|
+
|
420
|
+
@staticmethod
|
421
|
+
def materialize_root_fb(is_first_call):
|
422
|
+
if root.finalized:
|
423
|
+
return
|
424
|
+
if not is_first_call and root.empty:
|
425
|
+
# We have to forcefully finalize when `is_first_call` is True (even
|
426
|
+
# if the root itself is empty), so that there is a valid struct
|
427
|
+
# llvm::Module, if no field has been declared before the first kernel
|
428
|
+
# invocation. Example case:
|
429
|
+
# https://github.com/taichi-dev/taichi/blob/27bb1dc3227d9273a79fcb318fdb06fd053068f5/tests/python/test_ad_basics.py#L260-L266
|
430
|
+
return
|
431
|
+
|
432
|
+
if get_runtime().prog.config().debug:
|
433
|
+
if not root.finalized:
|
434
|
+
root._allocate_adjoint_checkbit()
|
435
|
+
|
436
|
+
root.finalize(raise_warning=not is_first_call)
|
437
|
+
global _root_fb
|
438
|
+
_root_fb = FieldsBuilder()
|
439
|
+
|
440
|
+
@staticmethod
|
441
|
+
def _finalize_root_fb_for_aot():
|
442
|
+
if _root_fb.finalized:
|
443
|
+
raise RuntimeError("AOT: can only finalize the root FieldsBuilder once")
|
444
|
+
assert isinstance(_root_fb, FieldsBuilder)
|
445
|
+
_root_fb._finalize_for_aot()
|
446
|
+
|
447
|
+
@staticmethod
|
448
|
+
def _get_tb(_var):
|
449
|
+
return getattr(_var, "declaration_tb", str(_var.ptr))
|
450
|
+
|
451
|
+
def _check_field_not_placed(self):
|
452
|
+
not_placed = []
|
453
|
+
for _var in self.global_vars:
|
454
|
+
if _var.ptr.snode() is None:
|
455
|
+
not_placed.append(self._get_tb(_var))
|
456
|
+
|
457
|
+
if len(not_placed):
|
458
|
+
bar = "=" * 44 + "\n"
|
459
|
+
raise RuntimeError(
|
460
|
+
f"These field(s) are not placed:\n{bar}"
|
461
|
+
+ f"{bar}".join(not_placed)
|
462
|
+
+ f"{bar}Please consider specifying a shape for them. E.g.,"
|
463
|
+
+ "\n\n x = ti.field(float, shape=(2, 3))"
|
464
|
+
)
|
465
|
+
|
466
|
+
def _check_gradient_field_not_placed(self, gradient_type):
|
467
|
+
not_placed = set()
|
468
|
+
gradient_vars = []
|
469
|
+
if gradient_type == "grad":
|
470
|
+
gradient_vars = self.grad_vars
|
471
|
+
elif gradient_type == "dual":
|
472
|
+
gradient_vars = self.dual_vars
|
473
|
+
for _var in gradient_vars:
|
474
|
+
if _var.ptr.snode() is None:
|
475
|
+
not_placed.add(self._get_tb(_var))
|
476
|
+
|
477
|
+
if len(not_placed):
|
478
|
+
bar = "=" * 44 + "\n"
|
479
|
+
raise RuntimeError(
|
480
|
+
f"These field(s) requrie `needs_{gradient_type}=True`, however their {gradient_type} field(s) are not placed:\n{bar}"
|
481
|
+
+ f"{bar}".join(not_placed)
|
482
|
+
+ f"{bar}Please consider place the {gradient_type} field(s). E.g.,"
|
483
|
+
+ "\n\n ti.root.dense(ti.i, 1).place(x.{gradient_type})"
|
484
|
+
+ "\n\n Or specify a shape for the field(s). E.g.,"
|
485
|
+
+ "\n\n x = ti.field(float, shape=(2, 3), needs_{gradient_type}=True)"
|
486
|
+
)
|
487
|
+
|
488
|
+
def _check_matrix_field_member_shape(self):
|
489
|
+
for _field in self.matrix_fields:
|
490
|
+
shapes = [_field.get_scalar_field(i, j).shape for i in range(_field.n) for j in range(_field.m)]
|
491
|
+
if any(shape != shapes[0] for shape in shapes):
|
492
|
+
raise RuntimeError(
|
493
|
+
"Members of the following field have different shapes "
|
494
|
+
+ f"{shapes}:\n{self._get_tb(_field._get_field_members()[0])}"
|
495
|
+
)
|
496
|
+
|
497
|
+
def _calc_matrix_field_dynamic_index_stride(self):
|
498
|
+
for _field in self.matrix_fields:
|
499
|
+
_field._calc_dynamic_index_stride()
|
500
|
+
|
501
|
+
def materialize(self):
|
502
|
+
self.materialize_root_fb(not self.materialized)
|
503
|
+
self.materialized = True
|
504
|
+
|
505
|
+
self.validate_fields_builder()
|
506
|
+
|
507
|
+
self._check_field_not_placed()
|
508
|
+
self._check_gradient_field_not_placed("grad")
|
509
|
+
self._check_gradient_field_not_placed("dual")
|
510
|
+
self._check_matrix_field_member_shape()
|
511
|
+
self._calc_matrix_field_dynamic_index_stride()
|
512
|
+
self.global_vars = []
|
513
|
+
self.grad_vars = []
|
514
|
+
self.dual_vars = []
|
515
|
+
self.matrix_fields = []
|
516
|
+
|
517
|
+
def _register_signal_handlers(self):
|
518
|
+
if self._signal_handler_registry is None:
|
519
|
+
self._signal_handler_registry = _ti_core.HackedSignalRegister()
|
520
|
+
|
521
|
+
def clear(self):
|
522
|
+
if self._prog:
|
523
|
+
self._prog.finalize()
|
524
|
+
self._prog = None
|
525
|
+
self._signal_handler_registry = None
|
526
|
+
self.materialized = False
|
527
|
+
|
528
|
+
def sync(self):
|
529
|
+
self.materialize()
|
530
|
+
assert self._prog is not None
|
531
|
+
self._prog.synchronize()
|
532
|
+
|
533
|
+
|
534
|
+
pytaichi = PyTaichi()
|
535
|
+
|
536
|
+
|
537
|
+
def get_runtime() -> PyTaichi:
|
538
|
+
return pytaichi
|
539
|
+
|
540
|
+
|
541
|
+
def reset():
|
542
|
+
global pytaichi
|
543
|
+
old_kernels = pytaichi.kernels
|
544
|
+
pytaichi.clear()
|
545
|
+
pytaichi = PyTaichi(old_kernels)
|
546
|
+
for k in old_kernels:
|
547
|
+
k.reset()
|
548
|
+
_ti_core.reset_default_compile_config()
|
549
|
+
|
550
|
+
|
551
|
+
@taichi_scope
|
552
|
+
def static_print(*args, __p=print, **kwargs):
|
553
|
+
"""The print function in Taichi scope.
|
554
|
+
|
555
|
+
This function is called at compile time and has no runtime overhead.
|
556
|
+
"""
|
557
|
+
__p(*args, **kwargs)
|
558
|
+
|
559
|
+
|
560
|
+
# we don't add @taichi_scope decorator for @ti.pyfunc to work
|
561
|
+
def static_assert(cond, msg=None):
|
562
|
+
"""Throw AssertionError when `cond` is False.
|
563
|
+
|
564
|
+
This function is called at compile time and has no runtime overhead.
|
565
|
+
The bool value in `cond` must can be determined at compile time.
|
566
|
+
|
567
|
+
Args:
|
568
|
+
cond (bool): an expression with a bool value.
|
569
|
+
msg (str): assertion message.
|
570
|
+
|
571
|
+
Example::
|
572
|
+
|
573
|
+
>>> year = 2001
|
574
|
+
>>> @ti.kernel
|
575
|
+
>>> def test():
|
576
|
+
>>> ti.static_assert(year % 4 == 0, "the year must be a lunar year")
|
577
|
+
AssertionError: the year must be a lunar year
|
578
|
+
"""
|
579
|
+
if isinstance(cond, Expr):
|
580
|
+
raise TaichiTypeError("Static assert with non-static condition")
|
581
|
+
if msg is not None:
|
582
|
+
assert cond, msg
|
583
|
+
else:
|
584
|
+
assert cond
|
585
|
+
|
586
|
+
|
587
|
+
def inside_kernel():
|
588
|
+
return pytaichi.inside_kernel
|
589
|
+
|
590
|
+
|
591
|
+
def index_nd(dim):
|
592
|
+
return axes(*range(dim))
|
593
|
+
|
594
|
+
|
595
|
+
class _UninitializedRootFieldsBuilder:
|
596
|
+
def __getattr__(self, item):
|
597
|
+
if item == "__qualname__":
|
598
|
+
# For sphinx docstring extraction.
|
599
|
+
return "_UninitializedRootFieldsBuilder"
|
600
|
+
raise TaichiRuntimeError("Please call init() first")
|
601
|
+
|
602
|
+
|
603
|
+
# `root` initialization must be delayed until after the program is
|
604
|
+
# created. Unfortunately, `root` exists in both taichi.lang.impl module and
|
605
|
+
# the top-level taichi module at this point; so if `root` itself is written, we
|
606
|
+
# would have to make sure that `root` in all the modules get updated to the same
|
607
|
+
# instance. This is an error-prone process.
|
608
|
+
#
|
609
|
+
# To avoid this situation, we create `root` once during the import time, and
|
610
|
+
# never write to it. The core part, `_root_fb`, is the one whose initialization
|
611
|
+
# gets delayed. `_root_fb` will only exist in the taichi.lang.impl module, so
|
612
|
+
# writing to it is would result in less for maintenance cost.
|
613
|
+
#
|
614
|
+
# `_root_fb` will be overridden inside :func:`taichi.lang.init`.
|
615
|
+
_root_fb = _UninitializedRootFieldsBuilder()
|
616
|
+
|
617
|
+
|
618
|
+
def deactivate_all_snodes():
|
619
|
+
"""Recursively deactivate all SNodes."""
|
620
|
+
for root_fb in FieldsBuilder._finalized_roots():
|
621
|
+
root_fb.deactivate_all()
|
622
|
+
|
623
|
+
|
624
|
+
class _Root:
|
625
|
+
"""Wrapper around the default root FieldsBuilder instance."""
|
626
|
+
|
627
|
+
@staticmethod
|
628
|
+
def parent(n=1):
|
629
|
+
"""Same as :func:`taichi.SNode.parent`"""
|
630
|
+
assert isinstance(_root_fb, FieldsBuilder)
|
631
|
+
return _root_fb.root.parent(n)
|
632
|
+
|
633
|
+
@staticmethod
|
634
|
+
def _loop_range():
|
635
|
+
"""Same as :func:`taichi.SNode.loop_range`"""
|
636
|
+
assert isinstance(_root_fb, FieldsBuilder)
|
637
|
+
return _root_fb.root._loop_range()
|
638
|
+
|
639
|
+
@staticmethod
|
640
|
+
def _get_children():
|
641
|
+
"""Same as :func:`taichi.SNode.get_children`"""
|
642
|
+
assert isinstance(_root_fb, FieldsBuilder)
|
643
|
+
return _root_fb.root._get_children()
|
644
|
+
|
645
|
+
# TODO: Record all of the SNodeTrees that finalized under 'ti.root'
|
646
|
+
@staticmethod
|
647
|
+
def deactivate_all():
|
648
|
+
warning("""'ti.root.deactivate_all()' would deactivate all finalized snodes.""")
|
649
|
+
deactivate_all_snodes()
|
650
|
+
|
651
|
+
@property
|
652
|
+
def shape(self):
|
653
|
+
"""Same as :func:`taichi.SNode.shape`"""
|
654
|
+
assert isinstance(_root_fb, FieldsBuilder)
|
655
|
+
return _root_fb.root.shape
|
656
|
+
|
657
|
+
@property
|
658
|
+
def _id(self):
|
659
|
+
assert isinstance(_root_fb, FieldsBuilder)
|
660
|
+
return _root_fb.root._id
|
661
|
+
|
662
|
+
def __getattr__(self, item):
|
663
|
+
return getattr(_root_fb, item)
|
664
|
+
|
665
|
+
def __repr__(self):
|
666
|
+
return "ti.root"
|
667
|
+
|
668
|
+
|
669
|
+
root = _Root()
|
670
|
+
"""Root of the declared Taichi :func:`~taichi.lang.impl.field`s.
|
671
|
+
|
672
|
+
See also https://docs.taichi-lang.org/docs/layout
|
673
|
+
|
674
|
+
Example::
|
675
|
+
|
676
|
+
>>> x = ti.field(ti.f32)
|
677
|
+
>>> ti.root.pointer(ti.ij, 4).dense(ti.ij, 8).place(x)
|
678
|
+
"""
|
679
|
+
|
680
|
+
|
681
|
+
def _create_snode(axis_seq: Sequence[int], shape_seq: Sequence[numbers.Number], same_level: bool):
|
682
|
+
dim = len(axis_seq)
|
683
|
+
assert dim == len(shape_seq)
|
684
|
+
snode = root
|
685
|
+
if same_level:
|
686
|
+
snode = snode.dense(axes(*axis_seq), shape_seq)
|
687
|
+
else:
|
688
|
+
for i in range(dim):
|
689
|
+
snode = snode.dense(axes(axis_seq[i]), (shape_seq[i],))
|
690
|
+
return snode
|
691
|
+
|
692
|
+
|
693
|
+
@python_scope
|
694
|
+
def create_field_member(dtype, name, needs_grad, needs_dual):
|
695
|
+
dtype = cook_dtype(dtype)
|
696
|
+
|
697
|
+
# primal
|
698
|
+
prog = get_runtime().prog
|
699
|
+
|
700
|
+
x = Expr(prog.make_id_expr(""))
|
701
|
+
x.declaration_tb = get_traceback(stacklevel=4)
|
702
|
+
x.ptr = _ti_core.expr_field(x.ptr, dtype)
|
703
|
+
x.ptr.set_name(name)
|
704
|
+
x.ptr.set_grad_type(SNodeGradType.PRIMAL)
|
705
|
+
pytaichi.global_vars.append(x)
|
706
|
+
|
707
|
+
x_grad = None
|
708
|
+
x_dual = None
|
709
|
+
# The x_grad_checkbit is used for global data access rule checker
|
710
|
+
x_grad_checkbit = None
|
711
|
+
if _ti_core.is_real(dtype):
|
712
|
+
# adjoint
|
713
|
+
x_grad = Expr(prog.make_id_expr(""))
|
714
|
+
x_grad.declaration_tb = get_traceback(stacklevel=4)
|
715
|
+
x_grad.ptr = _ti_core.expr_field(x_grad.ptr, dtype)
|
716
|
+
x_grad.ptr.set_name(name + ".grad")
|
717
|
+
x_grad.ptr.set_grad_type(SNodeGradType.ADJOINT)
|
718
|
+
x.ptr.set_adjoint(x_grad.ptr)
|
719
|
+
if needs_grad:
|
720
|
+
pytaichi.grad_vars.append(x_grad)
|
721
|
+
|
722
|
+
if prog.config().debug:
|
723
|
+
# adjoint checkbit
|
724
|
+
x_grad_checkbit = Expr(prog.make_id_expr(""))
|
725
|
+
dtype = u8
|
726
|
+
if prog.config().arch in (_ti_core.opengl, _ti_core.vulkan, _ti_core.gles):
|
727
|
+
dtype = i32
|
728
|
+
x_grad_checkbit.ptr = _ti_core.expr_field(x_grad_checkbit.ptr, cook_dtype(dtype))
|
729
|
+
x_grad_checkbit.ptr.set_name(name + ".grad_checkbit")
|
730
|
+
x_grad_checkbit.ptr.set_grad_type(SNodeGradType.ADJOINT_CHECKBIT)
|
731
|
+
x.ptr.set_adjoint_checkbit(x_grad_checkbit.ptr)
|
732
|
+
|
733
|
+
# dual
|
734
|
+
x_dual = Expr(prog.make_id_expr(""))
|
735
|
+
x_dual.ptr = _ti_core.expr_field(x_dual.ptr, dtype)
|
736
|
+
x_dual.ptr.set_name(name + ".dual")
|
737
|
+
x_dual.ptr.set_grad_type(SNodeGradType.DUAL)
|
738
|
+
x.ptr.set_dual(x_dual.ptr)
|
739
|
+
if needs_dual:
|
740
|
+
pytaichi.dual_vars.append(x_dual)
|
741
|
+
elif needs_grad or needs_dual:
|
742
|
+
raise TaichiRuntimeError(f"{dtype} is not supported for field with `needs_grad=True` or `needs_dual=True`.")
|
743
|
+
|
744
|
+
return x, x_grad, x_dual
|
745
|
+
|
746
|
+
|
747
|
+
@python_scope
|
748
|
+
def _field(
|
749
|
+
dtype,
|
750
|
+
shape=None,
|
751
|
+
order=None,
|
752
|
+
name="",
|
753
|
+
offset=None,
|
754
|
+
needs_grad=False,
|
755
|
+
needs_dual=False,
|
756
|
+
):
|
757
|
+
x, x_grad, x_dual = create_field_member(dtype, name, needs_grad, needs_dual)
|
758
|
+
x = ScalarField(x)
|
759
|
+
if x_grad:
|
760
|
+
x_grad = ScalarField(x_grad)
|
761
|
+
x._set_grad(x_grad)
|
762
|
+
if x_dual:
|
763
|
+
x_dual = ScalarField(x_dual)
|
764
|
+
x._set_dual(x_dual)
|
765
|
+
|
766
|
+
if shape is None:
|
767
|
+
if offset is not None:
|
768
|
+
raise TaichiSyntaxError("shape cannot be None when offset is set")
|
769
|
+
if order is not None:
|
770
|
+
raise TaichiSyntaxError("shape cannot be None when order is set")
|
771
|
+
else:
|
772
|
+
if isinstance(shape, numbers.Number):
|
773
|
+
shape = (shape,)
|
774
|
+
if isinstance(offset, numbers.Number):
|
775
|
+
offset = (offset,)
|
776
|
+
dim = len(shape)
|
777
|
+
if offset is not None and dim != len(offset):
|
778
|
+
raise TaichiSyntaxError(f"The dimensionality of shape and offset must be the same ({dim} != {len(offset)})")
|
779
|
+
axis_seq = []
|
780
|
+
shape_seq = []
|
781
|
+
if order is not None:
|
782
|
+
if dim != len(order):
|
783
|
+
raise TaichiSyntaxError(
|
784
|
+
f"The dimensionality of shape and order must be the same ({dim} != {len(order)})"
|
785
|
+
)
|
786
|
+
if dim != len(set(order)):
|
787
|
+
raise TaichiSyntaxError("The axes in order must be different")
|
788
|
+
for ch in order:
|
789
|
+
axis = ord(ch) - ord("i")
|
790
|
+
if axis < 0 or axis >= dim:
|
791
|
+
raise TaichiSyntaxError(f"Invalid axis {ch}")
|
792
|
+
axis_seq.append(axis)
|
793
|
+
shape_seq.append(shape[axis])
|
794
|
+
else:
|
795
|
+
axis_seq = list(range(dim))
|
796
|
+
shape_seq = list(shape)
|
797
|
+
same_level = order is None
|
798
|
+
_create_snode(axis_seq, shape_seq, same_level).place(x, offset=offset)
|
799
|
+
if needs_grad:
|
800
|
+
_create_snode(axis_seq, shape_seq, same_level).place(x_grad, offset=offset)
|
801
|
+
if needs_dual:
|
802
|
+
_create_snode(axis_seq, shape_seq, same_level).place(x_dual, offset=offset)
|
803
|
+
return x
|
804
|
+
|
805
|
+
|
806
|
+
@python_scope
|
807
|
+
def field(dtype, *args, **kwargs):
|
808
|
+
"""Defines a Taichi field.
|
809
|
+
|
810
|
+
A Taichi field can be viewed as an abstract N-dimensional array, hiding away
|
811
|
+
the complexity of how its underlying :class:`~taichi.lang.snode.SNode` are
|
812
|
+
actually defined. The data in a Taichi field can be directly accessed by
|
813
|
+
a Taichi :func:`~taichi.lang.kernel_impl.kernel`.
|
814
|
+
|
815
|
+
See also https://docs.taichi-lang.org/docs/field
|
816
|
+
|
817
|
+
Args:
|
818
|
+
dtype (DataType): data type of the field. Note it can be vector or matrix types as well.
|
819
|
+
shape (Union[int, tuple[int]], optional): shape of the field.
|
820
|
+
order (str, optional): order of the shape laid out in memory.
|
821
|
+
name (str, optional): name of the field.
|
822
|
+
offset (Union[int, tuple[int]], optional): offset of the field domain.
|
823
|
+
needs_grad (bool, optional): whether this field participates in autodiff (reverse mode)
|
824
|
+
and thus needs an adjoint field to store the gradients.
|
825
|
+
needs_dual (bool, optional): whether this field participates in autodiff (forward mode)
|
826
|
+
and thus needs an dual field to store the gradients.
|
827
|
+
|
828
|
+
Example::
|
829
|
+
|
830
|
+
The code below shows how a Taichi field can be declared and defined::
|
831
|
+
|
832
|
+
>>> x1 = ti.field(ti.f32, shape=(16, 8))
|
833
|
+
>>> # Equivalently
|
834
|
+
>>> x2 = ti.field(ti.f32)
|
835
|
+
>>> ti.root.dense(ti.ij, shape=(16, 8)).place(x2)
|
836
|
+
>>>
|
837
|
+
>>> x3 = ti.field(ti.f32, shape=(16, 8), order='ji')
|
838
|
+
>>> # Equivalently
|
839
|
+
>>> x4 = ti.field(ti.f32)
|
840
|
+
>>> ti.root.dense(ti.j, shape=8).dense(ti.i, shape=16).place(x4)
|
841
|
+
>>>
|
842
|
+
>>> x5 = ti.field(ti.math.vec3, shape=(16, 8))
|
843
|
+
|
844
|
+
"""
|
845
|
+
if isinstance(dtype, MatrixType):
|
846
|
+
if dtype.ndim == 1:
|
847
|
+
return Vector.field(dtype.n, dtype.dtype, *args, **kwargs)
|
848
|
+
return Matrix.field(dtype.n, dtype.m, dtype.dtype, *args, **kwargs)
|
849
|
+
return _field(dtype, *args, **kwargs)
|
850
|
+
|
851
|
+
|
852
|
+
@python_scope
|
853
|
+
def ndarray(dtype, shape, needs_grad=False):
|
854
|
+
"""Defines a Taichi ndarray with scalar elements.
|
855
|
+
|
856
|
+
Args:
|
857
|
+
dtype (Union[DataType, MatrixType]): Data type of each element. This can be either a scalar type like ti.f32 or a compound type like ti.types.vector(3, ti.i32).
|
858
|
+
shape (Union[int, tuple[int]]): Shape of the ndarray.
|
859
|
+
|
860
|
+
Example:
|
861
|
+
The code below shows how a Taichi ndarray with scalar elements can be declared and defined::
|
862
|
+
|
863
|
+
>>> x = ti.ndarray(ti.f32, shape=(16, 8)) # ndarray of shape (16, 8), each element is ti.f32 scalar.
|
864
|
+
>>> vec3 = ti.types.vector(3, ti.i32)
|
865
|
+
>>> y = ti.ndarray(vec3, shape=(10, 2)) # ndarray of shape (10, 2), each element is a vector of 3 ti.i32 scalars.
|
866
|
+
>>> matrix_ty = ti.types.matrix(3, 4, float)
|
867
|
+
>>> z = ti.ndarray(matrix_ty, shape=(4, 5)) # ndarray of shape (4, 5), each element is a matrix of (3, 4) ti.float scalars.
|
868
|
+
"""
|
869
|
+
# primal
|
870
|
+
if isinstance(shape, numbers.Number):
|
871
|
+
shape = (shape,)
|
872
|
+
if not all((isinstance(x, int) or isinstance(x, np.integer)) and x > 0 and x <= 2**31 - 1 for x in shape):
|
873
|
+
raise TaichiRuntimeError(f"{shape} is not a valid shape for ndarray")
|
874
|
+
if dtype in all_types:
|
875
|
+
dt = cook_dtype(dtype)
|
876
|
+
x = ScalarNdarray(dt, shape)
|
877
|
+
elif isinstance(dtype, MatrixType):
|
878
|
+
if dtype.ndim == 1:
|
879
|
+
x = VectorNdarray(dtype.n, dtype.dtype, shape)
|
880
|
+
else:
|
881
|
+
x = MatrixNdarray(dtype.n, dtype.m, dtype.dtype, shape)
|
882
|
+
dt = dtype.dtype
|
883
|
+
else:
|
884
|
+
raise TaichiRuntimeError(f"{dtype} is not supported as ndarray element type")
|
885
|
+
if needs_grad:
|
886
|
+
assert isinstance(dt, DataType)
|
887
|
+
if not _ti_core.is_real(dt):
|
888
|
+
raise TaichiRuntimeError(f"{dt} is not supported for ndarray with `needs_grad=True` or `needs_dual=True`.")
|
889
|
+
x_grad = ndarray(dtype, shape, needs_grad=False)
|
890
|
+
x._set_grad(x_grad)
|
891
|
+
return x
|
892
|
+
|
893
|
+
|
894
|
+
@taichi_scope
|
895
|
+
def ti_format_list_to_content_entries(raw):
|
896
|
+
# return a pair of [content, format]
|
897
|
+
def entry2content(_var):
|
898
|
+
if isinstance(_var, str):
|
899
|
+
return [_var, None]
|
900
|
+
if isinstance(_var, list):
|
901
|
+
assert len(_var) == 2 and (isinstance(_var[1], str) or _var[1] is None)
|
902
|
+
_var[0] = Expr(_var[0]).ptr
|
903
|
+
return _var
|
904
|
+
return [Expr(_var).ptr, None]
|
905
|
+
|
906
|
+
def list_ti_repr(_var):
|
907
|
+
yield "[" # distinguishing tuple & list will increase maintenance cost
|
908
|
+
for i, v in enumerate(_var):
|
909
|
+
if i:
|
910
|
+
yield ", "
|
911
|
+
yield v
|
912
|
+
yield "]"
|
913
|
+
|
914
|
+
def vars2entries(_vars):
|
915
|
+
for _var in _vars:
|
916
|
+
# If the first element is '__ti_fmt_value__', this list is an Expr and its format.
|
917
|
+
if isinstance(_var, list) and len(_var) == 3 and isinstance(_var[0], str) and _var[0] == "__ti_fmt_value__":
|
918
|
+
# yield [Expr, format] as a whole and don't pass it to vars2entries() again
|
919
|
+
yield _var[1:]
|
920
|
+
continue
|
921
|
+
elif hasattr(_var, "__ti_repr__"):
|
922
|
+
res = _var.__ti_repr__()
|
923
|
+
elif isinstance(_var, (list, tuple)):
|
924
|
+
# If the first element is '__ti_format__', this list is the result of ti_format.
|
925
|
+
if len(_var) > 0 and isinstance(_var[0], str) and _var[0] == "__ti_format__":
|
926
|
+
res = _var[1:]
|
927
|
+
else:
|
928
|
+
res = list_ti_repr(_var)
|
929
|
+
else:
|
930
|
+
yield _var
|
931
|
+
continue
|
932
|
+
|
933
|
+
for v in vars2entries(res):
|
934
|
+
yield v
|
935
|
+
|
936
|
+
def fused_string(entries):
|
937
|
+
accumated = ""
|
938
|
+
for entry in entries:
|
939
|
+
if isinstance(entry, str):
|
940
|
+
accumated += entry
|
941
|
+
else:
|
942
|
+
if accumated:
|
943
|
+
yield accumated
|
944
|
+
accumated = ""
|
945
|
+
yield entry
|
946
|
+
if accumated:
|
947
|
+
yield accumated
|
948
|
+
|
949
|
+
def extract_formats(entries):
|
950
|
+
contents, formats = zip(*entries)
|
951
|
+
return list(contents), list(formats)
|
952
|
+
|
953
|
+
entries = vars2entries(raw)
|
954
|
+
entries = fused_string(entries)
|
955
|
+
entries = [entry2content(entry) for entry in entries]
|
956
|
+
return extract_formats(entries)
|
957
|
+
|
958
|
+
|
959
|
+
@taichi_scope
|
960
|
+
def ti_print(*_vars, sep=" ", end="\n"):
|
961
|
+
def add_separators(_vars):
|
962
|
+
for i, _var in enumerate(_vars):
|
963
|
+
if i:
|
964
|
+
yield sep
|
965
|
+
yield _var
|
966
|
+
yield end
|
967
|
+
|
968
|
+
_vars = add_separators(_vars)
|
969
|
+
contents, formats = ti_format_list_to_content_entries(_vars)
|
970
|
+
compiling_callable = get_runtime().compiling_callable
|
971
|
+
assert compiling_callable is not None
|
972
|
+
compiling_callable.ast_builder().create_print(
|
973
|
+
contents, formats, _ti_core.DebugInfo(get_runtime().get_current_src_info())
|
974
|
+
)
|
975
|
+
|
976
|
+
|
977
|
+
@taichi_scope
|
978
|
+
def ti_format(*args):
|
979
|
+
content = args[0]
|
980
|
+
mixed = args[1:]
|
981
|
+
new_mixed = []
|
982
|
+
args = []
|
983
|
+
for x in mixed:
|
984
|
+
# x is a (formatted) Expr
|
985
|
+
if isinstance(x, Expr) or (isinstance(x, list) and len(x) == 3 and x[0] == "__ti_fmt_value__"):
|
986
|
+
new_mixed.append("{}")
|
987
|
+
args.append(x)
|
988
|
+
else:
|
989
|
+
new_mixed.append(x)
|
990
|
+
content = content.format(*new_mixed)
|
991
|
+
res = content.split("{}")
|
992
|
+
assert len(res) == len(args) + 1, "Number of args is different from number of positions provided in string"
|
993
|
+
|
994
|
+
for i, arg in enumerate(args):
|
995
|
+
res.insert(i * 2 + 1, arg)
|
996
|
+
res.insert(0, "__ti_format__")
|
997
|
+
return res
|
998
|
+
|
999
|
+
|
1000
|
+
@taichi_scope
|
1001
|
+
def ti_assert(cond, msg, extra_args, dbg_info):
|
1002
|
+
# Mostly a wrapper to help us convert from Expr (defined in Python) to
|
1003
|
+
# _ti_core.Expr (defined in C++)
|
1004
|
+
compiling_callable = get_runtime().compiling_callable
|
1005
|
+
assert compiling_callable is not None
|
1006
|
+
compiling_callable.ast_builder().create_assert_stmt(Expr(cond).ptr, msg, extra_args, dbg_info)
|
1007
|
+
|
1008
|
+
|
1009
|
+
@taichi_scope
|
1010
|
+
def ti_int(_var):
|
1011
|
+
if hasattr(_var, "__ti_int__"):
|
1012
|
+
return _var.__ti_int__()
|
1013
|
+
return int(_var)
|
1014
|
+
|
1015
|
+
|
1016
|
+
@taichi_scope
|
1017
|
+
def ti_bool(_var):
|
1018
|
+
if hasattr(_var, "__ti_bool__"):
|
1019
|
+
return _var.__ti_bool__()
|
1020
|
+
return bool(_var)
|
1021
|
+
|
1022
|
+
|
1023
|
+
@taichi_scope
|
1024
|
+
def ti_float(_var):
|
1025
|
+
if hasattr(_var, "__ti_float__"):
|
1026
|
+
return _var.__ti_float__()
|
1027
|
+
return float(_var)
|
1028
|
+
|
1029
|
+
|
1030
|
+
@taichi_scope
|
1031
|
+
def zero(x):
|
1032
|
+
# TODO: get dtype from Expr and Matrix:
|
1033
|
+
"""Returns an array of zeros with the same shape and type as the input. It's also a scalar
|
1034
|
+
if the input is a scalar.
|
1035
|
+
|
1036
|
+
Args:
|
1037
|
+
x (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): The input.
|
1038
|
+
|
1039
|
+
Returns:
|
1040
|
+
A new copy of the input but filled with zeros.
|
1041
|
+
|
1042
|
+
Example::
|
1043
|
+
|
1044
|
+
>>> x = ti.Vector([1, 1])
|
1045
|
+
>>> @ti.kernel
|
1046
|
+
>>> def test():
|
1047
|
+
>>> y = ti.zero(x)
|
1048
|
+
>>> print(y)
|
1049
|
+
[0, 0]
|
1050
|
+
"""
|
1051
|
+
return x * 0
|
1052
|
+
|
1053
|
+
|
1054
|
+
@taichi_scope
|
1055
|
+
def one(x):
|
1056
|
+
"""Returns an array of ones with the same shape and type as the input. It's also a scalar
|
1057
|
+
if the input is a scalar.
|
1058
|
+
|
1059
|
+
Args:
|
1060
|
+
x (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): The input.
|
1061
|
+
|
1062
|
+
Returns:
|
1063
|
+
A new copy of the input but filled with ones.
|
1064
|
+
|
1065
|
+
Example::
|
1066
|
+
|
1067
|
+
>>> x = ti.Vector([0, 0])
|
1068
|
+
>>> @ti.kernel
|
1069
|
+
>>> def test():
|
1070
|
+
>>> y = ti.one(x)
|
1071
|
+
>>> print(y)
|
1072
|
+
[1, 1]
|
1073
|
+
"""
|
1074
|
+
return zero(x) + 1
|
1075
|
+
|
1076
|
+
|
1077
|
+
def axes(*x: int):
|
1078
|
+
"""Defines a list of axes to be used by a field.
|
1079
|
+
|
1080
|
+
Args:
|
1081
|
+
*x: A list of axes to be activated
|
1082
|
+
|
1083
|
+
Note that Taichi has already provided a set of commonly used axes. For example,
|
1084
|
+
`ti.ij` is just `axes(0, 1)` under the hood.
|
1085
|
+
"""
|
1086
|
+
return [_ti_core.Axis(i) for i in x]
|
1087
|
+
|
1088
|
+
|
1089
|
+
Axis = _ti_core.Axis
|
1090
|
+
|
1091
|
+
|
1092
|
+
def static(x, *xs) -> Any:
|
1093
|
+
"""Evaluates a Taichi-scope expression at compile time.
|
1094
|
+
|
1095
|
+
`static()` is what enables the so-called metaprogramming in Taichi. It is
|
1096
|
+
in many ways similar to ``constexpr`` in C++.
|
1097
|
+
|
1098
|
+
See also https://docs.taichi-lang.org/docs/meta.
|
1099
|
+
|
1100
|
+
Args:
|
1101
|
+
x (Any): an expression to be evaluated
|
1102
|
+
*xs (Any): for Python-ish swapping assignment
|
1103
|
+
|
1104
|
+
Example:
|
1105
|
+
The most common usage of `static()` is for compile-time evaluation::
|
1106
|
+
|
1107
|
+
>>> cond = False
|
1108
|
+
>>>
|
1109
|
+
>>> @ti.kernel
|
1110
|
+
>>> def run():
|
1111
|
+
>>> if ti.static(cond):
|
1112
|
+
>>> do_a()
|
1113
|
+
>>> else:
|
1114
|
+
>>> do_b()
|
1115
|
+
|
1116
|
+
Depending on the value of ``cond``, ``run()`` will be directly compiled
|
1117
|
+
into either ``do_a()`` or ``do_b()``. Thus there won't be a runtime
|
1118
|
+
condition check.
|
1119
|
+
|
1120
|
+
Another common usage is for compile-time loop unrolling::
|
1121
|
+
|
1122
|
+
>>> @ti.kernel
|
1123
|
+
>>> def run():
|
1124
|
+
>>> for i in ti.static(range(3)):
|
1125
|
+
>>> print(i)
|
1126
|
+
>>>
|
1127
|
+
>>> # The above will be unrolled to:
|
1128
|
+
>>> @ti.kernel
|
1129
|
+
>>> def run():
|
1130
|
+
>>> print(0)
|
1131
|
+
>>> print(1)
|
1132
|
+
>>> print(2)
|
1133
|
+
"""
|
1134
|
+
if len(xs): # for python-ish pointer assign: x, y = ti.static(y, x)
|
1135
|
+
return [static(x)] + [static(x) for x in xs]
|
1136
|
+
|
1137
|
+
if (
|
1138
|
+
isinstance(
|
1139
|
+
x,
|
1140
|
+
(
|
1141
|
+
bool,
|
1142
|
+
int,
|
1143
|
+
float,
|
1144
|
+
range,
|
1145
|
+
list,
|
1146
|
+
tuple,
|
1147
|
+
enumerate,
|
1148
|
+
GroupedNDRange,
|
1149
|
+
_Ndrange,
|
1150
|
+
zip,
|
1151
|
+
filter,
|
1152
|
+
map,
|
1153
|
+
),
|
1154
|
+
)
|
1155
|
+
or x is None
|
1156
|
+
):
|
1157
|
+
return x
|
1158
|
+
if isinstance(x, (np.bool_, np.integer, np.floating)):
|
1159
|
+
return x
|
1160
|
+
|
1161
|
+
if isinstance(x, AnyArray):
|
1162
|
+
return x
|
1163
|
+
if isinstance(x, Field):
|
1164
|
+
return x
|
1165
|
+
if isinstance(x, (FunctionType, MethodType)):
|
1166
|
+
return x
|
1167
|
+
raise ValueError(f"Input to ti.static must be compile-time constants or global pointers, instead of {type(x)}")
|
1168
|
+
|
1169
|
+
|
1170
|
+
@taichi_scope
|
1171
|
+
def grouped(x):
|
1172
|
+
"""Groups the indices in the iterator returned by `ndrange()` into a 1-D vector.
|
1173
|
+
|
1174
|
+
This is often used when you want to iterate over all indices returned by `ndrange()`
|
1175
|
+
in one `for` loop and a single index.
|
1176
|
+
|
1177
|
+
Args:
|
1178
|
+
x (:func:`~taichi.ndrange`): an iterator object returned by `ti.ndrange`.
|
1179
|
+
|
1180
|
+
Example::
|
1181
|
+
>>> # without ti.grouped
|
1182
|
+
>>> for I in ti.ndrange(2, 3):
|
1183
|
+
>>> print(I)
|
1184
|
+
prints 0, 1, 2, 3, 4, 5
|
1185
|
+
|
1186
|
+
>>> # with ti.grouped
|
1187
|
+
>>> for I in ti.grouped(ti.ndrange(2, 3)):
|
1188
|
+
>>> print(I)
|
1189
|
+
prints [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]
|
1190
|
+
"""
|
1191
|
+
if isinstance(x, _Ndrange):
|
1192
|
+
return x.grouped()
|
1193
|
+
return x
|
1194
|
+
|
1195
|
+
|
1196
|
+
def stop_grad(x):
|
1197
|
+
"""Stops computing gradients during back propagation.
|
1198
|
+
|
1199
|
+
Args:
|
1200
|
+
x (:class:`~taichi.Field`): A field.
|
1201
|
+
"""
|
1202
|
+
compiling_callable = get_runtime().compiling_callable
|
1203
|
+
assert compiling_callable is not None
|
1204
|
+
compiling_callable.ast_builder().stop_grad(x.snode.ptr)
|
1205
|
+
|
1206
|
+
|
1207
|
+
def current_cfg():
|
1208
|
+
return get_runtime().prog.config()
|
1209
|
+
|
1210
|
+
|
1211
|
+
def default_cfg():
|
1212
|
+
return _ti_core.default_compile_config()
|
1213
|
+
|
1214
|
+
|
1215
|
+
def call_internal(name, *args, with_runtime_context=True):
|
1216
|
+
return expr_init(_ti_core.insert_internal_func_call(getattr(_ti_core.InternalOp, name), make_expr_group(args)))
|
1217
|
+
|
1218
|
+
|
1219
|
+
def get_cuda_compute_capability():
|
1220
|
+
return _ti_core.query_int64("cuda_compute_capability")
|
1221
|
+
|
1222
|
+
|
1223
|
+
@taichi_scope
|
1224
|
+
def mesh_relation_access(mesh, from_index, to_element_type):
|
1225
|
+
# to support ti.mesh_local and access mesh attribute as field
|
1226
|
+
if isinstance(from_index, MeshInstance):
|
1227
|
+
return getattr(from_index, element_type_name(to_element_type))
|
1228
|
+
if isinstance(mesh, MeshInstance):
|
1229
|
+
return MeshRelationAccessProxy(mesh, from_index, to_element_type)
|
1230
|
+
raise RuntimeError("Relation access should be with a mesh instance!")
|
1231
|
+
|
1232
|
+
|
1233
|
+
__all__ = [
|
1234
|
+
"axes",
|
1235
|
+
"deactivate_all_snodes",
|
1236
|
+
"field",
|
1237
|
+
"grouped",
|
1238
|
+
"ndarray",
|
1239
|
+
"one",
|
1240
|
+
"root",
|
1241
|
+
"static",
|
1242
|
+
"static_assert",
|
1243
|
+
"static_print",
|
1244
|
+
"stop_grad",
|
1245
|
+
"zero",
|
1246
|
+
]
|