spikezoo 0.2.3.5__py3-none-any.whl → 0.2.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
- spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
- spikezoo/archs/stir/metrics/__pycache__/losses.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/Vgg19.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/geometry.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/image_proc.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/losses.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
- spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/yourmodel/arch/__pycache__/net.cpython-39.pyc +0 -0
- spikezoo/archs/yourmodel/arch/net.py +35 -0
- spikezoo/datasets/__init__.py +20 -21
- spikezoo/datasets/base_dataset.py +25 -19
- spikezoo/datasets/{realworld_dataset.py → realdata_dataset.py} +5 -7
- spikezoo/datasets/reds_base_dataset.py +1 -1
- spikezoo/datasets/szdata_dataset.py +1 -1
- spikezoo/datasets/uhsr_dataset.py +1 -1
- spikezoo/datasets/yourdataset_dataset.py +23 -0
- spikezoo/models/__init__.py +11 -18
- spikezoo/models/base_model.py +10 -4
- spikezoo/models/yourmodel_model.py +22 -0
- spikezoo/pipeline/base_pipeline.py +17 -10
- spikezoo/pipeline/ensemble_pipeline.py +2 -1
- spikezoo/pipeline/train_cfgs.py +32 -29
- spikezoo/pipeline/train_pipeline.py +14 -14
- spikezoo/utils/spike_utils.py +1 -1
- spikezoo-0.2.3.7.dist-info/METADATA +151 -0
- {spikezoo-0.2.3.5.dist-info → spikezoo-0.2.3.7.dist-info}/RECORD +44 -41
- spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
- spikezoo-0.2.3.5.dist-info/METADATA +0 -258
- {spikezoo-0.2.3.5.dist-info → spikezoo-0.2.3.7.dist-info}/LICENSE.txt +0 -0
- {spikezoo-0.2.3.5.dist-info → spikezoo-0.2.3.7.dist-info}/WHEEL +0 -0
- {spikezoo-0.2.3.5.dist-info → spikezoo-0.2.3.7.dist-info}/top_level.txt +0 -0
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
@@ -0,0 +1,35 @@
|
|
1
|
+
import torch.nn as nn
|
2
|
+
|
3
|
+
def conv_layer(inDim, outDim, ks, s, p, norm_layer="none"):
|
4
|
+
## convolutional layer
|
5
|
+
conv = nn.Conv2d(inDim, outDim, kernel_size=ks, stride=s, padding=p)
|
6
|
+
relu = nn.ReLU(True)
|
7
|
+
assert norm_layer in ("batch", "instance", "none")
|
8
|
+
if norm_layer == "none":
|
9
|
+
seq = nn.Sequential(*[conv, relu])
|
10
|
+
else:
|
11
|
+
if norm_layer == "instance":
|
12
|
+
norm = nn.InstanceNorm2d(outDim, affine=False, track_running_stats=False) # instance norm
|
13
|
+
else:
|
14
|
+
momentum = 0.1
|
15
|
+
norm = nn.BatchNorm2d(outDim, momentum=momentum, affine=True, track_running_stats=True)
|
16
|
+
seq = nn.Sequential(*[conv, norm, relu])
|
17
|
+
return seq
|
18
|
+
|
19
|
+
|
20
|
+
class YourNet(nn.Module):
|
21
|
+
"""Borrow the structure from the SpikeCLIP. (https://arxiv.org/abs/2501.04477)"""
|
22
|
+
|
23
|
+
def __init__(self, inDim=41):
|
24
|
+
super(YourNet, self).__init__()
|
25
|
+
norm = "none"
|
26
|
+
outDim = 1
|
27
|
+
convBlock1 = conv_layer(inDim, 64, 3, 1, 1)
|
28
|
+
convBlock2 = conv_layer(64, 128, 3, 1, 1, norm)
|
29
|
+
convBlock3 = conv_layer(128, 64, 3, 1, 1, norm)
|
30
|
+
convBlock4 = conv_layer(64, 16, 3, 1, 1, norm)
|
31
|
+
conv = nn.Conv2d(16, outDim, 3, 1, 1)
|
32
|
+
self.seq = nn.Sequential(*[convBlock1, convBlock2, convBlock3, convBlock4, conv])
|
33
|
+
|
34
|
+
def forward(self, x):
|
35
|
+
return self.seq(x)
|
spikezoo/datasets/__init__.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
|
2
|
+
|
2
3
|
from dataclasses import replace
|
3
4
|
import importlib, inspect
|
4
5
|
import os
|
@@ -12,23 +13,24 @@ dataset_list = [file.replace("_dataset.py", "") for file in files_list if file.e
|
|
12
13
|
|
13
14
|
|
14
15
|
# todo register function
|
15
|
-
def build_dataset_cfg(cfg: BaseDatasetConfig
|
16
|
+
def build_dataset_cfg(cfg: BaseDatasetConfig):
|
16
17
|
"""Build the dataset from the given dataset config."""
|
17
|
-
# build new cfg according to split
|
18
|
-
cfg = replace(cfg, split=split)
|
19
18
|
# dataset module
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
19
|
+
if cfg.dataset_cls_local == None:
|
20
|
+
module_name = cfg.dataset_name + "_dataset"
|
21
|
+
assert cfg.dataset_name in dataset_list, f"Given dataset {cfg.dataset_name} not in our dataset list {dataset_list}."
|
22
|
+
module_name = "spikezoo.datasets." + module_name
|
23
|
+
module = importlib.import_module(module_name)
|
24
|
+
# dataset,dataset_config
|
25
|
+
dataset_name = cfg.dataset_name
|
26
|
+
dataset_name = dataset_name + "Dataset" if dataset_name == "base" else dataset_name
|
27
|
+
dataset_cls: BaseDataset = getattr_case_insensitive(module, dataset_name)
|
28
|
+
else:
|
29
|
+
dataset_cls = cfg.dataset_cls_local
|
28
30
|
dataset = dataset_cls(cfg)
|
29
31
|
return dataset
|
30
32
|
|
31
|
-
def build_dataset_name(dataset_name: str
|
33
|
+
def build_dataset_name(dataset_name: str):
|
32
34
|
"""Build the default dataset from the given name."""
|
33
35
|
module_name = dataset_name + "_dataset"
|
34
36
|
assert dataset_name in dataset_list, f"Given dataset {dataset_name} not in our dataset list {dataset_list}."
|
@@ -37,22 +39,19 @@ def build_dataset_name(dataset_name: str, split: Literal["train", "test"] = "tes
|
|
37
39
|
# dataset,dataset_config
|
38
40
|
dataset_name = dataset_name + "Dataset" if dataset_name == "base" else dataset_name
|
39
41
|
dataset_cls: BaseDataset = getattr_case_insensitive(module, dataset_name)
|
40
|
-
dataset_cfg: BaseDatasetConfig = getattr_case_insensitive(module, dataset_name + "config")(
|
42
|
+
dataset_cfg: BaseDatasetConfig = getattr_case_insensitive(module, dataset_name + "config")()
|
41
43
|
dataset = dataset_cls(dataset_cfg)
|
42
44
|
return dataset
|
43
45
|
|
44
46
|
|
45
47
|
# todo to modify according to the basicsr
|
46
|
-
def build_dataloader(dataset
|
48
|
+
def build_dataloader(dataset, cfg):
|
47
49
|
# train dataloader
|
48
|
-
if dataset.
|
49
|
-
|
50
|
-
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
|
51
|
-
else:
|
52
|
-
return torch.utils.data.DataLoader(dataset, batch_size=cfg.bs_train, shuffle=True, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory)
|
50
|
+
if dataset.split == "train" and cfg._mode == "train_mode":
|
51
|
+
return torch.utils.data.DataLoader(dataset, batch_size=cfg.bs_train, shuffle=True, num_workers=cfg.nw_train, pin_memory=cfg.pin_memory)
|
53
52
|
# test dataloader
|
54
|
-
|
55
|
-
return torch.utils.data.DataLoader(dataset, batch_size=
|
53
|
+
else:
|
54
|
+
return torch.utils.data.DataLoader(dataset, batch_size=cfg.bs_test, shuffle=False, num_workers=cfg.nw_test,pin_memory=False)
|
56
55
|
|
57
56
|
|
58
57
|
# dataset_size_dict = {}
|
@@ -11,6 +11,7 @@ import warnings
|
|
11
11
|
import torch
|
12
12
|
from tqdm import tqdm
|
13
13
|
from spikezoo.utils.data_utils import Augmentor
|
14
|
+
from typing import Optional
|
14
15
|
|
15
16
|
|
16
17
|
@dataclass
|
@@ -36,24 +37,18 @@ class BaseDatasetConfig:
|
|
36
37
|
img_dir_name: str = "gt"
|
37
38
|
"Rate. (-1 denotes variant)"
|
38
39
|
rate: float = 0.6
|
39
|
-
|
40
|
+
|
40
41
|
# ------------- Config -------------
|
41
|
-
"Dataset split: train/test. Default set as the 'test' for evaluation."
|
42
|
-
split: Literal["train", "test"] = "test"
|
43
42
|
"Use the data augumentation technique or not."
|
44
43
|
use_aug: bool = False
|
45
44
|
"Use cache mechanism."
|
46
45
|
use_cache: bool = False
|
47
46
|
"Crop size."
|
48
47
|
crop_size: tuple = (-1, -1)
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
self.spike_length = self.spike_length_train if self.split == "train" else self.spike_length_test
|
54
|
-
self.root_dir = Path(self.root_dir) if isinstance(self.root_dir, str) else self.root_dir
|
55
|
-
# todo try download
|
56
|
-
assert self.root_dir.exists(), f"No files found in {self.root_dir} for the specified dataset `{self.dataset_name}`."
|
48
|
+
"Load the dataset from local or spikezoo lib."
|
49
|
+
dataset_cls_local: Optional[Dataset] = None
|
50
|
+
"Spike load version. [python,cpp]"
|
51
|
+
spike_load_version: Literal["python", "cpp"] = "python"
|
57
52
|
|
58
53
|
|
59
54
|
# todo cache mechanism
|
@@ -61,10 +56,6 @@ class BaseDataset(Dataset):
|
|
61
56
|
def __init__(self, cfg: BaseDatasetConfig):
|
62
57
|
super(BaseDataset, self).__init__()
|
63
58
|
self.cfg = cfg
|
64
|
-
self.augmentor = Augmentor(self.cfg.crop_size) if self.cfg.use_aug == True and self.cfg.split == "train" else -1
|
65
|
-
self.prepare_data()
|
66
|
-
self.cache_data() if cfg.use_cache == True else -1
|
67
|
-
warnings.warn("Lengths of the image list and the spike list should be equal.") if len(self.img_list) != len(self.spike_list) else -1
|
68
59
|
|
69
60
|
def __len__(self):
|
70
61
|
return len(self.spike_list)
|
@@ -79,7 +70,7 @@ class BaseDataset(Dataset):
|
|
79
70
|
img = self.get_img(idx)
|
80
71
|
|
81
72
|
# process data
|
82
|
-
if self.cfg.use_aug == True and self.
|
73
|
+
if self.cfg.use_aug == True and self.split == "train":
|
83
74
|
spike, img = self.augmentor(spike, img)
|
84
75
|
|
85
76
|
# rate
|
@@ -89,15 +80,29 @@ class BaseDataset(Dataset):
|
|
89
80
|
batch = {"spike": spike, "gt_img": img, "rate": rate}
|
90
81
|
return batch
|
91
82
|
|
83
|
+
def build_source(self, split: Literal["train", "test"] = "test"):
|
84
|
+
"""Build the dataset source and prepare to be loaded files."""
|
85
|
+
# spike length
|
86
|
+
self.split = split
|
87
|
+
self.spike_length = self.cfg.spike_length_train if self.split == "train" else self.cfg.spike_length_test
|
88
|
+
# root dir
|
89
|
+
self.cfg.root_dir = Path(self.cfg.root_dir) if isinstance(self.cfg.root_dir, str) else self.cfg.root_dir
|
90
|
+
assert self.cfg.root_dir.exists(), f"No files found in {self.cfg.root_dir} for the specified dataset `{self.cfg.dataset_name}`."
|
91
|
+
# prepare
|
92
|
+
self.augmentor = Augmentor(self.cfg.crop_size) if self.cfg.use_aug == True and self.split == "train" else -1
|
93
|
+
self.prepare_data()
|
94
|
+
self.cache_data() if self.cfg.use_cache == True else -1
|
95
|
+
warnings.warn("Lengths of the image list and the spike list should be equal.") if len(self.img_list) != len(self.spike_list) else -1
|
96
|
+
|
92
97
|
# todo: To be overridden
|
93
98
|
def prepare_data(self):
|
94
99
|
"""Specify the spike and image files to be loaded."""
|
95
100
|
# spike
|
96
|
-
self.spike_dir = self.cfg.root_dir / self.
|
101
|
+
self.spike_dir = self.cfg.root_dir / self.split / self.cfg.spike_dir_name
|
97
102
|
self.spike_list = self.get_spike_files(self.spike_dir)
|
98
103
|
# gt
|
99
104
|
if self.cfg.with_img == True:
|
100
|
-
self.img_dir = self.cfg.root_dir / self.
|
105
|
+
self.img_dir = self.cfg.root_dir / self.split / self.cfg.img_dir_name
|
101
106
|
self.img_list = self.get_image_files(self.img_dir)
|
102
107
|
|
103
108
|
# todo: To be overridden
|
@@ -115,12 +120,13 @@ class BaseDataset(Dataset):
|
|
115
120
|
height=self.cfg.height,
|
116
121
|
width=self.cfg.width,
|
117
122
|
out_format="tensor",
|
123
|
+
version=self.cfg.spike_load_version
|
118
124
|
)
|
119
125
|
return spike
|
120
126
|
|
121
127
|
def get_spike(self, idx):
|
122
128
|
"""Get and process the spike stream from the given idx."""
|
123
|
-
spike_length = self.
|
129
|
+
spike_length = self.spike_length
|
124
130
|
spike = self.load_spike(idx)
|
125
131
|
assert spike.shape[0] >= spike_length, f"Given spike length {spike.shape[0]} smaller than the required length {spike_length}"
|
126
132
|
spike_mid = spike.shape[0] // 2
|
@@ -4,21 +4,19 @@ from dataclasses import dataclass
|
|
4
4
|
|
5
5
|
|
6
6
|
@dataclass
|
7
|
-
class
|
8
|
-
dataset_name: str = "
|
9
|
-
root_dir: Path = Path(__file__).parent.parent / Path("data/
|
7
|
+
class RealDataConfig(BaseDatasetConfig):
|
8
|
+
dataset_name: str = "realdata"
|
9
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/realdata")
|
10
10
|
width: int = 400
|
11
11
|
height: int = 250
|
12
12
|
with_img: bool = False
|
13
13
|
spike_length_train: int = -1
|
14
14
|
spike_length_test: int = -1
|
15
15
|
rate: float = 1
|
16
|
-
|
17
16
|
|
18
|
-
|
19
|
-
class RealWorld(BaseDataset):
|
17
|
+
class RealData(BaseDataset):
|
20
18
|
def __init__(self, cfg: BaseDatasetConfig):
|
21
|
-
super(
|
19
|
+
super(RealData, self).__init__(cfg)
|
22
20
|
|
23
21
|
def prepare_data(self):
|
24
22
|
self.spike_dir = self.cfg.root_dir
|
@@ -8,7 +8,7 @@ import re
|
|
8
8
|
@dataclass
|
9
9
|
class REDS_BASEConfig(BaseDatasetConfig):
|
10
10
|
dataset_name: str = "reds_base"
|
11
|
-
root_dir: Path = Path(__file__).parent.parent / Path("data/
|
11
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/reds_base")
|
12
12
|
width: int = 400
|
13
13
|
height: int = 250
|
14
14
|
with_img: bool = True
|
@@ -9,7 +9,7 @@ import numpy as np
|
|
9
9
|
@dataclass
|
10
10
|
class SZDataConfig(BaseDatasetConfig):
|
11
11
|
dataset_name: str = "szdata"
|
12
|
-
root_dir: Path = Path(__file__).parent.parent / Path("data/
|
12
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/szdata")
|
13
13
|
width: int = 400
|
14
14
|
height: int = 250
|
15
15
|
with_img: bool = True
|
@@ -7,7 +7,7 @@ import torch
|
|
7
7
|
@dataclass
|
8
8
|
class UHSRConfig(BaseDatasetConfig):
|
9
9
|
dataset_name: str = "uhsr"
|
10
|
-
root_dir: Path = Path(__file__).parent.parent / Path("data/
|
10
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/u_caltech")
|
11
11
|
width: int = 224
|
12
12
|
height: int = 224
|
13
13
|
with_img: bool = False
|
@@ -0,0 +1,23 @@
|
|
1
|
+
from torch.utils.data import Dataset
|
2
|
+
from pathlib import Path
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import Literal, Union
|
5
|
+
from typing import Optional
|
6
|
+
from spikezoo.datasets.base_dataset import BaseDatasetConfig,BaseDataset
|
7
|
+
|
8
|
+
@dataclass
|
9
|
+
class YourDatasetConfig(BaseDatasetConfig):
|
10
|
+
dataset_name: str = "yourdataset"
|
11
|
+
root_dir: Union[str, Path] = Path(__file__).parent.parent / Path("data/your_data_path")
|
12
|
+
width: int = 400
|
13
|
+
height: int = 250
|
14
|
+
with_img: bool = True
|
15
|
+
spike_length_train: int = -1
|
16
|
+
spike_length_test: int = -1
|
17
|
+
spike_dir_name: str = "spike_data"
|
18
|
+
img_dir_name: str = "sharp_data"
|
19
|
+
rate: float = 1
|
20
|
+
|
21
|
+
class YourDataset(BaseDataset):
|
22
|
+
def __init__(self, cfg: BaseDatasetConfig):
|
23
|
+
super(YourDataset, self).__init__(cfg)
|
spikezoo/models/__init__.py
CHANGED
@@ -1,16 +1,6 @@
|
|
1
1
|
import importlib
|
2
2
|
import inspect
|
3
3
|
from spikezoo.models.base_model import BaseModel,BaseModelConfig
|
4
|
-
from spikezoo.models.tfp_model import TFPModel,TFPConfig
|
5
|
-
from spikezoo.models.tfi_model import TFIModel,TFIConfig
|
6
|
-
from spikezoo.models.spk2imgnet_model import Spk2ImgNet,Spk2ImgNetConfig
|
7
|
-
from spikezoo.models.wgse_model import WGSE,WGSEConfig
|
8
|
-
from spikezoo.models.ssml_model import SSML,SSMLConfig
|
9
|
-
from spikezoo.models.bsf_model import BSF,BSFConfig
|
10
|
-
from spikezoo.models.stir_model import STIR,STIRConfig
|
11
|
-
from spikezoo.models.ssir_model import SSIR,SSIRConfig
|
12
|
-
from spikezoo.models.spikeclip_model import SpikeCLIP,SpikeCLIPConfig
|
13
|
-
|
14
4
|
|
15
5
|
from spikezoo.utils.other_utils import getattr_case_insensitive
|
16
6
|
import os
|
@@ -24,14 +14,17 @@ model_list = [file.split("_")[0] for file in files_list if file.endswith("_model
|
|
24
14
|
def build_model_cfg(cfg: BaseModelConfig):
|
25
15
|
"""Build the model from the given model config."""
|
26
16
|
# model module name
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
17
|
+
if cfg.model_cls_local == None:
|
18
|
+
module_name = cfg.model_name + "_model"
|
19
|
+
assert cfg.model_name in model_list, f"Given model {cfg.model_name} not in our model zoo {model_list}."
|
20
|
+
module_name = "spikezoo.models." + module_name
|
21
|
+
module = importlib.import_module(module_name)
|
22
|
+
# model,model_config
|
23
|
+
model_name = cfg.model_name
|
24
|
+
model_name = model_name + 'Model' if model_name == "base" else model_name
|
25
|
+
model_cls: BaseModel = getattr_case_insensitive(module,model_name)
|
26
|
+
else:
|
27
|
+
model_cls: BaseModel = cfg.model_cls_local
|
35
28
|
model = model_cls(cfg)
|
36
29
|
return model
|
37
30
|
|
spikezoo/models/base_model.py
CHANGED
@@ -44,7 +44,10 @@ class BaseModelConfig:
|
|
44
44
|
multi_gpu: bool = False
|
45
45
|
"Base url."
|
46
46
|
base_url: str = "https://github.com/chenkang455/Spike-Zoo/releases/download"
|
47
|
-
|
47
|
+
"Load the model from local class or spikezoo lib. (None)"
|
48
|
+
model_cls_local: Optional[nn.Module] = None
|
49
|
+
"Load the arch from local class or spikezoo lib. (None)"
|
50
|
+
arch_cls_local: Optional[nn.Module] = None
|
48
51
|
|
49
52
|
class BaseModel(nn.Module):
|
50
53
|
def __init__(self, cfg: BaseModelConfig):
|
@@ -71,8 +74,11 @@ class BaseModel(nn.Module):
|
|
71
74
|
):
|
72
75
|
"""Build the network and load the pretrained weight."""
|
73
76
|
# network
|
74
|
-
|
75
|
-
|
77
|
+
if self.cfg.arch_cls_local == None:
|
78
|
+
module = importlib.import_module(f"spikezoo.archs.{self.cfg.model_name}.{self.cfg.model_file_name}")
|
79
|
+
model_cls = getattr(module, self.cfg.model_cls_name)
|
80
|
+
else:
|
81
|
+
model_cls = self.cfg.arch_cls_local
|
76
82
|
# load model config parameters
|
77
83
|
if version == "local":
|
78
84
|
model = model_cls(**self.cfg.model_params)
|
@@ -129,7 +135,7 @@ class BaseModel(nn.Module):
|
|
129
135
|
"""Crop the spike length."""
|
130
136
|
spike_length = spike.shape[1]
|
131
137
|
spike_mid = spike_length // 2
|
132
|
-
assert spike_length >= self.model_length, f"Spike input is not long enough, given {spike_length} frames < {self.cfg.model_length}."
|
138
|
+
assert spike_length >= self.model_length, f"Spike input is not long enough, given {spike_length} frames < {self.cfg.model_length} required by the {self.cfg.model_name}."
|
133
139
|
# even length
|
134
140
|
if self.model_length == self.model_half_length * 2:
|
135
141
|
spike = spike[
|
@@ -0,0 +1,22 @@
|
|
1
|
+
from torch.utils.data import Dataset
|
2
|
+
from pathlib import Path
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import Literal, Union
|
5
|
+
from typing import Optional
|
6
|
+
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
7
|
+
from dataclasses import field
|
8
|
+
import torch.nn as nn
|
9
|
+
|
10
|
+
|
11
|
+
@dataclass
|
12
|
+
class YourModelConfig(BaseModelConfig):
|
13
|
+
model_name: str = "yourmodel" # 需与文件名保持一致
|
14
|
+
model_file_name: str = "arch.net" # archs路径下的模块路径
|
15
|
+
model_cls_name: str = "YourNet" # 模型类名
|
16
|
+
model_length: int = 41
|
17
|
+
require_params: bool = True
|
18
|
+
model_params: dict = field(default_factory=lambda: {"inDim": 41})
|
19
|
+
|
20
|
+
class YourModel(BaseModel):
|
21
|
+
def __init__(self, cfg: BaseModelConfig):
|
22
|
+
super(YourModel, self).__init__(cfg)
|
@@ -34,11 +34,17 @@ class PipelineConfig:
|
|
34
34
|
"Evaluate metrics or not."
|
35
35
|
save_metric: bool = True
|
36
36
|
"Metric names for evaluation."
|
37
|
-
metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim","niqe","brisque"])
|
37
|
+
metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim", "niqe", "brisque"])
|
38
38
|
"Save recoverd images or not."
|
39
39
|
save_img: bool = True
|
40
40
|
"Normalizing recoverd images and gt or not."
|
41
|
-
|
41
|
+
img_norm: bool = False
|
42
|
+
"Batch size for the test dataloader."
|
43
|
+
bs_test: int = 1
|
44
|
+
"Num_workers for the test dataloader."
|
45
|
+
nw_test: int = 0
|
46
|
+
"Pin_memory true or false for the dataloader."
|
47
|
+
pin_memory: bool = False
|
42
48
|
"Different modes for the pipeline."
|
43
49
|
_mode: Literal["single_mode", "multi_mode", "train_mode"] = "single_mode"
|
44
50
|
|
@@ -63,7 +69,8 @@ class Pipeline:
|
|
63
69
|
torch.set_grad_enabled(False)
|
64
70
|
# dataset
|
65
71
|
self.dataset: BaseDataset = build_dataset_name(dataset_cfg) if isinstance(dataset_cfg, str) else build_dataset_cfg(dataset_cfg)
|
66
|
-
self.
|
72
|
+
self.dataset.build_source(split="test")
|
73
|
+
self.dataloader = build_dataloader(self.dataset,self.cfg)
|
67
74
|
# device
|
68
75
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
69
76
|
|
@@ -103,7 +110,7 @@ class Pipeline:
|
|
103
110
|
"""Function I---Save the recoverd image and calculate the metric from the given dataset."""
|
104
111
|
# save folder
|
105
112
|
self.logger.info("*********************** infer_from_dataset ***********************")
|
106
|
-
save_folder = self.save_folder / Path(f"infer_from_dataset/{self.dataset.cfg.dataset_name}_dataset/{self.dataset.
|
113
|
+
save_folder = self.save_folder / Path(f"infer_from_dataset/{self.dataset.cfg.dataset_name}_dataset/{self.dataset.split}/{idx:06d}")
|
107
114
|
os.makedirs(str(save_folder), exist_ok=True)
|
108
115
|
|
109
116
|
# data process
|
@@ -117,7 +124,7 @@ class Pipeline:
|
|
117
124
|
img = None
|
118
125
|
return self.infer(spike, img, save_folder, rate)
|
119
126
|
|
120
|
-
def infer_from_file(self, file_path, height=-1, width=-1,
|
127
|
+
def infer_from_file(self, file_path, height=-1, width=-1, rate=1, img_path=None, remove_head=False):
|
121
128
|
"""Function II---Save the recoverd image and calculate the metric from the given input file."""
|
122
129
|
# save folder
|
123
130
|
self.logger.info("*********************** infer_from_file ***********************")
|
@@ -144,7 +151,7 @@ class Pipeline:
|
|
144
151
|
spike = torch.from_numpy(spike)[None].to(self.device)
|
145
152
|
return self.infer(spike, img, save_folder, rate)
|
146
153
|
|
147
|
-
def infer_from_spk(self, spike,
|
154
|
+
def infer_from_spk(self, spike, rate=1, img=None):
|
148
155
|
"""Function III---Save the recoverd image and calculate the metric from the given spike stream."""
|
149
156
|
# save folder
|
150
157
|
self.logger.info("*********************** infer_from_spk ***********************")
|
@@ -181,7 +188,7 @@ class Pipeline:
|
|
181
188
|
for idx in range(len(self.dataset)):
|
182
189
|
self.infer_from_dataset(idx=idx)
|
183
190
|
self.cfg.save_metric = base_setting
|
184
|
-
|
191
|
+
|
185
192
|
# TODO: To be overridden
|
186
193
|
def cal_params(self):
|
187
194
|
"""Function VI---Calculate the parameters/flops/latency of the given method."""
|
@@ -228,8 +235,8 @@ class Pipeline:
|
|
228
235
|
# With no GT
|
229
236
|
if recon_img == None:
|
230
237
|
return None
|
231
|
-
#
|
232
|
-
if model_name in ["
|
238
|
+
# spikeclip is normalized automatically
|
239
|
+
if model_name in ["spikeclip"] or self.cfg.img_norm == True:
|
233
240
|
recon_img = (recon_img - recon_img.min()) / (recon_img.max() - recon_img.min())
|
234
241
|
else:
|
235
242
|
recon_img = recon_img / rate
|
@@ -253,7 +260,7 @@ class Pipeline:
|
|
253
260
|
batch = model.feed_to_device(batch)
|
254
261
|
outputs = model.get_outputs_dict(batch)
|
255
262
|
recon_img, img = model.get_paired_imgs(batch, outputs)
|
256
|
-
recon_img, img = self._post_process_img(recon_img, model_name), self._post_process_img(img, "
|
263
|
+
recon_img, img = self._post_process_img(recon_img, model_name), self._post_process_img(img, "gt")
|
257
264
|
for metric_name in metrics_dict.keys():
|
258
265
|
if metric_name in metric_pair_names:
|
259
266
|
metrics_dict[metric_name].update(cal_metric_pair(recon_img, img, metric_name))
|
@@ -48,7 +48,8 @@ class EnsemblePipeline(Pipeline):
|
|
48
48
|
torch.set_grad_enabled(False)
|
49
49
|
# data
|
50
50
|
self.dataset: BaseDataset = build_dataset_name(dataset_cfg) if isinstance(dataset_cfg, str) else build_dataset_cfg(dataset_cfg)
|
51
|
-
self.
|
51
|
+
self.dataset.build_source(split = "test")
|
52
|
+
self.dataloader = build_dataloader(self.dataset,self.cfg)
|
52
53
|
# device
|
53
54
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
54
55
|
|
spikezoo/pipeline/train_cfgs.py
CHANGED
@@ -18,50 +18,53 @@ class REDS_BASE_TrainConfig(TrainPipelineConfig):
|
|
18
18
|
steps_per_save_imgs: int = 200
|
19
19
|
steps_per_save_ckpt: int = 500
|
20
20
|
steps_per_cal_metrics: int = 100
|
21
|
-
metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim","lpips","niqe","brisque","piqe"])
|
21
|
+
metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim", "lpips", "niqe", "brisque", "piqe"])
|
22
22
|
|
23
23
|
# dataloader setting
|
24
24
|
bs_train: int = 8
|
25
|
-
|
25
|
+
nw_train: int = 4
|
26
26
|
pin_memory: bool = False
|
27
27
|
|
28
28
|
# train setting - optimizer & scheduler & loss_dict
|
29
|
-
optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4)
|
30
|
-
scheduler_cfg: Optional[SchedulerConfig] =
|
29
|
+
optimizer_cfg: OptimizerConfig = field(default_factory=lambda: AdamOptimizerConfig(lr=1e-4))
|
30
|
+
scheduler_cfg: Optional[SchedulerConfig] = field(
|
31
|
+
default_factory=lambda: MultiStepSchedulerConfig(milestones=[400], gamma=0.2)
|
32
|
+
) # from wgse
|
31
33
|
loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
|
32
34
|
|
33
|
-
# ! Train Config for each method on the official setting, not recommended to utilize their default parameters owing to the dataset setting.
|
34
|
-
@dataclass
|
35
|
-
class BSFTrainConfig(TrainPipelineConfig):
|
36
|
-
"""Training setting for BSF. https://github.com/ruizhao26/BSF"""
|
37
35
|
|
38
|
-
|
39
|
-
|
40
|
-
|
36
|
+
# # ! Train Config for each method on the official setting, not recommended to utilize their default parameters owing to the dataset setting.
|
37
|
+
# @dataclass
|
38
|
+
# class BSFTrainConfig(TrainPipelineConfig):
|
39
|
+
# """Training setting for BSF. https://github.com/ruizhao26/BSF"""
|
41
40
|
|
41
|
+
# optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4, weight_decay=0.0)
|
42
|
+
# scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[10, 20, 30, 40, 50, 60, 70, 80, 90, 100])
|
43
|
+
# loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
|
42
44
|
|
43
|
-
@dataclass
|
44
|
-
class WGSETrainConfig(TrainPipelineConfig):
|
45
|
-
"""Training setting for WGSE. https://github.com/Leozhangjiyuan/WGSE-SpikeCamera"""
|
46
45
|
|
47
|
-
|
48
|
-
|
49
|
-
|
46
|
+
# @dataclass
|
47
|
+
# class WGSETrainConfig(TrainPipelineConfig):
|
48
|
+
# """Training setting for WGSE. https://github.com/Leozhangjiyuan/WGSE-SpikeCamera"""
|
50
49
|
|
50
|
+
# optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4, betas=(0.9, 0.99), weight_decay=0)
|
51
|
+
# scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[400, 600], gamma=0.2)
|
52
|
+
# loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
|
51
53
|
|
52
|
-
@dataclass
|
53
|
-
class STIRTrainConfig(TrainPipelineConfig):
|
54
|
-
"""Training setting for STIR. https://github.com/GitCVfb/STIR"""
|
55
54
|
|
56
|
-
|
57
|
-
|
58
|
-
|
55
|
+
# @dataclass
|
56
|
+
# class STIRTrainConfig(TrainPipelineConfig):
|
57
|
+
# """Training setting for STIR. https://github.com/GitCVfb/STIR"""
|
59
58
|
|
59
|
+
# optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4, betas=(0.9, 0.999))
|
60
|
+
# scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70], gamma=0.7)
|
61
|
+
# loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
|
60
62
|
|
61
|
-
@dataclass
|
62
|
-
class Spk2ImgNetTrainConfig(TrainPipelineConfig):
|
63
|
-
"""Training setting for Spk2ImgNet. https://github.com/Vspacer/Spk2ImgNet"""
|
64
63
|
|
65
|
-
|
66
|
-
|
67
|
-
|
64
|
+
# @dataclass
|
65
|
+
# class Spk2ImgNetTrainConfig(TrainPipelineConfig):
|
66
|
+
# """Training setting for Spk2ImgNet. https://github.com/Vspacer/Spk2ImgNet"""
|
67
|
+
|
68
|
+
# optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4)
|
69
|
+
# scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[20], gamma=0.1)
|
70
|
+
# loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
|