pyDANT 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyDANT/AutoCuration.py ADDED
@@ -0,0 +1,313 @@
1
+ import numpy as np
2
+ import os
3
+ from scipy.sparse.csgraph import connected_components
4
+ from .utils import graphEditNumber
5
+ import matplotlib
6
+ matplotlib.use('Agg') # Use a non-interactive backend for matplotlib
7
+ import matplotlib.pyplot as plt
8
+
9
+ def autoCuration(user_settings):
10
+ """Automatic curation of clustering results.
11
+ Perform automatic curation of clustering results based on user settings.
12
+ This function loads precomputed features, applies auto-splitting and auto-merging of clusters,
13
+ and saves the curated results.
14
+
15
+ Arguments:
16
+ - user_settings (dict): User settings
17
+
18
+ Outputs:
19
+ - ClusterMatrix.npy: The connectivity matrix of clusters after curation.
20
+ - IdxCluster.npy: The cluster index of each unit after curation. -1 indicates unpaired units.
21
+ - MatchedPairs.npy: The matched pairs of units after curation.
22
+ - CurationPairs.npy: The pairs of units that were curated.
23
+ - CurationTypes.npy: The types of curation applied to each pair.
24
+ - CurationTypeNames.npy: The names of the curation types.
25
+ - Output.npz (optional): A dictionary containing other information about the final results.
26
+ """
27
+
28
+ # Load precomputed features
29
+ data_folder = user_settings["path_to_data"]
30
+ output_folder = user_settings["output_folder"]
31
+
32
+ sessions = np.load(os.path.join(data_folder , 'session_index.npy'))
33
+
34
+ clustering_result = np.load(os.path.join(output_folder, 'ClusteringResults.npz'))
35
+ idx_cluster_hdbscan = clustering_result['idx_cluster_hdbscan']
36
+ good_matches_matrix = clustering_result['good_matches_matrix']
37
+ hdbscan_matrix = clustering_result['hdbscan_matrix']
38
+ similarity_matrix = clustering_result['similarity_matrix']
39
+ leafOrder = clustering_result['leafOrder']
40
+
41
+ # Initialize the curation
42
+ curation_type_names = ['Removal_SameSession', 'Removal_LowSimilarity']
43
+ curation_pairs = []
44
+ curation_types = []
45
+ num_removal = 0
46
+
47
+ # Initialize parameters
48
+ n_cluster = np.max(idx_cluster_hdbscan)
49
+ print(f'{n_cluster} clusters and {int((np.sum(hdbscan_matrix) - hdbscan_matrix.shape[0])/2)} pairs before removing bad units!')
50
+
51
+ hdbscan_matrix_raw = hdbscan_matrix.copy()
52
+
53
+ # Remove bad units in clusters
54
+ for k in range(1, n_cluster+1):
55
+ units = np.where(idx_cluster_hdbscan == k)[0]
56
+ sessions_this = sessions[units]
57
+ similarity_matrix_this = similarity_matrix[np.ix_(units, units)]
58
+
59
+ while len(sessions_this) != len(np.unique(sessions_this)):
60
+ idx_remove = []
61
+
62
+ for j in range(len(sessions_this)):
63
+ for i in range(j+1, len(sessions_this)):
64
+ if sessions_this[i] == sessions_this[j]:
65
+ similarity_i = np.mean(similarity_matrix_this[i,:])
66
+ similarity_j = np.mean(similarity_matrix_this[j,:])
67
+
68
+ if similarity_i <= similarity_j:
69
+ idx_remove.append(i)
70
+ else:
71
+ idx_remove.append(j)
72
+
73
+ # update curation pairs and types
74
+ for j in range(len(idx_remove)):
75
+ unit1 = units[idx_remove[j]]
76
+ for i in range(len(units)):
77
+ if np.any(idx_remove[:j+1] == i):
78
+ continue
79
+
80
+ unit2 = units[i]
81
+ pair_this = np.sort([unit1, unit2])
82
+ curation_pairs.append(pair_this)
83
+ curation_types.append(1)
84
+
85
+ idx_remove = np.unique(idx_remove)
86
+ idx_cluster_hdbscan[units[idx_remove]] = -1
87
+ units = np.delete(units, idx_remove)
88
+ sessions_this = sessions[units]
89
+ similarity_matrix_this = similarity_matrix[np.ix_(units, units)]
90
+
91
+ # Split clusters if there's a clear boundary in good_matches_matrix
92
+ if user_settings['autoCuration']['auto_split']:
93
+ n_cluster_new = n_cluster
94
+ for k in range(1, n_cluster+1):
95
+ units = np.where(idx_cluster_hdbscan == k)[0]
96
+ graph_this = good_matches_matrix[np.ix_(units, units)]
97
+ n_sub_clusters, idx_sub_clusters = connected_components(graph_this, directed=False)
98
+
99
+ if n_sub_clusters <= 1:
100
+ continue
101
+
102
+ for j in range(2, n_sub_clusters+1):
103
+ units_this = units[idx_sub_clusters == j-1]
104
+ idx_cluster_hdbscan[units_this] = n_cluster_new + j - 1
105
+
106
+ # update curation pairs and types
107
+ for i in range(len(units_this)):
108
+ unit1 = units_this[i]
109
+ for ii in range(len(units)):
110
+
111
+ if idx_sub_clusters[ii] >= 1 and idx_sub_clusters[ii] <= j-1:
112
+ continue
113
+
114
+ unit2 = units[ii]
115
+ pair_this = np.sort([unit1, unit2])
116
+ curation_pairs.append(pair_this)
117
+ curation_types.append(2)
118
+
119
+ n_cluster_new += n_sub_clusters - 1
120
+
121
+ n_cluster = n_cluster_new
122
+
123
+ # Update clusters and hdbscan matrix
124
+ idx_remove = []
125
+ for k in range(1, n_cluster+1):
126
+ units = np.where(idx_cluster_hdbscan == k)[0]
127
+ if len(units) <= 1:
128
+ idx_cluster_hdbscan[units] = -1
129
+ idx_remove.append(k)
130
+
131
+ for k in sorted(idx_remove, reverse=True):
132
+ idx_cluster_hdbscan[idx_cluster_hdbscan >= k] -= 1
133
+
134
+ assert len(np.unique(idx_cluster_hdbscan)) == np.max(idx_cluster_hdbscan)+1
135
+ n_cluster = np.max(idx_cluster_hdbscan)
136
+
137
+ # Update hdbscan matrix
138
+ hdbscan_matrix = np.zeros_like(similarity_matrix, dtype=bool)
139
+ for k in range(1, n_cluster+1):
140
+ idx = np.where(idx_cluster_hdbscan == k)[0]
141
+ for j in range(len(idx)):
142
+ for i in range(j+1, len(idx)):
143
+ hdbscan_matrix[idx[j], idx[i]] = True
144
+ hdbscan_matrix[idx[i], idx[j]] = True
145
+
146
+ np.fill_diagonal(hdbscan_matrix, True)
147
+
148
+ num_same, num_before, num_after = graphEditNumber(hdbscan_matrix_raw, hdbscan_matrix)
149
+ assert num_same == num_after
150
+
151
+ num_removal = num_before - num_after
152
+ print(f'{num_removal} deleting steps are done!')
153
+ print(f'{n_cluster} clusters and {int((np.sum(hdbscan_matrix) - hdbscan_matrix.shape[0])/2)} pairs after removing bad units!')
154
+
155
+ # Save curated results
156
+ hdbscan_matrix_curated = hdbscan_matrix.copy()
157
+ idx_cluster_hdbscan_curated = idx_cluster_hdbscan.copy()
158
+
159
+ # Get all matched pairs
160
+ matched_pairs_curated = []
161
+ for k in range(len(hdbscan_matrix_curated)):
162
+ for j in range(k+1, len(hdbscan_matrix_curated)):
163
+ if hdbscan_matrix_curated[k,j]:
164
+ matched_pairs_curated.append([k, j])
165
+
166
+ matched_pairs_curated = np.array(matched_pairs_curated)
167
+
168
+ # Save final output
169
+ np.save(os.path.join(output_folder, 'ClusterMatrix.npy'), hdbscan_matrix_curated)
170
+ np.save(os.path.join(output_folder, 'IdxCluster.npy'), idx_cluster_hdbscan_curated)
171
+ np.save(os.path.join(output_folder, 'MatchedPairs.npy'), matched_pairs_curated)
172
+ np.save(os.path.join(output_folder, 'CurationPairs.npy'), np.array(curation_pairs))
173
+ np.save(os.path.join(output_folder, 'CurationTypes.npy'), np.array(curation_types))
174
+ np.save(os.path.join(output_folder, 'CurationTypeNames.npy'), np.array(curation_type_names))
175
+
176
+ Output = {
177
+ 'NumClusters': np.max(idx_cluster_hdbscan_curated),
178
+ 'NumUnits': len(idx_cluster_hdbscan_curated),
179
+ 'IdxSort': leafOrder,
180
+ 'IdxCluster': idx_cluster_hdbscan_curated,
181
+ 'SimilarityMatrix': similarity_matrix,
182
+ 'GoodMatchesMatrix': good_matches_matrix,
183
+ 'ClusterMatrix': hdbscan_matrix_curated,
184
+ 'MatchedPairs': matched_pairs_curated,
185
+ 'Params': user_settings,
186
+ 'NumSession': np.max(sessions),
187
+ 'Sessions': sessions
188
+ }
189
+
190
+ np.savez(os.path.join(output_folder, 'Output.npz'), Output)
191
+ print(f'DANT done! Output saved to {os.path.join(output_folder, "Output.npz")}')
192
+ print(f'Found {Output["NumClusters"]} clusters and {len(Output["MatchedPairs"])} matches from {Output["NumUnits"]} units during {Output["NumSession"]} sessions!')
193
+
194
+
195
+ # Plot the results
196
+
197
+ # probability of matching between sessions
198
+ n_session = np.max(sessions)
199
+ n_cluster = np.max(idx_cluster_hdbscan)
200
+ n_matched_matrix = np.zeros((n_session, n_session))
201
+
202
+ n_units_each_session = np.array(
203
+ [np.sum(sessions == i) for i in range(1, n_session+1)]
204
+ )
205
+
206
+ for k in range(1, n_cluster+1):
207
+ units = np.where(idx_cluster_hdbscan == k)[0]
208
+ for j in range(len(units)):
209
+ for i in range(j+1, len(units)):
210
+ n_matched_matrix[sessions[units[j]]-1, sessions[units[i]]-1] += 1
211
+ n_matched_matrix[sessions[units[i]]-1, sessions[units[j]]-1] += 1
212
+
213
+ d_session = np.arange(-n_session+1, n_session)
214
+ p_matched = [[] for _ in range(len(d_session))]
215
+ p_matched_matrix = np.zeros((n_session, n_session))
216
+ for k in range(1, n_session+1):
217
+ for j in range(k+1, n_session+1):
218
+ p_matched_matrix[k-1,j-1] = n_matched_matrix[k-1,j-1]/n_units_each_session[k-1]
219
+ p_matched_matrix[j-1,k-1] = n_matched_matrix[k-1,j-1]/n_units_each_session[j-1]
220
+
221
+ idx_this = np.where(d_session == j-k)[0]
222
+ assert len(idx_this) == 1, f"Error: {j-k} not in d_session"
223
+ p_matched[idx_this[0]].append(n_matched_matrix[k-1,j-1]/n_units_each_session[j-1])
224
+
225
+ idx_this = np.where(d_session == k-j)[0]
226
+ assert len(idx_this) == 1, f"Error: {j-k} not in d_session"
227
+ p_matched[idx_this[0]].append(n_matched_matrix[k-1,j-1]/n_units_each_session[k-1])
228
+
229
+ p_matches_mean = np.array([np.mean(p_matched[k]) for k in range(len(d_session))])
230
+ p_matches_std = np.array([np.std(p_matched[k]) for k in range(len(d_session))])
231
+
232
+ p_matches_mean[d_session == 0] = np.nan
233
+ p_matches_std[d_session == 0] = np.nan
234
+
235
+ plt.figure(figsize=(12, 5))
236
+
237
+ plt.subplot(121)
238
+ plt.imshow(p_matched_matrix, cmap='plasma')
239
+ plt.gca().invert_yaxis()
240
+ plt.colorbar(label='Prob. of matched')
241
+ plt.xlabel('Sessions')
242
+ plt.ylabel('Sessions')
243
+
244
+ plt.subplot(122)
245
+ plt.plot(d_session, p_matches_mean, 'k.-')
246
+ plt.xlabel('Δ session')
247
+ plt.ylabel('Prob. of matched')
248
+
249
+ plt.savefig(os.path.join(output_folder, 'Figures/MatchedProbability.png'), dpi=300)
250
+ plt.close()
251
+
252
+ # Plot the similarity distribution and the weights of each feature
253
+ similarity_all = clustering_result['similarity_all']
254
+ idx_unit_pairs = clustering_result['idx_unit_pairs']
255
+ weights = clustering_result['weights']
256
+ similarity_names = user_settings['clustering']['features']
257
+
258
+ n_pairs = similarity_all.shape[0]
259
+ n_features = similarity_all.shape[1]
260
+
261
+ mean_similarity = similarity_all @ weights
262
+ is_matched = np.array([hdbscan_matrix_curated[idx_unit_pairs[k,0], idx_unit_pairs[k,1]] for k in range(n_pairs)])
263
+
264
+ n_plots = n_features+1
265
+ plt.figure(figsize=(4*n_plots + 1, 4))
266
+
267
+ plt.subplot(1, n_plots, 1)
268
+ plt.hist(mean_similarity[~is_matched], bins=50, color='k', density=True)
269
+ plt.hist(mean_similarity[is_matched], bins=50, color='b', density=True)
270
+ plt.xlabel('Mean similarity')
271
+ plt.ylabel('Density')
272
+ plt.title('Weighted summation')
273
+
274
+ for k in range(n_features):
275
+ plt.subplot(1, n_plots, k+2)
276
+ plt.hist(similarity_all[~is_matched,k], bins=50, color='k', density=True)
277
+ plt.hist(similarity_all[is_matched,k], bins=50, color='b', density=True)
278
+ plt.xlabel(similarity_names[k])
279
+ plt.title('weight = ' + f'{weights[k]:.2f}')
280
+
281
+ plt.savefig(os.path.join(output_folder, 'Figures/SimilarityDistribution.png'), dpi=300)
282
+ plt.close()
283
+
284
+ # Plot the scatters between any two features
285
+ n_points_max = 5000
286
+ idx_matched_rnd = np.where(is_matched)[0]
287
+ idx_unmatched_rnd = np.where(~is_matched)[0]
288
+
289
+ if np.sum(is_matched) > n_points_max:
290
+ idx_matched_rnd = np.random.choice(np.where(is_matched)[0], n_points_max, replace=False)
291
+
292
+ if np.sum(~is_matched) > n_points_max:
293
+ idx_unmatched_rnd = np.random.choice(np.where(~is_matched)[0], n_points_max, replace=False)
294
+
295
+ n_plots = n_features*(n_features-1)//2
296
+
297
+ plt.figure(figsize=(4*n_plots + 1, 4))
298
+ count = 0
299
+ for k in range(n_features):
300
+ for j in range(k+1, n_features):
301
+ count += 1
302
+ plt.subplot(1, n_plots, count)
303
+ plt.plot(similarity_all[idx_unmatched_rnd,k], similarity_all[idx_unmatched_rnd,j], 'k.', markersize=1, alpha=0.3)
304
+ plt.plot(similarity_all[idx_matched_rnd,k], similarity_all[idx_matched_rnd,j], 'b.', markersize=1, alpha=0.3)
305
+
306
+ plt.xlabel(similarity_names[k])
307
+ plt.ylabel(similarity_names[j])
308
+
309
+ plt.savefig(os.path.join(output_folder, 'Figures/FeatureScatter.png'), dpi=300)
310
+ plt.close()
311
+
312
+
313
+
@@ -0,0 +1,78 @@
1
+ import numpy as np
2
+ from joblib import Parallel, delayed
3
+ from tqdm import tqdm
4
+ import os
5
+ from .utils import waveformEstimation
6
+ import copy
7
+
8
+ def computeWaveformFeatures(user_settings, waveform_all, motion):
9
+ """ Compute the corrected waveforms based on the motion of the probe.
10
+ The corrected waveforms on the reference probe are computed using the Kriging interpolation method
11
+ and saved to the output folder.
12
+
13
+ Arguments:
14
+ - user_settings (dict): User settings
15
+ - waveform_all (numpy.ndarray): The waveforms of all units (n_unit, n_channel, n_sample)
16
+ - motion (Motion): The motion object containing the linear and constant parameters for correction
17
+ Outputs:
18
+ - waveforms_corrected.npy: The corrected waveforms.
19
+
20
+ """
21
+
22
+ data_folder = user_settings["path_to_data"]
23
+ output_folder = user_settings["output_folder"]
24
+ n_templates = user_settings["waveformCorrection"]["n_templates"]
25
+
26
+ channel_locations = np.load(os.path.join(data_folder, 'channel_locations.npy'))
27
+ sessions = np.load(os.path.join(data_folder , 'session_index.npy'))
28
+
29
+ locations = np.load(os.path.join(output_folder, 'locations.npy'))
30
+
31
+ n_sample = waveform_all.shape[2]
32
+ n_channel = waveform_all.shape[1]
33
+ n_unit = waveform_all.shape[0]
34
+
35
+ min_channel_depth = np.min(channel_locations[:,1])
36
+ max_channel_depth = np.max(channel_locations[:,1])
37
+
38
+ motion_bottom = motion.LinearScale*motion.Linear*min_channel_depth + motion.Constant
39
+ motion_top = motion.LinearScale*motion.Linear*max_channel_depth + motion.Constant
40
+
41
+ min_motion = np.min(np.concatenate((motion_bottom, motion_top)))
42
+ max_motion = np.max(np.concatenate((motion_bottom, motion_top)))
43
+ print('The range of motion: [%.1f μm ~ %.1f μm]\n' % (min_motion, max_motion))
44
+
45
+ def process_spike(locations_this, motion, channel_locations, waveform_this, session_this, n_templates, min_motion, max_motion):
46
+
47
+ motion_this = copy.deepcopy(motion)
48
+ waveforms_corrected = np.zeros((n_channel, n_sample, n_templates))
49
+
50
+ for k in range(n_templates):
51
+ if n_templates == 2:
52
+ if k == 0:
53
+ motion_this.Constant = motion_this.Constant - min_motion
54
+ else:
55
+ motion_this.Constant = motion_this.Constant - max_motion
56
+
57
+ dy = motion_this.get_motion(session_this, locations_this[1])
58
+ location_new = locations_this.copy()
59
+ location_new[1] -= dy
60
+
61
+ waveforms_corrected[:,:,k] = waveformEstimation(
62
+ waveform_this, locations_this, channel_locations, location_new)
63
+
64
+ return waveforms_corrected
65
+
66
+ # Run parallel processing with progress bar
67
+ out = Parallel(n_jobs=user_settings["n_jobs"])(
68
+ delayed(process_spike)(locations[k,:2], motion, channel_locations, waveform_all[k,:,:], sessions[k], n_templates, min_motion, max_motion)
69
+ for k in tqdm(range(n_unit), desc='Computing waveform features')
70
+ )
71
+
72
+ waveforms_corrected = np.zeros((n_unit, n_channel, n_sample, n_templates))
73
+ for k in range(n_unit):
74
+ waveforms_corrected[k,:,:,:] = out[k]
75
+
76
+ # Save the corrected waveforms
77
+ output_folder = user_settings['output_folder']
78
+ np.save(os.path.join(output_folder, 'waveforms_corrected.npy'), waveforms_corrected)