diffkalman 0.1.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
diffkalman/__init__.py ADDED
@@ -0,0 +1,6 @@
1
+ """Top level module for diffkalman package."""
2
+ from .filter import DiffrentiableKalmanFilter
3
+ from .em_loop import em_updates
4
+
5
+
6
+ __all__ = ['DiffrentiableKalmanFilter', 'em_updates']
diffkalman/em_loop.py ADDED
@@ -0,0 +1,115 @@
1
+ """The EM loop module that implements the EM algorithm for the Differentiable Kalman Filter."""
2
+
3
+ from .filter import DiffrentiableKalmanFilter
4
+ from .utils import SymmetricPositiveDefiniteMatrix
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ __all__ = [
10
+ "em_updates",
11
+ ]
12
+
13
+
14
+ def em_updates(
15
+ dkf: DiffrentiableKalmanFilter,
16
+ z_seq: torch.Tensor,
17
+ x0: torch.Tensor,
18
+ P0: torch.Tensor,
19
+ Q: SymmetricPositiveDefiniteMatrix,
20
+ R: SymmetricPositiveDefiniteMatrix,
21
+ optimizer: torch.optim.Optimizer,
22
+ lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
23
+ num_cycles: int = 20,
24
+ num_epochs: int = 100,
25
+ h_args: tuple = (),
26
+ f_args: tuple = (),
27
+ ) -> dict:
28
+ """A sample implementation of the EM algorithm for the Differentiable Kalman Filter.
29
+
30
+ Args:
31
+ z (torch.Tensor): The noisy measurements sequence. Dimension: (seq_len, obs_dim)
32
+ x0 (torch.Tensor): The initial state vector. Dimension: (state_dim,)
33
+ P0 (torch.Tensor): The initial covariance matrix of the state vector. Dimension: (state_dim, state_dim)
34
+ Q (SymmetricPositiveDefiniteMatrix): The process noise covariance matrix module.
35
+ R (SymmetricPositiveDefiniteMatrix): The measurement noise covariance matrix module.
36
+ optimizer (torch.optim.Optimizer): The optimizer
37
+ lr_scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler
38
+ num_cycles (int): The number of cycles. Default: 20
39
+ num_epochs (int): The number of epochs. Default: 100
40
+ h_args (tuple): Additional arguments for the measurement model. Default: ()
41
+ f_args (tuple): Additional arguments for the transition function. Default: ()
42
+
43
+ Returns:
44
+ dict: log likelihoods of the model with respect to the number of epochs and cycles
45
+
46
+ Note:
47
+ - Use the SymmetricPositiveDefiniteMatrix module to ensure that the process noise covariance matrix Q and the measurement noise covariance matrix R are symmetric and positive definite.
48
+ - The optimizer and the learning rate scheduler should be initialized before calling this function.
49
+ - The measurement model and the transition function should be differentiable torch modules.
50
+ """
51
+ likelihoods = torch.zeros(num_epochs, num_cycles)
52
+
53
+ # Perform EM updates for num_epochs
54
+ for e in range(num_epochs):
55
+
56
+ ## The E-step
57
+ ## Without gradients tracking, get the posterior state distribution wrt current values of the parameters
58
+ with torch.no_grad():
59
+ posterior = dkf.sequence_smooth(
60
+ z_seq=z_seq,
61
+ x0=x0,
62
+ P0=P0,
63
+ Q=Q().repeat(len(z_seq), 1, 1),
64
+ R=R().repeat(len(z_seq), 1, 1),
65
+ f_args=f_args,
66
+ h_args=h_args,
67
+ )
68
+
69
+ ## The M-step (Update the parameters) with respect to the current posterior state distribution for num_cycles
70
+ for c in range(num_cycles):
71
+ # Zero the gradients
72
+ optimizer.zero_grad()
73
+
74
+ # Compute the marginal likelihood for logging
75
+ with torch.no_grad():
76
+ marginal_likelihood = dkf.marginal_log_likelihood(
77
+ z_seq=z_seq,
78
+ x0=x0,
79
+ P0=P0,
80
+ Q=Q().repeat(len(z_seq), 1, 1),
81
+ R=R().repeat(len(z_seq), 1, 1),
82
+ f_args=f_args,
83
+ h_args=h_args,
84
+ )
85
+ likelihoods[e, c] = marginal_likelihood
86
+
87
+ # Perform the forward pass i.e compute the expected complete joint log-likelihood with respect previous posterior state distribution and current parameters
88
+ complete_log_likelihood = dkf.monte_carlo_expected_joint_log_likekihood(
89
+ z_seq=z_seq,
90
+ x0=x0,
91
+ P0=P0,
92
+ # below represents the posterior state distribution
93
+ x0_smoothed=posterior["x0_smoothed"],
94
+ P0_smoothed=posterior["P0_smoothed"],
95
+ x_smoothed=posterior["x_smoothed"],
96
+ P_smoothed=posterior["P_smoothed"],
97
+ Q_seq=Q().repeat(len(z_seq), 1, 1),
98
+ R_seq=R().repeat(len(z_seq), 1, 1),
99
+ f_args=f_args,
100
+ h_args=h_args,
101
+ )
102
+
103
+ # Update the parameters
104
+ (-complete_log_likelihood).backward()
105
+ optimizer.step()
106
+ lr_scheduler.step()
107
+
108
+ # Print the log likelihood
109
+ print(
110
+ f"Epoch {e + 1}/{num_epochs} Cycle {c + 1}/{num_cycles} Log Likelihood: {marginal_likelihood.item()}"
111
+ )
112
+
113
+ return {
114
+ "likelihoods": likelihoods,
115
+ }
diffkalman/filter.py ADDED
@@ -0,0 +1,654 @@
1
+ """
2
+ This module provides a differentiable Kalman Filter implementation designed for linear and non-linear dynamical systems perturbed by Gaussian noise.
3
+
4
+ The Kalman Filter is implemented using the PyTorch library, allowing seamless integration with neural networks to parameterize the dynamics, observation, and noise models. These parameters are optimized using Stochastic Variational Inference (SVI), which is mathematically equivalent to the classical Expectation-Maximization (EM) algorithm. The Gaussian assumption ensures that the posterior distribution of states remains analytically tractable using Rauch-Tung-Striebel smoother equations. This tractability facilitates the maximization of the log-likelihood of the observed data, given the model parameters.
5
+
6
+ Key Features:
7
+ - Differentiable Kalman Filter implementation.
8
+ - Flexible state transition and observation models (linear/non-linear, learnable, or fixed).
9
+ - Automatic computation of Jacobians using PyTorch's autograd functionality.
10
+ - Support for Monte Carlo sampling to evaluate the expected joint log-likelihood.
11
+
12
+ Classes:
13
+ DifferentiableKalmanFilter: Implements the core Kalman Filter algorithm for linear and non-linear dynamical systems.
14
+
15
+ Functions:
16
+ joint_jacobian_transform: Computes joint Jacobians for state and observation functions.
17
+ log_likelihood: Computes the log-likelihood for Gaussian distributions.
18
+ gaussian_log_likelihood: Computes Gaussian log-likelihoods for given inputs.
19
+
20
+ Example:
21
+ # Define custom state transition and observation functions
22
+ f = MyStateTransitionFunction()
23
+ h = MyObservationFunction()
24
+
25
+ # Initialize the Kalman filter
26
+ kalman_filter = DifferentiableKalmanFilter(dim_x=4, dim_z=2, f=f, h=h)
27
+
28
+ # Perform filtering on a sequence of observations
29
+ results = kalman_filter.sequence_filter(
30
+ z_seq=observations,
31
+ x0=initial_state,
32
+ P0=initial_covariance,
33
+ Q=process_noise_covariance,
34
+ R=observation_noise_covariance,
35
+ )
36
+
37
+ Dependencies:
38
+ - PyTorch: For deep learning integration and differentiability.
39
+ - Custom Modules: `joint_jacobian_transform`, `negative_log_likelihood`.
40
+
41
+ __all__ = ["DifferentiableKalmanFilter"]
42
+ """
43
+
44
+ import torch
45
+ import torch.nn as nn
46
+ from .joint_jacobian_transform import joint_jacobian_transform
47
+ from .negative_log_likelihood import log_likelihood, gaussain_log_likelihood
48
+
49
+
50
+ __all__ = ["DiffrentiableKalmanFilter"]
51
+
52
+
53
+ class DiffrentiableKalmanFilter(nn.Module):
54
+ """
55
+ Implements a differentiable Kalman Filter for linear and non-linear dynamical systems.
56
+
57
+ This class provides methods for prediction, update, and smoothing, enabling filtering of
58
+ sequences of observations while allowing the use of learnable neural network models for
59
+ state transition and observation functions.
60
+
61
+ Args:
62
+ dim_x (int): Dimensionality of the state space.
63
+ dim_z (int): Dimensionality of the observation space.
64
+ f (nn.Module): State transition function, which can be linear/non-linear and learnable or fixed.
65
+ h (nn.Module): Observation function, which can be linear/non-linear and learnable or fixed.
66
+ mc_samples (int, optional): Number of Monte Carlo samples used for expected log-likelihood computation. Defaults to 100.
67
+
68
+ Attributes:
69
+ dim_x (int): Dimensionality of the state space.
70
+ dim_z (int): Dimensionality of the observation space.
71
+ f (nn.Module): State transition function.
72
+ h (nn.Module): Observation function.
73
+ I (torch.Tensor): Identity matrix of size `(dim_x, dim_x)`.
74
+ _f_joint (Callable): Joint Jacobian function of the state transition function.
75
+ _h_joint (Callable): Joint Jacobian function of the observation function.
76
+
77
+ Methods:
78
+ predict: Performs state prediction given the current state and covariance.
79
+ update: Updates the state estimate based on the current observation.
80
+ predict_update: Combines prediction and update steps for a single observation.
81
+ sequence_filter: Filters a sequence of observations to estimate states over time.
82
+ sequence_smooth: Smooths a sequence of observations for refined state estimates.
83
+
84
+ Note:
85
+ - PyTorch's autograd functionality is utilized to compute Jacobians of the state transition
86
+ and observation functions, eliminating the need for manual derivation.
87
+ - Monte Carlo sampling is used for approximating the expected log-likelihood in non-linear scenarios.
88
+ """
89
+
90
+ def __init__(
91
+ self, dim_x: int, dim_z: int, f: nn.Module, h: nn.Module, mc_samples: int = 100
92
+ ) -> None:
93
+ """Initializes the DiffrentiableKalmanFilter object.
94
+
95
+ Args:
96
+ dim_x (int): Dimensionality of the state space.
97
+ dim_z (int): Dimensionality of the observation space.
98
+ f (nn.Module): State transition function which can be parametrized by a neural network which can linear/non-linear and learnable or fixed.
99
+ h (nn.Module): Observation function which can be parametrized by a neural network which can linear/non-linear and learnable or fixed.
100
+ mc_samples (int): Number of Monte Carlo samples to draw to calcualte expected joint log-likelihood. Default is 100.
101
+
102
+ Returns:
103
+ None
104
+
105
+ Note:
106
+ - The state transition function signature is f(x: torch.Tensor, *args) -> torch.Tensor.
107
+ - The observation function signature is h(x: torch.Tensor, *args) -> torch.Tensor.
108
+ """
109
+
110
+ super().__init__()
111
+
112
+ # Store the dimensionality of the state and observation space
113
+ self.dim_x = dim_x
114
+ self.dim_z = dim_z
115
+
116
+ # Store the state transition and observation functions
117
+ self.f = f
118
+ self.h = h
119
+
120
+ # Store the number of Monte Carlo samples
121
+ self.mc_samples = mc_samples
122
+
123
+ # Register the Jacobian functions of the state transition and observation functions
124
+ self._f_joint = joint_jacobian_transform(self.f)
125
+ self._h_joint = joint_jacobian_transform(self.h)
126
+
127
+ # Identity matrix of shape (dim_x, dim_x)
128
+ self.register_buffer("I", torch.eye(dim_x))
129
+
130
+ def predict(
131
+ self,
132
+ x: torch.Tensor,
133
+ P: torch.Tensor,
134
+ Q: torch.Tensor,
135
+ f_args: tuple = (),
136
+ ) -> dict[str, torch.Tensor]:
137
+ """
138
+ Predicts the next state given the current state and the state transition function.
139
+
140
+ Args:
141
+ x (torch.Tensor): Current state estimate of shape `(dim_x,)`.
142
+ P (torch.Tensor): Current state covariance of shape `(dim_x, dim_x)`.
143
+ Q (torch.Tensor): Process noise covariance of shape `(dim_x, dim_x)`.
144
+ f_args (tuple, optional): Additional arguments to the state transition function.
145
+
146
+ Returns:
147
+ dict[str, torch.Tensor]:
148
+ - `x_prior`: Predicted state estimate of shape `(dim_x,)`.
149
+ - `P_prior`: Predicted state covariance of shape `(dim_x, dim_x)`.
150
+ - `state_jacobian`: Jacobian of the state transition function of shape `(dim_x, dim_x)`.
151
+
152
+ Note:
153
+ - The control input, if any, can be incorporated into the state transition function by passing it as an argument in `f_args`.
154
+ """
155
+ # Compute the predicted state estimate and covariance
156
+ F, x_prior = self._f_joint(x, *f_args)
157
+
158
+ P_prior = F @ P @ F.T + Q
159
+
160
+ return {"x_prior": x_prior, "P_prior": P_prior, "state_jacobian": F}
161
+
162
+ def update(
163
+ self,
164
+ z: torch.Tensor,
165
+ x_prior: torch.Tensor,
166
+ P_prior: torch.Tensor,
167
+ R: torch.Tensor,
168
+ h_args: tuple = (),
169
+ ) -> dict[str, torch.Tensor]:
170
+ """
171
+ Updates the state estimate based on the current observation and observation model.
172
+
173
+ Args:
174
+ x (torch.Tensor): Predicted state estimate of shape `(dim_x,)`.
175
+ P (torch.Tensor): Predicted state covariance of shape `(dim_x, dim_x)`.
176
+ z (torch.Tensor): Current observation of shape `(dim_z,)`.
177
+ R (torch.Tensor): Observation noise covariance of shape `(dim_z, dim_z)`.
178
+ h_args (tuple, optional): Additional arguments to the observation function.
179
+
180
+ Returns:
181
+ dict[str, torch.Tensor]:
182
+ - `x_post`: Updated state estimate of shape `(dim_x,)`.
183
+ - `P_post`: Updated state covariance of shape `(dim_x, dim_x)`.
184
+ - `y`: Observation residual (innovation) of shape `(dim_z,)`.
185
+ - `S`: Innovation covariance of shape `(dim_z, dim_z)`.
186
+ - `K`: Kalman gain of shape `(dim_x, dim_z)`.
187
+
188
+ Note:
189
+ - PyTorch's autograd is used to compute the Jacobian of the observation function, ensuring compatibility with differentiable models.
190
+ """
191
+ # Compute the jacboian of the observation function
192
+ H, z_pred = self._h_joint(x_prior, *h_args)
193
+
194
+ # Compute the innovation
195
+ y = z - z_pred
196
+ # Compute the innovation covariance matrix
197
+ S = H @ P_prior @ H.T + R
198
+ # Compute the Kalman gain
199
+ K = P_prior @ H.T @ torch.linalg.inv(S)
200
+
201
+ # Update the state vector
202
+ x_post = x_prior + K @ y
203
+ # Update the state covariance matrix using joseph form
204
+ factor = self.I - K @ H
205
+ P_post = factor @ P_prior @ factor.T + K @ R @ K.T
206
+
207
+ return {
208
+ "x_post": x_post,
209
+ "P_post": P_post,
210
+ "innovation": y,
211
+ "innovation_covariance": S,
212
+ "observation_jacobian": H,
213
+ "kalman_gain": K,
214
+ }
215
+
216
+ def predict_update(
217
+ self,
218
+ z: torch.Tensor,
219
+ x: torch.Tensor,
220
+ P: torch.Tensor,
221
+ Q: torch.Tensor,
222
+ R: torch.Tensor,
223
+ f_args: tuple = (),
224
+ h_args: tuple = (),
225
+ ) -> dict[str, torch.Tensor]:
226
+ """
227
+ Combines prediction and update steps for a single observation.
228
+
229
+ Args:
230
+ x (torch.Tensor): Current state estimate of shape `(dim_x,)`.
231
+ P (torch.Tensor): Current state covariance of shape `(dim_x, dim_x)`.
232
+ z (torch.Tensor): Current observation of shape `(dim_z,)`.
233
+ Q (torch.Tensor): Process noise covariance of shape `(dim_x, dim_x)`.
234
+ R (torch.Tensor): Observation noise covariance of shape `(dim_z, dim_z)`.
235
+ f_args (tuple, optional): Additional arguments to the state transition function.
236
+ h_args (tuple, optional): Additional arguments to the observation function.
237
+
238
+ Returns:
239
+ dict[str, torch.Tensor]:
240
+ - `x_post`: Updated state estimate of shape `(dim_x,)`.
241
+ - `P_post`: Updated state covariance of shape `(dim_x, dim_x)`.
242
+ - `y`: Observation residual (innovation) of shape `(dim_z,)`.
243
+ - `S`: Innovation covariance of shape `(dim_z, dim_z)`.
244
+ - `K`: Kalman gain of shape `(dim_x, dim_z)`.
245
+
246
+ Note:
247
+ - This method is particularly useful for real-time filtering where observations arrive sequentially.
248
+ """
249
+ # Predict the next state
250
+ prediction = self.predict(x=x, P=P, Q=Q, f_args=f_args)
251
+
252
+ # Update the state estimate
253
+ update = self.update(
254
+ z=z,
255
+ x_prior=prediction["x_prior"],
256
+ P_prior=prediction["P_prior"],
257
+ R=R,
258
+ h_args=h_args,
259
+ )
260
+
261
+ return {**prediction, **update}
262
+
263
+ def sequence_filter(
264
+ self,
265
+ z_seq: torch.Tensor,
266
+ x0: torch.Tensor,
267
+ P0: torch.Tensor,
268
+ Q: torch.Tensor,
269
+ R: torch.Tensor,
270
+ f_args: tuple = (),
271
+ h_args: tuple = (),
272
+ ) -> dict[str, torch.Tensor]:
273
+ """
274
+ Filters a sequence of observations to estimate states over time.
275
+
276
+ Args:
277
+ z_seq (torch.Tensor): Sequence of observations of shape `(T, dim_z)`, where `T` is the number of time steps.
278
+ x0 (torch.Tensor): Initial state estimate of shape `(dim_x,)`.
279
+ P0 (torch.Tensor): Initial state covariance of shape `(dim_x, dim_x)`.
280
+ Q (torch.Tensor): Process noise covariance of shape `(T, dim_x, dim_x)`.
281
+ R (torch.Tensor): Observation noise covariance of shape `(T, dim_z, dim_z)`.
282
+ f_args (tuple, optional): Additional arguments to the state transition function.
283
+ h_args (tuple, optional): Additional arguments to the observation function.
284
+
285
+ Returns:
286
+ dict[str, torch.Tensor]:
287
+ - `x_filtered`: Filtered state estimates of shape `(T, dim_x)`.
288
+ - `P_filtered`: Filtered state covariances of shape `(T, dim_x, dim_x)`.
289
+
290
+ Note:
291
+ - Each time step is processed iteratively using the `predict` and `update` methods.
292
+ """
293
+ # Get the sequence length
294
+ T = z_seq.size(0)
295
+
296
+ # Initialize the output tensors
297
+ outs_terms = [
298
+ "x_prior",
299
+ "P_prior",
300
+ "x_post",
301
+ "P_post",
302
+ "state_jacobian",
303
+ "innovation",
304
+ "innovation_covariance",
305
+ "observation_jacobian",
306
+ "kalman_gain",
307
+ ]
308
+ outputs = {term: [] for term in outs_terms}
309
+
310
+ # Run the Kalman Filter sequentially
311
+ for t in range(T):
312
+ # Predict the next state
313
+ results = self.predict_update(
314
+ z=z_seq[t],
315
+ x=x0 if t == 0 else outputs["x_post"][-1],
316
+ P=P0 if t == 0 else outputs["P_post"][-1],
317
+ Q=Q[t],
318
+ R=R[t],
319
+ f_args=(args[t] for args in f_args),
320
+ h_args=(args[t] for args in h_args),
321
+ )
322
+
323
+ # Update the output tensors
324
+ for term in outs_terms:
325
+ outputs[term].append(results[term])
326
+
327
+ # Stack the output tensors
328
+ for term in outs_terms:
329
+ outputs[term] = torch.stack(outputs[term])
330
+
331
+ # Calculate the log-likelihood
332
+ log_like = log_likelihood(
333
+ innovation=outputs["innovation"],
334
+ innovation_covariance=outputs["innovation_covariance"],
335
+ )
336
+ outputs["log_likelihood"] = log_like
337
+
338
+ return outputs
339
+
340
+ def marginal_log_likelihood(
341
+ self,
342
+ z_seq: torch.Tensor,
343
+ x0: torch.Tensor,
344
+ P0: torch.Tensor,
345
+ Q: torch.Tensor,
346
+ R: torch.Tensor,
347
+ f_args: tuple = (),
348
+ h_args: tuple = (),
349
+ ) -> torch.Tensor:
350
+ """
351
+ Computes the marginal log-likelihood of the observed data given the model parameters.
352
+
353
+ Args:
354
+ z_seq (torch.Tensor): Sequence of observations of shape `(T, dim_z)`, where `T` is the number of time steps.
355
+ x0 (torch.Tensor): Initial state estimate of shape `(dim_x,)`.
356
+ P0 (torch.Tensor): Initial state covariance of shape `(dim_x, dim_x)`.
357
+ Q (torch.Tensor): Process noise covariance of shape `(T, dim_x, dim_x)`.
358
+ R (torch.Tensor): Observation noise covariance of shape `(T, dim_z, dim_z)`.
359
+ f_args (tuple, optional): Additional arguments to the state transition function.
360
+ h_args (tuple, optional): Additional arguments to the observation function.
361
+
362
+ Returns:
363
+ torch.Tensor: The marginal log-likelihood of the observed data given the model parameters.
364
+ """
365
+ # Run the Kalman Filter and only track innovation and innovation covariance
366
+ innov = []
367
+ innov_cov = []
368
+
369
+ # Initialize the state estimate
370
+ x_post = x0
371
+ P_post = P0
372
+
373
+ for t in range(z_seq.size(0)):
374
+ # Predict the next state
375
+ prediction = self.predict(
376
+ x=x_post,
377
+ P=P_post,
378
+ Q=Q[t],
379
+ f_args=(args[t] for args in f_args),
380
+ )
381
+
382
+ # Update the state estimate
383
+ update = self.update(
384
+ z=z_seq[t],
385
+ x_prior=prediction["x_prior"],
386
+ P_prior=prediction["P_prior"],
387
+ R=R[t],
388
+ h_args=(args[t] for args in h_args),
389
+ )
390
+
391
+ # Update the state estimate
392
+ x_post = update["x_post"]
393
+ P_post = update["P_post"]
394
+
395
+ # Store the innovation and innovation covariance
396
+ innov.append(update["innovation"])
397
+ innov_cov.append(update["innovation_covariance"])
398
+
399
+ return log_likelihood(
400
+ innovation=torch.stack(innov),
401
+ innovation_covariance=torch.stack(innov_cov),
402
+ )
403
+
404
+ def rauch_tung_striebel_smoothing(
405
+ self,
406
+ x0: torch.Tensor,
407
+ P0: torch.Tensor,
408
+ x_prior: torch.Tensor,
409
+ P_prior: torch.Tensor,
410
+ x_filtered: torch.Tensor,
411
+ P_filtered: torch.Tensor,
412
+ state_jacobian: torch.Tensor,
413
+ ) -> dict[str, torch.Tensor]:
414
+ """
415
+ Smooths a sequence of filtered states for refined estimates over time.
416
+
417
+ Args:
418
+ x_filtered (torch.Tensor): Filtered state estimates of shape `(T, dim_x)`.
419
+ P_filtered (torch.Tensor): Filtered state covariances of shape `(T, dim_x, dim_x)`.
420
+ Q (torch.Tensor): Process noise covariance of shape `(dim_x, dim_x)`.
421
+ f_args (tuple, optional): Additional arguments to the state transition function.
422
+
423
+ Returns:
424
+ dict[str, torch.Tensor]:
425
+ - `x_smoothed`: Smoothed state estimates of shape `(T, dim_x)`.
426
+ - `P_smoothed`: Smoothed state covariances of shape `(T, dim_x, dim_x)`.
427
+
428
+ Note:
429
+ - This method uses Rauch-Tung-Striebel smoothing equations to refine the estimates.
430
+ - Smoothed states incorporate information from future observations, improving the overall estimation accuracy.
431
+ """
432
+
433
+ # Initialize the output tensors
434
+ outs_terms = ["x_smoothed", "P_smoothed", "smoothing_gain"]
435
+ outputs = {term: [] for term in outs_terms}
436
+
437
+ # Last state estimate is already the smoothed state estimate
438
+ # by definition
439
+ outputs["x_smoothed"].append(x_filtered[-1])
440
+ outputs["P_smoothed"].append(P_filtered[-1])
441
+
442
+ # Sequence length
443
+ T = x_filtered.size(0)
444
+
445
+ # Start the backward-recursion form the second last time step
446
+ # to the first time step
447
+ for t in range(T - 2, -1, -1):
448
+ # Compute the smoothing gain
449
+ L = (
450
+ P_filtered[t]
451
+ @ state_jacobian[t + 1].T
452
+ @ torch.linalg.inv(P_prior[t + 1])
453
+ )
454
+ # Insert the smoothing gain into the output tensor
455
+ outputs["smoothing_gain"].insert(0, L)
456
+ # Compute the smoothed state estimate
457
+ outputs["x_smoothed"].insert(
458
+ 0, x_filtered[t] + L @ (outputs["x_smoothed"][0] - x_prior[t + 1])
459
+ )
460
+ # Compute the smoothed state covariance
461
+ outputs["P_smoothed"].insert(
462
+ 0, P_filtered[t] + L @ (outputs["P_smoothed"][0] - P_prior[t + 1]) @ L.T
463
+ )
464
+
465
+ # Smoothed the initial state estimate
466
+ L0 = P0 @ state_jacobian[0].T @ torch.linalg.inv(P_prior[0])
467
+ outputs["smoothing_gain"].insert(0, L0)
468
+ x0_smoothed = x0 + L0 @ (outputs["x_smoothed"][0] - x_prior[0])
469
+ P0_smoothed = P0 + L0 @ (outputs["P_smoothed"][0] - P_prior[0]) @ L0.T
470
+
471
+ return {
472
+ "x_smoothed": torch.stack(outputs["x_smoothed"]),
473
+ "P_smoothed": torch.stack(outputs["P_smoothed"]),
474
+ "smoothing_gain": torch.stack(outputs["smoothing_gain"]),
475
+ "x0_smoothed": x0_smoothed,
476
+ "P0_smoothed": P0_smoothed,
477
+ }
478
+
479
+ def sequence_smooth(
480
+ self,
481
+ z_seq: torch.Tensor,
482
+ x0: torch.Tensor,
483
+ P0: torch.Tensor,
484
+ Q: torch.Tensor,
485
+ R: torch.Tensor,
486
+ f_args: tuple = (),
487
+ h_args: tuple = (),
488
+ ) -> dict[str, torch.Tensor]:
489
+ """
490
+ Applies a smoothing algorithm to a sequence of observations using a two-step process:
491
+ Kalman filtering followed by Rauch-Tung-Striebel smoothing.
492
+
493
+ Args:
494
+ z_seq (torch.Tensor): Sequence of observations with shape `(seq_len, dim_z)`.
495
+ x0 (torch.Tensor): Initial state estimate with shape `(dim_x,)`.
496
+ P0 (torch.Tensor): Initial state covariance matrix with shape `(dim_x, dim_x)`.
497
+ Q (torch.Tensor): Process noise covariance matrix with shape `(seq_len, dim_x, dim_x)`.
498
+ R (torch.Tensor): Observation noise covariance matrix with shape `(seq_len, dim_z, dim_z)`.
499
+ f_args (tuple): Optional sequence of additional arguments for the state transition function.
500
+ h_args (tuple): Optional sequence of additional arguments for the observation function.
501
+
502
+ Returns:
503
+ dict[str, torch.Tensor]: A dictionary containing:
504
+ - Filtered state estimates (`x_post`) and covariances (`P_post`).
505
+ - Smoothed state estimates (`x_smooth`) and covariances (`P_smooth`).
506
+
507
+ Notes:
508
+ - The `f_args` and `h_args` sequences correspond to specific time steps.
509
+ - For time-invariant `Q` or `R`, use PyTorch's `repeat` functionality to create sequences.
510
+ """
511
+ # Run the Kalman Filter
512
+ filter_results = self.sequence_filter(
513
+ z_seq=z_seq,
514
+ x0=x0,
515
+ P0=P0,
516
+ Q=Q,
517
+ R=R,
518
+ f_args=f_args,
519
+ h_args=h_args,
520
+ )
521
+
522
+ # Run the Rauch-Tung-Striebel Smoothing
523
+ smooth_results = self.rauch_tung_striebel_smoothing(
524
+ x0=x0,
525
+ P0=P0,
526
+ x_prior=filter_results["x_prior"],
527
+ P_prior=filter_results["P_prior"],
528
+ x_filtered=filter_results["x_post"],
529
+ P_filtered=filter_results["P_post"],
530
+ state_jacobian=filter_results["state_jacobian"],
531
+ )
532
+
533
+ return {**filter_results, **smooth_results}
534
+
535
+ def draw_samples(
536
+ self, n_samples: int, mu: torch.Tensor, P: torch.Tensor
537
+ ) -> torch.Tensor:
538
+ """
539
+ Generates samples from a multivariate normal distribution parameterized
540
+ by a mean vector and covariance matrix.
541
+
542
+ Args:
543
+ n_samples (int): Number of samples to generate.
544
+ mu (torch.Tensor): Mean vector of the distribution with shape `(dim_x,)`.
545
+ P (torch.Tensor): Covariance matrix of the distribution with shape `(dim_x, dim_x)`.
546
+
547
+ Returns:
548
+ torch.Tensor: A tensor of generated samples with shape `(n_samples, dim_x)`.
549
+
550
+ Notes:
551
+ - Cholesky decomposition is used to ensure numerical stability.
552
+ - The samples are drawn using PyTorch's distribution utilities.
553
+ """
554
+ # Draw samples from a multivariate normal distribution
555
+ X_unit_samples = (
556
+ torch.distributions.MultivariateNormal(
557
+ loc=torch.zeros(self.dim_x), covariance_matrix=torch.eye(self.dim_x)
558
+ )
559
+ .sample((n_samples,))
560
+ .to(device=mu.device, dtype=mu.dtype)
561
+ )
562
+
563
+ return mu + torch.einsum("kk,jk->jk", torch.linalg.cholesky(P), X_unit_samples)
564
+
565
+ def monte_carlo_expected_joint_log_likekihood(
566
+ self,
567
+ z_seq: torch.Tensor,
568
+ x0: torch.Tensor,
569
+ P0: torch.Tensor,
570
+ x0_smoothed: torch.Tensor,
571
+ P0_smoothed: torch.Tensor,
572
+ x_smoothed: torch.Tensor,
573
+ P_smoothed: torch.Tensor,
574
+ Q_seq: torch.Tensor,
575
+ R_seq: torch.Tensor,
576
+ f_args: tuple = (),
577
+ h_args: tuple = (),
578
+ ) -> torch.Tensor:
579
+ """
580
+ This function computes the Monte Carlo approximation of the expected joint log-likelihood
581
+ of the observed data given the posterior distribution of the states.
582
+
583
+ The posterior distribution of the states is derived from the smoothed state estimates,
584
+ which are tractable and exact in the case of linear Gaussian models. This simplifies the
585
+ process compared to a full-fledged Variational Inference (VI) approach, where the posterior
586
+ distribution is approximated using a variational distribution. Here, the exact posterior
587
+ distribution is used, as it is directly obtained from the smoothing algorithm, eliminating
588
+ the need for variational approximations.
589
+
590
+ In the case where the dynamics and observation models are parameterized by neural networks,
591
+ the variational distribution corresponds to the exact posterior distribution over the unknown
592
+ variables (e.g., parameters of the models). This formulation maintains computational efficiency
593
+ while leveraging the expressiveness of neural networks for system modeling.
594
+
595
+ Args:
596
+ z_seq (torch.Tensor): Sequence of observations of shape (seq_len, dim_z).\
597
+ x0 (torch.Tensor): Initial state estimate of shape (dim_x,).
598
+ P0 (torch.Tensor): Initial state covariance of shape (dim_x, dim_x).
599
+ x0_smoothed (torch.Tensor): Initial smoothed state estimate of shape (dim_x,).
600
+ P0_smoothed (torch.Tensor): Initial smoothed state covariance of shape (dim_x, dim_x).
601
+ x_smoothed (torch.Tensor): Smoothed state estimates of shape (seq_len, dim_x).
602
+ P_smoothed (torch.Tensor): Smoothed state covariances of shape (seq_len, dim_x, dim_x).
603
+ Q_seq (torch.Tensor): Process noise covariance of shape (seq_len, dim_x, dim_x).
604
+ R_seq (torch.Tensor): Observation noise covariance of shape (seq_len, dim_z, dim_z).
605
+ f_args (tuple): Additional arguments to the state transition function.
606
+ h_args (tuple): Additional arguments to the observation function.
607
+
608
+
609
+ Returns:
610
+ torch.Tensor: The Monte Carlo approximation of the expected joint log-likelihood.
611
+ """
612
+ # Loop to calculate the monte-carlo approximation of the expected joint log-likelihood
613
+ E_log_likelihood = 0.0
614
+ # Here we will use the vmap function to vectorize the computation of the log-likelihood
615
+ vllf = torch.vmap(
616
+ func=gaussain_log_likelihood,
617
+ in_dims=(0, None),
618
+ )
619
+
620
+ # Draw samples from the smoothed initial distribution
621
+ X0_samples = self.draw_samples(
622
+ n_samples=self.mc_samples,
623
+ mu=x0_smoothed,
624
+ P=P0_smoothed,
625
+ )
626
+ innov_x0 = X0_samples - x0
627
+ E_log_likelihood += vllf(innov_x0, P0).mean()
628
+
629
+ # For each measurement calculate the expected log-likelihood
630
+ for i in range(z_seq.size(0)):
631
+ # Calculate the measurement term
632
+ Xt = self.draw_samples(
633
+ n_samples=self.mc_samples,
634
+ mu=x_smoothed[i],
635
+ P=P_smoothed[i],
636
+ )
637
+ innov_z = z_seq[i] - torch.vmap(
638
+ lambda x: self.h(x, *(args[i] for args in h_args)),
639
+ )(Xt)
640
+ E_log_likelihood += vllf(innov_z, R_seq[i]).mean()
641
+
642
+ # Calculate the transition term
643
+ Xp = self.draw_samples(
644
+ n_samples=self.mc_samples,
645
+ mu=x_smoothed[i - 1] if i > 0 else x0_smoothed,
646
+ P=P_smoothed[i - 1] if i > 0 else P0_smoothed,
647
+ )
648
+ innov_x = Xt - torch.vmap(
649
+ lambda x: self.f(x, *(args[i] for args in f_args)),
650
+ )(Xp)
651
+
652
+ E_log_likelihood += vllf(innov_x, Q_seq[i]).mean()
653
+
654
+ return E_log_likelihood
@@ -0,0 +1,54 @@
1
+ """Joint Jacobian Transformation to compute the Jacobian of a function with respect to its first argument and the function output."""
2
+
3
+ from torch.func import jacrev
4
+
5
+ __all__ = ["joint_jacobian_transform"]
6
+
7
+
8
+ def joint_jacobian_transform(f: callable) -> callable:
9
+ """
10
+ Transforms a given function f into a joint function that returns both the
11
+ function's output and its Jacobian matrix with respect to its first argument.
12
+
13
+ The transformed function, when called, returns a tuple where the first element
14
+ is the Jacobian matrix of the function `f` with respect to its first argument
15
+ and the second element is the output of the function `f`.
16
+
17
+ Args:
18
+ f (callable): The function to be transformed. It should be of the form
19
+ `f(x, *args)` where `x` is the variable with respect to
20
+ which the Jacobian is computed and `*args` are additional
21
+ arguments.
22
+
23
+ Returns:
24
+ callable: A new function `F` such that `F(x, *args)` returns a tuple
25
+ `(Df(x, *args), f(x, *args))`, where `Df(x, *args)` is the
26
+ Jacobian matrix of `f` with respect to `x` and `f(x, *args)`
27
+ is the original output of `f`.
28
+
29
+ Example:
30
+ >>> def my_function(x, a, b):
31
+ >>> return a * x + b
32
+ >>>
33
+ >>> F = joint_jacobian_transform(my_function)
34
+ >>> x = torch.tensor([1.0, 2.0])
35
+ >>> a = torch.tensor([2.0])
36
+ >>> b = torch.tensor([3.0])
37
+ >>> jacobian, output = F(x, a, b)
38
+ >>> print(jacobian) # Jacobian of my_function with respect to x
39
+ >>> print(output) # Output of my_function
40
+ """
41
+
42
+ # Define the new joint function
43
+ def joint_func(*args, **kwargs):
44
+ # Compute the output of the function
45
+ output = f(*args, **kwargs)
46
+ return output, output
47
+
48
+ # The new joint jacrev function
49
+ Df = jacrev(
50
+ func=joint_func,
51
+ argnums=0,
52
+ has_aux=True,
53
+ )
54
+ return Df
@@ -0,0 +1,133 @@
1
+ """
2
+ Utility functions for calculating negative log likelihood of observations.
3
+
4
+ This module provides functions to compute the negative log likelihood of a sequence of observations given
5
+ their predicted values and associated covariance matrices. The negative log likelihood quantifies the
6
+ likelihood of observing the data under a given model, incorporating uncertainties in the measurements.
7
+
8
+ The negative log likelihood for a sequence of innovations and their covariances is computed using the
9
+ logarithm of the joint distribution of the observations given the parameters. It is expressed mathematically as:
10
+
11
+ $$
12
+ \ln(\mathcal{L}(\Phi \mid \mathcal{D})) = -\frac{nk}{2} \ln(2\pi) - \frac{1}{2} \sum_{i=1}^{n} \left[ \ln |S_{i}| + \Delta y_{i}^T S_{i}^{-1} \Delta y_i \right]
13
+ $$
14
+
15
+ where $ n $ is the sequence length, $ k $ is the dimension of the innovation vector,
16
+ $ \Delta y_i $ is the innovation vector at time step $ i $, and $ S_i $ is the innovation
17
+ covariance matrix at time step $ i $.
18
+
19
+ Functions:
20
+ - `likelihood_at_time`: Calculates likelihood at a single time step.
21
+ - `log_likelihood`: Computes total log likelihood for a sequence of time steps.
22
+ - `negative_log_likelihood`: Computes total negative log likelihood for a sequence of time steps.
23
+
24
+ Note:
25
+ - The constant factor $ \frac{nk}{2} \ln(2\pi) $ is skipped in the implementation since it does not affect the optimization process.
26
+ """
27
+
28
+ import torch
29
+
30
+ __all__ = ["gaussain_log_likelihood", "log_likelihood"]
31
+
32
+
33
+ def gaussain_log_likelihood(
34
+ innovation: torch.Tensor,
35
+ innovation_covariance: torch.Tensor,
36
+ ) -> torch.Tensor:
37
+ """
38
+ Calculates the likelihood estimation at a single time step.
39
+
40
+ Args:
41
+ innovation (torch.Tensor): The innovation vector, representing the difference
42
+ between the observed and predicted measurements.
43
+ Shape: (dim_z,)
44
+ innovation_covariance (torch.Tensor): The innovation covariance matrix,
45
+ representing the uncertainty in the innovation.
46
+ Shape: (dim_z, dim_z)
47
+
48
+ Returns:
49
+ torch.Tensor: The loss value for the given innovation and its covariance.
50
+
51
+ Note:
52
+ - The dimension factor is skipped since it is a constant and does not affect the optimization.
53
+
54
+ Example:
55
+ >>> innovation = torch.tensor([1.0, 2.0])
56
+ >>> innovation_covariance = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
57
+ >>> loss = likelihood_at_time(innovation, innovation_covariance)
58
+ >>> print(loss)
59
+ """
60
+ # Calculate the log determinant of the innovation covariance matrix
61
+ log_det = torch.linalg.slogdet(innovation_covariance)[1]
62
+
63
+ k = innovation.shape[0]
64
+ C = (-k / 2) * torch.log(
65
+ 2 * torch.tensor(torch.pi, dtype=innovation.dtype, device=innovation.device)
66
+ )
67
+ return C - 0.5 * (
68
+ log_det + innovation @ torch.linalg.inv(innovation_covariance) @ innovation
69
+ )
70
+
71
+
72
+ def log_likelihood(
73
+ innovation: torch.Tensor,
74
+ innovation_covariance: torch.Tensor,
75
+ ) -> torch.Tensor:
76
+ """
77
+ Calculates the log likelihood estimation for a sequence of time steps.
78
+
79
+ Args:
80
+ innovation (torch.Tensor): The innovation vector for each time step,
81
+ representing the differences between observed and predicted
82
+ measurements over a sequence. Shape: (seq_len, dim_z)
83
+ innovation_covariance (torch.Tensor): The innovation covariance matrix for each time step,
84
+ representing the uncertainty in the innovations
85
+ over a sequence. Shape: (seq_len, dim_z, dim_z)
86
+
87
+ Returns:
88
+ torch.Tensor: The total log likelihood value for the sequence of innovations and
89
+ their covariances.
90
+
91
+ Example:
92
+ >>> innovation = torch.tensor([[1.0, 2.0], [0.5, 1.5]])
93
+ >>> innovation_covariance = torch.tensor([
94
+ >>> [[1.0, 0.0], [0.0, 1.0]],
95
+ >>> [[0.5, 0.0], [0.0, 0.5]]
96
+ >>> ])
97
+ >>> total_log_likelihood = log_likelihood(innovation, innovation_covariance)
98
+ >>> print(total_log_likelihood)
99
+ """
100
+ return torch.vmap(gaussain_log_likelihood, in_dims=0)(
101
+ innovation, innovation_covariance
102
+ ).sum()
103
+
104
+
105
+ def negative_log_likelihood(
106
+ innovation: torch.Tensor,
107
+ innovation_covariance: torch.Tensor,
108
+ ) -> torch.Tensor:
109
+ """
110
+ Calculates the negative log likelihood estimation for a sequence of time steps.
111
+
112
+ Args:
113
+ innovation (torch.Tensor): The innovation vector for each time step,
114
+ representing the differences between observed and predicted
115
+ measurements over a sequence. Shape: (seq_len, dim_z)
116
+ innovation_covariance (torch.Tensor): The innovation covariance matrix for each time step,
117
+ representing the uncertainty in the innovations
118
+ over a sequence. Shape: (seq_len, dim_z, dim_z)
119
+
120
+ Returns:
121
+ torch.Tensor: The total negative log likelihood value for the sequence of innovations and
122
+ their covariances.
123
+
124
+ Example:
125
+ >>> innovation = torch.tensor([[1.0, 2.0], [0.5, 1.5]])
126
+ >>> innovation_covariance = torch.tensor([
127
+ >>> [[1.0, 0.0], [0.0, 1.0]],
128
+ >>> [[0.5, 0.0], [0.0, 0.5]]
129
+ >>> ])
130
+ >>> total_negative_log_likelihood = negative_log_likelihood(innovation, innovation_covariance)
131
+ >>> print(total_negative_log_likelihood)
132
+ """
133
+ return -log_likelihood(innovation, innovation_covariance)
diffkalman/py.typed ADDED
File without changes
diffkalman/utils.py ADDED
@@ -0,0 +1,90 @@
1
+ """Implementation of utility modules for the differentiable Kalman filter."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ __all__ = [
7
+ "SymmetricPositiveDefiniteMatrix",
8
+ "DiagonalSymmetricPositiveDefiniteMatrix",
9
+ ]
10
+
11
+
12
+ class SymmetricPositiveDefiniteMatrix(nn.Module):
13
+ """
14
+ Module for ensuring a symmetric positive definite matrix using a parameterized approach.
15
+
16
+ This module constructs a symmetric positive definite matrix from an initial matrix Q_0.
17
+ It ensures that the resultant matrix is symmetric and has positive eigenvalues.
18
+
19
+ Attributes:
20
+ LT_mask (torch.Tensor): Lower triangular mask used for parameter initialization.
21
+ W (nn.Parameter): Parameter representing the input matrix adjusted for symmetry and positivity.
22
+
23
+ Methods:
24
+ forward():
25
+ Performs the forward pass of the module.
26
+ Returns a symmetric positive definite matrix derived from the input parameter.
27
+ """
28
+
29
+ def __init__(self, M: torch.Tensor, trainable: bool = True):
30
+ """
31
+ Initializes the SymmetricPositiveDefiniteMatrix module.
32
+
33
+ Args:
34
+ M (torch.Tensor): Initial matrix for the parameter.
35
+ trainable (bool): Flag to indicate whether the parameter is trainable. Default: True.
36
+ """
37
+ super().__init__()
38
+
39
+ # Cholesky decomposition of the initial matrix
40
+ L = torch.linalg.cholesky(M, upper=False)
41
+
42
+ # Initialize the parameter with the lower triangular part of the Cholesky decomposition
43
+ self.W = nn.Parameter(L, requires_grad=trainable)
44
+
45
+ # Mask for the diagonal entries
46
+ self.diag_mask = torch.eye(M.shape[0], dtype=torch.bool)
47
+
48
+ def forward(self) -> torch.Tensor:
49
+ """
50
+ Forward pass of the SymmetricPositiveDefiniteMatrix module.
51
+
52
+ Returns:
53
+ torch.Tensor: Symmetric positive definite matrix derived from the input parameter.
54
+ """
55
+ # Make the diagonal entries positive
56
+ L = torch.tril(self.W)
57
+ L[self.diag_mask] = torch.abs(L[self.diag_mask])
58
+
59
+ return L @ L.T
60
+
61
+
62
+ class DiagonalSymmetricPositiveDefiniteMatrix(nn.Module):
63
+ """
64
+ A PyTorch module representing a diagonal symmetric positive definite matrix.
65
+
66
+ This module takes a diagonal matrix `M` as input and constructs a diagonal symmetric positive definite matrix `W`.
67
+ The diagonal entries of `W` are initialized with the diagonal entries of `M` and can be trained if `trainable` is set to True.
68
+
69
+ Args:
70
+ M (torch.Tensor): The diagonal matrix used to initialize the diagonal entries of `W`.
71
+ trainable (bool, optional): Whether the diagonal entries of `W` should be trainable. Defaults to True.
72
+
73
+ Returns:
74
+ torch.Tensor: The diagonal symmetric positive definite matrix `W`.
75
+ """
76
+
77
+ def __init__(self, M: torch.Tensor, trainable: bool = True):
78
+ super().__init__()
79
+
80
+ self.W = nn.Parameter(torch.diag(M.diagonal()), requires_grad=trainable)
81
+
82
+ def forward(self) -> torch.Tensor:
83
+ """
84
+ Forward pass of the module.
85
+
86
+ Returns:
87
+ torch.Tensor: The diagonal symmetric positive definite matrix `W`.
88
+ """
89
+ # Make the diagonal entries positive
90
+ return torch.abs(self.W)
@@ -0,0 +1,183 @@
1
+ Metadata-Version: 2.4
2
+ Name: diffkalman
3
+ Version: 0.1.0
4
+ Summary: A diffrentiable kalman filter library for auto-tuning kalman filters.
5
+ Author-email: hades <nischalbhattaraipi@gmail.com>
6
+ License-File: LICENSE
7
+ Requires-Python: >=3.11
8
+ Requires-Dist: numpy>=2.2.1
9
+ Requires-Dist: tqdm>=4.67.1
10
+ Provides-Extra: cpu
11
+ Requires-Dist: torch>=2.5.1; extra == 'cpu'
12
+ Provides-Extra: cu124
13
+ Requires-Dist: torch>=2.5.1; extra == 'cu124'
14
+ Description-Content-Type: text/markdown
15
+
16
+ # Differentiable Kalman Filter
17
+
18
+ A PyTorch-based implementation of a differentiable Kalman Filter designed for both linear and non-linear dynamical systems with Gaussian noise. This module seamlessly integrates with neural networks, enabling learnable dynamics, observation, and noise models optimized through Stochastic Variational Inference (SVI).
19
+
20
+ ## Features
21
+
22
+ - **Fully Differentiable**: End-to-end differentiable implementation compatible with PyTorch's autograd
23
+ - **Flexible Models**: Support for both linear and non-linear state transition and observation models
24
+ - **Neural Network Integration**: Models can be parameterized using neural networks
25
+ - **Automatic Jacobian Computation**: Utilizes PyTorch's autograd for derivative calculations
26
+ - **Monte Carlo Sampling**: Supports evaluation of expected joint log-likelihood to perform Expectation-Maximization (EM) learning
27
+ - **Rauch-Tung-Striebel Smoothing**: Implements forward-backward smoothing for improved state estimation using RTS algorithm
28
+
29
+ ## Installation
30
+
31
+ ```bash
32
+ pip install torch # Required dependency
33
+ # Add your package installation command here
34
+ ```
35
+
36
+ ## Quick Start
37
+
38
+ Here's a simple example of using the Differentiable Kalman Filter:
39
+
40
+ ```python
41
+ import torch
42
+ from diffkalman import DifferentiableKalmanFilter
43
+ from diffkalman.utils import SymmetricPositiveDefiniteMatrix
44
+ from diffkalman.em_loop import em_updates
45
+
46
+ # Define custom state transition and observation functions
47
+ class StateTransition(torch.nn.Module):
48
+ def forward(self, x, *args):
49
+ # Your state transition logic here
50
+ return x
51
+
52
+ class ObservationModel(torch.nn.Module):
53
+ def forward(self, x, *args):
54
+ # Your observation logic here
55
+ return x
56
+
57
+ # Initialize the filter
58
+ f = StateTransition()
59
+ h = ObservationModel()
60
+ Q = SymmetricPositiveDefiniteMatrix(dim=4, trainable=True)
61
+ R = SymmetricPositiveDefiniteMatrix(dim=2, trainable=True)
62
+ kalman_filter = DifferentiableKalmanFilter(
63
+ dim_x=4, # State dimension
64
+ dim_z=2, # Observation dimension
65
+ f=f, # State transition function
66
+ h=h # Observation function
67
+ )
68
+
69
+ # Run the filter
70
+ results = kalman_filter.sequence_filter(
71
+ z_seq=observations, # Shape: (T, dim_z)
72
+ x0=initial_state, # Shape: (dim_x,)
73
+ P0=initial_covariance, # Shape: (dim_x, dim_x)
74
+ Q=Q().repeat(len(observations), 1, 1), # Shape: (T, dim_x, dim_x)
75
+ R=R().repeat(len(observations), 1, 1) # Shape: (T, dim_z, dim_z)
76
+ )
77
+ ```
78
+
79
+ ## Detailed Usage
80
+
81
+ ### State Estimation
82
+
83
+ The module provides three main estimation methods:
84
+
85
+ 1. **Filtering**: Forward pass only
86
+ ```python
87
+ filtered_results = kalman_filter.sequence_filter(
88
+ z_seq=observations,
89
+ x0=initial_state,
90
+ P0=initial_covariance,
91
+ Q=process_noise,
92
+ R=observation_noise
93
+ )
94
+ ```
95
+
96
+ 2. **Smoothing**: Forward-backward pass
97
+ ```python
98
+ smoothed_results = kalman_filter.sequence_smooth(
99
+ z_seq=observations,
100
+ x0=initial_state,
101
+ P0=initial_covariance,
102
+ Q=process_noise,
103
+ R=observation_noise
104
+ )
105
+ ```
106
+
107
+ 3. **Single-step Prediction**: For real-time applications
108
+ ```python
109
+ step_result = kalman_filter.predict_update(
110
+ z=current_observation,
111
+ x=current_state,
112
+ P=current_covariance,
113
+ Q=process_noise,
114
+ R=observation_noise
115
+ )
116
+ ```
117
+
118
+ ### Parameter Learning
119
+
120
+ The module supports learning model parameters through using backpropagation using the negative expected joint log-likelihood of the
121
+ data as the loss function.
122
+
123
+ ```python
124
+ # Define optimizer
125
+ optimizer = torch.optim.Adam(params=[
126
+ {'params': kalman_filter.f.parameters()},
127
+ {'params': kalman_filter.h.parameters()},
128
+ {'params': Q.parameters()},
129
+ {'params': R.parameters()}
130
+ ]
131
+
132
+ NUM_EPOCHS = 10
133
+ NUM_CYCLES = 10
134
+
135
+ # Run the EM loop
136
+ marginal_likelihoods = em_updates(
137
+ kalman_filter=kalman_filter,
138
+ z_seq=observations,
139
+ x0=initial_state,
140
+ P0=initial_covariance,
141
+ Q=Q,
142
+ R=R,
143
+ optimizer=optimizer,
144
+ num_cycles=NUM_CYCLES,
145
+ num_epochs=NUM_EPOCHS
146
+ )
147
+
148
+ ```
149
+
150
+ ## API Reference
151
+
152
+ ### DifferentiableKalmanFilter
153
+
154
+ Main class implementing the Kalman Filter algorithm.
155
+
156
+ #### Constructor Parameters:
157
+ - `dim_x` (int): State space dimension
158
+ - `dim_z` (int): Observation space dimension
159
+ - `f` (nn.Module): State transition function
160
+ - `h` (nn.Module): Observation function
161
+ - `mc_samples` (int, optional): Number of Monte Carlo samples for log-likelihood estimation
162
+
163
+ #### Key Methods:
164
+ - `predict`: State prediction step
165
+ - `update`: Measurement update step
166
+ - `predict_update`: Combined prediction and update
167
+ - `sequence_filter`: Full sequence filtering
168
+ - `sequence_smooth`: Full sequence smoothing
169
+ - `marginal_log_likelihood`: Compute marginal log-likelihood
170
+ - `monte_carlo_expected_joint_log_likekihood`: Estimate expected joint log-likelihood
171
+
172
+ ## Requirements
173
+
174
+ - PyTorch >= 1.9.0
175
+ - Python >= 3.7
176
+
177
+ ## Contributing
178
+
179
+ Contributions are welcome! Please feel free to submit a Pull Request.
180
+
181
+ ## License
182
+
183
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
@@ -0,0 +1,11 @@
1
+ diffkalman/__init__.py,sha256=owxzQKRKYYZU_0bQ9QEIanoOx-Fv138JSFl8VdvQfI8,180
2
+ diffkalman/em_loop.py,sha256=j-OJEGuotI-3z8l35joevBIMWZeNOvaKNUr1QLqzRbM,4682
3
+ diffkalman/filter.py,sha256=xGk-JqZlf7vhHFk3wV3GQQN9fZrf5XbGezECK6iCo2U,27412
4
+ diffkalman/joint_jacobian_transform.py,sha256=HRmWkcCwCQxrENZjJeE3RY7901bPUP9qnbqh7NRKiYw,2000
5
+ diffkalman/negative_log_likelihood.py,sha256=B1ICGZ0urrS1THxvg020rjJe3gua_AabvwB64L7Rhdo,5730
6
+ diffkalman/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
+ diffkalman/utils.py,sha256=ovlLCdPJksAA5mfGtDaNCqqA7hLBhH5-IMRdGvMQ22o,3230
8
+ diffkalman-0.1.0.dist-info/METADATA,sha256=s50Zc-lmWCINrgyBHP-SMgW-WvbtZYGBV9ALgCuE0UY,5423
9
+ diffkalman-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
+ diffkalman-0.1.0.dist-info/licenses/LICENSE,sha256=IHkfLg6o-U7zWIh6oiJUp0j0abQ59foTPoB1OLEPleQ,1063
11
+ diffkalman-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.27.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 HadesX
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.