spikezoo 0.1.2__py3-none-any.whl → 0.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- spikezoo/__init__.py +13 -0
- spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/base/nets.py +34 -0
- spikezoo/archs/bsf/README.md +92 -0
- spikezoo/archs/bsf/datasets/datasets.py +328 -0
- spikezoo/archs/bsf/datasets/ds_utils.py +64 -0
- spikezoo/archs/bsf/main.py +398 -0
- spikezoo/archs/bsf/metrics/psnr.py +22 -0
- spikezoo/archs/bsf/metrics/ssim.py +54 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/align.py +154 -0
- spikezoo/archs/bsf/models/bsf/bsf.py +105 -0
- spikezoo/archs/bsf/models/bsf/dsft_convert.py +96 -0
- spikezoo/archs/bsf/models/bsf/rep.py +44 -0
- spikezoo/archs/bsf/models/get_model.py +7 -0
- spikezoo/archs/bsf/prepare_data/DSFT.py +62 -0
- spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +135 -0
- spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +139 -0
- spikezoo/archs/bsf/prepare_data/crop_train.sh +4 -0
- spikezoo/archs/bsf/prepare_data/crop_val.sh +4 -0
- spikezoo/archs/bsf/prepare_data/io_utils.py +64 -0
- spikezoo/archs/bsf/requirements.txt +9 -0
- spikezoo/archs/bsf/test.py +16 -0
- spikezoo/archs/bsf/utils.py +154 -0
- spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spikeclip/nets.py +40 -0
- spikezoo/archs/spikeformer/CheckPoints/readme +1 -0
- spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +60 -0
- spikezoo/archs/spikeformer/DataProcess/DataLoader.py +115 -0
- spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +39 -0
- spikezoo/archs/spikeformer/EvalResults/readme +1 -0
- spikezoo/archs/spikeformer/LICENSE +21 -0
- spikezoo/archs/spikeformer/Metrics/Metrics.py +50 -0
- spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/Loss.py +89 -0
- spikezoo/archs/spikeformer/Model/SpikeFormer.py +230 -0
- spikezoo/archs/spikeformer/Model/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/README.md +30 -0
- spikezoo/archs/spikeformer/evaluate.py +87 -0
- spikezoo/archs/spikeformer/recon_real_data.py +97 -0
- spikezoo/archs/spikeformer/requirements.yml +95 -0
- spikezoo/archs/spikeformer/train.py +173 -0
- spikezoo/archs/spikeformer/utils.py +22 -0
- spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +23 -0
- spikezoo/archs/spk2imgnet/.gitignore +150 -0
- spikezoo/archs/spk2imgnet/DCNv2.py +135 -0
- spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/align_arch.py +159 -0
- spikezoo/archs/spk2imgnet/dataset.py +144 -0
- spikezoo/archs/spk2imgnet/nets.py +230 -0
- spikezoo/archs/spk2imgnet/readme.md +86 -0
- spikezoo/archs/spk2imgnet/test_gen_imgseq.py +118 -0
- spikezoo/archs/spk2imgnet/train.py +189 -0
- spikezoo/archs/spk2imgnet/utils.py +64 -0
- spikezoo/archs/ssir/README.md +87 -0
- spikezoo/archs/ssir/configs/SSIR.yml +37 -0
- spikezoo/archs/ssir/configs/yml_parser.py +78 -0
- spikezoo/archs/ssir/datasets/dataset_sreds.py +170 -0
- spikezoo/archs/ssir/datasets/ds_utils.py +66 -0
- spikezoo/archs/ssir/losses.py +21 -0
- spikezoo/archs/ssir/main.py +326 -0
- spikezoo/archs/ssir/metrics/psnr.py +22 -0
- spikezoo/archs/ssir/metrics/ssim.py +54 -0
- spikezoo/archs/ssir/models/Vgg19.py +42 -0
- spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/layers.py +110 -0
- spikezoo/archs/ssir/models/networks.py +61 -0
- spikezoo/archs/ssir/requirements.txt +8 -0
- spikezoo/archs/ssir/shells/eval_SREDS.sh +6 -0
- spikezoo/archs/ssir/shells/train_SSIR.sh +12 -0
- spikezoo/archs/ssir/test.py +3 -0
- spikezoo/archs/ssir/utils.py +154 -0
- spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/cbam.py +224 -0
- spikezoo/archs/ssml/model.py +290 -0
- spikezoo/archs/ssml/res.png +0 -0
- spikezoo/archs/ssml/test.py +67 -0
- spikezoo/archs/stir/.git-credentials +0 -0
- spikezoo/archs/stir/README.md +65 -0
- spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +1 -0
- spikezoo/archs/stir/configs/STIR.yml +37 -0
- spikezoo/archs/stir/configs/utils.py +155 -0
- spikezoo/archs/stir/configs/yml_parser.py +78 -0
- spikezoo/archs/stir/datasets/dataset_sreds.py +180 -0
- spikezoo/archs/stir/datasets/ds_utils.py +66 -0
- spikezoo/archs/stir/eval_SREDS.sh +5 -0
- spikezoo/archs/stir/main.py +397 -0
- spikezoo/archs/stir/metrics/losses.py +219 -0
- spikezoo/archs/stir/metrics/psnr.py +22 -0
- spikezoo/archs/stir/metrics/ssim.py +54 -0
- spikezoo/archs/stir/models/Vgg19.py +42 -0
- spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/networks_STIR.py +361 -0
- spikezoo/archs/stir/models/submodules.py +86 -0
- spikezoo/archs/stir/models/transformer_new.py +151 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +721 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +133 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +167 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +76 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +458 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +183 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +40 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +198 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +51 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +53 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +100 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +333 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +123 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +72 -0
- spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
- spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/convertions.py +721 -0
- spikezoo/archs/stir/package_core/package_core/disp_netS.py +133 -0
- spikezoo/archs/stir/package_core/package_core/flow_utils.py +167 -0
- spikezoo/archs/stir/package_core/package_core/generic_train_test.py +76 -0
- spikezoo/archs/stir/package_core/package_core/geometry.py +458 -0
- spikezoo/archs/stir/package_core/package_core/image_proc.py +183 -0
- spikezoo/archs/stir/package_core/package_core/linalg.py +40 -0
- spikezoo/archs/stir/package_core/package_core/losses.py +198 -0
- spikezoo/archs/stir/package_core/package_core/metrics.py +51 -0
- spikezoo/archs/stir/package_core/package_core/model_base.py +53 -0
- spikezoo/archs/stir/package_core/package_core/net_basics.py +100 -0
- spikezoo/archs/stir/package_core/package_core/resnet.py +333 -0
- spikezoo/archs/stir/package_core/package_core/transforms.py +123 -0
- spikezoo/archs/stir/package_core/package_core/utils.py +72 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +3 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +20 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +1 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +1 -0
- spikezoo/archs/stir/package_core/setup.py +5 -0
- spikezoo/archs/stir/requirements.txt +12 -0
- spikezoo/archs/stir/train_STIR.sh +9 -0
- spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfi/nets.py +43 -0
- spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfp/nets.py +13 -0
- spikezoo/archs/wgse/README.md +64 -0
- spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/dataset.py +59 -0
- spikezoo/archs/wgse/demo.png +0 -0
- spikezoo/archs/wgse/demo.py +83 -0
- spikezoo/archs/wgse/dwtnets.py +145 -0
- spikezoo/archs/wgse/eval.py +133 -0
- spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +11 -0
- spikezoo/archs/wgse/submodules.py +68 -0
- spikezoo/archs/wgse/train.py +261 -0
- spikezoo/archs/wgse/transform.py +139 -0
- spikezoo/archs/wgse/utils.py +128 -0
- spikezoo/archs/wgse/weights/demo.png +0 -0
- spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
- spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
- spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
- spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
- spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
- spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
- spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
- spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
- spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
- spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
- spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
- spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
- spikezoo/datasets/base_dataset.py +2 -3
- spikezoo/metrics/__init__.py +1 -1
- spikezoo/models/base_model.py +1 -3
- spikezoo/pipeline/base_pipeline.py +7 -5
- spikezoo/pipeline/train_pipeline.py +1 -1
- spikezoo/utils/other_utils.py +16 -6
- spikezoo/utils/spike_utils.py +33 -29
- spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
- spikezoo-0.2.dist-info/METADATA +163 -0
- spikezoo-0.2.dist-info/RECORD +211 -0
- spikezoo/models/spcsnet_model.py +0 -19
- spikezoo-0.1.2.dist-info/METADATA +0 -39
- spikezoo-0.1.2.dist-info/RECORD +0 -36
- {spikezoo-0.1.2.dist-info → spikezoo-0.2.dist-info}/LICENSE.txt +0 -0
- {spikezoo-0.1.2.dist-info → spikezoo-0.2.dist-info}/WHEEL +0 -0
- {spikezoo-0.1.2.dist-info → spikezoo-0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,105 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from .rep import MODF
|
5
|
+
from .align import Multi_Granularity_Align
|
6
|
+
|
7
|
+
|
8
|
+
class BasicModel(nn.Module):
|
9
|
+
def __init__(self):
|
10
|
+
super().__init__()
|
11
|
+
|
12
|
+
####################################################################################
|
13
|
+
## Tools functions for neural networks
|
14
|
+
def weight_parameters(self):
|
15
|
+
return [param for name, param in self.named_parameters() if 'weight' in name]
|
16
|
+
|
17
|
+
def bias_parameters(self):
|
18
|
+
return [param for name, param in self.named_parameters() if 'bias' in name]
|
19
|
+
|
20
|
+
def num_parameters(self):
|
21
|
+
return sum([p.data.nelement() if p.requires_grad else 0 for p in self.parameters()])
|
22
|
+
|
23
|
+
def init_weights(self):
|
24
|
+
for layer in self.named_modules():
|
25
|
+
if isinstance(layer, nn.Conv2d):
|
26
|
+
nn.init.kaiming_normal_(layer.weight)
|
27
|
+
if layer.bias is not None:
|
28
|
+
nn.init.constant_(layer.bias, 0)
|
29
|
+
|
30
|
+
elif isinstance(layer, nn.ConvTranspose2d):
|
31
|
+
nn.init.kaiming_normal_(layer.weight)
|
32
|
+
if layer.bias is not None:
|
33
|
+
nn.init.constant_(layer.bias, 0)
|
34
|
+
|
35
|
+
|
36
|
+
def split_and_b_cat(x):
|
37
|
+
x0 = x[:, 10-10:10+10+1].clone()
|
38
|
+
x1 = x[:, 20-10:20+10+1].clone()
|
39
|
+
x2 = x[:, 30-10:30+10+1].clone()
|
40
|
+
x3 = x[:, 40-10:40+10+1].clone()
|
41
|
+
x4 = x[:, 50-10:50+10+1].clone()
|
42
|
+
return torch.cat([x0, x1, x2, x3, x4], dim=0)
|
43
|
+
|
44
|
+
|
45
|
+
class Encoder(nn.Module):
|
46
|
+
def __init__(self, base_dim=64, layers=4, act=nn.ReLU()):
|
47
|
+
super().__init__()
|
48
|
+
self.conv_list = nn.ModuleList()
|
49
|
+
for ii in range(layers):
|
50
|
+
self.conv_list.append(
|
51
|
+
nn.Sequential(
|
52
|
+
nn.Conv2d(base_dim, base_dim, kernel_size=3, padding=1),
|
53
|
+
act,
|
54
|
+
nn.Conv2d(base_dim, base_dim, kernel_size=3, padding=1),
|
55
|
+
)
|
56
|
+
)
|
57
|
+
self.act = act
|
58
|
+
|
59
|
+
def forward(self, x):
|
60
|
+
for conv in self.conv_list:
|
61
|
+
x = self.act(conv(x) + x)
|
62
|
+
return x
|
63
|
+
|
64
|
+
##########################################################################
|
65
|
+
class BSF(BasicModel):
|
66
|
+
def __init__(self, act=nn.ReLU()):
|
67
|
+
super().__init__()
|
68
|
+
self.offset_groups = 4
|
69
|
+
self.corr_max_disp = 3
|
70
|
+
|
71
|
+
self.rep = MODF(base_dim=64, act=act)
|
72
|
+
|
73
|
+
self.encoder = Encoder(base_dim=64, layers=4, act=act)
|
74
|
+
|
75
|
+
self.align = Multi_Granularity_Align(base_dim=64, groups=self.offset_groups, act=act, sc=3)
|
76
|
+
|
77
|
+
self.recons = nn.Sequential(
|
78
|
+
nn.Conv2d(64*5, 64*3, kernel_size=3, padding=1),
|
79
|
+
act,
|
80
|
+
nn.Conv2d(64*3, 64, kernel_size=3, padding=1),
|
81
|
+
act,
|
82
|
+
nn.Conv2d(64, 1, kernel_size=3, padding=1),
|
83
|
+
)
|
84
|
+
|
85
|
+
def forward(self, input_dict):
|
86
|
+
dsft_dict = input_dict['dsft_dict']
|
87
|
+
dsft11 = dsft_dict['dsft11']
|
88
|
+
dsft12 = dsft_dict['dsft12']
|
89
|
+
dsft21 = dsft_dict['dsft21']
|
90
|
+
dsft22 = dsft_dict['dsft22']
|
91
|
+
|
92
|
+
dsft_b_cat = {
|
93
|
+
'dsft11': split_and_b_cat(dsft11),
|
94
|
+
'dsft12': split_and_b_cat(dsft12),
|
95
|
+
'dsft21': split_and_b_cat(dsft21),
|
96
|
+
'dsft22': split_and_b_cat(dsft22),
|
97
|
+
}
|
98
|
+
|
99
|
+
feat_b_cat = self.rep(dsft_b_cat)
|
100
|
+
feat_b_cat = self.encoder(feat_b_cat)
|
101
|
+
feat_list = feat_b_cat.chunk(5, dim=0)
|
102
|
+
feat_list_align = self.align(feat_list=feat_list)
|
103
|
+
out = self.recons(torch.cat(feat_list_align, dim=1))
|
104
|
+
|
105
|
+
return out
|
@@ -0,0 +1,96 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
|
4
|
+
def convert_dsft4(dsft, spike):
|
5
|
+
'''
|
6
|
+
input: Pytorch Tensor
|
7
|
+
dsft: dsft(1,1) b x T x h x w
|
8
|
+
spike: 01 spike b x T x h x w
|
9
|
+
output: Pytorch Tensor
|
10
|
+
dsft_dict: {dsft(1,1), dsft(1,2), dsft(2,1), dsft(2,2)}
|
11
|
+
'''
|
12
|
+
|
13
|
+
b, T, h, w = spike.shape
|
14
|
+
|
15
|
+
## dsft_mask_left_shift -- abbr. --> dmls1, (right-shift: dmrs1)
|
16
|
+
dmls1 = -1 * torch.ones(spike.shape, device=spike.device, dtype=torch.float32)
|
17
|
+
dmrs1 = -1 * torch.ones(spike.shape, device=spike.device, dtype=torch.float32)
|
18
|
+
|
19
|
+
## for dmls1
|
20
|
+
# flag的用途是为了边界的copy-padding
|
21
|
+
flag = -2 * torch.ones([b, h, w], device=spike.device, dtype=torch.float32)
|
22
|
+
for ii in range(T-1, 0-1, -1):
|
23
|
+
flag += (spike[:,ii]==1)
|
24
|
+
|
25
|
+
copy_pad_coord = (flag < 0)
|
26
|
+
dmls1[:,ii][copy_pad_coord] = dsft[:,ii][copy_pad_coord]
|
27
|
+
|
28
|
+
if ii < T-1:
|
29
|
+
## dmls1的数据该更新的情况
|
30
|
+
update_coord = (spike[:,ii+1]==1) * (~copy_pad_coord)
|
31
|
+
dmls1[:,ii][update_coord] = dsft[:,ii+1][update_coord]
|
32
|
+
|
33
|
+
## dmls1的数据不该更新,该继承之前的数的情况
|
34
|
+
non_update_coord = (spike[:,ii+1]!=1) * (~copy_pad_coord)
|
35
|
+
dmls1[:,ii][non_update_coord] = dmls1[:, ii+1][non_update_coord]
|
36
|
+
|
37
|
+
|
38
|
+
## for dmrs1
|
39
|
+
# flag的用途是为了边界的copy-padding
|
40
|
+
flag = -2 * torch.ones([b, h, w], device=spike.device, dtype=torch.float32)
|
41
|
+
for ii in range(0, T, 1):
|
42
|
+
flag += (spike[:,ii]==1)
|
43
|
+
|
44
|
+
## for 边界的 copy-padding
|
45
|
+
copy_pad_coord = (flag < 0)
|
46
|
+
dmrs1[:,ii][copy_pad_coord] = dsft[:,ii][copy_pad_coord]
|
47
|
+
|
48
|
+
if ii > 0:
|
49
|
+
## dmrs1的数据该更新的情况
|
50
|
+
update_coord = (spike[:,ii]==1) * (~copy_pad_coord)
|
51
|
+
dmrs1[:,ii][update_coord] = dsft[:,ii-1][update_coord]
|
52
|
+
|
53
|
+
## dmrs1的数据不该更新,该继承之前的数的情况
|
54
|
+
non_update_coord = (spike[:,ii]!=1) * (~copy_pad_coord)
|
55
|
+
dmrs1[:,ii][non_update_coord] = dmrs1[:, ii-1][non_update_coord]
|
56
|
+
|
57
|
+
|
58
|
+
dsft12 = dsft + dmls1
|
59
|
+
dsft21 = dsft + dmrs1
|
60
|
+
dsft22 = dsft + dmls1 + dmrs1
|
61
|
+
|
62
|
+
|
63
|
+
dsft_dict = {
|
64
|
+
'dsft11': dsft,
|
65
|
+
'dsft12': dsft12,
|
66
|
+
'dsft21': dsft21,
|
67
|
+
'dsft22': dsft22,
|
68
|
+
}
|
69
|
+
|
70
|
+
return dsft_dict
|
71
|
+
|
72
|
+
|
73
|
+
|
74
|
+
if __name__ == '__main__':
|
75
|
+
# spike = [0,0,1,0,1,0,0,0,0,1,0,0,1,0,1,0,0,0,1,0,0]
|
76
|
+
# dsft = [2,2,2,2,2,5,5,5,5,5,3,3,3,2,2,4,4,4,4,4,4]
|
77
|
+
|
78
|
+
spike = [0,0,1,0,0,1,0,0,0,1,0,0,0,1,0,0,1,0,1,0,0,1,0,0,0,1,0,0,1,0,0,1,0]
|
79
|
+
dsft = [3,3,3,3,3,4,4,4,4,4,4,4,4,3,3,3,2,2,3,3,3,4,4,4,4,3,3,3,3,3,3,3,3]
|
80
|
+
|
81
|
+
spike = torch.tensor(spike, device='cpu', dtype=torch.float32)[None,:,None,None]
|
82
|
+
dsft = torch.tensor(dsft , device='cpu', dtype=torch.float32)[None,:,None,None]
|
83
|
+
|
84
|
+
dsft_dict = convert_dsft4(dsft=dsft, spike=spike)
|
85
|
+
dsft_11 = dsft_dict['dsft11']
|
86
|
+
dsft_12 = dsft_dict['dsft12']
|
87
|
+
dsft_21 = dsft_dict['dsft21']
|
88
|
+
dsft_22 = dsft_dict['dsft22']
|
89
|
+
|
90
|
+
print(dsft_11[0,:,0,0])
|
91
|
+
print()
|
92
|
+
print(dsft_12[0,:,0,0])
|
93
|
+
print()
|
94
|
+
print(dsft_21[0,:,0,0])
|
95
|
+
print()
|
96
|
+
print(dsft_22[0,:,0,0])
|
@@ -0,0 +1,44 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
|
4
|
+
class MODF(nn.Module):
|
5
|
+
def __init__(self, base_dim=64, act=nn.ReLU()):
|
6
|
+
super().__init__()
|
7
|
+
self.base_dim = base_dim
|
8
|
+
|
9
|
+
self.conv1 = self._make_layer(input_dim=21, hidden_dim=self.base_dim, output_dim=self.base_dim, act=act)
|
10
|
+
self.conv_for_others = nn.ModuleList([
|
11
|
+
self._make_layer(input_dim=self.base_dim, hidden_dim=self.base_dim, output_dim=self.base_dim, act=act) for ii in range(3)
|
12
|
+
])
|
13
|
+
self.conv_fuse = self._make_layer(input_dim=self.base_dim*3, hidden_dim=self.base_dim, output_dim=self.base_dim, act=act)
|
14
|
+
|
15
|
+
def _make_layer(self, input_dim, hidden_dim, output_dim, act):
|
16
|
+
layer = nn.Sequential(
|
17
|
+
nn.Conv2d(input_dim, hidden_dim, kernel_size=3, padding=1),
|
18
|
+
act,
|
19
|
+
nn.Conv2d(hidden_dim, output_dim, kernel_size=3, padding=1),
|
20
|
+
)
|
21
|
+
return layer
|
22
|
+
|
23
|
+
def forward(self, dsft_dict):
|
24
|
+
d11 = 1.0 / dsft_dict['dsft11']
|
25
|
+
d12 = 2.0 / dsft_dict['dsft12']
|
26
|
+
d21 = 2.0 / dsft_dict['dsft21']
|
27
|
+
d22 = 3.0 / dsft_dict['dsft22']
|
28
|
+
|
29
|
+
d_list = [d11, d12, d21, d22]
|
30
|
+
feat_batch_cat = self.conv1(torch.cat(d_list, dim=0))
|
31
|
+
feat_list = feat_batch_cat.chunk(4, dim=0)
|
32
|
+
|
33
|
+
feat_11 = feat_list[0]
|
34
|
+
feat_others_list = feat_list[1:]
|
35
|
+
feat_others_list_processed = []
|
36
|
+
for ii in range(3):
|
37
|
+
feat_others_list_processed.append(self.conv_for_others[ii](feat_others_list[ii]))
|
38
|
+
|
39
|
+
|
40
|
+
other_feat = torch.cat(feat_others_list_processed, dim=1)
|
41
|
+
other_feat_res = self.conv_fuse(other_feat)
|
42
|
+
|
43
|
+
return feat_11 + other_feat_res
|
44
|
+
|
@@ -0,0 +1,62 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import torch
|
3
|
+
|
4
|
+
class DSFT:
|
5
|
+
def __init__(self, spike_h, spike_w, device):
|
6
|
+
self.spike_h = spike_h
|
7
|
+
self.spike_w = spike_w
|
8
|
+
self.device = device
|
9
|
+
|
10
|
+
|
11
|
+
def spikes2images(self, spikes, max_search_half_window=20):
|
12
|
+
'''
|
13
|
+
将spikes整体转换为一段DSFT
|
14
|
+
|
15
|
+
输入:
|
16
|
+
spikes: T x H x W 的numpy张量, 类型: 整型与浮点皆可
|
17
|
+
max_search_half_window: 对于要转换为图像的时刻点而言, 左右各参考的最大脉冲帧数量,超过这个数字就不搜了
|
18
|
+
|
19
|
+
输出:
|
20
|
+
ImageMatrix: T' x H x W 的numpy张量, 其中T' = T - (2 x max_search_half_window)
|
21
|
+
类型: uint8, 取值范围: 0 ~ 255
|
22
|
+
'''
|
23
|
+
|
24
|
+
T = spikes.shape[0]
|
25
|
+
T_im = T - 2*max_search_half_window
|
26
|
+
|
27
|
+
if T_im < 0:
|
28
|
+
raise ValueError('The length of spike stream {:d} is not enough for max_search half window length {:d}'.format(T, max_search_half_window))
|
29
|
+
|
30
|
+
spikes = torch.from_numpy(spikes).to(self.device).float()
|
31
|
+
ImageMatrix = torch.zeros([T_im, self.spike_h, self.spike_w]).to(self.device)
|
32
|
+
|
33
|
+
pre_idx = -1 * torch.ones([T, self.spike_h, self.spike_w]).float().to(self.device)
|
34
|
+
cur_idx = -1 * torch.ones([T, self.spike_h, self.spike_w]).float().to(self.device)
|
35
|
+
|
36
|
+
for ii in range(T):
|
37
|
+
if ii > 0:
|
38
|
+
pre_idx[ii] = cur_idx[ii-1]
|
39
|
+
cur_idx[ii] = cur_idx[ii-1]
|
40
|
+
cur_spk = spikes[ii]
|
41
|
+
cur_idx[ii][cur_spk==1] = ii
|
42
|
+
|
43
|
+
diff = cur_idx - pre_idx
|
44
|
+
|
45
|
+
|
46
|
+
interval = -1 * torch.ones([T, self.spike_h, self.spike_w]).float().to(self.device)
|
47
|
+
for ii in range(T-1, 0-1, -1):
|
48
|
+
interval[ii][diff[ii]!=0] = diff[ii][diff[ii]!=0]
|
49
|
+
if ii < T-1:
|
50
|
+
interval[ii][diff[ii]==0] = interval[ii+1][diff[ii]==0]
|
51
|
+
|
52
|
+
# boundary
|
53
|
+
interval[interval==-1] = 255
|
54
|
+
interval[pre_idx==-1] = 255
|
55
|
+
|
56
|
+
# for uint8
|
57
|
+
interval = torch.clip(interval, 0, 255)
|
58
|
+
|
59
|
+
ImageMatrix = interval[max_search_half_window:-max_search_half_window].cpu().detach().numpy().astype(np.uint8)
|
60
|
+
|
61
|
+
|
62
|
+
return ImageMatrix
|
@@ -0,0 +1,135 @@
|
|
1
|
+
import os
|
2
|
+
import os.path as osp
|
3
|
+
import argparse
|
4
|
+
import cv2
|
5
|
+
import numpy as np
|
6
|
+
from io_utils import *
|
7
|
+
import h5py
|
8
|
+
from tqdm import *
|
9
|
+
from DSFT import DSFT
|
10
|
+
|
11
|
+
parser = argparse.ArgumentParser()
|
12
|
+
parser.add_argument("--root", type=str, default="/data/rzhao/REDS120fps")
|
13
|
+
parser.add_argument("--output_path", type=str, default="/data/rzhao/REDS120fps/crop")
|
14
|
+
###### 参数
|
15
|
+
parser.add_argument("--eta", type=float, default=1.0)
|
16
|
+
parser.add_argument("--gamma", type=int, default=60)
|
17
|
+
parser.add_argument("--alpha", type=float, default=0.7)
|
18
|
+
|
19
|
+
parser.add_argument("--cu", '-c', type=str, default='0')
|
20
|
+
|
21
|
+
parser.add_argument("--crop_image", action='store_true')
|
22
|
+
args = parser.parse_args()
|
23
|
+
|
24
|
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.cu
|
25
|
+
|
26
|
+
|
27
|
+
if __name__ == '__main__':
|
28
|
+
imgs_path = osp.join(args.root, 'imgs', 'train')
|
29
|
+
spks_path = osp.join(args.root, 'spikes', 'train',
|
30
|
+
"eta_{:.2f}_gamma_{:d}_alpha_{:.1f}".format(args.eta, args.gamma, args.alpha))
|
31
|
+
|
32
|
+
scene_list = sorted(os.listdir(spks_path))
|
33
|
+
for scene in tqdm(scene_list):
|
34
|
+
scene_imgs_path = osp.join(imgs_path, scene)
|
35
|
+
scene_spks_path = osp.join(spks_path, scene)
|
36
|
+
|
37
|
+
|
38
|
+
if not args.crop_image:
|
39
|
+
# read all the dat files
|
40
|
+
dat_path = sorted(os.listdir(scene_spks_path))
|
41
|
+
spks_list = []
|
42
|
+
#### abandon 00000000.dat, corresponding to the following spike_idx_offset
|
43
|
+
for dat_name in dat_path[1:]:
|
44
|
+
spks_list.append(dat_to_spmat(dat_path=osp.join(scene_spks_path, dat_name), size=(720, 1280)))
|
45
|
+
spikes = np.concatenate(spks_list, axis=0)
|
46
|
+
|
47
|
+
# spikes -> DSFT(max_search_half_win=80)
|
48
|
+
dsft_solver = DSFT(spike_h=720, spike_w=1280, device='cuda')
|
49
|
+
dsft = dsft_solver.spikes2images(spikes, max_search_half_window=100)
|
50
|
+
|
51
|
+
|
52
|
+
# crop Image
|
53
|
+
if args.crop_image:
|
54
|
+
imgs_list = []
|
55
|
+
for im_idx in range(11, 28+1):
|
56
|
+
img = cv2.imread(osp.join(scene_imgs_path, '{:08d}.png'.format(im_idx)))
|
57
|
+
# 1. central crop
|
58
|
+
crop_img = img[32:-32, 128:-128]
|
59
|
+
# 2. crop
|
60
|
+
for ii in range(3):
|
61
|
+
for jj in range(4):
|
62
|
+
if (ii != 2) and (jj != 3):
|
63
|
+
cur_img = crop_img[256*ii:256*(ii+1), 256*jj:256*(jj+1)]
|
64
|
+
elif ii != 2:
|
65
|
+
cur_img = crop_img[256*ii:256*(ii+1), -256:]
|
66
|
+
elif jj != 3:
|
67
|
+
cur_img = crop_img[-256:, 256*jj:256*(jj+1)]
|
68
|
+
else:
|
69
|
+
cur_img = crop_img[-256:, -256:]
|
70
|
+
cur_save_root = osp.join(args.output_path, 'train', 'imgs', scene, '{:02}'.format(ii*4+jj))
|
71
|
+
os.makedirs(cur_save_root, exist_ok=True)
|
72
|
+
cur_save_path = osp.join(cur_save_root, '{:08d}.png'.format(im_idx))
|
73
|
+
if osp.exists(cur_save_path):
|
74
|
+
os.remove(cur_save_path)
|
75
|
+
cv2.imwrite(cur_save_path, cur_img)
|
76
|
+
continue
|
77
|
+
|
78
|
+
# crop spikes
|
79
|
+
# since 00000000.dat is abandoned
|
80
|
+
spike_idx_offset = 10
|
81
|
+
# 1. central crop
|
82
|
+
spikes = spikes[:, 32:-32, 128:-128]
|
83
|
+
# 2. crop
|
84
|
+
for spk_idx in range(11, 28+1):
|
85
|
+
crop_spike = spikes[spk_idx*10-spike_idx_offset : spk_idx*10-spike_idx_offset+10]
|
86
|
+
for ii in range(3):
|
87
|
+
for jj in range(4):
|
88
|
+
if (ii != 2) and (jj != 3):
|
89
|
+
cur_spk = crop_spike[:, 256*ii:256*(ii+1), 256*jj:256*(jj+1)]
|
90
|
+
elif ii != 2:
|
91
|
+
cur_spk = crop_spike[:, 256*ii:256*(ii+1), -256:]
|
92
|
+
elif jj != 3:
|
93
|
+
cur_spk = crop_spike[:, -256:, 256*jj:256*(jj+1)]
|
94
|
+
else:
|
95
|
+
cur_spk = crop_spike[:, -256:, -256:]
|
96
|
+
|
97
|
+
cur_save_root = osp.join(args.output_path, 'train',
|
98
|
+
"eta_{:.2f}_gamma_{:d}_alpha_{:.1f}".format(args.eta, args.gamma, args.alpha),
|
99
|
+
scene, '{:02}'.format(ii*4+jj), 'spikes')
|
100
|
+
|
101
|
+
os.makedirs(cur_save_root, exist_ok=True)
|
102
|
+
cur_save_path = osp.join(cur_save_root,'{:08d}.dat'.format(spk_idx))
|
103
|
+
if osp.exists(cur_save_path):
|
104
|
+
os.remove(cur_save_path)
|
105
|
+
SpikeToRaw(SpikeSeq=cur_spk, save_path=cur_save_path)
|
106
|
+
|
107
|
+
|
108
|
+
# crop dsft
|
109
|
+
dsft_idx_offset = 10 + 100
|
110
|
+
# 1. central crop
|
111
|
+
dsft = dsft[:, 32:-32, 128:-128]
|
112
|
+
# 2. crop
|
113
|
+
for dsft_idx in range(11, 28+1):
|
114
|
+
crop_dsft = dsft[dsft_idx*10-dsft_idx_offset : dsft_idx*10-dsft_idx_offset+10]
|
115
|
+
for ii in range(3):
|
116
|
+
for jj in range(4):
|
117
|
+
if (ii != 2) and (jj != 3):
|
118
|
+
cur_dsft = crop_dsft[:, 256*ii:256*(ii+1), 256*jj:256*(jj+1)]
|
119
|
+
elif ii != 2:
|
120
|
+
cur_dsft = crop_dsft[:, 256*ii:256*(ii+1), -256:]
|
121
|
+
elif jj != 3:
|
122
|
+
cur_dsft = crop_dsft[:, -256:, 256*jj:256*(jj+1)]
|
123
|
+
else:
|
124
|
+
cur_dsft = crop_dsft[:, -256:, -256:]
|
125
|
+
|
126
|
+
cur_save_root = osp.join(args.output_path, 'train',
|
127
|
+
"eta_{:.2f}_gamma_{:d}_alpha_{:.1f}".format(args.eta, args.gamma, args.alpha),
|
128
|
+
scene, '{:02}'.format(ii*4+jj), 'dsft')
|
129
|
+
os.makedirs(cur_save_root, exist_ok=True)
|
130
|
+
cur_save_path = osp.join(cur_save_root, '{:08d}.h5'.format(dsft_idx))
|
131
|
+
if osp.exists(cur_save_path):
|
132
|
+
os.remove(cur_save_path)
|
133
|
+
f = h5py.File(cur_save_path, 'w')
|
134
|
+
f['dsft'] = cur_dsft
|
135
|
+
f.close()
|
@@ -0,0 +1,139 @@
|
|
1
|
+
import os
|
2
|
+
import os.path as osp
|
3
|
+
import argparse
|
4
|
+
import cv2
|
5
|
+
import numpy as np
|
6
|
+
from io_utils import *
|
7
|
+
import h5py
|
8
|
+
from tqdm import *
|
9
|
+
from DSFT import DSFT
|
10
|
+
|
11
|
+
parser = argparse.ArgumentParser()
|
12
|
+
parser.add_argument("--root", type=str, default="/data/rzhao/REDS120fps")
|
13
|
+
parser.add_argument("--output_path", type=str, default="/data/rzhao/REDS120fps/crop")
|
14
|
+
###### 参数
|
15
|
+
parser.add_argument("--eta", type=float, default=1.00)
|
16
|
+
parser.add_argument("--gamma", type=int, default=60)
|
17
|
+
parser.add_argument("--alpha", type=float, default=0.7)
|
18
|
+
|
19
|
+
parser.add_argument("--cu", '-c', type=str, default='0')
|
20
|
+
|
21
|
+
parser.add_argument("--crop_image", action='store_true')
|
22
|
+
args = parser.parse_args()
|
23
|
+
|
24
|
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.cu
|
25
|
+
|
26
|
+
if __name__ == '__main__':
|
27
|
+
imgs_path = osp.join(args.root, 'imgs', 'val')
|
28
|
+
spks_path = osp.join(args.root, 'spikes', 'val',
|
29
|
+
"eta_{:.2f}_gamma_{:d}_alpha_{:.1f}".format(args.eta, args.gamma, args.alpha))
|
30
|
+
|
31
|
+
scene_list = sorted(os.listdir(spks_path))
|
32
|
+
for scene in tqdm(scene_list):
|
33
|
+
scene_imgs_path = osp.join(imgs_path, scene)
|
34
|
+
scene_spks_path = osp.join(spks_path, scene)
|
35
|
+
|
36
|
+
if not args.crop_image:
|
37
|
+
# read all the dat files
|
38
|
+
dat_path = sorted(os.listdir(scene_spks_path))
|
39
|
+
spks_list = []
|
40
|
+
#### abandon 00000000.dat, corresponding to the following spike_idx_offset
|
41
|
+
for dat_name in dat_path[1:]:
|
42
|
+
spks_list.append(dat_to_spmat(dat_path=osp.join(scene_spks_path, dat_name), size=(720, 1280)))
|
43
|
+
spikes = np.concatenate(spks_list, axis=0)
|
44
|
+
|
45
|
+
# spikes -> DSFT(max_search_half_win=80)
|
46
|
+
dsft_solver = DSFT(spike_h=720, spike_w=1280, device='cuda')
|
47
|
+
dsft = dsft_solver.spikes2images(spikes, max_search_half_window=100)
|
48
|
+
|
49
|
+
|
50
|
+
# crop Image
|
51
|
+
if args.crop_image:
|
52
|
+
imgs_list = []
|
53
|
+
for im_idx in range(11, 28+1):
|
54
|
+
img = cv2.imread(osp.join(scene_imgs_path, '{:08d}.png'.format(im_idx)))
|
55
|
+
# 1. central crop
|
56
|
+
crop_img = img[32:-32, 128:-128]
|
57
|
+
for sub_scene_idx in range(4):
|
58
|
+
cur_scene = '{:s}_{:d}'.format(scene, sub_scene_idx)
|
59
|
+
if sub_scene_idx == 0:
|
60
|
+
cur_crop_img = crop_img[:384, :512]
|
61
|
+
elif sub_scene_idx == 1:
|
62
|
+
cur_crop_img = crop_img[-384:, :512]
|
63
|
+
elif sub_scene_idx == 2:
|
64
|
+
cur_crop_img = crop_img[:384, -512:]
|
65
|
+
elif sub_scene_idx == 3:
|
66
|
+
cur_crop_img = crop_img[-384:, -512:]
|
67
|
+
|
68
|
+
cur_save_root = osp.join(args.output_path, 'val_small', 'imgs', cur_scene)
|
69
|
+
os.makedirs(cur_save_root, exist_ok=True)
|
70
|
+
cur_save_path = osp.join(cur_save_root, '{:08d}.png'.format(im_idx))
|
71
|
+
if osp.exists(cur_save_path):
|
72
|
+
os.remove(cur_save_path)
|
73
|
+
cv2.imwrite(cur_save_path, cur_crop_img)
|
74
|
+
continue
|
75
|
+
|
76
|
+
|
77
|
+
# 裁切 spikes
|
78
|
+
# since 00000000.dat is abandoned
|
79
|
+
spike_idx_offset = 10
|
80
|
+
# 1. central crop
|
81
|
+
spikes = spikes[:, 32:-32, 128:-128]
|
82
|
+
# 2. crop
|
83
|
+
for spk_idx in range(11, 28+1):
|
84
|
+
crop_spike = spikes[spk_idx*10-spike_idx_offset : spk_idx*10-spike_idx_offset+10]
|
85
|
+
|
86
|
+
for sub_scene_idx in range(4):
|
87
|
+
cur_scene = '{:s}_{:d}'.format(scene, sub_scene_idx)
|
88
|
+
if sub_scene_idx == 0:
|
89
|
+
cur_crop_spike = crop_spike[:, :384, :512]
|
90
|
+
elif sub_scene_idx == 1:
|
91
|
+
cur_crop_spike = crop_spike[:, -384:, :512]
|
92
|
+
elif sub_scene_idx == 2:
|
93
|
+
cur_crop_spike = crop_spike[:, :384, -512:]
|
94
|
+
elif sub_scene_idx == 3:
|
95
|
+
cur_crop_spike = crop_spike[:, -384:, -512:]
|
96
|
+
|
97
|
+
cur_save_root = osp.join(args.output_path, 'val_small',
|
98
|
+
"eta_{:.2f}_gamma_{:d}_alpha_{:.1f}".format(args.eta, args.gamma, args.alpha),
|
99
|
+
cur_scene,
|
100
|
+
'spikes')
|
101
|
+
|
102
|
+
os.makedirs(cur_save_root, exist_ok=True)
|
103
|
+
cur_save_path = osp.join(cur_save_root,'{:08d}.dat'.format(spk_idx))
|
104
|
+
if osp.exists(cur_save_path):
|
105
|
+
os.remove(cur_save_path)
|
106
|
+
SpikeToRaw(SpikeSeq=cur_crop_spike, save_path=cur_save_path)
|
107
|
+
|
108
|
+
|
109
|
+
# crop dsft
|
110
|
+
dsft_idx_offset = 10 + 100
|
111
|
+
# 1. central crop
|
112
|
+
dsft = dsft[:, 32:-32, 128:-128]
|
113
|
+
# 2. crop
|
114
|
+
for dsft_idx in range(11, 28+1):
|
115
|
+
crop_dsft = dsft[dsft_idx*10-dsft_idx_offset : dsft_idx*10-dsft_idx_offset+10]
|
116
|
+
|
117
|
+
for sub_scene_idx in range(4):
|
118
|
+
cur_scene = '{:s}_{:d}'.format(scene, sub_scene_idx)
|
119
|
+
if sub_scene_idx == 0:
|
120
|
+
cur_crop_dsft = crop_dsft[:, :384, :512]
|
121
|
+
elif sub_scene_idx == 1:
|
122
|
+
cur_crop_dsft = crop_dsft[:, -384:, :512]
|
123
|
+
elif sub_scene_idx == 2:
|
124
|
+
cur_crop_dsft = crop_dsft[:, :384, -512:]
|
125
|
+
elif sub_scene_idx == 3:
|
126
|
+
cur_crop_dsft = crop_dsft[:, -384:, -512:]
|
127
|
+
|
128
|
+
cur_save_root = osp.join(args.output_path, 'val_small',
|
129
|
+
"eta_{:.2f}_gamma_{:d}_alpha_{:.1f}".format(args.eta, args.gamma, args.alpha),
|
130
|
+
cur_scene,
|
131
|
+
'dsft')
|
132
|
+
os.makedirs(cur_save_root, exist_ok=True)
|
133
|
+
cur_save_path = osp.join(cur_save_root, '{:08d}.h5'.format(dsft_idx))
|
134
|
+
if osp.exists(cur_save_path):
|
135
|
+
os.remove(cur_save_path)
|
136
|
+
f = h5py.File(cur_save_path, 'w')
|
137
|
+
f['dsft'] = cur_crop_dsft
|
138
|
+
f.close()
|
139
|
+
|
@@ -0,0 +1,64 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import os
|
3
|
+
import os.path as osp
|
4
|
+
|
5
|
+
|
6
|
+
def RawToSpike(video_seq, h, w, flipud=True):
|
7
|
+
video_seq = np.array(video_seq).astype(np.uint8)
|
8
|
+
img_size = h*w
|
9
|
+
img_num = len(video_seq)//(img_size//8)
|
10
|
+
SpikeMatrix = np.zeros([img_num, h, w], np.uint8)
|
11
|
+
pix_id = np.arange(0,h*w)
|
12
|
+
pix_id = np.reshape(pix_id, (h, w))
|
13
|
+
comparator = np.left_shift(1, np.mod(pix_id, 8))
|
14
|
+
byte_id = pix_id // 8
|
15
|
+
|
16
|
+
for img_id in np.arange(img_num):
|
17
|
+
id_start = img_id*img_size//8
|
18
|
+
id_end = id_start + img_size//8
|
19
|
+
cur_info = video_seq[id_start:id_end]
|
20
|
+
data = cur_info[byte_id]
|
21
|
+
result = np.bitwise_and(data, comparator)
|
22
|
+
if flipud:
|
23
|
+
SpikeMatrix[img_id, :, :] = np.flipud((result == comparator))
|
24
|
+
else:
|
25
|
+
SpikeMatrix[img_id, :, :] = (result == comparator)
|
26
|
+
|
27
|
+
return SpikeMatrix
|
28
|
+
|
29
|
+
def dat_to_spmat(dat_path, size):
|
30
|
+
f = open(dat_path, 'rb')
|
31
|
+
video_seq = f.read()
|
32
|
+
video_seq = np.frombuffer(video_seq, 'b')
|
33
|
+
sp_mat = RawToSpike(video_seq, size[0], size[1])
|
34
|
+
return sp_mat
|
35
|
+
|
36
|
+
|
37
|
+
## Save Raw dat files
|
38
|
+
def SpikeToRaw(SpikeSeq, save_path):
|
39
|
+
"""
|
40
|
+
SpikeSeq: Numpy array (sfn x h x w)
|
41
|
+
save_path: full saving path (string)
|
42
|
+
"""
|
43
|
+
sfn, h, w = SpikeSeq.shape
|
44
|
+
base = np.power(2, np.linspace(0, 7, 8))
|
45
|
+
fid = open(save_path, 'ab')
|
46
|
+
for img_id in range(sfn):
|
47
|
+
# 模拟相机的倒像
|
48
|
+
spike = np.flipud(SpikeSeq[img_id, :, :])
|
49
|
+
# numpy按自动按行排,数据也是按行存的
|
50
|
+
spike = spike.flatten()
|
51
|
+
spike = spike.reshape([int(h*w/8), 8])
|
52
|
+
data = spike * base
|
53
|
+
data = np.sum(data, axis=1).astype(np.uint8)
|
54
|
+
fid.write(data.tobytes())
|
55
|
+
|
56
|
+
fid.close()
|
57
|
+
|
58
|
+
return
|
59
|
+
|
60
|
+
|
61
|
+
def save_to_h5(SpikeMatrix, h5path, name):
|
62
|
+
f = h5py.File(h5path, 'w')
|
63
|
+
f[name] = SpikeMatrix
|
64
|
+
f.close()
|