spikezoo 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- spikezoo/__init__.py +23 -7
- spikezoo/archs/bsf/models/bsf/bsf.py +37 -25
- spikezoo/archs/bsf/models/bsf/rep.py +2 -2
- spikezoo/archs/spk2imgnet/nets.py +1 -1
- spikezoo/archs/ssir/models/networks.py +1 -1
- spikezoo/archs/ssml/model.py +9 -5
- spikezoo/archs/stir/metrics/losses.py +1 -1
- spikezoo/archs/stir/models/networks_STIR.py +16 -9
- spikezoo/archs/tfi/nets.py +1 -1
- spikezoo/archs/tfp/nets.py +1 -1
- spikezoo/archs/wgse/dwtnets.py +6 -6
- spikezoo/datasets/__init__.py +11 -9
- spikezoo/datasets/base_dataset.py +10 -3
- spikezoo/datasets/realworld_dataset.py +1 -3
- spikezoo/datasets/{reds_small_dataset.py → reds_base_dataset.py} +9 -8
- spikezoo/datasets/reds_ssir_dataset.py +181 -0
- spikezoo/datasets/szdata_dataset.py +5 -15
- spikezoo/datasets/uhsr_dataset.py +4 -3
- spikezoo/models/__init__.py +8 -6
- spikezoo/models/base_model.py +120 -64
- spikezoo/models/bsf_model.py +11 -3
- spikezoo/models/spcsnet_model.py +19 -0
- spikezoo/models/spikeclip_model.py +4 -3
- spikezoo/models/spk2imgnet_model.py +9 -15
- spikezoo/models/ssir_model.py +4 -6
- spikezoo/models/ssml_model.py +44 -2
- spikezoo/models/stir_model.py +26 -5
- spikezoo/models/tfi_model.py +3 -1
- spikezoo/models/tfp_model.py +4 -2
- spikezoo/models/wgse_model.py +8 -14
- spikezoo/pipeline/base_pipeline.py +79 -55
- spikezoo/pipeline/ensemble_pipeline.py +10 -9
- spikezoo/pipeline/train_cfgs.py +89 -0
- spikezoo/pipeline/train_pipeline.py +129 -30
- spikezoo/utils/optimizer_utils.py +22 -0
- spikezoo/utils/other_utils.py +31 -6
- spikezoo/utils/scheduler_utils.py +25 -0
- spikezoo/utils/spike_utils.py +61 -29
- spikezoo-0.2.3.dist-info/METADATA +263 -0
- {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.dist-info}/RECORD +43 -80
- spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
- spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/CheckPoints/readme +0 -1
- spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +0 -60
- spikezoo/archs/spikeformer/DataProcess/DataLoader.py +0 -115
- spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +0 -39
- spikezoo/archs/spikeformer/EvalResults/readme +0 -1
- spikezoo/archs/spikeformer/LICENSE +0 -21
- spikezoo/archs/spikeformer/Metrics/Metrics.py +0 -50
- spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/Loss.py +0 -89
- spikezoo/archs/spikeformer/Model/SpikeFormer.py +0 -230
- spikezoo/archs/spikeformer/Model/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/README.md +0 -30
- spikezoo/archs/spikeformer/evaluate.py +0 -87
- spikezoo/archs/spikeformer/recon_real_data.py +0 -97
- spikezoo/archs/spikeformer/requirements.yml +0 -95
- spikezoo/archs/spikeformer/train.py +0 -173
- spikezoo/archs/spikeformer/utils.py +0 -22
- spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
- spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/models/spikeformer_model.py +0 -50
- spikezoo-0.2.2.dist-info/METADATA +0 -196
- {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.dist-info}/LICENSE.txt +0 -0
- {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.dist-info}/WHEEL +0 -0
- {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.dist-info}/top_level.txt +0 -0
spikezoo/models/wgse_model.py
CHANGED
@@ -1,6 +1,12 @@
|
|
1
1
|
from dataclasses import dataclass, field
|
2
2
|
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
3
3
|
from typing import List
|
4
|
+
from spikezoo.pipeline import TrainPipelineConfig
|
5
|
+
import torch.nn as nn
|
6
|
+
import torch.optim as optim
|
7
|
+
import torch.optim.lr_scheduler as lr_scheduler
|
8
|
+
from typing import List
|
9
|
+
from spikezoo.archs.wgse.dwtnets import Dwt1dResnetX_TCN
|
4
10
|
|
5
11
|
|
6
12
|
@dataclass
|
@@ -9,21 +15,9 @@ class WGSEConfig(BaseModelConfig):
|
|
9
15
|
model_name: str = "wgse"
|
10
16
|
model_file_name: str = "dwtnets"
|
11
17
|
model_cls_name: str = "Dwt1dResnetX_TCN"
|
12
|
-
|
18
|
+
model_length: int = 41
|
19
|
+
model_length_dict: dict = field(default_factory=lambda: {"v010": 41, "v023": 41})
|
13
20
|
require_params: bool = True
|
14
|
-
ckpt_path: str = "weights/wgse.pt"
|
15
|
-
model_params: dict = field(
|
16
|
-
default_factory=lambda: {
|
17
|
-
"wvlname": "db8",
|
18
|
-
"J": 5,
|
19
|
-
"yl_size": "15",
|
20
|
-
"yh_size": [28, 21, 18, 16, 15],
|
21
|
-
"num_residual_blocks": 3,
|
22
|
-
"norm": None,
|
23
|
-
"ks": 3,
|
24
|
-
"store_features": True,
|
25
|
-
}
|
26
|
-
)
|
27
21
|
|
28
22
|
|
29
23
|
class WGSE(BaseModel):
|
@@ -18,24 +18,27 @@ from tqdm import tqdm
|
|
18
18
|
from spikezoo.models import build_model_cfg, build_model_name, BaseModel, BaseModelConfig
|
19
19
|
from spikezoo.datasets import build_dataset_cfg, build_dataset_name, BaseDataset, BaseDatasetConfig, build_dataloader
|
20
20
|
from typing import Optional, Union, List
|
21
|
+
import shutil
|
22
|
+
from spikingjelly.clock_driven import functional
|
23
|
+
import spikezoo as sz
|
21
24
|
|
22
25
|
|
23
26
|
@dataclass
|
24
27
|
class PipelineConfig:
|
25
|
-
"
|
26
|
-
|
27
|
-
"Save recoverd images or not."
|
28
|
-
save_img: bool = True
|
29
|
-
"Normalizing recoverd images or not."
|
30
|
-
save_img_norm: bool = False
|
31
|
-
"Normalizing gt or not."
|
32
|
-
gt_img_norm: bool = False
|
28
|
+
"Loading weights from local or version on the url."
|
29
|
+
version: Literal["local", "v010", "v023"] = "local"
|
33
30
|
"Save folder for the code running result."
|
34
31
|
save_folder: str = ""
|
35
32
|
"Saved experiment name."
|
36
33
|
exp_name: str = ""
|
34
|
+
"Evaluate metrics or not."
|
35
|
+
save_metric: bool = True
|
37
36
|
"Metric names for evaluation."
|
38
37
|
metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim"])
|
38
|
+
"Save recoverd images or not."
|
39
|
+
save_img: bool = True
|
40
|
+
"Normalizing recoverd images and gt or not."
|
41
|
+
save_img_norm: bool = False
|
39
42
|
"Different modes for the pipeline."
|
40
43
|
_mode: Literal["single_mode", "multi_mode", "train_mode"] = "single_mode"
|
41
44
|
|
@@ -44,8 +47,8 @@ class Pipeline:
|
|
44
47
|
def __init__(
|
45
48
|
self,
|
46
49
|
cfg: PipelineConfig,
|
47
|
-
model_cfg: Union[
|
48
|
-
dataset_cfg: Union[
|
50
|
+
model_cfg: Union[sz.METHOD, BaseModelConfig],
|
51
|
+
dataset_cfg: Union[sz.DATASET, BaseDatasetConfig],
|
49
52
|
):
|
50
53
|
self.cfg = cfg
|
51
54
|
self._setup_model_data(model_cfg, dataset_cfg)
|
@@ -53,9 +56,9 @@ class Pipeline:
|
|
53
56
|
|
54
57
|
def _setup_model_data(self, model_cfg, dataset_cfg):
|
55
58
|
"""Model and Data setup."""
|
56
|
-
# model
|
59
|
+
# model [1] build the model. [2] build the network.
|
57
60
|
self.model: BaseModel = build_model_name(model_cfg) if isinstance(model_cfg, str) else build_model_cfg(model_cfg)
|
58
|
-
self.model =
|
61
|
+
self.model.build_network(mode="eval", version=self.cfg.version)
|
59
62
|
torch.set_grad_enabled(False)
|
60
63
|
# dataset
|
61
64
|
self.dataset: BaseDataset = build_dataset_name(dataset_cfg) if isinstance(dataset_cfg, str) else build_dataset_cfg(dataset_cfg)
|
@@ -74,15 +77,19 @@ class Pipeline:
|
|
74
77
|
if len(self.cfg.exp_name) == 0
|
75
78
|
else self.save_folder / Path(f"{mode_name}/{self.cfg.exp_name}")
|
76
79
|
)
|
77
|
-
|
78
|
-
|
80
|
+
# remove and establish folder
|
81
|
+
save_folder = str(self.save_folder)
|
82
|
+
if os.path.exists(save_folder):
|
83
|
+
shutil.rmtree(save_folder)
|
84
|
+
os.makedirs(save_folder)
|
85
|
+
save_folder = Path(save_folder)
|
79
86
|
# logger result
|
80
87
|
self.logger = setup_logging(save_folder / Path("result.log"))
|
81
88
|
self.logger.info(f"Info logs are saved on the {save_folder}/result.log")
|
82
89
|
# pipeline config
|
83
90
|
save_config(self.cfg, save_folder / Path("cfg_pipeline.log"))
|
84
91
|
# model config
|
85
|
-
if self.cfg._mode
|
92
|
+
if self.cfg._mode in ["single_mode", "train_mode"]:
|
86
93
|
save_config(self.model.cfg, save_folder / Path("cfg_model.log"))
|
87
94
|
elif self.cfg._mode == "multi_mode":
|
88
95
|
for model in self.model_list:
|
@@ -90,28 +97,29 @@ class Pipeline:
|
|
90
97
|
# dataset config
|
91
98
|
save_config(self.dataset.cfg, save_folder / Path("cfg_dataset.log"))
|
92
99
|
|
93
|
-
def
|
94
|
-
"""
|
100
|
+
def infer_from_dataset(self, idx=0):
|
101
|
+
"""Function I---Save the recoverd image and calculate the metric from the given dataset."""
|
95
102
|
# save folder
|
96
|
-
self.logger.info("***********************
|
97
|
-
save_folder = self.save_folder / Path(f"
|
103
|
+
self.logger.info("*********************** infer_from_dataset ***********************")
|
104
|
+
save_folder = self.save_folder / Path(f"infer_from_dataset/{self.dataset.cfg.dataset_name}_dataset/{self.dataset.cfg.split}/{idx:06d}")
|
98
105
|
os.makedirs(str(save_folder), exist_ok=True)
|
99
106
|
|
100
107
|
# data process
|
108
|
+
# todo
|
101
109
|
batch = self.dataset[idx]
|
102
|
-
spike, img = batch["spike"], batch["
|
110
|
+
spike, img, rate = batch["spike"], batch["gt_img"], batch["rate"]
|
103
111
|
spike = spike[None].to(self.device)
|
104
112
|
if self.dataset.cfg.with_img == True:
|
105
113
|
img = img[None].to(self.device)
|
106
114
|
else:
|
107
115
|
img = None
|
108
|
-
return self.
|
116
|
+
return self.infer(spike, img, save_folder, rate)
|
109
117
|
|
110
|
-
def
|
111
|
-
"""
|
118
|
+
def infer_from_file(self, file_path, height=-1, width=-1, img_path=None, rate=1, remove_head=False):
|
119
|
+
"""Function II---Save the recoverd image and calculate the metric from the given input file."""
|
112
120
|
# save folder
|
113
|
-
self.logger.info("***********************
|
114
|
-
save_folder = self.save_folder / Path(f"
|
121
|
+
self.logger.info("*********************** infer_from_file ***********************")
|
122
|
+
save_folder = self.save_folder / Path(f"infer_from_file/{os.path.basename(file_path)}")
|
115
123
|
os.makedirs(str(save_folder), exist_ok=True)
|
116
124
|
|
117
125
|
# load spike from .dat
|
@@ -132,13 +140,13 @@ class Pipeline:
|
|
132
140
|
else:
|
133
141
|
img = img_path
|
134
142
|
spike = torch.from_numpy(spike)[None].to(self.device)
|
135
|
-
return self.
|
143
|
+
return self.infer(spike, img, save_folder, rate)
|
136
144
|
|
137
|
-
def
|
138
|
-
"""
|
145
|
+
def infer_from_spk(self, spike, img=None, rate=1):
|
146
|
+
"""Function III---Save the recoverd image and calculate the metric from the given spike stream."""
|
139
147
|
# save folder
|
140
|
-
self.logger.info("***********************
|
141
|
-
save_folder = self.save_folder / Path(f"
|
148
|
+
self.logger.info("*********************** infer_from_spk ***********************")
|
149
|
+
save_folder = self.save_folder / Path(f"infer_from_spk")
|
142
150
|
os.makedirs(str(save_folder), exist_ok=True)
|
143
151
|
|
144
152
|
# spike process
|
@@ -157,36 +165,40 @@ class Pipeline:
|
|
157
165
|
img = torch.from_numpy(img)[None, None].to(self.device)
|
158
166
|
else:
|
159
167
|
raise RuntimeError("Not recognized image input type.")
|
160
|
-
return self.
|
168
|
+
return self.infer(spike, img, save_folder, rate)
|
169
|
+
|
170
|
+
# TODO: To be overridden
|
171
|
+
def infer(self, spike, img, save_folder, rate):
|
172
|
+
"""Function IV---Spike-to-image conversion interface, input data format: spike [bs,c,h,w] (0-1), img [bs,1,h,w] (0-1)"""
|
173
|
+
return self._infer_model(self.model, spike, img, save_folder, rate)
|
161
174
|
|
162
175
|
def save_imgs_from_dataset(self):
|
163
|
-
"""
|
176
|
+
"""Function V---Save all images from the given dataset."""
|
177
|
+
base_setting = self.cfg.save_metric
|
178
|
+
self.cfg.save_metric = False
|
164
179
|
for idx in range(len(self.dataset)):
|
165
|
-
self.
|
166
|
-
|
180
|
+
self.infer_from_dataset(idx=idx)
|
181
|
+
self.cfg.save_metric = base_setting
|
182
|
+
|
167
183
|
# TODO: To be overridden
|
168
184
|
def cal_params(self):
|
169
|
-
"""
|
185
|
+
"""Function VI---Calculate the parameters/flops/latency of the given method."""
|
170
186
|
self._cal_prams_model(self.model)
|
171
187
|
|
172
188
|
# TODO: To be overridden
|
173
189
|
def cal_metrics(self):
|
174
|
-
"""
|
175
|
-
self._cal_metrics_model(self.model)
|
176
|
-
|
177
|
-
# TODO: To be overridden
|
178
|
-
def _spk2img(self, spike, img, save_folder):
|
179
|
-
"""Spike-to-image: spike:[bs,c,h,w] (0-1), img:[bs,1,h,w] (0-1)"""
|
180
|
-
return self._spk2img_model(self.model, spike, img, save_folder)
|
190
|
+
"""Function VII---Calculate the metric of the given method."""
|
191
|
+
return self._cal_metrics_model(self.model)
|
181
192
|
|
182
|
-
def
|
193
|
+
def _infer_model(self, model, spike, img, save_folder, rate):
|
183
194
|
"""Spike-to-image from the given model."""
|
184
195
|
# spike2image conversion
|
185
196
|
model_name = model.cfg.model_name
|
186
|
-
recon_img = model(spike)
|
197
|
+
recon_img = model.spk2img(spike)
|
187
198
|
recon_img_copy = recon_img.clone()
|
188
199
|
# normalization
|
189
|
-
recon_img, img = self._post_process_img(model_name,
|
200
|
+
recon_img, img = self._post_process_img(recon_img, model_name, rate), self._post_process_img(img, None, 1)
|
201
|
+
self._state_reset(model)
|
190
202
|
# metric
|
191
203
|
if self.cfg.save_metric == True:
|
192
204
|
self.logger.info(f"----------------------Method: {model_name.upper()}----------------------")
|
@@ -209,19 +221,24 @@ class Pipeline:
|
|
209
221
|
self.logger.info(f"Images are saved on the {save_folder}")
|
210
222
|
return recon_img_copy
|
211
223
|
|
212
|
-
def _post_process_img(self,
|
224
|
+
def _post_process_img(self, recon_img, model_name, rate=1):
|
213
225
|
"""Post process the reconstructed image."""
|
214
|
-
#
|
215
|
-
if
|
216
|
-
|
217
|
-
|
226
|
+
# With no GT
|
227
|
+
if recon_img == None:
|
228
|
+
return None
|
229
|
+
# TFP, TFI, spikeclip algorithms are normalized automatically, others are normalized based on the self.cfg.use_norm
|
230
|
+
if model_name in ["tfp", "tfi", "spikeclip"] or self.cfg.save_img_norm == True:
|
218
231
|
recon_img = (recon_img - recon_img.min()) / (recon_img.max() - recon_img.min())
|
232
|
+
else:
|
233
|
+
recon_img = recon_img / rate
|
219
234
|
recon_img = recon_img.clip(0, 1)
|
220
|
-
|
221
|
-
return recon_img, gt_img
|
235
|
+
return recon_img
|
222
236
|
|
223
237
|
def _cal_metrics_model(self, model: BaseModel):
|
224
238
|
"""Calculate the metrics for the given model."""
|
239
|
+
# metric state reset (since get_outputs_dict from the training state is utilized)
|
240
|
+
model_state = model.net.training
|
241
|
+
model.net.training = True
|
225
242
|
# metrics construct
|
226
243
|
model_name = model.cfg.model_name
|
227
244
|
metrics_dict = {}
|
@@ -234,17 +251,19 @@ class Pipeline:
|
|
234
251
|
batch = model.feed_to_device(batch)
|
235
252
|
outputs = model.get_outputs_dict(batch)
|
236
253
|
recon_img, img = model.get_paired_imgs(batch, outputs)
|
237
|
-
recon_img, img = self._post_process_img(model_name,
|
254
|
+
recon_img, img = self._post_process_img(recon_img, model_name), self._post_process_img(img, "auto")
|
238
255
|
for metric_name in metrics_dict.keys():
|
239
256
|
if metric_name in metric_pair_names:
|
240
257
|
metrics_dict[metric_name].update(cal_metric_pair(recon_img, img, metric_name))
|
241
258
|
elif metric_name in metric_single_names:
|
242
259
|
metrics_dict[metric_name].update(cal_metric_single(recon_img, metric_name))
|
243
|
-
|
260
|
+
self._state_reset(model)
|
261
|
+
model.net.training = model_state
|
244
262
|
# metrics self.logger.info
|
245
263
|
self.logger.info(f"----------------------Method: {model_name.upper()}----------------------")
|
246
264
|
for metric_name in metrics_dict.keys():
|
247
265
|
self.logger.info(f"{metric_name.upper()}: {metrics_dict[metric_name].avg}")
|
266
|
+
return metrics_dict
|
248
267
|
|
249
268
|
def _cal_prams_model(self, model):
|
250
269
|
"""Calculate the parameters for the given model."""
|
@@ -256,7 +275,7 @@ class Pipeline:
|
|
256
275
|
spike = torch.zeros((1, 200, 250, 400)).cuda()
|
257
276
|
start_time = time.time()
|
258
277
|
for _ in range(100):
|
259
|
-
model(spike)
|
278
|
+
model.spk2img(spike)
|
260
279
|
latency = (time.time() - start_time) / 100
|
261
280
|
# flop # todo thop bug for BSF
|
262
281
|
flops, _ = profile((model), inputs=(spike,))
|
@@ -267,3 +286,8 @@ class Pipeline:
|
|
267
286
|
)
|
268
287
|
self.logger.info(f"----------------------Method: {model_name}----------------------")
|
269
288
|
self.logger.info(re_msg)
|
289
|
+
|
290
|
+
def _state_reset(self, model):
|
291
|
+
"""State reset for the snn-based method."""
|
292
|
+
if model.cfg.model_name == "ssir":
|
293
|
+
functional.reset_net(model.net)
|
@@ -19,6 +19,7 @@ from spikezoo.models import build_model_cfg, build_model_name, BaseModel, BaseMo
|
|
19
19
|
from spikezoo.datasets import build_dataset_cfg, build_dataset_name, BaseDataset, BaseDatasetConfig, build_dataloader
|
20
20
|
from typing import Optional, Union, List
|
21
21
|
from spikezoo.pipeline.base_pipeline import Pipeline, PipelineConfig
|
22
|
+
import spikezoo as sz
|
22
23
|
|
23
24
|
|
24
25
|
@dataclass
|
@@ -30,30 +31,30 @@ class EnsemblePipeline(Pipeline):
|
|
30
31
|
def __init__(
|
31
32
|
self,
|
32
33
|
cfg: PipelineConfig,
|
33
|
-
model_cfg_list: Union[List[
|
34
|
-
dataset_cfg: Union[
|
34
|
+
model_cfg_list: Union[List[sz.METHOD], List[BaseModelConfig]],
|
35
|
+
dataset_cfg: Union[sz.DATASET, BaseDatasetConfig],
|
35
36
|
):
|
36
37
|
self.cfg = cfg
|
37
|
-
self._setup_model_data(model_cfg_list,dataset_cfg)
|
38
|
+
self._setup_model_data(model_cfg_list, dataset_cfg)
|
38
39
|
self._setup_pipeline()
|
39
40
|
|
40
|
-
def _setup_model_data(self,model_cfg_list,dataset_cfg):
|
41
|
+
def _setup_model_data(self, model_cfg_list, dataset_cfg):
|
41
42
|
"""Model and Data setup."""
|
42
43
|
# model
|
43
44
|
self.model_list: List[BaseModel] = (
|
44
|
-
[build_model_name(name) for name in model_cfg_list] if isinstance(model_cfg_list[0],str) else [build_model_cfg(cfg) for cfg in model_cfg_list]
|
45
|
+
[build_model_name(name) for name in model_cfg_list] if isinstance(model_cfg_list[0], str) else [build_model_cfg(cfg) for cfg in model_cfg_list]
|
45
46
|
)
|
46
|
-
self.model_list = [model.eval
|
47
|
+
self.model_list = [model.build_network(mode="eval", version=self.cfg.version) for model in self.model_list]
|
47
48
|
torch.set_grad_enabled(False)
|
48
49
|
# data
|
49
50
|
self.dataset: BaseDataset = build_dataset_name(dataset_cfg) if isinstance(dataset_cfg, str) else build_dataset_cfg(dataset_cfg)
|
50
51
|
self.dataloader = build_dataloader(self.dataset)
|
51
52
|
# device
|
52
53
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
53
|
-
|
54
|
-
def
|
54
|
+
|
55
|
+
def infer(self, spike, img, save_folder, rate):
|
55
56
|
for model in self.model_list:
|
56
|
-
self.
|
57
|
+
self._infer_model(model, spike, img, save_folder, rate)
|
57
58
|
|
58
59
|
def cal_params(self):
|
59
60
|
for model in self.model_list:
|
@@ -0,0 +1,89 @@
|
|
1
|
+
import torch.nn as nn
|
2
|
+
import torch.optim as optimizer
|
3
|
+
import torch.optim.lr_scheduler as lr_scheduler
|
4
|
+
import functools
|
5
|
+
from spikezoo.utils.optimizer_utils import OptimizerConfig, AdamOptimizerConfig
|
6
|
+
from spikezoo.utils.scheduler_utils import SchedulerConfig, MultiStepSchedulerConfig, CosineAnnealingLRConfig
|
7
|
+
from dataclasses import dataclass, field
|
8
|
+
from spikezoo.pipeline.train_pipeline import TrainPipelineConfig
|
9
|
+
from typing import Optional, Dict, List
|
10
|
+
|
11
|
+
|
12
|
+
@dataclass
|
13
|
+
class REDS_BASE_TrainConfig(TrainPipelineConfig):
|
14
|
+
"""Training setting for methods on the REDS-BASE dataset."""
|
15
|
+
|
16
|
+
# parameters setting
|
17
|
+
epochs: int = 600
|
18
|
+
steps_per_save_imgs: int = 200
|
19
|
+
steps_per_save_ckpt: int = 500
|
20
|
+
steps_per_cal_metrics: int = 100
|
21
|
+
metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim","lpips","niqe","brisque","piqe"])
|
22
|
+
|
23
|
+
# dataloader setting
|
24
|
+
bs_train: int = 8
|
25
|
+
num_workers: int = 4
|
26
|
+
pin_memory: bool = False
|
27
|
+
|
28
|
+
# train setting - optimizer & scheduler & loss_dict
|
29
|
+
optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4)
|
30
|
+
scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[400], gamma=0.2) # from wgse
|
31
|
+
loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
|
32
|
+
|
33
|
+
# @dataclass
|
34
|
+
# class REDS_BASE_TrainConfig(TrainPipelineConfig):
|
35
|
+
# """Training setting for methods on the REDS-BASE dataset."""
|
36
|
+
|
37
|
+
# # parameters setting
|
38
|
+
# epochs: int = 700
|
39
|
+
# steps_per_save_imgs: int = 200
|
40
|
+
# steps_per_save_ckpt: int = 500
|
41
|
+
# steps_per_cal_metrics: int = 100
|
42
|
+
# metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim"])
|
43
|
+
|
44
|
+
# # dataloader setting
|
45
|
+
# bs_train: int = 8
|
46
|
+
# num_workers: int = 4
|
47
|
+
# pin_memory: bool = False
|
48
|
+
|
49
|
+
# # train setting - optimizer & scheduler & loss_dict
|
50
|
+
# optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4)
|
51
|
+
# scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[400, 600], gamma=0.2) # from wgse
|
52
|
+
# loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
|
53
|
+
|
54
|
+
|
55
|
+
# ! Train Config for each method on the official setting, not recommended to utilize their default parameters owing to the dataset setting.
|
56
|
+
@dataclass
|
57
|
+
class BSFTrainConfig(TrainPipelineConfig):
|
58
|
+
"""Training setting for BSF. https://github.com/ruizhao26/BSF"""
|
59
|
+
|
60
|
+
optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4, weight_decay=0.0)
|
61
|
+
scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[10, 20, 30, 40, 50, 60, 70, 80, 90, 100])
|
62
|
+
loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
|
63
|
+
|
64
|
+
|
65
|
+
@dataclass
|
66
|
+
class WGSETrainConfig(TrainPipelineConfig):
|
67
|
+
"""Training setting for WGSE. https://github.com/Leozhangjiyuan/WGSE-SpikeCamera"""
|
68
|
+
|
69
|
+
optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4, betas=(0.9, 0.99), weight_decay=0)
|
70
|
+
scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[400, 600], gamma=0.2)
|
71
|
+
loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
|
72
|
+
|
73
|
+
|
74
|
+
@dataclass
|
75
|
+
class STIRTrainConfig(TrainPipelineConfig):
|
76
|
+
"""Training setting for STIR. https://github.com/GitCVfb/STIR"""
|
77
|
+
|
78
|
+
optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4, betas=(0.9, 0.999))
|
79
|
+
scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70], gamma=0.7)
|
80
|
+
loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
|
81
|
+
|
82
|
+
|
83
|
+
@dataclass
|
84
|
+
class Spk2ImgNetTrainConfig(TrainPipelineConfig):
|
85
|
+
"""Training setting for Spk2ImgNet. https://github.com/Vspacer/Spk2ImgNet"""
|
86
|
+
|
87
|
+
optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4)
|
88
|
+
scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[20], gamma=0.1)
|
89
|
+
loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
|
@@ -1,27 +1,64 @@
|
|
1
1
|
import torch
|
2
|
-
from dataclasses import dataclass
|
2
|
+
from dataclasses import dataclass, field
|
3
3
|
import os
|
4
4
|
from spikezoo.utils.img_utils import tensor2npy
|
5
5
|
import cv2
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import Literal
|
7
|
+
from typing import Literal, Dict
|
8
8
|
from tqdm import tqdm
|
9
9
|
from spikezoo.models import build_model_cfg, build_model_name, BaseModel, BaseModelConfig
|
10
10
|
from spikezoo.datasets import build_dataset_cfg, build_dataset_name, BaseDataset, BaseDatasetConfig, build_dataloader
|
11
|
-
from typing import
|
11
|
+
from typing import Union, List, Optional
|
12
12
|
from spikezoo.pipeline.base_pipeline import Pipeline, PipelineConfig
|
13
|
+
import torch.nn as nn
|
14
|
+
import torch.optim as optimizer
|
15
|
+
import torch.optim.lr_scheduler as lr_scheduler
|
16
|
+
import functools
|
17
|
+
from spikezoo.utils.optimizer_utils import OptimizerConfig, AdamOptimizerConfig
|
18
|
+
from spikezoo.utils.scheduler_utils import SchedulerConfig, MultiStepSchedulerConfig, CosineAnnealingLRConfig
|
19
|
+
from torch.utils.tensorboard import SummaryWriter
|
20
|
+
import subprocess
|
21
|
+
import webbrowser
|
22
|
+
import time
|
23
|
+
import re
|
24
|
+
from spikezoo.utils.other_utils import set_random_seed
|
25
|
+
from spikingjelly.clock_driven import functional
|
13
26
|
|
14
27
|
|
15
28
|
@dataclass
|
16
29
|
class TrainPipelineConfig(PipelineConfig):
|
17
|
-
|
18
|
-
epochs
|
19
|
-
|
30
|
+
# parameters setting
|
31
|
+
"Training epochs."
|
32
|
+
epochs: int = 1000
|
33
|
+
"Steps per to save images."
|
34
|
+
steps_per_save_imgs: int = 200
|
35
|
+
"Steps per to save model weights."
|
36
|
+
steps_per_save_ckpt: int = 500
|
37
|
+
"Steps per to calculate the metrics."
|
38
|
+
steps_per_cal_metrics: int = 100
|
39
|
+
"Step for gradient accumulation. (for snn methods)"
|
40
|
+
steps_grad_accumulation: int = 4
|
41
|
+
"Pipeline mode."
|
42
|
+
_mode: Literal["single_mode", "multi_mode", "train_mode"] = "train_mode"
|
43
|
+
"Use tensorboard or not"
|
44
|
+
use_tensorboard: bool = True
|
45
|
+
"Random seed."
|
46
|
+
seed: int = 521
|
47
|
+
# dataloader setting
|
48
|
+
"Batch size for the train dataloader."
|
49
|
+
bs_train: int = 8
|
50
|
+
"Num_workers for the train dataloader."
|
20
51
|
num_workers: int = 4
|
52
|
+
"Pin_memory true or false for the train dataloader."
|
21
53
|
pin_memory: bool = False
|
22
|
-
|
23
|
-
|
24
|
-
|
54
|
+
|
55
|
+
# train setting - optimizer & scheduler & loss_dict
|
56
|
+
"Optimizer config."
|
57
|
+
optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-3)
|
58
|
+
"Scheduler config."
|
59
|
+
scheduler_cfg: Optional[SchedulerConfig] = None
|
60
|
+
"Loss dict {loss_name,weight}."
|
61
|
+
loss_weight_dict: Dict[Literal["l1", "l2"], float] = field(default_factory=lambda: {"l1": 1})
|
25
62
|
|
26
63
|
|
27
64
|
class TrainPipeline(Pipeline):
|
@@ -34,13 +71,20 @@ class TrainPipeline(Pipeline):
|
|
34
71
|
self.cfg = cfg
|
35
72
|
self._setup_model_data(model_cfg, dataset_cfg)
|
36
73
|
self._setup_pipeline()
|
37
|
-
self.
|
74
|
+
self._setup_training()
|
75
|
+
|
76
|
+
def _setup_pipeline(self):
|
77
|
+
super()._setup_pipeline()
|
78
|
+
set_random_seed(self.cfg.seed)
|
79
|
+
if self.cfg.use_tensorboard:
|
80
|
+
self.writer = SummaryWriter(self.save_folder / Path(""))
|
81
|
+
subprocess.Popen(["tensorboard", f"--logdir={self.save_folder}"])
|
38
82
|
|
39
83
|
def _setup_model_data(self, model_cfg, dataset_cfg):
|
40
84
|
"""Model and Data setup."""
|
41
85
|
# model
|
42
86
|
self.model: BaseModel = build_model_name(model_cfg) if isinstance(model_cfg, str) else build_model_cfg(model_cfg)
|
43
|
-
self.model =
|
87
|
+
self.model.build_network(mode = "train",version="local")
|
44
88
|
torch.set_grad_enabled(True)
|
45
89
|
# data
|
46
90
|
if isinstance(dataset_cfg, str):
|
@@ -54,41 +98,96 @@ class TrainPipeline(Pipeline):
|
|
54
98
|
# device
|
55
99
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
56
100
|
|
101
|
+
def _setup_training(self):
|
102
|
+
"""Setup training optimizer."""
|
103
|
+
self.optimizer = self.cfg.optimizer_cfg.setup(self.model.net.parameters())
|
104
|
+
self.scheduler = self.cfg.scheduler_cfg.setup(self.optimizer) if self.cfg.scheduler_cfg != None else None
|
105
|
+
self.cnt_grad = 0
|
106
|
+
|
57
107
|
def save_network(self, epoch):
|
58
108
|
"""Save the network."""
|
59
109
|
save_folder = self.save_folder / Path("ckpt")
|
60
110
|
os.makedirs(save_folder, exist_ok=True)
|
61
111
|
self.model.save_network(save_folder / f"{epoch:06d}.pth")
|
62
112
|
|
113
|
+
def train(self):
|
114
|
+
"""Training code."""
|
115
|
+
self.logger.info("Start Training!")
|
116
|
+
for epoch in range(self.cfg.epochs):
|
117
|
+
# training
|
118
|
+
for batch_idx, batch in enumerate(tqdm(self.train_dataloader)):
|
119
|
+
batch = self.model.feed_to_device(batch)
|
120
|
+
outputs = self.model.get_outputs_dict(batch)
|
121
|
+
loss_dict, loss_values_dict = self.model.get_loss_dict(outputs, batch, self.cfg.loss_weight_dict)
|
122
|
+
self.optimize_parameters(loss_dict, batch_idx == len(self.train_dataloader) - 1)
|
123
|
+
self.update_learning_rate()
|
124
|
+
self.write_log_train(epoch, loss_values_dict)
|
125
|
+
|
126
|
+
# save visual results & save ckpt & evaluate metrics
|
127
|
+
with torch.no_grad():
|
128
|
+
if epoch % self.cfg.steps_per_save_imgs == 0 or epoch == self.cfg.epochs - 1:
|
129
|
+
self.save_visual(epoch)
|
130
|
+
if epoch % self.cfg.steps_per_save_ckpt == 0 or epoch == self.cfg.epochs - 1:
|
131
|
+
self.save_network(epoch)
|
132
|
+
if epoch % self.cfg.steps_per_cal_metrics == 0 or epoch == self.cfg.epochs - 1:
|
133
|
+
metrics_dict = self.cal_metrics()
|
134
|
+
self.write_log_test(epoch, metrics_dict)
|
135
|
+
|
63
136
|
def save_visual(self, epoch):
|
64
137
|
"""Save the visual results."""
|
65
138
|
self.logger.info("Saving visual results...")
|
66
139
|
save_folder = self.save_folder / Path("imgs") / Path(f"{epoch:06d}")
|
67
140
|
os.makedirs(save_folder, exist_ok=True)
|
68
141
|
for batch_idx, batch in enumerate(tqdm(self.dataloader)):
|
69
|
-
if batch_idx % (len(self.dataloader) //
|
70
|
-
continue
|
142
|
+
if batch_idx % (len(self.dataloader) // 3) != 0:
|
143
|
+
continue
|
71
144
|
batch = self.model.feed_to_device(batch)
|
72
145
|
outputs = self.model.get_outputs_dict(batch)
|
73
146
|
visual_dict = self.model.get_visual_dict(batch, outputs)
|
147
|
+
self._state_reset(self.model)
|
74
148
|
# save
|
75
149
|
for key, img in visual_dict.items():
|
150
|
+
img = self._post_process_img(img, model_name=self.model.cfg.model_name)
|
76
151
|
cv2.imwrite(str(save_folder / Path(f"{batch_idx:06d}_{key}.png")), tensor2npy(img))
|
152
|
+
if self.cfg.use_tensorboard == True and batch_idx == 0:
|
153
|
+
self.writer.add_image(f"imgs/{key}", img[0].detach().cpu(), epoch)
|
77
154
|
|
78
|
-
def
|
79
|
-
"""
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
self.
|
93
|
-
|
94
|
-
|
155
|
+
def optimize_parameters(self, loss_dict, final_flag):
|
156
|
+
"""Optimize the parameters from the loss_dict."""
|
157
|
+
loss = functools.reduce(torch.add, loss_dict.values())
|
158
|
+
step_grad = self.cfg.steps_grad_accumulation
|
159
|
+
# for snn methods
|
160
|
+
if self.model.cfg.model_name == "ssir":
|
161
|
+
if self.cnt_grad % step_grad != step_grad - 1 and final_flag == False:
|
162
|
+
loss.backward(retain_graph=True)
|
163
|
+
self.cnt_grad += 1
|
164
|
+
else:
|
165
|
+
loss.backward(retain_graph=False)
|
166
|
+
self.optimizer.step()
|
167
|
+
self.optimizer.zero_grad()
|
168
|
+
self._state_reset(self.model)
|
169
|
+
self.cnt_grad = 0
|
170
|
+
else:
|
171
|
+
# for cnn methods
|
172
|
+
self.optimizer.zero_grad()
|
173
|
+
loss.backward()
|
174
|
+
self.optimizer.step()
|
175
|
+
|
176
|
+
def update_learning_rate(self):
|
177
|
+
"""Update the learning rate."""
|
178
|
+
self.scheduler.step() if self.cfg.scheduler_cfg != None else None
|
179
|
+
|
180
|
+
def write_log_train(self, epoch, loss_values_dict):
|
181
|
+
"""Write the train log information."""
|
182
|
+
self.logger.info(f"EPOCH {epoch}/{self.cfg.epochs}: Train Loss: {loss_values_dict}")
|
183
|
+
if self.cfg.use_tensorboard:
|
184
|
+
for name, val in loss_values_dict.items():
|
185
|
+
self.writer.add_scalar(f"Loss/{name}", val, epoch)
|
186
|
+
lr = self.optimizer.param_groups[0]["lr"]
|
187
|
+
self.writer.add_scalar(f"Loss/lr", lr, epoch)
|
188
|
+
|
189
|
+
def write_log_test(self, epoch, metrics_dict):
|
190
|
+
"""Write the test log information."""
|
191
|
+
if self.cfg.use_tensorboard:
|
192
|
+
for name in metrics_dict.keys():
|
193
|
+
self.writer.add_scalar(f"Test/{name}", metrics_dict[name].avg, epoch)
|