spikezoo 0.2.3.5__py3-none-any.whl → 0.2.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
  2. spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
  3. spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
  4. spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
  5. spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
  6. spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
  7. spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
  8. spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
  9. spikezoo/archs/stir/metrics/__pycache__/losses.cpython-39.pyc +0 -0
  10. spikezoo/archs/stir/models/__pycache__/Vgg19.cpython-39.pyc +0 -0
  11. spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
  12. spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
  13. spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
  14. spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
  15. spikezoo/archs/stir/package_core/package_core/__pycache__/geometry.cpython-39.pyc +0 -0
  16. spikezoo/archs/stir/package_core/package_core/__pycache__/image_proc.cpython-39.pyc +0 -0
  17. spikezoo/archs/stir/package_core/package_core/__pycache__/losses.cpython-39.pyc +0 -0
  18. spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
  19. spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
  20. spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
  21. spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
  22. spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
  23. spikezoo/archs/yourmodel/arch/__pycache__/net.cpython-39.pyc +0 -0
  24. spikezoo/archs/yourmodel/arch/net.py +35 -0
  25. spikezoo/datasets/__init__.py +20 -21
  26. spikezoo/datasets/base_dataset.py +25 -19
  27. spikezoo/datasets/{realworld_dataset.py → realdata_dataset.py} +5 -7
  28. spikezoo/datasets/reds_base_dataset.py +1 -1
  29. spikezoo/datasets/szdata_dataset.py +1 -1
  30. spikezoo/datasets/uhsr_dataset.py +1 -1
  31. spikezoo/datasets/yourdataset_dataset.py +23 -0
  32. spikezoo/models/__init__.py +11 -18
  33. spikezoo/models/base_model.py +10 -4
  34. spikezoo/models/yourmodel_model.py +22 -0
  35. spikezoo/pipeline/base_pipeline.py +17 -10
  36. spikezoo/pipeline/ensemble_pipeline.py +2 -1
  37. spikezoo/pipeline/train_cfgs.py +32 -29
  38. spikezoo/pipeline/train_pipeline.py +14 -14
  39. spikezoo/utils/spike_utils.py +1 -1
  40. spikezoo-0.2.3.7.dist-info/METADATA +151 -0
  41. {spikezoo-0.2.3.5.dist-info → spikezoo-0.2.3.7.dist-info}/RECORD +44 -41
  42. spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
  43. spikezoo-0.2.3.5.dist-info/METADATA +0 -258
  44. {spikezoo-0.2.3.5.dist-info → spikezoo-0.2.3.7.dist-info}/LICENSE.txt +0 -0
  45. {spikezoo-0.2.3.5.dist-info → spikezoo-0.2.3.7.dist-info}/WHEEL +0 -0
  46. {spikezoo-0.2.3.5.dist-info → spikezoo-0.2.3.7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,35 @@
1
+ import torch.nn as nn
2
+
3
+ def conv_layer(inDim, outDim, ks, s, p, norm_layer="none"):
4
+ ## convolutional layer
5
+ conv = nn.Conv2d(inDim, outDim, kernel_size=ks, stride=s, padding=p)
6
+ relu = nn.ReLU(True)
7
+ assert norm_layer in ("batch", "instance", "none")
8
+ if norm_layer == "none":
9
+ seq = nn.Sequential(*[conv, relu])
10
+ else:
11
+ if norm_layer == "instance":
12
+ norm = nn.InstanceNorm2d(outDim, affine=False, track_running_stats=False) # instance norm
13
+ else:
14
+ momentum = 0.1
15
+ norm = nn.BatchNorm2d(outDim, momentum=momentum, affine=True, track_running_stats=True)
16
+ seq = nn.Sequential(*[conv, norm, relu])
17
+ return seq
18
+
19
+
20
+ class YourNet(nn.Module):
21
+ """Borrow the structure from the SpikeCLIP. (https://arxiv.org/abs/2501.04477)"""
22
+
23
+ def __init__(self, inDim=41):
24
+ super(YourNet, self).__init__()
25
+ norm = "none"
26
+ outDim = 1
27
+ convBlock1 = conv_layer(inDim, 64, 3, 1, 1)
28
+ convBlock2 = conv_layer(64, 128, 3, 1, 1, norm)
29
+ convBlock3 = conv_layer(128, 64, 3, 1, 1, norm)
30
+ convBlock4 = conv_layer(64, 16, 3, 1, 1, norm)
31
+ conv = nn.Conv2d(16, outDim, 3, 1, 1)
32
+ self.seq = nn.Sequential(*[convBlock1, convBlock2, convBlock3, convBlock4, conv])
33
+
34
+ def forward(self, x):
35
+ return self.seq(x)
@@ -1,4 +1,5 @@
1
1
  from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
2
+
2
3
  from dataclasses import replace
3
4
  import importlib, inspect
4
5
  import os
@@ -12,23 +13,24 @@ dataset_list = [file.replace("_dataset.py", "") for file in files_list if file.e
12
13
 
13
14
 
14
15
  # todo register function
15
- def build_dataset_cfg(cfg: BaseDatasetConfig, split: Literal["train", "test"] = "test"):
16
+ def build_dataset_cfg(cfg: BaseDatasetConfig):
16
17
  """Build the dataset from the given dataset config."""
17
- # build new cfg according to split
18
- cfg = replace(cfg, split=split)
19
18
  # dataset module
20
- module_name = cfg.dataset_name + "_dataset"
21
- assert cfg.dataset_name in dataset_list, f"Given dataset {cfg.dataset_name} not in our dataset list {dataset_list}."
22
- module_name = "spikezoo.datasets." + module_name
23
- module = importlib.import_module(module_name)
24
- # dataset,dataset_config
25
- dataset_name = cfg.dataset_name
26
- dataset_name = dataset_name + "Dataset" if dataset_name == "base" else dataset_name
27
- dataset_cls: BaseDataset = getattr_case_insensitive(module, dataset_name)
19
+ if cfg.dataset_cls_local == None:
20
+ module_name = cfg.dataset_name + "_dataset"
21
+ assert cfg.dataset_name in dataset_list, f"Given dataset {cfg.dataset_name} not in our dataset list {dataset_list}."
22
+ module_name = "spikezoo.datasets." + module_name
23
+ module = importlib.import_module(module_name)
24
+ # dataset,dataset_config
25
+ dataset_name = cfg.dataset_name
26
+ dataset_name = dataset_name + "Dataset" if dataset_name == "base" else dataset_name
27
+ dataset_cls: BaseDataset = getattr_case_insensitive(module, dataset_name)
28
+ else:
29
+ dataset_cls = cfg.dataset_cls_local
28
30
  dataset = dataset_cls(cfg)
29
31
  return dataset
30
32
 
31
- def build_dataset_name(dataset_name: str, split: Literal["train", "test"] = "test"):
33
+ def build_dataset_name(dataset_name: str):
32
34
  """Build the default dataset from the given name."""
33
35
  module_name = dataset_name + "_dataset"
34
36
  assert dataset_name in dataset_list, f"Given dataset {dataset_name} not in our dataset list {dataset_list}."
@@ -37,22 +39,19 @@ def build_dataset_name(dataset_name: str, split: Literal["train", "test"] = "tes
37
39
  # dataset,dataset_config
38
40
  dataset_name = dataset_name + "Dataset" if dataset_name == "base" else dataset_name
39
41
  dataset_cls: BaseDataset = getattr_case_insensitive(module, dataset_name)
40
- dataset_cfg: BaseDatasetConfig = getattr_case_insensitive(module, dataset_name + "config")(split=split)
42
+ dataset_cfg: BaseDatasetConfig = getattr_case_insensitive(module, dataset_name + "config")()
41
43
  dataset = dataset_cls(dataset_cfg)
42
44
  return dataset
43
45
 
44
46
 
45
47
  # todo to modify according to the basicsr
46
- def build_dataloader(dataset: BaseDataset, cfg=None):
48
+ def build_dataloader(dataset, cfg):
47
49
  # train dataloader
48
- if dataset.cfg.split == "train":
49
- if cfg is None:
50
- return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
51
- else:
52
- return torch.utils.data.DataLoader(dataset, batch_size=cfg.bs_train, shuffle=True, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory)
50
+ if dataset.split == "train" and cfg._mode == "train_mode":
51
+ return torch.utils.data.DataLoader(dataset, batch_size=cfg.bs_train, shuffle=True, num_workers=cfg.nw_train, pin_memory=cfg.pin_memory)
53
52
  # test dataloader
54
- elif dataset.cfg.split == "test":
55
- return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
53
+ else:
54
+ return torch.utils.data.DataLoader(dataset, batch_size=cfg.bs_test, shuffle=False, num_workers=cfg.nw_test,pin_memory=False)
56
55
 
57
56
 
58
57
  # dataset_size_dict = {}
@@ -11,6 +11,7 @@ import warnings
11
11
  import torch
12
12
  from tqdm import tqdm
13
13
  from spikezoo.utils.data_utils import Augmentor
14
+ from typing import Optional
14
15
 
15
16
 
16
17
  @dataclass
@@ -36,24 +37,18 @@ class BaseDatasetConfig:
36
37
  img_dir_name: str = "gt"
37
38
  "Rate. (-1 denotes variant)"
38
39
  rate: float = 0.6
39
-
40
+
40
41
  # ------------- Config -------------
41
- "Dataset split: train/test. Default set as the 'test' for evaluation."
42
- split: Literal["train", "test"] = "test"
43
42
  "Use the data augumentation technique or not."
44
43
  use_aug: bool = False
45
44
  "Use cache mechanism."
46
45
  use_cache: bool = False
47
46
  "Crop size."
48
47
  crop_size: tuple = (-1, -1)
49
-
50
-
51
- # post process
52
- def __post_init__(self):
53
- self.spike_length = self.spike_length_train if self.split == "train" else self.spike_length_test
54
- self.root_dir = Path(self.root_dir) if isinstance(self.root_dir, str) else self.root_dir
55
- # todo try download
56
- assert self.root_dir.exists(), f"No files found in {self.root_dir} for the specified dataset `{self.dataset_name}`."
48
+ "Load the dataset from local or spikezoo lib."
49
+ dataset_cls_local: Optional[Dataset] = None
50
+ "Spike load version. [python,cpp]"
51
+ spike_load_version: Literal["python", "cpp"] = "python"
57
52
 
58
53
 
59
54
  # todo cache mechanism
@@ -61,10 +56,6 @@ class BaseDataset(Dataset):
61
56
  def __init__(self, cfg: BaseDatasetConfig):
62
57
  super(BaseDataset, self).__init__()
63
58
  self.cfg = cfg
64
- self.augmentor = Augmentor(self.cfg.crop_size) if self.cfg.use_aug == True and self.cfg.split == "train" else -1
65
- self.prepare_data()
66
- self.cache_data() if cfg.use_cache == True else -1
67
- warnings.warn("Lengths of the image list and the spike list should be equal.") if len(self.img_list) != len(self.spike_list) else -1
68
59
 
69
60
  def __len__(self):
70
61
  return len(self.spike_list)
@@ -79,7 +70,7 @@ class BaseDataset(Dataset):
79
70
  img = self.get_img(idx)
80
71
 
81
72
  # process data
82
- if self.cfg.use_aug == True and self.cfg.split == "train":
73
+ if self.cfg.use_aug == True and self.split == "train":
83
74
  spike, img = self.augmentor(spike, img)
84
75
 
85
76
  # rate
@@ -89,15 +80,29 @@ class BaseDataset(Dataset):
89
80
  batch = {"spike": spike, "gt_img": img, "rate": rate}
90
81
  return batch
91
82
 
83
+ def build_source(self, split: Literal["train", "test"] = "test"):
84
+ """Build the dataset source and prepare to be loaded files."""
85
+ # spike length
86
+ self.split = split
87
+ self.spike_length = self.cfg.spike_length_train if self.split == "train" else self.cfg.spike_length_test
88
+ # root dir
89
+ self.cfg.root_dir = Path(self.cfg.root_dir) if isinstance(self.cfg.root_dir, str) else self.cfg.root_dir
90
+ assert self.cfg.root_dir.exists(), f"No files found in {self.cfg.root_dir} for the specified dataset `{self.cfg.dataset_name}`."
91
+ # prepare
92
+ self.augmentor = Augmentor(self.cfg.crop_size) if self.cfg.use_aug == True and self.split == "train" else -1
93
+ self.prepare_data()
94
+ self.cache_data() if self.cfg.use_cache == True else -1
95
+ warnings.warn("Lengths of the image list and the spike list should be equal.") if len(self.img_list) != len(self.spike_list) else -1
96
+
92
97
  # todo: To be overridden
93
98
  def prepare_data(self):
94
99
  """Specify the spike and image files to be loaded."""
95
100
  # spike
96
- self.spike_dir = self.cfg.root_dir / self.cfg.split / self.cfg.spike_dir_name
101
+ self.spike_dir = self.cfg.root_dir / self.split / self.cfg.spike_dir_name
97
102
  self.spike_list = self.get_spike_files(self.spike_dir)
98
103
  # gt
99
104
  if self.cfg.with_img == True:
100
- self.img_dir = self.cfg.root_dir / self.cfg.split / self.cfg.img_dir_name
105
+ self.img_dir = self.cfg.root_dir / self.split / self.cfg.img_dir_name
101
106
  self.img_list = self.get_image_files(self.img_dir)
102
107
 
103
108
  # todo: To be overridden
@@ -115,12 +120,13 @@ class BaseDataset(Dataset):
115
120
  height=self.cfg.height,
116
121
  width=self.cfg.width,
117
122
  out_format="tensor",
123
+ version=self.cfg.spike_load_version
118
124
  )
119
125
  return spike
120
126
 
121
127
  def get_spike(self, idx):
122
128
  """Get and process the spike stream from the given idx."""
123
- spike_length = self.cfg.spike_length
129
+ spike_length = self.spike_length
124
130
  spike = self.load_spike(idx)
125
131
  assert spike.shape[0] >= spike_length, f"Given spike length {spike.shape[0]} smaller than the required length {spike_length}"
126
132
  spike_mid = spike.shape[0] // 2
@@ -4,21 +4,19 @@ from dataclasses import dataclass
4
4
 
5
5
 
6
6
  @dataclass
7
- class RealWorldConfig(BaseDatasetConfig):
8
- dataset_name: str = "realworld"
9
- root_dir: Path = Path(__file__).parent.parent / Path("data/recVidarReal2019")
7
+ class RealDataConfig(BaseDatasetConfig):
8
+ dataset_name: str = "realdata"
9
+ root_dir: Path = Path(__file__).parent.parent / Path("data/realdata")
10
10
  width: int = 400
11
11
  height: int = 250
12
12
  with_img: bool = False
13
13
  spike_length_train: int = -1
14
14
  spike_length_test: int = -1
15
15
  rate: float = 1
16
-
17
16
 
18
-
19
- class RealWorld(BaseDataset):
17
+ class RealData(BaseDataset):
20
18
  def __init__(self, cfg: BaseDatasetConfig):
21
- super(RealWorld, self).__init__(cfg)
19
+ super(RealData, self).__init__(cfg)
22
20
 
23
21
  def prepare_data(self):
24
22
  self.spike_dir = self.cfg.root_dir
@@ -8,7 +8,7 @@ import re
8
8
  @dataclass
9
9
  class REDS_BASEConfig(BaseDatasetConfig):
10
10
  dataset_name: str = "reds_base"
11
- root_dir: Path = Path(__file__).parent.parent / Path("data/REDS_BASE")
11
+ root_dir: Path = Path(__file__).parent.parent / Path("data/reds_base")
12
12
  width: int = 400
13
13
  height: int = 250
14
14
  with_img: bool = True
@@ -9,7 +9,7 @@ import numpy as np
9
9
  @dataclass
10
10
  class SZDataConfig(BaseDatasetConfig):
11
11
  dataset_name: str = "szdata"
12
- root_dir: Path = Path(__file__).parent.parent / Path("data/dataset")
12
+ root_dir: Path = Path(__file__).parent.parent / Path("data/szdata")
13
13
  width: int = 400
14
14
  height: int = 250
15
15
  with_img: bool = True
@@ -7,7 +7,7 @@ import torch
7
7
  @dataclass
8
8
  class UHSRConfig(BaseDatasetConfig):
9
9
  dataset_name: str = "uhsr"
10
- root_dir: Path = Path(__file__).parent.parent / Path("data/U-CALTECH")
10
+ root_dir: Path = Path(__file__).parent.parent / Path("data/u_caltech")
11
11
  width: int = 224
12
12
  height: int = 224
13
13
  with_img: bool = False
@@ -0,0 +1,23 @@
1
+ from torch.utils.data import Dataset
2
+ from pathlib import Path
3
+ from dataclasses import dataclass
4
+ from typing import Literal, Union
5
+ from typing import Optional
6
+ from spikezoo.datasets.base_dataset import BaseDatasetConfig,BaseDataset
7
+
8
+ @dataclass
9
+ class YourDatasetConfig(BaseDatasetConfig):
10
+ dataset_name: str = "yourdataset"
11
+ root_dir: Union[str, Path] = Path(__file__).parent.parent / Path("data/your_data_path")
12
+ width: int = 400
13
+ height: int = 250
14
+ with_img: bool = True
15
+ spike_length_train: int = -1
16
+ spike_length_test: int = -1
17
+ spike_dir_name: str = "spike_data"
18
+ img_dir_name: str = "sharp_data"
19
+ rate: float = 1
20
+
21
+ class YourDataset(BaseDataset):
22
+ def __init__(self, cfg: BaseDatasetConfig):
23
+ super(YourDataset, self).__init__(cfg)
@@ -1,16 +1,6 @@
1
1
  import importlib
2
2
  import inspect
3
3
  from spikezoo.models.base_model import BaseModel,BaseModelConfig
4
- from spikezoo.models.tfp_model import TFPModel,TFPConfig
5
- from spikezoo.models.tfi_model import TFIModel,TFIConfig
6
- from spikezoo.models.spk2imgnet_model import Spk2ImgNet,Spk2ImgNetConfig
7
- from spikezoo.models.wgse_model import WGSE,WGSEConfig
8
- from spikezoo.models.ssml_model import SSML,SSMLConfig
9
- from spikezoo.models.bsf_model import BSF,BSFConfig
10
- from spikezoo.models.stir_model import STIR,STIRConfig
11
- from spikezoo.models.ssir_model import SSIR,SSIRConfig
12
- from spikezoo.models.spikeclip_model import SpikeCLIP,SpikeCLIPConfig
13
-
14
4
 
15
5
  from spikezoo.utils.other_utils import getattr_case_insensitive
16
6
  import os
@@ -24,14 +14,17 @@ model_list = [file.split("_")[0] for file in files_list if file.endswith("_model
24
14
  def build_model_cfg(cfg: BaseModelConfig):
25
15
  """Build the model from the given model config."""
26
16
  # model module name
27
- module_name = cfg.model_name + "_model"
28
- assert cfg.model_name in model_list, f"Given model {cfg.model_name} not in our model zoo {model_list}."
29
- module_name = "spikezoo.models." + module_name
30
- module = importlib.import_module(module_name)
31
- # model,model_config
32
- model_name = cfg.model_name
33
- model_name = model_name + 'Model' if model_name == "base" else model_name
34
- model_cls: BaseModel = getattr_case_insensitive(module,model_name)
17
+ if cfg.model_cls_local == None:
18
+ module_name = cfg.model_name + "_model"
19
+ assert cfg.model_name in model_list, f"Given model {cfg.model_name} not in our model zoo {model_list}."
20
+ module_name = "spikezoo.models." + module_name
21
+ module = importlib.import_module(module_name)
22
+ # model,model_config
23
+ model_name = cfg.model_name
24
+ model_name = model_name + 'Model' if model_name == "base" else model_name
25
+ model_cls: BaseModel = getattr_case_insensitive(module,model_name)
26
+ else:
27
+ model_cls: BaseModel = cfg.model_cls_local
35
28
  model = model_cls(cfg)
36
29
  return model
37
30
 
@@ -44,7 +44,10 @@ class BaseModelConfig:
44
44
  multi_gpu: bool = False
45
45
  "Base url."
46
46
  base_url: str = "https://github.com/chenkang455/Spike-Zoo/releases/download"
47
-
47
+ "Load the model from local class or spikezoo lib. (None)"
48
+ model_cls_local: Optional[nn.Module] = None
49
+ "Load the arch from local class or spikezoo lib. (None)"
50
+ arch_cls_local: Optional[nn.Module] = None
48
51
 
49
52
  class BaseModel(nn.Module):
50
53
  def __init__(self, cfg: BaseModelConfig):
@@ -71,8 +74,11 @@ class BaseModel(nn.Module):
71
74
  ):
72
75
  """Build the network and load the pretrained weight."""
73
76
  # network
74
- module = importlib.import_module(f"spikezoo.archs.{self.cfg.model_name}.{self.cfg.model_file_name}")
75
- model_cls = getattr(module, self.cfg.model_cls_name)
77
+ if self.cfg.arch_cls_local == None:
78
+ module = importlib.import_module(f"spikezoo.archs.{self.cfg.model_name}.{self.cfg.model_file_name}")
79
+ model_cls = getattr(module, self.cfg.model_cls_name)
80
+ else:
81
+ model_cls = self.cfg.arch_cls_local
76
82
  # load model config parameters
77
83
  if version == "local":
78
84
  model = model_cls(**self.cfg.model_params)
@@ -129,7 +135,7 @@ class BaseModel(nn.Module):
129
135
  """Crop the spike length."""
130
136
  spike_length = spike.shape[1]
131
137
  spike_mid = spike_length // 2
132
- assert spike_length >= self.model_length, f"Spike input is not long enough, given {spike_length} frames < {self.cfg.model_length}."
138
+ assert spike_length >= self.model_length, f"Spike input is not long enough, given {spike_length} frames < {self.cfg.model_length} required by the {self.cfg.model_name}."
133
139
  # even length
134
140
  if self.model_length == self.model_half_length * 2:
135
141
  spike = spike[
@@ -0,0 +1,22 @@
1
+ from torch.utils.data import Dataset
2
+ from pathlib import Path
3
+ from dataclasses import dataclass
4
+ from typing import Literal, Union
5
+ from typing import Optional
6
+ from spikezoo.models.base_model import BaseModel, BaseModelConfig
7
+ from dataclasses import field
8
+ import torch.nn as nn
9
+
10
+
11
+ @dataclass
12
+ class YourModelConfig(BaseModelConfig):
13
+ model_name: str = "yourmodel" # 需与文件名保持一致
14
+ model_file_name: str = "arch.net" # archs路径下的模块路径
15
+ model_cls_name: str = "YourNet" # 模型类名
16
+ model_length: int = 41
17
+ require_params: bool = True
18
+ model_params: dict = field(default_factory=lambda: {"inDim": 41})
19
+
20
+ class YourModel(BaseModel):
21
+ def __init__(self, cfg: BaseModelConfig):
22
+ super(YourModel, self).__init__(cfg)
@@ -34,11 +34,17 @@ class PipelineConfig:
34
34
  "Evaluate metrics or not."
35
35
  save_metric: bool = True
36
36
  "Metric names for evaluation."
37
- metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim","niqe","brisque"])
37
+ metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim", "niqe", "brisque"])
38
38
  "Save recoverd images or not."
39
39
  save_img: bool = True
40
40
  "Normalizing recoverd images and gt or not."
41
- save_img_norm: bool = False
41
+ img_norm: bool = False
42
+ "Batch size for the test dataloader."
43
+ bs_test: int = 1
44
+ "Num_workers for the test dataloader."
45
+ nw_test: int = 0
46
+ "Pin_memory true or false for the dataloader."
47
+ pin_memory: bool = False
42
48
  "Different modes for the pipeline."
43
49
  _mode: Literal["single_mode", "multi_mode", "train_mode"] = "single_mode"
44
50
 
@@ -63,7 +69,8 @@ class Pipeline:
63
69
  torch.set_grad_enabled(False)
64
70
  # dataset
65
71
  self.dataset: BaseDataset = build_dataset_name(dataset_cfg) if isinstance(dataset_cfg, str) else build_dataset_cfg(dataset_cfg)
66
- self.dataloader = build_dataloader(self.dataset)
72
+ self.dataset.build_source(split="test")
73
+ self.dataloader = build_dataloader(self.dataset,self.cfg)
67
74
  # device
68
75
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
69
76
 
@@ -103,7 +110,7 @@ class Pipeline:
103
110
  """Function I---Save the recoverd image and calculate the metric from the given dataset."""
104
111
  # save folder
105
112
  self.logger.info("*********************** infer_from_dataset ***********************")
106
- save_folder = self.save_folder / Path(f"infer_from_dataset/{self.dataset.cfg.dataset_name}_dataset/{self.dataset.cfg.split}/{idx:06d}")
113
+ save_folder = self.save_folder / Path(f"infer_from_dataset/{self.dataset.cfg.dataset_name}_dataset/{self.dataset.split}/{idx:06d}")
107
114
  os.makedirs(str(save_folder), exist_ok=True)
108
115
 
109
116
  # data process
@@ -117,7 +124,7 @@ class Pipeline:
117
124
  img = None
118
125
  return self.infer(spike, img, save_folder, rate)
119
126
 
120
- def infer_from_file(self, file_path, height=-1, width=-1, img_path=None, rate=1, remove_head=False):
127
+ def infer_from_file(self, file_path, height=-1, width=-1, rate=1, img_path=None, remove_head=False):
121
128
  """Function II---Save the recoverd image and calculate the metric from the given input file."""
122
129
  # save folder
123
130
  self.logger.info("*********************** infer_from_file ***********************")
@@ -144,7 +151,7 @@ class Pipeline:
144
151
  spike = torch.from_numpy(spike)[None].to(self.device)
145
152
  return self.infer(spike, img, save_folder, rate)
146
153
 
147
- def infer_from_spk(self, spike, img=None, rate=1):
154
+ def infer_from_spk(self, spike, rate=1, img=None):
148
155
  """Function III---Save the recoverd image and calculate the metric from the given spike stream."""
149
156
  # save folder
150
157
  self.logger.info("*********************** infer_from_spk ***********************")
@@ -181,7 +188,7 @@ class Pipeline:
181
188
  for idx in range(len(self.dataset)):
182
189
  self.infer_from_dataset(idx=idx)
183
190
  self.cfg.save_metric = base_setting
184
-
191
+
185
192
  # TODO: To be overridden
186
193
  def cal_params(self):
187
194
  """Function VI---Calculate the parameters/flops/latency of the given method."""
@@ -228,8 +235,8 @@ class Pipeline:
228
235
  # With no GT
229
236
  if recon_img == None:
230
237
  return None
231
- # TFP, TFI, spikeclip algorithms are normalized automatically, others are normalized based on the self.cfg.use_norm
232
- if model_name in ["tfp", "tfi", "spikeclip"] or self.cfg.save_img_norm == True:
238
+ # spikeclip is normalized automatically
239
+ if model_name in ["spikeclip"] or self.cfg.img_norm == True:
233
240
  recon_img = (recon_img - recon_img.min()) / (recon_img.max() - recon_img.min())
234
241
  else:
235
242
  recon_img = recon_img / rate
@@ -253,7 +260,7 @@ class Pipeline:
253
260
  batch = model.feed_to_device(batch)
254
261
  outputs = model.get_outputs_dict(batch)
255
262
  recon_img, img = model.get_paired_imgs(batch, outputs)
256
- recon_img, img = self._post_process_img(recon_img, model_name), self._post_process_img(img, "auto")
263
+ recon_img, img = self._post_process_img(recon_img, model_name), self._post_process_img(img, "gt")
257
264
  for metric_name in metrics_dict.keys():
258
265
  if metric_name in metric_pair_names:
259
266
  metrics_dict[metric_name].update(cal_metric_pair(recon_img, img, metric_name))
@@ -48,7 +48,8 @@ class EnsemblePipeline(Pipeline):
48
48
  torch.set_grad_enabled(False)
49
49
  # data
50
50
  self.dataset: BaseDataset = build_dataset_name(dataset_cfg) if isinstance(dataset_cfg, str) else build_dataset_cfg(dataset_cfg)
51
- self.dataloader = build_dataloader(self.dataset)
51
+ self.dataset.build_source(split = "test")
52
+ self.dataloader = build_dataloader(self.dataset,self.cfg)
52
53
  # device
53
54
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
54
55
 
@@ -18,50 +18,53 @@ class REDS_BASE_TrainConfig(TrainPipelineConfig):
18
18
  steps_per_save_imgs: int = 200
19
19
  steps_per_save_ckpt: int = 500
20
20
  steps_per_cal_metrics: int = 100
21
- metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim","lpips","niqe","brisque","piqe"])
21
+ metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim", "lpips", "niqe", "brisque", "piqe"])
22
22
 
23
23
  # dataloader setting
24
24
  bs_train: int = 8
25
- num_workers: int = 4
25
+ nw_train: int = 4
26
26
  pin_memory: bool = False
27
27
 
28
28
  # train setting - optimizer & scheduler & loss_dict
29
- optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4)
30
- scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[400], gamma=0.2) # from wgse
29
+ optimizer_cfg: OptimizerConfig = field(default_factory=lambda: AdamOptimizerConfig(lr=1e-4))
30
+ scheduler_cfg: Optional[SchedulerConfig] = field(
31
+ default_factory=lambda: MultiStepSchedulerConfig(milestones=[400], gamma=0.2)
32
+ ) # from wgse
31
33
  loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
32
34
 
33
- # ! Train Config for each method on the official setting, not recommended to utilize their default parameters owing to the dataset setting.
34
- @dataclass
35
- class BSFTrainConfig(TrainPipelineConfig):
36
- """Training setting for BSF. https://github.com/ruizhao26/BSF"""
37
35
 
38
- optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4, weight_decay=0.0)
39
- scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[10, 20, 30, 40, 50, 60, 70, 80, 90, 100])
40
- loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
36
+ # # ! Train Config for each method on the official setting, not recommended to utilize their default parameters owing to the dataset setting.
37
+ # @dataclass
38
+ # class BSFTrainConfig(TrainPipelineConfig):
39
+ # """Training setting for BSF. https://github.com/ruizhao26/BSF"""
41
40
 
41
+ # optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4, weight_decay=0.0)
42
+ # scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[10, 20, 30, 40, 50, 60, 70, 80, 90, 100])
43
+ # loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
42
44
 
43
- @dataclass
44
- class WGSETrainConfig(TrainPipelineConfig):
45
- """Training setting for WGSE. https://github.com/Leozhangjiyuan/WGSE-SpikeCamera"""
46
45
 
47
- optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4, betas=(0.9, 0.99), weight_decay=0)
48
- scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[400, 600], gamma=0.2)
49
- loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
46
+ # @dataclass
47
+ # class WGSETrainConfig(TrainPipelineConfig):
48
+ # """Training setting for WGSE. https://github.com/Leozhangjiyuan/WGSE-SpikeCamera"""
50
49
 
50
+ # optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4, betas=(0.9, 0.99), weight_decay=0)
51
+ # scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[400, 600], gamma=0.2)
52
+ # loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
51
53
 
52
- @dataclass
53
- class STIRTrainConfig(TrainPipelineConfig):
54
- """Training setting for STIR. https://github.com/GitCVfb/STIR"""
55
54
 
56
- optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4, betas=(0.9, 0.999))
57
- scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70], gamma=0.7)
58
- loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
55
+ # @dataclass
56
+ # class STIRTrainConfig(TrainPipelineConfig):
57
+ # """Training setting for STIR. https://github.com/GitCVfb/STIR"""
59
58
 
59
+ # optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4, betas=(0.9, 0.999))
60
+ # scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70], gamma=0.7)
61
+ # loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
60
62
 
61
- @dataclass
62
- class Spk2ImgNetTrainConfig(TrainPipelineConfig):
63
- """Training setting for Spk2ImgNet. https://github.com/Vspacer/Spk2ImgNet"""
64
63
 
65
- optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4)
66
- scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[20], gamma=0.1)
67
- loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
64
+ # @dataclass
65
+ # class Spk2ImgNetTrainConfig(TrainPipelineConfig):
66
+ # """Training setting for Spk2ImgNet. https://github.com/Vspacer/Spk2ImgNet"""
67
+
68
+ # optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4)
69
+ # scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[20], gamma=0.1)
70
+ # loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})