spikezoo 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (86) hide show
  1. spikezoo/__init__.py +23 -7
  2. spikezoo/archs/bsf/models/bsf/bsf.py +37 -25
  3. spikezoo/archs/bsf/models/bsf/rep.py +2 -2
  4. spikezoo/archs/spk2imgnet/nets.py +1 -1
  5. spikezoo/archs/ssir/models/networks.py +1 -1
  6. spikezoo/archs/ssml/model.py +9 -5
  7. spikezoo/archs/stir/metrics/losses.py +1 -1
  8. spikezoo/archs/stir/models/networks_STIR.py +16 -9
  9. spikezoo/archs/tfi/nets.py +1 -1
  10. spikezoo/archs/tfp/nets.py +1 -1
  11. spikezoo/archs/wgse/dwtnets.py +6 -6
  12. spikezoo/datasets/__init__.py +11 -9
  13. spikezoo/datasets/base_dataset.py +10 -3
  14. spikezoo/datasets/realworld_dataset.py +1 -3
  15. spikezoo/datasets/{reds_small_dataset.py → reds_base_dataset.py} +9 -8
  16. spikezoo/datasets/reds_ssir_dataset.py +181 -0
  17. spikezoo/datasets/szdata_dataset.py +5 -15
  18. spikezoo/datasets/uhsr_dataset.py +4 -3
  19. spikezoo/models/__init__.py +8 -6
  20. spikezoo/models/base_model.py +120 -64
  21. spikezoo/models/bsf_model.py +11 -3
  22. spikezoo/models/spcsnet_model.py +19 -0
  23. spikezoo/models/spikeclip_model.py +4 -3
  24. spikezoo/models/spk2imgnet_model.py +9 -15
  25. spikezoo/models/ssir_model.py +4 -6
  26. spikezoo/models/ssml_model.py +44 -2
  27. spikezoo/models/stir_model.py +26 -5
  28. spikezoo/models/tfi_model.py +3 -1
  29. spikezoo/models/tfp_model.py +4 -2
  30. spikezoo/models/wgse_model.py +8 -14
  31. spikezoo/pipeline/base_pipeline.py +79 -55
  32. spikezoo/pipeline/ensemble_pipeline.py +10 -9
  33. spikezoo/pipeline/train_cfgs.py +89 -0
  34. spikezoo/pipeline/train_pipeline.py +129 -30
  35. spikezoo/utils/optimizer_utils.py +22 -0
  36. spikezoo/utils/other_utils.py +31 -6
  37. spikezoo/utils/scheduler_utils.py +25 -0
  38. spikezoo/utils/spike_utils.py +61 -29
  39. spikezoo-0.2.3.dist-info/METADATA +263 -0
  40. {spikezoo-0.2.1.dist-info → spikezoo-0.2.3.dist-info}/RECORD +43 -80
  41. spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
  42. spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
  43. spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
  44. spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
  45. spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
  46. spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
  47. spikezoo/archs/spikeformer/CheckPoints/readme +0 -1
  48. spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +0 -60
  49. spikezoo/archs/spikeformer/DataProcess/DataLoader.py +0 -115
  50. spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +0 -39
  51. spikezoo/archs/spikeformer/EvalResults/readme +0 -1
  52. spikezoo/archs/spikeformer/LICENSE +0 -21
  53. spikezoo/archs/spikeformer/Metrics/Metrics.py +0 -50
  54. spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
  55. spikezoo/archs/spikeformer/Model/Loss.py +0 -89
  56. spikezoo/archs/spikeformer/Model/SpikeFormer.py +0 -230
  57. spikezoo/archs/spikeformer/Model/__init__.py +0 -0
  58. spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
  59. spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
  60. spikezoo/archs/spikeformer/README.md +0 -30
  61. spikezoo/archs/spikeformer/evaluate.py +0 -87
  62. spikezoo/archs/spikeformer/recon_real_data.py +0 -97
  63. spikezoo/archs/spikeformer/requirements.yml +0 -95
  64. spikezoo/archs/spikeformer/train.py +0 -173
  65. spikezoo/archs/spikeformer/utils.py +0 -22
  66. spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
  67. spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
  68. spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
  69. spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
  70. spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
  71. spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
  72. spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
  73. spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
  74. spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
  75. spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
  76. spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
  77. spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
  78. spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
  79. spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
  80. spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
  81. spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
  82. spikezoo/models/spikeformer_model.py +0 -50
  83. spikezoo-0.2.1.dist-info/METADATA +0 -167
  84. {spikezoo-0.2.1.dist-info → spikezoo-0.2.3.dist-info}/LICENSE.txt +0 -0
  85. {spikezoo-0.2.1.dist-info → spikezoo-0.2.3.dist-info}/WHEEL +0 -0
  86. {spikezoo-0.2.1.dist-info → spikezoo-0.2.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,22 @@
1
+ # code borrow from https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/engine/optimizers.py
2
+ from dataclasses import dataclass
3
+ import torch
4
+ from typing import Any, Dict, List, Optional, Type,Tuple
5
+
6
+ @dataclass
7
+ class OptimizerConfig:
8
+ def setup(self, model_params) -> torch.optim.Optimizer:
9
+ """Returns the instantiated object using the config."""
10
+ kwargs = vars(self).copy()
11
+ kwargs.pop("_target")
12
+ return self._target(model_params, **kwargs)
13
+
14
+ @dataclass
15
+ class AdamOptimizerConfig(OptimizerConfig):
16
+ """Basic optimizer config with Adam"""
17
+ lr: float = 1e-4
18
+ betas: Tuple[float, float] = (0.9, 0.999)
19
+ eps: float = 1e-8
20
+ weight_decay: float = 0
21
+ _target: Type = torch.optim.Adam
22
+
@@ -3,7 +3,9 @@ from dataclasses import dataclass, field, asdict
3
3
  import requests
4
4
  from tqdm import tqdm
5
5
  import os
6
-
6
+ import torch
7
+ import numpy as np
8
+ import random
7
9
 
8
10
  # log info
9
11
  def setup_logging(log_file):
@@ -36,9 +38,9 @@ def save_config(cfg, filename, mode="w"):
36
38
  def download_file(url, output_path):
37
39
  headers = {
38
40
  "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/90.0.4430.212 Safari/537.36",
39
- "Accept": "*/*", # 表示接受所有类型的响应
40
- "Accept-Encoding": "gzip, deflate, br", # 支持压缩格式
41
- "Connection": "keep-alive" # 保持连接
41
+ "Accept": "*/*",
42
+ "Accept-Encoding": "gzip, deflate, br",
43
+ "Connection": "keep-alive"
42
44
  }
43
45
 
44
46
  try:
@@ -62,8 +64,31 @@ def download_file(url, output_path):
62
64
  raise RuntimeError(f"Files fail to download 😔😔😔. Try downloading it from {url} and move it to {output_path}.")
63
65
 
64
66
  except requests.exceptions.RequestException as e:
65
- # If an error occurs, remove the partially downloaded file
66
67
  if os.path.exists(output_path):
67
68
  os.remove(output_path)
68
69
  print(f"Partial download failed. The incomplete file has been removed. 😔😔😔")
69
- raise RuntimeError(f"Files fail to download 😔😔😔. Try downloading it from {url} and move it to {output_path}.")
70
+ raise RuntimeError(f"Files fail to download 😔😔😔. Try downloading it from {url} and move it to {output_path}.")
71
+
72
+ def check_file_exists(url):
73
+ response = requests.head(url)
74
+ if response.status_code == 200:
75
+ return True
76
+ else:
77
+ return False
78
+
79
+ def getattr_case_insensitive(obj, name):
80
+ name = name.lower()
81
+ for attr in dir(obj):
82
+ if attr.lower() == name:
83
+ return getattr(obj, attr)
84
+ raise RuntimeError("No attr found!!")
85
+
86
+
87
+ def set_random_seed(seed):
88
+ """Set random seeds."""
89
+ torch.manual_seed(seed)
90
+ torch.cuda.manual_seed_all(seed)
91
+ np.random.seed(seed)
92
+ random.seed(seed)
93
+ torch.backends.cudnn.deterministic = True
94
+ torch.cuda.manual_seed(seed)
@@ -0,0 +1,25 @@
1
+ # code borrow from https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/engine/optimizers.py
2
+ from dataclasses import dataclass
3
+ import torch
4
+ from typing import Any, Dict, List, Optional, Type,Tuple
5
+
6
+ @dataclass
7
+ class SchedulerConfig:
8
+ def setup(self, optimizer) -> torch.optim.lr_scheduler.LRScheduler:
9
+ """Returns the instantiated object using the config."""
10
+ kwargs = vars(self).copy()
11
+ kwargs.pop("_target")
12
+ return self._target(optimizer, **kwargs)
13
+
14
+ @dataclass
15
+ class CosineAnnealingLRConfig(SchedulerConfig):
16
+ T_max: int
17
+ eta_min: float = 0
18
+ _target: Type = torch.optim.lr_scheduler.CosineAnnealingLR
19
+
20
+
21
+ @dataclass
22
+ class MultiStepSchedulerConfig(SchedulerConfig):
23
+ milestones: List[int]
24
+ gamma: float = 0.1
25
+ _target: Type = torch.optim.lr_scheduler.MultiStepLR
@@ -2,23 +2,28 @@ import numpy as np
2
2
  import torch
3
3
  import torch.nn as nn
4
4
  import os
5
- from .vidar_loader import load_vidar_dat_cpp
6
5
  from typing import Literal
6
+ import platform
7
+ import cv2
8
+ import imageio
7
9
 
8
- def load_vidar_dat(filename, height, width,remove_head=False, version:Literal['python','cpp'] = "cpp", out_format : Literal['array','tensor']="array",):
10
+ _platform_check_done = False
11
+
12
+
13
+ def load_vidar_dat(filename, height, width, remove_head=False, version: Literal["python", "cpp"] = "cpp", out_format: Literal["array", "tensor"] = "array"):
9
14
  """Load the spike stream from the .dat file."""
15
+ global _platform_check_done
10
16
  # Spike decode
11
- if version == "python":
12
- if isinstance(filename, str):
13
- array = np.fromfile(filename, dtype=np.uint8)
14
- elif isinstance(filename, (list, tuple)):
15
- l = []
16
- for name in filename:
17
- a = np.fromfile(name, dtype=np.uint8)
18
- l.append(a)
19
- array = np.concatenate(l)
20
- else:
21
- raise NotImplementedError
17
+ if version == "cpp" and platform.system().lower() == "linux":
18
+ from .vidar_loader import load_vidar_dat_cpp
19
+
20
+ spikes = load_vidar_dat_cpp(filename, height, width)
21
+ else:
22
+ # todo double check
23
+ if version == "cpp" and platform.system().lower() != "linux" and _platform_check_done == False:
24
+ _platform_check_done = True
25
+ print("Cpp load version is only supported on the linux now. Auto transfer to the python version.")
26
+ array = np.fromfile(filename, dtype=np.uint8)
22
27
  len_per_frame = height * width // 8
23
28
  framecnt = len(array) // len_per_frame
24
29
  spikes = []
@@ -33,10 +38,6 @@ def load_vidar_dat(filename, height, width,remove_head=False, version:Literal['p
33
38
  spk = spk[:, :, :-16] if remove_head == True else spk
34
39
  spikes.append(spk)
35
40
  spikes = np.concatenate(spikes).astype(np.float32)
36
- elif version == "cpp":
37
- spikes = load_vidar_dat_cpp(filename, height, width)
38
- else:
39
- raise RuntimeError("Not recognized version.")
40
41
 
41
42
  # # Output format conversion
42
43
  format_dict = {"array": lambda x: x, "tensor": torch.from_numpy}
@@ -44,15 +45,14 @@ def load_vidar_dat(filename, height, width,remove_head=False, version:Literal['p
44
45
  return spikes
45
46
 
46
47
 
47
- def SpikeToRaw(save_path, SpikeSeq, filpud=True, delete_if_exists=True):
48
+ def save_vidar_dat(save_path, SpikeSeq, filpud=True):
48
49
  """Save the spike sequence to the .dat file."""
49
- if delete_if_exists:
50
- if os.path.exists(save_path):
51
- os.remove(save_path)
50
+ if os.path.exists(save_path):
51
+ os.remove(save_path)
52
52
  sfn, h, w = SpikeSeq.shape
53
53
  remainder = int((h * w) % 8)
54
54
  base = np.power(2, np.linspace(0, 7, 8))
55
- fid = open(save_path, 'ab')
55
+ fid = open(save_path, "ab")
56
56
  for img_id in range(sfn):
57
57
  if filpud:
58
58
  spike = np.flipud(SpikeSeq[img_id, :, :])
@@ -61,20 +61,54 @@ def SpikeToRaw(save_path, SpikeSeq, filpud=True, delete_if_exists=True):
61
61
  if remainder == 0:
62
62
  spike = spike.flatten()
63
63
  else:
64
- spike = np.concatenate([spike.flatten(), np.array([0]*(8-remainder))])
65
- spike = spike.reshape([int(h*w/8), 8])
64
+ spike = np.concatenate([spike.flatten(), np.array([0] * (8 - remainder))])
65
+ spike = spike.reshape([int(h * w / 8), 8])
66
66
  data = spike * base
67
67
  data = np.sum(data, axis=1).astype(np.uint8)
68
68
  fid.write(data.tobytes())
69
69
  fid.close()
70
- return
70
+
71
+
72
+ def merge_vidar_dat(filename, dat_files, height, width, remove_head=False):
73
+ """Merge selected spike dat files."""
74
+ spikes = []
75
+ for dat_file in dat_files:
76
+ spike = load_vidar_dat(dat_file,height, width, remove_head)
77
+ spikes.append(spike)
78
+ spikes = np.concatenate(spikes, axis=0)
79
+ save_vidar_dat(filename, spikes)
80
+ return spikes
81
+
82
+ def visual_vidar_dat(filename, spike, out_format: Literal["mp4", "gif"] = "gif", fps=15):
83
+ """Convert the spike stream to the video."""
84
+ _, height, width = spike.shape
85
+ if out_format == "mp4":
86
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v") # 或 'avc1'
87
+ mp4_video = cv2.VideoWriter(filename, fourcc, fps, (width, height))
88
+ elif out_format == "gif":
89
+ frames = []
90
+
91
+ for idx in range(len(spike)):
92
+ spk = spike[idx]
93
+ spk = (255 * spk).astype(np.uint8)
94
+ spk = spk[..., None].repeat(3, axis=-1)
95
+ if out_format == "mp4":
96
+ mp4_video.write(spk)
97
+ elif out_format == "gif":
98
+ frames.append(spk)
99
+
100
+ if out_format == "mp4":
101
+ mp4_video.release()
102
+ elif out_format == "gif":
103
+ imageio.mimsave(filename, frames, "GIF", fps=fps, loop=0)
104
+
71
105
 
72
106
  def video2spike_simulation(imgs, threshold=2.0):
73
107
  """Convert the images input to the spike stream."""
74
108
  imgs = np.array(imgs)
75
- T,H, W = imgs.shape
109
+ T, H, W = imgs.shape
76
110
  spike = np.zeros([T, H, W], np.uint8)
77
- integral = np.random.random(size=([H,W])) * threshold
111
+ integral = np.random.random(size=([H, W])) * threshold
78
112
  for t in range(0, T):
79
113
  integral += imgs[t]
80
114
  fire = (integral - threshold) >= 0
@@ -82,5 +116,3 @@ def video2spike_simulation(imgs, threshold=2.0):
82
116
  integral[fire_pos] -= threshold
83
117
  spike[t][fire_pos] = 1
84
118
  return spike
85
-
86
-
@@ -0,0 +1,263 @@
1
+ Metadata-Version: 2.2
2
+ Name: spikezoo
3
+ Version: 0.2.3
4
+ Summary: A deep learning toolbox for spike-to-image models.
5
+ Home-page: https://github.com/chenkang455/Spike-Zoo
6
+ Author: Kang Chen
7
+ Author-email: mrchenkang@stu.pku.edu.cn
8
+ Requires-Python: >=3.7
9
+ Description-Content-Type: text/markdown
10
+ License-File: LICENSE.txt
11
+ Requires-Dist: torch
12
+ Requires-Dist: requests
13
+ Requires-Dist: numpy
14
+ Requires-Dist: tqdm
15
+ Requires-Dist: scikit-image
16
+ Requires-Dist: lpips
17
+ Requires-Dist: pyiqa
18
+ Requires-Dist: opencv-python
19
+ Requires-Dist: thop
20
+ Requires-Dist: pytorch-wavelets
21
+ Requires-Dist: pytz
22
+ Requires-Dist: PyWavelets
23
+ Requires-Dist: pandas
24
+ Requires-Dist: pillow
25
+ Requires-Dist: scikit-learn
26
+ Requires-Dist: scipy
27
+ Requires-Dist: spikingjelly
28
+ Requires-Dist: setuptools
29
+ Dynamic: author
30
+ Dynamic: author-email
31
+ Dynamic: description
32
+ Dynamic: description-content-type
33
+ Dynamic: home-page
34
+ Dynamic: requires-dist
35
+ Dynamic: requires-python
36
+ Dynamic: summary
37
+
38
+ <p align="center">
39
+ <br>
40
+ <img src="imgs/spike-zoo.png" width="500"/>
41
+ <br>
42
+ <p>
43
+
44
+ <h5 align="center">
45
+
46
+ [![GitHub repo stars](https://img.shields.io/github/stars/chenkang455/Spike-Zoo?style=flat&logo=github&logoColor=whitesmoke&label=Stars)](https://github.com/chenkang455/Spike-Zoo/stargazers) [![GitHub Issues](https://img.shields.io/github/issues/chenkang455/Spike-Zoo?style=flat&logo=github&logoColor=whitesmoke&label=Stars)](https://github.com/chenkang455/Spike-Zoo/issues) <a href="https://badge.fury.io/py/spikezoo"><img src="https://badge.fury.io/py/spikezoo.svg" alt="PyPI version"></a> [![License](https://img.shields.io/badge/License-MIT-yellow)](https://github.com/chenkang455/Spike-Zoo)
47
+
48
+ <p>
49
+
50
+ <!-- <h2 align="center">
51
+ <a href="">⚡Spike-Zoo: A Toolbox for Spike-to-Image Reconstruction
52
+ </a>
53
+ </h2> -->
54
+
55
+ ## 📖 About
56
+ ⚡Spike-Zoo is the go-to library for state-of-the-art pretrained **spike-to-image** models designed to reconstruct images from spike streams. Whether you're looking for a simple inference solution or aiming to train your own spike-to-image models, ⚡Spike-Zoo is a modular toolbox that supports both, with key features including:
57
+
58
+ - Fast inference with pre-trained models.
59
+ - Training support for custom-designed spike-to-image models.
60
+ - Specialized functions for processing spike data.
61
+
62
+
63
+
64
+ ## 🚩 Updates/Changelog
65
+ * **25-02-02:** Release the `Spike-Zoo v0.2` code, which supports more methods, provide more usages like training your method from scratch.
66
+ * **24-07-19:** Release the `Spike-Zoo v0.1` code for base evaluation of SOTA methods.
67
+
68
+ ## 🍾 Quick Start
69
+ ### 1. Installation
70
+ For users focused on **utilizing pretrained models for spike-to-image conversion**, we recommend installing SpikeZoo using one of the following methods:
71
+
72
+ * Install the last stable version `0.2.3` from PyPI:
73
+ ```
74
+ pip install spikezoo
75
+ ```
76
+ * Install the latest developing version `0.2.3` from the source code :
77
+ ```
78
+ git clone https://github.com/chenkang455/Spike-Zoo
79
+ cd Spike-Zoo
80
+ python setup.py install
81
+ ```
82
+
83
+ For users interested in **training their own spike-to-image model based on our framework**, we recommend cloning the repository and modifying the related code directly.
84
+ ```
85
+ git clone https://github.com/chenkang455/Spike-Zoo
86
+ cd Spike-Zoo
87
+ python setup.py develop
88
+ ```
89
+
90
+ ### 2. Inference
91
+ Reconstructing images from the spike is super easy with Spike-Zoo. Try the following code of the single model:
92
+ ``` python
93
+ from spikezoo.pipeline import Pipeline, PipelineConfig
94
+ import spikezoo as sz
95
+ pipeline = Pipeline(
96
+ cfg=PipelineConfig(save_folder="results",version="v023"),
97
+ model_cfg=sz.METHOD.BASE,
98
+ dataset_cfg=sz.DATASET.BASE
99
+ )
100
+ ```
101
+ You can also run multiple models at once by changing the pipeline (version parameter corresponds to our released different versions in [Releases](https://github.com/chenkang455/Spike-Zoo/releases)):
102
+ ``` python
103
+ import spikezoo as sz
104
+ from spikezoo.pipeline import EnsemblePipeline, EnsemblePipelineConfig
105
+ pipeline = EnsemblePipeline(
106
+ cfg=EnsemblePipelineConfig(save_folder="results",version="v023"),
107
+ model_cfg_list=[
108
+ sz.METHOD.BASE,sz.METHOD.TFP,sz.METHOD.TFI,sz.METHOD.SPK2IMGNET,sz.METHOD.WGSE,
109
+ sz.METHOD.SSML,sz.METHOD.BSF,sz.METHOD.STIR,sz.METHOD.SPIKECLIP,sz.METHOD.SSIR],
110
+ dataset_cfg=sz.DATASET.BASE,
111
+ )
112
+ ```
113
+ Having established our pipelines, we provide following functions to enjoy these spike-to-image models.
114
+
115
+ * I. Obtain the restoration metric and save the recovered image from the given spike:
116
+ ``` python
117
+ # 1. spike-to-image from the given dataset
118
+ pipeline.infer_from_dataset(idx = 0)
119
+
120
+ # 2. spike-to-image from the given .dat file
121
+ pipeline.infer_from_file(file_path = 'data/scissor.dat',width = 400,height=250)
122
+
123
+ # 3. spike-to-image from the given spike
124
+ import spikezoo as sz
125
+ spike = sz.load_vidar_dat("data/scissor.dat",width = 400,height = 250)
126
+ pipeline.infer_from_spk(spike)
127
+ ```
128
+
129
+
130
+ * II. Save all images from the given dataset.
131
+ ``` python
132
+ pipeline.save_imgs_from_dataset()
133
+ ```
134
+
135
+ * III. Calculate the metrics for the specified dataset.
136
+ ``` python
137
+ pipeline.cal_metrics()
138
+ ```
139
+
140
+ * IV. Calculate the parameters (params,flops,latency) based on the established pipeline.
141
+ ``` python
142
+ pipeline.cal_params()
143
+ ```
144
+
145
+ For detailed usage, welcome check [test_single.ipynb](examples/test/test_single.ipynb) and [test_ensemble.ipynb](examples/test/test_ensemble.ipynb).
146
+
147
+ ### 3. Training
148
+ We provide a user-friendly code for training our provided `base` model (modified from the `SpikeCLIP`) for the classic `REDS` dataset introduced in `Spk2ImgNet`:
149
+ ``` python
150
+ from spikezoo.pipeline import TrainPipelineConfig, TrainPipeline
151
+ from spikezoo.datasets.reds_base_dataset import REDS_BASEConfig
152
+ from spikezoo.models.base_model import BaseModelConfig
153
+ pipeline = TrainPipeline(
154
+ cfg=TrainPipelineConfig(save_folder="results", epochs = 10),
155
+ dataset_cfg=REDS_BASEConfig(root_dir = "spikezoo/data/REDS_BASE"),
156
+ model_cfg=BaseModelConfig(),
157
+ )
158
+ pipeline.train()
159
+ ```
160
+ We finish the training with one 4090 GPU in `2 minutes`, achieving `32.8dB` in PSNR and `0.92` in SSIM.
161
+
162
+ > 🌟 We encourage users to develop their models with simple modifications to our framework, and the tutorial will be released soon.
163
+
164
+ We retrain all supported methods except `SPIKECLIP` on this REDS dataset (training scripts are placed on [examples/train_reds_base](examples/train_reds_base) and evaluation script is placed on [test_REDS_base.py](examples/test/test_REDS_base.py)), with our reported metrics as follows:
165
+
166
+ | Method | PSNR | SSIM | LPIPS | NIQE | BRISQUE | PIQE | Params (M) | FLOPs (G) | Latency (ms) |
167
+ |----------------------|:-------:|:--------:|:---------:|:---------:|:----------:|:-------:|:------------:|:-----------:|:--------------:|
168
+ | `TFI` | 16.503 | 0.454 | 0.382 | 7.289 | 43.17 | 49.12 | 0.00 | 0.00 | 3.60 |
169
+ | `TFP` | 24.287 | 0.644 | 0.274 | 8.197 | 48.48 | 38.38 | 0.00 | 0.00 | 0.03 |
170
+ | `SPIKECLIP` | 21.873 | 0.578 | 0.333 | 7.802 | 42.08 | 54.01 | 0.19 | 23.69 | 1.27 |
171
+ | `SSIR` | 26.544 | 0.718 | 0.325 | 4.769 | 28.45 | 21.59 | 0.38 | 25.92 | 4.52 |
172
+ | `SSML` | 33.697 | 0.943 | 0.088 | 4.669 | 32.48 | 37.30 | 2.38 | 386.02 | 244.18 |
173
+ | `BASE` | 36.589 | 0.965 | 0.034 | 4.393 | 26.16 | 38.43 | 0.18 | 18.04 | 0.40 |
174
+ | `STIR` | 37.914 | 0.973 | 0.027 | 4.236 | 25.10 | 39.18 | 5.08 | 43.31 | 21.07 |
175
+ | `WGSE` | 39.036 | 0.978 | 0.023 | 4.231 | 25.76 | 44.11 | 3.81 | 415.26 | 73.62 |
176
+ | `SPK2IMGNET` | 39.154 | 0.978 | 0.022 | 4.243 | 25.20 | 43.09 | 3.90 | 1000.50 | 123.38 |
177
+ | `BSF` | 39.576 | 0.979 | 0.019 | 4.139 | 24.93 | 43.03 | 2.47 | 705.23 | 401.50 |
178
+
179
+ ### 4. Model Usage
180
+ We also provide a direct interface for users interested in taking the spike-to-image model as a part of their work:
181
+
182
+ ```python
183
+ import spikezoo as sz
184
+ from spikezoo.models.base_model import BaseModel, BaseModelConfig
185
+ # input data
186
+ spike = sz.load_vidar_dat("data/data.dat", width=400, height=250, out_format="tensor")
187
+ spike = spike[None].cuda()
188
+ print(f"Input spike shape: {spike.shape}")
189
+ # net
190
+ net = BaseModel(BaseModelConfig(model_params={"inDim": 41}))
191
+ net.build_network(mode = "debug")
192
+ # process
193
+ recon_img = net(spike)
194
+ print(recon_img.shape,recon_img.max(),recon_img.min())
195
+ ```
196
+ For detailed usage, welcome check [test_model.ipynb](examples/test/test_model.ipynb).
197
+
198
+ ### 5. Spike Utility
199
+ #### I. Faster spike loading interface
200
+ We provide a faster `load_vidar_dat` function implemented with `cpp` (by [@zeal-ye](https://github.com/zeal-ye)):
201
+ ``` python
202
+ import spikezoo as sz
203
+ spike = sz.load_vidar_dat("data/scissor.dat",width = 400,height = 250,version='cpp')
204
+ ```
205
+ 🚀 Results on [test_load_dat.py](examples/test_load_dat.py) show that the `cpp` version is more than 10 times faster than the `python` version.
206
+
207
+ #### II. Spike simulation pipeline.
208
+ We provide our overall spike simulation pipeline in [scripts](scripts/), try to modify the config in `run.sh` and run the command to start the simulation process:
209
+ ``` bash
210
+ bash run.sh
211
+ ```
212
+
213
+ #### III. Spike-related functions.
214
+ For other spike-related functions, welcome check [spike_utils.py](spikezoo/utils/spike_utils.py)
215
+
216
+ ## 📅 TODO
217
+ - [x] Support the overall pipeline for spike simulation.
218
+ - [ ] Provide the tutorials.
219
+ - [ ] Support more training settings.
220
+ - [ ] Support more spike-based image reconstruction methods and datasets.
221
+
222
+ ## 🤗 Supports
223
+ Run the following code to find our supported models, datasets and metrics:
224
+ ``` python
225
+ import spikezoo as sz
226
+ print(sz.METHODS)
227
+ print(sz.DATASETS)
228
+ print(sz.METRICS)
229
+ ```
230
+ **Supported Models:**
231
+ | Models | Source
232
+ | ---- | ---- |
233
+ | `tfp`,`tfi` | Spike camera and its coding methods |
234
+ | `spk2imgnet` | Spk2ImgNet: Learning to Reconstruct Dynamic Scene from Continuous Spike Stream |
235
+ | `wgse` | Learning Temporal-Ordered Representation for Spike Streams Based on Discrete Wavelet Transforms |
236
+ | `ssml` | Self-Supervised Mutual Learning for Dynamic Scene Reconstruction of Spiking Camera |
237
+ | `ssir` | Spike Camera Image Reconstruction Using Deep Spiking Neural Networks |
238
+ | `bsf` | Boosting Spike Camera Image Reconstruction from a Perspective of Dealing with Spike Fluctuations |
239
+ | `stir` | Spatio-Temporal Interactive Learning for Efficient Image Reconstruction of Spiking Cameras |
240
+ | `base`,`spikeclip` | Rethinking High-speed Image Reconstruction Framework with Spike Camera |
241
+
242
+ **Supported Datasets:**
243
+ | Datasets | Source
244
+ | ---- | ---- |
245
+ | `reds_base` | Spk2ImgNet: Learning to Reconstruct Dynamic Scene from Continuous Spike Stream |
246
+ | `uhsr` | Recognizing Ultra-High-Speed Moving Objects with Bio-Inspired Spike Camera |
247
+ | `realworld` | `recVidarReal2019`,`momVidarReal2021` in [SpikeCV](https://github.com/Zyj061/SpikeCV) |
248
+ | `szdata` | SpikeReveal: Unlocking Temporal Sequences from Real Blurry Inputs with Spike Streams |
249
+
250
+
251
+ ## ✨‍ Acknowledgment
252
+ Our code is built on the open-source projects of [SpikeCV](https://spikecv.github.io/), [IQA-Pytorch](https://github.com/chaofengc/IQA-PyTorch), [BasicSR](https://github.com/XPixelGroup/BasicSR) and [NeRFStudio](https://github.com/nerfstudio-project/nerfstudio).We appreciate the effort of the contributors to these repositories. Thanks for [@ruizhao26](https://github.com/ruizhao26), [@shiyan_chen](https://github.com/hnmizuho) and [@Leozhangjiyuan](https://github.com/Leozhangjiyuan) for their help in building this project.
253
+
254
+ ## 📑 Citation
255
+ If you find our codes helpful to your research, please consider to use the following citation:
256
+ ```
257
+ @misc{spikezoo,
258
+ title={{Spike-Zoo}: Spike-Zoo: A Toolbox for Spike-to-Image Reconstruction},
259
+ author={Kang Chen and Zhiyuan Ye},
260
+ year={2025},
261
+ howpublished = "[Online]. Available: \url{https://github.com/chenkang455/Spike-Zoo}"
262
+ }
263
+ ```