gstaichi 2.1.1rc3__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. gstaichi/CHANGELOG.md +4 -0
  2. gstaichi/__init__.py +40 -0
  3. gstaichi/_funcs.py +706 -0
  4. gstaichi/_kernels.py +420 -0
  5. gstaichi/_lib/__init__.py +3 -0
  6. gstaichi/_lib/core/__init__.py +0 -0
  7. gstaichi/_lib/core/gstaichi_python.cpython-310-darwin.so +0 -0
  8. gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
  9. gstaichi/_lib/core/py.typed +0 -0
  10. gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
  11. gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
  12. gstaichi/_lib/utils.py +243 -0
  13. gstaichi/_logging.py +131 -0
  14. gstaichi/_snode/__init__.py +5 -0
  15. gstaichi/_snode/fields_builder.py +187 -0
  16. gstaichi/_snode/snode_tree.py +34 -0
  17. gstaichi/_test_tools/__init__.py +18 -0
  18. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  19. gstaichi/_test_tools/load_kernel_string.py +30 -0
  20. gstaichi/_test_tools/textwrap2.py +6 -0
  21. gstaichi/_version.py +1 -0
  22. gstaichi/_version_check.py +100 -0
  23. gstaichi/ad/__init__.py +3 -0
  24. gstaichi/ad/_ad.py +530 -0
  25. gstaichi/algorithms/__init__.py +3 -0
  26. gstaichi/algorithms/_algorithms.py +117 -0
  27. gstaichi/assets/.git +1 -0
  28. gstaichi/assets/Go-Regular.ttf +0 -0
  29. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  30. gstaichi/examples/lcg_python.py +26 -0
  31. gstaichi/examples/lcg_taichi.py +34 -0
  32. gstaichi/examples/minimal.py +28 -0
  33. gstaichi/experimental.py +16 -0
  34. gstaichi/lang/__init__.py +50 -0
  35. gstaichi/lang/_dataclass_util.py +31 -0
  36. gstaichi/lang/_fast_caching/__init__.py +3 -0
  37. gstaichi/lang/_fast_caching/args_hasher.py +110 -0
  38. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  39. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  40. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  41. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  42. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  43. gstaichi/lang/_fast_caching/src_hasher.py +75 -0
  44. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  45. gstaichi/lang/_ndarray.py +352 -0
  46. gstaichi/lang/_ndrange.py +152 -0
  47. gstaichi/lang/_template_mapper.py +195 -0
  48. gstaichi/lang/_texture.py +172 -0
  49. gstaichi/lang/_wrap_inspect.py +215 -0
  50. gstaichi/lang/any_array.py +99 -0
  51. gstaichi/lang/ast/__init__.py +5 -0
  52. gstaichi/lang/ast/ast_transformer.py +1323 -0
  53. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  54. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  55. gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
  56. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  57. gstaichi/lang/ast/checkers.py +106 -0
  58. gstaichi/lang/ast/symbol_resolver.py +57 -0
  59. gstaichi/lang/ast/transform.py +9 -0
  60. gstaichi/lang/common_ops.py +310 -0
  61. gstaichi/lang/exception.py +80 -0
  62. gstaichi/lang/expr.py +180 -0
  63. gstaichi/lang/field.py +428 -0
  64. gstaichi/lang/impl.py +1243 -0
  65. gstaichi/lang/kernel_arguments.py +155 -0
  66. gstaichi/lang/kernel_impl.py +1341 -0
  67. gstaichi/lang/matrix.py +1835 -0
  68. gstaichi/lang/matrix_ops.py +341 -0
  69. gstaichi/lang/matrix_ops_utils.py +190 -0
  70. gstaichi/lang/mesh.py +687 -0
  71. gstaichi/lang/misc.py +782 -0
  72. gstaichi/lang/ops.py +1494 -0
  73. gstaichi/lang/runtime_ops.py +13 -0
  74. gstaichi/lang/shell.py +35 -0
  75. gstaichi/lang/simt/__init__.py +5 -0
  76. gstaichi/lang/simt/block.py +94 -0
  77. gstaichi/lang/simt/grid.py +7 -0
  78. gstaichi/lang/simt/subgroup.py +191 -0
  79. gstaichi/lang/simt/warp.py +96 -0
  80. gstaichi/lang/snode.py +489 -0
  81. gstaichi/lang/source_builder.py +150 -0
  82. gstaichi/lang/struct.py +810 -0
  83. gstaichi/lang/util.py +312 -0
  84. gstaichi/linalg/__init__.py +8 -0
  85. gstaichi/linalg/matrixfree_cg.py +310 -0
  86. gstaichi/linalg/sparse_cg.py +59 -0
  87. gstaichi/linalg/sparse_matrix.py +303 -0
  88. gstaichi/linalg/sparse_solver.py +123 -0
  89. gstaichi/math/__init__.py +11 -0
  90. gstaichi/math/_complex.py +205 -0
  91. gstaichi/math/mathimpl.py +886 -0
  92. gstaichi/profiler/__init__.py +6 -0
  93. gstaichi/profiler/kernel_metrics.py +260 -0
  94. gstaichi/profiler/kernel_profiler.py +586 -0
  95. gstaichi/profiler/memory_profiler.py +15 -0
  96. gstaichi/profiler/scoped_profiler.py +36 -0
  97. gstaichi/sparse/__init__.py +3 -0
  98. gstaichi/sparse/_sparse_grid.py +77 -0
  99. gstaichi/tools/__init__.py +12 -0
  100. gstaichi/tools/diagnose.py +117 -0
  101. gstaichi/tools/np2ply.py +364 -0
  102. gstaichi/tools/vtk.py +38 -0
  103. gstaichi/types/__init__.py +19 -0
  104. gstaichi/types/annotations.py +52 -0
  105. gstaichi/types/compound_types.py +71 -0
  106. gstaichi/types/enums.py +49 -0
  107. gstaichi/types/ndarray_type.py +169 -0
  108. gstaichi/types/primitive_types.py +206 -0
  109. gstaichi/types/quant.py +88 -0
  110. gstaichi/types/texture_type.py +85 -0
  111. gstaichi/types/utils.py +11 -0
  112. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
  113. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
  114. gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
  115. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
  116. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
  117. gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
  118. gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
  119. gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  120. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
  121. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
  122. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  123. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  124. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  125. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  126. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  127. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  128. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  129. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  130. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  131. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  132. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  133. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  134. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  135. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  136. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  137. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  138. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  139. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  140. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  141. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  142. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  143. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  144. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  145. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  146. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  147. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  148. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  149. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  150. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  151. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  152. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  153. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  154. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  155. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  156. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  157. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  158. gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  159. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  160. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  161. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  162. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  163. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  164. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  165. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  166. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  167. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  168. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  169. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  170. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  171. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  172. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  173. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  174. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  175. gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
  176. gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
  177. gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
  178. gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
  179. gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,303 @@
1
+ # type: ignore
2
+
3
+ from functools import reduce
4
+
5
+ import numpy as np
6
+
7
+ from gstaichi._lib import core as _ti_core
8
+ from gstaichi.lang._ndarray import Ndarray, ScalarNdarray
9
+ from gstaichi.lang.exception import GsTaichiRuntimeError
10
+ from gstaichi.lang.field import Field
11
+ from gstaichi.lang.impl import get_runtime
12
+ from gstaichi.types import f32
13
+
14
+
15
+ class SparseMatrix:
16
+ """GsTaichi's Sparse Matrix class
17
+
18
+ A sparse matrix allows the programmer to solve a large linear system.
19
+
20
+ Args:
21
+ n (int): the first dimension of a sparse matrix.
22
+ m (int): the second dimension of a sparse matrix.
23
+ sm (SparseMatrix): another sparse matrix that will be built from.
24
+ """
25
+
26
+ def __init__(self, n=None, m=None, sm=None, dtype=f32, storage_format="col_major"):
27
+ self.dtype = dtype
28
+ if sm is None:
29
+ self.n = n
30
+ self.m = m if m else n
31
+ self.matrix = get_runtime().prog.create_sparse_matrix(n, m, dtype, storage_format)
32
+ else:
33
+ self.n = sm.num_rows()
34
+ self.m = sm.num_cols()
35
+ self.matrix = sm
36
+
37
+ def __iadd__(self, other):
38
+ """Addition operation for sparse matrix.
39
+
40
+ Returns:
41
+ The result sparse matrix of the addition.
42
+ """
43
+ assert (
44
+ self.n == other.n and self.m == other.m
45
+ ), f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})"
46
+ self.matrix += other.matrix
47
+ return self
48
+
49
+ def __add__(self, other):
50
+ """Addition operation for sparse matrix.
51
+
52
+ Returns:
53
+ The result sparse matrix of the addition.
54
+ """
55
+ assert (
56
+ self.n == other.n and self.m == other.m
57
+ ), f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})"
58
+ sm = self.matrix + other.matrix
59
+ return SparseMatrix(sm=sm)
60
+
61
+ def __isub__(self, other):
62
+ """Subtraction operation for sparse matrix.
63
+
64
+ Returns:
65
+ The result sparse matrix of the subtraction.
66
+ """
67
+ assert (
68
+ self.n == other.n and self.m == other.m
69
+ ), f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})"
70
+ self.matrix -= other.matrix
71
+ return self
72
+
73
+ def __sub__(self, other):
74
+ """Subtraction operation for sparse matrix.
75
+
76
+ Returns:
77
+ The result sparse matrix of the subtraction.
78
+ """
79
+ assert (
80
+ self.n == other.n and self.m == other.m
81
+ ), f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})"
82
+ sm = self.matrix - other.matrix
83
+ return SparseMatrix(sm=sm)
84
+
85
+ def __mul__(self, other):
86
+ """Sparse matrix's multiplication against real numbers or the hadamard product against another matrix
87
+
88
+ Args:
89
+ other (float or SparseMatrix): the other operand of multiplication.
90
+ Returns:
91
+ The result of multiplication.
92
+ """
93
+ if isinstance(other, float):
94
+ sm = other * self.matrix
95
+ return SparseMatrix(sm=sm)
96
+ if isinstance(other, SparseMatrix):
97
+ assert (
98
+ self.n == other.n and self.m == other.m
99
+ ), f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})"
100
+ sm = self.matrix * other.matrix
101
+ return SparseMatrix(sm=sm)
102
+
103
+ return None
104
+
105
+ def __rmul__(self, other):
106
+ """Right scalar multiplication for sparse matrix.
107
+
108
+ Args:
109
+ other (float): the other operand of scalar multiplication.
110
+ Returns:
111
+ The result of multiplication.
112
+ """
113
+ if isinstance(other, float):
114
+ sm = self.matrix * other
115
+ return SparseMatrix(sm=sm)
116
+
117
+ return None
118
+
119
+ def transpose(self):
120
+ """Sparse Matrix transpose.
121
+
122
+ Returns:
123
+ The transposed sparse mastrix.
124
+ """
125
+ sm = self.matrix.transpose()
126
+ return SparseMatrix(sm=sm)
127
+
128
+ def __matmul__(self, other):
129
+ """Matrix multiplication.
130
+
131
+ Args:
132
+ other (SparseMatrix, Field, or numpy.array): the other sparse matrix of the multiplication.
133
+ Returns:
134
+ The result of matrix multiplication.
135
+ """
136
+ if isinstance(other, SparseMatrix):
137
+ assert (
138
+ self.m == other.n
139
+ ), f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})"
140
+ sm = self.matrix.matmul(other.matrix)
141
+ return SparseMatrix(sm=sm)
142
+ if isinstance(other, Field):
143
+ assert (
144
+ self.m == other.shape[0]
145
+ ), f"Dimension mismatch between sparse matrix ({self.n}, {self.m}) and vector ({other.shape})"
146
+ return self.matrix.mat_vec_mul(other.to_numpy())
147
+ if isinstance(other, np.ndarray):
148
+ assert (
149
+ self.m == other.shape[0]
150
+ ), f"Dimension mismatch between sparse matrix ({self.n}, {self.m}) and vector ({other.shape})"
151
+ return self.matrix.mat_vec_mul(other)
152
+ if isinstance(other, Ndarray):
153
+ if self.m != other.shape[0]:
154
+ raise GsTaichiRuntimeError(
155
+ f"Dimension mismatch between sparse matrix ({self.n}, {self.m}) and vector ({other.shape})"
156
+ )
157
+ res = ScalarNdarray(dtype=other.dtype, arr_shape=(self.n,))
158
+ self.matrix.spmv(get_runtime().prog, other.arr, res.arr)
159
+ return res
160
+ raise GsTaichiRuntimeError(
161
+ f"Sparse matrix-matrix/vector multiplication does not support {type(other)} for now. Supported types are SparseMatrix, ti.field, and numpy ndarray."
162
+ )
163
+
164
+ def __getitem__(self, indices):
165
+ return self.matrix.get_element(indices[0], indices[1])
166
+
167
+ def __setitem__(self, indices, value):
168
+ self.matrix.set_element(indices[0], indices[1], value)
169
+
170
+ def __str__(self):
171
+ """Python scope matrix print support."""
172
+ return self.matrix.to_string()
173
+
174
+ def __repr__(self):
175
+ return self.matrix.to_string()
176
+
177
+ @property
178
+ def shape(self):
179
+ """The shape of the sparse matrix."""
180
+ return (self.n, self.m)
181
+
182
+ def build_from_ndarray(self, ndarray):
183
+ """Build the sparse matrix from a ndarray.
184
+
185
+ Args:
186
+ ndarray (Union[ti.ndarray, ti.Vector.ndarray, ti.Matrix.ndarray]): the ndarray to build the sparse matrix from.
187
+
188
+ Raises:
189
+ GsTaichiRuntimeError: If the input is not a ndarray or the length is not divisible by 3.
190
+
191
+ Example::
192
+ >>> N = 5
193
+ >>> triplets = ti.Vector.ndarray(n=3, dtype=ti.f32, shape=10, layout=ti.Layout.AOS)
194
+ >>> @ti.kernel
195
+ >>> def fill(triplets: ti.types.ndarray()):
196
+ >>> for i in range(N):
197
+ >>> triplets[i] = ti.Vector([i, (i + 1) % N, i+1], dt=ti.f32)
198
+ >>> fill(triplets)
199
+ >>> A = ti.linalg.SparseMatrix(n=N, m=N, dtype=ti.f32)
200
+ >>> A.build_from_ndarray(triplets)
201
+ >>> print(A)
202
+ [0, 1, 0, 0, 0]
203
+ [0, 0, 2, 0, 0]
204
+ [0, 0, 0, 3, 0]
205
+ [0, 0, 0, 0, 4]
206
+ [5, 0, 0, 0, 0]
207
+ """
208
+ if isinstance(ndarray, Ndarray):
209
+ num_scalars = reduce(lambda x, y: x * y, ndarray.shape + ndarray.element_shape)
210
+ if num_scalars % 3 != 0:
211
+ raise GsTaichiRuntimeError("The number of ndarray elements must have a length that is divisible by 3.")
212
+ get_runtime().prog.make_sparse_matrix_from_ndarray(self.matrix, ndarray.arr)
213
+ else:
214
+ raise GsTaichiRuntimeError(
215
+ "Sparse matrix only supports building from [ti.ndarray, ti.Vector.ndarray, ti.Matrix.ndarray]"
216
+ )
217
+
218
+ def mmwrite(self, filename):
219
+ """Writes the sparse matrix to Matrix Market file-like target.
220
+
221
+ Args:
222
+ filename (str): the file name to write the sparse matrix to.
223
+ """
224
+ self.matrix.mmwrite(filename)
225
+
226
+
227
+ class SparseMatrixBuilder:
228
+ """A python wrap around sparse matrix builder.
229
+
230
+ Use this builder to fill the sparse matrix.
231
+
232
+ Args:
233
+ num_rows (int): the first dimension of a sparse matrix.
234
+ num_cols (int): the second dimension of a sparse matrix.
235
+ max_num_triplets (int): the maximum number of triplets.
236
+ dtype (ti.dtype): the data type of the sparse matrix.
237
+ storage_format (str): the storage format of the sparse matrix.
238
+ """
239
+
240
+ def __init__(
241
+ self,
242
+ num_rows=None,
243
+ num_cols=None,
244
+ max_num_triplets=0,
245
+ dtype=f32,
246
+ storage_format="col_major",
247
+ ):
248
+ self.num_rows = num_rows
249
+ self.num_cols = num_cols if num_cols else num_rows
250
+ self.dtype = dtype
251
+ if num_rows is not None:
252
+ gstaichi_arch = get_runtime().prog.config().arch
253
+ if gstaichi_arch in [
254
+ _ti_core.Arch.x64,
255
+ _ti_core.Arch.arm64,
256
+ _ti_core.Arch.cuda,
257
+ ]:
258
+ self.ptr = _ti_core.SparseMatrixBuilder(
259
+ num_rows,
260
+ num_cols,
261
+ max_num_triplets,
262
+ dtype,
263
+ storage_format,
264
+ )
265
+ self.ptr.create_ndarray(get_runtime().prog)
266
+ else:
267
+ raise GsTaichiRuntimeError("SparseMatrix only supports CPU and CUDA for now.")
268
+
269
+ def _get_addr(self):
270
+ """Get the address of the sparse matrix"""
271
+ return self.ptr.get_addr()
272
+
273
+ def _get_ndarray_addr(self):
274
+ """Get the address of the ndarray"""
275
+ return self.ptr.get_ndarray_data_ptr()
276
+
277
+ def print_triplets(self):
278
+ """Print the triplets stored in the builder"""
279
+ gstaichi_arch = get_runtime().prog.config().arch
280
+ if gstaichi_arch in [_ti_core.Arch.x64, _ti_core.Arch.arm64]:
281
+ self.ptr.print_triplets_eigen()
282
+ elif gstaichi_arch == _ti_core.Arch.cuda:
283
+ self.ptr.print_triplets_cuda()
284
+
285
+ def build(self, dtype=f32, _format="CSR"):
286
+ """Create a sparse matrix using the triplets"""
287
+ gstaichi_arch = get_runtime().prog.config().arch
288
+ if gstaichi_arch in [_ti_core.Arch.x64, _ti_core.Arch.arm64]:
289
+ sm = self.ptr.build()
290
+ return SparseMatrix(sm=sm, dtype=self.dtype)
291
+ if gstaichi_arch == _ti_core.Arch.cuda:
292
+ if self.dtype != f32:
293
+ raise GsTaichiRuntimeError("CUDA sparse matrix only supports f32.")
294
+ sm = self.ptr.build_cuda()
295
+ return SparseMatrix(sm=sm, dtype=self.dtype)
296
+ raise GsTaichiRuntimeError("Sparse matrix only supports CPU and CUDA backends.")
297
+
298
+ def __del__(self):
299
+ if get_runtime() is not None and get_runtime().prog is not None:
300
+ self.ptr.delete_ndarray(get_runtime().prog)
301
+
302
+
303
+ __all__ = ["SparseMatrix", "SparseMatrixBuilder"]
@@ -0,0 +1,123 @@
1
+ # type: ignore
2
+
3
+ import numpy as np
4
+
5
+ import gstaichi.lang
6
+ from gstaichi._lib import core as _ti_core
7
+ from gstaichi.lang._ndarray import Ndarray, ScalarNdarray
8
+ from gstaichi.lang.exception import GsTaichiRuntimeError
9
+ from gstaichi.lang.field import Field
10
+ from gstaichi.lang.impl import get_runtime
11
+ from gstaichi.linalg.sparse_matrix import SparseMatrix
12
+ from gstaichi.types.primitive_types import f32
13
+
14
+
15
+ class SparseSolver:
16
+ """Sparse linear system solver
17
+
18
+ Use this class to solve linear systems represented by sparse matrices.
19
+
20
+ Args:
21
+ solver_type (str): The factorization type.
22
+ ordering (str): The method for matrices re-ordering.
23
+ """
24
+
25
+ def __init__(self, dtype=f32, solver_type="LLT", ordering="AMD"):
26
+ self.matrix = None
27
+ self.dtype = dtype
28
+ solver_type_list = ["LLT", "LDLT", "LU"]
29
+ solver_ordering = ["AMD", "COLAMD"]
30
+ if solver_type in solver_type_list and ordering in solver_ordering:
31
+ gstaichi_arch = gstaichi.lang.impl.get_runtime().prog.config().arch
32
+ assert (
33
+ gstaichi_arch == _ti_core.Arch.x64
34
+ or gstaichi_arch == _ti_core.Arch.arm64
35
+ or gstaichi_arch == _ti_core.Arch.cuda
36
+ ), "SparseSolver only supports CPU and CUDA for now."
37
+ if gstaichi_arch == _ti_core.Arch.cuda:
38
+ self.solver = _ti_core.make_cusparse_solver(dtype, solver_type, ordering)
39
+ else:
40
+ self.solver = _ti_core.make_sparse_solver(dtype, solver_type, ordering)
41
+ else:
42
+ raise GsTaichiRuntimeError(
43
+ f"The solver type {solver_type} with {ordering} is not supported for now. Only {solver_type_list} with {solver_ordering} are supported."
44
+ )
45
+
46
+ @staticmethod
47
+ def _type_assert(sparse_matrix):
48
+ raise GsTaichiRuntimeError(
49
+ f"The parameter type: {type(sparse_matrix)} is not supported in linear solvers for now."
50
+ )
51
+
52
+ def compute(self, sparse_matrix):
53
+ """This method is equivalent to calling both `analyze_pattern` and then `factorize`.
54
+
55
+ Args:
56
+ sparse_matrix (SparseMatrix): The sparse matrix to be computed.
57
+ """
58
+ if isinstance(sparse_matrix, SparseMatrix):
59
+ self.matrix = sparse_matrix
60
+ gstaichi_arch = gstaichi.lang.impl.get_runtime().prog.config().arch
61
+ if gstaichi_arch == _ti_core.Arch.x64 or gstaichi_arch == _ti_core.Arch.arm64:
62
+ self.solver.compute(sparse_matrix.matrix)
63
+ elif gstaichi_arch == _ti_core.Arch.cuda:
64
+ self.analyze_pattern(self.matrix)
65
+ self.factorize(self.matrix)
66
+ else:
67
+ self._type_assert(sparse_matrix)
68
+
69
+ def analyze_pattern(self, sparse_matrix):
70
+ """Reorder the nonzero elements of the matrix, such that the factorization step creates less fill-in.
71
+
72
+ Args:
73
+ sparse_matrix (SparseMatrix): The sparse matrix to be analyzed.
74
+ """
75
+ if isinstance(sparse_matrix, SparseMatrix):
76
+ self.matrix = sparse_matrix
77
+ if self.matrix.dtype != self.dtype:
78
+ raise GsTaichiRuntimeError(
79
+ f"The SparseSolver's dtype {self.dtype} is not consistent with the SparseMatrix's dtype {self.matrix.dtype}."
80
+ )
81
+ self.solver.analyze_pattern(sparse_matrix.matrix)
82
+ else:
83
+ self._type_assert(sparse_matrix)
84
+
85
+ def factorize(self, sparse_matrix):
86
+ """Do the factorization step
87
+
88
+ Args:
89
+ sparse_matrix (SparseMatrix): The sparse matrix to be factorized.
90
+ """
91
+ if isinstance(sparse_matrix, SparseMatrix):
92
+ self.matrix = sparse_matrix
93
+ self.solver.factorize(sparse_matrix.matrix)
94
+ else:
95
+ self._type_assert(sparse_matrix)
96
+
97
+ def solve(self, b): # pylint: disable=R1710
98
+ """Computes the solution of the linear systems.
99
+ Args:
100
+ b (numpy.array or Field): The right-hand side of the linear systems.
101
+
102
+ Returns:
103
+ numpy.array: The solution of linear systems.
104
+ """
105
+ if self.matrix is None:
106
+ raise GsTaichiRuntimeError("Please call compute() before calling solve().")
107
+ if isinstance(b, Field):
108
+ return self.solver.solve(b.to_numpy())
109
+ if isinstance(b, np.ndarray):
110
+ return self.solver.solve(b)
111
+ if isinstance(b, Ndarray):
112
+ x = ScalarNdarray(b.dtype, [self.matrix.m])
113
+ self.solver.solve_rf(get_runtime().prog, self.matrix.matrix, b.arr, x.arr)
114
+ return x
115
+ raise GsTaichiRuntimeError(f"The parameter type: {type(b)} is not supported in linear solvers for now.")
116
+
117
+ def info(self):
118
+ """Check if the linear systems are solved successfully.
119
+
120
+ Returns:
121
+ bool: True if the solving process succeeded, False otherwise.
122
+ """
123
+ return self.solver.info()
@@ -0,0 +1,11 @@
1
+ # type: ignore
2
+
3
+ """GsTaichi math module.
4
+
5
+ The math module supports glsl-style vectors, matrices and functions.
6
+ """
7
+
8
+ from ._complex import *
9
+ from .mathimpl import * # pylint: disable=W0622
10
+
11
+ del mathimpl
@@ -0,0 +1,205 @@
1
+ # type: ignore
2
+
3
+ from gstaichi.lang import ops
4
+ from gstaichi.lang.kernel_impl import func
5
+
6
+ from .mathimpl import dot, vec2
7
+
8
+
9
+ @func
10
+ def cmul(z1, z2):
11
+ """Performs complex multiplication between two 2d vectors.
12
+
13
+ This is equivalent to the multiplication in the complex number field
14
+ when `z1` and `z2` are treated as complex numbers.
15
+
16
+ Args:
17
+ z1 (:class:`~gstaichi.math.vec2`): The first input.
18
+ z2 (:class:`~gstaichi.math.vec2`): The second input.
19
+
20
+ Example::
21
+
22
+ >>> @ti.kernel
23
+ >>> def test():
24
+ >>> z1 = ti.math.vec2(1, 1)
25
+ >>> z2 = ti.math.vec2(0, 1)
26
+ >>> ti.math.cmul(z1, z2) # [-1, 1]
27
+
28
+ Returns:
29
+ :class:`~gstaichi.math.vec2`: the complex multiplication `z1 * z2`.
30
+ """
31
+ x1, y1 = z1[0], z1[1]
32
+ x2, y2 = z2[0], z2[1]
33
+ return vec2(x1 * x2 - y1 * y2, x1 * y2 + x2 * y1)
34
+
35
+
36
+ @func
37
+ def cconj(z):
38
+ """Returns the complex conjugate of a 2d vector.
39
+
40
+ If `z=(x, y)` then the conjugate of `z` is `(x, -y)`.
41
+
42
+ Args:
43
+ z (:class:`~gstaichi.math.vec2`): The input.
44
+
45
+ Returns:
46
+ :class:`~gstaichi.math.vec2`: The complex conjugate of `z`.
47
+ """
48
+ return vec2(z[0], -z[1])
49
+
50
+
51
+ @func
52
+ def cdiv(z1, z2):
53
+ """Performs complex division between two 2d vectors.
54
+
55
+ This is equivalent to the division in the complex number field
56
+ when `z1` and `z2` are treated as complex numbers.
57
+
58
+ Args:
59
+ z1 (:class:`~gstaichi.math.vec2`): The first input.
60
+ z2 (:class:`~gstaichi.math.vec2`): The second input.
61
+
62
+ Example::
63
+
64
+ >>> @ti.kernel
65
+ >>> def test():
66
+ >>> z1 = ti.math.vec2(1, 1)
67
+ >>> z2 = ti.math.vec2(0, 1)
68
+ >>> ti.math.cdiv(z1, z2) # [1, -1]
69
+
70
+ Returns:
71
+ :class:`~gstaichi.math.vec2`: the complex division of `z1 / z2`.
72
+ """
73
+ x1, y1 = z1[0], z1[1]
74
+ x2, y2 = z2[0], z2[1]
75
+ return vec2(x1 * x2 + y1 * y2, -x1 * y2 + x2 * y1) / dot(z2, z2)
76
+
77
+
78
+ @func
79
+ def csqrt(z):
80
+ """Returns the complex square root of a 2d vector `z`, so that
81
+ if `w^2=z`, then `w = csqrt(z)`.
82
+
83
+ Among the two square roots of `z`, if their real parts are non-zero,
84
+ the one with positive real part is returned. If both their real parts
85
+ are zero, the one with non-negative imaginary part is returned.
86
+
87
+ Args:
88
+ z (:class:`~gstaichi.math.vec2`): The input.
89
+
90
+ Example::
91
+
92
+ >>> @ti.kernel
93
+ >>> def test():
94
+ >>> z = ti.math.vec2(-1, 0)
95
+ >>> w = ti.math.csqrt(z) # [0, 1]
96
+
97
+ Returns:
98
+ :class:`~gstaichi.math.vec2`: The complex square root.
99
+ """
100
+ result = vec2(0.0)
101
+ if any(z):
102
+ r = ops.sqrt(z.norm())
103
+ a = ops.atan2(z[1], z[0])
104
+ result = r * vec2(ops.cos(a / 2.0), ops.sin(a / 2.0))
105
+
106
+ return result
107
+
108
+
109
+ @func
110
+ def cinv(z):
111
+ """Computes the reciprocal of a complex `z`.
112
+
113
+ Args:
114
+ z (:class:`~gstaichi.math.vec2`): The input.
115
+
116
+ Example::
117
+
118
+ >>> @ti.kernel
119
+ >>> def test():
120
+ >>> z = ti.math.vec2(1, 1)
121
+ >>> w = ti.math.cinv(z) # [0.5, -0.5]
122
+
123
+ Returns:
124
+ :class:`~gstaichi.math.vec2`: The reciprocal of `z`.
125
+ """
126
+ return cconj(z) / dot(z, z)
127
+
128
+
129
+ @func
130
+ def cpow(z, n):
131
+ """Computes the power of a complex `z`: :math:`z^a`.
132
+
133
+ Args:
134
+ z (:class:`~gstaichi.math.vec2`): The base.
135
+ a (float): The exponent.
136
+
137
+ Example::
138
+
139
+ >>> @ti.kernel
140
+ >>> def test():
141
+ >>> z = ti.math.vec2(1, 1)
142
+ >>> w = ti.math.cpow(z) # [-2, 2]
143
+
144
+ Returns:
145
+ :class:`~gstaichi.math.vec2`: The power :math:`z^a`.
146
+ """
147
+ result = vec2(0.0)
148
+ if any(z):
149
+ r2 = dot(z, z)
150
+ a = ops.atan2(z[1], z[0]) * n
151
+ result = ops.pow(r2, n / 2.0) * vec2(ops.cos(a), ops.sin(a))
152
+
153
+ return result
154
+
155
+
156
+ @func
157
+ def cexp(z):
158
+ """Returns the complex exponential :math:`e^z`.
159
+
160
+ `z` is a 2d vector treated as a complex number.
161
+
162
+ Args:
163
+ z (:class:`~gstaichi.math.vec2`): The exponent.
164
+
165
+ Example::
166
+
167
+ >>> @ti.kernel
168
+ >>> def test():
169
+ >>> z = ti.math.vec2(1, 1)
170
+ >>> w = ti.math.cexp(z) # [1.468694, 2.287355]
171
+
172
+ Returns:
173
+ :class:`~gstaichi.math.vec2`: The power :math:`exp(z)`
174
+ """
175
+ r = ops.exp(z[0])
176
+ return vec2(r * ops.cos(z[1]), r * ops.sin(z[1]))
177
+
178
+
179
+ @func
180
+ def clog(z):
181
+ """Returns the complex logarithm of `z`, so that if :math:`e^w = z`,
182
+ then :math:`log(z) = w`.
183
+
184
+ `z` is a 2d vector treated as a complex number. The argument of :math:`w`
185
+ lies in the range (-pi, pi].
186
+
187
+ Args:
188
+ z (:class:`~gstaichi.math.vec2`): The input.
189
+
190
+ Example::
191
+
192
+ >>> @ti.kernel
193
+ >>> def test():
194
+ >>> z = ti.math.vec2(1, 1)
195
+ >>> w = ti.math.clog(z) # [0.346574, 0.785398]
196
+
197
+ Returns:
198
+ :class:`~gstaichi.math.vec2`: The logarithm of `z`.
199
+ """
200
+ ang = ops.atan2(z[1], z[0])
201
+ r2 = dot(z, z)
202
+ return vec2(ops.log(r2) / 2.0, ang)
203
+
204
+
205
+ __all__ = ["cconj", "cdiv", "cexp", "cinv", "clog", "cmul", "cpow", "csqrt"]