gstaichi 0.1.18.dev1__cp310-cp310-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3.h +6389 -0
  2. gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3native.h +594 -0
  3. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/instrument.hpp +268 -0
  4. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.h +907 -0
  5. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.hpp +375 -0
  6. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/linker.hpp +97 -0
  7. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/optimizer.hpp +970 -0
  8. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  9. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv.h +2568 -0
  10. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv.hpp +2579 -0
  11. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  12. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  13. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  14. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  15. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  16. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  17. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  18. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  19. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  20. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  21. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  22. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  23. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  24. gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  25. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  26. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  27. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  28. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  29. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  30. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  31. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  32. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  33. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  34. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  35. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  36. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  37. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  38. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  39. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  40. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  41. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  42. gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  43. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  44. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  45. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  46. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  47. gstaichi-0.1.18.dev1.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  48. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  49. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  50. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  51. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  52. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  53. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  54. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  55. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  56. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  57. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  58. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  59. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  60. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  61. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  62. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  63. gstaichi-0.1.18.dev1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  64. gstaichi-0.1.18.dev1.dist-info/METADATA +108 -0
  65. gstaichi-0.1.18.dev1.dist-info/RECORD +219 -0
  66. gstaichi-0.1.18.dev1.dist-info/WHEEL +5 -0
  67. gstaichi-0.1.18.dev1.dist-info/entry_points.txt +2 -0
  68. gstaichi-0.1.18.dev1.dist-info/licenses/LICENSE +201 -0
  69. gstaichi-0.1.18.dev1.dist-info/top_level.txt +1 -0
  70. taichi/__init__.py +44 -0
  71. taichi/__main__.py +5 -0
  72. taichi/_funcs.py +706 -0
  73. taichi/_kernels.py +420 -0
  74. taichi/_lib/__init__.py +3 -0
  75. taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +1401 -0
  76. taichi/_lib/c_api/include/taichi/taichi.h +29 -0
  77. taichi/_lib/c_api/include/taichi/taichi_core.h +1111 -0
  78. taichi/_lib/c_api/include/taichi/taichi_cpu.h +29 -0
  79. taichi/_lib/c_api/include/taichi/taichi_metal.h +72 -0
  80. taichi/_lib/c_api/include/taichi/taichi_platform.h +55 -0
  81. taichi/_lib/c_api/include/taichi/taichi_unity.h +64 -0
  82. taichi/_lib/c_api/include/taichi/taichi_vulkan.h +151 -0
  83. taichi/_lib/c_api/lib/libtaichi_c_api.dylib +0 -0
  84. taichi/_lib/c_api/runtime/libMoltenVK.dylib +0 -0
  85. taichi/_lib/c_api/runtime/runtime_arm64.bc +0 -0
  86. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +29 -0
  87. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +65 -0
  88. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +121 -0
  89. taichi/_lib/core/__init__.py +0 -0
  90. taichi/_lib/core/py.typed +0 -0
  91. taichi/_lib/core/taichi_python.cpython-310-darwin.so +0 -0
  92. taichi/_lib/core/taichi_python.pyi +3077 -0
  93. taichi/_lib/runtime/libMoltenVK.dylib +0 -0
  94. taichi/_lib/runtime/runtime_arm64.bc +0 -0
  95. taichi/_lib/utils.py +249 -0
  96. taichi/_logging.py +131 -0
  97. taichi/_main.py +552 -0
  98. taichi/_snode/__init__.py +5 -0
  99. taichi/_snode/fields_builder.py +189 -0
  100. taichi/_snode/snode_tree.py +34 -0
  101. taichi/_ti_module/__init__.py +3 -0
  102. taichi/_ti_module/cppgen.py +309 -0
  103. taichi/_ti_module/module.py +145 -0
  104. taichi/_version.py +1 -0
  105. taichi/_version_check.py +100 -0
  106. taichi/ad/__init__.py +3 -0
  107. taichi/ad/_ad.py +530 -0
  108. taichi/algorithms/__init__.py +3 -0
  109. taichi/algorithms/_algorithms.py +117 -0
  110. taichi/aot/__init__.py +12 -0
  111. taichi/aot/_export.py +28 -0
  112. taichi/aot/conventions/__init__.py +3 -0
  113. taichi/aot/conventions/gfxruntime140/__init__.py +38 -0
  114. taichi/aot/conventions/gfxruntime140/dr.py +244 -0
  115. taichi/aot/conventions/gfxruntime140/sr.py +613 -0
  116. taichi/aot/module.py +253 -0
  117. taichi/aot/utils.py +151 -0
  118. taichi/assets/.git +1 -0
  119. taichi/assets/Go-Regular.ttf +0 -0
  120. taichi/assets/static/imgs/ti_gallery.png +0 -0
  121. taichi/examples/minimal.py +28 -0
  122. taichi/experimental.py +16 -0
  123. taichi/graph/__init__.py +3 -0
  124. taichi/graph/_graph.py +292 -0
  125. taichi/lang/__init__.py +50 -0
  126. taichi/lang/_ndarray.py +348 -0
  127. taichi/lang/_ndrange.py +152 -0
  128. taichi/lang/_texture.py +172 -0
  129. taichi/lang/_wrap_inspect.py +189 -0
  130. taichi/lang/any_array.py +99 -0
  131. taichi/lang/argpack.py +411 -0
  132. taichi/lang/ast/__init__.py +5 -0
  133. taichi/lang/ast/ast_transformer.py +1806 -0
  134. taichi/lang/ast/ast_transformer_utils.py +328 -0
  135. taichi/lang/ast/checkers.py +106 -0
  136. taichi/lang/ast/symbol_resolver.py +57 -0
  137. taichi/lang/ast/transform.py +9 -0
  138. taichi/lang/common_ops.py +310 -0
  139. taichi/lang/exception.py +80 -0
  140. taichi/lang/expr.py +180 -0
  141. taichi/lang/field.py +464 -0
  142. taichi/lang/impl.py +1246 -0
  143. taichi/lang/kernel_arguments.py +157 -0
  144. taichi/lang/kernel_impl.py +1415 -0
  145. taichi/lang/matrix.py +1877 -0
  146. taichi/lang/matrix_ops.py +341 -0
  147. taichi/lang/matrix_ops_utils.py +190 -0
  148. taichi/lang/mesh.py +687 -0
  149. taichi/lang/misc.py +807 -0
  150. taichi/lang/ops.py +1489 -0
  151. taichi/lang/runtime_ops.py +13 -0
  152. taichi/lang/shell.py +35 -0
  153. taichi/lang/simt/__init__.py +5 -0
  154. taichi/lang/simt/block.py +94 -0
  155. taichi/lang/simt/grid.py +7 -0
  156. taichi/lang/simt/subgroup.py +191 -0
  157. taichi/lang/simt/warp.py +96 -0
  158. taichi/lang/snode.py +487 -0
  159. taichi/lang/source_builder.py +150 -0
  160. taichi/lang/struct.py +855 -0
  161. taichi/lang/util.py +381 -0
  162. taichi/linalg/__init__.py +8 -0
  163. taichi/linalg/matrixfree_cg.py +310 -0
  164. taichi/linalg/sparse_cg.py +59 -0
  165. taichi/linalg/sparse_matrix.py +303 -0
  166. taichi/linalg/sparse_solver.py +123 -0
  167. taichi/math/__init__.py +11 -0
  168. taichi/math/_complex.py +204 -0
  169. taichi/math/mathimpl.py +886 -0
  170. taichi/profiler/__init__.py +6 -0
  171. taichi/profiler/kernel_metrics.py +260 -0
  172. taichi/profiler/kernel_profiler.py +592 -0
  173. taichi/profiler/memory_profiler.py +15 -0
  174. taichi/profiler/scoped_profiler.py +36 -0
  175. taichi/shaders/Circles_vk.frag +29 -0
  176. taichi/shaders/Circles_vk.vert +45 -0
  177. taichi/shaders/Circles_vk_frag.spv +0 -0
  178. taichi/shaders/Circles_vk_vert.spv +0 -0
  179. taichi/shaders/Lines_vk.frag +9 -0
  180. taichi/shaders/Lines_vk.vert +11 -0
  181. taichi/shaders/Lines_vk_frag.spv +0 -0
  182. taichi/shaders/Lines_vk_vert.spv +0 -0
  183. taichi/shaders/Mesh_vk.frag +71 -0
  184. taichi/shaders/Mesh_vk.vert +68 -0
  185. taichi/shaders/Mesh_vk_frag.spv +0 -0
  186. taichi/shaders/Mesh_vk_vert.spv +0 -0
  187. taichi/shaders/Particles_vk.frag +95 -0
  188. taichi/shaders/Particles_vk.vert +73 -0
  189. taichi/shaders/Particles_vk_frag.spv +0 -0
  190. taichi/shaders/Particles_vk_vert.spv +0 -0
  191. taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
  192. taichi/shaders/SceneLines_vk.frag +9 -0
  193. taichi/shaders/SceneLines_vk.vert +12 -0
  194. taichi/shaders/SceneLines_vk_frag.spv +0 -0
  195. taichi/shaders/SceneLines_vk_vert.spv +0 -0
  196. taichi/shaders/SetImage_vk.frag +21 -0
  197. taichi/shaders/SetImage_vk.vert +15 -0
  198. taichi/shaders/SetImage_vk_frag.spv +0 -0
  199. taichi/shaders/SetImage_vk_vert.spv +0 -0
  200. taichi/shaders/Triangles_vk.frag +16 -0
  201. taichi/shaders/Triangles_vk.vert +29 -0
  202. taichi/shaders/Triangles_vk_frag.spv +0 -0
  203. taichi/shaders/Triangles_vk_vert.spv +0 -0
  204. taichi/shaders/lines2quad_vk_comp.spv +0 -0
  205. taichi/sparse/__init__.py +3 -0
  206. taichi/sparse/_sparse_grid.py +77 -0
  207. taichi/tools/__init__.py +12 -0
  208. taichi/tools/diagnose.py +124 -0
  209. taichi/tools/np2ply.py +364 -0
  210. taichi/tools/vtk.py +38 -0
  211. taichi/types/__init__.py +19 -0
  212. taichi/types/annotations.py +47 -0
  213. taichi/types/compound_types.py +90 -0
  214. taichi/types/enums.py +49 -0
  215. taichi/types/ndarray_type.py +147 -0
  216. taichi/types/primitive_types.py +203 -0
  217. taichi/types/quant.py +88 -0
  218. taichi/types/texture_type.py +85 -0
  219. taichi/types/utils.py +13 -0
taichi/_funcs.py ADDED
@@ -0,0 +1,706 @@
1
+ # type: ignore
2
+
3
+ import math
4
+
5
+ from taichi.lang import impl, ops
6
+ from taichi.lang.impl import get_runtime, grouped, static
7
+ from taichi.lang.kernel_impl import func
8
+ from taichi.lang.matrix import Matrix, Vector
9
+ from taichi.types import f32, f64
10
+ from taichi.types.annotations import template
11
+
12
+
13
+ @func
14
+ def _randn(dt):
15
+ """
16
+ Generate a random float sampled from univariate standard normal
17
+ (Gaussian) distribution of mean 0 and variance 1, using the
18
+ Box-Muller transformation.
19
+ """
20
+ assert dt == f32 or dt == f64
21
+ u1 = ops.cast(1.0, dt) - ops.random(dt)
22
+ u2 = ops.random(dt)
23
+ r = ops.sqrt(-2 * ops.log(u1))
24
+ c = ops.cos(math.tau * u2)
25
+ return r * c
26
+
27
+
28
+ def randn(dt=None):
29
+ """Generate a random float sampled from univariate standard normal
30
+ (Gaussian) distribution of mean 0 and variance 1, using the
31
+ Box-Muller transformation. Must be called in Taichi scope.
32
+
33
+ Args:
34
+ dt (DataType): Data type of the required random number. Default to `None`.
35
+ If set to `None` `dt` will be determined dynamically in runtime.
36
+
37
+ Returns:
38
+ The generated random float.
39
+
40
+ Example::
41
+
42
+ >>> @ti.kernel
43
+ >>> def main():
44
+ >>> print(ti.randn())
45
+ >>>
46
+ >>> main()
47
+ -0.463608
48
+ """
49
+ if dt is None:
50
+ dt = impl.get_runtime().default_fp
51
+ return _randn(dt)
52
+
53
+
54
+ @func
55
+ def _polar_decompose2d(A, dt):
56
+ """Perform polar decomposition (A=UP) for 2x2 matrix.
57
+ Mathematical concept refers to https://en.wikipedia.org/wiki/Polar_decomposition.
58
+
59
+ Args:
60
+ A (ti.Matrix(2, 2)): input 2x2 matrix `A`.
61
+ dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64.
62
+
63
+ Returns:
64
+ Decomposed 2x2 matrices `U` and `P`. `U` is a 2x2 orthogonal matrix
65
+ and `P` is a 2x2 positive or semi-positive definite matrix.
66
+ """
67
+ U = Matrix.identity(dt, 2)
68
+ P = ops.cast(A, dt)
69
+ zero = ops.cast(0.0, dt)
70
+ # if A is a zero matrix we simply return the pair (I, A)
71
+ if A[0, 0] == zero and A[0, 1] == zero and A[1, 0] == zero and A[1, 1] == zero:
72
+ pass
73
+ else:
74
+ detA = A[0, 0] * A[1, 1] - A[1, 0] * A[0, 1]
75
+ adetA = abs(detA)
76
+ B = Matrix(
77
+ [
78
+ [A[0, 0] + A[1, 1], A[0, 1] - A[1, 0]],
79
+ [A[1, 0] - A[0, 1], A[1, 1] + A[0, 0]],
80
+ ],
81
+ dt,
82
+ )
83
+
84
+ if detA < zero:
85
+ B = Matrix(
86
+ [
87
+ [A[0, 0] - A[1, 1], A[0, 1] + A[1, 0]],
88
+ [A[1, 0] + A[0, 1], A[1, 1] - A[0, 0]],
89
+ ],
90
+ dt,
91
+ )
92
+ # here det(B) != 0 if A is not the zero matrix
93
+ adetB = abs(B[0, 0] * B[1, 1] - B[1, 0] * B[0, 1])
94
+ k = ops.cast(1.0, dt) / ops.sqrt(adetB)
95
+ U = B * k
96
+ P = (A.transpose() @ A + adetA * Matrix.identity(dt, 2)) * k
97
+
98
+ return U, P
99
+
100
+
101
+ @func
102
+ def _polar_decompose3d(A, dt):
103
+ """Perform polar decomposition (A=UP) for 3x3 matrix.
104
+
105
+ Mathematical concept refers to https://en.wikipedia.org/wiki/Polar_decomposition.
106
+
107
+ Args:
108
+ A (ti.Matrix(3, 3)): input 3x3 matrix `A`.
109
+ dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64.
110
+
111
+ Returns:
112
+ Decomposed 3x3 matrices `U` and `P`.
113
+ """
114
+ U, sig, V = _svd3d(A, dt)
115
+ return U @ V.transpose(), V @ sig @ V.transpose()
116
+
117
+
118
+ # https://www.seas.upenn.edu/~cffjiang/research/svd/svd.pdf
119
+ @func
120
+ def _svd2d(A, dt):
121
+ """Perform singular value decomposition (A=USV^T) for 2x2 matrix.
122
+
123
+ Mathematical concept refers to https://en.wikipedia.org/wiki/Singular_value_decomposition.
124
+
125
+ Args:
126
+ A (ti.Matrix(2, 2)): input 2x2 matrix `A`.
127
+ dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64.
128
+
129
+ Returns:
130
+ Decomposed 2x2 matrices `U`, 'S' and `V`.
131
+ """
132
+ R, S = _polar_decompose2d(A, dt)
133
+ c, s = ops.cast(0.0, dt), ops.cast(0.0, dt)
134
+ s1, s2 = ops.cast(0.0, dt), ops.cast(0.0, dt)
135
+ if abs(S[0, 1]) < 1e-5:
136
+ c, s = 1, 0
137
+ s1, s2 = S[0, 0], S[1, 1]
138
+ else:
139
+ tao = ops.cast(0.5, dt) * (S[0, 0] - S[1, 1])
140
+ w = ops.sqrt(tao**2 + S[0, 1] ** 2)
141
+ t = ops.cast(0.0, dt)
142
+ if tao > 0:
143
+ t = S[0, 1] / (tao + w)
144
+ else:
145
+ t = S[0, 1] / (tao - w)
146
+ c = 1 / ops.sqrt(t**2 + 1)
147
+ s = -t * c
148
+ s1 = c**2 * S[0, 0] - 2 * c * s * S[0, 1] + s**2 * S[1, 1]
149
+ s2 = s**2 * S[0, 0] + 2 * c * s * S[0, 1] + c**2 * S[1, 1]
150
+ V = Matrix.zero(dt, 2, 2)
151
+ if s1 < s2:
152
+ tmp = s1
153
+ s1 = s2
154
+ s2 = tmp
155
+ V = Matrix([[-s, c], [-c, -s]], dt=dt)
156
+ else:
157
+ V = Matrix([[c, s], [-s, c]], dt=dt)
158
+ U = R @ V
159
+ return U, Matrix([[s1, ops.cast(0, dt)], [ops.cast(0, dt), s2]], dt=dt), V
160
+
161
+
162
+ def _svd3d(A, dt, iters=None):
163
+ """Perform singular value decomposition (A=USV^T) for 3x3 matrix.
164
+
165
+ Mathematical concept refers to https://en.wikipedia.org/wiki/Singular_value_decomposition.
166
+
167
+ Args:
168
+ A (ti.Matrix(3, 3)): input 3x3 matrix `A`.
169
+ dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64.
170
+ iters (int): iteration number to control algorithm precision.
171
+
172
+ Returns:
173
+ Decomposed 3x3 matrices `U`, 'S' and `V`.
174
+ """
175
+ assert A.n == 3 and A.m == 3
176
+ assert dt in [f32, f64]
177
+ if iters is None:
178
+ if dt == f32:
179
+ iters = 5
180
+ else:
181
+ iters = 8
182
+ if dt == f32:
183
+ rets = get_runtime().compiling_callable.ast_builder().sifakis_svd_f32(A.ptr, iters)
184
+ else:
185
+ rets = get_runtime().compiling_callable.ast_builder().sifakis_svd_f64(A.ptr, iters)
186
+ assert len(rets) == 21
187
+ U_entries = rets[:9]
188
+ V_entries = rets[9:18]
189
+ sig_entries = rets[18:]
190
+
191
+ @func
192
+ def get_result():
193
+ U = Matrix.zero(dt, 3, 3)
194
+ V = Matrix.zero(dt, 3, 3)
195
+ sigma = Matrix.zero(dt, 3, 3)
196
+ for i in static(range(3)):
197
+ for j in static(range(3)):
198
+ U[i, j] = U_entries[i * 3 + j]
199
+ V[i, j] = V_entries[i * 3 + j]
200
+ sigma[i, i] = sig_entries[i]
201
+ return U, sigma, V
202
+
203
+ return get_result()
204
+
205
+
206
+ @func
207
+ def _eig2x2(A, dt):
208
+ """Compute the eigenvalues and right eigenvectors (Av=lambda v) of a 2x2 real matrix.
209
+
210
+ Mathematical concept refers to https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix.
211
+
212
+ Args:
213
+ A (ti.Matrix(2, 2)): input 2x2 matrix `A`.
214
+ dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64.
215
+
216
+ Returns:
217
+ eigenvalues (ti.Matrix(2, 2)): The eigenvalues in complex form. Each row stores one eigenvalue. The first number of the eigenvalue represents the real part and the second number represents the imaginary part.
218
+ eigenvectors: (ti.Matrix(4, 2)): The eigenvectors in complex form. Each column stores one eigenvector. Each eigenvector consists of 2 entries, each of which is represented by two numbers for its real part and imaginary part.
219
+ """
220
+ tr = A.trace()
221
+ det = A.determinant()
222
+ gap = tr**2 - 4 * det
223
+ lambda1 = Vector.zero(dt, 2)
224
+ lambda2 = Vector.zero(dt, 2)
225
+ v1 = Vector.zero(dt, 4)
226
+ v2 = Vector.zero(dt, 4)
227
+ if gap > 0:
228
+ lambda1 = Vector([tr + ops.sqrt(gap), 0.0], dt=dt) * 0.5
229
+ lambda2 = Vector([tr - ops.sqrt(gap), 0.0], dt=dt) * 0.5
230
+ A1 = A - lambda1[0] * Matrix.identity(dt, 2)
231
+ A2 = A - lambda2[0] * Matrix.identity(dt, 2)
232
+ if all(A1 == Matrix.zero(dt, 2, 2)) and all(A1 == Matrix.zero(dt, 2, 2)):
233
+ v1 = Vector([0.0, 0.0, 1.0, 0.0]).cast(dt)
234
+ v2 = Vector([1.0, 0.0, 0.0, 0.0]).cast(dt)
235
+ else:
236
+ v1 = Vector([A2[0, 0], 0.0, A2[1, 0], 0.0], dt=dt).normalized()
237
+ v2 = Vector([A1[0, 0], 0.0, A1[1, 0], 0.0], dt=dt).normalized()
238
+ else:
239
+ lambda1 = Vector([tr, ops.sqrt(-gap)], dt=dt) * 0.5
240
+ lambda2 = Vector([tr, -ops.sqrt(-gap)], dt=dt) * 0.5
241
+ A1r = A - lambda1[0] * Matrix.identity(dt, 2)
242
+ A1i = -lambda1[1] * Matrix.identity(dt, 2)
243
+ A2r = A - lambda2[0] * Matrix.identity(dt, 2)
244
+ A2i = -lambda2[1] * Matrix.identity(dt, 2)
245
+ v1 = Vector([A2r[0, 0], A2i[0, 0], A2r[1, 0], A2i[1, 0]], dt=dt).normalized()
246
+ v2 = Vector([A1r[0, 0], A1i[0, 0], A1r[1, 0], A1i[1, 0]], dt=dt).normalized()
247
+ eigenvalues = Matrix.rows([lambda1, lambda2])
248
+ eigenvectors = Matrix.cols([v1, v2])
249
+
250
+ return eigenvalues, eigenvectors
251
+
252
+
253
+ @func
254
+ def _sym_eig2x2(A, dt):
255
+ """Compute the eigenvalues and right eigenvectors (Av=lambda v) of a 2x2 real symmetric matrix.
256
+
257
+ Mathematical concept refers to https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix.
258
+
259
+ Args:
260
+ A (ti.Matrix(2, 2)): input 2x2 symmetric matrix `A`.
261
+ dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64.
262
+
263
+ Returns:
264
+ eigenvalues (ti.Vector(2)): The eigenvalues. Each entry store one eigen value.
265
+ eigenvectors (ti.Matrix(2, 2)): The eigenvectors. Each column stores one eigenvector.
266
+ """
267
+ assert all(A == A.transpose()), "A needs to be symmetric"
268
+ tr = A.trace()
269
+ det = A.determinant()
270
+ gap = tr**2 - 4 * det
271
+ lambda1 = (tr + ops.sqrt(gap)) * 0.5
272
+ lambda2 = (tr - ops.sqrt(gap)) * 0.5
273
+ eigenvalues = Vector([lambda1, lambda2], dt=dt)
274
+
275
+ A1 = A - lambda1 * Matrix.identity(dt, 2)
276
+ A2 = A - lambda2 * Matrix.identity(dt, 2)
277
+ v1 = Vector.zero(dt, 2)
278
+ v2 = Vector.zero(dt, 2)
279
+ if all(A1 == Matrix.zero(dt, 2, 2)) and all(A1 == Matrix.zero(dt, 2, 2)):
280
+ v1 = Vector([0.0, 1.0]).cast(dt)
281
+ v2 = Vector([1.0, 0.0]).cast(dt)
282
+ else:
283
+ v1 = Vector([A2[0, 0], A2[1, 0]], dt=dt).normalized()
284
+ v2 = Vector([A1[0, 0], A1[1, 0]], dt=dt).normalized()
285
+ eigenvectors = Matrix.cols([v1, v2])
286
+ return eigenvalues, eigenvectors
287
+
288
+
289
+ @func
290
+ def dsytrd3(A, Q, dt):
291
+ Q[0, 0] = 1.0
292
+ Q[1, 1] = 1.0
293
+ Q[2, 2] = 1.0
294
+ e = Vector([0.0, 0.0, 0.0], dt=dt)
295
+ u = Vector([0.0, 0.0, 0.0], dt=dt)
296
+ q = Vector([0.0, 0.0, 0.0], dt=dt)
297
+ d = Vector([0.0, 0.0, 0.0], dt=dt)
298
+ h = A[0, 1] ** 2 + A[0, 2] ** 2
299
+ g = 0.0
300
+ if A[0, 1] > 0:
301
+ g = -ops.sqrt(h)
302
+ else:
303
+ g = ops.sqrt(h)
304
+ e[0] = g
305
+ f = g * A[0, 1]
306
+ u[1] = A[0, 1] - g
307
+ u[2] = A[0, 2]
308
+ omega = h - f
309
+ if omega > 0.0:
310
+ omega = 1.0 / omega
311
+ K = 0.0
312
+ f = A[1, 1] * u[1] + A[1, 2] * u[2]
313
+ q[1] = omega * f # p
314
+ K += u[1] * f # u* A u
315
+
316
+ f = A[1, 2] * u[1] + A[2, 2] * u[2]
317
+ q[2] = omega * f # p
318
+ K += u[2] * f # u* A u
319
+
320
+ K *= 0.5 * omega * omega
321
+
322
+ q[1] = q[1] - K * u[1]
323
+ q[2] = q[2] - K * u[2]
324
+
325
+ d[0] = A[0, 0]
326
+ d[1] = A[1, 1] - 2.0 * q[1] * u[1]
327
+ d[2] = A[2, 2] - 2.0 * q[2] * u[2]
328
+
329
+ for j in range(1, 3):
330
+ f = omega * u[j]
331
+ for i in range(1, 3):
332
+ Q[i, j] = Q[i, j] - f * u[i]
333
+
334
+ # Calculate updated A[1, 2] and store it in e[1]
335
+ e[1] = A[1, 2] - q[1] * u[2] - u[1] * q[2]
336
+ else:
337
+ d[0] = A[0, 0]
338
+ d[1] = A[1, 1]
339
+ d[2] = A[2, 2]
340
+ e[1] = A[1, 2]
341
+ return d, e, Q
342
+
343
+
344
+ @func
345
+ def dsyevq3(A, Q, w, dt):
346
+ w, e, Q = dsytrd3(A, Q, dt)
347
+ for l in range(0, 2):
348
+ nIter = 0
349
+ while True:
350
+ # Check for convergence and exit iteration loop if off-diagonal
351
+ # element e(l) is zero
352
+ m = 0
353
+ for i in range(l, 2):
354
+ m = i
355
+ g = ops.abs(w[m]) + ops.abs(w[m + 1])
356
+ if ops.abs(e[m]) + g == g:
357
+ break
358
+ if m == l:
359
+ break
360
+
361
+ nIter += 1
362
+ assert nIter <= 30, "Timeout"
363
+
364
+ # Calculate g = d_m - k
365
+ g = (w[l + 1] - w[l]) / (e[l] + e[l])
366
+ r = ops.sqrt(g * g + 1.0)
367
+ if g > 0:
368
+ g = w[m] - w[l] + e[l] / (g + r)
369
+ else:
370
+ g = w[m] - w[l] + e[l] / (g - r)
371
+
372
+ s = c = 1.0
373
+ p = 0.0
374
+ i = m - 1
375
+ while i >= l:
376
+ f = s * e[i]
377
+ b = c * e[i]
378
+ if ops.abs(f) > ops.abs(g):
379
+ c = g / f
380
+ r = ops.sqrt(c * c + 1.0)
381
+ e[i + 1] = f * r
382
+ s = 1.0 / r
383
+ c *= s
384
+ else:
385
+ s = f / g
386
+ r = ops.sqrt(s * s + 1.0)
387
+ e[i + 1] = g * r
388
+ c = 1.0 / r
389
+ s *= c
390
+
391
+ g = w[i + 1] - p
392
+ r = (w[i] - g) * s + 2.0 * c * b
393
+ p = s * r
394
+ w[i + 1] = g + p
395
+ g = c * r - b
396
+
397
+ for k in range(0, 3):
398
+ t = Q[k, i + 1]
399
+ Q[k, i + 1] = s * Q[k, i] + c * t
400
+ Q[k, i] = c * Q[k, i] - s * t
401
+
402
+ i -= 1
403
+ w[l] -= p
404
+ e[l] = g
405
+ e[m] = 0.0
406
+ return Q, w
407
+
408
+
409
+ @func
410
+ def _sym_eig3x3(A, dt):
411
+ """Compute the eigenvalues and right eigenvectors (Av=lambda v) of a 3x3 real symmetric matrix using Cardano's method.
412
+
413
+ Mathematical concept refers to https://www.mpi-hd.mpg.de/personalhomes/globes/3x3/.
414
+
415
+ Args:
416
+ A (ti.Matrix(3, 3)): input 3x3 symmetric matrix `A`.
417
+ dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64.
418
+
419
+ Returns:
420
+ eigenvalues (ti.Vector(3)): The eigenvalues. Each entry store one eigen value.
421
+ eigenvectors (ti.Matrix(3, 3)): The eigenvectors. Each column stores one eigenvector.
422
+ """
423
+ assert all(A == A.transpose()), "A needs to be symmetric"
424
+ M_SQRT3 = 1.73205080756887729352744634151
425
+ DBL_EPSILON = 2.2204460492503131e-16
426
+ m = A.trace()
427
+ dd = A[0, 1] * A[0, 1]
428
+ ee = A[1, 2] * A[1, 2]
429
+ ff = A[0, 2] * A[0, 2]
430
+ c1 = A[0, 0] * A[1, 1] + A[0, 0] * A[2, 2] + A[1, 1] * A[2, 2] - (dd + ee + ff)
431
+ c0 = A[2, 2] * dd + A[0, 0] * ee + A[1, 1] * ff - A[0, 0] * A[1, 1] * A[2, 2] - 2.0 * A[0, 2] * A[0, 1] * A[1, 2]
432
+
433
+ p = m * m - 3.0 * c1
434
+ q = m * (p - 1.5 * c1) - 13.5 * c0
435
+ sqrt_p = ops.sqrt(ops.abs(p))
436
+ phi = 27.0 * (0.25 * c1 * c1 * (p - c1) + c0 * (q + 6.75 * c0))
437
+ phi = (1.0 / 3.0) * ops.atan2(ops.sqrt(ops.abs(phi)), q)
438
+
439
+ c = sqrt_p * ops.cos(phi)
440
+ s = (1.0 / M_SQRT3) * sqrt_p * ops.sin(phi)
441
+ eigenvalues = Vector([0.0, 0.0, 0.0], dt=dt)
442
+ eigenvalues_final = Vector([0.0, 0.0, 0.0], dt=dt)
443
+ eigenvalues[1] = (1.0 / 3.0) * (m - c)
444
+ eigenvalues[2] = eigenvalues[1] + s
445
+ eigenvalues[0] = eigenvalues[1] + c
446
+ eigenvalues[1] = eigenvalues[1] - s
447
+
448
+ t = ops.abs(eigenvalues[0])
449
+ u = ops.abs(eigenvalues[1])
450
+ if u > t:
451
+ t = u
452
+ u = ops.abs(eigenvalues[2])
453
+ if u > t:
454
+ t = u
455
+ if t < 1.0:
456
+ u = t
457
+ else:
458
+ u = t * t
459
+ error = 256.0 * DBL_EPSILON * u * u
460
+ Q = Matrix.zero(dt, 3, 3)
461
+ Q_final = Matrix.zero(dt, 3, 3)
462
+ Q[0, 1] = A[0, 1] * A[1, 2] - A[0, 2] * A[1, 1]
463
+ Q[1, 1] = A[0, 2] * A[0, 1] - A[1, 2] * A[0, 0]
464
+ Q[2, 1] = A[0, 1] * A[0, 1]
465
+
466
+ Q[0, 0] = Q[0, 1] + A[0, 2] * eigenvalues[0]
467
+ Q[1, 0] = Q[1, 1] + A[1, 2] * eigenvalues[0]
468
+ Q[2, 0] = (A[0, 0] - eigenvalues[0]) * (A[1, 1] - eigenvalues[0]) - Q[2, 1]
469
+ norm = Q[0, 0] * Q[0, 0] + Q[1, 0] * Q[1, 0] + Q[2, 0] * Q[2, 0]
470
+ early_ret = 0
471
+ if norm <= error:
472
+ Q_final, eigenvalues_final = dsyevq3(A, Q, eigenvalues, dt)
473
+ early_ret = 1
474
+ else:
475
+ norm = ops.sqrt(1.0 / norm)
476
+ Q[0, 0] *= norm
477
+ Q[1, 0] *= norm
478
+ Q[2, 0] *= norm
479
+
480
+ if not early_ret:
481
+ Q[0, 1] = Q[0, 1] + A[0, 2] * eigenvalues[1]
482
+ Q[1, 1] = Q[1, 1] + A[1, 2] * eigenvalues[1]
483
+ Q[2, 1] = (A[0, 0] - eigenvalues[1]) * (A[1, 1] - eigenvalues[1]) - Q[2, 1]
484
+ norm = Q[0, 1] * Q[0, 1] + Q[1, 1] * Q[1, 1] + Q[2, 1] * Q[2, 1]
485
+ if norm <= error:
486
+ Q_final, eigenvalues_final = dsyevq3(A, Q, eigenvalues, dt)
487
+ early_ret = 1
488
+ else:
489
+ norm = ops.sqrt(1.0 / norm)
490
+ Q[0, 1] *= norm
491
+ Q[1, 1] *= norm
492
+ Q[2, 1] *= norm
493
+
494
+ Q[0, 2] = Q[1, 0] * Q[2, 1] - Q[2, 0] * Q[1, 1]
495
+ Q[1, 2] = Q[2, 0] * Q[0, 1] - Q[0, 0] * Q[2, 1]
496
+ Q[2, 2] = Q[0, 0] * Q[1, 1] - Q[1, 0] * Q[0, 1]
497
+
498
+ if early_ret:
499
+ Q = Q_final
500
+ eigenvalues = eigenvalues_final
501
+
502
+ if eigenvalues[1] < eigenvalues[0]:
503
+ tmp = eigenvalues[0]
504
+ eigenvalues[0] = eigenvalues[1]
505
+ eigenvalues[1] = tmp
506
+ tmp2 = Q[:, 0]
507
+ Q[:, 0] = Q[:, 1]
508
+ Q[:, 1] = tmp2
509
+
510
+ if eigenvalues[2] < eigenvalues[0]:
511
+ tmp = eigenvalues[0]
512
+ eigenvalues[0] = eigenvalues[2]
513
+ eigenvalues[2] = tmp
514
+ tmp2 = Q[:, 0]
515
+ Q[:, 0] = Q[:, 2]
516
+ Q[:, 2] = tmp2
517
+
518
+ if eigenvalues[2] < eigenvalues[1]:
519
+ tmp = eigenvalues[1]
520
+ eigenvalues[1] = eigenvalues[2]
521
+ eigenvalues[2] = tmp
522
+ tmp2 = Q[:, 1]
523
+ Q[:, 1] = Q[:, 2]
524
+ Q[:, 2] = tmp2
525
+
526
+ return eigenvalues, Q
527
+
528
+
529
+ def polar_decompose(A, dt=None):
530
+ """Perform polar decomposition (A=UP) for arbitrary size matrix.
531
+
532
+ Mathematical concept refers to https://en.wikipedia.org/wiki/Polar_decomposition.
533
+
534
+ Args:
535
+ A (ti.Matrix(n, n)): input nxn matrix `A`.
536
+ dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64.
537
+
538
+ Returns:
539
+ Decomposed nxn matrices `U` and `P`.
540
+ """
541
+ if dt is None:
542
+ dt = impl.get_runtime().default_fp
543
+ if A.n == 2:
544
+ return _polar_decompose2d(A, dt)
545
+ if A.n == 3:
546
+ return _polar_decompose3d(A, dt)
547
+ raise Exception("Polar decomposition only supports 2D and 3D matrices.")
548
+
549
+
550
+ def svd(A, dt=None):
551
+ """Perform singular value decomposition (A=USV^T) for arbitrary size matrix.
552
+
553
+ Mathematical concept refers to https://en.wikipedia.org/wiki/Singular_value_decomposition.
554
+
555
+ Args:
556
+ A (ti.Matrix(n, n)): input nxn matrix `A`.
557
+ dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64.
558
+
559
+ Returns:
560
+ Decomposed nxn matrices `U`, 'S' and `V`.
561
+ """
562
+ if dt is None:
563
+ dt = impl.get_runtime().default_fp
564
+ if A.n == 2:
565
+ return _svd2d(A, dt)
566
+ if A.n == 3:
567
+ return _svd3d(A, dt)
568
+ raise Exception("SVD only supports 2D and 3D matrices.")
569
+
570
+
571
+ def eig(A, dt=None):
572
+ """Compute the eigenvalues and right eigenvectors of a real matrix.
573
+
574
+ Mathematical concept refers to https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix.
575
+
576
+ Args:
577
+ A (ti.Matrix(n, n)): 2D Matrix for which the eigenvalues and right eigenvectors will be computed.
578
+ dt (DataType): The datatype for the eigenvalues and right eigenvectors.
579
+
580
+ Returns:
581
+ eigenvalues (ti.Matrix(n, 2)): The eigenvalues in complex form. Each row stores one eigenvalue. The first number of the eigenvalue represents the real part and the second number represents the imaginary part.
582
+ eigenvectors (ti.Matrix(n*2, n)): The eigenvectors in complex form. Each column stores one eigenvector. Each eigenvector consists of n entries, each of which is represented by two numbers for its real part and imaginary part.
583
+ """
584
+ if dt is None:
585
+ dt = impl.get_runtime().default_fp
586
+ if A.n == 2:
587
+ return _eig2x2(A, dt)
588
+ raise Exception("Eigen solver only supports 2D matrices.")
589
+
590
+
591
+ def sym_eig(A, dt=None):
592
+ """Compute the eigenvalues and right eigenvectors of a real symmetric matrix.
593
+
594
+ Mathematical concept refers to https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix.
595
+
596
+ Args:
597
+ A (ti.Matrix(n, n)): Symmetric Matrix for which the eigenvalues and right eigenvectors will be computed.
598
+ dt (DataType): The datatype for the eigenvalues and right eigenvectors.
599
+
600
+ Returns:
601
+ eigenvalues (ti.Vector(n)): The eigenvalues. Each entry store one eigen value.
602
+ eigenvectors (ti.Matrix(n, n)): The eigenvectors. Each column stores one eigenvector.
603
+ """
604
+ if dt is None:
605
+ dt = impl.get_runtime().default_fp
606
+ if A.n == 2:
607
+ return _sym_eig2x2(A, dt)
608
+ if A.n == 3:
609
+ return _sym_eig3x3(A, dt)
610
+ raise Exception("Symmetric eigen solver only supports 2D and 3D matrices.")
611
+
612
+
613
+ @func
614
+ def _gauss_elimination_2x2(Ab, dt):
615
+ if ops.abs(Ab[0, 0]) < ops.abs(Ab[1, 0]):
616
+ Ab[0, 0], Ab[1, 0] = Ab[1, 0], Ab[0, 0]
617
+ Ab[0, 1], Ab[1, 1] = Ab[1, 1], Ab[0, 1]
618
+ Ab[0, 2], Ab[1, 2] = Ab[1, 2], Ab[0, 2]
619
+ assert Ab[0, 0] != 0.0, "Matrix is singular in linear solve."
620
+ scale = Ab[1, 0] / Ab[0, 0]
621
+ Ab[1, 0] = 0.0
622
+ for k in static(range(1, 3)):
623
+ Ab[1, k] -= Ab[0, k] * scale
624
+ x = Vector.zero(dt, 2)
625
+ # Back substitution
626
+ x[1] = Ab[1, 2] / Ab[1, 1]
627
+ x[0] = (Ab[0, 2] - Ab[0, 1] * x[1]) / Ab[0, 0]
628
+ return x
629
+
630
+
631
+ @func
632
+ def _gauss_elimination_3x3(Ab, dt):
633
+ for i in static(range(3)):
634
+ max_row = i
635
+ max_v = ops.abs(Ab[i, i])
636
+ for j in static(range(i + 1, 3)):
637
+ if ops.abs(Ab[j, i]) > max_v:
638
+ max_row = j
639
+ max_v = ops.abs(Ab[j, i])
640
+ assert max_v != 0.0, "Matrix is singular in linear solve."
641
+ if i != max_row:
642
+ if max_row == 1:
643
+ for col in static(range(4)):
644
+ Ab[i, col], Ab[1, col] = Ab[1, col], Ab[i, col]
645
+ else:
646
+ for col in static(range(4)):
647
+ Ab[i, col], Ab[2, col] = Ab[2, col], Ab[i, col]
648
+ assert Ab[i, i] != 0.0, "Matrix is singular in linear solve."
649
+ for j in static(range(i + 1, 3)):
650
+ scale = Ab[j, i] / Ab[i, i]
651
+ Ab[j, i] = 0.0
652
+ for k in static(range(i + 1, 4)):
653
+ Ab[j, k] -= Ab[i, k] * scale
654
+ # Back substitution
655
+ x = Vector.zero(dt, 3)
656
+ for i in static(range(2, -1, -1)):
657
+ x[i] = Ab[i, 3]
658
+ for k in static(range(i + 1, 3)):
659
+ x[i] -= Ab[i, k] * x[k]
660
+ x[i] = x[i] / Ab[i, i]
661
+ return x
662
+
663
+
664
+ @func
665
+ def _combine(A, b, dt):
666
+ n = static(A.n)
667
+ Ab = Matrix.zero(dt, n, n + 1)
668
+ for i in static(range(n)):
669
+ for j in static(range(n)):
670
+ Ab[i, j] = A[i, j]
671
+ for i in static(range(n)):
672
+ Ab[i, n] = b[i]
673
+ return Ab
674
+
675
+
676
+ def solve(A, b, dt=None):
677
+ """Solve a matrix using Gauss elimination method.
678
+
679
+ Args:
680
+ A (ti.Matrix(n, n)): input nxn matrix `A`.
681
+ b (ti.Vector(n, 1)): input nx1 vector `b`.
682
+ dt (DataType): The datatype for the `A` and `b`.
683
+
684
+ Returns:
685
+ x (ti.Vector(n, 1)): the solution of Ax=b.
686
+ """
687
+ assert A.n == A.m, "Only square matrix is supported"
688
+ assert A.n >= 2 and A.n <= 3, "Only 2D and 3D matrices are supported"
689
+ assert A.m == b.n, "Matrix and Vector dimension dismatch"
690
+ if dt is None:
691
+ dt = impl.get_runtime().default_fp
692
+ Ab = _combine(A, b, dt)
693
+ if A.n == 2:
694
+ return _gauss_elimination_2x2(Ab, dt)
695
+ if A.n == 3:
696
+ return _gauss_elimination_3x3(Ab, dt)
697
+ raise Exception("Solver only supports 2D and 3D matrices.")
698
+
699
+
700
+ @func
701
+ def field_fill_taichi_scope(F: template(), val: template()):
702
+ for I in grouped(F):
703
+ F[I] = val
704
+
705
+
706
+ __all__ = ["randn", "polar_decompose", "eig", "sym_eig", "svd", "solve"]