spikezoo 0.2.3.2__py3-none-any.whl → 0.2.3.3__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- spikezoo/datasets/base_dataset.py +1 -1
- spikezoo/datasets/realworld_dataset.py +2 -0
- spikezoo/datasets/szdata_dataset.py +1 -0
- spikezoo/datasets/uhsr_dataset.py +1 -2
- spikezoo/pipeline/base_pipeline.py +2 -0
- {spikezoo-0.2.3.2.dist-info → spikezoo-0.2.3.3.dist-info}/METADATA +12 -16
- {spikezoo-0.2.3.2.dist-info → spikezoo-0.2.3.3.dist-info}/RECORD +10 -11
- spikezoo/datasets/reds_ssir_dataset.py +0 -181
- {spikezoo-0.2.3.2.dist-info → spikezoo-0.2.3.3.dist-info}/LICENSE.txt +0 -0
- {spikezoo-0.2.3.2.dist-info → spikezoo-0.2.3.3.dist-info}/WHEEL +0 -0
- {spikezoo-0.2.3.2.dist-info → spikezoo-0.2.3.3.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,6 @@ from dataclasses import dataclass
|
|
4
4
|
import numpy as np
|
5
5
|
import torch
|
6
6
|
|
7
|
-
|
8
7
|
@dataclass
|
9
8
|
class UHSRConfig(BaseDatasetConfig):
|
10
9
|
dataset_name: str = "uhsr"
|
@@ -16,7 +15,7 @@ class UHSRConfig(BaseDatasetConfig):
|
|
16
15
|
spike_length_test: int = 200
|
17
16
|
spike_dir_name: str = "spike"
|
18
17
|
img_dir_name: str = ""
|
19
|
-
|
18
|
+
rate: float = 1
|
20
19
|
|
21
20
|
class UHSR(BaseDataset):
|
22
21
|
def __init__(self, cfg: BaseDatasetConfig):
|
@@ -56,6 +56,7 @@ class Pipeline:
|
|
56
56
|
|
57
57
|
def _setup_model_data(self, model_cfg, dataset_cfg):
|
58
58
|
"""Model and Data setup."""
|
59
|
+
self.logger.info("Model and dataset is setting up...")
|
59
60
|
# model [1] build the model. [2] build the network.
|
60
61
|
self.model: BaseModel = build_model_name(model_cfg) if isinstance(model_cfg, str) else build_model_cfg(model_cfg)
|
61
62
|
self.model.build_network(mode="eval", version=self.cfg.version)
|
@@ -68,6 +69,7 @@ class Pipeline:
|
|
68
69
|
|
69
70
|
def _setup_pipeline(self):
|
70
71
|
"""Pipeline setup."""
|
72
|
+
self.logger.info("Pipeline is setting up...")
|
71
73
|
# save folder
|
72
74
|
self.thistime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S-%f")[:23]
|
73
75
|
self.save_folder = Path(f"results") if len(self.cfg.save_folder) == 0 else self.cfg.save_folder
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: spikezoo
|
3
|
-
Version: 0.2.3.
|
3
|
+
Version: 0.2.3.3
|
4
4
|
Summary: A deep learning toolbox for spike-to-image models.
|
5
5
|
Home-page: https://github.com/chenkang455/Spike-Zoo
|
6
6
|
Author: Kang Chen
|
@@ -36,15 +36,11 @@ Dynamic: requires-python
|
|
36
36
|
Dynamic: summary
|
37
37
|
|
38
38
|
<p align="center">
|
39
|
-
<
|
40
|
-
<img src="imgs/spike-zoo.png" width="500"/>
|
41
|
-
<br>
|
39
|
+
<img src="imgs/spike-zoo.png" width="350"/>
|
42
40
|
<p>
|
43
|
-
|
44
41
|
<h5 align="center">
|
45
42
|
|
46
43
|
[data:image/s3,"s3://crabby-images/8f6ba/8f6bad60765d4c0565fd31a12c603a2259433773" alt="GitHub repo stars"](https://github.com/chenkang455/Spike-Zoo/stargazers) [data:image/s3,"s3://crabby-images/239d1/239d1b8c04ba7ce2bfe65b5e92c17da7957ee27d" alt="GitHub Issues"](https://github.com/chenkang455/Spike-Zoo/issues) <a href="https://badge.fury.io/py/spikezoo"><img src="https://badge.fury.io/py/spikezoo.svg" alt="PyPI version"></a> [data:image/s3,"s3://crabby-images/b47ac/b47ac140313f21e07d58e3af01454f0def406d87" alt="License"](https://github.com/chenkang455/Spike-Zoo)
|
47
|
-
|
48
44
|
<p>
|
49
45
|
|
50
46
|
<!-- <h2 align="center">
|
@@ -165,16 +161,16 @@ We retrain all supported methods except `SPIKECLIP` on this REDS dataset (traini
|
|
165
161
|
|
166
162
|
| Method | PSNR | SSIM | LPIPS | NIQE | BRISQUE | PIQE | Params (M) | FLOPs (G) | Latency (ms) |
|
167
163
|
|----------------------|:-------:|:--------:|:---------:|:---------:|:----------:|:-------:|:------------:|:-----------:|:--------------:|
|
168
|
-
| `
|
169
|
-
| `
|
170
|
-
| `
|
171
|
-
| `
|
172
|
-
| `
|
173
|
-
| `
|
174
|
-
| `
|
175
|
-
| `
|
176
|
-
| `
|
177
|
-
| `
|
164
|
+
| `tfi` | 16.503 | 0.454 | 0.382 | 7.289 | 43.17 | 49.12 | 0.00 | 0.00 | 3.60 |
|
165
|
+
| `tfp` | 24.287 | 0.644 | 0.274 | 8.197 | 48.48 | 38.38 | 0.00 | 0.00 | 0.03 |
|
166
|
+
| `spikeclip` | 21.873 | 0.578 | 0.333 | 7.802 | 42.08 | 54.01 | 0.19 | 23.69 | 1.27 |
|
167
|
+
| `ssir` | 26.544 | 0.718 | 0.325 | 4.769 | 28.45 | 21.59 | 0.38 | 25.92 | 4.52 |
|
168
|
+
| `ssml` | 33.697 | 0.943 | 0.088 | 4.669 | 32.48 | 37.30 | 2.38 | 386.02 | 244.18 |
|
169
|
+
| `base` | 36.589 | 0.965 | 0.034 | 4.393 | 26.16 | 38.43 | 0.18 | 18.04 | 0.40 |
|
170
|
+
| `stir` | 37.914 | 0.973 | 0.027 | 4.236 | 25.10 | 39.18 | 5.08 | 43.31 | 21.07 |
|
171
|
+
| `wgse` | 39.036 | 0.978 | 0.023 | 4.231 | 25.76 | 44.11 | 3.81 | 415.26 | 73.62 |
|
172
|
+
| `spk2imgnet` | 39.154 | 0.978 | 0.022 | 4.243 | 25.20 | 43.09 | 3.90 | 1000.50 | 123.38 |
|
173
|
+
| `bsf` | 39.576 | 0.979 | 0.019 | 4.139 | 24.93 | 43.03 | 2.47 | 705.23 | 401.50 |
|
178
174
|
|
179
175
|
### 4. Model Usage
|
180
176
|
We also provide a direct interface for users interested in taking the spike-to-image model as a part of their work:
|
@@ -135,12 +135,11 @@ spikezoo/data/base/train/spike/203_part2_key_id151.dat,sha256=YEenLmbPvcxnKkVn3O
|
|
135
135
|
spikezoo/data/base/train/spike/203_part3_key_id151.dat,sha256=MY9nM6XzKj-P-tRQ33WZ3G5xulNTpAXKP0y8ZQo7AIQ,3762500
|
136
136
|
spikezoo/data/base/train/spike/203_part4_key_id151.dat,sha256=IVi2jics66YzpIF-WTkw47te4qOj9cjdgz56GmHpJKg,3762500
|
137
137
|
spikezoo/datasets/__init__.py,sha256=lRJsvCfgbe3qrd9BKTlG9dsgfIJbfXqWOynnlAcBiUI,3346
|
138
|
-
spikezoo/datasets/base_dataset.py,sha256=
|
139
|
-
spikezoo/datasets/realworld_dataset.py,sha256=
|
138
|
+
spikezoo/datasets/base_dataset.py,sha256=oQ_AqWuMlaKnR712_sJ4WiTbqqPqVsfcukDNpFDYXb0,5956
|
139
|
+
spikezoo/datasets/realworld_dataset.py,sha256=VqT6zcLa72DL3Lg8f4TThhYUa1xSIifsrPwpjvk2uBE,726
|
140
140
|
spikezoo/datasets/reds_base_dataset.py,sha256=W-IJv9H1bsKgp3RT3zsV40jw2PqY2M76jtIS4Qpif1o,859
|
141
|
-
spikezoo/datasets/
|
142
|
-
spikezoo/datasets/
|
143
|
-
spikezoo/datasets/uhsr_dataset.py,sha256=Ue0vfKMpmvR0TelQx8G4xdWB8HdSlb4HSjIVOyFe5oE,1168
|
141
|
+
spikezoo/datasets/szdata_dataset.py,sha256=xvgkZFHNSQ-Sk_rqmgRKAqpeb2gYpt_gmstJKJ8ooqU,870
|
142
|
+
spikezoo/datasets/uhsr_dataset.py,sha256=MKQeQsoCal10yMgHy3I7NJDgJJgkKgruH5tantP921A,1186
|
144
143
|
spikezoo/metrics/__init__.py,sha256=LIKeWNeEMZLANITQD68XJBOhDq7iHiKC7ExtdrXMyQs,3273
|
145
144
|
spikezoo/models/__init__.py,sha256=QZTELBoM3bUW8jZoxN4OuA2RYKeVUT1fboyeIuK8Rtk,1722
|
146
145
|
spikezoo/models/base_model.py,sha256=v3TD4AmjttTZUg0vEy736TOFdbbBgDLZg_RL-b4-vYM,9152
|
@@ -155,7 +154,7 @@ spikezoo/models/tfi_model.py,sha256=tgD_HsiXk9jGuh5f_Bh6c3BqJi1p5DWCVo4N1tp5fgs,
|
|
155
154
|
spikezoo/models/tfp_model.py,sha256=ihl1H__bWIbE9oair_t8rNJ5qnPJPKl-r_DpaO-0Sdk,663
|
156
155
|
spikezoo/models/wgse_model.py,sha256=Kl9uV-LeO0Lj7SuPQ9pglw1Khs2b-7miS3A_faL6WSU,805
|
157
156
|
spikezoo/pipeline/__init__.py,sha256=WPsukNR4cannwsghiukqNsWbWGH5DVPapR_Ly-WOU4Q,188
|
158
|
-
spikezoo/pipeline/base_pipeline.py,sha256=
|
157
|
+
spikezoo/pipeline/base_pipeline.py,sha256=3laGM-cMhNTSfCXx5jY53V8GY5BhPhHK48yjqm-Gre0,13478
|
159
158
|
spikezoo/pipeline/ensemble_pipeline.py,sha256=ljZkGiCCpxvpC04Aa-r_tvBnqcBpUVi9fl_878tJAcg,2555
|
160
159
|
spikezoo/pipeline/train_cfgs.py,sha256=6NO7DfPc7yjJfOrcIPQPfUPbUODz6eRKurEIDjMmaxA,3836
|
161
160
|
spikezoo/pipeline/train_pipeline.py,sha256=BgHUsdv33B_OKauOVclNt7yIPb-_O-93ZHLHIjrwWaA,8459
|
@@ -167,8 +166,8 @@ spikezoo/utils/other_utils.py,sha256=uWNWaII9Jv7fkWNfkAD9wD-4ID-GAzbR-gGYT-1FF_c
|
|
167
166
|
spikezoo/utils/scheduler_utils.py,sha256=5RBh-hl3-2y-IomxMs47T1p3JsbicZNYLza6q1uAKHo,828
|
168
167
|
spikezoo/utils/spike_utils.py,sha256=u4Haa6Sp5xFqs61ztvq161oXTA_aZmNW3VYUZcayNW0,4296
|
169
168
|
spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so,sha256=uXqu7ME---cZRRU5LUcLiNrjjtlOjxNwWHyTIQ10BGg,199088
|
170
|
-
spikezoo-0.2.3.
|
171
|
-
spikezoo-0.2.3.
|
172
|
-
spikezoo-0.2.3.
|
173
|
-
spikezoo-0.2.3.
|
174
|
-
spikezoo-0.2.3.
|
169
|
+
spikezoo-0.2.3.3.dist-info/LICENSE.txt,sha256=ukEi8E0PKq1dQGTXHUflg3rppLymwAhr7il9x-0nPgg,1062
|
170
|
+
spikezoo-0.2.3.3.dist-info/METADATA,sha256=3aqIRpJr6TAfjldU8ZLmWy6uuCUla400W5tGpkH-X2M,11941
|
171
|
+
spikezoo-0.2.3.3.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
172
|
+
spikezoo-0.2.3.3.dist-info/top_level.txt,sha256=xF2iuOstrACJh43NW4dsTwIdgKfXPXAb_Xzl3M1ricM,9
|
173
|
+
spikezoo-0.2.3.3.dist-info/RECORD,,
|
@@ -1,181 +0,0 @@
|
|
1
|
-
from torch.utils.data import Dataset
|
2
|
-
from pathlib import Path
|
3
|
-
from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
|
4
|
-
from dataclasses import dataclass
|
5
|
-
import re
|
6
|
-
|
7
|
-
|
8
|
-
@dataclass
|
9
|
-
class REDS_SSIRConfig(BaseDatasetConfig):
|
10
|
-
dataset_name: str = "reds_ssir"
|
11
|
-
root_dir: Path = Path(__file__).parent.parent / Path("data/REDS_SSIR")
|
12
|
-
train_width: int = 96
|
13
|
-
train_height: int = 96
|
14
|
-
test_width: int = 1280
|
15
|
-
test_height: int = 720
|
16
|
-
width: int = -1
|
17
|
-
height: int = -1
|
18
|
-
with_img: bool = True
|
19
|
-
spike_length_train: int = 41
|
20
|
-
spike_length_test: int = 301
|
21
|
-
|
22
|
-
# post process
|
23
|
-
def __post_init__(self):
|
24
|
-
self.root_dir = Path(self.root_dir) if isinstance(self.root_dir, str) else self.root_dir
|
25
|
-
# todo try download
|
26
|
-
assert self.root_dir.exists(), f"No files found in {self.root_dir} for the specified dataset `{self.dataset_name}`."
|
27
|
-
# train/test split
|
28
|
-
if self.split == "train":
|
29
|
-
self.spike_length = self.spike_length_train
|
30
|
-
self.width = self.train_width
|
31
|
-
self.height = self.train_height
|
32
|
-
else:
|
33
|
-
self.spike_length = self.spike_length_test
|
34
|
-
self.width = self.test_width
|
35
|
-
self.height = self.test_height
|
36
|
-
|
37
|
-
|
38
|
-
class REDS_SSIR(BaseDataset):
|
39
|
-
def __init__(self, cfg: BaseDatasetConfig):
|
40
|
-
super(REDS_SSIR, self).__init__(cfg)
|
41
|
-
|
42
|
-
def prepare_data(self):
|
43
|
-
"""Specify the spike and image files to be loaded."""
|
44
|
-
# spike/imgs dir train/test
|
45
|
-
if self.cfg.split == "train":
|
46
|
-
self.img_dir = self.cfg.root_dir / Path("crop_mini/spike/train/interp_20_alpha_0.40")
|
47
|
-
self.spike_dir = self.cfg.root_dir / Path("crop_mini/image/train/train_orig")
|
48
|
-
else:
|
49
|
-
self.img_dir = self.cfg.root_dir / Path("imgs/val/val_orig")
|
50
|
-
self.spike_dir = self.cfg.root_dir / Path("spike/val/interp_20_alpha_0.40")
|
51
|
-
# get files
|
52
|
-
self.spike_list = self.get_spike_files(self.spike_dir)
|
53
|
-
self.img_list = []
|
54
|
-
|
55
|
-
|
56
|
-
class sreds_train(torch.utils.data.Dataset):
|
57
|
-
def __init__(self, cfg):
|
58
|
-
self.cfg = cfg
|
59
|
-
self.pair_step = self.cfg["loader"]["pair_step"]
|
60
|
-
self.augmentor = Augmentor(crop_size=self.cfg["loader"]["crop_size"])
|
61
|
-
self.samples = self.collect_samples()
|
62
|
-
print("The samples num of training data: {:d}".format(len(self.samples)))
|
63
|
-
|
64
|
-
def confirm_exist(self, path_list_list):
|
65
|
-
for pl in path_list_list:
|
66
|
-
for p in pl:
|
67
|
-
if not osp.exists(p):
|
68
|
-
return 0
|
69
|
-
return 1
|
70
|
-
|
71
|
-
def collect_samples(self):
|
72
|
-
spike_path = osp.join(
|
73
|
-
self.cfg["data"]["root"], "crop_mini", "spike", "train", "interp_{:d}_alpha_{:.2f}".format(self.cfg["data"]["interp"], self.cfg["data"]["alpha"])
|
74
|
-
)
|
75
|
-
image_path = osp.join(self.cfg["data"]["root"], "crop_mini", "image", "train", "train_orig")
|
76
|
-
scene_list = sorted(os.listdir(spike_path))
|
77
|
-
samples = []
|
78
|
-
|
79
|
-
for scene in scene_list:
|
80
|
-
spike_dir = osp.join(spike_path, scene)
|
81
|
-
image_dir = osp.join(image_path, scene)
|
82
|
-
spk_path_list = sorted(os.listdir(spike_dir))
|
83
|
-
|
84
|
-
spklen = len(spk_path_list)
|
85
|
-
seq_len = self.cfg["model"]["seq_len"] + 2
|
86
|
-
"""
|
87
|
-
for st in range(0, spklen - ((spklen - self.pair_step) % seq_len) - seq_len, self.pair_step):
|
88
|
-
# 按照文件名称读取
|
89
|
-
spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(st, st+seq_len)]
|
90
|
-
images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4]+'.png') for ii in range(st, st+seq_len)]
|
91
|
-
|
92
|
-
if(self.confirm_exist([spikes_path_list, images_path_list])):
|
93
|
-
s = {}
|
94
|
-
s['spikes_paths'] = spikes_path_list
|
95
|
-
s['images_paths'] = images_path_list
|
96
|
-
samples.append(s)
|
97
|
-
"""
|
98
|
-
# 按照文件名称读取
|
99
|
-
spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(spklen)]
|
100
|
-
images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4] + ".png") for ii in range(spklen)]
|
101
|
-
|
102
|
-
if self.confirm_exist([spikes_path_list, images_path_list]):
|
103
|
-
s = {}
|
104
|
-
s["spikes_paths"] = spikes_path_list
|
105
|
-
s["images_paths"] = images_path_list
|
106
|
-
samples.append(s)
|
107
|
-
|
108
|
-
return samples
|
109
|
-
|
110
|
-
def _load_sample(self, s):
|
111
|
-
data = {}
|
112
|
-
|
113
|
-
data["spikes"] = [np.array(dat_to_spmat(p, size=(96, 96)), dtype=np.float32) for p in s["spikes_paths"]]
|
114
|
-
data["images"] = [read_img_gray(p) for p in s["images_paths"]]
|
115
|
-
|
116
|
-
data["spikes"], data["images"] = self.augmentor(data["spikes"], data["images"])
|
117
|
-
# print("data['spikes'][0].shape, data['images'][0].shape", data['spikes'][0].shape, data['images'][0].shape)
|
118
|
-
|
119
|
-
return data
|
120
|
-
|
121
|
-
def __len__(self):
|
122
|
-
return len(self.samples)
|
123
|
-
|
124
|
-
def __getitem__(self, index):
|
125
|
-
data = self._load_sample(self.samples[index])
|
126
|
-
return data
|
127
|
-
|
128
|
-
|
129
|
-
class sreds_test(torch.utils.data.Dataset):
|
130
|
-
def __init__(self, cfg):
|
131
|
-
self.cfg = cfg
|
132
|
-
self.samples = self.collect_samples()
|
133
|
-
print("The samples num of testing data: {:d}".format(len(self.samples)))
|
134
|
-
|
135
|
-
def confirm_exist(self, path_list_list):
|
136
|
-
for pl in path_list_list:
|
137
|
-
for p in pl:
|
138
|
-
if not osp.exists(p):
|
139
|
-
return 0
|
140
|
-
return 1
|
141
|
-
|
142
|
-
def collect_samples(self):
|
143
|
-
spike_path = osp.join(
|
144
|
-
self.cfg["data"]["root"], "spike", "val", "interp_{:d}_alpha_{:.2f}".format(self.cfg["data"]["interp"], self.cfg["data"]["alpha"])
|
145
|
-
)
|
146
|
-
image_path = osp.join(self.cfg["data"]["root"], "imgs", "val", "val_orig")
|
147
|
-
scene_list = sorted(os.listdir(spike_path))
|
148
|
-
samples = []
|
149
|
-
|
150
|
-
for scene in scene_list:
|
151
|
-
spike_dir = osp.join(spike_path, scene)
|
152
|
-
image_dir = osp.join(image_path, scene)
|
153
|
-
spk_path_list = sorted(os.listdir(spike_dir))
|
154
|
-
|
155
|
-
spklen = len(spk_path_list)
|
156
|
-
# seq_len = self.cfg['model']['seq_len']
|
157
|
-
|
158
|
-
# 按照文件名称读取
|
159
|
-
spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(spklen)]
|
160
|
-
images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4] + ".png") for ii in range(spklen)]
|
161
|
-
|
162
|
-
if self.confirm_exist([spikes_path_list, images_path_list]):
|
163
|
-
s = {}
|
164
|
-
s["spikes_paths"] = spikes_path_list
|
165
|
-
s["images_paths"] = images_path_list
|
166
|
-
samples.append(s)
|
167
|
-
|
168
|
-
return samples
|
169
|
-
|
170
|
-
def _load_sample(self, s):
|
171
|
-
data = {}
|
172
|
-
data["spikes"] = [np.array(dat_to_spmat(p, size=(720, 1280)), dtype=np.float32) for p in s["spikes_paths"]]
|
173
|
-
data["images"] = [read_img_gray(p) for p in s["images_paths"]]
|
174
|
-
return data
|
175
|
-
|
176
|
-
def __len__(self):
|
177
|
-
return len(self.samples)
|
178
|
-
|
179
|
-
def __getitem__(self, index):
|
180
|
-
data = self._load_sample(self.samples[index])
|
181
|
-
return data
|
File without changes
|
File without changes
|
File without changes
|