spikezoo 0.2.3__py3-none-any.whl → 0.2.3.3__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -47,7 +47,7 @@ class BaseDatasetConfig:
47
47
  "Crop size."
48
48
  crop_size: tuple = (-1, -1)
49
49
  "Rate. (-1 denotes variant)"
50
- rate: float = 1
50
+ rate: float = 0.6
51
51
 
52
52
  # post process
53
53
  def __post_init__(self):
@@ -12,6 +12,8 @@ class RealWorldConfig(BaseDatasetConfig):
12
12
  with_img: bool = False
13
13
  spike_length_train: int = -1
14
14
  spike_length_test: int = -1
15
+ rate: float = 1
16
+
15
17
 
16
18
 
17
19
  class RealWorld(BaseDataset):
@@ -17,6 +17,7 @@ class SZDataConfig(BaseDatasetConfig):
17
17
  spike_length_test: int = -1
18
18
  spike_dir_name: str = "spike_data"
19
19
  img_dir_name: str = "sharp_data"
20
+ rate: float = 1
20
21
 
21
22
  class SZData(BaseDataset):
22
23
  def __init__(self, cfg: BaseDatasetConfig):
@@ -4,7 +4,6 @@ from dataclasses import dataclass
4
4
  import numpy as np
5
5
  import torch
6
6
 
7
-
8
7
  @dataclass
9
8
  class UHSRConfig(BaseDatasetConfig):
10
9
  dataset_name: str = "uhsr"
@@ -16,7 +15,7 @@ class UHSRConfig(BaseDatasetConfig):
16
15
  spike_length_test: int = 200
17
16
  spike_dir_name: str = "spike"
18
17
  img_dir_name: str = ""
19
-
18
+ rate: float = 1
20
19
 
21
20
  class UHSR(BaseDataset):
22
21
  def __init__(self, cfg: BaseDatasetConfig):
@@ -93,9 +93,9 @@ class BaseModel(nn.Module):
93
93
  if version != "local":
94
94
  load_folder = os.path.dirname(os.path.abspath(__file__))
95
95
  ckpt_name = f"{self.cfg.model_name}.{get_suffix(self.cfg.model_name,version)}"
96
- ckpt_path = f"weights/{version}/{ckpt_name}"
96
+ ckpt_path = os.path.join("weights",version,ckpt_name)
97
97
  ckpt_path = os.path.join(load_folder, ckpt_path)
98
- ckpt_path_url = f"{self.cfg.base_url}/{get_url_version(version)}/{ckpt_name}"
98
+ ckpt_path_url = os.path.join(self.cfg.base_url,get_url_version(version),ckpt_name)
99
99
  elif version == "local":
100
100
  ckpt_path = self.cfg.ckpt_path
101
101
 
@@ -56,6 +56,7 @@ class Pipeline:
56
56
 
57
57
  def _setup_model_data(self, model_cfg, dataset_cfg):
58
58
  """Model and Data setup."""
59
+ self.logger.info("Model and dataset is setting up...")
59
60
  # model [1] build the model. [2] build the network.
60
61
  self.model: BaseModel = build_model_name(model_cfg) if isinstance(model_cfg, str) else build_model_cfg(model_cfg)
61
62
  self.model.build_network(mode="eval", version=self.cfg.version)
@@ -68,6 +69,7 @@ class Pipeline:
68
69
 
69
70
  def _setup_pipeline(self):
70
71
  """Pipeline setup."""
72
+ self.logger.info("Pipeline is setting up...")
71
73
  # save folder
72
74
  self.thistime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S-%f")[:23]
73
75
  self.save_folder = Path(f"results") if len(self.cfg.save_folder) == 0 else self.cfg.save_folder
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: spikezoo
3
- Version: 0.2.3
3
+ Version: 0.2.3.3
4
4
  Summary: A deep learning toolbox for spike-to-image models.
5
5
  Home-page: https://github.com/chenkang455/Spike-Zoo
6
6
  Author: Kang Chen
@@ -36,15 +36,11 @@ Dynamic: requires-python
36
36
  Dynamic: summary
37
37
 
38
38
  <p align="center">
39
- <br>
40
- <img src="imgs/spike-zoo.png" width="500"/>
41
- <br>
39
+ <img src="imgs/spike-zoo.png" width="350"/>
42
40
  <p>
43
-
44
41
  <h5 align="center">
45
42
 
46
43
  [![GitHub repo stars](https://img.shields.io/github/stars/chenkang455/Spike-Zoo?style=flat&logo=github&logoColor=whitesmoke&label=Stars)](https://github.com/chenkang455/Spike-Zoo/stargazers) [![GitHub Issues](https://img.shields.io/github/issues/chenkang455/Spike-Zoo?style=flat&logo=github&logoColor=whitesmoke&label=Stars)](https://github.com/chenkang455/Spike-Zoo/issues) <a href="https://badge.fury.io/py/spikezoo"><img src="https://badge.fury.io/py/spikezoo.svg" alt="PyPI version"></a> [![License](https://img.shields.io/badge/License-MIT-yellow)](https://github.com/chenkang455/Spike-Zoo)
47
-
48
44
  <p>
49
45
 
50
46
  <!-- <h2 align="center">
@@ -165,16 +161,16 @@ We retrain all supported methods except `SPIKECLIP` on this REDS dataset (traini
165
161
 
166
162
  | Method | PSNR | SSIM | LPIPS | NIQE | BRISQUE | PIQE | Params (M) | FLOPs (G) | Latency (ms) |
167
163
  |----------------------|:-------:|:--------:|:---------:|:---------:|:----------:|:-------:|:------------:|:-----------:|:--------------:|
168
- | `TFI` | 16.503 | 0.454 | 0.382 | 7.289 | 43.17 | 49.12 | 0.00 | 0.00 | 3.60 |
169
- | `TFP` | 24.287 | 0.644 | 0.274 | 8.197 | 48.48 | 38.38 | 0.00 | 0.00 | 0.03 |
170
- | `SPIKECLIP` | 21.873 | 0.578 | 0.333 | 7.802 | 42.08 | 54.01 | 0.19 | 23.69 | 1.27 |
171
- | `SSIR` | 26.544 | 0.718 | 0.325 | 4.769 | 28.45 | 21.59 | 0.38 | 25.92 | 4.52 |
172
- | `SSML` | 33.697 | 0.943 | 0.088 | 4.669 | 32.48 | 37.30 | 2.38 | 386.02 | 244.18 |
173
- | `BASE` | 36.589 | 0.965 | 0.034 | 4.393 | 26.16 | 38.43 | 0.18 | 18.04 | 0.40 |
174
- | `STIR` | 37.914 | 0.973 | 0.027 | 4.236 | 25.10 | 39.18 | 5.08 | 43.31 | 21.07 |
175
- | `WGSE` | 39.036 | 0.978 | 0.023 | 4.231 | 25.76 | 44.11 | 3.81 | 415.26 | 73.62 |
176
- | `SPK2IMGNET` | 39.154 | 0.978 | 0.022 | 4.243 | 25.20 | 43.09 | 3.90 | 1000.50 | 123.38 |
177
- | `BSF` | 39.576 | 0.979 | 0.019 | 4.139 | 24.93 | 43.03 | 2.47 | 705.23 | 401.50 |
164
+ | `tfi` | 16.503 | 0.454 | 0.382 | 7.289 | 43.17 | 49.12 | 0.00 | 0.00 | 3.60 |
165
+ | `tfp` | 24.287 | 0.644 | 0.274 | 8.197 | 48.48 | 38.38 | 0.00 | 0.00 | 0.03 |
166
+ | `spikeclip` | 21.873 | 0.578 | 0.333 | 7.802 | 42.08 | 54.01 | 0.19 | 23.69 | 1.27 |
167
+ | `ssir` | 26.544 | 0.718 | 0.325 | 4.769 | 28.45 | 21.59 | 0.38 | 25.92 | 4.52 |
168
+ | `ssml` | 33.697 | 0.943 | 0.088 | 4.669 | 32.48 | 37.30 | 2.38 | 386.02 | 244.18 |
169
+ | `base` | 36.589 | 0.965 | 0.034 | 4.393 | 26.16 | 38.43 | 0.18 | 18.04 | 0.40 |
170
+ | `stir` | 37.914 | 0.973 | 0.027 | 4.236 | 25.10 | 39.18 | 5.08 | 43.31 | 21.07 |
171
+ | `wgse` | 39.036 | 0.978 | 0.023 | 4.231 | 25.76 | 44.11 | 3.81 | 415.26 | 73.62 |
172
+ | `spk2imgnet` | 39.154 | 0.978 | 0.022 | 4.243 | 25.20 | 43.09 | 3.90 | 1000.50 | 123.38 |
173
+ | `bsf` | 39.576 | 0.979 | 0.019 | 4.139 | 24.93 | 43.03 | 2.47 | 705.23 | 401.50 |
178
174
 
179
175
  ### 4. Model Usage
180
176
  We also provide a direct interface for users interested in taking the spike-to-image model as a part of their work:
@@ -135,15 +135,14 @@ spikezoo/data/base/train/spike/203_part2_key_id151.dat,sha256=YEenLmbPvcxnKkVn3O
135
135
  spikezoo/data/base/train/spike/203_part3_key_id151.dat,sha256=MY9nM6XzKj-P-tRQ33WZ3G5xulNTpAXKP0y8ZQo7AIQ,3762500
136
136
  spikezoo/data/base/train/spike/203_part4_key_id151.dat,sha256=IVi2jics66YzpIF-WTkw47te4qOj9cjdgz56GmHpJKg,3762500
137
137
  spikezoo/datasets/__init__.py,sha256=lRJsvCfgbe3qrd9BKTlG9dsgfIJbfXqWOynnlAcBiUI,3346
138
- spikezoo/datasets/base_dataset.py,sha256=gaMcrgWB3JLfc5lAM5wZjgMj72lZaO5otGfXOwRRiL8,5954
139
- spikezoo/datasets/realworld_dataset.py,sha256=CSyXj_uo0y1a4TlLDSTkB0CyFiFdZiU1phrWdkwvqgg,701
138
+ spikezoo/datasets/base_dataset.py,sha256=oQ_AqWuMlaKnR712_sJ4WiTbqqPqVsfcukDNpFDYXb0,5956
139
+ spikezoo/datasets/realworld_dataset.py,sha256=VqT6zcLa72DL3Lg8f4TThhYUa1xSIifsrPwpjvk2uBE,726
140
140
  spikezoo/datasets/reds_base_dataset.py,sha256=W-IJv9H1bsKgp3RT3zsV40jw2PqY2M76jtIS4Qpif1o,859
141
- spikezoo/datasets/reds_ssir_dataset.py,sha256=t0hm6PUWX5hfvSXB0UEv_JuihIhc5-mufrfur5bIq_0,7076
142
- spikezoo/datasets/szdata_dataset.py,sha256=7lSqNzWkVUVTVufPhm7AjpLghpZvjfu1X9c0sCsrjwM,850
143
- spikezoo/datasets/uhsr_dataset.py,sha256=Ue0vfKMpmvR0TelQx8G4xdWB8HdSlb4HSjIVOyFe5oE,1168
141
+ spikezoo/datasets/szdata_dataset.py,sha256=xvgkZFHNSQ-Sk_rqmgRKAqpeb2gYpt_gmstJKJ8ooqU,870
142
+ spikezoo/datasets/uhsr_dataset.py,sha256=MKQeQsoCal10yMgHy3I7NJDgJJgkKgruH5tantP921A,1186
144
143
  spikezoo/metrics/__init__.py,sha256=LIKeWNeEMZLANITQD68XJBOhDq7iHiKC7ExtdrXMyQs,3273
145
144
  spikezoo/models/__init__.py,sha256=QZTELBoM3bUW8jZoxN4OuA2RYKeVUT1fboyeIuK8Rtk,1722
146
- spikezoo/models/base_model.py,sha256=RyWcK8FSdAYdP8UxKULnnqe1uNn6Qr5tBexee-DzSm8,9138
145
+ spikezoo/models/base_model.py,sha256=v3TD4AmjttTZUg0vEy736TOFdbbBgDLZg_RL-b4-vYM,9152
147
146
  spikezoo/models/bsf_model.py,sha256=XeZcVC_ODJxyS_I6-CtzlHXSWntgsUtbuAKjczIQ_0M,3972
148
147
  spikezoo/models/spcsnet_model.py,sha256=kLzv-ASXZGnqEFx0jUBONBeRCrsnQ_omkQUYEnr6uJc,540
149
148
  spikezoo/models/spikeclip_model.py,sha256=Ej84RuYbkFRthtBMV1JtmTkUshAqINlrrJ7yiKIsC9s,1125
@@ -155,7 +154,7 @@ spikezoo/models/tfi_model.py,sha256=tgD_HsiXk9jGuh5f_Bh6c3BqJi1p5DWCVo4N1tp5fgs,
155
154
  spikezoo/models/tfp_model.py,sha256=ihl1H__bWIbE9oair_t8rNJ5qnPJPKl-r_DpaO-0Sdk,663
156
155
  spikezoo/models/wgse_model.py,sha256=Kl9uV-LeO0Lj7SuPQ9pglw1Khs2b-7miS3A_faL6WSU,805
157
156
  spikezoo/pipeline/__init__.py,sha256=WPsukNR4cannwsghiukqNsWbWGH5DVPapR_Ly-WOU4Q,188
158
- spikezoo/pipeline/base_pipeline.py,sha256=ns2rBRbKaq8M9yDtWOiD-t6lo_5_c5rE2ZWH9sR78P4,13361
157
+ spikezoo/pipeline/base_pipeline.py,sha256=3laGM-cMhNTSfCXx5jY53V8GY5BhPhHK48yjqm-Gre0,13478
159
158
  spikezoo/pipeline/ensemble_pipeline.py,sha256=ljZkGiCCpxvpC04Aa-r_tvBnqcBpUVi9fl_878tJAcg,2555
160
159
  spikezoo/pipeline/train_cfgs.py,sha256=6NO7DfPc7yjJfOrcIPQPfUPbUODz6eRKurEIDjMmaxA,3836
161
160
  spikezoo/pipeline/train_pipeline.py,sha256=BgHUsdv33B_OKauOVclNt7yIPb-_O-93ZHLHIjrwWaA,8459
@@ -167,8 +166,8 @@ spikezoo/utils/other_utils.py,sha256=uWNWaII9Jv7fkWNfkAD9wD-4ID-GAzbR-gGYT-1FF_c
167
166
  spikezoo/utils/scheduler_utils.py,sha256=5RBh-hl3-2y-IomxMs47T1p3JsbicZNYLza6q1uAKHo,828
168
167
  spikezoo/utils/spike_utils.py,sha256=u4Haa6Sp5xFqs61ztvq161oXTA_aZmNW3VYUZcayNW0,4296
169
168
  spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so,sha256=uXqu7ME---cZRRU5LUcLiNrjjtlOjxNwWHyTIQ10BGg,199088
170
- spikezoo-0.2.3.dist-info/LICENSE.txt,sha256=ukEi8E0PKq1dQGTXHUflg3rppLymwAhr7il9x-0nPgg,1062
171
- spikezoo-0.2.3.dist-info/METADATA,sha256=XajepZVqBR32M502EF2u-ALRC_s6cHk2IJK6AYgUo78,11960
172
- spikezoo-0.2.3.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
173
- spikezoo-0.2.3.dist-info/top_level.txt,sha256=xF2iuOstrACJh43NW4dsTwIdgKfXPXAb_Xzl3M1ricM,9
174
- spikezoo-0.2.3.dist-info/RECORD,,
169
+ spikezoo-0.2.3.3.dist-info/LICENSE.txt,sha256=ukEi8E0PKq1dQGTXHUflg3rppLymwAhr7il9x-0nPgg,1062
170
+ spikezoo-0.2.3.3.dist-info/METADATA,sha256=3aqIRpJr6TAfjldU8ZLmWy6uuCUla400W5tGpkH-X2M,11941
171
+ spikezoo-0.2.3.3.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
172
+ spikezoo-0.2.3.3.dist-info/top_level.txt,sha256=xF2iuOstrACJh43NW4dsTwIdgKfXPXAb_Xzl3M1ricM,9
173
+ spikezoo-0.2.3.3.dist-info/RECORD,,
@@ -1,181 +0,0 @@
1
- from torch.utils.data import Dataset
2
- from pathlib import Path
3
- from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
4
- from dataclasses import dataclass
5
- import re
6
-
7
-
8
- @dataclass
9
- class REDS_SSIRConfig(BaseDatasetConfig):
10
- dataset_name: str = "reds_ssir"
11
- root_dir: Path = Path(__file__).parent.parent / Path("data/REDS_SSIR")
12
- train_width: int = 96
13
- train_height: int = 96
14
- test_width: int = 1280
15
- test_height: int = 720
16
- width: int = -1
17
- height: int = -1
18
- with_img: bool = True
19
- spike_length_train: int = 41
20
- spike_length_test: int = 301
21
-
22
- # post process
23
- def __post_init__(self):
24
- self.root_dir = Path(self.root_dir) if isinstance(self.root_dir, str) else self.root_dir
25
- # todo try download
26
- assert self.root_dir.exists(), f"No files found in {self.root_dir} for the specified dataset `{self.dataset_name}`."
27
- # train/test split
28
- if self.split == "train":
29
- self.spike_length = self.spike_length_train
30
- self.width = self.train_width
31
- self.height = self.train_height
32
- else:
33
- self.spike_length = self.spike_length_test
34
- self.width = self.test_width
35
- self.height = self.test_height
36
-
37
-
38
- class REDS_SSIR(BaseDataset):
39
- def __init__(self, cfg: BaseDatasetConfig):
40
- super(REDS_SSIR, self).__init__(cfg)
41
-
42
- def prepare_data(self):
43
- """Specify the spike and image files to be loaded."""
44
- # spike/imgs dir train/test
45
- if self.cfg.split == "train":
46
- self.img_dir = self.cfg.root_dir / Path("crop_mini/spike/train/interp_20_alpha_0.40")
47
- self.spike_dir = self.cfg.root_dir / Path("crop_mini/image/train/train_orig")
48
- else:
49
- self.img_dir = self.cfg.root_dir / Path("imgs/val/val_orig")
50
- self.spike_dir = self.cfg.root_dir / Path("spike/val/interp_20_alpha_0.40")
51
- # get files
52
- self.spike_list = self.get_spike_files(self.spike_dir)
53
- self.img_list = []
54
-
55
-
56
- class sreds_train(torch.utils.data.Dataset):
57
- def __init__(self, cfg):
58
- self.cfg = cfg
59
- self.pair_step = self.cfg["loader"]["pair_step"]
60
- self.augmentor = Augmentor(crop_size=self.cfg["loader"]["crop_size"])
61
- self.samples = self.collect_samples()
62
- print("The samples num of training data: {:d}".format(len(self.samples)))
63
-
64
- def confirm_exist(self, path_list_list):
65
- for pl in path_list_list:
66
- for p in pl:
67
- if not osp.exists(p):
68
- return 0
69
- return 1
70
-
71
- def collect_samples(self):
72
- spike_path = osp.join(
73
- self.cfg["data"]["root"], "crop_mini", "spike", "train", "interp_{:d}_alpha_{:.2f}".format(self.cfg["data"]["interp"], self.cfg["data"]["alpha"])
74
- )
75
- image_path = osp.join(self.cfg["data"]["root"], "crop_mini", "image", "train", "train_orig")
76
- scene_list = sorted(os.listdir(spike_path))
77
- samples = []
78
-
79
- for scene in scene_list:
80
- spike_dir = osp.join(spike_path, scene)
81
- image_dir = osp.join(image_path, scene)
82
- spk_path_list = sorted(os.listdir(spike_dir))
83
-
84
- spklen = len(spk_path_list)
85
- seq_len = self.cfg["model"]["seq_len"] + 2
86
- """
87
- for st in range(0, spklen - ((spklen - self.pair_step) % seq_len) - seq_len, self.pair_step):
88
- # 按照文件名称读取
89
- spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(st, st+seq_len)]
90
- images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4]+'.png') for ii in range(st, st+seq_len)]
91
-
92
- if(self.confirm_exist([spikes_path_list, images_path_list])):
93
- s = {}
94
- s['spikes_paths'] = spikes_path_list
95
- s['images_paths'] = images_path_list
96
- samples.append(s)
97
- """
98
- # 按照文件名称读取
99
- spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(spklen)]
100
- images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4] + ".png") for ii in range(spklen)]
101
-
102
- if self.confirm_exist([spikes_path_list, images_path_list]):
103
- s = {}
104
- s["spikes_paths"] = spikes_path_list
105
- s["images_paths"] = images_path_list
106
- samples.append(s)
107
-
108
- return samples
109
-
110
- def _load_sample(self, s):
111
- data = {}
112
-
113
- data["spikes"] = [np.array(dat_to_spmat(p, size=(96, 96)), dtype=np.float32) for p in s["spikes_paths"]]
114
- data["images"] = [read_img_gray(p) for p in s["images_paths"]]
115
-
116
- data["spikes"], data["images"] = self.augmentor(data["spikes"], data["images"])
117
- # print("data['spikes'][0].shape, data['images'][0].shape", data['spikes'][0].shape, data['images'][0].shape)
118
-
119
- return data
120
-
121
- def __len__(self):
122
- return len(self.samples)
123
-
124
- def __getitem__(self, index):
125
- data = self._load_sample(self.samples[index])
126
- return data
127
-
128
-
129
- class sreds_test(torch.utils.data.Dataset):
130
- def __init__(self, cfg):
131
- self.cfg = cfg
132
- self.samples = self.collect_samples()
133
- print("The samples num of testing data: {:d}".format(len(self.samples)))
134
-
135
- def confirm_exist(self, path_list_list):
136
- for pl in path_list_list:
137
- for p in pl:
138
- if not osp.exists(p):
139
- return 0
140
- return 1
141
-
142
- def collect_samples(self):
143
- spike_path = osp.join(
144
- self.cfg["data"]["root"], "spike", "val", "interp_{:d}_alpha_{:.2f}".format(self.cfg["data"]["interp"], self.cfg["data"]["alpha"])
145
- )
146
- image_path = osp.join(self.cfg["data"]["root"], "imgs", "val", "val_orig")
147
- scene_list = sorted(os.listdir(spike_path))
148
- samples = []
149
-
150
- for scene in scene_list:
151
- spike_dir = osp.join(spike_path, scene)
152
- image_dir = osp.join(image_path, scene)
153
- spk_path_list = sorted(os.listdir(spike_dir))
154
-
155
- spklen = len(spk_path_list)
156
- # seq_len = self.cfg['model']['seq_len']
157
-
158
- # 按照文件名称读取
159
- spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(spklen)]
160
- images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4] + ".png") for ii in range(spklen)]
161
-
162
- if self.confirm_exist([spikes_path_list, images_path_list]):
163
- s = {}
164
- s["spikes_paths"] = spikes_path_list
165
- s["images_paths"] = images_path_list
166
- samples.append(s)
167
-
168
- return samples
169
-
170
- def _load_sample(self, s):
171
- data = {}
172
- data["spikes"] = [np.array(dat_to_spmat(p, size=(720, 1280)), dtype=np.float32) for p in s["spikes_paths"]]
173
- data["images"] = [read_img_gray(p) for p in s["images_paths"]]
174
- return data
175
-
176
- def __len__(self):
177
- return len(self.samples)
178
-
179
- def __getitem__(self, index):
180
- data = self._load_sample(self.samples[index])
181
- return data