spikezoo 0.1__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (41) hide show
  1. spikezoo-0.1/LICENSE.txt +17 -0
  2. spikezoo-0.1/PKG-INFO +39 -0
  3. spikezoo-0.1/README.md +2 -0
  4. spikezoo-0.1/setup.cfg +4 -0
  5. spikezoo-0.1/setup.py +23 -0
  6. spikezoo-0.1/spikezoo/__init__.py +0 -0
  7. spikezoo-0.1/spikezoo/archs/__init__.py +0 -0
  8. spikezoo-0.1/spikezoo/datasets/__init__.py +68 -0
  9. spikezoo-0.1/spikezoo/datasets/base_dataset.py +157 -0
  10. spikezoo-0.1/spikezoo/datasets/realworld_dataset.py +25 -0
  11. spikezoo-0.1/spikezoo/datasets/reds_small_dataset.py +27 -0
  12. spikezoo-0.1/spikezoo/datasets/szdata_dataset.py +37 -0
  13. spikezoo-0.1/spikezoo/datasets/uhsr_dataset.py +38 -0
  14. spikezoo-0.1/spikezoo/metrics/__init__.py +96 -0
  15. spikezoo-0.1/spikezoo/models/__init__.py +37 -0
  16. spikezoo-0.1/spikezoo/models/base_model.py +177 -0
  17. spikezoo-0.1/spikezoo/models/bsf_model.py +90 -0
  18. spikezoo-0.1/spikezoo/models/spcsnet_model.py +19 -0
  19. spikezoo-0.1/spikezoo/models/spikeclip_model.py +32 -0
  20. spikezoo-0.1/spikezoo/models/spikeformer_model.py +50 -0
  21. spikezoo-0.1/spikezoo/models/spk2imgnet_model.py +51 -0
  22. spikezoo-0.1/spikezoo/models/ssir_model.py +22 -0
  23. spikezoo-0.1/spikezoo/models/ssml_model.py +18 -0
  24. spikezoo-0.1/spikezoo/models/stir_model.py +37 -0
  25. spikezoo-0.1/spikezoo/models/tfi_model.py +18 -0
  26. spikezoo-0.1/spikezoo/models/tfp_model.py +18 -0
  27. spikezoo-0.1/spikezoo/models/wgse_model.py +31 -0
  28. spikezoo-0.1/spikezoo/pipeline/__init__.py +4 -0
  29. spikezoo-0.1/spikezoo/pipeline/base_pipeline.py +267 -0
  30. spikezoo-0.1/spikezoo/pipeline/ensemble_pipeline.py +64 -0
  31. spikezoo-0.1/spikezoo/pipeline/train_pipeline.py +94 -0
  32. spikezoo-0.1/spikezoo/utils/__init__.py +3 -0
  33. spikezoo-0.1/spikezoo/utils/data_utils.py +52 -0
  34. spikezoo-0.1/spikezoo/utils/img_utils.py +72 -0
  35. spikezoo-0.1/spikezoo/utils/other_utils.py +59 -0
  36. spikezoo-0.1/spikezoo/utils/spike_utils.py +82 -0
  37. spikezoo-0.1/spikezoo.egg-info/PKG-INFO +39 -0
  38. spikezoo-0.1/spikezoo.egg-info/SOURCES.txt +39 -0
  39. spikezoo-0.1/spikezoo.egg-info/dependency_links.txt +1 -0
  40. spikezoo-0.1/spikezoo.egg-info/requires.txt +18 -0
  41. spikezoo-0.1/spikezoo.egg-info/top_level.txt +1 -0
@@ -0,0 +1,17 @@
1
+ MIT License
2
+ Copyright (c) 2018 YOUR NAME
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+ The above copyright notice and this permission notice shall be included in all
10
+ copies or substantial portions of the Software.
11
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
12
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
13
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
14
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
15
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
16
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17
+ SOFTWARE.
spikezoo-0.1/PKG-INFO ADDED
@@ -0,0 +1,39 @@
1
+ Metadata-Version: 2.2
2
+ Name: spikezoo
3
+ Version: 0.1
4
+ Summary: A deep learning toolbox for spike-to-image models.
5
+ Home-page: https://github.com/chenkang455/Spike-Zoo
6
+ Author: Kang Chen
7
+ Author-email: mrchenkang@stu.pku.edu.cn
8
+ Requires-Python: >=3.7
9
+ Description-Content-Type: text/markdown
10
+ License-File: LICENSE.txt
11
+ Requires-Dist: torch
12
+ Requires-Dist: requests
13
+ Requires-Dist: numpy
14
+ Requires-Dist: tqdm
15
+ Requires-Dist: scikit-image
16
+ Requires-Dist: lpips
17
+ Requires-Dist: pyiqa
18
+ Requires-Dist: opencv-python
19
+ Requires-Dist: thop
20
+ Requires-Dist: pytorch-wavelets
21
+ Requires-Dist: pytz
22
+ Requires-Dist: PyWavelets
23
+ Requires-Dist: pandas
24
+ Requires-Dist: pillow
25
+ Requires-Dist: scikit-learn
26
+ Requires-Dist: scipy
27
+ Requires-Dist: spikingjelly
28
+ Requires-Dist: setuptools
29
+ Dynamic: author
30
+ Dynamic: author-email
31
+ Dynamic: description
32
+ Dynamic: description-content-type
33
+ Dynamic: home-page
34
+ Dynamic: requires-dist
35
+ Dynamic: requires-python
36
+ Dynamic: summary
37
+
38
+ ⚡Spike-Zoo is the go-to library for state-of-the-art pretrained **spike-to-image** models for reconstructing the image from the given spike stream. Whether you're looking for a **simple inference** solution or **training** your own spike-to-image models, ⚡Spike-Zoo is a modular toolbox that supports both.
39
+
spikezoo-0.1/README.md ADDED
@@ -0,0 +1,2 @@
1
+ ⚡Spike-Zoo is the go-to library for state-of-the-art pretrained **spike-to-image** models for reconstructing the image from the given spike stream. Whether you're looking for a **simple inference** solution or **training** your own spike-to-image models, ⚡Spike-Zoo is a modular toolbox that supports both.
2
+
spikezoo-0.1/setup.cfg ADDED
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
spikezoo-0.1/setup.py ADDED
@@ -0,0 +1,23 @@
1
+ from setuptools import find_packages
2
+ from setuptools import setup
3
+
4
+ with open("./requirements.txt", "r", encoding="utf-8") as fh:
5
+ install_requires = fh.read()
6
+
7
+ with open("./README.md", "r", encoding="utf-8") as fh:
8
+ long_description = fh.read()
9
+
10
+ setup(
11
+ install_requires=install_requires,
12
+ name="spikezoo",
13
+ version="0.1",
14
+ author="Kang Chen",
15
+ author_email="mrchenkang@stu.pku.edu.cn",
16
+ description="A deep learning toolbox for spike-to-image models.",
17
+ long_description=long_description,
18
+ long_description_content_type="text/markdown",
19
+ url="https://github.com/chenkang455/Spike-Zoo",
20
+ packages=find_packages(),
21
+ python_requires='>=3.7',
22
+ include_package_data=False
23
+ )
File without changes
File without changes
@@ -0,0 +1,68 @@
1
+ from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
2
+ from dataclasses import replace
3
+ import importlib, inspect
4
+ import os
5
+ import torch
6
+ from typing import Literal
7
+
8
+ # todo auto detect/register datasets
9
+ files_list = os.listdir(os.path.dirname(os.path.abspath(__file__)))
10
+ dataset_list = [file.replace("_dataset.py", "") for file in files_list if file.endswith("_dataset.py")]
11
+
12
+ # todo register function
13
+ def build_dataset_cfg(cfg: BaseDatasetConfig, split: Literal["train", "test"] = "test"):
14
+ """Build the dataset from the given dataset config."""
15
+ # build new cfg according to split
16
+ cfg = replace(cfg,split = split,spike_length = cfg.spike_length_train if split == "train" else cfg.spike_length_test)
17
+ # dataset module
18
+ module_name = cfg.dataset_name + "_dataset"
19
+ assert cfg.dataset_name in dataset_list, f"Given dataset {cfg.dataset_name} not in our dataset list {dataset_list}."
20
+ module_name = "spikezoo.datasets." + module_name
21
+ module = importlib.import_module(module_name)
22
+ # dataset,dataset_config
23
+ classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
24
+ dataset_cls: BaseDataset = getattr(module, classes[0])
25
+ dataset = dataset_cls(cfg)
26
+ return dataset
27
+
28
+
29
+ def build_dataset_name(dataset_name: str, split: Literal["train", "test"] = "test"):
30
+ """Build the default dataset from the given name."""
31
+ module_name = dataset_name + "_dataset"
32
+ assert dataset_name in dataset_list, f"Given dataset {dataset_name} not in our dataset list {dataset_list}."
33
+ module_name = "spikezoo.datasets." + module_name
34
+ module = importlib.import_module(module_name)
35
+ # dataset,dataset_config
36
+ classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
37
+ dataset_cls: BaseDataset = getattr(module, classes[0])
38
+ dataset_cfg: BaseDatasetConfig = getattr(module, classes[1])(split=split)
39
+ dataset = dataset_cls(dataset_cfg)
40
+ return dataset
41
+
42
+
43
+ # todo to modify according to the basicsr
44
+ def build_dataloader(dataset: BaseDataset,cfg = None):
45
+ # train dataloader
46
+ if dataset.cfg.split == "train":
47
+ if cfg is None:
48
+ return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
49
+ else:
50
+ return torch.utils.data.DataLoader(dataset, batch_size=cfg.bs_train, shuffle=True, num_workers=cfg.num_workers,pin_memory=cfg.pin_memory)
51
+ # test dataloader
52
+ elif dataset.cfg.split == "test":
53
+ return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
54
+
55
+
56
+ # dataset_size_dict = {}
57
+ # for dataset in dataset_list:
58
+ # module_name = dataset + "_dataset"
59
+ # module_name = "spikezoo.datasets." + module_name
60
+ # module = importlib.import_module(module_name)
61
+ # classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj)])
62
+ # dataset_cfg: BaseDatasetConfig = getattr(module, classes[1])
63
+ # dataset_size_dict[dataset] = (dataset_cfg.height, dataset_cfg.width)
64
+
65
+
66
+ # def get_dataset_size(name):
67
+ # assert name in dataset_list, f"Given dataset {name} not in our dataset list {dataset_list}."
68
+ # return dataset_size_dict[name]
@@ -0,0 +1,157 @@
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ from pathlib import Path
4
+ import cv2
5
+ import numpy as np
6
+ from spikezoo.utils.spike_utils import load_vidar_dat
7
+ import re
8
+ from dataclasses import dataclass, replace
9
+ from typing import Literal
10
+ import warnings
11
+ import torch
12
+ from tqdm import tqdm
13
+ from spikezoo.utils.data_utils import Augmentor
14
+
15
+
16
+ @dataclass
17
+ class BaseDatasetConfig:
18
+ # ------------- Not Recommended to Change -------------
19
+ "Dataset name."
20
+ dataset_name: str = "base"
21
+ "Directory specifying location of data."
22
+ root_dir: Path = Path(__file__).parent.parent / Path("data/base")
23
+ "Image width."
24
+ width: int = 400
25
+ "Image height."
26
+ height: int = 250
27
+ "Spike paried with the image or not."
28
+ with_img: bool = True
29
+ "Dataset spike length for the train data."
30
+ spike_length_train: int = -1
31
+ "Dataset spike length for the test data."
32
+ spike_length_test: int = -1
33
+ "Dataset spike length for the instantiation dataclass."
34
+ spike_length: int = -1
35
+ "Dir name for the spike."
36
+ spike_dir_name: str = "spike"
37
+ "Dir name for the image."
38
+ img_dir_name: str = "gt"
39
+
40
+ # ------------- Config -------------
41
+ "Dataset split: train/test. Default set as the 'test' for evaluation."
42
+ split: Literal["train", "test"] = "test"
43
+ "Use the data augumentation technique or not."
44
+ use_aug: bool = False
45
+ "Use cache mechanism."
46
+ use_cache: bool = False
47
+ "Crop size."
48
+ crop_size: tuple = (-1, -1)
49
+
50
+ # post process
51
+ def __post_init__(self):
52
+ self.spike_length = self.spike_length_train if self.split == "train" else self.spike_length_test
53
+ self.root_dir = Path(self.root_dir) if isinstance(self.root_dir, str) else self.root_dir
54
+ # todo try download
55
+ assert self.root_dir.exists(), f"No files found in {self.root_dir} for the specified dataset `{self.dataset_name}`."
56
+
57
+ # todo cache mechanism
58
+ class BaseDataset(Dataset):
59
+ def __init__(self, cfg: BaseDatasetConfig):
60
+ super(BaseDataset, self).__init__()
61
+ self.cfg = cfg
62
+ self.augmentor = Augmentor(self.cfg.crop_size) if self.cfg.use_aug == True and self.cfg.split == "train" else -1
63
+ self.prepare_data()
64
+ self.cache_data() if cfg.use_cache == True else -1
65
+ warnings.warn("Lengths of the image list and the spike list should be equal.") if len(self.img_list) != len(self.spike_list) else -1
66
+
67
+ def __len__(self):
68
+ return len(self.spike_list)
69
+
70
+ def __getitem__(self, idx: int):
71
+ # load data
72
+ if self.cfg.use_cache == True:
73
+ spike, img = self.cache_spkimg[idx]
74
+ spike = spike.to(torch.float32)
75
+ else:
76
+ spike = self.get_spike(idx)
77
+ img = self.get_img(idx)
78
+
79
+ # process data
80
+ if self.cfg.use_aug == True and self.cfg.split == "train":
81
+ spike, img = self.augmentor(spike, img)
82
+
83
+ batch = {"spike": spike, "img": img}
84
+ return batch
85
+
86
+ # todo: To be overridden
87
+ def prepare_data(self):
88
+ """Specify the spike and image files to be loaded."""
89
+ # spike
90
+ self.spike_dir = self.cfg.root_dir / self.cfg.split / self.cfg.spike_dir_name
91
+ self.spike_list = self.get_spike_files(self.spike_dir)
92
+ # gt
93
+ if self.cfg.with_img == True:
94
+ self.img_dir = self.cfg.root_dir / self.cfg.split / self.cfg.img_dir_name
95
+ self.img_list = self.get_image_files(self.img_dir)
96
+
97
+ # todo: To be overridden
98
+ def get_spike_files(self, path: Path):
99
+ """Recognize spike files automatically (default .dat)."""
100
+ files = path.glob("**/*.dat")
101
+ return sorted(files)
102
+
103
+ # todo: To be overridden
104
+ def load_spike(self, idx):
105
+ """Load the spike stream from the given idx."""
106
+ spike_name = str(self.spike_list[idx])
107
+ spike = load_vidar_dat(
108
+ spike_name,
109
+ height=self.cfg.height,
110
+ width=self.cfg.width,
111
+ out_type="float",
112
+ out_format="tensor",
113
+ )
114
+ return spike
115
+
116
+ def get_spike(self, idx):
117
+ """Get and process the spike stream from the given idx."""
118
+ spike_length = self.cfg.spike_length
119
+ spike = self.load_spike(idx)
120
+ assert spike.shape[0] >= spike_length, f"Given spike length {spike.shape[0]} smaller than the required length {spike_length}"
121
+ spike_mid = spike.shape[0] // 2
122
+ # spike length process
123
+ if spike_length == -1:
124
+ spike = spike
125
+ elif spike_length % 2 == 1:
126
+ spike = spike[spike_mid - spike_length // 2 : spike_mid + spike_length // 2 + 1]
127
+ elif spike_length % 2 == 0:
128
+ spike = spike[spike_mid - spike_length // 2 : spike_mid + spike_length // 2]
129
+ return spike
130
+
131
+ def get_image_files(self, path: Path):
132
+ """Recognize image files automatically."""
133
+ files = [f for f in path.glob("**/*") if re.match(r".*\.(jpg|jpeg|png)$", f.name, re.IGNORECASE)]
134
+ return sorted(files)
135
+
136
+ # todo: To be overridden
137
+ def get_img(self, idx):
138
+ """Get the image from the given idx."""
139
+ if self.cfg.with_img:
140
+ img_name = str(self.img_list[idx])
141
+ img = cv2.imread(img_name)
142
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
143
+ img = (img / 255).astype(np.float32)
144
+ img = img[None]
145
+ img = torch.from_numpy(img)
146
+ else:
147
+ spike = self.get_spike(idx)
148
+ img = torch.mean(spike, dim=0, keepdim=True)
149
+ return img
150
+
151
+ def cache_data(self):
152
+ """Cache the data."""
153
+ self.cache_spkimg = []
154
+ for idx in tqdm(range(len(self.spike_list)), desc="Caching data", unit="sample"):
155
+ spike = self.get_spike(idx).to(torch.uint8)
156
+ img = self.get_img(idx)
157
+ self.cache_spkimg.append([spike, img])
@@ -0,0 +1,25 @@
1
+ from pathlib import Path
2
+ from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
3
+ from dataclasses import dataclass
4
+
5
+
6
+ @dataclass
7
+ class RealWorld_Config(BaseDatasetConfig):
8
+ dataset_name: str = "realworld"
9
+ root_dir: Path = Path(__file__).parent.parent / Path("data/recVidarReal2019")
10
+ width: int = 400
11
+ height: int = 250
12
+ with_img: bool = False
13
+ spike_length_train: int = -1
14
+ spike_length_test: int = -1
15
+
16
+
17
+ class RealWorld(BaseDataset):
18
+ def __init__(self, cfg: BaseDatasetConfig):
19
+ super(RealWorld, self).__init__(cfg)
20
+
21
+ def prepare_data(self):
22
+ self.spike_dir = self.cfg.root_dir
23
+ self.spike_list = self.get_spike_files(self.spike_dir)
24
+
25
+
@@ -0,0 +1,27 @@
1
+ from torch.utils.data import Dataset
2
+ from pathlib import Path
3
+ from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
4
+ from dataclasses import dataclass
5
+ import re
6
+
7
+ @dataclass
8
+ class REDS_Small_Config(BaseDatasetConfig):
9
+ dataset_name: str = "reds_small"
10
+ root_dir: Path = Path(__file__).parent.parent / Path("data/REDS_Small")
11
+ width: int = 400
12
+ height: int = 250
13
+ with_img: bool = True
14
+ spike_length_train: int = 41
15
+ spike_length_test: int = 301
16
+ spike_dir_name: str = "spike"
17
+ img_dir_name: str = "gt"
18
+
19
+ class REDS_Small(BaseDataset):
20
+ def __init__(self, cfg: BaseDatasetConfig):
21
+ super(REDS_Small, self).__init__(cfg)
22
+
23
+ def prepare_data(self):
24
+ super().prepare_data()
25
+ if self.cfg.split == "train":
26
+ self.img_list = [self.img_dir / Path(str(s.name).replace('.dat','.png')) for s in self.spike_list]
27
+
@@ -0,0 +1,37 @@
1
+ from pathlib import Path
2
+ from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
3
+ from dataclasses import dataclass
4
+ import cv2
5
+ import torch
6
+ import numpy as np
7
+
8
+ @dataclass
9
+ class SZData_Config(BaseDatasetConfig):
10
+ dataset_name: str = "szdata"
11
+ root_dir: Path = Path(__file__).parent.parent / Path("data/dataset")
12
+ width: int = 400
13
+ height: int = 250
14
+ with_img: bool = True
15
+ spike_length_train: int = -1
16
+ spike_length_test: int = -1
17
+ spike_dir_name: str = "spike_data"
18
+ img_dir_name: str = "sharp_data"
19
+
20
+ class SZData(BaseDataset):
21
+ def __init__(self, cfg: BaseDatasetConfig):
22
+ super(SZData, self).__init__(cfg)
23
+
24
+ def get_img(self, idx):
25
+ if self.cfg.with_img:
26
+ spike_name = self.spike_list[idx]
27
+ img_name = str(spike_name).replace(self.cfg.spike_dir_name,self.cfg.img_dir_name).replace(".dat",".png")
28
+ img = cv2.imread(img_name)
29
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
30
+ img = (img / 255).astype(np.float32)
31
+ img = img[None]
32
+ img = torch.from_numpy(img)
33
+ else:
34
+ spike = self.get_spike(idx)
35
+ img = torch.mean(spike, dim=0, keepdim=True)
36
+ return img
37
+
@@ -0,0 +1,38 @@
1
+ from pathlib import Path
2
+ from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
3
+ from dataclasses import dataclass
4
+ import numpy as np
5
+ import torch
6
+
7
+ @dataclass
8
+ class UHSR_Config(BaseDatasetConfig):
9
+ dataset_name: str = "uhsr"
10
+ root_dir: Path = Path(__file__).parent.parent / Path("data/U-CALTECH")
11
+ width: int = 224
12
+ height: int = 224
13
+ with_img: bool = False
14
+ spike_length_train: int = 200
15
+ spike_length_test: int = 200
16
+ spike_dir_name: str = "spike"
17
+ img_dir_name: str = ""
18
+
19
+
20
+ class UHSR(BaseDataset):
21
+ def __init__(self, cfg: BaseDatasetConfig):
22
+ super(UHSR, self).__init__(cfg)
23
+
24
+ def prepare_data(self):
25
+ self.spike_dir = self.cfg.root_dir / self.cfg.dataset_split
26
+ self.spike_list = self.get_spike_files(self.spike_dir)
27
+
28
+ def get_spike_files(self, path: Path):
29
+ files = path.glob("**/*.npz")
30
+ return sorted(files)
31
+
32
+ def load_spike(self,idx):
33
+ spike_name = str(self.spike_list[idx])
34
+ data = np.load(spike_name)
35
+ spike = data["spk"].astype(np.float32)
36
+ spike = torch.from_numpy(spike)
37
+ spike = spike[:, 13:237, 13:237]
38
+ return spike
@@ -0,0 +1,96 @@
1
+ from skimage import metrics
2
+ import torch
3
+ import torch.hub
4
+ from lpips.lpips import LPIPS
5
+ import os
6
+ import os
7
+ import pyiqa
8
+ import numpy as np
9
+ from torchvision import transforms
10
+ import torch.nn.functional as F
11
+
12
+ # todo with the union type
13
+ metric_pair_names = ["psnr", "ssim", "lpips", "mse"]
14
+ metric_single_names = ["niqe", "brisque", "piqe"]
15
+ metric_all_names = metric_pair_names + metric_single_names
16
+
17
+ metric_single_list = {}
18
+
19
+ metric_pair_list = {
20
+ "mse": metrics.mean_squared_error,
21
+ "ssim": metrics.structural_similarity,
22
+ "psnr": metrics.peak_signal_noise_ratio,
23
+ "lpips": None,
24
+ }
25
+
26
+
27
+ def cal_metric_single(img: torch.Tensor, metric_name="niqe"):
28
+ if metric_name not in metric_single_list.keys():
29
+ if metric_name in pyiqa.list_models():
30
+ iqa_metric = pyiqa.create_metric(metric_name, device=torch.device("cuda"))
31
+ metric_single_list.update({metric_name: iqa_metric})
32
+ else:
33
+ raise RuntimeError(f"Metric {metric_name} not recognized by the IQA lib.")
34
+ # image process
35
+ if img.dim() == 3:
36
+ img = img[None]
37
+ elif img.dim() == 2:
38
+ img = img[None, None]
39
+
40
+ # resize
41
+ if metric_name == "liqe_mix":
42
+ short_edge = 384
43
+ h, w = img.shape[2], img.shape[3]
44
+ if h < w:
45
+ new_h, new_w = short_edge, int(w * short_edge / h)
46
+ else:
47
+ new_h, new_w = int(h * short_edge / w), short_edge
48
+ img = F.interpolate(img, size=(new_h, new_w), mode="bilinear", align_corners=False)
49
+ return metric_single_list[metric_name](img).item()
50
+
51
+
52
+ def cal_metric_pair(im1t: torch.Tensor, im2t: torch.Tensor, metric_name="mse"):
53
+ """
54
+ im1t, im2t: torch.tensors with batched imaged shape, range from (0, 1)
55
+ """
56
+ if metric_name not in metric_pair_list.keys():
57
+ raise RuntimeError(f"Metric {metric_name} not recognized")
58
+ if metric_name == "lpips" and metric_pair_list[metric_name] is None:
59
+ metric_pair_list[metric_name] = LPIPS().cuda() if im1t.is_cuda else LPIPS().cpu()
60
+ metric_method = metric_pair_list[metric_name]
61
+
62
+ # convert from [0, 1] to [-1, 1]
63
+ im1t = (im1t * 2 - 1).clamp(-1, 1)
64
+ im2t = (im2t * 2 - 1).clamp(-1, 1)
65
+
66
+ # [c,h,w] -> [1,c,h,w]
67
+ if im1t.dim() == 3:
68
+ im1t = im1t[None]
69
+ im2t = im2t[None]
70
+ elif im1t.dim() == 2:
71
+ im1t = im1t[None, None]
72
+ im2t = im2t[None, None]
73
+
74
+ # [1,h,w,3] -> [1,3,h,w]
75
+ if im1t.shape[-1] == 3:
76
+ im1t = im1t.permute(0, 3, 1, 2)
77
+ im2t = im2t.permute(0, 3, 1, 2)
78
+
79
+ # img array: [1,h,w,3] imgt tensor: [1,3,h,w]
80
+ im1 = im1t.permute(0, 2, 3, 1).detach().cpu().numpy()
81
+ im2 = im2t.permute(0, 2, 3, 1).detach().cpu().numpy()
82
+ batchsz, hei, wid, _ = im1.shape
83
+
84
+ # batch processing
85
+ values = []
86
+ for i in range(batchsz):
87
+ if metric_name in ["mse", "psnr"]:
88
+ value = metric_method(im1[i], im2[i])
89
+ elif metric_name in ["ssim"]:
90
+ value, ssimmap = metric_method(im1[i], im2[i], channel_axis=-1, data_range=2, full=True)
91
+ elif metric_name in ["lpips"]:
92
+ value = metric_method(im1t[i : i + 1], im2t[i : i + 1])[0, 0, 0, 0]
93
+ value = value.detach().cpu().float().item()
94
+ values.append(value)
95
+
96
+ return sum(values) / len(values)
@@ -0,0 +1,37 @@
1
+ import importlib
2
+ import inspect
3
+ from spikezoo.models.base_model import BaseModel,BaseModelConfig
4
+ import os
5
+ from pathlib import Path
6
+
7
+ current_file_path = Path(__file__).parent
8
+ files_list = os.listdir(os.path.dirname(os.path.abspath(__file__)))
9
+ model_list = [file.split("_")[0] for file in files_list if file.endswith("_model.py")]
10
+
11
+ # todo register function
12
+ def build_model_cfg(cfg: BaseModelConfig):
13
+ """Build the model from the given model config."""
14
+ # model module name
15
+ module_name = cfg.model_name + "_model"
16
+ assert cfg.model_name in model_list, f"Given model {cfg.model_name} not in our model zoo {model_list}."
17
+ module_name = "spikezoo.models." + module_name
18
+ module = importlib.import_module(module_name)
19
+ # model,model_config
20
+ classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
21
+ model_cls: BaseModel = getattr(module, classes[0])
22
+ model = model_cls(cfg)
23
+ return model
24
+
25
+ def build_model_name(model_name: str):
26
+ """Build the default dataset from the given name."""
27
+ # model module name
28
+ module_name = model_name + "_model"
29
+ assert model_name in model_list, f"Given model {model_name} not in our model zoo {model_list}."
30
+ module_name = "spikezoo.models." + module_name
31
+ module = importlib.import_module(module_name)
32
+ # model,model_config
33
+ classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
34
+ model_cls: BaseModel = getattr(module, classes[0])
35
+ model_cfg: BaseModelConfig = getattr(module, classes[1])()
36
+ model = model_cls(model_cfg)
37
+ return model