gstaichi 2.1.1rc3__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. gstaichi/CHANGELOG.md +4 -0
  2. gstaichi/__init__.py +40 -0
  3. gstaichi/_funcs.py +706 -0
  4. gstaichi/_kernels.py +420 -0
  5. gstaichi/_lib/__init__.py +3 -0
  6. gstaichi/_lib/core/__init__.py +0 -0
  7. gstaichi/_lib/core/gstaichi_python.cpython-310-darwin.so +0 -0
  8. gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
  9. gstaichi/_lib/core/py.typed +0 -0
  10. gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
  11. gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
  12. gstaichi/_lib/utils.py +243 -0
  13. gstaichi/_logging.py +131 -0
  14. gstaichi/_snode/__init__.py +5 -0
  15. gstaichi/_snode/fields_builder.py +187 -0
  16. gstaichi/_snode/snode_tree.py +34 -0
  17. gstaichi/_test_tools/__init__.py +18 -0
  18. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  19. gstaichi/_test_tools/load_kernel_string.py +30 -0
  20. gstaichi/_test_tools/textwrap2.py +6 -0
  21. gstaichi/_version.py +1 -0
  22. gstaichi/_version_check.py +100 -0
  23. gstaichi/ad/__init__.py +3 -0
  24. gstaichi/ad/_ad.py +530 -0
  25. gstaichi/algorithms/__init__.py +3 -0
  26. gstaichi/algorithms/_algorithms.py +117 -0
  27. gstaichi/assets/.git +1 -0
  28. gstaichi/assets/Go-Regular.ttf +0 -0
  29. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  30. gstaichi/examples/lcg_python.py +26 -0
  31. gstaichi/examples/lcg_taichi.py +34 -0
  32. gstaichi/examples/minimal.py +28 -0
  33. gstaichi/experimental.py +16 -0
  34. gstaichi/lang/__init__.py +50 -0
  35. gstaichi/lang/_dataclass_util.py +31 -0
  36. gstaichi/lang/_fast_caching/__init__.py +3 -0
  37. gstaichi/lang/_fast_caching/args_hasher.py +110 -0
  38. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  39. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  40. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  41. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  42. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  43. gstaichi/lang/_fast_caching/src_hasher.py +75 -0
  44. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  45. gstaichi/lang/_ndarray.py +352 -0
  46. gstaichi/lang/_ndrange.py +152 -0
  47. gstaichi/lang/_template_mapper.py +195 -0
  48. gstaichi/lang/_texture.py +172 -0
  49. gstaichi/lang/_wrap_inspect.py +215 -0
  50. gstaichi/lang/any_array.py +99 -0
  51. gstaichi/lang/ast/__init__.py +5 -0
  52. gstaichi/lang/ast/ast_transformer.py +1323 -0
  53. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  54. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  55. gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
  56. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  57. gstaichi/lang/ast/checkers.py +106 -0
  58. gstaichi/lang/ast/symbol_resolver.py +57 -0
  59. gstaichi/lang/ast/transform.py +9 -0
  60. gstaichi/lang/common_ops.py +310 -0
  61. gstaichi/lang/exception.py +80 -0
  62. gstaichi/lang/expr.py +180 -0
  63. gstaichi/lang/field.py +428 -0
  64. gstaichi/lang/impl.py +1243 -0
  65. gstaichi/lang/kernel_arguments.py +155 -0
  66. gstaichi/lang/kernel_impl.py +1341 -0
  67. gstaichi/lang/matrix.py +1835 -0
  68. gstaichi/lang/matrix_ops.py +341 -0
  69. gstaichi/lang/matrix_ops_utils.py +190 -0
  70. gstaichi/lang/mesh.py +687 -0
  71. gstaichi/lang/misc.py +782 -0
  72. gstaichi/lang/ops.py +1494 -0
  73. gstaichi/lang/runtime_ops.py +13 -0
  74. gstaichi/lang/shell.py +35 -0
  75. gstaichi/lang/simt/__init__.py +5 -0
  76. gstaichi/lang/simt/block.py +94 -0
  77. gstaichi/lang/simt/grid.py +7 -0
  78. gstaichi/lang/simt/subgroup.py +191 -0
  79. gstaichi/lang/simt/warp.py +96 -0
  80. gstaichi/lang/snode.py +489 -0
  81. gstaichi/lang/source_builder.py +150 -0
  82. gstaichi/lang/struct.py +810 -0
  83. gstaichi/lang/util.py +312 -0
  84. gstaichi/linalg/__init__.py +8 -0
  85. gstaichi/linalg/matrixfree_cg.py +310 -0
  86. gstaichi/linalg/sparse_cg.py +59 -0
  87. gstaichi/linalg/sparse_matrix.py +303 -0
  88. gstaichi/linalg/sparse_solver.py +123 -0
  89. gstaichi/math/__init__.py +11 -0
  90. gstaichi/math/_complex.py +205 -0
  91. gstaichi/math/mathimpl.py +886 -0
  92. gstaichi/profiler/__init__.py +6 -0
  93. gstaichi/profiler/kernel_metrics.py +260 -0
  94. gstaichi/profiler/kernel_profiler.py +586 -0
  95. gstaichi/profiler/memory_profiler.py +15 -0
  96. gstaichi/profiler/scoped_profiler.py +36 -0
  97. gstaichi/sparse/__init__.py +3 -0
  98. gstaichi/sparse/_sparse_grid.py +77 -0
  99. gstaichi/tools/__init__.py +12 -0
  100. gstaichi/tools/diagnose.py +117 -0
  101. gstaichi/tools/np2ply.py +364 -0
  102. gstaichi/tools/vtk.py +38 -0
  103. gstaichi/types/__init__.py +19 -0
  104. gstaichi/types/annotations.py +52 -0
  105. gstaichi/types/compound_types.py +71 -0
  106. gstaichi/types/enums.py +49 -0
  107. gstaichi/types/ndarray_type.py +169 -0
  108. gstaichi/types/primitive_types.py +206 -0
  109. gstaichi/types/quant.py +88 -0
  110. gstaichi/types/texture_type.py +85 -0
  111. gstaichi/types/utils.py +11 -0
  112. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
  113. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
  114. gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
  115. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
  116. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
  117. gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
  118. gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
  119. gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  120. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
  121. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
  122. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  123. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  124. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  125. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  126. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  127. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  128. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  129. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  130. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  131. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  132. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  133. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  134. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  135. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  136. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  137. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  138. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  139. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  140. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  141. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  142. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  143. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  144. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  145. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  146. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  147. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  148. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  149. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  150. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  151. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  152. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  153. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  154. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  155. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  156. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  157. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  158. gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  159. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  160. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  161. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  162. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  163. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  164. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  165. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  166. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  167. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  168. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  169. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  170. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  171. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  172. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  173. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  174. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  175. gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
  176. gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
  177. gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
  178. gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
  179. gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
gstaichi/lang/snode.py ADDED
@@ -0,0 +1,489 @@
1
+ # type: ignore
2
+
3
+ import numbers
4
+
5
+ from gstaichi._lib import core as _ti_core
6
+ from gstaichi._lib.core.gstaichi_python import (
7
+ Axis,
8
+ SNodeCxx,
9
+ )
10
+ from gstaichi.lang import expr, impl, matrix
11
+ from gstaichi.lang.exception import GsTaichiRuntimeError
12
+ from gstaichi.lang.field import BitpackedFields, Field
13
+ from gstaichi.lang.util import get_traceback
14
+
15
+
16
+ class SNode:
17
+ """A Python-side SNode wrapper.
18
+
19
+ For more information on GsTaichi's SNode system, please check out
20
+ these references:
21
+
22
+ * https://docs.taichi-lang.org/docs/sparse
23
+ * https://yuanming.gstaichi.graphics/publication/2019-gstaichi/gstaichi-lang.pdf
24
+
25
+ Arg:
26
+ ptr (pointer): The C++ side SNode pointer.
27
+ """
28
+
29
+ def __init__(self, ptr: SNodeCxx) -> None:
30
+ self.ptr = ptr
31
+
32
+ def dense(self, axes: list[Axis], dimensions: list[int] | int) -> "SNode":
33
+ """Adds a dense SNode as a child component of `self`.
34
+
35
+ Args:
36
+ axes (List[Axis]): Axes to activate.
37
+ dimensions (Union[List[int], int]): Shape of each axis.
38
+
39
+ Returns:
40
+ The added :class:`~gstaichi.lang.SNode` instance.
41
+ """
42
+ if isinstance(dimensions, numbers.Number):
43
+ dimensions = [dimensions] * len(axes)
44
+ return SNode(self.ptr.dense(axes, dimensions, _ti_core.DebugInfo(get_traceback())))
45
+
46
+ def pointer(self, axes: list[Axis], dimensions: list[int] | int) -> "SNode":
47
+ """Adds a pointer SNode as a child component of `self`.
48
+
49
+ Args:
50
+ axes (List[Axis]): Axes to activate.
51
+ dimensions (Union[List[int], int]): Shape of each axis.
52
+
53
+ Returns:
54
+ The added :class:`~gstaichi.lang.SNode` instance.
55
+ """
56
+ if not _ti_core.is_extension_supported(impl.current_cfg().arch, _ti_core.Extension.sparse):
57
+ raise GsTaichiRuntimeError("Pointer SNode is not supported on this backend.")
58
+ if isinstance(dimensions, numbers.Number):
59
+ dimensions = [dimensions] * len(axes)
60
+ return SNode(self.ptr.pointer(axes, dimensions, _ti_core.DebugInfo(get_traceback())))
61
+
62
+ @staticmethod
63
+ def _hash(axes, dimensions):
64
+ # original code is #def hash(self,axes, dimensions) without #@staticmethod before fix pylint R0201
65
+ """Not supported."""
66
+ raise RuntimeError("hash not yet supported")
67
+ # if isinstance(dimensions, int):
68
+ # dimensions = [dimensions] * len(axes)
69
+ # return SNode(self.ptr.hash(axes, dimensions))
70
+
71
+ def dynamic(self, axis: list[Axis], dimension: int, chunk_size: int | None = None) -> "SNode":
72
+ """Adds a dynamic SNode as a child component of `self`.
73
+
74
+ Args:
75
+ axis (List[Axis]): Axis to activate, must be 1.
76
+ dimension (int): Shape of the axis.
77
+ chunk_size (int): Chunk size.
78
+
79
+ Returns:
80
+ The added :class:`~gstaichi.lang.SNode` instance.
81
+ """
82
+ if not _ti_core.is_extension_supported(impl.current_cfg().arch, _ti_core.Extension.sparse):
83
+ raise GsTaichiRuntimeError("Dynamic SNode is not supported on this backend.")
84
+ assert len(axis) == 1
85
+ if chunk_size is None:
86
+ chunk_size = dimension
87
+ return SNode(self.ptr.dynamic(axis[0], dimension, chunk_size, _ti_core.DebugInfo(get_traceback())))
88
+
89
+ def bitmasked(self, axes: list[Axis], dimensions: list[int] | int) -> "SNode":
90
+ """Adds a bitmasked SNode as a child component of `self`.
91
+
92
+ Args:
93
+ axes (List[Axis]): Axes to activate.
94
+ dimensions (Union[List[int], int]): Shape of each axis.
95
+
96
+ Returns:
97
+ The added :class:`~gstaichi.lang.SNode` instance.
98
+ """
99
+ if not _ti_core.is_extension_supported(impl.current_cfg().arch, _ti_core.Extension.sparse):
100
+ raise GsTaichiRuntimeError("Bitmasked SNode is not supported on this backend.")
101
+ if isinstance(dimensions, numbers.Number):
102
+ dimensions = [dimensions] * len(axes)
103
+ return SNode(self.ptr.bitmasked(axes, dimensions, _ti_core.DebugInfo(get_traceback())))
104
+
105
+ def quant_array(self, axes: list[Axis], dimensions: list[int] | int, max_num_bits: int) -> "SNode":
106
+ """Adds a quant_array SNode as a child component of `self`.
107
+
108
+ Args:
109
+ axes (List[Axis]): Axes to activate.
110
+ dimensions (Union[List[int], int]): Shape of each axis.
111
+ max_num_bits (int): Maximum number of bits it can hold.
112
+
113
+ Returns:
114
+ The added :class:`~gstaichi.lang.SNode` instance.
115
+ """
116
+ if isinstance(dimensions, numbers.Number):
117
+ dimensions = [dimensions] * len(axes)
118
+ return SNode(self.ptr.quant_array(axes, dimensions, max_num_bits, _ti_core.DebugInfo(get_traceback())))
119
+
120
+ def place(self, *args, offset: numbers.Number | tuple[numbers.Number] | None = None) -> "SNode":
121
+ """Places a list of GsTaichi fields under the `self` container.
122
+
123
+ Args:
124
+ *args (List[ti.field]): A list of GsTaichi fields to place.
125
+ offset (Union[Number, tuple[Number]]): Offset of the field domain.
126
+
127
+ Returns:
128
+ The `self` container.
129
+ """
130
+ if offset is None:
131
+ offset = ()
132
+ if isinstance(offset, numbers.Number):
133
+ offset = (offset,)
134
+
135
+ for arg in args:
136
+ if isinstance(arg, BitpackedFields):
137
+ bit_struct_type = arg.bit_struct_type_builder.build()
138
+ bit_struct_snode = self.ptr.bit_struct(bit_struct_type, _ti_core.DebugInfo(get_traceback()))
139
+ for field, id_in_bit_struct in arg.fields:
140
+ bit_struct_snode.place(field, offset, id_in_bit_struct)
141
+ elif isinstance(arg, Field):
142
+ for var in arg._get_field_members():
143
+ self.ptr.place(var.ptr, offset, -1)
144
+ elif isinstance(arg, list):
145
+ for x in arg:
146
+ self.place(x, offset=offset)
147
+ else:
148
+ raise ValueError(f"{arg} cannot be placed")
149
+ return self
150
+
151
+ def lazy_grad(self):
152
+ """Automatically place the adjoint fields following the layout of their primal fields.
153
+
154
+ Users don't need to specify ``needs_grad`` when they define scalar/vector/matrix fields (primal fields) using autodiff.
155
+ When all the primal fields are defined, using ``gstaichi.root.lazy_grad()`` could automatically generate
156
+ their corresponding adjoint fields (gradient field).
157
+
158
+ To know more details about primal, adjoint fields and ``lazy_grad()``,
159
+ please see Page 4 and Page 13-14 of DiffGsTaichi Paper: https://arxiv.org/pdf/1910.00935.pdf
160
+ """
161
+ self.ptr.lazy_grad()
162
+
163
+ def lazy_dual(self):
164
+ """Automatically place the dual fields following the layout of their primal fields."""
165
+ self.ptr.lazy_dual()
166
+
167
+ def _allocate_adjoint_checkbit(self):
168
+ """Automatically place the adjoint flag fields following the layout of their primal fields for global data access rule checker"""
169
+ self.ptr.allocate_adjoint_checkbit()
170
+
171
+ def parent(self, n=1):
172
+ """Gets an ancestor of `self` in the SNode tree.
173
+
174
+ Args:
175
+ n (int): the number of levels going up from `self`.
176
+
177
+ Returns:
178
+ Union[None, _Root, SNode]: The n-th parent of `self`.
179
+ """
180
+ p = self.ptr
181
+ while p and n > 0:
182
+ p = p.parent
183
+ n -= 1
184
+ if p is None:
185
+ return None
186
+
187
+ if p.type == _ti_core.SNodeType.root:
188
+ return impl.root
189
+
190
+ return SNode(p)
191
+
192
+ def _path_from_root(self):
193
+ """Gets the path from root to `self` in the SNode tree.
194
+
195
+ Returns:
196
+ List[Union[_Root, SNode]]: The list of SNodes on the path from root to `self`.
197
+ """
198
+ p = self
199
+ res = [p]
200
+ while p != impl.root:
201
+ p = p.parent()
202
+ res.append(p)
203
+ res.reverse()
204
+ return res
205
+
206
+ @property
207
+ def _dtype(self):
208
+ """Gets the data type of `self`.
209
+
210
+ Returns:
211
+ DataType: The data type of `self`.
212
+ """
213
+ return self.ptr.data_type()
214
+
215
+ @property
216
+ def _id(self):
217
+ """Gets the id of `self`.
218
+
219
+ Returns:
220
+ int: The id of `self`.
221
+ """
222
+ return self.ptr.id
223
+
224
+ @property
225
+ def _snode_tree_id(self):
226
+ return self.ptr.get_snode_tree_id()
227
+
228
+ @property
229
+ def shape(self):
230
+ """Gets the number of elements from root in each axis of `self`.
231
+
232
+ Returns:
233
+ Tuple[int]: The number of elements from root in each axis of `self`.
234
+ """
235
+ dim = self.ptr.num_active_indices()
236
+ ret = tuple(self.ptr.get_shape_along_axis(i) for i in range(dim))
237
+
238
+ return ret
239
+
240
+ def _loop_range(self):
241
+ """Gets the gstaichi_python.SNode to serve as loop range.
242
+
243
+ Returns:
244
+ gstaichi_python.SNode: See above.
245
+ """
246
+ return self.ptr
247
+
248
+ @property
249
+ def _name(self):
250
+ """Gets the name of `self`.
251
+
252
+ Returns:
253
+ str: The name of `self`.
254
+ """
255
+ return self.ptr.name()
256
+
257
+ @property
258
+ def _snode(self):
259
+ """Gets `self`.
260
+ Returns:
261
+ SNode: `self`.
262
+ """
263
+ return self
264
+
265
+ def _get_children(self):
266
+ """Gets all children components of `self`.
267
+
268
+ Returns:
269
+ List[SNode]: All children components of `self`.
270
+ """
271
+ children = []
272
+ for i in range(self.ptr.get_num_ch()):
273
+ children.append(SNode(self.ptr.get_ch(i)))
274
+ return children
275
+
276
+ @property
277
+ def _num_dynamically_allocated(self):
278
+ runtime = impl.get_runtime()
279
+ runtime.materialize_root_fb(False)
280
+ return runtime.prog.get_snode_num_dynamically_allocated(self.ptr)
281
+
282
+ @property
283
+ def _cell_size_bytes(self):
284
+ impl.get_runtime().materialize_root_fb(False)
285
+ return self.ptr.cell_size_bytes
286
+
287
+ @property
288
+ def _offset_bytes_in_parent_cell(self):
289
+ impl.get_runtime().materialize_root_fb(False)
290
+ return self.ptr.offset_bytes_in_parent_cell
291
+
292
+ def deactivate_all(self):
293
+ """Recursively deactivate all children components of `self`."""
294
+ ch = self._get_children()
295
+ for c in ch:
296
+ c.deactivate_all()
297
+ SNodeType = _ti_core.SNodeType
298
+ if self.ptr.type == SNodeType.pointer or self.ptr.type == SNodeType.bitmasked:
299
+ from gstaichi._kernels import snode_deactivate # pylint: disable=C0415
300
+
301
+ snode_deactivate(self)
302
+ if self.ptr.type == SNodeType.dynamic:
303
+ # Note that dynamic nodes are different from other sparse nodes:
304
+ # instead of deactivating each element, we only need to deactivate
305
+ # its parent, whose linked list of chunks of elements will be deleted.
306
+ from gstaichi._kernels import ( # pylint: disable=C0415
307
+ snode_deactivate_dynamic,
308
+ )
309
+
310
+ snode_deactivate_dynamic(self)
311
+
312
+ def __repr__(self):
313
+ type_ = str(self.ptr.type)[len("SNodeType.") :]
314
+ return f"<ti.SNode of type {type_}>"
315
+
316
+ def __str__(self):
317
+ # ti.root.dense(ti.i, 3).dense(ti.jk, (4, 5)).place(x)
318
+ # ti.root => dense [3] => dense [3, 4, 5] => place [3, 4, 5]
319
+ type_ = str(self.ptr.type)[len("SNodeType.") :]
320
+ shape = str(list(self.shape))
321
+ parent = str(self.parent())
322
+ return f"{parent} => {type_} {shape}"
323
+
324
+ def __eq__(self, other):
325
+ return self.ptr == other.ptr
326
+
327
+ def _physical_index_position(self):
328
+ """Gets mappings from virtual axes to physical axes.
329
+
330
+ Returns:
331
+ Dict[int, int]: Mappings from virtual axes to physical axes.
332
+ """
333
+ ret = {}
334
+ for virtual, physical in enumerate(self.ptr.get_physical_index_position()):
335
+ if physical != -1:
336
+ ret[virtual] = physical
337
+ return ret
338
+
339
+
340
+ def rescale_index(a, b, I):
341
+ """Rescales the index 'I' of field (or SNode) 'a' to match the shape of SNode 'b'.
342
+
343
+ Args:
344
+
345
+ a, b (Union[:class:`~gstaichi.Field`, :class:`~gstaichi.MatrixField`): Input gstaichi fields or snodes.
346
+ I (Union[list, :class:`~gstaichi.Vector`]): grouped loop index.
347
+
348
+ Returns:
349
+ Ib (:class:`~gstaichi.Vector`): rescaled grouped loop index
350
+ """
351
+
352
+ assert isinstance(a, (Field, SNode)), "The first argument must be a field or an SNode"
353
+ assert isinstance(b, (Field, SNode)), "The second argument must be a field or an SNode"
354
+ if isinstance(I, list):
355
+ n = len(I)
356
+ else:
357
+ assert isinstance(
358
+ I, (expr.Expr, matrix.Matrix)
359
+ ), "The third argument must be an index (list, ti.Vector, or Expr with TensorType)"
360
+ n = I.n
361
+
362
+ from gstaichi.lang.kernel_impl import pyfunc # pylint: disable=C0415
363
+
364
+ @pyfunc
365
+ def _rescale_index():
366
+ result = matrix.Vector([I[i] for i in range(n)])
367
+ for i in impl.static(range(min(n, min(len(a.shape), len(b.shape))))):
368
+ if a.shape[i] > b.shape[i]:
369
+ result[i] = I[i] // (a.shape[i] // b.shape[i])
370
+ if a.shape[i] < b.shape[i]:
371
+ result[i] = I[i] * (b.shape[i] // a.shape[i])
372
+ return result
373
+
374
+ return _rescale_index()
375
+
376
+
377
+ def append(node, indices, val):
378
+ """Append a value `val` to a SNode `node` at index `indices`.
379
+
380
+ Args:
381
+ node (:class:`~gstaichi.SNode`): Input SNode.
382
+ indices (Union[int, :class:`~gstaichi.Vector`]): the indices to visit.
383
+ val (Union[:mod:`~gstaichi.types.primitive_types`, :mod:`~gstaichi.types.compound_types`]): the data to be appended.
384
+ """
385
+ ptrs = expr._get_flattened_ptrs(val)
386
+ append_expr = expr.Expr(
387
+ impl.get_runtime()
388
+ .compiling_callable.ast_builder()
389
+ .expr_snode_append(node._snode.ptr, expr.make_expr_group(indices), ptrs),
390
+ dbg_info=_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
391
+ )
392
+ a = impl.expr_init(append_expr)
393
+ return a
394
+
395
+
396
+ def is_active(node, indices):
397
+ """Explicitly query whether a cell in a SNode `node` at location
398
+ `indices` is active or not.
399
+
400
+ Args:
401
+ node (:class:`~gstaichi.SNode`): Must be a pointer, hash or bitmasked node.
402
+ indices (Union[int, list, :class:`~gstaichi.Vector`]): the indices to visit.
403
+
404
+ Returns:
405
+ bool: the cell `node[indices]` is active or not.
406
+ """
407
+ return expr.Expr(
408
+ impl.get_runtime()
409
+ .compiling_callable.ast_builder()
410
+ .expr_snode_is_active(node._snode.ptr, expr.make_expr_group(indices)),
411
+ dbg_info=_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
412
+ )
413
+
414
+
415
+ def activate(node, indices):
416
+ """Explicitly activate a cell of `node` at location `indices`.
417
+
418
+ Args:
419
+ node (:class:`~gstaichi.SNode`): Must be a pointer, hash or bitmasked node.
420
+ indices (Union[int, :class:`~gstaichi.Vector`]): the indices to activate.
421
+ """
422
+ impl.get_runtime().compiling_callable.ast_builder().insert_activate(
423
+ node._snode.ptr, expr.make_expr_group(indices), _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
424
+ )
425
+
426
+
427
+ def deactivate(node, indices):
428
+ """Explicitly deactivate a cell of `node` at location `indices`.
429
+
430
+ After deactivation, the GsTaichi runtime automatically recycles and zero-fills
431
+ the memory of the deactivated cell.
432
+
433
+ Args:
434
+ node (:class:`~gstaichi.SNode`): Must be a pointer, hash or bitmasked node.
435
+ indices (Union[int, :class:`~gstaichi.Vector`]): the indices to deactivate.
436
+ """
437
+ impl.get_runtime().compiling_callable.ast_builder().insert_deactivate(
438
+ node._snode.ptr, expr.make_expr_group(indices), _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
439
+ )
440
+
441
+
442
+ def length(node, indices):
443
+ """Return the length of the dynamic SNode `node` at index `indices`.
444
+
445
+ Args:
446
+ node (:class:`~gstaichi.SNode`): a dynamic SNode.
447
+ indices (Union[int, :class:`~gstaichi.Vector`]): the indices to query.
448
+
449
+ Returns:
450
+ int: the length of cell `node[indices]`.
451
+ """
452
+ return expr.Expr(
453
+ impl.get_runtime()
454
+ .compiling_callable.ast_builder()
455
+ .expr_snode_length(node._snode.ptr, expr.make_expr_group(indices)),
456
+ dbg_info=_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
457
+ )
458
+
459
+
460
+ def get_addr(f, indices):
461
+ """Query the memory address (on CUDA/x64) of field `f` at index `indices`.
462
+
463
+ Currently, this function can only be called inside a gstaichi kernel.
464
+
465
+ Args:
466
+ f (Union[:class:`~gstaichi.Field`, :class:`~gstaichi.MatrixField`]): Input gstaichi field for memory address query.
467
+ indices (Union[int, :class:`~gstaichi.Vector`]): The specified field indices of the query.
468
+
469
+ Returns:
470
+ ti.u64: The memory address of `f[indices]`.
471
+ """
472
+ return expr.Expr(
473
+ impl.get_runtime()
474
+ .compiling_callable.ast_builder()
475
+ .expr_snode_get_addr(f._snode.ptr, expr.make_expr_group(indices)),
476
+ dbg_info=_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
477
+ )
478
+
479
+
480
+ __all__ = [
481
+ "activate",
482
+ "append",
483
+ "deactivate",
484
+ "get_addr",
485
+ "is_active",
486
+ "length",
487
+ "rescale_index",
488
+ "SNode",
489
+ ]
@@ -0,0 +1,150 @@
1
+ # type: ignore
2
+
3
+ import atexit
4
+ import ctypes
5
+ import os
6
+ import shutil
7
+ import subprocess
8
+ import tempfile
9
+
10
+ from gstaichi._lib import core as _ti_core
11
+ from gstaichi.lang import impl
12
+ from gstaichi.lang.exception import GsTaichiSyntaxError
13
+ from gstaichi.lang.expr import make_expr_group
14
+ from gstaichi.lang.util import get_clangpp
15
+
16
+
17
+ class SourceBuilder:
18
+ def __init__(self):
19
+ self.bc = None
20
+ self.so = None
21
+ self.mode = None
22
+ self.td = None
23
+
24
+ def cleanup():
25
+ if self.td is not None:
26
+ shutil.rmtree(self.td)
27
+
28
+ atexit.register(cleanup)
29
+
30
+ @classmethod
31
+ def from_file(cls, filename, compile_fn=None, _temp_dir=None):
32
+ self = cls()
33
+ self.td = _temp_dir
34
+ if self.td is None:
35
+ self.td = tempfile.mkdtemp()
36
+
37
+ if filename.endswith((".cpp", ".c", ".cc")):
38
+ if impl.current_cfg().arch not in [_ti_core.Arch.x64, _ti_core.Arch.cuda]:
39
+ raise GsTaichiSyntaxError("Unsupported arch for external function call")
40
+ if compile_fn is None:
41
+
42
+ def compile_fn_impl(filename):
43
+ if impl.current_cfg().arch == _ti_core.Arch.x64:
44
+ subprocess.call(
45
+ get_clangpp() + " -flto -c " + filename + " -o " + os.path.join(self.td, "source.bc"),
46
+ shell=True,
47
+ )
48
+ else:
49
+ subprocess.call(
50
+ get_clangpp()
51
+ + " -flto -c "
52
+ + filename
53
+ + " -o "
54
+ + os.path.join(self.td, "source.bc")
55
+ + " -target nvptx64-nvidia-cuda",
56
+ shell=True,
57
+ )
58
+ return os.path.join(self.td, "source.bc")
59
+
60
+ compile_fn = compile_fn_impl
61
+ self.bc = compile_fn(filename)
62
+ self.mode = "bc"
63
+ elif filename.endswith(".cu"):
64
+ if impl.current_cfg().arch not in [_ti_core.Arch.cuda]:
65
+ raise GsTaichiSyntaxError("Unsupported arch for external function call")
66
+ if compile_fn is None:
67
+ shutil.copy(filename, os.path.join(self.td, "source.cu"))
68
+
69
+ def compile_fn_impl(filename):
70
+ # Cannot use -o to specify multiple output files
71
+ subprocess.call(
72
+ get_clangpp()
73
+ + " "
74
+ + os.path.join(self.td, "source.cu")
75
+ + " -c -emit-llvm -std=c++17 --cuda-gpu-arch=sm_50 -nocudalib",
76
+ cwd=self.td,
77
+ shell=True,
78
+ )
79
+ return os.path.join(self.td, "source-cuda-nvptx64-nvidia-cuda-sm_50.bc")
80
+
81
+ compile_fn = compile_fn_impl
82
+ self.bc = compile_fn(filename)
83
+ self.mode = "bc"
84
+ elif filename.endswith((".so", ".dylib", ".dll")):
85
+ if impl.current_cfg().arch not in [_ti_core.Arch.x64]:
86
+ raise GsTaichiSyntaxError("Unsupported arch for external function call")
87
+ self.so = ctypes.CDLL(filename)
88
+ self.mode = "so"
89
+ elif filename.endswith(".ll"):
90
+ if impl.current_cfg().arch not in [_ti_core.Arch.x64, _ti_core.Arch.cuda]:
91
+ raise GsTaichiSyntaxError("Unsupported arch for external function call")
92
+ subprocess.call(
93
+ "llvm-as " + filename + " -o " + os.path.join(self.td, "source.bc"),
94
+ shell=True,
95
+ )
96
+ self.bc = os.path.join(self.td, "source.bc")
97
+ self.mode = "bc"
98
+ elif filename.endswith(".bc"):
99
+ if impl.current_cfg().arch not in [_ti_core.Arch.x64, _ti_core.Arch.cuda]:
100
+ raise GsTaichiSyntaxError("Unsupported arch for external function call")
101
+ self.bc = filename
102
+ self.mode = "bc"
103
+ else:
104
+ raise GsTaichiSyntaxError("Unsupported file type for external function call.")
105
+ return self
106
+
107
+ @classmethod
108
+ def from_source(cls, source_code, compile_fn=None):
109
+ if impl.current_cfg().arch not in [_ti_core.Arch.x64, _ti_core.Arch.cuda]:
110
+ raise GsTaichiSyntaxError("Unsupported arch for external function call")
111
+ _temp_dir = tempfile.mkdtemp()
112
+ _temp_source = os.path.join(_temp_dir, "_temp_source.cpp")
113
+ with open(_temp_source, "w") as f:
114
+ f.write(source_code)
115
+ return SourceBuilder.from_file(_temp_source, compile_fn, _temp_dir)
116
+
117
+ def __getattr__(self, item):
118
+ def bitcode_func_call_wrapper(*args):
119
+ impl.get_runtime().compiling_callable.ast_builder().insert_external_func_call(
120
+ 0,
121
+ "",
122
+ self.bc,
123
+ item,
124
+ make_expr_group(args),
125
+ make_expr_group([]),
126
+ _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
127
+ )
128
+
129
+ if self.mode == "bc":
130
+ return bitcode_func_call_wrapper
131
+
132
+ def external_func_call_wrapper(args=[], outputs=[]):
133
+ func_addr = ctypes.cast(self.so.__getattr__(item), ctypes.c_void_p).value
134
+ impl.get_runtime().compiling_callable.ast_builder().insert_external_func_call(
135
+ func_addr,
136
+ "",
137
+ "",
138
+ "",
139
+ make_expr_group(args),
140
+ make_expr_group(outputs),
141
+ _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
142
+ )
143
+
144
+ if self.mode == "so":
145
+ return external_func_call_wrapper
146
+
147
+ raise GsTaichiSyntaxError("Error occurs when calling external function.")
148
+
149
+
150
+ __all__ = []