spike-encoding 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2 @@
1
+ from .delta_modulation_converter import DeltaModulationConverter
2
+ from .step_forward_converter import StepForwardConverter
@@ -0,0 +1,14 @@
1
+ import torch
2
+
3
+ class BaseConverter(torch.nn.Module):
4
+ def __init__(self):
5
+ super(BaseConverter, self).__init__()
6
+
7
+ def encode(self, tensor):
8
+ raise NotImplementedError
9
+
10
+ def decode(self, spikes: torch.Tensor):
11
+ raise NotImplementedError
12
+
13
+ def optimize(self, data: torch.Tensor):
14
+ raise NotImplementedError
@@ -0,0 +1,279 @@
1
+ import torch
2
+ from torch import Tensor
3
+ from torchmetrics import MeanSquaredError
4
+ from scipy.signal import firwin
5
+ import optuna
6
+ from optuna.samplers import TPESampler
7
+
8
+ from spike_encoding.base_converter import BaseConverter
9
+
10
+
11
+ class BensSpikerAlgorithm(BaseConverter):
12
+ def __init__(
13
+ self,
14
+ threshold: torch.Tensor = torch.tensor(
15
+ [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], dtype=torch.float
16
+ ),
17
+ down_spike: bool = True,
18
+ min_value=None,
19
+ max_value=None,
20
+ scale_factor=None,
21
+ filter_order: torch.Tensor = torch.tensor(
22
+ [10, 10, 10, 10, 10, 10, 10], dtype=torch.int
23
+ ),
24
+ filter_cutoff: torch.Tensor = torch.tensor(
25
+ [0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2], dtype=torch.float
26
+ ),
27
+ ):
28
+ super().__init__()
29
+ if isinstance(threshold, (float, int)):
30
+ self.threshold = torch.tensor([threshold], dtype=torch.float)
31
+ elif isinstance(threshold, (list, tuple)):
32
+ self.threshold = torch.tensor(threshold, dtype=torch.float)
33
+ elif isinstance(threshold, torch.Tensor):
34
+ self.threshold = threshold.float()
35
+ else:
36
+ raise ValueError(
37
+ "threshold must be a float, int, list, tuple, or torch.Tensor"
38
+ )
39
+ self.down_spike = down_spike
40
+ self.min_value = min_value
41
+ self.max_value = max_value
42
+ self.scale_factor = scale_factor
43
+ if isinstance(filter_order, int):
44
+ self.filter_order = torch.tensor([filter_order], dtype=torch.int)
45
+ elif isinstance(filter_order, (list, tuple)):
46
+ self.filter_order = torch.tensor(filter_order, dtype=torch.int)
47
+ elif isinstance(filter_order, torch.Tensor):
48
+ self.filter_order = filter_order.int()
49
+ else:
50
+ raise ValueError("filter_order must be an int, list, tuple or torch.Tensor")
51
+
52
+ if isinstance(filter_cutoff, (float, int)):
53
+ self.filter_cutoff = torch.tensor([filter_cutoff], dtype=torch.float)
54
+ elif isinstance(filter_cutoff, (list, tuple)):
55
+ self.filter_cutoff = torch.tensor(filter_cutoff, dtype=torch.float)
56
+ elif isinstance(filter_cutoff, torch.Tensor):
57
+ self.filter_cutoff = filter_cutoff.float()
58
+ else:
59
+ raise ValueError(
60
+ "filter_cutoff must be a float, int, list, tuple, or torch.Tensor"
61
+ )
62
+
63
+ def encode(
64
+ self,
65
+ signal: torch.Tensor,
66
+ filter_order: torch.Tensor = torch.tensor([], dtype=torch.int),
67
+ filter_cutoff: torch.Tensor = torch.tensor([], dtype=torch.float),
68
+ threshold: torch.Tensor = torch.tensor([], dtype=torch.float),
69
+ isNormed: bool = False,
70
+ ):
71
+ """
72
+ Encodes a signal into spikes using an FIR filter and a threshold.
73
+
74
+ Args:
75
+ signal (torch.Tensor): The signal to be encoded as a tensor.
76
+ filter_order (torch.Tensor): possible to give a order, DEFAULT: 100
77
+ cutoff (torch.Tensor): the cutoff percentage for the FIR-filter. DEFAULT: 0.2
78
+ threshold (torch.Tensor): The threshold to generate spikes. DEFAULT: 0.5
79
+
80
+ Returns:
81
+ spikes (torch.Tensor): The encoded spikes as a tensor.
82
+ """
83
+ signal.squeeze_()
84
+ signal = torch.atleast_2d(signal)
85
+ if isNormed:
86
+ signal_norm = signal
87
+ else:
88
+ signal_norm = self.normalize_tensor(signal)
89
+
90
+ self.filter_order = (
91
+ self.filter_order if filter_order.numel() == 0 else filter_order
92
+ )
93
+ self.filter_cutoff = (
94
+ self.filter_cutoff if filter_cutoff.numel() == 0 else filter_cutoff
95
+ )
96
+ self.threshold = self.threshold if threshold.numel() == 0 else threshold
97
+
98
+ if (
99
+ len(self.filter_order) < len(signal)
100
+ or len(self.filter_cutoff) < len(signal)
101
+ or len(self.threshold) < len(signal)
102
+ ):
103
+ raise ValueError(
104
+ "filter_order, filter_cutoff and threshold must be a tensor of the same length as the input signal."
105
+ )
106
+ self.filter_order = self.filter_order[: len(signal)]
107
+ self.filter_cutoff = self.filter_cutoff[: len(signal)]
108
+ self.threshold = self.threshold[: len(signal)]
109
+
110
+ FIR = self.fir_filter()
111
+ num_rows, L = signal.shape
112
+
113
+ # print(f"Encoding {num_rows} signals with length {L}. filter_order={self.filter_order}, filter_cutoff={self.filter_cutoff}, threshold={self.threshold} \nfir_filter={FIR}")
114
+
115
+ spike_trains = torch.zeros(num_rows, L)
116
+
117
+ for row in range(num_rows):
118
+ F = len(FIR[row])
119
+
120
+ s = signal_norm[row]
121
+
122
+ out = torch.zeros(L)
123
+ for t in range(1, L):
124
+ err1 = 0
125
+ err2 = 0
126
+
127
+ for k in range(1, F):
128
+ if t + k - 1 < L:
129
+ err1 += torch.abs(s[t + k - 1] - FIR[row][k])
130
+ err2 += torch.abs(s[t + k - 1])
131
+
132
+ if err1 <= err2 - self.threshold[row]:
133
+ out[t] = 1
134
+ for k in range(1, F):
135
+ if t + k - 1 < L:
136
+ s[t + k - 1] -= FIR[row][k]
137
+
138
+ spike_trains[row] = out
139
+
140
+ return spike_trains
141
+
142
+ def decode(self, spikes: torch.Tensor):
143
+
144
+ spikes.squeeze_()
145
+ spikes = torch.atleast_2d(spikes)
146
+
147
+ FIR = self.fir_filter()
148
+
149
+ feature_size = spikes.shape[0]
150
+ L = spikes.shape[1]
151
+
152
+ reconstructed_signals = torch.zeros((feature_size, L))
153
+
154
+ for row in range(feature_size):
155
+ spike_train = spikes[row].clone().detach()
156
+
157
+ out = torch.zeros(L)
158
+ F = len(FIR[row])
159
+ padding = F // 2
160
+ padded_input = torch.cat(
161
+ (torch.zeros(padding), spike_train, torch.zeros(padding))
162
+ )
163
+
164
+ for i in range(padding, L):
165
+ for j in range(F):
166
+ if i + j < L + padding * 2:
167
+ out[i] += padded_input[i + j - padding + 1] * FIR[row][j]
168
+
169
+ reconstructed_signals[row] = out
170
+
171
+ result_vector = []
172
+ for i in range(feature_size):
173
+ scaled_signal = (
174
+ reconstructed_signals[i] / self.scale_factor[i] + self.min_value[i]
175
+ )
176
+ result_vector.append(scaled_signal.tolist())
177
+
178
+ return result_vector
179
+
180
+ def optimize(
181
+ self,
182
+ data: torch.Tensor,
183
+ trials: int = 100,
184
+ error_function=MeanSquaredError(),
185
+ plot_history=False,
186
+ ):
187
+ """
188
+ Optimize the threshold and the cutoff persentage of the FIR-Filter of the encoding
189
+
190
+ Args:
191
+ - data (torch.Tensor): The given Signal
192
+
193
+ Returns:
194
+ - self.threshold: best threshold for the Singnal.
195
+ - self.filter_cutoff: best cutoff for the FIR-filter
196
+ """
197
+
198
+ if data.ndim > 2:
199
+ raise ValueError(
200
+ f"{type(self).__name__}.{self.optimize.__name__}() only supports input tensors with dimension <=2, but got dimension {data.ndim}."
201
+ )
202
+ data = torch.atleast_2d(data)
203
+ synth_sig = data # self.normalize_tensor(data)
204
+ f_order = torch.zeros(len(synth_sig))
205
+ f_cutoff = torch.zeros(len(synth_sig))
206
+ _threshold = torch.zeros(len(synth_sig))
207
+
208
+ for i in range(len(synth_sig)):
209
+
210
+ def loss_function(trial):
211
+ filter_order = trial.suggest_int("filter_order", 35, 50)
212
+ filter_cutoff = trial.suggest_float(
213
+ "filter_cutoff", 0.001, 0.25
214
+ ) # 0 < cutoff < fs/2
215
+ threshold = trial.suggest_float("threshold", 0.4, 1.1)
216
+ filter_order_T = torch.tensor([filter_order], dtype=torch.int)
217
+ filter_cutoff_T = torch.tensor([filter_cutoff], dtype=torch.float)
218
+ threshold_T = torch.tensor([threshold], dtype=torch.float)
219
+ encoded_signal = self.encode(
220
+ synth_sig[i],
221
+ filter_order=filter_order_T,
222
+ filter_cutoff=filter_cutoff_T,
223
+ threshold=threshold_T,
224
+ isNormed=False,
225
+ )
226
+ decoded_signal = self.decode(encoded_signal)
227
+ loss = error_function(Tensor(decoded_signal)[0], data[i])
228
+
229
+ return loss
230
+
231
+ print(f"Optimizing for signal {i+1}/{len(synth_sig)}:")
232
+
233
+ optuna.logging.set_verbosity(optuna.logging.WARNING) # hide logging
234
+ study = optuna.create_study(direction="minimize", sampler=TPESampler())
235
+ study.optimize(loss_function, n_trials=trials, show_progress_bar=True)
236
+ f_order[i] = study.best_params["filter_order"]
237
+ f_cutoff[i] = study.best_params["filter_cutoff"]
238
+ _threshold[i] = study.best_params["threshold"]
239
+
240
+ print(
241
+ f"best_order={f_order[i]}, best_cutoff={f_cutoff[i]}, best_threshold={_threshold[i]}"
242
+ )
243
+
244
+ if plot_history:
245
+ # Plot the optimization values
246
+ optuna.visualization.plot_optimization_history(study).show()
247
+ optuna.visualization.plot_slice(
248
+ study, params=["filter_order", "filter_cutoff", "threshold"]
249
+ ).show()
250
+
251
+ self.filter_order = f_order.int()
252
+ self.filter_cutoff = f_cutoff.float()
253
+ self.threshold = _threshold.float()
254
+
255
+ return self.filter_order, self.filter_cutoff, self.threshold
256
+
257
+ def normalize_tensor(self, tensor):
258
+ """Normalize a tensor to the range between 0 and 1."""
259
+ self.min_value = torch.min(tensor, dim=1)[0]
260
+ self.max_value = torch.max(tensor, dim=1)[0]
261
+ self.scale_factor = 1.0 / (self.max_value - self.min_value)
262
+
263
+ normalized_tensor = torch.zeros(tensor.shape)
264
+ for i in range(len(tensor)):
265
+ normalized_tensor[i] = self.scale_factor[i] * (
266
+ tensor[i] - self.min_value[i]
267
+ )
268
+ return normalized_tensor
269
+
270
+ def fir_filter(self):
271
+ """set up the FIR-filter coefficents"""
272
+ filter_coeffs = []
273
+ for i in range(len(self.filter_order)):
274
+ fir = firwin(
275
+ self.filter_order[i].item() + 1, self.filter_cutoff[i].item(), fs=1.0
276
+ ) # fs=1.0 for normalized frequencies
277
+ filter_coeffs.append(torch.tensor(fir, dtype=torch.float32))
278
+
279
+ return filter_coeffs
@@ -0,0 +1,80 @@
1
+ from typing import List, Literal, Union
2
+ import numpy as np
3
+ from numpy.typing import NDArray as ndarray
4
+ import numpy.random as rndm
5
+
6
+ from spike_encoding.base_converter import BaseConverter
7
+ from spike_encoding.gymnasium_bounds_finder import ScalerFactory
8
+ from spike_encoding.encoder_common import poisson, rate
9
+
10
+
11
+ def gaussian_response(x, mu, sigma=0.3):
12
+ return np.exp(-np.power(x - mu, 2.0) / (2 * np.power(sigma, 2.0)))
13
+
14
+
15
+ def transform_firing_rates(firing_rates, n_bins, sigma=0.3):
16
+ # Assuming firing rates are scaled to [0, 1]
17
+ bin_centers = np.linspace(0, 1, n_bins)
18
+ transformed_rates = []
19
+
20
+ for rate in firing_rates:
21
+ responses = [gaussian_response(rate, mu, sigma=sigma) for mu in bin_centers]
22
+ # Normalize the responses to sum to 1
23
+ normalized_responses = responses / np.sum(responses)
24
+ transformed_rates += normalized_responses.tolist()
25
+
26
+ return np.array(transformed_rates)
27
+
28
+
29
+ class BinEncoder(BaseConverter):
30
+ """
31
+ This encoder creates spike trains from gymnasium observations. To use it, first install gymnasium https://gymnasium.farama.org/
32
+
33
+ Args:
34
+ seq_length: The number of timesteps in the spike train. E.g. [[1], [0], [1]] has a seq_length of 3
35
+ scaler: This is important for scaling inputs internally. For example, if your input goes up to 10, the scaler needs to scale instances with value 10 to a firing rate of 1.0
36
+ spike_train_conversion_method: determines how a firing rate is converted into a spike train. In poisson encoding, a firing rate of 0.1 means there is a 10% chance for any given timestep to be a spike. By chance there could be more or fewer spikes. If instead "deterministic" is chosen, you are guraranteed that 10% of timesteps are spikes
37
+ max_firing_rate: multiplier for maximum firing rate. typically the maximum is 1.0 (i.e. spikes at every step). You can set this to a lower value like 0.5 so on average you will only get a spike at every other step. This may be important for some R-STDP scenarios, where very high firing rates can impact synaptic tags
38
+
39
+
40
+ Returns:
41
+ spike train: an array of arrays of spikes.
42
+ """
43
+
44
+ spike_train_conversion_method: Literal["poisson", "deterministic"] = "poisson"
45
+
46
+ def __init__(
47
+ self,
48
+ seq_length: int,
49
+ min_values: Union[List[float], ndarray],
50
+ max_values: Union[List[float], ndarray],
51
+ spike_train_conversion_method: Literal["poisson", "deterministic"] = "poisson",
52
+ n_bins=10,
53
+ max_firing_rate=1.0,
54
+ sigma=0.1,
55
+ ):
56
+ self.sigma = sigma
57
+ self.seq_length = seq_length if seq_length <= 1 else seq_length
58
+ scaler_factory = ScalerFactory()
59
+ self.scaler = scaler_factory.from_known_values(min_values, max_values)
60
+ self.n_bins = n_bins
61
+ self.spike_train_conversion_method = spike_train_conversion_method
62
+ self.max_firing_rate = max_firing_rate
63
+
64
+ self.seed = 42
65
+ rndm.seed(self.seed)
66
+
67
+ def encode(self, state: ndarray) -> ndarray:
68
+ # NOTE this uses batches for scaling and coding, but not for binning
69
+ p_spikes = self.scaler.transform(np.atleast_2d(state))[0]
70
+ p_bins = np.atleast_2d(
71
+ [transform_firing_rates(p_spikes, self.n_bins, self.sigma)]
72
+ )
73
+ p_bins *= self.max_firing_rate
74
+
75
+ if self.spike_train_conversion_method == "poisson":
76
+ output = poisson(p_bins, self.seq_length) # type: ignore
77
+ else:
78
+ output = rate(p_bins, self.seq_length) # type: ignore
79
+
80
+ return output
@@ -0,0 +1,52 @@
1
+ import torch
2
+ from .base_converter import BaseConverter
3
+
4
+
5
+
6
+ class DeltaModulationConverter(BaseConverter):
7
+ """Encodes a tensor of input values by the difference bewteen two subsequent timesteps, and simulates the spikes that
8
+ occur when meets threshold.
9
+
10
+ Parameters:
11
+ delta (float): Input with a change greater than the thresold across one timestep will generate a spike, defaults to ``0.1``
12
+ padding (bool): If ``True``, the first time step will be compared with itself resulting in ``0``'s in spikes. If ``False``, it will be padded with ``0``'s, defaults to ``False``
13
+ off_spike: If ``True``, spikes for negative changes less than ``-threshold``, defaults to ``False``
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ delta: float = 0.1,
19
+ normalized: bool = True,
20
+ padding: bool = False,
21
+ off_spike: bool = False,
22
+ ):
23
+ super(DeltaModulationConverter, self).__init__()
24
+ self.delta = delta
25
+ self.normalized = normalized
26
+ self.padding = padding
27
+ self.off_spike = off_spike
28
+
29
+ def forward(self, tensor):
30
+ if tensor.ndimension() < 2:
31
+ tensor = tensor.unsqueeze(0)
32
+ cols = torch.split(tensor, 1, 1)
33
+
34
+ if not self.normalized:
35
+ for i in range(tensor.shape[0]):
36
+ tensor[i] = (tensor[i] - torch.min(tensor[i])) / (torch.max(tensor[i]) - torch.min(tensor[i]))
37
+
38
+ if self.padding:
39
+ data_offset = torch.cat((cols[0], tensor), dim=1)[:, :-1] # duplicate first time step, remove final step
40
+ else:
41
+ data_offset = torch.cat((torch.zeros_like(cols[0]), tensor), dim=1)[:,
42
+ :-1] # add 0's to first step, remove final step
43
+
44
+ if not self.off_spike:
45
+ return torch.stack((torch.ones_like(tensor) * ((tensor - data_offset) >= self.delta), torch.zeros_like(tensor)))
46
+ else:
47
+ on_spk = torch.ones_like(tensor) * ((tensor - data_offset) >= self.delta)
48
+ off_spk = -torch.ones_like(tensor) * ((tensor - data_offset) <= -self.delta)
49
+ return torch.stack((on_spk, off_spk))
50
+
51
+ def decode(self, spikes: torch.Tensor):
52
+ raise NotImplementedError
@@ -0,0 +1,31 @@
1
+ from numpy.typing import NDArray as ndarray
2
+ import numpy as np
3
+ import numpy.random as rndm
4
+
5
+
6
+ def poisson(spike_probabilities: ndarray, max_t: int) -> ndarray:
7
+ spikes = (
8
+ np.asarray(
9
+ [
10
+ spike_probabilities > rndm.random(spike_probabilities.shape)
11
+ for i in range(int(max_t))
12
+ ]
13
+ )
14
+ + 0
15
+ )
16
+
17
+ return spikes
18
+
19
+
20
+ def rate(spike_probabilities: ndarray, max_t: int) -> ndarray:
21
+ spikes = np.zeros((max_t, *spike_probabilities.shape), dtype=int)
22
+
23
+ for index in np.ndindex(spike_probabilities.shape):
24
+ i = spike_probabilities[index]
25
+ idx = np.array(np.arange(max_t * i, 0, -1) / i, dtype=int) - 1
26
+ idx = np.clip(idx, 0, max_t - 1) # Ensure indices are within bounds
27
+
28
+ # Use tuple unpacking to set the value at the specified indices
29
+ spikes[(idx,) + index] = 1
30
+
31
+ return spikes
@@ -0,0 +1,238 @@
1
+ from typing import List, Union
2
+ import gymnasium as gym
3
+ import tqdm
4
+ from joblib import Parallel, delayed
5
+ import joblib
6
+ from sklearn.preprocessing import MinMaxScaler
7
+ from torch import Tensor
8
+ import numpy as np
9
+ import os
10
+ from numpy.typing import NDArray as ndarray
11
+
12
+ from .rate_step_forward_converter import RateStepForwardConverter as Converter
13
+
14
+
15
+ def len_2(arr: ndarray) -> int:
16
+ return arr.shape[1]
17
+
18
+
19
+ def get_runs(env_name: str, num_steps: int, workers: int, print_updates=False) -> list:
20
+ if print_updates:
21
+ print("Generating " + str(num_steps) + " random runs...")
22
+
23
+ def get_runs() -> list:
24
+ env = gym.make(env_name)
25
+ runs = []
26
+ for i in tqdm.tqdm(range(int(num_steps / workers))):
27
+ # Get the initial state of the environment
28
+ states = []
29
+ state, _ = env.reset()
30
+ done = False
31
+ trunc = False
32
+ while not (done or trunc):
33
+ states.append(state)
34
+ action = env.action_space.sample()
35
+ state, _, done, trunc, _ = env.step(action)
36
+
37
+ runs.append(np.array(states).swapaxes(0, 1))
38
+
39
+ return runs
40
+
41
+ if workers > 1:
42
+ results = Parallel(n_jobs=-1)(delayed(get_runs)() for i in [()] * workers)
43
+ runs = []
44
+ assert results is not None
45
+ for res in results:
46
+ runs.extend(res)
47
+ else:
48
+ runs = get_runs()
49
+
50
+ return runs
51
+
52
+
53
+ def save_scaler(path: str, scaler: MinMaxScaler, print_updates=False) -> bool:
54
+ joblib.dump(scaler, path)
55
+ if print_updates:
56
+ print("Saved scaler at " + path)
57
+ return True
58
+
59
+
60
+ def load_presaved_env_data(path: str, print_updates=False) -> MinMaxScaler:
61
+ scaler = joblib.load(path)
62
+ if print_updates:
63
+ print("Loaded scaler at " + path)
64
+ return scaler
65
+
66
+
67
+ def save_treshold(path: str, treshold: Tensor, print_updates=False) -> bool:
68
+ with open(path, "w") as f:
69
+ f.write(str(treshold.tolist()))
70
+
71
+ if print_updates:
72
+ print("Saved thresholds at " + path)
73
+ return True
74
+
75
+
76
+ def load_treshold(path: str, print_updates=False) -> Tensor:
77
+ with open(path, "r") as f:
78
+ for line in f:
79
+ th = line
80
+ threshold = [float(nr) for nr in th.split("[")[1].split("]")[0].split(",")]
81
+ if print_updates:
82
+ print("Loaded thresholds at " + path)
83
+ return Tensor(threshold)
84
+
85
+
86
+ def get_path_from_dirs(dirs: list) -> str:
87
+ path = os.getcwd()
88
+
89
+ for dir_name in dirs:
90
+ path = os.path.join(path, dir_name)
91
+ if not os.path.exists(path):
92
+ os.mkdir(path)
93
+
94
+ return path
95
+
96
+
97
+ class ScalerFactory:
98
+ workers = 64
99
+ num_steps = 1000
100
+
101
+ def __init__(
102
+ self,
103
+ print_updates=False,
104
+ ):
105
+ self.print_updates = print_updates
106
+
107
+ def has_presaved_env_data(self) -> bool:
108
+ return os.path.exists(self.scaler_path)
109
+
110
+ def from_env(self, env: gym.Env):
111
+ assert env.spec is not None
112
+ self.env_path = get_path_from_dirs(["data", "envs", env.spec.id])
113
+ self.scaler_path = os.path.join(self.env_path, "scaler.save")
114
+ if not self.has_presaved_env_data():
115
+ return self.run_env(env)
116
+ return load_presaved_env_data(self.scaler_path)
117
+
118
+ def from_known_values(
119
+ self, minima: Union[List[float], ndarray], maxima: Union[List[float], ndarray]
120
+ ):
121
+ scaler = MinMaxScaler()
122
+ data = np.concatenate(([minima], [maxima]), axis=0)
123
+ scaler.fit(data)
124
+ return scaler
125
+
126
+ def run_env(self, env: gym.Env):
127
+ assert env.spec is not None
128
+ runs = get_runs(env.spec.id, self.num_steps * 10, workers=self.workers)
129
+ scaler = self._optimize_scaler(runs)
130
+ save_scaler(self.scaler_path, scaler)
131
+ return scaler
132
+
133
+ def _optimize_scaler(self, runs: list) -> MinMaxScaler:
134
+ if self.print_updates:
135
+ print("Fitting scaler...")
136
+ # Fit data between 0 and 1
137
+ concat_runs = np.concatenate(runs, axis=1).swapaxes(0, 1)
138
+ # Balance upper and lower bounds
139
+ upper = (
140
+ np.concatenate([abs(concat_runs.min(axis=0)), concat_runs.max(axis=0)])
141
+ .reshape(-1, concat_runs.shape[1])
142
+ .max(axis=0)
143
+ )
144
+ lower = -upper
145
+ concat_runs = np.append(concat_runs, upper.reshape(1, -1), axis=0)
146
+ concat_runs = np.append(concat_runs, lower.reshape(1, -1), axis=0)
147
+
148
+ # absolute values generate spikes from the first time steps
149
+
150
+ # Multiplyier for modifying the upper and lower bounds
151
+ # Needs to be greater than 0
152
+ # A value between 0 and 1 makes the converter more sensitive, i.e. it will generate more spikes
153
+ # A value greater then 1 makes the converter less sensitive, i.e. it will generate less spikes
154
+ max_modifier: int = 1
155
+ concat_runs = concat_runs * max_modifier
156
+ scaler = MinMaxScaler().fit(concat_runs)
157
+
158
+ return scaler
159
+
160
+
161
+ class ConverterFactory:
162
+ workers = 64
163
+ num_steps = 1000
164
+
165
+ def __init__(self, env: gym.Env, scaler: MinMaxScaler, print_updates=False):
166
+ self.env = env
167
+ self.print_updates = print_updates
168
+ self.scaler = scaler
169
+ assert env.observation_space.shape is not None
170
+ self.features = env.observation_space.shape[0]
171
+ assert env.spec is not None
172
+ self.env_path = get_path_from_dirs(["data", "envs", env.spec.id])
173
+ self.enc_path = os.path.join(self.env_path, "th.txt")
174
+ self.runs = []
175
+
176
+ def initialize(self, runs):
177
+ if len(runs) == 0:
178
+ assert self.env.spec is not None
179
+ runs = get_runs(self.env.spec.id, self.num_steps, workers=self.workers)
180
+ else:
181
+ runs = runs[: self.num_steps]
182
+
183
+ if self.print_updates:
184
+ print("Initializing encoder...")
185
+ conv = self._optimize_converter(runs)
186
+ th = conv.threshold
187
+ th_numpy = th.detach().numpy()
188
+ save_treshold(self.enc_path, th_numpy)
189
+ return conv, th_numpy
190
+
191
+ def generate(self):
192
+ if not self.initialized():
193
+ return self.initialize(self.runs)
194
+ else:
195
+ th = load_treshold(self.enc_path)
196
+ th_numpy = th.detach().numpy()
197
+ if self.print_updates:
198
+ print("Loaded existing encoder initialization at " + self.enc_path)
199
+ return Converter(self.features, th_numpy), th_numpy
200
+
201
+ def initialized(self) -> bool:
202
+ return os.path.exists(self.enc_path)
203
+
204
+ def _optimize_converter(self, runs: list) -> Converter:
205
+ # Sort runs by length
206
+ runs.sort(key=len_2)
207
+
208
+ if self.print_updates:
209
+ print("Normalizing data")
210
+
211
+ runs = [
212
+ self.scaler.transform(arr.swapaxes(0, 1)).swapaxes(0, 1) for arr in runs
213
+ ]
214
+
215
+ if self.print_updates:
216
+ print("Reshaping tensors...")
217
+
218
+ # Fill all arrays to the same length and stack along run dimension
219
+ max_len = runs[-1].shape[1]
220
+ runs_tensor = Tensor(
221
+ np.stack(
222
+ [
223
+ np.append(
224
+ arr,
225
+ np.ones((self.features, max_len - arr.shape[1])) * np.inf,
226
+ axis=1,
227
+ )
228
+ for arr in runs
229
+ ]
230
+ )
231
+ )
232
+
233
+ converter = Converter(self.features, np.ones(self.features) * 0.01)
234
+ if self.print_updates:
235
+ print("Optimizing converter tresholds...")
236
+ converter.optimize(runs_tensor)
237
+
238
+ return converter