gstaichi 2.1.1rc3__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. gstaichi/CHANGELOG.md +4 -0
  2. gstaichi/__init__.py +40 -0
  3. gstaichi/_funcs.py +706 -0
  4. gstaichi/_kernels.py +420 -0
  5. gstaichi/_lib/__init__.py +3 -0
  6. gstaichi/_lib/core/__init__.py +0 -0
  7. gstaichi/_lib/core/gstaichi_python.cpython-311-darwin.so +0 -0
  8. gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
  9. gstaichi/_lib/core/py.typed +0 -0
  10. gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
  11. gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
  12. gstaichi/_lib/utils.py +243 -0
  13. gstaichi/_logging.py +131 -0
  14. gstaichi/_snode/__init__.py +5 -0
  15. gstaichi/_snode/fields_builder.py +187 -0
  16. gstaichi/_snode/snode_tree.py +34 -0
  17. gstaichi/_test_tools/__init__.py +18 -0
  18. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  19. gstaichi/_test_tools/load_kernel_string.py +30 -0
  20. gstaichi/_test_tools/textwrap2.py +6 -0
  21. gstaichi/_version.py +1 -0
  22. gstaichi/_version_check.py +100 -0
  23. gstaichi/ad/__init__.py +3 -0
  24. gstaichi/ad/_ad.py +530 -0
  25. gstaichi/algorithms/__init__.py +3 -0
  26. gstaichi/algorithms/_algorithms.py +117 -0
  27. gstaichi/assets/.git +1 -0
  28. gstaichi/assets/Go-Regular.ttf +0 -0
  29. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  30. gstaichi/examples/lcg_python.py +26 -0
  31. gstaichi/examples/lcg_taichi.py +34 -0
  32. gstaichi/examples/minimal.py +28 -0
  33. gstaichi/experimental.py +16 -0
  34. gstaichi/lang/__init__.py +50 -0
  35. gstaichi/lang/_dataclass_util.py +31 -0
  36. gstaichi/lang/_fast_caching/__init__.py +3 -0
  37. gstaichi/lang/_fast_caching/args_hasher.py +110 -0
  38. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  39. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  40. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  41. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  42. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  43. gstaichi/lang/_fast_caching/src_hasher.py +75 -0
  44. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  45. gstaichi/lang/_ndarray.py +352 -0
  46. gstaichi/lang/_ndrange.py +152 -0
  47. gstaichi/lang/_template_mapper.py +195 -0
  48. gstaichi/lang/_texture.py +172 -0
  49. gstaichi/lang/_wrap_inspect.py +215 -0
  50. gstaichi/lang/any_array.py +99 -0
  51. gstaichi/lang/ast/__init__.py +5 -0
  52. gstaichi/lang/ast/ast_transformer.py +1323 -0
  53. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  54. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  55. gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
  56. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  57. gstaichi/lang/ast/checkers.py +106 -0
  58. gstaichi/lang/ast/symbol_resolver.py +57 -0
  59. gstaichi/lang/ast/transform.py +9 -0
  60. gstaichi/lang/common_ops.py +310 -0
  61. gstaichi/lang/exception.py +80 -0
  62. gstaichi/lang/expr.py +180 -0
  63. gstaichi/lang/field.py +428 -0
  64. gstaichi/lang/impl.py +1243 -0
  65. gstaichi/lang/kernel_arguments.py +155 -0
  66. gstaichi/lang/kernel_impl.py +1341 -0
  67. gstaichi/lang/matrix.py +1835 -0
  68. gstaichi/lang/matrix_ops.py +341 -0
  69. gstaichi/lang/matrix_ops_utils.py +190 -0
  70. gstaichi/lang/mesh.py +687 -0
  71. gstaichi/lang/misc.py +782 -0
  72. gstaichi/lang/ops.py +1494 -0
  73. gstaichi/lang/runtime_ops.py +13 -0
  74. gstaichi/lang/shell.py +35 -0
  75. gstaichi/lang/simt/__init__.py +5 -0
  76. gstaichi/lang/simt/block.py +94 -0
  77. gstaichi/lang/simt/grid.py +7 -0
  78. gstaichi/lang/simt/subgroup.py +191 -0
  79. gstaichi/lang/simt/warp.py +96 -0
  80. gstaichi/lang/snode.py +489 -0
  81. gstaichi/lang/source_builder.py +150 -0
  82. gstaichi/lang/struct.py +810 -0
  83. gstaichi/lang/util.py +312 -0
  84. gstaichi/linalg/__init__.py +8 -0
  85. gstaichi/linalg/matrixfree_cg.py +310 -0
  86. gstaichi/linalg/sparse_cg.py +59 -0
  87. gstaichi/linalg/sparse_matrix.py +303 -0
  88. gstaichi/linalg/sparse_solver.py +123 -0
  89. gstaichi/math/__init__.py +11 -0
  90. gstaichi/math/_complex.py +205 -0
  91. gstaichi/math/mathimpl.py +886 -0
  92. gstaichi/profiler/__init__.py +6 -0
  93. gstaichi/profiler/kernel_metrics.py +260 -0
  94. gstaichi/profiler/kernel_profiler.py +586 -0
  95. gstaichi/profiler/memory_profiler.py +15 -0
  96. gstaichi/profiler/scoped_profiler.py +36 -0
  97. gstaichi/sparse/__init__.py +3 -0
  98. gstaichi/sparse/_sparse_grid.py +77 -0
  99. gstaichi/tools/__init__.py +12 -0
  100. gstaichi/tools/diagnose.py +117 -0
  101. gstaichi/tools/np2ply.py +364 -0
  102. gstaichi/tools/vtk.py +38 -0
  103. gstaichi/types/__init__.py +19 -0
  104. gstaichi/types/annotations.py +52 -0
  105. gstaichi/types/compound_types.py +71 -0
  106. gstaichi/types/enums.py +49 -0
  107. gstaichi/types/ndarray_type.py +169 -0
  108. gstaichi/types/primitive_types.py +206 -0
  109. gstaichi/types/quant.py +88 -0
  110. gstaichi/types/texture_type.py +85 -0
  111. gstaichi/types/utils.py +11 -0
  112. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
  113. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
  114. gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
  115. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
  116. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
  117. gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
  118. gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
  119. gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  120. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
  121. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
  122. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  123. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  124. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  125. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  126. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  127. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  128. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  129. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  130. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  131. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  132. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  133. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  134. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  135. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  136. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  137. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  138. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  139. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  140. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  141. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  142. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  143. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  144. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  145. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  146. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  147. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  148. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  149. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  150. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  151. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  152. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  153. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  154. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  155. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  156. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  157. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  158. gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  159. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  160. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  161. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  162. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  163. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  164. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  165. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  166. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  167. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  168. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  169. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  170. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  171. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  172. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  173. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  174. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  175. gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
  176. gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
  177. gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
  178. gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
  179. gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
gstaichi/lang/misc.py ADDED
@@ -0,0 +1,782 @@
1
+ # type: ignore
2
+
3
+ import atexit
4
+ import os
5
+ import shutil
6
+ import tempfile
7
+ import warnings
8
+ from copy import deepcopy as _deepcopy
9
+
10
+ from gstaichi import _logging, _snode, _version_check
11
+ from gstaichi._lib import core as _ti_core
12
+ from gstaichi.lang import impl
13
+ from gstaichi.lang.expr import Expr
14
+ from gstaichi.lang.impl import axes, get_runtime
15
+ from gstaichi.profiler.kernel_profiler import get_default_kernel_profiler
16
+ from gstaichi.types.primitive_types import f32, f64, i32, i64
17
+
18
+ warnings.filterwarnings("once", category=DeprecationWarning, module="gstaichi")
19
+
20
+ # ----------------------
21
+ i = axes(0)
22
+ """Axis 0. For multi-dimensional arrays it's the direction downward the rows.
23
+ For a 1d array it's the direction along this array.
24
+ """
25
+ # ----------------------
26
+
27
+ j = axes(1)
28
+ """Axis 1. For multi-dimensional arrays it's the direction across the columns.
29
+ """
30
+ # ----------------------
31
+
32
+ k = axes(2)
33
+ """Axis 2. For arrays of dimension `d` >= 3, view each cell as an array of
34
+ lower dimension d-2, it's the first axis of this cell.
35
+ """
36
+ # ----------------------
37
+
38
+ l = axes(3)
39
+ """Axis 3. For arrays of dimension `d` >= 4, view each cell as an array of
40
+ lower dimension d-2, it's the second axis of this cell.
41
+ """
42
+ # ----------------------
43
+
44
+ ij = axes(0, 1)
45
+ """Axes (0, 1).
46
+ """
47
+ # ----------------------
48
+
49
+ ik = axes(0, 2)
50
+ """Axes (0, 2).
51
+ """
52
+ # ----------------------
53
+
54
+ il = axes(0, 3)
55
+ """Axes (0, 3).
56
+ """
57
+ # ----------------------
58
+
59
+ jk = axes(1, 2)
60
+ """Axes (1, 2).
61
+ """
62
+ # ----------------------
63
+
64
+ jl = axes(1, 3)
65
+ """Axes (1, 3).
66
+ """
67
+ # ----------------------
68
+
69
+ kl = axes(2, 3)
70
+ """Axes (2, 3).
71
+ """
72
+ # ----------------------
73
+
74
+ ijk = axes(0, 1, 2)
75
+ """Axes (0, 1, 2).
76
+ """
77
+ # ----------------------
78
+
79
+ ijl = axes(0, 1, 3)
80
+ """Axes (0, 1, 3).
81
+ """
82
+ # ----------------------
83
+
84
+ ikl = axes(0, 2, 3)
85
+ """Axes (0, 2, 3).
86
+ """
87
+ # ----------------------
88
+
89
+ jkl = axes(1, 2, 3)
90
+ """Axes (1, 2, 3).
91
+ """
92
+ # ----------------------
93
+
94
+ ijkl = axes(0, 1, 2, 3)
95
+ """Axes (0, 1, 2, 3).
96
+ """
97
+ # ----------------------
98
+
99
+ # ----------------------
100
+
101
+ x86_64 = _ti_core.x64
102
+ """The x64 CPU backend.
103
+ """
104
+ # ----------------------
105
+
106
+ x64 = _ti_core.x64
107
+ """The X64 CPU backend.
108
+ """
109
+ # ----------------------
110
+
111
+ arm64 = _ti_core.arm64
112
+ """The ARM CPU backend.
113
+ """
114
+ # ----------------------
115
+
116
+ cuda = _ti_core.cuda
117
+ """The CUDA backend.
118
+ """
119
+ # ----------------------
120
+
121
+ amdgpu = _ti_core.amdgpu
122
+ """The AMDGPU backend.
123
+ """
124
+ # ----------------------
125
+
126
+ metal = _ti_core.metal
127
+ """The Apple Metal backend.
128
+ """
129
+ # ----------------------
130
+
131
+ vulkan = _ti_core.vulkan
132
+ """The Vulkan backend.
133
+ """
134
+ # ----------------------
135
+
136
+ gpu = [cuda, metal, vulkan, amdgpu]
137
+ """A list of GPU backends supported on the current system.
138
+ Currently contains 'cuda', 'metal', 'vulkan', 'amdgpu'.
139
+
140
+ When this is used, GsTaichi automatically picks the matching GPU backend. If no
141
+ GPU is detected, GsTaichi falls back to the CPU backend.
142
+ """
143
+ # ----------------------
144
+
145
+ cpu = _ti_core.host_arch()
146
+ """A list of CPU backends supported on the current system.
147
+ Currently contains 'x64', 'x86_64', 'arm64'.
148
+
149
+ When this is used, GsTaichi automatically picks the matching CPU backend.
150
+ """
151
+ # ----------------------
152
+
153
+
154
+ def timeline_clear():
155
+ return impl.get_runtime().prog.timeline_clear()
156
+
157
+
158
+ def timeline_save(fn):
159
+ return impl.get_runtime().prog.timeline_save(fn)
160
+
161
+
162
+ extension = _ti_core.Extension
163
+ """An instance of GsTaichi extension.
164
+
165
+ The list of currently available extensions is ['sparse', 'quant', \
166
+ 'mesh', 'quant_basic', 'data64', 'adstack', 'bls', 'assertion', \
167
+ 'extfunc'].
168
+ """
169
+
170
+
171
+ def is_extension_supported(arch, ext):
172
+ """Checks whether an extension is supported on an arch.
173
+
174
+ Args:
175
+ arch (gstaichi_python.Arch): Specified arch.
176
+ ext (gstaichi_python.Extension): Specified extension.
177
+
178
+ Returns:
179
+ bool: Whether `ext` is supported on `arch`.
180
+ """
181
+ return _ti_core.is_extension_supported(arch, ext)
182
+
183
+
184
+ def reset():
185
+ """Resets GsTaichi to its initial state.
186
+ This will destroy all the allocated fields and kernels, and restore
187
+ the runtime to its default configuration.
188
+
189
+ Example::
190
+
191
+ >>> a = ti.field(ti.i32, shape=())
192
+ >>> a[None] = 1
193
+ >>> print("before reset: ", a)
194
+ before rest: 1
195
+ >>>
196
+ >>> ti.reset()
197
+ >>> print("after reset: ", a)
198
+ # will raise error because a is unavailable after reset.
199
+ """
200
+ impl.reset()
201
+ global runtime
202
+ runtime = impl.get_runtime()
203
+
204
+
205
+ class _EnvironmentConfigurator:
206
+ def __init__(self, kwargs, _cfg):
207
+ self.cfg = _cfg
208
+ self.kwargs = kwargs
209
+ self.keys = []
210
+
211
+ def add(self, key, _cast=None):
212
+ _cast = _cast or self.bool_int
213
+
214
+ self.keys.append(key)
215
+
216
+ # TI_OFFLINE_CACHE= : no effect
217
+ # TI_OFFLINE_CACHE=0 : False
218
+ # TI_OFFLINE_CACHE=1 : True
219
+ name = "TI_" + key.upper()
220
+ value = os.environ.get(name, "")
221
+ if key in self.kwargs:
222
+ self[key] = self.kwargs[key]
223
+ if value:
224
+ _ti_core.warn(f'Environment variable {name}={value} overridden by ti.init argument "{key}"')
225
+ del self.kwargs[key] # mark as recognized
226
+ elif value:
227
+ self[key] = _cast(value)
228
+
229
+ def __getitem__(self, key):
230
+ return getattr(self.cfg, key)
231
+
232
+ def __setitem__(self, key, value):
233
+ setattr(self.cfg, key, value)
234
+
235
+ @staticmethod
236
+ def bool_int(x):
237
+ return bool(int(x))
238
+
239
+
240
+ class _SpecialConfig:
241
+ # like CompileConfig in C++, this is the configurations that belong to other submodules
242
+ def __init__(self):
243
+ self.log_level = "info"
244
+ self.gdb_trigger = False
245
+ self.short_circuit_operators = True
246
+ self.print_full_traceback = False
247
+ self.unrolling_limit = 32
248
+
249
+
250
+ def prepare_sandbox():
251
+ """
252
+ Returns a temporary directory, which will be automatically deleted on exit.
253
+ It may contain the gstaichi_python shared object or some misc. files.
254
+ """
255
+ tmp_dir = tempfile.mkdtemp(prefix="gstaichi-")
256
+ atexit.register(shutil.rmtree, tmp_dir)
257
+ print(f"[GsTaichi] preparing sandbox at {tmp_dir}")
258
+ os.mkdir(os.path.join(tmp_dir, "runtime/"))
259
+ return tmp_dir
260
+
261
+
262
+ def check_require_version(require_version):
263
+ """
264
+ Check if installed version meets the requirements.
265
+ Allow to specify <major>.<minor>.<patch>.<hash>.
266
+ <patch>.<hash> is optional. If not match, raise an exception.
267
+ """
268
+ # Extract version number part (i.e. toss any revision / hash parts).
269
+ version_number_str = require_version
270
+ for c_idx, c in enumerate(require_version):
271
+ if not (c.isdigit() or c == "."):
272
+ version_number_str = require_version[:c_idx]
273
+ break
274
+ # Get required version.
275
+ try:
276
+ version_number_tuple = tuple([int(n) for n in version_number_str.split(".")])
277
+ major = version_number_tuple[0]
278
+ minor = version_number_tuple[1]
279
+ patch = 0
280
+ if len(version_number_tuple) > 2:
281
+ patch = version_number_tuple[2]
282
+ except:
283
+ raise Exception(
284
+ "The require_version should be formatted following PEP 440, "
285
+ "and inlucdes major, minor, and patch number, "
286
+ "e.g., major.minor.patch."
287
+ ) from None
288
+ # Get installed version
289
+ versions = [
290
+ int(_ti_core.get_version_major()),
291
+ int(_ti_core.get_version_minor()),
292
+ int(_ti_core.get_version_patch()),
293
+ ]
294
+ # Match installed version and required version.
295
+ match = major == versions[0] and (minor < versions[1] or minor == versions[1] and patch <= versions[2])
296
+
297
+ if not match:
298
+ raise Exception(
299
+ f"GsTaichi version mismatch. Required version >= {major}.{minor}.{patch}, installed version = {_ti_core.get_version_string()}."
300
+ )
301
+
302
+
303
+ def init(
304
+ arch=None,
305
+ default_fp=None,
306
+ default_ip=None,
307
+ _test_mode: bool = False,
308
+ enable_fallback: bool = True,
309
+ require_version: str | None = None,
310
+ src_ll_cache: bool = True,
311
+ **kwargs,
312
+ ):
313
+ """Initializes the GsTaichi runtime.
314
+
315
+ This should always be the entry point of your GsTaichi program. Most
316
+ importantly, it sets the backend used throughout the program.
317
+
318
+ Args:
319
+ arch: Backend to use. This is usually :const:`~gstaichi.lang.cpu` or :const:`~gstaichi.lang.gpu`.
320
+ default_fp (Optional[type]): Default floating-point type.
321
+ default_ip (Optional[type]): Default integral type.
322
+ require_version: A version string.
323
+ src_ll_cache: enable SRC-LL-CACHE, which will accelerate loading from cache, across all architectures,
324
+ for pure kernels (i.e. kernels declared as @ti.pure)
325
+ **kwargs: GsTaichi provides highly customizable compilation through
326
+ ``kwargs``, which allows for fine grained control of GsTaichi compiler
327
+ behavior. Below we list some of the most frequently used ones. For a
328
+ complete list, please check out
329
+ https://github.com/taichi-dev/gstaichi/blob/master/gstaichi/program/compile_config.h.
330
+
331
+ * ``cpu_max_num_threads`` (int): Sets the number of threads used by the CPU thread pool.
332
+ * ``debug`` (bool): Enables the debug mode, under which GsTaichi does a few more things like boundary checks.
333
+ * ``print_ir`` (bool): Prints the CHI IR of the GsTaichi kernels.
334
+ *``offline_cache`` (bool): Enables offline cache of the compiled kernels. Default to True. When this is enabled GsTaichi will cache compiled kernel on your local disk to accelerate future calls.
335
+ *``random_seed`` (int): Sets the seed of the random generator. The default is 0.
336
+ """
337
+ # Check version for users every 7 days if not disabled by users.
338
+ _version_check.start_version_check_thread()
339
+
340
+ # FIXME(https://github.com/taichi-dev/gstaichi/issues/4811): save the current working directory since it may be
341
+ # changed by the Vulkan backend initialization on OS X.
342
+ current_dir = os.getcwd()
343
+
344
+ # Check if installed version meets the requirements.
345
+ if require_version is not None:
346
+ check_require_version(require_version)
347
+
348
+ if "default_up" in kwargs:
349
+ raise KeyError("'default_up' is always the unsigned type of 'default_ip'. Please set 'default_ip' instead.")
350
+ # Make a deepcopy in case these args reference to items from ti.cfg, which are
351
+ # actually references. If no copy is made and the args are indeed references,
352
+ # ti.reset() could override the args to their default values.
353
+ default_fp = _deepcopy(default_fp)
354
+ default_ip = _deepcopy(default_ip)
355
+ kwargs = _deepcopy(kwargs)
356
+ reset()
357
+
358
+ cfg = impl.default_cfg()
359
+ cfg.offline_cache = True # Enable offline cache in frontend instead of C++ side
360
+
361
+ spec_cfg = _SpecialConfig()
362
+ env_comp = _EnvironmentConfigurator(kwargs, cfg)
363
+ env_spec = _EnvironmentConfigurator(kwargs, spec_cfg)
364
+
365
+ # configure default_fp/ip:
366
+ # TODO: move these stuff to _SpecialConfig too:
367
+ env_default_fp = os.environ.get("TI_DEFAULT_FP")
368
+ if env_default_fp:
369
+ if default_fp is not None:
370
+ _ti_core.warn(
371
+ f'Environment variable TI_DEFAULT_FP={env_default_fp} overridden by ti.init argument "default_fp"'
372
+ )
373
+ elif env_default_fp == "32":
374
+ default_fp = f32
375
+ elif env_default_fp == "64":
376
+ default_fp = f64
377
+ elif env_default_fp is not None:
378
+ raise ValueError(f"Invalid TI_DEFAULT_FP={env_default_fp}, should be 32 or 64")
379
+
380
+ env_default_ip = os.environ.get("TI_DEFAULT_IP")
381
+ if env_default_ip:
382
+ if default_ip is not None:
383
+ _ti_core.warn(
384
+ f'Environment variable TI_DEFAULT_IP={env_default_ip} overridden by ti.init argument "default_ip"'
385
+ )
386
+ elif env_default_ip == "32":
387
+ default_ip = i32
388
+ elif env_default_ip == "64":
389
+ default_ip = i64
390
+ elif env_default_ip is not None:
391
+ raise ValueError(f"Invalid TI_DEFAULT_IP={env_default_ip}, should be 32 or 64")
392
+
393
+ if default_fp is not None:
394
+ impl.get_runtime().set_default_fp(default_fp)
395
+ if default_ip is not None:
396
+ impl.get_runtime().set_default_ip(default_ip)
397
+
398
+ # submodule configurations (spec_cfg):
399
+ env_spec.add("log_level", str)
400
+ env_spec.add("gdb_trigger")
401
+ env_spec.add("short_circuit_operators")
402
+ env_spec.add("print_full_traceback")
403
+ env_spec.add("unrolling_limit")
404
+
405
+ # compiler configurations (ti.cfg):
406
+ for key in dir(cfg):
407
+ if key in ["arch", "default_fp", "default_ip"]:
408
+ continue
409
+ _cast = type(getattr(cfg, key))
410
+ if _cast is bool:
411
+ _cast = None
412
+ env_comp.add(key, _cast)
413
+
414
+ unexpected_keys = kwargs.keys()
415
+
416
+ if len(unexpected_keys):
417
+ raise KeyError(f'Unrecognized keyword argument(s) for ti.init: {", ".join(unexpected_keys)}')
418
+
419
+ # dispatch configurations that are not in ti.cfg:
420
+ if not _test_mode:
421
+ _ti_core.set_core_trigger_gdb_when_crash(spec_cfg.gdb_trigger)
422
+ impl.get_runtime().short_circuit_operators = spec_cfg.short_circuit_operators
423
+ impl.get_runtime().print_full_traceback = spec_cfg.print_full_traceback
424
+ impl.get_runtime().unrolling_limit = spec_cfg.unrolling_limit
425
+ impl.get_runtime().src_ll_cache = src_ll_cache
426
+ _logging.set_logging_level(spec_cfg.log_level.lower())
427
+
428
+ # select arch (backend):
429
+ env_arch = os.environ.get("TI_ARCH")
430
+ if env_arch is not None:
431
+ _logging.info(f"Following TI_ARCH setting up for arch={env_arch}")
432
+ arch = _ti_core.arch_from_name(env_arch)
433
+ cfg.arch = adaptive_arch_select(arch, enable_fallback)
434
+ print(f"[GsTaichi] Starting on arch={_ti_core.arch_name(cfg.arch)}")
435
+
436
+ if _test_mode:
437
+ return spec_cfg
438
+
439
+ get_default_kernel_profiler().set_kernel_profiler_mode(cfg.kernel_profiler)
440
+
441
+ # create a new program:
442
+ impl.get_runtime().create_program()
443
+
444
+ _logging.trace("Materializing runtime...")
445
+ impl.get_runtime().prog.materialize_runtime()
446
+
447
+ impl._root_fb = _snode.FieldsBuilder()
448
+
449
+ if cfg.debug:
450
+ impl.get_runtime()._register_signal_handlers()
451
+
452
+ # Recover the current working directory (https://github.com/taichi-dev/gstaichi/issues/4811)
453
+ os.chdir(current_dir)
454
+ return None
455
+
456
+
457
+ def no_activate(*args):
458
+ """Deactivates a SNode pointer."""
459
+ assert isinstance(get_runtime().compiling_callable, _ti_core.KernelCxx)
460
+ for v in args:
461
+ get_runtime().compiling_callable.no_activate(v._snode.ptr)
462
+
463
+
464
+ def block_local(*args):
465
+ """Hints GsTaichi to cache the fields and to enable the BLS optimization.
466
+
467
+ Please visit https://docs.taichi-lang.org/docs/performance
468
+ for how BLS is used.
469
+
470
+ Args:
471
+ *args (List[Field]): A list of sparse GsTaichi fields.
472
+ """
473
+ if impl.current_cfg().opt_level == 0:
474
+ _logging.warn("""opt_level = 1 is enforced to enable bls analysis.""")
475
+ impl.current_cfg().opt_level = 1
476
+ for a in args:
477
+ for v in a._get_field_members():
478
+ get_runtime().compiling_callable.ast_builder().insert_snode_access_flag(
479
+ _ti_core.SNodeAccessFlag.block_local, v.ptr
480
+ )
481
+
482
+
483
+ def mesh_local(*args):
484
+ """Hints the compiler to cache the mesh attributes
485
+ and to enable the mesh BLS optimization,
486
+ only available for backends supporting `ti.extension.mesh` and to use with mesh-for loop.
487
+
488
+ Related to https://github.com/taichi-dev/gstaichi/issues/3608
489
+
490
+ Args:
491
+ *args (List[Attribute]): A list of mesh attributes or fields accessed as attributes.
492
+
493
+ Examples::
494
+
495
+ # instantiate model
496
+ mesh_builder = ti.Mesh.tri()
497
+ mesh_builder.verts.place({
498
+ 'x' : ti.f32,
499
+ 'y' : ti.f32
500
+ })
501
+ model = mesh_builder.build(meta)
502
+
503
+ @ti.kernel
504
+ def foo():
505
+ # hint the compiler to cache mesh vertex attribute `x` and `y`.
506
+ ti.mesh_local(model.verts.x, model.verts.y)
507
+ for v0 in model.verts: # mesh-for loop
508
+ for v1 in v0.verts:
509
+ v0.x += v1.y
510
+ """
511
+ for a in args:
512
+ for v in a._get_field_members():
513
+ get_runtime().compiling_callable.ast_builder().insert_snode_access_flag(
514
+ _ti_core.SNodeAccessFlag.mesh_local, v.ptr
515
+ )
516
+
517
+
518
+ def cache_read_only(*args):
519
+ for a in args:
520
+ for v in a._get_field_members():
521
+ get_runtime().compiling_callable.ast_builder().insert_snode_access_flag(
522
+ _ti_core.SNodeAccessFlag.read_only, v.ptr
523
+ )
524
+
525
+
526
+ def assume_in_range(val, base, low, high):
527
+ """Hints the compiler that a value is between a specified range,
528
+ for the compiler to perform scatchpad optimization, and return the
529
+ value untouched.
530
+
531
+ The assumed range is `[base + low, base + high)`.
532
+
533
+ Args:
534
+
535
+ val (Number): The input value.
536
+ base (Number): The base point for the range interval.
537
+ low (Number): The lower offset relative to `base` (included).
538
+ high (Number): The higher offset relative to `base` (excluded).
539
+
540
+ Returns:
541
+ Return the input `value` untouched.
542
+
543
+ Example::
544
+
545
+ >>> # hint the compiler that x is in range [8, 12).
546
+ >>> x = ti.assume_in_range(x, 10, -2, 2)
547
+ >>> x
548
+ 10
549
+ """
550
+ return _ti_core.expr_assume_in_range(
551
+ Expr(val).ptr, Expr(base).ptr, low, high, _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
552
+ )
553
+
554
+
555
+ def loop_unique(val, covers=None):
556
+ if covers is None:
557
+ covers = []
558
+ if not isinstance(covers, (list, tuple)):
559
+ covers = [covers]
560
+ covers = [x.snode.ptr if isinstance(x, Expr) else x.ptr for x in covers]
561
+ return _ti_core.expr_loop_unique(
562
+ Expr(val).ptr, covers, _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
563
+ )
564
+
565
+
566
+ def _parallelize(v):
567
+ """Sets the number of threads to use on CPU."""
568
+ get_runtime().compiling_callable.ast_builder().parallelize(v)
569
+ if v == 1:
570
+ get_runtime().compiling_callable.ast_builder().strictly_serialize()
571
+
572
+
573
+ def _serialize():
574
+ """Sets the number of threads to 1."""
575
+ _parallelize(1)
576
+
577
+
578
+ def _block_dim(dim):
579
+ """Set the number of threads in a block to `dim`."""
580
+ get_runtime().compiling_callable.ast_builder().block_dim(dim)
581
+
582
+
583
+ def _block_dim_adaptive(block_dim_adaptive):
584
+ """Enable/Disable backends set block_dim adaptively."""
585
+ if get_runtime().prog.config().arch != cpu:
586
+ _logging.warn("Adaptive block_dim is supported on CPU backend only")
587
+ else:
588
+ get_runtime().prog.config().cpu_block_dim_adaptive = block_dim_adaptive
589
+
590
+
591
+ def _bit_vectorize():
592
+ """Enable bit vectorization of struct fors on quant_arrays."""
593
+ get_runtime().compiling_callable.ast_builder().bit_vectorize()
594
+
595
+
596
+ def loop_config(
597
+ *,
598
+ block_dim=None,
599
+ serialize=False,
600
+ parallelize=None,
601
+ block_dim_adaptive=True,
602
+ bit_vectorize=False,
603
+ ):
604
+ """Sets directives for the next loop
605
+
606
+ Args:
607
+ block_dim (int): The number of threads in a block on GPU
608
+ serialize (bool): Whether to let the for loop execute serially, `serialize=True` equals to `parallelize=1`
609
+ parallelize (int): The number of threads to use on CPU
610
+ block_dim_adaptive (bool): Whether to allow backends set block_dim adaptively, enabled by default
611
+ bit_vectorize (bool): Whether to enable bit vectorization of struct fors on quant_arrays.
612
+
613
+ Examples::
614
+
615
+ @ti.kernel
616
+ def break_in_serial_for() -> ti.i32:
617
+ a = 0
618
+ ti.loop_config(serialize=True)
619
+ for i in range(100): # This loop runs serially
620
+ a += i
621
+ if i == 10:
622
+ break
623
+ return a
624
+
625
+ break_in_serial_for() # returns 55
626
+
627
+ n = 128
628
+ val = ti.field(ti.i32, shape=n)
629
+ @ti.kernel
630
+ def fill():
631
+ ti.loop_config(parallelize=8, block_dim=16)
632
+ # If the kernel is run on the CPU backend, 8 threads will be used to run it
633
+ # If the kernel is run on the CUDA backend, each block will have 16 threads.
634
+ for i in range(n):
635
+ val[i] = i
636
+
637
+ u1 = ti.types.quant.int(bits=1, signed=False)
638
+ x = ti.field(dtype=u1)
639
+ y = ti.field(dtype=u1)
640
+ cell = ti.root.dense(ti.ij, (128, 4))
641
+ cell.quant_array(ti.j, 32).place(x)
642
+ cell.quant_array(ti.j, 32).place(y)
643
+ @ti.kernel
644
+ def copy():
645
+ ti.loop_config(bit_vectorize=True)
646
+ # 32 bits, instead of 1 bit, will be copied at a time
647
+ for i, j in x:
648
+ y[i, j] = x[i, j]
649
+ """
650
+ if block_dim is not None:
651
+ _block_dim(block_dim)
652
+
653
+ if serialize:
654
+ _parallelize(1)
655
+ elif parallelize is not None:
656
+ _parallelize(parallelize)
657
+
658
+ if not block_dim_adaptive:
659
+ _block_dim_adaptive(block_dim_adaptive)
660
+
661
+ if bit_vectorize:
662
+ _bit_vectorize()
663
+
664
+
665
+ def global_thread_idx():
666
+ """Returns the global thread id of this running thread,
667
+ only available for cpu and cuda backends.
668
+
669
+ For cpu backends this is equal to the cpu thread id,
670
+ For cuda backends this is equal to `block_id * block_dim + thread_id`.
671
+
672
+ Example::
673
+
674
+ >>> f = ti.field(ti.f32, shape=(16, 16))
675
+ >>> @ti.kernel
676
+ >>> def test():
677
+ >>> for i in ti.grouped(f):
678
+ >>> print(ti.global_thread_idx())
679
+ >>>
680
+ test()
681
+ """
682
+ return impl.get_runtime().compiling_callable.ast_builder().insert_thread_idx_expr()
683
+
684
+
685
+ def mesh_patch_idx():
686
+ """Returns the internal mesh patch id of this running thread,
687
+ only available for backends supporting `ti.extension.mesh` and to use within mesh-for loop.
688
+
689
+ Related to https://github.com/taichi-dev/gstaichi/issues/3608
690
+ """
691
+ return (
692
+ impl.get_runtime()
693
+ .compiling_callable.ast_builder()
694
+ .insert_patch_idx_expr(_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()))
695
+ )
696
+
697
+
698
+ def is_arch_supported(arch):
699
+ """Checks whether an arch is supported on the machine.
700
+
701
+ Args:
702
+ arch (gstaichi_python.Arch): Specified arch.
703
+
704
+ Returns:
705
+ bool: Whether `arch` is supported on the machine.
706
+ """
707
+
708
+ arch_table = {
709
+ cuda: _ti_core.with_cuda,
710
+ amdgpu: _ti_core.with_amdgpu,
711
+ metal: _ti_core.with_metal,
712
+ vulkan: _ti_core.with_vulkan,
713
+ cpu: lambda: True,
714
+ }
715
+ with_arch = arch_table.get(arch, lambda: False)
716
+ try:
717
+ return with_arch()
718
+ except Exception as e:
719
+ arch = _ti_core.arch_name(arch)
720
+ _ti_core.warn(
721
+ f"{e.__class__.__name__}: '{e}' occurred when detecting "
722
+ f"{arch}, consider adding `TI_ENABLE_{arch.upper()}=0` "
723
+ f" to environment variables to suppress this warning message."
724
+ )
725
+ return False
726
+
727
+
728
+ def adaptive_arch_select(arch, enable_fallback):
729
+ if arch is None:
730
+ return cpu
731
+ if not isinstance(arch, (list, tuple)):
732
+ arch = [arch]
733
+ for a in arch:
734
+ if is_arch_supported(a):
735
+ return a
736
+ if not enable_fallback:
737
+ raise RuntimeError(f"Arch={arch} is not supported")
738
+ _logging.warn(f"Arch={arch} is not supported, falling back to CPU")
739
+ return cpu
740
+
741
+
742
+ def get_host_arch_list():
743
+ return [_ti_core.host_arch()]
744
+
745
+
746
+ __all__ = [
747
+ "i",
748
+ "ij",
749
+ "ijk",
750
+ "ijkl",
751
+ "ijl",
752
+ "ik",
753
+ "ikl",
754
+ "il",
755
+ "j",
756
+ "jk",
757
+ "jkl",
758
+ "jl",
759
+ "k",
760
+ "kl",
761
+ "l",
762
+ "x86_64",
763
+ "x64",
764
+ "arm64",
765
+ "cpu",
766
+ "cuda",
767
+ "amdgpu",
768
+ "gpu",
769
+ "metal",
770
+ "vulkan",
771
+ "extension",
772
+ "loop_config",
773
+ "global_thread_idx",
774
+ "assume_in_range",
775
+ "block_local",
776
+ "cache_read_only",
777
+ "init",
778
+ "mesh_local",
779
+ "no_activate",
780
+ "reset",
781
+ "mesh_patch_idx",
782
+ ]