gstaichi 2.1.1rc3__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. gstaichi/CHANGELOG.md +4 -0
  2. gstaichi/__init__.py +40 -0
  3. gstaichi/_funcs.py +706 -0
  4. gstaichi/_kernels.py +420 -0
  5. gstaichi/_lib/__init__.py +3 -0
  6. gstaichi/_lib/core/__init__.py +0 -0
  7. gstaichi/_lib/core/gstaichi_python.cpython-310-darwin.so +0 -0
  8. gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
  9. gstaichi/_lib/core/py.typed +0 -0
  10. gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
  11. gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
  12. gstaichi/_lib/utils.py +243 -0
  13. gstaichi/_logging.py +131 -0
  14. gstaichi/_snode/__init__.py +5 -0
  15. gstaichi/_snode/fields_builder.py +187 -0
  16. gstaichi/_snode/snode_tree.py +34 -0
  17. gstaichi/_test_tools/__init__.py +18 -0
  18. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  19. gstaichi/_test_tools/load_kernel_string.py +30 -0
  20. gstaichi/_test_tools/textwrap2.py +6 -0
  21. gstaichi/_version.py +1 -0
  22. gstaichi/_version_check.py +100 -0
  23. gstaichi/ad/__init__.py +3 -0
  24. gstaichi/ad/_ad.py +530 -0
  25. gstaichi/algorithms/__init__.py +3 -0
  26. gstaichi/algorithms/_algorithms.py +117 -0
  27. gstaichi/assets/.git +1 -0
  28. gstaichi/assets/Go-Regular.ttf +0 -0
  29. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  30. gstaichi/examples/lcg_python.py +26 -0
  31. gstaichi/examples/lcg_taichi.py +34 -0
  32. gstaichi/examples/minimal.py +28 -0
  33. gstaichi/experimental.py +16 -0
  34. gstaichi/lang/__init__.py +50 -0
  35. gstaichi/lang/_dataclass_util.py +31 -0
  36. gstaichi/lang/_fast_caching/__init__.py +3 -0
  37. gstaichi/lang/_fast_caching/args_hasher.py +110 -0
  38. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  39. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  40. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  41. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  42. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  43. gstaichi/lang/_fast_caching/src_hasher.py +75 -0
  44. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  45. gstaichi/lang/_ndarray.py +352 -0
  46. gstaichi/lang/_ndrange.py +152 -0
  47. gstaichi/lang/_template_mapper.py +195 -0
  48. gstaichi/lang/_texture.py +172 -0
  49. gstaichi/lang/_wrap_inspect.py +215 -0
  50. gstaichi/lang/any_array.py +99 -0
  51. gstaichi/lang/ast/__init__.py +5 -0
  52. gstaichi/lang/ast/ast_transformer.py +1323 -0
  53. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  54. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  55. gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
  56. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  57. gstaichi/lang/ast/checkers.py +106 -0
  58. gstaichi/lang/ast/symbol_resolver.py +57 -0
  59. gstaichi/lang/ast/transform.py +9 -0
  60. gstaichi/lang/common_ops.py +310 -0
  61. gstaichi/lang/exception.py +80 -0
  62. gstaichi/lang/expr.py +180 -0
  63. gstaichi/lang/field.py +428 -0
  64. gstaichi/lang/impl.py +1243 -0
  65. gstaichi/lang/kernel_arguments.py +155 -0
  66. gstaichi/lang/kernel_impl.py +1341 -0
  67. gstaichi/lang/matrix.py +1835 -0
  68. gstaichi/lang/matrix_ops.py +341 -0
  69. gstaichi/lang/matrix_ops_utils.py +190 -0
  70. gstaichi/lang/mesh.py +687 -0
  71. gstaichi/lang/misc.py +782 -0
  72. gstaichi/lang/ops.py +1494 -0
  73. gstaichi/lang/runtime_ops.py +13 -0
  74. gstaichi/lang/shell.py +35 -0
  75. gstaichi/lang/simt/__init__.py +5 -0
  76. gstaichi/lang/simt/block.py +94 -0
  77. gstaichi/lang/simt/grid.py +7 -0
  78. gstaichi/lang/simt/subgroup.py +191 -0
  79. gstaichi/lang/simt/warp.py +96 -0
  80. gstaichi/lang/snode.py +489 -0
  81. gstaichi/lang/source_builder.py +150 -0
  82. gstaichi/lang/struct.py +810 -0
  83. gstaichi/lang/util.py +312 -0
  84. gstaichi/linalg/__init__.py +8 -0
  85. gstaichi/linalg/matrixfree_cg.py +310 -0
  86. gstaichi/linalg/sparse_cg.py +59 -0
  87. gstaichi/linalg/sparse_matrix.py +303 -0
  88. gstaichi/linalg/sparse_solver.py +123 -0
  89. gstaichi/math/__init__.py +11 -0
  90. gstaichi/math/_complex.py +205 -0
  91. gstaichi/math/mathimpl.py +886 -0
  92. gstaichi/profiler/__init__.py +6 -0
  93. gstaichi/profiler/kernel_metrics.py +260 -0
  94. gstaichi/profiler/kernel_profiler.py +586 -0
  95. gstaichi/profiler/memory_profiler.py +15 -0
  96. gstaichi/profiler/scoped_profiler.py +36 -0
  97. gstaichi/sparse/__init__.py +3 -0
  98. gstaichi/sparse/_sparse_grid.py +77 -0
  99. gstaichi/tools/__init__.py +12 -0
  100. gstaichi/tools/diagnose.py +117 -0
  101. gstaichi/tools/np2ply.py +364 -0
  102. gstaichi/tools/vtk.py +38 -0
  103. gstaichi/types/__init__.py +19 -0
  104. gstaichi/types/annotations.py +52 -0
  105. gstaichi/types/compound_types.py +71 -0
  106. gstaichi/types/enums.py +49 -0
  107. gstaichi/types/ndarray_type.py +169 -0
  108. gstaichi/types/primitive_types.py +206 -0
  109. gstaichi/types/quant.py +88 -0
  110. gstaichi/types/texture_type.py +85 -0
  111. gstaichi/types/utils.py +11 -0
  112. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
  113. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
  114. gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
  115. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
  116. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
  117. gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
  118. gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
  119. gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  120. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
  121. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
  122. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  123. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  124. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  125. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  126. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  127. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  128. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  129. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  130. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  131. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  132. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  133. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  134. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  135. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  136. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  137. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  138. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  139. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  140. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  141. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  142. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  143. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  144. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  145. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  146. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  147. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  148. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  149. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  150. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  151. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  152. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  153. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  154. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  155. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  156. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  157. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  158. gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  159. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  160. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  161. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  162. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  163. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  164. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  165. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  166. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  167. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  168. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  169. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  170. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  171. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  172. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  173. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  174. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  175. gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
  176. gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
  177. gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
  178. gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
  179. gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
gstaichi/lang/util.py ADDED
@@ -0,0 +1,312 @@
1
+ import functools
2
+ import os
3
+ import traceback
4
+ import warnings
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+ from colorama import Fore, Style
9
+
10
+ from gstaichi._lib import core as _ti_core
11
+ from gstaichi._logging import is_logging_effective
12
+ from gstaichi.lang import impl
13
+ from gstaichi.types import Template
14
+ from gstaichi.types.primitive_types import (
15
+ f16,
16
+ f32,
17
+ f64,
18
+ i8,
19
+ i16,
20
+ i32,
21
+ i64,
22
+ u1,
23
+ u8,
24
+ u16,
25
+ u32,
26
+ u64,
27
+ )
28
+
29
+
30
+ def has_pytorch():
31
+ """Whether has pytorch in the current Python environment.
32
+
33
+ Returns:
34
+ bool: True if has pytorch else False.
35
+
36
+ """
37
+ _has_pytorch = False
38
+ _env_torch = os.environ.get("TI_ENABLE_TORCH", "1")
39
+ if not _env_torch or int(_env_torch):
40
+ try:
41
+ import torch # pylint: disable=C0415
42
+
43
+ _has_pytorch = True
44
+ except:
45
+ pass
46
+ return _has_pytorch
47
+
48
+
49
+ def get_clangpp():
50
+ from distutils.spawn import find_executable # pylint: disable=C0415
51
+
52
+ # GsTaichi itself uses llvm-10.0.0 to compile.
53
+ # There will be some issues compiling CUDA with other clang++ version.
54
+ _clangpp_candidates = ["clang++-10"]
55
+ for c in _clangpp_candidates:
56
+ if find_executable(c) is not None:
57
+ _clangpp_presence = find_executable(c)
58
+ return _clangpp_presence
59
+ return None
60
+
61
+
62
+ def has_clangpp():
63
+ return get_clangpp() is not None
64
+
65
+
66
+ def is_matrix_class(rhs):
67
+ matrix_class = False
68
+ try:
69
+ if rhs._is_matrix_class:
70
+ matrix_class = True
71
+ except:
72
+ pass
73
+ return matrix_class
74
+
75
+
76
+ def is_gstaichi_class(rhs):
77
+ gstaichi_class = False
78
+ try:
79
+ if rhs._is_gstaichi_class:
80
+ gstaichi_class = True
81
+ except:
82
+ pass
83
+ return gstaichi_class
84
+
85
+
86
+ def to_numpy_type(dt):
87
+ """Convert gstaichi data type to its counterpart in numpy.
88
+
89
+ Args:
90
+ dt (DataType): The desired data type to convert.
91
+
92
+ Returns:
93
+ DataType: The counterpart data type in numpy.
94
+
95
+ """
96
+ if dt == f32:
97
+ return np.float32
98
+ if dt == f64:
99
+ return np.float64
100
+ if dt == i32:
101
+ return np.int32
102
+ if dt == i64:
103
+ return np.int64
104
+ if dt == i8:
105
+ return np.int8
106
+ if dt == i16:
107
+ return np.int16
108
+ if dt == u1:
109
+ return np.bool_
110
+ if dt == u8:
111
+ return np.uint8
112
+ if dt == u16:
113
+ return np.uint16
114
+ if dt == u32:
115
+ return np.uint32
116
+ if dt == u64:
117
+ return np.uint64
118
+ if dt == f16:
119
+ return np.half
120
+ assert False
121
+
122
+
123
+ def to_pytorch_type(dt):
124
+ """Convert gstaichi data type to its counterpart in torch.
125
+
126
+ Args:
127
+ dt (DataType): The desired data type to convert.
128
+
129
+ Returns:
130
+ DataType: The counterpart data type in torch.
131
+
132
+ """
133
+ import torch # pylint: disable=C0415
134
+
135
+ # pylint: disable=E1101
136
+ if dt == f32:
137
+ return torch.float32
138
+ if dt == f64:
139
+ return torch.float64
140
+ if dt == i32:
141
+ return torch.int32
142
+ if dt == i64:
143
+ return torch.int64
144
+ if dt == i8:
145
+ return torch.int8
146
+ if dt == i16:
147
+ return torch.int16
148
+ if dt == u1:
149
+ return torch.bool
150
+ if dt == u8:
151
+ return torch.uint8
152
+ if dt == f16:
153
+ return torch.float16
154
+
155
+ if dt in (u16, u32, u64):
156
+ if hasattr(torch, "uint16"):
157
+ if dt == u16:
158
+ return torch.uint16
159
+ if dt == u32:
160
+ return torch.uint32
161
+ if dt == u64:
162
+ return torch.uint64
163
+ raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.")
164
+
165
+ raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type.")
166
+
167
+
168
+ def to_gstaichi_type(dt):
169
+ """Convert numpy or torch data type to its counterpart in gstaichi.
170
+
171
+ Args:
172
+ dt (DataType): The desired data type to convert.
173
+
174
+ Returns:
175
+ DataType: The counterpart data type in gstaichi.
176
+
177
+ """
178
+ if type(dt) == _ti_core.DataTypeCxx:
179
+ return dt
180
+
181
+ if dt == np.float32:
182
+ return f32
183
+ if dt == np.float64:
184
+ return f64
185
+ if dt == np.int32:
186
+ return i32
187
+ if dt == np.int64:
188
+ return i64
189
+ if dt == np.int8:
190
+ return i8
191
+ if dt == np.int16:
192
+ return i16
193
+ if dt == np.bool_:
194
+ return u1
195
+ if dt == np.uint8:
196
+ return u8
197
+ if dt == np.uint16:
198
+ return u16
199
+ if dt == np.uint32:
200
+ return u32
201
+ if dt == np.uint64:
202
+ return u64
203
+ if dt == np.half:
204
+ return f16
205
+
206
+ if has_pytorch():
207
+ import torch # pylint: disable=C0415
208
+
209
+ # pylint: disable=E1101
210
+ if dt == torch.float32:
211
+ return f32
212
+ if dt == torch.float64:
213
+ return f64
214
+ if dt == torch.int32:
215
+ return i32
216
+ if dt == torch.int64:
217
+ return i64
218
+ if dt == torch.int8:
219
+ return i8
220
+ if dt == torch.int16:
221
+ return i16
222
+ if dt == torch.bool:
223
+ return u1
224
+ if dt == torch.uint8:
225
+ return u8
226
+ if dt == torch.float16:
227
+ return f16
228
+
229
+ if hasattr(torch, "uint16"):
230
+ if dt == torch.uint16:
231
+ return u16
232
+ if dt == torch.uint32:
233
+ return u32
234
+ if dt == torch.uint64:
235
+ return u64
236
+
237
+ raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.")
238
+
239
+ raise AssertionError(f"Unknown type {dt}")
240
+
241
+
242
+ def cook_dtype(dtype):
243
+ if isinstance(dtype, _ti_core.DataTypeCxx):
244
+ return dtype
245
+ if isinstance(dtype, _ti_core.Type):
246
+ return _ti_core.DataTypeCxx(dtype)
247
+ if dtype is float:
248
+ return impl.get_runtime().default_fp
249
+ if dtype is int:
250
+ return impl.get_runtime().default_ip
251
+ if dtype is bool:
252
+ return u1
253
+ raise ValueError(f"Invalid data type {dtype}")
254
+
255
+
256
+ def in_gstaichi_scope():
257
+ return impl.inside_kernel()
258
+
259
+
260
+ def in_python_scope():
261
+ return not in_gstaichi_scope()
262
+
263
+
264
+ def gstaichi_scope(func):
265
+ @functools.wraps(func)
266
+ def wrapped(*args, **kwargs):
267
+ assert in_gstaichi_scope(), f"{func.__name__} cannot be called in Python-scope"
268
+ return func(*args, **kwargs)
269
+
270
+ return wrapped
271
+
272
+
273
+ def python_scope(func):
274
+ @functools.wraps(func)
275
+ def wrapped(*args, **kwargs):
276
+ assert in_python_scope(), f"{func.__name__} cannot be called in GsTaichi-scope"
277
+ return func(*args, **kwargs)
278
+
279
+ return wrapped
280
+
281
+
282
+ def warning(msg, warning_type=UserWarning, stacklevel=1, print_stack=True):
283
+ """Print a warning message. Note that the builtin `warnings` module is
284
+ unreliable since it may be suppressed by other packages such as IPython.
285
+
286
+ Args:
287
+ msg (str): message to print.
288
+ warning_type (Type[Warning]): type of warning.
289
+ stacklevel (int): warning stack level from the caller.
290
+ print_stack (bool): whether to print the stack
291
+ """
292
+ if not is_logging_effective("warn"):
293
+ return
294
+ if print_stack:
295
+ msg += f"\n{get_traceback(stacklevel)}"
296
+ warnings.warn(Fore.YELLOW + Style.BRIGHT + msg + Style.RESET_ALL, warning_type)
297
+
298
+
299
+ def get_traceback(stacklevel=1):
300
+ s = traceback.extract_stack()[: -1 - stacklevel]
301
+ return "".join(traceback.format_list(s))
302
+
303
+
304
+ def is_data_oriented(obj: Any) -> bool:
305
+ return getattr(obj, "_data_oriented", False)
306
+
307
+
308
+ def is_ti_template(annotation: Any) -> bool:
309
+ return annotation == Template or isinstance(annotation, Template)
310
+
311
+
312
+ __all__ = []
@@ -0,0 +1,8 @@
1
+ # type: ignore
2
+
3
+ """GsTaichi support module for sparse matrix operations."""
4
+
5
+ from gstaichi.linalg.matrixfree_cg import *
6
+ from gstaichi.linalg.sparse_cg import SparseCG
7
+ from gstaichi.linalg.sparse_matrix import *
8
+ from gstaichi.linalg.sparse_solver import SparseSolver
@@ -0,0 +1,310 @@
1
+ # type: ignore
2
+
3
+ from math import sqrt
4
+
5
+ from gstaichi.lang import misc
6
+ from gstaichi.lang.exception import GsTaichiRuntimeError, GsTaichiTypeError
7
+ from gstaichi.lang.impl import FieldsBuilder, field, grouped
8
+ from gstaichi.lang.kernel_impl import data_oriented, kernel
9
+ from gstaichi.types import primitive_types, template
10
+
11
+
12
+ @data_oriented
13
+ class LinearOperator:
14
+ def __init__(self, matvec_kernel):
15
+ self._matvec = matvec_kernel
16
+
17
+ def matvec(self, x, Ax):
18
+ if x.shape != Ax.shape:
19
+ raise GsTaichiRuntimeError(f"Dimension mismatch x.shape{x.shape} != Ax.shape{Ax.shape}.")
20
+ self._matvec(x, Ax)
21
+
22
+
23
+ def MatrixFreeCG(A, b, x, tol=1e-6, maxiter=5000, quiet=True):
24
+ """Matrix-free conjugate-gradient solver.
25
+
26
+ Use conjugate-gradient method to solve the linear system Ax = b, where A is implicitly
27
+ represented as a LinearOperator.
28
+
29
+ Args:
30
+ A (LinearOperator): The coefficient matrix A of the linear system.
31
+ b (Field): The right-hand side of the linear system.
32
+ x (Field): The initial guess for the solution.
33
+ maxiter (int): Maximum number of iterations.
34
+ atol: Tolerance(absolute) for convergence.
35
+ quiet (bool): Switch to turn on/off iteration log.
36
+ """
37
+
38
+ if b.dtype != x.dtype:
39
+ raise GsTaichiTypeError(f"Dtype mismatch b.dtype({b.dtype}) != x.dtype({x.dtype}).")
40
+ if str(b.dtype) == "f32":
41
+ solver_dtype = primitive_types.f32
42
+ elif str(b.dtype) == "f64":
43
+ solver_dtype = primitive_types.f64
44
+ else:
45
+ raise GsTaichiTypeError(f"Not supported dtype: {b.dtype}")
46
+ if b.shape != x.shape:
47
+ raise GsTaichiRuntimeError(f"Dimension mismatch b.shape{b.shape} != x.shape{x.shape}.")
48
+
49
+ size = b.shape
50
+ vector_fields_builder = FieldsBuilder()
51
+ p = field(dtype=solver_dtype)
52
+ r = field(dtype=solver_dtype)
53
+ Ap = field(dtype=solver_dtype)
54
+ Ax = field(dtype=solver_dtype)
55
+ if len(size) == 1:
56
+ axes = misc.i
57
+ elif len(size) == 2:
58
+ axes = misc.ij
59
+ elif len(size) == 3:
60
+ axes = misc.ijk
61
+ else:
62
+ raise GsTaichiRuntimeError(f"MatrixFreeCG only support 1D, 2D, 3D inputs; your inputs is {len(size)}-D.")
63
+ vector_fields_builder.dense(axes, size).place(p, r, Ap, Ax)
64
+ vector_fields_snode_tree = vector_fields_builder.finalize()
65
+
66
+ scalar_builder = FieldsBuilder()
67
+ alpha = field(dtype=solver_dtype)
68
+ beta = field(dtype=solver_dtype)
69
+ scalar_builder.place(alpha, beta)
70
+ scalar_snode_tree = scalar_builder.finalize()
71
+
72
+ @kernel
73
+ def init():
74
+ for I in grouped(x):
75
+ r[I] = b[I] - Ax[I]
76
+ p[I] = 0.0
77
+ Ap[I] = 0.0
78
+
79
+ @kernel
80
+ def reduce(p: template(), q: template()) -> solver_dtype:
81
+ result = solver_dtype(0.0)
82
+ for I in grouped(p):
83
+ result += p[I] * q[I]
84
+ return result
85
+
86
+ @kernel
87
+ def update_x():
88
+ for I in grouped(x):
89
+ x[I] += alpha[None] * p[I]
90
+
91
+ @kernel
92
+ def update_r():
93
+ for I in grouped(r):
94
+ r[I] -= alpha[None] * Ap[I]
95
+
96
+ @kernel
97
+ def update_p():
98
+ for I in grouped(p):
99
+ p[I] = r[I] + beta[None] * p[I]
100
+
101
+ def solve():
102
+ succeeded = True
103
+ A._matvec(x, Ax)
104
+ init()
105
+ initial_rTr = reduce(r, r)
106
+ if not quiet:
107
+ print(f">>> Initial residual = {initial_rTr:e}")
108
+ old_rTr = initial_rTr
109
+ new_rTr = initial_rTr
110
+ update_p()
111
+ if sqrt(initial_rTr) >= tol: # Do nothing if the initial residual is small enough
112
+ # -- Main loop --
113
+ for i in range(maxiter):
114
+ A._matvec(p, Ap) # compute Ap = A x p
115
+ pAp = reduce(p, Ap)
116
+ alpha[None] = old_rTr / pAp
117
+ update_x()
118
+ update_r()
119
+ new_rTr = reduce(r, r)
120
+ if sqrt(new_rTr) < tol:
121
+ if not quiet:
122
+ print(">>> Conjugate Gradient method converged.")
123
+ print(f">>> #iterations {i}")
124
+ break
125
+ beta[None] = new_rTr / old_rTr
126
+ update_p()
127
+ old_rTr = new_rTr
128
+ if not quiet:
129
+ print(f">>> Iter = {i+1:4}, Residual = {sqrt(new_rTr):e}")
130
+ if new_rTr >= tol:
131
+ if not quiet:
132
+ print(
133
+ f">>> Conjugate Gradient method failed to converge in {maxiter} iterations: Residual = {sqrt(new_rTr):e}"
134
+ )
135
+ succeeded = False
136
+ return succeeded
137
+
138
+ succeeded = solve()
139
+ vector_fields_snode_tree.destroy()
140
+ scalar_snode_tree.destroy()
141
+ return succeeded
142
+
143
+
144
+ def MatrixFreeBICGSTAB(A, b, x, tol=1e-6, maxiter=5000, quiet=True):
145
+ """Matrix-free biconjugate-gradient stabilized solver (BiCGSTAB).
146
+
147
+ Use BiCGSTAB method to solve the linear system Ax = b, where A is implicitly
148
+ represented as a LinearOperator.
149
+
150
+ Args:
151
+ A (LinearOperator): The coefficient matrix A of the linear system.
152
+ b (Field): The right-hand side of the linear system.
153
+ x (Field): The initial guess for the solution.
154
+ maxiter (int): Maximum number of iterations.
155
+ atol: Tolerance(absolute) for convergence.
156
+ quiet (bool): Switch to turn on/off iteration log.
157
+ """
158
+
159
+ if b.dtype != x.dtype:
160
+ raise GsTaichiTypeError(f"Dtype mismatch b.dtype({b.dtype}) != x.dtype({x.dtype}).")
161
+ if str(b.dtype) == "f32":
162
+ solver_dtype = primitive_types.f32
163
+ elif str(b.dtype) == "f64":
164
+ solver_dtype = primitive_types.f64
165
+ else:
166
+ raise GsTaichiTypeError(f"Not supported dtype: {b.dtype}")
167
+ if b.shape != x.shape:
168
+ raise GsTaichiRuntimeError(f"Dimension mismatch b.shape{b.shape} != x.shape{x.shape}.")
169
+
170
+ size = b.shape
171
+ vector_fields_builder = FieldsBuilder()
172
+ p = field(dtype=solver_dtype)
173
+ p_hat = field(dtype=solver_dtype)
174
+ r = field(dtype=solver_dtype)
175
+ r_tld = field(dtype=solver_dtype)
176
+ s = field(dtype=solver_dtype)
177
+ s_hat = field(dtype=solver_dtype)
178
+ t = field(dtype=solver_dtype)
179
+ Ap = field(dtype=solver_dtype)
180
+ Ax = field(dtype=solver_dtype)
181
+ Ashat = field(dtype=solver_dtype)
182
+ if len(size) == 1:
183
+ axes = misc.i
184
+ elif len(size) == 2:
185
+ axes = misc.ij
186
+ elif len(size) == 3:
187
+ axes = misc.ijk
188
+ else:
189
+ raise GsTaichiRuntimeError(f"MatrixFreeBICGSTAB only support 1D, 2D, 3D inputs; your inputs is {len(size)}-D.")
190
+ vector_fields_builder.dense(axes, size).place(p, p_hat, r, r_tld, s, s_hat, t, Ap, Ax, Ashat)
191
+ vector_fields_snode_tree = vector_fields_builder.finalize()
192
+
193
+ scalar_builder = FieldsBuilder()
194
+ alpha = field(dtype=solver_dtype)
195
+ beta = field(dtype=solver_dtype)
196
+ omega = field(dtype=solver_dtype)
197
+ rho = field(dtype=solver_dtype)
198
+ rho_1 = field(dtype=solver_dtype)
199
+ scalar_builder.place(alpha, beta, omega, rho, rho_1)
200
+ scalar_snode_tree = scalar_builder.finalize()
201
+ succeeded = True
202
+
203
+ @kernel
204
+ def init():
205
+ for I in grouped(x):
206
+ r[I] = b[I] - Ax[I]
207
+ r_tld[I] = b[I]
208
+ p[I] = 0.0
209
+ Ap[I] = 0.0
210
+ Ashat[I] = 0.0
211
+ rho[None] = 0.0
212
+ rho_1[None] = 1.0
213
+ alpha[None] = 1.0
214
+ beta[None] = 1.0
215
+ omega[None] = 1.0
216
+
217
+ @kernel
218
+ def reduce(p: template(), q: template()) -> solver_dtype:
219
+ result = solver_dtype(0.0)
220
+ for I in grouped(p):
221
+ result += p[I] * q[I]
222
+ return result
223
+
224
+ @kernel
225
+ def copy(orig: template(), dest: template()):
226
+ for I in grouped(orig):
227
+ dest[I] = orig[I]
228
+
229
+ @kernel
230
+ def update_p():
231
+ for I in grouped(p):
232
+ p[I] = r[I] + beta[None] * (p[I] - omega[None] * Ap[I])
233
+
234
+ @kernel
235
+ def update_phat():
236
+ for I in grouped(p_hat):
237
+ p_hat[I] = p[I]
238
+
239
+ @kernel
240
+ def update_s():
241
+ for I in grouped(s):
242
+ s[I] = r[I] - alpha[None] * Ap[I]
243
+
244
+ @kernel
245
+ def update_shat():
246
+ for I in grouped(s_hat):
247
+ s_hat[I] = s[I]
248
+
249
+ @kernel
250
+ def update_x():
251
+ for I in grouped(x):
252
+ x[I] += alpha[None] * p_hat[I] + omega[None] * s_hat[I]
253
+
254
+ @kernel
255
+ def update_r():
256
+ for I in grouped(r):
257
+ r[I] = s[I] - omega[None] * t[I]
258
+
259
+ def solve():
260
+ succeeded = True
261
+ A._matvec(x, Ax)
262
+ init()
263
+ initial_rTr = reduce(r, r)
264
+ rTr = initial_rTr
265
+ if not quiet:
266
+ print(f">>> Initial residual = {initial_rTr:e}")
267
+ if sqrt(initial_rTr) >= tol: # Do nothing if the initial residual is small enough
268
+ for i in range(maxiter):
269
+ rho[None] = reduce(r, r_tld)
270
+ if rho[None] == 0.0:
271
+ if not quiet:
272
+ print(">>> BICGSTAB failed because r@r_tld = 0.")
273
+ succeeded = False
274
+ break
275
+ if i == 0:
276
+ copy(orig=r, dest=p)
277
+ else:
278
+ beta[None] = (rho[None] / rho_1[None]) * (alpha[None] / omega[None])
279
+ update_p()
280
+ update_phat()
281
+ A._matvec(p, Ap)
282
+ alpha_lower = reduce(r_tld, Ap)
283
+ alpha[None] = rho[None] / alpha_lower
284
+ update_s()
285
+ update_shat()
286
+ A._matvec(s_hat, Ashat)
287
+ copy(orig=Ashat, dest=t)
288
+ omega_upper = reduce(t, s)
289
+ omega_lower = reduce(t, t)
290
+ omega[None] = omega_upper / (omega_lower + 1e-16) if omega_lower == 0.0 else omega_upper / omega_lower
291
+ update_x()
292
+ update_r()
293
+ rTr = reduce(r, r)
294
+ if not quiet:
295
+ print(f">>> Iter = {i+1:4}, Residual = {sqrt(rTr):e}")
296
+ if sqrt(rTr) < tol:
297
+ if not quiet:
298
+ print(f">>> BICGSTAB method converged at #iterations {i}")
299
+ break
300
+ rho_1[None] = rho[None]
301
+ if rTr >= tol:
302
+ if not quiet:
303
+ print(f">>> BICGSTAB failed to converge in {maxiter} iterations: Residual = {sqrt(rTr):e}")
304
+ succeeded = False
305
+ return succeeded
306
+
307
+ succeeded = solve()
308
+ vector_fields_snode_tree.destroy()
309
+ scalar_snode_tree.destroy()
310
+ return succeeded
@@ -0,0 +1,59 @@
1
+ # type: ignore
2
+
3
+ import numpy as np
4
+
5
+ from gstaichi._lib import core as _ti_core
6
+ from gstaichi.lang._ndarray import Ndarray, ScalarNdarray
7
+ from gstaichi.lang.exception import GsTaichiRuntimeError
8
+ from gstaichi.lang.impl import get_runtime
9
+ from gstaichi.types import f32, f64
10
+
11
+
12
+ class SparseCG:
13
+ """Conjugate-gradient solver built for SparseMatrix.
14
+
15
+ Use conjugate-gradient method to solve the linear system Ax = b, where A is SparseMatrix.
16
+
17
+ Args:
18
+ A (SparseMatrix): The coefficient matrix A of the linear system.
19
+ b (numpy ndarray, gstaichi Ndarray): The right-hand side of the linear system.
20
+ x0 (numpy ndarray, gstaichi Ndarray): The initial guess for the solution.
21
+ max_iter (int): Maximum number of iterations.
22
+ atol: Tolerance(absolute) for convergence.
23
+ """
24
+
25
+ def __init__(self, A, b, x0=None, max_iter=50, atol=1e-6):
26
+ self.dtype = A.dtype
27
+ self.ti_arch = get_runtime().prog.config().arch
28
+ self.matrix = A
29
+ self.b = b
30
+ if self.ti_arch == _ti_core.Arch.cuda:
31
+ self.cg_solver = _ti_core.make_cucg_solver(A.matrix, max_iter, atol, True)
32
+ elif self.ti_arch == _ti_core.Arch.x64 or self.ti_arch == _ti_core.Arch.arm64:
33
+ if self.dtype == f32:
34
+ self.cg_solver = _ti_core.make_float_cg_solver(A.matrix, max_iter, atol, True)
35
+ elif self.dtype == f64:
36
+ self.cg_solver = _ti_core.make_double_cg_solver(A.matrix, max_iter, atol, True)
37
+ else:
38
+ raise GsTaichiRuntimeError(f"Unsupported CG dtype: {self.dtype}")
39
+ if isinstance(b, Ndarray):
40
+ self.cg_solver.set_b_ndarray(get_runtime().prog, b.arr)
41
+ elif isinstance(b, np.ndarray):
42
+ self.cg_solver.set_b(b)
43
+ if isinstance(x0, Ndarray):
44
+ self.cg_solver.set_x_ndarray(get_runtime().prog, x0.arr)
45
+ elif isinstance(x0, np.ndarray):
46
+ self.cg_solver.set_x(x0)
47
+ else:
48
+ raise GsTaichiRuntimeError(f"Unsupported CG arch: {self.ti_arch}")
49
+
50
+ def solve(self):
51
+ if self.ti_arch == _ti_core.Arch.cuda:
52
+ if isinstance(self.b, Ndarray):
53
+ x = ScalarNdarray(self.b.dtype, [self.matrix.m])
54
+ self.cg_solver.solve(get_runtime().prog, x.arr, self.b.arr)
55
+ return x, True
56
+ raise GsTaichiRuntimeError(f"Unsupported CG RHS type: {type(self.b)}")
57
+ else:
58
+ self.cg_solver.solve()
59
+ return self.cg_solver.get_x(), self.cg_solver.is_success()