spikezoo 0.1.1__py3-none-any.whl → 0.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- spikezoo/__init__.py +13 -0
- spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/base/nets.py +34 -0
- spikezoo/archs/bsf/README.md +92 -0
- spikezoo/archs/bsf/datasets/datasets.py +328 -0
- spikezoo/archs/bsf/datasets/ds_utils.py +64 -0
- spikezoo/archs/bsf/main.py +398 -0
- spikezoo/archs/bsf/metrics/psnr.py +22 -0
- spikezoo/archs/bsf/metrics/ssim.py +54 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/align.py +154 -0
- spikezoo/archs/bsf/models/bsf/bsf.py +105 -0
- spikezoo/archs/bsf/models/bsf/dsft_convert.py +96 -0
- spikezoo/archs/bsf/models/bsf/rep.py +44 -0
- spikezoo/archs/bsf/models/get_model.py +7 -0
- spikezoo/archs/bsf/prepare_data/DSFT.py +62 -0
- spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +135 -0
- spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +139 -0
- spikezoo/archs/bsf/prepare_data/crop_train.sh +4 -0
- spikezoo/archs/bsf/prepare_data/crop_val.sh +4 -0
- spikezoo/archs/bsf/prepare_data/io_utils.py +64 -0
- spikezoo/archs/bsf/requirements.txt +9 -0
- spikezoo/archs/bsf/test.py +16 -0
- spikezoo/archs/bsf/utils.py +154 -0
- spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spikeclip/nets.py +40 -0
- spikezoo/archs/spikeformer/CheckPoints/readme +1 -0
- spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +60 -0
- spikezoo/archs/spikeformer/DataProcess/DataLoader.py +115 -0
- spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +39 -0
- spikezoo/archs/spikeformer/EvalResults/readme +1 -0
- spikezoo/archs/spikeformer/LICENSE +21 -0
- spikezoo/archs/spikeformer/Metrics/Metrics.py +50 -0
- spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/Loss.py +89 -0
- spikezoo/archs/spikeformer/Model/SpikeFormer.py +230 -0
- spikezoo/archs/spikeformer/Model/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/README.md +30 -0
- spikezoo/archs/spikeformer/evaluate.py +87 -0
- spikezoo/archs/spikeformer/recon_real_data.py +97 -0
- spikezoo/archs/spikeformer/requirements.yml +95 -0
- spikezoo/archs/spikeformer/train.py +173 -0
- spikezoo/archs/spikeformer/utils.py +22 -0
- spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +23 -0
- spikezoo/archs/spk2imgnet/.gitignore +150 -0
- spikezoo/archs/spk2imgnet/DCNv2.py +135 -0
- spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/align_arch.py +159 -0
- spikezoo/archs/spk2imgnet/dataset.py +144 -0
- spikezoo/archs/spk2imgnet/nets.py +230 -0
- spikezoo/archs/spk2imgnet/readme.md +86 -0
- spikezoo/archs/spk2imgnet/test_gen_imgseq.py +118 -0
- spikezoo/archs/spk2imgnet/train.py +189 -0
- spikezoo/archs/spk2imgnet/utils.py +64 -0
- spikezoo/archs/ssir/README.md +87 -0
- spikezoo/archs/ssir/configs/SSIR.yml +37 -0
- spikezoo/archs/ssir/configs/yml_parser.py +78 -0
- spikezoo/archs/ssir/datasets/dataset_sreds.py +170 -0
- spikezoo/archs/ssir/datasets/ds_utils.py +66 -0
- spikezoo/archs/ssir/losses.py +21 -0
- spikezoo/archs/ssir/main.py +326 -0
- spikezoo/archs/ssir/metrics/psnr.py +22 -0
- spikezoo/archs/ssir/metrics/ssim.py +54 -0
- spikezoo/archs/ssir/models/Vgg19.py +42 -0
- spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/layers.py +110 -0
- spikezoo/archs/ssir/models/networks.py +61 -0
- spikezoo/archs/ssir/requirements.txt +8 -0
- spikezoo/archs/ssir/shells/eval_SREDS.sh +6 -0
- spikezoo/archs/ssir/shells/train_SSIR.sh +12 -0
- spikezoo/archs/ssir/test.py +3 -0
- spikezoo/archs/ssir/utils.py +154 -0
- spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/cbam.py +224 -0
- spikezoo/archs/ssml/model.py +290 -0
- spikezoo/archs/ssml/res.png +0 -0
- spikezoo/archs/ssml/test.py +67 -0
- spikezoo/archs/stir/.git-credentials +0 -0
- spikezoo/archs/stir/README.md +65 -0
- spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +1 -0
- spikezoo/archs/stir/configs/STIR.yml +37 -0
- spikezoo/archs/stir/configs/utils.py +155 -0
- spikezoo/archs/stir/configs/yml_parser.py +78 -0
- spikezoo/archs/stir/datasets/dataset_sreds.py +180 -0
- spikezoo/archs/stir/datasets/ds_utils.py +66 -0
- spikezoo/archs/stir/eval_SREDS.sh +5 -0
- spikezoo/archs/stir/main.py +397 -0
- spikezoo/archs/stir/metrics/losses.py +219 -0
- spikezoo/archs/stir/metrics/psnr.py +22 -0
- spikezoo/archs/stir/metrics/ssim.py +54 -0
- spikezoo/archs/stir/models/Vgg19.py +42 -0
- spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/networks_STIR.py +361 -0
- spikezoo/archs/stir/models/submodules.py +86 -0
- spikezoo/archs/stir/models/transformer_new.py +151 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +721 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +133 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +167 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +76 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +458 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +183 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +40 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +198 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +51 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +53 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +100 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +333 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +123 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +72 -0
- spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
- spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/convertions.py +721 -0
- spikezoo/archs/stir/package_core/package_core/disp_netS.py +133 -0
- spikezoo/archs/stir/package_core/package_core/flow_utils.py +167 -0
- spikezoo/archs/stir/package_core/package_core/generic_train_test.py +76 -0
- spikezoo/archs/stir/package_core/package_core/geometry.py +458 -0
- spikezoo/archs/stir/package_core/package_core/image_proc.py +183 -0
- spikezoo/archs/stir/package_core/package_core/linalg.py +40 -0
- spikezoo/archs/stir/package_core/package_core/losses.py +198 -0
- spikezoo/archs/stir/package_core/package_core/metrics.py +51 -0
- spikezoo/archs/stir/package_core/package_core/model_base.py +53 -0
- spikezoo/archs/stir/package_core/package_core/net_basics.py +100 -0
- spikezoo/archs/stir/package_core/package_core/resnet.py +333 -0
- spikezoo/archs/stir/package_core/package_core/transforms.py +123 -0
- spikezoo/archs/stir/package_core/package_core/utils.py +72 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +3 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +20 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +1 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +1 -0
- spikezoo/archs/stir/package_core/setup.py +5 -0
- spikezoo/archs/stir/requirements.txt +12 -0
- spikezoo/archs/stir/train_STIR.sh +9 -0
- spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfi/nets.py +43 -0
- spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfp/nets.py +13 -0
- spikezoo/archs/wgse/README.md +64 -0
- spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/dataset.py +59 -0
- spikezoo/archs/wgse/demo.png +0 -0
- spikezoo/archs/wgse/demo.py +83 -0
- spikezoo/archs/wgse/dwtnets.py +145 -0
- spikezoo/archs/wgse/eval.py +133 -0
- spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +11 -0
- spikezoo/archs/wgse/submodules.py +68 -0
- spikezoo/archs/wgse/train.py +261 -0
- spikezoo/archs/wgse/transform.py +139 -0
- spikezoo/archs/wgse/utils.py +128 -0
- spikezoo/archs/wgse/weights/demo.png +0 -0
- spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
- spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
- spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
- spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
- spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
- spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
- spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
- spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
- spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
- spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
- spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
- spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
- spikezoo/datasets/base_dataset.py +2 -3
- spikezoo/metrics/__init__.py +1 -1
- spikezoo/models/base_model.py +1 -3
- spikezoo/pipeline/base_pipeline.py +7 -5
- spikezoo/pipeline/train_pipeline.py +1 -1
- spikezoo/utils/other_utils.py +16 -6
- spikezoo/utils/spike_utils.py +33 -29
- spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
- spikezoo-0.2.dist-info/METADATA +163 -0
- spikezoo-0.2.dist-info/RECORD +211 -0
- spikezoo/models/spcsnet_model.py +0 -19
- spikezoo-0.1.1.dist-info/METADATA +0 -39
- spikezoo-0.1.1.dist-info/RECORD +0 -36
- {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/LICENSE.txt +0 -0
- {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/WHEEL +0 -0
- {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,721 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
|
5
|
+
__all__ = [
|
6
|
+
# functional api
|
7
|
+
"pi",
|
8
|
+
"rad2deg",
|
9
|
+
"deg2rad",
|
10
|
+
"convert_points_from_homogeneous",
|
11
|
+
"convert_points_to_homogeneous",
|
12
|
+
"angle_axis_to_rotation_matrix",
|
13
|
+
"angle_axis_to_quaternion",
|
14
|
+
"euler_to_rotation_matrix",
|
15
|
+
"rotation_matrix_to_angle_axis",
|
16
|
+
"rotation_matrix_to_quaternion",
|
17
|
+
"rotation_matrix_to_euler",
|
18
|
+
"quaternion_to_angle_axis",
|
19
|
+
"quaternion_to_rotation_matrix",
|
20
|
+
"quaternion_log_to_exp",
|
21
|
+
"quaternion_exp_to_log",
|
22
|
+
"denormalize_pixel_coordinates",
|
23
|
+
"normalize_pixel_coordinates",
|
24
|
+
"normalize_quaternion",
|
25
|
+
"denormalize_pixel_coordinates3d",
|
26
|
+
"normalize_pixel_coordinates3d",
|
27
|
+
]
|
28
|
+
|
29
|
+
|
30
|
+
"""Constant with number pi
|
31
|
+
"""
|
32
|
+
pi = torch.tensor(3.14159265358979323846)
|
33
|
+
|
34
|
+
def isclose(mat1, mat2, tol=1e-6):
|
35
|
+
"""Check element-wise if two tensors are close within some tolerance.
|
36
|
+
Either tensor can be replaced by a scalar.
|
37
|
+
"""
|
38
|
+
return (mat1 - mat2).abs_().lt(tol)
|
39
|
+
|
40
|
+
def rad2deg(tensor):
|
41
|
+
r"""Function that converts angles from radians to degrees.
|
42
|
+
Args:
|
43
|
+
tensor (torch.Tensor): Tensor of arbitrary shape.
|
44
|
+
Returns:
|
45
|
+
torch.Tensor: Tensor with same shape as input.
|
46
|
+
Example:
|
47
|
+
>>> input = kornia.pi * torch.rand(1, 3, 3)
|
48
|
+
>>> output = kornia.rad2deg(input)
|
49
|
+
"""
|
50
|
+
if not isinstance(tensor, torch.Tensor):
|
51
|
+
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
52
|
+
type(tensor)))
|
53
|
+
|
54
|
+
return 180. * tensor / pi.to(tensor.device).type(tensor.dtype)
|
55
|
+
|
56
|
+
|
57
|
+
def deg2rad(tensor):
|
58
|
+
r"""Function that converts angles from degrees to radians.
|
59
|
+
Args:
|
60
|
+
tensor (torch.Tensor): Tensor of arbitrary shape.
|
61
|
+
Returns:
|
62
|
+
torch.Tensor: tensor with same shape as input.
|
63
|
+
Examples::
|
64
|
+
>>> input = 360. * torch.rand(1, 3, 3)
|
65
|
+
>>> output = kornia.deg2rad(input)
|
66
|
+
"""
|
67
|
+
if not isinstance(tensor, torch.Tensor):
|
68
|
+
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
69
|
+
type(tensor)))
|
70
|
+
|
71
|
+
return tensor * pi.to(tensor.device).type(tensor.dtype) / 180.
|
72
|
+
|
73
|
+
def convert_points_from_homogeneous(
|
74
|
+
points, eps: float = 1e-8) -> torch.Tensor:
|
75
|
+
r"""Function that converts points from homogeneous to Euclidean space.
|
76
|
+
Examples::
|
77
|
+
>>> input = torch.rand(2, 4, 3) # BxNx3
|
78
|
+
>>> output = kornia.convert_points_from_homogeneous(input) # BxNx2
|
79
|
+
"""
|
80
|
+
if not isinstance(points, torch.Tensor):
|
81
|
+
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
82
|
+
type(points)))
|
83
|
+
|
84
|
+
if len(points.shape) < 2:
|
85
|
+
raise ValueError("Input must be at least a 2D tensor. Got {}".format(
|
86
|
+
points.shape))
|
87
|
+
|
88
|
+
# we check for points at infinity
|
89
|
+
z_vec = points[..., -1:]
|
90
|
+
|
91
|
+
# set the results of division by zeror/near-zero to 1.0
|
92
|
+
# follow the convention of opencv:
|
93
|
+
# https://github.com/opencv/opencv/pull/14411/files
|
94
|
+
mask = torch.abs(z_vec) > eps
|
95
|
+
scale = torch.ones_like(z_vec).masked_scatter_(
|
96
|
+
mask, torch.tensor(1.0) / z_vec[mask])
|
97
|
+
|
98
|
+
return scale * points[..., :-1]
|
99
|
+
|
100
|
+
def convert_points_to_homogeneous(points) -> torch.Tensor:
|
101
|
+
r"""Function that converts points from Euclidean to homogeneous space.
|
102
|
+
Examples::
|
103
|
+
>>> input = torch.rand(2, 4, 3) # BxNx3
|
104
|
+
>>> output = kornia.convert_points_to_homogeneous(input) # BxNx4
|
105
|
+
"""
|
106
|
+
if not isinstance(points, torch.Tensor):
|
107
|
+
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
108
|
+
type(points)))
|
109
|
+
if len(points.shape) < 2:
|
110
|
+
raise ValueError("Input must be at least a 2D tensor. Got {}".format(
|
111
|
+
points.shape))
|
112
|
+
|
113
|
+
return torch.nn.functional.pad(points, [0, 1], "constant", 1.0)
|
114
|
+
|
115
|
+
|
116
|
+
def angle_axis_to_rotation_matrix(angle_axis) -> torch.Tensor:
|
117
|
+
r"""Convert 3d vector of axis-angle rotation to 3x3 rotation matrix
|
118
|
+
Args:
|
119
|
+
angle_axis (torch.Tensor): tensor of 3d vector of axis-angle rotations.
|
120
|
+
Returns:
|
121
|
+
torch.Tensor: tensor of 3x3 rotation matrices.
|
122
|
+
Shape:
|
123
|
+
- Input: :math:`(N, 3)`
|
124
|
+
- Output: :math:`(N, 3, 3)`
|
125
|
+
Example:
|
126
|
+
>>> input = torch.rand(1, 3) # Nx3
|
127
|
+
>>> output = kornia.angle_axis_to_rotation_matrix(input) # Nx3x3
|
128
|
+
"""
|
129
|
+
if not isinstance(angle_axis, torch.Tensor):
|
130
|
+
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
131
|
+
type(angle_axis)))
|
132
|
+
|
133
|
+
if not angle_axis.shape[-1] == 3:
|
134
|
+
raise ValueError(
|
135
|
+
"Input size must be a (*, 3) tensor. Got {}".format(
|
136
|
+
angle_axis.shape))
|
137
|
+
|
138
|
+
def _compute_rotation_matrix(angle_axis, theta2, eps=1e-6):
|
139
|
+
# We want to be careful to only evaluate the square root if the
|
140
|
+
# norm of the angle_axis vector is greater than zero. Otherwise
|
141
|
+
# we get a division by zero.
|
142
|
+
k_one = 1.0
|
143
|
+
theta = torch.sqrt(theta2)
|
144
|
+
wxyz = angle_axis / (theta + eps)
|
145
|
+
wx, wy, wz = torch.chunk(wxyz, 3, dim=1)
|
146
|
+
cos_theta = torch.cos(theta)
|
147
|
+
sin_theta = torch.sin(theta)
|
148
|
+
|
149
|
+
r00 = cos_theta + wx * wx * (k_one - cos_theta)
|
150
|
+
r10 = wz * sin_theta + wx * wy * (k_one - cos_theta)
|
151
|
+
r20 = -wy * sin_theta + wx * wz * (k_one - cos_theta)
|
152
|
+
r01 = wx * wy * (k_one - cos_theta) - wz * sin_theta
|
153
|
+
r11 = cos_theta + wy * wy * (k_one - cos_theta)
|
154
|
+
r21 = wx * sin_theta + wy * wz * (k_one - cos_theta)
|
155
|
+
r02 = wy * sin_theta + wx * wz * (k_one - cos_theta)
|
156
|
+
r12 = -wx * sin_theta + wy * wz * (k_one - cos_theta)
|
157
|
+
r22 = cos_theta + wz * wz * (k_one - cos_theta)
|
158
|
+
rotation_matrix = torch.cat(
|
159
|
+
[r00, r01, r02, r10, r11, r12, r20, r21, r22], dim=1)
|
160
|
+
return rotation_matrix.view(-1, 3, 3)
|
161
|
+
|
162
|
+
def _compute_rotation_matrix_taylor(angle_axis):
|
163
|
+
rx, ry, rz = torch.chunk(angle_axis, 3, dim=1)
|
164
|
+
k_one = torch.ones_like(rx)
|
165
|
+
rotation_matrix = torch.cat(
|
166
|
+
[k_one, -rz, ry, rz, k_one, -rx, -ry, rx, k_one], dim=1)
|
167
|
+
return rotation_matrix.view(-1, 3, 3)
|
168
|
+
|
169
|
+
# stolen from ceres/rotation.h
|
170
|
+
|
171
|
+
_angle_axis = torch.unsqueeze(angle_axis, dim=1)
|
172
|
+
theta2 = torch.matmul(_angle_axis, _angle_axis.transpose(1, 2))
|
173
|
+
theta2 = torch.squeeze(theta2, dim=1)
|
174
|
+
|
175
|
+
# compute rotation matrices
|
176
|
+
rotation_matrix_normal = _compute_rotation_matrix(angle_axis, theta2)
|
177
|
+
rotation_matrix_taylor = _compute_rotation_matrix_taylor(angle_axis)
|
178
|
+
|
179
|
+
# create mask to handle both cases
|
180
|
+
eps = 1e-6
|
181
|
+
mask = (theta2 > eps).view(-1, 1, 1).to(theta2.device)
|
182
|
+
mask_pos = (mask).type_as(theta2)
|
183
|
+
mask_neg = (mask == False).type_as(theta2) # noqa
|
184
|
+
|
185
|
+
# create output pose matrix
|
186
|
+
batch_size = angle_axis.shape[0]
|
187
|
+
rotation_matrix = torch.eye(3).to(angle_axis.device).type_as(angle_axis)
|
188
|
+
rotation_matrix = rotation_matrix.view(1, 3, 3).repeat(batch_size, 1, 1)
|
189
|
+
# fill output matrix with masked values
|
190
|
+
rotation_matrix[..., :3, :3] = \
|
191
|
+
mask_pos * rotation_matrix_normal + mask_neg * rotation_matrix_taylor
|
192
|
+
return rotation_matrix # Nx4x4
|
193
|
+
|
194
|
+
|
195
|
+
def rotation_matrix_to_angle_axis(
|
196
|
+
rotation_matrix) -> torch.Tensor:
|
197
|
+
r"""Convert 3x3 rotation matrix to Rodrigues vector.
|
198
|
+
Args:
|
199
|
+
rotation_matrix (torch.Tensor): rotation matrix.
|
200
|
+
Returns:
|
201
|
+
torch.Tensor: Rodrigues vector transformation.
|
202
|
+
Shape:
|
203
|
+
- Input: :math:`(N, 3, 3)`
|
204
|
+
- Output: :math:`(N, 3)`
|
205
|
+
Example:
|
206
|
+
>>> input = torch.rand(2, 3, 3) # Nx3x3
|
207
|
+
>>> output = kornia.rotation_matrix_to_angle_axis(input) # Nx3
|
208
|
+
"""
|
209
|
+
if not isinstance(rotation_matrix, torch.Tensor):
|
210
|
+
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
211
|
+
type(rotation_matrix)))
|
212
|
+
|
213
|
+
if not rotation_matrix.shape[-2:] == (3, 3):
|
214
|
+
raise ValueError(
|
215
|
+
"Input size must be a (*, 3, 3) tensor. Got {}".format(
|
216
|
+
rotation_matrix.shape))
|
217
|
+
quaternion = rotation_matrix_to_quaternion(rotation_matrix)
|
218
|
+
return quaternion_to_angle_axis(quaternion)
|
219
|
+
|
220
|
+
|
221
|
+
def rotation_matrix_to_quaternion(
|
222
|
+
rotation_matrix,
|
223
|
+
eps: float = 1e-8) -> torch.Tensor:
|
224
|
+
r"""Convert 3x3 rotation matrix to 4d quaternion vector.
|
225
|
+
The quaternion vector has components in (x, y, z, w) format.
|
226
|
+
Args:
|
227
|
+
rotation_matrix (torch.Tensor): the rotation matrix to convert.
|
228
|
+
eps (float): small value to avoid zero division. Default: 1e-8.
|
229
|
+
Return:
|
230
|
+
torch.Tensor: the rotation in quaternion.
|
231
|
+
Shape:
|
232
|
+
- Input: :math:`(*, 3, 3)`
|
233
|
+
- Output: :math:`(*, 4)`
|
234
|
+
Example:
|
235
|
+
>>> input = torch.rand(4, 3, 3) # Nx3x3
|
236
|
+
>>> output = kornia.rotation_matrix_to_quaternion(input) # Nx4
|
237
|
+
"""
|
238
|
+
if not isinstance(rotation_matrix, torch.Tensor):
|
239
|
+
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
240
|
+
type(rotation_matrix)))
|
241
|
+
|
242
|
+
if not rotation_matrix.shape[-2:] == (3, 3):
|
243
|
+
raise ValueError(
|
244
|
+
"Input size must be a (*, 3, 3) tensor. Got {}".format(
|
245
|
+
rotation_matrix.shape))
|
246
|
+
|
247
|
+
def safe_zero_division(numerator,
|
248
|
+
denominator) -> torch.Tensor:
|
249
|
+
eps = torch.finfo(numerator.dtype).tiny # type: ignore
|
250
|
+
return numerator / torch.clamp(denominator, min=eps)
|
251
|
+
|
252
|
+
rotation_matrix_vec = rotation_matrix.view(
|
253
|
+
*rotation_matrix.shape[:-2], 9)
|
254
|
+
|
255
|
+
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.chunk(
|
256
|
+
rotation_matrix_vec, chunks=9, dim=-1)
|
257
|
+
|
258
|
+
trace = m00 + m11 + m22
|
259
|
+
|
260
|
+
def trace_positive_cond():
|
261
|
+
sq = torch.sqrt(trace + 1.0) * 2. # sq = 4 * qw.
|
262
|
+
qw = 0.25 * sq
|
263
|
+
qx = safe_zero_division(m21 - m12, sq)
|
264
|
+
qy = safe_zero_division(m02 - m20, sq)
|
265
|
+
qz = safe_zero_division(m10 - m01, sq)
|
266
|
+
return torch.cat([qx, qy, qz, qw], dim=-1)
|
267
|
+
|
268
|
+
def cond_1():
|
269
|
+
sq = torch.sqrt(1.0 + m00 - m11 - m22 + eps) * 2. # sq = 4 * qx.
|
270
|
+
qw = safe_zero_division(m21 - m12, sq)
|
271
|
+
qx = 0.25 * sq
|
272
|
+
qy = safe_zero_division(m01 + m10, sq)
|
273
|
+
qz = safe_zero_division(m02 + m20, sq)
|
274
|
+
return torch.cat([qx, qy, qz, qw], dim=-1)
|
275
|
+
|
276
|
+
def cond_2():
|
277
|
+
sq = torch.sqrt(1.0 + m11 - m00 - m22 + eps) * 2. # sq = 4 * qy.
|
278
|
+
qw = safe_zero_division(m02 - m20, sq)
|
279
|
+
qx = safe_zero_division(m01 + m10, sq)
|
280
|
+
qy = 0.25 * sq
|
281
|
+
qz = safe_zero_division(m12 + m21, sq)
|
282
|
+
return torch.cat([qx, qy, qz, qw], dim=-1)
|
283
|
+
|
284
|
+
def cond_3():
|
285
|
+
sq = torch.sqrt(1.0 + m22 - m00 - m11 + eps) * 2. # sq = 4 * qz.
|
286
|
+
qw = safe_zero_division(m10 - m01, sq)
|
287
|
+
qx = safe_zero_division(m02 + m20, sq)
|
288
|
+
qy = safe_zero_division(m12 + m21, sq)
|
289
|
+
qz = 0.25 * sq
|
290
|
+
return torch.cat([qx, qy, qz, qw], dim=-1)
|
291
|
+
|
292
|
+
where_2 = torch.where(m11 > m22, cond_2(), cond_3())
|
293
|
+
where_1 = torch.where(
|
294
|
+
(m00 > m11) & (m00 > m22), cond_1(), where_2)
|
295
|
+
|
296
|
+
quaternion = torch.where(
|
297
|
+
trace > 0., trace_positive_cond(), where_1)
|
298
|
+
return quaternion
|
299
|
+
|
300
|
+
|
301
|
+
def normalize_quaternion(quaternion,
|
302
|
+
eps: float = 1e-12) -> torch.Tensor:
|
303
|
+
r"""Normalizes a quaternion.
|
304
|
+
The quaternion should be in (x, y, z, w) format.
|
305
|
+
Args:
|
306
|
+
quaternion (torch.Tensor): a tensor containing a quaternion to be
|
307
|
+
normalized. The tensor can be of shape :math:`(*, 4)`.
|
308
|
+
eps (Optional[bool]): small value to avoid division by zero.
|
309
|
+
Default: 1e-12.
|
310
|
+
Return:
|
311
|
+
torch.Tensor: the normalized quaternion of shape :math:`(*, 4)`.
|
312
|
+
Example:
|
313
|
+
>>> quaternion = torch.tensor([1., 0., 1., 0.])
|
314
|
+
>>> kornia.normalize_quaternion(quaternion)
|
315
|
+
tensor([0.7071, 0.0000, 0.7071, 0.0000])
|
316
|
+
"""
|
317
|
+
if not isinstance(quaternion, torch.Tensor):
|
318
|
+
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
319
|
+
type(quaternion)))
|
320
|
+
|
321
|
+
if not quaternion.shape[-1] == 4:
|
322
|
+
raise ValueError(
|
323
|
+
"Input must be a tensor of shape (*, 4). Got {}".format(
|
324
|
+
quaternion.shape))
|
325
|
+
return F.normalize(quaternion, p=2, dim=-1, eps=eps)
|
326
|
+
|
327
|
+
|
328
|
+
# based on:
|
329
|
+
# https://github.com/matthew-brett/transforms3d/blob/8965c48401d9e8e66b6a8c37c65f2fc200a076fa/transforms3d/quaternions.py#L101
|
330
|
+
# https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/geometry/transformation/rotation_matrix_3d.py#L247
|
331
|
+
|
332
|
+
def quaternion_to_rotation_matrix(quaternion) -> torch.Tensor:
|
333
|
+
r"""Converts a quaternion to a rotation matrix.
|
334
|
+
The quaternion should be in (x, y, z, w) format.
|
335
|
+
Args:
|
336
|
+
quaternion (torch.Tensor): a tensor containing a quaternion to be
|
337
|
+
converted. The tensor can be of shape :math:`(*, 4)`.
|
338
|
+
Return:
|
339
|
+
torch.Tensor: the rotation matrix of shape :math:`(*, 3, 3)`.
|
340
|
+
Example:
|
341
|
+
>>> quaternion = torch.tensor([0., 0., 1., 0.])
|
342
|
+
>>> kornia.quaternion_to_rotation_matrix(quaternion)
|
343
|
+
tensor([[[-1., 0., 0.],
|
344
|
+
[ 0., -1., 0.],
|
345
|
+
[ 0., 0., 1.]]])
|
346
|
+
"""
|
347
|
+
if not isinstance(quaternion, torch.Tensor):
|
348
|
+
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
349
|
+
type(quaternion)))
|
350
|
+
|
351
|
+
if not quaternion.shape[-1] == 4:
|
352
|
+
raise ValueError(
|
353
|
+
"Input must be a tensor of shape (*, 4). Got {}".format(
|
354
|
+
quaternion.shape))
|
355
|
+
# normalize the input quaternion
|
356
|
+
quaternion_norm = normalize_quaternion(quaternion)
|
357
|
+
|
358
|
+
# unpack the normalized quaternion components
|
359
|
+
x, y, z, w = torch.chunk(quaternion_norm, chunks=4, dim=-1)
|
360
|
+
|
361
|
+
# compute the actual conversion
|
362
|
+
tx = 2.0 * x
|
363
|
+
ty = 2.0 * y
|
364
|
+
tz = 2.0 * z
|
365
|
+
twx = tx * w
|
366
|
+
twy = ty * w
|
367
|
+
twz = tz * w
|
368
|
+
txx = tx * x
|
369
|
+
txy = ty * x
|
370
|
+
txz = tz * x
|
371
|
+
tyy = ty * y
|
372
|
+
tyz = tz * y
|
373
|
+
tzz = tz * z
|
374
|
+
one = torch.tensor(1.)
|
375
|
+
|
376
|
+
matrix = torch.stack([
|
377
|
+
one - (tyy + tzz), txy - twz, txz + twy,
|
378
|
+
txy + twz, one - (txx + tzz), tyz - twx,
|
379
|
+
txz - twy, tyz + twx, one - (txx + tyy)
|
380
|
+
], dim=-1).view(-1, 3, 3)
|
381
|
+
|
382
|
+
if len(quaternion.shape) == 1:
|
383
|
+
matrix = torch.squeeze(matrix, dim=0)
|
384
|
+
return matrix
|
385
|
+
|
386
|
+
|
387
|
+
def quaternion_to_angle_axis(quaternion) -> torch.Tensor:
|
388
|
+
"""Convert quaternion vector to angle axis of rotation.
|
389
|
+
The quaternion should be in (x, y, z, w) format.
|
390
|
+
Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
|
391
|
+
Args:
|
392
|
+
quaternion (torch.Tensor): tensor with quaternions.
|
393
|
+
Return:
|
394
|
+
torch.Tensor: tensor with angle axis of rotation.
|
395
|
+
Shape:
|
396
|
+
- Input: :math:`(*, 4)` where `*` means, any number of dimensions
|
397
|
+
- Output: :math:`(*, 3)`
|
398
|
+
Example:
|
399
|
+
>>> quaternion = torch.rand(2, 4) # Nx4
|
400
|
+
>>> angle_axis = kornia.quaternion_to_angle_axis(quaternion) # Nx3
|
401
|
+
"""
|
402
|
+
if not torch.is_tensor(quaternion):
|
403
|
+
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
404
|
+
type(quaternion)))
|
405
|
+
|
406
|
+
if not quaternion.shape[-1] == 4:
|
407
|
+
raise ValueError(
|
408
|
+
"Input must be a tensor of shape Nx4 or 4. Got {}".format(
|
409
|
+
quaternion.shape))
|
410
|
+
# unpack input and compute conversion
|
411
|
+
q1 = quaternion[..., 0]
|
412
|
+
q2 = quaternion[..., 1]
|
413
|
+
q3 = quaternion[..., 2]
|
414
|
+
sin_squared_theta = q1 * q1 + q2 * q2 + q3 * q3
|
415
|
+
|
416
|
+
sin_theta = torch.sqrt(sin_squared_theta)
|
417
|
+
cos_theta = quaternion[..., 3]
|
418
|
+
two_theta = 2.0 * torch.where(
|
419
|
+
cos_theta < 0.0, torch.atan2(-sin_theta, -cos_theta),
|
420
|
+
torch.atan2(sin_theta, cos_theta))
|
421
|
+
|
422
|
+
k_pos = two_theta / (sin_theta+1e-8)
|
423
|
+
k_neg = 2.0 * torch.ones_like(sin_theta)
|
424
|
+
k = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
|
425
|
+
|
426
|
+
angle_axis = torch.zeros_like(quaternion)[..., :3]
|
427
|
+
angle_axis[..., 0] += q1 * k
|
428
|
+
angle_axis[..., 1] += q2 * k
|
429
|
+
angle_axis[..., 2] += q3 * k
|
430
|
+
return angle_axis
|
431
|
+
|
432
|
+
|
433
|
+
def quaternion_log_to_exp(quaternion,
|
434
|
+
eps: float = 1e-8) -> torch.Tensor:
|
435
|
+
r"""Applies exponential map to log quaternion.
|
436
|
+
The quaternion should be in (x, y, z, w) format.
|
437
|
+
Args:
|
438
|
+
quaternion (torch.Tensor): a tensor containing a quaternion to be
|
439
|
+
converted. The tensor can be of shape :math:`(*, 3)`.
|
440
|
+
Return:
|
441
|
+
torch.Tensor: the quaternion exponential map of shape :math:`(*, 4)`.
|
442
|
+
Example:
|
443
|
+
>>> quaternion = torch.tensor([0., 0., 0.])
|
444
|
+
>>> kornia.quaternion_log_to_exp(quaternion)
|
445
|
+
tensor([0., 0., 0., 1.])
|
446
|
+
"""
|
447
|
+
if not isinstance(quaternion, torch.Tensor):
|
448
|
+
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
449
|
+
type(quaternion)))
|
450
|
+
|
451
|
+
if not quaternion.shape[-1] == 3:
|
452
|
+
raise ValueError(
|
453
|
+
"Input must be a tensor of shape (*, 3). Got {}".format(
|
454
|
+
quaternion.shape))
|
455
|
+
# compute quaternion norm
|
456
|
+
norm_q = torch.norm(
|
457
|
+
quaternion, p=2, dim=-1, keepdim=True).clamp(min=eps)
|
458
|
+
|
459
|
+
# compute scalar and vector
|
460
|
+
quaternion_vector = quaternion * torch.sin(norm_q) / norm_q
|
461
|
+
quaternion_scalar = torch.cos(norm_q)
|
462
|
+
|
463
|
+
# compose quaternion and return
|
464
|
+
quaternion_exp = torch.cat(
|
465
|
+
[quaternion_vector, quaternion_scalar], dim=-1)
|
466
|
+
return quaternion_exp
|
467
|
+
|
468
|
+
|
469
|
+
def quaternion_exp_to_log(quaternion,
|
470
|
+
eps: float = 1e-8) -> torch.Tensor:
|
471
|
+
r"""Applies the log map to a quaternion.
|
472
|
+
The quaternion should be in (x, y, z, w) format.
|
473
|
+
Args:
|
474
|
+
quaternion (torch.Tensor): a tensor containing a quaternion to be
|
475
|
+
converted. The tensor can be of shape :math:`(*, 4)`.
|
476
|
+
Return:
|
477
|
+
torch.Tensor: the quaternion log map of shape :math:`(*, 3)`.
|
478
|
+
Example:
|
479
|
+
>>> quaternion = torch.tensor([0., 0., 0., 1.])
|
480
|
+
>>> kornia.quaternion_exp_to_log(quaternion)
|
481
|
+
tensor([0., 0., 0.])
|
482
|
+
"""
|
483
|
+
if not isinstance(quaternion, torch.Tensor):
|
484
|
+
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
485
|
+
type(quaternion)))
|
486
|
+
|
487
|
+
if not quaternion.shape[-1] == 4:
|
488
|
+
raise ValueError(
|
489
|
+
"Input must be a tensor of shape (*, 4). Got {}".format(
|
490
|
+
quaternion.shape))
|
491
|
+
# unpack quaternion vector and scalar
|
492
|
+
quaternion_vector = quaternion[..., 0:3]
|
493
|
+
quaternion_scalar = quaternion[..., 3:4]
|
494
|
+
|
495
|
+
# compute quaternion norm
|
496
|
+
norm_q = torch.norm(
|
497
|
+
quaternion_vector, p=2, dim=-1, keepdim=True).clamp(min=eps)
|
498
|
+
|
499
|
+
# apply log map
|
500
|
+
quaternion_log = quaternion_vector * torch.acos(
|
501
|
+
torch.clamp(quaternion_scalar, min=-1.0, max=1.0)) / norm_q
|
502
|
+
return quaternion_log
|
503
|
+
|
504
|
+
|
505
|
+
# based on:
|
506
|
+
# https://github.com/facebookresearch/QuaterNet/blob/master/common/quaternion.py#L138
|
507
|
+
|
508
|
+
|
509
|
+
def angle_axis_to_quaternion(angle_axis) -> torch.Tensor:
|
510
|
+
r"""Convert an angle axis to a quaternion.
|
511
|
+
The quaternion vector has components in (x, y, z, w) format.
|
512
|
+
Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
|
513
|
+
Args:
|
514
|
+
angle_axis (torch.Tensor): tensor with angle axis.
|
515
|
+
Return:
|
516
|
+
torch.Tensor: tensor with quaternion.
|
517
|
+
Shape:
|
518
|
+
- Input: :math:`(*, 3)` where `*` means, any number of dimensions
|
519
|
+
- Output: :math:`(*, 4)`
|
520
|
+
Example:
|
521
|
+
>>> angle_axis = torch.rand(2, 4) # Nx4
|
522
|
+
>>> quaternion = kornia.angle_axis_to_quaternion(angle_axis) # Nx3
|
523
|
+
"""
|
524
|
+
if not torch.is_tensor(angle_axis):
|
525
|
+
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
|
526
|
+
type(angle_axis)))
|
527
|
+
|
528
|
+
if not angle_axis.shape[-1] == 3:
|
529
|
+
raise ValueError(
|
530
|
+
"Input must be a tensor of shape Nx3 or 3. Got {}".format(
|
531
|
+
angle_axis.shape))
|
532
|
+
# unpack input and compute conversion
|
533
|
+
a0 = angle_axis[..., 0:1]
|
534
|
+
a1 = angle_axis[..., 1:2]
|
535
|
+
a2 = angle_axis[..., 2:3]
|
536
|
+
theta_squared = a0 * a0 + a1 * a1 + a2 * a2
|
537
|
+
|
538
|
+
theta = torch.sqrt(theta_squared)
|
539
|
+
half_theta = theta * 0.5
|
540
|
+
|
541
|
+
mask = theta_squared > 0.0
|
542
|
+
ones = torch.ones_like(half_theta)
|
543
|
+
|
544
|
+
k_neg = 0.5 * ones
|
545
|
+
k_pos = torch.sin(half_theta) / theta
|
546
|
+
k = torch.where(mask, k_pos, k_neg)
|
547
|
+
w = torch.where(mask, torch.cos(half_theta), ones)
|
548
|
+
|
549
|
+
quaternion = torch.zeros_like(angle_axis)
|
550
|
+
quaternion[..., 0:1] += a0 * k
|
551
|
+
quaternion[..., 1:2] += a1 * k
|
552
|
+
quaternion[..., 2:3] += a2 * k
|
553
|
+
return torch.cat([quaternion, w], dim=-1)
|
554
|
+
|
555
|
+
|
556
|
+
# based on:
|
557
|
+
# https://github.com/ClementPinard/SfmLearner-Pytorch/blob/master/inverse_warp.py#L65-L71
|
558
|
+
|
559
|
+
def normalize_pixel_coordinates(
|
560
|
+
pixel_coordinates,
|
561
|
+
height: int,
|
562
|
+
width: int,
|
563
|
+
eps: float = 1e-8) -> torch.Tensor:
|
564
|
+
r"""Normalize pixel coordinates between -1 and 1.
|
565
|
+
Normalized, -1 if on extreme left, 1 if on extreme right (x = w-1).
|
566
|
+
Args:
|
567
|
+
pixel_coordinates (torch.Tensor): the grid with pixel coordinates.
|
568
|
+
Shape can be :math:`(*, 2)`.
|
569
|
+
width (int): the maximum width in the x-axis.
|
570
|
+
height (int): the maximum height in the y-axis.
|
571
|
+
eps (float): safe division by zero. (default 1e-8).
|
572
|
+
Return:
|
573
|
+
torch.Tensor: the normalized pixel coordinates.
|
574
|
+
"""
|
575
|
+
if pixel_coordinates.shape[-1] != 2:
|
576
|
+
raise ValueError("Input pixel_coordinates must be of shape (*, 2). "
|
577
|
+
"Got {}".format(pixel_coordinates.shape))
|
578
|
+
# compute normalization factor
|
579
|
+
hw = torch.stack([
|
580
|
+
torch.tensor(width), torch.tensor(height)
|
581
|
+
]).to(pixel_coordinates.device).to(pixel_coordinates.dtype)
|
582
|
+
|
583
|
+
factor = torch.tensor(2.) / (hw - 1).clamp(eps)
|
584
|
+
|
585
|
+
return factor * pixel_coordinates - 1
|
586
|
+
|
587
|
+
|
588
|
+
def denormalize_pixel_coordinates(
|
589
|
+
pixel_coordinates,
|
590
|
+
height: int,
|
591
|
+
width: int,
|
592
|
+
eps: float = 1e-8) -> torch.Tensor:
|
593
|
+
r"""Denormalize pixel coordinates.
|
594
|
+
The input is assumed to be -1 if on extreme left, 1 if on
|
595
|
+
extreme right (x = w-1).
|
596
|
+
Args:
|
597
|
+
pixel_coordinates (torch.Tensor): the normalized grid coordinates.
|
598
|
+
Shape can be :math:`(*, 2)`.
|
599
|
+
width (int): the maximum width in the x-axis.
|
600
|
+
height (int): the maximum height in the y-axis.
|
601
|
+
eps (float): safe division by zero. (default 1e-8).
|
602
|
+
Return:
|
603
|
+
torch.Tensor: the denormalized pixel coordinates.
|
604
|
+
"""
|
605
|
+
if pixel_coordinates.shape[-1] != 2:
|
606
|
+
raise ValueError("Input pixel_coordinates must be of shape (*, 2). "
|
607
|
+
"Got {}".format(pixel_coordinates.shape))
|
608
|
+
# compute normalization factor
|
609
|
+
hw = torch.stack([
|
610
|
+
torch.tensor(width), torch.tensor(height)
|
611
|
+
]).to(pixel_coordinates.device).to(pixel_coordinates.dtype)
|
612
|
+
|
613
|
+
factor = torch.tensor(2.) / (hw - 1).clamp(eps)
|
614
|
+
|
615
|
+
return torch.tensor(1.) / factor * (pixel_coordinates + 1)
|
616
|
+
|
617
|
+
|
618
|
+
def normalize_pixel_coordinates3d(
|
619
|
+
pixel_coordinates,
|
620
|
+
depth: int,
|
621
|
+
height: int,
|
622
|
+
width: int,
|
623
|
+
eps: float = 1e-8) -> torch.Tensor:
|
624
|
+
r"""Normalize pixel coordinates between -1 and 1.
|
625
|
+
Normalized, -1 if on extreme left, 1 if on extreme right (x = w-1).
|
626
|
+
Args:
|
627
|
+
pixel_coordinates (torch.Tensor): the grid with pixel coordinates.
|
628
|
+
Shape can be :math:`(*, 3)`.
|
629
|
+
depth (int): the maximum depth in the z-axis.
|
630
|
+
height (int): the maximum height in the y-axis.
|
631
|
+
width (int): the maximum width in the x-axis.
|
632
|
+
eps (float): safe division by zero. (default 1e-8).
|
633
|
+
Return:
|
634
|
+
torch.Tensor: the normalized pixel coordinates.
|
635
|
+
"""
|
636
|
+
if pixel_coordinates.shape[-1] != 3:
|
637
|
+
raise ValueError("Input pixel_coordinates must be of shape (*, 3). "
|
638
|
+
"Got {}".format(pixel_coordinates.shape))
|
639
|
+
# compute normalization factor
|
640
|
+
dhw = torch.stack([
|
641
|
+
torch.tensor(depth), torch.tensor(width), torch.tensor(height)
|
642
|
+
]).to(pixel_coordinates.device).to(pixel_coordinates.dtype)
|
643
|
+
|
644
|
+
factor = torch.tensor(2.) / (dhw - 1).clamp(eps)
|
645
|
+
|
646
|
+
return factor * pixel_coordinates - 1
|
647
|
+
|
648
|
+
|
649
|
+
def denormalize_pixel_coordinates3d(
|
650
|
+
pixel_coordinates,
|
651
|
+
depth: int,
|
652
|
+
height: int,
|
653
|
+
width: int,
|
654
|
+
eps: float = 1e-8) -> torch.Tensor:
|
655
|
+
r"""Denormalize pixel coordinates.
|
656
|
+
The input is assumed to be -1 if on extreme left, 1 if on
|
657
|
+
extreme right (x = w-1).
|
658
|
+
Args:
|
659
|
+
pixel_coordinates (torch.Tensor): the normalized grid coordinates.
|
660
|
+
Shape can be :math:`(*, 3)`.
|
661
|
+
depth (int): the maximum depth in the x-axis.
|
662
|
+
height (int): the maximum height in the y-axis.
|
663
|
+
width (int): the maximum width in the x-axis.
|
664
|
+
eps (float): safe division by zero. (default 1e-8).
|
665
|
+
Return:
|
666
|
+
torch.Tensor: the denormalized pixel coordinates.
|
667
|
+
"""
|
668
|
+
if pixel_coordinates.shape[-1] != 3:
|
669
|
+
raise ValueError("Input pixel_coordinates must be of shape (*, 3). "
|
670
|
+
"Got {}".format(pixel_coordinates.shape))
|
671
|
+
# compute normalization factor
|
672
|
+
dhw = torch.stack([
|
673
|
+
torch.tensor(depth), torch.tensor(width), torch.tensor(height)
|
674
|
+
]).to(pixel_coordinates.device).to(pixel_coordinates.dtype)
|
675
|
+
|
676
|
+
factor = torch.tensor(2.) / (dhw - 1).clamp(eps)
|
677
|
+
return torch.tensor(1.) / factor * (pixel_coordinates + 1)
|
678
|
+
|
679
|
+
def rotation_matrix_to_euler(R):
|
680
|
+
# Following convertion has problem when Rx=90 degrees, which is
|
681
|
+
# usually known as the gimbal lock problem
|
682
|
+
Q = rotation_matrix_to_quaternion(R)
|
683
|
+
|
684
|
+
x = Q[:,0].unsqueeze(-1)
|
685
|
+
y = Q[:,1].unsqueeze(-1)
|
686
|
+
z = Q[:,2].unsqueeze(-1)
|
687
|
+
w = Q[:,3].unsqueeze(-1)
|
688
|
+
|
689
|
+
Rx = torch.atan2(2.0 * (w * x + y * z), 1.0 - 2.0 * (x * x + y * y))
|
690
|
+
Ry = torch.asin(2.0 * (w * y - x * z))
|
691
|
+
Rz = torch.atan2(2.0 * (w * z + x * y), 1.0 - 2.0 * (y * y + z * z))
|
692
|
+
|
693
|
+
euler = torch.cat([Rx, Ry, Rz], dim=-1)
|
694
|
+
return euler
|
695
|
+
|
696
|
+
def euler_to_rotation_matrix(euler):
|
697
|
+
# euler is defined x, y, z
|
698
|
+
Rx = euler[:,0].unsqueeze(-1)
|
699
|
+
Ry = euler[:,1].unsqueeze(-1)
|
700
|
+
Rz = euler[:,2].unsqueeze(-1)
|
701
|
+
|
702
|
+
cr = torch.cos(0.5 * Rx)
|
703
|
+
sr = torch.sin(0.5 * Rx)
|
704
|
+
cp = torch.cos(0.5 * Ry)
|
705
|
+
sp = torch.sin(0.5 * Ry)
|
706
|
+
cy = torch.cos(0.5 * Rz)
|
707
|
+
sy = torch.sin(0.5 * Rz)
|
708
|
+
|
709
|
+
w = cr * cp * cy + sr * sp * sy
|
710
|
+
x = sr * cp * cy - cr * sp * sy
|
711
|
+
y = cr * sp * cy + sr * cp * sy
|
712
|
+
z = cr * cp * sy - sr * sp * cy
|
713
|
+
|
714
|
+
Q = torch.cat([x,y,z,w], dim=-1)
|
715
|
+
R = quaternion_to_rotation_matrix(Q)
|
716
|
+
return R
|
717
|
+
|
718
|
+
# R = torch.tensor([[[ 0.0266, 0.1109, 0.9935],
|
719
|
+
# [-0.0344, 0.9933, -0.1099],
|
720
|
+
# [-0.9991, -0.0313, 0.0302]]])
|
721
|
+
# print(rotation_matrix_to_quaternion(R))
|