gstaichi 2.1.1rc3__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. gstaichi/CHANGELOG.md +4 -0
  2. gstaichi/__init__.py +40 -0
  3. gstaichi/_funcs.py +706 -0
  4. gstaichi/_kernels.py +420 -0
  5. gstaichi/_lib/__init__.py +3 -0
  6. gstaichi/_lib/core/__init__.py +0 -0
  7. gstaichi/_lib/core/gstaichi_python.cpython-313-darwin.so +0 -0
  8. gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
  9. gstaichi/_lib/core/py.typed +0 -0
  10. gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
  11. gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
  12. gstaichi/_lib/utils.py +243 -0
  13. gstaichi/_logging.py +131 -0
  14. gstaichi/_snode/__init__.py +5 -0
  15. gstaichi/_snode/fields_builder.py +187 -0
  16. gstaichi/_snode/snode_tree.py +34 -0
  17. gstaichi/_test_tools/__init__.py +18 -0
  18. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  19. gstaichi/_test_tools/load_kernel_string.py +30 -0
  20. gstaichi/_test_tools/textwrap2.py +6 -0
  21. gstaichi/_version.py +1 -0
  22. gstaichi/_version_check.py +100 -0
  23. gstaichi/ad/__init__.py +3 -0
  24. gstaichi/ad/_ad.py +530 -0
  25. gstaichi/algorithms/__init__.py +3 -0
  26. gstaichi/algorithms/_algorithms.py +117 -0
  27. gstaichi/assets/.git +1 -0
  28. gstaichi/assets/Go-Regular.ttf +0 -0
  29. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  30. gstaichi/examples/lcg_python.py +26 -0
  31. gstaichi/examples/lcg_taichi.py +34 -0
  32. gstaichi/examples/minimal.py +28 -0
  33. gstaichi/experimental.py +16 -0
  34. gstaichi/lang/__init__.py +50 -0
  35. gstaichi/lang/_dataclass_util.py +31 -0
  36. gstaichi/lang/_fast_caching/__init__.py +3 -0
  37. gstaichi/lang/_fast_caching/args_hasher.py +110 -0
  38. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  39. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  40. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  41. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  42. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  43. gstaichi/lang/_fast_caching/src_hasher.py +75 -0
  44. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  45. gstaichi/lang/_ndarray.py +352 -0
  46. gstaichi/lang/_ndrange.py +152 -0
  47. gstaichi/lang/_template_mapper.py +195 -0
  48. gstaichi/lang/_texture.py +172 -0
  49. gstaichi/lang/_wrap_inspect.py +215 -0
  50. gstaichi/lang/any_array.py +99 -0
  51. gstaichi/lang/ast/__init__.py +5 -0
  52. gstaichi/lang/ast/ast_transformer.py +1323 -0
  53. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  54. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  55. gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
  56. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  57. gstaichi/lang/ast/checkers.py +106 -0
  58. gstaichi/lang/ast/symbol_resolver.py +57 -0
  59. gstaichi/lang/ast/transform.py +9 -0
  60. gstaichi/lang/common_ops.py +310 -0
  61. gstaichi/lang/exception.py +80 -0
  62. gstaichi/lang/expr.py +180 -0
  63. gstaichi/lang/field.py +428 -0
  64. gstaichi/lang/impl.py +1243 -0
  65. gstaichi/lang/kernel_arguments.py +155 -0
  66. gstaichi/lang/kernel_impl.py +1341 -0
  67. gstaichi/lang/matrix.py +1835 -0
  68. gstaichi/lang/matrix_ops.py +341 -0
  69. gstaichi/lang/matrix_ops_utils.py +190 -0
  70. gstaichi/lang/mesh.py +687 -0
  71. gstaichi/lang/misc.py +782 -0
  72. gstaichi/lang/ops.py +1494 -0
  73. gstaichi/lang/runtime_ops.py +13 -0
  74. gstaichi/lang/shell.py +35 -0
  75. gstaichi/lang/simt/__init__.py +5 -0
  76. gstaichi/lang/simt/block.py +94 -0
  77. gstaichi/lang/simt/grid.py +7 -0
  78. gstaichi/lang/simt/subgroup.py +191 -0
  79. gstaichi/lang/simt/warp.py +96 -0
  80. gstaichi/lang/snode.py +489 -0
  81. gstaichi/lang/source_builder.py +150 -0
  82. gstaichi/lang/struct.py +810 -0
  83. gstaichi/lang/util.py +312 -0
  84. gstaichi/linalg/__init__.py +8 -0
  85. gstaichi/linalg/matrixfree_cg.py +310 -0
  86. gstaichi/linalg/sparse_cg.py +59 -0
  87. gstaichi/linalg/sparse_matrix.py +303 -0
  88. gstaichi/linalg/sparse_solver.py +123 -0
  89. gstaichi/math/__init__.py +11 -0
  90. gstaichi/math/_complex.py +205 -0
  91. gstaichi/math/mathimpl.py +886 -0
  92. gstaichi/profiler/__init__.py +6 -0
  93. gstaichi/profiler/kernel_metrics.py +260 -0
  94. gstaichi/profiler/kernel_profiler.py +586 -0
  95. gstaichi/profiler/memory_profiler.py +15 -0
  96. gstaichi/profiler/scoped_profiler.py +36 -0
  97. gstaichi/sparse/__init__.py +3 -0
  98. gstaichi/sparse/_sparse_grid.py +77 -0
  99. gstaichi/tools/__init__.py +12 -0
  100. gstaichi/tools/diagnose.py +117 -0
  101. gstaichi/tools/np2ply.py +364 -0
  102. gstaichi/tools/vtk.py +38 -0
  103. gstaichi/types/__init__.py +19 -0
  104. gstaichi/types/annotations.py +52 -0
  105. gstaichi/types/compound_types.py +71 -0
  106. gstaichi/types/enums.py +49 -0
  107. gstaichi/types/ndarray_type.py +169 -0
  108. gstaichi/types/primitive_types.py +206 -0
  109. gstaichi/types/quant.py +88 -0
  110. gstaichi/types/texture_type.py +85 -0
  111. gstaichi/types/utils.py +11 -0
  112. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
  113. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
  114. gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
  115. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
  116. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
  117. gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
  118. gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
  119. gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  120. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
  121. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
  122. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  123. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  124. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  125. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  126. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  127. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  128. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  129. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  130. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  131. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  132. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  133. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  134. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  135. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  136. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  137. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  138. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  139. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  140. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  141. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  142. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  143. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  144. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  145. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  146. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  147. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  148. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  149. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  150. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  151. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  152. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  153. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  154. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  155. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  156. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  157. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  158. gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  159. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  160. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  161. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  162. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  163. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  164. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  165. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  166. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  167. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  168. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  169. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  170. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  171. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  172. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  173. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  174. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  175. gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
  176. gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
  177. gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
  178. gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
  179. gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
gstaichi/lang/impl.py ADDED
@@ -0,0 +1,1243 @@
1
+ import numbers
2
+ from types import FunctionType, MethodType
3
+ from typing import Any, Iterable, Sequence
4
+
5
+ import numpy as np
6
+
7
+ from gstaichi._lib import core as _ti_core
8
+ from gstaichi._lib.core.gstaichi_python import (
9
+ DataTypeCxx,
10
+ Function,
11
+ KernelCxx,
12
+ Program,
13
+ )
14
+ from gstaichi._snode.fields_builder import FieldsBuilder
15
+ from gstaichi.lang._ndarray import ScalarNdarray
16
+ from gstaichi.lang._ndrange import GroupedNDRange, _Ndrange
17
+ from gstaichi.lang._texture import RWTextureAccessor
18
+ from gstaichi.lang.any_array import AnyArray
19
+ from gstaichi.lang.exception import (
20
+ GsTaichiCompilationError,
21
+ GsTaichiRuntimeError,
22
+ GsTaichiSyntaxError,
23
+ GsTaichiTypeError,
24
+ )
25
+ from gstaichi.lang.expr import Expr, make_expr_group
26
+ from gstaichi.lang.field import Field, ScalarField
27
+ from gstaichi.lang.kernel_arguments import SparseMatrixProxy
28
+ from gstaichi.lang.kernel_impl import BoundGsTaichiCallable, GsTaichiCallable, Kernel
29
+ from gstaichi.lang.matrix import (
30
+ Matrix,
31
+ MatrixField,
32
+ MatrixNdarray,
33
+ MatrixType,
34
+ Vector,
35
+ VectorNdarray,
36
+ make_matrix,
37
+ )
38
+ from gstaichi.lang.mesh import (
39
+ ConvType,
40
+ MeshElementFieldProxy,
41
+ MeshInstance,
42
+ MeshRelationAccessProxy,
43
+ MeshReorderedMatrixFieldProxy,
44
+ MeshReorderedScalarFieldProxy,
45
+ element_type_name,
46
+ )
47
+ from gstaichi.lang.simt.block import SharedArray
48
+ from gstaichi.lang.snode import SNode
49
+ from gstaichi.lang.struct import Struct, StructField, _IntermediateStruct
50
+ from gstaichi.lang.util import (
51
+ cook_dtype,
52
+ get_traceback,
53
+ gstaichi_scope,
54
+ is_gstaichi_class,
55
+ python_scope,
56
+ warning,
57
+ )
58
+ from gstaichi.types.enums import SNodeGradType
59
+ from gstaichi.types.primitive_types import (
60
+ all_types,
61
+ f16,
62
+ f32,
63
+ f64,
64
+ i32,
65
+ i64,
66
+ u8,
67
+ u32,
68
+ u64,
69
+ )
70
+
71
+
72
+ @gstaichi_scope
73
+ def expr_init_shared_array(shape, element_type):
74
+ ast_builder = get_runtime().compiling_callable.ast_builder()
75
+ debug_info = _ti_core.DebugInfo(get_runtime().get_current_src_info())
76
+ return ast_builder.expr_alloca_shared_array(shape, element_type, debug_info)
77
+
78
+
79
+ @gstaichi_scope
80
+ def expr_init(rhs):
81
+ compiling_callable = get_runtime().compiling_callable
82
+ if rhs is None:
83
+ return Expr(
84
+ compiling_callable.ast_builder().expr_alloca(_ti_core.DebugInfo(get_runtime().get_current_src_info()))
85
+ )
86
+ if isinstance(rhs, Matrix) and (hasattr(rhs, "_DIM")):
87
+ return Matrix(*rhs.to_list(), ndim=rhs.ndim) # type: ignore
88
+ if isinstance(rhs, Matrix):
89
+ return make_matrix(rhs.to_list())
90
+ if isinstance(rhs, SharedArray):
91
+ return rhs
92
+ if isinstance(rhs, Struct):
93
+ return Struct(rhs.to_dict(include_methods=True, include_ndim=True))
94
+ if isinstance(rhs, list):
95
+ return [expr_init(e) for e in rhs]
96
+ if isinstance(rhs, tuple):
97
+ return tuple(expr_init(e) for e in rhs)
98
+ if isinstance(rhs, dict):
99
+ return dict((key, expr_init(val)) for key, val in rhs.items())
100
+ if isinstance(rhs, _ti_core.DataTypeCxx):
101
+ return rhs
102
+ if isinstance(rhs, _ti_core.Arch):
103
+ return rhs
104
+ if isinstance(rhs, _Ndrange):
105
+ return rhs
106
+ if isinstance(rhs, MeshElementFieldProxy):
107
+ return rhs
108
+ if isinstance(rhs, MeshRelationAccessProxy):
109
+ return rhs
110
+ if hasattr(rhs, "_data_oriented"):
111
+ return rhs
112
+ return Expr(
113
+ compiling_callable.ast_builder().expr_var(
114
+ Expr(rhs).ptr, _ti_core.DebugInfo(get_runtime().get_current_src_info())
115
+ )
116
+ )
117
+
118
+
119
+ @gstaichi_scope
120
+ def expr_init_func(rhs): # temporary solution to allow passing in fields as arguments
121
+ if isinstance(rhs, Field):
122
+ return rhs
123
+ return expr_init(rhs)
124
+
125
+
126
+ def begin_frontend_struct_for(ast_builder, group, loop_range):
127
+ if not isinstance(loop_range, (AnyArray, Field, SNode, RWTextureAccessor, _Root)):
128
+ raise TypeError(
129
+ f"Cannot loop over the object {type(loop_range)} in GsTaichi scope. Only GsTaichi fields (via template) or dense arrays (via types.ndarray) are supported."
130
+ )
131
+ if group.size() != len(loop_range.shape):
132
+ raise IndexError(
133
+ "Number of struct-for indices does not match loop variable dimensionality "
134
+ f"({group.size()} != {len(loop_range.shape)}). Maybe you wanted to "
135
+ 'use "for I in ti.grouped(x)" to group all indices into a single vector I?'
136
+ )
137
+ dbg_info = _ti_core.DebugInfo(get_runtime().get_current_src_info())
138
+ if isinstance(loop_range, (AnyArray, RWTextureAccessor)):
139
+ ast_builder.begin_frontend_struct_for_on_external_tensor(group, loop_range._loop_range(), dbg_info)
140
+ else:
141
+ ast_builder.begin_frontend_struct_for_on_snode(group, loop_range._loop_range(), dbg_info)
142
+
143
+
144
+ def begin_frontend_if(ast_builder, cond, stmt_dbg_info):
145
+ assert ast_builder is not None
146
+ if is_gstaichi_class(cond):
147
+ raise ValueError(
148
+ "The truth value of vectors/matrices is ambiguous.\n"
149
+ "Consider using `any` or `all` when comparing vectors/matrices:\n"
150
+ " if all(x == y):\n"
151
+ "or\n"
152
+ " if any(x != y):\n"
153
+ )
154
+ ast_builder.begin_frontend_if(Expr(cond).ptr, stmt_dbg_info)
155
+
156
+
157
+ @gstaichi_scope
158
+ def _calc_slice(index, default_stop):
159
+ start, stop, step = index.start or 0, index.stop or default_stop, index.step or 1
160
+
161
+ def check_validity(x):
162
+ # TODO(mzmzm): support variable in slice
163
+ if isinstance(x, Expr):
164
+ raise GsTaichiCompilationError(
165
+ "GsTaichi does not support variables in slice now, please use constant instead of it."
166
+ )
167
+
168
+ _ = check_validity(start), check_validity(stop), check_validity(step)
169
+ return [_ for _ in range(start, stop, step)]
170
+
171
+
172
+ def validate_subscript_index(value, index):
173
+ if isinstance(value, Field):
174
+ # field supports negative indices
175
+ return
176
+
177
+ if isinstance(index, Expr):
178
+ return
179
+
180
+ if isinstance(index, Iterable):
181
+ for ind in index:
182
+ validate_subscript_index(value, ind)
183
+
184
+ if isinstance(index, slice):
185
+ validate_subscript_index(value, index.start)
186
+ validate_subscript_index(value, index.stop)
187
+
188
+ if isinstance(index, int) and index < 0:
189
+ raise GsTaichiSyntaxError("Negative indices are not supported in GsTaichi kernels.")
190
+
191
+
192
+ @gstaichi_scope
193
+ def subscript(ast_builder, value, *_indices, skip_reordered=False):
194
+ dbg_info = _ti_core.DebugInfo(get_runtime().get_current_src_info())
195
+ ast_builder = get_runtime().compiling_callable.ast_builder()
196
+ # Directly evaluate in Python for non-GsTaichi types
197
+ if not isinstance(
198
+ value,
199
+ (
200
+ Expr,
201
+ Field,
202
+ AnyArray,
203
+ SparseMatrixProxy,
204
+ MeshElementFieldProxy,
205
+ MeshRelationAccessProxy,
206
+ SharedArray,
207
+ ),
208
+ ):
209
+ if len(_indices) == 1:
210
+ _indices = _indices[0]
211
+ return value.__getitem__(_indices)
212
+
213
+ has_slice = False
214
+
215
+ flattened_indices = []
216
+ for _index in _indices:
217
+ if isinstance(_index, Matrix):
218
+ ind = _index.to_list()
219
+ elif isinstance(_index, slice):
220
+ ind = [_index]
221
+ has_slice = True
222
+ else:
223
+ ind = [_index]
224
+ flattened_indices += ind
225
+ indices = tuple(flattened_indices)
226
+ validate_subscript_index(value, indices)
227
+
228
+ if len(indices) == 1 and indices[0] is None:
229
+ indices = ()
230
+
231
+ indices_expr_group = None
232
+ if has_slice:
233
+ if not (isinstance(value, Expr) and value.is_tensor()):
234
+ raise GsTaichiSyntaxError(f"The type {type(value)} do not support index of slice type")
235
+ else:
236
+ indices_expr_group = make_expr_group(*indices)
237
+
238
+ if isinstance(value, SharedArray):
239
+ return value.subscript(*indices)
240
+ if isinstance(value, MeshElementFieldProxy):
241
+ return value.subscript(*indices) # type: ignore
242
+ if isinstance(value, MeshRelationAccessProxy):
243
+ return value.subscript(*indices)
244
+ if isinstance(value, (MeshReorderedScalarFieldProxy, MeshReorderedMatrixFieldProxy)) and not skip_reordered:
245
+ assert len(indices) > 0
246
+ reordered_index = tuple(
247
+ [
248
+ Expr(
249
+ ast_builder.mesh_index_conversion(
250
+ value.mesh_ptr, value.element_type, Expr(indices[0]).ptr, ConvType.g2r, dbg_info
251
+ )
252
+ )
253
+ ]
254
+ )
255
+ return subscript(ast_builder, value, *reordered_index, skip_reordered=True)
256
+ if isinstance(value, SparseMatrixProxy):
257
+ return value.subscript(*indices)
258
+ if isinstance(value, Field):
259
+ _var = value._get_field_members()[0].ptr
260
+ snode = _var.snode()
261
+ if snode is None:
262
+ if _var.is_primal():
263
+ raise RuntimeError(f"{_var.get_expr_name()} has not been placed.")
264
+ else:
265
+ raise RuntimeError(
266
+ f"Gradient {_var.get_expr_name()} has not been placed, check whether `needs_grad=True`"
267
+ )
268
+
269
+ assert indices_expr_group is not None
270
+ if isinstance(value, MatrixField):
271
+ return Expr(ast_builder.expr_subscript(value.ptr, indices_expr_group, dbg_info))
272
+ if isinstance(value, StructField):
273
+ entries = {k: subscript(ast_builder, v, *indices) for k, v in value._items}
274
+ entries["__struct_methods"] = value.struct_methods
275
+ return _IntermediateStruct(entries)
276
+ return Expr(ast_builder.expr_subscript(_var, indices_expr_group, dbg_info))
277
+ if isinstance(value, AnyArray):
278
+ assert indices_expr_group is not None
279
+ return Expr(ast_builder.expr_subscript(value.ptr, indices_expr_group, dbg_info))
280
+ assert isinstance(value, Expr)
281
+ # Index into TensorType
282
+ # value: IndexExpression with ret_type = TensorType
283
+ assert value.is_tensor()
284
+
285
+ if has_slice:
286
+ shape = value.get_shape()
287
+ dim = len(shape)
288
+ assert dim == len(indices)
289
+ indices = [
290
+ _calc_slice(index, shape[i]) if isinstance(index, slice) else index for i, index in enumerate(indices)
291
+ ]
292
+ if dim == 1:
293
+ assert isinstance(indices[0], list)
294
+ multiple_indices = [make_expr_group(i) for i in indices[0]]
295
+ return_shape = (len(indices[0]),)
296
+ else:
297
+ assert dim == 2
298
+ if isinstance(indices[0], list) and isinstance(indices[1], list):
299
+ multiple_indices = [make_expr_group(i, j) for i in indices[0] for j in indices[1]]
300
+ return_shape = (len(indices[0]), len(indices[1]))
301
+ elif isinstance(indices[0], list): # indices[1] is not list
302
+ multiple_indices = [make_expr_group(i, indices[1]) for i in indices[0]]
303
+ return_shape = (len(indices[0]),)
304
+ else: # indices[0] is not list while indices[1] is list
305
+ multiple_indices = [make_expr_group(indices[0], j) for j in indices[1]]
306
+ return_shape = (len(indices[1]),)
307
+ return Expr(
308
+ _ti_core.subscript_with_multiple_indices(
309
+ value.ptr,
310
+ multiple_indices,
311
+ return_shape,
312
+ dbg_info,
313
+ )
314
+ )
315
+ return Expr(ast_builder.expr_subscript(value.ptr, indices_expr_group, dbg_info))
316
+
317
+
318
+ class SrcInfoGuard:
319
+ def __init__(self, info_stack, info):
320
+ self.info_stack = info_stack
321
+ self.info = info
322
+
323
+ def __enter__(self):
324
+ self.info_stack.append(self.info)
325
+
326
+ def __exit__(self, exc_type, exc_val, exc_tb):
327
+ self.info_stack.pop()
328
+
329
+
330
+ class PyGsTaichi:
331
+ def __init__(self, kernels=None):
332
+ self.materialized = False
333
+ self._prog: Program | None = None
334
+ self.src_info_stack = []
335
+ self.inside_kernel: bool = False
336
+ self._compiling_callable: KernelCxx | Kernel | Function | None = None
337
+ self._current_kernel: "Kernel | None" = None
338
+ self.global_vars = []
339
+ self.grad_vars = []
340
+ self.dual_vars = []
341
+ self.matrix_fields = []
342
+ self.default_fp = f32
343
+ self.default_ip = i32
344
+ self.default_up = u32
345
+ self.print_full_traceback: bool = False
346
+ self.target_tape = None
347
+ self.fwd_mode_manager = None
348
+ self.grad_replaced = False
349
+ self.kernels: list[Kernel] = kernels or []
350
+ self._signal_handler_registry = None
351
+ self.unfinalized_fields_builder = {}
352
+ self.src_ll_cache: bool = True
353
+
354
+ @property
355
+ def compiling_callable(self) -> KernelCxx | Kernel | Function:
356
+ if self._compiling_callable is None:
357
+ raise GsTaichiRuntimeError(
358
+ "_compiling_callable attribute not initialized. Maybe you forgot to call `ti.init()` first?"
359
+ )
360
+ return self._compiling_callable
361
+
362
+ @property
363
+ def prog(self) -> Program:
364
+ if self._prog is None:
365
+ raise GsTaichiRuntimeError("_prog attribute not initialized. Maybe you forgot to call `ti.init()` first?")
366
+ return self._prog
367
+
368
+ @property
369
+ def current_kernel(self) -> Kernel:
370
+ if self._current_kernel is None:
371
+ raise GsTaichiRuntimeError(
372
+ "_current_kernel attribute not initialized. Maybe you forgot to call `ti.init()` first?"
373
+ )
374
+ return self._current_kernel
375
+
376
+ def initialize_fields_builder(self, builder):
377
+ self.unfinalized_fields_builder[builder] = get_traceback(2)
378
+
379
+ def clear_compiled_functions(self):
380
+ for k in self.kernels:
381
+ k.materialized_kernels.clear()
382
+
383
+ def finalize_fields_builder(self, builder):
384
+ self.unfinalized_fields_builder.pop(builder)
385
+
386
+ def validate_fields_builder(self):
387
+ for builder, tb in self.unfinalized_fields_builder.items():
388
+ if builder == _root_fb:
389
+ continue
390
+
391
+ raise GsTaichiRuntimeError(
392
+ f"Field builder {builder} is not finalized. " f"Please call finalize() on it. Traceback:\n{tb}"
393
+ )
394
+
395
+ def get_num_compiled_functions(self):
396
+ count = 0
397
+ for k in self.kernels:
398
+ count += len(k.materialized_kernels)
399
+ return count
400
+
401
+ def src_info_guard(self, info):
402
+ return SrcInfoGuard(self.src_info_stack, info)
403
+
404
+ def get_current_src_info(self):
405
+ return self.src_info_stack[-1]
406
+
407
+ def set_default_fp(self, fp):
408
+ assert fp in [f16, f32, f64]
409
+ self.default_fp = fp
410
+ default_cfg().default_fp = self.default_fp
411
+
412
+ def set_default_ip(self, ip):
413
+ assert ip in [i32, i64]
414
+ self.default_ip = ip
415
+ self.default_up = u32 if ip == i32 else u64
416
+ default_cfg().default_ip = self.default_ip
417
+ default_cfg().default_up = self.default_up
418
+
419
+ def create_program(self):
420
+ if self._prog is None:
421
+ self._prog = _ti_core.Program()
422
+
423
+ @staticmethod
424
+ def materialize_root_fb(is_first_call):
425
+ if root.finalized:
426
+ return
427
+ if not is_first_call and root.empty:
428
+ # We have to forcefully finalize when `is_first_call` is True (even
429
+ # if the root itself is empty), so that there is a valid struct
430
+ # llvm::Module, if no field has been declared before the first kernel
431
+ # invocation. Example case:
432
+ # https://github.com/taichi-dev/gstaichi/blob/27bb1dc3227d9273a79fcb318fdb06fd053068f5/tests/python/test_ad_basics.py#L260-L266
433
+ return
434
+
435
+ if get_runtime().prog.config().debug:
436
+ if not root.finalized:
437
+ root._allocate_adjoint_checkbit()
438
+
439
+ root.finalize(raise_warning=not is_first_call)
440
+ global _root_fb
441
+ _root_fb = FieldsBuilder()
442
+
443
+ @staticmethod
444
+ def _get_tb(_var):
445
+ return getattr(_var, "declaration_tb", str(_var.ptr))
446
+
447
+ def _check_field_not_placed(self):
448
+ not_placed = []
449
+ for _var in self.global_vars:
450
+ if _var.ptr.snode() is None:
451
+ not_placed.append(self._get_tb(_var))
452
+
453
+ if len(not_placed):
454
+ bar = "=" * 44 + "\n"
455
+ raise RuntimeError(
456
+ f"These field(s) are not placed:\n{bar}"
457
+ + f"{bar}".join(not_placed)
458
+ + f"{bar}Please consider specifying a shape for them. E.g.,"
459
+ + "\n\n x = ti.field(float, shape=(2, 3))"
460
+ )
461
+
462
+ def _check_gradient_field_not_placed(self, gradient_type):
463
+ not_placed = set()
464
+ gradient_vars = []
465
+ if gradient_type == "grad":
466
+ gradient_vars = self.grad_vars
467
+ elif gradient_type == "dual":
468
+ gradient_vars = self.dual_vars
469
+ for _var in gradient_vars:
470
+ if _var.ptr.snode() is None:
471
+ not_placed.add(self._get_tb(_var))
472
+
473
+ if len(not_placed):
474
+ bar = "=" * 44 + "\n"
475
+ raise RuntimeError(
476
+ f"These field(s) requrie `needs_{gradient_type}=True`, however their {gradient_type} field(s) are not placed:\n{bar}"
477
+ + f"{bar}".join(not_placed)
478
+ + f"{bar}Please consider place the {gradient_type} field(s). E.g.,"
479
+ + "\n\n ti.root.dense(ti.i, 1).place(x.{gradient_type})"
480
+ + "\n\n Or specify a shape for the field(s). E.g.,"
481
+ + "\n\n x = ti.field(float, shape=(2, 3), needs_{gradient_type}=True)"
482
+ )
483
+
484
+ def _check_matrix_field_member_shape(self):
485
+ for _field in self.matrix_fields:
486
+ shapes = [_field.get_scalar_field(i, j).shape for i in range(_field.n) for j in range(_field.m)]
487
+ if any(shape != shapes[0] for shape in shapes):
488
+ raise RuntimeError(
489
+ "Members of the following field have different shapes "
490
+ + f"{shapes}:\n{self._get_tb(_field._get_field_members()[0])}"
491
+ )
492
+
493
+ def _calc_matrix_field_dynamic_index_stride(self):
494
+ for _field in self.matrix_fields:
495
+ _field._calc_dynamic_index_stride()
496
+
497
+ def materialize(self):
498
+ self.materialize_root_fb(not self.materialized)
499
+ self.materialized = True
500
+
501
+ self.validate_fields_builder()
502
+
503
+ self._check_field_not_placed()
504
+ self._check_gradient_field_not_placed("grad")
505
+ self._check_gradient_field_not_placed("dual")
506
+ self._check_matrix_field_member_shape()
507
+ self._calc_matrix_field_dynamic_index_stride()
508
+ self.global_vars = []
509
+ self.grad_vars = []
510
+ self.dual_vars = []
511
+ self.matrix_fields = []
512
+
513
+ def _register_signal_handlers(self):
514
+ if self._signal_handler_registry is None:
515
+ self._signal_handler_registry = _ti_core.HackedSignalRegister()
516
+
517
+ def clear(self):
518
+ if self._prog:
519
+ self._prog.finalize()
520
+ self._prog = None
521
+ self._signal_handler_registry = None
522
+ self.materialized = False
523
+
524
+ def sync(self):
525
+ self.materialize()
526
+ assert self._prog is not None
527
+ self._prog.synchronize()
528
+
529
+
530
+ pygstaichi = PyGsTaichi()
531
+
532
+
533
+ def get_runtime() -> PyGsTaichi:
534
+ return pygstaichi
535
+
536
+
537
+ def reset():
538
+ global pygstaichi
539
+ old_kernels = pygstaichi.kernels
540
+ pygstaichi.clear()
541
+ pygstaichi = PyGsTaichi(old_kernels)
542
+ for k in old_kernels:
543
+ k.reset()
544
+ _ti_core.reset_default_compile_config()
545
+
546
+
547
+ @gstaichi_scope
548
+ def static_print(*args, __p=print, **kwargs):
549
+ """The print function in GsTaichi scope.
550
+
551
+ This function is called at compile time and has no runtime overhead.
552
+ """
553
+ __p(*args, **kwargs)
554
+
555
+
556
+ # we don't add @gstaichi_scope decorator for @ti.pyfunc to work
557
+ def static_assert(cond, msg=None):
558
+ """Throw AssertionError when `cond` is False.
559
+
560
+ This function is called at compile time and has no runtime overhead.
561
+ The bool value in `cond` must can be determined at compile time.
562
+
563
+ Args:
564
+ cond (bool): an expression with a bool value.
565
+ msg (str): assertion message.
566
+
567
+ Example::
568
+
569
+ >>> year = 2001
570
+ >>> @ti.kernel
571
+ >>> def test():
572
+ >>> ti.static_assert(year % 4 == 0, "the year must be a lunar year")
573
+ AssertionError: the year must be a lunar year
574
+ """
575
+ if isinstance(cond, Expr):
576
+ raise GsTaichiTypeError("Static assert with non-static condition")
577
+ if msg is not None:
578
+ assert cond, msg
579
+ else:
580
+ assert cond
581
+
582
+
583
+ def inside_kernel():
584
+ return pygstaichi.inside_kernel
585
+
586
+
587
+ def index_nd(dim):
588
+ return axes(*range(dim))
589
+
590
+
591
+ class _UninitializedRootFieldsBuilder:
592
+ def __getattr__(self, item):
593
+ if item == "__qualname__":
594
+ # For sphinx docstring extraction.
595
+ return "_UninitializedRootFieldsBuilder"
596
+ raise GsTaichiRuntimeError("Please call init() first")
597
+
598
+
599
+ # `root` initialization must be delayed until after the program is
600
+ # created. Unfortunately, `root` exists in both gstaichi.lang.impl module and
601
+ # the top-level gstaichi module at this point; so if `root` itself is written, we
602
+ # would have to make sure that `root` in all the modules get updated to the same
603
+ # instance. This is an error-prone process.
604
+ #
605
+ # To avoid this situation, we create `root` once during the import time, and
606
+ # never write to it. The core part, `_root_fb`, is the one whose initialization
607
+ # gets delayed. `_root_fb` will only exist in the gstaichi.lang.impl module, so
608
+ # writing to it is would result in less for maintenance cost.
609
+ #
610
+ # `_root_fb` will be overridden inside :func:`gstaichi.lang.init`.
611
+ _root_fb = _UninitializedRootFieldsBuilder()
612
+
613
+
614
+ def deactivate_all_snodes():
615
+ """Recursively deactivate all SNodes."""
616
+ for root_fb in FieldsBuilder._finalized_roots():
617
+ root_fb.deactivate_all()
618
+
619
+
620
+ class _Root:
621
+ """Wrapper around the default root FieldsBuilder instance."""
622
+
623
+ @staticmethod
624
+ def parent(n=1):
625
+ """Same as :func:`gstaichi.SNode.parent`"""
626
+ assert isinstance(_root_fb, FieldsBuilder)
627
+ return _root_fb.root.parent(n)
628
+
629
+ @staticmethod
630
+ def _loop_range():
631
+ """Same as :func:`gstaichi.SNode.loop_range`"""
632
+ assert isinstance(_root_fb, FieldsBuilder)
633
+ return _root_fb.root._loop_range()
634
+
635
+ @staticmethod
636
+ def _get_children():
637
+ """Same as :func:`gstaichi.SNode.get_children`"""
638
+ assert isinstance(_root_fb, FieldsBuilder)
639
+ return _root_fb.root._get_children()
640
+
641
+ # TODO: Record all of the SNodeTrees that finalized under 'ti.root'
642
+ @staticmethod
643
+ def deactivate_all():
644
+ warning("""'ti.root.deactivate_all()' would deactivate all finalized snodes.""")
645
+ deactivate_all_snodes()
646
+
647
+ @property
648
+ def shape(self):
649
+ """Same as :func:`gstaichi.SNode.shape`"""
650
+ assert isinstance(_root_fb, FieldsBuilder)
651
+ return _root_fb.root.shape
652
+
653
+ @property
654
+ def _id(self):
655
+ assert isinstance(_root_fb, FieldsBuilder)
656
+ return _root_fb.root._id
657
+
658
+ def __getattr__(self, item):
659
+ return getattr(_root_fb, item)
660
+
661
+ def __repr__(self):
662
+ return "ti.root"
663
+
664
+
665
+ root = _Root()
666
+ """Root of the declared GsTaichi :func:`~gstaichi.lang.impl.field`s.
667
+
668
+ See also https://docs.taichi-lang.org/docs/layout
669
+
670
+ Example::
671
+
672
+ >>> x = ti.field(ti.f32)
673
+ >>> ti.root.pointer(ti.ij, 4).dense(ti.ij, 8).place(x)
674
+ """
675
+
676
+
677
+ def _create_snode(axis_seq: Sequence[int], shape_seq: Sequence[numbers.Number], same_level: bool):
678
+ dim = len(axis_seq)
679
+ assert dim == len(shape_seq)
680
+ snode = root
681
+ if same_level:
682
+ snode = snode.dense(axes(*axis_seq), shape_seq)
683
+ else:
684
+ for i in range(dim):
685
+ snode = snode.dense(axes(axis_seq[i]), (shape_seq[i],))
686
+ return snode
687
+
688
+
689
+ @python_scope
690
+ def create_field_member(dtype, name, needs_grad, needs_dual):
691
+ dtype = cook_dtype(dtype)
692
+
693
+ # primal
694
+ prog = get_runtime().prog
695
+
696
+ x = Expr(prog.make_id_expr(""))
697
+ x.declaration_tb = get_traceback(stacklevel=4)
698
+ x.ptr = _ti_core.expr_field(x.ptr, dtype)
699
+ x.ptr.set_name(name)
700
+ x.ptr.set_grad_type(SNodeGradType.PRIMAL)
701
+ pygstaichi.global_vars.append(x)
702
+
703
+ x_grad = None
704
+ x_dual = None
705
+ # The x_grad_checkbit is used for global data access rule checker
706
+ x_grad_checkbit = None
707
+ if _ti_core.is_real(dtype):
708
+ # adjoint
709
+ x_grad = Expr(prog.make_id_expr(""))
710
+ x_grad.declaration_tb = get_traceback(stacklevel=4)
711
+ x_grad.ptr = _ti_core.expr_field(x_grad.ptr, dtype)
712
+ x_grad.ptr.set_name(name + ".grad")
713
+ x_grad.ptr.set_grad_type(SNodeGradType.ADJOINT)
714
+ x.ptr.set_adjoint(x_grad.ptr)
715
+ if needs_grad:
716
+ pygstaichi.grad_vars.append(x_grad)
717
+
718
+ if prog.config().debug:
719
+ # adjoint checkbit
720
+ x_grad_checkbit = Expr(prog.make_id_expr(""))
721
+ dtype = u8
722
+ if prog.config().arch == _ti_core.vulkan:
723
+ dtype = i32
724
+ x_grad_checkbit.ptr = _ti_core.expr_field(x_grad_checkbit.ptr, cook_dtype(dtype))
725
+ x_grad_checkbit.ptr.set_name(name + ".grad_checkbit")
726
+ x_grad_checkbit.ptr.set_grad_type(SNodeGradType.ADJOINT_CHECKBIT)
727
+ x.ptr.set_adjoint_checkbit(x_grad_checkbit.ptr)
728
+
729
+ # dual
730
+ x_dual = Expr(prog.make_id_expr(""))
731
+ x_dual.ptr = _ti_core.expr_field(x_dual.ptr, dtype)
732
+ x_dual.ptr.set_name(name + ".dual")
733
+ x_dual.ptr.set_grad_type(SNodeGradType.DUAL)
734
+ x.ptr.set_dual(x_dual.ptr)
735
+ if needs_dual:
736
+ pygstaichi.dual_vars.append(x_dual)
737
+ elif needs_grad or needs_dual:
738
+ raise GsTaichiRuntimeError(f"{dtype} is not supported for field with `needs_grad=True` or `needs_dual=True`.")
739
+
740
+ return x, x_grad, x_dual
741
+
742
+
743
+ @python_scope
744
+ def _field(
745
+ dtype,
746
+ shape=None,
747
+ order=None,
748
+ name="",
749
+ offset=None,
750
+ needs_grad=False,
751
+ needs_dual=False,
752
+ ):
753
+ x, x_grad, x_dual = create_field_member(dtype, name, needs_grad, needs_dual)
754
+ x = ScalarField(x)
755
+ if x_grad:
756
+ x_grad = ScalarField(x_grad)
757
+ x._set_grad(x_grad)
758
+ if x_dual:
759
+ x_dual = ScalarField(x_dual)
760
+ x._set_dual(x_dual)
761
+
762
+ if shape is None:
763
+ if offset is not None:
764
+ raise GsTaichiSyntaxError("shape cannot be None when offset is set")
765
+ if order is not None:
766
+ raise GsTaichiSyntaxError("shape cannot be None when order is set")
767
+ else:
768
+ if isinstance(shape, numbers.Number):
769
+ shape = (shape,)
770
+ if isinstance(offset, numbers.Number):
771
+ offset = (offset,)
772
+ dim = len(shape)
773
+ if offset is not None and dim != len(offset):
774
+ raise GsTaichiSyntaxError(
775
+ f"The dimensionality of shape and offset must be the same ({dim} != {len(offset)})"
776
+ )
777
+ axis_seq = []
778
+ shape_seq = []
779
+ if order is not None:
780
+ if dim != len(order):
781
+ raise GsTaichiSyntaxError(
782
+ f"The dimensionality of shape and order must be the same ({dim} != {len(order)})"
783
+ )
784
+ if dim != len(set(order)):
785
+ raise GsTaichiSyntaxError("The axes in order must be different")
786
+ for ch in order:
787
+ axis = ord(ch) - ord("i")
788
+ if axis < 0 or axis >= dim:
789
+ raise GsTaichiSyntaxError(f"Invalid axis {ch}")
790
+ axis_seq.append(axis)
791
+ shape_seq.append(shape[axis])
792
+ else:
793
+ axis_seq = list(range(dim))
794
+ shape_seq = list(shape)
795
+ same_level = order is None
796
+ _create_snode(axis_seq, shape_seq, same_level).place(x, offset=offset)
797
+ if needs_grad:
798
+ _create_snode(axis_seq, shape_seq, same_level).place(x_grad, offset=offset)
799
+ if needs_dual:
800
+ _create_snode(axis_seq, shape_seq, same_level).place(x_dual, offset=offset)
801
+ return x
802
+
803
+
804
+ @python_scope
805
+ def field(dtype, *args, **kwargs):
806
+ """Defines a GsTaichi field.
807
+
808
+ A GsTaichi field can be viewed as an abstract N-dimensional array, hiding away
809
+ the complexity of how its underlying :class:`~gstaichi.lang.snode.SNode` are
810
+ actually defined. The data in a GsTaichi field can be directly accessed by
811
+ a GsTaichi :func:`~gstaichi.lang.kernel_impl.kernel`.
812
+
813
+ See also https://docs.taichi-lang.org/docs/field
814
+
815
+ Args:
816
+ dtype (DataType): data type of the field. Note it can be vector or matrix types as well.
817
+ shape (Union[int, tuple[int]], optional): shape of the field.
818
+ order (str, optional): order of the shape laid out in memory.
819
+ name (str, optional): name of the field.
820
+ offset (Union[int, tuple[int]], optional): offset of the field domain.
821
+ needs_grad (bool, optional): whether this field participates in autodiff (reverse mode)
822
+ and thus needs an adjoint field to store the gradients.
823
+ needs_dual (bool, optional): whether this field participates in autodiff (forward mode)
824
+ and thus needs an dual field to store the gradients.
825
+
826
+ Example::
827
+
828
+ The code below shows how a GsTaichi field can be declared and defined::
829
+
830
+ >>> x1 = ti.field(ti.f32, shape=(16, 8))
831
+ >>> # Equivalently
832
+ >>> x2 = ti.field(ti.f32)
833
+ >>> ti.root.dense(ti.ij, shape=(16, 8)).place(x2)
834
+ >>>
835
+ >>> x3 = ti.field(ti.f32, shape=(16, 8), order='ji')
836
+ >>> # Equivalently
837
+ >>> x4 = ti.field(ti.f32)
838
+ >>> ti.root.dense(ti.j, shape=8).dense(ti.i, shape=16).place(x4)
839
+ >>>
840
+ >>> x5 = ti.field(ti.math.vec3, shape=(16, 8))
841
+
842
+ """
843
+ if isinstance(dtype, MatrixType):
844
+ if dtype.ndim == 1:
845
+ return Vector.field(dtype.n, dtype.dtype, *args, **kwargs)
846
+ return Matrix.field(dtype.n, dtype.m, dtype.dtype, *args, **kwargs)
847
+ return _field(dtype, *args, **kwargs)
848
+
849
+
850
+ @python_scope
851
+ def ndarray(dtype, shape, needs_grad=False):
852
+ """Defines a GsTaichi ndarray with scalar elements.
853
+
854
+ Args:
855
+ dtype (Union[DataType, MatrixType]): Data type of each element. This can be either a scalar type like ti.f32 or a compound type like ti.types.vector(3, ti.i32).
856
+ shape (Union[int, tuple[int]]): Shape of the ndarray.
857
+
858
+ Example:
859
+ The code below shows how a GsTaichi ndarray with scalar elements can be declared and defined::
860
+
861
+ >>> x = ti.ndarray(ti.f32, shape=(16, 8)) # ndarray of shape (16, 8), each element is ti.f32 scalar.
862
+ >>> vec3 = ti.types.vector(3, ti.i32)
863
+ >>> y = ti.ndarray(vec3, shape=(10, 2)) # ndarray of shape (10, 2), each element is a vector of 3 ti.i32 scalars.
864
+ >>> matrix_ty = ti.types.matrix(3, 4, float)
865
+ >>> z = ti.ndarray(matrix_ty, shape=(4, 5)) # ndarray of shape (4, 5), each element is a matrix of (3, 4) ti.float scalars.
866
+ """
867
+ # primal
868
+ if isinstance(shape, numbers.Number):
869
+ shape = (shape,)
870
+ if not all((isinstance(x, int) or isinstance(x, np.integer)) and x > 0 and x <= 2**31 - 1 for x in shape):
871
+ raise GsTaichiRuntimeError(f"{shape} is not a valid shape for ndarray")
872
+ if dtype in all_types:
873
+ dt = cook_dtype(dtype)
874
+ x = ScalarNdarray(dt, shape)
875
+ elif isinstance(dtype, MatrixType):
876
+ if dtype.ndim == 1:
877
+ x = VectorNdarray(dtype.n, dtype.dtype, shape)
878
+ else:
879
+ x = MatrixNdarray(dtype.n, dtype.m, dtype.dtype, shape)
880
+ dt = dtype.dtype
881
+ else:
882
+ raise GsTaichiRuntimeError(f"{dtype} is not supported as ndarray element type")
883
+ if needs_grad:
884
+ assert isinstance(dt, DataTypeCxx)
885
+ if not _ti_core.is_real(dt):
886
+ raise GsTaichiRuntimeError(
887
+ f"{dt} is not supported for ndarray with `needs_grad=True` or `needs_dual=True`."
888
+ )
889
+ x_grad = ndarray(dtype, shape, needs_grad=False)
890
+ x._set_grad(x_grad)
891
+ return x
892
+
893
+
894
+ @gstaichi_scope
895
+ def ti_format_list_to_content_entries(raw):
896
+ # return a pair of [content, format]
897
+ def entry2content(_var):
898
+ if isinstance(_var, str):
899
+ return [_var, None]
900
+ if isinstance(_var, list):
901
+ assert len(_var) == 2 and (isinstance(_var[1], str) or _var[1] is None)
902
+ _var[0] = Expr(_var[0]).ptr
903
+ return _var
904
+ return [Expr(_var).ptr, None]
905
+
906
+ def list_ti_repr(_var):
907
+ yield "[" # distinguishing tuple & list will increase maintenance cost
908
+ for i, v in enumerate(_var):
909
+ if i:
910
+ yield ", "
911
+ yield v
912
+ yield "]"
913
+
914
+ def vars2entries(_vars):
915
+ for _var in _vars:
916
+ # If the first element is '__ti_fmt_value__', this list is an Expr and its format.
917
+ if isinstance(_var, list) and len(_var) == 3 and isinstance(_var[0], str) and _var[0] == "__ti_fmt_value__":
918
+ # yield [Expr, format] as a whole and don't pass it to vars2entries() again
919
+ yield _var[1:]
920
+ continue
921
+ elif hasattr(_var, "__ti_repr__"):
922
+ res = _var.__ti_repr__() # type: ignore
923
+ elif isinstance(_var, (list, tuple)):
924
+ # If the first element is '__ti_format__', this list is the result of ti_format.
925
+ if len(_var) > 0 and isinstance(_var[0], str) and _var[0] == "__ti_format__":
926
+ res = _var[1:]
927
+ else:
928
+ res = list_ti_repr(_var)
929
+ else:
930
+ yield _var
931
+ continue
932
+
933
+ for v in vars2entries(res):
934
+ yield v
935
+
936
+ def fused_string(entries):
937
+ accumated = ""
938
+ for entry in entries:
939
+ if isinstance(entry, str):
940
+ accumated += entry
941
+ else:
942
+ if accumated:
943
+ yield accumated
944
+ accumated = ""
945
+ yield entry
946
+ if accumated:
947
+ yield accumated
948
+
949
+ def extract_formats(entries):
950
+ contents, formats = zip(*entries)
951
+ return list(contents), list(formats)
952
+
953
+ entries = vars2entries(raw)
954
+ entries = fused_string(entries)
955
+ entries = [entry2content(entry) for entry in entries]
956
+ return extract_formats(entries)
957
+
958
+
959
+ @gstaichi_scope
960
+ def ti_print(*_vars, sep=" ", end="\n"):
961
+ def add_separators(_vars):
962
+ for i, _var in enumerate(_vars):
963
+ if i:
964
+ yield sep
965
+ yield _var
966
+ yield end
967
+
968
+ _vars = add_separators(_vars)
969
+ contents, formats = ti_format_list_to_content_entries(_vars)
970
+ ast_builder = get_runtime().compiling_callable.ast_builder()
971
+ debug_info = _ti_core.DebugInfo(get_runtime().get_current_src_info())
972
+ ast_builder.create_print(contents, formats, debug_info)
973
+
974
+
975
+ @gstaichi_scope
976
+ def ti_format(*args):
977
+ content = args[0]
978
+ mixed = args[1:]
979
+ new_mixed = []
980
+ args = []
981
+ for x in mixed:
982
+ # x is a (formatted) Expr
983
+ if isinstance(x, Expr) or (isinstance(x, list) and len(x) == 3 and x[0] == "__ti_fmt_value__"):
984
+ new_mixed.append("{}")
985
+ args.append(x)
986
+ else:
987
+ new_mixed.append(x)
988
+ content = content.format(*new_mixed)
989
+ res = content.split("{}")
990
+ assert len(res) == len(args) + 1, "Number of args is different from number of positions provided in string"
991
+
992
+ for i, arg in enumerate(args):
993
+ res.insert(i * 2 + 1, arg)
994
+ res.insert(0, "__ti_format__")
995
+ return res
996
+
997
+
998
+ @gstaichi_scope
999
+ def ti_assert(cond, msg, extra_args, dbg_info):
1000
+ # Mostly a wrapper to help us convert from Expr (defined in Python) to
1001
+ # _ti_core.Expr (defined in C++)
1002
+ ast_builder = get_runtime().compiling_callable.ast_builder()
1003
+ ast_builder.create_assert_stmt(Expr(cond).ptr, msg, extra_args, dbg_info)
1004
+
1005
+
1006
+ @gstaichi_scope
1007
+ def ti_int(_var):
1008
+ if hasattr(_var, "__ti_int__"):
1009
+ return _var.__ti_int__()
1010
+ return int(_var)
1011
+
1012
+
1013
+ @gstaichi_scope
1014
+ def ti_bool(_var):
1015
+ if hasattr(_var, "__ti_bool__"):
1016
+ return _var.__ti_bool__()
1017
+ return bool(_var)
1018
+
1019
+
1020
+ @gstaichi_scope
1021
+ def ti_float(_var):
1022
+ if hasattr(_var, "__ti_float__"):
1023
+ return _var.__ti_float__()
1024
+ return float(_var)
1025
+
1026
+
1027
+ @gstaichi_scope
1028
+ def zero(x):
1029
+ # TODO: get dtype from Expr and Matrix:
1030
+ """Returns an array of zeros with the same shape and type as the input. It's also a scalar
1031
+ if the input is a scalar.
1032
+
1033
+ Args:
1034
+ x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): The input.
1035
+
1036
+ Returns:
1037
+ A new copy of the input but filled with zeros.
1038
+
1039
+ Example::
1040
+
1041
+ >>> x = ti.Vector([1, 1])
1042
+ >>> @ti.kernel
1043
+ >>> def test():
1044
+ >>> y = ti.zero(x)
1045
+ >>> print(y)
1046
+ [0, 0]
1047
+ """
1048
+ return x * 0
1049
+
1050
+
1051
+ @gstaichi_scope
1052
+ def one(x):
1053
+ """Returns an array of ones with the same shape and type as the input. It's also a scalar
1054
+ if the input is a scalar.
1055
+
1056
+ Args:
1057
+ x (Union[:mod:`~gstaichi.types.primitive_types`, :class:`~gstaichi.Matrix`]): The input.
1058
+
1059
+ Returns:
1060
+ A new copy of the input but filled with ones.
1061
+
1062
+ Example::
1063
+
1064
+ >>> x = ti.Vector([0, 0])
1065
+ >>> @ti.kernel
1066
+ >>> def test():
1067
+ >>> y = ti.one(x)
1068
+ >>> print(y)
1069
+ [1, 1]
1070
+ """
1071
+ return zero(x) + 1
1072
+
1073
+
1074
+ def axes(*x: int):
1075
+ """Defines a list of axes to be used by a field.
1076
+
1077
+ Args:
1078
+ *x: A list of axes to be activated
1079
+
1080
+ Note that GsTaichi has already provided a set of commonly used axes. For example,
1081
+ `ti.ij` is just `axes(0, 1)` under the hood.
1082
+ """
1083
+ return [_ti_core.Axis(i) for i in x]
1084
+
1085
+
1086
+ Axis = _ti_core.Axis
1087
+
1088
+
1089
+ def static(x, *xs) -> Any:
1090
+ """Evaluates a GsTaichi-scope expression at compile time.
1091
+
1092
+ `static()` is what enables the so-called metaprogramming in GsTaichi. It is
1093
+ in many ways similar to ``constexpr`` in C++.
1094
+
1095
+ See also https://docs.taichi-lang.org/docs/meta.
1096
+
1097
+ Args:
1098
+ x (Any): an expression to be evaluated
1099
+ *xs (Any): for Python-ish swapping assignment
1100
+
1101
+ Example:
1102
+ The most common usage of `static()` is for compile-time evaluation::
1103
+
1104
+ >>> cond = False
1105
+ >>>
1106
+ >>> @ti.kernel
1107
+ >>> def run():
1108
+ >>> if ti.static(cond):
1109
+ >>> do_a()
1110
+ >>> else:
1111
+ >>> do_b()
1112
+
1113
+ Depending on the value of ``cond``, ``run()`` will be directly compiled
1114
+ into either ``do_a()`` or ``do_b()``. Thus there won't be a runtime
1115
+ condition check.
1116
+
1117
+ Another common usage is for compile-time loop unrolling::
1118
+
1119
+ >>> @ti.kernel
1120
+ >>> def run():
1121
+ >>> for i in ti.static(range(3)):
1122
+ >>> print(i)
1123
+ >>>
1124
+ >>> # The above will be unrolled to:
1125
+ >>> @ti.kernel
1126
+ >>> def run():
1127
+ >>> print(0)
1128
+ >>> print(1)
1129
+ >>> print(2)
1130
+ """
1131
+ if len(xs): # for python-ish pointer assign: x, y = ti.static(y, x)
1132
+ return [static(x)] + [static(x) for x in xs]
1133
+
1134
+ if (
1135
+ isinstance(
1136
+ x,
1137
+ (
1138
+ bool,
1139
+ int,
1140
+ float,
1141
+ range,
1142
+ list,
1143
+ tuple,
1144
+ enumerate,
1145
+ GroupedNDRange,
1146
+ _Ndrange,
1147
+ zip,
1148
+ filter,
1149
+ map,
1150
+ ),
1151
+ )
1152
+ or x is None
1153
+ ):
1154
+ return x
1155
+ if isinstance(x, (np.bool_, np.integer, np.floating)):
1156
+ return x
1157
+
1158
+ if isinstance(x, AnyArray):
1159
+ return x
1160
+ if isinstance(x, Field):
1161
+ return x
1162
+ if isinstance(x, (FunctionType, MethodType, BoundGsTaichiCallable, GsTaichiCallable)):
1163
+ return x
1164
+ raise ValueError(f"Input to ti.static must be compile-time constants or global pointers, instead of {type(x)}")
1165
+
1166
+
1167
+ @gstaichi_scope
1168
+ def grouped(x):
1169
+ """Groups the indices in the iterator returned by `ndrange()` into a 1-D vector.
1170
+
1171
+ This is often used when you want to iterate over all indices returned by `ndrange()`
1172
+ in one `for` loop and a single index.
1173
+
1174
+ Args:
1175
+ x (:func:`~gstaichi.ndrange`): an iterator object returned by `ti.ndrange`.
1176
+
1177
+ Example::
1178
+ >>> # without ti.grouped
1179
+ >>> for I in ti.ndrange(2, 3):
1180
+ >>> print(I)
1181
+ prints 0, 1, 2, 3, 4, 5
1182
+
1183
+ >>> # with ti.grouped
1184
+ >>> for I in ti.grouped(ti.ndrange(2, 3)):
1185
+ >>> print(I)
1186
+ prints [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]
1187
+ """
1188
+ if isinstance(x, _Ndrange):
1189
+ return x.grouped()
1190
+ return x
1191
+
1192
+
1193
+ def stop_grad(x):
1194
+ """Stops computing gradients during back propagation.
1195
+
1196
+ Args:
1197
+ x (:class:`~gstaichi.Field`): A field.
1198
+ """
1199
+ compiling_callable = get_runtime().compiling_callable
1200
+ assert compiling_callable is not None
1201
+ compiling_callable.ast_builder().stop_grad(x.snode.ptr)
1202
+
1203
+
1204
+ def current_cfg():
1205
+ return get_runtime().prog.config()
1206
+
1207
+
1208
+ def default_cfg():
1209
+ return _ti_core.default_compile_config()
1210
+
1211
+
1212
+ def call_internal(name, *args, with_runtime_context=True):
1213
+ return expr_init(_ti_core.insert_internal_func_call(getattr(_ti_core.InternalOp, name), make_expr_group(args)))
1214
+
1215
+
1216
+ def get_cuda_compute_capability():
1217
+ return _ti_core.query_int64("cuda_compute_capability")
1218
+
1219
+
1220
+ @gstaichi_scope
1221
+ def mesh_relation_access(mesh, from_index, to_element_type):
1222
+ # to support ti.mesh_local and access mesh attribute as field
1223
+ if isinstance(from_index, MeshInstance):
1224
+ return getattr(from_index, element_type_name(to_element_type))
1225
+ if isinstance(mesh, MeshInstance):
1226
+ return MeshRelationAccessProxy(mesh, from_index, to_element_type)
1227
+ raise RuntimeError("Relation access should be with a mesh instance!")
1228
+
1229
+
1230
+ __all__ = [
1231
+ "axes",
1232
+ "deactivate_all_snodes",
1233
+ "field",
1234
+ "grouped",
1235
+ "ndarray",
1236
+ "one",
1237
+ "root",
1238
+ "static",
1239
+ "static_assert",
1240
+ "static_print",
1241
+ "stop_grad",
1242
+ "zero",
1243
+ ]