gstaichi 2.1.1rc3__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. gstaichi/CHANGELOG.md +4 -0
  2. gstaichi/__init__.py +40 -0
  3. gstaichi/_funcs.py +706 -0
  4. gstaichi/_kernels.py +420 -0
  5. gstaichi/_lib/__init__.py +3 -0
  6. gstaichi/_lib/core/__init__.py +0 -0
  7. gstaichi/_lib/core/gstaichi_python.cpython-313-darwin.so +0 -0
  8. gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
  9. gstaichi/_lib/core/py.typed +0 -0
  10. gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
  11. gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
  12. gstaichi/_lib/utils.py +243 -0
  13. gstaichi/_logging.py +131 -0
  14. gstaichi/_snode/__init__.py +5 -0
  15. gstaichi/_snode/fields_builder.py +187 -0
  16. gstaichi/_snode/snode_tree.py +34 -0
  17. gstaichi/_test_tools/__init__.py +18 -0
  18. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  19. gstaichi/_test_tools/load_kernel_string.py +30 -0
  20. gstaichi/_test_tools/textwrap2.py +6 -0
  21. gstaichi/_version.py +1 -0
  22. gstaichi/_version_check.py +100 -0
  23. gstaichi/ad/__init__.py +3 -0
  24. gstaichi/ad/_ad.py +530 -0
  25. gstaichi/algorithms/__init__.py +3 -0
  26. gstaichi/algorithms/_algorithms.py +117 -0
  27. gstaichi/assets/.git +1 -0
  28. gstaichi/assets/Go-Regular.ttf +0 -0
  29. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  30. gstaichi/examples/lcg_python.py +26 -0
  31. gstaichi/examples/lcg_taichi.py +34 -0
  32. gstaichi/examples/minimal.py +28 -0
  33. gstaichi/experimental.py +16 -0
  34. gstaichi/lang/__init__.py +50 -0
  35. gstaichi/lang/_dataclass_util.py +31 -0
  36. gstaichi/lang/_fast_caching/__init__.py +3 -0
  37. gstaichi/lang/_fast_caching/args_hasher.py +110 -0
  38. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  39. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  40. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  41. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  42. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  43. gstaichi/lang/_fast_caching/src_hasher.py +75 -0
  44. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  45. gstaichi/lang/_ndarray.py +352 -0
  46. gstaichi/lang/_ndrange.py +152 -0
  47. gstaichi/lang/_template_mapper.py +195 -0
  48. gstaichi/lang/_texture.py +172 -0
  49. gstaichi/lang/_wrap_inspect.py +215 -0
  50. gstaichi/lang/any_array.py +99 -0
  51. gstaichi/lang/ast/__init__.py +5 -0
  52. gstaichi/lang/ast/ast_transformer.py +1323 -0
  53. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  54. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  55. gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
  56. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  57. gstaichi/lang/ast/checkers.py +106 -0
  58. gstaichi/lang/ast/symbol_resolver.py +57 -0
  59. gstaichi/lang/ast/transform.py +9 -0
  60. gstaichi/lang/common_ops.py +310 -0
  61. gstaichi/lang/exception.py +80 -0
  62. gstaichi/lang/expr.py +180 -0
  63. gstaichi/lang/field.py +428 -0
  64. gstaichi/lang/impl.py +1243 -0
  65. gstaichi/lang/kernel_arguments.py +155 -0
  66. gstaichi/lang/kernel_impl.py +1341 -0
  67. gstaichi/lang/matrix.py +1835 -0
  68. gstaichi/lang/matrix_ops.py +341 -0
  69. gstaichi/lang/matrix_ops_utils.py +190 -0
  70. gstaichi/lang/mesh.py +687 -0
  71. gstaichi/lang/misc.py +782 -0
  72. gstaichi/lang/ops.py +1494 -0
  73. gstaichi/lang/runtime_ops.py +13 -0
  74. gstaichi/lang/shell.py +35 -0
  75. gstaichi/lang/simt/__init__.py +5 -0
  76. gstaichi/lang/simt/block.py +94 -0
  77. gstaichi/lang/simt/grid.py +7 -0
  78. gstaichi/lang/simt/subgroup.py +191 -0
  79. gstaichi/lang/simt/warp.py +96 -0
  80. gstaichi/lang/snode.py +489 -0
  81. gstaichi/lang/source_builder.py +150 -0
  82. gstaichi/lang/struct.py +810 -0
  83. gstaichi/lang/util.py +312 -0
  84. gstaichi/linalg/__init__.py +8 -0
  85. gstaichi/linalg/matrixfree_cg.py +310 -0
  86. gstaichi/linalg/sparse_cg.py +59 -0
  87. gstaichi/linalg/sparse_matrix.py +303 -0
  88. gstaichi/linalg/sparse_solver.py +123 -0
  89. gstaichi/math/__init__.py +11 -0
  90. gstaichi/math/_complex.py +205 -0
  91. gstaichi/math/mathimpl.py +886 -0
  92. gstaichi/profiler/__init__.py +6 -0
  93. gstaichi/profiler/kernel_metrics.py +260 -0
  94. gstaichi/profiler/kernel_profiler.py +586 -0
  95. gstaichi/profiler/memory_profiler.py +15 -0
  96. gstaichi/profiler/scoped_profiler.py +36 -0
  97. gstaichi/sparse/__init__.py +3 -0
  98. gstaichi/sparse/_sparse_grid.py +77 -0
  99. gstaichi/tools/__init__.py +12 -0
  100. gstaichi/tools/diagnose.py +117 -0
  101. gstaichi/tools/np2ply.py +364 -0
  102. gstaichi/tools/vtk.py +38 -0
  103. gstaichi/types/__init__.py +19 -0
  104. gstaichi/types/annotations.py +52 -0
  105. gstaichi/types/compound_types.py +71 -0
  106. gstaichi/types/enums.py +49 -0
  107. gstaichi/types/ndarray_type.py +169 -0
  108. gstaichi/types/primitive_types.py +206 -0
  109. gstaichi/types/quant.py +88 -0
  110. gstaichi/types/texture_type.py +85 -0
  111. gstaichi/types/utils.py +11 -0
  112. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
  113. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
  114. gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
  115. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
  116. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
  117. gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
  118. gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
  119. gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  120. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
  121. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
  122. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  123. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  124. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  125. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  126. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  127. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  128. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  129. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  130. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  131. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  132. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  133. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  134. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  135. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  136. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  137. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  138. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  139. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  140. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  141. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  142. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  143. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  144. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  145. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  146. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  147. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  148. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  149. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  150. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  151. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  152. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  153. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  154. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  155. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  156. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  157. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  158. gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  159. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  160. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  161. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  162. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  163. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  164. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  165. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  166. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  167. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  168. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  169. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  170. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  171. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  172. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  173. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  174. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  175. gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
  176. gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
  177. gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
  178. gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
  179. gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,215 @@
1
+ # type: ignore
2
+
3
+ # GsTaichi's custom inspect module.
4
+ # This module is used by GsTaichi's ast transformer to parse the source code.
5
+ # Currently this module is aimed for working in the following modes:
6
+ # 1. Usual Python/IPython mode, e.g. python script.py
7
+ # In this case we mainly rely on the built-in `inspect` module, except
8
+ # we need some hacks when we are in IPython mode and there is a cell magic.
9
+ # 2. Blender's scripting mode, e.g. Users write GsTaichi code in the scripting
10
+ # window in Blender and press the run button. In this case we need to
11
+ # retrieve the source using Blender's `bpy.data.texts` and write it to a temp
12
+ # file so that the inspect module can parse.
13
+ # 3. The interactive shell mode, e.g. Users directly type their code in the
14
+ # interactive shell. In this case we use `dill` to get the source.
15
+ #
16
+ # NB: Running GsTaichi in other modes are likely not supported.
17
+
18
+ import atexit
19
+ import inspect
20
+ import os
21
+ import tempfile
22
+ from typing import Callable
23
+
24
+ import dill
25
+ from pydantic import BaseModel
26
+
27
+ _builtin_getfile = inspect.getfile
28
+ _builtin_findsource = inspect.findsource
29
+
30
+
31
+ def _find_source_with_custom_getfile_func(func, obj):
32
+ """Use a custom function `func` to replace inspect's `getfile`, return the
33
+ source found by the new routine and restore the original `getfile` back.
34
+ """
35
+ inspect.getfile = func # replace with our custom func
36
+ source = inspect.findsource(obj)
37
+ inspect.getfile = _builtin_getfile # restore
38
+ return source
39
+
40
+
41
+ def _blender_get_text_name(filename: str):
42
+ """Extract filename from path in the Blender mode."""
43
+ # In Blender's scripting mode, unsaved files are named
44
+ # like `/Text`, `/Text.001`, `/test.py`, etc.
45
+ # We simply remove this path seperator.
46
+ if filename.startswith(os.path.sep) and filename.count(os.path.sep) == 1:
47
+ return filename[1:] # "/Text.001" --> "Text.001"
48
+
49
+ # Saved text files are named like `some-path/xxx.blend/Text` or
50
+ # `some-path/xxx.blend/test.py`
51
+ # We drop the path and extract the filename with extension.
52
+ index = filename.rfind(".blend" + os.path.sep)
53
+ if index != -1:
54
+ return filename[index + 7 :] # "xxx.blend/test.py" --> "test.py"
55
+
56
+ return None
57
+
58
+
59
+ def _blender_findsource(obj):
60
+ try:
61
+ import bpy # pylint: disable=import-outside-toplevel
62
+ except:
63
+ raise ImportError("Not in Blender environment!")
64
+
65
+ # Inspect's built-in `getfile` returns the filename like
66
+ # `/Text`, `/Text.001`, `some-path/xxx.blend/test.py`
67
+ # This filename may not be a full valid path.
68
+ filename = _builtin_getfile(obj)
69
+ # Extract the text name without path
70
+ text_name = _blender_get_text_name(filename)
71
+ if text_name is None:
72
+ raise IOError("Object `{obj.__name__}` is not defined in a .blend file!")
73
+ # Get the lines of code via text_name
74
+ lines = bpy.data.texts[text_name].as_string()
75
+ # Now we have found the lines of code.
76
+ # We first check if they are already cached, to avoid file io in each query.
77
+ try:
78
+ filename = _blender_findsource._saved_inspect_cache[lines] # pylint: disable=no-member
79
+ except KeyError:
80
+ # Save the code to a valid path.
81
+ fd, filename = tempfile.mkstemp(prefix="_Blender_", suffix=f"_{text_name}.py")
82
+ os.close(fd)
83
+
84
+ with open(filename, "w") as f:
85
+ f.write(lines)
86
+
87
+ _blender_findsource._saved_inspect_cache[lines] = filename # pylint: disable=no-member
88
+ atexit.register(os.unlink, filename) # Remove file when program exits
89
+
90
+ # Our custom getfile function
91
+ def wrapped_getfile(ob):
92
+ if id(ob) == id(obj):
93
+ return filename
94
+
95
+ return _builtin_getfile(ob)
96
+
97
+ return _find_source_with_custom_getfile_func(wrapped_getfile, obj)
98
+
99
+
100
+ _blender_findsource._saved_inspect_cache = {}
101
+
102
+
103
+ def _Python_IPython_findsource(obj):
104
+ try:
105
+ # In Python and IPython the builtin inspect would suffice in most cases
106
+ return _builtin_findsource(obj)
107
+ except IOError:
108
+ # Except that the cell has a magic command like %%time or %%timeit
109
+ # In this case the filename returned by the built-in's getfile is wrong,
110
+ # it becomes something like `<timed exec>` or `<magic-timeit>`.
111
+ filename = _builtin_getfile(obj)
112
+ if filename in {"<timed exec>", "<magic-timeit>"}:
113
+ try:
114
+ ip = get_ipython()
115
+ if ip is not None:
116
+ # So we are in IPython's cell magic
117
+ session_id = ip.history_manager.get_last_session_id()
118
+ fd, filename = tempfile.mkstemp(prefix="_IPython_", suffix=f"_{session_id}.py")
119
+ os.close(fd)
120
+ # The latest lines of code can be retrived from here
121
+ lines = ip.history_manager._i00
122
+
123
+ # `lines` is a string that also contains the cell magic
124
+ # command, we need to remove the magic command
125
+ # (and spaces/sep around it) to obtain a valid Python code
126
+ # snippet before saving it to a file
127
+ index = lines.find("%time")
128
+ lines_stripped = lines[index:]
129
+ lines_stripped = lines_stripped.split(maxsplit=1)[1]
130
+
131
+ with open(filename, "w") as f:
132
+ f.write(lines_stripped)
133
+
134
+ atexit.register(os.unlink, filename) # Remove the file after the program exits
135
+ func = lambda obj: filename
136
+ return _find_source_with_custom_getfile_func(func, obj)
137
+
138
+ except ImportError:
139
+ pass
140
+ raise IOError(
141
+ f"Cannot find source code for Object: {obj}, it's likely \
142
+ you are not running GsTaichi from command line or IPython."
143
+ )
144
+
145
+
146
+ def _REPL_findsource(obj):
147
+ """Findsource in the interactive shell mode."""
148
+ return dill.source.findsource(obj)
149
+
150
+
151
+ def _custom_findsource(obj):
152
+ try:
153
+ return _Python_IPython_findsource(obj)
154
+ except IOError:
155
+ try:
156
+ return _REPL_findsource(obj)
157
+ except:
158
+ try:
159
+ return _blender_findsource(obj)
160
+ except:
161
+ raise IOError(
162
+ f"Cannot find source code for Object: {obj}, this \
163
+ is possibly because of you are running GsTaichi in an environment that GsTaichi's own \
164
+ inspect module cannot find the source. Please report an issue to help us fix: \
165
+ https://github.com/taichi-dev/gstaichi/issues"
166
+ )
167
+
168
+
169
+ class _InspectContextManager:
170
+ def __enter__(self):
171
+ inspect.findsource = _custom_findsource
172
+ return self
173
+
174
+ def __exit__(self, *_):
175
+ inspect.findsource = _builtin_findsource
176
+
177
+
178
+ def getsourcelines(obj):
179
+ with _InspectContextManager():
180
+ return inspect.getsourcelines(obj)
181
+
182
+
183
+ def getsourcefile(obj):
184
+ with _InspectContextManager():
185
+ ret = inspect.getsourcefile(obj)
186
+ if ret is None:
187
+ ret = inspect.getfile(obj)
188
+ return ret
189
+
190
+
191
+ class FunctionSourceInfo(BaseModel):
192
+ function_name: str
193
+ filepath: str
194
+ start_lineno: int
195
+ end_lineno: int
196
+
197
+ class Config:
198
+ frozen = True
199
+
200
+
201
+ def get_source_info_and_src(func: Callable) -> tuple[FunctionSourceInfo, list[str]]:
202
+ file = getsourcefile(func)
203
+ name = func.__name__
204
+ src, start_lineno = getsourcelines(func)
205
+ end_lineno = start_lineno + len(src) - 1
206
+ func_info = FunctionSourceInfo(
207
+ function_name=name,
208
+ filepath=file,
209
+ start_lineno=start_lineno,
210
+ end_lineno=end_lineno,
211
+ )
212
+ return (func_info, src)
213
+
214
+
215
+ __all__ = ["getsourcelines", "getsourcefile", "get_source_info_and_src"]
@@ -0,0 +1,99 @@
1
+ # type: ignore
2
+
3
+ from gstaichi._lib import core as _ti_core
4
+ from gstaichi.lang import impl
5
+ from gstaichi.lang.expr import Expr, make_expr_group
6
+ from gstaichi.lang.util import gstaichi_scope
7
+ from gstaichi.types.enums import Layout
8
+ from gstaichi.types.ndarray_type import NdarrayTypeMetadata
9
+
10
+
11
+ class AnyArray:
12
+ """Class for arbitrary arrays in Python AST.
13
+
14
+ Args:
15
+ ptr (gstaichi_python.Expr): A gstaichi_python.Expr wrapping a gstaichi_python.ExternalTensorExpression.
16
+ element_shape (Tuple[Int]): () if scalar elements (default), (n) if vector elements, and (n, m) if matrix elements.
17
+ layout (Layout): Memory layout.
18
+ """
19
+
20
+ def __init__(self, ptr):
21
+ assert ptr.is_external_tensor_expr()
22
+ self.ptr = ptr
23
+ self.ptr.type_check(impl.get_runtime().prog.config())
24
+
25
+ def element_shape(self):
26
+ return _ti_core.get_external_tensor_element_shape(self.ptr)
27
+
28
+ def layout(self):
29
+ # 0: scalar; 1: vector (SOA); 2: matrix (SOA); -1: vector
30
+ # (AOS); -2: matrix (AOS)
31
+ element_dim = _ti_core.get_external_tensor_element_dim(self.ptr)
32
+ if element_dim == 1 or element_dim == 2:
33
+ return Layout.SOA
34
+ return Layout.AOS
35
+
36
+ def get_type(self):
37
+ return NdarrayTypeMetadata(
38
+ _ti_core.get_external_tensor_element_type(self.ptr), None, _ti_core.get_external_tensor_needs_grad(self.ptr)
39
+ ) # AnyArray can take any shape
40
+
41
+ @property
42
+ @gstaichi_scope
43
+ def grad(self):
44
+ """Returns the gradient of this array."""
45
+ return AnyArray(_ti_core.make_external_tensor_grad_expr(self.ptr))
46
+
47
+ @property
48
+ @gstaichi_scope
49
+ def shape(self):
50
+ """A list containing sizes for each dimension. Note that element shape will be excluded.
51
+
52
+ Returns:
53
+ List[Int]: The result list.
54
+ """
55
+ dim = _ti_core.get_external_tensor_dim(self.ptr)
56
+ dbg_info = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
57
+ return [Expr(_ti_core.get_external_tensor_shape_along_axis(self.ptr, i, dbg_info)) for i in range(dim)]
58
+
59
+ @gstaichi_scope
60
+ def _loop_range(self):
61
+ """Gets the corresponding gstaichi_python.Expr to serve as loop range.
62
+
63
+ Returns:
64
+ gstaichi_python.Expr: See above.
65
+ """
66
+ return self.ptr
67
+
68
+
69
+ class AnyArrayAccess:
70
+ """Class for first-level access to AnyArray with Vector/Matrix elements in Python AST.
71
+
72
+ Args:
73
+ arr (AnyArray): See above.
74
+ indices_first (Tuple[Int]): Indices of first-level access.
75
+ """
76
+
77
+ def __init__(self, arr, indices_first):
78
+ self.arr = arr
79
+ self.indices_first = indices_first
80
+
81
+ @gstaichi_scope
82
+ def subscript(self, i, j):
83
+ ast_builder = impl.get_runtime().compiling_callable.ast_builder()
84
+
85
+ indices_second = (i,) if len(self.arr.element_shape()) == 1 else (i, j)
86
+ if self.arr.layout() == Layout.SOA:
87
+ indices = indices_second + self.indices_first
88
+ else:
89
+ indices = self.indices_first + indices_second
90
+ return Expr(
91
+ ast_builder.expr_subscript(
92
+ self.arr.ptr,
93
+ make_expr_group(*indices),
94
+ _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
95
+ )
96
+ )
97
+
98
+
99
+ __all__ = []
@@ -0,0 +1,5 @@
1
+ # type: ignore
2
+
3
+ from gstaichi.lang.ast.ast_transformer_utils import ASTTransformerContext
4
+ from gstaichi.lang.ast.checkers import KernelSimplicityASTChecker
5
+ from gstaichi.lang.ast.transform import transform_tree