gstaichi 2.1.1rc3__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +4 -0
- gstaichi/__init__.py +40 -0
- gstaichi/_funcs.py +706 -0
- gstaichi/_kernels.py +420 -0
- gstaichi/_lib/__init__.py +3 -0
- gstaichi/_lib/core/__init__.py +0 -0
- gstaichi/_lib/core/gstaichi_python.cpython-310-darwin.so +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
- gstaichi/_lib/core/py.typed +0 -0
- gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
- gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
- gstaichi/_lib/utils.py +243 -0
- gstaichi/_logging.py +131 -0
- gstaichi/_snode/__init__.py +5 -0
- gstaichi/_snode/fields_builder.py +187 -0
- gstaichi/_snode/snode_tree.py +34 -0
- gstaichi/_test_tools/__init__.py +18 -0
- gstaichi/_test_tools/dataclass_test_tools.py +36 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_test_tools/textwrap2.py +6 -0
- gstaichi/_version.py +1 -0
- gstaichi/_version_check.py +100 -0
- gstaichi/ad/__init__.py +3 -0
- gstaichi/ad/_ad.py +530 -0
- gstaichi/algorithms/__init__.py +3 -0
- gstaichi/algorithms/_algorithms.py +117 -0
- gstaichi/assets/.git +1 -0
- gstaichi/assets/Go-Regular.ttf +0 -0
- gstaichi/assets/static/imgs/ti_gallery.png +0 -0
- gstaichi/examples/lcg_python.py +26 -0
- gstaichi/examples/lcg_taichi.py +34 -0
- gstaichi/examples/minimal.py +28 -0
- gstaichi/experimental.py +16 -0
- gstaichi/lang/__init__.py +50 -0
- gstaichi/lang/_dataclass_util.py +31 -0
- gstaichi/lang/_fast_caching/__init__.py +3 -0
- gstaichi/lang/_fast_caching/args_hasher.py +110 -0
- gstaichi/lang/_fast_caching/config_hasher.py +30 -0
- gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
- gstaichi/lang/_fast_caching/function_hasher.py +57 -0
- gstaichi/lang/_fast_caching/hash_utils.py +11 -0
- gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
- gstaichi/lang/_fast_caching/src_hasher.py +75 -0
- gstaichi/lang/_kernel_impl_dataclass.py +212 -0
- gstaichi/lang/_ndarray.py +352 -0
- gstaichi/lang/_ndrange.py +152 -0
- gstaichi/lang/_template_mapper.py +195 -0
- gstaichi/lang/_texture.py +172 -0
- gstaichi/lang/_wrap_inspect.py +215 -0
- gstaichi/lang/any_array.py +99 -0
- gstaichi/lang/ast/__init__.py +5 -0
- gstaichi/lang/ast/ast_transformer.py +1323 -0
- gstaichi/lang/ast/ast_transformer_utils.py +346 -0
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
- gstaichi/lang/ast/checkers.py +106 -0
- gstaichi/lang/ast/symbol_resolver.py +57 -0
- gstaichi/lang/ast/transform.py +9 -0
- gstaichi/lang/common_ops.py +310 -0
- gstaichi/lang/exception.py +80 -0
- gstaichi/lang/expr.py +180 -0
- gstaichi/lang/field.py +428 -0
- gstaichi/lang/impl.py +1243 -0
- gstaichi/lang/kernel_arguments.py +155 -0
- gstaichi/lang/kernel_impl.py +1341 -0
- gstaichi/lang/matrix.py +1835 -0
- gstaichi/lang/matrix_ops.py +341 -0
- gstaichi/lang/matrix_ops_utils.py +190 -0
- gstaichi/lang/mesh.py +687 -0
- gstaichi/lang/misc.py +782 -0
- gstaichi/lang/ops.py +1494 -0
- gstaichi/lang/runtime_ops.py +13 -0
- gstaichi/lang/shell.py +35 -0
- gstaichi/lang/simt/__init__.py +5 -0
- gstaichi/lang/simt/block.py +94 -0
- gstaichi/lang/simt/grid.py +7 -0
- gstaichi/lang/simt/subgroup.py +191 -0
- gstaichi/lang/simt/warp.py +96 -0
- gstaichi/lang/snode.py +489 -0
- gstaichi/lang/source_builder.py +150 -0
- gstaichi/lang/struct.py +810 -0
- gstaichi/lang/util.py +312 -0
- gstaichi/linalg/__init__.py +8 -0
- gstaichi/linalg/matrixfree_cg.py +310 -0
- gstaichi/linalg/sparse_cg.py +59 -0
- gstaichi/linalg/sparse_matrix.py +303 -0
- gstaichi/linalg/sparse_solver.py +123 -0
- gstaichi/math/__init__.py +11 -0
- gstaichi/math/_complex.py +205 -0
- gstaichi/math/mathimpl.py +886 -0
- gstaichi/profiler/__init__.py +6 -0
- gstaichi/profiler/kernel_metrics.py +260 -0
- gstaichi/profiler/kernel_profiler.py +586 -0
- gstaichi/profiler/memory_profiler.py +15 -0
- gstaichi/profiler/scoped_profiler.py +36 -0
- gstaichi/sparse/__init__.py +3 -0
- gstaichi/sparse/_sparse_grid.py +77 -0
- gstaichi/tools/__init__.py +12 -0
- gstaichi/tools/diagnose.py +117 -0
- gstaichi/tools/np2ply.py +364 -0
- gstaichi/tools/vtk.py +38 -0
- gstaichi/types/__init__.py +19 -0
- gstaichi/types/annotations.py +52 -0
- gstaichi/types/compound_types.py +71 -0
- gstaichi/types/enums.py +49 -0
- gstaichi/types/ndarray_type.py +169 -0
- gstaichi/types/primitive_types.py +206 -0
- gstaichi/types/quant.py +88 -0
- gstaichi/types/texture_type.py +85 -0
- gstaichi/types/utils.py +11 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
- gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
- gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
- gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
- gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
- gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,303 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
from functools import reduce
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from gstaichi._lib import core as _ti_core
|
8
|
+
from gstaichi.lang._ndarray import Ndarray, ScalarNdarray
|
9
|
+
from gstaichi.lang.exception import GsTaichiRuntimeError
|
10
|
+
from gstaichi.lang.field import Field
|
11
|
+
from gstaichi.lang.impl import get_runtime
|
12
|
+
from gstaichi.types import f32
|
13
|
+
|
14
|
+
|
15
|
+
class SparseMatrix:
|
16
|
+
"""GsTaichi's Sparse Matrix class
|
17
|
+
|
18
|
+
A sparse matrix allows the programmer to solve a large linear system.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
n (int): the first dimension of a sparse matrix.
|
22
|
+
m (int): the second dimension of a sparse matrix.
|
23
|
+
sm (SparseMatrix): another sparse matrix that will be built from.
|
24
|
+
"""
|
25
|
+
|
26
|
+
def __init__(self, n=None, m=None, sm=None, dtype=f32, storage_format="col_major"):
|
27
|
+
self.dtype = dtype
|
28
|
+
if sm is None:
|
29
|
+
self.n = n
|
30
|
+
self.m = m if m else n
|
31
|
+
self.matrix = get_runtime().prog.create_sparse_matrix(n, m, dtype, storage_format)
|
32
|
+
else:
|
33
|
+
self.n = sm.num_rows()
|
34
|
+
self.m = sm.num_cols()
|
35
|
+
self.matrix = sm
|
36
|
+
|
37
|
+
def __iadd__(self, other):
|
38
|
+
"""Addition operation for sparse matrix.
|
39
|
+
|
40
|
+
Returns:
|
41
|
+
The result sparse matrix of the addition.
|
42
|
+
"""
|
43
|
+
assert (
|
44
|
+
self.n == other.n and self.m == other.m
|
45
|
+
), f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})"
|
46
|
+
self.matrix += other.matrix
|
47
|
+
return self
|
48
|
+
|
49
|
+
def __add__(self, other):
|
50
|
+
"""Addition operation for sparse matrix.
|
51
|
+
|
52
|
+
Returns:
|
53
|
+
The result sparse matrix of the addition.
|
54
|
+
"""
|
55
|
+
assert (
|
56
|
+
self.n == other.n and self.m == other.m
|
57
|
+
), f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})"
|
58
|
+
sm = self.matrix + other.matrix
|
59
|
+
return SparseMatrix(sm=sm)
|
60
|
+
|
61
|
+
def __isub__(self, other):
|
62
|
+
"""Subtraction operation for sparse matrix.
|
63
|
+
|
64
|
+
Returns:
|
65
|
+
The result sparse matrix of the subtraction.
|
66
|
+
"""
|
67
|
+
assert (
|
68
|
+
self.n == other.n and self.m == other.m
|
69
|
+
), f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})"
|
70
|
+
self.matrix -= other.matrix
|
71
|
+
return self
|
72
|
+
|
73
|
+
def __sub__(self, other):
|
74
|
+
"""Subtraction operation for sparse matrix.
|
75
|
+
|
76
|
+
Returns:
|
77
|
+
The result sparse matrix of the subtraction.
|
78
|
+
"""
|
79
|
+
assert (
|
80
|
+
self.n == other.n and self.m == other.m
|
81
|
+
), f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})"
|
82
|
+
sm = self.matrix - other.matrix
|
83
|
+
return SparseMatrix(sm=sm)
|
84
|
+
|
85
|
+
def __mul__(self, other):
|
86
|
+
"""Sparse matrix's multiplication against real numbers or the hadamard product against another matrix
|
87
|
+
|
88
|
+
Args:
|
89
|
+
other (float or SparseMatrix): the other operand of multiplication.
|
90
|
+
Returns:
|
91
|
+
The result of multiplication.
|
92
|
+
"""
|
93
|
+
if isinstance(other, float):
|
94
|
+
sm = other * self.matrix
|
95
|
+
return SparseMatrix(sm=sm)
|
96
|
+
if isinstance(other, SparseMatrix):
|
97
|
+
assert (
|
98
|
+
self.n == other.n and self.m == other.m
|
99
|
+
), f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})"
|
100
|
+
sm = self.matrix * other.matrix
|
101
|
+
return SparseMatrix(sm=sm)
|
102
|
+
|
103
|
+
return None
|
104
|
+
|
105
|
+
def __rmul__(self, other):
|
106
|
+
"""Right scalar multiplication for sparse matrix.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
other (float): the other operand of scalar multiplication.
|
110
|
+
Returns:
|
111
|
+
The result of multiplication.
|
112
|
+
"""
|
113
|
+
if isinstance(other, float):
|
114
|
+
sm = self.matrix * other
|
115
|
+
return SparseMatrix(sm=sm)
|
116
|
+
|
117
|
+
return None
|
118
|
+
|
119
|
+
def transpose(self):
|
120
|
+
"""Sparse Matrix transpose.
|
121
|
+
|
122
|
+
Returns:
|
123
|
+
The transposed sparse mastrix.
|
124
|
+
"""
|
125
|
+
sm = self.matrix.transpose()
|
126
|
+
return SparseMatrix(sm=sm)
|
127
|
+
|
128
|
+
def __matmul__(self, other):
|
129
|
+
"""Matrix multiplication.
|
130
|
+
|
131
|
+
Args:
|
132
|
+
other (SparseMatrix, Field, or numpy.array): the other sparse matrix of the multiplication.
|
133
|
+
Returns:
|
134
|
+
The result of matrix multiplication.
|
135
|
+
"""
|
136
|
+
if isinstance(other, SparseMatrix):
|
137
|
+
assert (
|
138
|
+
self.m == other.n
|
139
|
+
), f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})"
|
140
|
+
sm = self.matrix.matmul(other.matrix)
|
141
|
+
return SparseMatrix(sm=sm)
|
142
|
+
if isinstance(other, Field):
|
143
|
+
assert (
|
144
|
+
self.m == other.shape[0]
|
145
|
+
), f"Dimension mismatch between sparse matrix ({self.n}, {self.m}) and vector ({other.shape})"
|
146
|
+
return self.matrix.mat_vec_mul(other.to_numpy())
|
147
|
+
if isinstance(other, np.ndarray):
|
148
|
+
assert (
|
149
|
+
self.m == other.shape[0]
|
150
|
+
), f"Dimension mismatch between sparse matrix ({self.n}, {self.m}) and vector ({other.shape})"
|
151
|
+
return self.matrix.mat_vec_mul(other)
|
152
|
+
if isinstance(other, Ndarray):
|
153
|
+
if self.m != other.shape[0]:
|
154
|
+
raise GsTaichiRuntimeError(
|
155
|
+
f"Dimension mismatch between sparse matrix ({self.n}, {self.m}) and vector ({other.shape})"
|
156
|
+
)
|
157
|
+
res = ScalarNdarray(dtype=other.dtype, arr_shape=(self.n,))
|
158
|
+
self.matrix.spmv(get_runtime().prog, other.arr, res.arr)
|
159
|
+
return res
|
160
|
+
raise GsTaichiRuntimeError(
|
161
|
+
f"Sparse matrix-matrix/vector multiplication does not support {type(other)} for now. Supported types are SparseMatrix, ti.field, and numpy ndarray."
|
162
|
+
)
|
163
|
+
|
164
|
+
def __getitem__(self, indices):
|
165
|
+
return self.matrix.get_element(indices[0], indices[1])
|
166
|
+
|
167
|
+
def __setitem__(self, indices, value):
|
168
|
+
self.matrix.set_element(indices[0], indices[1], value)
|
169
|
+
|
170
|
+
def __str__(self):
|
171
|
+
"""Python scope matrix print support."""
|
172
|
+
return self.matrix.to_string()
|
173
|
+
|
174
|
+
def __repr__(self):
|
175
|
+
return self.matrix.to_string()
|
176
|
+
|
177
|
+
@property
|
178
|
+
def shape(self):
|
179
|
+
"""The shape of the sparse matrix."""
|
180
|
+
return (self.n, self.m)
|
181
|
+
|
182
|
+
def build_from_ndarray(self, ndarray):
|
183
|
+
"""Build the sparse matrix from a ndarray.
|
184
|
+
|
185
|
+
Args:
|
186
|
+
ndarray (Union[ti.ndarray, ti.Vector.ndarray, ti.Matrix.ndarray]): the ndarray to build the sparse matrix from.
|
187
|
+
|
188
|
+
Raises:
|
189
|
+
GsTaichiRuntimeError: If the input is not a ndarray or the length is not divisible by 3.
|
190
|
+
|
191
|
+
Example::
|
192
|
+
>>> N = 5
|
193
|
+
>>> triplets = ti.Vector.ndarray(n=3, dtype=ti.f32, shape=10, layout=ti.Layout.AOS)
|
194
|
+
>>> @ti.kernel
|
195
|
+
>>> def fill(triplets: ti.types.ndarray()):
|
196
|
+
>>> for i in range(N):
|
197
|
+
>>> triplets[i] = ti.Vector([i, (i + 1) % N, i+1], dt=ti.f32)
|
198
|
+
>>> fill(triplets)
|
199
|
+
>>> A = ti.linalg.SparseMatrix(n=N, m=N, dtype=ti.f32)
|
200
|
+
>>> A.build_from_ndarray(triplets)
|
201
|
+
>>> print(A)
|
202
|
+
[0, 1, 0, 0, 0]
|
203
|
+
[0, 0, 2, 0, 0]
|
204
|
+
[0, 0, 0, 3, 0]
|
205
|
+
[0, 0, 0, 0, 4]
|
206
|
+
[5, 0, 0, 0, 0]
|
207
|
+
"""
|
208
|
+
if isinstance(ndarray, Ndarray):
|
209
|
+
num_scalars = reduce(lambda x, y: x * y, ndarray.shape + ndarray.element_shape)
|
210
|
+
if num_scalars % 3 != 0:
|
211
|
+
raise GsTaichiRuntimeError("The number of ndarray elements must have a length that is divisible by 3.")
|
212
|
+
get_runtime().prog.make_sparse_matrix_from_ndarray(self.matrix, ndarray.arr)
|
213
|
+
else:
|
214
|
+
raise GsTaichiRuntimeError(
|
215
|
+
"Sparse matrix only supports building from [ti.ndarray, ti.Vector.ndarray, ti.Matrix.ndarray]"
|
216
|
+
)
|
217
|
+
|
218
|
+
def mmwrite(self, filename):
|
219
|
+
"""Writes the sparse matrix to Matrix Market file-like target.
|
220
|
+
|
221
|
+
Args:
|
222
|
+
filename (str): the file name to write the sparse matrix to.
|
223
|
+
"""
|
224
|
+
self.matrix.mmwrite(filename)
|
225
|
+
|
226
|
+
|
227
|
+
class SparseMatrixBuilder:
|
228
|
+
"""A python wrap around sparse matrix builder.
|
229
|
+
|
230
|
+
Use this builder to fill the sparse matrix.
|
231
|
+
|
232
|
+
Args:
|
233
|
+
num_rows (int): the first dimension of a sparse matrix.
|
234
|
+
num_cols (int): the second dimension of a sparse matrix.
|
235
|
+
max_num_triplets (int): the maximum number of triplets.
|
236
|
+
dtype (ti.dtype): the data type of the sparse matrix.
|
237
|
+
storage_format (str): the storage format of the sparse matrix.
|
238
|
+
"""
|
239
|
+
|
240
|
+
def __init__(
|
241
|
+
self,
|
242
|
+
num_rows=None,
|
243
|
+
num_cols=None,
|
244
|
+
max_num_triplets=0,
|
245
|
+
dtype=f32,
|
246
|
+
storage_format="col_major",
|
247
|
+
):
|
248
|
+
self.num_rows = num_rows
|
249
|
+
self.num_cols = num_cols if num_cols else num_rows
|
250
|
+
self.dtype = dtype
|
251
|
+
if num_rows is not None:
|
252
|
+
gstaichi_arch = get_runtime().prog.config().arch
|
253
|
+
if gstaichi_arch in [
|
254
|
+
_ti_core.Arch.x64,
|
255
|
+
_ti_core.Arch.arm64,
|
256
|
+
_ti_core.Arch.cuda,
|
257
|
+
]:
|
258
|
+
self.ptr = _ti_core.SparseMatrixBuilder(
|
259
|
+
num_rows,
|
260
|
+
num_cols,
|
261
|
+
max_num_triplets,
|
262
|
+
dtype,
|
263
|
+
storage_format,
|
264
|
+
)
|
265
|
+
self.ptr.create_ndarray(get_runtime().prog)
|
266
|
+
else:
|
267
|
+
raise GsTaichiRuntimeError("SparseMatrix only supports CPU and CUDA for now.")
|
268
|
+
|
269
|
+
def _get_addr(self):
|
270
|
+
"""Get the address of the sparse matrix"""
|
271
|
+
return self.ptr.get_addr()
|
272
|
+
|
273
|
+
def _get_ndarray_addr(self):
|
274
|
+
"""Get the address of the ndarray"""
|
275
|
+
return self.ptr.get_ndarray_data_ptr()
|
276
|
+
|
277
|
+
def print_triplets(self):
|
278
|
+
"""Print the triplets stored in the builder"""
|
279
|
+
gstaichi_arch = get_runtime().prog.config().arch
|
280
|
+
if gstaichi_arch in [_ti_core.Arch.x64, _ti_core.Arch.arm64]:
|
281
|
+
self.ptr.print_triplets_eigen()
|
282
|
+
elif gstaichi_arch == _ti_core.Arch.cuda:
|
283
|
+
self.ptr.print_triplets_cuda()
|
284
|
+
|
285
|
+
def build(self, dtype=f32, _format="CSR"):
|
286
|
+
"""Create a sparse matrix using the triplets"""
|
287
|
+
gstaichi_arch = get_runtime().prog.config().arch
|
288
|
+
if gstaichi_arch in [_ti_core.Arch.x64, _ti_core.Arch.arm64]:
|
289
|
+
sm = self.ptr.build()
|
290
|
+
return SparseMatrix(sm=sm, dtype=self.dtype)
|
291
|
+
if gstaichi_arch == _ti_core.Arch.cuda:
|
292
|
+
if self.dtype != f32:
|
293
|
+
raise GsTaichiRuntimeError("CUDA sparse matrix only supports f32.")
|
294
|
+
sm = self.ptr.build_cuda()
|
295
|
+
return SparseMatrix(sm=sm, dtype=self.dtype)
|
296
|
+
raise GsTaichiRuntimeError("Sparse matrix only supports CPU and CUDA backends.")
|
297
|
+
|
298
|
+
def __del__(self):
|
299
|
+
if get_runtime() is not None and get_runtime().prog is not None:
|
300
|
+
self.ptr.delete_ndarray(get_runtime().prog)
|
301
|
+
|
302
|
+
|
303
|
+
__all__ = ["SparseMatrix", "SparseMatrixBuilder"]
|
@@ -0,0 +1,123 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
|
5
|
+
import gstaichi.lang
|
6
|
+
from gstaichi._lib import core as _ti_core
|
7
|
+
from gstaichi.lang._ndarray import Ndarray, ScalarNdarray
|
8
|
+
from gstaichi.lang.exception import GsTaichiRuntimeError
|
9
|
+
from gstaichi.lang.field import Field
|
10
|
+
from gstaichi.lang.impl import get_runtime
|
11
|
+
from gstaichi.linalg.sparse_matrix import SparseMatrix
|
12
|
+
from gstaichi.types.primitive_types import f32
|
13
|
+
|
14
|
+
|
15
|
+
class SparseSolver:
|
16
|
+
"""Sparse linear system solver
|
17
|
+
|
18
|
+
Use this class to solve linear systems represented by sparse matrices.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
solver_type (str): The factorization type.
|
22
|
+
ordering (str): The method for matrices re-ordering.
|
23
|
+
"""
|
24
|
+
|
25
|
+
def __init__(self, dtype=f32, solver_type="LLT", ordering="AMD"):
|
26
|
+
self.matrix = None
|
27
|
+
self.dtype = dtype
|
28
|
+
solver_type_list = ["LLT", "LDLT", "LU"]
|
29
|
+
solver_ordering = ["AMD", "COLAMD"]
|
30
|
+
if solver_type in solver_type_list and ordering in solver_ordering:
|
31
|
+
gstaichi_arch = gstaichi.lang.impl.get_runtime().prog.config().arch
|
32
|
+
assert (
|
33
|
+
gstaichi_arch == _ti_core.Arch.x64
|
34
|
+
or gstaichi_arch == _ti_core.Arch.arm64
|
35
|
+
or gstaichi_arch == _ti_core.Arch.cuda
|
36
|
+
), "SparseSolver only supports CPU and CUDA for now."
|
37
|
+
if gstaichi_arch == _ti_core.Arch.cuda:
|
38
|
+
self.solver = _ti_core.make_cusparse_solver(dtype, solver_type, ordering)
|
39
|
+
else:
|
40
|
+
self.solver = _ti_core.make_sparse_solver(dtype, solver_type, ordering)
|
41
|
+
else:
|
42
|
+
raise GsTaichiRuntimeError(
|
43
|
+
f"The solver type {solver_type} with {ordering} is not supported for now. Only {solver_type_list} with {solver_ordering} are supported."
|
44
|
+
)
|
45
|
+
|
46
|
+
@staticmethod
|
47
|
+
def _type_assert(sparse_matrix):
|
48
|
+
raise GsTaichiRuntimeError(
|
49
|
+
f"The parameter type: {type(sparse_matrix)} is not supported in linear solvers for now."
|
50
|
+
)
|
51
|
+
|
52
|
+
def compute(self, sparse_matrix):
|
53
|
+
"""This method is equivalent to calling both `analyze_pattern` and then `factorize`.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
sparse_matrix (SparseMatrix): The sparse matrix to be computed.
|
57
|
+
"""
|
58
|
+
if isinstance(sparse_matrix, SparseMatrix):
|
59
|
+
self.matrix = sparse_matrix
|
60
|
+
gstaichi_arch = gstaichi.lang.impl.get_runtime().prog.config().arch
|
61
|
+
if gstaichi_arch == _ti_core.Arch.x64 or gstaichi_arch == _ti_core.Arch.arm64:
|
62
|
+
self.solver.compute(sparse_matrix.matrix)
|
63
|
+
elif gstaichi_arch == _ti_core.Arch.cuda:
|
64
|
+
self.analyze_pattern(self.matrix)
|
65
|
+
self.factorize(self.matrix)
|
66
|
+
else:
|
67
|
+
self._type_assert(sparse_matrix)
|
68
|
+
|
69
|
+
def analyze_pattern(self, sparse_matrix):
|
70
|
+
"""Reorder the nonzero elements of the matrix, such that the factorization step creates less fill-in.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
sparse_matrix (SparseMatrix): The sparse matrix to be analyzed.
|
74
|
+
"""
|
75
|
+
if isinstance(sparse_matrix, SparseMatrix):
|
76
|
+
self.matrix = sparse_matrix
|
77
|
+
if self.matrix.dtype != self.dtype:
|
78
|
+
raise GsTaichiRuntimeError(
|
79
|
+
f"The SparseSolver's dtype {self.dtype} is not consistent with the SparseMatrix's dtype {self.matrix.dtype}."
|
80
|
+
)
|
81
|
+
self.solver.analyze_pattern(sparse_matrix.matrix)
|
82
|
+
else:
|
83
|
+
self._type_assert(sparse_matrix)
|
84
|
+
|
85
|
+
def factorize(self, sparse_matrix):
|
86
|
+
"""Do the factorization step
|
87
|
+
|
88
|
+
Args:
|
89
|
+
sparse_matrix (SparseMatrix): The sparse matrix to be factorized.
|
90
|
+
"""
|
91
|
+
if isinstance(sparse_matrix, SparseMatrix):
|
92
|
+
self.matrix = sparse_matrix
|
93
|
+
self.solver.factorize(sparse_matrix.matrix)
|
94
|
+
else:
|
95
|
+
self._type_assert(sparse_matrix)
|
96
|
+
|
97
|
+
def solve(self, b): # pylint: disable=R1710
|
98
|
+
"""Computes the solution of the linear systems.
|
99
|
+
Args:
|
100
|
+
b (numpy.array or Field): The right-hand side of the linear systems.
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
numpy.array: The solution of linear systems.
|
104
|
+
"""
|
105
|
+
if self.matrix is None:
|
106
|
+
raise GsTaichiRuntimeError("Please call compute() before calling solve().")
|
107
|
+
if isinstance(b, Field):
|
108
|
+
return self.solver.solve(b.to_numpy())
|
109
|
+
if isinstance(b, np.ndarray):
|
110
|
+
return self.solver.solve(b)
|
111
|
+
if isinstance(b, Ndarray):
|
112
|
+
x = ScalarNdarray(b.dtype, [self.matrix.m])
|
113
|
+
self.solver.solve_rf(get_runtime().prog, self.matrix.matrix, b.arr, x.arr)
|
114
|
+
return x
|
115
|
+
raise GsTaichiRuntimeError(f"The parameter type: {type(b)} is not supported in linear solvers for now.")
|
116
|
+
|
117
|
+
def info(self):
|
118
|
+
"""Check if the linear systems are solved successfully.
|
119
|
+
|
120
|
+
Returns:
|
121
|
+
bool: True if the solving process succeeded, False otherwise.
|
122
|
+
"""
|
123
|
+
return self.solver.info()
|
@@ -0,0 +1,205 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
from gstaichi.lang import ops
|
4
|
+
from gstaichi.lang.kernel_impl import func
|
5
|
+
|
6
|
+
from .mathimpl import dot, vec2
|
7
|
+
|
8
|
+
|
9
|
+
@func
|
10
|
+
def cmul(z1, z2):
|
11
|
+
"""Performs complex multiplication between two 2d vectors.
|
12
|
+
|
13
|
+
This is equivalent to the multiplication in the complex number field
|
14
|
+
when `z1` and `z2` are treated as complex numbers.
|
15
|
+
|
16
|
+
Args:
|
17
|
+
z1 (:class:`~gstaichi.math.vec2`): The first input.
|
18
|
+
z2 (:class:`~gstaichi.math.vec2`): The second input.
|
19
|
+
|
20
|
+
Example::
|
21
|
+
|
22
|
+
>>> @ti.kernel
|
23
|
+
>>> def test():
|
24
|
+
>>> z1 = ti.math.vec2(1, 1)
|
25
|
+
>>> z2 = ti.math.vec2(0, 1)
|
26
|
+
>>> ti.math.cmul(z1, z2) # [-1, 1]
|
27
|
+
|
28
|
+
Returns:
|
29
|
+
:class:`~gstaichi.math.vec2`: the complex multiplication `z1 * z2`.
|
30
|
+
"""
|
31
|
+
x1, y1 = z1[0], z1[1]
|
32
|
+
x2, y2 = z2[0], z2[1]
|
33
|
+
return vec2(x1 * x2 - y1 * y2, x1 * y2 + x2 * y1)
|
34
|
+
|
35
|
+
|
36
|
+
@func
|
37
|
+
def cconj(z):
|
38
|
+
"""Returns the complex conjugate of a 2d vector.
|
39
|
+
|
40
|
+
If `z=(x, y)` then the conjugate of `z` is `(x, -y)`.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
z (:class:`~gstaichi.math.vec2`): The input.
|
44
|
+
|
45
|
+
Returns:
|
46
|
+
:class:`~gstaichi.math.vec2`: The complex conjugate of `z`.
|
47
|
+
"""
|
48
|
+
return vec2(z[0], -z[1])
|
49
|
+
|
50
|
+
|
51
|
+
@func
|
52
|
+
def cdiv(z1, z2):
|
53
|
+
"""Performs complex division between two 2d vectors.
|
54
|
+
|
55
|
+
This is equivalent to the division in the complex number field
|
56
|
+
when `z1` and `z2` are treated as complex numbers.
|
57
|
+
|
58
|
+
Args:
|
59
|
+
z1 (:class:`~gstaichi.math.vec2`): The first input.
|
60
|
+
z2 (:class:`~gstaichi.math.vec2`): The second input.
|
61
|
+
|
62
|
+
Example::
|
63
|
+
|
64
|
+
>>> @ti.kernel
|
65
|
+
>>> def test():
|
66
|
+
>>> z1 = ti.math.vec2(1, 1)
|
67
|
+
>>> z2 = ti.math.vec2(0, 1)
|
68
|
+
>>> ti.math.cdiv(z1, z2) # [1, -1]
|
69
|
+
|
70
|
+
Returns:
|
71
|
+
:class:`~gstaichi.math.vec2`: the complex division of `z1 / z2`.
|
72
|
+
"""
|
73
|
+
x1, y1 = z1[0], z1[1]
|
74
|
+
x2, y2 = z2[0], z2[1]
|
75
|
+
return vec2(x1 * x2 + y1 * y2, -x1 * y2 + x2 * y1) / dot(z2, z2)
|
76
|
+
|
77
|
+
|
78
|
+
@func
|
79
|
+
def csqrt(z):
|
80
|
+
"""Returns the complex square root of a 2d vector `z`, so that
|
81
|
+
if `w^2=z`, then `w = csqrt(z)`.
|
82
|
+
|
83
|
+
Among the two square roots of `z`, if their real parts are non-zero,
|
84
|
+
the one with positive real part is returned. If both their real parts
|
85
|
+
are zero, the one with non-negative imaginary part is returned.
|
86
|
+
|
87
|
+
Args:
|
88
|
+
z (:class:`~gstaichi.math.vec2`): The input.
|
89
|
+
|
90
|
+
Example::
|
91
|
+
|
92
|
+
>>> @ti.kernel
|
93
|
+
>>> def test():
|
94
|
+
>>> z = ti.math.vec2(-1, 0)
|
95
|
+
>>> w = ti.math.csqrt(z) # [0, 1]
|
96
|
+
|
97
|
+
Returns:
|
98
|
+
:class:`~gstaichi.math.vec2`: The complex square root.
|
99
|
+
"""
|
100
|
+
result = vec2(0.0)
|
101
|
+
if any(z):
|
102
|
+
r = ops.sqrt(z.norm())
|
103
|
+
a = ops.atan2(z[1], z[0])
|
104
|
+
result = r * vec2(ops.cos(a / 2.0), ops.sin(a / 2.0))
|
105
|
+
|
106
|
+
return result
|
107
|
+
|
108
|
+
|
109
|
+
@func
|
110
|
+
def cinv(z):
|
111
|
+
"""Computes the reciprocal of a complex `z`.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
z (:class:`~gstaichi.math.vec2`): The input.
|
115
|
+
|
116
|
+
Example::
|
117
|
+
|
118
|
+
>>> @ti.kernel
|
119
|
+
>>> def test():
|
120
|
+
>>> z = ti.math.vec2(1, 1)
|
121
|
+
>>> w = ti.math.cinv(z) # [0.5, -0.5]
|
122
|
+
|
123
|
+
Returns:
|
124
|
+
:class:`~gstaichi.math.vec2`: The reciprocal of `z`.
|
125
|
+
"""
|
126
|
+
return cconj(z) / dot(z, z)
|
127
|
+
|
128
|
+
|
129
|
+
@func
|
130
|
+
def cpow(z, n):
|
131
|
+
"""Computes the power of a complex `z`: :math:`z^a`.
|
132
|
+
|
133
|
+
Args:
|
134
|
+
z (:class:`~gstaichi.math.vec2`): The base.
|
135
|
+
a (float): The exponent.
|
136
|
+
|
137
|
+
Example::
|
138
|
+
|
139
|
+
>>> @ti.kernel
|
140
|
+
>>> def test():
|
141
|
+
>>> z = ti.math.vec2(1, 1)
|
142
|
+
>>> w = ti.math.cpow(z) # [-2, 2]
|
143
|
+
|
144
|
+
Returns:
|
145
|
+
:class:`~gstaichi.math.vec2`: The power :math:`z^a`.
|
146
|
+
"""
|
147
|
+
result = vec2(0.0)
|
148
|
+
if any(z):
|
149
|
+
r2 = dot(z, z)
|
150
|
+
a = ops.atan2(z[1], z[0]) * n
|
151
|
+
result = ops.pow(r2, n / 2.0) * vec2(ops.cos(a), ops.sin(a))
|
152
|
+
|
153
|
+
return result
|
154
|
+
|
155
|
+
|
156
|
+
@func
|
157
|
+
def cexp(z):
|
158
|
+
"""Returns the complex exponential :math:`e^z`.
|
159
|
+
|
160
|
+
`z` is a 2d vector treated as a complex number.
|
161
|
+
|
162
|
+
Args:
|
163
|
+
z (:class:`~gstaichi.math.vec2`): The exponent.
|
164
|
+
|
165
|
+
Example::
|
166
|
+
|
167
|
+
>>> @ti.kernel
|
168
|
+
>>> def test():
|
169
|
+
>>> z = ti.math.vec2(1, 1)
|
170
|
+
>>> w = ti.math.cexp(z) # [1.468694, 2.287355]
|
171
|
+
|
172
|
+
Returns:
|
173
|
+
:class:`~gstaichi.math.vec2`: The power :math:`exp(z)`
|
174
|
+
"""
|
175
|
+
r = ops.exp(z[0])
|
176
|
+
return vec2(r * ops.cos(z[1]), r * ops.sin(z[1]))
|
177
|
+
|
178
|
+
|
179
|
+
@func
|
180
|
+
def clog(z):
|
181
|
+
"""Returns the complex logarithm of `z`, so that if :math:`e^w = z`,
|
182
|
+
then :math:`log(z) = w`.
|
183
|
+
|
184
|
+
`z` is a 2d vector treated as a complex number. The argument of :math:`w`
|
185
|
+
lies in the range (-pi, pi].
|
186
|
+
|
187
|
+
Args:
|
188
|
+
z (:class:`~gstaichi.math.vec2`): The input.
|
189
|
+
|
190
|
+
Example::
|
191
|
+
|
192
|
+
>>> @ti.kernel
|
193
|
+
>>> def test():
|
194
|
+
>>> z = ti.math.vec2(1, 1)
|
195
|
+
>>> w = ti.math.clog(z) # [0.346574, 0.785398]
|
196
|
+
|
197
|
+
Returns:
|
198
|
+
:class:`~gstaichi.math.vec2`: The logarithm of `z`.
|
199
|
+
"""
|
200
|
+
ang = ops.atan2(z[1], z[0])
|
201
|
+
r2 = dot(z, z)
|
202
|
+
return vec2(ops.log(r2) / 2.0, ang)
|
203
|
+
|
204
|
+
|
205
|
+
__all__ = ["cconj", "cdiv", "cexp", "cinv", "clog", "cmul", "cpow", "csqrt"]
|