UnitMatchPy 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- UnitMatchPy/AssignUniqueID.py +143 -0
- UnitMatchPy/Bayes_fun.py +66 -0
- UnitMatchPy/Extract_raw_data.py +216 -0
- UnitMatchPy/GUI.py +1300 -0
- UnitMatchPy/Metrics_fun.py +550 -0
- UnitMatchPy/Overlord.py +102 -0
- UnitMatchPy/Param_fun.py +306 -0
- UnitMatchPy/Save_utils.py +324 -0
- UnitMatchPy/__init__.py +9 -0
- UnitMatchPy/utils.py +441 -0
- UnitMatchPy-1.0.dist-info/LICENSE +437 -0
- UnitMatchPy-1.0.dist-info/METADATA +47 -0
- UnitMatchPy-1.0.dist-info/RECORD +15 -0
- UnitMatchPy-1.0.dist-info/WHEEL +5 -0
- UnitMatchPy-1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
def check_is_in(TestArray, ParentArray):
|
|
4
|
+
"""
|
|
5
|
+
Test to se if a row in TestArray is in ParrentArray
|
|
6
|
+
Arguments:
|
|
7
|
+
TestArray -- ndarray (N, 2)
|
|
8
|
+
ParentArray -- ndArray(M, 2)
|
|
9
|
+
|
|
10
|
+
Returns:
|
|
11
|
+
IsIn -- ndarray (N) dtype - bool
|
|
12
|
+
"""
|
|
13
|
+
IsIn = (TestArray[:, None] == ParentArray).all(-1).any(-1)
|
|
14
|
+
return IsIn
|
|
15
|
+
|
|
16
|
+
def AssignUID(Output, param, ClusInfo):
|
|
17
|
+
|
|
18
|
+
AllClusterIDs = ClusInfo['OriginalID'] # each units has unique ID
|
|
19
|
+
|
|
20
|
+
#create arrays for the uniwue ids
|
|
21
|
+
UniqueIDLiberal = np.arange(AllClusterIDs.shape[0])
|
|
22
|
+
OriUniqueID = np.arange(AllClusterIDs.shape[0])
|
|
23
|
+
UniqueIDConservative = np.arange(AllClusterIDs.shape[0])
|
|
24
|
+
UniqueID = np.arange(AllClusterIDs.shape[0])
|
|
25
|
+
|
|
26
|
+
GoodRecSesID = ClusInfo['SessionID']
|
|
27
|
+
RecOpt = np.unique(ClusInfo['SessionID'])
|
|
28
|
+
nRec = RecOpt.shape[0]
|
|
29
|
+
|
|
30
|
+
#data Driven Threshold?
|
|
31
|
+
if param.get('UseDataDrivenProbThrs', False):
|
|
32
|
+
stepsz = 0.1
|
|
33
|
+
binedges = np.arange(0, 1 + stepsz, stepsz)
|
|
34
|
+
plotvec = np.arange(stepsz / 2, 1, stepsz)
|
|
35
|
+
|
|
36
|
+
hw, __ = np.histogram(np.diag(Output), bins = len(binedges), density = True)
|
|
37
|
+
|
|
38
|
+
Threshold = plotvec[np.diff(hw) > 0.1]
|
|
39
|
+
else:
|
|
40
|
+
Threshold = param['MatchThreshold']
|
|
41
|
+
|
|
42
|
+
Pairs = np.argwhere(Output > Threshold)
|
|
43
|
+
Pairs = np.delete(Pairs, np.argwhere(Pairs[:,0] == Pairs[:,1]), axis =0) #delete self-matches
|
|
44
|
+
Pairs = np.sort(Pairs, axis = 1)# arange so smaller pairID is in column 1
|
|
45
|
+
#Only keep one copy of pairs only if both CV agrree its a match
|
|
46
|
+
PairsUnique, Count = np.unique(Pairs, axis = 0, return_counts=True)
|
|
47
|
+
PairsUniqueFilt = np.delete(PairsUnique, Count == 1, axis = 0) #if Count = 1 only 1 CV for that pair!
|
|
48
|
+
|
|
49
|
+
#get the mean probabilty for each match
|
|
50
|
+
ProbMean = np.nanmean(np.vstack((Output[PairsUniqueFilt[:,0], PairsUniqueFilt[:,1]], Output[PairsUniqueFilt[:,1], PairsUniqueFilt[:,0]])), axis=0)
|
|
51
|
+
#sort by the mean probabilty
|
|
52
|
+
PairsProb = np.hstack((PairsUniqueFilt, ProbMean[:, np.newaxis]))
|
|
53
|
+
SortIdxs = np.argsort(-PairsProb[:,2], axis = 0) #start go in decending order
|
|
54
|
+
PairsProbSorted = np.zeros_like(PairsProb)
|
|
55
|
+
PairsProbSorted = PairsProb[SortIdxs,:]
|
|
56
|
+
|
|
57
|
+
#Create a list which has both copies of each match e.g (1,2) and (2,1) for easier comparisson
|
|
58
|
+
PairsAll = np.zeros((PairsUniqueFilt.shape[0]*2,2))
|
|
59
|
+
PairsAll[:PairsUniqueFilt.shape[0],:] = PairsUniqueFilt
|
|
60
|
+
PairsAll[PairsUniqueFilt.shape[0]:,:] = PairsUniqueFilt[:, (1,0)]
|
|
61
|
+
|
|
62
|
+
nMatchesConservative = 0
|
|
63
|
+
nMatchesLiberal = 0
|
|
64
|
+
nMatches = 0
|
|
65
|
+
#Go through each pair and assign to groups!!
|
|
66
|
+
for pair in PairsProbSorted[:,:2]:
|
|
67
|
+
pair = pair.astype(np.int16)
|
|
68
|
+
|
|
69
|
+
#Get the conservative group ID for thecurrent 2 units
|
|
70
|
+
UnitAConservativeID = UniqueIDConservative[pair[0]]
|
|
71
|
+
UnitBConservativeID = UniqueIDConservative[pair[1]]
|
|
72
|
+
# get all units which have the same ID
|
|
73
|
+
SameGroupIdA = np.argwhere(UniqueIDConservative == UnitAConservativeID).squeeze()
|
|
74
|
+
SameGroupIdB = np.argwhere(UniqueIDConservative == UnitBConservativeID).squeeze()
|
|
75
|
+
#reshape array to be a 1d array if needed
|
|
76
|
+
if len(SameGroupIdA.shape) == 0:
|
|
77
|
+
SameGroupIdA = SameGroupIdA[np.newaxis]
|
|
78
|
+
if len(SameGroupIdB.shape) == 0:
|
|
79
|
+
SameGroupIdB = SameGroupIdB[np.newaxis]
|
|
80
|
+
|
|
81
|
+
#will need to check if pair[0] has match with SameGroupIdB and vice versa
|
|
82
|
+
CheckPairsA = np.stack((SameGroupIdB, np.broadcast_to(np.array(pair[0]), SameGroupIdB.shape)), axis = -1)
|
|
83
|
+
CheckPairsB = np.stack((SameGroupIdA, np.broadcast_to(np.array(pair[1]), SameGroupIdA.shape)), axis = -1)
|
|
84
|
+
# delete the potential self-matches
|
|
85
|
+
CheckPairsA = np.delete(CheckPairsA, np.argwhere(CheckPairsA[:,0] == CheckPairsA[:,1]), axis =0)
|
|
86
|
+
CheckPairsB = np.delete(CheckPairsB, np.argwhere(CheckPairsB[:,0] == CheckPairsB[:,1]), axis =0)
|
|
87
|
+
|
|
88
|
+
if (np.logical_and(np.all(check_is_in(CheckPairsA, PairsAll)), np.all(check_is_in(CheckPairsB, PairsAll)))):
|
|
89
|
+
#If each pairs matches with every unit in the other pairs group
|
|
90
|
+
#can add as match to all classes
|
|
91
|
+
UniqueIDConservative[pair[0]] = min(UniqueIDConservative[pair])
|
|
92
|
+
UniqueIDConservative[pair[1]] = min(UniqueIDConservative[pair])
|
|
93
|
+
nMatchesConservative +=1
|
|
94
|
+
|
|
95
|
+
UniqueID[pair[0]] = min(UniqueID[pair])
|
|
96
|
+
UniqueID[pair[1]] = min(UniqueID[pair])
|
|
97
|
+
nMatches +=1
|
|
98
|
+
|
|
99
|
+
UniqueIDLiberal[pair[0]] = min(UniqueIDLiberal[pair])
|
|
100
|
+
UniqueIDLiberal[pair[0]] = min(UniqueIDLiberal[pair])
|
|
101
|
+
nMatchesLiberal +=1
|
|
102
|
+
else:
|
|
103
|
+
#Now test to see if each pairs match with every unit in the other pair IF they are in the same/adjacent sessions
|
|
104
|
+
UnitAID = UniqueID[pair[0]]
|
|
105
|
+
UnitBID = UniqueID[pair[1]]
|
|
106
|
+
|
|
107
|
+
SameGroupIdA = np.argwhere(UniqueID == UnitAID).squeeze()
|
|
108
|
+
SameGroupIdB = np.argwhere(UniqueID == UnitBID).squeeze()
|
|
109
|
+
if len(SameGroupIdA.shape) == 0:
|
|
110
|
+
SameGroupIdA = SameGroupIdA[np.newaxis]
|
|
111
|
+
if len(SameGroupIdB.shape) == 0:
|
|
112
|
+
SameGroupIdB = SameGroupIdB[np.newaxis]
|
|
113
|
+
|
|
114
|
+
CheckPairsA = np.stack((SameGroupIdB, np.broadcast_to(np.array(pair[0]), SameGroupIdB.shape)), axis = -1)
|
|
115
|
+
CheckPairsB = np.stack((SameGroupIdA, np.broadcast_to(np.array(pair[1]), SameGroupIdA.shape)), axis = -1)
|
|
116
|
+
#delte potential self-matches
|
|
117
|
+
CheckPairsA = np.delete(CheckPairsA, np.argwhere(CheckPairsA[:,0] == CheckPairsA[:,1]), axis =0)
|
|
118
|
+
CheckPairsB = np.delete(CheckPairsB, np.argwhere(CheckPairsB[:,0] == CheckPairsB[:,1]), axis =0)
|
|
119
|
+
|
|
120
|
+
#check to see if they are in the same or adjacent sessions
|
|
121
|
+
InNearSessionA = np.abs(np.diff(ClusInfo['SessionID'][CheckPairsA])) <= 1
|
|
122
|
+
InNearSessionB = np.abs(np.diff(ClusInfo['SessionID'][CheckPairsB])) <= 1
|
|
123
|
+
|
|
124
|
+
CheckPairsNearA = CheckPairsA[InNearSessionA.squeeze()]
|
|
125
|
+
CheckPairsNearB = CheckPairsB[InNearSessionB.squeeze()]
|
|
126
|
+
|
|
127
|
+
if (np.logical_and(np.all(check_is_in(CheckPairsNearA, PairsAll)), np.all(check_is_in(CheckPairsNearB, PairsAll)))):
|
|
128
|
+
UniqueID[pair[0]] = min(UniqueID[pair])
|
|
129
|
+
UniqueID[pair[1]] = min(UniqueID[pair])
|
|
130
|
+
nMatches +=1
|
|
131
|
+
|
|
132
|
+
UniqueIDLiberal[pair[0]] = min(UniqueIDLiberal[pair])
|
|
133
|
+
UniqueIDLiberal[pair[1]] = min(UniqueIDLiberal[pair])
|
|
134
|
+
nMatchesLiberal +=1
|
|
135
|
+
else:
|
|
136
|
+
UniqueIDLiberal[pair[0]] = min(UniqueIDLiberal[pair])
|
|
137
|
+
UniqueIDLiberal[pair[1]] = min(UniqueIDLiberal[pair])
|
|
138
|
+
nMatchesLiberal +=1
|
|
139
|
+
|
|
140
|
+
print(f'Number of Liberal Matches: {nMatchesLiberal}')
|
|
141
|
+
print(f'Number of Intermediate Matches: {nMatches}')
|
|
142
|
+
print(f'Number of Conservative Matches: {nMatchesConservative}')
|
|
143
|
+
return [UniqueIDLiberal, UniqueID, UniqueIDConservative, OriUniqueID]
|
UnitMatchPy/Bayes_fun.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import UnitMatchPy.Param_fun as pf
|
|
3
|
+
|
|
4
|
+
def get_ParameterKernels(Scores2Include, labels, Cond, param, addone = 1):
|
|
5
|
+
"""
|
|
6
|
+
Requires Score2Include, a dictionary where the keys are the metric used and the values are
|
|
7
|
+
nUnits * nUnits with the score for each unit.
|
|
8
|
+
|
|
9
|
+
Smoothing and add one is done to try and compensate for the fact the histogram used as a prediction for the
|
|
10
|
+
probability distn has few values, therefore this smoothing hopes to make it more similar to the true distn
|
|
11
|
+
by smoothing nearby peaks and trough to reduce shot noise
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
ScoreVector = param['ScoreVector']
|
|
15
|
+
Bins = param['Bins']
|
|
16
|
+
SmoothProb = param['SmoothProb']
|
|
17
|
+
|
|
18
|
+
ParameterKernels = np.full((len(ScoreVector), len(Scores2Include), len(Cond)), np.nan)
|
|
19
|
+
|
|
20
|
+
ScoreID = 0
|
|
21
|
+
for sc in Scores2Include:
|
|
22
|
+
Scorestmp = Scores2Include[sc]
|
|
23
|
+
|
|
24
|
+
SmoothTmp = SmoothProb # Not doing the different ones for now (default the same)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
for Ck in range(len(Cond)):
|
|
28
|
+
|
|
29
|
+
HistTmp , __ = np.histogram(Scorestmp[labels == Ck], Bins)
|
|
30
|
+
ParameterKernels[:,ScoreID, Ck] = pf.smooth(HistTmp, SmoothTmp)
|
|
31
|
+
ParameterKernels[:,ScoreID, Ck] /= np.sum(ParameterKernels[:,ScoreID,Ck])
|
|
32
|
+
|
|
33
|
+
ParameterKernels[:,ScoreID, Ck] += addone* np.min(ParameterKernels[ParameterKernels[:,ScoreID, Ck] !=0, ScoreID, Ck], axis = 0)
|
|
34
|
+
|
|
35
|
+
ScoreID +=1
|
|
36
|
+
|
|
37
|
+
return ParameterKernels
|
|
38
|
+
|
|
39
|
+
def apply_naive_bayes(ParameterKernels,Priors, Predictors, param, Cond):
|
|
40
|
+
"""
|
|
41
|
+
Using the Paramater kernels, Priors and Predictors, calculate the probability each pair of units is a
|
|
42
|
+
match
|
|
43
|
+
"""
|
|
44
|
+
ScoreVector = param['ScoreVector']
|
|
45
|
+
|
|
46
|
+
nPairs = Predictors.shape[0] ** 2
|
|
47
|
+
|
|
48
|
+
unravel = np.reshape(Predictors , (Predictors.shape[0] * Predictors.shape[1], Predictors.shape[2],1))
|
|
49
|
+
x1 = np.tile(unravel, ( 1, 1, len(ScoreVector)))
|
|
50
|
+
tmp = np.expand_dims(ScoreVector, axis = (0,1))
|
|
51
|
+
x2 = np.tile(tmp, (x1.shape[0], x1.shape[1], 1))
|
|
52
|
+
minidx = np.argmin( np.abs(x1 - x2), axis = 2)
|
|
53
|
+
|
|
54
|
+
likelihood = np.full((nPairs, len(Cond)), np.nan)
|
|
55
|
+
for Ck in range(len(Cond)):
|
|
56
|
+
tmpp = np.zeros_like(minidx, np.float64)
|
|
57
|
+
for yy in range(minidx.shape[1]):
|
|
58
|
+
tmpp[:,yy] = ParameterKernels[minidx[:,yy],yy,Ck]
|
|
59
|
+
likelihood[:,Ck] = np.prod(tmpp, axis=1)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
Probability = np.full((nPairs,2), np.nan )
|
|
63
|
+
for Ck in range(len(Cond)):
|
|
64
|
+
Probability[:,Ck] = Priors[Ck] * likelihood[:,Ck] / np.nansum((Priors * likelihood), axis =1)
|
|
65
|
+
|
|
66
|
+
return Probability
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
#Function for extracting and averaging raw data
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import numpy as np
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from scipy.ndimage import gaussian_filter
|
|
7
|
+
from mtscomp import decompress
|
|
8
|
+
from joblib import Parallel, delayed
|
|
9
|
+
import UnitMatchPy.utils as util
|
|
10
|
+
|
|
11
|
+
#Decompressed data functions
|
|
12
|
+
def Read_Meta(metaPath):
|
|
13
|
+
"Readin Meta data as a dictionary"
|
|
14
|
+
metaDict = {}
|
|
15
|
+
with metaPath.open() as f:
|
|
16
|
+
mdatList = f.read().splitlines()
|
|
17
|
+
# convert the list entries into key value pairs
|
|
18
|
+
for m in mdatList:
|
|
19
|
+
csList = m.split(sep='=')
|
|
20
|
+
if csList[0][0] == '~':
|
|
21
|
+
currKey = csList[0][1:len(csList[0])]
|
|
22
|
+
else:
|
|
23
|
+
currKey = csList[0]
|
|
24
|
+
metaDict.update({currKey: csList[1]})
|
|
25
|
+
|
|
26
|
+
return(metaDict)
|
|
27
|
+
|
|
28
|
+
def get_sample_idx(SpikeTimes, UnitIDs, SampleAmount, units):
|
|
29
|
+
"""
|
|
30
|
+
Needs spike times, unit ID's (from kilosort dir) and maximum number of samples per unit.
|
|
31
|
+
Returns a (nUnits, SampleAmount) array with what spikes to sample for every unit, selected spikes evely spaced over time and
|
|
32
|
+
fill with NaN if the unit has less spikes than SampleAmount
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
UniqueUnitIDs = np.unique(UnitIDs)
|
|
36
|
+
nUnitsALL = len(UniqueUnitIDs)
|
|
37
|
+
|
|
38
|
+
SampleIdx = np.zeros((nUnitsALL, SampleAmount))
|
|
39
|
+
#Process ALL unit
|
|
40
|
+
for i, idx in enumerate(units):
|
|
41
|
+
UnitTimes = SpikeTimes[UnitIDs == idx]
|
|
42
|
+
if SampleAmount < len(UnitTimes):
|
|
43
|
+
ChooseIdx = np.linspace(0,len(UnitTimes)-1, SampleAmount, dtype = int) # -1 so can't indx out of region
|
|
44
|
+
SampleIdx[i,:] = UnitTimes[ChooseIdx]
|
|
45
|
+
else:
|
|
46
|
+
SampleIdx[i,:len(UnitTimes)] = UnitTimes
|
|
47
|
+
SampleIdx[i,len(UnitTimes):] = np.nan
|
|
48
|
+
|
|
49
|
+
return SampleIdx
|
|
50
|
+
|
|
51
|
+
def Extract_A_Unit(SampleIdx, Data, HalfWidth, SpikeWidth, nChannels, SampleAmount):
|
|
52
|
+
"""
|
|
53
|
+
This function extracts and averages the raw data for A unit, and splits the unit into two half for cross verification.
|
|
54
|
+
returns AvgWavforms shape (nChannels, SpikeWidth, 2)
|
|
55
|
+
|
|
56
|
+
NOTE - Here SampleIdx is a array of shape (SampleAmount), i.e use SampleIdx[UnitIdx] to get the AvgWAveform for that unit
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
Channels = np.arange(0,nChannels)
|
|
60
|
+
|
|
61
|
+
AllSampleWaveforms = np.zeros( (SampleAmount, SpikeWidth, nChannels))
|
|
62
|
+
for i, idx in enumerate(SampleIdx[:]):
|
|
63
|
+
if np.isnan(idx):
|
|
64
|
+
continue
|
|
65
|
+
tmp = Data[ int(idx - HalfWidth - 1): int(idx + HalfWidth - 1), Channels] # -1, to better fit with ML
|
|
66
|
+
tmp.astype(np.float32)
|
|
67
|
+
#gaussina smooth, over time gaussina window = 5, sigma = window size / 5
|
|
68
|
+
tmp = gaussian_filter(tmp, 1, radius = 2, axes = 0) #edges are handled differently to ML
|
|
69
|
+
# window ~ radius *2 + 1
|
|
70
|
+
tmp = tmp - np.mean(tmp[:20,:], axis = 0)
|
|
71
|
+
AllSampleWaveforms[i] = tmp
|
|
72
|
+
|
|
73
|
+
#median and split CV's
|
|
74
|
+
nWavs = np.sum(~np.isnan(SampleIdx[:]))
|
|
75
|
+
CVlim = np.floor(nWavs / 2).astype(int)
|
|
76
|
+
|
|
77
|
+
#find median over samples
|
|
78
|
+
AvgWaveforms = np.zeros((SpikeWidth, nChannels, 2))
|
|
79
|
+
AvgWaveforms[:, :, 0] = np.median(AllSampleWaveforms[:CVlim, :, :], axis = 0) #median over samples
|
|
80
|
+
AvgWaveforms[:, :, 1] = np.median(AllSampleWaveforms[CVlim:nWavs, :, :], axis = 0) #median over samples
|
|
81
|
+
return AvgWaveforms
|
|
82
|
+
|
|
83
|
+
def Extract_A_UnitKS4(SampleIdx, Data, SamplesBefore, SamplesAfter, SpikeWidth, nChannels, SampleAmount):
|
|
84
|
+
"""
|
|
85
|
+
This function extracts and averages the raw data for A unit, and splits the unit into two half for cross verification.
|
|
86
|
+
returns AvgWavforms shape (nChannels, SpikeWidth, 2)
|
|
87
|
+
|
|
88
|
+
NOTE - Here SampleIdx is a array of shape (SampleAmount), i.e use SampleIdx[UnitIdx] to get the AvgWAveform for that unit
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
Channels = np.arange(0,nChannels)
|
|
92
|
+
|
|
93
|
+
AllSampleWaveforms = np.zeros( (SampleAmount, SpikeWidth, nChannels))
|
|
94
|
+
for i, idx in enumerate(SampleIdx[:]):
|
|
95
|
+
if np.isnan(idx):
|
|
96
|
+
continue
|
|
97
|
+
tmp = Data[ int(idx - SamplesBefore - 1): int(idx + SamplesAfter - 1), Channels] # -1, to better fit with ML
|
|
98
|
+
tmp.astype(np.float32)
|
|
99
|
+
#gaussina smooth, over time gaussina window = 5, sigma = window size / 5
|
|
100
|
+
tmp = gaussian_filter(tmp, 1, radius = 2, axes = 0) #edges are handled differently to ML
|
|
101
|
+
# window ~ radius *2 + 1
|
|
102
|
+
tmp = tmp - np.mean(tmp[:20,:], axis = 0)
|
|
103
|
+
AllSampleWaveforms[i] = tmp
|
|
104
|
+
|
|
105
|
+
#median and split CV's
|
|
106
|
+
nWavs = np.sum(~np.isnan(SampleIdx[:]))
|
|
107
|
+
CVlim = np.floor(nWavs / 2).astype(int)
|
|
108
|
+
|
|
109
|
+
#find median over samples
|
|
110
|
+
AvgWaveforms = np.zeros((SpikeWidth, nChannels, 2))
|
|
111
|
+
AvgWaveforms[:, :, 0] = np.median(AllSampleWaveforms[:CVlim, :, :], axis = 0) #median over samples
|
|
112
|
+
AvgWaveforms[:, :, 1] = np.median(AllSampleWaveforms[CVlim:nWavs, :, :], axis = 0) #median over samples
|
|
113
|
+
return AvgWaveforms
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def Save_AvgWaveforms(AvgWaveforms, SaveDir, GoodUnits, ExtractGoodUnitsOnly = False):
|
|
117
|
+
"""
|
|
118
|
+
Will save the extracted average waveforms in a folder called 'RawWaveforms' in the given SaveDir
|
|
119
|
+
Each waveform will be saved in a unique .npy file called 'UnitX_RawSpikes.npy.
|
|
120
|
+
Supply GoodUnits, a array of which idx's are included, if they you are not extract all units
|
|
121
|
+
from the recording session.
|
|
122
|
+
"""
|
|
123
|
+
CurrentDir = os.getcwd()
|
|
124
|
+
os.chdir(SaveDir)
|
|
125
|
+
DirList = os.listdir()
|
|
126
|
+
if 'RawWaveforms' in DirList:
|
|
127
|
+
TmpPath = os.path.join(SaveDir, 'RawWaveforms')
|
|
128
|
+
|
|
129
|
+
else:
|
|
130
|
+
os.mkdir('RawWaveforms')
|
|
131
|
+
TmpPath = os.path.join(SaveDir, 'RawWaveforms')
|
|
132
|
+
|
|
133
|
+
os.chdir(TmpPath)
|
|
134
|
+
|
|
135
|
+
#first axis is each unit
|
|
136
|
+
|
|
137
|
+
#ALL waveforms from 0->nUnits
|
|
138
|
+
if ExtractGoodUnitsOnly == False:
|
|
139
|
+
for i in range(AvgWaveforms.shape[0]):
|
|
140
|
+
np.save(f'Unit{i}_RawSpikes.npy', AvgWaveforms[i,:,:,:])
|
|
141
|
+
|
|
142
|
+
#If only extracting GoodUnits
|
|
143
|
+
else:
|
|
144
|
+
for i, idx in enumerate(GoodUnits):
|
|
145
|
+
# ironically need idx[0], to selct value so saves with correct name
|
|
146
|
+
np.save(f'Unit{idx[0]}_RawSpikes.npy', AvgWaveforms[i,:,:,:])
|
|
147
|
+
|
|
148
|
+
os.chdir(CurrentDir)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
# Load in necessary files from KS directory and raw data directory
|
|
154
|
+
# extracting n Sessions
|
|
155
|
+
def get_raw_data_paths(RawDataDirPaths):
|
|
156
|
+
"""
|
|
157
|
+
This function requires RawDatPaths, a list of pahts to the Raw data directories, e.g where .cbin, .ch .meta files are
|
|
158
|
+
This function will return a list fo paths to the.cbin, .ch and .meta files
|
|
159
|
+
"""
|
|
160
|
+
cbinPaths = []
|
|
161
|
+
chPaths = []
|
|
162
|
+
metaPaths = []
|
|
163
|
+
|
|
164
|
+
for i in range(len(RawDataDirPaths)):
|
|
165
|
+
for f in os.listdir(RawDataDirPaths[i]):
|
|
166
|
+
name, ext = os.path.splitext(f)
|
|
167
|
+
|
|
168
|
+
if ext == '.cbin':
|
|
169
|
+
cbinPaths.append(os.path.join(RawDataDirPaths[i], name + ext))
|
|
170
|
+
|
|
171
|
+
if ext == '.ch':
|
|
172
|
+
chPaths.append(os.path.join(RawDataDirPaths[i], name + ext))
|
|
173
|
+
|
|
174
|
+
if ext == '.meta':
|
|
175
|
+
metaPaths.append(os.path.join(RawDataDirPaths[i], name + ext))
|
|
176
|
+
|
|
177
|
+
return cbinPaths, chPaths, metaPaths
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def extract_KSdata(KSdirs, ExtractGoodUnitsOnly = False):
|
|
181
|
+
"""
|
|
182
|
+
This fucntion requires KSdirs, a lsit of KiloSort directories for each session.
|
|
183
|
+
This function will then load in the spike_times, spike_ids and a Good_Units
|
|
184
|
+
"""
|
|
185
|
+
nSessions = len(KSdirs)
|
|
186
|
+
|
|
187
|
+
#Load Spike Times
|
|
188
|
+
SpikeTimes = []
|
|
189
|
+
for i in range(nSessions):
|
|
190
|
+
PathTmp = os.path.join(KSdirs[i], 'spike_times.npy')
|
|
191
|
+
SpikeTimestmp = np.load(PathTmp)
|
|
192
|
+
SpikeTimes.append(SpikeTimestmp)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
#Load Spike ID's
|
|
196
|
+
SpikeIDs = []
|
|
197
|
+
for i in range(nSessions):
|
|
198
|
+
PathTmp = os.path.join(KSdirs[i], 'spike_clusters.npy')
|
|
199
|
+
SpikeIDstmp = np.load(PathTmp)
|
|
200
|
+
SpikeIDs.append(SpikeIDstmp)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
if ExtractGoodUnitsOnly:
|
|
204
|
+
#Good unit ID's
|
|
205
|
+
UnitLabelPaths = []
|
|
206
|
+
|
|
207
|
+
# load Good unit Paths
|
|
208
|
+
for i in range(nSessions):
|
|
209
|
+
UnitLabelPaths.append( os.path.join(KSdirs[i], 'cluster_group.tsv'))
|
|
210
|
+
|
|
211
|
+
GoodUnits = util.get_good_units(UnitLabelPaths)
|
|
212
|
+
|
|
213
|
+
return SpikeIDs, SpikeTimes, GoodUnits
|
|
214
|
+
|
|
215
|
+
else:
|
|
216
|
+
return SpikeIDs, SpikeTimes, [None for s in range(nSessions)]
|