spikezoo 0.2.2__py3-none-any.whl → 0.2.3.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (86) hide show
  1. spikezoo/__init__.py +23 -7
  2. spikezoo/archs/bsf/models/bsf/bsf.py +37 -25
  3. spikezoo/archs/bsf/models/bsf/rep.py +2 -2
  4. spikezoo/archs/spk2imgnet/nets.py +1 -1
  5. spikezoo/archs/ssir/models/networks.py +1 -1
  6. spikezoo/archs/ssml/model.py +9 -5
  7. spikezoo/archs/stir/metrics/losses.py +1 -1
  8. spikezoo/archs/stir/models/networks_STIR.py +16 -9
  9. spikezoo/archs/tfi/nets.py +1 -1
  10. spikezoo/archs/tfp/nets.py +1 -1
  11. spikezoo/archs/wgse/dwtnets.py +6 -6
  12. spikezoo/datasets/__init__.py +11 -9
  13. spikezoo/datasets/base_dataset.py +10 -3
  14. spikezoo/datasets/realworld_dataset.py +1 -3
  15. spikezoo/datasets/{reds_small_dataset.py → reds_base_dataset.py} +9 -8
  16. spikezoo/datasets/reds_ssir_dataset.py +181 -0
  17. spikezoo/datasets/szdata_dataset.py +5 -15
  18. spikezoo/datasets/uhsr_dataset.py +4 -3
  19. spikezoo/models/__init__.py +8 -6
  20. spikezoo/models/base_model.py +120 -64
  21. spikezoo/models/bsf_model.py +11 -3
  22. spikezoo/models/spcsnet_model.py +19 -0
  23. spikezoo/models/spikeclip_model.py +4 -3
  24. spikezoo/models/spk2imgnet_model.py +9 -15
  25. spikezoo/models/ssir_model.py +4 -6
  26. spikezoo/models/ssml_model.py +44 -2
  27. spikezoo/models/stir_model.py +26 -5
  28. spikezoo/models/tfi_model.py +3 -1
  29. spikezoo/models/tfp_model.py +4 -2
  30. spikezoo/models/wgse_model.py +8 -14
  31. spikezoo/pipeline/base_pipeline.py +79 -55
  32. spikezoo/pipeline/ensemble_pipeline.py +10 -9
  33. spikezoo/pipeline/train_cfgs.py +89 -0
  34. spikezoo/pipeline/train_pipeline.py +129 -30
  35. spikezoo/utils/optimizer_utils.py +22 -0
  36. spikezoo/utils/other_utils.py +31 -6
  37. spikezoo/utils/scheduler_utils.py +25 -0
  38. spikezoo/utils/spike_utils.py +61 -29
  39. spikezoo-0.2.3.2.dist-info/METADATA +263 -0
  40. {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.2.dist-info}/RECORD +43 -80
  41. spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
  42. spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
  43. spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
  44. spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
  45. spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
  46. spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
  47. spikezoo/archs/spikeformer/CheckPoints/readme +0 -1
  48. spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +0 -60
  49. spikezoo/archs/spikeformer/DataProcess/DataLoader.py +0 -115
  50. spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +0 -39
  51. spikezoo/archs/spikeformer/EvalResults/readme +0 -1
  52. spikezoo/archs/spikeformer/LICENSE +0 -21
  53. spikezoo/archs/spikeformer/Metrics/Metrics.py +0 -50
  54. spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
  55. spikezoo/archs/spikeformer/Model/Loss.py +0 -89
  56. spikezoo/archs/spikeformer/Model/SpikeFormer.py +0 -230
  57. spikezoo/archs/spikeformer/Model/__init__.py +0 -0
  58. spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
  59. spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
  60. spikezoo/archs/spikeformer/README.md +0 -30
  61. spikezoo/archs/spikeformer/evaluate.py +0 -87
  62. spikezoo/archs/spikeformer/recon_real_data.py +0 -97
  63. spikezoo/archs/spikeformer/requirements.yml +0 -95
  64. spikezoo/archs/spikeformer/train.py +0 -173
  65. spikezoo/archs/spikeformer/utils.py +0 -22
  66. spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
  67. spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
  68. spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
  69. spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
  70. spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
  71. spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
  72. spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
  73. spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
  74. spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
  75. spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
  76. spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
  77. spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
  78. spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
  79. spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
  80. spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
  81. spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
  82. spikezoo/models/spikeformer_model.py +0 -50
  83. spikezoo-0.2.2.dist-info/METADATA +0 -196
  84. {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.2.dist-info}/LICENSE.txt +0 -0
  85. {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.2.dist-info}/WHEEL +0 -0
  86. {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.2.dist-info}/top_level.txt +0 -0
@@ -5,8 +5,9 @@ import cv2
5
5
  import torch
6
6
  import numpy as np
7
7
 
8
+ # todo tobe evaluated
8
9
  @dataclass
9
- class SZData_Config(BaseDatasetConfig):
10
+ class SZDataConfig(BaseDatasetConfig):
10
11
  dataset_name: str = "szdata"
11
12
  root_dir: Path = Path(__file__).parent.parent / Path("data/dataset")
12
13
  width: int = 400
@@ -21,17 +22,6 @@ class SZData(BaseDataset):
21
22
  def __init__(self, cfg: BaseDatasetConfig):
22
23
  super(SZData, self).__init__(cfg)
23
24
 
24
- def get_img(self, idx):
25
- if self.cfg.with_img:
26
- spike_name = self.spike_list[idx]
27
- img_name = str(spike_name).replace(self.cfg.spike_dir_name,self.cfg.img_dir_name).replace(".dat",".png")
28
- img = cv2.imread(img_name)
29
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
30
- img = (img / 255).astype(np.float32)
31
- img = img[None]
32
- img = torch.from_numpy(img)
33
- else:
34
- spike = self.get_spike(idx)
35
- img = torch.mean(spike, dim=0, keepdim=True)
36
- return img
37
-
25
+ def prepare_data(self):
26
+ super().prepare_data()
27
+ self.img_list = [self.img_dir / Path(str(s.name).replace('.dat','.png')) for s in self.spike_list]
@@ -4,8 +4,9 @@ from dataclasses import dataclass
4
4
  import numpy as np
5
5
  import torch
6
6
 
7
+
7
8
  @dataclass
8
- class UHSR_Config(BaseDatasetConfig):
9
+ class UHSRConfig(BaseDatasetConfig):
9
10
  dataset_name: str = "uhsr"
10
11
  root_dir: Path = Path(__file__).parent.parent / Path("data/U-CALTECH")
11
12
  width: int = 224
@@ -29,10 +30,10 @@ class UHSR(BaseDataset):
29
30
  files = path.glob("**/*.npz")
30
31
  return sorted(files)
31
32
 
32
- def load_spike(self,idx):
33
+ def load_spike(self, idx):
33
34
  spike_name = str(self.spike_list[idx])
34
35
  data = np.load(spike_name)
35
36
  spike = data["spk"].astype(np.float32)
36
37
  spike = torch.from_numpy(spike)
37
38
  spike = spike[:, 13:237, 13:237]
38
- return spike
39
+ return spike
@@ -1,6 +1,7 @@
1
1
  import importlib
2
2
  import inspect
3
3
  from spikezoo.models.base_model import BaseModel,BaseModelConfig
4
+ from spikezoo.utils.other_utils import getattr_case_insensitive
4
5
  import os
5
6
  from pathlib import Path
6
7
 
@@ -17,8 +18,9 @@ def build_model_cfg(cfg: BaseModelConfig):
17
18
  module_name = "spikezoo.models." + module_name
18
19
  module = importlib.import_module(module_name)
19
20
  # model,model_config
20
- classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
21
- model_cls: BaseModel = getattr(module, classes[0])
21
+ model_name = cfg.model_name
22
+ model_name = model_name + 'Model' if model_name == "base" else model_name
23
+ model_cls: BaseModel = getattr_case_insensitive(module,model_name)
22
24
  model = model_cls(cfg)
23
25
  return model
24
26
 
@@ -30,8 +32,8 @@ def build_model_name(model_name: str):
30
32
  module_name = "spikezoo.models." + module_name
31
33
  module = importlib.import_module(module_name)
32
34
  # model,model_config
33
- classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
34
- model_cls: BaseModel = getattr(module, classes[0])
35
- model_cfg: BaseModelConfig = getattr(module, classes[1])()
35
+ model_name = model_name + 'Model' if model_name == "base" else model_name
36
+ model_cls: BaseModel = getattr_case_insensitive(module,model_name)
37
+ model_cfg: BaseModelConfig = getattr_case_insensitive(module, model_name + 'config')()
36
38
  model = model_cls(model_cfg)
37
- return model
39
+ return model
@@ -6,36 +6,44 @@ from dataclasses import dataclass, field
6
6
  from spikezoo.utils import load_network, download_file
7
7
  import os
8
8
  import time
9
- from typing import Dict
9
+ from typing import Dict, Literal
10
10
  from torch.optim import Adam
11
11
  from torch.optim.lr_scheduler import CosineAnnealingLR
12
12
  import functools
13
+ import torch.nn as nn
14
+ from typing import Optional, Union, List
15
+ from spikezoo.archs.base.nets import BaseNet
13
16
 
14
17
 
15
18
  # todo private design
16
19
  @dataclass
17
20
  class BaseModelConfig:
18
- # default params for BaseModel
21
+ # ------------- Not Recommended to Change -------------
19
22
  "Registerd model name."
20
23
  model_name: str = "base"
21
24
  "File name of the specified model."
22
25
  model_file_name: str = "nets"
23
26
  "Class name of the specified model in spikezoo/archs/base/{model_file_name}.py."
24
27
  model_cls_name: str = "BaseNet"
25
- "Spike input length for the specified model."
26
- model_win_length: int = 41
28
+ "Spike input length. (local mode)"
29
+ model_length: int = 41
30
+ "Spike input length for different versions."
31
+ model_length_dict: dict = field(default_factory=lambda: {"v010": 41, "v023": 41})
27
32
  "Model require model parameters or not."
28
- require_params: bool = False
29
- "Model stored path."
33
+ require_params: bool = True
34
+ "Model parameters. (local mode)"
35
+ model_params: dict = field(default_factory=lambda: {})
36
+ "Model parameters for different versions."
37
+ model_params_dict: dict = field(default_factory=lambda: {"v010": {}, "v023": {}})
38
+ # ------------- Config -------------
39
+ "Load ckpt path. Used on the local mode."
30
40
  ckpt_path: str = ""
31
- "Load pretrained weights or not."
32
- load_state: bool = True
33
- "Base url for storing pretrained models."
34
- base_url: str = "https://github.com/chenkang455/Spike-Zoo/releases/download/v0.1/"
41
+ "Load pretrained weights or not. (default false, set to true during the evaluation mode.)"
42
+ load_state: bool = False
35
43
  "Multi-GPU setting."
36
44
  multi_gpu: bool = False
37
- "Model parameters."
38
- model_params: dict = field(default_factory=lambda: {})
45
+ "Base url."
46
+ base_url: str = "https://github.com/chenkang455/Spike-Zoo/releases/download"
39
47
 
40
48
 
41
49
  class BaseModel(nn.Module):
@@ -43,33 +51,69 @@ class BaseModel(nn.Module):
43
51
  super(BaseModel, self).__init__()
44
52
  self.cfg = cfg
45
53
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
46
- self.net = self.build_network().to(self.device)
47
- self.net = nn.DataParallel(self.net) if cfg.multi_gpu == True else self.net
48
- self.model_half_win_length: int = cfg.model_win_length // 2
54
+ self.loss_func_cache = {}
49
55
 
50
- # ! Might lead to low speed training on the BSF.
51
56
  def forward(self, spike):
52
- """A simple implementation for the spike-to-image conversion, given the spike input and output the reconstructed image."""
57
+ return self.spk2img(spike)
58
+
59
+ # ! Might lead to low speed training on the BSF.
60
+ def spk2img(self, spike):
61
+ """A simple implementation for the spike-to-image conversion (**tailored for the evaluation mode**), given the spike input and output the reconstructed image."""
53
62
  spike = self.preprocess_spike(spike)
54
63
  img = self.net(spike)
55
64
  img = self.postprocess_img(img)
56
65
  return img
57
66
 
58
- def build_network(self):
67
+ def build_network(
68
+ self,
69
+ mode: Literal["debug", "train", "eval"] = "debug",
70
+ version: Literal["local", "v010", "v023"] = "local",
71
+ ):
59
72
  """Build the network and load the pretrained weight."""
60
73
  # network
61
74
  module = importlib.import_module(f"spikezoo.archs.{self.cfg.model_name}.{self.cfg.model_file_name}")
62
75
  model_cls = getattr(module, self.cfg.model_cls_name)
63
- model = model_cls(**self.cfg.model_params)
76
+ # load model config parameters
77
+ if version == "local":
78
+ model = model_cls(**self.cfg.model_params)
79
+ self.model_length = self.cfg.model_length
80
+ self.model_half_length = self.model_length // 2
81
+ else:
82
+ model = model_cls(**self.cfg.model_params_dict[version])
83
+ self.model_length = self.cfg.model_length_dict[version]
84
+ self.model_half_length = self.model_length // 2
85
+ model.train() if mode == "train" else model.eval()
86
+ # auto set the load_state to True under the eval mode
87
+ if mode == "eval" and self.cfg.load_state == False:
88
+ print(f"Method {self.cfg.model_name} on the evaluation mode, load_state is set to True automatically.")
89
+ self.cfg.load_state = True
90
+ # load model
64
91
  if self.cfg.load_state and self.cfg.require_params:
65
- load_folder = os.path.dirname(os.path.abspath(__file__))
66
- weight_path = os.path.join(load_folder, self.cfg.ckpt_path)
67
- if os.path.exists(weight_path) == False:
68
- os.makedirs(os.path.dirname(weight_path), exist_ok=True)
69
- self.download_weight(weight_path)
92
+ # load from the url version
93
+ if version != "local":
94
+ load_folder = os.path.dirname(os.path.abspath(__file__))
95
+ ckpt_name = f"{self.cfg.model_name}.{get_suffix(self.cfg.model_name,version)}"
96
+ ckpt_path = os.path.join("weights",version,ckpt_name)
97
+ ckpt_path = os.path.join(load_folder, ckpt_path)
98
+ ckpt_path_url = os.path.join(self.cfg.base_url,get_url_version(version),ckpt_name)
99
+ elif version == "local":
100
+ ckpt_path = self.cfg.ckpt_path
101
+
102
+ # no ckpt found on the device, try to download from the url
103
+ if os.path.isfile(ckpt_path) == False and version != "local":
104
+ os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
105
+ download_file(ckpt_path_url, ckpt_path)
70
106
  time.sleep(0.5)
71
- model = load_network(weight_path, model)
72
- return model
107
+ elif os.path.isfile(ckpt_path) == False and version == "local":
108
+ raise RuntimeError(
109
+ f"For the method {self.cfg.model_name}, no ckpt can be found on the {ckpt_path} !!! Try set the version to get the model from the url."
110
+ )
111
+ model = load_network(ckpt_path, model)
112
+ # to device
113
+ model = model.to(self.device)
114
+ model = nn.DataParallel(model) if self.cfg.multi_gpu == True else model
115
+ self.net = model
116
+ return self
73
117
 
74
118
  def save_network(self, save_path):
75
119
  """Save the network."""
@@ -81,27 +125,22 @@ class BaseModel(nn.Module):
81
125
  state_dict[key] = param.cpu()
82
126
  torch.save(state_dict, save_path)
83
127
 
84
- def download_weight(self, weight_path):
85
- """Download the pretrained weight from the given url."""
86
- url = self.cfg.base_url + os.path.basename(self.cfg.ckpt_path)
87
- download_file(url, weight_path)
88
-
89
128
  def crop_spike_length(self, spike):
90
129
  """Crop the spike length."""
91
130
  spike_length = spike.shape[1]
92
131
  spike_mid = spike_length // 2
93
- assert spike_length >= self.cfg.model_win_length, f"Spike input is not long enough, given {spike_length} frames < {self.cfg.model_win_length}."
132
+ assert spike_length >= self.model_length, f"Spike input is not long enough, given {spike_length} frames < {self.cfg.model_length}."
94
133
  # even length
95
- if self.cfg.model_win_length == self.model_half_win_length * 2:
134
+ if self.model_length == self.model_half_length * 2:
96
135
  spike = spike[
97
136
  :,
98
- spike_mid - self.model_half_win_length : spike_mid + self.model_half_win_length,
137
+ spike_mid - self.model_half_length : spike_mid + self.model_half_length,
99
138
  ]
100
139
  # odd length
101
140
  else:
102
141
  spike = spike[
103
142
  :,
104
- spike_mid - self.model_half_win_length : spike_mid + self.model_half_win_length + 1,
143
+ spike_mid - self.model_half_length : spike_mid + self.model_half_length + 1,
105
144
  ]
106
145
  self.spike_size = (spike.shape[2], spike.shape[3])
107
146
  return spike
@@ -116,60 +155,77 @@ class BaseModel(nn.Module):
116
155
  return image
117
156
 
118
157
  # -------------------- Training Part --------------------
119
- def setup_training(self, pipeline_cfg):
120
- """Setup training optimizer and loss."""
121
- from spikezoo.pipeline import TrainPipelineConfig
122
-
123
- cfg: TrainPipelineConfig = pipeline_cfg
124
- self.optimizer = Adam(self.net.parameters(), lr=cfg.lr, betas=(0.9, 0.99), weight_decay=0)
125
- self.scheduler = CosineAnnealingLR(self.optimizer, T_max=cfg.epochs, eta_min=0)
126
- self.criterion = nn.L1Loss()
127
-
128
158
  def get_outputs_dict(self, batch):
129
159
  """Get the output dict for the given input batch. (Designed for the training mode considering possible auxiliary output.)"""
130
160
  # data process
131
161
  spike = batch["spike"]
162
+ rate = batch["rate"].view(-1, 1, 1, 1).float()
132
163
  # outputs
133
164
  outputs = {}
134
- recon_img = self(spike)
135
- outputs["recon_img"] = recon_img
165
+ recon_img = self.spk2img(spike)
166
+ outputs["recon_img"] = recon_img / rate
136
167
  return outputs
137
168
 
138
169
  def get_visual_dict(self, batch, outputs):
139
170
  """Get the visual dict from the given input batch and outputs."""
140
171
  visual_dict = {}
141
- visual_dict["recon"] = outputs["recon_img"]
142
- visual_dict["img"] = batch["img"]
172
+ visual_dict["recon_img"] = outputs["recon_img"]
173
+ visual_dict["gt_img"] = batch["gt_img"]
143
174
  return visual_dict
144
175
 
145
- def get_loss_dict(self, outputs, batch):
176
+ def get_loss_dict(self, outputs, batch, loss_weight_dict):
146
177
  """Get the loss dict from the given input batch and outputs."""
147
178
  # data process
148
- gt_img = batch["img"]
179
+ gt_img = batch["gt_img"]
149
180
  # recon image
150
181
  recon_img = outputs["recon_img"]
151
182
  # loss dict
152
183
  loss_dict = {}
153
- loss_dict["l1"] = self.criterion(recon_img, gt_img)
184
+ for loss_name, weight in loss_weight_dict.items():
185
+ loss_dict[loss_name] = weight * self.get_loss_func(loss_name)(recon_img, gt_img)
186
+
187
+ # todo add your desired loss here by loss_dict["name"] = loss()
188
+
154
189
  loss_values_dict = {k: v.item() for k, v in loss_dict.items()}
155
- return loss_dict,loss_values_dict
190
+ return loss_dict, loss_values_dict
191
+
192
+ def get_loss_func(self, name: Literal["l1", "l2"]):
193
+ """Get the loss function from the given loss name."""
194
+ if name not in self.loss_func_cache:
195
+ if name == "l1":
196
+ self.loss_func_cache[name] = nn.L1Loss()
197
+ elif name == "l2":
198
+ self.loss_func_cache[name] = nn.MSELoss()
199
+ else:
200
+ self.loss_func_cache[name] = lambda x, y: 0
201
+ loss_func = self.loss_func_cache[name]
202
+ return loss_func
156
203
 
157
204
  def get_paired_imgs(self, batch, outputs):
205
+ """Get paired images for the metric calculation."""
158
206
  recon_img = outputs["recon_img"]
159
- img = batch["img"]
207
+ img = batch["gt_img"]
160
208
  return recon_img, img
161
209
 
162
- def optimize_parameters(self, loss_dict):
163
- """Optimize the parameters from the loss_dict."""
164
- loss = functools.reduce(torch.add, loss_dict.values())
165
- self.optimizer.zero_grad()
166
- loss.backward()
167
- self.optimizer.step()
168
-
169
- def update_learning_rate(self):
170
- """Update the learning rate."""
171
- self.scheduler.step()
172
-
173
210
  def feed_to_device(self, batch):
211
+ """Feed the batch data to the given device."""
174
212
  batch = {k: v.to(self.device, non_blocking=True) if torch.is_tensor(v) else v for k, v in batch.items()}
175
213
  return batch
214
+
215
+
216
+ # functions
217
+ def get_suffix(model_name, version):
218
+ if version == "v010":
219
+ if model_name in ["ssml", "wgse"]:
220
+ return "pt"
221
+ else:
222
+ return "pth"
223
+ else:
224
+ return "pth"
225
+
226
+
227
+ def get_url_version(version):
228
+ major = version[1]
229
+ minor = version[2]
230
+ patch = version[3]
231
+ return f"v{major}.{minor}.{patch}"
@@ -1,6 +1,12 @@
1
1
  import torch
2
- from dataclasses import dataclass
2
+ from dataclasses import dataclass, field
3
3
  from spikezoo.models.base_model import BaseModel, BaseModelConfig
4
+ from torch.optim import Adam
5
+ import torch.optim.lr_scheduler as lr_scheduler
6
+ import torch.nn as nn
7
+ from spikezoo.pipeline import TrainPipelineConfig
8
+ from typing import List
9
+ from spikezoo.archs.bsf.models.bsf.bsf import BSF
4
10
 
5
11
 
6
12
  @dataclass
@@ -9,9 +15,11 @@ class BSFConfig(BaseModelConfig):
9
15
  model_name: str = "bsf"
10
16
  model_file_name: str = "models.bsf.bsf"
11
17
  model_cls_name: str = "BSF"
12
- model_win_length: int = 61
18
+ model_length: int = 61
19
+ model_length_dict: dict = field(default_factory=lambda: {"v010": 61, "v023": 41})
13
20
  require_params: bool = True
14
- ckpt_path: str = "weights/bsf.pth"
21
+ model_params: dict = field(default_factory=lambda: {})
22
+ model_params_dict: dict = field(default_factory=lambda: {"v010": {"spike_dim": 61}, "v023": {"spike_dim": 41}})
15
23
 
16
24
 
17
25
  class BSF(BaseModel):
@@ -0,0 +1,19 @@
1
+ from dataclasses import dataclass, field
2
+ from spikezoo.models.base_model import BaseModel, BaseModelConfig
3
+ from typing import List
4
+
5
+
6
+ @dataclass
7
+ class SPCSNetConfig(BaseModelConfig):
8
+ # default params for WGSE
9
+ model_name: str = "spcsnet"
10
+ model_file_name: str = "models"
11
+ model_cls_name: str = "SPCS_Net"
12
+ model_win_length: int = 41
13
+ require_params: bool = True
14
+ ckpt_path: str = 'weights/spcsnet.pth'
15
+
16
+
17
+ class SPCSNet(BaseModel):
18
+ def __init__(self, cfg: BaseModelConfig):
19
+ super(SPCSNet, self).__init__(cfg)
@@ -2,7 +2,8 @@ from dataclasses import dataclass
2
2
  from spikezoo.models.base_model import BaseModel, BaseModelConfig
3
3
  import torch
4
4
  import torch.nn.functional as F
5
-
5
+ from dataclasses import field
6
+ from spikezoo.archs.spikeclip.nets import LRN
6
7
 
7
8
  @dataclass
8
9
  class SpikeCLIPConfig(BaseModelConfig):
@@ -10,9 +11,9 @@ class SpikeCLIPConfig(BaseModelConfig):
10
11
  model_name: str = "spikeclip"
11
12
  model_file_name: str = "nets"
12
13
  model_cls_name: str = "LRN"
13
- model_win_length: int = 200
14
+ model_length: int = 200
15
+ model_length_dict: dict = field(default_factory=lambda: {"v010": 200, "v023": 200})
14
16
  require_params: bool = True
15
- ckpt_path: str = "weights/spikeclip.pth"
16
17
 
17
18
 
18
19
  class SpikeCLIP(BaseModel):
@@ -1,7 +1,12 @@
1
1
  import torch
2
2
  from dataclasses import dataclass, field
3
3
  from spikezoo.models.base_model import BaseModel, BaseModelConfig
4
-
4
+ from spikezoo.pipeline import TrainPipelineConfig
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ import torch.optim.lr_scheduler as lr_scheduler
8
+ from typing import List
9
+ from spikezoo.archs.spk2imgnet.nets import SpikeNet
5
10
 
6
11
  @dataclass
7
12
  class Spk2ImgNetConfig(BaseModelConfig):
@@ -9,22 +14,11 @@ class Spk2ImgNetConfig(BaseModelConfig):
9
14
  model_name: str = "spk2imgnet"
10
15
  model_file_name: str = "nets"
11
16
  model_cls_name: str = "SpikeNet"
12
- model_win_length: int = 41
17
+ model_length: int = 41
18
+ model_length_dict: dict = field(default_factory=lambda: {"v010": 41, "v023": 41})
13
19
  require_params: bool = True
14
- ckpt_path: str = "weights/spk2imgnet.pth"
15
20
  light_correction: bool = False
16
21
 
17
- # model params
18
- model_params: dict = field(
19
- default_factory=lambda: {
20
- "in_channels": 13,
21
- "features": 64,
22
- "out_channels": 1,
23
- "win_r": 6,
24
- "win_step": 7,
25
- }
26
- )
27
-
28
22
 
29
23
  class Spk2ImgNet(BaseModel):
30
24
  def __init__(self, cfg: BaseModelConfig):
@@ -45,7 +39,7 @@ class Spk2ImgNet(BaseModel):
45
39
  image = image[:, :, :250, :]
46
40
  elif self.spike_size == (480, 854):
47
41
  image = image[:, :, :, :854]
48
- # used on the REDS_small dataset.
42
+ # used on the REDS_BASE dataset.
49
43
  if self.cfg.light_correction == True:
50
44
  image = torch.clamp(image / 0.6, 0, 1)
51
45
  return image
@@ -1,5 +1,7 @@
1
1
  from dataclasses import dataclass
2
2
  from spikezoo.models.base_model import BaseModel, BaseModelConfig
3
+ from dataclasses import field
4
+ from spikezoo.archs.ssir.models.networks import SSIR
3
5
 
4
6
 
5
7
  @dataclass
@@ -8,15 +10,11 @@ class SSIRConfig(BaseModelConfig):
8
10
  model_name: str = "ssir"
9
11
  model_file_name: str = "models.networks"
10
12
  model_cls_name: str = "SSIR"
11
- model_win_length: int = 41
13
+ model_length: int = 41
14
+ model_length_dict: dict = field(default_factory=lambda: {"v010": 41, "v023": 41})
12
15
  require_params: bool = True
13
- ckpt_path: str = "weights/ssir.pth"
14
16
 
15
17
 
16
18
  class SSIR(BaseModel):
17
19
  def __init__(self, cfg: BaseModelConfig):
18
20
  super(SSIR, self).__init__(cfg)
19
-
20
- def postprocess_img(self, image):
21
- # image = image[0]
22
- return image
@@ -1,5 +1,8 @@
1
1
  from dataclasses import dataclass
2
2
  from spikezoo.models.base_model import BaseModel, BaseModelConfig
3
+ import torch
4
+ from dataclasses import field
5
+ from spikezoo.archs.ssml.model import DoubleNet
3
6
 
4
7
 
5
8
  @dataclass
@@ -8,11 +11,50 @@ class SSMLConfig(BaseModelConfig):
8
11
  model_name: str = "ssml"
9
12
  model_file_name: str = "model"
10
13
  model_cls_name: str = "DoubleNet"
11
- model_win_length: int = 41
14
+ model_length: int = 41
15
+ model_length_dict: dict = field(default_factory=lambda: {"v010": 41, "v023": 41})
16
+ tfp_label_length: int = 11
12
17
  require_params: bool = True
13
- ckpt_path: str = 'weights/ssml.pt'
14
18
 
15
19
 
20
+ # ! A simple version of SSML rather than the full version
16
21
  class SSML(BaseModel):
17
22
  def __init__(self, cfg: BaseModelConfig):
18
23
  super(SSML, self).__init__(cfg)
24
+
25
+ def get_outputs_dict(self, batch):
26
+ # data process
27
+ spike = batch["spike"]
28
+ spike = self.preprocess_spike(spike)
29
+ rate = batch["rate"].view(-1, 1, 1, 1).float()
30
+ # outputs
31
+ outputs = {}
32
+ bsn_pred, nbsn_pred = self.net(spike)
33
+ bsn_pred = self.postprocess_img(bsn_pred)
34
+ nbsn_pred = self.postprocess_img(nbsn_pred)
35
+ outputs["recon_img"] = nbsn_pred / rate
36
+ outputs["bsn_pred"] = bsn_pred / rate
37
+ # tfp-label
38
+ mid = spike.shape[1] // 2
39
+ tfp_label = torch.mean(spike[:, mid - self.cfg.tfp_label_length // 2 : mid + self.cfg.tfp_label_length // 2 + 1], dim=1, keepdim=True)
40
+ outputs["tfp_label"] = self.postprocess_img(tfp_label) / rate
41
+ return outputs
42
+
43
+ def get_visual_dict(self, batch, outputs):
44
+ visual_dict = super().get_visual_dict(batch, outputs)
45
+ visual_dict["bsn_pred"] = outputs["bsn_pred"]
46
+ visual_dict["tfp_label"] = outputs["tfp_label"]
47
+ return visual_dict
48
+
49
+ def get_loss_dict(self, outputs, batch, loss_weight_dict):
50
+ # recon image
51
+ recon_img = outputs["recon_img"]
52
+ bsn_pred = outputs["bsn_pred"]
53
+ tfp_label = outputs["tfp_label"]
54
+ # loss dict
55
+ loss_dict = {}
56
+ for loss_name, weight in loss_weight_dict.items():
57
+ loss_dict["bsn_loss_" + loss_name] = weight * self.get_loss_func(loss_name)(bsn_pred, tfp_label)
58
+ loss_dict["mutual_loss_" + loss_name] = 0.01 * weight * self.get_loss_func(loss_name)(recon_img, bsn_pred)
59
+ loss_values_dict = {k: v.item() for k, v in loss_dict.items()}
60
+ return loss_dict, loss_values_dict
@@ -1,7 +1,14 @@
1
1
  import torch
2
- from dataclasses import dataclass
2
+ from dataclasses import dataclass, field
3
3
  from spikezoo.models.base_model import BaseModel, BaseModelConfig
4
-
4
+ from torch.optim import Adam
5
+ import torch.optim.lr_scheduler as lr_scheduler
6
+ import torch.nn as nn
7
+ from spikezoo.pipeline import TrainPipelineConfig
8
+ from typing import List
9
+ from spikezoo.archs.stir.metrics.losses import compute_per_loss_single
10
+ from spikezoo.archs.stir.models.Vgg19 import Vgg19
11
+ from spikezoo.archs.stir.models.networks_STIR import STIR
5
12
 
6
13
  @dataclass
7
14
  class STIRConfig(BaseModelConfig):
@@ -9,9 +16,11 @@ class STIRConfig(BaseModelConfig):
9
16
  model_name: str = "stir"
10
17
  model_file_name: str = "models.networks_STIR"
11
18
  model_cls_name: str = "STIR"
12
- model_win_length: int = 61
19
+ model_length: int = 61
20
+ model_length_dict: dict = field(default_factory=lambda: {"v010": 61, "v023": 41})
13
21
  require_params: bool = True
14
- ckpt_path: str = "weights/stir.pth"
22
+ model_params: dict = field(default_factory=lambda: {})
23
+ model_params_dict: dict = field(default_factory=lambda: {"v010": {"spike_dim": 61}, "v023": {"spike_dim": 41}})
15
24
 
16
25
 
17
26
  class STIR(BaseModel):
@@ -29,9 +38,21 @@ class STIR(BaseModel):
29
38
  return spike
30
39
 
31
40
  def postprocess_img(self, image):
32
- # recon, Fs_lv_0, Fs_lv_1, Fs_lv_2, Fs_lv_3, Fs_lv_4, Est = image
33
41
  if self.spike_size == (250, 400):
34
42
  image = image[:, :, :250, :]
35
43
  elif self.spike_size == (480, 854):
36
44
  image = image[:, :, :, :854]
37
45
  return image
46
+
47
+ def get_outputs_dict(self, batch):
48
+ # data process
49
+ spike = batch["spike"]
50
+ rate = batch["rate"].view(-1, 1, 1, 1).float()
51
+ # outputs
52
+ outputs = {}
53
+ spike = self.preprocess_spike(spike)
54
+ # pyramid loss is omitted owing to limited performance gain.
55
+ img_pred_0, Fs_lv_0, Fs_lv_1, Fs_lv_2, Fs_lv_3, Fs_lv_4, Est = self.net(spike)
56
+ img_pred_0 = self.postprocess_img(img_pred_0)
57
+ outputs["recon_img"] = img_pred_0 / rate
58
+ return outputs
@@ -1,5 +1,6 @@
1
1
  from dataclasses import dataclass, field
2
2
  from spikezoo.models.base_model import BaseModel, BaseModelConfig
3
+ from spikezoo.archs.tfi.nets import TFIModel
3
4
 
4
5
 
5
6
  @dataclass
@@ -8,7 +9,8 @@ class TFIConfig(BaseModelConfig):
8
9
  model_name: str = "tfi"
9
10
  model_file_name: str = "nets"
10
11
  model_cls_name: str = "TFIModel"
11
- model_win_length: int = 41
12
+ model_length: int = 41
13
+ model_length_dict: dict = field(default_factory=lambda: {"v010": 41, "v023": 41})
12
14
  require_params: bool = False
13
15
  model_params: dict = field(default_factory=lambda: {"model_win_length": 41})
14
16
 
@@ -1,5 +1,6 @@
1
- from dataclasses import dataclass,field
1
+ from dataclasses import dataclass, field
2
2
  from spikezoo.models.base_model import BaseModel, BaseModelConfig
3
+ from spikezoo.archs.tfp.nets import TFPModel
3
4
 
4
5
 
5
6
  @dataclass
@@ -8,7 +9,8 @@ class TFPConfig(BaseModelConfig):
8
9
  model_name: str = "tfp"
9
10
  model_file_name: str = "nets"
10
11
  model_cls_name: str = "TFPModel"
11
- model_win_length: int = 41
12
+ model_length: int = 41
13
+ model_length_dict: dict = field(default_factory=lambda: {"v010": 41, "v023": 41})
12
14
  require_params: bool = False
13
15
  model_params: dict = field(default_factory=lambda: {"model_win_length": 41})
14
16