spikezoo 0.2.3.2__py3-none-any.whl → 0.2.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spikezoo/datasets/base_dataset.py +1 -1
- spikezoo/datasets/realworld_dataset.py +2 -0
- spikezoo/datasets/szdata_dataset.py +1 -0
- spikezoo/datasets/uhsr_dataset.py +1 -2
- spikezoo/pipeline/base_pipeline.py +2 -0
- {spikezoo-0.2.3.2.dist-info → spikezoo-0.2.3.3.dist-info}/METADATA +12 -16
- {spikezoo-0.2.3.2.dist-info → spikezoo-0.2.3.3.dist-info}/RECORD +10 -11
- spikezoo/datasets/reds_ssir_dataset.py +0 -181
- {spikezoo-0.2.3.2.dist-info → spikezoo-0.2.3.3.dist-info}/LICENSE.txt +0 -0
- {spikezoo-0.2.3.2.dist-info → spikezoo-0.2.3.3.dist-info}/WHEEL +0 -0
- {spikezoo-0.2.3.2.dist-info → spikezoo-0.2.3.3.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,6 @@ from dataclasses import dataclass
|
|
4
4
|
import numpy as np
|
5
5
|
import torch
|
6
6
|
|
7
|
-
|
8
7
|
@dataclass
|
9
8
|
class UHSRConfig(BaseDatasetConfig):
|
10
9
|
dataset_name: str = "uhsr"
|
@@ -16,7 +15,7 @@ class UHSRConfig(BaseDatasetConfig):
|
|
16
15
|
spike_length_test: int = 200
|
17
16
|
spike_dir_name: str = "spike"
|
18
17
|
img_dir_name: str = ""
|
19
|
-
|
18
|
+
rate: float = 1
|
20
19
|
|
21
20
|
class UHSR(BaseDataset):
|
22
21
|
def __init__(self, cfg: BaseDatasetConfig):
|
@@ -56,6 +56,7 @@ class Pipeline:
|
|
56
56
|
|
57
57
|
def _setup_model_data(self, model_cfg, dataset_cfg):
|
58
58
|
"""Model and Data setup."""
|
59
|
+
self.logger.info("Model and dataset is setting up...")
|
59
60
|
# model [1] build the model. [2] build the network.
|
60
61
|
self.model: BaseModel = build_model_name(model_cfg) if isinstance(model_cfg, str) else build_model_cfg(model_cfg)
|
61
62
|
self.model.build_network(mode="eval", version=self.cfg.version)
|
@@ -68,6 +69,7 @@ class Pipeline:
|
|
68
69
|
|
69
70
|
def _setup_pipeline(self):
|
70
71
|
"""Pipeline setup."""
|
72
|
+
self.logger.info("Pipeline is setting up...")
|
71
73
|
# save folder
|
72
74
|
self.thistime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S-%f")[:23]
|
73
75
|
self.save_folder = Path(f"results") if len(self.cfg.save_folder) == 0 else self.cfg.save_folder
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: spikezoo
|
3
|
-
Version: 0.2.3.
|
3
|
+
Version: 0.2.3.3
|
4
4
|
Summary: A deep learning toolbox for spike-to-image models.
|
5
5
|
Home-page: https://github.com/chenkang455/Spike-Zoo
|
6
6
|
Author: Kang Chen
|
@@ -36,15 +36,11 @@ Dynamic: requires-python
|
|
36
36
|
Dynamic: summary
|
37
37
|
|
38
38
|
<p align="center">
|
39
|
-
<
|
40
|
-
<img src="imgs/spike-zoo.png" width="500"/>
|
41
|
-
<br>
|
39
|
+
<img src="imgs/spike-zoo.png" width="350"/>
|
42
40
|
<p>
|
43
|
-
|
44
41
|
<h5 align="center">
|
45
42
|
|
46
43
|
[](https://github.com/chenkang455/Spike-Zoo/stargazers) [](https://github.com/chenkang455/Spike-Zoo/issues) <a href="https://badge.fury.io/py/spikezoo"><img src="https://badge.fury.io/py/spikezoo.svg" alt="PyPI version"></a> [](https://github.com/chenkang455/Spike-Zoo)
|
47
|
-
|
48
44
|
<p>
|
49
45
|
|
50
46
|
<!-- <h2 align="center">
|
@@ -165,16 +161,16 @@ We retrain all supported methods except `SPIKECLIP` on this REDS dataset (traini
|
|
165
161
|
|
166
162
|
| Method | PSNR | SSIM | LPIPS | NIQE | BRISQUE | PIQE | Params (M) | FLOPs (G) | Latency (ms) |
|
167
163
|
|----------------------|:-------:|:--------:|:---------:|:---------:|:----------:|:-------:|:------------:|:-----------:|:--------------:|
|
168
|
-
| `
|
169
|
-
| `
|
170
|
-
| `
|
171
|
-
| `
|
172
|
-
| `
|
173
|
-
| `
|
174
|
-
| `
|
175
|
-
| `
|
176
|
-
| `
|
177
|
-
| `
|
164
|
+
| `tfi` | 16.503 | 0.454 | 0.382 | 7.289 | 43.17 | 49.12 | 0.00 | 0.00 | 3.60 |
|
165
|
+
| `tfp` | 24.287 | 0.644 | 0.274 | 8.197 | 48.48 | 38.38 | 0.00 | 0.00 | 0.03 |
|
166
|
+
| `spikeclip` | 21.873 | 0.578 | 0.333 | 7.802 | 42.08 | 54.01 | 0.19 | 23.69 | 1.27 |
|
167
|
+
| `ssir` | 26.544 | 0.718 | 0.325 | 4.769 | 28.45 | 21.59 | 0.38 | 25.92 | 4.52 |
|
168
|
+
| `ssml` | 33.697 | 0.943 | 0.088 | 4.669 | 32.48 | 37.30 | 2.38 | 386.02 | 244.18 |
|
169
|
+
| `base` | 36.589 | 0.965 | 0.034 | 4.393 | 26.16 | 38.43 | 0.18 | 18.04 | 0.40 |
|
170
|
+
| `stir` | 37.914 | 0.973 | 0.027 | 4.236 | 25.10 | 39.18 | 5.08 | 43.31 | 21.07 |
|
171
|
+
| `wgse` | 39.036 | 0.978 | 0.023 | 4.231 | 25.76 | 44.11 | 3.81 | 415.26 | 73.62 |
|
172
|
+
| `spk2imgnet` | 39.154 | 0.978 | 0.022 | 4.243 | 25.20 | 43.09 | 3.90 | 1000.50 | 123.38 |
|
173
|
+
| `bsf` | 39.576 | 0.979 | 0.019 | 4.139 | 24.93 | 43.03 | 2.47 | 705.23 | 401.50 |
|
178
174
|
|
179
175
|
### 4. Model Usage
|
180
176
|
We also provide a direct interface for users interested in taking the spike-to-image model as a part of their work:
|
@@ -135,12 +135,11 @@ spikezoo/data/base/train/spike/203_part2_key_id151.dat,sha256=YEenLmbPvcxnKkVn3O
|
|
135
135
|
spikezoo/data/base/train/spike/203_part3_key_id151.dat,sha256=MY9nM6XzKj-P-tRQ33WZ3G5xulNTpAXKP0y8ZQo7AIQ,3762500
|
136
136
|
spikezoo/data/base/train/spike/203_part4_key_id151.dat,sha256=IVi2jics66YzpIF-WTkw47te4qOj9cjdgz56GmHpJKg,3762500
|
137
137
|
spikezoo/datasets/__init__.py,sha256=lRJsvCfgbe3qrd9BKTlG9dsgfIJbfXqWOynnlAcBiUI,3346
|
138
|
-
spikezoo/datasets/base_dataset.py,sha256=
|
139
|
-
spikezoo/datasets/realworld_dataset.py,sha256=
|
138
|
+
spikezoo/datasets/base_dataset.py,sha256=oQ_AqWuMlaKnR712_sJ4WiTbqqPqVsfcukDNpFDYXb0,5956
|
139
|
+
spikezoo/datasets/realworld_dataset.py,sha256=VqT6zcLa72DL3Lg8f4TThhYUa1xSIifsrPwpjvk2uBE,726
|
140
140
|
spikezoo/datasets/reds_base_dataset.py,sha256=W-IJv9H1bsKgp3RT3zsV40jw2PqY2M76jtIS4Qpif1o,859
|
141
|
-
spikezoo/datasets/
|
142
|
-
spikezoo/datasets/
|
143
|
-
spikezoo/datasets/uhsr_dataset.py,sha256=Ue0vfKMpmvR0TelQx8G4xdWB8HdSlb4HSjIVOyFe5oE,1168
|
141
|
+
spikezoo/datasets/szdata_dataset.py,sha256=xvgkZFHNSQ-Sk_rqmgRKAqpeb2gYpt_gmstJKJ8ooqU,870
|
142
|
+
spikezoo/datasets/uhsr_dataset.py,sha256=MKQeQsoCal10yMgHy3I7NJDgJJgkKgruH5tantP921A,1186
|
144
143
|
spikezoo/metrics/__init__.py,sha256=LIKeWNeEMZLANITQD68XJBOhDq7iHiKC7ExtdrXMyQs,3273
|
145
144
|
spikezoo/models/__init__.py,sha256=QZTELBoM3bUW8jZoxN4OuA2RYKeVUT1fboyeIuK8Rtk,1722
|
146
145
|
spikezoo/models/base_model.py,sha256=v3TD4AmjttTZUg0vEy736TOFdbbBgDLZg_RL-b4-vYM,9152
|
@@ -155,7 +154,7 @@ spikezoo/models/tfi_model.py,sha256=tgD_HsiXk9jGuh5f_Bh6c3BqJi1p5DWCVo4N1tp5fgs,
|
|
155
154
|
spikezoo/models/tfp_model.py,sha256=ihl1H__bWIbE9oair_t8rNJ5qnPJPKl-r_DpaO-0Sdk,663
|
156
155
|
spikezoo/models/wgse_model.py,sha256=Kl9uV-LeO0Lj7SuPQ9pglw1Khs2b-7miS3A_faL6WSU,805
|
157
156
|
spikezoo/pipeline/__init__.py,sha256=WPsukNR4cannwsghiukqNsWbWGH5DVPapR_Ly-WOU4Q,188
|
158
|
-
spikezoo/pipeline/base_pipeline.py,sha256=
|
157
|
+
spikezoo/pipeline/base_pipeline.py,sha256=3laGM-cMhNTSfCXx5jY53V8GY5BhPhHK48yjqm-Gre0,13478
|
159
158
|
spikezoo/pipeline/ensemble_pipeline.py,sha256=ljZkGiCCpxvpC04Aa-r_tvBnqcBpUVi9fl_878tJAcg,2555
|
160
159
|
spikezoo/pipeline/train_cfgs.py,sha256=6NO7DfPc7yjJfOrcIPQPfUPbUODz6eRKurEIDjMmaxA,3836
|
161
160
|
spikezoo/pipeline/train_pipeline.py,sha256=BgHUsdv33B_OKauOVclNt7yIPb-_O-93ZHLHIjrwWaA,8459
|
@@ -167,8 +166,8 @@ spikezoo/utils/other_utils.py,sha256=uWNWaII9Jv7fkWNfkAD9wD-4ID-GAzbR-gGYT-1FF_c
|
|
167
166
|
spikezoo/utils/scheduler_utils.py,sha256=5RBh-hl3-2y-IomxMs47T1p3JsbicZNYLza6q1uAKHo,828
|
168
167
|
spikezoo/utils/spike_utils.py,sha256=u4Haa6Sp5xFqs61ztvq161oXTA_aZmNW3VYUZcayNW0,4296
|
169
168
|
spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so,sha256=uXqu7ME---cZRRU5LUcLiNrjjtlOjxNwWHyTIQ10BGg,199088
|
170
|
-
spikezoo-0.2.3.
|
171
|
-
spikezoo-0.2.3.
|
172
|
-
spikezoo-0.2.3.
|
173
|
-
spikezoo-0.2.3.
|
174
|
-
spikezoo-0.2.3.
|
169
|
+
spikezoo-0.2.3.3.dist-info/LICENSE.txt,sha256=ukEi8E0PKq1dQGTXHUflg3rppLymwAhr7il9x-0nPgg,1062
|
170
|
+
spikezoo-0.2.3.3.dist-info/METADATA,sha256=3aqIRpJr6TAfjldU8ZLmWy6uuCUla400W5tGpkH-X2M,11941
|
171
|
+
spikezoo-0.2.3.3.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
172
|
+
spikezoo-0.2.3.3.dist-info/top_level.txt,sha256=xF2iuOstrACJh43NW4dsTwIdgKfXPXAb_Xzl3M1ricM,9
|
173
|
+
spikezoo-0.2.3.3.dist-info/RECORD,,
|
@@ -1,181 +0,0 @@
|
|
1
|
-
from torch.utils.data import Dataset
|
2
|
-
from pathlib import Path
|
3
|
-
from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
|
4
|
-
from dataclasses import dataclass
|
5
|
-
import re
|
6
|
-
|
7
|
-
|
8
|
-
@dataclass
|
9
|
-
class REDS_SSIRConfig(BaseDatasetConfig):
|
10
|
-
dataset_name: str = "reds_ssir"
|
11
|
-
root_dir: Path = Path(__file__).parent.parent / Path("data/REDS_SSIR")
|
12
|
-
train_width: int = 96
|
13
|
-
train_height: int = 96
|
14
|
-
test_width: int = 1280
|
15
|
-
test_height: int = 720
|
16
|
-
width: int = -1
|
17
|
-
height: int = -1
|
18
|
-
with_img: bool = True
|
19
|
-
spike_length_train: int = 41
|
20
|
-
spike_length_test: int = 301
|
21
|
-
|
22
|
-
# post process
|
23
|
-
def __post_init__(self):
|
24
|
-
self.root_dir = Path(self.root_dir) if isinstance(self.root_dir, str) else self.root_dir
|
25
|
-
# todo try download
|
26
|
-
assert self.root_dir.exists(), f"No files found in {self.root_dir} for the specified dataset `{self.dataset_name}`."
|
27
|
-
# train/test split
|
28
|
-
if self.split == "train":
|
29
|
-
self.spike_length = self.spike_length_train
|
30
|
-
self.width = self.train_width
|
31
|
-
self.height = self.train_height
|
32
|
-
else:
|
33
|
-
self.spike_length = self.spike_length_test
|
34
|
-
self.width = self.test_width
|
35
|
-
self.height = self.test_height
|
36
|
-
|
37
|
-
|
38
|
-
class REDS_SSIR(BaseDataset):
|
39
|
-
def __init__(self, cfg: BaseDatasetConfig):
|
40
|
-
super(REDS_SSIR, self).__init__(cfg)
|
41
|
-
|
42
|
-
def prepare_data(self):
|
43
|
-
"""Specify the spike and image files to be loaded."""
|
44
|
-
# spike/imgs dir train/test
|
45
|
-
if self.cfg.split == "train":
|
46
|
-
self.img_dir = self.cfg.root_dir / Path("crop_mini/spike/train/interp_20_alpha_0.40")
|
47
|
-
self.spike_dir = self.cfg.root_dir / Path("crop_mini/image/train/train_orig")
|
48
|
-
else:
|
49
|
-
self.img_dir = self.cfg.root_dir / Path("imgs/val/val_orig")
|
50
|
-
self.spike_dir = self.cfg.root_dir / Path("spike/val/interp_20_alpha_0.40")
|
51
|
-
# get files
|
52
|
-
self.spike_list = self.get_spike_files(self.spike_dir)
|
53
|
-
self.img_list = []
|
54
|
-
|
55
|
-
|
56
|
-
class sreds_train(torch.utils.data.Dataset):
|
57
|
-
def __init__(self, cfg):
|
58
|
-
self.cfg = cfg
|
59
|
-
self.pair_step = self.cfg["loader"]["pair_step"]
|
60
|
-
self.augmentor = Augmentor(crop_size=self.cfg["loader"]["crop_size"])
|
61
|
-
self.samples = self.collect_samples()
|
62
|
-
print("The samples num of training data: {:d}".format(len(self.samples)))
|
63
|
-
|
64
|
-
def confirm_exist(self, path_list_list):
|
65
|
-
for pl in path_list_list:
|
66
|
-
for p in pl:
|
67
|
-
if not osp.exists(p):
|
68
|
-
return 0
|
69
|
-
return 1
|
70
|
-
|
71
|
-
def collect_samples(self):
|
72
|
-
spike_path = osp.join(
|
73
|
-
self.cfg["data"]["root"], "crop_mini", "spike", "train", "interp_{:d}_alpha_{:.2f}".format(self.cfg["data"]["interp"], self.cfg["data"]["alpha"])
|
74
|
-
)
|
75
|
-
image_path = osp.join(self.cfg["data"]["root"], "crop_mini", "image", "train", "train_orig")
|
76
|
-
scene_list = sorted(os.listdir(spike_path))
|
77
|
-
samples = []
|
78
|
-
|
79
|
-
for scene in scene_list:
|
80
|
-
spike_dir = osp.join(spike_path, scene)
|
81
|
-
image_dir = osp.join(image_path, scene)
|
82
|
-
spk_path_list = sorted(os.listdir(spike_dir))
|
83
|
-
|
84
|
-
spklen = len(spk_path_list)
|
85
|
-
seq_len = self.cfg["model"]["seq_len"] + 2
|
86
|
-
"""
|
87
|
-
for st in range(0, spklen - ((spklen - self.pair_step) % seq_len) - seq_len, self.pair_step):
|
88
|
-
# 按照文件名称读取
|
89
|
-
spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(st, st+seq_len)]
|
90
|
-
images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4]+'.png') for ii in range(st, st+seq_len)]
|
91
|
-
|
92
|
-
if(self.confirm_exist([spikes_path_list, images_path_list])):
|
93
|
-
s = {}
|
94
|
-
s['spikes_paths'] = spikes_path_list
|
95
|
-
s['images_paths'] = images_path_list
|
96
|
-
samples.append(s)
|
97
|
-
"""
|
98
|
-
# 按照文件名称读取
|
99
|
-
spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(spklen)]
|
100
|
-
images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4] + ".png") for ii in range(spklen)]
|
101
|
-
|
102
|
-
if self.confirm_exist([spikes_path_list, images_path_list]):
|
103
|
-
s = {}
|
104
|
-
s["spikes_paths"] = spikes_path_list
|
105
|
-
s["images_paths"] = images_path_list
|
106
|
-
samples.append(s)
|
107
|
-
|
108
|
-
return samples
|
109
|
-
|
110
|
-
def _load_sample(self, s):
|
111
|
-
data = {}
|
112
|
-
|
113
|
-
data["spikes"] = [np.array(dat_to_spmat(p, size=(96, 96)), dtype=np.float32) for p in s["spikes_paths"]]
|
114
|
-
data["images"] = [read_img_gray(p) for p in s["images_paths"]]
|
115
|
-
|
116
|
-
data["spikes"], data["images"] = self.augmentor(data["spikes"], data["images"])
|
117
|
-
# print("data['spikes'][0].shape, data['images'][0].shape", data['spikes'][0].shape, data['images'][0].shape)
|
118
|
-
|
119
|
-
return data
|
120
|
-
|
121
|
-
def __len__(self):
|
122
|
-
return len(self.samples)
|
123
|
-
|
124
|
-
def __getitem__(self, index):
|
125
|
-
data = self._load_sample(self.samples[index])
|
126
|
-
return data
|
127
|
-
|
128
|
-
|
129
|
-
class sreds_test(torch.utils.data.Dataset):
|
130
|
-
def __init__(self, cfg):
|
131
|
-
self.cfg = cfg
|
132
|
-
self.samples = self.collect_samples()
|
133
|
-
print("The samples num of testing data: {:d}".format(len(self.samples)))
|
134
|
-
|
135
|
-
def confirm_exist(self, path_list_list):
|
136
|
-
for pl in path_list_list:
|
137
|
-
for p in pl:
|
138
|
-
if not osp.exists(p):
|
139
|
-
return 0
|
140
|
-
return 1
|
141
|
-
|
142
|
-
def collect_samples(self):
|
143
|
-
spike_path = osp.join(
|
144
|
-
self.cfg["data"]["root"], "spike", "val", "interp_{:d}_alpha_{:.2f}".format(self.cfg["data"]["interp"], self.cfg["data"]["alpha"])
|
145
|
-
)
|
146
|
-
image_path = osp.join(self.cfg["data"]["root"], "imgs", "val", "val_orig")
|
147
|
-
scene_list = sorted(os.listdir(spike_path))
|
148
|
-
samples = []
|
149
|
-
|
150
|
-
for scene in scene_list:
|
151
|
-
spike_dir = osp.join(spike_path, scene)
|
152
|
-
image_dir = osp.join(image_path, scene)
|
153
|
-
spk_path_list = sorted(os.listdir(spike_dir))
|
154
|
-
|
155
|
-
spklen = len(spk_path_list)
|
156
|
-
# seq_len = self.cfg['model']['seq_len']
|
157
|
-
|
158
|
-
# 按照文件名称读取
|
159
|
-
spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(spklen)]
|
160
|
-
images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4] + ".png") for ii in range(spklen)]
|
161
|
-
|
162
|
-
if self.confirm_exist([spikes_path_list, images_path_list]):
|
163
|
-
s = {}
|
164
|
-
s["spikes_paths"] = spikes_path_list
|
165
|
-
s["images_paths"] = images_path_list
|
166
|
-
samples.append(s)
|
167
|
-
|
168
|
-
return samples
|
169
|
-
|
170
|
-
def _load_sample(self, s):
|
171
|
-
data = {}
|
172
|
-
data["spikes"] = [np.array(dat_to_spmat(p, size=(720, 1280)), dtype=np.float32) for p in s["spikes_paths"]]
|
173
|
-
data["images"] = [read_img_gray(p) for p in s["images_paths"]]
|
174
|
-
return data
|
175
|
-
|
176
|
-
def __len__(self):
|
177
|
-
return len(self.samples)
|
178
|
-
|
179
|
-
def __getitem__(self, index):
|
180
|
-
data = self._load_sample(self.samples[index])
|
181
|
-
return data
|
File without changes
|
File without changes
|
File without changes
|