labdata 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
labdata/compute/ec2.py ADDED
@@ -0,0 +1,198 @@
1
+ from ..utils import *
2
+
3
+ def ec2_connect(access_key = None,secret_key = None, region = None):
4
+ import boto3
5
+ if 'aws' in prefs['compute'].keys():
6
+ if 'access_key' in prefs['compute']['aws'].keys():
7
+ access_key = prefs['compute']['aws']['access_key']
8
+ secret_key = prefs['compute']['aws']['secret_key']
9
+ region = prefs['compute']['aws']['region']
10
+ if access_key is None:
11
+ raise ValueError('Need to supply an access key to access ec2, set compute:aws:access_key in the preference file.')
12
+ if region[-1].isalpha(): # then it includes the availability zone
13
+ region = region[:-1]
14
+ botosession = boto3.Session(
15
+ aws_access_key_id=access_key,
16
+ aws_secret_access_key=secret_key, region_name = region)
17
+
18
+ ec2 = botosession.resource('ec2',region_name = region)
19
+ return (botosession,ec2)
20
+
21
+ def ec2_get_key(ec2 = None, keyname = None):
22
+ keyspath = Path(prefs['compute']['aws']['access_key_folder'])
23
+ keys = list(keyspath.glob('*'))
24
+ if not len(keys):
25
+ date = datetime.now().strftime('%Y%m%d_%H:%M:%S')
26
+ keyname = f"ec2-labdata-{prefs['hostname']}-{date}"
27
+ if ec2 is None:
28
+ session,ec2 = ec2_connect()
29
+ key = ec2.create_key_pair(KeyName=keyname)
30
+ # save key info
31
+ keyspath.mkdir(parents=True, exist_ok=True)
32
+ with open(keyspath/keyname,'w') as fd:
33
+ keyinfo = dict(key_name = key.key_name,
34
+ key_material = key.key_material,
35
+ key_pair_id = key.key_pair_id)
36
+ json.dump(keyinfo,
37
+ fd,
38
+ indent = 4)
39
+ else:
40
+ with open(keys[0],'r') as fd:
41
+ keyinfo = json.load(fd)
42
+ return keyinfo
43
+
44
+ def ec2_instance_from_id(ec2,instance_id):
45
+ if ec2 is None:
46
+ session,ec2 = ec2_connect()
47
+ instances = list(ec2.instances.filter(InstanceIds=[instance_id]))
48
+ if not len(instances):
49
+ print(f'There are no instances with id: {instance_id}')
50
+ elif len(instances)!=1:
51
+ print(f'There are multiple instances with id: {instance_id}')
52
+ return instances
53
+ else:
54
+ return instances[0]
55
+
56
+ def ec2_create_instance(ec2,
57
+ image_id = "linux",
58
+ instance_type = "t2.micro",
59
+ key_name = None,
60
+ availability_zone = None,
61
+ security_groups = None, # these should come from the preferences
62
+ user_data = 'echo hostname'):
63
+ if not image_id in prefs['compute']['aws']['image_ids'].keys():
64
+ raise ValueError(f'image_id {image_id} is not in the preference_file {list(prefs["compute"]["aws"]["image_ids"].keys())}')
65
+
66
+ if ec2 is None:
67
+ session,ec2 = ec2_connect()
68
+ if security_groups is None:
69
+ security_groups = prefs['compute']['aws']['security_groups']
70
+ image_id = prefs['compute']['aws']['image_ids'][image_id]
71
+ if key_name is None:
72
+ keyinfo = ec2_get_key(ec2)
73
+ key_name = keyinfo['key_name']
74
+ if availability_zone is None:
75
+ availability_zone = prefs['compute']['aws']['region']
76
+ insdict = dict(instance = ec2.create_instances(
77
+ ImageId = image_id['ami'],
78
+ MinCount=1,
79
+ MaxCount=1,
80
+ InstanceType=instance_type,
81
+ KeyName=key_name,
82
+ InstanceInitiatedShutdownBehavior='terminate',
83
+ UserData = user_data,
84
+ Placement={'AvailabilityZone':availability_zone},
85
+ SecurityGroups=security_groups)[0],
86
+ key_name = key_name,
87
+ instance_type = instance_type,
88
+ user_name = image_id['user'],
89
+ ami = image_id['ami'])
90
+ insdict['id'] = insdict['instance'].id
91
+ #print(user_data)
92
+ return insdict
93
+
94
+ def ec2_wait_for_instance(ec2,instancedict,desired = 'running',interval = 0.05):
95
+ if ec2 is None:
96
+ session,ec2 = ec2_connect()
97
+
98
+ instance = ec2_instance_from_id(instancedict['id'])
99
+ instance.wait_until_running()
100
+ import time
101
+ while instance.state['Name'] != desired:
102
+ time.sleep(interval)
103
+ instance = ec2_instance_from_id(ins['id'])
104
+ instancedict['instance'] = instance
105
+ return instance
106
+
107
+
108
+ def ec2_instance_ssh(instance, user = 'ubuntu'):
109
+ import paramiko
110
+
111
+ ip_address = instance.public_dns_name
112
+
113
+ ssh = paramiko.SSHClient()
114
+ ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
115
+
116
+ privkey = ec2_get_key()['key_material']
117
+ try:
118
+ from StringIO import StringIO
119
+ except ImportError:
120
+ from io import StringIO # Python3
121
+ privkey = paramiko.RSAKey.from_private_key(StringIO(privkey))
122
+
123
+ print('SSH into the instance: {}'.format(ip_address))
124
+ ssh.connect(hostname=ip_address,
125
+ username=user, pkey=privkey)
126
+ return ssh
127
+
128
+ def ec2_cmd_for_launch(singularity_container,
129
+ singularity_command,
130
+ singularity_cuda = False,
131
+ nvme = 'nvme1n1',
132
+ shutdown_when_done = True,
133
+ is_self_contained = True,
134
+ append_log = True):
135
+ userdata = ''
136
+ if not nvme is None:
137
+ userdata += f'''#! /bin/bash
138
+ # mount the data drive and set permissions
139
+ mkfs -t ext4 /dev/{nvme}
140
+ mkdir /data
141
+ mount /dev/{nvme} /data/
142
+ chown -R ubuntu /data
143
+ # make home point to /data where there is space
144
+ export HOME="/data"
145
+ '''
146
+
147
+
148
+ storage = prefs['storage'][prefs['compute']['containers']['storage']]
149
+
150
+ userdata += f'''
151
+ mkdir $HOME/.aws
152
+ echo "[default]" > $HOME/.aws/credentials
153
+ echo aws_access_key_id={storage['access_key']} >> $HOME/.aws/credentials
154
+ echo aws_secret_access_key={storage['secret_key']} >> $HOME/.aws/credentials
155
+ mkdir -p $HOME/labdata/containers
156
+ echo "Downloading container"
157
+ aws s3 cp s3://{storage['bucket']}/containers/{singularity_container}.sif $HOME/labdata/containers/
158
+
159
+ '''
160
+ if is_self_contained:
161
+ ec2pref = json.dumps(dict(compute = dict(containers=dict(local = '/data'),
162
+ analysis = prefs['compute']['analysis'],
163
+ default_target = 'local'),
164
+ database = prefs['database'],
165
+ local_paths = ['/data'],
166
+ scratch_path = '/data',
167
+ path_rules = prefs['path_rules'],
168
+ storage = prefs['storage'],
169
+ allow_s3_download = True,
170
+ use_awscli = True))
171
+ cuda = ''
172
+ if singularity_cuda:
173
+ cuda = '--nv'
174
+ userdata += '''
175
+ modprobe nvidia-uvm
176
+ nvidia-container-cli -k list
177
+ '''
178
+ userdata += f'''
179
+ cat > $HOME/labdata/user_preferences.json << EOL
180
+ {ec2pref}
181
+ EOL
182
+
183
+ mkdir -p /home/ubuntu/labdata
184
+ cp $HOME/labdata/user_preferences.json /home/ubuntu/labdata/
185
+ mkdir -p /home/ubuntu/.cache/torch/kernels
186
+ sudo chown -R ubuntu /home/ubuntu
187
+
188
+ sudo -u ubuntu bash -c "cd /home/ubuntu; singularity exec {cuda} --bind /data:/data $HOME/labdata/containers/{singularity_container}.sif {singularity_command}'''
189
+ if not append_log is None:
190
+ userdata += f' |& singularity exec $HOME/labdata/containers/{singularity_container}.sif labdata2 logpipe {append_log}'
191
+ userdata += '"'
192
+ if shutdown_when_done:
193
+ userdata += '''
194
+
195
+ shutdown now -h
196
+ '''
197
+
198
+ return userdata
@@ -0,0 +1,469 @@
1
+ from ..utils import *
2
+ from .utils import BaseCompute
3
+
4
+ class SpksCompute(BaseCompute):
5
+ container = 'labdata-spks'
6
+ cuda = True
7
+ ec2 = dict(small = dict(instance_type = 'g4dn.2xlarge'), # 8 cpus, 32 GB mem, 200 GB nvme, 1 gpu
8
+ large = dict(instance_type = 'g6.4xlarge',
9
+ availability_zone = 'us-west-2b')) # 16 cpus, 64 GB mem, 600 GB nvme, 1 gpu
10
+ name = 'spks'
11
+ url = 'http://github.com/spkware/spks'
12
+ def __init__(self,job_id, allow_s3 = None, delete_results = True, **kwargs):
13
+ '''
14
+ #1) find the files
15
+ #2) copy just the file you need to scratch
16
+ #3) run spike sorting on that file/folder
17
+ #4) delete the raw files
18
+ #5) repeat until all probes are processed.
19
+ '''
20
+ super(SpksCompute,self).__init__(job_id, allow_s3 = allow_s3)
21
+ self.file_filters = ['.ap.']
22
+ # default parameters
23
+ self.parameters = dict(algorithm_name = 'spks_kilosort4.0',
24
+ motion_correction = True,
25
+ low_pass = 300.,
26
+ high_pass = 13000.)
27
+ self.use_hdf5 = True # flag to use h5py or zarr format for the waveforms.
28
+ self.parameter_set_num = None # identifier in SpikeSortingParams
29
+ self._init_job()
30
+ if not self.job_id is None:
31
+ self.add_parameter_key()
32
+ self.delete_results = delete_results
33
+
34
+ def _get_parameter_number(self):
35
+ parameter_set_num = None
36
+ from ..schema import SpikeSorting, SpikeSortingParams, EphysRecording
37
+ # check if in spike sorting
38
+ parameters = pd.DataFrame(SpikeSortingParams().fetch())
39
+ for i,r in parameters.iterrows():
40
+ # go through every algo
41
+ if self.parameters == json.loads(r.parameters_dict):
42
+ parameter_set_num = r.parameter_set_num
43
+ if parameter_set_num is None:
44
+ if len(parameters) == 0:
45
+ parameter_set_num = 1
46
+ else:
47
+ parameter_set_num = np.max(parameters.parameter_set_num.values)+1
48
+ return parameter_set_num,parameters
49
+
50
+ def add_parameter_key(self):
51
+ parameter_set_num, parameters = self._get_parameter_number()
52
+ from ..schema import SpikeSorting, SpikeSortingParams, EphysRecording
53
+ if not parameter_set_num in parameters.parameter_set_num.values:
54
+ SpikeSortingParams().insert1(dict(parameter_set_num = parameter_set_num,
55
+ algorithm_name = self.parameters['algorithm_name'],
56
+ parameters_dict = json.dumps(self.parameters),
57
+ code_link = self.url),
58
+ skip_duplicates=True)
59
+ self.parameter_set_num = parameter_set_num
60
+ recordings = EphysRecording.ProbeSetting() & dict(self.dataset_key)
61
+ sortings = SpikeSorting() & dict(self.dataset_key, parameter_set_num = self.parameter_set_num)
62
+ if len(recordings) == len(sortings):
63
+ self.set_job_status(
64
+ job_status = 'FAILED',
65
+ job_waiting = 0,
66
+ job_log = f'{self.dataset_key} was already sorted with parameters {self.parameter_set_num}.')
67
+ raise(ValueError(f'{self.dataset_key} was already sorted with parameters {self.parameter_set_num}.'))
68
+
69
+ def _secondary_parse(self,arguments,parameter_number = None):
70
+ '''
71
+ Handles parsing the command line interface
72
+ '''
73
+ if not parameter_number is None:
74
+ from ..schema import SpikeSortingParams
75
+ self.parameters = ((SpikeSortingParams() & f'parameter_set_num = {parameter_number}')).fetch(as_dict = True)
76
+ if not len(self.parameters):
77
+ raise(f'Could not find parameter {parameter_number} in SpikeSortingParams.')
78
+ self.parameters = self.parameters[0]
79
+ else:
80
+ import argparse
81
+ parser = argparse.ArgumentParser(
82
+ description = 'Analysis of spike data using kilosort version 2.5 through the spks package.',
83
+ usage = 'spks -a <SUBJECT> -s <SESSION> -- <PARAMETERS>')
84
+
85
+ parser.add_argument('-p','--probe',
86
+ action='store', default=None, type = int,
87
+ help = "THIS DOES NOTHING NOW. WILL BE FOR OPENING PHY")
88
+ parser.add_argument('-m','--method',action='store',default = 'ks2.5',type = str,
89
+ help = 'Method for spike sorting ks2.5 [Kilosort], ks3.0, ks4.0, ms5 [MountainSort]')
90
+ parser.add_argument('-l','--low-pass',
91
+ action='store', default=self.parameters['low_pass'], type = float,
92
+ help = "Lowpass filter (default 300.Hz)")
93
+ parser.add_argument('-i','--high-pass',
94
+ action='store', default=self.parameters['high_pass'], type = float,
95
+ help = "Highpass filter (default 13000.Hz)")
96
+ parser.add_argument('-t','--thresholds',
97
+ action='store', default=None, type = float, nargs = 2,
98
+ help = "Thresholds for spike detection default depends on method.")
99
+ parser.add_argument('-n','--no-motion-correction',
100
+ action='store_false', default = True,
101
+ help = "Skip motion correction")
102
+ parser.add_argument('-c','--remove_cross-unit-duplicates',
103
+ action='store_true', default = False,
104
+ help = "Skip removing duplicates across units.")
105
+
106
+ args = parser.parse_args(arguments[1:])
107
+ if 'ks2.5' in args.method: # defaults for ks25
108
+ self.parameters = dict(algorithm_name = 'spks_kilosort2.5',
109
+ motion_correction = args.no_motion_correction,
110
+ low_pass = args.low_pass,
111
+ high_pass = args.high_pass,
112
+ thresholds = [9.,3.],
113
+ remove_cross_duplicates = args.remove_cross_unit_duplicates)
114
+ elif 'ks3.0' in args.method: # defaults for ks3.0
115
+ self.parameters = dict(algorithm_name = 'spks_kilosort3.0',
116
+ motion_correction = args.no_motion_correction,
117
+ low_pass = args.low_pass,
118
+ high_pass = args.high_pass,
119
+ thresholds = [9.,9.],
120
+ remove_cross_duplicates = args.remove_cross_unit_duplicates)
121
+ elif 'ks4.0' in args.method: # defaults for ks3.0
122
+ self.parameters = dict(algorithm_name = 'spks_kilosort4.0',
123
+ motion_correction = args.no_motion_correction,
124
+ low_pass = args.low_pass,
125
+ high_pass = args.high_pass,
126
+ thresholds = [9.,8.],
127
+ remove_cross_duplicates = args.remove_cross_unit_duplicates)
128
+ else:
129
+ raise(NotImplemented(f'{args.method} not implemented.'))
130
+
131
+ if not args.thresholds is None:
132
+ self.parameters['thresholds'] = args.thresholds
133
+
134
+ self.probe = args.probe
135
+
136
+ def find_datasets(self, subject_name = None, session_name = None):
137
+ '''
138
+ Searches for subjects and sessions in EphysRecording
139
+ '''
140
+ if subject_name is None and session_name is None:
141
+ print("\n\nPlease specify a 'subject_name' and a 'session_name' to perform spike-sorting.\n\n")
142
+ from ..schema import EphysRecording, SpikeSorting
143
+
144
+ keys = []
145
+ if not subject_name is None:
146
+ if len(subject_name) > 1:
147
+ raise ValueError(f'Please submit one subject at a time {subject_name}.')
148
+ if not subject_name[0] == '':
149
+ subject_name = subject_name[0]
150
+ if not session_name is None:
151
+ for s in session_name:
152
+ if not s == '':
153
+ keys.append(dict(subject_name = subject_name,
154
+ session_name = s))
155
+ else:
156
+ # find all sessions that can be spike sorted
157
+ parameter_set_num, parameters = self._get_parameter_number()
158
+ sessions = np.unique(((
159
+ (EphysRecording() & f'subject_name = "{subject_name}"') -
160
+ (SpikeSorting() & f'parameter_set_num = {parameter_set_num}'))).fetch('session_name'))
161
+ for ses in sessions:
162
+ keys.append(dict(subject_name = subject_name,
163
+ session_name = ses))
164
+ datasets = []
165
+ for k in keys:
166
+ datasets += (EphysRecording()& k).proj('subject_name','session_name','dataset_name').fetch(as_dict = True)
167
+ return datasets
168
+
169
+ def _compute(self):
170
+ from ..schema import EphysRecording
171
+ datasets = pd.DataFrame((EphysRecording.ProbeFile() & self.dataset_key).fetch())
172
+
173
+ for probe_num in np.unique(datasets.probe_num):
174
+ self.set_job_status(job_log = f'Sorting {probe_num}')
175
+ files = datasets[datasets.probe_num.values == probe_num]
176
+ dset = []
177
+ for i,f in files.iterrows():
178
+ if 'ap.cbin' in f.file_path or 'ap.ch' in f.file_path:
179
+ dset.append(i)
180
+ elif 'ap.meta' in f.file_path: # requires a metadata file (spikeglx)
181
+ dset.append(i)
182
+ dset = files.loc[dset]
183
+ if not len(dset):
184
+ print(files)
185
+ raise(ValueError(f'Could not find ap.cbin files for probe {probe_num}'))
186
+ localfiles = self.get_files(dset, allowed_extensions = ['.ap.bin'])
187
+ probepath = list(filter(lambda x: str(x).endswith('bin'),localfiles))
188
+ # print(probepath)
189
+ if 'kilosort' in self.parameters['algorithm_name']:
190
+ from spks.sorting import run_kilosort
191
+ if self.parameters['algorithm_name'] == 'spks_kilosort2.5':
192
+ results_folder = run_kilosort(version = '2.5',sessionfiles = probepath,
193
+ temporary_folder = prefs['scratch_path'],
194
+ do_post_processing = False,
195
+ motion_correction = self.parameters['motion_correction'],
196
+ thresholds = self.parameters['thresholds'],
197
+ lowpass = self.parameters['low_pass'],
198
+ highpass = self.parameters['high_pass'])
199
+ elif self.parameters['algorithm_name'] == 'spks_kilosort3.0':
200
+ results_folder = run_kilosort(version = '3.0',
201
+ sessionfiles = probepath,
202
+ temporary_folder = prefs['scratch_path'],
203
+ do_post_processing = False,
204
+ motion_correction = self.parameters['motion_correction'],
205
+ thresholds = self.parameters['thresholds'],
206
+ lowpass = self.parameters['low_pass'],
207
+ highpass = self.parameters['high_pass'])
208
+
209
+ elif self.parameters['algorithm_name'] == 'spks_kilosort4.0':
210
+ results_folder = run_kilosort(version = '4.0',
211
+ sessionfiles = probepath,
212
+ temporary_folder = prefs['scratch_path'],
213
+ do_post_processing = False,
214
+ motion_correction = self.parameters['motion_correction'],
215
+ thresholds = self.parameters['thresholds'],
216
+ lowpass = self.parameters['low_pass'],
217
+ highpass = self.parameters['high_pass'])
218
+ elif self.parameters['algorithm_name'] == 'spks_mountainsort5':
219
+ raise(NotImplemented(f"[{self.name} job] - Algorithm {self.parameters['algorithm_name']} not implemented."))
220
+ else:
221
+ raise(NotImplemented(f"[{self.name} job] - Algorithm {self.parameters['algorithm_name']} not implemented."))
222
+ self.set_job_status(job_log = f'Probe {probe_num} sorted, running post-processing.')
223
+ self.postprocess_and_insert(results_folder,
224
+ probe_num = probe_num,
225
+ remove_duplicates = True,
226
+ n_pre_samples = 45)
227
+ self.unregister_safe_exit() # in case these get triggered by shutdown
228
+ try:
229
+ from joblib.externals.loky import get_reusable_executor
230
+ get_reusable_executor().shutdown(wait=True)
231
+
232
+ except:
233
+ print(f'[{self.name} job] Tried to clear joblib Loky executers and failed.')
234
+ self.register_safe_exit() # put it back..
235
+
236
+ if self.delete_results:
237
+ # delete results_folder
238
+ print(f'[{self.name} job] Removing the results folder.')
239
+ import shutil
240
+ shutil.rmtree(results_folder)
241
+ # delete local files if they did not exist
242
+ if not self.files_existed:
243
+ for f in localfiles:
244
+ os.unlink(f)
245
+
246
+ def prepare_results(self,results_folder,
247
+ probe_num,
248
+ remove_duplicates,
249
+ n_pre_samples):
250
+ from spks import Clusters
251
+ if remove_duplicates:
252
+ clu = Clusters(results_folder, get_waveforms = False, get_metrics = False)
253
+ clu.remove_duplicate_spikes(
254
+ overwrite_phy = True,
255
+ remove_cross_duplicates = self.parameters['remove_cross_duplicates'])
256
+ del clu
257
+ clu = Clusters(results_folder, get_waveforms = False, get_metrics = False)
258
+ clu.compute_template_amplitudes_and_depths()
259
+ # waveforms
260
+
261
+ base_key = dict(self.dataset_key,
262
+ probe_num = probe_num,
263
+ parameter_set_num = self.parameter_set_num)
264
+ ssdict = dict(base_key,
265
+ n_pre_samples = n_pre_samples,
266
+ n_sorted_units = len(clu),
267
+ n_detected_spikes = len(clu.spike_times),
268
+ sorting_datetime = datetime.fromtimestamp(
269
+ Path(results_folder).stat().st_ctime),
270
+ channel_indices = clu.channel_map.flatten(),
271
+ channel_coords = clu.channel_positions)
272
+ udict = [] # unit
273
+ for iclu in clu.cluster_id:
274
+ idx = np.where(clu.spike_clusters == iclu)[0]
275
+ udict.append(dict(
276
+ base_key,unit_id = iclu,
277
+ spike_positions = clu.spike_positions[idx,:].astype(np.float32),
278
+ spike_times = clu.spike_times[idx].flatten().astype(np.uint64),
279
+ spike_amplitudes = clu.spike_amplitudes[idx].flatten().astype(np.float32)))
280
+
281
+ featurestosave = dict(template_features = clu.spike_pc_features.astype(np.float32),
282
+ spike_templates = clu.spike_templates,
283
+ cluster_indices = clu.spike_clusters,
284
+ whitening_matrix = clu.whitening_matrix,
285
+ templates = clu.templates,
286
+ template_feature_ind = clu.template_pc_features_ind)
287
+ return clu,base_key,ssdict, udict, featurestosave
288
+
289
+ def postprocess_and_insert(self,
290
+ results_folder,
291
+ probe_num,
292
+ remove_duplicates = True,
293
+ n_pre_samples = 45):
294
+ '''Does the preprocessing for a spike sorting and inserts'''
295
+ # get the results in a dictionary and remove duplicates
296
+ clu,base_key,ssdict, udict, featurestosave = self.prepare_results(results_folder,
297
+ probe_num,
298
+ remove_duplicates,
299
+ n_pre_samples)
300
+ # save the features to a file, will take like 2 min
301
+ if not featurestosave['template_features'] is None:
302
+ save_dict_to_h5(Path(results_folder)/'features.hdf5',featurestosave)
303
+ n_jobs = DEFAULT_N_JOBS # gets the default number of jobs from labdata
304
+ # extract the waveforms from the binary file
305
+ n_jobs_wave = n_jobs
306
+ if len(clu) > 800:
307
+ n_jobs_wave = 2 # to prevent running out of memory when collecting waveforms
308
+ udict, binaryfile, nchannels,res = self.extract_waveforms(udict,
309
+ clu,
310
+ results_folder,
311
+ n_pre_samples,
312
+ n_jobs_wave)
313
+ def median_waves(r,gains):
314
+ if not r is None:
315
+ return np.median(r.astype(np.float32),axis = 0)*gains
316
+ else:
317
+ return None
318
+ waves_dict = []
319
+ extras = dict(compression = 'gzip',
320
+ compression_opts = 1,
321
+ chunks = True,
322
+ shuffle = True)
323
+ from tqdm import tqdm
324
+ print('Collecting waveforms and saving.')
325
+ # save these to zarr to be compressed faster
326
+ if self.use_hdf5: # zarr not implemented yet.
327
+ import h5py as h5
328
+ with h5.File(Path(results_folder)/'waveforms.hdf5','w') as wavefile:
329
+ for u,w in tqdm(zip(udict,res),desc = 'Saving waveforms to file'):
330
+ m = median_waves(w,gains = clu.channel_gains)
331
+ if not w is None:
332
+ waves_dict.append(dict(base_key,
333
+ unit_id = u['unit_id'],
334
+ waveform_median = m))
335
+ # save to the file
336
+ wavefile.create_dataset(str(u['unit_id'])+'/waveforms',data = w,**extras)
337
+ wavefile.create_dataset(str(u['unit_id'])+'/indices',data = u['waveform_indices'],**extras)
338
+ else:
339
+ print(f"Unit {u['unit_id']} had no spikes extracted")
340
+ stream_name = f'imec{probe_num}' # to save the events and files
341
+ src = [Path(results_folder)/'waveforms.hdf5',Path(results_folder)/'features.hdf5']
342
+ dataset = dict(**self.dataset_key)
343
+ dataset['dataset_name'] = f'spike_sorting/{stream_name}/{self.parameter_set_num}'
344
+ from ..schema import AnalysisFile
345
+ filekeys = AnalysisFile().upload_files(src,dataset)
346
+ ssdict['waveforms_file'] = filekeys[0]['file_path']
347
+ ssdict['waveforms_storage'] = filekeys[0]['storage']
348
+ if not featurestosave['template_features'] is None:
349
+ ssdict['features_file'] = filekeys[1]['file_path']
350
+ ssdict['features_storage'] = filekeys[1]['storage']
351
+ # insert the syncs
352
+ events = []
353
+ for c in clu.metadata.keys():
354
+ if 'sync_onsets' in c:
355
+ for k in clu.metadata[c].keys():
356
+ events.append(dict(self.dataset_key,
357
+ stream_name = stream_name,
358
+ event_name = str(k),
359
+ event_timestamps = clu.metadata[c][k].astype(np.uint64)) )
360
+ from ..schema import SpikeSorting, SpikeSortingParams, EphysRecording, DatasetEvents
361
+ if len(events):
362
+ # Add stream
363
+ DatasetEvents.insert1(dict(self.dataset_key,
364
+ stream_name = stream_name),
365
+ skip_duplicates = True, allow_direct_insert = True)
366
+ DatasetEvents.Digital.insert(events,
367
+ skip_duplicates = True,
368
+ allow_direct_insert = True)
369
+
370
+ # inserts
371
+ # do all the inserts here
372
+ import logging
373
+ logging.getLogger('datajoint').setLevel(logging.WARNING)
374
+ # these can't be done in a safe way quickly so if they fail we have delete SpikeSorting
375
+ SpikeSorting.insert1(ssdict,skip_duplicates = True)
376
+ # Insert datajoint in parallel.
377
+ Parallel(n_jobs = n_jobs)(delayed(SpikeSorting.Unit.insert1)(
378
+ u,
379
+ skip_duplicates=True,
380
+ ignore_extra_fields = True) for u in tqdm(udict));
381
+ Parallel(n_jobs = n_jobs)(delayed(SpikeSorting.Waveforms.insert1)(
382
+ u,
383
+ skip_duplicates=True,
384
+ ignore_extra_fields = True) for u in tqdm(waves_dict));
385
+ # Add a segment from a random location.
386
+ from spks.io import map_binary
387
+ dat = map_binary(binaryfile, nchannels = nchannels)
388
+ nsamples = int(clu.sampling_rate*2)
389
+ offset_samples = int(np.random.uniform(nsamples, len(dat)-nsamples-1))
390
+ SpikeSorting.Segment.insert1(dict(base_key,
391
+ segment_num = 1,
392
+ offset_samples = offset_samples,
393
+ segment = np.array(dat[offset_samples : offset_samples + nsamples])))
394
+ del dat
395
+ self.set_job_status(job_log = f'Completed {base_key}')
396
+ from labdata.schema import UnitMetrics
397
+ # limit number of jobs because of memory constraints
398
+ UnitMetrics.populate(base_key, processes = int(max(1,np.ceil(n_jobs/2))))
399
+
400
+ def extract_waveforms(self,udict, clu, results_folder,n_pre_samples,n_jobs):
401
+ # extract the waveforms
402
+ from spks.io import map_binary
403
+ binaryfile = list(Path(results_folder).glob("filtered_recording*.bin"))[0]
404
+ nchannels = clu.metadata['nchannels']
405
+ dat = map_binary(binaryfile,nchannels = nchannels) # to get the duration
406
+
407
+ udict = select_random_waveforms(udict,
408
+ wpre = n_pre_samples,
409
+ wpost = n_pre_samples,
410
+ duration = dat.shape[0])
411
+ del dat
412
+ res = get_waveforms_from_binary(binaryfile, nchannels,
413
+ [u['waveform_indices'] for u in udict],
414
+ wpre = n_pre_samples,
415
+ wpost = n_pre_samples,
416
+ n_jobs = n_jobs)
417
+ return udict, binaryfile, nchannels,res
418
+
419
+ def select_random_waveforms(unit_dict,
420
+ wpre = 45,
421
+ wpost = 45,
422
+ duration = None, # size of the file
423
+ nmax_waves = 500):
424
+
425
+ if duration is None:
426
+ duration = np.max([np.max(u['spike_times']) for u in unit_dict])
427
+ for u in unit_dict:
428
+ s = u['spike_times']
429
+ s_begin = s[(s>(wpre+2))&(s<(duration//4))]
430
+ s_end = s[(s>(3*(duration//4))) & (s<(duration-2*wpost))]
431
+ sel = []
432
+ if len(s_begin)>nmax_waves:
433
+ sel = [t for t in np.random.choice(s_begin, nmax_waves, replace=True)]
434
+ else:
435
+ sel = [t for t in s_begin]
436
+ if len(s_end)>nmax_waves:
437
+ sel += [t for t in np.random.choice(s_end, nmax_waves, replace=True)]
438
+ else:
439
+ sel += [t for t in s_end]
440
+ u['waveform_indices'] = np.sort(np.array(sel).flatten()) # add this to the
441
+ return unit_dict
442
+
443
+ def get_spike_waveforms(data,indices,wpre = 45,wpost = 45):
444
+ idx = np.arange(-wpre,wpost,dtype = np.int64)
445
+ waves = []
446
+ for i in indices.astype(np.int64):
447
+ waves.append(np.array(np.take(data,idx+i,axis = 0)))
448
+ if len(waves):
449
+ return np.stack(waves,dtype = data.dtype)
450
+ else:
451
+ return None
452
+
453
+ def get_waveforms_from_binary(binary_file,
454
+ binary_file_nchannels,
455
+ waveform_indices,
456
+ wpre = 45,
457
+ wpost = 45,
458
+ n_jobs = 8):
459
+ from tqdm import tqdm
460
+ from spks.io import map_binary
461
+ dat = map_binary(binary_file,nchannels = binary_file_nchannels)
462
+ # return as generator to avoid having to use huge amounts of memory.
463
+ res = Parallel(backend='loky',n_jobs=n_jobs,return_as = 'generator')(delayed(get_spike_waveforms)(
464
+ dat,
465
+ w,
466
+ wpre = wpre,
467
+ wpost = wpost) for w in tqdm(
468
+ waveform_indices,desc = "Extracting waveforms"))
469
+ return res