gstaichi 2.1.1__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (178) hide show
  1. gstaichi/__init__.py +40 -0
  2. gstaichi/_funcs.py +706 -0
  3. gstaichi/_kernels.py +420 -0
  4. gstaichi/_lib/__init__.py +3 -0
  5. gstaichi/_lib/core/__init__.py +0 -0
  6. gstaichi/_lib/core/gstaichi_python.cpython-311-darwin.so +0 -0
  7. gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
  8. gstaichi/_lib/core/py.typed +0 -0
  9. gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
  10. gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
  11. gstaichi/_lib/utils.py +243 -0
  12. gstaichi/_logging.py +131 -0
  13. gstaichi/_snode/__init__.py +5 -0
  14. gstaichi/_snode/fields_builder.py +187 -0
  15. gstaichi/_snode/snode_tree.py +34 -0
  16. gstaichi/_test_tools/__init__.py +18 -0
  17. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  18. gstaichi/_test_tools/load_kernel_string.py +30 -0
  19. gstaichi/_test_tools/textwrap2.py +6 -0
  20. gstaichi/_version.py +1 -0
  21. gstaichi/_version_check.py +100 -0
  22. gstaichi/ad/__init__.py +3 -0
  23. gstaichi/ad/_ad.py +530 -0
  24. gstaichi/algorithms/__init__.py +3 -0
  25. gstaichi/algorithms/_algorithms.py +117 -0
  26. gstaichi/assets/.git +1 -0
  27. gstaichi/assets/Go-Regular.ttf +0 -0
  28. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  29. gstaichi/examples/lcg_python.py +26 -0
  30. gstaichi/examples/lcg_taichi.py +34 -0
  31. gstaichi/examples/minimal.py +28 -0
  32. gstaichi/experimental.py +16 -0
  33. gstaichi/lang/__init__.py +50 -0
  34. gstaichi/lang/_dataclass_util.py +31 -0
  35. gstaichi/lang/_fast_caching/__init__.py +3 -0
  36. gstaichi/lang/_fast_caching/args_hasher.py +110 -0
  37. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  38. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  39. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  40. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  41. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  42. gstaichi/lang/_fast_caching/src_hasher.py +75 -0
  43. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  44. gstaichi/lang/_ndarray.py +352 -0
  45. gstaichi/lang/_ndrange.py +152 -0
  46. gstaichi/lang/_template_mapper.py +195 -0
  47. gstaichi/lang/_texture.py +172 -0
  48. gstaichi/lang/_wrap_inspect.py +215 -0
  49. gstaichi/lang/any_array.py +99 -0
  50. gstaichi/lang/ast/__init__.py +5 -0
  51. gstaichi/lang/ast/ast_transformer.py +1323 -0
  52. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  53. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  54. gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
  55. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  56. gstaichi/lang/ast/checkers.py +106 -0
  57. gstaichi/lang/ast/symbol_resolver.py +57 -0
  58. gstaichi/lang/ast/transform.py +9 -0
  59. gstaichi/lang/common_ops.py +310 -0
  60. gstaichi/lang/exception.py +80 -0
  61. gstaichi/lang/expr.py +180 -0
  62. gstaichi/lang/field.py +428 -0
  63. gstaichi/lang/impl.py +1245 -0
  64. gstaichi/lang/kernel_arguments.py +155 -0
  65. gstaichi/lang/kernel_impl.py +1341 -0
  66. gstaichi/lang/matrix.py +1835 -0
  67. gstaichi/lang/matrix_ops.py +341 -0
  68. gstaichi/lang/matrix_ops_utils.py +190 -0
  69. gstaichi/lang/mesh.py +687 -0
  70. gstaichi/lang/misc.py +780 -0
  71. gstaichi/lang/ops.py +1494 -0
  72. gstaichi/lang/runtime_ops.py +13 -0
  73. gstaichi/lang/shell.py +35 -0
  74. gstaichi/lang/simt/__init__.py +5 -0
  75. gstaichi/lang/simt/block.py +94 -0
  76. gstaichi/lang/simt/grid.py +7 -0
  77. gstaichi/lang/simt/subgroup.py +191 -0
  78. gstaichi/lang/simt/warp.py +96 -0
  79. gstaichi/lang/snode.py +489 -0
  80. gstaichi/lang/source_builder.py +150 -0
  81. gstaichi/lang/struct.py +810 -0
  82. gstaichi/lang/util.py +312 -0
  83. gstaichi/linalg/__init__.py +8 -0
  84. gstaichi/linalg/matrixfree_cg.py +310 -0
  85. gstaichi/linalg/sparse_cg.py +59 -0
  86. gstaichi/linalg/sparse_matrix.py +303 -0
  87. gstaichi/linalg/sparse_solver.py +123 -0
  88. gstaichi/math/__init__.py +11 -0
  89. gstaichi/math/_complex.py +205 -0
  90. gstaichi/math/mathimpl.py +886 -0
  91. gstaichi/profiler/__init__.py +6 -0
  92. gstaichi/profiler/kernel_metrics.py +260 -0
  93. gstaichi/profiler/kernel_profiler.py +586 -0
  94. gstaichi/profiler/memory_profiler.py +15 -0
  95. gstaichi/profiler/scoped_profiler.py +36 -0
  96. gstaichi/sparse/__init__.py +3 -0
  97. gstaichi/sparse/_sparse_grid.py +77 -0
  98. gstaichi/tools/__init__.py +12 -0
  99. gstaichi/tools/diagnose.py +117 -0
  100. gstaichi/tools/np2ply.py +364 -0
  101. gstaichi/tools/vtk.py +38 -0
  102. gstaichi/types/__init__.py +19 -0
  103. gstaichi/types/annotations.py +52 -0
  104. gstaichi/types/compound_types.py +71 -0
  105. gstaichi/types/enums.py +49 -0
  106. gstaichi/types/ndarray_type.py +169 -0
  107. gstaichi/types/primitive_types.py +206 -0
  108. gstaichi/types/quant.py +88 -0
  109. gstaichi/types/texture_type.py +85 -0
  110. gstaichi/types/utils.py +11 -0
  111. gstaichi-2.1.1.data/data/include/GLFW/glfw3.h +6389 -0
  112. gstaichi-2.1.1.data/data/include/GLFW/glfw3native.h +594 -0
  113. gstaichi-2.1.1.data/data/include/spirv-tools/instrument.hpp +268 -0
  114. gstaichi-2.1.1.data/data/include/spirv-tools/libspirv.h +907 -0
  115. gstaichi-2.1.1.data/data/include/spirv-tools/libspirv.hpp +375 -0
  116. gstaichi-2.1.1.data/data/include/spirv-tools/linker.hpp +97 -0
  117. gstaichi-2.1.1.data/data/include/spirv-tools/optimizer.hpp +970 -0
  118. gstaichi-2.1.1.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  119. gstaichi-2.1.1.data/data/include/spirv_cross/spirv.h +2568 -0
  120. gstaichi-2.1.1.data/data/include/spirv_cross/spirv.hpp +2579 -0
  121. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  122. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  123. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  124. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  125. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  126. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  127. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  128. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  129. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  130. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  131. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  132. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  133. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  134. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  135. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  136. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  137. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  138. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  139. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  140. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  141. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  142. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  143. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  144. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  145. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  146. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  147. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  148. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  149. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  150. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  151. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  152. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  153. gstaichi-2.1.1.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  154. gstaichi-2.1.1.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  155. gstaichi-2.1.1.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  156. gstaichi-2.1.1.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  157. gstaichi-2.1.1.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  158. gstaichi-2.1.1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  159. gstaichi-2.1.1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  160. gstaichi-2.1.1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  161. gstaichi-2.1.1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  162. gstaichi-2.1.1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  163. gstaichi-2.1.1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  164. gstaichi-2.1.1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  165. gstaichi-2.1.1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  166. gstaichi-2.1.1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  167. gstaichi-2.1.1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  168. gstaichi-2.1.1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  169. gstaichi-2.1.1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  170. gstaichi-2.1.1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  171. gstaichi-2.1.1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  172. gstaichi-2.1.1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  173. gstaichi-2.1.1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  174. gstaichi-2.1.1.dist-info/METADATA +106 -0
  175. gstaichi-2.1.1.dist-info/RECORD +178 -0
  176. gstaichi-2.1.1.dist-info/WHEEL +5 -0
  177. gstaichi-2.1.1.dist-info/licenses/LICENSE +201 -0
  178. gstaichi-2.1.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,346 @@
1
+ # type: ignore
2
+
3
+ import ast
4
+ import builtins
5
+ import traceback
6
+ from enum import Enum
7
+ from textwrap import TextWrapper
8
+ from typing import TYPE_CHECKING, Any, List
9
+
10
+ from gstaichi._lib.core.gstaichi_python import ASTBuilder
11
+ from gstaichi.lang import impl
12
+ from gstaichi.lang._ndrange import ndrange
13
+ from gstaichi.lang.ast.symbol_resolver import ASTResolver
14
+ from gstaichi.lang.exception import (
15
+ GsTaichiCompilationError,
16
+ GsTaichiNameError,
17
+ GsTaichiSyntaxError,
18
+ handle_exception_from_cpp,
19
+ )
20
+
21
+ if TYPE_CHECKING:
22
+ from gstaichi.lang.kernel_impl import (
23
+ Func,
24
+ Kernel,
25
+ )
26
+
27
+
28
+ class Builder:
29
+ def __call__(self, ctx: "ASTTransformerContext", node: ast.AST):
30
+ method_name = "build_" + node.__class__.__name__
31
+ method = getattr(self, method_name, None)
32
+ try:
33
+ if method is None:
34
+ error_msg = f'Unsupported node "{node.__class__.__name__}"'
35
+ raise GsTaichiSyntaxError(error_msg)
36
+ info = ctx.get_pos_info(node) if isinstance(node, (ast.stmt, ast.expr)) else ""
37
+ with impl.get_runtime().src_info_guard(info):
38
+ return method(ctx, node)
39
+ except Exception as e:
40
+ if impl.get_runtime().print_full_traceback:
41
+ raise e
42
+ if ctx.raised or not isinstance(node, (ast.stmt, ast.expr)):
43
+ raise e.with_traceback(None)
44
+ ctx.raised = True
45
+ e = handle_exception_from_cpp(e)
46
+ if not isinstance(e, GsTaichiCompilationError):
47
+ msg = ctx.get_pos_info(node) + traceback.format_exc()
48
+ raise GsTaichiCompilationError(msg) from None
49
+ msg = ctx.get_pos_info(node) + str(e)
50
+ raise type(e)(msg) from None
51
+
52
+
53
+ class VariableScopeGuard:
54
+ def __init__(self, scopes: list[dict[str, Any]]):
55
+ self.scopes = scopes
56
+
57
+ def __enter__(self):
58
+ self.scopes.append({})
59
+
60
+ def __exit__(self, exc_type, exc_val, exc_tb):
61
+ self.scopes.pop()
62
+
63
+
64
+ class StaticScopeStatus:
65
+ def __init__(self):
66
+ self.is_in_static_scope = False
67
+
68
+
69
+ class StaticScopeGuard:
70
+ def __init__(self, status: StaticScopeStatus):
71
+ self.status = status
72
+
73
+ def __enter__(self):
74
+ self.prev = self.status.is_in_static_scope
75
+ self.status.is_in_static_scope = True
76
+
77
+ def __exit__(self, exc_type, exc_val, exc_tb):
78
+ self.status.is_in_static_scope = self.prev
79
+
80
+
81
+ class NonStaticControlFlowStatus:
82
+ def __init__(self):
83
+ self.is_in_non_static_control_flow = False
84
+
85
+
86
+ class NonStaticControlFlowGuard:
87
+ def __init__(self, status: NonStaticControlFlowStatus):
88
+ self.status = status
89
+
90
+ def __enter__(self):
91
+ self.prev = self.status.is_in_non_static_control_flow
92
+ self.status.is_in_non_static_control_flow = True
93
+
94
+ def __exit__(self, exc_type, exc_val, exc_tb):
95
+ self.status.is_in_non_static_control_flow = self.prev
96
+
97
+
98
+ class LoopStatus(Enum):
99
+ Normal = 0
100
+ Break = 1
101
+ Continue = 2
102
+
103
+
104
+ class LoopScopeAttribute:
105
+ def __init__(self, is_static: bool):
106
+ self.is_static = is_static
107
+ self.status: LoopStatus = LoopStatus.Normal
108
+ self.nearest_non_static_if: ast.If | None = None
109
+
110
+
111
+ class LoopScopeGuard:
112
+ def __init__(self, scopes: list[dict[str, Any]], non_static_guard=None):
113
+ self.scopes = scopes
114
+ self.non_static_guard = non_static_guard
115
+
116
+ def __enter__(self):
117
+ self.scopes.append(LoopScopeAttribute(self.non_static_guard is None))
118
+ if self.non_static_guard:
119
+ self.non_static_guard.__enter__()
120
+
121
+ def __exit__(self, exc_type, exc_val, exc_tb):
122
+ self.scopes.pop()
123
+ if self.non_static_guard:
124
+ self.non_static_guard.__exit__(exc_type, exc_val, exc_tb)
125
+
126
+
127
+ class NonStaticIfGuard:
128
+ def __init__(
129
+ self,
130
+ if_node: ast.If,
131
+ loop_attribute: LoopScopeAttribute,
132
+ non_static_status: NonStaticControlFlowStatus,
133
+ ):
134
+ self.loop_attribute = loop_attribute
135
+ self.if_node = if_node
136
+ self.non_static_guard = NonStaticControlFlowGuard(non_static_status)
137
+
138
+ def __enter__(self):
139
+ if self.loop_attribute:
140
+ self.old_non_static_if = self.loop_attribute.nearest_non_static_if
141
+ self.loop_attribute.nearest_non_static_if = self.if_node
142
+ self.non_static_guard.__enter__()
143
+
144
+ def __exit__(self, exc_type, exc_val, exc_tb):
145
+ if self.loop_attribute:
146
+ self.loop_attribute.nearest_non_static_if = self.old_non_static_if
147
+ self.non_static_guard.__exit__(exc_type, exc_val, exc_tb)
148
+
149
+
150
+ class ReturnStatus(Enum):
151
+ NoReturn = 0
152
+ ReturnedVoid = 1
153
+ ReturnedValue = 2
154
+
155
+
156
+ class ASTTransformerContext:
157
+ def __init__(
158
+ self,
159
+ excluded_parameters,
160
+ end_lineno: int,
161
+ is_kernel: bool,
162
+ func: "Func | Kernel",
163
+ arg_features: list[tuple[Any, ...]] | None,
164
+ global_vars: dict[str, Any],
165
+ argument_data,
166
+ file: str,
167
+ src: list[str],
168
+ start_lineno: int,
169
+ ast_builder: ASTBuilder | None,
170
+ is_real_function: bool,
171
+ ):
172
+ self.func = func
173
+ self.local_scopes: list[dict[str, Any]] = []
174
+ self.loop_scopes: List[LoopScopeAttribute] = []
175
+ self.excluded_parameters = excluded_parameters
176
+ self.is_kernel = is_kernel
177
+ self.arg_features = arg_features
178
+ self.returns = None
179
+ self.global_vars = global_vars
180
+ self.argument_data = argument_data
181
+ self.return_data: tuple[Any, ...] | Any | None = None
182
+ self.file = file
183
+ self.src = src
184
+ self.indent = 0
185
+ for c in self.src[0]:
186
+ if c == " ":
187
+ self.indent += 1
188
+ else:
189
+ break
190
+ self.lineno_offset = start_lineno - 1
191
+ self.start_lineno = start_lineno
192
+ self.end_lineno = end_lineno
193
+ self.raised = False
194
+ self.non_static_control_flow_status = NonStaticControlFlowStatus()
195
+ self.static_scope_status = StaticScopeStatus()
196
+ self.returned = ReturnStatus.NoReturn
197
+ self.ast_builder = ast_builder
198
+ self.visited_funcdef = False
199
+ self.is_real_function = is_real_function
200
+ self.kernel_args: list = []
201
+ self.only_parse_function_def: bool = False
202
+
203
+ # e.g.: FunctionDef, Module, Global
204
+ def variable_scope_guard(self):
205
+ return VariableScopeGuard(self.local_scopes)
206
+
207
+ # e.g.: For, While
208
+ def loop_scope_guard(self, is_static=False):
209
+ if is_static:
210
+ return LoopScopeGuard(self.loop_scopes)
211
+ return LoopScopeGuard(self.loop_scopes, self.non_static_control_flow_guard())
212
+
213
+ def non_static_if_guard(self, if_node: ast.If):
214
+ return NonStaticIfGuard(
215
+ if_node,
216
+ self.current_loop_scope() if self.loop_scopes else None,
217
+ self.non_static_control_flow_status,
218
+ )
219
+
220
+ def non_static_control_flow_guard(self) -> NonStaticControlFlowGuard:
221
+ return NonStaticControlFlowGuard(self.non_static_control_flow_status)
222
+
223
+ def static_scope_guard(self) -> StaticScopeGuard:
224
+ return StaticScopeGuard(self.static_scope_status)
225
+
226
+ def current_scope(self) -> dict[str, Any]:
227
+ return self.local_scopes[-1]
228
+
229
+ def current_loop_scope(self) -> dict[str, Any]:
230
+ return self.loop_scopes[-1]
231
+
232
+ def loop_status(self) -> LoopStatus:
233
+ if self.loop_scopes:
234
+ return self.loop_scopes[-1].status
235
+ return LoopStatus.Normal
236
+
237
+ def set_loop_status(self, status: LoopStatus) -> None:
238
+ self.loop_scopes[-1].status = status
239
+
240
+ def is_in_static_for(self) -> bool:
241
+ if self.loop_scopes:
242
+ return self.loop_scopes[-1].is_static
243
+ return False
244
+
245
+ def is_in_non_static_control_flow(self) -> bool:
246
+ return self.non_static_control_flow_status.is_in_non_static_control_flow
247
+
248
+ def is_in_static_scope(self) -> bool:
249
+ return self.static_scope_status.is_in_static_scope
250
+
251
+ def is_var_declared(self, name: str) -> bool:
252
+ for s in self.local_scopes:
253
+ if name in s:
254
+ return True
255
+ return False
256
+
257
+ def create_variable(self, name: str, var: Any) -> None:
258
+ if name in self.current_scope():
259
+ raise GsTaichiSyntaxError("Recreating variables is not allowed")
260
+ self.current_scope()[name] = var
261
+
262
+ def check_loop_var(self, loop_var: str) -> None:
263
+ if self.is_var_declared(loop_var):
264
+ raise GsTaichiSyntaxError(
265
+ f"Variable '{loop_var}' is already declared in the outer scope and cannot be used as loop variable"
266
+ )
267
+
268
+ def get_var_by_name(self, name: str) -> Any:
269
+ for s in reversed(self.local_scopes):
270
+ if name in s:
271
+ return s[name]
272
+ if name in self.global_vars:
273
+ var = self.global_vars[name]
274
+ from gstaichi.lang.matrix import ( # pylint: disable-msg=C0415
275
+ Matrix,
276
+ make_matrix,
277
+ )
278
+
279
+ if isinstance(var, Matrix):
280
+ return make_matrix(var.to_list())
281
+ return var
282
+ try:
283
+ return getattr(builtins, name)
284
+ except AttributeError:
285
+ raise GsTaichiNameError(f'Name "{name}" is not defined')
286
+
287
+ def get_pos_info(self, node: ast.AST) -> str:
288
+ msg = f'File "{self.file}", line {node.lineno + self.lineno_offset}, in {self.func.func.__name__}:\n'
289
+ col_offset = self.indent + node.col_offset
290
+ end_col_offset = self.indent + node.end_col_offset
291
+
292
+ wrapper = TextWrapper(width=80)
293
+
294
+ def gen_line(code: str, hint: str) -> str:
295
+ hint += " " * (len(code) - len(hint))
296
+ code = wrapper.wrap(code)
297
+ hint = wrapper.wrap(hint)
298
+ if not len(code):
299
+ return "\n\n"
300
+ return "".join([c + "\n" + h + "\n" for c, h in zip(code, hint)])
301
+
302
+ if node.lineno == node.end_lineno:
303
+ if node.lineno - 1 < len(self.src):
304
+ hint = " " * col_offset + "^" * (end_col_offset - col_offset)
305
+ msg += gen_line(self.src[node.lineno - 1], hint)
306
+ else:
307
+ node_type = node.__class__.__name__
308
+
309
+ if node_type in ["For", "While", "FunctionDef", "If"]:
310
+ end_lineno = max(node.body[0].lineno - 1, node.lineno)
311
+ else:
312
+ end_lineno = node.end_lineno
313
+
314
+ for i in range(node.lineno - 1, end_lineno):
315
+ last = len(self.src[i])
316
+ while last > 0 and (self.src[i][last - 1].isspace() or not self.src[i][last - 1].isprintable()):
317
+ last -= 1
318
+ first = 0
319
+ while first < len(self.src[i]) and (
320
+ self.src[i][first].isspace() or not self.src[i][first].isprintable()
321
+ ):
322
+ first += 1
323
+ if i == node.lineno - 1:
324
+ hint = " " * col_offset + "^" * (last - col_offset)
325
+ elif i == node.end_lineno - 1:
326
+ hint = " " * first + "^" * (end_col_offset - first)
327
+ elif first < last:
328
+ hint = " " * first + "^" * (last - first)
329
+ else:
330
+ hint = ""
331
+ msg += gen_line(self.src[i], hint)
332
+ return msg
333
+
334
+
335
+ def get_decorator(ctx: ASTTransformerContext, node) -> str:
336
+ if not isinstance(node, ast.Call):
337
+ return ""
338
+ for wanted, name in [
339
+ (impl.static, "static"),
340
+ (impl.static_assert, "static_assert"),
341
+ (impl.grouped, "grouped"),
342
+ (ndrange, "ndrange"),
343
+ ]:
344
+ if ASTResolver.resolve_to(node.func, wanted, ctx.global_vars):
345
+ return name
346
+ return ""
File without changes
@@ -0,0 +1,324 @@
1
+ # type: ignore
2
+
3
+ import ast
4
+ import dataclasses
5
+ import inspect
6
+ import math
7
+ import operator
8
+ import re
9
+ import warnings
10
+ from ast import unparse
11
+ from collections import ChainMap
12
+ from typing import Any
13
+
14
+ import numpy as np
15
+
16
+ from gstaichi.lang import (
17
+ expr,
18
+ impl,
19
+ matrix,
20
+ )
21
+ from gstaichi.lang import ops as ti_ops
22
+ from gstaichi.lang._dataclass_util import create_flat_name
23
+ from gstaichi.lang.ast.ast_transformer_utils import (
24
+ ASTTransformerContext,
25
+ get_decorator,
26
+ )
27
+ from gstaichi.lang.exception import (
28
+ GsTaichiSyntaxError,
29
+ GsTaichiTypeError,
30
+ )
31
+ from gstaichi.lang.expr import Expr
32
+ from gstaichi.lang.matrix import Matrix, Vector
33
+ from gstaichi.lang.util import is_gstaichi_class
34
+ from gstaichi.types import primitive_types
35
+
36
+
37
+ class CallTransformer:
38
+ @staticmethod
39
+ def _build_call_if_is_builtin(ctx: ASTTransformerContext, node, args, keywords):
40
+ from gstaichi.lang import matrix_ops # pylint: disable=C0415
41
+
42
+ func = node.func.ptr
43
+ replace_func = {
44
+ id(print): impl.ti_print,
45
+ id(min): ti_ops.min,
46
+ id(max): ti_ops.max,
47
+ id(int): impl.ti_int,
48
+ id(bool): impl.ti_bool,
49
+ id(float): impl.ti_float,
50
+ id(any): matrix_ops.any,
51
+ id(all): matrix_ops.all,
52
+ id(abs): abs,
53
+ id(pow): pow,
54
+ id(operator.matmul): matrix_ops.matmul,
55
+ }
56
+
57
+ # Builtin 'len' function on Matrix Expr
58
+ if id(func) == id(len) and len(args) == 1:
59
+ if isinstance(args[0], Expr) and args[0].ptr.is_tensor():
60
+ node.ptr = args[0].get_shape()[0]
61
+ return True
62
+
63
+ if id(func) in replace_func:
64
+ node.ptr = replace_func[id(func)](*args, **keywords)
65
+ return True
66
+ return False
67
+
68
+ @staticmethod
69
+ def _build_call_if_is_type(ctx: ASTTransformerContext, node, args, keywords):
70
+ func = node.func.ptr
71
+ if id(func) in primitive_types.type_ids:
72
+ if len(args) != 1 or keywords:
73
+ raise GsTaichiSyntaxError("A primitive type can only decorate a single expression.")
74
+ if is_gstaichi_class(args[0]):
75
+ raise GsTaichiSyntaxError("A primitive type cannot decorate an expression with a compound type.")
76
+
77
+ if isinstance(args[0], expr.Expr):
78
+ if args[0].ptr.is_tensor():
79
+ raise GsTaichiSyntaxError("A primitive type cannot decorate an expression with a compound type.")
80
+ node.ptr = ti_ops.cast(args[0], func)
81
+ else:
82
+ node.ptr = expr.Expr(args[0], dtype=func)
83
+ return True
84
+ return False
85
+
86
+ @staticmethod
87
+ def _is_external_func(ctx: ASTTransformerContext, func) -> bool:
88
+ if ctx.is_in_static_scope(): # allow external function in static scope
89
+ return False
90
+ if hasattr(func, "_is_gstaichi_function") or hasattr(func, "_is_wrapped_kernel"): # gstaichi func/kernel
91
+ return False
92
+ if hasattr(func, "__module__") and func.__module__ and func.__module__.startswith("gstaichi."):
93
+ return False
94
+ return True
95
+
96
+ @staticmethod
97
+ def _warn_if_is_external_func(ctx: ASTTransformerContext, node):
98
+ func = node.func.ptr
99
+ if not CallTransformer._is_external_func(ctx, func):
100
+ return
101
+ name = unparse(node.func).strip()
102
+ warnings.warn_explicit(
103
+ f"\x1b[38;5;226m" # Yellow
104
+ f'Calling non-gstaichi function "{name}". '
105
+ f"Scope inside the function is not processed by the GsTaichi AST transformer. "
106
+ f"The function may not work as expected. Proceed with caution! "
107
+ f"Maybe you can consider turning it into a @ti.func?"
108
+ f"\x1b[0m", # Reset
109
+ SyntaxWarning,
110
+ ctx.file,
111
+ node.lineno + ctx.lineno_offset,
112
+ module="gstaichi",
113
+ )
114
+
115
+ @staticmethod
116
+ # Parses a formatted string and extracts format specifiers from it, along with positional and keyword arguments.
117
+ # This function produces a canonicalized formatted string that includes solely empty replacement fields, e.g. 'qwerty {} {} {} {} {}'.
118
+ # Note that the arguments can be used multiple times in the string.
119
+ # e.g.:
120
+ # origin input: 'qwerty {1} {} {1:.3f} {k:.4f} {k:}'.format(1.0, 2.0, k=k)
121
+ # raw_string: 'qwerty {1} {} {1:.3f} {k:.4f} {k:}'
122
+ # raw_args: [1.0, 2.0]
123
+ # raw_keywords: {'k': <ti.Expr>}
124
+ # return value: ['qwerty {} {} {} {} {}', 2.0, 1.0, ['__ti_fmt_value__', 2.0, '.3f'], ['__ti_fmt_value__', <ti.Expr>, '.4f'], <ti.Expr>]
125
+ def _canonicalize_formatted_string(raw_string: str, *raw_args: list, **raw_keywords: dict):
126
+ raw_brackets = re.findall(r"{(.*?)}", raw_string)
127
+ brackets = []
128
+ unnamed = 0
129
+ for bracket in raw_brackets:
130
+ item, spec = bracket.split(":") if ":" in bracket else (bracket, None)
131
+ if item.isdigit():
132
+ item = int(item)
133
+ # handle unnamed positional args
134
+ if item == "":
135
+ item = unnamed
136
+ unnamed += 1
137
+ # handle empty spec
138
+ if spec == "":
139
+ spec = None
140
+ brackets.append((item, spec))
141
+
142
+ # check for errors in the arguments
143
+ max_args_index = max([t[0] for t in brackets if isinstance(t[0], int)], default=-1)
144
+ if max_args_index + 1 != len(raw_args):
145
+ raise GsTaichiSyntaxError(
146
+ f"Expected {max_args_index + 1} positional argument(s), but received {len(raw_args)} instead."
147
+ )
148
+ brackets_keywords = [t[0] for t in brackets if isinstance(t[0], str)]
149
+ for item in brackets_keywords:
150
+ if item not in raw_keywords:
151
+ raise GsTaichiSyntaxError(f"Keyword '{item}' not found.")
152
+ for item in raw_keywords:
153
+ if item not in brackets_keywords:
154
+ raise GsTaichiSyntaxError(f"Keyword '{item}' not used.")
155
+
156
+ # reorganize the arguments based on their positions, keywords, and format specifiers
157
+ args = []
158
+ for item, spec in brackets:
159
+ new_arg = raw_args[item] if isinstance(item, int) else raw_keywords[item]
160
+ if spec is not None:
161
+ args.append(["__ti_fmt_value__", new_arg, spec])
162
+ else:
163
+ args.append(new_arg)
164
+ # put the formatted string as the first argument to make ti.format() happy
165
+ args.insert(0, re.sub(r"{.*?}", "{}", raw_string))
166
+ return args
167
+
168
+ @staticmethod
169
+ def _expand_Call_dataclass_args(args: tuple[ast.stmt]) -> tuple[ast.stmt]:
170
+ """
171
+ We require that each node has a .ptr attribute added to it, that contains
172
+ the associated Python object
173
+ """
174
+ args_new = []
175
+ for arg in args:
176
+ val = arg.ptr
177
+ if dataclasses.is_dataclass(val):
178
+ dataclass_type = val
179
+ for field in dataclasses.fields(dataclass_type):
180
+ child_name = create_flat_name(arg.id, field.name)
181
+ load_ctx = ast.Load()
182
+ arg_node = ast.Name(
183
+ id=child_name,
184
+ ctx=load_ctx,
185
+ lineno=arg.lineno,
186
+ end_lineno=arg.end_lineno,
187
+ col_offset=arg.col_offset,
188
+ end_col_offset=arg.end_col_offset,
189
+ )
190
+ if dataclasses.is_dataclass(field.type):
191
+ arg_node.ptr = field.type
192
+ args_new.extend(CallTransformer._expand_Call_dataclass_args((arg_node,)))
193
+ else:
194
+ args_new.append(arg_node)
195
+ else:
196
+ args_new.append(arg)
197
+ return tuple(args_new)
198
+
199
+ @staticmethod
200
+ def _expand_Call_dataclass_kwargs(kwargs: list[ast.keyword]) -> list[ast.keyword]:
201
+ """
202
+ We require that each node has a .ptr attribute added to it, that contains
203
+ the associated Python object
204
+ """
205
+ kwargs_new = []
206
+ for i, kwarg in enumerate(kwargs):
207
+ val = kwarg.ptr[kwarg.arg]
208
+ if dataclasses.is_dataclass(val):
209
+ dataclass_type = val
210
+ for field in dataclasses.fields(dataclass_type):
211
+ src_name = create_flat_name(kwarg.value.id, field.name)
212
+ child_name = create_flat_name(kwarg.arg, field.name)
213
+ load_ctx = ast.Load()
214
+ src_node = ast.Name(
215
+ id=src_name,
216
+ ctx=load_ctx,
217
+ lineno=kwarg.lineno,
218
+ end_lineno=kwarg.end_lineno,
219
+ col_offset=kwarg.col_offset,
220
+ end_col_offset=kwarg.end_col_offset,
221
+ )
222
+ kwarg_node = ast.keyword(
223
+ arg=child_name,
224
+ value=src_node,
225
+ ctx=load_ctx,
226
+ lineno=kwarg.lineno,
227
+ end_lineno=kwarg.end_lineno,
228
+ col_offset=kwarg.col_offset,
229
+ end_col_offset=kwarg.end_col_offset,
230
+ )
231
+ if dataclasses.is_dataclass(field.type):
232
+ kwarg_node.ptr = {child_name: field.type}
233
+ kwargs_new.extend(CallTransformer._expand_Call_dataclass_kwargs([kwarg_node]))
234
+ else:
235
+ kwargs_new.append(kwarg_node)
236
+ else:
237
+ kwargs_new.append(kwarg)
238
+ return kwargs_new
239
+
240
+ @staticmethod
241
+ def build_Call(ctx: ASTTransformerContext, node: ast.Call, build_stmt, build_stmts) -> Any | None:
242
+ """
243
+ example ast:
244
+ Call(func=Name(id='f2', ctx=Load()), args=[Name(id='my_struct_ab', ctx=Load())], keywords=[])
245
+ """
246
+ if get_decorator(ctx, node) in ["static", "static_assert"]:
247
+ with ctx.static_scope_guard():
248
+ build_stmt(ctx, node.func)
249
+ build_stmts(ctx, node.args)
250
+ build_stmts(ctx, node.keywords)
251
+ else:
252
+ build_stmt(ctx, node.func)
253
+ # creates variable for the dataclass itself (as well as other variables,
254
+ # not related to dataclasses). Necessary for calling further child functions
255
+ build_stmts(ctx, node.args)
256
+ build_stmts(ctx, node.keywords)
257
+ node.args = CallTransformer._expand_Call_dataclass_args(node.args)
258
+ node.keywords = CallTransformer._expand_Call_dataclass_kwargs(node.keywords)
259
+ # create variables for the now-expanded dataclass members
260
+ build_stmts(ctx, node.args)
261
+ build_stmts(ctx, node.keywords)
262
+
263
+ args = []
264
+ for arg in node.args:
265
+ if isinstance(arg, ast.Starred):
266
+ arg_list = arg.ptr
267
+ if isinstance(arg_list, Expr) and arg_list.is_tensor():
268
+ # Expand Expr with Matrix-type return into list of Exprs
269
+ arg_list = [Expr(x) for x in ctx.ast_builder.expand_exprs([arg_list.ptr])]
270
+
271
+ for i in arg_list:
272
+ args.append(i)
273
+ else:
274
+ args.append(arg.ptr)
275
+ keywords = dict(ChainMap(*[keyword.ptr for keyword in node.keywords]))
276
+ func = node.func.ptr
277
+
278
+ if id(func) in [id(print), id(impl.ti_print)]:
279
+ ctx.func.has_print = True
280
+
281
+ if isinstance(node.func, ast.Attribute) and isinstance(node.func.value.ptr, str) and node.func.attr == "format":
282
+ raw_string = node.func.value.ptr
283
+ args = CallTransformer._canonicalize_formatted_string(raw_string, *args, **keywords)
284
+ node.ptr = impl.ti_format(*args)
285
+ return node.ptr
286
+
287
+ if id(func) == id(Matrix) or id(func) == id(Vector):
288
+ node.ptr = matrix.make_matrix(*args, **keywords)
289
+ return node.ptr
290
+
291
+ if CallTransformer._build_call_if_is_builtin(ctx, node, args, keywords):
292
+ return node.ptr
293
+
294
+ if CallTransformer._build_call_if_is_type(ctx, node, args, keywords):
295
+ return node.ptr
296
+
297
+ if hasattr(node.func, "caller"):
298
+ node.ptr = func(node.func.caller, *args, **keywords)
299
+ return node.ptr
300
+
301
+ CallTransformer._warn_if_is_external_func(ctx, node)
302
+ try:
303
+ node.ptr = func(*args, **keywords)
304
+ except TypeError as e:
305
+ module = inspect.getmodule(func)
306
+ error_msg = re.sub(r"\bExpr\b", "GsTaichi Expression", str(e))
307
+ func_name = getattr(func, "__name__", func.__class__.__name__)
308
+ msg = f"TypeError when calling `{func_name}`: {error_msg}."
309
+ if CallTransformer._is_external_func(ctx, node.func.ptr):
310
+ args_has_expr = any([isinstance(arg, Expr) for arg in args])
311
+ if args_has_expr and (module == math or module == np):
312
+ exec_str = f"from gstaichi import {func.__name__}"
313
+ try:
314
+ exec(exec_str, {})
315
+ except:
316
+ pass
317
+ else:
318
+ msg += f"\nDid you mean to use `ti.{func.__name__}` instead of `{module.__name__}.{func.__name__}`?"
319
+ raise GsTaichiTypeError(msg)
320
+
321
+ if getattr(func, "_is_gstaichi_function", False):
322
+ ctx.func.has_print |= func.wrapper.has_print
323
+
324
+ return node.ptr