spikezoo 0.2.2__py3-none-any.whl → 0.2.3.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- spikezoo/__init__.py +23 -7
- spikezoo/archs/bsf/models/bsf/bsf.py +37 -25
- spikezoo/archs/bsf/models/bsf/rep.py +2 -2
- spikezoo/archs/spk2imgnet/nets.py +1 -1
- spikezoo/archs/ssir/models/networks.py +1 -1
- spikezoo/archs/ssml/model.py +9 -5
- spikezoo/archs/stir/metrics/losses.py +1 -1
- spikezoo/archs/stir/models/networks_STIR.py +16 -9
- spikezoo/archs/tfi/nets.py +1 -1
- spikezoo/archs/tfp/nets.py +1 -1
- spikezoo/archs/wgse/dwtnets.py +6 -6
- spikezoo/datasets/__init__.py +11 -9
- spikezoo/datasets/base_dataset.py +10 -3
- spikezoo/datasets/realworld_dataset.py +1 -3
- spikezoo/datasets/{reds_small_dataset.py → reds_base_dataset.py} +9 -8
- spikezoo/datasets/reds_ssir_dataset.py +181 -0
- spikezoo/datasets/szdata_dataset.py +5 -15
- spikezoo/datasets/uhsr_dataset.py +4 -3
- spikezoo/models/__init__.py +8 -6
- spikezoo/models/base_model.py +120 -64
- spikezoo/models/bsf_model.py +11 -3
- spikezoo/models/spcsnet_model.py +19 -0
- spikezoo/models/spikeclip_model.py +4 -3
- spikezoo/models/spk2imgnet_model.py +9 -15
- spikezoo/models/ssir_model.py +4 -6
- spikezoo/models/ssml_model.py +44 -2
- spikezoo/models/stir_model.py +26 -5
- spikezoo/models/tfi_model.py +3 -1
- spikezoo/models/tfp_model.py +4 -2
- spikezoo/models/wgse_model.py +8 -14
- spikezoo/pipeline/base_pipeline.py +79 -55
- spikezoo/pipeline/ensemble_pipeline.py +10 -9
- spikezoo/pipeline/train_cfgs.py +89 -0
- spikezoo/pipeline/train_pipeline.py +129 -30
- spikezoo/utils/optimizer_utils.py +22 -0
- spikezoo/utils/other_utils.py +31 -6
- spikezoo/utils/scheduler_utils.py +25 -0
- spikezoo/utils/spike_utils.py +61 -29
- spikezoo-0.2.3.2.dist-info/METADATA +263 -0
- {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.2.dist-info}/RECORD +43 -80
- spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
- spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/CheckPoints/readme +0 -1
- spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +0 -60
- spikezoo/archs/spikeformer/DataProcess/DataLoader.py +0 -115
- spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +0 -39
- spikezoo/archs/spikeformer/EvalResults/readme +0 -1
- spikezoo/archs/spikeformer/LICENSE +0 -21
- spikezoo/archs/spikeformer/Metrics/Metrics.py +0 -50
- spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/Loss.py +0 -89
- spikezoo/archs/spikeformer/Model/SpikeFormer.py +0 -230
- spikezoo/archs/spikeformer/Model/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/README.md +0 -30
- spikezoo/archs/spikeformer/evaluate.py +0 -87
- spikezoo/archs/spikeformer/recon_real_data.py +0 -97
- spikezoo/archs/spikeformer/requirements.yml +0 -95
- spikezoo/archs/spikeformer/train.py +0 -173
- spikezoo/archs/spikeformer/utils.py +0 -22
- spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
- spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/models/spikeformer_model.py +0 -50
- spikezoo-0.2.2.dist-info/METADATA +0 -196
- {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.2.dist-info}/LICENSE.txt +0 -0
- {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.2.dist-info}/WHEEL +0 -0
- {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.2.dist-info}/top_level.txt +0 -0
spikezoo/__init__.py
CHANGED
@@ -1,13 +1,29 @@
|
|
1
|
-
from .utils.spike_utils import
|
1
|
+
from .utils.spike_utils import *
|
2
2
|
from .models import model_list
|
3
3
|
from .datasets import dataset_list
|
4
4
|
from .metrics import metric_all_names
|
5
5
|
|
6
|
-
|
7
|
-
|
6
|
+
# METHOD NAME DEFINITION
|
7
|
+
METHODS = model_list
|
8
|
+
class METHOD:
|
9
|
+
BASE = "base"
|
10
|
+
TFP = "tfp"
|
11
|
+
TFI = "tfi"
|
12
|
+
SPK2IMGNET = "spk2imgnet"
|
13
|
+
WGSE = "wgse"
|
14
|
+
SSML = "ssml"
|
15
|
+
BSF = "bsf"
|
16
|
+
STIR = "stir"
|
17
|
+
SSIR = "ssir"
|
18
|
+
SPIKECLIP = "spikeclip"
|
8
19
|
|
9
|
-
|
10
|
-
|
20
|
+
# DATASET NAME DEFINITION
|
21
|
+
DATASETS = dataset_list
|
22
|
+
class DATASET:
|
23
|
+
BASE = "base"
|
24
|
+
REDS_BASE = "reds_base"
|
25
|
+
REALWORLD = "realworld"
|
26
|
+
UHSR = "uhsr"
|
11
27
|
|
12
|
-
|
13
|
-
|
28
|
+
# METRIC NAME DEFINITION
|
29
|
+
METRICS = metric_all_names
|
@@ -8,18 +8,18 @@ from .align import Multi_Granularity_Align
|
|
8
8
|
class BasicModel(nn.Module):
|
9
9
|
def __init__(self):
|
10
10
|
super().__init__()
|
11
|
-
|
11
|
+
|
12
12
|
####################################################################################
|
13
13
|
## Tools functions for neural networks
|
14
14
|
def weight_parameters(self):
|
15
|
-
return [param for name, param in self.named_parameters() if
|
15
|
+
return [param for name, param in self.named_parameters() if "weight" in name]
|
16
16
|
|
17
17
|
def bias_parameters(self):
|
18
|
-
return [param for name, param in self.named_parameters() if
|
18
|
+
return [param for name, param in self.named_parameters() if "bias" in name]
|
19
19
|
|
20
20
|
def num_parameters(self):
|
21
21
|
return sum([p.data.nelement() if p.requires_grad else 0 for p in self.parameters()])
|
22
|
-
|
22
|
+
|
23
23
|
def init_weights(self):
|
24
24
|
for layer in self.named_modules():
|
25
25
|
if isinstance(layer, nn.Conv2d):
|
@@ -33,12 +33,21 @@ class BasicModel(nn.Module):
|
|
33
33
|
nn.init.constant_(layer.bias, 0)
|
34
34
|
|
35
35
|
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
36
|
+
from typing import Literal
|
37
|
+
|
38
|
+
|
39
|
+
def split_and_b_cat(x, spike_dim: Literal[41, 61] = 61):
|
40
|
+
if spike_dim == 61:
|
41
|
+
win_r = 10
|
42
|
+
win_step = 10
|
43
|
+
elif spike_dim == 41:
|
44
|
+
win_r = 6
|
45
|
+
win_step = 7
|
46
|
+
x0 = x[:, 0 : 2 * win_r + 1, :, :].clone()
|
47
|
+
x1 = x[:, win_step : win_step + 2 * win_r + 1, :, :].clone()
|
48
|
+
x2 = x[:, 2 * win_step : 2 * win_step + 2 * win_r + 1, :, :].clone()
|
49
|
+
x3 = x[:, 3 * win_step : 3 * win_step + 2 * win_r + 1, :, :].clone()
|
50
|
+
x4 = x[:, 4 * win_step : 4 * win_step + 2 * win_r + 1, :, :].clone()
|
42
51
|
return torch.cat([x0, x1, x2, x3, x4], dim=0)
|
43
52
|
|
44
53
|
|
@@ -61,39 +70,42 @@ class Encoder(nn.Module):
|
|
61
70
|
x = self.act(conv(x) + x)
|
62
71
|
return x
|
63
72
|
|
73
|
+
|
64
74
|
##########################################################################
|
65
75
|
class BSF(BasicModel):
|
66
|
-
def __init__(self, act=nn.ReLU()):
|
76
|
+
def __init__(self, spike_dim=61, act=nn.ReLU()):
|
67
77
|
super().__init__()
|
78
|
+
self.spike_dim = spike_dim
|
68
79
|
self.offset_groups = 4
|
69
80
|
self.corr_max_disp = 3
|
70
|
-
|
71
|
-
|
72
|
-
|
81
|
+
if spike_dim == 61:
|
82
|
+
self.rep = MODF(in_dim=21,base_dim=64, act=act)
|
83
|
+
elif spike_dim == 41:
|
84
|
+
self.rep = MODF(in_dim=13,base_dim=64, act=act)
|
73
85
|
self.encoder = Encoder(base_dim=64, layers=4, act=act)
|
74
86
|
|
75
87
|
self.align = Multi_Granularity_Align(base_dim=64, groups=self.offset_groups, act=act, sc=3)
|
76
88
|
|
77
89
|
self.recons = nn.Sequential(
|
78
|
-
nn.Conv2d(64*5, 64*3, kernel_size=3, padding=1),
|
90
|
+
nn.Conv2d(64 * 5, 64 * 3, kernel_size=3, padding=1),
|
79
91
|
act,
|
80
|
-
nn.Conv2d(64*3, 64, kernel_size=3, padding=1),
|
92
|
+
nn.Conv2d(64 * 3, 64, kernel_size=3, padding=1),
|
81
93
|
act,
|
82
94
|
nn.Conv2d(64, 1, kernel_size=3, padding=1),
|
83
95
|
)
|
84
96
|
|
85
97
|
def forward(self, input_dict):
|
86
|
-
dsft_dict = input_dict[
|
87
|
-
dsft11 = dsft_dict[
|
88
|
-
dsft12 = dsft_dict[
|
89
|
-
dsft21 = dsft_dict[
|
90
|
-
dsft22 = dsft_dict[
|
98
|
+
dsft_dict = input_dict["dsft_dict"]
|
99
|
+
dsft11 = dsft_dict["dsft11"]
|
100
|
+
dsft12 = dsft_dict["dsft12"]
|
101
|
+
dsft21 = dsft_dict["dsft21"]
|
102
|
+
dsft22 = dsft_dict["dsft22"]
|
91
103
|
|
92
104
|
dsft_b_cat = {
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
105
|
+
"dsft11": split_and_b_cat(dsft11, self.spike_dim),
|
106
|
+
"dsft12": split_and_b_cat(dsft12, self.spike_dim),
|
107
|
+
"dsft21": split_and_b_cat(dsft21, self.spike_dim),
|
108
|
+
"dsft22": split_and_b_cat(dsft22, self.spike_dim),
|
97
109
|
}
|
98
110
|
|
99
111
|
feat_b_cat = self.rep(dsft_b_cat)
|
@@ -2,11 +2,11 @@ import torch
|
|
2
2
|
import torch.nn as nn
|
3
3
|
|
4
4
|
class MODF(nn.Module):
|
5
|
-
def __init__(self, base_dim=64, act=nn.ReLU()):
|
5
|
+
def __init__(self, in_dim = 21, base_dim=64, act=nn.ReLU()):
|
6
6
|
super().__init__()
|
7
7
|
self.base_dim = base_dim
|
8
8
|
|
9
|
-
self.conv1 = self._make_layer(input_dim=
|
9
|
+
self.conv1 = self._make_layer(input_dim=in_dim, hidden_dim=self.base_dim, output_dim=self.base_dim, act=act)
|
10
10
|
self.conv_for_others = nn.ModuleList([
|
11
11
|
self._make_layer(input_dim=self.base_dim, hidden_dim=self.base_dim, output_dim=self.base_dim, act=act) for ii in range(3)
|
12
12
|
])
|
@@ -167,7 +167,7 @@ class FusionMaskV1(nn.Module):
|
|
167
167
|
|
168
168
|
# current best model
|
169
169
|
class SpikeNet(nn.Module):
|
170
|
-
def __init__(self, in_channels, features, out_channels, win_r, win_step):
|
170
|
+
def __init__(self, in_channels = 13, features = 64, out_channels = 1, win_r = 6, win_step = 7):
|
171
171
|
super(SpikeNet, self).__init__()
|
172
172
|
self.extractor = FeatureExtractor(
|
173
173
|
in_channels=in_channels,
|
spikezoo/archs/ssml/model.py
CHANGED
@@ -272,18 +272,22 @@ class BSN(nn.Module):
|
|
272
272
|
diff = W - H
|
273
273
|
x0 = x0[:, :, (diff // 2):(diff // 2 + H), 0:W]
|
274
274
|
|
275
|
-
return x0
|
275
|
+
return x0
|
276
276
|
|
277
277
|
class DoubleNet(nn.Module):
|
278
278
|
def __init__(self):
|
279
279
|
super().__init__()
|
280
280
|
self.nbsn = BSN(n_channels=41, n_output=1,blind=False)
|
281
|
-
|
281
|
+
self.bsn = BSN(n_channels=41, n_output=1,blind=True)
|
282
282
|
|
283
283
|
def forward(self, x):
|
284
|
-
|
285
|
-
|
286
|
-
|
284
|
+
if self.training:
|
285
|
+
bsn_pred = self.bsn(x)
|
286
|
+
nbsn_pred = self.nbsn(x)
|
287
|
+
return bsn_pred,nbsn_pred
|
288
|
+
else:
|
289
|
+
nbsn_pred = self.nbsn(x)
|
290
|
+
return nbsn_pred
|
287
291
|
|
288
292
|
if __name__ == '__main__':
|
289
293
|
a=DoubleNet().cuda()
|
@@ -292,16 +292,21 @@ class STIRDecorder(nn.Module):#second and third levels
|
|
292
292
|
|
293
293
|
##############################Our Model####################################
|
294
294
|
class STIR(BasicModel):
|
295
|
-
def __init__(self, hidd_chs=8, win_r=6, win_step=7):
|
295
|
+
def __init__(self, spike_dim = 61,hidd_chs=8, win_r=6, win_step=7):
|
296
296
|
super().__init__()
|
297
297
|
|
298
298
|
self.init_chs = [16, 24, 32, 64, 96]
|
299
299
|
self.hidd_chs = hidd_chs
|
300
|
+
self.spike_dim = spike_dim
|
300
301
|
self.attn_num_splits = 1
|
301
302
|
|
302
303
|
self.N_group = 3
|
303
|
-
|
304
|
-
|
304
|
+
if spike_dim == 61:
|
305
|
+
self.resnet = ResidualBlock(in_channles=21, num_channles=11, use_1x1conv=True)
|
306
|
+
dim_tfp = 16 # 5 + num_channels
|
307
|
+
elif spike_dim == 41:
|
308
|
+
self.resnet = ResidualBlock(in_channles=15, num_channles=11, use_1x1conv=True)
|
309
|
+
dim_tfp = 15 # 4 + num_channels
|
305
310
|
self.encoder = ImageEncoder(in_chs=dim_tfp, init_chs=self.init_chs)
|
306
311
|
|
307
312
|
self.transformer = CrossTransformerBlock(dim=self.init_chs[-1], num_heads=4, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias')
|
@@ -314,14 +319,16 @@ class STIR(BasicModel):
|
|
314
319
|
self.win_r = win_r
|
315
320
|
self.win_step = win_step
|
316
321
|
|
317
|
-
self.resnet = ResidualBlock(in_channles=21, num_channles=11, use_1x1conv=True)
|
318
|
-
|
319
322
|
def forward(self, x):
|
320
323
|
b,_,h,w=x.size()
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
324
|
+
if self.spike_dim == 61:
|
325
|
+
block1 = x[:, 0 : 21, :, :]
|
326
|
+
block2 = x[:, 20 : 41, :, :]
|
327
|
+
block3 = x[:, 40 : 61, :, :]
|
328
|
+
elif self.spike_dim == 41:
|
329
|
+
block1 = x[:, 0 : 15, :, :]
|
330
|
+
block2 = x[:, 13 : 28, :, :]
|
331
|
+
block3 = x[:, 26 : 41, :, :]
|
325
332
|
|
326
333
|
repre1 = TFP(block1, channel_step=2)#C: 5
|
327
334
|
repre2 = TFP(block2, channel_step=2)
|
spikezoo/archs/tfi/nets.py
CHANGED
spikezoo/archs/tfp/nets.py
CHANGED
spikezoo/archs/wgse/dwtnets.py
CHANGED
@@ -94,15 +94,15 @@ class Dwt1dModule_Tcn(nn.Module):
|
|
94
94
|
class Dwt1dResnetX_TCN(nn.Module):
|
95
95
|
def __init__(
|
96
96
|
self,
|
97
|
-
wvlname='
|
98
|
-
J=
|
99
|
-
yl_size=
|
100
|
-
yh_size=[
|
101
|
-
num_residual_blocks=
|
97
|
+
wvlname='db8',
|
98
|
+
J=5,
|
99
|
+
yl_size=15,
|
100
|
+
yh_size=[28, 21, 18, 16, 15],
|
101
|
+
num_residual_blocks=3,
|
102
102
|
norm=None,
|
103
103
|
inc=41,
|
104
104
|
ks=3,
|
105
|
-
store_features=
|
105
|
+
store_features=True
|
106
106
|
):
|
107
107
|
super().__init__()
|
108
108
|
|
spikezoo/datasets/__init__.py
CHANGED
@@ -4,28 +4,30 @@ import importlib, inspect
|
|
4
4
|
import os
|
5
5
|
import torch
|
6
6
|
from typing import Literal
|
7
|
+
from spikezoo.utils.other_utils import getattr_case_insensitive
|
7
8
|
|
8
9
|
# todo auto detect/register datasets
|
9
10
|
files_list = os.listdir(os.path.dirname(os.path.abspath(__file__)))
|
10
11
|
dataset_list = [file.replace("_dataset.py", "") for file in files_list if file.endswith("_dataset.py")]
|
11
12
|
|
13
|
+
|
12
14
|
# todo register function
|
13
15
|
def build_dataset_cfg(cfg: BaseDatasetConfig, split: Literal["train", "test"] = "test"):
|
14
16
|
"""Build the dataset from the given dataset config."""
|
15
17
|
# build new cfg according to split
|
16
|
-
cfg = replace(cfg,split
|
18
|
+
cfg = replace(cfg, split=split, spike_length=cfg.spike_length_train if split == "train" else cfg.spike_length_test)
|
17
19
|
# dataset module
|
18
20
|
module_name = cfg.dataset_name + "_dataset"
|
19
21
|
assert cfg.dataset_name in dataset_list, f"Given dataset {cfg.dataset_name} not in our dataset list {dataset_list}."
|
20
22
|
module_name = "spikezoo.datasets." + module_name
|
21
23
|
module = importlib.import_module(module_name)
|
22
24
|
# dataset,dataset_config
|
23
|
-
|
24
|
-
|
25
|
+
dataset_name = cfg.dataset_name
|
26
|
+
dataset_name = dataset_name + "Dataset" if dataset_name == "base" else dataset_name
|
27
|
+
dataset_cls: BaseDataset = getattr_case_insensitive(module, dataset_name)
|
25
28
|
dataset = dataset_cls(cfg)
|
26
29
|
return dataset
|
27
30
|
|
28
|
-
|
29
31
|
def build_dataset_name(dataset_name: str, split: Literal["train", "test"] = "test"):
|
30
32
|
"""Build the default dataset from the given name."""
|
31
33
|
module_name = dataset_name + "_dataset"
|
@@ -33,21 +35,21 @@ def build_dataset_name(dataset_name: str, split: Literal["train", "test"] = "tes
|
|
33
35
|
module_name = "spikezoo.datasets." + module_name
|
34
36
|
module = importlib.import_module(module_name)
|
35
37
|
# dataset,dataset_config
|
36
|
-
|
37
|
-
dataset_cls: BaseDataset =
|
38
|
-
dataset_cfg: BaseDatasetConfig =
|
38
|
+
dataset_name = dataset_name + "Dataset" if dataset_name == "base" else dataset_name
|
39
|
+
dataset_cls: BaseDataset = getattr_case_insensitive(module, dataset_name)
|
40
|
+
dataset_cfg: BaseDatasetConfig = getattr_case_insensitive(module, dataset_name + "config")(split=split)
|
39
41
|
dataset = dataset_cls(dataset_cfg)
|
40
42
|
return dataset
|
41
43
|
|
42
44
|
|
43
45
|
# todo to modify according to the basicsr
|
44
|
-
def build_dataloader(dataset: BaseDataset,cfg
|
46
|
+
def build_dataloader(dataset: BaseDataset, cfg=None):
|
45
47
|
# train dataloader
|
46
48
|
if dataset.cfg.split == "train":
|
47
49
|
if cfg is None:
|
48
50
|
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
|
49
51
|
else:
|
50
|
-
return torch.utils.data.DataLoader(dataset, batch_size=cfg.bs_train, shuffle=True, num_workers=cfg.num_workers,pin_memory=cfg.pin_memory)
|
52
|
+
return torch.utils.data.DataLoader(dataset, batch_size=cfg.bs_train, shuffle=True, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory)
|
51
53
|
# test dataloader
|
52
54
|
elif dataset.cfg.split == "test":
|
53
55
|
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
|
@@ -6,7 +6,7 @@ import numpy as np
|
|
6
6
|
from spikezoo.utils.spike_utils import load_vidar_dat
|
7
7
|
import re
|
8
8
|
from dataclasses import dataclass, replace
|
9
|
-
from typing import Literal,Union
|
9
|
+
from typing import Literal, Union
|
10
10
|
import warnings
|
11
11
|
import torch
|
12
12
|
from tqdm import tqdm
|
@@ -19,7 +19,7 @@ class BaseDatasetConfig:
|
|
19
19
|
"Dataset name."
|
20
20
|
dataset_name: str = "base"
|
21
21
|
"Directory specifying location of data."
|
22
|
-
root_dir: Union[str,Path] = Path(__file__).parent.parent / Path("data/base")
|
22
|
+
root_dir: Union[str, Path] = Path(__file__).parent.parent / Path("data/base")
|
23
23
|
"Image width."
|
24
24
|
width: int = 400
|
25
25
|
"Image height."
|
@@ -46,6 +46,8 @@ class BaseDatasetConfig:
|
|
46
46
|
use_cache: bool = False
|
47
47
|
"Crop size."
|
48
48
|
crop_size: tuple = (-1, -1)
|
49
|
+
"Rate. (-1 denotes variant)"
|
50
|
+
rate: float = 1
|
49
51
|
|
50
52
|
# post process
|
51
53
|
def __post_init__(self):
|
@@ -54,6 +56,7 @@ class BaseDatasetConfig:
|
|
54
56
|
# todo try download
|
55
57
|
assert self.root_dir.exists(), f"No files found in {self.root_dir} for the specified dataset `{self.dataset_name}`."
|
56
58
|
|
59
|
+
|
57
60
|
# todo cache mechanism
|
58
61
|
class BaseDataset(Dataset):
|
59
62
|
def __init__(self, cfg: BaseDatasetConfig):
|
@@ -80,7 +83,11 @@ class BaseDataset(Dataset):
|
|
80
83
|
if self.cfg.use_aug == True and self.cfg.split == "train":
|
81
84
|
spike, img = self.augmentor(spike, img)
|
82
85
|
|
83
|
-
|
86
|
+
# rate
|
87
|
+
rate = self.cfg.rate
|
88
|
+
|
89
|
+
# ! spike and gt_img names are necessary
|
90
|
+
batch = {"spike": spike, "gt_img": img, "rate": rate}
|
84
91
|
return batch
|
85
92
|
|
86
93
|
# todo: To be overridden
|
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|
4
4
|
|
5
5
|
|
6
6
|
@dataclass
|
7
|
-
class
|
7
|
+
class RealWorldConfig(BaseDatasetConfig):
|
8
8
|
dataset_name: str = "realworld"
|
9
9
|
root_dir: Path = Path(__file__).parent.parent / Path("data/recVidarReal2019")
|
10
10
|
width: int = 400
|
@@ -21,5 +21,3 @@ class RealWorld(BaseDataset):
|
|
21
21
|
def prepare_data(self):
|
22
22
|
self.spike_dir = self.cfg.root_dir
|
23
23
|
self.spike_list = self.get_spike_files(self.spike_dir)
|
24
|
-
|
25
|
-
|
@@ -4,10 +4,11 @@ from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
|
|
4
4
|
from dataclasses import dataclass
|
5
5
|
import re
|
6
6
|
|
7
|
+
|
7
8
|
@dataclass
|
8
|
-
class
|
9
|
-
dataset_name: str = "
|
10
|
-
root_dir: Path = Path(__file__).parent.parent / Path("data/
|
9
|
+
class REDS_BASEConfig(BaseDatasetConfig):
|
10
|
+
dataset_name: str = "reds_base"
|
11
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/REDS_BASE")
|
11
12
|
width: int = 400
|
12
13
|
height: int = 250
|
13
14
|
with_img: bool = True
|
@@ -15,13 +16,13 @@ class REDS_Small_Config(BaseDatasetConfig):
|
|
15
16
|
spike_length_test: int = 301
|
16
17
|
spike_dir_name: str = "spike"
|
17
18
|
img_dir_name: str = "gt"
|
19
|
+
rate: float = 0.6
|
20
|
+
|
18
21
|
|
19
|
-
class
|
22
|
+
class REDS_BASE(BaseDataset):
|
20
23
|
def __init__(self, cfg: BaseDatasetConfig):
|
21
|
-
super(
|
24
|
+
super(REDS_BASE, self).__init__(cfg)
|
22
25
|
|
23
26
|
def prepare_data(self):
|
24
27
|
super().prepare_data()
|
25
|
-
|
26
|
-
self.img_list = [self.img_dir / Path(str(s.name).replace('.dat','.png')) for s in self.spike_list]
|
27
|
-
|
28
|
+
self.img_list = [self.img_dir / Path(str(s.name).replace(".dat", ".png")) for s in self.spike_list]
|
@@ -0,0 +1,181 @@
|
|
1
|
+
from torch.utils.data import Dataset
|
2
|
+
from pathlib import Path
|
3
|
+
from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
|
4
|
+
from dataclasses import dataclass
|
5
|
+
import re
|
6
|
+
|
7
|
+
|
8
|
+
@dataclass
|
9
|
+
class REDS_SSIRConfig(BaseDatasetConfig):
|
10
|
+
dataset_name: str = "reds_ssir"
|
11
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/REDS_SSIR")
|
12
|
+
train_width: int = 96
|
13
|
+
train_height: int = 96
|
14
|
+
test_width: int = 1280
|
15
|
+
test_height: int = 720
|
16
|
+
width: int = -1
|
17
|
+
height: int = -1
|
18
|
+
with_img: bool = True
|
19
|
+
spike_length_train: int = 41
|
20
|
+
spike_length_test: int = 301
|
21
|
+
|
22
|
+
# post process
|
23
|
+
def __post_init__(self):
|
24
|
+
self.root_dir = Path(self.root_dir) if isinstance(self.root_dir, str) else self.root_dir
|
25
|
+
# todo try download
|
26
|
+
assert self.root_dir.exists(), f"No files found in {self.root_dir} for the specified dataset `{self.dataset_name}`."
|
27
|
+
# train/test split
|
28
|
+
if self.split == "train":
|
29
|
+
self.spike_length = self.spike_length_train
|
30
|
+
self.width = self.train_width
|
31
|
+
self.height = self.train_height
|
32
|
+
else:
|
33
|
+
self.spike_length = self.spike_length_test
|
34
|
+
self.width = self.test_width
|
35
|
+
self.height = self.test_height
|
36
|
+
|
37
|
+
|
38
|
+
class REDS_SSIR(BaseDataset):
|
39
|
+
def __init__(self, cfg: BaseDatasetConfig):
|
40
|
+
super(REDS_SSIR, self).__init__(cfg)
|
41
|
+
|
42
|
+
def prepare_data(self):
|
43
|
+
"""Specify the spike and image files to be loaded."""
|
44
|
+
# spike/imgs dir train/test
|
45
|
+
if self.cfg.split == "train":
|
46
|
+
self.img_dir = self.cfg.root_dir / Path("crop_mini/spike/train/interp_20_alpha_0.40")
|
47
|
+
self.spike_dir = self.cfg.root_dir / Path("crop_mini/image/train/train_orig")
|
48
|
+
else:
|
49
|
+
self.img_dir = self.cfg.root_dir / Path("imgs/val/val_orig")
|
50
|
+
self.spike_dir = self.cfg.root_dir / Path("spike/val/interp_20_alpha_0.40")
|
51
|
+
# get files
|
52
|
+
self.spike_list = self.get_spike_files(self.spike_dir)
|
53
|
+
self.img_list = []
|
54
|
+
|
55
|
+
|
56
|
+
class sreds_train(torch.utils.data.Dataset):
|
57
|
+
def __init__(self, cfg):
|
58
|
+
self.cfg = cfg
|
59
|
+
self.pair_step = self.cfg["loader"]["pair_step"]
|
60
|
+
self.augmentor = Augmentor(crop_size=self.cfg["loader"]["crop_size"])
|
61
|
+
self.samples = self.collect_samples()
|
62
|
+
print("The samples num of training data: {:d}".format(len(self.samples)))
|
63
|
+
|
64
|
+
def confirm_exist(self, path_list_list):
|
65
|
+
for pl in path_list_list:
|
66
|
+
for p in pl:
|
67
|
+
if not osp.exists(p):
|
68
|
+
return 0
|
69
|
+
return 1
|
70
|
+
|
71
|
+
def collect_samples(self):
|
72
|
+
spike_path = osp.join(
|
73
|
+
self.cfg["data"]["root"], "crop_mini", "spike", "train", "interp_{:d}_alpha_{:.2f}".format(self.cfg["data"]["interp"], self.cfg["data"]["alpha"])
|
74
|
+
)
|
75
|
+
image_path = osp.join(self.cfg["data"]["root"], "crop_mini", "image", "train", "train_orig")
|
76
|
+
scene_list = sorted(os.listdir(spike_path))
|
77
|
+
samples = []
|
78
|
+
|
79
|
+
for scene in scene_list:
|
80
|
+
spike_dir = osp.join(spike_path, scene)
|
81
|
+
image_dir = osp.join(image_path, scene)
|
82
|
+
spk_path_list = sorted(os.listdir(spike_dir))
|
83
|
+
|
84
|
+
spklen = len(spk_path_list)
|
85
|
+
seq_len = self.cfg["model"]["seq_len"] + 2
|
86
|
+
"""
|
87
|
+
for st in range(0, spklen - ((spklen - self.pair_step) % seq_len) - seq_len, self.pair_step):
|
88
|
+
# 按照文件名称读取
|
89
|
+
spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(st, st+seq_len)]
|
90
|
+
images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4]+'.png') for ii in range(st, st+seq_len)]
|
91
|
+
|
92
|
+
if(self.confirm_exist([spikes_path_list, images_path_list])):
|
93
|
+
s = {}
|
94
|
+
s['spikes_paths'] = spikes_path_list
|
95
|
+
s['images_paths'] = images_path_list
|
96
|
+
samples.append(s)
|
97
|
+
"""
|
98
|
+
# 按照文件名称读取
|
99
|
+
spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(spklen)]
|
100
|
+
images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4] + ".png") for ii in range(spklen)]
|
101
|
+
|
102
|
+
if self.confirm_exist([spikes_path_list, images_path_list]):
|
103
|
+
s = {}
|
104
|
+
s["spikes_paths"] = spikes_path_list
|
105
|
+
s["images_paths"] = images_path_list
|
106
|
+
samples.append(s)
|
107
|
+
|
108
|
+
return samples
|
109
|
+
|
110
|
+
def _load_sample(self, s):
|
111
|
+
data = {}
|
112
|
+
|
113
|
+
data["spikes"] = [np.array(dat_to_spmat(p, size=(96, 96)), dtype=np.float32) for p in s["spikes_paths"]]
|
114
|
+
data["images"] = [read_img_gray(p) for p in s["images_paths"]]
|
115
|
+
|
116
|
+
data["spikes"], data["images"] = self.augmentor(data["spikes"], data["images"])
|
117
|
+
# print("data['spikes'][0].shape, data['images'][0].shape", data['spikes'][0].shape, data['images'][0].shape)
|
118
|
+
|
119
|
+
return data
|
120
|
+
|
121
|
+
def __len__(self):
|
122
|
+
return len(self.samples)
|
123
|
+
|
124
|
+
def __getitem__(self, index):
|
125
|
+
data = self._load_sample(self.samples[index])
|
126
|
+
return data
|
127
|
+
|
128
|
+
|
129
|
+
class sreds_test(torch.utils.data.Dataset):
|
130
|
+
def __init__(self, cfg):
|
131
|
+
self.cfg = cfg
|
132
|
+
self.samples = self.collect_samples()
|
133
|
+
print("The samples num of testing data: {:d}".format(len(self.samples)))
|
134
|
+
|
135
|
+
def confirm_exist(self, path_list_list):
|
136
|
+
for pl in path_list_list:
|
137
|
+
for p in pl:
|
138
|
+
if not osp.exists(p):
|
139
|
+
return 0
|
140
|
+
return 1
|
141
|
+
|
142
|
+
def collect_samples(self):
|
143
|
+
spike_path = osp.join(
|
144
|
+
self.cfg["data"]["root"], "spike", "val", "interp_{:d}_alpha_{:.2f}".format(self.cfg["data"]["interp"], self.cfg["data"]["alpha"])
|
145
|
+
)
|
146
|
+
image_path = osp.join(self.cfg["data"]["root"], "imgs", "val", "val_orig")
|
147
|
+
scene_list = sorted(os.listdir(spike_path))
|
148
|
+
samples = []
|
149
|
+
|
150
|
+
for scene in scene_list:
|
151
|
+
spike_dir = osp.join(spike_path, scene)
|
152
|
+
image_dir = osp.join(image_path, scene)
|
153
|
+
spk_path_list = sorted(os.listdir(spike_dir))
|
154
|
+
|
155
|
+
spklen = len(spk_path_list)
|
156
|
+
# seq_len = self.cfg['model']['seq_len']
|
157
|
+
|
158
|
+
# 按照文件名称读取
|
159
|
+
spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(spklen)]
|
160
|
+
images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4] + ".png") for ii in range(spklen)]
|
161
|
+
|
162
|
+
if self.confirm_exist([spikes_path_list, images_path_list]):
|
163
|
+
s = {}
|
164
|
+
s["spikes_paths"] = spikes_path_list
|
165
|
+
s["images_paths"] = images_path_list
|
166
|
+
samples.append(s)
|
167
|
+
|
168
|
+
return samples
|
169
|
+
|
170
|
+
def _load_sample(self, s):
|
171
|
+
data = {}
|
172
|
+
data["spikes"] = [np.array(dat_to_spmat(p, size=(720, 1280)), dtype=np.float32) for p in s["spikes_paths"]]
|
173
|
+
data["images"] = [read_img_gray(p) for p in s["images_paths"]]
|
174
|
+
return data
|
175
|
+
|
176
|
+
def __len__(self):
|
177
|
+
return len(self.samples)
|
178
|
+
|
179
|
+
def __getitem__(self, index):
|
180
|
+
data = self._load_sample(self.samples[index])
|
181
|
+
return data
|