spikezoo 0.1.1__py3-none-any.whl → 0.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- spikezoo/__init__.py +13 -0
- spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/base/nets.py +34 -0
- spikezoo/archs/bsf/README.md +92 -0
- spikezoo/archs/bsf/datasets/datasets.py +328 -0
- spikezoo/archs/bsf/datasets/ds_utils.py +64 -0
- spikezoo/archs/bsf/main.py +398 -0
- spikezoo/archs/bsf/metrics/psnr.py +22 -0
- spikezoo/archs/bsf/metrics/ssim.py +54 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/align.py +154 -0
- spikezoo/archs/bsf/models/bsf/bsf.py +105 -0
- spikezoo/archs/bsf/models/bsf/dsft_convert.py +96 -0
- spikezoo/archs/bsf/models/bsf/rep.py +44 -0
- spikezoo/archs/bsf/models/get_model.py +7 -0
- spikezoo/archs/bsf/prepare_data/DSFT.py +62 -0
- spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +135 -0
- spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +139 -0
- spikezoo/archs/bsf/prepare_data/crop_train.sh +4 -0
- spikezoo/archs/bsf/prepare_data/crop_val.sh +4 -0
- spikezoo/archs/bsf/prepare_data/io_utils.py +64 -0
- spikezoo/archs/bsf/requirements.txt +9 -0
- spikezoo/archs/bsf/test.py +16 -0
- spikezoo/archs/bsf/utils.py +154 -0
- spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spikeclip/nets.py +40 -0
- spikezoo/archs/spikeformer/CheckPoints/readme +1 -0
- spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +60 -0
- spikezoo/archs/spikeformer/DataProcess/DataLoader.py +115 -0
- spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +39 -0
- spikezoo/archs/spikeformer/EvalResults/readme +1 -0
- spikezoo/archs/spikeformer/LICENSE +21 -0
- spikezoo/archs/spikeformer/Metrics/Metrics.py +50 -0
- spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/Loss.py +89 -0
- spikezoo/archs/spikeformer/Model/SpikeFormer.py +230 -0
- spikezoo/archs/spikeformer/Model/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/README.md +30 -0
- spikezoo/archs/spikeformer/evaluate.py +87 -0
- spikezoo/archs/spikeformer/recon_real_data.py +97 -0
- spikezoo/archs/spikeformer/requirements.yml +95 -0
- spikezoo/archs/spikeformer/train.py +173 -0
- spikezoo/archs/spikeformer/utils.py +22 -0
- spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +23 -0
- spikezoo/archs/spk2imgnet/.gitignore +150 -0
- spikezoo/archs/spk2imgnet/DCNv2.py +135 -0
- spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/align_arch.py +159 -0
- spikezoo/archs/spk2imgnet/dataset.py +144 -0
- spikezoo/archs/spk2imgnet/nets.py +230 -0
- spikezoo/archs/spk2imgnet/readme.md +86 -0
- spikezoo/archs/spk2imgnet/test_gen_imgseq.py +118 -0
- spikezoo/archs/spk2imgnet/train.py +189 -0
- spikezoo/archs/spk2imgnet/utils.py +64 -0
- spikezoo/archs/ssir/README.md +87 -0
- spikezoo/archs/ssir/configs/SSIR.yml +37 -0
- spikezoo/archs/ssir/configs/yml_parser.py +78 -0
- spikezoo/archs/ssir/datasets/dataset_sreds.py +170 -0
- spikezoo/archs/ssir/datasets/ds_utils.py +66 -0
- spikezoo/archs/ssir/losses.py +21 -0
- spikezoo/archs/ssir/main.py +326 -0
- spikezoo/archs/ssir/metrics/psnr.py +22 -0
- spikezoo/archs/ssir/metrics/ssim.py +54 -0
- spikezoo/archs/ssir/models/Vgg19.py +42 -0
- spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/layers.py +110 -0
- spikezoo/archs/ssir/models/networks.py +61 -0
- spikezoo/archs/ssir/requirements.txt +8 -0
- spikezoo/archs/ssir/shells/eval_SREDS.sh +6 -0
- spikezoo/archs/ssir/shells/train_SSIR.sh +12 -0
- spikezoo/archs/ssir/test.py +3 -0
- spikezoo/archs/ssir/utils.py +154 -0
- spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/cbam.py +224 -0
- spikezoo/archs/ssml/model.py +290 -0
- spikezoo/archs/ssml/res.png +0 -0
- spikezoo/archs/ssml/test.py +67 -0
- spikezoo/archs/stir/.git-credentials +0 -0
- spikezoo/archs/stir/README.md +65 -0
- spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +1 -0
- spikezoo/archs/stir/configs/STIR.yml +37 -0
- spikezoo/archs/stir/configs/utils.py +155 -0
- spikezoo/archs/stir/configs/yml_parser.py +78 -0
- spikezoo/archs/stir/datasets/dataset_sreds.py +180 -0
- spikezoo/archs/stir/datasets/ds_utils.py +66 -0
- spikezoo/archs/stir/eval_SREDS.sh +5 -0
- spikezoo/archs/stir/main.py +397 -0
- spikezoo/archs/stir/metrics/losses.py +219 -0
- spikezoo/archs/stir/metrics/psnr.py +22 -0
- spikezoo/archs/stir/metrics/ssim.py +54 -0
- spikezoo/archs/stir/models/Vgg19.py +42 -0
- spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/networks_STIR.py +361 -0
- spikezoo/archs/stir/models/submodules.py +86 -0
- spikezoo/archs/stir/models/transformer_new.py +151 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +721 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +133 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +167 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +76 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +458 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +183 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +40 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +198 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +51 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +53 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +100 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +333 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +123 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +72 -0
- spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
- spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/convertions.py +721 -0
- spikezoo/archs/stir/package_core/package_core/disp_netS.py +133 -0
- spikezoo/archs/stir/package_core/package_core/flow_utils.py +167 -0
- spikezoo/archs/stir/package_core/package_core/generic_train_test.py +76 -0
- spikezoo/archs/stir/package_core/package_core/geometry.py +458 -0
- spikezoo/archs/stir/package_core/package_core/image_proc.py +183 -0
- spikezoo/archs/stir/package_core/package_core/linalg.py +40 -0
- spikezoo/archs/stir/package_core/package_core/losses.py +198 -0
- spikezoo/archs/stir/package_core/package_core/metrics.py +51 -0
- spikezoo/archs/stir/package_core/package_core/model_base.py +53 -0
- spikezoo/archs/stir/package_core/package_core/net_basics.py +100 -0
- spikezoo/archs/stir/package_core/package_core/resnet.py +333 -0
- spikezoo/archs/stir/package_core/package_core/transforms.py +123 -0
- spikezoo/archs/stir/package_core/package_core/utils.py +72 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +3 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +20 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +1 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +1 -0
- spikezoo/archs/stir/package_core/setup.py +5 -0
- spikezoo/archs/stir/requirements.txt +12 -0
- spikezoo/archs/stir/train_STIR.sh +9 -0
- spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfi/nets.py +43 -0
- spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfp/nets.py +13 -0
- spikezoo/archs/wgse/README.md +64 -0
- spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/dataset.py +59 -0
- spikezoo/archs/wgse/demo.png +0 -0
- spikezoo/archs/wgse/demo.py +83 -0
- spikezoo/archs/wgse/dwtnets.py +145 -0
- spikezoo/archs/wgse/eval.py +133 -0
- spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +11 -0
- spikezoo/archs/wgse/submodules.py +68 -0
- spikezoo/archs/wgse/train.py +261 -0
- spikezoo/archs/wgse/transform.py +139 -0
- spikezoo/archs/wgse/utils.py +128 -0
- spikezoo/archs/wgse/weights/demo.png +0 -0
- spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
- spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
- spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
- spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
- spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
- spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
- spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
- spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
- spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
- spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
- spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
- spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
- spikezoo/datasets/base_dataset.py +2 -3
- spikezoo/metrics/__init__.py +1 -1
- spikezoo/models/base_model.py +1 -3
- spikezoo/pipeline/base_pipeline.py +7 -5
- spikezoo/pipeline/train_pipeline.py +1 -1
- spikezoo/utils/other_utils.py +16 -6
- spikezoo/utils/spike_utils.py +33 -29
- spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
- spikezoo-0.2.dist-info/METADATA +163 -0
- spikezoo-0.2.dist-info/RECORD +211 -0
- spikezoo/models/spcsnet_model.py +0 -19
- spikezoo-0.1.1.dist-info/METADATA +0 -39
- spikezoo-0.1.1.dist-info/RECORD +0 -36
- {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/LICENSE.txt +0 -0
- {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/WHEEL +0 -0
- {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,133 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from torch.nn.init import xavier_uniform_, zeros_
|
5
|
+
|
6
|
+
|
7
|
+
def downsample_conv(in_planes, out_planes, kernel_size=3):
|
8
|
+
return nn.Sequential(
|
9
|
+
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2),
|
10
|
+
nn.ReLU(inplace=True),
|
11
|
+
nn.Conv2d(out_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size-1)//2),
|
12
|
+
nn.ReLU(inplace=True)
|
13
|
+
)
|
14
|
+
|
15
|
+
|
16
|
+
def predict_disp(in_planes):
|
17
|
+
return nn.Sequential(
|
18
|
+
nn.Conv2d(in_planes, 1, kernel_size=3, padding=1),
|
19
|
+
nn.Sigmoid()
|
20
|
+
)
|
21
|
+
|
22
|
+
|
23
|
+
def conv(in_planes, out_planes):
|
24
|
+
return nn.Sequential(
|
25
|
+
nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1),
|
26
|
+
nn.ReLU(inplace=True)
|
27
|
+
)
|
28
|
+
|
29
|
+
|
30
|
+
def upconv(in_planes, out_planes):
|
31
|
+
return nn.Sequential(
|
32
|
+
nn.ConvTranspose2d(in_planes, out_planes, kernel_size=3, stride=2, padding=1, output_padding=1),
|
33
|
+
nn.ReLU(inplace=True)
|
34
|
+
)
|
35
|
+
|
36
|
+
|
37
|
+
def crop_like(input, ref):
|
38
|
+
assert(input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3))
|
39
|
+
return input[:, :, :ref.size(2), :ref.size(3)]
|
40
|
+
|
41
|
+
|
42
|
+
class DispNetS(nn.Module):
|
43
|
+
|
44
|
+
def __init__(self, alpha=10, beta=0.01):
|
45
|
+
super(DispNetS, self).__init__()
|
46
|
+
|
47
|
+
self.alpha = alpha
|
48
|
+
self.beta = beta
|
49
|
+
|
50
|
+
conv_planes = [32, 64, 128, 256, 512, 512, 512]
|
51
|
+
self.conv1 = downsample_conv(3, conv_planes[0], kernel_size=7)
|
52
|
+
self.conv2 = downsample_conv(conv_planes[0], conv_planes[1], kernel_size=5)
|
53
|
+
self.conv3 = downsample_conv(conv_planes[1], conv_planes[2])
|
54
|
+
self.conv4 = downsample_conv(conv_planes[2], conv_planes[3])
|
55
|
+
self.conv5 = downsample_conv(conv_planes[3], conv_planes[4])
|
56
|
+
self.conv6 = downsample_conv(conv_planes[4], conv_planes[5])
|
57
|
+
self.conv7 = downsample_conv(conv_planes[5], conv_planes[6])
|
58
|
+
|
59
|
+
upconv_planes = [512, 512, 256, 128, 64, 32, 16]
|
60
|
+
self.upconv7 = upconv(conv_planes[6], upconv_planes[0])
|
61
|
+
self.upconv6 = upconv(upconv_planes[0], upconv_planes[1])
|
62
|
+
self.upconv5 = upconv(upconv_planes[1], upconv_planes[2])
|
63
|
+
self.upconv4 = upconv(upconv_planes[2], upconv_planes[3])
|
64
|
+
self.upconv3 = upconv(upconv_planes[3], upconv_planes[4])
|
65
|
+
self.upconv2 = upconv(upconv_planes[4], upconv_planes[5])
|
66
|
+
self.upconv1 = upconv(upconv_planes[5], upconv_planes[6])
|
67
|
+
|
68
|
+
self.iconv7 = conv(upconv_planes[0] + conv_planes[5], upconv_planes[0])
|
69
|
+
self.iconv6 = conv(upconv_planes[1] + conv_planes[4], upconv_planes[1])
|
70
|
+
self.iconv5 = conv(upconv_planes[2] + conv_planes[3], upconv_planes[2])
|
71
|
+
self.iconv4 = conv(upconv_planes[3] + conv_planes[2], upconv_planes[3])
|
72
|
+
self.iconv3 = conv(1 + upconv_planes[4] + conv_planes[1], upconv_planes[4])
|
73
|
+
self.iconv2 = conv(1 + upconv_planes[5] + conv_planes[0], upconv_planes[5])
|
74
|
+
self.iconv1 = conv(1 + upconv_planes[6], upconv_planes[6])
|
75
|
+
|
76
|
+
self.predict_disp4 = predict_disp(upconv_planes[3])
|
77
|
+
self.predict_disp3 = predict_disp(upconv_planes[4])
|
78
|
+
self.predict_disp2 = predict_disp(upconv_planes[5])
|
79
|
+
self.predict_disp1 = predict_disp(upconv_planes[6])
|
80
|
+
|
81
|
+
def init_weights(self):
|
82
|
+
for m in self.modules():
|
83
|
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
84
|
+
xavier_uniform_(m.weight)
|
85
|
+
if m.bias is not None:
|
86
|
+
zeros_(m.bias)
|
87
|
+
|
88
|
+
def forward(self, x):
|
89
|
+
out_conv1 = self.conv1(x)
|
90
|
+
out_conv2 = self.conv2(out_conv1)
|
91
|
+
out_conv3 = self.conv3(out_conv2)
|
92
|
+
out_conv4 = self.conv4(out_conv3)
|
93
|
+
out_conv5 = self.conv5(out_conv4)
|
94
|
+
out_conv6 = self.conv6(out_conv5)
|
95
|
+
out_conv7 = self.conv7(out_conv6)
|
96
|
+
|
97
|
+
out_upconv7 = crop_like(self.upconv7(out_conv7), out_conv6)
|
98
|
+
concat7 = torch.cat((out_upconv7, out_conv6), 1)
|
99
|
+
out_iconv7 = self.iconv7(concat7)
|
100
|
+
|
101
|
+
out_upconv6 = crop_like(self.upconv6(out_iconv7), out_conv5)
|
102
|
+
concat6 = torch.cat((out_upconv6, out_conv5), 1)
|
103
|
+
out_iconv6 = self.iconv6(concat6)
|
104
|
+
|
105
|
+
out_upconv5 = crop_like(self.upconv5(out_iconv6), out_conv4)
|
106
|
+
concat5 = torch.cat((out_upconv5, out_conv4), 1)
|
107
|
+
out_iconv5 = self.iconv5(concat5)
|
108
|
+
|
109
|
+
out_upconv4 = crop_like(self.upconv4(out_iconv5), out_conv3)
|
110
|
+
concat4 = torch.cat((out_upconv4, out_conv3), 1)
|
111
|
+
out_iconv4 = self.iconv4(concat4)
|
112
|
+
disp4 = self.alpha * self.predict_disp4(out_iconv4) + self.beta
|
113
|
+
|
114
|
+
out_upconv3 = crop_like(self.upconv3(out_iconv4), out_conv2)
|
115
|
+
disp4_up = crop_like(F.interpolate(disp4, scale_factor=2, mode='bilinear', align_corners=False), out_conv2)
|
116
|
+
concat3 = torch.cat((out_upconv3, out_conv2, disp4_up), 1)
|
117
|
+
out_iconv3 = self.iconv3(concat3)
|
118
|
+
disp3 = self.alpha * self.predict_disp3(out_iconv3) + self.beta
|
119
|
+
|
120
|
+
out_upconv2 = crop_like(self.upconv2(out_iconv3), out_conv1)
|
121
|
+
disp3_up = crop_like(F.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
|
122
|
+
concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
|
123
|
+
out_iconv2 = self.iconv2(concat2)
|
124
|
+
disp2 = self.alpha * self.predict_disp2(out_iconv2) + self.beta
|
125
|
+
|
126
|
+
out_upconv1 = crop_like(self.upconv1(out_iconv2), x)
|
127
|
+
disp2_up = crop_like(F.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
|
128
|
+
concat1 = torch.cat((out_upconv1, disp2_up), 1)
|
129
|
+
out_iconv1 = self.iconv1(concat1)
|
130
|
+
disp1 = self.alpha * self.predict_disp1(out_iconv1) + self.beta
|
131
|
+
|
132
|
+
return disp1, disp2, disp3, disp4
|
133
|
+
|
@@ -0,0 +1,167 @@
|
|
1
|
+
import cv2
|
2
|
+
import sys
|
3
|
+
import numpy as np
|
4
|
+
import argparse
|
5
|
+
|
6
|
+
def load_flow(path):
|
7
|
+
with open(path, 'rb') as f:
|
8
|
+
magic = float(np.fromfile(f, np.float32, count = 1)[0])
|
9
|
+
if magic == 202021.25:
|
10
|
+
w, h = np.fromfile(f, np.int32, count = 1)[0], np.fromfile(f, np.int32, count = 1)[0]
|
11
|
+
data = np.fromfile(f, np.float32, count = h*w*2)
|
12
|
+
data.resize((h, w, 2))
|
13
|
+
return data
|
14
|
+
return None
|
15
|
+
|
16
|
+
def save_flow(path, flow):
|
17
|
+
magic = np.array([202021.25], np.float32)
|
18
|
+
h, w = flow.shape[:2]
|
19
|
+
h, w = np.array([h], np.int32), np.array([w], np.int32)
|
20
|
+
|
21
|
+
with open(path, 'wb') as f:
|
22
|
+
magic.tofile(f); w.tofile(f); h.tofile(f); flow.tofile(f)
|
23
|
+
|
24
|
+
def makeColorwheel():
|
25
|
+
|
26
|
+
# color encoding scheme
|
27
|
+
|
28
|
+
# adapted from the color circle idea described at
|
29
|
+
# http://members.shaw.ca/quadibloc/other/colint.htm
|
30
|
+
|
31
|
+
RY = 15
|
32
|
+
YG = 6
|
33
|
+
GC = 4
|
34
|
+
CB = 11
|
35
|
+
BM = 13
|
36
|
+
MR = 6
|
37
|
+
|
38
|
+
ncols = RY + YG + GC + CB + BM + MR
|
39
|
+
|
40
|
+
colorwheel = np.zeros([ncols, 3]) # r g b
|
41
|
+
|
42
|
+
col = 0
|
43
|
+
#RY
|
44
|
+
colorwheel[0:RY, 0] = 255
|
45
|
+
colorwheel[0:RY, 1] = np.floor(255*np.arange(0, RY, 1)/RY)
|
46
|
+
col += RY
|
47
|
+
|
48
|
+
#YG
|
49
|
+
colorwheel[col:YG+col, 0]= 255 - np.floor(255*np.arange(0, YG, 1)/YG)
|
50
|
+
colorwheel[col:YG+col, 1] = 255;
|
51
|
+
col += YG;
|
52
|
+
|
53
|
+
#GC
|
54
|
+
colorwheel[col:GC+col, 1]= 255
|
55
|
+
colorwheel[col:GC+col, 2] = np.floor(255*np.arange(0, GC, 1)/GC)
|
56
|
+
col += GC;
|
57
|
+
|
58
|
+
#CB
|
59
|
+
colorwheel[col:CB+col, 1]= 255 - np.floor(255*np.arange(0, CB, 1)/CB)
|
60
|
+
colorwheel[col:CB+col, 2] = 255
|
61
|
+
col += CB;
|
62
|
+
|
63
|
+
#BM
|
64
|
+
colorwheel[col:BM+col, 2]= 255
|
65
|
+
colorwheel[col:BM+col, 0] = np.floor(255*np.arange(0, BM, 1)/BM)
|
66
|
+
col += BM;
|
67
|
+
|
68
|
+
#MR
|
69
|
+
colorwheel[col:MR+col, 2]= 255 - np.floor(255*np.arange(0, MR, 1)/MR)
|
70
|
+
colorwheel[col:MR+col, 0] = 255
|
71
|
+
return colorwheel
|
72
|
+
|
73
|
+
def computeColor(u, v):
|
74
|
+
|
75
|
+
colorwheel = makeColorwheel();
|
76
|
+
nan_u = np.isnan(u)
|
77
|
+
nan_v = np.isnan(v)
|
78
|
+
nan_u = np.where(nan_u)
|
79
|
+
nan_v = np.where(nan_v)
|
80
|
+
|
81
|
+
u[nan_u] = 0
|
82
|
+
u[nan_v] = 0
|
83
|
+
v[nan_u] = 0
|
84
|
+
v[nan_v] = 0
|
85
|
+
|
86
|
+
ncols = colorwheel.shape[0]
|
87
|
+
radius = np.sqrt(u**2 + v**2)
|
88
|
+
a = np.arctan2(-v, -u) / np.pi
|
89
|
+
fk = (a+1) /2 * (ncols-1) # -1~1 maped to 1~ncols
|
90
|
+
k0 = fk.astype(np.uint8) # 1, 2, ..., ncols
|
91
|
+
k1 = k0+1
|
92
|
+
k1[k1 == ncols] = 0
|
93
|
+
f = fk - k0
|
94
|
+
|
95
|
+
img = np.empty([k1.shape[0], k1.shape[1],3])
|
96
|
+
ncolors = colorwheel.shape[1]
|
97
|
+
for i in range(ncolors):
|
98
|
+
tmp = colorwheel[:,i]
|
99
|
+
col0 = tmp[k0]/255
|
100
|
+
col1 = tmp[k1]/255
|
101
|
+
col = (1-f)*col0 + f*col1
|
102
|
+
idx = radius <= 1
|
103
|
+
col[idx] = 1 - radius[idx]*(1-col[idx]) # increase saturation with radius
|
104
|
+
col[~idx] *= 0.75 # out of range
|
105
|
+
img[:,:,2-i] = np.floor(255*col).astype(np.uint8)
|
106
|
+
|
107
|
+
return img.astype(np.uint8)
|
108
|
+
|
109
|
+
|
110
|
+
def flow2rgb(flow):
|
111
|
+
# H, W, 2
|
112
|
+
eps = sys.float_info.epsilon
|
113
|
+
UNKNOWN_FLOW_THRESH = 1e9
|
114
|
+
UNKNOWN_FLOW = 1e10
|
115
|
+
|
116
|
+
u = flow[:,:,0]
|
117
|
+
v = flow[:,:,1]
|
118
|
+
|
119
|
+
maxu = -999
|
120
|
+
maxv = -999
|
121
|
+
|
122
|
+
minu = 999
|
123
|
+
minv = 999
|
124
|
+
|
125
|
+
maxrad = -1
|
126
|
+
#fix unknown flow
|
127
|
+
greater_u = np.where(u > UNKNOWN_FLOW_THRESH)
|
128
|
+
greater_v = np.where(v > UNKNOWN_FLOW_THRESH)
|
129
|
+
u[greater_u] = 0
|
130
|
+
u[greater_v] = 0
|
131
|
+
v[greater_u] = 0
|
132
|
+
v[greater_v] = 0
|
133
|
+
|
134
|
+
maxu = max([maxu, np.amax(u)])
|
135
|
+
minu = min([minu, np.amin(u)])
|
136
|
+
|
137
|
+
maxv = max([maxv, np.amax(v)])
|
138
|
+
minv = min([minv, np.amin(v)])
|
139
|
+
rad = np.sqrt(np.multiply(u,u)+np.multiply(v,v))
|
140
|
+
maxrad = max([maxrad, np.amax(rad)])
|
141
|
+
# print('max flow: %.4f flow range: u = %.3f .. %.3f; v = %.3f .. %.3f\n' % (maxrad, minu, maxu, minv, maxv))
|
142
|
+
|
143
|
+
u = u/(maxrad+eps)
|
144
|
+
v = v/(maxrad+eps)
|
145
|
+
img = computeColor(u, v)
|
146
|
+
return img[:,:,[2,1,0]]
|
147
|
+
|
148
|
+
def flow_to_numpy_rgb(flow):
|
149
|
+
flow_map_np = flow.detach().cpu().numpy().transpose((0, 2, 3, 1))
|
150
|
+
B, H, W, _ = flow_map_np.shape
|
151
|
+
colored_rgb = np.empty([B, H, W, 3])
|
152
|
+
for i in range(B):
|
153
|
+
colored_array= flow2rgb(flow_map_np[i])
|
154
|
+
colored_rgb[i] = colored_array
|
155
|
+
return colored_rgb.astype(np.uint8)
|
156
|
+
|
157
|
+
#if __name__ == '__main__':
|
158
|
+
# import matplotlib.pyplot as plt#
|
159
|
+
#
|
160
|
+
# flow = load_flow('/home/autovision/mycode/PWC-Net/PyTorch/tmp/frame_0010.flo')
|
161
|
+
# #flow = load_flow('datasets/Sintel/training/flow/alley_1/frame_0001.flo')
|
162
|
+
# img = vis_flow(flow)
|
163
|
+
# import imageio
|
164
|
+
# imageio.imsave('test.png', img)
|
165
|
+
# import cv2
|
166
|
+
# cv2.imshow('', img[:,:,:])
|
167
|
+
# cv2.waitKey()
|
@@ -0,0 +1,76 @@
|
|
1
|
+
import os
|
2
|
+
import time
|
3
|
+
from skimage import io
|
4
|
+
|
5
|
+
from .metrics import *
|
6
|
+
from .image_proc import *
|
7
|
+
|
8
|
+
class Generic_train_test():
|
9
|
+
def __init__(self, model, opts, dataloader, logger, dataloader_val=None):
|
10
|
+
self.model=model
|
11
|
+
self.opts=opts
|
12
|
+
self.dataloader=dataloader
|
13
|
+
self.logger=logger
|
14
|
+
self.dataloader_val = dataloader_val
|
15
|
+
|
16
|
+
def decode_input(self, data):
|
17
|
+
raise NotImplementedError()
|
18
|
+
|
19
|
+
def validation(self):
|
20
|
+
raise NotImplementedError()
|
21
|
+
|
22
|
+
def train_single_iterate(self, data, total_steps, epoch):
|
23
|
+
_input=self.decode_input(data)
|
24
|
+
|
25
|
+
self.model.set_input(_input)
|
26
|
+
self.model.optimize_parameters()
|
27
|
+
|
28
|
+
#=========== visualize results ============#
|
29
|
+
if total_steps % self.opts.log_freq==0:
|
30
|
+
info = self.model.get_current_scalars()
|
31
|
+
for tag, value in info.items():
|
32
|
+
self.logger.add_scalar(tag, value, total_steps)
|
33
|
+
|
34
|
+
results = self.model.get_current_visuals()
|
35
|
+
for tag, images in results.items():
|
36
|
+
self.logger.add_images(tag, images, total_steps)
|
37
|
+
|
38
|
+
print('epoch', epoch, 'steps', total_steps)
|
39
|
+
print('losses', info)
|
40
|
+
|
41
|
+
def train(self):
|
42
|
+
total_steps = 0
|
43
|
+
if self.dataloader is not None:
|
44
|
+
print('#training images ', len(self.dataloader)*self.opts.batch_sz)
|
45
|
+
|
46
|
+
for epoch in range(self.opts.start_epoch, self.opts.max_epochs):
|
47
|
+
if epoch > self.opts.lr_start_epoch_decay - self.opts.lr_step:
|
48
|
+
self.model.update_lr()
|
49
|
+
|
50
|
+
if epoch % self.opts.save_freq==0 or epoch <= self.opts.save_begin:
|
51
|
+
self.model.save_checkpoint(str(epoch))
|
52
|
+
|
53
|
+
if self.dataloader is not None:
|
54
|
+
for i, data in enumerate(self.dataloader):
|
55
|
+
total_steps+=1
|
56
|
+
self.train_single_iterate(data, total_steps, epoch)
|
57
|
+
else:
|
58
|
+
for i in range(10000):
|
59
|
+
total_steps+=1
|
60
|
+
self.train_single_iterate(None, total_steps, epoch)
|
61
|
+
|
62
|
+
# validation if dataloader provided
|
63
|
+
if self.dataloader_val is not None:
|
64
|
+
self.validation(epoch)
|
65
|
+
|
66
|
+
def train_single_instance(self):
|
67
|
+
total_steps = 0
|
68
|
+
data=iter(self.dataloader).next()
|
69
|
+
|
70
|
+
for epoch in range(10000):
|
71
|
+
for i in range(1000):
|
72
|
+
total_steps+=1
|
73
|
+
self.train_single_iterate(data, total_steps, epoch)
|
74
|
+
|
75
|
+
|
76
|
+
|