spikezoo 0.2.3.2__py3-none-any.whl → 0.2.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -47,7 +47,7 @@ class BaseDatasetConfig:
47
47
  "Crop size."
48
48
  crop_size: tuple = (-1, -1)
49
49
  "Rate. (-1 denotes variant)"
50
- rate: float = 1
50
+ rate: float = 0.6
51
51
 
52
52
  # post process
53
53
  def __post_init__(self):
@@ -12,6 +12,8 @@ class RealWorldConfig(BaseDatasetConfig):
12
12
  with_img: bool = False
13
13
  spike_length_train: int = -1
14
14
  spike_length_test: int = -1
15
+ rate: float = 1
16
+
15
17
 
16
18
 
17
19
  class RealWorld(BaseDataset):
@@ -17,6 +17,7 @@ class SZDataConfig(BaseDatasetConfig):
17
17
  spike_length_test: int = -1
18
18
  spike_dir_name: str = "spike_data"
19
19
  img_dir_name: str = "sharp_data"
20
+ rate: float = 1
20
21
 
21
22
  class SZData(BaseDataset):
22
23
  def __init__(self, cfg: BaseDatasetConfig):
@@ -4,7 +4,6 @@ from dataclasses import dataclass
4
4
  import numpy as np
5
5
  import torch
6
6
 
7
-
8
7
  @dataclass
9
8
  class UHSRConfig(BaseDatasetConfig):
10
9
  dataset_name: str = "uhsr"
@@ -16,7 +15,7 @@ class UHSRConfig(BaseDatasetConfig):
16
15
  spike_length_test: int = 200
17
16
  spike_dir_name: str = "spike"
18
17
  img_dir_name: str = ""
19
-
18
+ rate: float = 1
20
19
 
21
20
  class UHSR(BaseDataset):
22
21
  def __init__(self, cfg: BaseDatasetConfig):
@@ -56,6 +56,7 @@ class Pipeline:
56
56
 
57
57
  def _setup_model_data(self, model_cfg, dataset_cfg):
58
58
  """Model and Data setup."""
59
+ self.logger.info("Model and dataset is setting up...")
59
60
  # model [1] build the model. [2] build the network.
60
61
  self.model: BaseModel = build_model_name(model_cfg) if isinstance(model_cfg, str) else build_model_cfg(model_cfg)
61
62
  self.model.build_network(mode="eval", version=self.cfg.version)
@@ -68,6 +69,7 @@ class Pipeline:
68
69
 
69
70
  def _setup_pipeline(self):
70
71
  """Pipeline setup."""
72
+ self.logger.info("Pipeline is setting up...")
71
73
  # save folder
72
74
  self.thistime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S-%f")[:23]
73
75
  self.save_folder = Path(f"results") if len(self.cfg.save_folder) == 0 else self.cfg.save_folder
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: spikezoo
3
- Version: 0.2.3.2
3
+ Version: 0.2.3.3
4
4
  Summary: A deep learning toolbox for spike-to-image models.
5
5
  Home-page: https://github.com/chenkang455/Spike-Zoo
6
6
  Author: Kang Chen
@@ -36,15 +36,11 @@ Dynamic: requires-python
36
36
  Dynamic: summary
37
37
 
38
38
  <p align="center">
39
- <br>
40
- <img src="imgs/spike-zoo.png" width="500"/>
41
- <br>
39
+ <img src="imgs/spike-zoo.png" width="350"/>
42
40
  <p>
43
-
44
41
  <h5 align="center">
45
42
 
46
43
  [![GitHub repo stars](https://img.shields.io/github/stars/chenkang455/Spike-Zoo?style=flat&logo=github&logoColor=whitesmoke&label=Stars)](https://github.com/chenkang455/Spike-Zoo/stargazers) [![GitHub Issues](https://img.shields.io/github/issues/chenkang455/Spike-Zoo?style=flat&logo=github&logoColor=whitesmoke&label=Stars)](https://github.com/chenkang455/Spike-Zoo/issues) <a href="https://badge.fury.io/py/spikezoo"><img src="https://badge.fury.io/py/spikezoo.svg" alt="PyPI version"></a> [![License](https://img.shields.io/badge/License-MIT-yellow)](https://github.com/chenkang455/Spike-Zoo)
47
-
48
44
  <p>
49
45
 
50
46
  <!-- <h2 align="center">
@@ -165,16 +161,16 @@ We retrain all supported methods except `SPIKECLIP` on this REDS dataset (traini
165
161
 
166
162
  | Method | PSNR | SSIM | LPIPS | NIQE | BRISQUE | PIQE | Params (M) | FLOPs (G) | Latency (ms) |
167
163
  |----------------------|:-------:|:--------:|:---------:|:---------:|:----------:|:-------:|:------------:|:-----------:|:--------------:|
168
- | `TFI` | 16.503 | 0.454 | 0.382 | 7.289 | 43.17 | 49.12 | 0.00 | 0.00 | 3.60 |
169
- | `TFP` | 24.287 | 0.644 | 0.274 | 8.197 | 48.48 | 38.38 | 0.00 | 0.00 | 0.03 |
170
- | `SPIKECLIP` | 21.873 | 0.578 | 0.333 | 7.802 | 42.08 | 54.01 | 0.19 | 23.69 | 1.27 |
171
- | `SSIR` | 26.544 | 0.718 | 0.325 | 4.769 | 28.45 | 21.59 | 0.38 | 25.92 | 4.52 |
172
- | `SSML` | 33.697 | 0.943 | 0.088 | 4.669 | 32.48 | 37.30 | 2.38 | 386.02 | 244.18 |
173
- | `BASE` | 36.589 | 0.965 | 0.034 | 4.393 | 26.16 | 38.43 | 0.18 | 18.04 | 0.40 |
174
- | `STIR` | 37.914 | 0.973 | 0.027 | 4.236 | 25.10 | 39.18 | 5.08 | 43.31 | 21.07 |
175
- | `WGSE` | 39.036 | 0.978 | 0.023 | 4.231 | 25.76 | 44.11 | 3.81 | 415.26 | 73.62 |
176
- | `SPK2IMGNET` | 39.154 | 0.978 | 0.022 | 4.243 | 25.20 | 43.09 | 3.90 | 1000.50 | 123.38 |
177
- | `BSF` | 39.576 | 0.979 | 0.019 | 4.139 | 24.93 | 43.03 | 2.47 | 705.23 | 401.50 |
164
+ | `tfi` | 16.503 | 0.454 | 0.382 | 7.289 | 43.17 | 49.12 | 0.00 | 0.00 | 3.60 |
165
+ | `tfp` | 24.287 | 0.644 | 0.274 | 8.197 | 48.48 | 38.38 | 0.00 | 0.00 | 0.03 |
166
+ | `spikeclip` | 21.873 | 0.578 | 0.333 | 7.802 | 42.08 | 54.01 | 0.19 | 23.69 | 1.27 |
167
+ | `ssir` | 26.544 | 0.718 | 0.325 | 4.769 | 28.45 | 21.59 | 0.38 | 25.92 | 4.52 |
168
+ | `ssml` | 33.697 | 0.943 | 0.088 | 4.669 | 32.48 | 37.30 | 2.38 | 386.02 | 244.18 |
169
+ | `base` | 36.589 | 0.965 | 0.034 | 4.393 | 26.16 | 38.43 | 0.18 | 18.04 | 0.40 |
170
+ | `stir` | 37.914 | 0.973 | 0.027 | 4.236 | 25.10 | 39.18 | 5.08 | 43.31 | 21.07 |
171
+ | `wgse` | 39.036 | 0.978 | 0.023 | 4.231 | 25.76 | 44.11 | 3.81 | 415.26 | 73.62 |
172
+ | `spk2imgnet` | 39.154 | 0.978 | 0.022 | 4.243 | 25.20 | 43.09 | 3.90 | 1000.50 | 123.38 |
173
+ | `bsf` | 39.576 | 0.979 | 0.019 | 4.139 | 24.93 | 43.03 | 2.47 | 705.23 | 401.50 |
178
174
 
179
175
  ### 4. Model Usage
180
176
  We also provide a direct interface for users interested in taking the spike-to-image model as a part of their work:
@@ -135,12 +135,11 @@ spikezoo/data/base/train/spike/203_part2_key_id151.dat,sha256=YEenLmbPvcxnKkVn3O
135
135
  spikezoo/data/base/train/spike/203_part3_key_id151.dat,sha256=MY9nM6XzKj-P-tRQ33WZ3G5xulNTpAXKP0y8ZQo7AIQ,3762500
136
136
  spikezoo/data/base/train/spike/203_part4_key_id151.dat,sha256=IVi2jics66YzpIF-WTkw47te4qOj9cjdgz56GmHpJKg,3762500
137
137
  spikezoo/datasets/__init__.py,sha256=lRJsvCfgbe3qrd9BKTlG9dsgfIJbfXqWOynnlAcBiUI,3346
138
- spikezoo/datasets/base_dataset.py,sha256=gaMcrgWB3JLfc5lAM5wZjgMj72lZaO5otGfXOwRRiL8,5954
139
- spikezoo/datasets/realworld_dataset.py,sha256=CSyXj_uo0y1a4TlLDSTkB0CyFiFdZiU1phrWdkwvqgg,701
138
+ spikezoo/datasets/base_dataset.py,sha256=oQ_AqWuMlaKnR712_sJ4WiTbqqPqVsfcukDNpFDYXb0,5956
139
+ spikezoo/datasets/realworld_dataset.py,sha256=VqT6zcLa72DL3Lg8f4TThhYUa1xSIifsrPwpjvk2uBE,726
140
140
  spikezoo/datasets/reds_base_dataset.py,sha256=W-IJv9H1bsKgp3RT3zsV40jw2PqY2M76jtIS4Qpif1o,859
141
- spikezoo/datasets/reds_ssir_dataset.py,sha256=t0hm6PUWX5hfvSXB0UEv_JuihIhc5-mufrfur5bIq_0,7076
142
- spikezoo/datasets/szdata_dataset.py,sha256=7lSqNzWkVUVTVufPhm7AjpLghpZvjfu1X9c0sCsrjwM,850
143
- spikezoo/datasets/uhsr_dataset.py,sha256=Ue0vfKMpmvR0TelQx8G4xdWB8HdSlb4HSjIVOyFe5oE,1168
141
+ spikezoo/datasets/szdata_dataset.py,sha256=xvgkZFHNSQ-Sk_rqmgRKAqpeb2gYpt_gmstJKJ8ooqU,870
142
+ spikezoo/datasets/uhsr_dataset.py,sha256=MKQeQsoCal10yMgHy3I7NJDgJJgkKgruH5tantP921A,1186
144
143
  spikezoo/metrics/__init__.py,sha256=LIKeWNeEMZLANITQD68XJBOhDq7iHiKC7ExtdrXMyQs,3273
145
144
  spikezoo/models/__init__.py,sha256=QZTELBoM3bUW8jZoxN4OuA2RYKeVUT1fboyeIuK8Rtk,1722
146
145
  spikezoo/models/base_model.py,sha256=v3TD4AmjttTZUg0vEy736TOFdbbBgDLZg_RL-b4-vYM,9152
@@ -155,7 +154,7 @@ spikezoo/models/tfi_model.py,sha256=tgD_HsiXk9jGuh5f_Bh6c3BqJi1p5DWCVo4N1tp5fgs,
155
154
  spikezoo/models/tfp_model.py,sha256=ihl1H__bWIbE9oair_t8rNJ5qnPJPKl-r_DpaO-0Sdk,663
156
155
  spikezoo/models/wgse_model.py,sha256=Kl9uV-LeO0Lj7SuPQ9pglw1Khs2b-7miS3A_faL6WSU,805
157
156
  spikezoo/pipeline/__init__.py,sha256=WPsukNR4cannwsghiukqNsWbWGH5DVPapR_Ly-WOU4Q,188
158
- spikezoo/pipeline/base_pipeline.py,sha256=ns2rBRbKaq8M9yDtWOiD-t6lo_5_c5rE2ZWH9sR78P4,13361
157
+ spikezoo/pipeline/base_pipeline.py,sha256=3laGM-cMhNTSfCXx5jY53V8GY5BhPhHK48yjqm-Gre0,13478
159
158
  spikezoo/pipeline/ensemble_pipeline.py,sha256=ljZkGiCCpxvpC04Aa-r_tvBnqcBpUVi9fl_878tJAcg,2555
160
159
  spikezoo/pipeline/train_cfgs.py,sha256=6NO7DfPc7yjJfOrcIPQPfUPbUODz6eRKurEIDjMmaxA,3836
161
160
  spikezoo/pipeline/train_pipeline.py,sha256=BgHUsdv33B_OKauOVclNt7yIPb-_O-93ZHLHIjrwWaA,8459
@@ -167,8 +166,8 @@ spikezoo/utils/other_utils.py,sha256=uWNWaII9Jv7fkWNfkAD9wD-4ID-GAzbR-gGYT-1FF_c
167
166
  spikezoo/utils/scheduler_utils.py,sha256=5RBh-hl3-2y-IomxMs47T1p3JsbicZNYLza6q1uAKHo,828
168
167
  spikezoo/utils/spike_utils.py,sha256=u4Haa6Sp5xFqs61ztvq161oXTA_aZmNW3VYUZcayNW0,4296
169
168
  spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so,sha256=uXqu7ME---cZRRU5LUcLiNrjjtlOjxNwWHyTIQ10BGg,199088
170
- spikezoo-0.2.3.2.dist-info/LICENSE.txt,sha256=ukEi8E0PKq1dQGTXHUflg3rppLymwAhr7il9x-0nPgg,1062
171
- spikezoo-0.2.3.2.dist-info/METADATA,sha256=k4-w6mwZAfSx1YhunWGVx3LYuocZEqs8ypcMr9WhkZ8,11962
172
- spikezoo-0.2.3.2.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
173
- spikezoo-0.2.3.2.dist-info/top_level.txt,sha256=xF2iuOstrACJh43NW4dsTwIdgKfXPXAb_Xzl3M1ricM,9
174
- spikezoo-0.2.3.2.dist-info/RECORD,,
169
+ spikezoo-0.2.3.3.dist-info/LICENSE.txt,sha256=ukEi8E0PKq1dQGTXHUflg3rppLymwAhr7il9x-0nPgg,1062
170
+ spikezoo-0.2.3.3.dist-info/METADATA,sha256=3aqIRpJr6TAfjldU8ZLmWy6uuCUla400W5tGpkH-X2M,11941
171
+ spikezoo-0.2.3.3.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
172
+ spikezoo-0.2.3.3.dist-info/top_level.txt,sha256=xF2iuOstrACJh43NW4dsTwIdgKfXPXAb_Xzl3M1ricM,9
173
+ spikezoo-0.2.3.3.dist-info/RECORD,,
@@ -1,181 +0,0 @@
1
- from torch.utils.data import Dataset
2
- from pathlib import Path
3
- from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
4
- from dataclasses import dataclass
5
- import re
6
-
7
-
8
- @dataclass
9
- class REDS_SSIRConfig(BaseDatasetConfig):
10
- dataset_name: str = "reds_ssir"
11
- root_dir: Path = Path(__file__).parent.parent / Path("data/REDS_SSIR")
12
- train_width: int = 96
13
- train_height: int = 96
14
- test_width: int = 1280
15
- test_height: int = 720
16
- width: int = -1
17
- height: int = -1
18
- with_img: bool = True
19
- spike_length_train: int = 41
20
- spike_length_test: int = 301
21
-
22
- # post process
23
- def __post_init__(self):
24
- self.root_dir = Path(self.root_dir) if isinstance(self.root_dir, str) else self.root_dir
25
- # todo try download
26
- assert self.root_dir.exists(), f"No files found in {self.root_dir} for the specified dataset `{self.dataset_name}`."
27
- # train/test split
28
- if self.split == "train":
29
- self.spike_length = self.spike_length_train
30
- self.width = self.train_width
31
- self.height = self.train_height
32
- else:
33
- self.spike_length = self.spike_length_test
34
- self.width = self.test_width
35
- self.height = self.test_height
36
-
37
-
38
- class REDS_SSIR(BaseDataset):
39
- def __init__(self, cfg: BaseDatasetConfig):
40
- super(REDS_SSIR, self).__init__(cfg)
41
-
42
- def prepare_data(self):
43
- """Specify the spike and image files to be loaded."""
44
- # spike/imgs dir train/test
45
- if self.cfg.split == "train":
46
- self.img_dir = self.cfg.root_dir / Path("crop_mini/spike/train/interp_20_alpha_0.40")
47
- self.spike_dir = self.cfg.root_dir / Path("crop_mini/image/train/train_orig")
48
- else:
49
- self.img_dir = self.cfg.root_dir / Path("imgs/val/val_orig")
50
- self.spike_dir = self.cfg.root_dir / Path("spike/val/interp_20_alpha_0.40")
51
- # get files
52
- self.spike_list = self.get_spike_files(self.spike_dir)
53
- self.img_list = []
54
-
55
-
56
- class sreds_train(torch.utils.data.Dataset):
57
- def __init__(self, cfg):
58
- self.cfg = cfg
59
- self.pair_step = self.cfg["loader"]["pair_step"]
60
- self.augmentor = Augmentor(crop_size=self.cfg["loader"]["crop_size"])
61
- self.samples = self.collect_samples()
62
- print("The samples num of training data: {:d}".format(len(self.samples)))
63
-
64
- def confirm_exist(self, path_list_list):
65
- for pl in path_list_list:
66
- for p in pl:
67
- if not osp.exists(p):
68
- return 0
69
- return 1
70
-
71
- def collect_samples(self):
72
- spike_path = osp.join(
73
- self.cfg["data"]["root"], "crop_mini", "spike", "train", "interp_{:d}_alpha_{:.2f}".format(self.cfg["data"]["interp"], self.cfg["data"]["alpha"])
74
- )
75
- image_path = osp.join(self.cfg["data"]["root"], "crop_mini", "image", "train", "train_orig")
76
- scene_list = sorted(os.listdir(spike_path))
77
- samples = []
78
-
79
- for scene in scene_list:
80
- spike_dir = osp.join(spike_path, scene)
81
- image_dir = osp.join(image_path, scene)
82
- spk_path_list = sorted(os.listdir(spike_dir))
83
-
84
- spklen = len(spk_path_list)
85
- seq_len = self.cfg["model"]["seq_len"] + 2
86
- """
87
- for st in range(0, spklen - ((spklen - self.pair_step) % seq_len) - seq_len, self.pair_step):
88
- # 按照文件名称读取
89
- spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(st, st+seq_len)]
90
- images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4]+'.png') for ii in range(st, st+seq_len)]
91
-
92
- if(self.confirm_exist([spikes_path_list, images_path_list])):
93
- s = {}
94
- s['spikes_paths'] = spikes_path_list
95
- s['images_paths'] = images_path_list
96
- samples.append(s)
97
- """
98
- # 按照文件名称读取
99
- spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(spklen)]
100
- images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4] + ".png") for ii in range(spklen)]
101
-
102
- if self.confirm_exist([spikes_path_list, images_path_list]):
103
- s = {}
104
- s["spikes_paths"] = spikes_path_list
105
- s["images_paths"] = images_path_list
106
- samples.append(s)
107
-
108
- return samples
109
-
110
- def _load_sample(self, s):
111
- data = {}
112
-
113
- data["spikes"] = [np.array(dat_to_spmat(p, size=(96, 96)), dtype=np.float32) for p in s["spikes_paths"]]
114
- data["images"] = [read_img_gray(p) for p in s["images_paths"]]
115
-
116
- data["spikes"], data["images"] = self.augmentor(data["spikes"], data["images"])
117
- # print("data['spikes'][0].shape, data['images'][0].shape", data['spikes'][0].shape, data['images'][0].shape)
118
-
119
- return data
120
-
121
- def __len__(self):
122
- return len(self.samples)
123
-
124
- def __getitem__(self, index):
125
- data = self._load_sample(self.samples[index])
126
- return data
127
-
128
-
129
- class sreds_test(torch.utils.data.Dataset):
130
- def __init__(self, cfg):
131
- self.cfg = cfg
132
- self.samples = self.collect_samples()
133
- print("The samples num of testing data: {:d}".format(len(self.samples)))
134
-
135
- def confirm_exist(self, path_list_list):
136
- for pl in path_list_list:
137
- for p in pl:
138
- if not osp.exists(p):
139
- return 0
140
- return 1
141
-
142
- def collect_samples(self):
143
- spike_path = osp.join(
144
- self.cfg["data"]["root"], "spike", "val", "interp_{:d}_alpha_{:.2f}".format(self.cfg["data"]["interp"], self.cfg["data"]["alpha"])
145
- )
146
- image_path = osp.join(self.cfg["data"]["root"], "imgs", "val", "val_orig")
147
- scene_list = sorted(os.listdir(spike_path))
148
- samples = []
149
-
150
- for scene in scene_list:
151
- spike_dir = osp.join(spike_path, scene)
152
- image_dir = osp.join(image_path, scene)
153
- spk_path_list = sorted(os.listdir(spike_dir))
154
-
155
- spklen = len(spk_path_list)
156
- # seq_len = self.cfg['model']['seq_len']
157
-
158
- # 按照文件名称读取
159
- spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(spklen)]
160
- images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4] + ".png") for ii in range(spklen)]
161
-
162
- if self.confirm_exist([spikes_path_list, images_path_list]):
163
- s = {}
164
- s["spikes_paths"] = spikes_path_list
165
- s["images_paths"] = images_path_list
166
- samples.append(s)
167
-
168
- return samples
169
-
170
- def _load_sample(self, s):
171
- data = {}
172
- data["spikes"] = [np.array(dat_to_spmat(p, size=(720, 1280)), dtype=np.float32) for p in s["spikes_paths"]]
173
- data["images"] = [read_img_gray(p) for p in s["images_paths"]]
174
- return data
175
-
176
- def __len__(self):
177
- return len(self.samples)
178
-
179
- def __getitem__(self, index):
180
- data = self._load_sample(self.samples[index])
181
- return data