spikezoo 0.2.3.4__py3-none-any.whl → 0.2.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
- spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
- spikezoo/archs/stir/metrics/__pycache__/losses.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/Vgg19.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/geometry.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/image_proc.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/losses.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
- spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/yourmodel/arch/__pycache__/net.cpython-39.pyc +0 -0
- spikezoo/archs/yourmodel/arch/net.py +35 -0
- spikezoo/datasets/__init__.py +20 -21
- spikezoo/datasets/base_dataset.py +26 -21
- spikezoo/datasets/{realworld_dataset.py → realdata_dataset.py} +5 -7
- spikezoo/datasets/reds_base_dataset.py +1 -1
- spikezoo/datasets/szdata_dataset.py +1 -5
- spikezoo/datasets/uhsr_dataset.py +1 -1
- spikezoo/datasets/yourdataset_dataset.py +23 -0
- spikezoo/models/__init__.py +12 -8
- spikezoo/models/base_model.py +10 -4
- spikezoo/models/bsf_model.py +0 -1
- spikezoo/models/spk2imgnet_model.py +0 -1
- spikezoo/models/stir_model.py +0 -1
- spikezoo/models/wgse_model.py +0 -1
- spikezoo/models/yourmodel_model.py +22 -0
- spikezoo/pipeline/base_pipeline.py +17 -10
- spikezoo/pipeline/ensemble_pipeline.py +2 -1
- spikezoo/pipeline/train_cfgs.py +3 -1
- spikezoo/pipeline/train_pipeline.py +12 -12
- spikezoo/utils/spike_utils.py +2 -2
- spikezoo-0.2.3.6.dist-info/METADATA +151 -0
- {spikezoo-0.2.3.4.dist-info → spikezoo-0.2.3.6.dist-info}/RECORD +53 -23
- spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
- spikezoo-0.2.3.4.dist-info/METADATA +0 -259
- {spikezoo-0.2.3.4.dist-info → spikezoo-0.2.3.6.dist-info}/LICENSE.txt +0 -0
- {spikezoo-0.2.3.4.dist-info → spikezoo-0.2.3.6.dist-info}/WHEEL +0 -0
- {spikezoo-0.2.3.4.dist-info → spikezoo-0.2.3.6.dist-info}/top_level.txt +0 -0
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
@@ -0,0 +1,35 @@
|
|
1
|
+
import torch.nn as nn
|
2
|
+
|
3
|
+
def conv_layer(inDim, outDim, ks, s, p, norm_layer="none"):
|
4
|
+
## convolutional layer
|
5
|
+
conv = nn.Conv2d(inDim, outDim, kernel_size=ks, stride=s, padding=p)
|
6
|
+
relu = nn.ReLU(True)
|
7
|
+
assert norm_layer in ("batch", "instance", "none")
|
8
|
+
if norm_layer == "none":
|
9
|
+
seq = nn.Sequential(*[conv, relu])
|
10
|
+
else:
|
11
|
+
if norm_layer == "instance":
|
12
|
+
norm = nn.InstanceNorm2d(outDim, affine=False, track_running_stats=False) # instance norm
|
13
|
+
else:
|
14
|
+
momentum = 0.1
|
15
|
+
norm = nn.BatchNorm2d(outDim, momentum=momentum, affine=True, track_running_stats=True)
|
16
|
+
seq = nn.Sequential(*[conv, norm, relu])
|
17
|
+
return seq
|
18
|
+
|
19
|
+
|
20
|
+
class YourNet(nn.Module):
|
21
|
+
"""Borrow the structure from the SpikeCLIP. (https://arxiv.org/abs/2501.04477)"""
|
22
|
+
|
23
|
+
def __init__(self, inDim=41):
|
24
|
+
super(YourNet, self).__init__()
|
25
|
+
norm = "none"
|
26
|
+
outDim = 1
|
27
|
+
convBlock1 = conv_layer(inDim, 64, 3, 1, 1)
|
28
|
+
convBlock2 = conv_layer(64, 128, 3, 1, 1, norm)
|
29
|
+
convBlock3 = conv_layer(128, 64, 3, 1, 1, norm)
|
30
|
+
convBlock4 = conv_layer(64, 16, 3, 1, 1, norm)
|
31
|
+
conv = nn.Conv2d(16, outDim, 3, 1, 1)
|
32
|
+
self.seq = nn.Sequential(*[convBlock1, convBlock2, convBlock3, convBlock4, conv])
|
33
|
+
|
34
|
+
def forward(self, x):
|
35
|
+
return self.seq(x)
|
spikezoo/datasets/__init__.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
|
2
|
+
|
2
3
|
from dataclasses import replace
|
3
4
|
import importlib, inspect
|
4
5
|
import os
|
@@ -12,23 +13,24 @@ dataset_list = [file.replace("_dataset.py", "") for file in files_list if file.e
|
|
12
13
|
|
13
14
|
|
14
15
|
# todo register function
|
15
|
-
def build_dataset_cfg(cfg: BaseDatasetConfig
|
16
|
+
def build_dataset_cfg(cfg: BaseDatasetConfig):
|
16
17
|
"""Build the dataset from the given dataset config."""
|
17
|
-
# build new cfg according to split
|
18
|
-
cfg = replace(cfg, split=split, spike_length=cfg.spike_length_train if split == "train" else cfg.spike_length_test)
|
19
18
|
# dataset module
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
19
|
+
if cfg.dataset_cls_local == None:
|
20
|
+
module_name = cfg.dataset_name + "_dataset"
|
21
|
+
assert cfg.dataset_name in dataset_list, f"Given dataset {cfg.dataset_name} not in our dataset list {dataset_list}."
|
22
|
+
module_name = "spikezoo.datasets." + module_name
|
23
|
+
module = importlib.import_module(module_name)
|
24
|
+
# dataset,dataset_config
|
25
|
+
dataset_name = cfg.dataset_name
|
26
|
+
dataset_name = dataset_name + "Dataset" if dataset_name == "base" else dataset_name
|
27
|
+
dataset_cls: BaseDataset = getattr_case_insensitive(module, dataset_name)
|
28
|
+
else:
|
29
|
+
dataset_cls = cfg.dataset_cls_local
|
28
30
|
dataset = dataset_cls(cfg)
|
29
31
|
return dataset
|
30
32
|
|
31
|
-
def build_dataset_name(dataset_name: str
|
33
|
+
def build_dataset_name(dataset_name: str):
|
32
34
|
"""Build the default dataset from the given name."""
|
33
35
|
module_name = dataset_name + "_dataset"
|
34
36
|
assert dataset_name in dataset_list, f"Given dataset {dataset_name} not in our dataset list {dataset_list}."
|
@@ -37,22 +39,19 @@ def build_dataset_name(dataset_name: str, split: Literal["train", "test"] = "tes
|
|
37
39
|
# dataset,dataset_config
|
38
40
|
dataset_name = dataset_name + "Dataset" if dataset_name == "base" else dataset_name
|
39
41
|
dataset_cls: BaseDataset = getattr_case_insensitive(module, dataset_name)
|
40
|
-
dataset_cfg: BaseDatasetConfig = getattr_case_insensitive(module, dataset_name + "config")(
|
42
|
+
dataset_cfg: BaseDatasetConfig = getattr_case_insensitive(module, dataset_name + "config")()
|
41
43
|
dataset = dataset_cls(dataset_cfg)
|
42
44
|
return dataset
|
43
45
|
|
44
46
|
|
45
47
|
# todo to modify according to the basicsr
|
46
|
-
def build_dataloader(dataset
|
48
|
+
def build_dataloader(dataset, cfg):
|
47
49
|
# train dataloader
|
48
|
-
if dataset.
|
49
|
-
|
50
|
-
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
|
51
|
-
else:
|
52
|
-
return torch.utils.data.DataLoader(dataset, batch_size=cfg.bs_train, shuffle=True, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory)
|
50
|
+
if dataset.split == "train" and cfg._mode == "train_mode":
|
51
|
+
return torch.utils.data.DataLoader(dataset, batch_size=cfg.bs_train, shuffle=True, num_workers=cfg.nw_train, pin_memory=cfg.pin_memory)
|
53
52
|
# test dataloader
|
54
|
-
|
55
|
-
return torch.utils.data.DataLoader(dataset, batch_size=
|
53
|
+
else:
|
54
|
+
return torch.utils.data.DataLoader(dataset, batch_size=cfg.bs_test, shuffle=False, num_workers=cfg.nw_test,pin_memory=False)
|
56
55
|
|
57
56
|
|
58
57
|
# dataset_size_dict = {}
|
@@ -11,6 +11,7 @@ import warnings
|
|
11
11
|
import torch
|
12
12
|
from tqdm import tqdm
|
13
13
|
from spikezoo.utils.data_utils import Augmentor
|
14
|
+
from typing import Optional
|
14
15
|
|
15
16
|
|
16
17
|
@dataclass
|
@@ -30,31 +31,24 @@ class BaseDatasetConfig:
|
|
30
31
|
spike_length_train: int = -1
|
31
32
|
"Dataset spike length for the test data."
|
32
33
|
spike_length_test: int = -1
|
33
|
-
"Dataset spike length for the instantiation dataclass."
|
34
|
-
spike_length: int = -1
|
35
34
|
"Dir name for the spike."
|
36
35
|
spike_dir_name: str = "spike"
|
37
36
|
"Dir name for the image."
|
38
37
|
img_dir_name: str = "gt"
|
38
|
+
"Rate. (-1 denotes variant)"
|
39
|
+
rate: float = 0.6
|
39
40
|
|
40
41
|
# ------------- Config -------------
|
41
|
-
"Dataset split: train/test. Default set as the 'test' for evaluation."
|
42
|
-
split: Literal["train", "test"] = "test"
|
43
42
|
"Use the data augumentation technique or not."
|
44
43
|
use_aug: bool = False
|
45
44
|
"Use cache mechanism."
|
46
45
|
use_cache: bool = False
|
47
46
|
"Crop size."
|
48
47
|
crop_size: tuple = (-1, -1)
|
49
|
-
"
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
def __post_init__(self):
|
54
|
-
self.spike_length = self.spike_length_train if self.split == "train" else self.spike_length_test
|
55
|
-
self.root_dir = Path(self.root_dir) if isinstance(self.root_dir, str) else self.root_dir
|
56
|
-
# todo try download
|
57
|
-
assert self.root_dir.exists(), f"No files found in {self.root_dir} for the specified dataset `{self.dataset_name}`."
|
48
|
+
"Load the dataset from local or spikezoo lib."
|
49
|
+
dataset_cls_local: Optional[Dataset] = None
|
50
|
+
"Spike load version. [python,cpp]"
|
51
|
+
spike_load_version: Literal["python", "cpp"] = "python"
|
58
52
|
|
59
53
|
|
60
54
|
# todo cache mechanism
|
@@ -62,10 +56,6 @@ class BaseDataset(Dataset):
|
|
62
56
|
def __init__(self, cfg: BaseDatasetConfig):
|
63
57
|
super(BaseDataset, self).__init__()
|
64
58
|
self.cfg = cfg
|
65
|
-
self.augmentor = Augmentor(self.cfg.crop_size) if self.cfg.use_aug == True and self.cfg.split == "train" else -1
|
66
|
-
self.prepare_data()
|
67
|
-
self.cache_data() if cfg.use_cache == True else -1
|
68
|
-
warnings.warn("Lengths of the image list and the spike list should be equal.") if len(self.img_list) != len(self.spike_list) else -1
|
69
59
|
|
70
60
|
def __len__(self):
|
71
61
|
return len(self.spike_list)
|
@@ -80,7 +70,7 @@ class BaseDataset(Dataset):
|
|
80
70
|
img = self.get_img(idx)
|
81
71
|
|
82
72
|
# process data
|
83
|
-
if self.cfg.use_aug == True and self.
|
73
|
+
if self.cfg.use_aug == True and self.split == "train":
|
84
74
|
spike, img = self.augmentor(spike, img)
|
85
75
|
|
86
76
|
# rate
|
@@ -90,15 +80,29 @@ class BaseDataset(Dataset):
|
|
90
80
|
batch = {"spike": spike, "gt_img": img, "rate": rate}
|
91
81
|
return batch
|
92
82
|
|
83
|
+
def build_source(self, split: Literal["train", "test"] = "test"):
|
84
|
+
"""Build the dataset source and prepare to be loaded files."""
|
85
|
+
# spike length
|
86
|
+
self.split = split
|
87
|
+
self.spike_length = self.cfg.spike_length_train if self.split == "train" else self.cfg.spike_length_test
|
88
|
+
# root dir
|
89
|
+
self.cfg.root_dir = Path(self.cfg.root_dir) if isinstance(self.cfg.root_dir, str) else self.cfg.root_dir
|
90
|
+
assert self.cfg.root_dir.exists(), f"No files found in {self.cfg.root_dir} for the specified dataset `{self.cfg.dataset_name}`."
|
91
|
+
# prepare
|
92
|
+
self.augmentor = Augmentor(self.cfg.crop_size) if self.cfg.use_aug == True and self.split == "train" else -1
|
93
|
+
self.prepare_data()
|
94
|
+
self.cache_data() if self.cfg.use_cache == True else -1
|
95
|
+
warnings.warn("Lengths of the image list and the spike list should be equal.") if len(self.img_list) != len(self.spike_list) else -1
|
96
|
+
|
93
97
|
# todo: To be overridden
|
94
98
|
def prepare_data(self):
|
95
99
|
"""Specify the spike and image files to be loaded."""
|
96
100
|
# spike
|
97
|
-
self.spike_dir = self.cfg.root_dir / self.
|
101
|
+
self.spike_dir = self.cfg.root_dir / self.split / self.cfg.spike_dir_name
|
98
102
|
self.spike_list = self.get_spike_files(self.spike_dir)
|
99
103
|
# gt
|
100
104
|
if self.cfg.with_img == True:
|
101
|
-
self.img_dir = self.cfg.root_dir / self.
|
105
|
+
self.img_dir = self.cfg.root_dir / self.split / self.cfg.img_dir_name
|
102
106
|
self.img_list = self.get_image_files(self.img_dir)
|
103
107
|
|
104
108
|
# todo: To be overridden
|
@@ -116,12 +120,13 @@ class BaseDataset(Dataset):
|
|
116
120
|
height=self.cfg.height,
|
117
121
|
width=self.cfg.width,
|
118
122
|
out_format="tensor",
|
123
|
+
version=self.cfg.spike_load_version
|
119
124
|
)
|
120
125
|
return spike
|
121
126
|
|
122
127
|
def get_spike(self, idx):
|
123
128
|
"""Get and process the spike stream from the given idx."""
|
124
|
-
spike_length = self.
|
129
|
+
spike_length = self.spike_length
|
125
130
|
spike = self.load_spike(idx)
|
126
131
|
assert spike.shape[0] >= spike_length, f"Given spike length {spike.shape[0]} smaller than the required length {spike_length}"
|
127
132
|
spike_mid = spike.shape[0] // 2
|
@@ -4,21 +4,19 @@ from dataclasses import dataclass
|
|
4
4
|
|
5
5
|
|
6
6
|
@dataclass
|
7
|
-
class
|
8
|
-
dataset_name: str = "
|
9
|
-
root_dir: Path = Path(__file__).parent.parent / Path("data/
|
7
|
+
class RealDataConfig(BaseDatasetConfig):
|
8
|
+
dataset_name: str = "realdata"
|
9
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/realdata")
|
10
10
|
width: int = 400
|
11
11
|
height: int = 250
|
12
12
|
with_img: bool = False
|
13
13
|
spike_length_train: int = -1
|
14
14
|
spike_length_test: int = -1
|
15
15
|
rate: float = 1
|
16
|
-
|
17
16
|
|
18
|
-
|
19
|
-
class RealWorld(BaseDataset):
|
17
|
+
class RealData(BaseDataset):
|
20
18
|
def __init__(self, cfg: BaseDatasetConfig):
|
21
|
-
super(
|
19
|
+
super(RealData, self).__init__(cfg)
|
22
20
|
|
23
21
|
def prepare_data(self):
|
24
22
|
self.spike_dir = self.cfg.root_dir
|
@@ -8,7 +8,7 @@ import re
|
|
8
8
|
@dataclass
|
9
9
|
class REDS_BASEConfig(BaseDatasetConfig):
|
10
10
|
dataset_name: str = "reds_base"
|
11
|
-
root_dir: Path = Path(__file__).parent.parent / Path("data/
|
11
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/reds_base")
|
12
12
|
width: int = 400
|
13
13
|
height: int = 250
|
14
14
|
with_img: bool = True
|
@@ -9,7 +9,7 @@ import numpy as np
|
|
9
9
|
@dataclass
|
10
10
|
class SZDataConfig(BaseDatasetConfig):
|
11
11
|
dataset_name: str = "szdata"
|
12
|
-
root_dir: Path = Path(__file__).parent.parent / Path("data/
|
12
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/szdata")
|
13
13
|
width: int = 400
|
14
14
|
height: int = 250
|
15
15
|
with_img: bool = True
|
@@ -22,7 +22,3 @@ class SZDataConfig(BaseDatasetConfig):
|
|
22
22
|
class SZData(BaseDataset):
|
23
23
|
def __init__(self, cfg: BaseDatasetConfig):
|
24
24
|
super(SZData, self).__init__(cfg)
|
25
|
-
|
26
|
-
def prepare_data(self):
|
27
|
-
super().prepare_data()
|
28
|
-
self.img_list = [self.img_dir / Path(str(s.name).replace('.dat','.png')) for s in self.spike_list]
|
@@ -7,7 +7,7 @@ import torch
|
|
7
7
|
@dataclass
|
8
8
|
class UHSRConfig(BaseDatasetConfig):
|
9
9
|
dataset_name: str = "uhsr"
|
10
|
-
root_dir: Path = Path(__file__).parent.parent / Path("data/
|
10
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/u_caltech")
|
11
11
|
width: int = 224
|
12
12
|
height: int = 224
|
13
13
|
with_img: bool = False
|
@@ -0,0 +1,23 @@
|
|
1
|
+
from torch.utils.data import Dataset
|
2
|
+
from pathlib import Path
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import Literal, Union
|
5
|
+
from typing import Optional
|
6
|
+
from spikezoo.datasets.base_dataset import BaseDatasetConfig,BaseDataset
|
7
|
+
|
8
|
+
@dataclass
|
9
|
+
class YourDatasetConfig(BaseDatasetConfig):
|
10
|
+
dataset_name: str = "yourdataset"
|
11
|
+
root_dir: Union[str, Path] = Path(__file__).parent.parent / Path("data/your_data_path")
|
12
|
+
width: int = 400
|
13
|
+
height: int = 250
|
14
|
+
with_img: bool = True
|
15
|
+
spike_length_train: int = -1
|
16
|
+
spike_length_test: int = -1
|
17
|
+
spike_dir_name: str = "spike_data"
|
18
|
+
img_dir_name: str = "sharp_data"
|
19
|
+
rate: float = 1
|
20
|
+
|
21
|
+
class YourDataset(BaseDataset):
|
22
|
+
def __init__(self, cfg: BaseDatasetConfig):
|
23
|
+
super(YourDataset, self).__init__(cfg)
|
spikezoo/models/__init__.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import importlib
|
2
2
|
import inspect
|
3
3
|
from spikezoo.models.base_model import BaseModel,BaseModelConfig
|
4
|
+
|
4
5
|
from spikezoo.utils.other_utils import getattr_case_insensitive
|
5
6
|
import os
|
6
7
|
from pathlib import Path
|
@@ -13,14 +14,17 @@ model_list = [file.split("_")[0] for file in files_list if file.endswith("_model
|
|
13
14
|
def build_model_cfg(cfg: BaseModelConfig):
|
14
15
|
"""Build the model from the given model config."""
|
15
16
|
# model module name
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
17
|
+
if cfg.model_cls_local == None:
|
18
|
+
module_name = cfg.model_name + "_model"
|
19
|
+
assert cfg.model_name in model_list, f"Given model {cfg.model_name} not in our model zoo {model_list}."
|
20
|
+
module_name = "spikezoo.models." + module_name
|
21
|
+
module = importlib.import_module(module_name)
|
22
|
+
# model,model_config
|
23
|
+
model_name = cfg.model_name
|
24
|
+
model_name = model_name + 'Model' if model_name == "base" else model_name
|
25
|
+
model_cls: BaseModel = getattr_case_insensitive(module,model_name)
|
26
|
+
else:
|
27
|
+
model_cls: BaseModel = cfg.model_cls_local
|
24
28
|
model = model_cls(cfg)
|
25
29
|
return model
|
26
30
|
|
spikezoo/models/base_model.py
CHANGED
@@ -44,7 +44,10 @@ class BaseModelConfig:
|
|
44
44
|
multi_gpu: bool = False
|
45
45
|
"Base url."
|
46
46
|
base_url: str = "https://github.com/chenkang455/Spike-Zoo/releases/download"
|
47
|
-
|
47
|
+
"Load the model from local class or spikezoo lib. (None)"
|
48
|
+
model_cls_local: Optional[nn.Module] = None
|
49
|
+
"Load the arch from local class or spikezoo lib. (None)"
|
50
|
+
arch_cls_local: Optional[nn.Module] = None
|
48
51
|
|
49
52
|
class BaseModel(nn.Module):
|
50
53
|
def __init__(self, cfg: BaseModelConfig):
|
@@ -71,8 +74,11 @@ class BaseModel(nn.Module):
|
|
71
74
|
):
|
72
75
|
"""Build the network and load the pretrained weight."""
|
73
76
|
# network
|
74
|
-
|
75
|
-
|
77
|
+
if self.cfg.arch_cls_local == None:
|
78
|
+
module = importlib.import_module(f"spikezoo.archs.{self.cfg.model_name}.{self.cfg.model_file_name}")
|
79
|
+
model_cls = getattr(module, self.cfg.model_cls_name)
|
80
|
+
else:
|
81
|
+
model_cls = self.cfg.arch_cls_local
|
76
82
|
# load model config parameters
|
77
83
|
if version == "local":
|
78
84
|
model = model_cls(**self.cfg.model_params)
|
@@ -129,7 +135,7 @@ class BaseModel(nn.Module):
|
|
129
135
|
"""Crop the spike length."""
|
130
136
|
spike_length = spike.shape[1]
|
131
137
|
spike_mid = spike_length // 2
|
132
|
-
assert spike_length >= self.model_length, f"Spike input is not long enough, given {spike_length} frames < {self.cfg.model_length}."
|
138
|
+
assert spike_length >= self.model_length, f"Spike input is not long enough, given {spike_length} frames < {self.cfg.model_length} required by the {self.cfg.model_name}."
|
133
139
|
# even length
|
134
140
|
if self.model_length == self.model_half_length * 2:
|
135
141
|
spike = spike[
|
spikezoo/models/bsf_model.py
CHANGED
@@ -4,7 +4,6 @@ from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
|
4
4
|
from torch.optim import Adam
|
5
5
|
import torch.optim.lr_scheduler as lr_scheduler
|
6
6
|
import torch.nn as nn
|
7
|
-
from spikezoo.pipeline import TrainPipelineConfig
|
8
7
|
from typing import List
|
9
8
|
from spikezoo.archs.bsf.models.bsf.bsf import BSF
|
10
9
|
|
@@ -1,7 +1,6 @@
|
|
1
1
|
import torch
|
2
2
|
from dataclasses import dataclass, field
|
3
3
|
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
4
|
-
from spikezoo.pipeline import TrainPipelineConfig
|
5
4
|
import torch.nn as nn
|
6
5
|
import torch.optim as optim
|
7
6
|
import torch.optim.lr_scheduler as lr_scheduler
|
spikezoo/models/stir_model.py
CHANGED
@@ -4,7 +4,6 @@ from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
|
4
4
|
from torch.optim import Adam
|
5
5
|
import torch.optim.lr_scheduler as lr_scheduler
|
6
6
|
import torch.nn as nn
|
7
|
-
from spikezoo.pipeline import TrainPipelineConfig
|
8
7
|
from typing import List
|
9
8
|
from spikezoo.archs.stir.metrics.losses import compute_per_loss_single
|
10
9
|
from spikezoo.archs.stir.models.Vgg19 import Vgg19
|
spikezoo/models/wgse_model.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
from dataclasses import dataclass, field
|
2
2
|
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
3
3
|
from typing import List
|
4
|
-
from spikezoo.pipeline import TrainPipelineConfig
|
5
4
|
import torch.nn as nn
|
6
5
|
import torch.optim as optim
|
7
6
|
import torch.optim.lr_scheduler as lr_scheduler
|
@@ -0,0 +1,22 @@
|
|
1
|
+
from torch.utils.data import Dataset
|
2
|
+
from pathlib import Path
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import Literal, Union
|
5
|
+
from typing import Optional
|
6
|
+
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
7
|
+
from dataclasses import field
|
8
|
+
import torch.nn as nn
|
9
|
+
|
10
|
+
|
11
|
+
@dataclass
|
12
|
+
class YourModelConfig(BaseModelConfig):
|
13
|
+
model_name: str = "yourmodel" # 需与文件名保持一致
|
14
|
+
model_file_name: str = "arch.net" # archs路径下的模块路径
|
15
|
+
model_cls_name: str = "YourNet" # 模型类名
|
16
|
+
model_length: int = 41
|
17
|
+
require_params: bool = True
|
18
|
+
model_params: dict = field(default_factory=lambda: {"inDim": 41})
|
19
|
+
|
20
|
+
class YourModel(BaseModel):
|
21
|
+
def __init__(self, cfg: BaseModelConfig):
|
22
|
+
super(YourModel, self).__init__(cfg)
|
@@ -34,11 +34,17 @@ class PipelineConfig:
|
|
34
34
|
"Evaluate metrics or not."
|
35
35
|
save_metric: bool = True
|
36
36
|
"Metric names for evaluation."
|
37
|
-
metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim"])
|
37
|
+
metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim", "niqe", "brisque"])
|
38
38
|
"Save recoverd images or not."
|
39
39
|
save_img: bool = True
|
40
40
|
"Normalizing recoverd images and gt or not."
|
41
|
-
|
41
|
+
img_norm: bool = False
|
42
|
+
"Batch size for the test dataloader."
|
43
|
+
bs_test: int = 1
|
44
|
+
"Num_workers for the test dataloader."
|
45
|
+
nw_test: int = 0
|
46
|
+
"Pin_memory true or false for the dataloader."
|
47
|
+
pin_memory: bool = False
|
42
48
|
"Different modes for the pipeline."
|
43
49
|
_mode: Literal["single_mode", "multi_mode", "train_mode"] = "single_mode"
|
44
50
|
|
@@ -63,7 +69,8 @@ class Pipeline:
|
|
63
69
|
torch.set_grad_enabled(False)
|
64
70
|
# dataset
|
65
71
|
self.dataset: BaseDataset = build_dataset_name(dataset_cfg) if isinstance(dataset_cfg, str) else build_dataset_cfg(dataset_cfg)
|
66
|
-
self.
|
72
|
+
self.dataset.build_source(split="test")
|
73
|
+
self.dataloader = build_dataloader(self.dataset,self.cfg)
|
67
74
|
# device
|
68
75
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
69
76
|
|
@@ -103,7 +110,7 @@ class Pipeline:
|
|
103
110
|
"""Function I---Save the recoverd image and calculate the metric from the given dataset."""
|
104
111
|
# save folder
|
105
112
|
self.logger.info("*********************** infer_from_dataset ***********************")
|
106
|
-
save_folder = self.save_folder / Path(f"infer_from_dataset/{self.dataset.cfg.dataset_name}_dataset/{self.dataset.
|
113
|
+
save_folder = self.save_folder / Path(f"infer_from_dataset/{self.dataset.cfg.dataset_name}_dataset/{self.dataset.split}/{idx:06d}")
|
107
114
|
os.makedirs(str(save_folder), exist_ok=True)
|
108
115
|
|
109
116
|
# data process
|
@@ -117,7 +124,7 @@ class Pipeline:
|
|
117
124
|
img = None
|
118
125
|
return self.infer(spike, img, save_folder, rate)
|
119
126
|
|
120
|
-
def infer_from_file(self, file_path, height=-1, width=-1,
|
127
|
+
def infer_from_file(self, file_path, height=-1, width=-1, rate=1, img_path=None, remove_head=False):
|
121
128
|
"""Function II---Save the recoverd image and calculate the metric from the given input file."""
|
122
129
|
# save folder
|
123
130
|
self.logger.info("*********************** infer_from_file ***********************")
|
@@ -144,7 +151,7 @@ class Pipeline:
|
|
144
151
|
spike = torch.from_numpy(spike)[None].to(self.device)
|
145
152
|
return self.infer(spike, img, save_folder, rate)
|
146
153
|
|
147
|
-
def infer_from_spk(self, spike,
|
154
|
+
def infer_from_spk(self, spike, rate=1, img=None):
|
148
155
|
"""Function III---Save the recoverd image and calculate the metric from the given spike stream."""
|
149
156
|
# save folder
|
150
157
|
self.logger.info("*********************** infer_from_spk ***********************")
|
@@ -181,7 +188,7 @@ class Pipeline:
|
|
181
188
|
for idx in range(len(self.dataset)):
|
182
189
|
self.infer_from_dataset(idx=idx)
|
183
190
|
self.cfg.save_metric = base_setting
|
184
|
-
|
191
|
+
|
185
192
|
# TODO: To be overridden
|
186
193
|
def cal_params(self):
|
187
194
|
"""Function VI---Calculate the parameters/flops/latency of the given method."""
|
@@ -228,8 +235,8 @@ class Pipeline:
|
|
228
235
|
# With no GT
|
229
236
|
if recon_img == None:
|
230
237
|
return None
|
231
|
-
#
|
232
|
-
if model_name in ["
|
238
|
+
# spikeclip is normalized automatically
|
239
|
+
if model_name in ["spikeclip"] or self.cfg.img_norm == True:
|
233
240
|
recon_img = (recon_img - recon_img.min()) / (recon_img.max() - recon_img.min())
|
234
241
|
else:
|
235
242
|
recon_img = recon_img / rate
|
@@ -253,7 +260,7 @@ class Pipeline:
|
|
253
260
|
batch = model.feed_to_device(batch)
|
254
261
|
outputs = model.get_outputs_dict(batch)
|
255
262
|
recon_img, img = model.get_paired_imgs(batch, outputs)
|
256
|
-
recon_img, img = self._post_process_img(recon_img, model_name), self._post_process_img(img, "
|
263
|
+
recon_img, img = self._post_process_img(recon_img, model_name), self._post_process_img(img, "gt")
|
257
264
|
for metric_name in metrics_dict.keys():
|
258
265
|
if metric_name in metric_pair_names:
|
259
266
|
metrics_dict[metric_name].update(cal_metric_pair(recon_img, img, metric_name))
|
@@ -48,7 +48,8 @@ class EnsemblePipeline(Pipeline):
|
|
48
48
|
torch.set_grad_enabled(False)
|
49
49
|
# data
|
50
50
|
self.dataset: BaseDataset = build_dataset_name(dataset_cfg) if isinstance(dataset_cfg, str) else build_dataset_cfg(dataset_cfg)
|
51
|
-
self.
|
51
|
+
self.dataset.build_source(split = "test")
|
52
|
+
self.dataloader = build_dataloader(self.dataset,self.cfg)
|
52
53
|
# device
|
53
54
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
54
55
|
|
spikezoo/pipeline/train_cfgs.py
CHANGED
@@ -22,7 +22,7 @@ class REDS_BASE_TrainConfig(TrainPipelineConfig):
|
|
22
22
|
|
23
23
|
# dataloader setting
|
24
24
|
bs_train: int = 8
|
25
|
-
|
25
|
+
nw_train: int = 4
|
26
26
|
pin_memory: bool = False
|
27
27
|
|
28
28
|
# train setting - optimizer & scheduler & loss_dict
|
@@ -30,6 +30,8 @@ class REDS_BASE_TrainConfig(TrainPipelineConfig):
|
|
30
30
|
scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[400], gamma=0.2) # from wgse
|
31
31
|
loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
|
32
32
|
|
33
|
+
|
34
|
+
|
33
35
|
# ! Train Config for each method on the official setting, not recommended to utilize their default parameters owing to the dataset setting.
|
34
36
|
@dataclass
|
35
37
|
class BSFTrainConfig(TrainPipelineConfig):
|