bmtool 0.4.2__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bmtool/bmplot.py +274 -100
- bmtool/connectors.py +45 -51
- bmtool/util/util.py +65 -8
- bmtool-0.5.1.dist-info/METADATA +519 -0
- {bmtool-0.4.2.dist-info → bmtool-0.5.1.dist-info}/RECORD +9 -9
- bmtool-0.4.2.dist-info/METADATA +0 -550
- {bmtool-0.4.2.dist-info → bmtool-0.5.1.dist-info}/LICENSE +0 -0
- {bmtool-0.4.2.dist-info → bmtool-0.5.1.dist-info}/WHEEL +0 -0
- {bmtool-0.4.2.dist-info → bmtool-0.5.1.dist-info}/entry_points.txt +0 -0
- {bmtool-0.4.2.dist-info → bmtool-0.5.1.dist-info}/top_level.txt +0 -0
bmtool/bmplot.py
CHANGED
@@ -14,12 +14,11 @@ import matplotlib.colors as colors
|
|
14
14
|
import matplotlib.gridspec as gridspec
|
15
15
|
from mpl_toolkits.mplot3d import Axes3D
|
16
16
|
from IPython import get_ipython
|
17
|
-
import
|
17
|
+
from IPython.display import display, HTML
|
18
|
+
import statistics
|
18
19
|
import pandas as pd
|
19
|
-
import h5py
|
20
20
|
import os
|
21
21
|
import sys
|
22
|
-
import time
|
23
22
|
|
24
23
|
from .util.util import CellVarsFile #, missing_units
|
25
24
|
from bmtk.analyzer.utils import listify
|
@@ -31,7 +30,36 @@ Plot BMTK models easily.
|
|
31
30
|
python -m bmtool.plot
|
32
31
|
"""
|
33
32
|
|
34
|
-
def
|
33
|
+
def is_notebook() -> bool:
|
34
|
+
"""
|
35
|
+
Used to tell if inside jupyter notebook or not. This is used to tell if we should use plt.show or not
|
36
|
+
"""
|
37
|
+
try:
|
38
|
+
shell = get_ipython().__class__.__name__
|
39
|
+
if shell == 'ZMQInteractiveShell':
|
40
|
+
return True # Jupyter notebook or qtconsole
|
41
|
+
elif shell == 'TerminalInteractiveShell':
|
42
|
+
return False # Terminal running IPython
|
43
|
+
else:
|
44
|
+
return False # Other type (?)
|
45
|
+
except NameError:
|
46
|
+
return False # Probably standard Python interpreter
|
47
|
+
|
48
|
+
def total_connection_matrix(config=None,title=None,sources=None, targets=None, sids=None, tids=None,no_prepend_pop=False,save_file=None,synaptic_info='0',include_gap=True):
|
49
|
+
"""
|
50
|
+
Generates connection plot displaying total connection or other stats
|
51
|
+
config: A BMTK simulation config
|
52
|
+
sources: network name(s) to plot
|
53
|
+
targets: network name(s) to plot
|
54
|
+
sids: source node identifier
|
55
|
+
tids: target node identifier
|
56
|
+
no_prepend_pop: dictates if population name is displayed before sid or tid when displaying graph
|
57
|
+
save_file: If plot should be saved
|
58
|
+
synaptic_info: '0' for total connections, '1' for mean and stdev connections, '2' for all synapse .mod files used, '3' for all synapse .json files used
|
59
|
+
include_gap: Determines if connectivity shown should include gap junctions + chemical synapses. False will only include chemical
|
60
|
+
"""
|
61
|
+
if not config:
|
62
|
+
raise Exception("config not defined")
|
35
63
|
if not sources or not targets:
|
36
64
|
raise Exception("Sources or targets not defined")
|
37
65
|
sources = sources.split(",")
|
@@ -44,7 +72,7 @@ def connection_matrix(config=None,nodes=None,edges=None,title=None,sources=None,
|
|
44
72
|
tids = tids.split(",")
|
45
73
|
else:
|
46
74
|
tids = []
|
47
|
-
text,num, source_labels, target_labels = util.connection_totals(config=config,nodes=None,edges=None,sources=sources,targets=targets,sids=sids,tids=tids,prepend_pop=not no_prepend_pop,synaptic_info=synaptic_info)
|
75
|
+
text,num, source_labels, target_labels = util.connection_totals(config=config,nodes=None,edges=None,sources=sources,targets=targets,sids=sids,tids=tids,prepend_pop=not no_prepend_pop,synaptic_info=synaptic_info,include_gap=include_gap)
|
48
76
|
|
49
77
|
if title == None or title=="":
|
50
78
|
title = "Total Connections"
|
@@ -58,18 +86,53 @@ def connection_matrix(config=None,nodes=None,edges=None,title=None,sources=None,
|
|
58
86
|
plot_connection_info(text,num,source_labels,target_labels,title, syn_info=synaptic_info, save_file=save_file)
|
59
87
|
return
|
60
88
|
|
61
|
-
def percent_connection_matrix(config=None,nodes=None,edges=None,title=None,sources=None, targets=None, sids=None, tids=None, no_prepend_pop=False,save_file=None):
|
62
|
-
|
63
|
-
|
89
|
+
def percent_connection_matrix(config=None,nodes=None,edges=None,title=None,sources=None, targets=None, sids=None, tids=None, no_prepend_pop=False,save_file=None,method = 'total',include_gap=True):
|
90
|
+
"""
|
91
|
+
Generates a plot showing the percent connectivity of a network
|
92
|
+
config: A BMTK simulation config
|
93
|
+
sources: network name(s) to plot
|
94
|
+
targets: network name(s) to plot
|
95
|
+
sids: source node identifier
|
96
|
+
tids: target node identifier
|
97
|
+
no_prepend_pop: dictates if population name is displayed before sid or tid when displaying graph
|
98
|
+
method: what percent to displace on the graph 'total','uni',or 'bi' for total connections, unidirectional connections or bidirectional connections
|
99
|
+
save_file: If plot should be saved
|
100
|
+
include_gap: Determines if connectivity shown should include gap junctions + chemical synapses. False will only include chemical
|
101
|
+
"""
|
102
|
+
if not config:
|
103
|
+
raise Exception("config not defined")
|
104
|
+
if not sources or not targets:
|
105
|
+
raise Exception("Sources or targets not defined")
|
106
|
+
|
107
|
+
sources = sources.split(",")
|
108
|
+
targets = targets.split(",")
|
109
|
+
if sids:
|
110
|
+
sids = sids.split(",")
|
111
|
+
else:
|
112
|
+
sids = []
|
113
|
+
if tids:
|
114
|
+
tids = tids.split(",")
|
115
|
+
else:
|
116
|
+
tids = []
|
117
|
+
text,num, source_labels, target_labels = util.percent_connections(config=config,nodes=None,edges=None,sources=sources,targets=targets,sids=sids,tids=tids,prepend_pop=not no_prepend_pop,method=method,include_gap=include_gap)
|
64
118
|
if title == None or title=="":
|
65
119
|
title = "Percent Connectivity"
|
66
120
|
|
121
|
+
|
67
122
|
plot_connection_info(text,num,source_labels,target_labels,title, save_file=save_file)
|
68
123
|
return
|
69
124
|
|
70
125
|
def probability_connection_matrix(config=None,nodes=None,edges=None,title=None,sources=None, targets=None, sids=None, tids=None,
|
71
|
-
no_prepend_pop=False,save_file=None, dist_X=True,dist_Y=True,dist_Z=True,bins=8,line_plot=False,verbose=False):
|
72
|
-
|
126
|
+
no_prepend_pop=False,save_file=None, dist_X=True,dist_Y=True,dist_Z=True,bins=8,line_plot=False,verbose=False,include_gap=True):
|
127
|
+
"""
|
128
|
+
Generates probability graphs
|
129
|
+
need to look into this more to see what it does
|
130
|
+
needs model_template to be defined to work
|
131
|
+
"""
|
132
|
+
if not config:
|
133
|
+
raise Exception("config not defined")
|
134
|
+
if not sources or not targets:
|
135
|
+
raise Exception("Sources or targets not defined")
|
73
136
|
if not sources or not targets:
|
74
137
|
raise Exception("Sources or targets not defined")
|
75
138
|
sources = sources.split(",")
|
@@ -85,7 +148,7 @@ def probability_connection_matrix(config=None,nodes=None,edges=None,title=None,s
|
|
85
148
|
|
86
149
|
throwaway, data, source_labels, target_labels = util.connection_probabilities(config=config,nodes=None,
|
87
150
|
edges=None,sources=sources,targets=targets,sids=sids,tids=tids,
|
88
|
-
prepend_pop=not no_prepend_pop,dist_X=dist_X,dist_Y=dist_Y,dist_Z=dist_Z,num_bins=bins)
|
151
|
+
prepend_pop=not no_prepend_pop,dist_X=dist_X,dist_Y=dist_Y,dist_Z=dist_Z,num_bins=bins,include_gap=include_gap)
|
89
152
|
if not data.any():
|
90
153
|
return
|
91
154
|
if data[0][0]==-1:
|
@@ -129,14 +192,44 @@ def probability_connection_matrix(config=None,nodes=None,edges=None,title=None,s
|
|
129
192
|
st = fig.suptitle(tt, fontsize=14)
|
130
193
|
fig.text(0.5, 0.04, 'Target', ha='center')
|
131
194
|
fig.text(0.04, 0.5, 'Source', va='center', rotation='vertical')
|
132
|
-
|
195
|
+
notebook = is_notebook
|
196
|
+
if notebook == False:
|
197
|
+
fig.show()
|
133
198
|
|
134
199
|
return
|
135
200
|
|
136
|
-
def convergence_connection_matrix(config=None,
|
137
|
-
|
201
|
+
def convergence_connection_matrix(config=None,title=None,sources=None, targets=None, sids=None, tids=None, no_prepend_pop=False,save_file=None,convergence=True,method='mean',include_gap=True):
|
202
|
+
"""
|
203
|
+
Generates connection plot displaying convergence data
|
204
|
+
config: A BMTK simulation config
|
205
|
+
sources: network name(s) to plot
|
206
|
+
targets: network name(s) to plot
|
207
|
+
sids: source node identifier
|
208
|
+
tids: target node identifier
|
209
|
+
no_prepend_pop: dictates if population name is displayed before sid or tid when displaying graph
|
210
|
+
save_file: If plot should be saved
|
211
|
+
method: 'mean','min','max','stdev' for connvergence plot
|
212
|
+
"""
|
213
|
+
if not config:
|
214
|
+
raise Exception("config not defined")
|
215
|
+
if not sources or not targets:
|
216
|
+
raise Exception("Sources or targets not defined")
|
217
|
+
return divergence_connection_matrix(config,title ,sources, targets, sids, tids, no_prepend_pop, save_file ,convergence, method,include_gap=include_gap)
|
138
218
|
|
139
|
-
def divergence_connection_matrix(config=None,
|
219
|
+
def divergence_connection_matrix(config=None,title=None,sources=None, targets=None, sids=None, tids=None, no_prepend_pop=False,save_file=None,convergence=False,method='mean',include_gap=True):
|
220
|
+
"""
|
221
|
+
Generates connection plot displaying divergence data
|
222
|
+
config: A BMTK simulation config
|
223
|
+
sources: network name(s) to plot
|
224
|
+
targets: network name(s) to plot
|
225
|
+
sids: source node identifier
|
226
|
+
tids: target node identifier
|
227
|
+
no_prepend_pop: dictates if population name is displayed before sid or tid when displaying graph
|
228
|
+
save_file: If plot should be saved
|
229
|
+
method: 'mean','min','max','stdev' for connvergence plot
|
230
|
+
"""
|
231
|
+
if not config:
|
232
|
+
raise Exception("config not defined")
|
140
233
|
if not sources or not targets:
|
141
234
|
raise Exception("Sources or targets not defined")
|
142
235
|
sources = sources.split(",")
|
@@ -150,7 +243,7 @@ def divergence_connection_matrix(config=None,nodes=None,edges=None,title=None,so
|
|
150
243
|
else:
|
151
244
|
tids = []
|
152
245
|
|
153
|
-
syn_info, data, source_labels, target_labels = util.connection_divergence(config=config,nodes=None,edges=None,sources=sources,targets=targets,sids=sids,tids=tids,prepend_pop=not no_prepend_pop,convergence=convergence,method=method)
|
246
|
+
syn_info, data, source_labels, target_labels = util.connection_divergence(config=config,nodes=None,edges=None,sources=sources,targets=targets,sids=sids,tids=tids,prepend_pop=not no_prepend_pop,convergence=convergence,method=method,include_gap=include_gap)
|
154
247
|
|
155
248
|
|
156
249
|
#data, labels = util.connection_divergence_average(config=config,nodes=nodes,edges=edges,populations=populations)
|
@@ -174,26 +267,69 @@ def divergence_connection_matrix(config=None,nodes=None,edges=None,title=None,so
|
|
174
267
|
plot_connection_info(data,data,source_labels,target_labels,title, save_file=save_file)
|
175
268
|
return
|
176
269
|
|
177
|
-
def
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
270
|
+
def connection_histogram(config=None,nodes=None,edges=None,sources=[],targets=[],sids=[],tids=[],prepend_pop=True,synaptic_info='0',
|
271
|
+
source_cell = None,target_cell = None,include_gap=True):
|
272
|
+
"""
|
273
|
+
Generates histogram of number of connections individual cells in a population receieve from another population
|
274
|
+
config: A BMTK simulation config
|
275
|
+
sources: network name(s) to plot
|
276
|
+
targets: network name(s) to plot
|
277
|
+
sids: source node identifier
|
278
|
+
tids: target node identifier
|
279
|
+
no_prepend_pop: dictates if population name is displayed before sid or tid when displaying graph
|
280
|
+
source_cell: where connections are coming from
|
281
|
+
target_cell: where connections on coming onto
|
282
|
+
save_file: If plot should be saved
|
283
|
+
"""
|
284
|
+
def connection_pair_histogram(**kwargs):
|
285
|
+
edges = kwargs["edges"]
|
286
|
+
source_id_type = kwargs["sid"]
|
287
|
+
target_id_type = kwargs["tid"]
|
288
|
+
source_id = kwargs["source_id"]
|
289
|
+
target_id = kwargs["target_id"]
|
290
|
+
if source_id == source_cell and target_id == target_cell:
|
291
|
+
temp = edges[(edges[source_id_type] == source_id) & (edges[target_id_type]==target_id)]
|
292
|
+
if include_gap == False:
|
293
|
+
temp = temp[temp['is_gap_junction'] != True]
|
294
|
+
node_pairs = temp.groupby('target_node_id')['source_node_id'].count()
|
295
|
+
conn_mean = statistics.mean(node_pairs.values)
|
296
|
+
conn_std = statistics.stdev(node_pairs.values)
|
297
|
+
conn_median = statistics.median(node_pairs.values)
|
298
|
+
label = "mean {:.2f} std ({:.2f}) median {:.2f}".format(conn_mean,conn_std,conn_median)
|
299
|
+
plt.hist(node_pairs.values,density=True,bins='auto',stacked=True,label=label)
|
300
|
+
plt.legend()
|
301
|
+
plt.xlabel("# of conns from {} to {}".format(source_cell,target_cell))
|
302
|
+
plt.ylabel("Density")
|
303
|
+
plt.show()
|
304
|
+
else: # dont care about other cell pairs so pass
|
305
|
+
pass
|
306
|
+
|
307
|
+
if not config:
|
308
|
+
raise Exception("config not defined")
|
193
309
|
if not sources or not targets:
|
194
310
|
raise Exception("Sources or targets not defined")
|
195
311
|
sources = sources.split(",")
|
196
312
|
targets = targets.split(",")
|
313
|
+
if sids:
|
314
|
+
sids = sids.split(",")
|
315
|
+
else:
|
316
|
+
sids = []
|
317
|
+
if tids:
|
318
|
+
tids = tids.split(",")
|
319
|
+
else:
|
320
|
+
tids = []
|
321
|
+
util.relation_matrix(config,nodes,edges,sources,targets,sids,tids,prepend_pop,relation_func=connection_pair_histogram,synaptic_info=synaptic_info)
|
322
|
+
|
323
|
+
def edge_histogram_matrix(config=None,sources = None,targets=None,sids=None,tids=None,no_prepend_pop=None,edge_property = None,time = None,time_compare = None,report=None,title=None,save_file=None):
|
324
|
+
"""
|
325
|
+
write about function here
|
326
|
+
"""
|
327
|
+
|
328
|
+
if not config:
|
329
|
+
raise Exception("config not defined")
|
330
|
+
if not sources or not targets:
|
331
|
+
raise Exception("Sources or targets not defined")
|
332
|
+
targets = targets.split(",")
|
197
333
|
if sids:
|
198
334
|
sids = sids.split(",")
|
199
335
|
else:
|
@@ -231,8 +367,11 @@ def edge_histogram_matrix(**kwargs):
|
|
231
367
|
fig.text(0.04, 0.5, 'Source', va='center', rotation='vertical')
|
232
368
|
plt.draw()
|
233
369
|
|
234
|
-
|
235
370
|
def plot_connection_info(text, num, source_labels,target_labels, title, syn_info='0', save_file=None):
|
371
|
+
"""
|
372
|
+
write about function here
|
373
|
+
"""
|
374
|
+
|
236
375
|
#num = pd.DataFrame(num).fillna('nc').to_numpy() # replace nan with nc * does not work with imshow
|
237
376
|
|
238
377
|
num_source=len(source_labels)
|
@@ -274,14 +413,17 @@ def plot_connection_info(text, num, source_labels,target_labels, title, syn_info
|
|
274
413
|
ax1.set_xlabel('Target', size=11, weight = 'semibold')
|
275
414
|
ax1.set_title(title,size=20, weight = 'semibold')
|
276
415
|
#plt.tight_layout()
|
277
|
-
|
278
|
-
|
279
|
-
|
416
|
+
notebook = is_notebook()
|
417
|
+
if notebook == False:
|
418
|
+
fig1.show()
|
280
419
|
if save_file:
|
281
420
|
plt.savefig(save_file)
|
282
421
|
return
|
283
422
|
|
284
423
|
def raster_old(config=None,title=None,populations=['hippocampus']):
|
424
|
+
"""
|
425
|
+
old function probs dep
|
426
|
+
"""
|
285
427
|
conf = util.load_config(config)
|
286
428
|
spikes_path = os.path.join(conf["output"]["output_dir"],conf["output"]["spikes_file"])
|
287
429
|
nodes = util.load_nodes_from_config(config)
|
@@ -289,6 +431,9 @@ def raster_old(config=None,title=None,populations=['hippocampus']):
|
|
289
431
|
return
|
290
432
|
|
291
433
|
def raster(config=None,title=None,population=None,group_key='pop_name'):
|
434
|
+
"""
|
435
|
+
old function probs dep or more to new spike module?
|
436
|
+
"""
|
292
437
|
conf = util.load_config(config)
|
293
438
|
|
294
439
|
cells_file = conf["networks"]["nodes"][0]["nodes_file"]
|
@@ -300,6 +445,9 @@ def raster(config=None,title=None,population=None,group_key='pop_name'):
|
|
300
445
|
return
|
301
446
|
|
302
447
|
def plot_spikes(nodes, spikes_file,save_file=None):
|
448
|
+
"""
|
449
|
+
old function probs dep
|
450
|
+
"""
|
303
451
|
import h5py
|
304
452
|
|
305
453
|
spikes_h5 = h5py.File(spikes_file, 'r')
|
@@ -349,12 +497,23 @@ def plot_spikes(nodes, spikes_file,save_file=None):
|
|
349
497
|
|
350
498
|
return
|
351
499
|
|
352
|
-
def plot_3d_positions(
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
500
|
+
def plot_3d_positions(config=None,populations_list=None,group_by=None,title=None,save_file=None):
|
501
|
+
"""
|
502
|
+
plots a 3D graph of all cells with x,y,z location
|
503
|
+
config: A BMTK simulation config
|
504
|
+
populations_list: Which network(s) to plot
|
505
|
+
group_by: How to name cell groups
|
506
|
+
title: plot title
|
507
|
+
save_file: If plot should be saved
|
508
|
+
"""
|
509
|
+
|
510
|
+
if not config:
|
511
|
+
raise Exception("config not defined")
|
512
|
+
if populations_list == None:
|
513
|
+
populations_list = "all"
|
514
|
+
group_keys = group_by
|
515
|
+
if title == None:
|
516
|
+
title = "3D positions"
|
358
517
|
|
359
518
|
nodes = util.load_nodes_from_config(config)
|
360
519
|
|
@@ -366,7 +525,7 @@ def plot_3d_positions(**kwargs):
|
|
366
525
|
group_keys = group_keys.split(",")
|
367
526
|
group_keys += (len(populations)-len(group_keys)) * ["node_type_id"] #Extend the array to default values if not enough given
|
368
527
|
fig = plt.figure(figsize=(10,10))
|
369
|
-
ax =
|
528
|
+
ax = fig.add_subplot(projection='3d')
|
370
529
|
handles = []
|
371
530
|
for nodes_key,group_key in zip(list(nodes),group_keys):
|
372
531
|
if 'all' not in populations and nodes_key not in populations:
|
@@ -401,115 +560,123 @@ def plot_3d_positions(**kwargs):
|
|
401
560
|
|
402
561
|
if save_file:
|
403
562
|
plt.savefig(save_file)
|
563
|
+
notebook = is_notebook
|
564
|
+
if notebook == False:
|
565
|
+
plt.show()
|
404
566
|
|
405
567
|
return
|
406
568
|
|
569
|
+
def cell_rotation_3d(config=None, populations_list=None, group_by=None, title=None, save_file=None, quiver_length=None, arrow_length_ratio=None, group=None, max_cells=1000000):
|
570
|
+
from scipy.spatial.transform import Rotation as R
|
571
|
+
if not config:
|
572
|
+
raise Exception("config not defined")
|
407
573
|
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
group_keys =
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
arrow_length_ratio = kwargs["arrow_length_ratio"]
|
416
|
-
group = kwargs["group"]
|
417
|
-
max_cells = kwargs.get("max_cells",999999999)
|
418
|
-
init_vector = kwargs.get("init_vector","1,0,0")
|
574
|
+
if populations_list is None:
|
575
|
+
populations_list = ["all"]
|
576
|
+
|
577
|
+
group_keys = group_by.split(",") if group_by else []
|
578
|
+
|
579
|
+
if title is None:
|
580
|
+
title = "Cell rotations"
|
419
581
|
|
420
582
|
nodes = util.load_nodes_from_config(config)
|
421
583
|
|
422
584
|
if 'all' in populations_list:
|
423
585
|
populations = list(nodes)
|
424
586
|
else:
|
425
|
-
populations = populations_list
|
587
|
+
populations = populations_list
|
426
588
|
|
427
|
-
|
428
|
-
|
429
|
-
fig = plt.figure(figsize=(10,10))
|
430
|
-
ax = Axes3D(fig)
|
589
|
+
fig = plt.figure(figsize=(10, 10))
|
590
|
+
ax = fig.add_subplot(111, projection='3d')
|
431
591
|
handles = []
|
432
|
-
|
592
|
+
|
593
|
+
for nodes_key, group_key in zip(list(nodes), group_keys):
|
433
594
|
if 'all' not in populations and nodes_key not in populations:
|
434
595
|
continue
|
435
|
-
|
596
|
+
|
436
597
|
nodes_df = nodes[nodes_key]
|
437
598
|
|
438
599
|
if group_key is not None:
|
439
|
-
if group_key not in nodes_df:
|
440
|
-
raise Exception('Could not find column {}'
|
600
|
+
if group_key not in nodes_df.columns:
|
601
|
+
raise Exception(f'Could not find column {group_key}')
|
441
602
|
groupings = nodes_df.groupby(group_key)
|
442
603
|
|
443
604
|
n_colors = nodes_df[group_key].nunique()
|
444
|
-
color_norm = colors.Normalize(vmin=0, vmax=(n_colors-1))
|
605
|
+
color_norm = colors.Normalize(vmin=0, vmax=(n_colors - 1))
|
445
606
|
scalar_map = cmx.ScalarMappable(norm=color_norm, cmap='hsv')
|
446
|
-
color_map = [scalar_map.to_rgba(i) for i in range(
|
607
|
+
color_map = [scalar_map.to_rgba(i) for i in range(n_colors)]
|
447
608
|
else:
|
448
609
|
groupings = [(None, nodes_df)]
|
449
610
|
color_map = ['blue']
|
450
611
|
|
451
612
|
cells_plotted = 0
|
452
613
|
for color, (group_name, group_df) in zip(color_map, groupings):
|
453
|
-
# if we selected a group and it's not in the list continue
|
454
614
|
if group and group_name not in group.split(","):
|
455
615
|
continue
|
456
616
|
|
457
|
-
if "pos_x" not in group_df
|
458
|
-
continue
|
617
|
+
if "pos_x" not in group_df or "rotation_angle_xaxis" not in group_df:
|
618
|
+
continue
|
459
619
|
|
460
|
-
# if we exceed the max cells, stop plotting or limit
|
461
620
|
if cells_plotted >= max_cells:
|
462
621
|
continue
|
622
|
+
|
463
623
|
if len(group_df) + cells_plotted > max_cells:
|
464
624
|
total_remaining = max_cells - cells_plotted
|
465
625
|
group_df = group_df[:total_remaining]
|
626
|
+
|
466
627
|
cells_plotted += len(group_df)
|
467
628
|
|
468
629
|
X = group_df["pos_x"]
|
469
630
|
Y = group_df["pos_y"]
|
470
631
|
Z = group_df["pos_z"]
|
471
|
-
U = group_df
|
472
|
-
V = group_df
|
473
|
-
W = group_df
|
632
|
+
U = group_df["rotation_angle_xaxis"].values
|
633
|
+
V = group_df["rotation_angle_yaxis"].values
|
634
|
+
W = group_df["rotation_angle_zaxis"].values
|
635
|
+
|
474
636
|
if U is None:
|
475
637
|
U = np.zeros(len(X))
|
476
638
|
if V is None:
|
477
639
|
V = np.zeros(len(Y))
|
478
640
|
if W is None:
|
479
641
|
W = np.zeros(len(Z))
|
480
|
-
|
481
|
-
#Convert to arrow direction
|
482
|
-
from scipy.spatial.transform import Rotation as R
|
483
|
-
uvw = pd.DataFrame([U,V,W]).T
|
484
|
-
init_vector = init_vector.split(',')
|
485
|
-
init_vector = np.repeat([init_vector],len(X),axis=0)
|
486
|
-
|
487
|
-
# To get the final cell orientation after rotation,
|
488
|
-
# you need to use function Rotaion.apply(init_vec),
|
489
|
-
# where init_vec is a vector of the initial orientation of a cell
|
490
|
-
#rots = R.from_euler('xyz', uvw).apply(init_vector.astype(float))
|
491
|
-
#rots = R.from_euler('xyz', pd.DataFrame([rots[:,0],rots[:,1],rots[:,2]]).T).as_rotvec().T
|
492
642
|
|
493
|
-
|
494
|
-
|
643
|
+
# Create rotation matrices from Euler angles
|
644
|
+
rotations = R.from_euler('xyz', np.column_stack((U, V, W)), degrees=False)
|
495
645
|
|
496
|
-
#
|
497
|
-
|
646
|
+
# Define initial vectors
|
647
|
+
init_vectors = np.column_stack((np.ones(len(X)), np.zeros(len(Y)), np.zeros(len(Z))))
|
648
|
+
|
649
|
+
# Apply rotations to initial vectors
|
650
|
+
rots = np.dot(rotations.as_matrix(), init_vectors.T).T
|
651
|
+
|
652
|
+
# Extract x, y, and z components of the rotated vectors
|
653
|
+
rot_x = rots[:, 0]
|
654
|
+
rot_y = rots[:, 1]
|
655
|
+
rot_z = rots[:, 2]
|
656
|
+
|
657
|
+
h = ax.quiver(X, Y, Z, rot_x, rot_y, rot_z, color=color, label=group_name, arrow_length_ratio=arrow_length_ratio, length=quiver_length)
|
658
|
+
ax.scatter(X, Y, Z, color=color, label=group_name)
|
659
|
+
ax.set_xlim([min(X), max(X)])
|
660
|
+
ax.set_ylim([min(Y), max(Y)])
|
661
|
+
ax.set_zlim([min(Z), max(Z)])
|
498
662
|
handles.append(h)
|
663
|
+
|
499
664
|
if not handles:
|
500
665
|
return
|
666
|
+
|
501
667
|
plt.title(title)
|
502
668
|
plt.legend(handles=handles)
|
503
|
-
|
504
669
|
plt.draw()
|
505
670
|
|
506
671
|
if save_file:
|
507
672
|
plt.savefig(save_file)
|
508
|
-
|
509
|
-
|
673
|
+
notebook = is_notebook
|
674
|
+
if notebook == False:
|
675
|
+
plt.show()
|
510
676
|
|
511
677
|
def plot_network_graph(config=None,nodes=None,edges=None,title=None,sources=None, targets=None, sids=None, tids=None, no_prepend_pop=False,save_file=None,edge_property='model_template'):
|
512
|
-
|
678
|
+
if not config:
|
679
|
+
raise Exception("config not defined")
|
513
680
|
if not sources or not targets:
|
514
681
|
raise Exception("Sources or targets not defined")
|
515
682
|
sources = sources.split(",")
|
@@ -522,7 +689,7 @@ def plot_network_graph(config=None,nodes=None,edges=None,title=None,sources=None
|
|
522
689
|
tids = tids.split(",")
|
523
690
|
else:
|
524
691
|
tids = []
|
525
|
-
data, source_labels, target_labels = util.connection_graph_edge_types(nodes=None,edges=None,sources=sources,targets=targets,sids=sids,tids=tids,prepend_pop=not no_prepend_pop,edge_property=edge_property)
|
692
|
+
throw_away, data, source_labels, target_labels = util.connection_graph_edge_types(config=config,nodes=None,edges=None,sources=sources,targets=targets,sids=sids,tids=tids,prepend_pop=not no_prepend_pop,edge_property=edge_property)
|
526
693
|
|
527
694
|
if title == None or title=="":
|
528
695
|
title = "Network Graph"
|
@@ -638,7 +805,7 @@ def sim_setup(config_file='simulation_config.json',network=None):
|
|
638
805
|
# Plot spike train info
|
639
806
|
plot_inspikes(config_file)
|
640
807
|
# Using bmtool, print total number of connections between cell groups
|
641
|
-
|
808
|
+
total_connection_matrix(config=config_file,sources='all',targets='all',sids='pop_name',tids='pop_name',title='All Connections found', size_scalar=2, no_prepend_pop=True, synaptic_info='0')
|
642
809
|
# Plot 3d positions of the network
|
643
810
|
plot_3d_positions(populations='all',config=config_file,group_by='pop_name',title='3D Positions',save_file=None)
|
644
811
|
|
@@ -661,7 +828,7 @@ def plot_I_clamps(fp):
|
|
661
828
|
plt.legend()
|
662
829
|
num_clamps=num_clamps+1
|
663
830
|
|
664
|
-
def plot_basic_cell_info(config_file
|
831
|
+
def plot_basic_cell_info(config_file):
|
665
832
|
print("Network and node info:")
|
666
833
|
nodes=util.load_nodes_from_config(config_file)
|
667
834
|
if not nodes:
|
@@ -695,7 +862,11 @@ def plot_basic_cell_info(config_file,notebook=0):
|
|
695
862
|
count=1
|
696
863
|
df1 = pd.DataFrame(CELLS, columns = ["node_type","pop_name","model_type","count"])
|
697
864
|
print(j+':')
|
698
|
-
|
865
|
+
notebook = is_notebook()
|
866
|
+
if notebook == True:
|
867
|
+
display(HTML(df1.to_html()))
|
868
|
+
else:
|
869
|
+
print(df1)
|
699
870
|
elif node['model_type'][0]=='biophysical':
|
700
871
|
CELLS=[]
|
701
872
|
count=1
|
@@ -709,7 +880,7 @@ def plot_basic_cell_info(config_file,notebook=0):
|
|
709
880
|
pop_name=node['pop_name'][i]
|
710
881
|
model_type=node['model_type'][i]
|
711
882
|
model_template=node['model_template'][i]
|
712
|
-
morphology=node['morphology'][i] if node
|
883
|
+
morphology=node['morphology'][i] if node['morphology'][i] else ''
|
713
884
|
CELLS.append([node_type,pop_name,model_type,model_template,morphology,count])
|
714
885
|
count=1
|
715
886
|
else:
|
@@ -717,17 +888,20 @@ def plot_basic_cell_info(config_file,notebook=0):
|
|
717
888
|
pop_name=node['pop_name'][i]
|
718
889
|
model_type=node['model_type'][i]
|
719
890
|
model_template=node['model_template'][i]
|
720
|
-
morphology=node['morphology'][i] if node
|
891
|
+
morphology=node['morphology'][i] if node['morphology'][i] else ''
|
721
892
|
CELLS.append([node_type,pop_name,model_type,model_template,morphology,count])
|
722
893
|
count=1
|
723
894
|
df2 = pd.DataFrame(CELLS, columns = ["node_type","pop_name","model_type","model_template","morphology","count"])
|
724
895
|
print(j+':')
|
725
896
|
bio.append(j)
|
726
|
-
|
897
|
+
notebook = is_notebook()
|
898
|
+
if notebook == True:
|
899
|
+
display(HTML(df2.to_html()))
|
900
|
+
else:
|
901
|
+
print(df2)
|
727
902
|
if len(bio)>0:
|
728
903
|
return bio[0]
|
729
904
|
|
730
|
-
|
731
905
|
def plot_inspikes(fp):
|
732
906
|
|
733
907
|
print("Plotting spike Train info...")
|