spikezoo 0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spikezoo/__init__.py +0 -0
- spikezoo/archs/__init__.py +0 -0
- spikezoo/datasets/__init__.py +68 -0
- spikezoo/datasets/base_dataset.py +157 -0
- spikezoo/datasets/realworld_dataset.py +25 -0
- spikezoo/datasets/reds_small_dataset.py +27 -0
- spikezoo/datasets/szdata_dataset.py +37 -0
- spikezoo/datasets/uhsr_dataset.py +38 -0
- spikezoo/metrics/__init__.py +96 -0
- spikezoo/models/__init__.py +37 -0
- spikezoo/models/base_model.py +177 -0
- spikezoo/models/bsf_model.py +90 -0
- spikezoo/models/spcsnet_model.py +19 -0
- spikezoo/models/spikeclip_model.py +32 -0
- spikezoo/models/spikeformer_model.py +50 -0
- spikezoo/models/spk2imgnet_model.py +51 -0
- spikezoo/models/ssir_model.py +22 -0
- spikezoo/models/ssml_model.py +18 -0
- spikezoo/models/stir_model.py +37 -0
- spikezoo/models/tfi_model.py +18 -0
- spikezoo/models/tfp_model.py +18 -0
- spikezoo/models/wgse_model.py +31 -0
- spikezoo/pipeline/__init__.py +4 -0
- spikezoo/pipeline/base_pipeline.py +267 -0
- spikezoo/pipeline/ensemble_pipeline.py +64 -0
- spikezoo/pipeline/train_pipeline.py +94 -0
- spikezoo/utils/__init__.py +3 -0
- spikezoo/utils/data_utils.py +52 -0
- spikezoo/utils/img_utils.py +72 -0
- spikezoo/utils/other_utils.py +59 -0
- spikezoo/utils/spike_utils.py +82 -0
- spikezoo-0.1.dist-info/LICENSE.txt +17 -0
- spikezoo-0.1.dist-info/METADATA +39 -0
- spikezoo-0.1.dist-info/RECORD +36 -0
- spikezoo-0.1.dist-info/WHEEL +5 -0
- spikezoo-0.1.dist-info/top_level.txt +1 -0
spikezoo/__init__.py
ADDED
File without changes
|
File without changes
|
@@ -0,0 +1,68 @@
|
|
1
|
+
from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
|
2
|
+
from dataclasses import replace
|
3
|
+
import importlib, inspect
|
4
|
+
import os
|
5
|
+
import torch
|
6
|
+
from typing import Literal
|
7
|
+
|
8
|
+
# todo auto detect/register datasets
|
9
|
+
files_list = os.listdir(os.path.dirname(os.path.abspath(__file__)))
|
10
|
+
dataset_list = [file.replace("_dataset.py", "") for file in files_list if file.endswith("_dataset.py")]
|
11
|
+
|
12
|
+
# todo register function
|
13
|
+
def build_dataset_cfg(cfg: BaseDatasetConfig, split: Literal["train", "test"] = "test"):
|
14
|
+
"""Build the dataset from the given dataset config."""
|
15
|
+
# build new cfg according to split
|
16
|
+
cfg = replace(cfg,split = split,spike_length = cfg.spike_length_train if split == "train" else cfg.spike_length_test)
|
17
|
+
# dataset module
|
18
|
+
module_name = cfg.dataset_name + "_dataset"
|
19
|
+
assert cfg.dataset_name in dataset_list, f"Given dataset {cfg.dataset_name} not in our dataset list {dataset_list}."
|
20
|
+
module_name = "spikezoo.datasets." + module_name
|
21
|
+
module = importlib.import_module(module_name)
|
22
|
+
# dataset,dataset_config
|
23
|
+
classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
|
24
|
+
dataset_cls: BaseDataset = getattr(module, classes[0])
|
25
|
+
dataset = dataset_cls(cfg)
|
26
|
+
return dataset
|
27
|
+
|
28
|
+
|
29
|
+
def build_dataset_name(dataset_name: str, split: Literal["train", "test"] = "test"):
|
30
|
+
"""Build the default dataset from the given name."""
|
31
|
+
module_name = dataset_name + "_dataset"
|
32
|
+
assert dataset_name in dataset_list, f"Given dataset {dataset_name} not in our dataset list {dataset_list}."
|
33
|
+
module_name = "spikezoo.datasets." + module_name
|
34
|
+
module = importlib.import_module(module_name)
|
35
|
+
# dataset,dataset_config
|
36
|
+
classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
|
37
|
+
dataset_cls: BaseDataset = getattr(module, classes[0])
|
38
|
+
dataset_cfg: BaseDatasetConfig = getattr(module, classes[1])(split=split)
|
39
|
+
dataset = dataset_cls(dataset_cfg)
|
40
|
+
return dataset
|
41
|
+
|
42
|
+
|
43
|
+
# todo to modify according to the basicsr
|
44
|
+
def build_dataloader(dataset: BaseDataset,cfg = None):
|
45
|
+
# train dataloader
|
46
|
+
if dataset.cfg.split == "train":
|
47
|
+
if cfg is None:
|
48
|
+
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
|
49
|
+
else:
|
50
|
+
return torch.utils.data.DataLoader(dataset, batch_size=cfg.bs_train, shuffle=True, num_workers=cfg.num_workers,pin_memory=cfg.pin_memory)
|
51
|
+
# test dataloader
|
52
|
+
elif dataset.cfg.split == "test":
|
53
|
+
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
|
54
|
+
|
55
|
+
|
56
|
+
# dataset_size_dict = {}
|
57
|
+
# for dataset in dataset_list:
|
58
|
+
# module_name = dataset + "_dataset"
|
59
|
+
# module_name = "spikezoo.datasets." + module_name
|
60
|
+
# module = importlib.import_module(module_name)
|
61
|
+
# classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj)])
|
62
|
+
# dataset_cfg: BaseDatasetConfig = getattr(module, classes[1])
|
63
|
+
# dataset_size_dict[dataset] = (dataset_cfg.height, dataset_cfg.width)
|
64
|
+
|
65
|
+
|
66
|
+
# def get_dataset_size(name):
|
67
|
+
# assert name in dataset_list, f"Given dataset {name} not in our dataset list {dataset_list}."
|
68
|
+
# return dataset_size_dict[name]
|
@@ -0,0 +1,157 @@
|
|
1
|
+
import torch
|
2
|
+
from torch.utils.data import Dataset
|
3
|
+
from pathlib import Path
|
4
|
+
import cv2
|
5
|
+
import numpy as np
|
6
|
+
from spikezoo.utils.spike_utils import load_vidar_dat
|
7
|
+
import re
|
8
|
+
from dataclasses import dataclass, replace
|
9
|
+
from typing import Literal
|
10
|
+
import warnings
|
11
|
+
import torch
|
12
|
+
from tqdm import tqdm
|
13
|
+
from spikezoo.utils.data_utils import Augmentor
|
14
|
+
|
15
|
+
|
16
|
+
@dataclass
|
17
|
+
class BaseDatasetConfig:
|
18
|
+
# ------------- Not Recommended to Change -------------
|
19
|
+
"Dataset name."
|
20
|
+
dataset_name: str = "base"
|
21
|
+
"Directory specifying location of data."
|
22
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/base")
|
23
|
+
"Image width."
|
24
|
+
width: int = 400
|
25
|
+
"Image height."
|
26
|
+
height: int = 250
|
27
|
+
"Spike paried with the image or not."
|
28
|
+
with_img: bool = True
|
29
|
+
"Dataset spike length for the train data."
|
30
|
+
spike_length_train: int = -1
|
31
|
+
"Dataset spike length for the test data."
|
32
|
+
spike_length_test: int = -1
|
33
|
+
"Dataset spike length for the instantiation dataclass."
|
34
|
+
spike_length: int = -1
|
35
|
+
"Dir name for the spike."
|
36
|
+
spike_dir_name: str = "spike"
|
37
|
+
"Dir name for the image."
|
38
|
+
img_dir_name: str = "gt"
|
39
|
+
|
40
|
+
# ------------- Config -------------
|
41
|
+
"Dataset split: train/test. Default set as the 'test' for evaluation."
|
42
|
+
split: Literal["train", "test"] = "test"
|
43
|
+
"Use the data augumentation technique or not."
|
44
|
+
use_aug: bool = False
|
45
|
+
"Use cache mechanism."
|
46
|
+
use_cache: bool = False
|
47
|
+
"Crop size."
|
48
|
+
crop_size: tuple = (-1, -1)
|
49
|
+
|
50
|
+
# post process
|
51
|
+
def __post_init__(self):
|
52
|
+
self.spike_length = self.spike_length_train if self.split == "train" else self.spike_length_test
|
53
|
+
self.root_dir = Path(self.root_dir) if isinstance(self.root_dir, str) else self.root_dir
|
54
|
+
# todo try download
|
55
|
+
assert self.root_dir.exists(), f"No files found in {self.root_dir} for the specified dataset `{self.dataset_name}`."
|
56
|
+
|
57
|
+
# todo cache mechanism
|
58
|
+
class BaseDataset(Dataset):
|
59
|
+
def __init__(self, cfg: BaseDatasetConfig):
|
60
|
+
super(BaseDataset, self).__init__()
|
61
|
+
self.cfg = cfg
|
62
|
+
self.augmentor = Augmentor(self.cfg.crop_size) if self.cfg.use_aug == True and self.cfg.split == "train" else -1
|
63
|
+
self.prepare_data()
|
64
|
+
self.cache_data() if cfg.use_cache == True else -1
|
65
|
+
warnings.warn("Lengths of the image list and the spike list should be equal.") if len(self.img_list) != len(self.spike_list) else -1
|
66
|
+
|
67
|
+
def __len__(self):
|
68
|
+
return len(self.spike_list)
|
69
|
+
|
70
|
+
def __getitem__(self, idx: int):
|
71
|
+
# load data
|
72
|
+
if self.cfg.use_cache == True:
|
73
|
+
spike, img = self.cache_spkimg[idx]
|
74
|
+
spike = spike.to(torch.float32)
|
75
|
+
else:
|
76
|
+
spike = self.get_spike(idx)
|
77
|
+
img = self.get_img(idx)
|
78
|
+
|
79
|
+
# process data
|
80
|
+
if self.cfg.use_aug == True and self.cfg.split == "train":
|
81
|
+
spike, img = self.augmentor(spike, img)
|
82
|
+
|
83
|
+
batch = {"spike": spike, "img": img}
|
84
|
+
return batch
|
85
|
+
|
86
|
+
# todo: To be overridden
|
87
|
+
def prepare_data(self):
|
88
|
+
"""Specify the spike and image files to be loaded."""
|
89
|
+
# spike
|
90
|
+
self.spike_dir = self.cfg.root_dir / self.cfg.split / self.cfg.spike_dir_name
|
91
|
+
self.spike_list = self.get_spike_files(self.spike_dir)
|
92
|
+
# gt
|
93
|
+
if self.cfg.with_img == True:
|
94
|
+
self.img_dir = self.cfg.root_dir / self.cfg.split / self.cfg.img_dir_name
|
95
|
+
self.img_list = self.get_image_files(self.img_dir)
|
96
|
+
|
97
|
+
# todo: To be overridden
|
98
|
+
def get_spike_files(self, path: Path):
|
99
|
+
"""Recognize spike files automatically (default .dat)."""
|
100
|
+
files = path.glob("**/*.dat")
|
101
|
+
return sorted(files)
|
102
|
+
|
103
|
+
# todo: To be overridden
|
104
|
+
def load_spike(self, idx):
|
105
|
+
"""Load the spike stream from the given idx."""
|
106
|
+
spike_name = str(self.spike_list[idx])
|
107
|
+
spike = load_vidar_dat(
|
108
|
+
spike_name,
|
109
|
+
height=self.cfg.height,
|
110
|
+
width=self.cfg.width,
|
111
|
+
out_type="float",
|
112
|
+
out_format="tensor",
|
113
|
+
)
|
114
|
+
return spike
|
115
|
+
|
116
|
+
def get_spike(self, idx):
|
117
|
+
"""Get and process the spike stream from the given idx."""
|
118
|
+
spike_length = self.cfg.spike_length
|
119
|
+
spike = self.load_spike(idx)
|
120
|
+
assert spike.shape[0] >= spike_length, f"Given spike length {spike.shape[0]} smaller than the required length {spike_length}"
|
121
|
+
spike_mid = spike.shape[0] // 2
|
122
|
+
# spike length process
|
123
|
+
if spike_length == -1:
|
124
|
+
spike = spike
|
125
|
+
elif spike_length % 2 == 1:
|
126
|
+
spike = spike[spike_mid - spike_length // 2 : spike_mid + spike_length // 2 + 1]
|
127
|
+
elif spike_length % 2 == 0:
|
128
|
+
spike = spike[spike_mid - spike_length // 2 : spike_mid + spike_length // 2]
|
129
|
+
return spike
|
130
|
+
|
131
|
+
def get_image_files(self, path: Path):
|
132
|
+
"""Recognize image files automatically."""
|
133
|
+
files = [f for f in path.glob("**/*") if re.match(r".*\.(jpg|jpeg|png)$", f.name, re.IGNORECASE)]
|
134
|
+
return sorted(files)
|
135
|
+
|
136
|
+
# todo: To be overridden
|
137
|
+
def get_img(self, idx):
|
138
|
+
"""Get the image from the given idx."""
|
139
|
+
if self.cfg.with_img:
|
140
|
+
img_name = str(self.img_list[idx])
|
141
|
+
img = cv2.imread(img_name)
|
142
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
143
|
+
img = (img / 255).astype(np.float32)
|
144
|
+
img = img[None]
|
145
|
+
img = torch.from_numpy(img)
|
146
|
+
else:
|
147
|
+
spike = self.get_spike(idx)
|
148
|
+
img = torch.mean(spike, dim=0, keepdim=True)
|
149
|
+
return img
|
150
|
+
|
151
|
+
def cache_data(self):
|
152
|
+
"""Cache the data."""
|
153
|
+
self.cache_spkimg = []
|
154
|
+
for idx in tqdm(range(len(self.spike_list)), desc="Caching data", unit="sample"):
|
155
|
+
spike = self.get_spike(idx).to(torch.uint8)
|
156
|
+
img = self.get_img(idx)
|
157
|
+
self.cache_spkimg.append([spike, img])
|
@@ -0,0 +1,25 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
|
3
|
+
from dataclasses import dataclass
|
4
|
+
|
5
|
+
|
6
|
+
@dataclass
|
7
|
+
class RealWorld_Config(BaseDatasetConfig):
|
8
|
+
dataset_name: str = "realworld"
|
9
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/recVidarReal2019")
|
10
|
+
width: int = 400
|
11
|
+
height: int = 250
|
12
|
+
with_img: bool = False
|
13
|
+
spike_length_train: int = -1
|
14
|
+
spike_length_test: int = -1
|
15
|
+
|
16
|
+
|
17
|
+
class RealWorld(BaseDataset):
|
18
|
+
def __init__(self, cfg: BaseDatasetConfig):
|
19
|
+
super(RealWorld, self).__init__(cfg)
|
20
|
+
|
21
|
+
def prepare_data(self):
|
22
|
+
self.spike_dir = self.cfg.root_dir
|
23
|
+
self.spike_list = self.get_spike_files(self.spike_dir)
|
24
|
+
|
25
|
+
|
@@ -0,0 +1,27 @@
|
|
1
|
+
from torch.utils.data import Dataset
|
2
|
+
from pathlib import Path
|
3
|
+
from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
|
4
|
+
from dataclasses import dataclass
|
5
|
+
import re
|
6
|
+
|
7
|
+
@dataclass
|
8
|
+
class REDS_Small_Config(BaseDatasetConfig):
|
9
|
+
dataset_name: str = "reds_small"
|
10
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/REDS_Small")
|
11
|
+
width: int = 400
|
12
|
+
height: int = 250
|
13
|
+
with_img: bool = True
|
14
|
+
spike_length_train: int = 41
|
15
|
+
spike_length_test: int = 301
|
16
|
+
spike_dir_name: str = "spike"
|
17
|
+
img_dir_name: str = "gt"
|
18
|
+
|
19
|
+
class REDS_Small(BaseDataset):
|
20
|
+
def __init__(self, cfg: BaseDatasetConfig):
|
21
|
+
super(REDS_Small, self).__init__(cfg)
|
22
|
+
|
23
|
+
def prepare_data(self):
|
24
|
+
super().prepare_data()
|
25
|
+
if self.cfg.split == "train":
|
26
|
+
self.img_list = [self.img_dir / Path(str(s.name).replace('.dat','.png')) for s in self.spike_list]
|
27
|
+
|
@@ -0,0 +1,37 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
|
3
|
+
from dataclasses import dataclass
|
4
|
+
import cv2
|
5
|
+
import torch
|
6
|
+
import numpy as np
|
7
|
+
|
8
|
+
@dataclass
|
9
|
+
class SZData_Config(BaseDatasetConfig):
|
10
|
+
dataset_name: str = "szdata"
|
11
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/dataset")
|
12
|
+
width: int = 400
|
13
|
+
height: int = 250
|
14
|
+
with_img: bool = True
|
15
|
+
spike_length_train: int = -1
|
16
|
+
spike_length_test: int = -1
|
17
|
+
spike_dir_name: str = "spike_data"
|
18
|
+
img_dir_name: str = "sharp_data"
|
19
|
+
|
20
|
+
class SZData(BaseDataset):
|
21
|
+
def __init__(self, cfg: BaseDatasetConfig):
|
22
|
+
super(SZData, self).__init__(cfg)
|
23
|
+
|
24
|
+
def get_img(self, idx):
|
25
|
+
if self.cfg.with_img:
|
26
|
+
spike_name = self.spike_list[idx]
|
27
|
+
img_name = str(spike_name).replace(self.cfg.spike_dir_name,self.cfg.img_dir_name).replace(".dat",".png")
|
28
|
+
img = cv2.imread(img_name)
|
29
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
30
|
+
img = (img / 255).astype(np.float32)
|
31
|
+
img = img[None]
|
32
|
+
img = torch.from_numpy(img)
|
33
|
+
else:
|
34
|
+
spike = self.get_spike(idx)
|
35
|
+
img = torch.mean(spike, dim=0, keepdim=True)
|
36
|
+
return img
|
37
|
+
|
@@ -0,0 +1,38 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
|
3
|
+
from dataclasses import dataclass
|
4
|
+
import numpy as np
|
5
|
+
import torch
|
6
|
+
|
7
|
+
@dataclass
|
8
|
+
class UHSR_Config(BaseDatasetConfig):
|
9
|
+
dataset_name: str = "uhsr"
|
10
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/U-CALTECH")
|
11
|
+
width: int = 224
|
12
|
+
height: int = 224
|
13
|
+
with_img: bool = False
|
14
|
+
spike_length_train: int = 200
|
15
|
+
spike_length_test: int = 200
|
16
|
+
spike_dir_name: str = "spike"
|
17
|
+
img_dir_name: str = ""
|
18
|
+
|
19
|
+
|
20
|
+
class UHSR(BaseDataset):
|
21
|
+
def __init__(self, cfg: BaseDatasetConfig):
|
22
|
+
super(UHSR, self).__init__(cfg)
|
23
|
+
|
24
|
+
def prepare_data(self):
|
25
|
+
self.spike_dir = self.cfg.root_dir / self.cfg.dataset_split
|
26
|
+
self.spike_list = self.get_spike_files(self.spike_dir)
|
27
|
+
|
28
|
+
def get_spike_files(self, path: Path):
|
29
|
+
files = path.glob("**/*.npz")
|
30
|
+
return sorted(files)
|
31
|
+
|
32
|
+
def load_spike(self,idx):
|
33
|
+
spike_name = str(self.spike_list[idx])
|
34
|
+
data = np.load(spike_name)
|
35
|
+
spike = data["spk"].astype(np.float32)
|
36
|
+
spike = torch.from_numpy(spike)
|
37
|
+
spike = spike[:, 13:237, 13:237]
|
38
|
+
return spike
|
@@ -0,0 +1,96 @@
|
|
1
|
+
from skimage import metrics
|
2
|
+
import torch
|
3
|
+
import torch.hub
|
4
|
+
from lpips.lpips import LPIPS
|
5
|
+
import os
|
6
|
+
import os
|
7
|
+
import pyiqa
|
8
|
+
import numpy as np
|
9
|
+
from torchvision import transforms
|
10
|
+
import torch.nn.functional as F
|
11
|
+
|
12
|
+
# todo with the union type
|
13
|
+
metric_pair_names = ["psnr", "ssim", "lpips", "mse"]
|
14
|
+
metric_single_names = ["niqe", "brisque", "piqe"]
|
15
|
+
metric_all_names = metric_pair_names + metric_single_names
|
16
|
+
|
17
|
+
metric_single_list = {}
|
18
|
+
|
19
|
+
metric_pair_list = {
|
20
|
+
"mse": metrics.mean_squared_error,
|
21
|
+
"ssim": metrics.structural_similarity,
|
22
|
+
"psnr": metrics.peak_signal_noise_ratio,
|
23
|
+
"lpips": None,
|
24
|
+
}
|
25
|
+
|
26
|
+
|
27
|
+
def cal_metric_single(img: torch.Tensor, metric_name="niqe"):
|
28
|
+
if metric_name not in metric_single_list.keys():
|
29
|
+
if metric_name in pyiqa.list_models():
|
30
|
+
iqa_metric = pyiqa.create_metric(metric_name, device=torch.device("cuda"))
|
31
|
+
metric_single_list.update({metric_name: iqa_metric})
|
32
|
+
else:
|
33
|
+
raise RuntimeError(f"Metric {metric_name} not recognized by the IQA lib.")
|
34
|
+
# image process
|
35
|
+
if img.dim() == 3:
|
36
|
+
img = img[None]
|
37
|
+
elif img.dim() == 2:
|
38
|
+
img = img[None, None]
|
39
|
+
|
40
|
+
# resize
|
41
|
+
if metric_name == "liqe_mix":
|
42
|
+
short_edge = 384
|
43
|
+
h, w = img.shape[2], img.shape[3]
|
44
|
+
if h < w:
|
45
|
+
new_h, new_w = short_edge, int(w * short_edge / h)
|
46
|
+
else:
|
47
|
+
new_h, new_w = int(h * short_edge / w), short_edge
|
48
|
+
img = F.interpolate(img, size=(new_h, new_w), mode="bilinear", align_corners=False)
|
49
|
+
return metric_single_list[metric_name](img).item()
|
50
|
+
|
51
|
+
|
52
|
+
def cal_metric_pair(im1t: torch.Tensor, im2t: torch.Tensor, metric_name="mse"):
|
53
|
+
"""
|
54
|
+
im1t, im2t: torch.tensors with batched imaged shape, range from (0, 1)
|
55
|
+
"""
|
56
|
+
if metric_name not in metric_pair_list.keys():
|
57
|
+
raise RuntimeError(f"Metric {metric_name} not recognized")
|
58
|
+
if metric_name == "lpips" and metric_pair_list[metric_name] is None:
|
59
|
+
metric_pair_list[metric_name] = LPIPS().cuda() if im1t.is_cuda else LPIPS().cpu()
|
60
|
+
metric_method = metric_pair_list[metric_name]
|
61
|
+
|
62
|
+
# convert from [0, 1] to [-1, 1]
|
63
|
+
im1t = (im1t * 2 - 1).clamp(-1, 1)
|
64
|
+
im2t = (im2t * 2 - 1).clamp(-1, 1)
|
65
|
+
|
66
|
+
# [c,h,w] -> [1,c,h,w]
|
67
|
+
if im1t.dim() == 3:
|
68
|
+
im1t = im1t[None]
|
69
|
+
im2t = im2t[None]
|
70
|
+
elif im1t.dim() == 2:
|
71
|
+
im1t = im1t[None, None]
|
72
|
+
im2t = im2t[None, None]
|
73
|
+
|
74
|
+
# [1,h,w,3] -> [1,3,h,w]
|
75
|
+
if im1t.shape[-1] == 3:
|
76
|
+
im1t = im1t.permute(0, 3, 1, 2)
|
77
|
+
im2t = im2t.permute(0, 3, 1, 2)
|
78
|
+
|
79
|
+
# img array: [1,h,w,3] imgt tensor: [1,3,h,w]
|
80
|
+
im1 = im1t.permute(0, 2, 3, 1).detach().cpu().numpy()
|
81
|
+
im2 = im2t.permute(0, 2, 3, 1).detach().cpu().numpy()
|
82
|
+
batchsz, hei, wid, _ = im1.shape
|
83
|
+
|
84
|
+
# batch processing
|
85
|
+
values = []
|
86
|
+
for i in range(batchsz):
|
87
|
+
if metric_name in ["mse", "psnr"]:
|
88
|
+
value = metric_method(im1[i], im2[i])
|
89
|
+
elif metric_name in ["ssim"]:
|
90
|
+
value, ssimmap = metric_method(im1[i], im2[i], channel_axis=-1, data_range=2, full=True)
|
91
|
+
elif metric_name in ["lpips"]:
|
92
|
+
value = metric_method(im1t[i : i + 1], im2t[i : i + 1])[0, 0, 0, 0]
|
93
|
+
value = value.detach().cpu().float().item()
|
94
|
+
values.append(value)
|
95
|
+
|
96
|
+
return sum(values) / len(values)
|
@@ -0,0 +1,37 @@
|
|
1
|
+
import importlib
|
2
|
+
import inspect
|
3
|
+
from spikezoo.models.base_model import BaseModel,BaseModelConfig
|
4
|
+
import os
|
5
|
+
from pathlib import Path
|
6
|
+
|
7
|
+
current_file_path = Path(__file__).parent
|
8
|
+
files_list = os.listdir(os.path.dirname(os.path.abspath(__file__)))
|
9
|
+
model_list = [file.split("_")[0] for file in files_list if file.endswith("_model.py")]
|
10
|
+
|
11
|
+
# todo register function
|
12
|
+
def build_model_cfg(cfg: BaseModelConfig):
|
13
|
+
"""Build the model from the given model config."""
|
14
|
+
# model module name
|
15
|
+
module_name = cfg.model_name + "_model"
|
16
|
+
assert cfg.model_name in model_list, f"Given model {cfg.model_name} not in our model zoo {model_list}."
|
17
|
+
module_name = "spikezoo.models." + module_name
|
18
|
+
module = importlib.import_module(module_name)
|
19
|
+
# model,model_config
|
20
|
+
classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
|
21
|
+
model_cls: BaseModel = getattr(module, classes[0])
|
22
|
+
model = model_cls(cfg)
|
23
|
+
return model
|
24
|
+
|
25
|
+
def build_model_name(model_name: str):
|
26
|
+
"""Build the default dataset from the given name."""
|
27
|
+
# model module name
|
28
|
+
module_name = model_name + "_model"
|
29
|
+
assert model_name in model_list, f"Given model {model_name} not in our model zoo {model_list}."
|
30
|
+
module_name = "spikezoo.models." + module_name
|
31
|
+
module = importlib.import_module(module_name)
|
32
|
+
# model,model_config
|
33
|
+
classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
|
34
|
+
model_cls: BaseModel = getattr(module, classes[0])
|
35
|
+
model_cfg: BaseModelConfig = getattr(module, classes[1])()
|
36
|
+
model = model_cls(model_cfg)
|
37
|
+
return model
|
@@ -0,0 +1,177 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import importlib
|
4
|
+
import inspect
|
5
|
+
from dataclasses import dataclass, field
|
6
|
+
from spikezoo.utils import load_network, download_file
|
7
|
+
import os
|
8
|
+
import time
|
9
|
+
from typing import Dict
|
10
|
+
from torch.optim import Adam
|
11
|
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
12
|
+
import functools
|
13
|
+
|
14
|
+
|
15
|
+
# todo private design
|
16
|
+
@dataclass
|
17
|
+
class BaseModelConfig:
|
18
|
+
# default params for BaseModel
|
19
|
+
"Registerd model name."
|
20
|
+
model_name: str = "base"
|
21
|
+
"File name of the specified model."
|
22
|
+
model_file_name: str = "nets"
|
23
|
+
"Class name of the specified model in spikezoo/archs/base/{model_file_name}.py."
|
24
|
+
model_cls_name: str = "BaseNet"
|
25
|
+
"Spike input length for the specified model."
|
26
|
+
model_win_length: int = 41
|
27
|
+
"Model require model parameters or not."
|
28
|
+
require_params: bool = False
|
29
|
+
"Model stored path."
|
30
|
+
ckpt_path: str = ""
|
31
|
+
"Load pretrained weights or not."
|
32
|
+
load_state: bool = True
|
33
|
+
"Base url for storing pretrained models."
|
34
|
+
base_url: str = "https://github.com/chenkang455/Spike-Zoo/releases/download/v0.1/"
|
35
|
+
"Multi-GPU setting."
|
36
|
+
multi_gpu: bool = False
|
37
|
+
"Model parameters."
|
38
|
+
model_params: dict = field(default_factory=lambda: {})
|
39
|
+
|
40
|
+
|
41
|
+
class BaseModel(nn.Module):
|
42
|
+
def __init__(self, cfg: BaseModelConfig):
|
43
|
+
super(BaseModel, self).__init__()
|
44
|
+
self.cfg = cfg
|
45
|
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
46
|
+
self.net = self.build_network().to(self.device)
|
47
|
+
self.net = nn.DataParallel(self.net) if cfg.multi_gpu == True else self.net
|
48
|
+
self.spike_size = None
|
49
|
+
self.model_half_win_length: int = cfg.model_win_length // 2
|
50
|
+
|
51
|
+
# ! Might lead to low speed training on the BSF.
|
52
|
+
def forward(self, spike):
|
53
|
+
"""A simple implementation for the spike-to-image conversion, given the spike input and output the reconstructed image."""
|
54
|
+
spike = self.preprocess_spike(spike)
|
55
|
+
img = self.net(spike)
|
56
|
+
img = self.postprocess_img(img)
|
57
|
+
return img
|
58
|
+
|
59
|
+
def build_network(self):
|
60
|
+
"""Build the network and load the pretrained weight."""
|
61
|
+
# network
|
62
|
+
module = importlib.import_module(f"spikezoo.archs.{self.cfg.model_name}.{self.cfg.model_file_name}")
|
63
|
+
model_cls = getattr(module, self.cfg.model_cls_name)
|
64
|
+
model = model_cls(**self.cfg.model_params)
|
65
|
+
if self.cfg.load_state and self.cfg.require_params:
|
66
|
+
load_folder = os.path.dirname(os.path.abspath(__file__))
|
67
|
+
weight_path = os.path.join(load_folder, self.cfg.ckpt_path)
|
68
|
+
if os.path.exists(weight_path) == False:
|
69
|
+
os.makedirs(os.path.dirname(weight_path), exist_ok=True)
|
70
|
+
self.download_weight(weight_path)
|
71
|
+
time.sleep(0.5)
|
72
|
+
model = load_network(weight_path, model)
|
73
|
+
return model
|
74
|
+
|
75
|
+
def save_network(self, save_path):
|
76
|
+
"""Save the network."""
|
77
|
+
network = self.net
|
78
|
+
if isinstance(network, nn.DataParallel):
|
79
|
+
network = network.module
|
80
|
+
state_dict = network.state_dict()
|
81
|
+
for key, param in state_dict.items():
|
82
|
+
state_dict[key] = param.cpu()
|
83
|
+
torch.save(state_dict, save_path)
|
84
|
+
|
85
|
+
def download_weight(self, weight_path):
|
86
|
+
"""Download the pretrained weight from the given url."""
|
87
|
+
url = self.cfg.base_url + os.path.basename(self.cfg.ckpt_path)
|
88
|
+
download_file(url, weight_path)
|
89
|
+
|
90
|
+
def crop_spike_length(self, spike):
|
91
|
+
"""Crop the spike length."""
|
92
|
+
spike_length = spike.shape[1]
|
93
|
+
spike_mid = spike_length // 2
|
94
|
+
assert spike_length >= self.cfg.model_win_length, f"Spike input is not long enough, given {spike_length} frames < {self.cfg.model_win_length}."
|
95
|
+
# even length
|
96
|
+
if self.cfg.model_win_length == self.model_half_win_length * 2:
|
97
|
+
spike = spike[
|
98
|
+
:,
|
99
|
+
spike_mid - self.model_half_win_length : spike_mid + self.model_half_win_length,
|
100
|
+
]
|
101
|
+
# odd length
|
102
|
+
else:
|
103
|
+
spike = spike[
|
104
|
+
:,
|
105
|
+
spike_mid - self.model_half_win_length : spike_mid + self.model_half_win_length + 1,
|
106
|
+
]
|
107
|
+
if self.spike_size == None:
|
108
|
+
self.spike_size = (spike.shape[2], spike.shape[3])
|
109
|
+
return spike
|
110
|
+
|
111
|
+
def preprocess_spike(self, spike):
|
112
|
+
"""Preprocess the spike (length size)."""
|
113
|
+
spike = self.crop_spike_length(spike)
|
114
|
+
return spike
|
115
|
+
|
116
|
+
def postprocess_img(self, image):
|
117
|
+
"""Postprocess the image."""
|
118
|
+
return image
|
119
|
+
|
120
|
+
# -------------------- Training Part --------------------
|
121
|
+
def setup_training(self, pipeline_cfg):
|
122
|
+
"""Setup training optimizer and loss."""
|
123
|
+
from spikezoo.pipeline import TrainPipelineConfig
|
124
|
+
|
125
|
+
cfg: TrainPipelineConfig = pipeline_cfg
|
126
|
+
self.optimizer = Adam(self.net.parameters(), lr=cfg.lr, betas=(0.9, 0.99), weight_decay=0)
|
127
|
+
self.scheduler = CosineAnnealingLR(self.optimizer, T_max=cfg.epochs, eta_min=0)
|
128
|
+
self.criterion = nn.L1Loss()
|
129
|
+
|
130
|
+
def get_outputs_dict(self, batch):
|
131
|
+
"""Get the output dict for the given input batch. (Designed for the training mode considering possible auxiliary output.)"""
|
132
|
+
# data process
|
133
|
+
spike = batch["spike"]
|
134
|
+
# outputs
|
135
|
+
outputs = {}
|
136
|
+
recon_img = self(spike)
|
137
|
+
outputs["recon_img"] = recon_img
|
138
|
+
return outputs
|
139
|
+
|
140
|
+
def get_visual_dict(self, batch, outputs):
|
141
|
+
"""Get the visual dict from the given input batch and outputs."""
|
142
|
+
visual_dict = {}
|
143
|
+
visual_dict["recon"] = outputs["recon_img"]
|
144
|
+
visual_dict["img"] = batch["img"]
|
145
|
+
return visual_dict
|
146
|
+
|
147
|
+
def get_loss_dict(self, outputs, batch):
|
148
|
+
"""Get the loss dict from the given input batch and outputs."""
|
149
|
+
# data process
|
150
|
+
gt_img = batch["img"]
|
151
|
+
# recon image
|
152
|
+
recon_img = outputs["recon_img"]
|
153
|
+
# loss dict
|
154
|
+
loss_dict = {}
|
155
|
+
loss_dict["l1"] = self.criterion(recon_img, gt_img)
|
156
|
+
loss_values_dict = {k: v.item() for k, v in loss_dict.items()}
|
157
|
+
return loss_dict,loss_values_dict
|
158
|
+
|
159
|
+
def get_paired_imgs(self, batch, outputs):
|
160
|
+
recon_img = outputs["recon_img"]
|
161
|
+
img = batch["img"]
|
162
|
+
return recon_img, img
|
163
|
+
|
164
|
+
def optimize_parameters(self, loss_dict):
|
165
|
+
"""Optimize the parameters from the loss_dict."""
|
166
|
+
loss = functools.reduce(torch.add, loss_dict.values())
|
167
|
+
self.optimizer.zero_grad()
|
168
|
+
loss.backward()
|
169
|
+
self.optimizer.step()
|
170
|
+
|
171
|
+
def update_learning_rate(self):
|
172
|
+
"""Update the learning rate."""
|
173
|
+
self.scheduler.step()
|
174
|
+
|
175
|
+
def feed_to_device(self, batch):
|
176
|
+
batch = {k: v.to(self.device, non_blocking=True) if torch.is_tensor(v) else v for k, v in batch.items()}
|
177
|
+
return batch
|