gstaichi 0.1.25.dev0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. gstaichi/CHANGELOG.md +9 -0
  2. gstaichi/__init__.py +40 -0
  3. gstaichi/__main__.py +5 -0
  4. gstaichi/_funcs.py +706 -0
  5. gstaichi/_kernels.py +420 -0
  6. gstaichi/_lib/__init__.py +3 -0
  7. gstaichi/_lib/core/__init__.py +0 -0
  8. gstaichi/_lib/core/gstaichi_python.cp313-win_amd64.pyd +0 -0
  9. gstaichi/_lib/core/gstaichi_python.pyi +2937 -0
  10. gstaichi/_lib/core/py.typed +0 -0
  11. gstaichi/_lib/runtime/runtime_cuda.bc +0 -0
  12. gstaichi/_lib/runtime/runtime_x64.bc +0 -0
  13. gstaichi/_lib/runtime/slim_libdevice.10.bc +0 -0
  14. gstaichi/_lib/utils.py +249 -0
  15. gstaichi/_logging.py +131 -0
  16. gstaichi/_main.py +545 -0
  17. gstaichi/_snode/__init__.py +5 -0
  18. gstaichi/_snode/fields_builder.py +187 -0
  19. gstaichi/_snode/snode_tree.py +34 -0
  20. gstaichi/_test_tools/__init__.py +0 -0
  21. gstaichi/_test_tools/load_kernel_string.py +30 -0
  22. gstaichi/_version.py +1 -0
  23. gstaichi/_version_check.py +103 -0
  24. gstaichi/ad/__init__.py +3 -0
  25. gstaichi/ad/_ad.py +530 -0
  26. gstaichi/algorithms/__init__.py +3 -0
  27. gstaichi/algorithms/_algorithms.py +117 -0
  28. gstaichi/assets/.git +1 -0
  29. gstaichi/assets/Go-Regular.ttf +0 -0
  30. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  31. gstaichi/examples/minimal.py +28 -0
  32. gstaichi/experimental.py +16 -0
  33. gstaichi/lang/__init__.py +50 -0
  34. gstaichi/lang/_ndarray.py +352 -0
  35. gstaichi/lang/_ndrange.py +152 -0
  36. gstaichi/lang/_template_mapper.py +199 -0
  37. gstaichi/lang/_texture.py +172 -0
  38. gstaichi/lang/_wrap_inspect.py +189 -0
  39. gstaichi/lang/any_array.py +99 -0
  40. gstaichi/lang/argpack.py +411 -0
  41. gstaichi/lang/ast/__init__.py +5 -0
  42. gstaichi/lang/ast/ast_transformer.py +1318 -0
  43. gstaichi/lang/ast/ast_transformer_utils.py +341 -0
  44. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  45. gstaichi/lang/ast/ast_transformers/call_transformer.py +267 -0
  46. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +320 -0
  47. gstaichi/lang/ast/checkers.py +106 -0
  48. gstaichi/lang/ast/symbol_resolver.py +57 -0
  49. gstaichi/lang/ast/transform.py +9 -0
  50. gstaichi/lang/common_ops.py +310 -0
  51. gstaichi/lang/exception.py +80 -0
  52. gstaichi/lang/expr.py +180 -0
  53. gstaichi/lang/field.py +466 -0
  54. gstaichi/lang/impl.py +1241 -0
  55. gstaichi/lang/kernel_arguments.py +157 -0
  56. gstaichi/lang/kernel_impl.py +1382 -0
  57. gstaichi/lang/matrix.py +1881 -0
  58. gstaichi/lang/matrix_ops.py +341 -0
  59. gstaichi/lang/matrix_ops_utils.py +190 -0
  60. gstaichi/lang/mesh.py +687 -0
  61. gstaichi/lang/misc.py +778 -0
  62. gstaichi/lang/ops.py +1494 -0
  63. gstaichi/lang/runtime_ops.py +13 -0
  64. gstaichi/lang/shell.py +35 -0
  65. gstaichi/lang/simt/__init__.py +5 -0
  66. gstaichi/lang/simt/block.py +94 -0
  67. gstaichi/lang/simt/grid.py +7 -0
  68. gstaichi/lang/simt/subgroup.py +191 -0
  69. gstaichi/lang/simt/warp.py +96 -0
  70. gstaichi/lang/snode.py +489 -0
  71. gstaichi/lang/source_builder.py +150 -0
  72. gstaichi/lang/struct.py +855 -0
  73. gstaichi/lang/util.py +381 -0
  74. gstaichi/linalg/__init__.py +8 -0
  75. gstaichi/linalg/matrixfree_cg.py +310 -0
  76. gstaichi/linalg/sparse_cg.py +59 -0
  77. gstaichi/linalg/sparse_matrix.py +303 -0
  78. gstaichi/linalg/sparse_solver.py +123 -0
  79. gstaichi/math/__init__.py +11 -0
  80. gstaichi/math/_complex.py +205 -0
  81. gstaichi/math/mathimpl.py +886 -0
  82. gstaichi/profiler/__init__.py +6 -0
  83. gstaichi/profiler/kernel_metrics.py +260 -0
  84. gstaichi/profiler/kernel_profiler.py +586 -0
  85. gstaichi/profiler/memory_profiler.py +15 -0
  86. gstaichi/profiler/scoped_profiler.py +36 -0
  87. gstaichi/sparse/__init__.py +3 -0
  88. gstaichi/sparse/_sparse_grid.py +77 -0
  89. gstaichi/tools/__init__.py +12 -0
  90. gstaichi/tools/diagnose.py +117 -0
  91. gstaichi/tools/np2ply.py +364 -0
  92. gstaichi/tools/vtk.py +38 -0
  93. gstaichi/types/__init__.py +19 -0
  94. gstaichi/types/annotations.py +47 -0
  95. gstaichi/types/compound_types.py +90 -0
  96. gstaichi/types/enums.py +49 -0
  97. gstaichi/types/ndarray_type.py +147 -0
  98. gstaichi/types/primitive_types.py +206 -0
  99. gstaichi/types/quant.py +88 -0
  100. gstaichi/types/texture_type.py +85 -0
  101. gstaichi/types/utils.py +13 -0
  102. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +5 -0
  103. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +29 -0
  104. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +113 -0
  105. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +5 -0
  106. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +19 -0
  107. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +122 -0
  108. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +5 -0
  109. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +19 -0
  110. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +122 -0
  111. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +5 -0
  112. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +19 -0
  113. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +122 -0
  114. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +5 -0
  115. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +19 -0
  116. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +122 -0
  117. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +5 -0
  118. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  119. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +122 -0
  120. gstaichi-0.1.25.dev0.data/data/bin/SPIRV-Tools-shared.dll +0 -0
  121. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/instrument.hpp +268 -0
  122. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.h +907 -0
  123. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.hpp +375 -0
  124. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/linker.hpp +97 -0
  125. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/optimizer.hpp +970 -0
  126. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-diff.lib +0 -0
  127. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-link.lib +0 -0
  128. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-lint.lib +0 -0
  129. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-opt.lib +0 -0
  130. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-reduce.lib +0 -0
  131. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-shared.lib +0 -0
  132. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools.lib +0 -0
  133. gstaichi-0.1.25.dev0.dist-info/METADATA +105 -0
  134. gstaichi-0.1.25.dev0.dist-info/RECORD +138 -0
  135. gstaichi-0.1.25.dev0.dist-info/WHEEL +5 -0
  136. gstaichi-0.1.25.dev0.dist-info/entry_points.txt +2 -0
  137. gstaichi-0.1.25.dev0.dist-info/licenses/LICENSE +201 -0
  138. gstaichi-0.1.25.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,341 @@
1
+ # type: ignore
2
+
3
+ import ast
4
+ import builtins
5
+ import traceback
6
+ from enum import Enum
7
+ from textwrap import TextWrapper
8
+ from typing import TYPE_CHECKING, Any, List
9
+
10
+ from gstaichi._lib.core.gstaichi_python import ASTBuilder
11
+ from gstaichi.lang import impl
12
+ from gstaichi.lang._ndrange import ndrange
13
+ from gstaichi.lang.ast.symbol_resolver import ASTResolver
14
+ from gstaichi.lang.exception import (
15
+ GsTaichiCompilationError,
16
+ GsTaichiNameError,
17
+ GsTaichiSyntaxError,
18
+ handle_exception_from_cpp,
19
+ )
20
+
21
+ if TYPE_CHECKING:
22
+ from gstaichi.lang.kernel_impl import (
23
+ Func,
24
+ Kernel,
25
+ )
26
+
27
+
28
+ class Builder:
29
+ def __call__(self, ctx: "ASTTransformerContext", node: ast.AST):
30
+ method = getattr(self, "build_" + node.__class__.__name__, None)
31
+ try:
32
+ if method is None:
33
+ error_msg = f'Unsupported node "{node.__class__.__name__}"'
34
+ raise GsTaichiSyntaxError(error_msg)
35
+ info = ctx.get_pos_info(node) if isinstance(node, (ast.stmt, ast.expr)) else ""
36
+ with impl.get_runtime().src_info_guard(info):
37
+ return method(ctx, node)
38
+ except Exception as e:
39
+ if impl.get_runtime().print_full_traceback:
40
+ raise e
41
+ if ctx.raised or not isinstance(node, (ast.stmt, ast.expr)):
42
+ raise e.with_traceback(None)
43
+ ctx.raised = True
44
+ e = handle_exception_from_cpp(e)
45
+ if not isinstance(e, GsTaichiCompilationError):
46
+ msg = ctx.get_pos_info(node) + traceback.format_exc()
47
+ raise GsTaichiCompilationError(msg) from None
48
+ msg = ctx.get_pos_info(node) + str(e)
49
+ raise type(e)(msg) from None
50
+
51
+
52
+ class VariableScopeGuard:
53
+ def __init__(self, scopes: list[dict[str, Any]]):
54
+ self.scopes = scopes
55
+
56
+ def __enter__(self):
57
+ self.scopes.append({})
58
+
59
+ def __exit__(self, exc_type, exc_val, exc_tb):
60
+ self.scopes.pop()
61
+
62
+
63
+ class StaticScopeStatus:
64
+ def __init__(self):
65
+ self.is_in_static_scope = False
66
+
67
+
68
+ class StaticScopeGuard:
69
+ def __init__(self, status: StaticScopeStatus):
70
+ self.status = status
71
+
72
+ def __enter__(self):
73
+ self.prev = self.status.is_in_static_scope
74
+ self.status.is_in_static_scope = True
75
+
76
+ def __exit__(self, exc_type, exc_val, exc_tb):
77
+ self.status.is_in_static_scope = self.prev
78
+
79
+
80
+ class NonStaticControlFlowStatus:
81
+ def __init__(self):
82
+ self.is_in_non_static_control_flow = False
83
+
84
+
85
+ class NonStaticControlFlowGuard:
86
+ def __init__(self, status: NonStaticControlFlowStatus):
87
+ self.status = status
88
+
89
+ def __enter__(self):
90
+ self.prev = self.status.is_in_non_static_control_flow
91
+ self.status.is_in_non_static_control_flow = True
92
+
93
+ def __exit__(self, exc_type, exc_val, exc_tb):
94
+ self.status.is_in_non_static_control_flow = self.prev
95
+
96
+
97
+ class LoopStatus(Enum):
98
+ Normal = 0
99
+ Break = 1
100
+ Continue = 2
101
+
102
+
103
+ class LoopScopeAttribute:
104
+ def __init__(self, is_static: bool):
105
+ self.is_static = is_static
106
+ self.status: LoopStatus = LoopStatus.Normal
107
+ self.nearest_non_static_if: ast.If | None = None
108
+
109
+
110
+ class LoopScopeGuard:
111
+ def __init__(self, scopes: list[dict[str, Any]], non_static_guard=None):
112
+ self.scopes = scopes
113
+ self.non_static_guard = non_static_guard
114
+
115
+ def __enter__(self):
116
+ self.scopes.append(LoopScopeAttribute(self.non_static_guard is None))
117
+ if self.non_static_guard:
118
+ self.non_static_guard.__enter__()
119
+
120
+ def __exit__(self, exc_type, exc_val, exc_tb):
121
+ self.scopes.pop()
122
+ if self.non_static_guard:
123
+ self.non_static_guard.__exit__(exc_type, exc_val, exc_tb)
124
+
125
+
126
+ class NonStaticIfGuard:
127
+ def __init__(
128
+ self,
129
+ if_node: ast.If,
130
+ loop_attribute: LoopScopeAttribute,
131
+ non_static_status: NonStaticControlFlowStatus,
132
+ ):
133
+ self.loop_attribute = loop_attribute
134
+ self.if_node = if_node
135
+ self.non_static_guard = NonStaticControlFlowGuard(non_static_status)
136
+
137
+ def __enter__(self):
138
+ if self.loop_attribute:
139
+ self.old_non_static_if = self.loop_attribute.nearest_non_static_if
140
+ self.loop_attribute.nearest_non_static_if = self.if_node
141
+ self.non_static_guard.__enter__()
142
+
143
+ def __exit__(self, exc_type, exc_val, exc_tb):
144
+ if self.loop_attribute:
145
+ self.loop_attribute.nearest_non_static_if = self.old_non_static_if
146
+ self.non_static_guard.__exit__(exc_type, exc_val, exc_tb)
147
+
148
+
149
+ class ReturnStatus(Enum):
150
+ NoReturn = 0
151
+ ReturnedVoid = 1
152
+ ReturnedValue = 2
153
+
154
+
155
+ class ASTTransformerContext:
156
+ def __init__(
157
+ self,
158
+ excluded_parameters=(),
159
+ is_kernel: bool = True,
160
+ func: "Func | Kernel | None" = None,
161
+ arg_features=None,
162
+ global_vars: dict[str, Any] | None = None,
163
+ argument_data=None,
164
+ file: str | None = None,
165
+ src: list[str] | None = None,
166
+ start_lineno: int | None = None,
167
+ ast_builder: ASTBuilder | None = None,
168
+ is_real_function: bool = False,
169
+ ):
170
+ self.func = func
171
+ self.local_scopes: list[dict[str, Any]] = []
172
+ self.loop_scopes: List[LoopScopeAttribute] = []
173
+ self.excluded_parameters = excluded_parameters
174
+ self.is_kernel = is_kernel
175
+ self.arg_features = arg_features
176
+ self.returns = None
177
+ self.global_vars = global_vars
178
+ self.argument_data = argument_data
179
+ self.return_data = None
180
+ self.file = file
181
+ self.src = src
182
+ self.indent = 0
183
+ for c in self.src[0]:
184
+ if c == " ":
185
+ self.indent += 1
186
+ else:
187
+ break
188
+ self.lineno_offset = start_lineno - 1
189
+ self.raised = False
190
+ self.non_static_control_flow_status = NonStaticControlFlowStatus()
191
+ self.static_scope_status = StaticScopeStatus()
192
+ self.returned = ReturnStatus.NoReturn
193
+ self.ast_builder = ast_builder
194
+ self.visited_funcdef = False
195
+ self.is_real_function = is_real_function
196
+ self.kernel_args: list = []
197
+
198
+ # e.g.: FunctionDef, Module, Global
199
+ def variable_scope_guard(self):
200
+ return VariableScopeGuard(self.local_scopes)
201
+
202
+ # e.g.: For, While
203
+ def loop_scope_guard(self, is_static=False):
204
+ if is_static:
205
+ return LoopScopeGuard(self.loop_scopes)
206
+ return LoopScopeGuard(self.loop_scopes, self.non_static_control_flow_guard())
207
+
208
+ def non_static_if_guard(self, if_node: ast.If):
209
+ return NonStaticIfGuard(
210
+ if_node,
211
+ self.current_loop_scope() if self.loop_scopes else None,
212
+ self.non_static_control_flow_status,
213
+ )
214
+
215
+ def non_static_control_flow_guard(self) -> NonStaticControlFlowGuard:
216
+ return NonStaticControlFlowGuard(self.non_static_control_flow_status)
217
+
218
+ def static_scope_guard(self) -> StaticScopeGuard:
219
+ return StaticScopeGuard(self.static_scope_status)
220
+
221
+ def current_scope(self) -> dict[str, Any]:
222
+ return self.local_scopes[-1]
223
+
224
+ def current_loop_scope(self) -> dict[str, Any]:
225
+ return self.loop_scopes[-1]
226
+
227
+ def loop_status(self) -> LoopStatus:
228
+ if self.loop_scopes:
229
+ return self.loop_scopes[-1].status
230
+ return LoopStatus.Normal
231
+
232
+ def set_loop_status(self, status: LoopStatus) -> None:
233
+ self.loop_scopes[-1].status = status
234
+
235
+ def is_in_static_for(self) -> bool:
236
+ if self.loop_scopes:
237
+ return self.loop_scopes[-1].is_static
238
+ return False
239
+
240
+ def is_in_non_static_control_flow(self) -> bool:
241
+ return self.non_static_control_flow_status.is_in_non_static_control_flow
242
+
243
+ def is_in_static_scope(self) -> bool:
244
+ return self.static_scope_status.is_in_static_scope
245
+
246
+ def is_var_declared(self, name: str) -> bool:
247
+ for s in self.local_scopes:
248
+ if name in s:
249
+ return True
250
+ return False
251
+
252
+ def create_variable(self, name: str, var: Any) -> None:
253
+ if name in self.current_scope():
254
+ raise GsTaichiSyntaxError("Recreating variables is not allowed")
255
+ self.current_scope()[name] = var
256
+
257
+ def check_loop_var(self, loop_var: str) -> None:
258
+ if self.is_var_declared(loop_var):
259
+ raise GsTaichiSyntaxError(
260
+ f"Variable '{loop_var}' is already declared in the outer scope and cannot be used as loop variable"
261
+ )
262
+
263
+ def get_var_by_name(self, name: str) -> Any:
264
+ for s in reversed(self.local_scopes):
265
+ if name in s:
266
+ return s[name]
267
+ if name in self.global_vars:
268
+ var = self.global_vars[name]
269
+ from gstaichi.lang.matrix import ( # pylint: disable-msg=C0415
270
+ Matrix,
271
+ make_matrix,
272
+ )
273
+
274
+ if isinstance(var, Matrix):
275
+ return make_matrix(var.to_list())
276
+ return var
277
+ try:
278
+ return getattr(builtins, name)
279
+ except AttributeError:
280
+ raise GsTaichiNameError(f'Name "{name}" is not defined')
281
+
282
+ def get_pos_info(self, node: ast.AST) -> str:
283
+ msg = f'File "{self.file}", line {node.lineno + self.lineno_offset}, in {self.func.func.__name__}:\n'
284
+ col_offset = self.indent + node.col_offset
285
+ end_col_offset = self.indent + node.end_col_offset
286
+
287
+ wrapper = TextWrapper(width=80)
288
+
289
+ def gen_line(code: str, hint: str) -> str:
290
+ hint += " " * (len(code) - len(hint))
291
+ code = wrapper.wrap(code)
292
+ hint = wrapper.wrap(hint)
293
+ if not len(code):
294
+ return "\n\n"
295
+ return "".join([c + "\n" + h + "\n" for c, h in zip(code, hint)])
296
+
297
+ if node.lineno == node.end_lineno:
298
+ if node.lineno - 1 < len(self.src):
299
+ hint = " " * col_offset + "^" * (end_col_offset - col_offset)
300
+ msg += gen_line(self.src[node.lineno - 1], hint)
301
+ else:
302
+ node_type = node.__class__.__name__
303
+
304
+ if node_type in ["For", "While", "FunctionDef", "If"]:
305
+ end_lineno = max(node.body[0].lineno - 1, node.lineno)
306
+ else:
307
+ end_lineno = node.end_lineno
308
+
309
+ for i in range(node.lineno - 1, end_lineno):
310
+ last = len(self.src[i])
311
+ while last > 0 and (self.src[i][last - 1].isspace() or not self.src[i][last - 1].isprintable()):
312
+ last -= 1
313
+ first = 0
314
+ while first < len(self.src[i]) and (
315
+ self.src[i][first].isspace() or not self.src[i][first].isprintable()
316
+ ):
317
+ first += 1
318
+ if i == node.lineno - 1:
319
+ hint = " " * col_offset + "^" * (last - col_offset)
320
+ elif i == node.end_lineno - 1:
321
+ hint = " " * first + "^" * (end_col_offset - first)
322
+ elif first < last:
323
+ hint = " " * first + "^" * (last - first)
324
+ else:
325
+ hint = ""
326
+ msg += gen_line(self.src[i], hint)
327
+ return msg
328
+
329
+
330
+ def get_decorator(ctx: ASTTransformerContext, node) -> str:
331
+ if not isinstance(node, ast.Call):
332
+ return ""
333
+ for wanted, name in [
334
+ (impl.static, "static"),
335
+ (impl.static_assert, "static_assert"),
336
+ (impl.grouped, "grouped"),
337
+ (ndrange, "ndrange"),
338
+ ]:
339
+ if ASTResolver.resolve_to(node.func, wanted, ctx.global_vars):
340
+ return name
341
+ return ""
File without changes
@@ -0,0 +1,267 @@
1
+ # type: ignore
2
+
3
+ import ast
4
+ import dataclasses
5
+ import inspect
6
+ import math
7
+ import operator
8
+ import re
9
+ import warnings
10
+ from ast import unparse
11
+ from collections import ChainMap
12
+
13
+ import numpy as np
14
+
15
+ from gstaichi.lang import (
16
+ expr,
17
+ impl,
18
+ matrix,
19
+ )
20
+ from gstaichi.lang import ops as ti_ops
21
+ from gstaichi.lang.ast.ast_transformer_utils import (
22
+ ASTTransformerContext,
23
+ get_decorator,
24
+ )
25
+ from gstaichi.lang.exception import (
26
+ GsTaichiSyntaxError,
27
+ GsTaichiTypeError,
28
+ )
29
+ from gstaichi.lang.expr import Expr
30
+ from gstaichi.lang.matrix import Matrix, Vector
31
+ from gstaichi.lang.util import is_gstaichi_class
32
+ from gstaichi.types import primitive_types
33
+
34
+
35
+ class CallTransformer:
36
+ @staticmethod
37
+ def build_call_if_is_builtin(ctx: ASTTransformerContext, node, args, keywords):
38
+ from gstaichi.lang import matrix_ops # pylint: disable=C0415
39
+
40
+ func = node.func.ptr
41
+ replace_func = {
42
+ id(print): impl.ti_print,
43
+ id(min): ti_ops.min,
44
+ id(max): ti_ops.max,
45
+ id(int): impl.ti_int,
46
+ id(bool): impl.ti_bool,
47
+ id(float): impl.ti_float,
48
+ id(any): matrix_ops.any,
49
+ id(all): matrix_ops.all,
50
+ id(abs): abs,
51
+ id(pow): pow,
52
+ id(operator.matmul): matrix_ops.matmul,
53
+ }
54
+
55
+ # Builtin 'len' function on Matrix Expr
56
+ if id(func) == id(len) and len(args) == 1:
57
+ if isinstance(args[0], Expr) and args[0].ptr.is_tensor():
58
+ node.ptr = args[0].get_shape()[0]
59
+ return True
60
+
61
+ if id(func) in replace_func:
62
+ node.ptr = replace_func[id(func)](*args, **keywords)
63
+ return True
64
+ return False
65
+
66
+ @staticmethod
67
+ def build_call_if_is_type(ctx: ASTTransformerContext, node, args, keywords):
68
+ func = node.func.ptr
69
+ if id(func) in primitive_types.type_ids:
70
+ if len(args) != 1 or keywords:
71
+ raise GsTaichiSyntaxError("A primitive type can only decorate a single expression.")
72
+ if is_gstaichi_class(args[0]):
73
+ raise GsTaichiSyntaxError("A primitive type cannot decorate an expression with a compound type.")
74
+
75
+ if isinstance(args[0], expr.Expr):
76
+ if args[0].ptr.is_tensor():
77
+ raise GsTaichiSyntaxError("A primitive type cannot decorate an expression with a compound type.")
78
+ node.ptr = ti_ops.cast(args[0], func)
79
+ else:
80
+ node.ptr = expr.Expr(args[0], dtype=func)
81
+ return True
82
+ return False
83
+
84
+ @staticmethod
85
+ def is_external_func(ctx: ASTTransformerContext, func) -> bool:
86
+ if ctx.is_in_static_scope(): # allow external function in static scope
87
+ return False
88
+ if hasattr(func, "_is_gstaichi_function") or hasattr(func, "_is_wrapped_kernel"): # gstaichi func/kernel
89
+ return False
90
+ if hasattr(func, "__module__") and func.__module__ and func.__module__.startswith("gstaichi."):
91
+ return False
92
+ return True
93
+
94
+ @staticmethod
95
+ def warn_if_is_external_func(ctx: ASTTransformerContext, node):
96
+ func = node.func.ptr
97
+ if not CallTransformer.is_external_func(ctx, func):
98
+ return
99
+ name = unparse(node.func).strip()
100
+ warnings.warn_explicit(
101
+ f"\x1b[38;5;226m" # Yellow
102
+ f'Calling non-gstaichi function "{name}". '
103
+ f"Scope inside the function is not processed by the GsTaichi AST transformer. "
104
+ f"The function may not work as expected. Proceed with caution! "
105
+ f"Maybe you can consider turning it into a @ti.func?"
106
+ f"\x1b[0m", # Reset
107
+ SyntaxWarning,
108
+ ctx.file,
109
+ node.lineno + ctx.lineno_offset,
110
+ module="gstaichi",
111
+ )
112
+
113
+ @staticmethod
114
+ # Parses a formatted string and extracts format specifiers from it, along with positional and keyword arguments.
115
+ # This function produces a canonicalized formatted string that includes solely empty replacement fields, e.g. 'qwerty {} {} {} {} {}'.
116
+ # Note that the arguments can be used multiple times in the string.
117
+ # e.g.:
118
+ # origin input: 'qwerty {1} {} {1:.3f} {k:.4f} {k:}'.format(1.0, 2.0, k=k)
119
+ # raw_string: 'qwerty {1} {} {1:.3f} {k:.4f} {k:}'
120
+ # raw_args: [1.0, 2.0]
121
+ # raw_keywords: {'k': <ti.Expr>}
122
+ # return value: ['qwerty {} {} {} {} {}', 2.0, 1.0, ['__ti_fmt_value__', 2.0, '.3f'], ['__ti_fmt_value__', <ti.Expr>, '.4f'], <ti.Expr>]
123
+ def canonicalize_formatted_string(raw_string: str, *raw_args: list, **raw_keywords: dict):
124
+ raw_brackets = re.findall(r"{(.*?)}", raw_string)
125
+ brackets = []
126
+ unnamed = 0
127
+ for bracket in raw_brackets:
128
+ item, spec = bracket.split(":") if ":" in bracket else (bracket, None)
129
+ if item.isdigit():
130
+ item = int(item)
131
+ # handle unnamed positional args
132
+ if item == "":
133
+ item = unnamed
134
+ unnamed += 1
135
+ # handle empty spec
136
+ if spec == "":
137
+ spec = None
138
+ brackets.append((item, spec))
139
+
140
+ # check for errors in the arguments
141
+ max_args_index = max([t[0] for t in brackets if isinstance(t[0], int)], default=-1)
142
+ if max_args_index + 1 != len(raw_args):
143
+ raise GsTaichiSyntaxError(
144
+ f"Expected {max_args_index + 1} positional argument(s), but received {len(raw_args)} instead."
145
+ )
146
+ brackets_keywords = [t[0] for t in brackets if isinstance(t[0], str)]
147
+ for item in brackets_keywords:
148
+ if item not in raw_keywords:
149
+ raise GsTaichiSyntaxError(f"Keyword '{item}' not found.")
150
+ for item in raw_keywords:
151
+ if item not in brackets_keywords:
152
+ raise GsTaichiSyntaxError(f"Keyword '{item}' not used.")
153
+
154
+ # reorganize the arguments based on their positions, keywords, and format specifiers
155
+ args = []
156
+ for item, spec in brackets:
157
+ new_arg = raw_args[item] if isinstance(item, int) else raw_keywords[item]
158
+ if spec is not None:
159
+ args.append(["__ti_fmt_value__", new_arg, spec])
160
+ else:
161
+ args.append(new_arg)
162
+ # put the formatted string as the first argument to make ti.format() happy
163
+ args.insert(0, re.sub(r"{.*?}", "{}", raw_string))
164
+ return args
165
+
166
+ @staticmethod
167
+ def expand_node_args_dataclasses(args: tuple[ast.AST, ...]) -> tuple[ast.AST, ...]:
168
+ args_new = []
169
+ for arg in args:
170
+ val = arg.ptr
171
+ if dataclasses.is_dataclass(val):
172
+ dataclass_type = val
173
+ for field in dataclasses.fields(dataclass_type):
174
+ child_name = f"__ti_{arg.id}_{field.name}"
175
+ load_ctx = ast.Load()
176
+ arg_node = ast.Name(
177
+ id=child_name,
178
+ ctx=load_ctx,
179
+ lineno=arg.lineno,
180
+ end_lineno=arg.end_lineno,
181
+ col_offset=arg.col_offset,
182
+ end_col_offset=arg.end_col_offset,
183
+ )
184
+ args_new.append(arg_node)
185
+ else:
186
+ args_new.append(arg)
187
+ return tuple(args_new)
188
+
189
+ @staticmethod
190
+ def build_Call(ctx: ASTTransformerContext, node: ast.Call, build_stmt, build_stmts):
191
+ if get_decorator(ctx, node) in ["static", "static_assert"]:
192
+ with ctx.static_scope_guard():
193
+ build_stmt(ctx, node.func)
194
+ build_stmts(ctx, node.args)
195
+ build_stmts(ctx, node.keywords)
196
+ else:
197
+ build_stmt(ctx, node.func)
198
+ # creates variable for the dataclass itself (as well as other variables,
199
+ # not related to dataclasses). Necessary for calling further child functions
200
+ build_stmts(ctx, node.args)
201
+ node.args = CallTransformer.expand_node_args_dataclasses(node.args)
202
+ # create variables for the now-expanded dataclass members
203
+ build_stmts(ctx, node.args)
204
+ build_stmts(ctx, node.keywords)
205
+
206
+ args = []
207
+ for arg in node.args:
208
+ if isinstance(arg, ast.Starred):
209
+ arg_list = arg.ptr
210
+ if isinstance(arg_list, Expr) and arg_list.is_tensor():
211
+ # Expand Expr with Matrix-type return into list of Exprs
212
+ arg_list = [Expr(x) for x in ctx.ast_builder.expand_exprs([arg_list.ptr])]
213
+
214
+ for i in arg_list:
215
+ args.append(i)
216
+ else:
217
+ args.append(arg.ptr)
218
+ keywords = dict(ChainMap(*[keyword.ptr for keyword in node.keywords]))
219
+ func = node.func.ptr
220
+
221
+ if id(func) in [id(print), id(impl.ti_print)]:
222
+ ctx.func.has_print = True
223
+
224
+ if isinstance(node.func, ast.Attribute) and isinstance(node.func.value.ptr, str) and node.func.attr == "format":
225
+ raw_string = node.func.value.ptr
226
+ args = CallTransformer.canonicalize_formatted_string(raw_string, *args, **keywords)
227
+ node.ptr = impl.ti_format(*args)
228
+ return node.ptr
229
+
230
+ if id(func) == id(Matrix) or id(func) == id(Vector):
231
+ node.ptr = matrix.make_matrix(*args, **keywords)
232
+ return node.ptr
233
+
234
+ if CallTransformer.build_call_if_is_builtin(ctx, node, args, keywords):
235
+ return node.ptr
236
+
237
+ if CallTransformer.build_call_if_is_type(ctx, node, args, keywords):
238
+ return node.ptr
239
+
240
+ if hasattr(node.func, "caller"):
241
+ node.ptr = func(node.func.caller, *args, **keywords)
242
+ return node.ptr
243
+
244
+ CallTransformer.warn_if_is_external_func(ctx, node)
245
+ try:
246
+ node.ptr = func(*args, **keywords)
247
+ except TypeError as e:
248
+ module = inspect.getmodule(func)
249
+ error_msg = re.sub(r"\bExpr\b", "GsTaichi Expression", str(e))
250
+ func_name = getattr(func, "__name__", func.__class__.__name__)
251
+ msg = f"TypeError when calling `{func_name}`: {error_msg}."
252
+ if CallTransformer.is_external_func(ctx, node.func.ptr):
253
+ args_has_expr = any([isinstance(arg, Expr) for arg in args])
254
+ if args_has_expr and (module == math or module == np):
255
+ exec_str = f"from gstaichi import {func.__name__}"
256
+ try:
257
+ exec(exec_str, {})
258
+ except:
259
+ pass
260
+ else:
261
+ msg += f"\nDid you mean to use `ti.{func.__name__}` instead of `{module.__name__}.{func.__name__}`?"
262
+ raise GsTaichiTypeError(msg)
263
+
264
+ if getattr(func, "_is_gstaichi_function", False):
265
+ ctx.func.has_print |= func.wrapper.has_print
266
+
267
+ return node.ptr