bmtool 0.6.9.1__py3-none-any.whl → 0.6.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bmtool/bmplot.py CHANGED
@@ -306,6 +306,28 @@ def gap_junction_matrix(config=None,title=None,sources=None, targets=None, sids=
306
306
 
307
307
 
308
308
  def filter_rows(syn_info, data, source_labels, target_labels):
309
+ """
310
+ Filters out rows in a connectivity matrix that contain only NaN or zero values.
311
+
312
+ This function is used to clean up connection matrices by removing rows that have no meaningful data,
313
+ which helps create more informative visualizations of network connectivity.
314
+
315
+ Parameters:
316
+ -----------
317
+ syn_info : numpy.ndarray
318
+ Array containing synaptic information corresponding to the data matrix.
319
+ data : numpy.ndarray
320
+ 2D matrix containing connectivity data with rows representing sources and columns representing targets.
321
+ source_labels : list
322
+ List of labels for the source populations corresponding to rows in the data matrix.
323
+ target_labels : list
324
+ List of labels for the target populations corresponding to columns in the data matrix.
325
+
326
+ Returns:
327
+ --------
328
+ tuple
329
+ A tuple containing the filtered (syn_info, data, source_labels, target_labels) with invalid rows removed.
330
+ """
309
331
  # Identify rows with all NaN or all zeros
310
332
  valid_rows = ~np.all(np.isnan(data), axis=1) & ~np.all(data == 0, axis=1)
311
333
 
@@ -317,6 +339,30 @@ def gap_junction_matrix(config=None,title=None,sources=None, targets=None, sids=
317
339
  return new_syn_info, new_data, new_source_labels, target_labels
318
340
 
319
341
  def filter_rows_and_columns(syn_info, data, source_labels, target_labels):
342
+ """
343
+ Filters out both rows and columns in a connectivity matrix that contain only NaN or zero values.
344
+
345
+ This function performs a two-step filtering process: first removing rows with no data,
346
+ then transposing the matrix and removing columns with no data (by treating them as rows).
347
+ This creates a cleaner, more informative connectivity matrix visualization.
348
+
349
+ Parameters:
350
+ -----------
351
+ syn_info : numpy.ndarray
352
+ Array containing synaptic information corresponding to the data matrix.
353
+ data : numpy.ndarray
354
+ 2D matrix containing connectivity data with rows representing sources and columns representing targets.
355
+ source_labels : list
356
+ List of labels for the source populations corresponding to rows in the data matrix.
357
+ target_labels : list
358
+ List of labels for the target populations corresponding to columns in the data matrix.
359
+
360
+ Returns:
361
+ --------
362
+ tuple
363
+ A tuple containing the filtered (syn_info, data, source_labels, target_labels) with both
364
+ invalid rows and columns removed.
365
+ """
320
366
  # Filter rows first
321
367
  syn_info, data, source_labels, target_labels = filter_rows(syn_info, data, source_labels, target_labels)
322
368
 
@@ -366,6 +412,36 @@ def connection_histogram(config=None,nodes=None,edges=None,sources=[],targets=[]
366
412
  save_file: If plot should be saved
367
413
  """
368
414
  def connection_pair_histogram(**kwargs):
415
+ """
416
+ Creates a histogram showing the distribution of connection counts between a specific source and target cell type.
417
+
418
+ This function is designed to be used with the relation_matrix utility and will only create histograms
419
+ for the specified source and target cell types, ignoring all other combinations.
420
+
421
+ Parameters:
422
+ -----------
423
+ kwargs : dict
424
+ Dictionary containing the following keys:
425
+ - edges: DataFrame containing edge information
426
+ - sid: Column name for source ID type in the edges DataFrame
427
+ - tid: Column name for target ID type in the edges DataFrame
428
+ - source_id: Value to filter edges by source ID type
429
+ - target_id: Value to filter edges by target ID type
430
+
431
+ Global parameters used:
432
+ ---------------------
433
+ source_cell : str
434
+ The source cell type to plot.
435
+ target_cell : str
436
+ The target cell type to plot.
437
+ include_gap : bool
438
+ Whether to include gap junctions in the analysis. If False, gap junctions are excluded.
439
+
440
+ Returns:
441
+ --------
442
+ None
443
+ Displays a histogram showing the distribution of connection counts.
444
+ """
369
445
  edges = kwargs["edges"]
370
446
  source_id_type = kwargs["sid"]
371
447
  target_id_type = kwargs["tid"]
@@ -491,7 +567,43 @@ def connection_distance(config: str,sources: str,targets: str,
491
567
 
492
568
  def edge_histogram_matrix(config=None,sources = None,targets=None,sids=None,tids=None,no_prepend_pop=None,edge_property = None,time = None,time_compare = None,report=None,title=None,save_file=None):
493
569
  """
494
- write about function here
570
+ Generates a matrix of histograms showing the distribution of edge properties between different populations.
571
+
572
+ This function creates a grid of histograms where each cell in the grid represents the distribution of a
573
+ specific edge property (e.g., synaptic weights, delays) between a source population (row) and
574
+ target population (column).
575
+
576
+ Parameters:
577
+ -----------
578
+ config : str
579
+ Path to a BMTK simulation config file.
580
+ sources : str
581
+ Comma-separated list of source network names.
582
+ targets : str
583
+ Comma-separated list of target network names.
584
+ sids : str, optional
585
+ Comma-separated list of source node identifiers to filter by.
586
+ tids : str, optional
587
+ Comma-separated list of target node identifiers to filter by.
588
+ no_prepend_pop : bool, optional
589
+ If True, population names are not prepended to node identifiers in the display.
590
+ edge_property : str
591
+ The edge property to analyze and display in the histograms (e.g., 'syn_weight', 'delay').
592
+ time : int, optional
593
+ Time point to analyze from a time series report.
594
+ time_compare : int, optional
595
+ Second time point for comparison with 'time'.
596
+ report : str, optional
597
+ Name of the report to analyze.
598
+ title : str, optional
599
+ Custom title for the plot. If None, defaults to "{edge_property} Histogram Matrix".
600
+ save_file : str, optional
601
+ Path to save the generated plot.
602
+
603
+ Returns:
604
+ --------
605
+ None
606
+ Displays a matrix of histograms.
495
607
  """
496
608
 
497
609
  if not config:
@@ -655,8 +767,11 @@ def connector_percent_matrix(csv_path: str = None, exclude_strings=None, assemb_
655
767
  filtered_string = match.group(1)
656
768
  if 'Gap' in string:
657
769
  filtered_string = filtered_string + "-Gap"
658
- if assemb_key in string:
659
- filtered_string = filtered_string + assemb_key
770
+
771
+ if assemb_key:
772
+ if assemb_key in string:
773
+ filtered_string = filtered_string + assemb_key
774
+
660
775
  return filtered_string # Return matched string
661
776
 
662
777
  return string # If no match, return the original string
@@ -670,39 +785,40 @@ def connector_percent_matrix(csv_path: str = None, exclude_strings=None, assemb_
670
785
  df = filter_dataframe(df, 'Target', exclude_strings)
671
786
 
672
787
  #process assem rows and combine them into one prob per assem type
673
- assems = df[df['Source'].str.contains(assemb_key)]
674
- unique_sources = assems['Source'].unique()
675
-
676
- for source in unique_sources:
677
- source_assems = assems[assems['Source'] == source]
678
- unique_targets = source_assems['Target'].unique() # Filter targets for the current source
679
-
680
- for target in unique_targets:
681
- # Filter the assemblies with the current source and target
682
- unique_assems = source_assems[source_assems['Target'] == target]
683
-
684
- # find the prob of a conn
685
- forward_probs = []
686
- for _,row in unique_assems.iterrows():
687
- selected_percentage = row[selected_column]
688
- selected_percentage = [float(p) for p in selected_percentage.strip('[]').split()]
689
- if len(selected_percentage) == 1 or len(selected_percentage) == 2:
690
- forward_probs.append(selected_percentage[0])
691
- if len(selected_percentage) == 3:
692
- forward_probs.append(selected_percentage[0])
693
- forward_probs.append(selected_percentage[1])
694
-
695
- mean_probs = np.mean(forward_probs)
696
- source = source.replace(assemb_key, "")
697
- target = target.replace(assemb_key, "")
698
- new_row = pd.DataFrame({
699
- 'Source': [source],
700
- 'Target': [target],
701
- 'Percent connectionivity within possible connections': [mean_probs],
702
- 'Percent connectionivity within all connections': [0]
703
- })
704
-
705
- df = pd.concat([df, new_row], ignore_index=False)
788
+ if assemb_key:
789
+ assems = df[df['Source'].str.contains(assemb_key)]
790
+ unique_sources = assems['Source'].unique()
791
+
792
+ for source in unique_sources:
793
+ source_assems = assems[assems['Source'] == source]
794
+ unique_targets = source_assems['Target'].unique() # Filter targets for the current source
795
+
796
+ for target in unique_targets:
797
+ # Filter the assemblies with the current source and target
798
+ unique_assems = source_assems[source_assems['Target'] == target]
799
+
800
+ # find the prob of a conn
801
+ forward_probs = []
802
+ for _,row in unique_assems.iterrows():
803
+ selected_percentage = row[selected_column]
804
+ selected_percentage = [float(p) for p in selected_percentage.strip('[]').split()]
805
+ if len(selected_percentage) == 1 or len(selected_percentage) == 2:
806
+ forward_probs.append(selected_percentage[0])
807
+ if len(selected_percentage) == 3:
808
+ forward_probs.append(selected_percentage[0])
809
+ forward_probs.append(selected_percentage[1])
810
+
811
+ mean_probs = np.mean(forward_probs)
812
+ source = source.replace(assemb_key, "")
813
+ target = target.replace(assemb_key, "")
814
+ new_row = pd.DataFrame({
815
+ 'Source': [source],
816
+ 'Target': [target],
817
+ 'Percent connectionivity within possible connections': [mean_probs],
818
+ 'Percent connectionivity within all connections': [0]
819
+ })
820
+
821
+ df = pd.concat([df, new_row], ignore_index=False)
706
822
 
707
823
  # Prepare connection data
708
824
  connection_data = {}
@@ -1013,6 +1129,23 @@ def plot_firing_rate_distribution(individual_stats: pd.DataFrame, groupby: Union
1013
1129
  return ax
1014
1130
 
1015
1131
  def plot_entrainment():
1132
+ """
1133
+ Plots entrainment analysis for oscillatory network activity.
1134
+
1135
+ This function analyzes and visualizes how well neural populations entrain to rhythmic
1136
+ input or how synchronized they become during oscillatory activity. It can show phase
1137
+ locking, coherence, or other entrainment metrics.
1138
+
1139
+ Note: This is currently a placeholder function and not yet implemented.
1140
+
1141
+ Parameters:
1142
+ -----------
1143
+ None
1144
+
1145
+ Returns:
1146
+ --------
1147
+ None
1148
+ """
1016
1149
  pass
1017
1150
 
1018
1151
  def plot_3d_positions(config=None, populations_list=None, group_by=None, title=None, save_file=None, subset=None):
@@ -1216,6 +1349,43 @@ def plot_3d_cell_rotation(config=None, populations_list=None, group_by=None, tit
1216
1349
  plt.show()
1217
1350
 
1218
1351
  def plot_network_graph(config=None,nodes=None,edges=None,title=None,sources=None, targets=None, sids=None, tids=None, no_prepend_pop=False,save_file=None,edge_property='model_template'):
1352
+ """
1353
+ Creates a directed graph visualization of the network connectivity using NetworkX.
1354
+
1355
+ This function generates a network diagram showing the connections between different
1356
+ cell populations, with edge labels indicating the connection types based on the specified
1357
+ edge property.
1358
+
1359
+ Parameters:
1360
+ -----------
1361
+ config : str
1362
+ Path to a BMTK simulation configuration file.
1363
+ nodes : dict, optional
1364
+ Dictionary of node information (if already loaded).
1365
+ edges : dict, optional
1366
+ Dictionary of edge information (if already loaded).
1367
+ title : str, optional
1368
+ Custom title for the plot. If None, defaults to "Network Graph".
1369
+ sources : str
1370
+ Comma-separated list of source network names.
1371
+ targets : str
1372
+ Comma-separated list of target network names.
1373
+ sids : str, optional
1374
+ Comma-separated list of source node identifiers to filter by.
1375
+ tids : str, optional
1376
+ Comma-separated list of target node identifiers to filter by.
1377
+ no_prepend_pop : bool, default=False
1378
+ If True, population names are not prepended to node identifiers in the display.
1379
+ save_file : str, optional
1380
+ Path to save the generated plot.
1381
+ edge_property : str, default='model_template'
1382
+ The edge property to use for labeling connections in the graph.
1383
+
1384
+ Returns:
1385
+ --------
1386
+ None
1387
+ Displays a network graph visualization.
1388
+ """
1219
1389
  if not config:
1220
1390
  raise Exception("config not defined")
1221
1391
  if not sources or not targets:
@@ -1308,6 +1478,28 @@ def plot_report(config_file=None, report_file=None, report_name=None, variables=
1308
1478
  plt.show()
1309
1479
 
1310
1480
  def plot_report_default(config, report_name, variables, gids):
1481
+ """
1482
+ A simplified interface for plotting cell report variables from BMTK simulations.
1483
+
1484
+ This function handles the common case of plotting specific variables for specific cells
1485
+ from a BMTK report file, with minimal parameter requirements.
1486
+
1487
+ Parameters:
1488
+ -----------
1489
+ config : str
1490
+ Path to a BMTK simulation configuration file.
1491
+ report_name : str
1492
+ Name of the report to plot (without file extension).
1493
+ variables : str
1494
+ Comma-separated list of variable names to plot (e.g., 'v,i_na,i_k').
1495
+ gids : str
1496
+ Comma-separated list of cell IDs (gids) to plot data for.
1497
+
1498
+ Returns:
1499
+ --------
1500
+ None
1501
+ Displays plots of the specified variables for the specified cells.
1502
+ """
1311
1503
 
1312
1504
  if variables:
1313
1505
  variables = variables.split(',')