spikezoo 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (86) hide show
  1. spikezoo/__init__.py +23 -7
  2. spikezoo/archs/bsf/models/bsf/bsf.py +37 -25
  3. spikezoo/archs/bsf/models/bsf/rep.py +2 -2
  4. spikezoo/archs/spk2imgnet/nets.py +1 -1
  5. spikezoo/archs/ssir/models/networks.py +1 -1
  6. spikezoo/archs/ssml/model.py +9 -5
  7. spikezoo/archs/stir/metrics/losses.py +1 -1
  8. spikezoo/archs/stir/models/networks_STIR.py +16 -9
  9. spikezoo/archs/tfi/nets.py +1 -1
  10. spikezoo/archs/tfp/nets.py +1 -1
  11. spikezoo/archs/wgse/dwtnets.py +6 -6
  12. spikezoo/datasets/__init__.py +11 -9
  13. spikezoo/datasets/base_dataset.py +10 -3
  14. spikezoo/datasets/realworld_dataset.py +1 -3
  15. spikezoo/datasets/{reds_small_dataset.py → reds_base_dataset.py} +9 -8
  16. spikezoo/datasets/reds_ssir_dataset.py +181 -0
  17. spikezoo/datasets/szdata_dataset.py +5 -15
  18. spikezoo/datasets/uhsr_dataset.py +4 -3
  19. spikezoo/models/__init__.py +8 -6
  20. spikezoo/models/base_model.py +120 -64
  21. spikezoo/models/bsf_model.py +11 -3
  22. spikezoo/models/spcsnet_model.py +19 -0
  23. spikezoo/models/spikeclip_model.py +4 -3
  24. spikezoo/models/spk2imgnet_model.py +9 -15
  25. spikezoo/models/ssir_model.py +4 -6
  26. spikezoo/models/ssml_model.py +44 -2
  27. spikezoo/models/stir_model.py +26 -5
  28. spikezoo/models/tfi_model.py +3 -1
  29. spikezoo/models/tfp_model.py +4 -2
  30. spikezoo/models/wgse_model.py +8 -14
  31. spikezoo/pipeline/base_pipeline.py +79 -55
  32. spikezoo/pipeline/ensemble_pipeline.py +10 -9
  33. spikezoo/pipeline/train_cfgs.py +89 -0
  34. spikezoo/pipeline/train_pipeline.py +129 -30
  35. spikezoo/utils/optimizer_utils.py +22 -0
  36. spikezoo/utils/other_utils.py +31 -6
  37. spikezoo/utils/scheduler_utils.py +25 -0
  38. spikezoo/utils/spike_utils.py +61 -29
  39. spikezoo-0.2.3.dist-info/METADATA +263 -0
  40. {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.dist-info}/RECORD +43 -80
  41. spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
  42. spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
  43. spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
  44. spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
  45. spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
  46. spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
  47. spikezoo/archs/spikeformer/CheckPoints/readme +0 -1
  48. spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +0 -60
  49. spikezoo/archs/spikeformer/DataProcess/DataLoader.py +0 -115
  50. spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +0 -39
  51. spikezoo/archs/spikeformer/EvalResults/readme +0 -1
  52. spikezoo/archs/spikeformer/LICENSE +0 -21
  53. spikezoo/archs/spikeformer/Metrics/Metrics.py +0 -50
  54. spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
  55. spikezoo/archs/spikeformer/Model/Loss.py +0 -89
  56. spikezoo/archs/spikeformer/Model/SpikeFormer.py +0 -230
  57. spikezoo/archs/spikeformer/Model/__init__.py +0 -0
  58. spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
  59. spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
  60. spikezoo/archs/spikeformer/README.md +0 -30
  61. spikezoo/archs/spikeformer/evaluate.py +0 -87
  62. spikezoo/archs/spikeformer/recon_real_data.py +0 -97
  63. spikezoo/archs/spikeformer/requirements.yml +0 -95
  64. spikezoo/archs/spikeformer/train.py +0 -173
  65. spikezoo/archs/spikeformer/utils.py +0 -22
  66. spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
  67. spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
  68. spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
  69. spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
  70. spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
  71. spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
  72. spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
  73. spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
  74. spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
  75. spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
  76. spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
  77. spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
  78. spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
  79. spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
  80. spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
  81. spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
  82. spikezoo/models/spikeformer_model.py +0 -50
  83. spikezoo-0.2.2.dist-info/METADATA +0 -196
  84. {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.dist-info}/LICENSE.txt +0 -0
  85. {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.dist-info}/WHEEL +0 -0
  86. {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,12 @@
1
1
  from dataclasses import dataclass, field
2
2
  from spikezoo.models.base_model import BaseModel, BaseModelConfig
3
3
  from typing import List
4
+ from spikezoo.pipeline import TrainPipelineConfig
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ import torch.optim.lr_scheduler as lr_scheduler
8
+ from typing import List
9
+ from spikezoo.archs.wgse.dwtnets import Dwt1dResnetX_TCN
4
10
 
5
11
 
6
12
  @dataclass
@@ -9,21 +15,9 @@ class WGSEConfig(BaseModelConfig):
9
15
  model_name: str = "wgse"
10
16
  model_file_name: str = "dwtnets"
11
17
  model_cls_name: str = "Dwt1dResnetX_TCN"
12
- model_win_length: int = 41
18
+ model_length: int = 41
19
+ model_length_dict: dict = field(default_factory=lambda: {"v010": 41, "v023": 41})
13
20
  require_params: bool = True
14
- ckpt_path: str = "weights/wgse.pt"
15
- model_params: dict = field(
16
- default_factory=lambda: {
17
- "wvlname": "db8",
18
- "J": 5,
19
- "yl_size": "15",
20
- "yh_size": [28, 21, 18, 16, 15],
21
- "num_residual_blocks": 3,
22
- "norm": None,
23
- "ks": 3,
24
- "store_features": True,
25
- }
26
- )
27
21
 
28
22
 
29
23
  class WGSE(BaseModel):
@@ -18,24 +18,27 @@ from tqdm import tqdm
18
18
  from spikezoo.models import build_model_cfg, build_model_name, BaseModel, BaseModelConfig
19
19
  from spikezoo.datasets import build_dataset_cfg, build_dataset_name, BaseDataset, BaseDatasetConfig, build_dataloader
20
20
  from typing import Optional, Union, List
21
+ import shutil
22
+ from spikingjelly.clock_driven import functional
23
+ import spikezoo as sz
21
24
 
22
25
 
23
26
  @dataclass
24
27
  class PipelineConfig:
25
- "Evaluate metrics or not."
26
- save_metric: bool = True
27
- "Save recoverd images or not."
28
- save_img: bool = True
29
- "Normalizing recoverd images or not."
30
- save_img_norm: bool = False
31
- "Normalizing gt or not."
32
- gt_img_norm: bool = False
28
+ "Loading weights from local or version on the url."
29
+ version: Literal["local", "v010", "v023"] = "local"
33
30
  "Save folder for the code running result."
34
31
  save_folder: str = ""
35
32
  "Saved experiment name."
36
33
  exp_name: str = ""
34
+ "Evaluate metrics or not."
35
+ save_metric: bool = True
37
36
  "Metric names for evaluation."
38
37
  metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim"])
38
+ "Save recoverd images or not."
39
+ save_img: bool = True
40
+ "Normalizing recoverd images and gt or not."
41
+ save_img_norm: bool = False
39
42
  "Different modes for the pipeline."
40
43
  _mode: Literal["single_mode", "multi_mode", "train_mode"] = "single_mode"
41
44
 
@@ -44,8 +47,8 @@ class Pipeline:
44
47
  def __init__(
45
48
  self,
46
49
  cfg: PipelineConfig,
47
- model_cfg: Union[str, BaseModelConfig],
48
- dataset_cfg: Union[str, BaseDatasetConfig],
50
+ model_cfg: Union[sz.METHOD, BaseModelConfig],
51
+ dataset_cfg: Union[sz.DATASET, BaseDatasetConfig],
49
52
  ):
50
53
  self.cfg = cfg
51
54
  self._setup_model_data(model_cfg, dataset_cfg)
@@ -53,9 +56,9 @@ class Pipeline:
53
56
 
54
57
  def _setup_model_data(self, model_cfg, dataset_cfg):
55
58
  """Model and Data setup."""
56
- # model
59
+ # model [1] build the model. [2] build the network.
57
60
  self.model: BaseModel = build_model_name(model_cfg) if isinstance(model_cfg, str) else build_model_cfg(model_cfg)
58
- self.model = self.model.eval()
61
+ self.model.build_network(mode="eval", version=self.cfg.version)
59
62
  torch.set_grad_enabled(False)
60
63
  # dataset
61
64
  self.dataset: BaseDataset = build_dataset_name(dataset_cfg) if isinstance(dataset_cfg, str) else build_dataset_cfg(dataset_cfg)
@@ -74,15 +77,19 @@ class Pipeline:
74
77
  if len(self.cfg.exp_name) == 0
75
78
  else self.save_folder / Path(f"{mode_name}/{self.cfg.exp_name}")
76
79
  )
77
- save_folder = self.save_folder
78
- os.makedirs(str(save_folder), exist_ok=True)
80
+ # remove and establish folder
81
+ save_folder = str(self.save_folder)
82
+ if os.path.exists(save_folder):
83
+ shutil.rmtree(save_folder)
84
+ os.makedirs(save_folder)
85
+ save_folder = Path(save_folder)
79
86
  # logger result
80
87
  self.logger = setup_logging(save_folder / Path("result.log"))
81
88
  self.logger.info(f"Info logs are saved on the {save_folder}/result.log")
82
89
  # pipeline config
83
90
  save_config(self.cfg, save_folder / Path("cfg_pipeline.log"))
84
91
  # model config
85
- if self.cfg._mode == "single_mode":
92
+ if self.cfg._mode in ["single_mode", "train_mode"]:
86
93
  save_config(self.model.cfg, save_folder / Path("cfg_model.log"))
87
94
  elif self.cfg._mode == "multi_mode":
88
95
  for model in self.model_list:
@@ -90,28 +97,29 @@ class Pipeline:
90
97
  # dataset config
91
98
  save_config(self.dataset.cfg, save_folder / Path("cfg_dataset.log"))
92
99
 
93
- def spk2img_from_dataset(self, idx=0):
94
- """Func---Save the recoverd image and calculate the metric from the given dataset."""
100
+ def infer_from_dataset(self, idx=0):
101
+ """Function I---Save the recoverd image and calculate the metric from the given dataset."""
95
102
  # save folder
96
- self.logger.info("*********************** spk2img_from_dataset ***********************")
97
- save_folder = self.save_folder / Path(f"spk2img_from_dataset/{self.dataset.cfg.dataset_name}_dataset/{self.dataset.cfg.split}/{idx:06d}")
103
+ self.logger.info("*********************** infer_from_dataset ***********************")
104
+ save_folder = self.save_folder / Path(f"infer_from_dataset/{self.dataset.cfg.dataset_name}_dataset/{self.dataset.cfg.split}/{idx:06d}")
98
105
  os.makedirs(str(save_folder), exist_ok=True)
99
106
 
100
107
  # data process
108
+ # todo
101
109
  batch = self.dataset[idx]
102
- spike, img = batch["spike"], batch["img"]
110
+ spike, img, rate = batch["spike"], batch["gt_img"], batch["rate"]
103
111
  spike = spike[None].to(self.device)
104
112
  if self.dataset.cfg.with_img == True:
105
113
  img = img[None].to(self.device)
106
114
  else:
107
115
  img = None
108
- return self._spk2img(spike, img, save_folder)
116
+ return self.infer(spike, img, save_folder, rate)
109
117
 
110
- def spk2img_from_file(self, file_path, height = -1, width = -1, img_path=None, remove_head=False):
111
- """Func---Save the recoverd image and calculate the metric from the given input file."""
118
+ def infer_from_file(self, file_path, height=-1, width=-1, img_path=None, rate=1, remove_head=False):
119
+ """Function II---Save the recoverd image and calculate the metric from the given input file."""
112
120
  # save folder
113
- self.logger.info("*********************** spk2img_from_file ***********************")
114
- save_folder = self.save_folder / Path(f"spk2img_from_file/{os.path.basename(file_path)}")
121
+ self.logger.info("*********************** infer_from_file ***********************")
122
+ save_folder = self.save_folder / Path(f"infer_from_file/{os.path.basename(file_path)}")
115
123
  os.makedirs(str(save_folder), exist_ok=True)
116
124
 
117
125
  # load spike from .dat
@@ -132,13 +140,13 @@ class Pipeline:
132
140
  else:
133
141
  img = img_path
134
142
  spike = torch.from_numpy(spike)[None].to(self.device)
135
- return self._spk2img(spike, img, save_folder)
143
+ return self.infer(spike, img, save_folder, rate)
136
144
 
137
- def spk2img_from_spk(self, spike, img=None):
138
- """Func---Save the recoverd image and calculate the metric from the given spike stream."""
145
+ def infer_from_spk(self, spike, img=None, rate=1):
146
+ """Function III---Save the recoverd image and calculate the metric from the given spike stream."""
139
147
  # save folder
140
- self.logger.info("*********************** spk2img_from_spk ***********************")
141
- save_folder = self.save_folder / Path(f"spk2img_from_spk/{self.thistime}")
148
+ self.logger.info("*********************** infer_from_spk ***********************")
149
+ save_folder = self.save_folder / Path(f"infer_from_spk")
142
150
  os.makedirs(str(save_folder), exist_ok=True)
143
151
 
144
152
  # spike process
@@ -157,36 +165,40 @@ class Pipeline:
157
165
  img = torch.from_numpy(img)[None, None].to(self.device)
158
166
  else:
159
167
  raise RuntimeError("Not recognized image input type.")
160
- return self._spk2img(spike, img, save_folder)
168
+ return self.infer(spike, img, save_folder, rate)
169
+
170
+ # TODO: To be overridden
171
+ def infer(self, spike, img, save_folder, rate):
172
+ """Function IV---Spike-to-image conversion interface, input data format: spike [bs,c,h,w] (0-1), img [bs,1,h,w] (0-1)"""
173
+ return self._infer_model(self.model, spike, img, save_folder, rate)
161
174
 
162
175
  def save_imgs_from_dataset(self):
163
- """Func---Save all images from the given dataset."""
176
+ """Function V---Save all images from the given dataset."""
177
+ base_setting = self.cfg.save_metric
178
+ self.cfg.save_metric = False
164
179
  for idx in range(len(self.dataset)):
165
- self.spk2img_from_dataset(idx=idx)
166
-
180
+ self.infer_from_dataset(idx=idx)
181
+ self.cfg.save_metric = base_setting
182
+
167
183
  # TODO: To be overridden
168
184
  def cal_params(self):
169
- """Func---Calculate the parameters/flops/latency of the given method."""
185
+ """Function VI---Calculate the parameters/flops/latency of the given method."""
170
186
  self._cal_prams_model(self.model)
171
187
 
172
188
  # TODO: To be overridden
173
189
  def cal_metrics(self):
174
- """Func---Calculate the metric of the given method."""
175
- self._cal_metrics_model(self.model)
176
-
177
- # TODO: To be overridden
178
- def _spk2img(self, spike, img, save_folder):
179
- """Spike-to-image: spike:[bs,c,h,w] (0-1), img:[bs,1,h,w] (0-1)"""
180
- return self._spk2img_model(self.model, spike, img, save_folder)
190
+ """Function VII---Calculate the metric of the given method."""
191
+ return self._cal_metrics_model(self.model)
181
192
 
182
- def _spk2img_model(self, model, spike, img, save_folder):
193
+ def _infer_model(self, model, spike, img, save_folder, rate):
183
194
  """Spike-to-image from the given model."""
184
195
  # spike2image conversion
185
196
  model_name = model.cfg.model_name
186
- recon_img = model(spike)
197
+ recon_img = model.spk2img(spike)
187
198
  recon_img_copy = recon_img.clone()
188
199
  # normalization
189
- recon_img, img = self._post_process_img(model_name, recon_img, img)
200
+ recon_img, img = self._post_process_img(recon_img, model_name, rate), self._post_process_img(img, None, 1)
201
+ self._state_reset(model)
190
202
  # metric
191
203
  if self.cfg.save_metric == True:
192
204
  self.logger.info(f"----------------------Method: {model_name.upper()}----------------------")
@@ -209,19 +221,24 @@ class Pipeline:
209
221
  self.logger.info(f"Images are saved on the {save_folder}")
210
222
  return recon_img_copy
211
223
 
212
- def _post_process_img(self, model_name, recon_img, gt_img):
224
+ def _post_process_img(self, recon_img, model_name, rate=1):
213
225
  """Post process the reconstructed image."""
214
- # TFP and TFI algorithms are normalized automatically, others are normalized based on the self.cfg.use_norm
215
- if model_name in ["tfp", "tfi", "spikeformer", "spikeclip"]:
216
- recon_img = (recon_img - recon_img.min()) / (recon_img.max() - recon_img.min())
217
- elif self.cfg.save_img_norm == True:
226
+ # With no GT
227
+ if recon_img == None:
228
+ return None
229
+ # TFP, TFI, spikeclip algorithms are normalized automatically, others are normalized based on the self.cfg.use_norm
230
+ if model_name in ["tfp", "tfi", "spikeclip"] or self.cfg.save_img_norm == True:
218
231
  recon_img = (recon_img - recon_img.min()) / (recon_img.max() - recon_img.min())
232
+ else:
233
+ recon_img = recon_img / rate
219
234
  recon_img = recon_img.clip(0, 1)
220
- gt_img = (gt_img - gt_img.min()) / (gt_img.max() - gt_img.min()) if self.cfg.gt_img_norm == True and gt_img is not None else gt_img
221
- return recon_img, gt_img
235
+ return recon_img
222
236
 
223
237
  def _cal_metrics_model(self, model: BaseModel):
224
238
  """Calculate the metrics for the given model."""
239
+ # metric state reset (since get_outputs_dict from the training state is utilized)
240
+ model_state = model.net.training
241
+ model.net.training = True
225
242
  # metrics construct
226
243
  model_name = model.cfg.model_name
227
244
  metrics_dict = {}
@@ -234,17 +251,19 @@ class Pipeline:
234
251
  batch = model.feed_to_device(batch)
235
252
  outputs = model.get_outputs_dict(batch)
236
253
  recon_img, img = model.get_paired_imgs(batch, outputs)
237
- recon_img, img = self._post_process_img(model_name, recon_img, img)
254
+ recon_img, img = self._post_process_img(recon_img, model_name), self._post_process_img(img, "auto")
238
255
  for metric_name in metrics_dict.keys():
239
256
  if metric_name in metric_pair_names:
240
257
  metrics_dict[metric_name].update(cal_metric_pair(recon_img, img, metric_name))
241
258
  elif metric_name in metric_single_names:
242
259
  metrics_dict[metric_name].update(cal_metric_single(recon_img, metric_name))
243
-
260
+ self._state_reset(model)
261
+ model.net.training = model_state
244
262
  # metrics self.logger.info
245
263
  self.logger.info(f"----------------------Method: {model_name.upper()}----------------------")
246
264
  for metric_name in metrics_dict.keys():
247
265
  self.logger.info(f"{metric_name.upper()}: {metrics_dict[metric_name].avg}")
266
+ return metrics_dict
248
267
 
249
268
  def _cal_prams_model(self, model):
250
269
  """Calculate the parameters for the given model."""
@@ -256,7 +275,7 @@ class Pipeline:
256
275
  spike = torch.zeros((1, 200, 250, 400)).cuda()
257
276
  start_time = time.time()
258
277
  for _ in range(100):
259
- model(spike)
278
+ model.spk2img(spike)
260
279
  latency = (time.time() - start_time) / 100
261
280
  # flop # todo thop bug for BSF
262
281
  flops, _ = profile((model), inputs=(spike,))
@@ -267,3 +286,8 @@ class Pipeline:
267
286
  )
268
287
  self.logger.info(f"----------------------Method: {model_name}----------------------")
269
288
  self.logger.info(re_msg)
289
+
290
+ def _state_reset(self, model):
291
+ """State reset for the snn-based method."""
292
+ if model.cfg.model_name == "ssir":
293
+ functional.reset_net(model.net)
@@ -19,6 +19,7 @@ from spikezoo.models import build_model_cfg, build_model_name, BaseModel, BaseMo
19
19
  from spikezoo.datasets import build_dataset_cfg, build_dataset_name, BaseDataset, BaseDatasetConfig, build_dataloader
20
20
  from typing import Optional, Union, List
21
21
  from spikezoo.pipeline.base_pipeline import Pipeline, PipelineConfig
22
+ import spikezoo as sz
22
23
 
23
24
 
24
25
  @dataclass
@@ -30,30 +31,30 @@ class EnsemblePipeline(Pipeline):
30
31
  def __init__(
31
32
  self,
32
33
  cfg: PipelineConfig,
33
- model_cfg_list: Union[List[str], List[BaseModelConfig]],
34
- dataset_cfg: Union[str, BaseDatasetConfig],
34
+ model_cfg_list: Union[List[sz.METHOD], List[BaseModelConfig]],
35
+ dataset_cfg: Union[sz.DATASET, BaseDatasetConfig],
35
36
  ):
36
37
  self.cfg = cfg
37
- self._setup_model_data(model_cfg_list,dataset_cfg)
38
+ self._setup_model_data(model_cfg_list, dataset_cfg)
38
39
  self._setup_pipeline()
39
40
 
40
- def _setup_model_data(self,model_cfg_list,dataset_cfg):
41
+ def _setup_model_data(self, model_cfg_list, dataset_cfg):
41
42
  """Model and Data setup."""
42
43
  # model
43
44
  self.model_list: List[BaseModel] = (
44
- [build_model_name(name) for name in model_cfg_list] if isinstance(model_cfg_list[0],str) else [build_model_cfg(cfg) for cfg in model_cfg_list]
45
+ [build_model_name(name) for name in model_cfg_list] if isinstance(model_cfg_list[0], str) else [build_model_cfg(cfg) for cfg in model_cfg_list]
45
46
  )
46
- self.model_list = [model.eval() for model in self.model_list]
47
+ self.model_list = [model.build_network(mode="eval", version=self.cfg.version) for model in self.model_list]
47
48
  torch.set_grad_enabled(False)
48
49
  # data
49
50
  self.dataset: BaseDataset = build_dataset_name(dataset_cfg) if isinstance(dataset_cfg, str) else build_dataset_cfg(dataset_cfg)
50
51
  self.dataloader = build_dataloader(self.dataset)
51
52
  # device
52
53
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
53
-
54
- def _spk2img(self, spike, img, save_folder):
54
+
55
+ def infer(self, spike, img, save_folder, rate):
55
56
  for model in self.model_list:
56
- self._spk2img_model(model, spike, img, save_folder)
57
+ self._infer_model(model, spike, img, save_folder, rate)
57
58
 
58
59
  def cal_params(self):
59
60
  for model in self.model_list:
@@ -0,0 +1,89 @@
1
+ import torch.nn as nn
2
+ import torch.optim as optimizer
3
+ import torch.optim.lr_scheduler as lr_scheduler
4
+ import functools
5
+ from spikezoo.utils.optimizer_utils import OptimizerConfig, AdamOptimizerConfig
6
+ from spikezoo.utils.scheduler_utils import SchedulerConfig, MultiStepSchedulerConfig, CosineAnnealingLRConfig
7
+ from dataclasses import dataclass, field
8
+ from spikezoo.pipeline.train_pipeline import TrainPipelineConfig
9
+ from typing import Optional, Dict, List
10
+
11
+
12
+ @dataclass
13
+ class REDS_BASE_TrainConfig(TrainPipelineConfig):
14
+ """Training setting for methods on the REDS-BASE dataset."""
15
+
16
+ # parameters setting
17
+ epochs: int = 600
18
+ steps_per_save_imgs: int = 200
19
+ steps_per_save_ckpt: int = 500
20
+ steps_per_cal_metrics: int = 100
21
+ metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim","lpips","niqe","brisque","piqe"])
22
+
23
+ # dataloader setting
24
+ bs_train: int = 8
25
+ num_workers: int = 4
26
+ pin_memory: bool = False
27
+
28
+ # train setting - optimizer & scheduler & loss_dict
29
+ optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4)
30
+ scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[400], gamma=0.2) # from wgse
31
+ loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
32
+
33
+ # @dataclass
34
+ # class REDS_BASE_TrainConfig(TrainPipelineConfig):
35
+ # """Training setting for methods on the REDS-BASE dataset."""
36
+
37
+ # # parameters setting
38
+ # epochs: int = 700
39
+ # steps_per_save_imgs: int = 200
40
+ # steps_per_save_ckpt: int = 500
41
+ # steps_per_cal_metrics: int = 100
42
+ # metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim"])
43
+
44
+ # # dataloader setting
45
+ # bs_train: int = 8
46
+ # num_workers: int = 4
47
+ # pin_memory: bool = False
48
+
49
+ # # train setting - optimizer & scheduler & loss_dict
50
+ # optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4)
51
+ # scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[400, 600], gamma=0.2) # from wgse
52
+ # loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
53
+
54
+
55
+ # ! Train Config for each method on the official setting, not recommended to utilize their default parameters owing to the dataset setting.
56
+ @dataclass
57
+ class BSFTrainConfig(TrainPipelineConfig):
58
+ """Training setting for BSF. https://github.com/ruizhao26/BSF"""
59
+
60
+ optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4, weight_decay=0.0)
61
+ scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[10, 20, 30, 40, 50, 60, 70, 80, 90, 100])
62
+ loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
63
+
64
+
65
+ @dataclass
66
+ class WGSETrainConfig(TrainPipelineConfig):
67
+ """Training setting for WGSE. https://github.com/Leozhangjiyuan/WGSE-SpikeCamera"""
68
+
69
+ optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4, betas=(0.9, 0.99), weight_decay=0)
70
+ scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[400, 600], gamma=0.2)
71
+ loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
72
+
73
+
74
+ @dataclass
75
+ class STIRTrainConfig(TrainPipelineConfig):
76
+ """Training setting for STIR. https://github.com/GitCVfb/STIR"""
77
+
78
+ optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4, betas=(0.9, 0.999))
79
+ scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70], gamma=0.7)
80
+ loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
81
+
82
+
83
+ @dataclass
84
+ class Spk2ImgNetTrainConfig(TrainPipelineConfig):
85
+ """Training setting for Spk2ImgNet. https://github.com/Vspacer/Spk2ImgNet"""
86
+
87
+ optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4)
88
+ scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[20], gamma=0.1)
89
+ loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
@@ -1,27 +1,64 @@
1
1
  import torch
2
- from dataclasses import dataclass
2
+ from dataclasses import dataclass, field
3
3
  import os
4
4
  from spikezoo.utils.img_utils import tensor2npy
5
5
  import cv2
6
6
  from pathlib import Path
7
- from typing import Literal
7
+ from typing import Literal, Dict
8
8
  from tqdm import tqdm
9
9
  from spikezoo.models import build_model_cfg, build_model_name, BaseModel, BaseModelConfig
10
10
  from spikezoo.datasets import build_dataset_cfg, build_dataset_name, BaseDataset, BaseDatasetConfig, build_dataloader
11
- from typing import Union
11
+ from typing import Union, List, Optional
12
12
  from spikezoo.pipeline.base_pipeline import Pipeline, PipelineConfig
13
+ import torch.nn as nn
14
+ import torch.optim as optimizer
15
+ import torch.optim.lr_scheduler as lr_scheduler
16
+ import functools
17
+ from spikezoo.utils.optimizer_utils import OptimizerConfig, AdamOptimizerConfig
18
+ from spikezoo.utils.scheduler_utils import SchedulerConfig, MultiStepSchedulerConfig, CosineAnnealingLRConfig
19
+ from torch.utils.tensorboard import SummaryWriter
20
+ import subprocess
21
+ import webbrowser
22
+ import time
23
+ import re
24
+ from spikezoo.utils.other_utils import set_random_seed
25
+ from spikingjelly.clock_driven import functional
13
26
 
14
27
 
15
28
  @dataclass
16
29
  class TrainPipelineConfig(PipelineConfig):
17
- bs_train: int = 4
18
- epochs: int = 100
19
- lr: float = 1e-3
30
+ # parameters setting
31
+ "Training epochs."
32
+ epochs: int = 1000
33
+ "Steps per to save images."
34
+ steps_per_save_imgs: int = 200
35
+ "Steps per to save model weights."
36
+ steps_per_save_ckpt: int = 500
37
+ "Steps per to calculate the metrics."
38
+ steps_per_cal_metrics: int = 100
39
+ "Step for gradient accumulation. (for snn methods)"
40
+ steps_grad_accumulation: int = 4
41
+ "Pipeline mode."
42
+ _mode: Literal["single_mode", "multi_mode", "train_mode"] = "train_mode"
43
+ "Use tensorboard or not"
44
+ use_tensorboard: bool = True
45
+ "Random seed."
46
+ seed: int = 521
47
+ # dataloader setting
48
+ "Batch size for the train dataloader."
49
+ bs_train: int = 8
50
+ "Num_workers for the train dataloader."
20
51
  num_workers: int = 4
52
+ "Pin_memory true or false for the train dataloader."
21
53
  pin_memory: bool = False
22
- steps_per_save_imgs = 10
23
- steps_per_cal_metrics = 10
24
- _mode: Literal["single_mode", "multi_mode", "train_mode"] = "train_mode"
54
+
55
+ # train setting - optimizer & scheduler & loss_dict
56
+ "Optimizer config."
57
+ optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-3)
58
+ "Scheduler config."
59
+ scheduler_cfg: Optional[SchedulerConfig] = None
60
+ "Loss dict {loss_name,weight}."
61
+ loss_weight_dict: Dict[Literal["l1", "l2"], float] = field(default_factory=lambda: {"l1": 1})
25
62
 
26
63
 
27
64
  class TrainPipeline(Pipeline):
@@ -34,13 +71,20 @@ class TrainPipeline(Pipeline):
34
71
  self.cfg = cfg
35
72
  self._setup_model_data(model_cfg, dataset_cfg)
36
73
  self._setup_pipeline()
37
- self.model.setup_training(cfg)
74
+ self._setup_training()
75
+
76
+ def _setup_pipeline(self):
77
+ super()._setup_pipeline()
78
+ set_random_seed(self.cfg.seed)
79
+ if self.cfg.use_tensorboard:
80
+ self.writer = SummaryWriter(self.save_folder / Path(""))
81
+ subprocess.Popen(["tensorboard", f"--logdir={self.save_folder}"])
38
82
 
39
83
  def _setup_model_data(self, model_cfg, dataset_cfg):
40
84
  """Model and Data setup."""
41
85
  # model
42
86
  self.model: BaseModel = build_model_name(model_cfg) if isinstance(model_cfg, str) else build_model_cfg(model_cfg)
43
- self.model = self.model.train()
87
+ self.model.build_network(mode = "train",version="local")
44
88
  torch.set_grad_enabled(True)
45
89
  # data
46
90
  if isinstance(dataset_cfg, str):
@@ -54,41 +98,96 @@ class TrainPipeline(Pipeline):
54
98
  # device
55
99
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
56
100
 
101
+ def _setup_training(self):
102
+ """Setup training optimizer."""
103
+ self.optimizer = self.cfg.optimizer_cfg.setup(self.model.net.parameters())
104
+ self.scheduler = self.cfg.scheduler_cfg.setup(self.optimizer) if self.cfg.scheduler_cfg != None else None
105
+ self.cnt_grad = 0
106
+
57
107
  def save_network(self, epoch):
58
108
  """Save the network."""
59
109
  save_folder = self.save_folder / Path("ckpt")
60
110
  os.makedirs(save_folder, exist_ok=True)
61
111
  self.model.save_network(save_folder / f"{epoch:06d}.pth")
62
112
 
113
+ def train(self):
114
+ """Training code."""
115
+ self.logger.info("Start Training!")
116
+ for epoch in range(self.cfg.epochs):
117
+ # training
118
+ for batch_idx, batch in enumerate(tqdm(self.train_dataloader)):
119
+ batch = self.model.feed_to_device(batch)
120
+ outputs = self.model.get_outputs_dict(batch)
121
+ loss_dict, loss_values_dict = self.model.get_loss_dict(outputs, batch, self.cfg.loss_weight_dict)
122
+ self.optimize_parameters(loss_dict, batch_idx == len(self.train_dataloader) - 1)
123
+ self.update_learning_rate()
124
+ self.write_log_train(epoch, loss_values_dict)
125
+
126
+ # save visual results & save ckpt & evaluate metrics
127
+ with torch.no_grad():
128
+ if epoch % self.cfg.steps_per_save_imgs == 0 or epoch == self.cfg.epochs - 1:
129
+ self.save_visual(epoch)
130
+ if epoch % self.cfg.steps_per_save_ckpt == 0 or epoch == self.cfg.epochs - 1:
131
+ self.save_network(epoch)
132
+ if epoch % self.cfg.steps_per_cal_metrics == 0 or epoch == self.cfg.epochs - 1:
133
+ metrics_dict = self.cal_metrics()
134
+ self.write_log_test(epoch, metrics_dict)
135
+
63
136
  def save_visual(self, epoch):
64
137
  """Save the visual results."""
65
138
  self.logger.info("Saving visual results...")
66
139
  save_folder = self.save_folder / Path("imgs") / Path(f"{epoch:06d}")
67
140
  os.makedirs(save_folder, exist_ok=True)
68
141
  for batch_idx, batch in enumerate(tqdm(self.dataloader)):
69
- if batch_idx % (len(self.dataloader) // 4) != 0:
70
- continue
142
+ if batch_idx % (len(self.dataloader) // 3) != 0:
143
+ continue
71
144
  batch = self.model.feed_to_device(batch)
72
145
  outputs = self.model.get_outputs_dict(batch)
73
146
  visual_dict = self.model.get_visual_dict(batch, outputs)
147
+ self._state_reset(self.model)
74
148
  # save
75
149
  for key, img in visual_dict.items():
150
+ img = self._post_process_img(img, model_name=self.model.cfg.model_name)
76
151
  cv2.imwrite(str(save_folder / Path(f"{batch_idx:06d}_{key}.png")), tensor2npy(img))
152
+ if self.cfg.use_tensorboard == True and batch_idx == 0:
153
+ self.writer.add_image(f"imgs/{key}", img[0].detach().cpu(), epoch)
77
154
 
78
- def train(self):
79
- """Training code."""
80
- self.logger.info("Start Training!")
81
- for epoch in range(self.cfg.epochs):
82
- # training
83
- for batch_idx, batch in enumerate(tqdm(self.train_dataloader)):
84
- batch = self.model.feed_to_device(batch)
85
- outputs = self.model.get_outputs_dict(batch)
86
- loss_dict, loss_values_dict = self.model.get_loss_dict(outputs, batch)
87
- self.model.optimize_parameters(loss_dict)
88
- self.model.update_learning_rate()
89
- self.logger.info(f"EPOCH {epoch}/{self.cfg.epochs}: Train Loss: {loss_values_dict}")
90
- # save visual results & evaluate metrics
91
- if epoch % self.cfg.steps_per_save_imgs == 0 or epoch == self.cfg.epochs - 1:
92
- self.save_visual(epoch)
93
- if epoch % self.cfg.steps_per_cal_metrics == 0 or epoch == self.cfg.epochs - 1:
94
- self.cal_metrics()
155
+ def optimize_parameters(self, loss_dict, final_flag):
156
+ """Optimize the parameters from the loss_dict."""
157
+ loss = functools.reduce(torch.add, loss_dict.values())
158
+ step_grad = self.cfg.steps_grad_accumulation
159
+ # for snn methods
160
+ if self.model.cfg.model_name == "ssir":
161
+ if self.cnt_grad % step_grad != step_grad - 1 and final_flag == False:
162
+ loss.backward(retain_graph=True)
163
+ self.cnt_grad += 1
164
+ else:
165
+ loss.backward(retain_graph=False)
166
+ self.optimizer.step()
167
+ self.optimizer.zero_grad()
168
+ self._state_reset(self.model)
169
+ self.cnt_grad = 0
170
+ else:
171
+ # for cnn methods
172
+ self.optimizer.zero_grad()
173
+ loss.backward()
174
+ self.optimizer.step()
175
+
176
+ def update_learning_rate(self):
177
+ """Update the learning rate."""
178
+ self.scheduler.step() if self.cfg.scheduler_cfg != None else None
179
+
180
+ def write_log_train(self, epoch, loss_values_dict):
181
+ """Write the train log information."""
182
+ self.logger.info(f"EPOCH {epoch}/{self.cfg.epochs}: Train Loss: {loss_values_dict}")
183
+ if self.cfg.use_tensorboard:
184
+ for name, val in loss_values_dict.items():
185
+ self.writer.add_scalar(f"Loss/{name}", val, epoch)
186
+ lr = self.optimizer.param_groups[0]["lr"]
187
+ self.writer.add_scalar(f"Loss/lr", lr, epoch)
188
+
189
+ def write_log_test(self, epoch, metrics_dict):
190
+ """Write the test log information."""
191
+ if self.cfg.use_tensorboard:
192
+ for name in metrics_dict.keys():
193
+ self.writer.add_scalar(f"Test/{name}", metrics_dict[name].avg, epoch)