spikezoo 0.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,90 @@
1
+ import torch
2
+ from dataclasses import dataclass
3
+ from spikezoo.models.base_model import BaseModel, BaseModelConfig
4
+
5
+
6
+ @dataclass
7
+ class BSFConfig(BaseModelConfig):
8
+ # default params for BSF
9
+ model_name: str = "bsf"
10
+ model_file_name: str = "models.bsf.bsf"
11
+ model_cls_name: str = "BSF"
12
+ model_win_length: int = 61
13
+ require_params: bool = True
14
+ ckpt_path: str = "weights/bsf.pth"
15
+
16
+
17
+ class BSF(BaseModel):
18
+ def __init__(self, cfg: BaseModelConfig):
19
+ super(BSF, self).__init__(cfg)
20
+
21
+ def preprocess_spike(self, spike):
22
+ # length
23
+ spike = self.crop_spike_length(spike)
24
+ # size
25
+ if self.spike_size == (250, 400):
26
+ spike = torch.cat([spike, spike[:, :, -2:]], dim=2)
27
+ elif self.spike_size == (480, 854):
28
+ spike = torch.cat([spike, spike[:, :, :, -2:]], dim=3)
29
+ # dsft
30
+ dsft = self.compute_dsft_core(spike)
31
+ dsft_dict = self.convert_dsft4(dsft, spike)
32
+ input_dict = {
33
+ "dsft_dict": dsft_dict,
34
+ "spikes": spike,
35
+ }
36
+ return input_dict
37
+
38
+ def postprocess_img(self, image):
39
+ if self.spike_size == (250, 400):
40
+ image = image[:, :, :250, :]
41
+ elif self.spike_size == (480, 854):
42
+ image = image[:, :, :, :854]
43
+ return image
44
+
45
+ def compute_dsft_core(self, spike):
46
+ bs, T, H, W = spike.shape
47
+ time = spike * torch.arange(T, device="cuda").reshape(1, T, 1, 1)
48
+ l_idx, _ = time.cummax(dim=1)
49
+ time[time == 0] = T
50
+ r_idx, _ = torch.flip(time, [1]).cummin(dim=1)
51
+ r_idx = torch.flip(r_idx, [1])
52
+ r_idx = torch.cat([r_idx[:, 1:, :, :], torch.ones([bs, 1, H, W], device="cuda") * T], dim=1)
53
+ res = r_idx - l_idx
54
+ res = torch.clip(res, 0)
55
+ return res
56
+
57
+ def convert_dsft4(self, dsft, spike):
58
+ b, T, h, w = spike.shape
59
+ dmls1 = -1 * torch.ones(spike.shape, device=spike.device, dtype=torch.float32)
60
+ dmrs1 = -1 * torch.ones(spike.shape, device=spike.device, dtype=torch.float32)
61
+ flag = -2 * torch.ones([b, h, w], device=spike.device, dtype=torch.float32)
62
+ for ii in range(T - 1, 0 - 1, -1):
63
+ flag += spike[:, ii] == 1
64
+ copy_pad_coord = flag < 0
65
+ dmls1[:, ii][copy_pad_coord] = dsft[:, ii][copy_pad_coord]
66
+ if ii < T - 1:
67
+ update_coord = (spike[:, ii + 1] == 1) * (~copy_pad_coord)
68
+ dmls1[:, ii][update_coord] = dsft[:, ii + 1][update_coord]
69
+ non_update_coord = (spike[:, ii + 1] != 1) * (~copy_pad_coord)
70
+ dmls1[:, ii][non_update_coord] = dmls1[:, ii + 1][non_update_coord]
71
+ flag = -2 * torch.ones([b, h, w], device=spike.device, dtype=torch.float32)
72
+ for ii in range(0, T, 1):
73
+ flag += spike[:, ii] == 1
74
+ copy_pad_coord = flag < 0
75
+ dmrs1[:, ii][copy_pad_coord] = dsft[:, ii][copy_pad_coord]
76
+ if ii > 0:
77
+ update_coord = (spike[:, ii] == 1) * (~copy_pad_coord)
78
+ dmrs1[:, ii][update_coord] = dsft[:, ii - 1][update_coord]
79
+ non_update_coord = (spike[:, ii] != 1) * (~copy_pad_coord)
80
+ dmrs1[:, ii][non_update_coord] = dmrs1[:, ii - 1][non_update_coord]
81
+ dsft12 = dsft + dmls1
82
+ dsft21 = dsft + dmrs1
83
+ dsft22 = dsft + dmls1 + dmrs1
84
+ dsft_dict = {
85
+ "dsft11": dsft,
86
+ "dsft12": dsft12,
87
+ "dsft21": dsft21,
88
+ "dsft22": dsft22,
89
+ }
90
+ return dsft_dict
@@ -0,0 +1,19 @@
1
+ from dataclasses import dataclass, field
2
+ from spikezoo.models.base_model import BaseModel, BaseModelConfig
3
+ from typing import List
4
+
5
+
6
+ @dataclass
7
+ class SPCSNetConfig(BaseModelConfig):
8
+ # default params for WGSE
9
+ model_name: str = "spcsnet"
10
+ model_file_name: str = "models"
11
+ model_cls_name: str = "SPCS_Net"
12
+ model_win_length: int = 41
13
+ require_params: bool = True
14
+ ckpt_path: str = 'weights/spcsnet.pth'
15
+
16
+
17
+ class SPCSNet(BaseModel):
18
+ def __init__(self, cfg: BaseModelConfig):
19
+ super(SPCSNet, self).__init__(cfg)
@@ -0,0 +1,32 @@
1
+ from dataclasses import dataclass
2
+ from spikezoo.models.base_model import BaseModel, BaseModelConfig
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ @dataclass
8
+ class SpikeCLIPConfig(BaseModelConfig):
9
+ # default params for SpikeCLIP
10
+ model_name: str = "spikeclip"
11
+ model_file_name: str = "nets"
12
+ model_cls_name: str = "LRN"
13
+ model_win_length: int = 200
14
+ require_params: bool = True
15
+ ckpt_path: str = "weights/spikeclip.pth"
16
+
17
+
18
+ class SpikeCLIP(BaseModel):
19
+ def __init__(self, cfg: BaseModelConfig):
20
+ super(SpikeCLIP, self).__init__(cfg)
21
+
22
+ def preprocess_spike(self, spike):
23
+ # length
24
+ spike = self.crop_spike_length(spike)
25
+ # voxel
26
+ voxel = torch.sum(spike.reshape(-1, 50, 4, spike.shape[-2], spike.shape[-1]), axis=2) # [200,224,224] -> [50,224,224]
27
+ voxel = F.pad(voxel, pad=(20, 20, 20, 20), mode="reflect", value=0)
28
+ return voxel
29
+
30
+ def postprocess_img(self, image):
31
+ image = image[:, :, 20:-20, 20:-20]
32
+ return image
@@ -0,0 +1,50 @@
1
+ import torch
2
+ from dataclasses import dataclass, field
3
+ from spikezoo.models.base_model import BaseModel, BaseModelConfig
4
+
5
+
6
+ @dataclass
7
+ class SpikeFormerConfig(BaseModelConfig):
8
+ # default params for SpikeFormer
9
+ model_name: str = "spikeformer"
10
+ model_file_name: str = "Model.SpikeFormer"
11
+ model_cls_name: str = "SpikeFormer"
12
+ model_win_length: int = 65
13
+ require_params: bool = True
14
+ ckpt_path: str = "weights/spikeformer.pth"
15
+ model_params: dict = field(
16
+ default_factory=lambda: {
17
+ "inputDim": 65,
18
+ "dims": (32, 64, 160, 256),
19
+ "heads": (1, 2, 5, 8),
20
+ "ff_expansion": (8, 8, 4, 4),
21
+ "reduction_ratio": (8, 4, 2, 1),
22
+ "num_layers": 2,
23
+ "decoder_dim": 256,
24
+ "out_channel": 1,
25
+ }
26
+ )
27
+
28
+
29
+ class SpikeFormer(BaseModel):
30
+ def __init__(self, cfg: BaseModelConfig):
31
+ super(SpikeFormer, self).__init__(cfg)
32
+
33
+ def preprocess_spike(self, spike):
34
+ # length
35
+ spike = self.crop_spike_length(spike)
36
+ # size
37
+ if self.spike_size == (250, 400):
38
+ spike = torch.cat([spike[:, :, :3, :], spike, spike[:, :, -3:, :]], dim=2)
39
+ elif self.spike_size == (480, 854):
40
+ spike = torch.cat([spike, spike[:, :, :, -2:]], dim=3)
41
+ # input
42
+ spike = 2 * spike - 1
43
+ return spike
44
+
45
+ def postprocess_img(self, image):
46
+ if self.spike_size == (250, 400):
47
+ image = image[:, :, 3:-3, :]
48
+ elif self.spike_size == (480, 854):
49
+ image = image[:, :, :, :854]
50
+ return image
@@ -0,0 +1,51 @@
1
+ import torch
2
+ from dataclasses import dataclass, field
3
+ from spikezoo.models.base_model import BaseModel, BaseModelConfig
4
+
5
+
6
+ @dataclass
7
+ class Spk2ImgNetConfig(BaseModelConfig):
8
+ # default params for Spk2ImgNet
9
+ model_name: str = "spk2imgnet"
10
+ model_file_name: str = "nets"
11
+ model_cls_name: str = "SpikeNet"
12
+ model_win_length: int = 41
13
+ require_params: bool = True
14
+ ckpt_path: str = "weights/spk2imgnet.pth"
15
+ light_correction: bool = False
16
+
17
+ # model params
18
+ model_params: dict = field(
19
+ default_factory=lambda: {
20
+ "in_channels": 13,
21
+ "features": 64,
22
+ "out_channels": 1,
23
+ "win_r": 6,
24
+ "win_step": 7,
25
+ }
26
+ )
27
+
28
+
29
+ class Spk2ImgNet(BaseModel):
30
+ def __init__(self, cfg: BaseModelConfig):
31
+ super(Spk2ImgNet, self).__init__(cfg)
32
+
33
+ def preprocess_spike(self, spike):
34
+ # length
35
+ spike = self.crop_spike_length(spike)
36
+ # size
37
+ if self.spike_size == (250, 400):
38
+ spike = torch.cat([spike, spike[:, :, -2:]], dim=2)
39
+ elif self.spike_size == (480, 854):
40
+ spike = torch.cat([spike, spike[:, :, :, -2:]], dim=3)
41
+ return spike
42
+
43
+ def postprocess_img(self, image):
44
+ if self.spike_size == (250, 400):
45
+ image = image[:, :, :250, :]
46
+ elif self.spike_size == (480, 854):
47
+ image = image[:, :, :, :854]
48
+ # used on the REDS_small dataset.
49
+ if self.cfg.light_correction == True:
50
+ image = torch.clamp(image / 0.6, 0, 1)
51
+ return image
@@ -0,0 +1,22 @@
1
+ from dataclasses import dataclass
2
+ from spikezoo.models.base_model import BaseModel, BaseModelConfig
3
+
4
+
5
+ @dataclass
6
+ class SSIRConfig(BaseModelConfig):
7
+ # default params for SSIR
8
+ model_name: str = "ssir"
9
+ model_file_name: str = "models.networks"
10
+ model_cls_name: str = "SSIR"
11
+ model_win_length: int = 41
12
+ require_params: bool = True
13
+ ckpt_path: str = "weights/ssir.pth"
14
+
15
+
16
+ class SSIR(BaseModel):
17
+ def __init__(self, cfg: BaseModelConfig):
18
+ super(SSIR, self).__init__(cfg)
19
+
20
+ def postprocess_img(self, image):
21
+ # image = image[0]
22
+ return image
@@ -0,0 +1,18 @@
1
+ from dataclasses import dataclass
2
+ from spikezoo.models.base_model import BaseModel, BaseModelConfig
3
+
4
+
5
+ @dataclass
6
+ class SSMLConfig(BaseModelConfig):
7
+ # default params for SSML
8
+ model_name: str = "ssml"
9
+ model_file_name: str = "model"
10
+ model_cls_name: str = "DoubleNet"
11
+ model_win_length: int = 41
12
+ require_params: bool = True
13
+ ckpt_path: str = 'weights/ssml.pt'
14
+
15
+
16
+ class SSML(BaseModel):
17
+ def __init__(self, cfg: BaseModelConfig):
18
+ super(SSML, self).__init__(cfg)
@@ -0,0 +1,37 @@
1
+ import torch
2
+ from dataclasses import dataclass
3
+ from spikezoo.models.base_model import BaseModel, BaseModelConfig
4
+
5
+
6
+ @dataclass
7
+ class STIRConfig(BaseModelConfig):
8
+ # default params for SSIR
9
+ model_name: str = "stir"
10
+ model_file_name: str = "models.networks_STIR"
11
+ model_cls_name: str = "STIR"
12
+ model_win_length: int = 61
13
+ require_params: bool = True
14
+ ckpt_path: str = "weights/stir.pth"
15
+
16
+
17
+ class STIR(BaseModel):
18
+ def __init__(self, cfg: BaseModelConfig):
19
+ super(STIR, self).__init__(cfg)
20
+
21
+ def preprocess_spike(self, spike):
22
+ # length
23
+ spike = self.crop_spike_length(spike)
24
+ # size
25
+ if self.spike_size == (250, 400):
26
+ spike = torch.cat([spike, spike[:, :, -6:]], dim=2)
27
+ elif self.spike_size == (480, 854):
28
+ spike = torch.cat([spike, spike[:, :, :, -10:]], dim=3)
29
+ return spike
30
+
31
+ def postprocess_img(self, image):
32
+ # recon, Fs_lv_0, Fs_lv_1, Fs_lv_2, Fs_lv_3, Fs_lv_4, Est = image
33
+ if self.spike_size == (250, 400):
34
+ image = image[:, :, :250, :]
35
+ elif self.spike_size == (480, 854):
36
+ image = image[:, :, :, :854]
37
+ return image
@@ -0,0 +1,18 @@
1
+ from dataclasses import dataclass, field
2
+ from spikezoo.models.base_model import BaseModel, BaseModelConfig
3
+
4
+
5
+ @dataclass
6
+ class TFIConfig(BaseModelConfig):
7
+ # default params for TFI
8
+ model_name: str = "tfi"
9
+ model_file_name: str = "nets"
10
+ model_cls_name: str = "TFIModel"
11
+ model_win_length: int = 41
12
+ require_params: bool = False
13
+ model_params: dict = field(default_factory=lambda: {"model_win_length": 41})
14
+
15
+
16
+ class TFI(BaseModel):
17
+ def __init__(self, cfg: BaseModelConfig):
18
+ super(TFI, self).__init__(cfg)
@@ -0,0 +1,18 @@
1
+ from dataclasses import dataclass,field
2
+ from spikezoo.models.base_model import BaseModel, BaseModelConfig
3
+
4
+
5
+ @dataclass
6
+ class TFPConfig(BaseModelConfig):
7
+ # default params for TFP
8
+ model_name: str = "tfp"
9
+ model_file_name: str = "nets"
10
+ model_cls_name: str = "TFPModel"
11
+ model_win_length: int = 41
12
+ require_params: bool = False
13
+ model_params: dict = field(default_factory=lambda: {"model_win_length": 41})
14
+
15
+
16
+ class TFP(BaseModel):
17
+ def __init__(self, cfg: BaseModelConfig):
18
+ super(TFP, self).__init__(cfg)
@@ -0,0 +1,31 @@
1
+ from dataclasses import dataclass, field
2
+ from spikezoo.models.base_model import BaseModel, BaseModelConfig
3
+ from typing import List
4
+
5
+
6
+ @dataclass
7
+ class WGSEConfig(BaseModelConfig):
8
+ # default params for WGSE
9
+ model_name: str = "wgse"
10
+ model_file_name: str = "dwtnets"
11
+ model_cls_name: str = "Dwt1dResnetX_TCN"
12
+ model_win_length: int = 41
13
+ require_params: bool = True
14
+ ckpt_path: str = "weights/wgse.pt"
15
+ model_params: dict = field(
16
+ default_factory=lambda: {
17
+ "wvlname": "db8",
18
+ "J": 5,
19
+ "yl_size": "15",
20
+ "yh_size": [28, 21, 18, 16, 15],
21
+ "num_residual_blocks": 3,
22
+ "norm": None,
23
+ "ks": 3,
24
+ "store_features": True,
25
+ }
26
+ )
27
+
28
+
29
+ class WGSE(BaseModel):
30
+ def __init__(self, cfg: BaseModelConfig):
31
+ super(WGSE, self).__init__(cfg)
@@ -0,0 +1,4 @@
1
+ from .base_pipeline import Pipeline, PipelineConfig
2
+ from .ensemble_pipeline import EnsemblePipelineConfig, EnsemblePipeline
3
+ from .train_pipeline import TrainPipelineConfig, TrainPipeline
4
+
@@ -0,0 +1,267 @@
1
+ import torch
2
+ from dataclasses import dataclass, field
3
+ import os
4
+ from spikezoo.utils.img_utils import tensor2npy, AverageMeter
5
+ from spikezoo.utils.spike_utils import load_vidar_dat
6
+ from spikezoo.metrics import cal_metric_pair, cal_metric_single
7
+ import numpy as np
8
+ import cv2
9
+ from pathlib import Path
10
+ from enum import Enum, auto
11
+ from typing import Literal
12
+ from spikezoo.metrics import metric_pair_names, metric_single_names, metric_all_names
13
+ from thop import profile
14
+ import time
15
+ from datetime import datetime
16
+ from spikezoo.utils import setup_logging, save_config
17
+ from tqdm import tqdm
18
+ from spikezoo.models import build_model_cfg, build_model_name, BaseModel, BaseModelConfig
19
+ from spikezoo.datasets import build_dataset_cfg, build_dataset_name, BaseDataset, BaseDatasetConfig, build_dataloader
20
+ from typing import Optional, Union, List
21
+
22
+
23
+ @dataclass
24
+ class PipelineConfig:
25
+ "Evaluate metrics or not."
26
+ save_metric: bool = True
27
+ "Save recoverd images or not."
28
+ save_img: bool = True
29
+ "Normalizing recoverd images or not."
30
+ save_img_norm: bool = False
31
+ "Normalizing gt or not."
32
+ gt_img_norm: bool = False
33
+ "Save folder for the code running result."
34
+ save_folder: str = ""
35
+ "Saved experiment name."
36
+ exp_name: str = ""
37
+ "Metric names for evaluation."
38
+ metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim"])
39
+ "Different modes for the pipeline."
40
+ _mode: Literal["single_mode", "multi_mode", "train_mode"] = "single_mode"
41
+
42
+
43
+ class Pipeline:
44
+ def __init__(
45
+ self,
46
+ cfg: PipelineConfig,
47
+ model_cfg: Union[str, BaseModelConfig],
48
+ dataset_cfg: Union[str, BaseDatasetConfig],
49
+ ):
50
+ self.cfg = cfg
51
+ self._setup_model_data(model_cfg, dataset_cfg)
52
+ self._setup_pipeline()
53
+
54
+ def _setup_model_data(self, model_cfg, dataset_cfg):
55
+ """Model and Data setup."""
56
+ # model
57
+ self.model: BaseModel = build_model_name(model_cfg) if isinstance(model_cfg, str) else build_model_cfg(model_cfg)
58
+ self.model = self.model.eval()
59
+ torch.set_grad_enabled(False)
60
+ # dataset
61
+ self.dataset: BaseDataset = build_dataset_name(dataset_cfg) if isinstance(dataset_cfg, str) else build_dataset_cfg(dataset_cfg)
62
+ self.dataloader = build_dataloader(self.dataset)
63
+ # device
64
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
65
+
66
+ def _setup_pipeline(self):
67
+ """Pipeline setup."""
68
+ # save folder
69
+ self.thistime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S-%f")[:23]
70
+ self.save_folder = Path(__file__).parent.parent / Path(f"results") if len(self.cfg.save_folder) == 0 else self.cfg.save_folder
71
+ mode_name = "train" if self.cfg._mode == "train_mode" else "detect"
72
+ self.save_folder = (
73
+ self.save_folder / Path(f"{mode_name}/{self.thistime}")
74
+ if len(self.cfg.exp_name) == 0
75
+ else self.save_folder / Path(f"{mode_name}/{self.cfg.exp_name}")
76
+ )
77
+ save_folder = self.save_folder
78
+ os.makedirs(str(save_folder), exist_ok=True)
79
+ # logger result
80
+ self.logger = setup_logging(save_folder / Path("result.log"))
81
+ self.logger.info(f"Info logs are saved on the {save_folder}/result.log")
82
+ # pipeline config
83
+ save_config(self.cfg, save_folder / Path("cfg_pipeline.log"))
84
+ # model config
85
+ if self.cfg._mode == "single_mode":
86
+ save_config(self.model.cfg, save_folder / Path("cfg_model.log"))
87
+ elif self.cfg._mode == "multi_mode":
88
+ for model in self.model_list:
89
+ save_config(model.cfg, save_folder / Path("cfg_model.log"), mode="a")
90
+ # dataset config
91
+ save_config(self.dataset.cfg, save_folder / Path("cfg_dataset.log"))
92
+
93
+ def spk2img_from_dataset(self, idx=0):
94
+ """Func---Save the recoverd image and calculate the metric from the given dataset."""
95
+ # save folder
96
+ save_folder = self.save_folder / Path(f"spk2img_from_dataset/{self.dataset.cfg.dataset_name}_dataset/{self.dataset.cfg.split}/{idx:06d}")
97
+ os.makedirs(str(save_folder), exist_ok=True)
98
+
99
+ # data process
100
+ batch = self.dataset[idx]
101
+ spike, img = batch["spike"], batch["img"]
102
+ spike = spike[None].to(self.device)
103
+ if self.dataset.cfg.with_img == True:
104
+ img = img[None].to(self.device)
105
+ else:
106
+ img = None
107
+ return self._spk2img(spike, img, save_folder)
108
+
109
+ def spk2img_from_file(self, file_path, height, width, img_path=None, remove_head=False):
110
+ """Func---Save the recoverd image and calculate the metric from the given input file."""
111
+ # save folder
112
+ save_folder = self.save_folder / Path(f"spk2img_from_file/{os.path.basename(file_path)}")
113
+ os.makedirs(str(save_folder), exist_ok=True)
114
+
115
+ # load spike from .dat
116
+ if file_path.endswith(".dat"):
117
+ spike = load_vidar_dat(file_path, height, width, remove_head)
118
+ # load spike from .npz from UHSR
119
+ elif file_path.endswith("npz"):
120
+ spike = np.load(file_path)["spk"].astype(np.float32)[:, 13:237, 13:237]
121
+ else:
122
+ raise RuntimeError("Not recognized spike input file.")
123
+ # load img from .png/.jpg image file
124
+ if img_path is not None:
125
+ img = cv2.imread(img_path)
126
+ if img.ndim == 3:
127
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
128
+ img = (img / 255).astype(np.float32)
129
+ img = torch.from_numpy(img)[None, None].to(self.device)
130
+ else:
131
+ img = img_path
132
+ spike = torch.from_numpy(spike)[None].to(self.device)
133
+ return self._spk2img(spike, img, save_folder)
134
+
135
+ def spk2img_from_spk(self, spike, img=None):
136
+ """Func---Save the recoverd image and calculate the metric from the given spike stream."""
137
+ # save folder
138
+ save_folder = self.save_folder / Path(f"spk2img_from_spk/{self.thistime}")
139
+ os.makedirs(str(save_folder), exist_ok=True)
140
+
141
+ # spike process
142
+ if isinstance(spike, np.ndarray):
143
+ spike = torch.from_numpy(spike)
144
+ spike = spike.to(self.device)
145
+ # [c,h,w] -> [1,c,w,h]
146
+ if spike.dim() == 3:
147
+ spike = spike[None]
148
+ spike = spike.float()
149
+ # img process
150
+ if img is not None:
151
+ if isinstance(img, np.ndarray):
152
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if img.ndim == 3 else img
153
+ img = (img / 255).astype(np.float32)
154
+ img = torch.from_numpy(img)[None, None].to(self.device)
155
+ else:
156
+ raise RuntimeError("Not recognized image input type.")
157
+ return self._spk2img(spike, img, save_folder)
158
+
159
+ def save_imgs_from_dataset(self):
160
+ """Func---Save all images from the given dataset."""
161
+ for idx in range(len(self.dataset)):
162
+ self.spk2img_from_dataset(idx=idx)
163
+
164
+ # TODO: To be overridden
165
+ def cal_params(self):
166
+ """Func---Calculate the parameters/flops/latency of the given method."""
167
+ self._cal_prams_model(self.model)
168
+
169
+ # TODO: To be overridden
170
+ def cal_metrics(self):
171
+ """Func---Calculate the metric of the given method."""
172
+ self._cal_metrics_model(self.model)
173
+
174
+ # TODO: To be overridden
175
+ def _spk2img(self, spike, img, save_folder):
176
+ """Spike-to-image: spike:[bs,c,h,w] (0-1), img:[bs,1,h,w] (0-1)"""
177
+ return self._spk2img_model(self.model, spike, img, save_folder)
178
+
179
+ def _spk2img_model(self, model, spike, img, save_folder):
180
+ """Spike-to-image from the given model."""
181
+ # spike2image conversion
182
+ model_name = model.cfg.model_name
183
+ recon_img = model(spike)
184
+ recon_img_copy = recon_img.clone()
185
+ # normalization
186
+ recon_img, img = self._post_process_img(model_name, recon_img, img)
187
+ # metric
188
+ if self.cfg.save_metric == True:
189
+ self.logger.info(f"----------------------Method: {model_name.upper()}----------------------")
190
+ # paired metric
191
+ for metric_name in metric_all_names:
192
+ if img is not None and metric_name in metric_pair_names:
193
+ self.logger.info(f"{metric_name.upper()}: {cal_metric_pair(recon_img,img,metric_name)}")
194
+ elif metric_name in metric_single_names:
195
+ self.logger.info(f"{metric_name.upper()}: {cal_metric_single(recon_img,metric_name)}")
196
+ else:
197
+ self.logger.info(f"{metric_name.upper()} not calculated since no ground truth provided.")
198
+
199
+ # visual
200
+ if self.cfg.save_img == True:
201
+ recon_img = tensor2npy(recon_img[0, 0])
202
+ cv2.imwrite(f"{save_folder}/{model.cfg.model_name}.png", recon_img)
203
+ if img is not None:
204
+ img = tensor2npy(img[0, 0])
205
+ cv2.imwrite(f"{save_folder}/sharp_img.png", img)
206
+ self.logger.info(f"Images are saved on the {save_folder}")
207
+
208
+ return recon_img_copy
209
+
210
+ def _post_process_img(self, model_name, recon_img, gt_img):
211
+ """Post process the reconstructed image."""
212
+ # TFP and TFI algorithms are normalized automatically, others are normalized based on the self.cfg.use_norm
213
+ if model_name in ["tfp", "tfi", "spikeformer", "spikeclip"]:
214
+ recon_img = (recon_img - recon_img.min()) / (recon_img.max() - recon_img.min())
215
+ elif self.cfg.save_img_norm == True:
216
+ recon_img = (recon_img - recon_img.min()) / (recon_img.max() - recon_img.min())
217
+ recon_img = recon_img.clip(0, 1)
218
+ gt_img = (gt_img - gt_img.min()) / (gt_img.max() - gt_img.min()) if self.cfg.gt_img_norm == True and gt_img is not None else gt_img
219
+ return recon_img, gt_img
220
+
221
+ def _cal_metrics_model(self, model: BaseModel):
222
+ """Calculate the metrics for the given model."""
223
+ # metrics construct
224
+ model_name = model.cfg.model_name
225
+ metrics_dict = {}
226
+ for metric_name in self.cfg.metric_names:
227
+ if (self.dataset.cfg.with_img == True) or (metric_name in metric_single_names):
228
+ metrics_dict[metric_name] = AverageMeter()
229
+
230
+ # metrics calculate
231
+ for batch_idx, batch in enumerate(tqdm(self.dataloader)):
232
+ batch = model.feed_to_device(batch)
233
+ outputs = model.get_outputs_dict(batch)
234
+ recon_img, img = model.get_paired_imgs(batch, outputs)
235
+ recon_img, img = self._post_process_img(model_name, recon_img, img)
236
+ for metric_name in metrics_dict.keys():
237
+ if metric_name in metric_pair_names:
238
+ metrics_dict[metric_name].update(cal_metric_pair(recon_img, img, metric_name))
239
+ elif metric_name in metric_single_names:
240
+ metrics_dict[metric_name].update(cal_metric_single(recon_img, metric_name))
241
+
242
+ # metrics self.logger.info
243
+ self.logger.info(f"----------------------Method: {model_name.upper()}----------------------")
244
+ for metric_name in metrics_dict.keys():
245
+ self.logger.info(f"{metric_name.upper()}: {metrics_dict[metric_name].avg}")
246
+
247
+ def _cal_prams_model(self, model):
248
+ """Calculate the parameters for the given model."""
249
+ network = model.net
250
+ model_name = model.cfg.model_name.upper()
251
+ # params
252
+ params = sum(p.numel() for p in network.parameters())
253
+ # latency
254
+ spike = torch.zeros((1, 200, 250, 400)).cuda()
255
+ start_time = time.time()
256
+ for _ in range(100):
257
+ model(spike)
258
+ latency = (time.time() - start_time) / 100
259
+ # flop # todo thop bug for BSF
260
+ flops, _ = profile((model), inputs=(spike,))
261
+ re_msg = (
262
+ "Total params: %.4fM" % (params / 1e6),
263
+ "FLOPs:" + str(flops / 1e9) + "{}".format("G"),
264
+ "Latency: {:.6f} seconds".format(latency),
265
+ )
266
+ self.logger.info(f"----------------------Method: {model_name}----------------------")
267
+ self.logger.info(re_msg)