gstaichi 0.1.18.dev1__cp310-cp310-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3.h +6389 -0
  2. gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3native.h +594 -0
  3. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/instrument.hpp +268 -0
  4. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.h +907 -0
  5. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.hpp +375 -0
  6. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/linker.hpp +97 -0
  7. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/optimizer.hpp +970 -0
  8. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  9. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv.h +2568 -0
  10. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv.hpp +2579 -0
  11. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  12. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  13. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  14. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  15. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  16. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  17. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  18. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  19. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  20. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  21. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  22. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  23. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  24. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  25. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  26. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  27. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  28. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  29. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  30. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  31. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  32. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  33. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  34. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  35. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  36. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  37. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  38. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  39. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  40. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  41. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  42. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  43. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  44. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  45. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  46. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  47. gstaichi-0.1.18.dev1.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  48. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  49. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  50. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  51. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  52. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  53. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  54. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  55. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  56. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  57. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  58. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  59. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  60. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  61. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  62. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  63. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  64. gstaichi-0.1.18.dev1.dist-info/METADATA +108 -0
  65. gstaichi-0.1.18.dev1.dist-info/RECORD +219 -0
  66. gstaichi-0.1.18.dev1.dist-info/WHEEL +5 -0
  67. gstaichi-0.1.18.dev1.dist-info/entry_points.txt +2 -0
  68. gstaichi-0.1.18.dev1.dist-info/licenses/LICENSE +201 -0
  69. gstaichi-0.1.18.dev1.dist-info/top_level.txt +1 -0
  70. taichi/__init__.py +44 -0
  71. taichi/__main__.py +5 -0
  72. taichi/_funcs.py +706 -0
  73. taichi/_kernels.py +420 -0
  74. taichi/_lib/__init__.py +3 -0
  75. taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +1401 -0
  76. taichi/_lib/c_api/include/taichi/taichi.h +29 -0
  77. taichi/_lib/c_api/include/taichi/taichi_core.h +1111 -0
  78. taichi/_lib/c_api/include/taichi/taichi_cpu.h +29 -0
  79. taichi/_lib/c_api/include/taichi/taichi_metal.h +72 -0
  80. taichi/_lib/c_api/include/taichi/taichi_platform.h +55 -0
  81. taichi/_lib/c_api/include/taichi/taichi_unity.h +64 -0
  82. taichi/_lib/c_api/include/taichi/taichi_vulkan.h +151 -0
  83. taichi/_lib/c_api/lib/libtaichi_c_api.dylib +0 -0
  84. taichi/_lib/c_api/runtime/libMoltenVK.dylib +0 -0
  85. taichi/_lib/c_api/runtime/runtime_arm64.bc +0 -0
  86. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +29 -0
  87. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +65 -0
  88. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +121 -0
  89. taichi/_lib/core/__init__.py +0 -0
  90. taichi/_lib/core/py.typed +0 -0
  91. taichi/_lib/core/taichi_python.cpython-310-darwin.so +0 -0
  92. taichi/_lib/core/taichi_python.pyi +3077 -0
  93. taichi/_lib/runtime/libMoltenVK.dylib +0 -0
  94. taichi/_lib/runtime/runtime_arm64.bc +0 -0
  95. taichi/_lib/utils.py +249 -0
  96. taichi/_logging.py +131 -0
  97. taichi/_main.py +552 -0
  98. taichi/_snode/__init__.py +5 -0
  99. taichi/_snode/fields_builder.py +189 -0
  100. taichi/_snode/snode_tree.py +34 -0
  101. taichi/_ti_module/__init__.py +3 -0
  102. taichi/_ti_module/cppgen.py +309 -0
  103. taichi/_ti_module/module.py +145 -0
  104. taichi/_version.py +1 -0
  105. taichi/_version_check.py +100 -0
  106. taichi/ad/__init__.py +3 -0
  107. taichi/ad/_ad.py +530 -0
  108. taichi/algorithms/__init__.py +3 -0
  109. taichi/algorithms/_algorithms.py +117 -0
  110. taichi/aot/__init__.py +12 -0
  111. taichi/aot/_export.py +28 -0
  112. taichi/aot/conventions/__init__.py +3 -0
  113. taichi/aot/conventions/gfxruntime140/__init__.py +38 -0
  114. taichi/aot/conventions/gfxruntime140/dr.py +244 -0
  115. taichi/aot/conventions/gfxruntime140/sr.py +613 -0
  116. taichi/aot/module.py +253 -0
  117. taichi/aot/utils.py +151 -0
  118. taichi/assets/.git +1 -0
  119. taichi/assets/Go-Regular.ttf +0 -0
  120. taichi/assets/static/imgs/ti_gallery.png +0 -0
  121. taichi/examples/minimal.py +28 -0
  122. taichi/experimental.py +16 -0
  123. taichi/graph/__init__.py +3 -0
  124. taichi/graph/_graph.py +292 -0
  125. taichi/lang/__init__.py +50 -0
  126. taichi/lang/_ndarray.py +348 -0
  127. taichi/lang/_ndrange.py +152 -0
  128. taichi/lang/_texture.py +172 -0
  129. taichi/lang/_wrap_inspect.py +189 -0
  130. taichi/lang/any_array.py +99 -0
  131. taichi/lang/argpack.py +411 -0
  132. taichi/lang/ast/__init__.py +5 -0
  133. taichi/lang/ast/ast_transformer.py +1806 -0
  134. taichi/lang/ast/ast_transformer_utils.py +328 -0
  135. taichi/lang/ast/checkers.py +106 -0
  136. taichi/lang/ast/symbol_resolver.py +57 -0
  137. taichi/lang/ast/transform.py +9 -0
  138. taichi/lang/common_ops.py +310 -0
  139. taichi/lang/exception.py +80 -0
  140. taichi/lang/expr.py +180 -0
  141. taichi/lang/field.py +464 -0
  142. taichi/lang/impl.py +1246 -0
  143. taichi/lang/kernel_arguments.py +157 -0
  144. taichi/lang/kernel_impl.py +1415 -0
  145. taichi/lang/matrix.py +1877 -0
  146. taichi/lang/matrix_ops.py +341 -0
  147. taichi/lang/matrix_ops_utils.py +190 -0
  148. taichi/lang/mesh.py +687 -0
  149. taichi/lang/misc.py +807 -0
  150. taichi/lang/ops.py +1489 -0
  151. taichi/lang/runtime_ops.py +13 -0
  152. taichi/lang/shell.py +35 -0
  153. taichi/lang/simt/__init__.py +5 -0
  154. taichi/lang/simt/block.py +94 -0
  155. taichi/lang/simt/grid.py +7 -0
  156. taichi/lang/simt/subgroup.py +191 -0
  157. taichi/lang/simt/warp.py +96 -0
  158. taichi/lang/snode.py +487 -0
  159. taichi/lang/source_builder.py +150 -0
  160. taichi/lang/struct.py +855 -0
  161. taichi/lang/util.py +381 -0
  162. taichi/linalg/__init__.py +8 -0
  163. taichi/linalg/matrixfree_cg.py +310 -0
  164. taichi/linalg/sparse_cg.py +59 -0
  165. taichi/linalg/sparse_matrix.py +303 -0
  166. taichi/linalg/sparse_solver.py +123 -0
  167. taichi/math/__init__.py +11 -0
  168. taichi/math/_complex.py +204 -0
  169. taichi/math/mathimpl.py +886 -0
  170. taichi/profiler/__init__.py +6 -0
  171. taichi/profiler/kernel_metrics.py +260 -0
  172. taichi/profiler/kernel_profiler.py +592 -0
  173. taichi/profiler/memory_profiler.py +15 -0
  174. taichi/profiler/scoped_profiler.py +36 -0
  175. taichi/shaders/Circles_vk.frag +29 -0
  176. taichi/shaders/Circles_vk.vert +45 -0
  177. taichi/shaders/Circles_vk_frag.spv +0 -0
  178. taichi/shaders/Circles_vk_vert.spv +0 -0
  179. taichi/shaders/Lines_vk.frag +9 -0
  180. taichi/shaders/Lines_vk.vert +11 -0
  181. taichi/shaders/Lines_vk_frag.spv +0 -0
  182. taichi/shaders/Lines_vk_vert.spv +0 -0
  183. taichi/shaders/Mesh_vk.frag +71 -0
  184. taichi/shaders/Mesh_vk.vert +68 -0
  185. taichi/shaders/Mesh_vk_frag.spv +0 -0
  186. taichi/shaders/Mesh_vk_vert.spv +0 -0
  187. taichi/shaders/Particles_vk.frag +95 -0
  188. taichi/shaders/Particles_vk.vert +73 -0
  189. taichi/shaders/Particles_vk_frag.spv +0 -0
  190. taichi/shaders/Particles_vk_vert.spv +0 -0
  191. taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
  192. taichi/shaders/SceneLines_vk.frag +9 -0
  193. taichi/shaders/SceneLines_vk.vert +12 -0
  194. taichi/shaders/SceneLines_vk_frag.spv +0 -0
  195. taichi/shaders/SceneLines_vk_vert.spv +0 -0
  196. taichi/shaders/SetImage_vk.frag +21 -0
  197. taichi/shaders/SetImage_vk.vert +15 -0
  198. taichi/shaders/SetImage_vk_frag.spv +0 -0
  199. taichi/shaders/SetImage_vk_vert.spv +0 -0
  200. taichi/shaders/Triangles_vk.frag +16 -0
  201. taichi/shaders/Triangles_vk.vert +29 -0
  202. taichi/shaders/Triangles_vk_frag.spv +0 -0
  203. taichi/shaders/Triangles_vk_vert.spv +0 -0
  204. taichi/shaders/lines2quad_vk_comp.spv +0 -0
  205. taichi/sparse/__init__.py +3 -0
  206. taichi/sparse/_sparse_grid.py +77 -0
  207. taichi/tools/__init__.py +12 -0
  208. taichi/tools/diagnose.py +124 -0
  209. taichi/tools/np2ply.py +364 -0
  210. taichi/tools/vtk.py +38 -0
  211. taichi/types/__init__.py +19 -0
  212. taichi/types/annotations.py +47 -0
  213. taichi/types/compound_types.py +90 -0
  214. taichi/types/enums.py +49 -0
  215. taichi/types/ndarray_type.py +147 -0
  216. taichi/types/primitive_types.py +203 -0
  217. taichi/types/quant.py +88 -0
  218. taichi/types/texture_type.py +85 -0
  219. taichi/types/utils.py +13 -0
@@ -0,0 +1,59 @@
1
+ # type: ignore
2
+
3
+ import numpy as np
4
+
5
+ from taichi._lib import core as _ti_core
6
+ from taichi.lang._ndarray import Ndarray, ScalarNdarray
7
+ from taichi.lang.exception import TaichiRuntimeError
8
+ from taichi.lang.impl import get_runtime
9
+ from taichi.types import f32, f64
10
+
11
+
12
+ class SparseCG:
13
+ """Conjugate-gradient solver built for SparseMatrix.
14
+
15
+ Use conjugate-gradient method to solve the linear system Ax = b, where A is SparseMatrix.
16
+
17
+ Args:
18
+ A (SparseMatrix): The coefficient matrix A of the linear system.
19
+ b (numpy ndarray, taichi Ndarray): The right-hand side of the linear system.
20
+ x0 (numpy ndarray, taichi Ndarray): The initial guess for the solution.
21
+ max_iter (int): Maximum number of iterations.
22
+ atol: Tolerance(absolute) for convergence.
23
+ """
24
+
25
+ def __init__(self, A, b, x0=None, max_iter=50, atol=1e-6):
26
+ self.dtype = A.dtype
27
+ self.ti_arch = get_runtime().prog.config().arch
28
+ self.matrix = A
29
+ self.b = b
30
+ if self.ti_arch == _ti_core.Arch.cuda:
31
+ self.cg_solver = _ti_core.make_cucg_solver(A.matrix, max_iter, atol, True)
32
+ elif self.ti_arch == _ti_core.Arch.x64 or self.ti_arch == _ti_core.Arch.arm64:
33
+ if self.dtype == f32:
34
+ self.cg_solver = _ti_core.make_float_cg_solver(A.matrix, max_iter, atol, True)
35
+ elif self.dtype == f64:
36
+ self.cg_solver = _ti_core.make_double_cg_solver(A.matrix, max_iter, atol, True)
37
+ else:
38
+ raise TaichiRuntimeError(f"Unsupported CG dtype: {self.dtype}")
39
+ if isinstance(b, Ndarray):
40
+ self.cg_solver.set_b_ndarray(get_runtime().prog, b.arr)
41
+ elif isinstance(b, np.ndarray):
42
+ self.cg_solver.set_b(b)
43
+ if isinstance(x0, Ndarray):
44
+ self.cg_solver.set_x_ndarray(get_runtime().prog, x0.arr)
45
+ elif isinstance(x0, np.ndarray):
46
+ self.cg_solver.set_x(x0)
47
+ else:
48
+ raise TaichiRuntimeError(f"Unsupported CG arch: {self.ti_arch}")
49
+
50
+ def solve(self):
51
+ if self.ti_arch == _ti_core.Arch.cuda:
52
+ if isinstance(self.b, Ndarray):
53
+ x = ScalarNdarray(self.b.dtype, [self.matrix.m])
54
+ self.cg_solver.solve(get_runtime().prog, x.arr, self.b.arr)
55
+ return x, True
56
+ raise TaichiRuntimeError(f"Unsupported CG RHS type: {type(self.b)}")
57
+ else:
58
+ self.cg_solver.solve()
59
+ return self.cg_solver.get_x(), self.cg_solver.is_success()
@@ -0,0 +1,303 @@
1
+ # type: ignore
2
+
3
+ from functools import reduce
4
+
5
+ import numpy as np
6
+
7
+ from taichi._lib import core as _ti_core
8
+ from taichi.lang._ndarray import Ndarray, ScalarNdarray
9
+ from taichi.lang.exception import TaichiRuntimeError
10
+ from taichi.lang.field import Field
11
+ from taichi.lang.impl import get_runtime
12
+ from taichi.types import f32
13
+
14
+
15
+ class SparseMatrix:
16
+ """Taichi's Sparse Matrix class
17
+
18
+ A sparse matrix allows the programmer to solve a large linear system.
19
+
20
+ Args:
21
+ n (int): the first dimension of a sparse matrix.
22
+ m (int): the second dimension of a sparse matrix.
23
+ sm (SparseMatrix): another sparse matrix that will be built from.
24
+ """
25
+
26
+ def __init__(self, n=None, m=None, sm=None, dtype=f32, storage_format="col_major"):
27
+ self.dtype = dtype
28
+ if sm is None:
29
+ self.n = n
30
+ self.m = m if m else n
31
+ self.matrix = get_runtime().prog.create_sparse_matrix(n, m, dtype, storage_format)
32
+ else:
33
+ self.n = sm.num_rows()
34
+ self.m = sm.num_cols()
35
+ self.matrix = sm
36
+
37
+ def __iadd__(self, other):
38
+ """Addition operation for sparse matrix.
39
+
40
+ Returns:
41
+ The result sparse matrix of the addition.
42
+ """
43
+ assert (
44
+ self.n == other.n and self.m == other.m
45
+ ), f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})"
46
+ self.matrix += other.matrix
47
+ return self
48
+
49
+ def __add__(self, other):
50
+ """Addition operation for sparse matrix.
51
+
52
+ Returns:
53
+ The result sparse matrix of the addition.
54
+ """
55
+ assert (
56
+ self.n == other.n and self.m == other.m
57
+ ), f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})"
58
+ sm = self.matrix + other.matrix
59
+ return SparseMatrix(sm=sm)
60
+
61
+ def __isub__(self, other):
62
+ """Subtraction operation for sparse matrix.
63
+
64
+ Returns:
65
+ The result sparse matrix of the subtraction.
66
+ """
67
+ assert (
68
+ self.n == other.n and self.m == other.m
69
+ ), f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})"
70
+ self.matrix -= other.matrix
71
+ return self
72
+
73
+ def __sub__(self, other):
74
+ """Subtraction operation for sparse matrix.
75
+
76
+ Returns:
77
+ The result sparse matrix of the subtraction.
78
+ """
79
+ assert (
80
+ self.n == other.n and self.m == other.m
81
+ ), f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})"
82
+ sm = self.matrix - other.matrix
83
+ return SparseMatrix(sm=sm)
84
+
85
+ def __mul__(self, other):
86
+ """Sparse matrix's multiplication against real numbers or the hadamard product against another matrix
87
+
88
+ Args:
89
+ other (float or SparseMatrix): the other operand of multiplication.
90
+ Returns:
91
+ The result of multiplication.
92
+ """
93
+ if isinstance(other, float):
94
+ sm = other * self.matrix
95
+ return SparseMatrix(sm=sm)
96
+ if isinstance(other, SparseMatrix):
97
+ assert (
98
+ self.n == other.n and self.m == other.m
99
+ ), f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})"
100
+ sm = self.matrix * other.matrix
101
+ return SparseMatrix(sm=sm)
102
+
103
+ return None
104
+
105
+ def __rmul__(self, other):
106
+ """Right scalar multiplication for sparse matrix.
107
+
108
+ Args:
109
+ other (float): the other operand of scalar multiplication.
110
+ Returns:
111
+ The result of multiplication.
112
+ """
113
+ if isinstance(other, float):
114
+ sm = self.matrix * other
115
+ return SparseMatrix(sm=sm)
116
+
117
+ return None
118
+
119
+ def transpose(self):
120
+ """Sparse Matrix transpose.
121
+
122
+ Returns:
123
+ The transposed sparse mastrix.
124
+ """
125
+ sm = self.matrix.transpose()
126
+ return SparseMatrix(sm=sm)
127
+
128
+ def __matmul__(self, other):
129
+ """Matrix multiplication.
130
+
131
+ Args:
132
+ other (SparseMatrix, Field, or numpy.array): the other sparse matrix of the multiplication.
133
+ Returns:
134
+ The result of matrix multiplication.
135
+ """
136
+ if isinstance(other, SparseMatrix):
137
+ assert (
138
+ self.m == other.n
139
+ ), f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})"
140
+ sm = self.matrix.matmul(other.matrix)
141
+ return SparseMatrix(sm=sm)
142
+ if isinstance(other, Field):
143
+ assert (
144
+ self.m == other.shape[0]
145
+ ), f"Dimension mismatch between sparse matrix ({self.n}, {self.m}) and vector ({other.shape})"
146
+ return self.matrix.mat_vec_mul(other.to_numpy())
147
+ if isinstance(other, np.ndarray):
148
+ assert (
149
+ self.m == other.shape[0]
150
+ ), f"Dimension mismatch between sparse matrix ({self.n}, {self.m}) and vector ({other.shape})"
151
+ return self.matrix.mat_vec_mul(other)
152
+ if isinstance(other, Ndarray):
153
+ if self.m != other.shape[0]:
154
+ raise TaichiRuntimeError(
155
+ f"Dimension mismatch between sparse matrix ({self.n}, {self.m}) and vector ({other.shape})"
156
+ )
157
+ res = ScalarNdarray(dtype=other.dtype, arr_shape=(self.n,))
158
+ self.matrix.spmv(get_runtime().prog, other.arr, res.arr)
159
+ return res
160
+ raise TaichiRuntimeError(
161
+ f"Sparse matrix-matrix/vector multiplication does not support {type(other)} for now. Supported types are SparseMatrix, ti.field, and numpy ndarray."
162
+ )
163
+
164
+ def __getitem__(self, indices):
165
+ return self.matrix.get_element(indices[0], indices[1])
166
+
167
+ def __setitem__(self, indices, value):
168
+ self.matrix.set_element(indices[0], indices[1], value)
169
+
170
+ def __str__(self):
171
+ """Python scope matrix print support."""
172
+ return self.matrix.to_string()
173
+
174
+ def __repr__(self):
175
+ return self.matrix.to_string()
176
+
177
+ @property
178
+ def shape(self):
179
+ """The shape of the sparse matrix."""
180
+ return (self.n, self.m)
181
+
182
+ def build_from_ndarray(self, ndarray):
183
+ """Build the sparse matrix from a ndarray.
184
+
185
+ Args:
186
+ ndarray (Union[ti.ndarray, ti.Vector.ndarray, ti.Matrix.ndarray]): the ndarray to build the sparse matrix from.
187
+
188
+ Raises:
189
+ TaichiRuntimeError: If the input is not a ndarray or the length is not divisible by 3.
190
+
191
+ Example::
192
+ >>> N = 5
193
+ >>> triplets = ti.Vector.ndarray(n=3, dtype=ti.f32, shape=10, layout=ti.Layout.AOS)
194
+ >>> @ti.kernel
195
+ >>> def fill(triplets: ti.types.ndarray()):
196
+ >>> for i in range(N):
197
+ >>> triplets[i] = ti.Vector([i, (i + 1) % N, i+1], dt=ti.f32)
198
+ >>> fill(triplets)
199
+ >>> A = ti.linalg.SparseMatrix(n=N, m=N, dtype=ti.f32)
200
+ >>> A.build_from_ndarray(triplets)
201
+ >>> print(A)
202
+ [0, 1, 0, 0, 0]
203
+ [0, 0, 2, 0, 0]
204
+ [0, 0, 0, 3, 0]
205
+ [0, 0, 0, 0, 4]
206
+ [5, 0, 0, 0, 0]
207
+ """
208
+ if isinstance(ndarray, Ndarray):
209
+ num_scalars = reduce(lambda x, y: x * y, ndarray.shape + ndarray.element_shape)
210
+ if num_scalars % 3 != 0:
211
+ raise TaichiRuntimeError("The number of ndarray elements must have a length that is divisible by 3.")
212
+ get_runtime().prog.make_sparse_matrix_from_ndarray(self.matrix, ndarray.arr)
213
+ else:
214
+ raise TaichiRuntimeError(
215
+ "Sparse matrix only supports building from [ti.ndarray, ti.Vector.ndarray, ti.Matrix.ndarray]"
216
+ )
217
+
218
+ def mmwrite(self, filename):
219
+ """Writes the sparse matrix to Matrix Market file-like target.
220
+
221
+ Args:
222
+ filename (str): the file name to write the sparse matrix to.
223
+ """
224
+ self.matrix.mmwrite(filename)
225
+
226
+
227
+ class SparseMatrixBuilder:
228
+ """A python wrap around sparse matrix builder.
229
+
230
+ Use this builder to fill the sparse matrix.
231
+
232
+ Args:
233
+ num_rows (int): the first dimension of a sparse matrix.
234
+ num_cols (int): the second dimension of a sparse matrix.
235
+ max_num_triplets (int): the maximum number of triplets.
236
+ dtype (ti.dtype): the data type of the sparse matrix.
237
+ storage_format (str): the storage format of the sparse matrix.
238
+ """
239
+
240
+ def __init__(
241
+ self,
242
+ num_rows=None,
243
+ num_cols=None,
244
+ max_num_triplets=0,
245
+ dtype=f32,
246
+ storage_format="col_major",
247
+ ):
248
+ self.num_rows = num_rows
249
+ self.num_cols = num_cols if num_cols else num_rows
250
+ self.dtype = dtype
251
+ if num_rows is not None:
252
+ taichi_arch = get_runtime().prog.config().arch
253
+ if taichi_arch in [
254
+ _ti_core.Arch.x64,
255
+ _ti_core.Arch.arm64,
256
+ _ti_core.Arch.cuda,
257
+ ]:
258
+ self.ptr = _ti_core.SparseMatrixBuilder(
259
+ num_rows,
260
+ num_cols,
261
+ max_num_triplets,
262
+ dtype,
263
+ storage_format,
264
+ )
265
+ self.ptr.create_ndarray(get_runtime().prog)
266
+ else:
267
+ raise TaichiRuntimeError("SparseMatrix only supports CPU and CUDA for now.")
268
+
269
+ def _get_addr(self):
270
+ """Get the address of the sparse matrix"""
271
+ return self.ptr.get_addr()
272
+
273
+ def _get_ndarray_addr(self):
274
+ """Get the address of the ndarray"""
275
+ return self.ptr.get_ndarray_data_ptr()
276
+
277
+ def print_triplets(self):
278
+ """Print the triplets stored in the builder"""
279
+ taichi_arch = get_runtime().prog.config().arch
280
+ if taichi_arch in [_ti_core.Arch.x64, _ti_core.Arch.arm64]:
281
+ self.ptr.print_triplets_eigen()
282
+ elif taichi_arch == _ti_core.Arch.cuda:
283
+ self.ptr.print_triplets_cuda()
284
+
285
+ def build(self, dtype=f32, _format="CSR"):
286
+ """Create a sparse matrix using the triplets"""
287
+ taichi_arch = get_runtime().prog.config().arch
288
+ if taichi_arch in [_ti_core.Arch.x64, _ti_core.Arch.arm64]:
289
+ sm = self.ptr.build()
290
+ return SparseMatrix(sm=sm, dtype=self.dtype)
291
+ if taichi_arch == _ti_core.Arch.cuda:
292
+ if self.dtype != f32:
293
+ raise TaichiRuntimeError("CUDA sparse matrix only supports f32.")
294
+ sm = self.ptr.build_cuda()
295
+ return SparseMatrix(sm=sm, dtype=self.dtype)
296
+ raise TaichiRuntimeError("Sparse matrix only supports CPU and CUDA backends.")
297
+
298
+ def __del__(self):
299
+ if get_runtime() is not None and get_runtime().prog is not None:
300
+ self.ptr.delete_ndarray(get_runtime().prog)
301
+
302
+
303
+ __all__ = ["SparseMatrix", "SparseMatrixBuilder"]
@@ -0,0 +1,123 @@
1
+ # type: ignore
2
+
3
+ import numpy as np
4
+
5
+ import taichi.lang
6
+ from taichi._lib import core as _ti_core
7
+ from taichi.lang._ndarray import Ndarray, ScalarNdarray
8
+ from taichi.lang.exception import TaichiRuntimeError
9
+ from taichi.lang.field import Field
10
+ from taichi.lang.impl import get_runtime
11
+ from taichi.linalg.sparse_matrix import SparseMatrix
12
+ from taichi.types.primitive_types import f32
13
+
14
+
15
+ class SparseSolver:
16
+ """Sparse linear system solver
17
+
18
+ Use this class to solve linear systems represented by sparse matrices.
19
+
20
+ Args:
21
+ solver_type (str): The factorization type.
22
+ ordering (str): The method for matrices re-ordering.
23
+ """
24
+
25
+ def __init__(self, dtype=f32, solver_type="LLT", ordering="AMD"):
26
+ self.matrix = None
27
+ self.dtype = dtype
28
+ solver_type_list = ["LLT", "LDLT", "LU"]
29
+ solver_ordering = ["AMD", "COLAMD"]
30
+ if solver_type in solver_type_list and ordering in solver_ordering:
31
+ taichi_arch = taichi.lang.impl.get_runtime().prog.config().arch
32
+ assert (
33
+ taichi_arch == _ti_core.Arch.x64
34
+ or taichi_arch == _ti_core.Arch.arm64
35
+ or taichi_arch == _ti_core.Arch.cuda
36
+ ), "SparseSolver only supports CPU and CUDA for now."
37
+ if taichi_arch == _ti_core.Arch.cuda:
38
+ self.solver = _ti_core.make_cusparse_solver(dtype, solver_type, ordering)
39
+ else:
40
+ self.solver = _ti_core.make_sparse_solver(dtype, solver_type, ordering)
41
+ else:
42
+ raise TaichiRuntimeError(
43
+ f"The solver type {solver_type} with {ordering} is not supported for now. Only {solver_type_list} with {solver_ordering} are supported."
44
+ )
45
+
46
+ @staticmethod
47
+ def _type_assert(sparse_matrix):
48
+ raise TaichiRuntimeError(
49
+ f"The parameter type: {type(sparse_matrix)} is not supported in linear solvers for now."
50
+ )
51
+
52
+ def compute(self, sparse_matrix):
53
+ """This method is equivalent to calling both `analyze_pattern` and then `factorize`.
54
+
55
+ Args:
56
+ sparse_matrix (SparseMatrix): The sparse matrix to be computed.
57
+ """
58
+ if isinstance(sparse_matrix, SparseMatrix):
59
+ self.matrix = sparse_matrix
60
+ taichi_arch = taichi.lang.impl.get_runtime().prog.config().arch
61
+ if taichi_arch == _ti_core.Arch.x64 or taichi_arch == _ti_core.Arch.arm64:
62
+ self.solver.compute(sparse_matrix.matrix)
63
+ elif taichi_arch == _ti_core.Arch.cuda:
64
+ self.analyze_pattern(self.matrix)
65
+ self.factorize(self.matrix)
66
+ else:
67
+ self._type_assert(sparse_matrix)
68
+
69
+ def analyze_pattern(self, sparse_matrix):
70
+ """Reorder the nonzero elements of the matrix, such that the factorization step creates less fill-in.
71
+
72
+ Args:
73
+ sparse_matrix (SparseMatrix): The sparse matrix to be analyzed.
74
+ """
75
+ if isinstance(sparse_matrix, SparseMatrix):
76
+ self.matrix = sparse_matrix
77
+ if self.matrix.dtype != self.dtype:
78
+ raise TaichiRuntimeError(
79
+ f"The SparseSolver's dtype {self.dtype} is not consistent with the SparseMatrix's dtype {self.matrix.dtype}."
80
+ )
81
+ self.solver.analyze_pattern(sparse_matrix.matrix)
82
+ else:
83
+ self._type_assert(sparse_matrix)
84
+
85
+ def factorize(self, sparse_matrix):
86
+ """Do the factorization step
87
+
88
+ Args:
89
+ sparse_matrix (SparseMatrix): The sparse matrix to be factorized.
90
+ """
91
+ if isinstance(sparse_matrix, SparseMatrix):
92
+ self.matrix = sparse_matrix
93
+ self.solver.factorize(sparse_matrix.matrix)
94
+ else:
95
+ self._type_assert(sparse_matrix)
96
+
97
+ def solve(self, b): # pylint: disable=R1710
98
+ """Computes the solution of the linear systems.
99
+ Args:
100
+ b (numpy.array or Field): The right-hand side of the linear systems.
101
+
102
+ Returns:
103
+ numpy.array: The solution of linear systems.
104
+ """
105
+ if self.matrix is None:
106
+ raise TaichiRuntimeError("Please call compute() before calling solve().")
107
+ if isinstance(b, Field):
108
+ return self.solver.solve(b.to_numpy())
109
+ if isinstance(b, np.ndarray):
110
+ return self.solver.solve(b)
111
+ if isinstance(b, Ndarray):
112
+ x = ScalarNdarray(b.dtype, [self.matrix.m])
113
+ self.solver.solve_rf(get_runtime().prog, self.matrix.matrix, b.arr, x.arr)
114
+ return x
115
+ raise TaichiRuntimeError(f"The parameter type: {type(b)} is not supported in linear solvers for now.")
116
+
117
+ def info(self):
118
+ """Check if the linear systems are solved successfully.
119
+
120
+ Returns:
121
+ bool: True if the solving process succeeded, False otherwise.
122
+ """
123
+ return self.solver.info()
@@ -0,0 +1,11 @@
1
+ # type: ignore
2
+
3
+ """Taichi math module.
4
+
5
+ The math module supports glsl-style vectors, matrices and functions.
6
+ """
7
+
8
+ from ._complex import *
9
+ from .mathimpl import * # pylint: disable=W0622
10
+
11
+ del mathimpl
@@ -0,0 +1,204 @@
1
+ # type: ignore
2
+
3
+ from .mathimpl import dot, vec2
4
+ from taichi.lang import ops
5
+ from taichi.lang.kernel_impl import func
6
+
7
+
8
+ @func
9
+ def cmul(z1, z2):
10
+ """Performs complex multiplication between two 2d vectors.
11
+
12
+ This is equivalent to the multiplication in the complex number field
13
+ when `z1` and `z2` are treated as complex numbers.
14
+
15
+ Args:
16
+ z1 (:class:`~taichi.math.vec2`): The first input.
17
+ z2 (:class:`~taichi.math.vec2`): The second input.
18
+
19
+ Example::
20
+
21
+ >>> @ti.kernel
22
+ >>> def test():
23
+ >>> z1 = ti.math.vec2(1, 1)
24
+ >>> z2 = ti.math.vec2(0, 1)
25
+ >>> ti.math.cmul(z1, z2) # [-1, 1]
26
+
27
+ Returns:
28
+ :class:`~taichi.math.vec2`: the complex multiplication `z1 * z2`.
29
+ """
30
+ x1, y1 = z1[0], z1[1]
31
+ x2, y2 = z2[0], z2[1]
32
+ return vec2(x1 * x2 - y1 * y2, x1 * y2 + x2 * y1)
33
+
34
+
35
+ @func
36
+ def cconj(z):
37
+ """Returns the complex conjugate of a 2d vector.
38
+
39
+ If `z=(x, y)` then the conjugate of `z` is `(x, -y)`.
40
+
41
+ Args:
42
+ z (:class:`~taichi.math.vec2`): The input.
43
+
44
+ Returns:
45
+ :class:`~taichi.math.vec2`: The complex conjugate of `z`.
46
+ """
47
+ return vec2(z[0], -z[1])
48
+
49
+
50
+ @func
51
+ def cdiv(z1, z2):
52
+ """Performs complex division between two 2d vectors.
53
+
54
+ This is equivalent to the division in the complex number field
55
+ when `z1` and `z2` are treated as complex numbers.
56
+
57
+ Args:
58
+ z1 (:class:`~taichi.math.vec2`): The first input.
59
+ z2 (:class:`~taichi.math.vec2`): The second input.
60
+
61
+ Example::
62
+
63
+ >>> @ti.kernel
64
+ >>> def test():
65
+ >>> z1 = ti.math.vec2(1, 1)
66
+ >>> z2 = ti.math.vec2(0, 1)
67
+ >>> ti.math.cdiv(z1, z2) # [1, -1]
68
+
69
+ Returns:
70
+ :class:`~taichi.math.vec2`: the complex division of `z1 / z2`.
71
+ """
72
+ x1, y1 = z1[0], z1[1]
73
+ x2, y2 = z2[0], z2[1]
74
+ return vec2(x1 * x2 + y1 * y2, -x1 * y2 + x2 * y1) / dot(z2, z2)
75
+
76
+
77
+ @func
78
+ def csqrt(z):
79
+ """Returns the complex square root of a 2d vector `z`, so that
80
+ if `w^2=z`, then `w = csqrt(z)`.
81
+
82
+ Among the two square roots of `z`, if their real parts are non-zero,
83
+ the one with positive real part is returned. If both their real parts
84
+ are zero, the one with non-negative imaginary part is returned.
85
+
86
+ Args:
87
+ z (:class:`~taichi.math.vec2`): The input.
88
+
89
+ Example::
90
+
91
+ >>> @ti.kernel
92
+ >>> def test():
93
+ >>> z = ti.math.vec2(-1, 0)
94
+ >>> w = ti.math.csqrt(z) # [0, 1]
95
+
96
+ Returns:
97
+ :class:`~taichi.math.vec2`: The complex square root.
98
+ """
99
+ result = vec2(0.0)
100
+ if any(z):
101
+ r = ops.sqrt(z.norm())
102
+ a = ops.atan2(z[1], z[0])
103
+ result = r * vec2(ops.cos(a / 2.0), ops.sin(a / 2.0))
104
+
105
+ return result
106
+
107
+
108
+ @func
109
+ def cinv(z):
110
+ """Computes the reciprocal of a complex `z`.
111
+
112
+ Args:
113
+ z (:class:`~taichi.math.vec2`): The input.
114
+
115
+ Example::
116
+
117
+ >>> @ti.kernel
118
+ >>> def test():
119
+ >>> z = ti.math.vec2(1, 1)
120
+ >>> w = ti.math.cinv(z) # [0.5, -0.5]
121
+
122
+ Returns:
123
+ :class:`~taichi.math.vec2`: The reciprocal of `z`.
124
+ """
125
+ return cconj(z) / dot(z, z)
126
+
127
+
128
+ @func
129
+ def cpow(z, n):
130
+ """Computes the power of a complex `z`: :math:`z^a`.
131
+
132
+ Args:
133
+ z (:class:`~taichi.math.vec2`): The base.
134
+ a (float): The exponent.
135
+
136
+ Example::
137
+
138
+ >>> @ti.kernel
139
+ >>> def test():
140
+ >>> z = ti.math.vec2(1, 1)
141
+ >>> w = ti.math.cpow(z) # [-2, 2]
142
+
143
+ Returns:
144
+ :class:`~taichi.math.vec2`: The power :math:`z^a`.
145
+ """
146
+ result = vec2(0.0)
147
+ if any(z):
148
+ r2 = dot(z, z)
149
+ a = ops.atan2(z[1], z[0]) * n
150
+ result = ops.pow(r2, n / 2.0) * vec2(ops.cos(a), ops.sin(a))
151
+
152
+ return result
153
+
154
+
155
+ @func
156
+ def cexp(z):
157
+ """Returns the complex exponential :math:`e^z`.
158
+
159
+ `z` is a 2d vector treated as a complex number.
160
+
161
+ Args:
162
+ z (:class:`~taichi.math.vec2`): The exponent.
163
+
164
+ Example::
165
+
166
+ >>> @ti.kernel
167
+ >>> def test():
168
+ >>> z = ti.math.vec2(1, 1)
169
+ >>> w = ti.math.cexp(z) # [1.468694, 2.287355]
170
+
171
+ Returns:
172
+ :class:`~taichi.math.vec2`: The power :math:`exp(z)`
173
+ """
174
+ r = ops.exp(z[0])
175
+ return vec2(r * ops.cos(z[1]), r * ops.sin(z[1]))
176
+
177
+
178
+ @func
179
+ def clog(z):
180
+ """Returns the complex logarithm of `z`, so that if :math:`e^w = z`,
181
+ then :math:`log(z) = w`.
182
+
183
+ `z` is a 2d vector treated as a complex number. The argument of :math:`w`
184
+ lies in the range (-pi, pi].
185
+
186
+ Args:
187
+ z (:class:`~taichi.math.vec2`): The input.
188
+
189
+ Example::
190
+
191
+ >>> @ti.kernel
192
+ >>> def test():
193
+ >>> z = ti.math.vec2(1, 1)
194
+ >>> w = ti.math.clog(z) # [0.346574, 0.785398]
195
+
196
+ Returns:
197
+ :class:`~taichi.math.vec2`: The logarithm of `z`.
198
+ """
199
+ ang = ops.atan2(z[1], z[0])
200
+ r2 = dot(z, z)
201
+ return vec2(ops.log(r2) / 2.0, ang)
202
+
203
+
204
+ __all__ = ["cconj", "cdiv", "cexp", "cinv", "clog", "cmul", "cpow", "csqrt"]