spikezoo 0.2.3.4__py3-none-any.whl → 0.2.3.6__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
- spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
- spikezoo/archs/stir/metrics/__pycache__/losses.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/Vgg19.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/geometry.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/image_proc.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/losses.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
- spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/yourmodel/arch/__pycache__/net.cpython-39.pyc +0 -0
- spikezoo/archs/yourmodel/arch/net.py +35 -0
- spikezoo/datasets/__init__.py +20 -21
- spikezoo/datasets/base_dataset.py +26 -21
- spikezoo/datasets/{realworld_dataset.py → realdata_dataset.py} +5 -7
- spikezoo/datasets/reds_base_dataset.py +1 -1
- spikezoo/datasets/szdata_dataset.py +1 -5
- spikezoo/datasets/uhsr_dataset.py +1 -1
- spikezoo/datasets/yourdataset_dataset.py +23 -0
- spikezoo/models/__init__.py +12 -8
- spikezoo/models/base_model.py +10 -4
- spikezoo/models/bsf_model.py +0 -1
- spikezoo/models/spk2imgnet_model.py +0 -1
- spikezoo/models/stir_model.py +0 -1
- spikezoo/models/wgse_model.py +0 -1
- spikezoo/models/yourmodel_model.py +22 -0
- spikezoo/pipeline/base_pipeline.py +17 -10
- spikezoo/pipeline/ensemble_pipeline.py +2 -1
- spikezoo/pipeline/train_cfgs.py +3 -1
- spikezoo/pipeline/train_pipeline.py +12 -12
- spikezoo/utils/spike_utils.py +2 -2
- spikezoo-0.2.3.6.dist-info/METADATA +151 -0
- {spikezoo-0.2.3.4.dist-info → spikezoo-0.2.3.6.dist-info}/RECORD +53 -23
- spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
- spikezoo-0.2.3.4.dist-info/METADATA +0 -259
- {spikezoo-0.2.3.4.dist-info → spikezoo-0.2.3.6.dist-info}/LICENSE.txt +0 -0
- {spikezoo-0.2.3.4.dist-info → spikezoo-0.2.3.6.dist-info}/WHEEL +0 -0
- {spikezoo-0.2.3.4.dist-info → spikezoo-0.2.3.6.dist-info}/top_level.txt +0 -0
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
@@ -0,0 +1,35 @@
|
|
1
|
+
import torch.nn as nn
|
2
|
+
|
3
|
+
def conv_layer(inDim, outDim, ks, s, p, norm_layer="none"):
|
4
|
+
## convolutional layer
|
5
|
+
conv = nn.Conv2d(inDim, outDim, kernel_size=ks, stride=s, padding=p)
|
6
|
+
relu = nn.ReLU(True)
|
7
|
+
assert norm_layer in ("batch", "instance", "none")
|
8
|
+
if norm_layer == "none":
|
9
|
+
seq = nn.Sequential(*[conv, relu])
|
10
|
+
else:
|
11
|
+
if norm_layer == "instance":
|
12
|
+
norm = nn.InstanceNorm2d(outDim, affine=False, track_running_stats=False) # instance norm
|
13
|
+
else:
|
14
|
+
momentum = 0.1
|
15
|
+
norm = nn.BatchNorm2d(outDim, momentum=momentum, affine=True, track_running_stats=True)
|
16
|
+
seq = nn.Sequential(*[conv, norm, relu])
|
17
|
+
return seq
|
18
|
+
|
19
|
+
|
20
|
+
class YourNet(nn.Module):
|
21
|
+
"""Borrow the structure from the SpikeCLIP. (https://arxiv.org/abs/2501.04477)"""
|
22
|
+
|
23
|
+
def __init__(self, inDim=41):
|
24
|
+
super(YourNet, self).__init__()
|
25
|
+
norm = "none"
|
26
|
+
outDim = 1
|
27
|
+
convBlock1 = conv_layer(inDim, 64, 3, 1, 1)
|
28
|
+
convBlock2 = conv_layer(64, 128, 3, 1, 1, norm)
|
29
|
+
convBlock3 = conv_layer(128, 64, 3, 1, 1, norm)
|
30
|
+
convBlock4 = conv_layer(64, 16, 3, 1, 1, norm)
|
31
|
+
conv = nn.Conv2d(16, outDim, 3, 1, 1)
|
32
|
+
self.seq = nn.Sequential(*[convBlock1, convBlock2, convBlock3, convBlock4, conv])
|
33
|
+
|
34
|
+
def forward(self, x):
|
35
|
+
return self.seq(x)
|
spikezoo/datasets/__init__.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
|
2
|
+
|
2
3
|
from dataclasses import replace
|
3
4
|
import importlib, inspect
|
4
5
|
import os
|
@@ -12,23 +13,24 @@ dataset_list = [file.replace("_dataset.py", "") for file in files_list if file.e
|
|
12
13
|
|
13
14
|
|
14
15
|
# todo register function
|
15
|
-
def build_dataset_cfg(cfg: BaseDatasetConfig
|
16
|
+
def build_dataset_cfg(cfg: BaseDatasetConfig):
|
16
17
|
"""Build the dataset from the given dataset config."""
|
17
|
-
# build new cfg according to split
|
18
|
-
cfg = replace(cfg, split=split, spike_length=cfg.spike_length_train if split == "train" else cfg.spike_length_test)
|
19
18
|
# dataset module
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
19
|
+
if cfg.dataset_cls_local == None:
|
20
|
+
module_name = cfg.dataset_name + "_dataset"
|
21
|
+
assert cfg.dataset_name in dataset_list, f"Given dataset {cfg.dataset_name} not in our dataset list {dataset_list}."
|
22
|
+
module_name = "spikezoo.datasets." + module_name
|
23
|
+
module = importlib.import_module(module_name)
|
24
|
+
# dataset,dataset_config
|
25
|
+
dataset_name = cfg.dataset_name
|
26
|
+
dataset_name = dataset_name + "Dataset" if dataset_name == "base" else dataset_name
|
27
|
+
dataset_cls: BaseDataset = getattr_case_insensitive(module, dataset_name)
|
28
|
+
else:
|
29
|
+
dataset_cls = cfg.dataset_cls_local
|
28
30
|
dataset = dataset_cls(cfg)
|
29
31
|
return dataset
|
30
32
|
|
31
|
-
def build_dataset_name(dataset_name: str
|
33
|
+
def build_dataset_name(dataset_name: str):
|
32
34
|
"""Build the default dataset from the given name."""
|
33
35
|
module_name = dataset_name + "_dataset"
|
34
36
|
assert dataset_name in dataset_list, f"Given dataset {dataset_name} not in our dataset list {dataset_list}."
|
@@ -37,22 +39,19 @@ def build_dataset_name(dataset_name: str, split: Literal["train", "test"] = "tes
|
|
37
39
|
# dataset,dataset_config
|
38
40
|
dataset_name = dataset_name + "Dataset" if dataset_name == "base" else dataset_name
|
39
41
|
dataset_cls: BaseDataset = getattr_case_insensitive(module, dataset_name)
|
40
|
-
dataset_cfg: BaseDatasetConfig = getattr_case_insensitive(module, dataset_name + "config")(
|
42
|
+
dataset_cfg: BaseDatasetConfig = getattr_case_insensitive(module, dataset_name + "config")()
|
41
43
|
dataset = dataset_cls(dataset_cfg)
|
42
44
|
return dataset
|
43
45
|
|
44
46
|
|
45
47
|
# todo to modify according to the basicsr
|
46
|
-
def build_dataloader(dataset
|
48
|
+
def build_dataloader(dataset, cfg):
|
47
49
|
# train dataloader
|
48
|
-
if dataset.
|
49
|
-
|
50
|
-
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
|
51
|
-
else:
|
52
|
-
return torch.utils.data.DataLoader(dataset, batch_size=cfg.bs_train, shuffle=True, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory)
|
50
|
+
if dataset.split == "train" and cfg._mode == "train_mode":
|
51
|
+
return torch.utils.data.DataLoader(dataset, batch_size=cfg.bs_train, shuffle=True, num_workers=cfg.nw_train, pin_memory=cfg.pin_memory)
|
53
52
|
# test dataloader
|
54
|
-
|
55
|
-
return torch.utils.data.DataLoader(dataset, batch_size=
|
53
|
+
else:
|
54
|
+
return torch.utils.data.DataLoader(dataset, batch_size=cfg.bs_test, shuffle=False, num_workers=cfg.nw_test,pin_memory=False)
|
56
55
|
|
57
56
|
|
58
57
|
# dataset_size_dict = {}
|
@@ -11,6 +11,7 @@ import warnings
|
|
11
11
|
import torch
|
12
12
|
from tqdm import tqdm
|
13
13
|
from spikezoo.utils.data_utils import Augmentor
|
14
|
+
from typing import Optional
|
14
15
|
|
15
16
|
|
16
17
|
@dataclass
|
@@ -30,31 +31,24 @@ class BaseDatasetConfig:
|
|
30
31
|
spike_length_train: int = -1
|
31
32
|
"Dataset spike length for the test data."
|
32
33
|
spike_length_test: int = -1
|
33
|
-
"Dataset spike length for the instantiation dataclass."
|
34
|
-
spike_length: int = -1
|
35
34
|
"Dir name for the spike."
|
36
35
|
spike_dir_name: str = "spike"
|
37
36
|
"Dir name for the image."
|
38
37
|
img_dir_name: str = "gt"
|
38
|
+
"Rate. (-1 denotes variant)"
|
39
|
+
rate: float = 0.6
|
39
40
|
|
40
41
|
# ------------- Config -------------
|
41
|
-
"Dataset split: train/test. Default set as the 'test' for evaluation."
|
42
|
-
split: Literal["train", "test"] = "test"
|
43
42
|
"Use the data augumentation technique or not."
|
44
43
|
use_aug: bool = False
|
45
44
|
"Use cache mechanism."
|
46
45
|
use_cache: bool = False
|
47
46
|
"Crop size."
|
48
47
|
crop_size: tuple = (-1, -1)
|
49
|
-
"
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
def __post_init__(self):
|
54
|
-
self.spike_length = self.spike_length_train if self.split == "train" else self.spike_length_test
|
55
|
-
self.root_dir = Path(self.root_dir) if isinstance(self.root_dir, str) else self.root_dir
|
56
|
-
# todo try download
|
57
|
-
assert self.root_dir.exists(), f"No files found in {self.root_dir} for the specified dataset `{self.dataset_name}`."
|
48
|
+
"Load the dataset from local or spikezoo lib."
|
49
|
+
dataset_cls_local: Optional[Dataset] = None
|
50
|
+
"Spike load version. [python,cpp]"
|
51
|
+
spike_load_version: Literal["python", "cpp"] = "python"
|
58
52
|
|
59
53
|
|
60
54
|
# todo cache mechanism
|
@@ -62,10 +56,6 @@ class BaseDataset(Dataset):
|
|
62
56
|
def __init__(self, cfg: BaseDatasetConfig):
|
63
57
|
super(BaseDataset, self).__init__()
|
64
58
|
self.cfg = cfg
|
65
|
-
self.augmentor = Augmentor(self.cfg.crop_size) if self.cfg.use_aug == True and self.cfg.split == "train" else -1
|
66
|
-
self.prepare_data()
|
67
|
-
self.cache_data() if cfg.use_cache == True else -1
|
68
|
-
warnings.warn("Lengths of the image list and the spike list should be equal.") if len(self.img_list) != len(self.spike_list) else -1
|
69
59
|
|
70
60
|
def __len__(self):
|
71
61
|
return len(self.spike_list)
|
@@ -80,7 +70,7 @@ class BaseDataset(Dataset):
|
|
80
70
|
img = self.get_img(idx)
|
81
71
|
|
82
72
|
# process data
|
83
|
-
if self.cfg.use_aug == True and self.
|
73
|
+
if self.cfg.use_aug == True and self.split == "train":
|
84
74
|
spike, img = self.augmentor(spike, img)
|
85
75
|
|
86
76
|
# rate
|
@@ -90,15 +80,29 @@ class BaseDataset(Dataset):
|
|
90
80
|
batch = {"spike": spike, "gt_img": img, "rate": rate}
|
91
81
|
return batch
|
92
82
|
|
83
|
+
def build_source(self, split: Literal["train", "test"] = "test"):
|
84
|
+
"""Build the dataset source and prepare to be loaded files."""
|
85
|
+
# spike length
|
86
|
+
self.split = split
|
87
|
+
self.spike_length = self.cfg.spike_length_train if self.split == "train" else self.cfg.spike_length_test
|
88
|
+
# root dir
|
89
|
+
self.cfg.root_dir = Path(self.cfg.root_dir) if isinstance(self.cfg.root_dir, str) else self.cfg.root_dir
|
90
|
+
assert self.cfg.root_dir.exists(), f"No files found in {self.cfg.root_dir} for the specified dataset `{self.cfg.dataset_name}`."
|
91
|
+
# prepare
|
92
|
+
self.augmentor = Augmentor(self.cfg.crop_size) if self.cfg.use_aug == True and self.split == "train" else -1
|
93
|
+
self.prepare_data()
|
94
|
+
self.cache_data() if self.cfg.use_cache == True else -1
|
95
|
+
warnings.warn("Lengths of the image list and the spike list should be equal.") if len(self.img_list) != len(self.spike_list) else -1
|
96
|
+
|
93
97
|
# todo: To be overridden
|
94
98
|
def prepare_data(self):
|
95
99
|
"""Specify the spike and image files to be loaded."""
|
96
100
|
# spike
|
97
|
-
self.spike_dir = self.cfg.root_dir / self.
|
101
|
+
self.spike_dir = self.cfg.root_dir / self.split / self.cfg.spike_dir_name
|
98
102
|
self.spike_list = self.get_spike_files(self.spike_dir)
|
99
103
|
# gt
|
100
104
|
if self.cfg.with_img == True:
|
101
|
-
self.img_dir = self.cfg.root_dir / self.
|
105
|
+
self.img_dir = self.cfg.root_dir / self.split / self.cfg.img_dir_name
|
102
106
|
self.img_list = self.get_image_files(self.img_dir)
|
103
107
|
|
104
108
|
# todo: To be overridden
|
@@ -116,12 +120,13 @@ class BaseDataset(Dataset):
|
|
116
120
|
height=self.cfg.height,
|
117
121
|
width=self.cfg.width,
|
118
122
|
out_format="tensor",
|
123
|
+
version=self.cfg.spike_load_version
|
119
124
|
)
|
120
125
|
return spike
|
121
126
|
|
122
127
|
def get_spike(self, idx):
|
123
128
|
"""Get and process the spike stream from the given idx."""
|
124
|
-
spike_length = self.
|
129
|
+
spike_length = self.spike_length
|
125
130
|
spike = self.load_spike(idx)
|
126
131
|
assert spike.shape[0] >= spike_length, f"Given spike length {spike.shape[0]} smaller than the required length {spike_length}"
|
127
132
|
spike_mid = spike.shape[0] // 2
|
@@ -4,21 +4,19 @@ from dataclasses import dataclass
|
|
4
4
|
|
5
5
|
|
6
6
|
@dataclass
|
7
|
-
class
|
8
|
-
dataset_name: str = "
|
9
|
-
root_dir: Path = Path(__file__).parent.parent / Path("data/
|
7
|
+
class RealDataConfig(BaseDatasetConfig):
|
8
|
+
dataset_name: str = "realdata"
|
9
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/realdata")
|
10
10
|
width: int = 400
|
11
11
|
height: int = 250
|
12
12
|
with_img: bool = False
|
13
13
|
spike_length_train: int = -1
|
14
14
|
spike_length_test: int = -1
|
15
15
|
rate: float = 1
|
16
|
-
|
17
16
|
|
18
|
-
|
19
|
-
class RealWorld(BaseDataset):
|
17
|
+
class RealData(BaseDataset):
|
20
18
|
def __init__(self, cfg: BaseDatasetConfig):
|
21
|
-
super(
|
19
|
+
super(RealData, self).__init__(cfg)
|
22
20
|
|
23
21
|
def prepare_data(self):
|
24
22
|
self.spike_dir = self.cfg.root_dir
|
@@ -8,7 +8,7 @@ import re
|
|
8
8
|
@dataclass
|
9
9
|
class REDS_BASEConfig(BaseDatasetConfig):
|
10
10
|
dataset_name: str = "reds_base"
|
11
|
-
root_dir: Path = Path(__file__).parent.parent / Path("data/
|
11
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/reds_base")
|
12
12
|
width: int = 400
|
13
13
|
height: int = 250
|
14
14
|
with_img: bool = True
|
@@ -9,7 +9,7 @@ import numpy as np
|
|
9
9
|
@dataclass
|
10
10
|
class SZDataConfig(BaseDatasetConfig):
|
11
11
|
dataset_name: str = "szdata"
|
12
|
-
root_dir: Path = Path(__file__).parent.parent / Path("data/
|
12
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/szdata")
|
13
13
|
width: int = 400
|
14
14
|
height: int = 250
|
15
15
|
with_img: bool = True
|
@@ -22,7 +22,3 @@ class SZDataConfig(BaseDatasetConfig):
|
|
22
22
|
class SZData(BaseDataset):
|
23
23
|
def __init__(self, cfg: BaseDatasetConfig):
|
24
24
|
super(SZData, self).__init__(cfg)
|
25
|
-
|
26
|
-
def prepare_data(self):
|
27
|
-
super().prepare_data()
|
28
|
-
self.img_list = [self.img_dir / Path(str(s.name).replace('.dat','.png')) for s in self.spike_list]
|
@@ -7,7 +7,7 @@ import torch
|
|
7
7
|
@dataclass
|
8
8
|
class UHSRConfig(BaseDatasetConfig):
|
9
9
|
dataset_name: str = "uhsr"
|
10
|
-
root_dir: Path = Path(__file__).parent.parent / Path("data/
|
10
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/u_caltech")
|
11
11
|
width: int = 224
|
12
12
|
height: int = 224
|
13
13
|
with_img: bool = False
|
@@ -0,0 +1,23 @@
|
|
1
|
+
from torch.utils.data import Dataset
|
2
|
+
from pathlib import Path
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import Literal, Union
|
5
|
+
from typing import Optional
|
6
|
+
from spikezoo.datasets.base_dataset import BaseDatasetConfig,BaseDataset
|
7
|
+
|
8
|
+
@dataclass
|
9
|
+
class YourDatasetConfig(BaseDatasetConfig):
|
10
|
+
dataset_name: str = "yourdataset"
|
11
|
+
root_dir: Union[str, Path] = Path(__file__).parent.parent / Path("data/your_data_path")
|
12
|
+
width: int = 400
|
13
|
+
height: int = 250
|
14
|
+
with_img: bool = True
|
15
|
+
spike_length_train: int = -1
|
16
|
+
spike_length_test: int = -1
|
17
|
+
spike_dir_name: str = "spike_data"
|
18
|
+
img_dir_name: str = "sharp_data"
|
19
|
+
rate: float = 1
|
20
|
+
|
21
|
+
class YourDataset(BaseDataset):
|
22
|
+
def __init__(self, cfg: BaseDatasetConfig):
|
23
|
+
super(YourDataset, self).__init__(cfg)
|
spikezoo/models/__init__.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import importlib
|
2
2
|
import inspect
|
3
3
|
from spikezoo.models.base_model import BaseModel,BaseModelConfig
|
4
|
+
|
4
5
|
from spikezoo.utils.other_utils import getattr_case_insensitive
|
5
6
|
import os
|
6
7
|
from pathlib import Path
|
@@ -13,14 +14,17 @@ model_list = [file.split("_")[0] for file in files_list if file.endswith("_model
|
|
13
14
|
def build_model_cfg(cfg: BaseModelConfig):
|
14
15
|
"""Build the model from the given model config."""
|
15
16
|
# model module name
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
17
|
+
if cfg.model_cls_local == None:
|
18
|
+
module_name = cfg.model_name + "_model"
|
19
|
+
assert cfg.model_name in model_list, f"Given model {cfg.model_name} not in our model zoo {model_list}."
|
20
|
+
module_name = "spikezoo.models." + module_name
|
21
|
+
module = importlib.import_module(module_name)
|
22
|
+
# model,model_config
|
23
|
+
model_name = cfg.model_name
|
24
|
+
model_name = model_name + 'Model' if model_name == "base" else model_name
|
25
|
+
model_cls: BaseModel = getattr_case_insensitive(module,model_name)
|
26
|
+
else:
|
27
|
+
model_cls: BaseModel = cfg.model_cls_local
|
24
28
|
model = model_cls(cfg)
|
25
29
|
return model
|
26
30
|
|
spikezoo/models/base_model.py
CHANGED
@@ -44,7 +44,10 @@ class BaseModelConfig:
|
|
44
44
|
multi_gpu: bool = False
|
45
45
|
"Base url."
|
46
46
|
base_url: str = "https://github.com/chenkang455/Spike-Zoo/releases/download"
|
47
|
-
|
47
|
+
"Load the model from local class or spikezoo lib. (None)"
|
48
|
+
model_cls_local: Optional[nn.Module] = None
|
49
|
+
"Load the arch from local class or spikezoo lib. (None)"
|
50
|
+
arch_cls_local: Optional[nn.Module] = None
|
48
51
|
|
49
52
|
class BaseModel(nn.Module):
|
50
53
|
def __init__(self, cfg: BaseModelConfig):
|
@@ -71,8 +74,11 @@ class BaseModel(nn.Module):
|
|
71
74
|
):
|
72
75
|
"""Build the network and load the pretrained weight."""
|
73
76
|
# network
|
74
|
-
|
75
|
-
|
77
|
+
if self.cfg.arch_cls_local == None:
|
78
|
+
module = importlib.import_module(f"spikezoo.archs.{self.cfg.model_name}.{self.cfg.model_file_name}")
|
79
|
+
model_cls = getattr(module, self.cfg.model_cls_name)
|
80
|
+
else:
|
81
|
+
model_cls = self.cfg.arch_cls_local
|
76
82
|
# load model config parameters
|
77
83
|
if version == "local":
|
78
84
|
model = model_cls(**self.cfg.model_params)
|
@@ -129,7 +135,7 @@ class BaseModel(nn.Module):
|
|
129
135
|
"""Crop the spike length."""
|
130
136
|
spike_length = spike.shape[1]
|
131
137
|
spike_mid = spike_length // 2
|
132
|
-
assert spike_length >= self.model_length, f"Spike input is not long enough, given {spike_length} frames < {self.cfg.model_length}."
|
138
|
+
assert spike_length >= self.model_length, f"Spike input is not long enough, given {spike_length} frames < {self.cfg.model_length} required by the {self.cfg.model_name}."
|
133
139
|
# even length
|
134
140
|
if self.model_length == self.model_half_length * 2:
|
135
141
|
spike = spike[
|
spikezoo/models/bsf_model.py
CHANGED
@@ -4,7 +4,6 @@ from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
|
4
4
|
from torch.optim import Adam
|
5
5
|
import torch.optim.lr_scheduler as lr_scheduler
|
6
6
|
import torch.nn as nn
|
7
|
-
from spikezoo.pipeline import TrainPipelineConfig
|
8
7
|
from typing import List
|
9
8
|
from spikezoo.archs.bsf.models.bsf.bsf import BSF
|
10
9
|
|
@@ -1,7 +1,6 @@
|
|
1
1
|
import torch
|
2
2
|
from dataclasses import dataclass, field
|
3
3
|
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
4
|
-
from spikezoo.pipeline import TrainPipelineConfig
|
5
4
|
import torch.nn as nn
|
6
5
|
import torch.optim as optim
|
7
6
|
import torch.optim.lr_scheduler as lr_scheduler
|
spikezoo/models/stir_model.py
CHANGED
@@ -4,7 +4,6 @@ from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
|
4
4
|
from torch.optim import Adam
|
5
5
|
import torch.optim.lr_scheduler as lr_scheduler
|
6
6
|
import torch.nn as nn
|
7
|
-
from spikezoo.pipeline import TrainPipelineConfig
|
8
7
|
from typing import List
|
9
8
|
from spikezoo.archs.stir.metrics.losses import compute_per_loss_single
|
10
9
|
from spikezoo.archs.stir.models.Vgg19 import Vgg19
|
spikezoo/models/wgse_model.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
from dataclasses import dataclass, field
|
2
2
|
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
3
3
|
from typing import List
|
4
|
-
from spikezoo.pipeline import TrainPipelineConfig
|
5
4
|
import torch.nn as nn
|
6
5
|
import torch.optim as optim
|
7
6
|
import torch.optim.lr_scheduler as lr_scheduler
|
@@ -0,0 +1,22 @@
|
|
1
|
+
from torch.utils.data import Dataset
|
2
|
+
from pathlib import Path
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import Literal, Union
|
5
|
+
from typing import Optional
|
6
|
+
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
7
|
+
from dataclasses import field
|
8
|
+
import torch.nn as nn
|
9
|
+
|
10
|
+
|
11
|
+
@dataclass
|
12
|
+
class YourModelConfig(BaseModelConfig):
|
13
|
+
model_name: str = "yourmodel" # 需与文件名保持一致
|
14
|
+
model_file_name: str = "arch.net" # archs路径下的模块路径
|
15
|
+
model_cls_name: str = "YourNet" # 模型类名
|
16
|
+
model_length: int = 41
|
17
|
+
require_params: bool = True
|
18
|
+
model_params: dict = field(default_factory=lambda: {"inDim": 41})
|
19
|
+
|
20
|
+
class YourModel(BaseModel):
|
21
|
+
def __init__(self, cfg: BaseModelConfig):
|
22
|
+
super(YourModel, self).__init__(cfg)
|
@@ -34,11 +34,17 @@ class PipelineConfig:
|
|
34
34
|
"Evaluate metrics or not."
|
35
35
|
save_metric: bool = True
|
36
36
|
"Metric names for evaluation."
|
37
|
-
metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim"])
|
37
|
+
metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim", "niqe", "brisque"])
|
38
38
|
"Save recoverd images or not."
|
39
39
|
save_img: bool = True
|
40
40
|
"Normalizing recoverd images and gt or not."
|
41
|
-
|
41
|
+
img_norm: bool = False
|
42
|
+
"Batch size for the test dataloader."
|
43
|
+
bs_test: int = 1
|
44
|
+
"Num_workers for the test dataloader."
|
45
|
+
nw_test: int = 0
|
46
|
+
"Pin_memory true or false for the dataloader."
|
47
|
+
pin_memory: bool = False
|
42
48
|
"Different modes for the pipeline."
|
43
49
|
_mode: Literal["single_mode", "multi_mode", "train_mode"] = "single_mode"
|
44
50
|
|
@@ -63,7 +69,8 @@ class Pipeline:
|
|
63
69
|
torch.set_grad_enabled(False)
|
64
70
|
# dataset
|
65
71
|
self.dataset: BaseDataset = build_dataset_name(dataset_cfg) if isinstance(dataset_cfg, str) else build_dataset_cfg(dataset_cfg)
|
66
|
-
self.
|
72
|
+
self.dataset.build_source(split="test")
|
73
|
+
self.dataloader = build_dataloader(self.dataset,self.cfg)
|
67
74
|
# device
|
68
75
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
69
76
|
|
@@ -103,7 +110,7 @@ class Pipeline:
|
|
103
110
|
"""Function I---Save the recoverd image and calculate the metric from the given dataset."""
|
104
111
|
# save folder
|
105
112
|
self.logger.info("*********************** infer_from_dataset ***********************")
|
106
|
-
save_folder = self.save_folder / Path(f"infer_from_dataset/{self.dataset.cfg.dataset_name}_dataset/{self.dataset.
|
113
|
+
save_folder = self.save_folder / Path(f"infer_from_dataset/{self.dataset.cfg.dataset_name}_dataset/{self.dataset.split}/{idx:06d}")
|
107
114
|
os.makedirs(str(save_folder), exist_ok=True)
|
108
115
|
|
109
116
|
# data process
|
@@ -117,7 +124,7 @@ class Pipeline:
|
|
117
124
|
img = None
|
118
125
|
return self.infer(spike, img, save_folder, rate)
|
119
126
|
|
120
|
-
def infer_from_file(self, file_path, height=-1, width=-1,
|
127
|
+
def infer_from_file(self, file_path, height=-1, width=-1, rate=1, img_path=None, remove_head=False):
|
121
128
|
"""Function II---Save the recoverd image and calculate the metric from the given input file."""
|
122
129
|
# save folder
|
123
130
|
self.logger.info("*********************** infer_from_file ***********************")
|
@@ -144,7 +151,7 @@ class Pipeline:
|
|
144
151
|
spike = torch.from_numpy(spike)[None].to(self.device)
|
145
152
|
return self.infer(spike, img, save_folder, rate)
|
146
153
|
|
147
|
-
def infer_from_spk(self, spike,
|
154
|
+
def infer_from_spk(self, spike, rate=1, img=None):
|
148
155
|
"""Function III---Save the recoverd image and calculate the metric from the given spike stream."""
|
149
156
|
# save folder
|
150
157
|
self.logger.info("*********************** infer_from_spk ***********************")
|
@@ -181,7 +188,7 @@ class Pipeline:
|
|
181
188
|
for idx in range(len(self.dataset)):
|
182
189
|
self.infer_from_dataset(idx=idx)
|
183
190
|
self.cfg.save_metric = base_setting
|
184
|
-
|
191
|
+
|
185
192
|
# TODO: To be overridden
|
186
193
|
def cal_params(self):
|
187
194
|
"""Function VI---Calculate the parameters/flops/latency of the given method."""
|
@@ -228,8 +235,8 @@ class Pipeline:
|
|
228
235
|
# With no GT
|
229
236
|
if recon_img == None:
|
230
237
|
return None
|
231
|
-
#
|
232
|
-
if model_name in ["
|
238
|
+
# spikeclip is normalized automatically
|
239
|
+
if model_name in ["spikeclip"] or self.cfg.img_norm == True:
|
233
240
|
recon_img = (recon_img - recon_img.min()) / (recon_img.max() - recon_img.min())
|
234
241
|
else:
|
235
242
|
recon_img = recon_img / rate
|
@@ -253,7 +260,7 @@ class Pipeline:
|
|
253
260
|
batch = model.feed_to_device(batch)
|
254
261
|
outputs = model.get_outputs_dict(batch)
|
255
262
|
recon_img, img = model.get_paired_imgs(batch, outputs)
|
256
|
-
recon_img, img = self._post_process_img(recon_img, model_name), self._post_process_img(img, "
|
263
|
+
recon_img, img = self._post_process_img(recon_img, model_name), self._post_process_img(img, "gt")
|
257
264
|
for metric_name in metrics_dict.keys():
|
258
265
|
if metric_name in metric_pair_names:
|
259
266
|
metrics_dict[metric_name].update(cal_metric_pair(recon_img, img, metric_name))
|
@@ -48,7 +48,8 @@ class EnsemblePipeline(Pipeline):
|
|
48
48
|
torch.set_grad_enabled(False)
|
49
49
|
# data
|
50
50
|
self.dataset: BaseDataset = build_dataset_name(dataset_cfg) if isinstance(dataset_cfg, str) else build_dataset_cfg(dataset_cfg)
|
51
|
-
self.
|
51
|
+
self.dataset.build_source(split = "test")
|
52
|
+
self.dataloader = build_dataloader(self.dataset,self.cfg)
|
52
53
|
# device
|
53
54
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
54
55
|
|
spikezoo/pipeline/train_cfgs.py
CHANGED
@@ -22,7 +22,7 @@ class REDS_BASE_TrainConfig(TrainPipelineConfig):
|
|
22
22
|
|
23
23
|
# dataloader setting
|
24
24
|
bs_train: int = 8
|
25
|
-
|
25
|
+
nw_train: int = 4
|
26
26
|
pin_memory: bool = False
|
27
27
|
|
28
28
|
# train setting - optimizer & scheduler & loss_dict
|
@@ -30,6 +30,8 @@ class REDS_BASE_TrainConfig(TrainPipelineConfig):
|
|
30
30
|
scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[400], gamma=0.2) # from wgse
|
31
31
|
loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
|
32
32
|
|
33
|
+
|
34
|
+
|
33
35
|
# ! Train Config for each method on the official setting, not recommended to utilize their default parameters owing to the dataset setting.
|
34
36
|
@dataclass
|
35
37
|
class BSFTrainConfig(TrainPipelineConfig):
|