gstaichi 0.0.0__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (178) hide show
  1. gstaichi/CHANGELOG.md +4 -0
  2. gstaichi/__init__.py +51 -0
  3. gstaichi/_funcs.py +706 -0
  4. gstaichi/_kernels.py +420 -0
  5. gstaichi/_lib/__init__.py +5 -0
  6. gstaichi/_lib/core/__init__.py +0 -0
  7. gstaichi/_lib/core/gstaichi_python.cpython-311-darwin.so +0 -0
  8. gstaichi/_lib/core/gstaichi_python.pyi +2917 -0
  9. gstaichi/_lib/core/py.typed +0 -0
  10. gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
  11. gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
  12. gstaichi/_lib/utils.py +243 -0
  13. gstaichi/_logging.py +131 -0
  14. gstaichi/_snode/__init__.py +5 -0
  15. gstaichi/_snode/fields_builder.py +187 -0
  16. gstaichi/_snode/snode_tree.py +34 -0
  17. gstaichi/_test_tools/__init__.py +18 -0
  18. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  19. gstaichi/_test_tools/load_kernel_string.py +30 -0
  20. gstaichi/_test_tools/textwrap2.py +6 -0
  21. gstaichi/_version_check.py +100 -0
  22. gstaichi/ad/__init__.py +3 -0
  23. gstaichi/ad/_ad.py +530 -0
  24. gstaichi/algorithms/__init__.py +3 -0
  25. gstaichi/algorithms/_algorithms.py +117 -0
  26. gstaichi/assets/.git +1 -0
  27. gstaichi/assets/Go-Regular.ttf +0 -0
  28. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  29. gstaichi/examples/lcg_python.py +26 -0
  30. gstaichi/examples/lcg_taichi.py +34 -0
  31. gstaichi/examples/minimal.py +28 -0
  32. gstaichi/experimental.py +16 -0
  33. gstaichi/lang/__init__.py +50 -0
  34. gstaichi/lang/_dataclass_util.py +31 -0
  35. gstaichi/lang/_fast_caching/__init__.py +3 -0
  36. gstaichi/lang/_fast_caching/args_hasher.py +122 -0
  37. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  38. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  39. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  40. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  41. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  42. gstaichi/lang/_fast_caching/src_hasher.py +83 -0
  43. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  44. gstaichi/lang/_ndarray.py +366 -0
  45. gstaichi/lang/_ndrange.py +152 -0
  46. gstaichi/lang/_template_mapper.py +195 -0
  47. gstaichi/lang/_texture.py +172 -0
  48. gstaichi/lang/_wrap_inspect.py +215 -0
  49. gstaichi/lang/any_array.py +99 -0
  50. gstaichi/lang/ast/__init__.py +7 -0
  51. gstaichi/lang/ast/ast_transformer.py +1351 -0
  52. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  53. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  54. gstaichi/lang/ast/ast_transformers/call_transformer.py +327 -0
  55. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  56. gstaichi/lang/ast/checkers.py +106 -0
  57. gstaichi/lang/ast/symbol_resolver.py +57 -0
  58. gstaichi/lang/ast/transform.py +9 -0
  59. gstaichi/lang/common_ops.py +310 -0
  60. gstaichi/lang/exception.py +80 -0
  61. gstaichi/lang/expr.py +180 -0
  62. gstaichi/lang/field.py +428 -0
  63. gstaichi/lang/impl.py +1259 -0
  64. gstaichi/lang/kernel_arguments.py +155 -0
  65. gstaichi/lang/kernel_impl.py +1386 -0
  66. gstaichi/lang/matrix.py +1835 -0
  67. gstaichi/lang/matrix_ops.py +341 -0
  68. gstaichi/lang/matrix_ops_utils.py +190 -0
  69. gstaichi/lang/mesh.py +687 -0
  70. gstaichi/lang/misc.py +784 -0
  71. gstaichi/lang/ops.py +1494 -0
  72. gstaichi/lang/runtime_ops.py +13 -0
  73. gstaichi/lang/shell.py +35 -0
  74. gstaichi/lang/simt/__init__.py +5 -0
  75. gstaichi/lang/simt/block.py +94 -0
  76. gstaichi/lang/simt/grid.py +7 -0
  77. gstaichi/lang/simt/subgroup.py +191 -0
  78. gstaichi/lang/simt/warp.py +96 -0
  79. gstaichi/lang/snode.py +489 -0
  80. gstaichi/lang/source_builder.py +150 -0
  81. gstaichi/lang/struct.py +810 -0
  82. gstaichi/lang/util.py +312 -0
  83. gstaichi/linalg/__init__.py +10 -0
  84. gstaichi/linalg/matrixfree_cg.py +310 -0
  85. gstaichi/linalg/sparse_cg.py +59 -0
  86. gstaichi/linalg/sparse_matrix.py +303 -0
  87. gstaichi/linalg/sparse_solver.py +123 -0
  88. gstaichi/math/__init__.py +11 -0
  89. gstaichi/math/_complex.py +205 -0
  90. gstaichi/math/mathimpl.py +886 -0
  91. gstaichi/profiler/__init__.py +6 -0
  92. gstaichi/profiler/kernel_metrics.py +260 -0
  93. gstaichi/profiler/kernel_profiler.py +586 -0
  94. gstaichi/profiler/memory_profiler.py +15 -0
  95. gstaichi/profiler/scoped_profiler.py +36 -0
  96. gstaichi/sparse/__init__.py +3 -0
  97. gstaichi/sparse/_sparse_grid.py +77 -0
  98. gstaichi/tools/__init__.py +12 -0
  99. gstaichi/tools/diagnose.py +117 -0
  100. gstaichi/tools/np2ply.py +364 -0
  101. gstaichi/tools/vtk.py +38 -0
  102. gstaichi/types/__init__.py +21 -0
  103. gstaichi/types/annotations.py +52 -0
  104. gstaichi/types/compound_types.py +71 -0
  105. gstaichi/types/enums.py +49 -0
  106. gstaichi/types/ndarray_type.py +169 -0
  107. gstaichi/types/primitive_types.py +206 -0
  108. gstaichi/types/quant.py +88 -0
  109. gstaichi/types/texture_type.py +85 -0
  110. gstaichi/types/utils.py +11 -0
  111. gstaichi-0.0.0.data/data/include/GLFW/glfw3.h +6389 -0
  112. gstaichi-0.0.0.data/data/include/GLFW/glfw3native.h +594 -0
  113. gstaichi-0.0.0.data/data/include/spirv-tools/instrument.hpp +268 -0
  114. gstaichi-0.0.0.data/data/include/spirv-tools/libspirv.h +907 -0
  115. gstaichi-0.0.0.data/data/include/spirv-tools/libspirv.hpp +375 -0
  116. gstaichi-0.0.0.data/data/include/spirv-tools/linker.hpp +97 -0
  117. gstaichi-0.0.0.data/data/include/spirv-tools/optimizer.hpp +970 -0
  118. gstaichi-0.0.0.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  119. gstaichi-0.0.0.data/data/include/spirv_cross/spirv.h +2568 -0
  120. gstaichi-0.0.0.data/data/include/spirv_cross/spirv.hpp +2579 -0
  121. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  122. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  123. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  124. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  125. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  126. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  127. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  128. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  129. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  130. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  131. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  132. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  133. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  134. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  135. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  136. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  137. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  138. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  139. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  140. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  141. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  142. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  143. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  144. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  145. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  146. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  147. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  148. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  149. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  150. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  151. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  152. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  153. gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  154. gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  155. gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  156. gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  157. gstaichi-0.0.0.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  158. gstaichi-0.0.0.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  159. gstaichi-0.0.0.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  160. gstaichi-0.0.0.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  161. gstaichi-0.0.0.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  162. gstaichi-0.0.0.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  163. gstaichi-0.0.0.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  164. gstaichi-0.0.0.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  165. gstaichi-0.0.0.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  166. gstaichi-0.0.0.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  167. gstaichi-0.0.0.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  168. gstaichi-0.0.0.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  169. gstaichi-0.0.0.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  170. gstaichi-0.0.0.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  171. gstaichi-0.0.0.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  172. gstaichi-0.0.0.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  173. gstaichi-0.0.0.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  174. gstaichi-0.0.0.dist-info/METADATA +97 -0
  175. gstaichi-0.0.0.dist-info/RECORD +178 -0
  176. gstaichi-0.0.0.dist-info/WHEEL +5 -0
  177. gstaichi-0.0.0.dist-info/licenses/LICENSE +201 -0
  178. gstaichi-0.0.0.dist-info/top_level.txt +1 -0
File without changes
Binary file
Binary file
gstaichi/_lib/utils.py ADDED
@@ -0,0 +1,243 @@
1
+ # type: ignore
2
+
3
+ import os
4
+ import platform
5
+ import re
6
+ import sys
7
+ import warnings
8
+
9
+ from colorama import Fore, Style
10
+
11
+ if sys.version_info[0] < 3 or sys.version_info[1] <= 5:
12
+ raise RuntimeError(
13
+ "\nPlease restart with Python 3.6+\n" + "Current Python version:",
14
+ sys.version_info,
15
+ )
16
+
17
+
18
+ def in_docker():
19
+ if os.environ.get("TI_IN_DOCKER", "") == "":
20
+ return False
21
+ return True
22
+
23
+
24
+ def get_os_name():
25
+ name = platform.platform()
26
+ # in python 3.8, platform.platform() uses mac_ver() on macOS
27
+ # it will return 'macOS-XXXX' instead of 'Darwin-XXXX'
28
+ if name.lower().startswith("darwin") or name.lower().startswith("macos"):
29
+ return "osx"
30
+ if name.lower().startswith("windows"):
31
+ return "win"
32
+ if name.lower().startswith("linux"):
33
+ return "linux"
34
+ if "bsd" in name.lower():
35
+ return "unix"
36
+ assert False, f"Unknown platform name {name}"
37
+
38
+
39
+ def import_ti_python_core():
40
+ if get_os_name() != "win":
41
+ # pylint: disable=E1101
42
+ old_flags = sys.getdlopenflags()
43
+ sys.setdlopenflags(2 | 8) # RTLD_NOW | RTLD_DEEPBIND
44
+ else:
45
+ pyddir = os.path.dirname(os.path.realpath(__file__))
46
+ os.environ["PATH"] += os.pathsep + pyddir
47
+ try:
48
+ from gstaichi._lib.core import gstaichi_python as core # pylint: disable=C0415
49
+ except Exception as e:
50
+ if isinstance(e, ImportError):
51
+ print(
52
+ Fore.YELLOW + "Share object gstaichi_python import failed, "
53
+ "check this page for possible solutions:\n"
54
+ "https://docs.taichi-lang.org/docs/install" + Fore.RESET
55
+ )
56
+ if get_os_name() == "win":
57
+ # pylint: disable=E1101
58
+ e.msg += "\nConsider installing Microsoft Visual C++ Redistributable: https://aka.ms/vs/16/release/vc_redist.x64.exe"
59
+ raise e from None
60
+
61
+ if get_os_name() != "win":
62
+ sys.setdlopenflags(old_flags) # pylint: disable=E1101
63
+ lib_dir = os.path.join(package_root, "_lib", "runtime")
64
+ core.set_lib_dir(locale_encode(lib_dir))
65
+ return core
66
+
67
+
68
+ def locale_encode(path):
69
+ try:
70
+ import locale # pylint: disable=C0415
71
+
72
+ return path.encode(locale.getdefaultlocale()[1])
73
+ except (UnicodeEncodeError, TypeError):
74
+ try:
75
+ return path.encode(sys.getfilesystemencoding())
76
+ except UnicodeEncodeError:
77
+ try:
78
+ return path.encode()
79
+ except UnicodeEncodeError:
80
+ return path
81
+
82
+
83
+ def is_ci():
84
+ return os.environ.get("TI_CI", "") == "1"
85
+
86
+
87
+ package_root = os.path.join(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
88
+
89
+
90
+ def get_core_shared_object():
91
+ directory = os.path.join(package_root, "_lib")
92
+ return os.path.join(directory, "libgstaichi_python.so")
93
+
94
+
95
+ def print_red_bold(*args, **kwargs):
96
+ print(Fore.RED + Style.BRIGHT, end="")
97
+ print(*args, **kwargs)
98
+ print(Style.RESET_ALL, end="")
99
+
100
+
101
+ def print_yellow_bold(*args, **kwargs):
102
+ print(Fore.YELLOW + Style.BRIGHT, end="")
103
+ print(*args, **kwargs)
104
+ print(Style.RESET_ALL, end="")
105
+
106
+
107
+ def check_exists(src):
108
+ if not os.path.exists(src):
109
+ raise FileNotFoundError(f'File "{src}" not exist. Installation corrupted or build incomplete?')
110
+
111
+
112
+ ti_python_core = import_ti_python_core()
113
+
114
+ ti_python_core.set_python_package_dir(package_root)
115
+
116
+ log_level = os.environ.get("TI_LOG_LEVEL", "")
117
+ if log_level:
118
+ ti_python_core.set_logging_level(log_level)
119
+
120
+
121
+ def get_dll_name(name):
122
+ if get_os_name() == "linux":
123
+ return f"libgstaichi_{name}.so"
124
+ if get_os_name() == "osx":
125
+ return f"libgstaichi_{name}.dylib"
126
+ if get_os_name() == "win":
127
+ return f"gstaichi_{name}.dll"
128
+ raise Exception(f"Unknown OS: {get_os_name()}")
129
+
130
+
131
+ def at_startup():
132
+ ti_python_core.set_core_state_python_imported(True)
133
+
134
+
135
+ at_startup()
136
+
137
+
138
+ def compare_version(latest, current):
139
+ latest_num = map(int, latest.split("."))
140
+ current_num = map(int, current.split("."))
141
+ return tuple(latest_num) > tuple(current_num)
142
+
143
+
144
+ def _print_gstaichi_header():
145
+ header = "[GsTaichi] "
146
+ header += f"version {ti_python_core.get_version_string()}, "
147
+
148
+ try:
149
+ timestamp_path = os.path.join(ti_python_core.get_repo_dir(), "timestamp")
150
+ if os.path.exists(timestamp_path):
151
+ latest_version = ""
152
+ with open(timestamp_path, "r") as f:
153
+ latest_version = f.readlines()[1].rstrip()
154
+ if compare_version(latest_version, ti_python_core.get_version_string()):
155
+ header += f"latest version {latest_version}, "
156
+ except:
157
+ pass
158
+
159
+ llvm_target_support = ti_python_core.get_llvm_target_support()
160
+ header += f"llvm {llvm_target_support}, "
161
+
162
+ commit_hash = ti_python_core.get_commit_hash()
163
+ commit_hash = commit_hash[:8]
164
+ header += f"commit {commit_hash}, "
165
+
166
+ header += f"{get_os_name()}, "
167
+
168
+ py_ver = ".".join(str(x) for x in sys.version_info[:3])
169
+ header += f"python {py_ver}"
170
+
171
+ print(header)
172
+
173
+
174
+ if os.getenv("ENABLE_GSTAICHI_HEADER_PRINT", "True").lower() not in ("false", "0", "f"):
175
+ _print_gstaichi_header()
176
+
177
+
178
+ def try_get_wheel_tag(module):
179
+ try:
180
+ from email.parser import Parser # pylint: disable=import-outside-toplevel
181
+
182
+ wheel_path = f'{module.__path__[0]}-{".".join(map(str, module.__version__))}.dist-info/WHEEL'
183
+ with open(wheel_path, "r") as f:
184
+ meta = Parser().parse(f)
185
+ return meta.get("Tag")
186
+ except Exception:
187
+ return None
188
+
189
+
190
+ def try_get_loaded_libc_version():
191
+ assert platform.system() == "Linux"
192
+ with open("/proc/self/maps") as f:
193
+ content = f.read()
194
+
195
+ try:
196
+ libc_path = next(v for v in content.split() if "libc-" in v)
197
+ ver = re.findall(r"\d+\.\d+", libc_path)
198
+ if not ver:
199
+ return None
200
+ return tuple([int(v) for v in ver[0].split(".")])
201
+ except StopIteration:
202
+ return None
203
+
204
+
205
+ def try_get_pip_version():
206
+ try:
207
+ with warnings.catch_warnings():
208
+ warnings.simplefilter("ignore")
209
+ import pip # pylint: disable=import-outside-toplevel
210
+ return tuple([int(v) for v in pip.__version__.split(".")])
211
+ except ImportError:
212
+ return None
213
+
214
+
215
+ def warn_restricted_version():
216
+ if os.environ.get("TI_MANYLINUX2014_OK", ""):
217
+ return
218
+
219
+ if get_os_name() == "linux":
220
+ try:
221
+ import gstaichi as ti # pylint: disable=import-outside-toplevel
222
+
223
+ wheel_tag = try_get_wheel_tag(ti)
224
+ if wheel_tag and "manylinux" in wheel_tag:
225
+ libc_ver = try_get_loaded_libc_version()
226
+ if libc_ver and libc_ver < (2, 27):
227
+ print_yellow_bold(
228
+ "!! GsTaichi requires glibc >= 2.27 to run, please try upgrading your OS to a recent one (e.g. Ubuntu 18.04 or later) if possible."
229
+ )
230
+
231
+ pip_ver = try_get_pip_version()
232
+ if pip_ver and pip_ver < (20, 3, 0):
233
+ print_yellow_bold(
234
+ f"!! Your pip (version {'.'.join(map(str, pip_ver))}) is outdated (20.3.0 or later required), "
235
+ "try upgrading pip and install gstaichi again."
236
+ )
237
+ print()
238
+ print_yellow_bold(" $ python3 -m pip install --upgrade pip")
239
+ print_yellow_bold(" $ python3 -m pip install --force-reinstall gstaichi")
240
+ print()
241
+
242
+ except Exception:
243
+ pass
gstaichi/_logging.py ADDED
@@ -0,0 +1,131 @@
1
+ # type: ignore
2
+
3
+ import inspect
4
+ import os
5
+
6
+ from gstaichi._lib import core as ti_python_core
7
+
8
+
9
+ def _get_logging(name):
10
+ """Generates a decorator to decorate a specific logger function.
11
+
12
+ Args:
13
+ name (str): The string represents logging level.
14
+ Effective levels include: 'trace', 'debug', 'info', 'warn', 'error', 'critical'.
15
+
16
+ Returns:
17
+ Callabe: The decorated function.
18
+ """
19
+
20
+ def logger(msg, *args, **kwargs):
21
+ # Python inspection takes time (~0.1ms) so avoid it as much as possible
22
+ if ti_python_core.logging_effective(name):
23
+ msg_formatted = msg.format(*args, **kwargs)
24
+ func = getattr(ti_python_core, name)
25
+ frame = inspect.currentframe().f_back
26
+ file_name, lineno, func_name, _, _ = inspect.getframeinfo(frame)
27
+ file_name = os.path.basename(file_name)
28
+ msg = f"[{file_name}:{func_name}@{lineno}] {msg_formatted}"
29
+ func(msg)
30
+
31
+ return logger
32
+
33
+
34
+ def set_logging_level(level):
35
+ """Setting the logging level to a specified value.
36
+ Available levels are: 'trace', 'debug', 'info', 'warn', 'error', 'critical'.
37
+
38
+ Note that after calling this function, logging levels below the specified one will
39
+ also be effective. For example if `level` is set to 'warn', then the levels below
40
+ it, which are 'error' and 'critical' in this case, will also be effective.
41
+
42
+ See also https://docs.taichi-lang.org/docs/developer_utilities#logging.
43
+
44
+ Args:
45
+ level (str): Logging level.
46
+
47
+ Example::
48
+
49
+ >>> set_logging_level('debug')
50
+ """
51
+ ti_python_core.set_logging_level(level)
52
+
53
+
54
+ def is_logging_effective(level):
55
+ """Check if the specified logging level is effective.
56
+ All levels below current level will be effective.
57
+ The default level is 'info'.
58
+
59
+ See also https://docs.taichi-lang.org/docs/developer_utilities#logging.
60
+
61
+ Args:
62
+ level (str): The string represents logging level. \
63
+ Effective levels include: 'trace', 'debug', 'info', 'warn', 'error', 'critical'.
64
+
65
+ Returns:
66
+ Bool: Indicate whether the logging level is effective.
67
+
68
+ Example::
69
+
70
+ >>> # assume current level is 'info'
71
+ >>> print(ti.is_logging_effective("trace")) # False
72
+ >>> print(ti.is_logging_effective("debug")) # False
73
+ >>> print(ti.is_logging_effective("info")) # True
74
+ >>> print(ti.is_logging_effective("warn")) # True
75
+ >>> print(ti.is_logging_effective("error")) # True
76
+ >>> print(ti.is_logging_effective("critical")) # True
77
+ """
78
+ return ti_python_core.logging_effective(level)
79
+
80
+
81
+ # ------------------------
82
+
83
+ DEBUG = "debug"
84
+ """The `str` 'debug', used for the `debug` logging level.
85
+ """
86
+ # ------------------------
87
+
88
+ TRACE = "trace"
89
+ """The `str` 'trace', used for the `debug` logging level.
90
+ """
91
+ # ------------------------
92
+
93
+ INFO = "info"
94
+ """The `str` 'info', used for the `info` logging level.
95
+ """
96
+ # ------------------------
97
+
98
+ WARN = "warn"
99
+ """The `str` 'warn', used for the `warn` logging level.
100
+ """
101
+ # ------------------------
102
+
103
+ ERROR = "error"
104
+ """The `str` 'error', used for the `error` logging level.
105
+ """
106
+ # ------------------------
107
+
108
+ CRITICAL = "critical"
109
+ """The `str` 'critical', used for the `critical` logging level.
110
+ """
111
+ # ------------------------
112
+
113
+ supported_log_levels = [DEBUG, TRACE, INFO, WARN, ERROR, CRITICAL]
114
+
115
+ debug = _get_logging(DEBUG)
116
+ trace = _get_logging(TRACE)
117
+ info = _get_logging(INFO)
118
+ warn = _get_logging(WARN)
119
+ error = _get_logging(ERROR)
120
+ critical = _get_logging(CRITICAL)
121
+
122
+ __all__ = [
123
+ "DEBUG",
124
+ "TRACE",
125
+ "INFO",
126
+ "WARN",
127
+ "ERROR",
128
+ "CRITICAL",
129
+ "set_logging_level",
130
+ "is_logging_effective",
131
+ ]
@@ -0,0 +1,5 @@
1
+ # type: ignore
2
+
3
+ from gstaichi._snode.fields_builder import FieldsBuilder
4
+
5
+ __all__ = ["FieldsBuilder"]
@@ -0,0 +1,187 @@
1
+ # type: ignore
2
+
3
+ from typing import Any, Optional, Sequence, Union
4
+
5
+ from gstaichi._lib import core as _ti_core
6
+ from gstaichi._lib.core.gstaichi_python import SNodeCxx
7
+ from gstaichi._snode.snode_tree import SNodeTree
8
+ from gstaichi.lang import impl, snode
9
+ from gstaichi.lang.exception import GsTaichiRuntimeError
10
+ from gstaichi.lang.util import warning
11
+
12
+ _snode_registry = _ti_core.SNodeRegistry()
13
+
14
+ _Axis = _ti_core.Axis
15
+
16
+
17
+ class FieldsBuilder:
18
+ """A builder that constructs a SNodeTree instance.
19
+
20
+ Example::
21
+
22
+ x = ti.field(ti.i32)
23
+ y = ti.field(ti.f32)
24
+ fb = ti.FieldsBuilder()
25
+ fb.dense(ti.ij, 8).place(x)
26
+ fb.pointer(ti.ij, 8).dense(ti.ij, 4).place(y)
27
+
28
+ # After this line, `x` and `y` are placed. No more fields can be placed
29
+ # into `fb`.
30
+ #
31
+ # The tree looks like the following:
32
+ # (implicit root)
33
+ # |
34
+ # +-- dense +-- place(x)
35
+ # |
36
+ # +-- pointer +-- dense +-- place(y)
37
+ fb.finalize()
38
+ """
39
+
40
+ def __init__(self):
41
+ self.ptr: SNodeCxx = _snode_registry.create_root(impl.get_runtime().prog)
42
+ self.root = snode.SNode(self.ptr)
43
+ self.finalized = False
44
+ self.empty = True
45
+ impl.get_runtime().initialize_fields_builder(self)
46
+
47
+ # TODO: move this into SNodeTree
48
+ @classmethod
49
+ def _finalized_roots(cls):
50
+ """Gets all the roots of the finalized SNodeTree.
51
+
52
+ Returns:
53
+ A list of the roots of the finalized SNodeTree.
54
+ """
55
+ roots_ptr = []
56
+ size = impl.get_runtime().prog.get_snode_tree_size()
57
+ for i in range(size):
58
+ res = impl.get_runtime().prog.get_snode_root(i)
59
+ roots_ptr.append(snode.SNode(res))
60
+ return roots_ptr
61
+
62
+ # TODO: move this to SNodeTree class.
63
+ def deactivate_all(self):
64
+ """Same as :func:`gstaichi.lang.snode.SNode.deactivate_all`"""
65
+ if self.finalized:
66
+ self.root.deactivate_all()
67
+ else:
68
+ warning("""'deactivate_all()' would do nothing if FieldsBuilder is not finalized""")
69
+
70
+ def dense(
71
+ self,
72
+ indices: Union[Sequence[_Axis], _Axis],
73
+ dimensions: Union[Sequence[int], int],
74
+ ):
75
+ """Same as :func:`gstaichi.lang.snode.SNode.dense`"""
76
+ self._check_not_finalized()
77
+ self.empty = False
78
+ return self.root.dense(indices, dimensions)
79
+
80
+ def pointer(
81
+ self,
82
+ indices: Union[Sequence[_Axis], _Axis],
83
+ dimensions: Union[Sequence[int], int],
84
+ ):
85
+ """Same as :func:`gstaichi.lang.snode.SNode.pointer`"""
86
+ if not _ti_core.is_extension_supported(impl.current_cfg().arch, _ti_core.Extension.sparse):
87
+ raise GsTaichiRuntimeError("Pointer SNode is not supported on this backend.")
88
+ self._check_not_finalized()
89
+ self.empty = False
90
+ return self.root.pointer(indices, dimensions)
91
+
92
+ def _hash(self, indices, dimensions):
93
+ """Same as :func:`gstaichi.lang.snode.SNode.hash`"""
94
+ raise NotImplementedError()
95
+
96
+ def dynamic(
97
+ self,
98
+ index: Union[Sequence[_Axis], _Axis],
99
+ dimension: Union[Sequence[int], int],
100
+ chunk_size: Optional[int] = None,
101
+ ):
102
+ """Same as :func:`gstaichi.lang.snode.SNode.dynamic`"""
103
+ if not _ti_core.is_extension_supported(impl.current_cfg().arch, _ti_core.Extension.sparse):
104
+ raise GsTaichiRuntimeError("Dynamic SNode is not supported on this backend.")
105
+
106
+ if dimension >= 2**31:
107
+ raise GsTaichiRuntimeError(
108
+ f"The maximum dimension of a dynamic SNode cannot exceed the maximum value of a 32-bit signed integer: Got {dimension} > 2**31-1"
109
+ )
110
+ if chunk_size is not None and chunk_size >= 2**31:
111
+ raise GsTaichiRuntimeError(
112
+ f"Chunk size cannot exceed the maximum value of a 32-bit signed integer: Got {chunk_size} > 2**31-1"
113
+ )
114
+
115
+ self._check_not_finalized()
116
+ self.empty = False
117
+ return self.root.dynamic(index, dimension, chunk_size)
118
+
119
+ def bitmasked(
120
+ self,
121
+ indices: Union[Sequence[_Axis], _Axis],
122
+ dimensions: Union[Sequence[int], int],
123
+ ):
124
+ """Same as :func:`gstaichi.lang.snode.SNode.bitmasked`"""
125
+ if not _ti_core.is_extension_supported(impl.current_cfg().arch, _ti_core.Extension.sparse):
126
+ raise GsTaichiRuntimeError("Bitmasked SNode is not supported on this backend.")
127
+ self._check_not_finalized()
128
+ self.empty = False
129
+ return self.root.bitmasked(indices, dimensions)
130
+
131
+ def quant_array(
132
+ self,
133
+ indices: Union[Sequence[_Axis], _Axis],
134
+ dimensions: Union[Sequence[int], int],
135
+ max_num_bits: int,
136
+ ):
137
+ """Same as :func:`gstaichi.lang.snode.SNode.quant_array`"""
138
+ self._check_not_finalized()
139
+ self.empty = False
140
+ return self.root.quant_array(indices, dimensions, max_num_bits)
141
+
142
+ def place(self, *args: Any, offset: Optional[Union[Sequence[int], int]] = None):
143
+ """Same as :func:`gstaichi.lang.snode.SNode.place`"""
144
+ self._check_not_finalized()
145
+ self.empty = False
146
+ self.root.place(*args, offset=offset)
147
+
148
+ def lazy_grad(self):
149
+ """Same as :func:`gstaichi.lang.snode.SNode.lazy_grad`"""
150
+ # TODO: This complicates the implementation. Figure out why we need this
151
+ self._check_not_finalized()
152
+ self.empty = False
153
+ self.root.lazy_grad()
154
+
155
+ def _allocate_adjoint_checkbit(self):
156
+ """Same as :func:`gstaichi.lang.snode.SNode._allocate_adjoint_checkbit`"""
157
+ self._check_not_finalized()
158
+ self.empty = False
159
+ self.root._allocate_adjoint_checkbit()
160
+
161
+ def lazy_dual(self):
162
+ """Same as :func:`gstaichi.lang.snode.SNode.lazy_dual`"""
163
+ # TODO: This complicates the implementation. Figure out why we need this
164
+ self._check_not_finalized()
165
+ self.empty = False
166
+ self.root.lazy_dual()
167
+
168
+ def finalize(self, raise_warning=True):
169
+ """Constructs the SNodeTree and finalizes this builder.
170
+
171
+ Args:
172
+ raise_warning (bool): Raise warning or not."""
173
+ return self._finalize(raise_warning, compile_only=False)
174
+
175
+ def _finalize(self, raise_warning, compile_only) -> SNodeTree:
176
+ self._check_not_finalized()
177
+ if self.empty and raise_warning:
178
+ warning("Finalizing an empty FieldsBuilder!")
179
+ self.finalized = True
180
+ impl.get_runtime().finalize_fields_builder(self)
181
+ return SNodeTree(
182
+ _ti_core.finalize_snode_tree(_snode_registry, self.ptr, impl.get_runtime()._prog, compile_only)
183
+ )
184
+
185
+ def _check_not_finalized(self):
186
+ if self.finalized:
187
+ raise GsTaichiRuntimeError("FieldsBuilder finalized")
@@ -0,0 +1,34 @@
1
+ # type: ignore
2
+
3
+ # The reason we import just the gstaichi.core.util module, instead of the ti_python_core
4
+ # object within it, is that ti_python_core is stateful. While in practice ti_python_core is
5
+ # loaded during the import procedure, it's probably still good to delay the
6
+ # access to it.
7
+
8
+ from gstaichi.lang import impl
9
+ from gstaichi.lang.exception import GsTaichiRuntimeError
10
+
11
+
12
+ class SNodeTree:
13
+ def __init__(self, ptr):
14
+ self.prog = impl.get_runtime().prog
15
+ self.ptr = ptr
16
+ self.destroyed = False
17
+
18
+ def destroy(self):
19
+ if self.destroyed:
20
+ raise GsTaichiRuntimeError("SNode tree has been destroyed")
21
+ if self.prog != impl.get_runtime().prog:
22
+ return
23
+ self.ptr.destroy_snode_tree(impl.get_runtime().prog)
24
+
25
+ # FieldExpression holds a SNode* to the place-SNode associated with a SNodeTree
26
+ # Therefore, we have to recompile all the kernels after destroying a SNodeTree
27
+ impl.get_runtime().clear_compiled_functions()
28
+ self.destroyed = True
29
+
30
+ @property
31
+ def id(self):
32
+ if self.destroyed:
33
+ raise GsTaichiRuntimeError("SNode tree has been destroyed")
34
+ return self.ptr.id()
@@ -0,0 +1,18 @@
1
+ import gstaichi as ti
2
+
3
+ from . import textwrap2
4
+
5
+
6
+ def ti_init_same_arch(**options) -> None:
7
+ """
8
+ Used in tests to call ti.init, passing in the same arch as currently
9
+ configured. Since it's fairly fiddly to do that, extracting this out
10
+ to this helper function.
11
+ """
12
+ assert ti.cfg is not None
13
+ options = dict(options)
14
+ options["arch"] = getattr(ti, ti.cfg.arch.name)
15
+ ti.init(**options)
16
+
17
+
18
+ __all__ = ["textwrap2"]
@@ -0,0 +1,36 @@
1
+ import dataclasses
2
+ from typing import Any, cast
3
+
4
+ import gstaichi as ti
5
+
6
+
7
+ def _make_child_obj(obj_type: Any) -> Any:
8
+ if isinstance(obj_type, ti.types.NDArray):
9
+ ndarray_type = cast(ti.types.ndarray, obj_type)
10
+ assert ndarray_type.ndim is not None
11
+ shape = tuple([10] * ndarray_type.ndim)
12
+ child_obj = ti.ndarray(ndarray_type.dtype, shape=shape)
13
+ elif dataclasses.is_dataclass(obj_type):
14
+ child_obj = build_struct(obj_type)
15
+ elif isinstance(obj_type, ti.Template) or obj_type == ti.Template:
16
+ child_obj = ti.field(ti.i32, (10,))
17
+ else:
18
+ raise Exception("unknown type ", obj_type)
19
+ return child_obj
20
+
21
+
22
+ def build_struct(struct_type: Any) -> Any:
23
+ member_objects = {}
24
+ for field in dataclasses.fields(struct_type):
25
+ child_obj = _make_child_obj(field.type)
26
+ member_objects[field.name] = child_obj
27
+ dataclass_object = struct_type(**member_objects)
28
+ return dataclass_object
29
+
30
+
31
+ def build_obj_tuple_from_type_dict(name_to_type: dict[str, Any]) -> tuple[Any, ...]:
32
+ obj_l = []
33
+ for _name, param_type in name_to_type.items():
34
+ child_obj = _make_child_obj(param_type)
35
+ obj_l.append(child_obj)
36
+ return tuple(obj_l)
@@ -0,0 +1,30 @@
1
+ import importlib.util
2
+ import sys
3
+ import tempfile
4
+ from contextlib import contextmanager
5
+ from pathlib import Path
6
+
7
+
8
+ def import_kernel_from_file(kernel_filepath: Path, kernel_name: str):
9
+ spec = importlib.util.spec_from_file_location(kernel_name, kernel_filepath)
10
+ assert spec is not None
11
+ module = importlib.util.module_from_spec(spec)
12
+ sys.modules[kernel_name] = module
13
+ loader = spec.loader
14
+ assert loader is not None
15
+ loader.exec_module(module)
16
+ return getattr(module, kernel_name)
17
+
18
+
19
+ @contextmanager
20
+ def load_kernel_from_string(kernel_str: str, kernel_name: str):
21
+ with tempfile.TemporaryDirectory() as temp_dir:
22
+ filepath = Path(temp_dir) / f"{kernel_name}.py"
23
+ with open(filepath, "w") as f:
24
+ f.write(kernel_str)
25
+ try:
26
+ kernel = import_kernel_from_file(kernel_filepath=filepath, kernel_name=kernel_name)
27
+ yield kernel
28
+ finally:
29
+ if kernel_name in sys.modules:
30
+ del sys.modules[kernel_name]