bmtool 0.7.0.6.4__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bmtool/util/util.py CHANGED
@@ -1,53 +1,69 @@
1
1
  import argparse
2
- from argparse import RawTextHelpFormatter,SUPPRESS
3
- import glob, json, os, re, sys
4
2
  import math
5
- import numpy as np
6
- from numpy import genfromtxt
3
+ import os
4
+ import smtplib
5
+ import sys
6
+ from argparse import SUPPRESS, RawTextHelpFormatter
7
+ from email.mime.application import MIMEApplication
8
+ from email.mime.multipart import MIMEMultipart
9
+ from email.mime.text import MIMEText
10
+ from email.utils import COMMASPACE, formatdate
11
+ from os.path import basename
12
+
7
13
  import h5py
8
- import pandas as pd
9
14
  import neuron
15
+ import numpy as np
16
+ import pandas as pd
10
17
  from neuron import h
11
18
 
12
- #from bmtk.utils.io.cell_vars import CellVarsFile
13
- #from bmtk.analyzer.cell_vars import _get_cell_report
14
- #from bmtk.analyzer.io_tools import load_config
19
+ # from bmtk.utils.io.cell_vars import CellVarsFile
20
+ # from bmtk.analyzer.cell_vars import _get_cell_report
21
+ # from bmtk.analyzer.io_tools import load_config
22
+
15
23
 
16
24
  def get_argparse(use_description):
17
- parser = argparse.ArgumentParser(description=use_description, formatter_class=RawTextHelpFormatter,usage=SUPPRESS)
25
+ parser = argparse.ArgumentParser(
26
+ description=use_description, formatter_class=RawTextHelpFormatter, usage=SUPPRESS
27
+ )
18
28
  return parser
19
-
29
+
30
+
20
31
  def verify_parse(parser):
21
32
  try:
22
33
  if not len(sys.argv) > 1:
23
34
  raise
24
- #if sys.argv[1] in ['-h','--h','-help','--help','help']:
35
+ # if sys.argv[1] in ['-h','--h','-help','--help','help']:
25
36
  # raise
26
37
  parser.parse_args()
27
38
  except:
28
39
  parser.print_help()
29
40
  sys.exit(0)
30
-
41
+
42
+
31
43
  use_description = """
32
44
  BMTK model utilties.
33
45
 
34
- python -m bmtool.util
46
+ python -m bmtool.util
35
47
  """
36
48
 
37
- if __name__ == '__main__':
49
+ if __name__ == "__main__":
38
50
  parser = get_argparse(use_description)
39
51
  verify_parse(parser)
40
-
52
+
41
53
 
42
54
  class CellVarsFile(object):
43
- VAR_UNKNOWN = 'Unknown'
44
- UNITS_UNKNOWN = 'NA'
55
+ VAR_UNKNOWN = "Unknown"
56
+ UNITS_UNKNOWN = "NA"
45
57
 
46
- def __init__(self, filename, mode='r', **params):
47
-
58
+ def __init__(self, filename, mode="r", **params):
48
59
  import h5py
49
- self._h5_handle = h5py.File(filename, 'r')
50
- self._h5_root = self._h5_handle[params['h5_root']] if 'h5_root' in params else self._h5_handle['/report/cortex']
60
+
61
+ self._h5_handle = h5py.File(filename, "r")
62
+ self._h5_root = (
63
+ self._h5_handle[params["h5_root"]]
64
+ if "h5_root" in params
65
+ else self._h5_handle["/report/cortex"]
66
+ )
51
67
  self._var_data = {}
52
68
  self._var_units = {}
53
69
 
@@ -58,38 +74,45 @@ class CellVarsFile(object):
58
74
  print(self._h5_root.keys())
59
75
  hf_grp = self._h5_root[var_name]
60
76
 
61
- if var_name == 'data':
77
+ if var_name == "data":
62
78
  # According to the sonata format the /data table should be located at the root
63
- var_name = self._h5_root['data'].attrs.get('variable_name', CellVarsFile.VAR_UNKNOWN)
64
- self._var_data[var_name] = self._h5_root['data']
65
- self._var_units[var_name] = self._find_units(self._h5_root['data'])
79
+ var_name = self._h5_root["data"].attrs.get(
80
+ "variable_name", CellVarsFile.VAR_UNKNOWN
81
+ )
82
+ self._var_data[var_name] = self._h5_root["data"]
83
+ self._var_units[var_name] = self._find_units(self._h5_root["data"])
66
84
 
67
85
  if not isinstance(hf_grp, h5py.Group):
68
86
  continue
69
87
 
70
- if var_name == 'mapping':
88
+ if var_name == "mapping":
71
89
  # Check for /mapping group
72
90
  self._mapping = hf_grp
73
91
  else:
74
92
  # In the bmtk we can support multiple variables in the same file (not sonata compliant but should be)
75
93
  # where each variable table is separated into its own group /<var_name>/data
76
- if 'data' not in hf_grp:
77
- print('Warning: could not find "data" dataset in {}. Skipping!'.format(var_name))
94
+ if "data" not in hf_grp:
95
+ print(
96
+ 'Warning: could not find "data" dataset in {}. Skipping!'.format(var_name)
97
+ )
78
98
  else:
79
- self._var_data[var_name] = hf_grp['data']
80
- self._var_units[var_name] = self._find_units(hf_grp['data'])
99
+ self._var_data[var_name] = hf_grp["data"]
100
+ self._var_units[var_name] = self._find_units(hf_grp["data"])
81
101
 
82
102
  # create map between gids and tables
83
103
  self._gid2data_table = {}
84
104
  if self._mapping is None:
85
- raise Exception('could not find /mapping group')
105
+ raise Exception("could not find /mapping group")
86
106
  else:
87
- gids_ds = self._mapping['node_ids']
88
- index_pointer_ds = self._mapping['index_pointer']
107
+ gids_ds = self._mapping["node_ids"]
108
+ index_pointer_ds = self._mapping["index_pointer"]
89
109
  for indx, gid in enumerate(gids_ds):
90
- self._gid2data_table[gid] = (index_pointer_ds[indx], index_pointer_ds[indx+1]) # slice(index_pointer_ds[indx], index_pointer_ds[indx+1])
110
+ self._gid2data_table[gid] = (
111
+ index_pointer_ds[indx],
112
+ index_pointer_ds[indx + 1],
113
+ ) # slice(index_pointer_ds[indx], index_pointer_ds[indx+1])
91
114
 
92
- time_ds = self._mapping['time']
115
+ time_ds = self._mapping["time"]
93
116
  self._t_start = time_ds[0]
94
117
  self._t_stop = time_ds[1]
95
118
  self._dt = time_ds[2]
@@ -124,7 +147,7 @@ class CellVarsFile(object):
124
147
  return self._h5_root
125
148
 
126
149
  def _find_units(self, data_set):
127
- return data_set.attrs.get('units', CellVarsFile.UNITS_UNKNOWN)
150
+ return data_set.attrs.get("units", CellVarsFile.UNITS_UNKNOWN)
128
151
 
129
152
  def units(self, var_name=VAR_UNKNOWN):
130
153
  return self._var_units[var_name]
@@ -135,128 +158,139 @@ class CellVarsFile(object):
135
158
 
136
159
  def compartment_ids(self, gid):
137
160
  bounds = self._gid2data_table[gid]
138
- return self._mapping['element_id'][bounds[0]:bounds[1]]
161
+ return self._mapping["element_id"][bounds[0] : bounds[1]]
139
162
 
140
163
  def compartment_positions(self, gid):
141
164
  bounds = self._gid2data_table[gid]
142
- return self._mapping['element_pos'][bounds[0]:bounds[1]]
165
+ return self._mapping["element_pos"][bounds[0] : bounds[1]]
143
166
 
144
- def data(self, gid, var_name=VAR_UNKNOWN,time_window=None, compartments='origin'):
167
+ def data(self, gid, var_name=VAR_UNKNOWN, time_window=None, compartments="origin"):
145
168
  print(self.variables)
146
169
  if var_name not in self.variables:
147
- raise Exception('Unknown variable {}'.format(var_name))
170
+ raise Exception("Unknown variable {}".format(var_name))
148
171
 
149
172
  if time_window is None:
150
173
  time_slice = slice(0, self._n_steps)
151
174
  else:
152
175
  if len(time_window) != 2:
153
- raise Exception('Invalid time_window, expecting tuple [being, end].')
176
+ raise Exception("Invalid time_window, expecting tuple [being, end].")
154
177
 
155
- window_beg = max(int((time_window[0] - self.t_start)/self.dt), 0)
156
- window_end = min(int((time_window[1] - self.t_start)/self.dt), self._n_steps/self.dt)
178
+ window_beg = max(int((time_window[0] - self.t_start) / self.dt), 0)
179
+ window_end = min(
180
+ int((time_window[1] - self.t_start) / self.dt), self._n_steps / self.dt
181
+ )
157
182
  time_slice = slice(window_beg, window_end)
158
183
 
159
184
  multi_compartments = True
160
- if compartments == 'origin' or self.n_compartments(gid) == 1:
185
+ if compartments == "origin" or self.n_compartments(gid) == 1:
161
186
  # Return the first (and possibly only) compartment for said gid
162
187
  gid_slice = self._gid2data_table[gid][0]
163
188
  multi_compartments = False
164
- elif compartments == 'all':
189
+ elif compartments == "all":
165
190
  # Return all compartments
166
191
  gid_slice = slice(self._gid2data_table[gid][0], self._gid2data_table[gid][1])
167
192
  else:
168
193
  # return all compartments with corresponding element id
169
- compartment_list = list(compartments) if isinstance(compartments, (long, int)) else compartments
194
+ compartment_list = list(compartments) if isinstance(compartments, int) else compartments
170
195
  begin = self._gid2data_table[gid][0]
171
196
  end = self._gid2data_table[gid][1]
172
197
  gid_slice = [i for i in range(begin, end) if self._mapping[i] in compartment_list]
173
198
 
174
199
  data = np.array(self._var_data[var_name][time_slice, gid_slice])
175
200
  return data.T if multi_compartments else data
176
-
201
+
202
+
177
203
  def load_config(config_file):
178
204
  import bmtk.simulator.core.simulation_config as config
205
+
179
206
  conf = config.from_json(config_file)
180
- #from bmtk.simulator import bionet
181
- #conf = bionet.Config.from_json(config_file, validate=True)
207
+ # from bmtk.simulator import bionet
208
+ # conf = bionet.Config.from_json(config_file, validate=True)
182
209
  return conf
183
210
 
211
+
184
212
  def load_nodes_edges_from_config(fp):
185
213
  if fp is None:
186
- fp = 'simulation_config.json'
214
+ fp = "simulation_config.json"
187
215
  config = load_config(fp)
188
- nodes = load_nodes_from_paths(config['networks']['nodes'])
189
- edges = load_edges_from_paths(config['networks']['edges'])
216
+ nodes = load_nodes_from_paths(config["networks"]["nodes"])
217
+ edges = load_edges_from_paths(config["networks"]["edges"])
190
218
  return nodes, edges
191
219
 
220
+
192
221
  def load_nodes(nodes_file, node_types_file):
193
- nodes_arr = [{"nodes_file":nodes_file,"node_types_file":node_types_file}]
222
+ nodes_arr = [{"nodes_file": nodes_file, "node_types_file": node_types_file}]
194
223
  nodes = list(load_nodes_from_paths(nodes_arr).items())[0] # single item
195
224
  return nodes # return (population, nodes_df)
196
225
 
226
+
197
227
  def load_nodes_from_config(config):
198
228
  if config is None:
199
- config = 'simulation_config.json'
200
- networks = load_config(config)['networks']
201
- return load_nodes_from_paths(networks['nodes'])
229
+ config = "simulation_config.json"
230
+ networks = load_config(config)["networks"]
231
+ return load_nodes_from_paths(networks["nodes"])
232
+
202
233
 
203
234
  def load_nodes_from_paths(node_paths):
204
235
  """
205
- node_paths must be in the format in a circuit config file:
206
- [
207
- {
208
- "nodes_file":"filepath",
209
- "node_types_file":"filepath"
210
- },...
211
- ]
212
- #Glob all files for *_nodes.h5
213
- #Glob all files for *_edges.h5
214
-
215
- Returns a dictionary indexed by population, of pandas tables in the following format:
216
- node_type_id model_template morphology model_type pop_name pos_x pos_y pos_z
217
- node_id
218
- 0 100 hoc:IzhiCell_EC blank.swc biophysical EC 1.5000 0.2500 10.0
219
-
220
- Where pop_name was a user defined cell property
236
+ node_paths must be in the format in a circuit config file:
237
+ [
238
+ {
239
+ "nodes_file":"filepath",
240
+ "node_types_file":"filepath"
241
+ },...
242
+ ]
243
+ #Glob all files for *_nodes.h5
244
+ #Glob all files for *_edges.h5
245
+
246
+ Returns a dictionary indexed by population, of pandas tables in the following format:
247
+ node_type_id model_template morphology model_type pop_name pos_x pos_y pos_z
248
+ node_id
249
+ 0 100 hoc:IzhiCell_EC blank.swc biophysical EC 1.5000 0.2500 10.0
250
+
251
+ Where pop_name was a user defined cell property
221
252
  """
222
253
  import h5py
223
254
  import pandas as pd
224
-
225
- #nodes_regex = "_nodes.h5"
226
- #node_types_regex = "_node_types.csv"
227
255
 
228
- #nodes_h5_fpaths = glob.glob(os.path.join(network_dir,'*'+nodes_regex))
229
- #node_types_fpaths = glob.glob(os.path.join(network_dir,'*'+node_types_regex))
256
+ # nodes_regex = "_nodes.h5"
257
+ # node_types_regex = "_node_types.csv"
258
+
259
+ # nodes_h5_fpaths = glob.glob(os.path.join(network_dir,'*'+nodes_regex))
260
+ # node_types_fpaths = glob.glob(os.path.join(network_dir,'*'+node_types_regex))
230
261
 
231
- #regions = [re.findall('^[^_]+', os.path.basename(n))[0] for n in nodes_h5_fpaths]
262
+ # regions = [re.findall('^[^_]+', os.path.basename(n))[0] for n in nodes_h5_fpaths]
232
263
  region_dict = {}
233
264
 
234
- pos_labels = ('pos_x', 'pos_y', 'pos_z')
265
+ pos_labels = ("pos_x", "pos_y", "pos_z")
235
266
 
236
- #Need to get all cell groups for each region
267
+ # Need to get all cell groups for each region
237
268
  def get_node_table(cell_models_file, cells_file, population=None):
238
- cm_df = pd.read_csv(cells_file, sep=' ')
239
- cm_df.set_index('node_type_id', inplace=True)
269
+ cm_df = pd.read_csv(cells_file, sep=" ")
270
+ cm_df.set_index("node_type_id", inplace=True)
240
271
 
241
- cells_h5 = h5py.File(cell_models_file, 'r')
272
+ cells_h5 = h5py.File(cell_models_file, "r")
242
273
  if population is None:
243
- if len(cells_h5['/nodes']) > 1:
244
- raise Exception('Multiple populations in nodes file. Not currently supported. Should be easy to implement when needed. Let Tyler know.')
274
+ if len(cells_h5["/nodes"]) > 1:
275
+ raise Exception(
276
+ "Multiple populations in nodes file. Not currently supported. Should be easy to implement when needed. Let Tyler know."
277
+ )
245
278
  else:
246
- population = list(cells_h5['/nodes'])[0]
279
+ population = list(cells_h5["/nodes"])[0]
247
280
 
248
- nodes_grp = cells_h5['/nodes'][population]
249
- c_df = pd.DataFrame({key: nodes_grp[key] for key in ('node_id', 'node_type_id')})
250
- c_df.set_index('node_id', inplace=True)
281
+ nodes_grp = cells_h5["/nodes"][population]
282
+ c_df = pd.DataFrame({key: nodes_grp[key] for key in ("node_id", "node_type_id")})
283
+ c_df.set_index("node_id", inplace=True)
251
284
 
252
- nodes_df = pd.merge(left=c_df, right=cm_df, how='left',
253
- left_on='node_type_id', right_index=True) # use 'model_id' key to merge, for right table the "model_id" is an index
285
+ nodes_df = pd.merge(
286
+ left=c_df, right=cm_df, how="left", left_on="node_type_id", right_index=True
287
+ ) # use 'model_id' key to merge, for right table the "model_id" is an index
254
288
 
255
289
  # extra properties of individual nodes (see SONATA Data format)
256
- if nodes_grp.get('0'):
257
- node_id = nodes_grp['node_id'][()]
258
- node_group_id = nodes_grp['node_group_id'][()]
259
- node_group_index = nodes_grp['node_group_index'][()]
290
+ if nodes_grp.get("0"):
291
+ node_id = nodes_grp["node_id"][()]
292
+ node_group_id = nodes_grp["node_group_id"][()]
293
+ node_group_index = nodes_grp["node_group_index"][()]
260
294
  n_group = node_group_id.max() + 1
261
295
  prop_dtype = {}
262
296
  for group_id in range(n_group):
@@ -265,7 +299,7 @@ def load_nodes_from_paths(node_paths):
265
299
  group_node = node_id[idx]
266
300
  group_index = node_group_index[idx]
267
301
  for prop in group:
268
- if prop == 'positions':
302
+ if prop == "positions":
269
303
  positions = group[prop][group_index]
270
304
  for i in range(positions.shape[1]):
271
305
  if pos_labels[i] not in nodes_df:
@@ -273,17 +307,17 @@ def load_nodes_from_paths(node_paths):
273
307
  nodes_df.loc[group_node, pos_labels[i]] = positions[:, i]
274
308
  else:
275
309
  # create new column with NaN if property does not exist
276
- if prop not in nodes_df:
310
+ if prop not in nodes_df:
277
311
  nodes_df[prop] = np.nan
278
312
  nodes_df.loc[group_node, prop] = group[prop][group_index]
279
313
  prop_dtype[prop] = group[prop].dtype
280
314
  # convert to original data type if possible
281
315
  for prop, dtype in prop_dtype.items():
282
- nodes_df[prop] = nodes_df[prop].astype(dtype, errors='ignore')
316
+ nodes_df[prop] = nodes_df[prop].astype(dtype, errors="ignore")
283
317
 
284
318
  return population, nodes_df
285
319
 
286
- #for region, cell_models_file, cells_file in zip(regions, node_types_fpaths, nodes_h5_fpaths):
320
+ # for region, cell_models_file, cells_file in zip(regions, node_types_fpaths, nodes_h5_fpaths):
287
321
  # region_dict[region] = get_node_table(cell_models_file,cells_file,population=region)
288
322
  for nodes in node_paths:
289
323
  cell_models_file = nodes["nodes_file"]
@@ -291,24 +325,27 @@ def load_nodes_from_paths(node_paths):
291
325
  region_name, region = get_node_table(cell_models_file, cells_file)
292
326
  region_dict[region_name] = region
293
327
 
294
- #cell_num = 2
295
- #print(region_dict["hippocampus"].iloc[cell_num]["node_type_id"])
296
- #print(list(set(region_dict["hippocampus"]["node_type_id"]))) #Unique
328
+ # cell_num = 2
329
+ # print(region_dict["hippocampus"].iloc[cell_num]["node_type_id"])
330
+ # print(list(set(region_dict["hippocampus"]["node_type_id"]))) #Unique
297
331
 
298
332
  return region_dict
299
-
333
+
334
+
300
335
  def load_edges_from_config(config):
301
336
  if config is None:
302
- config = 'simulation_config.json'
303
- networks = load_config(config)['networks']
304
- return load_edges_from_paths(networks['edges'])
337
+ config = "simulation_config.json"
338
+ networks = load_config(config)["networks"]
339
+ return load_edges_from_paths(networks["edges"])
340
+
305
341
 
306
342
  def load_edges(edges_file, edge_types_file):
307
- edges_arr = [{"edges_file":edges_file,"edge_types_file":edge_types_file}]
343
+ edges_arr = [{"edges_file": edges_file, "edge_types_file": edge_types_file}]
308
344
  edges = list(load_edges_from_paths(edges_arr).items())[0] # single item
309
345
  return edges # return (population, edges_df)
310
346
 
311
- def load_edges_from_paths(edge_paths):#network_dir='network'):
347
+
348
+ def load_edges_from_paths(edge_paths): # network_dir='network'):
312
349
  """
313
350
  Returns: A dictionary of connections with filenames (minus _edges.h5) as keys
314
351
 
@@ -323,44 +360,48 @@ def load_edges_from_paths(edge_paths):#network_dir='network'):
323
360
  """
324
361
  import h5py
325
362
  import pandas as pd
326
- #edges_regex = "_edges.h5"
327
- #edge_types_regex = "_edge_types.csv"
363
+ # edges_regex = "_edges.h5"
364
+ # edge_types_regex = "_edge_types.csv"
328
365
 
329
- #edges_h5_fpaths = glob.glob(os.path.join(network_dir,'*'+edges_regex))
330
- #edge_types_fpaths = glob.glob(os.path.join(network_dir,'*'+edge_types_regex))
366
+ # edges_h5_fpaths = glob.glob(os.path.join(network_dir,'*'+edges_regex))
367
+ # edge_types_fpaths = glob.glob(os.path.join(network_dir,'*'+edge_types_regex))
331
368
 
332
- #connections = [re.findall('^[A-Za-z0-9]+_[A-Za-z0-9][^_]+', os.path.basename(n))[0] for n in edges_h5_fpaths]
369
+ # connections = [re.findall('^[A-Za-z0-9]+_[A-Za-z0-9][^_]+', os.path.basename(n))[0] for n in edges_h5_fpaths]
333
370
  edges_dict = {}
334
- def get_edge_table(edges_file, edge_types_file, population=None):
335
371
 
372
+ def get_edge_table(edges_file, edge_types_file, population=None):
336
373
  # dataframe where each row is an edge type
337
- cm_df = pd.read_csv(edge_types_file, sep=' ')
338
- cm_df.set_index('edge_type_id', inplace=True)
374
+ cm_df = pd.read_csv(edge_types_file, sep=" ")
375
+ cm_df.set_index("edge_type_id", inplace=True)
339
376
 
340
- with h5py.File(edges_file, 'r') as connections_h5:
377
+ with h5py.File(edges_file, "r") as connections_h5:
341
378
  if population is None:
342
- if len(connections_h5['/edges']) > 1:
343
- raise Exception('Multiple populations in edges file. Not currently implemented, should not be hard to do, contact Tyler')
379
+ if len(connections_h5["/edges"]) > 1:
380
+ raise Exception(
381
+ "Multiple populations in edges file. Not currently implemented, should not be hard to do, contact Tyler"
382
+ )
344
383
  else:
345
- population = list(connections_h5['/edges'])[0]
346
- conn_grp = connections_h5['/edges'][population]
384
+ population = list(connections_h5["/edges"])[0]
385
+ conn_grp = connections_h5["/edges"][population]
347
386
 
348
387
  # dataframe where each row is an edge
349
- c_df = pd.DataFrame({key: conn_grp[key] for key in (
350
- 'edge_type_id', 'source_node_id', 'target_node_id')})
388
+ c_df = pd.DataFrame(
389
+ {key: conn_grp[key] for key in ("edge_type_id", "source_node_id", "target_node_id")}
390
+ )
351
391
 
352
392
  c_df.reset_index(inplace=True)
353
- c_df.rename(columns={'index': 'edge_id'}, inplace=True)
354
- c_df.set_index('edge_type_id', inplace=True)
393
+ c_df.rename(columns={"index": "edge_id"}, inplace=True)
394
+ c_df.set_index("edge_type_id", inplace=True)
355
395
 
356
396
  # add edge type properties to df of edges
357
- edges_df = pd.merge(left=c_df, right=cm_df, how='left',
358
- left_index=True, right_index=True)
397
+ edges_df = pd.merge(
398
+ left=c_df, right=cm_df, how="left", left_index=True, right_index=True
399
+ )
359
400
 
360
401
  # extra properties of individual edges (see SONATA Data format)
361
- if conn_grp.get('0'):
362
- edge_group_id = conn_grp['edge_group_id'][()]
363
- edge_group_index = conn_grp['edge_group_index'][()]
402
+ if conn_grp.get("0"):
403
+ edge_group_id = conn_grp["edge_group_id"][()]
404
+ edge_group_index = conn_grp["edge_group_index"][()]
364
405
  n_group = edge_group_id.max() + 1
365
406
  prop_dtype = {}
366
407
  for group_id in range(n_group):
@@ -368,17 +409,17 @@ def load_edges_from_paths(edge_paths):#network_dir='network'):
368
409
  idx = edge_group_id == group_id
369
410
  for prop in group:
370
411
  # create new column with NaN if property does not exist
371
- if prop not in edges_df:
412
+ if prop not in edges_df:
372
413
  edges_df[prop] = np.nan
373
414
  edges_df.loc[idx, prop] = tuple(group[prop][edge_group_index[idx]])
374
415
  prop_dtype[prop] = group[prop].dtype
375
416
  # convert to original data type if possible
376
417
  for prop, dtype in prop_dtype.items():
377
- edges_df[prop] = edges_df[prop].astype(dtype, errors='ignore')
418
+ edges_df[prop] = edges_df[prop].astype(dtype, errors="ignore")
378
419
 
379
420
  return population, edges_df
380
421
 
381
- #for edges_dict, conn_models_file, conns_file in zip(connections, edge_types_fpaths, edges_h5_fpaths):
422
+ # for edges_dict, conn_models_file, conns_file in zip(connections, edge_types_fpaths, edges_h5_fpaths):
382
423
  # connections_dict[connection] = get_connection_table(conn_models_file,conns_file)
383
424
  try:
384
425
  for nodes in edge_paths:
@@ -391,24 +432,27 @@ def load_edges_from_paths(edge_paths):#network_dir='network'):
391
432
  print("Hint: Are you loading the correct simulation config file?")
392
433
  print("Command Line: bmtool plot --config yourconfig.json <rest of command>")
393
434
  print("Python: bmplot.connection_matrix(config='yourconfig.json')")
394
-
435
+
395
436
  return edges_dict
396
437
 
438
+
397
439
  def load_mechanisms_from_config(config=None):
398
440
  """
399
441
  loads neuron mechanisms from BMTK config
400
442
  """
401
443
  if config is None:
402
- config = 'simulation_config.json'
444
+ config = "simulation_config.json"
403
445
  config = load_config(config)
404
- return neuron.load_mechanisms(config['components']['mechanisms_dir'])
446
+ return neuron.load_mechanisms(config["components"]["mechanisms_dir"])
447
+
405
448
 
406
449
  def load_templates_from_config(config=None):
407
450
  if config is None:
408
- config = 'simulation_config.json'
451
+ config = "simulation_config.json"
409
452
  config = load_config(config)
410
453
  load_mechanisms_from_config(config)
411
- return load_templates_from_paths(config['components']['templates_dir'])
454
+ return load_templates_from_paths(config["components"]["templates_dir"])
455
+
412
456
 
413
457
  def load_templates_from_paths(template_paths):
414
458
  # load all the files in the templates dir
@@ -418,7 +462,6 @@ def load_templates_from_paths(template_paths):
418
462
  print(f"loading {item_path}")
419
463
  h.load_file(item_path)
420
464
 
421
-
422
465
 
423
466
  def cell_positions_by_id(config=None, nodes=None, populations=[], popids=[], prepend_pop=True):
424
467
  """
@@ -427,41 +470,56 @@ def cell_positions_by_id(config=None, nodes=None, populations=[], popids=[], pre
427
470
  if not nodes:
428
471
  nodes = load_nodes_from_config(config)
429
472
 
430
- import pdb
431
-
432
- if 'all' in populations or not populations or not len(populations):
473
+ if "all" in populations or not populations or not len(populations):
433
474
  populations = list(nodes)
434
475
 
435
- popids += (len(populations)-len(popids)) * ["node_type_id"] #Extend the array to default values if not enough given
476
+ popids += (len(populations) - len(popids)) * [
477
+ "node_type_id"
478
+ ] # Extend the array to default values if not enough given
436
479
  cells_by_id = {}
437
- for population,pid in zip(populations,popids):
438
- #get a list of unique cell types based on pid
439
- pdb.set_trace()
480
+ for population, pid in zip(populations, popids):
481
+ # get a list of unique cell types based on pid
482
+ # Debug statement removed
440
483
  cell_types = list(nodes[population][str(pid)].unique())
441
484
  for ct in cell_types:
442
- cells_by_id[population+'_'+ct] = 0
443
-
485
+ cells_by_id[population + "_" + ct] = 0
486
+
444
487
  return cells_by_id
445
488
 
446
- def relation_matrix(config=None, nodes=None,edges=None,sources=[],targets=[],sids=[],tids=[],prepend_pop=True,relation_func=None,return_type=float,drop_point_process=False,synaptic_info='0'):
447
-
489
+
490
+ def relation_matrix(
491
+ config=None,
492
+ nodes=None,
493
+ edges=None,
494
+ sources=[],
495
+ targets=[],
496
+ sids=[],
497
+ tids=[],
498
+ prepend_pop=True,
499
+ relation_func=None,
500
+ return_type=float,
501
+ drop_point_process=False,
502
+ synaptic_info="0",
503
+ ):
448
504
  import pandas as pd
449
-
505
+
450
506
  if not nodes and not edges:
451
- nodes,edges = load_nodes_edges_from_config(config)
507
+ nodes, edges = load_nodes_edges_from_config(config)
452
508
  if not nodes:
453
509
  nodes = load_nodes_from_config(config)
454
510
  if not edges:
455
511
  edges = load_edges_from_config(config)
456
512
  if not edges and not nodes and not config:
457
513
  raise Exception("No information given to load nodes/edges")
458
-
459
- if 'all' in sources:
514
+
515
+ if "all" in sources:
460
516
  sources = list(nodes)
461
- if 'all' in targets:
517
+ if "all" in targets:
462
518
  targets = list(nodes)
463
- sids += (len(sources)-len(sids)) * ["node_type_id"] #Extend the array to default values if not enough given
464
- tids += (len(targets)-len(tids)) * ["node_type_id"]
519
+ sids += (len(sources) - len(sids)) * [
520
+ "node_type_id"
521
+ ] # Extend the array to default values if not enough given
522
+ tids += (len(targets) - len(tids)) * ["node_type_id"]
465
523
 
466
524
  total_source_cell_types = 0
467
525
  total_target_cell_types = 0
@@ -472,20 +530,20 @@ def relation_matrix(config=None, nodes=None,edges=None,sources=[],targets=[],sid
472
530
  source_totals = []
473
531
  target_totals = []
474
532
 
475
- source_map = {}#Sometimes we don't add the item to sources or targets, need to keep track of the index
476
- target_map = {}#Or change to be a dictionary sometime
533
+ source_map = {} # Sometimes we don't add the item to sources or targets, need to keep track of the index
534
+ target_map = {} # Or change to be a dictionary sometime
477
535
 
478
- for source,sid in zip(sources,sids):
536
+ for source, sid in zip(sources, sids):
479
537
  do_process = False
480
538
  for t, target in enumerate(targets):
481
- e_name = source+"_to_"+target
482
- if e_name in list(edges):
483
- do_process=True
484
- if not do_process: # This is not seen as an input, don't process it.
539
+ e_name = source + "_to_" + target
540
+ if isinstance(edges, dict) and e_name in edges:
541
+ do_process = True
542
+ if not do_process: # This is not seen as an input, don't process it.
485
543
  continue
486
-
544
+
487
545
  if drop_point_process:
488
- nodes_src = pd.DataFrame(nodes[source][nodes[source]['model_type']!='point_process'])
546
+ nodes_src = pd.DataFrame(nodes[source][nodes[source]["model_type"] != "point_process"])
489
547
  else:
490
548
  nodes_src = pd.DataFrame(nodes[source])
491
549
  total_source_cell_types = total_source_cell_types + len(list(set(nodes_src[sid])))
@@ -493,81 +551,109 @@ def relation_matrix(config=None, nodes=None,edges=None,sources=[],targets=[],sid
493
551
  source_uids.append(unique_)
494
552
  prepend_str = ""
495
553
  if prepend_pop:
496
- prepend_str = str(source) +"_"
497
- unique_= list(np.array((prepend_str+ pd.DataFrame(unique_).astype(str)).values.tolist()).ravel())
554
+ prepend_str = str(source) + "_"
555
+ unique_ = list(
556
+ np.array((prepend_str + pd.DataFrame(unique_).astype(str)).values.tolist()).ravel()
557
+ )
498
558
  source_pop_names = source_pop_names + unique_
499
559
  source_totals.append(len(unique_))
500
- source_map[source] = len(source_uids)-1
501
- for target,tid in zip(targets,tids):
560
+ source_map[source] = len(source_uids) - 1
561
+ for target, tid in zip(targets, tids):
502
562
  do_process = False
503
563
  for s, source in enumerate(sources):
504
- e_name = source+"_to_"+target
505
- if e_name in list(edges):
506
- do_process=True
564
+ e_name = source + "_to_" + target
565
+ if isinstance(edges, dict) and e_name in edges:
566
+ do_process = True
507
567
  if not do_process:
508
568
  continue
509
569
 
510
570
  if drop_point_process:
511
- nodes_trg = pd.DataFrame(nodes[target][nodes[target]['model_type']!='point_process'])
571
+ nodes_trg = pd.DataFrame(nodes[target][nodes[target]["model_type"] != "point_process"])
512
572
  else:
513
573
  nodes_trg = pd.DataFrame(nodes[target])
514
574
 
515
575
  total_target_cell_types = total_target_cell_types + len(list(set(nodes_trg[tid])))
516
-
576
+
517
577
  unique_ = nodes_trg[tid].unique()
518
578
  target_uids.append(unique_)
519
579
  prepend_str = ""
520
580
  if prepend_pop:
521
- prepend_str = str(target) +"_"
522
- unique_ = list(np.array((prepend_str + pd.DataFrame(unique_).astype(str)).values.tolist()).ravel())
581
+ prepend_str = str(target) + "_"
582
+ unique_ = list(
583
+ np.array((prepend_str + pd.DataFrame(unique_).astype(str)).values.tolist()).ravel()
584
+ )
523
585
  target_pop_names = target_pop_names + unique_
524
586
  target_totals.append(len(unique_))
525
- target_map[target] = len(target_uids) -1
587
+ target_map[target] = len(target_uids) - 1
526
588
 
527
- e_matrix = np.zeros((total_source_cell_types,total_target_cell_types),dtype=return_type)
528
- syn_info = np.zeros((total_source_cell_types,total_target_cell_types),dtype=object)
529
- sources_start = np.cumsum(source_totals) -source_totals
530
- target_start = np.cumsum(target_totals) -target_totals
589
+ e_matrix = np.zeros((total_source_cell_types, total_target_cell_types), dtype=return_type)
590
+ syn_info = np.zeros((total_source_cell_types, total_target_cell_types), dtype=object)
591
+ sources_start = np.cumsum(source_totals) - source_totals
592
+ target_start = np.cumsum(target_totals) - target_totals
531
593
  total = 0
532
- stdev=0
533
- mean=0
594
+ stdev = 0
595
+ mean = 0
534
596
  for s, source in enumerate(sources):
535
597
  for t, target in enumerate(targets):
536
- e_name = source+"_to_"+target
598
+ e_name = source + "_to_" + target
537
599
  if e_name not in list(edges):
538
600
  continue
539
601
  if relation_func:
540
- source_nodes = nodes[source].add_prefix('source_')
541
- target_nodes = nodes[target].add_prefix('target_')
542
-
543
- c_edges = pd.merge(left=edges[e_name],
544
- right=source_nodes,
545
- how='left',
546
- left_on='source_node_id',
547
- right_index=True)
548
-
549
- c_edges = pd.merge(left=c_edges,
550
- right=target_nodes,
551
- how='left',
552
- left_on='target_node_id',
553
- right_index=True)
554
-
602
+ source_nodes = nodes[source].add_prefix("source_")
603
+ target_nodes = nodes[target].add_prefix("target_")
604
+
605
+ c_edges = pd.merge(
606
+ left=edges[e_name],
607
+ right=source_nodes,
608
+ how="left",
609
+ left_on="source_node_id",
610
+ right_index=True,
611
+ )
612
+
613
+ c_edges = pd.merge(
614
+ left=c_edges,
615
+ right=target_nodes,
616
+ how="left",
617
+ left_on="target_node_id",
618
+ right_index=True,
619
+ )
620
+
555
621
  sm = source_map[source]
556
622
  tm = target_map[target]
557
-
623
+
558
624
  def syn_info_func(**kwargs):
559
625
  edges = kwargs["edges"]
560
626
  source_id_type = kwargs["sid"]
561
627
  target_id_type = kwargs["tid"]
562
628
  source_id = kwargs["source_id"]
563
629
  target_id = kwargs["target_id"]
564
- if edges[(edges[source_id_type] == source_id) & (edges[target_id_type]==target_id)]["dynamics_params"].count()!=0:
565
- params = str(edges[(edges[source_id_type] == source_id) & (edges[target_id_type]==target_id)]["dynamics_params"].drop_duplicates().values[0])
630
+ if (
631
+ edges[
632
+ (edges[source_id_type] == source_id)
633
+ & (edges[target_id_type] == target_id)
634
+ ]["dynamics_params"].count()
635
+ != 0
636
+ ):
637
+ params = str(
638
+ edges[
639
+ (edges[source_id_type] == source_id)
640
+ & (edges[target_id_type] == target_id)
641
+ ]["dynamics_params"]
642
+ .drop_duplicates()
643
+ .values[0]
644
+ )
566
645
  params = params[:-5]
567
- mod = str(edges[(edges[source_id_type] == source_id) & (edges[target_id_type]==target_id)]["model_template"].drop_duplicates().values[0])
568
- if mod and synaptic_info=='1':
646
+ mod = str(
647
+ edges[
648
+ (edges[source_id_type] == source_id)
649
+ & (edges[target_id_type] == target_id)
650
+ ]["model_template"]
651
+ .drop_duplicates()
652
+ .values[0]
653
+ )
654
+ if mod and synaptic_info == "1":
569
655
  return mod
570
- elif params and synaptic_info=='2':
656
+ elif params and synaptic_info == "2":
571
657
  return params
572
658
  else:
573
659
  return None
@@ -578,7 +664,14 @@ def relation_matrix(config=None, nodes=None,edges=None,sources=[],targets=[],sid
578
664
  target_id_type = kwargs["tid"]
579
665
  source_id = kwargs["source_id"]
580
666
  target_id = kwargs["target_id"]
581
- mean = edges[(edges[source_id_type] == source_id) & (edges[target_id_type]==target_id)]['target_node_id'].value_counts().mean()
667
+ mean = (
668
+ edges[
669
+ (edges[source_id_type] == source_id)
670
+ & (edges[target_id_type] == target_id)
671
+ ]["target_node_id"]
672
+ .value_counts()
673
+ .mean()
674
+ )
582
675
  return mean
583
676
 
584
677
  def conn_stdev_func(**kwargs):
@@ -587,51 +680,121 @@ def relation_matrix(config=None, nodes=None,edges=None,sources=[],targets=[],sid
587
680
  target_id_type = kwargs["tid"]
588
681
  source_id = kwargs["source_id"]
589
682
  target_id = kwargs["target_id"]
590
- stdev = edges[(edges[source_id_type] == source_id) & (edges[target_id_type]==target_id)]['target_node_id'].value_counts().std()
683
+ stdev = (
684
+ edges[
685
+ (edges[source_id_type] == source_id)
686
+ & (edges[target_id_type] == target_id)
687
+ ]["target_node_id"]
688
+ .value_counts()
689
+ .std()
690
+ )
591
691
  return stdev
592
692
 
593
- for s_type_ind,s_type in enumerate(source_uids[sm]):
594
-
595
- for t_type_ind,t_type in enumerate(target_uids[tm]):
596
- source_index = int(s_type_ind+sources_start[sm])
597
- target_index = int(t_type_ind+target_start[tm])
598
-
599
- total = relation_func(source_nodes=source_nodes, target_nodes=target_nodes, edges=c_edges, source=source,sid="source_"+sids[s], target=target,tid="target_"+tids[t],source_id=s_type,target_id=t_type)
600
- if synaptic_info=='0':
693
+ for s_type_ind, s_type in enumerate(source_uids[sm]):
694
+ for t_type_ind, t_type in enumerate(target_uids[tm]):
695
+ source_index = int(s_type_ind + sources_start[sm])
696
+ target_index = int(t_type_ind + target_start[tm])
697
+
698
+ total = relation_func(
699
+ source_nodes=source_nodes,
700
+ target_nodes=target_nodes,
701
+ edges=c_edges,
702
+ source=source,
703
+ sid="source_" + sids[s],
704
+ target=target,
705
+ tid="target_" + tids[t],
706
+ source_id=s_type,
707
+ target_id=t_type,
708
+ )
709
+ if synaptic_info == "0":
601
710
  if isinstance(total, tuple):
602
- syn_info[source_index, target_index] = str(round(total[0], 1)) + '\n' + str(round(total[1], 1))
711
+ syn_info[source_index, target_index] = (
712
+ str(round(total[0], 1)) + "\n" + str(round(total[1], 1))
713
+ )
603
714
  else:
604
- syn_info[source_index,target_index] = total
605
- elif synaptic_info=='1':
606
- mean = conn_mean_func(source_nodes=source_nodes, target_nodes=target_nodes, edges=c_edges, source=source,sid="source_"+sids[s], target=target,tid="target_"+tids[t],source_id=s_type,target_id=t_type)
607
- stdev = conn_stdev_func(source_nodes=source_nodes, target_nodes=target_nodes, edges=c_edges, source=source,sid="source_"+sids[s], target=target,tid="target_"+tids[t],source_id=s_type,target_id=t_type)
715
+ syn_info[source_index, target_index] = total
716
+ elif synaptic_info == "1":
717
+ mean = conn_mean_func(
718
+ source_nodes=source_nodes,
719
+ target_nodes=target_nodes,
720
+ edges=c_edges,
721
+ source=source,
722
+ sid="source_" + sids[s],
723
+ target=target,
724
+ tid="target_" + tids[t],
725
+ source_id=s_type,
726
+ target_id=t_type,
727
+ )
728
+ stdev = conn_stdev_func(
729
+ source_nodes=source_nodes,
730
+ target_nodes=target_nodes,
731
+ edges=c_edges,
732
+ source=source,
733
+ sid="source_" + sids[s],
734
+ target=target,
735
+ tid="target_" + tids[t],
736
+ source_id=s_type,
737
+ target_id=t_type,
738
+ )
608
739
  if math.isnan(mean):
609
- mean=0
740
+ mean = 0
610
741
  if math.isnan(stdev):
611
- stdev=0
612
- syn_info[source_index,target_index] = str(round(mean,1)) + '\n'+ str(round(stdev,1))
613
- elif synaptic_info=='2':
614
- syn_list = syn_info_func(source_nodes=source_nodes, target_nodes=target_nodes, edges=c_edges, source=source,sid="source_"+sids[s], target=target,tid="target_"+tids[t],source_id=s_type,target_id=t_type)
742
+ stdev = 0
743
+ syn_info[source_index, target_index] = (
744
+ str(round(mean, 1)) + "\n" + str(round(stdev, 1))
745
+ )
746
+ elif synaptic_info == "2":
747
+ syn_list = syn_info_func(
748
+ source_nodes=source_nodes,
749
+ target_nodes=target_nodes,
750
+ edges=c_edges,
751
+ source=source,
752
+ sid="source_" + sids[s],
753
+ target=target,
754
+ tid="target_" + tids[t],
755
+ source_id=s_type,
756
+ target_id=t_type,
757
+ )
615
758
  if syn_list is None:
616
- syn_info[source_index,target_index] = ""
759
+ syn_info[source_index, target_index] = ""
617
760
  else:
618
- syn_info[source_index,target_index] = syn_list
619
- elif synaptic_info=='3':
620
- syn_list = syn_info_func(source_nodes=source_nodes, target_nodes=target_nodes, edges=c_edges, source=source,sid="source_"+sids[s], target=target,tid="target_"+tids[t],source_id=s_type,target_id=t_type)
761
+ syn_info[source_index, target_index] = syn_list
762
+ elif synaptic_info == "3":
763
+ syn_list = syn_info_func(
764
+ source_nodes=source_nodes,
765
+ target_nodes=target_nodes,
766
+ edges=c_edges,
767
+ source=source,
768
+ sid="source_" + sids[s],
769
+ target=target,
770
+ tid="target_" + tids[t],
771
+ source_id=s_type,
772
+ target_id=t_type,
773
+ )
621
774
  if syn_list is None:
622
- syn_info[source_index,target_index] = ""
775
+ syn_info[source_index, target_index] = ""
623
776
  else:
624
- syn_info[source_index,target_index] = syn_list
777
+ syn_info[source_index, target_index] = syn_list
625
778
  if isinstance(total, tuple):
626
- e_matrix[source_index,target_index]=total[0]
779
+ e_matrix[source_index, target_index] = total[0]
627
780
  else:
628
- e_matrix[source_index,target_index]=total
781
+ e_matrix[source_index, target_index] = total
629
782
 
630
-
631
783
  return syn_info, e_matrix, source_pop_names, target_pop_names
632
784
 
633
- def connection_totals(config=None,nodes=None,edges=None,sources=[],targets=[],sids=[],tids=[],prepend_pop=True,synaptic_info='0',include_gap=True):
634
-
785
+
786
+ def connection_totals(
787
+ config=None,
788
+ nodes=None,
789
+ edges=None,
790
+ sources=[],
791
+ targets=[],
792
+ sids=[],
793
+ tids=[],
794
+ prepend_pop=True,
795
+ synaptic_info="0",
796
+ include_gap=True,
797
+ ):
635
798
  def total_connection_relationship(**kwargs):
636
799
  edges = kwargs["edges"]
637
800
  source_id_type = kwargs["sid"]
@@ -639,23 +802,46 @@ def connection_totals(config=None,nodes=None,edges=None,sources=[],targets=[],si
639
802
  source_id = kwargs["source_id"]
640
803
  target_id = kwargs["target_id"]
641
804
 
642
- total = edges[(edges[source_id_type] == source_id) & (edges[target_id_type]==target_id)]
643
- if include_gap == False:
644
- try:
645
- cons = cons[cons['is_gap_junction'] != True]
805
+ total = edges[(edges[source_id_type] == source_id) & (edges[target_id_type] == target_id)]
806
+ if not include_gap:
807
+ try:
808
+ total = total[~total["is_gap_junction"]]
646
809
  except:
647
- raise Exception("no gap junctions found to drop from connections")
648
-
810
+ # If there are no gap junctions, just continue
811
+ pass
812
+
649
813
  total = total.count()
650
- total = total.source_node_id # may not be the best way to pick
814
+ total = total.source_node_id # may not be the best way to pick
651
815
  return total
652
- return relation_matrix(config,nodes,edges,sources,targets,sids,tids,prepend_pop,relation_func=total_connection_relationship,synaptic_info=synaptic_info)
653
-
654
-
655
- def percent_connections(config=None,nodes=None,edges=None,sources=[],targets=[],sids=[],tids=[],prepend_pop=True,type='convergence',method=None,include_gap=True):
656
-
657
816
 
658
- def precent_func(**kwargs):
817
+ return relation_matrix(
818
+ config,
819
+ nodes,
820
+ edges,
821
+ sources,
822
+ targets,
823
+ sids,
824
+ tids,
825
+ prepend_pop,
826
+ relation_func=total_connection_relationship,
827
+ synaptic_info=synaptic_info,
828
+ )
829
+
830
+
831
+ def percent_connections(
832
+ config=None,
833
+ nodes=None,
834
+ edges=None,
835
+ sources=[],
836
+ targets=[],
837
+ sids=[],
838
+ tids=[],
839
+ prepend_pop=True,
840
+ type="convergence",
841
+ method=None,
842
+ include_gap=True,
843
+ ):
844
+ def precent_func(**kwargs):
659
845
  edges = kwargs["edges"]
660
846
  source_id_type = kwargs["sid"]
661
847
  target_id_type = kwargs["tid"]
@@ -664,53 +850,66 @@ def percent_connections(config=None,nodes=None,edges=None,sources=[],targets=[],
664
850
  t_list = kwargs["target_nodes"]
665
851
  s_list = kwargs["source_nodes"]
666
852
 
667
- cons = edges[(edges[source_id_type] == source_id) & (edges[target_id_type]==target_id)]
668
- if include_gap == False:
669
- try:
670
- cons = cons[cons['is_gap_junction'] != True]
853
+ cons = edges[(edges[source_id_type] == source_id) & (edges[target_id_type] == target_id)]
854
+ if not include_gap:
855
+ try:
856
+ cons = cons[~cons["is_gap_junction"]]
671
857
  except:
672
858
  raise Exception("no gap junctions found to drop from connections")
673
-
859
+
674
860
  total_cons = cons.count().source_node_id
675
861
  # to determine reciprocal connectivity
676
862
  # create a copy and flip source/dest
677
- cons_flip = edges[(edges[source_id_type] == target_id) & (edges[target_id_type]==source_id)]
678
- cons_flip = cons_flip.rename(columns={'source_node_id':'target_node_id','target_node_id':'source_node_id'})
679
- # append to original
863
+ cons_flip = edges[
864
+ (edges[source_id_type] == target_id) & (edges[target_id_type] == source_id)
865
+ ]
866
+ cons_flip = cons_flip.rename(
867
+ columns={"source_node_id": "target_node_id", "target_node_id": "source_node_id"}
868
+ )
869
+ # append to original
680
870
  cons_recip = pd.concat([cons, cons_flip])
681
871
 
682
872
  # determine dropped duplicates (keep=False)
683
- cons_recip_dedup = cons_recip.drop_duplicates(subset=['source_node_id','target_node_id'])
873
+ cons_recip_dedup = cons_recip.drop_duplicates(subset=["source_node_id", "target_node_id"])
684
874
 
685
875
  # note counts
686
- num_bi = (cons_recip.count().source_node_id - cons_recip_dedup.count().source_node_id)
687
- num_uni = total_cons - num_bi
876
+ num_bi = cons_recip.count().source_node_id - cons_recip_dedup.count().source_node_id
877
+ num_uni = total_cons - num_bi
688
878
 
689
- #num_sources = s_list.apply(pd.Series.value_counts)[source_id_type].dropna().sort_index().loc[source_id]
690
- #num_targets = t_list.apply(pd.Series.value_counts)[target_id_type].dropna().sort_index().loc[target_id]
879
+ # num_sources = s_list.apply(pd.Series.value_counts)[source_id_type].dropna().sort_index().loc[source_id]
880
+ # num_targets = t_list.apply(pd.Series.value_counts)[target_id_type].dropna().sort_index().loc[target_id]
691
881
 
692
882
  num_sources = s_list[source_id_type].value_counts().sort_index().loc[source_id]
693
883
  num_targets = t_list[target_id_type].value_counts().sort_index().loc[target_id]
694
884
 
695
-
696
- total = round(total_cons / (num_sources*num_targets) * 100,2)
697
- uni = round(num_uni / (num_sources*num_targets) * 100,2)
698
- bi = round(num_bi / (num_sources*num_targets) * 100,2)
699
- if method == 'total':
885
+ total = round(total_cons / (num_sources * num_targets) * 100, 2)
886
+ uni = round(num_uni / (num_sources * num_targets) * 100, 2)
887
+ bi = round(num_bi / (num_sources * num_targets) * 100, 2)
888
+ if method == "total":
700
889
  return total
701
- if method == 'uni':
890
+ if method == "uni":
702
891
  return uni
703
- if method == 'bi':
892
+ if method == "bi":
704
893
  return bi
705
894
 
706
-
707
- return relation_matrix(config,nodes,edges,sources,targets,sids,tids,prepend_pop,relation_func=precent_func)
708
-
709
-
710
- def connection_divergence(config=None,nodes=None,edges=None,sources=[],targets=[],sids=[],tids=[],prepend_pop=True,convergence=False,method='mean+std',include_gap=True):
711
-
712
- import pandas as pd
713
-
895
+ return relation_matrix(
896
+ config, nodes, edges, sources, targets, sids, tids, prepend_pop, relation_func=precent_func
897
+ )
898
+
899
+
900
+ def connection_divergence(
901
+ config=None,
902
+ nodes=None,
903
+ edges=None,
904
+ sources=[],
905
+ targets=[],
906
+ sids=[],
907
+ tids=[],
908
+ prepend_pop=True,
909
+ convergence=False,
910
+ method="mean+std",
911
+ include_gap=True,
912
+ ):
714
913
  def total_connection_relationship(**kwargs):
715
914
  edges = kwargs["edges"]
716
915
  source_id_type = kwargs["sid"]
@@ -721,73 +920,94 @@ def connection_divergence(config=None,nodes=None,edges=None,sources=[],targets=[
721
920
  s_list = kwargs["source_nodes"]
722
921
  count = 1
723
922
 
724
- cons = edges[(edges[source_id_type] == source_id) & (edges[target_id_type]==target_id)]
725
- if include_gap == False:
726
- try:
727
- cons = cons[cons['is_gap_junction'] != True]
923
+ cons = edges[(edges[source_id_type] == source_id) & (edges[target_id_type] == target_id)]
924
+ if not include_gap:
925
+ try:
926
+ cons = cons[~cons["is_gap_junction"]]
728
927
  except:
729
928
  raise Exception("no gap junctions found to drop from connections")
730
929
 
731
930
  if convergence:
732
- if method == 'min':
733
- count = cons['target_node_id'].value_counts().min()
734
- return round(count,2)
735
- elif method == 'max':
736
- count = cons['target_node_id'].value_counts().max()
737
- return round(count,2)
738
- elif method == 'std':
739
- std = cons['target_node_id'].value_counts().std()
740
- return round(std,2)
741
- elif method == 'mean':
742
- mean = cons['target_node_id'].value_counts().mean()
743
- return round(mean,2)
744
- elif method == 'mean+std': #default is mean + std
745
- mean = cons['target_node_id'].value_counts().mean()
746
- std = cons['target_node_id'].value_counts().std()
747
- #std = cons.apply(pd.Series.value_counts).target_node_id.dropna().std() no longer a valid way
748
- return (round(mean,2)), (round(std,2))
749
- else: #divergence
750
- if method == 'min':
751
- count = cons['source_node_id'].value_counts().min()
752
- return round(count,2)
753
- elif method == 'max':
754
- count = cons['source_node_id'].value_counts().max()
755
- return round(count,2)
756
- elif method == 'std':
757
- std = cons['source_node_id'].value_counts().std()
758
- return round(std,2)
759
- elif method == 'mean':
760
- mean = cons['source_node_id'].value_counts().mean()
761
- return round(mean,2)
762
- elif method == 'mean+std': #default is mean + std
763
- mean = cons['source_node_id'].value_counts().mean()
764
- std = cons['source_node_id'].value_counts().std()
765
- return (round(mean,2)), (round(std,2))
766
-
767
- return relation_matrix(config,nodes,edges,sources,targets,sids,tids,prepend_pop,relation_func=total_connection_relationship)
768
-
769
- def gap_junction_connections(config=None,nodes=None,edges=None,sources=[],targets=[],sids=[],tids=[],prepend_pop=True,method='convergence'):
770
-
771
-
772
- def total_connection_relationship(**kwargs): #reduced version of original function; only gets mean+std
931
+ if method == "min":
932
+ count = cons["target_node_id"].value_counts().min()
933
+ return round(count, 2)
934
+ elif method == "max":
935
+ count = cons["target_node_id"].value_counts().max()
936
+ return round(count, 2)
937
+ elif method == "std":
938
+ std = cons["target_node_id"].value_counts().std()
939
+ return round(std, 2)
940
+ elif method == "mean":
941
+ mean = cons["target_node_id"].value_counts().mean()
942
+ return round(mean, 2)
943
+ elif method == "mean+std": # default is mean + std
944
+ mean = cons["target_node_id"].value_counts().mean()
945
+ std = cons["target_node_id"].value_counts().std()
946
+ # std = cons.apply(pd.Series.value_counts).target_node_id.dropna().std() no longer a valid way
947
+ return (round(mean, 2)), (round(std, 2))
948
+ else: # divergence
949
+ if method == "min":
950
+ count = cons["source_node_id"].value_counts().min()
951
+ return round(count, 2)
952
+ elif method == "max":
953
+ count = cons["source_node_id"].value_counts().max()
954
+ return round(count, 2)
955
+ elif method == "std":
956
+ std = cons["source_node_id"].value_counts().std()
957
+ return round(std, 2)
958
+ elif method == "mean":
959
+ mean = cons["source_node_id"].value_counts().mean()
960
+ return round(mean, 2)
961
+ elif method == "mean+std": # default is mean + std
962
+ mean = cons["source_node_id"].value_counts().mean()
963
+ std = cons["source_node_id"].value_counts().std()
964
+ return (round(mean, 2)), (round(std, 2))
965
+
966
+ return relation_matrix(
967
+ config,
968
+ nodes,
969
+ edges,
970
+ sources,
971
+ targets,
972
+ sids,
973
+ tids,
974
+ prepend_pop,
975
+ relation_func=total_connection_relationship,
976
+ )
977
+
978
+
979
+ def gap_junction_connections(
980
+ config=None,
981
+ nodes=None,
982
+ edges=None,
983
+ sources=[],
984
+ targets=[],
985
+ sids=[],
986
+ tids=[],
987
+ prepend_pop=True,
988
+ method="convergence",
989
+ ):
990
+ def total_connection_relationship(
991
+ **kwargs
992
+ ): # reduced version of original function; only gets mean+std
773
993
  edges = kwargs["edges"]
774
994
  source_id_type = kwargs["sid"]
775
995
  target_id_type = kwargs["tid"]
776
996
  source_id = kwargs["source_id"]
777
997
  target_id = kwargs["target_id"]
778
998
 
779
- cons = edges[(edges[source_id_type] == source_id) & (edges[target_id_type]==target_id)]
780
- #print(cons)
781
-
782
- try:
783
- cons = cons[cons['is_gap_junction'] == True]
999
+ cons = edges[(edges[source_id_type] == source_id) & (edges[target_id_type] == target_id)]
1000
+ # print(cons)
1001
+
1002
+ try:
1003
+ cons = cons[cons["is_gap_junction"]]
784
1004
  except:
785
1005
  raise Exception("no gap junctions found to drop from connections")
786
- mean = cons['target_node_id'].value_counts().mean()
787
- std = cons['target_node_id'].value_counts().std()
788
- return (round(mean,2)), (round(std,2))
789
-
790
- def precent_func(**kwargs): #barely different than original function; only gets gap_junctions.
1006
+ mean = cons["target_node_id"].value_counts().mean()
1007
+ std = cons["target_node_id"].value_counts().std()
1008
+ return (round(mean, 2)), (round(std, 2))
1009
+
1010
+ def precent_func(**kwargs): # barely different than original function; only gets gap_junctions.
791
1011
  edges = kwargs["edges"]
792
1012
  source_id_type = kwargs["sid"]
793
1013
  target_id_type = kwargs["tid"]
@@ -796,34 +1016,68 @@ def gap_junction_connections(config=None,nodes=None,edges=None,sources=[],target
796
1016
  t_list = kwargs["target_nodes"]
797
1017
  s_list = kwargs["source_nodes"]
798
1018
 
799
- cons = edges[(edges[source_id_type] == source_id) & (edges[target_id_type]==target_id)]
800
- #add functionality that shows only the one's with gap_junctions
801
- try:
802
- cons = cons[cons['is_gap_junction'] == True]
1019
+ cons = edges[(edges[source_id_type] == source_id) & (edges[target_id_type] == target_id)]
1020
+ # add functionality that shows only the one's with gap_junctions
1021
+ try:
1022
+ cons = cons[cons["is_gap_junction"]]
803
1023
  except:
804
1024
  raise Exception("no gap junctions found to drop from connections")
805
-
1025
+
806
1026
  total_cons = cons.count().source_node_id
807
1027
 
808
1028
  num_sources = s_list[source_id_type].value_counts().sort_index().loc[source_id]
809
1029
  num_targets = t_list[target_id_type].value_counts().sort_index().loc[target_id]
810
1030
 
811
-
812
- total = round(total_cons / (num_sources*num_targets) * 100,2) * 2 #not sure why but the percent is off by roughly 2 times ill make khuram fix it
1031
+ total = (
1032
+ round(total_cons / (num_sources * num_targets) * 100, 2) * 2
1033
+ ) # not sure why but the percent is off by roughly 2 times ill make khuram fix it
813
1034
  return total
814
-
815
- if method == 'convergence':
816
- return relation_matrix(config,nodes,edges,sources,targets,sids,tids,prepend_pop,relation_func=total_connection_relationship)
817
- elif method == 'percent':
818
- return relation_matrix(config,nodes,edges,sources,targets,sids,tids,prepend_pop,relation_func=precent_func)
819
-
820
-
821
- def connection_probabilities(config=None,nodes=None,edges=None,sources=[],
822
- targets=[],sids=[],tids=[],prepend_pop=True,dist_X=True,dist_Y=True,dist_Z=True,num_bins=10,include_gap=True):
823
-
1035
+
1036
+ if method == "convergence":
1037
+ return relation_matrix(
1038
+ config,
1039
+ nodes,
1040
+ edges,
1041
+ sources,
1042
+ targets,
1043
+ sids,
1044
+ tids,
1045
+ prepend_pop,
1046
+ relation_func=total_connection_relationship,
1047
+ )
1048
+ elif method == "percent":
1049
+ return relation_matrix(
1050
+ config,
1051
+ nodes,
1052
+ edges,
1053
+ sources,
1054
+ targets,
1055
+ sids,
1056
+ tids,
1057
+ prepend_pop,
1058
+ relation_func=precent_func,
1059
+ )
1060
+
1061
+
1062
+ def connection_probabilities(
1063
+ config=None,
1064
+ nodes=None,
1065
+ edges=None,
1066
+ sources=[],
1067
+ targets=[],
1068
+ sids=[],
1069
+ tids=[],
1070
+ prepend_pop=True,
1071
+ dist_X=True,
1072
+ dist_Y=True,
1073
+ dist_Z=True,
1074
+ num_bins=10,
1075
+ include_gap=True,
1076
+ ):
1077
+ import matplotlib.pyplot as plt
824
1078
  import pandas as pd
825
1079
  from scipy.spatial import distance
826
- import matplotlib.pyplot as plt
1080
+
827
1081
  pd.options.mode.chained_assignment = None
828
1082
 
829
1083
  def connection_relationship(**kwargs):
@@ -835,7 +1089,6 @@ def connection_probabilities(config=None,nodes=None,edges=None,sources=[],
835
1089
  t_list = kwargs["target_nodes"]
836
1090
  s_list = kwargs["source_nodes"]
837
1091
 
838
-
839
1092
  """
840
1093
  count = 1
841
1094
 
@@ -852,82 +1105,124 @@ def connection_probabilities(config=None,nodes=None,edges=None,sources=[],
852
1105
  total = total.source_node_id # may not be the best way to pick
853
1106
  return round(total/count,1)
854
1107
  """
855
-
856
- def eudist(df,use_x=True,use_y=True,use_z=True):
1108
+
1109
+ def eudist(df, use_x=True, use_y=True, use_z=True):
857
1110
  def _dist(x):
858
1111
  if len(x) == 6:
859
- return distance.euclidean((x.iloc[0], x.iloc[1], x.iloc[2]), (x.iloc[3], x.iloc[4], x.iloc[5]))
1112
+ return distance.euclidean(
1113
+ (x.iloc[0], x.iloc[1], x.iloc[2]), (x.iloc[3], x.iloc[4], x.iloc[5])
1114
+ )
860
1115
  elif len(x) == 4:
861
- return distance.euclidean((x.iloc[0],x.iloc[1]),(x.iloc[2],x.iloc[3]))
1116
+ return distance.euclidean((x.iloc[0], x.iloc[1]), (x.iloc[2], x.iloc[3]))
862
1117
  elif len(x) == 2:
863
- return distance.euclidean((x.iloc[0]),(x.iloc[1]))
1118
+ return distance.euclidean((x.iloc[0]), (x.iloc[1]))
864
1119
  else:
865
1120
  return -1
866
1121
 
867
- if use_x and use_y and use_z: #(XYZ)
868
- cols = ['source_pos_x','source_pos_y','source_pos_z',
869
- 'target_pos_x','target_pos_y','target_pos_z']
870
- elif use_x and use_y and not use_z: #(XY)
871
- cols = ['source_pos_x','source_pos_y',
872
- 'target_pos_x','target_pos_y',]
873
- elif use_x and not use_y and use_z: #(XZ)
874
- cols = ['source_pos_x','source_pos_z',
875
- 'target_pos_x','target_pos_z']
876
- elif not use_x and use_y and use_z: #(YZ)
877
- cols = ['source_pos_y','source_pos_z',
878
- 'target_pos_y','target_pos_z']
879
- elif use_x and not use_y and not use_z: #(X)
880
- cols = ['source_pos_x','target_pos_x']
881
- elif not use_x and use_y and not use_z: #(Y)
882
- cols = ['source_pos_y','target_pos_y']
883
- elif not use_x and not use_y and use_z: #(Z)
884
- cols = ['source_pos_z','target_pos_z']
1122
+ if use_x and use_y and use_z: # (XYZ)
1123
+ cols = [
1124
+ "source_pos_x",
1125
+ "source_pos_y",
1126
+ "source_pos_z",
1127
+ "target_pos_x",
1128
+ "target_pos_y",
1129
+ "target_pos_z",
1130
+ ]
1131
+ elif use_x and use_y and not use_z: # (XY)
1132
+ cols = [
1133
+ "source_pos_x",
1134
+ "source_pos_y",
1135
+ "target_pos_x",
1136
+ "target_pos_y",
1137
+ ]
1138
+ elif use_x and not use_y and use_z: # (XZ)
1139
+ cols = ["source_pos_x", "source_pos_z", "target_pos_x", "target_pos_z"]
1140
+ elif not use_x and use_y and use_z: # (YZ)
1141
+ cols = ["source_pos_y", "source_pos_z", "target_pos_y", "target_pos_z"]
1142
+ elif use_x and not use_y and not use_z: # (X)
1143
+ cols = ["source_pos_x", "target_pos_x"]
1144
+ elif not use_x and use_y and not use_z: # (Y)
1145
+ cols = ["source_pos_y", "target_pos_y"]
1146
+ elif not use_x and not use_y and use_z: # (Z)
1147
+ cols = ["source_pos_z", "target_pos_z"]
885
1148
  else:
886
1149
  cols = []
887
1150
 
888
- if ('source_pos_x' in df and 'target_pos_x' in df) or ('source_pos_y' in df and 'target_pos_y' in df) or ('source_pos_' in df and 'target_pos_z' in df):
889
- ret = df.loc[:,cols].apply(_dist,axis=1)
1151
+ if (
1152
+ ("source_pos_x" in df and "target_pos_x" in df)
1153
+ or ("source_pos_y" in df and "target_pos_y" in df)
1154
+ or ("source_pos_" in df and "target_pos_z" in df)
1155
+ ):
1156
+ ret = df.loc[:, cols].apply(_dist, axis=1)
890
1157
  else:
891
- print('No x, y, or z positions defined')
892
- ret=np.zeros(1)
893
-
1158
+ print("No x, y, or z positions defined")
1159
+ ret = np.zeros(1)
1160
+
894
1161
  return ret
895
1162
 
896
- relevant_edges = edges[(edges[source_id_type] == source_id) & (edges[target_id_type]==target_id)]
897
- if include_gap == False:
898
- try:
899
- relevant_edges = relevant_edges[relevant_edges['is_gap_junction'] != True]
1163
+ relevant_edges = edges[
1164
+ (edges[source_id_type] == source_id) & (edges[target_id_type] == target_id)
1165
+ ]
1166
+ if not include_gap:
1167
+ try:
1168
+ relevant_edges = relevant_edges[~relevant_edges["is_gap_junction"]]
900
1169
  except:
901
1170
  raise Exception("no gap junctions found to drop from connections")
902
- connected_distances = eudist(relevant_edges,dist_X,dist_Y,dist_Z).values.tolist()
903
- if len(connected_distances)>0:
904
- if connected_distances[0]==0:
1171
+ connected_distances = eudist(relevant_edges, dist_X, dist_Y, dist_Z).values.tolist()
1172
+ if len(connected_distances) > 0:
1173
+ if connected_distances[0] == 0:
905
1174
  return -1
906
- sl = s_list[s_list[source_id_type]==source_id]
907
- tl = t_list[t_list[target_id_type]==target_id]
908
-
909
- target_rows = ["target_pos_x","target_pos_y","target_pos_z"]
910
-
1175
+ sl = s_list[s_list[source_id_type] == source_id]
1176
+ tl = t_list[t_list[target_id_type] == target_id]
1177
+
1178
+ target_rows = ["target_pos_x", "target_pos_y", "target_pos_z"]
1179
+
911
1180
  all_distances = []
912
1181
  for target in tl.iterrows():
913
1182
  target = target[1]
914
1183
  for new_col in target_rows:
915
1184
  sl[new_col] = target[new_col]
916
- #sl[target_rows] = target.loc[target_rows].tolist()
917
- row_distances = eudist(sl,dist_X,dist_Y,dist_Z).tolist()
1185
+ # sl[target_rows] = target.loc[target_rows].tolist()
1186
+ row_distances = eudist(sl, dist_X, dist_Y, dist_Z).tolist()
918
1187
  all_distances = all_distances + row_distances
919
1188
  plt.ioff()
920
- ns,bins,patches = plt.hist([connected_distances,all_distances],density=False,histtype='stepfilled',bins=num_bins)
1189
+ ns, bins, patches = plt.hist(
1190
+ [connected_distances, all_distances],
1191
+ density=False,
1192
+ histtype="stepfilled",
1193
+ bins=num_bins,
1194
+ )
921
1195
  plt.ion()
922
- return {"ns":ns,"bins":bins}
923
- #import pdb;pdb.set_trace()
1196
+ return {"ns": ns, "bins": bins}
1197
+ # import pdb;pdb.set_trace()
924
1198
  # edges contains all edges
925
1199
 
926
- return relation_matrix(config,nodes,edges,sources,targets,sids,tids,prepend_pop,relation_func=connection_relationship,return_type=object,drop_point_process=True)
927
-
928
-
929
- def connection_graph_edge_types(config=None,nodes=None,edges=None,sources=[],targets=[],sids=[],tids=[],prepend_pop=True,edge_property='model_template'):
930
-
1200
+ return relation_matrix(
1201
+ config,
1202
+ nodes,
1203
+ edges,
1204
+ sources,
1205
+ targets,
1206
+ sids,
1207
+ tids,
1208
+ prepend_pop,
1209
+ relation_func=connection_relationship,
1210
+ return_type=object,
1211
+ drop_point_process=True,
1212
+ )
1213
+
1214
+
1215
+ def connection_graph_edge_types(
1216
+ config=None,
1217
+ nodes=None,
1218
+ edges=None,
1219
+ sources=[],
1220
+ targets=[],
1221
+ sids=[],
1222
+ tids=[],
1223
+ prepend_pop=True,
1224
+ edge_property="model_template",
1225
+ ):
931
1226
  def synapse_type_relationship(**kwargs):
932
1227
  edges = kwargs["edges"]
933
1228
  source_id_type = kwargs["sid"]
@@ -935,21 +1230,46 @@ def connection_graph_edge_types(config=None,nodes=None,edges=None,sources=[],tar
935
1230
  source_id = kwargs["source_id"]
936
1231
  target_id = kwargs["target_id"]
937
1232
 
938
- connections = edges[(edges[source_id_type] == source_id) & (edges[target_id_type]==target_id)]
939
-
940
- return list(connections[edge_property].unique())
941
-
942
- return relation_matrix(config,nodes,edges,sources,targets,sids,tids,prepend_pop,relation_func=synapse_type_relationship,return_type=object)
1233
+ connections = edges[
1234
+ (edges[source_id_type] == source_id) & (edges[target_id_type] == target_id)
1235
+ ]
943
1236
 
1237
+ return list(connections[edge_property].unique())
944
1238
 
945
- def edge_property_matrix(edge_property, config=None, nodes=None, edges=None, sources=[],targets=[],sids=[],tids=[],prepend_pop=True,report=None,time=-1,time_compare=None):
946
-
1239
+ return relation_matrix(
1240
+ config,
1241
+ nodes,
1242
+ edges,
1243
+ sources,
1244
+ targets,
1245
+ sids,
1246
+ tids,
1247
+ prepend_pop,
1248
+ relation_func=synapse_type_relationship,
1249
+ return_type=object,
1250
+ )
1251
+
1252
+
1253
+ def edge_property_matrix(
1254
+ edge_property,
1255
+ config=None,
1256
+ nodes=None,
1257
+ edges=None,
1258
+ sources=[],
1259
+ targets=[],
1260
+ sids=[],
1261
+ tids=[],
1262
+ prepend_pop=True,
1263
+ report=None,
1264
+ time=-1,
1265
+ time_compare=None,
1266
+ ):
947
1267
  var_report = None
948
- if time>=0 and report:
1268
+ if time >= 0 and report:
949
1269
  cfg = load_config(config)
950
- #report_full, report_file = _get_cell_report(config,report)
951
- report_file = report # Same difference
952
- var_report = EdgeVarsFile(os.path.join(cfg['output']['output_dir'],report_file+'.h5'))
1270
+ # report_full, report_file = _get_cell_report(config,report)
1271
+ report_file = report # Same difference
1272
+ var_report = EdgeVarsFile(os.path.join(cfg["output"]["output_dir"], report_file + ".h5"))
953
1273
 
954
1274
  def weight_hist_relationship(**kwargs):
955
1275
  edges = kwargs["edges"]
@@ -958,36 +1278,60 @@ def edge_property_matrix(edge_property, config=None, nodes=None, edges=None, sou
958
1278
  source_id = kwargs["source_id"]
959
1279
  target_id = kwargs["target_id"]
960
1280
 
961
- connections = edges[(edges[source_id_type] == source_id) & (edges[target_id_type]==target_id)]
1281
+ connections = edges[
1282
+ (edges[source_id_type] == source_id) & (edges[target_id_type] == target_id)
1283
+ ]
962
1284
  nonlocal time, report, var_report
963
1285
  ret = []
964
1286
 
965
- if time>=0 and report:
966
- sources = list(connections['source_node_id'].unique())
1287
+ if time >= 0 and report:
1288
+ sources = list(connections["source_node_id"].unique())
967
1289
  sources.sort()
968
- targets = list(connections['target_node_id'].unique())
969
- targets.sort()
970
-
971
- data,sources,targets = get_synapse_vars(None,None,edge_property,targets,source_gids=sources,compartments='all',var_report=var_report,time=time,time_compare=time_compare)
972
- if len(data.shape) and data.shape[0]!=0:
973
- ret = data[:,0]
1290
+ targets = list(connections["target_node_id"].unique())
1291
+ targets.sort()
1292
+
1293
+ data, sources, targets = get_synapse_vars(
1294
+ None,
1295
+ None,
1296
+ edge_property,
1297
+ targets,
1298
+ source_gids=sources,
1299
+ compartments="all",
1300
+ var_report=var_report,
1301
+ time=time,
1302
+ time_compare=time_compare,
1303
+ )
1304
+ if len(data.shape) and data.shape[0] != 0:
1305
+ ret = data[:, 0]
974
1306
  else:
975
1307
  ret = []
976
1308
  else:
977
- #if connections.get(edge_property) is not None: #Maybe we should fail if we can't find the variable...
1309
+ # if connections.get(edge_property) is not None: #Maybe we should fail if we can't find the variable...
978
1310
  ret = list(connections[edge_property])
979
1311
 
980
1312
  return ret
981
1313
 
982
- return relation_matrix(config,nodes,edges,sources,targets,sids,tids,prepend_pop,relation_func=weight_hist_relationship,return_type=object)
983
-
984
-
985
- def percent_connectivity(config=None,nodes=None,edges=None,sources=[],targets=[],sids=[],tids=[],prepend_pop=True):
986
-
1314
+ return relation_matrix(
1315
+ config,
1316
+ nodes,
1317
+ edges,
1318
+ sources,
1319
+ targets,
1320
+ sids,
1321
+ tids,
1322
+ prepend_pop,
1323
+ relation_func=weight_hist_relationship,
1324
+ return_type=object,
1325
+ )
1326
+
1327
+
1328
+ def percent_connectivity(
1329
+ config=None, nodes=None, edges=None, sources=[], targets=[], sids=[], tids=[], prepend_pop=True
1330
+ ):
987
1331
  import pandas as pd
988
-
1332
+
989
1333
  if not nodes and not edges:
990
- nodes,edges = load_nodes_edges_from_config(config)
1334
+ nodes, edges = load_nodes_edges_from_config(config)
991
1335
  if not nodes:
992
1336
  nodes = load_nodes_from_config(config)
993
1337
  if not edges:
@@ -995,38 +1339,49 @@ def percent_connectivity(config=None,nodes=None,edges=None,sources=[],targets=[]
995
1339
  if not edges and not nodes and not config:
996
1340
  raise Exception("No information given to load nodes/edges")
997
1341
 
998
- data, source_labels, target_labels = connection_totals(config=config,nodes=None,edges=None,sources=sources,targets=targets,sids=sids,tids=tids,prepend_pop=prepend_pop)
999
-
1000
- #total_cell_types = len(list(set(nodes[populations[0]]["node_type_id"])))
1342
+ data, source_labels, target_labels = connection_totals(
1343
+ config=config,
1344
+ nodes=None,
1345
+ edges=None,
1346
+ sources=sources,
1347
+ targets=targets,
1348
+ sids=sids,
1349
+ tids=tids,
1350
+ prepend_pop=prepend_pop,
1351
+ )
1352
+
1353
+ # total_cell_types = len(list(set(nodes[populations[0]]["node_type_id"])))
1001
1354
  vc = nodes[sources[0]].apply(pd.Series.value_counts)
1002
1355
  vc = vc["node_type_id"].dropna().sort_index()
1003
1356
  vc = list(vc)
1004
1357
 
1005
- max_connect = np.ones((len(vc),len(vc)),dtype=np.float)
1358
+ max_connect = np.ones((len(vc), len(vc)), dtype=np.float)
1006
1359
 
1007
1360
  for a, i in enumerate(vc):
1008
1361
  for b, j in enumerate(vc):
1009
- max_connect[a,b] = i*j
1010
- ret = data/max_connect
1011
- ret = ret*100
1362
+ max_connect[a, b] = i * j
1363
+ ret = data / max_connect
1364
+ ret = ret * 100
1012
1365
  ret = np.around(ret, decimals=1)
1013
1366
 
1014
1367
  return ret, source_labels, target_labels
1015
-
1368
+
1016
1369
 
1017
1370
  def connection_average_synapses():
1018
1371
  return
1019
1372
 
1020
1373
 
1021
- def connection_divergence_average_old(config=None, nodes=None, edges=None,populations=[],convergence=False):
1374
+ def connection_divergence_average_old(
1375
+ config=None, nodes=None, edges=None, populations=[], convergence=False
1376
+ ):
1022
1377
  """
1023
1378
  For each cell in source count # of connections in target and average
1024
1379
  """
1025
1380
 
1026
1381
  import pandas as pd
1027
-
1382
+
1028
1383
  if not nodes and not edges:
1029
- nodes,edges = load_nodes_edges_from_config(config)
1384
+ nodes, edges = load_nodes_edges_from_config(config)
1030
1385
  if not nodes:
1031
1386
  nodes = load_nodes_from_config(config)
1032
1387
  if not edges:
@@ -1039,26 +1394,45 @@ def connection_divergence_average_old(config=None, nodes=None, edges=None,popula
1039
1394
 
1040
1395
  nodes = nodes[list(nodes)[1]]
1041
1396
  edges = edges[list(edges)[1]]
1042
- pdb.set_trace()
1043
- src_df = pd.DataFrame({'edge_node_id': nodes.index,'source_node_pop_name':nodes['pop_name'],'source_node_type_id':nodes['node_type_id']})
1044
- tgt_df = pd.DataFrame({'edge_node_id': nodes.index,'target_node_pop_name':nodes['pop_name'],'target_node_type_id':nodes['node_type_id']})
1045
-
1046
- src_df.set_index('edge_node_id', inplace=True)
1047
- tgt_df.set_index('edge_node_id', inplace=True)
1048
-
1049
- edges_df = pd.merge(left=edges,
1050
- right=src_df,
1051
- how='left',
1052
- left_on='source_node_id',
1053
- right_index=True)
1054
-
1055
- edges_df = pd.merge(left=edges_df,
1056
- right=tgt_df,
1057
- how='left',
1058
- left_on='target_node_id',
1059
- right_index=True)
1060
-
1061
- edges_df_trim = edges_df.drop(edges_df.columns.difference(['source_node_type_id','target_node_type_id','source_node_pop_name','target_node_pop_name']), 1, inplace=False)
1397
+ # Debug statement removed
1398
+ src_df = pd.DataFrame(
1399
+ {
1400
+ "edge_node_id": nodes.index,
1401
+ "source_node_pop_name": nodes["pop_name"],
1402
+ "source_node_type_id": nodes["node_type_id"],
1403
+ }
1404
+ )
1405
+ tgt_df = pd.DataFrame(
1406
+ {
1407
+ "edge_node_id": nodes.index,
1408
+ "target_node_pop_name": nodes["pop_name"],
1409
+ "target_node_type_id": nodes["node_type_id"],
1410
+ }
1411
+ )
1412
+
1413
+ src_df.set_index("edge_node_id", inplace=True)
1414
+ tgt_df.set_index("edge_node_id", inplace=True)
1415
+
1416
+ edges_df = pd.merge(
1417
+ left=edges, right=src_df, how="left", left_on="source_node_id", right_index=True
1418
+ )
1419
+
1420
+ edges_df = pd.merge(
1421
+ left=edges_df, right=tgt_df, how="left", left_on="target_node_id", right_index=True
1422
+ )
1423
+
1424
+ edges_df_trim = edges_df.drop(
1425
+ edges_df.columns.difference(
1426
+ [
1427
+ "source_node_type_id",
1428
+ "target_node_type_id",
1429
+ "source_node_pop_name",
1430
+ "target_node_pop_name",
1431
+ ]
1432
+ ),
1433
+ 1,
1434
+ inplace=False,
1435
+ )
1062
1436
 
1063
1437
  vc = nodes.apply(pd.Series.value_counts)
1064
1438
  vc = vc["node_type_id"].dropna().sort_index()
@@ -1072,16 +1446,20 @@ def connection_divergence_average_old(config=None, nodes=None, edges=None,popula
1072
1446
  """
1073
1447
  src_list_node_types = list(set(edges_df_trim["source_node_type_id"]))
1074
1448
  tgt_list_node_types = list(set(edges_df_trim["target_node_type_id"]))
1075
- node_types = list(set(src_list_node_types+tgt_list_node_types))
1449
+ node_types = list(set(src_list_node_types + tgt_list_node_types))
1076
1450
 
1077
- e_matrix = np.zeros((len(node_types),len(node_types)))
1451
+ e_matrix = np.zeros((len(node_types), len(node_types)))
1078
1452
 
1079
1453
  for a, i in enumerate(node_types):
1080
1454
  for b, j in enumerate(node_types):
1081
- num_conns = edges_df_trim[(edges_df_trim.source_node_type_id == i) & (edges_df_trim.target_node_type_id==j)].count()
1082
- c = b if convergence else a #Show convergence if set. By dividing by targe totals instead of source
1455
+ num_conns = edges_df_trim[
1456
+ (edges_df_trim.source_node_type_id == i) & (edges_df_trim.target_node_type_id == j)
1457
+ ].count()
1458
+ c = (
1459
+ b if convergence else a
1460
+ ) # Show convergence if set. By dividing by targe totals instead of source
1083
1461
 
1084
- e_matrix[a,b] = num_conns.source_node_type_id/vc[c]
1462
+ e_matrix[a, b] = num_conns.source_node_type_id / vc[c]
1085
1463
 
1086
1464
  ret = np.around(e_matrix, decimals=1)
1087
1465
 
@@ -1089,29 +1467,39 @@ def connection_divergence_average_old(config=None, nodes=None, edges=None,popula
1089
1467
 
1090
1468
 
1091
1469
  class EdgeVarsFile(CellVarsFile):
1092
- def __init__(self, filename, mode='r', **params):
1470
+ def __init__(self, filename, mode="r", **params):
1093
1471
  super().__init__(filename, mode, **params)
1094
1472
  self._var_src_ids = []
1095
1473
  self._var_trg_ids = []
1096
- for var_name in self._h5_root['mapping'].keys():
1097
- if var_name == 'src_ids':
1098
- self._var_src_ids = list(self._h5_root['mapping']['src_ids'])
1099
- if var_name == 'trg_ids':
1100
- self._var_trg_ids = list(self._h5_root['mapping']['trg_ids'])
1101
- def sources(self,target_gid=None):
1474
+ for var_name in self._h5_root["mapping"].keys():
1475
+ if var_name == "src_ids":
1476
+ self._var_src_ids = list(self._h5_root["mapping"]["src_ids"])
1477
+ if var_name == "trg_ids":
1478
+ self._var_trg_ids = list(self._h5_root["mapping"]["trg_ids"])
1479
+
1480
+ def sources(self, target_gid=None):
1102
1481
  if target_gid:
1103
1482
  tb = self._gid2data_table[target_gid]
1104
- return self._h5_root['mapping']['src_ids'][tb[0]:tb[1]]
1483
+ return self._h5_root["mapping"]["src_ids"][tb[0] : tb[1]]
1105
1484
  else:
1106
1485
  return self._var_src_ids
1486
+
1107
1487
  def targets(self):
1108
1488
  return self._var_trg_ids
1109
- def data(self,gid,var_name=CellVarsFile.VAR_UNKNOWN,time_window=None,compartments='origin',sources=None):
1110
- d = super().data(gid,var_name,time_window,compartments)
1489
+
1490
+ def data(
1491
+ self,
1492
+ gid,
1493
+ var_name=CellVarsFile.VAR_UNKNOWN,
1494
+ time_window=None,
1495
+ compartments="origin",
1496
+ sources=None,
1497
+ ):
1498
+ d = super().data(gid, var_name, time_window, compartments)
1111
1499
  if not sources:
1112
1500
  return d
1113
1501
  else:
1114
- if type(sources) is int:
1502
+ if isinstance(sources, int):
1115
1503
  sources = [sources]
1116
1504
  d_new = None
1117
1505
  for dl, s in zip(d, self.sources()):
@@ -1119,238 +1507,254 @@ class EdgeVarsFile(CellVarsFile):
1119
1507
  if d_new is None:
1120
1508
  d_new = np.array([dl])
1121
1509
  else:
1122
- d_new = np.append(d_new, [dl],axis=0)
1510
+ d_new = np.append(d_new, [dl], axis=0)
1123
1511
  if d_new is None:
1124
- d_new = np.array([])
1512
+ d_new = np.array([])
1125
1513
  return d_new
1126
1514
 
1127
1515
 
1128
- def get_synapse_vars(config,report,var_name,target_gids,source_gids=None,compartments='all',var_report=None,time=None,time_compare=None):
1516
+ def get_synapse_vars(
1517
+ config,
1518
+ report,
1519
+ var_name,
1520
+ target_gids,
1521
+ source_gids=None,
1522
+ compartments="all",
1523
+ var_report=None,
1524
+ time=None,
1525
+ time_compare=None,
1526
+ ):
1129
1527
  """
1130
1528
  Ex: data, sources = get_synapse_vars('9999_simulation_config.json', 'syn_report', 'W_ampa', 31)
1131
1529
  """
1132
1530
  if not var_report:
1133
1531
  cfg = load_config(config)
1134
- #report, report_file = _get_cell_report(config,report)
1135
- report_file = report # Same difference
1136
- var_report = EdgeVarsFile(os.path.join(cfg['output']['output_dir'],report_file+'.h5'))
1532
+ # report, report_file = _get_cell_report(config,report)
1533
+ report_file = report # Same difference
1534
+ var_report = EdgeVarsFile(os.path.join(cfg["output"]["output_dir"], report_file + ".h5"))
1137
1535
 
1138
- if type(target_gids) is int:
1536
+ if isinstance(target_gids, int):
1139
1537
  target_gids = [target_gids]
1140
-
1538
+
1141
1539
  data_ret = None
1142
1540
  sources_ret = None
1143
1541
  targets_ret = None
1144
-
1542
+
1145
1543
  for target_gid in target_gids:
1146
- if not var_report._gid2data_table.get(target_gid):#This cell was not reported
1544
+ if not var_report._gid2data_table.get(target_gid): # This cell was not reported
1147
1545
  continue
1148
1546
  data = var_report.data(gid=target_gid, var_name=var_name, compartments=compartments)
1149
- if(len(data.shape)==1):
1150
- data = data.reshape(1,-1)
1547
+ if len(data.shape) == 1:
1548
+ data = data.reshape(1, -1)
1151
1549
 
1152
1550
  if time is not None and time_compare is not None:
1153
- data = np.array(data[:,time_compare] - data[:,time]).reshape(-1,1)
1551
+ data = np.array(data[:, time_compare] - data[:, time]).reshape(-1, 1)
1154
1552
  elif time is not None:
1155
- data = np.delete(data,np.s_[time+1:],1)
1156
- data = np.delete(data,np.s_[:time],1)
1553
+ data = np.delete(data, np.s_[time + 1 :], 1)
1554
+ data = np.delete(data, np.s_[:time], 1)
1157
1555
 
1158
1556
  sources = var_report.sources(target_gid=target_gid)
1159
1557
  if source_gids:
1160
- if type(source_gids) is int:
1558
+ if isinstance(source_gids, int):
1161
1559
  source_gids = [source_gids]
1162
- data = [d for d,s in zip(data,sources) if s in source_gids]
1560
+ data = [d for d, s in zip(data, sources) if s in source_gids]
1163
1561
  sources = [s for s in sources if s in source_gids]
1164
-
1562
+
1165
1563
  targets = np.zeros(len(sources))
1166
1564
  targets.fill(target_gid)
1167
1565
 
1168
- if data_ret is None or data_ret is not None and len(data_ret)==0:
1566
+ if data_ret is None or data_ret is not None and len(data_ret) == 0:
1169
1567
  data_ret = data
1170
1568
  else:
1171
- data_ret = np.append(data_ret, data,axis=0)
1172
- if sources_ret is None or sources_ret is not None and len(sources_ret)==0:
1569
+ data_ret = np.append(data_ret, data, axis=0)
1570
+ if sources_ret is None or sources_ret is not None and len(sources_ret) == 0:
1173
1571
  sources_ret = sources
1174
1572
  else:
1175
- sources_ret = np.append(sources_ret, sources,axis=0)
1176
- if targets_ret is None or targets_ret is not None and len(targets_ret)==0:
1573
+ sources_ret = np.append(sources_ret, sources, axis=0)
1574
+ if targets_ret is None or targets_ret is not None and len(targets_ret) == 0:
1177
1575
  targets_ret = targets
1178
1576
  else:
1179
- targets_ret = np.append(targets_ret, targets,axis=0)
1577
+ targets_ret = np.append(targets_ret, targets, axis=0)
1180
1578
 
1181
1579
  return np.array(data_ret), np.array(sources_ret), np.array(targets_ret)
1182
1580
 
1183
1581
 
1184
- def tk_email_input(title="Send Model Files (with simplified GUI)",prompt="Enter your email address. (CHECK YOUR SPAM FOLDER)"):
1582
+ def tk_email_input(
1583
+ title="Send Model Files (with simplified GUI)",
1584
+ prompt="Enter your email address. (CHECK YOUR SPAM FOLDER)",
1585
+ ):
1185
1586
  import tkinter as tk
1186
1587
  from tkinter import simpledialog
1588
+
1187
1589
  root = tk.Tk()
1188
1590
  root.withdraw()
1189
1591
  # the input dialog
1190
1592
  user_inp = simpledialog.askstring(title=title, prompt=prompt)
1191
1593
  return user_inp
1192
1594
 
1595
+
1193
1596
  def popupmsg(msg):
1194
1597
  import tkinter as tk
1195
1598
  from tkinter import ttk
1599
+
1196
1600
  popup = tk.Tk()
1197
1601
  popup.wm_title("!")
1198
1602
  NORM_FONT = ("Helvetica", 10)
1199
1603
  label = ttk.Label(popup, text=msg, font=NORM_FONT)
1200
1604
  label.pack(side="top", fill="x", pady=10)
1201
- B1 = ttk.Button(popup, text="Okay", command = popup.destroy)
1605
+ B1 = ttk.Button(popup, text="Okay", command=popup.destroy)
1202
1606
  B1.pack()
1203
1607
  popup.mainloop()
1204
1608
 
1205
- import smtplib
1206
- from os.path import basename
1207
- from email.mime.application import MIMEApplication
1208
- from email.mime.multipart import MIMEMultipart
1209
- from email.mime.text import MIMEText
1210
- from email.utils import COMMASPACE, formatdate
1211
1609
 
1212
-
1213
- def send_mail(send_from, send_to, subject, text, files=None,server="127.0.0.1"):
1610
+ def send_mail(send_from, send_to, subject, text, files=None, server="127.0.0.1"):
1214
1611
  assert isinstance(send_to, list)
1215
1612
  msg = MIMEMultipart()
1216
- msg['From'] = send_from
1217
- msg['To'] = COMMASPACE.join(send_to)
1218
- msg['Date'] = formatdate(localtime=True)
1219
- msg['Subject'] = subject
1613
+ msg["From"] = send_from
1614
+ msg["To"] = COMMASPACE.join(send_to)
1615
+ msg["Date"] = formatdate(localtime=True)
1616
+ msg["Subject"] = subject
1220
1617
  msg.attach(MIMEText(text))
1221
1618
  for f in files or []:
1222
1619
  with open(f, "rb") as fil:
1223
- part = MIMEApplication(
1224
- fil.read(),
1225
- Name=basename(f)
1226
- )
1620
+ part = MIMEApplication(fil.read(), Name=basename(f))
1227
1621
  # After the file is closed
1228
- part['Content-Disposition'] = 'attachment; filename="%s"' % basename(f)
1622
+ part["Content-Disposition"] = 'attachment; filename="%s"' % basename(f)
1229
1623
  msg.attach(part)
1230
1624
  smtp = smtplib.SMTP(server)
1231
1625
  smtp.sendmail(send_from, send_to, msg.as_string())
1232
1626
  smtp.close()
1233
1627
 
1628
+
1234
1629
  def load_csv(csvfile):
1235
1630
  # TODO: make the separator more flexible
1236
1631
  if isinstance(csvfile, pd.DataFrame):
1237
1632
  return csvfile
1238
1633
 
1239
1634
  # TODO: check if it is csv object and convert to a pd dataframe
1240
- return pd.read_csv(csvfile, sep=' ', na_values='NONE')
1635
+ return pd.read_csv(csvfile, sep=" ", na_values="NONE")
1636
+
1241
1637
 
1242
1638
  # The following code was developed by Matthew Stroud 7/15/21: Neural engineering group supervisor: Satish Nair
1243
- # This is an extension of bmtool: a development of Tyler Banks.
1639
+ # This is an extension of bmtool: a development of Tyler Banks.
1244
1640
  # These are helper functions for I_clamps and spike train information.
1245
1641
 
1642
+
1246
1643
  def load_I_clamp_from_paths(Iclamp_paths):
1247
1644
  # Get info from .h5 files
1248
- if Iclamp_paths.endswith('.h5'):
1249
- f = h5py.File(Iclamp_paths, 'r')
1250
- if 'amplitudes' in f and 'dts' in f and 'gids' in f:
1251
- [amplitudes]=f['amplitudes'][:].tolist()
1252
- dts=list(f['dts'])
1253
- dts=dts[0]
1254
- dset=f['gids']
1645
+ if Iclamp_paths.endswith(".h5"):
1646
+ f = h5py.File(Iclamp_paths, "r")
1647
+ if "amplitudes" in f and "dts" in f and "gids" in f:
1648
+ [amplitudes] = f["amplitudes"][:].tolist()
1649
+ dts = list(f["dts"])
1650
+ dts = dts[0]
1651
+ dset = f["gids"]
1255
1652
  gids = dset[()]
1256
- if gids == 'all':
1257
- gids=' All biophysical cells'
1258
- clamp=[amplitudes,dts,gids]
1653
+ if gids == "all":
1654
+ gids = " All biophysical cells"
1655
+ clamp = [amplitudes, dts, gids]
1259
1656
  else:
1260
- raise Exception('.h5 file is not in the format "amplitudes","dts","gids". Cannot parse.')
1657
+ raise Exception(
1658
+ '.h5 file is not in the format "amplitudes","dts","gids". Cannot parse.'
1659
+ )
1261
1660
  else:
1262
- raise Exception('Input file is not of type .h5. Cannot parse.')
1661
+ raise Exception("Input file is not of type .h5. Cannot parse.")
1263
1662
  return clamp
1264
1663
 
1664
+
1265
1665
  def load_I_clamp_from_config(fp):
1266
1666
  if fp is None:
1267
- fp = 'config.json'
1667
+ fp = "config.json"
1268
1668
  config = load_config(fp)
1269
- inputs=config['inputs']
1669
+ inputs = config["inputs"]
1270
1670
  # Load in all current clamps
1271
- ICLAMPS=[]
1671
+ ICLAMPS = []
1272
1672
  for i in inputs:
1273
- if inputs[i]['input_type']=="current_clamp":
1274
- I_clamp=inputs[i]
1673
+ if inputs[i]["input_type"] == "current_clamp":
1674
+ I_clamp = inputs[i]
1275
1675
  # Get current clamp info where an input file is provided
1276
- if 'input_file' in I_clamp:
1277
- ICLAMPS.append(load_I_clamp_from_paths(I_clamp['input_file']))
1676
+ if "input_file" in I_clamp:
1677
+ ICLAMPS.append(load_I_clamp_from_paths(I_clamp["input_file"]))
1278
1678
  # Get current clamp info when provided in "amp", "delay", "duration" format
1279
- elif 'amp' in I_clamp and 'delay' in I_clamp and 'duration' in I_clamp:
1679
+ elif "amp" in I_clamp and "delay" in I_clamp and "duration" in I_clamp:
1280
1680
  # Get simulation info from config
1281
- run=config['run']
1282
- dts=run['dt']
1283
- if 'tstart' in run:
1284
- tstart=run['tstart']
1681
+ run = config["run"]
1682
+ dts = run["dt"]
1683
+ if "tstart" in run:
1684
+ tstart = run["tstart"]
1285
1685
  else:
1286
- tstart=0
1287
- tstop=run['tstop']
1288
- simlength=tstop-tstart
1289
- nstep=int(simlength/dts)
1686
+ tstart = 0
1687
+ tstop = run["tstop"]
1688
+ simlength = tstop - tstart
1689
+ nstep = int(simlength / dts)
1290
1690
  # Get input info from config
1291
- amp=I_clamp['amp']
1292
- gids=I_clamp['node_set']
1293
- delay=I_clamp['delay']
1294
- duration=I_clamp['duration']
1691
+ amp = I_clamp["amp"]
1692
+ gids = I_clamp["node_set"]
1693
+ delay = I_clamp["delay"]
1694
+ duration = I_clamp["duration"]
1295
1695
  # Create a list with amplitude at each time step in the simulation
1296
- amplitude=[]
1696
+ amplitude = []
1297
1697
  for i in range(nstep):
1298
- if i*dts>=delay and i*dts<=delay+duration:
1698
+ if i * dts >= delay and i * dts <= delay + duration:
1299
1699
  amplitude.append(amp)
1300
1700
  else:
1301
1701
  amplitude.append(0)
1302
- ICLAMPS.append([amplitude,dts,gids])
1702
+ ICLAMPS.append([amplitude, dts, gids])
1303
1703
  else:
1304
- raise Exception('No information found about this current clamp.')
1704
+ raise Exception("No information found about this current clamp.")
1305
1705
  return ICLAMPS
1306
1706
 
1707
+
1307
1708
  def load_inspikes_from_paths(inspike_paths):
1308
1709
  # Get info from .h5 files
1309
- if inspike_paths.endswith('.h5'):
1310
- f = h5py.File(inspike_paths, 'r')
1710
+ if inspike_paths.endswith(".h5"):
1711
+ f = h5py.File(inspike_paths, "r")
1311
1712
  # This is assuming that the first object in the file is named 'spikes'
1312
- spikes=f['spikes']
1713
+ spikes = f["spikes"]
1313
1714
  for i in spikes:
1314
- inp=spikes[i]
1315
- if 'node_ids' in inp and 'timestamps' in inp:
1316
- node_ids=list(inp['node_ids'])
1317
- timestamps=list(inp['timestamps'])
1318
- data=[]
1715
+ inp = spikes[i]
1716
+ if "node_ids" in inp and "timestamps" in inp:
1717
+ node_ids = list(inp["node_ids"])
1718
+ timestamps = list(inp["timestamps"])
1719
+ data = []
1319
1720
  for j in range(len(node_ids)):
1320
- data.append([str(timestamps[j]),str(node_ids[j]),''])
1321
- data=np.array(data, dtype=object)
1322
- elif inspike_paths.endswith('.csv'):
1721
+ data.append([str(timestamps[j]), str(node_ids[j]), ""])
1722
+ data = np.array(data, dtype=object)
1723
+ elif inspike_paths.endswith(".csv"):
1323
1724
  # Loads in .csv and if it is of the form (timestamps node_ids population) it skips the conditionals.
1324
- data=np.loadtxt(open(inspike_paths, 'r'), delimiter=" ",dtype=object,skiprows=1)
1325
- if len(data[0])==2: #This assumes gid in first column and spike times comma separated in second column
1326
- temp=[]
1725
+ data = np.loadtxt(open(inspike_paths, "r"), delimiter=" ", dtype=object, skiprows=1)
1726
+ if (
1727
+ len(data[0]) == 2
1728
+ ): # This assumes gid in first column and spike times comma separated in second column
1729
+ temp = []
1327
1730
  for i in data:
1328
- timestamps=i[1]
1329
- timestamps=timestamps.split(',')
1731
+ timestamps = i[1]
1732
+ timestamps = timestamps.split(",")
1330
1733
  for j in timestamps:
1331
- temp.append([j,i[0],''])
1332
- data=np.array(temp, dtype=object)
1333
- #If the .csv is not in the form (timestamps node_ids population) or (gid timestamps)
1334
- elif not len(data[0])==3:
1335
- print('The .csv spike file '+ inspike_paths +' is not in the correct format')
1734
+ temp.append([j, i[0], ""])
1735
+ data = np.array(temp, dtype=object)
1736
+ # If the .csv is not in the form (timestamps node_ids population) or (gid timestamps)
1737
+ elif not len(data[0]) == 3:
1738
+ print("The .csv spike file " + inspike_paths + " is not in the correct format")
1336
1739
  return
1337
1740
  else:
1338
- raise Exception('Input file is not of type .h5 or .csv. Cannot parse.')
1741
+ raise Exception("Input file is not of type .h5 or .csv. Cannot parse.")
1339
1742
  return data
1340
1743
 
1744
+
1341
1745
  def load_inspikes_from_config(fp):
1342
1746
  if fp is None:
1343
- fp = 'config.json'
1747
+ fp = "config.json"
1344
1748
  config = load_config(fp)
1345
- inputs=config['inputs']
1749
+ inputs = config["inputs"]
1346
1750
  # Load in all current clamps
1347
- INSPIKES=[]
1751
+ INSPIKES = []
1348
1752
  for i in inputs:
1349
- if inputs[i]['input_type']=="spikes":
1350
- INPUT=inputs[i]
1753
+ if inputs[i]["input_type"] == "spikes":
1754
+ INPUT = inputs[i]
1351
1755
  # Get current clamp info where an input file is provided
1352
- if 'input_file' in INPUT:
1353
- INSPIKES.append(load_inspikes_from_paths(INPUT['input_file']))
1756
+ if "input_file" in INPUT:
1757
+ INSPIKES.append(load_inspikes_from_paths(INPUT["input_file"]))
1354
1758
  else:
1355
- raise Exception('No information found about this current clamp.')
1759
+ raise Exception("No information found about this current clamp.")
1356
1760
  return INSPIKES