spikezoo 0.1.2__py3-none-any.whl → 0.2.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (192) hide show
  1. spikezoo/__init__.py +13 -0
  2. spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
  3. spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
  4. spikezoo/archs/base/nets.py +34 -0
  5. spikezoo/archs/bsf/README.md +92 -0
  6. spikezoo/archs/bsf/datasets/datasets.py +328 -0
  7. spikezoo/archs/bsf/datasets/ds_utils.py +64 -0
  8. spikezoo/archs/bsf/main.py +398 -0
  9. spikezoo/archs/bsf/metrics/psnr.py +22 -0
  10. spikezoo/archs/bsf/metrics/ssim.py +54 -0
  11. spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
  12. spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
  13. spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
  14. spikezoo/archs/bsf/models/bsf/align.py +154 -0
  15. spikezoo/archs/bsf/models/bsf/bsf.py +105 -0
  16. spikezoo/archs/bsf/models/bsf/dsft_convert.py +96 -0
  17. spikezoo/archs/bsf/models/bsf/rep.py +44 -0
  18. spikezoo/archs/bsf/models/get_model.py +7 -0
  19. spikezoo/archs/bsf/prepare_data/DSFT.py +62 -0
  20. spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +135 -0
  21. spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +139 -0
  22. spikezoo/archs/bsf/prepare_data/crop_train.sh +4 -0
  23. spikezoo/archs/bsf/prepare_data/crop_val.sh +4 -0
  24. spikezoo/archs/bsf/prepare_data/io_utils.py +64 -0
  25. spikezoo/archs/bsf/requirements.txt +9 -0
  26. spikezoo/archs/bsf/test.py +16 -0
  27. spikezoo/archs/bsf/utils.py +154 -0
  28. spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
  29. spikezoo/archs/spikeclip/nets.py +40 -0
  30. spikezoo/archs/spikeformer/CheckPoints/readme +1 -0
  31. spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +60 -0
  32. spikezoo/archs/spikeformer/DataProcess/DataLoader.py +115 -0
  33. spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +39 -0
  34. spikezoo/archs/spikeformer/EvalResults/readme +1 -0
  35. spikezoo/archs/spikeformer/LICENSE +21 -0
  36. spikezoo/archs/spikeformer/Metrics/Metrics.py +50 -0
  37. spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
  38. spikezoo/archs/spikeformer/Model/Loss.py +89 -0
  39. spikezoo/archs/spikeformer/Model/SpikeFormer.py +230 -0
  40. spikezoo/archs/spikeformer/Model/__init__.py +0 -0
  41. spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
  42. spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
  43. spikezoo/archs/spikeformer/README.md +30 -0
  44. spikezoo/archs/spikeformer/evaluate.py +87 -0
  45. spikezoo/archs/spikeformer/recon_real_data.py +97 -0
  46. spikezoo/archs/spikeformer/requirements.yml +95 -0
  47. spikezoo/archs/spikeformer/train.py +173 -0
  48. spikezoo/archs/spikeformer/utils.py +22 -0
  49. spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +23 -0
  50. spikezoo/archs/spk2imgnet/.gitignore +150 -0
  51. spikezoo/archs/spk2imgnet/DCNv2.py +135 -0
  52. spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
  53. spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
  54. spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
  55. spikezoo/archs/spk2imgnet/align_arch.py +159 -0
  56. spikezoo/archs/spk2imgnet/dataset.py +144 -0
  57. spikezoo/archs/spk2imgnet/nets.py +230 -0
  58. spikezoo/archs/spk2imgnet/readme.md +86 -0
  59. spikezoo/archs/spk2imgnet/test_gen_imgseq.py +118 -0
  60. spikezoo/archs/spk2imgnet/train.py +189 -0
  61. spikezoo/archs/spk2imgnet/utils.py +64 -0
  62. spikezoo/archs/ssir/README.md +87 -0
  63. spikezoo/archs/ssir/configs/SSIR.yml +37 -0
  64. spikezoo/archs/ssir/configs/yml_parser.py +78 -0
  65. spikezoo/archs/ssir/datasets/dataset_sreds.py +170 -0
  66. spikezoo/archs/ssir/datasets/ds_utils.py +66 -0
  67. spikezoo/archs/ssir/losses.py +21 -0
  68. spikezoo/archs/ssir/main.py +326 -0
  69. spikezoo/archs/ssir/metrics/psnr.py +22 -0
  70. spikezoo/archs/ssir/metrics/ssim.py +54 -0
  71. spikezoo/archs/ssir/models/Vgg19.py +42 -0
  72. spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
  73. spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
  74. spikezoo/archs/ssir/models/layers.py +110 -0
  75. spikezoo/archs/ssir/models/networks.py +61 -0
  76. spikezoo/archs/ssir/requirements.txt +8 -0
  77. spikezoo/archs/ssir/shells/eval_SREDS.sh +6 -0
  78. spikezoo/archs/ssir/shells/train_SSIR.sh +12 -0
  79. spikezoo/archs/ssir/test.py +3 -0
  80. spikezoo/archs/ssir/utils.py +154 -0
  81. spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
  82. spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
  83. spikezoo/archs/ssml/cbam.py +224 -0
  84. spikezoo/archs/ssml/model.py +290 -0
  85. spikezoo/archs/ssml/res.png +0 -0
  86. spikezoo/archs/ssml/test.py +67 -0
  87. spikezoo/archs/stir/.git-credentials +0 -0
  88. spikezoo/archs/stir/README.md +65 -0
  89. spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +1 -0
  90. spikezoo/archs/stir/configs/STIR.yml +37 -0
  91. spikezoo/archs/stir/configs/utils.py +155 -0
  92. spikezoo/archs/stir/configs/yml_parser.py +78 -0
  93. spikezoo/archs/stir/datasets/dataset_sreds.py +180 -0
  94. spikezoo/archs/stir/datasets/ds_utils.py +66 -0
  95. spikezoo/archs/stir/eval_SREDS.sh +5 -0
  96. spikezoo/archs/stir/main.py +397 -0
  97. spikezoo/archs/stir/metrics/losses.py +219 -0
  98. spikezoo/archs/stir/metrics/psnr.py +22 -0
  99. spikezoo/archs/stir/metrics/ssim.py +54 -0
  100. spikezoo/archs/stir/models/Vgg19.py +42 -0
  101. spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
  102. spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
  103. spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
  104. spikezoo/archs/stir/models/networks_STIR.py +361 -0
  105. spikezoo/archs/stir/models/submodules.py +86 -0
  106. spikezoo/archs/stir/models/transformer_new.py +151 -0
  107. spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
  108. spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +721 -0
  109. spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +133 -0
  110. spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +167 -0
  111. spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +76 -0
  112. spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +458 -0
  113. spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +183 -0
  114. spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +40 -0
  115. spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +198 -0
  116. spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +51 -0
  117. spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +53 -0
  118. spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +100 -0
  119. spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +333 -0
  120. spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +123 -0
  121. spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +72 -0
  122. spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
  123. spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
  124. spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
  125. spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
  126. spikezoo/archs/stir/package_core/package_core/convertions.py +721 -0
  127. spikezoo/archs/stir/package_core/package_core/disp_netS.py +133 -0
  128. spikezoo/archs/stir/package_core/package_core/flow_utils.py +167 -0
  129. spikezoo/archs/stir/package_core/package_core/generic_train_test.py +76 -0
  130. spikezoo/archs/stir/package_core/package_core/geometry.py +458 -0
  131. spikezoo/archs/stir/package_core/package_core/image_proc.py +183 -0
  132. spikezoo/archs/stir/package_core/package_core/linalg.py +40 -0
  133. spikezoo/archs/stir/package_core/package_core/losses.py +198 -0
  134. spikezoo/archs/stir/package_core/package_core/metrics.py +51 -0
  135. spikezoo/archs/stir/package_core/package_core/model_base.py +53 -0
  136. spikezoo/archs/stir/package_core/package_core/net_basics.py +100 -0
  137. spikezoo/archs/stir/package_core/package_core/resnet.py +333 -0
  138. spikezoo/archs/stir/package_core/package_core/transforms.py +123 -0
  139. spikezoo/archs/stir/package_core/package_core/utils.py +72 -0
  140. spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +3 -0
  141. spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +20 -0
  142. spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +1 -0
  143. spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +1 -0
  144. spikezoo/archs/stir/package_core/setup.py +5 -0
  145. spikezoo/archs/stir/requirements.txt +12 -0
  146. spikezoo/archs/stir/train_STIR.sh +9 -0
  147. spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
  148. spikezoo/archs/tfi/nets.py +43 -0
  149. spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
  150. spikezoo/archs/tfp/nets.py +13 -0
  151. spikezoo/archs/wgse/README.md +64 -0
  152. spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
  153. spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
  154. spikezoo/archs/wgse/dataset.py +59 -0
  155. spikezoo/archs/wgse/demo.png +0 -0
  156. spikezoo/archs/wgse/demo.py +83 -0
  157. spikezoo/archs/wgse/dwtnets.py +145 -0
  158. spikezoo/archs/wgse/eval.py +133 -0
  159. spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +11 -0
  160. spikezoo/archs/wgse/submodules.py +68 -0
  161. spikezoo/archs/wgse/train.py +261 -0
  162. spikezoo/archs/wgse/transform.py +139 -0
  163. spikezoo/archs/wgse/utils.py +128 -0
  164. spikezoo/archs/wgse/weights/demo.png +0 -0
  165. spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
  166. spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
  167. spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
  168. spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
  169. spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
  170. spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
  171. spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
  172. spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
  173. spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
  174. spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
  175. spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
  176. spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
  177. spikezoo/datasets/base_dataset.py +2 -3
  178. spikezoo/metrics/__init__.py +1 -1
  179. spikezoo/models/base_model.py +1 -3
  180. spikezoo/pipeline/base_pipeline.py +7 -5
  181. spikezoo/pipeline/train_pipeline.py +1 -1
  182. spikezoo/utils/other_utils.py +16 -6
  183. spikezoo/utils/spike_utils.py +33 -29
  184. spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
  185. spikezoo-0.2.1.dist-info/METADATA +167 -0
  186. spikezoo-0.2.1.dist-info/RECORD +211 -0
  187. spikezoo/models/spcsnet_model.py +0 -19
  188. spikezoo-0.1.2.dist-info/METADATA +0 -39
  189. spikezoo-0.1.2.dist-info/RECORD +0 -36
  190. {spikezoo-0.1.2.dist-info → spikezoo-0.2.1.dist-info}/LICENSE.txt +0 -0
  191. {spikezoo-0.1.2.dist-info → spikezoo-0.2.1.dist-info}/WHEEL +0 -0
  192. {spikezoo-0.1.2.dist-info → spikezoo-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,721 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ __all__ = [
6
+ # functional api
7
+ "pi",
8
+ "rad2deg",
9
+ "deg2rad",
10
+ "convert_points_from_homogeneous",
11
+ "convert_points_to_homogeneous",
12
+ "angle_axis_to_rotation_matrix",
13
+ "angle_axis_to_quaternion",
14
+ "euler_to_rotation_matrix",
15
+ "rotation_matrix_to_angle_axis",
16
+ "rotation_matrix_to_quaternion",
17
+ "rotation_matrix_to_euler",
18
+ "quaternion_to_angle_axis",
19
+ "quaternion_to_rotation_matrix",
20
+ "quaternion_log_to_exp",
21
+ "quaternion_exp_to_log",
22
+ "denormalize_pixel_coordinates",
23
+ "normalize_pixel_coordinates",
24
+ "normalize_quaternion",
25
+ "denormalize_pixel_coordinates3d",
26
+ "normalize_pixel_coordinates3d",
27
+ ]
28
+
29
+
30
+ """Constant with number pi
31
+ """
32
+ pi = torch.tensor(3.14159265358979323846)
33
+
34
+ def isclose(mat1, mat2, tol=1e-6):
35
+ """Check element-wise if two tensors are close within some tolerance.
36
+ Either tensor can be replaced by a scalar.
37
+ """
38
+ return (mat1 - mat2).abs_().lt(tol)
39
+
40
+ def rad2deg(tensor):
41
+ r"""Function that converts angles from radians to degrees.
42
+ Args:
43
+ tensor (torch.Tensor): Tensor of arbitrary shape.
44
+ Returns:
45
+ torch.Tensor: Tensor with same shape as input.
46
+ Example:
47
+ >>> input = kornia.pi * torch.rand(1, 3, 3)
48
+ >>> output = kornia.rad2deg(input)
49
+ """
50
+ if not isinstance(tensor, torch.Tensor):
51
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
52
+ type(tensor)))
53
+
54
+ return 180. * tensor / pi.to(tensor.device).type(tensor.dtype)
55
+
56
+
57
+ def deg2rad(tensor):
58
+ r"""Function that converts angles from degrees to radians.
59
+ Args:
60
+ tensor (torch.Tensor): Tensor of arbitrary shape.
61
+ Returns:
62
+ torch.Tensor: tensor with same shape as input.
63
+ Examples::
64
+ >>> input = 360. * torch.rand(1, 3, 3)
65
+ >>> output = kornia.deg2rad(input)
66
+ """
67
+ if not isinstance(tensor, torch.Tensor):
68
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
69
+ type(tensor)))
70
+
71
+ return tensor * pi.to(tensor.device).type(tensor.dtype) / 180.
72
+
73
+ def convert_points_from_homogeneous(
74
+ points, eps: float = 1e-8) -> torch.Tensor:
75
+ r"""Function that converts points from homogeneous to Euclidean space.
76
+ Examples::
77
+ >>> input = torch.rand(2, 4, 3) # BxNx3
78
+ >>> output = kornia.convert_points_from_homogeneous(input) # BxNx2
79
+ """
80
+ if not isinstance(points, torch.Tensor):
81
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
82
+ type(points)))
83
+
84
+ if len(points.shape) < 2:
85
+ raise ValueError("Input must be at least a 2D tensor. Got {}".format(
86
+ points.shape))
87
+
88
+ # we check for points at infinity
89
+ z_vec = points[..., -1:]
90
+
91
+ # set the results of division by zeror/near-zero to 1.0
92
+ # follow the convention of opencv:
93
+ # https://github.com/opencv/opencv/pull/14411/files
94
+ mask = torch.abs(z_vec) > eps
95
+ scale = torch.ones_like(z_vec).masked_scatter_(
96
+ mask, torch.tensor(1.0) / z_vec[mask])
97
+
98
+ return scale * points[..., :-1]
99
+
100
+ def convert_points_to_homogeneous(points) -> torch.Tensor:
101
+ r"""Function that converts points from Euclidean to homogeneous space.
102
+ Examples::
103
+ >>> input = torch.rand(2, 4, 3) # BxNx3
104
+ >>> output = kornia.convert_points_to_homogeneous(input) # BxNx4
105
+ """
106
+ if not isinstance(points, torch.Tensor):
107
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
108
+ type(points)))
109
+ if len(points.shape) < 2:
110
+ raise ValueError("Input must be at least a 2D tensor. Got {}".format(
111
+ points.shape))
112
+
113
+ return torch.nn.functional.pad(points, [0, 1], "constant", 1.0)
114
+
115
+
116
+ def angle_axis_to_rotation_matrix(angle_axis) -> torch.Tensor:
117
+ r"""Convert 3d vector of axis-angle rotation to 3x3 rotation matrix
118
+ Args:
119
+ angle_axis (torch.Tensor): tensor of 3d vector of axis-angle rotations.
120
+ Returns:
121
+ torch.Tensor: tensor of 3x3 rotation matrices.
122
+ Shape:
123
+ - Input: :math:`(N, 3)`
124
+ - Output: :math:`(N, 3, 3)`
125
+ Example:
126
+ >>> input = torch.rand(1, 3) # Nx3
127
+ >>> output = kornia.angle_axis_to_rotation_matrix(input) # Nx3x3
128
+ """
129
+ if not isinstance(angle_axis, torch.Tensor):
130
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
131
+ type(angle_axis)))
132
+
133
+ if not angle_axis.shape[-1] == 3:
134
+ raise ValueError(
135
+ "Input size must be a (*, 3) tensor. Got {}".format(
136
+ angle_axis.shape))
137
+
138
+ def _compute_rotation_matrix(angle_axis, theta2, eps=1e-6):
139
+ # We want to be careful to only evaluate the square root if the
140
+ # norm of the angle_axis vector is greater than zero. Otherwise
141
+ # we get a division by zero.
142
+ k_one = 1.0
143
+ theta = torch.sqrt(theta2)
144
+ wxyz = angle_axis / (theta + eps)
145
+ wx, wy, wz = torch.chunk(wxyz, 3, dim=1)
146
+ cos_theta = torch.cos(theta)
147
+ sin_theta = torch.sin(theta)
148
+
149
+ r00 = cos_theta + wx * wx * (k_one - cos_theta)
150
+ r10 = wz * sin_theta + wx * wy * (k_one - cos_theta)
151
+ r20 = -wy * sin_theta + wx * wz * (k_one - cos_theta)
152
+ r01 = wx * wy * (k_one - cos_theta) - wz * sin_theta
153
+ r11 = cos_theta + wy * wy * (k_one - cos_theta)
154
+ r21 = wx * sin_theta + wy * wz * (k_one - cos_theta)
155
+ r02 = wy * sin_theta + wx * wz * (k_one - cos_theta)
156
+ r12 = -wx * sin_theta + wy * wz * (k_one - cos_theta)
157
+ r22 = cos_theta + wz * wz * (k_one - cos_theta)
158
+ rotation_matrix = torch.cat(
159
+ [r00, r01, r02, r10, r11, r12, r20, r21, r22], dim=1)
160
+ return rotation_matrix.view(-1, 3, 3)
161
+
162
+ def _compute_rotation_matrix_taylor(angle_axis):
163
+ rx, ry, rz = torch.chunk(angle_axis, 3, dim=1)
164
+ k_one = torch.ones_like(rx)
165
+ rotation_matrix = torch.cat(
166
+ [k_one, -rz, ry, rz, k_one, -rx, -ry, rx, k_one], dim=1)
167
+ return rotation_matrix.view(-1, 3, 3)
168
+
169
+ # stolen from ceres/rotation.h
170
+
171
+ _angle_axis = torch.unsqueeze(angle_axis, dim=1)
172
+ theta2 = torch.matmul(_angle_axis, _angle_axis.transpose(1, 2))
173
+ theta2 = torch.squeeze(theta2, dim=1)
174
+
175
+ # compute rotation matrices
176
+ rotation_matrix_normal = _compute_rotation_matrix(angle_axis, theta2)
177
+ rotation_matrix_taylor = _compute_rotation_matrix_taylor(angle_axis)
178
+
179
+ # create mask to handle both cases
180
+ eps = 1e-6
181
+ mask = (theta2 > eps).view(-1, 1, 1).to(theta2.device)
182
+ mask_pos = (mask).type_as(theta2)
183
+ mask_neg = (mask == False).type_as(theta2) # noqa
184
+
185
+ # create output pose matrix
186
+ batch_size = angle_axis.shape[0]
187
+ rotation_matrix = torch.eye(3).to(angle_axis.device).type_as(angle_axis)
188
+ rotation_matrix = rotation_matrix.view(1, 3, 3).repeat(batch_size, 1, 1)
189
+ # fill output matrix with masked values
190
+ rotation_matrix[..., :3, :3] = \
191
+ mask_pos * rotation_matrix_normal + mask_neg * rotation_matrix_taylor
192
+ return rotation_matrix # Nx4x4
193
+
194
+
195
+ def rotation_matrix_to_angle_axis(
196
+ rotation_matrix) -> torch.Tensor:
197
+ r"""Convert 3x3 rotation matrix to Rodrigues vector.
198
+ Args:
199
+ rotation_matrix (torch.Tensor): rotation matrix.
200
+ Returns:
201
+ torch.Tensor: Rodrigues vector transformation.
202
+ Shape:
203
+ - Input: :math:`(N, 3, 3)`
204
+ - Output: :math:`(N, 3)`
205
+ Example:
206
+ >>> input = torch.rand(2, 3, 3) # Nx3x3
207
+ >>> output = kornia.rotation_matrix_to_angle_axis(input) # Nx3
208
+ """
209
+ if not isinstance(rotation_matrix, torch.Tensor):
210
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
211
+ type(rotation_matrix)))
212
+
213
+ if not rotation_matrix.shape[-2:] == (3, 3):
214
+ raise ValueError(
215
+ "Input size must be a (*, 3, 3) tensor. Got {}".format(
216
+ rotation_matrix.shape))
217
+ quaternion = rotation_matrix_to_quaternion(rotation_matrix)
218
+ return quaternion_to_angle_axis(quaternion)
219
+
220
+
221
+ def rotation_matrix_to_quaternion(
222
+ rotation_matrix,
223
+ eps: float = 1e-8) -> torch.Tensor:
224
+ r"""Convert 3x3 rotation matrix to 4d quaternion vector.
225
+ The quaternion vector has components in (x, y, z, w) format.
226
+ Args:
227
+ rotation_matrix (torch.Tensor): the rotation matrix to convert.
228
+ eps (float): small value to avoid zero division. Default: 1e-8.
229
+ Return:
230
+ torch.Tensor: the rotation in quaternion.
231
+ Shape:
232
+ - Input: :math:`(*, 3, 3)`
233
+ - Output: :math:`(*, 4)`
234
+ Example:
235
+ >>> input = torch.rand(4, 3, 3) # Nx3x3
236
+ >>> output = kornia.rotation_matrix_to_quaternion(input) # Nx4
237
+ """
238
+ if not isinstance(rotation_matrix, torch.Tensor):
239
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
240
+ type(rotation_matrix)))
241
+
242
+ if not rotation_matrix.shape[-2:] == (3, 3):
243
+ raise ValueError(
244
+ "Input size must be a (*, 3, 3) tensor. Got {}".format(
245
+ rotation_matrix.shape))
246
+
247
+ def safe_zero_division(numerator,
248
+ denominator) -> torch.Tensor:
249
+ eps = torch.finfo(numerator.dtype).tiny # type: ignore
250
+ return numerator / torch.clamp(denominator, min=eps)
251
+
252
+ rotation_matrix_vec = rotation_matrix.view(
253
+ *rotation_matrix.shape[:-2], 9)
254
+
255
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.chunk(
256
+ rotation_matrix_vec, chunks=9, dim=-1)
257
+
258
+ trace = m00 + m11 + m22
259
+
260
+ def trace_positive_cond():
261
+ sq = torch.sqrt(trace + 1.0) * 2. # sq = 4 * qw.
262
+ qw = 0.25 * sq
263
+ qx = safe_zero_division(m21 - m12, sq)
264
+ qy = safe_zero_division(m02 - m20, sq)
265
+ qz = safe_zero_division(m10 - m01, sq)
266
+ return torch.cat([qx, qy, qz, qw], dim=-1)
267
+
268
+ def cond_1():
269
+ sq = torch.sqrt(1.0 + m00 - m11 - m22 + eps) * 2. # sq = 4 * qx.
270
+ qw = safe_zero_division(m21 - m12, sq)
271
+ qx = 0.25 * sq
272
+ qy = safe_zero_division(m01 + m10, sq)
273
+ qz = safe_zero_division(m02 + m20, sq)
274
+ return torch.cat([qx, qy, qz, qw], dim=-1)
275
+
276
+ def cond_2():
277
+ sq = torch.sqrt(1.0 + m11 - m00 - m22 + eps) * 2. # sq = 4 * qy.
278
+ qw = safe_zero_division(m02 - m20, sq)
279
+ qx = safe_zero_division(m01 + m10, sq)
280
+ qy = 0.25 * sq
281
+ qz = safe_zero_division(m12 + m21, sq)
282
+ return torch.cat([qx, qy, qz, qw], dim=-1)
283
+
284
+ def cond_3():
285
+ sq = torch.sqrt(1.0 + m22 - m00 - m11 + eps) * 2. # sq = 4 * qz.
286
+ qw = safe_zero_division(m10 - m01, sq)
287
+ qx = safe_zero_division(m02 + m20, sq)
288
+ qy = safe_zero_division(m12 + m21, sq)
289
+ qz = 0.25 * sq
290
+ return torch.cat([qx, qy, qz, qw], dim=-1)
291
+
292
+ where_2 = torch.where(m11 > m22, cond_2(), cond_3())
293
+ where_1 = torch.where(
294
+ (m00 > m11) & (m00 > m22), cond_1(), where_2)
295
+
296
+ quaternion = torch.where(
297
+ trace > 0., trace_positive_cond(), where_1)
298
+ return quaternion
299
+
300
+
301
+ def normalize_quaternion(quaternion,
302
+ eps: float = 1e-12) -> torch.Tensor:
303
+ r"""Normalizes a quaternion.
304
+ The quaternion should be in (x, y, z, w) format.
305
+ Args:
306
+ quaternion (torch.Tensor): a tensor containing a quaternion to be
307
+ normalized. The tensor can be of shape :math:`(*, 4)`.
308
+ eps (Optional[bool]): small value to avoid division by zero.
309
+ Default: 1e-12.
310
+ Return:
311
+ torch.Tensor: the normalized quaternion of shape :math:`(*, 4)`.
312
+ Example:
313
+ >>> quaternion = torch.tensor([1., 0., 1., 0.])
314
+ >>> kornia.normalize_quaternion(quaternion)
315
+ tensor([0.7071, 0.0000, 0.7071, 0.0000])
316
+ """
317
+ if not isinstance(quaternion, torch.Tensor):
318
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
319
+ type(quaternion)))
320
+
321
+ if not quaternion.shape[-1] == 4:
322
+ raise ValueError(
323
+ "Input must be a tensor of shape (*, 4). Got {}".format(
324
+ quaternion.shape))
325
+ return F.normalize(quaternion, p=2, dim=-1, eps=eps)
326
+
327
+
328
+ # based on:
329
+ # https://github.com/matthew-brett/transforms3d/blob/8965c48401d9e8e66b6a8c37c65f2fc200a076fa/transforms3d/quaternions.py#L101
330
+ # https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/geometry/transformation/rotation_matrix_3d.py#L247
331
+
332
+ def quaternion_to_rotation_matrix(quaternion) -> torch.Tensor:
333
+ r"""Converts a quaternion to a rotation matrix.
334
+ The quaternion should be in (x, y, z, w) format.
335
+ Args:
336
+ quaternion (torch.Tensor): a tensor containing a quaternion to be
337
+ converted. The tensor can be of shape :math:`(*, 4)`.
338
+ Return:
339
+ torch.Tensor: the rotation matrix of shape :math:`(*, 3, 3)`.
340
+ Example:
341
+ >>> quaternion = torch.tensor([0., 0., 1., 0.])
342
+ >>> kornia.quaternion_to_rotation_matrix(quaternion)
343
+ tensor([[[-1., 0., 0.],
344
+ [ 0., -1., 0.],
345
+ [ 0., 0., 1.]]])
346
+ """
347
+ if not isinstance(quaternion, torch.Tensor):
348
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
349
+ type(quaternion)))
350
+
351
+ if not quaternion.shape[-1] == 4:
352
+ raise ValueError(
353
+ "Input must be a tensor of shape (*, 4). Got {}".format(
354
+ quaternion.shape))
355
+ # normalize the input quaternion
356
+ quaternion_norm = normalize_quaternion(quaternion)
357
+
358
+ # unpack the normalized quaternion components
359
+ x, y, z, w = torch.chunk(quaternion_norm, chunks=4, dim=-1)
360
+
361
+ # compute the actual conversion
362
+ tx = 2.0 * x
363
+ ty = 2.0 * y
364
+ tz = 2.0 * z
365
+ twx = tx * w
366
+ twy = ty * w
367
+ twz = tz * w
368
+ txx = tx * x
369
+ txy = ty * x
370
+ txz = tz * x
371
+ tyy = ty * y
372
+ tyz = tz * y
373
+ tzz = tz * z
374
+ one = torch.tensor(1.)
375
+
376
+ matrix = torch.stack([
377
+ one - (tyy + tzz), txy - twz, txz + twy,
378
+ txy + twz, one - (txx + tzz), tyz - twx,
379
+ txz - twy, tyz + twx, one - (txx + tyy)
380
+ ], dim=-1).view(-1, 3, 3)
381
+
382
+ if len(quaternion.shape) == 1:
383
+ matrix = torch.squeeze(matrix, dim=0)
384
+ return matrix
385
+
386
+
387
+ def quaternion_to_angle_axis(quaternion) -> torch.Tensor:
388
+ """Convert quaternion vector to angle axis of rotation.
389
+ The quaternion should be in (x, y, z, w) format.
390
+ Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
391
+ Args:
392
+ quaternion (torch.Tensor): tensor with quaternions.
393
+ Return:
394
+ torch.Tensor: tensor with angle axis of rotation.
395
+ Shape:
396
+ - Input: :math:`(*, 4)` where `*` means, any number of dimensions
397
+ - Output: :math:`(*, 3)`
398
+ Example:
399
+ >>> quaternion = torch.rand(2, 4) # Nx4
400
+ >>> angle_axis = kornia.quaternion_to_angle_axis(quaternion) # Nx3
401
+ """
402
+ if not torch.is_tensor(quaternion):
403
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
404
+ type(quaternion)))
405
+
406
+ if not quaternion.shape[-1] == 4:
407
+ raise ValueError(
408
+ "Input must be a tensor of shape Nx4 or 4. Got {}".format(
409
+ quaternion.shape))
410
+ # unpack input and compute conversion
411
+ q1 = quaternion[..., 0]
412
+ q2 = quaternion[..., 1]
413
+ q3 = quaternion[..., 2]
414
+ sin_squared_theta = q1 * q1 + q2 * q2 + q3 * q3
415
+
416
+ sin_theta = torch.sqrt(sin_squared_theta)
417
+ cos_theta = quaternion[..., 3]
418
+ two_theta = 2.0 * torch.where(
419
+ cos_theta < 0.0, torch.atan2(-sin_theta, -cos_theta),
420
+ torch.atan2(sin_theta, cos_theta))
421
+
422
+ k_pos = two_theta / (sin_theta+1e-8)
423
+ k_neg = 2.0 * torch.ones_like(sin_theta)
424
+ k = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
425
+
426
+ angle_axis = torch.zeros_like(quaternion)[..., :3]
427
+ angle_axis[..., 0] += q1 * k
428
+ angle_axis[..., 1] += q2 * k
429
+ angle_axis[..., 2] += q3 * k
430
+ return angle_axis
431
+
432
+
433
+ def quaternion_log_to_exp(quaternion,
434
+ eps: float = 1e-8) -> torch.Tensor:
435
+ r"""Applies exponential map to log quaternion.
436
+ The quaternion should be in (x, y, z, w) format.
437
+ Args:
438
+ quaternion (torch.Tensor): a tensor containing a quaternion to be
439
+ converted. The tensor can be of shape :math:`(*, 3)`.
440
+ Return:
441
+ torch.Tensor: the quaternion exponential map of shape :math:`(*, 4)`.
442
+ Example:
443
+ >>> quaternion = torch.tensor([0., 0., 0.])
444
+ >>> kornia.quaternion_log_to_exp(quaternion)
445
+ tensor([0., 0., 0., 1.])
446
+ """
447
+ if not isinstance(quaternion, torch.Tensor):
448
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
449
+ type(quaternion)))
450
+
451
+ if not quaternion.shape[-1] == 3:
452
+ raise ValueError(
453
+ "Input must be a tensor of shape (*, 3). Got {}".format(
454
+ quaternion.shape))
455
+ # compute quaternion norm
456
+ norm_q = torch.norm(
457
+ quaternion, p=2, dim=-1, keepdim=True).clamp(min=eps)
458
+
459
+ # compute scalar and vector
460
+ quaternion_vector = quaternion * torch.sin(norm_q) / norm_q
461
+ quaternion_scalar = torch.cos(norm_q)
462
+
463
+ # compose quaternion and return
464
+ quaternion_exp = torch.cat(
465
+ [quaternion_vector, quaternion_scalar], dim=-1)
466
+ return quaternion_exp
467
+
468
+
469
+ def quaternion_exp_to_log(quaternion,
470
+ eps: float = 1e-8) -> torch.Tensor:
471
+ r"""Applies the log map to a quaternion.
472
+ The quaternion should be in (x, y, z, w) format.
473
+ Args:
474
+ quaternion (torch.Tensor): a tensor containing a quaternion to be
475
+ converted. The tensor can be of shape :math:`(*, 4)`.
476
+ Return:
477
+ torch.Tensor: the quaternion log map of shape :math:`(*, 3)`.
478
+ Example:
479
+ >>> quaternion = torch.tensor([0., 0., 0., 1.])
480
+ >>> kornia.quaternion_exp_to_log(quaternion)
481
+ tensor([0., 0., 0.])
482
+ """
483
+ if not isinstance(quaternion, torch.Tensor):
484
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
485
+ type(quaternion)))
486
+
487
+ if not quaternion.shape[-1] == 4:
488
+ raise ValueError(
489
+ "Input must be a tensor of shape (*, 4). Got {}".format(
490
+ quaternion.shape))
491
+ # unpack quaternion vector and scalar
492
+ quaternion_vector = quaternion[..., 0:3]
493
+ quaternion_scalar = quaternion[..., 3:4]
494
+
495
+ # compute quaternion norm
496
+ norm_q = torch.norm(
497
+ quaternion_vector, p=2, dim=-1, keepdim=True).clamp(min=eps)
498
+
499
+ # apply log map
500
+ quaternion_log = quaternion_vector * torch.acos(
501
+ torch.clamp(quaternion_scalar, min=-1.0, max=1.0)) / norm_q
502
+ return quaternion_log
503
+
504
+
505
+ # based on:
506
+ # https://github.com/facebookresearch/QuaterNet/blob/master/common/quaternion.py#L138
507
+
508
+
509
+ def angle_axis_to_quaternion(angle_axis) -> torch.Tensor:
510
+ r"""Convert an angle axis to a quaternion.
511
+ The quaternion vector has components in (x, y, z, w) format.
512
+ Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
513
+ Args:
514
+ angle_axis (torch.Tensor): tensor with angle axis.
515
+ Return:
516
+ torch.Tensor: tensor with quaternion.
517
+ Shape:
518
+ - Input: :math:`(*, 3)` where `*` means, any number of dimensions
519
+ - Output: :math:`(*, 4)`
520
+ Example:
521
+ >>> angle_axis = torch.rand(2, 4) # Nx4
522
+ >>> quaternion = kornia.angle_axis_to_quaternion(angle_axis) # Nx3
523
+ """
524
+ if not torch.is_tensor(angle_axis):
525
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
526
+ type(angle_axis)))
527
+
528
+ if not angle_axis.shape[-1] == 3:
529
+ raise ValueError(
530
+ "Input must be a tensor of shape Nx3 or 3. Got {}".format(
531
+ angle_axis.shape))
532
+ # unpack input and compute conversion
533
+ a0 = angle_axis[..., 0:1]
534
+ a1 = angle_axis[..., 1:2]
535
+ a2 = angle_axis[..., 2:3]
536
+ theta_squared = a0 * a0 + a1 * a1 + a2 * a2
537
+
538
+ theta = torch.sqrt(theta_squared)
539
+ half_theta = theta * 0.5
540
+
541
+ mask = theta_squared > 0.0
542
+ ones = torch.ones_like(half_theta)
543
+
544
+ k_neg = 0.5 * ones
545
+ k_pos = torch.sin(half_theta) / theta
546
+ k = torch.where(mask, k_pos, k_neg)
547
+ w = torch.where(mask, torch.cos(half_theta), ones)
548
+
549
+ quaternion = torch.zeros_like(angle_axis)
550
+ quaternion[..., 0:1] += a0 * k
551
+ quaternion[..., 1:2] += a1 * k
552
+ quaternion[..., 2:3] += a2 * k
553
+ return torch.cat([quaternion, w], dim=-1)
554
+
555
+
556
+ # based on:
557
+ # https://github.com/ClementPinard/SfmLearner-Pytorch/blob/master/inverse_warp.py#L65-L71
558
+
559
+ def normalize_pixel_coordinates(
560
+ pixel_coordinates,
561
+ height: int,
562
+ width: int,
563
+ eps: float = 1e-8) -> torch.Tensor:
564
+ r"""Normalize pixel coordinates between -1 and 1.
565
+ Normalized, -1 if on extreme left, 1 if on extreme right (x = w-1).
566
+ Args:
567
+ pixel_coordinates (torch.Tensor): the grid with pixel coordinates.
568
+ Shape can be :math:`(*, 2)`.
569
+ width (int): the maximum width in the x-axis.
570
+ height (int): the maximum height in the y-axis.
571
+ eps (float): safe division by zero. (default 1e-8).
572
+ Return:
573
+ torch.Tensor: the normalized pixel coordinates.
574
+ """
575
+ if pixel_coordinates.shape[-1] != 2:
576
+ raise ValueError("Input pixel_coordinates must be of shape (*, 2). "
577
+ "Got {}".format(pixel_coordinates.shape))
578
+ # compute normalization factor
579
+ hw = torch.stack([
580
+ torch.tensor(width), torch.tensor(height)
581
+ ]).to(pixel_coordinates.device).to(pixel_coordinates.dtype)
582
+
583
+ factor = torch.tensor(2.) / (hw - 1).clamp(eps)
584
+
585
+ return factor * pixel_coordinates - 1
586
+
587
+
588
+ def denormalize_pixel_coordinates(
589
+ pixel_coordinates,
590
+ height: int,
591
+ width: int,
592
+ eps: float = 1e-8) -> torch.Tensor:
593
+ r"""Denormalize pixel coordinates.
594
+ The input is assumed to be -1 if on extreme left, 1 if on
595
+ extreme right (x = w-1).
596
+ Args:
597
+ pixel_coordinates (torch.Tensor): the normalized grid coordinates.
598
+ Shape can be :math:`(*, 2)`.
599
+ width (int): the maximum width in the x-axis.
600
+ height (int): the maximum height in the y-axis.
601
+ eps (float): safe division by zero. (default 1e-8).
602
+ Return:
603
+ torch.Tensor: the denormalized pixel coordinates.
604
+ """
605
+ if pixel_coordinates.shape[-1] != 2:
606
+ raise ValueError("Input pixel_coordinates must be of shape (*, 2). "
607
+ "Got {}".format(pixel_coordinates.shape))
608
+ # compute normalization factor
609
+ hw = torch.stack([
610
+ torch.tensor(width), torch.tensor(height)
611
+ ]).to(pixel_coordinates.device).to(pixel_coordinates.dtype)
612
+
613
+ factor = torch.tensor(2.) / (hw - 1).clamp(eps)
614
+
615
+ return torch.tensor(1.) / factor * (pixel_coordinates + 1)
616
+
617
+
618
+ def normalize_pixel_coordinates3d(
619
+ pixel_coordinates,
620
+ depth: int,
621
+ height: int,
622
+ width: int,
623
+ eps: float = 1e-8) -> torch.Tensor:
624
+ r"""Normalize pixel coordinates between -1 and 1.
625
+ Normalized, -1 if on extreme left, 1 if on extreme right (x = w-1).
626
+ Args:
627
+ pixel_coordinates (torch.Tensor): the grid with pixel coordinates.
628
+ Shape can be :math:`(*, 3)`.
629
+ depth (int): the maximum depth in the z-axis.
630
+ height (int): the maximum height in the y-axis.
631
+ width (int): the maximum width in the x-axis.
632
+ eps (float): safe division by zero. (default 1e-8).
633
+ Return:
634
+ torch.Tensor: the normalized pixel coordinates.
635
+ """
636
+ if pixel_coordinates.shape[-1] != 3:
637
+ raise ValueError("Input pixel_coordinates must be of shape (*, 3). "
638
+ "Got {}".format(pixel_coordinates.shape))
639
+ # compute normalization factor
640
+ dhw = torch.stack([
641
+ torch.tensor(depth), torch.tensor(width), torch.tensor(height)
642
+ ]).to(pixel_coordinates.device).to(pixel_coordinates.dtype)
643
+
644
+ factor = torch.tensor(2.) / (dhw - 1).clamp(eps)
645
+
646
+ return factor * pixel_coordinates - 1
647
+
648
+
649
+ def denormalize_pixel_coordinates3d(
650
+ pixel_coordinates,
651
+ depth: int,
652
+ height: int,
653
+ width: int,
654
+ eps: float = 1e-8) -> torch.Tensor:
655
+ r"""Denormalize pixel coordinates.
656
+ The input is assumed to be -1 if on extreme left, 1 if on
657
+ extreme right (x = w-1).
658
+ Args:
659
+ pixel_coordinates (torch.Tensor): the normalized grid coordinates.
660
+ Shape can be :math:`(*, 3)`.
661
+ depth (int): the maximum depth in the x-axis.
662
+ height (int): the maximum height in the y-axis.
663
+ width (int): the maximum width in the x-axis.
664
+ eps (float): safe division by zero. (default 1e-8).
665
+ Return:
666
+ torch.Tensor: the denormalized pixel coordinates.
667
+ """
668
+ if pixel_coordinates.shape[-1] != 3:
669
+ raise ValueError("Input pixel_coordinates must be of shape (*, 3). "
670
+ "Got {}".format(pixel_coordinates.shape))
671
+ # compute normalization factor
672
+ dhw = torch.stack([
673
+ torch.tensor(depth), torch.tensor(width), torch.tensor(height)
674
+ ]).to(pixel_coordinates.device).to(pixel_coordinates.dtype)
675
+
676
+ factor = torch.tensor(2.) / (dhw - 1).clamp(eps)
677
+ return torch.tensor(1.) / factor * (pixel_coordinates + 1)
678
+
679
+ def rotation_matrix_to_euler(R):
680
+ # Following convertion has problem when Rx=90 degrees, which is
681
+ # usually known as the gimbal lock problem
682
+ Q = rotation_matrix_to_quaternion(R)
683
+
684
+ x = Q[:,0].unsqueeze(-1)
685
+ y = Q[:,1].unsqueeze(-1)
686
+ z = Q[:,2].unsqueeze(-1)
687
+ w = Q[:,3].unsqueeze(-1)
688
+
689
+ Rx = torch.atan2(2.0 * (w * x + y * z), 1.0 - 2.0 * (x * x + y * y))
690
+ Ry = torch.asin(2.0 * (w * y - x * z))
691
+ Rz = torch.atan2(2.0 * (w * z + x * y), 1.0 - 2.0 * (y * y + z * z))
692
+
693
+ euler = torch.cat([Rx, Ry, Rz], dim=-1)
694
+ return euler
695
+
696
+ def euler_to_rotation_matrix(euler):
697
+ # euler is defined x, y, z
698
+ Rx = euler[:,0].unsqueeze(-1)
699
+ Ry = euler[:,1].unsqueeze(-1)
700
+ Rz = euler[:,2].unsqueeze(-1)
701
+
702
+ cr = torch.cos(0.5 * Rx)
703
+ sr = torch.sin(0.5 * Rx)
704
+ cp = torch.cos(0.5 * Ry)
705
+ sp = torch.sin(0.5 * Ry)
706
+ cy = torch.cos(0.5 * Rz)
707
+ sy = torch.sin(0.5 * Rz)
708
+
709
+ w = cr * cp * cy + sr * sp * sy
710
+ x = sr * cp * cy - cr * sp * sy
711
+ y = cr * sp * cy + sr * cp * sy
712
+ z = cr * cp * sy - sr * sp * cy
713
+
714
+ Q = torch.cat([x,y,z,w], dim=-1)
715
+ R = quaternion_to_rotation_matrix(Q)
716
+ return R
717
+
718
+ # R = torch.tensor([[[ 0.0266, 0.1109, 0.9935],
719
+ # [-0.0344, 0.9933, -0.1099],
720
+ # [-0.9991, -0.0313, 0.0302]]])
721
+ # print(rotation_matrix_to_quaternion(R))