bmtool 0.7.0.6.2__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,40 +2,37 @@
2
2
  Want to be able to take multiple plot names in and plot them all at the same time, to save time
3
3
  https://stackoverflow.com/questions/458209/is-there-a-way-to-detach-matplotlib-plots-so-that-the-computation-can-continue
4
4
  """
5
- from ..util import util
6
- import numpy as np
5
+ import re
6
+ import statistics
7
+
7
8
  import matplotlib
8
- import matplotlib.pyplot as plt
9
9
  import matplotlib.cm as cmx
10
10
  import matplotlib.colors as colors
11
- from IPython import get_ipython
12
- from IPython.display import display, HTML
13
- import statistics
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
14
13
  import pandas as pd
15
- import os
16
- import sys
17
- import re
18
-
19
- from ..util.util import CellVarsFile,load_nodes_from_config,load_templates_from_config #, missing_units
20
- from bmtk.analyzer.utils import listify
14
+ from IPython import get_ipython
21
15
  from neuron import h
22
16
 
17
+ from ..util import util
18
+
23
19
  use_description = """
24
20
 
25
21
  Plot BMTK models easily.
26
22
 
27
- python -m bmtool.plot
23
+ python -m bmtool.plot
28
24
  """
29
25
 
26
+
30
27
  def is_notebook() -> bool:
31
28
  """
32
29
  Detect if code is running in a Jupyter notebook environment.
33
-
30
+
34
31
  Returns:
35
32
  --------
36
33
  bool
37
34
  True if running in a Jupyter notebook, False otherwise.
38
-
35
+
39
36
  Notes:
40
37
  ------
41
38
  This is used to determine whether to call plt.show() explicitly or
@@ -43,20 +40,31 @@ def is_notebook() -> bool:
43
40
  """
44
41
  try:
45
42
  shell = get_ipython().__class__.__name__
46
- if shell == 'ZMQInteractiveShell':
47
- return True # Jupyter notebook or qtconsole
48
- elif shell == 'TerminalInteractiveShell':
43
+ if shell == "ZMQInteractiveShell":
44
+ return True # Jupyter notebook or qtconsole
45
+ elif shell == "TerminalInteractiveShell":
49
46
  return False # Terminal running IPython
50
47
  else:
51
48
  return False # Other type (?)
52
49
  except NameError:
53
- return False # Probably standard Python interpreter
54
-
55
-
56
- def total_connection_matrix(config=None, title=None, sources=None, targets=None, sids=None, tids=None, no_prepend_pop=False, save_file=None, synaptic_info='0', include_gap=True):
50
+ return False # Probably standard Python interpreter
51
+
52
+
53
+ def total_connection_matrix(
54
+ config=None,
55
+ title=None,
56
+ sources=None,
57
+ targets=None,
58
+ sids=None,
59
+ tids=None,
60
+ no_prepend_pop=False,
61
+ save_file=None,
62
+ synaptic_info="0",
63
+ include_gap=True,
64
+ ):
57
65
  """
58
66
  Generate a plot displaying total connections or other synaptic statistics.
59
-
67
+
60
68
  Parameters:
61
69
  -----------
62
70
  config : str
@@ -84,7 +92,7 @@ def total_connection_matrix(config=None, title=None, sources=None, targets=None,
84
92
  include_gap : bool, optional
85
93
  If True, include gap junctions and chemical synapses in the analysis.
86
94
  If False, only include chemical synapses.
87
-
95
+
88
96
  Returns:
89
97
  --------
90
98
  None
@@ -104,27 +112,53 @@ def total_connection_matrix(config=None, title=None, sources=None, targets=None,
104
112
  tids = tids.split(",")
105
113
  else:
106
114
  tids = []
107
- text,num, source_labels, target_labels = util.connection_totals(config=config,nodes=None,edges=None,sources=sources,targets=targets,sids=sids,tids=tids,prepend_pop=not no_prepend_pop,synaptic_info=synaptic_info,include_gap=include_gap)
115
+ text, num, source_labels, target_labels = util.connection_totals(
116
+ config=config,
117
+ nodes=None,
118
+ edges=None,
119
+ sources=sources,
120
+ targets=targets,
121
+ sids=sids,
122
+ tids=tids,
123
+ prepend_pop=not no_prepend_pop,
124
+ synaptic_info=synaptic_info,
125
+ include_gap=include_gap,
126
+ )
108
127
 
109
- if title == None or title=="":
128
+ if title is None or title == "":
110
129
  title = "Total Connections"
111
- if synaptic_info=='1':
112
- title = "Mean and Stdev # of Conn on Target"
113
- if synaptic_info=='2':
130
+ if synaptic_info == "1":
131
+ title = "Mean and Stdev # of Conn on Target"
132
+ if synaptic_info == "2":
114
133
  title = "All Synapse .mod Files Used"
115
- if synaptic_info=='3':
134
+ if synaptic_info == "3":
116
135
  title = "All Synapse .json Files Used"
117
- plot_connection_info(text,num,source_labels,target_labels,title, syn_info=synaptic_info, save_file=save_file)
136
+ plot_connection_info(
137
+ text, num, source_labels, target_labels, title, syn_info=synaptic_info, save_file=save_file
138
+ )
118
139
  return
119
140
 
120
141
 
121
- def percent_connection_matrix(config=None,nodes=None,edges=None,title=None,sources=None, targets=None, sids=None, tids=None, no_prepend_pop=False,save_file=None,method = 'total',include_gap=True):
142
+ def percent_connection_matrix(
143
+ config=None,
144
+ nodes=None,
145
+ edges=None,
146
+ title=None,
147
+ sources=None,
148
+ targets=None,
149
+ sids=None,
150
+ tids=None,
151
+ no_prepend_pop=False,
152
+ save_file=None,
153
+ method="total",
154
+ include_gap=True,
155
+ ):
122
156
  """
123
157
  Generates a plot showing the percent connectivity of a network
124
- config: A BMTK simulation config
158
+ config: A BMTK simulation config
125
159
  sources: network name(s) to plot
126
160
  targets: network name(s) to plot
127
- sids: source node identifier
161
+ sids: source node identifier
128
162
  tids: target node identifier
129
163
  no_prepend_pop: dictates if population name is displayed before sid or tid when displaying graph
130
164
  method: what percent to displace on the graph 'total','uni',or 'bi' for total connections, unidirectional connections or bidirectional connections
@@ -135,7 +169,7 @@ def percent_connection_matrix(config=None,nodes=None,edges=None,title=None,sourc
135
169
  raise Exception("config not defined")
136
170
  if not sources or not targets:
137
171
  raise Exception("Sources or targets not defined")
138
-
172
+
139
173
  sources = sources.split(",")
140
174
  targets = targets.split(",")
141
175
  if sids:
@@ -146,17 +180,44 @@ def percent_connection_matrix(config=None,nodes=None,edges=None,title=None,sourc
146
180
  tids = tids.split(",")
147
181
  else:
148
182
  tids = []
149
- text,num, source_labels, target_labels = util.percent_connections(config=config,nodes=None,edges=None,sources=sources,targets=targets,sids=sids,tids=tids,prepend_pop=not no_prepend_pop,method=method,include_gap=include_gap)
150
- if title == None or title=="":
183
+ text, num, source_labels, target_labels = util.percent_connections(
184
+ config=config,
185
+ nodes=None,
186
+ edges=None,
187
+ sources=sources,
188
+ targets=targets,
189
+ sids=sids,
190
+ tids=tids,
191
+ prepend_pop=not no_prepend_pop,
192
+ method=method,
193
+ include_gap=include_gap,
194
+ )
195
+ if title is None or title == "":
151
196
  title = "Percent Connectivity"
152
197
 
153
-
154
- plot_connection_info(text,num,source_labels,target_labels,title, save_file=save_file)
198
+ plot_connection_info(text, num, source_labels, target_labels, title, save_file=save_file)
155
199
  return
156
200
 
157
201
 
158
- def probability_connection_matrix(config=None,nodes=None,edges=None,title=None,sources=None, targets=None, sids=None, tids=None,
159
- no_prepend_pop=False,save_file=None, dist_X=True,dist_Y=True,dist_Z=True,bins=8,line_plot=False,verbose=False,include_gap=True):
202
+ def probability_connection_matrix(
203
+ config=None,
204
+ nodes=None,
205
+ edges=None,
206
+ title=None,
207
+ sources=None,
208
+ targets=None,
209
+ sids=None,
210
+ tids=None,
211
+ no_prepend_pop=False,
212
+ save_file=None,
213
+ dist_X=True,
214
+ dist_Y=True,
215
+ dist_Z=True,
216
+ bins=8,
217
+ line_plot=False,
218
+ verbose=False,
219
+ include_gap=True,
220
+ ):
160
221
  """
161
222
  Generates probability graphs
162
223
  need to look into this more to see what it does
@@ -179,41 +240,53 @@ def probability_connection_matrix(config=None,nodes=None,edges=None,title=None,s
179
240
  else:
180
241
  tids = []
181
242
 
182
- throwaway, data, source_labels, target_labels = util.connection_probabilities(config=config,nodes=None,
183
- edges=None,sources=sources,targets=targets,sids=sids,tids=tids,
184
- prepend_pop=not no_prepend_pop,dist_X=dist_X,dist_Y=dist_Y,dist_Z=dist_Z,num_bins=bins,include_gap=include_gap)
243
+ throwaway, data, source_labels, target_labels = util.connection_probabilities(
244
+ config=config,
245
+ nodes=None,
246
+ edges=None,
247
+ sources=sources,
248
+ targets=targets,
249
+ sids=sids,
250
+ tids=tids,
251
+ prepend_pop=not no_prepend_pop,
252
+ dist_X=dist_X,
253
+ dist_Y=dist_Y,
254
+ dist_Z=dist_Z,
255
+ num_bins=bins,
256
+ include_gap=include_gap,
257
+ )
185
258
  if not data.any():
186
259
  return
187
- if data[0][0]==-1:
260
+ if data[0][0] == -1:
188
261
  return
189
- #plot_connection_info(data,source_labels,target_labels,title, save_file=save_file)
262
+ # plot_connection_info(data,source_labels,target_labels,title, save_file=save_file)
190
263
 
191
- #plt.clf()# clears previous plots
192
- np.seterr(divide='ignore', invalid='ignore')
264
+ # plt.clf()# clears previous plots
265
+ np.seterr(divide="ignore", invalid="ignore")
193
266
  num_src, num_tar = data.shape
194
- fig, axes = plt.subplots(nrows=num_src, ncols=num_tar, figsize=(12,12))
267
+ fig, axes = plt.subplots(nrows=num_src, ncols=num_tar, figsize=(12, 12))
195
268
  fig.subplots_adjust(hspace=0.5, wspace=0.5)
196
269
 
197
270
  for x in range(num_src):
198
271
  for y in range(num_tar):
199
272
  ns = data[x][y]["ns"]
200
273
  bins = data[x][y]["bins"]
201
-
274
+
202
275
  XX = bins[:-1]
203
- YY = ns[0]/ns[1]
276
+ YY = ns[0] / ns[1]
204
277
 
205
278
  if line_plot:
206
- axes[x,y].plot(XX,YY)
279
+ axes[x, y].plot(XX, YY)
207
280
  else:
208
- axes[x,y].bar(XX,YY)
281
+ axes[x, y].bar(XX, YY)
209
282
 
210
- if x == num_src-1:
211
- axes[x,y].set_xlabel(target_labels[y])
283
+ if x == num_src - 1:
284
+ axes[x, y].set_xlabel(target_labels[y])
212
285
  if y == 0:
213
- axes[x,y].set_ylabel(source_labels[x])
286
+ axes[x, y].set_ylabel(source_labels[x])
214
287
 
215
288
  if verbose:
216
- print("Source: [" + source_labels[x] + "] | Target: ["+ target_labels[y] +"]")
289
+ print("Source: [" + source_labels[x] + "] | Target: [" + target_labels[y] + "]")
217
290
  print("X:")
218
291
  print(XX)
219
292
  print("Y:")
@@ -223,41 +296,80 @@ def probability_connection_matrix(config=None,nodes=None,edges=None,title=None,s
223
296
  if title:
224
297
  tt = title
225
298
  st = fig.suptitle(tt, fontsize=14)
226
- fig.text(0.5, 0.04, 'Target', ha='center')
227
- fig.text(0.04, 0.5, 'Source', va='center', rotation='vertical')
299
+ fig.text(0.5, 0.04, "Target", ha="center")
300
+ fig.text(0.04, 0.5, "Source", va="center", rotation="vertical")
228
301
  notebook = is_notebook
229
- if notebook == False:
302
+ if not notebook:
230
303
  fig.show()
231
304
 
232
305
  return
233
306
 
234
307
 
235
- def convergence_connection_matrix(config=None,title=None,sources=None, targets=None, sids=None, tids=None, no_prepend_pop=False,save_file=None,convergence=True,method='mean+std',include_gap=True,return_dict=None):
308
+ def convergence_connection_matrix(
309
+ config=None,
310
+ title=None,
311
+ sources=None,
312
+ targets=None,
313
+ sids=None,
314
+ tids=None,
315
+ no_prepend_pop=False,
316
+ save_file=None,
317
+ convergence=True,
318
+ method="mean+std",
319
+ include_gap=True,
320
+ return_dict=None,
321
+ ):
236
322
  """
237
323
  Generates connection plot displaying convergence data
238
- config: A BMTK simulation config
324
+ config: A BMTK simulation config
239
325
  sources: network name(s) to plot
240
326
  targets: network name(s) to plot
241
- sids: source node identifier
327
+ sids: source node identifier
242
328
  tids: target node identifier
243
329
  no_prepend_pop: dictates if population name is displayed before sid or tid when displaying graph
244
330
  save_file: If plot should be saved
245
- method: 'mean','min','max','stdev' or 'mean+std' connvergence plot
331
+ method: 'mean','min','max','stdev' or 'mean+std' connvergence plot
246
332
  """
247
333
  if not config:
248
334
  raise Exception("config not defined")
249
335
  if not sources or not targets:
250
336
  raise Exception("Sources or targets not defined")
251
- return divergence_connection_matrix(config,title ,sources, targets, sids, tids, no_prepend_pop, save_file ,convergence, method,include_gap=include_gap,return_dict=return_dict)
337
+ return divergence_connection_matrix(
338
+ config,
339
+ title,
340
+ sources,
341
+ targets,
342
+ sids,
343
+ tids,
344
+ no_prepend_pop,
345
+ save_file,
346
+ convergence,
347
+ method,
348
+ include_gap=include_gap,
349
+ return_dict=return_dict,
350
+ )
252
351
 
253
352
 
254
- def divergence_connection_matrix(config=None,title=None,sources=None, targets=None, sids=None, tids=None, no_prepend_pop=False,save_file=None,convergence=False,method='mean+std',include_gap=True,return_dict=None):
353
+ def divergence_connection_matrix(
354
+ config=None,
355
+ title=None,
356
+ sources=None,
357
+ targets=None,
358
+ sids=None,
359
+ tids=None,
360
+ no_prepend_pop=False,
361
+ save_file=None,
362
+ convergence=False,
363
+ method="mean+std",
364
+ include_gap=True,
365
+ return_dict=None,
366
+ ):
255
367
  """
256
368
  Generates connection plot displaying divergence data
257
- config: A BMTK simulation config
369
+ config: A BMTK simulation config
258
370
  sources: network name(s) to plot
259
371
  targets: network name(s) to plot
260
- sids: source node identifier
372
+ sids: source node identifier
261
373
  tids: target node identifier
262
374
  no_prepend_pop: dictates if population name is displayed before sid or tid when displaying graph
263
375
  save_file: If plot should be saved
@@ -278,43 +390,73 @@ def divergence_connection_matrix(config=None,title=None,sources=None, targets=No
278
390
  else:
279
391
  tids = []
280
392
 
281
- syn_info, data, source_labels, target_labels = util.connection_divergence(config=config,nodes=None,edges=None,sources=sources,targets=targets,sids=sids,tids=tids,prepend_pop=not no_prepend_pop,convergence=convergence,method=method,include_gap=include_gap)
282
-
283
-
284
- #data, labels = util.connection_divergence_average(config=config,nodes=nodes,edges=edges,populations=populations)
393
+ syn_info, data, source_labels, target_labels = util.connection_divergence(
394
+ config=config,
395
+ nodes=None,
396
+ edges=None,
397
+ sources=sources,
398
+ targets=targets,
399
+ sids=sids,
400
+ tids=tids,
401
+ prepend_pop=not no_prepend_pop,
402
+ convergence=convergence,
403
+ method=method,
404
+ include_gap=include_gap,
405
+ )
285
406
 
286
- if title == None or title=="":
407
+ # data, labels = util.connection_divergence_average(config=config,nodes=nodes,edges=edges,populations=populations)
287
408
 
288
- if method == 'min':
409
+ if title is None or title == "":
410
+ if method == "min":
289
411
  title = "Minimum "
290
- elif method == 'max':
412
+ elif method == "max":
291
413
  title = "Maximum "
292
- elif method == 'std':
414
+ elif method == "std":
293
415
  title = "Standard Deviation "
294
- elif method == 'mean':
416
+ elif method == "mean":
295
417
  title = "Mean "
296
- else:
297
- title = 'Mean + Std '
418
+ else:
419
+ title = "Mean + Std "
298
420
 
299
421
  if convergence:
300
422
  title = title + "Synaptic Convergence"
301
423
  else:
302
424
  title = title + "Synaptic Divergence"
303
425
  if return_dict:
304
- dict = plot_connection_info(syn_info,data,source_labels,target_labels,title, save_file=save_file,return_dict=return_dict)
426
+ dict = plot_connection_info(
427
+ syn_info,
428
+ data,
429
+ source_labels,
430
+ target_labels,
431
+ title,
432
+ save_file=save_file,
433
+ return_dict=return_dict,
434
+ )
305
435
  return dict
306
436
  else:
307
- plot_connection_info(syn_info,data,source_labels,target_labels,title, save_file=save_file)
437
+ plot_connection_info(
438
+ syn_info, data, source_labels, target_labels, title, save_file=save_file
439
+ )
308
440
  return
309
441
 
310
442
 
311
- def gap_junction_matrix(config=None,title=None,sources=None, targets=None, sids=None,tids=None, no_prepend_pop=False,save_file=None,method='convergence'):
443
+ def gap_junction_matrix(
444
+ config=None,
445
+ title=None,
446
+ sources=None,
447
+ targets=None,
448
+ sids=None,
449
+ tids=None,
450
+ no_prepend_pop=False,
451
+ save_file=None,
452
+ method="convergence",
453
+ ):
312
454
  """
313
455
  Generates connection plot displaying gap junction data.
314
- config: A BMTK simulation config
456
+ config: A BMTK simulation config
315
457
  sources: network name(s) to plot
316
458
  targets: network name(s) to plot
317
- sids: source node identifier
459
+ sids: source node identifier
318
460
  tids: target node identifier
319
461
  no_prepend_pop: dictates if population name is displayed before sid or tid when displaying graph
320
462
  save_file: If plot should be saved
@@ -324,7 +466,7 @@ def gap_junction_matrix(config=None,title=None,sources=None, targets=None, sids=
324
466
  raise Exception("config not defined")
325
467
  if not sources or not targets:
326
468
  raise Exception("Sources or targets not defined")
327
- if method !='convergence' and method!='percent':
469
+ if method != "convergence" and method != "percent":
328
470
  raise Exception("type must be 'convergence' or 'percent'")
329
471
  sources = sources.split(",")
330
472
  targets = targets.split(",")
@@ -336,16 +478,25 @@ def gap_junction_matrix(config=None,title=None,sources=None, targets=None, sids=
336
478
  tids = tids.split(",")
337
479
  else:
338
480
  tids = []
339
- syn_info, data, source_labels, target_labels = util.gap_junction_connections(config=config,nodes=None,edges=None,sources=sources,targets=targets,sids=sids,tids=tids,prepend_pop=not no_prepend_pop,method=method)
340
-
341
-
481
+ syn_info, data, source_labels, target_labels = util.gap_junction_connections(
482
+ config=config,
483
+ nodes=None,
484
+ edges=None,
485
+ sources=sources,
486
+ targets=targets,
487
+ sids=sids,
488
+ tids=tids,
489
+ prepend_pop=not no_prepend_pop,
490
+ method=method,
491
+ )
492
+
342
493
  def filter_rows(syn_info, data, source_labels, target_labels):
343
494
  """
344
495
  Filters out rows in a connectivity matrix that contain only NaN or zero values.
345
-
496
+
346
497
  This function is used to clean up connection matrices by removing rows that have no meaningful data,
347
498
  which helps create more informative visualizations of network connectivity.
348
-
499
+
349
500
  Parameters:
350
501
  -----------
351
502
  syn_info : numpy.ndarray
@@ -356,7 +507,7 @@ def gap_junction_matrix(config=None,title=None,sources=None, targets=None, sids=
356
507
  List of labels for the source populations corresponding to rows in the data matrix.
357
508
  target_labels : list
358
509
  List of labels for the target populations corresponding to columns in the data matrix.
359
-
510
+
360
511
  Returns:
361
512
  --------
362
513
  tuple
@@ -375,11 +526,11 @@ def gap_junction_matrix(config=None,title=None,sources=None, targets=None, sids=
375
526
  def filter_rows_and_columns(syn_info, data, source_labels, target_labels):
376
527
  """
377
528
  Filters out both rows and columns in a connectivity matrix that contain only NaN or zero values.
378
-
529
+
379
530
  This function performs a two-step filtering process: first removing rows with no data,
380
531
  then transposing the matrix and removing columns with no data (by treating them as rows).
381
532
  This creates a cleaner, more informative connectivity matrix visualization.
382
-
533
+
383
534
  Parameters:
384
535
  -----------
385
536
  syn_info : numpy.ndarray
@@ -390,7 +541,7 @@ def gap_junction_matrix(config=None,title=None,sources=None, targets=None, sids=
390
541
  List of labels for the source populations corresponding to rows in the data matrix.
391
542
  target_labels : list
392
543
  List of labels for the target populations corresponding to columns in the data matrix.
393
-
544
+
394
545
  Returns:
395
546
  --------
396
547
  tuple
@@ -398,7 +549,9 @@ def gap_junction_matrix(config=None,title=None,sources=None, targets=None, sids=
398
549
  invalid rows and columns removed.
399
550
  """
400
551
  # Filter rows first
401
- syn_info, data, source_labels, target_labels = filter_rows(syn_info, data, source_labels, target_labels)
552
+ syn_info, data, source_labels, target_labels = filter_rows(
553
+ syn_info, data, source_labels, target_labels
554
+ )
402
555
 
403
556
  # Transpose data to filter columns
404
557
  transposed_syn_info = np.transpose(syn_info)
@@ -407,7 +560,12 @@ def gap_junction_matrix(config=None,title=None,sources=None, targets=None, sids=
407
560
  transposed_target_labels = source_labels
408
561
 
409
562
  # Filter columns (by treating them as rows in transposed data)
410
- transposed_syn_info, transposed_data, transposed_source_labels, transposed_target_labels = filter_rows(
563
+ (
564
+ transposed_syn_info,
565
+ transposed_data,
566
+ transposed_source_labels,
567
+ transposed_target_labels,
568
+ ) = filter_rows(
411
569
  transposed_syn_info, transposed_data, transposed_source_labels, transposed_target_labels
412
570
  )
413
571
 
@@ -419,40 +577,54 @@ def gap_junction_matrix(config=None,title=None,sources=None, targets=None, sids=
419
577
 
420
578
  return filtered_syn_info, filtered_data, filtered_source_labels, filtered_target_labels
421
579
 
422
-
423
- syn_info, data, source_labels, target_labels = filter_rows_and_columns(syn_info, data, source_labels, target_labels)
580
+ syn_info, data, source_labels, target_labels = filter_rows_and_columns(
581
+ syn_info, data, source_labels, target_labels
582
+ )
424
583
 
425
- if title == None or title=="":
426
- title = 'Gap Junction'
427
- if method == 'convergence':
428
- title+=' Syn Convergence'
429
- elif method == 'percent':
430
- title+=' Percent Connectivity'
431
- plot_connection_info(syn_info,data,source_labels,target_labels,title, save_file=save_file)
584
+ if title is None or title == "":
585
+ title = "Gap Junction"
586
+ if method == "convergence":
587
+ title += " Syn Convergence"
588
+ elif method == "percent":
589
+ title += " Percent Connectivity"
590
+ plot_connection_info(syn_info, data, source_labels, target_labels, title, save_file=save_file)
432
591
  return
433
592
 
434
593
 
435
- def connection_histogram(config=None,nodes=None,edges=None,sources=[],targets=[],sids=[],tids=[],no_prepend_pop=True,synaptic_info='0',
436
- source_cell = None,target_cell = None,include_gap=True):
594
+ def connection_histogram(
595
+ config=None,
596
+ nodes=None,
597
+ edges=None,
598
+ sources=[],
599
+ targets=[],
600
+ sids=[],
601
+ tids=[],
602
+ no_prepend_pop=True,
603
+ synaptic_info="0",
604
+ source_cell=None,
605
+ target_cell=None,
606
+ include_gap=True,
607
+ ):
437
608
  """
438
609
  Generates histogram of number of connections individual cells in a population receieve from another population
439
- config: A BMTK simulation config
610
+ config: A BMTK simulation config
440
611
  sources: network name(s) to plot
441
612
  targets: network name(s) to plot
442
- sids: source node identifier
613
+ sids: source node identifier
443
614
  tids: target node identifier
444
615
  no_prepend_pop: dictates if population name is displayed before sid or tid when displaying graph
445
616
  source_cell: where connections are coming from
446
617
  target_cell: where connections on coming onto
447
618
  save_file: If plot should be saved
448
619
  """
620
+
449
621
  def connection_pair_histogram(**kwargs):
450
622
  """
451
623
  Creates a histogram showing the distribution of connection counts between a specific source and target cell type.
452
-
624
+
453
625
  This function is designed to be used with the relation_matrix utility and will only create histograms
454
626
  for the specified source and target cell types, ignoring all other combinations.
455
-
627
+
456
628
  Parameters:
457
629
  -----------
458
630
  kwargs : dict
@@ -462,7 +634,7 @@ def connection_histogram(config=None,nodes=None,edges=None,sources=[],targets=[]
462
634
  - tid: Column name for target ID type in the edges DataFrame
463
635
  - source_id: Value to filter edges by source ID type
464
636
  - target_id: Value to filter edges by target ID type
465
-
637
+
466
638
  Global parameters used:
467
639
  ---------------------
468
640
  source_cell : str
@@ -471,37 +643,41 @@ def connection_histogram(config=None,nodes=None,edges=None,sources=[],targets=[]
471
643
  The target cell type to plot.
472
644
  include_gap : bool
473
645
  Whether to include gap junctions in the analysis. If False, gap junctions are excluded.
474
-
646
+
475
647
  Returns:
476
648
  --------
477
649
  None
478
650
  Displays a histogram showing the distribution of connection counts.
479
651
  """
480
- edges = kwargs["edges"]
652
+ edges = kwargs["edges"]
481
653
  source_id_type = kwargs["sid"]
482
654
  target_id_type = kwargs["tid"]
483
655
  source_id = kwargs["source_id"]
484
656
  target_id = kwargs["target_id"]
485
657
  if source_id == source_cell and target_id == target_cell:
486
- temp = edges[(edges[source_id_type] == source_id) & (edges[target_id_type]==target_id)]
487
- if include_gap == False:
488
- temp = temp[temp['is_gap_junction'] != True]
489
- node_pairs = temp.groupby('target_node_id')['source_node_id'].count()
658
+ temp = edges[
659
+ (edges[source_id_type] == source_id) & (edges[target_id_type] == target_id)
660
+ ]
661
+ if not include_gap:
662
+ temp = temp[~temp["is_gap_junction"]]
663
+ node_pairs = temp.groupby("target_node_id")["source_node_id"].count()
490
664
  try:
491
665
  conn_mean = statistics.mean(node_pairs.values)
492
666
  conn_std = statistics.stdev(node_pairs.values)
493
667
  conn_median = statistics.median(node_pairs.values)
494
- label = "mean {:.2f} std {:.2f} median {:.2f}".format(conn_mean,conn_std,conn_median)
495
- except: # lazy fix for std not calculated with 1 node
668
+ label = "mean {:.2f} std {:.2f} median {:.2f}".format(
669
+ conn_mean, conn_std, conn_median
670
+ )
671
+ except: # lazy fix for std not calculated with 1 node
496
672
  conn_mean = statistics.mean(node_pairs.values)
497
673
  conn_median = statistics.median(node_pairs.values)
498
- label = "mean {:.2f} median {:.2f}".format(conn_mean,conn_median)
499
- plt.hist(node_pairs.values,density=False,bins='auto',stacked=True,label=label)
674
+ label = "mean {:.2f} median {:.2f}".format(conn_mean, conn_median)
675
+ plt.hist(node_pairs.values, density=False, bins="auto", stacked=True, label=label)
500
676
  plt.legend()
501
- plt.xlabel("# of conns from {} to {}".format(source_cell,target_cell))
677
+ plt.xlabel("# of conns from {} to {}".format(source_cell, target_cell))
502
678
  plt.ylabel("# of cells")
503
679
  plt.show()
504
- else: # dont care about other cell pairs so pass
680
+ else: # dont care about other cell pairs so pass
505
681
  pass
506
682
 
507
683
  if not config:
@@ -518,18 +694,35 @@ def connection_histogram(config=None,nodes=None,edges=None,sources=[],targets=[]
518
694
  tids = tids.split(",")
519
695
  else:
520
696
  tids = []
521
- util.relation_matrix(config,nodes,edges,sources,targets,sids,tids,not no_prepend_pop,relation_func=connection_pair_histogram,synaptic_info=synaptic_info)
697
+ util.relation_matrix(
698
+ config,
699
+ nodes,
700
+ edges,
701
+ sources,
702
+ targets,
703
+ sids,
704
+ tids,
705
+ not no_prepend_pop,
706
+ relation_func=connection_pair_histogram,
707
+ synaptic_info=synaptic_info,
708
+ )
522
709
 
523
710
 
524
- def connection_distance(config: str,sources: str,targets: str,
525
- source_cell_id: int,target_id_type: str,ignore_z:bool=False) -> None:
711
+ def connection_distance(
712
+ config: str,
713
+ sources: str,
714
+ targets: str,
715
+ source_cell_id: int,
716
+ target_id_type: str,
717
+ ignore_z: bool = False,
718
+ ) -> None:
526
719
  """
527
720
  Plots the 3D spatial distribution of target nodes relative to a source node
528
721
  and a histogram of distances from the source node to each target node.
529
722
 
530
723
  Parameters:
531
724
  ----------
532
- config: (str) A BMTK simulation config
725
+ config: (str) A BMTK simulation config
533
726
  sources: (str) network name(s) to plot
534
727
  targets: (str) network name(s) to plot
535
728
  source_cell_id : (int) ID of the source cell for calculating distances to target nodes.
@@ -541,22 +734,22 @@ def connection_distance(config: str,sources: str,targets: str,
541
734
  raise Exception("config not defined")
542
735
  if not sources or not targets:
543
736
  raise Exception("Sources or targets not defined")
544
- #if source != target:
545
- #raise Exception("Code is setup for source and target to be the same! Look at source code for function to add feature")
546
-
737
+ # if source != target:
738
+ # raise Exception("Code is setup for source and target to be the same! Look at source code for function to add feature")
739
+
547
740
  # Load nodes and edges based on config file
548
741
  nodes, edges = util.load_nodes_edges_from_config(config)
549
-
742
+
550
743
  edge_network = sources + "_to_" + targets
551
744
  node_network = sources
552
745
 
553
746
  # Filter edges to obtain connections originating from the source node
554
747
  edge = edges[edge_network]
555
- edge = edge[edge['source_node_id'] == source_cell_id]
748
+ edge = edge[edge["source_node_id"] == source_cell_id]
556
749
  if target_id_type:
557
- edge = edge[edge['target_query'].str.contains(target_id_type, na=False)]
750
+ edge = edge[edge["target_query"].str.contains(target_id_type, na=False)]
558
751
 
559
- target_node_ids = edge['target_node_id']
752
+ target_node_ids = edge["target_node_id"]
560
753
 
561
754
  # Filter nodes to obtain only the target and source nodes
562
755
  node = nodes[node_network]
@@ -565,24 +758,40 @@ def connection_distance(config: str,sources: str,targets: str,
565
758
 
566
759
  # Calculate distances between source node and each target node
567
760
  if ignore_z:
568
- target_positions = target_nodes[['pos_x', 'pos_y']].values
569
- source_position = np.array([source_node['pos_x'], source_node['pos_y']]).ravel() # Ensure 1D shape
761
+ target_positions = target_nodes[["pos_x", "pos_y"]].values
762
+ source_position = np.array(
763
+ [source_node["pos_x"], source_node["pos_y"]]
764
+ ).ravel() # Ensure 1D shape
570
765
  else:
571
- target_positions = target_nodes[['pos_x', 'pos_y', 'pos_z']].values
572
- source_position = np.array([source_node['pos_x'], source_node['pos_y'], source_node['pos_z']]).ravel() # Ensure 1D shape
766
+ target_positions = target_nodes[["pos_x", "pos_y", "pos_z"]].values
767
+ source_position = np.array(
768
+ [source_node["pos_x"], source_node["pos_y"], source_node["pos_z"]]
769
+ ).ravel() # Ensure 1D shape
573
770
  distances = np.linalg.norm(target_positions - source_position, axis=1)
574
771
 
575
772
  # Plot positions of source and target nodes in 3D space or 2D
576
773
  if ignore_z:
577
- fig = plt.figure(figsize=(8, 6))
774
+ fig = plt.figure(figsize=(8, 6))
578
775
  ax = fig.add_subplot(111)
579
- ax.scatter(target_nodes['pos_x'], target_nodes['pos_y'], c='blue', label="target cells")
580
- ax.scatter(source_node['pos_x'], source_node['pos_y'], c='red', label="source cell")
776
+ ax.scatter(target_nodes["pos_x"], target_nodes["pos_y"], c="blue", label="target cells")
777
+ ax.scatter(source_node["pos_x"], source_node["pos_y"], c="red", label="source cell")
581
778
  else:
582
- fig = plt.figure(figsize=(8, 6))
583
- ax = fig.add_subplot(111, projection='3d')
584
- ax.scatter(target_nodes['pos_x'], target_nodes['pos_y'], target_nodes['pos_z'], c='blue', label="target cells")
585
- ax.scatter(source_node['pos_x'], source_node['pos_y'], source_node['pos_z'], c='red', label="source cell")
779
+ fig = plt.figure(figsize=(8, 6))
780
+ ax = fig.add_subplot(111, projection="3d")
781
+ ax.scatter(
782
+ target_nodes["pos_x"],
783
+ target_nodes["pos_y"],
784
+ target_nodes["pos_z"],
785
+ c="blue",
786
+ label="target cells",
787
+ )
788
+ ax.scatter(
789
+ source_node["pos_x"],
790
+ source_node["pos_y"],
791
+ source_node["pos_z"],
792
+ c="red",
793
+ label="source cell",
794
+ )
586
795
 
587
796
  # Optional: Add text annotations for distances
588
797
  # for i, distance in enumerate(distances):
@@ -593,23 +802,36 @@ def connection_distance(config: str,sources: str,targets: str,
593
802
  plt.show()
594
803
 
595
804
  # Plot distances in a separate 2D plot
596
- plt.figure(figsize=(8, 6))
597
- plt.hist(distances, bins=20, color='blue', edgecolor='black')
805
+ plt.figure(figsize=(8, 6))
806
+ plt.hist(distances, bins=20, color="blue", edgecolor="black")
598
807
  plt.xlabel("Distance")
599
808
  plt.ylabel("Count")
600
- plt.title(f"Distance from Source Node to Each Target Node")
809
+ plt.title("Distance from Source Node to Each Target Node")
601
810
  plt.grid(True)
602
811
  plt.show()
603
812
 
604
813
 
605
- def edge_histogram_matrix(config=None,sources = None,targets=None,sids=None,tids=None,no_prepend_pop=None,edge_property = None,time = None,time_compare = None,report=None,title=None,save_file=None):
814
+ def edge_histogram_matrix(
815
+ config=None,
816
+ sources=None,
817
+ targets=None,
818
+ sids=None,
819
+ tids=None,
820
+ no_prepend_pop=None,
821
+ edge_property=None,
822
+ time=None,
823
+ time_compare=None,
824
+ report=None,
825
+ title=None,
826
+ save_file=None,
827
+ ):
606
828
  """
607
829
  Generates a matrix of histograms showing the distribution of edge properties between different populations.
608
-
609
- This function creates a grid of histograms where each cell in the grid represents the distribution of a
610
- specific edge property (e.g., synaptic weights, delays) between a source population (row) and
830
+
831
+ This function creates a grid of histograms where each cell in the grid represents the distribution of a
832
+ specific edge property (e.g., synaptic weights, delays) between a source population (row) and
611
833
  target population (column).
612
-
834
+
613
835
  Parameters:
614
836
  -----------
615
837
  config : str
@@ -636,13 +858,13 @@ def edge_histogram_matrix(config=None,sources = None,targets=None,sids=None,tids
636
858
  Custom title for the plot.
637
859
  save_file : str, optional
638
860
  Path to save the generated plot.
639
-
861
+
640
862
  Returns:
641
863
  --------
642
864
  None
643
865
  Displays a matrix of histograms.
644
866
  """
645
-
867
+
646
868
  if not config:
647
869
  raise Exception("config not defined")
648
870
  if not sources or not targets:
@@ -660,34 +882,48 @@ def edge_histogram_matrix(config=None,sources = None,targets=None,sids=None,tids
660
882
  if time_compare:
661
883
  time_compare = int(time_compare)
662
884
 
663
- data, source_labels, target_labels = util.edge_property_matrix(edge_property,nodes=None,edges=None,config=config,sources=sources,targets=targets,sids=sids,tids=tids,prepend_pop=not no_prepend_pop,report=report,time=time,time_compare=time_compare)
885
+ data, source_labels, target_labels = util.edge_property_matrix(
886
+ edge_property,
887
+ nodes=None,
888
+ edges=None,
889
+ config=config,
890
+ sources=sources,
891
+ targets=targets,
892
+ sids=sids,
893
+ tids=tids,
894
+ prepend_pop=not no_prepend_pop,
895
+ report=report,
896
+ time=time,
897
+ time_compare=time_compare,
898
+ )
664
899
 
665
900
  # Fantastic resource
666
- # https://stackoverflow.com/questions/7941207/is-there-a-function-to-make-scatterplot-matrices-in-matplotlib
901
+ # https://stackoverflow.com/questions/7941207/is-there-a-function-to-make-scatterplot-matrices-in-matplotlib
667
902
  num_src, num_tar = data.shape
668
- fig, axes = plt.subplots(nrows=num_src, ncols=num_tar, figsize=(12,12))
903
+ fig, axes = plt.subplots(nrows=num_src, ncols=num_tar, figsize=(12, 12))
669
904
  fig.subplots_adjust(hspace=0.5, wspace=0.5)
670
905
 
671
906
  for x in range(num_src):
672
907
  for y in range(num_tar):
673
- axes[x,y].hist(data[x][y])
908
+ axes[x, y].hist(data[x][y])
674
909
 
675
- if x == num_src-1:
676
- axes[x,y].set_xlabel(target_labels[y])
910
+ if x == num_src - 1:
911
+ axes[x, y].set_xlabel(target_labels[y])
677
912
  if y == 0:
678
- axes[x,y].set_ylabel(source_labels[x])
913
+ axes[x, y].set_ylabel(source_labels[x])
679
914
 
680
915
  tt = edge_property + " Histogram Matrix"
681
916
  if title:
682
917
  tt = title
683
918
  st = fig.suptitle(tt, fontsize=14)
684
- fig.text(0.5, 0.04, 'Target', ha='center')
685
- fig.text(0.04, 0.5, 'Source', va='center', rotation='vertical')
919
+ fig.text(0.5, 0.04, "Target", ha="center")
920
+ fig.text(0.04, 0.5, "Source", va="center", rotation="vertical")
686
921
  plt.draw()
687
922
 
688
923
 
689
- def distance_delay_plot(simulation_config: str,source: str,target: str,
690
- group_by: str,sid: str,tid: str) -> None:
924
+ def distance_delay_plot(
925
+ simulation_config: str, source: str, target: str, group_by: str, sid: str, tid: str
926
+ ) -> None:
691
927
  """
692
928
  Plots the relationship between the distance and delay of connections between nodes in a neural network simulation.
693
929
 
@@ -708,25 +944,27 @@ def distance_delay_plot(simulation_config: str,source: str,target: str,
708
944
  """
709
945
  nodes, edges = util.load_nodes_edges_from_config(simulation_config)
710
946
  nodes = nodes[target]
711
- #node id is index of nodes df
947
+ # node id is index of nodes df
712
948
  node_id_source = nodes[nodes[group_by] == sid].index
713
949
  node_id_target = nodes[nodes[group_by] == tid].index
714
950
 
715
- edges = edges[f'{source}_to_{target}']
716
- edges = edges[edges['source_node_id'].isin(node_id_source) & edges['target_node_id'].isin(node_id_target)]
951
+ edges = edges[f"{source}_to_{target}"]
952
+ edges = edges[
953
+ edges["source_node_id"].isin(node_id_source) & edges["target_node_id"].isin(node_id_target)
954
+ ]
717
955
 
718
956
  stuff_to_plot = []
719
957
  for index, row in edges.iterrows():
720
958
  try:
721
- source_node = row['source_node_id']
722
- target_node = row['target_node_id']
723
-
724
- source_pos = nodes.loc[[source_node], ['pos_x', 'pos_y', 'pos_z']]
725
- target_pos = nodes.loc[[target_node], ['pos_x', 'pos_y', 'pos_z']]
726
-
959
+ source_node = row["source_node_id"]
960
+ target_node = row["target_node_id"]
961
+
962
+ source_pos = nodes.loc[[source_node], ["pos_x", "pos_y", "pos_z"]]
963
+ target_pos = nodes.loc[[target_node], ["pos_x", "pos_y", "pos_z"]]
964
+
727
965
  distance = np.linalg.norm(source_pos.values - target_pos.values)
728
966
 
729
- delay = row['delay'] # This line may raise KeyError
967
+ delay = row["delay"] # This line may raise KeyError
730
968
  stuff_to_plot.append([distance, delay])
731
969
 
732
970
  except KeyError as e:
@@ -737,9 +975,9 @@ def distance_delay_plot(simulation_config: str,source: str,target: str,
737
975
  print(f"Unexpected error at edge index {index}: {e}")
738
976
 
739
977
  plt.scatter([x[0] for x in stuff_to_plot], [x[1] for x in stuff_to_plot])
740
- plt.xlabel('Distance')
741
- plt.ylabel('Delay')
742
- plt.title(f'Distance vs Delay for edge between {sid} and {tid}')
978
+ plt.xlabel("Distance")
979
+ plt.ylabel("Delay")
980
+ plt.title(f"Distance vs Delay for edge between {sid} and {tid}")
743
981
  plt.show()
744
982
 
745
983
 
@@ -748,7 +986,7 @@ def plot_synapse_location_histograms(config, target_model, source=None, target=N
748
986
  generates a histogram of the positions of the synapses on a cell broken down by section
749
987
  config: a BMTK config
750
988
  target_model: the name of the model_template used when building the BMTK node
751
- source: The source BMTK network
989
+ source: The source BMTK network
752
990
  target: The target BMTK network
753
991
  """
754
992
  # Load mechanisms and template
@@ -758,37 +996,36 @@ def plot_synapse_location_histograms(config, target_model, source=None, target=N
758
996
  # Load node and edge data
759
997
  nodes, edges = util.load_nodes_edges_from_config(config)
760
998
  nodes = nodes[source]
761
- edges = edges[f'{source}_to_{target}']
999
+ edges = edges[f"{source}_to_{target}"]
762
1000
 
763
1001
  # Map target_node_id to model_template
764
- edges['target_model_template'] = edges['target_node_id'].map(nodes['model_template'])
1002
+ edges["target_model_template"] = edges["target_node_id"].map(nodes["model_template"])
765
1003
 
766
1004
  # Map source_node_id to pop_name
767
- edges['source_pop_name'] = edges['source_node_id'].map(nodes['pop_name'])
1005
+ edges["source_pop_name"] = edges["source_node_id"].map(nodes["pop_name"])
768
1006
 
769
- edges = edges[edges['target_model_template'] == target_model]
1007
+ edges = edges[edges["target_model_template"] == target_model]
770
1008
 
771
1009
  # Create the cell model from target model
772
- cell = getattr(h, target_model.split(':')[1])()
773
-
1010
+ cell = getattr(h, target_model.split(":")[1])()
1011
+
774
1012
  # Create a mapping from section index to section name
775
1013
  section_id_to_name = {}
776
1014
  for idx, sec in enumerate(cell.all):
777
1015
  section_id_to_name[idx] = sec.name()
778
1016
 
779
1017
  # Add a new column with section names based on afferent_section_id
780
- edges['afferent_section_name'] = edges['afferent_section_id'].map(section_id_to_name)
1018
+ edges["afferent_section_name"] = edges["afferent_section_id"].map(section_id_to_name)
781
1019
 
782
1020
  # Get unique sections and source populations
783
- unique_pops = edges['source_pop_name'].unique()
1021
+ unique_pops = edges["source_pop_name"].unique()
784
1022
 
785
1023
  # Filter to only include sections with data
786
- section_counts = edges['afferent_section_name'].value_counts()
1024
+ section_counts = edges["afferent_section_name"].value_counts()
787
1025
  sections_with_data = section_counts[section_counts > 0].index.tolist()
788
1026
 
789
-
790
1027
  # Create a figure with subplots for each section
791
- plt.figure(figsize=(8,12))
1028
+ plt.figure(figsize=(8, 12))
792
1029
 
793
1030
  # Color map for source populations
794
1031
  color_map = plt.cm.tab10(np.linspace(0, 1, len(unique_pops)))
@@ -796,30 +1033,37 @@ def plot_synapse_location_histograms(config, target_model, source=None, target=N
796
1033
 
797
1034
  # Create a histogram for each section
798
1035
  for i, section in enumerate(sections_with_data):
799
- ax = plt.subplot(len(sections_with_data), 1, i+1)
800
-
1036
+ ax = plt.subplot(len(sections_with_data), 1, i + 1)
1037
+
801
1038
  # Get data for this section
802
- section_data = edges[edges['afferent_section_name'] == section]
803
-
1039
+ section_data = edges[edges["afferent_section_name"] == section]
1040
+
804
1041
  # Group by source population
805
- for pop_name, pop_group in section_data.groupby('source_pop_name'):
1042
+ for pop_name, pop_group in section_data.groupby("source_pop_name"):
806
1043
  if len(pop_group) > 0:
807
- ax.hist(pop_group['afferent_section_pos'], bins=15, alpha=0.7,
808
- label=pop_name, color=pop_colors[pop_name])
809
-
1044
+ ax.hist(
1045
+ pop_group["afferent_section_pos"],
1046
+ bins=15,
1047
+ alpha=0.7,
1048
+ label=pop_name,
1049
+ color=pop_colors[pop_name],
1050
+ )
1051
+
810
1052
  # Set title and labels
811
1053
  ax.set_title(f"{section}", fontsize=10)
812
- ax.set_xlabel('Section Position', fontsize=8)
813
- ax.set_ylabel('Frequency', fontsize=8)
1054
+ ax.set_xlabel("Section Position", fontsize=8)
1055
+ ax.set_ylabel("Frequency", fontsize=8)
814
1056
  ax.tick_params(labelsize=7)
815
1057
  ax.grid(True, alpha=0.3)
816
-
1058
+
817
1059
  # Only add legend to the first plot
818
1060
  if i == 0:
819
1061
  ax.legend(fontsize=8)
820
1062
 
821
1063
  plt.tight_layout()
822
- plt.suptitle('Connection Distribution by Cell Section and Source Population', fontsize=16, y=1.02)
1064
+ plt.suptitle(
1065
+ "Connection Distribution by Cell Section and Source Population", fontsize=16, y=1.02
1066
+ )
823
1067
  if is_notebook:
824
1068
  plt.show()
825
1069
  else:
@@ -828,105 +1072,147 @@ def plot_synapse_location_histograms(config, target_model, source=None, target=N
828
1072
  # Create a summary table
829
1073
  print("Summary of connections by section and source population:")
830
1074
  pivot_table = edges.pivot_table(
831
- values='afferent_section_id',
832
- index='afferent_section_name',
833
- columns='source_pop_name',
834
- aggfunc='count',
835
- fill_value=0
1075
+ values="afferent_section_id",
1076
+ index="afferent_section_name",
1077
+ columns="source_pop_name",
1078
+ aggfunc="count",
1079
+ fill_value=0,
836
1080
  )
837
1081
  print(pivot_table)
838
1082
 
839
1083
 
840
- def plot_connection_info(text, num, source_labels, target_labels, title, syn_info='0', save_file=None, return_dict=None):
1084
+ def plot_connection_info(
1085
+ text, num, source_labels, target_labels, title, syn_info="0", save_file=None, return_dict=None
1086
+ ):
841
1087
  """
842
1088
  Function to plot connection information as a heatmap, including handling missing source and target values.
843
1089
  If there is no source or target, set the value to 0.
844
1090
  """
845
-
1091
+
846
1092
  # Ensure text dimensions match num dimensions
847
1093
  num_source = len(source_labels)
848
1094
  num_target = len(target_labels)
849
-
1095
+
850
1096
  # Set color map
851
- matplotlib.rc('image', cmap='viridis')
852
-
1097
+ matplotlib.rc("image", cmap="viridis")
1098
+
853
1099
  # Create figure and axis for the plot
854
1100
  fig1, ax1 = plt.subplots(figsize=(num_source, num_target))
855
- num = np.nan_to_num(num, nan=0) # replace NaN with 0
1101
+ num = np.nan_to_num(num, nan=0) # replace NaN with 0
856
1102
  im1 = ax1.imshow(num)
857
-
1103
+
858
1104
  # Set ticks and labels for source and target
859
1105
  ax1.set_xticks(list(np.arange(len(target_labels))))
860
1106
  ax1.set_yticks(list(np.arange(len(source_labels))))
861
1107
  ax1.set_xticklabels(target_labels)
862
- ax1.set_yticklabels(source_labels, size=12, weight='semibold')
863
-
1108
+ ax1.set_yticklabels(source_labels, size=12, weight="semibold")
1109
+
864
1110
  # Rotate the tick labels for better visibility
865
- plt.setp(ax1.get_xticklabels(), rotation=45, ha="right",
866
- rotation_mode="anchor", size=12, weight='semibold')
867
-
1111
+ plt.setp(
1112
+ ax1.get_xticklabels(),
1113
+ rotation=45,
1114
+ ha="right",
1115
+ rotation_mode="anchor",
1116
+ size=12,
1117
+ weight="semibold",
1118
+ )
1119
+
868
1120
  # Dictionary to store connection information
869
1121
  graph_dict = {}
870
-
1122
+
871
1123
  # Loop over data dimensions and create text annotations
872
1124
  for i in range(num_source):
873
1125
  for j in range(num_target):
874
1126
  # Get the edge info, or set it to '0' if it's missing
875
1127
  edge_info = text[i, j] if text[i, j] is not None else 0
876
-
1128
+
877
1129
  # Initialize the dictionary for the source node if not already done
878
1130
  if source_labels[i] not in graph_dict:
879
1131
  graph_dict[source_labels[i]] = {}
880
-
1132
+
881
1133
  # Add edge info for the target node
882
1134
  graph_dict[source_labels[i]][target_labels[j]] = edge_info
883
-
1135
+
884
1136
  # Set text annotations based on syn_info type
885
- if syn_info == '2' or syn_info == '3':
1137
+ if syn_info == "2" or syn_info == "3":
886
1138
  if num_source > 8 and num_source < 20:
887
- fig_text = ax1.text(j, i, edge_info,
888
- ha="center", va="center", color="w", rotation=37.5, size=8, weight='semibold')
1139
+ fig_text = ax1.text(
1140
+ j,
1141
+ i,
1142
+ edge_info,
1143
+ ha="center",
1144
+ va="center",
1145
+ color="w",
1146
+ rotation=37.5,
1147
+ size=8,
1148
+ weight="semibold",
1149
+ )
889
1150
  elif num_source > 20:
890
- fig_text = ax1.text(j, i, edge_info,
891
- ha="center", va="center", color="w", rotation=37.5, size=7, weight='semibold')
1151
+ fig_text = ax1.text(
1152
+ j,
1153
+ i,
1154
+ edge_info,
1155
+ ha="center",
1156
+ va="center",
1157
+ color="w",
1158
+ rotation=37.5,
1159
+ size=7,
1160
+ weight="semibold",
1161
+ )
892
1162
  else:
893
- fig_text = ax1.text(j, i, edge_info,
894
- ha="center", va="center", color="w", rotation=37.5, size=11, weight='semibold')
1163
+ fig_text = ax1.text(
1164
+ j,
1165
+ i,
1166
+ edge_info,
1167
+ ha="center",
1168
+ va="center",
1169
+ color="w",
1170
+ rotation=37.5,
1171
+ size=11,
1172
+ weight="semibold",
1173
+ )
895
1174
  else:
896
- fig_text = ax1.text(j, i, edge_info,
897
- ha="center", va="center", color="w", size=11, weight='semibold')
898
-
1175
+ fig_text = ax1.text(
1176
+ j, i, edge_info, ha="center", va="center", color="w", size=11, weight="semibold"
1177
+ )
1178
+
899
1179
  # Set labels and title for the plot
900
- ax1.set_ylabel('Source', size=11, weight='semibold')
901
- ax1.set_xlabel('Target', size=11, weight='semibold')
902
- ax1.set_title(title, size=20, weight='semibold')
903
-
1180
+ ax1.set_ylabel("Source", size=11, weight="semibold")
1181
+ ax1.set_xlabel("Target", size=11, weight="semibold")
1182
+ ax1.set_title(title, size=20, weight="semibold")
1183
+
904
1184
  # Display the plot or save it based on the environment and arguments
905
1185
  notebook = is_notebook() # Check if running in a Jupyter notebook
906
- if notebook == False:
1186
+ if not notebook:
907
1187
  fig1.show()
908
-
1188
+
909
1189
  if save_file:
910
1190
  plt.savefig(save_file)
911
-
1191
+
912
1192
  if return_dict:
913
1193
  return graph_dict
914
1194
  else:
915
1195
  return
916
1196
 
917
1197
 
918
- def connector_percent_matrix(csv_path: str = None, exclude_strings=None, assemb_key=None, title: str = 'Percent connection matrix', pop_order=None) -> None:
1198
+ def connector_percent_matrix(
1199
+ csv_path: str = None,
1200
+ exclude_strings=None,
1201
+ assemb_key=None,
1202
+ title: str = "Percent connection matrix",
1203
+ pop_order=None,
1204
+ ) -> None:
919
1205
  """
920
1206
  Generates and plots a connection matrix based on connection probabilities from a CSV file produced by bmtool.connector.
921
1207
 
922
- This function is useful for visualizing percent connectivity while factoring in population distance and other parameters.
923
- It processes the connection data by filtering the 'Source' and 'Target' columns in the CSV, and displays the percentage of
1208
+ This function is useful for visualizing percent connectivity while factoring in population distance and other parameters.
1209
+ It processes the connection data by filtering the 'Source' and 'Target' columns in the CSV, and displays the percentage of
924
1210
  connected pairs for each population combination in a matrix.
925
1211
 
926
1212
  Parameters:
927
1213
  -----------
928
1214
  csv_path : str
929
- Path to the CSV file containing the connection data. The CSV should be an output from the bmtool.connector
1215
+ Path to the CSV file containing the connection data. The CSV should be an output from the bmtool.connector
930
1216
  classes, specifically generated by the `save_connection_report()` function.
931
1217
  exclude_strings : list of str, optional
932
1218
  List of strings to exclude rows where 'Source' or 'Target' contain these strings.
@@ -949,15 +1235,14 @@ def connector_percent_matrix(csv_path: str = None, exclude_strings=None, assemb_
949
1235
  # Filter the DataFrame based on exclude_strings
950
1236
  def filter_dataframe(df, column_name, exclude_strings):
951
1237
  def process_string(string):
952
-
953
1238
  match = re.search(r"\[\'(.*?)\'\]", string)
954
1239
  if exclude_strings and any(ex_string in string for ex_string in exclude_strings):
955
1240
  return None
956
1241
  elif match:
957
1242
  filtered_string = match.group(1)
958
- if 'Gap' in string:
959
- filtered_string = filtered_string + "-Gap"
960
-
1243
+ if "Gap" in string:
1244
+ filtered_string = filtered_string + "-Gap"
1245
+
961
1246
  if assemb_key:
962
1247
  if assemb_key in string:
963
1248
  filtered_string = filtered_string + assemb_key
@@ -965,65 +1250,73 @@ def connector_percent_matrix(csv_path: str = None, exclude_strings=None, assemb_
965
1250
  return filtered_string # Return matched string
966
1251
 
967
1252
  return string # If no match, return the original string
968
-
1253
+
969
1254
  df[column_name] = df[column_name].apply(process_string)
970
1255
  df = df.dropna(subset=[column_name])
971
-
1256
+
972
1257
  return df
973
1258
 
974
- df = filter_dataframe(df, 'Source', exclude_strings)
975
- df = filter_dataframe(df, 'Target', exclude_strings)
976
-
977
- #process assem rows and combine them into one prob per assem type
1259
+ df = filter_dataframe(df, "Source", exclude_strings)
1260
+ df = filter_dataframe(df, "Target", exclude_strings)
1261
+
1262
+ # process assem rows and combine them into one prob per assem type
978
1263
  if assemb_key:
979
- assems = df[df['Source'].str.contains(assemb_key)]
980
- unique_sources = assems['Source'].unique()
1264
+ assems = df[df["Source"].str.contains(assemb_key)]
1265
+ unique_sources = assems["Source"].unique()
981
1266
 
982
1267
  for source in unique_sources:
983
- source_assems = assems[assems['Source'] == source]
984
- unique_targets = source_assems['Target'].unique() # Filter targets for the current source
1268
+ source_assems = assems[assems["Source"] == source]
1269
+ unique_targets = source_assems[
1270
+ "Target"
1271
+ ].unique() # Filter targets for the current source
985
1272
 
986
1273
  for target in unique_targets:
987
1274
  # Filter the assemblies with the current source and target
988
- unique_assems = source_assems[source_assems['Target'] == target]
989
-
1275
+ unique_assems = source_assems[source_assems["Target"] == target]
1276
+
990
1277
  # find the prob of a conn
991
1278
  forward_probs = []
992
- for _,row in unique_assems.iterrows():
1279
+ for _, row in unique_assems.iterrows():
993
1280
  selected_percentage = row[selected_column]
994
- selected_percentage = [float(p) for p in selected_percentage.strip('[]').split()]
1281
+ selected_percentage = [
1282
+ float(p) for p in selected_percentage.strip("[]").split()
1283
+ ]
995
1284
  if len(selected_percentage) == 1 or len(selected_percentage) == 2:
996
1285
  forward_probs.append(selected_percentage[0])
997
1286
  if len(selected_percentage) == 3:
998
1287
  forward_probs.append(selected_percentage[0])
999
1288
  forward_probs.append(selected_percentage[1])
1000
-
1289
+
1001
1290
  mean_probs = np.mean(forward_probs)
1002
1291
  source = source.replace(assemb_key, "")
1003
1292
  target = target.replace(assemb_key, "")
1004
- new_row = pd.DataFrame({
1005
- 'Source': [source],
1006
- 'Target': [target],
1007
- 'Percent connectionivity within possible connections': [mean_probs],
1008
- 'Percent connectionivity within all connections': [0]
1009
- })
1293
+ new_row = pd.DataFrame(
1294
+ {
1295
+ "Source": [source],
1296
+ "Target": [target],
1297
+ "Percent connectionivity within possible connections": [mean_probs],
1298
+ "Percent connectionivity within all connections": [0],
1299
+ }
1300
+ )
1010
1301
 
1011
1302
  df = pd.concat([df, new_row], ignore_index=False)
1012
-
1303
+
1013
1304
  # Prepare connection data
1014
1305
  connection_data = {}
1015
1306
  for _, row in df.iterrows():
1016
- source, target, selected_percentage = row['Source'], row['Target'], row[selected_column]
1307
+ source, target, selected_percentage = row["Source"], row["Target"], row[selected_column]
1017
1308
  if isinstance(selected_percentage, str):
1018
- selected_percentage = [float(p) for p in selected_percentage.strip('[]').split()]
1309
+ selected_percentage = [float(p) for p in selected_percentage.strip("[]").split()]
1019
1310
  connection_data[(source, target)] = selected_percentage
1020
1311
 
1021
1312
  # Determine population order
1022
- populations = sorted(list(set(df['Source'].unique()) | set(df['Target'].unique())))
1313
+ populations = sorted(list(set(df["Source"].unique()) | set(df["Target"].unique())))
1023
1314
  if pop_order:
1024
- populations = [pop for pop in pop_order if pop in populations] # Order according to pop_order, if provided
1315
+ populations = [
1316
+ pop for pop in pop_order if pop in populations
1317
+ ] # Order according to pop_order, if provided
1025
1318
  num_populations = len(populations)
1026
-
1319
+
1027
1320
  # Create an empty matrix and populate it
1028
1321
  connection_matrix = np.zeros((num_populations, num_populations), dtype=float)
1029
1322
  for (source, target), probabilities in connection_data.items():
@@ -1031,40 +1324,49 @@ def connector_percent_matrix(csv_path: str = None, exclude_strings=None, assemb_
1031
1324
  source_idx = populations.index(source)
1032
1325
  target_idx = populations.index(target)
1033
1326
 
1034
- if type(probabilities) == float:
1327
+ if isinstance(probabilities, float):
1035
1328
  connection_matrix[source_idx][target_idx] = probabilities
1036
1329
  elif len(probabilities) == 1:
1037
1330
  connection_matrix[source_idx][target_idx] = probabilities[0]
1038
1331
  elif len(probabilities) == 2:
1039
1332
  connection_matrix[source_idx][target_idx] = probabilities[0]
1040
1333
  elif len(probabilities) == 3:
1041
- connection_matrix[source_idx][target_idx] = probabilities[0]
1334
+ connection_matrix[source_idx][target_idx] = probabilities[0]
1042
1335
  connection_matrix[target_idx][source_idx] = probabilities[1]
1043
1336
  else:
1044
1337
  raise Exception("unsupported format")
1045
1338
 
1046
1339
  # Plotting
1047
1340
  fig, ax = plt.subplots(figsize=(10, 8))
1048
- im = ax.imshow(connection_matrix, cmap='viridis', interpolation='nearest')
1341
+ im = ax.imshow(connection_matrix, cmap="viridis", interpolation="nearest")
1049
1342
 
1050
1343
  # Add annotations
1051
1344
  for i in range(num_populations):
1052
1345
  for j in range(num_populations):
1053
- text = ax.text(j, i, f"{connection_matrix[i, j]:.2f}%", ha="center", va="center", color="w", size=10, weight='semibold')
1346
+ text = ax.text(
1347
+ j,
1348
+ i,
1349
+ f"{connection_matrix[i, j]:.2f}%",
1350
+ ha="center",
1351
+ va="center",
1352
+ color="w",
1353
+ size=10,
1354
+ weight="semibold",
1355
+ )
1054
1356
 
1055
1357
  # Add colorbar
1056
- plt.colorbar(im, label=f'{selected_column}')
1358
+ plt.colorbar(im, label=f"{selected_column}")
1057
1359
 
1058
1360
  # Set title and axis labels
1059
1361
  ax.set_title(title)
1060
- ax.set_xlabel('Target Population')
1061
- ax.set_ylabel('Source Population')
1362
+ ax.set_xlabel("Target Population")
1363
+ ax.set_ylabel("Source Population")
1062
1364
 
1063
1365
  # Set ticks and labels based on populations in specified order
1064
1366
  ax.set_xticks(np.arange(num_populations))
1065
1367
  ax.set_yticks(np.arange(num_populations))
1066
- ax.set_xticklabels(populations, rotation=45, ha="right", size=12, weight='semibold')
1067
- ax.set_yticklabels(populations, size=12, weight='semibold')
1368
+ ax.set_xticklabels(populations, rotation=45, ha="right", size=12, weight="semibold")
1369
+ ax.set_yticklabels(populations, size=12, weight="semibold")
1068
1370
 
1069
1371
  plt.tight_layout()
1070
1372
  plt.show()
@@ -1073,22 +1375,22 @@ def connector_percent_matrix(csv_path: str = None, exclude_strings=None, assemb_
1073
1375
  def plot_3d_positions(config=None, sources=None, sid=None, title=None, save_file=None, subset=None):
1074
1376
  """
1075
1377
  Plots a 3D graph of all cells with x, y, z location.
1076
-
1378
+
1077
1379
  Parameters:
1078
- - config: A BMTK simulation config
1079
- - sources: Which network(s) to plot
1380
+ - config: A BMTK simulation config
1381
+ - sources: Which network(s) to plot
1080
1382
  - sid: How to name cell groups
1081
1383
  - title: Plot title
1082
1384
  - save_file: If plot should be saved
1083
1385
  - subset: Take every Nth row. This will make plotting large network graphs easier to see.
1084
1386
  """
1085
-
1387
+
1086
1388
  if not config:
1087
1389
  raise Exception("config not defined")
1088
-
1390
+
1089
1391
  if sources is None:
1090
1392
  sources = "all"
1091
-
1393
+
1092
1394
  # Set group keys (e.g., node types)
1093
1395
  group_keys = sid
1094
1396
  if title is None:
@@ -1096,66 +1398,75 @@ def plot_3d_positions(config=None, sources=None, sid=None, title=None, save_file
1096
1398
 
1097
1399
  # Load nodes from the configuration
1098
1400
  nodes = util.load_nodes_from_config(config)
1099
-
1401
+
1100
1402
  # Get the list of populations to plot
1101
- if 'all' in sources:
1403
+ if "all" in sources:
1102
1404
  populations = list(nodes)
1103
1405
  else:
1104
1406
  populations = sources.split(",")
1105
-
1106
- # Split group_by into list
1407
+
1408
+ # Split group_by into list
1107
1409
  group_keys = group_keys.split(",")
1108
- group_keys += (len(populations) - len(group_keys)) * ["node_type_id"] # Extend the array to default values if not enough given
1410
+ group_keys += (len(populations) - len(group_keys)) * [
1411
+ "node_type_id"
1412
+ ] # Extend the array to default values if not enough given
1109
1413
  if len(group_keys) > 1:
1110
1414
  raise Exception("Only one group by is supported currently!")
1111
-
1415
+
1112
1416
  fig = plt.figure(figsize=(10, 10))
1113
- ax = fig.add_subplot(projection='3d')
1417
+ ax = fig.add_subplot(projection="3d")
1114
1418
  handles = []
1115
1419
 
1116
- for pop in (list(nodes)):
1117
-
1118
- if 'all' not in populations and pop not in populations:
1420
+ for pop in list(nodes):
1421
+ if "all" not in populations and pop not in populations:
1119
1422
  continue
1120
-
1423
+
1121
1424
  nodes_df = nodes[pop]
1122
- group_key = group_keys[0]
1123
-
1425
+ group_key = group_keys[0]
1426
+
1124
1427
  # If group_key is provided, ensure the column exists in the dataframe
1125
1428
  if group_key is not None:
1126
1429
  if group_key not in nodes_df:
1127
1430
  raise Exception(f"Could not find column '{group_key}' in {pop}")
1128
-
1431
+
1129
1432
  groupings = nodes_df.groupby(group_key)
1130
1433
  n_colors = nodes_df[group_key].nunique()
1131
1434
  color_norm = colors.Normalize(vmin=0, vmax=(n_colors - 1))
1132
- scalar_map = cmx.ScalarMappable(norm=color_norm, cmap='hsv')
1435
+ scalar_map = cmx.ScalarMappable(norm=color_norm, cmap="hsv")
1133
1436
  color_map = [scalar_map.to_rgba(i) for i in range(n_colors)]
1134
1437
  else:
1135
1438
  groupings = [(None, nodes_df)]
1136
- color_map = ['blue']
1439
+ color_map = ["blue"]
1137
1440
 
1138
1441
  # Loop over groupings and plot
1139
1442
  for color, (group_name, group_df) in zip(color_map, groupings):
1140
1443
  if "pos_x" not in group_df or "pos_y" not in group_df or "pos_z" not in group_df:
1141
- print(f"Warning: Missing position columns in group '{group_name}' for {pop}. Skipping this group.")
1444
+ print(
1445
+ f"Warning: Missing position columns in group '{group_name}' for {pop}. Skipping this group."
1446
+ )
1142
1447
  continue # Skip if position columns are missing
1143
1448
 
1144
1449
  # Subset the dataframe by taking every Nth row if subset is provided
1145
1450
  if subset is not None:
1146
1451
  group_df = group_df.iloc[::subset]
1147
1452
 
1148
- h = ax.scatter(group_df["pos_x"], group_df["pos_y"], group_df["pos_z"], color=color, label=group_name)
1453
+ h = ax.scatter(
1454
+ group_df["pos_x"],
1455
+ group_df["pos_y"],
1456
+ group_df["pos_z"],
1457
+ color=color,
1458
+ label=group_name,
1459
+ )
1149
1460
  handles.append(h)
1150
-
1461
+
1151
1462
  if not handles:
1152
1463
  print("No data to plot.")
1153
1464
  return
1154
-
1465
+
1155
1466
  # Set plot title and legend
1156
1467
  plt.title(title)
1157
1468
  plt.legend(handles=handles)
1158
-
1469
+
1159
1470
  # Draw the plot
1160
1471
  plt.draw()
1161
1472
 
@@ -1170,8 +1481,19 @@ def plot_3d_positions(config=None, sources=None, sid=None, title=None, save_file
1170
1481
  return ax
1171
1482
 
1172
1483
 
1173
- def plot_3d_cell_rotation(config=None, sources=None, sids=None, title=None, save_file=None, quiver_length=None, arrow_length_ratio=None, group=None, subset=None):
1484
+ def plot_3d_cell_rotation(
1485
+ config=None,
1486
+ sources=None,
1487
+ sids=None,
1488
+ title=None,
1489
+ save_file=None,
1490
+ quiver_length=None,
1491
+ arrow_length_ratio=None,
1492
+ group=None,
1493
+ subset=None,
1494
+ ):
1174
1495
  from scipy.spatial.transform import Rotation as R
1496
+
1175
1497
  if not config:
1176
1498
  raise Exception("config not defined")
1177
1499
 
@@ -1185,38 +1507,38 @@ def plot_3d_cell_rotation(config=None, sources=None, sids=None, title=None, save
1185
1507
 
1186
1508
  nodes = util.load_nodes_from_config(config)
1187
1509
 
1188
- if 'all' in sources:
1510
+ if "all" in sources:
1189
1511
  populations = list(nodes)
1190
1512
  else:
1191
1513
  populations = sources
1192
1514
 
1193
1515
  fig = plt.figure(figsize=(10, 10))
1194
- ax = fig.add_subplot(111, projection='3d')
1516
+ ax = fig.add_subplot(111, projection="3d")
1195
1517
  handles = []
1196
1518
 
1197
1519
  for nodes_key, group_key in zip(list(nodes), group_keys):
1198
- if 'all' not in populations and nodes_key not in populations:
1520
+ if "all" not in populations and nodes_key not in populations:
1199
1521
  continue
1200
1522
 
1201
1523
  nodes_df = nodes[nodes_key]
1202
1524
 
1203
1525
  if group_key is not None:
1204
1526
  if group_key not in nodes_df.columns:
1205
- raise Exception(f'Could not find column {group_key}')
1527
+ raise Exception(f"Could not find column {group_key}")
1206
1528
  groupings = nodes_df.groupby(group_key)
1207
1529
 
1208
1530
  n_colors = nodes_df[group_key].nunique()
1209
1531
  color_norm = colors.Normalize(vmin=0, vmax=(n_colors - 1))
1210
- scalar_map = cmx.ScalarMappable(norm=color_norm, cmap='hsv')
1532
+ scalar_map = cmx.ScalarMappable(norm=color_norm, cmap="hsv")
1211
1533
  color_map = [scalar_map.to_rgba(i) for i in range(n_colors)]
1212
1534
  else:
1213
1535
  groupings = [(None, nodes_df)]
1214
- color_map = ['blue']
1536
+ color_map = ["blue"]
1215
1537
 
1216
1538
  for color, (group_name, group_df) in zip(color_map, groupings):
1217
1539
  if subset is not None:
1218
1540
  group_df = group_df.iloc[::subset]
1219
-
1541
+
1220
1542
  if group and group_name not in group.split(","):
1221
1543
  continue
1222
1544
 
@@ -1238,9 +1560,9 @@ def plot_3d_cell_rotation(config=None, sources=None, sids=None, title=None, save
1238
1560
  W = np.zeros(len(Z))
1239
1561
 
1240
1562
  # Create rotation matrices from Euler angles
1241
- rotations = R.from_euler('xyz', np.column_stack((U, V, W)), degrees=False)
1563
+ rotations = R.from_euler("xyz", np.column_stack((U, V, W)), degrees=False)
1242
1564
 
1243
- # Define initial vectors
1565
+ # Define initial vectors
1244
1566
  init_vectors = np.column_stack((np.ones(len(X)), np.zeros(len(Y)), np.zeros(len(Z))))
1245
1567
 
1246
1568
  # Apply rotations to initial vectors
@@ -1251,7 +1573,18 @@ def plot_3d_cell_rotation(config=None, sources=None, sids=None, title=None, save
1251
1573
  rot_y = rots[:, 1]
1252
1574
  rot_z = rots[:, 2]
1253
1575
 
1254
- h = ax.quiver(X, Y, Z, rot_x, rot_y, rot_z, color=color, label=group_name, arrow_length_ratio=arrow_length_ratio, length=quiver_length)
1576
+ h = ax.quiver(
1577
+ X,
1578
+ Y,
1579
+ Z,
1580
+ rot_x,
1581
+ rot_y,
1582
+ rot_z,
1583
+ color=color,
1584
+ label=group_name,
1585
+ arrow_length_ratio=arrow_length_ratio,
1586
+ length=quiver_length,
1587
+ )
1255
1588
  ax.scatter(X, Y, Z, color=color, label=group_name)
1256
1589
  ax.set_xlim([min(X), max(X)])
1257
1590
  ax.set_ylim([min(Y), max(Y)])
@@ -1268,5 +1601,5 @@ def plot_3d_cell_rotation(config=None, sources=None, sids=None, title=None, save
1268
1601
  if save_file:
1269
1602
  plt.savefig(save_file)
1270
1603
  notebook = is_notebook
1271
- if notebook == False:
1604
+ if not notebook:
1272
1605
  plt.show()