UnitMatchPy 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- UnitMatchPy/AssignUniqueID.py +143 -0
- UnitMatchPy/Bayes_fun.py +66 -0
- UnitMatchPy/Extract_raw_data.py +216 -0
- UnitMatchPy/GUI.py +1300 -0
- UnitMatchPy/Metrics_fun.py +550 -0
- UnitMatchPy/Overlord.py +102 -0
- UnitMatchPy/Param_fun.py +306 -0
- UnitMatchPy/Save_utils.py +324 -0
- UnitMatchPy/__init__.py +9 -0
- UnitMatchPy/utils.py +441 -0
- UnitMatchPy-1.0.dist-info/LICENSE +437 -0
- UnitMatchPy-1.0.dist-info/METADATA +47 -0
- UnitMatchPy-1.0.dist-info/RECORD +15 -0
- UnitMatchPy-1.0.dist-info/WHEEL +5 -0
- UnitMatchPy-1.0.dist-info/top_level.txt +1 -0
UnitMatchPy/utils.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
1
|
+
# utility function for loading files etc
|
|
2
|
+
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import os
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
|
|
7
|
+
def load_tsv(path):
|
|
8
|
+
"""
|
|
9
|
+
Loads a tsv, as a numpy array with the headers removed.
|
|
10
|
+
"""
|
|
11
|
+
df = pd.read_csv(path, sep='\t', skiprows = 0)
|
|
12
|
+
return df.values
|
|
13
|
+
|
|
14
|
+
def get_session_number(unitid, SessionSwitch):
|
|
15
|
+
|
|
16
|
+
for i in range(len(SessionSwitch) - 1):
|
|
17
|
+
if (SessionSwitch[i] <= unitid < SessionSwitch[i+1]):
|
|
18
|
+
return i
|
|
19
|
+
|
|
20
|
+
def get_session_data(nUnitsPerSession):
|
|
21
|
+
"""
|
|
22
|
+
Input the number of units per day/session as a numpy array, will return:
|
|
23
|
+
the total number of units, sessionid and array where each unit is given a number according to what session it is a member of
|
|
24
|
+
the index's of when the session switches in form [0, end of session 1, end of session 2....end of final session]
|
|
25
|
+
"""
|
|
26
|
+
nSessions = len(nUnitsPerSession)
|
|
27
|
+
nUnits = nUnitsPerSession.sum()
|
|
28
|
+
|
|
29
|
+
sessionid = np.zeros(nUnits, dtype = int)
|
|
30
|
+
SessionSwitch = np.cumsum(nUnitsPerSession)
|
|
31
|
+
SessionSwitch = np.insert(SessionSwitch, 0, 0)
|
|
32
|
+
for i in range(nSessions):
|
|
33
|
+
sessionid[SessionSwitch[i]:SessionSwitch[i+1]] = int(i)
|
|
34
|
+
|
|
35
|
+
return nUnits, sessionid, SessionSwitch, nSessions
|
|
36
|
+
|
|
37
|
+
def get_within_session(sessionid, param):
|
|
38
|
+
"""
|
|
39
|
+
Uses the session id to great a nUnits * nUnits array, where it is 0 if the units are from the same session
|
|
40
|
+
and it is one if the units are from a different session
|
|
41
|
+
"""
|
|
42
|
+
nUnits = param['nUnits']
|
|
43
|
+
|
|
44
|
+
tmp1 = np.expand_dims(sessionid , axis=1)
|
|
45
|
+
tmp2 = np.expand_dims(sessionid, axis=0)
|
|
46
|
+
|
|
47
|
+
WithinSession = np.ones((nUnits, nUnits))
|
|
48
|
+
WithinSession[tmp1 == tmp2] = 0
|
|
49
|
+
|
|
50
|
+
return WithinSession
|
|
51
|
+
|
|
52
|
+
def get_default_param(param = None):
|
|
53
|
+
"""
|
|
54
|
+
Create param, a dictionary with the default parameters.
|
|
55
|
+
If a dictionary is given, it will add values to it without overwriting existing values.
|
|
56
|
+
Do not need to give a dictionary.
|
|
57
|
+
"""
|
|
58
|
+
tmp = {'SpikeWidth' : 82, 'waveidx' : np.arange(33,56), 'ChannelRadius' : 150,
|
|
59
|
+
'PeakLoc' : 40, 'MaxDist' : 100, 'NeighbourDist' : 50, 'stepsz' : 0.01,
|
|
60
|
+
'SmoothProb' : 9, 'MinAngleDist' : 0.1, 'NoShanks' : 4, 'ShankDist' : 175,
|
|
61
|
+
'MatchNumThreshold' : 15, 'MatchThreshold' : 0.5
|
|
62
|
+
}
|
|
63
|
+
tmp['ScoreVector'] = np.arange(tmp['stepsz']/2 ,1 ,tmp['stepsz'])
|
|
64
|
+
tmp['Bins'] = np.arange(0, 1 + tmp['stepsz'], tmp['stepsz'])
|
|
65
|
+
|
|
66
|
+
#if no dictionary is given just returns the default parameters
|
|
67
|
+
if param == None:
|
|
68
|
+
out = tmp
|
|
69
|
+
else:
|
|
70
|
+
#Add default parameters to param dictionary, does not overwrite pre existing param values
|
|
71
|
+
out = tmp | param
|
|
72
|
+
return out
|
|
73
|
+
|
|
74
|
+
def load_good_waveforms(WavePaths, UnitLabelPaths, param, GoodUnitsOnly = True):
|
|
75
|
+
""""
|
|
76
|
+
This is the recommended way to read in data. It uses
|
|
77
|
+
"""
|
|
78
|
+
if len(WavePaths) == len(UnitLabelPaths):
|
|
79
|
+
nSessions = len(WavePaths)
|
|
80
|
+
else:
|
|
81
|
+
print('Warning: gave different number of paths for waveforms and labels!')
|
|
82
|
+
return
|
|
83
|
+
|
|
84
|
+
GoodUnits = []
|
|
85
|
+
nUnitsPerSessionALL = []
|
|
86
|
+
for i in range(len(UnitLabelPaths)):
|
|
87
|
+
UnitLabel = load_tsv(UnitLabelPaths[i])
|
|
88
|
+
TmpIdx = np.argwhere(UnitLabel[:,1] == 'good')
|
|
89
|
+
goodunit_idx = UnitLabel[TmpIdx, 0]
|
|
90
|
+
GoodUnits.append(goodunit_idx)
|
|
91
|
+
nUnitsPerSessionALL.append(UnitLabel.shape[0])
|
|
92
|
+
|
|
93
|
+
waveforms = []
|
|
94
|
+
if GoodUnitsOnly:
|
|
95
|
+
#go through each session and load in units to waveforms list
|
|
96
|
+
for ls in range(len(WavePaths)):
|
|
97
|
+
#load in the first good unit, to get the shape of each waveform
|
|
98
|
+
p_file = os.path.join(WavePaths[ls],f'Unit{int(GoodUnits[ls][0].squeeze())}_RawSpikes.npy')
|
|
99
|
+
tmp = np.load(p_file)
|
|
100
|
+
tmpWaveform = np.zeros( (len(GoodUnits[ls]), tmp.shape[0], tmp.shape[1], tmp.shape[2]))
|
|
101
|
+
|
|
102
|
+
for i in range(len(GoodUnits[ls])):
|
|
103
|
+
#loads in all GoodUnits for that session
|
|
104
|
+
p_file_good = os.path.join(WavePaths[ls],f'Unit{int(GoodUnits[ls][i].squeeze())}_RawSpikes.npy')
|
|
105
|
+
tmpWaveform[i] = np.load(p_file_good)
|
|
106
|
+
#adds that session to the list
|
|
107
|
+
waveforms.append(tmpWaveform)
|
|
108
|
+
|
|
109
|
+
del tmpWaveform
|
|
110
|
+
del tmp
|
|
111
|
+
|
|
112
|
+
else:
|
|
113
|
+
for ls in range(len(WavePaths)):
|
|
114
|
+
#load in the first good unit, to get the shape of each waveform
|
|
115
|
+
p_file = os.path.join(WavePaths[ls],f'Unit{int(GoodUnits[ls][0].squeeze())}_RawSpikes.npy')
|
|
116
|
+
tmp = np.load(p_file)
|
|
117
|
+
tmpWaveform = np.zeros( (len(os.listdir(WavePaths[ls])), tmp.shape[0], tmp.shape[1], tmp.shape[2]))
|
|
118
|
+
|
|
119
|
+
for i in range(len(os.listdir(WavePaths[ls]))):
|
|
120
|
+
#loads in all GoodUnits for that session
|
|
121
|
+
p_file_good = os.path.join(WavePaths[ls], f'Unit{int(GoodUnits[ls][0].squeeze())}_RawSpikes.npy')
|
|
122
|
+
tmpWaveform[i] = np.load(p_file_good)
|
|
123
|
+
#adds that session to the list
|
|
124
|
+
waveforms.append(tmpWaveform)
|
|
125
|
+
|
|
126
|
+
del tmpWaveform
|
|
127
|
+
del tmp
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
nUnitsPerSession = np.zeros(nSessions, dtype = 'int')
|
|
132
|
+
waveform = np.array([])
|
|
133
|
+
|
|
134
|
+
#add all of the individual waveforms to one waveform array
|
|
135
|
+
for i in range(nSessions):
|
|
136
|
+
if i == 0:
|
|
137
|
+
waveform = waveforms[i]
|
|
138
|
+
else:
|
|
139
|
+
waveform = np.concatenate((waveform, waveforms[i]), axis = 0)
|
|
140
|
+
|
|
141
|
+
nUnitsPerSession[i] = waveforms[i].shape[0]
|
|
142
|
+
|
|
143
|
+
param['nUnits'], sessionid, SessionSwitch, param['nSessions'] = get_session_data(nUnitsPerSession)
|
|
144
|
+
WithinSession = get_within_session(sessionid, param)
|
|
145
|
+
param['nChannels'] = waveform.shape[2]
|
|
146
|
+
param['nUnitsPerSession'] = nUnitsPerSessionALL
|
|
147
|
+
|
|
148
|
+
#if the set of default paramaters have a different spike width update these parameters
|
|
149
|
+
if param['SpikeWidth'] != waveform.shape[1]:
|
|
150
|
+
param['SpikeWidth'] = waveform.shape[1]
|
|
151
|
+
param['PeakLoc'] = np.floor(waveform.shape[1]/2).astype(int)
|
|
152
|
+
param['waveidx'] = np.arange(param['PeakLoc'] - 8, param['PeakLoc'] + 15, dtype = int)
|
|
153
|
+
|
|
154
|
+
return waveform, sessionid, SessionSwitch, WithinSession, GoodUnits, param
|
|
155
|
+
|
|
156
|
+
def get_good_units(UnitLabelPaths, good = True):
|
|
157
|
+
"""
|
|
158
|
+
Requires the paths to .tsv files, which contain the unit index's and if they area a good unit.
|
|
159
|
+
Will return a list where each index of the list is a numpy array ofall the good index's.
|
|
160
|
+
This function is set to only get index's for units labelled 'good', pass good = False to get ALL unit index's
|
|
161
|
+
"""
|
|
162
|
+
GoodUnits = []
|
|
163
|
+
for i in range(len(UnitLabelPaths)):
|
|
164
|
+
UnitLabel = load_tsv(UnitLabelPaths[i])
|
|
165
|
+
if good == True:
|
|
166
|
+
TmpIdx = np.argwhere(UnitLabel[:,1] == 'good')
|
|
167
|
+
else:
|
|
168
|
+
TmpIdx = UnitLabel[:,0] # every unit index in the first column
|
|
169
|
+
GoodUnitIdx = UnitLabel[TmpIdx, 0]
|
|
170
|
+
GoodUnits.append(GoodUnitIdx)
|
|
171
|
+
return GoodUnits
|
|
172
|
+
|
|
173
|
+
def load_good_units(GoodUnits, WavePaths, param):
|
|
174
|
+
"""
|
|
175
|
+
Requires a list which contains a numpy array with the units to load per session, as well as a path to
|
|
176
|
+
a file which contains all the the raw averaged units
|
|
177
|
+
"""
|
|
178
|
+
if len(WavePaths) == len(GoodUnits):
|
|
179
|
+
nSessions = len(WavePaths)
|
|
180
|
+
else:
|
|
181
|
+
print('Warning: gave different number of paths for waveforms and labels!')
|
|
182
|
+
return
|
|
183
|
+
|
|
184
|
+
waveforms = []
|
|
185
|
+
#go through each session and load in units to waveforms list
|
|
186
|
+
for ls in range(len(WavePaths)):
|
|
187
|
+
#load in the first good unit, to get the shape of each waveform
|
|
188
|
+
tmp_path = os.path.join(WavePaths[ls], f'Unit{int(GoodUnits[ls][0].squeeze())}_RawSpikes.npy')
|
|
189
|
+
tmp = np.load(tmp_path)
|
|
190
|
+
tmpWaveform = np.zeros( (len(GoodUnits[ls]), tmp.shape[0], tmp.shape[1], tmp.shape[2]))
|
|
191
|
+
|
|
192
|
+
for i in range(len(GoodUnits[ls])):
|
|
193
|
+
#loads in all GoodUnits for that session
|
|
194
|
+
tmp_path_good = os.path.join(WavePaths[ls], f'Unit{int(GoodUnits[ls][i].squeeze())}_RawSpikes.npy')
|
|
195
|
+
tmpWaveform[i] = np.load(tmp_path_good)
|
|
196
|
+
#adds that session to the list
|
|
197
|
+
waveforms.append(tmpWaveform)
|
|
198
|
+
|
|
199
|
+
del tmpWaveform
|
|
200
|
+
del tmp
|
|
201
|
+
|
|
202
|
+
nUnitsPerSession = np.zeros(nSessions, dtype = 'int')
|
|
203
|
+
waveform = np.array([])
|
|
204
|
+
|
|
205
|
+
#add all of the individual waveforms to one waveform array
|
|
206
|
+
for i in range(nSessions):
|
|
207
|
+
if i == 0:
|
|
208
|
+
waveform = waveforms[i]
|
|
209
|
+
else:
|
|
210
|
+
waveform = np.concatenate((waveform, waveforms[i]), axis = 0)
|
|
211
|
+
|
|
212
|
+
nUnitsPerSession[i] = waveforms[i].shape[0]
|
|
213
|
+
|
|
214
|
+
param['nUnits'], sessionid, SessionSwitch, param['nSessions'] = get_session_data(nUnitsPerSession)
|
|
215
|
+
WithinSession = get_within_session(sessionid, param)
|
|
216
|
+
param['nChannels'] = waveform.shape[2]
|
|
217
|
+
return waveform, sessionid, SessionSwitch, WithinSession, param
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def compare_units(AvgWaveform, AvgCentroid, unit1, unit2):
|
|
221
|
+
"""
|
|
222
|
+
Basic helper function, plots the average wave function (of cv 0) and the average centroid to quickly compare 2 units
|
|
223
|
+
"""
|
|
224
|
+
plt.plot(AvgWaveform[:,unit1,0])
|
|
225
|
+
plt.plot(AvgWaveform[:,unit2,0])
|
|
226
|
+
print(f'Average centroid of unit {unit1} is :{AvgCentroid[:,unit1,0]}')
|
|
227
|
+
print(f'Average centroid of unit {unit2} is :{AvgCentroid[:,unit2,0]}')
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def evaluate_output(output, param, WithinSession, SessionSwitch, MatchThreshold = 0.5):
|
|
231
|
+
""""
|
|
232
|
+
Input: output - the n_units * n_units probability matrix (each value is prob those units match)
|
|
233
|
+
the param dictionary and optionally the threshold used to calculate if a unit is a match
|
|
234
|
+
|
|
235
|
+
This function then print:
|
|
236
|
+
The number of units matched to themselves across cv
|
|
237
|
+
The false negative %, how many did not match to themselves across cv
|
|
238
|
+
the false positive % in two ways, how many miss-matches are there in the off-diagonal per session
|
|
239
|
+
and how many false match out of how many matches we should get
|
|
240
|
+
"""
|
|
241
|
+
|
|
242
|
+
OutputThreshold = np.zeros_like(output)
|
|
243
|
+
OutputThreshold[output > MatchThreshold] = 1
|
|
244
|
+
|
|
245
|
+
# get the number of diagonal matches
|
|
246
|
+
nDiag = np.sum(OutputThreshold[np.eye(param['nUnits']).astype(bool)])
|
|
247
|
+
SelfMatch = nDiag / param['nUnits'] *100
|
|
248
|
+
print(f'The percentage of units matched to themselves is: {SelfMatch}%')
|
|
249
|
+
print(f'The percentage of false -ve\'s then is: {100 - SelfMatch}% \n')
|
|
250
|
+
|
|
251
|
+
#off-diagonal miss-matches
|
|
252
|
+
nOffDiag = np.zeros_like(output)
|
|
253
|
+
nOffDiag = OutputThreshold
|
|
254
|
+
nOffDiag[WithinSession == 1] = 0
|
|
255
|
+
nOffDiag[np.eye(param['nUnits']) == 1] = 0
|
|
256
|
+
FPest = nOffDiag.sum() / (param['nUnits'])
|
|
257
|
+
print(f'The rate of miss-match(es) per expected match {FPest}')
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
#compute matlab FP per session per session
|
|
261
|
+
FPestPerSession = np.zeros(param['nSessions'])
|
|
262
|
+
for did in range(param['nSessions']):
|
|
263
|
+
tmpDiag = OutputThreshold[SessionSwitch[did]:SessionSwitch[did + 1], SessionSwitch[did]:SessionSwitch[did + 1]]
|
|
264
|
+
nUnits = tmpDiag.shape[0]
|
|
265
|
+
tmpDiag[np.eye(nUnits) == 1] = 0
|
|
266
|
+
FPestPerSession[did] = tmpDiag.sum() / (nUnits ** 2 - nUnits) * 100
|
|
267
|
+
print(f'The percentage of false +ve\'s is {FPestPerSession[did]}% for session {did +1}')
|
|
268
|
+
|
|
269
|
+
print('\nThis assumes that the spike sorter has made no mistakes')
|
|
270
|
+
|
|
271
|
+
def currate_matches(MatchesGUI, IsMatch, NotMatch, Mode = 'And'):
|
|
272
|
+
"""
|
|
273
|
+
Thereare two options, 'And' 'Or'.
|
|
274
|
+
'And' gives a match if both CV give it as a match
|
|
275
|
+
'Or gives a mathc if either CV gives it as a match
|
|
276
|
+
"""
|
|
277
|
+
MatchesA = MatchesGUI[0]
|
|
278
|
+
MatchesB = MatchesGUI[1]
|
|
279
|
+
|
|
280
|
+
IsMatch = np.array(IsMatch)
|
|
281
|
+
NotMatch = np.array(NotMatch)
|
|
282
|
+
|
|
283
|
+
if Mode == 'And':
|
|
284
|
+
MatchesTmp = np.concatenate((MatchesA, MatchesB), axis = 0)
|
|
285
|
+
MatchesTmp, counts = np.unique(MatchesTmp, return_counts = True, axis = 0)
|
|
286
|
+
Matches = MatchesTmp[counts == 2]
|
|
287
|
+
|
|
288
|
+
if Mode == 'Or':
|
|
289
|
+
Matches = np.unique(np.concatenate((MatchesA, MatchesB), axis = 0), axis = 0)
|
|
290
|
+
|
|
291
|
+
#add matches in IS Matches
|
|
292
|
+
Matches = np.unique(np.concatenate((Matches, IsMatch), axis = 0), axis = 0)
|
|
293
|
+
print(Matches.shape)
|
|
294
|
+
#remove Matches in NotMatch
|
|
295
|
+
MatchesTmp = np.concatenate((Matches, NotMatch), axis = 0)
|
|
296
|
+
MatchesTmp, counts = np.unique(MatchesTmp, return_counts = True, axis = 0)
|
|
297
|
+
Matches = MatchesTmp[counts == 1]
|
|
298
|
+
|
|
299
|
+
return Matches
|
|
300
|
+
|
|
301
|
+
def paths_fromKS(KSdirs):
|
|
302
|
+
nSessions = len(KSdirs)
|
|
303
|
+
|
|
304
|
+
#Load ChannelPos
|
|
305
|
+
ChannelPos = []
|
|
306
|
+
for i in range(nSessions):
|
|
307
|
+
PathTmp = os.path.join(KSdirs[i], 'channel_positions.npy')
|
|
308
|
+
PosTmp = np.load(PathTmp)
|
|
309
|
+
# Want 3-D positions, however at the moment code only needs 2-D so add 1's to 0 axis position
|
|
310
|
+
PosTmp = np.insert(PosTmp, 0, np.ones(PosTmp.shape[0]), axis = 1)
|
|
311
|
+
ChannelPos.append(PosTmp)
|
|
312
|
+
|
|
313
|
+
UnitLabelPaths = []
|
|
314
|
+
# load Good unit Paths
|
|
315
|
+
for i in range(nSessions):
|
|
316
|
+
UnitLabelPaths.append( os.path.join(KSdirs[i], 'cluster_group.tsv'))
|
|
317
|
+
|
|
318
|
+
WavePaths = []
|
|
319
|
+
for i in range(nSessions):
|
|
320
|
+
WavePaths.append( os.path.join(KSdirs[i], 'RawWaveforms'))
|
|
321
|
+
|
|
322
|
+
return WavePaths, UnitLabelPaths, ChannelPos
|
|
323
|
+
|
|
324
|
+
##########################################################################################################################
|
|
325
|
+
#The following functions are the old way of reading in units, is slower and will not work if unit are missing e.g 1,2,4
|
|
326
|
+
|
|
327
|
+
# def load_waveforms(WavePaths, UnitLabelPaths, param):
|
|
328
|
+
# """
|
|
329
|
+
# This function uses a list of paths to the average waveforms and good units to load in all
|
|
330
|
+
# the waveforms and session related information.
|
|
331
|
+
# """
|
|
332
|
+
|
|
333
|
+
# #assuming the number of sessions is the same as length of WaveF_paths
|
|
334
|
+
# nSessions = len(WavePaths)
|
|
335
|
+
|
|
336
|
+
# # load in individual session waveforms as list of np arrays
|
|
337
|
+
# waveforms = []
|
|
338
|
+
# for i in range(len(WavePaths)):
|
|
339
|
+
# #waveforms.append(util.get_waveform(WaveF_paths[i]))
|
|
340
|
+
# tmp = get_waveform(WavePaths[i])
|
|
341
|
+
# GoodUnitIdxTmp = get_good_unit_idx(UnitLabelPaths[i])
|
|
342
|
+
# waveforms.append(good_units(tmp, GoodUnitIdxTmp))
|
|
343
|
+
# del tmp
|
|
344
|
+
# del GoodUnitIdxTmp
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
# nUnitsPerSession = np.zeros(nSessions, dtype = 'int')
|
|
348
|
+
# waveform = np.array([])
|
|
349
|
+
|
|
350
|
+
# #add all of the individual waveforms to one waveform array
|
|
351
|
+
# for i in range(nSessions):
|
|
352
|
+
# if i == 0:
|
|
353
|
+
# waveform = waveforms[i]
|
|
354
|
+
# else:
|
|
355
|
+
# waveform = np.concatenate((waveform, waveforms[i]), axis = 0)
|
|
356
|
+
|
|
357
|
+
# nUnitsPerSession[i] = waveforms[i].shape[0]
|
|
358
|
+
|
|
359
|
+
# param['n_units'], sessionid, SessionSwitch, param['n_days'] = get_session_data(nUnitsPerSession)
|
|
360
|
+
# WithinSession = get_within_session(sessionid, param)
|
|
361
|
+
# param['n_channels'] = waveform.shape[2]
|
|
362
|
+
|
|
363
|
+
# return waveform, sessionid, SessionSwitch, WithinSession, param
|
|
364
|
+
|
|
365
|
+
# def get_waveform(FolderPath):
|
|
366
|
+
# '''
|
|
367
|
+
# Assuming the raw spike are saved as Unitxxx_RawSpikes.npy where xxx is the number id of the spike,
|
|
368
|
+
# Requires the path to the folder where all the spike are saved.
|
|
369
|
+
# requires all spike to have same dimensions
|
|
370
|
+
|
|
371
|
+
# returns the waveform matrix (No. Units, spike dims), assumed (No.units, time, channel_no, first half/second half)
|
|
372
|
+
|
|
373
|
+
# Could:
|
|
374
|
+
# - parallelize
|
|
375
|
+
# - open in blocks of n units
|
|
376
|
+
# - adapt to open other types of files
|
|
377
|
+
# '''
|
|
378
|
+
# nFiles = len(os.listdir(FolderPath))
|
|
379
|
+
|
|
380
|
+
# tmp = np.load(FolderPath + r'\Unit0_RawSpikes.npy')
|
|
381
|
+
# waveform = np.zeros((nFiles, tmp.shape[0], tmp.shape[1], tmp.shape[2]))
|
|
382
|
+
|
|
383
|
+
# for i in range(nFiles):
|
|
384
|
+
# path = FolderPath + rf'\Unit{i}_RawSpikes.npy'
|
|
385
|
+
# waveform[i] = np.load(path)
|
|
386
|
+
# return waveform
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
# def get_files_outdated(FolderPath):
|
|
390
|
+
# ''' Similar to get_wave for (maybe more optimized), however:
|
|
391
|
+
# 1 doesn't return numerically order spikes
|
|
392
|
+
# 2 doesn't assume anything about how spikes are saved ( except numpy array of constant dims)
|
|
393
|
+
# '''
|
|
394
|
+
# for path, dirs, files in os.walk(FolderPath, topdown=True):
|
|
395
|
+
# i = 0
|
|
396
|
+
# tmp = np.load(os.path.join(path, files[0]))
|
|
397
|
+
# waveform = np.zeros((len(files),tmp.shape[0], tmp.shape[1], tmp.shape[2] ))
|
|
398
|
+
# for f in files:
|
|
399
|
+
# waveform[i] = np.load(os.path.join(path, f))
|
|
400
|
+
# i +=1
|
|
401
|
+
# return waveform
|
|
402
|
+
|
|
403
|
+
# def get_good_unit_idx(UnitLabelPath):
|
|
404
|
+
# """
|
|
405
|
+
# Assuming until label path, is the path to a tsv file where the second row onwards is 2 columns, where the second one is the unit label
|
|
406
|
+
# """
|
|
407
|
+
# UnitLabel = load_tsv(UnitLabelPath)
|
|
408
|
+
# TmpIdx = np.argwhere(UnitLabel[:,1] == 'good')
|
|
409
|
+
# GoodUnitIdx = UnitLabel[TmpIdx, 0]
|
|
410
|
+
# return GoodUnitIdx
|
|
411
|
+
|
|
412
|
+
# def good_units(waveform, GoodUnitIdx):
|
|
413
|
+
# """
|
|
414
|
+
# Using goodunit_idx, this function returns the good units of a waveform
|
|
415
|
+
# ** may want to edit, so it can select good units if the unit axes isn't the first axis and is adaptable to any shape of input
|
|
416
|
+
# """
|
|
417
|
+
# waveform = waveform[GoodUnitIdx,:,:,:].squeeze()
|
|
418
|
+
# return waveform
|
|
419
|
+
|
|
420
|
+
# #################################
|
|
421
|
+
# # The following is how the above function would be used to read in data
|
|
422
|
+
# # read in data and select the good units and exact metadata
|
|
423
|
+
|
|
424
|
+
# #loads in waveforms,
|
|
425
|
+
# waveform1 = util.get_waveform(WavePath1)
|
|
426
|
+
# waveform2 = util.get_waveform(WavePath2)
|
|
427
|
+
|
|
428
|
+
# #selects 'good' units
|
|
429
|
+
# GoodUnitIdx1 = util.get_good_unit_idx(UnitLabelPath1)
|
|
430
|
+
# waveform1 = util.good_units(waveform1, GoodUnitIdx1)
|
|
431
|
+
# GoodUnitIdx2 = util.get_good_unit_idx(UnitLabelPath2)
|
|
432
|
+
# waveform2 = util.good_units(waveform2, GoodUnitIdx2)
|
|
433
|
+
|
|
434
|
+
# #joins the waveforms together, and keep track of length of each session
|
|
435
|
+
# waveform = np.concatenate((waveform1,waveform2), axis = 0 )
|
|
436
|
+
# nUnitsPerSession = np.asarray([waveform1.shape[0], waveform2.shape[0]])
|
|
437
|
+
|
|
438
|
+
# # assigns a session id to each units and notes when the sessions switch
|
|
439
|
+
# param['nUnits'], sessionid, SessionSwitch, param['nSessions'] = util.get_session_data(nUnitsPerSession)
|
|
440
|
+
# WithinSession = util.get_within_session(sessionid, param)
|
|
441
|
+
# param['nChannels'] = waveform.shape[2]
|