spikezoo 0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spikezoo/__init__.py +0 -0
- spikezoo/archs/__init__.py +0 -0
- spikezoo/datasets/__init__.py +68 -0
- spikezoo/datasets/base_dataset.py +157 -0
- spikezoo/datasets/realworld_dataset.py +25 -0
- spikezoo/datasets/reds_small_dataset.py +27 -0
- spikezoo/datasets/szdata_dataset.py +37 -0
- spikezoo/datasets/uhsr_dataset.py +38 -0
- spikezoo/metrics/__init__.py +96 -0
- spikezoo/models/__init__.py +37 -0
- spikezoo/models/base_model.py +177 -0
- spikezoo/models/bsf_model.py +90 -0
- spikezoo/models/spcsnet_model.py +19 -0
- spikezoo/models/spikeclip_model.py +32 -0
- spikezoo/models/spikeformer_model.py +50 -0
- spikezoo/models/spk2imgnet_model.py +51 -0
- spikezoo/models/ssir_model.py +22 -0
- spikezoo/models/ssml_model.py +18 -0
- spikezoo/models/stir_model.py +37 -0
- spikezoo/models/tfi_model.py +18 -0
- spikezoo/models/tfp_model.py +18 -0
- spikezoo/models/wgse_model.py +31 -0
- spikezoo/pipeline/__init__.py +4 -0
- spikezoo/pipeline/base_pipeline.py +267 -0
- spikezoo/pipeline/ensemble_pipeline.py +64 -0
- spikezoo/pipeline/train_pipeline.py +94 -0
- spikezoo/utils/__init__.py +3 -0
- spikezoo/utils/data_utils.py +52 -0
- spikezoo/utils/img_utils.py +72 -0
- spikezoo/utils/other_utils.py +59 -0
- spikezoo/utils/spike_utils.py +82 -0
- spikezoo-0.1.dist-info/LICENSE.txt +17 -0
- spikezoo-0.1.dist-info/METADATA +39 -0
- spikezoo-0.1.dist-info/RECORD +36 -0
- spikezoo-0.1.dist-info/WHEEL +5 -0
- spikezoo-0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,90 @@
|
|
1
|
+
import torch
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
4
|
+
|
5
|
+
|
6
|
+
@dataclass
|
7
|
+
class BSFConfig(BaseModelConfig):
|
8
|
+
# default params for BSF
|
9
|
+
model_name: str = "bsf"
|
10
|
+
model_file_name: str = "models.bsf.bsf"
|
11
|
+
model_cls_name: str = "BSF"
|
12
|
+
model_win_length: int = 61
|
13
|
+
require_params: bool = True
|
14
|
+
ckpt_path: str = "weights/bsf.pth"
|
15
|
+
|
16
|
+
|
17
|
+
class BSF(BaseModel):
|
18
|
+
def __init__(self, cfg: BaseModelConfig):
|
19
|
+
super(BSF, self).__init__(cfg)
|
20
|
+
|
21
|
+
def preprocess_spike(self, spike):
|
22
|
+
# length
|
23
|
+
spike = self.crop_spike_length(spike)
|
24
|
+
# size
|
25
|
+
if self.spike_size == (250, 400):
|
26
|
+
spike = torch.cat([spike, spike[:, :, -2:]], dim=2)
|
27
|
+
elif self.spike_size == (480, 854):
|
28
|
+
spike = torch.cat([spike, spike[:, :, :, -2:]], dim=3)
|
29
|
+
# dsft
|
30
|
+
dsft = self.compute_dsft_core(spike)
|
31
|
+
dsft_dict = self.convert_dsft4(dsft, spike)
|
32
|
+
input_dict = {
|
33
|
+
"dsft_dict": dsft_dict,
|
34
|
+
"spikes": spike,
|
35
|
+
}
|
36
|
+
return input_dict
|
37
|
+
|
38
|
+
def postprocess_img(self, image):
|
39
|
+
if self.spike_size == (250, 400):
|
40
|
+
image = image[:, :, :250, :]
|
41
|
+
elif self.spike_size == (480, 854):
|
42
|
+
image = image[:, :, :, :854]
|
43
|
+
return image
|
44
|
+
|
45
|
+
def compute_dsft_core(self, spike):
|
46
|
+
bs, T, H, W = spike.shape
|
47
|
+
time = spike * torch.arange(T, device="cuda").reshape(1, T, 1, 1)
|
48
|
+
l_idx, _ = time.cummax(dim=1)
|
49
|
+
time[time == 0] = T
|
50
|
+
r_idx, _ = torch.flip(time, [1]).cummin(dim=1)
|
51
|
+
r_idx = torch.flip(r_idx, [1])
|
52
|
+
r_idx = torch.cat([r_idx[:, 1:, :, :], torch.ones([bs, 1, H, W], device="cuda") * T], dim=1)
|
53
|
+
res = r_idx - l_idx
|
54
|
+
res = torch.clip(res, 0)
|
55
|
+
return res
|
56
|
+
|
57
|
+
def convert_dsft4(self, dsft, spike):
|
58
|
+
b, T, h, w = spike.shape
|
59
|
+
dmls1 = -1 * torch.ones(spike.shape, device=spike.device, dtype=torch.float32)
|
60
|
+
dmrs1 = -1 * torch.ones(spike.shape, device=spike.device, dtype=torch.float32)
|
61
|
+
flag = -2 * torch.ones([b, h, w], device=spike.device, dtype=torch.float32)
|
62
|
+
for ii in range(T - 1, 0 - 1, -1):
|
63
|
+
flag += spike[:, ii] == 1
|
64
|
+
copy_pad_coord = flag < 0
|
65
|
+
dmls1[:, ii][copy_pad_coord] = dsft[:, ii][copy_pad_coord]
|
66
|
+
if ii < T - 1:
|
67
|
+
update_coord = (spike[:, ii + 1] == 1) * (~copy_pad_coord)
|
68
|
+
dmls1[:, ii][update_coord] = dsft[:, ii + 1][update_coord]
|
69
|
+
non_update_coord = (spike[:, ii + 1] != 1) * (~copy_pad_coord)
|
70
|
+
dmls1[:, ii][non_update_coord] = dmls1[:, ii + 1][non_update_coord]
|
71
|
+
flag = -2 * torch.ones([b, h, w], device=spike.device, dtype=torch.float32)
|
72
|
+
for ii in range(0, T, 1):
|
73
|
+
flag += spike[:, ii] == 1
|
74
|
+
copy_pad_coord = flag < 0
|
75
|
+
dmrs1[:, ii][copy_pad_coord] = dsft[:, ii][copy_pad_coord]
|
76
|
+
if ii > 0:
|
77
|
+
update_coord = (spike[:, ii] == 1) * (~copy_pad_coord)
|
78
|
+
dmrs1[:, ii][update_coord] = dsft[:, ii - 1][update_coord]
|
79
|
+
non_update_coord = (spike[:, ii] != 1) * (~copy_pad_coord)
|
80
|
+
dmrs1[:, ii][non_update_coord] = dmrs1[:, ii - 1][non_update_coord]
|
81
|
+
dsft12 = dsft + dmls1
|
82
|
+
dsft21 = dsft + dmrs1
|
83
|
+
dsft22 = dsft + dmls1 + dmrs1
|
84
|
+
dsft_dict = {
|
85
|
+
"dsft11": dsft,
|
86
|
+
"dsft12": dsft12,
|
87
|
+
"dsft21": dsft21,
|
88
|
+
"dsft22": dsft22,
|
89
|
+
}
|
90
|
+
return dsft_dict
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from dataclasses import dataclass, field
|
2
|
+
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
3
|
+
from typing import List
|
4
|
+
|
5
|
+
|
6
|
+
@dataclass
|
7
|
+
class SPCSNetConfig(BaseModelConfig):
|
8
|
+
# default params for WGSE
|
9
|
+
model_name: str = "spcsnet"
|
10
|
+
model_file_name: str = "models"
|
11
|
+
model_cls_name: str = "SPCS_Net"
|
12
|
+
model_win_length: int = 41
|
13
|
+
require_params: bool = True
|
14
|
+
ckpt_path: str = 'weights/spcsnet.pth'
|
15
|
+
|
16
|
+
|
17
|
+
class SPCSNet(BaseModel):
|
18
|
+
def __init__(self, cfg: BaseModelConfig):
|
19
|
+
super(SPCSNet, self).__init__(cfg)
|
@@ -0,0 +1,32 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
3
|
+
import torch
|
4
|
+
import torch.nn.functional as F
|
5
|
+
|
6
|
+
|
7
|
+
@dataclass
|
8
|
+
class SpikeCLIPConfig(BaseModelConfig):
|
9
|
+
# default params for SpikeCLIP
|
10
|
+
model_name: str = "spikeclip"
|
11
|
+
model_file_name: str = "nets"
|
12
|
+
model_cls_name: str = "LRN"
|
13
|
+
model_win_length: int = 200
|
14
|
+
require_params: bool = True
|
15
|
+
ckpt_path: str = "weights/spikeclip.pth"
|
16
|
+
|
17
|
+
|
18
|
+
class SpikeCLIP(BaseModel):
|
19
|
+
def __init__(self, cfg: BaseModelConfig):
|
20
|
+
super(SpikeCLIP, self).__init__(cfg)
|
21
|
+
|
22
|
+
def preprocess_spike(self, spike):
|
23
|
+
# length
|
24
|
+
spike = self.crop_spike_length(spike)
|
25
|
+
# voxel
|
26
|
+
voxel = torch.sum(spike.reshape(-1, 50, 4, spike.shape[-2], spike.shape[-1]), axis=2) # [200,224,224] -> [50,224,224]
|
27
|
+
voxel = F.pad(voxel, pad=(20, 20, 20, 20), mode="reflect", value=0)
|
28
|
+
return voxel
|
29
|
+
|
30
|
+
def postprocess_img(self, image):
|
31
|
+
image = image[:, :, 20:-20, 20:-20]
|
32
|
+
return image
|
@@ -0,0 +1,50 @@
|
|
1
|
+
import torch
|
2
|
+
from dataclasses import dataclass, field
|
3
|
+
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
4
|
+
|
5
|
+
|
6
|
+
@dataclass
|
7
|
+
class SpikeFormerConfig(BaseModelConfig):
|
8
|
+
# default params for SpikeFormer
|
9
|
+
model_name: str = "spikeformer"
|
10
|
+
model_file_name: str = "Model.SpikeFormer"
|
11
|
+
model_cls_name: str = "SpikeFormer"
|
12
|
+
model_win_length: int = 65
|
13
|
+
require_params: bool = True
|
14
|
+
ckpt_path: str = "weights/spikeformer.pth"
|
15
|
+
model_params: dict = field(
|
16
|
+
default_factory=lambda: {
|
17
|
+
"inputDim": 65,
|
18
|
+
"dims": (32, 64, 160, 256),
|
19
|
+
"heads": (1, 2, 5, 8),
|
20
|
+
"ff_expansion": (8, 8, 4, 4),
|
21
|
+
"reduction_ratio": (8, 4, 2, 1),
|
22
|
+
"num_layers": 2,
|
23
|
+
"decoder_dim": 256,
|
24
|
+
"out_channel": 1,
|
25
|
+
}
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
class SpikeFormer(BaseModel):
|
30
|
+
def __init__(self, cfg: BaseModelConfig):
|
31
|
+
super(SpikeFormer, self).__init__(cfg)
|
32
|
+
|
33
|
+
def preprocess_spike(self, spike):
|
34
|
+
# length
|
35
|
+
spike = self.crop_spike_length(spike)
|
36
|
+
# size
|
37
|
+
if self.spike_size == (250, 400):
|
38
|
+
spike = torch.cat([spike[:, :, :3, :], spike, spike[:, :, -3:, :]], dim=2)
|
39
|
+
elif self.spike_size == (480, 854):
|
40
|
+
spike = torch.cat([spike, spike[:, :, :, -2:]], dim=3)
|
41
|
+
# input
|
42
|
+
spike = 2 * spike - 1
|
43
|
+
return spike
|
44
|
+
|
45
|
+
def postprocess_img(self, image):
|
46
|
+
if self.spike_size == (250, 400):
|
47
|
+
image = image[:, :, 3:-3, :]
|
48
|
+
elif self.spike_size == (480, 854):
|
49
|
+
image = image[:, :, :, :854]
|
50
|
+
return image
|
@@ -0,0 +1,51 @@
|
|
1
|
+
import torch
|
2
|
+
from dataclasses import dataclass, field
|
3
|
+
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
4
|
+
|
5
|
+
|
6
|
+
@dataclass
|
7
|
+
class Spk2ImgNetConfig(BaseModelConfig):
|
8
|
+
# default params for Spk2ImgNet
|
9
|
+
model_name: str = "spk2imgnet"
|
10
|
+
model_file_name: str = "nets"
|
11
|
+
model_cls_name: str = "SpikeNet"
|
12
|
+
model_win_length: int = 41
|
13
|
+
require_params: bool = True
|
14
|
+
ckpt_path: str = "weights/spk2imgnet.pth"
|
15
|
+
light_correction: bool = False
|
16
|
+
|
17
|
+
# model params
|
18
|
+
model_params: dict = field(
|
19
|
+
default_factory=lambda: {
|
20
|
+
"in_channels": 13,
|
21
|
+
"features": 64,
|
22
|
+
"out_channels": 1,
|
23
|
+
"win_r": 6,
|
24
|
+
"win_step": 7,
|
25
|
+
}
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
class Spk2ImgNet(BaseModel):
|
30
|
+
def __init__(self, cfg: BaseModelConfig):
|
31
|
+
super(Spk2ImgNet, self).__init__(cfg)
|
32
|
+
|
33
|
+
def preprocess_spike(self, spike):
|
34
|
+
# length
|
35
|
+
spike = self.crop_spike_length(spike)
|
36
|
+
# size
|
37
|
+
if self.spike_size == (250, 400):
|
38
|
+
spike = torch.cat([spike, spike[:, :, -2:]], dim=2)
|
39
|
+
elif self.spike_size == (480, 854):
|
40
|
+
spike = torch.cat([spike, spike[:, :, :, -2:]], dim=3)
|
41
|
+
return spike
|
42
|
+
|
43
|
+
def postprocess_img(self, image):
|
44
|
+
if self.spike_size == (250, 400):
|
45
|
+
image = image[:, :, :250, :]
|
46
|
+
elif self.spike_size == (480, 854):
|
47
|
+
image = image[:, :, :, :854]
|
48
|
+
# used on the REDS_small dataset.
|
49
|
+
if self.cfg.light_correction == True:
|
50
|
+
image = torch.clamp(image / 0.6, 0, 1)
|
51
|
+
return image
|
@@ -0,0 +1,22 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
3
|
+
|
4
|
+
|
5
|
+
@dataclass
|
6
|
+
class SSIRConfig(BaseModelConfig):
|
7
|
+
# default params for SSIR
|
8
|
+
model_name: str = "ssir"
|
9
|
+
model_file_name: str = "models.networks"
|
10
|
+
model_cls_name: str = "SSIR"
|
11
|
+
model_win_length: int = 41
|
12
|
+
require_params: bool = True
|
13
|
+
ckpt_path: str = "weights/ssir.pth"
|
14
|
+
|
15
|
+
|
16
|
+
class SSIR(BaseModel):
|
17
|
+
def __init__(self, cfg: BaseModelConfig):
|
18
|
+
super(SSIR, self).__init__(cfg)
|
19
|
+
|
20
|
+
def postprocess_img(self, image):
|
21
|
+
# image = image[0]
|
22
|
+
return image
|
@@ -0,0 +1,18 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
3
|
+
|
4
|
+
|
5
|
+
@dataclass
|
6
|
+
class SSMLConfig(BaseModelConfig):
|
7
|
+
# default params for SSML
|
8
|
+
model_name: str = "ssml"
|
9
|
+
model_file_name: str = "model"
|
10
|
+
model_cls_name: str = "DoubleNet"
|
11
|
+
model_win_length: int = 41
|
12
|
+
require_params: bool = True
|
13
|
+
ckpt_path: str = 'weights/ssml.pt'
|
14
|
+
|
15
|
+
|
16
|
+
class SSML(BaseModel):
|
17
|
+
def __init__(self, cfg: BaseModelConfig):
|
18
|
+
super(SSML, self).__init__(cfg)
|
@@ -0,0 +1,37 @@
|
|
1
|
+
import torch
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
4
|
+
|
5
|
+
|
6
|
+
@dataclass
|
7
|
+
class STIRConfig(BaseModelConfig):
|
8
|
+
# default params for SSIR
|
9
|
+
model_name: str = "stir"
|
10
|
+
model_file_name: str = "models.networks_STIR"
|
11
|
+
model_cls_name: str = "STIR"
|
12
|
+
model_win_length: int = 61
|
13
|
+
require_params: bool = True
|
14
|
+
ckpt_path: str = "weights/stir.pth"
|
15
|
+
|
16
|
+
|
17
|
+
class STIR(BaseModel):
|
18
|
+
def __init__(self, cfg: BaseModelConfig):
|
19
|
+
super(STIR, self).__init__(cfg)
|
20
|
+
|
21
|
+
def preprocess_spike(self, spike):
|
22
|
+
# length
|
23
|
+
spike = self.crop_spike_length(spike)
|
24
|
+
# size
|
25
|
+
if self.spike_size == (250, 400):
|
26
|
+
spike = torch.cat([spike, spike[:, :, -6:]], dim=2)
|
27
|
+
elif self.spike_size == (480, 854):
|
28
|
+
spike = torch.cat([spike, spike[:, :, :, -10:]], dim=3)
|
29
|
+
return spike
|
30
|
+
|
31
|
+
def postprocess_img(self, image):
|
32
|
+
# recon, Fs_lv_0, Fs_lv_1, Fs_lv_2, Fs_lv_3, Fs_lv_4, Est = image
|
33
|
+
if self.spike_size == (250, 400):
|
34
|
+
image = image[:, :, :250, :]
|
35
|
+
elif self.spike_size == (480, 854):
|
36
|
+
image = image[:, :, :, :854]
|
37
|
+
return image
|
@@ -0,0 +1,18 @@
|
|
1
|
+
from dataclasses import dataclass, field
|
2
|
+
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
3
|
+
|
4
|
+
|
5
|
+
@dataclass
|
6
|
+
class TFIConfig(BaseModelConfig):
|
7
|
+
# default params for TFI
|
8
|
+
model_name: str = "tfi"
|
9
|
+
model_file_name: str = "nets"
|
10
|
+
model_cls_name: str = "TFIModel"
|
11
|
+
model_win_length: int = 41
|
12
|
+
require_params: bool = False
|
13
|
+
model_params: dict = field(default_factory=lambda: {"model_win_length": 41})
|
14
|
+
|
15
|
+
|
16
|
+
class TFI(BaseModel):
|
17
|
+
def __init__(self, cfg: BaseModelConfig):
|
18
|
+
super(TFI, self).__init__(cfg)
|
@@ -0,0 +1,18 @@
|
|
1
|
+
from dataclasses import dataclass,field
|
2
|
+
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
3
|
+
|
4
|
+
|
5
|
+
@dataclass
|
6
|
+
class TFPConfig(BaseModelConfig):
|
7
|
+
# default params for TFP
|
8
|
+
model_name: str = "tfp"
|
9
|
+
model_file_name: str = "nets"
|
10
|
+
model_cls_name: str = "TFPModel"
|
11
|
+
model_win_length: int = 41
|
12
|
+
require_params: bool = False
|
13
|
+
model_params: dict = field(default_factory=lambda: {"model_win_length": 41})
|
14
|
+
|
15
|
+
|
16
|
+
class TFP(BaseModel):
|
17
|
+
def __init__(self, cfg: BaseModelConfig):
|
18
|
+
super(TFP, self).__init__(cfg)
|
@@ -0,0 +1,31 @@
|
|
1
|
+
from dataclasses import dataclass, field
|
2
|
+
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
3
|
+
from typing import List
|
4
|
+
|
5
|
+
|
6
|
+
@dataclass
|
7
|
+
class WGSEConfig(BaseModelConfig):
|
8
|
+
# default params for WGSE
|
9
|
+
model_name: str = "wgse"
|
10
|
+
model_file_name: str = "dwtnets"
|
11
|
+
model_cls_name: str = "Dwt1dResnetX_TCN"
|
12
|
+
model_win_length: int = 41
|
13
|
+
require_params: bool = True
|
14
|
+
ckpt_path: str = "weights/wgse.pt"
|
15
|
+
model_params: dict = field(
|
16
|
+
default_factory=lambda: {
|
17
|
+
"wvlname": "db8",
|
18
|
+
"J": 5,
|
19
|
+
"yl_size": "15",
|
20
|
+
"yh_size": [28, 21, 18, 16, 15],
|
21
|
+
"num_residual_blocks": 3,
|
22
|
+
"norm": None,
|
23
|
+
"ks": 3,
|
24
|
+
"store_features": True,
|
25
|
+
}
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
class WGSE(BaseModel):
|
30
|
+
def __init__(self, cfg: BaseModelConfig):
|
31
|
+
super(WGSE, self).__init__(cfg)
|
@@ -0,0 +1,267 @@
|
|
1
|
+
import torch
|
2
|
+
from dataclasses import dataclass, field
|
3
|
+
import os
|
4
|
+
from spikezoo.utils.img_utils import tensor2npy, AverageMeter
|
5
|
+
from spikezoo.utils.spike_utils import load_vidar_dat
|
6
|
+
from spikezoo.metrics import cal_metric_pair, cal_metric_single
|
7
|
+
import numpy as np
|
8
|
+
import cv2
|
9
|
+
from pathlib import Path
|
10
|
+
from enum import Enum, auto
|
11
|
+
from typing import Literal
|
12
|
+
from spikezoo.metrics import metric_pair_names, metric_single_names, metric_all_names
|
13
|
+
from thop import profile
|
14
|
+
import time
|
15
|
+
from datetime import datetime
|
16
|
+
from spikezoo.utils import setup_logging, save_config
|
17
|
+
from tqdm import tqdm
|
18
|
+
from spikezoo.models import build_model_cfg, build_model_name, BaseModel, BaseModelConfig
|
19
|
+
from spikezoo.datasets import build_dataset_cfg, build_dataset_name, BaseDataset, BaseDatasetConfig, build_dataloader
|
20
|
+
from typing import Optional, Union, List
|
21
|
+
|
22
|
+
|
23
|
+
@dataclass
|
24
|
+
class PipelineConfig:
|
25
|
+
"Evaluate metrics or not."
|
26
|
+
save_metric: bool = True
|
27
|
+
"Save recoverd images or not."
|
28
|
+
save_img: bool = True
|
29
|
+
"Normalizing recoverd images or not."
|
30
|
+
save_img_norm: bool = False
|
31
|
+
"Normalizing gt or not."
|
32
|
+
gt_img_norm: bool = False
|
33
|
+
"Save folder for the code running result."
|
34
|
+
save_folder: str = ""
|
35
|
+
"Saved experiment name."
|
36
|
+
exp_name: str = ""
|
37
|
+
"Metric names for evaluation."
|
38
|
+
metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim"])
|
39
|
+
"Different modes for the pipeline."
|
40
|
+
_mode: Literal["single_mode", "multi_mode", "train_mode"] = "single_mode"
|
41
|
+
|
42
|
+
|
43
|
+
class Pipeline:
|
44
|
+
def __init__(
|
45
|
+
self,
|
46
|
+
cfg: PipelineConfig,
|
47
|
+
model_cfg: Union[str, BaseModelConfig],
|
48
|
+
dataset_cfg: Union[str, BaseDatasetConfig],
|
49
|
+
):
|
50
|
+
self.cfg = cfg
|
51
|
+
self._setup_model_data(model_cfg, dataset_cfg)
|
52
|
+
self._setup_pipeline()
|
53
|
+
|
54
|
+
def _setup_model_data(self, model_cfg, dataset_cfg):
|
55
|
+
"""Model and Data setup."""
|
56
|
+
# model
|
57
|
+
self.model: BaseModel = build_model_name(model_cfg) if isinstance(model_cfg, str) else build_model_cfg(model_cfg)
|
58
|
+
self.model = self.model.eval()
|
59
|
+
torch.set_grad_enabled(False)
|
60
|
+
# dataset
|
61
|
+
self.dataset: BaseDataset = build_dataset_name(dataset_cfg) if isinstance(dataset_cfg, str) else build_dataset_cfg(dataset_cfg)
|
62
|
+
self.dataloader = build_dataloader(self.dataset)
|
63
|
+
# device
|
64
|
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
65
|
+
|
66
|
+
def _setup_pipeline(self):
|
67
|
+
"""Pipeline setup."""
|
68
|
+
# save folder
|
69
|
+
self.thistime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S-%f")[:23]
|
70
|
+
self.save_folder = Path(__file__).parent.parent / Path(f"results") if len(self.cfg.save_folder) == 0 else self.cfg.save_folder
|
71
|
+
mode_name = "train" if self.cfg._mode == "train_mode" else "detect"
|
72
|
+
self.save_folder = (
|
73
|
+
self.save_folder / Path(f"{mode_name}/{self.thistime}")
|
74
|
+
if len(self.cfg.exp_name) == 0
|
75
|
+
else self.save_folder / Path(f"{mode_name}/{self.cfg.exp_name}")
|
76
|
+
)
|
77
|
+
save_folder = self.save_folder
|
78
|
+
os.makedirs(str(save_folder), exist_ok=True)
|
79
|
+
# logger result
|
80
|
+
self.logger = setup_logging(save_folder / Path("result.log"))
|
81
|
+
self.logger.info(f"Info logs are saved on the {save_folder}/result.log")
|
82
|
+
# pipeline config
|
83
|
+
save_config(self.cfg, save_folder / Path("cfg_pipeline.log"))
|
84
|
+
# model config
|
85
|
+
if self.cfg._mode == "single_mode":
|
86
|
+
save_config(self.model.cfg, save_folder / Path("cfg_model.log"))
|
87
|
+
elif self.cfg._mode == "multi_mode":
|
88
|
+
for model in self.model_list:
|
89
|
+
save_config(model.cfg, save_folder / Path("cfg_model.log"), mode="a")
|
90
|
+
# dataset config
|
91
|
+
save_config(self.dataset.cfg, save_folder / Path("cfg_dataset.log"))
|
92
|
+
|
93
|
+
def spk2img_from_dataset(self, idx=0):
|
94
|
+
"""Func---Save the recoverd image and calculate the metric from the given dataset."""
|
95
|
+
# save folder
|
96
|
+
save_folder = self.save_folder / Path(f"spk2img_from_dataset/{self.dataset.cfg.dataset_name}_dataset/{self.dataset.cfg.split}/{idx:06d}")
|
97
|
+
os.makedirs(str(save_folder), exist_ok=True)
|
98
|
+
|
99
|
+
# data process
|
100
|
+
batch = self.dataset[idx]
|
101
|
+
spike, img = batch["spike"], batch["img"]
|
102
|
+
spike = spike[None].to(self.device)
|
103
|
+
if self.dataset.cfg.with_img == True:
|
104
|
+
img = img[None].to(self.device)
|
105
|
+
else:
|
106
|
+
img = None
|
107
|
+
return self._spk2img(spike, img, save_folder)
|
108
|
+
|
109
|
+
def spk2img_from_file(self, file_path, height, width, img_path=None, remove_head=False):
|
110
|
+
"""Func---Save the recoverd image and calculate the metric from the given input file."""
|
111
|
+
# save folder
|
112
|
+
save_folder = self.save_folder / Path(f"spk2img_from_file/{os.path.basename(file_path)}")
|
113
|
+
os.makedirs(str(save_folder), exist_ok=True)
|
114
|
+
|
115
|
+
# load spike from .dat
|
116
|
+
if file_path.endswith(".dat"):
|
117
|
+
spike = load_vidar_dat(file_path, height, width, remove_head)
|
118
|
+
# load spike from .npz from UHSR
|
119
|
+
elif file_path.endswith("npz"):
|
120
|
+
spike = np.load(file_path)["spk"].astype(np.float32)[:, 13:237, 13:237]
|
121
|
+
else:
|
122
|
+
raise RuntimeError("Not recognized spike input file.")
|
123
|
+
# load img from .png/.jpg image file
|
124
|
+
if img_path is not None:
|
125
|
+
img = cv2.imread(img_path)
|
126
|
+
if img.ndim == 3:
|
127
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
128
|
+
img = (img / 255).astype(np.float32)
|
129
|
+
img = torch.from_numpy(img)[None, None].to(self.device)
|
130
|
+
else:
|
131
|
+
img = img_path
|
132
|
+
spike = torch.from_numpy(spike)[None].to(self.device)
|
133
|
+
return self._spk2img(spike, img, save_folder)
|
134
|
+
|
135
|
+
def spk2img_from_spk(self, spike, img=None):
|
136
|
+
"""Func---Save the recoverd image and calculate the metric from the given spike stream."""
|
137
|
+
# save folder
|
138
|
+
save_folder = self.save_folder / Path(f"spk2img_from_spk/{self.thistime}")
|
139
|
+
os.makedirs(str(save_folder), exist_ok=True)
|
140
|
+
|
141
|
+
# spike process
|
142
|
+
if isinstance(spike, np.ndarray):
|
143
|
+
spike = torch.from_numpy(spike)
|
144
|
+
spike = spike.to(self.device)
|
145
|
+
# [c,h,w] -> [1,c,w,h]
|
146
|
+
if spike.dim() == 3:
|
147
|
+
spike = spike[None]
|
148
|
+
spike = spike.float()
|
149
|
+
# img process
|
150
|
+
if img is not None:
|
151
|
+
if isinstance(img, np.ndarray):
|
152
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if img.ndim == 3 else img
|
153
|
+
img = (img / 255).astype(np.float32)
|
154
|
+
img = torch.from_numpy(img)[None, None].to(self.device)
|
155
|
+
else:
|
156
|
+
raise RuntimeError("Not recognized image input type.")
|
157
|
+
return self._spk2img(spike, img, save_folder)
|
158
|
+
|
159
|
+
def save_imgs_from_dataset(self):
|
160
|
+
"""Func---Save all images from the given dataset."""
|
161
|
+
for idx in range(len(self.dataset)):
|
162
|
+
self.spk2img_from_dataset(idx=idx)
|
163
|
+
|
164
|
+
# TODO: To be overridden
|
165
|
+
def cal_params(self):
|
166
|
+
"""Func---Calculate the parameters/flops/latency of the given method."""
|
167
|
+
self._cal_prams_model(self.model)
|
168
|
+
|
169
|
+
# TODO: To be overridden
|
170
|
+
def cal_metrics(self):
|
171
|
+
"""Func---Calculate the metric of the given method."""
|
172
|
+
self._cal_metrics_model(self.model)
|
173
|
+
|
174
|
+
# TODO: To be overridden
|
175
|
+
def _spk2img(self, spike, img, save_folder):
|
176
|
+
"""Spike-to-image: spike:[bs,c,h,w] (0-1), img:[bs,1,h,w] (0-1)"""
|
177
|
+
return self._spk2img_model(self.model, spike, img, save_folder)
|
178
|
+
|
179
|
+
def _spk2img_model(self, model, spike, img, save_folder):
|
180
|
+
"""Spike-to-image from the given model."""
|
181
|
+
# spike2image conversion
|
182
|
+
model_name = model.cfg.model_name
|
183
|
+
recon_img = model(spike)
|
184
|
+
recon_img_copy = recon_img.clone()
|
185
|
+
# normalization
|
186
|
+
recon_img, img = self._post_process_img(model_name, recon_img, img)
|
187
|
+
# metric
|
188
|
+
if self.cfg.save_metric == True:
|
189
|
+
self.logger.info(f"----------------------Method: {model_name.upper()}----------------------")
|
190
|
+
# paired metric
|
191
|
+
for metric_name in metric_all_names:
|
192
|
+
if img is not None and metric_name in metric_pair_names:
|
193
|
+
self.logger.info(f"{metric_name.upper()}: {cal_metric_pair(recon_img,img,metric_name)}")
|
194
|
+
elif metric_name in metric_single_names:
|
195
|
+
self.logger.info(f"{metric_name.upper()}: {cal_metric_single(recon_img,metric_name)}")
|
196
|
+
else:
|
197
|
+
self.logger.info(f"{metric_name.upper()} not calculated since no ground truth provided.")
|
198
|
+
|
199
|
+
# visual
|
200
|
+
if self.cfg.save_img == True:
|
201
|
+
recon_img = tensor2npy(recon_img[0, 0])
|
202
|
+
cv2.imwrite(f"{save_folder}/{model.cfg.model_name}.png", recon_img)
|
203
|
+
if img is not None:
|
204
|
+
img = tensor2npy(img[0, 0])
|
205
|
+
cv2.imwrite(f"{save_folder}/sharp_img.png", img)
|
206
|
+
self.logger.info(f"Images are saved on the {save_folder}")
|
207
|
+
|
208
|
+
return recon_img_copy
|
209
|
+
|
210
|
+
def _post_process_img(self, model_name, recon_img, gt_img):
|
211
|
+
"""Post process the reconstructed image."""
|
212
|
+
# TFP and TFI algorithms are normalized automatically, others are normalized based on the self.cfg.use_norm
|
213
|
+
if model_name in ["tfp", "tfi", "spikeformer", "spikeclip"]:
|
214
|
+
recon_img = (recon_img - recon_img.min()) / (recon_img.max() - recon_img.min())
|
215
|
+
elif self.cfg.save_img_norm == True:
|
216
|
+
recon_img = (recon_img - recon_img.min()) / (recon_img.max() - recon_img.min())
|
217
|
+
recon_img = recon_img.clip(0, 1)
|
218
|
+
gt_img = (gt_img - gt_img.min()) / (gt_img.max() - gt_img.min()) if self.cfg.gt_img_norm == True and gt_img is not None else gt_img
|
219
|
+
return recon_img, gt_img
|
220
|
+
|
221
|
+
def _cal_metrics_model(self, model: BaseModel):
|
222
|
+
"""Calculate the metrics for the given model."""
|
223
|
+
# metrics construct
|
224
|
+
model_name = model.cfg.model_name
|
225
|
+
metrics_dict = {}
|
226
|
+
for metric_name in self.cfg.metric_names:
|
227
|
+
if (self.dataset.cfg.with_img == True) or (metric_name in metric_single_names):
|
228
|
+
metrics_dict[metric_name] = AverageMeter()
|
229
|
+
|
230
|
+
# metrics calculate
|
231
|
+
for batch_idx, batch in enumerate(tqdm(self.dataloader)):
|
232
|
+
batch = model.feed_to_device(batch)
|
233
|
+
outputs = model.get_outputs_dict(batch)
|
234
|
+
recon_img, img = model.get_paired_imgs(batch, outputs)
|
235
|
+
recon_img, img = self._post_process_img(model_name, recon_img, img)
|
236
|
+
for metric_name in metrics_dict.keys():
|
237
|
+
if metric_name in metric_pair_names:
|
238
|
+
metrics_dict[metric_name].update(cal_metric_pair(recon_img, img, metric_name))
|
239
|
+
elif metric_name in metric_single_names:
|
240
|
+
metrics_dict[metric_name].update(cal_metric_single(recon_img, metric_name))
|
241
|
+
|
242
|
+
# metrics self.logger.info
|
243
|
+
self.logger.info(f"----------------------Method: {model_name.upper()}----------------------")
|
244
|
+
for metric_name in metrics_dict.keys():
|
245
|
+
self.logger.info(f"{metric_name.upper()}: {metrics_dict[metric_name].avg}")
|
246
|
+
|
247
|
+
def _cal_prams_model(self, model):
|
248
|
+
"""Calculate the parameters for the given model."""
|
249
|
+
network = model.net
|
250
|
+
model_name = model.cfg.model_name.upper()
|
251
|
+
# params
|
252
|
+
params = sum(p.numel() for p in network.parameters())
|
253
|
+
# latency
|
254
|
+
spike = torch.zeros((1, 200, 250, 400)).cuda()
|
255
|
+
start_time = time.time()
|
256
|
+
for _ in range(100):
|
257
|
+
model(spike)
|
258
|
+
latency = (time.time() - start_time) / 100
|
259
|
+
# flop # todo thop bug for BSF
|
260
|
+
flops, _ = profile((model), inputs=(spike,))
|
261
|
+
re_msg = (
|
262
|
+
"Total params: %.4fM" % (params / 1e6),
|
263
|
+
"FLOPs:" + str(flops / 1e9) + "{}".format("G"),
|
264
|
+
"Latency: {:.6f} seconds".format(latency),
|
265
|
+
)
|
266
|
+
self.logger.info(f"----------------------Method: {model_name}----------------------")
|
267
|
+
self.logger.info(re_msg)
|