spikezoo 0.1.1__py3-none-any.whl → 0.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (192) hide show
  1. spikezoo/__init__.py +13 -0
  2. spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
  3. spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
  4. spikezoo/archs/base/nets.py +34 -0
  5. spikezoo/archs/bsf/README.md +92 -0
  6. spikezoo/archs/bsf/datasets/datasets.py +328 -0
  7. spikezoo/archs/bsf/datasets/ds_utils.py +64 -0
  8. spikezoo/archs/bsf/main.py +398 -0
  9. spikezoo/archs/bsf/metrics/psnr.py +22 -0
  10. spikezoo/archs/bsf/metrics/ssim.py +54 -0
  11. spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
  12. spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
  13. spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
  14. spikezoo/archs/bsf/models/bsf/align.py +154 -0
  15. spikezoo/archs/bsf/models/bsf/bsf.py +105 -0
  16. spikezoo/archs/bsf/models/bsf/dsft_convert.py +96 -0
  17. spikezoo/archs/bsf/models/bsf/rep.py +44 -0
  18. spikezoo/archs/bsf/models/get_model.py +7 -0
  19. spikezoo/archs/bsf/prepare_data/DSFT.py +62 -0
  20. spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +135 -0
  21. spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +139 -0
  22. spikezoo/archs/bsf/prepare_data/crop_train.sh +4 -0
  23. spikezoo/archs/bsf/prepare_data/crop_val.sh +4 -0
  24. spikezoo/archs/bsf/prepare_data/io_utils.py +64 -0
  25. spikezoo/archs/bsf/requirements.txt +9 -0
  26. spikezoo/archs/bsf/test.py +16 -0
  27. spikezoo/archs/bsf/utils.py +154 -0
  28. spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
  29. spikezoo/archs/spikeclip/nets.py +40 -0
  30. spikezoo/archs/spikeformer/CheckPoints/readme +1 -0
  31. spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +60 -0
  32. spikezoo/archs/spikeformer/DataProcess/DataLoader.py +115 -0
  33. spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +39 -0
  34. spikezoo/archs/spikeformer/EvalResults/readme +1 -0
  35. spikezoo/archs/spikeformer/LICENSE +21 -0
  36. spikezoo/archs/spikeformer/Metrics/Metrics.py +50 -0
  37. spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
  38. spikezoo/archs/spikeformer/Model/Loss.py +89 -0
  39. spikezoo/archs/spikeformer/Model/SpikeFormer.py +230 -0
  40. spikezoo/archs/spikeformer/Model/__init__.py +0 -0
  41. spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
  42. spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
  43. spikezoo/archs/spikeformer/README.md +30 -0
  44. spikezoo/archs/spikeformer/evaluate.py +87 -0
  45. spikezoo/archs/spikeformer/recon_real_data.py +97 -0
  46. spikezoo/archs/spikeformer/requirements.yml +95 -0
  47. spikezoo/archs/spikeformer/train.py +173 -0
  48. spikezoo/archs/spikeformer/utils.py +22 -0
  49. spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +23 -0
  50. spikezoo/archs/spk2imgnet/.gitignore +150 -0
  51. spikezoo/archs/spk2imgnet/DCNv2.py +135 -0
  52. spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
  53. spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
  54. spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
  55. spikezoo/archs/spk2imgnet/align_arch.py +159 -0
  56. spikezoo/archs/spk2imgnet/dataset.py +144 -0
  57. spikezoo/archs/spk2imgnet/nets.py +230 -0
  58. spikezoo/archs/spk2imgnet/readme.md +86 -0
  59. spikezoo/archs/spk2imgnet/test_gen_imgseq.py +118 -0
  60. spikezoo/archs/spk2imgnet/train.py +189 -0
  61. spikezoo/archs/spk2imgnet/utils.py +64 -0
  62. spikezoo/archs/ssir/README.md +87 -0
  63. spikezoo/archs/ssir/configs/SSIR.yml +37 -0
  64. spikezoo/archs/ssir/configs/yml_parser.py +78 -0
  65. spikezoo/archs/ssir/datasets/dataset_sreds.py +170 -0
  66. spikezoo/archs/ssir/datasets/ds_utils.py +66 -0
  67. spikezoo/archs/ssir/losses.py +21 -0
  68. spikezoo/archs/ssir/main.py +326 -0
  69. spikezoo/archs/ssir/metrics/psnr.py +22 -0
  70. spikezoo/archs/ssir/metrics/ssim.py +54 -0
  71. spikezoo/archs/ssir/models/Vgg19.py +42 -0
  72. spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
  73. spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
  74. spikezoo/archs/ssir/models/layers.py +110 -0
  75. spikezoo/archs/ssir/models/networks.py +61 -0
  76. spikezoo/archs/ssir/requirements.txt +8 -0
  77. spikezoo/archs/ssir/shells/eval_SREDS.sh +6 -0
  78. spikezoo/archs/ssir/shells/train_SSIR.sh +12 -0
  79. spikezoo/archs/ssir/test.py +3 -0
  80. spikezoo/archs/ssir/utils.py +154 -0
  81. spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
  82. spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
  83. spikezoo/archs/ssml/cbam.py +224 -0
  84. spikezoo/archs/ssml/model.py +290 -0
  85. spikezoo/archs/ssml/res.png +0 -0
  86. spikezoo/archs/ssml/test.py +67 -0
  87. spikezoo/archs/stir/.git-credentials +0 -0
  88. spikezoo/archs/stir/README.md +65 -0
  89. spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +1 -0
  90. spikezoo/archs/stir/configs/STIR.yml +37 -0
  91. spikezoo/archs/stir/configs/utils.py +155 -0
  92. spikezoo/archs/stir/configs/yml_parser.py +78 -0
  93. spikezoo/archs/stir/datasets/dataset_sreds.py +180 -0
  94. spikezoo/archs/stir/datasets/ds_utils.py +66 -0
  95. spikezoo/archs/stir/eval_SREDS.sh +5 -0
  96. spikezoo/archs/stir/main.py +397 -0
  97. spikezoo/archs/stir/metrics/losses.py +219 -0
  98. spikezoo/archs/stir/metrics/psnr.py +22 -0
  99. spikezoo/archs/stir/metrics/ssim.py +54 -0
  100. spikezoo/archs/stir/models/Vgg19.py +42 -0
  101. spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
  102. spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
  103. spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
  104. spikezoo/archs/stir/models/networks_STIR.py +361 -0
  105. spikezoo/archs/stir/models/submodules.py +86 -0
  106. spikezoo/archs/stir/models/transformer_new.py +151 -0
  107. spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
  108. spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +721 -0
  109. spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +133 -0
  110. spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +167 -0
  111. spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +76 -0
  112. spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +458 -0
  113. spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +183 -0
  114. spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +40 -0
  115. spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +198 -0
  116. spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +51 -0
  117. spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +53 -0
  118. spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +100 -0
  119. spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +333 -0
  120. spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +123 -0
  121. spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +72 -0
  122. spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
  123. spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
  124. spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
  125. spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
  126. spikezoo/archs/stir/package_core/package_core/convertions.py +721 -0
  127. spikezoo/archs/stir/package_core/package_core/disp_netS.py +133 -0
  128. spikezoo/archs/stir/package_core/package_core/flow_utils.py +167 -0
  129. spikezoo/archs/stir/package_core/package_core/generic_train_test.py +76 -0
  130. spikezoo/archs/stir/package_core/package_core/geometry.py +458 -0
  131. spikezoo/archs/stir/package_core/package_core/image_proc.py +183 -0
  132. spikezoo/archs/stir/package_core/package_core/linalg.py +40 -0
  133. spikezoo/archs/stir/package_core/package_core/losses.py +198 -0
  134. spikezoo/archs/stir/package_core/package_core/metrics.py +51 -0
  135. spikezoo/archs/stir/package_core/package_core/model_base.py +53 -0
  136. spikezoo/archs/stir/package_core/package_core/net_basics.py +100 -0
  137. spikezoo/archs/stir/package_core/package_core/resnet.py +333 -0
  138. spikezoo/archs/stir/package_core/package_core/transforms.py +123 -0
  139. spikezoo/archs/stir/package_core/package_core/utils.py +72 -0
  140. spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +3 -0
  141. spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +20 -0
  142. spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +1 -0
  143. spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +1 -0
  144. spikezoo/archs/stir/package_core/setup.py +5 -0
  145. spikezoo/archs/stir/requirements.txt +12 -0
  146. spikezoo/archs/stir/train_STIR.sh +9 -0
  147. spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
  148. spikezoo/archs/tfi/nets.py +43 -0
  149. spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
  150. spikezoo/archs/tfp/nets.py +13 -0
  151. spikezoo/archs/wgse/README.md +64 -0
  152. spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
  153. spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
  154. spikezoo/archs/wgse/dataset.py +59 -0
  155. spikezoo/archs/wgse/demo.png +0 -0
  156. spikezoo/archs/wgse/demo.py +83 -0
  157. spikezoo/archs/wgse/dwtnets.py +145 -0
  158. spikezoo/archs/wgse/eval.py +133 -0
  159. spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +11 -0
  160. spikezoo/archs/wgse/submodules.py +68 -0
  161. spikezoo/archs/wgse/train.py +261 -0
  162. spikezoo/archs/wgse/transform.py +139 -0
  163. spikezoo/archs/wgse/utils.py +128 -0
  164. spikezoo/archs/wgse/weights/demo.png +0 -0
  165. spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
  166. spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
  167. spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
  168. spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
  169. spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
  170. spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
  171. spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
  172. spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
  173. spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
  174. spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
  175. spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
  176. spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
  177. spikezoo/datasets/base_dataset.py +2 -3
  178. spikezoo/metrics/__init__.py +1 -1
  179. spikezoo/models/base_model.py +1 -3
  180. spikezoo/pipeline/base_pipeline.py +7 -5
  181. spikezoo/pipeline/train_pipeline.py +1 -1
  182. spikezoo/utils/other_utils.py +16 -6
  183. spikezoo/utils/spike_utils.py +33 -29
  184. spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
  185. spikezoo-0.2.dist-info/METADATA +163 -0
  186. spikezoo-0.2.dist-info/RECORD +211 -0
  187. spikezoo/models/spcsnet_model.py +0 -19
  188. spikezoo-0.1.1.dist-info/METADATA +0 -39
  189. spikezoo-0.1.1.dist-info/RECORD +0 -36
  190. {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/LICENSE.txt +0 -0
  191. {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/WHEEL +0 -0
  192. {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,458 @@
1
+ """
2
+ A collection of geometric transformation operations
3
+ @author: Zhaoyang Lv
4
+ @Date: March, 2019
5
+ """
6
+
7
+ from __future__ import absolute_import
8
+ from __future__ import division
9
+ from __future__ import print_function
10
+ from __future__ import unicode_literals
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch import sin, cos, atan2, acos
15
+
16
+ _NEXT_AXIS = [1, 2, 0, 1]
17
+
18
+ # map axes strings to/from tuples of inner axis, parity, repetition, frame
19
+ _AXES2TUPLE = {
20
+ 'sxyz': (0, 0, 0, 0), 'sxyx': (0, 0, 1, 0), 'sxzy': (0, 1, 0, 0),
21
+ 'sxzx': (0, 1, 1, 0), 'syzx': (1, 0, 0, 0), 'syzy': (1, 0, 1, 0),
22
+ 'syxz': (1, 1, 0, 0), 'syxy': (1, 1, 1, 0), 'szxy': (2, 0, 0, 0),
23
+ 'szxz': (2, 0, 1, 0), 'szyx': (2, 1, 0, 0), 'szyz': (2, 1, 1, 0),
24
+ 'rzyx': (0, 0, 0, 1), 'rxyx': (0, 0, 1, 1), 'ryzx': (0, 1, 0, 1),
25
+ 'rxzx': (0, 1, 1, 1), 'rxzy': (1, 0, 0, 1), 'ryzy': (1, 0, 1, 1),
26
+ 'rzxy': (1, 1, 0, 1), 'ryxy': (1, 1, 1, 1), 'ryxz': (2, 0, 0, 1),
27
+ 'rzxz': (2, 0, 1, 1), 'rxyz': (2, 1, 0, 1), 'rzyz': (2, 1, 1, 1)}
28
+
29
+ _TUPLE2AXES = dict((v, k) for k, v in _AXES2TUPLE.items())
30
+
31
+ def meshgrid(H, W, B=None, is_cuda=False):
32
+ """ torch version of numpy meshgrid function
33
+ :input
34
+ :param height
35
+ :param width
36
+ :param batch size
37
+ :param initialize a cuda tensor if true
38
+ -------
39
+ :return
40
+ :param meshgrid in column
41
+ :param meshgrid in row
42
+ """
43
+ u = torch.arange(0, W)
44
+ v = torch.arange(0, H)
45
+
46
+ if is_cuda:
47
+ u, v = u.cuda(), v.cuda()
48
+
49
+ u = u.repeat(H, 1).view(1,H,W)
50
+ v = v.repeat(W, 1).t_().view(1,H,W)
51
+
52
+ if B is not None:
53
+ u, v = u.repeat(B,1,1,1), v.repeat(B,1,1,1)
54
+ return u, v
55
+
56
+ def generate_xy_grid(B, H, W, K):
57
+ """ Generate a batch of image grid from image space to world space
58
+ px = (u - cx) / fx
59
+ py = (y - cy) / fy
60
+ function tested in 'test_geometry.py'
61
+ :input
62
+ :param batch size
63
+ :param height
64
+ :param width
65
+ :param camera intrinsic array [fx,fy,cx,cy]
66
+ ---------
67
+ :return
68
+ :param
69
+ :param
70
+ """
71
+ fx, fy, cx, cy = K.split(1,dim=1)
72
+ uv_grid = meshgrid(H, W, B)
73
+ u_grid, v_grid = [uv.type_as(cx) for uv in uv_grid]
74
+ px = ((u_grid.view(B,-1) - cx) / fx).view(B,1,H,W)
75
+ py = ((v_grid.view(B,-1) - cy) / fy).view(B,1,H,W)
76
+ return px, py
77
+
78
+ def batch_inverse_Rt(R, t):
79
+ """ The inverse of the R, t: [R' | -R't]
80
+ function tested in 'test_geometry.py'
81
+ :input
82
+ :param rotation Bx3x3
83
+ :param translation Bx3
84
+ ----------
85
+ :return
86
+ :param rotation inverse Bx3x3
87
+ :param translation inverse Bx3
88
+ """
89
+ R_t = R.transpose(1,2)
90
+ t_inv = -torch.bmm(R_t, t.contiguous().view(-1, 3, 1))
91
+
92
+ return R_t, t_inv.view(-1,3)
93
+
94
+ def batch_Rt_compose(d_R, d_t, R0, t0):
95
+ """ Compose operator of R, t: [d_R*R | d_R*t + d_t]
96
+ We use left-mulitplication rule here.
97
+ function tested in 'test_geometry.py'
98
+
99
+ :input
100
+ :param rotation incremental Bx3x3
101
+ :param translation incremental Bx3
102
+ :param initial rotation Bx3x3
103
+ :param initial translation Bx3
104
+ ----------
105
+ :return
106
+ :param composed rotation Bx3x3
107
+ :param composed translation Bx3
108
+ """
109
+ R1 = d_R.bmm(R0)
110
+ t1 = d_R.bmm(t0.view(-1,3,1)) + d_t.view(-1,3,1)
111
+ return R1, t1.view(-1,3)
112
+
113
+ def batch_Rt_between(R0, t0, R1, t1):
114
+ """ Between operator of R, t, transform of T_0=[R0, t0] to T_1=[R1, t1]
115
+ which is T_1 \compose T^{-1}_0
116
+ function tested in 'test_geometry.py'
117
+
118
+ :input
119
+ :param rotation of source Bx3x3
120
+ :param translation of source Bx3
121
+ :param rotation of target Bx3x3
122
+ :param translation of target Bx3
123
+ ----------
124
+ :return
125
+ :param incremental rotation Bx3x3
126
+ :param incremnetal translation Bx3
127
+ """
128
+ R0t = R0.transpose(1,2)
129
+ dR = R1.bmm(R0t)
130
+ dt = t1.view(-1,3) - dR.bmm(t0.view(-1,3,1)).view(-1,3)
131
+ return dR, dt
132
+
133
+ def batch_skew(w):
134
+ """ Generate a batch of skew-symmetric matrices.
135
+ function tested in 'test_geometry.py'
136
+ :input
137
+ :param skew symmetric matrix entry Bx3
138
+ ---------
139
+ :return
140
+ :param the skew-symmetric matrix Bx3x3
141
+ """
142
+ B, D = w.size()
143
+ assert(D == 3)
144
+ o = torch.zeros(B).type_as(w)
145
+ w0, w1, w2 = w[:, 0], w[:, 1], w[:, 2]
146
+ return torch.stack((o, -w2, w1, w2, o, -w0, -w1, w0, o), 1).view(B, 3, 3)
147
+
148
+ def batch_twist2Mat(twist):
149
+ """ The exponential map from so3 to SO3
150
+ Calculate the rotation matrix using Rodrigues' Rotation Formula
151
+ http://electroncastle.com/wp/?p=39
152
+ or Ethan Eade's lie group note:
153
+ http://ethaneade.com/lie.pdf equation (13)-(15)
154
+ @todo: may rename the interface to batch_so3expmap(twist)
155
+ functioned tested with cv2.Rodrigues implementation in 'test_geometry.py'
156
+ :input
157
+ :param twist/axis angle Bx3 \in \so3 space
158
+ ----------
159
+ :return
160
+ :param Rotation matrix Bx3x3 \in \SO3 space
161
+ """
162
+ B = twist.size()[0]
163
+ theta = twist.norm(p=2, dim=1).view(B, 1)
164
+ w_so3 = twist / theta.expand(B, 3)
165
+ W = batch_skew(w_so3)
166
+ return torch.eye(3).repeat(B,1,1).type_as(W) \
167
+ + W*sin(theta.view(B,1,1)) \
168
+ + W.bmm(W)*(1-cos(theta).view(B,1,1))
169
+
170
+ def batch_mat2angle(R):
171
+ """ Calcuate the axis angles (twist) from a batch of rotation matrices
172
+ Ethan Eade's lie group note:
173
+ http://ethaneade.com/lie.pdf equation (17)
174
+ function tested in 'test_geometry.py'
175
+ :input
176
+ :param Rotation matrix Bx3x3 \in \SO3 space
177
+ --------
178
+ :return
179
+ :param the axis angle B
180
+ """
181
+ R1 = [torch.trace(R[i]) for i in range(R.size()[0])]
182
+ R_trace = torch.stack(R1)
183
+ # clamp if the angle is too large (break small angle assumption)
184
+ # @todo: not sure whether it is absoluately necessary in training.
185
+ angle = acos( ((R_trace - 1)/2).clamp(-1,1))
186
+ return angle
187
+
188
+ def batch_mat2twist(R):
189
+ """ The log map from SO3 to so3
190
+ Calculate the twist vector from Rotation matrix
191
+ Ethan Eade's lie group note:
192
+ http://ethaneade.com/lie.pdf equation (18)
193
+ @todo: may rename the interface to batch_so3logmap(R)
194
+ function tested in 'test_geometry.py'
195
+ @note: it currently does not consider extreme small values.
196
+ If you use it as training loss, you may run into problems
197
+ :input
198
+ :param Rotation matrix Bx3x3 \in \SO3 space
199
+ --------
200
+ :param the twist vector Bx3 \in \so3 space
201
+ """
202
+ B = R.size()[0]
203
+
204
+ R1 = [torch.trace(R[i]) for i in range(R.size()[0])]
205
+ tr = torch.stack(R1)
206
+ theta = acos( ((tr - 1)/2).clamp(-1,1) )
207
+
208
+ r11,r12,r13,r21,r22,r23,r31,r32,r33 = torch.split(R.view(B,-1),1,dim=1)
209
+ res = torch.cat([r32-r23, r13-r31, r21-r12],dim=1)
210
+
211
+ magnitude = (0.5*theta/sin(theta))
212
+
213
+ return magnitude.view(B,1) * res
214
+
215
+ def batch_warp_inverse_depth(p_x, p_y, p_invD, pose, K):
216
+ """ Compute the warping grid w.r.t. the SE3 transform given the inverse depth
217
+ :input
218
+ :param p_x the x coordinate map
219
+ :param p_y the y coordinate map
220
+ :param p_invD the inverse depth
221
+ :param pose the 3D transform in SE3
222
+ :param K the intrinsics
223
+ --------
224
+ :return
225
+ :param projected u coordinate in image space Bx1xHxW
226
+ :param projected v coordinate in image space Bx1xHxW
227
+ :param projected inverse depth Bx1XHxW
228
+ """
229
+ [R, t] = pose
230
+ B, _, H, W = p_x.shape
231
+
232
+ I = torch.ones((B,1,H,W)).type_as(p_invD)
233
+ x_y_1 = torch.cat((p_x, p_y, I), dim=1)
234
+
235
+ warped = torch.bmm(R, x_y_1.view(B,3,H*W)) + \
236
+ t.view(B,3,1).expand(B,3,H*W) * p_invD.view(B, 1, H*W).expand(B,3,H*W)
237
+
238
+ x_, y_, s_ = torch.split(warped, 1, dim=1)
239
+ fx, fy, cx, cy = torch.split(K, 1, dim=1)
240
+
241
+ u_ = (x_ / s_).view(B,-1) * fx + cx
242
+ v_ = (y_ / s_).view(B,-1) * fy + cy
243
+
244
+ inv_z_ = p_invD / s_.view(B,1,H,W)
245
+
246
+ return u_.view(B,1,H,W), v_.view(B,1,H,W), inv_z_
247
+
248
+ def batch_warp_affine(pu, pv, affine):
249
+ # A = affine[:,:,:2]
250
+ # t = affine[:,:, 2]
251
+ B,_,H,W = pu.shape
252
+ ones = torch.ones(pu.shape).type_as(pu)
253
+ uv = torch.cat((pu, pv, ones), dim=1)
254
+ uv = torch.bmm(affine, uv.view(B,3,-1)) #+ t.view(B,2,1)
255
+ return uv[:,0].view(B,1,H,W), uv[:,1].view(B,1,H,W)
256
+
257
+ def check_occ(inv_z_buffer, inv_z_ref, u, v, thres=1e-1):
258
+ """ z-buffering check of occlusion
259
+ :param inverse depth of target frame
260
+ :param inverse depth of reference frame
261
+ """
262
+ B, _, H, W = inv_z_buffer.shape
263
+
264
+ inv_z_warped = warp_features(inv_z_ref, u, v)
265
+ inlier = (inv_z_buffer > inv_z_warped - thres)
266
+
267
+ inviews = inlier & (u > 0) & (u < W) & \
268
+ (v > 0) & (v < H)
269
+
270
+ return 1-inviews
271
+
272
+ def warp_features(F, u, v):
273
+ """
274
+ Warp the feature map (F) w.r.t. the grid (u, v)
275
+ """
276
+ B, C, H, W = F.shape
277
+
278
+ u_norm = u / ((W-1)/2) - 1
279
+ v_norm = v / ((H-1)/2) - 1
280
+ uv_grid = torch.cat((u_norm.view(B,H,W,1), v_norm.view(B,H,W,1)), dim=3)
281
+ F_warped = nn.functional.grid_sample(F, uv_grid,
282
+ mode='bilinear', padding_mode='border')
283
+ return F_warped
284
+
285
+ def batch_transform_xyz(xyz_tensor, R, t, get_Jacobian=True):
286
+ '''
287
+ transform the point cloud w.r.t. the transformation matrix
288
+ :param xyz_tensor: B * 3 * H * W
289
+ :param R: rotation matrix B * 3 * 3
290
+ :param t: translation vector B * 3
291
+ '''
292
+ B, C, H, W = xyz_tensor.size()
293
+ t_tensor = t.contiguous().view(B,3,1).repeat(1,1,H*W)
294
+ p_tensor = xyz_tensor.contiguous().view(B, C, H*W)
295
+ # the transformation process is simply:
296
+ # p' = t + R*p
297
+ xyz_t_tensor = torch.baddbmm(t_tensor, R, p_tensor)
298
+
299
+ if get_Jacobian:
300
+ # return both the transformed tensor and its Jacobian matrix
301
+ J_r = R.bmm(batch_skew_symmetric_matrix(-1*p_tensor.permute(0,2,1)))
302
+ J_t = -1 * torch.eye(3).view(1,3,3).expand(B,3,3)
303
+ J = torch.cat((J_r, J_t), 1)
304
+ return xyz_t_tensor.view(B, C, H, W), J
305
+ else:
306
+ return xyz_t_tensor.view(B, C, H, W)
307
+
308
+ def flow_from_rigid_transform(depth, extrinsic, intrinsic):
309
+ """
310
+ Get the optical flow induced by rigid transform [R,t] and depth
311
+ """
312
+ [R, t] = extrinsic
313
+ [fx, fy, cx, cy] = intrinsic
314
+
315
+ def batch_project(xyz_tensor, K):
316
+ """ Project a point cloud into pixels (u,v) given intrinsic K
317
+ [u';v';w] = [K][x;y;z]
318
+ u = u' / w; v = v' / w
319
+ :param the xyz points
320
+ :param calibration is a torch array composed of [fx, fy, cx, cy]
321
+ -------
322
+ :return u, v grid tensor in image coordinate
323
+ (tested through inverse project)
324
+ """
325
+ B, _, H, W = xyz_tensor.size()
326
+ batch_K = K.expand(H, W, B, 4).permute(2,3,0,1)
327
+
328
+ x, y, z = torch.split(xyz_tensor, 1, dim=1)
329
+ fx, fy, cx, cy = torch.split(batch_K, 1, dim=1)
330
+
331
+ u = fx*x / z + cx
332
+ v = fy*y / z + cy
333
+ return torch.cat((u,v), dim=1)
334
+
335
+ def batch_inverse_project(depth, K):
336
+ """ Inverse project pixels (u,v) to a point cloud given intrinsic
337
+ :param depth dim B*H*W
338
+ :param calibration is torch array composed of [fx, fy, cx, cy]
339
+ :param color (optional) dim B*3*H*W
340
+ -------
341
+ :return xyz tensor (batch of point cloud)
342
+ (tested through projection)
343
+ """
344
+ if depth.dim() == 3:
345
+ B, H, W = depth.size()
346
+ else:
347
+ B, _, H, W = depth.size()
348
+
349
+ x, y = generate_xy_grid(B,H,W,K)
350
+ z = depth.view(B,1,H,W)
351
+ return torch.cat((x*z, y*z, z), dim=1)
352
+
353
+ def batch_euler2mat(ai, aj, ak, axes='sxyz'):
354
+ """ A torch implementation euler2mat from transform3d:
355
+ https://github.com/matthew-brett/transforms3d/blob/master/transforms3d/euler.py
356
+ :param ai : First rotation angle (according to `axes`).
357
+ :param aj : Second rotation angle (according to `axes`).
358
+ :param ak : Third rotation angle (according to `axes`).
359
+ :param axes : Axis specification; one of 24 axis sequences as string or encoded tuple - e.g. ``sxyz`` (the default).
360
+ -------
361
+ :return rotation matrix, array-like shape (B, 3, 3)
362
+ Tested w.r.t. transforms3d.euler module
363
+ """
364
+ B = ai.size()[0]
365
+
366
+ try:
367
+ firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()]
368
+ except (AttributeError, KeyError):
369
+ _TUPLE2AXES[axes] # validation
370
+ firstaxis, parity, repetition, frame = axes
371
+
372
+ i = firstaxis
373
+ j = _NEXT_AXIS[i+parity]
374
+ k = _NEXT_AXIS[i-parity+1]
375
+ order = [i, j, k]
376
+
377
+ if frame:
378
+ ai, ak = ak, ai
379
+ if parity:
380
+ ai, aj, ak = -ai, -aj, -ak
381
+
382
+ si, sj, sk = sin(ai), sin(aj), sin(ak)
383
+ ci, cj, ck = cos(ai), cos(aj), cos(ak)
384
+ cc, cs = ci*ck, ci*sk
385
+ sc, ss = si*ck, si*sk
386
+
387
+ # M = torch.zeros(B, 3, 3).cuda()
388
+ if repetition:
389
+ c_i = [cj, sj*si, sj*ci]
390
+ c_j = [sj*sk, -cj*ss+cc, -cj*cs-sc]
391
+ c_k = [-sj*ck, cj*sc+cs, cj*cc-ss]
392
+ else:
393
+ c_i = [cj*ck, sj*sc-cs, sj*cc+ss]
394
+ c_j = [cj*sk, sj*ss+cc, sj*cs-sc]
395
+ c_k = [-sj, cj*si, cj*ci]
396
+
397
+ def permute(X): # sort X w.r.t. the axis indices
398
+ return [ x for (y, x) in sorted(zip(order, X)) ]
399
+
400
+ c_i = permute(c_i)
401
+ c_j = permute(c_j)
402
+ c_k = permute(c_k)
403
+
404
+ r =[torch.stack(c_i, 1),
405
+ torch.stack(c_j, 1),
406
+ torch.stack(c_k, 1)]
407
+ r = permute(r)
408
+
409
+ return torch.stack(r, 1)
410
+
411
+ def batch_mat2euler(M, axes='sxyz'):
412
+ """ A torch implementation euler2mat from transform3d:
413
+ https://github.com/matthew-brett/transforms3d/blob/master/transforms3d/euler.py
414
+ :param array-like shape (3, 3) or (4, 4). Rotation matrix or affine.
415
+ :param Axis specification; one of 24 axis sequences as string or encoded tuple - e.g. ``sxyz`` (the default).
416
+ --------
417
+ :returns
418
+ :param ai : First rotation angle (according to `axes`).
419
+ :param aj : Second rotation angle (according to `axes`).
420
+ :param ak : Third rotation angle (according to `axes`).
421
+ """
422
+ try:
423
+ firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()]
424
+ except (AttributeError, KeyError):
425
+ _TUPLE2AXES[axes] # validation
426
+ firstaxis, parity, repetition, frame = axes
427
+
428
+ i = firstaxis
429
+ j = _NEXT_AXIS[i+parity]
430
+ k = _NEXT_AXIS[i-parity+1]
431
+
432
+ if repetition:
433
+ sy = torch.sqrt(M[:, i, j]**2 + M[:, i, k]**2)
434
+ # A lazy way to cope with batch data. Can be more efficient
435
+ mask = ~(sy > 1e-8)
436
+ ax = atan2( M[:, i, j], M[:, i, k])
437
+ ay = atan2( sy, M[:, i, i])
438
+ az = atan2( M[:, j, i], -M[:, k, i])
439
+ if mask.sum() > 0:
440
+ ax[mask] = atan2(-M[:, j, k][mask], M[:, j, j][mask])
441
+ ay[mask] = atan2( sy[mask], M[:, i, i][mask])
442
+ az[mask] = 0.0
443
+ else:
444
+ cy = torch.sqrt(M[:, i, i]**2 + M[:, j, i]**2)
445
+ mask = ~(cy > 1e-8)
446
+ ax = atan2( M[:, k, j], M[:, k, k])
447
+ ay = atan2(-M[:, k, i], cy)
448
+ az = atan2( M[:, j, i], M[:, i, i])
449
+ if mask.sum() > 0:
450
+ ax[mask] = atan2(-M[:, j, k][mask], M[:, j, j][mask])
451
+ ay[mask] = atan2(-M[:, k, i][mask], cy[mask])
452
+ az[mask] = 0.0
453
+
454
+ if parity:
455
+ ax, ay, az = -ax, -ay, -az
456
+ if frame:
457
+ ax, az = az, ax
458
+ return ax, ay, az
@@ -0,0 +1,183 @@
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ def white_balance(img):
8
+ img = (img*255.).astype(np.uint8)
9
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
10
+ avg_a = np.average(img[:, :, 1])
11
+ avg_b = np.average(img[:, :, 2])
12
+ img[:, :, 1] = img[:, :, 1] - ((avg_a - 128) * (img[:, :, 0] / 255.0) * 1.1)
13
+ img[:, :, 2] = img[:, :, 2] - ((avg_b - 128) * (img[:, :, 0] / 255.0) * 1.1)
14
+ img = cv2.cvtColor(img, cv2.COLOR_LAB2BGR)
15
+ img = img.astype(np.float)/255.
16
+ return img
17
+
18
+ def warp_image_flow(ref_image, flow):
19
+ [B, _, H, W] = ref_image.size()
20
+
21
+ # mesh grid
22
+ xx = torch.arange(0, W).view(1,-1).repeat(H,1)
23
+ yy = torch.arange(0, H).view(-1,1).repeat(1,W)
24
+ xx = xx.view(1,1,H,W).repeat(B,1,1,1)
25
+ yy = yy.view(1,1,H,W).repeat(B,1,1,1)
26
+ grid = torch.cat((xx,yy),1).float()
27
+
28
+ if ref_image.is_cuda:
29
+ grid = grid.cuda()
30
+
31
+ flow_f = flow + grid
32
+ flow_fx = flow_f[:, 0, :, :]
33
+ flow_fy = flow_f[:, 1, :, :]
34
+
35
+ with torch.no_grad():
36
+ mask_x = ~((flow_fx < 0) | (flow_fx > (W - 1)))
37
+ mask_y = ~((flow_fy < 0) | (flow_fy > (H - 1)))
38
+ mask = mask_x & mask_y
39
+ mask = mask.unsqueeze(1)
40
+
41
+ flow_fx = flow_fx / float(W) * 2. - 1.
42
+ flow_fy = flow_fy / float(H) * 2. - 1.
43
+
44
+ flow_fxy = torch.stack([flow_fx, flow_fy], dim=-1)
45
+ img = torch.nn.functional.grid_sample(ref_image, flow_fxy, padding_mode='zeros')
46
+ return img, mask
47
+
48
+ def warp_image(depth_cur, T_cur2ref, K, img_ref, crop_tl_h=0, crop_tl_w=0):
49
+ B,_, H, W = depth_cur.size()
50
+ fx, fy, cx, cy=K[:,0,0], K[:,1,1], K[:,0,2], K[:,1,2]
51
+ fx = fx.unsqueeze(-1)
52
+ fy = fy.unsqueeze(-1)
53
+ cx = cx.unsqueeze(-1)
54
+ cy = cy.unsqueeze(-1)
55
+
56
+ x_ref = torch.arange(0, W, 1).float().cuda() + crop_tl_w
57
+ y_ref = torch.arange(0, H, 1).float().cuda() + crop_tl_h
58
+
59
+ x_ref = x_ref.unsqueeze(0).repeat(B,1)
60
+ y_ref = y_ref.unsqueeze(0).repeat(B,1)
61
+
62
+ x_ref = (x_ref - cx)/fx
63
+ y_ref = (y_ref - cy)/fy
64
+
65
+ x_ref = x_ref.unsqueeze(1)
66
+ y_ref = y_ref.unsqueeze(-1)
67
+
68
+ xx_ref = x_ref.repeat(1, H, 1).unsqueeze(1)
69
+ yy_ref = y_ref.repeat(1, 1, W).unsqueeze(1)
70
+ ones = torch.ones_like(xx_ref)
71
+
72
+ p3d_ref = torch.cat([xx_ref, yy_ref, ones], dim=1)*depth_cur
73
+ ones = torch.ones_like(depth_cur)
74
+
75
+ p4d_ref = torch.cat([p3d_ref, ones], dim=1)
76
+ p4d_ref = p4d_ref.view(B,4,-1)
77
+
78
+ p3d_cur = T_cur2ref.bmm(p4d_ref)
79
+ p3d_cur = p3d_cur/(p3d_cur[:,2,:].unsqueeze(dim=1) + 1e-8)
80
+ p2d_cur = K.clone().bmm(p3d_cur)[:,:2,:]
81
+
82
+ p2d_cur[:,0,:]=p2d_cur[:,0,:] - crop_tl_w
83
+ p2d_cur[:,1,:]=p2d_cur[:,1,:] - crop_tl_h
84
+
85
+ # normalize
86
+ p2d_cur[:,0,:] = 2.0 * (p2d_cur[:,0,:] - W * 0.5 + 0.5) / (W - 1.)
87
+ p2d_cur[:,1,:] = 2.0 * (p2d_cur[:,1,:] - H * 0.5 + 0.5) / (H - 1.)
88
+
89
+ p2d_cur = p2d_cur.permute(0, 2, 1)
90
+ p2d_cur = p2d_cur.view(B,H,W,2)
91
+
92
+ with torch.no_grad():
93
+ mask_x = ~((p2d_cur[:,:,:,0] < -1.) | (p2d_cur[:,:,:,0] > 1.))
94
+ mask_y = ~((p2d_cur[:,:,:,1] < -1.) | (p2d_cur[:,:,:,1] > 1.))
95
+ mask = mask_x & mask_y
96
+ mask = mask.unsqueeze(1)
97
+
98
+ syn_ref_image = torch.nn.functional.grid_sample(img_ref, p2d_cur, padding_mode='zeros')
99
+ return syn_ref_image, mask.float()
100
+
101
+ class Grid_gradient_central_diff():
102
+ def __init__(self, nc, padding=True, diagonal=False):
103
+ self.conv_x = nn.Conv2d(nc, nc, kernel_size=2, stride=1, bias=False)
104
+ self.conv_y = nn.Conv2d(nc, nc, kernel_size=2, stride=1, bias=False)
105
+ self.conv_xy = None
106
+ if diagonal:
107
+ self.conv_xy = nn.Conv2d(nc, nc, kernel_size=2, stride=1, bias=False)
108
+
109
+ self.padding=None
110
+ if padding:
111
+ self.padding = nn.ReplicationPad2d([0,1,0,1])
112
+
113
+ fx = torch.zeros(nc, nc, 2, 2).float().cuda()
114
+ fy = torch.zeros(nc, nc, 2, 2).float().cuda()
115
+ if diagonal:
116
+ fxy = torch.zeros(nc, nc, 2, 2).float().cuda()
117
+
118
+ fx_ = torch.tensor([[1,-1],[0,0]]).cuda()
119
+ fy_ = torch.tensor([[1,0],[-1,0]]).cuda()
120
+ if diagonal:
121
+ fxy_ = torch.tensor([[1,0],[0,-1]]).cuda()
122
+
123
+ for i in range(nc):
124
+ fx[i, i, :, :] = fx_
125
+ fy[i, i, :, :] = fy_
126
+ if diagonal:
127
+ fxy[i,i,:,:] = fxy_
128
+
129
+ self.conv_x.weight = nn.Parameter(fx)
130
+ self.conv_y.weight = nn.Parameter(fy)
131
+ if diagonal:
132
+ self.conv_xy.weight = nn.Parameter(fxy)
133
+
134
+ def __call__(self, grid_2d):
135
+ _image = grid_2d
136
+ if self.padding is not None:
137
+ _image = self.padding(_image)
138
+ dx = self.conv_x(_image)
139
+ dy = self.conv_y(_image)
140
+
141
+ if self.conv_xy is not None:
142
+ dxy = self.conv_xy(_image)
143
+ return dx, dy, dxy
144
+ return dx, dy
145
+
146
+ class RandomScaleCrop(object):
147
+ """Randomly zooms images up to 15% and crop them to keep same size as before."""
148
+ def __call__(self, images, intrinsics):
149
+ assert intrinsics is not None
150
+ output_intrinsics = intrinsics.clone()
151
+
152
+ _, _, in_h, in_w = images.size()
153
+ x_scaling, y_scaling = np.random.uniform(1, 1.15, 2)
154
+ scaled_h, scaled_w = int(in_h * y_scaling), int(in_w * x_scaling)
155
+
156
+ output_intrinsics[:, 0, 0] *= x_scaling
157
+ output_intrinsics[:, 0, 2] *= x_scaling
158
+ output_intrinsics[:, 1, 1] *= y_scaling
159
+ output_intrinsics[:, 1, 2] *= y_scaling
160
+ scaled_images = F.interpolate(images, size=(scaled_h, scaled_w), mode='bilinear')
161
+
162
+ offset_y = np.random.randint(scaled_h - in_h + 1)
163
+ offset_x = np.random.randint(scaled_w - in_w + 1)
164
+ cropped_images = scaled_images[:, :, offset_y:offset_y + in_h, offset_x:offset_x + in_w]
165
+
166
+ output_intrinsics[:, 0, 2] -= offset_x
167
+ output_intrinsics[:, 1, 2] -= offset_y
168
+
169
+ return cropped_images, output_intrinsics
170
+
171
+ if __name__=='__main__':
172
+ from skimage import io
173
+ import cv2
174
+ image = io.imread('/home/peidong/leonhard/project/infk/cvg/liup/mydata/KITTI/odometry/resized/832x256/test/09/image_2/001545.jpg')
175
+ image = torch.from_numpy(image.transpose(2,0,1)).float()/255.
176
+ image = image.unsqueeze(0)
177
+
178
+ intrinsics = torch.eye(3).float().unsqueeze(0)
179
+ cropped_images, intrinsics = RandomScaleCrop()(image, intrinsics)
180
+
181
+ cv2.imshow('orig', image.numpy().transpose(0,2,3,1)[0])
182
+ cv2.imshow('crop', cropped_images.numpy().transpose(0,2,3,1)[0])
183
+ cv2.waitKey(0)
@@ -0,0 +1,40 @@
1
+ import torch
2
+
3
+ def transformation_matrix_multiply(T1, T2):
4
+ if T1.dim() == 2:
5
+ T1=torch.unsqueeze(T1, 0)
6
+
7
+ if T2.dim() == 2:
8
+ T2=torch.unsqueeze(T2, 0)
9
+
10
+ R1 = T1[:, :, :3]
11
+ t1 = T1[:, :, 3:4]
12
+
13
+ R2 = T2[:, :, :3]
14
+ t2 = T2[:, :, 3:4]
15
+
16
+ R = torch.bmm(R1, R2)
17
+ t = torch.bmm(R1, t2) + t1
18
+
19
+ return torch.cat([R, t], dim=2)
20
+
21
+ def transformation_matrix_inverse(T):
22
+ if T.dim() == 2:
23
+ T = T.unsqueeze(dim=0)
24
+ R = T[:, :, :3]
25
+ t = T[:, :, 3:4]
26
+
27
+ R_inv = R.transpose(2, 1)
28
+ t_inv = torch.bmm(R_inv, t)
29
+ t_inv = -1. * t_inv
30
+ return torch.cat([R_inv, t_inv], dim=2)
31
+
32
+ def skew_matrix(phi):
33
+ Phi=torch.zeros([phi.shape[0], 3, 3], dtype=phi.dtype, device=phi.device)
34
+ Phi[:, 0, 1] = -phi[:, 2]
35
+ Phi[:, 1, 0] = phi[:, 2]
36
+ Phi[:, 0, 2] = phi[:, 1]
37
+ Phi[:, 2, 0] = -phi[:, 1]
38
+ Phi[:, 1, 2] = -phi[:, 0]
39
+ Phi[:, 2, 1] = phi[:, 0]
40
+ return Phi