ezmsg-learn 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,133 @@
1
+ import torch
2
+ import torch.nn
3
+
4
+
5
+ class MLP(torch.nn.Module):
6
+ """
7
+ A simple Multi-Layer Perceptron (MLP) model. Adapted from Ezmsg MLP.
8
+
9
+ Attributes:
10
+ feature_extractor (torch.nn.Sequential): The sequential feature extractor part of the MLP.
11
+ heads (torch.nn.ModuleDict): A dictionary of output linear layers for each output head.
12
+ """
13
+
14
+ def __init__(
15
+ self,
16
+ input_size: int,
17
+ hidden_size: int | list[int],
18
+ num_layers: int | None = None,
19
+ output_heads: int | dict[str, int] = 2,
20
+ norm_layer: str | None = None,
21
+ activation_layer: str | None = "ReLU",
22
+ inplace: bool | None = None,
23
+ bias: bool = True,
24
+ dropout: float = 0.0,
25
+ ):
26
+ """
27
+ Initialize the MLP model.
28
+ Args:
29
+ input_size (int): The size of the input features.
30
+ hidden_size (int | list[int]): The sizes of the hidden layers. If a list, num_layers must be None or the length
31
+ of the list. If a single integer, num_layers must be specified and determines the number of hidden layers.
32
+ num_layers (int, optional): The number of hidden layers. Length of hidden_size if None. Default is None.
33
+ output_heads (int | dict[str, int], optional): Number of output features or classes if single head output or a
34
+ dictionary mapping head names to output sizes if multi-head output. Default is 2 (single head).
35
+ norm_layer (str, optional): A normalization layer to be applied after each linear layer. Default is None.
36
+ Common choices are "BatchNorm1d" or "LayerNorm".
37
+ activation_layer (str, optional): An activation function to be applied after each normalization
38
+ layer. Default is "ReLU".
39
+ inplace (bool, optional): Whether the activation function is performed in-place. Default is None.
40
+ bias (bool, optional): Whether to use bias in the linear layers. Default is True.
41
+ dropout (float, optional): The dropout rate to be applied after each linear layer. Default is 0.0.
42
+ """
43
+ super().__init__()
44
+ if isinstance(hidden_size, int):
45
+ if num_layers is None:
46
+ raise ValueError(
47
+ "If hidden_size is an integer, num_layers must be specified."
48
+ )
49
+ hidden_size = [hidden_size] * num_layers
50
+ if len(hidden_size) == 0:
51
+ raise ValueError("hidden_size must have at least one element")
52
+ if any(not isinstance(x, int) for x in hidden_size):
53
+ raise ValueError("hidden_size must contain only integers")
54
+ if num_layers is not None and len(hidden_size) != num_layers:
55
+ raise ValueError(
56
+ "Length of hidden_size must match num_layers if num_layers is specified."
57
+ )
58
+
59
+ params = {} if inplace is None else {"inplace": inplace}
60
+
61
+ layers = []
62
+ in_dim = input_size
63
+
64
+ def _get_layer_class(layer_name: str):
65
+ if layer_name is not None and "torch.nn" in layer_name:
66
+ return getattr(torch.nn, layer_name.rsplit(".", 1)[1])
67
+ return None
68
+
69
+ norm_layer_class = _get_layer_class(norm_layer)
70
+ activation_layer_class = _get_layer_class(activation_layer)
71
+ for hidden_dim in hidden_size[:-1]:
72
+ layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
73
+ if norm_layer_class is not None:
74
+ layers.append(norm_layer_class(hidden_dim))
75
+ if activation_layer_class is not None:
76
+ layers.append(activation_layer_class(**params))
77
+ layers.append(torch.nn.Dropout(dropout, **params))
78
+ in_dim = hidden_dim
79
+
80
+ layers.append(torch.nn.Linear(in_dim, hidden_size[-1], bias=bias))
81
+
82
+ self.feature_extractor = torch.nn.Sequential(*layers)
83
+
84
+ if isinstance(output_heads, int):
85
+ output_heads = {"output": output_heads}
86
+ self.heads = torch.nn.ModuleDict(
87
+ {
88
+ name: torch.nn.Linear(hidden_size[-1], output_size)
89
+ for name, output_size in output_heads.items()
90
+ }
91
+ )
92
+
93
+ @classmethod
94
+ def infer_config_from_state_dict(cls, state_dict: dict) -> dict[str, int | float]:
95
+ """
96
+ Infer the configuration from the state dict.
97
+
98
+ Args:
99
+ state_dict: The state dict of the model.
100
+
101
+ Returns:
102
+ dict[str, int | float]: A dictionary containing the inferred configuration.
103
+ """
104
+ input_size = state_dict["feature_extractor.0.weight"].shape[1]
105
+ hidden_size = [
106
+ param.shape[0]
107
+ for key, param in state_dict.items()
108
+ if key.startswith("feature_extractor.") and key.endswith(".weight")
109
+ ]
110
+ output_heads = {
111
+ key.split(".")[1]: param.shape[0]
112
+ for key, param in state_dict.items()
113
+ if key.startswith("heads.") and key.endswith(".bias")
114
+ }
115
+
116
+ return {
117
+ "input_size": input_size,
118
+ "hidden_size": hidden_size,
119
+ "output_heads": output_heads,
120
+ }
121
+
122
+ def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
123
+ """
124
+ Forward pass through the MLP.
125
+
126
+ Args:
127
+ x (torch.Tensor): Input tensor of shape (batch, seq_len, input_size).
128
+
129
+ Returns:
130
+ dict[str, torch.Tensor]: A dictionary mapping head names to output tensors.
131
+ """
132
+ x = self.feature_extractor(x)
133
+ return {name: head(x) for name, head in self.heads.items()}
@@ -0,0 +1,49 @@
1
+ import torch
2
+ import torch.nn
3
+
4
+
5
+ class MLP(torch.nn.Sequential):
6
+ def __init__(
7
+ self,
8
+ in_channels: int,
9
+ hidden_channels: list[int],
10
+ norm_layer: torch.nn.Module | None = None,
11
+ activation_layer: torch.nn.Module | None = torch.nn.ReLU,
12
+ inplace: bool | None = None,
13
+ bias: bool = True,
14
+ dropout: float = 0.0,
15
+ ):
16
+ """
17
+ Copy-pasted from torchvision MLP
18
+
19
+ :param in_channels: Number of input channels
20
+ :param hidden_channels: List of the hidden channel dimensions
21
+ :param norm_layer: Norm layer that will be stacked on top of the linear layer. If None this layer won’t be used.
22
+ :param activation_layer: Activation function which will be stacked on top of the normalization layer
23
+ (if not None), otherwise on top of the linear layer. If None this layer won’t be used.
24
+ :param inplace: Parameter for the activation layer, which can optionally do the operation in-place.
25
+ Default is None, which uses the respective default values of the activation_layer and Dropout layer.
26
+ :param bias: Whether to use bias in the linear layer.
27
+ :param dropout: The probability for the dropout layer.
28
+ """
29
+ if len(hidden_channels) == 0:
30
+ raise ValueError("hidden_channels must have at least one element")
31
+ if any(not isinstance(x, int) for x in hidden_channels):
32
+ raise ValueError("hidden_channels must contain only integers")
33
+
34
+ params = {} if inplace is None else {"inplace": inplace}
35
+
36
+ layers = []
37
+ in_dim = in_channels
38
+ for hidden_dim in hidden_channels[:-1]:
39
+ layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
40
+ if norm_layer is not None:
41
+ layers.append(norm_layer(hidden_dim))
42
+ if activation_layer is not None:
43
+ layers.append(activation_layer(**params))
44
+ layers.append(torch.nn.Dropout(dropout, **params))
45
+ in_dim = hidden_dim
46
+
47
+ layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
48
+
49
+ super().__init__(*layers)
@@ -0,0 +1,401 @@
1
+ # refit_kalman.py
2
+
3
+ import numpy as np
4
+ from numpy.linalg import LinAlgError
5
+ from scipy.linalg import solve_discrete_are
6
+
7
+
8
+ class RefitKalmanFilter:
9
+ """
10
+ Refit Kalman filter for adaptive neural decoding.
11
+
12
+ This class implements a Kalman filter that can be refitted online during operation.
13
+ Unlike the standard Kalman filter, this version can adapt its observation model
14
+ (H and Q matrices) based on new data while maintaining the state transition model
15
+ (A and W matrices). This is particularly useful for brain-computer interfaces
16
+ where the relationship between neural activity and intended movements may change
17
+ over time.
18
+
19
+ The filter operates in two phases:
20
+ 1. Initial fitting: Learns all system matrices (A, W, H, Q) from training data
21
+ 2. Refitting: Updates only the observation model (H, Q) based on new data
22
+
23
+ Attributes:
24
+ A_state_transition_matrix: The state transition matrix A (n_states x n_states).
25
+ W_process_noise_covariance: The process noise covariance matrix W (n_states x n_states).
26
+ H_observation_matrix: The observation matrix H (n_observations x n_states).
27
+ Q_measurement_noise_covariance: The measurement noise covariance matrix Q (n_observations x n_observations).
28
+ K_kalman_gain: The Kalman gain matrix (n_states x n_observations).
29
+ P_state_covariance: The state error covariance matrix (n_states x n_states).
30
+ steady_state: Whether to use steady-state Kalman gain computation.
31
+ is_fitted: Whether the model has been fitted with data.
32
+
33
+ Example:
34
+ >>> # Create and fit the filter
35
+ >>> rkf = RefitKalmanFilter(steady_state=True)
36
+ >>> rkf.fit(X_train, y_train)
37
+ >>>
38
+ >>> # Refit with new data
39
+ >>> rkf.refit(X_new, Y_state, velocity_indices, targets, cursors, holds)
40
+ >>>
41
+ >>> # Predict with updated model
42
+ >>> x_updated = rkf.predict_and_update(measurement, current_state)
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ A_state_transition_matrix=None,
48
+ W_process_noise_covariance=None,
49
+ H_observation_matrix=None,
50
+ Q_measurement_noise_covariance=None,
51
+ steady_state=False,
52
+ enforce_state_structure=False,
53
+ alpha_fading_memory=1.000,
54
+ process_noise_scale=1,
55
+ measurement_noise_scale=1.2,
56
+ ):
57
+ self.A_state_transition_matrix = A_state_transition_matrix
58
+ self.W_process_noise_covariance = W_process_noise_covariance
59
+ self.H_observation_matrix = H_observation_matrix
60
+ self.Q_measurement_noise_covariance = Q_measurement_noise_covariance
61
+ self.K_kalman_gain = None
62
+ self.P_state_covariance = None
63
+ self.alpha_fading_memory = alpha_fading_memory
64
+
65
+ # Noise scaling factors for smoothing control
66
+ self.process_noise_scale = process_noise_scale
67
+ self.measurement_noise_scale = measurement_noise_scale
68
+
69
+ self.steady_state = steady_state
70
+ self.enforce_state_structure = enforce_state_structure
71
+ self.is_fitted = False
72
+
73
+ def _validate_state_vector(self, Y_state):
74
+ """
75
+ Validate that the state vector has proper dimensions.
76
+
77
+ Args:
78
+ Y_state: State vector to validate
79
+
80
+ Raises:
81
+ ValueError: If state vector has invalid dimensions
82
+ """
83
+ if Y_state.ndim != 2:
84
+ raise ValueError(f"State vector must be 2D, got {Y_state.ndim}D")
85
+
86
+ if (
87
+ not hasattr(self, "H_observation_matrix")
88
+ or self.H_observation_matrix is None
89
+ ):
90
+ raise ValueError("Model must be fitted before refitting")
91
+
92
+ expected_states = self.H_observation_matrix.shape[1]
93
+ if Y_state.shape[1] != expected_states:
94
+ raise ValueError(
95
+ f"State vector has {Y_state.shape[1]} dimensions, expected {expected_states}"
96
+ )
97
+
98
+ def fit(self, X_train, y_train):
99
+ """
100
+ Fit the Refit Kalman filter to the training data.
101
+
102
+ This method learns all system matrices (A, W, H, Q) from training data
103
+ using least-squares estimation, then computes the steady-state solution.
104
+ This is the initial fitting phase that establishes the baseline model.
105
+
106
+ Args:
107
+ X_train: Neural activity (n_samples, n_neurons).
108
+ y_train: Outputs being predicted (n_samples, n_states).
109
+
110
+ Raises:
111
+ ValueError: If training data has invalid dimensions.
112
+ LinAlgError: If matrix operations fail during fitting.
113
+ """
114
+ # self._validate_state_vector(y_train)
115
+
116
+ X = np.array(y_train)
117
+ Z = np.array(X_train)
118
+ n_samples = X.shape[0]
119
+
120
+ # Calculate the transition matrix (from x_t to x_t+1) using least-squares
121
+ X2 = X[1:, :] # x_{t+1}
122
+ X1 = X[:-1, :] # x_t
123
+ A = X2.T @ X1 @ np.linalg.inv(X1.T @ X1) # Transition matrix
124
+ W = (
125
+ (X2 - X1 @ A.T).T @ (X2 - X1 @ A.T) / (n_samples - 1)
126
+ ) # Covariance of transition matrix
127
+
128
+ # Calculate the measurement matrix (from x_t to z_t) using least-squares
129
+ H = Z.T @ X @ np.linalg.inv(X.T @ X) # Measurement matrix
130
+ Q = (
131
+ (Z - X @ H.T).T @ (Z - X @ H.T) / Z.shape[0]
132
+ ) # Covariance of measurement matrix
133
+
134
+ self.A_state_transition_matrix = A
135
+ self.W_process_noise_covariance = W * self.process_noise_scale
136
+ self.H_observation_matrix = H
137
+ self.Q_measurement_noise_covariance = Q * self.measurement_noise_scale
138
+
139
+ self._compute_gain()
140
+ self.is_fitted = True
141
+
142
+ def refit(
143
+ self,
144
+ X_neural: np.ndarray,
145
+ Y_state: np.ndarray,
146
+ intention_velocity_indices: int | None = None,
147
+ target_positions: np.ndarray | None = None,
148
+ cursor_positions: np.ndarray | None = None,
149
+ hold_indices: np.ndarray | None = None,
150
+ ):
151
+ """
152
+ Refit the observation model based on new data.
153
+
154
+ This method updates only the observation model (H and Q matrices) while
155
+ keeping the state transition model (A and W matrices) unchanged. The refitting
156
+ process modifies the intended states based on target positions and hold flags
157
+ to better align with user intentions.
158
+
159
+ The refitting process:
160
+ 1. Modifies intended states based on target positions and hold flags
161
+ 2. Recalculates the observation matrix H using least-squares
162
+ 3. Recalculates the measurement noise covariance Q
163
+ 4. Updates the Kalman gain accordingly
164
+
165
+ Args:
166
+ X_neural: Neural activity data (n_samples, n_neurons).
167
+ Y_state: State estimates (n_samples, n_states).
168
+ intention_velocity_indices: Index of velocity components in state vector.
169
+ target_positions: Target positions for each sample (n_samples, 2).
170
+ cursor_positions: Current cursor positions (n_samples, 2).
171
+ hold_indices: Boolean flags indicating hold periods (n_samples,).
172
+
173
+ Raises:
174
+ ValueError: If input data has invalid dimensions or the model is not fitted.
175
+ """
176
+ self._validate_state_vector(Y_state)
177
+
178
+ # Check if velocity indices are provided
179
+ if intention_velocity_indices is None:
180
+ # Assume (x, y, vx, vy)
181
+ vel_idx = 2 if Y_state.shape[1] >= 4 else 0
182
+ print(
183
+ f"[RefitKalmanFilter] No velocity index provided — defaulting to {vel_idx}"
184
+ )
185
+ else:
186
+ if isinstance(intention_velocity_indices, (list, tuple)):
187
+ if len(intention_velocity_indices) != 1:
188
+ raise ValueError(
189
+ "Only one velocity start index should be provided."
190
+ )
191
+ vel_idx = intention_velocity_indices[0]
192
+ else:
193
+ vel_idx = intention_velocity_indices
194
+
195
+ # Only remap velocity if target and cursor positions are provided
196
+ if target_positions is None or cursor_positions is None:
197
+ intended_states = Y_state.copy()
198
+ else:
199
+ intended_states = Y_state.copy()
200
+ # Calculate intended velocities for each sample
201
+ for i, (state, pos, target) in enumerate(
202
+ zip(Y_state, cursor_positions, target_positions)
203
+ ):
204
+ is_hold = hold_indices[i] if hold_indices is not None else False
205
+
206
+ if is_hold:
207
+ # During hold periods, intended velocity is zero
208
+ intended_states[i, vel_idx : vel_idx + 2] = 0.0
209
+ if i > 0:
210
+ intended_states[i, :2] = intended_states[
211
+ i - 1, :2
212
+ ] # Same position as previous
213
+ else:
214
+ # Calculate direction to target
215
+ to_target = target - pos
216
+ target_distance = np.linalg.norm(to_target)
217
+
218
+ if target_distance > 1e-5: # Avoid division by zero
219
+ # Get current decoded velocity magnitude
220
+ current_velocity = state[vel_idx : vel_idx + 2]
221
+ current_speed = np.linalg.norm(current_velocity)
222
+
223
+ # Calculate intended velocity: same speed, but toward target
224
+ target_direction = to_target / target_distance
225
+ intended_velocity = target_direction * current_speed
226
+
227
+ # Update intended state with new velocity
228
+ intended_states[i, vel_idx : vel_idx + 2] = intended_velocity
229
+ # If target is very close, keep original velocity
230
+ else:
231
+ intended_states[i, vel_idx : vel_idx + 2] = state[
232
+ vel_idx : vel_idx + 2
233
+ ]
234
+
235
+ intended_states = np.array(intended_states)
236
+ Z = np.array(X_neural)
237
+
238
+ # Recalculate observation matrix and noise covariance
239
+ H = (
240
+ Z.T @ intended_states @ np.linalg.pinv(intended_states.T @ intended_states)
241
+ ) # Using pinv() instead of inv() to avoid singular matrix errors
242
+ Q = (Z - intended_states @ H.T).T @ (Z - intended_states @ H.T) / Z.shape[0]
243
+
244
+ self.H_observation_matrix = H
245
+ self.Q_measurement_noise_covariance = Q
246
+
247
+ self._compute_gain()
248
+
249
+ def _compute_gain(self):
250
+ """
251
+ Compute the Kalman gain matrix.
252
+
253
+ This method computes the Kalman gain matrix based on the current system
254
+ parameters. In steady-state mode, it solves the discrete-time algebraic
255
+ Riccati equation to find the optimal steady-state gain. In non-steady-state
256
+ mode, it computes the gain using the current covariance matrix.
257
+
258
+ Raises:
259
+ LinAlgError: If the Riccati equation cannot be solved or matrix operations fail.
260
+ """
261
+ ## TODO: consider removing non-steady-state for compute_gain() - non_steady_state updates will occur during predict() and update()
262
+ # if self.steady_state:
263
+ try:
264
+ # Try with original matrices
265
+ self.P_state_covariance = solve_discrete_are(
266
+ self.A_state_transition_matrix.T,
267
+ self.H_observation_matrix.T,
268
+ self.W_process_noise_covariance,
269
+ self.Q_measurement_noise_covariance,
270
+ )
271
+ self.K_kalman_gain = (
272
+ self.P_state_covariance
273
+ @ self.H_observation_matrix.T
274
+ @ np.linalg.inv(
275
+ self.H_observation_matrix
276
+ @ self.P_state_covariance
277
+ @ self.H_observation_matrix.T
278
+ + self.Q_measurement_noise_covariance
279
+ )
280
+ )
281
+ except LinAlgError:
282
+ # Apply regularization and retry
283
+ # A_reg = self.A_state_transition_matrix * 0.999 # Slight damping
284
+ # W_reg = self.W_process_noise_covariance + 1e-7 * np.eye(
285
+ # self.W_process_noise_covariance.shape[0]
286
+ # )
287
+ Q_reg = self.Q_measurement_noise_covariance + 1e-7 * np.eye(
288
+ self.Q_measurement_noise_covariance.shape[0]
289
+ )
290
+
291
+ try:
292
+ self.P_state_covariance = solve_discrete_are(
293
+ self.A_state_transition_matrix.T,
294
+ self.H_observation_matrix.T,
295
+ self.W_process_noise_covariance,
296
+ Q_reg,
297
+ )
298
+ self.K_kalman_gain = (
299
+ self.P_state_covariance
300
+ @ self.H_observation_matrix.T
301
+ @ np.linalg.inv(
302
+ self.H_observation_matrix
303
+ @ self.P_state_covariance
304
+ @ self.H_observation_matrix.T
305
+ + Q_reg
306
+ )
307
+ )
308
+ print("Warning: Used regularized matrices for DARE solution")
309
+ except LinAlgError:
310
+ # Fallback to identity or manual initialization
311
+ print("Warning: DARE failed, using identity covariance")
312
+ self.P_state_covariance = np.eye(
313
+ self.A_state_transition_matrix.shape[0]
314
+ )
315
+
316
+ # else:
317
+ # n_states = self.A_state_transition_matrix.shape[0]
318
+ # self.P_state_covariance = (
319
+ # np.eye(n_states) * 1000
320
+ # ) # Large initial uncertainty
321
+
322
+ # P_m = (
323
+ # self.A_state_transition_matrix
324
+ # @ self.P_state_covariance
325
+ # @ self.A_state_transition_matrix.T
326
+ # + self.W_process_noise_covariance
327
+ # )
328
+
329
+ # S = (
330
+ # self.H_observation_matrix @ P_m @ self.H_observation_matrix.T
331
+ # + self.Q_measurement_noise_covariance
332
+ # )
333
+
334
+ # self.K_kalman_gain = P_m @ self.H_observation_matrix.T @ np.linalg.pinv(S)
335
+
336
+ # I_mat = np.eye(self.A_state_transition_matrix.shape[0])
337
+ # self.P_state_covariance = (
338
+ # I_mat - self.K_kalman_gain @ self.H_observation_matrix
339
+ # ) @ P_m
340
+
341
+ def predict(self, x_current: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
342
+ """
343
+ Predict the next state and covariance.
344
+
345
+ This method predicts the next state and covariance using the current state.
346
+ """
347
+ x_predicted = self.A_state_transition_matrix @ x_current
348
+ if self.steady_state is True:
349
+ return x_predicted, None
350
+ else:
351
+ P_predicted = self.alpha_fading_memory**2 * (
352
+ self.A_state_transition_matrix
353
+ @ self.P_state_covariance
354
+ @ self.A_state_transition_matrix.T
355
+ + self.W_process_noise_covariance
356
+ )
357
+ return x_predicted, P_predicted
358
+
359
+ def update(
360
+ self,
361
+ z_measurement: np.ndarray,
362
+ x_predicted: np.ndarray,
363
+ P_predicted: np.ndarray | None = None,
364
+ ) -> np.ndarray:
365
+ """Update state estimate and covariance based on measurement z."""
366
+
367
+ # Compute residual
368
+ innovation = z_measurement - self.H_observation_matrix @ x_predicted
369
+
370
+ if self.steady_state:
371
+ x_updated = x_predicted + self.K_kalman_gain @ innovation
372
+ return x_updated
373
+
374
+ if P_predicted is None:
375
+ raise ValueError("P_predicted must be provided for non-steady-state mode")
376
+
377
+ # Non-steady-state mode
378
+ # System uncertainty
379
+ S = (
380
+ self.H_observation_matrix @ P_predicted @ self.H_observation_matrix.T
381
+ + self.Q_measurement_noise_covariance
382
+ )
383
+
384
+ # Kalman gain
385
+ K = P_predicted @ self.H_observation_matrix.T @ np.linalg.pinv(S)
386
+
387
+ # Updated state
388
+ x_updated = x_predicted + K @ innovation
389
+
390
+ # Covariance update
391
+ I_mat = np.eye(self.A_state_transition_matrix.shape[0])
392
+ P_updated = (I_mat - K @ self.H_observation_matrix) @ P_predicted @ (
393
+ I_mat - K @ self.H_observation_matrix
394
+ ).T + K @ self.Q_measurement_noise_covariance @ K.T
395
+
396
+ # Save updated values
397
+ self.P_state_covariance = P_updated
398
+ self.K_kalman_gain = K
399
+ # self.S = S # Optional: for diagnostics
400
+
401
+ return x_updated