gstaichi 2.1.1__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/__init__.py +40 -0
- gstaichi/_funcs.py +706 -0
- gstaichi/_kernels.py +420 -0
- gstaichi/_lib/__init__.py +3 -0
- gstaichi/_lib/core/__init__.py +0 -0
- gstaichi/_lib/core/gstaichi_python.cpython-311-darwin.so +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
- gstaichi/_lib/core/py.typed +0 -0
- gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
- gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
- gstaichi/_lib/utils.py +243 -0
- gstaichi/_logging.py +131 -0
- gstaichi/_snode/__init__.py +5 -0
- gstaichi/_snode/fields_builder.py +187 -0
- gstaichi/_snode/snode_tree.py +34 -0
- gstaichi/_test_tools/__init__.py +18 -0
- gstaichi/_test_tools/dataclass_test_tools.py +36 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_test_tools/textwrap2.py +6 -0
- gstaichi/_version.py +1 -0
- gstaichi/_version_check.py +100 -0
- gstaichi/ad/__init__.py +3 -0
- gstaichi/ad/_ad.py +530 -0
- gstaichi/algorithms/__init__.py +3 -0
- gstaichi/algorithms/_algorithms.py +117 -0
- gstaichi/assets/.git +1 -0
- gstaichi/assets/Go-Regular.ttf +0 -0
- gstaichi/assets/static/imgs/ti_gallery.png +0 -0
- gstaichi/examples/lcg_python.py +26 -0
- gstaichi/examples/lcg_taichi.py +34 -0
- gstaichi/examples/minimal.py +28 -0
- gstaichi/experimental.py +16 -0
- gstaichi/lang/__init__.py +50 -0
- gstaichi/lang/_dataclass_util.py +31 -0
- gstaichi/lang/_fast_caching/__init__.py +3 -0
- gstaichi/lang/_fast_caching/args_hasher.py +110 -0
- gstaichi/lang/_fast_caching/config_hasher.py +30 -0
- gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
- gstaichi/lang/_fast_caching/function_hasher.py +57 -0
- gstaichi/lang/_fast_caching/hash_utils.py +11 -0
- gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
- gstaichi/lang/_fast_caching/src_hasher.py +75 -0
- gstaichi/lang/_kernel_impl_dataclass.py +212 -0
- gstaichi/lang/_ndarray.py +352 -0
- gstaichi/lang/_ndrange.py +152 -0
- gstaichi/lang/_template_mapper.py +195 -0
- gstaichi/lang/_texture.py +172 -0
- gstaichi/lang/_wrap_inspect.py +215 -0
- gstaichi/lang/any_array.py +99 -0
- gstaichi/lang/ast/__init__.py +5 -0
- gstaichi/lang/ast/ast_transformer.py +1323 -0
- gstaichi/lang/ast/ast_transformer_utils.py +346 -0
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
- gstaichi/lang/ast/checkers.py +106 -0
- gstaichi/lang/ast/symbol_resolver.py +57 -0
- gstaichi/lang/ast/transform.py +9 -0
- gstaichi/lang/common_ops.py +310 -0
- gstaichi/lang/exception.py +80 -0
- gstaichi/lang/expr.py +180 -0
- gstaichi/lang/field.py +428 -0
- gstaichi/lang/impl.py +1245 -0
- gstaichi/lang/kernel_arguments.py +155 -0
- gstaichi/lang/kernel_impl.py +1341 -0
- gstaichi/lang/matrix.py +1835 -0
- gstaichi/lang/matrix_ops.py +341 -0
- gstaichi/lang/matrix_ops_utils.py +190 -0
- gstaichi/lang/mesh.py +687 -0
- gstaichi/lang/misc.py +780 -0
- gstaichi/lang/ops.py +1494 -0
- gstaichi/lang/runtime_ops.py +13 -0
- gstaichi/lang/shell.py +35 -0
- gstaichi/lang/simt/__init__.py +5 -0
- gstaichi/lang/simt/block.py +94 -0
- gstaichi/lang/simt/grid.py +7 -0
- gstaichi/lang/simt/subgroup.py +191 -0
- gstaichi/lang/simt/warp.py +96 -0
- gstaichi/lang/snode.py +489 -0
- gstaichi/lang/source_builder.py +150 -0
- gstaichi/lang/struct.py +810 -0
- gstaichi/lang/util.py +312 -0
- gstaichi/linalg/__init__.py +8 -0
- gstaichi/linalg/matrixfree_cg.py +310 -0
- gstaichi/linalg/sparse_cg.py +59 -0
- gstaichi/linalg/sparse_matrix.py +303 -0
- gstaichi/linalg/sparse_solver.py +123 -0
- gstaichi/math/__init__.py +11 -0
- gstaichi/math/_complex.py +205 -0
- gstaichi/math/mathimpl.py +886 -0
- gstaichi/profiler/__init__.py +6 -0
- gstaichi/profiler/kernel_metrics.py +260 -0
- gstaichi/profiler/kernel_profiler.py +586 -0
- gstaichi/profiler/memory_profiler.py +15 -0
- gstaichi/profiler/scoped_profiler.py +36 -0
- gstaichi/sparse/__init__.py +3 -0
- gstaichi/sparse/_sparse_grid.py +77 -0
- gstaichi/tools/__init__.py +12 -0
- gstaichi/tools/diagnose.py +117 -0
- gstaichi/tools/np2ply.py +364 -0
- gstaichi/tools/vtk.py +38 -0
- gstaichi/types/__init__.py +19 -0
- gstaichi/types/annotations.py +52 -0
- gstaichi/types/compound_types.py +71 -0
- gstaichi/types/enums.py +49 -0
- gstaichi/types/ndarray_type.py +169 -0
- gstaichi/types/primitive_types.py +206 -0
- gstaichi/types/quant.py +88 -0
- gstaichi/types/texture_type.py +85 -0
- gstaichi/types/utils.py +11 -0
- gstaichi-2.1.1.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-2.1.1.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-2.1.1.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-2.1.1.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-2.1.1.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-2.1.1.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-2.1.1.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv.h +2568 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv.hpp +2579 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
- gstaichi-2.1.1.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-2.1.1.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-2.1.1.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-2.1.1.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-2.1.1.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
- gstaichi-2.1.1.dist-info/METADATA +106 -0
- gstaichi-2.1.1.dist-info/RECORD +178 -0
- gstaichi-2.1.1.dist-info/WHEEL +5 -0
- gstaichi-2.1.1.dist-info/licenses/LICENSE +201 -0
- gstaichi-2.1.1.dist-info/top_level.txt +1 -0
File without changes
|
Binary file
|
Binary file
|
gstaichi/_lib/utils.py
ADDED
@@ -0,0 +1,243 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import os
|
4
|
+
import platform
|
5
|
+
import re
|
6
|
+
import sys
|
7
|
+
import warnings
|
8
|
+
|
9
|
+
from colorama import Fore, Style
|
10
|
+
|
11
|
+
if sys.version_info[0] < 3 or sys.version_info[1] <= 5:
|
12
|
+
raise RuntimeError(
|
13
|
+
"\nPlease restart with Python 3.6+\n" + "Current Python version:",
|
14
|
+
sys.version_info,
|
15
|
+
)
|
16
|
+
|
17
|
+
|
18
|
+
def in_docker():
|
19
|
+
if os.environ.get("TI_IN_DOCKER", "") == "":
|
20
|
+
return False
|
21
|
+
return True
|
22
|
+
|
23
|
+
|
24
|
+
def get_os_name():
|
25
|
+
name = platform.platform()
|
26
|
+
# in python 3.8, platform.platform() uses mac_ver() on macOS
|
27
|
+
# it will return 'macOS-XXXX' instead of 'Darwin-XXXX'
|
28
|
+
if name.lower().startswith("darwin") or name.lower().startswith("macos"):
|
29
|
+
return "osx"
|
30
|
+
if name.lower().startswith("windows"):
|
31
|
+
return "win"
|
32
|
+
if name.lower().startswith("linux"):
|
33
|
+
return "linux"
|
34
|
+
if "bsd" in name.lower():
|
35
|
+
return "unix"
|
36
|
+
assert False, f"Unknown platform name {name}"
|
37
|
+
|
38
|
+
|
39
|
+
def import_ti_python_core():
|
40
|
+
if get_os_name() != "win":
|
41
|
+
# pylint: disable=E1101
|
42
|
+
old_flags = sys.getdlopenflags()
|
43
|
+
sys.setdlopenflags(2 | 8) # RTLD_NOW | RTLD_DEEPBIND
|
44
|
+
else:
|
45
|
+
pyddir = os.path.dirname(os.path.realpath(__file__))
|
46
|
+
os.environ["PATH"] += os.pathsep + pyddir
|
47
|
+
try:
|
48
|
+
from gstaichi._lib.core import gstaichi_python as core # pylint: disable=C0415
|
49
|
+
except Exception as e:
|
50
|
+
if isinstance(e, ImportError):
|
51
|
+
print(
|
52
|
+
Fore.YELLOW + "Share object gstaichi_python import failed, "
|
53
|
+
"check this page for possible solutions:\n"
|
54
|
+
"https://docs.taichi-lang.org/docs/install" + Fore.RESET
|
55
|
+
)
|
56
|
+
if get_os_name() == "win":
|
57
|
+
# pylint: disable=E1101
|
58
|
+
e.msg += "\nConsider installing Microsoft Visual C++ Redistributable: https://aka.ms/vs/16/release/vc_redist.x64.exe"
|
59
|
+
raise e from None
|
60
|
+
|
61
|
+
if get_os_name() != "win":
|
62
|
+
sys.setdlopenflags(old_flags) # pylint: disable=E1101
|
63
|
+
lib_dir = os.path.join(package_root, "_lib", "runtime")
|
64
|
+
core.set_lib_dir(locale_encode(lib_dir))
|
65
|
+
return core
|
66
|
+
|
67
|
+
|
68
|
+
def locale_encode(path):
|
69
|
+
try:
|
70
|
+
import locale # pylint: disable=C0415
|
71
|
+
|
72
|
+
return path.encode(locale.getdefaultlocale()[1])
|
73
|
+
except (UnicodeEncodeError, TypeError):
|
74
|
+
try:
|
75
|
+
return path.encode(sys.getfilesystemencoding())
|
76
|
+
except UnicodeEncodeError:
|
77
|
+
try:
|
78
|
+
return path.encode()
|
79
|
+
except UnicodeEncodeError:
|
80
|
+
return path
|
81
|
+
|
82
|
+
|
83
|
+
def is_ci():
|
84
|
+
return os.environ.get("TI_CI", "") == "1"
|
85
|
+
|
86
|
+
|
87
|
+
package_root = os.path.join(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
|
88
|
+
|
89
|
+
|
90
|
+
def get_core_shared_object():
|
91
|
+
directory = os.path.join(package_root, "_lib")
|
92
|
+
return os.path.join(directory, "libgstaichi_python.so")
|
93
|
+
|
94
|
+
|
95
|
+
def print_red_bold(*args, **kwargs):
|
96
|
+
print(Fore.RED + Style.BRIGHT, end="")
|
97
|
+
print(*args, **kwargs)
|
98
|
+
print(Style.RESET_ALL, end="")
|
99
|
+
|
100
|
+
|
101
|
+
def print_yellow_bold(*args, **kwargs):
|
102
|
+
print(Fore.YELLOW + Style.BRIGHT, end="")
|
103
|
+
print(*args, **kwargs)
|
104
|
+
print(Style.RESET_ALL, end="")
|
105
|
+
|
106
|
+
|
107
|
+
def check_exists(src):
|
108
|
+
if not os.path.exists(src):
|
109
|
+
raise FileNotFoundError(f'File "{src}" not exist. Installation corrupted or build incomplete?')
|
110
|
+
|
111
|
+
|
112
|
+
ti_python_core = import_ti_python_core()
|
113
|
+
|
114
|
+
ti_python_core.set_python_package_dir(package_root)
|
115
|
+
|
116
|
+
log_level = os.environ.get("TI_LOG_LEVEL", "")
|
117
|
+
if log_level:
|
118
|
+
ti_python_core.set_logging_level(log_level)
|
119
|
+
|
120
|
+
|
121
|
+
def get_dll_name(name):
|
122
|
+
if get_os_name() == "linux":
|
123
|
+
return f"libgstaichi_{name}.so"
|
124
|
+
if get_os_name() == "osx":
|
125
|
+
return f"libgstaichi_{name}.dylib"
|
126
|
+
if get_os_name() == "win":
|
127
|
+
return f"gstaichi_{name}.dll"
|
128
|
+
raise Exception(f"Unknown OS: {get_os_name()}")
|
129
|
+
|
130
|
+
|
131
|
+
def at_startup():
|
132
|
+
ti_python_core.set_core_state_python_imported(True)
|
133
|
+
|
134
|
+
|
135
|
+
at_startup()
|
136
|
+
|
137
|
+
|
138
|
+
def compare_version(latest, current):
|
139
|
+
latest_num = map(int, latest.split("."))
|
140
|
+
current_num = map(int, current.split("."))
|
141
|
+
return tuple(latest_num) > tuple(current_num)
|
142
|
+
|
143
|
+
|
144
|
+
def _print_gstaichi_header():
|
145
|
+
header = "[GsTaichi] "
|
146
|
+
header += f"version {ti_python_core.get_version_string()}, "
|
147
|
+
|
148
|
+
try:
|
149
|
+
timestamp_path = os.path.join(ti_python_core.get_repo_dir(), "timestamp")
|
150
|
+
if os.path.exists(timestamp_path):
|
151
|
+
latest_version = ""
|
152
|
+
with open(timestamp_path, "r") as f:
|
153
|
+
latest_version = f.readlines()[1].rstrip()
|
154
|
+
if compare_version(latest_version, ti_python_core.get_version_string()):
|
155
|
+
header += f"latest version {latest_version}, "
|
156
|
+
except:
|
157
|
+
pass
|
158
|
+
|
159
|
+
llvm_target_support = ti_python_core.get_llvm_target_support()
|
160
|
+
header += f"llvm {llvm_target_support}, "
|
161
|
+
|
162
|
+
commit_hash = ti_python_core.get_commit_hash()
|
163
|
+
commit_hash = commit_hash[:8]
|
164
|
+
header += f"commit {commit_hash}, "
|
165
|
+
|
166
|
+
header += f"{get_os_name()}, "
|
167
|
+
|
168
|
+
py_ver = ".".join(str(x) for x in sys.version_info[:3])
|
169
|
+
header += f"python {py_ver}"
|
170
|
+
|
171
|
+
print(header)
|
172
|
+
|
173
|
+
|
174
|
+
if os.getenv("ENABLE_GSTAICHI_HEADER_PRINT", "True").lower() not in ("false", "0", "f"):
|
175
|
+
_print_gstaichi_header()
|
176
|
+
|
177
|
+
|
178
|
+
def try_get_wheel_tag(module):
|
179
|
+
try:
|
180
|
+
from email.parser import Parser # pylint: disable=import-outside-toplevel
|
181
|
+
|
182
|
+
wheel_path = f'{module.__path__[0]}-{".".join(map(str, module.__version__))}.dist-info/WHEEL'
|
183
|
+
with open(wheel_path, "r") as f:
|
184
|
+
meta = Parser().parse(f)
|
185
|
+
return meta.get("Tag")
|
186
|
+
except Exception:
|
187
|
+
return None
|
188
|
+
|
189
|
+
|
190
|
+
def try_get_loaded_libc_version():
|
191
|
+
assert platform.system() == "Linux"
|
192
|
+
with open("/proc/self/maps") as f:
|
193
|
+
content = f.read()
|
194
|
+
|
195
|
+
try:
|
196
|
+
libc_path = next(v for v in content.split() if "libc-" in v)
|
197
|
+
ver = re.findall(r"\d+\.\d+", libc_path)
|
198
|
+
if not ver:
|
199
|
+
return None
|
200
|
+
return tuple([int(v) for v in ver[0].split(".")])
|
201
|
+
except StopIteration:
|
202
|
+
return None
|
203
|
+
|
204
|
+
|
205
|
+
def try_get_pip_version():
|
206
|
+
try:
|
207
|
+
with warnings.catch_warnings():
|
208
|
+
warnings.simplefilter("ignore")
|
209
|
+
import pip # pylint: disable=import-outside-toplevel
|
210
|
+
return tuple([int(v) for v in pip.__version__.split(".")])
|
211
|
+
except ImportError:
|
212
|
+
return None
|
213
|
+
|
214
|
+
|
215
|
+
def warn_restricted_version():
|
216
|
+
if os.environ.get("TI_MANYLINUX2014_OK", ""):
|
217
|
+
return
|
218
|
+
|
219
|
+
if get_os_name() == "linux":
|
220
|
+
try:
|
221
|
+
import gstaichi as ti # pylint: disable=import-outside-toplevel
|
222
|
+
|
223
|
+
wheel_tag = try_get_wheel_tag(ti)
|
224
|
+
if wheel_tag and "manylinux" in wheel_tag:
|
225
|
+
libc_ver = try_get_loaded_libc_version()
|
226
|
+
if libc_ver and libc_ver < (2, 27):
|
227
|
+
print_yellow_bold(
|
228
|
+
"!! GsTaichi requires glibc >= 2.27 to run, please try upgrading your OS to a recent one (e.g. Ubuntu 18.04 or later) if possible."
|
229
|
+
)
|
230
|
+
|
231
|
+
pip_ver = try_get_pip_version()
|
232
|
+
if pip_ver and pip_ver < (20, 3, 0):
|
233
|
+
print_yellow_bold(
|
234
|
+
f"!! Your pip (version {'.'.join(map(str, pip_ver))}) is outdated (20.3.0 or later required), "
|
235
|
+
"try upgrading pip and install gstaichi again."
|
236
|
+
)
|
237
|
+
print()
|
238
|
+
print_yellow_bold(" $ python3 -m pip install --upgrade pip")
|
239
|
+
print_yellow_bold(" $ python3 -m pip install --force-reinstall gstaichi")
|
240
|
+
print()
|
241
|
+
|
242
|
+
except Exception:
|
243
|
+
pass
|
gstaichi/_logging.py
ADDED
@@ -0,0 +1,131 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import inspect
|
4
|
+
import os
|
5
|
+
|
6
|
+
from gstaichi._lib import core as ti_python_core
|
7
|
+
|
8
|
+
|
9
|
+
def _get_logging(name):
|
10
|
+
"""Generates a decorator to decorate a specific logger function.
|
11
|
+
|
12
|
+
Args:
|
13
|
+
name (str): The string represents logging level.
|
14
|
+
Effective levels include: 'trace', 'debug', 'info', 'warn', 'error', 'critical'.
|
15
|
+
|
16
|
+
Returns:
|
17
|
+
Callabe: The decorated function.
|
18
|
+
"""
|
19
|
+
|
20
|
+
def logger(msg, *args, **kwargs):
|
21
|
+
# Python inspection takes time (~0.1ms) so avoid it as much as possible
|
22
|
+
if ti_python_core.logging_effective(name):
|
23
|
+
msg_formatted = msg.format(*args, **kwargs)
|
24
|
+
func = getattr(ti_python_core, name)
|
25
|
+
frame = inspect.currentframe().f_back
|
26
|
+
file_name, lineno, func_name, _, _ = inspect.getframeinfo(frame)
|
27
|
+
file_name = os.path.basename(file_name)
|
28
|
+
msg = f"[{file_name}:{func_name}@{lineno}] {msg_formatted}"
|
29
|
+
func(msg)
|
30
|
+
|
31
|
+
return logger
|
32
|
+
|
33
|
+
|
34
|
+
def set_logging_level(level):
|
35
|
+
"""Setting the logging level to a specified value.
|
36
|
+
Available levels are: 'trace', 'debug', 'info', 'warn', 'error', 'critical'.
|
37
|
+
|
38
|
+
Note that after calling this function, logging levels below the specified one will
|
39
|
+
also be effective. For example if `level` is set to 'warn', then the levels below
|
40
|
+
it, which are 'error' and 'critical' in this case, will also be effective.
|
41
|
+
|
42
|
+
See also https://docs.taichi-lang.org/docs/developer_utilities#logging.
|
43
|
+
|
44
|
+
Args:
|
45
|
+
level (str): Logging level.
|
46
|
+
|
47
|
+
Example::
|
48
|
+
|
49
|
+
>>> set_logging_level('debug')
|
50
|
+
"""
|
51
|
+
ti_python_core.set_logging_level(level)
|
52
|
+
|
53
|
+
|
54
|
+
def is_logging_effective(level):
|
55
|
+
"""Check if the specified logging level is effective.
|
56
|
+
All levels below current level will be effective.
|
57
|
+
The default level is 'info'.
|
58
|
+
|
59
|
+
See also https://docs.taichi-lang.org/docs/developer_utilities#logging.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
level (str): The string represents logging level. \
|
63
|
+
Effective levels include: 'trace', 'debug', 'info', 'warn', 'error', 'critical'.
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
Bool: Indicate whether the logging level is effective.
|
67
|
+
|
68
|
+
Example::
|
69
|
+
|
70
|
+
>>> # assume current level is 'info'
|
71
|
+
>>> print(ti.is_logging_effective("trace")) # False
|
72
|
+
>>> print(ti.is_logging_effective("debug")) # False
|
73
|
+
>>> print(ti.is_logging_effective("info")) # True
|
74
|
+
>>> print(ti.is_logging_effective("warn")) # True
|
75
|
+
>>> print(ti.is_logging_effective("error")) # True
|
76
|
+
>>> print(ti.is_logging_effective("critical")) # True
|
77
|
+
"""
|
78
|
+
return ti_python_core.logging_effective(level)
|
79
|
+
|
80
|
+
|
81
|
+
# ------------------------
|
82
|
+
|
83
|
+
DEBUG = "debug"
|
84
|
+
"""The `str` 'debug', used for the `debug` logging level.
|
85
|
+
"""
|
86
|
+
# ------------------------
|
87
|
+
|
88
|
+
TRACE = "trace"
|
89
|
+
"""The `str` 'trace', used for the `debug` logging level.
|
90
|
+
"""
|
91
|
+
# ------------------------
|
92
|
+
|
93
|
+
INFO = "info"
|
94
|
+
"""The `str` 'info', used for the `info` logging level.
|
95
|
+
"""
|
96
|
+
# ------------------------
|
97
|
+
|
98
|
+
WARN = "warn"
|
99
|
+
"""The `str` 'warn', used for the `warn` logging level.
|
100
|
+
"""
|
101
|
+
# ------------------------
|
102
|
+
|
103
|
+
ERROR = "error"
|
104
|
+
"""The `str` 'error', used for the `error` logging level.
|
105
|
+
"""
|
106
|
+
# ------------------------
|
107
|
+
|
108
|
+
CRITICAL = "critical"
|
109
|
+
"""The `str` 'critical', used for the `critical` logging level.
|
110
|
+
"""
|
111
|
+
# ------------------------
|
112
|
+
|
113
|
+
supported_log_levels = [DEBUG, TRACE, INFO, WARN, ERROR, CRITICAL]
|
114
|
+
|
115
|
+
debug = _get_logging(DEBUG)
|
116
|
+
trace = _get_logging(TRACE)
|
117
|
+
info = _get_logging(INFO)
|
118
|
+
warn = _get_logging(WARN)
|
119
|
+
error = _get_logging(ERROR)
|
120
|
+
critical = _get_logging(CRITICAL)
|
121
|
+
|
122
|
+
__all__ = [
|
123
|
+
"DEBUG",
|
124
|
+
"TRACE",
|
125
|
+
"INFO",
|
126
|
+
"WARN",
|
127
|
+
"ERROR",
|
128
|
+
"CRITICAL",
|
129
|
+
"set_logging_level",
|
130
|
+
"is_logging_effective",
|
131
|
+
]
|
@@ -0,0 +1,187 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
from typing import Any, Optional, Sequence, Union
|
4
|
+
|
5
|
+
from gstaichi._lib import core as _ti_core
|
6
|
+
from gstaichi._lib.core.gstaichi_python import SNodeCxx
|
7
|
+
from gstaichi._snode.snode_tree import SNodeTree
|
8
|
+
from gstaichi.lang import impl, snode
|
9
|
+
from gstaichi.lang.exception import GsTaichiRuntimeError
|
10
|
+
from gstaichi.lang.util import warning
|
11
|
+
|
12
|
+
_snode_registry = _ti_core.SNodeRegistry()
|
13
|
+
|
14
|
+
_Axis = _ti_core.Axis
|
15
|
+
|
16
|
+
|
17
|
+
class FieldsBuilder:
|
18
|
+
"""A builder that constructs a SNodeTree instance.
|
19
|
+
|
20
|
+
Example::
|
21
|
+
|
22
|
+
x = ti.field(ti.i32)
|
23
|
+
y = ti.field(ti.f32)
|
24
|
+
fb = ti.FieldsBuilder()
|
25
|
+
fb.dense(ti.ij, 8).place(x)
|
26
|
+
fb.pointer(ti.ij, 8).dense(ti.ij, 4).place(y)
|
27
|
+
|
28
|
+
# After this line, `x` and `y` are placed. No more fields can be placed
|
29
|
+
# into `fb`.
|
30
|
+
#
|
31
|
+
# The tree looks like the following:
|
32
|
+
# (implicit root)
|
33
|
+
# |
|
34
|
+
# +-- dense +-- place(x)
|
35
|
+
# |
|
36
|
+
# +-- pointer +-- dense +-- place(y)
|
37
|
+
fb.finalize()
|
38
|
+
"""
|
39
|
+
|
40
|
+
def __init__(self):
|
41
|
+
self.ptr: SNodeCxx = _snode_registry.create_root(impl.get_runtime().prog)
|
42
|
+
self.root = snode.SNode(self.ptr)
|
43
|
+
self.finalized = False
|
44
|
+
self.empty = True
|
45
|
+
impl.get_runtime().initialize_fields_builder(self)
|
46
|
+
|
47
|
+
# TODO: move this into SNodeTree
|
48
|
+
@classmethod
|
49
|
+
def _finalized_roots(cls):
|
50
|
+
"""Gets all the roots of the finalized SNodeTree.
|
51
|
+
|
52
|
+
Returns:
|
53
|
+
A list of the roots of the finalized SNodeTree.
|
54
|
+
"""
|
55
|
+
roots_ptr = []
|
56
|
+
size = impl.get_runtime().prog.get_snode_tree_size()
|
57
|
+
for i in range(size):
|
58
|
+
res = impl.get_runtime().prog.get_snode_root(i)
|
59
|
+
roots_ptr.append(snode.SNode(res))
|
60
|
+
return roots_ptr
|
61
|
+
|
62
|
+
# TODO: move this to SNodeTree class.
|
63
|
+
def deactivate_all(self):
|
64
|
+
"""Same as :func:`gstaichi.lang.snode.SNode.deactivate_all`"""
|
65
|
+
if self.finalized:
|
66
|
+
self.root.deactivate_all()
|
67
|
+
else:
|
68
|
+
warning("""'deactivate_all()' would do nothing if FieldsBuilder is not finalized""")
|
69
|
+
|
70
|
+
def dense(
|
71
|
+
self,
|
72
|
+
indices: Union[Sequence[_Axis], _Axis],
|
73
|
+
dimensions: Union[Sequence[int], int],
|
74
|
+
):
|
75
|
+
"""Same as :func:`gstaichi.lang.snode.SNode.dense`"""
|
76
|
+
self._check_not_finalized()
|
77
|
+
self.empty = False
|
78
|
+
return self.root.dense(indices, dimensions)
|
79
|
+
|
80
|
+
def pointer(
|
81
|
+
self,
|
82
|
+
indices: Union[Sequence[_Axis], _Axis],
|
83
|
+
dimensions: Union[Sequence[int], int],
|
84
|
+
):
|
85
|
+
"""Same as :func:`gstaichi.lang.snode.SNode.pointer`"""
|
86
|
+
if not _ti_core.is_extension_supported(impl.current_cfg().arch, _ti_core.Extension.sparse):
|
87
|
+
raise GsTaichiRuntimeError("Pointer SNode is not supported on this backend.")
|
88
|
+
self._check_not_finalized()
|
89
|
+
self.empty = False
|
90
|
+
return self.root.pointer(indices, dimensions)
|
91
|
+
|
92
|
+
def _hash(self, indices, dimensions):
|
93
|
+
"""Same as :func:`gstaichi.lang.snode.SNode.hash`"""
|
94
|
+
raise NotImplementedError()
|
95
|
+
|
96
|
+
def dynamic(
|
97
|
+
self,
|
98
|
+
index: Union[Sequence[_Axis], _Axis],
|
99
|
+
dimension: Union[Sequence[int], int],
|
100
|
+
chunk_size: Optional[int] = None,
|
101
|
+
):
|
102
|
+
"""Same as :func:`gstaichi.lang.snode.SNode.dynamic`"""
|
103
|
+
if not _ti_core.is_extension_supported(impl.current_cfg().arch, _ti_core.Extension.sparse):
|
104
|
+
raise GsTaichiRuntimeError("Dynamic SNode is not supported on this backend.")
|
105
|
+
|
106
|
+
if dimension >= 2**31:
|
107
|
+
raise GsTaichiRuntimeError(
|
108
|
+
f"The maximum dimension of a dynamic SNode cannot exceed the maximum value of a 32-bit signed integer: Got {dimension} > 2**31-1"
|
109
|
+
)
|
110
|
+
if chunk_size is not None and chunk_size >= 2**31:
|
111
|
+
raise GsTaichiRuntimeError(
|
112
|
+
f"Chunk size cannot exceed the maximum value of a 32-bit signed integer: Got {chunk_size} > 2**31-1"
|
113
|
+
)
|
114
|
+
|
115
|
+
self._check_not_finalized()
|
116
|
+
self.empty = False
|
117
|
+
return self.root.dynamic(index, dimension, chunk_size)
|
118
|
+
|
119
|
+
def bitmasked(
|
120
|
+
self,
|
121
|
+
indices: Union[Sequence[_Axis], _Axis],
|
122
|
+
dimensions: Union[Sequence[int], int],
|
123
|
+
):
|
124
|
+
"""Same as :func:`gstaichi.lang.snode.SNode.bitmasked`"""
|
125
|
+
if not _ti_core.is_extension_supported(impl.current_cfg().arch, _ti_core.Extension.sparse):
|
126
|
+
raise GsTaichiRuntimeError("Bitmasked SNode is not supported on this backend.")
|
127
|
+
self._check_not_finalized()
|
128
|
+
self.empty = False
|
129
|
+
return self.root.bitmasked(indices, dimensions)
|
130
|
+
|
131
|
+
def quant_array(
|
132
|
+
self,
|
133
|
+
indices: Union[Sequence[_Axis], _Axis],
|
134
|
+
dimensions: Union[Sequence[int], int],
|
135
|
+
max_num_bits: int,
|
136
|
+
):
|
137
|
+
"""Same as :func:`gstaichi.lang.snode.SNode.quant_array`"""
|
138
|
+
self._check_not_finalized()
|
139
|
+
self.empty = False
|
140
|
+
return self.root.quant_array(indices, dimensions, max_num_bits)
|
141
|
+
|
142
|
+
def place(self, *args: Any, offset: Optional[Union[Sequence[int], int]] = None):
|
143
|
+
"""Same as :func:`gstaichi.lang.snode.SNode.place`"""
|
144
|
+
self._check_not_finalized()
|
145
|
+
self.empty = False
|
146
|
+
self.root.place(*args, offset=offset)
|
147
|
+
|
148
|
+
def lazy_grad(self):
|
149
|
+
"""Same as :func:`gstaichi.lang.snode.SNode.lazy_grad`"""
|
150
|
+
# TODO: This complicates the implementation. Figure out why we need this
|
151
|
+
self._check_not_finalized()
|
152
|
+
self.empty = False
|
153
|
+
self.root.lazy_grad()
|
154
|
+
|
155
|
+
def _allocate_adjoint_checkbit(self):
|
156
|
+
"""Same as :func:`gstaichi.lang.snode.SNode._allocate_adjoint_checkbit`"""
|
157
|
+
self._check_not_finalized()
|
158
|
+
self.empty = False
|
159
|
+
self.root._allocate_adjoint_checkbit()
|
160
|
+
|
161
|
+
def lazy_dual(self):
|
162
|
+
"""Same as :func:`gstaichi.lang.snode.SNode.lazy_dual`"""
|
163
|
+
# TODO: This complicates the implementation. Figure out why we need this
|
164
|
+
self._check_not_finalized()
|
165
|
+
self.empty = False
|
166
|
+
self.root.lazy_dual()
|
167
|
+
|
168
|
+
def finalize(self, raise_warning=True):
|
169
|
+
"""Constructs the SNodeTree and finalizes this builder.
|
170
|
+
|
171
|
+
Args:
|
172
|
+
raise_warning (bool): Raise warning or not."""
|
173
|
+
return self._finalize(raise_warning, compile_only=False)
|
174
|
+
|
175
|
+
def _finalize(self, raise_warning, compile_only) -> SNodeTree:
|
176
|
+
self._check_not_finalized()
|
177
|
+
if self.empty and raise_warning:
|
178
|
+
warning("Finalizing an empty FieldsBuilder!")
|
179
|
+
self.finalized = True
|
180
|
+
impl.get_runtime().finalize_fields_builder(self)
|
181
|
+
return SNodeTree(
|
182
|
+
_ti_core.finalize_snode_tree(_snode_registry, self.ptr, impl.get_runtime()._prog, compile_only)
|
183
|
+
)
|
184
|
+
|
185
|
+
def _check_not_finalized(self):
|
186
|
+
if self.finalized:
|
187
|
+
raise GsTaichiRuntimeError("FieldsBuilder finalized")
|
@@ -0,0 +1,34 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
# The reason we import just the gstaichi.core.util module, instead of the ti_python_core
|
4
|
+
# object within it, is that ti_python_core is stateful. While in practice ti_python_core is
|
5
|
+
# loaded during the import procedure, it's probably still good to delay the
|
6
|
+
# access to it.
|
7
|
+
|
8
|
+
from gstaichi.lang import impl
|
9
|
+
from gstaichi.lang.exception import GsTaichiRuntimeError
|
10
|
+
|
11
|
+
|
12
|
+
class SNodeTree:
|
13
|
+
def __init__(self, ptr):
|
14
|
+
self.prog = impl.get_runtime().prog
|
15
|
+
self.ptr = ptr
|
16
|
+
self.destroyed = False
|
17
|
+
|
18
|
+
def destroy(self):
|
19
|
+
if self.destroyed:
|
20
|
+
raise GsTaichiRuntimeError("SNode tree has been destroyed")
|
21
|
+
if self.prog != impl.get_runtime().prog:
|
22
|
+
return
|
23
|
+
self.ptr.destroy_snode_tree(impl.get_runtime().prog)
|
24
|
+
|
25
|
+
# FieldExpression holds a SNode* to the place-SNode associated with a SNodeTree
|
26
|
+
# Therefore, we have to recompile all the kernels after destroying a SNodeTree
|
27
|
+
impl.get_runtime().clear_compiled_functions()
|
28
|
+
self.destroyed = True
|
29
|
+
|
30
|
+
@property
|
31
|
+
def id(self):
|
32
|
+
if self.destroyed:
|
33
|
+
raise GsTaichiRuntimeError("SNode tree has been destroyed")
|
34
|
+
return self.ptr.id()
|
@@ -0,0 +1,18 @@
|
|
1
|
+
import gstaichi as ti
|
2
|
+
|
3
|
+
from . import textwrap2
|
4
|
+
|
5
|
+
|
6
|
+
def ti_init_same_arch(**options) -> None:
|
7
|
+
"""
|
8
|
+
Used in tests to call ti.init, passing in the same arch as currently
|
9
|
+
configured. Since it's fairly fiddly to do that, extracting this out
|
10
|
+
to this helper function.
|
11
|
+
"""
|
12
|
+
assert ti.cfg is not None
|
13
|
+
options = dict(options)
|
14
|
+
options["arch"] = getattr(ti, ti.cfg.arch.name)
|
15
|
+
ti.init(**options)
|
16
|
+
|
17
|
+
|
18
|
+
__all__ = ["textwrap2"]
|
@@ -0,0 +1,36 @@
|
|
1
|
+
import dataclasses
|
2
|
+
from typing import Any, cast
|
3
|
+
|
4
|
+
import gstaichi as ti
|
5
|
+
|
6
|
+
|
7
|
+
def _make_child_obj(obj_type: Any) -> Any:
|
8
|
+
if isinstance(obj_type, ti.types.NDArray):
|
9
|
+
ndarray_type = cast(ti.types.ndarray, obj_type)
|
10
|
+
assert ndarray_type.ndim is not None
|
11
|
+
shape = tuple([10] * ndarray_type.ndim)
|
12
|
+
child_obj = ti.ndarray(ndarray_type.dtype, shape=shape)
|
13
|
+
elif dataclasses.is_dataclass(obj_type):
|
14
|
+
child_obj = build_struct(obj_type)
|
15
|
+
elif isinstance(obj_type, ti.Template) or obj_type == ti.Template:
|
16
|
+
child_obj = ti.field(ti.i32, (10,))
|
17
|
+
else:
|
18
|
+
raise Exception("unknown type ", obj_type)
|
19
|
+
return child_obj
|
20
|
+
|
21
|
+
|
22
|
+
def build_struct(struct_type: Any) -> Any:
|
23
|
+
member_objects = {}
|
24
|
+
for field in dataclasses.fields(struct_type):
|
25
|
+
child_obj = _make_child_obj(field.type)
|
26
|
+
member_objects[field.name] = child_obj
|
27
|
+
dataclass_object = struct_type(**member_objects)
|
28
|
+
return dataclass_object
|
29
|
+
|
30
|
+
|
31
|
+
def build_obj_tuple_from_type_dict(name_to_type: dict[str, Any]) -> tuple[Any, ...]:
|
32
|
+
obj_l = []
|
33
|
+
for _name, param_type in name_to_type.items():
|
34
|
+
child_obj = _make_child_obj(param_type)
|
35
|
+
obj_l.append(child_obj)
|
36
|
+
return tuple(obj_l)
|
@@ -0,0 +1,30 @@
|
|
1
|
+
import importlib.util
|
2
|
+
import sys
|
3
|
+
import tempfile
|
4
|
+
from contextlib import contextmanager
|
5
|
+
from pathlib import Path
|
6
|
+
|
7
|
+
|
8
|
+
def import_kernel_from_file(kernel_filepath: Path, kernel_name: str):
|
9
|
+
spec = importlib.util.spec_from_file_location(kernel_name, kernel_filepath)
|
10
|
+
assert spec is not None
|
11
|
+
module = importlib.util.module_from_spec(spec)
|
12
|
+
sys.modules[kernel_name] = module
|
13
|
+
loader = spec.loader
|
14
|
+
assert loader is not None
|
15
|
+
loader.exec_module(module)
|
16
|
+
return getattr(module, kernel_name)
|
17
|
+
|
18
|
+
|
19
|
+
@contextmanager
|
20
|
+
def load_kernel_from_string(kernel_str: str, kernel_name: str):
|
21
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
22
|
+
filepath = Path(temp_dir) / f"{kernel_name}.py"
|
23
|
+
with open(filepath, "w") as f:
|
24
|
+
f.write(kernel_str)
|
25
|
+
try:
|
26
|
+
kernel = import_kernel_from_file(kernel_filepath=filepath, kernel_name=kernel_name)
|
27
|
+
yield kernel
|
28
|
+
finally:
|
29
|
+
if kernel_name in sys.modules:
|
30
|
+
del sys.modules[kernel_name]
|