pyDANT 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyDANT/AutoCuration.py +313 -0
- pyDANT/ComputeWaveformFeatures.py +78 -0
- pyDANT/IterativeClustering.py +307 -0
- pyDANT/MotionEstimation.py +232 -0
- pyDANT/Preprocess.py +268 -0
- pyDANT/__init__.py +5 -0
- pyDANT/utils.py +292 -0
- pydant-0.0.7.dist-info/METADATA +80 -0
- pydant-0.0.7.dist-info/RECORD +12 -0
- pydant-0.0.7.dist-info/WHEEL +5 -0
- pydant-0.0.7.dist-info/licenses/LICENSE +674 -0
- pydant-0.0.7.dist-info/top_level.txt +1 -0
pyDANT/AutoCuration.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import os
|
|
3
|
+
from scipy.sparse.csgraph import connected_components
|
|
4
|
+
from .utils import graphEditNumber
|
|
5
|
+
import matplotlib
|
|
6
|
+
matplotlib.use('Agg') # Use a non-interactive backend for matplotlib
|
|
7
|
+
import matplotlib.pyplot as plt
|
|
8
|
+
|
|
9
|
+
def autoCuration(user_settings):
|
|
10
|
+
"""Automatic curation of clustering results.
|
|
11
|
+
Perform automatic curation of clustering results based on user settings.
|
|
12
|
+
This function loads precomputed features, applies auto-splitting and auto-merging of clusters,
|
|
13
|
+
and saves the curated results.
|
|
14
|
+
|
|
15
|
+
Arguments:
|
|
16
|
+
- user_settings (dict): User settings
|
|
17
|
+
|
|
18
|
+
Outputs:
|
|
19
|
+
- ClusterMatrix.npy: The connectivity matrix of clusters after curation.
|
|
20
|
+
- IdxCluster.npy: The cluster index of each unit after curation. -1 indicates unpaired units.
|
|
21
|
+
- MatchedPairs.npy: The matched pairs of units after curation.
|
|
22
|
+
- CurationPairs.npy: The pairs of units that were curated.
|
|
23
|
+
- CurationTypes.npy: The types of curation applied to each pair.
|
|
24
|
+
- CurationTypeNames.npy: The names of the curation types.
|
|
25
|
+
- Output.npz (optional): A dictionary containing other information about the final results.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
# Load precomputed features
|
|
29
|
+
data_folder = user_settings["path_to_data"]
|
|
30
|
+
output_folder = user_settings["output_folder"]
|
|
31
|
+
|
|
32
|
+
sessions = np.load(os.path.join(data_folder , 'session_index.npy'))
|
|
33
|
+
|
|
34
|
+
clustering_result = np.load(os.path.join(output_folder, 'ClusteringResults.npz'))
|
|
35
|
+
idx_cluster_hdbscan = clustering_result['idx_cluster_hdbscan']
|
|
36
|
+
good_matches_matrix = clustering_result['good_matches_matrix']
|
|
37
|
+
hdbscan_matrix = clustering_result['hdbscan_matrix']
|
|
38
|
+
similarity_matrix = clustering_result['similarity_matrix']
|
|
39
|
+
leafOrder = clustering_result['leafOrder']
|
|
40
|
+
|
|
41
|
+
# Initialize the curation
|
|
42
|
+
curation_type_names = ['Removal_SameSession', 'Removal_LowSimilarity']
|
|
43
|
+
curation_pairs = []
|
|
44
|
+
curation_types = []
|
|
45
|
+
num_removal = 0
|
|
46
|
+
|
|
47
|
+
# Initialize parameters
|
|
48
|
+
n_cluster = np.max(idx_cluster_hdbscan)
|
|
49
|
+
print(f'{n_cluster} clusters and {int((np.sum(hdbscan_matrix) - hdbscan_matrix.shape[0])/2)} pairs before removing bad units!')
|
|
50
|
+
|
|
51
|
+
hdbscan_matrix_raw = hdbscan_matrix.copy()
|
|
52
|
+
|
|
53
|
+
# Remove bad units in clusters
|
|
54
|
+
for k in range(1, n_cluster+1):
|
|
55
|
+
units = np.where(idx_cluster_hdbscan == k)[0]
|
|
56
|
+
sessions_this = sessions[units]
|
|
57
|
+
similarity_matrix_this = similarity_matrix[np.ix_(units, units)]
|
|
58
|
+
|
|
59
|
+
while len(sessions_this) != len(np.unique(sessions_this)):
|
|
60
|
+
idx_remove = []
|
|
61
|
+
|
|
62
|
+
for j in range(len(sessions_this)):
|
|
63
|
+
for i in range(j+1, len(sessions_this)):
|
|
64
|
+
if sessions_this[i] == sessions_this[j]:
|
|
65
|
+
similarity_i = np.mean(similarity_matrix_this[i,:])
|
|
66
|
+
similarity_j = np.mean(similarity_matrix_this[j,:])
|
|
67
|
+
|
|
68
|
+
if similarity_i <= similarity_j:
|
|
69
|
+
idx_remove.append(i)
|
|
70
|
+
else:
|
|
71
|
+
idx_remove.append(j)
|
|
72
|
+
|
|
73
|
+
# update curation pairs and types
|
|
74
|
+
for j in range(len(idx_remove)):
|
|
75
|
+
unit1 = units[idx_remove[j]]
|
|
76
|
+
for i in range(len(units)):
|
|
77
|
+
if np.any(idx_remove[:j+1] == i):
|
|
78
|
+
continue
|
|
79
|
+
|
|
80
|
+
unit2 = units[i]
|
|
81
|
+
pair_this = np.sort([unit1, unit2])
|
|
82
|
+
curation_pairs.append(pair_this)
|
|
83
|
+
curation_types.append(1)
|
|
84
|
+
|
|
85
|
+
idx_remove = np.unique(idx_remove)
|
|
86
|
+
idx_cluster_hdbscan[units[idx_remove]] = -1
|
|
87
|
+
units = np.delete(units, idx_remove)
|
|
88
|
+
sessions_this = sessions[units]
|
|
89
|
+
similarity_matrix_this = similarity_matrix[np.ix_(units, units)]
|
|
90
|
+
|
|
91
|
+
# Split clusters if there's a clear boundary in good_matches_matrix
|
|
92
|
+
if user_settings['autoCuration']['auto_split']:
|
|
93
|
+
n_cluster_new = n_cluster
|
|
94
|
+
for k in range(1, n_cluster+1):
|
|
95
|
+
units = np.where(idx_cluster_hdbscan == k)[0]
|
|
96
|
+
graph_this = good_matches_matrix[np.ix_(units, units)]
|
|
97
|
+
n_sub_clusters, idx_sub_clusters = connected_components(graph_this, directed=False)
|
|
98
|
+
|
|
99
|
+
if n_sub_clusters <= 1:
|
|
100
|
+
continue
|
|
101
|
+
|
|
102
|
+
for j in range(2, n_sub_clusters+1):
|
|
103
|
+
units_this = units[idx_sub_clusters == j-1]
|
|
104
|
+
idx_cluster_hdbscan[units_this] = n_cluster_new + j - 1
|
|
105
|
+
|
|
106
|
+
# update curation pairs and types
|
|
107
|
+
for i in range(len(units_this)):
|
|
108
|
+
unit1 = units_this[i]
|
|
109
|
+
for ii in range(len(units)):
|
|
110
|
+
|
|
111
|
+
if idx_sub_clusters[ii] >= 1 and idx_sub_clusters[ii] <= j-1:
|
|
112
|
+
continue
|
|
113
|
+
|
|
114
|
+
unit2 = units[ii]
|
|
115
|
+
pair_this = np.sort([unit1, unit2])
|
|
116
|
+
curation_pairs.append(pair_this)
|
|
117
|
+
curation_types.append(2)
|
|
118
|
+
|
|
119
|
+
n_cluster_new += n_sub_clusters - 1
|
|
120
|
+
|
|
121
|
+
n_cluster = n_cluster_new
|
|
122
|
+
|
|
123
|
+
# Update clusters and hdbscan matrix
|
|
124
|
+
idx_remove = []
|
|
125
|
+
for k in range(1, n_cluster+1):
|
|
126
|
+
units = np.where(idx_cluster_hdbscan == k)[0]
|
|
127
|
+
if len(units) <= 1:
|
|
128
|
+
idx_cluster_hdbscan[units] = -1
|
|
129
|
+
idx_remove.append(k)
|
|
130
|
+
|
|
131
|
+
for k in sorted(idx_remove, reverse=True):
|
|
132
|
+
idx_cluster_hdbscan[idx_cluster_hdbscan >= k] -= 1
|
|
133
|
+
|
|
134
|
+
assert len(np.unique(idx_cluster_hdbscan)) == np.max(idx_cluster_hdbscan)+1
|
|
135
|
+
n_cluster = np.max(idx_cluster_hdbscan)
|
|
136
|
+
|
|
137
|
+
# Update hdbscan matrix
|
|
138
|
+
hdbscan_matrix = np.zeros_like(similarity_matrix, dtype=bool)
|
|
139
|
+
for k in range(1, n_cluster+1):
|
|
140
|
+
idx = np.where(idx_cluster_hdbscan == k)[0]
|
|
141
|
+
for j in range(len(idx)):
|
|
142
|
+
for i in range(j+1, len(idx)):
|
|
143
|
+
hdbscan_matrix[idx[j], idx[i]] = True
|
|
144
|
+
hdbscan_matrix[idx[i], idx[j]] = True
|
|
145
|
+
|
|
146
|
+
np.fill_diagonal(hdbscan_matrix, True)
|
|
147
|
+
|
|
148
|
+
num_same, num_before, num_after = graphEditNumber(hdbscan_matrix_raw, hdbscan_matrix)
|
|
149
|
+
assert num_same == num_after
|
|
150
|
+
|
|
151
|
+
num_removal = num_before - num_after
|
|
152
|
+
print(f'{num_removal} deleting steps are done!')
|
|
153
|
+
print(f'{n_cluster} clusters and {int((np.sum(hdbscan_matrix) - hdbscan_matrix.shape[0])/2)} pairs after removing bad units!')
|
|
154
|
+
|
|
155
|
+
# Save curated results
|
|
156
|
+
hdbscan_matrix_curated = hdbscan_matrix.copy()
|
|
157
|
+
idx_cluster_hdbscan_curated = idx_cluster_hdbscan.copy()
|
|
158
|
+
|
|
159
|
+
# Get all matched pairs
|
|
160
|
+
matched_pairs_curated = []
|
|
161
|
+
for k in range(len(hdbscan_matrix_curated)):
|
|
162
|
+
for j in range(k+1, len(hdbscan_matrix_curated)):
|
|
163
|
+
if hdbscan_matrix_curated[k,j]:
|
|
164
|
+
matched_pairs_curated.append([k, j])
|
|
165
|
+
|
|
166
|
+
matched_pairs_curated = np.array(matched_pairs_curated)
|
|
167
|
+
|
|
168
|
+
# Save final output
|
|
169
|
+
np.save(os.path.join(output_folder, 'ClusterMatrix.npy'), hdbscan_matrix_curated)
|
|
170
|
+
np.save(os.path.join(output_folder, 'IdxCluster.npy'), idx_cluster_hdbscan_curated)
|
|
171
|
+
np.save(os.path.join(output_folder, 'MatchedPairs.npy'), matched_pairs_curated)
|
|
172
|
+
np.save(os.path.join(output_folder, 'CurationPairs.npy'), np.array(curation_pairs))
|
|
173
|
+
np.save(os.path.join(output_folder, 'CurationTypes.npy'), np.array(curation_types))
|
|
174
|
+
np.save(os.path.join(output_folder, 'CurationTypeNames.npy'), np.array(curation_type_names))
|
|
175
|
+
|
|
176
|
+
Output = {
|
|
177
|
+
'NumClusters': np.max(idx_cluster_hdbscan_curated),
|
|
178
|
+
'NumUnits': len(idx_cluster_hdbscan_curated),
|
|
179
|
+
'IdxSort': leafOrder,
|
|
180
|
+
'IdxCluster': idx_cluster_hdbscan_curated,
|
|
181
|
+
'SimilarityMatrix': similarity_matrix,
|
|
182
|
+
'GoodMatchesMatrix': good_matches_matrix,
|
|
183
|
+
'ClusterMatrix': hdbscan_matrix_curated,
|
|
184
|
+
'MatchedPairs': matched_pairs_curated,
|
|
185
|
+
'Params': user_settings,
|
|
186
|
+
'NumSession': np.max(sessions),
|
|
187
|
+
'Sessions': sessions
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
np.savez(os.path.join(output_folder, 'Output.npz'), Output)
|
|
191
|
+
print(f'DANT done! Output saved to {os.path.join(output_folder, "Output.npz")}')
|
|
192
|
+
print(f'Found {Output["NumClusters"]} clusters and {len(Output["MatchedPairs"])} matches from {Output["NumUnits"]} units during {Output["NumSession"]} sessions!')
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
# Plot the results
|
|
196
|
+
|
|
197
|
+
# probability of matching between sessions
|
|
198
|
+
n_session = np.max(sessions)
|
|
199
|
+
n_cluster = np.max(idx_cluster_hdbscan)
|
|
200
|
+
n_matched_matrix = np.zeros((n_session, n_session))
|
|
201
|
+
|
|
202
|
+
n_units_each_session = np.array(
|
|
203
|
+
[np.sum(sessions == i) for i in range(1, n_session+1)]
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
for k in range(1, n_cluster+1):
|
|
207
|
+
units = np.where(idx_cluster_hdbscan == k)[0]
|
|
208
|
+
for j in range(len(units)):
|
|
209
|
+
for i in range(j+1, len(units)):
|
|
210
|
+
n_matched_matrix[sessions[units[j]]-1, sessions[units[i]]-1] += 1
|
|
211
|
+
n_matched_matrix[sessions[units[i]]-1, sessions[units[j]]-1] += 1
|
|
212
|
+
|
|
213
|
+
d_session = np.arange(-n_session+1, n_session)
|
|
214
|
+
p_matched = [[] for _ in range(len(d_session))]
|
|
215
|
+
p_matched_matrix = np.zeros((n_session, n_session))
|
|
216
|
+
for k in range(1, n_session+1):
|
|
217
|
+
for j in range(k+1, n_session+1):
|
|
218
|
+
p_matched_matrix[k-1,j-1] = n_matched_matrix[k-1,j-1]/n_units_each_session[k-1]
|
|
219
|
+
p_matched_matrix[j-1,k-1] = n_matched_matrix[k-1,j-1]/n_units_each_session[j-1]
|
|
220
|
+
|
|
221
|
+
idx_this = np.where(d_session == j-k)[0]
|
|
222
|
+
assert len(idx_this) == 1, f"Error: {j-k} not in d_session"
|
|
223
|
+
p_matched[idx_this[0]].append(n_matched_matrix[k-1,j-1]/n_units_each_session[j-1])
|
|
224
|
+
|
|
225
|
+
idx_this = np.where(d_session == k-j)[0]
|
|
226
|
+
assert len(idx_this) == 1, f"Error: {j-k} not in d_session"
|
|
227
|
+
p_matched[idx_this[0]].append(n_matched_matrix[k-1,j-1]/n_units_each_session[k-1])
|
|
228
|
+
|
|
229
|
+
p_matches_mean = np.array([np.mean(p_matched[k]) for k in range(len(d_session))])
|
|
230
|
+
p_matches_std = np.array([np.std(p_matched[k]) for k in range(len(d_session))])
|
|
231
|
+
|
|
232
|
+
p_matches_mean[d_session == 0] = np.nan
|
|
233
|
+
p_matches_std[d_session == 0] = np.nan
|
|
234
|
+
|
|
235
|
+
plt.figure(figsize=(12, 5))
|
|
236
|
+
|
|
237
|
+
plt.subplot(121)
|
|
238
|
+
plt.imshow(p_matched_matrix, cmap='plasma')
|
|
239
|
+
plt.gca().invert_yaxis()
|
|
240
|
+
plt.colorbar(label='Prob. of matched')
|
|
241
|
+
plt.xlabel('Sessions')
|
|
242
|
+
plt.ylabel('Sessions')
|
|
243
|
+
|
|
244
|
+
plt.subplot(122)
|
|
245
|
+
plt.plot(d_session, p_matches_mean, 'k.-')
|
|
246
|
+
plt.xlabel('Δ session')
|
|
247
|
+
plt.ylabel('Prob. of matched')
|
|
248
|
+
|
|
249
|
+
plt.savefig(os.path.join(output_folder, 'Figures/MatchedProbability.png'), dpi=300)
|
|
250
|
+
plt.close()
|
|
251
|
+
|
|
252
|
+
# Plot the similarity distribution and the weights of each feature
|
|
253
|
+
similarity_all = clustering_result['similarity_all']
|
|
254
|
+
idx_unit_pairs = clustering_result['idx_unit_pairs']
|
|
255
|
+
weights = clustering_result['weights']
|
|
256
|
+
similarity_names = user_settings['clustering']['features']
|
|
257
|
+
|
|
258
|
+
n_pairs = similarity_all.shape[0]
|
|
259
|
+
n_features = similarity_all.shape[1]
|
|
260
|
+
|
|
261
|
+
mean_similarity = similarity_all @ weights
|
|
262
|
+
is_matched = np.array([hdbscan_matrix_curated[idx_unit_pairs[k,0], idx_unit_pairs[k,1]] for k in range(n_pairs)])
|
|
263
|
+
|
|
264
|
+
n_plots = n_features+1
|
|
265
|
+
plt.figure(figsize=(4*n_plots + 1, 4))
|
|
266
|
+
|
|
267
|
+
plt.subplot(1, n_plots, 1)
|
|
268
|
+
plt.hist(mean_similarity[~is_matched], bins=50, color='k', density=True)
|
|
269
|
+
plt.hist(mean_similarity[is_matched], bins=50, color='b', density=True)
|
|
270
|
+
plt.xlabel('Mean similarity')
|
|
271
|
+
plt.ylabel('Density')
|
|
272
|
+
plt.title('Weighted summation')
|
|
273
|
+
|
|
274
|
+
for k in range(n_features):
|
|
275
|
+
plt.subplot(1, n_plots, k+2)
|
|
276
|
+
plt.hist(similarity_all[~is_matched,k], bins=50, color='k', density=True)
|
|
277
|
+
plt.hist(similarity_all[is_matched,k], bins=50, color='b', density=True)
|
|
278
|
+
plt.xlabel(similarity_names[k])
|
|
279
|
+
plt.title('weight = ' + f'{weights[k]:.2f}')
|
|
280
|
+
|
|
281
|
+
plt.savefig(os.path.join(output_folder, 'Figures/SimilarityDistribution.png'), dpi=300)
|
|
282
|
+
plt.close()
|
|
283
|
+
|
|
284
|
+
# Plot the scatters between any two features
|
|
285
|
+
n_points_max = 5000
|
|
286
|
+
idx_matched_rnd = np.where(is_matched)[0]
|
|
287
|
+
idx_unmatched_rnd = np.where(~is_matched)[0]
|
|
288
|
+
|
|
289
|
+
if np.sum(is_matched) > n_points_max:
|
|
290
|
+
idx_matched_rnd = np.random.choice(np.where(is_matched)[0], n_points_max, replace=False)
|
|
291
|
+
|
|
292
|
+
if np.sum(~is_matched) > n_points_max:
|
|
293
|
+
idx_unmatched_rnd = np.random.choice(np.where(~is_matched)[0], n_points_max, replace=False)
|
|
294
|
+
|
|
295
|
+
n_plots = n_features*(n_features-1)//2
|
|
296
|
+
|
|
297
|
+
plt.figure(figsize=(4*n_plots + 1, 4))
|
|
298
|
+
count = 0
|
|
299
|
+
for k in range(n_features):
|
|
300
|
+
for j in range(k+1, n_features):
|
|
301
|
+
count += 1
|
|
302
|
+
plt.subplot(1, n_plots, count)
|
|
303
|
+
plt.plot(similarity_all[idx_unmatched_rnd,k], similarity_all[idx_unmatched_rnd,j], 'k.', markersize=1, alpha=0.3)
|
|
304
|
+
plt.plot(similarity_all[idx_matched_rnd,k], similarity_all[idx_matched_rnd,j], 'b.', markersize=1, alpha=0.3)
|
|
305
|
+
|
|
306
|
+
plt.xlabel(similarity_names[k])
|
|
307
|
+
plt.ylabel(similarity_names[j])
|
|
308
|
+
|
|
309
|
+
plt.savefig(os.path.join(output_folder, 'Figures/FeatureScatter.png'), dpi=300)
|
|
310
|
+
plt.close()
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from joblib import Parallel, delayed
|
|
3
|
+
from tqdm import tqdm
|
|
4
|
+
import os
|
|
5
|
+
from .utils import waveformEstimation
|
|
6
|
+
import copy
|
|
7
|
+
|
|
8
|
+
def computeWaveformFeatures(user_settings, waveform_all, motion):
|
|
9
|
+
""" Compute the corrected waveforms based on the motion of the probe.
|
|
10
|
+
The corrected waveforms on the reference probe are computed using the Kriging interpolation method
|
|
11
|
+
and saved to the output folder.
|
|
12
|
+
|
|
13
|
+
Arguments:
|
|
14
|
+
- user_settings (dict): User settings
|
|
15
|
+
- waveform_all (numpy.ndarray): The waveforms of all units (n_unit, n_channel, n_sample)
|
|
16
|
+
- motion (Motion): The motion object containing the linear and constant parameters for correction
|
|
17
|
+
Outputs:
|
|
18
|
+
- waveforms_corrected.npy: The corrected waveforms.
|
|
19
|
+
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
data_folder = user_settings["path_to_data"]
|
|
23
|
+
output_folder = user_settings["output_folder"]
|
|
24
|
+
n_templates = user_settings["waveformCorrection"]["n_templates"]
|
|
25
|
+
|
|
26
|
+
channel_locations = np.load(os.path.join(data_folder, 'channel_locations.npy'))
|
|
27
|
+
sessions = np.load(os.path.join(data_folder , 'session_index.npy'))
|
|
28
|
+
|
|
29
|
+
locations = np.load(os.path.join(output_folder, 'locations.npy'))
|
|
30
|
+
|
|
31
|
+
n_sample = waveform_all.shape[2]
|
|
32
|
+
n_channel = waveform_all.shape[1]
|
|
33
|
+
n_unit = waveform_all.shape[0]
|
|
34
|
+
|
|
35
|
+
min_channel_depth = np.min(channel_locations[:,1])
|
|
36
|
+
max_channel_depth = np.max(channel_locations[:,1])
|
|
37
|
+
|
|
38
|
+
motion_bottom = motion.LinearScale*motion.Linear*min_channel_depth + motion.Constant
|
|
39
|
+
motion_top = motion.LinearScale*motion.Linear*max_channel_depth + motion.Constant
|
|
40
|
+
|
|
41
|
+
min_motion = np.min(np.concatenate((motion_bottom, motion_top)))
|
|
42
|
+
max_motion = np.max(np.concatenate((motion_bottom, motion_top)))
|
|
43
|
+
print('The range of motion: [%.1f μm ~ %.1f μm]\n' % (min_motion, max_motion))
|
|
44
|
+
|
|
45
|
+
def process_spike(locations_this, motion, channel_locations, waveform_this, session_this, n_templates, min_motion, max_motion):
|
|
46
|
+
|
|
47
|
+
motion_this = copy.deepcopy(motion)
|
|
48
|
+
waveforms_corrected = np.zeros((n_channel, n_sample, n_templates))
|
|
49
|
+
|
|
50
|
+
for k in range(n_templates):
|
|
51
|
+
if n_templates == 2:
|
|
52
|
+
if k == 0:
|
|
53
|
+
motion_this.Constant = motion_this.Constant - min_motion
|
|
54
|
+
else:
|
|
55
|
+
motion_this.Constant = motion_this.Constant - max_motion
|
|
56
|
+
|
|
57
|
+
dy = motion_this.get_motion(session_this, locations_this[1])
|
|
58
|
+
location_new = locations_this.copy()
|
|
59
|
+
location_new[1] -= dy
|
|
60
|
+
|
|
61
|
+
waveforms_corrected[:,:,k] = waveformEstimation(
|
|
62
|
+
waveform_this, locations_this, channel_locations, location_new)
|
|
63
|
+
|
|
64
|
+
return waveforms_corrected
|
|
65
|
+
|
|
66
|
+
# Run parallel processing with progress bar
|
|
67
|
+
out = Parallel(n_jobs=user_settings["n_jobs"])(
|
|
68
|
+
delayed(process_spike)(locations[k,:2], motion, channel_locations, waveform_all[k,:,:], sessions[k], n_templates, min_motion, max_motion)
|
|
69
|
+
for k in tqdm(range(n_unit), desc='Computing waveform features')
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
waveforms_corrected = np.zeros((n_unit, n_channel, n_sample, n_templates))
|
|
73
|
+
for k in range(n_unit):
|
|
74
|
+
waveforms_corrected[k,:,:,:] = out[k]
|
|
75
|
+
|
|
76
|
+
# Save the corrected waveforms
|
|
77
|
+
output_folder = user_settings['output_folder']
|
|
78
|
+
np.save(os.path.join(output_folder, 'waveforms_corrected.npy'), waveforms_corrected)
|