spikezoo 0.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- spikezoo/__init__.py +0 -0
- spikezoo/archs/__init__.py +0 -0
- spikezoo/datasets/__init__.py +68 -0
- spikezoo/datasets/base_dataset.py +157 -0
- spikezoo/datasets/realworld_dataset.py +25 -0
- spikezoo/datasets/reds_small_dataset.py +27 -0
- spikezoo/datasets/szdata_dataset.py +37 -0
- spikezoo/datasets/uhsr_dataset.py +38 -0
- spikezoo/metrics/__init__.py +96 -0
- spikezoo/models/__init__.py +37 -0
- spikezoo/models/base_model.py +177 -0
- spikezoo/models/bsf_model.py +90 -0
- spikezoo/models/spcsnet_model.py +19 -0
- spikezoo/models/spikeclip_model.py +32 -0
- spikezoo/models/spikeformer_model.py +50 -0
- spikezoo/models/spk2imgnet_model.py +51 -0
- spikezoo/models/ssir_model.py +22 -0
- spikezoo/models/ssml_model.py +18 -0
- spikezoo/models/stir_model.py +37 -0
- spikezoo/models/tfi_model.py +18 -0
- spikezoo/models/tfp_model.py +18 -0
- spikezoo/models/wgse_model.py +31 -0
- spikezoo/pipeline/__init__.py +4 -0
- spikezoo/pipeline/base_pipeline.py +267 -0
- spikezoo/pipeline/ensemble_pipeline.py +64 -0
- spikezoo/pipeline/train_pipeline.py +94 -0
- spikezoo/utils/__init__.py +3 -0
- spikezoo/utils/data_utils.py +52 -0
- spikezoo/utils/img_utils.py +72 -0
- spikezoo/utils/other_utils.py +59 -0
- spikezoo/utils/spike_utils.py +82 -0
- spikezoo-0.1.dist-info/LICENSE.txt +17 -0
- spikezoo-0.1.dist-info/METADATA +39 -0
- spikezoo-0.1.dist-info/RECORD +36 -0
- spikezoo-0.1.dist-info/WHEEL +5 -0
- spikezoo-0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,90 @@
|
|
1
|
+
import torch
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
4
|
+
|
5
|
+
|
6
|
+
@dataclass
|
7
|
+
class BSFConfig(BaseModelConfig):
|
8
|
+
# default params for BSF
|
9
|
+
model_name: str = "bsf"
|
10
|
+
model_file_name: str = "models.bsf.bsf"
|
11
|
+
model_cls_name: str = "BSF"
|
12
|
+
model_win_length: int = 61
|
13
|
+
require_params: bool = True
|
14
|
+
ckpt_path: str = "weights/bsf.pth"
|
15
|
+
|
16
|
+
|
17
|
+
class BSF(BaseModel):
|
18
|
+
def __init__(self, cfg: BaseModelConfig):
|
19
|
+
super(BSF, self).__init__(cfg)
|
20
|
+
|
21
|
+
def preprocess_spike(self, spike):
|
22
|
+
# length
|
23
|
+
spike = self.crop_spike_length(spike)
|
24
|
+
# size
|
25
|
+
if self.spike_size == (250, 400):
|
26
|
+
spike = torch.cat([spike, spike[:, :, -2:]], dim=2)
|
27
|
+
elif self.spike_size == (480, 854):
|
28
|
+
spike = torch.cat([spike, spike[:, :, :, -2:]], dim=3)
|
29
|
+
# dsft
|
30
|
+
dsft = self.compute_dsft_core(spike)
|
31
|
+
dsft_dict = self.convert_dsft4(dsft, spike)
|
32
|
+
input_dict = {
|
33
|
+
"dsft_dict": dsft_dict,
|
34
|
+
"spikes": spike,
|
35
|
+
}
|
36
|
+
return input_dict
|
37
|
+
|
38
|
+
def postprocess_img(self, image):
|
39
|
+
if self.spike_size == (250, 400):
|
40
|
+
image = image[:, :, :250, :]
|
41
|
+
elif self.spike_size == (480, 854):
|
42
|
+
image = image[:, :, :, :854]
|
43
|
+
return image
|
44
|
+
|
45
|
+
def compute_dsft_core(self, spike):
|
46
|
+
bs, T, H, W = spike.shape
|
47
|
+
time = spike * torch.arange(T, device="cuda").reshape(1, T, 1, 1)
|
48
|
+
l_idx, _ = time.cummax(dim=1)
|
49
|
+
time[time == 0] = T
|
50
|
+
r_idx, _ = torch.flip(time, [1]).cummin(dim=1)
|
51
|
+
r_idx = torch.flip(r_idx, [1])
|
52
|
+
r_idx = torch.cat([r_idx[:, 1:, :, :], torch.ones([bs, 1, H, W], device="cuda") * T], dim=1)
|
53
|
+
res = r_idx - l_idx
|
54
|
+
res = torch.clip(res, 0)
|
55
|
+
return res
|
56
|
+
|
57
|
+
def convert_dsft4(self, dsft, spike):
|
58
|
+
b, T, h, w = spike.shape
|
59
|
+
dmls1 = -1 * torch.ones(spike.shape, device=spike.device, dtype=torch.float32)
|
60
|
+
dmrs1 = -1 * torch.ones(spike.shape, device=spike.device, dtype=torch.float32)
|
61
|
+
flag = -2 * torch.ones([b, h, w], device=spike.device, dtype=torch.float32)
|
62
|
+
for ii in range(T - 1, 0 - 1, -1):
|
63
|
+
flag += spike[:, ii] == 1
|
64
|
+
copy_pad_coord = flag < 0
|
65
|
+
dmls1[:, ii][copy_pad_coord] = dsft[:, ii][copy_pad_coord]
|
66
|
+
if ii < T - 1:
|
67
|
+
update_coord = (spike[:, ii + 1] == 1) * (~copy_pad_coord)
|
68
|
+
dmls1[:, ii][update_coord] = dsft[:, ii + 1][update_coord]
|
69
|
+
non_update_coord = (spike[:, ii + 1] != 1) * (~copy_pad_coord)
|
70
|
+
dmls1[:, ii][non_update_coord] = dmls1[:, ii + 1][non_update_coord]
|
71
|
+
flag = -2 * torch.ones([b, h, w], device=spike.device, dtype=torch.float32)
|
72
|
+
for ii in range(0, T, 1):
|
73
|
+
flag += spike[:, ii] == 1
|
74
|
+
copy_pad_coord = flag < 0
|
75
|
+
dmrs1[:, ii][copy_pad_coord] = dsft[:, ii][copy_pad_coord]
|
76
|
+
if ii > 0:
|
77
|
+
update_coord = (spike[:, ii] == 1) * (~copy_pad_coord)
|
78
|
+
dmrs1[:, ii][update_coord] = dsft[:, ii - 1][update_coord]
|
79
|
+
non_update_coord = (spike[:, ii] != 1) * (~copy_pad_coord)
|
80
|
+
dmrs1[:, ii][non_update_coord] = dmrs1[:, ii - 1][non_update_coord]
|
81
|
+
dsft12 = dsft + dmls1
|
82
|
+
dsft21 = dsft + dmrs1
|
83
|
+
dsft22 = dsft + dmls1 + dmrs1
|
84
|
+
dsft_dict = {
|
85
|
+
"dsft11": dsft,
|
86
|
+
"dsft12": dsft12,
|
87
|
+
"dsft21": dsft21,
|
88
|
+
"dsft22": dsft22,
|
89
|
+
}
|
90
|
+
return dsft_dict
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from dataclasses import dataclass, field
|
2
|
+
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
3
|
+
from typing import List
|
4
|
+
|
5
|
+
|
6
|
+
@dataclass
|
7
|
+
class SPCSNetConfig(BaseModelConfig):
|
8
|
+
# default params for WGSE
|
9
|
+
model_name: str = "spcsnet"
|
10
|
+
model_file_name: str = "models"
|
11
|
+
model_cls_name: str = "SPCS_Net"
|
12
|
+
model_win_length: int = 41
|
13
|
+
require_params: bool = True
|
14
|
+
ckpt_path: str = 'weights/spcsnet.pth'
|
15
|
+
|
16
|
+
|
17
|
+
class SPCSNet(BaseModel):
|
18
|
+
def __init__(self, cfg: BaseModelConfig):
|
19
|
+
super(SPCSNet, self).__init__(cfg)
|
@@ -0,0 +1,32 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
3
|
+
import torch
|
4
|
+
import torch.nn.functional as F
|
5
|
+
|
6
|
+
|
7
|
+
@dataclass
|
8
|
+
class SpikeCLIPConfig(BaseModelConfig):
|
9
|
+
# default params for SpikeCLIP
|
10
|
+
model_name: str = "spikeclip"
|
11
|
+
model_file_name: str = "nets"
|
12
|
+
model_cls_name: str = "LRN"
|
13
|
+
model_win_length: int = 200
|
14
|
+
require_params: bool = True
|
15
|
+
ckpt_path: str = "weights/spikeclip.pth"
|
16
|
+
|
17
|
+
|
18
|
+
class SpikeCLIP(BaseModel):
|
19
|
+
def __init__(self, cfg: BaseModelConfig):
|
20
|
+
super(SpikeCLIP, self).__init__(cfg)
|
21
|
+
|
22
|
+
def preprocess_spike(self, spike):
|
23
|
+
# length
|
24
|
+
spike = self.crop_spike_length(spike)
|
25
|
+
# voxel
|
26
|
+
voxel = torch.sum(spike.reshape(-1, 50, 4, spike.shape[-2], spike.shape[-1]), axis=2) # [200,224,224] -> [50,224,224]
|
27
|
+
voxel = F.pad(voxel, pad=(20, 20, 20, 20), mode="reflect", value=0)
|
28
|
+
return voxel
|
29
|
+
|
30
|
+
def postprocess_img(self, image):
|
31
|
+
image = image[:, :, 20:-20, 20:-20]
|
32
|
+
return image
|
@@ -0,0 +1,50 @@
|
|
1
|
+
import torch
|
2
|
+
from dataclasses import dataclass, field
|
3
|
+
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
4
|
+
|
5
|
+
|
6
|
+
@dataclass
|
7
|
+
class SpikeFormerConfig(BaseModelConfig):
|
8
|
+
# default params for SpikeFormer
|
9
|
+
model_name: str = "spikeformer"
|
10
|
+
model_file_name: str = "Model.SpikeFormer"
|
11
|
+
model_cls_name: str = "SpikeFormer"
|
12
|
+
model_win_length: int = 65
|
13
|
+
require_params: bool = True
|
14
|
+
ckpt_path: str = "weights/spikeformer.pth"
|
15
|
+
model_params: dict = field(
|
16
|
+
default_factory=lambda: {
|
17
|
+
"inputDim": 65,
|
18
|
+
"dims": (32, 64, 160, 256),
|
19
|
+
"heads": (1, 2, 5, 8),
|
20
|
+
"ff_expansion": (8, 8, 4, 4),
|
21
|
+
"reduction_ratio": (8, 4, 2, 1),
|
22
|
+
"num_layers": 2,
|
23
|
+
"decoder_dim": 256,
|
24
|
+
"out_channel": 1,
|
25
|
+
}
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
class SpikeFormer(BaseModel):
|
30
|
+
def __init__(self, cfg: BaseModelConfig):
|
31
|
+
super(SpikeFormer, self).__init__(cfg)
|
32
|
+
|
33
|
+
def preprocess_spike(self, spike):
|
34
|
+
# length
|
35
|
+
spike = self.crop_spike_length(spike)
|
36
|
+
# size
|
37
|
+
if self.spike_size == (250, 400):
|
38
|
+
spike = torch.cat([spike[:, :, :3, :], spike, spike[:, :, -3:, :]], dim=2)
|
39
|
+
elif self.spike_size == (480, 854):
|
40
|
+
spike = torch.cat([spike, spike[:, :, :, -2:]], dim=3)
|
41
|
+
# input
|
42
|
+
spike = 2 * spike - 1
|
43
|
+
return spike
|
44
|
+
|
45
|
+
def postprocess_img(self, image):
|
46
|
+
if self.spike_size == (250, 400):
|
47
|
+
image = image[:, :, 3:-3, :]
|
48
|
+
elif self.spike_size == (480, 854):
|
49
|
+
image = image[:, :, :, :854]
|
50
|
+
return image
|
@@ -0,0 +1,51 @@
|
|
1
|
+
import torch
|
2
|
+
from dataclasses import dataclass, field
|
3
|
+
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
4
|
+
|
5
|
+
|
6
|
+
@dataclass
|
7
|
+
class Spk2ImgNetConfig(BaseModelConfig):
|
8
|
+
# default params for Spk2ImgNet
|
9
|
+
model_name: str = "spk2imgnet"
|
10
|
+
model_file_name: str = "nets"
|
11
|
+
model_cls_name: str = "SpikeNet"
|
12
|
+
model_win_length: int = 41
|
13
|
+
require_params: bool = True
|
14
|
+
ckpt_path: str = "weights/spk2imgnet.pth"
|
15
|
+
light_correction: bool = False
|
16
|
+
|
17
|
+
# model params
|
18
|
+
model_params: dict = field(
|
19
|
+
default_factory=lambda: {
|
20
|
+
"in_channels": 13,
|
21
|
+
"features": 64,
|
22
|
+
"out_channels": 1,
|
23
|
+
"win_r": 6,
|
24
|
+
"win_step": 7,
|
25
|
+
}
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
class Spk2ImgNet(BaseModel):
|
30
|
+
def __init__(self, cfg: BaseModelConfig):
|
31
|
+
super(Spk2ImgNet, self).__init__(cfg)
|
32
|
+
|
33
|
+
def preprocess_spike(self, spike):
|
34
|
+
# length
|
35
|
+
spike = self.crop_spike_length(spike)
|
36
|
+
# size
|
37
|
+
if self.spike_size == (250, 400):
|
38
|
+
spike = torch.cat([spike, spike[:, :, -2:]], dim=2)
|
39
|
+
elif self.spike_size == (480, 854):
|
40
|
+
spike = torch.cat([spike, spike[:, :, :, -2:]], dim=3)
|
41
|
+
return spike
|
42
|
+
|
43
|
+
def postprocess_img(self, image):
|
44
|
+
if self.spike_size == (250, 400):
|
45
|
+
image = image[:, :, :250, :]
|
46
|
+
elif self.spike_size == (480, 854):
|
47
|
+
image = image[:, :, :, :854]
|
48
|
+
# used on the REDS_small dataset.
|
49
|
+
if self.cfg.light_correction == True:
|
50
|
+
image = torch.clamp(image / 0.6, 0, 1)
|
51
|
+
return image
|
@@ -0,0 +1,22 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
3
|
+
|
4
|
+
|
5
|
+
@dataclass
|
6
|
+
class SSIRConfig(BaseModelConfig):
|
7
|
+
# default params for SSIR
|
8
|
+
model_name: str = "ssir"
|
9
|
+
model_file_name: str = "models.networks"
|
10
|
+
model_cls_name: str = "SSIR"
|
11
|
+
model_win_length: int = 41
|
12
|
+
require_params: bool = True
|
13
|
+
ckpt_path: str = "weights/ssir.pth"
|
14
|
+
|
15
|
+
|
16
|
+
class SSIR(BaseModel):
|
17
|
+
def __init__(self, cfg: BaseModelConfig):
|
18
|
+
super(SSIR, self).__init__(cfg)
|
19
|
+
|
20
|
+
def postprocess_img(self, image):
|
21
|
+
# image = image[0]
|
22
|
+
return image
|
@@ -0,0 +1,18 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
3
|
+
|
4
|
+
|
5
|
+
@dataclass
|
6
|
+
class SSMLConfig(BaseModelConfig):
|
7
|
+
# default params for SSML
|
8
|
+
model_name: str = "ssml"
|
9
|
+
model_file_name: str = "model"
|
10
|
+
model_cls_name: str = "DoubleNet"
|
11
|
+
model_win_length: int = 41
|
12
|
+
require_params: bool = True
|
13
|
+
ckpt_path: str = 'weights/ssml.pt'
|
14
|
+
|
15
|
+
|
16
|
+
class SSML(BaseModel):
|
17
|
+
def __init__(self, cfg: BaseModelConfig):
|
18
|
+
super(SSML, self).__init__(cfg)
|
@@ -0,0 +1,37 @@
|
|
1
|
+
import torch
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
4
|
+
|
5
|
+
|
6
|
+
@dataclass
|
7
|
+
class STIRConfig(BaseModelConfig):
|
8
|
+
# default params for SSIR
|
9
|
+
model_name: str = "stir"
|
10
|
+
model_file_name: str = "models.networks_STIR"
|
11
|
+
model_cls_name: str = "STIR"
|
12
|
+
model_win_length: int = 61
|
13
|
+
require_params: bool = True
|
14
|
+
ckpt_path: str = "weights/stir.pth"
|
15
|
+
|
16
|
+
|
17
|
+
class STIR(BaseModel):
|
18
|
+
def __init__(self, cfg: BaseModelConfig):
|
19
|
+
super(STIR, self).__init__(cfg)
|
20
|
+
|
21
|
+
def preprocess_spike(self, spike):
|
22
|
+
# length
|
23
|
+
spike = self.crop_spike_length(spike)
|
24
|
+
# size
|
25
|
+
if self.spike_size == (250, 400):
|
26
|
+
spike = torch.cat([spike, spike[:, :, -6:]], dim=2)
|
27
|
+
elif self.spike_size == (480, 854):
|
28
|
+
spike = torch.cat([spike, spike[:, :, :, -10:]], dim=3)
|
29
|
+
return spike
|
30
|
+
|
31
|
+
def postprocess_img(self, image):
|
32
|
+
# recon, Fs_lv_0, Fs_lv_1, Fs_lv_2, Fs_lv_3, Fs_lv_4, Est = image
|
33
|
+
if self.spike_size == (250, 400):
|
34
|
+
image = image[:, :, :250, :]
|
35
|
+
elif self.spike_size == (480, 854):
|
36
|
+
image = image[:, :, :, :854]
|
37
|
+
return image
|
@@ -0,0 +1,18 @@
|
|
1
|
+
from dataclasses import dataclass, field
|
2
|
+
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
3
|
+
|
4
|
+
|
5
|
+
@dataclass
|
6
|
+
class TFIConfig(BaseModelConfig):
|
7
|
+
# default params for TFI
|
8
|
+
model_name: str = "tfi"
|
9
|
+
model_file_name: str = "nets"
|
10
|
+
model_cls_name: str = "TFIModel"
|
11
|
+
model_win_length: int = 41
|
12
|
+
require_params: bool = False
|
13
|
+
model_params: dict = field(default_factory=lambda: {"model_win_length": 41})
|
14
|
+
|
15
|
+
|
16
|
+
class TFI(BaseModel):
|
17
|
+
def __init__(self, cfg: BaseModelConfig):
|
18
|
+
super(TFI, self).__init__(cfg)
|
@@ -0,0 +1,18 @@
|
|
1
|
+
from dataclasses import dataclass,field
|
2
|
+
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
3
|
+
|
4
|
+
|
5
|
+
@dataclass
|
6
|
+
class TFPConfig(BaseModelConfig):
|
7
|
+
# default params for TFP
|
8
|
+
model_name: str = "tfp"
|
9
|
+
model_file_name: str = "nets"
|
10
|
+
model_cls_name: str = "TFPModel"
|
11
|
+
model_win_length: int = 41
|
12
|
+
require_params: bool = False
|
13
|
+
model_params: dict = field(default_factory=lambda: {"model_win_length": 41})
|
14
|
+
|
15
|
+
|
16
|
+
class TFP(BaseModel):
|
17
|
+
def __init__(self, cfg: BaseModelConfig):
|
18
|
+
super(TFP, self).__init__(cfg)
|
@@ -0,0 +1,31 @@
|
|
1
|
+
from dataclasses import dataclass, field
|
2
|
+
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
3
|
+
from typing import List
|
4
|
+
|
5
|
+
|
6
|
+
@dataclass
|
7
|
+
class WGSEConfig(BaseModelConfig):
|
8
|
+
# default params for WGSE
|
9
|
+
model_name: str = "wgse"
|
10
|
+
model_file_name: str = "dwtnets"
|
11
|
+
model_cls_name: str = "Dwt1dResnetX_TCN"
|
12
|
+
model_win_length: int = 41
|
13
|
+
require_params: bool = True
|
14
|
+
ckpt_path: str = "weights/wgse.pt"
|
15
|
+
model_params: dict = field(
|
16
|
+
default_factory=lambda: {
|
17
|
+
"wvlname": "db8",
|
18
|
+
"J": 5,
|
19
|
+
"yl_size": "15",
|
20
|
+
"yh_size": [28, 21, 18, 16, 15],
|
21
|
+
"num_residual_blocks": 3,
|
22
|
+
"norm": None,
|
23
|
+
"ks": 3,
|
24
|
+
"store_features": True,
|
25
|
+
}
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
class WGSE(BaseModel):
|
30
|
+
def __init__(self, cfg: BaseModelConfig):
|
31
|
+
super(WGSE, self).__init__(cfg)
|
@@ -0,0 +1,267 @@
|
|
1
|
+
import torch
|
2
|
+
from dataclasses import dataclass, field
|
3
|
+
import os
|
4
|
+
from spikezoo.utils.img_utils import tensor2npy, AverageMeter
|
5
|
+
from spikezoo.utils.spike_utils import load_vidar_dat
|
6
|
+
from spikezoo.metrics import cal_metric_pair, cal_metric_single
|
7
|
+
import numpy as np
|
8
|
+
import cv2
|
9
|
+
from pathlib import Path
|
10
|
+
from enum import Enum, auto
|
11
|
+
from typing import Literal
|
12
|
+
from spikezoo.metrics import metric_pair_names, metric_single_names, metric_all_names
|
13
|
+
from thop import profile
|
14
|
+
import time
|
15
|
+
from datetime import datetime
|
16
|
+
from spikezoo.utils import setup_logging, save_config
|
17
|
+
from tqdm import tqdm
|
18
|
+
from spikezoo.models import build_model_cfg, build_model_name, BaseModel, BaseModelConfig
|
19
|
+
from spikezoo.datasets import build_dataset_cfg, build_dataset_name, BaseDataset, BaseDatasetConfig, build_dataloader
|
20
|
+
from typing import Optional, Union, List
|
21
|
+
|
22
|
+
|
23
|
+
@dataclass
|
24
|
+
class PipelineConfig:
|
25
|
+
"Evaluate metrics or not."
|
26
|
+
save_metric: bool = True
|
27
|
+
"Save recoverd images or not."
|
28
|
+
save_img: bool = True
|
29
|
+
"Normalizing recoverd images or not."
|
30
|
+
save_img_norm: bool = False
|
31
|
+
"Normalizing gt or not."
|
32
|
+
gt_img_norm: bool = False
|
33
|
+
"Save folder for the code running result."
|
34
|
+
save_folder: str = ""
|
35
|
+
"Saved experiment name."
|
36
|
+
exp_name: str = ""
|
37
|
+
"Metric names for evaluation."
|
38
|
+
metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim"])
|
39
|
+
"Different modes for the pipeline."
|
40
|
+
_mode: Literal["single_mode", "multi_mode", "train_mode"] = "single_mode"
|
41
|
+
|
42
|
+
|
43
|
+
class Pipeline:
|
44
|
+
def __init__(
|
45
|
+
self,
|
46
|
+
cfg: PipelineConfig,
|
47
|
+
model_cfg: Union[str, BaseModelConfig],
|
48
|
+
dataset_cfg: Union[str, BaseDatasetConfig],
|
49
|
+
):
|
50
|
+
self.cfg = cfg
|
51
|
+
self._setup_model_data(model_cfg, dataset_cfg)
|
52
|
+
self._setup_pipeline()
|
53
|
+
|
54
|
+
def _setup_model_data(self, model_cfg, dataset_cfg):
|
55
|
+
"""Model and Data setup."""
|
56
|
+
# model
|
57
|
+
self.model: BaseModel = build_model_name(model_cfg) if isinstance(model_cfg, str) else build_model_cfg(model_cfg)
|
58
|
+
self.model = self.model.eval()
|
59
|
+
torch.set_grad_enabled(False)
|
60
|
+
# dataset
|
61
|
+
self.dataset: BaseDataset = build_dataset_name(dataset_cfg) if isinstance(dataset_cfg, str) else build_dataset_cfg(dataset_cfg)
|
62
|
+
self.dataloader = build_dataloader(self.dataset)
|
63
|
+
# device
|
64
|
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
65
|
+
|
66
|
+
def _setup_pipeline(self):
|
67
|
+
"""Pipeline setup."""
|
68
|
+
# save folder
|
69
|
+
self.thistime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S-%f")[:23]
|
70
|
+
self.save_folder = Path(__file__).parent.parent / Path(f"results") if len(self.cfg.save_folder) == 0 else self.cfg.save_folder
|
71
|
+
mode_name = "train" if self.cfg._mode == "train_mode" else "detect"
|
72
|
+
self.save_folder = (
|
73
|
+
self.save_folder / Path(f"{mode_name}/{self.thistime}")
|
74
|
+
if len(self.cfg.exp_name) == 0
|
75
|
+
else self.save_folder / Path(f"{mode_name}/{self.cfg.exp_name}")
|
76
|
+
)
|
77
|
+
save_folder = self.save_folder
|
78
|
+
os.makedirs(str(save_folder), exist_ok=True)
|
79
|
+
# logger result
|
80
|
+
self.logger = setup_logging(save_folder / Path("result.log"))
|
81
|
+
self.logger.info(f"Info logs are saved on the {save_folder}/result.log")
|
82
|
+
# pipeline config
|
83
|
+
save_config(self.cfg, save_folder / Path("cfg_pipeline.log"))
|
84
|
+
# model config
|
85
|
+
if self.cfg._mode == "single_mode":
|
86
|
+
save_config(self.model.cfg, save_folder / Path("cfg_model.log"))
|
87
|
+
elif self.cfg._mode == "multi_mode":
|
88
|
+
for model in self.model_list:
|
89
|
+
save_config(model.cfg, save_folder / Path("cfg_model.log"), mode="a")
|
90
|
+
# dataset config
|
91
|
+
save_config(self.dataset.cfg, save_folder / Path("cfg_dataset.log"))
|
92
|
+
|
93
|
+
def spk2img_from_dataset(self, idx=0):
|
94
|
+
"""Func---Save the recoverd image and calculate the metric from the given dataset."""
|
95
|
+
# save folder
|
96
|
+
save_folder = self.save_folder / Path(f"spk2img_from_dataset/{self.dataset.cfg.dataset_name}_dataset/{self.dataset.cfg.split}/{idx:06d}")
|
97
|
+
os.makedirs(str(save_folder), exist_ok=True)
|
98
|
+
|
99
|
+
# data process
|
100
|
+
batch = self.dataset[idx]
|
101
|
+
spike, img = batch["spike"], batch["img"]
|
102
|
+
spike = spike[None].to(self.device)
|
103
|
+
if self.dataset.cfg.with_img == True:
|
104
|
+
img = img[None].to(self.device)
|
105
|
+
else:
|
106
|
+
img = None
|
107
|
+
return self._spk2img(spike, img, save_folder)
|
108
|
+
|
109
|
+
def spk2img_from_file(self, file_path, height, width, img_path=None, remove_head=False):
|
110
|
+
"""Func---Save the recoverd image and calculate the metric from the given input file."""
|
111
|
+
# save folder
|
112
|
+
save_folder = self.save_folder / Path(f"spk2img_from_file/{os.path.basename(file_path)}")
|
113
|
+
os.makedirs(str(save_folder), exist_ok=True)
|
114
|
+
|
115
|
+
# load spike from .dat
|
116
|
+
if file_path.endswith(".dat"):
|
117
|
+
spike = load_vidar_dat(file_path, height, width, remove_head)
|
118
|
+
# load spike from .npz from UHSR
|
119
|
+
elif file_path.endswith("npz"):
|
120
|
+
spike = np.load(file_path)["spk"].astype(np.float32)[:, 13:237, 13:237]
|
121
|
+
else:
|
122
|
+
raise RuntimeError("Not recognized spike input file.")
|
123
|
+
# load img from .png/.jpg image file
|
124
|
+
if img_path is not None:
|
125
|
+
img = cv2.imread(img_path)
|
126
|
+
if img.ndim == 3:
|
127
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
128
|
+
img = (img / 255).astype(np.float32)
|
129
|
+
img = torch.from_numpy(img)[None, None].to(self.device)
|
130
|
+
else:
|
131
|
+
img = img_path
|
132
|
+
spike = torch.from_numpy(spike)[None].to(self.device)
|
133
|
+
return self._spk2img(spike, img, save_folder)
|
134
|
+
|
135
|
+
def spk2img_from_spk(self, spike, img=None):
|
136
|
+
"""Func---Save the recoverd image and calculate the metric from the given spike stream."""
|
137
|
+
# save folder
|
138
|
+
save_folder = self.save_folder / Path(f"spk2img_from_spk/{self.thistime}")
|
139
|
+
os.makedirs(str(save_folder), exist_ok=True)
|
140
|
+
|
141
|
+
# spike process
|
142
|
+
if isinstance(spike, np.ndarray):
|
143
|
+
spike = torch.from_numpy(spike)
|
144
|
+
spike = spike.to(self.device)
|
145
|
+
# [c,h,w] -> [1,c,w,h]
|
146
|
+
if spike.dim() == 3:
|
147
|
+
spike = spike[None]
|
148
|
+
spike = spike.float()
|
149
|
+
# img process
|
150
|
+
if img is not None:
|
151
|
+
if isinstance(img, np.ndarray):
|
152
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if img.ndim == 3 else img
|
153
|
+
img = (img / 255).astype(np.float32)
|
154
|
+
img = torch.from_numpy(img)[None, None].to(self.device)
|
155
|
+
else:
|
156
|
+
raise RuntimeError("Not recognized image input type.")
|
157
|
+
return self._spk2img(spike, img, save_folder)
|
158
|
+
|
159
|
+
def save_imgs_from_dataset(self):
|
160
|
+
"""Func---Save all images from the given dataset."""
|
161
|
+
for idx in range(len(self.dataset)):
|
162
|
+
self.spk2img_from_dataset(idx=idx)
|
163
|
+
|
164
|
+
# TODO: To be overridden
|
165
|
+
def cal_params(self):
|
166
|
+
"""Func---Calculate the parameters/flops/latency of the given method."""
|
167
|
+
self._cal_prams_model(self.model)
|
168
|
+
|
169
|
+
# TODO: To be overridden
|
170
|
+
def cal_metrics(self):
|
171
|
+
"""Func---Calculate the metric of the given method."""
|
172
|
+
self._cal_metrics_model(self.model)
|
173
|
+
|
174
|
+
# TODO: To be overridden
|
175
|
+
def _spk2img(self, spike, img, save_folder):
|
176
|
+
"""Spike-to-image: spike:[bs,c,h,w] (0-1), img:[bs,1,h,w] (0-1)"""
|
177
|
+
return self._spk2img_model(self.model, spike, img, save_folder)
|
178
|
+
|
179
|
+
def _spk2img_model(self, model, spike, img, save_folder):
|
180
|
+
"""Spike-to-image from the given model."""
|
181
|
+
# spike2image conversion
|
182
|
+
model_name = model.cfg.model_name
|
183
|
+
recon_img = model(spike)
|
184
|
+
recon_img_copy = recon_img.clone()
|
185
|
+
# normalization
|
186
|
+
recon_img, img = self._post_process_img(model_name, recon_img, img)
|
187
|
+
# metric
|
188
|
+
if self.cfg.save_metric == True:
|
189
|
+
self.logger.info(f"----------------------Method: {model_name.upper()}----------------------")
|
190
|
+
# paired metric
|
191
|
+
for metric_name in metric_all_names:
|
192
|
+
if img is not None and metric_name in metric_pair_names:
|
193
|
+
self.logger.info(f"{metric_name.upper()}: {cal_metric_pair(recon_img,img,metric_name)}")
|
194
|
+
elif metric_name in metric_single_names:
|
195
|
+
self.logger.info(f"{metric_name.upper()}: {cal_metric_single(recon_img,metric_name)}")
|
196
|
+
else:
|
197
|
+
self.logger.info(f"{metric_name.upper()} not calculated since no ground truth provided.")
|
198
|
+
|
199
|
+
# visual
|
200
|
+
if self.cfg.save_img == True:
|
201
|
+
recon_img = tensor2npy(recon_img[0, 0])
|
202
|
+
cv2.imwrite(f"{save_folder}/{model.cfg.model_name}.png", recon_img)
|
203
|
+
if img is not None:
|
204
|
+
img = tensor2npy(img[0, 0])
|
205
|
+
cv2.imwrite(f"{save_folder}/sharp_img.png", img)
|
206
|
+
self.logger.info(f"Images are saved on the {save_folder}")
|
207
|
+
|
208
|
+
return recon_img_copy
|
209
|
+
|
210
|
+
def _post_process_img(self, model_name, recon_img, gt_img):
|
211
|
+
"""Post process the reconstructed image."""
|
212
|
+
# TFP and TFI algorithms are normalized automatically, others are normalized based on the self.cfg.use_norm
|
213
|
+
if model_name in ["tfp", "tfi", "spikeformer", "spikeclip"]:
|
214
|
+
recon_img = (recon_img - recon_img.min()) / (recon_img.max() - recon_img.min())
|
215
|
+
elif self.cfg.save_img_norm == True:
|
216
|
+
recon_img = (recon_img - recon_img.min()) / (recon_img.max() - recon_img.min())
|
217
|
+
recon_img = recon_img.clip(0, 1)
|
218
|
+
gt_img = (gt_img - gt_img.min()) / (gt_img.max() - gt_img.min()) if self.cfg.gt_img_norm == True and gt_img is not None else gt_img
|
219
|
+
return recon_img, gt_img
|
220
|
+
|
221
|
+
def _cal_metrics_model(self, model: BaseModel):
|
222
|
+
"""Calculate the metrics for the given model."""
|
223
|
+
# metrics construct
|
224
|
+
model_name = model.cfg.model_name
|
225
|
+
metrics_dict = {}
|
226
|
+
for metric_name in self.cfg.metric_names:
|
227
|
+
if (self.dataset.cfg.with_img == True) or (metric_name in metric_single_names):
|
228
|
+
metrics_dict[metric_name] = AverageMeter()
|
229
|
+
|
230
|
+
# metrics calculate
|
231
|
+
for batch_idx, batch in enumerate(tqdm(self.dataloader)):
|
232
|
+
batch = model.feed_to_device(batch)
|
233
|
+
outputs = model.get_outputs_dict(batch)
|
234
|
+
recon_img, img = model.get_paired_imgs(batch, outputs)
|
235
|
+
recon_img, img = self._post_process_img(model_name, recon_img, img)
|
236
|
+
for metric_name in metrics_dict.keys():
|
237
|
+
if metric_name in metric_pair_names:
|
238
|
+
metrics_dict[metric_name].update(cal_metric_pair(recon_img, img, metric_name))
|
239
|
+
elif metric_name in metric_single_names:
|
240
|
+
metrics_dict[metric_name].update(cal_metric_single(recon_img, metric_name))
|
241
|
+
|
242
|
+
# metrics self.logger.info
|
243
|
+
self.logger.info(f"----------------------Method: {model_name.upper()}----------------------")
|
244
|
+
for metric_name in metrics_dict.keys():
|
245
|
+
self.logger.info(f"{metric_name.upper()}: {metrics_dict[metric_name].avg}")
|
246
|
+
|
247
|
+
def _cal_prams_model(self, model):
|
248
|
+
"""Calculate the parameters for the given model."""
|
249
|
+
network = model.net
|
250
|
+
model_name = model.cfg.model_name.upper()
|
251
|
+
# params
|
252
|
+
params = sum(p.numel() for p in network.parameters())
|
253
|
+
# latency
|
254
|
+
spike = torch.zeros((1, 200, 250, 400)).cuda()
|
255
|
+
start_time = time.time()
|
256
|
+
for _ in range(100):
|
257
|
+
model(spike)
|
258
|
+
latency = (time.time() - start_time) / 100
|
259
|
+
# flop # todo thop bug for BSF
|
260
|
+
flops, _ = profile((model), inputs=(spike,))
|
261
|
+
re_msg = (
|
262
|
+
"Total params: %.4fM" % (params / 1e6),
|
263
|
+
"FLOPs:" + str(flops / 1e9) + "{}".format("G"),
|
264
|
+
"Latency: {:.6f} seconds".format(latency),
|
265
|
+
)
|
266
|
+
self.logger.info(f"----------------------Method: {model_name}----------------------")
|
267
|
+
self.logger.info(re_msg)
|