spikezoo 0.2.3.6__py3-none-any.whl → 0.2.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spikezoo/pipeline/train_cfgs.py +30 -29
- spikezoo/pipeline/train_pipeline.py +3 -3
- {spikezoo-0.2.3.6.dist-info → spikezoo-0.2.3.7.dist-info}/METADATA +2 -2
- {spikezoo-0.2.3.6.dist-info → spikezoo-0.2.3.7.dist-info}/RECORD +7 -7
- {spikezoo-0.2.3.6.dist-info → spikezoo-0.2.3.7.dist-info}/LICENSE.txt +0 -0
- {spikezoo-0.2.3.6.dist-info → spikezoo-0.2.3.7.dist-info}/WHEEL +0 -0
- {spikezoo-0.2.3.6.dist-info → spikezoo-0.2.3.7.dist-info}/top_level.txt +0 -0
spikezoo/pipeline/train_cfgs.py
CHANGED
@@ -18,7 +18,7 @@ class REDS_BASE_TrainConfig(TrainPipelineConfig):
|
|
18
18
|
steps_per_save_imgs: int = 200
|
19
19
|
steps_per_save_ckpt: int = 500
|
20
20
|
steps_per_cal_metrics: int = 100
|
21
|
-
metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim","lpips","niqe","brisque","piqe"])
|
21
|
+
metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim", "lpips", "niqe", "brisque", "piqe"])
|
22
22
|
|
23
23
|
# dataloader setting
|
24
24
|
bs_train: int = 8
|
@@ -26,44 +26,45 @@ class REDS_BASE_TrainConfig(TrainPipelineConfig):
|
|
26
26
|
pin_memory: bool = False
|
27
27
|
|
28
28
|
# train setting - optimizer & scheduler & loss_dict
|
29
|
-
optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4)
|
30
|
-
scheduler_cfg: Optional[SchedulerConfig] =
|
29
|
+
optimizer_cfg: OptimizerConfig = field(default_factory=lambda: AdamOptimizerConfig(lr=1e-4))
|
30
|
+
scheduler_cfg: Optional[SchedulerConfig] = field(
|
31
|
+
default_factory=lambda: MultiStepSchedulerConfig(milestones=[400], gamma=0.2)
|
32
|
+
) # from wgse
|
31
33
|
loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
|
32
34
|
|
33
35
|
|
36
|
+
# # ! Train Config for each method on the official setting, not recommended to utilize their default parameters owing to the dataset setting.
|
37
|
+
# @dataclass
|
38
|
+
# class BSFTrainConfig(TrainPipelineConfig):
|
39
|
+
# """Training setting for BSF. https://github.com/ruizhao26/BSF"""
|
34
40
|
|
35
|
-
#
|
36
|
-
|
37
|
-
|
38
|
-
"""Training setting for BSF. https://github.com/ruizhao26/BSF"""
|
41
|
+
# optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4, weight_decay=0.0)
|
42
|
+
# scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[10, 20, 30, 40, 50, 60, 70, 80, 90, 100])
|
43
|
+
# loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
|
39
44
|
|
40
|
-
optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4, weight_decay=0.0)
|
41
|
-
scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[10, 20, 30, 40, 50, 60, 70, 80, 90, 100])
|
42
|
-
loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
|
43
45
|
|
46
|
+
# @dataclass
|
47
|
+
# class WGSETrainConfig(TrainPipelineConfig):
|
48
|
+
# """Training setting for WGSE. https://github.com/Leozhangjiyuan/WGSE-SpikeCamera"""
|
44
49
|
|
45
|
-
|
46
|
-
|
47
|
-
|
50
|
+
# optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4, betas=(0.9, 0.99), weight_decay=0)
|
51
|
+
# scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[400, 600], gamma=0.2)
|
52
|
+
# loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
|
48
53
|
|
49
|
-
optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4, betas=(0.9, 0.99), weight_decay=0)
|
50
|
-
scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[400, 600], gamma=0.2)
|
51
|
-
loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
|
52
54
|
|
55
|
+
# @dataclass
|
56
|
+
# class STIRTrainConfig(TrainPipelineConfig):
|
57
|
+
# """Training setting for STIR. https://github.com/GitCVfb/STIR"""
|
53
58
|
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4, betas=(0.9, 0.999))
|
59
|
-
scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70], gamma=0.7)
|
60
|
-
loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
|
59
|
+
# optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4, betas=(0.9, 0.999))
|
60
|
+
# scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70], gamma=0.7)
|
61
|
+
# loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
|
61
62
|
|
62
63
|
|
63
|
-
@dataclass
|
64
|
-
class Spk2ImgNetTrainConfig(TrainPipelineConfig):
|
65
|
-
|
64
|
+
# @dataclass
|
65
|
+
# class Spk2ImgNetTrainConfig(TrainPipelineConfig):
|
66
|
+
# """Training setting for Spk2ImgNet. https://github.com/Vspacer/Spk2ImgNet"""
|
66
67
|
|
67
|
-
|
68
|
-
|
69
|
-
|
68
|
+
# optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4)
|
69
|
+
# scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[20], gamma=0.1)
|
70
|
+
# loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
|
@@ -52,7 +52,7 @@ class TrainPipelineConfig(PipelineConfig):
|
|
52
52
|
|
53
53
|
# train setting - optimizer & scheduler & loss_dict
|
54
54
|
"Optimizer config."
|
55
|
-
optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-3)
|
55
|
+
optimizer_cfg: OptimizerConfig = field(default_factory=lambda: AdamOptimizerConfig(lr=1e-3))
|
56
56
|
"Scheduler config."
|
57
57
|
scheduler_cfg: Optional[SchedulerConfig] = None
|
58
58
|
"Loss dict {loss_name,weight}."
|
@@ -82,7 +82,7 @@ class TrainPipeline(Pipeline):
|
|
82
82
|
"""Model and Data setup."""
|
83
83
|
# model
|
84
84
|
self.model: BaseModel = build_model_name(model_cfg) if isinstance(model_cfg, str) else build_model_cfg(model_cfg)
|
85
|
-
self.model.build_network(mode
|
85
|
+
self.model.build_network(mode="train", version="local")
|
86
86
|
torch.set_grad_enabled(True)
|
87
87
|
# data
|
88
88
|
if isinstance(dataset_cfg, str):
|
@@ -94,7 +94,7 @@ class TrainPipeline(Pipeline):
|
|
94
94
|
self.train_dataset.build_source("train")
|
95
95
|
self.dataset.build_source("test")
|
96
96
|
self.train_dataloader = build_dataloader(self.train_dataset, self.cfg)
|
97
|
-
self.dataloader = build_dataloader(self.dataset,self.cfg)
|
97
|
+
self.dataloader = build_dataloader(self.dataset, self.cfg)
|
98
98
|
# device
|
99
99
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
100
100
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: spikezoo
|
3
|
-
Version: 0.2.3.
|
3
|
+
Version: 0.2.3.7
|
4
4
|
Summary: A deep learning toolbox for spike-to-image models.
|
5
5
|
Home-page: https://github.com/chenkang455/Spike-Zoo
|
6
6
|
Author: Kang Chen
|
@@ -41,7 +41,7 @@ Dynamic: summary
|
|
41
41
|
|
42
42
|
<h5 align="center">
|
43
43
|
|
44
|
-
[](https://github.com/chenkang455/Spike-Zoo/stargazers) [](https://github.com/chenkang455/Spike-Zoo/stargazers) [](https://github.com/chenkang455/Spike-Zoo/issues) <a href="https://badge.fury.io/py/spikezoo"><img src="https://badge.fury.io/py/spikezoo.svg" alt="PyPI version"></a> <a href='https://spike-zoo.readthedocs.io/zh-cn/latest/index.html'><img src='https://readthedocs.com/projects/plenoptix-nerfstudio/badge/?version=latest' alt='Documentation Status' /></a>[](https://github.com/chenkang455/Spike-Zoo)
|
45
45
|
<p>
|
46
46
|
|
47
47
|
|
@@ -186,8 +186,8 @@ spikezoo/models/yourmodel_model.py,sha256=mQ3hRsDbHovxL6NhsxAKO-W3tvx5WwAHRZDyyG
|
|
186
186
|
spikezoo/pipeline/__init__.py,sha256=WPsukNR4cannwsghiukqNsWbWGH5DVPapR_Ly-WOU4Q,188
|
187
187
|
spikezoo/pipeline/base_pipeline.py,sha256=9-0vt70x2oftLlNvzRmmLIhnJZ9MtenFiZjQEZn3x58,13625
|
188
188
|
spikezoo/pipeline/ensemble_pipeline.py,sha256=cn-QzK-j7T9B43ONsRTr-lJQkquRyDSJfU9gutEO6nk,2614
|
189
|
-
spikezoo/pipeline/train_cfgs.py,sha256=
|
190
|
-
spikezoo/pipeline/train_pipeline.py,sha256=
|
189
|
+
spikezoo/pipeline/train_cfgs.py,sha256=fnxYmX070XolVx8rXjY0Nm4WdMU_geZdLcL9pWm3Uww,3157
|
190
|
+
spikezoo/pipeline/train_pipeline.py,sha256=CKFOF4mG3oXYJlzcTMG8AHVjKh0P9ndcGlBs9EoNIvE,8438
|
191
191
|
spikezoo/utils/__init__.py,sha256=bYLlusAXwLCoY4s6nhVgviax9ioRA9aea8qgRmj2HpI,152
|
192
192
|
spikezoo/utils/data_utils.py,sha256=mk1xeyIb7o_E1J7Z6-gtPq-rpKiMTxAWSTcvvPvVku8,2033
|
193
193
|
spikezoo/utils/img_utils.py,sha256=0O9z58VzLxQEAuz-GGWCbpeHuHPOCpgBVjCBV9kf6sI,2257
|
@@ -196,8 +196,8 @@ spikezoo/utils/other_utils.py,sha256=uWNWaII9Jv7fkWNfkAD9wD-4ID-GAzbR-gGYT-1FF_c
|
|
196
196
|
spikezoo/utils/scheduler_utils.py,sha256=5RBh-hl3-2y-IomxMs47T1p3JsbicZNYLza6q1uAKHo,828
|
197
197
|
spikezoo/utils/spike_utils.py,sha256=XBFo3JOiNeyAQhsdgd_e6v9vVSViHx8DzN0hO3SbxnE,4300
|
198
198
|
spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so,sha256=uXqu7ME---cZRRU5LUcLiNrjjtlOjxNwWHyTIQ10BGg,199088
|
199
|
-
spikezoo-0.2.3.
|
200
|
-
spikezoo-0.2.3.
|
201
|
-
spikezoo-0.2.3.
|
202
|
-
spikezoo-0.2.3.
|
203
|
-
spikezoo-0.2.3.
|
199
|
+
spikezoo-0.2.3.7.dist-info/LICENSE.txt,sha256=ukEi8E0PKq1dQGTXHUflg3rppLymwAhr7il9x-0nPgg,1062
|
200
|
+
spikezoo-0.2.3.7.dist-info/METADATA,sha256=W9E0K3HLmcnsrCoReh_DcHTyj1zpDmKvqwAdKa9NTYk,7205
|
201
|
+
spikezoo-0.2.3.7.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
202
|
+
spikezoo-0.2.3.7.dist-info/top_level.txt,sha256=xF2iuOstrACJh43NW4dsTwIdgKfXPXAb_Xzl3M1ricM,9
|
203
|
+
spikezoo-0.2.3.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|