gstaichi 2.1.1rc3__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. gstaichi/CHANGELOG.md +4 -0
  2. gstaichi/__init__.py +40 -0
  3. gstaichi/_funcs.py +706 -0
  4. gstaichi/_kernels.py +420 -0
  5. gstaichi/_lib/__init__.py +3 -0
  6. gstaichi/_lib/core/__init__.py +0 -0
  7. gstaichi/_lib/core/gstaichi_python.cpython-313-darwin.so +0 -0
  8. gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
  9. gstaichi/_lib/core/py.typed +0 -0
  10. gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
  11. gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
  12. gstaichi/_lib/utils.py +243 -0
  13. gstaichi/_logging.py +131 -0
  14. gstaichi/_snode/__init__.py +5 -0
  15. gstaichi/_snode/fields_builder.py +187 -0
  16. gstaichi/_snode/snode_tree.py +34 -0
  17. gstaichi/_test_tools/__init__.py +18 -0
  18. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  19. gstaichi/_test_tools/load_kernel_string.py +30 -0
  20. gstaichi/_test_tools/textwrap2.py +6 -0
  21. gstaichi/_version.py +1 -0
  22. gstaichi/_version_check.py +100 -0
  23. gstaichi/ad/__init__.py +3 -0
  24. gstaichi/ad/_ad.py +530 -0
  25. gstaichi/algorithms/__init__.py +3 -0
  26. gstaichi/algorithms/_algorithms.py +117 -0
  27. gstaichi/assets/.git +1 -0
  28. gstaichi/assets/Go-Regular.ttf +0 -0
  29. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  30. gstaichi/examples/lcg_python.py +26 -0
  31. gstaichi/examples/lcg_taichi.py +34 -0
  32. gstaichi/examples/minimal.py +28 -0
  33. gstaichi/experimental.py +16 -0
  34. gstaichi/lang/__init__.py +50 -0
  35. gstaichi/lang/_dataclass_util.py +31 -0
  36. gstaichi/lang/_fast_caching/__init__.py +3 -0
  37. gstaichi/lang/_fast_caching/args_hasher.py +110 -0
  38. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  39. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  40. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  41. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  42. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  43. gstaichi/lang/_fast_caching/src_hasher.py +75 -0
  44. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  45. gstaichi/lang/_ndarray.py +352 -0
  46. gstaichi/lang/_ndrange.py +152 -0
  47. gstaichi/lang/_template_mapper.py +195 -0
  48. gstaichi/lang/_texture.py +172 -0
  49. gstaichi/lang/_wrap_inspect.py +215 -0
  50. gstaichi/lang/any_array.py +99 -0
  51. gstaichi/lang/ast/__init__.py +5 -0
  52. gstaichi/lang/ast/ast_transformer.py +1323 -0
  53. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  54. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  55. gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
  56. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  57. gstaichi/lang/ast/checkers.py +106 -0
  58. gstaichi/lang/ast/symbol_resolver.py +57 -0
  59. gstaichi/lang/ast/transform.py +9 -0
  60. gstaichi/lang/common_ops.py +310 -0
  61. gstaichi/lang/exception.py +80 -0
  62. gstaichi/lang/expr.py +180 -0
  63. gstaichi/lang/field.py +428 -0
  64. gstaichi/lang/impl.py +1243 -0
  65. gstaichi/lang/kernel_arguments.py +155 -0
  66. gstaichi/lang/kernel_impl.py +1341 -0
  67. gstaichi/lang/matrix.py +1835 -0
  68. gstaichi/lang/matrix_ops.py +341 -0
  69. gstaichi/lang/matrix_ops_utils.py +190 -0
  70. gstaichi/lang/mesh.py +687 -0
  71. gstaichi/lang/misc.py +782 -0
  72. gstaichi/lang/ops.py +1494 -0
  73. gstaichi/lang/runtime_ops.py +13 -0
  74. gstaichi/lang/shell.py +35 -0
  75. gstaichi/lang/simt/__init__.py +5 -0
  76. gstaichi/lang/simt/block.py +94 -0
  77. gstaichi/lang/simt/grid.py +7 -0
  78. gstaichi/lang/simt/subgroup.py +191 -0
  79. gstaichi/lang/simt/warp.py +96 -0
  80. gstaichi/lang/snode.py +489 -0
  81. gstaichi/lang/source_builder.py +150 -0
  82. gstaichi/lang/struct.py +810 -0
  83. gstaichi/lang/util.py +312 -0
  84. gstaichi/linalg/__init__.py +8 -0
  85. gstaichi/linalg/matrixfree_cg.py +310 -0
  86. gstaichi/linalg/sparse_cg.py +59 -0
  87. gstaichi/linalg/sparse_matrix.py +303 -0
  88. gstaichi/linalg/sparse_solver.py +123 -0
  89. gstaichi/math/__init__.py +11 -0
  90. gstaichi/math/_complex.py +205 -0
  91. gstaichi/math/mathimpl.py +886 -0
  92. gstaichi/profiler/__init__.py +6 -0
  93. gstaichi/profiler/kernel_metrics.py +260 -0
  94. gstaichi/profiler/kernel_profiler.py +586 -0
  95. gstaichi/profiler/memory_profiler.py +15 -0
  96. gstaichi/profiler/scoped_profiler.py +36 -0
  97. gstaichi/sparse/__init__.py +3 -0
  98. gstaichi/sparse/_sparse_grid.py +77 -0
  99. gstaichi/tools/__init__.py +12 -0
  100. gstaichi/tools/diagnose.py +117 -0
  101. gstaichi/tools/np2ply.py +364 -0
  102. gstaichi/tools/vtk.py +38 -0
  103. gstaichi/types/__init__.py +19 -0
  104. gstaichi/types/annotations.py +52 -0
  105. gstaichi/types/compound_types.py +71 -0
  106. gstaichi/types/enums.py +49 -0
  107. gstaichi/types/ndarray_type.py +169 -0
  108. gstaichi/types/primitive_types.py +206 -0
  109. gstaichi/types/quant.py +88 -0
  110. gstaichi/types/texture_type.py +85 -0
  111. gstaichi/types/utils.py +11 -0
  112. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
  113. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
  114. gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
  115. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
  116. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
  117. gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
  118. gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
  119. gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  120. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
  121. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
  122. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  123. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  124. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  125. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  126. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  127. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  128. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  129. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  130. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  131. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  132. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  133. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  134. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  135. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  136. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  137. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  138. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  139. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  140. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  141. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  142. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  143. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  144. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  145. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  146. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  147. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  148. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  149. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  150. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  151. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  152. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  153. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  154. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  155. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  156. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  157. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  158. gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  159. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  160. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  161. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  162. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  163. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  164. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  165. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  166. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  167. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  168. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  169. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  170. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  171. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  172. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  173. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  174. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  175. gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
  176. gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
  177. gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
  178. gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
  179. gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,304 @@
1
+ # type: ignore
2
+
3
+ import ast
4
+ import dataclasses
5
+ from typing import Any, Callable
6
+
7
+ from gstaichi._lib.core.gstaichi_python import (
8
+ BoundaryMode,
9
+ DataTypeCxx,
10
+ )
11
+ from gstaichi.lang import (
12
+ _ndarray,
13
+ any_array,
14
+ expr,
15
+ impl,
16
+ kernel_arguments,
17
+ matrix,
18
+ )
19
+ from gstaichi.lang import ops as ti_ops
20
+ from gstaichi.lang._dataclass_util import create_flat_name
21
+ from gstaichi.lang.ast.ast_transformer_utils import (
22
+ ASTTransformerContext,
23
+ )
24
+ from gstaichi.lang.exception import (
25
+ GsTaichiSyntaxError,
26
+ )
27
+ from gstaichi.lang.matrix import MatrixType
28
+ from gstaichi.lang.struct import StructType
29
+ from gstaichi.lang.util import to_gstaichi_type
30
+ from gstaichi.types import annotations, ndarray_type, primitive_types, texture_type
31
+
32
+
33
+ class FunctionDefTransformer:
34
+ @staticmethod
35
+ def _decl_and_create_variable(
36
+ ctx: ASTTransformerContext,
37
+ annotation: Any,
38
+ name: str,
39
+ this_arg_features: tuple[tuple[Any, ...], ...] | None,
40
+ prefix_name: str,
41
+ ) -> tuple[bool, Any]:
42
+ full_name = prefix_name + "_" + name
43
+ if not isinstance(annotation, primitive_types.RefType):
44
+ ctx.kernel_args.append(name)
45
+ if annotation == annotations.template or isinstance(annotation, annotations.template):
46
+ assert ctx.global_vars is not None
47
+ return True, ctx.global_vars[name]
48
+ if isinstance(annotation, annotations.sparse_matrix_builder):
49
+ return False, (
50
+ kernel_arguments.decl_sparse_matrix,
51
+ (
52
+ to_gstaichi_type(this_arg_features),
53
+ full_name,
54
+ ),
55
+ )
56
+ if isinstance(annotation, ndarray_type.NdarrayType):
57
+ assert this_arg_features is not None
58
+ raw_element_type: DataTypeCxx
59
+ ndim: int
60
+ needs_grad: bool
61
+ boundary: BoundaryMode
62
+ raw_element_type, ndim, needs_grad, boundary = this_arg_features
63
+ return False, (
64
+ kernel_arguments.decl_ndarray_arg,
65
+ (
66
+ to_gstaichi_type(raw_element_type),
67
+ ndim,
68
+ full_name,
69
+ needs_grad,
70
+ boundary,
71
+ ),
72
+ )
73
+ if isinstance(annotation, texture_type.TextureType):
74
+ assert this_arg_features is not None
75
+ return False, (kernel_arguments.decl_texture_arg, (this_arg_features[0], full_name))
76
+ if isinstance(annotation, texture_type.RWTextureType):
77
+ assert this_arg_features is not None
78
+ return False, (
79
+ kernel_arguments.decl_rw_texture_arg,
80
+ (this_arg_features[0], this_arg_features[1], this_arg_features[2], full_name),
81
+ )
82
+ if isinstance(annotation, MatrixType):
83
+ return True, kernel_arguments.decl_matrix_arg(annotation, name)
84
+ if isinstance(annotation, StructType):
85
+ return True, kernel_arguments.decl_struct_arg(annotation, name)
86
+ return True, kernel_arguments.decl_scalar_arg(annotation, name)
87
+
88
+ @staticmethod
89
+ def _transform_kernel_arg(
90
+ ctx: ASTTransformerContext,
91
+ argument_name: str,
92
+ argument_type: Any,
93
+ this_arg_features: tuple[Any, ...],
94
+ ) -> None:
95
+ if dataclasses.is_dataclass(argument_type):
96
+ ctx.create_variable(argument_name, argument_type)
97
+ for field_idx, field in enumerate(dataclasses.fields(argument_type)):
98
+ flat_name = create_flat_name(argument_name, field.name)
99
+ # if a field is a dataclass, then feed back into process_kernel_arg recursively
100
+ if dataclasses.is_dataclass(field.type):
101
+ FunctionDefTransformer._transform_kernel_arg(
102
+ ctx,
103
+ flat_name,
104
+ field.type,
105
+ this_arg_features[field_idx],
106
+ )
107
+ else:
108
+ result, obj = FunctionDefTransformer._decl_and_create_variable(
109
+ ctx,
110
+ field.type,
111
+ flat_name,
112
+ this_arg_features[field_idx],
113
+ "",
114
+ )
115
+ if result:
116
+ ctx.create_variable(flat_name, obj)
117
+ else:
118
+ decl_type_func, type_args = obj
119
+ obj = decl_type_func(*type_args)
120
+ ctx.create_variable(flat_name, obj)
121
+ else:
122
+ result, obj = FunctionDefTransformer._decl_and_create_variable(
123
+ ctx,
124
+ argument_type,
125
+ argument_name,
126
+ this_arg_features if ctx.arg_features is not None else None,
127
+ "",
128
+ )
129
+ if not result:
130
+ decl_type_func, type_args = obj
131
+ obj = decl_type_func(*type_args)
132
+ ctx.create_variable(argument_name, obj)
133
+
134
+ @staticmethod
135
+ def _transform_as_kernel(ctx: ASTTransformerContext, node: ast.FunctionDef, args: ast.arguments) -> None:
136
+ assert ctx.func is not None
137
+ assert ctx.arg_features is not None
138
+ if node.returns is not None:
139
+ if not isinstance(node.returns, ast.Constant):
140
+ assert ctx.func.return_type is not None
141
+ for return_type in ctx.func.return_type:
142
+ kernel_arguments.decl_ret(return_type)
143
+ compiling_callable = impl.get_runtime().compiling_callable
144
+ assert compiling_callable is not None
145
+ compiling_callable.finalize_rets()
146
+
147
+ for i in range(len(args.args)):
148
+ arg_meta = ctx.func.arg_metas[i]
149
+ FunctionDefTransformer._transform_kernel_arg(
150
+ ctx,
151
+ arg_meta.name,
152
+ arg_meta.annotation,
153
+ ctx.arg_features[i] if ctx.arg_features is not None else (),
154
+ )
155
+
156
+ compiling_callable.finalize_params()
157
+ # remove original args
158
+ node.args.args = []
159
+
160
+ @staticmethod
161
+ def _transform_func_arg(
162
+ ctx: ASTTransformerContext,
163
+ argument_name: str,
164
+ argument_type: Any,
165
+ data: Any,
166
+ ) -> None:
167
+ # Template arguments are passed by reference.
168
+ if isinstance(argument_type, annotations.template):
169
+ ctx.create_variable(argument_name, data)
170
+ return None
171
+
172
+ if dataclasses.is_dataclass(argument_type):
173
+ for field in dataclasses.fields(argument_type):
174
+ flat_name = create_flat_name(argument_name, field.name)
175
+ data_child = getattr(data, field.name)
176
+ if isinstance(
177
+ data_child,
178
+ (
179
+ _ndarray.ScalarNdarray,
180
+ matrix.VectorNdarray,
181
+ matrix.MatrixNdarray,
182
+ any_array.AnyArray,
183
+ ),
184
+ ):
185
+ field.type.check_matched(data_child.get_type(), field.name)
186
+ ctx.create_variable(flat_name, data_child)
187
+ elif dataclasses.is_dataclass(data_child):
188
+ FunctionDefTransformer._transform_func_arg(
189
+ ctx,
190
+ flat_name,
191
+ field.type,
192
+ getattr(data, field.name),
193
+ )
194
+ else:
195
+ raise GsTaichiSyntaxError(
196
+ f"Argument {field.name} of type {argument_type} {field.type} is not recognized."
197
+ )
198
+ return None
199
+
200
+ # Ndarray arguments are passed by reference.
201
+ if isinstance(argument_type, (ndarray_type.NdarrayType)):
202
+ if not isinstance(
203
+ data, (_ndarray.ScalarNdarray, matrix.VectorNdarray, matrix.MatrixNdarray, any_array.AnyArray)
204
+ ):
205
+ raise GsTaichiSyntaxError(f"Argument {argument_name} of type {argument_type} is not recognized.")
206
+ argument_type.check_matched(data.get_type(), argument_name)
207
+ ctx.create_variable(argument_name, data)
208
+ return None
209
+
210
+ # Matrix arguments are passed by value.
211
+ if isinstance(argument_type, (MatrixType)):
212
+ # "data" is expected to be an Expr here,
213
+ # so we simply call "impl.expr_init_func(data)" to perform:
214
+ #
215
+ # TensorType* t = alloca()
216
+ # assign(t, data)
217
+ #
218
+ # We created local variable "t" - a copy of the passed-in argument "data"
219
+ if not isinstance(data, expr.Expr) or not data.ptr.is_tensor():
220
+ raise GsTaichiSyntaxError(
221
+ f"Argument {argument_name} of type {argument_type} is expected to be a Matrix, but got {type(data)}."
222
+ )
223
+
224
+ element_shape = data.ptr.get_rvalue_type().shape()
225
+ if len(element_shape) != argument_type.ndim:
226
+ raise GsTaichiSyntaxError(
227
+ f"Argument {argument_name} of type {argument_type} is expected to be a Matrix with ndim {argument_type.ndim}, but got {len(element_shape)}."
228
+ )
229
+
230
+ assert argument_type.ndim > 0
231
+ if element_shape[0] != argument_type.n:
232
+ raise GsTaichiSyntaxError(
233
+ f"Argument {argument_name} of type {argument_type} is expected to be a Matrix with n {argument_type.n}, but got {element_shape[0]}."
234
+ )
235
+
236
+ if argument_type.ndim == 2 and element_shape[1] != argument_type.m:
237
+ raise GsTaichiSyntaxError(
238
+ f"Argument {argument_name} of type {argument_type} is expected to be a Matrix with m {argument_type.m}, but got {element_shape[0]}."
239
+ )
240
+
241
+ ctx.create_variable(argument_name, impl.expr_init_func(data))
242
+ return None
243
+
244
+ if id(argument_type) in primitive_types.type_ids:
245
+ ctx.create_variable(argument_name, impl.expr_init_func(ti_ops.cast(data, argument_type)))
246
+ return None
247
+ # Create a copy for non-template arguments,
248
+ # so that they are passed by value.
249
+ var_name = argument_name
250
+ ctx.create_variable(var_name, impl.expr_init_func(data))
251
+ return None
252
+
253
+ @staticmethod
254
+ def _transform_as_func(ctx: ASTTransformerContext, node: ast.FunctionDef, args: ast.arguments) -> None:
255
+ # pylint: disable=import-outside-toplevel
256
+ from gstaichi.lang.kernel_impl import Func
257
+
258
+ assert isinstance(ctx.func, Func)
259
+ assert ctx.argument_data is not None
260
+ for data_i, data in enumerate(ctx.argument_data):
261
+ argument = ctx.func.arg_metas[data_i]
262
+ FunctionDefTransformer._transform_func_arg(ctx, argument.name, argument.annotation, data)
263
+
264
+ # deal with dataclasses
265
+ for v in ctx.func.orig_arguments:
266
+ if dataclasses.is_dataclass(v.annotation):
267
+ ctx.create_variable(v.name, v.annotation)
268
+
269
+ @staticmethod
270
+ def build_FunctionDef(
271
+ ctx: ASTTransformerContext,
272
+ node: ast.FunctionDef,
273
+ build_stmts: Callable[[ASTTransformerContext, list[ast.stmt]], None],
274
+ ) -> None:
275
+ if ctx.visited_funcdef:
276
+ raise GsTaichiSyntaxError(
277
+ f"Function definition is not allowed in 'ti.{'kernel' if ctx.is_kernel else 'func'}'."
278
+ )
279
+ ctx.visited_funcdef = True
280
+
281
+ args = node.args
282
+ assert args.vararg is None
283
+ assert args.kwonlyargs == []
284
+ assert args.kw_defaults == []
285
+ assert args.kwarg is None
286
+
287
+ if ctx.is_kernel: # ti.kernel
288
+ FunctionDefTransformer._transform_as_kernel(ctx, node, args)
289
+
290
+ if ctx.only_parse_function_def:
291
+ return None
292
+
293
+ if not ctx.is_kernel: # ti.func
294
+ assert ctx.argument_data is not None
295
+ assert ctx.func is not None
296
+ if ctx.is_real_function:
297
+ FunctionDefTransformer._transform_as_kernel(ctx, node, args)
298
+ else:
299
+ FunctionDefTransformer._transform_as_func(ctx, node, args)
300
+
301
+ with ctx.variable_scope_guard():
302
+ build_stmts(ctx, node.body)
303
+
304
+ return None
@@ -0,0 +1,106 @@
1
+ # type: ignore
2
+
3
+ import ast
4
+
5
+ from gstaichi.lang._wrap_inspect import getsourcefile, getsourcelines
6
+ from gstaichi.lang.exception import GsTaichiSyntaxError
7
+
8
+
9
+ class KernelSimplicityASTChecker(ast.NodeVisitor):
10
+ class ScopeGuard:
11
+ def __init__(self, checker):
12
+ self.c = checker
13
+ self._allows_for_loop = True
14
+ self._allows_more_stmt = True
15
+
16
+ @property
17
+ def allows_for_loop(self):
18
+ return self._allows_for_loop
19
+
20
+ @property
21
+ def allows_more_stmt(self):
22
+ return self._allows_more_stmt
23
+
24
+ def mark_no_more_for_loop(self):
25
+ self._allows_for_loop = False
26
+
27
+ def mark_no_more_stmt(self):
28
+ self._allows_for_loop = False
29
+ self._allows_more_stmt = False
30
+
31
+ def __enter__(self):
32
+ self.c._scope_guards.append(self)
33
+
34
+ def __exit__(self, exc_type, exc_val, exc_tb):
35
+ self.c._scope_guards.pop()
36
+
37
+ def __init__(self, func):
38
+ super().__init__()
39
+ self._func_file = getsourcefile(func)
40
+ self._func_lineno = getsourcelines(func)[1]
41
+ self._func_name = func.__name__
42
+ self._scope_guards = []
43
+
44
+ def new_scope(self):
45
+ return KernelSimplicityASTChecker.ScopeGuard(self)
46
+
47
+ @property
48
+ def current_scope(self):
49
+ return self._scope_guards[-1]
50
+
51
+ @property
52
+ def top_level(self):
53
+ return len(self._scope_guards) == 0
54
+
55
+ def get_error_location(self, node):
56
+ # -1 because ast's lineno is 1-based.
57
+ lineno = self._func_lineno + node.lineno - 1
58
+ return f"file={self._func_file} kernel={self._func_name} line={lineno}"
59
+
60
+ @staticmethod
61
+ def should_check(node):
62
+ if not isinstance(node, ast.stmt):
63
+ return False
64
+ # TODO(#536): Frontend pass should help make sure |func| is a valid AST for
65
+ # GsTaichi.
66
+ ignored = [ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef]
67
+ return not any(map(lambda t: isinstance(node, t), ignored))
68
+
69
+ def generic_visit(self, node):
70
+ if not self.should_check(node):
71
+ super().generic_visit(node)
72
+ return
73
+
74
+ if not (self.top_level or self.current_scope.allows_more_stmt):
75
+ raise GsTaichiSyntaxError(f"No more statements allowed, at {self.get_error_location(node)}")
76
+ old_top_level = self.top_level
77
+ if old_top_level:
78
+ self._scope_guards.append(self.new_scope())
79
+ # Marking here before the visit has the effect of disallow for-loops in
80
+ # nested blocks. E.g. if |node| is a IfStmt, then the checker would disallow
81
+ # for-loops inside it.
82
+ self.current_scope.mark_no_more_for_loop()
83
+ super().generic_visit(node)
84
+ if old_top_level:
85
+ self._scope_guards.pop()
86
+
87
+ @staticmethod
88
+ def visit_for(node):
89
+ # TODO: since autodiff is enhanced, AST checker rules should be relaxed. This part should be updated.
90
+ # original code is #def visit_For(self, node) without #@staticmethod before fix pylint R0201
91
+ return
92
+ # is_static = (isinstance(node.iter, ast.Call)
93
+ # and isinstance(node.iter.func, ast.Attribute)
94
+ # and isinstance(node.iter.func.value, ast.Name)
95
+ # and node.iter.func.value.id == 'ti'
96
+ # and node.iter.func.attr == 'static')
97
+ # if not (self.top_level or self.current_scope.allows_for_loop
98
+ # or is_static):
99
+ # raise GsTaichiSyntaxError(
100
+ # f'No more for loops allowed, at {self.get_error_location(node)}'
101
+ # )
102
+ # with self.new_scope():
103
+ # super().generic_visit(node)
104
+ #
105
+ # if not (self.top_level or is_static):
106
+ # self.current_scope.mark_no_more_stmt()
@@ -0,0 +1,57 @@
1
+ # type: ignore
2
+
3
+ """Provides helpers to resolve AST nodes."""
4
+
5
+ import ast
6
+
7
+
8
+ class ASTResolver:
9
+ """Provides helper methods to resolve AST nodes."""
10
+
11
+ @staticmethod
12
+ def resolve_to(node, wanted, scope):
13
+ """Check if symbol ``node`` resolves to ``wanted`` object.
14
+
15
+ This is only intended to check if a given AST node resolves to a symbol
16
+ under some namespaces, e.g. the ``a.b.c.foo`` pattern, but not meant for
17
+ more complicated expressions like ``(a + b).foo``.
18
+
19
+ Args:
20
+ node (Union[ast.Attribute, ast.Name]): an AST node to be resolved.
21
+ wanted (Any): The expected python object.
22
+ scope (Dict[str, Any]): Maps from symbol names to objects, for
23
+ example, globals()
24
+
25
+ Returns:
26
+ bool: The checked result.
27
+ """
28
+ if isinstance(node, ast.Name):
29
+ return scope.get(node.id) is wanted
30
+
31
+ if not isinstance(node, ast.Attribute):
32
+ return False
33
+
34
+ v = node.value
35
+ chain = [node.attr]
36
+ while isinstance(v, ast.Attribute):
37
+ chain.append(v.attr)
38
+ v = v.value
39
+ if not isinstance(v, ast.Name):
40
+ # Example cases that fall under this branch:
41
+ #
42
+ # x[i].attr: ast.Subscript
43
+ # (a + b).attr: ast.BinOp
44
+ # ...
45
+ return False
46
+ chain.append(v.id)
47
+
48
+ for attr in reversed(chain):
49
+ try:
50
+ if isinstance(scope, dict):
51
+ scope = scope[attr]
52
+ else:
53
+ scope = getattr(scope, attr)
54
+ except (KeyError, AttributeError):
55
+ return False
56
+ # The name ``scope`` here could be a bit confusing
57
+ return scope is wanted
@@ -0,0 +1,9 @@
1
+ # type: ignore
2
+
3
+ from gstaichi.lang.ast.ast_transformer import ASTTransformer
4
+ from gstaichi.lang.ast.ast_transformer_utils import ASTTransformerContext
5
+
6
+
7
+ def transform_tree(tree, ctx: ASTTransformerContext):
8
+ ASTTransformer()(ctx, tree)
9
+ return ctx.return_data