gstaichi 2.1.1rc3__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +4 -0
- gstaichi/__init__.py +40 -0
- gstaichi/_funcs.py +706 -0
- gstaichi/_kernels.py +420 -0
- gstaichi/_lib/__init__.py +3 -0
- gstaichi/_lib/core/__init__.py +0 -0
- gstaichi/_lib/core/gstaichi_python.cpython-310-darwin.so +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
- gstaichi/_lib/core/py.typed +0 -0
- gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
- gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
- gstaichi/_lib/utils.py +243 -0
- gstaichi/_logging.py +131 -0
- gstaichi/_snode/__init__.py +5 -0
- gstaichi/_snode/fields_builder.py +187 -0
- gstaichi/_snode/snode_tree.py +34 -0
- gstaichi/_test_tools/__init__.py +18 -0
- gstaichi/_test_tools/dataclass_test_tools.py +36 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_test_tools/textwrap2.py +6 -0
- gstaichi/_version.py +1 -0
- gstaichi/_version_check.py +100 -0
- gstaichi/ad/__init__.py +3 -0
- gstaichi/ad/_ad.py +530 -0
- gstaichi/algorithms/__init__.py +3 -0
- gstaichi/algorithms/_algorithms.py +117 -0
- gstaichi/assets/.git +1 -0
- gstaichi/assets/Go-Regular.ttf +0 -0
- gstaichi/assets/static/imgs/ti_gallery.png +0 -0
- gstaichi/examples/lcg_python.py +26 -0
- gstaichi/examples/lcg_taichi.py +34 -0
- gstaichi/examples/minimal.py +28 -0
- gstaichi/experimental.py +16 -0
- gstaichi/lang/__init__.py +50 -0
- gstaichi/lang/_dataclass_util.py +31 -0
- gstaichi/lang/_fast_caching/__init__.py +3 -0
- gstaichi/lang/_fast_caching/args_hasher.py +110 -0
- gstaichi/lang/_fast_caching/config_hasher.py +30 -0
- gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
- gstaichi/lang/_fast_caching/function_hasher.py +57 -0
- gstaichi/lang/_fast_caching/hash_utils.py +11 -0
- gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
- gstaichi/lang/_fast_caching/src_hasher.py +75 -0
- gstaichi/lang/_kernel_impl_dataclass.py +212 -0
- gstaichi/lang/_ndarray.py +352 -0
- gstaichi/lang/_ndrange.py +152 -0
- gstaichi/lang/_template_mapper.py +195 -0
- gstaichi/lang/_texture.py +172 -0
- gstaichi/lang/_wrap_inspect.py +215 -0
- gstaichi/lang/any_array.py +99 -0
- gstaichi/lang/ast/__init__.py +5 -0
- gstaichi/lang/ast/ast_transformer.py +1323 -0
- gstaichi/lang/ast/ast_transformer_utils.py +346 -0
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
- gstaichi/lang/ast/checkers.py +106 -0
- gstaichi/lang/ast/symbol_resolver.py +57 -0
- gstaichi/lang/ast/transform.py +9 -0
- gstaichi/lang/common_ops.py +310 -0
- gstaichi/lang/exception.py +80 -0
- gstaichi/lang/expr.py +180 -0
- gstaichi/lang/field.py +428 -0
- gstaichi/lang/impl.py +1243 -0
- gstaichi/lang/kernel_arguments.py +155 -0
- gstaichi/lang/kernel_impl.py +1341 -0
- gstaichi/lang/matrix.py +1835 -0
- gstaichi/lang/matrix_ops.py +341 -0
- gstaichi/lang/matrix_ops_utils.py +190 -0
- gstaichi/lang/mesh.py +687 -0
- gstaichi/lang/misc.py +782 -0
- gstaichi/lang/ops.py +1494 -0
- gstaichi/lang/runtime_ops.py +13 -0
- gstaichi/lang/shell.py +35 -0
- gstaichi/lang/simt/__init__.py +5 -0
- gstaichi/lang/simt/block.py +94 -0
- gstaichi/lang/simt/grid.py +7 -0
- gstaichi/lang/simt/subgroup.py +191 -0
- gstaichi/lang/simt/warp.py +96 -0
- gstaichi/lang/snode.py +489 -0
- gstaichi/lang/source_builder.py +150 -0
- gstaichi/lang/struct.py +810 -0
- gstaichi/lang/util.py +312 -0
- gstaichi/linalg/__init__.py +8 -0
- gstaichi/linalg/matrixfree_cg.py +310 -0
- gstaichi/linalg/sparse_cg.py +59 -0
- gstaichi/linalg/sparse_matrix.py +303 -0
- gstaichi/linalg/sparse_solver.py +123 -0
- gstaichi/math/__init__.py +11 -0
- gstaichi/math/_complex.py +205 -0
- gstaichi/math/mathimpl.py +886 -0
- gstaichi/profiler/__init__.py +6 -0
- gstaichi/profiler/kernel_metrics.py +260 -0
- gstaichi/profiler/kernel_profiler.py +586 -0
- gstaichi/profiler/memory_profiler.py +15 -0
- gstaichi/profiler/scoped_profiler.py +36 -0
- gstaichi/sparse/__init__.py +3 -0
- gstaichi/sparse/_sparse_grid.py +77 -0
- gstaichi/tools/__init__.py +12 -0
- gstaichi/tools/diagnose.py +117 -0
- gstaichi/tools/np2ply.py +364 -0
- gstaichi/tools/vtk.py +38 -0
- gstaichi/types/__init__.py +19 -0
- gstaichi/types/annotations.py +52 -0
- gstaichi/types/compound_types.py +71 -0
- gstaichi/types/enums.py +49 -0
- gstaichi/types/ndarray_type.py +169 -0
- gstaichi/types/primitive_types.py +206 -0
- gstaichi/types/quant.py +88 -0
- gstaichi/types/texture_type.py +85 -0
- gstaichi/types/utils.py +11 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
- gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
- gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
- gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
- gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
- gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
gstaichi/lang/util.py
ADDED
@@ -0,0 +1,312 @@
|
|
1
|
+
import functools
|
2
|
+
import os
|
3
|
+
import traceback
|
4
|
+
import warnings
|
5
|
+
from typing import Any
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
from colorama import Fore, Style
|
9
|
+
|
10
|
+
from gstaichi._lib import core as _ti_core
|
11
|
+
from gstaichi._logging import is_logging_effective
|
12
|
+
from gstaichi.lang import impl
|
13
|
+
from gstaichi.types import Template
|
14
|
+
from gstaichi.types.primitive_types import (
|
15
|
+
f16,
|
16
|
+
f32,
|
17
|
+
f64,
|
18
|
+
i8,
|
19
|
+
i16,
|
20
|
+
i32,
|
21
|
+
i64,
|
22
|
+
u1,
|
23
|
+
u8,
|
24
|
+
u16,
|
25
|
+
u32,
|
26
|
+
u64,
|
27
|
+
)
|
28
|
+
|
29
|
+
|
30
|
+
def has_pytorch():
|
31
|
+
"""Whether has pytorch in the current Python environment.
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
bool: True if has pytorch else False.
|
35
|
+
|
36
|
+
"""
|
37
|
+
_has_pytorch = False
|
38
|
+
_env_torch = os.environ.get("TI_ENABLE_TORCH", "1")
|
39
|
+
if not _env_torch or int(_env_torch):
|
40
|
+
try:
|
41
|
+
import torch # pylint: disable=C0415
|
42
|
+
|
43
|
+
_has_pytorch = True
|
44
|
+
except:
|
45
|
+
pass
|
46
|
+
return _has_pytorch
|
47
|
+
|
48
|
+
|
49
|
+
def get_clangpp():
|
50
|
+
from distutils.spawn import find_executable # pylint: disable=C0415
|
51
|
+
|
52
|
+
# GsTaichi itself uses llvm-10.0.0 to compile.
|
53
|
+
# There will be some issues compiling CUDA with other clang++ version.
|
54
|
+
_clangpp_candidates = ["clang++-10"]
|
55
|
+
for c in _clangpp_candidates:
|
56
|
+
if find_executable(c) is not None:
|
57
|
+
_clangpp_presence = find_executable(c)
|
58
|
+
return _clangpp_presence
|
59
|
+
return None
|
60
|
+
|
61
|
+
|
62
|
+
def has_clangpp():
|
63
|
+
return get_clangpp() is not None
|
64
|
+
|
65
|
+
|
66
|
+
def is_matrix_class(rhs):
|
67
|
+
matrix_class = False
|
68
|
+
try:
|
69
|
+
if rhs._is_matrix_class:
|
70
|
+
matrix_class = True
|
71
|
+
except:
|
72
|
+
pass
|
73
|
+
return matrix_class
|
74
|
+
|
75
|
+
|
76
|
+
def is_gstaichi_class(rhs):
|
77
|
+
gstaichi_class = False
|
78
|
+
try:
|
79
|
+
if rhs._is_gstaichi_class:
|
80
|
+
gstaichi_class = True
|
81
|
+
except:
|
82
|
+
pass
|
83
|
+
return gstaichi_class
|
84
|
+
|
85
|
+
|
86
|
+
def to_numpy_type(dt):
|
87
|
+
"""Convert gstaichi data type to its counterpart in numpy.
|
88
|
+
|
89
|
+
Args:
|
90
|
+
dt (DataType): The desired data type to convert.
|
91
|
+
|
92
|
+
Returns:
|
93
|
+
DataType: The counterpart data type in numpy.
|
94
|
+
|
95
|
+
"""
|
96
|
+
if dt == f32:
|
97
|
+
return np.float32
|
98
|
+
if dt == f64:
|
99
|
+
return np.float64
|
100
|
+
if dt == i32:
|
101
|
+
return np.int32
|
102
|
+
if dt == i64:
|
103
|
+
return np.int64
|
104
|
+
if dt == i8:
|
105
|
+
return np.int8
|
106
|
+
if dt == i16:
|
107
|
+
return np.int16
|
108
|
+
if dt == u1:
|
109
|
+
return np.bool_
|
110
|
+
if dt == u8:
|
111
|
+
return np.uint8
|
112
|
+
if dt == u16:
|
113
|
+
return np.uint16
|
114
|
+
if dt == u32:
|
115
|
+
return np.uint32
|
116
|
+
if dt == u64:
|
117
|
+
return np.uint64
|
118
|
+
if dt == f16:
|
119
|
+
return np.half
|
120
|
+
assert False
|
121
|
+
|
122
|
+
|
123
|
+
def to_pytorch_type(dt):
|
124
|
+
"""Convert gstaichi data type to its counterpart in torch.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
dt (DataType): The desired data type to convert.
|
128
|
+
|
129
|
+
Returns:
|
130
|
+
DataType: The counterpart data type in torch.
|
131
|
+
|
132
|
+
"""
|
133
|
+
import torch # pylint: disable=C0415
|
134
|
+
|
135
|
+
# pylint: disable=E1101
|
136
|
+
if dt == f32:
|
137
|
+
return torch.float32
|
138
|
+
if dt == f64:
|
139
|
+
return torch.float64
|
140
|
+
if dt == i32:
|
141
|
+
return torch.int32
|
142
|
+
if dt == i64:
|
143
|
+
return torch.int64
|
144
|
+
if dt == i8:
|
145
|
+
return torch.int8
|
146
|
+
if dt == i16:
|
147
|
+
return torch.int16
|
148
|
+
if dt == u1:
|
149
|
+
return torch.bool
|
150
|
+
if dt == u8:
|
151
|
+
return torch.uint8
|
152
|
+
if dt == f16:
|
153
|
+
return torch.float16
|
154
|
+
|
155
|
+
if dt in (u16, u32, u64):
|
156
|
+
if hasattr(torch, "uint16"):
|
157
|
+
if dt == u16:
|
158
|
+
return torch.uint16
|
159
|
+
if dt == u32:
|
160
|
+
return torch.uint32
|
161
|
+
if dt == u64:
|
162
|
+
return torch.uint64
|
163
|
+
raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.")
|
164
|
+
|
165
|
+
raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type.")
|
166
|
+
|
167
|
+
|
168
|
+
def to_gstaichi_type(dt):
|
169
|
+
"""Convert numpy or torch data type to its counterpart in gstaichi.
|
170
|
+
|
171
|
+
Args:
|
172
|
+
dt (DataType): The desired data type to convert.
|
173
|
+
|
174
|
+
Returns:
|
175
|
+
DataType: The counterpart data type in gstaichi.
|
176
|
+
|
177
|
+
"""
|
178
|
+
if type(dt) == _ti_core.DataTypeCxx:
|
179
|
+
return dt
|
180
|
+
|
181
|
+
if dt == np.float32:
|
182
|
+
return f32
|
183
|
+
if dt == np.float64:
|
184
|
+
return f64
|
185
|
+
if dt == np.int32:
|
186
|
+
return i32
|
187
|
+
if dt == np.int64:
|
188
|
+
return i64
|
189
|
+
if dt == np.int8:
|
190
|
+
return i8
|
191
|
+
if dt == np.int16:
|
192
|
+
return i16
|
193
|
+
if dt == np.bool_:
|
194
|
+
return u1
|
195
|
+
if dt == np.uint8:
|
196
|
+
return u8
|
197
|
+
if dt == np.uint16:
|
198
|
+
return u16
|
199
|
+
if dt == np.uint32:
|
200
|
+
return u32
|
201
|
+
if dt == np.uint64:
|
202
|
+
return u64
|
203
|
+
if dt == np.half:
|
204
|
+
return f16
|
205
|
+
|
206
|
+
if has_pytorch():
|
207
|
+
import torch # pylint: disable=C0415
|
208
|
+
|
209
|
+
# pylint: disable=E1101
|
210
|
+
if dt == torch.float32:
|
211
|
+
return f32
|
212
|
+
if dt == torch.float64:
|
213
|
+
return f64
|
214
|
+
if dt == torch.int32:
|
215
|
+
return i32
|
216
|
+
if dt == torch.int64:
|
217
|
+
return i64
|
218
|
+
if dt == torch.int8:
|
219
|
+
return i8
|
220
|
+
if dt == torch.int16:
|
221
|
+
return i16
|
222
|
+
if dt == torch.bool:
|
223
|
+
return u1
|
224
|
+
if dt == torch.uint8:
|
225
|
+
return u8
|
226
|
+
if dt == torch.float16:
|
227
|
+
return f16
|
228
|
+
|
229
|
+
if hasattr(torch, "uint16"):
|
230
|
+
if dt == torch.uint16:
|
231
|
+
return u16
|
232
|
+
if dt == torch.uint32:
|
233
|
+
return u32
|
234
|
+
if dt == torch.uint64:
|
235
|
+
return u64
|
236
|
+
|
237
|
+
raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.")
|
238
|
+
|
239
|
+
raise AssertionError(f"Unknown type {dt}")
|
240
|
+
|
241
|
+
|
242
|
+
def cook_dtype(dtype):
|
243
|
+
if isinstance(dtype, _ti_core.DataTypeCxx):
|
244
|
+
return dtype
|
245
|
+
if isinstance(dtype, _ti_core.Type):
|
246
|
+
return _ti_core.DataTypeCxx(dtype)
|
247
|
+
if dtype is float:
|
248
|
+
return impl.get_runtime().default_fp
|
249
|
+
if dtype is int:
|
250
|
+
return impl.get_runtime().default_ip
|
251
|
+
if dtype is bool:
|
252
|
+
return u1
|
253
|
+
raise ValueError(f"Invalid data type {dtype}")
|
254
|
+
|
255
|
+
|
256
|
+
def in_gstaichi_scope():
|
257
|
+
return impl.inside_kernel()
|
258
|
+
|
259
|
+
|
260
|
+
def in_python_scope():
|
261
|
+
return not in_gstaichi_scope()
|
262
|
+
|
263
|
+
|
264
|
+
def gstaichi_scope(func):
|
265
|
+
@functools.wraps(func)
|
266
|
+
def wrapped(*args, **kwargs):
|
267
|
+
assert in_gstaichi_scope(), f"{func.__name__} cannot be called in Python-scope"
|
268
|
+
return func(*args, **kwargs)
|
269
|
+
|
270
|
+
return wrapped
|
271
|
+
|
272
|
+
|
273
|
+
def python_scope(func):
|
274
|
+
@functools.wraps(func)
|
275
|
+
def wrapped(*args, **kwargs):
|
276
|
+
assert in_python_scope(), f"{func.__name__} cannot be called in GsTaichi-scope"
|
277
|
+
return func(*args, **kwargs)
|
278
|
+
|
279
|
+
return wrapped
|
280
|
+
|
281
|
+
|
282
|
+
def warning(msg, warning_type=UserWarning, stacklevel=1, print_stack=True):
|
283
|
+
"""Print a warning message. Note that the builtin `warnings` module is
|
284
|
+
unreliable since it may be suppressed by other packages such as IPython.
|
285
|
+
|
286
|
+
Args:
|
287
|
+
msg (str): message to print.
|
288
|
+
warning_type (Type[Warning]): type of warning.
|
289
|
+
stacklevel (int): warning stack level from the caller.
|
290
|
+
print_stack (bool): whether to print the stack
|
291
|
+
"""
|
292
|
+
if not is_logging_effective("warn"):
|
293
|
+
return
|
294
|
+
if print_stack:
|
295
|
+
msg += f"\n{get_traceback(stacklevel)}"
|
296
|
+
warnings.warn(Fore.YELLOW + Style.BRIGHT + msg + Style.RESET_ALL, warning_type)
|
297
|
+
|
298
|
+
|
299
|
+
def get_traceback(stacklevel=1):
|
300
|
+
s = traceback.extract_stack()[: -1 - stacklevel]
|
301
|
+
return "".join(traceback.format_list(s))
|
302
|
+
|
303
|
+
|
304
|
+
def is_data_oriented(obj: Any) -> bool:
|
305
|
+
return getattr(obj, "_data_oriented", False)
|
306
|
+
|
307
|
+
|
308
|
+
def is_ti_template(annotation: Any) -> bool:
|
309
|
+
return annotation == Template or isinstance(annotation, Template)
|
310
|
+
|
311
|
+
|
312
|
+
__all__ = []
|
@@ -0,0 +1,8 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
"""GsTaichi support module for sparse matrix operations."""
|
4
|
+
|
5
|
+
from gstaichi.linalg.matrixfree_cg import *
|
6
|
+
from gstaichi.linalg.sparse_cg import SparseCG
|
7
|
+
from gstaichi.linalg.sparse_matrix import *
|
8
|
+
from gstaichi.linalg.sparse_solver import SparseSolver
|
@@ -0,0 +1,310 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
from math import sqrt
|
4
|
+
|
5
|
+
from gstaichi.lang import misc
|
6
|
+
from gstaichi.lang.exception import GsTaichiRuntimeError, GsTaichiTypeError
|
7
|
+
from gstaichi.lang.impl import FieldsBuilder, field, grouped
|
8
|
+
from gstaichi.lang.kernel_impl import data_oriented, kernel
|
9
|
+
from gstaichi.types import primitive_types, template
|
10
|
+
|
11
|
+
|
12
|
+
@data_oriented
|
13
|
+
class LinearOperator:
|
14
|
+
def __init__(self, matvec_kernel):
|
15
|
+
self._matvec = matvec_kernel
|
16
|
+
|
17
|
+
def matvec(self, x, Ax):
|
18
|
+
if x.shape != Ax.shape:
|
19
|
+
raise GsTaichiRuntimeError(f"Dimension mismatch x.shape{x.shape} != Ax.shape{Ax.shape}.")
|
20
|
+
self._matvec(x, Ax)
|
21
|
+
|
22
|
+
|
23
|
+
def MatrixFreeCG(A, b, x, tol=1e-6, maxiter=5000, quiet=True):
|
24
|
+
"""Matrix-free conjugate-gradient solver.
|
25
|
+
|
26
|
+
Use conjugate-gradient method to solve the linear system Ax = b, where A is implicitly
|
27
|
+
represented as a LinearOperator.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
A (LinearOperator): The coefficient matrix A of the linear system.
|
31
|
+
b (Field): The right-hand side of the linear system.
|
32
|
+
x (Field): The initial guess for the solution.
|
33
|
+
maxiter (int): Maximum number of iterations.
|
34
|
+
atol: Tolerance(absolute) for convergence.
|
35
|
+
quiet (bool): Switch to turn on/off iteration log.
|
36
|
+
"""
|
37
|
+
|
38
|
+
if b.dtype != x.dtype:
|
39
|
+
raise GsTaichiTypeError(f"Dtype mismatch b.dtype({b.dtype}) != x.dtype({x.dtype}).")
|
40
|
+
if str(b.dtype) == "f32":
|
41
|
+
solver_dtype = primitive_types.f32
|
42
|
+
elif str(b.dtype) == "f64":
|
43
|
+
solver_dtype = primitive_types.f64
|
44
|
+
else:
|
45
|
+
raise GsTaichiTypeError(f"Not supported dtype: {b.dtype}")
|
46
|
+
if b.shape != x.shape:
|
47
|
+
raise GsTaichiRuntimeError(f"Dimension mismatch b.shape{b.shape} != x.shape{x.shape}.")
|
48
|
+
|
49
|
+
size = b.shape
|
50
|
+
vector_fields_builder = FieldsBuilder()
|
51
|
+
p = field(dtype=solver_dtype)
|
52
|
+
r = field(dtype=solver_dtype)
|
53
|
+
Ap = field(dtype=solver_dtype)
|
54
|
+
Ax = field(dtype=solver_dtype)
|
55
|
+
if len(size) == 1:
|
56
|
+
axes = misc.i
|
57
|
+
elif len(size) == 2:
|
58
|
+
axes = misc.ij
|
59
|
+
elif len(size) == 3:
|
60
|
+
axes = misc.ijk
|
61
|
+
else:
|
62
|
+
raise GsTaichiRuntimeError(f"MatrixFreeCG only support 1D, 2D, 3D inputs; your inputs is {len(size)}-D.")
|
63
|
+
vector_fields_builder.dense(axes, size).place(p, r, Ap, Ax)
|
64
|
+
vector_fields_snode_tree = vector_fields_builder.finalize()
|
65
|
+
|
66
|
+
scalar_builder = FieldsBuilder()
|
67
|
+
alpha = field(dtype=solver_dtype)
|
68
|
+
beta = field(dtype=solver_dtype)
|
69
|
+
scalar_builder.place(alpha, beta)
|
70
|
+
scalar_snode_tree = scalar_builder.finalize()
|
71
|
+
|
72
|
+
@kernel
|
73
|
+
def init():
|
74
|
+
for I in grouped(x):
|
75
|
+
r[I] = b[I] - Ax[I]
|
76
|
+
p[I] = 0.0
|
77
|
+
Ap[I] = 0.0
|
78
|
+
|
79
|
+
@kernel
|
80
|
+
def reduce(p: template(), q: template()) -> solver_dtype:
|
81
|
+
result = solver_dtype(0.0)
|
82
|
+
for I in grouped(p):
|
83
|
+
result += p[I] * q[I]
|
84
|
+
return result
|
85
|
+
|
86
|
+
@kernel
|
87
|
+
def update_x():
|
88
|
+
for I in grouped(x):
|
89
|
+
x[I] += alpha[None] * p[I]
|
90
|
+
|
91
|
+
@kernel
|
92
|
+
def update_r():
|
93
|
+
for I in grouped(r):
|
94
|
+
r[I] -= alpha[None] * Ap[I]
|
95
|
+
|
96
|
+
@kernel
|
97
|
+
def update_p():
|
98
|
+
for I in grouped(p):
|
99
|
+
p[I] = r[I] + beta[None] * p[I]
|
100
|
+
|
101
|
+
def solve():
|
102
|
+
succeeded = True
|
103
|
+
A._matvec(x, Ax)
|
104
|
+
init()
|
105
|
+
initial_rTr = reduce(r, r)
|
106
|
+
if not quiet:
|
107
|
+
print(f">>> Initial residual = {initial_rTr:e}")
|
108
|
+
old_rTr = initial_rTr
|
109
|
+
new_rTr = initial_rTr
|
110
|
+
update_p()
|
111
|
+
if sqrt(initial_rTr) >= tol: # Do nothing if the initial residual is small enough
|
112
|
+
# -- Main loop --
|
113
|
+
for i in range(maxiter):
|
114
|
+
A._matvec(p, Ap) # compute Ap = A x p
|
115
|
+
pAp = reduce(p, Ap)
|
116
|
+
alpha[None] = old_rTr / pAp
|
117
|
+
update_x()
|
118
|
+
update_r()
|
119
|
+
new_rTr = reduce(r, r)
|
120
|
+
if sqrt(new_rTr) < tol:
|
121
|
+
if not quiet:
|
122
|
+
print(">>> Conjugate Gradient method converged.")
|
123
|
+
print(f">>> #iterations {i}")
|
124
|
+
break
|
125
|
+
beta[None] = new_rTr / old_rTr
|
126
|
+
update_p()
|
127
|
+
old_rTr = new_rTr
|
128
|
+
if not quiet:
|
129
|
+
print(f">>> Iter = {i+1:4}, Residual = {sqrt(new_rTr):e}")
|
130
|
+
if new_rTr >= tol:
|
131
|
+
if not quiet:
|
132
|
+
print(
|
133
|
+
f">>> Conjugate Gradient method failed to converge in {maxiter} iterations: Residual = {sqrt(new_rTr):e}"
|
134
|
+
)
|
135
|
+
succeeded = False
|
136
|
+
return succeeded
|
137
|
+
|
138
|
+
succeeded = solve()
|
139
|
+
vector_fields_snode_tree.destroy()
|
140
|
+
scalar_snode_tree.destroy()
|
141
|
+
return succeeded
|
142
|
+
|
143
|
+
|
144
|
+
def MatrixFreeBICGSTAB(A, b, x, tol=1e-6, maxiter=5000, quiet=True):
|
145
|
+
"""Matrix-free biconjugate-gradient stabilized solver (BiCGSTAB).
|
146
|
+
|
147
|
+
Use BiCGSTAB method to solve the linear system Ax = b, where A is implicitly
|
148
|
+
represented as a LinearOperator.
|
149
|
+
|
150
|
+
Args:
|
151
|
+
A (LinearOperator): The coefficient matrix A of the linear system.
|
152
|
+
b (Field): The right-hand side of the linear system.
|
153
|
+
x (Field): The initial guess for the solution.
|
154
|
+
maxiter (int): Maximum number of iterations.
|
155
|
+
atol: Tolerance(absolute) for convergence.
|
156
|
+
quiet (bool): Switch to turn on/off iteration log.
|
157
|
+
"""
|
158
|
+
|
159
|
+
if b.dtype != x.dtype:
|
160
|
+
raise GsTaichiTypeError(f"Dtype mismatch b.dtype({b.dtype}) != x.dtype({x.dtype}).")
|
161
|
+
if str(b.dtype) == "f32":
|
162
|
+
solver_dtype = primitive_types.f32
|
163
|
+
elif str(b.dtype) == "f64":
|
164
|
+
solver_dtype = primitive_types.f64
|
165
|
+
else:
|
166
|
+
raise GsTaichiTypeError(f"Not supported dtype: {b.dtype}")
|
167
|
+
if b.shape != x.shape:
|
168
|
+
raise GsTaichiRuntimeError(f"Dimension mismatch b.shape{b.shape} != x.shape{x.shape}.")
|
169
|
+
|
170
|
+
size = b.shape
|
171
|
+
vector_fields_builder = FieldsBuilder()
|
172
|
+
p = field(dtype=solver_dtype)
|
173
|
+
p_hat = field(dtype=solver_dtype)
|
174
|
+
r = field(dtype=solver_dtype)
|
175
|
+
r_tld = field(dtype=solver_dtype)
|
176
|
+
s = field(dtype=solver_dtype)
|
177
|
+
s_hat = field(dtype=solver_dtype)
|
178
|
+
t = field(dtype=solver_dtype)
|
179
|
+
Ap = field(dtype=solver_dtype)
|
180
|
+
Ax = field(dtype=solver_dtype)
|
181
|
+
Ashat = field(dtype=solver_dtype)
|
182
|
+
if len(size) == 1:
|
183
|
+
axes = misc.i
|
184
|
+
elif len(size) == 2:
|
185
|
+
axes = misc.ij
|
186
|
+
elif len(size) == 3:
|
187
|
+
axes = misc.ijk
|
188
|
+
else:
|
189
|
+
raise GsTaichiRuntimeError(f"MatrixFreeBICGSTAB only support 1D, 2D, 3D inputs; your inputs is {len(size)}-D.")
|
190
|
+
vector_fields_builder.dense(axes, size).place(p, p_hat, r, r_tld, s, s_hat, t, Ap, Ax, Ashat)
|
191
|
+
vector_fields_snode_tree = vector_fields_builder.finalize()
|
192
|
+
|
193
|
+
scalar_builder = FieldsBuilder()
|
194
|
+
alpha = field(dtype=solver_dtype)
|
195
|
+
beta = field(dtype=solver_dtype)
|
196
|
+
omega = field(dtype=solver_dtype)
|
197
|
+
rho = field(dtype=solver_dtype)
|
198
|
+
rho_1 = field(dtype=solver_dtype)
|
199
|
+
scalar_builder.place(alpha, beta, omega, rho, rho_1)
|
200
|
+
scalar_snode_tree = scalar_builder.finalize()
|
201
|
+
succeeded = True
|
202
|
+
|
203
|
+
@kernel
|
204
|
+
def init():
|
205
|
+
for I in grouped(x):
|
206
|
+
r[I] = b[I] - Ax[I]
|
207
|
+
r_tld[I] = b[I]
|
208
|
+
p[I] = 0.0
|
209
|
+
Ap[I] = 0.0
|
210
|
+
Ashat[I] = 0.0
|
211
|
+
rho[None] = 0.0
|
212
|
+
rho_1[None] = 1.0
|
213
|
+
alpha[None] = 1.0
|
214
|
+
beta[None] = 1.0
|
215
|
+
omega[None] = 1.0
|
216
|
+
|
217
|
+
@kernel
|
218
|
+
def reduce(p: template(), q: template()) -> solver_dtype:
|
219
|
+
result = solver_dtype(0.0)
|
220
|
+
for I in grouped(p):
|
221
|
+
result += p[I] * q[I]
|
222
|
+
return result
|
223
|
+
|
224
|
+
@kernel
|
225
|
+
def copy(orig: template(), dest: template()):
|
226
|
+
for I in grouped(orig):
|
227
|
+
dest[I] = orig[I]
|
228
|
+
|
229
|
+
@kernel
|
230
|
+
def update_p():
|
231
|
+
for I in grouped(p):
|
232
|
+
p[I] = r[I] + beta[None] * (p[I] - omega[None] * Ap[I])
|
233
|
+
|
234
|
+
@kernel
|
235
|
+
def update_phat():
|
236
|
+
for I in grouped(p_hat):
|
237
|
+
p_hat[I] = p[I]
|
238
|
+
|
239
|
+
@kernel
|
240
|
+
def update_s():
|
241
|
+
for I in grouped(s):
|
242
|
+
s[I] = r[I] - alpha[None] * Ap[I]
|
243
|
+
|
244
|
+
@kernel
|
245
|
+
def update_shat():
|
246
|
+
for I in grouped(s_hat):
|
247
|
+
s_hat[I] = s[I]
|
248
|
+
|
249
|
+
@kernel
|
250
|
+
def update_x():
|
251
|
+
for I in grouped(x):
|
252
|
+
x[I] += alpha[None] * p_hat[I] + omega[None] * s_hat[I]
|
253
|
+
|
254
|
+
@kernel
|
255
|
+
def update_r():
|
256
|
+
for I in grouped(r):
|
257
|
+
r[I] = s[I] - omega[None] * t[I]
|
258
|
+
|
259
|
+
def solve():
|
260
|
+
succeeded = True
|
261
|
+
A._matvec(x, Ax)
|
262
|
+
init()
|
263
|
+
initial_rTr = reduce(r, r)
|
264
|
+
rTr = initial_rTr
|
265
|
+
if not quiet:
|
266
|
+
print(f">>> Initial residual = {initial_rTr:e}")
|
267
|
+
if sqrt(initial_rTr) >= tol: # Do nothing if the initial residual is small enough
|
268
|
+
for i in range(maxiter):
|
269
|
+
rho[None] = reduce(r, r_tld)
|
270
|
+
if rho[None] == 0.0:
|
271
|
+
if not quiet:
|
272
|
+
print(">>> BICGSTAB failed because r@r_tld = 0.")
|
273
|
+
succeeded = False
|
274
|
+
break
|
275
|
+
if i == 0:
|
276
|
+
copy(orig=r, dest=p)
|
277
|
+
else:
|
278
|
+
beta[None] = (rho[None] / rho_1[None]) * (alpha[None] / omega[None])
|
279
|
+
update_p()
|
280
|
+
update_phat()
|
281
|
+
A._matvec(p, Ap)
|
282
|
+
alpha_lower = reduce(r_tld, Ap)
|
283
|
+
alpha[None] = rho[None] / alpha_lower
|
284
|
+
update_s()
|
285
|
+
update_shat()
|
286
|
+
A._matvec(s_hat, Ashat)
|
287
|
+
copy(orig=Ashat, dest=t)
|
288
|
+
omega_upper = reduce(t, s)
|
289
|
+
omega_lower = reduce(t, t)
|
290
|
+
omega[None] = omega_upper / (omega_lower + 1e-16) if omega_lower == 0.0 else omega_upper / omega_lower
|
291
|
+
update_x()
|
292
|
+
update_r()
|
293
|
+
rTr = reduce(r, r)
|
294
|
+
if not quiet:
|
295
|
+
print(f">>> Iter = {i+1:4}, Residual = {sqrt(rTr):e}")
|
296
|
+
if sqrt(rTr) < tol:
|
297
|
+
if not quiet:
|
298
|
+
print(f">>> BICGSTAB method converged at #iterations {i}")
|
299
|
+
break
|
300
|
+
rho_1[None] = rho[None]
|
301
|
+
if rTr >= tol:
|
302
|
+
if not quiet:
|
303
|
+
print(f">>> BICGSTAB failed to converge in {maxiter} iterations: Residual = {sqrt(rTr):e}")
|
304
|
+
succeeded = False
|
305
|
+
return succeeded
|
306
|
+
|
307
|
+
succeeded = solve()
|
308
|
+
vector_fields_snode_tree.destroy()
|
309
|
+
scalar_snode_tree.destroy()
|
310
|
+
return succeeded
|
@@ -0,0 +1,59 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
|
5
|
+
from gstaichi._lib import core as _ti_core
|
6
|
+
from gstaichi.lang._ndarray import Ndarray, ScalarNdarray
|
7
|
+
from gstaichi.lang.exception import GsTaichiRuntimeError
|
8
|
+
from gstaichi.lang.impl import get_runtime
|
9
|
+
from gstaichi.types import f32, f64
|
10
|
+
|
11
|
+
|
12
|
+
class SparseCG:
|
13
|
+
"""Conjugate-gradient solver built for SparseMatrix.
|
14
|
+
|
15
|
+
Use conjugate-gradient method to solve the linear system Ax = b, where A is SparseMatrix.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
A (SparseMatrix): The coefficient matrix A of the linear system.
|
19
|
+
b (numpy ndarray, gstaichi Ndarray): The right-hand side of the linear system.
|
20
|
+
x0 (numpy ndarray, gstaichi Ndarray): The initial guess for the solution.
|
21
|
+
max_iter (int): Maximum number of iterations.
|
22
|
+
atol: Tolerance(absolute) for convergence.
|
23
|
+
"""
|
24
|
+
|
25
|
+
def __init__(self, A, b, x0=None, max_iter=50, atol=1e-6):
|
26
|
+
self.dtype = A.dtype
|
27
|
+
self.ti_arch = get_runtime().prog.config().arch
|
28
|
+
self.matrix = A
|
29
|
+
self.b = b
|
30
|
+
if self.ti_arch == _ti_core.Arch.cuda:
|
31
|
+
self.cg_solver = _ti_core.make_cucg_solver(A.matrix, max_iter, atol, True)
|
32
|
+
elif self.ti_arch == _ti_core.Arch.x64 or self.ti_arch == _ti_core.Arch.arm64:
|
33
|
+
if self.dtype == f32:
|
34
|
+
self.cg_solver = _ti_core.make_float_cg_solver(A.matrix, max_iter, atol, True)
|
35
|
+
elif self.dtype == f64:
|
36
|
+
self.cg_solver = _ti_core.make_double_cg_solver(A.matrix, max_iter, atol, True)
|
37
|
+
else:
|
38
|
+
raise GsTaichiRuntimeError(f"Unsupported CG dtype: {self.dtype}")
|
39
|
+
if isinstance(b, Ndarray):
|
40
|
+
self.cg_solver.set_b_ndarray(get_runtime().prog, b.arr)
|
41
|
+
elif isinstance(b, np.ndarray):
|
42
|
+
self.cg_solver.set_b(b)
|
43
|
+
if isinstance(x0, Ndarray):
|
44
|
+
self.cg_solver.set_x_ndarray(get_runtime().prog, x0.arr)
|
45
|
+
elif isinstance(x0, np.ndarray):
|
46
|
+
self.cg_solver.set_x(x0)
|
47
|
+
else:
|
48
|
+
raise GsTaichiRuntimeError(f"Unsupported CG arch: {self.ti_arch}")
|
49
|
+
|
50
|
+
def solve(self):
|
51
|
+
if self.ti_arch == _ti_core.Arch.cuda:
|
52
|
+
if isinstance(self.b, Ndarray):
|
53
|
+
x = ScalarNdarray(self.b.dtype, [self.matrix.m])
|
54
|
+
self.cg_solver.solve(get_runtime().prog, x.arr, self.b.arr)
|
55
|
+
return x, True
|
56
|
+
raise GsTaichiRuntimeError(f"Unsupported CG RHS type: {type(self.b)}")
|
57
|
+
else:
|
58
|
+
self.cg_solver.solve()
|
59
|
+
return self.cg_solver.get_x(), self.cg_solver.is_success()
|