UnitMatchPy 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
UnitMatchPy/utils.py ADDED
@@ -0,0 +1,441 @@
1
+ # utility function for loading files etc
2
+ import numpy as np
3
+ import pandas as pd
4
+ import os
5
+ import matplotlib.pyplot as plt
6
+
7
+ def load_tsv(path):
8
+ """
9
+ Loads a tsv, as a numpy array with the headers removed.
10
+ """
11
+ df = pd.read_csv(path, sep='\t', skiprows = 0)
12
+ return df.values
13
+
14
+ def get_session_number(unitid, SessionSwitch):
15
+
16
+ for i in range(len(SessionSwitch) - 1):
17
+ if (SessionSwitch[i] <= unitid < SessionSwitch[i+1]):
18
+ return i
19
+
20
+ def get_session_data(nUnitsPerSession):
21
+ """
22
+ Input the number of units per day/session as a numpy array, will return:
23
+ the total number of units, sessionid and array where each unit is given a number according to what session it is a member of
24
+ the index's of when the session switches in form [0, end of session 1, end of session 2....end of final session]
25
+ """
26
+ nSessions = len(nUnitsPerSession)
27
+ nUnits = nUnitsPerSession.sum()
28
+
29
+ sessionid = np.zeros(nUnits, dtype = int)
30
+ SessionSwitch = np.cumsum(nUnitsPerSession)
31
+ SessionSwitch = np.insert(SessionSwitch, 0, 0)
32
+ for i in range(nSessions):
33
+ sessionid[SessionSwitch[i]:SessionSwitch[i+1]] = int(i)
34
+
35
+ return nUnits, sessionid, SessionSwitch, nSessions
36
+
37
+ def get_within_session(sessionid, param):
38
+ """
39
+ Uses the session id to great a nUnits * nUnits array, where it is 0 if the units are from the same session
40
+ and it is one if the units are from a different session
41
+ """
42
+ nUnits = param['nUnits']
43
+
44
+ tmp1 = np.expand_dims(sessionid , axis=1)
45
+ tmp2 = np.expand_dims(sessionid, axis=0)
46
+
47
+ WithinSession = np.ones((nUnits, nUnits))
48
+ WithinSession[tmp1 == tmp2] = 0
49
+
50
+ return WithinSession
51
+
52
+ def get_default_param(param = None):
53
+ """
54
+ Create param, a dictionary with the default parameters.
55
+ If a dictionary is given, it will add values to it without overwriting existing values.
56
+ Do not need to give a dictionary.
57
+ """
58
+ tmp = {'SpikeWidth' : 82, 'waveidx' : np.arange(33,56), 'ChannelRadius' : 150,
59
+ 'PeakLoc' : 40, 'MaxDist' : 100, 'NeighbourDist' : 50, 'stepsz' : 0.01,
60
+ 'SmoothProb' : 9, 'MinAngleDist' : 0.1, 'NoShanks' : 4, 'ShankDist' : 175,
61
+ 'MatchNumThreshold' : 15, 'MatchThreshold' : 0.5
62
+ }
63
+ tmp['ScoreVector'] = np.arange(tmp['stepsz']/2 ,1 ,tmp['stepsz'])
64
+ tmp['Bins'] = np.arange(0, 1 + tmp['stepsz'], tmp['stepsz'])
65
+
66
+ #if no dictionary is given just returns the default parameters
67
+ if param == None:
68
+ out = tmp
69
+ else:
70
+ #Add default parameters to param dictionary, does not overwrite pre existing param values
71
+ out = tmp | param
72
+ return out
73
+
74
+ def load_good_waveforms(WavePaths, UnitLabelPaths, param, GoodUnitsOnly = True):
75
+ """"
76
+ This is the recommended way to read in data. It uses
77
+ """
78
+ if len(WavePaths) == len(UnitLabelPaths):
79
+ nSessions = len(WavePaths)
80
+ else:
81
+ print('Warning: gave different number of paths for waveforms and labels!')
82
+ return
83
+
84
+ GoodUnits = []
85
+ nUnitsPerSessionALL = []
86
+ for i in range(len(UnitLabelPaths)):
87
+ UnitLabel = load_tsv(UnitLabelPaths[i])
88
+ TmpIdx = np.argwhere(UnitLabel[:,1] == 'good')
89
+ goodunit_idx = UnitLabel[TmpIdx, 0]
90
+ GoodUnits.append(goodunit_idx)
91
+ nUnitsPerSessionALL.append(UnitLabel.shape[0])
92
+
93
+ waveforms = []
94
+ if GoodUnitsOnly:
95
+ #go through each session and load in units to waveforms list
96
+ for ls in range(len(WavePaths)):
97
+ #load in the first good unit, to get the shape of each waveform
98
+ p_file = os.path.join(WavePaths[ls],f'Unit{int(GoodUnits[ls][0].squeeze())}_RawSpikes.npy')
99
+ tmp = np.load(p_file)
100
+ tmpWaveform = np.zeros( (len(GoodUnits[ls]), tmp.shape[0], tmp.shape[1], tmp.shape[2]))
101
+
102
+ for i in range(len(GoodUnits[ls])):
103
+ #loads in all GoodUnits for that session
104
+ p_file_good = os.path.join(WavePaths[ls],f'Unit{int(GoodUnits[ls][i].squeeze())}_RawSpikes.npy')
105
+ tmpWaveform[i] = np.load(p_file_good)
106
+ #adds that session to the list
107
+ waveforms.append(tmpWaveform)
108
+
109
+ del tmpWaveform
110
+ del tmp
111
+
112
+ else:
113
+ for ls in range(len(WavePaths)):
114
+ #load in the first good unit, to get the shape of each waveform
115
+ p_file = os.path.join(WavePaths[ls],f'Unit{int(GoodUnits[ls][0].squeeze())}_RawSpikes.npy')
116
+ tmp = np.load(p_file)
117
+ tmpWaveform = np.zeros( (len(os.listdir(WavePaths[ls])), tmp.shape[0], tmp.shape[1], tmp.shape[2]))
118
+
119
+ for i in range(len(os.listdir(WavePaths[ls]))):
120
+ #loads in all GoodUnits for that session
121
+ p_file_good = os.path.join(WavePaths[ls], f'Unit{int(GoodUnits[ls][0].squeeze())}_RawSpikes.npy')
122
+ tmpWaveform[i] = np.load(p_file_good)
123
+ #adds that session to the list
124
+ waveforms.append(tmpWaveform)
125
+
126
+ del tmpWaveform
127
+ del tmp
128
+
129
+
130
+
131
+ nUnitsPerSession = np.zeros(nSessions, dtype = 'int')
132
+ waveform = np.array([])
133
+
134
+ #add all of the individual waveforms to one waveform array
135
+ for i in range(nSessions):
136
+ if i == 0:
137
+ waveform = waveforms[i]
138
+ else:
139
+ waveform = np.concatenate((waveform, waveforms[i]), axis = 0)
140
+
141
+ nUnitsPerSession[i] = waveforms[i].shape[0]
142
+
143
+ param['nUnits'], sessionid, SessionSwitch, param['nSessions'] = get_session_data(nUnitsPerSession)
144
+ WithinSession = get_within_session(sessionid, param)
145
+ param['nChannels'] = waveform.shape[2]
146
+ param['nUnitsPerSession'] = nUnitsPerSessionALL
147
+
148
+ #if the set of default paramaters have a different spike width update these parameters
149
+ if param['SpikeWidth'] != waveform.shape[1]:
150
+ param['SpikeWidth'] = waveform.shape[1]
151
+ param['PeakLoc'] = np.floor(waveform.shape[1]/2).astype(int)
152
+ param['waveidx'] = np.arange(param['PeakLoc'] - 8, param['PeakLoc'] + 15, dtype = int)
153
+
154
+ return waveform, sessionid, SessionSwitch, WithinSession, GoodUnits, param
155
+
156
+ def get_good_units(UnitLabelPaths, good = True):
157
+ """
158
+ Requires the paths to .tsv files, which contain the unit index's and if they area a good unit.
159
+ Will return a list where each index of the list is a numpy array ofall the good index's.
160
+ This function is set to only get index's for units labelled 'good', pass good = False to get ALL unit index's
161
+ """
162
+ GoodUnits = []
163
+ for i in range(len(UnitLabelPaths)):
164
+ UnitLabel = load_tsv(UnitLabelPaths[i])
165
+ if good == True:
166
+ TmpIdx = np.argwhere(UnitLabel[:,1] == 'good')
167
+ else:
168
+ TmpIdx = UnitLabel[:,0] # every unit index in the first column
169
+ GoodUnitIdx = UnitLabel[TmpIdx, 0]
170
+ GoodUnits.append(GoodUnitIdx)
171
+ return GoodUnits
172
+
173
+ def load_good_units(GoodUnits, WavePaths, param):
174
+ """
175
+ Requires a list which contains a numpy array with the units to load per session, as well as a path to
176
+ a file which contains all the the raw averaged units
177
+ """
178
+ if len(WavePaths) == len(GoodUnits):
179
+ nSessions = len(WavePaths)
180
+ else:
181
+ print('Warning: gave different number of paths for waveforms and labels!')
182
+ return
183
+
184
+ waveforms = []
185
+ #go through each session and load in units to waveforms list
186
+ for ls in range(len(WavePaths)):
187
+ #load in the first good unit, to get the shape of each waveform
188
+ tmp_path = os.path.join(WavePaths[ls], f'Unit{int(GoodUnits[ls][0].squeeze())}_RawSpikes.npy')
189
+ tmp = np.load(tmp_path)
190
+ tmpWaveform = np.zeros( (len(GoodUnits[ls]), tmp.shape[0], tmp.shape[1], tmp.shape[2]))
191
+
192
+ for i in range(len(GoodUnits[ls])):
193
+ #loads in all GoodUnits for that session
194
+ tmp_path_good = os.path.join(WavePaths[ls], f'Unit{int(GoodUnits[ls][i].squeeze())}_RawSpikes.npy')
195
+ tmpWaveform[i] = np.load(tmp_path_good)
196
+ #adds that session to the list
197
+ waveforms.append(tmpWaveform)
198
+
199
+ del tmpWaveform
200
+ del tmp
201
+
202
+ nUnitsPerSession = np.zeros(nSessions, dtype = 'int')
203
+ waveform = np.array([])
204
+
205
+ #add all of the individual waveforms to one waveform array
206
+ for i in range(nSessions):
207
+ if i == 0:
208
+ waveform = waveforms[i]
209
+ else:
210
+ waveform = np.concatenate((waveform, waveforms[i]), axis = 0)
211
+
212
+ nUnitsPerSession[i] = waveforms[i].shape[0]
213
+
214
+ param['nUnits'], sessionid, SessionSwitch, param['nSessions'] = get_session_data(nUnitsPerSession)
215
+ WithinSession = get_within_session(sessionid, param)
216
+ param['nChannels'] = waveform.shape[2]
217
+ return waveform, sessionid, SessionSwitch, WithinSession, param
218
+
219
+
220
+ def compare_units(AvgWaveform, AvgCentroid, unit1, unit2):
221
+ """
222
+ Basic helper function, plots the average wave function (of cv 0) and the average centroid to quickly compare 2 units
223
+ """
224
+ plt.plot(AvgWaveform[:,unit1,0])
225
+ plt.plot(AvgWaveform[:,unit2,0])
226
+ print(f'Average centroid of unit {unit1} is :{AvgCentroid[:,unit1,0]}')
227
+ print(f'Average centroid of unit {unit2} is :{AvgCentroid[:,unit2,0]}')
228
+
229
+
230
+ def evaluate_output(output, param, WithinSession, SessionSwitch, MatchThreshold = 0.5):
231
+ """"
232
+ Input: output - the n_units * n_units probability matrix (each value is prob those units match)
233
+ the param dictionary and optionally the threshold used to calculate if a unit is a match
234
+
235
+ This function then print:
236
+ The number of units matched to themselves across cv
237
+ The false negative %, how many did not match to themselves across cv
238
+ the false positive % in two ways, how many miss-matches are there in the off-diagonal per session
239
+ and how many false match out of how many matches we should get
240
+ """
241
+
242
+ OutputThreshold = np.zeros_like(output)
243
+ OutputThreshold[output > MatchThreshold] = 1
244
+
245
+ # get the number of diagonal matches
246
+ nDiag = np.sum(OutputThreshold[np.eye(param['nUnits']).astype(bool)])
247
+ SelfMatch = nDiag / param['nUnits'] *100
248
+ print(f'The percentage of units matched to themselves is: {SelfMatch}%')
249
+ print(f'The percentage of false -ve\'s then is: {100 - SelfMatch}% \n')
250
+
251
+ #off-diagonal miss-matches
252
+ nOffDiag = np.zeros_like(output)
253
+ nOffDiag = OutputThreshold
254
+ nOffDiag[WithinSession == 1] = 0
255
+ nOffDiag[np.eye(param['nUnits']) == 1] = 0
256
+ FPest = nOffDiag.sum() / (param['nUnits'])
257
+ print(f'The rate of miss-match(es) per expected match {FPest}')
258
+
259
+
260
+ #compute matlab FP per session per session
261
+ FPestPerSession = np.zeros(param['nSessions'])
262
+ for did in range(param['nSessions']):
263
+ tmpDiag = OutputThreshold[SessionSwitch[did]:SessionSwitch[did + 1], SessionSwitch[did]:SessionSwitch[did + 1]]
264
+ nUnits = tmpDiag.shape[0]
265
+ tmpDiag[np.eye(nUnits) == 1] = 0
266
+ FPestPerSession[did] = tmpDiag.sum() / (nUnits ** 2 - nUnits) * 100
267
+ print(f'The percentage of false +ve\'s is {FPestPerSession[did]}% for session {did +1}')
268
+
269
+ print('\nThis assumes that the spike sorter has made no mistakes')
270
+
271
+ def currate_matches(MatchesGUI, IsMatch, NotMatch, Mode = 'And'):
272
+ """
273
+ Thereare two options, 'And' 'Or'.
274
+ 'And' gives a match if both CV give it as a match
275
+ 'Or gives a mathc if either CV gives it as a match
276
+ """
277
+ MatchesA = MatchesGUI[0]
278
+ MatchesB = MatchesGUI[1]
279
+
280
+ IsMatch = np.array(IsMatch)
281
+ NotMatch = np.array(NotMatch)
282
+
283
+ if Mode == 'And':
284
+ MatchesTmp = np.concatenate((MatchesA, MatchesB), axis = 0)
285
+ MatchesTmp, counts = np.unique(MatchesTmp, return_counts = True, axis = 0)
286
+ Matches = MatchesTmp[counts == 2]
287
+
288
+ if Mode == 'Or':
289
+ Matches = np.unique(np.concatenate((MatchesA, MatchesB), axis = 0), axis = 0)
290
+
291
+ #add matches in IS Matches
292
+ Matches = np.unique(np.concatenate((Matches, IsMatch), axis = 0), axis = 0)
293
+ print(Matches.shape)
294
+ #remove Matches in NotMatch
295
+ MatchesTmp = np.concatenate((Matches, NotMatch), axis = 0)
296
+ MatchesTmp, counts = np.unique(MatchesTmp, return_counts = True, axis = 0)
297
+ Matches = MatchesTmp[counts == 1]
298
+
299
+ return Matches
300
+
301
+ def paths_fromKS(KSdirs):
302
+ nSessions = len(KSdirs)
303
+
304
+ #Load ChannelPos
305
+ ChannelPos = []
306
+ for i in range(nSessions):
307
+ PathTmp = os.path.join(KSdirs[i], 'channel_positions.npy')
308
+ PosTmp = np.load(PathTmp)
309
+ # Want 3-D positions, however at the moment code only needs 2-D so add 1's to 0 axis position
310
+ PosTmp = np.insert(PosTmp, 0, np.ones(PosTmp.shape[0]), axis = 1)
311
+ ChannelPos.append(PosTmp)
312
+
313
+ UnitLabelPaths = []
314
+ # load Good unit Paths
315
+ for i in range(nSessions):
316
+ UnitLabelPaths.append( os.path.join(KSdirs[i], 'cluster_group.tsv'))
317
+
318
+ WavePaths = []
319
+ for i in range(nSessions):
320
+ WavePaths.append( os.path.join(KSdirs[i], 'RawWaveforms'))
321
+
322
+ return WavePaths, UnitLabelPaths, ChannelPos
323
+
324
+ ##########################################################################################################################
325
+ #The following functions are the old way of reading in units, is slower and will not work if unit are missing e.g 1,2,4
326
+
327
+ # def load_waveforms(WavePaths, UnitLabelPaths, param):
328
+ # """
329
+ # This function uses a list of paths to the average waveforms and good units to load in all
330
+ # the waveforms and session related information.
331
+ # """
332
+
333
+ # #assuming the number of sessions is the same as length of WaveF_paths
334
+ # nSessions = len(WavePaths)
335
+
336
+ # # load in individual session waveforms as list of np arrays
337
+ # waveforms = []
338
+ # for i in range(len(WavePaths)):
339
+ # #waveforms.append(util.get_waveform(WaveF_paths[i]))
340
+ # tmp = get_waveform(WavePaths[i])
341
+ # GoodUnitIdxTmp = get_good_unit_idx(UnitLabelPaths[i])
342
+ # waveforms.append(good_units(tmp, GoodUnitIdxTmp))
343
+ # del tmp
344
+ # del GoodUnitIdxTmp
345
+
346
+
347
+ # nUnitsPerSession = np.zeros(nSessions, dtype = 'int')
348
+ # waveform = np.array([])
349
+
350
+ # #add all of the individual waveforms to one waveform array
351
+ # for i in range(nSessions):
352
+ # if i == 0:
353
+ # waveform = waveforms[i]
354
+ # else:
355
+ # waveform = np.concatenate((waveform, waveforms[i]), axis = 0)
356
+
357
+ # nUnitsPerSession[i] = waveforms[i].shape[0]
358
+
359
+ # param['n_units'], sessionid, SessionSwitch, param['n_days'] = get_session_data(nUnitsPerSession)
360
+ # WithinSession = get_within_session(sessionid, param)
361
+ # param['n_channels'] = waveform.shape[2]
362
+
363
+ # return waveform, sessionid, SessionSwitch, WithinSession, param
364
+
365
+ # def get_waveform(FolderPath):
366
+ # '''
367
+ # Assuming the raw spike are saved as Unitxxx_RawSpikes.npy where xxx is the number id of the spike,
368
+ # Requires the path to the folder where all the spike are saved.
369
+ # requires all spike to have same dimensions
370
+
371
+ # returns the waveform matrix (No. Units, spike dims), assumed (No.units, time, channel_no, first half/second half)
372
+
373
+ # Could:
374
+ # - parallelize
375
+ # - open in blocks of n units
376
+ # - adapt to open other types of files
377
+ # '''
378
+ # nFiles = len(os.listdir(FolderPath))
379
+
380
+ # tmp = np.load(FolderPath + r'\Unit0_RawSpikes.npy')
381
+ # waveform = np.zeros((nFiles, tmp.shape[0], tmp.shape[1], tmp.shape[2]))
382
+
383
+ # for i in range(nFiles):
384
+ # path = FolderPath + rf'\Unit{i}_RawSpikes.npy'
385
+ # waveform[i] = np.load(path)
386
+ # return waveform
387
+
388
+
389
+ # def get_files_outdated(FolderPath):
390
+ # ''' Similar to get_wave for (maybe more optimized), however:
391
+ # 1 doesn't return numerically order spikes
392
+ # 2 doesn't assume anything about how spikes are saved ( except numpy array of constant dims)
393
+ # '''
394
+ # for path, dirs, files in os.walk(FolderPath, topdown=True):
395
+ # i = 0
396
+ # tmp = np.load(os.path.join(path, files[0]))
397
+ # waveform = np.zeros((len(files),tmp.shape[0], tmp.shape[1], tmp.shape[2] ))
398
+ # for f in files:
399
+ # waveform[i] = np.load(os.path.join(path, f))
400
+ # i +=1
401
+ # return waveform
402
+
403
+ # def get_good_unit_idx(UnitLabelPath):
404
+ # """
405
+ # Assuming until label path, is the path to a tsv file where the second row onwards is 2 columns, where the second one is the unit label
406
+ # """
407
+ # UnitLabel = load_tsv(UnitLabelPath)
408
+ # TmpIdx = np.argwhere(UnitLabel[:,1] == 'good')
409
+ # GoodUnitIdx = UnitLabel[TmpIdx, 0]
410
+ # return GoodUnitIdx
411
+
412
+ # def good_units(waveform, GoodUnitIdx):
413
+ # """
414
+ # Using goodunit_idx, this function returns the good units of a waveform
415
+ # ** may want to edit, so it can select good units if the unit axes isn't the first axis and is adaptable to any shape of input
416
+ # """
417
+ # waveform = waveform[GoodUnitIdx,:,:,:].squeeze()
418
+ # return waveform
419
+
420
+ # #################################
421
+ # # The following is how the above function would be used to read in data
422
+ # # read in data and select the good units and exact metadata
423
+
424
+ # #loads in waveforms,
425
+ # waveform1 = util.get_waveform(WavePath1)
426
+ # waveform2 = util.get_waveform(WavePath2)
427
+
428
+ # #selects 'good' units
429
+ # GoodUnitIdx1 = util.get_good_unit_idx(UnitLabelPath1)
430
+ # waveform1 = util.good_units(waveform1, GoodUnitIdx1)
431
+ # GoodUnitIdx2 = util.get_good_unit_idx(UnitLabelPath2)
432
+ # waveform2 = util.good_units(waveform2, GoodUnitIdx2)
433
+
434
+ # #joins the waveforms together, and keep track of length of each session
435
+ # waveform = np.concatenate((waveform1,waveform2), axis = 0 )
436
+ # nUnitsPerSession = np.asarray([waveform1.shape[0], waveform2.shape[0]])
437
+
438
+ # # assigns a session id to each units and notes when the sessions switch
439
+ # param['nUnits'], sessionid, SessionSwitch, param['nSessions'] = util.get_session_data(nUnitsPerSession)
440
+ # WithinSession = util.get_within_session(sessionid, param)
441
+ # param['nChannels'] = waveform.shape[2]