diffkalman 0.1.0__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- diffkalman-0.1.0/.gitignore +10 -0
- diffkalman-0.1.0/.python-version +1 -0
- diffkalman-0.1.0/LICENSE +21 -0
- diffkalman-0.1.0/PKG-INFO +183 -0
- diffkalman-0.1.0/README.md +168 -0
- diffkalman-0.1.0/pyproject.toml +57 -0
- diffkalman-0.1.0/src/diffkalman/__init__.py +6 -0
- diffkalman-0.1.0/src/diffkalman/em_loop.py +115 -0
- diffkalman-0.1.0/src/diffkalman/filter.py +654 -0
- diffkalman-0.1.0/src/diffkalman/joint_jacobian_transform.py +54 -0
- diffkalman-0.1.0/src/diffkalman/negative_log_likelihood.py +133 -0
- diffkalman-0.1.0/src/diffkalman/py.typed +0 -0
- diffkalman-0.1.0/src/diffkalman/utils.py +90 -0
- diffkalman-0.1.0/uv.lock +1265 -0
@@ -0,0 +1 @@
|
|
1
|
+
3.11
|
diffkalman-0.1.0/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2025 HadesX
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
@@ -0,0 +1,183 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: diffkalman
|
3
|
+
Version: 0.1.0
|
4
|
+
Summary: A diffrentiable kalman filter library for auto-tuning kalman filters.
|
5
|
+
Author-email: hades <nischalbhattaraipi@gmail.com>
|
6
|
+
License-File: LICENSE
|
7
|
+
Requires-Python: >=3.11
|
8
|
+
Requires-Dist: numpy>=2.2.1
|
9
|
+
Requires-Dist: tqdm>=4.67.1
|
10
|
+
Provides-Extra: cpu
|
11
|
+
Requires-Dist: torch>=2.5.1; extra == 'cpu'
|
12
|
+
Provides-Extra: cu124
|
13
|
+
Requires-Dist: torch>=2.5.1; extra == 'cu124'
|
14
|
+
Description-Content-Type: text/markdown
|
15
|
+
|
16
|
+
# Differentiable Kalman Filter
|
17
|
+
|
18
|
+
A PyTorch-based implementation of a differentiable Kalman Filter designed for both linear and non-linear dynamical systems with Gaussian noise. This module seamlessly integrates with neural networks, enabling learnable dynamics, observation, and noise models optimized through Stochastic Variational Inference (SVI).
|
19
|
+
|
20
|
+
## Features
|
21
|
+
|
22
|
+
- **Fully Differentiable**: End-to-end differentiable implementation compatible with PyTorch's autograd
|
23
|
+
- **Flexible Models**: Support for both linear and non-linear state transition and observation models
|
24
|
+
- **Neural Network Integration**: Models can be parameterized using neural networks
|
25
|
+
- **Automatic Jacobian Computation**: Utilizes PyTorch's autograd for derivative calculations
|
26
|
+
- **Monte Carlo Sampling**: Supports evaluation of expected joint log-likelihood to perform Expectation-Maximization (EM) learning
|
27
|
+
- **Rauch-Tung-Striebel Smoothing**: Implements forward-backward smoothing for improved state estimation using RTS algorithm
|
28
|
+
|
29
|
+
## Installation
|
30
|
+
|
31
|
+
```bash
|
32
|
+
pip install torch # Required dependency
|
33
|
+
# Add your package installation command here
|
34
|
+
```
|
35
|
+
|
36
|
+
## Quick Start
|
37
|
+
|
38
|
+
Here's a simple example of using the Differentiable Kalman Filter:
|
39
|
+
|
40
|
+
```python
|
41
|
+
import torch
|
42
|
+
from diffkalman import DifferentiableKalmanFilter
|
43
|
+
from diffkalman.utils import SymmetricPositiveDefiniteMatrix
|
44
|
+
from diffkalman.em_loop import em_updates
|
45
|
+
|
46
|
+
# Define custom state transition and observation functions
|
47
|
+
class StateTransition(torch.nn.Module):
|
48
|
+
def forward(self, x, *args):
|
49
|
+
# Your state transition logic here
|
50
|
+
return x
|
51
|
+
|
52
|
+
class ObservationModel(torch.nn.Module):
|
53
|
+
def forward(self, x, *args):
|
54
|
+
# Your observation logic here
|
55
|
+
return x
|
56
|
+
|
57
|
+
# Initialize the filter
|
58
|
+
f = StateTransition()
|
59
|
+
h = ObservationModel()
|
60
|
+
Q = SymmetricPositiveDefiniteMatrix(dim=4, trainable=True)
|
61
|
+
R = SymmetricPositiveDefiniteMatrix(dim=2, trainable=True)
|
62
|
+
kalman_filter = DifferentiableKalmanFilter(
|
63
|
+
dim_x=4, # State dimension
|
64
|
+
dim_z=2, # Observation dimension
|
65
|
+
f=f, # State transition function
|
66
|
+
h=h # Observation function
|
67
|
+
)
|
68
|
+
|
69
|
+
# Run the filter
|
70
|
+
results = kalman_filter.sequence_filter(
|
71
|
+
z_seq=observations, # Shape: (T, dim_z)
|
72
|
+
x0=initial_state, # Shape: (dim_x,)
|
73
|
+
P0=initial_covariance, # Shape: (dim_x, dim_x)
|
74
|
+
Q=Q().repeat(len(observations), 1, 1), # Shape: (T, dim_x, dim_x)
|
75
|
+
R=R().repeat(len(observations), 1, 1) # Shape: (T, dim_z, dim_z)
|
76
|
+
)
|
77
|
+
```
|
78
|
+
|
79
|
+
## Detailed Usage
|
80
|
+
|
81
|
+
### State Estimation
|
82
|
+
|
83
|
+
The module provides three main estimation methods:
|
84
|
+
|
85
|
+
1. **Filtering**: Forward pass only
|
86
|
+
```python
|
87
|
+
filtered_results = kalman_filter.sequence_filter(
|
88
|
+
z_seq=observations,
|
89
|
+
x0=initial_state,
|
90
|
+
P0=initial_covariance,
|
91
|
+
Q=process_noise,
|
92
|
+
R=observation_noise
|
93
|
+
)
|
94
|
+
```
|
95
|
+
|
96
|
+
2. **Smoothing**: Forward-backward pass
|
97
|
+
```python
|
98
|
+
smoothed_results = kalman_filter.sequence_smooth(
|
99
|
+
z_seq=observations,
|
100
|
+
x0=initial_state,
|
101
|
+
P0=initial_covariance,
|
102
|
+
Q=process_noise,
|
103
|
+
R=observation_noise
|
104
|
+
)
|
105
|
+
```
|
106
|
+
|
107
|
+
3. **Single-step Prediction**: For real-time applications
|
108
|
+
```python
|
109
|
+
step_result = kalman_filter.predict_update(
|
110
|
+
z=current_observation,
|
111
|
+
x=current_state,
|
112
|
+
P=current_covariance,
|
113
|
+
Q=process_noise,
|
114
|
+
R=observation_noise
|
115
|
+
)
|
116
|
+
```
|
117
|
+
|
118
|
+
### Parameter Learning
|
119
|
+
|
120
|
+
The module supports learning model parameters through using backpropagation using the negative expected joint log-likelihood of the
|
121
|
+
data as the loss function.
|
122
|
+
|
123
|
+
```python
|
124
|
+
# Define optimizer
|
125
|
+
optimizer = torch.optim.Adam(params=[
|
126
|
+
{'params': kalman_filter.f.parameters()},
|
127
|
+
{'params': kalman_filter.h.parameters()},
|
128
|
+
{'params': Q.parameters()},
|
129
|
+
{'params': R.parameters()}
|
130
|
+
]
|
131
|
+
|
132
|
+
NUM_EPOCHS = 10
|
133
|
+
NUM_CYCLES = 10
|
134
|
+
|
135
|
+
# Run the EM loop
|
136
|
+
marginal_likelihoods = em_updates(
|
137
|
+
kalman_filter=kalman_filter,
|
138
|
+
z_seq=observations,
|
139
|
+
x0=initial_state,
|
140
|
+
P0=initial_covariance,
|
141
|
+
Q=Q,
|
142
|
+
R=R,
|
143
|
+
optimizer=optimizer,
|
144
|
+
num_cycles=NUM_CYCLES,
|
145
|
+
num_epochs=NUM_EPOCHS
|
146
|
+
)
|
147
|
+
|
148
|
+
```
|
149
|
+
|
150
|
+
## API Reference
|
151
|
+
|
152
|
+
### DifferentiableKalmanFilter
|
153
|
+
|
154
|
+
Main class implementing the Kalman Filter algorithm.
|
155
|
+
|
156
|
+
#### Constructor Parameters:
|
157
|
+
- `dim_x` (int): State space dimension
|
158
|
+
- `dim_z` (int): Observation space dimension
|
159
|
+
- `f` (nn.Module): State transition function
|
160
|
+
- `h` (nn.Module): Observation function
|
161
|
+
- `mc_samples` (int, optional): Number of Monte Carlo samples for log-likelihood estimation
|
162
|
+
|
163
|
+
#### Key Methods:
|
164
|
+
- `predict`: State prediction step
|
165
|
+
- `update`: Measurement update step
|
166
|
+
- `predict_update`: Combined prediction and update
|
167
|
+
- `sequence_filter`: Full sequence filtering
|
168
|
+
- `sequence_smooth`: Full sequence smoothing
|
169
|
+
- `marginal_log_likelihood`: Compute marginal log-likelihood
|
170
|
+
- `monte_carlo_expected_joint_log_likekihood`: Estimate expected joint log-likelihood
|
171
|
+
|
172
|
+
## Requirements
|
173
|
+
|
174
|
+
- PyTorch >= 1.9.0
|
175
|
+
- Python >= 3.7
|
176
|
+
|
177
|
+
## Contributing
|
178
|
+
|
179
|
+
Contributions are welcome! Please feel free to submit a Pull Request.
|
180
|
+
|
181
|
+
## License
|
182
|
+
|
183
|
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
@@ -0,0 +1,168 @@
|
|
1
|
+
# Differentiable Kalman Filter
|
2
|
+
|
3
|
+
A PyTorch-based implementation of a differentiable Kalman Filter designed for both linear and non-linear dynamical systems with Gaussian noise. This module seamlessly integrates with neural networks, enabling learnable dynamics, observation, and noise models optimized through Stochastic Variational Inference (SVI).
|
4
|
+
|
5
|
+
## Features
|
6
|
+
|
7
|
+
- **Fully Differentiable**: End-to-end differentiable implementation compatible with PyTorch's autograd
|
8
|
+
- **Flexible Models**: Support for both linear and non-linear state transition and observation models
|
9
|
+
- **Neural Network Integration**: Models can be parameterized using neural networks
|
10
|
+
- **Automatic Jacobian Computation**: Utilizes PyTorch's autograd for derivative calculations
|
11
|
+
- **Monte Carlo Sampling**: Supports evaluation of expected joint log-likelihood to perform Expectation-Maximization (EM) learning
|
12
|
+
- **Rauch-Tung-Striebel Smoothing**: Implements forward-backward smoothing for improved state estimation using RTS algorithm
|
13
|
+
|
14
|
+
## Installation
|
15
|
+
|
16
|
+
```bash
|
17
|
+
pip install torch # Required dependency
|
18
|
+
# Add your package installation command here
|
19
|
+
```
|
20
|
+
|
21
|
+
## Quick Start
|
22
|
+
|
23
|
+
Here's a simple example of using the Differentiable Kalman Filter:
|
24
|
+
|
25
|
+
```python
|
26
|
+
import torch
|
27
|
+
from diffkalman import DifferentiableKalmanFilter
|
28
|
+
from diffkalman.utils import SymmetricPositiveDefiniteMatrix
|
29
|
+
from diffkalman.em_loop import em_updates
|
30
|
+
|
31
|
+
# Define custom state transition and observation functions
|
32
|
+
class StateTransition(torch.nn.Module):
|
33
|
+
def forward(self, x, *args):
|
34
|
+
# Your state transition logic here
|
35
|
+
return x
|
36
|
+
|
37
|
+
class ObservationModel(torch.nn.Module):
|
38
|
+
def forward(self, x, *args):
|
39
|
+
# Your observation logic here
|
40
|
+
return x
|
41
|
+
|
42
|
+
# Initialize the filter
|
43
|
+
f = StateTransition()
|
44
|
+
h = ObservationModel()
|
45
|
+
Q = SymmetricPositiveDefiniteMatrix(dim=4, trainable=True)
|
46
|
+
R = SymmetricPositiveDefiniteMatrix(dim=2, trainable=True)
|
47
|
+
kalman_filter = DifferentiableKalmanFilter(
|
48
|
+
dim_x=4, # State dimension
|
49
|
+
dim_z=2, # Observation dimension
|
50
|
+
f=f, # State transition function
|
51
|
+
h=h # Observation function
|
52
|
+
)
|
53
|
+
|
54
|
+
# Run the filter
|
55
|
+
results = kalman_filter.sequence_filter(
|
56
|
+
z_seq=observations, # Shape: (T, dim_z)
|
57
|
+
x0=initial_state, # Shape: (dim_x,)
|
58
|
+
P0=initial_covariance, # Shape: (dim_x, dim_x)
|
59
|
+
Q=Q().repeat(len(observations), 1, 1), # Shape: (T, dim_x, dim_x)
|
60
|
+
R=R().repeat(len(observations), 1, 1) # Shape: (T, dim_z, dim_z)
|
61
|
+
)
|
62
|
+
```
|
63
|
+
|
64
|
+
## Detailed Usage
|
65
|
+
|
66
|
+
### State Estimation
|
67
|
+
|
68
|
+
The module provides three main estimation methods:
|
69
|
+
|
70
|
+
1. **Filtering**: Forward pass only
|
71
|
+
```python
|
72
|
+
filtered_results = kalman_filter.sequence_filter(
|
73
|
+
z_seq=observations,
|
74
|
+
x0=initial_state,
|
75
|
+
P0=initial_covariance,
|
76
|
+
Q=process_noise,
|
77
|
+
R=observation_noise
|
78
|
+
)
|
79
|
+
```
|
80
|
+
|
81
|
+
2. **Smoothing**: Forward-backward pass
|
82
|
+
```python
|
83
|
+
smoothed_results = kalman_filter.sequence_smooth(
|
84
|
+
z_seq=observations,
|
85
|
+
x0=initial_state,
|
86
|
+
P0=initial_covariance,
|
87
|
+
Q=process_noise,
|
88
|
+
R=observation_noise
|
89
|
+
)
|
90
|
+
```
|
91
|
+
|
92
|
+
3. **Single-step Prediction**: For real-time applications
|
93
|
+
```python
|
94
|
+
step_result = kalman_filter.predict_update(
|
95
|
+
z=current_observation,
|
96
|
+
x=current_state,
|
97
|
+
P=current_covariance,
|
98
|
+
Q=process_noise,
|
99
|
+
R=observation_noise
|
100
|
+
)
|
101
|
+
```
|
102
|
+
|
103
|
+
### Parameter Learning
|
104
|
+
|
105
|
+
The module supports learning model parameters through using backpropagation using the negative expected joint log-likelihood of the
|
106
|
+
data as the loss function.
|
107
|
+
|
108
|
+
```python
|
109
|
+
# Define optimizer
|
110
|
+
optimizer = torch.optim.Adam(params=[
|
111
|
+
{'params': kalman_filter.f.parameters()},
|
112
|
+
{'params': kalman_filter.h.parameters()},
|
113
|
+
{'params': Q.parameters()},
|
114
|
+
{'params': R.parameters()}
|
115
|
+
]
|
116
|
+
|
117
|
+
NUM_EPOCHS = 10
|
118
|
+
NUM_CYCLES = 10
|
119
|
+
|
120
|
+
# Run the EM loop
|
121
|
+
marginal_likelihoods = em_updates(
|
122
|
+
kalman_filter=kalman_filter,
|
123
|
+
z_seq=observations,
|
124
|
+
x0=initial_state,
|
125
|
+
P0=initial_covariance,
|
126
|
+
Q=Q,
|
127
|
+
R=R,
|
128
|
+
optimizer=optimizer,
|
129
|
+
num_cycles=NUM_CYCLES,
|
130
|
+
num_epochs=NUM_EPOCHS
|
131
|
+
)
|
132
|
+
|
133
|
+
```
|
134
|
+
|
135
|
+
## API Reference
|
136
|
+
|
137
|
+
### DifferentiableKalmanFilter
|
138
|
+
|
139
|
+
Main class implementing the Kalman Filter algorithm.
|
140
|
+
|
141
|
+
#### Constructor Parameters:
|
142
|
+
- `dim_x` (int): State space dimension
|
143
|
+
- `dim_z` (int): Observation space dimension
|
144
|
+
- `f` (nn.Module): State transition function
|
145
|
+
- `h` (nn.Module): Observation function
|
146
|
+
- `mc_samples` (int, optional): Number of Monte Carlo samples for log-likelihood estimation
|
147
|
+
|
148
|
+
#### Key Methods:
|
149
|
+
- `predict`: State prediction step
|
150
|
+
- `update`: Measurement update step
|
151
|
+
- `predict_update`: Combined prediction and update
|
152
|
+
- `sequence_filter`: Full sequence filtering
|
153
|
+
- `sequence_smooth`: Full sequence smoothing
|
154
|
+
- `marginal_log_likelihood`: Compute marginal log-likelihood
|
155
|
+
- `monte_carlo_expected_joint_log_likekihood`: Estimate expected joint log-likelihood
|
156
|
+
|
157
|
+
## Requirements
|
158
|
+
|
159
|
+
- PyTorch >= 1.9.0
|
160
|
+
- Python >= 3.7
|
161
|
+
|
162
|
+
## Contributing
|
163
|
+
|
164
|
+
Contributions are welcome! Please feel free to submit a Pull Request.
|
165
|
+
|
166
|
+
## License
|
167
|
+
|
168
|
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
@@ -0,0 +1,57 @@
|
|
1
|
+
[project]
|
2
|
+
name = "diffkalman"
|
3
|
+
version = "0.1.0"
|
4
|
+
description = "A diffrentiable kalman filter library for auto-tuning kalman filters."
|
5
|
+
readme = "README.md"
|
6
|
+
authors = [
|
7
|
+
{ name = "hades", email = "nischalbhattaraipi@gmail.com" }
|
8
|
+
]
|
9
|
+
requires-python = ">=3.11"
|
10
|
+
dependencies = [
|
11
|
+
"numpy>=2.2.1",
|
12
|
+
"tqdm>=4.67.1",
|
13
|
+
]
|
14
|
+
|
15
|
+
[project.optional-dependencies]
|
16
|
+
cpu = [
|
17
|
+
"torch>=2.5.1",
|
18
|
+
]
|
19
|
+
cu124 = [
|
20
|
+
"torch>=2.5.1",
|
21
|
+
]
|
22
|
+
|
23
|
+
[tool.uv]
|
24
|
+
conflicts = [
|
25
|
+
[
|
26
|
+
{ extra = "cpu" },
|
27
|
+
{ extra = "cu124" },
|
28
|
+
],
|
29
|
+
]
|
30
|
+
|
31
|
+
[tool.uv.sources]
|
32
|
+
torch = [
|
33
|
+
{ index = "pytorch-cpu", extra = "cpu" },
|
34
|
+
{ index = "pytorch-cu124", extra = "cu124" },
|
35
|
+
]
|
36
|
+
|
37
|
+
[[tool.uv.index]]
|
38
|
+
name = "pytorch-cpu"
|
39
|
+
url = "https://download.pytorch.org/whl/cpu"
|
40
|
+
explicit = true
|
41
|
+
|
42
|
+
[[tool.uv.index]]
|
43
|
+
name = "pytorch-cu124"
|
44
|
+
url = "https://download.pytorch.org/whl/cu124"
|
45
|
+
explicit = true
|
46
|
+
|
47
|
+
|
48
|
+
[build-system]
|
49
|
+
requires = ["hatchling"]
|
50
|
+
build-backend = "hatchling.build"
|
51
|
+
|
52
|
+
[dependency-groups]
|
53
|
+
dev = [
|
54
|
+
"ipykernel>=6.29.5",
|
55
|
+
"matplotlib>=3.10.0",
|
56
|
+
"pandas>=2.2.3",
|
57
|
+
]
|
@@ -0,0 +1,115 @@
|
|
1
|
+
"""The EM loop module that implements the EM algorithm for the Differentiable Kalman Filter."""
|
2
|
+
|
3
|
+
from .filter import DiffrentiableKalmanFilter
|
4
|
+
from .utils import SymmetricPositiveDefiniteMatrix
|
5
|
+
import torch
|
6
|
+
import torch.nn as nn
|
7
|
+
|
8
|
+
|
9
|
+
__all__ = [
|
10
|
+
"em_updates",
|
11
|
+
]
|
12
|
+
|
13
|
+
|
14
|
+
def em_updates(
|
15
|
+
dkf: DiffrentiableKalmanFilter,
|
16
|
+
z_seq: torch.Tensor,
|
17
|
+
x0: torch.Tensor,
|
18
|
+
P0: torch.Tensor,
|
19
|
+
Q: SymmetricPositiveDefiniteMatrix,
|
20
|
+
R: SymmetricPositiveDefiniteMatrix,
|
21
|
+
optimizer: torch.optim.Optimizer,
|
22
|
+
lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
|
23
|
+
num_cycles: int = 20,
|
24
|
+
num_epochs: int = 100,
|
25
|
+
h_args: tuple = (),
|
26
|
+
f_args: tuple = (),
|
27
|
+
) -> dict:
|
28
|
+
"""A sample implementation of the EM algorithm for the Differentiable Kalman Filter.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
z (torch.Tensor): The noisy measurements sequence. Dimension: (seq_len, obs_dim)
|
32
|
+
x0 (torch.Tensor): The initial state vector. Dimension: (state_dim,)
|
33
|
+
P0 (torch.Tensor): The initial covariance matrix of the state vector. Dimension: (state_dim, state_dim)
|
34
|
+
Q (SymmetricPositiveDefiniteMatrix): The process noise covariance matrix module.
|
35
|
+
R (SymmetricPositiveDefiniteMatrix): The measurement noise covariance matrix module.
|
36
|
+
optimizer (torch.optim.Optimizer): The optimizer
|
37
|
+
lr_scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler
|
38
|
+
num_cycles (int): The number of cycles. Default: 20
|
39
|
+
num_epochs (int): The number of epochs. Default: 100
|
40
|
+
h_args (tuple): Additional arguments for the measurement model. Default: ()
|
41
|
+
f_args (tuple): Additional arguments for the transition function. Default: ()
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
dict: log likelihoods of the model with respect to the number of epochs and cycles
|
45
|
+
|
46
|
+
Note:
|
47
|
+
- Use the SymmetricPositiveDefiniteMatrix module to ensure that the process noise covariance matrix Q and the measurement noise covariance matrix R are symmetric and positive definite.
|
48
|
+
- The optimizer and the learning rate scheduler should be initialized before calling this function.
|
49
|
+
- The measurement model and the transition function should be differentiable torch modules.
|
50
|
+
"""
|
51
|
+
likelihoods = torch.zeros(num_epochs, num_cycles)
|
52
|
+
|
53
|
+
# Perform EM updates for num_epochs
|
54
|
+
for e in range(num_epochs):
|
55
|
+
|
56
|
+
## The E-step
|
57
|
+
## Without gradients tracking, get the posterior state distribution wrt current values of the parameters
|
58
|
+
with torch.no_grad():
|
59
|
+
posterior = dkf.sequence_smooth(
|
60
|
+
z_seq=z_seq,
|
61
|
+
x0=x0,
|
62
|
+
P0=P0,
|
63
|
+
Q=Q().repeat(len(z_seq), 1, 1),
|
64
|
+
R=R().repeat(len(z_seq), 1, 1),
|
65
|
+
f_args=f_args,
|
66
|
+
h_args=h_args,
|
67
|
+
)
|
68
|
+
|
69
|
+
## The M-step (Update the parameters) with respect to the current posterior state distribution for num_cycles
|
70
|
+
for c in range(num_cycles):
|
71
|
+
# Zero the gradients
|
72
|
+
optimizer.zero_grad()
|
73
|
+
|
74
|
+
# Compute the marginal likelihood for logging
|
75
|
+
with torch.no_grad():
|
76
|
+
marginal_likelihood = dkf.marginal_log_likelihood(
|
77
|
+
z_seq=z_seq,
|
78
|
+
x0=x0,
|
79
|
+
P0=P0,
|
80
|
+
Q=Q().repeat(len(z_seq), 1, 1),
|
81
|
+
R=R().repeat(len(z_seq), 1, 1),
|
82
|
+
f_args=f_args,
|
83
|
+
h_args=h_args,
|
84
|
+
)
|
85
|
+
likelihoods[e, c] = marginal_likelihood
|
86
|
+
|
87
|
+
# Perform the forward pass i.e compute the expected complete joint log-likelihood with respect previous posterior state distribution and current parameters
|
88
|
+
complete_log_likelihood = dkf.monte_carlo_expected_joint_log_likekihood(
|
89
|
+
z_seq=z_seq,
|
90
|
+
x0=x0,
|
91
|
+
P0=P0,
|
92
|
+
# below represents the posterior state distribution
|
93
|
+
x0_smoothed=posterior["x0_smoothed"],
|
94
|
+
P0_smoothed=posterior["P0_smoothed"],
|
95
|
+
x_smoothed=posterior["x_smoothed"],
|
96
|
+
P_smoothed=posterior["P_smoothed"],
|
97
|
+
Q_seq=Q().repeat(len(z_seq), 1, 1),
|
98
|
+
R_seq=R().repeat(len(z_seq), 1, 1),
|
99
|
+
f_args=f_args,
|
100
|
+
h_args=h_args,
|
101
|
+
)
|
102
|
+
|
103
|
+
# Update the parameters
|
104
|
+
(-complete_log_likelihood).backward()
|
105
|
+
optimizer.step()
|
106
|
+
lr_scheduler.step()
|
107
|
+
|
108
|
+
# Print the log likelihood
|
109
|
+
print(
|
110
|
+
f"Epoch {e + 1}/{num_epochs} Cycle {c + 1}/{num_cycles} Log Likelihood: {marginal_likelihood.item()}"
|
111
|
+
)
|
112
|
+
|
113
|
+
return {
|
114
|
+
"likelihoods": likelihoods,
|
115
|
+
}
|