gstaichi 0.1.25.dev0__cp313-cp313-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. gstaichi/__init__.py +40 -0
  2. gstaichi/__main__.py +5 -0
  3. gstaichi/_funcs.py +706 -0
  4. gstaichi/_kernels.py +420 -0
  5. gstaichi/_lib/__init__.py +3 -0
  6. gstaichi/_lib/core/__init__.py +0 -0
  7. gstaichi/_lib/core/gstaichi_python.cpython-313-darwin.so +0 -0
  8. gstaichi/_lib/core/gstaichi_python.pyi +2939 -0
  9. gstaichi/_lib/core/py.typed +0 -0
  10. gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
  11. gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
  12. gstaichi/_lib/utils.py +249 -0
  13. gstaichi/_logging.py +131 -0
  14. gstaichi/_main.py +545 -0
  15. gstaichi/_snode/__init__.py +5 -0
  16. gstaichi/_snode/fields_builder.py +187 -0
  17. gstaichi/_snode/snode_tree.py +34 -0
  18. gstaichi/_test_tools/__init__.py +0 -0
  19. gstaichi/_test_tools/load_kernel_string.py +30 -0
  20. gstaichi/_version.py +1 -0
  21. gstaichi/_version_check.py +103 -0
  22. gstaichi/ad/__init__.py +3 -0
  23. gstaichi/ad/_ad.py +530 -0
  24. gstaichi/algorithms/__init__.py +3 -0
  25. gstaichi/algorithms/_algorithms.py +117 -0
  26. gstaichi/assets/.git +1 -0
  27. gstaichi/assets/Go-Regular.ttf +0 -0
  28. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  29. gstaichi/examples/minimal.py +28 -0
  30. gstaichi/experimental.py +16 -0
  31. gstaichi/lang/__init__.py +50 -0
  32. gstaichi/lang/_ndarray.py +352 -0
  33. gstaichi/lang/_ndrange.py +152 -0
  34. gstaichi/lang/_template_mapper.py +199 -0
  35. gstaichi/lang/_texture.py +172 -0
  36. gstaichi/lang/_wrap_inspect.py +189 -0
  37. gstaichi/lang/any_array.py +99 -0
  38. gstaichi/lang/argpack.py +411 -0
  39. gstaichi/lang/ast/__init__.py +5 -0
  40. gstaichi/lang/ast/ast_transformer.py +1318 -0
  41. gstaichi/lang/ast/ast_transformer_utils.py +341 -0
  42. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  43. gstaichi/lang/ast/ast_transformers/call_transformer.py +267 -0
  44. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +320 -0
  45. gstaichi/lang/ast/checkers.py +106 -0
  46. gstaichi/lang/ast/symbol_resolver.py +57 -0
  47. gstaichi/lang/ast/transform.py +9 -0
  48. gstaichi/lang/common_ops.py +310 -0
  49. gstaichi/lang/exception.py +80 -0
  50. gstaichi/lang/expr.py +180 -0
  51. gstaichi/lang/field.py +466 -0
  52. gstaichi/lang/impl.py +1241 -0
  53. gstaichi/lang/kernel_arguments.py +157 -0
  54. gstaichi/lang/kernel_impl.py +1382 -0
  55. gstaichi/lang/matrix.py +1881 -0
  56. gstaichi/lang/matrix_ops.py +341 -0
  57. gstaichi/lang/matrix_ops_utils.py +190 -0
  58. gstaichi/lang/mesh.py +687 -0
  59. gstaichi/lang/misc.py +778 -0
  60. gstaichi/lang/ops.py +1494 -0
  61. gstaichi/lang/runtime_ops.py +13 -0
  62. gstaichi/lang/shell.py +35 -0
  63. gstaichi/lang/simt/__init__.py +5 -0
  64. gstaichi/lang/simt/block.py +94 -0
  65. gstaichi/lang/simt/grid.py +7 -0
  66. gstaichi/lang/simt/subgroup.py +191 -0
  67. gstaichi/lang/simt/warp.py +96 -0
  68. gstaichi/lang/snode.py +489 -0
  69. gstaichi/lang/source_builder.py +150 -0
  70. gstaichi/lang/struct.py +855 -0
  71. gstaichi/lang/util.py +381 -0
  72. gstaichi/linalg/__init__.py +8 -0
  73. gstaichi/linalg/matrixfree_cg.py +310 -0
  74. gstaichi/linalg/sparse_cg.py +59 -0
  75. gstaichi/linalg/sparse_matrix.py +303 -0
  76. gstaichi/linalg/sparse_solver.py +123 -0
  77. gstaichi/math/__init__.py +11 -0
  78. gstaichi/math/_complex.py +205 -0
  79. gstaichi/math/mathimpl.py +886 -0
  80. gstaichi/profiler/__init__.py +6 -0
  81. gstaichi/profiler/kernel_metrics.py +260 -0
  82. gstaichi/profiler/kernel_profiler.py +586 -0
  83. gstaichi/profiler/memory_profiler.py +15 -0
  84. gstaichi/profiler/scoped_profiler.py +36 -0
  85. gstaichi/sparse/__init__.py +3 -0
  86. gstaichi/sparse/_sparse_grid.py +77 -0
  87. gstaichi/tools/__init__.py +12 -0
  88. gstaichi/tools/diagnose.py +117 -0
  89. gstaichi/tools/np2ply.py +364 -0
  90. gstaichi/tools/vtk.py +38 -0
  91. gstaichi/types/__init__.py +19 -0
  92. gstaichi/types/annotations.py +47 -0
  93. gstaichi/types/compound_types.py +90 -0
  94. gstaichi/types/enums.py +49 -0
  95. gstaichi/types/ndarray_type.py +147 -0
  96. gstaichi/types/primitive_types.py +206 -0
  97. gstaichi/types/quant.py +88 -0
  98. gstaichi/types/texture_type.py +85 -0
  99. gstaichi/types/utils.py +13 -0
  100. gstaichi-0.1.25.dev0.data/data/include/GLFW/glfw3.h +6389 -0
  101. gstaichi-0.1.25.dev0.data/data/include/GLFW/glfw3native.h +594 -0
  102. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/instrument.hpp +268 -0
  103. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.h +907 -0
  104. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.hpp +375 -0
  105. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/linker.hpp +97 -0
  106. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/optimizer.hpp +970 -0
  107. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  108. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv.h +2568 -0
  109. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv.hpp +2579 -0
  110. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  111. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  112. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  113. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  114. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  115. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  116. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  117. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  118. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  119. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  120. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  121. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  122. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  123. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  124. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  125. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  126. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  127. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  128. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  129. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  130. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  131. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  132. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  133. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  134. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  135. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  136. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  137. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  138. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  139. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  140. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  141. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  142. gstaichi-0.1.25.dev0.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  143. gstaichi-0.1.25.dev0.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  144. gstaichi-0.1.25.dev0.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  145. gstaichi-0.1.25.dev0.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  146. gstaichi-0.1.25.dev0.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  147. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  148. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  149. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  150. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  151. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  152. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  153. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  154. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  155. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  156. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  157. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  158. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  159. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  160. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  161. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  162. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  163. gstaichi-0.1.25.dev0.dist-info/METADATA +105 -0
  164. gstaichi-0.1.25.dev0.dist-info/RECORD +168 -0
  165. gstaichi-0.1.25.dev0.dist-info/WHEEL +5 -0
  166. gstaichi-0.1.25.dev0.dist-info/entry_points.txt +2 -0
  167. gstaichi-0.1.25.dev0.dist-info/licenses/LICENSE +201 -0
  168. gstaichi-0.1.25.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1318 @@
1
+ # type: ignore
2
+
3
+ import ast
4
+ import collections.abc
5
+ import itertools
6
+ import warnings
7
+ from ast import unparse
8
+ from typing import Any, Iterable, Type
9
+
10
+ import numpy as np
11
+
12
+ from gstaichi._lib import core as _ti_core
13
+ from gstaichi.lang import expr, impl, matrix, mesh
14
+ from gstaichi.lang import ops as ti_ops
15
+ from gstaichi.lang._ndrange import _Ndrange
16
+ from gstaichi.lang.ast.ast_transformer_utils import (
17
+ ASTTransformerContext,
18
+ Builder,
19
+ LoopStatus,
20
+ ReturnStatus,
21
+ get_decorator,
22
+ )
23
+ from gstaichi.lang.ast.ast_transformers.call_transformer import CallTransformer
24
+ from gstaichi.lang.ast.ast_transformers.function_def_transformer import (
25
+ FunctionDefTransformer,
26
+ )
27
+ from gstaichi.lang.exception import (
28
+ GsTaichiIndexError,
29
+ GsTaichiRuntimeTypeError,
30
+ GsTaichiSyntaxError,
31
+ GsTaichiTypeError,
32
+ handle_exception_from_cpp,
33
+ )
34
+ from gstaichi.lang.expr import Expr, make_expr_group
35
+ from gstaichi.lang.field import Field
36
+ from gstaichi.lang.matrix import Matrix, MatrixType
37
+ from gstaichi.lang.snode import append, deactivate, length
38
+ from gstaichi.lang.struct import Struct, StructType
39
+ from gstaichi.types import primitive_types
40
+ from gstaichi.types.utils import is_integral
41
+
42
+
43
+ def reshape_list(flat_list: list[Any], target_shape: Iterable[int]) -> list[Any]:
44
+ if len(target_shape) < 2:
45
+ return flat_list
46
+
47
+ curr_list = []
48
+ dim = target_shape[-1]
49
+ for i, elem in enumerate(flat_list):
50
+ if i % dim == 0:
51
+ curr_list.append([])
52
+ curr_list[-1].append(elem)
53
+
54
+ return reshape_list(curr_list, target_shape[:-1])
55
+
56
+
57
+ def boundary_type_cast_warning(expression: Expr) -> None:
58
+ expr_dtype = expression.ptr.get_rvalue_type()
59
+ if not is_integral(expr_dtype) or expr_dtype in [
60
+ primitive_types.i64,
61
+ primitive_types.u64,
62
+ primitive_types.u32,
63
+ ]:
64
+ warnings.warn(
65
+ f"Casting range_for boundary values from {expr_dtype} to i32, which may cause numerical issues",
66
+ Warning,
67
+ )
68
+
69
+
70
+ class ASTTransformer(Builder):
71
+ @staticmethod
72
+ def build_Name(ctx: ASTTransformerContext, node: ast.Name):
73
+ node.ptr = ctx.get_var_by_name(node.id)
74
+ if isinstance(node, (ast.stmt, ast.expr)) and isinstance(node.ptr, Expr):
75
+ node.ptr.dbg_info = _ti_core.DebugInfo(ctx.get_pos_info(node))
76
+ node.ptr.ptr.set_dbg_info(node.ptr.dbg_info)
77
+ return node.ptr
78
+
79
+ @staticmethod
80
+ def build_AnnAssign(ctx: ASTTransformerContext, node: ast.AnnAssign):
81
+ build_stmt(ctx, node.value)
82
+ build_stmt(ctx, node.annotation)
83
+
84
+ is_static_assign = isinstance(node.value, ast.Call) and node.value.func.ptr is impl.static
85
+
86
+ node.ptr = ASTTransformer.build_assign_annotated(
87
+ ctx, node.target, node.value.ptr, is_static_assign, node.annotation.ptr
88
+ )
89
+ return node.ptr
90
+
91
+ @staticmethod
92
+ def build_assign_annotated(
93
+ ctx: ASTTransformerContext, target: ast.Name, value, is_static_assign: bool, annotation: Type
94
+ ):
95
+ """Build an annotated assignment like this: target: annotation = value.
96
+
97
+ Args:
98
+ ctx (ast_builder_utils.BuilderContext): The builder context.
99
+ target (ast.Name): A variable name. `target.id` holds the name as
100
+ a string.
101
+ annotation: A type we hope to assign to the target
102
+ value: A node representing the value.
103
+ is_static_assign: A boolean value indicating whether this is a static assignment
104
+ """
105
+ is_local = isinstance(target, ast.Name)
106
+ if is_local and target.id in ctx.kernel_args:
107
+ raise GsTaichiSyntaxError(
108
+ f'Kernel argument "{target.id}" is immutable in the kernel. '
109
+ f"If you want to change its value, please create a new variable."
110
+ )
111
+ anno = impl.expr_init(annotation)
112
+ if is_static_assign:
113
+ raise GsTaichiSyntaxError("Static assign cannot be used on annotated assignment")
114
+ if is_local and not ctx.is_var_declared(target.id):
115
+ var = ti_ops.cast(value, anno)
116
+ var = impl.expr_init(var)
117
+ ctx.create_variable(target.id, var)
118
+ else:
119
+ var = build_stmt(ctx, target)
120
+ if var.ptr.get_rvalue_type() != anno:
121
+ raise GsTaichiSyntaxError("Static assign cannot have type overloading")
122
+ var._assign(value)
123
+ return var
124
+
125
+ @staticmethod
126
+ def build_Assign(ctx: ASTTransformerContext, node: ast.Assign) -> None:
127
+ build_stmt(ctx, node.value)
128
+ is_static_assign = isinstance(node.value, ast.Call) and node.value.func.ptr is impl.static
129
+
130
+ # Keep all generated assign statements and compose single one at last.
131
+ # The variable is introduced to support chained assignments.
132
+ # Ref https://github.com/taichi-dev/gstaichi/issues/2659.
133
+ values = node.value.ptr if is_static_assign else impl.expr_init(node.value.ptr)
134
+
135
+ for node_target in node.targets:
136
+ ASTTransformer.build_assign_unpack(ctx, node_target, values, is_static_assign)
137
+ return None
138
+
139
+ @staticmethod
140
+ def build_assign_unpack(ctx: ASTTransformerContext, node_target: list | ast.Tuple, values, is_static_assign: bool):
141
+ """Build the unpack assignments like this: (target1, target2) = (value1, value2).
142
+ The function should be called only if the node target is a tuple.
143
+
144
+ Args:
145
+ ctx (ast_builder_utils.BuilderContext): The builder context.
146
+ node_target (ast.Tuple): A list or tuple object. `node_target.elts` holds a
147
+ list of nodes representing the elements.
148
+ values: A node/list representing the values.
149
+ is_static_assign: A boolean value indicating whether this is a static assignment
150
+ """
151
+ if not isinstance(node_target, ast.Tuple):
152
+ return ASTTransformer.build_assign_basic(ctx, node_target, values, is_static_assign)
153
+ targets = node_target.elts
154
+
155
+ if isinstance(values, matrix.Matrix):
156
+ if not values.m == 1:
157
+ raise ValueError("Matrices with more than one columns cannot be unpacked")
158
+ values = values.entries
159
+
160
+ # Unpack: a, b, c = ti.Vector([1., 2., 3.])
161
+ if isinstance(values, impl.Expr) and values.ptr.is_tensor():
162
+ if len(values.get_shape()) > 1:
163
+ raise ValueError("Matrices with more than one columns cannot be unpacked")
164
+
165
+ values = ctx.ast_builder.expand_exprs([values.ptr])
166
+ if len(values) == 1:
167
+ values = values[0]
168
+
169
+ if isinstance(values, impl.Expr) and values.ptr.is_struct():
170
+ values = ctx.ast_builder.expand_exprs([values.ptr])
171
+ if len(values) == 1:
172
+ values = values[0]
173
+
174
+ if not isinstance(values, collections.abc.Sequence):
175
+ raise GsTaichiSyntaxError(f"Cannot unpack type: {type(values)}")
176
+
177
+ if len(values) != len(targets):
178
+ raise GsTaichiSyntaxError("The number of targets is not equal to value length")
179
+
180
+ for i, target in enumerate(targets):
181
+ ASTTransformer.build_assign_basic(ctx, target, values[i], is_static_assign)
182
+
183
+ return None
184
+
185
+ @staticmethod
186
+ def build_assign_basic(ctx: ASTTransformerContext, target: ast.Name, value, is_static_assign: bool):
187
+ """Build basic assignment like this: target = value.
188
+
189
+ Args:
190
+ ctx (ast_builder_utils.BuilderContext): The builder context.
191
+ target (ast.Name): A variable name. `target.id` holds the name as
192
+ a string.
193
+ value: A node representing the value.
194
+ is_static_assign: A boolean value indicating whether this is a static assignment
195
+ """
196
+ is_local = isinstance(target, ast.Name)
197
+ if is_local and target.id in ctx.kernel_args:
198
+ raise GsTaichiSyntaxError(
199
+ f'Kernel argument "{target.id}" is immutable in the kernel. '
200
+ f"If you want to change its value, please create a new variable."
201
+ )
202
+ if is_static_assign:
203
+ if not is_local:
204
+ raise GsTaichiSyntaxError("Static assign cannot be used on elements in arrays")
205
+ ctx.create_variable(target.id, value)
206
+ var = value
207
+ elif is_local and not ctx.is_var_declared(target.id):
208
+ var = impl.expr_init(value)
209
+ ctx.create_variable(target.id, var)
210
+ else:
211
+ var = build_stmt(ctx, target)
212
+ try:
213
+ var._assign(value)
214
+ except AttributeError:
215
+ raise GsTaichiSyntaxError(
216
+ f"Variable '{unparse(target).strip()}' cannot be assigned. Maybe it is not a GsTaichi object?"
217
+ )
218
+ return var
219
+
220
+ @staticmethod
221
+ def build_NamedExpr(ctx: ASTTransformerContext, node: ast.NamedExpr):
222
+ build_stmt(ctx, node.value)
223
+ is_static_assign = isinstance(node.value, ast.Call) and node.value.func.ptr is impl.static
224
+ node.ptr = ASTTransformer.build_assign_basic(ctx, node.target, node.value.ptr, is_static_assign)
225
+ return node.ptr
226
+
227
+ @staticmethod
228
+ def is_tuple(node):
229
+ if isinstance(node, ast.Tuple):
230
+ return True
231
+ if isinstance(node, ast.Index) and isinstance(node.value.ptr, tuple):
232
+ return True
233
+ if isinstance(node.ptr, tuple):
234
+ return True
235
+ return False
236
+
237
+ @staticmethod
238
+ def build_Subscript(ctx: ASTTransformerContext, node: ast.Subscript):
239
+ build_stmt(ctx, node.value)
240
+ build_stmt(ctx, node.slice)
241
+ if not ASTTransformer.is_tuple(node.slice):
242
+ node.slice.ptr = [node.slice.ptr]
243
+ node.ptr = impl.subscript(ctx.ast_builder, node.value.ptr, *node.slice.ptr)
244
+ return node.ptr
245
+
246
+ @staticmethod
247
+ def build_Slice(ctx: ASTTransformerContext, node: ast.Slice):
248
+ if node.lower is not None:
249
+ build_stmt(ctx, node.lower)
250
+ if node.upper is not None:
251
+ build_stmt(ctx, node.upper)
252
+ if node.step is not None:
253
+ build_stmt(ctx, node.step)
254
+
255
+ node.ptr = slice(
256
+ node.lower.ptr if node.lower else None,
257
+ node.upper.ptr if node.upper else None,
258
+ node.step.ptr if node.step else None,
259
+ )
260
+ return node.ptr
261
+
262
+ @staticmethod
263
+ def build_ExtSlice(ctx: ASTTransformerContext, node: ast.ExtSlice):
264
+ build_stmts(ctx, node.dims)
265
+ node.ptr = tuple(dim.ptr for dim in node.dims)
266
+ return node.ptr
267
+
268
+ @staticmethod
269
+ def build_Tuple(ctx: ASTTransformerContext, node: ast.Tuple):
270
+ build_stmts(ctx, node.elts)
271
+ node.ptr = tuple(elt.ptr for elt in node.elts)
272
+ return node.ptr
273
+
274
+ @staticmethod
275
+ def build_List(ctx: ASTTransformerContext, node: ast.List):
276
+ build_stmts(ctx, node.elts)
277
+ node.ptr = [elt.ptr for elt in node.elts]
278
+ return node.ptr
279
+
280
+ @staticmethod
281
+ def build_Dict(ctx: ASTTransformerContext, node: ast.Dict):
282
+ dic = {}
283
+ for key, value in zip(node.keys, node.values):
284
+ if key is None:
285
+ dic.update(build_stmt(ctx, value))
286
+ else:
287
+ dic[build_stmt(ctx, key)] = build_stmt(ctx, value)
288
+ node.ptr = dic
289
+ return node.ptr
290
+
291
+ @staticmethod
292
+ def process_listcomp(ctx: ASTTransformerContext, node, result) -> None:
293
+ result.append(build_stmt(ctx, node.elt))
294
+
295
+ @staticmethod
296
+ def process_dictcomp(ctx: ASTTransformerContext, node, result) -> None:
297
+ key = build_stmt(ctx, node.key)
298
+ value = build_stmt(ctx, node.value)
299
+ result[key] = value
300
+
301
+ @staticmethod
302
+ def process_generators(ctx: ASTTransformerContext, node: ast.GeneratorExp, now_comp, func, result):
303
+ if now_comp >= len(node.generators):
304
+ return func(ctx, node, result)
305
+ with ctx.static_scope_guard():
306
+ _iter = build_stmt(ctx, node.generators[now_comp].iter)
307
+
308
+ if isinstance(_iter, impl.Expr) and _iter.ptr.is_tensor():
309
+ shape = _iter.ptr.get_shape()
310
+ flattened = [Expr(x) for x in ctx.ast_builder.expand_exprs([_iter.ptr])]
311
+ _iter = reshape_list(flattened, shape)
312
+
313
+ for value in _iter:
314
+ with ctx.variable_scope_guard():
315
+ ASTTransformer.build_assign_unpack(ctx, node.generators[now_comp].target, value, True)
316
+ with ctx.static_scope_guard():
317
+ build_stmts(ctx, node.generators[now_comp].ifs)
318
+ ASTTransformer.process_ifs(ctx, node, now_comp, 0, func, result)
319
+ return None
320
+
321
+ @staticmethod
322
+ def process_ifs(ctx: ASTTransformerContext, node: ast.If, now_comp, now_if, func, result):
323
+ if now_if >= len(node.generators[now_comp].ifs):
324
+ return ASTTransformer.process_generators(ctx, node, now_comp + 1, func, result)
325
+ cond = node.generators[now_comp].ifs[now_if].ptr
326
+ if cond:
327
+ ASTTransformer.process_ifs(ctx, node, now_comp, now_if + 1, func, result)
328
+
329
+ return None
330
+
331
+ @staticmethod
332
+ def build_ListComp(ctx: ASTTransformerContext, node: ast.ListComp):
333
+ result = []
334
+ ASTTransformer.process_generators(ctx, node, 0, ASTTransformer.process_listcomp, result)
335
+ node.ptr = result
336
+ return node.ptr
337
+
338
+ @staticmethod
339
+ def build_DictComp(ctx: ASTTransformerContext, node: ast.DictComp):
340
+ result = {}
341
+ ASTTransformer.process_generators(ctx, node, 0, ASTTransformer.process_dictcomp, result)
342
+ node.ptr = result
343
+ return node.ptr
344
+
345
+ @staticmethod
346
+ def build_Index(ctx: ASTTransformerContext, node: ast.Index):
347
+ node.ptr = build_stmt(ctx, node.value)
348
+ return node.ptr
349
+
350
+ @staticmethod
351
+ def build_Constant(ctx: ASTTransformerContext, node: ast.Constant):
352
+ node.ptr = node.value
353
+ return node.ptr
354
+
355
+ @staticmethod
356
+ def build_Num(ctx: ASTTransformerContext, node: ast.Num):
357
+ node.ptr = node.n
358
+ return node.ptr
359
+
360
+ @staticmethod
361
+ def build_Str(ctx: ASTTransformerContext, node: ast.Str):
362
+ node.ptr = node.s
363
+ return node.ptr
364
+
365
+ @staticmethod
366
+ def build_Bytes(ctx: ASTTransformerContext, node: ast.Bytes):
367
+ node.ptr = node.s
368
+ return node.ptr
369
+
370
+ @staticmethod
371
+ def build_NameConstant(ctx: ASTTransformerContext, node: ast.NameConstant):
372
+ node.ptr = node.value
373
+ return node.ptr
374
+
375
+ @staticmethod
376
+ def build_keyword(ctx: ASTTransformerContext, node: ast.keyword):
377
+ build_stmt(ctx, node.value)
378
+ if node.arg is None:
379
+ node.ptr = node.value.ptr
380
+ else:
381
+ node.ptr = {node.arg: node.value.ptr}
382
+ return node.ptr
383
+
384
+ @staticmethod
385
+ def build_Starred(ctx: ASTTransformerContext, node: ast.Starred):
386
+ node.ptr = build_stmt(ctx, node.value)
387
+ return node.ptr
388
+
389
+ @staticmethod
390
+ def build_FormattedValue(ctx: ASTTransformerContext, node: ast.FormattedValue):
391
+ node.ptr = build_stmt(ctx, node.value)
392
+ if node.format_spec is None or len(node.format_spec.values) == 0:
393
+ return node.ptr
394
+ values = node.format_spec.values
395
+ assert len(values) == 1
396
+ format_str = values[0].s
397
+ assert format_str is not None
398
+ # distinguished from normal list
399
+ return ["__ti_fmt_value__", node.ptr, format_str]
400
+
401
+ @staticmethod
402
+ def build_JoinedStr(ctx: ASTTransformerContext, node: ast.JoinedStr):
403
+ str_spec = ""
404
+ args = []
405
+ for sub_node in node.values:
406
+ if isinstance(sub_node, ast.FormattedValue):
407
+ str_spec += "{}"
408
+ args.append(build_stmt(ctx, sub_node))
409
+ elif isinstance(sub_node, ast.Constant):
410
+ str_spec += sub_node.value
411
+ elif isinstance(sub_node, ast.Str):
412
+ str_spec += sub_node.s
413
+ else:
414
+ raise GsTaichiSyntaxError("Invalid value for fstring.")
415
+
416
+ args.insert(0, str_spec)
417
+ node.ptr = impl.ti_format(*args)
418
+ return node.ptr
419
+
420
+ @staticmethod
421
+ def build_Call(ctx: ASTTransformerContext, node: ast.Call) -> Any | None:
422
+ return CallTransformer.build_Call(ctx, node, build_stmt, build_stmts)
423
+
424
+ @staticmethod
425
+ def build_FunctionDef(ctx: ASTTransformerContext, node: ast.FunctionDef) -> None:
426
+ FunctionDefTransformer.build_FunctionDef(ctx, node, build_stmts)
427
+
428
+ @staticmethod
429
+ def build_Return(ctx: ASTTransformerContext, node: ast.Return) -> None:
430
+ if not ctx.is_real_function:
431
+ if ctx.is_in_non_static_control_flow():
432
+ raise GsTaichiSyntaxError("Return inside non-static if/for is not supported")
433
+ if node.value is not None:
434
+ build_stmt(ctx, node.value)
435
+ if node.value is None or node.value.ptr is None:
436
+ if not ctx.is_real_function:
437
+ ctx.returned = ReturnStatus.ReturnedVoid
438
+ return None
439
+ if ctx.is_kernel or ctx.is_real_function:
440
+ # TODO: check if it's at the end of a kernel, throw GsTaichiSyntaxError if not
441
+ if ctx.func.return_type is None:
442
+ raise GsTaichiSyntaxError(
443
+ f'A {"kernel" if ctx.is_kernel else "function"} '
444
+ "with a return value must be annotated "
445
+ "with a return type, e.g. def func() -> ti.f32"
446
+ )
447
+ return_exprs = []
448
+ if len(ctx.func.return_type) == 1:
449
+ node.value.ptr = [node.value.ptr]
450
+ assert len(ctx.func.return_type) == len(node.value.ptr)
451
+ for return_type, ptr in zip(ctx.func.return_type, node.value.ptr):
452
+ if id(return_type) in primitive_types.type_ids:
453
+ if isinstance(ptr, Expr):
454
+ if ptr.is_tensor() or ptr.is_struct() or ptr.element_type() not in primitive_types.all_types:
455
+ raise GsTaichiRuntimeTypeError.get_ret(str(return_type), ptr)
456
+ elif not isinstance(ptr, (float, int, np.floating, np.integer)):
457
+ raise GsTaichiRuntimeTypeError.get_ret(str(return_type), ptr)
458
+ return_exprs += [ti_ops.cast(expr.Expr(ptr), return_type).ptr]
459
+ elif isinstance(return_type, MatrixType):
460
+ values = ptr
461
+ if isinstance(values, Matrix):
462
+ if values.ndim != ctx.func.return_type.ndim:
463
+ raise GsTaichiRuntimeTypeError(
464
+ f"Return matrix ndim mismatch, expecting={return_type.ndim}, got={values.ndim}."
465
+ )
466
+ elif return_type.get_shape() != values.get_shape():
467
+ raise GsTaichiRuntimeTypeError(
468
+ f"Return matrix shape mismatch, expecting={return_type.get_shape()}, got={values.get_shape()}."
469
+ )
470
+ values = (
471
+ itertools.chain.from_iterable(values.to_list())
472
+ if values.ndim == 1
473
+ else iter(values.to_list())
474
+ )
475
+ elif isinstance(values, Expr):
476
+ if not values.is_tensor():
477
+ raise GsTaichiRuntimeTypeError.get_ret(return_type.to_string(), ptr)
478
+ elif (
479
+ return_type.dtype in primitive_types.real_types
480
+ and not values.element_type() in primitive_types.all_types
481
+ ):
482
+ raise GsTaichiRuntimeTypeError.get_ret(return_type.dtype.to_string(), values.element_type())
483
+ elif (
484
+ return_type.dtype in primitive_types.integer_types
485
+ and not values.element_type() in primitive_types.integer_types
486
+ ):
487
+ raise GsTaichiRuntimeTypeError.get_ret(return_type.dtype.to_string(), values.element_type())
488
+ elif len(values.get_shape()) != return_type.ndim:
489
+ raise GsTaichiRuntimeTypeError(
490
+ f"Return matrix ndim mismatch, expecting={return_type.ndim}, got={len(values.get_shape())}."
491
+ )
492
+ elif return_type.get_shape() != values.get_shape():
493
+ raise GsTaichiRuntimeTypeError(
494
+ f"Return matrix shape mismatch, expecting={return_type.get_shape()}, got={values.get_shape()}."
495
+ )
496
+ values = [values]
497
+ else:
498
+ np_array = np.array(values)
499
+ dt, shape, ndim = np_array.dtype, np_array.shape, np_array.ndim
500
+ if return_type.dtype in primitive_types.real_types and dt not in (
501
+ float,
502
+ int,
503
+ np.floating,
504
+ np.integer,
505
+ ):
506
+ raise GsTaichiRuntimeTypeError.get_ret(return_type.dtype.to_string(), dt)
507
+ elif return_type.dtype in primitive_types.integer_types and dt not in (int, np.integer):
508
+ raise GsTaichiRuntimeTypeError.get_ret(return_type.dtype.to_string(), dt)
509
+ elif ndim != return_type.ndim:
510
+ raise GsTaichiRuntimeTypeError(
511
+ f"Return matrix ndim mismatch, expecting={return_type.ndim}, got={ndim}."
512
+ )
513
+ elif return_type.get_shape() != shape:
514
+ raise GsTaichiRuntimeTypeError(
515
+ f"Return matrix shape mismatch, expecting={return_type.get_shape()}, got={shape}."
516
+ )
517
+ values = [values]
518
+ return_exprs += [ti_ops.cast(exp, return_type.dtype) for exp in values]
519
+ elif isinstance(return_type, StructType):
520
+ if not isinstance(ptr, Struct) or not isinstance(ptr, return_type):
521
+ raise GsTaichiRuntimeTypeError.get_ret(str(return_type), ptr)
522
+ values = ptr
523
+ assert isinstance(values, Struct)
524
+ return_exprs += expr._get_flattened_ptrs(values)
525
+ else:
526
+ raise GsTaichiSyntaxError("The return type is not supported now!")
527
+ ctx.ast_builder.create_kernel_exprgroup_return(
528
+ expr.make_expr_group(return_exprs), _ti_core.DebugInfo(ctx.get_pos_info(node))
529
+ )
530
+ else:
531
+ ctx.return_data = node.value.ptr
532
+ if ctx.func.return_type is not None:
533
+ if len(ctx.func.return_type) == 1:
534
+ ctx.return_data = [ctx.return_data]
535
+ for i, return_type in enumerate(ctx.func.return_type):
536
+ if id(return_type) in primitive_types.type_ids:
537
+ ctx.return_data[i] = ti_ops.cast(ctx.return_data[i], return_type)
538
+ if len(ctx.func.return_type) == 1:
539
+ ctx.return_data = ctx.return_data[0]
540
+ if not ctx.is_real_function:
541
+ ctx.returned = ReturnStatus.ReturnedValue
542
+ return None
543
+
544
+ @staticmethod
545
+ def build_Module(ctx: ASTTransformerContext, node: ast.Module) -> None:
546
+ with ctx.variable_scope_guard():
547
+ # Do NOT use |build_stmts| which inserts 'del' statements to the
548
+ # end and deletes parameters passed into the module
549
+ for stmt in node.body:
550
+ build_stmt(ctx, stmt)
551
+ return None
552
+
553
+ @staticmethod
554
+ def build_attribute_if_is_dynamic_snode_method(ctx: ASTTransformerContext, node) -> bool:
555
+ is_subscript = isinstance(node.value, ast.Subscript)
556
+ names = ("append", "deactivate", "length")
557
+ if node.attr not in names:
558
+ return False
559
+ if is_subscript:
560
+ x = node.value.value.ptr
561
+ indices = node.value.slice.ptr
562
+ else:
563
+ x = node.value.ptr
564
+ indices = []
565
+ if not isinstance(x, Field):
566
+ return False
567
+ if not x.parent().ptr.type == _ti_core.SNodeType.dynamic:
568
+ return False
569
+ field_dim = x.snode.ptr.num_active_indices()
570
+ indices_expr_group = make_expr_group(*indices)
571
+ index_dim = indices_expr_group.size()
572
+ if field_dim != index_dim + 1:
573
+ return False
574
+ if node.attr == "append":
575
+ node.ptr = lambda val: append(x.parent(), indices, val)
576
+ elif node.attr == "deactivate":
577
+ node.ptr = lambda: deactivate(x.parent(), indices)
578
+ else:
579
+ node.ptr = lambda: length(x.parent(), indices)
580
+ return True
581
+
582
+ @staticmethod
583
+ def build_Attribute(ctx: ASTTransformerContext, node: ast.Attribute):
584
+ # There are two valid cases for the methods of Dynamic SNode:
585
+ #
586
+ # 1. x[i, j].append (where the dimension of the field (3 in this case) is equal to one plus the number of the
587
+ # indices (2 in this case) )
588
+ #
589
+ # 2. x.append (where the dimension of the field is one, equal to x[()].append)
590
+ #
591
+ # For the first case, the AST (simplified) is like node = Attribute(value=Subscript(value=x, slice=[i, j]),
592
+ # attr="append"), when we build_stmt(node.value)(build the expression of the Subscript i.e. x[i, j]),
593
+ # it should build the expression of node.value.value (i.e. x) and node.value.slice (i.e. [i, j]), and raise a
594
+ # GsTaichiIndexError because the dimension of the field is not equal to the number of the indices. Therefore,
595
+ # when we meet the error, we can detect whether it is a method of Dynamic SNode and build the expression if
596
+ # it is by calling build_attribute_if_is_dynamic_snode_method. If we find that it is not a method of Dynamic
597
+ # SNode, we raise the error again.
598
+ #
599
+ # For the second case, the AST (simplified) is like node = Attribute(value=x, attr="append"), and it does not
600
+ # raise error when we build_stmt(node.value). Therefore, when we do not meet the error, we can also detect
601
+ # whether it is a method of Dynamic SNode and build the expression if it is by calling
602
+ # build_attribute_if_is_dynamic_snode_method. If we find that it is not a method of Dynamic SNode,
603
+ # we continue to process it as a normal attribute node.
604
+ try:
605
+ build_stmt(ctx, node.value)
606
+ except Exception as e:
607
+ e = handle_exception_from_cpp(e)
608
+ if isinstance(e, GsTaichiIndexError):
609
+ node.value.ptr = None
610
+ if ASTTransformer.build_attribute_if_is_dynamic_snode_method(ctx, node):
611
+ return node.ptr
612
+ raise e
613
+
614
+ if ASTTransformer.build_attribute_if_is_dynamic_snode_method(ctx, node):
615
+ return node.ptr
616
+
617
+ if isinstance(node.value.ptr, Expr) and not hasattr(node.value.ptr, node.attr):
618
+ if node.attr in Matrix._swizzle_to_keygroup:
619
+ keygroup = Matrix._swizzle_to_keygroup[node.attr]
620
+ Matrix._keygroup_to_checker[keygroup](node.value.ptr, node.attr)
621
+ attr_len = len(node.attr)
622
+ if attr_len == 1:
623
+ node.ptr = Expr(
624
+ impl.get_runtime()
625
+ .compiling_callable.ast_builder()
626
+ .expr_subscript(
627
+ node.value.ptr.ptr,
628
+ make_expr_group(keygroup.index(node.attr)),
629
+ _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
630
+ )
631
+ )
632
+ else:
633
+ node.ptr = Expr(
634
+ _ti_core.subscript_with_multiple_indices(
635
+ node.value.ptr.ptr,
636
+ [make_expr_group(keygroup.index(ch)) for ch in node.attr],
637
+ (attr_len,),
638
+ _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
639
+ )
640
+ )
641
+ else:
642
+ from gstaichi.lang import ( # pylint: disable=C0415
643
+ matrix_ops as tensor_ops,
644
+ )
645
+
646
+ node.ptr = getattr(tensor_ops, node.attr)
647
+ setattr(node, "caller", node.value.ptr)
648
+ else:
649
+ node.ptr = getattr(node.value.ptr, node.attr)
650
+ return node.ptr
651
+
652
+ @staticmethod
653
+ def build_BinOp(ctx: ASTTransformerContext, node: ast.BinOp):
654
+ build_stmt(ctx, node.left)
655
+ build_stmt(ctx, node.right)
656
+ # pylint: disable-msg=C0415
657
+ from gstaichi.lang.matrix_ops import matmul
658
+
659
+ op = {
660
+ ast.Add: lambda l, r: l + r,
661
+ ast.Sub: lambda l, r: l - r,
662
+ ast.Mult: lambda l, r: l * r,
663
+ ast.Div: lambda l, r: l / r,
664
+ ast.FloorDiv: lambda l, r: l // r,
665
+ ast.Mod: lambda l, r: l % r,
666
+ ast.Pow: lambda l, r: l**r,
667
+ ast.LShift: lambda l, r: l << r,
668
+ ast.RShift: lambda l, r: l >> r,
669
+ ast.BitOr: lambda l, r: l | r,
670
+ ast.BitXor: lambda l, r: l ^ r,
671
+ ast.BitAnd: lambda l, r: l & r,
672
+ ast.MatMult: matmul,
673
+ }.get(type(node.op))
674
+ try:
675
+ node.ptr = op(node.left.ptr, node.right.ptr)
676
+ except TypeError as e:
677
+ raise GsTaichiTypeError(str(e)) from None
678
+ return node.ptr
679
+
680
+ @staticmethod
681
+ def build_AugAssign(ctx: ASTTransformerContext, node: ast.AugAssign):
682
+ build_stmt(ctx, node.target)
683
+ build_stmt(ctx, node.value)
684
+ if isinstance(node.target, ast.Name) and node.target.id in ctx.kernel_args:
685
+ raise GsTaichiSyntaxError(
686
+ f'Kernel argument "{node.target.id}" is immutable in the kernel. '
687
+ f"If you want to change its value, please create a new variable."
688
+ )
689
+ node.ptr = node.target.ptr._augassign(node.value.ptr, type(node.op).__name__)
690
+ return node.ptr
691
+
692
+ @staticmethod
693
+ def build_UnaryOp(ctx: ASTTransformerContext, node: ast.UnaryOp):
694
+ build_stmt(ctx, node.operand)
695
+ op = {
696
+ ast.UAdd: lambda l: l,
697
+ ast.USub: lambda l: -l,
698
+ ast.Not: ti_ops.logical_not,
699
+ ast.Invert: lambda l: ~l,
700
+ }.get(type(node.op))
701
+ node.ptr = op(node.operand.ptr)
702
+ return node.ptr
703
+
704
+ @staticmethod
705
+ def build_bool_op(op):
706
+ def inner(operands):
707
+ if len(operands) == 1:
708
+ return operands[0].ptr
709
+ return op(operands[0].ptr, inner(operands[1:]))
710
+
711
+ return inner
712
+
713
+ @staticmethod
714
+ def build_static_and(operands):
715
+ for operand in operands:
716
+ if not operand.ptr:
717
+ return operand.ptr
718
+ return operands[-1].ptr
719
+
720
+ @staticmethod
721
+ def build_static_or(operands):
722
+ for operand in operands:
723
+ if operand.ptr:
724
+ return operand.ptr
725
+ return operands[-1].ptr
726
+
727
+ @staticmethod
728
+ def build_BoolOp(ctx: ASTTransformerContext, node: ast.BoolOp):
729
+ build_stmts(ctx, node.values)
730
+ if ctx.is_in_static_scope():
731
+ ops = {
732
+ ast.And: ASTTransformer.build_static_and,
733
+ ast.Or: ASTTransformer.build_static_or,
734
+ }
735
+ elif impl.get_runtime().short_circuit_operators:
736
+ ops = {
737
+ ast.And: ASTTransformer.build_bool_op(ti_ops.logical_and),
738
+ ast.Or: ASTTransformer.build_bool_op(ti_ops.logical_or),
739
+ }
740
+ else:
741
+ ops = {
742
+ ast.And: ASTTransformer.build_bool_op(ti_ops.bit_and),
743
+ ast.Or: ASTTransformer.build_bool_op(ti_ops.bit_or),
744
+ }
745
+ op = ops.get(type(node.op))
746
+ node.ptr = op(node.values)
747
+ return node.ptr
748
+
749
+ @staticmethod
750
+ def build_Compare(ctx: ASTTransformerContext, node: ast.Compare):
751
+ build_stmt(ctx, node.left)
752
+ build_stmts(ctx, node.comparators)
753
+ ops = {
754
+ ast.Eq: lambda l, r: l == r,
755
+ ast.NotEq: lambda l, r: l != r,
756
+ ast.Lt: lambda l, r: l < r,
757
+ ast.LtE: lambda l, r: l <= r,
758
+ ast.Gt: lambda l, r: l > r,
759
+ ast.GtE: lambda l, r: l >= r,
760
+ }
761
+ ops_static = {
762
+ ast.In: lambda l, r: l in r,
763
+ ast.NotIn: lambda l, r: l not in r,
764
+ }
765
+ if ctx.is_in_static_scope():
766
+ ops = {**ops, **ops_static}
767
+ operands = [node.left.ptr] + [comparator.ptr for comparator in node.comparators]
768
+ val = True
769
+ for i, node_op in enumerate(node.ops):
770
+ if isinstance(node_op, (ast.Is, ast.IsNot)):
771
+ name = "is" if isinstance(node_op, ast.Is) else "is not"
772
+ raise GsTaichiSyntaxError(f'Operator "{name}" in GsTaichi scope is not supported.')
773
+ l = operands[i]
774
+ r = operands[i + 1]
775
+ op = ops.get(type(node_op))
776
+
777
+ if op is None:
778
+ if type(node_op) in ops_static:
779
+ raise GsTaichiSyntaxError(f'"{type(node_op).__name__}" is only supported inside `ti.static`.')
780
+ else:
781
+ raise GsTaichiSyntaxError(f'"{type(node_op).__name__}" is not supported in GsTaichi kernels.')
782
+ val = ti_ops.logical_and(val, op(l, r))
783
+ if not isinstance(val, (bool, np.bool_)):
784
+ val = ti_ops.cast(val, primitive_types.u1)
785
+ node.ptr = val
786
+ return node.ptr
787
+
788
+ @staticmethod
789
+ def get_for_loop_targets(node: ast.Name | ast.Tuple | Any) -> list:
790
+ """
791
+ Returns the list of indices of the for loop |node|.
792
+ See also: https://docs.python.org/3/library/ast.html#ast.For
793
+ """
794
+ if isinstance(node.target, ast.Name):
795
+ return [node.target.id]
796
+ assert isinstance(node.target, ast.Tuple)
797
+ return [name.id for name in node.target.elts]
798
+
799
+ @staticmethod
800
+ def build_static_for(ctx: ASTTransformerContext, node: ast.For, is_grouped: bool) -> None:
801
+ ti_unroll_limit = impl.get_runtime().unrolling_limit
802
+ if is_grouped:
803
+ assert len(node.iter.args[0].args) == 1
804
+ ndrange_arg = build_stmt(ctx, node.iter.args[0].args[0])
805
+ if not isinstance(ndrange_arg, _Ndrange):
806
+ raise GsTaichiSyntaxError("Only 'ti.ndrange' is allowed in 'ti.static(ti.grouped(...))'.")
807
+ targets = ASTTransformer.get_for_loop_targets(node)
808
+ if len(targets) != 1:
809
+ raise GsTaichiSyntaxError(f"Group for should have 1 loop target, found {len(targets)}")
810
+ target = targets[0]
811
+ iter_time = 0
812
+ alert_already = False
813
+
814
+ for value in impl.grouped(ndrange_arg):
815
+ iter_time += 1
816
+ if not alert_already and ti_unroll_limit and iter_time > ti_unroll_limit:
817
+ alert_already = True
818
+ warnings.warn_explicit(
819
+ f"""You are unrolling more than
820
+ {ti_unroll_limit} iterations, so the compile time may be extremely long.
821
+ You can use a non-static for loop if you want to decrease the compile time.
822
+ You can disable this warning by setting ti.init(unrolling_limit=0).""",
823
+ SyntaxWarning,
824
+ ctx.file,
825
+ node.lineno + ctx.lineno_offset,
826
+ module="gstaichi",
827
+ )
828
+
829
+ with ctx.variable_scope_guard():
830
+ ctx.create_variable(target, value)
831
+ build_stmts(ctx, node.body)
832
+ status = ctx.loop_status()
833
+ if status == LoopStatus.Break:
834
+ break
835
+ elif status == LoopStatus.Continue:
836
+ ctx.set_loop_status(LoopStatus.Normal)
837
+ else:
838
+ build_stmt(ctx, node.iter)
839
+ targets = ASTTransformer.get_for_loop_targets(node)
840
+
841
+ iter_time = 0
842
+ alert_already = False
843
+ for target_values in node.iter.ptr:
844
+ if not isinstance(target_values, collections.abc.Sequence) or len(targets) == 1:
845
+ target_values = [target_values]
846
+
847
+ iter_time += 1
848
+ if not alert_already and ti_unroll_limit and iter_time > ti_unroll_limit:
849
+ alert_already = True
850
+ warnings.warn_explicit(
851
+ f"""You are unrolling more than
852
+ {ti_unroll_limit} iterations, so the compile time may be extremely long.
853
+ You can use a non-static for loop if you want to decrease the compile time.
854
+ You can disable this warning by setting ti.init(unrolling_limit=0).""",
855
+ SyntaxWarning,
856
+ ctx.file,
857
+ node.lineno + ctx.lineno_offset,
858
+ module="gstaichi",
859
+ )
860
+
861
+ with ctx.variable_scope_guard():
862
+ for target, target_value in zip(targets, target_values):
863
+ ctx.create_variable(target, target_value)
864
+ build_stmts(ctx, node.body)
865
+ status = ctx.loop_status()
866
+ if status == LoopStatus.Break:
867
+ break
868
+ elif status == LoopStatus.Continue:
869
+ ctx.set_loop_status(LoopStatus.Normal)
870
+ return None
871
+
872
+ @staticmethod
873
+ def build_range_for(ctx: ASTTransformerContext, node: ast.For) -> None:
874
+ with ctx.variable_scope_guard():
875
+ loop_name = node.target.id
876
+ ctx.check_loop_var(loop_name)
877
+ loop_var = expr.Expr(ctx.ast_builder.make_id_expr(""))
878
+ ctx.create_variable(loop_name, loop_var)
879
+ if len(node.iter.args) not in [1, 2]:
880
+ raise GsTaichiSyntaxError(f"Range should have 1 or 2 arguments, found {len(node.iter.args)}")
881
+ if len(node.iter.args) == 2:
882
+ begin_expr = expr.Expr(build_stmt(ctx, node.iter.args[0]))
883
+ end_expr = expr.Expr(build_stmt(ctx, node.iter.args[1]))
884
+
885
+ # Warning for implicit dtype conversion
886
+ boundary_type_cast_warning(begin_expr)
887
+ boundary_type_cast_warning(end_expr)
888
+
889
+ begin = ti_ops.cast(begin_expr, primitive_types.i32)
890
+ end = ti_ops.cast(end_expr, primitive_types.i32)
891
+
892
+ else:
893
+ end_expr = expr.Expr(build_stmt(ctx, node.iter.args[0]))
894
+
895
+ # Warning for implicit dtype conversion
896
+ boundary_type_cast_warning(end_expr)
897
+
898
+ begin = ti_ops.cast(expr.Expr(0), primitive_types.i32)
899
+ end = ti_ops.cast(end_expr, primitive_types.i32)
900
+
901
+ for_di = _ti_core.DebugInfo(ctx.get_pos_info(node))
902
+ ctx.ast_builder.begin_frontend_range_for(loop_var.ptr, begin.ptr, end.ptr, for_di)
903
+ build_stmts(ctx, node.body)
904
+ ctx.ast_builder.end_frontend_range_for()
905
+ return None
906
+
907
+ @staticmethod
908
+ def build_ndrange_for(ctx: ASTTransformerContext, node: ast.For) -> None:
909
+ with ctx.variable_scope_guard():
910
+ ndrange_var = impl.expr_init(build_stmt(ctx, node.iter))
911
+ ndrange_begin = ti_ops.cast(expr.Expr(0), primitive_types.i32)
912
+ ndrange_end = ti_ops.cast(
913
+ expr.Expr(impl.subscript(ctx.ast_builder, ndrange_var.acc_dimensions, 0)),
914
+ primitive_types.i32,
915
+ )
916
+ ndrange_loop_var = expr.Expr(ctx.ast_builder.make_id_expr(""))
917
+ for_di = _ti_core.DebugInfo(ctx.get_pos_info(node))
918
+ ctx.ast_builder.begin_frontend_range_for(ndrange_loop_var.ptr, ndrange_begin.ptr, ndrange_end.ptr, for_di)
919
+ I = impl.expr_init(ndrange_loop_var)
920
+ targets = ASTTransformer.get_for_loop_targets(node)
921
+ if len(targets) != len(ndrange_var.dimensions):
922
+ raise GsTaichiSyntaxError(
923
+ "Ndrange for loop with number of the loop variables not equal to "
924
+ "the dimension of the ndrange is not supported. "
925
+ "Please check if the number of arguments of ti.ndrange() is equal to "
926
+ "the number of the loop variables."
927
+ )
928
+ for i, target in enumerate(targets):
929
+ if i + 1 < len(targets):
930
+ target_tmp = impl.expr_init(I // ndrange_var.acc_dimensions[i + 1])
931
+ else:
932
+ target_tmp = impl.expr_init(I)
933
+ ctx.create_variable(
934
+ target,
935
+ impl.expr_init(
936
+ target_tmp
937
+ + impl.subscript(
938
+ ctx.ast_builder,
939
+ impl.subscript(ctx.ast_builder, ndrange_var.bounds, i),
940
+ 0,
941
+ )
942
+ ),
943
+ )
944
+ if i + 1 < len(targets):
945
+ I._assign(I - target_tmp * ndrange_var.acc_dimensions[i + 1])
946
+ build_stmts(ctx, node.body)
947
+ ctx.ast_builder.end_frontend_range_for()
948
+ return None
949
+
950
+ @staticmethod
951
+ def build_grouped_ndrange_for(ctx: ASTTransformerContext, node: ast.For) -> None:
952
+ with ctx.variable_scope_guard():
953
+ ndrange_var = impl.expr_init(build_stmt(ctx, node.iter.args[0]))
954
+ ndrange_begin = ti_ops.cast(expr.Expr(0), primitive_types.i32)
955
+ ndrange_end = ti_ops.cast(
956
+ expr.Expr(impl.subscript(ctx.ast_builder, ndrange_var.acc_dimensions, 0)),
957
+ primitive_types.i32,
958
+ )
959
+ ndrange_loop_var = expr.Expr(ctx.ast_builder.make_id_expr(""))
960
+ for_di = _ti_core.DebugInfo(ctx.get_pos_info(node))
961
+ ctx.ast_builder.begin_frontend_range_for(ndrange_loop_var.ptr, ndrange_begin.ptr, ndrange_end.ptr, for_di)
962
+
963
+ targets = ASTTransformer.get_for_loop_targets(node)
964
+ if len(targets) != 1:
965
+ raise GsTaichiSyntaxError(f"Group for should have 1 loop target, found {len(targets)}")
966
+ target = targets[0]
967
+ mat = matrix.make_matrix([0] * len(ndrange_var.dimensions), dt=primitive_types.i32)
968
+ target_var = impl.expr_init(mat)
969
+
970
+ ctx.create_variable(target, target_var)
971
+ I = impl.expr_init(ndrange_loop_var)
972
+ for i in range(len(ndrange_var.dimensions)):
973
+ if i + 1 < len(ndrange_var.dimensions):
974
+ target_tmp = I // ndrange_var.acc_dimensions[i + 1]
975
+ else:
976
+ target_tmp = I
977
+ impl.subscript(ctx.ast_builder, target_var, i)._assign(target_tmp + ndrange_var.bounds[i][0])
978
+ if i + 1 < len(ndrange_var.dimensions):
979
+ I._assign(I - target_tmp * ndrange_var.acc_dimensions[i + 1])
980
+ build_stmts(ctx, node.body)
981
+ ctx.ast_builder.end_frontend_range_for()
982
+ return None
983
+
984
+ @staticmethod
985
+ def build_struct_for(ctx: ASTTransformerContext, node: ast.For, is_grouped: bool) -> None:
986
+ # for i, j in x
987
+ # for I in ti.grouped(x)
988
+ targets = ASTTransformer.get_for_loop_targets(node)
989
+
990
+ for target in targets:
991
+ ctx.check_loop_var(target)
992
+
993
+ with ctx.variable_scope_guard():
994
+ if is_grouped:
995
+ if len(targets) != 1:
996
+ raise GsTaichiSyntaxError(f"Group for should have 1 loop target, found {len(targets)}")
997
+ target = targets[0]
998
+ loop_var = build_stmt(ctx, node.iter)
999
+ loop_indices = expr.make_var_list(size=len(loop_var.shape), ast_builder=ctx.ast_builder)
1000
+ expr_group = expr.make_expr_group(loop_indices)
1001
+ impl.begin_frontend_struct_for(ctx.ast_builder, expr_group, loop_var)
1002
+ ctx.create_variable(target, matrix.make_matrix(loop_indices, dt=primitive_types.i32))
1003
+ build_stmts(ctx, node.body)
1004
+ ctx.ast_builder.end_frontend_struct_for()
1005
+ else:
1006
+ _vars = []
1007
+ for name in targets:
1008
+ var = expr.Expr(ctx.ast_builder.make_id_expr(""))
1009
+ _vars.append(var)
1010
+ ctx.create_variable(name, var)
1011
+ loop_var = node.iter.ptr
1012
+ expr_group = expr.make_expr_group(*_vars)
1013
+ impl.begin_frontend_struct_for(ctx.ast_builder, expr_group, loop_var)
1014
+ build_stmts(ctx, node.body)
1015
+ ctx.ast_builder.end_frontend_struct_for()
1016
+ return None
1017
+
1018
+ @staticmethod
1019
+ def build_mesh_for(ctx: ASTTransformerContext, node: ast.For) -> None:
1020
+ targets = ASTTransformer.get_for_loop_targets(node)
1021
+ if len(targets) != 1:
1022
+ raise GsTaichiSyntaxError("Mesh for should have 1 loop target, found {len(targets)}")
1023
+ target = targets[0]
1024
+
1025
+ with ctx.variable_scope_guard():
1026
+ var = expr.Expr(ctx.ast_builder.make_id_expr(""))
1027
+ ctx.mesh = node.iter.ptr.mesh
1028
+ assert isinstance(ctx.mesh, impl.MeshInstance)
1029
+ mesh_idx = mesh.MeshElementFieldProxy(ctx.mesh, node.iter.ptr._type, var.ptr)
1030
+ ctx.create_variable(target, mesh_idx)
1031
+ ctx.ast_builder.begin_frontend_mesh_for(
1032
+ mesh_idx.ptr,
1033
+ ctx.mesh.mesh_ptr,
1034
+ node.iter.ptr._type,
1035
+ _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
1036
+ )
1037
+ build_stmts(ctx, node.body)
1038
+ ctx.mesh = None
1039
+ ctx.ast_builder.end_frontend_mesh_for()
1040
+ return None
1041
+
1042
+ @staticmethod
1043
+ def build_nested_mesh_for(ctx: ASTTransformerContext, node: ast.For) -> None:
1044
+ targets = ASTTransformer.get_for_loop_targets(node)
1045
+ if len(targets) != 1:
1046
+ raise GsTaichiSyntaxError("Nested-mesh for should have 1 loop target, found {len(targets)}")
1047
+ target = targets[0]
1048
+
1049
+ with ctx.variable_scope_guard():
1050
+ ctx.mesh = node.iter.ptr.mesh
1051
+ assert isinstance(ctx.mesh, impl.MeshInstance)
1052
+ loop_name = node.target.id + "_index__"
1053
+ loop_var = expr.Expr(ctx.ast_builder.make_id_expr(""))
1054
+ ctx.create_variable(loop_name, loop_var)
1055
+ begin = expr.Expr(0)
1056
+ end = ti_ops.cast(node.iter.ptr.size, primitive_types.i32)
1057
+ for_di = _ti_core.DebugInfo(ctx.get_pos_info(node))
1058
+ ctx.ast_builder.begin_frontend_range_for(loop_var.ptr, begin.ptr, end.ptr, for_di)
1059
+ entry_expr = _ti_core.get_relation_access(
1060
+ ctx.mesh.mesh_ptr,
1061
+ node.iter.ptr.from_index.ptr,
1062
+ node.iter.ptr.to_element_type,
1063
+ loop_var.ptr,
1064
+ )
1065
+ entry_expr.type_check(impl.get_runtime().prog.config())
1066
+ mesh_idx = mesh.MeshElementFieldProxy(ctx.mesh, node.iter.ptr.to_element_type, entry_expr)
1067
+ ctx.create_variable(target, mesh_idx)
1068
+ build_stmts(ctx, node.body)
1069
+ ctx.ast_builder.end_frontend_range_for()
1070
+
1071
+ return None
1072
+
1073
+ @staticmethod
1074
+ def build_For(ctx: ASTTransformerContext, node: ast.For) -> None:
1075
+ if node.orelse:
1076
+ raise GsTaichiSyntaxError("'else' clause for 'for' not supported in GsTaichi kernels")
1077
+ decorator = get_decorator(ctx, node.iter)
1078
+ double_decorator = ""
1079
+ if decorator != "" and len(node.iter.args) == 1:
1080
+ double_decorator = get_decorator(ctx, node.iter.args[0])
1081
+
1082
+ if decorator == "static":
1083
+ if double_decorator == "static":
1084
+ raise GsTaichiSyntaxError("'ti.static' cannot be nested")
1085
+ with ctx.loop_scope_guard(is_static=True):
1086
+ return ASTTransformer.build_static_for(ctx, node, double_decorator == "grouped")
1087
+ with ctx.loop_scope_guard():
1088
+ if decorator == "ndrange":
1089
+ if double_decorator != "":
1090
+ raise GsTaichiSyntaxError("No decorator is allowed inside 'ti.ndrange")
1091
+ return ASTTransformer.build_ndrange_for(ctx, node)
1092
+ if decorator == "grouped":
1093
+ if double_decorator == "static":
1094
+ raise GsTaichiSyntaxError("'ti.static' is not allowed inside 'ti.grouped'")
1095
+ elif double_decorator == "ndrange":
1096
+ return ASTTransformer.build_grouped_ndrange_for(ctx, node)
1097
+ elif double_decorator == "grouped":
1098
+ raise GsTaichiSyntaxError("'ti.grouped' cannot be nested")
1099
+ else:
1100
+ return ASTTransformer.build_struct_for(ctx, node, is_grouped=True)
1101
+ elif (
1102
+ isinstance(node.iter, ast.Call)
1103
+ and isinstance(node.iter.func, ast.Name)
1104
+ and node.iter.func.id == "range"
1105
+ ):
1106
+ return ASTTransformer.build_range_for(ctx, node)
1107
+ else:
1108
+ build_stmt(ctx, node.iter)
1109
+ if isinstance(node.iter.ptr, mesh.MeshElementField):
1110
+ if not _ti_core.is_extension_supported(impl.default_cfg().arch, _ti_core.Extension.mesh):
1111
+ raise Exception(
1112
+ "Backend " + str(impl.default_cfg().arch) + " doesn't support MeshGsTaichi extension"
1113
+ )
1114
+ return ASTTransformer.build_mesh_for(ctx, node)
1115
+ if isinstance(node.iter.ptr, mesh.MeshRelationAccessProxy):
1116
+ return ASTTransformer.build_nested_mesh_for(ctx, node)
1117
+ # Struct for
1118
+ return ASTTransformer.build_struct_for(ctx, node, is_grouped=False)
1119
+
1120
+ @staticmethod
1121
+ def build_While(ctx: ASTTransformerContext, node: ast.While) -> None:
1122
+ if node.orelse:
1123
+ raise GsTaichiSyntaxError("'else' clause for 'while' not supported in GsTaichi kernels")
1124
+
1125
+ with ctx.loop_scope_guard():
1126
+ stmt_dbg_info = _ti_core.DebugInfo(ctx.get_pos_info(node))
1127
+ ctx.ast_builder.begin_frontend_while(expr.Expr(1, dtype=primitive_types.i32).ptr, stmt_dbg_info)
1128
+ while_cond = build_stmt(ctx, node.test)
1129
+ impl.begin_frontend_if(ctx.ast_builder, while_cond, stmt_dbg_info)
1130
+ ctx.ast_builder.begin_frontend_if_true()
1131
+ ctx.ast_builder.pop_scope()
1132
+ ctx.ast_builder.begin_frontend_if_false()
1133
+ ctx.ast_builder.insert_break_stmt(stmt_dbg_info)
1134
+ ctx.ast_builder.pop_scope()
1135
+ build_stmts(ctx, node.body)
1136
+ ctx.ast_builder.pop_scope()
1137
+ return None
1138
+
1139
+ @staticmethod
1140
+ def build_If(ctx: ASTTransformerContext, node: ast.If) -> ast.If | None:
1141
+ build_stmt(ctx, node.test)
1142
+ is_static_if = get_decorator(ctx, node.test) == "static"
1143
+
1144
+ if is_static_if:
1145
+ if node.test.ptr:
1146
+ build_stmts(ctx, node.body)
1147
+ else:
1148
+ build_stmts(ctx, node.orelse)
1149
+ return node
1150
+
1151
+ with ctx.non_static_if_guard(node):
1152
+ stmt_dbg_info = _ti_core.DebugInfo(ctx.get_pos_info(node))
1153
+ impl.begin_frontend_if(ctx.ast_builder, node.test.ptr, stmt_dbg_info)
1154
+ ctx.ast_builder.begin_frontend_if_true()
1155
+ build_stmts(ctx, node.body)
1156
+ ctx.ast_builder.pop_scope()
1157
+ ctx.ast_builder.begin_frontend_if_false()
1158
+ build_stmts(ctx, node.orelse)
1159
+ ctx.ast_builder.pop_scope()
1160
+ return None
1161
+
1162
+ @staticmethod
1163
+ def build_Expr(ctx: ASTTransformerContext, node: ast.Expr) -> None:
1164
+ build_stmt(ctx, node.value)
1165
+ return None
1166
+
1167
+ @staticmethod
1168
+ def build_IfExp(ctx: ASTTransformerContext, node: ast.IfExp):
1169
+ build_stmt(ctx, node.test)
1170
+ build_stmt(ctx, node.body)
1171
+ build_stmt(ctx, node.orelse)
1172
+
1173
+ has_tensor_type = False
1174
+ if isinstance(node.test.ptr, expr.Expr) and node.test.ptr.is_tensor():
1175
+ has_tensor_type = True
1176
+ if isinstance(node.body.ptr, expr.Expr) and node.body.ptr.is_tensor():
1177
+ has_tensor_type = True
1178
+ if isinstance(node.orelse.ptr, expr.Expr) and node.orelse.ptr.is_tensor():
1179
+ has_tensor_type = True
1180
+
1181
+ if has_tensor_type:
1182
+ if isinstance(node.test.ptr, expr.Expr) and node.test.ptr.is_tensor():
1183
+ raise GsTaichiSyntaxError(
1184
+ "Using conditional expression for element-wise select operation on "
1185
+ "GsTaichi vectors/matrices is deprecated and removed starting from GsTaichi v1.5.0 "
1186
+ 'Please use "ti.select" instead.'
1187
+ )
1188
+ node.ptr = ti_ops.select(node.test.ptr, node.body.ptr, node.orelse.ptr)
1189
+ return node.ptr
1190
+
1191
+ is_static_if = get_decorator(ctx, node.test) == "static"
1192
+
1193
+ if is_static_if:
1194
+ if node.test.ptr:
1195
+ node.ptr = build_stmt(ctx, node.body)
1196
+ else:
1197
+ node.ptr = build_stmt(ctx, node.orelse)
1198
+ return node.ptr
1199
+
1200
+ node.ptr = ti_ops.ifte(node.test.ptr, node.body.ptr, node.orelse.ptr)
1201
+ return node.ptr
1202
+
1203
+ @staticmethod
1204
+ def _is_string_mod_args(msg) -> bool:
1205
+ # 1. str % (a, b, c, ...)
1206
+ # 2. str % single_item
1207
+ # Note that |msg.right| may not be a tuple.
1208
+ if not isinstance(msg, ast.BinOp):
1209
+ return False
1210
+ if not isinstance(msg.op, ast.Mod):
1211
+ return False
1212
+ if isinstance(msg.left, ast.Str):
1213
+ return True
1214
+ if isinstance(msg.left, ast.Constant) and isinstance(msg.left.value, str):
1215
+ return True
1216
+ return False
1217
+
1218
+ @staticmethod
1219
+ def _handle_string_mod_args(ctx: ASTTransformerContext, node):
1220
+ msg = build_stmt(ctx, node.left)
1221
+ args = build_stmt(ctx, node.right)
1222
+ if not isinstance(args, collections.abc.Sequence):
1223
+ args = (args,)
1224
+ args = [expr.Expr(x).ptr for x in args]
1225
+ return msg, args
1226
+
1227
+ @staticmethod
1228
+ def ti_format_list_to_assert_msg(raw) -> tuple[str, list]:
1229
+ # TODO: ignore formats here for now
1230
+ entries, _ = impl.ti_format_list_to_content_entries([raw])
1231
+ msg = ""
1232
+ args = []
1233
+ for entry in entries:
1234
+ if isinstance(entry, str):
1235
+ msg += entry
1236
+ elif isinstance(entry, _ti_core.ExprCxx):
1237
+ ty = entry.get_rvalue_type()
1238
+ if ty in primitive_types.real_types:
1239
+ msg += "%f"
1240
+ elif ty in primitive_types.integer_types:
1241
+ msg += "%d"
1242
+ else:
1243
+ raise GsTaichiSyntaxError(f"Unsupported data type: {type(ty)}")
1244
+ args.append(entry)
1245
+ else:
1246
+ raise GsTaichiSyntaxError(f"Unsupported type: {type(entry)}")
1247
+ return msg, args
1248
+
1249
+ @staticmethod
1250
+ def build_Assert(ctx: ASTTransformerContext, node: ast.Assert) -> None:
1251
+ extra_args = []
1252
+ if node.msg is not None:
1253
+ if ASTTransformer._is_string_mod_args(node.msg):
1254
+ msg, extra_args = ASTTransformer._handle_string_mod_args(ctx, node.msg)
1255
+ else:
1256
+ msg = build_stmt(ctx, node.msg)
1257
+ if isinstance(node.msg, ast.Constant):
1258
+ msg = str(msg)
1259
+ elif isinstance(node.msg, ast.Str):
1260
+ pass
1261
+ elif isinstance(msg, collections.abc.Sequence) and len(msg) > 0 and msg[0] == "__ti_format__":
1262
+ msg, extra_args = ASTTransformer.ti_format_list_to_assert_msg(msg)
1263
+ else:
1264
+ raise GsTaichiSyntaxError(f"assert info must be constant or formatted string, not {type(msg)}")
1265
+ else:
1266
+ msg = unparse(node.test)
1267
+ test = build_stmt(ctx, node.test)
1268
+ impl.ti_assert(test, msg.strip(), extra_args, _ti_core.DebugInfo(ctx.get_pos_info(node)))
1269
+ return None
1270
+
1271
+ @staticmethod
1272
+ def build_Break(ctx: ASTTransformerContext, node: ast.Break) -> None:
1273
+ if ctx.is_in_static_for():
1274
+ nearest_non_static_if = ctx.current_loop_scope().nearest_non_static_if
1275
+ if nearest_non_static_if:
1276
+ msg = ctx.get_pos_info(nearest_non_static_if.test)
1277
+ msg += (
1278
+ "You are trying to `break` a static `for` loop, "
1279
+ "but the `break` statement is inside a non-static `if`. "
1280
+ )
1281
+ raise GsTaichiSyntaxError(msg)
1282
+ ctx.set_loop_status(LoopStatus.Break)
1283
+ else:
1284
+ ctx.ast_builder.insert_break_stmt(_ti_core.DebugInfo(ctx.get_pos_info(node)))
1285
+ return None
1286
+
1287
+ @staticmethod
1288
+ def build_Continue(ctx: ASTTransformerContext, node: ast.Continue) -> None:
1289
+ if ctx.is_in_static_for():
1290
+ nearest_non_static_if = ctx.current_loop_scope().nearest_non_static_if
1291
+ if nearest_non_static_if:
1292
+ msg = ctx.get_pos_info(nearest_non_static_if.test)
1293
+ msg += (
1294
+ "You are trying to `continue` a static `for` loop, "
1295
+ "but the `continue` statement is inside a non-static `if`. "
1296
+ )
1297
+ raise GsTaichiSyntaxError(msg)
1298
+ ctx.set_loop_status(LoopStatus.Continue)
1299
+ else:
1300
+ ctx.ast_builder.insert_continue_stmt(_ti_core.DebugInfo(ctx.get_pos_info(node)))
1301
+ return None
1302
+
1303
+ @staticmethod
1304
+ def build_Pass(ctx: ASTTransformerContext, node: ast.Pass) -> None:
1305
+ return None
1306
+
1307
+
1308
+ build_stmt = ASTTransformer()
1309
+
1310
+
1311
+ def build_stmts(ctx: ASTTransformerContext, stmts: list[ast.stmt]):
1312
+ with ctx.variable_scope_guard():
1313
+ for stmt in stmts:
1314
+ if ctx.returned != ReturnStatus.NoReturn or ctx.loop_status() != LoopStatus.Normal:
1315
+ break
1316
+ else:
1317
+ build_stmt(ctx, stmt)
1318
+ return stmts