gstaichi 0.1.25.dev0__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. gstaichi/CHANGELOG.md +9 -0
  2. gstaichi/__init__.py +40 -0
  3. gstaichi/__main__.py +5 -0
  4. gstaichi/_funcs.py +706 -0
  5. gstaichi/_kernels.py +420 -0
  6. gstaichi/_lib/__init__.py +3 -0
  7. gstaichi/_lib/core/__init__.py +0 -0
  8. gstaichi/_lib/core/gstaichi_python.cp312-win_amd64.pyd +0 -0
  9. gstaichi/_lib/core/gstaichi_python.pyi +2937 -0
  10. gstaichi/_lib/core/py.typed +0 -0
  11. gstaichi/_lib/runtime/runtime_cuda.bc +0 -0
  12. gstaichi/_lib/runtime/runtime_x64.bc +0 -0
  13. gstaichi/_lib/runtime/slim_libdevice.10.bc +0 -0
  14. gstaichi/_lib/utils.py +249 -0
  15. gstaichi/_logging.py +131 -0
  16. gstaichi/_main.py +545 -0
  17. gstaichi/_snode/__init__.py +5 -0
  18. gstaichi/_snode/fields_builder.py +187 -0
  19. gstaichi/_snode/snode_tree.py +34 -0
  20. gstaichi/_test_tools/__init__.py +0 -0
  21. gstaichi/_test_tools/load_kernel_string.py +30 -0
  22. gstaichi/_version.py +1 -0
  23. gstaichi/_version_check.py +103 -0
  24. gstaichi/ad/__init__.py +3 -0
  25. gstaichi/ad/_ad.py +530 -0
  26. gstaichi/algorithms/__init__.py +3 -0
  27. gstaichi/algorithms/_algorithms.py +117 -0
  28. gstaichi/assets/.git +1 -0
  29. gstaichi/assets/Go-Regular.ttf +0 -0
  30. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  31. gstaichi/examples/minimal.py +28 -0
  32. gstaichi/experimental.py +16 -0
  33. gstaichi/lang/__init__.py +50 -0
  34. gstaichi/lang/_ndarray.py +352 -0
  35. gstaichi/lang/_ndrange.py +152 -0
  36. gstaichi/lang/_template_mapper.py +199 -0
  37. gstaichi/lang/_texture.py +172 -0
  38. gstaichi/lang/_wrap_inspect.py +189 -0
  39. gstaichi/lang/any_array.py +99 -0
  40. gstaichi/lang/argpack.py +411 -0
  41. gstaichi/lang/ast/__init__.py +5 -0
  42. gstaichi/lang/ast/ast_transformer.py +1318 -0
  43. gstaichi/lang/ast/ast_transformer_utils.py +341 -0
  44. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  45. gstaichi/lang/ast/ast_transformers/call_transformer.py +267 -0
  46. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +320 -0
  47. gstaichi/lang/ast/checkers.py +106 -0
  48. gstaichi/lang/ast/symbol_resolver.py +57 -0
  49. gstaichi/lang/ast/transform.py +9 -0
  50. gstaichi/lang/common_ops.py +310 -0
  51. gstaichi/lang/exception.py +80 -0
  52. gstaichi/lang/expr.py +180 -0
  53. gstaichi/lang/field.py +466 -0
  54. gstaichi/lang/impl.py +1241 -0
  55. gstaichi/lang/kernel_arguments.py +157 -0
  56. gstaichi/lang/kernel_impl.py +1382 -0
  57. gstaichi/lang/matrix.py +1881 -0
  58. gstaichi/lang/matrix_ops.py +341 -0
  59. gstaichi/lang/matrix_ops_utils.py +190 -0
  60. gstaichi/lang/mesh.py +687 -0
  61. gstaichi/lang/misc.py +778 -0
  62. gstaichi/lang/ops.py +1494 -0
  63. gstaichi/lang/runtime_ops.py +13 -0
  64. gstaichi/lang/shell.py +35 -0
  65. gstaichi/lang/simt/__init__.py +5 -0
  66. gstaichi/lang/simt/block.py +94 -0
  67. gstaichi/lang/simt/grid.py +7 -0
  68. gstaichi/lang/simt/subgroup.py +191 -0
  69. gstaichi/lang/simt/warp.py +96 -0
  70. gstaichi/lang/snode.py +489 -0
  71. gstaichi/lang/source_builder.py +150 -0
  72. gstaichi/lang/struct.py +855 -0
  73. gstaichi/lang/util.py +381 -0
  74. gstaichi/linalg/__init__.py +8 -0
  75. gstaichi/linalg/matrixfree_cg.py +310 -0
  76. gstaichi/linalg/sparse_cg.py +59 -0
  77. gstaichi/linalg/sparse_matrix.py +303 -0
  78. gstaichi/linalg/sparse_solver.py +123 -0
  79. gstaichi/math/__init__.py +11 -0
  80. gstaichi/math/_complex.py +205 -0
  81. gstaichi/math/mathimpl.py +886 -0
  82. gstaichi/profiler/__init__.py +6 -0
  83. gstaichi/profiler/kernel_metrics.py +260 -0
  84. gstaichi/profiler/kernel_profiler.py +586 -0
  85. gstaichi/profiler/memory_profiler.py +15 -0
  86. gstaichi/profiler/scoped_profiler.py +36 -0
  87. gstaichi/sparse/__init__.py +3 -0
  88. gstaichi/sparse/_sparse_grid.py +77 -0
  89. gstaichi/tools/__init__.py +12 -0
  90. gstaichi/tools/diagnose.py +117 -0
  91. gstaichi/tools/np2ply.py +364 -0
  92. gstaichi/tools/vtk.py +38 -0
  93. gstaichi/types/__init__.py +19 -0
  94. gstaichi/types/annotations.py +47 -0
  95. gstaichi/types/compound_types.py +90 -0
  96. gstaichi/types/enums.py +49 -0
  97. gstaichi/types/ndarray_type.py +147 -0
  98. gstaichi/types/primitive_types.py +206 -0
  99. gstaichi/types/quant.py +88 -0
  100. gstaichi/types/texture_type.py +85 -0
  101. gstaichi/types/utils.py +13 -0
  102. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +5 -0
  103. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +29 -0
  104. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +113 -0
  105. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +5 -0
  106. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +19 -0
  107. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +122 -0
  108. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +5 -0
  109. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +19 -0
  110. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +122 -0
  111. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +5 -0
  112. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +19 -0
  113. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +122 -0
  114. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +5 -0
  115. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +19 -0
  116. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +122 -0
  117. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +5 -0
  118. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  119. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +122 -0
  120. gstaichi-0.1.25.dev0.data/data/bin/SPIRV-Tools-shared.dll +0 -0
  121. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/instrument.hpp +268 -0
  122. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.h +907 -0
  123. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.hpp +375 -0
  124. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/linker.hpp +97 -0
  125. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/optimizer.hpp +970 -0
  126. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-diff.lib +0 -0
  127. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-link.lib +0 -0
  128. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-lint.lib +0 -0
  129. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-opt.lib +0 -0
  130. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-reduce.lib +0 -0
  131. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-shared.lib +0 -0
  132. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools.lib +0 -0
  133. gstaichi-0.1.25.dev0.dist-info/METADATA +105 -0
  134. gstaichi-0.1.25.dev0.dist-info/RECORD +138 -0
  135. gstaichi-0.1.25.dev0.dist-info/WHEEL +5 -0
  136. gstaichi-0.1.25.dev0.dist-info/entry_points.txt +2 -0
  137. gstaichi-0.1.25.dev0.dist-info/licenses/LICENSE +201 -0
  138. gstaichi-0.1.25.dev0.dist-info/top_level.txt +1 -0
gstaichi/lang/util.py ADDED
@@ -0,0 +1,381 @@
1
+ # type: ignore
2
+
3
+ import functools
4
+ import os
5
+ import traceback
6
+ import warnings
7
+
8
+ import numpy as np
9
+ from colorama import Fore, Style
10
+
11
+ from gstaichi._lib import core as _ti_core
12
+ from gstaichi._logging import is_logging_effective
13
+ from gstaichi.lang import impl
14
+ from gstaichi.types.primitive_types import (
15
+ f16,
16
+ f32,
17
+ f64,
18
+ i8,
19
+ i16,
20
+ i32,
21
+ i64,
22
+ u1,
23
+ u8,
24
+ u16,
25
+ u32,
26
+ u64,
27
+ )
28
+
29
+
30
+ def has_pytorch():
31
+ """Whether has pytorch in the current Python environment.
32
+
33
+ Returns:
34
+ bool: True if has pytorch else False.
35
+
36
+ """
37
+ _has_pytorch = False
38
+ _env_torch = os.environ.get("TI_ENABLE_TORCH", "1")
39
+ if not _env_torch or int(_env_torch):
40
+ try:
41
+ import torch # pylint: disable=C0415
42
+
43
+ _has_pytorch = True
44
+ except:
45
+ pass
46
+ return _has_pytorch
47
+
48
+
49
+ def has_paddle():
50
+ """Whether has paddle in the current Python environment.
51
+
52
+ Returns:
53
+ bool: True if has paddle else False.
54
+ """
55
+ _has_paddle = False
56
+ _env_paddle = os.environ.get("TI_ENABLE_PADDLE", "1")
57
+ if not _env_paddle or int(_env_paddle):
58
+ try:
59
+ import paddle # pylint: disable=C0415
60
+
61
+ _has_paddle = True
62
+ except:
63
+ pass
64
+ return _has_paddle
65
+
66
+
67
+ def get_clangpp():
68
+ from distutils.spawn import find_executable # pylint: disable=C0415
69
+
70
+ # GsTaichi itself uses llvm-10.0.0 to compile.
71
+ # There will be some issues compiling CUDA with other clang++ version.
72
+ _clangpp_candidates = ["clang++-10"]
73
+ for c in _clangpp_candidates:
74
+ if find_executable(c) is not None:
75
+ _clangpp_presence = find_executable(c)
76
+ return _clangpp_presence
77
+ return None
78
+
79
+
80
+ def has_clangpp():
81
+ return get_clangpp() is not None
82
+
83
+
84
+ def is_matrix_class(rhs):
85
+ matrix_class = False
86
+ try:
87
+ if rhs._is_matrix_class:
88
+ matrix_class = True
89
+ except:
90
+ pass
91
+ return matrix_class
92
+
93
+
94
+ def is_gstaichi_class(rhs):
95
+ gstaichi_class = False
96
+ try:
97
+ if rhs._is_gstaichi_class:
98
+ gstaichi_class = True
99
+ except:
100
+ pass
101
+ return gstaichi_class
102
+
103
+
104
+ def to_numpy_type(dt):
105
+ """Convert gstaichi data type to its counterpart in numpy.
106
+
107
+ Args:
108
+ dt (DataType): The desired data type to convert.
109
+
110
+ Returns:
111
+ DataType: The counterpart data type in numpy.
112
+
113
+ """
114
+ if dt == f32:
115
+ return np.float32
116
+ if dt == f64:
117
+ return np.float64
118
+ if dt == i32:
119
+ return np.int32
120
+ if dt == i64:
121
+ return np.int64
122
+ if dt == i8:
123
+ return np.int8
124
+ if dt == i16:
125
+ return np.int16
126
+ if dt == u1:
127
+ return np.bool_
128
+ if dt == u8:
129
+ return np.uint8
130
+ if dt == u16:
131
+ return np.uint16
132
+ if dt == u32:
133
+ return np.uint32
134
+ if dt == u64:
135
+ return np.uint64
136
+ if dt == f16:
137
+ return np.half
138
+ assert False
139
+
140
+
141
+ def to_pytorch_type(dt):
142
+ """Convert gstaichi data type to its counterpart in torch.
143
+
144
+ Args:
145
+ dt (DataType): The desired data type to convert.
146
+
147
+ Returns:
148
+ DataType: The counterpart data type in torch.
149
+
150
+ """
151
+ import torch # pylint: disable=C0415
152
+
153
+ # pylint: disable=E1101
154
+ if dt == f32:
155
+ return torch.float32
156
+ if dt == f64:
157
+ return torch.float64
158
+ if dt == i32:
159
+ return torch.int32
160
+ if dt == i64:
161
+ return torch.int64
162
+ if dt == i8:
163
+ return torch.int8
164
+ if dt == i16:
165
+ return torch.int16
166
+ if dt == u1:
167
+ return torch.bool
168
+ if dt == u8:
169
+ return torch.uint8
170
+ if dt == f16:
171
+ return torch.float16
172
+
173
+ if dt in (u16, u32, u64):
174
+ if hasattr(torch, "uint16"):
175
+ if dt == u16:
176
+ return torch.uint16
177
+ if dt == u32:
178
+ return torch.uint32
179
+ if dt == u64:
180
+ return torch.uint64
181
+ raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.")
182
+
183
+ raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type.")
184
+
185
+
186
+ def to_paddle_type(dt):
187
+ """Convert gstaichi data type to its counterpart in paddle.
188
+
189
+ Args:
190
+ dt (DataType): The desired data type to convert.
191
+
192
+ Returns:
193
+ DataType: The counterpart data type in paddle.
194
+
195
+ """
196
+ import paddle # pylint: disable=C0415
197
+
198
+ if dt == f32:
199
+ return paddle.float32
200
+ if dt == f64:
201
+ return paddle.float64
202
+ if dt == i32:
203
+ return paddle.int32
204
+ if dt == i64:
205
+ return paddle.int64
206
+ if dt == i8:
207
+ return paddle.int8
208
+ if dt == i16:
209
+ return paddle.int16
210
+ if dt == u1:
211
+ return paddle.bool
212
+ if dt == u8:
213
+ return paddle.uint8
214
+ if dt == f16:
215
+ return paddle.float16
216
+ if dt in (u16, u32, u64):
217
+ raise RuntimeError(f"Paddle doesn't support {dt.to_string()} data type.")
218
+ assert False
219
+
220
+
221
+ def to_gstaichi_type(dt):
222
+ """Convert numpy or torch or paddle data type to its counterpart in gstaichi.
223
+
224
+ Args:
225
+ dt (DataType): The desired data type to convert.
226
+
227
+ Returns:
228
+ DataType: The counterpart data type in gstaichi.
229
+
230
+ """
231
+ if type(dt) == _ti_core.DataTypeCxx:
232
+ return dt
233
+
234
+ if dt == np.float32:
235
+ return f32
236
+ if dt == np.float64:
237
+ return f64
238
+ if dt == np.int32:
239
+ return i32
240
+ if dt == np.int64:
241
+ return i64
242
+ if dt == np.int8:
243
+ return i8
244
+ if dt == np.int16:
245
+ return i16
246
+ if dt == np.bool_:
247
+ return u1
248
+ if dt == np.uint8:
249
+ return u8
250
+ if dt == np.uint16:
251
+ return u16
252
+ if dt == np.uint32:
253
+ return u32
254
+ if dt == np.uint64:
255
+ return u64
256
+ if dt == np.half:
257
+ return f16
258
+
259
+ if has_pytorch():
260
+ import torch # pylint: disable=C0415
261
+
262
+ # pylint: disable=E1101
263
+ if dt == torch.float32:
264
+ return f32
265
+ if dt == torch.float64:
266
+ return f64
267
+ if dt == torch.int32:
268
+ return i32
269
+ if dt == torch.int64:
270
+ return i64
271
+ if dt == torch.int8:
272
+ return i8
273
+ if dt == torch.int16:
274
+ return i16
275
+ if dt == torch.bool:
276
+ return u1
277
+ if dt == torch.uint8:
278
+ return u8
279
+ if dt == torch.float16:
280
+ return f16
281
+
282
+ if hasattr(torch, "uint16"):
283
+ if dt == torch.uint16:
284
+ return u16
285
+ if dt == torch.uint32:
286
+ return u32
287
+ if dt == torch.uint64:
288
+ return u64
289
+
290
+ raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.")
291
+
292
+ if has_paddle():
293
+ import paddle # pylint: disable=C0415
294
+
295
+ if dt == paddle.float32:
296
+ return f32
297
+ if dt == paddle.float64:
298
+ return f64
299
+ if dt == paddle.int32:
300
+ return i32
301
+ if dt == paddle.int64:
302
+ return i64
303
+ if dt == paddle.int8:
304
+ return i8
305
+ if dt == paddle.int16:
306
+ return i16
307
+ if dt == paddle.bool:
308
+ return u1
309
+ if dt == paddle.uint8:
310
+ return u8
311
+ if dt == paddle.float16:
312
+ return f16
313
+ if dt in (u16, u32, u64):
314
+ raise RuntimeError(f"Paddle doesn't support {dt.to_string()} data type.")
315
+
316
+ raise AssertionError(f"Unknown type {dt}")
317
+
318
+
319
+ def cook_dtype(dtype):
320
+ if isinstance(dtype, _ti_core.DataTypeCxx):
321
+ return dtype
322
+ if isinstance(dtype, _ti_core.Type):
323
+ return _ti_core.DataTypeCxx(dtype)
324
+ if dtype is float:
325
+ return impl.get_runtime().default_fp
326
+ if dtype is int:
327
+ return impl.get_runtime().default_ip
328
+ if dtype is bool:
329
+ return u1
330
+ raise ValueError(f"Invalid data type {dtype}")
331
+
332
+
333
+ def in_gstaichi_scope():
334
+ return impl.inside_kernel()
335
+
336
+
337
+ def in_python_scope():
338
+ return not in_gstaichi_scope()
339
+
340
+
341
+ def gstaichi_scope(func):
342
+ @functools.wraps(func)
343
+ def wrapped(*args, **kwargs):
344
+ assert in_gstaichi_scope(), f"{func.__name__} cannot be called in Python-scope"
345
+ return func(*args, **kwargs)
346
+
347
+ return wrapped
348
+
349
+
350
+ def python_scope(func):
351
+ @functools.wraps(func)
352
+ def wrapped(*args, **kwargs):
353
+ assert in_python_scope(), f"{func.__name__} cannot be called in GsTaichi-scope"
354
+ return func(*args, **kwargs)
355
+
356
+ return wrapped
357
+
358
+
359
+ def warning(msg, warning_type=UserWarning, stacklevel=1, print_stack=True):
360
+ """Print a warning message. Note that the builtin `warnings` module is
361
+ unreliable since it may be suppressed by other packages such as IPython.
362
+
363
+ Args:
364
+ msg (str): message to print.
365
+ warning_type (Type[Warning]): type of warning.
366
+ stacklevel (int): warning stack level from the caller.
367
+ print_stack (bool): whether to print the stack
368
+ """
369
+ if not is_logging_effective("warn"):
370
+ return
371
+ if print_stack:
372
+ msg += f"\n{get_traceback(stacklevel)}"
373
+ warnings.warn(Fore.YELLOW + Style.BRIGHT + msg + Style.RESET_ALL, warning_type)
374
+
375
+
376
+ def get_traceback(stacklevel=1):
377
+ s = traceback.extract_stack()[: -1 - stacklevel]
378
+ return "".join(traceback.format_list(s))
379
+
380
+
381
+ __all__ = []
@@ -0,0 +1,8 @@
1
+ # type: ignore
2
+
3
+ """GsTaichi support module for sparse matrix operations."""
4
+
5
+ from gstaichi.linalg.matrixfree_cg import *
6
+ from gstaichi.linalg.sparse_cg import SparseCG
7
+ from gstaichi.linalg.sparse_matrix import *
8
+ from gstaichi.linalg.sparse_solver import SparseSolver
@@ -0,0 +1,310 @@
1
+ # type: ignore
2
+
3
+ from math import sqrt
4
+
5
+ from gstaichi.lang import misc
6
+ from gstaichi.lang.exception import GsTaichiRuntimeError, GsTaichiTypeError
7
+ from gstaichi.lang.impl import FieldsBuilder, field, grouped
8
+ from gstaichi.lang.kernel_impl import data_oriented, kernel
9
+ from gstaichi.types import primitive_types, template
10
+
11
+
12
+ @data_oriented
13
+ class LinearOperator:
14
+ def __init__(self, matvec_kernel):
15
+ self._matvec = matvec_kernel
16
+
17
+ def matvec(self, x, Ax):
18
+ if x.shape != Ax.shape:
19
+ raise GsTaichiRuntimeError(f"Dimension mismatch x.shape{x.shape} != Ax.shape{Ax.shape}.")
20
+ self._matvec(x, Ax)
21
+
22
+
23
+ def MatrixFreeCG(A, b, x, tol=1e-6, maxiter=5000, quiet=True):
24
+ """Matrix-free conjugate-gradient solver.
25
+
26
+ Use conjugate-gradient method to solve the linear system Ax = b, where A is implicitly
27
+ represented as a LinearOperator.
28
+
29
+ Args:
30
+ A (LinearOperator): The coefficient matrix A of the linear system.
31
+ b (Field): The right-hand side of the linear system.
32
+ x (Field): The initial guess for the solution.
33
+ maxiter (int): Maximum number of iterations.
34
+ atol: Tolerance(absolute) for convergence.
35
+ quiet (bool): Switch to turn on/off iteration log.
36
+ """
37
+
38
+ if b.dtype != x.dtype:
39
+ raise GsTaichiTypeError(f"Dtype mismatch b.dtype({b.dtype}) != x.dtype({x.dtype}).")
40
+ if str(b.dtype) == "f32":
41
+ solver_dtype = primitive_types.f32
42
+ elif str(b.dtype) == "f64":
43
+ solver_dtype = primitive_types.f64
44
+ else:
45
+ raise GsTaichiTypeError(f"Not supported dtype: {b.dtype}")
46
+ if b.shape != x.shape:
47
+ raise GsTaichiRuntimeError(f"Dimension mismatch b.shape{b.shape} != x.shape{x.shape}.")
48
+
49
+ size = b.shape
50
+ vector_fields_builder = FieldsBuilder()
51
+ p = field(dtype=solver_dtype)
52
+ r = field(dtype=solver_dtype)
53
+ Ap = field(dtype=solver_dtype)
54
+ Ax = field(dtype=solver_dtype)
55
+ if len(size) == 1:
56
+ axes = misc.i
57
+ elif len(size) == 2:
58
+ axes = misc.ij
59
+ elif len(size) == 3:
60
+ axes = misc.ijk
61
+ else:
62
+ raise GsTaichiRuntimeError(f"MatrixFreeCG only support 1D, 2D, 3D inputs; your inputs is {len(size)}-D.")
63
+ vector_fields_builder.dense(axes, size).place(p, r, Ap, Ax)
64
+ vector_fields_snode_tree = vector_fields_builder.finalize()
65
+
66
+ scalar_builder = FieldsBuilder()
67
+ alpha = field(dtype=solver_dtype)
68
+ beta = field(dtype=solver_dtype)
69
+ scalar_builder.place(alpha, beta)
70
+ scalar_snode_tree = scalar_builder.finalize()
71
+
72
+ @kernel
73
+ def init():
74
+ for I in grouped(x):
75
+ r[I] = b[I] - Ax[I]
76
+ p[I] = 0.0
77
+ Ap[I] = 0.0
78
+
79
+ @kernel
80
+ def reduce(p: template(), q: template()) -> solver_dtype:
81
+ result = solver_dtype(0.0)
82
+ for I in grouped(p):
83
+ result += p[I] * q[I]
84
+ return result
85
+
86
+ @kernel
87
+ def update_x():
88
+ for I in grouped(x):
89
+ x[I] += alpha[None] * p[I]
90
+
91
+ @kernel
92
+ def update_r():
93
+ for I in grouped(r):
94
+ r[I] -= alpha[None] * Ap[I]
95
+
96
+ @kernel
97
+ def update_p():
98
+ for I in grouped(p):
99
+ p[I] = r[I] + beta[None] * p[I]
100
+
101
+ def solve():
102
+ succeeded = True
103
+ A._matvec(x, Ax)
104
+ init()
105
+ initial_rTr = reduce(r, r)
106
+ if not quiet:
107
+ print(f">>> Initial residual = {initial_rTr:e}")
108
+ old_rTr = initial_rTr
109
+ new_rTr = initial_rTr
110
+ update_p()
111
+ if sqrt(initial_rTr) >= tol: # Do nothing if the initial residual is small enough
112
+ # -- Main loop --
113
+ for i in range(maxiter):
114
+ A._matvec(p, Ap) # compute Ap = A x p
115
+ pAp = reduce(p, Ap)
116
+ alpha[None] = old_rTr / pAp
117
+ update_x()
118
+ update_r()
119
+ new_rTr = reduce(r, r)
120
+ if sqrt(new_rTr) < tol:
121
+ if not quiet:
122
+ print(">>> Conjugate Gradient method converged.")
123
+ print(f">>> #iterations {i}")
124
+ break
125
+ beta[None] = new_rTr / old_rTr
126
+ update_p()
127
+ old_rTr = new_rTr
128
+ if not quiet:
129
+ print(f">>> Iter = {i+1:4}, Residual = {sqrt(new_rTr):e}")
130
+ if new_rTr >= tol:
131
+ if not quiet:
132
+ print(
133
+ f">>> Conjugate Gradient method failed to converge in {maxiter} iterations: Residual = {sqrt(new_rTr):e}"
134
+ )
135
+ succeeded = False
136
+ return succeeded
137
+
138
+ succeeded = solve()
139
+ vector_fields_snode_tree.destroy()
140
+ scalar_snode_tree.destroy()
141
+ return succeeded
142
+
143
+
144
+ def MatrixFreeBICGSTAB(A, b, x, tol=1e-6, maxiter=5000, quiet=True):
145
+ """Matrix-free biconjugate-gradient stabilized solver (BiCGSTAB).
146
+
147
+ Use BiCGSTAB method to solve the linear system Ax = b, where A is implicitly
148
+ represented as a LinearOperator.
149
+
150
+ Args:
151
+ A (LinearOperator): The coefficient matrix A of the linear system.
152
+ b (Field): The right-hand side of the linear system.
153
+ x (Field): The initial guess for the solution.
154
+ maxiter (int): Maximum number of iterations.
155
+ atol: Tolerance(absolute) for convergence.
156
+ quiet (bool): Switch to turn on/off iteration log.
157
+ """
158
+
159
+ if b.dtype != x.dtype:
160
+ raise GsTaichiTypeError(f"Dtype mismatch b.dtype({b.dtype}) != x.dtype({x.dtype}).")
161
+ if str(b.dtype) == "f32":
162
+ solver_dtype = primitive_types.f32
163
+ elif str(b.dtype) == "f64":
164
+ solver_dtype = primitive_types.f64
165
+ else:
166
+ raise GsTaichiTypeError(f"Not supported dtype: {b.dtype}")
167
+ if b.shape != x.shape:
168
+ raise GsTaichiRuntimeError(f"Dimension mismatch b.shape{b.shape} != x.shape{x.shape}.")
169
+
170
+ size = b.shape
171
+ vector_fields_builder = FieldsBuilder()
172
+ p = field(dtype=solver_dtype)
173
+ p_hat = field(dtype=solver_dtype)
174
+ r = field(dtype=solver_dtype)
175
+ r_tld = field(dtype=solver_dtype)
176
+ s = field(dtype=solver_dtype)
177
+ s_hat = field(dtype=solver_dtype)
178
+ t = field(dtype=solver_dtype)
179
+ Ap = field(dtype=solver_dtype)
180
+ Ax = field(dtype=solver_dtype)
181
+ Ashat = field(dtype=solver_dtype)
182
+ if len(size) == 1:
183
+ axes = misc.i
184
+ elif len(size) == 2:
185
+ axes = misc.ij
186
+ elif len(size) == 3:
187
+ axes = misc.ijk
188
+ else:
189
+ raise GsTaichiRuntimeError(f"MatrixFreeBICGSTAB only support 1D, 2D, 3D inputs; your inputs is {len(size)}-D.")
190
+ vector_fields_builder.dense(axes, size).place(p, p_hat, r, r_tld, s, s_hat, t, Ap, Ax, Ashat)
191
+ vector_fields_snode_tree = vector_fields_builder.finalize()
192
+
193
+ scalar_builder = FieldsBuilder()
194
+ alpha = field(dtype=solver_dtype)
195
+ beta = field(dtype=solver_dtype)
196
+ omega = field(dtype=solver_dtype)
197
+ rho = field(dtype=solver_dtype)
198
+ rho_1 = field(dtype=solver_dtype)
199
+ scalar_builder.place(alpha, beta, omega, rho, rho_1)
200
+ scalar_snode_tree = scalar_builder.finalize()
201
+ succeeded = True
202
+
203
+ @kernel
204
+ def init():
205
+ for I in grouped(x):
206
+ r[I] = b[I] - Ax[I]
207
+ r_tld[I] = b[I]
208
+ p[I] = 0.0
209
+ Ap[I] = 0.0
210
+ Ashat[I] = 0.0
211
+ rho[None] = 0.0
212
+ rho_1[None] = 1.0
213
+ alpha[None] = 1.0
214
+ beta[None] = 1.0
215
+ omega[None] = 1.0
216
+
217
+ @kernel
218
+ def reduce(p: template(), q: template()) -> solver_dtype:
219
+ result = solver_dtype(0.0)
220
+ for I in grouped(p):
221
+ result += p[I] * q[I]
222
+ return result
223
+
224
+ @kernel
225
+ def copy(orig: template(), dest: template()):
226
+ for I in grouped(orig):
227
+ dest[I] = orig[I]
228
+
229
+ @kernel
230
+ def update_p():
231
+ for I in grouped(p):
232
+ p[I] = r[I] + beta[None] * (p[I] - omega[None] * Ap[I])
233
+
234
+ @kernel
235
+ def update_phat():
236
+ for I in grouped(p_hat):
237
+ p_hat[I] = p[I]
238
+
239
+ @kernel
240
+ def update_s():
241
+ for I in grouped(s):
242
+ s[I] = r[I] - alpha[None] * Ap[I]
243
+
244
+ @kernel
245
+ def update_shat():
246
+ for I in grouped(s_hat):
247
+ s_hat[I] = s[I]
248
+
249
+ @kernel
250
+ def update_x():
251
+ for I in grouped(x):
252
+ x[I] += alpha[None] * p_hat[I] + omega[None] * s_hat[I]
253
+
254
+ @kernel
255
+ def update_r():
256
+ for I in grouped(r):
257
+ r[I] = s[I] - omega[None] * t[I]
258
+
259
+ def solve():
260
+ succeeded = True
261
+ A._matvec(x, Ax)
262
+ init()
263
+ initial_rTr = reduce(r, r)
264
+ rTr = initial_rTr
265
+ if not quiet:
266
+ print(f">>> Initial residual = {initial_rTr:e}")
267
+ if sqrt(initial_rTr) >= tol: # Do nothing if the initial residual is small enough
268
+ for i in range(maxiter):
269
+ rho[None] = reduce(r, r_tld)
270
+ if rho[None] == 0.0:
271
+ if not quiet:
272
+ print(">>> BICGSTAB failed because r@r_tld = 0.")
273
+ succeeded = False
274
+ break
275
+ if i == 0:
276
+ copy(orig=r, dest=p)
277
+ else:
278
+ beta[None] = (rho[None] / rho_1[None]) * (alpha[None] / omega[None])
279
+ update_p()
280
+ update_phat()
281
+ A._matvec(p, Ap)
282
+ alpha_lower = reduce(r_tld, Ap)
283
+ alpha[None] = rho[None] / alpha_lower
284
+ update_s()
285
+ update_shat()
286
+ A._matvec(s_hat, Ashat)
287
+ copy(orig=Ashat, dest=t)
288
+ omega_upper = reduce(t, s)
289
+ omega_lower = reduce(t, t)
290
+ omega[None] = omega_upper / (omega_lower + 1e-16) if omega_lower == 0.0 else omega_upper / omega_lower
291
+ update_x()
292
+ update_r()
293
+ rTr = reduce(r, r)
294
+ if not quiet:
295
+ print(f">>> Iter = {i+1:4}, Residual = {sqrt(rTr):e}")
296
+ if sqrt(rTr) < tol:
297
+ if not quiet:
298
+ print(f">>> BICGSTAB method converged at #iterations {i}")
299
+ break
300
+ rho_1[None] = rho[None]
301
+ if rTr >= tol:
302
+ if not quiet:
303
+ print(f">>> BICGSTAB failed to converge in {maxiter} iterations: Residual = {sqrt(rTr):e}")
304
+ succeeded = False
305
+ return succeeded
306
+
307
+ succeeded = solve()
308
+ vector_fields_snode_tree.destroy()
309
+ scalar_snode_tree.destroy()
310
+ return succeeded