bmtool 0.5.4__py3-none-any.whl → 0.5.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bmtool/connectors.py CHANGED
@@ -5,6 +5,8 @@ from scipy.optimize import minimize_scalar
5
5
  from functools import partial
6
6
  import time
7
7
  import types
8
+ import pandas as pd
9
+ import re
8
10
 
9
11
  rng = np.random.default_rng()
10
12
 
@@ -220,7 +222,7 @@ class GaussianDropoff(DistantDependentProbability):
220
222
  "Probability crosses 1 at distance %.3g.\n") % (pmax, d)
221
223
  if self.ptotal is not None:
222
224
  warn += " ptotal may not be reached."
223
- print(warn)
225
+ print(warn,flush=True)
224
226
  self.probability = lambda dist: np.fmin(probability(dist), 1.)
225
227
  else:
226
228
  self.probability = probability
@@ -333,7 +335,7 @@ class Timer(object):
333
335
  return (time.perf_counter() - self._start) * self.scale
334
336
 
335
337
  def report(self, msg='Run time'):
336
- print((msg + ": %.3f " + self.unit) % self.end())
338
+ print((msg + ": %.3f " + self.unit) % self.end(),flush=True)
337
339
 
338
340
 
339
341
  def pr_2_rho(p0, p1, pr):
@@ -353,7 +355,7 @@ def rho_2_pr(p0, p1, rho):
353
355
  pr0, pr = pr, np.max((0., p0 + p1 - 1, np.min((p0, p1, pr))))
354
356
  rho0, rho = rho, (pr - p0 * p1) / (p0 * (1 - p0) * p1 * (1 - p1)) ** .5
355
357
  print('rho changed from %.3f to %.3f; pr changed from %.3f to %.3f'
356
- % (rho0, rho, pr0, pr))
358
+ % (rho0, rho, pr0, pr),flush=True)
357
359
  return pr
358
360
 
359
361
 
@@ -532,7 +534,7 @@ class ReciprocalConnector(AbstractConnector):
532
534
  pr=0., pr_arg=None, estimate_rho=True, rho=None,
533
535
  dist_range_forward=None, dist_range_backward=None,
534
536
  n_syn0=1, n_syn1=1, autapses=False,
535
- quick_pop_check=False, cache_data=True, verbose=True):
537
+ quick_pop_check=False, cache_data=True, verbose=True,save_report=True,report_name='connection_report.csv'):
536
538
  args = locals()
537
539
  var_set = ('p0', 'p0_arg', 'p1', 'p1_arg',
538
540
  'pr', 'pr_arg', 'n_syn0', 'n_syn1')
@@ -550,6 +552,8 @@ class ReciprocalConnector(AbstractConnector):
550
552
  self.quick = quick_pop_check
551
553
  self.cache = self.ConnectorCache(cache_data and self.estimate_rho)
552
554
  self.verbose = verbose
555
+ self.save_report = save_report
556
+ self.report_name = report_name
553
557
 
554
558
  self.conn_prop = [{}, {}]
555
559
  self.stage = 0
@@ -664,12 +668,12 @@ class ReciprocalConnector(AbstractConnector):
664
668
  fetch = out_len > 0
665
669
  if not fetch:
666
670
  print("\nWarning: Cache did not work properly for "
667
- + func_name + '\n')
671
+ + func_name + '\n',flush=True)
668
672
  self.fetch_output(func_name, fetch)
669
673
  self.iter_count = 0
670
674
  else:
671
675
  # if output not correct, disable and use original function
672
- print("\nWarning: Cache did not work properly.\n")
676
+ print("\nWarning: Cache did not work properly.\n",flush=True)
673
677
  for func_name in self.cache_dict:
674
678
  self.fetch_output(func_name, False)
675
679
  self.enable = False
@@ -828,7 +832,7 @@ class ReciprocalConnector(AbstractConnector):
828
832
  self.cache.cache_output(var, name, name in self.callable_set)
829
833
  if self.verbose and len(self.cache.cache_dict):
830
834
  print('Output of %s will be cached.'
831
- % ', '.join(self.cache.cache_dict))
835
+ % ', '.join(self.cache.cache_dict),flush=True)
832
836
 
833
837
  def setup_dist_range_checker(self):
834
838
  # Checker that determines whether to consider a pair for rho estimation
@@ -864,7 +868,7 @@ class ReciprocalConnector(AbstractConnector):
864
868
  if self.verbose:
865
869
  src_str, trg_str = self.get_nodes_info()
866
870
  print("\nStart building connection between: \n "
867
- + src_str + "\n " + trg_str)
871
+ + src_str + "\n " + trg_str,flush=True)
868
872
  self.initialize()
869
873
  cache = self.cache # write mode
870
874
 
@@ -889,11 +893,11 @@ class ReciprocalConnector(AbstractConnector):
889
893
  rho = (self.pr() * n - p0p1_sum) / norm_fac_sum
890
894
  if abs(rho) > 1:
891
895
  print("\nWarning: Estimated value of rho=%.3f "
892
- "outside the range [-1, 1]." % rho)
896
+ "outside the range [-1, 1]." % rho,flush=True)
893
897
  rho = np.clip(rho, -1, 1).item()
894
- print("Force rho to be %.0f.\n" % rho)
898
+ print("Force rho to be %.0f.\n" % rho,flush=True)
895
899
  elif self.verbose:
896
- print("Estimated value of rho=%.3f" % rho)
900
+ print("Estimated value of rho=%.3f" % rho,flush=True)
897
901
  self.rho = rho
898
902
  else:
899
903
  self.rho = 0
@@ -945,8 +949,10 @@ class ReciprocalConnector(AbstractConnector):
945
949
  if self.verbose:
946
950
  self.timer.report('Total time for creating connection matrix')
947
951
  if self.wrong_pr:
948
- print("Warning: Value of 'pr' outside the bounds occurred.\n")
952
+ print("Warning: Value of 'pr' outside the bounds occurred.\n",flush=True)
949
953
  self.connection_number_info()
954
+ if self.save_report:
955
+ self.save_connection_report()
950
956
 
951
957
  def make_connection(self):
952
958
  """ Assign number of synapses per iteration.
@@ -971,7 +977,7 @@ class ReciprocalConnector(AbstractConnector):
971
977
  self.stage = 0
972
978
  self.initial_all_to_all()
973
979
  if self.verbose:
974
- print("Assigning forward connections.")
980
+ print("Assigning forward connections.",flush=True)
975
981
  self.timer.start()
976
982
  return self.make_connection()
977
983
 
@@ -980,7 +986,7 @@ class ReciprocalConnector(AbstractConnector):
980
986
  if self.iter_count == 0:
981
987
  self.stage = 1
982
988
  if self.verbose:
983
- print("Assigning backward connections.")
989
+ print("Assigning backward connections.",flush=True)
984
990
  return self.make_connection()
985
991
 
986
992
  def free_memory(self):
@@ -1033,15 +1039,39 @@ class ReciprocalConnector(AbstractConnector):
1033
1039
  n_conn, n_poss, n_pair, fraction = self.connection_number()
1034
1040
  conn_type = "(all, reciprocal)" if self.recurrent \
1035
1041
  else "(forward, backward, reciprocal)"
1036
- print("Numbers of " + conn_type + " connections:")
1037
- print("Number of connected pairs: (%s)" % arr2str(n_conn, '%d'))
1038
- print("Number of possible connections: (%s)" % arr2str(n_poss, '%d'))
1042
+ print("Numbers of " + conn_type + " connections:",flush=True)
1043
+ print("Number of connected pairs: (%s)" % arr2str(n_conn, '%d'),flush=True)
1044
+ print("Number of possible connections: (%s)" % arr2str(n_poss, '%d'),flush=True)
1039
1045
  print("Fraction of connected pairs in possible ones: (%s)"
1040
- % arr2str(100 * fraction[0], '%.2f%%'))
1041
- print("Number of total pairs: %d" % n_pair)
1046
+ % arr2str(100 * fraction[0], '%.2f%%'),flush=True)
1047
+ print("Number of total pairs: %d" % n_pair,flush=True)
1042
1048
  print("Fraction of connected pairs in all pairs: (%s)\n"
1043
- % arr2str(100 * fraction[1], '%.2f%%'))
1049
+ % arr2str(100 * fraction[1], '%.2f%%'),flush=True)
1044
1050
 
1051
+ def save_connection_report(self):
1052
+ """Save connections into a CSV file to be read from later"""
1053
+ src_str, trg_str = self.get_nodes_info()
1054
+ n_conn, n_poss, n_pair, fraction = self.connection_number()
1055
+
1056
+ # Extract the population name from source_str and target_str
1057
+ data = {
1058
+ "Source": [src_str],
1059
+ "Target": [trg_str],
1060
+ "Fraction of connected pairs in possible ones (%)": [fraction[0]*100],
1061
+ "Fraction of connected pairs in all pairs (%)": [fraction[1]*100]
1062
+ }
1063
+ df = pd.DataFrame(data)
1064
+
1065
+ # Append the data to the CSV file
1066
+ try:
1067
+ # Check if the file exists by trying to read it
1068
+ existing_df = pd.read_csv(self.report_name)
1069
+ # If no exception is raised, append without header
1070
+ df.to_csv(self.report_name, mode='a', header=False, index=False)
1071
+ except FileNotFoundError:
1072
+ # If the file does not exist, write with header
1073
+ df.to_csv(self.report_name, mode='w', header=True, index=False)
1074
+
1045
1075
 
1046
1076
  class UnidirectionConnector(AbstractConnector):
1047
1077
  """
@@ -1074,12 +1104,14 @@ class UnidirectionConnector(AbstractConnector):
1074
1104
  This is useful in similar manner as in ReciprocalConnector.
1075
1105
  """
1076
1106
 
1077
- def __init__(self, p=1., p_arg=None, n_syn=1, verbose=True):
1107
+ def __init__(self, p=1., p_arg=None, n_syn=1, verbose=True,save_report=True,report_name='connection_report.csv'):
1078
1108
  args = locals()
1079
1109
  var_set = ('p', 'p_arg', 'n_syn')
1080
1110
  self.vars = {key: args[key] for key in var_set}
1081
1111
 
1082
1112
  self.verbose = verbose
1113
+ self.save_report = save_report
1114
+ self.report_name = report_name
1083
1115
  self.conn_prop = {}
1084
1116
  self.iter_count = 0
1085
1117
 
@@ -1136,7 +1168,7 @@ class UnidirectionConnector(AbstractConnector):
1136
1168
  if self.verbose:
1137
1169
  src_str, trg_str = self.get_nodes_info()
1138
1170
  print("\nStart building connection \n from "
1139
- + src_str + "\n to " + trg_str)
1171
+ + src_str + "\n to " + trg_str,flush=True)
1140
1172
 
1141
1173
  # Make random connections
1142
1174
  p_arg = self.p_arg(source, target)
@@ -1157,6 +1189,9 @@ class UnidirectionConnector(AbstractConnector):
1157
1189
  if self.verbose:
1158
1190
  self.connection_number_info()
1159
1191
  self.timer.report('Done! \nTime for building connections')
1192
+ if self.save_report:
1193
+ self.save_connection_report()
1194
+
1160
1195
  return nsyns
1161
1196
 
1162
1197
  # *** Helper functions for verbose ***
@@ -1168,13 +1203,37 @@ class UnidirectionConnector(AbstractConnector):
1168
1203
 
1169
1204
  def connection_number_info(self):
1170
1205
  """Print connection numbers after connections built"""
1171
- print("Number of connected pairs: %d" % self.n_conn)
1172
- print("Number of possible connections: %d" % self.n_poss)
1206
+ print("Number of connected pairs: %d" % self.n_conn,flush=True)
1207
+ print("Number of possible connections: %d" % self.n_poss,flush=True)
1173
1208
  print("Fraction of connected pairs in possible ones: %.2f%%"
1174
1209
  % (100. * self.n_conn / self.n_poss) if self.n_poss else 0.)
1175
- print("Number of total pairs: %d" % self.n_pair)
1210
+ print("Number of total pairs: %d" % self.n_pair,flush=True)
1176
1211
  print("Fraction of connected pairs in all pairs: %.2f%%\n"
1177
- % (100. * self.n_conn / self.n_pair))
1212
+ % (100. * self.n_conn / self.n_pair),flush=True)
1213
+
1214
+ def save_connection_report(self):
1215
+ """Save connections into a CSV file to be read from later"""
1216
+ src_str, trg_str = self.get_nodes_info()
1217
+ n_conn, n_poss, n_pair, fraction = self.connection_number()
1218
+
1219
+ # Extract the population name from source_str and target_str
1220
+ data = {
1221
+ "Source": [src_str],
1222
+ "Target": [trg_str],
1223
+ "Fraction of connected pairs in possible ones (%)": [fraction[0]*100],
1224
+ "Fraction of connected pairs in all pairs (%)": [fraction[1]*100]
1225
+ }
1226
+ df = pd.DataFrame(data)
1227
+
1228
+ # Append the data to the CSV file
1229
+ try:
1230
+ # Check if the file exists by trying to read it
1231
+ existing_df = pd.read_csv(self.report_name)
1232
+ # If no exception is raised, append without header
1233
+ df.to_csv(self.report_name, mode='a', header=False, index=False)
1234
+ except FileNotFoundError:
1235
+ # If the file does not exist, write with header
1236
+ df.to_csv(self.report_name, mode='w', header=True, index=False)
1178
1237
 
1179
1238
 
1180
1239
  class GapJunction(UnidirectionConnector):
@@ -1198,8 +1257,8 @@ class GapJunction(UnidirectionConnector):
1198
1257
  Similar to `UnidirectionConnector`.
1199
1258
  """
1200
1259
 
1201
- def __init__(self, p=1., p_arg=None, verbose=True):
1202
- super().__init__(p=p, p_arg=p_arg, verbose=verbose)
1260
+ def __init__(self, p=1., p_arg=None, verbose=True,report_name='connection_report.csv'):
1261
+ super().__init__(p=p, p_arg=p_arg, verbose=verbose,report_name=report_name)
1203
1262
 
1204
1263
  def setup_nodes(self, source=None, target=None):
1205
1264
  super().setup_nodes(source=source, target=target)
@@ -1215,7 +1274,7 @@ class GapJunction(UnidirectionConnector):
1215
1274
  self.initialize()
1216
1275
  if self.verbose:
1217
1276
  src_str, _ = self.get_nodes_info()
1218
- print("\nStart building gap junction \n in " + src_str)
1277
+ print("\nStart building gap junction \n in " + src_str,flush=True)
1219
1278
 
1220
1279
  # Consider each pair only once
1221
1280
  nsyns = 0
@@ -1239,6 +1298,8 @@ class GapJunction(UnidirectionConnector):
1239
1298
  if self.verbose:
1240
1299
  self.connection_number_info()
1241
1300
  self.timer.report('Done! \nTime for building connections')
1301
+ if self.save_report:
1302
+ self.save_connection_report()
1242
1303
  return nsyns
1243
1304
 
1244
1305
  def connection_number_info(self):
@@ -1247,6 +1308,32 @@ class GapJunction(UnidirectionConnector):
1247
1308
  super().connection_number_info()
1248
1309
  self.n_pair = n_pair
1249
1310
 
1311
+ def save_connection_report(self):
1312
+ """Save connections into a CSV file to be read from later"""
1313
+ src_str, trg_str = self.get_nodes_info()
1314
+ n_pair = self.n_pair
1315
+ fraction_0 = self.n_conn / self.n_poss if self.n_poss else 0.
1316
+ fraction_1 = self.n_conn / self.n_pair
1317
+
1318
+ # Convert fraction to percentage and prepare data for the DataFrame
1319
+ data = {
1320
+ "Source": [src_str+"Gap"],
1321
+ "Target": [trg_str+"Gap"],
1322
+ "Fraction of connected pairs in possible ones (%)": [fraction_0*100],
1323
+ "Fraction of connected pairs in all pairs (%)": [fraction_1*100]
1324
+ }
1325
+ df = pd.DataFrame(data)
1326
+
1327
+ # Append the data to the CSV file
1328
+ try:
1329
+ # Check if the file exists by trying to read it
1330
+ existing_df = pd.read_csv(self.report_name)
1331
+ # If no exception is raised, append without header
1332
+ df.to_csv(self.report_name, mode='a', header=False, index=False)
1333
+ except FileNotFoundError:
1334
+ # If the file does not exist, write with header
1335
+ df.to_csv(self.report_name, mode='w', header=True, index=False)
1336
+
1250
1337
 
1251
1338
  class CorrelatedGapJunction(GapJunction):
1252
1339
  """
@@ -1314,7 +1401,7 @@ class CorrelatedGapJunction(GapJunction):
1314
1401
  self.initialize()
1315
1402
  if self.verbose:
1316
1403
  src_str, _ = self.get_nodes_info()
1317
- print("\nStart building gap junction \n in " + src_str)
1404
+ print("\nStart building gap junction \n in " + src_str,flush=True)
1318
1405
 
1319
1406
  # Consider each pair only once
1320
1407
  nsyns = 0
@@ -1340,6 +1427,8 @@ class CorrelatedGapJunction(GapJunction):
1340
1427
  if self.verbose:
1341
1428
  self.connection_number_info()
1342
1429
  self.timer.report('Done! \nTime for building connections')
1430
+ if self.save_report:
1431
+ self.save_connection_report()
1343
1432
  return nsyns
1344
1433
 
1345
1434
 
@@ -1422,7 +1511,7 @@ class OneToOneSequentialConnector(AbstractConnector):
1422
1511
 
1423
1512
  if self.verbose and self.idx_range[-1] == self.n_source:
1424
1513
  print("All " + ("source" if self.partition_source else "target")
1425
- + " population partitions are filled.")
1514
+ + " population partitions are filled.",flush=True)
1426
1515
 
1427
1516
  def edge_params(self, target_pop_idx=-1):
1428
1517
  """Create the arguments for BMTK add_edges() method"""
@@ -1447,14 +1536,14 @@ class OneToOneSequentialConnector(AbstractConnector):
1447
1536
  self.target_count = 0
1448
1537
  src_str, trg_str = self.get_nodes_info()
1449
1538
  print("\nStart building connection " +
1450
- ("to " if self.partition_source else "from ") + src_str)
1539
+ ("to " if self.partition_source else "from ") + src_str,flush=True)
1451
1540
  self.timer = Timer()
1452
1541
 
1453
1542
  if self.iter_count == self.idx_range[self.target_count]:
1454
1543
  # Beginning of each target population
1455
1544
  src_str, trg_str = self.get_nodes_info(self.target_count)
1456
1545
  print((" %d. " % self.target_count) +
1457
- ("from " if self.partition_source else "to ") + trg_str)
1546
+ ("from " if self.partition_source else "to ") + trg_str,flush=True)
1458
1547
  self.target_count += 1
1459
1548
  self.timer_part = Timer()
1460
1549
 
@@ -1491,6 +1580,18 @@ FLUC_STDEV = 0.2 # ms
1491
1580
  DELAY_LOWBOUND = 0.2 # ms must be greater than h.dt
1492
1581
  DELAY_UPBOUND = 2.0 # ms
1493
1582
 
1583
+ def syn_const_delay(source=None, target = None, dist=100,
1584
+ min_delay=SYN_MIN_DELAY, velocity=SYN_VELOCITY,
1585
+ fluc_stdev=FLUC_STDEV, delay_bound=(DELAY_LOWBOUND, DELAY_UPBOUND),
1586
+ connector=None):
1587
+ """Synapse delay constant with some random fluctuation.
1588
+ """
1589
+ del_fluc = fluc_stdev * rng.normal()
1590
+ delay = dist / SYN_VELOCITY + SYN_MIN_DELAY + del_fluc
1591
+ delay = min(max(delay, DELAY_LOWBOUND), DELAY_UPBOUND)
1592
+ return delay
1593
+
1594
+
1494
1595
  def syn_dist_delay_feng(source, target, min_delay=SYN_MIN_DELAY,
1495
1596
  velocity=SYN_VELOCITY, fluc_stdev=FLUC_STDEV,
1496
1597
  delay_bound=(DELAY_LOWBOUND, DELAY_UPBOUND),
@@ -1520,6 +1621,14 @@ def syn_section_PN(source, target, p=0.9,
1520
1621
  return sec_id[syn_loc], sec_x[syn_loc]
1521
1622
 
1522
1623
 
1624
+ def syn_const_delay_feng_section_PN(source, target, p=0.9,
1625
+ sec_id=(1, 2), sec_x=(0.4, 0.6), **kwargs):
1626
+ """Assign both synapse delay and location with constant distance assumed"""
1627
+ delay = syn_const_delay(source, target,**kwargs)
1628
+ s_id, s_x = syn_section_PN(source, target, p=p, sec_id=sec_id, sec_x=sec_x)
1629
+ return delay, s_id, s_x
1630
+
1631
+
1523
1632
  def syn_dist_delay_feng_section_PN(source, target, p=0.9,
1524
1633
  sec_id=(1, 2), sec_x=(0.4, 0.6), **kwargs):
1525
1634
  """Assign both synapse delay and location"""
bmtool/graphs.py CHANGED
@@ -1,5 +1,4 @@
1
1
  import networkx as nx
2
- import plotly.graph_objects as go
3
2
  import pandas as pd
4
3
  import bmtool.util.util as u
5
4
  import pandas as pd
@@ -24,7 +23,7 @@ def generate_graph(config,source,target):
24
23
  edges = edges[edge_to_grap]
25
24
 
26
25
  # Create an empty graph
27
- G = nx.Graph()
26
+ G = nx.DiGraph()
28
27
 
29
28
  # Add nodes to the graph with their positions and labels
30
29
  for index, node_data in nodes.iterrows():
@@ -36,122 +35,124 @@ def generate_graph(config,source,target):
36
35
 
37
36
  return G
38
37
 
39
-
40
- def plot_graph(Graph=None,show_edges = False,title = None):
41
- """
42
- Generate an interactive plot of the network
43
- Graph: A Graph object
44
- show_edges: Boolean to show edges in graph plot
45
- title: A string for the title of the graph
38
+ # import plotly.graph_objects as go
39
+ # def plot_graph(Graph=None,show_edges = False,title = None):
40
+ # """
41
+ # Generate an interactive plot of the network
42
+ # Graph: A Graph object
43
+ # show_edges: Boolean to show edges in graph plot
44
+ # title: A string for the title of the graph
46
45
 
47
- """
48
-
49
- # Extract node positions
50
- node_positions = nx.get_node_attributes(Graph, 'pos')
51
- node_x = [data[0] for data in node_positions.values()]
52
- node_y = [data[1] for data in node_positions.values()]
53
- node_z = [data[2] for data in node_positions.values()]
54
-
55
- # Create edge traces
56
- edge_x = []
57
- edge_y = []
58
- edge_z = []
59
- for edge in Graph.edges():
60
- x0, y0, z0 = node_positions[edge[0]]
61
- x1, y1, z1 = node_positions[edge[1]]
62
- edge_x.extend([x0, x1, None])
63
- edge_y.extend([y0, y1, None])
64
- edge_z.extend([z0, z1, None])
65
-
66
- # Create edge trace
67
- edge_trace = go.Scatter3d(
68
- x=edge_x,
69
- y=edge_y,
70
- z=edge_z,
71
- line=dict(width=1, color='#888'),
72
- hoverinfo='none',
73
- mode='lines',
74
- opacity=0.2)
75
-
76
- # Create node trace
77
- node_trace = go.Scatter3d(
78
- x=node_x,
79
- y=node_y,
80
- z=node_z,
81
- mode='markers',
82
- hoverinfo='text',
83
- marker=dict(
84
- showscale=True,
85
- colorscale='YlGnBu', # Adjust color scale here
86
- reversescale=True,
87
- color=[len(list(Graph.neighbors(node))) for node in Graph.nodes()], # Assign color data here
88
- size=5, # Adjust the size of the nodes here
89
- colorbar=dict(
90
- thickness=15,
91
- title='Node Connections',
92
- xanchor='left',
93
- titleside='right'
94
- ),
95
- line_width=2,
96
- cmin=0, # Adjust color scale range here
97
- cmax=max([len(list(Graph.neighbors(node))) for node in Graph.nodes()]) # Adjust color scale range here
98
- ))
99
-
100
- # Define hover text for nodes
101
- node_hover_text = [f'Node ID: {node_id}<br>Population Name: {node_data["label"]}<br># of Connections: {len(list(Graph.neighbors(node_id)))}' for node_id, node_data in Graph.nodes(data=True)]
102
- node_trace.hovertext = node_hover_text
103
-
104
- # Create figure
105
- if show_edges:
106
- graph_prop = [edge_trace,node_trace]
107
- else:
108
- graph_prop = [node_trace]
109
-
110
- if title == None:
111
- title = '3D plot'
46
+ # """
47
+
48
+ # # Extract node positions
49
+ # node_positions = nx.get_node_attributes(Graph, 'pos')
50
+ # node_x = [data[0] for data in node_positions.values()]
51
+ # node_y = [data[1] for data in node_positions.values()]
52
+ # node_z = [data[2] for data in node_positions.values()]
53
+
54
+ # # Create edge traces
55
+ # edge_x = []
56
+ # edge_y = []
57
+ # edge_z = []
58
+ # for edge in Graph.edges():
59
+ # x0, y0, z0 = node_positions[edge[0]]
60
+ # x1, y1, z1 = node_positions[edge[1]]
61
+ # edge_x.extend([x0, x1, None])
62
+ # edge_y.extend([y0, y1, None])
63
+ # edge_z.extend([z0, z1, None])
64
+
65
+ # # Create edge trace
66
+ # edge_trace = go.Scatter3d(
67
+ # x=edge_x,
68
+ # y=edge_y,
69
+ # z=edge_z,
70
+ # line=dict(width=1, color='#888'),
71
+ # hoverinfo='none',
72
+ # mode='lines',
73
+ # opacity=0.2)
74
+
75
+ # # Create node trace
76
+ # node_trace = go.Scatter3d(
77
+ # x=node_x,
78
+ # y=node_y,
79
+ # z=node_z,
80
+ # mode='markers',
81
+ # hoverinfo='text',
82
+ # marker=dict(
83
+ # showscale=True,
84
+ # colorscale='YlGnBu', # Adjust color scale here
85
+ # reversescale=True,
86
+ # color=[len(list(Graph.neighbors(node))) for node in Graph.nodes()], # Assign color data here
87
+ # size=5, # Adjust the size of the nodes here
88
+ # colorbar=dict(
89
+ # thickness=15,
90
+ # title='Node Connections',
91
+ # xanchor='left',
92
+ # titleside='right'
93
+ # ),
94
+ # line_width=2,
95
+ # cmin=0, # Adjust color scale range here
96
+ # cmax=max([len(list(Graph.neighbors(node))) for node in Graph.nodes()]) # Adjust color scale range here
97
+ # ))
98
+
99
+ # # Define hover text for nodes
100
+ # node_hover_text = [f'Node ID: {node_id}<br>Population Name: {node_data["label"]}<br># of Connections: {len(list(Graph.neighbors(node_id)))}' for node_id, node_data in Graph.nodes(data=True)]
101
+ # node_trace.hovertext = node_hover_text
102
+
103
+ # # Create figure
104
+ # if show_edges:
105
+ # graph_prop = [edge_trace,node_trace]
106
+ # else:
107
+ # graph_prop = [node_trace]
108
+
109
+ # if title == None:
110
+ # title = '3D plot'
112
111
 
113
- fig = go.Figure(data=graph_prop,
114
- layout=go.Layout(
115
- title=title,
116
- titlefont_size=16,
117
- showlegend=False,
118
- hovermode='closest',
119
- margin=dict(b=20, l=5, r=5, t=40),
120
- scene=dict(
121
- xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
122
- yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
123
- zaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
124
- ),
125
- width=800,
126
- height=800
127
- ))
128
-
129
- # Show figure
130
- fig.show()
112
+ # fig = go.Figure(data=graph_prop,
113
+ # layout=go.Layout(
114
+ # title=title,
115
+ # titlefont_size=16,
116
+ # showlegend=False,
117
+ # hovermode='closest',
118
+ # margin=dict(b=20, l=5, r=5, t=40),
119
+ # scene=dict(
120
+ # xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
121
+ # yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
122
+ # zaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
123
+ # ),
124
+ # width=800,
125
+ # height=800
126
+ # ))
127
+
128
+ # # Show figure
129
+ # fig.show()
131
130
 
132
131
 
133
132
  def export_node_connections_to_csv(Graph, filename):
134
133
  """
135
- Generates a csv file with node type and all connections that node receives
136
- Graph: a Graph object
137
- filename: A string for the name of output must end in .csv
134
+ Generates a CSV file with node type and all outgoing connections that node has.
135
+
136
+ Parameters:
137
+ Graph: a DiGraph object (directed graph)
138
+ filename: A string for the name of output, must end in .csv
138
139
  """
139
140
  # Create an empty dictionary to store the connections for each node
140
141
  node_connections = {}
141
142
 
142
143
  # Iterate over each node in the graph
143
144
  for node in Graph.nodes():
144
- # Initialize a dictionary to store the connections for the current node
145
+ # Initialize a dictionary to store the outgoing connections for the current node
145
146
  connections = {}
146
147
  node_label = Graph.nodes[node]['label']
147
148
 
148
- # Iterate over each neighbor of the current node
149
- for neighbor in Graph.neighbors(node):
150
- # Get the label of the neighbor node
151
- neighbor_label = Graph.nodes[neighbor]['label']
149
+ # Iterate over each presuccessor (ingoing neighbor) of the current node
150
+ for successor in Graph.predecessors(node):
151
+ # Get the label of the successor node
152
+ successor_label = Graph.nodes[successor]['label']
152
153
 
153
- # Increment the connection count for the current node and neighbor label
154
- connections[f'{neighbor_label} Connections'] = connections.get(f'{neighbor_label} Connections', 0) + 1
154
+ # Increment the connection count for the current node and successor label
155
+ connections[f'{successor_label} incoming Connections'] = connections.get(f'{successor_label} incoming Connections', 0) + 1
155
156
 
156
157
  # Add the connections information for the current node to the dictionary
157
158
  connections['Node Label'] = node_label
@@ -166,5 +167,4 @@ def export_node_connections_to_csv(Graph, filename):
166
167
  df = df[cols]
167
168
 
168
169
  # Write the DataFrame to a CSV file
169
- df.to_csv(filename)
170
-
170
+ df.to_csv(filename)