cellworld-npx 0.0.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cellworld_npx-0.0.2/LICENSE.txt +19 -0
- cellworld_npx-0.0.2/PKG-INFO +21 -0
- cellworld_npx-0.0.2/README.md +0 -0
- cellworld_npx-0.0.2/pyproject.toml +3 -0
- cellworld_npx-0.0.2/setup.cfg +7 -0
- cellworld_npx-0.0.2/setup.py +16 -0
- cellworld_npx-0.0.2/src/cellworld_npx/__init__.py +0 -0
- cellworld_npx-0.0.2/src/cellworld_npx/camera.py +49 -0
- cellworld_npx-0.0.2/src/cellworld_npx/celltile.py +147 -0
- cellworld_npx-0.0.2/src/cellworld_npx/classifier.py +387 -0
- cellworld_npx-0.0.2/src/cellworld_npx/cluster_metrics.py +734 -0
- cellworld_npx-0.0.2/src/cellworld_npx/coverage.py +120 -0
- cellworld_npx-0.0.2/src/cellworld_npx/decoder.py +160 -0
- cellworld_npx-0.0.2/src/cellworld_npx/episode.py +34 -0
- cellworld_npx-0.0.2/src/cellworld_npx/io.py +224 -0
- cellworld_npx-0.0.2/src/cellworld_npx/kalman.py +101 -0
- cellworld_npx-0.0.2/src/cellworld_npx/lfp.py +291 -0
- cellworld_npx-0.0.2/src/cellworld_npx/map.py +110 -0
- cellworld_npx-0.0.2/src/cellworld_npx/probe.py +312 -0
- cellworld_npx-0.0.2/src/cellworld_npx/recording.py +836 -0
- cellworld_npx-0.0.2/src/cellworld_npx/state_decoder.py +386 -0
- cellworld_npx-0.0.2/src/cellworld_npx/sync.py +389 -0
- cellworld_npx-0.0.2/src/cellworld_npx/utils.py +1163 -0
- cellworld_npx-0.0.2/src/cellworld_npx.egg-info/PKG-INFO +21 -0
- cellworld_npx-0.0.2/src/cellworld_npx.egg-info/SOURCES.txt +28 -0
- cellworld_npx-0.0.2/src/cellworld_npx.egg-info/dependency_links.txt +1 -0
- cellworld_npx-0.0.2/src/cellworld_npx.egg-info/not-zip-safe +1 -0
- cellworld_npx-0.0.2/src/cellworld_npx.egg-info/requires.txt +14 -0
- cellworld_npx-0.0.2/src/cellworld_npx.egg-info/top_level.txt +1 -0
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
Copyright (c) 2018 The Python Packaging Authority
|
|
2
|
+
|
|
3
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
4
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
5
|
+
in the Software without restriction, including without limitation the rights
|
|
6
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
7
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
8
|
+
furnished to do so, subject to the following conditions:
|
|
9
|
+
|
|
10
|
+
The above copyright notice and this permission notice shall be included in all
|
|
11
|
+
copies or substantial portions of the Software.
|
|
12
|
+
|
|
13
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
14
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
15
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
16
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
17
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
18
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
19
|
+
SOFTWARE.
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: cellworld_npx
|
|
3
|
+
Version: 0.0.2
|
|
4
|
+
Summary: Dombeck/MacIver labs neuroethology analysis package - combined cellworld behavior and neuropixels recordings
|
|
5
|
+
Author: Chris Angeloni
|
|
6
|
+
Author-email: chris.angeloni@gmail.com
|
|
7
|
+
License: MIT
|
|
8
|
+
License-File: LICENSE.txt
|
|
9
|
+
Requires-Dist: numpy
|
|
10
|
+
Requires-Dist: scipy
|
|
11
|
+
Requires-Dist: matplotlib
|
|
12
|
+
Requires-Dist: json-cpp
|
|
13
|
+
Requires-Dist: cellworld
|
|
14
|
+
Requires-Dist: npyx
|
|
15
|
+
Requires-Dist: pandas
|
|
16
|
+
Requires-Dist: astropy
|
|
17
|
+
Requires-Dist: rtree
|
|
18
|
+
Requires-Dist: kilosort
|
|
19
|
+
Requires-Dist: torch
|
|
20
|
+
Provides-Extra: kilosort
|
|
21
|
+
Requires-Dist: kilosort; extra == "kilosort"
|
|
Binary file
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from setuptools import setup, find_packages
|
|
2
|
+
import pathlib
|
|
3
|
+
|
|
4
|
+
here = pathlib.Path(__file__).parent.resolve()
|
|
5
|
+
|
|
6
|
+
setup(name='cellworld_npx',
|
|
7
|
+
description='Dombeck/MacIver labs neuroethology analysis package - combined cellworld behavior and neuropixels recordings',
|
|
8
|
+
author='Chris Angeloni',
|
|
9
|
+
author_email='chris.angeloni@gmail.com',
|
|
10
|
+
packages=find_packages(where="src"),
|
|
11
|
+
package_dir={"": "src"},
|
|
12
|
+
install_requires=['numpy', 'scipy', 'matplotlib', 'json-cpp', 'cellworld', 'npyx', 'pandas', 'astropy', 'rtree', 'kilosort', 'torch'],
|
|
13
|
+
extras_require={'kilosort': 'kilosort'},
|
|
14
|
+
license='MIT',
|
|
15
|
+
version='0.0.2',
|
|
16
|
+
zip_safe=False)
|
|
File without changes
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import cv2
|
|
2
|
+
import numpy as np
|
|
3
|
+
from json_cpp import JsonObject, JsonList
|
|
4
|
+
|
|
5
|
+
class Cameras(JsonList):
|
|
6
|
+
def __init__(self):
|
|
7
|
+
super().__init__(list_type=Camera)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Camera(JsonObject):
|
|
11
|
+
def __init__(self, name=str(), root=str(), roi=(224, 351, 10, 14)):
|
|
12
|
+
self.name = name
|
|
13
|
+
self.root = root
|
|
14
|
+
self.fps = float()
|
|
15
|
+
self.frame_count = int()
|
|
16
|
+
self.width = int()
|
|
17
|
+
self.height = int()
|
|
18
|
+
self.roi = roi
|
|
19
|
+
self.get_capture_properties()
|
|
20
|
+
|
|
21
|
+
def select_roi(self):
|
|
22
|
+
cap = cv2.VideoCapture(self.root)
|
|
23
|
+
ret, frame = cap.read()
|
|
24
|
+
print("Please select the ROI by dragging a box.")
|
|
25
|
+
self.roi = cv2.selectROI("Select ROI", frame, fromCenter=False, showCrosshair=True)
|
|
26
|
+
cv2.destroyWindow("Select ROI") # Close the ROI selection window
|
|
27
|
+
cap.release()
|
|
28
|
+
|
|
29
|
+
def get_capture_properties(self):
|
|
30
|
+
cap = cv2.VideoCapture(self.root)
|
|
31
|
+
self.fps = cap.get(cv2.CAP_PROP_FPS)
|
|
32
|
+
self.frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
33
|
+
self.width = int(cap.get(cv2.CV_CAP_PROP_FRAME_WIDTH))
|
|
34
|
+
self.height = int(cap.get(cv2.CV_CAP_PROP_FRAME_HEIGHT))
|
|
35
|
+
cap.release()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_roi_intensity(filename, ROI=(224, 351, 10, 14)):
|
|
39
|
+
cap = cv2.VideoCapture(filename)
|
|
40
|
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
41
|
+
values = []
|
|
42
|
+
while True:
|
|
43
|
+
ret, frame = cap.read()
|
|
44
|
+
if not ret:
|
|
45
|
+
break
|
|
46
|
+
led = frame[ROI[1]:(ROI[1]+ROI[3]),ROI[0]:(ROI[0]+ROI[2]+1),1]
|
|
47
|
+
values.append(np.mean(led))
|
|
48
|
+
cap.release()
|
|
49
|
+
return values, fps
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import matplotlib.pyplot as plt
|
|
3
|
+
import matplotlib.path as mpath
|
|
4
|
+
import matplotlib.patches as mpatches
|
|
5
|
+
import rtree
|
|
6
|
+
from cellworld import Display, World
|
|
7
|
+
|
|
8
|
+
def get_tiles(e, bins=np.linspace(0,1,100)):
|
|
9
|
+
"""Get nxn locations tiled across the world in experiment object, then removes tiles
|
|
10
|
+
that are within obstacles in the world. (needs to Display the world to do so)"""
|
|
11
|
+
# generate world tiles
|
|
12
|
+
w = World.get_from_parameters_names('hexagonal','canonical',e.occlusions)
|
|
13
|
+
x = bins
|
|
14
|
+
xv,yv = np.meshgrid(x,x,indexing='ij')
|
|
15
|
+
xv = xv.reshape(1,-1)
|
|
16
|
+
yv = yv.reshape(1,-1)
|
|
17
|
+
points = np.concatenate((xv,yv)).T
|
|
18
|
+
|
|
19
|
+
# get the wall limits
|
|
20
|
+
plt.ioff()
|
|
21
|
+
d = Display(w, fig_size=(1,1), padding=0, cell_edge_color="lightgrey")
|
|
22
|
+
plt.ion()
|
|
23
|
+
path = d.habitat_polygon.get_path()
|
|
24
|
+
transform = d.habitat_polygon.get_patch_transform()
|
|
25
|
+
newpath = transform.transform_path(path)
|
|
26
|
+
polygon = mpatches.PathPatch(newpath)
|
|
27
|
+
inside = []
|
|
28
|
+
inside.append(~newpath.contains_points(points))
|
|
29
|
+
|
|
30
|
+
# get the occlusion limits and remove points
|
|
31
|
+
for poly in d.cell_polygons:
|
|
32
|
+
if poly._facecolor[0]==0:
|
|
33
|
+
path = poly.get_path()
|
|
34
|
+
transform = poly.get_patch_transform()
|
|
35
|
+
newpath = transform.transform_path(path)
|
|
36
|
+
polygon = mpatches.PathPatch(newpath)
|
|
37
|
+
inside.append(newpath.contains_points(points,radius=0.025))
|
|
38
|
+
index = np.any(np.vstack(inside).T,axis=1)
|
|
39
|
+
return points[~index,:]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def plot_tiles(pts,sparse_arr,e):
|
|
43
|
+
a = 1
|
|
44
|
+
w = World.get_from_parameters_names('hexagonal','canonical',e.occlusions)
|
|
45
|
+
|
|
46
|
+
# display
|
|
47
|
+
fig,ax = plt.subplots(1,2,figsize=(10,5))
|
|
48
|
+
d = Display(w, fig_size=(5,5), padding=0, cell_edge_color="lightgrey",ax=ax[0])
|
|
49
|
+
ax[0].scatter(pts[:,0],pts[:,1],5,'g',alpha = a)
|
|
50
|
+
ax[0].scatter(sparse_arr[:,0],sparse_arr[:,1],20,'m')
|
|
51
|
+
|
|
52
|
+
Display(w, fig_size=(5,5), padding=0, cell_edge_color="lightgrey",ax=ax[1])
|
|
53
|
+
ax[1].scatter(pts[:,0],pts[:,1],5,'g',alpha = a)
|
|
54
|
+
ax[1].scatter(sparse_arr[:,0],sparse_arr[:,1],20,'m')
|
|
55
|
+
ax[1].set_xlim((.25,.3))
|
|
56
|
+
ax[1].set_ylim((.25,.3))
|
|
57
|
+
return [fig,ax]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def dist(p,q):
|
|
61
|
+
"""Return distance between two points."""
|
|
62
|
+
return math.hypot(p[0]-q[0],p[1]-q[1])
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def sparse_subset(points,r):
|
|
66
|
+
"""Return a maximal list of elements of points such that no pairs of
|
|
67
|
+
points in the result have distance less than r."""
|
|
68
|
+
result = []
|
|
69
|
+
index = rtree.index.Index()
|
|
70
|
+
for i, p in enumerate(points):
|
|
71
|
+
px, py = p
|
|
72
|
+
nearby = index.intersection((px - r, py - r, px + r, py + r))
|
|
73
|
+
if all(dist(p, points[j]) >= r for j in nearby):
|
|
74
|
+
result.append(p)
|
|
75
|
+
index.insert(i, (px, py, px, py))
|
|
76
|
+
return result
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def get_vertices(e):
|
|
80
|
+
"""Gets unique vertices from all polygons."""
|
|
81
|
+
# make a list of all polygon vertices
|
|
82
|
+
w = World.get_from_parameters_names('hexagonal','canonical',e.occlusions)
|
|
83
|
+
all_polygons = Polygon_list.get_polygons(w.cells.get('location'),w.configuration.cell_shape.sides, w.implementation.cell_transformation.size / 2, w.implementation.space.transformation.rotation + w.implementation.cell_transformation.rotation)
|
|
84
|
+
x = []
|
|
85
|
+
y = []
|
|
86
|
+
for poly in all_polygons:
|
|
87
|
+
x.append(poly.vertices.get('x'))
|
|
88
|
+
y.append(poly.vertices.get('y'))
|
|
89
|
+
x = np.hstack(x).reshape(1,-1).T
|
|
90
|
+
y = np.hstack(y).reshape(1,-1).T
|
|
91
|
+
verts = np.concatenate((x,y),axis=1)
|
|
92
|
+
pts = verts.tolist()
|
|
93
|
+
|
|
94
|
+
# get unique vertices, removing those closeby
|
|
95
|
+
sparse_pts = sparse_subset(pts,0.01)
|
|
96
|
+
sparse_arr = np.vstack(sparse_pts)
|
|
97
|
+
return sparse_arr
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def get_world_mask(w, bins, wall_mask=True, occlusion_mask=True):
|
|
101
|
+
binc = bins[:-1] + np.mean(np.diff(bins))/2
|
|
102
|
+
xv,yv = np.meshgrid(binc, binc, indexing='ij')
|
|
103
|
+
xv = xv.reshape(1,-1)
|
|
104
|
+
yv = yv.reshape(1,-1)
|
|
105
|
+
|
|
106
|
+
points = np.concatenate((xv,yv)).T
|
|
107
|
+
index = np.zeros([1, (len(bins)-1)**2])
|
|
108
|
+
if wall_mask:
|
|
109
|
+
index = np.concatenate((index, get_wall_mask(w, points)[np.newaxis,:]), axis=0)
|
|
110
|
+
if occlusion_mask:
|
|
111
|
+
index = np.concatenate((index, get_occlusion_mask(w, points)[np.newaxis,:]), axis=0)
|
|
112
|
+
index = np.any(index, axis=0)
|
|
113
|
+
return index
|
|
114
|
+
|
|
115
|
+
def get_occlusion_mask(w, locations, r=0.025):
|
|
116
|
+
if type(w) is str:
|
|
117
|
+
w = World.get_from_parameters_names('hexagonal', 'canonical', w)
|
|
118
|
+
d = Display(w, fig_size=(1,1), padding=0, cell_edge_color="lightgrey")
|
|
119
|
+
plt.close(d.fig)
|
|
120
|
+
inside = []
|
|
121
|
+
for poly in d.cell_polygons:
|
|
122
|
+
if poly._facecolor[0]==0:
|
|
123
|
+
path = poly.get_path()
|
|
124
|
+
transform = poly.get_patch_transform()
|
|
125
|
+
newpath = transform.transform_path(path)
|
|
126
|
+
polygon = mpatches.PathPatch(newpath)
|
|
127
|
+
inside.append(newpath.contains_points(locations,radius=r))
|
|
128
|
+
if len(inside) == 0:
|
|
129
|
+
index = np.zeros(len(locations)) > 1
|
|
130
|
+
else:
|
|
131
|
+
index = np.any(np.vstack(inside).T,axis=1)
|
|
132
|
+
return index
|
|
133
|
+
|
|
134
|
+
def get_wall_mask(w, locations, r=0.025):
|
|
135
|
+
if type(w) is str:
|
|
136
|
+
w = World.get_from_parameters_names('hexagonal', 'canonical', w)
|
|
137
|
+
d = Display(w, fig_size=(1,1), padding=0, cell_edge_color="lightgrey")
|
|
138
|
+
plt.close(d.fig)
|
|
139
|
+
path = d.habitat_polygon.get_path()
|
|
140
|
+
transform = d.habitat_polygon.get_patch_transform()
|
|
141
|
+
newpath = transform.transform_path(path)
|
|
142
|
+
polygon = mpatches.PathPatch(newpath)
|
|
143
|
+
inside = []
|
|
144
|
+
inside.append(~newpath.contains_points(locations,radius=r))
|
|
145
|
+
index = np.any(np.vstack(inside).T,axis=1)
|
|
146
|
+
return index
|
|
147
|
+
|
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pickle
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
import torch
|
|
5
|
+
from kilosort.io import BinaryFiltered, load_ops
|
|
6
|
+
from cellworld_npx.lfp import get_binary_file
|
|
7
|
+
from cellworld_npx.probe import cluster_probe_channels
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
from sklearn.cluster import DBSCAN
|
|
10
|
+
from cellworld_npx.decoder import bin_recording, format_behavior_data
|
|
11
|
+
from replay_trajectory_classification import ClusterlessClassifier, ClusterlessDecoder, Environment, RandomWalk, Uniform, Identity
|
|
12
|
+
|
|
13
|
+
def group_channel_map(channel_map, n=4):
|
|
14
|
+
#TODO test overlapping groups
|
|
15
|
+
|
|
16
|
+
# group each block of channels into groups of n electrodes
|
|
17
|
+
channel_blocks = cluster_probe_channels(channel_map)
|
|
18
|
+
channel_groups = np.zeros(channel_blocks.shape)
|
|
19
|
+
block_length = 0
|
|
20
|
+
group_count = 0
|
|
21
|
+
for b in np.unique(channel_blocks):
|
|
22
|
+
for i,j in enumerate(range(0, (channel_blocks==b).sum(), n)):
|
|
23
|
+
ind = block_length + j
|
|
24
|
+
channel_groups[ind:ind+n] = i + group_count
|
|
25
|
+
group_count = group_count + i + 1
|
|
26
|
+
block_length = block_length + (channel_blocks==b).sum()
|
|
27
|
+
|
|
28
|
+
# calculate group COM
|
|
29
|
+
group_com = []
|
|
30
|
+
for c in np.unique(channel_groups):
|
|
31
|
+
group_com.append(channel_map[channel_groups==c, 1:].mean(0))
|
|
32
|
+
group_com = np.vstack(group_com)
|
|
33
|
+
|
|
34
|
+
# get four closest channels
|
|
35
|
+
group_channels = []
|
|
36
|
+
for i in range(group_com.shape[0]):
|
|
37
|
+
group_channels.append(np.argsort(np.sum((channel_map[:,1:] - group_com[i,:]) ** 2, axis=1) ** 0.5)[0:n])
|
|
38
|
+
group_channels = np.vstack(group_channels)
|
|
39
|
+
return group_channels, group_com
|
|
40
|
+
|
|
41
|
+
def assign_spike_groups(spike_positions, group_com, show_progress=True):
|
|
42
|
+
# assign spikes to each electrode group
|
|
43
|
+
distances = np.zeros((group_com.shape[0], spike_positions.shape[0]))
|
|
44
|
+
for i in tqdm(range(group_com.shape[0]), desc='assigning spikes to groups', disable=not show_progress):
|
|
45
|
+
distances[i,:] = np.sum((spike_positions - group_com[i,:]) ** 2, axis=1) ** 0.5
|
|
46
|
+
distances = np.vstack(distances)
|
|
47
|
+
spike_group = np.nanargmin(distances, axis=0)
|
|
48
|
+
return spike_group
|
|
49
|
+
|
|
50
|
+
def filter_spikes(R):
|
|
51
|
+
# load aligned spikes and behavior
|
|
52
|
+
spike_times, spike_clusters, clust_info = R.get_spikes()
|
|
53
|
+
d = format_behavior_data(R, agent='prey')
|
|
54
|
+
|
|
55
|
+
# remove spikes outside of episode times
|
|
56
|
+
episode_times = np.vstack([R.episodes.get('start_time'), R.episodes.get('end_time')]).T
|
|
57
|
+
spike_index = np.zeros(len(spike_times))
|
|
58
|
+
for i in range(episode_times.shape[0]):
|
|
59
|
+
spike_index = spike_index + ((spike_times > episode_times[i,0]) &
|
|
60
|
+
(spike_times < episode_times[i,1]) &
|
|
61
|
+
(spike_times < d['time_stamp'][-1]))
|
|
62
|
+
return spike_times, spike_clusters, spike_index
|
|
63
|
+
|
|
64
|
+
def extract_spike_template_features(spike_positions, spike_group, amplitudes, templates, spike_templates, group_channels, n=4,
|
|
65
|
+
show_progress=False):
|
|
66
|
+
# get spike features per spike
|
|
67
|
+
spike_features = np.zeros((spike_positions.shape[0], n))
|
|
68
|
+
for i in tqdm(range(spike_positions.shape[0]), desc='extracting spike template features', disable=not show_progress):
|
|
69
|
+
spike_features[i,:] = np.max(amplitudes[i] * templates[spike_templates[i],:,group_channels[spike_group[i],:]], axis=1)
|
|
70
|
+
return spike_features
|
|
71
|
+
|
|
72
|
+
def extract_spike_amplitude_features(spike_positions, spike_group, spike_amps, group_channels, n=4, show_progress=False):
|
|
73
|
+
spike_features = np.zeros((spike_positions.shape[0], n))
|
|
74
|
+
for i in tqdm(range(spike_positions.shape[0]), desc='extracting spike amplitude features', disable=not show_progress):
|
|
75
|
+
spike_features[i,:] = spike_amps[i, group_channels[spike_group[i]]]
|
|
76
|
+
return spike_features
|
|
77
|
+
|
|
78
|
+
def get_bfile(R, return_ops=False, hp_filter=True, whiten=True, dshift=True):
|
|
79
|
+
results_dir = Path(R.spike_path)
|
|
80
|
+
filename = Path(R.get_probes_continuous_paths()[0])
|
|
81
|
+
device = torch.device('cuda')
|
|
82
|
+
ops = load_ops(results_dir / 'ops.npy', device=device)
|
|
83
|
+
chan_map = ops['chanMap']
|
|
84
|
+
if hp_filter:
|
|
85
|
+
hp_filter = ops['fwav']
|
|
86
|
+
else:
|
|
87
|
+
hp_filter = None
|
|
88
|
+
if whiten:
|
|
89
|
+
whiten = ops['Wrot']
|
|
90
|
+
else:
|
|
91
|
+
whiten = None
|
|
92
|
+
if dshift:
|
|
93
|
+
dshift = ops['dshift']
|
|
94
|
+
else:
|
|
95
|
+
dshift=None
|
|
96
|
+
bfile = BinaryFiltered(filename, n_chan_bin=ops['n_chan_bin'], chan_map=chan_map, device=device,
|
|
97
|
+
hp_filter=hp_filter,
|
|
98
|
+
whiten_mat=whiten,
|
|
99
|
+
dshift=dshift)
|
|
100
|
+
if return_ops:
|
|
101
|
+
return bfile, ops
|
|
102
|
+
else:
|
|
103
|
+
return bfile
|
|
104
|
+
|
|
105
|
+
def get_spike_amplitudes(R, show_progress=False):
|
|
106
|
+
# calculate or load spike amplitudes
|
|
107
|
+
results_dir = Path(R.spike_path)
|
|
108
|
+
fn = results_dir / 'spike_amplitudes.npy'
|
|
109
|
+
if not fn.exists():
|
|
110
|
+
bfile = get_binary_file(R.get_binary_files()[0], hp_filter=True, whiten=True, dshift=True)
|
|
111
|
+
spike_times = np.load(results_dir / 'spike_times.npy')
|
|
112
|
+
clu = np.load(results_dir / 'spike_clusters.npy')
|
|
113
|
+
ops = load_ops(results_dir / 'ops.npy')
|
|
114
|
+
spike_amps = np.zeros((len(spike_times), ops['n_chan_bin']))
|
|
115
|
+
for i,t in enumerate(tqdm(spike_times, desc='extracting spike amplitudes'), disable=not show_progress):
|
|
116
|
+
tmin = t - bfile.nt0min
|
|
117
|
+
tmax = t + (bfile.nt - bfile.nt0min) + 1
|
|
118
|
+
if tmin < 0:
|
|
119
|
+
tmin = 0; tmax = bfile.nt + 1
|
|
120
|
+
if tmax > bfile.n_samples:
|
|
121
|
+
tmax = bfile.n_samples; tmin = tmax - bfile.nt - 1
|
|
122
|
+
spike_amps[i,:] = bfile[tmin:tmax].cpu().numpy()[:,ops['nt0min']].astype('float32')
|
|
123
|
+
np.save(fn, spike_amps)
|
|
124
|
+
else:
|
|
125
|
+
if show_progress:
|
|
126
|
+
print(f'loading spike amplitudes from {fn}')
|
|
127
|
+
spike_amps = np.load(fn)
|
|
128
|
+
|
|
129
|
+
return spike_amps
|
|
130
|
+
|
|
131
|
+
def get_multiunits(spike_times, spike_group, spike_features, bins, show_progress=False, check_duplicate_spikes=False):
|
|
132
|
+
spike_bin = np.digitize(spike_times, bins = bins, right=True)
|
|
133
|
+
ugroups = np.unique(spike_group)
|
|
134
|
+
multiunits = np.zeros((len(bins), spike_features.shape[1], len(ugroups)))
|
|
135
|
+
multiunits[:] = np.nan
|
|
136
|
+
if check_duplicate_spikes:
|
|
137
|
+
spike_counts = np.zeros((len(bins), len(ugroups)))
|
|
138
|
+
for i,s in tqdm(enumerate(spike_bin), total=len(spike_bin), desc='adding spike features to multi-unit array', disable=not show_progress):
|
|
139
|
+
multiunits[s-1,:,np.argwhere(ugroups == spike_group[i])] = spike_features[i,:]
|
|
140
|
+
if check_duplicate_spikes:
|
|
141
|
+
spike_counts[s-1, np.argwhere(ugroups == spike_group[i])] += 1
|
|
142
|
+
multiunits = multiunits[:-1,:,:]
|
|
143
|
+
if check_duplicate_spikes:
|
|
144
|
+
return multiunits, spike_counts
|
|
145
|
+
else:
|
|
146
|
+
return multiunits
|
|
147
|
+
|
|
148
|
+
def get_cv_folds(n_samples, n_folds=10, return_boolean=True):
|
|
149
|
+
cv_runs = []
|
|
150
|
+
run_size = int(np.ceil(n_samples / n_folds))
|
|
151
|
+
for i in range(n_folds):
|
|
152
|
+
train_bool = np.zeros(n_samples, dtype=bool)
|
|
153
|
+
train_ind = [i*run_size,
|
|
154
|
+
np.min([(i+1)*run_size, n_samples])]
|
|
155
|
+
train_bool[train_ind[0]:train_ind[1]] = True
|
|
156
|
+
if not return_boolean:
|
|
157
|
+
train = np.argwhere(~train_bool)
|
|
158
|
+
test = np.argwhere(train_bool)
|
|
159
|
+
else:
|
|
160
|
+
train = ~train_bool
|
|
161
|
+
test = train_bool
|
|
162
|
+
cv_runs.append([train, test])
|
|
163
|
+
return cv_runs
|
|
164
|
+
|
|
165
|
+
def preprocess_data(R, ops, verbose=False):
|
|
166
|
+
print('PREPROCESSING DATA')
|
|
167
|
+
# format behavior
|
|
168
|
+
d = format_behavior_data(R, agent=ops['agents'])
|
|
169
|
+
|
|
170
|
+
# get spike groups
|
|
171
|
+
results_dir = Path(R.spike_path)
|
|
172
|
+
spike_positions = np.load(results_dir / 'spike_positions.npy')
|
|
173
|
+
channel_map = R.get_probe_channel_map()[0]
|
|
174
|
+
group_channels, group_com = group_channel_map(channel_map, n=ops['n'])
|
|
175
|
+
spike_group = assign_spike_groups(spike_positions, group_com, show_progress=verbose)
|
|
176
|
+
|
|
177
|
+
# get spike amplitudes (takes a while first time)
|
|
178
|
+
spike_amps = get_spike_amplitudes(R, show_progress=verbose)
|
|
179
|
+
|
|
180
|
+
# remove out-of-episode spikes
|
|
181
|
+
spike_times, spike_clusters, spike_index = filter_spikes(R)
|
|
182
|
+
|
|
183
|
+
# remove spikes from noise clusters
|
|
184
|
+
good_units = np.argwhere(R.population.get('good_unit'))
|
|
185
|
+
print(f'including spikes from {len(good_units)} single/multi-units with good waveforms')
|
|
186
|
+
good_spikes = np.zeros(spike_clusters.shape)
|
|
187
|
+
for u in good_units:
|
|
188
|
+
good_spikes = good_spikes + (spike_clusters == u)
|
|
189
|
+
spike_index = (spike_index == 1) & (good_spikes == 1)
|
|
190
|
+
|
|
191
|
+
# filter spikes
|
|
192
|
+
spike_positions = spike_positions[spike_index == 1]
|
|
193
|
+
spike_times = spike_times[spike_index == 1]
|
|
194
|
+
spike_clusters = spike_clusters[spike_index == 1]
|
|
195
|
+
spike_group = spike_group[spike_index == 1]
|
|
196
|
+
spike_amps = spike_amps[spike_index == 1]
|
|
197
|
+
|
|
198
|
+
# extract spike features
|
|
199
|
+
spike_features = extract_spike_amplitude_features(spike_positions, spike_group, spike_amps, group_channels, show_progress=verbose)
|
|
200
|
+
|
|
201
|
+
# create multiunit array
|
|
202
|
+
bins = np.arange(np.nanmin(d['time_stamp']) - (ops['dt']/2), np.nanmax(d['time_stamp']) + (ops['dt']/2), ops['dt'])
|
|
203
|
+
multiunits = get_multiunits(spike_times=spike_times, spike_group=spike_group, spike_features=spike_features, bins=bins, show_progress=verbose)
|
|
204
|
+
|
|
205
|
+
# bin the recording
|
|
206
|
+
binned_data = bin_recording(R, agent=ops['agents'], dt=ops['dt'], skip_spikes=True,
|
|
207
|
+
kalman_filter=ops['kalman_filter'], show_progress=verbose)
|
|
208
|
+
|
|
209
|
+
# remove data where agents were not tracked
|
|
210
|
+
column = [i for i in binned_data.columns if 'tracked' in i][0]
|
|
211
|
+
tracked = binned_data[column]==1
|
|
212
|
+
mua = multiunits[tracked,:,:]
|
|
213
|
+
binned_data = binned_data[tracked]
|
|
214
|
+
|
|
215
|
+
data = {
|
|
216
|
+
'mua': mua,
|
|
217
|
+
'time': np.array(binned_data['time_stamp']),
|
|
218
|
+
'position': np.vstack(binned_data['prey_location']) * ops['canonical_to_cm'],
|
|
219
|
+
'velocity': np.vstack(binned_data['prey_velocity']) * ops['canonical_to_cm']
|
|
220
|
+
}
|
|
221
|
+
return data, binned_data
|
|
222
|
+
|
|
223
|
+
def set_ops(dt=0.002, n=4, agents='prey', canonical_to_cm=234, velocity_cutoff=0.1, cv_folds=5, cv_type='fold', kalman_filter=True, bin_size=5,
|
|
224
|
+
random_walk_var=6, mark_var=24, position_var=6, drop_causal_posterior=True, states=['continuous']):
|
|
225
|
+
|
|
226
|
+
# ensure correct ordering for classifier
|
|
227
|
+
if 'continuous' in states and 'fragmented' in states and 'stationary' in states:
|
|
228
|
+
states = ['continuous', 'fragmented', 'stationary']
|
|
229
|
+
elif 'continuous' in states and 'fragmented' in states and 'stationary' not in states:
|
|
230
|
+
states = ['continuous', 'fragmented']
|
|
231
|
+
|
|
232
|
+
# options
|
|
233
|
+
ops = {
|
|
234
|
+
'dt': dt, # decoding resolution (ms)
|
|
235
|
+
'n': n, # spike features
|
|
236
|
+
'agents': agents, # agents to include
|
|
237
|
+
'canonical_to_cm': canonical_to_cm, # convert canonical units to cm
|
|
238
|
+
'velocity_cutoff': velocity_cutoff, # velocity cutoff for training (cm/s)
|
|
239
|
+
'cv_folds': cv_folds, # cross validation folds
|
|
240
|
+
'cv_type': cv_type, # cross validation type ("fold" for fold split, "speed" for speed split)
|
|
241
|
+
'kalman_filter': kalman_filter, # kalman smooth raw data
|
|
242
|
+
'bin_size': bin_size, # bin size for rate maps (cm)
|
|
243
|
+
'random_walk_var': random_walk_var, # variance of movement of transition model (cm)
|
|
244
|
+
'mark_var': mark_var, # variance of encoding model mark space (~uV)
|
|
245
|
+
'position_var': position_var, # variance of encoding model position (cm)
|
|
246
|
+
'drop_causal_posterior': drop_causal_posterior,
|
|
247
|
+
'states': states
|
|
248
|
+
}
|
|
249
|
+
return ops
|
|
250
|
+
|
|
251
|
+
def build_continuous_transitions(ops):
|
|
252
|
+
state_transitions = {'continuous': RandomWalk(movement_var=ops['random_walk_var']),
|
|
253
|
+
'fragmented': Uniform(),
|
|
254
|
+
'stationary': Identity()}
|
|
255
|
+
if 'continuous' in ops['states'] and 'fragmented' in ops['states'] and 'stationary' in ops['states']:
|
|
256
|
+
return [
|
|
257
|
+
[RandomWalk(movement_var=ops['random_walk_var']), Uniform(), Identity()],
|
|
258
|
+
[Uniform(), Uniform(), Uniform()],
|
|
259
|
+
[RandomWalk(movement_var=ops['random_walk_var']), Uniform(), Identity()],
|
|
260
|
+
]
|
|
261
|
+
elif 'continuous' in ops['states'] and 'fragmented' in ops['states'] and 'stationary' not in ops['states']:
|
|
262
|
+
return [
|
|
263
|
+
[RandomWalk(movement_var=ops['random_walk_var']), Uniform()],
|
|
264
|
+
[Uniform(), Uniform()]
|
|
265
|
+
]
|
|
266
|
+
elif len(ops['states']) == 1:
|
|
267
|
+
return state_transitions[ops['states'][0]]
|
|
268
|
+
else:
|
|
269
|
+
raise AssertionError("ops['states'] must be ['continuous', 'fragmented', 'stationary'], ['continuous', 'fragmented'] or ['continuous'], ['fragmented'], or ['stationary']")
|
|
270
|
+
|
|
271
|
+
def run_decoder(data, ops):
|
|
272
|
+
print('TRAIN/TEST DECODER')
|
|
273
|
+
# cross validation
|
|
274
|
+
if ops['cv_type'] == 'speed':
|
|
275
|
+
# define training and testing index
|
|
276
|
+
moving = np.argwhere(data['velocity'].squeeze() > ops['velocity_cutoff'])
|
|
277
|
+
train = moving[0:int(len(moving)/2)].copy().squeeze()
|
|
278
|
+
test = moving[int(len(moving)/2):].copy().squeeze()
|
|
279
|
+
cv_runs = [train, test]
|
|
280
|
+
else:
|
|
281
|
+
cv_runs = get_cv_folds(data['position'].shape[0], ops['cv_folds'])
|
|
282
|
+
|
|
283
|
+
# model setup
|
|
284
|
+
lims = (ops['canonical_to_cm']*0.05, ops['canonical_to_cm']*1.05)
|
|
285
|
+
environment = Environment(place_bin_size=ops['bin_size'], position_range=[lims,lims])
|
|
286
|
+
assert len(ops['states']) == 1, f'For standard decoding there must be one state, not {ops["states"]}, try run_classifier instead...'
|
|
287
|
+
transition_type = build_continuous_transitions(ops)
|
|
288
|
+
clusterless_algorithm = 'multiunit_likelihood_gpu'
|
|
289
|
+
clusterless_algorithm_params = {
|
|
290
|
+
'mark_std': ops['mark_var'],
|
|
291
|
+
'position_std': ops['position_var']
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
# cv loop
|
|
295
|
+
decoders = []
|
|
296
|
+
results = []
|
|
297
|
+
for fold in tqdm(cv_runs, desc='cross-validation fold'):
|
|
298
|
+
decoder = ClusterlessDecoder(
|
|
299
|
+
environment=environment,
|
|
300
|
+
transition_type=transition_type,
|
|
301
|
+
clusterless_algorithm=clusterless_algorithm,
|
|
302
|
+
clusterless_algorithm_params=clusterless_algorithm_params)
|
|
303
|
+
|
|
304
|
+
decoder.fit(data['position'][fold[0],:], data['mua'][fold[0],:,:])
|
|
305
|
+
decoders.append(decoder)
|
|
306
|
+
|
|
307
|
+
result = decoder.predict(data['mua'][fold[1],:,:], time=data['time'][fold[1]], use_gpu=True)
|
|
308
|
+
results.append(result)
|
|
309
|
+
|
|
310
|
+
# compile across runs
|
|
311
|
+
map_estimate = []
|
|
312
|
+
for r in results:
|
|
313
|
+
post = r.acausal_posterior.stack(position=['x_position', 'y_position'])
|
|
314
|
+
map = post.position[post.argmax('position')]
|
|
315
|
+
map = np.asarray(map.values.tolist())
|
|
316
|
+
map_estimate.append(map)
|
|
317
|
+
map_estimate = np.vstack(map_estimate)
|
|
318
|
+
error = np.linalg.norm(map_estimate - data['position'], axis=1)
|
|
319
|
+
result = {
|
|
320
|
+
'dist_error': error,
|
|
321
|
+
'map_estimate': map_estimate,
|
|
322
|
+
'cv_results': {'decoders': decoders, 'results': results}
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
return result
|
|
326
|
+
|
|
327
|
+
def run_classifier(data, ops, drop_causal_posterior=True):
|
|
328
|
+
print('TRAIN/TEST CLASSIFIER')
|
|
329
|
+
# cross validation
|
|
330
|
+
if ops['cv_type'] == 'speed':
|
|
331
|
+
# define training and testing index
|
|
332
|
+
moving = np.argwhere(data['velocity'].squeeze() > ops['velocity_cutoff'])
|
|
333
|
+
train = moving[0:int(len(moving)/2)].copy().squeeze()
|
|
334
|
+
test = moving[int(len(moving)/2):].copy().squeeze()
|
|
335
|
+
cv_runs = [train, test]
|
|
336
|
+
else:
|
|
337
|
+
cv_runs = get_cv_folds(data['position'].shape[0], ops['cv_folds'])
|
|
338
|
+
|
|
339
|
+
# model setup
|
|
340
|
+
environment = Environment(place_bin_size=ops['bin_size'],
|
|
341
|
+
position_range=[(0,ops['canonical_to_cm']),(0,ops['canonical_to_cm'])])
|
|
342
|
+
continuous_transition_types = build_continuous_transitions(ops)
|
|
343
|
+
if len(ops['states']) == 1:
|
|
344
|
+
continuous_transition_types = [[continuous_transition_types]]
|
|
345
|
+
clusterless_algorithm = 'multiunit_likelihood_gpu'
|
|
346
|
+
clusterless_algorithm_params = {
|
|
347
|
+
'mark_std': ops['mark_var'],
|
|
348
|
+
'position_std': ops['position_var']
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
classifiers = []
|
|
352
|
+
results = []
|
|
353
|
+
for fold in tqdm(cv_runs, desc='cross-validation fold'):
|
|
354
|
+
classifier = ClusterlessClassifier(
|
|
355
|
+
environments=environment,
|
|
356
|
+
continuous_transition_types=continuous_transition_types,
|
|
357
|
+
clusterless_algorithm=clusterless_algorithm,
|
|
358
|
+
clusterless_algorithm_params=clusterless_algorithm_params)
|
|
359
|
+
classifier.fit(data['position'][fold[0],:], data['mua'][fold[0],:,:])
|
|
360
|
+
classifiers.append(classifier)
|
|
361
|
+
|
|
362
|
+
result = classifier.predict(data['mua'][fold[1],:,:], time=data['time'][fold[1]], use_gpu=True)
|
|
363
|
+
result['state'] = ops['states']
|
|
364
|
+
if drop_causal_posterior:
|
|
365
|
+
result.drop('causal_posterior')
|
|
366
|
+
results.append(result)
|
|
367
|
+
|
|
368
|
+
map_estimate = []
|
|
369
|
+
for r in results:
|
|
370
|
+
post = r.acausal_posterior.sum('state').stack(position=['x_position', 'y_position'])
|
|
371
|
+
map = post.position[post.argmax('position')]
|
|
372
|
+
map = np.asarray(map.values.tolist())
|
|
373
|
+
map_estimate.append(map)
|
|
374
|
+
map_estimate = np.vstack(map_estimate)
|
|
375
|
+
error = np.linalg.norm(map_estimate - data['position'], axis=1)
|
|
376
|
+
result = {
|
|
377
|
+
'dist_error': error,
|
|
378
|
+
'map_estimate': map_estimate,
|
|
379
|
+
'cv_results': {'classifiers': classifiers, 'results': results}
|
|
380
|
+
}
|
|
381
|
+
|
|
382
|
+
return result
|
|
383
|
+
|
|
384
|
+
def save_results(fn, result_list:list):
|
|
385
|
+
with open(fn, 'wb') as fid:
|
|
386
|
+
for r in result_list:
|
|
387
|
+
pickle.dump(r, fid)
|