gstaichi 2.1.1rc3__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. gstaichi/CHANGELOG.md +4 -0
  2. gstaichi/__init__.py +40 -0
  3. gstaichi/_funcs.py +706 -0
  4. gstaichi/_kernels.py +420 -0
  5. gstaichi/_lib/__init__.py +3 -0
  6. gstaichi/_lib/core/__init__.py +0 -0
  7. gstaichi/_lib/core/gstaichi_python.cpython-311-darwin.so +0 -0
  8. gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
  9. gstaichi/_lib/core/py.typed +0 -0
  10. gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
  11. gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
  12. gstaichi/_lib/utils.py +243 -0
  13. gstaichi/_logging.py +131 -0
  14. gstaichi/_snode/__init__.py +5 -0
  15. gstaichi/_snode/fields_builder.py +187 -0
  16. gstaichi/_snode/snode_tree.py +34 -0
  17. gstaichi/_test_tools/__init__.py +18 -0
  18. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  19. gstaichi/_test_tools/load_kernel_string.py +30 -0
  20. gstaichi/_test_tools/textwrap2.py +6 -0
  21. gstaichi/_version.py +1 -0
  22. gstaichi/_version_check.py +100 -0
  23. gstaichi/ad/__init__.py +3 -0
  24. gstaichi/ad/_ad.py +530 -0
  25. gstaichi/algorithms/__init__.py +3 -0
  26. gstaichi/algorithms/_algorithms.py +117 -0
  27. gstaichi/assets/.git +1 -0
  28. gstaichi/assets/Go-Regular.ttf +0 -0
  29. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  30. gstaichi/examples/lcg_python.py +26 -0
  31. gstaichi/examples/lcg_taichi.py +34 -0
  32. gstaichi/examples/minimal.py +28 -0
  33. gstaichi/experimental.py +16 -0
  34. gstaichi/lang/__init__.py +50 -0
  35. gstaichi/lang/_dataclass_util.py +31 -0
  36. gstaichi/lang/_fast_caching/__init__.py +3 -0
  37. gstaichi/lang/_fast_caching/args_hasher.py +110 -0
  38. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  39. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  40. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  41. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  42. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  43. gstaichi/lang/_fast_caching/src_hasher.py +75 -0
  44. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  45. gstaichi/lang/_ndarray.py +352 -0
  46. gstaichi/lang/_ndrange.py +152 -0
  47. gstaichi/lang/_template_mapper.py +195 -0
  48. gstaichi/lang/_texture.py +172 -0
  49. gstaichi/lang/_wrap_inspect.py +215 -0
  50. gstaichi/lang/any_array.py +99 -0
  51. gstaichi/lang/ast/__init__.py +5 -0
  52. gstaichi/lang/ast/ast_transformer.py +1323 -0
  53. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  54. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  55. gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
  56. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  57. gstaichi/lang/ast/checkers.py +106 -0
  58. gstaichi/lang/ast/symbol_resolver.py +57 -0
  59. gstaichi/lang/ast/transform.py +9 -0
  60. gstaichi/lang/common_ops.py +310 -0
  61. gstaichi/lang/exception.py +80 -0
  62. gstaichi/lang/expr.py +180 -0
  63. gstaichi/lang/field.py +428 -0
  64. gstaichi/lang/impl.py +1243 -0
  65. gstaichi/lang/kernel_arguments.py +155 -0
  66. gstaichi/lang/kernel_impl.py +1341 -0
  67. gstaichi/lang/matrix.py +1835 -0
  68. gstaichi/lang/matrix_ops.py +341 -0
  69. gstaichi/lang/matrix_ops_utils.py +190 -0
  70. gstaichi/lang/mesh.py +687 -0
  71. gstaichi/lang/misc.py +782 -0
  72. gstaichi/lang/ops.py +1494 -0
  73. gstaichi/lang/runtime_ops.py +13 -0
  74. gstaichi/lang/shell.py +35 -0
  75. gstaichi/lang/simt/__init__.py +5 -0
  76. gstaichi/lang/simt/block.py +94 -0
  77. gstaichi/lang/simt/grid.py +7 -0
  78. gstaichi/lang/simt/subgroup.py +191 -0
  79. gstaichi/lang/simt/warp.py +96 -0
  80. gstaichi/lang/snode.py +489 -0
  81. gstaichi/lang/source_builder.py +150 -0
  82. gstaichi/lang/struct.py +810 -0
  83. gstaichi/lang/util.py +312 -0
  84. gstaichi/linalg/__init__.py +8 -0
  85. gstaichi/linalg/matrixfree_cg.py +310 -0
  86. gstaichi/linalg/sparse_cg.py +59 -0
  87. gstaichi/linalg/sparse_matrix.py +303 -0
  88. gstaichi/linalg/sparse_solver.py +123 -0
  89. gstaichi/math/__init__.py +11 -0
  90. gstaichi/math/_complex.py +205 -0
  91. gstaichi/math/mathimpl.py +886 -0
  92. gstaichi/profiler/__init__.py +6 -0
  93. gstaichi/profiler/kernel_metrics.py +260 -0
  94. gstaichi/profiler/kernel_profiler.py +586 -0
  95. gstaichi/profiler/memory_profiler.py +15 -0
  96. gstaichi/profiler/scoped_profiler.py +36 -0
  97. gstaichi/sparse/__init__.py +3 -0
  98. gstaichi/sparse/_sparse_grid.py +77 -0
  99. gstaichi/tools/__init__.py +12 -0
  100. gstaichi/tools/diagnose.py +117 -0
  101. gstaichi/tools/np2ply.py +364 -0
  102. gstaichi/tools/vtk.py +38 -0
  103. gstaichi/types/__init__.py +19 -0
  104. gstaichi/types/annotations.py +52 -0
  105. gstaichi/types/compound_types.py +71 -0
  106. gstaichi/types/enums.py +49 -0
  107. gstaichi/types/ndarray_type.py +169 -0
  108. gstaichi/types/primitive_types.py +206 -0
  109. gstaichi/types/quant.py +88 -0
  110. gstaichi/types/texture_type.py +85 -0
  111. gstaichi/types/utils.py +11 -0
  112. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
  113. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
  114. gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
  115. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
  116. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
  117. gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
  118. gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
  119. gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  120. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
  121. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
  122. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  123. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  124. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  125. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  126. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  127. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  128. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  129. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  130. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  131. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  132. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  133. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  134. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  135. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  136. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  137. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  138. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  139. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  140. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  141. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  142. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  143. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  144. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  145. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  146. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  147. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  148. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  149. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  150. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  151. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  152. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  153. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  154. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  155. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  156. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  157. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  158. gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  159. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  160. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  161. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  162. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  163. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  164. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  165. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  166. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  167. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  168. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  169. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  170. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  171. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  172. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  173. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  174. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  175. gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
  176. gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
  177. gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
  178. gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
  179. gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1323 @@
1
+ # type: ignore
2
+
3
+ import ast
4
+ import collections.abc
5
+ import dataclasses
6
+ import itertools
7
+ import warnings
8
+ from ast import unparse
9
+ from typing import Any, Sequence, Type
10
+
11
+ import numpy as np
12
+
13
+ from gstaichi._lib import core as _ti_core
14
+ from gstaichi.lang import expr, impl, matrix, mesh
15
+ from gstaichi.lang import ops as ti_ops
16
+ from gstaichi.lang._ndrange import _Ndrange
17
+ from gstaichi.lang.ast.ast_transformer_utils import (
18
+ ASTTransformerContext,
19
+ Builder,
20
+ LoopStatus,
21
+ ReturnStatus,
22
+ get_decorator,
23
+ )
24
+ from gstaichi.lang.ast.ast_transformers.call_transformer import CallTransformer
25
+ from gstaichi.lang.ast.ast_transformers.function_def_transformer import (
26
+ FunctionDefTransformer,
27
+ )
28
+ from gstaichi.lang.exception import (
29
+ GsTaichiIndexError,
30
+ GsTaichiRuntimeTypeError,
31
+ GsTaichiSyntaxError,
32
+ GsTaichiTypeError,
33
+ handle_exception_from_cpp,
34
+ )
35
+ from gstaichi.lang.expr import Expr, make_expr_group
36
+ from gstaichi.lang.field import Field
37
+ from gstaichi.lang.matrix import Matrix, MatrixType
38
+ from gstaichi.lang.snode import append, deactivate, length
39
+ from gstaichi.lang.struct import Struct, StructType
40
+ from gstaichi.types import primitive_types
41
+ from gstaichi.types.utils import is_integral
42
+
43
+
44
+ def reshape_list(flat_list: list[Any], target_shape: Sequence[int]) -> list[Any]:
45
+ if len(target_shape) < 2:
46
+ return flat_list
47
+
48
+ curr_list = []
49
+ dim = target_shape[-1]
50
+ for i, elem in enumerate(flat_list):
51
+ if i % dim == 0:
52
+ curr_list.append([])
53
+ curr_list[-1].append(elem)
54
+
55
+ return reshape_list(curr_list, target_shape[:-1])
56
+
57
+
58
+ def boundary_type_cast_warning(expression: Expr) -> None:
59
+ expr_dtype = expression.ptr.get_rvalue_type()
60
+ if not is_integral(expr_dtype) or expr_dtype in [
61
+ primitive_types.i64,
62
+ primitive_types.u64,
63
+ primitive_types.u32,
64
+ ]:
65
+ warnings.warn(
66
+ f"Casting range_for boundary values from {expr_dtype} to i32, which may cause numerical issues",
67
+ Warning,
68
+ )
69
+
70
+
71
+ class ASTTransformer(Builder):
72
+ @staticmethod
73
+ def build_Name(ctx: ASTTransformerContext, node: ast.Name):
74
+ node.ptr = ctx.get_var_by_name(node.id)
75
+ if isinstance(node, (ast.stmt, ast.expr)) and isinstance(node.ptr, Expr):
76
+ node.ptr.dbg_info = _ti_core.DebugInfo(ctx.get_pos_info(node))
77
+ node.ptr.ptr.set_dbg_info(node.ptr.dbg_info)
78
+ return node.ptr
79
+
80
+ @staticmethod
81
+ def build_AnnAssign(ctx: ASTTransformerContext, node: ast.AnnAssign):
82
+ build_stmt(ctx, node.value)
83
+ build_stmt(ctx, node.annotation)
84
+
85
+ is_static_assign = isinstance(node.value, ast.Call) and node.value.func.ptr is impl.static
86
+
87
+ node.ptr = ASTTransformer.build_assign_annotated(
88
+ ctx, node.target, node.value.ptr, is_static_assign, node.annotation.ptr
89
+ )
90
+ return node.ptr
91
+
92
+ @staticmethod
93
+ def build_assign_annotated(
94
+ ctx: ASTTransformerContext, target: ast.Name, value, is_static_assign: bool, annotation: Type
95
+ ):
96
+ """Build an annotated assignment like this: target: annotation = value.
97
+
98
+ Args:
99
+ ctx (ast_builder_utils.BuilderContext): The builder context.
100
+ target (ast.Name): A variable name. `target.id` holds the name as
101
+ a string.
102
+ annotation: A type we hope to assign to the target
103
+ value: A node representing the value.
104
+ is_static_assign: A boolean value indicating whether this is a static assignment
105
+ """
106
+ is_local = isinstance(target, ast.Name)
107
+ if is_local and target.id in ctx.kernel_args:
108
+ raise GsTaichiSyntaxError(
109
+ f'Kernel argument "{target.id}" is immutable in the kernel. '
110
+ f"If you want to change its value, please create a new variable."
111
+ )
112
+ anno = impl.expr_init(annotation)
113
+ if is_static_assign:
114
+ raise GsTaichiSyntaxError("Static assign cannot be used on annotated assignment")
115
+ if is_local and not ctx.is_var_declared(target.id):
116
+ var = ti_ops.cast(value, anno)
117
+ var = impl.expr_init(var)
118
+ ctx.create_variable(target.id, var)
119
+ else:
120
+ var = build_stmt(ctx, target)
121
+ if var.ptr.get_rvalue_type() != anno:
122
+ raise GsTaichiSyntaxError("Static assign cannot have type overloading")
123
+ var._assign(value)
124
+ return var
125
+
126
+ @staticmethod
127
+ def build_Assign(ctx: ASTTransformerContext, node: ast.Assign) -> None:
128
+ build_stmt(ctx, node.value)
129
+ is_static_assign = isinstance(node.value, ast.Call) and node.value.func.ptr is impl.static
130
+
131
+ # Keep all generated assign statements and compose single one at last.
132
+ # The variable is introduced to support chained assignments.
133
+ # Ref https://github.com/taichi-dev/gstaichi/issues/2659.
134
+ values = node.value.ptr if is_static_assign else impl.expr_init(node.value.ptr)
135
+
136
+ for node_target in node.targets:
137
+ ASTTransformer.build_assign_unpack(ctx, node_target, values, is_static_assign)
138
+ return None
139
+
140
+ @staticmethod
141
+ def build_assign_unpack(ctx: ASTTransformerContext, node_target: list | ast.Tuple, values, is_static_assign: bool):
142
+ """Build the unpack assignments like this: (target1, target2) = (value1, value2).
143
+ The function should be called only if the node target is a tuple.
144
+
145
+ Args:
146
+ ctx (ast_builder_utils.BuilderContext): The builder context.
147
+ node_target (ast.Tuple): A list or tuple object. `node_target.elts` holds a
148
+ list of nodes representing the elements.
149
+ values: A node/list representing the values.
150
+ is_static_assign: A boolean value indicating whether this is a static assignment
151
+ """
152
+ if not isinstance(node_target, ast.Tuple):
153
+ return ASTTransformer.build_assign_basic(ctx, node_target, values, is_static_assign)
154
+ targets = node_target.elts
155
+
156
+ if isinstance(values, matrix.Matrix):
157
+ if not values.m == 1:
158
+ raise ValueError("Matrices with more than one columns cannot be unpacked")
159
+ values = values.entries
160
+
161
+ # Unpack: a, b, c = ti.Vector([1., 2., 3.])
162
+ if isinstance(values, impl.Expr) and values.ptr.is_tensor():
163
+ if len(values.get_shape()) > 1:
164
+ raise ValueError("Matrices with more than one columns cannot be unpacked")
165
+
166
+ values = ctx.ast_builder.expand_exprs([values.ptr])
167
+ if len(values) == 1:
168
+ values = values[0]
169
+
170
+ if isinstance(values, impl.Expr) and values.ptr.is_struct():
171
+ values = ctx.ast_builder.expand_exprs([values.ptr])
172
+ if len(values) == 1:
173
+ values = values[0]
174
+
175
+ if not isinstance(values, collections.abc.Sequence):
176
+ raise GsTaichiSyntaxError(f"Cannot unpack type: {type(values)}")
177
+
178
+ if len(values) != len(targets):
179
+ raise GsTaichiSyntaxError("The number of targets is not equal to value length")
180
+
181
+ for i, target in enumerate(targets):
182
+ ASTTransformer.build_assign_basic(ctx, target, values[i], is_static_assign)
183
+
184
+ return None
185
+
186
+ @staticmethod
187
+ def build_assign_basic(ctx: ASTTransformerContext, target: ast.Name, value, is_static_assign: bool):
188
+ """Build basic assignment like this: target = value.
189
+
190
+ Args:
191
+ ctx (ast_builder_utils.BuilderContext): The builder context.
192
+ target (ast.Name): A variable name. `target.id` holds the name as
193
+ a string.
194
+ value: A node representing the value.
195
+ is_static_assign: A boolean value indicating whether this is a static assignment
196
+ """
197
+ is_local = isinstance(target, ast.Name)
198
+ if is_local and target.id in ctx.kernel_args:
199
+ raise GsTaichiSyntaxError(
200
+ f'Kernel argument "{target.id}" is immutable in the kernel. '
201
+ f"If you want to change its value, please create a new variable."
202
+ )
203
+ if is_static_assign:
204
+ if not is_local:
205
+ raise GsTaichiSyntaxError("Static assign cannot be used on elements in arrays")
206
+ ctx.create_variable(target.id, value)
207
+ var = value
208
+ elif is_local and not ctx.is_var_declared(target.id):
209
+ var = impl.expr_init(value)
210
+ ctx.create_variable(target.id, var)
211
+ else:
212
+ var = build_stmt(ctx, target)
213
+ try:
214
+ var._assign(value)
215
+ except AttributeError:
216
+ raise GsTaichiSyntaxError(
217
+ f"Variable '{unparse(target).strip()}' cannot be assigned. Maybe it is not a GsTaichi object?"
218
+ )
219
+ return var
220
+
221
+ @staticmethod
222
+ def build_NamedExpr(ctx: ASTTransformerContext, node: ast.NamedExpr):
223
+ build_stmt(ctx, node.value)
224
+ is_static_assign = isinstance(node.value, ast.Call) and node.value.func.ptr is impl.static
225
+ node.ptr = ASTTransformer.build_assign_basic(ctx, node.target, node.value.ptr, is_static_assign)
226
+ return node.ptr
227
+
228
+ @staticmethod
229
+ def is_tuple(node):
230
+ if isinstance(node, ast.Tuple):
231
+ return True
232
+ if isinstance(node, ast.Index) and isinstance(node.value.ptr, tuple):
233
+ return True
234
+ if isinstance(node.ptr, tuple):
235
+ return True
236
+ return False
237
+
238
+ @staticmethod
239
+ def build_Subscript(ctx: ASTTransformerContext, node: ast.Subscript):
240
+ build_stmt(ctx, node.value)
241
+ build_stmt(ctx, node.slice)
242
+ if not ASTTransformer.is_tuple(node.slice):
243
+ node.slice.ptr = [node.slice.ptr]
244
+ node.ptr = impl.subscript(ctx.ast_builder, node.value.ptr, *node.slice.ptr)
245
+ return node.ptr
246
+
247
+ @staticmethod
248
+ def build_Slice(ctx: ASTTransformerContext, node: ast.Slice):
249
+ if node.lower is not None:
250
+ build_stmt(ctx, node.lower)
251
+ if node.upper is not None:
252
+ build_stmt(ctx, node.upper)
253
+ if node.step is not None:
254
+ build_stmt(ctx, node.step)
255
+
256
+ node.ptr = slice(
257
+ node.lower.ptr if node.lower else None,
258
+ node.upper.ptr if node.upper else None,
259
+ node.step.ptr if node.step else None,
260
+ )
261
+ return node.ptr
262
+
263
+ @staticmethod
264
+ def build_ExtSlice(ctx: ASTTransformerContext, node: ast.ExtSlice):
265
+ build_stmts(ctx, node.dims)
266
+ node.ptr = tuple(dim.ptr for dim in node.dims)
267
+ return node.ptr
268
+
269
+ @staticmethod
270
+ def build_Tuple(ctx: ASTTransformerContext, node: ast.Tuple):
271
+ build_stmts(ctx, node.elts)
272
+ node.ptr = tuple(elt.ptr for elt in node.elts)
273
+ return node.ptr
274
+
275
+ @staticmethod
276
+ def build_List(ctx: ASTTransformerContext, node: ast.List):
277
+ build_stmts(ctx, node.elts)
278
+ node.ptr = [elt.ptr for elt in node.elts]
279
+ return node.ptr
280
+
281
+ @staticmethod
282
+ def build_Dict(ctx: ASTTransformerContext, node: ast.Dict):
283
+ dic = {}
284
+ for key, value in zip(node.keys, node.values):
285
+ if key is None:
286
+ dic.update(build_stmt(ctx, value))
287
+ else:
288
+ dic[build_stmt(ctx, key)] = build_stmt(ctx, value)
289
+ node.ptr = dic
290
+ return node.ptr
291
+
292
+ @staticmethod
293
+ def process_listcomp(ctx: ASTTransformerContext, node, result) -> None:
294
+ result.append(build_stmt(ctx, node.elt))
295
+
296
+ @staticmethod
297
+ def process_dictcomp(ctx: ASTTransformerContext, node, result) -> None:
298
+ key = build_stmt(ctx, node.key)
299
+ value = build_stmt(ctx, node.value)
300
+ result[key] = value
301
+
302
+ @staticmethod
303
+ def process_generators(ctx: ASTTransformerContext, node: ast.GeneratorExp, now_comp, func, result):
304
+ if now_comp >= len(node.generators):
305
+ return func(ctx, node, result)
306
+ with ctx.static_scope_guard():
307
+ _iter = build_stmt(ctx, node.generators[now_comp].iter)
308
+
309
+ if isinstance(_iter, impl.Expr) and _iter.ptr.is_tensor():
310
+ shape = _iter.ptr.get_shape()
311
+ flattened = [Expr(x) for x in ctx.ast_builder.expand_exprs([_iter.ptr])]
312
+ _iter = reshape_list(flattened, shape)
313
+
314
+ for value in _iter:
315
+ with ctx.variable_scope_guard():
316
+ ASTTransformer.build_assign_unpack(ctx, node.generators[now_comp].target, value, True)
317
+ with ctx.static_scope_guard():
318
+ build_stmts(ctx, node.generators[now_comp].ifs)
319
+ ASTTransformer.process_ifs(ctx, node, now_comp, 0, func, result)
320
+ return None
321
+
322
+ @staticmethod
323
+ def process_ifs(ctx: ASTTransformerContext, node: ast.If, now_comp, now_if, func, result):
324
+ if now_if >= len(node.generators[now_comp].ifs):
325
+ return ASTTransformer.process_generators(ctx, node, now_comp + 1, func, result)
326
+ cond = node.generators[now_comp].ifs[now_if].ptr
327
+ if cond:
328
+ ASTTransformer.process_ifs(ctx, node, now_comp, now_if + 1, func, result)
329
+
330
+ return None
331
+
332
+ @staticmethod
333
+ def build_ListComp(ctx: ASTTransformerContext, node: ast.ListComp):
334
+ result = []
335
+ ASTTransformer.process_generators(ctx, node, 0, ASTTransformer.process_listcomp, result)
336
+ node.ptr = result
337
+ return node.ptr
338
+
339
+ @staticmethod
340
+ def build_DictComp(ctx: ASTTransformerContext, node: ast.DictComp):
341
+ result = {}
342
+ ASTTransformer.process_generators(ctx, node, 0, ASTTransformer.process_dictcomp, result)
343
+ node.ptr = result
344
+ return node.ptr
345
+
346
+ @staticmethod
347
+ def build_Index(ctx: ASTTransformerContext, node: ast.Index):
348
+ node.ptr = build_stmt(ctx, node.value)
349
+ return node.ptr
350
+
351
+ @staticmethod
352
+ def build_Constant(ctx: ASTTransformerContext, node: ast.Constant):
353
+ node.ptr = node.value
354
+ return node.ptr
355
+
356
+ @staticmethod
357
+ def build_Num(ctx: ASTTransformerContext, node: ast.Num):
358
+ node.ptr = node.n
359
+ return node.ptr
360
+
361
+ @staticmethod
362
+ def build_Str(ctx: ASTTransformerContext, node: ast.Str):
363
+ node.ptr = node.s
364
+ return node.ptr
365
+
366
+ @staticmethod
367
+ def build_Bytes(ctx: ASTTransformerContext, node: ast.Bytes):
368
+ node.ptr = node.s
369
+ return node.ptr
370
+
371
+ @staticmethod
372
+ def build_NameConstant(ctx: ASTTransformerContext, node: ast.NameConstant):
373
+ node.ptr = node.value
374
+ return node.ptr
375
+
376
+ @staticmethod
377
+ def build_keyword(ctx: ASTTransformerContext, node: ast.keyword):
378
+ build_stmt(ctx, node.value)
379
+ if node.arg is None:
380
+ node.ptr = node.value.ptr
381
+ else:
382
+ node.ptr = {node.arg: node.value.ptr}
383
+ return node.ptr
384
+
385
+ @staticmethod
386
+ def build_Starred(ctx: ASTTransformerContext, node: ast.Starred):
387
+ node.ptr = build_stmt(ctx, node.value)
388
+ return node.ptr
389
+
390
+ @staticmethod
391
+ def build_FormattedValue(ctx: ASTTransformerContext, node: ast.FormattedValue):
392
+ node.ptr = build_stmt(ctx, node.value)
393
+ if node.format_spec is None or len(node.format_spec.values) == 0:
394
+ return node.ptr
395
+ values = node.format_spec.values
396
+ assert len(values) == 1
397
+ format_str = values[0].s
398
+ assert format_str is not None
399
+ # distinguished from normal list
400
+ return ["__ti_fmt_value__", node.ptr, format_str]
401
+
402
+ @staticmethod
403
+ def build_JoinedStr(ctx: ASTTransformerContext, node: ast.JoinedStr):
404
+ str_spec = ""
405
+ args = []
406
+ for sub_node in node.values:
407
+ if isinstance(sub_node, ast.FormattedValue):
408
+ str_spec += "{}"
409
+ args.append(build_stmt(ctx, sub_node))
410
+ elif isinstance(sub_node, ast.Constant):
411
+ str_spec += sub_node.value
412
+ elif isinstance(sub_node, ast.Str):
413
+ str_spec += sub_node.s
414
+ else:
415
+ raise GsTaichiSyntaxError("Invalid value for fstring.")
416
+
417
+ args.insert(0, str_spec)
418
+ node.ptr = impl.ti_format(*args)
419
+ return node.ptr
420
+
421
+ @staticmethod
422
+ def build_Call(ctx: ASTTransformerContext, node: ast.Call) -> Any | None:
423
+ return CallTransformer.build_Call(ctx, node, build_stmt, build_stmts)
424
+
425
+ @staticmethod
426
+ def build_FunctionDef(ctx: ASTTransformerContext, node: ast.FunctionDef) -> None:
427
+ FunctionDefTransformer.build_FunctionDef(ctx, node, build_stmts)
428
+
429
+ @staticmethod
430
+ def build_Return(ctx: ASTTransformerContext, node: ast.Return) -> None:
431
+ if not ctx.is_real_function:
432
+ if ctx.is_in_non_static_control_flow():
433
+ raise GsTaichiSyntaxError("Return inside non-static if/for is not supported")
434
+ if node.value is not None:
435
+ build_stmt(ctx, node.value)
436
+ if node.value is None or node.value.ptr is None:
437
+ if not ctx.is_real_function:
438
+ ctx.returned = ReturnStatus.ReturnedVoid
439
+ return None
440
+ if ctx.is_kernel or ctx.is_real_function:
441
+ # TODO: check if it's at the end of a kernel, throw GsTaichiSyntaxError if not
442
+ if ctx.func.return_type is None:
443
+ raise GsTaichiSyntaxError(
444
+ f'A {"kernel" if ctx.is_kernel else "function"} '
445
+ "with a return value must be annotated "
446
+ "with a return type, e.g. def func() -> ti.f32"
447
+ )
448
+ return_exprs = []
449
+ if len(ctx.func.return_type) == 1:
450
+ node.value.ptr = [node.value.ptr]
451
+ assert len(ctx.func.return_type) == len(node.value.ptr)
452
+ for return_type, ptr in zip(ctx.func.return_type, node.value.ptr):
453
+ if id(return_type) in primitive_types.type_ids:
454
+ if isinstance(ptr, Expr):
455
+ if ptr.is_tensor() or ptr.is_struct() or ptr.element_type() not in primitive_types.all_types:
456
+ raise GsTaichiRuntimeTypeError.get_ret(str(return_type), ptr)
457
+ elif not isinstance(ptr, (float, int, np.floating, np.integer)):
458
+ raise GsTaichiRuntimeTypeError.get_ret(str(return_type), ptr)
459
+ return_exprs += [ti_ops.cast(expr.Expr(ptr), return_type).ptr]
460
+ elif isinstance(return_type, MatrixType):
461
+ values = ptr
462
+ if isinstance(values, Matrix):
463
+ if values.ndim != ctx.func.return_type.ndim:
464
+ raise GsTaichiRuntimeTypeError(
465
+ f"Return matrix ndim mismatch, expecting={return_type.ndim}, got={values.ndim}."
466
+ )
467
+ elif return_type.get_shape() != values.get_shape():
468
+ raise GsTaichiRuntimeTypeError(
469
+ f"Return matrix shape mismatch, expecting={return_type.get_shape()}, got={values.get_shape()}."
470
+ )
471
+ values = (
472
+ itertools.chain.from_iterable(values.to_list())
473
+ if values.ndim == 1
474
+ else iter(values.to_list())
475
+ )
476
+ elif isinstance(values, Expr):
477
+ if not values.is_tensor():
478
+ raise GsTaichiRuntimeTypeError.get_ret(return_type.to_string(), ptr)
479
+ elif (
480
+ return_type.dtype in primitive_types.real_types
481
+ and not values.element_type() in primitive_types.all_types
482
+ ):
483
+ raise GsTaichiRuntimeTypeError.get_ret(return_type.dtype.to_string(), values.element_type())
484
+ elif (
485
+ return_type.dtype in primitive_types.integer_types
486
+ and not values.element_type() in primitive_types.integer_types
487
+ ):
488
+ raise GsTaichiRuntimeTypeError.get_ret(return_type.dtype.to_string(), values.element_type())
489
+ elif len(values.get_shape()) != return_type.ndim:
490
+ raise GsTaichiRuntimeTypeError(
491
+ f"Return matrix ndim mismatch, expecting={return_type.ndim}, got={len(values.get_shape())}."
492
+ )
493
+ elif return_type.get_shape() != values.get_shape():
494
+ raise GsTaichiRuntimeTypeError(
495
+ f"Return matrix shape mismatch, expecting={return_type.get_shape()}, got={values.get_shape()}."
496
+ )
497
+ values = [values]
498
+ else:
499
+ np_array = np.array(values)
500
+ dt, shape, ndim = np_array.dtype, np_array.shape, np_array.ndim
501
+ if return_type.dtype in primitive_types.real_types and dt not in (
502
+ float,
503
+ int,
504
+ np.floating,
505
+ np.integer,
506
+ ):
507
+ raise GsTaichiRuntimeTypeError.get_ret(return_type.dtype.to_string(), dt)
508
+ elif return_type.dtype in primitive_types.integer_types and dt not in (int, np.integer):
509
+ raise GsTaichiRuntimeTypeError.get_ret(return_type.dtype.to_string(), dt)
510
+ elif ndim != return_type.ndim:
511
+ raise GsTaichiRuntimeTypeError(
512
+ f"Return matrix ndim mismatch, expecting={return_type.ndim}, got={ndim}."
513
+ )
514
+ elif return_type.get_shape() != shape:
515
+ raise GsTaichiRuntimeTypeError(
516
+ f"Return matrix shape mismatch, expecting={return_type.get_shape()}, got={shape}."
517
+ )
518
+ values = [values]
519
+ return_exprs += [ti_ops.cast(exp, return_type.dtype) for exp in values]
520
+ elif isinstance(return_type, StructType):
521
+ if not isinstance(ptr, Struct) or not isinstance(ptr, return_type):
522
+ raise GsTaichiRuntimeTypeError.get_ret(str(return_type), ptr)
523
+ values = ptr
524
+ assert isinstance(values, Struct)
525
+ return_exprs += expr._get_flattened_ptrs(values)
526
+ else:
527
+ raise GsTaichiSyntaxError("The return type is not supported now!")
528
+ ctx.ast_builder.create_kernel_exprgroup_return(
529
+ expr.make_expr_group(return_exprs), _ti_core.DebugInfo(ctx.get_pos_info(node))
530
+ )
531
+ else:
532
+ ctx.return_data = node.value.ptr
533
+ if ctx.func.return_type is not None:
534
+ if len(ctx.func.return_type) == 1:
535
+ ctx.return_data = [ctx.return_data]
536
+ for i, return_type in enumerate(ctx.func.return_type):
537
+ if id(return_type) in primitive_types.type_ids:
538
+ ctx.return_data[i] = ti_ops.cast(ctx.return_data[i], return_type)
539
+ if len(ctx.func.return_type) == 1:
540
+ ctx.return_data = ctx.return_data[0]
541
+ if not ctx.is_real_function:
542
+ ctx.returned = ReturnStatus.ReturnedValue
543
+ return None
544
+
545
+ @staticmethod
546
+ def build_Module(ctx: ASTTransformerContext, node: ast.Module) -> None:
547
+ with ctx.variable_scope_guard():
548
+ # Do NOT use |build_stmts| which inserts 'del' statements to the
549
+ # end and deletes parameters passed into the module
550
+ for stmt in node.body:
551
+ build_stmt(ctx, stmt)
552
+ return None
553
+
554
+ @staticmethod
555
+ def build_attribute_if_is_dynamic_snode_method(ctx: ASTTransformerContext, node) -> bool:
556
+ is_subscript = isinstance(node.value, ast.Subscript)
557
+ names = ("append", "deactivate", "length")
558
+ if node.attr not in names:
559
+ return False
560
+ if is_subscript:
561
+ x = node.value.value.ptr
562
+ indices = node.value.slice.ptr
563
+ else:
564
+ x = node.value.ptr
565
+ indices = []
566
+ if not isinstance(x, Field):
567
+ return False
568
+ if not x.parent().ptr.type == _ti_core.SNodeType.dynamic:
569
+ return False
570
+ field_dim = x.snode.ptr.num_active_indices()
571
+ indices_expr_group = make_expr_group(*indices)
572
+ index_dim = indices_expr_group.size()
573
+ if field_dim != index_dim + 1:
574
+ return False
575
+ if node.attr == "append":
576
+ node.ptr = lambda val: append(x.parent(), indices, val)
577
+ elif node.attr == "deactivate":
578
+ node.ptr = lambda: deactivate(x.parent(), indices)
579
+ else:
580
+ node.ptr = lambda: length(x.parent(), indices)
581
+ return True
582
+
583
+ @staticmethod
584
+ def build_Attribute(ctx: ASTTransformerContext, node: ast.Attribute):
585
+ # There are two valid cases for the methods of Dynamic SNode:
586
+ #
587
+ # 1. x[i, j].append (where the dimension of the field (3 in this case) is equal to one plus the number of the
588
+ # indices (2 in this case) )
589
+ #
590
+ # 2. x.append (where the dimension of the field is one, equal to x[()].append)
591
+ #
592
+ # For the first case, the AST (simplified) is like node = Attribute(value=Subscript(value=x, slice=[i, j]),
593
+ # attr="append"), when we build_stmt(node.value)(build the expression of the Subscript i.e. x[i, j]),
594
+ # it should build the expression of node.value.value (i.e. x) and node.value.slice (i.e. [i, j]), and raise a
595
+ # GsTaichiIndexError because the dimension of the field is not equal to the number of the indices. Therefore,
596
+ # when we meet the error, we can detect whether it is a method of Dynamic SNode and build the expression if
597
+ # it is by calling build_attribute_if_is_dynamic_snode_method. If we find that it is not a method of Dynamic
598
+ # SNode, we raise the error again.
599
+ #
600
+ # For the second case, the AST (simplified) is like node = Attribute(value=x, attr="append"), and it does not
601
+ # raise error when we build_stmt(node.value). Therefore, when we do not meet the error, we can also detect
602
+ # whether it is a method of Dynamic SNode and build the expression if it is by calling
603
+ # build_attribute_if_is_dynamic_snode_method. If we find that it is not a method of Dynamic SNode,
604
+ # we continue to process it as a normal attribute node.
605
+ try:
606
+ build_stmt(ctx, node.value)
607
+ except Exception as e:
608
+ e = handle_exception_from_cpp(e)
609
+ if isinstance(e, GsTaichiIndexError):
610
+ node.value.ptr = None
611
+ if ASTTransformer.build_attribute_if_is_dynamic_snode_method(ctx, node):
612
+ return node.ptr
613
+ raise e
614
+
615
+ if ASTTransformer.build_attribute_if_is_dynamic_snode_method(ctx, node):
616
+ return node.ptr
617
+
618
+ if isinstance(node.value.ptr, Expr) and not hasattr(node.value.ptr, node.attr):
619
+ if node.attr in Matrix._swizzle_to_keygroup:
620
+ keygroup = Matrix._swizzle_to_keygroup[node.attr]
621
+ Matrix._keygroup_to_checker[keygroup](node.value.ptr, node.attr)
622
+ attr_len = len(node.attr)
623
+ if attr_len == 1:
624
+ node.ptr = Expr(
625
+ impl.get_runtime()
626
+ .compiling_callable.ast_builder()
627
+ .expr_subscript(
628
+ node.value.ptr.ptr,
629
+ make_expr_group(keygroup.index(node.attr)),
630
+ _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
631
+ )
632
+ )
633
+ else:
634
+ node.ptr = Expr(
635
+ _ti_core.subscript_with_multiple_indices(
636
+ node.value.ptr.ptr,
637
+ [make_expr_group(keygroup.index(ch)) for ch in node.attr],
638
+ (attr_len,),
639
+ _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
640
+ )
641
+ )
642
+ else:
643
+ from gstaichi.lang import ( # pylint: disable=C0415
644
+ matrix_ops as tensor_ops,
645
+ )
646
+
647
+ node.ptr = getattr(tensor_ops, node.attr)
648
+ setattr(node, "caller", node.value.ptr)
649
+ elif dataclasses.is_dataclass(node.value.ptr):
650
+ node.ptr = next(field.type for field in dataclasses.fields(node.value.ptr))
651
+ else:
652
+ node.ptr = getattr(node.value.ptr, node.attr)
653
+ return node.ptr
654
+
655
+ @staticmethod
656
+ def build_BinOp(ctx: ASTTransformerContext, node: ast.BinOp):
657
+ build_stmt(ctx, node.left)
658
+ build_stmt(ctx, node.right)
659
+ # pylint: disable-msg=C0415
660
+ from gstaichi.lang.matrix_ops import matmul
661
+
662
+ op = {
663
+ ast.Add: lambda l, r: l + r,
664
+ ast.Sub: lambda l, r: l - r,
665
+ ast.Mult: lambda l, r: l * r,
666
+ ast.Div: lambda l, r: l / r,
667
+ ast.FloorDiv: lambda l, r: l // r,
668
+ ast.Mod: lambda l, r: l % r,
669
+ ast.Pow: lambda l, r: l**r,
670
+ ast.LShift: lambda l, r: l << r,
671
+ ast.RShift: lambda l, r: l >> r,
672
+ ast.BitOr: lambda l, r: l | r,
673
+ ast.BitXor: lambda l, r: l ^ r,
674
+ ast.BitAnd: lambda l, r: l & r,
675
+ ast.MatMult: matmul,
676
+ }.get(type(node.op))
677
+ try:
678
+ node.ptr = op(node.left.ptr, node.right.ptr)
679
+ except TypeError as e:
680
+ raise GsTaichiTypeError(str(e)) from None
681
+ return node.ptr
682
+
683
+ @staticmethod
684
+ def build_AugAssign(ctx: ASTTransformerContext, node: ast.AugAssign):
685
+ build_stmt(ctx, node.target)
686
+ build_stmt(ctx, node.value)
687
+ if isinstance(node.target, ast.Name) and node.target.id in ctx.kernel_args:
688
+ raise GsTaichiSyntaxError(
689
+ f'Kernel argument "{node.target.id}" is immutable in the kernel. '
690
+ f"If you want to change its value, please create a new variable."
691
+ )
692
+ node.ptr = node.target.ptr._augassign(node.value.ptr, type(node.op).__name__)
693
+ return node.ptr
694
+
695
+ @staticmethod
696
+ def build_UnaryOp(ctx: ASTTransformerContext, node: ast.UnaryOp):
697
+ build_stmt(ctx, node.operand)
698
+ op = {
699
+ ast.UAdd: lambda l: l,
700
+ ast.USub: lambda l: -l,
701
+ ast.Not: ti_ops.logical_not,
702
+ ast.Invert: lambda l: ~l,
703
+ }.get(type(node.op))
704
+ node.ptr = op(node.operand.ptr)
705
+ return node.ptr
706
+
707
+ @staticmethod
708
+ def build_bool_op(op):
709
+ def inner(operands):
710
+ if len(operands) == 1:
711
+ return operands[0].ptr
712
+ return op(operands[0].ptr, inner(operands[1:]))
713
+
714
+ return inner
715
+
716
+ @staticmethod
717
+ def build_static_and(operands):
718
+ for operand in operands:
719
+ if not operand.ptr:
720
+ return operand.ptr
721
+ return operands[-1].ptr
722
+
723
+ @staticmethod
724
+ def build_static_or(operands):
725
+ for operand in operands:
726
+ if operand.ptr:
727
+ return operand.ptr
728
+ return operands[-1].ptr
729
+
730
+ @staticmethod
731
+ def build_BoolOp(ctx: ASTTransformerContext, node: ast.BoolOp):
732
+ build_stmts(ctx, node.values)
733
+ if ctx.is_in_static_scope():
734
+ ops = {
735
+ ast.And: ASTTransformer.build_static_and,
736
+ ast.Or: ASTTransformer.build_static_or,
737
+ }
738
+ elif impl.get_runtime().short_circuit_operators:
739
+ ops = {
740
+ ast.And: ASTTransformer.build_bool_op(ti_ops.logical_and),
741
+ ast.Or: ASTTransformer.build_bool_op(ti_ops.logical_or),
742
+ }
743
+ else:
744
+ ops = {
745
+ ast.And: ASTTransformer.build_bool_op(ti_ops.bit_and),
746
+ ast.Or: ASTTransformer.build_bool_op(ti_ops.bit_or),
747
+ }
748
+ op = ops.get(type(node.op))
749
+ node.ptr = op(node.values)
750
+ return node.ptr
751
+
752
+ @staticmethod
753
+ def build_Compare(ctx: ASTTransformerContext, node: ast.Compare):
754
+ build_stmt(ctx, node.left)
755
+ build_stmts(ctx, node.comparators)
756
+ ops = {
757
+ ast.Eq: lambda l, r: l == r,
758
+ ast.NotEq: lambda l, r: l != r,
759
+ ast.Lt: lambda l, r: l < r,
760
+ ast.LtE: lambda l, r: l <= r,
761
+ ast.Gt: lambda l, r: l > r,
762
+ ast.GtE: lambda l, r: l >= r,
763
+ }
764
+ ops_static = {
765
+ ast.In: lambda l, r: l in r,
766
+ ast.NotIn: lambda l, r: l not in r,
767
+ }
768
+ if ctx.is_in_static_scope():
769
+ ops = {**ops, **ops_static}
770
+ operands = [node.left.ptr] + [comparator.ptr for comparator in node.comparators]
771
+ val = True
772
+ for i, node_op in enumerate(node.ops):
773
+ if isinstance(node_op, (ast.Is, ast.IsNot)):
774
+ name = "is" if isinstance(node_op, ast.Is) else "is not"
775
+ raise GsTaichiSyntaxError(f'Operator "{name}" in GsTaichi scope is not supported.')
776
+ l = operands[i]
777
+ r = operands[i + 1]
778
+ op = ops.get(type(node_op))
779
+
780
+ if op is None:
781
+ if type(node_op) in ops_static:
782
+ raise GsTaichiSyntaxError(f'"{type(node_op).__name__}" is only supported inside `ti.static`.')
783
+ else:
784
+ raise GsTaichiSyntaxError(f'"{type(node_op).__name__}" is not supported in GsTaichi kernels.')
785
+ val = ti_ops.logical_and(val, op(l, r))
786
+ if not isinstance(val, (bool, np.bool_)):
787
+ val = ti_ops.cast(val, primitive_types.u1)
788
+ node.ptr = val
789
+ return node.ptr
790
+
791
+ @staticmethod
792
+ def get_for_loop_targets(node: ast.Name | ast.Tuple | Any) -> list:
793
+ """
794
+ Returns the list of indices of the for loop |node|.
795
+ See also: https://docs.python.org/3/library/ast.html#ast.For
796
+ """
797
+ if isinstance(node.target, ast.Name):
798
+ return [node.target.id]
799
+ assert isinstance(node.target, ast.Tuple)
800
+ return [name.id for name in node.target.elts]
801
+
802
+ @staticmethod
803
+ def build_static_for(ctx: ASTTransformerContext, node: ast.For, is_grouped: bool) -> None:
804
+ ti_unroll_limit = impl.get_runtime().unrolling_limit
805
+ if is_grouped:
806
+ assert len(node.iter.args[0].args) == 1
807
+ ndrange_arg = build_stmt(ctx, node.iter.args[0].args[0])
808
+ if not isinstance(ndrange_arg, _Ndrange):
809
+ raise GsTaichiSyntaxError("Only 'ti.ndrange' is allowed in 'ti.static(ti.grouped(...))'.")
810
+ targets = ASTTransformer.get_for_loop_targets(node)
811
+ if len(targets) != 1:
812
+ raise GsTaichiSyntaxError(f"Group for should have 1 loop target, found {len(targets)}")
813
+ target = targets[0]
814
+ iter_time = 0
815
+ alert_already = False
816
+
817
+ for value in impl.grouped(ndrange_arg):
818
+ iter_time += 1
819
+ if not alert_already and ti_unroll_limit and iter_time > ti_unroll_limit:
820
+ alert_already = True
821
+ warnings.warn_explicit(
822
+ f"""You are unrolling more than
823
+ {ti_unroll_limit} iterations, so the compile time may be extremely long.
824
+ You can use a non-static for loop if you want to decrease the compile time.
825
+ You can disable this warning by setting ti.init(unrolling_limit=0).""",
826
+ SyntaxWarning,
827
+ ctx.file,
828
+ node.lineno + ctx.lineno_offset,
829
+ module="gstaichi",
830
+ )
831
+
832
+ with ctx.variable_scope_guard():
833
+ ctx.create_variable(target, value)
834
+ build_stmts(ctx, node.body)
835
+ status = ctx.loop_status()
836
+ if status == LoopStatus.Break:
837
+ break
838
+ elif status == LoopStatus.Continue:
839
+ ctx.set_loop_status(LoopStatus.Normal)
840
+ else:
841
+ build_stmt(ctx, node.iter)
842
+ targets = ASTTransformer.get_for_loop_targets(node)
843
+
844
+ iter_time = 0
845
+ alert_already = False
846
+ for target_values in node.iter.ptr:
847
+ if not isinstance(target_values, collections.abc.Sequence) or len(targets) == 1:
848
+ target_values = [target_values]
849
+
850
+ iter_time += 1
851
+ if not alert_already and ti_unroll_limit and iter_time > ti_unroll_limit:
852
+ alert_already = True
853
+ warnings.warn_explicit(
854
+ f"""You are unrolling more than
855
+ {ti_unroll_limit} iterations, so the compile time may be extremely long.
856
+ You can use a non-static for loop if you want to decrease the compile time.
857
+ You can disable this warning by setting ti.init(unrolling_limit=0).""",
858
+ SyntaxWarning,
859
+ ctx.file,
860
+ node.lineno + ctx.lineno_offset,
861
+ module="gstaichi",
862
+ )
863
+
864
+ with ctx.variable_scope_guard():
865
+ for target, target_value in zip(targets, target_values):
866
+ ctx.create_variable(target, target_value)
867
+ build_stmts(ctx, node.body)
868
+ status = ctx.loop_status()
869
+ if status == LoopStatus.Break:
870
+ break
871
+ elif status == LoopStatus.Continue:
872
+ ctx.set_loop_status(LoopStatus.Normal)
873
+ return None
874
+
875
+ @staticmethod
876
+ def build_range_for(ctx: ASTTransformerContext, node: ast.For) -> None:
877
+ with ctx.variable_scope_guard():
878
+ loop_name = node.target.id
879
+ ctx.check_loop_var(loop_name)
880
+ loop_var = expr.Expr(ctx.ast_builder.make_id_expr(""))
881
+ ctx.create_variable(loop_name, loop_var)
882
+ if len(node.iter.args) not in [1, 2]:
883
+ raise GsTaichiSyntaxError(f"Range should have 1 or 2 arguments, found {len(node.iter.args)}")
884
+ if len(node.iter.args) == 2:
885
+ begin_expr = expr.Expr(build_stmt(ctx, node.iter.args[0]))
886
+ end_expr = expr.Expr(build_stmt(ctx, node.iter.args[1]))
887
+
888
+ # Warning for implicit dtype conversion
889
+ boundary_type_cast_warning(begin_expr)
890
+ boundary_type_cast_warning(end_expr)
891
+
892
+ begin = ti_ops.cast(begin_expr, primitive_types.i32)
893
+ end = ti_ops.cast(end_expr, primitive_types.i32)
894
+
895
+ else:
896
+ end_expr = expr.Expr(build_stmt(ctx, node.iter.args[0]))
897
+
898
+ # Warning for implicit dtype conversion
899
+ boundary_type_cast_warning(end_expr)
900
+
901
+ begin = ti_ops.cast(expr.Expr(0), primitive_types.i32)
902
+ end = ti_ops.cast(end_expr, primitive_types.i32)
903
+
904
+ for_di = _ti_core.DebugInfo(ctx.get_pos_info(node))
905
+ ctx.ast_builder.begin_frontend_range_for(loop_var.ptr, begin.ptr, end.ptr, for_di)
906
+ build_stmts(ctx, node.body)
907
+ ctx.ast_builder.end_frontend_range_for()
908
+ return None
909
+
910
+ @staticmethod
911
+ def build_ndrange_for(ctx: ASTTransformerContext, node: ast.For) -> None:
912
+ with ctx.variable_scope_guard():
913
+ ndrange_var = impl.expr_init(build_stmt(ctx, node.iter))
914
+ ndrange_begin = ti_ops.cast(expr.Expr(0), primitive_types.i32)
915
+ ndrange_end = ti_ops.cast(
916
+ expr.Expr(impl.subscript(ctx.ast_builder, ndrange_var.acc_dimensions, 0)),
917
+ primitive_types.i32,
918
+ )
919
+ ndrange_loop_var = expr.Expr(ctx.ast_builder.make_id_expr(""))
920
+ for_di = _ti_core.DebugInfo(ctx.get_pos_info(node))
921
+ ctx.ast_builder.begin_frontend_range_for(ndrange_loop_var.ptr, ndrange_begin.ptr, ndrange_end.ptr, for_di)
922
+ I = impl.expr_init(ndrange_loop_var)
923
+ targets = ASTTransformer.get_for_loop_targets(node)
924
+ if len(targets) != len(ndrange_var.dimensions):
925
+ raise GsTaichiSyntaxError(
926
+ "Ndrange for loop with number of the loop variables not equal to "
927
+ "the dimension of the ndrange is not supported. "
928
+ "Please check if the number of arguments of ti.ndrange() is equal to "
929
+ "the number of the loop variables."
930
+ )
931
+ for i, target in enumerate(targets):
932
+ if i + 1 < len(targets):
933
+ target_tmp = impl.expr_init(I // ndrange_var.acc_dimensions[i + 1])
934
+ else:
935
+ target_tmp = impl.expr_init(I)
936
+ ctx.create_variable(
937
+ target,
938
+ impl.expr_init(
939
+ target_tmp
940
+ + impl.subscript(
941
+ ctx.ast_builder,
942
+ impl.subscript(ctx.ast_builder, ndrange_var.bounds, i),
943
+ 0,
944
+ )
945
+ ),
946
+ )
947
+ if i + 1 < len(targets):
948
+ I._assign(I - target_tmp * ndrange_var.acc_dimensions[i + 1])
949
+ build_stmts(ctx, node.body)
950
+ ctx.ast_builder.end_frontend_range_for()
951
+ return None
952
+
953
+ @staticmethod
954
+ def build_grouped_ndrange_for(ctx: ASTTransformerContext, node: ast.For) -> None:
955
+ with ctx.variable_scope_guard():
956
+ ndrange_var = impl.expr_init(build_stmt(ctx, node.iter.args[0]))
957
+ ndrange_begin = ti_ops.cast(expr.Expr(0), primitive_types.i32)
958
+ ndrange_end = ti_ops.cast(
959
+ expr.Expr(impl.subscript(ctx.ast_builder, ndrange_var.acc_dimensions, 0)),
960
+ primitive_types.i32,
961
+ )
962
+ ndrange_loop_var = expr.Expr(ctx.ast_builder.make_id_expr(""))
963
+ for_di = _ti_core.DebugInfo(ctx.get_pos_info(node))
964
+ ctx.ast_builder.begin_frontend_range_for(ndrange_loop_var.ptr, ndrange_begin.ptr, ndrange_end.ptr, for_di)
965
+
966
+ targets = ASTTransformer.get_for_loop_targets(node)
967
+ if len(targets) != 1:
968
+ raise GsTaichiSyntaxError(f"Group for should have 1 loop target, found {len(targets)}")
969
+ target = targets[0]
970
+ mat = matrix.make_matrix([0] * len(ndrange_var.dimensions), dt=primitive_types.i32)
971
+ target_var = impl.expr_init(mat)
972
+
973
+ ctx.create_variable(target, target_var)
974
+ I = impl.expr_init(ndrange_loop_var)
975
+ for i in range(len(ndrange_var.dimensions)):
976
+ if i + 1 < len(ndrange_var.dimensions):
977
+ target_tmp = I // ndrange_var.acc_dimensions[i + 1]
978
+ else:
979
+ target_tmp = I
980
+ impl.subscript(ctx.ast_builder, target_var, i)._assign(target_tmp + ndrange_var.bounds[i][0])
981
+ if i + 1 < len(ndrange_var.dimensions):
982
+ I._assign(I - target_tmp * ndrange_var.acc_dimensions[i + 1])
983
+ build_stmts(ctx, node.body)
984
+ ctx.ast_builder.end_frontend_range_for()
985
+ return None
986
+
987
+ @staticmethod
988
+ def build_struct_for(ctx: ASTTransformerContext, node: ast.For, is_grouped: bool) -> None:
989
+ # for i, j in x
990
+ # for I in ti.grouped(x)
991
+ targets = ASTTransformer.get_for_loop_targets(node)
992
+
993
+ for target in targets:
994
+ ctx.check_loop_var(target)
995
+
996
+ with ctx.variable_scope_guard():
997
+ if is_grouped:
998
+ if len(targets) != 1:
999
+ raise GsTaichiSyntaxError(f"Group for should have 1 loop target, found {len(targets)}")
1000
+ target = targets[0]
1001
+ loop_var = build_stmt(ctx, node.iter)
1002
+ loop_indices = expr.make_var_list(size=len(loop_var.shape), ast_builder=ctx.ast_builder)
1003
+ expr_group = expr.make_expr_group(loop_indices)
1004
+ impl.begin_frontend_struct_for(ctx.ast_builder, expr_group, loop_var)
1005
+ ctx.create_variable(target, matrix.make_matrix(loop_indices, dt=primitive_types.i32))
1006
+ build_stmts(ctx, node.body)
1007
+ ctx.ast_builder.end_frontend_struct_for()
1008
+ else:
1009
+ _vars = []
1010
+ for name in targets:
1011
+ var = expr.Expr(ctx.ast_builder.make_id_expr(""))
1012
+ _vars.append(var)
1013
+ ctx.create_variable(name, var)
1014
+ loop_var = node.iter.ptr
1015
+ expr_group = expr.make_expr_group(*_vars)
1016
+ impl.begin_frontend_struct_for(ctx.ast_builder, expr_group, loop_var)
1017
+ build_stmts(ctx, node.body)
1018
+ ctx.ast_builder.end_frontend_struct_for()
1019
+ return None
1020
+
1021
+ @staticmethod
1022
+ def build_mesh_for(ctx: ASTTransformerContext, node: ast.For) -> None:
1023
+ targets = ASTTransformer.get_for_loop_targets(node)
1024
+ if len(targets) != 1:
1025
+ raise GsTaichiSyntaxError("Mesh for should have 1 loop target, found {len(targets)}")
1026
+ target = targets[0]
1027
+
1028
+ with ctx.variable_scope_guard():
1029
+ var = expr.Expr(ctx.ast_builder.make_id_expr(""))
1030
+ ctx.mesh = node.iter.ptr.mesh
1031
+ assert isinstance(ctx.mesh, impl.MeshInstance)
1032
+ mesh_idx = mesh.MeshElementFieldProxy(ctx.mesh, node.iter.ptr._type, var.ptr)
1033
+ ctx.create_variable(target, mesh_idx)
1034
+ ctx.ast_builder.begin_frontend_mesh_for(
1035
+ mesh_idx.ptr,
1036
+ ctx.mesh.mesh_ptr,
1037
+ node.iter.ptr._type,
1038
+ _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
1039
+ )
1040
+ build_stmts(ctx, node.body)
1041
+ ctx.mesh = None
1042
+ ctx.ast_builder.end_frontend_mesh_for()
1043
+ return None
1044
+
1045
+ @staticmethod
1046
+ def build_nested_mesh_for(ctx: ASTTransformerContext, node: ast.For) -> None:
1047
+ targets = ASTTransformer.get_for_loop_targets(node)
1048
+ if len(targets) != 1:
1049
+ raise GsTaichiSyntaxError("Nested-mesh for should have 1 loop target, found {len(targets)}")
1050
+ target = targets[0]
1051
+
1052
+ with ctx.variable_scope_guard():
1053
+ ctx.mesh = node.iter.ptr.mesh
1054
+ assert isinstance(ctx.mesh, impl.MeshInstance)
1055
+ loop_name = node.target.id + "_index__"
1056
+ loop_var = expr.Expr(ctx.ast_builder.make_id_expr(""))
1057
+ ctx.create_variable(loop_name, loop_var)
1058
+ begin = expr.Expr(0)
1059
+ end = ti_ops.cast(node.iter.ptr.size, primitive_types.i32)
1060
+ for_di = _ti_core.DebugInfo(ctx.get_pos_info(node))
1061
+ ctx.ast_builder.begin_frontend_range_for(loop_var.ptr, begin.ptr, end.ptr, for_di)
1062
+ entry_expr = _ti_core.get_relation_access(
1063
+ ctx.mesh.mesh_ptr,
1064
+ node.iter.ptr.from_index.ptr,
1065
+ node.iter.ptr.to_element_type,
1066
+ loop_var.ptr,
1067
+ )
1068
+ entry_expr.type_check(impl.get_runtime().prog.config())
1069
+ mesh_idx = mesh.MeshElementFieldProxy(ctx.mesh, node.iter.ptr.to_element_type, entry_expr)
1070
+ ctx.create_variable(target, mesh_idx)
1071
+ build_stmts(ctx, node.body)
1072
+ ctx.ast_builder.end_frontend_range_for()
1073
+
1074
+ return None
1075
+
1076
+ @staticmethod
1077
+ def build_For(ctx: ASTTransformerContext, node: ast.For) -> None:
1078
+ if node.orelse:
1079
+ raise GsTaichiSyntaxError("'else' clause for 'for' not supported in GsTaichi kernels")
1080
+ decorator = get_decorator(ctx, node.iter)
1081
+ double_decorator = ""
1082
+ if decorator != "" and len(node.iter.args) == 1:
1083
+ double_decorator = get_decorator(ctx, node.iter.args[0])
1084
+
1085
+ if decorator == "static":
1086
+ if double_decorator == "static":
1087
+ raise GsTaichiSyntaxError("'ti.static' cannot be nested")
1088
+ with ctx.loop_scope_guard(is_static=True):
1089
+ return ASTTransformer.build_static_for(ctx, node, double_decorator == "grouped")
1090
+ with ctx.loop_scope_guard():
1091
+ if decorator == "ndrange":
1092
+ if double_decorator != "":
1093
+ raise GsTaichiSyntaxError("No decorator is allowed inside 'ti.ndrange")
1094
+ return ASTTransformer.build_ndrange_for(ctx, node)
1095
+ if decorator == "grouped":
1096
+ if double_decorator == "static":
1097
+ raise GsTaichiSyntaxError("'ti.static' is not allowed inside 'ti.grouped'")
1098
+ elif double_decorator == "ndrange":
1099
+ return ASTTransformer.build_grouped_ndrange_for(ctx, node)
1100
+ elif double_decorator == "grouped":
1101
+ raise GsTaichiSyntaxError("'ti.grouped' cannot be nested")
1102
+ else:
1103
+ return ASTTransformer.build_struct_for(ctx, node, is_grouped=True)
1104
+ elif (
1105
+ isinstance(node.iter, ast.Call)
1106
+ and isinstance(node.iter.func, ast.Name)
1107
+ and node.iter.func.id == "range"
1108
+ ):
1109
+ return ASTTransformer.build_range_for(ctx, node)
1110
+ else:
1111
+ build_stmt(ctx, node.iter)
1112
+ if isinstance(node.iter.ptr, mesh.MeshElementField):
1113
+ if not _ti_core.is_extension_supported(impl.default_cfg().arch, _ti_core.Extension.mesh):
1114
+ raise Exception(
1115
+ "Backend " + str(impl.default_cfg().arch) + " doesn't support MeshGsTaichi extension"
1116
+ )
1117
+ return ASTTransformer.build_mesh_for(ctx, node)
1118
+ if isinstance(node.iter.ptr, mesh.MeshRelationAccessProxy):
1119
+ return ASTTransformer.build_nested_mesh_for(ctx, node)
1120
+ # Struct for
1121
+ return ASTTransformer.build_struct_for(ctx, node, is_grouped=False)
1122
+
1123
+ @staticmethod
1124
+ def build_While(ctx: ASTTransformerContext, node: ast.While) -> None:
1125
+ if node.orelse:
1126
+ raise GsTaichiSyntaxError("'else' clause for 'while' not supported in GsTaichi kernels")
1127
+
1128
+ with ctx.loop_scope_guard():
1129
+ stmt_dbg_info = _ti_core.DebugInfo(ctx.get_pos_info(node))
1130
+ ctx.ast_builder.begin_frontend_while(expr.Expr(1, dtype=primitive_types.i32).ptr, stmt_dbg_info)
1131
+ while_cond = build_stmt(ctx, node.test)
1132
+ impl.begin_frontend_if(ctx.ast_builder, while_cond, stmt_dbg_info)
1133
+ ctx.ast_builder.begin_frontend_if_true()
1134
+ ctx.ast_builder.pop_scope()
1135
+ ctx.ast_builder.begin_frontend_if_false()
1136
+ ctx.ast_builder.insert_break_stmt(stmt_dbg_info)
1137
+ ctx.ast_builder.pop_scope()
1138
+ build_stmts(ctx, node.body)
1139
+ ctx.ast_builder.pop_scope()
1140
+ return None
1141
+
1142
+ @staticmethod
1143
+ def build_If(ctx: ASTTransformerContext, node: ast.If) -> ast.If | None:
1144
+ build_stmt(ctx, node.test)
1145
+ is_static_if = get_decorator(ctx, node.test) == "static"
1146
+
1147
+ if is_static_if:
1148
+ if node.test.ptr:
1149
+ build_stmts(ctx, node.body)
1150
+ else:
1151
+ build_stmts(ctx, node.orelse)
1152
+ return node
1153
+
1154
+ with ctx.non_static_if_guard(node):
1155
+ stmt_dbg_info = _ti_core.DebugInfo(ctx.get_pos_info(node))
1156
+ impl.begin_frontend_if(ctx.ast_builder, node.test.ptr, stmt_dbg_info)
1157
+ ctx.ast_builder.begin_frontend_if_true()
1158
+ build_stmts(ctx, node.body)
1159
+ ctx.ast_builder.pop_scope()
1160
+ ctx.ast_builder.begin_frontend_if_false()
1161
+ build_stmts(ctx, node.orelse)
1162
+ ctx.ast_builder.pop_scope()
1163
+ return None
1164
+
1165
+ @staticmethod
1166
+ def build_Expr(ctx: ASTTransformerContext, node: ast.Expr) -> None:
1167
+ build_stmt(ctx, node.value)
1168
+ return None
1169
+
1170
+ @staticmethod
1171
+ def build_IfExp(ctx: ASTTransformerContext, node: ast.IfExp):
1172
+ build_stmt(ctx, node.test)
1173
+ build_stmt(ctx, node.body)
1174
+ build_stmt(ctx, node.orelse)
1175
+
1176
+ has_tensor_type = False
1177
+ if isinstance(node.test.ptr, expr.Expr) and node.test.ptr.is_tensor():
1178
+ has_tensor_type = True
1179
+ if isinstance(node.body.ptr, expr.Expr) and node.body.ptr.is_tensor():
1180
+ has_tensor_type = True
1181
+ if isinstance(node.orelse.ptr, expr.Expr) and node.orelse.ptr.is_tensor():
1182
+ has_tensor_type = True
1183
+
1184
+ if has_tensor_type:
1185
+ if isinstance(node.test.ptr, expr.Expr) and node.test.ptr.is_tensor():
1186
+ raise GsTaichiSyntaxError(
1187
+ "Using conditional expression for element-wise select operation on "
1188
+ "GsTaichi vectors/matrices is deprecated and removed starting from GsTaichi v1.5.0 "
1189
+ 'Please use "ti.select" instead.'
1190
+ )
1191
+ node.ptr = ti_ops.select(node.test.ptr, node.body.ptr, node.orelse.ptr)
1192
+ return node.ptr
1193
+
1194
+ is_static_if = get_decorator(ctx, node.test) == "static"
1195
+
1196
+ if is_static_if:
1197
+ if node.test.ptr:
1198
+ node.ptr = build_stmt(ctx, node.body)
1199
+ else:
1200
+ node.ptr = build_stmt(ctx, node.orelse)
1201
+ return node.ptr
1202
+
1203
+ node.ptr = ti_ops.ifte(node.test.ptr, node.body.ptr, node.orelse.ptr)
1204
+ return node.ptr
1205
+
1206
+ @staticmethod
1207
+ def _is_string_mod_args(msg) -> bool:
1208
+ # 1. str % (a, b, c, ...)
1209
+ # 2. str % single_item
1210
+ # Note that |msg.right| may not be a tuple.
1211
+ if not isinstance(msg, ast.BinOp):
1212
+ return False
1213
+ if not isinstance(msg.op, ast.Mod):
1214
+ return False
1215
+ if isinstance(msg.left, ast.Str):
1216
+ return True
1217
+ if isinstance(msg.left, ast.Constant) and isinstance(msg.left.value, str):
1218
+ return True
1219
+ return False
1220
+
1221
+ @staticmethod
1222
+ def _handle_string_mod_args(ctx: ASTTransformerContext, node):
1223
+ msg = build_stmt(ctx, node.left)
1224
+ args = build_stmt(ctx, node.right)
1225
+ if not isinstance(args, collections.abc.Sequence):
1226
+ args = (args,)
1227
+ args = [expr.Expr(x).ptr for x in args]
1228
+ return msg, args
1229
+
1230
+ @staticmethod
1231
+ def ti_format_list_to_assert_msg(raw) -> tuple[str, list]:
1232
+ # TODO: ignore formats here for now
1233
+ entries, _ = impl.ti_format_list_to_content_entries([raw])
1234
+ msg = ""
1235
+ args = []
1236
+ for entry in entries:
1237
+ if isinstance(entry, str):
1238
+ msg += entry
1239
+ elif isinstance(entry, _ti_core.ExprCxx):
1240
+ ty = entry.get_rvalue_type()
1241
+ if ty in primitive_types.real_types:
1242
+ msg += "%f"
1243
+ elif ty in primitive_types.integer_types:
1244
+ msg += "%d"
1245
+ else:
1246
+ raise GsTaichiSyntaxError(f"Unsupported data type: {type(ty)}")
1247
+ args.append(entry)
1248
+ else:
1249
+ raise GsTaichiSyntaxError(f"Unsupported type: {type(entry)}")
1250
+ return msg, args
1251
+
1252
+ @staticmethod
1253
+ def build_Assert(ctx: ASTTransformerContext, node: ast.Assert) -> None:
1254
+ extra_args = []
1255
+ if node.msg is not None:
1256
+ if ASTTransformer._is_string_mod_args(node.msg):
1257
+ msg, extra_args = ASTTransformer._handle_string_mod_args(ctx, node.msg)
1258
+ else:
1259
+ msg = build_stmt(ctx, node.msg)
1260
+ if isinstance(node.msg, ast.Constant):
1261
+ msg = str(msg)
1262
+ elif isinstance(node.msg, ast.Str):
1263
+ pass
1264
+ elif isinstance(msg, collections.abc.Sequence) and len(msg) > 0 and msg[0] == "__ti_format__":
1265
+ msg, extra_args = ASTTransformer.ti_format_list_to_assert_msg(msg)
1266
+ else:
1267
+ raise GsTaichiSyntaxError(f"assert info must be constant or formatted string, not {type(msg)}")
1268
+ else:
1269
+ msg = unparse(node.test)
1270
+ test = build_stmt(ctx, node.test)
1271
+ impl.ti_assert(test, msg.strip(), extra_args, _ti_core.DebugInfo(ctx.get_pos_info(node)))
1272
+ return None
1273
+
1274
+ @staticmethod
1275
+ def build_Break(ctx: ASTTransformerContext, node: ast.Break) -> None:
1276
+ if ctx.is_in_static_for():
1277
+ nearest_non_static_if = ctx.current_loop_scope().nearest_non_static_if
1278
+ if nearest_non_static_if:
1279
+ msg = ctx.get_pos_info(nearest_non_static_if.test)
1280
+ msg += (
1281
+ "You are trying to `break` a static `for` loop, "
1282
+ "but the `break` statement is inside a non-static `if`. "
1283
+ )
1284
+ raise GsTaichiSyntaxError(msg)
1285
+ ctx.set_loop_status(LoopStatus.Break)
1286
+ else:
1287
+ ctx.ast_builder.insert_break_stmt(_ti_core.DebugInfo(ctx.get_pos_info(node)))
1288
+ return None
1289
+
1290
+ @staticmethod
1291
+ def build_Continue(ctx: ASTTransformerContext, node: ast.Continue) -> None:
1292
+ if ctx.is_in_static_for():
1293
+ nearest_non_static_if = ctx.current_loop_scope().nearest_non_static_if
1294
+ if nearest_non_static_if:
1295
+ msg = ctx.get_pos_info(nearest_non_static_if.test)
1296
+ msg += (
1297
+ "You are trying to `continue` a static `for` loop, "
1298
+ "but the `continue` statement is inside a non-static `if`. "
1299
+ )
1300
+ raise GsTaichiSyntaxError(msg)
1301
+ ctx.set_loop_status(LoopStatus.Continue)
1302
+ else:
1303
+ ctx.ast_builder.insert_continue_stmt(_ti_core.DebugInfo(ctx.get_pos_info(node)))
1304
+ return None
1305
+
1306
+ @staticmethod
1307
+ def build_Pass(ctx: ASTTransformerContext, node: ast.Pass) -> None:
1308
+ return None
1309
+
1310
+
1311
+ build_stmt = ASTTransformer()
1312
+
1313
+
1314
+ def build_stmts(ctx: ASTTransformerContext, stmts: list[ast.stmt]):
1315
+ # TODO: Should we just make this part of ASTTransformer? Then, easier to pass around (just
1316
+ # pass the ASTTransformer object around)
1317
+ with ctx.variable_scope_guard():
1318
+ for stmt in stmts:
1319
+ if ctx.returned != ReturnStatus.NoReturn or ctx.loop_status() != LoopStatus.Normal:
1320
+ break
1321
+ else:
1322
+ build_stmt(ctx, stmt)
1323
+ return stmts