UnitMatchPy 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,143 @@
1
+ import numpy as np
2
+
3
+ def check_is_in(TestArray, ParentArray):
4
+ """
5
+ Test to se if a row in TestArray is in ParrentArray
6
+ Arguments:
7
+ TestArray -- ndarray (N, 2)
8
+ ParentArray -- ndArray(M, 2)
9
+
10
+ Returns:
11
+ IsIn -- ndarray (N) dtype - bool
12
+ """
13
+ IsIn = (TestArray[:, None] == ParentArray).all(-1).any(-1)
14
+ return IsIn
15
+
16
+ def AssignUID(Output, param, ClusInfo):
17
+
18
+ AllClusterIDs = ClusInfo['OriginalID'] # each units has unique ID
19
+
20
+ #create arrays for the uniwue ids
21
+ UniqueIDLiberal = np.arange(AllClusterIDs.shape[0])
22
+ OriUniqueID = np.arange(AllClusterIDs.shape[0])
23
+ UniqueIDConservative = np.arange(AllClusterIDs.shape[0])
24
+ UniqueID = np.arange(AllClusterIDs.shape[0])
25
+
26
+ GoodRecSesID = ClusInfo['SessionID']
27
+ RecOpt = np.unique(ClusInfo['SessionID'])
28
+ nRec = RecOpt.shape[0]
29
+
30
+ #data Driven Threshold?
31
+ if param.get('UseDataDrivenProbThrs', False):
32
+ stepsz = 0.1
33
+ binedges = np.arange(0, 1 + stepsz, stepsz)
34
+ plotvec = np.arange(stepsz / 2, 1, stepsz)
35
+
36
+ hw, __ = np.histogram(np.diag(Output), bins = len(binedges), density = True)
37
+
38
+ Threshold = plotvec[np.diff(hw) > 0.1]
39
+ else:
40
+ Threshold = param['MatchThreshold']
41
+
42
+ Pairs = np.argwhere(Output > Threshold)
43
+ Pairs = np.delete(Pairs, np.argwhere(Pairs[:,0] == Pairs[:,1]), axis =0) #delete self-matches
44
+ Pairs = np.sort(Pairs, axis = 1)# arange so smaller pairID is in column 1
45
+ #Only keep one copy of pairs only if both CV agrree its a match
46
+ PairsUnique, Count = np.unique(Pairs, axis = 0, return_counts=True)
47
+ PairsUniqueFilt = np.delete(PairsUnique, Count == 1, axis = 0) #if Count = 1 only 1 CV for that pair!
48
+
49
+ #get the mean probabilty for each match
50
+ ProbMean = np.nanmean(np.vstack((Output[PairsUniqueFilt[:,0], PairsUniqueFilt[:,1]], Output[PairsUniqueFilt[:,1], PairsUniqueFilt[:,0]])), axis=0)
51
+ #sort by the mean probabilty
52
+ PairsProb = np.hstack((PairsUniqueFilt, ProbMean[:, np.newaxis]))
53
+ SortIdxs = np.argsort(-PairsProb[:,2], axis = 0) #start go in decending order
54
+ PairsProbSorted = np.zeros_like(PairsProb)
55
+ PairsProbSorted = PairsProb[SortIdxs,:]
56
+
57
+ #Create a list which has both copies of each match e.g (1,2) and (2,1) for easier comparisson
58
+ PairsAll = np.zeros((PairsUniqueFilt.shape[0]*2,2))
59
+ PairsAll[:PairsUniqueFilt.shape[0],:] = PairsUniqueFilt
60
+ PairsAll[PairsUniqueFilt.shape[0]:,:] = PairsUniqueFilt[:, (1,0)]
61
+
62
+ nMatchesConservative = 0
63
+ nMatchesLiberal = 0
64
+ nMatches = 0
65
+ #Go through each pair and assign to groups!!
66
+ for pair in PairsProbSorted[:,:2]:
67
+ pair = pair.astype(np.int16)
68
+
69
+ #Get the conservative group ID for thecurrent 2 units
70
+ UnitAConservativeID = UniqueIDConservative[pair[0]]
71
+ UnitBConservativeID = UniqueIDConservative[pair[1]]
72
+ # get all units which have the same ID
73
+ SameGroupIdA = np.argwhere(UniqueIDConservative == UnitAConservativeID).squeeze()
74
+ SameGroupIdB = np.argwhere(UniqueIDConservative == UnitBConservativeID).squeeze()
75
+ #reshape array to be a 1d array if needed
76
+ if len(SameGroupIdA.shape) == 0:
77
+ SameGroupIdA = SameGroupIdA[np.newaxis]
78
+ if len(SameGroupIdB.shape) == 0:
79
+ SameGroupIdB = SameGroupIdB[np.newaxis]
80
+
81
+ #will need to check if pair[0] has match with SameGroupIdB and vice versa
82
+ CheckPairsA = np.stack((SameGroupIdB, np.broadcast_to(np.array(pair[0]), SameGroupIdB.shape)), axis = -1)
83
+ CheckPairsB = np.stack((SameGroupIdA, np.broadcast_to(np.array(pair[1]), SameGroupIdA.shape)), axis = -1)
84
+ # delete the potential self-matches
85
+ CheckPairsA = np.delete(CheckPairsA, np.argwhere(CheckPairsA[:,0] == CheckPairsA[:,1]), axis =0)
86
+ CheckPairsB = np.delete(CheckPairsB, np.argwhere(CheckPairsB[:,0] == CheckPairsB[:,1]), axis =0)
87
+
88
+ if (np.logical_and(np.all(check_is_in(CheckPairsA, PairsAll)), np.all(check_is_in(CheckPairsB, PairsAll)))):
89
+ #If each pairs matches with every unit in the other pairs group
90
+ #can add as match to all classes
91
+ UniqueIDConservative[pair[0]] = min(UniqueIDConservative[pair])
92
+ UniqueIDConservative[pair[1]] = min(UniqueIDConservative[pair])
93
+ nMatchesConservative +=1
94
+
95
+ UniqueID[pair[0]] = min(UniqueID[pair])
96
+ UniqueID[pair[1]] = min(UniqueID[pair])
97
+ nMatches +=1
98
+
99
+ UniqueIDLiberal[pair[0]] = min(UniqueIDLiberal[pair])
100
+ UniqueIDLiberal[pair[0]] = min(UniqueIDLiberal[pair])
101
+ nMatchesLiberal +=1
102
+ else:
103
+ #Now test to see if each pairs match with every unit in the other pair IF they are in the same/adjacent sessions
104
+ UnitAID = UniqueID[pair[0]]
105
+ UnitBID = UniqueID[pair[1]]
106
+
107
+ SameGroupIdA = np.argwhere(UniqueID == UnitAID).squeeze()
108
+ SameGroupIdB = np.argwhere(UniqueID == UnitBID).squeeze()
109
+ if len(SameGroupIdA.shape) == 0:
110
+ SameGroupIdA = SameGroupIdA[np.newaxis]
111
+ if len(SameGroupIdB.shape) == 0:
112
+ SameGroupIdB = SameGroupIdB[np.newaxis]
113
+
114
+ CheckPairsA = np.stack((SameGroupIdB, np.broadcast_to(np.array(pair[0]), SameGroupIdB.shape)), axis = -1)
115
+ CheckPairsB = np.stack((SameGroupIdA, np.broadcast_to(np.array(pair[1]), SameGroupIdA.shape)), axis = -1)
116
+ #delte potential self-matches
117
+ CheckPairsA = np.delete(CheckPairsA, np.argwhere(CheckPairsA[:,0] == CheckPairsA[:,1]), axis =0)
118
+ CheckPairsB = np.delete(CheckPairsB, np.argwhere(CheckPairsB[:,0] == CheckPairsB[:,1]), axis =0)
119
+
120
+ #check to see if they are in the same or adjacent sessions
121
+ InNearSessionA = np.abs(np.diff(ClusInfo['SessionID'][CheckPairsA])) <= 1
122
+ InNearSessionB = np.abs(np.diff(ClusInfo['SessionID'][CheckPairsB])) <= 1
123
+
124
+ CheckPairsNearA = CheckPairsA[InNearSessionA.squeeze()]
125
+ CheckPairsNearB = CheckPairsB[InNearSessionB.squeeze()]
126
+
127
+ if (np.logical_and(np.all(check_is_in(CheckPairsNearA, PairsAll)), np.all(check_is_in(CheckPairsNearB, PairsAll)))):
128
+ UniqueID[pair[0]] = min(UniqueID[pair])
129
+ UniqueID[pair[1]] = min(UniqueID[pair])
130
+ nMatches +=1
131
+
132
+ UniqueIDLiberal[pair[0]] = min(UniqueIDLiberal[pair])
133
+ UniqueIDLiberal[pair[1]] = min(UniqueIDLiberal[pair])
134
+ nMatchesLiberal +=1
135
+ else:
136
+ UniqueIDLiberal[pair[0]] = min(UniqueIDLiberal[pair])
137
+ UniqueIDLiberal[pair[1]] = min(UniqueIDLiberal[pair])
138
+ nMatchesLiberal +=1
139
+
140
+ print(f'Number of Liberal Matches: {nMatchesLiberal}')
141
+ print(f'Number of Intermediate Matches: {nMatches}')
142
+ print(f'Number of Conservative Matches: {nMatchesConservative}')
143
+ return [UniqueIDLiberal, UniqueID, UniqueIDConservative, OriUniqueID]
@@ -0,0 +1,66 @@
1
+ import numpy as np
2
+ import UnitMatchPy.Param_fun as pf
3
+
4
+ def get_ParameterKernels(Scores2Include, labels, Cond, param, addone = 1):
5
+ """
6
+ Requires Score2Include, a dictionary where the keys are the metric used and the values are
7
+ nUnits * nUnits with the score for each unit.
8
+
9
+ Smoothing and add one is done to try and compensate for the fact the histogram used as a prediction for the
10
+ probability distn has few values, therefore this smoothing hopes to make it more similar to the true distn
11
+ by smoothing nearby peaks and trough to reduce shot noise
12
+ """
13
+
14
+ ScoreVector = param['ScoreVector']
15
+ Bins = param['Bins']
16
+ SmoothProb = param['SmoothProb']
17
+
18
+ ParameterKernels = np.full((len(ScoreVector), len(Scores2Include), len(Cond)), np.nan)
19
+
20
+ ScoreID = 0
21
+ for sc in Scores2Include:
22
+ Scorestmp = Scores2Include[sc]
23
+
24
+ SmoothTmp = SmoothProb # Not doing the different ones for now (default the same)
25
+
26
+
27
+ for Ck in range(len(Cond)):
28
+
29
+ HistTmp , __ = np.histogram(Scorestmp[labels == Ck], Bins)
30
+ ParameterKernels[:,ScoreID, Ck] = pf.smooth(HistTmp, SmoothTmp)
31
+ ParameterKernels[:,ScoreID, Ck] /= np.sum(ParameterKernels[:,ScoreID,Ck])
32
+
33
+ ParameterKernels[:,ScoreID, Ck] += addone* np.min(ParameterKernels[ParameterKernels[:,ScoreID, Ck] !=0, ScoreID, Ck], axis = 0)
34
+
35
+ ScoreID +=1
36
+
37
+ return ParameterKernels
38
+
39
+ def apply_naive_bayes(ParameterKernels,Priors, Predictors, param, Cond):
40
+ """
41
+ Using the Paramater kernels, Priors and Predictors, calculate the probability each pair of units is a
42
+ match
43
+ """
44
+ ScoreVector = param['ScoreVector']
45
+
46
+ nPairs = Predictors.shape[0] ** 2
47
+
48
+ unravel = np.reshape(Predictors , (Predictors.shape[0] * Predictors.shape[1], Predictors.shape[2],1))
49
+ x1 = np.tile(unravel, ( 1, 1, len(ScoreVector)))
50
+ tmp = np.expand_dims(ScoreVector, axis = (0,1))
51
+ x2 = np.tile(tmp, (x1.shape[0], x1.shape[1], 1))
52
+ minidx = np.argmin( np.abs(x1 - x2), axis = 2)
53
+
54
+ likelihood = np.full((nPairs, len(Cond)), np.nan)
55
+ for Ck in range(len(Cond)):
56
+ tmpp = np.zeros_like(minidx, np.float64)
57
+ for yy in range(minidx.shape[1]):
58
+ tmpp[:,yy] = ParameterKernels[minidx[:,yy],yy,Ck]
59
+ likelihood[:,Ck] = np.prod(tmpp, axis=1)
60
+
61
+
62
+ Probability = np.full((nPairs,2), np.nan )
63
+ for Ck in range(len(Cond)):
64
+ Probability[:,Ck] = Priors[Ck] * likelihood[:,Ck] / np.nansum((Priors * likelihood), axis =1)
65
+
66
+ return Probability
@@ -0,0 +1,216 @@
1
+ #Function for extracting and averaging raw data
2
+
3
+ import os
4
+ import numpy as np
5
+ from pathlib import Path
6
+ from scipy.ndimage import gaussian_filter
7
+ from mtscomp import decompress
8
+ from joblib import Parallel, delayed
9
+ import UnitMatchPy.utils as util
10
+
11
+ #Decompressed data functions
12
+ def Read_Meta(metaPath):
13
+ "Readin Meta data as a dictionary"
14
+ metaDict = {}
15
+ with metaPath.open() as f:
16
+ mdatList = f.read().splitlines()
17
+ # convert the list entries into key value pairs
18
+ for m in mdatList:
19
+ csList = m.split(sep='=')
20
+ if csList[0][0] == '~':
21
+ currKey = csList[0][1:len(csList[0])]
22
+ else:
23
+ currKey = csList[0]
24
+ metaDict.update({currKey: csList[1]})
25
+
26
+ return(metaDict)
27
+
28
+ def get_sample_idx(SpikeTimes, UnitIDs, SampleAmount, units):
29
+ """
30
+ Needs spike times, unit ID's (from kilosort dir) and maximum number of samples per unit.
31
+ Returns a (nUnits, SampleAmount) array with what spikes to sample for every unit, selected spikes evely spaced over time and
32
+ fill with NaN if the unit has less spikes than SampleAmount
33
+ """
34
+
35
+ UniqueUnitIDs = np.unique(UnitIDs)
36
+ nUnitsALL = len(UniqueUnitIDs)
37
+
38
+ SampleIdx = np.zeros((nUnitsALL, SampleAmount))
39
+ #Process ALL unit
40
+ for i, idx in enumerate(units):
41
+ UnitTimes = SpikeTimes[UnitIDs == idx]
42
+ if SampleAmount < len(UnitTimes):
43
+ ChooseIdx = np.linspace(0,len(UnitTimes)-1, SampleAmount, dtype = int) # -1 so can't indx out of region
44
+ SampleIdx[i,:] = UnitTimes[ChooseIdx]
45
+ else:
46
+ SampleIdx[i,:len(UnitTimes)] = UnitTimes
47
+ SampleIdx[i,len(UnitTimes):] = np.nan
48
+
49
+ return SampleIdx
50
+
51
+ def Extract_A_Unit(SampleIdx, Data, HalfWidth, SpikeWidth, nChannels, SampleAmount):
52
+ """
53
+ This function extracts and averages the raw data for A unit, and splits the unit into two half for cross verification.
54
+ returns AvgWavforms shape (nChannels, SpikeWidth, 2)
55
+
56
+ NOTE - Here SampleIdx is a array of shape (SampleAmount), i.e use SampleIdx[UnitIdx] to get the AvgWAveform for that unit
57
+ """
58
+
59
+ Channels = np.arange(0,nChannels)
60
+
61
+ AllSampleWaveforms = np.zeros( (SampleAmount, SpikeWidth, nChannels))
62
+ for i, idx in enumerate(SampleIdx[:]):
63
+ if np.isnan(idx):
64
+ continue
65
+ tmp = Data[ int(idx - HalfWidth - 1): int(idx + HalfWidth - 1), Channels] # -1, to better fit with ML
66
+ tmp.astype(np.float32)
67
+ #gaussina smooth, over time gaussina window = 5, sigma = window size / 5
68
+ tmp = gaussian_filter(tmp, 1, radius = 2, axes = 0) #edges are handled differently to ML
69
+ # window ~ radius *2 + 1
70
+ tmp = tmp - np.mean(tmp[:20,:], axis = 0)
71
+ AllSampleWaveforms[i] = tmp
72
+
73
+ #median and split CV's
74
+ nWavs = np.sum(~np.isnan(SampleIdx[:]))
75
+ CVlim = np.floor(nWavs / 2).astype(int)
76
+
77
+ #find median over samples
78
+ AvgWaveforms = np.zeros((SpikeWidth, nChannels, 2))
79
+ AvgWaveforms[:, :, 0] = np.median(AllSampleWaveforms[:CVlim, :, :], axis = 0) #median over samples
80
+ AvgWaveforms[:, :, 1] = np.median(AllSampleWaveforms[CVlim:nWavs, :, :], axis = 0) #median over samples
81
+ return AvgWaveforms
82
+
83
+ def Extract_A_UnitKS4(SampleIdx, Data, SamplesBefore, SamplesAfter, SpikeWidth, nChannels, SampleAmount):
84
+ """
85
+ This function extracts and averages the raw data for A unit, and splits the unit into two half for cross verification.
86
+ returns AvgWavforms shape (nChannels, SpikeWidth, 2)
87
+
88
+ NOTE - Here SampleIdx is a array of shape (SampleAmount), i.e use SampleIdx[UnitIdx] to get the AvgWAveform for that unit
89
+ """
90
+
91
+ Channels = np.arange(0,nChannels)
92
+
93
+ AllSampleWaveforms = np.zeros( (SampleAmount, SpikeWidth, nChannels))
94
+ for i, idx in enumerate(SampleIdx[:]):
95
+ if np.isnan(idx):
96
+ continue
97
+ tmp = Data[ int(idx - SamplesBefore - 1): int(idx + SamplesAfter - 1), Channels] # -1, to better fit with ML
98
+ tmp.astype(np.float32)
99
+ #gaussina smooth, over time gaussina window = 5, sigma = window size / 5
100
+ tmp = gaussian_filter(tmp, 1, radius = 2, axes = 0) #edges are handled differently to ML
101
+ # window ~ radius *2 + 1
102
+ tmp = tmp - np.mean(tmp[:20,:], axis = 0)
103
+ AllSampleWaveforms[i] = tmp
104
+
105
+ #median and split CV's
106
+ nWavs = np.sum(~np.isnan(SampleIdx[:]))
107
+ CVlim = np.floor(nWavs / 2).astype(int)
108
+
109
+ #find median over samples
110
+ AvgWaveforms = np.zeros((SpikeWidth, nChannels, 2))
111
+ AvgWaveforms[:, :, 0] = np.median(AllSampleWaveforms[:CVlim, :, :], axis = 0) #median over samples
112
+ AvgWaveforms[:, :, 1] = np.median(AllSampleWaveforms[CVlim:nWavs, :, :], axis = 0) #median over samples
113
+ return AvgWaveforms
114
+
115
+
116
+ def Save_AvgWaveforms(AvgWaveforms, SaveDir, GoodUnits, ExtractGoodUnitsOnly = False):
117
+ """
118
+ Will save the extracted average waveforms in a folder called 'RawWaveforms' in the given SaveDir
119
+ Each waveform will be saved in a unique .npy file called 'UnitX_RawSpikes.npy.
120
+ Supply GoodUnits, a array of which idx's are included, if they you are not extract all units
121
+ from the recording session.
122
+ """
123
+ CurrentDir = os.getcwd()
124
+ os.chdir(SaveDir)
125
+ DirList = os.listdir()
126
+ if 'RawWaveforms' in DirList:
127
+ TmpPath = os.path.join(SaveDir, 'RawWaveforms')
128
+
129
+ else:
130
+ os.mkdir('RawWaveforms')
131
+ TmpPath = os.path.join(SaveDir, 'RawWaveforms')
132
+
133
+ os.chdir(TmpPath)
134
+
135
+ #first axis is each unit
136
+
137
+ #ALL waveforms from 0->nUnits
138
+ if ExtractGoodUnitsOnly == False:
139
+ for i in range(AvgWaveforms.shape[0]):
140
+ np.save(f'Unit{i}_RawSpikes.npy', AvgWaveforms[i,:,:,:])
141
+
142
+ #If only extracting GoodUnits
143
+ else:
144
+ for i, idx in enumerate(GoodUnits):
145
+ # ironically need idx[0], to selct value so saves with correct name
146
+ np.save(f'Unit{idx[0]}_RawSpikes.npy', AvgWaveforms[i,:,:,:])
147
+
148
+ os.chdir(CurrentDir)
149
+
150
+
151
+
152
+
153
+ # Load in necessary files from KS directory and raw data directory
154
+ # extracting n Sessions
155
+ def get_raw_data_paths(RawDataDirPaths):
156
+ """
157
+ This function requires RawDatPaths, a list of pahts to the Raw data directories, e.g where .cbin, .ch .meta files are
158
+ This function will return a list fo paths to the.cbin, .ch and .meta files
159
+ """
160
+ cbinPaths = []
161
+ chPaths = []
162
+ metaPaths = []
163
+
164
+ for i in range(len(RawDataDirPaths)):
165
+ for f in os.listdir(RawDataDirPaths[i]):
166
+ name, ext = os.path.splitext(f)
167
+
168
+ if ext == '.cbin':
169
+ cbinPaths.append(os.path.join(RawDataDirPaths[i], name + ext))
170
+
171
+ if ext == '.ch':
172
+ chPaths.append(os.path.join(RawDataDirPaths[i], name + ext))
173
+
174
+ if ext == '.meta':
175
+ metaPaths.append(os.path.join(RawDataDirPaths[i], name + ext))
176
+
177
+ return cbinPaths, chPaths, metaPaths
178
+
179
+
180
+ def extract_KSdata(KSdirs, ExtractGoodUnitsOnly = False):
181
+ """
182
+ This fucntion requires KSdirs, a lsit of KiloSort directories for each session.
183
+ This function will then load in the spike_times, spike_ids and a Good_Units
184
+ """
185
+ nSessions = len(KSdirs)
186
+
187
+ #Load Spike Times
188
+ SpikeTimes = []
189
+ for i in range(nSessions):
190
+ PathTmp = os.path.join(KSdirs[i], 'spike_times.npy')
191
+ SpikeTimestmp = np.load(PathTmp)
192
+ SpikeTimes.append(SpikeTimestmp)
193
+
194
+
195
+ #Load Spike ID's
196
+ SpikeIDs = []
197
+ for i in range(nSessions):
198
+ PathTmp = os.path.join(KSdirs[i], 'spike_clusters.npy')
199
+ SpikeIDstmp = np.load(PathTmp)
200
+ SpikeIDs.append(SpikeIDstmp)
201
+
202
+
203
+ if ExtractGoodUnitsOnly:
204
+ #Good unit ID's
205
+ UnitLabelPaths = []
206
+
207
+ # load Good unit Paths
208
+ for i in range(nSessions):
209
+ UnitLabelPaths.append( os.path.join(KSdirs[i], 'cluster_group.tsv'))
210
+
211
+ GoodUnits = util.get_good_units(UnitLabelPaths)
212
+
213
+ return SpikeIDs, SpikeTimes, GoodUnits
214
+
215
+ else:
216
+ return SpikeIDs, SpikeTimes, [None for s in range(nSessions)]