bmtool 0.4.2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bmtool/bmplot.py CHANGED
@@ -14,12 +14,11 @@ import matplotlib.colors as colors
14
14
  import matplotlib.gridspec as gridspec
15
15
  from mpl_toolkits.mplot3d import Axes3D
16
16
  from IPython import get_ipython
17
- import math
17
+ from IPython.display import display, HTML
18
+ import statistics
18
19
  import pandas as pd
19
- import h5py
20
20
  import os
21
21
  import sys
22
- import time
23
22
 
24
23
  from .util.util import CellVarsFile #, missing_units
25
24
  from bmtk.analyzer.utils import listify
@@ -31,7 +30,36 @@ Plot BMTK models easily.
31
30
  python -m bmtool.plot
32
31
  """
33
32
 
34
- def connection_matrix(config=None,nodes=None,edges=None,title=None,sources=None, targets=None, sids=None, tids=None, size_scalar=1,no_prepend_pop=False,save_file=None,synaptic_info='0'):
33
+ def is_notebook() -> bool:
34
+ """
35
+ Used to tell if inside jupyter notebook or not. This is used to tell if we should use plt.show or not
36
+ """
37
+ try:
38
+ shell = get_ipython().__class__.__name__
39
+ if shell == 'ZMQInteractiveShell':
40
+ return True # Jupyter notebook or qtconsole
41
+ elif shell == 'TerminalInteractiveShell':
42
+ return False # Terminal running IPython
43
+ else:
44
+ return False # Other type (?)
45
+ except NameError:
46
+ return False # Probably standard Python interpreter
47
+
48
+ def total_connection_matrix(config=None,title=None,sources=None, targets=None, sids=None, tids=None,no_prepend_pop=False,save_file=None,synaptic_info='0',include_gap=True):
49
+ """
50
+ Generates connection plot displaying total connection or other stats
51
+ config: A BMTK simulation config
52
+ sources: network name(s) to plot
53
+ targets: network name(s) to plot
54
+ sids: source node identifier
55
+ tids: target node identifier
56
+ no_prepend_pop: dictates if population name is displayed before sid or tid when displaying graph
57
+ save_file: If plot should be saved
58
+ synaptic_info: '0' for total connections, '1' for mean and stdev connections, '2' for all synapse .mod files used, '3' for all synapse .json files used
59
+ include_gap: Determines if connectivity shown should include gap junctions + chemical synapses. False will only include chemical
60
+ """
61
+ if not config:
62
+ raise Exception("config not defined")
35
63
  if not sources or not targets:
36
64
  raise Exception("Sources or targets not defined")
37
65
  sources = sources.split(",")
@@ -44,7 +72,7 @@ def connection_matrix(config=None,nodes=None,edges=None,title=None,sources=None,
44
72
  tids = tids.split(",")
45
73
  else:
46
74
  tids = []
47
- text,num, source_labels, target_labels = util.connection_totals(config=config,nodes=None,edges=None,sources=sources,targets=targets,sids=sids,tids=tids,prepend_pop=not no_prepend_pop,synaptic_info=synaptic_info)
75
+ text,num, source_labels, target_labels = util.connection_totals(config=config,nodes=None,edges=None,sources=sources,targets=targets,sids=sids,tids=tids,prepend_pop=not no_prepend_pop,synaptic_info=synaptic_info,include_gap=include_gap)
48
76
 
49
77
  if title == None or title=="":
50
78
  title = "Total Connections"
@@ -58,18 +86,53 @@ def connection_matrix(config=None,nodes=None,edges=None,title=None,sources=None,
58
86
  plot_connection_info(text,num,source_labels,target_labels,title, syn_info=synaptic_info, save_file=save_file)
59
87
  return
60
88
 
61
- def percent_connection_matrix(config=None,nodes=None,edges=None,title=None,sources=None, targets=None, sids=None, tids=None, no_prepend_pop=False,save_file=None):
62
- text,num, source_labels, target_labels = util.connection_totals(config=config,nodes=None,edges=None,sources=sources,targets=targets,sids=sids,tids=tids,prepend_pop=not no_prepend_pop)
63
-
89
+ def percent_connection_matrix(config=None,nodes=None,edges=None,title=None,sources=None, targets=None, sids=None, tids=None, no_prepend_pop=False,save_file=None,method = 'total',include_gap=True):
90
+ """
91
+ Generates a plot showing the percent connectivity of a network
92
+ config: A BMTK simulation config
93
+ sources: network name(s) to plot
94
+ targets: network name(s) to plot
95
+ sids: source node identifier
96
+ tids: target node identifier
97
+ no_prepend_pop: dictates if population name is displayed before sid or tid when displaying graph
98
+ method: what percent to displace on the graph 'total','uni',or 'bi' for total connections, unidirectional connections or bidirectional connections
99
+ save_file: If plot should be saved
100
+ include_gap: Determines if connectivity shown should include gap junctions + chemical synapses. False will only include chemical
101
+ """
102
+ if not config:
103
+ raise Exception("config not defined")
104
+ if not sources or not targets:
105
+ raise Exception("Sources or targets not defined")
106
+
107
+ sources = sources.split(",")
108
+ targets = targets.split(",")
109
+ if sids:
110
+ sids = sids.split(",")
111
+ else:
112
+ sids = []
113
+ if tids:
114
+ tids = tids.split(",")
115
+ else:
116
+ tids = []
117
+ text,num, source_labels, target_labels = util.percent_connections(config=config,nodes=None,edges=None,sources=sources,targets=targets,sids=sids,tids=tids,prepend_pop=not no_prepend_pop,method=method,include_gap=include_gap)
64
118
  if title == None or title=="":
65
119
  title = "Percent Connectivity"
66
120
 
121
+
67
122
  plot_connection_info(text,num,source_labels,target_labels,title, save_file=save_file)
68
123
  return
69
124
 
70
125
  def probability_connection_matrix(config=None,nodes=None,edges=None,title=None,sources=None, targets=None, sids=None, tids=None,
71
- no_prepend_pop=False,save_file=None, dist_X=True,dist_Y=True,dist_Z=True,bins=8,line_plot=False,verbose=False):
72
- np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning)
126
+ no_prepend_pop=False,save_file=None, dist_X=True,dist_Y=True,dist_Z=True,bins=8,line_plot=False,verbose=False,include_gap=True):
127
+ """
128
+ Generates probability graphs
129
+ need to look into this more to see what it does
130
+ needs model_template to be defined to work
131
+ """
132
+ if not config:
133
+ raise Exception("config not defined")
134
+ if not sources or not targets:
135
+ raise Exception("Sources or targets not defined")
73
136
  if not sources or not targets:
74
137
  raise Exception("Sources or targets not defined")
75
138
  sources = sources.split(",")
@@ -85,7 +148,7 @@ def probability_connection_matrix(config=None,nodes=None,edges=None,title=None,s
85
148
 
86
149
  throwaway, data, source_labels, target_labels = util.connection_probabilities(config=config,nodes=None,
87
150
  edges=None,sources=sources,targets=targets,sids=sids,tids=tids,
88
- prepend_pop=not no_prepend_pop,dist_X=dist_X,dist_Y=dist_Y,dist_Z=dist_Z,num_bins=bins)
151
+ prepend_pop=not no_prepend_pop,dist_X=dist_X,dist_Y=dist_Y,dist_Z=dist_Z,num_bins=bins,include_gap=include_gap)
89
152
  if not data.any():
90
153
  return
91
154
  if data[0][0]==-1:
@@ -129,14 +192,44 @@ def probability_connection_matrix(config=None,nodes=None,edges=None,title=None,s
129
192
  st = fig.suptitle(tt, fontsize=14)
130
193
  fig.text(0.5, 0.04, 'Target', ha='center')
131
194
  fig.text(0.04, 0.5, 'Source', va='center', rotation='vertical')
132
- fig.show()
195
+ notebook = is_notebook
196
+ if notebook == False:
197
+ fig.show()
133
198
 
134
199
  return
135
200
 
136
- def convergence_connection_matrix(config=None,nodes=None,edges=None,title=None,sources=None, targets=None, sids=None, tids=None, no_prepend_pop=False,save_file=None,convergence=True,method='mean'):
137
- return divergence_connection_matrix(config,nodes ,edges ,title ,sources, targets, sids, tids, no_prepend_pop, save_file ,convergence, method)
201
+ def convergence_connection_matrix(config=None,title=None,sources=None, targets=None, sids=None, tids=None, no_prepend_pop=False,save_file=None,convergence=True,method='mean',include_gap=True):
202
+ """
203
+ Generates connection plot displaying convergence data
204
+ config: A BMTK simulation config
205
+ sources: network name(s) to plot
206
+ targets: network name(s) to plot
207
+ sids: source node identifier
208
+ tids: target node identifier
209
+ no_prepend_pop: dictates if population name is displayed before sid or tid when displaying graph
210
+ save_file: If plot should be saved
211
+ method: 'mean','min','max','stdev' for connvergence plot
212
+ """
213
+ if not config:
214
+ raise Exception("config not defined")
215
+ if not sources or not targets:
216
+ raise Exception("Sources or targets not defined")
217
+ return divergence_connection_matrix(config,title ,sources, targets, sids, tids, no_prepend_pop, save_file ,convergence, method,include_gap=include_gap)
138
218
 
139
- def divergence_connection_matrix(config=None,nodes=None,edges=None,title=None,sources=None, targets=None, sids=None, tids=None, no_prepend_pop=False,save_file=None,convergence=False,method='mean'):
219
+ def divergence_connection_matrix(config=None,title=None,sources=None, targets=None, sids=None, tids=None, no_prepend_pop=False,save_file=None,convergence=False,method='mean',include_gap=True):
220
+ """
221
+ Generates connection plot displaying divergence data
222
+ config: A BMTK simulation config
223
+ sources: network name(s) to plot
224
+ targets: network name(s) to plot
225
+ sids: source node identifier
226
+ tids: target node identifier
227
+ no_prepend_pop: dictates if population name is displayed before sid or tid when displaying graph
228
+ save_file: If plot should be saved
229
+ method: 'mean','min','max','stdev' for connvergence plot
230
+ """
231
+ if not config:
232
+ raise Exception("config not defined")
140
233
  if not sources or not targets:
141
234
  raise Exception("Sources or targets not defined")
142
235
  sources = sources.split(",")
@@ -150,7 +243,7 @@ def divergence_connection_matrix(config=None,nodes=None,edges=None,title=None,so
150
243
  else:
151
244
  tids = []
152
245
 
153
- syn_info, data, source_labels, target_labels = util.connection_divergence(config=config,nodes=None,edges=None,sources=sources,targets=targets,sids=sids,tids=tids,prepend_pop=not no_prepend_pop,convergence=convergence,method=method)
246
+ syn_info, data, source_labels, target_labels = util.connection_divergence(config=config,nodes=None,edges=None,sources=sources,targets=targets,sids=sids,tids=tids,prepend_pop=not no_prepend_pop,convergence=convergence,method=method,include_gap=include_gap)
154
247
 
155
248
 
156
249
  #data, labels = util.connection_divergence_average(config=config,nodes=nodes,edges=edges,populations=populations)
@@ -174,26 +267,69 @@ def divergence_connection_matrix(config=None,nodes=None,edges=None,title=None,so
174
267
  plot_connection_info(data,data,source_labels,target_labels,title, save_file=save_file)
175
268
  return
176
269
 
177
- def edge_histogram_matrix(**kwargs):
178
- config = kwargs["config"]
179
- sources = kwargs["sources"]
180
- targets = kwargs["targets"]
181
- sids = kwargs["sids"]
182
- tids = kwargs["tids"]
183
- no_prepend_pop = kwargs["no_prepend_pop"]
184
- edge_property = kwargs["edge_property"]
185
- time = int(kwargs["time"])
186
- time_compare = kwargs["time_compare"]
187
- report = kwargs["report"]
188
-
189
- title = kwargs["title"]
190
-
191
- save_file = kwargs["save_file"]
192
-
270
+ def connection_histogram(config=None,nodes=None,edges=None,sources=[],targets=[],sids=[],tids=[],prepend_pop=True,synaptic_info='0',
271
+ source_cell = None,target_cell = None,include_gap=True):
272
+ """
273
+ Generates histogram of number of connections individual cells in a population receieve from another population
274
+ config: A BMTK simulation config
275
+ sources: network name(s) to plot
276
+ targets: network name(s) to plot
277
+ sids: source node identifier
278
+ tids: target node identifier
279
+ no_prepend_pop: dictates if population name is displayed before sid or tid when displaying graph
280
+ source_cell: where connections are coming from
281
+ target_cell: where connections on coming onto
282
+ save_file: If plot should be saved
283
+ """
284
+ def connection_pair_histogram(**kwargs):
285
+ edges = kwargs["edges"]
286
+ source_id_type = kwargs["sid"]
287
+ target_id_type = kwargs["tid"]
288
+ source_id = kwargs["source_id"]
289
+ target_id = kwargs["target_id"]
290
+ if source_id == source_cell and target_id == target_cell:
291
+ temp = edges[(edges[source_id_type] == source_id) & (edges[target_id_type]==target_id)]
292
+ if include_gap == False:
293
+ temp = temp[temp['is_gap_junction'] != True]
294
+ node_pairs = temp.groupby('target_node_id')['source_node_id'].count()
295
+ conn_mean = statistics.mean(node_pairs.values)
296
+ conn_std = statistics.stdev(node_pairs.values)
297
+ conn_median = statistics.median(node_pairs.values)
298
+ label = "mean {:.2f} std ({:.2f}) median {:.2f}".format(conn_mean,conn_std,conn_median)
299
+ plt.hist(node_pairs.values,density=True,bins='auto',stacked=True,label=label)
300
+ plt.legend()
301
+ plt.xlabel("# of conns from {} to {}".format(source_cell,target_cell))
302
+ plt.ylabel("Density")
303
+ plt.show()
304
+ else: # dont care about other cell pairs so pass
305
+ pass
306
+
307
+ if not config:
308
+ raise Exception("config not defined")
193
309
  if not sources or not targets:
194
310
  raise Exception("Sources or targets not defined")
195
311
  sources = sources.split(",")
196
312
  targets = targets.split(",")
313
+ if sids:
314
+ sids = sids.split(",")
315
+ else:
316
+ sids = []
317
+ if tids:
318
+ tids = tids.split(",")
319
+ else:
320
+ tids = []
321
+ util.relation_matrix(config,nodes,edges,sources,targets,sids,tids,prepend_pop,relation_func=connection_pair_histogram,synaptic_info=synaptic_info)
322
+
323
+ def edge_histogram_matrix(config=None,sources = None,targets=None,sids=None,tids=None,no_prepend_pop=None,edge_property = None,time = None,time_compare = None,report=None,title=None,save_file=None):
324
+ """
325
+ write about function here
326
+ """
327
+
328
+ if not config:
329
+ raise Exception("config not defined")
330
+ if not sources or not targets:
331
+ raise Exception("Sources or targets not defined")
332
+ targets = targets.split(",")
197
333
  if sids:
198
334
  sids = sids.split(",")
199
335
  else:
@@ -231,8 +367,11 @@ def edge_histogram_matrix(**kwargs):
231
367
  fig.text(0.04, 0.5, 'Source', va='center', rotation='vertical')
232
368
  plt.draw()
233
369
 
234
-
235
370
  def plot_connection_info(text, num, source_labels,target_labels, title, syn_info='0', save_file=None):
371
+ """
372
+ write about function here
373
+ """
374
+
236
375
  #num = pd.DataFrame(num).fillna('nc').to_numpy() # replace nan with nc * does not work with imshow
237
376
 
238
377
  num_source=len(source_labels)
@@ -274,14 +413,17 @@ def plot_connection_info(text, num, source_labels,target_labels, title, syn_info
274
413
  ax1.set_xlabel('Target', size=11, weight = 'semibold')
275
414
  ax1.set_title(title,size=20, weight = 'semibold')
276
415
  #plt.tight_layout()
277
-
278
- fig1.show()
279
-
416
+ notebook = is_notebook()
417
+ if notebook == False:
418
+ fig1.show()
280
419
  if save_file:
281
420
  plt.savefig(save_file)
282
421
  return
283
422
 
284
423
  def raster_old(config=None,title=None,populations=['hippocampus']):
424
+ """
425
+ old function probs dep
426
+ """
285
427
  conf = util.load_config(config)
286
428
  spikes_path = os.path.join(conf["output"]["output_dir"],conf["output"]["spikes_file"])
287
429
  nodes = util.load_nodes_from_config(config)
@@ -289,6 +431,9 @@ def raster_old(config=None,title=None,populations=['hippocampus']):
289
431
  return
290
432
 
291
433
  def raster(config=None,title=None,population=None,group_key='pop_name'):
434
+ """
435
+ old function probs dep or more to new spike module?
436
+ """
292
437
  conf = util.load_config(config)
293
438
 
294
439
  cells_file = conf["networks"]["nodes"][0]["nodes_file"]
@@ -300,6 +445,9 @@ def raster(config=None,title=None,population=None,group_key='pop_name'):
300
445
  return
301
446
 
302
447
  def plot_spikes(nodes, spikes_file,save_file=None):
448
+ """
449
+ old function probs dep
450
+ """
303
451
  import h5py
304
452
 
305
453
  spikes_h5 = h5py.File(spikes_file, 'r')
@@ -349,12 +497,23 @@ def plot_spikes(nodes, spikes_file,save_file=None):
349
497
 
350
498
  return
351
499
 
352
- def plot_3d_positions(**kwargs):
353
- populations_list = kwargs["populations"]
354
- config = kwargs["config"]
355
- group_keys = kwargs["group_by"]
356
- title = kwargs["title"]
357
- save_file = kwargs["save_file"]
500
+ def plot_3d_positions(config=None,populations_list=None,group_by=None,title=None,save_file=None):
501
+ """
502
+ plots a 3D graph of all cells with x,y,z location
503
+ config: A BMTK simulation config
504
+ populations_list: Which network(s) to plot
505
+ group_by: How to name cell groups
506
+ title: plot title
507
+ save_file: If plot should be saved
508
+ """
509
+
510
+ if not config:
511
+ raise Exception("config not defined")
512
+ if populations_list == None:
513
+ populations_list = "all"
514
+ group_keys = group_by
515
+ if title == None:
516
+ title = "3D positions"
358
517
 
359
518
  nodes = util.load_nodes_from_config(config)
360
519
 
@@ -366,7 +525,7 @@ def plot_3d_positions(**kwargs):
366
525
  group_keys = group_keys.split(",")
367
526
  group_keys += (len(populations)-len(group_keys)) * ["node_type_id"] #Extend the array to default values if not enough given
368
527
  fig = plt.figure(figsize=(10,10))
369
- ax = Axes3D(fig)
528
+ ax = fig.add_subplot(projection='3d')
370
529
  handles = []
371
530
  for nodes_key,group_key in zip(list(nodes),group_keys):
372
531
  if 'all' not in populations and nodes_key not in populations:
@@ -401,115 +560,123 @@ def plot_3d_positions(**kwargs):
401
560
 
402
561
  if save_file:
403
562
  plt.savefig(save_file)
563
+ notebook = is_notebook
564
+ if notebook == False:
565
+ plt.show()
404
566
 
405
567
  return
406
568
 
569
+ def cell_rotation_3d(config=None, populations_list=None, group_by=None, title=None, save_file=None, quiver_length=None, arrow_length_ratio=None, group=None, max_cells=1000000):
570
+ from scipy.spatial.transform import Rotation as R
571
+ if not config:
572
+ raise Exception("config not defined")
407
573
 
408
- def cell_rotation_3d(**kwargs):
409
- populations_list = kwargs["populations"]
410
- config = kwargs["config"]
411
- group_keys = kwargs["group_by"]
412
- title = kwargs.get("title")
413
- save_file = kwargs["save_file"]
414
- quiver_length = kwargs["quiver_length"]
415
- arrow_length_ratio = kwargs["arrow_length_ratio"]
416
- group = kwargs["group"]
417
- max_cells = kwargs.get("max_cells",999999999)
418
- init_vector = kwargs.get("init_vector","1,0,0")
574
+ if populations_list is None:
575
+ populations_list = ["all"]
576
+
577
+ group_keys = group_by.split(",") if group_by else []
578
+
579
+ if title is None:
580
+ title = "Cell rotations"
419
581
 
420
582
  nodes = util.load_nodes_from_config(config)
421
583
 
422
584
  if 'all' in populations_list:
423
585
  populations = list(nodes)
424
586
  else:
425
- populations = populations_list.split(",")
587
+ populations = populations_list
426
588
 
427
- group_keys = group_keys.split(",")
428
- group_keys += (len(populations)-len(group_keys)) * ["node_type_id"] #Extend the array to default values if not enough given
429
- fig = plt.figure(figsize=(10,10))
430
- ax = Axes3D(fig)
589
+ fig = plt.figure(figsize=(10, 10))
590
+ ax = fig.add_subplot(111, projection='3d')
431
591
  handles = []
432
- for nodes_key,group_key in zip(list(nodes),group_keys):
592
+
593
+ for nodes_key, group_key in zip(list(nodes), group_keys):
433
594
  if 'all' not in populations and nodes_key not in populations:
434
595
  continue
435
-
596
+
436
597
  nodes_df = nodes[nodes_key]
437
598
 
438
599
  if group_key is not None:
439
- if group_key not in nodes_df:
440
- raise Exception('Could not find column {}'.format(group_key))
600
+ if group_key not in nodes_df.columns:
601
+ raise Exception(f'Could not find column {group_key}')
441
602
  groupings = nodes_df.groupby(group_key)
442
603
 
443
604
  n_colors = nodes_df[group_key].nunique()
444
- color_norm = colors.Normalize(vmin=0, vmax=(n_colors-1))
605
+ color_norm = colors.Normalize(vmin=0, vmax=(n_colors - 1))
445
606
  scalar_map = cmx.ScalarMappable(norm=color_norm, cmap='hsv')
446
- color_map = [scalar_map.to_rgba(i) for i in range(0, n_colors)]
607
+ color_map = [scalar_map.to_rgba(i) for i in range(n_colors)]
447
608
  else:
448
609
  groupings = [(None, nodes_df)]
449
610
  color_map = ['blue']
450
611
 
451
612
  cells_plotted = 0
452
613
  for color, (group_name, group_df) in zip(color_map, groupings):
453
- # if we selected a group and it's not in the list continue
454
614
  if group and group_name not in group.split(","):
455
615
  continue
456
616
 
457
- if "pos_x" not in group_df: #could also check model type == virtual
458
- continue #can't plot them if there isn't an xy coordinate (may be virtual)
617
+ if "pos_x" not in group_df or "rotation_angle_xaxis" not in group_df:
618
+ continue
459
619
 
460
- # if we exceed the max cells, stop plotting or limit
461
620
  if cells_plotted >= max_cells:
462
621
  continue
622
+
463
623
  if len(group_df) + cells_plotted > max_cells:
464
624
  total_remaining = max_cells - cells_plotted
465
625
  group_df = group_df[:total_remaining]
626
+
466
627
  cells_plotted += len(group_df)
467
628
 
468
629
  X = group_df["pos_x"]
469
630
  Y = group_df["pos_y"]
470
631
  Z = group_df["pos_z"]
471
- U = group_df.get("rotation_angle_xaxis")
472
- V = group_df.get("rotation_angle_yaxis")
473
- W = group_df.get("rotation_angle_zaxis")
632
+ U = group_df["rotation_angle_xaxis"].values
633
+ V = group_df["rotation_angle_yaxis"].values
634
+ W = group_df["rotation_angle_zaxis"].values
635
+
474
636
  if U is None:
475
637
  U = np.zeros(len(X))
476
638
  if V is None:
477
639
  V = np.zeros(len(Y))
478
640
  if W is None:
479
641
  W = np.zeros(len(Z))
480
-
481
- #Convert to arrow direction
482
- from scipy.spatial.transform import Rotation as R
483
- uvw = pd.DataFrame([U,V,W]).T
484
- init_vector = init_vector.split(',')
485
- init_vector = np.repeat([init_vector],len(X),axis=0)
486
-
487
- # To get the final cell orientation after rotation,
488
- # you need to use function Rotaion.apply(init_vec),
489
- # where init_vec is a vector of the initial orientation of a cell
490
- #rots = R.from_euler('xyz', uvw).apply(init_vector.astype(float))
491
- #rots = R.from_euler('xyz', pd.DataFrame([rots[:,0],rots[:,1],rots[:,2]]).T).as_rotvec().T
492
642
 
493
- rots = R.from_euler('zyx', uvw).apply(init_vector.astype(float)).T
494
- h = ax.quiver(X, Y, Z, rots[0],rots[1],rots[2],color=color,label=group_name, arrow_length_ratio = arrow_length_ratio, length=quiver_length)
643
+ # Create rotation matrices from Euler angles
644
+ rotations = R.from_euler('xyz', np.column_stack((U, V, W)), degrees=False)
495
645
 
496
- #h = ax.quiver(X, Y, Z, rots[0],rots[1],rots[2],color=color,label=group_name, arrow_length_ratio = arrow_length_ratio, length=quiver_length)
497
- ax.scatter(X,Y,Z,color=color,label=group_name)
646
+ # Define initial vectors
647
+ init_vectors = np.column_stack((np.ones(len(X)), np.zeros(len(Y)), np.zeros(len(Z))))
648
+
649
+ # Apply rotations to initial vectors
650
+ rots = np.dot(rotations.as_matrix(), init_vectors.T).T
651
+
652
+ # Extract x, y, and z components of the rotated vectors
653
+ rot_x = rots[:, 0]
654
+ rot_y = rots[:, 1]
655
+ rot_z = rots[:, 2]
656
+
657
+ h = ax.quiver(X, Y, Z, rot_x, rot_y, rot_z, color=color, label=group_name, arrow_length_ratio=arrow_length_ratio, length=quiver_length)
658
+ ax.scatter(X, Y, Z, color=color, label=group_name)
659
+ ax.set_xlim([min(X), max(X)])
660
+ ax.set_ylim([min(Y), max(Y)])
661
+ ax.set_zlim([min(Z), max(Z)])
498
662
  handles.append(h)
663
+
499
664
  if not handles:
500
665
  return
666
+
501
667
  plt.title(title)
502
668
  plt.legend(handles=handles)
503
-
504
669
  plt.draw()
505
670
 
506
671
  if save_file:
507
672
  plt.savefig(save_file)
508
-
509
- return
673
+ notebook = is_notebook
674
+ if notebook == False:
675
+ plt.show()
510
676
 
511
677
  def plot_network_graph(config=None,nodes=None,edges=None,title=None,sources=None, targets=None, sids=None, tids=None, no_prepend_pop=False,save_file=None,edge_property='model_template'):
512
-
678
+ if not config:
679
+ raise Exception("config not defined")
513
680
  if not sources or not targets:
514
681
  raise Exception("Sources or targets not defined")
515
682
  sources = sources.split(",")
@@ -522,7 +689,7 @@ def plot_network_graph(config=None,nodes=None,edges=None,title=None,sources=None
522
689
  tids = tids.split(",")
523
690
  else:
524
691
  tids = []
525
- data, source_labels, target_labels = util.connection_graph_edge_types(nodes=None,edges=None,sources=sources,targets=targets,sids=sids,tids=tids,prepend_pop=not no_prepend_pop,edge_property=edge_property)
692
+ throw_away, data, source_labels, target_labels = util.connection_graph_edge_types(config=config,nodes=None,edges=None,sources=sources,targets=targets,sids=sids,tids=tids,prepend_pop=not no_prepend_pop,edge_property=edge_property)
526
693
 
527
694
  if title == None or title=="":
528
695
  title = "Network Graph"
@@ -638,7 +805,7 @@ def sim_setup(config_file='simulation_config.json',network=None):
638
805
  # Plot spike train info
639
806
  plot_inspikes(config_file)
640
807
  # Using bmtool, print total number of connections between cell groups
641
- connection_matrix(config=config_file,sources='all',targets='all',sids='pop_name',tids='pop_name',title='All Connections found', size_scalar=2, no_prepend_pop=True, synaptic_info='0')
808
+ total_connection_matrix(config=config_file,sources='all',targets='all',sids='pop_name',tids='pop_name',title='All Connections found', size_scalar=2, no_prepend_pop=True, synaptic_info='0')
642
809
  # Plot 3d positions of the network
643
810
  plot_3d_positions(populations='all',config=config_file,group_by='pop_name',title='3D Positions',save_file=None)
644
811
 
@@ -661,7 +828,7 @@ def plot_I_clamps(fp):
661
828
  plt.legend()
662
829
  num_clamps=num_clamps+1
663
830
 
664
- def plot_basic_cell_info(config_file,notebook=0):
831
+ def plot_basic_cell_info(config_file):
665
832
  print("Network and node info:")
666
833
  nodes=util.load_nodes_from_config(config_file)
667
834
  if not nodes:
@@ -695,7 +862,11 @@ def plot_basic_cell_info(config_file,notebook=0):
695
862
  count=1
696
863
  df1 = pd.DataFrame(CELLS, columns = ["node_type","pop_name","model_type","count"])
697
864
  print(j+':')
698
- print(df1)
865
+ notebook = is_notebook()
866
+ if notebook == True:
867
+ display(HTML(df1.to_html()))
868
+ else:
869
+ print(df1)
699
870
  elif node['model_type'][0]=='biophysical':
700
871
  CELLS=[]
701
872
  count=1
@@ -709,7 +880,7 @@ def plot_basic_cell_info(config_file,notebook=0):
709
880
  pop_name=node['pop_name'][i]
710
881
  model_type=node['model_type'][i]
711
882
  model_template=node['model_template'][i]
712
- morphology=node['morphology'][i] if node.get('morphology') else ''
883
+ morphology=node['morphology'][i] if node['morphology'][i] else ''
713
884
  CELLS.append([node_type,pop_name,model_type,model_template,morphology,count])
714
885
  count=1
715
886
  else:
@@ -717,17 +888,20 @@ def plot_basic_cell_info(config_file,notebook=0):
717
888
  pop_name=node['pop_name'][i]
718
889
  model_type=node['model_type'][i]
719
890
  model_template=node['model_template'][i]
720
- morphology=node['morphology'][i] if node.get('morphology') else ''
891
+ morphology=node['morphology'][i] if node['morphology'][i] else ''
721
892
  CELLS.append([node_type,pop_name,model_type,model_template,morphology,count])
722
893
  count=1
723
894
  df2 = pd.DataFrame(CELLS, columns = ["node_type","pop_name","model_type","model_template","morphology","count"])
724
895
  print(j+':')
725
896
  bio.append(j)
726
- print(df2)
897
+ notebook = is_notebook()
898
+ if notebook == True:
899
+ display(HTML(df2.to_html()))
900
+ else:
901
+ print(df2)
727
902
  if len(bio)>0:
728
903
  return bio[0]
729
904
 
730
-
731
905
  def plot_inspikes(fp):
732
906
 
733
907
  print("Plotting spike Train info...")