spike-encoding 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spike_encoding/__init__.py +2 -0
- spike_encoding/base_converter.py +14 -0
- spike_encoding/bens_spiker_algorithm.py +279 -0
- spike_encoding/bin_encoder.py +80 -0
- spike_encoding/delta_modulation_converter.py +52 -0
- spike_encoding/encoder_common.py +31 -0
- spike_encoding/gymnasium_bounds_finder.py +238 -0
- spike_encoding/gymnasium_encoder.py +149 -0
- spike_encoding/lif_based_encoding.py +393 -0
- spike_encoding/pulse_width_modulation.py +337 -0
- spike_encoding/rate_step_forward_converter.py +190 -0
- spike_encoding/step_forward_converter.py +454 -0
- spike_encoding-0.1.0.dist-info/METADATA +444 -0
- spike_encoding-0.1.0.dist-info/RECORD +15 -0
- spike_encoding-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
class BaseConverter(torch.nn.Module):
|
|
4
|
+
def __init__(self):
|
|
5
|
+
super(BaseConverter, self).__init__()
|
|
6
|
+
|
|
7
|
+
def encode(self, tensor):
|
|
8
|
+
raise NotImplementedError
|
|
9
|
+
|
|
10
|
+
def decode(self, spikes: torch.Tensor):
|
|
11
|
+
raise NotImplementedError
|
|
12
|
+
|
|
13
|
+
def optimize(self, data: torch.Tensor):
|
|
14
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import Tensor
|
|
3
|
+
from torchmetrics import MeanSquaredError
|
|
4
|
+
from scipy.signal import firwin
|
|
5
|
+
import optuna
|
|
6
|
+
from optuna.samplers import TPESampler
|
|
7
|
+
|
|
8
|
+
from spike_encoding.base_converter import BaseConverter
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BensSpikerAlgorithm(BaseConverter):
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
threshold: torch.Tensor = torch.tensor(
|
|
15
|
+
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], dtype=torch.float
|
|
16
|
+
),
|
|
17
|
+
down_spike: bool = True,
|
|
18
|
+
min_value=None,
|
|
19
|
+
max_value=None,
|
|
20
|
+
scale_factor=None,
|
|
21
|
+
filter_order: torch.Tensor = torch.tensor(
|
|
22
|
+
[10, 10, 10, 10, 10, 10, 10], dtype=torch.int
|
|
23
|
+
),
|
|
24
|
+
filter_cutoff: torch.Tensor = torch.tensor(
|
|
25
|
+
[0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2], dtype=torch.float
|
|
26
|
+
),
|
|
27
|
+
):
|
|
28
|
+
super().__init__()
|
|
29
|
+
if isinstance(threshold, (float, int)):
|
|
30
|
+
self.threshold = torch.tensor([threshold], dtype=torch.float)
|
|
31
|
+
elif isinstance(threshold, (list, tuple)):
|
|
32
|
+
self.threshold = torch.tensor(threshold, dtype=torch.float)
|
|
33
|
+
elif isinstance(threshold, torch.Tensor):
|
|
34
|
+
self.threshold = threshold.float()
|
|
35
|
+
else:
|
|
36
|
+
raise ValueError(
|
|
37
|
+
"threshold must be a float, int, list, tuple, or torch.Tensor"
|
|
38
|
+
)
|
|
39
|
+
self.down_spike = down_spike
|
|
40
|
+
self.min_value = min_value
|
|
41
|
+
self.max_value = max_value
|
|
42
|
+
self.scale_factor = scale_factor
|
|
43
|
+
if isinstance(filter_order, int):
|
|
44
|
+
self.filter_order = torch.tensor([filter_order], dtype=torch.int)
|
|
45
|
+
elif isinstance(filter_order, (list, tuple)):
|
|
46
|
+
self.filter_order = torch.tensor(filter_order, dtype=torch.int)
|
|
47
|
+
elif isinstance(filter_order, torch.Tensor):
|
|
48
|
+
self.filter_order = filter_order.int()
|
|
49
|
+
else:
|
|
50
|
+
raise ValueError("filter_order must be an int, list, tuple or torch.Tensor")
|
|
51
|
+
|
|
52
|
+
if isinstance(filter_cutoff, (float, int)):
|
|
53
|
+
self.filter_cutoff = torch.tensor([filter_cutoff], dtype=torch.float)
|
|
54
|
+
elif isinstance(filter_cutoff, (list, tuple)):
|
|
55
|
+
self.filter_cutoff = torch.tensor(filter_cutoff, dtype=torch.float)
|
|
56
|
+
elif isinstance(filter_cutoff, torch.Tensor):
|
|
57
|
+
self.filter_cutoff = filter_cutoff.float()
|
|
58
|
+
else:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
"filter_cutoff must be a float, int, list, tuple, or torch.Tensor"
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
def encode(
|
|
64
|
+
self,
|
|
65
|
+
signal: torch.Tensor,
|
|
66
|
+
filter_order: torch.Tensor = torch.tensor([], dtype=torch.int),
|
|
67
|
+
filter_cutoff: torch.Tensor = torch.tensor([], dtype=torch.float),
|
|
68
|
+
threshold: torch.Tensor = torch.tensor([], dtype=torch.float),
|
|
69
|
+
isNormed: bool = False,
|
|
70
|
+
):
|
|
71
|
+
"""
|
|
72
|
+
Encodes a signal into spikes using an FIR filter and a threshold.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
signal (torch.Tensor): The signal to be encoded as a tensor.
|
|
76
|
+
filter_order (torch.Tensor): possible to give a order, DEFAULT: 100
|
|
77
|
+
cutoff (torch.Tensor): the cutoff percentage for the FIR-filter. DEFAULT: 0.2
|
|
78
|
+
threshold (torch.Tensor): The threshold to generate spikes. DEFAULT: 0.5
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
spikes (torch.Tensor): The encoded spikes as a tensor.
|
|
82
|
+
"""
|
|
83
|
+
signal.squeeze_()
|
|
84
|
+
signal = torch.atleast_2d(signal)
|
|
85
|
+
if isNormed:
|
|
86
|
+
signal_norm = signal
|
|
87
|
+
else:
|
|
88
|
+
signal_norm = self.normalize_tensor(signal)
|
|
89
|
+
|
|
90
|
+
self.filter_order = (
|
|
91
|
+
self.filter_order if filter_order.numel() == 0 else filter_order
|
|
92
|
+
)
|
|
93
|
+
self.filter_cutoff = (
|
|
94
|
+
self.filter_cutoff if filter_cutoff.numel() == 0 else filter_cutoff
|
|
95
|
+
)
|
|
96
|
+
self.threshold = self.threshold if threshold.numel() == 0 else threshold
|
|
97
|
+
|
|
98
|
+
if (
|
|
99
|
+
len(self.filter_order) < len(signal)
|
|
100
|
+
or len(self.filter_cutoff) < len(signal)
|
|
101
|
+
or len(self.threshold) < len(signal)
|
|
102
|
+
):
|
|
103
|
+
raise ValueError(
|
|
104
|
+
"filter_order, filter_cutoff and threshold must be a tensor of the same length as the input signal."
|
|
105
|
+
)
|
|
106
|
+
self.filter_order = self.filter_order[: len(signal)]
|
|
107
|
+
self.filter_cutoff = self.filter_cutoff[: len(signal)]
|
|
108
|
+
self.threshold = self.threshold[: len(signal)]
|
|
109
|
+
|
|
110
|
+
FIR = self.fir_filter()
|
|
111
|
+
num_rows, L = signal.shape
|
|
112
|
+
|
|
113
|
+
# print(f"Encoding {num_rows} signals with length {L}. filter_order={self.filter_order}, filter_cutoff={self.filter_cutoff}, threshold={self.threshold} \nfir_filter={FIR}")
|
|
114
|
+
|
|
115
|
+
spike_trains = torch.zeros(num_rows, L)
|
|
116
|
+
|
|
117
|
+
for row in range(num_rows):
|
|
118
|
+
F = len(FIR[row])
|
|
119
|
+
|
|
120
|
+
s = signal_norm[row]
|
|
121
|
+
|
|
122
|
+
out = torch.zeros(L)
|
|
123
|
+
for t in range(1, L):
|
|
124
|
+
err1 = 0
|
|
125
|
+
err2 = 0
|
|
126
|
+
|
|
127
|
+
for k in range(1, F):
|
|
128
|
+
if t + k - 1 < L:
|
|
129
|
+
err1 += torch.abs(s[t + k - 1] - FIR[row][k])
|
|
130
|
+
err2 += torch.abs(s[t + k - 1])
|
|
131
|
+
|
|
132
|
+
if err1 <= err2 - self.threshold[row]:
|
|
133
|
+
out[t] = 1
|
|
134
|
+
for k in range(1, F):
|
|
135
|
+
if t + k - 1 < L:
|
|
136
|
+
s[t + k - 1] -= FIR[row][k]
|
|
137
|
+
|
|
138
|
+
spike_trains[row] = out
|
|
139
|
+
|
|
140
|
+
return spike_trains
|
|
141
|
+
|
|
142
|
+
def decode(self, spikes: torch.Tensor):
|
|
143
|
+
|
|
144
|
+
spikes.squeeze_()
|
|
145
|
+
spikes = torch.atleast_2d(spikes)
|
|
146
|
+
|
|
147
|
+
FIR = self.fir_filter()
|
|
148
|
+
|
|
149
|
+
feature_size = spikes.shape[0]
|
|
150
|
+
L = spikes.shape[1]
|
|
151
|
+
|
|
152
|
+
reconstructed_signals = torch.zeros((feature_size, L))
|
|
153
|
+
|
|
154
|
+
for row in range(feature_size):
|
|
155
|
+
spike_train = spikes[row].clone().detach()
|
|
156
|
+
|
|
157
|
+
out = torch.zeros(L)
|
|
158
|
+
F = len(FIR[row])
|
|
159
|
+
padding = F // 2
|
|
160
|
+
padded_input = torch.cat(
|
|
161
|
+
(torch.zeros(padding), spike_train, torch.zeros(padding))
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
for i in range(padding, L):
|
|
165
|
+
for j in range(F):
|
|
166
|
+
if i + j < L + padding * 2:
|
|
167
|
+
out[i] += padded_input[i + j - padding + 1] * FIR[row][j]
|
|
168
|
+
|
|
169
|
+
reconstructed_signals[row] = out
|
|
170
|
+
|
|
171
|
+
result_vector = []
|
|
172
|
+
for i in range(feature_size):
|
|
173
|
+
scaled_signal = (
|
|
174
|
+
reconstructed_signals[i] / self.scale_factor[i] + self.min_value[i]
|
|
175
|
+
)
|
|
176
|
+
result_vector.append(scaled_signal.tolist())
|
|
177
|
+
|
|
178
|
+
return result_vector
|
|
179
|
+
|
|
180
|
+
def optimize(
|
|
181
|
+
self,
|
|
182
|
+
data: torch.Tensor,
|
|
183
|
+
trials: int = 100,
|
|
184
|
+
error_function=MeanSquaredError(),
|
|
185
|
+
plot_history=False,
|
|
186
|
+
):
|
|
187
|
+
"""
|
|
188
|
+
Optimize the threshold and the cutoff persentage of the FIR-Filter of the encoding
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
- data (torch.Tensor): The given Signal
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
- self.threshold: best threshold for the Singnal.
|
|
195
|
+
- self.filter_cutoff: best cutoff for the FIR-filter
|
|
196
|
+
"""
|
|
197
|
+
|
|
198
|
+
if data.ndim > 2:
|
|
199
|
+
raise ValueError(
|
|
200
|
+
f"{type(self).__name__}.{self.optimize.__name__}() only supports input tensors with dimension <=2, but got dimension {data.ndim}."
|
|
201
|
+
)
|
|
202
|
+
data = torch.atleast_2d(data)
|
|
203
|
+
synth_sig = data # self.normalize_tensor(data)
|
|
204
|
+
f_order = torch.zeros(len(synth_sig))
|
|
205
|
+
f_cutoff = torch.zeros(len(synth_sig))
|
|
206
|
+
_threshold = torch.zeros(len(synth_sig))
|
|
207
|
+
|
|
208
|
+
for i in range(len(synth_sig)):
|
|
209
|
+
|
|
210
|
+
def loss_function(trial):
|
|
211
|
+
filter_order = trial.suggest_int("filter_order", 35, 50)
|
|
212
|
+
filter_cutoff = trial.suggest_float(
|
|
213
|
+
"filter_cutoff", 0.001, 0.25
|
|
214
|
+
) # 0 < cutoff < fs/2
|
|
215
|
+
threshold = trial.suggest_float("threshold", 0.4, 1.1)
|
|
216
|
+
filter_order_T = torch.tensor([filter_order], dtype=torch.int)
|
|
217
|
+
filter_cutoff_T = torch.tensor([filter_cutoff], dtype=torch.float)
|
|
218
|
+
threshold_T = torch.tensor([threshold], dtype=torch.float)
|
|
219
|
+
encoded_signal = self.encode(
|
|
220
|
+
synth_sig[i],
|
|
221
|
+
filter_order=filter_order_T,
|
|
222
|
+
filter_cutoff=filter_cutoff_T,
|
|
223
|
+
threshold=threshold_T,
|
|
224
|
+
isNormed=False,
|
|
225
|
+
)
|
|
226
|
+
decoded_signal = self.decode(encoded_signal)
|
|
227
|
+
loss = error_function(Tensor(decoded_signal)[0], data[i])
|
|
228
|
+
|
|
229
|
+
return loss
|
|
230
|
+
|
|
231
|
+
print(f"Optimizing for signal {i+1}/{len(synth_sig)}:")
|
|
232
|
+
|
|
233
|
+
optuna.logging.set_verbosity(optuna.logging.WARNING) # hide logging
|
|
234
|
+
study = optuna.create_study(direction="minimize", sampler=TPESampler())
|
|
235
|
+
study.optimize(loss_function, n_trials=trials, show_progress_bar=True)
|
|
236
|
+
f_order[i] = study.best_params["filter_order"]
|
|
237
|
+
f_cutoff[i] = study.best_params["filter_cutoff"]
|
|
238
|
+
_threshold[i] = study.best_params["threshold"]
|
|
239
|
+
|
|
240
|
+
print(
|
|
241
|
+
f"best_order={f_order[i]}, best_cutoff={f_cutoff[i]}, best_threshold={_threshold[i]}"
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
if plot_history:
|
|
245
|
+
# Plot the optimization values
|
|
246
|
+
optuna.visualization.plot_optimization_history(study).show()
|
|
247
|
+
optuna.visualization.plot_slice(
|
|
248
|
+
study, params=["filter_order", "filter_cutoff", "threshold"]
|
|
249
|
+
).show()
|
|
250
|
+
|
|
251
|
+
self.filter_order = f_order.int()
|
|
252
|
+
self.filter_cutoff = f_cutoff.float()
|
|
253
|
+
self.threshold = _threshold.float()
|
|
254
|
+
|
|
255
|
+
return self.filter_order, self.filter_cutoff, self.threshold
|
|
256
|
+
|
|
257
|
+
def normalize_tensor(self, tensor):
|
|
258
|
+
"""Normalize a tensor to the range between 0 and 1."""
|
|
259
|
+
self.min_value = torch.min(tensor, dim=1)[0]
|
|
260
|
+
self.max_value = torch.max(tensor, dim=1)[0]
|
|
261
|
+
self.scale_factor = 1.0 / (self.max_value - self.min_value)
|
|
262
|
+
|
|
263
|
+
normalized_tensor = torch.zeros(tensor.shape)
|
|
264
|
+
for i in range(len(tensor)):
|
|
265
|
+
normalized_tensor[i] = self.scale_factor[i] * (
|
|
266
|
+
tensor[i] - self.min_value[i]
|
|
267
|
+
)
|
|
268
|
+
return normalized_tensor
|
|
269
|
+
|
|
270
|
+
def fir_filter(self):
|
|
271
|
+
"""set up the FIR-filter coefficents"""
|
|
272
|
+
filter_coeffs = []
|
|
273
|
+
for i in range(len(self.filter_order)):
|
|
274
|
+
fir = firwin(
|
|
275
|
+
self.filter_order[i].item() + 1, self.filter_cutoff[i].item(), fs=1.0
|
|
276
|
+
) # fs=1.0 for normalized frequencies
|
|
277
|
+
filter_coeffs.append(torch.tensor(fir, dtype=torch.float32))
|
|
278
|
+
|
|
279
|
+
return filter_coeffs
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
from typing import List, Literal, Union
|
|
2
|
+
import numpy as np
|
|
3
|
+
from numpy.typing import NDArray as ndarray
|
|
4
|
+
import numpy.random as rndm
|
|
5
|
+
|
|
6
|
+
from spike_encoding.base_converter import BaseConverter
|
|
7
|
+
from spike_encoding.gymnasium_bounds_finder import ScalerFactory
|
|
8
|
+
from spike_encoding.encoder_common import poisson, rate
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def gaussian_response(x, mu, sigma=0.3):
|
|
12
|
+
return np.exp(-np.power(x - mu, 2.0) / (2 * np.power(sigma, 2.0)))
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def transform_firing_rates(firing_rates, n_bins, sigma=0.3):
|
|
16
|
+
# Assuming firing rates are scaled to [0, 1]
|
|
17
|
+
bin_centers = np.linspace(0, 1, n_bins)
|
|
18
|
+
transformed_rates = []
|
|
19
|
+
|
|
20
|
+
for rate in firing_rates:
|
|
21
|
+
responses = [gaussian_response(rate, mu, sigma=sigma) for mu in bin_centers]
|
|
22
|
+
# Normalize the responses to sum to 1
|
|
23
|
+
normalized_responses = responses / np.sum(responses)
|
|
24
|
+
transformed_rates += normalized_responses.tolist()
|
|
25
|
+
|
|
26
|
+
return np.array(transformed_rates)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class BinEncoder(BaseConverter):
|
|
30
|
+
"""
|
|
31
|
+
This encoder creates spike trains from gymnasium observations. To use it, first install gymnasium https://gymnasium.farama.org/
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
seq_length: The number of timesteps in the spike train. E.g. [[1], [0], [1]] has a seq_length of 3
|
|
35
|
+
scaler: This is important for scaling inputs internally. For example, if your input goes up to 10, the scaler needs to scale instances with value 10 to a firing rate of 1.0
|
|
36
|
+
spike_train_conversion_method: determines how a firing rate is converted into a spike train. In poisson encoding, a firing rate of 0.1 means there is a 10% chance for any given timestep to be a spike. By chance there could be more or fewer spikes. If instead "deterministic" is chosen, you are guraranteed that 10% of timesteps are spikes
|
|
37
|
+
max_firing_rate: multiplier for maximum firing rate. typically the maximum is 1.0 (i.e. spikes at every step). You can set this to a lower value like 0.5 so on average you will only get a spike at every other step. This may be important for some R-STDP scenarios, where very high firing rates can impact synaptic tags
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
spike train: an array of arrays of spikes.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
spike_train_conversion_method: Literal["poisson", "deterministic"] = "poisson"
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
seq_length: int,
|
|
49
|
+
min_values: Union[List[float], ndarray],
|
|
50
|
+
max_values: Union[List[float], ndarray],
|
|
51
|
+
spike_train_conversion_method: Literal["poisson", "deterministic"] = "poisson",
|
|
52
|
+
n_bins=10,
|
|
53
|
+
max_firing_rate=1.0,
|
|
54
|
+
sigma=0.1,
|
|
55
|
+
):
|
|
56
|
+
self.sigma = sigma
|
|
57
|
+
self.seq_length = seq_length if seq_length <= 1 else seq_length
|
|
58
|
+
scaler_factory = ScalerFactory()
|
|
59
|
+
self.scaler = scaler_factory.from_known_values(min_values, max_values)
|
|
60
|
+
self.n_bins = n_bins
|
|
61
|
+
self.spike_train_conversion_method = spike_train_conversion_method
|
|
62
|
+
self.max_firing_rate = max_firing_rate
|
|
63
|
+
|
|
64
|
+
self.seed = 42
|
|
65
|
+
rndm.seed(self.seed)
|
|
66
|
+
|
|
67
|
+
def encode(self, state: ndarray) -> ndarray:
|
|
68
|
+
# NOTE this uses batches for scaling and coding, but not for binning
|
|
69
|
+
p_spikes = self.scaler.transform(np.atleast_2d(state))[0]
|
|
70
|
+
p_bins = np.atleast_2d(
|
|
71
|
+
[transform_firing_rates(p_spikes, self.n_bins, self.sigma)]
|
|
72
|
+
)
|
|
73
|
+
p_bins *= self.max_firing_rate
|
|
74
|
+
|
|
75
|
+
if self.spike_train_conversion_method == "poisson":
|
|
76
|
+
output = poisson(p_bins, self.seq_length) # type: ignore
|
|
77
|
+
else:
|
|
78
|
+
output = rate(p_bins, self.seq_length) # type: ignore
|
|
79
|
+
|
|
80
|
+
return output
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .base_converter import BaseConverter
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DeltaModulationConverter(BaseConverter):
|
|
7
|
+
"""Encodes a tensor of input values by the difference bewteen two subsequent timesteps, and simulates the spikes that
|
|
8
|
+
occur when meets threshold.
|
|
9
|
+
|
|
10
|
+
Parameters:
|
|
11
|
+
delta (float): Input with a change greater than the thresold across one timestep will generate a spike, defaults to ``0.1``
|
|
12
|
+
padding (bool): If ``True``, the first time step will be compared with itself resulting in ``0``'s in spikes. If ``False``, it will be padded with ``0``'s, defaults to ``False``
|
|
13
|
+
off_spike: If ``True``, spikes for negative changes less than ``-threshold``, defaults to ``False``
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
delta: float = 0.1,
|
|
19
|
+
normalized: bool = True,
|
|
20
|
+
padding: bool = False,
|
|
21
|
+
off_spike: bool = False,
|
|
22
|
+
):
|
|
23
|
+
super(DeltaModulationConverter, self).__init__()
|
|
24
|
+
self.delta = delta
|
|
25
|
+
self.normalized = normalized
|
|
26
|
+
self.padding = padding
|
|
27
|
+
self.off_spike = off_spike
|
|
28
|
+
|
|
29
|
+
def forward(self, tensor):
|
|
30
|
+
if tensor.ndimension() < 2:
|
|
31
|
+
tensor = tensor.unsqueeze(0)
|
|
32
|
+
cols = torch.split(tensor, 1, 1)
|
|
33
|
+
|
|
34
|
+
if not self.normalized:
|
|
35
|
+
for i in range(tensor.shape[0]):
|
|
36
|
+
tensor[i] = (tensor[i] - torch.min(tensor[i])) / (torch.max(tensor[i]) - torch.min(tensor[i]))
|
|
37
|
+
|
|
38
|
+
if self.padding:
|
|
39
|
+
data_offset = torch.cat((cols[0], tensor), dim=1)[:, :-1] # duplicate first time step, remove final step
|
|
40
|
+
else:
|
|
41
|
+
data_offset = torch.cat((torch.zeros_like(cols[0]), tensor), dim=1)[:,
|
|
42
|
+
:-1] # add 0's to first step, remove final step
|
|
43
|
+
|
|
44
|
+
if not self.off_spike:
|
|
45
|
+
return torch.stack((torch.ones_like(tensor) * ((tensor - data_offset) >= self.delta), torch.zeros_like(tensor)))
|
|
46
|
+
else:
|
|
47
|
+
on_spk = torch.ones_like(tensor) * ((tensor - data_offset) >= self.delta)
|
|
48
|
+
off_spk = -torch.ones_like(tensor) * ((tensor - data_offset) <= -self.delta)
|
|
49
|
+
return torch.stack((on_spk, off_spk))
|
|
50
|
+
|
|
51
|
+
def decode(self, spikes: torch.Tensor):
|
|
52
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from numpy.typing import NDArray as ndarray
|
|
2
|
+
import numpy as np
|
|
3
|
+
import numpy.random as rndm
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def poisson(spike_probabilities: ndarray, max_t: int) -> ndarray:
|
|
7
|
+
spikes = (
|
|
8
|
+
np.asarray(
|
|
9
|
+
[
|
|
10
|
+
spike_probabilities > rndm.random(spike_probabilities.shape)
|
|
11
|
+
for i in range(int(max_t))
|
|
12
|
+
]
|
|
13
|
+
)
|
|
14
|
+
+ 0
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
return spikes
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def rate(spike_probabilities: ndarray, max_t: int) -> ndarray:
|
|
21
|
+
spikes = np.zeros((max_t, *spike_probabilities.shape), dtype=int)
|
|
22
|
+
|
|
23
|
+
for index in np.ndindex(spike_probabilities.shape):
|
|
24
|
+
i = spike_probabilities[index]
|
|
25
|
+
idx = np.array(np.arange(max_t * i, 0, -1) / i, dtype=int) - 1
|
|
26
|
+
idx = np.clip(idx, 0, max_t - 1) # Ensure indices are within bounds
|
|
27
|
+
|
|
28
|
+
# Use tuple unpacking to set the value at the specified indices
|
|
29
|
+
spikes[(idx,) + index] = 1
|
|
30
|
+
|
|
31
|
+
return spikes
|
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
from typing import List, Union
|
|
2
|
+
import gymnasium as gym
|
|
3
|
+
import tqdm
|
|
4
|
+
from joblib import Parallel, delayed
|
|
5
|
+
import joblib
|
|
6
|
+
from sklearn.preprocessing import MinMaxScaler
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
import numpy as np
|
|
9
|
+
import os
|
|
10
|
+
from numpy.typing import NDArray as ndarray
|
|
11
|
+
|
|
12
|
+
from .rate_step_forward_converter import RateStepForwardConverter as Converter
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def len_2(arr: ndarray) -> int:
|
|
16
|
+
return arr.shape[1]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_runs(env_name: str, num_steps: int, workers: int, print_updates=False) -> list:
|
|
20
|
+
if print_updates:
|
|
21
|
+
print("Generating " + str(num_steps) + " random runs...")
|
|
22
|
+
|
|
23
|
+
def get_runs() -> list:
|
|
24
|
+
env = gym.make(env_name)
|
|
25
|
+
runs = []
|
|
26
|
+
for i in tqdm.tqdm(range(int(num_steps / workers))):
|
|
27
|
+
# Get the initial state of the environment
|
|
28
|
+
states = []
|
|
29
|
+
state, _ = env.reset()
|
|
30
|
+
done = False
|
|
31
|
+
trunc = False
|
|
32
|
+
while not (done or trunc):
|
|
33
|
+
states.append(state)
|
|
34
|
+
action = env.action_space.sample()
|
|
35
|
+
state, _, done, trunc, _ = env.step(action)
|
|
36
|
+
|
|
37
|
+
runs.append(np.array(states).swapaxes(0, 1))
|
|
38
|
+
|
|
39
|
+
return runs
|
|
40
|
+
|
|
41
|
+
if workers > 1:
|
|
42
|
+
results = Parallel(n_jobs=-1)(delayed(get_runs)() for i in [()] * workers)
|
|
43
|
+
runs = []
|
|
44
|
+
assert results is not None
|
|
45
|
+
for res in results:
|
|
46
|
+
runs.extend(res)
|
|
47
|
+
else:
|
|
48
|
+
runs = get_runs()
|
|
49
|
+
|
|
50
|
+
return runs
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def save_scaler(path: str, scaler: MinMaxScaler, print_updates=False) -> bool:
|
|
54
|
+
joblib.dump(scaler, path)
|
|
55
|
+
if print_updates:
|
|
56
|
+
print("Saved scaler at " + path)
|
|
57
|
+
return True
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def load_presaved_env_data(path: str, print_updates=False) -> MinMaxScaler:
|
|
61
|
+
scaler = joblib.load(path)
|
|
62
|
+
if print_updates:
|
|
63
|
+
print("Loaded scaler at " + path)
|
|
64
|
+
return scaler
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def save_treshold(path: str, treshold: Tensor, print_updates=False) -> bool:
|
|
68
|
+
with open(path, "w") as f:
|
|
69
|
+
f.write(str(treshold.tolist()))
|
|
70
|
+
|
|
71
|
+
if print_updates:
|
|
72
|
+
print("Saved thresholds at " + path)
|
|
73
|
+
return True
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def load_treshold(path: str, print_updates=False) -> Tensor:
|
|
77
|
+
with open(path, "r") as f:
|
|
78
|
+
for line in f:
|
|
79
|
+
th = line
|
|
80
|
+
threshold = [float(nr) for nr in th.split("[")[1].split("]")[0].split(",")]
|
|
81
|
+
if print_updates:
|
|
82
|
+
print("Loaded thresholds at " + path)
|
|
83
|
+
return Tensor(threshold)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def get_path_from_dirs(dirs: list) -> str:
|
|
87
|
+
path = os.getcwd()
|
|
88
|
+
|
|
89
|
+
for dir_name in dirs:
|
|
90
|
+
path = os.path.join(path, dir_name)
|
|
91
|
+
if not os.path.exists(path):
|
|
92
|
+
os.mkdir(path)
|
|
93
|
+
|
|
94
|
+
return path
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class ScalerFactory:
|
|
98
|
+
workers = 64
|
|
99
|
+
num_steps = 1000
|
|
100
|
+
|
|
101
|
+
def __init__(
|
|
102
|
+
self,
|
|
103
|
+
print_updates=False,
|
|
104
|
+
):
|
|
105
|
+
self.print_updates = print_updates
|
|
106
|
+
|
|
107
|
+
def has_presaved_env_data(self) -> bool:
|
|
108
|
+
return os.path.exists(self.scaler_path)
|
|
109
|
+
|
|
110
|
+
def from_env(self, env: gym.Env):
|
|
111
|
+
assert env.spec is not None
|
|
112
|
+
self.env_path = get_path_from_dirs(["data", "envs", env.spec.id])
|
|
113
|
+
self.scaler_path = os.path.join(self.env_path, "scaler.save")
|
|
114
|
+
if not self.has_presaved_env_data():
|
|
115
|
+
return self.run_env(env)
|
|
116
|
+
return load_presaved_env_data(self.scaler_path)
|
|
117
|
+
|
|
118
|
+
def from_known_values(
|
|
119
|
+
self, minima: Union[List[float], ndarray], maxima: Union[List[float], ndarray]
|
|
120
|
+
):
|
|
121
|
+
scaler = MinMaxScaler()
|
|
122
|
+
data = np.concatenate(([minima], [maxima]), axis=0)
|
|
123
|
+
scaler.fit(data)
|
|
124
|
+
return scaler
|
|
125
|
+
|
|
126
|
+
def run_env(self, env: gym.Env):
|
|
127
|
+
assert env.spec is not None
|
|
128
|
+
runs = get_runs(env.spec.id, self.num_steps * 10, workers=self.workers)
|
|
129
|
+
scaler = self._optimize_scaler(runs)
|
|
130
|
+
save_scaler(self.scaler_path, scaler)
|
|
131
|
+
return scaler
|
|
132
|
+
|
|
133
|
+
def _optimize_scaler(self, runs: list) -> MinMaxScaler:
|
|
134
|
+
if self.print_updates:
|
|
135
|
+
print("Fitting scaler...")
|
|
136
|
+
# Fit data between 0 and 1
|
|
137
|
+
concat_runs = np.concatenate(runs, axis=1).swapaxes(0, 1)
|
|
138
|
+
# Balance upper and lower bounds
|
|
139
|
+
upper = (
|
|
140
|
+
np.concatenate([abs(concat_runs.min(axis=0)), concat_runs.max(axis=0)])
|
|
141
|
+
.reshape(-1, concat_runs.shape[1])
|
|
142
|
+
.max(axis=0)
|
|
143
|
+
)
|
|
144
|
+
lower = -upper
|
|
145
|
+
concat_runs = np.append(concat_runs, upper.reshape(1, -1), axis=0)
|
|
146
|
+
concat_runs = np.append(concat_runs, lower.reshape(1, -1), axis=0)
|
|
147
|
+
|
|
148
|
+
# absolute values generate spikes from the first time steps
|
|
149
|
+
|
|
150
|
+
# Multiplyier for modifying the upper and lower bounds
|
|
151
|
+
# Needs to be greater than 0
|
|
152
|
+
# A value between 0 and 1 makes the converter more sensitive, i.e. it will generate more spikes
|
|
153
|
+
# A value greater then 1 makes the converter less sensitive, i.e. it will generate less spikes
|
|
154
|
+
max_modifier: int = 1
|
|
155
|
+
concat_runs = concat_runs * max_modifier
|
|
156
|
+
scaler = MinMaxScaler().fit(concat_runs)
|
|
157
|
+
|
|
158
|
+
return scaler
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class ConverterFactory:
|
|
162
|
+
workers = 64
|
|
163
|
+
num_steps = 1000
|
|
164
|
+
|
|
165
|
+
def __init__(self, env: gym.Env, scaler: MinMaxScaler, print_updates=False):
|
|
166
|
+
self.env = env
|
|
167
|
+
self.print_updates = print_updates
|
|
168
|
+
self.scaler = scaler
|
|
169
|
+
assert env.observation_space.shape is not None
|
|
170
|
+
self.features = env.observation_space.shape[0]
|
|
171
|
+
assert env.spec is not None
|
|
172
|
+
self.env_path = get_path_from_dirs(["data", "envs", env.spec.id])
|
|
173
|
+
self.enc_path = os.path.join(self.env_path, "th.txt")
|
|
174
|
+
self.runs = []
|
|
175
|
+
|
|
176
|
+
def initialize(self, runs):
|
|
177
|
+
if len(runs) == 0:
|
|
178
|
+
assert self.env.spec is not None
|
|
179
|
+
runs = get_runs(self.env.spec.id, self.num_steps, workers=self.workers)
|
|
180
|
+
else:
|
|
181
|
+
runs = runs[: self.num_steps]
|
|
182
|
+
|
|
183
|
+
if self.print_updates:
|
|
184
|
+
print("Initializing encoder...")
|
|
185
|
+
conv = self._optimize_converter(runs)
|
|
186
|
+
th = conv.threshold
|
|
187
|
+
th_numpy = th.detach().numpy()
|
|
188
|
+
save_treshold(self.enc_path, th_numpy)
|
|
189
|
+
return conv, th_numpy
|
|
190
|
+
|
|
191
|
+
def generate(self):
|
|
192
|
+
if not self.initialized():
|
|
193
|
+
return self.initialize(self.runs)
|
|
194
|
+
else:
|
|
195
|
+
th = load_treshold(self.enc_path)
|
|
196
|
+
th_numpy = th.detach().numpy()
|
|
197
|
+
if self.print_updates:
|
|
198
|
+
print("Loaded existing encoder initialization at " + self.enc_path)
|
|
199
|
+
return Converter(self.features, th_numpy), th_numpy
|
|
200
|
+
|
|
201
|
+
def initialized(self) -> bool:
|
|
202
|
+
return os.path.exists(self.enc_path)
|
|
203
|
+
|
|
204
|
+
def _optimize_converter(self, runs: list) -> Converter:
|
|
205
|
+
# Sort runs by length
|
|
206
|
+
runs.sort(key=len_2)
|
|
207
|
+
|
|
208
|
+
if self.print_updates:
|
|
209
|
+
print("Normalizing data")
|
|
210
|
+
|
|
211
|
+
runs = [
|
|
212
|
+
self.scaler.transform(arr.swapaxes(0, 1)).swapaxes(0, 1) for arr in runs
|
|
213
|
+
]
|
|
214
|
+
|
|
215
|
+
if self.print_updates:
|
|
216
|
+
print("Reshaping tensors...")
|
|
217
|
+
|
|
218
|
+
# Fill all arrays to the same length and stack along run dimension
|
|
219
|
+
max_len = runs[-1].shape[1]
|
|
220
|
+
runs_tensor = Tensor(
|
|
221
|
+
np.stack(
|
|
222
|
+
[
|
|
223
|
+
np.append(
|
|
224
|
+
arr,
|
|
225
|
+
np.ones((self.features, max_len - arr.shape[1])) * np.inf,
|
|
226
|
+
axis=1,
|
|
227
|
+
)
|
|
228
|
+
for arr in runs
|
|
229
|
+
]
|
|
230
|
+
)
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
converter = Converter(self.features, np.ones(self.features) * 0.01)
|
|
234
|
+
if self.print_updates:
|
|
235
|
+
print("Optimizing converter tresholds...")
|
|
236
|
+
converter.optimize(runs_tensor)
|
|
237
|
+
|
|
238
|
+
return converter
|