spikezoo 0.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
spikezoo/__init__.py ADDED
File without changes
File without changes
@@ -0,0 +1,68 @@
1
+ from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
2
+ from dataclasses import replace
3
+ import importlib, inspect
4
+ import os
5
+ import torch
6
+ from typing import Literal
7
+
8
+ # todo auto detect/register datasets
9
+ files_list = os.listdir(os.path.dirname(os.path.abspath(__file__)))
10
+ dataset_list = [file.replace("_dataset.py", "") for file in files_list if file.endswith("_dataset.py")]
11
+
12
+ # todo register function
13
+ def build_dataset_cfg(cfg: BaseDatasetConfig, split: Literal["train", "test"] = "test"):
14
+ """Build the dataset from the given dataset config."""
15
+ # build new cfg according to split
16
+ cfg = replace(cfg,split = split,spike_length = cfg.spike_length_train if split == "train" else cfg.spike_length_test)
17
+ # dataset module
18
+ module_name = cfg.dataset_name + "_dataset"
19
+ assert cfg.dataset_name in dataset_list, f"Given dataset {cfg.dataset_name} not in our dataset list {dataset_list}."
20
+ module_name = "spikezoo.datasets." + module_name
21
+ module = importlib.import_module(module_name)
22
+ # dataset,dataset_config
23
+ classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
24
+ dataset_cls: BaseDataset = getattr(module, classes[0])
25
+ dataset = dataset_cls(cfg)
26
+ return dataset
27
+
28
+
29
+ def build_dataset_name(dataset_name: str, split: Literal["train", "test"] = "test"):
30
+ """Build the default dataset from the given name."""
31
+ module_name = dataset_name + "_dataset"
32
+ assert dataset_name in dataset_list, f"Given dataset {dataset_name} not in our dataset list {dataset_list}."
33
+ module_name = "spikezoo.datasets." + module_name
34
+ module = importlib.import_module(module_name)
35
+ # dataset,dataset_config
36
+ classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
37
+ dataset_cls: BaseDataset = getattr(module, classes[0])
38
+ dataset_cfg: BaseDatasetConfig = getattr(module, classes[1])(split=split)
39
+ dataset = dataset_cls(dataset_cfg)
40
+ return dataset
41
+
42
+
43
+ # todo to modify according to the basicsr
44
+ def build_dataloader(dataset: BaseDataset,cfg = None):
45
+ # train dataloader
46
+ if dataset.cfg.split == "train":
47
+ if cfg is None:
48
+ return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
49
+ else:
50
+ return torch.utils.data.DataLoader(dataset, batch_size=cfg.bs_train, shuffle=True, num_workers=cfg.num_workers,pin_memory=cfg.pin_memory)
51
+ # test dataloader
52
+ elif dataset.cfg.split == "test":
53
+ return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
54
+
55
+
56
+ # dataset_size_dict = {}
57
+ # for dataset in dataset_list:
58
+ # module_name = dataset + "_dataset"
59
+ # module_name = "spikezoo.datasets." + module_name
60
+ # module = importlib.import_module(module_name)
61
+ # classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj)])
62
+ # dataset_cfg: BaseDatasetConfig = getattr(module, classes[1])
63
+ # dataset_size_dict[dataset] = (dataset_cfg.height, dataset_cfg.width)
64
+
65
+
66
+ # def get_dataset_size(name):
67
+ # assert name in dataset_list, f"Given dataset {name} not in our dataset list {dataset_list}."
68
+ # return dataset_size_dict[name]
@@ -0,0 +1,157 @@
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ from pathlib import Path
4
+ import cv2
5
+ import numpy as np
6
+ from spikezoo.utils.spike_utils import load_vidar_dat
7
+ import re
8
+ from dataclasses import dataclass, replace
9
+ from typing import Literal
10
+ import warnings
11
+ import torch
12
+ from tqdm import tqdm
13
+ from spikezoo.utils.data_utils import Augmentor
14
+
15
+
16
+ @dataclass
17
+ class BaseDatasetConfig:
18
+ # ------------- Not Recommended to Change -------------
19
+ "Dataset name."
20
+ dataset_name: str = "base"
21
+ "Directory specifying location of data."
22
+ root_dir: Path = Path(__file__).parent.parent / Path("data/base")
23
+ "Image width."
24
+ width: int = 400
25
+ "Image height."
26
+ height: int = 250
27
+ "Spike paried with the image or not."
28
+ with_img: bool = True
29
+ "Dataset spike length for the train data."
30
+ spike_length_train: int = -1
31
+ "Dataset spike length for the test data."
32
+ spike_length_test: int = -1
33
+ "Dataset spike length for the instantiation dataclass."
34
+ spike_length: int = -1
35
+ "Dir name for the spike."
36
+ spike_dir_name: str = "spike"
37
+ "Dir name for the image."
38
+ img_dir_name: str = "gt"
39
+
40
+ # ------------- Config -------------
41
+ "Dataset split: train/test. Default set as the 'test' for evaluation."
42
+ split: Literal["train", "test"] = "test"
43
+ "Use the data augumentation technique or not."
44
+ use_aug: bool = False
45
+ "Use cache mechanism."
46
+ use_cache: bool = False
47
+ "Crop size."
48
+ crop_size: tuple = (-1, -1)
49
+
50
+ # post process
51
+ def __post_init__(self):
52
+ self.spike_length = self.spike_length_train if self.split == "train" else self.spike_length_test
53
+ self.root_dir = Path(self.root_dir) if isinstance(self.root_dir, str) else self.root_dir
54
+ # todo try download
55
+ assert self.root_dir.exists(), f"No files found in {self.root_dir} for the specified dataset `{self.dataset_name}`."
56
+
57
+ # todo cache mechanism
58
+ class BaseDataset(Dataset):
59
+ def __init__(self, cfg: BaseDatasetConfig):
60
+ super(BaseDataset, self).__init__()
61
+ self.cfg = cfg
62
+ self.augmentor = Augmentor(self.cfg.crop_size) if self.cfg.use_aug == True and self.cfg.split == "train" else -1
63
+ self.prepare_data()
64
+ self.cache_data() if cfg.use_cache == True else -1
65
+ warnings.warn("Lengths of the image list and the spike list should be equal.") if len(self.img_list) != len(self.spike_list) else -1
66
+
67
+ def __len__(self):
68
+ return len(self.spike_list)
69
+
70
+ def __getitem__(self, idx: int):
71
+ # load data
72
+ if self.cfg.use_cache == True:
73
+ spike, img = self.cache_spkimg[idx]
74
+ spike = spike.to(torch.float32)
75
+ else:
76
+ spike = self.get_spike(idx)
77
+ img = self.get_img(idx)
78
+
79
+ # process data
80
+ if self.cfg.use_aug == True and self.cfg.split == "train":
81
+ spike, img = self.augmentor(spike, img)
82
+
83
+ batch = {"spike": spike, "img": img}
84
+ return batch
85
+
86
+ # todo: To be overridden
87
+ def prepare_data(self):
88
+ """Specify the spike and image files to be loaded."""
89
+ # spike
90
+ self.spike_dir = self.cfg.root_dir / self.cfg.split / self.cfg.spike_dir_name
91
+ self.spike_list = self.get_spike_files(self.spike_dir)
92
+ # gt
93
+ if self.cfg.with_img == True:
94
+ self.img_dir = self.cfg.root_dir / self.cfg.split / self.cfg.img_dir_name
95
+ self.img_list = self.get_image_files(self.img_dir)
96
+
97
+ # todo: To be overridden
98
+ def get_spike_files(self, path: Path):
99
+ """Recognize spike files automatically (default .dat)."""
100
+ files = path.glob("**/*.dat")
101
+ return sorted(files)
102
+
103
+ # todo: To be overridden
104
+ def load_spike(self, idx):
105
+ """Load the spike stream from the given idx."""
106
+ spike_name = str(self.spike_list[idx])
107
+ spike = load_vidar_dat(
108
+ spike_name,
109
+ height=self.cfg.height,
110
+ width=self.cfg.width,
111
+ out_type="float",
112
+ out_format="tensor",
113
+ )
114
+ return spike
115
+
116
+ def get_spike(self, idx):
117
+ """Get and process the spike stream from the given idx."""
118
+ spike_length = self.cfg.spike_length
119
+ spike = self.load_spike(idx)
120
+ assert spike.shape[0] >= spike_length, f"Given spike length {spike.shape[0]} smaller than the required length {spike_length}"
121
+ spike_mid = spike.shape[0] // 2
122
+ # spike length process
123
+ if spike_length == -1:
124
+ spike = spike
125
+ elif spike_length % 2 == 1:
126
+ spike = spike[spike_mid - spike_length // 2 : spike_mid + spike_length // 2 + 1]
127
+ elif spike_length % 2 == 0:
128
+ spike = spike[spike_mid - spike_length // 2 : spike_mid + spike_length // 2]
129
+ return spike
130
+
131
+ def get_image_files(self, path: Path):
132
+ """Recognize image files automatically."""
133
+ files = [f for f in path.glob("**/*") if re.match(r".*\.(jpg|jpeg|png)$", f.name, re.IGNORECASE)]
134
+ return sorted(files)
135
+
136
+ # todo: To be overridden
137
+ def get_img(self, idx):
138
+ """Get the image from the given idx."""
139
+ if self.cfg.with_img:
140
+ img_name = str(self.img_list[idx])
141
+ img = cv2.imread(img_name)
142
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
143
+ img = (img / 255).astype(np.float32)
144
+ img = img[None]
145
+ img = torch.from_numpy(img)
146
+ else:
147
+ spike = self.get_spike(idx)
148
+ img = torch.mean(spike, dim=0, keepdim=True)
149
+ return img
150
+
151
+ def cache_data(self):
152
+ """Cache the data."""
153
+ self.cache_spkimg = []
154
+ for idx in tqdm(range(len(self.spike_list)), desc="Caching data", unit="sample"):
155
+ spike = self.get_spike(idx).to(torch.uint8)
156
+ img = self.get_img(idx)
157
+ self.cache_spkimg.append([spike, img])
@@ -0,0 +1,25 @@
1
+ from pathlib import Path
2
+ from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
3
+ from dataclasses import dataclass
4
+
5
+
6
+ @dataclass
7
+ class RealWorld_Config(BaseDatasetConfig):
8
+ dataset_name: str = "realworld"
9
+ root_dir: Path = Path(__file__).parent.parent / Path("data/recVidarReal2019")
10
+ width: int = 400
11
+ height: int = 250
12
+ with_img: bool = False
13
+ spike_length_train: int = -1
14
+ spike_length_test: int = -1
15
+
16
+
17
+ class RealWorld(BaseDataset):
18
+ def __init__(self, cfg: BaseDatasetConfig):
19
+ super(RealWorld, self).__init__(cfg)
20
+
21
+ def prepare_data(self):
22
+ self.spike_dir = self.cfg.root_dir
23
+ self.spike_list = self.get_spike_files(self.spike_dir)
24
+
25
+
@@ -0,0 +1,27 @@
1
+ from torch.utils.data import Dataset
2
+ from pathlib import Path
3
+ from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
4
+ from dataclasses import dataclass
5
+ import re
6
+
7
+ @dataclass
8
+ class REDS_Small_Config(BaseDatasetConfig):
9
+ dataset_name: str = "reds_small"
10
+ root_dir: Path = Path(__file__).parent.parent / Path("data/REDS_Small")
11
+ width: int = 400
12
+ height: int = 250
13
+ with_img: bool = True
14
+ spike_length_train: int = 41
15
+ spike_length_test: int = 301
16
+ spike_dir_name: str = "spike"
17
+ img_dir_name: str = "gt"
18
+
19
+ class REDS_Small(BaseDataset):
20
+ def __init__(self, cfg: BaseDatasetConfig):
21
+ super(REDS_Small, self).__init__(cfg)
22
+
23
+ def prepare_data(self):
24
+ super().prepare_data()
25
+ if self.cfg.split == "train":
26
+ self.img_list = [self.img_dir / Path(str(s.name).replace('.dat','.png')) for s in self.spike_list]
27
+
@@ -0,0 +1,37 @@
1
+ from pathlib import Path
2
+ from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
3
+ from dataclasses import dataclass
4
+ import cv2
5
+ import torch
6
+ import numpy as np
7
+
8
+ @dataclass
9
+ class SZData_Config(BaseDatasetConfig):
10
+ dataset_name: str = "szdata"
11
+ root_dir: Path = Path(__file__).parent.parent / Path("data/dataset")
12
+ width: int = 400
13
+ height: int = 250
14
+ with_img: bool = True
15
+ spike_length_train: int = -1
16
+ spike_length_test: int = -1
17
+ spike_dir_name: str = "spike_data"
18
+ img_dir_name: str = "sharp_data"
19
+
20
+ class SZData(BaseDataset):
21
+ def __init__(self, cfg: BaseDatasetConfig):
22
+ super(SZData, self).__init__(cfg)
23
+
24
+ def get_img(self, idx):
25
+ if self.cfg.with_img:
26
+ spike_name = self.spike_list[idx]
27
+ img_name = str(spike_name).replace(self.cfg.spike_dir_name,self.cfg.img_dir_name).replace(".dat",".png")
28
+ img = cv2.imread(img_name)
29
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
30
+ img = (img / 255).astype(np.float32)
31
+ img = img[None]
32
+ img = torch.from_numpy(img)
33
+ else:
34
+ spike = self.get_spike(idx)
35
+ img = torch.mean(spike, dim=0, keepdim=True)
36
+ return img
37
+
@@ -0,0 +1,38 @@
1
+ from pathlib import Path
2
+ from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
3
+ from dataclasses import dataclass
4
+ import numpy as np
5
+ import torch
6
+
7
+ @dataclass
8
+ class UHSR_Config(BaseDatasetConfig):
9
+ dataset_name: str = "uhsr"
10
+ root_dir: Path = Path(__file__).parent.parent / Path("data/U-CALTECH")
11
+ width: int = 224
12
+ height: int = 224
13
+ with_img: bool = False
14
+ spike_length_train: int = 200
15
+ spike_length_test: int = 200
16
+ spike_dir_name: str = "spike"
17
+ img_dir_name: str = ""
18
+
19
+
20
+ class UHSR(BaseDataset):
21
+ def __init__(self, cfg: BaseDatasetConfig):
22
+ super(UHSR, self).__init__(cfg)
23
+
24
+ def prepare_data(self):
25
+ self.spike_dir = self.cfg.root_dir / self.cfg.dataset_split
26
+ self.spike_list = self.get_spike_files(self.spike_dir)
27
+
28
+ def get_spike_files(self, path: Path):
29
+ files = path.glob("**/*.npz")
30
+ return sorted(files)
31
+
32
+ def load_spike(self,idx):
33
+ spike_name = str(self.spike_list[idx])
34
+ data = np.load(spike_name)
35
+ spike = data["spk"].astype(np.float32)
36
+ spike = torch.from_numpy(spike)
37
+ spike = spike[:, 13:237, 13:237]
38
+ return spike
@@ -0,0 +1,96 @@
1
+ from skimage import metrics
2
+ import torch
3
+ import torch.hub
4
+ from lpips.lpips import LPIPS
5
+ import os
6
+ import os
7
+ import pyiqa
8
+ import numpy as np
9
+ from torchvision import transforms
10
+ import torch.nn.functional as F
11
+
12
+ # todo with the union type
13
+ metric_pair_names = ["psnr", "ssim", "lpips", "mse"]
14
+ metric_single_names = ["niqe", "brisque", "piqe"]
15
+ metric_all_names = metric_pair_names + metric_single_names
16
+
17
+ metric_single_list = {}
18
+
19
+ metric_pair_list = {
20
+ "mse": metrics.mean_squared_error,
21
+ "ssim": metrics.structural_similarity,
22
+ "psnr": metrics.peak_signal_noise_ratio,
23
+ "lpips": None,
24
+ }
25
+
26
+
27
+ def cal_metric_single(img: torch.Tensor, metric_name="niqe"):
28
+ if metric_name not in metric_single_list.keys():
29
+ if metric_name in pyiqa.list_models():
30
+ iqa_metric = pyiqa.create_metric(metric_name, device=torch.device("cuda"))
31
+ metric_single_list.update({metric_name: iqa_metric})
32
+ else:
33
+ raise RuntimeError(f"Metric {metric_name} not recognized by the IQA lib.")
34
+ # image process
35
+ if img.dim() == 3:
36
+ img = img[None]
37
+ elif img.dim() == 2:
38
+ img = img[None, None]
39
+
40
+ # resize
41
+ if metric_name == "liqe_mix":
42
+ short_edge = 384
43
+ h, w = img.shape[2], img.shape[3]
44
+ if h < w:
45
+ new_h, new_w = short_edge, int(w * short_edge / h)
46
+ else:
47
+ new_h, new_w = int(h * short_edge / w), short_edge
48
+ img = F.interpolate(img, size=(new_h, new_w), mode="bilinear", align_corners=False)
49
+ return metric_single_list[metric_name](img).item()
50
+
51
+
52
+ def cal_metric_pair(im1t: torch.Tensor, im2t: torch.Tensor, metric_name="mse"):
53
+ """
54
+ im1t, im2t: torch.tensors with batched imaged shape, range from (0, 1)
55
+ """
56
+ if metric_name not in metric_pair_list.keys():
57
+ raise RuntimeError(f"Metric {metric_name} not recognized")
58
+ if metric_name == "lpips" and metric_pair_list[metric_name] is None:
59
+ metric_pair_list[metric_name] = LPIPS().cuda() if im1t.is_cuda else LPIPS().cpu()
60
+ metric_method = metric_pair_list[metric_name]
61
+
62
+ # convert from [0, 1] to [-1, 1]
63
+ im1t = (im1t * 2 - 1).clamp(-1, 1)
64
+ im2t = (im2t * 2 - 1).clamp(-1, 1)
65
+
66
+ # [c,h,w] -> [1,c,h,w]
67
+ if im1t.dim() == 3:
68
+ im1t = im1t[None]
69
+ im2t = im2t[None]
70
+ elif im1t.dim() == 2:
71
+ im1t = im1t[None, None]
72
+ im2t = im2t[None, None]
73
+
74
+ # [1,h,w,3] -> [1,3,h,w]
75
+ if im1t.shape[-1] == 3:
76
+ im1t = im1t.permute(0, 3, 1, 2)
77
+ im2t = im2t.permute(0, 3, 1, 2)
78
+
79
+ # img array: [1,h,w,3] imgt tensor: [1,3,h,w]
80
+ im1 = im1t.permute(0, 2, 3, 1).detach().cpu().numpy()
81
+ im2 = im2t.permute(0, 2, 3, 1).detach().cpu().numpy()
82
+ batchsz, hei, wid, _ = im1.shape
83
+
84
+ # batch processing
85
+ values = []
86
+ for i in range(batchsz):
87
+ if metric_name in ["mse", "psnr"]:
88
+ value = metric_method(im1[i], im2[i])
89
+ elif metric_name in ["ssim"]:
90
+ value, ssimmap = metric_method(im1[i], im2[i], channel_axis=-1, data_range=2, full=True)
91
+ elif metric_name in ["lpips"]:
92
+ value = metric_method(im1t[i : i + 1], im2t[i : i + 1])[0, 0, 0, 0]
93
+ value = value.detach().cpu().float().item()
94
+ values.append(value)
95
+
96
+ return sum(values) / len(values)
@@ -0,0 +1,37 @@
1
+ import importlib
2
+ import inspect
3
+ from spikezoo.models.base_model import BaseModel,BaseModelConfig
4
+ import os
5
+ from pathlib import Path
6
+
7
+ current_file_path = Path(__file__).parent
8
+ files_list = os.listdir(os.path.dirname(os.path.abspath(__file__)))
9
+ model_list = [file.split("_")[0] for file in files_list if file.endswith("_model.py")]
10
+
11
+ # todo register function
12
+ def build_model_cfg(cfg: BaseModelConfig):
13
+ """Build the model from the given model config."""
14
+ # model module name
15
+ module_name = cfg.model_name + "_model"
16
+ assert cfg.model_name in model_list, f"Given model {cfg.model_name} not in our model zoo {model_list}."
17
+ module_name = "spikezoo.models." + module_name
18
+ module = importlib.import_module(module_name)
19
+ # model,model_config
20
+ classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
21
+ model_cls: BaseModel = getattr(module, classes[0])
22
+ model = model_cls(cfg)
23
+ return model
24
+
25
+ def build_model_name(model_name: str):
26
+ """Build the default dataset from the given name."""
27
+ # model module name
28
+ module_name = model_name + "_model"
29
+ assert model_name in model_list, f"Given model {model_name} not in our model zoo {model_list}."
30
+ module_name = "spikezoo.models." + module_name
31
+ module = importlib.import_module(module_name)
32
+ # model,model_config
33
+ classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
34
+ model_cls: BaseModel = getattr(module, classes[0])
35
+ model_cfg: BaseModelConfig = getattr(module, classes[1])()
36
+ model = model_cls(model_cfg)
37
+ return model
@@ -0,0 +1,177 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import importlib
4
+ import inspect
5
+ from dataclasses import dataclass, field
6
+ from spikezoo.utils import load_network, download_file
7
+ import os
8
+ import time
9
+ from typing import Dict
10
+ from torch.optim import Adam
11
+ from torch.optim.lr_scheduler import CosineAnnealingLR
12
+ import functools
13
+
14
+
15
+ # todo private design
16
+ @dataclass
17
+ class BaseModelConfig:
18
+ # default params for BaseModel
19
+ "Registerd model name."
20
+ model_name: str = "base"
21
+ "File name of the specified model."
22
+ model_file_name: str = "nets"
23
+ "Class name of the specified model in spikezoo/archs/base/{model_file_name}.py."
24
+ model_cls_name: str = "BaseNet"
25
+ "Spike input length for the specified model."
26
+ model_win_length: int = 41
27
+ "Model require model parameters or not."
28
+ require_params: bool = False
29
+ "Model stored path."
30
+ ckpt_path: str = ""
31
+ "Load pretrained weights or not."
32
+ load_state: bool = True
33
+ "Base url for storing pretrained models."
34
+ base_url: str = "https://github.com/chenkang455/Spike-Zoo/releases/download/v0.1/"
35
+ "Multi-GPU setting."
36
+ multi_gpu: bool = False
37
+ "Model parameters."
38
+ model_params: dict = field(default_factory=lambda: {})
39
+
40
+
41
+ class BaseModel(nn.Module):
42
+ def __init__(self, cfg: BaseModelConfig):
43
+ super(BaseModel, self).__init__()
44
+ self.cfg = cfg
45
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
46
+ self.net = self.build_network().to(self.device)
47
+ self.net = nn.DataParallel(self.net) if cfg.multi_gpu == True else self.net
48
+ self.spike_size = None
49
+ self.model_half_win_length: int = cfg.model_win_length // 2
50
+
51
+ # ! Might lead to low speed training on the BSF.
52
+ def forward(self, spike):
53
+ """A simple implementation for the spike-to-image conversion, given the spike input and output the reconstructed image."""
54
+ spike = self.preprocess_spike(spike)
55
+ img = self.net(spike)
56
+ img = self.postprocess_img(img)
57
+ return img
58
+
59
+ def build_network(self):
60
+ """Build the network and load the pretrained weight."""
61
+ # network
62
+ module = importlib.import_module(f"spikezoo.archs.{self.cfg.model_name}.{self.cfg.model_file_name}")
63
+ model_cls = getattr(module, self.cfg.model_cls_name)
64
+ model = model_cls(**self.cfg.model_params)
65
+ if self.cfg.load_state and self.cfg.require_params:
66
+ load_folder = os.path.dirname(os.path.abspath(__file__))
67
+ weight_path = os.path.join(load_folder, self.cfg.ckpt_path)
68
+ if os.path.exists(weight_path) == False:
69
+ os.makedirs(os.path.dirname(weight_path), exist_ok=True)
70
+ self.download_weight(weight_path)
71
+ time.sleep(0.5)
72
+ model = load_network(weight_path, model)
73
+ return model
74
+
75
+ def save_network(self, save_path):
76
+ """Save the network."""
77
+ network = self.net
78
+ if isinstance(network, nn.DataParallel):
79
+ network = network.module
80
+ state_dict = network.state_dict()
81
+ for key, param in state_dict.items():
82
+ state_dict[key] = param.cpu()
83
+ torch.save(state_dict, save_path)
84
+
85
+ def download_weight(self, weight_path):
86
+ """Download the pretrained weight from the given url."""
87
+ url = self.cfg.base_url + os.path.basename(self.cfg.ckpt_path)
88
+ download_file(url, weight_path)
89
+
90
+ def crop_spike_length(self, spike):
91
+ """Crop the spike length."""
92
+ spike_length = spike.shape[1]
93
+ spike_mid = spike_length // 2
94
+ assert spike_length >= self.cfg.model_win_length, f"Spike input is not long enough, given {spike_length} frames < {self.cfg.model_win_length}."
95
+ # even length
96
+ if self.cfg.model_win_length == self.model_half_win_length * 2:
97
+ spike = spike[
98
+ :,
99
+ spike_mid - self.model_half_win_length : spike_mid + self.model_half_win_length,
100
+ ]
101
+ # odd length
102
+ else:
103
+ spike = spike[
104
+ :,
105
+ spike_mid - self.model_half_win_length : spike_mid + self.model_half_win_length + 1,
106
+ ]
107
+ if self.spike_size == None:
108
+ self.spike_size = (spike.shape[2], spike.shape[3])
109
+ return spike
110
+
111
+ def preprocess_spike(self, spike):
112
+ """Preprocess the spike (length size)."""
113
+ spike = self.crop_spike_length(spike)
114
+ return spike
115
+
116
+ def postprocess_img(self, image):
117
+ """Postprocess the image."""
118
+ return image
119
+
120
+ # -------------------- Training Part --------------------
121
+ def setup_training(self, pipeline_cfg):
122
+ """Setup training optimizer and loss."""
123
+ from spikezoo.pipeline import TrainPipelineConfig
124
+
125
+ cfg: TrainPipelineConfig = pipeline_cfg
126
+ self.optimizer = Adam(self.net.parameters(), lr=cfg.lr, betas=(0.9, 0.99), weight_decay=0)
127
+ self.scheduler = CosineAnnealingLR(self.optimizer, T_max=cfg.epochs, eta_min=0)
128
+ self.criterion = nn.L1Loss()
129
+
130
+ def get_outputs_dict(self, batch):
131
+ """Get the output dict for the given input batch. (Designed for the training mode considering possible auxiliary output.)"""
132
+ # data process
133
+ spike = batch["spike"]
134
+ # outputs
135
+ outputs = {}
136
+ recon_img = self(spike)
137
+ outputs["recon_img"] = recon_img
138
+ return outputs
139
+
140
+ def get_visual_dict(self, batch, outputs):
141
+ """Get the visual dict from the given input batch and outputs."""
142
+ visual_dict = {}
143
+ visual_dict["recon"] = outputs["recon_img"]
144
+ visual_dict["img"] = batch["img"]
145
+ return visual_dict
146
+
147
+ def get_loss_dict(self, outputs, batch):
148
+ """Get the loss dict from the given input batch and outputs."""
149
+ # data process
150
+ gt_img = batch["img"]
151
+ # recon image
152
+ recon_img = outputs["recon_img"]
153
+ # loss dict
154
+ loss_dict = {}
155
+ loss_dict["l1"] = self.criterion(recon_img, gt_img)
156
+ loss_values_dict = {k: v.item() for k, v in loss_dict.items()}
157
+ return loss_dict,loss_values_dict
158
+
159
+ def get_paired_imgs(self, batch, outputs):
160
+ recon_img = outputs["recon_img"]
161
+ img = batch["img"]
162
+ return recon_img, img
163
+
164
+ def optimize_parameters(self, loss_dict):
165
+ """Optimize the parameters from the loss_dict."""
166
+ loss = functools.reduce(torch.add, loss_dict.values())
167
+ self.optimizer.zero_grad()
168
+ loss.backward()
169
+ self.optimizer.step()
170
+
171
+ def update_learning_rate(self):
172
+ """Update the learning rate."""
173
+ self.scheduler.step()
174
+
175
+ def feed_to_device(self, batch):
176
+ batch = {k: v.to(self.device, non_blocking=True) if torch.is_tensor(v) else v for k, v in batch.items()}
177
+ return batch