spikezoo 0.2.2__py3-none-any.whl → 0.2.3.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (86) hide show
  1. spikezoo/__init__.py +23 -7
  2. spikezoo/archs/bsf/models/bsf/bsf.py +37 -25
  3. spikezoo/archs/bsf/models/bsf/rep.py +2 -2
  4. spikezoo/archs/spk2imgnet/nets.py +1 -1
  5. spikezoo/archs/ssir/models/networks.py +1 -1
  6. spikezoo/archs/ssml/model.py +9 -5
  7. spikezoo/archs/stir/metrics/losses.py +1 -1
  8. spikezoo/archs/stir/models/networks_STIR.py +16 -9
  9. spikezoo/archs/tfi/nets.py +1 -1
  10. spikezoo/archs/tfp/nets.py +1 -1
  11. spikezoo/archs/wgse/dwtnets.py +6 -6
  12. spikezoo/datasets/__init__.py +11 -9
  13. spikezoo/datasets/base_dataset.py +10 -3
  14. spikezoo/datasets/realworld_dataset.py +1 -3
  15. spikezoo/datasets/{reds_small_dataset.py → reds_base_dataset.py} +9 -8
  16. spikezoo/datasets/reds_ssir_dataset.py +181 -0
  17. spikezoo/datasets/szdata_dataset.py +5 -15
  18. spikezoo/datasets/uhsr_dataset.py +4 -3
  19. spikezoo/models/__init__.py +8 -6
  20. spikezoo/models/base_model.py +120 -64
  21. spikezoo/models/bsf_model.py +11 -3
  22. spikezoo/models/spcsnet_model.py +19 -0
  23. spikezoo/models/spikeclip_model.py +4 -3
  24. spikezoo/models/spk2imgnet_model.py +9 -15
  25. spikezoo/models/ssir_model.py +4 -6
  26. spikezoo/models/ssml_model.py +44 -2
  27. spikezoo/models/stir_model.py +26 -5
  28. spikezoo/models/tfi_model.py +3 -1
  29. spikezoo/models/tfp_model.py +4 -2
  30. spikezoo/models/wgse_model.py +8 -14
  31. spikezoo/pipeline/base_pipeline.py +79 -55
  32. spikezoo/pipeline/ensemble_pipeline.py +10 -9
  33. spikezoo/pipeline/train_cfgs.py +89 -0
  34. spikezoo/pipeline/train_pipeline.py +129 -30
  35. spikezoo/utils/optimizer_utils.py +22 -0
  36. spikezoo/utils/other_utils.py +31 -6
  37. spikezoo/utils/scheduler_utils.py +25 -0
  38. spikezoo/utils/spike_utils.py +61 -29
  39. spikezoo-0.2.3.2.dist-info/METADATA +263 -0
  40. {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.2.dist-info}/RECORD +43 -80
  41. spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
  42. spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
  43. spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
  44. spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
  45. spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
  46. spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
  47. spikezoo/archs/spikeformer/CheckPoints/readme +0 -1
  48. spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +0 -60
  49. spikezoo/archs/spikeformer/DataProcess/DataLoader.py +0 -115
  50. spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +0 -39
  51. spikezoo/archs/spikeformer/EvalResults/readme +0 -1
  52. spikezoo/archs/spikeformer/LICENSE +0 -21
  53. spikezoo/archs/spikeformer/Metrics/Metrics.py +0 -50
  54. spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
  55. spikezoo/archs/spikeformer/Model/Loss.py +0 -89
  56. spikezoo/archs/spikeformer/Model/SpikeFormer.py +0 -230
  57. spikezoo/archs/spikeformer/Model/__init__.py +0 -0
  58. spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
  59. spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
  60. spikezoo/archs/spikeformer/README.md +0 -30
  61. spikezoo/archs/spikeformer/evaluate.py +0 -87
  62. spikezoo/archs/spikeformer/recon_real_data.py +0 -97
  63. spikezoo/archs/spikeformer/requirements.yml +0 -95
  64. spikezoo/archs/spikeformer/train.py +0 -173
  65. spikezoo/archs/spikeformer/utils.py +0 -22
  66. spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
  67. spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
  68. spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
  69. spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
  70. spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
  71. spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
  72. spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
  73. spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
  74. spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
  75. spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
  76. spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
  77. spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
  78. spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
  79. spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
  80. spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
  81. spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
  82. spikezoo/models/spikeformer_model.py +0 -50
  83. spikezoo-0.2.2.dist-info/METADATA +0 -196
  84. {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.2.dist-info}/LICENSE.txt +0 -0
  85. {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.2.dist-info}/WHEEL +0 -0
  86. {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.2.dist-info}/top_level.txt +0 -0
spikezoo/__init__.py CHANGED
@@ -1,13 +1,29 @@
1
- from .utils.spike_utils import load_vidar_dat
1
+ from .utils.spike_utils import *
2
2
  from .models import model_list
3
3
  from .datasets import dataset_list
4
4
  from .metrics import metric_all_names
5
5
 
6
- def get_datasets():
7
- return dataset_list
6
+ # METHOD NAME DEFINITION
7
+ METHODS = model_list
8
+ class METHOD:
9
+ BASE = "base"
10
+ TFP = "tfp"
11
+ TFI = "tfi"
12
+ SPK2IMGNET = "spk2imgnet"
13
+ WGSE = "wgse"
14
+ SSML = "ssml"
15
+ BSF = "bsf"
16
+ STIR = "stir"
17
+ SSIR = "ssir"
18
+ SPIKECLIP = "spikeclip"
8
19
 
9
- def get_models():
10
- return model_list
20
+ # DATASET NAME DEFINITION
21
+ DATASETS = dataset_list
22
+ class DATASET:
23
+ BASE = "base"
24
+ REDS_BASE = "reds_base"
25
+ REALWORLD = "realworld"
26
+ UHSR = "uhsr"
11
27
 
12
- def get_metrics():
13
- return metric_all_names
28
+ # METRIC NAME DEFINITION
29
+ METRICS = metric_all_names
@@ -8,18 +8,18 @@ from .align import Multi_Granularity_Align
8
8
  class BasicModel(nn.Module):
9
9
  def __init__(self):
10
10
  super().__init__()
11
-
11
+
12
12
  ####################################################################################
13
13
  ## Tools functions for neural networks
14
14
  def weight_parameters(self):
15
- return [param for name, param in self.named_parameters() if 'weight' in name]
15
+ return [param for name, param in self.named_parameters() if "weight" in name]
16
16
 
17
17
  def bias_parameters(self):
18
- return [param for name, param in self.named_parameters() if 'bias' in name]
18
+ return [param for name, param in self.named_parameters() if "bias" in name]
19
19
 
20
20
  def num_parameters(self):
21
21
  return sum([p.data.nelement() if p.requires_grad else 0 for p in self.parameters()])
22
-
22
+
23
23
  def init_weights(self):
24
24
  for layer in self.named_modules():
25
25
  if isinstance(layer, nn.Conv2d):
@@ -33,12 +33,21 @@ class BasicModel(nn.Module):
33
33
  nn.init.constant_(layer.bias, 0)
34
34
 
35
35
 
36
- def split_and_b_cat(x):
37
- x0 = x[:, 10-10:10+10+1].clone()
38
- x1 = x[:, 20-10:20+10+1].clone()
39
- x2 = x[:, 30-10:30+10+1].clone()
40
- x3 = x[:, 40-10:40+10+1].clone()
41
- x4 = x[:, 50-10:50+10+1].clone()
36
+ from typing import Literal
37
+
38
+
39
+ def split_and_b_cat(x, spike_dim: Literal[41, 61] = 61):
40
+ if spike_dim == 61:
41
+ win_r = 10
42
+ win_step = 10
43
+ elif spike_dim == 41:
44
+ win_r = 6
45
+ win_step = 7
46
+ x0 = x[:, 0 : 2 * win_r + 1, :, :].clone()
47
+ x1 = x[:, win_step : win_step + 2 * win_r + 1, :, :].clone()
48
+ x2 = x[:, 2 * win_step : 2 * win_step + 2 * win_r + 1, :, :].clone()
49
+ x3 = x[:, 3 * win_step : 3 * win_step + 2 * win_r + 1, :, :].clone()
50
+ x4 = x[:, 4 * win_step : 4 * win_step + 2 * win_r + 1, :, :].clone()
42
51
  return torch.cat([x0, x1, x2, x3, x4], dim=0)
43
52
 
44
53
 
@@ -61,39 +70,42 @@ class Encoder(nn.Module):
61
70
  x = self.act(conv(x) + x)
62
71
  return x
63
72
 
73
+
64
74
  ##########################################################################
65
75
  class BSF(BasicModel):
66
- def __init__(self, act=nn.ReLU()):
76
+ def __init__(self, spike_dim=61, act=nn.ReLU()):
67
77
  super().__init__()
78
+ self.spike_dim = spike_dim
68
79
  self.offset_groups = 4
69
80
  self.corr_max_disp = 3
70
-
71
- self.rep = MODF(base_dim=64, act=act)
72
-
81
+ if spike_dim == 61:
82
+ self.rep = MODF(in_dim=21,base_dim=64, act=act)
83
+ elif spike_dim == 41:
84
+ self.rep = MODF(in_dim=13,base_dim=64, act=act)
73
85
  self.encoder = Encoder(base_dim=64, layers=4, act=act)
74
86
 
75
87
  self.align = Multi_Granularity_Align(base_dim=64, groups=self.offset_groups, act=act, sc=3)
76
88
 
77
89
  self.recons = nn.Sequential(
78
- nn.Conv2d(64*5, 64*3, kernel_size=3, padding=1),
90
+ nn.Conv2d(64 * 5, 64 * 3, kernel_size=3, padding=1),
79
91
  act,
80
- nn.Conv2d(64*3, 64, kernel_size=3, padding=1),
92
+ nn.Conv2d(64 * 3, 64, kernel_size=3, padding=1),
81
93
  act,
82
94
  nn.Conv2d(64, 1, kernel_size=3, padding=1),
83
95
  )
84
96
 
85
97
  def forward(self, input_dict):
86
- dsft_dict = input_dict['dsft_dict']
87
- dsft11 = dsft_dict['dsft11']
88
- dsft12 = dsft_dict['dsft12']
89
- dsft21 = dsft_dict['dsft21']
90
- dsft22 = dsft_dict['dsft22']
98
+ dsft_dict = input_dict["dsft_dict"]
99
+ dsft11 = dsft_dict["dsft11"]
100
+ dsft12 = dsft_dict["dsft12"]
101
+ dsft21 = dsft_dict["dsft21"]
102
+ dsft22 = dsft_dict["dsft22"]
91
103
 
92
104
  dsft_b_cat = {
93
- 'dsft11': split_and_b_cat(dsft11),
94
- 'dsft12': split_and_b_cat(dsft12),
95
- 'dsft21': split_and_b_cat(dsft21),
96
- 'dsft22': split_and_b_cat(dsft22),
105
+ "dsft11": split_and_b_cat(dsft11, self.spike_dim),
106
+ "dsft12": split_and_b_cat(dsft12, self.spike_dim),
107
+ "dsft21": split_and_b_cat(dsft21, self.spike_dim),
108
+ "dsft22": split_and_b_cat(dsft22, self.spike_dim),
97
109
  }
98
110
 
99
111
  feat_b_cat = self.rep(dsft_b_cat)
@@ -2,11 +2,11 @@ import torch
2
2
  import torch.nn as nn
3
3
 
4
4
  class MODF(nn.Module):
5
- def __init__(self, base_dim=64, act=nn.ReLU()):
5
+ def __init__(self, in_dim = 21, base_dim=64, act=nn.ReLU()):
6
6
  super().__init__()
7
7
  self.base_dim = base_dim
8
8
 
9
- self.conv1 = self._make_layer(input_dim=21, hidden_dim=self.base_dim, output_dim=self.base_dim, act=act)
9
+ self.conv1 = self._make_layer(input_dim=in_dim, hidden_dim=self.base_dim, output_dim=self.base_dim, act=act)
10
10
  self.conv_for_others = nn.ModuleList([
11
11
  self._make_layer(input_dim=self.base_dim, hidden_dim=self.base_dim, output_dim=self.base_dim, act=act) for ii in range(3)
12
12
  ])
@@ -167,7 +167,7 @@ class FusionMaskV1(nn.Module):
167
167
 
168
168
  # current best model
169
169
  class SpikeNet(nn.Module):
170
- def __init__(self, in_channels, features, out_channels, win_r, win_step):
170
+ def __init__(self, in_channels = 13, features = 64, out_channels = 1, win_r = 6, win_step = 7):
171
171
  super(SpikeNet, self).__init__()
172
172
  self.extractor = FeatureExtractor(
173
173
  in_channels=in_channels,
@@ -56,6 +56,6 @@ class SSIR(BasicModel):
56
56
  out3 = self.pred3(x7)
57
57
 
58
58
  if self.training:
59
- return [out3]
59
+ return out3
60
60
  else:
61
61
  return out3
@@ -272,18 +272,22 @@ class BSN(nn.Module):
272
272
  diff = W - H
273
273
  x0 = x0[:, :, (diff // 2):(diff // 2 + H), 0:W]
274
274
 
275
- return x0,tfi_label,tfp_label
275
+ return x0
276
276
 
277
277
  class DoubleNet(nn.Module):
278
278
  def __init__(self):
279
279
  super().__init__()
280
280
  self.nbsn = BSN(n_channels=41, n_output=1,blind=False)
281
- # self.bsn = BSN(n_channels=41, n_output=1,blind=True)
281
+ self.bsn = BSN(n_channels=41, n_output=1,blind=True)
282
282
 
283
283
  def forward(self, x):
284
- out1,_,_ = self.nbsn(x)
285
-
286
- return out1
284
+ if self.training:
285
+ bsn_pred = self.bsn(x)
286
+ nbsn_pred = self.nbsn(x)
287
+ return bsn_pred,nbsn_pred
288
+ else:
289
+ nbsn_pred = self.nbsn(x)
290
+ return nbsn_pred
287
291
 
288
292
  if __name__ == '__main__':
289
293
  a=DoubleNet().cuda()
@@ -6,7 +6,7 @@ import torch.nn.functional as F
6
6
 
7
7
  import math
8
8
 
9
- from package_core.losses import *
9
+ from ..package_core.package_core.losses import *
10
10
 
11
11
  def compute_l1_loss(img_list, gt):
12
12
  l1_loss = 0.0
@@ -292,16 +292,21 @@ class STIRDecorder(nn.Module):#second and third levels
292
292
 
293
293
  ##############################Our Model####################################
294
294
  class STIR(BasicModel):
295
- def __init__(self, hidd_chs=8, win_r=6, win_step=7):
295
+ def __init__(self, spike_dim = 61,hidd_chs=8, win_r=6, win_step=7):
296
296
  super().__init__()
297
297
 
298
298
  self.init_chs = [16, 24, 32, 64, 96]
299
299
  self.hidd_chs = hidd_chs
300
+ self.spike_dim = spike_dim
300
301
  self.attn_num_splits = 1
301
302
 
302
303
  self.N_group = 3
303
-
304
- dim_tfp = 16
304
+ if spike_dim == 61:
305
+ self.resnet = ResidualBlock(in_channles=21, num_channles=11, use_1x1conv=True)
306
+ dim_tfp = 16 # 5 + num_channels
307
+ elif spike_dim == 41:
308
+ self.resnet = ResidualBlock(in_channles=15, num_channles=11, use_1x1conv=True)
309
+ dim_tfp = 15 # 4 + num_channels
305
310
  self.encoder = ImageEncoder(in_chs=dim_tfp, init_chs=self.init_chs)
306
311
 
307
312
  self.transformer = CrossTransformerBlock(dim=self.init_chs[-1], num_heads=4, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias')
@@ -314,14 +319,16 @@ class STIR(BasicModel):
314
319
  self.win_r = win_r
315
320
  self.win_step = win_step
316
321
 
317
- self.resnet = ResidualBlock(in_channles=21, num_channles=11, use_1x1conv=True)
318
-
319
322
  def forward(self, x):
320
323
  b,_,h,w=x.size()
321
-
322
- block1 = x[:, 0 : 21, :, :]
323
- block2 = x[:, 20 : 41, :, :]
324
- block3 = x[:, 40 : 61, :, :]
324
+ if self.spike_dim == 61:
325
+ block1 = x[:, 0 : 21, :, :]
326
+ block2 = x[:, 20 : 41, :, :]
327
+ block3 = x[:, 40 : 61, :, :]
328
+ elif self.spike_dim == 41:
329
+ block1 = x[:, 0 : 15, :, :]
330
+ block2 = x[:, 13 : 28, :, :]
331
+ block3 = x[:, 26 : 41, :, :]
325
332
 
326
333
  repre1 = TFP(block1, channel_step=2)#C: 5
327
334
  repre2 = TFP(block2, channel_step=2)
@@ -6,7 +6,7 @@ import torch
6
6
 
7
7
 
8
8
  class TFIModel(nn.Module):
9
- def __init__(self, model_win_length):
9
+ def __init__(self, model_win_length = 41):
10
10
  super(TFIModel, self).__init__()
11
11
  self.window = model_win_length
12
12
  self.hald_window = model_win_length // 2
@@ -3,7 +3,7 @@ import torch
3
3
 
4
4
 
5
5
  class TFPModel(nn.Module):
6
- def __init__(self, model_win_length):
6
+ def __init__(self, model_win_length = 41):
7
7
  self.window = model_win_length
8
8
  super(TFPModel, self).__init__()
9
9
 
@@ -94,15 +94,15 @@ class Dwt1dModule_Tcn(nn.Module):
94
94
  class Dwt1dResnetX_TCN(nn.Module):
95
95
  def __init__(
96
96
  self,
97
- wvlname='db1',
98
- J=3,
99
- yl_size=14,
100
- yh_size=[26, 18, 14],
101
- num_residual_blocks=2,
97
+ wvlname='db8',
98
+ J=5,
99
+ yl_size=15,
100
+ yh_size=[28, 21, 18, 16, 15],
101
+ num_residual_blocks=3,
102
102
  norm=None,
103
103
  inc=41,
104
104
  ks=3,
105
- store_features=False
105
+ store_features=True
106
106
  ):
107
107
  super().__init__()
108
108
 
@@ -4,28 +4,30 @@ import importlib, inspect
4
4
  import os
5
5
  import torch
6
6
  from typing import Literal
7
+ from spikezoo.utils.other_utils import getattr_case_insensitive
7
8
 
8
9
  # todo auto detect/register datasets
9
10
  files_list = os.listdir(os.path.dirname(os.path.abspath(__file__)))
10
11
  dataset_list = [file.replace("_dataset.py", "") for file in files_list if file.endswith("_dataset.py")]
11
12
 
13
+
12
14
  # todo register function
13
15
  def build_dataset_cfg(cfg: BaseDatasetConfig, split: Literal["train", "test"] = "test"):
14
16
  """Build the dataset from the given dataset config."""
15
17
  # build new cfg according to split
16
- cfg = replace(cfg,split = split,spike_length = cfg.spike_length_train if split == "train" else cfg.spike_length_test)
18
+ cfg = replace(cfg, split=split, spike_length=cfg.spike_length_train if split == "train" else cfg.spike_length_test)
17
19
  # dataset module
18
20
  module_name = cfg.dataset_name + "_dataset"
19
21
  assert cfg.dataset_name in dataset_list, f"Given dataset {cfg.dataset_name} not in our dataset list {dataset_list}."
20
22
  module_name = "spikezoo.datasets." + module_name
21
23
  module = importlib.import_module(module_name)
22
24
  # dataset,dataset_config
23
- classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
24
- dataset_cls: BaseDataset = getattr(module, classes[0])
25
+ dataset_name = cfg.dataset_name
26
+ dataset_name = dataset_name + "Dataset" if dataset_name == "base" else dataset_name
27
+ dataset_cls: BaseDataset = getattr_case_insensitive(module, dataset_name)
25
28
  dataset = dataset_cls(cfg)
26
29
  return dataset
27
30
 
28
-
29
31
  def build_dataset_name(dataset_name: str, split: Literal["train", "test"] = "test"):
30
32
  """Build the default dataset from the given name."""
31
33
  module_name = dataset_name + "_dataset"
@@ -33,21 +35,21 @@ def build_dataset_name(dataset_name: str, split: Literal["train", "test"] = "tes
33
35
  module_name = "spikezoo.datasets." + module_name
34
36
  module = importlib.import_module(module_name)
35
37
  # dataset,dataset_config
36
- classes = sorted([name for name, obj in inspect.getmembers(module) if inspect.isclass(obj) and obj.__module__ == module.__name__])
37
- dataset_cls: BaseDataset = getattr(module, classes[0])
38
- dataset_cfg: BaseDatasetConfig = getattr(module, classes[1])(split=split)
38
+ dataset_name = dataset_name + "Dataset" if dataset_name == "base" else dataset_name
39
+ dataset_cls: BaseDataset = getattr_case_insensitive(module, dataset_name)
40
+ dataset_cfg: BaseDatasetConfig = getattr_case_insensitive(module, dataset_name + "config")(split=split)
39
41
  dataset = dataset_cls(dataset_cfg)
40
42
  return dataset
41
43
 
42
44
 
43
45
  # todo to modify according to the basicsr
44
- def build_dataloader(dataset: BaseDataset,cfg = None):
46
+ def build_dataloader(dataset: BaseDataset, cfg=None):
45
47
  # train dataloader
46
48
  if dataset.cfg.split == "train":
47
49
  if cfg is None:
48
50
  return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
49
51
  else:
50
- return torch.utils.data.DataLoader(dataset, batch_size=cfg.bs_train, shuffle=True, num_workers=cfg.num_workers,pin_memory=cfg.pin_memory)
52
+ return torch.utils.data.DataLoader(dataset, batch_size=cfg.bs_train, shuffle=True, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory)
51
53
  # test dataloader
52
54
  elif dataset.cfg.split == "test":
53
55
  return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
@@ -6,7 +6,7 @@ import numpy as np
6
6
  from spikezoo.utils.spike_utils import load_vidar_dat
7
7
  import re
8
8
  from dataclasses import dataclass, replace
9
- from typing import Literal,Union
9
+ from typing import Literal, Union
10
10
  import warnings
11
11
  import torch
12
12
  from tqdm import tqdm
@@ -19,7 +19,7 @@ class BaseDatasetConfig:
19
19
  "Dataset name."
20
20
  dataset_name: str = "base"
21
21
  "Directory specifying location of data."
22
- root_dir: Union[str,Path] = Path(__file__).parent.parent / Path("data/base")
22
+ root_dir: Union[str, Path] = Path(__file__).parent.parent / Path("data/base")
23
23
  "Image width."
24
24
  width: int = 400
25
25
  "Image height."
@@ -46,6 +46,8 @@ class BaseDatasetConfig:
46
46
  use_cache: bool = False
47
47
  "Crop size."
48
48
  crop_size: tuple = (-1, -1)
49
+ "Rate. (-1 denotes variant)"
50
+ rate: float = 1
49
51
 
50
52
  # post process
51
53
  def __post_init__(self):
@@ -54,6 +56,7 @@ class BaseDatasetConfig:
54
56
  # todo try download
55
57
  assert self.root_dir.exists(), f"No files found in {self.root_dir} for the specified dataset `{self.dataset_name}`."
56
58
 
59
+
57
60
  # todo cache mechanism
58
61
  class BaseDataset(Dataset):
59
62
  def __init__(self, cfg: BaseDatasetConfig):
@@ -80,7 +83,11 @@ class BaseDataset(Dataset):
80
83
  if self.cfg.use_aug == True and self.cfg.split == "train":
81
84
  spike, img = self.augmentor(spike, img)
82
85
 
83
- batch = {"spike": spike, "img": img}
86
+ # rate
87
+ rate = self.cfg.rate
88
+
89
+ # ! spike and gt_img names are necessary
90
+ batch = {"spike": spike, "gt_img": img, "rate": rate}
84
91
  return batch
85
92
 
86
93
  # todo: To be overridden
@@ -4,7 +4,7 @@ from dataclasses import dataclass
4
4
 
5
5
 
6
6
  @dataclass
7
- class RealWorld_Config(BaseDatasetConfig):
7
+ class RealWorldConfig(BaseDatasetConfig):
8
8
  dataset_name: str = "realworld"
9
9
  root_dir: Path = Path(__file__).parent.parent / Path("data/recVidarReal2019")
10
10
  width: int = 400
@@ -21,5 +21,3 @@ class RealWorld(BaseDataset):
21
21
  def prepare_data(self):
22
22
  self.spike_dir = self.cfg.root_dir
23
23
  self.spike_list = self.get_spike_files(self.spike_dir)
24
-
25
-
@@ -4,10 +4,11 @@ from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
4
4
  from dataclasses import dataclass
5
5
  import re
6
6
 
7
+
7
8
  @dataclass
8
- class REDS_Small_Config(BaseDatasetConfig):
9
- dataset_name: str = "reds_small"
10
- root_dir: Path = Path(__file__).parent.parent / Path("data/REDS_Small")
9
+ class REDS_BASEConfig(BaseDatasetConfig):
10
+ dataset_name: str = "reds_base"
11
+ root_dir: Path = Path(__file__).parent.parent / Path("data/REDS_BASE")
11
12
  width: int = 400
12
13
  height: int = 250
13
14
  with_img: bool = True
@@ -15,13 +16,13 @@ class REDS_Small_Config(BaseDatasetConfig):
15
16
  spike_length_test: int = 301
16
17
  spike_dir_name: str = "spike"
17
18
  img_dir_name: str = "gt"
19
+ rate: float = 0.6
20
+
18
21
 
19
- class REDS_Small(BaseDataset):
22
+ class REDS_BASE(BaseDataset):
20
23
  def __init__(self, cfg: BaseDatasetConfig):
21
- super(REDS_Small, self).__init__(cfg)
24
+ super(REDS_BASE, self).__init__(cfg)
22
25
 
23
26
  def prepare_data(self):
24
27
  super().prepare_data()
25
- if self.cfg.split == "train":
26
- self.img_list = [self.img_dir / Path(str(s.name).replace('.dat','.png')) for s in self.spike_list]
27
-
28
+ self.img_list = [self.img_dir / Path(str(s.name).replace(".dat", ".png")) for s in self.spike_list]
@@ -0,0 +1,181 @@
1
+ from torch.utils.data import Dataset
2
+ from pathlib import Path
3
+ from spikezoo.datasets.base_dataset import BaseDataset, BaseDatasetConfig
4
+ from dataclasses import dataclass
5
+ import re
6
+
7
+
8
+ @dataclass
9
+ class REDS_SSIRConfig(BaseDatasetConfig):
10
+ dataset_name: str = "reds_ssir"
11
+ root_dir: Path = Path(__file__).parent.parent / Path("data/REDS_SSIR")
12
+ train_width: int = 96
13
+ train_height: int = 96
14
+ test_width: int = 1280
15
+ test_height: int = 720
16
+ width: int = -1
17
+ height: int = -1
18
+ with_img: bool = True
19
+ spike_length_train: int = 41
20
+ spike_length_test: int = 301
21
+
22
+ # post process
23
+ def __post_init__(self):
24
+ self.root_dir = Path(self.root_dir) if isinstance(self.root_dir, str) else self.root_dir
25
+ # todo try download
26
+ assert self.root_dir.exists(), f"No files found in {self.root_dir} for the specified dataset `{self.dataset_name}`."
27
+ # train/test split
28
+ if self.split == "train":
29
+ self.spike_length = self.spike_length_train
30
+ self.width = self.train_width
31
+ self.height = self.train_height
32
+ else:
33
+ self.spike_length = self.spike_length_test
34
+ self.width = self.test_width
35
+ self.height = self.test_height
36
+
37
+
38
+ class REDS_SSIR(BaseDataset):
39
+ def __init__(self, cfg: BaseDatasetConfig):
40
+ super(REDS_SSIR, self).__init__(cfg)
41
+
42
+ def prepare_data(self):
43
+ """Specify the spike and image files to be loaded."""
44
+ # spike/imgs dir train/test
45
+ if self.cfg.split == "train":
46
+ self.img_dir = self.cfg.root_dir / Path("crop_mini/spike/train/interp_20_alpha_0.40")
47
+ self.spike_dir = self.cfg.root_dir / Path("crop_mini/image/train/train_orig")
48
+ else:
49
+ self.img_dir = self.cfg.root_dir / Path("imgs/val/val_orig")
50
+ self.spike_dir = self.cfg.root_dir / Path("spike/val/interp_20_alpha_0.40")
51
+ # get files
52
+ self.spike_list = self.get_spike_files(self.spike_dir)
53
+ self.img_list = []
54
+
55
+
56
+ class sreds_train(torch.utils.data.Dataset):
57
+ def __init__(self, cfg):
58
+ self.cfg = cfg
59
+ self.pair_step = self.cfg["loader"]["pair_step"]
60
+ self.augmentor = Augmentor(crop_size=self.cfg["loader"]["crop_size"])
61
+ self.samples = self.collect_samples()
62
+ print("The samples num of training data: {:d}".format(len(self.samples)))
63
+
64
+ def confirm_exist(self, path_list_list):
65
+ for pl in path_list_list:
66
+ for p in pl:
67
+ if not osp.exists(p):
68
+ return 0
69
+ return 1
70
+
71
+ def collect_samples(self):
72
+ spike_path = osp.join(
73
+ self.cfg["data"]["root"], "crop_mini", "spike", "train", "interp_{:d}_alpha_{:.2f}".format(self.cfg["data"]["interp"], self.cfg["data"]["alpha"])
74
+ )
75
+ image_path = osp.join(self.cfg["data"]["root"], "crop_mini", "image", "train", "train_orig")
76
+ scene_list = sorted(os.listdir(spike_path))
77
+ samples = []
78
+
79
+ for scene in scene_list:
80
+ spike_dir = osp.join(spike_path, scene)
81
+ image_dir = osp.join(image_path, scene)
82
+ spk_path_list = sorted(os.listdir(spike_dir))
83
+
84
+ spklen = len(spk_path_list)
85
+ seq_len = self.cfg["model"]["seq_len"] + 2
86
+ """
87
+ for st in range(0, spklen - ((spklen - self.pair_step) % seq_len) - seq_len, self.pair_step):
88
+ # 按照文件名称读取
89
+ spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(st, st+seq_len)]
90
+ images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4]+'.png') for ii in range(st, st+seq_len)]
91
+
92
+ if(self.confirm_exist([spikes_path_list, images_path_list])):
93
+ s = {}
94
+ s['spikes_paths'] = spikes_path_list
95
+ s['images_paths'] = images_path_list
96
+ samples.append(s)
97
+ """
98
+ # 按照文件名称读取
99
+ spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(spklen)]
100
+ images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4] + ".png") for ii in range(spklen)]
101
+
102
+ if self.confirm_exist([spikes_path_list, images_path_list]):
103
+ s = {}
104
+ s["spikes_paths"] = spikes_path_list
105
+ s["images_paths"] = images_path_list
106
+ samples.append(s)
107
+
108
+ return samples
109
+
110
+ def _load_sample(self, s):
111
+ data = {}
112
+
113
+ data["spikes"] = [np.array(dat_to_spmat(p, size=(96, 96)), dtype=np.float32) for p in s["spikes_paths"]]
114
+ data["images"] = [read_img_gray(p) for p in s["images_paths"]]
115
+
116
+ data["spikes"], data["images"] = self.augmentor(data["spikes"], data["images"])
117
+ # print("data['spikes'][0].shape, data['images'][0].shape", data['spikes'][0].shape, data['images'][0].shape)
118
+
119
+ return data
120
+
121
+ def __len__(self):
122
+ return len(self.samples)
123
+
124
+ def __getitem__(self, index):
125
+ data = self._load_sample(self.samples[index])
126
+ return data
127
+
128
+
129
+ class sreds_test(torch.utils.data.Dataset):
130
+ def __init__(self, cfg):
131
+ self.cfg = cfg
132
+ self.samples = self.collect_samples()
133
+ print("The samples num of testing data: {:d}".format(len(self.samples)))
134
+
135
+ def confirm_exist(self, path_list_list):
136
+ for pl in path_list_list:
137
+ for p in pl:
138
+ if not osp.exists(p):
139
+ return 0
140
+ return 1
141
+
142
+ def collect_samples(self):
143
+ spike_path = osp.join(
144
+ self.cfg["data"]["root"], "spike", "val", "interp_{:d}_alpha_{:.2f}".format(self.cfg["data"]["interp"], self.cfg["data"]["alpha"])
145
+ )
146
+ image_path = osp.join(self.cfg["data"]["root"], "imgs", "val", "val_orig")
147
+ scene_list = sorted(os.listdir(spike_path))
148
+ samples = []
149
+
150
+ for scene in scene_list:
151
+ spike_dir = osp.join(spike_path, scene)
152
+ image_dir = osp.join(image_path, scene)
153
+ spk_path_list = sorted(os.listdir(spike_dir))
154
+
155
+ spklen = len(spk_path_list)
156
+ # seq_len = self.cfg['model']['seq_len']
157
+
158
+ # 按照文件名称读取
159
+ spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(spklen)]
160
+ images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4] + ".png") for ii in range(spklen)]
161
+
162
+ if self.confirm_exist([spikes_path_list, images_path_list]):
163
+ s = {}
164
+ s["spikes_paths"] = spikes_path_list
165
+ s["images_paths"] = images_path_list
166
+ samples.append(s)
167
+
168
+ return samples
169
+
170
+ def _load_sample(self, s):
171
+ data = {}
172
+ data["spikes"] = [np.array(dat_to_spmat(p, size=(720, 1280)), dtype=np.float32) for p in s["spikes_paths"]]
173
+ data["images"] = [read_img_gray(p) for p in s["images_paths"]]
174
+ return data
175
+
176
+ def __len__(self):
177
+ return len(self.samples)
178
+
179
+ def __getitem__(self, index):
180
+ data = self._load_sample(self.samples[index])
181
+ return data