bmtool 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bmtool/graphs.py ADDED
@@ -0,0 +1,170 @@
1
+ import networkx as nx
2
+ import plotly.graph_objects as go
3
+ import pandas as pd
4
+ import bmtool.util.util as u
5
+ import pandas as pd
6
+
7
+
8
+ def generate_graph(config,source,target):
9
+ """
10
+ returns a graph object
11
+ config: A BMTK simulation config
12
+ source: network name
13
+ target: network name
14
+ """
15
+ nodes,edges = u.load_nodes_edges_from_config(config)
16
+ nodes_source = nodes[source]
17
+ nodes_target = nodes[target]
18
+ if source != target:
19
+ # Concatenate the DataFrames if source and target are different nodes
20
+ nodes = pd.concat([nodes_source, nodes_target])
21
+ else:
22
+ nodes = nodes[source]
23
+ edge_to_grap = source+"_to_"+target
24
+ edges = edges[edge_to_grap]
25
+
26
+ # Create an empty graph
27
+ G = nx.Graph()
28
+
29
+ # Add nodes to the graph with their positions and labels
30
+ for index, node_data in nodes.iterrows():
31
+ G.add_node(index, pos=(node_data['pos_x'], node_data['pos_y'], node_data['pos_z']), label=node_data['pop_name'])
32
+
33
+ # Add edges to the graph
34
+ for _, row in edges.iterrows():
35
+ G.add_edge(row['source_node_id'], row['target_node_id'])
36
+
37
+ return G
38
+
39
+
40
+ def plot_graph(Graph=None,show_edges = False,title = None):
41
+ """
42
+ Generate an interactive plot of the network
43
+ Graph: A Graph object
44
+ show_edges: Boolean to show edges in graph plot
45
+ title: A string for the title of the graph
46
+
47
+ """
48
+
49
+ # Extract node positions
50
+ node_positions = nx.get_node_attributes(Graph, 'pos')
51
+ node_x = [data[0] for data in node_positions.values()]
52
+ node_y = [data[1] for data in node_positions.values()]
53
+ node_z = [data[2] for data in node_positions.values()]
54
+
55
+ # Create edge traces
56
+ edge_x = []
57
+ edge_y = []
58
+ edge_z = []
59
+ for edge in Graph.edges():
60
+ x0, y0, z0 = node_positions[edge[0]]
61
+ x1, y1, z1 = node_positions[edge[1]]
62
+ edge_x.extend([x0, x1, None])
63
+ edge_y.extend([y0, y1, None])
64
+ edge_z.extend([z0, z1, None])
65
+
66
+ # Create edge trace
67
+ edge_trace = go.Scatter3d(
68
+ x=edge_x,
69
+ y=edge_y,
70
+ z=edge_z,
71
+ line=dict(width=1, color='#888'),
72
+ hoverinfo='none',
73
+ mode='lines',
74
+ opacity=0.2)
75
+
76
+ # Create node trace
77
+ node_trace = go.Scatter3d(
78
+ x=node_x,
79
+ y=node_y,
80
+ z=node_z,
81
+ mode='markers',
82
+ hoverinfo='text',
83
+ marker=dict(
84
+ showscale=True,
85
+ colorscale='YlGnBu', # Adjust color scale here
86
+ reversescale=True,
87
+ color=[len(list(Graph.neighbors(node))) for node in Graph.nodes()], # Assign color data here
88
+ size=5, # Adjust the size of the nodes here
89
+ colorbar=dict(
90
+ thickness=15,
91
+ title='Node Connections',
92
+ xanchor='left',
93
+ titleside='right'
94
+ ),
95
+ line_width=2,
96
+ cmin=0, # Adjust color scale range here
97
+ cmax=max([len(list(Graph.neighbors(node))) for node in Graph.nodes()]) # Adjust color scale range here
98
+ ))
99
+
100
+ # Define hover text for nodes
101
+ node_hover_text = [f'Node ID: {node_id}<br>Population Name: {node_data["label"]}<br># of Connections: {len(list(Graph.neighbors(node_id)))}' for node_id, node_data in Graph.nodes(data=True)]
102
+ node_trace.hovertext = node_hover_text
103
+
104
+ # Create figure
105
+ if show_edges:
106
+ graph_prop = [edge_trace,node_trace]
107
+ else:
108
+ graph_prop = [node_trace]
109
+
110
+ if title == None:
111
+ title = '3D plot'
112
+
113
+ fig = go.Figure(data=graph_prop,
114
+ layout=go.Layout(
115
+ title=title,
116
+ titlefont_size=16,
117
+ showlegend=False,
118
+ hovermode='closest',
119
+ margin=dict(b=20, l=5, r=5, t=40),
120
+ scene=dict(
121
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
122
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
123
+ zaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
124
+ ),
125
+ width=800,
126
+ height=800
127
+ ))
128
+
129
+ # Show figure
130
+ fig.show()
131
+
132
+
133
+ def export_node_connections_to_csv(Graph, filename):
134
+ """
135
+ Generates a csv file with node type and all connections that node receives
136
+ Graph: a Graph object
137
+ filename: A string for the name of output must end in .csv
138
+ """
139
+ # Create an empty dictionary to store the connections for each node
140
+ node_connections = {}
141
+
142
+ # Iterate over each node in the graph
143
+ for node in Graph.nodes():
144
+ # Initialize a dictionary to store the connections for the current node
145
+ connections = {}
146
+ node_label = Graph.nodes[node]['label']
147
+
148
+ # Iterate over each neighbor of the current node
149
+ for neighbor in Graph.neighbors(node):
150
+ # Get the label of the neighbor node
151
+ neighbor_label = Graph.nodes[neighbor]['label']
152
+
153
+ # Increment the connection count for the current node and neighbor label
154
+ connections[f'{neighbor_label} Connections'] = connections.get(f'{neighbor_label} Connections', 0) + 1
155
+
156
+ # Add the connections information for the current node to the dictionary
157
+ connections['Node Label'] = node_label
158
+ node_connections[node] = connections
159
+
160
+ # Convert the dictionary to a DataFrame
161
+ df = pd.DataFrame(node_connections).fillna(0).T
162
+
163
+ # Reorder columns so that 'Node Label' is the leftmost column
164
+ cols = df.columns.tolist()
165
+ cols = ['Node Label'] + [col for col in cols if col != 'Node Label']
166
+ df = df[cols]
167
+
168
+ # Write the DataFrame to a CSV file
169
+ df.to_csv(filename)
170
+
bmtool/singlecell.py CHANGED
@@ -217,8 +217,8 @@ class Passive(CurrentClamp):
217
217
  self.v_final_time = self.t_vec[self.index_v_final]
218
218
 
219
219
  t_idx = slice(self.index_v_rest, self.index_v_final + 1)
220
- self.v_vec_inj = self.v_vec.as_numpy()[t_idx].copy()
221
- self.t_vec_inj = self.t_vec.as_numpy()[t_idx].copy() - self.v_rest_time
220
+ self.v_vec_inj = np.array(self.v_vec)[t_idx]
221
+ self.t_vec_inj = np.array(self.t_vec)[t_idx] - self.v_rest_time
222
222
 
223
223
  self.v_diff = self.cell_v_final - self.v_rest
224
224
  self.r_in = self.v_diff / self.inj_amp # MegaOhms
@@ -388,8 +388,8 @@ class ZAP(CurrentClamp):
388
388
  self.v_rest_time = self.t_vec[self.index_v_rest]
389
389
 
390
390
  t_idx = slice(self.index_v_rest, self.index_v_final + 1)
391
- self.v_vec_inj = self.v_vec.as_numpy()[t_idx].copy() - self.v_rest
392
- self.t_vec_inj = self.t_vec.as_numpy()[t_idx].copy() - self.v_rest_time
391
+ self.v_vec_inj = np.array(self.v_vec)[t_idx] - self.v_rest
392
+ self.t_vec_inj = np.array(self.t_vec)[t_idx] - self.v_rest_time
393
393
 
394
394
  self.cell_v_amp_max = np.abs(self.v_vec_inj).max()
395
395
  self.Z = np.fft.rfft(self.v_vec_inj) / np.fft.rfft(self.zap_vec_inj) # MOhms
bmtool/util/util.py CHANGED
@@ -258,18 +258,20 @@ def load_nodes_from_paths(node_paths):
258
258
  for group_id in range(n_group):
259
259
  group = nodes_grp[str(group_id)]
260
260
  idx = node_group_id == group_id
261
+ group_node = node_id[idx]
262
+ group_index = node_group_index[idx]
261
263
  for prop in group:
262
264
  if prop == 'positions':
263
- positions = group[prop][node_group_index[idx]]
265
+ positions = group[prop][group_index]
264
266
  for i in range(positions.shape[1]):
265
267
  if pos_labels[i] not in nodes_df:
266
268
  nodes_df[pos_labels[i]] = np.nan
267
- nodes_df.loc[node_id, pos_labels[i]] = positions[:, i]
269
+ nodes_df.loc[group_node, pos_labels[i]] = positions[:, i]
268
270
  else:
269
271
  # create new column with NaN if property does not exist
270
272
  if prop not in nodes_df:
271
273
  nodes_df[prop] = np.nan
272
- nodes_df.loc[idx, prop] = tuple(group[prop][node_group_index[idx]])
274
+ nodes_df.loc[group_node, prop] = group[prop][group_index]
273
275
  prop_dtype[prop] = group[prop].dtype
274
276
  # convert to original data type if possible
275
277
  for prop, dtype in prop_dtype.items():
@@ -566,7 +568,10 @@ def relation_matrix(config=None, nodes=None,edges=None,sources=[],targets=[],sid
566
568
 
567
569
  total = relation_func(source_nodes=source_nodes, target_nodes=target_nodes, edges=c_edges, source=source,sid="source_"+sids[s], target=target,tid="target_"+tids[t],source_id=s_type,target_id=t_type)
568
570
  if synaptic_info=='0':
569
- syn_info[source_index,target_index] = total
571
+ if isinstance(total, tuple):
572
+ syn_info[source_index, target_index] = str(round(total[0], 1)) + '\n' + str(round(total[1], 1))
573
+ else:
574
+ syn_info[source_index,target_index] = total
570
575
  elif synaptic_info=='1':
571
576
  mean = conn_mean_func(source_nodes=source_nodes, target_nodes=target_nodes, edges=c_edges, source=source,sid="source_"+sids[s], target=target,tid="target_"+tids[t],source_id=s_type,target_id=t_type)
572
577
  stdev = conn_stdev_func(source_nodes=source_nodes, target_nodes=target_nodes, edges=c_edges, source=source,sid="source_"+sids[s], target=target,tid="target_"+tids[t],source_id=s_type,target_id=t_type)
@@ -587,8 +592,11 @@ def relation_matrix(config=None, nodes=None,edges=None,sources=[],targets=[],sid
587
592
  syn_info[source_index,target_index] = ""
588
593
  else:
589
594
  syn_info[source_index,target_index] = syn_list
595
+ if isinstance(total, tuple):
596
+ e_matrix[source_index,target_index]=total[0]
597
+ else:
598
+ e_matrix[source_index,target_index]=total
590
599
 
591
- e_matrix[source_index,target_index]=total
592
600
 
593
601
  return syn_info, e_matrix, source_pop_names, target_pop_names
594
602
 
@@ -609,7 +617,9 @@ def connection_totals(config=None,nodes=None,edges=None,sources=[],targets=[],si
609
617
  return total
610
618
  return relation_matrix(config,nodes,edges,sources,targets,sids,tids,prepend_pop,relation_func=total_connection_relationship,synaptic_info=synaptic_info)
611
619
 
612
- def percent_connections(config=None,nodes=None,edges=None,sources=[],targets=[],sids=[],tids=[],prepend_pop=True,method=None,include_gap=True):
620
+
621
+ def percent_connections(config=None,nodes=None,edges=None,sources=[],targets=[],sids=[],tids=[],prepend_pop=True,type='convergence',method=None,include_gap=True):
622
+
613
623
 
614
624
  def precent_func(**kwargs):
615
625
  edges = kwargs["edges"]
@@ -638,8 +648,12 @@ def percent_connections(config=None,nodes=None,edges=None,sources=[],targets=[],
638
648
  num_bi = (cons_recip.count().source_node_id - cons_recip_dedup.count().source_node_id)
639
649
  num_uni = total_cons - num_bi
640
650
 
641
- num_sources = s_list.apply(pd.Series.value_counts)[source_id_type].dropna().sort_index().loc[source_id]
642
- num_targets = t_list.apply(pd.Series.value_counts)[target_id_type].dropna().sort_index().loc[target_id]
651
+ #num_sources = s_list.apply(pd.Series.value_counts)[source_id_type].dropna().sort_index().loc[source_id]
652
+ #num_targets = t_list.apply(pd.Series.value_counts)[target_id_type].dropna().sort_index().loc[target_id]
653
+
654
+ num_sources = s_list[source_id_type].value_counts().sort_index().loc[source_id]
655
+ num_targets = t_list[target_id_type].value_counts().sort_index().loc[target_id]
656
+
643
657
 
644
658
  total = round(total_cons / (num_sources*num_targets) * 100,2)
645
659
  uni = round(num_uni / (num_sources*num_targets) * 100,2)
@@ -654,7 +668,8 @@ def percent_connections(config=None,nodes=None,edges=None,sources=[],targets=[],
654
668
 
655
669
  return relation_matrix(config,nodes,edges,sources,targets,sids,tids,prepend_pop,relation_func=precent_func)
656
670
 
657
- def connection_divergence(config=None,nodes=None,edges=None,sources=[],targets=[],sids=[],tids=[],prepend_pop=True,convergence=False,method='mean',include_gap=True):
671
+
672
+ def connection_divergence(config=None,nodes=None,edges=None,sources=[],targets=[],sids=[],tids=[],prepend_pop=True,convergence=False,method='mean+std',include_gap=True):
658
673
 
659
674
  import pandas as pd
660
675
 
@@ -674,47 +689,97 @@ def connection_divergence(config=None,nodes=None,edges=None,sources=[],targets=[
674
689
 
675
690
  if convergence:
676
691
  if method == 'min':
677
- count = cons.apply(pd.Series.value_counts).target_node_id.dropna().min()
678
- return count
692
+ count = cons['target_node_id'].value_counts().min()
693
+ return round(count,2)
679
694
  elif method == 'max':
680
- count = cons.apply(pd.Series.value_counts).target_node_id.dropna().max()
681
- return count
695
+ count = cons['target_node_id'].value_counts().max()
696
+ return round(count,2)
682
697
  elif method == 'std':
683
- std = cons.apply(pd.Series.value_counts).target_node_id.dropna().std()
698
+ std = cons['target_node_id'].value_counts().std()
684
699
  return round(std,2)
685
- else: #mean
686
- vc = t_list.apply(pd.Series.value_counts)
687
- vc = vc[target_id_type].dropna().sort_index()
688
- count = vc.loc[target_id]#t_list[t_list[target_id_type]==target_id]
700
+ elif method == 'mean':
701
+ mean = cons['target_node_id'].value_counts().mean()
702
+ return round(mean,2)
703
+ elif method == 'mean+std': #default is mean + std
704
+ mean = cons['target_node_id'].value_counts().mean()
705
+ std = cons['target_node_id'].value_counts().std()
706
+ #std = cons.apply(pd.Series.value_counts).target_node_id.dropna().std() no longer a valid way
707
+ return (round(mean,2)), (round(std,2))
689
708
  else: #divergence
690
709
  if method == 'min':
691
- count = cons.apply(pd.Series.value_counts).source_node_id.dropna().min()
692
- return count
710
+ count = cons['source_node_id'].value_counts().min()
711
+ return round(count,2)
693
712
  elif method == 'max':
694
- count = cons.apply(pd.Series.value_counts).source_node_id.dropna().max()
695
- return count
713
+ count = cons['source_node_id'].value_counts().max()
714
+ return round(count,2)
696
715
  elif method == 'std':
697
- std = cons.apply(pd.Series.value_counts).source_node_id.dropna().std()
716
+ std = cons['source_node_id'].value_counts().std()
698
717
  return round(std,2)
699
- else: #mean
700
- #vc = s_list.apply(pd.Series.value_counts)[source_id_type].dropna().sort_index().loc[source_id]
701
- vc = s_list.apply(pd.Series.value_counts)
702
- vc = vc[source_id_type].dropna().sort_index()
703
- count = vc.loc[source_id]#count = s_list[s_list[source_id_type]==source_id]
704
-
705
- # Only executed when mean
706
- total = edges[(edges[source_id_type] == source_id) & (edges[target_id_type]==target_id)]
707
- if include_gap == False:
708
- total = total[total['is_gap_junction'] != True]
709
- total = total.count()
710
- total = total.source_node_id # may not be the best way to pick
711
- ret = round(total/count,1)
712
- if ret == 0:
713
- ret = None
714
- return ret
718
+ elif method == 'mean':
719
+ mean = cons['source_node_id'].value_counts().mean()
720
+ return round(mean,2)
721
+ elif method == 'mean+std': #default is mean + std
722
+ mean = cons['source_node_id'].value_counts().mean()
723
+ std = cons['source_node_id'].value_counts().std()
724
+ return (round(mean,2)), (round(std,2))
715
725
 
716
726
  return relation_matrix(config,nodes,edges,sources,targets,sids,tids,prepend_pop,relation_func=total_connection_relationship)
717
727
 
728
+ def gap_junction_connections(config=None,nodes=None,edges=None,sources=[],targets=[],sids=[],tids=[],prepend_pop=True,type='convergence'):
729
+ import pandas as pd
730
+
731
+
732
+ def total_connection_relationship(**kwargs): #reduced version of original function; only gets mean+std
733
+ edges = kwargs["edges"]
734
+ source_id_type = kwargs["sid"]
735
+ target_id_type = kwargs["tid"]
736
+ source_id = kwargs["source_id"]
737
+ target_id = kwargs["target_id"]
738
+
739
+ cons = edges[(edges[source_id_type] == source_id) & (edges[target_id_type]==target_id)]
740
+ #print(cons)
741
+
742
+ cons = cons[cons['is_gap_junction'] == True] #only gap_junctions
743
+ mean = cons['target_node_id'].value_counts().mean()
744
+ std = cons['target_node_id'].value_counts().std()
745
+ return (round(mean,2)), (round(std,2))
746
+
747
+ def precent_func(**kwargs): #barely different than original function; only gets gap_junctions.
748
+ edges = kwargs["edges"]
749
+ source_id_type = kwargs["sid"]
750
+ target_id_type = kwargs["tid"]
751
+ source_id = kwargs["source_id"]
752
+ target_id = kwargs["target_id"]
753
+ t_list = kwargs["target_nodes"]
754
+ s_list = kwargs["source_nodes"]
755
+
756
+ cons = edges[(edges[source_id_type] == source_id) & (edges[target_id_type]==target_id)]
757
+ #add functionality that shows only the one's with gap_junctions
758
+ cons = cons[cons['is_gap_junction'] == True]
759
+ total_cons = cons.count().source_node_id
760
+
761
+ num_sources = s_list[source_id_type].value_counts().sort_index().loc[source_id]
762
+ num_targets = t_list[target_id_type].value_counts().sort_index().loc[target_id]
763
+
764
+
765
+ total = round(total_cons / (num_sources*num_targets) * 100,2)
766
+ return total
767
+
768
+ if type == 'convergence':
769
+ return relation_matrix(config,nodes,edges,sources,targets,sids,tids,prepend_pop,relation_func=total_connection_relationship)
770
+ elif type == 'percent':
771
+ return relation_matrix(config,nodes,edges,sources,targets,sids,tids,prepend_pop,relation_func=precent_func)
772
+
773
+
774
+ def gap_junction_percent_connections(config=None,nodes=None,edges=None,sources=[],targets=[],sids=[],tids=[],prepend_pop=True,method=None):
775
+ import pandas as pd
776
+
777
+
778
+
779
+
780
+
781
+
782
+
718
783
  def connection_probabilities(config=None,nodes=None,edges=None,sources=[],
719
784
  targets=[],sids=[],tids=[],prepend_pop=True,dist_X=True,dist_Y=True,dist_Z=True,num_bins=10,include_gap=True):
720
785
 
@@ -753,11 +818,11 @@ def connection_probabilities(config=None,nodes=None,edges=None,sources=[],
753
818
  def eudist(df,use_x=True,use_y=True,use_z=True):
754
819
  def _dist(x):
755
820
  if len(x) == 6:
756
- return distance.euclidean((x[0],x[1],x[2]),(x[3],x[4],x[5]))
821
+ return distance.euclidean((x.iloc[0], x.iloc[1], x.iloc[2]), (x.iloc[3], x.iloc[4], x.iloc[5]))
757
822
  elif len(x) == 4:
758
- return distance.euclidean((x[0],x[1]),(x[2],x[3]))
823
+ return distance.euclidean((x.iloc[0],x.iloc[1]),(x.iloc[2],x.iloc[3]))
759
824
  elif len(x) == 2:
760
- return distance.euclidean((x[0]),(x[1]))
825
+ return distance.euclidean((x.iloc[0]),(x.iloc[1]))
761
826
  else:
762
827
  return -1
763
828
 
@@ -835,6 +900,7 @@ def connection_graph_edge_types(config=None,nodes=None,edges=None,sources=[],tar
835
900
 
836
901
  return relation_matrix(config,nodes,edges,sources,targets,sids,tids,prepend_pop,relation_func=synapse_type_relationship,return_type=object)
837
902
 
903
+
838
904
  def edge_property_matrix(edge_property, config=None, nodes=None, edges=None, sources=[],targets=[],sids=[],tids=[],prepend_pop=True,report=None,time=-1,time_compare=None):
839
905
 
840
906
  var_report = None
@@ -906,10 +972,11 @@ def percent_connectivity(config=None,nodes=None,edges=None,sources=[],targets=[]
906
972
 
907
973
  return ret, source_labels, target_labels
908
974
 
909
-
975
+
910
976
  def connection_average_synapses():
911
977
  return
912
978
 
979
+
913
980
  def connection_divergence_average_old(config=None, nodes=None, edges=None,populations=[],convergence=False):
914
981
  """
915
982
  For each cell in source count # of connections in target and average