gstaichi 0.1.23.dev0__cp310-cp310-macosx_15_0_arm64.whl → 1.0.1__cp310-cp310-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (240) hide show
  1. gstaichi/CHANGELOG.md +6 -0
  2. gstaichi/__init__.py +40 -0
  3. {taichi → gstaichi}/_funcs.py +8 -8
  4. {taichi → gstaichi}/_kernels.py +19 -19
  5. gstaichi/_lib/__init__.py +3 -0
  6. taichi/_lib/core/taichi_python.cpython-310-darwin.so → gstaichi/_lib/core/gstaichi_python.cpython-310-darwin.so +0 -0
  7. taichi/_lib/core/taichi_python.pyi → gstaichi/_lib/core/gstaichi_python.pyi +382 -520
  8. {taichi → gstaichi}/_lib/runtime/runtime_arm64.bc +0 -0
  9. {taichi → gstaichi}/_lib/utils.py +15 -15
  10. {taichi → gstaichi}/_logging.py +1 -1
  11. gstaichi/_snode/__init__.py +5 -0
  12. {taichi → gstaichi}/_snode/fields_builder.py +27 -29
  13. {taichi → gstaichi}/_snode/snode_tree.py +5 -5
  14. gstaichi/_test_tools/__init__.py +0 -0
  15. gstaichi/_test_tools/load_kernel_string.py +30 -0
  16. gstaichi/_version.py +1 -0
  17. {taichi → gstaichi}/_version_check.py +8 -5
  18. gstaichi/ad/__init__.py +3 -0
  19. {taichi → gstaichi}/ad/_ad.py +26 -26
  20. {taichi → gstaichi}/algorithms/_algorithms.py +7 -7
  21. {taichi → gstaichi}/examples/minimal.py +1 -1
  22. {taichi → gstaichi}/experimental.py +1 -1
  23. gstaichi/lang/__init__.py +50 -0
  24. {taichi → gstaichi}/lang/_ndarray.py +30 -26
  25. {taichi → gstaichi}/lang/_ndrange.py +8 -8
  26. gstaichi/lang/_template_mapper.py +199 -0
  27. {taichi → gstaichi}/lang/_texture.py +19 -19
  28. {taichi → gstaichi}/lang/_wrap_inspect.py +7 -7
  29. {taichi → gstaichi}/lang/any_array.py +13 -13
  30. {taichi → gstaichi}/lang/argpack.py +29 -29
  31. gstaichi/lang/ast/__init__.py +5 -0
  32. {taichi → gstaichi}/lang/ast/ast_transformer.py +94 -582
  33. {taichi → gstaichi}/lang/ast/ast_transformer_utils.py +54 -41
  34. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  35. gstaichi/lang/ast/ast_transformers/call_transformer.py +267 -0
  36. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +320 -0
  37. {taichi → gstaichi}/lang/ast/checkers.py +5 -5
  38. gstaichi/lang/ast/transform.py +9 -0
  39. {taichi → gstaichi}/lang/common_ops.py +12 -12
  40. gstaichi/lang/exception.py +80 -0
  41. {taichi → gstaichi}/lang/expr.py +22 -22
  42. {taichi → gstaichi}/lang/field.py +29 -27
  43. {taichi → gstaichi}/lang/impl.py +116 -121
  44. {taichi → gstaichi}/lang/kernel_arguments.py +16 -16
  45. {taichi → gstaichi}/lang/kernel_impl.py +330 -363
  46. {taichi → gstaichi}/lang/matrix.py +119 -115
  47. {taichi → gstaichi}/lang/matrix_ops.py +6 -6
  48. {taichi → gstaichi}/lang/matrix_ops_utils.py +4 -4
  49. {taichi → gstaichi}/lang/mesh.py +22 -22
  50. {taichi → gstaichi}/lang/misc.py +39 -68
  51. {taichi → gstaichi}/lang/ops.py +146 -141
  52. {taichi → gstaichi}/lang/runtime_ops.py +2 -2
  53. {taichi → gstaichi}/lang/shell.py +3 -3
  54. {taichi → gstaichi}/lang/simt/__init__.py +1 -1
  55. {taichi → gstaichi}/lang/simt/block.py +7 -7
  56. {taichi → gstaichi}/lang/simt/grid.py +1 -1
  57. {taichi → gstaichi}/lang/simt/subgroup.py +1 -1
  58. {taichi → gstaichi}/lang/simt/warp.py +1 -1
  59. {taichi → gstaichi}/lang/snode.py +46 -44
  60. {taichi → gstaichi}/lang/source_builder.py +13 -13
  61. {taichi → gstaichi}/lang/struct.py +33 -33
  62. {taichi → gstaichi}/lang/util.py +24 -24
  63. gstaichi/linalg/__init__.py +8 -0
  64. {taichi → gstaichi}/linalg/matrixfree_cg.py +14 -14
  65. {taichi → gstaichi}/linalg/sparse_cg.py +10 -10
  66. {taichi → gstaichi}/linalg/sparse_matrix.py +23 -23
  67. {taichi → gstaichi}/linalg/sparse_solver.py +21 -21
  68. {taichi → gstaichi}/math/__init__.py +1 -1
  69. {taichi → gstaichi}/math/_complex.py +21 -20
  70. {taichi → gstaichi}/math/mathimpl.py +56 -56
  71. gstaichi/profiler/__init__.py +6 -0
  72. {taichi → gstaichi}/profiler/kernel_metrics.py +11 -11
  73. {taichi → gstaichi}/profiler/kernel_profiler.py +30 -36
  74. {taichi → gstaichi}/profiler/memory_profiler.py +1 -1
  75. {taichi → gstaichi}/profiler/scoped_profiler.py +2 -2
  76. {taichi → gstaichi}/sparse/_sparse_grid.py +7 -7
  77. {taichi → gstaichi}/tools/__init__.py +4 -4
  78. {taichi → gstaichi}/tools/diagnose.py +10 -17
  79. gstaichi/types/__init__.py +19 -0
  80. {taichi → gstaichi}/types/annotations.py +1 -1
  81. {taichi → gstaichi}/types/compound_types.py +8 -8
  82. {taichi → gstaichi}/types/enums.py +1 -1
  83. {taichi → gstaichi}/types/ndarray_type.py +7 -7
  84. {taichi → gstaichi}/types/primitive_types.py +17 -14
  85. {taichi → gstaichi}/types/quant.py +9 -9
  86. {taichi → gstaichi}/types/texture_type.py +5 -5
  87. {taichi → gstaichi}/types/utils.py +1 -1
  88. {gstaichi-0.1.23.dev0.dist-info → gstaichi-1.0.1.dist-info}/METADATA +13 -16
  89. gstaichi-1.0.1.dist-info/RECORD +166 -0
  90. gstaichi-1.0.1.dist-info/top_level.txt +1 -0
  91. gstaichi-0.1.23.dev0.dist-info/RECORD +0 -219
  92. gstaichi-0.1.23.dev0.dist-info/entry_points.txt +0 -2
  93. gstaichi-0.1.23.dev0.dist-info/top_level.txt +0 -1
  94. taichi/__init__.py +0 -44
  95. taichi/__main__.py +0 -5
  96. taichi/_lib/__init__.py +0 -3
  97. taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +0 -1401
  98. taichi/_lib/c_api/include/taichi/taichi.h +0 -29
  99. taichi/_lib/c_api/include/taichi/taichi_core.h +0 -1111
  100. taichi/_lib/c_api/include/taichi/taichi_cpu.h +0 -29
  101. taichi/_lib/c_api/include/taichi/taichi_metal.h +0 -72
  102. taichi/_lib/c_api/include/taichi/taichi_platform.h +0 -55
  103. taichi/_lib/c_api/include/taichi/taichi_unity.h +0 -64
  104. taichi/_lib/c_api/include/taichi/taichi_vulkan.h +0 -151
  105. taichi/_lib/c_api/lib/libtaichi_c_api.dylib +0 -0
  106. taichi/_lib/c_api/runtime/runtime_arm64.bc +0 -0
  107. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +0 -29
  108. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +0 -65
  109. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +0 -121
  110. taichi/_lib/runtime/libMoltenVK.dylib +0 -0
  111. taichi/_main.py +0 -552
  112. taichi/_snode/__init__.py +0 -5
  113. taichi/_ti_module/__init__.py +0 -3
  114. taichi/_ti_module/cppgen.py +0 -309
  115. taichi/_ti_module/module.py +0 -145
  116. taichi/_version.py +0 -1
  117. taichi/ad/__init__.py +0 -3
  118. taichi/aot/__init__.py +0 -12
  119. taichi/aot/_export.py +0 -28
  120. taichi/aot/conventions/__init__.py +0 -3
  121. taichi/aot/conventions/gfxruntime140/__init__.py +0 -38
  122. taichi/aot/conventions/gfxruntime140/dr.py +0 -244
  123. taichi/aot/conventions/gfxruntime140/sr.py +0 -613
  124. taichi/aot/module.py +0 -253
  125. taichi/aot/utils.py +0 -151
  126. taichi/graph/__init__.py +0 -3
  127. taichi/graph/_graph.py +0 -292
  128. taichi/lang/__init__.py +0 -50
  129. taichi/lang/ast/__init__.py +0 -5
  130. taichi/lang/ast/transform.py +0 -9
  131. taichi/lang/exception.py +0 -80
  132. taichi/linalg/__init__.py +0 -8
  133. taichi/profiler/__init__.py +0 -6
  134. taichi/shaders/Circles_vk.frag +0 -29
  135. taichi/shaders/Circles_vk.vert +0 -45
  136. taichi/shaders/Circles_vk_frag.spv +0 -0
  137. taichi/shaders/Circles_vk_vert.spv +0 -0
  138. taichi/shaders/Lines_vk.frag +0 -9
  139. taichi/shaders/Lines_vk.vert +0 -11
  140. taichi/shaders/Lines_vk_frag.spv +0 -0
  141. taichi/shaders/Lines_vk_vert.spv +0 -0
  142. taichi/shaders/Mesh_vk.frag +0 -71
  143. taichi/shaders/Mesh_vk.vert +0 -68
  144. taichi/shaders/Mesh_vk_frag.spv +0 -0
  145. taichi/shaders/Mesh_vk_vert.spv +0 -0
  146. taichi/shaders/Particles_vk.frag +0 -95
  147. taichi/shaders/Particles_vk.vert +0 -73
  148. taichi/shaders/Particles_vk_frag.spv +0 -0
  149. taichi/shaders/Particles_vk_vert.spv +0 -0
  150. taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
  151. taichi/shaders/SceneLines_vk.frag +0 -9
  152. taichi/shaders/SceneLines_vk.vert +0 -12
  153. taichi/shaders/SceneLines_vk_frag.spv +0 -0
  154. taichi/shaders/SceneLines_vk_vert.spv +0 -0
  155. taichi/shaders/SetImage_vk.frag +0 -21
  156. taichi/shaders/SetImage_vk.vert +0 -15
  157. taichi/shaders/SetImage_vk_frag.spv +0 -0
  158. taichi/shaders/SetImage_vk_vert.spv +0 -0
  159. taichi/shaders/Triangles_vk.frag +0 -16
  160. taichi/shaders/Triangles_vk.vert +0 -29
  161. taichi/shaders/Triangles_vk_frag.spv +0 -0
  162. taichi/shaders/Triangles_vk_vert.spv +0 -0
  163. taichi/shaders/lines2quad_vk_comp.spv +0 -0
  164. taichi/types/__init__.py +0 -19
  165. {taichi → gstaichi}/_lib/core/__init__.py +0 -0
  166. {taichi → gstaichi}/_lib/core/py.typed +0 -0
  167. {taichi/_lib/c_api → gstaichi/_lib}/runtime/libMoltenVK.dylib +0 -0
  168. {taichi → gstaichi}/algorithms/__init__.py +0 -0
  169. {taichi → gstaichi}/assets/.git +0 -0
  170. {taichi → gstaichi}/assets/Go-Regular.ttf +0 -0
  171. {taichi → gstaichi}/assets/static/imgs/ti_gallery.png +0 -0
  172. {taichi → gstaichi}/lang/ast/symbol_resolver.py +0 -0
  173. {taichi → gstaichi}/sparse/__init__.py +0 -0
  174. {taichi → gstaichi}/tools/np2ply.py +0 -0
  175. {taichi → gstaichi}/tools/vtk.py +0 -0
  176. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/GLFW/glfw3.h +0 -0
  177. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/GLFW/glfw3native.h +0 -0
  178. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/instrument.hpp +0 -0
  179. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/libspirv.h +0 -0
  180. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/libspirv.hpp +0 -0
  181. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/linker.hpp +0 -0
  182. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv-tools/optimizer.hpp +0 -0
  183. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/GLSL.std.450.h +0 -0
  184. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv.h +0 -0
  185. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv.hpp +0 -0
  186. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_cfg.hpp +0 -0
  187. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_common.hpp +0 -0
  188. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_cpp.hpp +0 -0
  189. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_cross.hpp +0 -0
  190. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_cross_c.h +0 -0
  191. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_cross_containers.hpp +0 -0
  192. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_cross_error_handling.hpp +0 -0
  193. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +0 -0
  194. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_cross_util.hpp +0 -0
  195. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_glsl.hpp +0 -0
  196. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_hlsl.hpp +0 -0
  197. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_msl.hpp +0 -0
  198. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_parser.hpp +0 -0
  199. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/include/spirv_cross/spirv_reflect.hpp +0 -0
  200. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +0 -0
  201. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +0 -0
  202. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +0 -0
  203. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +0 -0
  204. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +0 -0
  205. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +0 -0
  206. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +0 -0
  207. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +0 -0
  208. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +0 -0
  209. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +0 -0
  210. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +0 -0
  211. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +0 -0
  212. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +0 -0
  213. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +0 -0
  214. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +0 -0
  215. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +0 -0
  216. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +0 -0
  217. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +0 -0
  218. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/glfw3/glfw3Config.cmake +0 -0
  219. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +0 -0
  220. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/glfw3/glfw3Targets-release.cmake +0 -0
  221. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/cmake/glfw3/glfw3Targets.cmake +0 -0
  222. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  223. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +0 -0
  224. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +0 -0
  225. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +0 -0
  226. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +0 -0
  227. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +0 -0
  228. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +0 -0
  229. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +0 -0
  230. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +0 -0
  231. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +0 -0
  232. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +0 -0
  233. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +0 -0
  234. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +0 -0
  235. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +0 -0
  236. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +0 -0
  237. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +0 -0
  238. {gstaichi-0.1.23.dev0.data → gstaichi-1.0.1.data}/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +0 -0
  239. {gstaichi-0.1.23.dev0.dist-info → gstaichi-1.0.1.dist-info}/WHEEL +0 -0
  240. {gstaichi-0.1.23.dev0.dist-info → gstaichi-1.0.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,320 @@
1
+ # type: ignore
2
+
3
+ import ast
4
+ import dataclasses
5
+ from typing import Any, Callable
6
+
7
+ from gstaichi.lang import (
8
+ _ndarray,
9
+ any_array,
10
+ expr,
11
+ impl,
12
+ kernel_arguments,
13
+ matrix,
14
+ )
15
+ from gstaichi.lang import ops as ti_ops
16
+ from gstaichi.lang.argpack import ArgPackType
17
+ from gstaichi.lang.ast.ast_transformer_utils import (
18
+ ASTTransformerContext,
19
+ )
20
+ from gstaichi.lang.exception import (
21
+ GsTaichiSyntaxError,
22
+ )
23
+ from gstaichi.lang.matrix import MatrixType
24
+ from gstaichi.lang.struct import StructType
25
+ from gstaichi.lang.util import to_gstaichi_type
26
+ from gstaichi.types import annotations, ndarray_type, primitive_types, texture_type
27
+
28
+
29
+ class FunctionDefTransformer:
30
+ @staticmethod
31
+ def _decl_and_create_variable(
32
+ ctx: ASTTransformerContext, annotation, name, arg_features, invoke_later_dict, prefix_name, arg_depth
33
+ ) -> tuple[bool, Any]:
34
+ full_name = prefix_name + "_" + name
35
+ if not isinstance(annotation, primitive_types.RefType):
36
+ ctx.kernel_args.append(name)
37
+ if isinstance(annotation, ArgPackType):
38
+ kernel_arguments.push_argpack_arg(name)
39
+ d = {}
40
+ items_to_put_in_dict = []
41
+ for j, (_name, anno) in enumerate(annotation.members.items()):
42
+ result, obj = FunctionDefTransformer._decl_and_create_variable(
43
+ ctx, anno, _name, arg_features[j], invoke_later_dict, full_name, arg_depth + 1
44
+ )
45
+ if not result:
46
+ d[_name] = None
47
+ items_to_put_in_dict.append((full_name + "_" + _name, _name, obj))
48
+ else:
49
+ d[_name] = obj
50
+ argpack = kernel_arguments.decl_argpack_arg(annotation, d)
51
+ for item in items_to_put_in_dict:
52
+ invoke_later_dict[item[0]] = argpack, item[1], *item[2]
53
+ return True, argpack
54
+ if annotation == annotations.template or isinstance(annotation, annotations.template):
55
+ return True, ctx.global_vars[name]
56
+ if isinstance(annotation, annotations.sparse_matrix_builder):
57
+ return False, (
58
+ kernel_arguments.decl_sparse_matrix,
59
+ (
60
+ to_gstaichi_type(arg_features),
61
+ full_name,
62
+ ),
63
+ )
64
+ if isinstance(annotation, ndarray_type.NdarrayType):
65
+ return False, (
66
+ kernel_arguments.decl_ndarray_arg,
67
+ (
68
+ to_gstaichi_type(arg_features[0]),
69
+ arg_features[1],
70
+ full_name,
71
+ arg_features[2],
72
+ arg_features[3],
73
+ ),
74
+ )
75
+ if isinstance(annotation, texture_type.TextureType):
76
+ return False, (kernel_arguments.decl_texture_arg, (arg_features[0], full_name))
77
+ if isinstance(annotation, texture_type.RWTextureType):
78
+ return False, (
79
+ kernel_arguments.decl_rw_texture_arg,
80
+ (arg_features[0], arg_features[1], arg_features[2], full_name),
81
+ )
82
+ if isinstance(annotation, MatrixType):
83
+ return True, kernel_arguments.decl_matrix_arg(annotation, name, arg_depth)
84
+ if isinstance(annotation, StructType):
85
+ return True, kernel_arguments.decl_struct_arg(annotation, name, arg_depth)
86
+ return True, kernel_arguments.decl_scalar_arg(annotation, name, arg_depth)
87
+
88
+ @staticmethod
89
+ def _transform_kernel_arg(
90
+ ctx: ASTTransformerContext,
91
+ invoke_later_dict: dict[str, tuple[Any, str, Callable, list[Any]]],
92
+ create_variable_later: dict[str, Any],
93
+ argument_name: str,
94
+ argument_type: Any,
95
+ this_arg_features: tuple[Any, ...],
96
+ ) -> None:
97
+ if isinstance(argument_type, ArgPackType):
98
+ kernel_arguments.push_argpack_arg(argument_name)
99
+ d = {}
100
+ items_to_put_in_dict: list[tuple[str, str, Any]] = []
101
+ for j, (name, anno) in enumerate(argument_type.members.items()):
102
+ result, obj = FunctionDefTransformer._decl_and_create_variable(
103
+ ctx, anno, name, this_arg_features[j], invoke_later_dict, "__argpack_" + name, 1
104
+ )
105
+ if not result:
106
+ d[name] = None
107
+ items_to_put_in_dict.append(("__argpack_" + name, name, obj))
108
+ else:
109
+ d[name] = obj
110
+ argpack = kernel_arguments.decl_argpack_arg(argument_type, d)
111
+ for item in items_to_put_in_dict:
112
+ invoke_later_dict[item[0]] = argpack, item[1], *item[2]
113
+ create_variable_later[argument_name] = argpack
114
+ elif dataclasses.is_dataclass(argument_type):
115
+ arg_features = this_arg_features
116
+ ctx.create_variable(argument_name, argument_type)
117
+ for field_idx, field in enumerate(dataclasses.fields(argument_type)):
118
+ flat_name = f"__ti_{argument_name}_{field.name}"
119
+ result, obj = FunctionDefTransformer._decl_and_create_variable(
120
+ ctx,
121
+ field.type,
122
+ flat_name,
123
+ arg_features[field_idx],
124
+ invoke_later_dict,
125
+ "",
126
+ 0,
127
+ )
128
+ if result:
129
+ ctx.create_variable(flat_name, obj)
130
+ else:
131
+ decl_type_func, type_args = obj
132
+ obj = decl_type_func(*type_args)
133
+ ctx.create_variable(flat_name, obj)
134
+ else:
135
+ result, obj = FunctionDefTransformer._decl_and_create_variable(
136
+ ctx,
137
+ argument_type,
138
+ argument_name,
139
+ this_arg_features if ctx.arg_features is not None else None,
140
+ invoke_later_dict,
141
+ "",
142
+ 0,
143
+ )
144
+ if result:
145
+ ctx.create_variable(argument_name, obj)
146
+ else:
147
+ decl_type_func, type_args = obj
148
+ obj = decl_type_func(*type_args)
149
+ ctx.create_variable(argument_name, obj)
150
+
151
+ @staticmethod
152
+ def _transform_as_kernel(ctx: ASTTransformerContext, node: ast.FunctionDef, args: ast.arguments) -> None:
153
+ if node.returns is not None:
154
+ if not isinstance(node.returns, ast.Constant):
155
+ for return_type in ctx.func.return_type:
156
+ kernel_arguments.decl_ret(return_type)
157
+ impl.get_runtime().compiling_callable.finalize_rets()
158
+
159
+ invoke_later_dict: dict[str, tuple[Any, str, Any]] = dict()
160
+ create_variable_later = dict()
161
+ for i, arg in enumerate(args.args):
162
+ argument = ctx.func.arguments[i]
163
+ FunctionDefTransformer._transform_kernel_arg(
164
+ ctx,
165
+ invoke_later_dict,
166
+ create_variable_later,
167
+ argument.name,
168
+ argument.annotation,
169
+ ctx.arg_features[i] if ctx.arg_features is not None else (),
170
+ )
171
+
172
+ for k, v in invoke_later_dict.items():
173
+ argpack, name, func, params = v
174
+ argpack[name] = func(*params)
175
+ for k, v in create_variable_later.items():
176
+ ctx.create_variable(k, v)
177
+
178
+ impl.get_runtime().compiling_callable.finalize_params()
179
+ # remove original args
180
+ node.args.args = []
181
+
182
+ @staticmethod
183
+ def _transform_func_arg(
184
+ ctx: ASTTransformerContext,
185
+ argument_name: str,
186
+ argument_type: Any,
187
+ data: Any,
188
+ ) -> None:
189
+ if isinstance(argument_type, annotations.template):
190
+ ctx.create_variable(argument_name, data)
191
+ return None
192
+
193
+ if dataclasses.is_dataclass(argument_type):
194
+ dataclass_type = argument_type
195
+ for field in dataclasses.fields(dataclass_type):
196
+ data_child = getattr(data, field.name)
197
+ if not isinstance(
198
+ data_child,
199
+ (
200
+ _ndarray.ScalarNdarray,
201
+ matrix.VectorNdarray,
202
+ matrix.MatrixNdarray,
203
+ any_array.AnyArray,
204
+ ),
205
+ ):
206
+ raise GsTaichiSyntaxError(
207
+ f"Argument {argument_name} of type {dataclass_type} {field.type} is not recognized."
208
+ )
209
+ field.type.check_matched(data_child.get_type(), field.name)
210
+ var_name = f"__ti_{argument_name}_{field.name}"
211
+ ctx.create_variable(var_name, data_child)
212
+ return None
213
+
214
+ # Ndarray arguments are passed by reference.
215
+ if isinstance(argument_type, (ndarray_type.NdarrayType)):
216
+ if not isinstance(
217
+ data,
218
+ (
219
+ _ndarray.ScalarNdarray,
220
+ matrix.VectorNdarray,
221
+ matrix.MatrixNdarray,
222
+ any_array.AnyArray,
223
+ ),
224
+ ):
225
+ raise GsTaichiSyntaxError(f"Argument {arg.arg} of type {argument_type} is not recognized.")
226
+ argument_type.check_matched(data.get_type(), argument_name)
227
+ ctx.create_variable(argument_name, data)
228
+ return None
229
+
230
+ # Matrix arguments are passed by value.
231
+ if isinstance(argument_type, (MatrixType)):
232
+ var_name = argument_name
233
+ # "data" is expected to be an Expr here,
234
+ # so we simply call "impl.expr_init_func(data)" to perform:
235
+ #
236
+ # TensorType* t = alloca()
237
+ # assign(t, data)
238
+ #
239
+ # We created local variable "t" - a copy of the passed-in argument "data"
240
+ if not isinstance(data, expr.Expr) or not data.ptr.is_tensor():
241
+ raise GsTaichiSyntaxError(
242
+ f"Argument {var_name} of type {argument_type} is expected to be a Matrix, but got {type(data)}."
243
+ )
244
+
245
+ element_shape = data.ptr.get_rvalue_type().shape()
246
+ if len(element_shape) != argument_type.ndim:
247
+ raise GsTaichiSyntaxError(
248
+ f"Argument {var_name} of type {argument_type} is expected to be a Matrix with ndim {argument_type.ndim}, but got {len(element_shape)}."
249
+ )
250
+
251
+ assert argument_type.ndim > 0
252
+ if element_shape[0] != argument_type.n:
253
+ raise GsTaichiSyntaxError(
254
+ f"Argument {var_name} of type {argument_type} is expected to be a Matrix with n {argument_type.n}, but got {element_shape[0]}."
255
+ )
256
+
257
+ if argument_type.ndim == 2 and element_shape[1] != argument_type.m:
258
+ raise GsTaichiSyntaxError(
259
+ f"Argument {var_name} of type {argument_type} is expected to be a Matrix with m {argument_type.m}, but got {element_shape[0]}."
260
+ )
261
+
262
+ ctx.create_variable(var_name, impl.expr_init_func(data))
263
+ return None
264
+
265
+ if id(argument_type) in primitive_types.type_ids:
266
+ var_name = argument_name
267
+ ctx.create_variable(var_name, impl.expr_init_func(ti_ops.cast(data, argument_type)))
268
+ return None
269
+ # Create a copy for non-template arguments,
270
+ # so that they are passed by value.
271
+ var_name = argument_name
272
+ ctx.create_variable(var_name, impl.expr_init_func(data))
273
+ return None
274
+
275
+ @staticmethod
276
+ def _transform_as_func(ctx: ASTTransformerContext, node: ast.FunctionDef, args: ast.arguments) -> None:
277
+ for data_i, data in enumerate(ctx.argument_data):
278
+ argument = ctx.func.arguments[data_i]
279
+ FunctionDefTransformer._transform_func_arg(
280
+ ctx,
281
+ argument.name,
282
+ argument.annotation,
283
+ data,
284
+ )
285
+
286
+ for v in ctx.func.orig_arguments:
287
+ if dataclasses.is_dataclass(v.annotation):
288
+ ctx.create_variable(v.name, v.annotation)
289
+
290
+ @staticmethod
291
+ def build_FunctionDef(
292
+ ctx: ASTTransformerContext,
293
+ node: ast.FunctionDef,
294
+ build_stmts: Callable[[ASTTransformerContext, list[ast.stmt]], None],
295
+ ) -> None:
296
+ if ctx.visited_funcdef:
297
+ raise GsTaichiSyntaxError(
298
+ f"Function definition is not allowed in 'ti.{'kernel' if ctx.is_kernel else 'func'}'."
299
+ )
300
+ ctx.visited_funcdef = True
301
+
302
+ args = node.args
303
+ assert args.vararg is None
304
+ assert args.kwonlyargs == []
305
+ assert args.kw_defaults == []
306
+ assert args.kwarg is None
307
+
308
+ if ctx.is_kernel: # ti.kernel
309
+ FunctionDefTransformer._transform_as_kernel(ctx, node, args)
310
+
311
+ else: # ti.func
312
+ if ctx.is_real_function:
313
+ FunctionDefTransformer._transform_as_kernel(ctx, node, args)
314
+ else:
315
+ FunctionDefTransformer._transform_as_func(ctx, node, args)
316
+
317
+ with ctx.variable_scope_guard():
318
+ build_stmts(ctx, node.body)
319
+
320
+ return None
@@ -2,8 +2,8 @@
2
2
 
3
3
  import ast
4
4
 
5
- from taichi.lang._wrap_inspect import getsourcefile, getsourcelines
6
- from taichi.lang.exception import TaichiSyntaxError
5
+ from gstaichi.lang._wrap_inspect import getsourcefile, getsourcelines
6
+ from gstaichi.lang.exception import GsTaichiSyntaxError
7
7
 
8
8
 
9
9
  class KernelSimplicityASTChecker(ast.NodeVisitor):
@@ -62,7 +62,7 @@ class KernelSimplicityASTChecker(ast.NodeVisitor):
62
62
  if not isinstance(node, ast.stmt):
63
63
  return False
64
64
  # TODO(#536): Frontend pass should help make sure |func| is a valid AST for
65
- # Taichi.
65
+ # GsTaichi.
66
66
  ignored = [ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef]
67
67
  return not any(map(lambda t: isinstance(node, t), ignored))
68
68
 
@@ -72,7 +72,7 @@ class KernelSimplicityASTChecker(ast.NodeVisitor):
72
72
  return
73
73
 
74
74
  if not (self.top_level or self.current_scope.allows_more_stmt):
75
- raise TaichiSyntaxError(f"No more statements allowed, at {self.get_error_location(node)}")
75
+ raise GsTaichiSyntaxError(f"No more statements allowed, at {self.get_error_location(node)}")
76
76
  old_top_level = self.top_level
77
77
  if old_top_level:
78
78
  self._scope_guards.append(self.new_scope())
@@ -96,7 +96,7 @@ class KernelSimplicityASTChecker(ast.NodeVisitor):
96
96
  # and node.iter.func.attr == 'static')
97
97
  # if not (self.top_level or self.current_scope.allows_for_loop
98
98
  # or is_static):
99
- # raise TaichiSyntaxError(
99
+ # raise GsTaichiSyntaxError(
100
100
  # f'No more for loops allowed, at {self.get_error_location(node)}'
101
101
  # )
102
102
  # with self.new_scope():
@@ -0,0 +1,9 @@
1
+ # type: ignore
2
+
3
+ from gstaichi.lang.ast.ast_transformer import ASTTransformer
4
+ from gstaichi.lang.ast.ast_transformer_utils import ASTTransformerContext
5
+
6
+
7
+ def transform_tree(tree, ctx: ASTTransformerContext):
8
+ ASTTransformer()(ctx, tree)
9
+ return ctx.return_data
@@ -2,13 +2,13 @@
2
2
 
3
3
  from typing import TYPE_CHECKING
4
4
 
5
- from taichi.lang import ops
6
- from taichi.lang.util import in_python_scope
7
- from taichi.types import primitive_types
5
+ from gstaichi.lang import ops
6
+ from gstaichi.lang.util import in_python_scope
7
+ from gstaichi.types import primitive_types
8
8
 
9
9
 
10
- class TaichiOperations:
11
- """The base class of taichi operations of expressions. Subclasses: :class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`"""
10
+ class GsTaichiOperations:
11
+ """The base class of gstaichi operations of expressions. Subclasses: :class:`~gstaichi.lang.expr.Expr`, :class:`~gstaichi.lang.matrix.Matrix`"""
12
12
 
13
13
  if TYPE_CHECKING:
14
14
  # Make pylint happy
@@ -124,7 +124,7 @@ class TaichiOperations:
124
124
  other (Any): Given operand.
125
125
 
126
126
  Returns:
127
- :class:`~taichi.lang.expr.Expr`: The computing expression of atomic add."""
127
+ :class:`~gstaichi.lang.expr.Expr`: The computing expression of atomic add."""
128
128
  return ops.atomic_add(self, other)
129
129
 
130
130
  def _atomic_mul(self, other):
@@ -134,7 +134,7 @@ class TaichiOperations:
134
134
  other (Any): Given operand.
135
135
 
136
136
  Returns:
137
- :class:`~taichi.lang.expr.Expr`: The computing expression of atomic mul."""
137
+ :class:`~gstaichi.lang.expr.Expr`: The computing expression of atomic mul."""
138
138
  return ops.atomic_mul(self, other)
139
139
 
140
140
  def _atomic_sub(self, other):
@@ -144,7 +144,7 @@ class TaichiOperations:
144
144
  other (Any): Given operand.
145
145
 
146
146
  Returns:
147
- :class:`~taichi.lang.expr.Expr`: The computing expression of atomic sub."""
147
+ :class:`~gstaichi.lang.expr.Expr`: The computing expression of atomic sub."""
148
148
  return ops.atomic_sub(self, other)
149
149
 
150
150
  def _atomic_and(self, other):
@@ -154,7 +154,7 @@ class TaichiOperations:
154
154
  other (Any): Given operand.
155
155
 
156
156
  Returns:
157
- :class:`~taichi.lang.expr.Expr`: The computing expression of atomic and."""
157
+ :class:`~gstaichi.lang.expr.Expr`: The computing expression of atomic and."""
158
158
  return ops.atomic_and(self, other)
159
159
 
160
160
  def _atomic_xor(self, other):
@@ -164,7 +164,7 @@ class TaichiOperations:
164
164
  other (Any): Given operand.
165
165
 
166
166
  Returns:
167
- :class:`~taichi.lang.expr.Expr`: The computing expression of atomic xor."""
167
+ :class:`~gstaichi.lang.expr.Expr`: The computing expression of atomic xor."""
168
168
  return ops.atomic_xor(self, other)
169
169
 
170
170
  def _atomic_or(self, other):
@@ -174,7 +174,7 @@ class TaichiOperations:
174
174
  other (Any): Given operand.
175
175
 
176
176
  Returns:
177
- :class:`~taichi.lang.expr.Expr`: The computing expression of atomic or."""
177
+ :class:`~gstaichi.lang.expr.Expr`: The computing expression of atomic or."""
178
178
  return ops.atomic_or(self, other)
179
179
 
180
180
  # In-place operators in python scope returns NotImplemented to fall back to normal operators
@@ -264,7 +264,7 @@ class TaichiOperations:
264
264
  other (Any): Given operand.
265
265
 
266
266
  Returns:
267
- :class:`~taichi.lang.expr.Expr`: The expression after assigning."""
267
+ :class:`~gstaichi.lang.expr.Expr`: The expression after assigning."""
268
268
  return ops.assign(self, other)
269
269
 
270
270
  def _augassign(self, x, op):
@@ -0,0 +1,80 @@
1
+ # type: ignore
2
+
3
+ from gstaichi._lib import core
4
+
5
+
6
+ class GsTaichiCompilationError(Exception):
7
+ """Base class for all compilation exceptions."""
8
+
9
+ pass
10
+
11
+
12
+ class GsTaichiSyntaxError(GsTaichiCompilationError, SyntaxError):
13
+ """Thrown when a syntax error is found during compilation."""
14
+
15
+ pass
16
+
17
+
18
+ class GsTaichiNameError(GsTaichiCompilationError, NameError):
19
+ """Thrown when an undefine name is found during compilation."""
20
+
21
+ pass
22
+
23
+
24
+ class GsTaichiIndexError(GsTaichiCompilationError, IndexError):
25
+ """Thrown when an index error is found during compilation."""
26
+
27
+ pass
28
+
29
+
30
+ class GsTaichiTypeError(GsTaichiCompilationError, TypeError):
31
+ """Thrown when a type mismatch is found during compilation."""
32
+
33
+ pass
34
+
35
+
36
+ class GsTaichiRuntimeError(RuntimeError):
37
+ """Thrown when the compiled program cannot be executed due to unspecified reasons."""
38
+
39
+ pass
40
+
41
+
42
+ class GsTaichiAssertionError(GsTaichiRuntimeError, AssertionError):
43
+ """Thrown when assertion fails at runtime."""
44
+
45
+ pass
46
+
47
+
48
+ class GsTaichiRuntimeTypeError(GsTaichiRuntimeError, TypeError):
49
+ @staticmethod
50
+ def get(pos, needed, provided):
51
+ return GsTaichiRuntimeTypeError(
52
+ f"Argument {pos} (type={provided}) cannot be converted into required type {needed}"
53
+ )
54
+
55
+ @staticmethod
56
+ def get_ret(needed, provided):
57
+ return GsTaichiRuntimeTypeError(f"Return (type={provided}) cannot be converted into required type {needed}")
58
+
59
+
60
+ def handle_exception_from_cpp(exc):
61
+ if isinstance(exc, core.GsTaichiTypeError):
62
+ return GsTaichiTypeError(str(exc))
63
+ if isinstance(exc, core.GsTaichiSyntaxError):
64
+ return GsTaichiSyntaxError(str(exc))
65
+ if isinstance(exc, core.GsTaichiIndexError):
66
+ return GsTaichiIndexError(str(exc))
67
+ if isinstance(exc, core.GsTaichiAssertionError):
68
+ return GsTaichiAssertionError(str(exc))
69
+ return exc
70
+
71
+
72
+ __all__ = [
73
+ "GsTaichiSyntaxError",
74
+ "GsTaichiTypeError",
75
+ "GsTaichiCompilationError",
76
+ "GsTaichiNameError",
77
+ "GsTaichiRuntimeError",
78
+ "GsTaichiRuntimeTypeError",
79
+ "GsTaichiAssertionError",
80
+ ]
@@ -2,21 +2,21 @@ from typing import TYPE_CHECKING
2
2
 
3
3
  import numpy as np
4
4
 
5
- from taichi._lib import core as _ti_core
6
- from taichi.lang import impl
7
- from taichi.lang.common_ops import TaichiOperations
8
- from taichi.lang.exception import TaichiCompilationError, TaichiTypeError
9
- from taichi.lang.matrix import make_matrix
10
- from taichi.lang.util import is_matrix_class, is_taichi_class, to_numpy_type
11
- from taichi.types import primitive_types
12
- from taichi.types.primitive_types import integer_types, real_types
5
+ from gstaichi._lib import core as _ti_core
6
+ from gstaichi.lang import impl
7
+ from gstaichi.lang.common_ops import GsTaichiOperations
8
+ from gstaichi.lang.exception import GsTaichiCompilationError, GsTaichiTypeError
9
+ from gstaichi.lang.matrix import make_matrix
10
+ from gstaichi.lang.util import is_gstaichi_class, is_matrix_class, to_numpy_type
11
+ from gstaichi.types import primitive_types
12
+ from gstaichi.types.primitive_types import integer_types, real_types
13
13
 
14
14
  if TYPE_CHECKING:
15
- from taichi.lang.ast.ast_transformer_utils import ASTBuilder
15
+ from gstaichi.lang.ast.ast_transformer_utils import ASTBuilder
16
16
 
17
17
 
18
18
  # Scalar, basic data type
19
- class Expr(TaichiOperations):
19
+ class Expr(GsTaichiOperations):
20
20
  """A Python-side Expr wrapper, whose member variable `ptr` is an instance of C++ Expr class. A C++ Expr object contains member variable `expr` which holds an instance of C++ Expression class."""
21
21
 
22
22
  def __init__(self, *args, dbg_info=None, dtype=None):
@@ -24,7 +24,7 @@ class Expr(TaichiOperations):
24
24
  self.ptr_type_checked = False
25
25
  self.declaration_tb: str = ""
26
26
  if len(args) == 1:
27
- if isinstance(args[0], _ti_core.Expr):
27
+ if isinstance(args[0], _ti_core.ExprCxx):
28
28
  self.ptr = args[0]
29
29
  elif isinstance(args[0], Expr):
30
30
  self.ptr = args[0].ptr
@@ -39,7 +39,7 @@ class Expr(TaichiOperations):
39
39
  arg = args[0]
40
40
  if isinstance(arg, np.ndarray):
41
41
  if arg.shape:
42
- raise TaichiTypeError(
42
+ raise GsTaichiTypeError(
43
43
  "Only 0-dimensional numpy array can be used to initialize a scalar expression"
44
44
  )
45
45
  arg = arg.dtype.type(arg)
@@ -63,7 +63,7 @@ class Expr(TaichiOperations):
63
63
 
64
64
  def get_shape(self):
65
65
  if not self.is_tensor():
66
- raise TaichiCompilationError(f"Getting shape of non-tensor type: {self.ptr.get_rvalue_type()}")
66
+ raise GsTaichiCompilationError(f"Getting shape of non-tensor type: {self.ptr.get_rvalue_type()}")
67
67
  shape = self.ptr.get_shape()
68
68
  assert shape is not None
69
69
  return tuple(shape)
@@ -72,14 +72,14 @@ class Expr(TaichiOperations):
72
72
  def n(self):
73
73
  shape = self.get_shape()
74
74
  if len(shape) < 1:
75
- raise TaichiCompilationError(f"Getting n of tensor type < 1D: {self.ptr.get_rvalue_type()}")
75
+ raise GsTaichiCompilationError(f"Getting n of tensor type < 1D: {self.ptr.get_rvalue_type()}")
76
76
  return shape[0]
77
77
 
78
78
  @property
79
79
  def m(self):
80
80
  shape = self.get_shape()
81
81
  if len(shape) < 2:
82
- raise TaichiCompilationError(f"Getting m of tensor type < 2D: {self.ptr.get_rvalue_type()}")
82
+ raise GsTaichiCompilationError(f"Getting m of tensor type < 2D: {self.ptr.get_rvalue_type()}")
83
83
  return shape[1]
84
84
 
85
85
  def __hash__(self):
@@ -116,7 +116,7 @@ def make_constant_expr(val, dtype):
116
116
  if isinstance(val, (float, np.floating)):
117
117
  constant_dtype = impl.get_runtime().default_fp if dtype is None else dtype
118
118
  if constant_dtype not in real_types:
119
- raise TaichiTypeError(
119
+ raise GsTaichiTypeError(
120
120
  "Floating-point literals must be annotated with a floating-point type. For type casting, use `ti.cast`."
121
121
  )
122
122
  return Expr(_ti_core.make_const_expr_fp(constant_dtype, val))
@@ -124,19 +124,19 @@ def make_constant_expr(val, dtype):
124
124
  if isinstance(val, (int, np.integer)):
125
125
  constant_dtype = impl.get_runtime().default_ip if dtype is None else dtype
126
126
  if constant_dtype not in integer_types:
127
- raise TaichiTypeError(
127
+ raise GsTaichiTypeError(
128
128
  "Integer literals must be annotated with a integer type. For type casting, use `ti.cast`."
129
129
  )
130
130
  if _check_in_range(to_numpy_type(constant_dtype), val):
131
131
  return Expr(_ti_core.make_const_expr_int(constant_dtype, _clamp_unsigned_to_range(np.int64, val)))
132
132
  if dtype is None:
133
- raise TaichiTypeError(
133
+ raise GsTaichiTypeError(
134
134
  f"Integer literal {val} exceeded the range of default_ip: {impl.get_runtime().default_ip}, please specify the dtype via e.g. `ti.u64({val})` or set a different `default_ip` in `ti.init()`"
135
135
  )
136
136
  else:
137
- raise TaichiTypeError(f"Integer literal {val} exceeded the range of specified dtype: {dtype}")
137
+ raise GsTaichiTypeError(f"Integer literal {val} exceeded the range of specified dtype: {dtype}")
138
138
 
139
- raise TaichiTypeError(f"Invalid constant scalar data type: {type(val)}")
139
+ raise GsTaichiTypeError(f"Invalid constant scalar data type: {type(val)}")
140
140
 
141
141
 
142
142
  def make_var_list(size: int, ast_builder: "ASTBuilder | None" = None):
@@ -151,7 +151,7 @@ def make_var_list(size: int, ast_builder: "ASTBuilder | None" = None):
151
151
 
152
152
 
153
153
  def make_expr_group(*exprs):
154
- from taichi.lang.matrix import Matrix # pylint: disable=C0415
154
+ from gstaichi.lang.matrix import Matrix # pylint: disable=C0415
155
155
 
156
156
  if len(exprs) == 1:
157
157
  if isinstance(exprs[0], (list, tuple)):
@@ -169,7 +169,7 @@ def make_expr_group(*exprs):
169
169
 
170
170
 
171
171
  def _get_flattened_ptrs(val):
172
- if is_taichi_class(val):
172
+ if is_gstaichi_class(val):
173
173
  ptrs = []
174
174
  for item in val._members:
175
175
  ptrs.extend(_get_flattened_ptrs(item))