spikezoo 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spikezoo/__init__.py ADDED
File without changes
File without changes
@@ -0,0 +1,68 @@
1
+ from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
2
+ from dataclasses import replace
3
+ import importlib, inspect
4
+ import os
5
+ import torch
6
+ from typing import Literal
7
+
8
+ # todo auto detect/register datasets
9
+ files_list = os.listdir(os.path.dirname(os.path.abspath(__file__)))
10
+ dataset_list = [file.replace("_dataset.py", "") for file in files_list if file.endswith("_dataset.py")]
11
+
12
+ # todo register function
13
+ def build_dataset_cfg(cfg: BaseDatasetConfig, split: Literal["train", "test"] = "test"):
14
+ """Build the dataset from the given dataset config."""
15
+ # build new cfg according to split
16
+ cfg = replace(cfg,split = split,spike_length = cfg.spike_length_train if split == "train" else cfg.spike_length_test)
17
+ # dataset module
18
+ module_name = cfg.dataset_name + "_dataset"
19
+ assert cfg.dataset_name in dataset_list, f"Given dataset {cfg.dataset_name} not in our dataset list {dataset_list}."
20
+ module_name = "spikezoo.datasets." + module_name
21
+ module = importlib.import_module(module_name)
22
+ # dataset,dataset_config
23
+ classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
24
+ dataset_cls: BaseDataset = getattr(module, classes[0])
25
+ dataset = dataset_cls(cfg)
26
+ return dataset
27
+
28
+
29
+ def build_dataset_name(dataset_name: str, split: Literal["train", "test"] = "test"):
30
+ """Build the default dataset from the given name."""
31
+ module_name = dataset_name + "_dataset"
32
+ assert dataset_name in dataset_list, f"Given dataset {dataset_name} not in our dataset list {dataset_list}."
33
+ module_name = "spikezoo.datasets." + module_name
34
+ module = importlib.import_module(module_name)
35
+ # dataset,dataset_config
36
+ classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
37
+ dataset_cls: BaseDataset = getattr(module, classes[0])
38
+ dataset_cfg: BaseDatasetConfig = getattr(module, classes[1])(split=split)
39
+ dataset = dataset_cls(dataset_cfg)
40
+ return dataset
41
+
42
+
43
+ # todo to modify according to the basicsr
44
+ def build_dataloader(dataset: BaseDataset,cfg = None):
45
+ # train dataloader
46
+ if dataset.cfg.split == "train":
47
+ if cfg is None:
48
+ return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
49
+ else:
50
+ return torch.utils.data.DataLoader(dataset, batch_size=cfg.bs_train, shuffle=True, num_workers=cfg.num_workers,pin_memory=cfg.pin_memory)
51
+ # test dataloader
52
+ elif dataset.cfg.split == "test":
53
+ return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
54
+
55
+
56
+ # dataset_size_dict = {}
57
+ # for dataset in dataset_list:
58
+ # module_name = dataset + "_dataset"
59
+ # module_name = "spikezoo.datasets." + module_name
60
+ # module = importlib.import_module(module_name)
61
+ # classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj)])
62
+ # dataset_cfg: BaseDatasetConfig = getattr(module, classes[1])
63
+ # dataset_size_dict[dataset] = (dataset_cfg.height, dataset_cfg.width)
64
+
65
+
66
+ # def get_dataset_size(name):
67
+ # assert name in dataset_list, f"Given dataset {name} not in our dataset list {dataset_list}."
68
+ # return dataset_size_dict[name]
@@ -0,0 +1,157 @@
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ from pathlib import Path
4
+ import cv2
5
+ import numpy as np
6
+ from spikezoo.utils.spike_utils import load_vidar_dat
7
+ import re
8
+ from dataclasses import dataclass, replace
9
+ from typing import Literal
10
+ import warnings
11
+ import torch
12
+ from tqdm import tqdm
13
+ from spikezoo.utils.data_utils import Augmentor
14
+
15
+
16
+ @dataclass
17
+ class BaseDatasetConfig:
18
+ # ------------- Not Recommended to Change -------------
19
+ "Dataset name."
20
+ dataset_name: str = "base"
21
+ "Directory specifying location of data."
22
+ root_dir: Path = Path(__file__).parent.parent / Path("data/base")
23
+ "Image width."
24
+ width: int = 400
25
+ "Image height."
26
+ height: int = 250
27
+ "Spike paried with the image or not."
28
+ with_img: bool = True
29
+ "Dataset spike length for the train data."
30
+ spike_length_train: int = -1
31
+ "Dataset spike length for the test data."
32
+ spike_length_test: int = -1
33
+ "Dataset spike length for the instantiation dataclass."
34
+ spike_length: int = -1
35
+ "Dir name for the spike."
36
+ spike_dir_name: str = "spike"
37
+ "Dir name for the image."
38
+ img_dir_name: str = "gt"
39
+
40
+ # ------------- Config -------------
41
+ "Dataset split: train/test. Default set as the 'test' for evaluation."
42
+ split: Literal["train", "test"] = "test"
43
+ "Use the data augumentation technique or not."
44
+ use_aug: bool = False
45
+ "Use cache mechanism."
46
+ use_cache: bool = False
47
+ "Crop size."
48
+ crop_size: tuple = (-1, -1)
49
+
50
+ # post process
51
+ def __post_init__(self):
52
+ self.spike_length = self.spike_length_train if self.split == "train" else self.spike_length_test
53
+ self.root_dir = Path(self.root_dir) if isinstance(self.root_dir, str) else self.root_dir
54
+ # todo try download
55
+ assert self.root_dir.exists(), f"No files found in {self.root_dir} for the specified dataset `{self.dataset_name}`."
56
+
57
+ # todo cache mechanism
58
+ class BaseDataset(Dataset):
59
+ def __init__(self, cfg: BaseDatasetConfig):
60
+ super(BaseDataset, self).__init__()
61
+ self.cfg = cfg
62
+ self.augmentor = Augmentor(self.cfg.crop_size) if self.cfg.use_aug == True and self.cfg.split == "train" else -1
63
+ self.prepare_data()
64
+ self.cache_data() if cfg.use_cache == True else -1
65
+ warnings.warn("Lengths of the image list and the spike list should be equal.") if len(self.img_list) != len(self.spike_list) else -1
66
+
67
+ def __len__(self):
68
+ return len(self.spike_list)
69
+
70
+ def __getitem__(self, idx: int):
71
+ # load data
72
+ if self.cfg.use_cache == True:
73
+ spike, img = self.cache_spkimg[idx]
74
+ spike = spike.to(torch.float32)
75
+ else:
76
+ spike = self.get_spike(idx)
77
+ img = self.get_img(idx)
78
+
79
+ # process data
80
+ if self.cfg.use_aug == True and self.cfg.split == "train":
81
+ spike, img = self.augmentor(spike, img)
82
+
83
+ batch = {"spike": spike, "img": img}
84
+ return batch
85
+
86
+ # todo: To be overridden
87
+ def prepare_data(self):
88
+ """Specify the spike and image files to be loaded."""
89
+ # spike
90
+ self.spike_dir = self.cfg.root_dir / self.cfg.split / self.cfg.spike_dir_name
91
+ self.spike_list = self.get_spike_files(self.spike_dir)
92
+ # gt
93
+ if self.cfg.with_img == True:
94
+ self.img_dir = self.cfg.root_dir / self.cfg.split / self.cfg.img_dir_name
95
+ self.img_list = self.get_image_files(self.img_dir)
96
+
97
+ # todo: To be overridden
98
+ def get_spike_files(self, path: Path):
99
+ """Recognize spike files automatically (default .dat)."""
100
+ files = path.glob("**/*.dat")
101
+ return sorted(files)
102
+
103
+ # todo: To be overridden
104
+ def load_spike(self, idx):
105
+ """Load the spike stream from the given idx."""
106
+ spike_name = str(self.spike_list[idx])
107
+ spike = load_vidar_dat(
108
+ spike_name,
109
+ height=self.cfg.height,
110
+ width=self.cfg.width,
111
+ out_type="float",
112
+ out_format="tensor",
113
+ )
114
+ return spike
115
+
116
+ def get_spike(self, idx):
117
+ """Get and process the spike stream from the given idx."""
118
+ spike_length = self.cfg.spike_length
119
+ spike = self.load_spike(idx)
120
+ assert spike.shape[0] >= spike_length, f"Given spike length {spike.shape[0]} smaller than the required length {spike_length}"
121
+ spike_mid = spike.shape[0] // 2
122
+ # spike length process
123
+ if spike_length == -1:
124
+ spike = spike
125
+ elif spike_length % 2 == 1:
126
+ spike = spike[spike_mid - spike_length // 2 : spike_mid + spike_length // 2 + 1]
127
+ elif spike_length % 2 == 0:
128
+ spike = spike[spike_mid - spike_length // 2 : spike_mid + spike_length // 2]
129
+ return spike
130
+
131
+ def get_image_files(self, path: Path):
132
+ """Recognize image files automatically."""
133
+ files = [f for f in path.glob("**/*") if re.match(r".*\.(jpg|jpeg|png)$", f.name, re.IGNORECASE)]
134
+ return sorted(files)
135
+
136
+ # todo: To be overridden
137
+ def get_img(self, idx):
138
+ """Get the image from the given idx."""
139
+ if self.cfg.with_img:
140
+ img_name = str(self.img_list[idx])
141
+ img = cv2.imread(img_name)
142
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
143
+ img = (img / 255).astype(np.float32)
144
+ img = img[None]
145
+ img = torch.from_numpy(img)
146
+ else:
147
+ spike = self.get_spike(idx)
148
+ img = torch.mean(spike, dim=0, keepdim=True)
149
+ return img
150
+
151
+ def cache_data(self):
152
+ """Cache the data."""
153
+ self.cache_spkimg = []
154
+ for idx in tqdm(range(len(self.spike_list)), desc="Caching data", unit="sample"):
155
+ spike = self.get_spike(idx).to(torch.uint8)
156
+ img = self.get_img(idx)
157
+ self.cache_spkimg.append([spike, img])
@@ -0,0 +1,25 @@
1
+ from pathlib import Path
2
+ from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
3
+ from dataclasses import dataclass
4
+
5
+
6
+ @dataclass
7
+ class RealWorld_Config(BaseDatasetConfig):
8
+ dataset_name: str = "realworld"
9
+ root_dir: Path = Path(__file__).parent.parent / Path("data/recVidarReal2019")
10
+ width: int = 400
11
+ height: int = 250
12
+ with_img: bool = False
13
+ spike_length_train: int = -1
14
+ spike_length_test: int = -1
15
+
16
+
17
+ class RealWorld(BaseDataset):
18
+ def __init__(self, cfg: BaseDatasetConfig):
19
+ super(RealWorld, self).__init__(cfg)
20
+
21
+ def prepare_data(self):
22
+ self.spike_dir = self.cfg.root_dir
23
+ self.spike_list = self.get_spike_files(self.spike_dir)
24
+
25
+
@@ -0,0 +1,27 @@
1
+ from torch.utils.data import Dataset
2
+ from pathlib import Path
3
+ from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
4
+ from dataclasses import dataclass
5
+ import re
6
+
7
+ @dataclass
8
+ class REDS_Small_Config(BaseDatasetConfig):
9
+ dataset_name: str = "reds_small"
10
+ root_dir: Path = Path(__file__).parent.parent / Path("data/REDS_Small")
11
+ width: int = 400
12
+ height: int = 250
13
+ with_img: bool = True
14
+ spike_length_train: int = 41
15
+ spike_length_test: int = 301
16
+ spike_dir_name: str = "spike"
17
+ img_dir_name: str = "gt"
18
+
19
+ class REDS_Small(BaseDataset):
20
+ def __init__(self, cfg: BaseDatasetConfig):
21
+ super(REDS_Small, self).__init__(cfg)
22
+
23
+ def prepare_data(self):
24
+ super().prepare_data()
25
+ if self.cfg.split == "train":
26
+ self.img_list = [self.img_dir / Path(str(s.name).replace('.dat','.png')) for s in self.spike_list]
27
+
@@ -0,0 +1,37 @@
1
+ from pathlib import Path
2
+ from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
3
+ from dataclasses import dataclass
4
+ import cv2
5
+ import torch
6
+ import numpy as np
7
+
8
+ @dataclass
9
+ class SZData_Config(BaseDatasetConfig):
10
+ dataset_name: str = "szdata"
11
+ root_dir: Path = Path(__file__).parent.parent / Path("data/dataset")
12
+ width: int = 400
13
+ height: int = 250
14
+ with_img: bool = True
15
+ spike_length_train: int = -1
16
+ spike_length_test: int = -1
17
+ spike_dir_name: str = "spike_data"
18
+ img_dir_name: str = "sharp_data"
19
+
20
+ class SZData(BaseDataset):
21
+ def __init__(self, cfg: BaseDatasetConfig):
22
+ super(SZData, self).__init__(cfg)
23
+
24
+ def get_img(self, idx):
25
+ if self.cfg.with_img:
26
+ spike_name = self.spike_list[idx]
27
+ img_name = str(spike_name).replace(self.cfg.spike_dir_name,self.cfg.img_dir_name).replace(".dat",".png")
28
+ img = cv2.imread(img_name)
29
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
30
+ img = (img / 255).astype(np.float32)
31
+ img = img[None]
32
+ img = torch.from_numpy(img)
33
+ else:
34
+ spike = self.get_spike(idx)
35
+ img = torch.mean(spike, dim=0, keepdim=True)
36
+ return img
37
+
@@ -0,0 +1,38 @@
1
+ from pathlib import Path
2
+ from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
3
+ from dataclasses import dataclass
4
+ import numpy as np
5
+ import torch
6
+
7
+ @dataclass
8
+ class UHSR_Config(BaseDatasetConfig):
9
+ dataset_name: str = "uhsr"
10
+ root_dir: Path = Path(__file__).parent.parent / Path("data/U-CALTECH")
11
+ width: int = 224
12
+ height: int = 224
13
+ with_img: bool = False
14
+ spike_length_train: int = 200
15
+ spike_length_test: int = 200
16
+ spike_dir_name: str = "spike"
17
+ img_dir_name: str = ""
18
+
19
+
20
+ class UHSR(BaseDataset):
21
+ def __init__(self, cfg: BaseDatasetConfig):
22
+ super(UHSR, self).__init__(cfg)
23
+
24
+ def prepare_data(self):
25
+ self.spike_dir = self.cfg.root_dir / self.cfg.dataset_split
26
+ self.spike_list = self.get_spike_files(self.spike_dir)
27
+
28
+ def get_spike_files(self, path: Path):
29
+ files = path.glob("**/*.npz")
30
+ return sorted(files)
31
+
32
+ def load_spike(self,idx):
33
+ spike_name = str(self.spike_list[idx])
34
+ data = np.load(spike_name)
35
+ spike = data["spk"].astype(np.float32)
36
+ spike = torch.from_numpy(spike)
37
+ spike = spike[:, 13:237, 13:237]
38
+ return spike
@@ -0,0 +1,96 @@
1
+ from skimage import metrics
2
+ import torch
3
+ import torch.hub
4
+ from lpips.lpips import LPIPS
5
+ import os
6
+ import os
7
+ import pyiqa
8
+ import numpy as np
9
+ from torchvision import transforms
10
+ import torch.nn.functional as F
11
+
12
+ # todo with the union type
13
+ metric_pair_names = ["psnr", "ssim", "lpips", "mse"]
14
+ metric_single_names = ["niqe", "brisque", "piqe"]
15
+ metric_all_names = metric_pair_names + metric_single_names
16
+
17
+ metric_single_list = {}
18
+
19
+ metric_pair_list = {
20
+ "mse": metrics.mean_squared_error,
21
+ "ssim": metrics.structural_similarity,
22
+ "psnr": metrics.peak_signal_noise_ratio,
23
+ "lpips": None,
24
+ }
25
+
26
+
27
+ def cal_metric_single(img: torch.Tensor, metric_name="niqe"):
28
+ if metric_name not in metric_single_list.keys():
29
+ if metric_name in pyiqa.list_models():
30
+ iqa_metric = pyiqa.create_metric(metric_name, device=torch.device("cuda"))
31
+ metric_single_list.update({metric_name: iqa_metric})
32
+ else:
33
+ raise RuntimeError(f"Metric {metric_name} not recognized by the IQA lib.")
34
+ # image process
35
+ if img.dim() == 3:
36
+ img = img[None]
37
+ elif img.dim() == 2:
38
+ img = img[None, None]
39
+
40
+ # resize
41
+ if metric_name == "liqe_mix":
42
+ short_edge = 384
43
+ h, w = img.shape[2], img.shape[3]
44
+ if h < w:
45
+ new_h, new_w = short_edge, int(w * short_edge / h)
46
+ else:
47
+ new_h, new_w = int(h * short_edge / w), short_edge
48
+ img = F.interpolate(img, size=(new_h, new_w), mode="bilinear", align_corners=False)
49
+ return metric_single_list[metric_name](img).item()
50
+
51
+
52
+ def cal_metric_pair(im1t: torch.Tensor, im2t: torch.Tensor, metric_name="mse"):
53
+ """
54
+ im1t, im2t: torch.tensors with batched imaged shape, range from (0, 1)
55
+ """
56
+ if metric_name not in metric_pair_list.keys():
57
+ raise RuntimeError(f"Metric {metric_name} not recognized")
58
+ if metric_name == "lpips" and metric_pair_list[metric_name] is None:
59
+ metric_pair_list[metric_name] = LPIPS().cuda() if im1t.is_cuda else LPIPS().cpu()
60
+ metric_method = metric_pair_list[metric_name]
61
+
62
+ # convert from [0, 1] to [-1, 1]
63
+ im1t = (im1t * 2 - 1).clamp(-1, 1)
64
+ im2t = (im2t * 2 - 1).clamp(-1, 1)
65
+
66
+ # [c,h,w] -> [1,c,h,w]
67
+ if im1t.dim() == 3:
68
+ im1t = im1t[None]
69
+ im2t = im2t[None]
70
+ elif im1t.dim() == 2:
71
+ im1t = im1t[None, None]
72
+ im2t = im2t[None, None]
73
+
74
+ # [1,h,w,3] -> [1,3,h,w]
75
+ if im1t.shape[-1] == 3:
76
+ im1t = im1t.permute(0, 3, 1, 2)
77
+ im2t = im2t.permute(0, 3, 1, 2)
78
+
79
+ # img array: [1,h,w,3] imgt tensor: [1,3,h,w]
80
+ im1 = im1t.permute(0, 2, 3, 1).detach().cpu().numpy()
81
+ im2 = im2t.permute(0, 2, 3, 1).detach().cpu().numpy()
82
+ batchsz, hei, wid, _ = im1.shape
83
+
84
+ # batch processing
85
+ values = []
86
+ for i in range(batchsz):
87
+ if metric_name in ["mse", "psnr"]:
88
+ value = metric_method(im1[i], im2[i])
89
+ elif metric_name in ["ssim"]:
90
+ value, ssimmap = metric_method(im1[i], im2[i], channel_axis=-1, data_range=2, full=True)
91
+ elif metric_name in ["lpips"]:
92
+ value = metric_method(im1t[i : i + 1], im2t[i : i + 1])[0, 0, 0, 0]
93
+ value = value.detach().cpu().float().item()
94
+ values.append(value)
95
+
96
+ return sum(values) / len(values)
@@ -0,0 +1,37 @@
1
+ import importlib
2
+ import inspect
3
+ from spikezoo.models.base_model import BaseModel,BaseModelConfig
4
+ import os
5
+ from pathlib import Path
6
+
7
+ current_file_path = Path(__file__).parent
8
+ files_list = os.listdir(os.path.dirname(os.path.abspath(__file__)))
9
+ model_list = [file.split("_")[0] for file in files_list if file.endswith("_model.py")]
10
+
11
+ # todo register function
12
+ def build_model_cfg(cfg: BaseModelConfig):
13
+ """Build the model from the given model config."""
14
+ # model module name
15
+ module_name = cfg.model_name + "_model"
16
+ assert cfg.model_name in model_list, f"Given model {cfg.model_name} not in our model zoo {model_list}."
17
+ module_name = "spikezoo.models." + module_name
18
+ module = importlib.import_module(module_name)
19
+ # model,model_config
20
+ classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
21
+ model_cls: BaseModel = getattr(module, classes[0])
22
+ model = model_cls(cfg)
23
+ return model
24
+
25
+ def build_model_name(model_name: str):
26
+ """Build the default dataset from the given name."""
27
+ # model module name
28
+ module_name = model_name + "_model"
29
+ assert model_name in model_list, f"Given model {model_name} not in our model zoo {model_list}."
30
+ module_name = "spikezoo.models." + module_name
31
+ module = importlib.import_module(module_name)
32
+ # model,model_config
33
+ classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
34
+ model_cls: BaseModel = getattr(module, classes[0])
35
+ model_cfg: BaseModelConfig = getattr(module, classes[1])()
36
+ model = model_cls(model_cfg)
37
+ return model
@@ -0,0 +1,177 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import importlib
4
+ import inspect
5
+ from dataclasses import dataclass, field
6
+ from spikezoo.utils import load_network, download_file
7
+ import os
8
+ import time
9
+ from typing import Dict
10
+ from torch.optim import Adam
11
+ from torch.optim.lr_scheduler import CosineAnnealingLR
12
+ import functools
13
+
14
+
15
+ # todo private design
16
+ @dataclass
17
+ class BaseModelConfig:
18
+ # default params for BaseModel
19
+ "Registerd model name."
20
+ model_name: str = "base"
21
+ "File name of the specified model."
22
+ model_file_name: str = "nets"
23
+ "Class name of the specified model in spikezoo/archs/base/{model_file_name}.py."
24
+ model_cls_name: str = "BaseNet"
25
+ "Spike input length for the specified model."
26
+ model_win_length: int = 41
27
+ "Model require model parameters or not."
28
+ require_params: bool = False
29
+ "Model stored path."
30
+ ckpt_path: str = ""
31
+ "Load pretrained weights or not."
32
+ load_state: bool = True
33
+ "Base url for storing pretrained models."
34
+ base_url: str = "https://github.com/chenkang455/Spike-Zoo/releases/download/v0.1/"
35
+ "Multi-GPU setting."
36
+ multi_gpu: bool = False
37
+ "Model parameters."
38
+ model_params: dict = field(default_factory=lambda: {})
39
+
40
+
41
+ class BaseModel(nn.Module):
42
+ def __init__(self, cfg: BaseModelConfig):
43
+ super(BaseModel, self).__init__()
44
+ self.cfg = cfg
45
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
46
+ self.net = self.build_network().to(self.device)
47
+ self.net = nn.DataParallel(self.net) if cfg.multi_gpu == True else self.net
48
+ self.spike_size = None
49
+ self.model_half_win_length: int = cfg.model_win_length // 2
50
+
51
+ # ! Might lead to low speed training on the BSF.
52
+ def forward(self, spike):
53
+ """A simple implementation for the spike-to-image conversion, given the spike input and output the reconstructed image."""
54
+ spike = self.preprocess_spike(spike)
55
+ img = self.net(spike)
56
+ img = self.postprocess_img(img)
57
+ return img
58
+
59
+ def build_network(self):
60
+ """Build the network and load the pretrained weight."""
61
+ # network
62
+ module = importlib.import_module(f"spikezoo.archs.{self.cfg.model_name}.{self.cfg.model_file_name}")
63
+ model_cls = getattr(module, self.cfg.model_cls_name)
64
+ model = model_cls(**self.cfg.model_params)
65
+ if self.cfg.load_state and self.cfg.require_params:
66
+ load_folder = os.path.dirname(os.path.abspath(__file__))
67
+ weight_path = os.path.join(load_folder, self.cfg.ckpt_path)
68
+ if os.path.exists(weight_path) == False:
69
+ os.makedirs(os.path.dirname(weight_path), exist_ok=True)
70
+ self.download_weight(weight_path)
71
+ time.sleep(0.5)
72
+ model = load_network(weight_path, model)
73
+ return model
74
+
75
+ def save_network(self, save_path):
76
+ """Save the network."""
77
+ network = self.net
78
+ if isinstance(network, nn.DataParallel):
79
+ network = network.module
80
+ state_dict = network.state_dict()
81
+ for key, param in state_dict.items():
82
+ state_dict[key] = param.cpu()
83
+ torch.save(state_dict, save_path)
84
+
85
+ def download_weight(self, weight_path):
86
+ """Download the pretrained weight from the given url."""
87
+ url = self.cfg.base_url + os.path.basename(self.cfg.ckpt_path)
88
+ download_file(url, weight_path)
89
+
90
+ def crop_spike_length(self, spike):
91
+ """Crop the spike length."""
92
+ spike_length = spike.shape[1]
93
+ spike_mid = spike_length // 2
94
+ assert spike_length >= self.cfg.model_win_length, f"Spike input is not long enough, given {spike_length} frames < {self.cfg.model_win_length}."
95
+ # even length
96
+ if self.cfg.model_win_length == self.model_half_win_length * 2:
97
+ spike = spike[
98
+ :,
99
+ spike_mid - self.model_half_win_length : spike_mid + self.model_half_win_length,
100
+ ]
101
+ # odd length
102
+ else:
103
+ spike = spike[
104
+ :,
105
+ spike_mid - self.model_half_win_length : spike_mid + self.model_half_win_length + 1,
106
+ ]
107
+ if self.spike_size == None:
108
+ self.spike_size = (spike.shape[2], spike.shape[3])
109
+ return spike
110
+
111
+ def preprocess_spike(self, spike):
112
+ """Preprocess the spike (length size)."""
113
+ spike = self.crop_spike_length(spike)
114
+ return spike
115
+
116
+ def postprocess_img(self, image):
117
+ """Postprocess the image."""
118
+ return image
119
+
120
+ # -------------------- Training Part --------------------
121
+ def setup_training(self, pipeline_cfg):
122
+ """Setup training optimizer and loss."""
123
+ from spikezoo.pipeline import TrainPipelineConfig
124
+
125
+ cfg: TrainPipelineConfig = pipeline_cfg
126
+ self.optimizer = Adam(self.net.parameters(), lr=cfg.lr, betas=(0.9, 0.99), weight_decay=0)
127
+ self.scheduler = CosineAnnealingLR(self.optimizer, T_max=cfg.epochs, eta_min=0)
128
+ self.criterion = nn.L1Loss()
129
+
130
+ def get_outputs_dict(self, batch):
131
+ """Get the output dict for the given input batch. (Designed for the training mode considering possible auxiliary output.)"""
132
+ # data process
133
+ spike = batch["spike"]
134
+ # outputs
135
+ outputs = {}
136
+ recon_img = self(spike)
137
+ outputs["recon_img"] = recon_img
138
+ return outputs
139
+
140
+ def get_visual_dict(self, batch, outputs):
141
+ """Get the visual dict from the given input batch and outputs."""
142
+ visual_dict = {}
143
+ visual_dict["recon"] = outputs["recon_img"]
144
+ visual_dict["img"] = batch["img"]
145
+ return visual_dict
146
+
147
+ def get_loss_dict(self, outputs, batch):
148
+ """Get the loss dict from the given input batch and outputs."""
149
+ # data process
150
+ gt_img = batch["img"]
151
+ # recon image
152
+ recon_img = outputs["recon_img"]
153
+ # loss dict
154
+ loss_dict = {}
155
+ loss_dict["l1"] = self.criterion(recon_img, gt_img)
156
+ loss_values_dict = {k: v.item() for k, v in loss_dict.items()}
157
+ return loss_dict,loss_values_dict
158
+
159
+ def get_paired_imgs(self, batch, outputs):
160
+ recon_img = outputs["recon_img"]
161
+ img = batch["img"]
162
+ return recon_img, img
163
+
164
+ def optimize_parameters(self, loss_dict):
165
+ """Optimize the parameters from the loss_dict."""
166
+ loss = functools.reduce(torch.add, loss_dict.values())
167
+ self.optimizer.zero_grad()
168
+ loss.backward()
169
+ self.optimizer.step()
170
+
171
+ def update_learning_rate(self):
172
+ """Update the learning rate."""
173
+ self.scheduler.step()
174
+
175
+ def feed_to_device(self, batch):
176
+ batch = {k: v.to(self.device, non_blocking=True) if torch.is_tensor(v) else v for k, v in batch.items()}
177
+ return batch