gstaichi 2.1.1rc3__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +4 -0
- gstaichi/__init__.py +40 -0
- gstaichi/_funcs.py +706 -0
- gstaichi/_kernels.py +420 -0
- gstaichi/_lib/__init__.py +3 -0
- gstaichi/_lib/core/__init__.py +0 -0
- gstaichi/_lib/core/gstaichi_python.cpython-313-darwin.so +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
- gstaichi/_lib/core/py.typed +0 -0
- gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
- gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
- gstaichi/_lib/utils.py +243 -0
- gstaichi/_logging.py +131 -0
- gstaichi/_snode/__init__.py +5 -0
- gstaichi/_snode/fields_builder.py +187 -0
- gstaichi/_snode/snode_tree.py +34 -0
- gstaichi/_test_tools/__init__.py +18 -0
- gstaichi/_test_tools/dataclass_test_tools.py +36 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_test_tools/textwrap2.py +6 -0
- gstaichi/_version.py +1 -0
- gstaichi/_version_check.py +100 -0
- gstaichi/ad/__init__.py +3 -0
- gstaichi/ad/_ad.py +530 -0
- gstaichi/algorithms/__init__.py +3 -0
- gstaichi/algorithms/_algorithms.py +117 -0
- gstaichi/assets/.git +1 -0
- gstaichi/assets/Go-Regular.ttf +0 -0
- gstaichi/assets/static/imgs/ti_gallery.png +0 -0
- gstaichi/examples/lcg_python.py +26 -0
- gstaichi/examples/lcg_taichi.py +34 -0
- gstaichi/examples/minimal.py +28 -0
- gstaichi/experimental.py +16 -0
- gstaichi/lang/__init__.py +50 -0
- gstaichi/lang/_dataclass_util.py +31 -0
- gstaichi/lang/_fast_caching/__init__.py +3 -0
- gstaichi/lang/_fast_caching/args_hasher.py +110 -0
- gstaichi/lang/_fast_caching/config_hasher.py +30 -0
- gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
- gstaichi/lang/_fast_caching/function_hasher.py +57 -0
- gstaichi/lang/_fast_caching/hash_utils.py +11 -0
- gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
- gstaichi/lang/_fast_caching/src_hasher.py +75 -0
- gstaichi/lang/_kernel_impl_dataclass.py +212 -0
- gstaichi/lang/_ndarray.py +352 -0
- gstaichi/lang/_ndrange.py +152 -0
- gstaichi/lang/_template_mapper.py +195 -0
- gstaichi/lang/_texture.py +172 -0
- gstaichi/lang/_wrap_inspect.py +215 -0
- gstaichi/lang/any_array.py +99 -0
- gstaichi/lang/ast/__init__.py +5 -0
- gstaichi/lang/ast/ast_transformer.py +1323 -0
- gstaichi/lang/ast/ast_transformer_utils.py +346 -0
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
- gstaichi/lang/ast/checkers.py +106 -0
- gstaichi/lang/ast/symbol_resolver.py +57 -0
- gstaichi/lang/ast/transform.py +9 -0
- gstaichi/lang/common_ops.py +310 -0
- gstaichi/lang/exception.py +80 -0
- gstaichi/lang/expr.py +180 -0
- gstaichi/lang/field.py +428 -0
- gstaichi/lang/impl.py +1243 -0
- gstaichi/lang/kernel_arguments.py +155 -0
- gstaichi/lang/kernel_impl.py +1341 -0
- gstaichi/lang/matrix.py +1835 -0
- gstaichi/lang/matrix_ops.py +341 -0
- gstaichi/lang/matrix_ops_utils.py +190 -0
- gstaichi/lang/mesh.py +687 -0
- gstaichi/lang/misc.py +782 -0
- gstaichi/lang/ops.py +1494 -0
- gstaichi/lang/runtime_ops.py +13 -0
- gstaichi/lang/shell.py +35 -0
- gstaichi/lang/simt/__init__.py +5 -0
- gstaichi/lang/simt/block.py +94 -0
- gstaichi/lang/simt/grid.py +7 -0
- gstaichi/lang/simt/subgroup.py +191 -0
- gstaichi/lang/simt/warp.py +96 -0
- gstaichi/lang/snode.py +489 -0
- gstaichi/lang/source_builder.py +150 -0
- gstaichi/lang/struct.py +810 -0
- gstaichi/lang/util.py +312 -0
- gstaichi/linalg/__init__.py +8 -0
- gstaichi/linalg/matrixfree_cg.py +310 -0
- gstaichi/linalg/sparse_cg.py +59 -0
- gstaichi/linalg/sparse_matrix.py +303 -0
- gstaichi/linalg/sparse_solver.py +123 -0
- gstaichi/math/__init__.py +11 -0
- gstaichi/math/_complex.py +205 -0
- gstaichi/math/mathimpl.py +886 -0
- gstaichi/profiler/__init__.py +6 -0
- gstaichi/profiler/kernel_metrics.py +260 -0
- gstaichi/profiler/kernel_profiler.py +586 -0
- gstaichi/profiler/memory_profiler.py +15 -0
- gstaichi/profiler/scoped_profiler.py +36 -0
- gstaichi/sparse/__init__.py +3 -0
- gstaichi/sparse/_sparse_grid.py +77 -0
- gstaichi/tools/__init__.py +12 -0
- gstaichi/tools/diagnose.py +117 -0
- gstaichi/tools/np2ply.py +364 -0
- gstaichi/tools/vtk.py +38 -0
- gstaichi/types/__init__.py +19 -0
- gstaichi/types/annotations.py +52 -0
- gstaichi/types/compound_types.py +71 -0
- gstaichi/types/enums.py +49 -0
- gstaichi/types/ndarray_type.py +169 -0
- gstaichi/types/primitive_types.py +206 -0
- gstaichi/types/quant.py +88 -0
- gstaichi/types/texture_type.py +85 -0
- gstaichi/types/utils.py +11 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
- gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
- gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
- gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
- gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
- gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
gstaichi/lang/misc.py
ADDED
@@ -0,0 +1,782 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import atexit
|
4
|
+
import os
|
5
|
+
import shutil
|
6
|
+
import tempfile
|
7
|
+
import warnings
|
8
|
+
from copy import deepcopy as _deepcopy
|
9
|
+
|
10
|
+
from gstaichi import _logging, _snode, _version_check
|
11
|
+
from gstaichi._lib import core as _ti_core
|
12
|
+
from gstaichi.lang import impl
|
13
|
+
from gstaichi.lang.expr import Expr
|
14
|
+
from gstaichi.lang.impl import axes, get_runtime
|
15
|
+
from gstaichi.profiler.kernel_profiler import get_default_kernel_profiler
|
16
|
+
from gstaichi.types.primitive_types import f32, f64, i32, i64
|
17
|
+
|
18
|
+
warnings.filterwarnings("once", category=DeprecationWarning, module="gstaichi")
|
19
|
+
|
20
|
+
# ----------------------
|
21
|
+
i = axes(0)
|
22
|
+
"""Axis 0. For multi-dimensional arrays it's the direction downward the rows.
|
23
|
+
For a 1d array it's the direction along this array.
|
24
|
+
"""
|
25
|
+
# ----------------------
|
26
|
+
|
27
|
+
j = axes(1)
|
28
|
+
"""Axis 1. For multi-dimensional arrays it's the direction across the columns.
|
29
|
+
"""
|
30
|
+
# ----------------------
|
31
|
+
|
32
|
+
k = axes(2)
|
33
|
+
"""Axis 2. For arrays of dimension `d` >= 3, view each cell as an array of
|
34
|
+
lower dimension d-2, it's the first axis of this cell.
|
35
|
+
"""
|
36
|
+
# ----------------------
|
37
|
+
|
38
|
+
l = axes(3)
|
39
|
+
"""Axis 3. For arrays of dimension `d` >= 4, view each cell as an array of
|
40
|
+
lower dimension d-2, it's the second axis of this cell.
|
41
|
+
"""
|
42
|
+
# ----------------------
|
43
|
+
|
44
|
+
ij = axes(0, 1)
|
45
|
+
"""Axes (0, 1).
|
46
|
+
"""
|
47
|
+
# ----------------------
|
48
|
+
|
49
|
+
ik = axes(0, 2)
|
50
|
+
"""Axes (0, 2).
|
51
|
+
"""
|
52
|
+
# ----------------------
|
53
|
+
|
54
|
+
il = axes(0, 3)
|
55
|
+
"""Axes (0, 3).
|
56
|
+
"""
|
57
|
+
# ----------------------
|
58
|
+
|
59
|
+
jk = axes(1, 2)
|
60
|
+
"""Axes (1, 2).
|
61
|
+
"""
|
62
|
+
# ----------------------
|
63
|
+
|
64
|
+
jl = axes(1, 3)
|
65
|
+
"""Axes (1, 3).
|
66
|
+
"""
|
67
|
+
# ----------------------
|
68
|
+
|
69
|
+
kl = axes(2, 3)
|
70
|
+
"""Axes (2, 3).
|
71
|
+
"""
|
72
|
+
# ----------------------
|
73
|
+
|
74
|
+
ijk = axes(0, 1, 2)
|
75
|
+
"""Axes (0, 1, 2).
|
76
|
+
"""
|
77
|
+
# ----------------------
|
78
|
+
|
79
|
+
ijl = axes(0, 1, 3)
|
80
|
+
"""Axes (0, 1, 3).
|
81
|
+
"""
|
82
|
+
# ----------------------
|
83
|
+
|
84
|
+
ikl = axes(0, 2, 3)
|
85
|
+
"""Axes (0, 2, 3).
|
86
|
+
"""
|
87
|
+
# ----------------------
|
88
|
+
|
89
|
+
jkl = axes(1, 2, 3)
|
90
|
+
"""Axes (1, 2, 3).
|
91
|
+
"""
|
92
|
+
# ----------------------
|
93
|
+
|
94
|
+
ijkl = axes(0, 1, 2, 3)
|
95
|
+
"""Axes (0, 1, 2, 3).
|
96
|
+
"""
|
97
|
+
# ----------------------
|
98
|
+
|
99
|
+
# ----------------------
|
100
|
+
|
101
|
+
x86_64 = _ti_core.x64
|
102
|
+
"""The x64 CPU backend.
|
103
|
+
"""
|
104
|
+
# ----------------------
|
105
|
+
|
106
|
+
x64 = _ti_core.x64
|
107
|
+
"""The X64 CPU backend.
|
108
|
+
"""
|
109
|
+
# ----------------------
|
110
|
+
|
111
|
+
arm64 = _ti_core.arm64
|
112
|
+
"""The ARM CPU backend.
|
113
|
+
"""
|
114
|
+
# ----------------------
|
115
|
+
|
116
|
+
cuda = _ti_core.cuda
|
117
|
+
"""The CUDA backend.
|
118
|
+
"""
|
119
|
+
# ----------------------
|
120
|
+
|
121
|
+
amdgpu = _ti_core.amdgpu
|
122
|
+
"""The AMDGPU backend.
|
123
|
+
"""
|
124
|
+
# ----------------------
|
125
|
+
|
126
|
+
metal = _ti_core.metal
|
127
|
+
"""The Apple Metal backend.
|
128
|
+
"""
|
129
|
+
# ----------------------
|
130
|
+
|
131
|
+
vulkan = _ti_core.vulkan
|
132
|
+
"""The Vulkan backend.
|
133
|
+
"""
|
134
|
+
# ----------------------
|
135
|
+
|
136
|
+
gpu = [cuda, metal, vulkan, amdgpu]
|
137
|
+
"""A list of GPU backends supported on the current system.
|
138
|
+
Currently contains 'cuda', 'metal', 'vulkan', 'amdgpu'.
|
139
|
+
|
140
|
+
When this is used, GsTaichi automatically picks the matching GPU backend. If no
|
141
|
+
GPU is detected, GsTaichi falls back to the CPU backend.
|
142
|
+
"""
|
143
|
+
# ----------------------
|
144
|
+
|
145
|
+
cpu = _ti_core.host_arch()
|
146
|
+
"""A list of CPU backends supported on the current system.
|
147
|
+
Currently contains 'x64', 'x86_64', 'arm64'.
|
148
|
+
|
149
|
+
When this is used, GsTaichi automatically picks the matching CPU backend.
|
150
|
+
"""
|
151
|
+
# ----------------------
|
152
|
+
|
153
|
+
|
154
|
+
def timeline_clear():
|
155
|
+
return impl.get_runtime().prog.timeline_clear()
|
156
|
+
|
157
|
+
|
158
|
+
def timeline_save(fn):
|
159
|
+
return impl.get_runtime().prog.timeline_save(fn)
|
160
|
+
|
161
|
+
|
162
|
+
extension = _ti_core.Extension
|
163
|
+
"""An instance of GsTaichi extension.
|
164
|
+
|
165
|
+
The list of currently available extensions is ['sparse', 'quant', \
|
166
|
+
'mesh', 'quant_basic', 'data64', 'adstack', 'bls', 'assertion', \
|
167
|
+
'extfunc'].
|
168
|
+
"""
|
169
|
+
|
170
|
+
|
171
|
+
def is_extension_supported(arch, ext):
|
172
|
+
"""Checks whether an extension is supported on an arch.
|
173
|
+
|
174
|
+
Args:
|
175
|
+
arch (gstaichi_python.Arch): Specified arch.
|
176
|
+
ext (gstaichi_python.Extension): Specified extension.
|
177
|
+
|
178
|
+
Returns:
|
179
|
+
bool: Whether `ext` is supported on `arch`.
|
180
|
+
"""
|
181
|
+
return _ti_core.is_extension_supported(arch, ext)
|
182
|
+
|
183
|
+
|
184
|
+
def reset():
|
185
|
+
"""Resets GsTaichi to its initial state.
|
186
|
+
This will destroy all the allocated fields and kernels, and restore
|
187
|
+
the runtime to its default configuration.
|
188
|
+
|
189
|
+
Example::
|
190
|
+
|
191
|
+
>>> a = ti.field(ti.i32, shape=())
|
192
|
+
>>> a[None] = 1
|
193
|
+
>>> print("before reset: ", a)
|
194
|
+
before rest: 1
|
195
|
+
>>>
|
196
|
+
>>> ti.reset()
|
197
|
+
>>> print("after reset: ", a)
|
198
|
+
# will raise error because a is unavailable after reset.
|
199
|
+
"""
|
200
|
+
impl.reset()
|
201
|
+
global runtime
|
202
|
+
runtime = impl.get_runtime()
|
203
|
+
|
204
|
+
|
205
|
+
class _EnvironmentConfigurator:
|
206
|
+
def __init__(self, kwargs, _cfg):
|
207
|
+
self.cfg = _cfg
|
208
|
+
self.kwargs = kwargs
|
209
|
+
self.keys = []
|
210
|
+
|
211
|
+
def add(self, key, _cast=None):
|
212
|
+
_cast = _cast or self.bool_int
|
213
|
+
|
214
|
+
self.keys.append(key)
|
215
|
+
|
216
|
+
# TI_OFFLINE_CACHE= : no effect
|
217
|
+
# TI_OFFLINE_CACHE=0 : False
|
218
|
+
# TI_OFFLINE_CACHE=1 : True
|
219
|
+
name = "TI_" + key.upper()
|
220
|
+
value = os.environ.get(name, "")
|
221
|
+
if key in self.kwargs:
|
222
|
+
self[key] = self.kwargs[key]
|
223
|
+
if value:
|
224
|
+
_ti_core.warn(f'Environment variable {name}={value} overridden by ti.init argument "{key}"')
|
225
|
+
del self.kwargs[key] # mark as recognized
|
226
|
+
elif value:
|
227
|
+
self[key] = _cast(value)
|
228
|
+
|
229
|
+
def __getitem__(self, key):
|
230
|
+
return getattr(self.cfg, key)
|
231
|
+
|
232
|
+
def __setitem__(self, key, value):
|
233
|
+
setattr(self.cfg, key, value)
|
234
|
+
|
235
|
+
@staticmethod
|
236
|
+
def bool_int(x):
|
237
|
+
return bool(int(x))
|
238
|
+
|
239
|
+
|
240
|
+
class _SpecialConfig:
|
241
|
+
# like CompileConfig in C++, this is the configurations that belong to other submodules
|
242
|
+
def __init__(self):
|
243
|
+
self.log_level = "info"
|
244
|
+
self.gdb_trigger = False
|
245
|
+
self.short_circuit_operators = True
|
246
|
+
self.print_full_traceback = False
|
247
|
+
self.unrolling_limit = 32
|
248
|
+
|
249
|
+
|
250
|
+
def prepare_sandbox():
|
251
|
+
"""
|
252
|
+
Returns a temporary directory, which will be automatically deleted on exit.
|
253
|
+
It may contain the gstaichi_python shared object or some misc. files.
|
254
|
+
"""
|
255
|
+
tmp_dir = tempfile.mkdtemp(prefix="gstaichi-")
|
256
|
+
atexit.register(shutil.rmtree, tmp_dir)
|
257
|
+
print(f"[GsTaichi] preparing sandbox at {tmp_dir}")
|
258
|
+
os.mkdir(os.path.join(tmp_dir, "runtime/"))
|
259
|
+
return tmp_dir
|
260
|
+
|
261
|
+
|
262
|
+
def check_require_version(require_version):
|
263
|
+
"""
|
264
|
+
Check if installed version meets the requirements.
|
265
|
+
Allow to specify <major>.<minor>.<patch>.<hash>.
|
266
|
+
<patch>.<hash> is optional. If not match, raise an exception.
|
267
|
+
"""
|
268
|
+
# Extract version number part (i.e. toss any revision / hash parts).
|
269
|
+
version_number_str = require_version
|
270
|
+
for c_idx, c in enumerate(require_version):
|
271
|
+
if not (c.isdigit() or c == "."):
|
272
|
+
version_number_str = require_version[:c_idx]
|
273
|
+
break
|
274
|
+
# Get required version.
|
275
|
+
try:
|
276
|
+
version_number_tuple = tuple([int(n) for n in version_number_str.split(".")])
|
277
|
+
major = version_number_tuple[0]
|
278
|
+
minor = version_number_tuple[1]
|
279
|
+
patch = 0
|
280
|
+
if len(version_number_tuple) > 2:
|
281
|
+
patch = version_number_tuple[2]
|
282
|
+
except:
|
283
|
+
raise Exception(
|
284
|
+
"The require_version should be formatted following PEP 440, "
|
285
|
+
"and inlucdes major, minor, and patch number, "
|
286
|
+
"e.g., major.minor.patch."
|
287
|
+
) from None
|
288
|
+
# Get installed version
|
289
|
+
versions = [
|
290
|
+
int(_ti_core.get_version_major()),
|
291
|
+
int(_ti_core.get_version_minor()),
|
292
|
+
int(_ti_core.get_version_patch()),
|
293
|
+
]
|
294
|
+
# Match installed version and required version.
|
295
|
+
match = major == versions[0] and (minor < versions[1] or minor == versions[1] and patch <= versions[2])
|
296
|
+
|
297
|
+
if not match:
|
298
|
+
raise Exception(
|
299
|
+
f"GsTaichi version mismatch. Required version >= {major}.{minor}.{patch}, installed version = {_ti_core.get_version_string()}."
|
300
|
+
)
|
301
|
+
|
302
|
+
|
303
|
+
def init(
|
304
|
+
arch=None,
|
305
|
+
default_fp=None,
|
306
|
+
default_ip=None,
|
307
|
+
_test_mode: bool = False,
|
308
|
+
enable_fallback: bool = True,
|
309
|
+
require_version: str | None = None,
|
310
|
+
src_ll_cache: bool = True,
|
311
|
+
**kwargs,
|
312
|
+
):
|
313
|
+
"""Initializes the GsTaichi runtime.
|
314
|
+
|
315
|
+
This should always be the entry point of your GsTaichi program. Most
|
316
|
+
importantly, it sets the backend used throughout the program.
|
317
|
+
|
318
|
+
Args:
|
319
|
+
arch: Backend to use. This is usually :const:`~gstaichi.lang.cpu` or :const:`~gstaichi.lang.gpu`.
|
320
|
+
default_fp (Optional[type]): Default floating-point type.
|
321
|
+
default_ip (Optional[type]): Default integral type.
|
322
|
+
require_version: A version string.
|
323
|
+
src_ll_cache: enable SRC-LL-CACHE, which will accelerate loading from cache, across all architectures,
|
324
|
+
for pure kernels (i.e. kernels declared as @ti.pure)
|
325
|
+
**kwargs: GsTaichi provides highly customizable compilation through
|
326
|
+
``kwargs``, which allows for fine grained control of GsTaichi compiler
|
327
|
+
behavior. Below we list some of the most frequently used ones. For a
|
328
|
+
complete list, please check out
|
329
|
+
https://github.com/taichi-dev/gstaichi/blob/master/gstaichi/program/compile_config.h.
|
330
|
+
|
331
|
+
* ``cpu_max_num_threads`` (int): Sets the number of threads used by the CPU thread pool.
|
332
|
+
* ``debug`` (bool): Enables the debug mode, under which GsTaichi does a few more things like boundary checks.
|
333
|
+
* ``print_ir`` (bool): Prints the CHI IR of the GsTaichi kernels.
|
334
|
+
*``offline_cache`` (bool): Enables offline cache of the compiled kernels. Default to True. When this is enabled GsTaichi will cache compiled kernel on your local disk to accelerate future calls.
|
335
|
+
*``random_seed`` (int): Sets the seed of the random generator. The default is 0.
|
336
|
+
"""
|
337
|
+
# Check version for users every 7 days if not disabled by users.
|
338
|
+
_version_check.start_version_check_thread()
|
339
|
+
|
340
|
+
# FIXME(https://github.com/taichi-dev/gstaichi/issues/4811): save the current working directory since it may be
|
341
|
+
# changed by the Vulkan backend initialization on OS X.
|
342
|
+
current_dir = os.getcwd()
|
343
|
+
|
344
|
+
# Check if installed version meets the requirements.
|
345
|
+
if require_version is not None:
|
346
|
+
check_require_version(require_version)
|
347
|
+
|
348
|
+
if "default_up" in kwargs:
|
349
|
+
raise KeyError("'default_up' is always the unsigned type of 'default_ip'. Please set 'default_ip' instead.")
|
350
|
+
# Make a deepcopy in case these args reference to items from ti.cfg, which are
|
351
|
+
# actually references. If no copy is made and the args are indeed references,
|
352
|
+
# ti.reset() could override the args to their default values.
|
353
|
+
default_fp = _deepcopy(default_fp)
|
354
|
+
default_ip = _deepcopy(default_ip)
|
355
|
+
kwargs = _deepcopy(kwargs)
|
356
|
+
reset()
|
357
|
+
|
358
|
+
cfg = impl.default_cfg()
|
359
|
+
cfg.offline_cache = True # Enable offline cache in frontend instead of C++ side
|
360
|
+
|
361
|
+
spec_cfg = _SpecialConfig()
|
362
|
+
env_comp = _EnvironmentConfigurator(kwargs, cfg)
|
363
|
+
env_spec = _EnvironmentConfigurator(kwargs, spec_cfg)
|
364
|
+
|
365
|
+
# configure default_fp/ip:
|
366
|
+
# TODO: move these stuff to _SpecialConfig too:
|
367
|
+
env_default_fp = os.environ.get("TI_DEFAULT_FP")
|
368
|
+
if env_default_fp:
|
369
|
+
if default_fp is not None:
|
370
|
+
_ti_core.warn(
|
371
|
+
f'Environment variable TI_DEFAULT_FP={env_default_fp} overridden by ti.init argument "default_fp"'
|
372
|
+
)
|
373
|
+
elif env_default_fp == "32":
|
374
|
+
default_fp = f32
|
375
|
+
elif env_default_fp == "64":
|
376
|
+
default_fp = f64
|
377
|
+
elif env_default_fp is not None:
|
378
|
+
raise ValueError(f"Invalid TI_DEFAULT_FP={env_default_fp}, should be 32 or 64")
|
379
|
+
|
380
|
+
env_default_ip = os.environ.get("TI_DEFAULT_IP")
|
381
|
+
if env_default_ip:
|
382
|
+
if default_ip is not None:
|
383
|
+
_ti_core.warn(
|
384
|
+
f'Environment variable TI_DEFAULT_IP={env_default_ip} overridden by ti.init argument "default_ip"'
|
385
|
+
)
|
386
|
+
elif env_default_ip == "32":
|
387
|
+
default_ip = i32
|
388
|
+
elif env_default_ip == "64":
|
389
|
+
default_ip = i64
|
390
|
+
elif env_default_ip is not None:
|
391
|
+
raise ValueError(f"Invalid TI_DEFAULT_IP={env_default_ip}, should be 32 or 64")
|
392
|
+
|
393
|
+
if default_fp is not None:
|
394
|
+
impl.get_runtime().set_default_fp(default_fp)
|
395
|
+
if default_ip is not None:
|
396
|
+
impl.get_runtime().set_default_ip(default_ip)
|
397
|
+
|
398
|
+
# submodule configurations (spec_cfg):
|
399
|
+
env_spec.add("log_level", str)
|
400
|
+
env_spec.add("gdb_trigger")
|
401
|
+
env_spec.add("short_circuit_operators")
|
402
|
+
env_spec.add("print_full_traceback")
|
403
|
+
env_spec.add("unrolling_limit")
|
404
|
+
|
405
|
+
# compiler configurations (ti.cfg):
|
406
|
+
for key in dir(cfg):
|
407
|
+
if key in ["arch", "default_fp", "default_ip"]:
|
408
|
+
continue
|
409
|
+
_cast = type(getattr(cfg, key))
|
410
|
+
if _cast is bool:
|
411
|
+
_cast = None
|
412
|
+
env_comp.add(key, _cast)
|
413
|
+
|
414
|
+
unexpected_keys = kwargs.keys()
|
415
|
+
|
416
|
+
if len(unexpected_keys):
|
417
|
+
raise KeyError(f'Unrecognized keyword argument(s) for ti.init: {", ".join(unexpected_keys)}')
|
418
|
+
|
419
|
+
# dispatch configurations that are not in ti.cfg:
|
420
|
+
if not _test_mode:
|
421
|
+
_ti_core.set_core_trigger_gdb_when_crash(spec_cfg.gdb_trigger)
|
422
|
+
impl.get_runtime().short_circuit_operators = spec_cfg.short_circuit_operators
|
423
|
+
impl.get_runtime().print_full_traceback = spec_cfg.print_full_traceback
|
424
|
+
impl.get_runtime().unrolling_limit = spec_cfg.unrolling_limit
|
425
|
+
impl.get_runtime().src_ll_cache = src_ll_cache
|
426
|
+
_logging.set_logging_level(spec_cfg.log_level.lower())
|
427
|
+
|
428
|
+
# select arch (backend):
|
429
|
+
env_arch = os.environ.get("TI_ARCH")
|
430
|
+
if env_arch is not None:
|
431
|
+
_logging.info(f"Following TI_ARCH setting up for arch={env_arch}")
|
432
|
+
arch = _ti_core.arch_from_name(env_arch)
|
433
|
+
cfg.arch = adaptive_arch_select(arch, enable_fallback)
|
434
|
+
print(f"[GsTaichi] Starting on arch={_ti_core.arch_name(cfg.arch)}")
|
435
|
+
|
436
|
+
if _test_mode:
|
437
|
+
return spec_cfg
|
438
|
+
|
439
|
+
get_default_kernel_profiler().set_kernel_profiler_mode(cfg.kernel_profiler)
|
440
|
+
|
441
|
+
# create a new program:
|
442
|
+
impl.get_runtime().create_program()
|
443
|
+
|
444
|
+
_logging.trace("Materializing runtime...")
|
445
|
+
impl.get_runtime().prog.materialize_runtime()
|
446
|
+
|
447
|
+
impl._root_fb = _snode.FieldsBuilder()
|
448
|
+
|
449
|
+
if cfg.debug:
|
450
|
+
impl.get_runtime()._register_signal_handlers()
|
451
|
+
|
452
|
+
# Recover the current working directory (https://github.com/taichi-dev/gstaichi/issues/4811)
|
453
|
+
os.chdir(current_dir)
|
454
|
+
return None
|
455
|
+
|
456
|
+
|
457
|
+
def no_activate(*args):
|
458
|
+
"""Deactivates a SNode pointer."""
|
459
|
+
assert isinstance(get_runtime().compiling_callable, _ti_core.KernelCxx)
|
460
|
+
for v in args:
|
461
|
+
get_runtime().compiling_callable.no_activate(v._snode.ptr)
|
462
|
+
|
463
|
+
|
464
|
+
def block_local(*args):
|
465
|
+
"""Hints GsTaichi to cache the fields and to enable the BLS optimization.
|
466
|
+
|
467
|
+
Please visit https://docs.taichi-lang.org/docs/performance
|
468
|
+
for how BLS is used.
|
469
|
+
|
470
|
+
Args:
|
471
|
+
*args (List[Field]): A list of sparse GsTaichi fields.
|
472
|
+
"""
|
473
|
+
if impl.current_cfg().opt_level == 0:
|
474
|
+
_logging.warn("""opt_level = 1 is enforced to enable bls analysis.""")
|
475
|
+
impl.current_cfg().opt_level = 1
|
476
|
+
for a in args:
|
477
|
+
for v in a._get_field_members():
|
478
|
+
get_runtime().compiling_callable.ast_builder().insert_snode_access_flag(
|
479
|
+
_ti_core.SNodeAccessFlag.block_local, v.ptr
|
480
|
+
)
|
481
|
+
|
482
|
+
|
483
|
+
def mesh_local(*args):
|
484
|
+
"""Hints the compiler to cache the mesh attributes
|
485
|
+
and to enable the mesh BLS optimization,
|
486
|
+
only available for backends supporting `ti.extension.mesh` and to use with mesh-for loop.
|
487
|
+
|
488
|
+
Related to https://github.com/taichi-dev/gstaichi/issues/3608
|
489
|
+
|
490
|
+
Args:
|
491
|
+
*args (List[Attribute]): A list of mesh attributes or fields accessed as attributes.
|
492
|
+
|
493
|
+
Examples::
|
494
|
+
|
495
|
+
# instantiate model
|
496
|
+
mesh_builder = ti.Mesh.tri()
|
497
|
+
mesh_builder.verts.place({
|
498
|
+
'x' : ti.f32,
|
499
|
+
'y' : ti.f32
|
500
|
+
})
|
501
|
+
model = mesh_builder.build(meta)
|
502
|
+
|
503
|
+
@ti.kernel
|
504
|
+
def foo():
|
505
|
+
# hint the compiler to cache mesh vertex attribute `x` and `y`.
|
506
|
+
ti.mesh_local(model.verts.x, model.verts.y)
|
507
|
+
for v0 in model.verts: # mesh-for loop
|
508
|
+
for v1 in v0.verts:
|
509
|
+
v0.x += v1.y
|
510
|
+
"""
|
511
|
+
for a in args:
|
512
|
+
for v in a._get_field_members():
|
513
|
+
get_runtime().compiling_callable.ast_builder().insert_snode_access_flag(
|
514
|
+
_ti_core.SNodeAccessFlag.mesh_local, v.ptr
|
515
|
+
)
|
516
|
+
|
517
|
+
|
518
|
+
def cache_read_only(*args):
|
519
|
+
for a in args:
|
520
|
+
for v in a._get_field_members():
|
521
|
+
get_runtime().compiling_callable.ast_builder().insert_snode_access_flag(
|
522
|
+
_ti_core.SNodeAccessFlag.read_only, v.ptr
|
523
|
+
)
|
524
|
+
|
525
|
+
|
526
|
+
def assume_in_range(val, base, low, high):
|
527
|
+
"""Hints the compiler that a value is between a specified range,
|
528
|
+
for the compiler to perform scatchpad optimization, and return the
|
529
|
+
value untouched.
|
530
|
+
|
531
|
+
The assumed range is `[base + low, base + high)`.
|
532
|
+
|
533
|
+
Args:
|
534
|
+
|
535
|
+
val (Number): The input value.
|
536
|
+
base (Number): The base point for the range interval.
|
537
|
+
low (Number): The lower offset relative to `base` (included).
|
538
|
+
high (Number): The higher offset relative to `base` (excluded).
|
539
|
+
|
540
|
+
Returns:
|
541
|
+
Return the input `value` untouched.
|
542
|
+
|
543
|
+
Example::
|
544
|
+
|
545
|
+
>>> # hint the compiler that x is in range [8, 12).
|
546
|
+
>>> x = ti.assume_in_range(x, 10, -2, 2)
|
547
|
+
>>> x
|
548
|
+
10
|
549
|
+
"""
|
550
|
+
return _ti_core.expr_assume_in_range(
|
551
|
+
Expr(val).ptr, Expr(base).ptr, low, high, _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
|
552
|
+
)
|
553
|
+
|
554
|
+
|
555
|
+
def loop_unique(val, covers=None):
|
556
|
+
if covers is None:
|
557
|
+
covers = []
|
558
|
+
if not isinstance(covers, (list, tuple)):
|
559
|
+
covers = [covers]
|
560
|
+
covers = [x.snode.ptr if isinstance(x, Expr) else x.ptr for x in covers]
|
561
|
+
return _ti_core.expr_loop_unique(
|
562
|
+
Expr(val).ptr, covers, _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
|
563
|
+
)
|
564
|
+
|
565
|
+
|
566
|
+
def _parallelize(v):
|
567
|
+
"""Sets the number of threads to use on CPU."""
|
568
|
+
get_runtime().compiling_callable.ast_builder().parallelize(v)
|
569
|
+
if v == 1:
|
570
|
+
get_runtime().compiling_callable.ast_builder().strictly_serialize()
|
571
|
+
|
572
|
+
|
573
|
+
def _serialize():
|
574
|
+
"""Sets the number of threads to 1."""
|
575
|
+
_parallelize(1)
|
576
|
+
|
577
|
+
|
578
|
+
def _block_dim(dim):
|
579
|
+
"""Set the number of threads in a block to `dim`."""
|
580
|
+
get_runtime().compiling_callable.ast_builder().block_dim(dim)
|
581
|
+
|
582
|
+
|
583
|
+
def _block_dim_adaptive(block_dim_adaptive):
|
584
|
+
"""Enable/Disable backends set block_dim adaptively."""
|
585
|
+
if get_runtime().prog.config().arch != cpu:
|
586
|
+
_logging.warn("Adaptive block_dim is supported on CPU backend only")
|
587
|
+
else:
|
588
|
+
get_runtime().prog.config().cpu_block_dim_adaptive = block_dim_adaptive
|
589
|
+
|
590
|
+
|
591
|
+
def _bit_vectorize():
|
592
|
+
"""Enable bit vectorization of struct fors on quant_arrays."""
|
593
|
+
get_runtime().compiling_callable.ast_builder().bit_vectorize()
|
594
|
+
|
595
|
+
|
596
|
+
def loop_config(
|
597
|
+
*,
|
598
|
+
block_dim=None,
|
599
|
+
serialize=False,
|
600
|
+
parallelize=None,
|
601
|
+
block_dim_adaptive=True,
|
602
|
+
bit_vectorize=False,
|
603
|
+
):
|
604
|
+
"""Sets directives for the next loop
|
605
|
+
|
606
|
+
Args:
|
607
|
+
block_dim (int): The number of threads in a block on GPU
|
608
|
+
serialize (bool): Whether to let the for loop execute serially, `serialize=True` equals to `parallelize=1`
|
609
|
+
parallelize (int): The number of threads to use on CPU
|
610
|
+
block_dim_adaptive (bool): Whether to allow backends set block_dim adaptively, enabled by default
|
611
|
+
bit_vectorize (bool): Whether to enable bit vectorization of struct fors on quant_arrays.
|
612
|
+
|
613
|
+
Examples::
|
614
|
+
|
615
|
+
@ti.kernel
|
616
|
+
def break_in_serial_for() -> ti.i32:
|
617
|
+
a = 0
|
618
|
+
ti.loop_config(serialize=True)
|
619
|
+
for i in range(100): # This loop runs serially
|
620
|
+
a += i
|
621
|
+
if i == 10:
|
622
|
+
break
|
623
|
+
return a
|
624
|
+
|
625
|
+
break_in_serial_for() # returns 55
|
626
|
+
|
627
|
+
n = 128
|
628
|
+
val = ti.field(ti.i32, shape=n)
|
629
|
+
@ti.kernel
|
630
|
+
def fill():
|
631
|
+
ti.loop_config(parallelize=8, block_dim=16)
|
632
|
+
# If the kernel is run on the CPU backend, 8 threads will be used to run it
|
633
|
+
# If the kernel is run on the CUDA backend, each block will have 16 threads.
|
634
|
+
for i in range(n):
|
635
|
+
val[i] = i
|
636
|
+
|
637
|
+
u1 = ti.types.quant.int(bits=1, signed=False)
|
638
|
+
x = ti.field(dtype=u1)
|
639
|
+
y = ti.field(dtype=u1)
|
640
|
+
cell = ti.root.dense(ti.ij, (128, 4))
|
641
|
+
cell.quant_array(ti.j, 32).place(x)
|
642
|
+
cell.quant_array(ti.j, 32).place(y)
|
643
|
+
@ti.kernel
|
644
|
+
def copy():
|
645
|
+
ti.loop_config(bit_vectorize=True)
|
646
|
+
# 32 bits, instead of 1 bit, will be copied at a time
|
647
|
+
for i, j in x:
|
648
|
+
y[i, j] = x[i, j]
|
649
|
+
"""
|
650
|
+
if block_dim is not None:
|
651
|
+
_block_dim(block_dim)
|
652
|
+
|
653
|
+
if serialize:
|
654
|
+
_parallelize(1)
|
655
|
+
elif parallelize is not None:
|
656
|
+
_parallelize(parallelize)
|
657
|
+
|
658
|
+
if not block_dim_adaptive:
|
659
|
+
_block_dim_adaptive(block_dim_adaptive)
|
660
|
+
|
661
|
+
if bit_vectorize:
|
662
|
+
_bit_vectorize()
|
663
|
+
|
664
|
+
|
665
|
+
def global_thread_idx():
|
666
|
+
"""Returns the global thread id of this running thread,
|
667
|
+
only available for cpu and cuda backends.
|
668
|
+
|
669
|
+
For cpu backends this is equal to the cpu thread id,
|
670
|
+
For cuda backends this is equal to `block_id * block_dim + thread_id`.
|
671
|
+
|
672
|
+
Example::
|
673
|
+
|
674
|
+
>>> f = ti.field(ti.f32, shape=(16, 16))
|
675
|
+
>>> @ti.kernel
|
676
|
+
>>> def test():
|
677
|
+
>>> for i in ti.grouped(f):
|
678
|
+
>>> print(ti.global_thread_idx())
|
679
|
+
>>>
|
680
|
+
test()
|
681
|
+
"""
|
682
|
+
return impl.get_runtime().compiling_callable.ast_builder().insert_thread_idx_expr()
|
683
|
+
|
684
|
+
|
685
|
+
def mesh_patch_idx():
|
686
|
+
"""Returns the internal mesh patch id of this running thread,
|
687
|
+
only available for backends supporting `ti.extension.mesh` and to use within mesh-for loop.
|
688
|
+
|
689
|
+
Related to https://github.com/taichi-dev/gstaichi/issues/3608
|
690
|
+
"""
|
691
|
+
return (
|
692
|
+
impl.get_runtime()
|
693
|
+
.compiling_callable.ast_builder()
|
694
|
+
.insert_patch_idx_expr(_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()))
|
695
|
+
)
|
696
|
+
|
697
|
+
|
698
|
+
def is_arch_supported(arch):
|
699
|
+
"""Checks whether an arch is supported on the machine.
|
700
|
+
|
701
|
+
Args:
|
702
|
+
arch (gstaichi_python.Arch): Specified arch.
|
703
|
+
|
704
|
+
Returns:
|
705
|
+
bool: Whether `arch` is supported on the machine.
|
706
|
+
"""
|
707
|
+
|
708
|
+
arch_table = {
|
709
|
+
cuda: _ti_core.with_cuda,
|
710
|
+
amdgpu: _ti_core.with_amdgpu,
|
711
|
+
metal: _ti_core.with_metal,
|
712
|
+
vulkan: _ti_core.with_vulkan,
|
713
|
+
cpu: lambda: True,
|
714
|
+
}
|
715
|
+
with_arch = arch_table.get(arch, lambda: False)
|
716
|
+
try:
|
717
|
+
return with_arch()
|
718
|
+
except Exception as e:
|
719
|
+
arch = _ti_core.arch_name(arch)
|
720
|
+
_ti_core.warn(
|
721
|
+
f"{e.__class__.__name__}: '{e}' occurred when detecting "
|
722
|
+
f"{arch}, consider adding `TI_ENABLE_{arch.upper()}=0` "
|
723
|
+
f" to environment variables to suppress this warning message."
|
724
|
+
)
|
725
|
+
return False
|
726
|
+
|
727
|
+
|
728
|
+
def adaptive_arch_select(arch, enable_fallback):
|
729
|
+
if arch is None:
|
730
|
+
return cpu
|
731
|
+
if not isinstance(arch, (list, tuple)):
|
732
|
+
arch = [arch]
|
733
|
+
for a in arch:
|
734
|
+
if is_arch_supported(a):
|
735
|
+
return a
|
736
|
+
if not enable_fallback:
|
737
|
+
raise RuntimeError(f"Arch={arch} is not supported")
|
738
|
+
_logging.warn(f"Arch={arch} is not supported, falling back to CPU")
|
739
|
+
return cpu
|
740
|
+
|
741
|
+
|
742
|
+
def get_host_arch_list():
|
743
|
+
return [_ti_core.host_arch()]
|
744
|
+
|
745
|
+
|
746
|
+
__all__ = [
|
747
|
+
"i",
|
748
|
+
"ij",
|
749
|
+
"ijk",
|
750
|
+
"ijkl",
|
751
|
+
"ijl",
|
752
|
+
"ik",
|
753
|
+
"ikl",
|
754
|
+
"il",
|
755
|
+
"j",
|
756
|
+
"jk",
|
757
|
+
"jkl",
|
758
|
+
"jl",
|
759
|
+
"k",
|
760
|
+
"kl",
|
761
|
+
"l",
|
762
|
+
"x86_64",
|
763
|
+
"x64",
|
764
|
+
"arm64",
|
765
|
+
"cpu",
|
766
|
+
"cuda",
|
767
|
+
"amdgpu",
|
768
|
+
"gpu",
|
769
|
+
"metal",
|
770
|
+
"vulkan",
|
771
|
+
"extension",
|
772
|
+
"loop_config",
|
773
|
+
"global_thread_idx",
|
774
|
+
"assume_in_range",
|
775
|
+
"block_local",
|
776
|
+
"cache_read_only",
|
777
|
+
"init",
|
778
|
+
"mesh_local",
|
779
|
+
"no_activate",
|
780
|
+
"reset",
|
781
|
+
"mesh_patch_idx",
|
782
|
+
]
|