gstaichi 0.1.21.dev0__cp310-cp310-win_amd64.whl → 0.1.25.dev0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. gstaichi/CHANGELOG.md +9 -0
  2. {taichi → gstaichi}/__init__.py +9 -13
  3. {taichi → gstaichi}/_funcs.py +8 -8
  4. {taichi → gstaichi}/_kernels.py +19 -19
  5. gstaichi/_lib/__init__.py +3 -0
  6. taichi/_lib/core/taichi_python.cp310-win_amd64.pyd → gstaichi/_lib/core/gstaichi_python.cp310-win_amd64.pyd +0 -0
  7. taichi/_lib/core/taichi_python.pyi → gstaichi/_lib/core/gstaichi_python.pyi +382 -522
  8. {taichi → gstaichi}/_lib/runtime/runtime_cuda.bc +0 -0
  9. {taichi → gstaichi}/_lib/runtime/runtime_x64.bc +0 -0
  10. {taichi → gstaichi}/_lib/utils.py +15 -15
  11. {taichi → gstaichi}/_logging.py +1 -1
  12. {taichi → gstaichi}/_main.py +24 -31
  13. gstaichi/_snode/__init__.py +5 -0
  14. {taichi → gstaichi}/_snode/fields_builder.py +27 -29
  15. {taichi → gstaichi}/_snode/snode_tree.py +5 -5
  16. gstaichi/_test_tools/__init__.py +0 -0
  17. gstaichi/_test_tools/load_kernel_string.py +30 -0
  18. gstaichi/_version.py +1 -0
  19. {taichi → gstaichi}/_version_check.py +8 -5
  20. gstaichi/ad/__init__.py +3 -0
  21. {taichi → gstaichi}/ad/_ad.py +26 -26
  22. {taichi → gstaichi}/algorithms/_algorithms.py +7 -7
  23. {taichi → gstaichi}/examples/minimal.py +1 -1
  24. {taichi → gstaichi}/experimental.py +1 -1
  25. gstaichi/lang/__init__.py +50 -0
  26. {taichi → gstaichi}/lang/_ndarray.py +30 -26
  27. {taichi → gstaichi}/lang/_ndrange.py +8 -8
  28. gstaichi/lang/_template_mapper.py +199 -0
  29. {taichi → gstaichi}/lang/_texture.py +19 -19
  30. {taichi → gstaichi}/lang/_wrap_inspect.py +7 -7
  31. {taichi → gstaichi}/lang/any_array.py +13 -13
  32. {taichi → gstaichi}/lang/argpack.py +29 -29
  33. gstaichi/lang/ast/__init__.py +5 -0
  34. {taichi → gstaichi}/lang/ast/ast_transformer.py +94 -582
  35. {taichi → gstaichi}/lang/ast/ast_transformer_utils.py +54 -41
  36. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  37. gstaichi/lang/ast/ast_transformers/call_transformer.py +267 -0
  38. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +320 -0
  39. {taichi → gstaichi}/lang/ast/checkers.py +5 -5
  40. gstaichi/lang/ast/transform.py +9 -0
  41. {taichi → gstaichi}/lang/common_ops.py +12 -12
  42. gstaichi/lang/exception.py +80 -0
  43. {taichi → gstaichi}/lang/expr.py +22 -22
  44. {taichi → gstaichi}/lang/field.py +29 -27
  45. {taichi → gstaichi}/lang/impl.py +116 -121
  46. {taichi → gstaichi}/lang/kernel_arguments.py +16 -16
  47. {taichi → gstaichi}/lang/kernel_impl.py +330 -363
  48. {taichi → gstaichi}/lang/matrix.py +119 -115
  49. {taichi → gstaichi}/lang/matrix_ops.py +6 -6
  50. {taichi → gstaichi}/lang/matrix_ops_utils.py +4 -4
  51. {taichi → gstaichi}/lang/mesh.py +22 -22
  52. {taichi → gstaichi}/lang/misc.py +39 -68
  53. {taichi → gstaichi}/lang/ops.py +146 -141
  54. {taichi → gstaichi}/lang/runtime_ops.py +2 -2
  55. {taichi → gstaichi}/lang/shell.py +3 -3
  56. {taichi → gstaichi}/lang/simt/__init__.py +1 -1
  57. {taichi → gstaichi}/lang/simt/block.py +7 -7
  58. {taichi → gstaichi}/lang/simt/grid.py +1 -1
  59. {taichi → gstaichi}/lang/simt/subgroup.py +1 -1
  60. {taichi → gstaichi}/lang/simt/warp.py +1 -1
  61. {taichi → gstaichi}/lang/snode.py +46 -44
  62. {taichi → gstaichi}/lang/source_builder.py +13 -13
  63. {taichi → gstaichi}/lang/struct.py +33 -33
  64. {taichi → gstaichi}/lang/util.py +24 -24
  65. gstaichi/linalg/__init__.py +8 -0
  66. {taichi → gstaichi}/linalg/matrixfree_cg.py +14 -14
  67. {taichi → gstaichi}/linalg/sparse_cg.py +10 -10
  68. {taichi → gstaichi}/linalg/sparse_matrix.py +23 -23
  69. {taichi → gstaichi}/linalg/sparse_solver.py +21 -21
  70. {taichi → gstaichi}/math/__init__.py +1 -1
  71. {taichi → gstaichi}/math/_complex.py +21 -20
  72. {taichi → gstaichi}/math/mathimpl.py +56 -56
  73. gstaichi/profiler/__init__.py +6 -0
  74. {taichi → gstaichi}/profiler/kernel_metrics.py +11 -11
  75. {taichi → gstaichi}/profiler/kernel_profiler.py +30 -36
  76. {taichi → gstaichi}/profiler/memory_profiler.py +1 -1
  77. {taichi → gstaichi}/profiler/scoped_profiler.py +2 -2
  78. {taichi → gstaichi}/sparse/_sparse_grid.py +7 -7
  79. {taichi → gstaichi}/tools/__init__.py +4 -4
  80. {taichi → gstaichi}/tools/diagnose.py +10 -17
  81. gstaichi/types/__init__.py +19 -0
  82. {taichi → gstaichi}/types/annotations.py +1 -1
  83. {taichi → gstaichi}/types/compound_types.py +8 -8
  84. {taichi → gstaichi}/types/enums.py +1 -1
  85. {taichi → gstaichi}/types/ndarray_type.py +7 -7
  86. {taichi → gstaichi}/types/primitive_types.py +17 -14
  87. {taichi → gstaichi}/types/quant.py +9 -9
  88. {taichi → gstaichi}/types/texture_type.py +5 -5
  89. {taichi → gstaichi}/types/utils.py +1 -1
  90. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/bin/SPIRV-Tools-shared.dll +0 -0
  91. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools-diff.lib +0 -0
  92. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools-link.lib +0 -0
  93. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools-lint.lib +0 -0
  94. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools-opt.lib +0 -0
  95. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools-reduce.lib +0 -0
  96. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools-shared.lib +0 -0
  97. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools.lib +0 -0
  98. {gstaichi-0.1.21.dev0.dist-info → gstaichi-0.1.25.dev0.dist-info}/METADATA +13 -16
  99. gstaichi-0.1.25.dev0.dist-info/RECORD +138 -0
  100. gstaichi-0.1.25.dev0.dist-info/entry_points.txt +2 -0
  101. gstaichi-0.1.25.dev0.dist-info/top_level.txt +1 -0
  102. gstaichi-0.1.21.dev0.data/data/include/GLFW/glfw3.h +0 -6389
  103. gstaichi-0.1.21.dev0.data/data/include/GLFW/glfw3native.h +0 -594
  104. gstaichi-0.1.21.dev0.data/data/lib/cmake/glfw3/glfw3Config.cmake +0 -3
  105. gstaichi-0.1.21.dev0.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +0 -65
  106. gstaichi-0.1.21.dev0.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +0 -19
  107. gstaichi-0.1.21.dev0.data/data/lib/cmake/glfw3/glfw3Targets.cmake +0 -107
  108. gstaichi-0.1.21.dev0.data/data/lib/glfw3.lib +0 -0
  109. gstaichi-0.1.21.dev0.dist-info/RECORD +0 -198
  110. gstaichi-0.1.21.dev0.dist-info/entry_points.txt +0 -2
  111. gstaichi-0.1.21.dev0.dist-info/top_level.txt +0 -1
  112. taichi/CHANGELOG.md +0 -17
  113. taichi/_lib/__init__.py +0 -3
  114. taichi/_lib/c_api/bin/taichi_c_api.dll +0 -0
  115. taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +0 -1401
  116. taichi/_lib/c_api/include/taichi/taichi.h +0 -29
  117. taichi/_lib/c_api/include/taichi/taichi_core.h +0 -1111
  118. taichi/_lib/c_api/include/taichi/taichi_cpu.h +0 -29
  119. taichi/_lib/c_api/include/taichi/taichi_cuda.h +0 -36
  120. taichi/_lib/c_api/include/taichi/taichi_platform.h +0 -55
  121. taichi/_lib/c_api/include/taichi/taichi_unity.h +0 -64
  122. taichi/_lib/c_api/include/taichi/taichi_vulkan.h +0 -151
  123. taichi/_lib/c_api/lib/taichi_c_api.lib +0 -0
  124. taichi/_lib/c_api/runtime/runtime_cuda.bc +0 -0
  125. taichi/_lib/c_api/runtime/runtime_x64.bc +0 -0
  126. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +0 -29
  127. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +0 -65
  128. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +0 -121
  129. taichi/_lib/runtime/slim_libdevice.10.bc +0 -0
  130. taichi/_snode/__init__.py +0 -5
  131. taichi/_ti_module/__init__.py +0 -3
  132. taichi/_ti_module/cppgen.py +0 -309
  133. taichi/_ti_module/module.py +0 -145
  134. taichi/_version.py +0 -1
  135. taichi/ad/__init__.py +0 -3
  136. taichi/aot/__init__.py +0 -12
  137. taichi/aot/_export.py +0 -28
  138. taichi/aot/conventions/__init__.py +0 -3
  139. taichi/aot/conventions/gfxruntime140/__init__.py +0 -38
  140. taichi/aot/conventions/gfxruntime140/dr.py +0 -244
  141. taichi/aot/conventions/gfxruntime140/sr.py +0 -613
  142. taichi/aot/module.py +0 -253
  143. taichi/aot/utils.py +0 -151
  144. taichi/graph/__init__.py +0 -3
  145. taichi/graph/_graph.py +0 -292
  146. taichi/lang/__init__.py +0 -50
  147. taichi/lang/ast/__init__.py +0 -5
  148. taichi/lang/ast/transform.py +0 -9
  149. taichi/lang/exception.py +0 -80
  150. taichi/linalg/__init__.py +0 -8
  151. taichi/profiler/__init__.py +0 -6
  152. taichi/shaders/Circles_vk.frag +0 -29
  153. taichi/shaders/Circles_vk.vert +0 -45
  154. taichi/shaders/Circles_vk_frag.spv +0 -0
  155. taichi/shaders/Circles_vk_vert.spv +0 -0
  156. taichi/shaders/Lines_vk.frag +0 -9
  157. taichi/shaders/Lines_vk.vert +0 -11
  158. taichi/shaders/Lines_vk_frag.spv +0 -0
  159. taichi/shaders/Lines_vk_vert.spv +0 -0
  160. taichi/shaders/Mesh_vk.frag +0 -71
  161. taichi/shaders/Mesh_vk.vert +0 -68
  162. taichi/shaders/Mesh_vk_frag.spv +0 -0
  163. taichi/shaders/Mesh_vk_vert.spv +0 -0
  164. taichi/shaders/Particles_vk.frag +0 -95
  165. taichi/shaders/Particles_vk.vert +0 -73
  166. taichi/shaders/Particles_vk_frag.spv +0 -0
  167. taichi/shaders/Particles_vk_vert.spv +0 -0
  168. taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
  169. taichi/shaders/SceneLines_vk.frag +0 -9
  170. taichi/shaders/SceneLines_vk.vert +0 -12
  171. taichi/shaders/SceneLines_vk_frag.spv +0 -0
  172. taichi/shaders/SceneLines_vk_vert.spv +0 -0
  173. taichi/shaders/SetImage_vk.frag +0 -21
  174. taichi/shaders/SetImage_vk.vert +0 -15
  175. taichi/shaders/SetImage_vk_frag.spv +0 -0
  176. taichi/shaders/SetImage_vk_vert.spv +0 -0
  177. taichi/shaders/Triangles_vk.frag +0 -16
  178. taichi/shaders/Triangles_vk.vert +0 -29
  179. taichi/shaders/Triangles_vk_frag.spv +0 -0
  180. taichi/shaders/Triangles_vk_vert.spv +0 -0
  181. taichi/shaders/lines2quad_vk_comp.spv +0 -0
  182. taichi/types/__init__.py +0 -19
  183. {taichi → gstaichi}/__main__.py +0 -0
  184. {taichi → gstaichi}/_lib/core/__init__.py +0 -0
  185. {taichi → gstaichi}/_lib/core/py.typed +0 -0
  186. {taichi/_lib/c_api → gstaichi/_lib}/runtime/slim_libdevice.10.bc +0 -0
  187. {taichi → gstaichi}/algorithms/__init__.py +0 -0
  188. {taichi → gstaichi}/assets/.git +0 -0
  189. {taichi → gstaichi}/assets/Go-Regular.ttf +0 -0
  190. {taichi → gstaichi}/assets/static/imgs/ti_gallery.png +0 -0
  191. {taichi → gstaichi}/lang/ast/symbol_resolver.py +0 -0
  192. {taichi → gstaichi}/sparse/__init__.py +0 -0
  193. {taichi → gstaichi}/tools/np2ply.py +0 -0
  194. {taichi → gstaichi}/tools/vtk.py +0 -0
  195. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +0 -0
  196. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +0 -0
  197. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +0 -0
  198. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +0 -0
  199. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +0 -0
  200. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +0 -0
  201. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +0 -0
  202. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +0 -0
  203. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +0 -0
  204. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +0 -0
  205. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +0 -0
  206. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +0 -0
  207. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +0 -0
  208. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +0 -0
  209. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +0 -0
  210. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +0 -0
  211. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +0 -0
  212. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +0 -0
  213. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/instrument.hpp +0 -0
  214. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/libspirv.h +0 -0
  215. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/libspirv.hpp +0 -0
  216. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/linker.hpp +0 -0
  217. {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/optimizer.hpp +0 -0
  218. {gstaichi-0.1.21.dev0.dist-info → gstaichi-0.1.25.dev0.dist-info}/WHEEL +0 -0
  219. {gstaichi-0.1.21.dev0.dist-info → gstaichi-0.1.25.dev0.dist-info}/licenses/LICENSE +0 -0
@@ -4,33 +4,34 @@ import ast
4
4
  import builtins
5
5
  import traceback
6
6
  from enum import Enum
7
- from sys import version_info
8
7
  from textwrap import TextWrapper
9
8
  from typing import TYPE_CHECKING, Any, List
10
9
 
11
- from taichi._lib.core.taichi_python import ASTBuilder
12
- from taichi.lang import impl
13
- from taichi.lang.exception import (
14
- TaichiCompilationError,
15
- TaichiNameError,
16
- TaichiSyntaxError,
10
+ from gstaichi._lib.core.gstaichi_python import ASTBuilder
11
+ from gstaichi.lang import impl
12
+ from gstaichi.lang._ndrange import ndrange
13
+ from gstaichi.lang.ast.symbol_resolver import ASTResolver
14
+ from gstaichi.lang.exception import (
15
+ GsTaichiCompilationError,
16
+ GsTaichiNameError,
17
+ GsTaichiSyntaxError,
17
18
  handle_exception_from_cpp,
18
19
  )
19
20
 
20
21
  if TYPE_CHECKING:
21
- from taichi.lang.kernel_impl import (
22
+ from gstaichi.lang.kernel_impl import (
22
23
  Func,
23
24
  Kernel,
24
25
  )
25
26
 
26
27
 
27
28
  class Builder:
28
- def __call__(self, ctx, node):
29
+ def __call__(self, ctx: "ASTTransformerContext", node: ast.AST):
29
30
  method = getattr(self, "build_" + node.__class__.__name__, None)
30
31
  try:
31
32
  if method is None:
32
33
  error_msg = f'Unsupported node "{node.__class__.__name__}"'
33
- raise TaichiSyntaxError(error_msg)
34
+ raise GsTaichiSyntaxError(error_msg)
34
35
  info = ctx.get_pos_info(node) if isinstance(node, (ast.stmt, ast.expr)) else ""
35
36
  with impl.get_runtime().src_info_guard(info):
36
37
  return method(ctx, node)
@@ -41,15 +42,15 @@ class Builder:
41
42
  raise e.with_traceback(None)
42
43
  ctx.raised = True
43
44
  e = handle_exception_from_cpp(e)
44
- if not isinstance(e, TaichiCompilationError):
45
+ if not isinstance(e, GsTaichiCompilationError):
45
46
  msg = ctx.get_pos_info(node) + traceback.format_exc()
46
- raise TaichiCompilationError(msg) from None
47
+ raise GsTaichiCompilationError(msg) from None
47
48
  msg = ctx.get_pos_info(node) + str(e)
48
49
  raise type(e)(msg) from None
49
50
 
50
51
 
51
52
  class VariableScopeGuard:
52
- def __init__(self, scopes):
53
+ def __init__(self, scopes: list[dict[str, Any]]):
53
54
  self.scopes = scopes
54
55
 
55
56
  def __enter__(self):
@@ -65,7 +66,7 @@ class StaticScopeStatus:
65
66
 
66
67
 
67
68
  class StaticScopeGuard:
68
- def __init__(self, status):
69
+ def __init__(self, status: StaticScopeStatus):
69
70
  self.status = status
70
71
 
71
72
  def __enter__(self):
@@ -107,7 +108,7 @@ class LoopScopeAttribute:
107
108
 
108
109
 
109
110
  class LoopScopeGuard:
110
- def __init__(self, scopes, non_static_guard=None):
111
+ def __init__(self, scopes: list[dict[str, Any]], non_static_guard=None):
111
112
  self.scopes = scopes
112
113
  self.non_static_guard = non_static_guard
113
114
 
@@ -167,7 +168,7 @@ class ASTTransformerContext:
167
168
  is_real_function: bool = False,
168
169
  ):
169
170
  self.func = func
170
- self.local_scopes = []
171
+ self.local_scopes: list[dict[str, Any]] = []
171
172
  self.loop_scopes: List[LoopScopeAttribute] = []
172
173
  self.excluded_parameters = excluded_parameters
173
174
  self.is_kernel = is_kernel
@@ -192,7 +193,7 @@ class ASTTransformerContext:
192
193
  self.ast_builder = ast_builder
193
194
  self.visited_funcdef = False
194
195
  self.is_real_function = is_real_function
195
- self.kernel_args = []
196
+ self.kernel_args: list = []
196
197
 
197
198
  # e.g.: FunctionDef, Module, Global
198
199
  def variable_scope_guard(self):
@@ -211,61 +212,61 @@ class ASTTransformerContext:
211
212
  self.non_static_control_flow_status,
212
213
  )
213
214
 
214
- def non_static_control_flow_guard(self):
215
+ def non_static_control_flow_guard(self) -> NonStaticControlFlowGuard:
215
216
  return NonStaticControlFlowGuard(self.non_static_control_flow_status)
216
217
 
217
- def static_scope_guard(self):
218
+ def static_scope_guard(self) -> StaticScopeGuard:
218
219
  return StaticScopeGuard(self.static_scope_status)
219
220
 
220
- def current_scope(self):
221
+ def current_scope(self) -> dict[str, Any]:
221
222
  return self.local_scopes[-1]
222
223
 
223
- def current_loop_scope(self):
224
+ def current_loop_scope(self) -> dict[str, Any]:
224
225
  return self.loop_scopes[-1]
225
226
 
226
- def loop_status(self):
227
+ def loop_status(self) -> LoopStatus:
227
228
  if self.loop_scopes:
228
229
  return self.loop_scopes[-1].status
229
230
  return LoopStatus.Normal
230
231
 
231
- def set_loop_status(self, status):
232
+ def set_loop_status(self, status: LoopStatus) -> None:
232
233
  self.loop_scopes[-1].status = status
233
234
 
234
- def is_in_static_for(self):
235
+ def is_in_static_for(self) -> bool:
235
236
  if self.loop_scopes:
236
237
  return self.loop_scopes[-1].is_static
237
238
  return False
238
239
 
239
- def is_in_non_static_control_flow(self):
240
+ def is_in_non_static_control_flow(self) -> bool:
240
241
  return self.non_static_control_flow_status.is_in_non_static_control_flow
241
242
 
242
- def is_in_static_scope(self):
243
+ def is_in_static_scope(self) -> bool:
243
244
  return self.static_scope_status.is_in_static_scope
244
245
 
245
- def is_var_declared(self, name):
246
+ def is_var_declared(self, name: str) -> bool:
246
247
  for s in self.local_scopes:
247
248
  if name in s:
248
249
  return True
249
250
  return False
250
251
 
251
- def create_variable(self, name, var):
252
+ def create_variable(self, name: str, var: Any) -> None:
252
253
  if name in self.current_scope():
253
- raise TaichiSyntaxError("Recreating variables is not allowed")
254
+ raise GsTaichiSyntaxError("Recreating variables is not allowed")
254
255
  self.current_scope()[name] = var
255
256
 
256
- def check_loop_var(self, loop_var):
257
+ def check_loop_var(self, loop_var: str) -> None:
257
258
  if self.is_var_declared(loop_var):
258
- raise TaichiSyntaxError(
259
+ raise GsTaichiSyntaxError(
259
260
  f"Variable '{loop_var}' is already declared in the outer scope and cannot be used as loop variable"
260
261
  )
261
262
 
262
- def get_var_by_name(self, name: str):
263
+ def get_var_by_name(self, name: str) -> Any:
263
264
  for s in reversed(self.local_scopes):
264
265
  if name in s:
265
266
  return s[name]
266
267
  if name in self.global_vars:
267
268
  var = self.global_vars[name]
268
- from taichi.lang.matrix import ( # pylint: disable-msg=C0415
269
+ from gstaichi.lang.matrix import ( # pylint: disable-msg=C0415
269
270
  Matrix,
270
271
  make_matrix,
271
272
  )
@@ -276,19 +277,16 @@ class ASTTransformerContext:
276
277
  try:
277
278
  return getattr(builtins, name)
278
279
  except AttributeError:
279
- raise TaichiNameError(f'Name "{name}" is not defined')
280
+ raise GsTaichiNameError(f'Name "{name}" is not defined')
280
281
 
281
- def get_pos_info(self, node) -> str:
282
+ def get_pos_info(self, node: ast.AST) -> str:
282
283
  msg = f'File "{self.file}", line {node.lineno + self.lineno_offset}, in {self.func.func.__name__}:\n'
283
- if version_info < (3, 8):
284
- msg += self.src[node.lineno - 1] + "\n"
285
- return msg
286
284
  col_offset = self.indent + node.col_offset
287
285
  end_col_offset = self.indent + node.end_col_offset
288
286
 
289
287
  wrapper = TextWrapper(width=80)
290
288
 
291
- def gen_line(code, hint):
289
+ def gen_line(code: str, hint: str) -> str:
292
290
  hint += " " * (len(code) - len(hint))
293
291
  code = wrapper.wrap(code)
294
292
  hint = wrapper.wrap(hint)
@@ -297,8 +295,9 @@ class ASTTransformerContext:
297
295
  return "".join([c + "\n" + h + "\n" for c, h in zip(code, hint)])
298
296
 
299
297
  if node.lineno == node.end_lineno:
300
- hint = " " * col_offset + "^" * (end_col_offset - col_offset)
301
- msg += gen_line(self.src[node.lineno - 1], hint)
298
+ if node.lineno - 1 < len(self.src):
299
+ hint = " " * col_offset + "^" * (end_col_offset - col_offset)
300
+ msg += gen_line(self.src[node.lineno - 1], hint)
302
301
  else:
303
302
  node_type = node.__class__.__name__
304
303
 
@@ -326,3 +325,17 @@ class ASTTransformerContext:
326
325
  hint = ""
327
326
  msg += gen_line(self.src[i], hint)
328
327
  return msg
328
+
329
+
330
+ def get_decorator(ctx: ASTTransformerContext, node) -> str:
331
+ if not isinstance(node, ast.Call):
332
+ return ""
333
+ for wanted, name in [
334
+ (impl.static, "static"),
335
+ (impl.static_assert, "static_assert"),
336
+ (impl.grouped, "grouped"),
337
+ (ndrange, "ndrange"),
338
+ ]:
339
+ if ASTResolver.resolve_to(node.func, wanted, ctx.global_vars):
340
+ return name
341
+ return ""
File without changes
@@ -0,0 +1,267 @@
1
+ # type: ignore
2
+
3
+ import ast
4
+ import dataclasses
5
+ import inspect
6
+ import math
7
+ import operator
8
+ import re
9
+ import warnings
10
+ from ast import unparse
11
+ from collections import ChainMap
12
+
13
+ import numpy as np
14
+
15
+ from gstaichi.lang import (
16
+ expr,
17
+ impl,
18
+ matrix,
19
+ )
20
+ from gstaichi.lang import ops as ti_ops
21
+ from gstaichi.lang.ast.ast_transformer_utils import (
22
+ ASTTransformerContext,
23
+ get_decorator,
24
+ )
25
+ from gstaichi.lang.exception import (
26
+ GsTaichiSyntaxError,
27
+ GsTaichiTypeError,
28
+ )
29
+ from gstaichi.lang.expr import Expr
30
+ from gstaichi.lang.matrix import Matrix, Vector
31
+ from gstaichi.lang.util import is_gstaichi_class
32
+ from gstaichi.types import primitive_types
33
+
34
+
35
+ class CallTransformer:
36
+ @staticmethod
37
+ def build_call_if_is_builtin(ctx: ASTTransformerContext, node, args, keywords):
38
+ from gstaichi.lang import matrix_ops # pylint: disable=C0415
39
+
40
+ func = node.func.ptr
41
+ replace_func = {
42
+ id(print): impl.ti_print,
43
+ id(min): ti_ops.min,
44
+ id(max): ti_ops.max,
45
+ id(int): impl.ti_int,
46
+ id(bool): impl.ti_bool,
47
+ id(float): impl.ti_float,
48
+ id(any): matrix_ops.any,
49
+ id(all): matrix_ops.all,
50
+ id(abs): abs,
51
+ id(pow): pow,
52
+ id(operator.matmul): matrix_ops.matmul,
53
+ }
54
+
55
+ # Builtin 'len' function on Matrix Expr
56
+ if id(func) == id(len) and len(args) == 1:
57
+ if isinstance(args[0], Expr) and args[0].ptr.is_tensor():
58
+ node.ptr = args[0].get_shape()[0]
59
+ return True
60
+
61
+ if id(func) in replace_func:
62
+ node.ptr = replace_func[id(func)](*args, **keywords)
63
+ return True
64
+ return False
65
+
66
+ @staticmethod
67
+ def build_call_if_is_type(ctx: ASTTransformerContext, node, args, keywords):
68
+ func = node.func.ptr
69
+ if id(func) in primitive_types.type_ids:
70
+ if len(args) != 1 or keywords:
71
+ raise GsTaichiSyntaxError("A primitive type can only decorate a single expression.")
72
+ if is_gstaichi_class(args[0]):
73
+ raise GsTaichiSyntaxError("A primitive type cannot decorate an expression with a compound type.")
74
+
75
+ if isinstance(args[0], expr.Expr):
76
+ if args[0].ptr.is_tensor():
77
+ raise GsTaichiSyntaxError("A primitive type cannot decorate an expression with a compound type.")
78
+ node.ptr = ti_ops.cast(args[0], func)
79
+ else:
80
+ node.ptr = expr.Expr(args[0], dtype=func)
81
+ return True
82
+ return False
83
+
84
+ @staticmethod
85
+ def is_external_func(ctx: ASTTransformerContext, func) -> bool:
86
+ if ctx.is_in_static_scope(): # allow external function in static scope
87
+ return False
88
+ if hasattr(func, "_is_gstaichi_function") or hasattr(func, "_is_wrapped_kernel"): # gstaichi func/kernel
89
+ return False
90
+ if hasattr(func, "__module__") and func.__module__ and func.__module__.startswith("gstaichi."):
91
+ return False
92
+ return True
93
+
94
+ @staticmethod
95
+ def warn_if_is_external_func(ctx: ASTTransformerContext, node):
96
+ func = node.func.ptr
97
+ if not CallTransformer.is_external_func(ctx, func):
98
+ return
99
+ name = unparse(node.func).strip()
100
+ warnings.warn_explicit(
101
+ f"\x1b[38;5;226m" # Yellow
102
+ f'Calling non-gstaichi function "{name}". '
103
+ f"Scope inside the function is not processed by the GsTaichi AST transformer. "
104
+ f"The function may not work as expected. Proceed with caution! "
105
+ f"Maybe you can consider turning it into a @ti.func?"
106
+ f"\x1b[0m", # Reset
107
+ SyntaxWarning,
108
+ ctx.file,
109
+ node.lineno + ctx.lineno_offset,
110
+ module="gstaichi",
111
+ )
112
+
113
+ @staticmethod
114
+ # Parses a formatted string and extracts format specifiers from it, along with positional and keyword arguments.
115
+ # This function produces a canonicalized formatted string that includes solely empty replacement fields, e.g. 'qwerty {} {} {} {} {}'.
116
+ # Note that the arguments can be used multiple times in the string.
117
+ # e.g.:
118
+ # origin input: 'qwerty {1} {} {1:.3f} {k:.4f} {k:}'.format(1.0, 2.0, k=k)
119
+ # raw_string: 'qwerty {1} {} {1:.3f} {k:.4f} {k:}'
120
+ # raw_args: [1.0, 2.0]
121
+ # raw_keywords: {'k': <ti.Expr>}
122
+ # return value: ['qwerty {} {} {} {} {}', 2.0, 1.0, ['__ti_fmt_value__', 2.0, '.3f'], ['__ti_fmt_value__', <ti.Expr>, '.4f'], <ti.Expr>]
123
+ def canonicalize_formatted_string(raw_string: str, *raw_args: list, **raw_keywords: dict):
124
+ raw_brackets = re.findall(r"{(.*?)}", raw_string)
125
+ brackets = []
126
+ unnamed = 0
127
+ for bracket in raw_brackets:
128
+ item, spec = bracket.split(":") if ":" in bracket else (bracket, None)
129
+ if item.isdigit():
130
+ item = int(item)
131
+ # handle unnamed positional args
132
+ if item == "":
133
+ item = unnamed
134
+ unnamed += 1
135
+ # handle empty spec
136
+ if spec == "":
137
+ spec = None
138
+ brackets.append((item, spec))
139
+
140
+ # check for errors in the arguments
141
+ max_args_index = max([t[0] for t in brackets if isinstance(t[0], int)], default=-1)
142
+ if max_args_index + 1 != len(raw_args):
143
+ raise GsTaichiSyntaxError(
144
+ f"Expected {max_args_index + 1} positional argument(s), but received {len(raw_args)} instead."
145
+ )
146
+ brackets_keywords = [t[0] for t in brackets if isinstance(t[0], str)]
147
+ for item in brackets_keywords:
148
+ if item not in raw_keywords:
149
+ raise GsTaichiSyntaxError(f"Keyword '{item}' not found.")
150
+ for item in raw_keywords:
151
+ if item not in brackets_keywords:
152
+ raise GsTaichiSyntaxError(f"Keyword '{item}' not used.")
153
+
154
+ # reorganize the arguments based on their positions, keywords, and format specifiers
155
+ args = []
156
+ for item, spec in brackets:
157
+ new_arg = raw_args[item] if isinstance(item, int) else raw_keywords[item]
158
+ if spec is not None:
159
+ args.append(["__ti_fmt_value__", new_arg, spec])
160
+ else:
161
+ args.append(new_arg)
162
+ # put the formatted string as the first argument to make ti.format() happy
163
+ args.insert(0, re.sub(r"{.*?}", "{}", raw_string))
164
+ return args
165
+
166
+ @staticmethod
167
+ def expand_node_args_dataclasses(args: tuple[ast.AST, ...]) -> tuple[ast.AST, ...]:
168
+ args_new = []
169
+ for arg in args:
170
+ val = arg.ptr
171
+ if dataclasses.is_dataclass(val):
172
+ dataclass_type = val
173
+ for field in dataclasses.fields(dataclass_type):
174
+ child_name = f"__ti_{arg.id}_{field.name}"
175
+ load_ctx = ast.Load()
176
+ arg_node = ast.Name(
177
+ id=child_name,
178
+ ctx=load_ctx,
179
+ lineno=arg.lineno,
180
+ end_lineno=arg.end_lineno,
181
+ col_offset=arg.col_offset,
182
+ end_col_offset=arg.end_col_offset,
183
+ )
184
+ args_new.append(arg_node)
185
+ else:
186
+ args_new.append(arg)
187
+ return tuple(args_new)
188
+
189
+ @staticmethod
190
+ def build_Call(ctx: ASTTransformerContext, node: ast.Call, build_stmt, build_stmts):
191
+ if get_decorator(ctx, node) in ["static", "static_assert"]:
192
+ with ctx.static_scope_guard():
193
+ build_stmt(ctx, node.func)
194
+ build_stmts(ctx, node.args)
195
+ build_stmts(ctx, node.keywords)
196
+ else:
197
+ build_stmt(ctx, node.func)
198
+ # creates variable for the dataclass itself (as well as other variables,
199
+ # not related to dataclasses). Necessary for calling further child functions
200
+ build_stmts(ctx, node.args)
201
+ node.args = CallTransformer.expand_node_args_dataclasses(node.args)
202
+ # create variables for the now-expanded dataclass members
203
+ build_stmts(ctx, node.args)
204
+ build_stmts(ctx, node.keywords)
205
+
206
+ args = []
207
+ for arg in node.args:
208
+ if isinstance(arg, ast.Starred):
209
+ arg_list = arg.ptr
210
+ if isinstance(arg_list, Expr) and arg_list.is_tensor():
211
+ # Expand Expr with Matrix-type return into list of Exprs
212
+ arg_list = [Expr(x) for x in ctx.ast_builder.expand_exprs([arg_list.ptr])]
213
+
214
+ for i in arg_list:
215
+ args.append(i)
216
+ else:
217
+ args.append(arg.ptr)
218
+ keywords = dict(ChainMap(*[keyword.ptr for keyword in node.keywords]))
219
+ func = node.func.ptr
220
+
221
+ if id(func) in [id(print), id(impl.ti_print)]:
222
+ ctx.func.has_print = True
223
+
224
+ if isinstance(node.func, ast.Attribute) and isinstance(node.func.value.ptr, str) and node.func.attr == "format":
225
+ raw_string = node.func.value.ptr
226
+ args = CallTransformer.canonicalize_formatted_string(raw_string, *args, **keywords)
227
+ node.ptr = impl.ti_format(*args)
228
+ return node.ptr
229
+
230
+ if id(func) == id(Matrix) or id(func) == id(Vector):
231
+ node.ptr = matrix.make_matrix(*args, **keywords)
232
+ return node.ptr
233
+
234
+ if CallTransformer.build_call_if_is_builtin(ctx, node, args, keywords):
235
+ return node.ptr
236
+
237
+ if CallTransformer.build_call_if_is_type(ctx, node, args, keywords):
238
+ return node.ptr
239
+
240
+ if hasattr(node.func, "caller"):
241
+ node.ptr = func(node.func.caller, *args, **keywords)
242
+ return node.ptr
243
+
244
+ CallTransformer.warn_if_is_external_func(ctx, node)
245
+ try:
246
+ node.ptr = func(*args, **keywords)
247
+ except TypeError as e:
248
+ module = inspect.getmodule(func)
249
+ error_msg = re.sub(r"\bExpr\b", "GsTaichi Expression", str(e))
250
+ func_name = getattr(func, "__name__", func.__class__.__name__)
251
+ msg = f"TypeError when calling `{func_name}`: {error_msg}."
252
+ if CallTransformer.is_external_func(ctx, node.func.ptr):
253
+ args_has_expr = any([isinstance(arg, Expr) for arg in args])
254
+ if args_has_expr and (module == math or module == np):
255
+ exec_str = f"from gstaichi import {func.__name__}"
256
+ try:
257
+ exec(exec_str, {})
258
+ except:
259
+ pass
260
+ else:
261
+ msg += f"\nDid you mean to use `ti.{func.__name__}` instead of `{module.__name__}.{func.__name__}`?"
262
+ raise GsTaichiTypeError(msg)
263
+
264
+ if getattr(func, "_is_gstaichi_function", False):
265
+ ctx.func.has_print |= func.wrapper.has_print
266
+
267
+ return node.ptr