spikezoo 0.1__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- spikezoo-0.1/LICENSE.txt +17 -0
- spikezoo-0.1/PKG-INFO +39 -0
- spikezoo-0.1/README.md +2 -0
- spikezoo-0.1/setup.cfg +4 -0
- spikezoo-0.1/setup.py +23 -0
- spikezoo-0.1/spikezoo/__init__.py +0 -0
- spikezoo-0.1/spikezoo/archs/__init__.py +0 -0
- spikezoo-0.1/spikezoo/datasets/__init__.py +68 -0
- spikezoo-0.1/spikezoo/datasets/base_dataset.py +157 -0
- spikezoo-0.1/spikezoo/datasets/realworld_dataset.py +25 -0
- spikezoo-0.1/spikezoo/datasets/reds_small_dataset.py +27 -0
- spikezoo-0.1/spikezoo/datasets/szdata_dataset.py +37 -0
- spikezoo-0.1/spikezoo/datasets/uhsr_dataset.py +38 -0
- spikezoo-0.1/spikezoo/metrics/__init__.py +96 -0
- spikezoo-0.1/spikezoo/models/__init__.py +37 -0
- spikezoo-0.1/spikezoo/models/base_model.py +177 -0
- spikezoo-0.1/spikezoo/models/bsf_model.py +90 -0
- spikezoo-0.1/spikezoo/models/spcsnet_model.py +19 -0
- spikezoo-0.1/spikezoo/models/spikeclip_model.py +32 -0
- spikezoo-0.1/spikezoo/models/spikeformer_model.py +50 -0
- spikezoo-0.1/spikezoo/models/spk2imgnet_model.py +51 -0
- spikezoo-0.1/spikezoo/models/ssir_model.py +22 -0
- spikezoo-0.1/spikezoo/models/ssml_model.py +18 -0
- spikezoo-0.1/spikezoo/models/stir_model.py +37 -0
- spikezoo-0.1/spikezoo/models/tfi_model.py +18 -0
- spikezoo-0.1/spikezoo/models/tfp_model.py +18 -0
- spikezoo-0.1/spikezoo/models/wgse_model.py +31 -0
- spikezoo-0.1/spikezoo/pipeline/__init__.py +4 -0
- spikezoo-0.1/spikezoo/pipeline/base_pipeline.py +267 -0
- spikezoo-0.1/spikezoo/pipeline/ensemble_pipeline.py +64 -0
- spikezoo-0.1/spikezoo/pipeline/train_pipeline.py +94 -0
- spikezoo-0.1/spikezoo/utils/__init__.py +3 -0
- spikezoo-0.1/spikezoo/utils/data_utils.py +52 -0
- spikezoo-0.1/spikezoo/utils/img_utils.py +72 -0
- spikezoo-0.1/spikezoo/utils/other_utils.py +59 -0
- spikezoo-0.1/spikezoo/utils/spike_utils.py +82 -0
- spikezoo-0.1/spikezoo.egg-info/PKG-INFO +39 -0
- spikezoo-0.1/spikezoo.egg-info/SOURCES.txt +39 -0
- spikezoo-0.1/spikezoo.egg-info/dependency_links.txt +1 -0
- spikezoo-0.1/spikezoo.egg-info/requires.txt +18 -0
- spikezoo-0.1/spikezoo.egg-info/top_level.txt +1 -0
spikezoo-0.1/LICENSE.txt
ADDED
@@ -0,0 +1,17 @@
|
|
1
|
+
MIT License
|
2
|
+
Copyright (c) 2018 YOUR NAME
|
3
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
4
|
+
of this software and associated documentation files (the "Software"), to deal
|
5
|
+
in the Software without restriction, including without limitation the rights
|
6
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7
|
+
copies of the Software, and to permit persons to whom the Software is
|
8
|
+
furnished to do so, subject to the following conditions:
|
9
|
+
The above copyright notice and this permission notice shall be included in all
|
10
|
+
copies or substantial portions of the Software.
|
11
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
12
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
13
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
14
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
15
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
16
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
17
|
+
SOFTWARE.
|
spikezoo-0.1/PKG-INFO
ADDED
@@ -0,0 +1,39 @@
|
|
1
|
+
Metadata-Version: 2.2
|
2
|
+
Name: spikezoo
|
3
|
+
Version: 0.1
|
4
|
+
Summary: A deep learning toolbox for spike-to-image models.
|
5
|
+
Home-page: https://github.com/chenkang455/Spike-Zoo
|
6
|
+
Author: Kang Chen
|
7
|
+
Author-email: mrchenkang@stu.pku.edu.cn
|
8
|
+
Requires-Python: >=3.7
|
9
|
+
Description-Content-Type: text/markdown
|
10
|
+
License-File: LICENSE.txt
|
11
|
+
Requires-Dist: torch
|
12
|
+
Requires-Dist: requests
|
13
|
+
Requires-Dist: numpy
|
14
|
+
Requires-Dist: tqdm
|
15
|
+
Requires-Dist: scikit-image
|
16
|
+
Requires-Dist: lpips
|
17
|
+
Requires-Dist: pyiqa
|
18
|
+
Requires-Dist: opencv-python
|
19
|
+
Requires-Dist: thop
|
20
|
+
Requires-Dist: pytorch-wavelets
|
21
|
+
Requires-Dist: pytz
|
22
|
+
Requires-Dist: PyWavelets
|
23
|
+
Requires-Dist: pandas
|
24
|
+
Requires-Dist: pillow
|
25
|
+
Requires-Dist: scikit-learn
|
26
|
+
Requires-Dist: scipy
|
27
|
+
Requires-Dist: spikingjelly
|
28
|
+
Requires-Dist: setuptools
|
29
|
+
Dynamic: author
|
30
|
+
Dynamic: author-email
|
31
|
+
Dynamic: description
|
32
|
+
Dynamic: description-content-type
|
33
|
+
Dynamic: home-page
|
34
|
+
Dynamic: requires-dist
|
35
|
+
Dynamic: requires-python
|
36
|
+
Dynamic: summary
|
37
|
+
|
38
|
+
⚡Spike-Zoo is the go-to library for state-of-the-art pretrained **spike-to-image** models for reconstructing the image from the given spike stream. Whether you're looking for a **simple inference** solution or **training** your own spike-to-image models, ⚡Spike-Zoo is a modular toolbox that supports both.
|
39
|
+
|
spikezoo-0.1/README.md
ADDED
@@ -0,0 +1,2 @@
|
|
1
|
+
⚡Spike-Zoo is the go-to library for state-of-the-art pretrained **spike-to-image** models for reconstructing the image from the given spike stream. Whether you're looking for a **simple inference** solution or **training** your own spike-to-image models, ⚡Spike-Zoo is a modular toolbox that supports both.
|
2
|
+
|
spikezoo-0.1/setup.cfg
ADDED
spikezoo-0.1/setup.py
ADDED
@@ -0,0 +1,23 @@
|
|
1
|
+
from setuptools import find_packages
|
2
|
+
from setuptools import setup
|
3
|
+
|
4
|
+
with open("./requirements.txt", "r", encoding="utf-8") as fh:
|
5
|
+
install_requires = fh.read()
|
6
|
+
|
7
|
+
with open("./README.md", "r", encoding="utf-8") as fh:
|
8
|
+
long_description = fh.read()
|
9
|
+
|
10
|
+
setup(
|
11
|
+
install_requires=install_requires,
|
12
|
+
name="spikezoo",
|
13
|
+
version="0.1",
|
14
|
+
author="Kang Chen",
|
15
|
+
author_email="mrchenkang@stu.pku.edu.cn",
|
16
|
+
description="A deep learning toolbox for spike-to-image models.",
|
17
|
+
long_description=long_description,
|
18
|
+
long_description_content_type="text/markdown",
|
19
|
+
url="https://github.com/chenkang455/Spike-Zoo",
|
20
|
+
packages=find_packages(),
|
21
|
+
python_requires='>=3.7',
|
22
|
+
include_package_data=False
|
23
|
+
)
|
File without changes
|
File without changes
|
@@ -0,0 +1,68 @@
|
|
1
|
+
from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
|
2
|
+
from dataclasses import replace
|
3
|
+
import importlib, inspect
|
4
|
+
import os
|
5
|
+
import torch
|
6
|
+
from typing import Literal
|
7
|
+
|
8
|
+
# todo auto detect/register datasets
|
9
|
+
files_list = os.listdir(os.path.dirname(os.path.abspath(__file__)))
|
10
|
+
dataset_list = [file.replace("_dataset.py", "") for file in files_list if file.endswith("_dataset.py")]
|
11
|
+
|
12
|
+
# todo register function
|
13
|
+
def build_dataset_cfg(cfg: BaseDatasetConfig, split: Literal["train", "test"] = "test"):
|
14
|
+
"""Build the dataset from the given dataset config."""
|
15
|
+
# build new cfg according to split
|
16
|
+
cfg = replace(cfg,split = split,spike_length = cfg.spike_length_train if split == "train" else cfg.spike_length_test)
|
17
|
+
# dataset module
|
18
|
+
module_name = cfg.dataset_name + "_dataset"
|
19
|
+
assert cfg.dataset_name in dataset_list, f"Given dataset {cfg.dataset_name} not in our dataset list {dataset_list}."
|
20
|
+
module_name = "spikezoo.datasets." + module_name
|
21
|
+
module = importlib.import_module(module_name)
|
22
|
+
# dataset,dataset_config
|
23
|
+
classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
|
24
|
+
dataset_cls: BaseDataset = getattr(module, classes[0])
|
25
|
+
dataset = dataset_cls(cfg)
|
26
|
+
return dataset
|
27
|
+
|
28
|
+
|
29
|
+
def build_dataset_name(dataset_name: str, split: Literal["train", "test"] = "test"):
|
30
|
+
"""Build the default dataset from the given name."""
|
31
|
+
module_name = dataset_name + "_dataset"
|
32
|
+
assert dataset_name in dataset_list, f"Given dataset {dataset_name} not in our dataset list {dataset_list}."
|
33
|
+
module_name = "spikezoo.datasets." + module_name
|
34
|
+
module = importlib.import_module(module_name)
|
35
|
+
# dataset,dataset_config
|
36
|
+
classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
|
37
|
+
dataset_cls: BaseDataset = getattr(module, classes[0])
|
38
|
+
dataset_cfg: BaseDatasetConfig = getattr(module, classes[1])(split=split)
|
39
|
+
dataset = dataset_cls(dataset_cfg)
|
40
|
+
return dataset
|
41
|
+
|
42
|
+
|
43
|
+
# todo to modify according to the basicsr
|
44
|
+
def build_dataloader(dataset: BaseDataset,cfg = None):
|
45
|
+
# train dataloader
|
46
|
+
if dataset.cfg.split == "train":
|
47
|
+
if cfg is None:
|
48
|
+
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
|
49
|
+
else:
|
50
|
+
return torch.utils.data.DataLoader(dataset, batch_size=cfg.bs_train, shuffle=True, num_workers=cfg.num_workers,pin_memory=cfg.pin_memory)
|
51
|
+
# test dataloader
|
52
|
+
elif dataset.cfg.split == "test":
|
53
|
+
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
|
54
|
+
|
55
|
+
|
56
|
+
# dataset_size_dict = {}
|
57
|
+
# for dataset in dataset_list:
|
58
|
+
# module_name = dataset + "_dataset"
|
59
|
+
# module_name = "spikezoo.datasets." + module_name
|
60
|
+
# module = importlib.import_module(module_name)
|
61
|
+
# classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj)])
|
62
|
+
# dataset_cfg: BaseDatasetConfig = getattr(module, classes[1])
|
63
|
+
# dataset_size_dict[dataset] = (dataset_cfg.height, dataset_cfg.width)
|
64
|
+
|
65
|
+
|
66
|
+
# def get_dataset_size(name):
|
67
|
+
# assert name in dataset_list, f"Given dataset {name} not in our dataset list {dataset_list}."
|
68
|
+
# return dataset_size_dict[name]
|
@@ -0,0 +1,157 @@
|
|
1
|
+
import torch
|
2
|
+
from torch.utils.data import Dataset
|
3
|
+
from pathlib import Path
|
4
|
+
import cv2
|
5
|
+
import numpy as np
|
6
|
+
from spikezoo.utils.spike_utils import load_vidar_dat
|
7
|
+
import re
|
8
|
+
from dataclasses import dataclass, replace
|
9
|
+
from typing import Literal
|
10
|
+
import warnings
|
11
|
+
import torch
|
12
|
+
from tqdm import tqdm
|
13
|
+
from spikezoo.utils.data_utils import Augmentor
|
14
|
+
|
15
|
+
|
16
|
+
@dataclass
|
17
|
+
class BaseDatasetConfig:
|
18
|
+
# ------------- Not Recommended to Change -------------
|
19
|
+
"Dataset name."
|
20
|
+
dataset_name: str = "base"
|
21
|
+
"Directory specifying location of data."
|
22
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/base")
|
23
|
+
"Image width."
|
24
|
+
width: int = 400
|
25
|
+
"Image height."
|
26
|
+
height: int = 250
|
27
|
+
"Spike paried with the image or not."
|
28
|
+
with_img: bool = True
|
29
|
+
"Dataset spike length for the train data."
|
30
|
+
spike_length_train: int = -1
|
31
|
+
"Dataset spike length for the test data."
|
32
|
+
spike_length_test: int = -1
|
33
|
+
"Dataset spike length for the instantiation dataclass."
|
34
|
+
spike_length: int = -1
|
35
|
+
"Dir name for the spike."
|
36
|
+
spike_dir_name: str = "spike"
|
37
|
+
"Dir name for the image."
|
38
|
+
img_dir_name: str = "gt"
|
39
|
+
|
40
|
+
# ------------- Config -------------
|
41
|
+
"Dataset split: train/test. Default set as the 'test' for evaluation."
|
42
|
+
split: Literal["train", "test"] = "test"
|
43
|
+
"Use the data augumentation technique or not."
|
44
|
+
use_aug: bool = False
|
45
|
+
"Use cache mechanism."
|
46
|
+
use_cache: bool = False
|
47
|
+
"Crop size."
|
48
|
+
crop_size: tuple = (-1, -1)
|
49
|
+
|
50
|
+
# post process
|
51
|
+
def __post_init__(self):
|
52
|
+
self.spike_length = self.spike_length_train if self.split == "train" else self.spike_length_test
|
53
|
+
self.root_dir = Path(self.root_dir) if isinstance(self.root_dir, str) else self.root_dir
|
54
|
+
# todo try download
|
55
|
+
assert self.root_dir.exists(), f"No files found in {self.root_dir} for the specified dataset `{self.dataset_name}`."
|
56
|
+
|
57
|
+
# todo cache mechanism
|
58
|
+
class BaseDataset(Dataset):
|
59
|
+
def __init__(self, cfg: BaseDatasetConfig):
|
60
|
+
super(BaseDataset, self).__init__()
|
61
|
+
self.cfg = cfg
|
62
|
+
self.augmentor = Augmentor(self.cfg.crop_size) if self.cfg.use_aug == True and self.cfg.split == "train" else -1
|
63
|
+
self.prepare_data()
|
64
|
+
self.cache_data() if cfg.use_cache == True else -1
|
65
|
+
warnings.warn("Lengths of the image list and the spike list should be equal.") if len(self.img_list) != len(self.spike_list) else -1
|
66
|
+
|
67
|
+
def __len__(self):
|
68
|
+
return len(self.spike_list)
|
69
|
+
|
70
|
+
def __getitem__(self, idx: int):
|
71
|
+
# load data
|
72
|
+
if self.cfg.use_cache == True:
|
73
|
+
spike, img = self.cache_spkimg[idx]
|
74
|
+
spike = spike.to(torch.float32)
|
75
|
+
else:
|
76
|
+
spike = self.get_spike(idx)
|
77
|
+
img = self.get_img(idx)
|
78
|
+
|
79
|
+
# process data
|
80
|
+
if self.cfg.use_aug == True and self.cfg.split == "train":
|
81
|
+
spike, img = self.augmentor(spike, img)
|
82
|
+
|
83
|
+
batch = {"spike": spike, "img": img}
|
84
|
+
return batch
|
85
|
+
|
86
|
+
# todo: To be overridden
|
87
|
+
def prepare_data(self):
|
88
|
+
"""Specify the spike and image files to be loaded."""
|
89
|
+
# spike
|
90
|
+
self.spike_dir = self.cfg.root_dir / self.cfg.split / self.cfg.spike_dir_name
|
91
|
+
self.spike_list = self.get_spike_files(self.spike_dir)
|
92
|
+
# gt
|
93
|
+
if self.cfg.with_img == True:
|
94
|
+
self.img_dir = self.cfg.root_dir / self.cfg.split / self.cfg.img_dir_name
|
95
|
+
self.img_list = self.get_image_files(self.img_dir)
|
96
|
+
|
97
|
+
# todo: To be overridden
|
98
|
+
def get_spike_files(self, path: Path):
|
99
|
+
"""Recognize spike files automatically (default .dat)."""
|
100
|
+
files = path.glob("**/*.dat")
|
101
|
+
return sorted(files)
|
102
|
+
|
103
|
+
# todo: To be overridden
|
104
|
+
def load_spike(self, idx):
|
105
|
+
"""Load the spike stream from the given idx."""
|
106
|
+
spike_name = str(self.spike_list[idx])
|
107
|
+
spike = load_vidar_dat(
|
108
|
+
spike_name,
|
109
|
+
height=self.cfg.height,
|
110
|
+
width=self.cfg.width,
|
111
|
+
out_type="float",
|
112
|
+
out_format="tensor",
|
113
|
+
)
|
114
|
+
return spike
|
115
|
+
|
116
|
+
def get_spike(self, idx):
|
117
|
+
"""Get and process the spike stream from the given idx."""
|
118
|
+
spike_length = self.cfg.spike_length
|
119
|
+
spike = self.load_spike(idx)
|
120
|
+
assert spike.shape[0] >= spike_length, f"Given spike length {spike.shape[0]} smaller than the required length {spike_length}"
|
121
|
+
spike_mid = spike.shape[0] // 2
|
122
|
+
# spike length process
|
123
|
+
if spike_length == -1:
|
124
|
+
spike = spike
|
125
|
+
elif spike_length % 2 == 1:
|
126
|
+
spike = spike[spike_mid - spike_length // 2 : spike_mid + spike_length // 2 + 1]
|
127
|
+
elif spike_length % 2 == 0:
|
128
|
+
spike = spike[spike_mid - spike_length // 2 : spike_mid + spike_length // 2]
|
129
|
+
return spike
|
130
|
+
|
131
|
+
def get_image_files(self, path: Path):
|
132
|
+
"""Recognize image files automatically."""
|
133
|
+
files = [f for f in path.glob("**/*") if re.match(r".*\.(jpg|jpeg|png)$", f.name, re.IGNORECASE)]
|
134
|
+
return sorted(files)
|
135
|
+
|
136
|
+
# todo: To be overridden
|
137
|
+
def get_img(self, idx):
|
138
|
+
"""Get the image from the given idx."""
|
139
|
+
if self.cfg.with_img:
|
140
|
+
img_name = str(self.img_list[idx])
|
141
|
+
img = cv2.imread(img_name)
|
142
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
143
|
+
img = (img / 255).astype(np.float32)
|
144
|
+
img = img[None]
|
145
|
+
img = torch.from_numpy(img)
|
146
|
+
else:
|
147
|
+
spike = self.get_spike(idx)
|
148
|
+
img = torch.mean(spike, dim=0, keepdim=True)
|
149
|
+
return img
|
150
|
+
|
151
|
+
def cache_data(self):
|
152
|
+
"""Cache the data."""
|
153
|
+
self.cache_spkimg = []
|
154
|
+
for idx in tqdm(range(len(self.spike_list)), desc="Caching data", unit="sample"):
|
155
|
+
spike = self.get_spike(idx).to(torch.uint8)
|
156
|
+
img = self.get_img(idx)
|
157
|
+
self.cache_spkimg.append([spike, img])
|
@@ -0,0 +1,25 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
|
3
|
+
from dataclasses import dataclass
|
4
|
+
|
5
|
+
|
6
|
+
@dataclass
|
7
|
+
class RealWorld_Config(BaseDatasetConfig):
|
8
|
+
dataset_name: str = "realworld"
|
9
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/recVidarReal2019")
|
10
|
+
width: int = 400
|
11
|
+
height: int = 250
|
12
|
+
with_img: bool = False
|
13
|
+
spike_length_train: int = -1
|
14
|
+
spike_length_test: int = -1
|
15
|
+
|
16
|
+
|
17
|
+
class RealWorld(BaseDataset):
|
18
|
+
def __init__(self, cfg: BaseDatasetConfig):
|
19
|
+
super(RealWorld, self).__init__(cfg)
|
20
|
+
|
21
|
+
def prepare_data(self):
|
22
|
+
self.spike_dir = self.cfg.root_dir
|
23
|
+
self.spike_list = self.get_spike_files(self.spike_dir)
|
24
|
+
|
25
|
+
|
@@ -0,0 +1,27 @@
|
|
1
|
+
from torch.utils.data import Dataset
|
2
|
+
from pathlib import Path
|
3
|
+
from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
|
4
|
+
from dataclasses import dataclass
|
5
|
+
import re
|
6
|
+
|
7
|
+
@dataclass
|
8
|
+
class REDS_Small_Config(BaseDatasetConfig):
|
9
|
+
dataset_name: str = "reds_small"
|
10
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/REDS_Small")
|
11
|
+
width: int = 400
|
12
|
+
height: int = 250
|
13
|
+
with_img: bool = True
|
14
|
+
spike_length_train: int = 41
|
15
|
+
spike_length_test: int = 301
|
16
|
+
spike_dir_name: str = "spike"
|
17
|
+
img_dir_name: str = "gt"
|
18
|
+
|
19
|
+
class REDS_Small(BaseDataset):
|
20
|
+
def __init__(self, cfg: BaseDatasetConfig):
|
21
|
+
super(REDS_Small, self).__init__(cfg)
|
22
|
+
|
23
|
+
def prepare_data(self):
|
24
|
+
super().prepare_data()
|
25
|
+
if self.cfg.split == "train":
|
26
|
+
self.img_list = [self.img_dir / Path(str(s.name).replace('.dat','.png')) for s in self.spike_list]
|
27
|
+
|
@@ -0,0 +1,37 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
|
3
|
+
from dataclasses import dataclass
|
4
|
+
import cv2
|
5
|
+
import torch
|
6
|
+
import numpy as np
|
7
|
+
|
8
|
+
@dataclass
|
9
|
+
class SZData_Config(BaseDatasetConfig):
|
10
|
+
dataset_name: str = "szdata"
|
11
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/dataset")
|
12
|
+
width: int = 400
|
13
|
+
height: int = 250
|
14
|
+
with_img: bool = True
|
15
|
+
spike_length_train: int = -1
|
16
|
+
spike_length_test: int = -1
|
17
|
+
spike_dir_name: str = "spike_data"
|
18
|
+
img_dir_name: str = "sharp_data"
|
19
|
+
|
20
|
+
class SZData(BaseDataset):
|
21
|
+
def __init__(self, cfg: BaseDatasetConfig):
|
22
|
+
super(SZData, self).__init__(cfg)
|
23
|
+
|
24
|
+
def get_img(self, idx):
|
25
|
+
if self.cfg.with_img:
|
26
|
+
spike_name = self.spike_list[idx]
|
27
|
+
img_name = str(spike_name).replace(self.cfg.spike_dir_name,self.cfg.img_dir_name).replace(".dat",".png")
|
28
|
+
img = cv2.imread(img_name)
|
29
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
30
|
+
img = (img / 255).astype(np.float32)
|
31
|
+
img = img[None]
|
32
|
+
img = torch.from_numpy(img)
|
33
|
+
else:
|
34
|
+
spike = self.get_spike(idx)
|
35
|
+
img = torch.mean(spike, dim=0, keepdim=True)
|
36
|
+
return img
|
37
|
+
|
@@ -0,0 +1,38 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
|
3
|
+
from dataclasses import dataclass
|
4
|
+
import numpy as np
|
5
|
+
import torch
|
6
|
+
|
7
|
+
@dataclass
|
8
|
+
class UHSR_Config(BaseDatasetConfig):
|
9
|
+
dataset_name: str = "uhsr"
|
10
|
+
root_dir: Path = Path(__file__).parent.parent / Path("data/U-CALTECH")
|
11
|
+
width: int = 224
|
12
|
+
height: int = 224
|
13
|
+
with_img: bool = False
|
14
|
+
spike_length_train: int = 200
|
15
|
+
spike_length_test: int = 200
|
16
|
+
spike_dir_name: str = "spike"
|
17
|
+
img_dir_name: str = ""
|
18
|
+
|
19
|
+
|
20
|
+
class UHSR(BaseDataset):
|
21
|
+
def __init__(self, cfg: BaseDatasetConfig):
|
22
|
+
super(UHSR, self).__init__(cfg)
|
23
|
+
|
24
|
+
def prepare_data(self):
|
25
|
+
self.spike_dir = self.cfg.root_dir / self.cfg.dataset_split
|
26
|
+
self.spike_list = self.get_spike_files(self.spike_dir)
|
27
|
+
|
28
|
+
def get_spike_files(self, path: Path):
|
29
|
+
files = path.glob("**/*.npz")
|
30
|
+
return sorted(files)
|
31
|
+
|
32
|
+
def load_spike(self,idx):
|
33
|
+
spike_name = str(self.spike_list[idx])
|
34
|
+
data = np.load(spike_name)
|
35
|
+
spike = data["spk"].astype(np.float32)
|
36
|
+
spike = torch.from_numpy(spike)
|
37
|
+
spike = spike[:, 13:237, 13:237]
|
38
|
+
return spike
|
@@ -0,0 +1,96 @@
|
|
1
|
+
from skimage import metrics
|
2
|
+
import torch
|
3
|
+
import torch.hub
|
4
|
+
from lpips.lpips import LPIPS
|
5
|
+
import os
|
6
|
+
import os
|
7
|
+
import pyiqa
|
8
|
+
import numpy as np
|
9
|
+
from torchvision import transforms
|
10
|
+
import torch.nn.functional as F
|
11
|
+
|
12
|
+
# todo with the union type
|
13
|
+
metric_pair_names = ["psnr", "ssim", "lpips", "mse"]
|
14
|
+
metric_single_names = ["niqe", "brisque", "piqe"]
|
15
|
+
metric_all_names = metric_pair_names + metric_single_names
|
16
|
+
|
17
|
+
metric_single_list = {}
|
18
|
+
|
19
|
+
metric_pair_list = {
|
20
|
+
"mse": metrics.mean_squared_error,
|
21
|
+
"ssim": metrics.structural_similarity,
|
22
|
+
"psnr": metrics.peak_signal_noise_ratio,
|
23
|
+
"lpips": None,
|
24
|
+
}
|
25
|
+
|
26
|
+
|
27
|
+
def cal_metric_single(img: torch.Tensor, metric_name="niqe"):
|
28
|
+
if metric_name not in metric_single_list.keys():
|
29
|
+
if metric_name in pyiqa.list_models():
|
30
|
+
iqa_metric = pyiqa.create_metric(metric_name, device=torch.device("cuda"))
|
31
|
+
metric_single_list.update({metric_name: iqa_metric})
|
32
|
+
else:
|
33
|
+
raise RuntimeError(f"Metric {metric_name} not recognized by the IQA lib.")
|
34
|
+
# image process
|
35
|
+
if img.dim() == 3:
|
36
|
+
img = img[None]
|
37
|
+
elif img.dim() == 2:
|
38
|
+
img = img[None, None]
|
39
|
+
|
40
|
+
# resize
|
41
|
+
if metric_name == "liqe_mix":
|
42
|
+
short_edge = 384
|
43
|
+
h, w = img.shape[2], img.shape[3]
|
44
|
+
if h < w:
|
45
|
+
new_h, new_w = short_edge, int(w * short_edge / h)
|
46
|
+
else:
|
47
|
+
new_h, new_w = int(h * short_edge / w), short_edge
|
48
|
+
img = F.interpolate(img, size=(new_h, new_w), mode="bilinear", align_corners=False)
|
49
|
+
return metric_single_list[metric_name](img).item()
|
50
|
+
|
51
|
+
|
52
|
+
def cal_metric_pair(im1t: torch.Tensor, im2t: torch.Tensor, metric_name="mse"):
|
53
|
+
"""
|
54
|
+
im1t, im2t: torch.tensors with batched imaged shape, range from (0, 1)
|
55
|
+
"""
|
56
|
+
if metric_name not in metric_pair_list.keys():
|
57
|
+
raise RuntimeError(f"Metric {metric_name} not recognized")
|
58
|
+
if metric_name == "lpips" and metric_pair_list[metric_name] is None:
|
59
|
+
metric_pair_list[metric_name] = LPIPS().cuda() if im1t.is_cuda else LPIPS().cpu()
|
60
|
+
metric_method = metric_pair_list[metric_name]
|
61
|
+
|
62
|
+
# convert from [0, 1] to [-1, 1]
|
63
|
+
im1t = (im1t * 2 - 1).clamp(-1, 1)
|
64
|
+
im2t = (im2t * 2 - 1).clamp(-1, 1)
|
65
|
+
|
66
|
+
# [c,h,w] -> [1,c,h,w]
|
67
|
+
if im1t.dim() == 3:
|
68
|
+
im1t = im1t[None]
|
69
|
+
im2t = im2t[None]
|
70
|
+
elif im1t.dim() == 2:
|
71
|
+
im1t = im1t[None, None]
|
72
|
+
im2t = im2t[None, None]
|
73
|
+
|
74
|
+
# [1,h,w,3] -> [1,3,h,w]
|
75
|
+
if im1t.shape[-1] == 3:
|
76
|
+
im1t = im1t.permute(0, 3, 1, 2)
|
77
|
+
im2t = im2t.permute(0, 3, 1, 2)
|
78
|
+
|
79
|
+
# img array: [1,h,w,3] imgt tensor: [1,3,h,w]
|
80
|
+
im1 = im1t.permute(0, 2, 3, 1).detach().cpu().numpy()
|
81
|
+
im2 = im2t.permute(0, 2, 3, 1).detach().cpu().numpy()
|
82
|
+
batchsz, hei, wid, _ = im1.shape
|
83
|
+
|
84
|
+
# batch processing
|
85
|
+
values = []
|
86
|
+
for i in range(batchsz):
|
87
|
+
if metric_name in ["mse", "psnr"]:
|
88
|
+
value = metric_method(im1[i], im2[i])
|
89
|
+
elif metric_name in ["ssim"]:
|
90
|
+
value, ssimmap = metric_method(im1[i], im2[i], channel_axis=-1, data_range=2, full=True)
|
91
|
+
elif metric_name in ["lpips"]:
|
92
|
+
value = metric_method(im1t[i : i + 1], im2t[i : i + 1])[0, 0, 0, 0]
|
93
|
+
value = value.detach().cpu().float().item()
|
94
|
+
values.append(value)
|
95
|
+
|
96
|
+
return sum(values) / len(values)
|
@@ -0,0 +1,37 @@
|
|
1
|
+
import importlib
|
2
|
+
import inspect
|
3
|
+
from spikezoo.models.base_model import BaseModel,BaseModelConfig
|
4
|
+
import os
|
5
|
+
from pathlib import Path
|
6
|
+
|
7
|
+
current_file_path = Path(__file__).parent
|
8
|
+
files_list = os.listdir(os.path.dirname(os.path.abspath(__file__)))
|
9
|
+
model_list = [file.split("_")[0] for file in files_list if file.endswith("_model.py")]
|
10
|
+
|
11
|
+
# todo register function
|
12
|
+
def build_model_cfg(cfg: BaseModelConfig):
|
13
|
+
"""Build the model from the given model config."""
|
14
|
+
# model module name
|
15
|
+
module_name = cfg.model_name + "_model"
|
16
|
+
assert cfg.model_name in model_list, f"Given model {cfg.model_name} not in our model zoo {model_list}."
|
17
|
+
module_name = "spikezoo.models." + module_name
|
18
|
+
module = importlib.import_module(module_name)
|
19
|
+
# model,model_config
|
20
|
+
classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
|
21
|
+
model_cls: BaseModel = getattr(module, classes[0])
|
22
|
+
model = model_cls(cfg)
|
23
|
+
return model
|
24
|
+
|
25
|
+
def build_model_name(model_name: str):
|
26
|
+
"""Build the default dataset from the given name."""
|
27
|
+
# model module name
|
28
|
+
module_name = model_name + "_model"
|
29
|
+
assert model_name in model_list, f"Given model {model_name} not in our model zoo {model_list}."
|
30
|
+
module_name = "spikezoo.models." + module_name
|
31
|
+
module = importlib.import_module(module_name)
|
32
|
+
# model,model_config
|
33
|
+
classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
|
34
|
+
model_cls: BaseModel = getattr(module, classes[0])
|
35
|
+
model_cfg: BaseModelConfig = getattr(module, classes[1])()
|
36
|
+
model = model_cls(model_cfg)
|
37
|
+
return model
|