gstaichi 0.1.18.dev1__cp310-cp310-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3.h +6389 -0
  2. gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3native.h +594 -0
  3. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/instrument.hpp +268 -0
  4. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.h +907 -0
  5. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.hpp +375 -0
  6. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/linker.hpp +97 -0
  7. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/optimizer.hpp +970 -0
  8. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  9. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv.h +2568 -0
  10. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv.hpp +2579 -0
  11. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  12. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  13. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  14. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  15. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  16. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  17. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  18. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  19. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  20. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  21. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  22. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  23. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  24. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  25. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  26. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  27. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  28. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  29. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  30. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  31. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  32. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  33. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  34. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  35. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  36. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  37. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  38. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  39. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  40. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  41. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  42. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  43. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  44. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  45. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  46. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  47. gstaichi-0.1.18.dev1.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  48. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  49. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  50. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  51. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  52. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  53. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  54. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  55. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  56. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  57. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  58. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  59. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  60. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  61. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  62. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  63. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  64. gstaichi-0.1.18.dev1.dist-info/METADATA +108 -0
  65. gstaichi-0.1.18.dev1.dist-info/RECORD +219 -0
  66. gstaichi-0.1.18.dev1.dist-info/WHEEL +5 -0
  67. gstaichi-0.1.18.dev1.dist-info/entry_points.txt +2 -0
  68. gstaichi-0.1.18.dev1.dist-info/licenses/LICENSE +201 -0
  69. gstaichi-0.1.18.dev1.dist-info/top_level.txt +1 -0
  70. taichi/__init__.py +44 -0
  71. taichi/__main__.py +5 -0
  72. taichi/_funcs.py +706 -0
  73. taichi/_kernels.py +420 -0
  74. taichi/_lib/__init__.py +3 -0
  75. taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +1401 -0
  76. taichi/_lib/c_api/include/taichi/taichi.h +29 -0
  77. taichi/_lib/c_api/include/taichi/taichi_core.h +1111 -0
  78. taichi/_lib/c_api/include/taichi/taichi_cpu.h +29 -0
  79. taichi/_lib/c_api/include/taichi/taichi_metal.h +72 -0
  80. taichi/_lib/c_api/include/taichi/taichi_platform.h +55 -0
  81. taichi/_lib/c_api/include/taichi/taichi_unity.h +64 -0
  82. taichi/_lib/c_api/include/taichi/taichi_vulkan.h +151 -0
  83. taichi/_lib/c_api/lib/libtaichi_c_api.dylib +0 -0
  84. taichi/_lib/c_api/runtime/libMoltenVK.dylib +0 -0
  85. taichi/_lib/c_api/runtime/runtime_arm64.bc +0 -0
  86. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +29 -0
  87. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +65 -0
  88. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +121 -0
  89. taichi/_lib/core/__init__.py +0 -0
  90. taichi/_lib/core/py.typed +0 -0
  91. taichi/_lib/core/taichi_python.cpython-310-darwin.so +0 -0
  92. taichi/_lib/core/taichi_python.pyi +3077 -0
  93. taichi/_lib/runtime/libMoltenVK.dylib +0 -0
  94. taichi/_lib/runtime/runtime_arm64.bc +0 -0
  95. taichi/_lib/utils.py +249 -0
  96. taichi/_logging.py +131 -0
  97. taichi/_main.py +552 -0
  98. taichi/_snode/__init__.py +5 -0
  99. taichi/_snode/fields_builder.py +189 -0
  100. taichi/_snode/snode_tree.py +34 -0
  101. taichi/_ti_module/__init__.py +3 -0
  102. taichi/_ti_module/cppgen.py +309 -0
  103. taichi/_ti_module/module.py +145 -0
  104. taichi/_version.py +1 -0
  105. taichi/_version_check.py +100 -0
  106. taichi/ad/__init__.py +3 -0
  107. taichi/ad/_ad.py +530 -0
  108. taichi/algorithms/__init__.py +3 -0
  109. taichi/algorithms/_algorithms.py +117 -0
  110. taichi/aot/__init__.py +12 -0
  111. taichi/aot/_export.py +28 -0
  112. taichi/aot/conventions/__init__.py +3 -0
  113. taichi/aot/conventions/gfxruntime140/__init__.py +38 -0
  114. taichi/aot/conventions/gfxruntime140/dr.py +244 -0
  115. taichi/aot/conventions/gfxruntime140/sr.py +613 -0
  116. taichi/aot/module.py +253 -0
  117. taichi/aot/utils.py +151 -0
  118. taichi/assets/.git +1 -0
  119. taichi/assets/Go-Regular.ttf +0 -0
  120. taichi/assets/static/imgs/ti_gallery.png +0 -0
  121. taichi/examples/minimal.py +28 -0
  122. taichi/experimental.py +16 -0
  123. taichi/graph/__init__.py +3 -0
  124. taichi/graph/_graph.py +292 -0
  125. taichi/lang/__init__.py +50 -0
  126. taichi/lang/_ndarray.py +348 -0
  127. taichi/lang/_ndrange.py +152 -0
  128. taichi/lang/_texture.py +172 -0
  129. taichi/lang/_wrap_inspect.py +189 -0
  130. taichi/lang/any_array.py +99 -0
  131. taichi/lang/argpack.py +411 -0
  132. taichi/lang/ast/__init__.py +5 -0
  133. taichi/lang/ast/ast_transformer.py +1806 -0
  134. taichi/lang/ast/ast_transformer_utils.py +328 -0
  135. taichi/lang/ast/checkers.py +106 -0
  136. taichi/lang/ast/symbol_resolver.py +57 -0
  137. taichi/lang/ast/transform.py +9 -0
  138. taichi/lang/common_ops.py +310 -0
  139. taichi/lang/exception.py +80 -0
  140. taichi/lang/expr.py +180 -0
  141. taichi/lang/field.py +464 -0
  142. taichi/lang/impl.py +1246 -0
  143. taichi/lang/kernel_arguments.py +157 -0
  144. taichi/lang/kernel_impl.py +1415 -0
  145. taichi/lang/matrix.py +1877 -0
  146. taichi/lang/matrix_ops.py +341 -0
  147. taichi/lang/matrix_ops_utils.py +190 -0
  148. taichi/lang/mesh.py +687 -0
  149. taichi/lang/misc.py +807 -0
  150. taichi/lang/ops.py +1489 -0
  151. taichi/lang/runtime_ops.py +13 -0
  152. taichi/lang/shell.py +35 -0
  153. taichi/lang/simt/__init__.py +5 -0
  154. taichi/lang/simt/block.py +94 -0
  155. taichi/lang/simt/grid.py +7 -0
  156. taichi/lang/simt/subgroup.py +191 -0
  157. taichi/lang/simt/warp.py +96 -0
  158. taichi/lang/snode.py +487 -0
  159. taichi/lang/source_builder.py +150 -0
  160. taichi/lang/struct.py +855 -0
  161. taichi/lang/util.py +381 -0
  162. taichi/linalg/__init__.py +8 -0
  163. taichi/linalg/matrixfree_cg.py +310 -0
  164. taichi/linalg/sparse_cg.py +59 -0
  165. taichi/linalg/sparse_matrix.py +303 -0
  166. taichi/linalg/sparse_solver.py +123 -0
  167. taichi/math/__init__.py +11 -0
  168. taichi/math/_complex.py +204 -0
  169. taichi/math/mathimpl.py +886 -0
  170. taichi/profiler/__init__.py +6 -0
  171. taichi/profiler/kernel_metrics.py +260 -0
  172. taichi/profiler/kernel_profiler.py +592 -0
  173. taichi/profiler/memory_profiler.py +15 -0
  174. taichi/profiler/scoped_profiler.py +36 -0
  175. taichi/shaders/Circles_vk.frag +29 -0
  176. taichi/shaders/Circles_vk.vert +45 -0
  177. taichi/shaders/Circles_vk_frag.spv +0 -0
  178. taichi/shaders/Circles_vk_vert.spv +0 -0
  179. taichi/shaders/Lines_vk.frag +9 -0
  180. taichi/shaders/Lines_vk.vert +11 -0
  181. taichi/shaders/Lines_vk_frag.spv +0 -0
  182. taichi/shaders/Lines_vk_vert.spv +0 -0
  183. taichi/shaders/Mesh_vk.frag +71 -0
  184. taichi/shaders/Mesh_vk.vert +68 -0
  185. taichi/shaders/Mesh_vk_frag.spv +0 -0
  186. taichi/shaders/Mesh_vk_vert.spv +0 -0
  187. taichi/shaders/Particles_vk.frag +95 -0
  188. taichi/shaders/Particles_vk.vert +73 -0
  189. taichi/shaders/Particles_vk_frag.spv +0 -0
  190. taichi/shaders/Particles_vk_vert.spv +0 -0
  191. taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
  192. taichi/shaders/SceneLines_vk.frag +9 -0
  193. taichi/shaders/SceneLines_vk.vert +12 -0
  194. taichi/shaders/SceneLines_vk_frag.spv +0 -0
  195. taichi/shaders/SceneLines_vk_vert.spv +0 -0
  196. taichi/shaders/SetImage_vk.frag +21 -0
  197. taichi/shaders/SetImage_vk.vert +15 -0
  198. taichi/shaders/SetImage_vk_frag.spv +0 -0
  199. taichi/shaders/SetImage_vk_vert.spv +0 -0
  200. taichi/shaders/Triangles_vk.frag +16 -0
  201. taichi/shaders/Triangles_vk.vert +29 -0
  202. taichi/shaders/Triangles_vk_frag.spv +0 -0
  203. taichi/shaders/Triangles_vk_vert.spv +0 -0
  204. taichi/shaders/lines2quad_vk_comp.spv +0 -0
  205. taichi/sparse/__init__.py +3 -0
  206. taichi/sparse/_sparse_grid.py +77 -0
  207. taichi/tools/__init__.py +12 -0
  208. taichi/tools/diagnose.py +124 -0
  209. taichi/tools/np2ply.py +364 -0
  210. taichi/tools/vtk.py +38 -0
  211. taichi/types/__init__.py +19 -0
  212. taichi/types/annotations.py +47 -0
  213. taichi/types/compound_types.py +90 -0
  214. taichi/types/enums.py +49 -0
  215. taichi/types/ndarray_type.py +147 -0
  216. taichi/types/primitive_types.py +203 -0
  217. taichi/types/quant.py +88 -0
  218. taichi/types/texture_type.py +85 -0
  219. taichi/types/utils.py +13 -0
@@ -0,0 +1,1806 @@
1
+ # type: ignore
2
+
3
+ import ast
4
+ import collections.abc
5
+ import dataclasses
6
+ import inspect
7
+ import itertools
8
+ import math
9
+ import operator
10
+ import re
11
+ import warnings
12
+ from ast import unparse
13
+ from collections import ChainMap
14
+ from typing import Any, Iterable, Type
15
+
16
+ import numpy as np
17
+
18
+ from taichi._lib import core as _ti_core
19
+ from taichi.lang import _ndarray, any_array, expr, impl, kernel_arguments, matrix, mesh
20
+ from taichi.lang import ops as ti_ops
21
+ from taichi.lang._ndrange import _Ndrange, ndrange
22
+ from taichi.lang.argpack import ArgPackType
23
+ from taichi.lang.ast.ast_transformer_utils import (
24
+ ASTTransformerContext,
25
+ Builder,
26
+ LoopStatus,
27
+ ReturnStatus,
28
+ )
29
+ from taichi.lang.ast.symbol_resolver import ASTResolver
30
+ from taichi.lang.exception import (
31
+ TaichiIndexError,
32
+ TaichiRuntimeTypeError,
33
+ TaichiSyntaxError,
34
+ TaichiTypeError,
35
+ handle_exception_from_cpp,
36
+ )
37
+ from taichi.lang.expr import Expr, make_expr_group
38
+ from taichi.lang.field import Field
39
+ from taichi.lang.matrix import Matrix, MatrixType, Vector
40
+ from taichi.lang.snode import append, deactivate, length
41
+ from taichi.lang.struct import Struct, StructType
42
+ from taichi.lang.util import is_taichi_class, to_taichi_type
43
+ from taichi.types import annotations, ndarray_type, primitive_types, texture_type
44
+ from taichi.types.utils import is_integral
45
+
46
+
47
+ def reshape_list(flat_list: list[Any], target_shape: Iterable[int]) -> list[Any]:
48
+ if len(target_shape) < 2:
49
+ return flat_list
50
+
51
+ curr_list = []
52
+ dim = target_shape[-1]
53
+ for i, elem in enumerate(flat_list):
54
+ if i % dim == 0:
55
+ curr_list.append([])
56
+ curr_list[-1].append(elem)
57
+
58
+ return reshape_list(curr_list, target_shape[:-1])
59
+
60
+
61
+ def boundary_type_cast_warning(expression: Expr) -> None:
62
+ expr_dtype = expression.ptr.get_rvalue_type()
63
+ if not is_integral(expr_dtype) or expr_dtype in [
64
+ primitive_types.i64,
65
+ primitive_types.u64,
66
+ primitive_types.u32,
67
+ ]:
68
+ warnings.warn(
69
+ f"Casting range_for boundary values from {expr_dtype} to i32, which may cause numerical issues",
70
+ Warning,
71
+ )
72
+
73
+
74
+ class ASTTransformer(Builder):
75
+ @staticmethod
76
+ def build_Name(ctx: ASTTransformerContext, node: ast.Name):
77
+ node.ptr = ctx.get_var_by_name(node.id)
78
+ if isinstance(node, (ast.stmt, ast.expr)) and isinstance(node.ptr, Expr):
79
+ node.ptr.dbg_info = _ti_core.DebugInfo(ctx.get_pos_info(node))
80
+ node.ptr.ptr.set_dbg_info(node.ptr.dbg_info)
81
+ return node.ptr
82
+
83
+ @staticmethod
84
+ def build_AnnAssign(ctx: ASTTransformerContext, node: ast.AnnAssign):
85
+ build_stmt(ctx, node.value)
86
+ build_stmt(ctx, node.annotation)
87
+
88
+ is_static_assign = isinstance(node.value, ast.Call) and node.value.func.ptr is impl.static
89
+
90
+ node.ptr = ASTTransformer.build_assign_annotated(
91
+ ctx, node.target, node.value.ptr, is_static_assign, node.annotation.ptr
92
+ )
93
+ return node.ptr
94
+
95
+ @staticmethod
96
+ def build_assign_annotated(
97
+ ctx: ASTTransformerContext, target: ast.Name, value, is_static_assign: bool, annotation: Type
98
+ ):
99
+ """Build an annotated assignment like this: target: annotation = value.
100
+
101
+ Args:
102
+ ctx (ast_builder_utils.BuilderContext): The builder context.
103
+ target (ast.Name): A variable name. `target.id` holds the name as
104
+ a string.
105
+ annotation: A type we hope to assign to the target
106
+ value: A node representing the value.
107
+ is_static_assign: A boolean value indicating whether this is a static assignment
108
+ """
109
+ is_local = isinstance(target, ast.Name)
110
+ if is_local and target.id in ctx.kernel_args:
111
+ raise TaichiSyntaxError(
112
+ f'Kernel argument "{target.id}" is immutable in the kernel. '
113
+ f"If you want to change its value, please create a new variable."
114
+ )
115
+ anno = impl.expr_init(annotation)
116
+ if is_static_assign:
117
+ raise TaichiSyntaxError("Static assign cannot be used on annotated assignment")
118
+ if is_local and not ctx.is_var_declared(target.id):
119
+ var = ti_ops.cast(value, anno)
120
+ var = impl.expr_init(var)
121
+ ctx.create_variable(target.id, var)
122
+ else:
123
+ var = build_stmt(ctx, target)
124
+ if var.ptr.get_rvalue_type() != anno:
125
+ raise TaichiSyntaxError("Static assign cannot have type overloading")
126
+ var._assign(value)
127
+ return var
128
+
129
+ @staticmethod
130
+ def build_Assign(ctx: ASTTransformerContext, node: ast.Assign) -> None:
131
+ build_stmt(ctx, node.value)
132
+ is_static_assign = isinstance(node.value, ast.Call) and node.value.func.ptr is impl.static
133
+
134
+ # Keep all generated assign statements and compose single one at last.
135
+ # The variable is introduced to support chained assignments.
136
+ # Ref https://github.com/taichi-dev/taichi/issues/2659.
137
+ values = node.value.ptr if is_static_assign else impl.expr_init(node.value.ptr)
138
+
139
+ for node_target in node.targets:
140
+ ASTTransformer.build_assign_unpack(ctx, node_target, values, is_static_assign)
141
+ return None
142
+
143
+ @staticmethod
144
+ def build_assign_unpack(ctx: ASTTransformerContext, node_target: list | ast.Tuple, values, is_static_assign: bool):
145
+ """Build the unpack assignments like this: (target1, target2) = (value1, value2).
146
+ The function should be called only if the node target is a tuple.
147
+
148
+ Args:
149
+ ctx (ast_builder_utils.BuilderContext): The builder context.
150
+ node_target (ast.Tuple): A list or tuple object. `node_target.elts` holds a
151
+ list of nodes representing the elements.
152
+ values: A node/list representing the values.
153
+ is_static_assign: A boolean value indicating whether this is a static assignment
154
+ """
155
+ if not isinstance(node_target, ast.Tuple):
156
+ return ASTTransformer.build_assign_basic(ctx, node_target, values, is_static_assign)
157
+ targets = node_target.elts
158
+
159
+ if isinstance(values, matrix.Matrix):
160
+ if not values.m == 1:
161
+ raise ValueError("Matrices with more than one columns cannot be unpacked")
162
+ values = values.entries
163
+
164
+ # Unpack: a, b, c = ti.Vector([1., 2., 3.])
165
+ if isinstance(values, impl.Expr) and values.ptr.is_tensor():
166
+ if len(values.get_shape()) > 1:
167
+ raise ValueError("Matrices with more than one columns cannot be unpacked")
168
+
169
+ values = ctx.ast_builder.expand_exprs([values.ptr])
170
+ if len(values) == 1:
171
+ values = values[0]
172
+
173
+ if isinstance(values, impl.Expr) and values.ptr.is_struct():
174
+ values = ctx.ast_builder.expand_exprs([values.ptr])
175
+ if len(values) == 1:
176
+ values = values[0]
177
+
178
+ if not isinstance(values, collections.abc.Sequence):
179
+ raise TaichiSyntaxError(f"Cannot unpack type: {type(values)}")
180
+
181
+ if len(values) != len(targets):
182
+ raise TaichiSyntaxError("The number of targets is not equal to value length")
183
+
184
+ for i, target in enumerate(targets):
185
+ ASTTransformer.build_assign_basic(ctx, target, values[i], is_static_assign)
186
+
187
+ return None
188
+
189
+ @staticmethod
190
+ def build_assign_basic(ctx: ASTTransformerContext, target: ast.Name, value, is_static_assign: bool):
191
+ """Build basic assignment like this: target = value.
192
+
193
+ Args:
194
+ ctx (ast_builder_utils.BuilderContext): The builder context.
195
+ target (ast.Name): A variable name. `target.id` holds the name as
196
+ a string.
197
+ value: A node representing the value.
198
+ is_static_assign: A boolean value indicating whether this is a static assignment
199
+ """
200
+ is_local = isinstance(target, ast.Name)
201
+ if is_local and target.id in ctx.kernel_args:
202
+ raise TaichiSyntaxError(
203
+ f'Kernel argument "{target.id}" is immutable in the kernel. '
204
+ f"If you want to change its value, please create a new variable."
205
+ )
206
+ if is_static_assign:
207
+ if not is_local:
208
+ raise TaichiSyntaxError("Static assign cannot be used on elements in arrays")
209
+ ctx.create_variable(target.id, value)
210
+ var = value
211
+ elif is_local and not ctx.is_var_declared(target.id):
212
+ var = impl.expr_init(value)
213
+ ctx.create_variable(target.id, var)
214
+ else:
215
+ var = build_stmt(ctx, target)
216
+ try:
217
+ var._assign(value)
218
+ except AttributeError:
219
+ raise TaichiSyntaxError(
220
+ f"Variable '{unparse(target).strip()}' cannot be assigned. Maybe it is not a Taichi object?"
221
+ )
222
+ return var
223
+
224
+ @staticmethod
225
+ def build_NamedExpr(ctx: ASTTransformerContext, node: ast.NamedExpr):
226
+ build_stmt(ctx, node.value)
227
+ is_static_assign = isinstance(node.value, ast.Call) and node.value.func.ptr is impl.static
228
+ node.ptr = ASTTransformer.build_assign_basic(ctx, node.target, node.value.ptr, is_static_assign)
229
+ return node.ptr
230
+
231
+ @staticmethod
232
+ def is_tuple(node):
233
+ if isinstance(node, ast.Tuple):
234
+ return True
235
+ if isinstance(node, ast.Index) and isinstance(node.value.ptr, tuple):
236
+ return True
237
+ if isinstance(node.ptr, tuple):
238
+ return True
239
+ return False
240
+
241
+ @staticmethod
242
+ def build_Subscript(ctx: ASTTransformerContext, node: ast.Subscript):
243
+ build_stmt(ctx, node.value)
244
+ build_stmt(ctx, node.slice)
245
+ if not ASTTransformer.is_tuple(node.slice):
246
+ node.slice.ptr = [node.slice.ptr]
247
+ node.ptr = impl.subscript(ctx.ast_builder, node.value.ptr, *node.slice.ptr)
248
+ return node.ptr
249
+
250
+ @staticmethod
251
+ def build_Slice(ctx: ASTTransformerContext, node: ast.Slice):
252
+ if node.lower is not None:
253
+ build_stmt(ctx, node.lower)
254
+ if node.upper is not None:
255
+ build_stmt(ctx, node.upper)
256
+ if node.step is not None:
257
+ build_stmt(ctx, node.step)
258
+
259
+ node.ptr = slice(
260
+ node.lower.ptr if node.lower else None,
261
+ node.upper.ptr if node.upper else None,
262
+ node.step.ptr if node.step else None,
263
+ )
264
+ return node.ptr
265
+
266
+ @staticmethod
267
+ def build_ExtSlice(ctx: ASTTransformerContext, node: ast.ExtSlice):
268
+ build_stmts(ctx, node.dims)
269
+ node.ptr = tuple(dim.ptr for dim in node.dims)
270
+ return node.ptr
271
+
272
+ @staticmethod
273
+ def build_Tuple(ctx: ASTTransformerContext, node: ast.Tuple):
274
+ build_stmts(ctx, node.elts)
275
+ node.ptr = tuple(elt.ptr for elt in node.elts)
276
+ return node.ptr
277
+
278
+ @staticmethod
279
+ def build_List(ctx: ASTTransformerContext, node: ast.List):
280
+ build_stmts(ctx, node.elts)
281
+ node.ptr = [elt.ptr for elt in node.elts]
282
+ return node.ptr
283
+
284
+ @staticmethod
285
+ def build_Dict(ctx: ASTTransformerContext, node: ast.Dict):
286
+ dic = {}
287
+ for key, value in zip(node.keys, node.values):
288
+ if key is None:
289
+ dic.update(build_stmt(ctx, value))
290
+ else:
291
+ dic[build_stmt(ctx, key)] = build_stmt(ctx, value)
292
+ node.ptr = dic
293
+ return node.ptr
294
+
295
+ @staticmethod
296
+ def process_listcomp(ctx: ASTTransformerContext, node, result) -> None:
297
+ result.append(build_stmt(ctx, node.elt))
298
+
299
+ @staticmethod
300
+ def process_dictcomp(ctx: ASTTransformerContext, node, result) -> None:
301
+ key = build_stmt(ctx, node.key)
302
+ value = build_stmt(ctx, node.value)
303
+ result[key] = value
304
+
305
+ @staticmethod
306
+ def process_generators(ctx: ASTTransformerContext, node: ast.GeneratorExp, now_comp, func, result):
307
+ if now_comp >= len(node.generators):
308
+ return func(ctx, node, result)
309
+ with ctx.static_scope_guard():
310
+ _iter = build_stmt(ctx, node.generators[now_comp].iter)
311
+
312
+ if isinstance(_iter, impl.Expr) and _iter.ptr.is_tensor():
313
+ shape = _iter.ptr.get_shape()
314
+ flattened = [Expr(x) for x in ctx.ast_builder.expand_exprs([_iter.ptr])]
315
+ _iter = reshape_list(flattened, shape)
316
+
317
+ for value in _iter:
318
+ with ctx.variable_scope_guard():
319
+ ASTTransformer.build_assign_unpack(ctx, node.generators[now_comp].target, value, True)
320
+ with ctx.static_scope_guard():
321
+ build_stmts(ctx, node.generators[now_comp].ifs)
322
+ ASTTransformer.process_ifs(ctx, node, now_comp, 0, func, result)
323
+ return None
324
+
325
+ @staticmethod
326
+ def process_ifs(ctx: ASTTransformerContext, node: ast.If, now_comp, now_if, func, result):
327
+ if now_if >= len(node.generators[now_comp].ifs):
328
+ return ASTTransformer.process_generators(ctx, node, now_comp + 1, func, result)
329
+ cond = node.generators[now_comp].ifs[now_if].ptr
330
+ if cond:
331
+ ASTTransformer.process_ifs(ctx, node, now_comp, now_if + 1, func, result)
332
+
333
+ return None
334
+
335
+ @staticmethod
336
+ def build_ListComp(ctx: ASTTransformerContext, node: ast.ListComp):
337
+ result = []
338
+ ASTTransformer.process_generators(ctx, node, 0, ASTTransformer.process_listcomp, result)
339
+ node.ptr = result
340
+ return node.ptr
341
+
342
+ @staticmethod
343
+ def build_DictComp(ctx: ASTTransformerContext, node: ast.DictComp):
344
+ result = {}
345
+ ASTTransformer.process_generators(ctx, node, 0, ASTTransformer.process_dictcomp, result)
346
+ node.ptr = result
347
+ return node.ptr
348
+
349
+ @staticmethod
350
+ def build_Index(ctx: ASTTransformerContext, node: ast.Index):
351
+ node.ptr = build_stmt(ctx, node.value)
352
+ return node.ptr
353
+
354
+ @staticmethod
355
+ def build_Constant(ctx: ASTTransformerContext, node: ast.Constant):
356
+ node.ptr = node.value
357
+ return node.ptr
358
+
359
+ @staticmethod
360
+ def build_Num(ctx: ASTTransformerContext, node: ast.Num):
361
+ node.ptr = node.n
362
+ return node.ptr
363
+
364
+ @staticmethod
365
+ def build_Str(ctx: ASTTransformerContext, node: ast.Str):
366
+ node.ptr = node.s
367
+ return node.ptr
368
+
369
+ @staticmethod
370
+ def build_Bytes(ctx: ASTTransformerContext, node: ast.Bytes):
371
+ node.ptr = node.s
372
+ return node.ptr
373
+
374
+ @staticmethod
375
+ def build_NameConstant(ctx: ASTTransformerContext, node: ast.NameConstant):
376
+ node.ptr = node.value
377
+ return node.ptr
378
+
379
+ @staticmethod
380
+ def build_keyword(ctx: ASTTransformerContext, node: ast.keyword):
381
+ build_stmt(ctx, node.value)
382
+ if node.arg is None:
383
+ node.ptr = node.value.ptr
384
+ else:
385
+ node.ptr = {node.arg: node.value.ptr}
386
+ return node.ptr
387
+
388
+ @staticmethod
389
+ def build_Starred(ctx: ASTTransformerContext, node: ast.Starred):
390
+ node.ptr = build_stmt(ctx, node.value)
391
+ return node.ptr
392
+
393
+ @staticmethod
394
+ def build_FormattedValue(ctx: ASTTransformerContext, node: ast.FormattedValue):
395
+ node.ptr = build_stmt(ctx, node.value)
396
+ if node.format_spec is None or len(node.format_spec.values) == 0:
397
+ return node.ptr
398
+ values = node.format_spec.values
399
+ assert len(values) == 1
400
+ format_str = values[0].s
401
+ assert format_str is not None
402
+ # distinguished from normal list
403
+ return ["__ti_fmt_value__", node.ptr, format_str]
404
+
405
+ @staticmethod
406
+ def build_JoinedStr(ctx: ASTTransformerContext, node: ast.JoinedStr):
407
+ str_spec = ""
408
+ args = []
409
+ for sub_node in node.values:
410
+ if isinstance(sub_node, ast.FormattedValue):
411
+ str_spec += "{}"
412
+ args.append(build_stmt(ctx, sub_node))
413
+ elif isinstance(sub_node, ast.Constant):
414
+ str_spec += sub_node.value
415
+ elif isinstance(sub_node, ast.Str):
416
+ str_spec += sub_node.s
417
+ else:
418
+ raise TaichiSyntaxError("Invalid value for fstring.")
419
+
420
+ args.insert(0, str_spec)
421
+ node.ptr = impl.ti_format(*args)
422
+ return node.ptr
423
+
424
+ @staticmethod
425
+ def build_call_if_is_builtin(ctx: ASTTransformerContext, node, args, keywords):
426
+ from taichi.lang import matrix_ops # pylint: disable=C0415
427
+
428
+ func = node.func.ptr
429
+ replace_func = {
430
+ id(print): impl.ti_print,
431
+ id(min): ti_ops.min,
432
+ id(max): ti_ops.max,
433
+ id(int): impl.ti_int,
434
+ id(bool): impl.ti_bool,
435
+ id(float): impl.ti_float,
436
+ id(any): matrix_ops.any,
437
+ id(all): matrix_ops.all,
438
+ id(abs): abs,
439
+ id(pow): pow,
440
+ id(operator.matmul): matrix_ops.matmul,
441
+ }
442
+
443
+ # Builtin 'len' function on Matrix Expr
444
+ if id(func) == id(len) and len(args) == 1:
445
+ if isinstance(args[0], Expr) and args[0].ptr.is_tensor():
446
+ node.ptr = args[0].get_shape()[0]
447
+ return True
448
+
449
+ if id(func) in replace_func:
450
+ node.ptr = replace_func[id(func)](*args, **keywords)
451
+ return True
452
+ return False
453
+
454
+ @staticmethod
455
+ def build_call_if_is_type(ctx: ASTTransformerContext, node, args, keywords):
456
+ func = node.func.ptr
457
+ if id(func) in primitive_types.type_ids:
458
+ if len(args) != 1 or keywords:
459
+ raise TaichiSyntaxError("A primitive type can only decorate a single expression.")
460
+ if is_taichi_class(args[0]):
461
+ raise TaichiSyntaxError("A primitive type cannot decorate an expression with a compound type.")
462
+
463
+ if isinstance(args[0], expr.Expr):
464
+ if args[0].ptr.is_tensor():
465
+ raise TaichiSyntaxError("A primitive type cannot decorate an expression with a compound type.")
466
+ node.ptr = ti_ops.cast(args[0], func)
467
+ else:
468
+ node.ptr = expr.Expr(args[0], dtype=func)
469
+ return True
470
+ return False
471
+
472
+ @staticmethod
473
+ def is_external_func(ctx: ASTTransformerContext, func) -> bool:
474
+ if ctx.is_in_static_scope(): # allow external function in static scope
475
+ return False
476
+ if hasattr(func, "_is_taichi_function") or hasattr(func, "_is_wrapped_kernel"): # taichi func/kernel
477
+ return False
478
+ if hasattr(func, "__module__") and func.__module__ and func.__module__.startswith("taichi."):
479
+ return False
480
+ return True
481
+
482
+ @staticmethod
483
+ def warn_if_is_external_func(ctx: ASTTransformerContext, node):
484
+ func = node.func.ptr
485
+ if not ASTTransformer.is_external_func(ctx, func):
486
+ return
487
+ name = unparse(node.func).strip()
488
+ warnings.warn_explicit(
489
+ f"\x1b[38;5;226m" # Yellow
490
+ f'Calling non-taichi function "{name}". '
491
+ f"Scope inside the function is not processed by the Taichi AST transformer. "
492
+ f"The function may not work as expected. Proceed with caution! "
493
+ f"Maybe you can consider turning it into a @ti.func?"
494
+ f"\x1b[0m", # Reset
495
+ SyntaxWarning,
496
+ ctx.file,
497
+ node.lineno + ctx.lineno_offset,
498
+ module="taichi",
499
+ )
500
+
501
+ @staticmethod
502
+ # Parses a formatted string and extracts format specifiers from it, along with positional and keyword arguments.
503
+ # This function produces a canonicalized formatted string that includes solely empty replacement fields, e.g. 'qwerty {} {} {} {} {}'.
504
+ # Note that the arguments can be used multiple times in the string.
505
+ # e.g.:
506
+ # origin input: 'qwerty {1} {} {1:.3f} {k:.4f} {k:}'.format(1.0, 2.0, k=k)
507
+ # raw_string: 'qwerty {1} {} {1:.3f} {k:.4f} {k:}'
508
+ # raw_args: [1.0, 2.0]
509
+ # raw_keywords: {'k': <ti.Expr>}
510
+ # return value: ['qwerty {} {} {} {} {}', 2.0, 1.0, ['__ti_fmt_value__', 2.0, '.3f'], ['__ti_fmt_value__', <ti.Expr>, '.4f'], <ti.Expr>]
511
+ def canonicalize_formatted_string(raw_string: str, *raw_args: list, **raw_keywords: dict):
512
+ raw_brackets = re.findall(r"{(.*?)}", raw_string)
513
+ brackets = []
514
+ unnamed = 0
515
+ for bracket in raw_brackets:
516
+ item, spec = bracket.split(":") if ":" in bracket else (bracket, None)
517
+ if item.isdigit():
518
+ item = int(item)
519
+ # handle unnamed positional args
520
+ if item == "":
521
+ item = unnamed
522
+ unnamed += 1
523
+ # handle empty spec
524
+ if spec == "":
525
+ spec = None
526
+ brackets.append((item, spec))
527
+
528
+ # check for errors in the arguments
529
+ max_args_index = max([t[0] for t in brackets if isinstance(t[0], int)], default=-1)
530
+ if max_args_index + 1 != len(raw_args):
531
+ raise TaichiSyntaxError(
532
+ f"Expected {max_args_index + 1} positional argument(s), but received {len(raw_args)} instead."
533
+ )
534
+ brackets_keywords = [t[0] for t in brackets if isinstance(t[0], str)]
535
+ for item in brackets_keywords:
536
+ if item not in raw_keywords:
537
+ raise TaichiSyntaxError(f"Keyword '{item}' not found.")
538
+ for item in raw_keywords:
539
+ if item not in brackets_keywords:
540
+ raise TaichiSyntaxError(f"Keyword '{item}' not used.")
541
+
542
+ # reorganize the arguments based on their positions, keywords, and format specifiers
543
+ args = []
544
+ for item, spec in brackets:
545
+ new_arg = raw_args[item] if isinstance(item, int) else raw_keywords[item]
546
+ if spec is not None:
547
+ args.append(["__ti_fmt_value__", new_arg, spec])
548
+ else:
549
+ args.append(new_arg)
550
+ # put the formatted string as the first argument to make ti.format() happy
551
+ args.insert(0, re.sub(r"{.*?}", "{}", raw_string))
552
+ return args
553
+
554
+ @staticmethod
555
+ def expand_node_args_dataclasses(args: tuple[ast.AST, ...]) -> tuple[ast.AST, ...]:
556
+ args_new = []
557
+ for arg in args:
558
+ val = arg.ptr
559
+ if dataclasses.is_dataclass(val):
560
+ dataclass_type = val
561
+ for field in dataclasses.fields(dataclass_type):
562
+ child_name = f"__ti_{arg.id}_{field.name}"
563
+ load_ctx = ast.Load()
564
+ arg_node = ast.Name(
565
+ id=child_name,
566
+ ctx=load_ctx,
567
+ lineno=arg.lineno,
568
+ end_lineno=arg.end_lineno,
569
+ col_offset=arg.col_offset,
570
+ end_col_offset=arg.end_col_offset,
571
+ )
572
+ args_new.append(arg_node)
573
+ else:
574
+ args_new.append(arg)
575
+ return tuple(args_new)
576
+
577
+ @staticmethod
578
+ def build_Call(ctx: ASTTransformerContext, node: ast.Call):
579
+ if ASTTransformer.get_decorator(ctx, node) in ["static", "static_assert"]:
580
+ with ctx.static_scope_guard():
581
+ build_stmt(ctx, node.func)
582
+ build_stmts(ctx, node.args)
583
+ build_stmts(ctx, node.keywords)
584
+ else:
585
+ build_stmt(ctx, node.func)
586
+ # creates variable for the dataclass itself (as well as other variables,
587
+ # not related to dataclasses). Necessary for calling further child functions
588
+ build_stmts(ctx, node.args)
589
+ node.args = ASTTransformer.expand_node_args_dataclasses(node.args)
590
+ # create variables for the now-expanded dataclass members
591
+ build_stmts(ctx, node.args)
592
+ build_stmts(ctx, node.keywords)
593
+
594
+ args = []
595
+ for arg in node.args:
596
+ if isinstance(arg, ast.Starred):
597
+ arg_list = arg.ptr
598
+ if isinstance(arg_list, Expr) and arg_list.is_tensor():
599
+ # Expand Expr with Matrix-type return into list of Exprs
600
+ arg_list = [Expr(x) for x in ctx.ast_builder.expand_exprs([arg_list.ptr])]
601
+
602
+ for i in arg_list:
603
+ args.append(i)
604
+ else:
605
+ args.append(arg.ptr)
606
+ keywords = dict(ChainMap(*[keyword.ptr for keyword in node.keywords]))
607
+ func = node.func.ptr
608
+
609
+ if id(func) in [id(print), id(impl.ti_print)]:
610
+ ctx.func.has_print = True
611
+
612
+ if isinstance(node.func, ast.Attribute) and isinstance(node.func.value.ptr, str) and node.func.attr == "format":
613
+ raw_string = node.func.value.ptr
614
+ args = ASTTransformer.canonicalize_formatted_string(raw_string, *args, **keywords)
615
+ node.ptr = impl.ti_format(*args)
616
+ return node.ptr
617
+
618
+ if id(func) == id(Matrix) or id(func) == id(Vector):
619
+ node.ptr = matrix.make_matrix(*args, **keywords)
620
+ return node.ptr
621
+
622
+ if ASTTransformer.build_call_if_is_builtin(ctx, node, args, keywords):
623
+ return node.ptr
624
+
625
+ if ASTTransformer.build_call_if_is_type(ctx, node, args, keywords):
626
+ return node.ptr
627
+
628
+ if hasattr(node.func, "caller"):
629
+ node.ptr = func(node.func.caller, *args, **keywords)
630
+ return node.ptr
631
+ ASTTransformer.warn_if_is_external_func(ctx, node)
632
+ try:
633
+ node.ptr = func(*args, **keywords)
634
+ except TypeError as e:
635
+ module = inspect.getmodule(func)
636
+ error_msg = re.sub(r"\bExpr\b", "Taichi Expression", str(e))
637
+ msg = f"TypeError when calling `{func.__name__}`: {error_msg}."
638
+ if ASTTransformer.is_external_func(ctx, node.func.ptr):
639
+ args_has_expr = any([isinstance(arg, Expr) for arg in args])
640
+ if args_has_expr and (module == math or module == np):
641
+ exec_str = f"from taichi import {func.__name__}"
642
+ try:
643
+ exec(exec_str, {})
644
+ except:
645
+ pass
646
+ else:
647
+ msg += f"\nDid you mean to use `ti.{func.__name__}` instead of `{module.__name__}.{func.__name__}`?"
648
+ raise TaichiTypeError(msg)
649
+
650
+ if getattr(func, "_is_taichi_function", False):
651
+ ctx.func.has_print |= func.func.has_print
652
+
653
+ return node.ptr
654
+
655
+ @staticmethod
656
+ def build_FunctionDef(ctx: ASTTransformerContext, node: ast.FunctionDef):
657
+ if ctx.visited_funcdef:
658
+ raise TaichiSyntaxError(
659
+ f"Function definition is not allowed in 'ti.{'kernel' if ctx.is_kernel else 'func'}'."
660
+ )
661
+ ctx.visited_funcdef = True
662
+
663
+ args = node.args
664
+ assert args.vararg is None
665
+ assert args.kwonlyargs == []
666
+ assert args.kw_defaults == []
667
+ assert args.kwarg is None
668
+
669
+ def decl_and_create_variable(
670
+ annotation, name, arg_features, invoke_later_dict, prefix_name, arg_depth
671
+ ) -> tuple[bool, Any]:
672
+ full_name = prefix_name + "_" + name
673
+ if not isinstance(annotation, primitive_types.RefType):
674
+ ctx.kernel_args.append(name)
675
+ if isinstance(annotation, ArgPackType):
676
+ kernel_arguments.push_argpack_arg(name)
677
+ d = {}
678
+ items_to_put_in_dict = []
679
+ for j, (_name, anno) in enumerate(annotation.members.items()):
680
+ result, obj = decl_and_create_variable(
681
+ anno, _name, arg_features[j], invoke_later_dict, full_name, arg_depth + 1
682
+ )
683
+ if not result:
684
+ d[_name] = None
685
+ items_to_put_in_dict.append((full_name + "_" + _name, _name, obj))
686
+ else:
687
+ d[_name] = obj
688
+ argpack = kernel_arguments.decl_argpack_arg(annotation, d)
689
+ for item in items_to_put_in_dict:
690
+ invoke_later_dict[item[0]] = argpack, item[1], *item[2]
691
+ return True, argpack
692
+ if annotation == annotations.template or isinstance(annotation, annotations.template):
693
+ return True, ctx.global_vars[name]
694
+ if isinstance(annotation, annotations.sparse_matrix_builder):
695
+ return False, (
696
+ kernel_arguments.decl_sparse_matrix,
697
+ (
698
+ to_taichi_type(arg_features),
699
+ full_name,
700
+ ),
701
+ )
702
+ if isinstance(annotation, ndarray_type.NdarrayType):
703
+ return False, (
704
+ kernel_arguments.decl_ndarray_arg,
705
+ (
706
+ to_taichi_type(arg_features[0]),
707
+ arg_features[1],
708
+ full_name,
709
+ arg_features[2],
710
+ arg_features[3],
711
+ ),
712
+ )
713
+ if isinstance(annotation, texture_type.TextureType):
714
+ return False, (kernel_arguments.decl_texture_arg, (arg_features[0], full_name))
715
+ if isinstance(annotation, texture_type.RWTextureType):
716
+ return False, (
717
+ kernel_arguments.decl_rw_texture_arg,
718
+ (arg_features[0], arg_features[1], arg_features[2], full_name),
719
+ )
720
+ if isinstance(annotation, MatrixType):
721
+ return True, kernel_arguments.decl_matrix_arg(annotation, name, arg_depth)
722
+ if isinstance(annotation, StructType):
723
+ return True, kernel_arguments.decl_struct_arg(annotation, name, arg_depth)
724
+ return True, kernel_arguments.decl_scalar_arg(annotation, name, arg_depth)
725
+
726
+ def transform_as_kernel() -> None:
727
+ if node.returns is not None:
728
+ if not isinstance(node.returns, ast.Constant):
729
+ for return_type in ctx.func.return_type:
730
+ kernel_arguments.decl_ret(return_type)
731
+ impl.get_runtime().compiling_callable.finalize_rets()
732
+
733
+ invoke_later_dict: dict[str, tuple[Any, str, Any]] = dict()
734
+ create_variable_later = dict()
735
+ for i, arg in enumerate(args.args):
736
+ argument = ctx.func.arguments[i]
737
+ if isinstance(argument.annotation, ArgPackType):
738
+ kernel_arguments.push_argpack_arg(argument.name)
739
+ d = {}
740
+ items_to_put_in_dict: list[tuple[str, str, Any]] = []
741
+ for j, (name, anno) in enumerate(argument.annotation.members.items()):
742
+ result, obj = decl_and_create_variable(
743
+ anno, name, ctx.arg_features[i][j], invoke_later_dict, "__argpack_" + name, 1
744
+ )
745
+ if not result:
746
+ d[name] = None
747
+ items_to_put_in_dict.append(("__argpack_" + name, name, obj))
748
+ else:
749
+ d[name] = obj
750
+ argpack = kernel_arguments.decl_argpack_arg(ctx.func.arguments[i].annotation, d)
751
+ for item in items_to_put_in_dict:
752
+ invoke_later_dict[item[0]] = argpack, item[1], *item[2]
753
+ create_variable_later[arg.arg] = argpack
754
+ elif dataclasses.is_dataclass(argument.annotation):
755
+ arg_features = ctx.arg_features[i]
756
+ ctx.create_variable(argument.name, argument.annotation)
757
+ for field_idx, field in enumerate(dataclasses.fields(argument.annotation)):
758
+ flat_name = f"__ti_{argument.name}_{field.name}"
759
+ result, obj = decl_and_create_variable(
760
+ field.type,
761
+ flat_name,
762
+ arg_features[field_idx],
763
+ invoke_later_dict,
764
+ "",
765
+ 0,
766
+ )
767
+ if result:
768
+ ctx.create_variable(flat_name, obj)
769
+ else:
770
+ decl_type_func, type_args = obj
771
+ obj = decl_type_func(*type_args)
772
+ ctx.create_variable(flat_name, obj)
773
+ else:
774
+ result, obj = decl_and_create_variable(
775
+ argument.annotation,
776
+ argument.name,
777
+ ctx.arg_features[i] if ctx.arg_features is not None else None,
778
+ invoke_later_dict,
779
+ "",
780
+ 0,
781
+ )
782
+ if result:
783
+ ctx.create_variable(arg.arg, obj)
784
+ else:
785
+ decl_type_func, type_args = obj
786
+ obj = decl_type_func(*type_args)
787
+ ctx.create_variable(arg.arg, obj)
788
+ for k, v in invoke_later_dict.items():
789
+ argpack, name, func, params = v
790
+ argpack[name] = func(*params)
791
+ for k, v in create_variable_later.items():
792
+ ctx.create_variable(k, v)
793
+
794
+ impl.get_runtime().compiling_callable.finalize_params()
795
+ # remove original args
796
+ node.args.args = []
797
+
798
+ if ctx.is_kernel: # ti.kernel
799
+ transform_as_kernel()
800
+
801
+ else: # ti.func
802
+ if ctx.is_real_function:
803
+ transform_as_kernel()
804
+ else:
805
+ for data_i, data in enumerate(ctx.argument_data):
806
+ argument = ctx.func.arguments[data_i]
807
+ if isinstance(argument.annotation, annotations.template):
808
+ ctx.create_variable(argument.name, data)
809
+ continue
810
+
811
+ elif dataclasses.is_dataclass(argument.annotation):
812
+ dataclass_type = argument.annotation
813
+ for field in dataclasses.fields(dataclass_type):
814
+ data_child = getattr(data, field.name)
815
+ if not isinstance(
816
+ data_child,
817
+ (
818
+ _ndarray.ScalarNdarray,
819
+ matrix.VectorNdarray,
820
+ matrix.MatrixNdarray,
821
+ any_array.AnyArray,
822
+ ),
823
+ ):
824
+ raise TaichiSyntaxError(
825
+ f"Argument {argument.name} of type {dataclass_type} {field.type} is not recognized."
826
+ )
827
+ field.type.check_matched(data_child.get_type(), field.name)
828
+ var_name = f"__ti_{argument.name}_{field.name}"
829
+ ctx.create_variable(var_name, data_child)
830
+ continue
831
+
832
+ # Ndarray arguments are passed by reference.
833
+ if isinstance(argument.annotation, (ndarray_type.NdarrayType)):
834
+ if not isinstance(
835
+ data,
836
+ (
837
+ _ndarray.ScalarNdarray,
838
+ matrix.VectorNdarray,
839
+ matrix.MatrixNdarray,
840
+ any_array.AnyArray,
841
+ ),
842
+ ):
843
+ raise TaichiSyntaxError(
844
+ f"Argument {arg.arg} of type {argument.annotation} is not recognized."
845
+ )
846
+ argument.annotation.check_matched(data.get_type(), argument.name)
847
+ ctx.create_variable(argument.name, data)
848
+ continue
849
+
850
+ # Matrix arguments are passed by value.
851
+ if isinstance(argument.annotation, (MatrixType)):
852
+ var_name = argument.name
853
+ # "data" is expected to be an Expr here,
854
+ # so we simply call "impl.expr_init_func(data)" to perform:
855
+ #
856
+ # TensorType* t = alloca()
857
+ # assign(t, data)
858
+ #
859
+ # We created local variable "t" - a copy of the passed-in argument "data"
860
+ if not isinstance(data, expr.Expr) or not data.ptr.is_tensor():
861
+ raise TaichiSyntaxError(
862
+ f"Argument {var_name} of type {argument.annotation} is expected to be a Matrix, but got {type(data)}."
863
+ )
864
+
865
+ element_shape = data.ptr.get_rvalue_type().shape()
866
+ if len(element_shape) != argument.annotation.ndim:
867
+ raise TaichiSyntaxError(
868
+ f"Argument {var_name} of type {argument.annotation} is expected to be a Matrix with ndim {argument.annotation.ndim}, but got {len(element_shape)}."
869
+ )
870
+
871
+ assert argument.annotation.ndim > 0
872
+ if element_shape[0] != argument.annotation.n:
873
+ raise TaichiSyntaxError(
874
+ f"Argument {var_name} of type {argument.annotation} is expected to be a Matrix with n {argument.annotation.n}, but got {element_shape[0]}."
875
+ )
876
+
877
+ if argument.annotation.ndim == 2 and element_shape[1] != argument.annotation.m:
878
+ raise TaichiSyntaxError(
879
+ f"Argument {var_name} of type {argument.annotation} is expected to be a Matrix with m {argument.annotation.m}, but got {element_shape[0]}."
880
+ )
881
+
882
+ ctx.create_variable(var_name, impl.expr_init_func(data))
883
+ continue
884
+
885
+ if id(argument.annotation) in primitive_types.type_ids:
886
+ var_name = argument.name
887
+ ctx.create_variable(var_name, impl.expr_init_func(ti_ops.cast(data, argument.annotation)))
888
+ continue
889
+ # Create a copy for non-template arguments,
890
+ # so that they are passed by value.
891
+ var_name = argument.name
892
+ ctx.create_variable(var_name, impl.expr_init_func(data))
893
+ for v in ctx.func.orig_arguments:
894
+ if dataclasses.is_dataclass(v.annotation):
895
+ ctx.create_variable(v.name, v.annotation)
896
+
897
+ with ctx.variable_scope_guard():
898
+ build_stmts(ctx, node.body)
899
+
900
+ return None
901
+
902
+ @staticmethod
903
+ def build_Return(ctx: ASTTransformerContext, node: ast.Return) -> None:
904
+ if not ctx.is_real_function:
905
+ if ctx.is_in_non_static_control_flow():
906
+ raise TaichiSyntaxError("Return inside non-static if/for is not supported")
907
+ if node.value is not None:
908
+ build_stmt(ctx, node.value)
909
+ if node.value is None or node.value.ptr is None:
910
+ if not ctx.is_real_function:
911
+ ctx.returned = ReturnStatus.ReturnedVoid
912
+ return None
913
+ if ctx.is_kernel or ctx.is_real_function:
914
+ # TODO: check if it's at the end of a kernel, throw TaichiSyntaxError if not
915
+ if ctx.func.return_type is None:
916
+ raise TaichiSyntaxError(
917
+ f'A {"kernel" if ctx.is_kernel else "function"} '
918
+ "with a return value must be annotated "
919
+ "with a return type, e.g. def func() -> ti.f32"
920
+ )
921
+ return_exprs = []
922
+ if len(ctx.func.return_type) == 1:
923
+ node.value.ptr = [node.value.ptr]
924
+ assert len(ctx.func.return_type) == len(node.value.ptr)
925
+ for return_type, ptr in zip(ctx.func.return_type, node.value.ptr):
926
+ if id(return_type) in primitive_types.type_ids:
927
+ if isinstance(ptr, Expr):
928
+ if ptr.is_tensor() or ptr.is_struct() or ptr.element_type() not in primitive_types.all_types:
929
+ raise TaichiRuntimeTypeError.get_ret(str(return_type), ptr)
930
+ elif not isinstance(ptr, (float, int, np.floating, np.integer)):
931
+ raise TaichiRuntimeTypeError.get_ret(str(return_type), ptr)
932
+ return_exprs += [ti_ops.cast(expr.Expr(ptr), return_type).ptr]
933
+ elif isinstance(return_type, MatrixType):
934
+ values = ptr
935
+ if isinstance(values, Matrix):
936
+ if values.ndim != ctx.func.return_type.ndim:
937
+ raise TaichiRuntimeTypeError(
938
+ f"Return matrix ndim mismatch, expecting={return_type.ndim}, got={values.ndim}."
939
+ )
940
+ elif return_type.get_shape() != values.get_shape():
941
+ raise TaichiRuntimeTypeError(
942
+ f"Return matrix shape mismatch, expecting={return_type.get_shape()}, got={values.get_shape()}."
943
+ )
944
+ values = (
945
+ itertools.chain.from_iterable(values.to_list())
946
+ if values.ndim == 1
947
+ else iter(values.to_list())
948
+ )
949
+ elif isinstance(values, Expr):
950
+ if not values.is_tensor():
951
+ raise TaichiRuntimeTypeError.get_ret(return_type.to_string(), ptr)
952
+ elif (
953
+ return_type.dtype in primitive_types.real_types
954
+ and not values.element_type() in primitive_types.all_types
955
+ ):
956
+ raise TaichiRuntimeTypeError.get_ret(return_type.dtype.to_string(), values.element_type())
957
+ elif (
958
+ return_type.dtype in primitive_types.integer_types
959
+ and not values.element_type() in primitive_types.integer_types
960
+ ):
961
+ raise TaichiRuntimeTypeError.get_ret(return_type.dtype.to_string(), values.element_type())
962
+ elif len(values.get_shape()) != return_type.ndim:
963
+ raise TaichiRuntimeTypeError(
964
+ f"Return matrix ndim mismatch, expecting={return_type.ndim}, got={len(values.get_shape())}."
965
+ )
966
+ elif return_type.get_shape() != values.get_shape():
967
+ raise TaichiRuntimeTypeError(
968
+ f"Return matrix shape mismatch, expecting={return_type.get_shape()}, got={values.get_shape()}."
969
+ )
970
+ values = [values]
971
+ else:
972
+ np_array = np.array(values)
973
+ dt, shape, ndim = np_array.dtype, np_array.shape, np_array.ndim
974
+ if return_type.dtype in primitive_types.real_types and dt not in (
975
+ float,
976
+ int,
977
+ np.floating,
978
+ np.integer,
979
+ ):
980
+ raise TaichiRuntimeTypeError.get_ret(return_type.dtype.to_string(), dt)
981
+ elif return_type.dtype in primitive_types.integer_types and dt not in (int, np.integer):
982
+ raise TaichiRuntimeTypeError.get_ret(return_type.dtype.to_string(), dt)
983
+ elif ndim != return_type.ndim:
984
+ raise TaichiRuntimeTypeError(
985
+ f"Return matrix ndim mismatch, expecting={return_type.ndim}, got={ndim}."
986
+ )
987
+ elif return_type.get_shape() != shape:
988
+ raise TaichiRuntimeTypeError(
989
+ f"Return matrix shape mismatch, expecting={return_type.get_shape()}, got={shape}."
990
+ )
991
+ values = [values]
992
+ return_exprs += [ti_ops.cast(exp, return_type.dtype) for exp in values]
993
+ elif isinstance(return_type, StructType):
994
+ if not isinstance(ptr, Struct) or not isinstance(ptr, return_type):
995
+ raise TaichiRuntimeTypeError.get_ret(str(return_type), ptr)
996
+ values = ptr
997
+ assert isinstance(values, Struct)
998
+ return_exprs += expr._get_flattened_ptrs(values)
999
+ else:
1000
+ raise TaichiSyntaxError("The return type is not supported now!")
1001
+ ctx.ast_builder.create_kernel_exprgroup_return(
1002
+ expr.make_expr_group(return_exprs), _ti_core.DebugInfo(ctx.get_pos_info(node))
1003
+ )
1004
+ else:
1005
+ ctx.return_data = node.value.ptr
1006
+ if ctx.func.return_type is not None:
1007
+ if len(ctx.func.return_type) == 1:
1008
+ ctx.return_data = [ctx.return_data]
1009
+ for i, return_type in enumerate(ctx.func.return_type):
1010
+ if id(return_type) in primitive_types.type_ids:
1011
+ ctx.return_data[i] = ti_ops.cast(ctx.return_data[i], return_type)
1012
+ if len(ctx.func.return_type) == 1:
1013
+ ctx.return_data = ctx.return_data[0]
1014
+ if not ctx.is_real_function:
1015
+ ctx.returned = ReturnStatus.ReturnedValue
1016
+ return None
1017
+
1018
+ @staticmethod
1019
+ def build_Module(ctx: ASTTransformerContext, node: ast.Module) -> None:
1020
+ with ctx.variable_scope_guard():
1021
+ # Do NOT use |build_stmts| which inserts 'del' statements to the
1022
+ # end and deletes parameters passed into the module
1023
+ for stmt in node.body:
1024
+ build_stmt(ctx, stmt)
1025
+ return None
1026
+
1027
+ @staticmethod
1028
+ def build_attribute_if_is_dynamic_snode_method(ctx: ASTTransformerContext, node) -> bool:
1029
+ is_subscript = isinstance(node.value, ast.Subscript)
1030
+ names = ("append", "deactivate", "length")
1031
+ if node.attr not in names:
1032
+ return False
1033
+ if is_subscript:
1034
+ x = node.value.value.ptr
1035
+ indices = node.value.slice.ptr
1036
+ else:
1037
+ x = node.value.ptr
1038
+ indices = []
1039
+ if not isinstance(x, Field):
1040
+ return False
1041
+ if not x.parent().ptr.type == _ti_core.SNodeType.dynamic:
1042
+ return False
1043
+ field_dim = x.snode.ptr.num_active_indices()
1044
+ indices_expr_group = make_expr_group(*indices)
1045
+ index_dim = indices_expr_group.size()
1046
+ if field_dim != index_dim + 1:
1047
+ return False
1048
+ if node.attr == "append":
1049
+ node.ptr = lambda val: append(x.parent(), indices, val)
1050
+ elif node.attr == "deactivate":
1051
+ node.ptr = lambda: deactivate(x.parent(), indices)
1052
+ else:
1053
+ node.ptr = lambda: length(x.parent(), indices)
1054
+ return True
1055
+
1056
+ @staticmethod
1057
+ def build_Attribute(ctx: ASTTransformerContext, node: ast.Attribute):
1058
+ # There are two valid cases for the methods of Dynamic SNode:
1059
+ #
1060
+ # 1. x[i, j].append (where the dimension of the field (3 in this case) is equal to one plus the number of the
1061
+ # indices (2 in this case) )
1062
+ #
1063
+ # 2. x.append (where the dimension of the field is one, equal to x[()].append)
1064
+ #
1065
+ # For the first case, the AST (simplified) is like node = Attribute(value=Subscript(value=x, slice=[i, j]),
1066
+ # attr="append"), when we build_stmt(node.value)(build the expression of the Subscript i.e. x[i, j]),
1067
+ # it should build the expression of node.value.value (i.e. x) and node.value.slice (i.e. [i, j]), and raise a
1068
+ # TaichiIndexError because the dimension of the field is not equal to the number of the indices. Therefore,
1069
+ # when we meet the error, we can detect whether it is a method of Dynamic SNode and build the expression if
1070
+ # it is by calling build_attribute_if_is_dynamic_snode_method. If we find that it is not a method of Dynamic
1071
+ # SNode, we raise the error again.
1072
+ #
1073
+ # For the second case, the AST (simplified) is like node = Attribute(value=x, attr="append"), and it does not
1074
+ # raise error when we build_stmt(node.value). Therefore, when we do not meet the error, we can also detect
1075
+ # whether it is a method of Dynamic SNode and build the expression if it is by calling
1076
+ # build_attribute_if_is_dynamic_snode_method. If we find that it is not a method of Dynamic SNode,
1077
+ # we continue to process it as a normal attribute node.
1078
+ try:
1079
+ build_stmt(ctx, node.value)
1080
+ except Exception as e:
1081
+ e = handle_exception_from_cpp(e)
1082
+ if isinstance(e, TaichiIndexError):
1083
+ node.value.ptr = None
1084
+ if ASTTransformer.build_attribute_if_is_dynamic_snode_method(ctx, node):
1085
+ return node.ptr
1086
+ raise e
1087
+
1088
+ if ASTTransformer.build_attribute_if_is_dynamic_snode_method(ctx, node):
1089
+ return node.ptr
1090
+
1091
+ if isinstance(node.value.ptr, Expr) and not hasattr(node.value.ptr, node.attr):
1092
+ if node.attr in Matrix._swizzle_to_keygroup:
1093
+ keygroup = Matrix._swizzle_to_keygroup[node.attr]
1094
+ Matrix._keygroup_to_checker[keygroup](node.value.ptr, node.attr)
1095
+ attr_len = len(node.attr)
1096
+ if attr_len == 1:
1097
+ node.ptr = Expr(
1098
+ impl.get_runtime()
1099
+ .compiling_callable.ast_builder()
1100
+ .expr_subscript(
1101
+ node.value.ptr.ptr,
1102
+ make_expr_group(keygroup.index(node.attr)),
1103
+ _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
1104
+ )
1105
+ )
1106
+ else:
1107
+ node.ptr = Expr(
1108
+ _ti_core.subscript_with_multiple_indices(
1109
+ node.value.ptr.ptr,
1110
+ [make_expr_group(keygroup.index(ch)) for ch in node.attr],
1111
+ (attr_len,),
1112
+ _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
1113
+ )
1114
+ )
1115
+ else:
1116
+ from taichi.lang import ( # pylint: disable=C0415
1117
+ matrix_ops as tensor_ops,
1118
+ )
1119
+
1120
+ node.ptr = getattr(tensor_ops, node.attr)
1121
+ setattr(node, "caller", node.value.ptr)
1122
+ else:
1123
+ node.ptr = getattr(node.value.ptr, node.attr)
1124
+ return node.ptr
1125
+
1126
+ @staticmethod
1127
+ def build_BinOp(ctx: ASTTransformerContext, node: ast.BinOp):
1128
+ build_stmt(ctx, node.left)
1129
+ build_stmt(ctx, node.right)
1130
+ # pylint: disable-msg=C0415
1131
+ from taichi.lang.matrix_ops import matmul
1132
+
1133
+ op = {
1134
+ ast.Add: lambda l, r: l + r,
1135
+ ast.Sub: lambda l, r: l - r,
1136
+ ast.Mult: lambda l, r: l * r,
1137
+ ast.Div: lambda l, r: l / r,
1138
+ ast.FloorDiv: lambda l, r: l // r,
1139
+ ast.Mod: lambda l, r: l % r,
1140
+ ast.Pow: lambda l, r: l**r,
1141
+ ast.LShift: lambda l, r: l << r,
1142
+ ast.RShift: lambda l, r: l >> r,
1143
+ ast.BitOr: lambda l, r: l | r,
1144
+ ast.BitXor: lambda l, r: l ^ r,
1145
+ ast.BitAnd: lambda l, r: l & r,
1146
+ ast.MatMult: matmul,
1147
+ }.get(type(node.op))
1148
+ try:
1149
+ node.ptr = op(node.left.ptr, node.right.ptr)
1150
+ except TypeError as e:
1151
+ raise TaichiTypeError(str(e)) from None
1152
+ return node.ptr
1153
+
1154
+ @staticmethod
1155
+ def build_AugAssign(ctx: ASTTransformerContext, node: ast.AugAssign):
1156
+ build_stmt(ctx, node.target)
1157
+ build_stmt(ctx, node.value)
1158
+ if isinstance(node.target, ast.Name) and node.target.id in ctx.kernel_args:
1159
+ raise TaichiSyntaxError(
1160
+ f'Kernel argument "{node.target.id}" is immutable in the kernel. '
1161
+ f"If you want to change its value, please create a new variable."
1162
+ )
1163
+ node.ptr = node.target.ptr._augassign(node.value.ptr, type(node.op).__name__)
1164
+ return node.ptr
1165
+
1166
+ @staticmethod
1167
+ def build_UnaryOp(ctx: ASTTransformerContext, node: ast.UnaryOp):
1168
+ build_stmt(ctx, node.operand)
1169
+ op = {
1170
+ ast.UAdd: lambda l: l,
1171
+ ast.USub: lambda l: -l,
1172
+ ast.Not: ti_ops.logical_not,
1173
+ ast.Invert: lambda l: ~l,
1174
+ }.get(type(node.op))
1175
+ node.ptr = op(node.operand.ptr)
1176
+ return node.ptr
1177
+
1178
+ @staticmethod
1179
+ def build_bool_op(op):
1180
+ def inner(operands):
1181
+ if len(operands) == 1:
1182
+ return operands[0].ptr
1183
+ return op(operands[0].ptr, inner(operands[1:]))
1184
+
1185
+ return inner
1186
+
1187
+ @staticmethod
1188
+ def build_static_and(operands):
1189
+ for operand in operands:
1190
+ if not operand.ptr:
1191
+ return operand.ptr
1192
+ return operands[-1].ptr
1193
+
1194
+ @staticmethod
1195
+ def build_static_or(operands):
1196
+ for operand in operands:
1197
+ if operand.ptr:
1198
+ return operand.ptr
1199
+ return operands[-1].ptr
1200
+
1201
+ @staticmethod
1202
+ def build_BoolOp(ctx: ASTTransformerContext, node: ast.BoolOp):
1203
+ build_stmts(ctx, node.values)
1204
+ if ctx.is_in_static_scope():
1205
+ ops = {
1206
+ ast.And: ASTTransformer.build_static_and,
1207
+ ast.Or: ASTTransformer.build_static_or,
1208
+ }
1209
+ elif impl.get_runtime().short_circuit_operators:
1210
+ ops = {
1211
+ ast.And: ASTTransformer.build_bool_op(ti_ops.logical_and),
1212
+ ast.Or: ASTTransformer.build_bool_op(ti_ops.logical_or),
1213
+ }
1214
+ else:
1215
+ ops = {
1216
+ ast.And: ASTTransformer.build_bool_op(ti_ops.bit_and),
1217
+ ast.Or: ASTTransformer.build_bool_op(ti_ops.bit_or),
1218
+ }
1219
+ op = ops.get(type(node.op))
1220
+ node.ptr = op(node.values)
1221
+ return node.ptr
1222
+
1223
+ @staticmethod
1224
+ def build_Compare(ctx: ASTTransformerContext, node: ast.Compare):
1225
+ build_stmt(ctx, node.left)
1226
+ build_stmts(ctx, node.comparators)
1227
+ ops = {
1228
+ ast.Eq: lambda l, r: l == r,
1229
+ ast.NotEq: lambda l, r: l != r,
1230
+ ast.Lt: lambda l, r: l < r,
1231
+ ast.LtE: lambda l, r: l <= r,
1232
+ ast.Gt: lambda l, r: l > r,
1233
+ ast.GtE: lambda l, r: l >= r,
1234
+ }
1235
+ ops_static = {
1236
+ ast.In: lambda l, r: l in r,
1237
+ ast.NotIn: lambda l, r: l not in r,
1238
+ }
1239
+ if ctx.is_in_static_scope():
1240
+ ops = {**ops, **ops_static}
1241
+ operands = [node.left.ptr] + [comparator.ptr for comparator in node.comparators]
1242
+ val = True
1243
+ for i, node_op in enumerate(node.ops):
1244
+ if isinstance(node_op, (ast.Is, ast.IsNot)):
1245
+ name = "is" if isinstance(node_op, ast.Is) else "is not"
1246
+ raise TaichiSyntaxError(f'Operator "{name}" in Taichi scope is not supported.')
1247
+ l = operands[i]
1248
+ r = operands[i + 1]
1249
+ op = ops.get(type(node_op))
1250
+
1251
+ if op is None:
1252
+ if type(node_op) in ops_static:
1253
+ raise TaichiSyntaxError(f'"{type(node_op).__name__}" is only supported inside `ti.static`.')
1254
+ else:
1255
+ raise TaichiSyntaxError(f'"{type(node_op).__name__}" is not supported in Taichi kernels.')
1256
+ val = ti_ops.logical_and(val, op(l, r))
1257
+ if not isinstance(val, (bool, np.bool_)):
1258
+ val = ti_ops.cast(val, primitive_types.u1)
1259
+ node.ptr = val
1260
+ return node.ptr
1261
+
1262
+ @staticmethod
1263
+ def get_decorator(ctx: ASTTransformerContext, node) -> str:
1264
+ if not isinstance(node, ast.Call):
1265
+ return ""
1266
+ for wanted, name in [
1267
+ (impl.static, "static"),
1268
+ (impl.static_assert, "static_assert"),
1269
+ (impl.grouped, "grouped"),
1270
+ (ndrange, "ndrange"),
1271
+ ]:
1272
+ if ASTResolver.resolve_to(node.func, wanted, ctx.global_vars):
1273
+ return name
1274
+ return ""
1275
+
1276
+ @staticmethod
1277
+ def get_for_loop_targets(node: ast.Name | ast.Tuple | Any) -> list:
1278
+ """
1279
+ Returns the list of indices of the for loop |node|.
1280
+ See also: https://docs.python.org/3/library/ast.html#ast.For
1281
+ """
1282
+ if isinstance(node.target, ast.Name):
1283
+ return [node.target.id]
1284
+ assert isinstance(node.target, ast.Tuple)
1285
+ return [name.id for name in node.target.elts]
1286
+
1287
+ @staticmethod
1288
+ def build_static_for(ctx: ASTTransformerContext, node: ast.For, is_grouped: bool) -> None:
1289
+ ti_unroll_limit = impl.get_runtime().unrolling_limit
1290
+ if is_grouped:
1291
+ assert len(node.iter.args[0].args) == 1
1292
+ ndrange_arg = build_stmt(ctx, node.iter.args[0].args[0])
1293
+ if not isinstance(ndrange_arg, _Ndrange):
1294
+ raise TaichiSyntaxError("Only 'ti.ndrange' is allowed in 'ti.static(ti.grouped(...))'.")
1295
+ targets = ASTTransformer.get_for_loop_targets(node)
1296
+ if len(targets) != 1:
1297
+ raise TaichiSyntaxError(f"Group for should have 1 loop target, found {len(targets)}")
1298
+ target = targets[0]
1299
+ iter_time = 0
1300
+ alert_already = False
1301
+
1302
+ for value in impl.grouped(ndrange_arg):
1303
+ iter_time += 1
1304
+ if not alert_already and ti_unroll_limit and iter_time > ti_unroll_limit:
1305
+ alert_already = True
1306
+ warnings.warn_explicit(
1307
+ f"""You are unrolling more than
1308
+ {ti_unroll_limit} iterations, so the compile time may be extremely long.
1309
+ You can use a non-static for loop if you want to decrease the compile time.
1310
+ You can disable this warning by setting ti.init(unrolling_limit=0).""",
1311
+ SyntaxWarning,
1312
+ ctx.file,
1313
+ node.lineno + ctx.lineno_offset,
1314
+ module="taichi",
1315
+ )
1316
+
1317
+ with ctx.variable_scope_guard():
1318
+ ctx.create_variable(target, value)
1319
+ build_stmts(ctx, node.body)
1320
+ status = ctx.loop_status()
1321
+ if status == LoopStatus.Break:
1322
+ break
1323
+ elif status == LoopStatus.Continue:
1324
+ ctx.set_loop_status(LoopStatus.Normal)
1325
+ else:
1326
+ build_stmt(ctx, node.iter)
1327
+ targets = ASTTransformer.get_for_loop_targets(node)
1328
+
1329
+ iter_time = 0
1330
+ alert_already = False
1331
+ for target_values in node.iter.ptr:
1332
+ if not isinstance(target_values, collections.abc.Sequence) or len(targets) == 1:
1333
+ target_values = [target_values]
1334
+
1335
+ iter_time += 1
1336
+ if not alert_already and ti_unroll_limit and iter_time > ti_unroll_limit:
1337
+ alert_already = True
1338
+ warnings.warn_explicit(
1339
+ f"""You are unrolling more than
1340
+ {ti_unroll_limit} iterations, so the compile time may be extremely long.
1341
+ You can use a non-static for loop if you want to decrease the compile time.
1342
+ You can disable this warning by setting ti.init(unrolling_limit=0).""",
1343
+ SyntaxWarning,
1344
+ ctx.file,
1345
+ node.lineno + ctx.lineno_offset,
1346
+ module="taichi",
1347
+ )
1348
+
1349
+ with ctx.variable_scope_guard():
1350
+ for target, target_value in zip(targets, target_values):
1351
+ ctx.create_variable(target, target_value)
1352
+ build_stmts(ctx, node.body)
1353
+ status = ctx.loop_status()
1354
+ if status == LoopStatus.Break:
1355
+ break
1356
+ elif status == LoopStatus.Continue:
1357
+ ctx.set_loop_status(LoopStatus.Normal)
1358
+ return None
1359
+
1360
+ @staticmethod
1361
+ def build_range_for(ctx: ASTTransformerContext, node: ast.For) -> None:
1362
+ with ctx.variable_scope_guard():
1363
+ loop_name = node.target.id
1364
+ ctx.check_loop_var(loop_name)
1365
+ loop_var = expr.Expr(ctx.ast_builder.make_id_expr(""))
1366
+ ctx.create_variable(loop_name, loop_var)
1367
+ if len(node.iter.args) not in [1, 2]:
1368
+ raise TaichiSyntaxError(f"Range should have 1 or 2 arguments, found {len(node.iter.args)}")
1369
+ if len(node.iter.args) == 2:
1370
+ begin_expr = expr.Expr(build_stmt(ctx, node.iter.args[0]))
1371
+ end_expr = expr.Expr(build_stmt(ctx, node.iter.args[1]))
1372
+
1373
+ # Warning for implicit dtype conversion
1374
+ boundary_type_cast_warning(begin_expr)
1375
+ boundary_type_cast_warning(end_expr)
1376
+
1377
+ begin = ti_ops.cast(begin_expr, primitive_types.i32)
1378
+ end = ti_ops.cast(end_expr, primitive_types.i32)
1379
+
1380
+ else:
1381
+ end_expr = expr.Expr(build_stmt(ctx, node.iter.args[0]))
1382
+
1383
+ # Warning for implicit dtype conversion
1384
+ boundary_type_cast_warning(end_expr)
1385
+
1386
+ begin = ti_ops.cast(expr.Expr(0), primitive_types.i32)
1387
+ end = ti_ops.cast(end_expr, primitive_types.i32)
1388
+
1389
+ for_di = _ti_core.DebugInfo(ctx.get_pos_info(node))
1390
+ ctx.ast_builder.begin_frontend_range_for(loop_var.ptr, begin.ptr, end.ptr, for_di)
1391
+ build_stmts(ctx, node.body)
1392
+ ctx.ast_builder.end_frontend_range_for()
1393
+ return None
1394
+
1395
+ @staticmethod
1396
+ def build_ndrange_for(ctx: ASTTransformerContext, node: ast.For) -> None:
1397
+ with ctx.variable_scope_guard():
1398
+ ndrange_var = impl.expr_init(build_stmt(ctx, node.iter))
1399
+ ndrange_begin = ti_ops.cast(expr.Expr(0), primitive_types.i32)
1400
+ ndrange_end = ti_ops.cast(
1401
+ expr.Expr(impl.subscript(ctx.ast_builder, ndrange_var.acc_dimensions, 0)),
1402
+ primitive_types.i32,
1403
+ )
1404
+ ndrange_loop_var = expr.Expr(ctx.ast_builder.make_id_expr(""))
1405
+ for_di = _ti_core.DebugInfo(ctx.get_pos_info(node))
1406
+ ctx.ast_builder.begin_frontend_range_for(ndrange_loop_var.ptr, ndrange_begin.ptr, ndrange_end.ptr, for_di)
1407
+ I = impl.expr_init(ndrange_loop_var)
1408
+ targets = ASTTransformer.get_for_loop_targets(node)
1409
+ if len(targets) != len(ndrange_var.dimensions):
1410
+ raise TaichiSyntaxError(
1411
+ "Ndrange for loop with number of the loop variables not equal to "
1412
+ "the dimension of the ndrange is not supported. "
1413
+ "Please check if the number of arguments of ti.ndrange() is equal to "
1414
+ "the number of the loop variables."
1415
+ )
1416
+ for i, target in enumerate(targets):
1417
+ if i + 1 < len(targets):
1418
+ target_tmp = impl.expr_init(I // ndrange_var.acc_dimensions[i + 1])
1419
+ else:
1420
+ target_tmp = impl.expr_init(I)
1421
+ ctx.create_variable(
1422
+ target,
1423
+ impl.expr_init(
1424
+ target_tmp
1425
+ + impl.subscript(
1426
+ ctx.ast_builder,
1427
+ impl.subscript(ctx.ast_builder, ndrange_var.bounds, i),
1428
+ 0,
1429
+ )
1430
+ ),
1431
+ )
1432
+ if i + 1 < len(targets):
1433
+ I._assign(I - target_tmp * ndrange_var.acc_dimensions[i + 1])
1434
+ build_stmts(ctx, node.body)
1435
+ ctx.ast_builder.end_frontend_range_for()
1436
+ return None
1437
+
1438
+ @staticmethod
1439
+ def build_grouped_ndrange_for(ctx: ASTTransformerContext, node: ast.For) -> None:
1440
+ with ctx.variable_scope_guard():
1441
+ ndrange_var = impl.expr_init(build_stmt(ctx, node.iter.args[0]))
1442
+ ndrange_begin = ti_ops.cast(expr.Expr(0), primitive_types.i32)
1443
+ ndrange_end = ti_ops.cast(
1444
+ expr.Expr(impl.subscript(ctx.ast_builder, ndrange_var.acc_dimensions, 0)),
1445
+ primitive_types.i32,
1446
+ )
1447
+ ndrange_loop_var = expr.Expr(ctx.ast_builder.make_id_expr(""))
1448
+ for_di = _ti_core.DebugInfo(ctx.get_pos_info(node))
1449
+ ctx.ast_builder.begin_frontend_range_for(ndrange_loop_var.ptr, ndrange_begin.ptr, ndrange_end.ptr, for_di)
1450
+
1451
+ targets = ASTTransformer.get_for_loop_targets(node)
1452
+ if len(targets) != 1:
1453
+ raise TaichiSyntaxError(f"Group for should have 1 loop target, found {len(targets)}")
1454
+ target = targets[0]
1455
+ mat = matrix.make_matrix([0] * len(ndrange_var.dimensions), dt=primitive_types.i32)
1456
+ target_var = impl.expr_init(mat)
1457
+
1458
+ ctx.create_variable(target, target_var)
1459
+ I = impl.expr_init(ndrange_loop_var)
1460
+ for i in range(len(ndrange_var.dimensions)):
1461
+ if i + 1 < len(ndrange_var.dimensions):
1462
+ target_tmp = I // ndrange_var.acc_dimensions[i + 1]
1463
+ else:
1464
+ target_tmp = I
1465
+ impl.subscript(ctx.ast_builder, target_var, i)._assign(target_tmp + ndrange_var.bounds[i][0])
1466
+ if i + 1 < len(ndrange_var.dimensions):
1467
+ I._assign(I - target_tmp * ndrange_var.acc_dimensions[i + 1])
1468
+ build_stmts(ctx, node.body)
1469
+ ctx.ast_builder.end_frontend_range_for()
1470
+ return None
1471
+
1472
+ @staticmethod
1473
+ def build_struct_for(ctx: ASTTransformerContext, node: ast.For, is_grouped: bool) -> None:
1474
+ # for i, j in x
1475
+ # for I in ti.grouped(x)
1476
+ targets = ASTTransformer.get_for_loop_targets(node)
1477
+
1478
+ for target in targets:
1479
+ ctx.check_loop_var(target)
1480
+
1481
+ with ctx.variable_scope_guard():
1482
+ if is_grouped:
1483
+ if len(targets) != 1:
1484
+ raise TaichiSyntaxError(f"Group for should have 1 loop target, found {len(targets)}")
1485
+ target = targets[0]
1486
+ loop_var = build_stmt(ctx, node.iter)
1487
+ loop_indices = expr.make_var_list(size=len(loop_var.shape), ast_builder=ctx.ast_builder)
1488
+ expr_group = expr.make_expr_group(loop_indices)
1489
+ impl.begin_frontend_struct_for(ctx.ast_builder, expr_group, loop_var)
1490
+ ctx.create_variable(target, matrix.make_matrix(loop_indices, dt=primitive_types.i32))
1491
+ build_stmts(ctx, node.body)
1492
+ ctx.ast_builder.end_frontend_struct_for()
1493
+ else:
1494
+ _vars = []
1495
+ for name in targets:
1496
+ var = expr.Expr(ctx.ast_builder.make_id_expr(""))
1497
+ _vars.append(var)
1498
+ ctx.create_variable(name, var)
1499
+ loop_var = node.iter.ptr
1500
+ expr_group = expr.make_expr_group(*_vars)
1501
+ impl.begin_frontend_struct_for(ctx.ast_builder, expr_group, loop_var)
1502
+ build_stmts(ctx, node.body)
1503
+ ctx.ast_builder.end_frontend_struct_for()
1504
+ return None
1505
+
1506
+ @staticmethod
1507
+ def build_mesh_for(ctx: ASTTransformerContext, node: ast.For) -> None:
1508
+ targets = ASTTransformer.get_for_loop_targets(node)
1509
+ if len(targets) != 1:
1510
+ raise TaichiSyntaxError("Mesh for should have 1 loop target, found {len(targets)}")
1511
+ target = targets[0]
1512
+
1513
+ with ctx.variable_scope_guard():
1514
+ var = expr.Expr(ctx.ast_builder.make_id_expr(""))
1515
+ ctx.mesh = node.iter.ptr.mesh
1516
+ assert isinstance(ctx.mesh, impl.MeshInstance)
1517
+ mesh_idx = mesh.MeshElementFieldProxy(ctx.mesh, node.iter.ptr._type, var.ptr)
1518
+ ctx.create_variable(target, mesh_idx)
1519
+ ctx.ast_builder.begin_frontend_mesh_for(
1520
+ mesh_idx.ptr,
1521
+ ctx.mesh.mesh_ptr,
1522
+ node.iter.ptr._type,
1523
+ _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
1524
+ )
1525
+ build_stmts(ctx, node.body)
1526
+ ctx.mesh = None
1527
+ ctx.ast_builder.end_frontend_mesh_for()
1528
+ return None
1529
+
1530
+ @staticmethod
1531
+ def build_nested_mesh_for(ctx: ASTTransformerContext, node: ast.For) -> None:
1532
+ targets = ASTTransformer.get_for_loop_targets(node)
1533
+ if len(targets) != 1:
1534
+ raise TaichiSyntaxError("Nested-mesh for should have 1 loop target, found {len(targets)}")
1535
+ target = targets[0]
1536
+
1537
+ with ctx.variable_scope_guard():
1538
+ ctx.mesh = node.iter.ptr.mesh
1539
+ assert isinstance(ctx.mesh, impl.MeshInstance)
1540
+ loop_name = node.target.id + "_index__"
1541
+ loop_var = expr.Expr(ctx.ast_builder.make_id_expr(""))
1542
+ ctx.create_variable(loop_name, loop_var)
1543
+ begin = expr.Expr(0)
1544
+ end = ti_ops.cast(node.iter.ptr.size, primitive_types.i32)
1545
+ for_di = _ti_core.DebugInfo(ctx.get_pos_info(node))
1546
+ ctx.ast_builder.begin_frontend_range_for(loop_var.ptr, begin.ptr, end.ptr, for_di)
1547
+ entry_expr = _ti_core.get_relation_access(
1548
+ ctx.mesh.mesh_ptr,
1549
+ node.iter.ptr.from_index.ptr,
1550
+ node.iter.ptr.to_element_type,
1551
+ loop_var.ptr,
1552
+ )
1553
+ entry_expr.type_check(impl.get_runtime().prog.config())
1554
+ mesh_idx = mesh.MeshElementFieldProxy(ctx.mesh, node.iter.ptr.to_element_type, entry_expr)
1555
+ ctx.create_variable(target, mesh_idx)
1556
+ build_stmts(ctx, node.body)
1557
+ ctx.ast_builder.end_frontend_range_for()
1558
+
1559
+ return None
1560
+
1561
+ @staticmethod
1562
+ def build_For(ctx: ASTTransformerContext, node: ast.For) -> None:
1563
+ if node.orelse:
1564
+ raise TaichiSyntaxError("'else' clause for 'for' not supported in Taichi kernels")
1565
+ decorator = ASTTransformer.get_decorator(ctx, node.iter)
1566
+ double_decorator = ""
1567
+ if decorator != "" and len(node.iter.args) == 1:
1568
+ double_decorator = ASTTransformer.get_decorator(ctx, node.iter.args[0])
1569
+
1570
+ if decorator == "static":
1571
+ if double_decorator == "static":
1572
+ raise TaichiSyntaxError("'ti.static' cannot be nested")
1573
+ with ctx.loop_scope_guard(is_static=True):
1574
+ return ASTTransformer.build_static_for(ctx, node, double_decorator == "grouped")
1575
+ with ctx.loop_scope_guard():
1576
+ if decorator == "ndrange":
1577
+ if double_decorator != "":
1578
+ raise TaichiSyntaxError("No decorator is allowed inside 'ti.ndrange")
1579
+ return ASTTransformer.build_ndrange_for(ctx, node)
1580
+ if decorator == "grouped":
1581
+ if double_decorator == "static":
1582
+ raise TaichiSyntaxError("'ti.static' is not allowed inside 'ti.grouped'")
1583
+ elif double_decorator == "ndrange":
1584
+ return ASTTransformer.build_grouped_ndrange_for(ctx, node)
1585
+ elif double_decorator == "grouped":
1586
+ raise TaichiSyntaxError("'ti.grouped' cannot be nested")
1587
+ else:
1588
+ return ASTTransformer.build_struct_for(ctx, node, is_grouped=True)
1589
+ elif (
1590
+ isinstance(node.iter, ast.Call)
1591
+ and isinstance(node.iter.func, ast.Name)
1592
+ and node.iter.func.id == "range"
1593
+ ):
1594
+ return ASTTransformer.build_range_for(ctx, node)
1595
+ else:
1596
+ build_stmt(ctx, node.iter)
1597
+ if isinstance(node.iter.ptr, mesh.MeshElementField):
1598
+ if not _ti_core.is_extension_supported(impl.default_cfg().arch, _ti_core.Extension.mesh):
1599
+ raise Exception(
1600
+ "Backend " + str(impl.default_cfg().arch) + " doesn't support MeshTaichi extension"
1601
+ )
1602
+ return ASTTransformer.build_mesh_for(ctx, node)
1603
+ if isinstance(node.iter.ptr, mesh.MeshRelationAccessProxy):
1604
+ return ASTTransformer.build_nested_mesh_for(ctx, node)
1605
+ # Struct for
1606
+ return ASTTransformer.build_struct_for(ctx, node, is_grouped=False)
1607
+
1608
+ @staticmethod
1609
+ def build_While(ctx: ASTTransformerContext, node: ast.While) -> None:
1610
+ if node.orelse:
1611
+ raise TaichiSyntaxError("'else' clause for 'while' not supported in Taichi kernels")
1612
+
1613
+ with ctx.loop_scope_guard():
1614
+ stmt_dbg_info = _ti_core.DebugInfo(ctx.get_pos_info(node))
1615
+ ctx.ast_builder.begin_frontend_while(expr.Expr(1, dtype=primitive_types.i32).ptr, stmt_dbg_info)
1616
+ while_cond = build_stmt(ctx, node.test)
1617
+ impl.begin_frontend_if(ctx.ast_builder, while_cond, stmt_dbg_info)
1618
+ ctx.ast_builder.begin_frontend_if_true()
1619
+ ctx.ast_builder.pop_scope()
1620
+ ctx.ast_builder.begin_frontend_if_false()
1621
+ ctx.ast_builder.insert_break_stmt(stmt_dbg_info)
1622
+ ctx.ast_builder.pop_scope()
1623
+ build_stmts(ctx, node.body)
1624
+ ctx.ast_builder.pop_scope()
1625
+ return None
1626
+
1627
+ @staticmethod
1628
+ def build_If(ctx: ASTTransformerContext, node: ast.If) -> ast.If | None:
1629
+ build_stmt(ctx, node.test)
1630
+ is_static_if = ASTTransformer.get_decorator(ctx, node.test) == "static"
1631
+
1632
+ if is_static_if:
1633
+ if node.test.ptr:
1634
+ build_stmts(ctx, node.body)
1635
+ else:
1636
+ build_stmts(ctx, node.orelse)
1637
+ return node
1638
+
1639
+ with ctx.non_static_if_guard(node):
1640
+ stmt_dbg_info = _ti_core.DebugInfo(ctx.get_pos_info(node))
1641
+ impl.begin_frontend_if(ctx.ast_builder, node.test.ptr, stmt_dbg_info)
1642
+ ctx.ast_builder.begin_frontend_if_true()
1643
+ build_stmts(ctx, node.body)
1644
+ ctx.ast_builder.pop_scope()
1645
+ ctx.ast_builder.begin_frontend_if_false()
1646
+ build_stmts(ctx, node.orelse)
1647
+ ctx.ast_builder.pop_scope()
1648
+ return None
1649
+
1650
+ @staticmethod
1651
+ def build_Expr(ctx: ASTTransformerContext, node: ast.Expr) -> None:
1652
+ build_stmt(ctx, node.value)
1653
+ return None
1654
+
1655
+ @staticmethod
1656
+ def build_IfExp(ctx: ASTTransformerContext, node: ast.IfExp):
1657
+ build_stmt(ctx, node.test)
1658
+ build_stmt(ctx, node.body)
1659
+ build_stmt(ctx, node.orelse)
1660
+
1661
+ has_tensor_type = False
1662
+ if isinstance(node.test.ptr, expr.Expr) and node.test.ptr.is_tensor():
1663
+ has_tensor_type = True
1664
+ if isinstance(node.body.ptr, expr.Expr) and node.body.ptr.is_tensor():
1665
+ has_tensor_type = True
1666
+ if isinstance(node.orelse.ptr, expr.Expr) and node.orelse.ptr.is_tensor():
1667
+ has_tensor_type = True
1668
+
1669
+ if has_tensor_type:
1670
+ if isinstance(node.test.ptr, expr.Expr) and node.test.ptr.is_tensor():
1671
+ raise TaichiSyntaxError(
1672
+ "Using conditional expression for element-wise select operation on "
1673
+ "Taichi vectors/matrices is deprecated and removed starting from Taichi v1.5.0 "
1674
+ 'Please use "ti.select" instead.'
1675
+ )
1676
+ node.ptr = ti_ops.select(node.test.ptr, node.body.ptr, node.orelse.ptr)
1677
+ return node.ptr
1678
+
1679
+ is_static_if = ASTTransformer.get_decorator(ctx, node.test) == "static"
1680
+
1681
+ if is_static_if:
1682
+ if node.test.ptr:
1683
+ node.ptr = build_stmt(ctx, node.body)
1684
+ else:
1685
+ node.ptr = build_stmt(ctx, node.orelse)
1686
+ return node.ptr
1687
+
1688
+ node.ptr = ti_ops.ifte(node.test.ptr, node.body.ptr, node.orelse.ptr)
1689
+ return node.ptr
1690
+
1691
+ @staticmethod
1692
+ def _is_string_mod_args(msg) -> bool:
1693
+ # 1. str % (a, b, c, ...)
1694
+ # 2. str % single_item
1695
+ # Note that |msg.right| may not be a tuple.
1696
+ if not isinstance(msg, ast.BinOp):
1697
+ return False
1698
+ if not isinstance(msg.op, ast.Mod):
1699
+ return False
1700
+ if isinstance(msg.left, ast.Str):
1701
+ return True
1702
+ if isinstance(msg.left, ast.Constant) and isinstance(msg.left.value, str):
1703
+ return True
1704
+ return False
1705
+
1706
+ @staticmethod
1707
+ def _handle_string_mod_args(ctx: ASTTransformerContext, node):
1708
+ msg = build_stmt(ctx, node.left)
1709
+ args = build_stmt(ctx, node.right)
1710
+ if not isinstance(args, collections.abc.Sequence):
1711
+ args = (args,)
1712
+ args = [expr.Expr(x).ptr for x in args]
1713
+ return msg, args
1714
+
1715
+ @staticmethod
1716
+ def ti_format_list_to_assert_msg(raw) -> tuple[str, list]:
1717
+ # TODO: ignore formats here for now
1718
+ entries, _ = impl.ti_format_list_to_content_entries([raw])
1719
+ msg = ""
1720
+ args = []
1721
+ for entry in entries:
1722
+ if isinstance(entry, str):
1723
+ msg += entry
1724
+ elif isinstance(entry, _ti_core.Expr):
1725
+ ty = entry.get_rvalue_type()
1726
+ if ty in primitive_types.real_types:
1727
+ msg += "%f"
1728
+ elif ty in primitive_types.integer_types:
1729
+ msg += "%d"
1730
+ else:
1731
+ raise TaichiSyntaxError(f"Unsupported data type: {type(ty)}")
1732
+ args.append(entry)
1733
+ else:
1734
+ raise TaichiSyntaxError(f"Unsupported type: {type(entry)}")
1735
+ return msg, args
1736
+
1737
+ @staticmethod
1738
+ def build_Assert(ctx: ASTTransformerContext, node: ast.Assert) -> None:
1739
+ extra_args = []
1740
+ if node.msg is not None:
1741
+ if ASTTransformer._is_string_mod_args(node.msg):
1742
+ msg, extra_args = ASTTransformer._handle_string_mod_args(ctx, node.msg)
1743
+ else:
1744
+ msg = build_stmt(ctx, node.msg)
1745
+ if isinstance(node.msg, ast.Constant):
1746
+ msg = str(msg)
1747
+ elif isinstance(node.msg, ast.Str):
1748
+ pass
1749
+ elif isinstance(msg, collections.abc.Sequence) and len(msg) > 0 and msg[0] == "__ti_format__":
1750
+ msg, extra_args = ASTTransformer.ti_format_list_to_assert_msg(msg)
1751
+ else:
1752
+ raise TaichiSyntaxError(f"assert info must be constant or formatted string, not {type(msg)}")
1753
+ else:
1754
+ msg = unparse(node.test)
1755
+ test = build_stmt(ctx, node.test)
1756
+ impl.ti_assert(test, msg.strip(), extra_args, _ti_core.DebugInfo(ctx.get_pos_info(node)))
1757
+ return None
1758
+
1759
+ @staticmethod
1760
+ def build_Break(ctx: ASTTransformerContext, node: ast.Break) -> None:
1761
+ if ctx.is_in_static_for():
1762
+ nearest_non_static_if = ctx.current_loop_scope().nearest_non_static_if
1763
+ if nearest_non_static_if:
1764
+ msg = ctx.get_pos_info(nearest_non_static_if.test)
1765
+ msg += (
1766
+ "You are trying to `break` a static `for` loop, "
1767
+ "but the `break` statement is inside a non-static `if`. "
1768
+ )
1769
+ raise TaichiSyntaxError(msg)
1770
+ ctx.set_loop_status(LoopStatus.Break)
1771
+ else:
1772
+ ctx.ast_builder.insert_break_stmt(_ti_core.DebugInfo(ctx.get_pos_info(node)))
1773
+ return None
1774
+
1775
+ @staticmethod
1776
+ def build_Continue(ctx: ASTTransformerContext, node: ast.Continue) -> None:
1777
+ if ctx.is_in_static_for():
1778
+ nearest_non_static_if = ctx.current_loop_scope().nearest_non_static_if
1779
+ if nearest_non_static_if:
1780
+ msg = ctx.get_pos_info(nearest_non_static_if.test)
1781
+ msg += (
1782
+ "You are trying to `continue` a static `for` loop, "
1783
+ "but the `continue` statement is inside a non-static `if`. "
1784
+ )
1785
+ raise TaichiSyntaxError(msg)
1786
+ ctx.set_loop_status(LoopStatus.Continue)
1787
+ else:
1788
+ ctx.ast_builder.insert_continue_stmt(_ti_core.DebugInfo(ctx.get_pos_info(node)))
1789
+ return None
1790
+
1791
+ @staticmethod
1792
+ def build_Pass(ctx: ASTTransformerContext, node: ast.Pass) -> None:
1793
+ return None
1794
+
1795
+
1796
+ build_stmt = ASTTransformer()
1797
+
1798
+
1799
+ def build_stmts(ctx: ASTTransformerContext, stmts: list):
1800
+ with ctx.variable_scope_guard():
1801
+ for stmt in stmts:
1802
+ if ctx.returned != ReturnStatus.NoReturn or ctx.loop_status() != LoopStatus.Normal:
1803
+ break
1804
+ else:
1805
+ build_stmt(ctx, stmt)
1806
+ return stmts