gstaichi 2.1.1rc3__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +4 -0
- gstaichi/__init__.py +40 -0
- gstaichi/_funcs.py +706 -0
- gstaichi/_kernels.py +420 -0
- gstaichi/_lib/__init__.py +3 -0
- gstaichi/_lib/core/__init__.py +0 -0
- gstaichi/_lib/core/gstaichi_python.cpython-310-darwin.so +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
- gstaichi/_lib/core/py.typed +0 -0
- gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
- gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
- gstaichi/_lib/utils.py +243 -0
- gstaichi/_logging.py +131 -0
- gstaichi/_snode/__init__.py +5 -0
- gstaichi/_snode/fields_builder.py +187 -0
- gstaichi/_snode/snode_tree.py +34 -0
- gstaichi/_test_tools/__init__.py +18 -0
- gstaichi/_test_tools/dataclass_test_tools.py +36 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_test_tools/textwrap2.py +6 -0
- gstaichi/_version.py +1 -0
- gstaichi/_version_check.py +100 -0
- gstaichi/ad/__init__.py +3 -0
- gstaichi/ad/_ad.py +530 -0
- gstaichi/algorithms/__init__.py +3 -0
- gstaichi/algorithms/_algorithms.py +117 -0
- gstaichi/assets/.git +1 -0
- gstaichi/assets/Go-Regular.ttf +0 -0
- gstaichi/assets/static/imgs/ti_gallery.png +0 -0
- gstaichi/examples/lcg_python.py +26 -0
- gstaichi/examples/lcg_taichi.py +34 -0
- gstaichi/examples/minimal.py +28 -0
- gstaichi/experimental.py +16 -0
- gstaichi/lang/__init__.py +50 -0
- gstaichi/lang/_dataclass_util.py +31 -0
- gstaichi/lang/_fast_caching/__init__.py +3 -0
- gstaichi/lang/_fast_caching/args_hasher.py +110 -0
- gstaichi/lang/_fast_caching/config_hasher.py +30 -0
- gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
- gstaichi/lang/_fast_caching/function_hasher.py +57 -0
- gstaichi/lang/_fast_caching/hash_utils.py +11 -0
- gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
- gstaichi/lang/_fast_caching/src_hasher.py +75 -0
- gstaichi/lang/_kernel_impl_dataclass.py +212 -0
- gstaichi/lang/_ndarray.py +352 -0
- gstaichi/lang/_ndrange.py +152 -0
- gstaichi/lang/_template_mapper.py +195 -0
- gstaichi/lang/_texture.py +172 -0
- gstaichi/lang/_wrap_inspect.py +215 -0
- gstaichi/lang/any_array.py +99 -0
- gstaichi/lang/ast/__init__.py +5 -0
- gstaichi/lang/ast/ast_transformer.py +1323 -0
- gstaichi/lang/ast/ast_transformer_utils.py +346 -0
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
- gstaichi/lang/ast/checkers.py +106 -0
- gstaichi/lang/ast/symbol_resolver.py +57 -0
- gstaichi/lang/ast/transform.py +9 -0
- gstaichi/lang/common_ops.py +310 -0
- gstaichi/lang/exception.py +80 -0
- gstaichi/lang/expr.py +180 -0
- gstaichi/lang/field.py +428 -0
- gstaichi/lang/impl.py +1243 -0
- gstaichi/lang/kernel_arguments.py +155 -0
- gstaichi/lang/kernel_impl.py +1341 -0
- gstaichi/lang/matrix.py +1835 -0
- gstaichi/lang/matrix_ops.py +341 -0
- gstaichi/lang/matrix_ops_utils.py +190 -0
- gstaichi/lang/mesh.py +687 -0
- gstaichi/lang/misc.py +782 -0
- gstaichi/lang/ops.py +1494 -0
- gstaichi/lang/runtime_ops.py +13 -0
- gstaichi/lang/shell.py +35 -0
- gstaichi/lang/simt/__init__.py +5 -0
- gstaichi/lang/simt/block.py +94 -0
- gstaichi/lang/simt/grid.py +7 -0
- gstaichi/lang/simt/subgroup.py +191 -0
- gstaichi/lang/simt/warp.py +96 -0
- gstaichi/lang/snode.py +489 -0
- gstaichi/lang/source_builder.py +150 -0
- gstaichi/lang/struct.py +810 -0
- gstaichi/lang/util.py +312 -0
- gstaichi/linalg/__init__.py +8 -0
- gstaichi/linalg/matrixfree_cg.py +310 -0
- gstaichi/linalg/sparse_cg.py +59 -0
- gstaichi/linalg/sparse_matrix.py +303 -0
- gstaichi/linalg/sparse_solver.py +123 -0
- gstaichi/math/__init__.py +11 -0
- gstaichi/math/_complex.py +205 -0
- gstaichi/math/mathimpl.py +886 -0
- gstaichi/profiler/__init__.py +6 -0
- gstaichi/profiler/kernel_metrics.py +260 -0
- gstaichi/profiler/kernel_profiler.py +586 -0
- gstaichi/profiler/memory_profiler.py +15 -0
- gstaichi/profiler/scoped_profiler.py +36 -0
- gstaichi/sparse/__init__.py +3 -0
- gstaichi/sparse/_sparse_grid.py +77 -0
- gstaichi/tools/__init__.py +12 -0
- gstaichi/tools/diagnose.py +117 -0
- gstaichi/tools/np2ply.py +364 -0
- gstaichi/tools/vtk.py +38 -0
- gstaichi/types/__init__.py +19 -0
- gstaichi/types/annotations.py +52 -0
- gstaichi/types/compound_types.py +71 -0
- gstaichi/types/enums.py +49 -0
- gstaichi/types/ndarray_type.py +169 -0
- gstaichi/types/primitive_types.py +206 -0
- gstaichi/types/quant.py +88 -0
- gstaichi/types/texture_type.py +85 -0
- gstaichi/types/utils.py +11 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
- gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
- gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
- gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
- gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
- gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
gstaichi/lang/snode.py
ADDED
@@ -0,0 +1,489 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import numbers
|
4
|
+
|
5
|
+
from gstaichi._lib import core as _ti_core
|
6
|
+
from gstaichi._lib.core.gstaichi_python import (
|
7
|
+
Axis,
|
8
|
+
SNodeCxx,
|
9
|
+
)
|
10
|
+
from gstaichi.lang import expr, impl, matrix
|
11
|
+
from gstaichi.lang.exception import GsTaichiRuntimeError
|
12
|
+
from gstaichi.lang.field import BitpackedFields, Field
|
13
|
+
from gstaichi.lang.util import get_traceback
|
14
|
+
|
15
|
+
|
16
|
+
class SNode:
|
17
|
+
"""A Python-side SNode wrapper.
|
18
|
+
|
19
|
+
For more information on GsTaichi's SNode system, please check out
|
20
|
+
these references:
|
21
|
+
|
22
|
+
* https://docs.taichi-lang.org/docs/sparse
|
23
|
+
* https://yuanming.gstaichi.graphics/publication/2019-gstaichi/gstaichi-lang.pdf
|
24
|
+
|
25
|
+
Arg:
|
26
|
+
ptr (pointer): The C++ side SNode pointer.
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(self, ptr: SNodeCxx) -> None:
|
30
|
+
self.ptr = ptr
|
31
|
+
|
32
|
+
def dense(self, axes: list[Axis], dimensions: list[int] | int) -> "SNode":
|
33
|
+
"""Adds a dense SNode as a child component of `self`.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
axes (List[Axis]): Axes to activate.
|
37
|
+
dimensions (Union[List[int], int]): Shape of each axis.
|
38
|
+
|
39
|
+
Returns:
|
40
|
+
The added :class:`~gstaichi.lang.SNode` instance.
|
41
|
+
"""
|
42
|
+
if isinstance(dimensions, numbers.Number):
|
43
|
+
dimensions = [dimensions] * len(axes)
|
44
|
+
return SNode(self.ptr.dense(axes, dimensions, _ti_core.DebugInfo(get_traceback())))
|
45
|
+
|
46
|
+
def pointer(self, axes: list[Axis], dimensions: list[int] | int) -> "SNode":
|
47
|
+
"""Adds a pointer SNode as a child component of `self`.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
axes (List[Axis]): Axes to activate.
|
51
|
+
dimensions (Union[List[int], int]): Shape of each axis.
|
52
|
+
|
53
|
+
Returns:
|
54
|
+
The added :class:`~gstaichi.lang.SNode` instance.
|
55
|
+
"""
|
56
|
+
if not _ti_core.is_extension_supported(impl.current_cfg().arch, _ti_core.Extension.sparse):
|
57
|
+
raise GsTaichiRuntimeError("Pointer SNode is not supported on this backend.")
|
58
|
+
if isinstance(dimensions, numbers.Number):
|
59
|
+
dimensions = [dimensions] * len(axes)
|
60
|
+
return SNode(self.ptr.pointer(axes, dimensions, _ti_core.DebugInfo(get_traceback())))
|
61
|
+
|
62
|
+
@staticmethod
|
63
|
+
def _hash(axes, dimensions):
|
64
|
+
# original code is #def hash(self,axes, dimensions) without #@staticmethod before fix pylint R0201
|
65
|
+
"""Not supported."""
|
66
|
+
raise RuntimeError("hash not yet supported")
|
67
|
+
# if isinstance(dimensions, int):
|
68
|
+
# dimensions = [dimensions] * len(axes)
|
69
|
+
# return SNode(self.ptr.hash(axes, dimensions))
|
70
|
+
|
71
|
+
def dynamic(self, axis: list[Axis], dimension: int, chunk_size: int | None = None) -> "SNode":
|
72
|
+
"""Adds a dynamic SNode as a child component of `self`.
|
73
|
+
|
74
|
+
Args:
|
75
|
+
axis (List[Axis]): Axis to activate, must be 1.
|
76
|
+
dimension (int): Shape of the axis.
|
77
|
+
chunk_size (int): Chunk size.
|
78
|
+
|
79
|
+
Returns:
|
80
|
+
The added :class:`~gstaichi.lang.SNode` instance.
|
81
|
+
"""
|
82
|
+
if not _ti_core.is_extension_supported(impl.current_cfg().arch, _ti_core.Extension.sparse):
|
83
|
+
raise GsTaichiRuntimeError("Dynamic SNode is not supported on this backend.")
|
84
|
+
assert len(axis) == 1
|
85
|
+
if chunk_size is None:
|
86
|
+
chunk_size = dimension
|
87
|
+
return SNode(self.ptr.dynamic(axis[0], dimension, chunk_size, _ti_core.DebugInfo(get_traceback())))
|
88
|
+
|
89
|
+
def bitmasked(self, axes: list[Axis], dimensions: list[int] | int) -> "SNode":
|
90
|
+
"""Adds a bitmasked SNode as a child component of `self`.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
axes (List[Axis]): Axes to activate.
|
94
|
+
dimensions (Union[List[int], int]): Shape of each axis.
|
95
|
+
|
96
|
+
Returns:
|
97
|
+
The added :class:`~gstaichi.lang.SNode` instance.
|
98
|
+
"""
|
99
|
+
if not _ti_core.is_extension_supported(impl.current_cfg().arch, _ti_core.Extension.sparse):
|
100
|
+
raise GsTaichiRuntimeError("Bitmasked SNode is not supported on this backend.")
|
101
|
+
if isinstance(dimensions, numbers.Number):
|
102
|
+
dimensions = [dimensions] * len(axes)
|
103
|
+
return SNode(self.ptr.bitmasked(axes, dimensions, _ti_core.DebugInfo(get_traceback())))
|
104
|
+
|
105
|
+
def quant_array(self, axes: list[Axis], dimensions: list[int] | int, max_num_bits: int) -> "SNode":
|
106
|
+
"""Adds a quant_array SNode as a child component of `self`.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
axes (List[Axis]): Axes to activate.
|
110
|
+
dimensions (Union[List[int], int]): Shape of each axis.
|
111
|
+
max_num_bits (int): Maximum number of bits it can hold.
|
112
|
+
|
113
|
+
Returns:
|
114
|
+
The added :class:`~gstaichi.lang.SNode` instance.
|
115
|
+
"""
|
116
|
+
if isinstance(dimensions, numbers.Number):
|
117
|
+
dimensions = [dimensions] * len(axes)
|
118
|
+
return SNode(self.ptr.quant_array(axes, dimensions, max_num_bits, _ti_core.DebugInfo(get_traceback())))
|
119
|
+
|
120
|
+
def place(self, *args, offset: numbers.Number | tuple[numbers.Number] | None = None) -> "SNode":
|
121
|
+
"""Places a list of GsTaichi fields under the `self` container.
|
122
|
+
|
123
|
+
Args:
|
124
|
+
*args (List[ti.field]): A list of GsTaichi fields to place.
|
125
|
+
offset (Union[Number, tuple[Number]]): Offset of the field domain.
|
126
|
+
|
127
|
+
Returns:
|
128
|
+
The `self` container.
|
129
|
+
"""
|
130
|
+
if offset is None:
|
131
|
+
offset = ()
|
132
|
+
if isinstance(offset, numbers.Number):
|
133
|
+
offset = (offset,)
|
134
|
+
|
135
|
+
for arg in args:
|
136
|
+
if isinstance(arg, BitpackedFields):
|
137
|
+
bit_struct_type = arg.bit_struct_type_builder.build()
|
138
|
+
bit_struct_snode = self.ptr.bit_struct(bit_struct_type, _ti_core.DebugInfo(get_traceback()))
|
139
|
+
for field, id_in_bit_struct in arg.fields:
|
140
|
+
bit_struct_snode.place(field, offset, id_in_bit_struct)
|
141
|
+
elif isinstance(arg, Field):
|
142
|
+
for var in arg._get_field_members():
|
143
|
+
self.ptr.place(var.ptr, offset, -1)
|
144
|
+
elif isinstance(arg, list):
|
145
|
+
for x in arg:
|
146
|
+
self.place(x, offset=offset)
|
147
|
+
else:
|
148
|
+
raise ValueError(f"{arg} cannot be placed")
|
149
|
+
return self
|
150
|
+
|
151
|
+
def lazy_grad(self):
|
152
|
+
"""Automatically place the adjoint fields following the layout of their primal fields.
|
153
|
+
|
154
|
+
Users don't need to specify ``needs_grad`` when they define scalar/vector/matrix fields (primal fields) using autodiff.
|
155
|
+
When all the primal fields are defined, using ``gstaichi.root.lazy_grad()`` could automatically generate
|
156
|
+
their corresponding adjoint fields (gradient field).
|
157
|
+
|
158
|
+
To know more details about primal, adjoint fields and ``lazy_grad()``,
|
159
|
+
please see Page 4 and Page 13-14 of DiffGsTaichi Paper: https://arxiv.org/pdf/1910.00935.pdf
|
160
|
+
"""
|
161
|
+
self.ptr.lazy_grad()
|
162
|
+
|
163
|
+
def lazy_dual(self):
|
164
|
+
"""Automatically place the dual fields following the layout of their primal fields."""
|
165
|
+
self.ptr.lazy_dual()
|
166
|
+
|
167
|
+
def _allocate_adjoint_checkbit(self):
|
168
|
+
"""Automatically place the adjoint flag fields following the layout of their primal fields for global data access rule checker"""
|
169
|
+
self.ptr.allocate_adjoint_checkbit()
|
170
|
+
|
171
|
+
def parent(self, n=1):
|
172
|
+
"""Gets an ancestor of `self` in the SNode tree.
|
173
|
+
|
174
|
+
Args:
|
175
|
+
n (int): the number of levels going up from `self`.
|
176
|
+
|
177
|
+
Returns:
|
178
|
+
Union[None, _Root, SNode]: The n-th parent of `self`.
|
179
|
+
"""
|
180
|
+
p = self.ptr
|
181
|
+
while p and n > 0:
|
182
|
+
p = p.parent
|
183
|
+
n -= 1
|
184
|
+
if p is None:
|
185
|
+
return None
|
186
|
+
|
187
|
+
if p.type == _ti_core.SNodeType.root:
|
188
|
+
return impl.root
|
189
|
+
|
190
|
+
return SNode(p)
|
191
|
+
|
192
|
+
def _path_from_root(self):
|
193
|
+
"""Gets the path from root to `self` in the SNode tree.
|
194
|
+
|
195
|
+
Returns:
|
196
|
+
List[Union[_Root, SNode]]: The list of SNodes on the path from root to `self`.
|
197
|
+
"""
|
198
|
+
p = self
|
199
|
+
res = [p]
|
200
|
+
while p != impl.root:
|
201
|
+
p = p.parent()
|
202
|
+
res.append(p)
|
203
|
+
res.reverse()
|
204
|
+
return res
|
205
|
+
|
206
|
+
@property
|
207
|
+
def _dtype(self):
|
208
|
+
"""Gets the data type of `self`.
|
209
|
+
|
210
|
+
Returns:
|
211
|
+
DataType: The data type of `self`.
|
212
|
+
"""
|
213
|
+
return self.ptr.data_type()
|
214
|
+
|
215
|
+
@property
|
216
|
+
def _id(self):
|
217
|
+
"""Gets the id of `self`.
|
218
|
+
|
219
|
+
Returns:
|
220
|
+
int: The id of `self`.
|
221
|
+
"""
|
222
|
+
return self.ptr.id
|
223
|
+
|
224
|
+
@property
|
225
|
+
def _snode_tree_id(self):
|
226
|
+
return self.ptr.get_snode_tree_id()
|
227
|
+
|
228
|
+
@property
|
229
|
+
def shape(self):
|
230
|
+
"""Gets the number of elements from root in each axis of `self`.
|
231
|
+
|
232
|
+
Returns:
|
233
|
+
Tuple[int]: The number of elements from root in each axis of `self`.
|
234
|
+
"""
|
235
|
+
dim = self.ptr.num_active_indices()
|
236
|
+
ret = tuple(self.ptr.get_shape_along_axis(i) for i in range(dim))
|
237
|
+
|
238
|
+
return ret
|
239
|
+
|
240
|
+
def _loop_range(self):
|
241
|
+
"""Gets the gstaichi_python.SNode to serve as loop range.
|
242
|
+
|
243
|
+
Returns:
|
244
|
+
gstaichi_python.SNode: See above.
|
245
|
+
"""
|
246
|
+
return self.ptr
|
247
|
+
|
248
|
+
@property
|
249
|
+
def _name(self):
|
250
|
+
"""Gets the name of `self`.
|
251
|
+
|
252
|
+
Returns:
|
253
|
+
str: The name of `self`.
|
254
|
+
"""
|
255
|
+
return self.ptr.name()
|
256
|
+
|
257
|
+
@property
|
258
|
+
def _snode(self):
|
259
|
+
"""Gets `self`.
|
260
|
+
Returns:
|
261
|
+
SNode: `self`.
|
262
|
+
"""
|
263
|
+
return self
|
264
|
+
|
265
|
+
def _get_children(self):
|
266
|
+
"""Gets all children components of `self`.
|
267
|
+
|
268
|
+
Returns:
|
269
|
+
List[SNode]: All children components of `self`.
|
270
|
+
"""
|
271
|
+
children = []
|
272
|
+
for i in range(self.ptr.get_num_ch()):
|
273
|
+
children.append(SNode(self.ptr.get_ch(i)))
|
274
|
+
return children
|
275
|
+
|
276
|
+
@property
|
277
|
+
def _num_dynamically_allocated(self):
|
278
|
+
runtime = impl.get_runtime()
|
279
|
+
runtime.materialize_root_fb(False)
|
280
|
+
return runtime.prog.get_snode_num_dynamically_allocated(self.ptr)
|
281
|
+
|
282
|
+
@property
|
283
|
+
def _cell_size_bytes(self):
|
284
|
+
impl.get_runtime().materialize_root_fb(False)
|
285
|
+
return self.ptr.cell_size_bytes
|
286
|
+
|
287
|
+
@property
|
288
|
+
def _offset_bytes_in_parent_cell(self):
|
289
|
+
impl.get_runtime().materialize_root_fb(False)
|
290
|
+
return self.ptr.offset_bytes_in_parent_cell
|
291
|
+
|
292
|
+
def deactivate_all(self):
|
293
|
+
"""Recursively deactivate all children components of `self`."""
|
294
|
+
ch = self._get_children()
|
295
|
+
for c in ch:
|
296
|
+
c.deactivate_all()
|
297
|
+
SNodeType = _ti_core.SNodeType
|
298
|
+
if self.ptr.type == SNodeType.pointer or self.ptr.type == SNodeType.bitmasked:
|
299
|
+
from gstaichi._kernels import snode_deactivate # pylint: disable=C0415
|
300
|
+
|
301
|
+
snode_deactivate(self)
|
302
|
+
if self.ptr.type == SNodeType.dynamic:
|
303
|
+
# Note that dynamic nodes are different from other sparse nodes:
|
304
|
+
# instead of deactivating each element, we only need to deactivate
|
305
|
+
# its parent, whose linked list of chunks of elements will be deleted.
|
306
|
+
from gstaichi._kernels import ( # pylint: disable=C0415
|
307
|
+
snode_deactivate_dynamic,
|
308
|
+
)
|
309
|
+
|
310
|
+
snode_deactivate_dynamic(self)
|
311
|
+
|
312
|
+
def __repr__(self):
|
313
|
+
type_ = str(self.ptr.type)[len("SNodeType.") :]
|
314
|
+
return f"<ti.SNode of type {type_}>"
|
315
|
+
|
316
|
+
def __str__(self):
|
317
|
+
# ti.root.dense(ti.i, 3).dense(ti.jk, (4, 5)).place(x)
|
318
|
+
# ti.root => dense [3] => dense [3, 4, 5] => place [3, 4, 5]
|
319
|
+
type_ = str(self.ptr.type)[len("SNodeType.") :]
|
320
|
+
shape = str(list(self.shape))
|
321
|
+
parent = str(self.parent())
|
322
|
+
return f"{parent} => {type_} {shape}"
|
323
|
+
|
324
|
+
def __eq__(self, other):
|
325
|
+
return self.ptr == other.ptr
|
326
|
+
|
327
|
+
def _physical_index_position(self):
|
328
|
+
"""Gets mappings from virtual axes to physical axes.
|
329
|
+
|
330
|
+
Returns:
|
331
|
+
Dict[int, int]: Mappings from virtual axes to physical axes.
|
332
|
+
"""
|
333
|
+
ret = {}
|
334
|
+
for virtual, physical in enumerate(self.ptr.get_physical_index_position()):
|
335
|
+
if physical != -1:
|
336
|
+
ret[virtual] = physical
|
337
|
+
return ret
|
338
|
+
|
339
|
+
|
340
|
+
def rescale_index(a, b, I):
|
341
|
+
"""Rescales the index 'I' of field (or SNode) 'a' to match the shape of SNode 'b'.
|
342
|
+
|
343
|
+
Args:
|
344
|
+
|
345
|
+
a, b (Union[:class:`~gstaichi.Field`, :class:`~gstaichi.MatrixField`): Input gstaichi fields or snodes.
|
346
|
+
I (Union[list, :class:`~gstaichi.Vector`]): grouped loop index.
|
347
|
+
|
348
|
+
Returns:
|
349
|
+
Ib (:class:`~gstaichi.Vector`): rescaled grouped loop index
|
350
|
+
"""
|
351
|
+
|
352
|
+
assert isinstance(a, (Field, SNode)), "The first argument must be a field or an SNode"
|
353
|
+
assert isinstance(b, (Field, SNode)), "The second argument must be a field or an SNode"
|
354
|
+
if isinstance(I, list):
|
355
|
+
n = len(I)
|
356
|
+
else:
|
357
|
+
assert isinstance(
|
358
|
+
I, (expr.Expr, matrix.Matrix)
|
359
|
+
), "The third argument must be an index (list, ti.Vector, or Expr with TensorType)"
|
360
|
+
n = I.n
|
361
|
+
|
362
|
+
from gstaichi.lang.kernel_impl import pyfunc # pylint: disable=C0415
|
363
|
+
|
364
|
+
@pyfunc
|
365
|
+
def _rescale_index():
|
366
|
+
result = matrix.Vector([I[i] for i in range(n)])
|
367
|
+
for i in impl.static(range(min(n, min(len(a.shape), len(b.shape))))):
|
368
|
+
if a.shape[i] > b.shape[i]:
|
369
|
+
result[i] = I[i] // (a.shape[i] // b.shape[i])
|
370
|
+
if a.shape[i] < b.shape[i]:
|
371
|
+
result[i] = I[i] * (b.shape[i] // a.shape[i])
|
372
|
+
return result
|
373
|
+
|
374
|
+
return _rescale_index()
|
375
|
+
|
376
|
+
|
377
|
+
def append(node, indices, val):
|
378
|
+
"""Append a value `val` to a SNode `node` at index `indices`.
|
379
|
+
|
380
|
+
Args:
|
381
|
+
node (:class:`~gstaichi.SNode`): Input SNode.
|
382
|
+
indices (Union[int, :class:`~gstaichi.Vector`]): the indices to visit.
|
383
|
+
val (Union[:mod:`~gstaichi.types.primitive_types`, :mod:`~gstaichi.types.compound_types`]): the data to be appended.
|
384
|
+
"""
|
385
|
+
ptrs = expr._get_flattened_ptrs(val)
|
386
|
+
append_expr = expr.Expr(
|
387
|
+
impl.get_runtime()
|
388
|
+
.compiling_callable.ast_builder()
|
389
|
+
.expr_snode_append(node._snode.ptr, expr.make_expr_group(indices), ptrs),
|
390
|
+
dbg_info=_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
|
391
|
+
)
|
392
|
+
a = impl.expr_init(append_expr)
|
393
|
+
return a
|
394
|
+
|
395
|
+
|
396
|
+
def is_active(node, indices):
|
397
|
+
"""Explicitly query whether a cell in a SNode `node` at location
|
398
|
+
`indices` is active or not.
|
399
|
+
|
400
|
+
Args:
|
401
|
+
node (:class:`~gstaichi.SNode`): Must be a pointer, hash or bitmasked node.
|
402
|
+
indices (Union[int, list, :class:`~gstaichi.Vector`]): the indices to visit.
|
403
|
+
|
404
|
+
Returns:
|
405
|
+
bool: the cell `node[indices]` is active or not.
|
406
|
+
"""
|
407
|
+
return expr.Expr(
|
408
|
+
impl.get_runtime()
|
409
|
+
.compiling_callable.ast_builder()
|
410
|
+
.expr_snode_is_active(node._snode.ptr, expr.make_expr_group(indices)),
|
411
|
+
dbg_info=_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
|
412
|
+
)
|
413
|
+
|
414
|
+
|
415
|
+
def activate(node, indices):
|
416
|
+
"""Explicitly activate a cell of `node` at location `indices`.
|
417
|
+
|
418
|
+
Args:
|
419
|
+
node (:class:`~gstaichi.SNode`): Must be a pointer, hash or bitmasked node.
|
420
|
+
indices (Union[int, :class:`~gstaichi.Vector`]): the indices to activate.
|
421
|
+
"""
|
422
|
+
impl.get_runtime().compiling_callable.ast_builder().insert_activate(
|
423
|
+
node._snode.ptr, expr.make_expr_group(indices), _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
|
424
|
+
)
|
425
|
+
|
426
|
+
|
427
|
+
def deactivate(node, indices):
|
428
|
+
"""Explicitly deactivate a cell of `node` at location `indices`.
|
429
|
+
|
430
|
+
After deactivation, the GsTaichi runtime automatically recycles and zero-fills
|
431
|
+
the memory of the deactivated cell.
|
432
|
+
|
433
|
+
Args:
|
434
|
+
node (:class:`~gstaichi.SNode`): Must be a pointer, hash or bitmasked node.
|
435
|
+
indices (Union[int, :class:`~gstaichi.Vector`]): the indices to deactivate.
|
436
|
+
"""
|
437
|
+
impl.get_runtime().compiling_callable.ast_builder().insert_deactivate(
|
438
|
+
node._snode.ptr, expr.make_expr_group(indices), _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
|
439
|
+
)
|
440
|
+
|
441
|
+
|
442
|
+
def length(node, indices):
|
443
|
+
"""Return the length of the dynamic SNode `node` at index `indices`.
|
444
|
+
|
445
|
+
Args:
|
446
|
+
node (:class:`~gstaichi.SNode`): a dynamic SNode.
|
447
|
+
indices (Union[int, :class:`~gstaichi.Vector`]): the indices to query.
|
448
|
+
|
449
|
+
Returns:
|
450
|
+
int: the length of cell `node[indices]`.
|
451
|
+
"""
|
452
|
+
return expr.Expr(
|
453
|
+
impl.get_runtime()
|
454
|
+
.compiling_callable.ast_builder()
|
455
|
+
.expr_snode_length(node._snode.ptr, expr.make_expr_group(indices)),
|
456
|
+
dbg_info=_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
|
457
|
+
)
|
458
|
+
|
459
|
+
|
460
|
+
def get_addr(f, indices):
|
461
|
+
"""Query the memory address (on CUDA/x64) of field `f` at index `indices`.
|
462
|
+
|
463
|
+
Currently, this function can only be called inside a gstaichi kernel.
|
464
|
+
|
465
|
+
Args:
|
466
|
+
f (Union[:class:`~gstaichi.Field`, :class:`~gstaichi.MatrixField`]): Input gstaichi field for memory address query.
|
467
|
+
indices (Union[int, :class:`~gstaichi.Vector`]): The specified field indices of the query.
|
468
|
+
|
469
|
+
Returns:
|
470
|
+
ti.u64: The memory address of `f[indices]`.
|
471
|
+
"""
|
472
|
+
return expr.Expr(
|
473
|
+
impl.get_runtime()
|
474
|
+
.compiling_callable.ast_builder()
|
475
|
+
.expr_snode_get_addr(f._snode.ptr, expr.make_expr_group(indices)),
|
476
|
+
dbg_info=_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
|
477
|
+
)
|
478
|
+
|
479
|
+
|
480
|
+
__all__ = [
|
481
|
+
"activate",
|
482
|
+
"append",
|
483
|
+
"deactivate",
|
484
|
+
"get_addr",
|
485
|
+
"is_active",
|
486
|
+
"length",
|
487
|
+
"rescale_index",
|
488
|
+
"SNode",
|
489
|
+
]
|
@@ -0,0 +1,150 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import atexit
|
4
|
+
import ctypes
|
5
|
+
import os
|
6
|
+
import shutil
|
7
|
+
import subprocess
|
8
|
+
import tempfile
|
9
|
+
|
10
|
+
from gstaichi._lib import core as _ti_core
|
11
|
+
from gstaichi.lang import impl
|
12
|
+
from gstaichi.lang.exception import GsTaichiSyntaxError
|
13
|
+
from gstaichi.lang.expr import make_expr_group
|
14
|
+
from gstaichi.lang.util import get_clangpp
|
15
|
+
|
16
|
+
|
17
|
+
class SourceBuilder:
|
18
|
+
def __init__(self):
|
19
|
+
self.bc = None
|
20
|
+
self.so = None
|
21
|
+
self.mode = None
|
22
|
+
self.td = None
|
23
|
+
|
24
|
+
def cleanup():
|
25
|
+
if self.td is not None:
|
26
|
+
shutil.rmtree(self.td)
|
27
|
+
|
28
|
+
atexit.register(cleanup)
|
29
|
+
|
30
|
+
@classmethod
|
31
|
+
def from_file(cls, filename, compile_fn=None, _temp_dir=None):
|
32
|
+
self = cls()
|
33
|
+
self.td = _temp_dir
|
34
|
+
if self.td is None:
|
35
|
+
self.td = tempfile.mkdtemp()
|
36
|
+
|
37
|
+
if filename.endswith((".cpp", ".c", ".cc")):
|
38
|
+
if impl.current_cfg().arch not in [_ti_core.Arch.x64, _ti_core.Arch.cuda]:
|
39
|
+
raise GsTaichiSyntaxError("Unsupported arch for external function call")
|
40
|
+
if compile_fn is None:
|
41
|
+
|
42
|
+
def compile_fn_impl(filename):
|
43
|
+
if impl.current_cfg().arch == _ti_core.Arch.x64:
|
44
|
+
subprocess.call(
|
45
|
+
get_clangpp() + " -flto -c " + filename + " -o " + os.path.join(self.td, "source.bc"),
|
46
|
+
shell=True,
|
47
|
+
)
|
48
|
+
else:
|
49
|
+
subprocess.call(
|
50
|
+
get_clangpp()
|
51
|
+
+ " -flto -c "
|
52
|
+
+ filename
|
53
|
+
+ " -o "
|
54
|
+
+ os.path.join(self.td, "source.bc")
|
55
|
+
+ " -target nvptx64-nvidia-cuda",
|
56
|
+
shell=True,
|
57
|
+
)
|
58
|
+
return os.path.join(self.td, "source.bc")
|
59
|
+
|
60
|
+
compile_fn = compile_fn_impl
|
61
|
+
self.bc = compile_fn(filename)
|
62
|
+
self.mode = "bc"
|
63
|
+
elif filename.endswith(".cu"):
|
64
|
+
if impl.current_cfg().arch not in [_ti_core.Arch.cuda]:
|
65
|
+
raise GsTaichiSyntaxError("Unsupported arch for external function call")
|
66
|
+
if compile_fn is None:
|
67
|
+
shutil.copy(filename, os.path.join(self.td, "source.cu"))
|
68
|
+
|
69
|
+
def compile_fn_impl(filename):
|
70
|
+
# Cannot use -o to specify multiple output files
|
71
|
+
subprocess.call(
|
72
|
+
get_clangpp()
|
73
|
+
+ " "
|
74
|
+
+ os.path.join(self.td, "source.cu")
|
75
|
+
+ " -c -emit-llvm -std=c++17 --cuda-gpu-arch=sm_50 -nocudalib",
|
76
|
+
cwd=self.td,
|
77
|
+
shell=True,
|
78
|
+
)
|
79
|
+
return os.path.join(self.td, "source-cuda-nvptx64-nvidia-cuda-sm_50.bc")
|
80
|
+
|
81
|
+
compile_fn = compile_fn_impl
|
82
|
+
self.bc = compile_fn(filename)
|
83
|
+
self.mode = "bc"
|
84
|
+
elif filename.endswith((".so", ".dylib", ".dll")):
|
85
|
+
if impl.current_cfg().arch not in [_ti_core.Arch.x64]:
|
86
|
+
raise GsTaichiSyntaxError("Unsupported arch for external function call")
|
87
|
+
self.so = ctypes.CDLL(filename)
|
88
|
+
self.mode = "so"
|
89
|
+
elif filename.endswith(".ll"):
|
90
|
+
if impl.current_cfg().arch not in [_ti_core.Arch.x64, _ti_core.Arch.cuda]:
|
91
|
+
raise GsTaichiSyntaxError("Unsupported arch for external function call")
|
92
|
+
subprocess.call(
|
93
|
+
"llvm-as " + filename + " -o " + os.path.join(self.td, "source.bc"),
|
94
|
+
shell=True,
|
95
|
+
)
|
96
|
+
self.bc = os.path.join(self.td, "source.bc")
|
97
|
+
self.mode = "bc"
|
98
|
+
elif filename.endswith(".bc"):
|
99
|
+
if impl.current_cfg().arch not in [_ti_core.Arch.x64, _ti_core.Arch.cuda]:
|
100
|
+
raise GsTaichiSyntaxError("Unsupported arch for external function call")
|
101
|
+
self.bc = filename
|
102
|
+
self.mode = "bc"
|
103
|
+
else:
|
104
|
+
raise GsTaichiSyntaxError("Unsupported file type for external function call.")
|
105
|
+
return self
|
106
|
+
|
107
|
+
@classmethod
|
108
|
+
def from_source(cls, source_code, compile_fn=None):
|
109
|
+
if impl.current_cfg().arch not in [_ti_core.Arch.x64, _ti_core.Arch.cuda]:
|
110
|
+
raise GsTaichiSyntaxError("Unsupported arch for external function call")
|
111
|
+
_temp_dir = tempfile.mkdtemp()
|
112
|
+
_temp_source = os.path.join(_temp_dir, "_temp_source.cpp")
|
113
|
+
with open(_temp_source, "w") as f:
|
114
|
+
f.write(source_code)
|
115
|
+
return SourceBuilder.from_file(_temp_source, compile_fn, _temp_dir)
|
116
|
+
|
117
|
+
def __getattr__(self, item):
|
118
|
+
def bitcode_func_call_wrapper(*args):
|
119
|
+
impl.get_runtime().compiling_callable.ast_builder().insert_external_func_call(
|
120
|
+
0,
|
121
|
+
"",
|
122
|
+
self.bc,
|
123
|
+
item,
|
124
|
+
make_expr_group(args),
|
125
|
+
make_expr_group([]),
|
126
|
+
_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
|
127
|
+
)
|
128
|
+
|
129
|
+
if self.mode == "bc":
|
130
|
+
return bitcode_func_call_wrapper
|
131
|
+
|
132
|
+
def external_func_call_wrapper(args=[], outputs=[]):
|
133
|
+
func_addr = ctypes.cast(self.so.__getattr__(item), ctypes.c_void_p).value
|
134
|
+
impl.get_runtime().compiling_callable.ast_builder().insert_external_func_call(
|
135
|
+
func_addr,
|
136
|
+
"",
|
137
|
+
"",
|
138
|
+
"",
|
139
|
+
make_expr_group(args),
|
140
|
+
make_expr_group(outputs),
|
141
|
+
_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
|
142
|
+
)
|
143
|
+
|
144
|
+
if self.mode == "so":
|
145
|
+
return external_func_call_wrapper
|
146
|
+
|
147
|
+
raise GsTaichiSyntaxError("Error occurs when calling external function.")
|
148
|
+
|
149
|
+
|
150
|
+
__all__ = []
|