spikezoo 0.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- spikezoo/__init__.py +0 -0
- spikezoo/archs/__init__.py +0 -0
- spikezoo/datasets/__init__.py +68 -0
- spikezoo/datasets/base_dataset.py +157 -0
- spikezoo/datasets/realworld_dataset.py +25 -0
- spikezoo/datasets/reds_small_dataset.py +27 -0
- spikezoo/datasets/szdata_dataset.py +37 -0
- spikezoo/datasets/uhsr_dataset.py +38 -0
- spikezoo/metrics/__init__.py +96 -0
- spikezoo/models/__init__.py +37 -0
- spikezoo/models/base_model.py +177 -0
- spikezoo/models/bsf_model.py +90 -0
- spikezoo/models/spcsnet_model.py +19 -0
- spikezoo/models/spikeclip_model.py +32 -0
- spikezoo/models/spikeformer_model.py +50 -0
- spikezoo/models/spk2imgnet_model.py +51 -0
- spikezoo/models/ssir_model.py +22 -0
- spikezoo/models/ssml_model.py +18 -0
- spikezoo/models/stir_model.py +37 -0
- spikezoo/models/tfi_model.py +18 -0
- spikezoo/models/tfp_model.py +18 -0
- spikezoo/models/wgse_model.py +31 -0
- spikezoo/pipeline/__init__.py +4 -0
- spikezoo/pipeline/base_pipeline.py +267 -0
- spikezoo/pipeline/ensemble_pipeline.py +64 -0
- spikezoo/pipeline/train_pipeline.py +94 -0
- spikezoo/utils/__init__.py +3 -0
- spikezoo/utils/data_utils.py +52 -0
- spikezoo/utils/img_utils.py +72 -0
- spikezoo/utils/other_utils.py +59 -0
- spikezoo/utils/spike_utils.py +82 -0
- spikezoo-0.1.dist-info/LICENSE.txt +17 -0
- spikezoo-0.1.dist-info/METADATA +39 -0
- spikezoo-0.1.dist-info/RECORD +36 -0
- spikezoo-0.1.dist-info/WHEEL +5 -0
- spikezoo-0.1.dist-info/top_level.txt +1 -0
spikezoo/__init__.py
ADDED
File without changes
|
File without changes
|
@@ -0,0 +1,68 @@
|
|
1
|
+
from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
|
2
|
+
from dataclasses import replace
|
3
|
+
import importlib, inspect
|
4
|
+
import os
|
5
|
+
import torch
|
6
|
+
from typing import Literal
|
7
|
+
|
8
|
+
# todo auto detect/register datasets
|
9
|
+
files_list = os.listdir(os.path.dirname(os.path.abspath(__file__)))
|
10
|
+
dataset_list = [file.replace("_dataset.py", "") for file in files_list if file.endswith("_dataset.py")]
|
11
|
+
|
12
|
+
# todo register function
|
13
|
+
def build_dataset_cfg(cfg: BaseDatasetConfig, split: Literal["train", "test"] = "test"):
|
14
|
+
"""Build the dataset from the given dataset config."""
|
15
|
+
# build new cfg according to split
|
16
|
+
cfg = replace(cfg,split = split,spike_length = cfg.spike_length_train if split == "train" else cfg.spike_length_test)
|
17
|
+
# dataset module
|
18
|
+
module_name = cfg.dataset_name + "_dataset"
|
19
|
+
assert cfg.dataset_name in dataset_list, f"Given dataset {cfg.dataset_name} not in our dataset list {dataset_list}."
|
20
|
+
module_name = "spikezoo.datasets." + module_name
|
21
|
+
module = importlib.import_module(module_name)
|
22
|
+
# dataset,dataset_config
|
23
|
+
classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
|
24
|
+
dataset_cls: BaseDataset = getattr(module, classes[0])
|
25
|
+
dataset = dataset_cls(cfg)
|
26
|
+
return dataset
|
27
|
+
|
28
|
+
|
29
|
+
def build_dataset_name(dataset_name: str, split: Literal["train", "test"] = "test"):
|
30
|
+
"""Build the default dataset from the given name."""
|
31
|
+
module_name = dataset_name + "_dataset"
|
32
|
+
assert dataset_name in dataset_list, f"Given dataset {dataset_name} not in our dataset list {dataset_list}."
|
33
|
+
module_name = "spikezoo.datasets." + module_name
|
34
|
+
module = importlib.import_module(module_name)
|
35
|
+
# dataset,dataset_config
|
36
|
+
classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
|
37
|
+
dataset_cls: BaseDataset = getattr(module, classes[0])
|
38
|
+
dataset_cfg: BaseDatasetConfig = getattr(module, classes[1])(split=split)
|
39
|
+
dataset = dataset_cls(dataset_cfg)
|
40
|
+
return dataset
|
41
|
+
|
42
|
+
|
43
|
+
# todo to modify according to the basicsr
|
44
|
+
def build_dataloader(dataset: BaseDataset,cfg = None):
|
45
|
+
# train dataloader
|
46
|
+
if dataset.cfg.split == "train":
|
47
|
+
if cfg is None:
|
48
|
+
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
|
49
|
+
else:
|
50
|
+
return torch.utils.data.DataLoader(dataset, batch_size=cfg.bs_train, shuffle=True, num_workers=cfg.num_workers,pin_memory=cfg.pin_memory)
|
51
|
+
# test dataloader
|
52
|
+
elif dataset.cfg.split == "test":
|
53
|
+
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
|
54
|
+
|
55
|
+
|
56
|
+
# dataset_size_dict = {}
|
57
|
+
# for dataset in dataset_list:
|
58
|
+
# module_name = dataset + "_dataset"
|
59
|
+
# module_name = "spikezoo.datasets." + module_name
|
60
|
+
# module = importlib.import_module(module_name)
|
61
|
+
# classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj)])
|
62
|
+
# dataset_cfg: BaseDatasetConfig = getattr(module, classes[1])
|
63
|
+
# dataset_size_dict[dataset] = (dataset_cfg.height, dataset_cfg.width)
|
64
|
+
|
65
|
+
|
66
|
+
# def get_dataset_size(name):
|
67
|
+
# assert name in dataset_list, f"Given dataset {name} not in our dataset list {dataset_list}."
|
68
|
+
# return dataset_size_dict[name]
|
@@ -0,0 +1,157 @@
|
|
1
|
+
import torch
|
2
|
+
from torch.utils.data import Dataset
|
3
|
+
from pathlib import Path
|
4
|
+
import cv2
|
5
|
+
import numpy as np
|
6
|
+
from spikezoo.utils.spike_utils import load_vidar_dat
|
7
|
+
import re
|
8
|
+
from dataclasses import dataclass, replace
|
9
|
+
from typing import Literal
|
10
|
+
import warnings
|
11
|
+
import torch
|
12
|
+
from tqdm import tqdm
|
13
|
+
from spikezoo.utils.data_utils import Augmentor
|
14
|
+
|
15
|
+
|
16
|
+
@dataclass
|
17
|
+
class BaseDatasetConfig:
|
18
|
+
# ------------- Not Recommended to Change -------------
|
19
|
+
"Dataset name."
|
20
|
+
dataset_name: str = "base"
|
21
|
+
"Directory specifying location of data."
|
22
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/base")
|
23
|
+
"Image width."
|
24
|
+
width: int = 400
|
25
|
+
"Image height."
|
26
|
+
height: int = 250
|
27
|
+
"Spike paried with the image or not."
|
28
|
+
with_img: bool = True
|
29
|
+
"Dataset spike length for the train data."
|
30
|
+
spike_length_train: int = -1
|
31
|
+
"Dataset spike length for the test data."
|
32
|
+
spike_length_test: int = -1
|
33
|
+
"Dataset spike length for the instantiation dataclass."
|
34
|
+
spike_length: int = -1
|
35
|
+
"Dir name for the spike."
|
36
|
+
spike_dir_name: str = "spike"
|
37
|
+
"Dir name for the image."
|
38
|
+
img_dir_name: str = "gt"
|
39
|
+
|
40
|
+
# ------------- Config -------------
|
41
|
+
"Dataset split: train/test. Default set as the 'test' for evaluation."
|
42
|
+
split: Literal["train", "test"] = "test"
|
43
|
+
"Use the data augumentation technique or not."
|
44
|
+
use_aug: bool = False
|
45
|
+
"Use cache mechanism."
|
46
|
+
use_cache: bool = False
|
47
|
+
"Crop size."
|
48
|
+
crop_size: tuple = (-1, -1)
|
49
|
+
|
50
|
+
# post process
|
51
|
+
def __post_init__(self):
|
52
|
+
self.spike_length = self.spike_length_train if self.split == "train" else self.spike_length_test
|
53
|
+
self.root_dir = Path(self.root_dir) if isinstance(self.root_dir, str) else self.root_dir
|
54
|
+
# todo try download
|
55
|
+
assert self.root_dir.exists(), f"No files found in {self.root_dir} for the specified dataset `{self.dataset_name}`."
|
56
|
+
|
57
|
+
# todo cache mechanism
|
58
|
+
class BaseDataset(Dataset):
|
59
|
+
def __init__(self, cfg: BaseDatasetConfig):
|
60
|
+
super(BaseDataset, self).__init__()
|
61
|
+
self.cfg = cfg
|
62
|
+
self.augmentor = Augmentor(self.cfg.crop_size) if self.cfg.use_aug == True and self.cfg.split == "train" else -1
|
63
|
+
self.prepare_data()
|
64
|
+
self.cache_data() if cfg.use_cache == True else -1
|
65
|
+
warnings.warn("Lengths of the image list and the spike list should be equal.") if len(self.img_list) != len(self.spike_list) else -1
|
66
|
+
|
67
|
+
def __len__(self):
|
68
|
+
return len(self.spike_list)
|
69
|
+
|
70
|
+
def __getitem__(self, idx: int):
|
71
|
+
# load data
|
72
|
+
if self.cfg.use_cache == True:
|
73
|
+
spike, img = self.cache_spkimg[idx]
|
74
|
+
spike = spike.to(torch.float32)
|
75
|
+
else:
|
76
|
+
spike = self.get_spike(idx)
|
77
|
+
img = self.get_img(idx)
|
78
|
+
|
79
|
+
# process data
|
80
|
+
if self.cfg.use_aug == True and self.cfg.split == "train":
|
81
|
+
spike, img = self.augmentor(spike, img)
|
82
|
+
|
83
|
+
batch = {"spike": spike, "img": img}
|
84
|
+
return batch
|
85
|
+
|
86
|
+
# todo: To be overridden
|
87
|
+
def prepare_data(self):
|
88
|
+
"""Specify the spike and image files to be loaded."""
|
89
|
+
# spike
|
90
|
+
self.spike_dir = self.cfg.root_dir / self.cfg.split / self.cfg.spike_dir_name
|
91
|
+
self.spike_list = self.get_spike_files(self.spike_dir)
|
92
|
+
# gt
|
93
|
+
if self.cfg.with_img == True:
|
94
|
+
self.img_dir = self.cfg.root_dir / self.cfg.split / self.cfg.img_dir_name
|
95
|
+
self.img_list = self.get_image_files(self.img_dir)
|
96
|
+
|
97
|
+
# todo: To be overridden
|
98
|
+
def get_spike_files(self, path: Path):
|
99
|
+
"""Recognize spike files automatically (default .dat)."""
|
100
|
+
files = path.glob("**/*.dat")
|
101
|
+
return sorted(files)
|
102
|
+
|
103
|
+
# todo: To be overridden
|
104
|
+
def load_spike(self, idx):
|
105
|
+
"""Load the spike stream from the given idx."""
|
106
|
+
spike_name = str(self.spike_list[idx])
|
107
|
+
spike = load_vidar_dat(
|
108
|
+
spike_name,
|
109
|
+
height=self.cfg.height,
|
110
|
+
width=self.cfg.width,
|
111
|
+
out_type="float",
|
112
|
+
out_format="tensor",
|
113
|
+
)
|
114
|
+
return spike
|
115
|
+
|
116
|
+
def get_spike(self, idx):
|
117
|
+
"""Get and process the spike stream from the given idx."""
|
118
|
+
spike_length = self.cfg.spike_length
|
119
|
+
spike = self.load_spike(idx)
|
120
|
+
assert spike.shape[0] >= spike_length, f"Given spike length {spike.shape[0]} smaller than the required length {spike_length}"
|
121
|
+
spike_mid = spike.shape[0] // 2
|
122
|
+
# spike length process
|
123
|
+
if spike_length == -1:
|
124
|
+
spike = spike
|
125
|
+
elif spike_length % 2 == 1:
|
126
|
+
spike = spike[spike_mid - spike_length // 2 : spike_mid + spike_length // 2 + 1]
|
127
|
+
elif spike_length % 2 == 0:
|
128
|
+
spike = spike[spike_mid - spike_length // 2 : spike_mid + spike_length // 2]
|
129
|
+
return spike
|
130
|
+
|
131
|
+
def get_image_files(self, path: Path):
|
132
|
+
"""Recognize image files automatically."""
|
133
|
+
files = [f for f in path.glob("**/*") if re.match(r".*\.(jpg|jpeg|png)$", f.name, re.IGNORECASE)]
|
134
|
+
return sorted(files)
|
135
|
+
|
136
|
+
# todo: To be overridden
|
137
|
+
def get_img(self, idx):
|
138
|
+
"""Get the image from the given idx."""
|
139
|
+
if self.cfg.with_img:
|
140
|
+
img_name = str(self.img_list[idx])
|
141
|
+
img = cv2.imread(img_name)
|
142
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
143
|
+
img = (img / 255).astype(np.float32)
|
144
|
+
img = img[None]
|
145
|
+
img = torch.from_numpy(img)
|
146
|
+
else:
|
147
|
+
spike = self.get_spike(idx)
|
148
|
+
img = torch.mean(spike, dim=0, keepdim=True)
|
149
|
+
return img
|
150
|
+
|
151
|
+
def cache_data(self):
|
152
|
+
"""Cache the data."""
|
153
|
+
self.cache_spkimg = []
|
154
|
+
for idx in tqdm(range(len(self.spike_list)), desc="Caching data", unit="sample"):
|
155
|
+
spike = self.get_spike(idx).to(torch.uint8)
|
156
|
+
img = self.get_img(idx)
|
157
|
+
self.cache_spkimg.append([spike, img])
|
@@ -0,0 +1,25 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
|
3
|
+
from dataclasses import dataclass
|
4
|
+
|
5
|
+
|
6
|
+
@dataclass
|
7
|
+
class RealWorld_Config(BaseDatasetConfig):
|
8
|
+
dataset_name: str = "realworld"
|
9
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/recVidarReal2019")
|
10
|
+
width: int = 400
|
11
|
+
height: int = 250
|
12
|
+
with_img: bool = False
|
13
|
+
spike_length_train: int = -1
|
14
|
+
spike_length_test: int = -1
|
15
|
+
|
16
|
+
|
17
|
+
class RealWorld(BaseDataset):
|
18
|
+
def __init__(self, cfg: BaseDatasetConfig):
|
19
|
+
super(RealWorld, self).__init__(cfg)
|
20
|
+
|
21
|
+
def prepare_data(self):
|
22
|
+
self.spike_dir = self.cfg.root_dir
|
23
|
+
self.spike_list = self.get_spike_files(self.spike_dir)
|
24
|
+
|
25
|
+
|
@@ -0,0 +1,27 @@
|
|
1
|
+
from torch.utils.data import Dataset
|
2
|
+
from pathlib import Path
|
3
|
+
from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
|
4
|
+
from dataclasses import dataclass
|
5
|
+
import re
|
6
|
+
|
7
|
+
@dataclass
|
8
|
+
class REDS_Small_Config(BaseDatasetConfig):
|
9
|
+
dataset_name: str = "reds_small"
|
10
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/REDS_Small")
|
11
|
+
width: int = 400
|
12
|
+
height: int = 250
|
13
|
+
with_img: bool = True
|
14
|
+
spike_length_train: int = 41
|
15
|
+
spike_length_test: int = 301
|
16
|
+
spike_dir_name: str = "spike"
|
17
|
+
img_dir_name: str = "gt"
|
18
|
+
|
19
|
+
class REDS_Small(BaseDataset):
|
20
|
+
def __init__(self, cfg: BaseDatasetConfig):
|
21
|
+
super(REDS_Small, self).__init__(cfg)
|
22
|
+
|
23
|
+
def prepare_data(self):
|
24
|
+
super().prepare_data()
|
25
|
+
if self.cfg.split == "train":
|
26
|
+
self.img_list = [self.img_dir / Path(str(s.name).replace('.dat','.png')) for s in self.spike_list]
|
27
|
+
|
@@ -0,0 +1,37 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
|
3
|
+
from dataclasses import dataclass
|
4
|
+
import cv2
|
5
|
+
import torch
|
6
|
+
import numpy as np
|
7
|
+
|
8
|
+
@dataclass
|
9
|
+
class SZData_Config(BaseDatasetConfig):
|
10
|
+
dataset_name: str = "szdata"
|
11
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/dataset")
|
12
|
+
width: int = 400
|
13
|
+
height: int = 250
|
14
|
+
with_img: bool = True
|
15
|
+
spike_length_train: int = -1
|
16
|
+
spike_length_test: int = -1
|
17
|
+
spike_dir_name: str = "spike_data"
|
18
|
+
img_dir_name: str = "sharp_data"
|
19
|
+
|
20
|
+
class SZData(BaseDataset):
|
21
|
+
def __init__(self, cfg: BaseDatasetConfig):
|
22
|
+
super(SZData, self).__init__(cfg)
|
23
|
+
|
24
|
+
def get_img(self, idx):
|
25
|
+
if self.cfg.with_img:
|
26
|
+
spike_name = self.spike_list[idx]
|
27
|
+
img_name = str(spike_name).replace(self.cfg.spike_dir_name,self.cfg.img_dir_name).replace(".dat",".png")
|
28
|
+
img = cv2.imread(img_name)
|
29
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
30
|
+
img = (img / 255).astype(np.float32)
|
31
|
+
img = img[None]
|
32
|
+
img = torch.from_numpy(img)
|
33
|
+
else:
|
34
|
+
spike = self.get_spike(idx)
|
35
|
+
img = torch.mean(spike, dim=0, keepdim=True)
|
36
|
+
return img
|
37
|
+
|
@@ -0,0 +1,38 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
|
3
|
+
from dataclasses import dataclass
|
4
|
+
import numpy as np
|
5
|
+
import torch
|
6
|
+
|
7
|
+
@dataclass
|
8
|
+
class UHSR_Config(BaseDatasetConfig):
|
9
|
+
dataset_name: str = "uhsr"
|
10
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/U-CALTECH")
|
11
|
+
width: int = 224
|
12
|
+
height: int = 224
|
13
|
+
with_img: bool = False
|
14
|
+
spike_length_train: int = 200
|
15
|
+
spike_length_test: int = 200
|
16
|
+
spike_dir_name: str = "spike"
|
17
|
+
img_dir_name: str = ""
|
18
|
+
|
19
|
+
|
20
|
+
class UHSR(BaseDataset):
|
21
|
+
def __init__(self, cfg: BaseDatasetConfig):
|
22
|
+
super(UHSR, self).__init__(cfg)
|
23
|
+
|
24
|
+
def prepare_data(self):
|
25
|
+
self.spike_dir = self.cfg.root_dir / self.cfg.dataset_split
|
26
|
+
self.spike_list = self.get_spike_files(self.spike_dir)
|
27
|
+
|
28
|
+
def get_spike_files(self, path: Path):
|
29
|
+
files = path.glob("**/*.npz")
|
30
|
+
return sorted(files)
|
31
|
+
|
32
|
+
def load_spike(self,idx):
|
33
|
+
spike_name = str(self.spike_list[idx])
|
34
|
+
data = np.load(spike_name)
|
35
|
+
spike = data["spk"].astype(np.float32)
|
36
|
+
spike = torch.from_numpy(spike)
|
37
|
+
spike = spike[:, 13:237, 13:237]
|
38
|
+
return spike
|
@@ -0,0 +1,96 @@
|
|
1
|
+
from skimage import metrics
|
2
|
+
import torch
|
3
|
+
import torch.hub
|
4
|
+
from lpips.lpips import LPIPS
|
5
|
+
import os
|
6
|
+
import os
|
7
|
+
import pyiqa
|
8
|
+
import numpy as np
|
9
|
+
from torchvision import transforms
|
10
|
+
import torch.nn.functional as F
|
11
|
+
|
12
|
+
# todo with the union type
|
13
|
+
metric_pair_names = ["psnr", "ssim", "lpips", "mse"]
|
14
|
+
metric_single_names = ["niqe", "brisque", "piqe"]
|
15
|
+
metric_all_names = metric_pair_names + metric_single_names
|
16
|
+
|
17
|
+
metric_single_list = {}
|
18
|
+
|
19
|
+
metric_pair_list = {
|
20
|
+
"mse": metrics.mean_squared_error,
|
21
|
+
"ssim": metrics.structural_similarity,
|
22
|
+
"psnr": metrics.peak_signal_noise_ratio,
|
23
|
+
"lpips": None,
|
24
|
+
}
|
25
|
+
|
26
|
+
|
27
|
+
def cal_metric_single(img: torch.Tensor, metric_name="niqe"):
|
28
|
+
if metric_name not in metric_single_list.keys():
|
29
|
+
if metric_name in pyiqa.list_models():
|
30
|
+
iqa_metric = pyiqa.create_metric(metric_name, device=torch.device("cuda"))
|
31
|
+
metric_single_list.update({metric_name: iqa_metric})
|
32
|
+
else:
|
33
|
+
raise RuntimeError(f"Metric {metric_name} not recognized by the IQA lib.")
|
34
|
+
# image process
|
35
|
+
if img.dim() == 3:
|
36
|
+
img = img[None]
|
37
|
+
elif img.dim() == 2:
|
38
|
+
img = img[None, None]
|
39
|
+
|
40
|
+
# resize
|
41
|
+
if metric_name == "liqe_mix":
|
42
|
+
short_edge = 384
|
43
|
+
h, w = img.shape[2], img.shape[3]
|
44
|
+
if h < w:
|
45
|
+
new_h, new_w = short_edge, int(w * short_edge / h)
|
46
|
+
else:
|
47
|
+
new_h, new_w = int(h * short_edge / w), short_edge
|
48
|
+
img = F.interpolate(img, size=(new_h, new_w), mode="bilinear", align_corners=False)
|
49
|
+
return metric_single_list[metric_name](img).item()
|
50
|
+
|
51
|
+
|
52
|
+
def cal_metric_pair(im1t: torch.Tensor, im2t: torch.Tensor, metric_name="mse"):
|
53
|
+
"""
|
54
|
+
im1t, im2t: torch.tensors with batched imaged shape, range from (0, 1)
|
55
|
+
"""
|
56
|
+
if metric_name not in metric_pair_list.keys():
|
57
|
+
raise RuntimeError(f"Metric {metric_name} not recognized")
|
58
|
+
if metric_name == "lpips" and metric_pair_list[metric_name] is None:
|
59
|
+
metric_pair_list[metric_name] = LPIPS().cuda() if im1t.is_cuda else LPIPS().cpu()
|
60
|
+
metric_method = metric_pair_list[metric_name]
|
61
|
+
|
62
|
+
# convert from [0, 1] to [-1, 1]
|
63
|
+
im1t = (im1t * 2 - 1).clamp(-1, 1)
|
64
|
+
im2t = (im2t * 2 - 1).clamp(-1, 1)
|
65
|
+
|
66
|
+
# [c,h,w] -> [1,c,h,w]
|
67
|
+
if im1t.dim() == 3:
|
68
|
+
im1t = im1t[None]
|
69
|
+
im2t = im2t[None]
|
70
|
+
elif im1t.dim() == 2:
|
71
|
+
im1t = im1t[None, None]
|
72
|
+
im2t = im2t[None, None]
|
73
|
+
|
74
|
+
# [1,h,w,3] -> [1,3,h,w]
|
75
|
+
if im1t.shape[-1] == 3:
|
76
|
+
im1t = im1t.permute(0, 3, 1, 2)
|
77
|
+
im2t = im2t.permute(0, 3, 1, 2)
|
78
|
+
|
79
|
+
# img array: [1,h,w,3] imgt tensor: [1,3,h,w]
|
80
|
+
im1 = im1t.permute(0, 2, 3, 1).detach().cpu().numpy()
|
81
|
+
im2 = im2t.permute(0, 2, 3, 1).detach().cpu().numpy()
|
82
|
+
batchsz, hei, wid, _ = im1.shape
|
83
|
+
|
84
|
+
# batch processing
|
85
|
+
values = []
|
86
|
+
for i in range(batchsz):
|
87
|
+
if metric_name in ["mse", "psnr"]:
|
88
|
+
value = metric_method(im1[i], im2[i])
|
89
|
+
elif metric_name in ["ssim"]:
|
90
|
+
value, ssimmap = metric_method(im1[i], im2[i], channel_axis=-1, data_range=2, full=True)
|
91
|
+
elif metric_name in ["lpips"]:
|
92
|
+
value = metric_method(im1t[i : i + 1], im2t[i : i + 1])[0, 0, 0, 0]
|
93
|
+
value = value.detach().cpu().float().item()
|
94
|
+
values.append(value)
|
95
|
+
|
96
|
+
return sum(values) / len(values)
|
@@ -0,0 +1,37 @@
|
|
1
|
+
import importlib
|
2
|
+
import inspect
|
3
|
+
from spikezoo.models.base_model import BaseModel,BaseModelConfig
|
4
|
+
import os
|
5
|
+
from pathlib import Path
|
6
|
+
|
7
|
+
current_file_path = Path(__file__).parent
|
8
|
+
files_list = os.listdir(os.path.dirname(os.path.abspath(__file__)))
|
9
|
+
model_list = [file.split("_")[0] for file in files_list if file.endswith("_model.py")]
|
10
|
+
|
11
|
+
# todo register function
|
12
|
+
def build_model_cfg(cfg: BaseModelConfig):
|
13
|
+
"""Build the model from the given model config."""
|
14
|
+
# model module name
|
15
|
+
module_name = cfg.model_name + "_model"
|
16
|
+
assert cfg.model_name in model_list, f"Given model {cfg.model_name} not in our model zoo {model_list}."
|
17
|
+
module_name = "spikezoo.models." + module_name
|
18
|
+
module = importlib.import_module(module_name)
|
19
|
+
# model,model_config
|
20
|
+
classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
|
21
|
+
model_cls: BaseModel = getattr(module, classes[0])
|
22
|
+
model = model_cls(cfg)
|
23
|
+
return model
|
24
|
+
|
25
|
+
def build_model_name(model_name: str):
|
26
|
+
"""Build the default dataset from the given name."""
|
27
|
+
# model module name
|
28
|
+
module_name = model_name + "_model"
|
29
|
+
assert model_name in model_list, f"Given model {model_name} not in our model zoo {model_list}."
|
30
|
+
module_name = "spikezoo.models." + module_name
|
31
|
+
module = importlib.import_module(module_name)
|
32
|
+
# model,model_config
|
33
|
+
classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
|
34
|
+
model_cls: BaseModel = getattr(module, classes[0])
|
35
|
+
model_cfg: BaseModelConfig = getattr(module, classes[1])()
|
36
|
+
model = model_cls(model_cfg)
|
37
|
+
return model
|
@@ -0,0 +1,177 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import importlib
|
4
|
+
import inspect
|
5
|
+
from dataclasses import dataclass, field
|
6
|
+
from spikezoo.utils import load_network, download_file
|
7
|
+
import os
|
8
|
+
import time
|
9
|
+
from typing import Dict
|
10
|
+
from torch.optim import Adam
|
11
|
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
12
|
+
import functools
|
13
|
+
|
14
|
+
|
15
|
+
# todo private design
|
16
|
+
@dataclass
|
17
|
+
class BaseModelConfig:
|
18
|
+
# default params for BaseModel
|
19
|
+
"Registerd model name."
|
20
|
+
model_name: str = "base"
|
21
|
+
"File name of the specified model."
|
22
|
+
model_file_name: str = "nets"
|
23
|
+
"Class name of the specified model in spikezoo/archs/base/{model_file_name}.py."
|
24
|
+
model_cls_name: str = "BaseNet"
|
25
|
+
"Spike input length for the specified model."
|
26
|
+
model_win_length: int = 41
|
27
|
+
"Model require model parameters or not."
|
28
|
+
require_params: bool = False
|
29
|
+
"Model stored path."
|
30
|
+
ckpt_path: str = ""
|
31
|
+
"Load pretrained weights or not."
|
32
|
+
load_state: bool = True
|
33
|
+
"Base url for storing pretrained models."
|
34
|
+
base_url: str = "https://github.com/chenkang455/Spike-Zoo/releases/download/v0.1/"
|
35
|
+
"Multi-GPU setting."
|
36
|
+
multi_gpu: bool = False
|
37
|
+
"Model parameters."
|
38
|
+
model_params: dict = field(default_factory=lambda: {})
|
39
|
+
|
40
|
+
|
41
|
+
class BaseModel(nn.Module):
|
42
|
+
def __init__(self, cfg: BaseModelConfig):
|
43
|
+
super(BaseModel, self).__init__()
|
44
|
+
self.cfg = cfg
|
45
|
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
46
|
+
self.net = self.build_network().to(self.device)
|
47
|
+
self.net = nn.DataParallel(self.net) if cfg.multi_gpu == True else self.net
|
48
|
+
self.spike_size = None
|
49
|
+
self.model_half_win_length: int = cfg.model_win_length // 2
|
50
|
+
|
51
|
+
# ! Might lead to low speed training on the BSF.
|
52
|
+
def forward(self, spike):
|
53
|
+
"""A simple implementation for the spike-to-image conversion, given the spike input and output the reconstructed image."""
|
54
|
+
spike = self.preprocess_spike(spike)
|
55
|
+
img = self.net(spike)
|
56
|
+
img = self.postprocess_img(img)
|
57
|
+
return img
|
58
|
+
|
59
|
+
def build_network(self):
|
60
|
+
"""Build the network and load the pretrained weight."""
|
61
|
+
# network
|
62
|
+
module = importlib.import_module(f"spikezoo.archs.{self.cfg.model_name}.{self.cfg.model_file_name}")
|
63
|
+
model_cls = getattr(module, self.cfg.model_cls_name)
|
64
|
+
model = model_cls(**self.cfg.model_params)
|
65
|
+
if self.cfg.load_state and self.cfg.require_params:
|
66
|
+
load_folder = os.path.dirname(os.path.abspath(__file__))
|
67
|
+
weight_path = os.path.join(load_folder, self.cfg.ckpt_path)
|
68
|
+
if os.path.exists(weight_path) == False:
|
69
|
+
os.makedirs(os.path.dirname(weight_path), exist_ok=True)
|
70
|
+
self.download_weight(weight_path)
|
71
|
+
time.sleep(0.5)
|
72
|
+
model = load_network(weight_path, model)
|
73
|
+
return model
|
74
|
+
|
75
|
+
def save_network(self, save_path):
|
76
|
+
"""Save the network."""
|
77
|
+
network = self.net
|
78
|
+
if isinstance(network, nn.DataParallel):
|
79
|
+
network = network.module
|
80
|
+
state_dict = network.state_dict()
|
81
|
+
for key, param in state_dict.items():
|
82
|
+
state_dict[key] = param.cpu()
|
83
|
+
torch.save(state_dict, save_path)
|
84
|
+
|
85
|
+
def download_weight(self, weight_path):
|
86
|
+
"""Download the pretrained weight from the given url."""
|
87
|
+
url = self.cfg.base_url + os.path.basename(self.cfg.ckpt_path)
|
88
|
+
download_file(url, weight_path)
|
89
|
+
|
90
|
+
def crop_spike_length(self, spike):
|
91
|
+
"""Crop the spike length."""
|
92
|
+
spike_length = spike.shape[1]
|
93
|
+
spike_mid = spike_length // 2
|
94
|
+
assert spike_length >= self.cfg.model_win_length, f"Spike input is not long enough, given {spike_length} frames < {self.cfg.model_win_length}."
|
95
|
+
# even length
|
96
|
+
if self.cfg.model_win_length == self.model_half_win_length * 2:
|
97
|
+
spike = spike[
|
98
|
+
:,
|
99
|
+
spike_mid - self.model_half_win_length : spike_mid + self.model_half_win_length,
|
100
|
+
]
|
101
|
+
# odd length
|
102
|
+
else:
|
103
|
+
spike = spike[
|
104
|
+
:,
|
105
|
+
spike_mid - self.model_half_win_length : spike_mid + self.model_half_win_length + 1,
|
106
|
+
]
|
107
|
+
if self.spike_size == None:
|
108
|
+
self.spike_size = (spike.shape[2], spike.shape[3])
|
109
|
+
return spike
|
110
|
+
|
111
|
+
def preprocess_spike(self, spike):
|
112
|
+
"""Preprocess the spike (length size)."""
|
113
|
+
spike = self.crop_spike_length(spike)
|
114
|
+
return spike
|
115
|
+
|
116
|
+
def postprocess_img(self, image):
|
117
|
+
"""Postprocess the image."""
|
118
|
+
return image
|
119
|
+
|
120
|
+
# -------------------- Training Part --------------------
|
121
|
+
def setup_training(self, pipeline_cfg):
|
122
|
+
"""Setup training optimizer and loss."""
|
123
|
+
from spikezoo.pipeline import TrainPipelineConfig
|
124
|
+
|
125
|
+
cfg: TrainPipelineConfig = pipeline_cfg
|
126
|
+
self.optimizer = Adam(self.net.parameters(), lr=cfg.lr, betas=(0.9, 0.99), weight_decay=0)
|
127
|
+
self.scheduler = CosineAnnealingLR(self.optimizer, T_max=cfg.epochs, eta_min=0)
|
128
|
+
self.criterion = nn.L1Loss()
|
129
|
+
|
130
|
+
def get_outputs_dict(self, batch):
|
131
|
+
"""Get the output dict for the given input batch. (Designed for the training mode considering possible auxiliary output.)"""
|
132
|
+
# data process
|
133
|
+
spike = batch["spike"]
|
134
|
+
# outputs
|
135
|
+
outputs = {}
|
136
|
+
recon_img = self(spike)
|
137
|
+
outputs["recon_img"] = recon_img
|
138
|
+
return outputs
|
139
|
+
|
140
|
+
def get_visual_dict(self, batch, outputs):
|
141
|
+
"""Get the visual dict from the given input batch and outputs."""
|
142
|
+
visual_dict = {}
|
143
|
+
visual_dict["recon"] = outputs["recon_img"]
|
144
|
+
visual_dict["img"] = batch["img"]
|
145
|
+
return visual_dict
|
146
|
+
|
147
|
+
def get_loss_dict(self, outputs, batch):
|
148
|
+
"""Get the loss dict from the given input batch and outputs."""
|
149
|
+
# data process
|
150
|
+
gt_img = batch["img"]
|
151
|
+
# recon image
|
152
|
+
recon_img = outputs["recon_img"]
|
153
|
+
# loss dict
|
154
|
+
loss_dict = {}
|
155
|
+
loss_dict["l1"] = self.criterion(recon_img, gt_img)
|
156
|
+
loss_values_dict = {k: v.item() for k, v in loss_dict.items()}
|
157
|
+
return loss_dict,loss_values_dict
|
158
|
+
|
159
|
+
def get_paired_imgs(self, batch, outputs):
|
160
|
+
recon_img = outputs["recon_img"]
|
161
|
+
img = batch["img"]
|
162
|
+
return recon_img, img
|
163
|
+
|
164
|
+
def optimize_parameters(self, loss_dict):
|
165
|
+
"""Optimize the parameters from the loss_dict."""
|
166
|
+
loss = functools.reduce(torch.add, loss_dict.values())
|
167
|
+
self.optimizer.zero_grad()
|
168
|
+
loss.backward()
|
169
|
+
self.optimizer.step()
|
170
|
+
|
171
|
+
def update_learning_rate(self):
|
172
|
+
"""Update the learning rate."""
|
173
|
+
self.scheduler.step()
|
174
|
+
|
175
|
+
def feed_to_device(self, batch):
|
176
|
+
batch = {k: v.to(self.device, non_blocking=True) if torch.is_tensor(v) else v for k, v in batch.items()}
|
177
|
+
return batch
|