gstaichi 0.1.25.dev0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. gstaichi/CHANGELOG.md +9 -0
  2. gstaichi/__init__.py +40 -0
  3. gstaichi/__main__.py +5 -0
  4. gstaichi/_funcs.py +706 -0
  5. gstaichi/_kernels.py +420 -0
  6. gstaichi/_lib/__init__.py +3 -0
  7. gstaichi/_lib/core/__init__.py +0 -0
  8. gstaichi/_lib/core/gstaichi_python.cp311-win_amd64.pyd +0 -0
  9. gstaichi/_lib/core/gstaichi_python.pyi +2937 -0
  10. gstaichi/_lib/core/py.typed +0 -0
  11. gstaichi/_lib/runtime/runtime_cuda.bc +0 -0
  12. gstaichi/_lib/runtime/runtime_x64.bc +0 -0
  13. gstaichi/_lib/runtime/slim_libdevice.10.bc +0 -0
  14. gstaichi/_lib/utils.py +249 -0
  15. gstaichi/_logging.py +131 -0
  16. gstaichi/_main.py +545 -0
  17. gstaichi/_snode/__init__.py +5 -0
  18. gstaichi/_snode/fields_builder.py +187 -0
  19. gstaichi/_snode/snode_tree.py +34 -0
  20. gstaichi/_test_tools/__init__.py +0 -0
  21. gstaichi/_test_tools/load_kernel_string.py +30 -0
  22. gstaichi/_version.py +1 -0
  23. gstaichi/_version_check.py +103 -0
  24. gstaichi/ad/__init__.py +3 -0
  25. gstaichi/ad/_ad.py +530 -0
  26. gstaichi/algorithms/__init__.py +3 -0
  27. gstaichi/algorithms/_algorithms.py +117 -0
  28. gstaichi/assets/.git +1 -0
  29. gstaichi/assets/Go-Regular.ttf +0 -0
  30. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  31. gstaichi/examples/minimal.py +28 -0
  32. gstaichi/experimental.py +16 -0
  33. gstaichi/lang/__init__.py +50 -0
  34. gstaichi/lang/_ndarray.py +352 -0
  35. gstaichi/lang/_ndrange.py +152 -0
  36. gstaichi/lang/_template_mapper.py +199 -0
  37. gstaichi/lang/_texture.py +172 -0
  38. gstaichi/lang/_wrap_inspect.py +189 -0
  39. gstaichi/lang/any_array.py +99 -0
  40. gstaichi/lang/argpack.py +411 -0
  41. gstaichi/lang/ast/__init__.py +5 -0
  42. gstaichi/lang/ast/ast_transformer.py +1318 -0
  43. gstaichi/lang/ast/ast_transformer_utils.py +341 -0
  44. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  45. gstaichi/lang/ast/ast_transformers/call_transformer.py +267 -0
  46. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +320 -0
  47. gstaichi/lang/ast/checkers.py +106 -0
  48. gstaichi/lang/ast/symbol_resolver.py +57 -0
  49. gstaichi/lang/ast/transform.py +9 -0
  50. gstaichi/lang/common_ops.py +310 -0
  51. gstaichi/lang/exception.py +80 -0
  52. gstaichi/lang/expr.py +180 -0
  53. gstaichi/lang/field.py +466 -0
  54. gstaichi/lang/impl.py +1241 -0
  55. gstaichi/lang/kernel_arguments.py +157 -0
  56. gstaichi/lang/kernel_impl.py +1382 -0
  57. gstaichi/lang/matrix.py +1881 -0
  58. gstaichi/lang/matrix_ops.py +341 -0
  59. gstaichi/lang/matrix_ops_utils.py +190 -0
  60. gstaichi/lang/mesh.py +687 -0
  61. gstaichi/lang/misc.py +778 -0
  62. gstaichi/lang/ops.py +1494 -0
  63. gstaichi/lang/runtime_ops.py +13 -0
  64. gstaichi/lang/shell.py +35 -0
  65. gstaichi/lang/simt/__init__.py +5 -0
  66. gstaichi/lang/simt/block.py +94 -0
  67. gstaichi/lang/simt/grid.py +7 -0
  68. gstaichi/lang/simt/subgroup.py +191 -0
  69. gstaichi/lang/simt/warp.py +96 -0
  70. gstaichi/lang/snode.py +489 -0
  71. gstaichi/lang/source_builder.py +150 -0
  72. gstaichi/lang/struct.py +855 -0
  73. gstaichi/lang/util.py +381 -0
  74. gstaichi/linalg/__init__.py +8 -0
  75. gstaichi/linalg/matrixfree_cg.py +310 -0
  76. gstaichi/linalg/sparse_cg.py +59 -0
  77. gstaichi/linalg/sparse_matrix.py +303 -0
  78. gstaichi/linalg/sparse_solver.py +123 -0
  79. gstaichi/math/__init__.py +11 -0
  80. gstaichi/math/_complex.py +205 -0
  81. gstaichi/math/mathimpl.py +886 -0
  82. gstaichi/profiler/__init__.py +6 -0
  83. gstaichi/profiler/kernel_metrics.py +260 -0
  84. gstaichi/profiler/kernel_profiler.py +586 -0
  85. gstaichi/profiler/memory_profiler.py +15 -0
  86. gstaichi/profiler/scoped_profiler.py +36 -0
  87. gstaichi/sparse/__init__.py +3 -0
  88. gstaichi/sparse/_sparse_grid.py +77 -0
  89. gstaichi/tools/__init__.py +12 -0
  90. gstaichi/tools/diagnose.py +117 -0
  91. gstaichi/tools/np2ply.py +364 -0
  92. gstaichi/tools/vtk.py +38 -0
  93. gstaichi/types/__init__.py +19 -0
  94. gstaichi/types/annotations.py +47 -0
  95. gstaichi/types/compound_types.py +90 -0
  96. gstaichi/types/enums.py +49 -0
  97. gstaichi/types/ndarray_type.py +147 -0
  98. gstaichi/types/primitive_types.py +206 -0
  99. gstaichi/types/quant.py +88 -0
  100. gstaichi/types/texture_type.py +85 -0
  101. gstaichi/types/utils.py +13 -0
  102. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +5 -0
  103. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +29 -0
  104. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +113 -0
  105. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +5 -0
  106. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +19 -0
  107. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +122 -0
  108. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +5 -0
  109. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +19 -0
  110. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +122 -0
  111. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +5 -0
  112. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +19 -0
  113. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +122 -0
  114. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +5 -0
  115. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +19 -0
  116. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +122 -0
  117. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +5 -0
  118. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  119. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +122 -0
  120. gstaichi-0.1.25.dev0.data/data/bin/SPIRV-Tools-shared.dll +0 -0
  121. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/instrument.hpp +268 -0
  122. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.h +907 -0
  123. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.hpp +375 -0
  124. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/linker.hpp +97 -0
  125. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/optimizer.hpp +970 -0
  126. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-diff.lib +0 -0
  127. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-link.lib +0 -0
  128. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-lint.lib +0 -0
  129. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-opt.lib +0 -0
  130. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-reduce.lib +0 -0
  131. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-shared.lib +0 -0
  132. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools.lib +0 -0
  133. gstaichi-0.1.25.dev0.dist-info/METADATA +105 -0
  134. gstaichi-0.1.25.dev0.dist-info/RECORD +138 -0
  135. gstaichi-0.1.25.dev0.dist-info/WHEEL +5 -0
  136. gstaichi-0.1.25.dev0.dist-info/entry_points.txt +2 -0
  137. gstaichi-0.1.25.dev0.dist-info/licenses/LICENSE +201 -0
  138. gstaichi-0.1.25.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,320 @@
1
+ # type: ignore
2
+
3
+ import ast
4
+ import dataclasses
5
+ from typing import Any, Callable
6
+
7
+ from gstaichi.lang import (
8
+ _ndarray,
9
+ any_array,
10
+ expr,
11
+ impl,
12
+ kernel_arguments,
13
+ matrix,
14
+ )
15
+ from gstaichi.lang import ops as ti_ops
16
+ from gstaichi.lang.argpack import ArgPackType
17
+ from gstaichi.lang.ast.ast_transformer_utils import (
18
+ ASTTransformerContext,
19
+ )
20
+ from gstaichi.lang.exception import (
21
+ GsTaichiSyntaxError,
22
+ )
23
+ from gstaichi.lang.matrix import MatrixType
24
+ from gstaichi.lang.struct import StructType
25
+ from gstaichi.lang.util import to_gstaichi_type
26
+ from gstaichi.types import annotations, ndarray_type, primitive_types, texture_type
27
+
28
+
29
+ class FunctionDefTransformer:
30
+ @staticmethod
31
+ def _decl_and_create_variable(
32
+ ctx: ASTTransformerContext, annotation, name, arg_features, invoke_later_dict, prefix_name, arg_depth
33
+ ) -> tuple[bool, Any]:
34
+ full_name = prefix_name + "_" + name
35
+ if not isinstance(annotation, primitive_types.RefType):
36
+ ctx.kernel_args.append(name)
37
+ if isinstance(annotation, ArgPackType):
38
+ kernel_arguments.push_argpack_arg(name)
39
+ d = {}
40
+ items_to_put_in_dict = []
41
+ for j, (_name, anno) in enumerate(annotation.members.items()):
42
+ result, obj = FunctionDefTransformer._decl_and_create_variable(
43
+ ctx, anno, _name, arg_features[j], invoke_later_dict, full_name, arg_depth + 1
44
+ )
45
+ if not result:
46
+ d[_name] = None
47
+ items_to_put_in_dict.append((full_name + "_" + _name, _name, obj))
48
+ else:
49
+ d[_name] = obj
50
+ argpack = kernel_arguments.decl_argpack_arg(annotation, d)
51
+ for item in items_to_put_in_dict:
52
+ invoke_later_dict[item[0]] = argpack, item[1], *item[2]
53
+ return True, argpack
54
+ if annotation == annotations.template or isinstance(annotation, annotations.template):
55
+ return True, ctx.global_vars[name]
56
+ if isinstance(annotation, annotations.sparse_matrix_builder):
57
+ return False, (
58
+ kernel_arguments.decl_sparse_matrix,
59
+ (
60
+ to_gstaichi_type(arg_features),
61
+ full_name,
62
+ ),
63
+ )
64
+ if isinstance(annotation, ndarray_type.NdarrayType):
65
+ return False, (
66
+ kernel_arguments.decl_ndarray_arg,
67
+ (
68
+ to_gstaichi_type(arg_features[0]),
69
+ arg_features[1],
70
+ full_name,
71
+ arg_features[2],
72
+ arg_features[3],
73
+ ),
74
+ )
75
+ if isinstance(annotation, texture_type.TextureType):
76
+ return False, (kernel_arguments.decl_texture_arg, (arg_features[0], full_name))
77
+ if isinstance(annotation, texture_type.RWTextureType):
78
+ return False, (
79
+ kernel_arguments.decl_rw_texture_arg,
80
+ (arg_features[0], arg_features[1], arg_features[2], full_name),
81
+ )
82
+ if isinstance(annotation, MatrixType):
83
+ return True, kernel_arguments.decl_matrix_arg(annotation, name, arg_depth)
84
+ if isinstance(annotation, StructType):
85
+ return True, kernel_arguments.decl_struct_arg(annotation, name, arg_depth)
86
+ return True, kernel_arguments.decl_scalar_arg(annotation, name, arg_depth)
87
+
88
+ @staticmethod
89
+ def _transform_kernel_arg(
90
+ ctx: ASTTransformerContext,
91
+ invoke_later_dict: dict[str, tuple[Any, str, Callable, list[Any]]],
92
+ create_variable_later: dict[str, Any],
93
+ argument_name: str,
94
+ argument_type: Any,
95
+ this_arg_features: tuple[Any, ...],
96
+ ) -> None:
97
+ if isinstance(argument_type, ArgPackType):
98
+ kernel_arguments.push_argpack_arg(argument_name)
99
+ d = {}
100
+ items_to_put_in_dict: list[tuple[str, str, Any]] = []
101
+ for j, (name, anno) in enumerate(argument_type.members.items()):
102
+ result, obj = FunctionDefTransformer._decl_and_create_variable(
103
+ ctx, anno, name, this_arg_features[j], invoke_later_dict, "__argpack_" + name, 1
104
+ )
105
+ if not result:
106
+ d[name] = None
107
+ items_to_put_in_dict.append(("__argpack_" + name, name, obj))
108
+ else:
109
+ d[name] = obj
110
+ argpack = kernel_arguments.decl_argpack_arg(argument_type, d)
111
+ for item in items_to_put_in_dict:
112
+ invoke_later_dict[item[0]] = argpack, item[1], *item[2]
113
+ create_variable_later[argument_name] = argpack
114
+ elif dataclasses.is_dataclass(argument_type):
115
+ arg_features = this_arg_features
116
+ ctx.create_variable(argument_name, argument_type)
117
+ for field_idx, field in enumerate(dataclasses.fields(argument_type)):
118
+ flat_name = f"__ti_{argument_name}_{field.name}"
119
+ result, obj = FunctionDefTransformer._decl_and_create_variable(
120
+ ctx,
121
+ field.type,
122
+ flat_name,
123
+ arg_features[field_idx],
124
+ invoke_later_dict,
125
+ "",
126
+ 0,
127
+ )
128
+ if result:
129
+ ctx.create_variable(flat_name, obj)
130
+ else:
131
+ decl_type_func, type_args = obj
132
+ obj = decl_type_func(*type_args)
133
+ ctx.create_variable(flat_name, obj)
134
+ else:
135
+ result, obj = FunctionDefTransformer._decl_and_create_variable(
136
+ ctx,
137
+ argument_type,
138
+ argument_name,
139
+ this_arg_features if ctx.arg_features is not None else None,
140
+ invoke_later_dict,
141
+ "",
142
+ 0,
143
+ )
144
+ if result:
145
+ ctx.create_variable(argument_name, obj)
146
+ else:
147
+ decl_type_func, type_args = obj
148
+ obj = decl_type_func(*type_args)
149
+ ctx.create_variable(argument_name, obj)
150
+
151
+ @staticmethod
152
+ def _transform_as_kernel(ctx: ASTTransformerContext, node: ast.FunctionDef, args: ast.arguments) -> None:
153
+ if node.returns is not None:
154
+ if not isinstance(node.returns, ast.Constant):
155
+ for return_type in ctx.func.return_type:
156
+ kernel_arguments.decl_ret(return_type)
157
+ impl.get_runtime().compiling_callable.finalize_rets()
158
+
159
+ invoke_later_dict: dict[str, tuple[Any, str, Any]] = dict()
160
+ create_variable_later = dict()
161
+ for i, arg in enumerate(args.args):
162
+ argument = ctx.func.arguments[i]
163
+ FunctionDefTransformer._transform_kernel_arg(
164
+ ctx,
165
+ invoke_later_dict,
166
+ create_variable_later,
167
+ argument.name,
168
+ argument.annotation,
169
+ ctx.arg_features[i] if ctx.arg_features is not None else (),
170
+ )
171
+
172
+ for k, v in invoke_later_dict.items():
173
+ argpack, name, func, params = v
174
+ argpack[name] = func(*params)
175
+ for k, v in create_variable_later.items():
176
+ ctx.create_variable(k, v)
177
+
178
+ impl.get_runtime().compiling_callable.finalize_params()
179
+ # remove original args
180
+ node.args.args = []
181
+
182
+ @staticmethod
183
+ def _transform_func_arg(
184
+ ctx: ASTTransformerContext,
185
+ argument_name: str,
186
+ argument_type: Any,
187
+ data: Any,
188
+ ) -> None:
189
+ if isinstance(argument_type, annotations.template):
190
+ ctx.create_variable(argument_name, data)
191
+ return None
192
+
193
+ if dataclasses.is_dataclass(argument_type):
194
+ dataclass_type = argument_type
195
+ for field in dataclasses.fields(dataclass_type):
196
+ data_child = getattr(data, field.name)
197
+ if not isinstance(
198
+ data_child,
199
+ (
200
+ _ndarray.ScalarNdarray,
201
+ matrix.VectorNdarray,
202
+ matrix.MatrixNdarray,
203
+ any_array.AnyArray,
204
+ ),
205
+ ):
206
+ raise GsTaichiSyntaxError(
207
+ f"Argument {argument_name} of type {dataclass_type} {field.type} is not recognized."
208
+ )
209
+ field.type.check_matched(data_child.get_type(), field.name)
210
+ var_name = f"__ti_{argument_name}_{field.name}"
211
+ ctx.create_variable(var_name, data_child)
212
+ return None
213
+
214
+ # Ndarray arguments are passed by reference.
215
+ if isinstance(argument_type, (ndarray_type.NdarrayType)):
216
+ if not isinstance(
217
+ data,
218
+ (
219
+ _ndarray.ScalarNdarray,
220
+ matrix.VectorNdarray,
221
+ matrix.MatrixNdarray,
222
+ any_array.AnyArray,
223
+ ),
224
+ ):
225
+ raise GsTaichiSyntaxError(f"Argument {arg.arg} of type {argument_type} is not recognized.")
226
+ argument_type.check_matched(data.get_type(), argument_name)
227
+ ctx.create_variable(argument_name, data)
228
+ return None
229
+
230
+ # Matrix arguments are passed by value.
231
+ if isinstance(argument_type, (MatrixType)):
232
+ var_name = argument_name
233
+ # "data" is expected to be an Expr here,
234
+ # so we simply call "impl.expr_init_func(data)" to perform:
235
+ #
236
+ # TensorType* t = alloca()
237
+ # assign(t, data)
238
+ #
239
+ # We created local variable "t" - a copy of the passed-in argument "data"
240
+ if not isinstance(data, expr.Expr) or not data.ptr.is_tensor():
241
+ raise GsTaichiSyntaxError(
242
+ f"Argument {var_name} of type {argument_type} is expected to be a Matrix, but got {type(data)}."
243
+ )
244
+
245
+ element_shape = data.ptr.get_rvalue_type().shape()
246
+ if len(element_shape) != argument_type.ndim:
247
+ raise GsTaichiSyntaxError(
248
+ f"Argument {var_name} of type {argument_type} is expected to be a Matrix with ndim {argument_type.ndim}, but got {len(element_shape)}."
249
+ )
250
+
251
+ assert argument_type.ndim > 0
252
+ if element_shape[0] != argument_type.n:
253
+ raise GsTaichiSyntaxError(
254
+ f"Argument {var_name} of type {argument_type} is expected to be a Matrix with n {argument_type.n}, but got {element_shape[0]}."
255
+ )
256
+
257
+ if argument_type.ndim == 2 and element_shape[1] != argument_type.m:
258
+ raise GsTaichiSyntaxError(
259
+ f"Argument {var_name} of type {argument_type} is expected to be a Matrix with m {argument_type.m}, but got {element_shape[0]}."
260
+ )
261
+
262
+ ctx.create_variable(var_name, impl.expr_init_func(data))
263
+ return None
264
+
265
+ if id(argument_type) in primitive_types.type_ids:
266
+ var_name = argument_name
267
+ ctx.create_variable(var_name, impl.expr_init_func(ti_ops.cast(data, argument_type)))
268
+ return None
269
+ # Create a copy for non-template arguments,
270
+ # so that they are passed by value.
271
+ var_name = argument_name
272
+ ctx.create_variable(var_name, impl.expr_init_func(data))
273
+ return None
274
+
275
+ @staticmethod
276
+ def _transform_as_func(ctx: ASTTransformerContext, node: ast.FunctionDef, args: ast.arguments) -> None:
277
+ for data_i, data in enumerate(ctx.argument_data):
278
+ argument = ctx.func.arguments[data_i]
279
+ FunctionDefTransformer._transform_func_arg(
280
+ ctx,
281
+ argument.name,
282
+ argument.annotation,
283
+ data,
284
+ )
285
+
286
+ for v in ctx.func.orig_arguments:
287
+ if dataclasses.is_dataclass(v.annotation):
288
+ ctx.create_variable(v.name, v.annotation)
289
+
290
+ @staticmethod
291
+ def build_FunctionDef(
292
+ ctx: ASTTransformerContext,
293
+ node: ast.FunctionDef,
294
+ build_stmts: Callable[[ASTTransformerContext, list[ast.stmt]], None],
295
+ ) -> None:
296
+ if ctx.visited_funcdef:
297
+ raise GsTaichiSyntaxError(
298
+ f"Function definition is not allowed in 'ti.{'kernel' if ctx.is_kernel else 'func'}'."
299
+ )
300
+ ctx.visited_funcdef = True
301
+
302
+ args = node.args
303
+ assert args.vararg is None
304
+ assert args.kwonlyargs == []
305
+ assert args.kw_defaults == []
306
+ assert args.kwarg is None
307
+
308
+ if ctx.is_kernel: # ti.kernel
309
+ FunctionDefTransformer._transform_as_kernel(ctx, node, args)
310
+
311
+ else: # ti.func
312
+ if ctx.is_real_function:
313
+ FunctionDefTransformer._transform_as_kernel(ctx, node, args)
314
+ else:
315
+ FunctionDefTransformer._transform_as_func(ctx, node, args)
316
+
317
+ with ctx.variable_scope_guard():
318
+ build_stmts(ctx, node.body)
319
+
320
+ return None
@@ -0,0 +1,106 @@
1
+ # type: ignore
2
+
3
+ import ast
4
+
5
+ from gstaichi.lang._wrap_inspect import getsourcefile, getsourcelines
6
+ from gstaichi.lang.exception import GsTaichiSyntaxError
7
+
8
+
9
+ class KernelSimplicityASTChecker(ast.NodeVisitor):
10
+ class ScopeGuard:
11
+ def __init__(self, checker):
12
+ self.c = checker
13
+ self._allows_for_loop = True
14
+ self._allows_more_stmt = True
15
+
16
+ @property
17
+ def allows_for_loop(self):
18
+ return self._allows_for_loop
19
+
20
+ @property
21
+ def allows_more_stmt(self):
22
+ return self._allows_more_stmt
23
+
24
+ def mark_no_more_for_loop(self):
25
+ self._allows_for_loop = False
26
+
27
+ def mark_no_more_stmt(self):
28
+ self._allows_for_loop = False
29
+ self._allows_more_stmt = False
30
+
31
+ def __enter__(self):
32
+ self.c._scope_guards.append(self)
33
+
34
+ def __exit__(self, exc_type, exc_val, exc_tb):
35
+ self.c._scope_guards.pop()
36
+
37
+ def __init__(self, func):
38
+ super().__init__()
39
+ self._func_file = getsourcefile(func)
40
+ self._func_lineno = getsourcelines(func)[1]
41
+ self._func_name = func.__name__
42
+ self._scope_guards = []
43
+
44
+ def new_scope(self):
45
+ return KernelSimplicityASTChecker.ScopeGuard(self)
46
+
47
+ @property
48
+ def current_scope(self):
49
+ return self._scope_guards[-1]
50
+
51
+ @property
52
+ def top_level(self):
53
+ return len(self._scope_guards) == 0
54
+
55
+ def get_error_location(self, node):
56
+ # -1 because ast's lineno is 1-based.
57
+ lineno = self._func_lineno + node.lineno - 1
58
+ return f"file={self._func_file} kernel={self._func_name} line={lineno}"
59
+
60
+ @staticmethod
61
+ def should_check(node):
62
+ if not isinstance(node, ast.stmt):
63
+ return False
64
+ # TODO(#536): Frontend pass should help make sure |func| is a valid AST for
65
+ # GsTaichi.
66
+ ignored = [ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef]
67
+ return not any(map(lambda t: isinstance(node, t), ignored))
68
+
69
+ def generic_visit(self, node):
70
+ if not self.should_check(node):
71
+ super().generic_visit(node)
72
+ return
73
+
74
+ if not (self.top_level or self.current_scope.allows_more_stmt):
75
+ raise GsTaichiSyntaxError(f"No more statements allowed, at {self.get_error_location(node)}")
76
+ old_top_level = self.top_level
77
+ if old_top_level:
78
+ self._scope_guards.append(self.new_scope())
79
+ # Marking here before the visit has the effect of disallow for-loops in
80
+ # nested blocks. E.g. if |node| is a IfStmt, then the checker would disallow
81
+ # for-loops inside it.
82
+ self.current_scope.mark_no_more_for_loop()
83
+ super().generic_visit(node)
84
+ if old_top_level:
85
+ self._scope_guards.pop()
86
+
87
+ @staticmethod
88
+ def visit_for(node):
89
+ # TODO: since autodiff is enhanced, AST checker rules should be relaxed. This part should be updated.
90
+ # original code is #def visit_For(self, node) without #@staticmethod before fix pylint R0201
91
+ return
92
+ # is_static = (isinstance(node.iter, ast.Call)
93
+ # and isinstance(node.iter.func, ast.Attribute)
94
+ # and isinstance(node.iter.func.value, ast.Name)
95
+ # and node.iter.func.value.id == 'ti'
96
+ # and node.iter.func.attr == 'static')
97
+ # if not (self.top_level or self.current_scope.allows_for_loop
98
+ # or is_static):
99
+ # raise GsTaichiSyntaxError(
100
+ # f'No more for loops allowed, at {self.get_error_location(node)}'
101
+ # )
102
+ # with self.new_scope():
103
+ # super().generic_visit(node)
104
+ #
105
+ # if not (self.top_level or is_static):
106
+ # self.current_scope.mark_no_more_stmt()
@@ -0,0 +1,57 @@
1
+ # type: ignore
2
+
3
+ """Provides helpers to resolve AST nodes."""
4
+
5
+ import ast
6
+
7
+
8
+ class ASTResolver:
9
+ """Provides helper methods to resolve AST nodes."""
10
+
11
+ @staticmethod
12
+ def resolve_to(node, wanted, scope):
13
+ """Check if symbol ``node`` resolves to ``wanted`` object.
14
+
15
+ This is only intended to check if a given AST node resolves to a symbol
16
+ under some namespaces, e.g. the ``a.b.c.foo`` pattern, but not meant for
17
+ more complicated expressions like ``(a + b).foo``.
18
+
19
+ Args:
20
+ node (Union[ast.Attribute, ast.Name]): an AST node to be resolved.
21
+ wanted (Any): The expected python object.
22
+ scope (Dict[str, Any]): Maps from symbol names to objects, for
23
+ example, globals()
24
+
25
+ Returns:
26
+ bool: The checked result.
27
+ """
28
+ if isinstance(node, ast.Name):
29
+ return scope.get(node.id) is wanted
30
+
31
+ if not isinstance(node, ast.Attribute):
32
+ return False
33
+
34
+ v = node.value
35
+ chain = [node.attr]
36
+ while isinstance(v, ast.Attribute):
37
+ chain.append(v.attr)
38
+ v = v.value
39
+ if not isinstance(v, ast.Name):
40
+ # Example cases that fall under this branch:
41
+ #
42
+ # x[i].attr: ast.Subscript
43
+ # (a + b).attr: ast.BinOp
44
+ # ...
45
+ return False
46
+ chain.append(v.id)
47
+
48
+ for attr in reversed(chain):
49
+ try:
50
+ if isinstance(scope, dict):
51
+ scope = scope[attr]
52
+ else:
53
+ scope = getattr(scope, attr)
54
+ except (KeyError, AttributeError):
55
+ return False
56
+ # The name ``scope`` here could be a bit confusing
57
+ return scope is wanted
@@ -0,0 +1,9 @@
1
+ # type: ignore
2
+
3
+ from gstaichi.lang.ast.ast_transformer import ASTTransformer
4
+ from gstaichi.lang.ast.ast_transformer_utils import ASTTransformerContext
5
+
6
+
7
+ def transform_tree(tree, ctx: ASTTransformerContext):
8
+ ASTTransformer()(ctx, tree)
9
+ return ctx.return_data