kalmax 0.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kalmax-0.0.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Tom M George
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
kalmax-0.0.0/PKG-INFO ADDED
@@ -0,0 +1,123 @@
1
+ Metadata-Version: 2.1
2
+ Name: kalmax
3
+ Version: 0.0.0
4
+ Summary: Kalman based neural decoding in Jax
5
+ Home-page: https://github.com/TomGeorge1234/KalMax
6
+ Author: Tom George
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: License :: OSI Approved :: MIT License
9
+ Classifier: Operating System :: OS Independent
10
+ Requires-Python: >=3.6
11
+ Description-Content-Type: text/markdown
12
+ Provides-Extra: demo
13
+ License-File: LICENSE
14
+
15
+ # **KalMax**: Kalman based neural decoding in Jax
16
+ **KalMax** = **Kal**man smoothing of **Max**imum likelihood estimates in Jax.
17
+
18
+ You provide $\mathbf{S} \in \mathbb{N}^{T \times N}$ (spike counts) and $\mathbf{X} \in \mathbb{R}^{T \times D}$ (a continuous variable, e.g. position) and `KalMax` provides jax-optimised functions and classes for:
19
+
20
+ 1. **Fitting rate maps** using kernel density estimation (KDE)
21
+ 2. **Calculating likelihood** maps $p(\mathbf{s}_t|\mathbf{x})$
22
+ 3. **Kalman filter / smoother**
23
+
24
+ <img src="figures/display_figures/input_data.png" width=350>
25
+
26
+
27
+
28
+
29
+ #### Why are these functionalities combined into one package?...
30
+
31
+ Because Likelihood Estimation + Kalman filtering = Powerful neural decoding. By Kalman filtering/smoothing the maximum likelihood estimates (as opposed to the spikes themselves) we bypass the issues of naive Kalman filters (spikes are rarely linearly related to position) and maximum likelihood decoding (which does not account for temporal continuity in the trajectory), outperforming both for no extra computational cost.
32
+ <img src="figures/display_figures/filter_comparisons.gif" width=850>
33
+
34
+
35
+ Core `KalMax` functions are optimised and jit-compiled in jax making them **very fast**. For example `KalMax` kalman filtering is >13 times faster than an equivalent numpy implementation by the popular [`pykalman`](https://github.com/pykalman/pykalman/tree/master) library (see [demo](./kalmax_demo.ipynb)).
36
+
37
+ <img src="figures/display_figures/kalman_speed_comparison.png" width=150>
38
+
39
+
40
+ # Install
41
+ ```
42
+ git clone https://github.com/TomGeorge1234/KalMax.git
43
+ cd KalMax
44
+ pip install -e .
45
+ ```
46
+ (`-e`) is optional for developer install.
47
+
48
+ Alternatively
49
+ ```
50
+ pip install git+https://github.com/TomGeorge1234/KalMax.git
51
+ ```
52
+
53
+ # Usage
54
+
55
+ A full demo [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/TomGeorge1234/KalMax/blob/main/kalmax_demo.ipynb) is provided in the [`kalmax_demo.ipynb`](./kalmax_demo.ipynb). Sudo-code is provided below.
56
+
57
+ ```python
58
+ import kalmax
59
+ import jax.numpy as jnp
60
+ ```
61
+
62
+ ```python
63
+ # 0. PREPARE DATA IN JAX ARRAYS
64
+ S_train = jnp.array(...) # (T, N_CELLS) train spike counts
65
+ Z_train = jnp.array(...) # (T, DIMS) train continuous variable
66
+ S_test = jnp.array(...) # (T_TEST, N_CELLS) test spike counts
67
+ bins = jnp.array(...) # (N_BINS, DIMS) coordinates at which to estimate receptive fields / likelihoods)
68
+ ```
69
+ <img src="figures/display_figures/data.png" width=850>
70
+
71
+ ```python
72
+ # 1. FIT RECEPTIVE FIELDS using kalmax.kde
73
+ firing_rate = kalmax.kde.kde(
74
+ bins = bins,
75
+ trajectory = Z_train,
76
+ spikes = S_train,
77
+ kernel = kalmax.kernels.gaussian_kernel,
78
+ kernel_kwargs = {'covariance':0.01**2*np.eye(DIMS)}, # kernel bandwidth
79
+ ) # --> (N_CELLS, N_BINS)
80
+ ```
81
+ <img src="figures/display_figures/receptive_fields.png" width=850>
82
+
83
+
84
+ ```python
85
+ # 2.1 CALCULATE LIKELIHOODS using kalmax.poisson_log_likelihood
86
+ log_likelihoods = kalmax.kde.poisson_log_likelihood(
87
+ spikes = S_test,
88
+ mean_rate = firing_rate,
89
+ ) # --> (T_TEST, N_CELLS)
90
+
91
+ # 2.2 FIT GAUSSIAN TO LIKELIHOODS using kalmax.utils.fit_gaussian
92
+ MLE_means, MLE_modes, MLE_covs = kalmax.utils.fit_gaussian_vmap(
93
+ x = bins,
94
+ likelihoods = jnp.exp(log_likelihoods),
95
+ ) # --> (T_TEST, DIMS), (T_TEST, DIMS, DIMS)
96
+ ```
97
+ <img src="figures/display_figures/likelihood_maps_fitted.png" width=850>
98
+
99
+ ```python
100
+ # 3. KALMAN FILTER / SMOOTH using kalmax.KalmanFilter.KalmanFilter
101
+ kalman_filter = kalmax.kalman.KalmanFilter(
102
+ dim_Z = DIMS,
103
+ dim_Y = N_CELLS,
104
+ # SEE DEMO FOR HOW TO FIT/SET THESE
105
+ F=F, # state transition matrix
106
+ Q=Q, # state noise covariance
107
+ H=H, # observation matrix
108
+ R=R, # observation noise covariance
109
+ )
110
+
111
+ # [FILTER]
112
+ mus_f, sigmas_f = kalman_filter.filter(
113
+ Y = Y,
114
+ mu0 = mu0,
115
+ sigma0 = sigma0,
116
+ ) # --> (T, DIMS), (T, DIMS, DIMS)
117
+
118
+ # [SMOOTH]
119
+ mus_s, sigmas_s = kalman_filter.smooth(
120
+ mus_f = mus_f,
121
+ sigmas_f = sigmas_f,
122
+ ) # --> (T, DIMS), (T, DIMS, DIMS)
123
+ ```
kalmax-0.0.0/README.md ADDED
@@ -0,0 +1,109 @@
1
+ # **KalMax**: Kalman based neural decoding in Jax
2
+ **KalMax** = **Kal**man smoothing of **Max**imum likelihood estimates in Jax.
3
+
4
+ You provide $\mathbf{S} \in \mathbb{N}^{T \times N}$ (spike counts) and $\mathbf{X} \in \mathbb{R}^{T \times D}$ (a continuous variable, e.g. position) and `KalMax` provides jax-optimised functions and classes for:
5
+
6
+ 1. **Fitting rate maps** using kernel density estimation (KDE)
7
+ 2. **Calculating likelihood** maps $p(\mathbf{s}_t|\mathbf{x})$
8
+ 3. **Kalman filter / smoother**
9
+
10
+ <img src="figures/display_figures/input_data.png" width=350>
11
+
12
+
13
+
14
+
15
+ #### Why are these functionalities combined into one package?...
16
+
17
+ Because Likelihood Estimation + Kalman filtering = Powerful neural decoding. By Kalman filtering/smoothing the maximum likelihood estimates (as opposed to the spikes themselves) we bypass the issues of naive Kalman filters (spikes are rarely linearly related to position) and maximum likelihood decoding (which does not account for temporal continuity in the trajectory), outperforming both for no extra computational cost.
18
+ <img src="figures/display_figures/filter_comparisons.gif" width=850>
19
+
20
+
21
+ Core `KalMax` functions are optimised and jit-compiled in jax making them **very fast**. For example `KalMax` kalman filtering is >13 times faster than an equivalent numpy implementation by the popular [`pykalman`](https://github.com/pykalman/pykalman/tree/master) library (see [demo](./kalmax_demo.ipynb)).
22
+
23
+ <img src="figures/display_figures/kalman_speed_comparison.png" width=150>
24
+
25
+
26
+ # Install
27
+ ```
28
+ git clone https://github.com/TomGeorge1234/KalMax.git
29
+ cd KalMax
30
+ pip install -e .
31
+ ```
32
+ (`-e`) is optional for developer install.
33
+
34
+ Alternatively
35
+ ```
36
+ pip install git+https://github.com/TomGeorge1234/KalMax.git
37
+ ```
38
+
39
+ # Usage
40
+
41
+ A full demo [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/TomGeorge1234/KalMax/blob/main/kalmax_demo.ipynb) is provided in the [`kalmax_demo.ipynb`](./kalmax_demo.ipynb). Sudo-code is provided below.
42
+
43
+ ```python
44
+ import kalmax
45
+ import jax.numpy as jnp
46
+ ```
47
+
48
+ ```python
49
+ # 0. PREPARE DATA IN JAX ARRAYS
50
+ S_train = jnp.array(...) # (T, N_CELLS) train spike counts
51
+ Z_train = jnp.array(...) # (T, DIMS) train continuous variable
52
+ S_test = jnp.array(...) # (T_TEST, N_CELLS) test spike counts
53
+ bins = jnp.array(...) # (N_BINS, DIMS) coordinates at which to estimate receptive fields / likelihoods)
54
+ ```
55
+ <img src="figures/display_figures/data.png" width=850>
56
+
57
+ ```python
58
+ # 1. FIT RECEPTIVE FIELDS using kalmax.kde
59
+ firing_rate = kalmax.kde.kde(
60
+ bins = bins,
61
+ trajectory = Z_train,
62
+ spikes = S_train,
63
+ kernel = kalmax.kernels.gaussian_kernel,
64
+ kernel_kwargs = {'covariance':0.01**2*np.eye(DIMS)}, # kernel bandwidth
65
+ ) # --> (N_CELLS, N_BINS)
66
+ ```
67
+ <img src="figures/display_figures/receptive_fields.png" width=850>
68
+
69
+
70
+ ```python
71
+ # 2.1 CALCULATE LIKELIHOODS using kalmax.poisson_log_likelihood
72
+ log_likelihoods = kalmax.kde.poisson_log_likelihood(
73
+ spikes = S_test,
74
+ mean_rate = firing_rate,
75
+ ) # --> (T_TEST, N_CELLS)
76
+
77
+ # 2.2 FIT GAUSSIAN TO LIKELIHOODS using kalmax.utils.fit_gaussian
78
+ MLE_means, MLE_modes, MLE_covs = kalmax.utils.fit_gaussian_vmap(
79
+ x = bins,
80
+ likelihoods = jnp.exp(log_likelihoods),
81
+ ) # --> (T_TEST, DIMS), (T_TEST, DIMS, DIMS)
82
+ ```
83
+ <img src="figures/display_figures/likelihood_maps_fitted.png" width=850>
84
+
85
+ ```python
86
+ # 3. KALMAN FILTER / SMOOTH using kalmax.KalmanFilter.KalmanFilter
87
+ kalman_filter = kalmax.kalman.KalmanFilter(
88
+ dim_Z = DIMS,
89
+ dim_Y = N_CELLS,
90
+ # SEE DEMO FOR HOW TO FIT/SET THESE
91
+ F=F, # state transition matrix
92
+ Q=Q, # state noise covariance
93
+ H=H, # observation matrix
94
+ R=R, # observation noise covariance
95
+ )
96
+
97
+ # [FILTER]
98
+ mus_f, sigmas_f = kalman_filter.filter(
99
+ Y = Y,
100
+ mu0 = mu0,
101
+ sigma0 = sigma0,
102
+ ) # --> (T, DIMS), (T, DIMS, DIMS)
103
+
104
+ # [SMOOTH]
105
+ mus_s, sigmas_s = kalman_filter.smooth(
106
+ mus_f = mus_f,
107
+ sigmas_f = sigmas_f,
108
+ ) # --> (T, DIMS), (T, DIMS, DIMS)
109
+ ```
File without changes
@@ -0,0 +1,119 @@
1
+ # Some (mostly plotting and poorly written) utilities for the Kalman filter demo
2
+ import matplotlib
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+
6
+ def animate_over_time(
7
+ plot_function,
8
+ t_start,
9
+ t_end,
10
+ speed_up=10, #relative to real time
11
+ fps=20,):
12
+ """Animates the plot function over time_range.
13
+
14
+ Params
15
+ ------
16
+ plot_function: function, it must act as it's own init function when ax=None
17
+ Function that takes ax, time and time_start as input and returns ax.
18
+ Must have the following signature
19
+ def plot_function(ax,
20
+ time,
21
+ time_start=0 # time to start (relevant if plotting a trajectory)
22
+ ):
23
+ # do something
24
+ return ax
25
+ _type_: _description_
26
+ """
27
+ # plot function should take ax and time as input
28
+ time_per_frame = speed_up/fps
29
+ times = np.arange(t_start, t_end, time_per_frame)
30
+
31
+ def update(time, ax, fig):
32
+ for ax_ in fig.get_axes():
33
+ ax_.clear()
34
+ ax = plot_function(ax=ax, t_start=t_start, t_end=time)
35
+ return ax
36
+
37
+ ax = plot_function()
38
+ from matplotlib.animation import FuncAnimation
39
+ anim = FuncAnimation(plt.gcf(), update, frames=times, fargs=(ax, plt.gcf(),), interval=1000/fps)
40
+ plt.close()
41
+ return anim
42
+
43
+ def plot_trajectory(
44
+ trajectory,
45
+ time_stamps,
46
+ ax=None,
47
+ t_start=None,
48
+ t_end=None,
49
+ **plot_kwargs):
50
+ """Plots a trajectory over a given time range
51
+
52
+ Params
53
+ ------
54
+ ax: matplotlib axis, optional
55
+ Axis to plot on
56
+ t_start: float, optional
57
+ Start time
58
+ t_end: float, optional
59
+ End time
60
+ trajectory: np.ndarray, optional
61
+ Trajectory to plot, shape (T, 2)
62
+ time_stamps: np.ndarray, optional
63
+ Time stamps for the trajectory, shape (T,)
64
+ plot_kwargs: dict, optional
65
+ Additional plotting arguments like 'color', 'alpha', 'label', 'scatter_points', 'show_line', 'linewidth', 'title', 'xlabel', 'ylabel'
66
+
67
+ Returns
68
+ -------
69
+ ax: matplotlib axis
70
+ """
71
+ if t_start is None: t_start = time_stamps[0]
72
+ if t_end is None: t_end = time_stamps[-1]
73
+ if ax is None:
74
+ fig, ax = plt.subplots(1, 1, figsize=(6, 6))
75
+
76
+ id_start, id_end = np.argmin(np.abs(time_stamps - t_start)), np.argmin(np.abs(time_stamps-t_end))
77
+ trajectory_ = trajectory[id_start:id_end]
78
+
79
+ # Get plot kwargs
80
+ color = plot_kwargs.get('color', 'k')
81
+ scatter_points = plot_kwargs.get('scatter_points', True)
82
+ show_line = plot_kwargs.get('show_line', True)
83
+ linewidth = plot_kwargs.get('linewidth',1)
84
+ title = plot_kwargs.get('title',None)
85
+ xlabel = plot_kwargs.get('xlabel','x [m]')
86
+ ylabel = plot_kwargs.get('ylabel','y [m]')
87
+ alpha = plot_kwargs.get('alpha',1)
88
+ label = plot_kwargs.get('label',None)
89
+ min_x = plot_kwargs.get('min_x',trajectory[:,0].min().round(1))
90
+ max_x = plot_kwargs.get('max_x',trajectory[:,0].max().round(1))
91
+ min_y = plot_kwargs.get('min_y',trajectory[:,1].min().round(1))
92
+ max_y = plot_kwargs.get('max_y',trajectory[:,1].max().round(1))
93
+
94
+ if show_line:
95
+ ax.plot(trajectory_[:,0],trajectory_[:,1],color=color, linewidth=linewidth, alpha=alpha, label=label)
96
+ if scatter_points:
97
+ ax.scatter(trajectory_[:,0],trajectory_[:,1],color=color, linewidth=0, s=6, alpha=alpha)
98
+ ax.set_xlim(min_x, max_x); ax.set_ylim(min_y, max_y); ax.set_aspect('equal', 'box')
99
+ ax.set_xticks([min_x, max_x]); ax.set_yticks([min_y, max_y])
100
+ ax.set_xlabel(xlabel)
101
+ ax.set_ylabel(ylabel)
102
+ if title is not None: ax.set_title(title)
103
+
104
+ return ax
105
+
106
+ def plot_ellipse(ax, mean, cov, color):
107
+ lambda_, v = np.linalg.eig(cov)
108
+ lambda_ = np.sqrt(lambda_) # convert from variance to standard deviation along the eigenvectors
109
+ ell = matplotlib.patches.Ellipse(xy=mean,
110
+ width=lambda_[0]*2,
111
+ height=lambda_[1]*2,
112
+ angle=np.rad2deg(np.arctan(v[:, 0][1] / v[:, 0][0])),
113
+ lw=1,
114
+ fill=True,
115
+ edgecolor=color,
116
+ facecolor=color,
117
+ alpha=0.5,)
118
+ ax.add_artist(ell)
119
+ return ax