spikezoo 0.2.2__py3-none-any.whl → 0.2.3.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (86) hide show
  1. spikezoo/__init__.py +23 -7
  2. spikezoo/archs/bsf/models/bsf/bsf.py +37 -25
  3. spikezoo/archs/bsf/models/bsf/rep.py +2 -2
  4. spikezoo/archs/spk2imgnet/nets.py +1 -1
  5. spikezoo/archs/ssir/models/networks.py +1 -1
  6. spikezoo/archs/ssml/model.py +9 -5
  7. spikezoo/archs/stir/metrics/losses.py +1 -1
  8. spikezoo/archs/stir/models/networks_STIR.py +16 -9
  9. spikezoo/archs/tfi/nets.py +1 -1
  10. spikezoo/archs/tfp/nets.py +1 -1
  11. spikezoo/archs/wgse/dwtnets.py +6 -6
  12. spikezoo/datasets/__init__.py +11 -9
  13. spikezoo/datasets/base_dataset.py +10 -3
  14. spikezoo/datasets/realworld_dataset.py +1 -3
  15. spikezoo/datasets/{reds_small_dataset.py → reds_base_dataset.py} +9 -8
  16. spikezoo/datasets/reds_ssir_dataset.py +181 -0
  17. spikezoo/datasets/szdata_dataset.py +5 -15
  18. spikezoo/datasets/uhsr_dataset.py +4 -3
  19. spikezoo/models/__init__.py +8 -6
  20. spikezoo/models/base_model.py +120 -64
  21. spikezoo/models/bsf_model.py +11 -3
  22. spikezoo/models/spcsnet_model.py +19 -0
  23. spikezoo/models/spikeclip_model.py +4 -3
  24. spikezoo/models/spk2imgnet_model.py +9 -15
  25. spikezoo/models/ssir_model.py +4 -6
  26. spikezoo/models/ssml_model.py +44 -2
  27. spikezoo/models/stir_model.py +26 -5
  28. spikezoo/models/tfi_model.py +3 -1
  29. spikezoo/models/tfp_model.py +4 -2
  30. spikezoo/models/wgse_model.py +8 -14
  31. spikezoo/pipeline/base_pipeline.py +79 -55
  32. spikezoo/pipeline/ensemble_pipeline.py +10 -9
  33. spikezoo/pipeline/train_cfgs.py +89 -0
  34. spikezoo/pipeline/train_pipeline.py +129 -30
  35. spikezoo/utils/optimizer_utils.py +22 -0
  36. spikezoo/utils/other_utils.py +31 -6
  37. spikezoo/utils/scheduler_utils.py +25 -0
  38. spikezoo/utils/spike_utils.py +61 -29
  39. spikezoo-0.2.3.2.dist-info/METADATA +263 -0
  40. {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.2.dist-info}/RECORD +43 -80
  41. spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
  42. spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
  43. spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
  44. spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
  45. spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
  46. spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
  47. spikezoo/archs/spikeformer/CheckPoints/readme +0 -1
  48. spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +0 -60
  49. spikezoo/archs/spikeformer/DataProcess/DataLoader.py +0 -115
  50. spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +0 -39
  51. spikezoo/archs/spikeformer/EvalResults/readme +0 -1
  52. spikezoo/archs/spikeformer/LICENSE +0 -21
  53. spikezoo/archs/spikeformer/Metrics/Metrics.py +0 -50
  54. spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
  55. spikezoo/archs/spikeformer/Model/Loss.py +0 -89
  56. spikezoo/archs/spikeformer/Model/SpikeFormer.py +0 -230
  57. spikezoo/archs/spikeformer/Model/__init__.py +0 -0
  58. spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
  59. spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
  60. spikezoo/archs/spikeformer/README.md +0 -30
  61. spikezoo/archs/spikeformer/evaluate.py +0 -87
  62. spikezoo/archs/spikeformer/recon_real_data.py +0 -97
  63. spikezoo/archs/spikeformer/requirements.yml +0 -95
  64. spikezoo/archs/spikeformer/train.py +0 -173
  65. spikezoo/archs/spikeformer/utils.py +0 -22
  66. spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
  67. spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
  68. spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
  69. spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
  70. spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
  71. spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
  72. spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
  73. spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
  74. spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
  75. spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
  76. spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
  77. spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
  78. spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
  79. spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
  80. spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
  81. spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
  82. spikezoo/models/spikeformer_model.py +0 -50
  83. spikezoo-0.2.2.dist-info/METADATA +0 -196
  84. {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.2.dist-info}/LICENSE.txt +0 -0
  85. {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.2.dist-info}/WHEEL +0 -0
  86. {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.2.dist-info}/top_level.txt +0 -0
@@ -1,230 +0,0 @@
1
- '''
2
- Ref: To implenment SpikeFormer, we referred to the code of ”segefomer-pytorch” published on github
3
- (link: https://github.com/lucidrains/segformer-pytorch.git)
4
- '''
5
- from math import sqrt
6
- from functools import partial
7
- import torch
8
- from torch import nn, einsum
9
- from einops import rearrange, reduce
10
-
11
- # helpers
12
-
13
- def exists(val):
14
- return val is not None
15
-
16
- def cast_tuple(val, depth):
17
- return val if isinstance(val, tuple) else (val,) * depth
18
-
19
- LayerNorm = partial(nn.InstanceNorm2d, affine = True)
20
-
21
- # classes
22
-
23
- class DsConv2d(nn.Module):
24
- def __init__(self, dim_in, dim_out, kernel_size, padding, stride = 1, bias = True):
25
- super().__init__()
26
- self.net = nn.Sequential(
27
- nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
28
- nn.GELU(),
29
- nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias),
30
- nn.GELU(),
31
- )
32
- def forward(self, x):
33
- return self.net(x)
34
-
35
- class PreNorm(nn.Module):
36
- def __init__(self, dim, fn):
37
- super().__init__()
38
- self.fn = fn
39
- self.norm = LayerNorm(dim)
40
-
41
- def forward(self, x):
42
- # return self.fn(x)
43
- return self.fn(self.norm(x))
44
-
45
- class EfficientSelfAttention(nn.Module):
46
- def __init__(
47
- self,
48
- *,
49
- dim,
50
- heads,
51
- reduction_ratio
52
- ):
53
- super().__init__()
54
- self.scale = (dim // heads) ** -0.5
55
- self.heads = heads
56
-
57
- self.to_q = nn.Conv2d(dim, dim, 1, bias = False)
58
- self.to_kv = nn.Conv2d(dim, dim * 2, reduction_ratio, stride = reduction_ratio, bias = False)
59
- self.to_out = nn.Sequential(
60
- nn.Conv2d(dim, dim, 1, bias=False),
61
- nn.GELU()
62
- )
63
-
64
- def forward(self, x):
65
- h, w = x.shape[-2:]
66
- heads = self.heads
67
-
68
- q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
69
- q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = heads), (q, k, v))
70
-
71
- sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
72
- attn = sim.softmax(dim = -1)
73
-
74
- out = einsum('b i j, b j d -> b i d', attn, v)
75
- out = rearrange(out, '(b h) (x y) c -> b (h c) x y', h = heads, x = h, y = w)
76
- return self.to_out(out)
77
-
78
- class MixFeedForward(nn.Module):
79
- def __init__(
80
- self,
81
- *,
82
- dim,
83
- expansion_factor
84
- ):
85
- super().__init__()
86
- hidden_dim = dim * expansion_factor
87
- self.net = nn.Sequential(
88
- nn.Conv2d(dim, hidden_dim, 1),
89
- nn.GELU(),
90
- DsConv2d(hidden_dim, hidden_dim, 3, padding = 1),
91
- # nn.GELU(),
92
- nn.Conv2d(hidden_dim, dim, 1),
93
- nn.GELU(),
94
- )
95
-
96
- def forward(self, x):
97
- return self.net(x)
98
-
99
- class MiT(nn.Module):
100
- def __init__(
101
- self,
102
- *,
103
- channels,
104
- dims,
105
- heads,
106
- ff_expansion,
107
- reduction_ratio,
108
- num_layers
109
- ):
110
- super().__init__()
111
- stage_kernel_stride_pad = ((7, 4, 3), (3, 2, 1), (3, 2, 1), (3, 2, 1))
112
- channels_mod = channels/16
113
- dims = (channels_mod, *dims)
114
- dim_pairs = list(zip(dims[:-1], dims[1:]))
115
-
116
- self.stages = nn.ModuleList([])
117
-
118
- for (dim_in, dim_out), (kernel, stride, padding), num_layers, ff_expansion, heads, reduction_ratio in zip(dim_pairs, stage_kernel_stride_pad, num_layers, ff_expansion, heads, reduction_ratio):
119
- get_overlap_patches = nn.Unfold(kernel, stride = stride, padding = padding)
120
- # if count == 0:
121
- # overlap_patch_embed = nn.Conv2d(int((dim_in * kernel ** 2)/16), dim_out, 1)
122
- # else:
123
- overlap_patch_embed = nn.Conv2d(int(dim_in * kernel ** 2), dim_out, 1)
124
- # count+=1
125
-
126
- layers = nn.ModuleList([])
127
-
128
- for _ in range(num_layers):
129
- layers.append(nn.ModuleList([
130
- PreNorm(dim_out, EfficientSelfAttention(dim = dim_out, heads = heads, reduction_ratio = reduction_ratio)),
131
- PreNorm(dim_out, MixFeedForward(dim = dim_out, expansion_factor = ff_expansion)),
132
- ]))
133
-
134
- self.stages.append(nn.ModuleList([
135
- get_overlap_patches,
136
- overlap_patch_embed,
137
- layers
138
- ]))
139
-
140
- def forward(
141
- self,
142
- x,
143
- return_layer_outputs = False
144
- ):
145
- h, w = x.shape[-2:]
146
- # h = int(h/4)
147
- # w = int(w/4)
148
- # print(h)
149
-
150
- layer_outputs = []
151
- for (get_overlap_patches, overlap_embed, layers) in self.stages:
152
- x = get_overlap_patches(x)
153
- # print('aaa')
154
- # print(x.shape)
155
- num_patches = int(x.shape[-1])
156
-
157
-
158
- # num_patches = int(x.shape[-1]/16)
159
- # print(num_patches)
160
- ratio = int(sqrt((h * w) / num_patches))
161
- # print(ratio)
162
- x = rearrange(x, 'b c (h w) -> b c h w', h = h // ratio)
163
- # print(x.shape)
164
- x = x.type(torch.cuda.FloatTensor)
165
-
166
- x = overlap_embed(x)
167
- for (attn, ff) in layers:
168
- x = attn(x) + x
169
- x = ff(x) + x
170
-
171
- layer_outputs.append(x)
172
-
173
- ret = x if not return_layer_outputs else layer_outputs
174
- return ret
175
-
176
- class SpikeFormer(nn.Module):
177
- def __init__(
178
- self,
179
- inputDim=64,
180
- dims = (32, 64, 160, 256),
181
- heads = (1, 2, 5, 8),
182
- ff_expansion = (8, 8, 4, 4),
183
- reduction_ratio = (8, 4, 2, 1),
184
- num_layers = 2,
185
- channels =64,
186
- decoder_dim = 256,
187
- out_channel = 1
188
- ):
189
- super().__init__()
190
- dims, heads, ff_expansion, reduction_ratio, num_layers = map(partial(cast_tuple, depth = 4), (dims, heads, ff_expansion, reduction_ratio, num_layers))
191
- assert all([*map(lambda t: len(t) == 4, (dims, heads, ff_expansion, reduction_ratio, num_layers))]), 'only four stages are allowed, all keyword arguments must be either a single value or a tuple of 4 values'
192
-
193
- self.mit = MiT(
194
- channels = channels,
195
- dims = dims,
196
- heads = heads,
197
- ff_expansion = ff_expansion,
198
- reduction_ratio = reduction_ratio,
199
- num_layers = num_layers
200
- )
201
- self.channel_transform = nn.Sequential(
202
- nn.Conv2d(inputDim, 64, 3, 1, 1),
203
- nn.GELU()
204
- )
205
-
206
- self.to_fused = nn.ModuleList([nn.Sequential(
207
- nn.Conv2d(dim, decoder_dim, 1),
208
- nn.PixelShuffle(2 ** i),
209
- nn.GELU(),
210
- ) for i, dim in enumerate(dims)])
211
-
212
- self.to_restore = nn.Sequential(
213
- nn.Conv2d(256+64+16+4, decoder_dim, 1),
214
- nn.GELU(),
215
- nn.Conv2d(decoder_dim, out_channel, 1),
216
- )
217
- self.fournew = nn.PixelShuffle(4)
218
-
219
-
220
-
221
- def forward(self, x):
222
- x = self.channel_transform(x)
223
- x = self.fournew(x)
224
- layer_outputs = self.mit(x, return_layer_outputs = True)
225
-
226
- fused = [to_fused(output) for output, to_fused in zip(layer_outputs, self.to_fused)]
227
-
228
- fused = torch.cat(fused, dim = 1)
229
-
230
- return self.to_restore(fused)
File without changes
@@ -1,30 +0,0 @@
1
- # SpikeFormer [![MIT Licence](https://badges.frapsoft.com/os/mit/mit.svg?v=103)](https://opensource.org/licenses/mit-license.php)
2
- Pytorch Implementation of "SpikeFormer: Image Reconstruction from the Sequence of Spike Camera Based on Transformer"[[Paper]](https://dl.acm.org/doi/abs/10.1145/3512388.3512399)
3
-
4
- ## Prerequisites
5
- * Create a conda environment by running `conda env create -f requirements.yml`
6
-
7
- ## Dataset Structure
8
- * To train the SpikeFormer, please organize file structure of the dataset as follows:
9
- ```
10
- Dataset
11
- ├── test
12
- │   └── c.npz
13
- ├── train
14
- │   └── a.npz
15
- └── valid
16
- └── b.npz
17
- ```
18
-
19
- ## Pretrained Model
20
- * Download the pretrained model [here](https://pan.baidu.com/s/1aeW15vQh0GXgRJtfStBHDg) using password: nwh5.
21
- * Put the model to the path ./CheckPoints/
22
-
23
- ## Training
24
- * Run `python train.py` to train SpikeFormer on training set.
25
-
26
- ## Validation
27
- * Run `python evaluate.py` to evaluate the performance of trained model on testing set.
28
-
29
- ## Reconstruct Images from Real Spike Data
30
- * Run `python recon_real_data.py`.
@@ -1,87 +0,0 @@
1
- import os
2
- os.environ['CUDA_VISIBLE_DEVICES'] = "0"
3
- import torch
4
- import numpy as np
5
- from DataProcess import DataLoader as dl
6
- from Model import Loss
7
- from PIL import Image
8
- from Metrics.Metrics import Metrics
9
- from Model.SpikeFormer import SpikeFormer
10
- from utils import LoadModel
11
-
12
- if __name__ == "__main__":
13
-
14
- dataPath = "/home/storage2/shechen/Spike_Sample_250x400"
15
- spikeRadius = 32 # half length of input spike sequence expcept for the middle frame
16
- spikeLen = 2 * spikeRadius + 1 # length of input spike sequence
17
- batchSize = 4
18
-
19
- reuse = True
20
- checkPath = "CheckPoints/best.pth"
21
-
22
- validContainer = dl.DataContainer(dataPath=dataPath, dataType='valid',
23
- spikeRadius=spikeRadius,batchSize=batchSize)
24
- validData = validContainer.GetLoader()
25
-
26
- metrics = Metrics()
27
- # model = Spk2Img(spikeRadius, frameRadius, frameStride).cuda()
28
-
29
- model = SpikeFormer(
30
- inputDim=spikeLen,
31
- dims=(32, 64, 160, 256), # dimensions of each stage
32
- heads=(1, 2, 5, 8), # heads of each stage
33
- ff_expansion=(8, 8, 4, 4), # feedforward expansion factor of each stage
34
- reduction_ratio=(8, 4, 2, 1), # reduction ratio of each stage for efficient attention
35
- num_layers=2, # num layers of each stage
36
- decoder_dim=256, # decoder dimension
37
- out_channel=1 # channel of restored image
38
- ).cuda()
39
-
40
-
41
- if reuse:
42
- _, _, modelDict, _ = LoadModel(checkPath, model)
43
-
44
- model.eval()
45
- with torch.no_grad():
46
- num = 0
47
- pres = []
48
- gts = []
49
- for i, (spikes, gtImg) in enumerate(validData):
50
- B, D, H, W = spikes.size()
51
- spikes = spikes.cuda()
52
- gtImg = gtImg.cuda()
53
- predImg = model(spikes)
54
- predImg = predImg.squeeze(1)
55
-
56
- predImg = predImg.clamp(min=-1., max=1.)
57
- predImg = predImg.detach().cpu().numpy()
58
- gtImg = gtImg.clamp(min=-1., max=1.)
59
- gtImg = gtImg.detach().cpu().numpy()
60
-
61
- predImg = (predImg + 1.) / 2. * 255.
62
- predImg = predImg.astype(np.uint8)
63
- predImg = predImg[:, 3:-3]
64
-
65
- gtImg = (gtImg + 1.) / 2. * 255.
66
- gtImg = gtImg.astype(np.uint8)
67
-
68
- pres.append(predImg)
69
- gts.append(gtImg)
70
- pres = np.concatenate(pres, axis=0)
71
- gts = np.concatenate(gts, axis=0)
72
-
73
- psnr = metrics.Cal_PSNR(pres, gts)
74
- ssim = metrics.Cal_SSIM(pres, gts)
75
- best_psnr, best_ssim, _ = metrics.GetBestMetrics()
76
-
77
- B, H, W = pres.shape
78
- divide_line = np.zeros((H, 4)).astype(np.uint8)
79
- for pre, gt in zip(pres, gts):
80
- num += 1
81
- concatImg = np.concatenate([pre, divide_line, gt], axis=1)
82
- concatImg = Image.fromarray(concatImg)
83
- concatImg.save('EvalResults/test_%s.jpg' % (num))
84
-
85
- print('*********************************************************')
86
- print('PSNR: %s, SSIM: %s' % (psnr, ssim))
87
-
@@ -1,97 +0,0 @@
1
- import os
2
- os.environ['CUDA_VISIBLE_DEVICES'] = "0"
3
- import torch
4
- import numpy as np
5
- from PIL import Image
6
- from Metrics.Metrics import Metrics
7
- from Model.SpikeFormer import SpikeFormer
8
- from DataProcess.LoadSpike import load_spike_raw
9
- from utils import LoadModel
10
- import shutil
11
- from PIL import Image
12
-
13
- def PredictImg(model, inputs):
14
- inputs = torch.FloatTensor(inputs)
15
- inputs = inputs.cuda()
16
-
17
- predImg = model(inputs).squeeze(dim=1)
18
-
19
- predImg = predImg.clamp(min=-1., max=1.)
20
- predImg = predImg.detach().cpu().numpy()
21
- predImg = (predImg + 1.) / 2. * 255.
22
- predImg = np.clip(predImg, 0., 255.)
23
- predImg = predImg.astype(np.uint8)
24
- predImg = predImg[:, 3:-3]
25
-
26
- return predImg
27
-
28
- if __name__ == "__main__":
29
-
30
- dataName = "reds"
31
- spikeRadius = 32
32
- spikeLen = 2 * spikeRadius + 1
33
- stride = 32
34
- batchSize = 8
35
- reuse = True
36
- checkPath = "best.pth"
37
- sceneClass = {
38
- 1:'ballon.dat', 2:'car-100kmh.dat',
39
- 3:'forest.dat', 4:'railway.dat',
40
- 5:'rotation1.dat', 6:'rotation2.dat',
41
- 7:'train-350kmh.dat', 8:'viaduct-bridge.dat'
42
- }
43
- sceneName = sceneClass[2]
44
- dataPath = "/home/storage1/Dataset/SpikeImageData/RealData/%s" %(sceneName)
45
- resultPath = sceneName + "_stride_" + str(stride) + "/"
46
- shutil.rmtree(resultPath) if os.path.exists(resultPath) else os.mkdir(resultPath)
47
- spikes = load_spike_raw(dataPath)
48
- totalLen = spikes.shape[0]
49
- metrics = Metrics()
50
- model = SpikeFormer(
51
- inputDim=spikeLen,
52
- dims=(32, 64, 160, 256), # dimensions of each stage
53
- heads=(1, 2, 5, 8), # heads of each stage
54
- ff_expansion=(8, 8, 4, 4), # feedforward expansion factor of each stage
55
- reduction_ratio=(8, 4, 2, 1), # reduction ratio of each stage for efficient attention
56
- num_layers=2, # num layers of each stage
57
- decoder_dim=256, # decoder dimension
58
- out_channel = 1 # channel of restored image
59
- ).cuda()
60
-
61
- if reuse:
62
- _, _, modelDict, _ = LoadModel(checkPath, model)
63
-
64
- model.eval()
65
- with torch.no_grad():
66
- num = 0
67
- pres = []
68
- batchFlag = 1
69
- inputs = np.zeros((batchSize, spikeLen, 256, 400)) # 65
70
- for i in range(32, totalLen - 32, stride):
71
- batchFlag = 1
72
- spike = spikes[i - spikeRadius: i + spikeRadius + 1]
73
- spike = np.pad(spike, ((0, 0), (3, 3), (0, 0)), mode='constant')
74
- spike = spike.astype(float) * 2 - 1
75
- inputs[num % batchSize] = spike
76
- num += 1
77
-
78
- if num % batchSize == 0:
79
- predImg = PredictImg(model, inputs)
80
- inputs = np.zeros((batchSize, spikeLen, 256, 400)) # 65
81
- pres.append(predImg)
82
- batchFlag = 0
83
-
84
- if batchFlag == 1:
85
- imgNum = num % batchSize
86
- inputs = inputs[0: imgNum]
87
- predImg = PredictImg(model, inputs)
88
- inputs = np.zeros((batchSize, spikeLen, 256, 400))
89
- pres.append(predImg)
90
-
91
- predImgs = np.concatenate(pres, axis=0)
92
- count = 0
93
- for img in predImgs:
94
- count += 1
95
- img = Image.fromarray(img)
96
- img.save(resultPath + '%s.jpg' % (count))
97
-
@@ -1,95 +0,0 @@
1
- name: SpikeFormer
2
- channels:
3
- - pytorch
4
- - conda-forge
5
- - defaults
6
- dependencies:
7
- - _libgcc_mutex=0.1=main
8
- - _openmp_mutex=5.1=1_gnu
9
- - blas=1.0=mkl
10
- - brotlipy=0.7.0=py38h0a891b7_1004
11
- - bzip2=1.0.8=h7f98852_4
12
- - ca-certificates=2022.07.19=h06a4308_0
13
- - certifi=2022.6.15=py38h06a4308_0
14
- - cffi=1.14.6=py38ha65f79e_0
15
- - charset-normalizer=2.1.1=pyhd8ed1ab_0
16
- - cloudpickle=2.0.0=pyhd3eb1b0_0
17
- - cryptography=37.0.2=py38h2b5fc30_0
18
- - cudatoolkit=11.6.0=hecad31d_10
19
- - cytoolz=0.11.0=py38h7b6447c_0
20
- - dask-core=2022.7.0=py38h06a4308_0
21
- - ffmpeg=4.3=hf484d3e_0
22
- - fftw=3.3.9=h27cfd23_1
23
- - freetype=2.10.4=h0708190_1
24
- - fsspec=2022.7.1=py38h06a4308_0
25
- - gmp=6.2.1=h58526e2_0
26
- - gnutls=3.6.13=h85f3911_1
27
- - idna=3.3=pyhd8ed1ab_0
28
- - imageio=2.9.0=pyhd3eb1b0_0
29
- - intel-openmp=2021.4.0=h06a4308_3561
30
- - jpeg=9e=h166bdaf_1
31
- - lame=3.100=h7f98852_1001
32
- - lcms2=2.12=hddcbb42_0
33
- - ld_impl_linux-64=2.38=h1181459_1
34
- - libffi=3.3=he6710b0_2
35
- - libgcc-ng=11.2.0=h1234567_1
36
- - libgfortran-ng=11.2.0=h00389a5_1
37
- - libgfortran5=11.2.0=h1234567_1
38
- - libgomp=11.2.0=h1234567_1
39
- - libiconv=1.17=h166bdaf_0
40
- - libpng=1.6.37=h21135ba_2
41
- - libstdcxx-ng=11.2.0=h1234567_1
42
- - libtiff=4.2.0=hf544144_3
43
- - libwebp-base=1.2.2=h7f98852_1
44
- - locket=1.0.0=py38h06a4308_0
45
- - lz4-c=1.9.3=h9c3ff4c_1
46
- - mkl=2021.4.0=h06a4308_640
47
- - mkl-service=2.4.0=py38h95df7f1_0
48
- - mkl_fft=1.3.1=py38h8666266_1
49
- - mkl_random=1.2.2=py38h1abd341_0
50
- - natsort=7.1.1=pyhd3eb1b0_0
51
- - ncurses=6.3=h5eee18b_3
52
- - nettle=3.6=he412f7d_0
53
- - networkx=2.8.4=py38h06a4308_0
54
- - numpy=1.21.5=py38h6c91a56_3
55
- - numpy-base=1.21.5=py38ha15fc14_3
56
- - olefile=0.46=pyh9f0ad1d_1
57
- - openh264=2.1.1=h780b84a_0
58
- - openjpeg=2.4.0=hb52868f_1
59
- - openssl=1.1.1q=h7f8727e_0
60
- - packaging=21.3=pyhd3eb1b0_0
61
- - partd=1.2.0=pyhd3eb1b0_1
62
- - pillow=8.2.0=py38ha0e1e83_1
63
- - pip=22.1.2=py38h06a4308_0
64
- - pycparser=2.21=pyhd8ed1ab_0
65
- - pyopenssl=22.0.0=pyhd8ed1ab_0
66
- - pyparsing=3.0.9=py38h06a4308_0
67
- - pysocks=1.7.1=pyha2e5f31_6
68
- - python=3.8.13=h12debd9_0
69
- - python_abi=3.8=2_cp38
70
- - pytorch=1.12.1=py3.8_cuda11.6_cudnn8.3.2_0
71
- - pytorch-mutex=1.0=cuda
72
- - pywavelets=1.3.0=py38h7f8727e_0
73
- - pyyaml=6.0=py38h7f8727e_1
74
- - readline=8.1.2=h7f8727e_1
75
- - requests=2.28.1=pyhd8ed1ab_1
76
- - scikit-image=0.19.2=py38h51133e4_0
77
- - scipy=1.7.3=py38h6c91a56_2
78
- - setuptools=63.4.1=py38h06a4308_0
79
- - six=1.16.0=pyh6c4a22f_0
80
- - sqlite=3.39.2=h5082296_0
81
- - tifffile=2020.10.1=py38hdd07704_2
82
- - tk=8.6.12=h1ccaba5_0
83
- - toolz=0.11.2=pyhd3eb1b0_0
84
- - torchaudio=0.12.1=py38_cu116
85
- - torchvision=0.13.1=py38_cu116
86
- - typing_extensions=4.3.0=pyha770c72_0
87
- - urllib3=1.26.11=pyhd8ed1ab_0
88
- - wheel=0.37.1=pyhd3eb1b0_0
89
- - xz=5.2.5=h7f8727e_1
90
- - yaml=0.2.5=h7b6447c_0
91
- - zlib=1.2.12=h7f8727e_2
92
- - zstd=1.5.0=ha95c52a_0
93
- - pip:
94
- - einops==0.4.1
95
- - opencv-python==4.6.0.66