diffkalman 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffkalman-0.1.0/.gitignore +10 -0
- diffkalman-0.1.0/.python-version +1 -0
- diffkalman-0.1.0/LICENSE +21 -0
- diffkalman-0.1.0/PKG-INFO +183 -0
- diffkalman-0.1.0/README.md +168 -0
- diffkalman-0.1.0/pyproject.toml +57 -0
- diffkalman-0.1.0/src/diffkalman/__init__.py +6 -0
- diffkalman-0.1.0/src/diffkalman/em_loop.py +115 -0
- diffkalman-0.1.0/src/diffkalman/filter.py +654 -0
- diffkalman-0.1.0/src/diffkalman/joint_jacobian_transform.py +54 -0
- diffkalman-0.1.0/src/diffkalman/negative_log_likelihood.py +133 -0
- diffkalman-0.1.0/src/diffkalman/py.typed +0 -0
- diffkalman-0.1.0/src/diffkalman/utils.py +90 -0
- diffkalman-0.1.0/uv.lock +1265 -0
@@ -0,0 +1 @@
|
|
1
|
+
3.11
|
diffkalman-0.1.0/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2025 HadesX
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
@@ -0,0 +1,183 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: diffkalman
|
3
|
+
Version: 0.1.0
|
4
|
+
Summary: A diffrentiable kalman filter library for auto-tuning kalman filters.
|
5
|
+
Author-email: hades <nischalbhattaraipi@gmail.com>
|
6
|
+
License-File: LICENSE
|
7
|
+
Requires-Python: >=3.11
|
8
|
+
Requires-Dist: numpy>=2.2.1
|
9
|
+
Requires-Dist: tqdm>=4.67.1
|
10
|
+
Provides-Extra: cpu
|
11
|
+
Requires-Dist: torch>=2.5.1; extra == 'cpu'
|
12
|
+
Provides-Extra: cu124
|
13
|
+
Requires-Dist: torch>=2.5.1; extra == 'cu124'
|
14
|
+
Description-Content-Type: text/markdown
|
15
|
+
|
16
|
+
# Differentiable Kalman Filter
|
17
|
+
|
18
|
+
A PyTorch-based implementation of a differentiable Kalman Filter designed for both linear and non-linear dynamical systems with Gaussian noise. This module seamlessly integrates with neural networks, enabling learnable dynamics, observation, and noise models optimized through Stochastic Variational Inference (SVI).
|
19
|
+
|
20
|
+
## Features
|
21
|
+
|
22
|
+
- **Fully Differentiable**: End-to-end differentiable implementation compatible with PyTorch's autograd
|
23
|
+
- **Flexible Models**: Support for both linear and non-linear state transition and observation models
|
24
|
+
- **Neural Network Integration**: Models can be parameterized using neural networks
|
25
|
+
- **Automatic Jacobian Computation**: Utilizes PyTorch's autograd for derivative calculations
|
26
|
+
- **Monte Carlo Sampling**: Supports evaluation of expected joint log-likelihood to perform Expectation-Maximization (EM) learning
|
27
|
+
- **Rauch-Tung-Striebel Smoothing**: Implements forward-backward smoothing for improved state estimation using RTS algorithm
|
28
|
+
|
29
|
+
## Installation
|
30
|
+
|
31
|
+
```bash
|
32
|
+
pip install torch # Required dependency
|
33
|
+
# Add your package installation command here
|
34
|
+
```
|
35
|
+
|
36
|
+
## Quick Start
|
37
|
+
|
38
|
+
Here's a simple example of using the Differentiable Kalman Filter:
|
39
|
+
|
40
|
+
```python
|
41
|
+
import torch
|
42
|
+
from diffkalman import DifferentiableKalmanFilter
|
43
|
+
from diffkalman.utils import SymmetricPositiveDefiniteMatrix
|
44
|
+
from diffkalman.em_loop import em_updates
|
45
|
+
|
46
|
+
# Define custom state transition and observation functions
|
47
|
+
class StateTransition(torch.nn.Module):
|
48
|
+
def forward(self, x, *args):
|
49
|
+
# Your state transition logic here
|
50
|
+
return x
|
51
|
+
|
52
|
+
class ObservationModel(torch.nn.Module):
|
53
|
+
def forward(self, x, *args):
|
54
|
+
# Your observation logic here
|
55
|
+
return x
|
56
|
+
|
57
|
+
# Initialize the filter
|
58
|
+
f = StateTransition()
|
59
|
+
h = ObservationModel()
|
60
|
+
Q = SymmetricPositiveDefiniteMatrix(dim=4, trainable=True)
|
61
|
+
R = SymmetricPositiveDefiniteMatrix(dim=2, trainable=True)
|
62
|
+
kalman_filter = DifferentiableKalmanFilter(
|
63
|
+
dim_x=4, # State dimension
|
64
|
+
dim_z=2, # Observation dimension
|
65
|
+
f=f, # State transition function
|
66
|
+
h=h # Observation function
|
67
|
+
)
|
68
|
+
|
69
|
+
# Run the filter
|
70
|
+
results = kalman_filter.sequence_filter(
|
71
|
+
z_seq=observations, # Shape: (T, dim_z)
|
72
|
+
x0=initial_state, # Shape: (dim_x,)
|
73
|
+
P0=initial_covariance, # Shape: (dim_x, dim_x)
|
74
|
+
Q=Q().repeat(len(observations), 1, 1), # Shape: (T, dim_x, dim_x)
|
75
|
+
R=R().repeat(len(observations), 1, 1) # Shape: (T, dim_z, dim_z)
|
76
|
+
)
|
77
|
+
```
|
78
|
+
|
79
|
+
## Detailed Usage
|
80
|
+
|
81
|
+
### State Estimation
|
82
|
+
|
83
|
+
The module provides three main estimation methods:
|
84
|
+
|
85
|
+
1. **Filtering**: Forward pass only
|
86
|
+
```python
|
87
|
+
filtered_results = kalman_filter.sequence_filter(
|
88
|
+
z_seq=observations,
|
89
|
+
x0=initial_state,
|
90
|
+
P0=initial_covariance,
|
91
|
+
Q=process_noise,
|
92
|
+
R=observation_noise
|
93
|
+
)
|
94
|
+
```
|
95
|
+
|
96
|
+
2. **Smoothing**: Forward-backward pass
|
97
|
+
```python
|
98
|
+
smoothed_results = kalman_filter.sequence_smooth(
|
99
|
+
z_seq=observations,
|
100
|
+
x0=initial_state,
|
101
|
+
P0=initial_covariance,
|
102
|
+
Q=process_noise,
|
103
|
+
R=observation_noise
|
104
|
+
)
|
105
|
+
```
|
106
|
+
|
107
|
+
3. **Single-step Prediction**: For real-time applications
|
108
|
+
```python
|
109
|
+
step_result = kalman_filter.predict_update(
|
110
|
+
z=current_observation,
|
111
|
+
x=current_state,
|
112
|
+
P=current_covariance,
|
113
|
+
Q=process_noise,
|
114
|
+
R=observation_noise
|
115
|
+
)
|
116
|
+
```
|
117
|
+
|
118
|
+
### Parameter Learning
|
119
|
+
|
120
|
+
The module supports learning model parameters through using backpropagation using the negative expected joint log-likelihood of the
|
121
|
+
data as the loss function.
|
122
|
+
|
123
|
+
```python
|
124
|
+
# Define optimizer
|
125
|
+
optimizer = torch.optim.Adam(params=[
|
126
|
+
{'params': kalman_filter.f.parameters()},
|
127
|
+
{'params': kalman_filter.h.parameters()},
|
128
|
+
{'params': Q.parameters()},
|
129
|
+
{'params': R.parameters()}
|
130
|
+
]
|
131
|
+
|
132
|
+
NUM_EPOCHS = 10
|
133
|
+
NUM_CYCLES = 10
|
134
|
+
|
135
|
+
# Run the EM loop
|
136
|
+
marginal_likelihoods = em_updates(
|
137
|
+
kalman_filter=kalman_filter,
|
138
|
+
z_seq=observations,
|
139
|
+
x0=initial_state,
|
140
|
+
P0=initial_covariance,
|
141
|
+
Q=Q,
|
142
|
+
R=R,
|
143
|
+
optimizer=optimizer,
|
144
|
+
num_cycles=NUM_CYCLES,
|
145
|
+
num_epochs=NUM_EPOCHS
|
146
|
+
)
|
147
|
+
|
148
|
+
```
|
149
|
+
|
150
|
+
## API Reference
|
151
|
+
|
152
|
+
### DifferentiableKalmanFilter
|
153
|
+
|
154
|
+
Main class implementing the Kalman Filter algorithm.
|
155
|
+
|
156
|
+
#### Constructor Parameters:
|
157
|
+
- `dim_x` (int): State space dimension
|
158
|
+
- `dim_z` (int): Observation space dimension
|
159
|
+
- `f` (nn.Module): State transition function
|
160
|
+
- `h` (nn.Module): Observation function
|
161
|
+
- `mc_samples` (int, optional): Number of Monte Carlo samples for log-likelihood estimation
|
162
|
+
|
163
|
+
#### Key Methods:
|
164
|
+
- `predict`: State prediction step
|
165
|
+
- `update`: Measurement update step
|
166
|
+
- `predict_update`: Combined prediction and update
|
167
|
+
- `sequence_filter`: Full sequence filtering
|
168
|
+
- `sequence_smooth`: Full sequence smoothing
|
169
|
+
- `marginal_log_likelihood`: Compute marginal log-likelihood
|
170
|
+
- `monte_carlo_expected_joint_log_likekihood`: Estimate expected joint log-likelihood
|
171
|
+
|
172
|
+
## Requirements
|
173
|
+
|
174
|
+
- PyTorch >= 1.9.0
|
175
|
+
- Python >= 3.7
|
176
|
+
|
177
|
+
## Contributing
|
178
|
+
|
179
|
+
Contributions are welcome! Please feel free to submit a Pull Request.
|
180
|
+
|
181
|
+
## License
|
182
|
+
|
183
|
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
@@ -0,0 +1,168 @@
|
|
1
|
+
# Differentiable Kalman Filter
|
2
|
+
|
3
|
+
A PyTorch-based implementation of a differentiable Kalman Filter designed for both linear and non-linear dynamical systems with Gaussian noise. This module seamlessly integrates with neural networks, enabling learnable dynamics, observation, and noise models optimized through Stochastic Variational Inference (SVI).
|
4
|
+
|
5
|
+
## Features
|
6
|
+
|
7
|
+
- **Fully Differentiable**: End-to-end differentiable implementation compatible with PyTorch's autograd
|
8
|
+
- **Flexible Models**: Support for both linear and non-linear state transition and observation models
|
9
|
+
- **Neural Network Integration**: Models can be parameterized using neural networks
|
10
|
+
- **Automatic Jacobian Computation**: Utilizes PyTorch's autograd for derivative calculations
|
11
|
+
- **Monte Carlo Sampling**: Supports evaluation of expected joint log-likelihood to perform Expectation-Maximization (EM) learning
|
12
|
+
- **Rauch-Tung-Striebel Smoothing**: Implements forward-backward smoothing for improved state estimation using RTS algorithm
|
13
|
+
|
14
|
+
## Installation
|
15
|
+
|
16
|
+
```bash
|
17
|
+
pip install torch # Required dependency
|
18
|
+
# Add your package installation command here
|
19
|
+
```
|
20
|
+
|
21
|
+
## Quick Start
|
22
|
+
|
23
|
+
Here's a simple example of using the Differentiable Kalman Filter:
|
24
|
+
|
25
|
+
```python
|
26
|
+
import torch
|
27
|
+
from diffkalman import DifferentiableKalmanFilter
|
28
|
+
from diffkalman.utils import SymmetricPositiveDefiniteMatrix
|
29
|
+
from diffkalman.em_loop import em_updates
|
30
|
+
|
31
|
+
# Define custom state transition and observation functions
|
32
|
+
class StateTransition(torch.nn.Module):
|
33
|
+
def forward(self, x, *args):
|
34
|
+
# Your state transition logic here
|
35
|
+
return x
|
36
|
+
|
37
|
+
class ObservationModel(torch.nn.Module):
|
38
|
+
def forward(self, x, *args):
|
39
|
+
# Your observation logic here
|
40
|
+
return x
|
41
|
+
|
42
|
+
# Initialize the filter
|
43
|
+
f = StateTransition()
|
44
|
+
h = ObservationModel()
|
45
|
+
Q = SymmetricPositiveDefiniteMatrix(dim=4, trainable=True)
|
46
|
+
R = SymmetricPositiveDefiniteMatrix(dim=2, trainable=True)
|
47
|
+
kalman_filter = DifferentiableKalmanFilter(
|
48
|
+
dim_x=4, # State dimension
|
49
|
+
dim_z=2, # Observation dimension
|
50
|
+
f=f, # State transition function
|
51
|
+
h=h # Observation function
|
52
|
+
)
|
53
|
+
|
54
|
+
# Run the filter
|
55
|
+
results = kalman_filter.sequence_filter(
|
56
|
+
z_seq=observations, # Shape: (T, dim_z)
|
57
|
+
x0=initial_state, # Shape: (dim_x,)
|
58
|
+
P0=initial_covariance, # Shape: (dim_x, dim_x)
|
59
|
+
Q=Q().repeat(len(observations), 1, 1), # Shape: (T, dim_x, dim_x)
|
60
|
+
R=R().repeat(len(observations), 1, 1) # Shape: (T, dim_z, dim_z)
|
61
|
+
)
|
62
|
+
```
|
63
|
+
|
64
|
+
## Detailed Usage
|
65
|
+
|
66
|
+
### State Estimation
|
67
|
+
|
68
|
+
The module provides three main estimation methods:
|
69
|
+
|
70
|
+
1. **Filtering**: Forward pass only
|
71
|
+
```python
|
72
|
+
filtered_results = kalman_filter.sequence_filter(
|
73
|
+
z_seq=observations,
|
74
|
+
x0=initial_state,
|
75
|
+
P0=initial_covariance,
|
76
|
+
Q=process_noise,
|
77
|
+
R=observation_noise
|
78
|
+
)
|
79
|
+
```
|
80
|
+
|
81
|
+
2. **Smoothing**: Forward-backward pass
|
82
|
+
```python
|
83
|
+
smoothed_results = kalman_filter.sequence_smooth(
|
84
|
+
z_seq=observations,
|
85
|
+
x0=initial_state,
|
86
|
+
P0=initial_covariance,
|
87
|
+
Q=process_noise,
|
88
|
+
R=observation_noise
|
89
|
+
)
|
90
|
+
```
|
91
|
+
|
92
|
+
3. **Single-step Prediction**: For real-time applications
|
93
|
+
```python
|
94
|
+
step_result = kalman_filter.predict_update(
|
95
|
+
z=current_observation,
|
96
|
+
x=current_state,
|
97
|
+
P=current_covariance,
|
98
|
+
Q=process_noise,
|
99
|
+
R=observation_noise
|
100
|
+
)
|
101
|
+
```
|
102
|
+
|
103
|
+
### Parameter Learning
|
104
|
+
|
105
|
+
The module supports learning model parameters through using backpropagation using the negative expected joint log-likelihood of the
|
106
|
+
data as the loss function.
|
107
|
+
|
108
|
+
```python
|
109
|
+
# Define optimizer
|
110
|
+
optimizer = torch.optim.Adam(params=[
|
111
|
+
{'params': kalman_filter.f.parameters()},
|
112
|
+
{'params': kalman_filter.h.parameters()},
|
113
|
+
{'params': Q.parameters()},
|
114
|
+
{'params': R.parameters()}
|
115
|
+
]
|
116
|
+
|
117
|
+
NUM_EPOCHS = 10
|
118
|
+
NUM_CYCLES = 10
|
119
|
+
|
120
|
+
# Run the EM loop
|
121
|
+
marginal_likelihoods = em_updates(
|
122
|
+
kalman_filter=kalman_filter,
|
123
|
+
z_seq=observations,
|
124
|
+
x0=initial_state,
|
125
|
+
P0=initial_covariance,
|
126
|
+
Q=Q,
|
127
|
+
R=R,
|
128
|
+
optimizer=optimizer,
|
129
|
+
num_cycles=NUM_CYCLES,
|
130
|
+
num_epochs=NUM_EPOCHS
|
131
|
+
)
|
132
|
+
|
133
|
+
```
|
134
|
+
|
135
|
+
## API Reference
|
136
|
+
|
137
|
+
### DifferentiableKalmanFilter
|
138
|
+
|
139
|
+
Main class implementing the Kalman Filter algorithm.
|
140
|
+
|
141
|
+
#### Constructor Parameters:
|
142
|
+
- `dim_x` (int): State space dimension
|
143
|
+
- `dim_z` (int): Observation space dimension
|
144
|
+
- `f` (nn.Module): State transition function
|
145
|
+
- `h` (nn.Module): Observation function
|
146
|
+
- `mc_samples` (int, optional): Number of Monte Carlo samples for log-likelihood estimation
|
147
|
+
|
148
|
+
#### Key Methods:
|
149
|
+
- `predict`: State prediction step
|
150
|
+
- `update`: Measurement update step
|
151
|
+
- `predict_update`: Combined prediction and update
|
152
|
+
- `sequence_filter`: Full sequence filtering
|
153
|
+
- `sequence_smooth`: Full sequence smoothing
|
154
|
+
- `marginal_log_likelihood`: Compute marginal log-likelihood
|
155
|
+
- `monte_carlo_expected_joint_log_likekihood`: Estimate expected joint log-likelihood
|
156
|
+
|
157
|
+
## Requirements
|
158
|
+
|
159
|
+
- PyTorch >= 1.9.0
|
160
|
+
- Python >= 3.7
|
161
|
+
|
162
|
+
## Contributing
|
163
|
+
|
164
|
+
Contributions are welcome! Please feel free to submit a Pull Request.
|
165
|
+
|
166
|
+
## License
|
167
|
+
|
168
|
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
@@ -0,0 +1,57 @@
|
|
1
|
+
[project]
|
2
|
+
name = "diffkalman"
|
3
|
+
version = "0.1.0"
|
4
|
+
description = "A diffrentiable kalman filter library for auto-tuning kalman filters."
|
5
|
+
readme = "README.md"
|
6
|
+
authors = [
|
7
|
+
{ name = "hades", email = "nischalbhattaraipi@gmail.com" }
|
8
|
+
]
|
9
|
+
requires-python = ">=3.11"
|
10
|
+
dependencies = [
|
11
|
+
"numpy>=2.2.1",
|
12
|
+
"tqdm>=4.67.1",
|
13
|
+
]
|
14
|
+
|
15
|
+
[project.optional-dependencies]
|
16
|
+
cpu = [
|
17
|
+
"torch>=2.5.1",
|
18
|
+
]
|
19
|
+
cu124 = [
|
20
|
+
"torch>=2.5.1",
|
21
|
+
]
|
22
|
+
|
23
|
+
[tool.uv]
|
24
|
+
conflicts = [
|
25
|
+
[
|
26
|
+
{ extra = "cpu" },
|
27
|
+
{ extra = "cu124" },
|
28
|
+
],
|
29
|
+
]
|
30
|
+
|
31
|
+
[tool.uv.sources]
|
32
|
+
torch = [
|
33
|
+
{ index = "pytorch-cpu", extra = "cpu" },
|
34
|
+
{ index = "pytorch-cu124", extra = "cu124" },
|
35
|
+
]
|
36
|
+
|
37
|
+
[[tool.uv.index]]
|
38
|
+
name = "pytorch-cpu"
|
39
|
+
url = "https://download.pytorch.org/whl/cpu"
|
40
|
+
explicit = true
|
41
|
+
|
42
|
+
[[tool.uv.index]]
|
43
|
+
name = "pytorch-cu124"
|
44
|
+
url = "https://download.pytorch.org/whl/cu124"
|
45
|
+
explicit = true
|
46
|
+
|
47
|
+
|
48
|
+
[build-system]
|
49
|
+
requires = ["hatchling"]
|
50
|
+
build-backend = "hatchling.build"
|
51
|
+
|
52
|
+
[dependency-groups]
|
53
|
+
dev = [
|
54
|
+
"ipykernel>=6.29.5",
|
55
|
+
"matplotlib>=3.10.0",
|
56
|
+
"pandas>=2.2.3",
|
57
|
+
]
|
@@ -0,0 +1,115 @@
|
|
1
|
+
"""The EM loop module that implements the EM algorithm for the Differentiable Kalman Filter."""
|
2
|
+
|
3
|
+
from .filter import DiffrentiableKalmanFilter
|
4
|
+
from .utils import SymmetricPositiveDefiniteMatrix
|
5
|
+
import torch
|
6
|
+
import torch.nn as nn
|
7
|
+
|
8
|
+
|
9
|
+
__all__ = [
|
10
|
+
"em_updates",
|
11
|
+
]
|
12
|
+
|
13
|
+
|
14
|
+
def em_updates(
|
15
|
+
dkf: DiffrentiableKalmanFilter,
|
16
|
+
z_seq: torch.Tensor,
|
17
|
+
x0: torch.Tensor,
|
18
|
+
P0: torch.Tensor,
|
19
|
+
Q: SymmetricPositiveDefiniteMatrix,
|
20
|
+
R: SymmetricPositiveDefiniteMatrix,
|
21
|
+
optimizer: torch.optim.Optimizer,
|
22
|
+
lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
|
23
|
+
num_cycles: int = 20,
|
24
|
+
num_epochs: int = 100,
|
25
|
+
h_args: tuple = (),
|
26
|
+
f_args: tuple = (),
|
27
|
+
) -> dict:
|
28
|
+
"""A sample implementation of the EM algorithm for the Differentiable Kalman Filter.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
z (torch.Tensor): The noisy measurements sequence. Dimension: (seq_len, obs_dim)
|
32
|
+
x0 (torch.Tensor): The initial state vector. Dimension: (state_dim,)
|
33
|
+
P0 (torch.Tensor): The initial covariance matrix of the state vector. Dimension: (state_dim, state_dim)
|
34
|
+
Q (SymmetricPositiveDefiniteMatrix): The process noise covariance matrix module.
|
35
|
+
R (SymmetricPositiveDefiniteMatrix): The measurement noise covariance matrix module.
|
36
|
+
optimizer (torch.optim.Optimizer): The optimizer
|
37
|
+
lr_scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler
|
38
|
+
num_cycles (int): The number of cycles. Default: 20
|
39
|
+
num_epochs (int): The number of epochs. Default: 100
|
40
|
+
h_args (tuple): Additional arguments for the measurement model. Default: ()
|
41
|
+
f_args (tuple): Additional arguments for the transition function. Default: ()
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
dict: log likelihoods of the model with respect to the number of epochs and cycles
|
45
|
+
|
46
|
+
Note:
|
47
|
+
- Use the SymmetricPositiveDefiniteMatrix module to ensure that the process noise covariance matrix Q and the measurement noise covariance matrix R are symmetric and positive definite.
|
48
|
+
- The optimizer and the learning rate scheduler should be initialized before calling this function.
|
49
|
+
- The measurement model and the transition function should be differentiable torch modules.
|
50
|
+
"""
|
51
|
+
likelihoods = torch.zeros(num_epochs, num_cycles)
|
52
|
+
|
53
|
+
# Perform EM updates for num_epochs
|
54
|
+
for e in range(num_epochs):
|
55
|
+
|
56
|
+
## The E-step
|
57
|
+
## Without gradients tracking, get the posterior state distribution wrt current values of the parameters
|
58
|
+
with torch.no_grad():
|
59
|
+
posterior = dkf.sequence_smooth(
|
60
|
+
z_seq=z_seq,
|
61
|
+
x0=x0,
|
62
|
+
P0=P0,
|
63
|
+
Q=Q().repeat(len(z_seq), 1, 1),
|
64
|
+
R=R().repeat(len(z_seq), 1, 1),
|
65
|
+
f_args=f_args,
|
66
|
+
h_args=h_args,
|
67
|
+
)
|
68
|
+
|
69
|
+
## The M-step (Update the parameters) with respect to the current posterior state distribution for num_cycles
|
70
|
+
for c in range(num_cycles):
|
71
|
+
# Zero the gradients
|
72
|
+
optimizer.zero_grad()
|
73
|
+
|
74
|
+
# Compute the marginal likelihood for logging
|
75
|
+
with torch.no_grad():
|
76
|
+
marginal_likelihood = dkf.marginal_log_likelihood(
|
77
|
+
z_seq=z_seq,
|
78
|
+
x0=x0,
|
79
|
+
P0=P0,
|
80
|
+
Q=Q().repeat(len(z_seq), 1, 1),
|
81
|
+
R=R().repeat(len(z_seq), 1, 1),
|
82
|
+
f_args=f_args,
|
83
|
+
h_args=h_args,
|
84
|
+
)
|
85
|
+
likelihoods[e, c] = marginal_likelihood
|
86
|
+
|
87
|
+
# Perform the forward pass i.e compute the expected complete joint log-likelihood with respect previous posterior state distribution and current parameters
|
88
|
+
complete_log_likelihood = dkf.monte_carlo_expected_joint_log_likekihood(
|
89
|
+
z_seq=z_seq,
|
90
|
+
x0=x0,
|
91
|
+
P0=P0,
|
92
|
+
# below represents the posterior state distribution
|
93
|
+
x0_smoothed=posterior["x0_smoothed"],
|
94
|
+
P0_smoothed=posterior["P0_smoothed"],
|
95
|
+
x_smoothed=posterior["x_smoothed"],
|
96
|
+
P_smoothed=posterior["P_smoothed"],
|
97
|
+
Q_seq=Q().repeat(len(z_seq), 1, 1),
|
98
|
+
R_seq=R().repeat(len(z_seq), 1, 1),
|
99
|
+
f_args=f_args,
|
100
|
+
h_args=h_args,
|
101
|
+
)
|
102
|
+
|
103
|
+
# Update the parameters
|
104
|
+
(-complete_log_likelihood).backward()
|
105
|
+
optimizer.step()
|
106
|
+
lr_scheduler.step()
|
107
|
+
|
108
|
+
# Print the log likelihood
|
109
|
+
print(
|
110
|
+
f"Epoch {e + 1}/{num_epochs} Cycle {c + 1}/{num_cycles} Log Likelihood: {marginal_likelihood.item()}"
|
111
|
+
)
|
112
|
+
|
113
|
+
return {
|
114
|
+
"likelihoods": likelihoods,
|
115
|
+
}
|