diffkalman 0.1.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffkalman/__init__.py +6 -0
- diffkalman/em_loop.py +115 -0
- diffkalman/filter.py +654 -0
- diffkalman/joint_jacobian_transform.py +54 -0
- diffkalman/negative_log_likelihood.py +133 -0
- diffkalman/py.typed +0 -0
- diffkalman/utils.py +90 -0
- diffkalman-0.1.0.dist-info/METADATA +183 -0
- diffkalman-0.1.0.dist-info/RECORD +11 -0
- diffkalman-0.1.0.dist-info/WHEEL +4 -0
- diffkalman-0.1.0.dist-info/licenses/LICENSE +21 -0
diffkalman/__init__.py
ADDED
diffkalman/em_loop.py
ADDED
@@ -0,0 +1,115 @@
|
|
1
|
+
"""The EM loop module that implements the EM algorithm for the Differentiable Kalman Filter."""
|
2
|
+
|
3
|
+
from .filter import DiffrentiableKalmanFilter
|
4
|
+
from .utils import SymmetricPositiveDefiniteMatrix
|
5
|
+
import torch
|
6
|
+
import torch.nn as nn
|
7
|
+
|
8
|
+
|
9
|
+
__all__ = [
|
10
|
+
"em_updates",
|
11
|
+
]
|
12
|
+
|
13
|
+
|
14
|
+
def em_updates(
|
15
|
+
dkf: DiffrentiableKalmanFilter,
|
16
|
+
z_seq: torch.Tensor,
|
17
|
+
x0: torch.Tensor,
|
18
|
+
P0: torch.Tensor,
|
19
|
+
Q: SymmetricPositiveDefiniteMatrix,
|
20
|
+
R: SymmetricPositiveDefiniteMatrix,
|
21
|
+
optimizer: torch.optim.Optimizer,
|
22
|
+
lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
|
23
|
+
num_cycles: int = 20,
|
24
|
+
num_epochs: int = 100,
|
25
|
+
h_args: tuple = (),
|
26
|
+
f_args: tuple = (),
|
27
|
+
) -> dict:
|
28
|
+
"""A sample implementation of the EM algorithm for the Differentiable Kalman Filter.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
z (torch.Tensor): The noisy measurements sequence. Dimension: (seq_len, obs_dim)
|
32
|
+
x0 (torch.Tensor): The initial state vector. Dimension: (state_dim,)
|
33
|
+
P0 (torch.Tensor): The initial covariance matrix of the state vector. Dimension: (state_dim, state_dim)
|
34
|
+
Q (SymmetricPositiveDefiniteMatrix): The process noise covariance matrix module.
|
35
|
+
R (SymmetricPositiveDefiniteMatrix): The measurement noise covariance matrix module.
|
36
|
+
optimizer (torch.optim.Optimizer): The optimizer
|
37
|
+
lr_scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler
|
38
|
+
num_cycles (int): The number of cycles. Default: 20
|
39
|
+
num_epochs (int): The number of epochs. Default: 100
|
40
|
+
h_args (tuple): Additional arguments for the measurement model. Default: ()
|
41
|
+
f_args (tuple): Additional arguments for the transition function. Default: ()
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
dict: log likelihoods of the model with respect to the number of epochs and cycles
|
45
|
+
|
46
|
+
Note:
|
47
|
+
- Use the SymmetricPositiveDefiniteMatrix module to ensure that the process noise covariance matrix Q and the measurement noise covariance matrix R are symmetric and positive definite.
|
48
|
+
- The optimizer and the learning rate scheduler should be initialized before calling this function.
|
49
|
+
- The measurement model and the transition function should be differentiable torch modules.
|
50
|
+
"""
|
51
|
+
likelihoods = torch.zeros(num_epochs, num_cycles)
|
52
|
+
|
53
|
+
# Perform EM updates for num_epochs
|
54
|
+
for e in range(num_epochs):
|
55
|
+
|
56
|
+
## The E-step
|
57
|
+
## Without gradients tracking, get the posterior state distribution wrt current values of the parameters
|
58
|
+
with torch.no_grad():
|
59
|
+
posterior = dkf.sequence_smooth(
|
60
|
+
z_seq=z_seq,
|
61
|
+
x0=x0,
|
62
|
+
P0=P0,
|
63
|
+
Q=Q().repeat(len(z_seq), 1, 1),
|
64
|
+
R=R().repeat(len(z_seq), 1, 1),
|
65
|
+
f_args=f_args,
|
66
|
+
h_args=h_args,
|
67
|
+
)
|
68
|
+
|
69
|
+
## The M-step (Update the parameters) with respect to the current posterior state distribution for num_cycles
|
70
|
+
for c in range(num_cycles):
|
71
|
+
# Zero the gradients
|
72
|
+
optimizer.zero_grad()
|
73
|
+
|
74
|
+
# Compute the marginal likelihood for logging
|
75
|
+
with torch.no_grad():
|
76
|
+
marginal_likelihood = dkf.marginal_log_likelihood(
|
77
|
+
z_seq=z_seq,
|
78
|
+
x0=x0,
|
79
|
+
P0=P0,
|
80
|
+
Q=Q().repeat(len(z_seq), 1, 1),
|
81
|
+
R=R().repeat(len(z_seq), 1, 1),
|
82
|
+
f_args=f_args,
|
83
|
+
h_args=h_args,
|
84
|
+
)
|
85
|
+
likelihoods[e, c] = marginal_likelihood
|
86
|
+
|
87
|
+
# Perform the forward pass i.e compute the expected complete joint log-likelihood with respect previous posterior state distribution and current parameters
|
88
|
+
complete_log_likelihood = dkf.monte_carlo_expected_joint_log_likekihood(
|
89
|
+
z_seq=z_seq,
|
90
|
+
x0=x0,
|
91
|
+
P0=P0,
|
92
|
+
# below represents the posterior state distribution
|
93
|
+
x0_smoothed=posterior["x0_smoothed"],
|
94
|
+
P0_smoothed=posterior["P0_smoothed"],
|
95
|
+
x_smoothed=posterior["x_smoothed"],
|
96
|
+
P_smoothed=posterior["P_smoothed"],
|
97
|
+
Q_seq=Q().repeat(len(z_seq), 1, 1),
|
98
|
+
R_seq=R().repeat(len(z_seq), 1, 1),
|
99
|
+
f_args=f_args,
|
100
|
+
h_args=h_args,
|
101
|
+
)
|
102
|
+
|
103
|
+
# Update the parameters
|
104
|
+
(-complete_log_likelihood).backward()
|
105
|
+
optimizer.step()
|
106
|
+
lr_scheduler.step()
|
107
|
+
|
108
|
+
# Print the log likelihood
|
109
|
+
print(
|
110
|
+
f"Epoch {e + 1}/{num_epochs} Cycle {c + 1}/{num_cycles} Log Likelihood: {marginal_likelihood.item()}"
|
111
|
+
)
|
112
|
+
|
113
|
+
return {
|
114
|
+
"likelihoods": likelihoods,
|
115
|
+
}
|
diffkalman/filter.py
ADDED
@@ -0,0 +1,654 @@
|
|
1
|
+
"""
|
2
|
+
This module provides a differentiable Kalman Filter implementation designed for linear and non-linear dynamical systems perturbed by Gaussian noise.
|
3
|
+
|
4
|
+
The Kalman Filter is implemented using the PyTorch library, allowing seamless integration with neural networks to parameterize the dynamics, observation, and noise models. These parameters are optimized using Stochastic Variational Inference (SVI), which is mathematically equivalent to the classical Expectation-Maximization (EM) algorithm. The Gaussian assumption ensures that the posterior distribution of states remains analytically tractable using Rauch-Tung-Striebel smoother equations. This tractability facilitates the maximization of the log-likelihood of the observed data, given the model parameters.
|
5
|
+
|
6
|
+
Key Features:
|
7
|
+
- Differentiable Kalman Filter implementation.
|
8
|
+
- Flexible state transition and observation models (linear/non-linear, learnable, or fixed).
|
9
|
+
- Automatic computation of Jacobians using PyTorch's autograd functionality.
|
10
|
+
- Support for Monte Carlo sampling to evaluate the expected joint log-likelihood.
|
11
|
+
|
12
|
+
Classes:
|
13
|
+
DifferentiableKalmanFilter: Implements the core Kalman Filter algorithm for linear and non-linear dynamical systems.
|
14
|
+
|
15
|
+
Functions:
|
16
|
+
joint_jacobian_transform: Computes joint Jacobians for state and observation functions.
|
17
|
+
log_likelihood: Computes the log-likelihood for Gaussian distributions.
|
18
|
+
gaussian_log_likelihood: Computes Gaussian log-likelihoods for given inputs.
|
19
|
+
|
20
|
+
Example:
|
21
|
+
# Define custom state transition and observation functions
|
22
|
+
f = MyStateTransitionFunction()
|
23
|
+
h = MyObservationFunction()
|
24
|
+
|
25
|
+
# Initialize the Kalman filter
|
26
|
+
kalman_filter = DifferentiableKalmanFilter(dim_x=4, dim_z=2, f=f, h=h)
|
27
|
+
|
28
|
+
# Perform filtering on a sequence of observations
|
29
|
+
results = kalman_filter.sequence_filter(
|
30
|
+
z_seq=observations,
|
31
|
+
x0=initial_state,
|
32
|
+
P0=initial_covariance,
|
33
|
+
Q=process_noise_covariance,
|
34
|
+
R=observation_noise_covariance,
|
35
|
+
)
|
36
|
+
|
37
|
+
Dependencies:
|
38
|
+
- PyTorch: For deep learning integration and differentiability.
|
39
|
+
- Custom Modules: `joint_jacobian_transform`, `negative_log_likelihood`.
|
40
|
+
|
41
|
+
__all__ = ["DifferentiableKalmanFilter"]
|
42
|
+
"""
|
43
|
+
|
44
|
+
import torch
|
45
|
+
import torch.nn as nn
|
46
|
+
from .joint_jacobian_transform import joint_jacobian_transform
|
47
|
+
from .negative_log_likelihood import log_likelihood, gaussain_log_likelihood
|
48
|
+
|
49
|
+
|
50
|
+
__all__ = ["DiffrentiableKalmanFilter"]
|
51
|
+
|
52
|
+
|
53
|
+
class DiffrentiableKalmanFilter(nn.Module):
|
54
|
+
"""
|
55
|
+
Implements a differentiable Kalman Filter for linear and non-linear dynamical systems.
|
56
|
+
|
57
|
+
This class provides methods for prediction, update, and smoothing, enabling filtering of
|
58
|
+
sequences of observations while allowing the use of learnable neural network models for
|
59
|
+
state transition and observation functions.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
dim_x (int): Dimensionality of the state space.
|
63
|
+
dim_z (int): Dimensionality of the observation space.
|
64
|
+
f (nn.Module): State transition function, which can be linear/non-linear and learnable or fixed.
|
65
|
+
h (nn.Module): Observation function, which can be linear/non-linear and learnable or fixed.
|
66
|
+
mc_samples (int, optional): Number of Monte Carlo samples used for expected log-likelihood computation. Defaults to 100.
|
67
|
+
|
68
|
+
Attributes:
|
69
|
+
dim_x (int): Dimensionality of the state space.
|
70
|
+
dim_z (int): Dimensionality of the observation space.
|
71
|
+
f (nn.Module): State transition function.
|
72
|
+
h (nn.Module): Observation function.
|
73
|
+
I (torch.Tensor): Identity matrix of size `(dim_x, dim_x)`.
|
74
|
+
_f_joint (Callable): Joint Jacobian function of the state transition function.
|
75
|
+
_h_joint (Callable): Joint Jacobian function of the observation function.
|
76
|
+
|
77
|
+
Methods:
|
78
|
+
predict: Performs state prediction given the current state and covariance.
|
79
|
+
update: Updates the state estimate based on the current observation.
|
80
|
+
predict_update: Combines prediction and update steps for a single observation.
|
81
|
+
sequence_filter: Filters a sequence of observations to estimate states over time.
|
82
|
+
sequence_smooth: Smooths a sequence of observations for refined state estimates.
|
83
|
+
|
84
|
+
Note:
|
85
|
+
- PyTorch's autograd functionality is utilized to compute Jacobians of the state transition
|
86
|
+
and observation functions, eliminating the need for manual derivation.
|
87
|
+
- Monte Carlo sampling is used for approximating the expected log-likelihood in non-linear scenarios.
|
88
|
+
"""
|
89
|
+
|
90
|
+
def __init__(
|
91
|
+
self, dim_x: int, dim_z: int, f: nn.Module, h: nn.Module, mc_samples: int = 100
|
92
|
+
) -> None:
|
93
|
+
"""Initializes the DiffrentiableKalmanFilter object.
|
94
|
+
|
95
|
+
Args:
|
96
|
+
dim_x (int): Dimensionality of the state space.
|
97
|
+
dim_z (int): Dimensionality of the observation space.
|
98
|
+
f (nn.Module): State transition function which can be parametrized by a neural network which can linear/non-linear and learnable or fixed.
|
99
|
+
h (nn.Module): Observation function which can be parametrized by a neural network which can linear/non-linear and learnable or fixed.
|
100
|
+
mc_samples (int): Number of Monte Carlo samples to draw to calcualte expected joint log-likelihood. Default is 100.
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
None
|
104
|
+
|
105
|
+
Note:
|
106
|
+
- The state transition function signature is f(x: torch.Tensor, *args) -> torch.Tensor.
|
107
|
+
- The observation function signature is h(x: torch.Tensor, *args) -> torch.Tensor.
|
108
|
+
"""
|
109
|
+
|
110
|
+
super().__init__()
|
111
|
+
|
112
|
+
# Store the dimensionality of the state and observation space
|
113
|
+
self.dim_x = dim_x
|
114
|
+
self.dim_z = dim_z
|
115
|
+
|
116
|
+
# Store the state transition and observation functions
|
117
|
+
self.f = f
|
118
|
+
self.h = h
|
119
|
+
|
120
|
+
# Store the number of Monte Carlo samples
|
121
|
+
self.mc_samples = mc_samples
|
122
|
+
|
123
|
+
# Register the Jacobian functions of the state transition and observation functions
|
124
|
+
self._f_joint = joint_jacobian_transform(self.f)
|
125
|
+
self._h_joint = joint_jacobian_transform(self.h)
|
126
|
+
|
127
|
+
# Identity matrix of shape (dim_x, dim_x)
|
128
|
+
self.register_buffer("I", torch.eye(dim_x))
|
129
|
+
|
130
|
+
def predict(
|
131
|
+
self,
|
132
|
+
x: torch.Tensor,
|
133
|
+
P: torch.Tensor,
|
134
|
+
Q: torch.Tensor,
|
135
|
+
f_args: tuple = (),
|
136
|
+
) -> dict[str, torch.Tensor]:
|
137
|
+
"""
|
138
|
+
Predicts the next state given the current state and the state transition function.
|
139
|
+
|
140
|
+
Args:
|
141
|
+
x (torch.Tensor): Current state estimate of shape `(dim_x,)`.
|
142
|
+
P (torch.Tensor): Current state covariance of shape `(dim_x, dim_x)`.
|
143
|
+
Q (torch.Tensor): Process noise covariance of shape `(dim_x, dim_x)`.
|
144
|
+
f_args (tuple, optional): Additional arguments to the state transition function.
|
145
|
+
|
146
|
+
Returns:
|
147
|
+
dict[str, torch.Tensor]:
|
148
|
+
- `x_prior`: Predicted state estimate of shape `(dim_x,)`.
|
149
|
+
- `P_prior`: Predicted state covariance of shape `(dim_x, dim_x)`.
|
150
|
+
- `state_jacobian`: Jacobian of the state transition function of shape `(dim_x, dim_x)`.
|
151
|
+
|
152
|
+
Note:
|
153
|
+
- The control input, if any, can be incorporated into the state transition function by passing it as an argument in `f_args`.
|
154
|
+
"""
|
155
|
+
# Compute the predicted state estimate and covariance
|
156
|
+
F, x_prior = self._f_joint(x, *f_args)
|
157
|
+
|
158
|
+
P_prior = F @ P @ F.T + Q
|
159
|
+
|
160
|
+
return {"x_prior": x_prior, "P_prior": P_prior, "state_jacobian": F}
|
161
|
+
|
162
|
+
def update(
|
163
|
+
self,
|
164
|
+
z: torch.Tensor,
|
165
|
+
x_prior: torch.Tensor,
|
166
|
+
P_prior: torch.Tensor,
|
167
|
+
R: torch.Tensor,
|
168
|
+
h_args: tuple = (),
|
169
|
+
) -> dict[str, torch.Tensor]:
|
170
|
+
"""
|
171
|
+
Updates the state estimate based on the current observation and observation model.
|
172
|
+
|
173
|
+
Args:
|
174
|
+
x (torch.Tensor): Predicted state estimate of shape `(dim_x,)`.
|
175
|
+
P (torch.Tensor): Predicted state covariance of shape `(dim_x, dim_x)`.
|
176
|
+
z (torch.Tensor): Current observation of shape `(dim_z,)`.
|
177
|
+
R (torch.Tensor): Observation noise covariance of shape `(dim_z, dim_z)`.
|
178
|
+
h_args (tuple, optional): Additional arguments to the observation function.
|
179
|
+
|
180
|
+
Returns:
|
181
|
+
dict[str, torch.Tensor]:
|
182
|
+
- `x_post`: Updated state estimate of shape `(dim_x,)`.
|
183
|
+
- `P_post`: Updated state covariance of shape `(dim_x, dim_x)`.
|
184
|
+
- `y`: Observation residual (innovation) of shape `(dim_z,)`.
|
185
|
+
- `S`: Innovation covariance of shape `(dim_z, dim_z)`.
|
186
|
+
- `K`: Kalman gain of shape `(dim_x, dim_z)`.
|
187
|
+
|
188
|
+
Note:
|
189
|
+
- PyTorch's autograd is used to compute the Jacobian of the observation function, ensuring compatibility with differentiable models.
|
190
|
+
"""
|
191
|
+
# Compute the jacboian of the observation function
|
192
|
+
H, z_pred = self._h_joint(x_prior, *h_args)
|
193
|
+
|
194
|
+
# Compute the innovation
|
195
|
+
y = z - z_pred
|
196
|
+
# Compute the innovation covariance matrix
|
197
|
+
S = H @ P_prior @ H.T + R
|
198
|
+
# Compute the Kalman gain
|
199
|
+
K = P_prior @ H.T @ torch.linalg.inv(S)
|
200
|
+
|
201
|
+
# Update the state vector
|
202
|
+
x_post = x_prior + K @ y
|
203
|
+
# Update the state covariance matrix using joseph form
|
204
|
+
factor = self.I - K @ H
|
205
|
+
P_post = factor @ P_prior @ factor.T + K @ R @ K.T
|
206
|
+
|
207
|
+
return {
|
208
|
+
"x_post": x_post,
|
209
|
+
"P_post": P_post,
|
210
|
+
"innovation": y,
|
211
|
+
"innovation_covariance": S,
|
212
|
+
"observation_jacobian": H,
|
213
|
+
"kalman_gain": K,
|
214
|
+
}
|
215
|
+
|
216
|
+
def predict_update(
|
217
|
+
self,
|
218
|
+
z: torch.Tensor,
|
219
|
+
x: torch.Tensor,
|
220
|
+
P: torch.Tensor,
|
221
|
+
Q: torch.Tensor,
|
222
|
+
R: torch.Tensor,
|
223
|
+
f_args: tuple = (),
|
224
|
+
h_args: tuple = (),
|
225
|
+
) -> dict[str, torch.Tensor]:
|
226
|
+
"""
|
227
|
+
Combines prediction and update steps for a single observation.
|
228
|
+
|
229
|
+
Args:
|
230
|
+
x (torch.Tensor): Current state estimate of shape `(dim_x,)`.
|
231
|
+
P (torch.Tensor): Current state covariance of shape `(dim_x, dim_x)`.
|
232
|
+
z (torch.Tensor): Current observation of shape `(dim_z,)`.
|
233
|
+
Q (torch.Tensor): Process noise covariance of shape `(dim_x, dim_x)`.
|
234
|
+
R (torch.Tensor): Observation noise covariance of shape `(dim_z, dim_z)`.
|
235
|
+
f_args (tuple, optional): Additional arguments to the state transition function.
|
236
|
+
h_args (tuple, optional): Additional arguments to the observation function.
|
237
|
+
|
238
|
+
Returns:
|
239
|
+
dict[str, torch.Tensor]:
|
240
|
+
- `x_post`: Updated state estimate of shape `(dim_x,)`.
|
241
|
+
- `P_post`: Updated state covariance of shape `(dim_x, dim_x)`.
|
242
|
+
- `y`: Observation residual (innovation) of shape `(dim_z,)`.
|
243
|
+
- `S`: Innovation covariance of shape `(dim_z, dim_z)`.
|
244
|
+
- `K`: Kalman gain of shape `(dim_x, dim_z)`.
|
245
|
+
|
246
|
+
Note:
|
247
|
+
- This method is particularly useful for real-time filtering where observations arrive sequentially.
|
248
|
+
"""
|
249
|
+
# Predict the next state
|
250
|
+
prediction = self.predict(x=x, P=P, Q=Q, f_args=f_args)
|
251
|
+
|
252
|
+
# Update the state estimate
|
253
|
+
update = self.update(
|
254
|
+
z=z,
|
255
|
+
x_prior=prediction["x_prior"],
|
256
|
+
P_prior=prediction["P_prior"],
|
257
|
+
R=R,
|
258
|
+
h_args=h_args,
|
259
|
+
)
|
260
|
+
|
261
|
+
return {**prediction, **update}
|
262
|
+
|
263
|
+
def sequence_filter(
|
264
|
+
self,
|
265
|
+
z_seq: torch.Tensor,
|
266
|
+
x0: torch.Tensor,
|
267
|
+
P0: torch.Tensor,
|
268
|
+
Q: torch.Tensor,
|
269
|
+
R: torch.Tensor,
|
270
|
+
f_args: tuple = (),
|
271
|
+
h_args: tuple = (),
|
272
|
+
) -> dict[str, torch.Tensor]:
|
273
|
+
"""
|
274
|
+
Filters a sequence of observations to estimate states over time.
|
275
|
+
|
276
|
+
Args:
|
277
|
+
z_seq (torch.Tensor): Sequence of observations of shape `(T, dim_z)`, where `T` is the number of time steps.
|
278
|
+
x0 (torch.Tensor): Initial state estimate of shape `(dim_x,)`.
|
279
|
+
P0 (torch.Tensor): Initial state covariance of shape `(dim_x, dim_x)`.
|
280
|
+
Q (torch.Tensor): Process noise covariance of shape `(T, dim_x, dim_x)`.
|
281
|
+
R (torch.Tensor): Observation noise covariance of shape `(T, dim_z, dim_z)`.
|
282
|
+
f_args (tuple, optional): Additional arguments to the state transition function.
|
283
|
+
h_args (tuple, optional): Additional arguments to the observation function.
|
284
|
+
|
285
|
+
Returns:
|
286
|
+
dict[str, torch.Tensor]:
|
287
|
+
- `x_filtered`: Filtered state estimates of shape `(T, dim_x)`.
|
288
|
+
- `P_filtered`: Filtered state covariances of shape `(T, dim_x, dim_x)`.
|
289
|
+
|
290
|
+
Note:
|
291
|
+
- Each time step is processed iteratively using the `predict` and `update` methods.
|
292
|
+
"""
|
293
|
+
# Get the sequence length
|
294
|
+
T = z_seq.size(0)
|
295
|
+
|
296
|
+
# Initialize the output tensors
|
297
|
+
outs_terms = [
|
298
|
+
"x_prior",
|
299
|
+
"P_prior",
|
300
|
+
"x_post",
|
301
|
+
"P_post",
|
302
|
+
"state_jacobian",
|
303
|
+
"innovation",
|
304
|
+
"innovation_covariance",
|
305
|
+
"observation_jacobian",
|
306
|
+
"kalman_gain",
|
307
|
+
]
|
308
|
+
outputs = {term: [] for term in outs_terms}
|
309
|
+
|
310
|
+
# Run the Kalman Filter sequentially
|
311
|
+
for t in range(T):
|
312
|
+
# Predict the next state
|
313
|
+
results = self.predict_update(
|
314
|
+
z=z_seq[t],
|
315
|
+
x=x0 if t == 0 else outputs["x_post"][-1],
|
316
|
+
P=P0 if t == 0 else outputs["P_post"][-1],
|
317
|
+
Q=Q[t],
|
318
|
+
R=R[t],
|
319
|
+
f_args=(args[t] for args in f_args),
|
320
|
+
h_args=(args[t] for args in h_args),
|
321
|
+
)
|
322
|
+
|
323
|
+
# Update the output tensors
|
324
|
+
for term in outs_terms:
|
325
|
+
outputs[term].append(results[term])
|
326
|
+
|
327
|
+
# Stack the output tensors
|
328
|
+
for term in outs_terms:
|
329
|
+
outputs[term] = torch.stack(outputs[term])
|
330
|
+
|
331
|
+
# Calculate the log-likelihood
|
332
|
+
log_like = log_likelihood(
|
333
|
+
innovation=outputs["innovation"],
|
334
|
+
innovation_covariance=outputs["innovation_covariance"],
|
335
|
+
)
|
336
|
+
outputs["log_likelihood"] = log_like
|
337
|
+
|
338
|
+
return outputs
|
339
|
+
|
340
|
+
def marginal_log_likelihood(
|
341
|
+
self,
|
342
|
+
z_seq: torch.Tensor,
|
343
|
+
x0: torch.Tensor,
|
344
|
+
P0: torch.Tensor,
|
345
|
+
Q: torch.Tensor,
|
346
|
+
R: torch.Tensor,
|
347
|
+
f_args: tuple = (),
|
348
|
+
h_args: tuple = (),
|
349
|
+
) -> torch.Tensor:
|
350
|
+
"""
|
351
|
+
Computes the marginal log-likelihood of the observed data given the model parameters.
|
352
|
+
|
353
|
+
Args:
|
354
|
+
z_seq (torch.Tensor): Sequence of observations of shape `(T, dim_z)`, where `T` is the number of time steps.
|
355
|
+
x0 (torch.Tensor): Initial state estimate of shape `(dim_x,)`.
|
356
|
+
P0 (torch.Tensor): Initial state covariance of shape `(dim_x, dim_x)`.
|
357
|
+
Q (torch.Tensor): Process noise covariance of shape `(T, dim_x, dim_x)`.
|
358
|
+
R (torch.Tensor): Observation noise covariance of shape `(T, dim_z, dim_z)`.
|
359
|
+
f_args (tuple, optional): Additional arguments to the state transition function.
|
360
|
+
h_args (tuple, optional): Additional arguments to the observation function.
|
361
|
+
|
362
|
+
Returns:
|
363
|
+
torch.Tensor: The marginal log-likelihood of the observed data given the model parameters.
|
364
|
+
"""
|
365
|
+
# Run the Kalman Filter and only track innovation and innovation covariance
|
366
|
+
innov = []
|
367
|
+
innov_cov = []
|
368
|
+
|
369
|
+
# Initialize the state estimate
|
370
|
+
x_post = x0
|
371
|
+
P_post = P0
|
372
|
+
|
373
|
+
for t in range(z_seq.size(0)):
|
374
|
+
# Predict the next state
|
375
|
+
prediction = self.predict(
|
376
|
+
x=x_post,
|
377
|
+
P=P_post,
|
378
|
+
Q=Q[t],
|
379
|
+
f_args=(args[t] for args in f_args),
|
380
|
+
)
|
381
|
+
|
382
|
+
# Update the state estimate
|
383
|
+
update = self.update(
|
384
|
+
z=z_seq[t],
|
385
|
+
x_prior=prediction["x_prior"],
|
386
|
+
P_prior=prediction["P_prior"],
|
387
|
+
R=R[t],
|
388
|
+
h_args=(args[t] for args in h_args),
|
389
|
+
)
|
390
|
+
|
391
|
+
# Update the state estimate
|
392
|
+
x_post = update["x_post"]
|
393
|
+
P_post = update["P_post"]
|
394
|
+
|
395
|
+
# Store the innovation and innovation covariance
|
396
|
+
innov.append(update["innovation"])
|
397
|
+
innov_cov.append(update["innovation_covariance"])
|
398
|
+
|
399
|
+
return log_likelihood(
|
400
|
+
innovation=torch.stack(innov),
|
401
|
+
innovation_covariance=torch.stack(innov_cov),
|
402
|
+
)
|
403
|
+
|
404
|
+
def rauch_tung_striebel_smoothing(
|
405
|
+
self,
|
406
|
+
x0: torch.Tensor,
|
407
|
+
P0: torch.Tensor,
|
408
|
+
x_prior: torch.Tensor,
|
409
|
+
P_prior: torch.Tensor,
|
410
|
+
x_filtered: torch.Tensor,
|
411
|
+
P_filtered: torch.Tensor,
|
412
|
+
state_jacobian: torch.Tensor,
|
413
|
+
) -> dict[str, torch.Tensor]:
|
414
|
+
"""
|
415
|
+
Smooths a sequence of filtered states for refined estimates over time.
|
416
|
+
|
417
|
+
Args:
|
418
|
+
x_filtered (torch.Tensor): Filtered state estimates of shape `(T, dim_x)`.
|
419
|
+
P_filtered (torch.Tensor): Filtered state covariances of shape `(T, dim_x, dim_x)`.
|
420
|
+
Q (torch.Tensor): Process noise covariance of shape `(dim_x, dim_x)`.
|
421
|
+
f_args (tuple, optional): Additional arguments to the state transition function.
|
422
|
+
|
423
|
+
Returns:
|
424
|
+
dict[str, torch.Tensor]:
|
425
|
+
- `x_smoothed`: Smoothed state estimates of shape `(T, dim_x)`.
|
426
|
+
- `P_smoothed`: Smoothed state covariances of shape `(T, dim_x, dim_x)`.
|
427
|
+
|
428
|
+
Note:
|
429
|
+
- This method uses Rauch-Tung-Striebel smoothing equations to refine the estimates.
|
430
|
+
- Smoothed states incorporate information from future observations, improving the overall estimation accuracy.
|
431
|
+
"""
|
432
|
+
|
433
|
+
# Initialize the output tensors
|
434
|
+
outs_terms = ["x_smoothed", "P_smoothed", "smoothing_gain"]
|
435
|
+
outputs = {term: [] for term in outs_terms}
|
436
|
+
|
437
|
+
# Last state estimate is already the smoothed state estimate
|
438
|
+
# by definition
|
439
|
+
outputs["x_smoothed"].append(x_filtered[-1])
|
440
|
+
outputs["P_smoothed"].append(P_filtered[-1])
|
441
|
+
|
442
|
+
# Sequence length
|
443
|
+
T = x_filtered.size(0)
|
444
|
+
|
445
|
+
# Start the backward-recursion form the second last time step
|
446
|
+
# to the first time step
|
447
|
+
for t in range(T - 2, -1, -1):
|
448
|
+
# Compute the smoothing gain
|
449
|
+
L = (
|
450
|
+
P_filtered[t]
|
451
|
+
@ state_jacobian[t + 1].T
|
452
|
+
@ torch.linalg.inv(P_prior[t + 1])
|
453
|
+
)
|
454
|
+
# Insert the smoothing gain into the output tensor
|
455
|
+
outputs["smoothing_gain"].insert(0, L)
|
456
|
+
# Compute the smoothed state estimate
|
457
|
+
outputs["x_smoothed"].insert(
|
458
|
+
0, x_filtered[t] + L @ (outputs["x_smoothed"][0] - x_prior[t + 1])
|
459
|
+
)
|
460
|
+
# Compute the smoothed state covariance
|
461
|
+
outputs["P_smoothed"].insert(
|
462
|
+
0, P_filtered[t] + L @ (outputs["P_smoothed"][0] - P_prior[t + 1]) @ L.T
|
463
|
+
)
|
464
|
+
|
465
|
+
# Smoothed the initial state estimate
|
466
|
+
L0 = P0 @ state_jacobian[0].T @ torch.linalg.inv(P_prior[0])
|
467
|
+
outputs["smoothing_gain"].insert(0, L0)
|
468
|
+
x0_smoothed = x0 + L0 @ (outputs["x_smoothed"][0] - x_prior[0])
|
469
|
+
P0_smoothed = P0 + L0 @ (outputs["P_smoothed"][0] - P_prior[0]) @ L0.T
|
470
|
+
|
471
|
+
return {
|
472
|
+
"x_smoothed": torch.stack(outputs["x_smoothed"]),
|
473
|
+
"P_smoothed": torch.stack(outputs["P_smoothed"]),
|
474
|
+
"smoothing_gain": torch.stack(outputs["smoothing_gain"]),
|
475
|
+
"x0_smoothed": x0_smoothed,
|
476
|
+
"P0_smoothed": P0_smoothed,
|
477
|
+
}
|
478
|
+
|
479
|
+
def sequence_smooth(
|
480
|
+
self,
|
481
|
+
z_seq: torch.Tensor,
|
482
|
+
x0: torch.Tensor,
|
483
|
+
P0: torch.Tensor,
|
484
|
+
Q: torch.Tensor,
|
485
|
+
R: torch.Tensor,
|
486
|
+
f_args: tuple = (),
|
487
|
+
h_args: tuple = (),
|
488
|
+
) -> dict[str, torch.Tensor]:
|
489
|
+
"""
|
490
|
+
Applies a smoothing algorithm to a sequence of observations using a two-step process:
|
491
|
+
Kalman filtering followed by Rauch-Tung-Striebel smoothing.
|
492
|
+
|
493
|
+
Args:
|
494
|
+
z_seq (torch.Tensor): Sequence of observations with shape `(seq_len, dim_z)`.
|
495
|
+
x0 (torch.Tensor): Initial state estimate with shape `(dim_x,)`.
|
496
|
+
P0 (torch.Tensor): Initial state covariance matrix with shape `(dim_x, dim_x)`.
|
497
|
+
Q (torch.Tensor): Process noise covariance matrix with shape `(seq_len, dim_x, dim_x)`.
|
498
|
+
R (torch.Tensor): Observation noise covariance matrix with shape `(seq_len, dim_z, dim_z)`.
|
499
|
+
f_args (tuple): Optional sequence of additional arguments for the state transition function.
|
500
|
+
h_args (tuple): Optional sequence of additional arguments for the observation function.
|
501
|
+
|
502
|
+
Returns:
|
503
|
+
dict[str, torch.Tensor]: A dictionary containing:
|
504
|
+
- Filtered state estimates (`x_post`) and covariances (`P_post`).
|
505
|
+
- Smoothed state estimates (`x_smooth`) and covariances (`P_smooth`).
|
506
|
+
|
507
|
+
Notes:
|
508
|
+
- The `f_args` and `h_args` sequences correspond to specific time steps.
|
509
|
+
- For time-invariant `Q` or `R`, use PyTorch's `repeat` functionality to create sequences.
|
510
|
+
"""
|
511
|
+
# Run the Kalman Filter
|
512
|
+
filter_results = self.sequence_filter(
|
513
|
+
z_seq=z_seq,
|
514
|
+
x0=x0,
|
515
|
+
P0=P0,
|
516
|
+
Q=Q,
|
517
|
+
R=R,
|
518
|
+
f_args=f_args,
|
519
|
+
h_args=h_args,
|
520
|
+
)
|
521
|
+
|
522
|
+
# Run the Rauch-Tung-Striebel Smoothing
|
523
|
+
smooth_results = self.rauch_tung_striebel_smoothing(
|
524
|
+
x0=x0,
|
525
|
+
P0=P0,
|
526
|
+
x_prior=filter_results["x_prior"],
|
527
|
+
P_prior=filter_results["P_prior"],
|
528
|
+
x_filtered=filter_results["x_post"],
|
529
|
+
P_filtered=filter_results["P_post"],
|
530
|
+
state_jacobian=filter_results["state_jacobian"],
|
531
|
+
)
|
532
|
+
|
533
|
+
return {**filter_results, **smooth_results}
|
534
|
+
|
535
|
+
def draw_samples(
|
536
|
+
self, n_samples: int, mu: torch.Tensor, P: torch.Tensor
|
537
|
+
) -> torch.Tensor:
|
538
|
+
"""
|
539
|
+
Generates samples from a multivariate normal distribution parameterized
|
540
|
+
by a mean vector and covariance matrix.
|
541
|
+
|
542
|
+
Args:
|
543
|
+
n_samples (int): Number of samples to generate.
|
544
|
+
mu (torch.Tensor): Mean vector of the distribution with shape `(dim_x,)`.
|
545
|
+
P (torch.Tensor): Covariance matrix of the distribution with shape `(dim_x, dim_x)`.
|
546
|
+
|
547
|
+
Returns:
|
548
|
+
torch.Tensor: A tensor of generated samples with shape `(n_samples, dim_x)`.
|
549
|
+
|
550
|
+
Notes:
|
551
|
+
- Cholesky decomposition is used to ensure numerical stability.
|
552
|
+
- The samples are drawn using PyTorch's distribution utilities.
|
553
|
+
"""
|
554
|
+
# Draw samples from a multivariate normal distribution
|
555
|
+
X_unit_samples = (
|
556
|
+
torch.distributions.MultivariateNormal(
|
557
|
+
loc=torch.zeros(self.dim_x), covariance_matrix=torch.eye(self.dim_x)
|
558
|
+
)
|
559
|
+
.sample((n_samples,))
|
560
|
+
.to(device=mu.device, dtype=mu.dtype)
|
561
|
+
)
|
562
|
+
|
563
|
+
return mu + torch.einsum("kk,jk->jk", torch.linalg.cholesky(P), X_unit_samples)
|
564
|
+
|
565
|
+
def monte_carlo_expected_joint_log_likekihood(
|
566
|
+
self,
|
567
|
+
z_seq: torch.Tensor,
|
568
|
+
x0: torch.Tensor,
|
569
|
+
P0: torch.Tensor,
|
570
|
+
x0_smoothed: torch.Tensor,
|
571
|
+
P0_smoothed: torch.Tensor,
|
572
|
+
x_smoothed: torch.Tensor,
|
573
|
+
P_smoothed: torch.Tensor,
|
574
|
+
Q_seq: torch.Tensor,
|
575
|
+
R_seq: torch.Tensor,
|
576
|
+
f_args: tuple = (),
|
577
|
+
h_args: tuple = (),
|
578
|
+
) -> torch.Tensor:
|
579
|
+
"""
|
580
|
+
This function computes the Monte Carlo approximation of the expected joint log-likelihood
|
581
|
+
of the observed data given the posterior distribution of the states.
|
582
|
+
|
583
|
+
The posterior distribution of the states is derived from the smoothed state estimates,
|
584
|
+
which are tractable and exact in the case of linear Gaussian models. This simplifies the
|
585
|
+
process compared to a full-fledged Variational Inference (VI) approach, where the posterior
|
586
|
+
distribution is approximated using a variational distribution. Here, the exact posterior
|
587
|
+
distribution is used, as it is directly obtained from the smoothing algorithm, eliminating
|
588
|
+
the need for variational approximations.
|
589
|
+
|
590
|
+
In the case where the dynamics and observation models are parameterized by neural networks,
|
591
|
+
the variational distribution corresponds to the exact posterior distribution over the unknown
|
592
|
+
variables (e.g., parameters of the models). This formulation maintains computational efficiency
|
593
|
+
while leveraging the expressiveness of neural networks for system modeling.
|
594
|
+
|
595
|
+
Args:
|
596
|
+
z_seq (torch.Tensor): Sequence of observations of shape (seq_len, dim_z).\
|
597
|
+
x0 (torch.Tensor): Initial state estimate of shape (dim_x,).
|
598
|
+
P0 (torch.Tensor): Initial state covariance of shape (dim_x, dim_x).
|
599
|
+
x0_smoothed (torch.Tensor): Initial smoothed state estimate of shape (dim_x,).
|
600
|
+
P0_smoothed (torch.Tensor): Initial smoothed state covariance of shape (dim_x, dim_x).
|
601
|
+
x_smoothed (torch.Tensor): Smoothed state estimates of shape (seq_len, dim_x).
|
602
|
+
P_smoothed (torch.Tensor): Smoothed state covariances of shape (seq_len, dim_x, dim_x).
|
603
|
+
Q_seq (torch.Tensor): Process noise covariance of shape (seq_len, dim_x, dim_x).
|
604
|
+
R_seq (torch.Tensor): Observation noise covariance of shape (seq_len, dim_z, dim_z).
|
605
|
+
f_args (tuple): Additional arguments to the state transition function.
|
606
|
+
h_args (tuple): Additional arguments to the observation function.
|
607
|
+
|
608
|
+
|
609
|
+
Returns:
|
610
|
+
torch.Tensor: The Monte Carlo approximation of the expected joint log-likelihood.
|
611
|
+
"""
|
612
|
+
# Loop to calculate the monte-carlo approximation of the expected joint log-likelihood
|
613
|
+
E_log_likelihood = 0.0
|
614
|
+
# Here we will use the vmap function to vectorize the computation of the log-likelihood
|
615
|
+
vllf = torch.vmap(
|
616
|
+
func=gaussain_log_likelihood,
|
617
|
+
in_dims=(0, None),
|
618
|
+
)
|
619
|
+
|
620
|
+
# Draw samples from the smoothed initial distribution
|
621
|
+
X0_samples = self.draw_samples(
|
622
|
+
n_samples=self.mc_samples,
|
623
|
+
mu=x0_smoothed,
|
624
|
+
P=P0_smoothed,
|
625
|
+
)
|
626
|
+
innov_x0 = X0_samples - x0
|
627
|
+
E_log_likelihood += vllf(innov_x0, P0).mean()
|
628
|
+
|
629
|
+
# For each measurement calculate the expected log-likelihood
|
630
|
+
for i in range(z_seq.size(0)):
|
631
|
+
# Calculate the measurement term
|
632
|
+
Xt = self.draw_samples(
|
633
|
+
n_samples=self.mc_samples,
|
634
|
+
mu=x_smoothed[i],
|
635
|
+
P=P_smoothed[i],
|
636
|
+
)
|
637
|
+
innov_z = z_seq[i] - torch.vmap(
|
638
|
+
lambda x: self.h(x, *(args[i] for args in h_args)),
|
639
|
+
)(Xt)
|
640
|
+
E_log_likelihood += vllf(innov_z, R_seq[i]).mean()
|
641
|
+
|
642
|
+
# Calculate the transition term
|
643
|
+
Xp = self.draw_samples(
|
644
|
+
n_samples=self.mc_samples,
|
645
|
+
mu=x_smoothed[i - 1] if i > 0 else x0_smoothed,
|
646
|
+
P=P_smoothed[i - 1] if i > 0 else P0_smoothed,
|
647
|
+
)
|
648
|
+
innov_x = Xt - torch.vmap(
|
649
|
+
lambda x: self.f(x, *(args[i] for args in f_args)),
|
650
|
+
)(Xp)
|
651
|
+
|
652
|
+
E_log_likelihood += vllf(innov_x, Q_seq[i]).mean()
|
653
|
+
|
654
|
+
return E_log_likelihood
|
@@ -0,0 +1,54 @@
|
|
1
|
+
"""Joint Jacobian Transformation to compute the Jacobian of a function with respect to its first argument and the function output."""
|
2
|
+
|
3
|
+
from torch.func import jacrev
|
4
|
+
|
5
|
+
__all__ = ["joint_jacobian_transform"]
|
6
|
+
|
7
|
+
|
8
|
+
def joint_jacobian_transform(f: callable) -> callable:
|
9
|
+
"""
|
10
|
+
Transforms a given function f into a joint function that returns both the
|
11
|
+
function's output and its Jacobian matrix with respect to its first argument.
|
12
|
+
|
13
|
+
The transformed function, when called, returns a tuple where the first element
|
14
|
+
is the Jacobian matrix of the function `f` with respect to its first argument
|
15
|
+
and the second element is the output of the function `f`.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
f (callable): The function to be transformed. It should be of the form
|
19
|
+
`f(x, *args)` where `x` is the variable with respect to
|
20
|
+
which the Jacobian is computed and `*args` are additional
|
21
|
+
arguments.
|
22
|
+
|
23
|
+
Returns:
|
24
|
+
callable: A new function `F` such that `F(x, *args)` returns a tuple
|
25
|
+
`(Df(x, *args), f(x, *args))`, where `Df(x, *args)` is the
|
26
|
+
Jacobian matrix of `f` with respect to `x` and `f(x, *args)`
|
27
|
+
is the original output of `f`.
|
28
|
+
|
29
|
+
Example:
|
30
|
+
>>> def my_function(x, a, b):
|
31
|
+
>>> return a * x + b
|
32
|
+
>>>
|
33
|
+
>>> F = joint_jacobian_transform(my_function)
|
34
|
+
>>> x = torch.tensor([1.0, 2.0])
|
35
|
+
>>> a = torch.tensor([2.0])
|
36
|
+
>>> b = torch.tensor([3.0])
|
37
|
+
>>> jacobian, output = F(x, a, b)
|
38
|
+
>>> print(jacobian) # Jacobian of my_function with respect to x
|
39
|
+
>>> print(output) # Output of my_function
|
40
|
+
"""
|
41
|
+
|
42
|
+
# Define the new joint function
|
43
|
+
def joint_func(*args, **kwargs):
|
44
|
+
# Compute the output of the function
|
45
|
+
output = f(*args, **kwargs)
|
46
|
+
return output, output
|
47
|
+
|
48
|
+
# The new joint jacrev function
|
49
|
+
Df = jacrev(
|
50
|
+
func=joint_func,
|
51
|
+
argnums=0,
|
52
|
+
has_aux=True,
|
53
|
+
)
|
54
|
+
return Df
|
@@ -0,0 +1,133 @@
|
|
1
|
+
"""
|
2
|
+
Utility functions for calculating negative log likelihood of observations.
|
3
|
+
|
4
|
+
This module provides functions to compute the negative log likelihood of a sequence of observations given
|
5
|
+
their predicted values and associated covariance matrices. The negative log likelihood quantifies the
|
6
|
+
likelihood of observing the data under a given model, incorporating uncertainties in the measurements.
|
7
|
+
|
8
|
+
The negative log likelihood for a sequence of innovations and their covariances is computed using the
|
9
|
+
logarithm of the joint distribution of the observations given the parameters. It is expressed mathematically as:
|
10
|
+
|
11
|
+
$$
|
12
|
+
\ln(\mathcal{L}(\Phi \mid \mathcal{D})) = -\frac{nk}{2} \ln(2\pi) - \frac{1}{2} \sum_{i=1}^{n} \left[ \ln |S_{i}| + \Delta y_{i}^T S_{i}^{-1} \Delta y_i \right]
|
13
|
+
$$
|
14
|
+
|
15
|
+
where $ n $ is the sequence length, $ k $ is the dimension of the innovation vector,
|
16
|
+
$ \Delta y_i $ is the innovation vector at time step $ i $, and $ S_i $ is the innovation
|
17
|
+
covariance matrix at time step $ i $.
|
18
|
+
|
19
|
+
Functions:
|
20
|
+
- `likelihood_at_time`: Calculates likelihood at a single time step.
|
21
|
+
- `log_likelihood`: Computes total log likelihood for a sequence of time steps.
|
22
|
+
- `negative_log_likelihood`: Computes total negative log likelihood for a sequence of time steps.
|
23
|
+
|
24
|
+
Note:
|
25
|
+
- The constant factor $ \frac{nk}{2} \ln(2\pi) $ is skipped in the implementation since it does not affect the optimization process.
|
26
|
+
"""
|
27
|
+
|
28
|
+
import torch
|
29
|
+
|
30
|
+
__all__ = ["gaussain_log_likelihood", "log_likelihood"]
|
31
|
+
|
32
|
+
|
33
|
+
def gaussain_log_likelihood(
|
34
|
+
innovation: torch.Tensor,
|
35
|
+
innovation_covariance: torch.Tensor,
|
36
|
+
) -> torch.Tensor:
|
37
|
+
"""
|
38
|
+
Calculates the likelihood estimation at a single time step.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
innovation (torch.Tensor): The innovation vector, representing the difference
|
42
|
+
between the observed and predicted measurements.
|
43
|
+
Shape: (dim_z,)
|
44
|
+
innovation_covariance (torch.Tensor): The innovation covariance matrix,
|
45
|
+
representing the uncertainty in the innovation.
|
46
|
+
Shape: (dim_z, dim_z)
|
47
|
+
|
48
|
+
Returns:
|
49
|
+
torch.Tensor: The loss value for the given innovation and its covariance.
|
50
|
+
|
51
|
+
Note:
|
52
|
+
- The dimension factor is skipped since it is a constant and does not affect the optimization.
|
53
|
+
|
54
|
+
Example:
|
55
|
+
>>> innovation = torch.tensor([1.0, 2.0])
|
56
|
+
>>> innovation_covariance = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
|
57
|
+
>>> loss = likelihood_at_time(innovation, innovation_covariance)
|
58
|
+
>>> print(loss)
|
59
|
+
"""
|
60
|
+
# Calculate the log determinant of the innovation covariance matrix
|
61
|
+
log_det = torch.linalg.slogdet(innovation_covariance)[1]
|
62
|
+
|
63
|
+
k = innovation.shape[0]
|
64
|
+
C = (-k / 2) * torch.log(
|
65
|
+
2 * torch.tensor(torch.pi, dtype=innovation.dtype, device=innovation.device)
|
66
|
+
)
|
67
|
+
return C - 0.5 * (
|
68
|
+
log_det + innovation @ torch.linalg.inv(innovation_covariance) @ innovation
|
69
|
+
)
|
70
|
+
|
71
|
+
|
72
|
+
def log_likelihood(
|
73
|
+
innovation: torch.Tensor,
|
74
|
+
innovation_covariance: torch.Tensor,
|
75
|
+
) -> torch.Tensor:
|
76
|
+
"""
|
77
|
+
Calculates the log likelihood estimation for a sequence of time steps.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
innovation (torch.Tensor): The innovation vector for each time step,
|
81
|
+
representing the differences between observed and predicted
|
82
|
+
measurements over a sequence. Shape: (seq_len, dim_z)
|
83
|
+
innovation_covariance (torch.Tensor): The innovation covariance matrix for each time step,
|
84
|
+
representing the uncertainty in the innovations
|
85
|
+
over a sequence. Shape: (seq_len, dim_z, dim_z)
|
86
|
+
|
87
|
+
Returns:
|
88
|
+
torch.Tensor: The total log likelihood value for the sequence of innovations and
|
89
|
+
their covariances.
|
90
|
+
|
91
|
+
Example:
|
92
|
+
>>> innovation = torch.tensor([[1.0, 2.0], [0.5, 1.5]])
|
93
|
+
>>> innovation_covariance = torch.tensor([
|
94
|
+
>>> [[1.0, 0.0], [0.0, 1.0]],
|
95
|
+
>>> [[0.5, 0.0], [0.0, 0.5]]
|
96
|
+
>>> ])
|
97
|
+
>>> total_log_likelihood = log_likelihood(innovation, innovation_covariance)
|
98
|
+
>>> print(total_log_likelihood)
|
99
|
+
"""
|
100
|
+
return torch.vmap(gaussain_log_likelihood, in_dims=0)(
|
101
|
+
innovation, innovation_covariance
|
102
|
+
).sum()
|
103
|
+
|
104
|
+
|
105
|
+
def negative_log_likelihood(
|
106
|
+
innovation: torch.Tensor,
|
107
|
+
innovation_covariance: torch.Tensor,
|
108
|
+
) -> torch.Tensor:
|
109
|
+
"""
|
110
|
+
Calculates the negative log likelihood estimation for a sequence of time steps.
|
111
|
+
|
112
|
+
Args:
|
113
|
+
innovation (torch.Tensor): The innovation vector for each time step,
|
114
|
+
representing the differences between observed and predicted
|
115
|
+
measurements over a sequence. Shape: (seq_len, dim_z)
|
116
|
+
innovation_covariance (torch.Tensor): The innovation covariance matrix for each time step,
|
117
|
+
representing the uncertainty in the innovations
|
118
|
+
over a sequence. Shape: (seq_len, dim_z, dim_z)
|
119
|
+
|
120
|
+
Returns:
|
121
|
+
torch.Tensor: The total negative log likelihood value for the sequence of innovations and
|
122
|
+
their covariances.
|
123
|
+
|
124
|
+
Example:
|
125
|
+
>>> innovation = torch.tensor([[1.0, 2.0], [0.5, 1.5]])
|
126
|
+
>>> innovation_covariance = torch.tensor([
|
127
|
+
>>> [[1.0, 0.0], [0.0, 1.0]],
|
128
|
+
>>> [[0.5, 0.0], [0.0, 0.5]]
|
129
|
+
>>> ])
|
130
|
+
>>> total_negative_log_likelihood = negative_log_likelihood(innovation, innovation_covariance)
|
131
|
+
>>> print(total_negative_log_likelihood)
|
132
|
+
"""
|
133
|
+
return -log_likelihood(innovation, innovation_covariance)
|
diffkalman/py.typed
ADDED
File without changes
|
diffkalman/utils.py
ADDED
@@ -0,0 +1,90 @@
|
|
1
|
+
"""Implementation of utility modules for the differentiable Kalman filter."""
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.nn as nn
|
5
|
+
|
6
|
+
__all__ = [
|
7
|
+
"SymmetricPositiveDefiniteMatrix",
|
8
|
+
"DiagonalSymmetricPositiveDefiniteMatrix",
|
9
|
+
]
|
10
|
+
|
11
|
+
|
12
|
+
class SymmetricPositiveDefiniteMatrix(nn.Module):
|
13
|
+
"""
|
14
|
+
Module for ensuring a symmetric positive definite matrix using a parameterized approach.
|
15
|
+
|
16
|
+
This module constructs a symmetric positive definite matrix from an initial matrix Q_0.
|
17
|
+
It ensures that the resultant matrix is symmetric and has positive eigenvalues.
|
18
|
+
|
19
|
+
Attributes:
|
20
|
+
LT_mask (torch.Tensor): Lower triangular mask used for parameter initialization.
|
21
|
+
W (nn.Parameter): Parameter representing the input matrix adjusted for symmetry and positivity.
|
22
|
+
|
23
|
+
Methods:
|
24
|
+
forward():
|
25
|
+
Performs the forward pass of the module.
|
26
|
+
Returns a symmetric positive definite matrix derived from the input parameter.
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(self, M: torch.Tensor, trainable: bool = True):
|
30
|
+
"""
|
31
|
+
Initializes the SymmetricPositiveDefiniteMatrix module.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
M (torch.Tensor): Initial matrix for the parameter.
|
35
|
+
trainable (bool): Flag to indicate whether the parameter is trainable. Default: True.
|
36
|
+
"""
|
37
|
+
super().__init__()
|
38
|
+
|
39
|
+
# Cholesky decomposition of the initial matrix
|
40
|
+
L = torch.linalg.cholesky(M, upper=False)
|
41
|
+
|
42
|
+
# Initialize the parameter with the lower triangular part of the Cholesky decomposition
|
43
|
+
self.W = nn.Parameter(L, requires_grad=trainable)
|
44
|
+
|
45
|
+
# Mask for the diagonal entries
|
46
|
+
self.diag_mask = torch.eye(M.shape[0], dtype=torch.bool)
|
47
|
+
|
48
|
+
def forward(self) -> torch.Tensor:
|
49
|
+
"""
|
50
|
+
Forward pass of the SymmetricPositiveDefiniteMatrix module.
|
51
|
+
|
52
|
+
Returns:
|
53
|
+
torch.Tensor: Symmetric positive definite matrix derived from the input parameter.
|
54
|
+
"""
|
55
|
+
# Make the diagonal entries positive
|
56
|
+
L = torch.tril(self.W)
|
57
|
+
L[self.diag_mask] = torch.abs(L[self.diag_mask])
|
58
|
+
|
59
|
+
return L @ L.T
|
60
|
+
|
61
|
+
|
62
|
+
class DiagonalSymmetricPositiveDefiniteMatrix(nn.Module):
|
63
|
+
"""
|
64
|
+
A PyTorch module representing a diagonal symmetric positive definite matrix.
|
65
|
+
|
66
|
+
This module takes a diagonal matrix `M` as input and constructs a diagonal symmetric positive definite matrix `W`.
|
67
|
+
The diagonal entries of `W` are initialized with the diagonal entries of `M` and can be trained if `trainable` is set to True.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
M (torch.Tensor): The diagonal matrix used to initialize the diagonal entries of `W`.
|
71
|
+
trainable (bool, optional): Whether the diagonal entries of `W` should be trainable. Defaults to True.
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
torch.Tensor: The diagonal symmetric positive definite matrix `W`.
|
75
|
+
"""
|
76
|
+
|
77
|
+
def __init__(self, M: torch.Tensor, trainable: bool = True):
|
78
|
+
super().__init__()
|
79
|
+
|
80
|
+
self.W = nn.Parameter(torch.diag(M.diagonal()), requires_grad=trainable)
|
81
|
+
|
82
|
+
def forward(self) -> torch.Tensor:
|
83
|
+
"""
|
84
|
+
Forward pass of the module.
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
torch.Tensor: The diagonal symmetric positive definite matrix `W`.
|
88
|
+
"""
|
89
|
+
# Make the diagonal entries positive
|
90
|
+
return torch.abs(self.W)
|
@@ -0,0 +1,183 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: diffkalman
|
3
|
+
Version: 0.1.0
|
4
|
+
Summary: A diffrentiable kalman filter library for auto-tuning kalman filters.
|
5
|
+
Author-email: hades <nischalbhattaraipi@gmail.com>
|
6
|
+
License-File: LICENSE
|
7
|
+
Requires-Python: >=3.11
|
8
|
+
Requires-Dist: numpy>=2.2.1
|
9
|
+
Requires-Dist: tqdm>=4.67.1
|
10
|
+
Provides-Extra: cpu
|
11
|
+
Requires-Dist: torch>=2.5.1; extra == 'cpu'
|
12
|
+
Provides-Extra: cu124
|
13
|
+
Requires-Dist: torch>=2.5.1; extra == 'cu124'
|
14
|
+
Description-Content-Type: text/markdown
|
15
|
+
|
16
|
+
# Differentiable Kalman Filter
|
17
|
+
|
18
|
+
A PyTorch-based implementation of a differentiable Kalman Filter designed for both linear and non-linear dynamical systems with Gaussian noise. This module seamlessly integrates with neural networks, enabling learnable dynamics, observation, and noise models optimized through Stochastic Variational Inference (SVI).
|
19
|
+
|
20
|
+
## Features
|
21
|
+
|
22
|
+
- **Fully Differentiable**: End-to-end differentiable implementation compatible with PyTorch's autograd
|
23
|
+
- **Flexible Models**: Support for both linear and non-linear state transition and observation models
|
24
|
+
- **Neural Network Integration**: Models can be parameterized using neural networks
|
25
|
+
- **Automatic Jacobian Computation**: Utilizes PyTorch's autograd for derivative calculations
|
26
|
+
- **Monte Carlo Sampling**: Supports evaluation of expected joint log-likelihood to perform Expectation-Maximization (EM) learning
|
27
|
+
- **Rauch-Tung-Striebel Smoothing**: Implements forward-backward smoothing for improved state estimation using RTS algorithm
|
28
|
+
|
29
|
+
## Installation
|
30
|
+
|
31
|
+
```bash
|
32
|
+
pip install torch # Required dependency
|
33
|
+
# Add your package installation command here
|
34
|
+
```
|
35
|
+
|
36
|
+
## Quick Start
|
37
|
+
|
38
|
+
Here's a simple example of using the Differentiable Kalman Filter:
|
39
|
+
|
40
|
+
```python
|
41
|
+
import torch
|
42
|
+
from diffkalman import DifferentiableKalmanFilter
|
43
|
+
from diffkalman.utils import SymmetricPositiveDefiniteMatrix
|
44
|
+
from diffkalman.em_loop import em_updates
|
45
|
+
|
46
|
+
# Define custom state transition and observation functions
|
47
|
+
class StateTransition(torch.nn.Module):
|
48
|
+
def forward(self, x, *args):
|
49
|
+
# Your state transition logic here
|
50
|
+
return x
|
51
|
+
|
52
|
+
class ObservationModel(torch.nn.Module):
|
53
|
+
def forward(self, x, *args):
|
54
|
+
# Your observation logic here
|
55
|
+
return x
|
56
|
+
|
57
|
+
# Initialize the filter
|
58
|
+
f = StateTransition()
|
59
|
+
h = ObservationModel()
|
60
|
+
Q = SymmetricPositiveDefiniteMatrix(dim=4, trainable=True)
|
61
|
+
R = SymmetricPositiveDefiniteMatrix(dim=2, trainable=True)
|
62
|
+
kalman_filter = DifferentiableKalmanFilter(
|
63
|
+
dim_x=4, # State dimension
|
64
|
+
dim_z=2, # Observation dimension
|
65
|
+
f=f, # State transition function
|
66
|
+
h=h # Observation function
|
67
|
+
)
|
68
|
+
|
69
|
+
# Run the filter
|
70
|
+
results = kalman_filter.sequence_filter(
|
71
|
+
z_seq=observations, # Shape: (T, dim_z)
|
72
|
+
x0=initial_state, # Shape: (dim_x,)
|
73
|
+
P0=initial_covariance, # Shape: (dim_x, dim_x)
|
74
|
+
Q=Q().repeat(len(observations), 1, 1), # Shape: (T, dim_x, dim_x)
|
75
|
+
R=R().repeat(len(observations), 1, 1) # Shape: (T, dim_z, dim_z)
|
76
|
+
)
|
77
|
+
```
|
78
|
+
|
79
|
+
## Detailed Usage
|
80
|
+
|
81
|
+
### State Estimation
|
82
|
+
|
83
|
+
The module provides three main estimation methods:
|
84
|
+
|
85
|
+
1. **Filtering**: Forward pass only
|
86
|
+
```python
|
87
|
+
filtered_results = kalman_filter.sequence_filter(
|
88
|
+
z_seq=observations,
|
89
|
+
x0=initial_state,
|
90
|
+
P0=initial_covariance,
|
91
|
+
Q=process_noise,
|
92
|
+
R=observation_noise
|
93
|
+
)
|
94
|
+
```
|
95
|
+
|
96
|
+
2. **Smoothing**: Forward-backward pass
|
97
|
+
```python
|
98
|
+
smoothed_results = kalman_filter.sequence_smooth(
|
99
|
+
z_seq=observations,
|
100
|
+
x0=initial_state,
|
101
|
+
P0=initial_covariance,
|
102
|
+
Q=process_noise,
|
103
|
+
R=observation_noise
|
104
|
+
)
|
105
|
+
```
|
106
|
+
|
107
|
+
3. **Single-step Prediction**: For real-time applications
|
108
|
+
```python
|
109
|
+
step_result = kalman_filter.predict_update(
|
110
|
+
z=current_observation,
|
111
|
+
x=current_state,
|
112
|
+
P=current_covariance,
|
113
|
+
Q=process_noise,
|
114
|
+
R=observation_noise
|
115
|
+
)
|
116
|
+
```
|
117
|
+
|
118
|
+
### Parameter Learning
|
119
|
+
|
120
|
+
The module supports learning model parameters through using backpropagation using the negative expected joint log-likelihood of the
|
121
|
+
data as the loss function.
|
122
|
+
|
123
|
+
```python
|
124
|
+
# Define optimizer
|
125
|
+
optimizer = torch.optim.Adam(params=[
|
126
|
+
{'params': kalman_filter.f.parameters()},
|
127
|
+
{'params': kalman_filter.h.parameters()},
|
128
|
+
{'params': Q.parameters()},
|
129
|
+
{'params': R.parameters()}
|
130
|
+
]
|
131
|
+
|
132
|
+
NUM_EPOCHS = 10
|
133
|
+
NUM_CYCLES = 10
|
134
|
+
|
135
|
+
# Run the EM loop
|
136
|
+
marginal_likelihoods = em_updates(
|
137
|
+
kalman_filter=kalman_filter,
|
138
|
+
z_seq=observations,
|
139
|
+
x0=initial_state,
|
140
|
+
P0=initial_covariance,
|
141
|
+
Q=Q,
|
142
|
+
R=R,
|
143
|
+
optimizer=optimizer,
|
144
|
+
num_cycles=NUM_CYCLES,
|
145
|
+
num_epochs=NUM_EPOCHS
|
146
|
+
)
|
147
|
+
|
148
|
+
```
|
149
|
+
|
150
|
+
## API Reference
|
151
|
+
|
152
|
+
### DifferentiableKalmanFilter
|
153
|
+
|
154
|
+
Main class implementing the Kalman Filter algorithm.
|
155
|
+
|
156
|
+
#### Constructor Parameters:
|
157
|
+
- `dim_x` (int): State space dimension
|
158
|
+
- `dim_z` (int): Observation space dimension
|
159
|
+
- `f` (nn.Module): State transition function
|
160
|
+
- `h` (nn.Module): Observation function
|
161
|
+
- `mc_samples` (int, optional): Number of Monte Carlo samples for log-likelihood estimation
|
162
|
+
|
163
|
+
#### Key Methods:
|
164
|
+
- `predict`: State prediction step
|
165
|
+
- `update`: Measurement update step
|
166
|
+
- `predict_update`: Combined prediction and update
|
167
|
+
- `sequence_filter`: Full sequence filtering
|
168
|
+
- `sequence_smooth`: Full sequence smoothing
|
169
|
+
- `marginal_log_likelihood`: Compute marginal log-likelihood
|
170
|
+
- `monte_carlo_expected_joint_log_likekihood`: Estimate expected joint log-likelihood
|
171
|
+
|
172
|
+
## Requirements
|
173
|
+
|
174
|
+
- PyTorch >= 1.9.0
|
175
|
+
- Python >= 3.7
|
176
|
+
|
177
|
+
## Contributing
|
178
|
+
|
179
|
+
Contributions are welcome! Please feel free to submit a Pull Request.
|
180
|
+
|
181
|
+
## License
|
182
|
+
|
183
|
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
@@ -0,0 +1,11 @@
|
|
1
|
+
diffkalman/__init__.py,sha256=owxzQKRKYYZU_0bQ9QEIanoOx-Fv138JSFl8VdvQfI8,180
|
2
|
+
diffkalman/em_loop.py,sha256=j-OJEGuotI-3z8l35joevBIMWZeNOvaKNUr1QLqzRbM,4682
|
3
|
+
diffkalman/filter.py,sha256=xGk-JqZlf7vhHFk3wV3GQQN9fZrf5XbGezECK6iCo2U,27412
|
4
|
+
diffkalman/joint_jacobian_transform.py,sha256=HRmWkcCwCQxrENZjJeE3RY7901bPUP9qnbqh7NRKiYw,2000
|
5
|
+
diffkalman/negative_log_likelihood.py,sha256=B1ICGZ0urrS1THxvg020rjJe3gua_AabvwB64L7Rhdo,5730
|
6
|
+
diffkalman/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
|
+
diffkalman/utils.py,sha256=ovlLCdPJksAA5mfGtDaNCqqA7hLBhH5-IMRdGvMQ22o,3230
|
8
|
+
diffkalman-0.1.0.dist-info/METADATA,sha256=s50Zc-lmWCINrgyBHP-SMgW-WvbtZYGBV9ALgCuE0UY,5423
|
9
|
+
diffkalman-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
10
|
+
diffkalman-0.1.0.dist-info/licenses/LICENSE,sha256=IHkfLg6o-U7zWIh6oiJUp0j0abQ59foTPoB1OLEPleQ,1063
|
11
|
+
diffkalman-0.1.0.dist-info/RECORD,,
|
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2025 HadesX
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|