spikezoo 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,90 @@
1
+ import torch
2
+ from dataclasses import dataclass
3
+ from spikezoo.models.base_model import BaseModel, BaseModelConfig
4
+
5
+
6
+ @dataclass
7
+ class BSFConfig(BaseModelConfig):
8
+ # default params for BSF
9
+ model_name: str = "bsf"
10
+ model_file_name: str = "models.bsf.bsf"
11
+ model_cls_name: str = "BSF"
12
+ model_win_length: int = 61
13
+ require_params: bool = True
14
+ ckpt_path: str = "weights/bsf.pth"
15
+
16
+
17
+ class BSF(BaseModel):
18
+ def __init__(self, cfg: BaseModelConfig):
19
+ super(BSF, self).__init__(cfg)
20
+
21
+ def preprocess_spike(self, spike):
22
+ # length
23
+ spike = self.crop_spike_length(spike)
24
+ # size
25
+ if self.spike_size == (250, 400):
26
+ spike = torch.cat([spike, spike[:, :, -2:]], dim=2)
27
+ elif self.spike_size == (480, 854):
28
+ spike = torch.cat([spike, spike[:, :, :, -2:]], dim=3)
29
+ # dsft
30
+ dsft = self.compute_dsft_core(spike)
31
+ dsft_dict = self.convert_dsft4(dsft, spike)
32
+ input_dict = {
33
+ "dsft_dict": dsft_dict,
34
+ "spikes": spike,
35
+ }
36
+ return input_dict
37
+
38
+ def postprocess_img(self, image):
39
+ if self.spike_size == (250, 400):
40
+ image = image[:, :, :250, :]
41
+ elif self.spike_size == (480, 854):
42
+ image = image[:, :, :, :854]
43
+ return image
44
+
45
+ def compute_dsft_core(self, spike):
46
+ bs, T, H, W = spike.shape
47
+ time = spike * torch.arange(T, device="cuda").reshape(1, T, 1, 1)
48
+ l_idx, _ = time.cummax(dim=1)
49
+ time[time == 0] = T
50
+ r_idx, _ = torch.flip(time, [1]).cummin(dim=1)
51
+ r_idx = torch.flip(r_idx, [1])
52
+ r_idx = torch.cat([r_idx[:, 1:, :, :], torch.ones([bs, 1, H, W], device="cuda") * T], dim=1)
53
+ res = r_idx - l_idx
54
+ res = torch.clip(res, 0)
55
+ return res
56
+
57
+ def convert_dsft4(self, dsft, spike):
58
+ b, T, h, w = spike.shape
59
+ dmls1 = -1 * torch.ones(spike.shape, device=spike.device, dtype=torch.float32)
60
+ dmrs1 = -1 * torch.ones(spike.shape, device=spike.device, dtype=torch.float32)
61
+ flag = -2 * torch.ones([b, h, w], device=spike.device, dtype=torch.float32)
62
+ for ii in range(T - 1, 0 - 1, -1):
63
+ flag += spike[:, ii] == 1
64
+ copy_pad_coord = flag < 0
65
+ dmls1[:, ii][copy_pad_coord] = dsft[:, ii][copy_pad_coord]
66
+ if ii < T - 1:
67
+ update_coord = (spike[:, ii + 1] == 1) * (~copy_pad_coord)
68
+ dmls1[:, ii][update_coord] = dsft[:, ii + 1][update_coord]
69
+ non_update_coord = (spike[:, ii + 1] != 1) * (~copy_pad_coord)
70
+ dmls1[:, ii][non_update_coord] = dmls1[:, ii + 1][non_update_coord]
71
+ flag = -2 * torch.ones([b, h, w], device=spike.device, dtype=torch.float32)
72
+ for ii in range(0, T, 1):
73
+ flag += spike[:, ii] == 1
74
+ copy_pad_coord = flag < 0
75
+ dmrs1[:, ii][copy_pad_coord] = dsft[:, ii][copy_pad_coord]
76
+ if ii > 0:
77
+ update_coord = (spike[:, ii] == 1) * (~copy_pad_coord)
78
+ dmrs1[:, ii][update_coord] = dsft[:, ii - 1][update_coord]
79
+ non_update_coord = (spike[:, ii] != 1) * (~copy_pad_coord)
80
+ dmrs1[:, ii][non_update_coord] = dmrs1[:, ii - 1][non_update_coord]
81
+ dsft12 = dsft + dmls1
82
+ dsft21 = dsft + dmrs1
83
+ dsft22 = dsft + dmls1 + dmrs1
84
+ dsft_dict = {
85
+ "dsft11": dsft,
86
+ "dsft12": dsft12,
87
+ "dsft21": dsft21,
88
+ "dsft22": dsft22,
89
+ }
90
+ return dsft_dict
@@ -0,0 +1,19 @@
1
+ from dataclasses import dataclass, field
2
+ from spikezoo.models.base_model import BaseModel, BaseModelConfig
3
+ from typing import List
4
+
5
+
6
+ @dataclass
7
+ class SPCSNetConfig(BaseModelConfig):
8
+ # default params for WGSE
9
+ model_name: str = "spcsnet"
10
+ model_file_name: str = "models"
11
+ model_cls_name: str = "SPCS_Net"
12
+ model_win_length: int = 41
13
+ require_params: bool = True
14
+ ckpt_path: str = 'weights/spcsnet.pth'
15
+
16
+
17
+ class SPCSNet(BaseModel):
18
+ def __init__(self, cfg: BaseModelConfig):
19
+ super(SPCSNet, self).__init__(cfg)
@@ -0,0 +1,32 @@
1
+ from dataclasses import dataclass
2
+ from spikezoo.models.base_model import BaseModel, BaseModelConfig
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ @dataclass
8
+ class SpikeCLIPConfig(BaseModelConfig):
9
+ # default params for SpikeCLIP
10
+ model_name: str = "spikeclip"
11
+ model_file_name: str = "nets"
12
+ model_cls_name: str = "LRN"
13
+ model_win_length: int = 200
14
+ require_params: bool = True
15
+ ckpt_path: str = "weights/spikeclip.pth"
16
+
17
+
18
+ class SpikeCLIP(BaseModel):
19
+ def __init__(self, cfg: BaseModelConfig):
20
+ super(SpikeCLIP, self).__init__(cfg)
21
+
22
+ def preprocess_spike(self, spike):
23
+ # length
24
+ spike = self.crop_spike_length(spike)
25
+ # voxel
26
+ voxel = torch.sum(spike.reshape(-1, 50, 4, spike.shape[-2], spike.shape[-1]), axis=2) # [200,224,224] -> [50,224,224]
27
+ voxel = F.pad(voxel, pad=(20, 20, 20, 20), mode="reflect", value=0)
28
+ return voxel
29
+
30
+ def postprocess_img(self, image):
31
+ image = image[:, :, 20:-20, 20:-20]
32
+ return image
@@ -0,0 +1,50 @@
1
+ import torch
2
+ from dataclasses import dataclass, field
3
+ from spikezoo.models.base_model import BaseModel, BaseModelConfig
4
+
5
+
6
+ @dataclass
7
+ class SpikeFormerConfig(BaseModelConfig):
8
+ # default params for SpikeFormer
9
+ model_name: str = "spikeformer"
10
+ model_file_name: str = "Model.SpikeFormer"
11
+ model_cls_name: str = "SpikeFormer"
12
+ model_win_length: int = 65
13
+ require_params: bool = True
14
+ ckpt_path: str = "weights/spikeformer.pth"
15
+ model_params: dict = field(
16
+ default_factory=lambda: {
17
+ "inputDim": 65,
18
+ "dims": (32, 64, 160, 256),
19
+ "heads": (1, 2, 5, 8),
20
+ "ff_expansion": (8, 8, 4, 4),
21
+ "reduction_ratio": (8, 4, 2, 1),
22
+ "num_layers": 2,
23
+ "decoder_dim": 256,
24
+ "out_channel": 1,
25
+ }
26
+ )
27
+
28
+
29
+ class SpikeFormer(BaseModel):
30
+ def __init__(self, cfg: BaseModelConfig):
31
+ super(SpikeFormer, self).__init__(cfg)
32
+
33
+ def preprocess_spike(self, spike):
34
+ # length
35
+ spike = self.crop_spike_length(spike)
36
+ # size
37
+ if self.spike_size == (250, 400):
38
+ spike = torch.cat([spike[:, :, :3, :], spike, spike[:, :, -3:, :]], dim=2)
39
+ elif self.spike_size == (480, 854):
40
+ spike = torch.cat([spike, spike[:, :, :, -2:]], dim=3)
41
+ # input
42
+ spike = 2 * spike - 1
43
+ return spike
44
+
45
+ def postprocess_img(self, image):
46
+ if self.spike_size == (250, 400):
47
+ image = image[:, :, 3:-3, :]
48
+ elif self.spike_size == (480, 854):
49
+ image = image[:, :, :, :854]
50
+ return image
@@ -0,0 +1,51 @@
1
+ import torch
2
+ from dataclasses import dataclass, field
3
+ from spikezoo.models.base_model import BaseModel, BaseModelConfig
4
+
5
+
6
+ @dataclass
7
+ class Spk2ImgNetConfig(BaseModelConfig):
8
+ # default params for Spk2ImgNet
9
+ model_name: str = "spk2imgnet"
10
+ model_file_name: str = "nets"
11
+ model_cls_name: str = "SpikeNet"
12
+ model_win_length: int = 41
13
+ require_params: bool = True
14
+ ckpt_path: str = "weights/spk2imgnet.pth"
15
+ light_correction: bool = False
16
+
17
+ # model params
18
+ model_params: dict = field(
19
+ default_factory=lambda: {
20
+ "in_channels": 13,
21
+ "features": 64,
22
+ "out_channels": 1,
23
+ "win_r": 6,
24
+ "win_step": 7,
25
+ }
26
+ )
27
+
28
+
29
+ class Spk2ImgNet(BaseModel):
30
+ def __init__(self, cfg: BaseModelConfig):
31
+ super(Spk2ImgNet, self).__init__(cfg)
32
+
33
+ def preprocess_spike(self, spike):
34
+ # length
35
+ spike = self.crop_spike_length(spike)
36
+ # size
37
+ if self.spike_size == (250, 400):
38
+ spike = torch.cat([spike, spike[:, :, -2:]], dim=2)
39
+ elif self.spike_size == (480, 854):
40
+ spike = torch.cat([spike, spike[:, :, :, -2:]], dim=3)
41
+ return spike
42
+
43
+ def postprocess_img(self, image):
44
+ if self.spike_size == (250, 400):
45
+ image = image[:, :, :250, :]
46
+ elif self.spike_size == (480, 854):
47
+ image = image[:, :, :, :854]
48
+ # used on the REDS_small dataset.
49
+ if self.cfg.light_correction == True:
50
+ image = torch.clamp(image / 0.6, 0, 1)
51
+ return image
@@ -0,0 +1,22 @@
1
+ from dataclasses import dataclass
2
+ from spikezoo.models.base_model import BaseModel, BaseModelConfig
3
+
4
+
5
+ @dataclass
6
+ class SSIRConfig(BaseModelConfig):
7
+ # default params for SSIR
8
+ model_name: str = "ssir"
9
+ model_file_name: str = "models.networks"
10
+ model_cls_name: str = "SSIR"
11
+ model_win_length: int = 41
12
+ require_params: bool = True
13
+ ckpt_path: str = "weights/ssir.pth"
14
+
15
+
16
+ class SSIR(BaseModel):
17
+ def __init__(self, cfg: BaseModelConfig):
18
+ super(SSIR, self).__init__(cfg)
19
+
20
+ def postprocess_img(self, image):
21
+ # image = image[0]
22
+ return image
@@ -0,0 +1,18 @@
1
+ from dataclasses import dataclass
2
+ from spikezoo.models.base_model import BaseModel, BaseModelConfig
3
+
4
+
5
+ @dataclass
6
+ class SSMLConfig(BaseModelConfig):
7
+ # default params for SSML
8
+ model_name: str = "ssml"
9
+ model_file_name: str = "model"
10
+ model_cls_name: str = "DoubleNet"
11
+ model_win_length: int = 41
12
+ require_params: bool = True
13
+ ckpt_path: str = 'weights/ssml.pt'
14
+
15
+
16
+ class SSML(BaseModel):
17
+ def __init__(self, cfg: BaseModelConfig):
18
+ super(SSML, self).__init__(cfg)
@@ -0,0 +1,37 @@
1
+ import torch
2
+ from dataclasses import dataclass
3
+ from spikezoo.models.base_model import BaseModel, BaseModelConfig
4
+
5
+
6
+ @dataclass
7
+ class STIRConfig(BaseModelConfig):
8
+ # default params for SSIR
9
+ model_name: str = "stir"
10
+ model_file_name: str = "models.networks_STIR"
11
+ model_cls_name: str = "STIR"
12
+ model_win_length: int = 61
13
+ require_params: bool = True
14
+ ckpt_path: str = "weights/stir.pth"
15
+
16
+
17
+ class STIR(BaseModel):
18
+ def __init__(self, cfg: BaseModelConfig):
19
+ super(STIR, self).__init__(cfg)
20
+
21
+ def preprocess_spike(self, spike):
22
+ # length
23
+ spike = self.crop_spike_length(spike)
24
+ # size
25
+ if self.spike_size == (250, 400):
26
+ spike = torch.cat([spike, spike[:, :, -6:]], dim=2)
27
+ elif self.spike_size == (480, 854):
28
+ spike = torch.cat([spike, spike[:, :, :, -10:]], dim=3)
29
+ return spike
30
+
31
+ def postprocess_img(self, image):
32
+ # recon, Fs_lv_0, Fs_lv_1, Fs_lv_2, Fs_lv_3, Fs_lv_4, Est = image
33
+ if self.spike_size == (250, 400):
34
+ image = image[:, :, :250, :]
35
+ elif self.spike_size == (480, 854):
36
+ image = image[:, :, :, :854]
37
+ return image
@@ -0,0 +1,18 @@
1
+ from dataclasses import dataclass, field
2
+ from spikezoo.models.base_model import BaseModel, BaseModelConfig
3
+
4
+
5
+ @dataclass
6
+ class TFIConfig(BaseModelConfig):
7
+ # default params for TFI
8
+ model_name: str = "tfi"
9
+ model_file_name: str = "nets"
10
+ model_cls_name: str = "TFIModel"
11
+ model_win_length: int = 41
12
+ require_params: bool = False
13
+ model_params: dict = field(default_factory=lambda: {"model_win_length": 41})
14
+
15
+
16
+ class TFI(BaseModel):
17
+ def __init__(self, cfg: BaseModelConfig):
18
+ super(TFI, self).__init__(cfg)
@@ -0,0 +1,18 @@
1
+ from dataclasses import dataclass,field
2
+ from spikezoo.models.base_model import BaseModel, BaseModelConfig
3
+
4
+
5
+ @dataclass
6
+ class TFPConfig(BaseModelConfig):
7
+ # default params for TFP
8
+ model_name: str = "tfp"
9
+ model_file_name: str = "nets"
10
+ model_cls_name: str = "TFPModel"
11
+ model_win_length: int = 41
12
+ require_params: bool = False
13
+ model_params: dict = field(default_factory=lambda: {"model_win_length": 41})
14
+
15
+
16
+ class TFP(BaseModel):
17
+ def __init__(self, cfg: BaseModelConfig):
18
+ super(TFP, self).__init__(cfg)
@@ -0,0 +1,31 @@
1
+ from dataclasses import dataclass, field
2
+ from spikezoo.models.base_model import BaseModel, BaseModelConfig
3
+ from typing import List
4
+
5
+
6
+ @dataclass
7
+ class WGSEConfig(BaseModelConfig):
8
+ # default params for WGSE
9
+ model_name: str = "wgse"
10
+ model_file_name: str = "dwtnets"
11
+ model_cls_name: str = "Dwt1dResnetX_TCN"
12
+ model_win_length: int = 41
13
+ require_params: bool = True
14
+ ckpt_path: str = "weights/wgse.pt"
15
+ model_params: dict = field(
16
+ default_factory=lambda: {
17
+ "wvlname": "db8",
18
+ "J": 5,
19
+ "yl_size": "15",
20
+ "yh_size": [28, 21, 18, 16, 15],
21
+ "num_residual_blocks": 3,
22
+ "norm": None,
23
+ "ks": 3,
24
+ "store_features": True,
25
+ }
26
+ )
27
+
28
+
29
+ class WGSE(BaseModel):
30
+ def __init__(self, cfg: BaseModelConfig):
31
+ super(WGSE, self).__init__(cfg)
@@ -0,0 +1,4 @@
1
+ from .base_pipeline import Pipeline, PipelineConfig
2
+ from .ensemble_pipeline import EnsemblePipelineConfig, EnsemblePipeline
3
+ from .train_pipeline import TrainPipelineConfig, TrainPipeline
4
+
@@ -0,0 +1,267 @@
1
+ import torch
2
+ from dataclasses import dataclass, field
3
+ import os
4
+ from spikezoo.utils.img_utils import tensor2npy, AverageMeter
5
+ from spikezoo.utils.spike_utils import load_vidar_dat
6
+ from spikezoo.metrics import cal_metric_pair, cal_metric_single
7
+ import numpy as np
8
+ import cv2
9
+ from pathlib import Path
10
+ from enum import Enum, auto
11
+ from typing import Literal
12
+ from spikezoo.metrics import metric_pair_names, metric_single_names, metric_all_names
13
+ from thop import profile
14
+ import time
15
+ from datetime import datetime
16
+ from spikezoo.utils import setup_logging, save_config
17
+ from tqdm import tqdm
18
+ from spikezoo.models import build_model_cfg, build_model_name, BaseModel, BaseModelConfig
19
+ from spikezoo.datasets import build_dataset_cfg, build_dataset_name, BaseDataset, BaseDatasetConfig, build_dataloader
20
+ from typing import Optional, Union, List
21
+
22
+
23
+ @dataclass
24
+ class PipelineConfig:
25
+ "Evaluate metrics or not."
26
+ save_metric: bool = True
27
+ "Save recoverd images or not."
28
+ save_img: bool = True
29
+ "Normalizing recoverd images or not."
30
+ save_img_norm: bool = False
31
+ "Normalizing gt or not."
32
+ gt_img_norm: bool = False
33
+ "Save folder for the code running result."
34
+ save_folder: str = ""
35
+ "Saved experiment name."
36
+ exp_name: str = ""
37
+ "Metric names for evaluation."
38
+ metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim"])
39
+ "Different modes for the pipeline."
40
+ _mode: Literal["single_mode", "multi_mode", "train_mode"] = "single_mode"
41
+
42
+
43
+ class Pipeline:
44
+ def __init__(
45
+ self,
46
+ cfg: PipelineConfig,
47
+ model_cfg: Union[str, BaseModelConfig],
48
+ dataset_cfg: Union[str, BaseDatasetConfig],
49
+ ):
50
+ self.cfg = cfg
51
+ self._setup_model_data(model_cfg, dataset_cfg)
52
+ self._setup_pipeline()
53
+
54
+ def _setup_model_data(self, model_cfg, dataset_cfg):
55
+ """Model and Data setup."""
56
+ # model
57
+ self.model: BaseModel = build_model_name(model_cfg) if isinstance(model_cfg, str) else build_model_cfg(model_cfg)
58
+ self.model = self.model.eval()
59
+ torch.set_grad_enabled(False)
60
+ # dataset
61
+ self.dataset: BaseDataset = build_dataset_name(dataset_cfg) if isinstance(dataset_cfg, str) else build_dataset_cfg(dataset_cfg)
62
+ self.dataloader = build_dataloader(self.dataset)
63
+ # device
64
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
65
+
66
+ def _setup_pipeline(self):
67
+ """Pipeline setup."""
68
+ # save folder
69
+ self.thistime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S-%f")[:23]
70
+ self.save_folder = Path(__file__).parent.parent / Path(f"results") if len(self.cfg.save_folder) == 0 else self.cfg.save_folder
71
+ mode_name = "train" if self.cfg._mode == "train_mode" else "detect"
72
+ self.save_folder = (
73
+ self.save_folder / Path(f"{mode_name}/{self.thistime}")
74
+ if len(self.cfg.exp_name) == 0
75
+ else self.save_folder / Path(f"{mode_name}/{self.cfg.exp_name}")
76
+ )
77
+ save_folder = self.save_folder
78
+ os.makedirs(str(save_folder), exist_ok=True)
79
+ # logger result
80
+ self.logger = setup_logging(save_folder / Path("result.log"))
81
+ self.logger.info(f"Info logs are saved on the {save_folder}/result.log")
82
+ # pipeline config
83
+ save_config(self.cfg, save_folder / Path("cfg_pipeline.log"))
84
+ # model config
85
+ if self.cfg._mode == "single_mode":
86
+ save_config(self.model.cfg, save_folder / Path("cfg_model.log"))
87
+ elif self.cfg._mode == "multi_mode":
88
+ for model in self.model_list:
89
+ save_config(model.cfg, save_folder / Path("cfg_model.log"), mode="a")
90
+ # dataset config
91
+ save_config(self.dataset.cfg, save_folder / Path("cfg_dataset.log"))
92
+
93
+ def spk2img_from_dataset(self, idx=0):
94
+ """Func---Save the recoverd image and calculate the metric from the given dataset."""
95
+ # save folder
96
+ save_folder = self.save_folder / Path(f"spk2img_from_dataset/{self.dataset.cfg.dataset_name}_dataset/{self.dataset.cfg.split}/{idx:06d}")
97
+ os.makedirs(str(save_folder), exist_ok=True)
98
+
99
+ # data process
100
+ batch = self.dataset[idx]
101
+ spike, img = batch["spike"], batch["img"]
102
+ spike = spike[None].to(self.device)
103
+ if self.dataset.cfg.with_img == True:
104
+ img = img[None].to(self.device)
105
+ else:
106
+ img = None
107
+ return self._spk2img(spike, img, save_folder)
108
+
109
+ def spk2img_from_file(self, file_path, height, width, img_path=None, remove_head=False):
110
+ """Func---Save the recoverd image and calculate the metric from the given input file."""
111
+ # save folder
112
+ save_folder = self.save_folder / Path(f"spk2img_from_file/{os.path.basename(file_path)}")
113
+ os.makedirs(str(save_folder), exist_ok=True)
114
+
115
+ # load spike from .dat
116
+ if file_path.endswith(".dat"):
117
+ spike = load_vidar_dat(file_path, height, width, remove_head)
118
+ # load spike from .npz from UHSR
119
+ elif file_path.endswith("npz"):
120
+ spike = np.load(file_path)["spk"].astype(np.float32)[:, 13:237, 13:237]
121
+ else:
122
+ raise RuntimeError("Not recognized spike input file.")
123
+ # load img from .png/.jpg image file
124
+ if img_path is not None:
125
+ img = cv2.imread(img_path)
126
+ if img.ndim == 3:
127
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
128
+ img = (img / 255).astype(np.float32)
129
+ img = torch.from_numpy(img)[None, None].to(self.device)
130
+ else:
131
+ img = img_path
132
+ spike = torch.from_numpy(spike)[None].to(self.device)
133
+ return self._spk2img(spike, img, save_folder)
134
+
135
+ def spk2img_from_spk(self, spike, img=None):
136
+ """Func---Save the recoverd image and calculate the metric from the given spike stream."""
137
+ # save folder
138
+ save_folder = self.save_folder / Path(f"spk2img_from_spk/{self.thistime}")
139
+ os.makedirs(str(save_folder), exist_ok=True)
140
+
141
+ # spike process
142
+ if isinstance(spike, np.ndarray):
143
+ spike = torch.from_numpy(spike)
144
+ spike = spike.to(self.device)
145
+ # [c,h,w] -> [1,c,w,h]
146
+ if spike.dim() == 3:
147
+ spike = spike[None]
148
+ spike = spike.float()
149
+ # img process
150
+ if img is not None:
151
+ if isinstance(img, np.ndarray):
152
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if img.ndim == 3 else img
153
+ img = (img / 255).astype(np.float32)
154
+ img = torch.from_numpy(img)[None, None].to(self.device)
155
+ else:
156
+ raise RuntimeError("Not recognized image input type.")
157
+ return self._spk2img(spike, img, save_folder)
158
+
159
+ def save_imgs_from_dataset(self):
160
+ """Func---Save all images from the given dataset."""
161
+ for idx in range(len(self.dataset)):
162
+ self.spk2img_from_dataset(idx=idx)
163
+
164
+ # TODO: To be overridden
165
+ def cal_params(self):
166
+ """Func---Calculate the parameters/flops/latency of the given method."""
167
+ self._cal_prams_model(self.model)
168
+
169
+ # TODO: To be overridden
170
+ def cal_metrics(self):
171
+ """Func---Calculate the metric of the given method."""
172
+ self._cal_metrics_model(self.model)
173
+
174
+ # TODO: To be overridden
175
+ def _spk2img(self, spike, img, save_folder):
176
+ """Spike-to-image: spike:[bs,c,h,w] (0-1), img:[bs,1,h,w] (0-1)"""
177
+ return self._spk2img_model(self.model, spike, img, save_folder)
178
+
179
+ def _spk2img_model(self, model, spike, img, save_folder):
180
+ """Spike-to-image from the given model."""
181
+ # spike2image conversion
182
+ model_name = model.cfg.model_name
183
+ recon_img = model(spike)
184
+ recon_img_copy = recon_img.clone()
185
+ # normalization
186
+ recon_img, img = self._post_process_img(model_name, recon_img, img)
187
+ # metric
188
+ if self.cfg.save_metric == True:
189
+ self.logger.info(f"----------------------Method: {model_name.upper()}----------------------")
190
+ # paired metric
191
+ for metric_name in metric_all_names:
192
+ if img is not None and metric_name in metric_pair_names:
193
+ self.logger.info(f"{metric_name.upper()}: {cal_metric_pair(recon_img,img,metric_name)}")
194
+ elif metric_name in metric_single_names:
195
+ self.logger.info(f"{metric_name.upper()}: {cal_metric_single(recon_img,metric_name)}")
196
+ else:
197
+ self.logger.info(f"{metric_name.upper()} not calculated since no ground truth provided.")
198
+
199
+ # visual
200
+ if self.cfg.save_img == True:
201
+ recon_img = tensor2npy(recon_img[0, 0])
202
+ cv2.imwrite(f"{save_folder}/{model.cfg.model_name}.png", recon_img)
203
+ if img is not None:
204
+ img = tensor2npy(img[0, 0])
205
+ cv2.imwrite(f"{save_folder}/sharp_img.png", img)
206
+ self.logger.info(f"Images are saved on the {save_folder}")
207
+
208
+ return recon_img_copy
209
+
210
+ def _post_process_img(self, model_name, recon_img, gt_img):
211
+ """Post process the reconstructed image."""
212
+ # TFP and TFI algorithms are normalized automatically, others are normalized based on the self.cfg.use_norm
213
+ if model_name in ["tfp", "tfi", "spikeformer", "spikeclip"]:
214
+ recon_img = (recon_img - recon_img.min()) / (recon_img.max() - recon_img.min())
215
+ elif self.cfg.save_img_norm == True:
216
+ recon_img = (recon_img - recon_img.min()) / (recon_img.max() - recon_img.min())
217
+ recon_img = recon_img.clip(0, 1)
218
+ gt_img = (gt_img - gt_img.min()) / (gt_img.max() - gt_img.min()) if self.cfg.gt_img_norm == True and gt_img is not None else gt_img
219
+ return recon_img, gt_img
220
+
221
+ def _cal_metrics_model(self, model: BaseModel):
222
+ """Calculate the metrics for the given model."""
223
+ # metrics construct
224
+ model_name = model.cfg.model_name
225
+ metrics_dict = {}
226
+ for metric_name in self.cfg.metric_names:
227
+ if (self.dataset.cfg.with_img == True) or (metric_name in metric_single_names):
228
+ metrics_dict[metric_name] = AverageMeter()
229
+
230
+ # metrics calculate
231
+ for batch_idx, batch in enumerate(tqdm(self.dataloader)):
232
+ batch = model.feed_to_device(batch)
233
+ outputs = model.get_outputs_dict(batch)
234
+ recon_img, img = model.get_paired_imgs(batch, outputs)
235
+ recon_img, img = self._post_process_img(model_name, recon_img, img)
236
+ for metric_name in metrics_dict.keys():
237
+ if metric_name in metric_pair_names:
238
+ metrics_dict[metric_name].update(cal_metric_pair(recon_img, img, metric_name))
239
+ elif metric_name in metric_single_names:
240
+ metrics_dict[metric_name].update(cal_metric_single(recon_img, metric_name))
241
+
242
+ # metrics self.logger.info
243
+ self.logger.info(f"----------------------Method: {model_name.upper()}----------------------")
244
+ for metric_name in metrics_dict.keys():
245
+ self.logger.info(f"{metric_name.upper()}: {metrics_dict[metric_name].avg}")
246
+
247
+ def _cal_prams_model(self, model):
248
+ """Calculate the parameters for the given model."""
249
+ network = model.net
250
+ model_name = model.cfg.model_name.upper()
251
+ # params
252
+ params = sum(p.numel() for p in network.parameters())
253
+ # latency
254
+ spike = torch.zeros((1, 200, 250, 400)).cuda()
255
+ start_time = time.time()
256
+ for _ in range(100):
257
+ model(spike)
258
+ latency = (time.time() - start_time) / 100
259
+ # flop # todo thop bug for BSF
260
+ flops, _ = profile((model), inputs=(spike,))
261
+ re_msg = (
262
+ "Total params: %.4fM" % (params / 1e6),
263
+ "FLOPs:" + str(flops / 1e9) + "{}".format("G"),
264
+ "Latency: {:.6f} seconds".format(latency),
265
+ )
266
+ self.logger.info(f"----------------------Method: {model_name}----------------------")
267
+ self.logger.info(re_msg)