labdata 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- labdata/__init__.py +17 -0
- labdata/cli.py +499 -0
- labdata/compute/__init__.py +27 -0
- labdata/compute/ec2.py +198 -0
- labdata/compute/ephys.py +469 -0
- labdata/compute/pose.py +281 -0
- labdata/compute/schedulers.py +194 -0
- labdata/compute/singularity.py +95 -0
- labdata/compute/utils.py +561 -0
- labdata/copy.py +351 -0
- labdata/rules/__init__.py +78 -0
- labdata/rules/ephys.py +188 -0
- labdata/rules/imaging.py +618 -0
- labdata/rules/utils.py +290 -0
- labdata/s3.py +317 -0
- labdata/schema/__init__.py +24 -0
- labdata/schema/ephys.py +547 -0
- labdata/schema/general.py +647 -0
- labdata/schema/histology.py +309 -0
- labdata/schema/onephoton.py +93 -0
- labdata/schema/procedures.py +102 -0
- labdata/schema/tasks.py +66 -0
- labdata/schema/twophoton.py +142 -0
- labdata/schema/utils.py +25 -0
- labdata/schema/video.py +243 -0
- labdata/stacks.py +182 -0
- labdata/utils.py +598 -0
- labdata/widgets.py +412 -0
- labdata-0.0.3.dist-info/METADATA +42 -0
- labdata-0.0.3.dist-info/RECORD +36 -0
- labdata-0.0.3.dist-info/WHEEL +5 -0
- labdata-0.0.3.dist-info/entry_points.txt +2 -0
- labdata-0.0.3.dist-info/licenses/LICENSE +674 -0
- labdata-0.0.3.dist-info/top_level.txt +2 -0
- labdata_frontend/Home.py +39 -0
- labdata_frontend/__init__.py +0 -0
labdata/compute/ec2.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
from ..utils import *
|
|
2
|
+
|
|
3
|
+
def ec2_connect(access_key = None,secret_key = None, region = None):
|
|
4
|
+
import boto3
|
|
5
|
+
if 'aws' in prefs['compute'].keys():
|
|
6
|
+
if 'access_key' in prefs['compute']['aws'].keys():
|
|
7
|
+
access_key = prefs['compute']['aws']['access_key']
|
|
8
|
+
secret_key = prefs['compute']['aws']['secret_key']
|
|
9
|
+
region = prefs['compute']['aws']['region']
|
|
10
|
+
if access_key is None:
|
|
11
|
+
raise ValueError('Need to supply an access key to access ec2, set compute:aws:access_key in the preference file.')
|
|
12
|
+
if region[-1].isalpha(): # then it includes the availability zone
|
|
13
|
+
region = region[:-1]
|
|
14
|
+
botosession = boto3.Session(
|
|
15
|
+
aws_access_key_id=access_key,
|
|
16
|
+
aws_secret_access_key=secret_key, region_name = region)
|
|
17
|
+
|
|
18
|
+
ec2 = botosession.resource('ec2',region_name = region)
|
|
19
|
+
return (botosession,ec2)
|
|
20
|
+
|
|
21
|
+
def ec2_get_key(ec2 = None, keyname = None):
|
|
22
|
+
keyspath = Path(prefs['compute']['aws']['access_key_folder'])
|
|
23
|
+
keys = list(keyspath.glob('*'))
|
|
24
|
+
if not len(keys):
|
|
25
|
+
date = datetime.now().strftime('%Y%m%d_%H:%M:%S')
|
|
26
|
+
keyname = f"ec2-labdata-{prefs['hostname']}-{date}"
|
|
27
|
+
if ec2 is None:
|
|
28
|
+
session,ec2 = ec2_connect()
|
|
29
|
+
key = ec2.create_key_pair(KeyName=keyname)
|
|
30
|
+
# save key info
|
|
31
|
+
keyspath.mkdir(parents=True, exist_ok=True)
|
|
32
|
+
with open(keyspath/keyname,'w') as fd:
|
|
33
|
+
keyinfo = dict(key_name = key.key_name,
|
|
34
|
+
key_material = key.key_material,
|
|
35
|
+
key_pair_id = key.key_pair_id)
|
|
36
|
+
json.dump(keyinfo,
|
|
37
|
+
fd,
|
|
38
|
+
indent = 4)
|
|
39
|
+
else:
|
|
40
|
+
with open(keys[0],'r') as fd:
|
|
41
|
+
keyinfo = json.load(fd)
|
|
42
|
+
return keyinfo
|
|
43
|
+
|
|
44
|
+
def ec2_instance_from_id(ec2,instance_id):
|
|
45
|
+
if ec2 is None:
|
|
46
|
+
session,ec2 = ec2_connect()
|
|
47
|
+
instances = list(ec2.instances.filter(InstanceIds=[instance_id]))
|
|
48
|
+
if not len(instances):
|
|
49
|
+
print(f'There are no instances with id: {instance_id}')
|
|
50
|
+
elif len(instances)!=1:
|
|
51
|
+
print(f'There are multiple instances with id: {instance_id}')
|
|
52
|
+
return instances
|
|
53
|
+
else:
|
|
54
|
+
return instances[0]
|
|
55
|
+
|
|
56
|
+
def ec2_create_instance(ec2,
|
|
57
|
+
image_id = "linux",
|
|
58
|
+
instance_type = "t2.micro",
|
|
59
|
+
key_name = None,
|
|
60
|
+
availability_zone = None,
|
|
61
|
+
security_groups = None, # these should come from the preferences
|
|
62
|
+
user_data = 'echo hostname'):
|
|
63
|
+
if not image_id in prefs['compute']['aws']['image_ids'].keys():
|
|
64
|
+
raise ValueError(f'image_id {image_id} is not in the preference_file {list(prefs["compute"]["aws"]["image_ids"].keys())}')
|
|
65
|
+
|
|
66
|
+
if ec2 is None:
|
|
67
|
+
session,ec2 = ec2_connect()
|
|
68
|
+
if security_groups is None:
|
|
69
|
+
security_groups = prefs['compute']['aws']['security_groups']
|
|
70
|
+
image_id = prefs['compute']['aws']['image_ids'][image_id]
|
|
71
|
+
if key_name is None:
|
|
72
|
+
keyinfo = ec2_get_key(ec2)
|
|
73
|
+
key_name = keyinfo['key_name']
|
|
74
|
+
if availability_zone is None:
|
|
75
|
+
availability_zone = prefs['compute']['aws']['region']
|
|
76
|
+
insdict = dict(instance = ec2.create_instances(
|
|
77
|
+
ImageId = image_id['ami'],
|
|
78
|
+
MinCount=1,
|
|
79
|
+
MaxCount=1,
|
|
80
|
+
InstanceType=instance_type,
|
|
81
|
+
KeyName=key_name,
|
|
82
|
+
InstanceInitiatedShutdownBehavior='terminate',
|
|
83
|
+
UserData = user_data,
|
|
84
|
+
Placement={'AvailabilityZone':availability_zone},
|
|
85
|
+
SecurityGroups=security_groups)[0],
|
|
86
|
+
key_name = key_name,
|
|
87
|
+
instance_type = instance_type,
|
|
88
|
+
user_name = image_id['user'],
|
|
89
|
+
ami = image_id['ami'])
|
|
90
|
+
insdict['id'] = insdict['instance'].id
|
|
91
|
+
#print(user_data)
|
|
92
|
+
return insdict
|
|
93
|
+
|
|
94
|
+
def ec2_wait_for_instance(ec2,instancedict,desired = 'running',interval = 0.05):
|
|
95
|
+
if ec2 is None:
|
|
96
|
+
session,ec2 = ec2_connect()
|
|
97
|
+
|
|
98
|
+
instance = ec2_instance_from_id(instancedict['id'])
|
|
99
|
+
instance.wait_until_running()
|
|
100
|
+
import time
|
|
101
|
+
while instance.state['Name'] != desired:
|
|
102
|
+
time.sleep(interval)
|
|
103
|
+
instance = ec2_instance_from_id(ins['id'])
|
|
104
|
+
instancedict['instance'] = instance
|
|
105
|
+
return instance
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def ec2_instance_ssh(instance, user = 'ubuntu'):
|
|
109
|
+
import paramiko
|
|
110
|
+
|
|
111
|
+
ip_address = instance.public_dns_name
|
|
112
|
+
|
|
113
|
+
ssh = paramiko.SSHClient()
|
|
114
|
+
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
115
|
+
|
|
116
|
+
privkey = ec2_get_key()['key_material']
|
|
117
|
+
try:
|
|
118
|
+
from StringIO import StringIO
|
|
119
|
+
except ImportError:
|
|
120
|
+
from io import StringIO # Python3
|
|
121
|
+
privkey = paramiko.RSAKey.from_private_key(StringIO(privkey))
|
|
122
|
+
|
|
123
|
+
print('SSH into the instance: {}'.format(ip_address))
|
|
124
|
+
ssh.connect(hostname=ip_address,
|
|
125
|
+
username=user, pkey=privkey)
|
|
126
|
+
return ssh
|
|
127
|
+
|
|
128
|
+
def ec2_cmd_for_launch(singularity_container,
|
|
129
|
+
singularity_command,
|
|
130
|
+
singularity_cuda = False,
|
|
131
|
+
nvme = 'nvme1n1',
|
|
132
|
+
shutdown_when_done = True,
|
|
133
|
+
is_self_contained = True,
|
|
134
|
+
append_log = True):
|
|
135
|
+
userdata = ''
|
|
136
|
+
if not nvme is None:
|
|
137
|
+
userdata += f'''#! /bin/bash
|
|
138
|
+
# mount the data drive and set permissions
|
|
139
|
+
mkfs -t ext4 /dev/{nvme}
|
|
140
|
+
mkdir /data
|
|
141
|
+
mount /dev/{nvme} /data/
|
|
142
|
+
chown -R ubuntu /data
|
|
143
|
+
# make home point to /data where there is space
|
|
144
|
+
export HOME="/data"
|
|
145
|
+
'''
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
storage = prefs['storage'][prefs['compute']['containers']['storage']]
|
|
149
|
+
|
|
150
|
+
userdata += f'''
|
|
151
|
+
mkdir $HOME/.aws
|
|
152
|
+
echo "[default]" > $HOME/.aws/credentials
|
|
153
|
+
echo aws_access_key_id={storage['access_key']} >> $HOME/.aws/credentials
|
|
154
|
+
echo aws_secret_access_key={storage['secret_key']} >> $HOME/.aws/credentials
|
|
155
|
+
mkdir -p $HOME/labdata/containers
|
|
156
|
+
echo "Downloading container"
|
|
157
|
+
aws s3 cp s3://{storage['bucket']}/containers/{singularity_container}.sif $HOME/labdata/containers/
|
|
158
|
+
|
|
159
|
+
'''
|
|
160
|
+
if is_self_contained:
|
|
161
|
+
ec2pref = json.dumps(dict(compute = dict(containers=dict(local = '/data'),
|
|
162
|
+
analysis = prefs['compute']['analysis'],
|
|
163
|
+
default_target = 'local'),
|
|
164
|
+
database = prefs['database'],
|
|
165
|
+
local_paths = ['/data'],
|
|
166
|
+
scratch_path = '/data',
|
|
167
|
+
path_rules = prefs['path_rules'],
|
|
168
|
+
storage = prefs['storage'],
|
|
169
|
+
allow_s3_download = True,
|
|
170
|
+
use_awscli = True))
|
|
171
|
+
cuda = ''
|
|
172
|
+
if singularity_cuda:
|
|
173
|
+
cuda = '--nv'
|
|
174
|
+
userdata += '''
|
|
175
|
+
modprobe nvidia-uvm
|
|
176
|
+
nvidia-container-cli -k list
|
|
177
|
+
'''
|
|
178
|
+
userdata += f'''
|
|
179
|
+
cat > $HOME/labdata/user_preferences.json << EOL
|
|
180
|
+
{ec2pref}
|
|
181
|
+
EOL
|
|
182
|
+
|
|
183
|
+
mkdir -p /home/ubuntu/labdata
|
|
184
|
+
cp $HOME/labdata/user_preferences.json /home/ubuntu/labdata/
|
|
185
|
+
mkdir -p /home/ubuntu/.cache/torch/kernels
|
|
186
|
+
sudo chown -R ubuntu /home/ubuntu
|
|
187
|
+
|
|
188
|
+
sudo -u ubuntu bash -c "cd /home/ubuntu; singularity exec {cuda} --bind /data:/data $HOME/labdata/containers/{singularity_container}.sif {singularity_command}'''
|
|
189
|
+
if not append_log is None:
|
|
190
|
+
userdata += f' |& singularity exec $HOME/labdata/containers/{singularity_container}.sif labdata2 logpipe {append_log}'
|
|
191
|
+
userdata += '"'
|
|
192
|
+
if shutdown_when_done:
|
|
193
|
+
userdata += '''
|
|
194
|
+
|
|
195
|
+
shutdown now -h
|
|
196
|
+
'''
|
|
197
|
+
|
|
198
|
+
return userdata
|
labdata/compute/ephys.py
ADDED
|
@@ -0,0 +1,469 @@
|
|
|
1
|
+
from ..utils import *
|
|
2
|
+
from .utils import BaseCompute
|
|
3
|
+
|
|
4
|
+
class SpksCompute(BaseCompute):
|
|
5
|
+
container = 'labdata-spks'
|
|
6
|
+
cuda = True
|
|
7
|
+
ec2 = dict(small = dict(instance_type = 'g4dn.2xlarge'), # 8 cpus, 32 GB mem, 200 GB nvme, 1 gpu
|
|
8
|
+
large = dict(instance_type = 'g6.4xlarge',
|
|
9
|
+
availability_zone = 'us-west-2b')) # 16 cpus, 64 GB mem, 600 GB nvme, 1 gpu
|
|
10
|
+
name = 'spks'
|
|
11
|
+
url = 'http://github.com/spkware/spks'
|
|
12
|
+
def __init__(self,job_id, allow_s3 = None, delete_results = True, **kwargs):
|
|
13
|
+
'''
|
|
14
|
+
#1) find the files
|
|
15
|
+
#2) copy just the file you need to scratch
|
|
16
|
+
#3) run spike sorting on that file/folder
|
|
17
|
+
#4) delete the raw files
|
|
18
|
+
#5) repeat until all probes are processed.
|
|
19
|
+
'''
|
|
20
|
+
super(SpksCompute,self).__init__(job_id, allow_s3 = allow_s3)
|
|
21
|
+
self.file_filters = ['.ap.']
|
|
22
|
+
# default parameters
|
|
23
|
+
self.parameters = dict(algorithm_name = 'spks_kilosort4.0',
|
|
24
|
+
motion_correction = True,
|
|
25
|
+
low_pass = 300.,
|
|
26
|
+
high_pass = 13000.)
|
|
27
|
+
self.use_hdf5 = True # flag to use h5py or zarr format for the waveforms.
|
|
28
|
+
self.parameter_set_num = None # identifier in SpikeSortingParams
|
|
29
|
+
self._init_job()
|
|
30
|
+
if not self.job_id is None:
|
|
31
|
+
self.add_parameter_key()
|
|
32
|
+
self.delete_results = delete_results
|
|
33
|
+
|
|
34
|
+
def _get_parameter_number(self):
|
|
35
|
+
parameter_set_num = None
|
|
36
|
+
from ..schema import SpikeSorting, SpikeSortingParams, EphysRecording
|
|
37
|
+
# check if in spike sorting
|
|
38
|
+
parameters = pd.DataFrame(SpikeSortingParams().fetch())
|
|
39
|
+
for i,r in parameters.iterrows():
|
|
40
|
+
# go through every algo
|
|
41
|
+
if self.parameters == json.loads(r.parameters_dict):
|
|
42
|
+
parameter_set_num = r.parameter_set_num
|
|
43
|
+
if parameter_set_num is None:
|
|
44
|
+
if len(parameters) == 0:
|
|
45
|
+
parameter_set_num = 1
|
|
46
|
+
else:
|
|
47
|
+
parameter_set_num = np.max(parameters.parameter_set_num.values)+1
|
|
48
|
+
return parameter_set_num,parameters
|
|
49
|
+
|
|
50
|
+
def add_parameter_key(self):
|
|
51
|
+
parameter_set_num, parameters = self._get_parameter_number()
|
|
52
|
+
from ..schema import SpikeSorting, SpikeSortingParams, EphysRecording
|
|
53
|
+
if not parameter_set_num in parameters.parameter_set_num.values:
|
|
54
|
+
SpikeSortingParams().insert1(dict(parameter_set_num = parameter_set_num,
|
|
55
|
+
algorithm_name = self.parameters['algorithm_name'],
|
|
56
|
+
parameters_dict = json.dumps(self.parameters),
|
|
57
|
+
code_link = self.url),
|
|
58
|
+
skip_duplicates=True)
|
|
59
|
+
self.parameter_set_num = parameter_set_num
|
|
60
|
+
recordings = EphysRecording.ProbeSetting() & dict(self.dataset_key)
|
|
61
|
+
sortings = SpikeSorting() & dict(self.dataset_key, parameter_set_num = self.parameter_set_num)
|
|
62
|
+
if len(recordings) == len(sortings):
|
|
63
|
+
self.set_job_status(
|
|
64
|
+
job_status = 'FAILED',
|
|
65
|
+
job_waiting = 0,
|
|
66
|
+
job_log = f'{self.dataset_key} was already sorted with parameters {self.parameter_set_num}.')
|
|
67
|
+
raise(ValueError(f'{self.dataset_key} was already sorted with parameters {self.parameter_set_num}.'))
|
|
68
|
+
|
|
69
|
+
def _secondary_parse(self,arguments,parameter_number = None):
|
|
70
|
+
'''
|
|
71
|
+
Handles parsing the command line interface
|
|
72
|
+
'''
|
|
73
|
+
if not parameter_number is None:
|
|
74
|
+
from ..schema import SpikeSortingParams
|
|
75
|
+
self.parameters = ((SpikeSortingParams() & f'parameter_set_num = {parameter_number}')).fetch(as_dict = True)
|
|
76
|
+
if not len(self.parameters):
|
|
77
|
+
raise(f'Could not find parameter {parameter_number} in SpikeSortingParams.')
|
|
78
|
+
self.parameters = self.parameters[0]
|
|
79
|
+
else:
|
|
80
|
+
import argparse
|
|
81
|
+
parser = argparse.ArgumentParser(
|
|
82
|
+
description = 'Analysis of spike data using kilosort version 2.5 through the spks package.',
|
|
83
|
+
usage = 'spks -a <SUBJECT> -s <SESSION> -- <PARAMETERS>')
|
|
84
|
+
|
|
85
|
+
parser.add_argument('-p','--probe',
|
|
86
|
+
action='store', default=None, type = int,
|
|
87
|
+
help = "THIS DOES NOTHING NOW. WILL BE FOR OPENING PHY")
|
|
88
|
+
parser.add_argument('-m','--method',action='store',default = 'ks2.5',type = str,
|
|
89
|
+
help = 'Method for spike sorting ks2.5 [Kilosort], ks3.0, ks4.0, ms5 [MountainSort]')
|
|
90
|
+
parser.add_argument('-l','--low-pass',
|
|
91
|
+
action='store', default=self.parameters['low_pass'], type = float,
|
|
92
|
+
help = "Lowpass filter (default 300.Hz)")
|
|
93
|
+
parser.add_argument('-i','--high-pass',
|
|
94
|
+
action='store', default=self.parameters['high_pass'], type = float,
|
|
95
|
+
help = "Highpass filter (default 13000.Hz)")
|
|
96
|
+
parser.add_argument('-t','--thresholds',
|
|
97
|
+
action='store', default=None, type = float, nargs = 2,
|
|
98
|
+
help = "Thresholds for spike detection default depends on method.")
|
|
99
|
+
parser.add_argument('-n','--no-motion-correction',
|
|
100
|
+
action='store_false', default = True,
|
|
101
|
+
help = "Skip motion correction")
|
|
102
|
+
parser.add_argument('-c','--remove_cross-unit-duplicates',
|
|
103
|
+
action='store_true', default = False,
|
|
104
|
+
help = "Skip removing duplicates across units.")
|
|
105
|
+
|
|
106
|
+
args = parser.parse_args(arguments[1:])
|
|
107
|
+
if 'ks2.5' in args.method: # defaults for ks25
|
|
108
|
+
self.parameters = dict(algorithm_name = 'spks_kilosort2.5',
|
|
109
|
+
motion_correction = args.no_motion_correction,
|
|
110
|
+
low_pass = args.low_pass,
|
|
111
|
+
high_pass = args.high_pass,
|
|
112
|
+
thresholds = [9.,3.],
|
|
113
|
+
remove_cross_duplicates = args.remove_cross_unit_duplicates)
|
|
114
|
+
elif 'ks3.0' in args.method: # defaults for ks3.0
|
|
115
|
+
self.parameters = dict(algorithm_name = 'spks_kilosort3.0',
|
|
116
|
+
motion_correction = args.no_motion_correction,
|
|
117
|
+
low_pass = args.low_pass,
|
|
118
|
+
high_pass = args.high_pass,
|
|
119
|
+
thresholds = [9.,9.],
|
|
120
|
+
remove_cross_duplicates = args.remove_cross_unit_duplicates)
|
|
121
|
+
elif 'ks4.0' in args.method: # defaults for ks3.0
|
|
122
|
+
self.parameters = dict(algorithm_name = 'spks_kilosort4.0',
|
|
123
|
+
motion_correction = args.no_motion_correction,
|
|
124
|
+
low_pass = args.low_pass,
|
|
125
|
+
high_pass = args.high_pass,
|
|
126
|
+
thresholds = [9.,8.],
|
|
127
|
+
remove_cross_duplicates = args.remove_cross_unit_duplicates)
|
|
128
|
+
else:
|
|
129
|
+
raise(NotImplemented(f'{args.method} not implemented.'))
|
|
130
|
+
|
|
131
|
+
if not args.thresholds is None:
|
|
132
|
+
self.parameters['thresholds'] = args.thresholds
|
|
133
|
+
|
|
134
|
+
self.probe = args.probe
|
|
135
|
+
|
|
136
|
+
def find_datasets(self, subject_name = None, session_name = None):
|
|
137
|
+
'''
|
|
138
|
+
Searches for subjects and sessions in EphysRecording
|
|
139
|
+
'''
|
|
140
|
+
if subject_name is None and session_name is None:
|
|
141
|
+
print("\n\nPlease specify a 'subject_name' and a 'session_name' to perform spike-sorting.\n\n")
|
|
142
|
+
from ..schema import EphysRecording, SpikeSorting
|
|
143
|
+
|
|
144
|
+
keys = []
|
|
145
|
+
if not subject_name is None:
|
|
146
|
+
if len(subject_name) > 1:
|
|
147
|
+
raise ValueError(f'Please submit one subject at a time {subject_name}.')
|
|
148
|
+
if not subject_name[0] == '':
|
|
149
|
+
subject_name = subject_name[0]
|
|
150
|
+
if not session_name is None:
|
|
151
|
+
for s in session_name:
|
|
152
|
+
if not s == '':
|
|
153
|
+
keys.append(dict(subject_name = subject_name,
|
|
154
|
+
session_name = s))
|
|
155
|
+
else:
|
|
156
|
+
# find all sessions that can be spike sorted
|
|
157
|
+
parameter_set_num, parameters = self._get_parameter_number()
|
|
158
|
+
sessions = np.unique(((
|
|
159
|
+
(EphysRecording() & f'subject_name = "{subject_name}"') -
|
|
160
|
+
(SpikeSorting() & f'parameter_set_num = {parameter_set_num}'))).fetch('session_name'))
|
|
161
|
+
for ses in sessions:
|
|
162
|
+
keys.append(dict(subject_name = subject_name,
|
|
163
|
+
session_name = ses))
|
|
164
|
+
datasets = []
|
|
165
|
+
for k in keys:
|
|
166
|
+
datasets += (EphysRecording()& k).proj('subject_name','session_name','dataset_name').fetch(as_dict = True)
|
|
167
|
+
return datasets
|
|
168
|
+
|
|
169
|
+
def _compute(self):
|
|
170
|
+
from ..schema import EphysRecording
|
|
171
|
+
datasets = pd.DataFrame((EphysRecording.ProbeFile() & self.dataset_key).fetch())
|
|
172
|
+
|
|
173
|
+
for probe_num in np.unique(datasets.probe_num):
|
|
174
|
+
self.set_job_status(job_log = f'Sorting {probe_num}')
|
|
175
|
+
files = datasets[datasets.probe_num.values == probe_num]
|
|
176
|
+
dset = []
|
|
177
|
+
for i,f in files.iterrows():
|
|
178
|
+
if 'ap.cbin' in f.file_path or 'ap.ch' in f.file_path:
|
|
179
|
+
dset.append(i)
|
|
180
|
+
elif 'ap.meta' in f.file_path: # requires a metadata file (spikeglx)
|
|
181
|
+
dset.append(i)
|
|
182
|
+
dset = files.loc[dset]
|
|
183
|
+
if not len(dset):
|
|
184
|
+
print(files)
|
|
185
|
+
raise(ValueError(f'Could not find ap.cbin files for probe {probe_num}'))
|
|
186
|
+
localfiles = self.get_files(dset, allowed_extensions = ['.ap.bin'])
|
|
187
|
+
probepath = list(filter(lambda x: str(x).endswith('bin'),localfiles))
|
|
188
|
+
# print(probepath)
|
|
189
|
+
if 'kilosort' in self.parameters['algorithm_name']:
|
|
190
|
+
from spks.sorting import run_kilosort
|
|
191
|
+
if self.parameters['algorithm_name'] == 'spks_kilosort2.5':
|
|
192
|
+
results_folder = run_kilosort(version = '2.5',sessionfiles = probepath,
|
|
193
|
+
temporary_folder = prefs['scratch_path'],
|
|
194
|
+
do_post_processing = False,
|
|
195
|
+
motion_correction = self.parameters['motion_correction'],
|
|
196
|
+
thresholds = self.parameters['thresholds'],
|
|
197
|
+
lowpass = self.parameters['low_pass'],
|
|
198
|
+
highpass = self.parameters['high_pass'])
|
|
199
|
+
elif self.parameters['algorithm_name'] == 'spks_kilosort3.0':
|
|
200
|
+
results_folder = run_kilosort(version = '3.0',
|
|
201
|
+
sessionfiles = probepath,
|
|
202
|
+
temporary_folder = prefs['scratch_path'],
|
|
203
|
+
do_post_processing = False,
|
|
204
|
+
motion_correction = self.parameters['motion_correction'],
|
|
205
|
+
thresholds = self.parameters['thresholds'],
|
|
206
|
+
lowpass = self.parameters['low_pass'],
|
|
207
|
+
highpass = self.parameters['high_pass'])
|
|
208
|
+
|
|
209
|
+
elif self.parameters['algorithm_name'] == 'spks_kilosort4.0':
|
|
210
|
+
results_folder = run_kilosort(version = '4.0',
|
|
211
|
+
sessionfiles = probepath,
|
|
212
|
+
temporary_folder = prefs['scratch_path'],
|
|
213
|
+
do_post_processing = False,
|
|
214
|
+
motion_correction = self.parameters['motion_correction'],
|
|
215
|
+
thresholds = self.parameters['thresholds'],
|
|
216
|
+
lowpass = self.parameters['low_pass'],
|
|
217
|
+
highpass = self.parameters['high_pass'])
|
|
218
|
+
elif self.parameters['algorithm_name'] == 'spks_mountainsort5':
|
|
219
|
+
raise(NotImplemented(f"[{self.name} job] - Algorithm {self.parameters['algorithm_name']} not implemented."))
|
|
220
|
+
else:
|
|
221
|
+
raise(NotImplemented(f"[{self.name} job] - Algorithm {self.parameters['algorithm_name']} not implemented."))
|
|
222
|
+
self.set_job_status(job_log = f'Probe {probe_num} sorted, running post-processing.')
|
|
223
|
+
self.postprocess_and_insert(results_folder,
|
|
224
|
+
probe_num = probe_num,
|
|
225
|
+
remove_duplicates = True,
|
|
226
|
+
n_pre_samples = 45)
|
|
227
|
+
self.unregister_safe_exit() # in case these get triggered by shutdown
|
|
228
|
+
try:
|
|
229
|
+
from joblib.externals.loky import get_reusable_executor
|
|
230
|
+
get_reusable_executor().shutdown(wait=True)
|
|
231
|
+
|
|
232
|
+
except:
|
|
233
|
+
print(f'[{self.name} job] Tried to clear joblib Loky executers and failed.')
|
|
234
|
+
self.register_safe_exit() # put it back..
|
|
235
|
+
|
|
236
|
+
if self.delete_results:
|
|
237
|
+
# delete results_folder
|
|
238
|
+
print(f'[{self.name} job] Removing the results folder.')
|
|
239
|
+
import shutil
|
|
240
|
+
shutil.rmtree(results_folder)
|
|
241
|
+
# delete local files if they did not exist
|
|
242
|
+
if not self.files_existed:
|
|
243
|
+
for f in localfiles:
|
|
244
|
+
os.unlink(f)
|
|
245
|
+
|
|
246
|
+
def prepare_results(self,results_folder,
|
|
247
|
+
probe_num,
|
|
248
|
+
remove_duplicates,
|
|
249
|
+
n_pre_samples):
|
|
250
|
+
from spks import Clusters
|
|
251
|
+
if remove_duplicates:
|
|
252
|
+
clu = Clusters(results_folder, get_waveforms = False, get_metrics = False)
|
|
253
|
+
clu.remove_duplicate_spikes(
|
|
254
|
+
overwrite_phy = True,
|
|
255
|
+
remove_cross_duplicates = self.parameters['remove_cross_duplicates'])
|
|
256
|
+
del clu
|
|
257
|
+
clu = Clusters(results_folder, get_waveforms = False, get_metrics = False)
|
|
258
|
+
clu.compute_template_amplitudes_and_depths()
|
|
259
|
+
# waveforms
|
|
260
|
+
|
|
261
|
+
base_key = dict(self.dataset_key,
|
|
262
|
+
probe_num = probe_num,
|
|
263
|
+
parameter_set_num = self.parameter_set_num)
|
|
264
|
+
ssdict = dict(base_key,
|
|
265
|
+
n_pre_samples = n_pre_samples,
|
|
266
|
+
n_sorted_units = len(clu),
|
|
267
|
+
n_detected_spikes = len(clu.spike_times),
|
|
268
|
+
sorting_datetime = datetime.fromtimestamp(
|
|
269
|
+
Path(results_folder).stat().st_ctime),
|
|
270
|
+
channel_indices = clu.channel_map.flatten(),
|
|
271
|
+
channel_coords = clu.channel_positions)
|
|
272
|
+
udict = [] # unit
|
|
273
|
+
for iclu in clu.cluster_id:
|
|
274
|
+
idx = np.where(clu.spike_clusters == iclu)[0]
|
|
275
|
+
udict.append(dict(
|
|
276
|
+
base_key,unit_id = iclu,
|
|
277
|
+
spike_positions = clu.spike_positions[idx,:].astype(np.float32),
|
|
278
|
+
spike_times = clu.spike_times[idx].flatten().astype(np.uint64),
|
|
279
|
+
spike_amplitudes = clu.spike_amplitudes[idx].flatten().astype(np.float32)))
|
|
280
|
+
|
|
281
|
+
featurestosave = dict(template_features = clu.spike_pc_features.astype(np.float32),
|
|
282
|
+
spike_templates = clu.spike_templates,
|
|
283
|
+
cluster_indices = clu.spike_clusters,
|
|
284
|
+
whitening_matrix = clu.whitening_matrix,
|
|
285
|
+
templates = clu.templates,
|
|
286
|
+
template_feature_ind = clu.template_pc_features_ind)
|
|
287
|
+
return clu,base_key,ssdict, udict, featurestosave
|
|
288
|
+
|
|
289
|
+
def postprocess_and_insert(self,
|
|
290
|
+
results_folder,
|
|
291
|
+
probe_num,
|
|
292
|
+
remove_duplicates = True,
|
|
293
|
+
n_pre_samples = 45):
|
|
294
|
+
'''Does the preprocessing for a spike sorting and inserts'''
|
|
295
|
+
# get the results in a dictionary and remove duplicates
|
|
296
|
+
clu,base_key,ssdict, udict, featurestosave = self.prepare_results(results_folder,
|
|
297
|
+
probe_num,
|
|
298
|
+
remove_duplicates,
|
|
299
|
+
n_pre_samples)
|
|
300
|
+
# save the features to a file, will take like 2 min
|
|
301
|
+
if not featurestosave['template_features'] is None:
|
|
302
|
+
save_dict_to_h5(Path(results_folder)/'features.hdf5',featurestosave)
|
|
303
|
+
n_jobs = DEFAULT_N_JOBS # gets the default number of jobs from labdata
|
|
304
|
+
# extract the waveforms from the binary file
|
|
305
|
+
n_jobs_wave = n_jobs
|
|
306
|
+
if len(clu) > 800:
|
|
307
|
+
n_jobs_wave = 2 # to prevent running out of memory when collecting waveforms
|
|
308
|
+
udict, binaryfile, nchannels,res = self.extract_waveforms(udict,
|
|
309
|
+
clu,
|
|
310
|
+
results_folder,
|
|
311
|
+
n_pre_samples,
|
|
312
|
+
n_jobs_wave)
|
|
313
|
+
def median_waves(r,gains):
|
|
314
|
+
if not r is None:
|
|
315
|
+
return np.median(r.astype(np.float32),axis = 0)*gains
|
|
316
|
+
else:
|
|
317
|
+
return None
|
|
318
|
+
waves_dict = []
|
|
319
|
+
extras = dict(compression = 'gzip',
|
|
320
|
+
compression_opts = 1,
|
|
321
|
+
chunks = True,
|
|
322
|
+
shuffle = True)
|
|
323
|
+
from tqdm import tqdm
|
|
324
|
+
print('Collecting waveforms and saving.')
|
|
325
|
+
# save these to zarr to be compressed faster
|
|
326
|
+
if self.use_hdf5: # zarr not implemented yet.
|
|
327
|
+
import h5py as h5
|
|
328
|
+
with h5.File(Path(results_folder)/'waveforms.hdf5','w') as wavefile:
|
|
329
|
+
for u,w in tqdm(zip(udict,res),desc = 'Saving waveforms to file'):
|
|
330
|
+
m = median_waves(w,gains = clu.channel_gains)
|
|
331
|
+
if not w is None:
|
|
332
|
+
waves_dict.append(dict(base_key,
|
|
333
|
+
unit_id = u['unit_id'],
|
|
334
|
+
waveform_median = m))
|
|
335
|
+
# save to the file
|
|
336
|
+
wavefile.create_dataset(str(u['unit_id'])+'/waveforms',data = w,**extras)
|
|
337
|
+
wavefile.create_dataset(str(u['unit_id'])+'/indices',data = u['waveform_indices'],**extras)
|
|
338
|
+
else:
|
|
339
|
+
print(f"Unit {u['unit_id']} had no spikes extracted")
|
|
340
|
+
stream_name = f'imec{probe_num}' # to save the events and files
|
|
341
|
+
src = [Path(results_folder)/'waveforms.hdf5',Path(results_folder)/'features.hdf5']
|
|
342
|
+
dataset = dict(**self.dataset_key)
|
|
343
|
+
dataset['dataset_name'] = f'spike_sorting/{stream_name}/{self.parameter_set_num}'
|
|
344
|
+
from ..schema import AnalysisFile
|
|
345
|
+
filekeys = AnalysisFile().upload_files(src,dataset)
|
|
346
|
+
ssdict['waveforms_file'] = filekeys[0]['file_path']
|
|
347
|
+
ssdict['waveforms_storage'] = filekeys[0]['storage']
|
|
348
|
+
if not featurestosave['template_features'] is None:
|
|
349
|
+
ssdict['features_file'] = filekeys[1]['file_path']
|
|
350
|
+
ssdict['features_storage'] = filekeys[1]['storage']
|
|
351
|
+
# insert the syncs
|
|
352
|
+
events = []
|
|
353
|
+
for c in clu.metadata.keys():
|
|
354
|
+
if 'sync_onsets' in c:
|
|
355
|
+
for k in clu.metadata[c].keys():
|
|
356
|
+
events.append(dict(self.dataset_key,
|
|
357
|
+
stream_name = stream_name,
|
|
358
|
+
event_name = str(k),
|
|
359
|
+
event_timestamps = clu.metadata[c][k].astype(np.uint64)) )
|
|
360
|
+
from ..schema import SpikeSorting, SpikeSortingParams, EphysRecording, DatasetEvents
|
|
361
|
+
if len(events):
|
|
362
|
+
# Add stream
|
|
363
|
+
DatasetEvents.insert1(dict(self.dataset_key,
|
|
364
|
+
stream_name = stream_name),
|
|
365
|
+
skip_duplicates = True, allow_direct_insert = True)
|
|
366
|
+
DatasetEvents.Digital.insert(events,
|
|
367
|
+
skip_duplicates = True,
|
|
368
|
+
allow_direct_insert = True)
|
|
369
|
+
|
|
370
|
+
# inserts
|
|
371
|
+
# do all the inserts here
|
|
372
|
+
import logging
|
|
373
|
+
logging.getLogger('datajoint').setLevel(logging.WARNING)
|
|
374
|
+
# these can't be done in a safe way quickly so if they fail we have delete SpikeSorting
|
|
375
|
+
SpikeSorting.insert1(ssdict,skip_duplicates = True)
|
|
376
|
+
# Insert datajoint in parallel.
|
|
377
|
+
Parallel(n_jobs = n_jobs)(delayed(SpikeSorting.Unit.insert1)(
|
|
378
|
+
u,
|
|
379
|
+
skip_duplicates=True,
|
|
380
|
+
ignore_extra_fields = True) for u in tqdm(udict));
|
|
381
|
+
Parallel(n_jobs = n_jobs)(delayed(SpikeSorting.Waveforms.insert1)(
|
|
382
|
+
u,
|
|
383
|
+
skip_duplicates=True,
|
|
384
|
+
ignore_extra_fields = True) for u in tqdm(waves_dict));
|
|
385
|
+
# Add a segment from a random location.
|
|
386
|
+
from spks.io import map_binary
|
|
387
|
+
dat = map_binary(binaryfile, nchannels = nchannels)
|
|
388
|
+
nsamples = int(clu.sampling_rate*2)
|
|
389
|
+
offset_samples = int(np.random.uniform(nsamples, len(dat)-nsamples-1))
|
|
390
|
+
SpikeSorting.Segment.insert1(dict(base_key,
|
|
391
|
+
segment_num = 1,
|
|
392
|
+
offset_samples = offset_samples,
|
|
393
|
+
segment = np.array(dat[offset_samples : offset_samples + nsamples])))
|
|
394
|
+
del dat
|
|
395
|
+
self.set_job_status(job_log = f'Completed {base_key}')
|
|
396
|
+
from labdata.schema import UnitMetrics
|
|
397
|
+
# limit number of jobs because of memory constraints
|
|
398
|
+
UnitMetrics.populate(base_key, processes = int(max(1,np.ceil(n_jobs/2))))
|
|
399
|
+
|
|
400
|
+
def extract_waveforms(self,udict, clu, results_folder,n_pre_samples,n_jobs):
|
|
401
|
+
# extract the waveforms
|
|
402
|
+
from spks.io import map_binary
|
|
403
|
+
binaryfile = list(Path(results_folder).glob("filtered_recording*.bin"))[0]
|
|
404
|
+
nchannels = clu.metadata['nchannels']
|
|
405
|
+
dat = map_binary(binaryfile,nchannels = nchannels) # to get the duration
|
|
406
|
+
|
|
407
|
+
udict = select_random_waveforms(udict,
|
|
408
|
+
wpre = n_pre_samples,
|
|
409
|
+
wpost = n_pre_samples,
|
|
410
|
+
duration = dat.shape[0])
|
|
411
|
+
del dat
|
|
412
|
+
res = get_waveforms_from_binary(binaryfile, nchannels,
|
|
413
|
+
[u['waveform_indices'] for u in udict],
|
|
414
|
+
wpre = n_pre_samples,
|
|
415
|
+
wpost = n_pre_samples,
|
|
416
|
+
n_jobs = n_jobs)
|
|
417
|
+
return udict, binaryfile, nchannels,res
|
|
418
|
+
|
|
419
|
+
def select_random_waveforms(unit_dict,
|
|
420
|
+
wpre = 45,
|
|
421
|
+
wpost = 45,
|
|
422
|
+
duration = None, # size of the file
|
|
423
|
+
nmax_waves = 500):
|
|
424
|
+
|
|
425
|
+
if duration is None:
|
|
426
|
+
duration = np.max([np.max(u['spike_times']) for u in unit_dict])
|
|
427
|
+
for u in unit_dict:
|
|
428
|
+
s = u['spike_times']
|
|
429
|
+
s_begin = s[(s>(wpre+2))&(s<(duration//4))]
|
|
430
|
+
s_end = s[(s>(3*(duration//4))) & (s<(duration-2*wpost))]
|
|
431
|
+
sel = []
|
|
432
|
+
if len(s_begin)>nmax_waves:
|
|
433
|
+
sel = [t for t in np.random.choice(s_begin, nmax_waves, replace=True)]
|
|
434
|
+
else:
|
|
435
|
+
sel = [t for t in s_begin]
|
|
436
|
+
if len(s_end)>nmax_waves:
|
|
437
|
+
sel += [t for t in np.random.choice(s_end, nmax_waves, replace=True)]
|
|
438
|
+
else:
|
|
439
|
+
sel += [t for t in s_end]
|
|
440
|
+
u['waveform_indices'] = np.sort(np.array(sel).flatten()) # add this to the
|
|
441
|
+
return unit_dict
|
|
442
|
+
|
|
443
|
+
def get_spike_waveforms(data,indices,wpre = 45,wpost = 45):
|
|
444
|
+
idx = np.arange(-wpre,wpost,dtype = np.int64)
|
|
445
|
+
waves = []
|
|
446
|
+
for i in indices.astype(np.int64):
|
|
447
|
+
waves.append(np.array(np.take(data,idx+i,axis = 0)))
|
|
448
|
+
if len(waves):
|
|
449
|
+
return np.stack(waves,dtype = data.dtype)
|
|
450
|
+
else:
|
|
451
|
+
return None
|
|
452
|
+
|
|
453
|
+
def get_waveforms_from_binary(binary_file,
|
|
454
|
+
binary_file_nchannels,
|
|
455
|
+
waveform_indices,
|
|
456
|
+
wpre = 45,
|
|
457
|
+
wpost = 45,
|
|
458
|
+
n_jobs = 8):
|
|
459
|
+
from tqdm import tqdm
|
|
460
|
+
from spks.io import map_binary
|
|
461
|
+
dat = map_binary(binary_file,nchannels = binary_file_nchannels)
|
|
462
|
+
# return as generator to avoid having to use huge amounts of memory.
|
|
463
|
+
res = Parallel(backend='loky',n_jobs=n_jobs,return_as = 'generator')(delayed(get_spike_waveforms)(
|
|
464
|
+
dat,
|
|
465
|
+
w,
|
|
466
|
+
wpre = wpre,
|
|
467
|
+
wpost = wpost) for w in tqdm(
|
|
468
|
+
waveform_indices,desc = "Extracting waveforms"))
|
|
469
|
+
return res
|