ezmsg-learn 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/learn/__init__.py +2 -0
- ezmsg/learn/__version__.py +34 -0
- ezmsg/learn/dim_reduce/__init__.py +0 -0
- ezmsg/learn/dim_reduce/adaptive_decomp.py +274 -0
- ezmsg/learn/dim_reduce/incremental_decomp.py +173 -0
- ezmsg/learn/linear_model/__init__.py +1 -0
- ezmsg/learn/linear_model/adaptive_linear_regressor.py +12 -0
- ezmsg/learn/linear_model/cca.py +1 -0
- ezmsg/learn/linear_model/linear_regressor.py +9 -0
- ezmsg/learn/linear_model/sgd.py +9 -0
- ezmsg/learn/linear_model/slda.py +12 -0
- ezmsg/learn/model/__init__.py +0 -0
- ezmsg/learn/model/cca.py +122 -0
- ezmsg/learn/model/mlp.py +127 -0
- ezmsg/learn/model/mlp_old.py +49 -0
- ezmsg/learn/model/refit_kalman.py +369 -0
- ezmsg/learn/model/rnn.py +160 -0
- ezmsg/learn/model/transformer.py +175 -0
- ezmsg/learn/nlin_model/__init__.py +1 -0
- ezmsg/learn/nlin_model/mlp.py +10 -0
- ezmsg/learn/process/__init__.py +0 -0
- ezmsg/learn/process/adaptive_linear_regressor.py +142 -0
- ezmsg/learn/process/base.py +154 -0
- ezmsg/learn/process/linear_regressor.py +95 -0
- ezmsg/learn/process/mlp_old.py +188 -0
- ezmsg/learn/process/refit_kalman.py +403 -0
- ezmsg/learn/process/rnn.py +245 -0
- ezmsg/learn/process/sgd.py +117 -0
- ezmsg/learn/process/sklearn.py +241 -0
- ezmsg/learn/process/slda.py +110 -0
- ezmsg/learn/process/ssr.py +374 -0
- ezmsg/learn/process/torch.py +362 -0
- ezmsg/learn/process/transformer.py +215 -0
- ezmsg/learn/util.py +67 -0
- ezmsg_learn-1.1.0.dist-info/METADATA +30 -0
- ezmsg_learn-1.1.0.dist-info/RECORD +38 -0
- ezmsg_learn-1.1.0.dist-info/WHEEL +4 -0
- ezmsg_learn-1.1.0.dist-info/licenses/LICENSE +21 -0
ezmsg/learn/model/mlp.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class MLP(torch.nn.Module):
|
|
6
|
+
"""
|
|
7
|
+
A simple Multi-Layer Perceptron (MLP) model. Adapted from Ezmsg MLP.
|
|
8
|
+
|
|
9
|
+
Attributes:
|
|
10
|
+
feature_extractor (torch.nn.Sequential): The sequential feature extractor part of the MLP.
|
|
11
|
+
heads (torch.nn.ModuleDict): A dictionary of output linear layers for each output head.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
input_size: int,
|
|
17
|
+
hidden_size: int | list[int],
|
|
18
|
+
num_layers: int | None = None,
|
|
19
|
+
output_heads: int | dict[str, int] = 2,
|
|
20
|
+
norm_layer: str | None = None,
|
|
21
|
+
activation_layer: str | None = "ReLU",
|
|
22
|
+
inplace: bool | None = None,
|
|
23
|
+
bias: bool = True,
|
|
24
|
+
dropout: float = 0.0,
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
Initialize the MLP model.
|
|
28
|
+
Args:
|
|
29
|
+
input_size (int): The size of the input features.
|
|
30
|
+
hidden_size (int | list[int]): The sizes of the hidden layers. If a list, num_layers must be None or the
|
|
31
|
+
length of the list. If a single integer, num_layers must be specified and determines the number of
|
|
32
|
+
hidden layers.
|
|
33
|
+
num_layers (int, optional): The number of hidden layers. Length of hidden_size if None. Default is None.
|
|
34
|
+
output_heads (int | dict[str, int], optional): Number of output features or classes if single head output
|
|
35
|
+
or a dictionary mapping head names to output sizes if multi-head output. Default is 2 (single head).
|
|
36
|
+
norm_layer (str, optional): A normalization layer to be applied after each linear layer. Default is None.
|
|
37
|
+
Common choices are "BatchNorm1d" or "LayerNorm".
|
|
38
|
+
activation_layer (str, optional): An activation function to be applied after each normalization
|
|
39
|
+
layer. Default is "ReLU".
|
|
40
|
+
inplace (bool, optional): Whether the activation function is performed in-place. Default is None.
|
|
41
|
+
bias (bool, optional): Whether to use bias in the linear layers. Default is True.
|
|
42
|
+
dropout (float, optional): The dropout rate to be applied after each linear layer. Default is 0.0.
|
|
43
|
+
"""
|
|
44
|
+
super().__init__()
|
|
45
|
+
if isinstance(hidden_size, int):
|
|
46
|
+
if num_layers is None:
|
|
47
|
+
raise ValueError("If hidden_size is an integer, num_layers must be specified.")
|
|
48
|
+
hidden_size = [hidden_size] * num_layers
|
|
49
|
+
if len(hidden_size) == 0:
|
|
50
|
+
raise ValueError("hidden_size must have at least one element")
|
|
51
|
+
if any(not isinstance(x, int) for x in hidden_size):
|
|
52
|
+
raise ValueError("hidden_size must contain only integers")
|
|
53
|
+
if num_layers is not None and len(hidden_size) != num_layers:
|
|
54
|
+
raise ValueError("Length of hidden_size must match num_layers if num_layers is specified.")
|
|
55
|
+
|
|
56
|
+
params = {} if inplace is None else {"inplace": inplace}
|
|
57
|
+
|
|
58
|
+
layers = []
|
|
59
|
+
in_dim = input_size
|
|
60
|
+
|
|
61
|
+
def _get_layer_class(layer_name: str):
|
|
62
|
+
if layer_name is not None and "torch.nn" in layer_name:
|
|
63
|
+
return getattr(torch.nn, layer_name.rsplit(".", 1)[1])
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
norm_layer_class = _get_layer_class(norm_layer)
|
|
67
|
+
activation_layer_class = _get_layer_class(activation_layer)
|
|
68
|
+
for hidden_dim in hidden_size[:-1]:
|
|
69
|
+
layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
|
|
70
|
+
if norm_layer_class is not None:
|
|
71
|
+
layers.append(norm_layer_class(hidden_dim))
|
|
72
|
+
if activation_layer_class is not None:
|
|
73
|
+
layers.append(activation_layer_class(**params))
|
|
74
|
+
layers.append(torch.nn.Dropout(dropout, **params))
|
|
75
|
+
in_dim = hidden_dim
|
|
76
|
+
|
|
77
|
+
layers.append(torch.nn.Linear(in_dim, hidden_size[-1], bias=bias))
|
|
78
|
+
|
|
79
|
+
self.feature_extractor = torch.nn.Sequential(*layers)
|
|
80
|
+
|
|
81
|
+
if isinstance(output_heads, int):
|
|
82
|
+
output_heads = {"output": output_heads}
|
|
83
|
+
self.heads = torch.nn.ModuleDict(
|
|
84
|
+
{name: torch.nn.Linear(hidden_size[-1], output_size) for name, output_size in output_heads.items()}
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
@classmethod
|
|
88
|
+
def infer_config_from_state_dict(cls, state_dict: dict) -> dict[str, int | float]:
|
|
89
|
+
"""
|
|
90
|
+
Infer the configuration from the state dict.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
state_dict: The state dict of the model.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
dict[str, int | float]: A dictionary containing the inferred configuration.
|
|
97
|
+
"""
|
|
98
|
+
input_size = state_dict["feature_extractor.0.weight"].shape[1]
|
|
99
|
+
hidden_size = [
|
|
100
|
+
param.shape[0]
|
|
101
|
+
for key, param in state_dict.items()
|
|
102
|
+
if key.startswith("feature_extractor.") and key.endswith(".weight")
|
|
103
|
+
]
|
|
104
|
+
output_heads = {
|
|
105
|
+
key.split(".")[1]: param.shape[0]
|
|
106
|
+
for key, param in state_dict.items()
|
|
107
|
+
if key.startswith("heads.") and key.endswith(".bias")
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
return {
|
|
111
|
+
"input_size": input_size,
|
|
112
|
+
"hidden_size": hidden_size,
|
|
113
|
+
"output_heads": output_heads,
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
117
|
+
"""
|
|
118
|
+
Forward pass through the MLP.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
x (torch.Tensor): Input tensor of shape (batch, seq_len, input_size).
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
dict[str, torch.Tensor]: A dictionary mapping head names to output tensors.
|
|
125
|
+
"""
|
|
126
|
+
x = self.feature_extractor(x)
|
|
127
|
+
return {name: head(x) for name, head in self.heads.items()}
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class MLP(torch.nn.Sequential):
|
|
6
|
+
def __init__(
|
|
7
|
+
self,
|
|
8
|
+
in_channels: int,
|
|
9
|
+
hidden_channels: list[int],
|
|
10
|
+
norm_layer: torch.nn.Module | None = None,
|
|
11
|
+
activation_layer: torch.nn.Module | None = torch.nn.ReLU,
|
|
12
|
+
inplace: bool | None = None,
|
|
13
|
+
bias: bool = True,
|
|
14
|
+
dropout: float = 0.0,
|
|
15
|
+
):
|
|
16
|
+
"""
|
|
17
|
+
Copy-pasted from torchvision MLP
|
|
18
|
+
|
|
19
|
+
:param in_channels: Number of input channels
|
|
20
|
+
:param hidden_channels: List of the hidden channel dimensions
|
|
21
|
+
:param norm_layer: Norm layer that will be stacked on top of the linear layer. If None this layer won’t be used.
|
|
22
|
+
:param activation_layer: Activation function which will be stacked on top of the normalization layer
|
|
23
|
+
(if not None), otherwise on top of the linear layer. If None this layer won’t be used.
|
|
24
|
+
:param inplace: Parameter for the activation layer, which can optionally do the operation in-place.
|
|
25
|
+
Default is None, which uses the respective default values of the activation_layer and Dropout layer.
|
|
26
|
+
:param bias: Whether to use bias in the linear layer.
|
|
27
|
+
:param dropout: The probability for the dropout layer.
|
|
28
|
+
"""
|
|
29
|
+
if len(hidden_channels) == 0:
|
|
30
|
+
raise ValueError("hidden_channels must have at least one element")
|
|
31
|
+
if any(not isinstance(x, int) for x in hidden_channels):
|
|
32
|
+
raise ValueError("hidden_channels must contain only integers")
|
|
33
|
+
|
|
34
|
+
params = {} if inplace is None else {"inplace": inplace}
|
|
35
|
+
|
|
36
|
+
layers = []
|
|
37
|
+
in_dim = in_channels
|
|
38
|
+
for hidden_dim in hidden_channels[:-1]:
|
|
39
|
+
layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
|
|
40
|
+
if norm_layer is not None:
|
|
41
|
+
layers.append(norm_layer(hidden_dim))
|
|
42
|
+
if activation_layer is not None:
|
|
43
|
+
layers.append(activation_layer(**params))
|
|
44
|
+
layers.append(torch.nn.Dropout(dropout, **params))
|
|
45
|
+
in_dim = hidden_dim
|
|
46
|
+
|
|
47
|
+
layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
|
|
48
|
+
|
|
49
|
+
super().__init__(*layers)
|
|
@@ -0,0 +1,369 @@
|
|
|
1
|
+
# refit_kalman.py
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from numpy.linalg import LinAlgError
|
|
5
|
+
from scipy.linalg import solve_discrete_are
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class RefitKalmanFilter:
|
|
9
|
+
"""
|
|
10
|
+
Refit Kalman filter for adaptive neural decoding.
|
|
11
|
+
|
|
12
|
+
This class implements a Kalman filter that can be refitted online during operation.
|
|
13
|
+
Unlike the standard Kalman filter, this version can adapt its observation model
|
|
14
|
+
(H and Q matrices) based on new data while maintaining the state transition model
|
|
15
|
+
(A and W matrices). This is particularly useful for brain-computer interfaces
|
|
16
|
+
where the relationship between neural activity and intended movements may change
|
|
17
|
+
over time.
|
|
18
|
+
|
|
19
|
+
The filter operates in two phases:
|
|
20
|
+
1. Initial fitting: Learns all system matrices (A, W, H, Q) from training data
|
|
21
|
+
2. Refitting: Updates only the observation model (H, Q) based on new data
|
|
22
|
+
|
|
23
|
+
Attributes:
|
|
24
|
+
A_state_transition_matrix: The state transition matrix A (n_states x n_states).
|
|
25
|
+
W_process_noise_covariance: The process noise covariance matrix W (n_states x n_states).
|
|
26
|
+
H_observation_matrix: The observation matrix H (n_observations x n_states).
|
|
27
|
+
Q_measurement_noise_covariance: The measurement noise covariance matrix Q (n_observations x n_observations).
|
|
28
|
+
K_kalman_gain: The Kalman gain matrix (n_states x n_observations).
|
|
29
|
+
P_state_covariance: The state error covariance matrix (n_states x n_states).
|
|
30
|
+
steady_state: Whether to use steady-state Kalman gain computation.
|
|
31
|
+
is_fitted: Whether the model has been fitted with data.
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
>>> # Create and fit the filter
|
|
35
|
+
>>> rkf = RefitKalmanFilter(steady_state=True)
|
|
36
|
+
>>> rkf.fit(X_train, y_train)
|
|
37
|
+
>>>
|
|
38
|
+
>>> # Refit with new data
|
|
39
|
+
>>> rkf.refit(X_new, Y_state, velocity_indices, targets, cursors, holds)
|
|
40
|
+
>>>
|
|
41
|
+
>>> # Predict with updated model
|
|
42
|
+
>>> x_updated = rkf.predict_and_update(measurement, current_state)
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
A_state_transition_matrix=None,
|
|
48
|
+
W_process_noise_covariance=None,
|
|
49
|
+
H_observation_matrix=None,
|
|
50
|
+
Q_measurement_noise_covariance=None,
|
|
51
|
+
steady_state=False,
|
|
52
|
+
enforce_state_structure=False,
|
|
53
|
+
alpha_fading_memory=1.000,
|
|
54
|
+
process_noise_scale=1,
|
|
55
|
+
measurement_noise_scale=1.2,
|
|
56
|
+
):
|
|
57
|
+
self.A_state_transition_matrix = A_state_transition_matrix
|
|
58
|
+
self.W_process_noise_covariance = W_process_noise_covariance
|
|
59
|
+
self.H_observation_matrix = H_observation_matrix
|
|
60
|
+
self.Q_measurement_noise_covariance = Q_measurement_noise_covariance
|
|
61
|
+
self.K_kalman_gain = None
|
|
62
|
+
self.P_state_covariance = None
|
|
63
|
+
self.alpha_fading_memory = alpha_fading_memory
|
|
64
|
+
|
|
65
|
+
# Noise scaling factors for smoothing control
|
|
66
|
+
self.process_noise_scale = process_noise_scale
|
|
67
|
+
self.measurement_noise_scale = measurement_noise_scale
|
|
68
|
+
|
|
69
|
+
self.steady_state = steady_state
|
|
70
|
+
self.enforce_state_structure = enforce_state_structure
|
|
71
|
+
self.is_fitted = False
|
|
72
|
+
|
|
73
|
+
def _validate_state_vector(self, Y_state):
|
|
74
|
+
"""
|
|
75
|
+
Validate that the state vector has proper dimensions.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
Y_state: State vector to validate
|
|
79
|
+
|
|
80
|
+
Raises:
|
|
81
|
+
ValueError: If state vector has invalid dimensions
|
|
82
|
+
"""
|
|
83
|
+
if Y_state.ndim != 2:
|
|
84
|
+
raise ValueError(f"State vector must be 2D, got {Y_state.ndim}D")
|
|
85
|
+
|
|
86
|
+
if not hasattr(self, "H_observation_matrix") or self.H_observation_matrix is None:
|
|
87
|
+
raise ValueError("Model must be fitted before refitting")
|
|
88
|
+
|
|
89
|
+
expected_states = self.H_observation_matrix.shape[1]
|
|
90
|
+
if Y_state.shape[1] != expected_states:
|
|
91
|
+
raise ValueError(f"State vector has {Y_state.shape[1]} dimensions, expected {expected_states}")
|
|
92
|
+
|
|
93
|
+
def fit(self, X_train, y_train):
|
|
94
|
+
"""
|
|
95
|
+
Fit the Refit Kalman filter to the training data.
|
|
96
|
+
|
|
97
|
+
This method learns all system matrices (A, W, H, Q) from training data
|
|
98
|
+
using least-squares estimation, then computes the steady-state solution.
|
|
99
|
+
This is the initial fitting phase that establishes the baseline model.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
X_train: Neural activity (n_samples, n_neurons).
|
|
103
|
+
y_train: Outputs being predicted (n_samples, n_states).
|
|
104
|
+
|
|
105
|
+
Raises:
|
|
106
|
+
ValueError: If training data has invalid dimensions.
|
|
107
|
+
LinAlgError: If matrix operations fail during fitting.
|
|
108
|
+
"""
|
|
109
|
+
# self._validate_state_vector(y_train)
|
|
110
|
+
|
|
111
|
+
X = np.array(y_train)
|
|
112
|
+
Z = np.array(X_train)
|
|
113
|
+
n_samples = X.shape[0]
|
|
114
|
+
|
|
115
|
+
# Calculate the transition matrix (from x_t to x_t+1) using least-squares
|
|
116
|
+
X2 = X[1:, :] # x_{t+1}
|
|
117
|
+
X1 = X[:-1, :] # x_t
|
|
118
|
+
A = X2.T @ X1 @ np.linalg.inv(X1.T @ X1) # Transition matrix
|
|
119
|
+
W = (X2 - X1 @ A.T).T @ (X2 - X1 @ A.T) / (n_samples - 1) # Covariance of transition matrix
|
|
120
|
+
|
|
121
|
+
# Calculate the measurement matrix (from x_t to z_t) using least-squares
|
|
122
|
+
H = Z.T @ X @ np.linalg.inv(X.T @ X) # Measurement matrix
|
|
123
|
+
Q = (Z - X @ H.T).T @ (Z - X @ H.T) / Z.shape[0] # Covariance of measurement matrix
|
|
124
|
+
|
|
125
|
+
self.A_state_transition_matrix = A
|
|
126
|
+
self.W_process_noise_covariance = W * self.process_noise_scale
|
|
127
|
+
self.H_observation_matrix = H
|
|
128
|
+
self.Q_measurement_noise_covariance = Q * self.measurement_noise_scale
|
|
129
|
+
|
|
130
|
+
self._compute_gain()
|
|
131
|
+
self.is_fitted = True
|
|
132
|
+
|
|
133
|
+
def refit(
|
|
134
|
+
self,
|
|
135
|
+
X_neural: np.ndarray,
|
|
136
|
+
Y_state: np.ndarray,
|
|
137
|
+
intention_velocity_indices: int | None = None,
|
|
138
|
+
target_positions: np.ndarray | None = None,
|
|
139
|
+
cursor_positions: np.ndarray | None = None,
|
|
140
|
+
hold_indices: np.ndarray | None = None,
|
|
141
|
+
):
|
|
142
|
+
"""
|
|
143
|
+
Refit the observation model based on new data.
|
|
144
|
+
|
|
145
|
+
This method updates only the observation model (H and Q matrices) while
|
|
146
|
+
keeping the state transition model (A and W matrices) unchanged. The refitting
|
|
147
|
+
process modifies the intended states based on target positions and hold flags
|
|
148
|
+
to better align with user intentions.
|
|
149
|
+
|
|
150
|
+
The refitting process:
|
|
151
|
+
1. Modifies intended states based on target positions and hold flags
|
|
152
|
+
2. Recalculates the observation matrix H using least-squares
|
|
153
|
+
3. Recalculates the measurement noise covariance Q
|
|
154
|
+
4. Updates the Kalman gain accordingly
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
X_neural: Neural activity data (n_samples, n_neurons).
|
|
158
|
+
Y_state: State estimates (n_samples, n_states).
|
|
159
|
+
intention_velocity_indices: Index of velocity components in state vector.
|
|
160
|
+
target_positions: Target positions for each sample (n_samples, 2).
|
|
161
|
+
cursor_positions: Current cursor positions (n_samples, 2).
|
|
162
|
+
hold_indices: Boolean flags indicating hold periods (n_samples,).
|
|
163
|
+
|
|
164
|
+
Raises:
|
|
165
|
+
ValueError: If input data has invalid dimensions or the model is not fitted.
|
|
166
|
+
"""
|
|
167
|
+
self._validate_state_vector(Y_state)
|
|
168
|
+
|
|
169
|
+
# Check if velocity indices are provided
|
|
170
|
+
if intention_velocity_indices is None:
|
|
171
|
+
# Assume (x, y, vx, vy)
|
|
172
|
+
vel_idx = 2 if Y_state.shape[1] >= 4 else 0
|
|
173
|
+
print(f"[RefitKalmanFilter] No velocity index provided — defaulting to {vel_idx}")
|
|
174
|
+
else:
|
|
175
|
+
if isinstance(intention_velocity_indices, (list, tuple)):
|
|
176
|
+
if len(intention_velocity_indices) != 1:
|
|
177
|
+
raise ValueError("Only one velocity start index should be provided.")
|
|
178
|
+
vel_idx = intention_velocity_indices[0]
|
|
179
|
+
else:
|
|
180
|
+
vel_idx = intention_velocity_indices
|
|
181
|
+
|
|
182
|
+
# Only remap velocity if target and cursor positions are provided
|
|
183
|
+
if target_positions is None or cursor_positions is None:
|
|
184
|
+
intended_states = Y_state.copy()
|
|
185
|
+
else:
|
|
186
|
+
intended_states = Y_state.copy()
|
|
187
|
+
# Calculate intended velocities for each sample
|
|
188
|
+
for i, (state, pos, target) in enumerate(zip(Y_state, cursor_positions, target_positions)):
|
|
189
|
+
is_hold = hold_indices[i] if hold_indices is not None else False
|
|
190
|
+
|
|
191
|
+
if is_hold:
|
|
192
|
+
# During hold periods, intended velocity is zero
|
|
193
|
+
intended_states[i, vel_idx : vel_idx + 2] = 0.0
|
|
194
|
+
if i > 0:
|
|
195
|
+
intended_states[i, :2] = intended_states[i - 1, :2] # Same position as previous
|
|
196
|
+
else:
|
|
197
|
+
# Calculate direction to target
|
|
198
|
+
to_target = target - pos
|
|
199
|
+
target_distance = np.linalg.norm(to_target)
|
|
200
|
+
|
|
201
|
+
if target_distance > 1e-5: # Avoid division by zero
|
|
202
|
+
# Get current decoded velocity magnitude
|
|
203
|
+
current_velocity = state[vel_idx : vel_idx + 2]
|
|
204
|
+
current_speed = np.linalg.norm(current_velocity)
|
|
205
|
+
|
|
206
|
+
# Calculate intended velocity: same speed, but toward target
|
|
207
|
+
target_direction = to_target / target_distance
|
|
208
|
+
intended_velocity = target_direction * current_speed
|
|
209
|
+
|
|
210
|
+
# Update intended state with new velocity
|
|
211
|
+
intended_states[i, vel_idx : vel_idx + 2] = intended_velocity
|
|
212
|
+
# If target is very close, keep original velocity
|
|
213
|
+
else:
|
|
214
|
+
intended_states[i, vel_idx : vel_idx + 2] = state[vel_idx : vel_idx + 2]
|
|
215
|
+
|
|
216
|
+
intended_states = np.array(intended_states)
|
|
217
|
+
Z = np.array(X_neural)
|
|
218
|
+
|
|
219
|
+
# Recalculate observation matrix and noise covariance
|
|
220
|
+
H = (
|
|
221
|
+
Z.T @ intended_states @ np.linalg.pinv(intended_states.T @ intended_states)
|
|
222
|
+
) # Using pinv() instead of inv() to avoid singular matrix errors
|
|
223
|
+
Q = (Z - intended_states @ H.T).T @ (Z - intended_states @ H.T) / Z.shape[0]
|
|
224
|
+
|
|
225
|
+
self.H_observation_matrix = H
|
|
226
|
+
self.Q_measurement_noise_covariance = Q
|
|
227
|
+
|
|
228
|
+
self._compute_gain()
|
|
229
|
+
|
|
230
|
+
def _compute_gain(self):
|
|
231
|
+
"""
|
|
232
|
+
Compute the Kalman gain matrix.
|
|
233
|
+
|
|
234
|
+
This method computes the Kalman gain matrix based on the current system
|
|
235
|
+
parameters. In steady-state mode, it solves the discrete-time algebraic
|
|
236
|
+
Riccati equation to find the optimal steady-state gain. In non-steady-state
|
|
237
|
+
mode, it computes the gain using the current covariance matrix.
|
|
238
|
+
|
|
239
|
+
Raises:
|
|
240
|
+
LinAlgError: If the Riccati equation cannot be solved or matrix operations fail.
|
|
241
|
+
"""
|
|
242
|
+
# TODO: consider removing non-steady-state for compute_gain() -
|
|
243
|
+
# non_steady_state updates will occur during predict() and update()
|
|
244
|
+
# if self.steady_state:
|
|
245
|
+
try:
|
|
246
|
+
# Try with original matrices
|
|
247
|
+
self.P_state_covariance = solve_discrete_are(
|
|
248
|
+
self.A_state_transition_matrix.T,
|
|
249
|
+
self.H_observation_matrix.T,
|
|
250
|
+
self.W_process_noise_covariance,
|
|
251
|
+
self.Q_measurement_noise_covariance,
|
|
252
|
+
)
|
|
253
|
+
self.K_kalman_gain = (
|
|
254
|
+
self.P_state_covariance
|
|
255
|
+
@ self.H_observation_matrix.T
|
|
256
|
+
@ np.linalg.inv(
|
|
257
|
+
self.H_observation_matrix @ self.P_state_covariance @ self.H_observation_matrix.T
|
|
258
|
+
+ self.Q_measurement_noise_covariance
|
|
259
|
+
)
|
|
260
|
+
)
|
|
261
|
+
except LinAlgError:
|
|
262
|
+
# Apply regularization and retry
|
|
263
|
+
# A_reg = self.A_state_transition_matrix * 0.999 # Slight damping
|
|
264
|
+
# W_reg = self.W_process_noise_covariance + 1e-7 * np.eye(
|
|
265
|
+
# self.W_process_noise_covariance.shape[0]
|
|
266
|
+
# )
|
|
267
|
+
Q_reg = self.Q_measurement_noise_covariance + 1e-7 * np.eye(self.Q_measurement_noise_covariance.shape[0])
|
|
268
|
+
|
|
269
|
+
try:
|
|
270
|
+
self.P_state_covariance = solve_discrete_are(
|
|
271
|
+
self.A_state_transition_matrix.T,
|
|
272
|
+
self.H_observation_matrix.T,
|
|
273
|
+
self.W_process_noise_covariance,
|
|
274
|
+
Q_reg,
|
|
275
|
+
)
|
|
276
|
+
self.K_kalman_gain = (
|
|
277
|
+
self.P_state_covariance
|
|
278
|
+
@ self.H_observation_matrix.T
|
|
279
|
+
@ np.linalg.inv(
|
|
280
|
+
self.H_observation_matrix @ self.P_state_covariance @ self.H_observation_matrix.T + Q_reg
|
|
281
|
+
)
|
|
282
|
+
)
|
|
283
|
+
print("Warning: Used regularized matrices for DARE solution")
|
|
284
|
+
except LinAlgError:
|
|
285
|
+
# Fallback to identity or manual initialization
|
|
286
|
+
print("Warning: DARE failed, using identity covariance")
|
|
287
|
+
self.P_state_covariance = np.eye(self.A_state_transition_matrix.shape[0])
|
|
288
|
+
|
|
289
|
+
# else:
|
|
290
|
+
# n_states = self.A_state_transition_matrix.shape[0]
|
|
291
|
+
# self.P_state_covariance = (
|
|
292
|
+
# np.eye(n_states) * 1000
|
|
293
|
+
# ) # Large initial uncertainty
|
|
294
|
+
|
|
295
|
+
# P_m = (
|
|
296
|
+
# self.A_state_transition_matrix
|
|
297
|
+
# @ self.P_state_covariance
|
|
298
|
+
# @ self.A_state_transition_matrix.T
|
|
299
|
+
# + self.W_process_noise_covariance
|
|
300
|
+
# )
|
|
301
|
+
|
|
302
|
+
# S = (
|
|
303
|
+
# self.H_observation_matrix @ P_m @ self.H_observation_matrix.T
|
|
304
|
+
# + self.Q_measurement_noise_covariance
|
|
305
|
+
# )
|
|
306
|
+
|
|
307
|
+
# self.K_kalman_gain = P_m @ self.H_observation_matrix.T @ np.linalg.pinv(S)
|
|
308
|
+
|
|
309
|
+
# I_mat = np.eye(self.A_state_transition_matrix.shape[0])
|
|
310
|
+
# self.P_state_covariance = (
|
|
311
|
+
# I_mat - self.K_kalman_gain @ self.H_observation_matrix
|
|
312
|
+
# ) @ P_m
|
|
313
|
+
|
|
314
|
+
def predict(self, x_current: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
|
315
|
+
"""
|
|
316
|
+
Predict the next state and covariance.
|
|
317
|
+
|
|
318
|
+
This method predicts the next state and covariance using the current state.
|
|
319
|
+
"""
|
|
320
|
+
x_predicted = self.A_state_transition_matrix @ x_current
|
|
321
|
+
if self.steady_state is True:
|
|
322
|
+
return x_predicted, None
|
|
323
|
+
else:
|
|
324
|
+
P_predicted = self.alpha_fading_memory**2 * (
|
|
325
|
+
self.A_state_transition_matrix @ self.P_state_covariance @ self.A_state_transition_matrix.T
|
|
326
|
+
+ self.W_process_noise_covariance
|
|
327
|
+
)
|
|
328
|
+
return x_predicted, P_predicted
|
|
329
|
+
|
|
330
|
+
def update(
|
|
331
|
+
self,
|
|
332
|
+
z_measurement: np.ndarray,
|
|
333
|
+
x_predicted: np.ndarray,
|
|
334
|
+
P_predicted: np.ndarray | None = None,
|
|
335
|
+
) -> np.ndarray:
|
|
336
|
+
"""Update state estimate and covariance based on measurement z."""
|
|
337
|
+
|
|
338
|
+
# Compute residual
|
|
339
|
+
innovation = z_measurement - self.H_observation_matrix @ x_predicted
|
|
340
|
+
|
|
341
|
+
if self.steady_state:
|
|
342
|
+
x_updated = x_predicted + self.K_kalman_gain @ innovation
|
|
343
|
+
return x_updated
|
|
344
|
+
|
|
345
|
+
if P_predicted is None:
|
|
346
|
+
raise ValueError("P_predicted must be provided for non-steady-state mode")
|
|
347
|
+
|
|
348
|
+
# Non-steady-state mode
|
|
349
|
+
# System uncertainty
|
|
350
|
+
S = self.H_observation_matrix @ P_predicted @ self.H_observation_matrix.T + self.Q_measurement_noise_covariance
|
|
351
|
+
|
|
352
|
+
# Kalman gain
|
|
353
|
+
K = P_predicted @ self.H_observation_matrix.T @ np.linalg.pinv(S)
|
|
354
|
+
|
|
355
|
+
# Updated state
|
|
356
|
+
x_updated = x_predicted + K @ innovation
|
|
357
|
+
|
|
358
|
+
# Covariance update
|
|
359
|
+
I_mat = np.eye(self.A_state_transition_matrix.shape[0])
|
|
360
|
+
P_updated = (I_mat - K @ self.H_observation_matrix) @ P_predicted @ (
|
|
361
|
+
I_mat - K @ self.H_observation_matrix
|
|
362
|
+
).T + K @ self.Q_measurement_noise_covariance @ K.T
|
|
363
|
+
|
|
364
|
+
# Save updated values
|
|
365
|
+
self.P_state_covariance = P_updated
|
|
366
|
+
self.K_kalman_gain = K
|
|
367
|
+
# self.S = S # Optional: for diagnostics
|
|
368
|
+
|
|
369
|
+
return x_updated
|