bmtool 0.7.0.6.4__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bmtool/SLURM.py +162 -109
- bmtool/__init__.py +1 -1
- bmtool/__main__.py +8 -7
- bmtool/analysis/entrainment.py +250 -143
- bmtool/analysis/lfp.py +279 -134
- bmtool/analysis/netcon_reports.py +41 -44
- bmtool/analysis/spikes.py +114 -73
- bmtool/bmplot/connections.py +658 -325
- bmtool/bmplot/entrainment.py +17 -18
- bmtool/bmplot/lfp.py +24 -17
- bmtool/bmplot/netcon_reports.py +0 -4
- bmtool/bmplot/spikes.py +97 -48
- bmtool/connectors.py +394 -251
- bmtool/debug/commands.py +13 -7
- bmtool/debug/debug.py +2 -2
- bmtool/graphs.py +26 -19
- bmtool/manage.py +6 -11
- bmtool/plot_commands.py +350 -151
- bmtool/singlecell.py +357 -195
- bmtool/synapses.py +564 -470
- bmtool/util/commands.py +1079 -627
- bmtool/util/neuron/celltuner.py +989 -609
- bmtool/util/util.py +992 -588
- {bmtool-0.7.0.6.4.dist-info → bmtool-0.7.1.dist-info}/METADATA +40 -2
- bmtool-0.7.1.dist-info/RECORD +34 -0
- {bmtool-0.7.0.6.4.dist-info → bmtool-0.7.1.dist-info}/WHEEL +1 -1
- bmtool-0.7.0.6.4.dist-info/RECORD +0 -34
- {bmtool-0.7.0.6.4.dist-info → bmtool-0.7.1.dist-info}/entry_points.txt +0 -0
- {bmtool-0.7.0.6.4.dist-info → bmtool-0.7.1.dist-info}/licenses/LICENSE +0 -0
- {bmtool-0.7.0.6.4.dist-info → bmtool-0.7.1.dist-info}/top_level.txt +0 -0
bmtool/bmplot/connections.py
CHANGED
@@ -2,40 +2,37 @@
|
|
2
2
|
Want to be able to take multiple plot names in and plot them all at the same time, to save time
|
3
3
|
https://stackoverflow.com/questions/458209/is-there-a-way-to-detach-matplotlib-plots-so-that-the-computation-can-continue
|
4
4
|
"""
|
5
|
-
|
6
|
-
import
|
5
|
+
import re
|
6
|
+
import statistics
|
7
|
+
|
7
8
|
import matplotlib
|
8
|
-
import matplotlib.pyplot as plt
|
9
9
|
import matplotlib.cm as cmx
|
10
10
|
import matplotlib.colors as colors
|
11
|
-
|
12
|
-
|
13
|
-
import statistics
|
11
|
+
import matplotlib.pyplot as plt
|
12
|
+
import numpy as np
|
14
13
|
import pandas as pd
|
15
|
-
import
|
16
|
-
import sys
|
17
|
-
import re
|
18
|
-
|
19
|
-
from ..util.util import CellVarsFile,load_nodes_from_config,load_templates_from_config #, missing_units
|
20
|
-
from bmtk.analyzer.utils import listify
|
14
|
+
from IPython import get_ipython
|
21
15
|
from neuron import h
|
22
16
|
|
17
|
+
from ..util import util
|
18
|
+
|
23
19
|
use_description = """
|
24
20
|
|
25
21
|
Plot BMTK models easily.
|
26
22
|
|
27
|
-
python -m bmtool.plot
|
23
|
+
python -m bmtool.plot
|
28
24
|
"""
|
29
25
|
|
26
|
+
|
30
27
|
def is_notebook() -> bool:
|
31
28
|
"""
|
32
29
|
Detect if code is running in a Jupyter notebook environment.
|
33
|
-
|
30
|
+
|
34
31
|
Returns:
|
35
32
|
--------
|
36
33
|
bool
|
37
34
|
True if running in a Jupyter notebook, False otherwise.
|
38
|
-
|
35
|
+
|
39
36
|
Notes:
|
40
37
|
------
|
41
38
|
This is used to determine whether to call plt.show() explicitly or
|
@@ -43,20 +40,31 @@ def is_notebook() -> bool:
|
|
43
40
|
"""
|
44
41
|
try:
|
45
42
|
shell = get_ipython().__class__.__name__
|
46
|
-
if shell ==
|
47
|
-
return True
|
48
|
-
elif shell ==
|
43
|
+
if shell == "ZMQInteractiveShell":
|
44
|
+
return True # Jupyter notebook or qtconsole
|
45
|
+
elif shell == "TerminalInteractiveShell":
|
49
46
|
return False # Terminal running IPython
|
50
47
|
else:
|
51
48
|
return False # Other type (?)
|
52
49
|
except NameError:
|
53
|
-
return False
|
54
|
-
|
55
|
-
|
56
|
-
def total_connection_matrix(
|
50
|
+
return False # Probably standard Python interpreter
|
51
|
+
|
52
|
+
|
53
|
+
def total_connection_matrix(
|
54
|
+
config=None,
|
55
|
+
title=None,
|
56
|
+
sources=None,
|
57
|
+
targets=None,
|
58
|
+
sids=None,
|
59
|
+
tids=None,
|
60
|
+
no_prepend_pop=False,
|
61
|
+
save_file=None,
|
62
|
+
synaptic_info="0",
|
63
|
+
include_gap=True,
|
64
|
+
):
|
57
65
|
"""
|
58
66
|
Generate a plot displaying total connections or other synaptic statistics.
|
59
|
-
|
67
|
+
|
60
68
|
Parameters:
|
61
69
|
-----------
|
62
70
|
config : str
|
@@ -84,7 +92,7 @@ def total_connection_matrix(config=None, title=None, sources=None, targets=None,
|
|
84
92
|
include_gap : bool, optional
|
85
93
|
If True, include gap junctions and chemical synapses in the analysis.
|
86
94
|
If False, only include chemical synapses.
|
87
|
-
|
95
|
+
|
88
96
|
Returns:
|
89
97
|
--------
|
90
98
|
None
|
@@ -104,27 +112,53 @@ def total_connection_matrix(config=None, title=None, sources=None, targets=None,
|
|
104
112
|
tids = tids.split(",")
|
105
113
|
else:
|
106
114
|
tids = []
|
107
|
-
text,num, source_labels, target_labels = util.connection_totals(
|
115
|
+
text, num, source_labels, target_labels = util.connection_totals(
|
116
|
+
config=config,
|
117
|
+
nodes=None,
|
118
|
+
edges=None,
|
119
|
+
sources=sources,
|
120
|
+
targets=targets,
|
121
|
+
sids=sids,
|
122
|
+
tids=tids,
|
123
|
+
prepend_pop=not no_prepend_pop,
|
124
|
+
synaptic_info=synaptic_info,
|
125
|
+
include_gap=include_gap,
|
126
|
+
)
|
108
127
|
|
109
|
-
if title
|
128
|
+
if title is None or title == "":
|
110
129
|
title = "Total Connections"
|
111
|
-
if synaptic_info==
|
112
|
-
title = "Mean and Stdev # of Conn on Target"
|
113
|
-
if synaptic_info==
|
130
|
+
if synaptic_info == "1":
|
131
|
+
title = "Mean and Stdev # of Conn on Target"
|
132
|
+
if synaptic_info == "2":
|
114
133
|
title = "All Synapse .mod Files Used"
|
115
|
-
if synaptic_info==
|
134
|
+
if synaptic_info == "3":
|
116
135
|
title = "All Synapse .json Files Used"
|
117
|
-
plot_connection_info(
|
136
|
+
plot_connection_info(
|
137
|
+
text, num, source_labels, target_labels, title, syn_info=synaptic_info, save_file=save_file
|
138
|
+
)
|
118
139
|
return
|
119
140
|
|
120
141
|
|
121
|
-
def percent_connection_matrix(
|
142
|
+
def percent_connection_matrix(
|
143
|
+
config=None,
|
144
|
+
nodes=None,
|
145
|
+
edges=None,
|
146
|
+
title=None,
|
147
|
+
sources=None,
|
148
|
+
targets=None,
|
149
|
+
sids=None,
|
150
|
+
tids=None,
|
151
|
+
no_prepend_pop=False,
|
152
|
+
save_file=None,
|
153
|
+
method="total",
|
154
|
+
include_gap=True,
|
155
|
+
):
|
122
156
|
"""
|
123
157
|
Generates a plot showing the percent connectivity of a network
|
124
|
-
config: A BMTK simulation config
|
158
|
+
config: A BMTK simulation config
|
125
159
|
sources: network name(s) to plot
|
126
160
|
targets: network name(s) to plot
|
127
|
-
sids: source node identifier
|
161
|
+
sids: source node identifier
|
128
162
|
tids: target node identifier
|
129
163
|
no_prepend_pop: dictates if population name is displayed before sid or tid when displaying graph
|
130
164
|
method: what percent to displace on the graph 'total','uni',or 'bi' for total connections, unidirectional connections or bidirectional connections
|
@@ -135,7 +169,7 @@ def percent_connection_matrix(config=None,nodes=None,edges=None,title=None,sourc
|
|
135
169
|
raise Exception("config not defined")
|
136
170
|
if not sources or not targets:
|
137
171
|
raise Exception("Sources or targets not defined")
|
138
|
-
|
172
|
+
|
139
173
|
sources = sources.split(",")
|
140
174
|
targets = targets.split(",")
|
141
175
|
if sids:
|
@@ -146,17 +180,44 @@ def percent_connection_matrix(config=None,nodes=None,edges=None,title=None,sourc
|
|
146
180
|
tids = tids.split(",")
|
147
181
|
else:
|
148
182
|
tids = []
|
149
|
-
text,num, source_labels, target_labels = util.percent_connections(
|
150
|
-
|
183
|
+
text, num, source_labels, target_labels = util.percent_connections(
|
184
|
+
config=config,
|
185
|
+
nodes=None,
|
186
|
+
edges=None,
|
187
|
+
sources=sources,
|
188
|
+
targets=targets,
|
189
|
+
sids=sids,
|
190
|
+
tids=tids,
|
191
|
+
prepend_pop=not no_prepend_pop,
|
192
|
+
method=method,
|
193
|
+
include_gap=include_gap,
|
194
|
+
)
|
195
|
+
if title is None or title == "":
|
151
196
|
title = "Percent Connectivity"
|
152
197
|
|
153
|
-
|
154
|
-
plot_connection_info(text,num,source_labels,target_labels,title, save_file=save_file)
|
198
|
+
plot_connection_info(text, num, source_labels, target_labels, title, save_file=save_file)
|
155
199
|
return
|
156
200
|
|
157
201
|
|
158
|
-
def probability_connection_matrix(
|
159
|
-
|
202
|
+
def probability_connection_matrix(
|
203
|
+
config=None,
|
204
|
+
nodes=None,
|
205
|
+
edges=None,
|
206
|
+
title=None,
|
207
|
+
sources=None,
|
208
|
+
targets=None,
|
209
|
+
sids=None,
|
210
|
+
tids=None,
|
211
|
+
no_prepend_pop=False,
|
212
|
+
save_file=None,
|
213
|
+
dist_X=True,
|
214
|
+
dist_Y=True,
|
215
|
+
dist_Z=True,
|
216
|
+
bins=8,
|
217
|
+
line_plot=False,
|
218
|
+
verbose=False,
|
219
|
+
include_gap=True,
|
220
|
+
):
|
160
221
|
"""
|
161
222
|
Generates probability graphs
|
162
223
|
need to look into this more to see what it does
|
@@ -179,41 +240,53 @@ def probability_connection_matrix(config=None,nodes=None,edges=None,title=None,s
|
|
179
240
|
else:
|
180
241
|
tids = []
|
181
242
|
|
182
|
-
throwaway, data, source_labels, target_labels = util.connection_probabilities(
|
183
|
-
|
184
|
-
|
243
|
+
throwaway, data, source_labels, target_labels = util.connection_probabilities(
|
244
|
+
config=config,
|
245
|
+
nodes=None,
|
246
|
+
edges=None,
|
247
|
+
sources=sources,
|
248
|
+
targets=targets,
|
249
|
+
sids=sids,
|
250
|
+
tids=tids,
|
251
|
+
prepend_pop=not no_prepend_pop,
|
252
|
+
dist_X=dist_X,
|
253
|
+
dist_Y=dist_Y,
|
254
|
+
dist_Z=dist_Z,
|
255
|
+
num_bins=bins,
|
256
|
+
include_gap=include_gap,
|
257
|
+
)
|
185
258
|
if not data.any():
|
186
259
|
return
|
187
|
-
if data[0][0]
|
260
|
+
if data[0][0] == -1:
|
188
261
|
return
|
189
|
-
#plot_connection_info(data,source_labels,target_labels,title, save_file=save_file)
|
262
|
+
# plot_connection_info(data,source_labels,target_labels,title, save_file=save_file)
|
190
263
|
|
191
|
-
#plt.clf()# clears previous plots
|
192
|
-
np.seterr(divide=
|
264
|
+
# plt.clf()# clears previous plots
|
265
|
+
np.seterr(divide="ignore", invalid="ignore")
|
193
266
|
num_src, num_tar = data.shape
|
194
|
-
fig, axes = plt.subplots(nrows=num_src, ncols=num_tar, figsize=(12,12))
|
267
|
+
fig, axes = plt.subplots(nrows=num_src, ncols=num_tar, figsize=(12, 12))
|
195
268
|
fig.subplots_adjust(hspace=0.5, wspace=0.5)
|
196
269
|
|
197
270
|
for x in range(num_src):
|
198
271
|
for y in range(num_tar):
|
199
272
|
ns = data[x][y]["ns"]
|
200
273
|
bins = data[x][y]["bins"]
|
201
|
-
|
274
|
+
|
202
275
|
XX = bins[:-1]
|
203
|
-
YY = ns[0]/ns[1]
|
276
|
+
YY = ns[0] / ns[1]
|
204
277
|
|
205
278
|
if line_plot:
|
206
|
-
axes[x,y].plot(XX,YY)
|
279
|
+
axes[x, y].plot(XX, YY)
|
207
280
|
else:
|
208
|
-
axes[x,y].bar(XX,YY)
|
281
|
+
axes[x, y].bar(XX, YY)
|
209
282
|
|
210
|
-
if x == num_src-1:
|
211
|
-
axes[x,y].set_xlabel(target_labels[y])
|
283
|
+
if x == num_src - 1:
|
284
|
+
axes[x, y].set_xlabel(target_labels[y])
|
212
285
|
if y == 0:
|
213
|
-
axes[x,y].set_ylabel(source_labels[x])
|
286
|
+
axes[x, y].set_ylabel(source_labels[x])
|
214
287
|
|
215
288
|
if verbose:
|
216
|
-
print("Source: [" + source_labels[x] + "] | Target: ["+ target_labels[y] +"]")
|
289
|
+
print("Source: [" + source_labels[x] + "] | Target: [" + target_labels[y] + "]")
|
217
290
|
print("X:")
|
218
291
|
print(XX)
|
219
292
|
print("Y:")
|
@@ -223,41 +296,80 @@ def probability_connection_matrix(config=None,nodes=None,edges=None,title=None,s
|
|
223
296
|
if title:
|
224
297
|
tt = title
|
225
298
|
st = fig.suptitle(tt, fontsize=14)
|
226
|
-
fig.text(0.5, 0.04,
|
227
|
-
fig.text(0.04, 0.5,
|
299
|
+
fig.text(0.5, 0.04, "Target", ha="center")
|
300
|
+
fig.text(0.04, 0.5, "Source", va="center", rotation="vertical")
|
228
301
|
notebook = is_notebook
|
229
|
-
if notebook
|
302
|
+
if not notebook:
|
230
303
|
fig.show()
|
231
304
|
|
232
305
|
return
|
233
306
|
|
234
307
|
|
235
|
-
def convergence_connection_matrix(
|
308
|
+
def convergence_connection_matrix(
|
309
|
+
config=None,
|
310
|
+
title=None,
|
311
|
+
sources=None,
|
312
|
+
targets=None,
|
313
|
+
sids=None,
|
314
|
+
tids=None,
|
315
|
+
no_prepend_pop=False,
|
316
|
+
save_file=None,
|
317
|
+
convergence=True,
|
318
|
+
method="mean+std",
|
319
|
+
include_gap=True,
|
320
|
+
return_dict=None,
|
321
|
+
):
|
236
322
|
"""
|
237
323
|
Generates connection plot displaying convergence data
|
238
|
-
config: A BMTK simulation config
|
324
|
+
config: A BMTK simulation config
|
239
325
|
sources: network name(s) to plot
|
240
326
|
targets: network name(s) to plot
|
241
|
-
sids: source node identifier
|
327
|
+
sids: source node identifier
|
242
328
|
tids: target node identifier
|
243
329
|
no_prepend_pop: dictates if population name is displayed before sid or tid when displaying graph
|
244
330
|
save_file: If plot should be saved
|
245
|
-
method: 'mean','min','max','stdev' or 'mean+std' connvergence plot
|
331
|
+
method: 'mean','min','max','stdev' or 'mean+std' connvergence plot
|
246
332
|
"""
|
247
333
|
if not config:
|
248
334
|
raise Exception("config not defined")
|
249
335
|
if not sources or not targets:
|
250
336
|
raise Exception("Sources or targets not defined")
|
251
|
-
return divergence_connection_matrix(
|
337
|
+
return divergence_connection_matrix(
|
338
|
+
config,
|
339
|
+
title,
|
340
|
+
sources,
|
341
|
+
targets,
|
342
|
+
sids,
|
343
|
+
tids,
|
344
|
+
no_prepend_pop,
|
345
|
+
save_file,
|
346
|
+
convergence,
|
347
|
+
method,
|
348
|
+
include_gap=include_gap,
|
349
|
+
return_dict=return_dict,
|
350
|
+
)
|
252
351
|
|
253
352
|
|
254
|
-
def divergence_connection_matrix(
|
353
|
+
def divergence_connection_matrix(
|
354
|
+
config=None,
|
355
|
+
title=None,
|
356
|
+
sources=None,
|
357
|
+
targets=None,
|
358
|
+
sids=None,
|
359
|
+
tids=None,
|
360
|
+
no_prepend_pop=False,
|
361
|
+
save_file=None,
|
362
|
+
convergence=False,
|
363
|
+
method="mean+std",
|
364
|
+
include_gap=True,
|
365
|
+
return_dict=None,
|
366
|
+
):
|
255
367
|
"""
|
256
368
|
Generates connection plot displaying divergence data
|
257
|
-
config: A BMTK simulation config
|
369
|
+
config: A BMTK simulation config
|
258
370
|
sources: network name(s) to plot
|
259
371
|
targets: network name(s) to plot
|
260
|
-
sids: source node identifier
|
372
|
+
sids: source node identifier
|
261
373
|
tids: target node identifier
|
262
374
|
no_prepend_pop: dictates if population name is displayed before sid or tid when displaying graph
|
263
375
|
save_file: If plot should be saved
|
@@ -278,43 +390,73 @@ def divergence_connection_matrix(config=None,title=None,sources=None, targets=No
|
|
278
390
|
else:
|
279
391
|
tids = []
|
280
392
|
|
281
|
-
syn_info, data, source_labels, target_labels = util.connection_divergence(
|
282
|
-
|
283
|
-
|
284
|
-
|
393
|
+
syn_info, data, source_labels, target_labels = util.connection_divergence(
|
394
|
+
config=config,
|
395
|
+
nodes=None,
|
396
|
+
edges=None,
|
397
|
+
sources=sources,
|
398
|
+
targets=targets,
|
399
|
+
sids=sids,
|
400
|
+
tids=tids,
|
401
|
+
prepend_pop=not no_prepend_pop,
|
402
|
+
convergence=convergence,
|
403
|
+
method=method,
|
404
|
+
include_gap=include_gap,
|
405
|
+
)
|
285
406
|
|
286
|
-
|
407
|
+
# data, labels = util.connection_divergence_average(config=config,nodes=nodes,edges=edges,populations=populations)
|
287
408
|
|
288
|
-
|
409
|
+
if title is None or title == "":
|
410
|
+
if method == "min":
|
289
411
|
title = "Minimum "
|
290
|
-
elif method ==
|
412
|
+
elif method == "max":
|
291
413
|
title = "Maximum "
|
292
|
-
elif method ==
|
414
|
+
elif method == "std":
|
293
415
|
title = "Standard Deviation "
|
294
|
-
elif method ==
|
416
|
+
elif method == "mean":
|
295
417
|
title = "Mean "
|
296
|
-
else:
|
297
|
-
title =
|
418
|
+
else:
|
419
|
+
title = "Mean + Std "
|
298
420
|
|
299
421
|
if convergence:
|
300
422
|
title = title + "Synaptic Convergence"
|
301
423
|
else:
|
302
424
|
title = title + "Synaptic Divergence"
|
303
425
|
if return_dict:
|
304
|
-
dict = plot_connection_info(
|
426
|
+
dict = plot_connection_info(
|
427
|
+
syn_info,
|
428
|
+
data,
|
429
|
+
source_labels,
|
430
|
+
target_labels,
|
431
|
+
title,
|
432
|
+
save_file=save_file,
|
433
|
+
return_dict=return_dict,
|
434
|
+
)
|
305
435
|
return dict
|
306
436
|
else:
|
307
|
-
plot_connection_info(
|
437
|
+
plot_connection_info(
|
438
|
+
syn_info, data, source_labels, target_labels, title, save_file=save_file
|
439
|
+
)
|
308
440
|
return
|
309
441
|
|
310
442
|
|
311
|
-
def gap_junction_matrix(
|
443
|
+
def gap_junction_matrix(
|
444
|
+
config=None,
|
445
|
+
title=None,
|
446
|
+
sources=None,
|
447
|
+
targets=None,
|
448
|
+
sids=None,
|
449
|
+
tids=None,
|
450
|
+
no_prepend_pop=False,
|
451
|
+
save_file=None,
|
452
|
+
method="convergence",
|
453
|
+
):
|
312
454
|
"""
|
313
455
|
Generates connection plot displaying gap junction data.
|
314
|
-
config: A BMTK simulation config
|
456
|
+
config: A BMTK simulation config
|
315
457
|
sources: network name(s) to plot
|
316
458
|
targets: network name(s) to plot
|
317
|
-
sids: source node identifier
|
459
|
+
sids: source node identifier
|
318
460
|
tids: target node identifier
|
319
461
|
no_prepend_pop: dictates if population name is displayed before sid or tid when displaying graph
|
320
462
|
save_file: If plot should be saved
|
@@ -324,7 +466,7 @@ def gap_junction_matrix(config=None,title=None,sources=None, targets=None, sids=
|
|
324
466
|
raise Exception("config not defined")
|
325
467
|
if not sources or not targets:
|
326
468
|
raise Exception("Sources or targets not defined")
|
327
|
-
if method !=
|
469
|
+
if method != "convergence" and method != "percent":
|
328
470
|
raise Exception("type must be 'convergence' or 'percent'")
|
329
471
|
sources = sources.split(",")
|
330
472
|
targets = targets.split(",")
|
@@ -336,16 +478,25 @@ def gap_junction_matrix(config=None,title=None,sources=None, targets=None, sids=
|
|
336
478
|
tids = tids.split(",")
|
337
479
|
else:
|
338
480
|
tids = []
|
339
|
-
syn_info, data, source_labels, target_labels = util.gap_junction_connections(
|
340
|
-
|
341
|
-
|
481
|
+
syn_info, data, source_labels, target_labels = util.gap_junction_connections(
|
482
|
+
config=config,
|
483
|
+
nodes=None,
|
484
|
+
edges=None,
|
485
|
+
sources=sources,
|
486
|
+
targets=targets,
|
487
|
+
sids=sids,
|
488
|
+
tids=tids,
|
489
|
+
prepend_pop=not no_prepend_pop,
|
490
|
+
method=method,
|
491
|
+
)
|
492
|
+
|
342
493
|
def filter_rows(syn_info, data, source_labels, target_labels):
|
343
494
|
"""
|
344
495
|
Filters out rows in a connectivity matrix that contain only NaN or zero values.
|
345
|
-
|
496
|
+
|
346
497
|
This function is used to clean up connection matrices by removing rows that have no meaningful data,
|
347
498
|
which helps create more informative visualizations of network connectivity.
|
348
|
-
|
499
|
+
|
349
500
|
Parameters:
|
350
501
|
-----------
|
351
502
|
syn_info : numpy.ndarray
|
@@ -356,7 +507,7 @@ def gap_junction_matrix(config=None,title=None,sources=None, targets=None, sids=
|
|
356
507
|
List of labels for the source populations corresponding to rows in the data matrix.
|
357
508
|
target_labels : list
|
358
509
|
List of labels for the target populations corresponding to columns in the data matrix.
|
359
|
-
|
510
|
+
|
360
511
|
Returns:
|
361
512
|
--------
|
362
513
|
tuple
|
@@ -375,11 +526,11 @@ def gap_junction_matrix(config=None,title=None,sources=None, targets=None, sids=
|
|
375
526
|
def filter_rows_and_columns(syn_info, data, source_labels, target_labels):
|
376
527
|
"""
|
377
528
|
Filters out both rows and columns in a connectivity matrix that contain only NaN or zero values.
|
378
|
-
|
529
|
+
|
379
530
|
This function performs a two-step filtering process: first removing rows with no data,
|
380
531
|
then transposing the matrix and removing columns with no data (by treating them as rows).
|
381
532
|
This creates a cleaner, more informative connectivity matrix visualization.
|
382
|
-
|
533
|
+
|
383
534
|
Parameters:
|
384
535
|
-----------
|
385
536
|
syn_info : numpy.ndarray
|
@@ -390,7 +541,7 @@ def gap_junction_matrix(config=None,title=None,sources=None, targets=None, sids=
|
|
390
541
|
List of labels for the source populations corresponding to rows in the data matrix.
|
391
542
|
target_labels : list
|
392
543
|
List of labels for the target populations corresponding to columns in the data matrix.
|
393
|
-
|
544
|
+
|
394
545
|
Returns:
|
395
546
|
--------
|
396
547
|
tuple
|
@@ -398,7 +549,9 @@ def gap_junction_matrix(config=None,title=None,sources=None, targets=None, sids=
|
|
398
549
|
invalid rows and columns removed.
|
399
550
|
"""
|
400
551
|
# Filter rows first
|
401
|
-
syn_info, data, source_labels, target_labels = filter_rows(
|
552
|
+
syn_info, data, source_labels, target_labels = filter_rows(
|
553
|
+
syn_info, data, source_labels, target_labels
|
554
|
+
)
|
402
555
|
|
403
556
|
# Transpose data to filter columns
|
404
557
|
transposed_syn_info = np.transpose(syn_info)
|
@@ -407,7 +560,12 @@ def gap_junction_matrix(config=None,title=None,sources=None, targets=None, sids=
|
|
407
560
|
transposed_target_labels = source_labels
|
408
561
|
|
409
562
|
# Filter columns (by treating them as rows in transposed data)
|
410
|
-
|
563
|
+
(
|
564
|
+
transposed_syn_info,
|
565
|
+
transposed_data,
|
566
|
+
transposed_source_labels,
|
567
|
+
transposed_target_labels,
|
568
|
+
) = filter_rows(
|
411
569
|
transposed_syn_info, transposed_data, transposed_source_labels, transposed_target_labels
|
412
570
|
)
|
413
571
|
|
@@ -419,40 +577,54 @@ def gap_junction_matrix(config=None,title=None,sources=None, targets=None, sids=
|
|
419
577
|
|
420
578
|
return filtered_syn_info, filtered_data, filtered_source_labels, filtered_target_labels
|
421
579
|
|
422
|
-
|
423
|
-
|
580
|
+
syn_info, data, source_labels, target_labels = filter_rows_and_columns(
|
581
|
+
syn_info, data, source_labels, target_labels
|
582
|
+
)
|
424
583
|
|
425
|
-
if title
|
426
|
-
title =
|
427
|
-
if method ==
|
428
|
-
title+=
|
429
|
-
elif method ==
|
430
|
-
title+=
|
431
|
-
plot_connection_info(syn_info,data,source_labels,target_labels,title, save_file=save_file)
|
584
|
+
if title is None or title == "":
|
585
|
+
title = "Gap Junction"
|
586
|
+
if method == "convergence":
|
587
|
+
title += " Syn Convergence"
|
588
|
+
elif method == "percent":
|
589
|
+
title += " Percent Connectivity"
|
590
|
+
plot_connection_info(syn_info, data, source_labels, target_labels, title, save_file=save_file)
|
432
591
|
return
|
433
592
|
|
434
593
|
|
435
|
-
def connection_histogram(
|
436
|
-
|
594
|
+
def connection_histogram(
|
595
|
+
config=None,
|
596
|
+
nodes=None,
|
597
|
+
edges=None,
|
598
|
+
sources=[],
|
599
|
+
targets=[],
|
600
|
+
sids=[],
|
601
|
+
tids=[],
|
602
|
+
no_prepend_pop=True,
|
603
|
+
synaptic_info="0",
|
604
|
+
source_cell=None,
|
605
|
+
target_cell=None,
|
606
|
+
include_gap=True,
|
607
|
+
):
|
437
608
|
"""
|
438
609
|
Generates histogram of number of connections individual cells in a population receieve from another population
|
439
|
-
config: A BMTK simulation config
|
610
|
+
config: A BMTK simulation config
|
440
611
|
sources: network name(s) to plot
|
441
612
|
targets: network name(s) to plot
|
442
|
-
sids: source node identifier
|
613
|
+
sids: source node identifier
|
443
614
|
tids: target node identifier
|
444
615
|
no_prepend_pop: dictates if population name is displayed before sid or tid when displaying graph
|
445
616
|
source_cell: where connections are coming from
|
446
617
|
target_cell: where connections on coming onto
|
447
618
|
save_file: If plot should be saved
|
448
619
|
"""
|
620
|
+
|
449
621
|
def connection_pair_histogram(**kwargs):
|
450
622
|
"""
|
451
623
|
Creates a histogram showing the distribution of connection counts between a specific source and target cell type.
|
452
|
-
|
624
|
+
|
453
625
|
This function is designed to be used with the relation_matrix utility and will only create histograms
|
454
626
|
for the specified source and target cell types, ignoring all other combinations.
|
455
|
-
|
627
|
+
|
456
628
|
Parameters:
|
457
629
|
-----------
|
458
630
|
kwargs : dict
|
@@ -462,7 +634,7 @@ def connection_histogram(config=None,nodes=None,edges=None,sources=[],targets=[]
|
|
462
634
|
- tid: Column name for target ID type in the edges DataFrame
|
463
635
|
- source_id: Value to filter edges by source ID type
|
464
636
|
- target_id: Value to filter edges by target ID type
|
465
|
-
|
637
|
+
|
466
638
|
Global parameters used:
|
467
639
|
---------------------
|
468
640
|
source_cell : str
|
@@ -471,37 +643,41 @@ def connection_histogram(config=None,nodes=None,edges=None,sources=[],targets=[]
|
|
471
643
|
The target cell type to plot.
|
472
644
|
include_gap : bool
|
473
645
|
Whether to include gap junctions in the analysis. If False, gap junctions are excluded.
|
474
|
-
|
646
|
+
|
475
647
|
Returns:
|
476
648
|
--------
|
477
649
|
None
|
478
650
|
Displays a histogram showing the distribution of connection counts.
|
479
651
|
"""
|
480
|
-
edges = kwargs["edges"]
|
652
|
+
edges = kwargs["edges"]
|
481
653
|
source_id_type = kwargs["sid"]
|
482
654
|
target_id_type = kwargs["tid"]
|
483
655
|
source_id = kwargs["source_id"]
|
484
656
|
target_id = kwargs["target_id"]
|
485
657
|
if source_id == source_cell and target_id == target_cell:
|
486
|
-
temp = edges[
|
487
|
-
|
488
|
-
|
489
|
-
|
658
|
+
temp = edges[
|
659
|
+
(edges[source_id_type] == source_id) & (edges[target_id_type] == target_id)
|
660
|
+
]
|
661
|
+
if not include_gap:
|
662
|
+
temp = temp[~temp["is_gap_junction"]]
|
663
|
+
node_pairs = temp.groupby("target_node_id")["source_node_id"].count()
|
490
664
|
try:
|
491
665
|
conn_mean = statistics.mean(node_pairs.values)
|
492
666
|
conn_std = statistics.stdev(node_pairs.values)
|
493
667
|
conn_median = statistics.median(node_pairs.values)
|
494
|
-
label = "mean {:.2f} std {:.2f} median {:.2f}".format(
|
495
|
-
|
668
|
+
label = "mean {:.2f} std {:.2f} median {:.2f}".format(
|
669
|
+
conn_mean, conn_std, conn_median
|
670
|
+
)
|
671
|
+
except: # lazy fix for std not calculated with 1 node
|
496
672
|
conn_mean = statistics.mean(node_pairs.values)
|
497
673
|
conn_median = statistics.median(node_pairs.values)
|
498
|
-
label = "mean {:.2f} median {:.2f}".format(conn_mean,conn_median)
|
499
|
-
plt.hist(node_pairs.values,density=False,bins=
|
674
|
+
label = "mean {:.2f} median {:.2f}".format(conn_mean, conn_median)
|
675
|
+
plt.hist(node_pairs.values, density=False, bins="auto", stacked=True, label=label)
|
500
676
|
plt.legend()
|
501
|
-
plt.xlabel("# of conns from {} to {}".format(source_cell,target_cell))
|
677
|
+
plt.xlabel("# of conns from {} to {}".format(source_cell, target_cell))
|
502
678
|
plt.ylabel("# of cells")
|
503
679
|
plt.show()
|
504
|
-
else:
|
680
|
+
else: # dont care about other cell pairs so pass
|
505
681
|
pass
|
506
682
|
|
507
683
|
if not config:
|
@@ -518,18 +694,35 @@ def connection_histogram(config=None,nodes=None,edges=None,sources=[],targets=[]
|
|
518
694
|
tids = tids.split(",")
|
519
695
|
else:
|
520
696
|
tids = []
|
521
|
-
util.relation_matrix(
|
697
|
+
util.relation_matrix(
|
698
|
+
config,
|
699
|
+
nodes,
|
700
|
+
edges,
|
701
|
+
sources,
|
702
|
+
targets,
|
703
|
+
sids,
|
704
|
+
tids,
|
705
|
+
not no_prepend_pop,
|
706
|
+
relation_func=connection_pair_histogram,
|
707
|
+
synaptic_info=synaptic_info,
|
708
|
+
)
|
522
709
|
|
523
710
|
|
524
|
-
def connection_distance(
|
525
|
-
|
711
|
+
def connection_distance(
|
712
|
+
config: str,
|
713
|
+
sources: str,
|
714
|
+
targets: str,
|
715
|
+
source_cell_id: int,
|
716
|
+
target_id_type: str,
|
717
|
+
ignore_z: bool = False,
|
718
|
+
) -> None:
|
526
719
|
"""
|
527
720
|
Plots the 3D spatial distribution of target nodes relative to a source node
|
528
721
|
and a histogram of distances from the source node to each target node.
|
529
722
|
|
530
723
|
Parameters:
|
531
724
|
----------
|
532
|
-
config: (str) A BMTK simulation config
|
725
|
+
config: (str) A BMTK simulation config
|
533
726
|
sources: (str) network name(s) to plot
|
534
727
|
targets: (str) network name(s) to plot
|
535
728
|
source_cell_id : (int) ID of the source cell for calculating distances to target nodes.
|
@@ -541,22 +734,22 @@ def connection_distance(config: str,sources: str,targets: str,
|
|
541
734
|
raise Exception("config not defined")
|
542
735
|
if not sources or not targets:
|
543
736
|
raise Exception("Sources or targets not defined")
|
544
|
-
#if source != target:
|
545
|
-
|
546
|
-
|
737
|
+
# if source != target:
|
738
|
+
# raise Exception("Code is setup for source and target to be the same! Look at source code for function to add feature")
|
739
|
+
|
547
740
|
# Load nodes and edges based on config file
|
548
741
|
nodes, edges = util.load_nodes_edges_from_config(config)
|
549
|
-
|
742
|
+
|
550
743
|
edge_network = sources + "_to_" + targets
|
551
744
|
node_network = sources
|
552
745
|
|
553
746
|
# Filter edges to obtain connections originating from the source node
|
554
747
|
edge = edges[edge_network]
|
555
|
-
edge = edge[edge[
|
748
|
+
edge = edge[edge["source_node_id"] == source_cell_id]
|
556
749
|
if target_id_type:
|
557
|
-
edge = edge[edge[
|
750
|
+
edge = edge[edge["target_query"].str.contains(target_id_type, na=False)]
|
558
751
|
|
559
|
-
target_node_ids = edge[
|
752
|
+
target_node_ids = edge["target_node_id"]
|
560
753
|
|
561
754
|
# Filter nodes to obtain only the target and source nodes
|
562
755
|
node = nodes[node_network]
|
@@ -565,24 +758,40 @@ def connection_distance(config: str,sources: str,targets: str,
|
|
565
758
|
|
566
759
|
# Calculate distances between source node and each target node
|
567
760
|
if ignore_z:
|
568
|
-
target_positions = target_nodes[[
|
569
|
-
source_position = np.array(
|
761
|
+
target_positions = target_nodes[["pos_x", "pos_y"]].values
|
762
|
+
source_position = np.array(
|
763
|
+
[source_node["pos_x"], source_node["pos_y"]]
|
764
|
+
).ravel() # Ensure 1D shape
|
570
765
|
else:
|
571
|
-
target_positions = target_nodes[[
|
572
|
-
source_position = np.array(
|
766
|
+
target_positions = target_nodes[["pos_x", "pos_y", "pos_z"]].values
|
767
|
+
source_position = np.array(
|
768
|
+
[source_node["pos_x"], source_node["pos_y"], source_node["pos_z"]]
|
769
|
+
).ravel() # Ensure 1D shape
|
573
770
|
distances = np.linalg.norm(target_positions - source_position, axis=1)
|
574
771
|
|
575
772
|
# Plot positions of source and target nodes in 3D space or 2D
|
576
773
|
if ignore_z:
|
577
|
-
fig = plt.figure(figsize=(8, 6))
|
774
|
+
fig = plt.figure(figsize=(8, 6))
|
578
775
|
ax = fig.add_subplot(111)
|
579
|
-
ax.scatter(target_nodes[
|
580
|
-
ax.scatter(source_node[
|
776
|
+
ax.scatter(target_nodes["pos_x"], target_nodes["pos_y"], c="blue", label="target cells")
|
777
|
+
ax.scatter(source_node["pos_x"], source_node["pos_y"], c="red", label="source cell")
|
581
778
|
else:
|
582
|
-
fig = plt.figure(figsize=(8, 6))
|
583
|
-
ax = fig.add_subplot(111, projection=
|
584
|
-
ax.scatter(
|
585
|
-
|
779
|
+
fig = plt.figure(figsize=(8, 6))
|
780
|
+
ax = fig.add_subplot(111, projection="3d")
|
781
|
+
ax.scatter(
|
782
|
+
target_nodes["pos_x"],
|
783
|
+
target_nodes["pos_y"],
|
784
|
+
target_nodes["pos_z"],
|
785
|
+
c="blue",
|
786
|
+
label="target cells",
|
787
|
+
)
|
788
|
+
ax.scatter(
|
789
|
+
source_node["pos_x"],
|
790
|
+
source_node["pos_y"],
|
791
|
+
source_node["pos_z"],
|
792
|
+
c="red",
|
793
|
+
label="source cell",
|
794
|
+
)
|
586
795
|
|
587
796
|
# Optional: Add text annotations for distances
|
588
797
|
# for i, distance in enumerate(distances):
|
@@ -593,23 +802,36 @@ def connection_distance(config: str,sources: str,targets: str,
|
|
593
802
|
plt.show()
|
594
803
|
|
595
804
|
# Plot distances in a separate 2D plot
|
596
|
-
plt.figure(figsize=(8, 6))
|
597
|
-
plt.hist(distances, bins=20, color=
|
805
|
+
plt.figure(figsize=(8, 6))
|
806
|
+
plt.hist(distances, bins=20, color="blue", edgecolor="black")
|
598
807
|
plt.xlabel("Distance")
|
599
808
|
plt.ylabel("Count")
|
600
|
-
plt.title(
|
809
|
+
plt.title("Distance from Source Node to Each Target Node")
|
601
810
|
plt.grid(True)
|
602
811
|
plt.show()
|
603
812
|
|
604
813
|
|
605
|
-
def edge_histogram_matrix(
|
814
|
+
def edge_histogram_matrix(
|
815
|
+
config=None,
|
816
|
+
sources=None,
|
817
|
+
targets=None,
|
818
|
+
sids=None,
|
819
|
+
tids=None,
|
820
|
+
no_prepend_pop=None,
|
821
|
+
edge_property=None,
|
822
|
+
time=None,
|
823
|
+
time_compare=None,
|
824
|
+
report=None,
|
825
|
+
title=None,
|
826
|
+
save_file=None,
|
827
|
+
):
|
606
828
|
"""
|
607
829
|
Generates a matrix of histograms showing the distribution of edge properties between different populations.
|
608
|
-
|
609
|
-
This function creates a grid of histograms where each cell in the grid represents the distribution of a
|
610
|
-
specific edge property (e.g., synaptic weights, delays) between a source population (row) and
|
830
|
+
|
831
|
+
This function creates a grid of histograms where each cell in the grid represents the distribution of a
|
832
|
+
specific edge property (e.g., synaptic weights, delays) between a source population (row) and
|
611
833
|
target population (column).
|
612
|
-
|
834
|
+
|
613
835
|
Parameters:
|
614
836
|
-----------
|
615
837
|
config : str
|
@@ -636,13 +858,13 @@ def edge_histogram_matrix(config=None,sources = None,targets=None,sids=None,tids
|
|
636
858
|
Custom title for the plot.
|
637
859
|
save_file : str, optional
|
638
860
|
Path to save the generated plot.
|
639
|
-
|
861
|
+
|
640
862
|
Returns:
|
641
863
|
--------
|
642
864
|
None
|
643
865
|
Displays a matrix of histograms.
|
644
866
|
"""
|
645
|
-
|
867
|
+
|
646
868
|
if not config:
|
647
869
|
raise Exception("config not defined")
|
648
870
|
if not sources or not targets:
|
@@ -660,34 +882,48 @@ def edge_histogram_matrix(config=None,sources = None,targets=None,sids=None,tids
|
|
660
882
|
if time_compare:
|
661
883
|
time_compare = int(time_compare)
|
662
884
|
|
663
|
-
data, source_labels, target_labels = util.edge_property_matrix(
|
885
|
+
data, source_labels, target_labels = util.edge_property_matrix(
|
886
|
+
edge_property,
|
887
|
+
nodes=None,
|
888
|
+
edges=None,
|
889
|
+
config=config,
|
890
|
+
sources=sources,
|
891
|
+
targets=targets,
|
892
|
+
sids=sids,
|
893
|
+
tids=tids,
|
894
|
+
prepend_pop=not no_prepend_pop,
|
895
|
+
report=report,
|
896
|
+
time=time,
|
897
|
+
time_compare=time_compare,
|
898
|
+
)
|
664
899
|
|
665
900
|
# Fantastic resource
|
666
|
-
# https://stackoverflow.com/questions/7941207/is-there-a-function-to-make-scatterplot-matrices-in-matplotlib
|
901
|
+
# https://stackoverflow.com/questions/7941207/is-there-a-function-to-make-scatterplot-matrices-in-matplotlib
|
667
902
|
num_src, num_tar = data.shape
|
668
|
-
fig, axes = plt.subplots(nrows=num_src, ncols=num_tar, figsize=(12,12))
|
903
|
+
fig, axes = plt.subplots(nrows=num_src, ncols=num_tar, figsize=(12, 12))
|
669
904
|
fig.subplots_adjust(hspace=0.5, wspace=0.5)
|
670
905
|
|
671
906
|
for x in range(num_src):
|
672
907
|
for y in range(num_tar):
|
673
|
-
axes[x,y].hist(data[x][y])
|
908
|
+
axes[x, y].hist(data[x][y])
|
674
909
|
|
675
|
-
if x == num_src-1:
|
676
|
-
axes[x,y].set_xlabel(target_labels[y])
|
910
|
+
if x == num_src - 1:
|
911
|
+
axes[x, y].set_xlabel(target_labels[y])
|
677
912
|
if y == 0:
|
678
|
-
axes[x,y].set_ylabel(source_labels[x])
|
913
|
+
axes[x, y].set_ylabel(source_labels[x])
|
679
914
|
|
680
915
|
tt = edge_property + " Histogram Matrix"
|
681
916
|
if title:
|
682
917
|
tt = title
|
683
918
|
st = fig.suptitle(tt, fontsize=14)
|
684
|
-
fig.text(0.5, 0.04,
|
685
|
-
fig.text(0.04, 0.5,
|
919
|
+
fig.text(0.5, 0.04, "Target", ha="center")
|
920
|
+
fig.text(0.04, 0.5, "Source", va="center", rotation="vertical")
|
686
921
|
plt.draw()
|
687
922
|
|
688
923
|
|
689
|
-
def distance_delay_plot(
|
690
|
-
group_by: str,sid: str,tid: str
|
924
|
+
def distance_delay_plot(
|
925
|
+
simulation_config: str, source: str, target: str, group_by: str, sid: str, tid: str
|
926
|
+
) -> None:
|
691
927
|
"""
|
692
928
|
Plots the relationship between the distance and delay of connections between nodes in a neural network simulation.
|
693
929
|
|
@@ -708,25 +944,27 @@ def distance_delay_plot(simulation_config: str,source: str,target: str,
|
|
708
944
|
"""
|
709
945
|
nodes, edges = util.load_nodes_edges_from_config(simulation_config)
|
710
946
|
nodes = nodes[target]
|
711
|
-
#node id is index of nodes df
|
947
|
+
# node id is index of nodes df
|
712
948
|
node_id_source = nodes[nodes[group_by] == sid].index
|
713
949
|
node_id_target = nodes[nodes[group_by] == tid].index
|
714
950
|
|
715
|
-
edges = edges[f
|
716
|
-
edges = edges[
|
951
|
+
edges = edges[f"{source}_to_{target}"]
|
952
|
+
edges = edges[
|
953
|
+
edges["source_node_id"].isin(node_id_source) & edges["target_node_id"].isin(node_id_target)
|
954
|
+
]
|
717
955
|
|
718
956
|
stuff_to_plot = []
|
719
957
|
for index, row in edges.iterrows():
|
720
958
|
try:
|
721
|
-
source_node = row[
|
722
|
-
target_node = row[
|
723
|
-
|
724
|
-
source_pos = nodes.loc[[source_node], [
|
725
|
-
target_pos = nodes.loc[[target_node], [
|
726
|
-
|
959
|
+
source_node = row["source_node_id"]
|
960
|
+
target_node = row["target_node_id"]
|
961
|
+
|
962
|
+
source_pos = nodes.loc[[source_node], ["pos_x", "pos_y", "pos_z"]]
|
963
|
+
target_pos = nodes.loc[[target_node], ["pos_x", "pos_y", "pos_z"]]
|
964
|
+
|
727
965
|
distance = np.linalg.norm(source_pos.values - target_pos.values)
|
728
966
|
|
729
|
-
delay = row[
|
967
|
+
delay = row["delay"] # This line may raise KeyError
|
730
968
|
stuff_to_plot.append([distance, delay])
|
731
969
|
|
732
970
|
except KeyError as e:
|
@@ -737,9 +975,9 @@ def distance_delay_plot(simulation_config: str,source: str,target: str,
|
|
737
975
|
print(f"Unexpected error at edge index {index}: {e}")
|
738
976
|
|
739
977
|
plt.scatter([x[0] for x in stuff_to_plot], [x[1] for x in stuff_to_plot])
|
740
|
-
plt.xlabel(
|
741
|
-
plt.ylabel(
|
742
|
-
plt.title(f
|
978
|
+
plt.xlabel("Distance")
|
979
|
+
plt.ylabel("Delay")
|
980
|
+
plt.title(f"Distance vs Delay for edge between {sid} and {tid}")
|
743
981
|
plt.show()
|
744
982
|
|
745
983
|
|
@@ -748,7 +986,7 @@ def plot_synapse_location_histograms(config, target_model, source=None, target=N
|
|
748
986
|
generates a histogram of the positions of the synapses on a cell broken down by section
|
749
987
|
config: a BMTK config
|
750
988
|
target_model: the name of the model_template used when building the BMTK node
|
751
|
-
source: The source BMTK network
|
989
|
+
source: The source BMTK network
|
752
990
|
target: The target BMTK network
|
753
991
|
"""
|
754
992
|
# Load mechanisms and template
|
@@ -758,37 +996,36 @@ def plot_synapse_location_histograms(config, target_model, source=None, target=N
|
|
758
996
|
# Load node and edge data
|
759
997
|
nodes, edges = util.load_nodes_edges_from_config(config)
|
760
998
|
nodes = nodes[source]
|
761
|
-
edges = edges[f
|
999
|
+
edges = edges[f"{source}_to_{target}"]
|
762
1000
|
|
763
1001
|
# Map target_node_id to model_template
|
764
|
-
edges[
|
1002
|
+
edges["target_model_template"] = edges["target_node_id"].map(nodes["model_template"])
|
765
1003
|
|
766
1004
|
# Map source_node_id to pop_name
|
767
|
-
edges[
|
1005
|
+
edges["source_pop_name"] = edges["source_node_id"].map(nodes["pop_name"])
|
768
1006
|
|
769
|
-
edges = edges[edges[
|
1007
|
+
edges = edges[edges["target_model_template"] == target_model]
|
770
1008
|
|
771
1009
|
# Create the cell model from target model
|
772
|
-
cell = getattr(h, target_model.split(
|
773
|
-
|
1010
|
+
cell = getattr(h, target_model.split(":")[1])()
|
1011
|
+
|
774
1012
|
# Create a mapping from section index to section name
|
775
1013
|
section_id_to_name = {}
|
776
1014
|
for idx, sec in enumerate(cell.all):
|
777
1015
|
section_id_to_name[idx] = sec.name()
|
778
1016
|
|
779
1017
|
# Add a new column with section names based on afferent_section_id
|
780
|
-
edges[
|
1018
|
+
edges["afferent_section_name"] = edges["afferent_section_id"].map(section_id_to_name)
|
781
1019
|
|
782
1020
|
# Get unique sections and source populations
|
783
|
-
unique_pops = edges[
|
1021
|
+
unique_pops = edges["source_pop_name"].unique()
|
784
1022
|
|
785
1023
|
# Filter to only include sections with data
|
786
|
-
section_counts = edges[
|
1024
|
+
section_counts = edges["afferent_section_name"].value_counts()
|
787
1025
|
sections_with_data = section_counts[section_counts > 0].index.tolist()
|
788
1026
|
|
789
|
-
|
790
1027
|
# Create a figure with subplots for each section
|
791
|
-
plt.figure(figsize=(8,12))
|
1028
|
+
plt.figure(figsize=(8, 12))
|
792
1029
|
|
793
1030
|
# Color map for source populations
|
794
1031
|
color_map = plt.cm.tab10(np.linspace(0, 1, len(unique_pops)))
|
@@ -796,30 +1033,37 @@ def plot_synapse_location_histograms(config, target_model, source=None, target=N
|
|
796
1033
|
|
797
1034
|
# Create a histogram for each section
|
798
1035
|
for i, section in enumerate(sections_with_data):
|
799
|
-
ax = plt.subplot(len(sections_with_data), 1, i+1)
|
800
|
-
|
1036
|
+
ax = plt.subplot(len(sections_with_data), 1, i + 1)
|
1037
|
+
|
801
1038
|
# Get data for this section
|
802
|
-
section_data = edges[edges[
|
803
|
-
|
1039
|
+
section_data = edges[edges["afferent_section_name"] == section]
|
1040
|
+
|
804
1041
|
# Group by source population
|
805
|
-
for pop_name, pop_group in section_data.groupby(
|
1042
|
+
for pop_name, pop_group in section_data.groupby("source_pop_name"):
|
806
1043
|
if len(pop_group) > 0:
|
807
|
-
ax.hist(
|
808
|
-
|
809
|
-
|
1044
|
+
ax.hist(
|
1045
|
+
pop_group["afferent_section_pos"],
|
1046
|
+
bins=15,
|
1047
|
+
alpha=0.7,
|
1048
|
+
label=pop_name,
|
1049
|
+
color=pop_colors[pop_name],
|
1050
|
+
)
|
1051
|
+
|
810
1052
|
# Set title and labels
|
811
1053
|
ax.set_title(f"{section}", fontsize=10)
|
812
|
-
ax.set_xlabel(
|
813
|
-
ax.set_ylabel(
|
1054
|
+
ax.set_xlabel("Section Position", fontsize=8)
|
1055
|
+
ax.set_ylabel("Frequency", fontsize=8)
|
814
1056
|
ax.tick_params(labelsize=7)
|
815
1057
|
ax.grid(True, alpha=0.3)
|
816
|
-
|
1058
|
+
|
817
1059
|
# Only add legend to the first plot
|
818
1060
|
if i == 0:
|
819
1061
|
ax.legend(fontsize=8)
|
820
1062
|
|
821
1063
|
plt.tight_layout()
|
822
|
-
plt.suptitle(
|
1064
|
+
plt.suptitle(
|
1065
|
+
"Connection Distribution by Cell Section and Source Population", fontsize=16, y=1.02
|
1066
|
+
)
|
823
1067
|
if is_notebook:
|
824
1068
|
plt.show()
|
825
1069
|
else:
|
@@ -828,105 +1072,147 @@ def plot_synapse_location_histograms(config, target_model, source=None, target=N
|
|
828
1072
|
# Create a summary table
|
829
1073
|
print("Summary of connections by section and source population:")
|
830
1074
|
pivot_table = edges.pivot_table(
|
831
|
-
values=
|
832
|
-
index=
|
833
|
-
columns=
|
834
|
-
aggfunc=
|
835
|
-
fill_value=0
|
1075
|
+
values="afferent_section_id",
|
1076
|
+
index="afferent_section_name",
|
1077
|
+
columns="source_pop_name",
|
1078
|
+
aggfunc="count",
|
1079
|
+
fill_value=0,
|
836
1080
|
)
|
837
1081
|
print(pivot_table)
|
838
1082
|
|
839
1083
|
|
840
|
-
def plot_connection_info(
|
1084
|
+
def plot_connection_info(
|
1085
|
+
text, num, source_labels, target_labels, title, syn_info="0", save_file=None, return_dict=None
|
1086
|
+
):
|
841
1087
|
"""
|
842
1088
|
Function to plot connection information as a heatmap, including handling missing source and target values.
|
843
1089
|
If there is no source or target, set the value to 0.
|
844
1090
|
"""
|
845
|
-
|
1091
|
+
|
846
1092
|
# Ensure text dimensions match num dimensions
|
847
1093
|
num_source = len(source_labels)
|
848
1094
|
num_target = len(target_labels)
|
849
|
-
|
1095
|
+
|
850
1096
|
# Set color map
|
851
|
-
matplotlib.rc(
|
852
|
-
|
1097
|
+
matplotlib.rc("image", cmap="viridis")
|
1098
|
+
|
853
1099
|
# Create figure and axis for the plot
|
854
1100
|
fig1, ax1 = plt.subplots(figsize=(num_source, num_target))
|
855
|
-
num = np.nan_to_num(num, nan=0)
|
1101
|
+
num = np.nan_to_num(num, nan=0) # replace NaN with 0
|
856
1102
|
im1 = ax1.imshow(num)
|
857
|
-
|
1103
|
+
|
858
1104
|
# Set ticks and labels for source and target
|
859
1105
|
ax1.set_xticks(list(np.arange(len(target_labels))))
|
860
1106
|
ax1.set_yticks(list(np.arange(len(source_labels))))
|
861
1107
|
ax1.set_xticklabels(target_labels)
|
862
|
-
ax1.set_yticklabels(source_labels, size=12, weight=
|
863
|
-
|
1108
|
+
ax1.set_yticklabels(source_labels, size=12, weight="semibold")
|
1109
|
+
|
864
1110
|
# Rotate the tick labels for better visibility
|
865
|
-
plt.setp(
|
866
|
-
|
867
|
-
|
1111
|
+
plt.setp(
|
1112
|
+
ax1.get_xticklabels(),
|
1113
|
+
rotation=45,
|
1114
|
+
ha="right",
|
1115
|
+
rotation_mode="anchor",
|
1116
|
+
size=12,
|
1117
|
+
weight="semibold",
|
1118
|
+
)
|
1119
|
+
|
868
1120
|
# Dictionary to store connection information
|
869
1121
|
graph_dict = {}
|
870
|
-
|
1122
|
+
|
871
1123
|
# Loop over data dimensions and create text annotations
|
872
1124
|
for i in range(num_source):
|
873
1125
|
for j in range(num_target):
|
874
1126
|
# Get the edge info, or set it to '0' if it's missing
|
875
1127
|
edge_info = text[i, j] if text[i, j] is not None else 0
|
876
|
-
|
1128
|
+
|
877
1129
|
# Initialize the dictionary for the source node if not already done
|
878
1130
|
if source_labels[i] not in graph_dict:
|
879
1131
|
graph_dict[source_labels[i]] = {}
|
880
|
-
|
1132
|
+
|
881
1133
|
# Add edge info for the target node
|
882
1134
|
graph_dict[source_labels[i]][target_labels[j]] = edge_info
|
883
|
-
|
1135
|
+
|
884
1136
|
# Set text annotations based on syn_info type
|
885
|
-
if syn_info ==
|
1137
|
+
if syn_info == "2" or syn_info == "3":
|
886
1138
|
if num_source > 8 and num_source < 20:
|
887
|
-
fig_text = ax1.text(
|
888
|
-
|
1139
|
+
fig_text = ax1.text(
|
1140
|
+
j,
|
1141
|
+
i,
|
1142
|
+
edge_info,
|
1143
|
+
ha="center",
|
1144
|
+
va="center",
|
1145
|
+
color="w",
|
1146
|
+
rotation=37.5,
|
1147
|
+
size=8,
|
1148
|
+
weight="semibold",
|
1149
|
+
)
|
889
1150
|
elif num_source > 20:
|
890
|
-
fig_text = ax1.text(
|
891
|
-
|
1151
|
+
fig_text = ax1.text(
|
1152
|
+
j,
|
1153
|
+
i,
|
1154
|
+
edge_info,
|
1155
|
+
ha="center",
|
1156
|
+
va="center",
|
1157
|
+
color="w",
|
1158
|
+
rotation=37.5,
|
1159
|
+
size=7,
|
1160
|
+
weight="semibold",
|
1161
|
+
)
|
892
1162
|
else:
|
893
|
-
fig_text = ax1.text(
|
894
|
-
|
1163
|
+
fig_text = ax1.text(
|
1164
|
+
j,
|
1165
|
+
i,
|
1166
|
+
edge_info,
|
1167
|
+
ha="center",
|
1168
|
+
va="center",
|
1169
|
+
color="w",
|
1170
|
+
rotation=37.5,
|
1171
|
+
size=11,
|
1172
|
+
weight="semibold",
|
1173
|
+
)
|
895
1174
|
else:
|
896
|
-
fig_text = ax1.text(
|
897
|
-
|
898
|
-
|
1175
|
+
fig_text = ax1.text(
|
1176
|
+
j, i, edge_info, ha="center", va="center", color="w", size=11, weight="semibold"
|
1177
|
+
)
|
1178
|
+
|
899
1179
|
# Set labels and title for the plot
|
900
|
-
ax1.set_ylabel(
|
901
|
-
ax1.set_xlabel(
|
902
|
-
ax1.set_title(title, size=20, weight=
|
903
|
-
|
1180
|
+
ax1.set_ylabel("Source", size=11, weight="semibold")
|
1181
|
+
ax1.set_xlabel("Target", size=11, weight="semibold")
|
1182
|
+
ax1.set_title(title, size=20, weight="semibold")
|
1183
|
+
|
904
1184
|
# Display the plot or save it based on the environment and arguments
|
905
1185
|
notebook = is_notebook() # Check if running in a Jupyter notebook
|
906
|
-
if notebook
|
1186
|
+
if not notebook:
|
907
1187
|
fig1.show()
|
908
|
-
|
1188
|
+
|
909
1189
|
if save_file:
|
910
1190
|
plt.savefig(save_file)
|
911
|
-
|
1191
|
+
|
912
1192
|
if return_dict:
|
913
1193
|
return graph_dict
|
914
1194
|
else:
|
915
1195
|
return
|
916
1196
|
|
917
1197
|
|
918
|
-
def connector_percent_matrix(
|
1198
|
+
def connector_percent_matrix(
|
1199
|
+
csv_path: str = None,
|
1200
|
+
exclude_strings=None,
|
1201
|
+
assemb_key=None,
|
1202
|
+
title: str = "Percent connection matrix",
|
1203
|
+
pop_order=None,
|
1204
|
+
) -> None:
|
919
1205
|
"""
|
920
1206
|
Generates and plots a connection matrix based on connection probabilities from a CSV file produced by bmtool.connector.
|
921
1207
|
|
922
|
-
This function is useful for visualizing percent connectivity while factoring in population distance and other parameters.
|
923
|
-
It processes the connection data by filtering the 'Source' and 'Target' columns in the CSV, and displays the percentage of
|
1208
|
+
This function is useful for visualizing percent connectivity while factoring in population distance and other parameters.
|
1209
|
+
It processes the connection data by filtering the 'Source' and 'Target' columns in the CSV, and displays the percentage of
|
924
1210
|
connected pairs for each population combination in a matrix.
|
925
1211
|
|
926
1212
|
Parameters:
|
927
1213
|
-----------
|
928
1214
|
csv_path : str
|
929
|
-
Path to the CSV file containing the connection data. The CSV should be an output from the bmtool.connector
|
1215
|
+
Path to the CSV file containing the connection data. The CSV should be an output from the bmtool.connector
|
930
1216
|
classes, specifically generated by the `save_connection_report()` function.
|
931
1217
|
exclude_strings : list of str, optional
|
932
1218
|
List of strings to exclude rows where 'Source' or 'Target' contain these strings.
|
@@ -949,15 +1235,14 @@ def connector_percent_matrix(csv_path: str = None, exclude_strings=None, assemb_
|
|
949
1235
|
# Filter the DataFrame based on exclude_strings
|
950
1236
|
def filter_dataframe(df, column_name, exclude_strings):
|
951
1237
|
def process_string(string):
|
952
|
-
|
953
1238
|
match = re.search(r"\[\'(.*?)\'\]", string)
|
954
1239
|
if exclude_strings and any(ex_string in string for ex_string in exclude_strings):
|
955
1240
|
return None
|
956
1241
|
elif match:
|
957
1242
|
filtered_string = match.group(1)
|
958
|
-
if
|
959
|
-
|
960
|
-
|
1243
|
+
if "Gap" in string:
|
1244
|
+
filtered_string = filtered_string + "-Gap"
|
1245
|
+
|
961
1246
|
if assemb_key:
|
962
1247
|
if assemb_key in string:
|
963
1248
|
filtered_string = filtered_string + assemb_key
|
@@ -965,65 +1250,73 @@ def connector_percent_matrix(csv_path: str = None, exclude_strings=None, assemb_
|
|
965
1250
|
return filtered_string # Return matched string
|
966
1251
|
|
967
1252
|
return string # If no match, return the original string
|
968
|
-
|
1253
|
+
|
969
1254
|
df[column_name] = df[column_name].apply(process_string)
|
970
1255
|
df = df.dropna(subset=[column_name])
|
971
|
-
|
1256
|
+
|
972
1257
|
return df
|
973
1258
|
|
974
|
-
df = filter_dataframe(df,
|
975
|
-
df = filter_dataframe(df,
|
976
|
-
|
977
|
-
#process assem rows and combine them into one prob per assem type
|
1259
|
+
df = filter_dataframe(df, "Source", exclude_strings)
|
1260
|
+
df = filter_dataframe(df, "Target", exclude_strings)
|
1261
|
+
|
1262
|
+
# process assem rows and combine them into one prob per assem type
|
978
1263
|
if assemb_key:
|
979
|
-
assems = df[df[
|
980
|
-
unique_sources = assems[
|
1264
|
+
assems = df[df["Source"].str.contains(assemb_key)]
|
1265
|
+
unique_sources = assems["Source"].unique()
|
981
1266
|
|
982
1267
|
for source in unique_sources:
|
983
|
-
source_assems = assems[assems[
|
984
|
-
unique_targets = source_assems[
|
1268
|
+
source_assems = assems[assems["Source"] == source]
|
1269
|
+
unique_targets = source_assems[
|
1270
|
+
"Target"
|
1271
|
+
].unique() # Filter targets for the current source
|
985
1272
|
|
986
1273
|
for target in unique_targets:
|
987
1274
|
# Filter the assemblies with the current source and target
|
988
|
-
unique_assems = source_assems[source_assems[
|
989
|
-
|
1275
|
+
unique_assems = source_assems[source_assems["Target"] == target]
|
1276
|
+
|
990
1277
|
# find the prob of a conn
|
991
1278
|
forward_probs = []
|
992
|
-
for _,row in unique_assems.iterrows():
|
1279
|
+
for _, row in unique_assems.iterrows():
|
993
1280
|
selected_percentage = row[selected_column]
|
994
|
-
selected_percentage = [
|
1281
|
+
selected_percentage = [
|
1282
|
+
float(p) for p in selected_percentage.strip("[]").split()
|
1283
|
+
]
|
995
1284
|
if len(selected_percentage) == 1 or len(selected_percentage) == 2:
|
996
1285
|
forward_probs.append(selected_percentage[0])
|
997
1286
|
if len(selected_percentage) == 3:
|
998
1287
|
forward_probs.append(selected_percentage[0])
|
999
1288
|
forward_probs.append(selected_percentage[1])
|
1000
|
-
|
1289
|
+
|
1001
1290
|
mean_probs = np.mean(forward_probs)
|
1002
1291
|
source = source.replace(assemb_key, "")
|
1003
1292
|
target = target.replace(assemb_key, "")
|
1004
|
-
new_row = pd.DataFrame(
|
1005
|
-
|
1006
|
-
|
1007
|
-
|
1008
|
-
|
1009
|
-
|
1293
|
+
new_row = pd.DataFrame(
|
1294
|
+
{
|
1295
|
+
"Source": [source],
|
1296
|
+
"Target": [target],
|
1297
|
+
"Percent connectionivity within possible connections": [mean_probs],
|
1298
|
+
"Percent connectionivity within all connections": [0],
|
1299
|
+
}
|
1300
|
+
)
|
1010
1301
|
|
1011
1302
|
df = pd.concat([df, new_row], ignore_index=False)
|
1012
|
-
|
1303
|
+
|
1013
1304
|
# Prepare connection data
|
1014
1305
|
connection_data = {}
|
1015
1306
|
for _, row in df.iterrows():
|
1016
|
-
source, target, selected_percentage = row[
|
1307
|
+
source, target, selected_percentage = row["Source"], row["Target"], row[selected_column]
|
1017
1308
|
if isinstance(selected_percentage, str):
|
1018
|
-
selected_percentage = [float(p) for p in selected_percentage.strip(
|
1309
|
+
selected_percentage = [float(p) for p in selected_percentage.strip("[]").split()]
|
1019
1310
|
connection_data[(source, target)] = selected_percentage
|
1020
1311
|
|
1021
1312
|
# Determine population order
|
1022
|
-
populations = sorted(list(set(df[
|
1313
|
+
populations = sorted(list(set(df["Source"].unique()) | set(df["Target"].unique())))
|
1023
1314
|
if pop_order:
|
1024
|
-
populations = [
|
1315
|
+
populations = [
|
1316
|
+
pop for pop in pop_order if pop in populations
|
1317
|
+
] # Order according to pop_order, if provided
|
1025
1318
|
num_populations = len(populations)
|
1026
|
-
|
1319
|
+
|
1027
1320
|
# Create an empty matrix and populate it
|
1028
1321
|
connection_matrix = np.zeros((num_populations, num_populations), dtype=float)
|
1029
1322
|
for (source, target), probabilities in connection_data.items():
|
@@ -1031,40 +1324,49 @@ def connector_percent_matrix(csv_path: str = None, exclude_strings=None, assemb_
|
|
1031
1324
|
source_idx = populations.index(source)
|
1032
1325
|
target_idx = populations.index(target)
|
1033
1326
|
|
1034
|
-
if
|
1327
|
+
if isinstance(probabilities, float):
|
1035
1328
|
connection_matrix[source_idx][target_idx] = probabilities
|
1036
1329
|
elif len(probabilities) == 1:
|
1037
1330
|
connection_matrix[source_idx][target_idx] = probabilities[0]
|
1038
1331
|
elif len(probabilities) == 2:
|
1039
1332
|
connection_matrix[source_idx][target_idx] = probabilities[0]
|
1040
1333
|
elif len(probabilities) == 3:
|
1041
|
-
connection_matrix[source_idx][target_idx] = probabilities[0]
|
1334
|
+
connection_matrix[source_idx][target_idx] = probabilities[0]
|
1042
1335
|
connection_matrix[target_idx][source_idx] = probabilities[1]
|
1043
1336
|
else:
|
1044
1337
|
raise Exception("unsupported format")
|
1045
1338
|
|
1046
1339
|
# Plotting
|
1047
1340
|
fig, ax = plt.subplots(figsize=(10, 8))
|
1048
|
-
im = ax.imshow(connection_matrix, cmap=
|
1341
|
+
im = ax.imshow(connection_matrix, cmap="viridis", interpolation="nearest")
|
1049
1342
|
|
1050
1343
|
# Add annotations
|
1051
1344
|
for i in range(num_populations):
|
1052
1345
|
for j in range(num_populations):
|
1053
|
-
text = ax.text(
|
1346
|
+
text = ax.text(
|
1347
|
+
j,
|
1348
|
+
i,
|
1349
|
+
f"{connection_matrix[i, j]:.2f}%",
|
1350
|
+
ha="center",
|
1351
|
+
va="center",
|
1352
|
+
color="w",
|
1353
|
+
size=10,
|
1354
|
+
weight="semibold",
|
1355
|
+
)
|
1054
1356
|
|
1055
1357
|
# Add colorbar
|
1056
|
-
plt.colorbar(im, label=f
|
1358
|
+
plt.colorbar(im, label=f"{selected_column}")
|
1057
1359
|
|
1058
1360
|
# Set title and axis labels
|
1059
1361
|
ax.set_title(title)
|
1060
|
-
ax.set_xlabel(
|
1061
|
-
ax.set_ylabel(
|
1362
|
+
ax.set_xlabel("Target Population")
|
1363
|
+
ax.set_ylabel("Source Population")
|
1062
1364
|
|
1063
1365
|
# Set ticks and labels based on populations in specified order
|
1064
1366
|
ax.set_xticks(np.arange(num_populations))
|
1065
1367
|
ax.set_yticks(np.arange(num_populations))
|
1066
|
-
ax.set_xticklabels(populations, rotation=45, ha="right", size=12, weight=
|
1067
|
-
ax.set_yticklabels(populations, size=12, weight=
|
1368
|
+
ax.set_xticklabels(populations, rotation=45, ha="right", size=12, weight="semibold")
|
1369
|
+
ax.set_yticklabels(populations, size=12, weight="semibold")
|
1068
1370
|
|
1069
1371
|
plt.tight_layout()
|
1070
1372
|
plt.show()
|
@@ -1073,22 +1375,22 @@ def connector_percent_matrix(csv_path: str = None, exclude_strings=None, assemb_
|
|
1073
1375
|
def plot_3d_positions(config=None, sources=None, sid=None, title=None, save_file=None, subset=None):
|
1074
1376
|
"""
|
1075
1377
|
Plots a 3D graph of all cells with x, y, z location.
|
1076
|
-
|
1378
|
+
|
1077
1379
|
Parameters:
|
1078
|
-
- config: A BMTK simulation config
|
1079
|
-
- sources: Which network(s) to plot
|
1380
|
+
- config: A BMTK simulation config
|
1381
|
+
- sources: Which network(s) to plot
|
1080
1382
|
- sid: How to name cell groups
|
1081
1383
|
- title: Plot title
|
1082
1384
|
- save_file: If plot should be saved
|
1083
1385
|
- subset: Take every Nth row. This will make plotting large network graphs easier to see.
|
1084
1386
|
"""
|
1085
|
-
|
1387
|
+
|
1086
1388
|
if not config:
|
1087
1389
|
raise Exception("config not defined")
|
1088
|
-
|
1390
|
+
|
1089
1391
|
if sources is None:
|
1090
1392
|
sources = "all"
|
1091
|
-
|
1393
|
+
|
1092
1394
|
# Set group keys (e.g., node types)
|
1093
1395
|
group_keys = sid
|
1094
1396
|
if title is None:
|
@@ -1096,66 +1398,75 @@ def plot_3d_positions(config=None, sources=None, sid=None, title=None, save_file
|
|
1096
1398
|
|
1097
1399
|
# Load nodes from the configuration
|
1098
1400
|
nodes = util.load_nodes_from_config(config)
|
1099
|
-
|
1401
|
+
|
1100
1402
|
# Get the list of populations to plot
|
1101
|
-
if
|
1403
|
+
if "all" in sources:
|
1102
1404
|
populations = list(nodes)
|
1103
1405
|
else:
|
1104
1406
|
populations = sources.split(",")
|
1105
|
-
|
1106
|
-
# Split group_by into list
|
1407
|
+
|
1408
|
+
# Split group_by into list
|
1107
1409
|
group_keys = group_keys.split(",")
|
1108
|
-
group_keys += (len(populations) - len(group_keys)) * [
|
1410
|
+
group_keys += (len(populations) - len(group_keys)) * [
|
1411
|
+
"node_type_id"
|
1412
|
+
] # Extend the array to default values if not enough given
|
1109
1413
|
if len(group_keys) > 1:
|
1110
1414
|
raise Exception("Only one group by is supported currently!")
|
1111
|
-
|
1415
|
+
|
1112
1416
|
fig = plt.figure(figsize=(10, 10))
|
1113
|
-
ax = fig.add_subplot(projection=
|
1417
|
+
ax = fig.add_subplot(projection="3d")
|
1114
1418
|
handles = []
|
1115
1419
|
|
1116
|
-
for pop in
|
1117
|
-
|
1118
|
-
if 'all' not in populations and pop not in populations:
|
1420
|
+
for pop in list(nodes):
|
1421
|
+
if "all" not in populations and pop not in populations:
|
1119
1422
|
continue
|
1120
|
-
|
1423
|
+
|
1121
1424
|
nodes_df = nodes[pop]
|
1122
|
-
group_key = group_keys[0]
|
1123
|
-
|
1425
|
+
group_key = group_keys[0]
|
1426
|
+
|
1124
1427
|
# If group_key is provided, ensure the column exists in the dataframe
|
1125
1428
|
if group_key is not None:
|
1126
1429
|
if group_key not in nodes_df:
|
1127
1430
|
raise Exception(f"Could not find column '{group_key}' in {pop}")
|
1128
|
-
|
1431
|
+
|
1129
1432
|
groupings = nodes_df.groupby(group_key)
|
1130
1433
|
n_colors = nodes_df[group_key].nunique()
|
1131
1434
|
color_norm = colors.Normalize(vmin=0, vmax=(n_colors - 1))
|
1132
|
-
scalar_map = cmx.ScalarMappable(norm=color_norm, cmap=
|
1435
|
+
scalar_map = cmx.ScalarMappable(norm=color_norm, cmap="hsv")
|
1133
1436
|
color_map = [scalar_map.to_rgba(i) for i in range(n_colors)]
|
1134
1437
|
else:
|
1135
1438
|
groupings = [(None, nodes_df)]
|
1136
|
-
color_map = [
|
1439
|
+
color_map = ["blue"]
|
1137
1440
|
|
1138
1441
|
# Loop over groupings and plot
|
1139
1442
|
for color, (group_name, group_df) in zip(color_map, groupings):
|
1140
1443
|
if "pos_x" not in group_df or "pos_y" not in group_df or "pos_z" not in group_df:
|
1141
|
-
print(
|
1444
|
+
print(
|
1445
|
+
f"Warning: Missing position columns in group '{group_name}' for {pop}. Skipping this group."
|
1446
|
+
)
|
1142
1447
|
continue # Skip if position columns are missing
|
1143
1448
|
|
1144
1449
|
# Subset the dataframe by taking every Nth row if subset is provided
|
1145
1450
|
if subset is not None:
|
1146
1451
|
group_df = group_df.iloc[::subset]
|
1147
1452
|
|
1148
|
-
h = ax.scatter(
|
1453
|
+
h = ax.scatter(
|
1454
|
+
group_df["pos_x"],
|
1455
|
+
group_df["pos_y"],
|
1456
|
+
group_df["pos_z"],
|
1457
|
+
color=color,
|
1458
|
+
label=group_name,
|
1459
|
+
)
|
1149
1460
|
handles.append(h)
|
1150
|
-
|
1461
|
+
|
1151
1462
|
if not handles:
|
1152
1463
|
print("No data to plot.")
|
1153
1464
|
return
|
1154
|
-
|
1465
|
+
|
1155
1466
|
# Set plot title and legend
|
1156
1467
|
plt.title(title)
|
1157
1468
|
plt.legend(handles=handles)
|
1158
|
-
|
1469
|
+
|
1159
1470
|
# Draw the plot
|
1160
1471
|
plt.draw()
|
1161
1472
|
|
@@ -1170,8 +1481,19 @@ def plot_3d_positions(config=None, sources=None, sid=None, title=None, save_file
|
|
1170
1481
|
return ax
|
1171
1482
|
|
1172
1483
|
|
1173
|
-
def plot_3d_cell_rotation(
|
1484
|
+
def plot_3d_cell_rotation(
|
1485
|
+
config=None,
|
1486
|
+
sources=None,
|
1487
|
+
sids=None,
|
1488
|
+
title=None,
|
1489
|
+
save_file=None,
|
1490
|
+
quiver_length=None,
|
1491
|
+
arrow_length_ratio=None,
|
1492
|
+
group=None,
|
1493
|
+
subset=None,
|
1494
|
+
):
|
1174
1495
|
from scipy.spatial.transform import Rotation as R
|
1496
|
+
|
1175
1497
|
if not config:
|
1176
1498
|
raise Exception("config not defined")
|
1177
1499
|
|
@@ -1185,38 +1507,38 @@ def plot_3d_cell_rotation(config=None, sources=None, sids=None, title=None, save
|
|
1185
1507
|
|
1186
1508
|
nodes = util.load_nodes_from_config(config)
|
1187
1509
|
|
1188
|
-
if
|
1510
|
+
if "all" in sources:
|
1189
1511
|
populations = list(nodes)
|
1190
1512
|
else:
|
1191
1513
|
populations = sources
|
1192
1514
|
|
1193
1515
|
fig = plt.figure(figsize=(10, 10))
|
1194
|
-
ax = fig.add_subplot(111, projection=
|
1516
|
+
ax = fig.add_subplot(111, projection="3d")
|
1195
1517
|
handles = []
|
1196
1518
|
|
1197
1519
|
for nodes_key, group_key in zip(list(nodes), group_keys):
|
1198
|
-
if
|
1520
|
+
if "all" not in populations and nodes_key not in populations:
|
1199
1521
|
continue
|
1200
1522
|
|
1201
1523
|
nodes_df = nodes[nodes_key]
|
1202
1524
|
|
1203
1525
|
if group_key is not None:
|
1204
1526
|
if group_key not in nodes_df.columns:
|
1205
|
-
raise Exception(f
|
1527
|
+
raise Exception(f"Could not find column {group_key}")
|
1206
1528
|
groupings = nodes_df.groupby(group_key)
|
1207
1529
|
|
1208
1530
|
n_colors = nodes_df[group_key].nunique()
|
1209
1531
|
color_norm = colors.Normalize(vmin=0, vmax=(n_colors - 1))
|
1210
|
-
scalar_map = cmx.ScalarMappable(norm=color_norm, cmap=
|
1532
|
+
scalar_map = cmx.ScalarMappable(norm=color_norm, cmap="hsv")
|
1211
1533
|
color_map = [scalar_map.to_rgba(i) for i in range(n_colors)]
|
1212
1534
|
else:
|
1213
1535
|
groupings = [(None, nodes_df)]
|
1214
|
-
color_map = [
|
1536
|
+
color_map = ["blue"]
|
1215
1537
|
|
1216
1538
|
for color, (group_name, group_df) in zip(color_map, groupings):
|
1217
1539
|
if subset is not None:
|
1218
1540
|
group_df = group_df.iloc[::subset]
|
1219
|
-
|
1541
|
+
|
1220
1542
|
if group and group_name not in group.split(","):
|
1221
1543
|
continue
|
1222
1544
|
|
@@ -1238,9 +1560,9 @@ def plot_3d_cell_rotation(config=None, sources=None, sids=None, title=None, save
|
|
1238
1560
|
W = np.zeros(len(Z))
|
1239
1561
|
|
1240
1562
|
# Create rotation matrices from Euler angles
|
1241
|
-
rotations = R.from_euler(
|
1563
|
+
rotations = R.from_euler("xyz", np.column_stack((U, V, W)), degrees=False)
|
1242
1564
|
|
1243
|
-
# Define initial vectors
|
1565
|
+
# Define initial vectors
|
1244
1566
|
init_vectors = np.column_stack((np.ones(len(X)), np.zeros(len(Y)), np.zeros(len(Z))))
|
1245
1567
|
|
1246
1568
|
# Apply rotations to initial vectors
|
@@ -1251,7 +1573,18 @@ def plot_3d_cell_rotation(config=None, sources=None, sids=None, title=None, save
|
|
1251
1573
|
rot_y = rots[:, 1]
|
1252
1574
|
rot_z = rots[:, 2]
|
1253
1575
|
|
1254
|
-
h = ax.quiver(
|
1576
|
+
h = ax.quiver(
|
1577
|
+
X,
|
1578
|
+
Y,
|
1579
|
+
Z,
|
1580
|
+
rot_x,
|
1581
|
+
rot_y,
|
1582
|
+
rot_z,
|
1583
|
+
color=color,
|
1584
|
+
label=group_name,
|
1585
|
+
arrow_length_ratio=arrow_length_ratio,
|
1586
|
+
length=quiver_length,
|
1587
|
+
)
|
1255
1588
|
ax.scatter(X, Y, Z, color=color, label=group_name)
|
1256
1589
|
ax.set_xlim([min(X), max(X)])
|
1257
1590
|
ax.set_ylim([min(Y), max(Y)])
|
@@ -1268,5 +1601,5 @@ def plot_3d_cell_rotation(config=None, sources=None, sids=None, title=None, save
|
|
1268
1601
|
if save_file:
|
1269
1602
|
plt.savefig(save_file)
|
1270
1603
|
notebook = is_notebook
|
1271
|
-
if notebook
|
1604
|
+
if not notebook:
|
1272
1605
|
plt.show()
|