gstaichi 0.1.18.dev1__cp310-cp310-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv.h +2568 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv.hpp +2579 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-0.1.18.dev1.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.dist-info/METADATA +108 -0
- gstaichi-0.1.18.dev1.dist-info/RECORD +219 -0
- gstaichi-0.1.18.dev1.dist-info/WHEEL +5 -0
- gstaichi-0.1.18.dev1.dist-info/entry_points.txt +2 -0
- gstaichi-0.1.18.dev1.dist-info/licenses/LICENSE +201 -0
- gstaichi-0.1.18.dev1.dist-info/top_level.txt +1 -0
- taichi/__init__.py +44 -0
- taichi/__main__.py +5 -0
- taichi/_funcs.py +706 -0
- taichi/_kernels.py +420 -0
- taichi/_lib/__init__.py +3 -0
- taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +1401 -0
- taichi/_lib/c_api/include/taichi/taichi.h +29 -0
- taichi/_lib/c_api/include/taichi/taichi_core.h +1111 -0
- taichi/_lib/c_api/include/taichi/taichi_cpu.h +29 -0
- taichi/_lib/c_api/include/taichi/taichi_metal.h +72 -0
- taichi/_lib/c_api/include/taichi/taichi_platform.h +55 -0
- taichi/_lib/c_api/include/taichi/taichi_unity.h +64 -0
- taichi/_lib/c_api/include/taichi/taichi_vulkan.h +151 -0
- taichi/_lib/c_api/lib/libtaichi_c_api.dylib +0 -0
- taichi/_lib/c_api/runtime/libMoltenVK.dylib +0 -0
- taichi/_lib/c_api/runtime/runtime_arm64.bc +0 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +29 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +65 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +121 -0
- taichi/_lib/core/__init__.py +0 -0
- taichi/_lib/core/py.typed +0 -0
- taichi/_lib/core/taichi_python.cpython-310-darwin.so +0 -0
- taichi/_lib/core/taichi_python.pyi +3077 -0
- taichi/_lib/runtime/libMoltenVK.dylib +0 -0
- taichi/_lib/runtime/runtime_arm64.bc +0 -0
- taichi/_lib/utils.py +249 -0
- taichi/_logging.py +131 -0
- taichi/_main.py +552 -0
- taichi/_snode/__init__.py +5 -0
- taichi/_snode/fields_builder.py +189 -0
- taichi/_snode/snode_tree.py +34 -0
- taichi/_ti_module/__init__.py +3 -0
- taichi/_ti_module/cppgen.py +309 -0
- taichi/_ti_module/module.py +145 -0
- taichi/_version.py +1 -0
- taichi/_version_check.py +100 -0
- taichi/ad/__init__.py +3 -0
- taichi/ad/_ad.py +530 -0
- taichi/algorithms/__init__.py +3 -0
- taichi/algorithms/_algorithms.py +117 -0
- taichi/aot/__init__.py +12 -0
- taichi/aot/_export.py +28 -0
- taichi/aot/conventions/__init__.py +3 -0
- taichi/aot/conventions/gfxruntime140/__init__.py +38 -0
- taichi/aot/conventions/gfxruntime140/dr.py +244 -0
- taichi/aot/conventions/gfxruntime140/sr.py +613 -0
- taichi/aot/module.py +253 -0
- taichi/aot/utils.py +151 -0
- taichi/assets/.git +1 -0
- taichi/assets/Go-Regular.ttf +0 -0
- taichi/assets/static/imgs/ti_gallery.png +0 -0
- taichi/examples/minimal.py +28 -0
- taichi/experimental.py +16 -0
- taichi/graph/__init__.py +3 -0
- taichi/graph/_graph.py +292 -0
- taichi/lang/__init__.py +50 -0
- taichi/lang/_ndarray.py +348 -0
- taichi/lang/_ndrange.py +152 -0
- taichi/lang/_texture.py +172 -0
- taichi/lang/_wrap_inspect.py +189 -0
- taichi/lang/any_array.py +99 -0
- taichi/lang/argpack.py +411 -0
- taichi/lang/ast/__init__.py +5 -0
- taichi/lang/ast/ast_transformer.py +1806 -0
- taichi/lang/ast/ast_transformer_utils.py +328 -0
- taichi/lang/ast/checkers.py +106 -0
- taichi/lang/ast/symbol_resolver.py +57 -0
- taichi/lang/ast/transform.py +9 -0
- taichi/lang/common_ops.py +310 -0
- taichi/lang/exception.py +80 -0
- taichi/lang/expr.py +180 -0
- taichi/lang/field.py +464 -0
- taichi/lang/impl.py +1246 -0
- taichi/lang/kernel_arguments.py +157 -0
- taichi/lang/kernel_impl.py +1415 -0
- taichi/lang/matrix.py +1877 -0
- taichi/lang/matrix_ops.py +341 -0
- taichi/lang/matrix_ops_utils.py +190 -0
- taichi/lang/mesh.py +687 -0
- taichi/lang/misc.py +807 -0
- taichi/lang/ops.py +1489 -0
- taichi/lang/runtime_ops.py +13 -0
- taichi/lang/shell.py +35 -0
- taichi/lang/simt/__init__.py +5 -0
- taichi/lang/simt/block.py +94 -0
- taichi/lang/simt/grid.py +7 -0
- taichi/lang/simt/subgroup.py +191 -0
- taichi/lang/simt/warp.py +96 -0
- taichi/lang/snode.py +487 -0
- taichi/lang/source_builder.py +150 -0
- taichi/lang/struct.py +855 -0
- taichi/lang/util.py +381 -0
- taichi/linalg/__init__.py +8 -0
- taichi/linalg/matrixfree_cg.py +310 -0
- taichi/linalg/sparse_cg.py +59 -0
- taichi/linalg/sparse_matrix.py +303 -0
- taichi/linalg/sparse_solver.py +123 -0
- taichi/math/__init__.py +11 -0
- taichi/math/_complex.py +204 -0
- taichi/math/mathimpl.py +886 -0
- taichi/profiler/__init__.py +6 -0
- taichi/profiler/kernel_metrics.py +260 -0
- taichi/profiler/kernel_profiler.py +592 -0
- taichi/profiler/memory_profiler.py +15 -0
- taichi/profiler/scoped_profiler.py +36 -0
- taichi/shaders/Circles_vk.frag +29 -0
- taichi/shaders/Circles_vk.vert +45 -0
- taichi/shaders/Circles_vk_frag.spv +0 -0
- taichi/shaders/Circles_vk_vert.spv +0 -0
- taichi/shaders/Lines_vk.frag +9 -0
- taichi/shaders/Lines_vk.vert +11 -0
- taichi/shaders/Lines_vk_frag.spv +0 -0
- taichi/shaders/Lines_vk_vert.spv +0 -0
- taichi/shaders/Mesh_vk.frag +71 -0
- taichi/shaders/Mesh_vk.vert +68 -0
- taichi/shaders/Mesh_vk_frag.spv +0 -0
- taichi/shaders/Mesh_vk_vert.spv +0 -0
- taichi/shaders/Particles_vk.frag +95 -0
- taichi/shaders/Particles_vk.vert +73 -0
- taichi/shaders/Particles_vk_frag.spv +0 -0
- taichi/shaders/Particles_vk_vert.spv +0 -0
- taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
- taichi/shaders/SceneLines_vk.frag +9 -0
- taichi/shaders/SceneLines_vk.vert +12 -0
- taichi/shaders/SceneLines_vk_frag.spv +0 -0
- taichi/shaders/SceneLines_vk_vert.spv +0 -0
- taichi/shaders/SetImage_vk.frag +21 -0
- taichi/shaders/SetImage_vk.vert +15 -0
- taichi/shaders/SetImage_vk_frag.spv +0 -0
- taichi/shaders/SetImage_vk_vert.spv +0 -0
- taichi/shaders/Triangles_vk.frag +16 -0
- taichi/shaders/Triangles_vk.vert +29 -0
- taichi/shaders/Triangles_vk_frag.spv +0 -0
- taichi/shaders/Triangles_vk_vert.spv +0 -0
- taichi/shaders/lines2quad_vk_comp.spv +0 -0
- taichi/sparse/__init__.py +3 -0
- taichi/sparse/_sparse_grid.py +77 -0
- taichi/tools/__init__.py +12 -0
- taichi/tools/diagnose.py +124 -0
- taichi/tools/np2ply.py +364 -0
- taichi/tools/vtk.py +38 -0
- taichi/types/__init__.py +19 -0
- taichi/types/annotations.py +47 -0
- taichi/types/compound_types.py +90 -0
- taichi/types/enums.py +49 -0
- taichi/types/ndarray_type.py +147 -0
- taichi/types/primitive_types.py +203 -0
- taichi/types/quant.py +88 -0
- taichi/types/texture_type.py +85 -0
- taichi/types/utils.py +13 -0
taichi/_funcs.py
ADDED
@@ -0,0 +1,706 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import math
|
4
|
+
|
5
|
+
from taichi.lang import impl, ops
|
6
|
+
from taichi.lang.impl import get_runtime, grouped, static
|
7
|
+
from taichi.lang.kernel_impl import func
|
8
|
+
from taichi.lang.matrix import Matrix, Vector
|
9
|
+
from taichi.types import f32, f64
|
10
|
+
from taichi.types.annotations import template
|
11
|
+
|
12
|
+
|
13
|
+
@func
|
14
|
+
def _randn(dt):
|
15
|
+
"""
|
16
|
+
Generate a random float sampled from univariate standard normal
|
17
|
+
(Gaussian) distribution of mean 0 and variance 1, using the
|
18
|
+
Box-Muller transformation.
|
19
|
+
"""
|
20
|
+
assert dt == f32 or dt == f64
|
21
|
+
u1 = ops.cast(1.0, dt) - ops.random(dt)
|
22
|
+
u2 = ops.random(dt)
|
23
|
+
r = ops.sqrt(-2 * ops.log(u1))
|
24
|
+
c = ops.cos(math.tau * u2)
|
25
|
+
return r * c
|
26
|
+
|
27
|
+
|
28
|
+
def randn(dt=None):
|
29
|
+
"""Generate a random float sampled from univariate standard normal
|
30
|
+
(Gaussian) distribution of mean 0 and variance 1, using the
|
31
|
+
Box-Muller transformation. Must be called in Taichi scope.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
dt (DataType): Data type of the required random number. Default to `None`.
|
35
|
+
If set to `None` `dt` will be determined dynamically in runtime.
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
The generated random float.
|
39
|
+
|
40
|
+
Example::
|
41
|
+
|
42
|
+
>>> @ti.kernel
|
43
|
+
>>> def main():
|
44
|
+
>>> print(ti.randn())
|
45
|
+
>>>
|
46
|
+
>>> main()
|
47
|
+
-0.463608
|
48
|
+
"""
|
49
|
+
if dt is None:
|
50
|
+
dt = impl.get_runtime().default_fp
|
51
|
+
return _randn(dt)
|
52
|
+
|
53
|
+
|
54
|
+
@func
|
55
|
+
def _polar_decompose2d(A, dt):
|
56
|
+
"""Perform polar decomposition (A=UP) for 2x2 matrix.
|
57
|
+
Mathematical concept refers to https://en.wikipedia.org/wiki/Polar_decomposition.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
A (ti.Matrix(2, 2)): input 2x2 matrix `A`.
|
61
|
+
dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64.
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
Decomposed 2x2 matrices `U` and `P`. `U` is a 2x2 orthogonal matrix
|
65
|
+
and `P` is a 2x2 positive or semi-positive definite matrix.
|
66
|
+
"""
|
67
|
+
U = Matrix.identity(dt, 2)
|
68
|
+
P = ops.cast(A, dt)
|
69
|
+
zero = ops.cast(0.0, dt)
|
70
|
+
# if A is a zero matrix we simply return the pair (I, A)
|
71
|
+
if A[0, 0] == zero and A[0, 1] == zero and A[1, 0] == zero and A[1, 1] == zero:
|
72
|
+
pass
|
73
|
+
else:
|
74
|
+
detA = A[0, 0] * A[1, 1] - A[1, 0] * A[0, 1]
|
75
|
+
adetA = abs(detA)
|
76
|
+
B = Matrix(
|
77
|
+
[
|
78
|
+
[A[0, 0] + A[1, 1], A[0, 1] - A[1, 0]],
|
79
|
+
[A[1, 0] - A[0, 1], A[1, 1] + A[0, 0]],
|
80
|
+
],
|
81
|
+
dt,
|
82
|
+
)
|
83
|
+
|
84
|
+
if detA < zero:
|
85
|
+
B = Matrix(
|
86
|
+
[
|
87
|
+
[A[0, 0] - A[1, 1], A[0, 1] + A[1, 0]],
|
88
|
+
[A[1, 0] + A[0, 1], A[1, 1] - A[0, 0]],
|
89
|
+
],
|
90
|
+
dt,
|
91
|
+
)
|
92
|
+
# here det(B) != 0 if A is not the zero matrix
|
93
|
+
adetB = abs(B[0, 0] * B[1, 1] - B[1, 0] * B[0, 1])
|
94
|
+
k = ops.cast(1.0, dt) / ops.sqrt(adetB)
|
95
|
+
U = B * k
|
96
|
+
P = (A.transpose() @ A + adetA * Matrix.identity(dt, 2)) * k
|
97
|
+
|
98
|
+
return U, P
|
99
|
+
|
100
|
+
|
101
|
+
@func
|
102
|
+
def _polar_decompose3d(A, dt):
|
103
|
+
"""Perform polar decomposition (A=UP) for 3x3 matrix.
|
104
|
+
|
105
|
+
Mathematical concept refers to https://en.wikipedia.org/wiki/Polar_decomposition.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
A (ti.Matrix(3, 3)): input 3x3 matrix `A`.
|
109
|
+
dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64.
|
110
|
+
|
111
|
+
Returns:
|
112
|
+
Decomposed 3x3 matrices `U` and `P`.
|
113
|
+
"""
|
114
|
+
U, sig, V = _svd3d(A, dt)
|
115
|
+
return U @ V.transpose(), V @ sig @ V.transpose()
|
116
|
+
|
117
|
+
|
118
|
+
# https://www.seas.upenn.edu/~cffjiang/research/svd/svd.pdf
|
119
|
+
@func
|
120
|
+
def _svd2d(A, dt):
|
121
|
+
"""Perform singular value decomposition (A=USV^T) for 2x2 matrix.
|
122
|
+
|
123
|
+
Mathematical concept refers to https://en.wikipedia.org/wiki/Singular_value_decomposition.
|
124
|
+
|
125
|
+
Args:
|
126
|
+
A (ti.Matrix(2, 2)): input 2x2 matrix `A`.
|
127
|
+
dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64.
|
128
|
+
|
129
|
+
Returns:
|
130
|
+
Decomposed 2x2 matrices `U`, 'S' and `V`.
|
131
|
+
"""
|
132
|
+
R, S = _polar_decompose2d(A, dt)
|
133
|
+
c, s = ops.cast(0.0, dt), ops.cast(0.0, dt)
|
134
|
+
s1, s2 = ops.cast(0.0, dt), ops.cast(0.0, dt)
|
135
|
+
if abs(S[0, 1]) < 1e-5:
|
136
|
+
c, s = 1, 0
|
137
|
+
s1, s2 = S[0, 0], S[1, 1]
|
138
|
+
else:
|
139
|
+
tao = ops.cast(0.5, dt) * (S[0, 0] - S[1, 1])
|
140
|
+
w = ops.sqrt(tao**2 + S[0, 1] ** 2)
|
141
|
+
t = ops.cast(0.0, dt)
|
142
|
+
if tao > 0:
|
143
|
+
t = S[0, 1] / (tao + w)
|
144
|
+
else:
|
145
|
+
t = S[0, 1] / (tao - w)
|
146
|
+
c = 1 / ops.sqrt(t**2 + 1)
|
147
|
+
s = -t * c
|
148
|
+
s1 = c**2 * S[0, 0] - 2 * c * s * S[0, 1] + s**2 * S[1, 1]
|
149
|
+
s2 = s**2 * S[0, 0] + 2 * c * s * S[0, 1] + c**2 * S[1, 1]
|
150
|
+
V = Matrix.zero(dt, 2, 2)
|
151
|
+
if s1 < s2:
|
152
|
+
tmp = s1
|
153
|
+
s1 = s2
|
154
|
+
s2 = tmp
|
155
|
+
V = Matrix([[-s, c], [-c, -s]], dt=dt)
|
156
|
+
else:
|
157
|
+
V = Matrix([[c, s], [-s, c]], dt=dt)
|
158
|
+
U = R @ V
|
159
|
+
return U, Matrix([[s1, ops.cast(0, dt)], [ops.cast(0, dt), s2]], dt=dt), V
|
160
|
+
|
161
|
+
|
162
|
+
def _svd3d(A, dt, iters=None):
|
163
|
+
"""Perform singular value decomposition (A=USV^T) for 3x3 matrix.
|
164
|
+
|
165
|
+
Mathematical concept refers to https://en.wikipedia.org/wiki/Singular_value_decomposition.
|
166
|
+
|
167
|
+
Args:
|
168
|
+
A (ti.Matrix(3, 3)): input 3x3 matrix `A`.
|
169
|
+
dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64.
|
170
|
+
iters (int): iteration number to control algorithm precision.
|
171
|
+
|
172
|
+
Returns:
|
173
|
+
Decomposed 3x3 matrices `U`, 'S' and `V`.
|
174
|
+
"""
|
175
|
+
assert A.n == 3 and A.m == 3
|
176
|
+
assert dt in [f32, f64]
|
177
|
+
if iters is None:
|
178
|
+
if dt == f32:
|
179
|
+
iters = 5
|
180
|
+
else:
|
181
|
+
iters = 8
|
182
|
+
if dt == f32:
|
183
|
+
rets = get_runtime().compiling_callable.ast_builder().sifakis_svd_f32(A.ptr, iters)
|
184
|
+
else:
|
185
|
+
rets = get_runtime().compiling_callable.ast_builder().sifakis_svd_f64(A.ptr, iters)
|
186
|
+
assert len(rets) == 21
|
187
|
+
U_entries = rets[:9]
|
188
|
+
V_entries = rets[9:18]
|
189
|
+
sig_entries = rets[18:]
|
190
|
+
|
191
|
+
@func
|
192
|
+
def get_result():
|
193
|
+
U = Matrix.zero(dt, 3, 3)
|
194
|
+
V = Matrix.zero(dt, 3, 3)
|
195
|
+
sigma = Matrix.zero(dt, 3, 3)
|
196
|
+
for i in static(range(3)):
|
197
|
+
for j in static(range(3)):
|
198
|
+
U[i, j] = U_entries[i * 3 + j]
|
199
|
+
V[i, j] = V_entries[i * 3 + j]
|
200
|
+
sigma[i, i] = sig_entries[i]
|
201
|
+
return U, sigma, V
|
202
|
+
|
203
|
+
return get_result()
|
204
|
+
|
205
|
+
|
206
|
+
@func
|
207
|
+
def _eig2x2(A, dt):
|
208
|
+
"""Compute the eigenvalues and right eigenvectors (Av=lambda v) of a 2x2 real matrix.
|
209
|
+
|
210
|
+
Mathematical concept refers to https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix.
|
211
|
+
|
212
|
+
Args:
|
213
|
+
A (ti.Matrix(2, 2)): input 2x2 matrix `A`.
|
214
|
+
dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64.
|
215
|
+
|
216
|
+
Returns:
|
217
|
+
eigenvalues (ti.Matrix(2, 2)): The eigenvalues in complex form. Each row stores one eigenvalue. The first number of the eigenvalue represents the real part and the second number represents the imaginary part.
|
218
|
+
eigenvectors: (ti.Matrix(4, 2)): The eigenvectors in complex form. Each column stores one eigenvector. Each eigenvector consists of 2 entries, each of which is represented by two numbers for its real part and imaginary part.
|
219
|
+
"""
|
220
|
+
tr = A.trace()
|
221
|
+
det = A.determinant()
|
222
|
+
gap = tr**2 - 4 * det
|
223
|
+
lambda1 = Vector.zero(dt, 2)
|
224
|
+
lambda2 = Vector.zero(dt, 2)
|
225
|
+
v1 = Vector.zero(dt, 4)
|
226
|
+
v2 = Vector.zero(dt, 4)
|
227
|
+
if gap > 0:
|
228
|
+
lambda1 = Vector([tr + ops.sqrt(gap), 0.0], dt=dt) * 0.5
|
229
|
+
lambda2 = Vector([tr - ops.sqrt(gap), 0.0], dt=dt) * 0.5
|
230
|
+
A1 = A - lambda1[0] * Matrix.identity(dt, 2)
|
231
|
+
A2 = A - lambda2[0] * Matrix.identity(dt, 2)
|
232
|
+
if all(A1 == Matrix.zero(dt, 2, 2)) and all(A1 == Matrix.zero(dt, 2, 2)):
|
233
|
+
v1 = Vector([0.0, 0.0, 1.0, 0.0]).cast(dt)
|
234
|
+
v2 = Vector([1.0, 0.0, 0.0, 0.0]).cast(dt)
|
235
|
+
else:
|
236
|
+
v1 = Vector([A2[0, 0], 0.0, A2[1, 0], 0.0], dt=dt).normalized()
|
237
|
+
v2 = Vector([A1[0, 0], 0.0, A1[1, 0], 0.0], dt=dt).normalized()
|
238
|
+
else:
|
239
|
+
lambda1 = Vector([tr, ops.sqrt(-gap)], dt=dt) * 0.5
|
240
|
+
lambda2 = Vector([tr, -ops.sqrt(-gap)], dt=dt) * 0.5
|
241
|
+
A1r = A - lambda1[0] * Matrix.identity(dt, 2)
|
242
|
+
A1i = -lambda1[1] * Matrix.identity(dt, 2)
|
243
|
+
A2r = A - lambda2[0] * Matrix.identity(dt, 2)
|
244
|
+
A2i = -lambda2[1] * Matrix.identity(dt, 2)
|
245
|
+
v1 = Vector([A2r[0, 0], A2i[0, 0], A2r[1, 0], A2i[1, 0]], dt=dt).normalized()
|
246
|
+
v2 = Vector([A1r[0, 0], A1i[0, 0], A1r[1, 0], A1i[1, 0]], dt=dt).normalized()
|
247
|
+
eigenvalues = Matrix.rows([lambda1, lambda2])
|
248
|
+
eigenvectors = Matrix.cols([v1, v2])
|
249
|
+
|
250
|
+
return eigenvalues, eigenvectors
|
251
|
+
|
252
|
+
|
253
|
+
@func
|
254
|
+
def _sym_eig2x2(A, dt):
|
255
|
+
"""Compute the eigenvalues and right eigenvectors (Av=lambda v) of a 2x2 real symmetric matrix.
|
256
|
+
|
257
|
+
Mathematical concept refers to https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix.
|
258
|
+
|
259
|
+
Args:
|
260
|
+
A (ti.Matrix(2, 2)): input 2x2 symmetric matrix `A`.
|
261
|
+
dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64.
|
262
|
+
|
263
|
+
Returns:
|
264
|
+
eigenvalues (ti.Vector(2)): The eigenvalues. Each entry store one eigen value.
|
265
|
+
eigenvectors (ti.Matrix(2, 2)): The eigenvectors. Each column stores one eigenvector.
|
266
|
+
"""
|
267
|
+
assert all(A == A.transpose()), "A needs to be symmetric"
|
268
|
+
tr = A.trace()
|
269
|
+
det = A.determinant()
|
270
|
+
gap = tr**2 - 4 * det
|
271
|
+
lambda1 = (tr + ops.sqrt(gap)) * 0.5
|
272
|
+
lambda2 = (tr - ops.sqrt(gap)) * 0.5
|
273
|
+
eigenvalues = Vector([lambda1, lambda2], dt=dt)
|
274
|
+
|
275
|
+
A1 = A - lambda1 * Matrix.identity(dt, 2)
|
276
|
+
A2 = A - lambda2 * Matrix.identity(dt, 2)
|
277
|
+
v1 = Vector.zero(dt, 2)
|
278
|
+
v2 = Vector.zero(dt, 2)
|
279
|
+
if all(A1 == Matrix.zero(dt, 2, 2)) and all(A1 == Matrix.zero(dt, 2, 2)):
|
280
|
+
v1 = Vector([0.0, 1.0]).cast(dt)
|
281
|
+
v2 = Vector([1.0, 0.0]).cast(dt)
|
282
|
+
else:
|
283
|
+
v1 = Vector([A2[0, 0], A2[1, 0]], dt=dt).normalized()
|
284
|
+
v2 = Vector([A1[0, 0], A1[1, 0]], dt=dt).normalized()
|
285
|
+
eigenvectors = Matrix.cols([v1, v2])
|
286
|
+
return eigenvalues, eigenvectors
|
287
|
+
|
288
|
+
|
289
|
+
@func
|
290
|
+
def dsytrd3(A, Q, dt):
|
291
|
+
Q[0, 0] = 1.0
|
292
|
+
Q[1, 1] = 1.0
|
293
|
+
Q[2, 2] = 1.0
|
294
|
+
e = Vector([0.0, 0.0, 0.0], dt=dt)
|
295
|
+
u = Vector([0.0, 0.0, 0.0], dt=dt)
|
296
|
+
q = Vector([0.0, 0.0, 0.0], dt=dt)
|
297
|
+
d = Vector([0.0, 0.0, 0.0], dt=dt)
|
298
|
+
h = A[0, 1] ** 2 + A[0, 2] ** 2
|
299
|
+
g = 0.0
|
300
|
+
if A[0, 1] > 0:
|
301
|
+
g = -ops.sqrt(h)
|
302
|
+
else:
|
303
|
+
g = ops.sqrt(h)
|
304
|
+
e[0] = g
|
305
|
+
f = g * A[0, 1]
|
306
|
+
u[1] = A[0, 1] - g
|
307
|
+
u[2] = A[0, 2]
|
308
|
+
omega = h - f
|
309
|
+
if omega > 0.0:
|
310
|
+
omega = 1.0 / omega
|
311
|
+
K = 0.0
|
312
|
+
f = A[1, 1] * u[1] + A[1, 2] * u[2]
|
313
|
+
q[1] = omega * f # p
|
314
|
+
K += u[1] * f # u* A u
|
315
|
+
|
316
|
+
f = A[1, 2] * u[1] + A[2, 2] * u[2]
|
317
|
+
q[2] = omega * f # p
|
318
|
+
K += u[2] * f # u* A u
|
319
|
+
|
320
|
+
K *= 0.5 * omega * omega
|
321
|
+
|
322
|
+
q[1] = q[1] - K * u[1]
|
323
|
+
q[2] = q[2] - K * u[2]
|
324
|
+
|
325
|
+
d[0] = A[0, 0]
|
326
|
+
d[1] = A[1, 1] - 2.0 * q[1] * u[1]
|
327
|
+
d[2] = A[2, 2] - 2.0 * q[2] * u[2]
|
328
|
+
|
329
|
+
for j in range(1, 3):
|
330
|
+
f = omega * u[j]
|
331
|
+
for i in range(1, 3):
|
332
|
+
Q[i, j] = Q[i, j] - f * u[i]
|
333
|
+
|
334
|
+
# Calculate updated A[1, 2] and store it in e[1]
|
335
|
+
e[1] = A[1, 2] - q[1] * u[2] - u[1] * q[2]
|
336
|
+
else:
|
337
|
+
d[0] = A[0, 0]
|
338
|
+
d[1] = A[1, 1]
|
339
|
+
d[2] = A[2, 2]
|
340
|
+
e[1] = A[1, 2]
|
341
|
+
return d, e, Q
|
342
|
+
|
343
|
+
|
344
|
+
@func
|
345
|
+
def dsyevq3(A, Q, w, dt):
|
346
|
+
w, e, Q = dsytrd3(A, Q, dt)
|
347
|
+
for l in range(0, 2):
|
348
|
+
nIter = 0
|
349
|
+
while True:
|
350
|
+
# Check for convergence and exit iteration loop if off-diagonal
|
351
|
+
# element e(l) is zero
|
352
|
+
m = 0
|
353
|
+
for i in range(l, 2):
|
354
|
+
m = i
|
355
|
+
g = ops.abs(w[m]) + ops.abs(w[m + 1])
|
356
|
+
if ops.abs(e[m]) + g == g:
|
357
|
+
break
|
358
|
+
if m == l:
|
359
|
+
break
|
360
|
+
|
361
|
+
nIter += 1
|
362
|
+
assert nIter <= 30, "Timeout"
|
363
|
+
|
364
|
+
# Calculate g = d_m - k
|
365
|
+
g = (w[l + 1] - w[l]) / (e[l] + e[l])
|
366
|
+
r = ops.sqrt(g * g + 1.0)
|
367
|
+
if g > 0:
|
368
|
+
g = w[m] - w[l] + e[l] / (g + r)
|
369
|
+
else:
|
370
|
+
g = w[m] - w[l] + e[l] / (g - r)
|
371
|
+
|
372
|
+
s = c = 1.0
|
373
|
+
p = 0.0
|
374
|
+
i = m - 1
|
375
|
+
while i >= l:
|
376
|
+
f = s * e[i]
|
377
|
+
b = c * e[i]
|
378
|
+
if ops.abs(f) > ops.abs(g):
|
379
|
+
c = g / f
|
380
|
+
r = ops.sqrt(c * c + 1.0)
|
381
|
+
e[i + 1] = f * r
|
382
|
+
s = 1.0 / r
|
383
|
+
c *= s
|
384
|
+
else:
|
385
|
+
s = f / g
|
386
|
+
r = ops.sqrt(s * s + 1.0)
|
387
|
+
e[i + 1] = g * r
|
388
|
+
c = 1.0 / r
|
389
|
+
s *= c
|
390
|
+
|
391
|
+
g = w[i + 1] - p
|
392
|
+
r = (w[i] - g) * s + 2.0 * c * b
|
393
|
+
p = s * r
|
394
|
+
w[i + 1] = g + p
|
395
|
+
g = c * r - b
|
396
|
+
|
397
|
+
for k in range(0, 3):
|
398
|
+
t = Q[k, i + 1]
|
399
|
+
Q[k, i + 1] = s * Q[k, i] + c * t
|
400
|
+
Q[k, i] = c * Q[k, i] - s * t
|
401
|
+
|
402
|
+
i -= 1
|
403
|
+
w[l] -= p
|
404
|
+
e[l] = g
|
405
|
+
e[m] = 0.0
|
406
|
+
return Q, w
|
407
|
+
|
408
|
+
|
409
|
+
@func
|
410
|
+
def _sym_eig3x3(A, dt):
|
411
|
+
"""Compute the eigenvalues and right eigenvectors (Av=lambda v) of a 3x3 real symmetric matrix using Cardano's method.
|
412
|
+
|
413
|
+
Mathematical concept refers to https://www.mpi-hd.mpg.de/personalhomes/globes/3x3/.
|
414
|
+
|
415
|
+
Args:
|
416
|
+
A (ti.Matrix(3, 3)): input 3x3 symmetric matrix `A`.
|
417
|
+
dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64.
|
418
|
+
|
419
|
+
Returns:
|
420
|
+
eigenvalues (ti.Vector(3)): The eigenvalues. Each entry store one eigen value.
|
421
|
+
eigenvectors (ti.Matrix(3, 3)): The eigenvectors. Each column stores one eigenvector.
|
422
|
+
"""
|
423
|
+
assert all(A == A.transpose()), "A needs to be symmetric"
|
424
|
+
M_SQRT3 = 1.73205080756887729352744634151
|
425
|
+
DBL_EPSILON = 2.2204460492503131e-16
|
426
|
+
m = A.trace()
|
427
|
+
dd = A[0, 1] * A[0, 1]
|
428
|
+
ee = A[1, 2] * A[1, 2]
|
429
|
+
ff = A[0, 2] * A[0, 2]
|
430
|
+
c1 = A[0, 0] * A[1, 1] + A[0, 0] * A[2, 2] + A[1, 1] * A[2, 2] - (dd + ee + ff)
|
431
|
+
c0 = A[2, 2] * dd + A[0, 0] * ee + A[1, 1] * ff - A[0, 0] * A[1, 1] * A[2, 2] - 2.0 * A[0, 2] * A[0, 1] * A[1, 2]
|
432
|
+
|
433
|
+
p = m * m - 3.0 * c1
|
434
|
+
q = m * (p - 1.5 * c1) - 13.5 * c0
|
435
|
+
sqrt_p = ops.sqrt(ops.abs(p))
|
436
|
+
phi = 27.0 * (0.25 * c1 * c1 * (p - c1) + c0 * (q + 6.75 * c0))
|
437
|
+
phi = (1.0 / 3.0) * ops.atan2(ops.sqrt(ops.abs(phi)), q)
|
438
|
+
|
439
|
+
c = sqrt_p * ops.cos(phi)
|
440
|
+
s = (1.0 / M_SQRT3) * sqrt_p * ops.sin(phi)
|
441
|
+
eigenvalues = Vector([0.0, 0.0, 0.0], dt=dt)
|
442
|
+
eigenvalues_final = Vector([0.0, 0.0, 0.0], dt=dt)
|
443
|
+
eigenvalues[1] = (1.0 / 3.0) * (m - c)
|
444
|
+
eigenvalues[2] = eigenvalues[1] + s
|
445
|
+
eigenvalues[0] = eigenvalues[1] + c
|
446
|
+
eigenvalues[1] = eigenvalues[1] - s
|
447
|
+
|
448
|
+
t = ops.abs(eigenvalues[0])
|
449
|
+
u = ops.abs(eigenvalues[1])
|
450
|
+
if u > t:
|
451
|
+
t = u
|
452
|
+
u = ops.abs(eigenvalues[2])
|
453
|
+
if u > t:
|
454
|
+
t = u
|
455
|
+
if t < 1.0:
|
456
|
+
u = t
|
457
|
+
else:
|
458
|
+
u = t * t
|
459
|
+
error = 256.0 * DBL_EPSILON * u * u
|
460
|
+
Q = Matrix.zero(dt, 3, 3)
|
461
|
+
Q_final = Matrix.zero(dt, 3, 3)
|
462
|
+
Q[0, 1] = A[0, 1] * A[1, 2] - A[0, 2] * A[1, 1]
|
463
|
+
Q[1, 1] = A[0, 2] * A[0, 1] - A[1, 2] * A[0, 0]
|
464
|
+
Q[2, 1] = A[0, 1] * A[0, 1]
|
465
|
+
|
466
|
+
Q[0, 0] = Q[0, 1] + A[0, 2] * eigenvalues[0]
|
467
|
+
Q[1, 0] = Q[1, 1] + A[1, 2] * eigenvalues[0]
|
468
|
+
Q[2, 0] = (A[0, 0] - eigenvalues[0]) * (A[1, 1] - eigenvalues[0]) - Q[2, 1]
|
469
|
+
norm = Q[0, 0] * Q[0, 0] + Q[1, 0] * Q[1, 0] + Q[2, 0] * Q[2, 0]
|
470
|
+
early_ret = 0
|
471
|
+
if norm <= error:
|
472
|
+
Q_final, eigenvalues_final = dsyevq3(A, Q, eigenvalues, dt)
|
473
|
+
early_ret = 1
|
474
|
+
else:
|
475
|
+
norm = ops.sqrt(1.0 / norm)
|
476
|
+
Q[0, 0] *= norm
|
477
|
+
Q[1, 0] *= norm
|
478
|
+
Q[2, 0] *= norm
|
479
|
+
|
480
|
+
if not early_ret:
|
481
|
+
Q[0, 1] = Q[0, 1] + A[0, 2] * eigenvalues[1]
|
482
|
+
Q[1, 1] = Q[1, 1] + A[1, 2] * eigenvalues[1]
|
483
|
+
Q[2, 1] = (A[0, 0] - eigenvalues[1]) * (A[1, 1] - eigenvalues[1]) - Q[2, 1]
|
484
|
+
norm = Q[0, 1] * Q[0, 1] + Q[1, 1] * Q[1, 1] + Q[2, 1] * Q[2, 1]
|
485
|
+
if norm <= error:
|
486
|
+
Q_final, eigenvalues_final = dsyevq3(A, Q, eigenvalues, dt)
|
487
|
+
early_ret = 1
|
488
|
+
else:
|
489
|
+
norm = ops.sqrt(1.0 / norm)
|
490
|
+
Q[0, 1] *= norm
|
491
|
+
Q[1, 1] *= norm
|
492
|
+
Q[2, 1] *= norm
|
493
|
+
|
494
|
+
Q[0, 2] = Q[1, 0] * Q[2, 1] - Q[2, 0] * Q[1, 1]
|
495
|
+
Q[1, 2] = Q[2, 0] * Q[0, 1] - Q[0, 0] * Q[2, 1]
|
496
|
+
Q[2, 2] = Q[0, 0] * Q[1, 1] - Q[1, 0] * Q[0, 1]
|
497
|
+
|
498
|
+
if early_ret:
|
499
|
+
Q = Q_final
|
500
|
+
eigenvalues = eigenvalues_final
|
501
|
+
|
502
|
+
if eigenvalues[1] < eigenvalues[0]:
|
503
|
+
tmp = eigenvalues[0]
|
504
|
+
eigenvalues[0] = eigenvalues[1]
|
505
|
+
eigenvalues[1] = tmp
|
506
|
+
tmp2 = Q[:, 0]
|
507
|
+
Q[:, 0] = Q[:, 1]
|
508
|
+
Q[:, 1] = tmp2
|
509
|
+
|
510
|
+
if eigenvalues[2] < eigenvalues[0]:
|
511
|
+
tmp = eigenvalues[0]
|
512
|
+
eigenvalues[0] = eigenvalues[2]
|
513
|
+
eigenvalues[2] = tmp
|
514
|
+
tmp2 = Q[:, 0]
|
515
|
+
Q[:, 0] = Q[:, 2]
|
516
|
+
Q[:, 2] = tmp2
|
517
|
+
|
518
|
+
if eigenvalues[2] < eigenvalues[1]:
|
519
|
+
tmp = eigenvalues[1]
|
520
|
+
eigenvalues[1] = eigenvalues[2]
|
521
|
+
eigenvalues[2] = tmp
|
522
|
+
tmp2 = Q[:, 1]
|
523
|
+
Q[:, 1] = Q[:, 2]
|
524
|
+
Q[:, 2] = tmp2
|
525
|
+
|
526
|
+
return eigenvalues, Q
|
527
|
+
|
528
|
+
|
529
|
+
def polar_decompose(A, dt=None):
|
530
|
+
"""Perform polar decomposition (A=UP) for arbitrary size matrix.
|
531
|
+
|
532
|
+
Mathematical concept refers to https://en.wikipedia.org/wiki/Polar_decomposition.
|
533
|
+
|
534
|
+
Args:
|
535
|
+
A (ti.Matrix(n, n)): input nxn matrix `A`.
|
536
|
+
dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64.
|
537
|
+
|
538
|
+
Returns:
|
539
|
+
Decomposed nxn matrices `U` and `P`.
|
540
|
+
"""
|
541
|
+
if dt is None:
|
542
|
+
dt = impl.get_runtime().default_fp
|
543
|
+
if A.n == 2:
|
544
|
+
return _polar_decompose2d(A, dt)
|
545
|
+
if A.n == 3:
|
546
|
+
return _polar_decompose3d(A, dt)
|
547
|
+
raise Exception("Polar decomposition only supports 2D and 3D matrices.")
|
548
|
+
|
549
|
+
|
550
|
+
def svd(A, dt=None):
|
551
|
+
"""Perform singular value decomposition (A=USV^T) for arbitrary size matrix.
|
552
|
+
|
553
|
+
Mathematical concept refers to https://en.wikipedia.org/wiki/Singular_value_decomposition.
|
554
|
+
|
555
|
+
Args:
|
556
|
+
A (ti.Matrix(n, n)): input nxn matrix `A`.
|
557
|
+
dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64.
|
558
|
+
|
559
|
+
Returns:
|
560
|
+
Decomposed nxn matrices `U`, 'S' and `V`.
|
561
|
+
"""
|
562
|
+
if dt is None:
|
563
|
+
dt = impl.get_runtime().default_fp
|
564
|
+
if A.n == 2:
|
565
|
+
return _svd2d(A, dt)
|
566
|
+
if A.n == 3:
|
567
|
+
return _svd3d(A, dt)
|
568
|
+
raise Exception("SVD only supports 2D and 3D matrices.")
|
569
|
+
|
570
|
+
|
571
|
+
def eig(A, dt=None):
|
572
|
+
"""Compute the eigenvalues and right eigenvectors of a real matrix.
|
573
|
+
|
574
|
+
Mathematical concept refers to https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix.
|
575
|
+
|
576
|
+
Args:
|
577
|
+
A (ti.Matrix(n, n)): 2D Matrix for which the eigenvalues and right eigenvectors will be computed.
|
578
|
+
dt (DataType): The datatype for the eigenvalues and right eigenvectors.
|
579
|
+
|
580
|
+
Returns:
|
581
|
+
eigenvalues (ti.Matrix(n, 2)): The eigenvalues in complex form. Each row stores one eigenvalue. The first number of the eigenvalue represents the real part and the second number represents the imaginary part.
|
582
|
+
eigenvectors (ti.Matrix(n*2, n)): The eigenvectors in complex form. Each column stores one eigenvector. Each eigenvector consists of n entries, each of which is represented by two numbers for its real part and imaginary part.
|
583
|
+
"""
|
584
|
+
if dt is None:
|
585
|
+
dt = impl.get_runtime().default_fp
|
586
|
+
if A.n == 2:
|
587
|
+
return _eig2x2(A, dt)
|
588
|
+
raise Exception("Eigen solver only supports 2D matrices.")
|
589
|
+
|
590
|
+
|
591
|
+
def sym_eig(A, dt=None):
|
592
|
+
"""Compute the eigenvalues and right eigenvectors of a real symmetric matrix.
|
593
|
+
|
594
|
+
Mathematical concept refers to https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix.
|
595
|
+
|
596
|
+
Args:
|
597
|
+
A (ti.Matrix(n, n)): Symmetric Matrix for which the eigenvalues and right eigenvectors will be computed.
|
598
|
+
dt (DataType): The datatype for the eigenvalues and right eigenvectors.
|
599
|
+
|
600
|
+
Returns:
|
601
|
+
eigenvalues (ti.Vector(n)): The eigenvalues. Each entry store one eigen value.
|
602
|
+
eigenvectors (ti.Matrix(n, n)): The eigenvectors. Each column stores one eigenvector.
|
603
|
+
"""
|
604
|
+
if dt is None:
|
605
|
+
dt = impl.get_runtime().default_fp
|
606
|
+
if A.n == 2:
|
607
|
+
return _sym_eig2x2(A, dt)
|
608
|
+
if A.n == 3:
|
609
|
+
return _sym_eig3x3(A, dt)
|
610
|
+
raise Exception("Symmetric eigen solver only supports 2D and 3D matrices.")
|
611
|
+
|
612
|
+
|
613
|
+
@func
|
614
|
+
def _gauss_elimination_2x2(Ab, dt):
|
615
|
+
if ops.abs(Ab[0, 0]) < ops.abs(Ab[1, 0]):
|
616
|
+
Ab[0, 0], Ab[1, 0] = Ab[1, 0], Ab[0, 0]
|
617
|
+
Ab[0, 1], Ab[1, 1] = Ab[1, 1], Ab[0, 1]
|
618
|
+
Ab[0, 2], Ab[1, 2] = Ab[1, 2], Ab[0, 2]
|
619
|
+
assert Ab[0, 0] != 0.0, "Matrix is singular in linear solve."
|
620
|
+
scale = Ab[1, 0] / Ab[0, 0]
|
621
|
+
Ab[1, 0] = 0.0
|
622
|
+
for k in static(range(1, 3)):
|
623
|
+
Ab[1, k] -= Ab[0, k] * scale
|
624
|
+
x = Vector.zero(dt, 2)
|
625
|
+
# Back substitution
|
626
|
+
x[1] = Ab[1, 2] / Ab[1, 1]
|
627
|
+
x[0] = (Ab[0, 2] - Ab[0, 1] * x[1]) / Ab[0, 0]
|
628
|
+
return x
|
629
|
+
|
630
|
+
|
631
|
+
@func
|
632
|
+
def _gauss_elimination_3x3(Ab, dt):
|
633
|
+
for i in static(range(3)):
|
634
|
+
max_row = i
|
635
|
+
max_v = ops.abs(Ab[i, i])
|
636
|
+
for j in static(range(i + 1, 3)):
|
637
|
+
if ops.abs(Ab[j, i]) > max_v:
|
638
|
+
max_row = j
|
639
|
+
max_v = ops.abs(Ab[j, i])
|
640
|
+
assert max_v != 0.0, "Matrix is singular in linear solve."
|
641
|
+
if i != max_row:
|
642
|
+
if max_row == 1:
|
643
|
+
for col in static(range(4)):
|
644
|
+
Ab[i, col], Ab[1, col] = Ab[1, col], Ab[i, col]
|
645
|
+
else:
|
646
|
+
for col in static(range(4)):
|
647
|
+
Ab[i, col], Ab[2, col] = Ab[2, col], Ab[i, col]
|
648
|
+
assert Ab[i, i] != 0.0, "Matrix is singular in linear solve."
|
649
|
+
for j in static(range(i + 1, 3)):
|
650
|
+
scale = Ab[j, i] / Ab[i, i]
|
651
|
+
Ab[j, i] = 0.0
|
652
|
+
for k in static(range(i + 1, 4)):
|
653
|
+
Ab[j, k] -= Ab[i, k] * scale
|
654
|
+
# Back substitution
|
655
|
+
x = Vector.zero(dt, 3)
|
656
|
+
for i in static(range(2, -1, -1)):
|
657
|
+
x[i] = Ab[i, 3]
|
658
|
+
for k in static(range(i + 1, 3)):
|
659
|
+
x[i] -= Ab[i, k] * x[k]
|
660
|
+
x[i] = x[i] / Ab[i, i]
|
661
|
+
return x
|
662
|
+
|
663
|
+
|
664
|
+
@func
|
665
|
+
def _combine(A, b, dt):
|
666
|
+
n = static(A.n)
|
667
|
+
Ab = Matrix.zero(dt, n, n + 1)
|
668
|
+
for i in static(range(n)):
|
669
|
+
for j in static(range(n)):
|
670
|
+
Ab[i, j] = A[i, j]
|
671
|
+
for i in static(range(n)):
|
672
|
+
Ab[i, n] = b[i]
|
673
|
+
return Ab
|
674
|
+
|
675
|
+
|
676
|
+
def solve(A, b, dt=None):
|
677
|
+
"""Solve a matrix using Gauss elimination method.
|
678
|
+
|
679
|
+
Args:
|
680
|
+
A (ti.Matrix(n, n)): input nxn matrix `A`.
|
681
|
+
b (ti.Vector(n, 1)): input nx1 vector `b`.
|
682
|
+
dt (DataType): The datatype for the `A` and `b`.
|
683
|
+
|
684
|
+
Returns:
|
685
|
+
x (ti.Vector(n, 1)): the solution of Ax=b.
|
686
|
+
"""
|
687
|
+
assert A.n == A.m, "Only square matrix is supported"
|
688
|
+
assert A.n >= 2 and A.n <= 3, "Only 2D and 3D matrices are supported"
|
689
|
+
assert A.m == b.n, "Matrix and Vector dimension dismatch"
|
690
|
+
if dt is None:
|
691
|
+
dt = impl.get_runtime().default_fp
|
692
|
+
Ab = _combine(A, b, dt)
|
693
|
+
if A.n == 2:
|
694
|
+
return _gauss_elimination_2x2(Ab, dt)
|
695
|
+
if A.n == 3:
|
696
|
+
return _gauss_elimination_3x3(Ab, dt)
|
697
|
+
raise Exception("Solver only supports 2D and 3D matrices.")
|
698
|
+
|
699
|
+
|
700
|
+
@func
|
701
|
+
def field_fill_taichi_scope(F: template(), val: template()):
|
702
|
+
for I in grouped(F):
|
703
|
+
F[I] = val
|
704
|
+
|
705
|
+
|
706
|
+
__all__ = ["randn", "polar_decompose", "eig", "sym_eig", "svd", "solve"]
|