kalmax 0.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kalmax-0.0.0/LICENSE +21 -0
- kalmax-0.0.0/PKG-INFO +123 -0
- kalmax-0.0.0/README.md +109 -0
- kalmax-0.0.0/kalmax/__init__.py +0 -0
- kalmax-0.0.0/kalmax/demo_utils.py +119 -0
- kalmax-0.0.0/kalmax/kalman.py +644 -0
- kalmax-0.0.0/kalmax/kde.py +179 -0
- kalmax-0.0.0/kalmax/kernels.py +159 -0
- kalmax-0.0.0/kalmax/utils.py +166 -0
- kalmax-0.0.0/kalmax.egg-info/PKG-INFO +123 -0
- kalmax-0.0.0/kalmax.egg-info/SOURCES.txt +14 -0
- kalmax-0.0.0/kalmax.egg-info/dependency_links.txt +1 -0
- kalmax-0.0.0/kalmax.egg-info/requires.txt +7 -0
- kalmax-0.0.0/kalmax.egg-info/top_level.txt +1 -0
- kalmax-0.0.0/setup.cfg +4 -0
- kalmax-0.0.0/setup.py +27 -0
kalmax-0.0.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2022 Tom M George
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
kalmax-0.0.0/PKG-INFO
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: kalmax
|
|
3
|
+
Version: 0.0.0
|
|
4
|
+
Summary: Kalman based neural decoding in Jax
|
|
5
|
+
Home-page: https://github.com/TomGeorge1234/KalMax
|
|
6
|
+
Author: Tom George
|
|
7
|
+
Classifier: Programming Language :: Python :: 3
|
|
8
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
9
|
+
Classifier: Operating System :: OS Independent
|
|
10
|
+
Requires-Python: >=3.6
|
|
11
|
+
Description-Content-Type: text/markdown
|
|
12
|
+
Provides-Extra: demo
|
|
13
|
+
License-File: LICENSE
|
|
14
|
+
|
|
15
|
+
# **KalMax**: Kalman based neural decoding in Jax
|
|
16
|
+
**KalMax** = **Kal**man smoothing of **Max**imum likelihood estimates in Jax.
|
|
17
|
+
|
|
18
|
+
You provide $\mathbf{S} \in \mathbb{N}^{T \times N}$ (spike counts) and $\mathbf{X} \in \mathbb{R}^{T \times D}$ (a continuous variable, e.g. position) and `KalMax` provides jax-optimised functions and classes for:
|
|
19
|
+
|
|
20
|
+
1. **Fitting rate maps** using kernel density estimation (KDE)
|
|
21
|
+
2. **Calculating likelihood** maps $p(\mathbf{s}_t|\mathbf{x})$
|
|
22
|
+
3. **Kalman filter / smoother**
|
|
23
|
+
|
|
24
|
+
<img src="figures/display_figures/input_data.png" width=350>
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
#### Why are these functionalities combined into one package?...
|
|
30
|
+
|
|
31
|
+
Because Likelihood Estimation + Kalman filtering = Powerful neural decoding. By Kalman filtering/smoothing the maximum likelihood estimates (as opposed to the spikes themselves) we bypass the issues of naive Kalman filters (spikes are rarely linearly related to position) and maximum likelihood decoding (which does not account for temporal continuity in the trajectory), outperforming both for no extra computational cost.
|
|
32
|
+
<img src="figures/display_figures/filter_comparisons.gif" width=850>
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
Core `KalMax` functions are optimised and jit-compiled in jax making them **very fast**. For example `KalMax` kalman filtering is >13 times faster than an equivalent numpy implementation by the popular [`pykalman`](https://github.com/pykalman/pykalman/tree/master) library (see [demo](./kalmax_demo.ipynb)).
|
|
36
|
+
|
|
37
|
+
<img src="figures/display_figures/kalman_speed_comparison.png" width=150>
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# Install
|
|
41
|
+
```
|
|
42
|
+
git clone https://github.com/TomGeorge1234/KalMax.git
|
|
43
|
+
cd KalMax
|
|
44
|
+
pip install -e .
|
|
45
|
+
```
|
|
46
|
+
(`-e`) is optional for developer install.
|
|
47
|
+
|
|
48
|
+
Alternatively
|
|
49
|
+
```
|
|
50
|
+
pip install git+https://github.com/TomGeorge1234/KalMax.git
|
|
51
|
+
```
|
|
52
|
+
|
|
53
|
+
# Usage
|
|
54
|
+
|
|
55
|
+
A full demo [](https://colab.research.google.com/github/TomGeorge1234/KalMax/blob/main/kalmax_demo.ipynb) is provided in the [`kalmax_demo.ipynb`](./kalmax_demo.ipynb). Sudo-code is provided below.
|
|
56
|
+
|
|
57
|
+
```python
|
|
58
|
+
import kalmax
|
|
59
|
+
import jax.numpy as jnp
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
```python
|
|
63
|
+
# 0. PREPARE DATA IN JAX ARRAYS
|
|
64
|
+
S_train = jnp.array(...) # (T, N_CELLS) train spike counts
|
|
65
|
+
Z_train = jnp.array(...) # (T, DIMS) train continuous variable
|
|
66
|
+
S_test = jnp.array(...) # (T_TEST, N_CELLS) test spike counts
|
|
67
|
+
bins = jnp.array(...) # (N_BINS, DIMS) coordinates at which to estimate receptive fields / likelihoods)
|
|
68
|
+
```
|
|
69
|
+
<img src="figures/display_figures/data.png" width=850>
|
|
70
|
+
|
|
71
|
+
```python
|
|
72
|
+
# 1. FIT RECEPTIVE FIELDS using kalmax.kde
|
|
73
|
+
firing_rate = kalmax.kde.kde(
|
|
74
|
+
bins = bins,
|
|
75
|
+
trajectory = Z_train,
|
|
76
|
+
spikes = S_train,
|
|
77
|
+
kernel = kalmax.kernels.gaussian_kernel,
|
|
78
|
+
kernel_kwargs = {'covariance':0.01**2*np.eye(DIMS)}, # kernel bandwidth
|
|
79
|
+
) # --> (N_CELLS, N_BINS)
|
|
80
|
+
```
|
|
81
|
+
<img src="figures/display_figures/receptive_fields.png" width=850>
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
```python
|
|
85
|
+
# 2.1 CALCULATE LIKELIHOODS using kalmax.poisson_log_likelihood
|
|
86
|
+
log_likelihoods = kalmax.kde.poisson_log_likelihood(
|
|
87
|
+
spikes = S_test,
|
|
88
|
+
mean_rate = firing_rate,
|
|
89
|
+
) # --> (T_TEST, N_CELLS)
|
|
90
|
+
|
|
91
|
+
# 2.2 FIT GAUSSIAN TO LIKELIHOODS using kalmax.utils.fit_gaussian
|
|
92
|
+
MLE_means, MLE_modes, MLE_covs = kalmax.utils.fit_gaussian_vmap(
|
|
93
|
+
x = bins,
|
|
94
|
+
likelihoods = jnp.exp(log_likelihoods),
|
|
95
|
+
) # --> (T_TEST, DIMS), (T_TEST, DIMS, DIMS)
|
|
96
|
+
```
|
|
97
|
+
<img src="figures/display_figures/likelihood_maps_fitted.png" width=850>
|
|
98
|
+
|
|
99
|
+
```python
|
|
100
|
+
# 3. KALMAN FILTER / SMOOTH using kalmax.KalmanFilter.KalmanFilter
|
|
101
|
+
kalman_filter = kalmax.kalman.KalmanFilter(
|
|
102
|
+
dim_Z = DIMS,
|
|
103
|
+
dim_Y = N_CELLS,
|
|
104
|
+
# SEE DEMO FOR HOW TO FIT/SET THESE
|
|
105
|
+
F=F, # state transition matrix
|
|
106
|
+
Q=Q, # state noise covariance
|
|
107
|
+
H=H, # observation matrix
|
|
108
|
+
R=R, # observation noise covariance
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# [FILTER]
|
|
112
|
+
mus_f, sigmas_f = kalman_filter.filter(
|
|
113
|
+
Y = Y,
|
|
114
|
+
mu0 = mu0,
|
|
115
|
+
sigma0 = sigma0,
|
|
116
|
+
) # --> (T, DIMS), (T, DIMS, DIMS)
|
|
117
|
+
|
|
118
|
+
# [SMOOTH]
|
|
119
|
+
mus_s, sigmas_s = kalman_filter.smooth(
|
|
120
|
+
mus_f = mus_f,
|
|
121
|
+
sigmas_f = sigmas_f,
|
|
122
|
+
) # --> (T, DIMS), (T, DIMS, DIMS)
|
|
123
|
+
```
|
kalmax-0.0.0/README.md
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
# **KalMax**: Kalman based neural decoding in Jax
|
|
2
|
+
**KalMax** = **Kal**man smoothing of **Max**imum likelihood estimates in Jax.
|
|
3
|
+
|
|
4
|
+
You provide $\mathbf{S} \in \mathbb{N}^{T \times N}$ (spike counts) and $\mathbf{X} \in \mathbb{R}^{T \times D}$ (a continuous variable, e.g. position) and `KalMax` provides jax-optimised functions and classes for:
|
|
5
|
+
|
|
6
|
+
1. **Fitting rate maps** using kernel density estimation (KDE)
|
|
7
|
+
2. **Calculating likelihood** maps $p(\mathbf{s}_t|\mathbf{x})$
|
|
8
|
+
3. **Kalman filter / smoother**
|
|
9
|
+
|
|
10
|
+
<img src="figures/display_figures/input_data.png" width=350>
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
#### Why are these functionalities combined into one package?...
|
|
16
|
+
|
|
17
|
+
Because Likelihood Estimation + Kalman filtering = Powerful neural decoding. By Kalman filtering/smoothing the maximum likelihood estimates (as opposed to the spikes themselves) we bypass the issues of naive Kalman filters (spikes are rarely linearly related to position) and maximum likelihood decoding (which does not account for temporal continuity in the trajectory), outperforming both for no extra computational cost.
|
|
18
|
+
<img src="figures/display_figures/filter_comparisons.gif" width=850>
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
Core `KalMax` functions are optimised and jit-compiled in jax making them **very fast**. For example `KalMax` kalman filtering is >13 times faster than an equivalent numpy implementation by the popular [`pykalman`](https://github.com/pykalman/pykalman/tree/master) library (see [demo](./kalmax_demo.ipynb)).
|
|
22
|
+
|
|
23
|
+
<img src="figures/display_figures/kalman_speed_comparison.png" width=150>
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# Install
|
|
27
|
+
```
|
|
28
|
+
git clone https://github.com/TomGeorge1234/KalMax.git
|
|
29
|
+
cd KalMax
|
|
30
|
+
pip install -e .
|
|
31
|
+
```
|
|
32
|
+
(`-e`) is optional for developer install.
|
|
33
|
+
|
|
34
|
+
Alternatively
|
|
35
|
+
```
|
|
36
|
+
pip install git+https://github.com/TomGeorge1234/KalMax.git
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
# Usage
|
|
40
|
+
|
|
41
|
+
A full demo [](https://colab.research.google.com/github/TomGeorge1234/KalMax/blob/main/kalmax_demo.ipynb) is provided in the [`kalmax_demo.ipynb`](./kalmax_demo.ipynb). Sudo-code is provided below.
|
|
42
|
+
|
|
43
|
+
```python
|
|
44
|
+
import kalmax
|
|
45
|
+
import jax.numpy as jnp
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
```python
|
|
49
|
+
# 0. PREPARE DATA IN JAX ARRAYS
|
|
50
|
+
S_train = jnp.array(...) # (T, N_CELLS) train spike counts
|
|
51
|
+
Z_train = jnp.array(...) # (T, DIMS) train continuous variable
|
|
52
|
+
S_test = jnp.array(...) # (T_TEST, N_CELLS) test spike counts
|
|
53
|
+
bins = jnp.array(...) # (N_BINS, DIMS) coordinates at which to estimate receptive fields / likelihoods)
|
|
54
|
+
```
|
|
55
|
+
<img src="figures/display_figures/data.png" width=850>
|
|
56
|
+
|
|
57
|
+
```python
|
|
58
|
+
# 1. FIT RECEPTIVE FIELDS using kalmax.kde
|
|
59
|
+
firing_rate = kalmax.kde.kde(
|
|
60
|
+
bins = bins,
|
|
61
|
+
trajectory = Z_train,
|
|
62
|
+
spikes = S_train,
|
|
63
|
+
kernel = kalmax.kernels.gaussian_kernel,
|
|
64
|
+
kernel_kwargs = {'covariance':0.01**2*np.eye(DIMS)}, # kernel bandwidth
|
|
65
|
+
) # --> (N_CELLS, N_BINS)
|
|
66
|
+
```
|
|
67
|
+
<img src="figures/display_figures/receptive_fields.png" width=850>
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
```python
|
|
71
|
+
# 2.1 CALCULATE LIKELIHOODS using kalmax.poisson_log_likelihood
|
|
72
|
+
log_likelihoods = kalmax.kde.poisson_log_likelihood(
|
|
73
|
+
spikes = S_test,
|
|
74
|
+
mean_rate = firing_rate,
|
|
75
|
+
) # --> (T_TEST, N_CELLS)
|
|
76
|
+
|
|
77
|
+
# 2.2 FIT GAUSSIAN TO LIKELIHOODS using kalmax.utils.fit_gaussian
|
|
78
|
+
MLE_means, MLE_modes, MLE_covs = kalmax.utils.fit_gaussian_vmap(
|
|
79
|
+
x = bins,
|
|
80
|
+
likelihoods = jnp.exp(log_likelihoods),
|
|
81
|
+
) # --> (T_TEST, DIMS), (T_TEST, DIMS, DIMS)
|
|
82
|
+
```
|
|
83
|
+
<img src="figures/display_figures/likelihood_maps_fitted.png" width=850>
|
|
84
|
+
|
|
85
|
+
```python
|
|
86
|
+
# 3. KALMAN FILTER / SMOOTH using kalmax.KalmanFilter.KalmanFilter
|
|
87
|
+
kalman_filter = kalmax.kalman.KalmanFilter(
|
|
88
|
+
dim_Z = DIMS,
|
|
89
|
+
dim_Y = N_CELLS,
|
|
90
|
+
# SEE DEMO FOR HOW TO FIT/SET THESE
|
|
91
|
+
F=F, # state transition matrix
|
|
92
|
+
Q=Q, # state noise covariance
|
|
93
|
+
H=H, # observation matrix
|
|
94
|
+
R=R, # observation noise covariance
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# [FILTER]
|
|
98
|
+
mus_f, sigmas_f = kalman_filter.filter(
|
|
99
|
+
Y = Y,
|
|
100
|
+
mu0 = mu0,
|
|
101
|
+
sigma0 = sigma0,
|
|
102
|
+
) # --> (T, DIMS), (T, DIMS, DIMS)
|
|
103
|
+
|
|
104
|
+
# [SMOOTH]
|
|
105
|
+
mus_s, sigmas_s = kalman_filter.smooth(
|
|
106
|
+
mus_f = mus_f,
|
|
107
|
+
sigmas_f = sigmas_f,
|
|
108
|
+
) # --> (T, DIMS), (T, DIMS, DIMS)
|
|
109
|
+
```
|
|
File without changes
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
# Some (mostly plotting and poorly written) utilities for the Kalman filter demo
|
|
2
|
+
import matplotlib
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
def animate_over_time(
|
|
7
|
+
plot_function,
|
|
8
|
+
t_start,
|
|
9
|
+
t_end,
|
|
10
|
+
speed_up=10, #relative to real time
|
|
11
|
+
fps=20,):
|
|
12
|
+
"""Animates the plot function over time_range.
|
|
13
|
+
|
|
14
|
+
Params
|
|
15
|
+
------
|
|
16
|
+
plot_function: function, it must act as it's own init function when ax=None
|
|
17
|
+
Function that takes ax, time and time_start as input and returns ax.
|
|
18
|
+
Must have the following signature
|
|
19
|
+
def plot_function(ax,
|
|
20
|
+
time,
|
|
21
|
+
time_start=0 # time to start (relevant if plotting a trajectory)
|
|
22
|
+
):
|
|
23
|
+
# do something
|
|
24
|
+
return ax
|
|
25
|
+
_type_: _description_
|
|
26
|
+
"""
|
|
27
|
+
# plot function should take ax and time as input
|
|
28
|
+
time_per_frame = speed_up/fps
|
|
29
|
+
times = np.arange(t_start, t_end, time_per_frame)
|
|
30
|
+
|
|
31
|
+
def update(time, ax, fig):
|
|
32
|
+
for ax_ in fig.get_axes():
|
|
33
|
+
ax_.clear()
|
|
34
|
+
ax = plot_function(ax=ax, t_start=t_start, t_end=time)
|
|
35
|
+
return ax
|
|
36
|
+
|
|
37
|
+
ax = plot_function()
|
|
38
|
+
from matplotlib.animation import FuncAnimation
|
|
39
|
+
anim = FuncAnimation(plt.gcf(), update, frames=times, fargs=(ax, plt.gcf(),), interval=1000/fps)
|
|
40
|
+
plt.close()
|
|
41
|
+
return anim
|
|
42
|
+
|
|
43
|
+
def plot_trajectory(
|
|
44
|
+
trajectory,
|
|
45
|
+
time_stamps,
|
|
46
|
+
ax=None,
|
|
47
|
+
t_start=None,
|
|
48
|
+
t_end=None,
|
|
49
|
+
**plot_kwargs):
|
|
50
|
+
"""Plots a trajectory over a given time range
|
|
51
|
+
|
|
52
|
+
Params
|
|
53
|
+
------
|
|
54
|
+
ax: matplotlib axis, optional
|
|
55
|
+
Axis to plot on
|
|
56
|
+
t_start: float, optional
|
|
57
|
+
Start time
|
|
58
|
+
t_end: float, optional
|
|
59
|
+
End time
|
|
60
|
+
trajectory: np.ndarray, optional
|
|
61
|
+
Trajectory to plot, shape (T, 2)
|
|
62
|
+
time_stamps: np.ndarray, optional
|
|
63
|
+
Time stamps for the trajectory, shape (T,)
|
|
64
|
+
plot_kwargs: dict, optional
|
|
65
|
+
Additional plotting arguments like 'color', 'alpha', 'label', 'scatter_points', 'show_line', 'linewidth', 'title', 'xlabel', 'ylabel'
|
|
66
|
+
|
|
67
|
+
Returns
|
|
68
|
+
-------
|
|
69
|
+
ax: matplotlib axis
|
|
70
|
+
"""
|
|
71
|
+
if t_start is None: t_start = time_stamps[0]
|
|
72
|
+
if t_end is None: t_end = time_stamps[-1]
|
|
73
|
+
if ax is None:
|
|
74
|
+
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
|
|
75
|
+
|
|
76
|
+
id_start, id_end = np.argmin(np.abs(time_stamps - t_start)), np.argmin(np.abs(time_stamps-t_end))
|
|
77
|
+
trajectory_ = trajectory[id_start:id_end]
|
|
78
|
+
|
|
79
|
+
# Get plot kwargs
|
|
80
|
+
color = plot_kwargs.get('color', 'k')
|
|
81
|
+
scatter_points = plot_kwargs.get('scatter_points', True)
|
|
82
|
+
show_line = plot_kwargs.get('show_line', True)
|
|
83
|
+
linewidth = plot_kwargs.get('linewidth',1)
|
|
84
|
+
title = plot_kwargs.get('title',None)
|
|
85
|
+
xlabel = plot_kwargs.get('xlabel','x [m]')
|
|
86
|
+
ylabel = plot_kwargs.get('ylabel','y [m]')
|
|
87
|
+
alpha = plot_kwargs.get('alpha',1)
|
|
88
|
+
label = plot_kwargs.get('label',None)
|
|
89
|
+
min_x = plot_kwargs.get('min_x',trajectory[:,0].min().round(1))
|
|
90
|
+
max_x = plot_kwargs.get('max_x',trajectory[:,0].max().round(1))
|
|
91
|
+
min_y = plot_kwargs.get('min_y',trajectory[:,1].min().round(1))
|
|
92
|
+
max_y = plot_kwargs.get('max_y',trajectory[:,1].max().round(1))
|
|
93
|
+
|
|
94
|
+
if show_line:
|
|
95
|
+
ax.plot(trajectory_[:,0],trajectory_[:,1],color=color, linewidth=linewidth, alpha=alpha, label=label)
|
|
96
|
+
if scatter_points:
|
|
97
|
+
ax.scatter(trajectory_[:,0],trajectory_[:,1],color=color, linewidth=0, s=6, alpha=alpha)
|
|
98
|
+
ax.set_xlim(min_x, max_x); ax.set_ylim(min_y, max_y); ax.set_aspect('equal', 'box')
|
|
99
|
+
ax.set_xticks([min_x, max_x]); ax.set_yticks([min_y, max_y])
|
|
100
|
+
ax.set_xlabel(xlabel)
|
|
101
|
+
ax.set_ylabel(ylabel)
|
|
102
|
+
if title is not None: ax.set_title(title)
|
|
103
|
+
|
|
104
|
+
return ax
|
|
105
|
+
|
|
106
|
+
def plot_ellipse(ax, mean, cov, color):
|
|
107
|
+
lambda_, v = np.linalg.eig(cov)
|
|
108
|
+
lambda_ = np.sqrt(lambda_) # convert from variance to standard deviation along the eigenvectors
|
|
109
|
+
ell = matplotlib.patches.Ellipse(xy=mean,
|
|
110
|
+
width=lambda_[0]*2,
|
|
111
|
+
height=lambda_[1]*2,
|
|
112
|
+
angle=np.rad2deg(np.arctan(v[:, 0][1] / v[:, 0][0])),
|
|
113
|
+
lw=1,
|
|
114
|
+
fill=True,
|
|
115
|
+
edgecolor=color,
|
|
116
|
+
facecolor=color,
|
|
117
|
+
alpha=0.5,)
|
|
118
|
+
ax.add_artist(ell)
|
|
119
|
+
return ax
|