cellworld-npx 0.0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. cellworld_npx-0.0.2/LICENSE.txt +19 -0
  2. cellworld_npx-0.0.2/PKG-INFO +21 -0
  3. cellworld_npx-0.0.2/README.md +0 -0
  4. cellworld_npx-0.0.2/pyproject.toml +3 -0
  5. cellworld_npx-0.0.2/setup.cfg +7 -0
  6. cellworld_npx-0.0.2/setup.py +16 -0
  7. cellworld_npx-0.0.2/src/cellworld_npx/__init__.py +0 -0
  8. cellworld_npx-0.0.2/src/cellworld_npx/camera.py +49 -0
  9. cellworld_npx-0.0.2/src/cellworld_npx/celltile.py +147 -0
  10. cellworld_npx-0.0.2/src/cellworld_npx/classifier.py +387 -0
  11. cellworld_npx-0.0.2/src/cellworld_npx/cluster_metrics.py +734 -0
  12. cellworld_npx-0.0.2/src/cellworld_npx/coverage.py +120 -0
  13. cellworld_npx-0.0.2/src/cellworld_npx/decoder.py +160 -0
  14. cellworld_npx-0.0.2/src/cellworld_npx/episode.py +34 -0
  15. cellworld_npx-0.0.2/src/cellworld_npx/io.py +224 -0
  16. cellworld_npx-0.0.2/src/cellworld_npx/kalman.py +101 -0
  17. cellworld_npx-0.0.2/src/cellworld_npx/lfp.py +291 -0
  18. cellworld_npx-0.0.2/src/cellworld_npx/map.py +110 -0
  19. cellworld_npx-0.0.2/src/cellworld_npx/probe.py +312 -0
  20. cellworld_npx-0.0.2/src/cellworld_npx/recording.py +836 -0
  21. cellworld_npx-0.0.2/src/cellworld_npx/state_decoder.py +386 -0
  22. cellworld_npx-0.0.2/src/cellworld_npx/sync.py +389 -0
  23. cellworld_npx-0.0.2/src/cellworld_npx/utils.py +1163 -0
  24. cellworld_npx-0.0.2/src/cellworld_npx.egg-info/PKG-INFO +21 -0
  25. cellworld_npx-0.0.2/src/cellworld_npx.egg-info/SOURCES.txt +28 -0
  26. cellworld_npx-0.0.2/src/cellworld_npx.egg-info/dependency_links.txt +1 -0
  27. cellworld_npx-0.0.2/src/cellworld_npx.egg-info/not-zip-safe +1 -0
  28. cellworld_npx-0.0.2/src/cellworld_npx.egg-info/requires.txt +14 -0
  29. cellworld_npx-0.0.2/src/cellworld_npx.egg-info/top_level.txt +1 -0
@@ -0,0 +1,19 @@
1
+ Copyright (c) 2018 The Python Packaging Authority
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in all
11
+ copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ SOFTWARE.
@@ -0,0 +1,21 @@
1
+ Metadata-Version: 2.1
2
+ Name: cellworld_npx
3
+ Version: 0.0.2
4
+ Summary: Dombeck/MacIver labs neuroethology analysis package - combined cellworld behavior and neuropixels recordings
5
+ Author: Chris Angeloni
6
+ Author-email: chris.angeloni@gmail.com
7
+ License: MIT
8
+ License-File: LICENSE.txt
9
+ Requires-Dist: numpy
10
+ Requires-Dist: scipy
11
+ Requires-Dist: matplotlib
12
+ Requires-Dist: json-cpp
13
+ Requires-Dist: cellworld
14
+ Requires-Dist: npyx
15
+ Requires-Dist: pandas
16
+ Requires-Dist: astropy
17
+ Requires-Dist: rtree
18
+ Requires-Dist: kilosort
19
+ Requires-Dist: torch
20
+ Provides-Extra: kilosort
21
+ Requires-Dist: kilosort; extra == "kilosort"
Binary file
@@ -0,0 +1,3 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
@@ -0,0 +1,7 @@
1
+ [metadata]
2
+ license_files = LICENSE.txt
3
+
4
+ [egg_info]
5
+ tag_build =
6
+ tag_date = 0
7
+
@@ -0,0 +1,16 @@
1
+ from setuptools import setup, find_packages
2
+ import pathlib
3
+
4
+ here = pathlib.Path(__file__).parent.resolve()
5
+
6
+ setup(name='cellworld_npx',
7
+ description='Dombeck/MacIver labs neuroethology analysis package - combined cellworld behavior and neuropixels recordings',
8
+ author='Chris Angeloni',
9
+ author_email='chris.angeloni@gmail.com',
10
+ packages=find_packages(where="src"),
11
+ package_dir={"": "src"},
12
+ install_requires=['numpy', 'scipy', 'matplotlib', 'json-cpp', 'cellworld', 'npyx', 'pandas', 'astropy', 'rtree', 'kilosort', 'torch'],
13
+ extras_require={'kilosort': 'kilosort'},
14
+ license='MIT',
15
+ version='0.0.2',
16
+ zip_safe=False)
File without changes
@@ -0,0 +1,49 @@
1
+ import cv2
2
+ import numpy as np
3
+ from json_cpp import JsonObject, JsonList
4
+
5
+ class Cameras(JsonList):
6
+ def __init__(self):
7
+ super().__init__(list_type=Camera)
8
+
9
+
10
+ class Camera(JsonObject):
11
+ def __init__(self, name=str(), root=str(), roi=(224, 351, 10, 14)):
12
+ self.name = name
13
+ self.root = root
14
+ self.fps = float()
15
+ self.frame_count = int()
16
+ self.width = int()
17
+ self.height = int()
18
+ self.roi = roi
19
+ self.get_capture_properties()
20
+
21
+ def select_roi(self):
22
+ cap = cv2.VideoCapture(self.root)
23
+ ret, frame = cap.read()
24
+ print("Please select the ROI by dragging a box.")
25
+ self.roi = cv2.selectROI("Select ROI", frame, fromCenter=False, showCrosshair=True)
26
+ cv2.destroyWindow("Select ROI") # Close the ROI selection window
27
+ cap.release()
28
+
29
+ def get_capture_properties(self):
30
+ cap = cv2.VideoCapture(self.root)
31
+ self.fps = cap.get(cv2.CAP_PROP_FPS)
32
+ self.frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
33
+ self.width = int(cap.get(cv2.CV_CAP_PROP_FRAME_WIDTH))
34
+ self.height = int(cap.get(cv2.CV_CAP_PROP_FRAME_HEIGHT))
35
+ cap.release()
36
+
37
+
38
+ def get_roi_intensity(filename, ROI=(224, 351, 10, 14)):
39
+ cap = cv2.VideoCapture(filename)
40
+ fps = cap.get(cv2.CAP_PROP_FPS)
41
+ values = []
42
+ while True:
43
+ ret, frame = cap.read()
44
+ if not ret:
45
+ break
46
+ led = frame[ROI[1]:(ROI[1]+ROI[3]),ROI[0]:(ROI[0]+ROI[2]+1),1]
47
+ values.append(np.mean(led))
48
+ cap.release()
49
+ return values, fps
@@ -0,0 +1,147 @@
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib.path as mpath
4
+ import matplotlib.patches as mpatches
5
+ import rtree
6
+ from cellworld import Display, World
7
+
8
+ def get_tiles(e, bins=np.linspace(0,1,100)):
9
+ """Get nxn locations tiled across the world in experiment object, then removes tiles
10
+ that are within obstacles in the world. (needs to Display the world to do so)"""
11
+ # generate world tiles
12
+ w = World.get_from_parameters_names('hexagonal','canonical',e.occlusions)
13
+ x = bins
14
+ xv,yv = np.meshgrid(x,x,indexing='ij')
15
+ xv = xv.reshape(1,-1)
16
+ yv = yv.reshape(1,-1)
17
+ points = np.concatenate((xv,yv)).T
18
+
19
+ # get the wall limits
20
+ plt.ioff()
21
+ d = Display(w, fig_size=(1,1), padding=0, cell_edge_color="lightgrey")
22
+ plt.ion()
23
+ path = d.habitat_polygon.get_path()
24
+ transform = d.habitat_polygon.get_patch_transform()
25
+ newpath = transform.transform_path(path)
26
+ polygon = mpatches.PathPatch(newpath)
27
+ inside = []
28
+ inside.append(~newpath.contains_points(points))
29
+
30
+ # get the occlusion limits and remove points
31
+ for poly in d.cell_polygons:
32
+ if poly._facecolor[0]==0:
33
+ path = poly.get_path()
34
+ transform = poly.get_patch_transform()
35
+ newpath = transform.transform_path(path)
36
+ polygon = mpatches.PathPatch(newpath)
37
+ inside.append(newpath.contains_points(points,radius=0.025))
38
+ index = np.any(np.vstack(inside).T,axis=1)
39
+ return points[~index,:]
40
+
41
+
42
+ def plot_tiles(pts,sparse_arr,e):
43
+ a = 1
44
+ w = World.get_from_parameters_names('hexagonal','canonical',e.occlusions)
45
+
46
+ # display
47
+ fig,ax = plt.subplots(1,2,figsize=(10,5))
48
+ d = Display(w, fig_size=(5,5), padding=0, cell_edge_color="lightgrey",ax=ax[0])
49
+ ax[0].scatter(pts[:,0],pts[:,1],5,'g',alpha = a)
50
+ ax[0].scatter(sparse_arr[:,0],sparse_arr[:,1],20,'m')
51
+
52
+ Display(w, fig_size=(5,5), padding=0, cell_edge_color="lightgrey",ax=ax[1])
53
+ ax[1].scatter(pts[:,0],pts[:,1],5,'g',alpha = a)
54
+ ax[1].scatter(sparse_arr[:,0],sparse_arr[:,1],20,'m')
55
+ ax[1].set_xlim((.25,.3))
56
+ ax[1].set_ylim((.25,.3))
57
+ return [fig,ax]
58
+
59
+
60
+ def dist(p,q):
61
+ """Return distance between two points."""
62
+ return math.hypot(p[0]-q[0],p[1]-q[1])
63
+
64
+
65
+ def sparse_subset(points,r):
66
+ """Return a maximal list of elements of points such that no pairs of
67
+ points in the result have distance less than r."""
68
+ result = []
69
+ index = rtree.index.Index()
70
+ for i, p in enumerate(points):
71
+ px, py = p
72
+ nearby = index.intersection((px - r, py - r, px + r, py + r))
73
+ if all(dist(p, points[j]) >= r for j in nearby):
74
+ result.append(p)
75
+ index.insert(i, (px, py, px, py))
76
+ return result
77
+
78
+
79
+ def get_vertices(e):
80
+ """Gets unique vertices from all polygons."""
81
+ # make a list of all polygon vertices
82
+ w = World.get_from_parameters_names('hexagonal','canonical',e.occlusions)
83
+ all_polygons = Polygon_list.get_polygons(w.cells.get('location'),w.configuration.cell_shape.sides, w.implementation.cell_transformation.size / 2, w.implementation.space.transformation.rotation + w.implementation.cell_transformation.rotation)
84
+ x = []
85
+ y = []
86
+ for poly in all_polygons:
87
+ x.append(poly.vertices.get('x'))
88
+ y.append(poly.vertices.get('y'))
89
+ x = np.hstack(x).reshape(1,-1).T
90
+ y = np.hstack(y).reshape(1,-1).T
91
+ verts = np.concatenate((x,y),axis=1)
92
+ pts = verts.tolist()
93
+
94
+ # get unique vertices, removing those closeby
95
+ sparse_pts = sparse_subset(pts,0.01)
96
+ sparse_arr = np.vstack(sparse_pts)
97
+ return sparse_arr
98
+
99
+
100
+ def get_world_mask(w, bins, wall_mask=True, occlusion_mask=True):
101
+ binc = bins[:-1] + np.mean(np.diff(bins))/2
102
+ xv,yv = np.meshgrid(binc, binc, indexing='ij')
103
+ xv = xv.reshape(1,-1)
104
+ yv = yv.reshape(1,-1)
105
+
106
+ points = np.concatenate((xv,yv)).T
107
+ index = np.zeros([1, (len(bins)-1)**2])
108
+ if wall_mask:
109
+ index = np.concatenate((index, get_wall_mask(w, points)[np.newaxis,:]), axis=0)
110
+ if occlusion_mask:
111
+ index = np.concatenate((index, get_occlusion_mask(w, points)[np.newaxis,:]), axis=0)
112
+ index = np.any(index, axis=0)
113
+ return index
114
+
115
+ def get_occlusion_mask(w, locations, r=0.025):
116
+ if type(w) is str:
117
+ w = World.get_from_parameters_names('hexagonal', 'canonical', w)
118
+ d = Display(w, fig_size=(1,1), padding=0, cell_edge_color="lightgrey")
119
+ plt.close(d.fig)
120
+ inside = []
121
+ for poly in d.cell_polygons:
122
+ if poly._facecolor[0]==0:
123
+ path = poly.get_path()
124
+ transform = poly.get_patch_transform()
125
+ newpath = transform.transform_path(path)
126
+ polygon = mpatches.PathPatch(newpath)
127
+ inside.append(newpath.contains_points(locations,radius=r))
128
+ if len(inside) == 0:
129
+ index = np.zeros(len(locations)) > 1
130
+ else:
131
+ index = np.any(np.vstack(inside).T,axis=1)
132
+ return index
133
+
134
+ def get_wall_mask(w, locations, r=0.025):
135
+ if type(w) is str:
136
+ w = World.get_from_parameters_names('hexagonal', 'canonical', w)
137
+ d = Display(w, fig_size=(1,1), padding=0, cell_edge_color="lightgrey")
138
+ plt.close(d.fig)
139
+ path = d.habitat_polygon.get_path()
140
+ transform = d.habitat_polygon.get_patch_transform()
141
+ newpath = transform.transform_path(path)
142
+ polygon = mpatches.PathPatch(newpath)
143
+ inside = []
144
+ inside.append(~newpath.contains_points(locations,radius=r))
145
+ index = np.any(np.vstack(inside).T,axis=1)
146
+ return index
147
+
@@ -0,0 +1,387 @@
1
+ import numpy as np
2
+ import pickle
3
+ from pathlib import Path
4
+ import torch
5
+ from kilosort.io import BinaryFiltered, load_ops
6
+ from cellworld_npx.lfp import get_binary_file
7
+ from cellworld_npx.probe import cluster_probe_channels
8
+ from tqdm import tqdm
9
+ from sklearn.cluster import DBSCAN
10
+ from cellworld_npx.decoder import bin_recording, format_behavior_data
11
+ from replay_trajectory_classification import ClusterlessClassifier, ClusterlessDecoder, Environment, RandomWalk, Uniform, Identity
12
+
13
+ def group_channel_map(channel_map, n=4):
14
+ #TODO test overlapping groups
15
+
16
+ # group each block of channels into groups of n electrodes
17
+ channel_blocks = cluster_probe_channels(channel_map)
18
+ channel_groups = np.zeros(channel_blocks.shape)
19
+ block_length = 0
20
+ group_count = 0
21
+ for b in np.unique(channel_blocks):
22
+ for i,j in enumerate(range(0, (channel_blocks==b).sum(), n)):
23
+ ind = block_length + j
24
+ channel_groups[ind:ind+n] = i + group_count
25
+ group_count = group_count + i + 1
26
+ block_length = block_length + (channel_blocks==b).sum()
27
+
28
+ # calculate group COM
29
+ group_com = []
30
+ for c in np.unique(channel_groups):
31
+ group_com.append(channel_map[channel_groups==c, 1:].mean(0))
32
+ group_com = np.vstack(group_com)
33
+
34
+ # get four closest channels
35
+ group_channels = []
36
+ for i in range(group_com.shape[0]):
37
+ group_channels.append(np.argsort(np.sum((channel_map[:,1:] - group_com[i,:]) ** 2, axis=1) ** 0.5)[0:n])
38
+ group_channels = np.vstack(group_channels)
39
+ return group_channels, group_com
40
+
41
+ def assign_spike_groups(spike_positions, group_com, show_progress=True):
42
+ # assign spikes to each electrode group
43
+ distances = np.zeros((group_com.shape[0], spike_positions.shape[0]))
44
+ for i in tqdm(range(group_com.shape[0]), desc='assigning spikes to groups', disable=not show_progress):
45
+ distances[i,:] = np.sum((spike_positions - group_com[i,:]) ** 2, axis=1) ** 0.5
46
+ distances = np.vstack(distances)
47
+ spike_group = np.nanargmin(distances, axis=0)
48
+ return spike_group
49
+
50
+ def filter_spikes(R):
51
+ # load aligned spikes and behavior
52
+ spike_times, spike_clusters, clust_info = R.get_spikes()
53
+ d = format_behavior_data(R, agent='prey')
54
+
55
+ # remove spikes outside of episode times
56
+ episode_times = np.vstack([R.episodes.get('start_time'), R.episodes.get('end_time')]).T
57
+ spike_index = np.zeros(len(spike_times))
58
+ for i in range(episode_times.shape[0]):
59
+ spike_index = spike_index + ((spike_times > episode_times[i,0]) &
60
+ (spike_times < episode_times[i,1]) &
61
+ (spike_times < d['time_stamp'][-1]))
62
+ return spike_times, spike_clusters, spike_index
63
+
64
+ def extract_spike_template_features(spike_positions, spike_group, amplitudes, templates, spike_templates, group_channels, n=4,
65
+ show_progress=False):
66
+ # get spike features per spike
67
+ spike_features = np.zeros((spike_positions.shape[0], n))
68
+ for i in tqdm(range(spike_positions.shape[0]), desc='extracting spike template features', disable=not show_progress):
69
+ spike_features[i,:] = np.max(amplitudes[i] * templates[spike_templates[i],:,group_channels[spike_group[i],:]], axis=1)
70
+ return spike_features
71
+
72
+ def extract_spike_amplitude_features(spike_positions, spike_group, spike_amps, group_channels, n=4, show_progress=False):
73
+ spike_features = np.zeros((spike_positions.shape[0], n))
74
+ for i in tqdm(range(spike_positions.shape[0]), desc='extracting spike amplitude features', disable=not show_progress):
75
+ spike_features[i,:] = spike_amps[i, group_channels[spike_group[i]]]
76
+ return spike_features
77
+
78
+ def get_bfile(R, return_ops=False, hp_filter=True, whiten=True, dshift=True):
79
+ results_dir = Path(R.spike_path)
80
+ filename = Path(R.get_probes_continuous_paths()[0])
81
+ device = torch.device('cuda')
82
+ ops = load_ops(results_dir / 'ops.npy', device=device)
83
+ chan_map = ops['chanMap']
84
+ if hp_filter:
85
+ hp_filter = ops['fwav']
86
+ else:
87
+ hp_filter = None
88
+ if whiten:
89
+ whiten = ops['Wrot']
90
+ else:
91
+ whiten = None
92
+ if dshift:
93
+ dshift = ops['dshift']
94
+ else:
95
+ dshift=None
96
+ bfile = BinaryFiltered(filename, n_chan_bin=ops['n_chan_bin'], chan_map=chan_map, device=device,
97
+ hp_filter=hp_filter,
98
+ whiten_mat=whiten,
99
+ dshift=dshift)
100
+ if return_ops:
101
+ return bfile, ops
102
+ else:
103
+ return bfile
104
+
105
+ def get_spike_amplitudes(R, show_progress=False):
106
+ # calculate or load spike amplitudes
107
+ results_dir = Path(R.spike_path)
108
+ fn = results_dir / 'spike_amplitudes.npy'
109
+ if not fn.exists():
110
+ bfile = get_binary_file(R.get_binary_files()[0], hp_filter=True, whiten=True, dshift=True)
111
+ spike_times = np.load(results_dir / 'spike_times.npy')
112
+ clu = np.load(results_dir / 'spike_clusters.npy')
113
+ ops = load_ops(results_dir / 'ops.npy')
114
+ spike_amps = np.zeros((len(spike_times), ops['n_chan_bin']))
115
+ for i,t in enumerate(tqdm(spike_times, desc='extracting spike amplitudes'), disable=not show_progress):
116
+ tmin = t - bfile.nt0min
117
+ tmax = t + (bfile.nt - bfile.nt0min) + 1
118
+ if tmin < 0:
119
+ tmin = 0; tmax = bfile.nt + 1
120
+ if tmax > bfile.n_samples:
121
+ tmax = bfile.n_samples; tmin = tmax - bfile.nt - 1
122
+ spike_amps[i,:] = bfile[tmin:tmax].cpu().numpy()[:,ops['nt0min']].astype('float32')
123
+ np.save(fn, spike_amps)
124
+ else:
125
+ if show_progress:
126
+ print(f'loading spike amplitudes from {fn}')
127
+ spike_amps = np.load(fn)
128
+
129
+ return spike_amps
130
+
131
+ def get_multiunits(spike_times, spike_group, spike_features, bins, show_progress=False, check_duplicate_spikes=False):
132
+ spike_bin = np.digitize(spike_times, bins = bins, right=True)
133
+ ugroups = np.unique(spike_group)
134
+ multiunits = np.zeros((len(bins), spike_features.shape[1], len(ugroups)))
135
+ multiunits[:] = np.nan
136
+ if check_duplicate_spikes:
137
+ spike_counts = np.zeros((len(bins), len(ugroups)))
138
+ for i,s in tqdm(enumerate(spike_bin), total=len(spike_bin), desc='adding spike features to multi-unit array', disable=not show_progress):
139
+ multiunits[s-1,:,np.argwhere(ugroups == spike_group[i])] = spike_features[i,:]
140
+ if check_duplicate_spikes:
141
+ spike_counts[s-1, np.argwhere(ugroups == spike_group[i])] += 1
142
+ multiunits = multiunits[:-1,:,:]
143
+ if check_duplicate_spikes:
144
+ return multiunits, spike_counts
145
+ else:
146
+ return multiunits
147
+
148
+ def get_cv_folds(n_samples, n_folds=10, return_boolean=True):
149
+ cv_runs = []
150
+ run_size = int(np.ceil(n_samples / n_folds))
151
+ for i in range(n_folds):
152
+ train_bool = np.zeros(n_samples, dtype=bool)
153
+ train_ind = [i*run_size,
154
+ np.min([(i+1)*run_size, n_samples])]
155
+ train_bool[train_ind[0]:train_ind[1]] = True
156
+ if not return_boolean:
157
+ train = np.argwhere(~train_bool)
158
+ test = np.argwhere(train_bool)
159
+ else:
160
+ train = ~train_bool
161
+ test = train_bool
162
+ cv_runs.append([train, test])
163
+ return cv_runs
164
+
165
+ def preprocess_data(R, ops, verbose=False):
166
+ print('PREPROCESSING DATA')
167
+ # format behavior
168
+ d = format_behavior_data(R, agent=ops['agents'])
169
+
170
+ # get spike groups
171
+ results_dir = Path(R.spike_path)
172
+ spike_positions = np.load(results_dir / 'spike_positions.npy')
173
+ channel_map = R.get_probe_channel_map()[0]
174
+ group_channels, group_com = group_channel_map(channel_map, n=ops['n'])
175
+ spike_group = assign_spike_groups(spike_positions, group_com, show_progress=verbose)
176
+
177
+ # get spike amplitudes (takes a while first time)
178
+ spike_amps = get_spike_amplitudes(R, show_progress=verbose)
179
+
180
+ # remove out-of-episode spikes
181
+ spike_times, spike_clusters, spike_index = filter_spikes(R)
182
+
183
+ # remove spikes from noise clusters
184
+ good_units = np.argwhere(R.population.get('good_unit'))
185
+ print(f'including spikes from {len(good_units)} single/multi-units with good waveforms')
186
+ good_spikes = np.zeros(spike_clusters.shape)
187
+ for u in good_units:
188
+ good_spikes = good_spikes + (spike_clusters == u)
189
+ spike_index = (spike_index == 1) & (good_spikes == 1)
190
+
191
+ # filter spikes
192
+ spike_positions = spike_positions[spike_index == 1]
193
+ spike_times = spike_times[spike_index == 1]
194
+ spike_clusters = spike_clusters[spike_index == 1]
195
+ spike_group = spike_group[spike_index == 1]
196
+ spike_amps = spike_amps[spike_index == 1]
197
+
198
+ # extract spike features
199
+ spike_features = extract_spike_amplitude_features(spike_positions, spike_group, spike_amps, group_channels, show_progress=verbose)
200
+
201
+ # create multiunit array
202
+ bins = np.arange(np.nanmin(d['time_stamp']) - (ops['dt']/2), np.nanmax(d['time_stamp']) + (ops['dt']/2), ops['dt'])
203
+ multiunits = get_multiunits(spike_times=spike_times, spike_group=spike_group, spike_features=spike_features, bins=bins, show_progress=verbose)
204
+
205
+ # bin the recording
206
+ binned_data = bin_recording(R, agent=ops['agents'], dt=ops['dt'], skip_spikes=True,
207
+ kalman_filter=ops['kalman_filter'], show_progress=verbose)
208
+
209
+ # remove data where agents were not tracked
210
+ column = [i for i in binned_data.columns if 'tracked' in i][0]
211
+ tracked = binned_data[column]==1
212
+ mua = multiunits[tracked,:,:]
213
+ binned_data = binned_data[tracked]
214
+
215
+ data = {
216
+ 'mua': mua,
217
+ 'time': np.array(binned_data['time_stamp']),
218
+ 'position': np.vstack(binned_data['prey_location']) * ops['canonical_to_cm'],
219
+ 'velocity': np.vstack(binned_data['prey_velocity']) * ops['canonical_to_cm']
220
+ }
221
+ return data, binned_data
222
+
223
+ def set_ops(dt=0.002, n=4, agents='prey', canonical_to_cm=234, velocity_cutoff=0.1, cv_folds=5, cv_type='fold', kalman_filter=True, bin_size=5,
224
+ random_walk_var=6, mark_var=24, position_var=6, drop_causal_posterior=True, states=['continuous']):
225
+
226
+ # ensure correct ordering for classifier
227
+ if 'continuous' in states and 'fragmented' in states and 'stationary' in states:
228
+ states = ['continuous', 'fragmented', 'stationary']
229
+ elif 'continuous' in states and 'fragmented' in states and 'stationary' not in states:
230
+ states = ['continuous', 'fragmented']
231
+
232
+ # options
233
+ ops = {
234
+ 'dt': dt, # decoding resolution (ms)
235
+ 'n': n, # spike features
236
+ 'agents': agents, # agents to include
237
+ 'canonical_to_cm': canonical_to_cm, # convert canonical units to cm
238
+ 'velocity_cutoff': velocity_cutoff, # velocity cutoff for training (cm/s)
239
+ 'cv_folds': cv_folds, # cross validation folds
240
+ 'cv_type': cv_type, # cross validation type ("fold" for fold split, "speed" for speed split)
241
+ 'kalman_filter': kalman_filter, # kalman smooth raw data
242
+ 'bin_size': bin_size, # bin size for rate maps (cm)
243
+ 'random_walk_var': random_walk_var, # variance of movement of transition model (cm)
244
+ 'mark_var': mark_var, # variance of encoding model mark space (~uV)
245
+ 'position_var': position_var, # variance of encoding model position (cm)
246
+ 'drop_causal_posterior': drop_causal_posterior,
247
+ 'states': states
248
+ }
249
+ return ops
250
+
251
+ def build_continuous_transitions(ops):
252
+ state_transitions = {'continuous': RandomWalk(movement_var=ops['random_walk_var']),
253
+ 'fragmented': Uniform(),
254
+ 'stationary': Identity()}
255
+ if 'continuous' in ops['states'] and 'fragmented' in ops['states'] and 'stationary' in ops['states']:
256
+ return [
257
+ [RandomWalk(movement_var=ops['random_walk_var']), Uniform(), Identity()],
258
+ [Uniform(), Uniform(), Uniform()],
259
+ [RandomWalk(movement_var=ops['random_walk_var']), Uniform(), Identity()],
260
+ ]
261
+ elif 'continuous' in ops['states'] and 'fragmented' in ops['states'] and 'stationary' not in ops['states']:
262
+ return [
263
+ [RandomWalk(movement_var=ops['random_walk_var']), Uniform()],
264
+ [Uniform(), Uniform()]
265
+ ]
266
+ elif len(ops['states']) == 1:
267
+ return state_transitions[ops['states'][0]]
268
+ else:
269
+ raise AssertionError("ops['states'] must be ['continuous', 'fragmented', 'stationary'], ['continuous', 'fragmented'] or ['continuous'], ['fragmented'], or ['stationary']")
270
+
271
+ def run_decoder(data, ops):
272
+ print('TRAIN/TEST DECODER')
273
+ # cross validation
274
+ if ops['cv_type'] == 'speed':
275
+ # define training and testing index
276
+ moving = np.argwhere(data['velocity'].squeeze() > ops['velocity_cutoff'])
277
+ train = moving[0:int(len(moving)/2)].copy().squeeze()
278
+ test = moving[int(len(moving)/2):].copy().squeeze()
279
+ cv_runs = [train, test]
280
+ else:
281
+ cv_runs = get_cv_folds(data['position'].shape[0], ops['cv_folds'])
282
+
283
+ # model setup
284
+ lims = (ops['canonical_to_cm']*0.05, ops['canonical_to_cm']*1.05)
285
+ environment = Environment(place_bin_size=ops['bin_size'], position_range=[lims,lims])
286
+ assert len(ops['states']) == 1, f'For standard decoding there must be one state, not {ops["states"]}, try run_classifier instead...'
287
+ transition_type = build_continuous_transitions(ops)
288
+ clusterless_algorithm = 'multiunit_likelihood_gpu'
289
+ clusterless_algorithm_params = {
290
+ 'mark_std': ops['mark_var'],
291
+ 'position_std': ops['position_var']
292
+ }
293
+
294
+ # cv loop
295
+ decoders = []
296
+ results = []
297
+ for fold in tqdm(cv_runs, desc='cross-validation fold'):
298
+ decoder = ClusterlessDecoder(
299
+ environment=environment,
300
+ transition_type=transition_type,
301
+ clusterless_algorithm=clusterless_algorithm,
302
+ clusterless_algorithm_params=clusterless_algorithm_params)
303
+
304
+ decoder.fit(data['position'][fold[0],:], data['mua'][fold[0],:,:])
305
+ decoders.append(decoder)
306
+
307
+ result = decoder.predict(data['mua'][fold[1],:,:], time=data['time'][fold[1]], use_gpu=True)
308
+ results.append(result)
309
+
310
+ # compile across runs
311
+ map_estimate = []
312
+ for r in results:
313
+ post = r.acausal_posterior.stack(position=['x_position', 'y_position'])
314
+ map = post.position[post.argmax('position')]
315
+ map = np.asarray(map.values.tolist())
316
+ map_estimate.append(map)
317
+ map_estimate = np.vstack(map_estimate)
318
+ error = np.linalg.norm(map_estimate - data['position'], axis=1)
319
+ result = {
320
+ 'dist_error': error,
321
+ 'map_estimate': map_estimate,
322
+ 'cv_results': {'decoders': decoders, 'results': results}
323
+ }
324
+
325
+ return result
326
+
327
+ def run_classifier(data, ops, drop_causal_posterior=True):
328
+ print('TRAIN/TEST CLASSIFIER')
329
+ # cross validation
330
+ if ops['cv_type'] == 'speed':
331
+ # define training and testing index
332
+ moving = np.argwhere(data['velocity'].squeeze() > ops['velocity_cutoff'])
333
+ train = moving[0:int(len(moving)/2)].copy().squeeze()
334
+ test = moving[int(len(moving)/2):].copy().squeeze()
335
+ cv_runs = [train, test]
336
+ else:
337
+ cv_runs = get_cv_folds(data['position'].shape[0], ops['cv_folds'])
338
+
339
+ # model setup
340
+ environment = Environment(place_bin_size=ops['bin_size'],
341
+ position_range=[(0,ops['canonical_to_cm']),(0,ops['canonical_to_cm'])])
342
+ continuous_transition_types = build_continuous_transitions(ops)
343
+ if len(ops['states']) == 1:
344
+ continuous_transition_types = [[continuous_transition_types]]
345
+ clusterless_algorithm = 'multiunit_likelihood_gpu'
346
+ clusterless_algorithm_params = {
347
+ 'mark_std': ops['mark_var'],
348
+ 'position_std': ops['position_var']
349
+ }
350
+
351
+ classifiers = []
352
+ results = []
353
+ for fold in tqdm(cv_runs, desc='cross-validation fold'):
354
+ classifier = ClusterlessClassifier(
355
+ environments=environment,
356
+ continuous_transition_types=continuous_transition_types,
357
+ clusterless_algorithm=clusterless_algorithm,
358
+ clusterless_algorithm_params=clusterless_algorithm_params)
359
+ classifier.fit(data['position'][fold[0],:], data['mua'][fold[0],:,:])
360
+ classifiers.append(classifier)
361
+
362
+ result = classifier.predict(data['mua'][fold[1],:,:], time=data['time'][fold[1]], use_gpu=True)
363
+ result['state'] = ops['states']
364
+ if drop_causal_posterior:
365
+ result.drop('causal_posterior')
366
+ results.append(result)
367
+
368
+ map_estimate = []
369
+ for r in results:
370
+ post = r.acausal_posterior.sum('state').stack(position=['x_position', 'y_position'])
371
+ map = post.position[post.argmax('position')]
372
+ map = np.asarray(map.values.tolist())
373
+ map_estimate.append(map)
374
+ map_estimate = np.vstack(map_estimate)
375
+ error = np.linalg.norm(map_estimate - data['position'], axis=1)
376
+ result = {
377
+ 'dist_error': error,
378
+ 'map_estimate': map_estimate,
379
+ 'cv_results': {'classifiers': classifiers, 'results': results}
380
+ }
381
+
382
+ return result
383
+
384
+ def save_results(fn, result_list:list):
385
+ with open(fn, 'wb') as fid:
386
+ for r in result_list:
387
+ pickle.dump(r, fid)