bmtool 0.7.0.6.4__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bmtool/connectors.py CHANGED
@@ -1,32 +1,33 @@
1
- from abc import ABC, abstractmethod
2
- import numpy as np
3
- from scipy.special import erf
4
- from scipy.optimize import minimize_scalar
5
- from functools import partial
6
1
  import time
7
2
  import types
3
+ from abc import ABC, abstractmethod
4
+ from functools import partial
5
+
6
+ import numpy as np
8
7
  import pandas as pd
9
- import re
8
+ from scipy.optimize import minimize_scalar
9
+ from scipy.special import erf
10
10
 
11
11
  rng = np.random.default_rng()
12
12
 
13
- report_name = 'conn.csv'
13
+ report_name = "conn.csv"
14
14
 
15
15
  ##############################################################################
16
16
  ############################## CONNECT CELLS #################################
17
17
 
18
+
18
19
  # Utility Functions
19
20
  def num_prop(ratio, N):
20
21
  """
21
22
  Calculate numbers of total N in proportion to ratio.
22
-
23
+
23
24
  Parameters:
24
25
  -----------
25
26
  ratio : array-like
26
27
  Proportions to distribute N across.
27
28
  N : int
28
29
  Total number to distribute.
29
-
30
+
30
31
  Returns:
31
32
  --------
32
33
  numpy.ndarray
@@ -40,14 +41,14 @@ def num_prop(ratio, N):
40
41
  def decision(prob, size=None):
41
42
  """
42
43
  Make random decision(s) based on input probability.
43
-
44
+
44
45
  Parameters:
45
46
  -----------
46
47
  prob : float
47
48
  Probability threshold between 0 and 1.
48
49
  size : int or tuple, optional
49
50
  Size of the output array. If None, a single decision is returned.
50
-
51
+
51
52
  Returns:
52
53
  --------
53
54
  bool or numpy.ndarray
@@ -60,16 +61,16 @@ def decision(prob, size=None):
60
61
  def decisions(prob):
61
62
  """
62
63
  Make multiple random decisions based on input probabilities.
63
-
64
+
64
65
  Parameters:
65
66
  -----------
66
67
  prob : array-like
67
68
  Array of probability thresholds between 0 and 1.
68
-
69
+
69
70
  Returns:
70
71
  --------
71
72
  numpy.ndarray
72
- Boolean array with the same shape as prob, containing results of
73
+ Boolean array with the same shape as prob, containing results of
73
74
  the random decisions.
74
75
  """
75
76
  prob = np.asarray(prob)
@@ -82,17 +83,17 @@ def euclid_dist(p1, p2):
82
83
  p1, p2: Coordinates in numpy array
83
84
  """
84
85
  dvec = np.asarray(p1) - np.asarray(p2)
85
- return (dvec @ dvec) ** .5
86
+ return (dvec @ dvec) ** 0.5
86
87
 
87
88
 
88
89
  def spherical_dist(node1, node2):
89
90
  """Spherical distance between two input nodes"""
90
- return euclid_dist(node1['positions'], node2['positions']).item()
91
+ return euclid_dist(node1["positions"], node2["positions"]).item()
91
92
 
92
93
 
93
94
  def cylindrical_dist_z(node1, node2):
94
95
  """Cylindircal distance between two input nodes (ignoring z-axis)"""
95
- return euclid_dist(node1['positions'][:2], node2['positions'][:2]).item()
96
+ return euclid_dist(node1["positions"][:2], node2["positions"][:2]).item()
96
97
 
97
98
 
98
99
  # Probability Classes
@@ -118,8 +119,8 @@ class ProbabilityFunction(ABC):
118
119
  class DistantDependentProbability(ProbabilityFunction):
119
120
  """Base class for distance dependent probability"""
120
121
 
121
- def __init__(self, min_dist=0., max_dist=np.inf):
122
- assert(min_dist >= 0 and min_dist < max_dist)
122
+ def __init__(self, min_dist=0.0, max_dist=np.inf):
123
+ assert min_dist >= 0 and min_dist < max_dist
123
124
  self.min_dist, self.max_dist = min_dist, max_dist
124
125
 
125
126
  def __call__(self, dist, *arg, **kwargs):
@@ -127,7 +128,7 @@ class DistantDependentProbability(ProbabilityFunction):
127
128
  if dist >= self.min_dist and dist <= self.max_dist:
128
129
  return self.probability(dist)
129
130
  else:
130
- return 0.
131
+ return 0.0
131
132
 
132
133
  def decisions(self, dist):
133
134
  """Return bool array of decisions given distance array"""
@@ -144,31 +145,32 @@ class DistantDependentProbability(ProbabilityFunction):
144
145
  class UniformInRange(DistantDependentProbability):
145
146
  """Constant probability within a distance range"""
146
147
 
147
- def __init__(self, p=0., min_dist=0., max_dist=np.inf):
148
+ def __init__(self, p=0.0, min_dist=0.0, max_dist=np.inf):
148
149
  super().__init__(min_dist=min_dist, max_dist=max_dist)
149
150
  self.p = np.array(p)
150
- assert(self.p.size == 1)
151
- assert(p >= 0. and p <= 1.)
151
+ assert self.p.size == 1
152
+ assert p >= 0.0 and p <= 1.0
152
153
 
153
154
  def probability(self, dist):
154
155
  return self.p
155
156
 
156
157
 
157
- NORM_COEF = (2 * np.pi) ** (-.5) # coefficient of standard normal PDF
158
+ NORM_COEF = (2 * np.pi) ** (-0.5) # coefficient of standard normal PDF
159
+
158
160
 
159
- def gaussian(x, mean=0., stdev=1., pmax=NORM_COEF):
161
+ def gaussian(x, mean=0.0, stdev=1.0, pmax=NORM_COEF):
160
162
  """Gaussian function. Default is the PDF of standard normal distribution"""
161
163
  x = (x - mean) / stdev
162
- return pmax * np.exp(- x * x / 2)
164
+ return pmax * np.exp(-x * x / 2)
163
165
 
164
166
 
165
167
  class GaussianDropoff(DistantDependentProbability):
166
168
  """
167
169
  Connection probability class that follows a Gaussian function of distance.
168
-
170
+
169
171
  This class calculates connection probabilities using a Gaussian function
170
172
  of the distance between cells, with options for spherical or cylindrical metrics.
171
-
173
+
172
174
  Parameters:
173
175
  -----------
174
176
  mean : float, optional
@@ -188,7 +190,7 @@ class GaussianDropoff(DistantDependentProbability):
188
190
  Distance range (min_dist, max_dist) for calculating pmax when ptotal is provided.
189
191
  dist_type : str, optional
190
192
  Distance metric to use, either 'spherical' (default) or 'cylindrical'.
191
-
193
+
192
194
  Notes:
193
195
  ------
194
196
  When ptotal is specified, the maximum probability (pmax) is calculated to achieve
@@ -196,16 +198,24 @@ class GaussianDropoff(DistantDependentProbability):
196
198
  assuming homogeneous cell density.
197
199
  """
198
200
 
199
- def __init__(self, mean=0., stdev=1., min_dist=0., max_dist=np.inf,
200
- pmax=1, ptotal=None, ptotal_dist_range=None,
201
- dist_type='spherical'):
201
+ def __init__(
202
+ self,
203
+ mean=0.0,
204
+ stdev=1.0,
205
+ min_dist=0.0,
206
+ max_dist=np.inf,
207
+ pmax=1,
208
+ ptotal=None,
209
+ ptotal_dist_range=None,
210
+ dist_type="spherical",
211
+ ):
202
212
  super().__init__(min_dist=min_dist, max_dist=max_dist)
203
213
  self.mean, self.stdev = mean, stdev
204
214
  self.ptotal = ptotal
205
- self.ptotal_dist_range = (min_dist, max_dist) \
206
- if ptotal_dist_range is None else ptotal_dist_range
207
- self.dist_type = dist_type if dist_type in \
208
- ['cylindrical'] else 'spherical'
215
+ self.ptotal_dist_range = (
216
+ (min_dist, max_dist) if ptotal_dist_range is None else ptotal_dist_range
217
+ )
218
+ self.dist_type = dist_type if dist_type in ["cylindrical"] else "spherical"
209
219
  self.pmax = pmax if ptotal is None else self.calc_pmax_from_ptotal()
210
220
  self.set_probability_func()
211
221
 
@@ -233,18 +243,21 @@ class GaussianDropoff(DistantDependentProbability):
233
243
  mu, sig = self.mean, self.stdev
234
244
  r1, r2 = self.ptotal_dist_range[:2]
235
245
  x1, x2 = (r1 - mu) / sig, (r2 - mu) / sig # normalized distance
236
- if self.dist_type == 'cylindrical':
237
- dr = r2 ** 2 - r1 ** 2
246
+ if self.dist_type == "cylindrical":
247
+ dr = r2**2 - r1**2
248
+
238
249
  def F(x):
239
- f1 = sig * mu / NORM_COEF * erf(x / 2**.5)
240
- f2 = -2 * sig * sig * gaussian(x, pmax=1.)
250
+ f1 = sig * mu / NORM_COEF * erf(x / 2**0.5)
251
+ f2 = -2 * sig * sig * gaussian(x, pmax=1.0)
241
252
  return f1 + f2
242
253
  else:
243
- dr = r2 ** 3 - r1 ** 3
254
+ dr = r2**3 - r1**3
255
+
244
256
  def F(x):
245
- f1 = 1.5 * sig * (sig**2 + mu**2) / NORM_COEF * erf(x / 2**.5)
246
- f2 = -3 * sig * sig * (2 * mu + sig * x) * gaussian(x, pmax=1.)
257
+ f1 = 1.5 * sig * (sig**2 + mu**2) / NORM_COEF * erf(x / 2**0.5)
258
+ f2 = -3 * sig * sig * (2 * mu + sig * x) * gaussian(x, pmax=1.0)
247
259
  return f1 + f2
260
+
248
261
  return self.ptotal * dr / (F(x2) - F(x1))
249
262
 
250
263
  def probability(self):
@@ -252,24 +265,30 @@ class GaussianDropoff(DistantDependentProbability):
252
265
 
253
266
  def set_probability_func(self):
254
267
  """Set up function for calculating probability"""
255
- keys = ['mean', 'stdev', 'pmax']
268
+ keys = ["mean", "stdev", "pmax"]
256
269
  kwargs = {key: getattr(self, key) for key in keys}
257
270
  probability = partial(gaussian, **kwargs)
258
271
 
259
272
  # Verify maximum probability
260
273
  # (is not self.pmax if self.mean outside distance range)
261
274
  bounds = (self.min_dist, min(self.max_dist, 1e9))
262
- pmax = self.pmax if self.mean >= bounds[0] and self.mean <= bounds[1] \
275
+ pmax = (
276
+ self.pmax
277
+ if self.mean >= bounds[0] and self.mean <= bounds[1]
263
278
  else probability(np.asarray(bounds)).max()
279
+ )
264
280
  if pmax > 1:
265
- d = minimize_scalar(lambda x: (probability(x) - 1)**2,
266
- method='bounded', bounds=bounds).x
267
- warn = ("\nWarning: Maximum probability=%.3f is greater than 1. "
268
- "Probability crosses 1 at distance %.3g.\n") % (pmax, d)
281
+ d = minimize_scalar(
282
+ lambda x: (probability(x) - 1) ** 2, method="bounded", bounds=bounds
283
+ ).x
284
+ warn = (
285
+ "\nWarning: Maximum probability=%.3f is greater than 1. "
286
+ "Probability crosses 1 at distance %.3g.\n"
287
+ ) % (pmax, d)
269
288
  if self.ptotal is not None:
270
289
  warn += " ptotal may not be reached."
271
- print(warn,flush=True)
272
- self.probability = lambda dist: np.fmin(probability(dist), 1.)
290
+ print(warn, flush=True)
291
+ self.probability = lambda dist: np.fmin(probability(dist), 1.0)
273
292
  else:
274
293
  self.probability = probability
275
294
 
@@ -280,7 +299,7 @@ class NormalizedReciprocalRate(ProbabilityFunction):
280
299
  connection probability and the connection probability for a randomly
281
300
  connected network where the two unidirectional connections between any pair
282
301
  of neurons are independent. NRR = pr / (p0 * p1)
283
-
302
+
284
303
  Parameters:
285
304
  NRR: a constant or distance dependent function for normalized reciprocal
286
305
  rate. When being a function, it should be accept vectorized input.
@@ -288,7 +307,7 @@ class NormalizedReciprocalRate(ProbabilityFunction):
288
307
  A callable object that returns the probability value.
289
308
  """
290
309
 
291
- def __init__(self, NRR=1.):
310
+ def __init__(self, NRR=1.0):
292
311
  self.NRR = NRR if callable(NRR) else lambda *x: NRR
293
312
 
294
313
  def probability(self, dist, p0, p1):
@@ -311,17 +330,18 @@ class NormalizedReciprocalRate(ProbabilityFunction):
311
330
  dist, p0, p1 = map(np.asarray, (dist, p0, p1))
312
331
  pr = np.empty(dist.shape)
313
332
  pr[:] = self.probability(dist, p0, p1)
314
- pr = np.clip(pr, a_min=np.fmax(p0 + p1 - 1., 0.), a_max=np.fmin(p0, p1))
333
+ pr = np.clip(pr, a_min=np.fmax(p0 + p1 - 1.0, 0.0), a_max=np.fmin(p0, p1))
315
334
  if cond is not None:
316
335
  mask = np.asarray(cond[1])
317
336
  pr[mask] /= p1 if cond[0] else p0
318
- pr[~mask] = 0.
337
+ pr[~mask] = 0.0
319
338
  return decisions(pr)
320
339
 
321
340
 
322
341
  # Connector Classes
323
342
  class AbstractConnector(ABC):
324
343
  """Abstract base class for connectors"""
344
+
325
345
  @abstractmethod
326
346
  def setup_nodes(self, source=None, target=None):
327
347
  """After network nodes are added to the BMTK network. Pass in the
@@ -338,8 +358,10 @@ class AbstractConnector(ABC):
338
358
  @staticmethod
339
359
  def constant_function(val):
340
360
  """Convert a constant to a constant function"""
361
+
341
362
  def constant(*arg):
342
363
  return val
364
+
343
365
  return constant
344
366
 
345
367
 
@@ -348,29 +370,31 @@ def is_same_pop(source, target, quick=False):
348
370
  """Check whether two NodePool objects direct to the same population"""
349
371
  if quick:
350
372
  # Quick check (compare filter conditions)
351
- same = (source.network_name == target.network_name and
352
- source._NodePool__properties ==
353
- target._NodePool__properties)
373
+ same = (
374
+ source.network_name == target.network_name
375
+ and source._NodePool__properties == target._NodePool__properties
376
+ )
354
377
  else:
355
378
  # Strict check (compare all nodes)
356
- same = (source.network_name == target.network_name and
357
- len(source) == len(target) and
358
- all([s.node_id == t.node_id
359
- for s, t in zip(source, target)]))
379
+ same = (
380
+ source.network_name == target.network_name
381
+ and len(source) == len(target)
382
+ and all([s.node_id == t.node_id for s, t in zip(source, target)])
383
+ )
360
384
  return same
361
385
 
362
386
 
363
387
  class Timer(object):
364
- def __init__(self, unit='sec'):
365
- if unit == 'ms':
388
+ def __init__(self, unit="sec"):
389
+ if unit == "ms":
366
390
  self.scale = 1e3
367
- elif unit == 'us':
391
+ elif unit == "us":
368
392
  self.scale = 1e6
369
- elif unit == 'min':
393
+ elif unit == "min":
370
394
  self.scale = 1 / 60
371
395
  else:
372
396
  self.scale = 1
373
- unit = 'sec'
397
+ unit = "sec"
374
398
  self.unit = unit
375
399
  self.start()
376
400
 
@@ -380,28 +404,30 @@ class Timer(object):
380
404
  def end(self):
381
405
  return (time.perf_counter() - self._start) * self.scale
382
406
 
383
- def report(self, msg='Run time'):
384
- print((msg + ": %.3f " + self.unit) % self.end(),flush=True)
407
+ def report(self, msg="Run time"):
408
+ print((msg + ": %.3f " + self.unit) % self.end(), flush=True)
385
409
 
386
410
 
387
411
  def pr_2_rho(p0, p1, pr):
388
412
  """Calculate correlation coefficient rho given reciprocal probability pr"""
389
413
  for p in (p0, p1):
390
- assert(p > 0 and p < 1)
391
- assert(pr >= 0 and pr <= p0 and pr <= p1 and pr >= p0 + p1 - 1)
392
- return (pr - p0 * p1) / (p0 * (1 - p0) * p1 * (1 - p1)) ** .5
414
+ assert p > 0 and p < 1
415
+ assert pr >= 0 and pr <= p0 and pr <= p1 and pr >= p0 + p1 - 1
416
+ return (pr - p0 * p1) / (p0 * (1 - p0) * p1 * (1 - p1)) ** 0.5
393
417
 
394
418
 
395
419
  def rho_2_pr(p0, p1, rho):
396
420
  """Calculate reciprocal probability pr given correlation coefficient rho"""
397
421
  for p in (p0, p1):
398
- assert(p > 0 and p < 1)
399
- pr = p0 * p1 + rho * (p0 * (1 - p0) * p1 * (1 - p1)) ** .5
422
+ assert p > 0 and p < 1
423
+ pr = p0 * p1 + rho * (p0 * (1 - p0) * p1 * (1 - p1)) ** 0.5
400
424
  if not (pr >= 0 and pr <= p0 and pr <= p1 and pr >= p0 + p1 - 1):
401
- pr0, pr = pr, np.max((0., p0 + p1 - 1, np.min((p0, p1, pr))))
402
- rho0, rho = rho, (pr - p0 * p1) / (p0 * (1 - p0) * p1 * (1 - p1)) ** .5
403
- print('rho changed from %.3f to %.3f; pr changed from %.3f to %.3f'
404
- % (rho0, rho, pr0, pr),flush=True)
425
+ pr0, pr = pr, np.max((0.0, p0 + p1 - 1, np.min((p0, p1, pr))))
426
+ rho0, rho = rho, (pr - p0 * p1) / (p0 * (1 - p0) * p1 * (1 - p1)) ** 0.5
427
+ print(
428
+ "rho changed from %.3f to %.3f; pr changed from %.3f to %.3f" % (rho0, rho, pr0, pr),
429
+ flush=True,
430
+ )
405
431
  return pr
406
432
 
407
433
 
@@ -575,15 +601,31 @@ class ReciprocalConnector(AbstractConnector):
575
601
  properties, so that they can access the information here.
576
602
  """
577
603
 
578
- def __init__(self, p0=1., p1=1., symmetric_p1=False,
579
- p0_arg=None, p1_arg=None, symmetric_p1_arg=False,
580
- pr=0., pr_arg=None, estimate_rho=True, rho=None,
581
- dist_range_forward=None, dist_range_backward=None,
582
- n_syn0=1, n_syn1=1, autapses=False,
583
- quick_pop_check=False, cache_data=True, verbose=True,save_report=True,report_name=None):
604
+ def __init__(
605
+ self,
606
+ p0=1.0,
607
+ p1=1.0,
608
+ symmetric_p1=False,
609
+ p0_arg=None,
610
+ p1_arg=None,
611
+ symmetric_p1_arg=False,
612
+ pr=0.0,
613
+ pr_arg=None,
614
+ estimate_rho=True,
615
+ rho=None,
616
+ dist_range_forward=None,
617
+ dist_range_backward=None,
618
+ n_syn0=1,
619
+ n_syn1=1,
620
+ autapses=False,
621
+ quick_pop_check=False,
622
+ cache_data=True,
623
+ verbose=True,
624
+ save_report=True,
625
+ report_name=None,
626
+ ):
584
627
  args = locals()
585
- var_set = ('p0', 'p0_arg', 'p1', 'p1_arg',
586
- 'pr', 'pr_arg', 'n_syn0', 'n_syn1')
628
+ var_set = ("p0", "p0_arg", "p1", "p1_arg", "pr", "pr_arg", "n_syn0", "n_syn1")
587
629
  self.vars = {key: args[key] for key in var_set}
588
630
 
589
631
  self.symmetric_p1 = symmetric_p1 and symmetric_p1_arg
@@ -601,7 +643,7 @@ class ReciprocalConnector(AbstractConnector):
601
643
  self.save_report = save_report
602
644
 
603
645
  if report_name is None:
604
- report_name = globals().get('report_name', 'default_report.csv')
646
+ report_name = globals().get("report_name", "default_report.csv")
605
647
  self.report_name = report_name
606
648
 
607
649
  self.conn_prop = [{}, {}]
@@ -613,9 +655,12 @@ class ReciprocalConnector(AbstractConnector):
613
655
  """Must run this before building connections"""
614
656
  if self.stage:
615
657
  # check whether the correct populations
616
- if (source is None or target is None or
617
- not is_same_pop(source, self.target, quick=self.quick) or
618
- not is_same_pop(target, self.source, quick=self.quick)):
658
+ if (
659
+ source is None
660
+ or target is None
661
+ or not is_same_pop(source, self.target, quick=self.quick)
662
+ or not is_same_pop(target, self.source, quick=self.quick)
663
+ ):
619
664
  raise ValueError("Source or target population not consistent.")
620
665
  # Skip adding nodes for the backward stage.
621
666
  return
@@ -648,22 +693,28 @@ class ReciprocalConnector(AbstractConnector):
648
693
  if self.recurrent:
649
694
  self.symmetric_p1_arg = True
650
695
  self.symmetric_p1 = True
651
- self.vars['n_syn1'] = self.vars['n_syn0']
696
+ self.vars["n_syn1"] = self.vars["n_syn0"]
652
697
  if self.symmetric_p1_arg:
653
- self.vars['p1_arg'] = self.vars['p0_arg']
698
+ self.vars["p1_arg"] = self.vars["p0_arg"]
654
699
  if self.symmetric_p1:
655
- self.vars['p1'] = self.vars['p0']
700
+ self.vars["p1"] = self.vars["p0"]
656
701
 
657
702
  def edge_params(self):
658
703
  """Create the arguments for BMTK add_edges() method"""
659
704
  if self.stage == 0:
660
- params = {'source': self.source, 'target': self.target,
661
- 'iterator': 'one_to_all',
662
- 'connection_rule': self.make_forward_connection}
705
+ params = {
706
+ "source": self.source,
707
+ "target": self.target,
708
+ "iterator": "one_to_all",
709
+ "connection_rule": self.make_forward_connection,
710
+ }
663
711
  else:
664
- params = {'source': self.target, 'target': self.source,
665
- 'iterator': 'all_to_one',
666
- 'connection_rule': self.make_backward_connection}
712
+ params = {
713
+ "source": self.target,
714
+ "target": self.source,
715
+ "iterator": "all_to_one",
716
+ "connection_rule": self.make_backward_connection,
717
+ }
667
718
  self.stage += 1
668
719
  return params
669
720
 
@@ -687,6 +738,7 @@ class ReciprocalConnector(AbstractConnector):
687
738
  val = func(*args)
688
739
  output.append(val)
689
740
  return val
741
+
690
742
  setattr(self, func_name, writer)
691
743
  else:
692
744
  setattr(self, func_name, func)
@@ -694,15 +746,17 @@ class ReciprocalConnector(AbstractConnector):
694
746
  def write_mode(self):
695
747
  for val in self._output.values():
696
748
  val.clear()
697
- self.mode = 'write'
749
+ self.mode = "write"
698
750
  self.iter_count = 0
699
751
 
700
752
  def fetch_output(self, func_name, fetch=True):
701
753
  output = self._output[func_name]
702
754
 
703
755
  if fetch:
756
+
704
757
  def reader(*args):
705
758
  return output[self.iter_count]
759
+
706
760
  setattr(self, func_name, reader)
707
761
  else:
708
762
  setattr(self, func_name, self.cache_dict[func_name])
@@ -718,36 +772,44 @@ class ReciprocalConnector(AbstractConnector):
718
772
  for func_name, out_len in zip(self._output, output_len):
719
773
  fetch = out_len > 0
720
774
  if not fetch:
721
- print("\nWarning: Cache did not work properly for "
722
- + func_name + '\n',flush=True)
775
+ print(
776
+ "\nWarning: Cache did not work properly for " + func_name + "\n",
777
+ flush=True,
778
+ )
723
779
  self.fetch_output(func_name, fetch)
724
780
  self.iter_count = 0
725
781
  else:
726
782
  # if output not correct, disable and use original function
727
- print("\nWarning: Cache did not work properly.\n",flush=True)
783
+ print("\nWarning: Cache did not work properly.\n", flush=True)
728
784
  for func_name in self.cache_dict:
729
785
  self.fetch_output(func_name, False)
730
786
  self.enable = False
731
- self.mode = 'read'
787
+ self.mode = "read"
732
788
 
733
789
  def set_next_it(self):
734
790
  if self.enable:
791
+
735
792
  def next_it():
736
793
  self.iter_count += 1
737
794
  else:
795
+
738
796
  def next_it():
739
797
  pass
798
+
740
799
  self.next_it = next_it
741
800
 
742
801
  def node_2_idx_input(self, var_func, reverse=False):
743
802
  """Convert a function that accept nodes as input
744
803
  to accept indices as input"""
745
804
  if reverse:
805
+
746
806
  def idx_2_var(j, i):
747
807
  return var_func(self.target_list[j], self.source_list[i])
748
808
  else:
809
+
749
810
  def idx_2_var(i, j):
750
811
  return var_func(self.source_list[i], self.target_list[j])
812
+
751
813
  return idx_2_var
752
814
 
753
815
  def iterate_pairs(self):
@@ -784,6 +846,7 @@ class ReciprocalConnector(AbstractConnector):
784
846
  if self.rho is None:
785
847
  # Determine by pr for each pair
786
848
  if self.verbose:
849
+
787
850
  def cond_backward(cond, p0, p1, pr):
788
851
  if p0 > 0:
789
852
  pr_bound = (p0 + p1 - 1, min(p0, p1))
@@ -795,6 +858,7 @@ class ReciprocalConnector(AbstractConnector):
795
858
  else:
796
859
  return p1
797
860
  else:
861
+
798
862
  def cond_backward(cond, p0, p1, pr):
799
863
  if p0 > 0:
800
864
  pr_bound = (p0 + p1 - 1, min(p0, p1))
@@ -810,10 +874,11 @@ class ReciprocalConnector(AbstractConnector):
810
874
  # Dependent with fixed correlation coefficient rho
811
875
  def cond_backward(cond, p0, p1, pr):
812
876
  # Standard deviation of r.v. for p1
813
- sd = ((1 - p1) * p1) ** .5
877
+ sd = ((1 - p1) * p1) ** 0.5
814
878
  # Z-score of random variable for p0
815
- zs = ((1 - p0) / p0) ** .5 if cond else - (p0 / (1 - p0)) ** .5
879
+ zs = ((1 - p0) / p0) ** 0.5 if cond else -((p0 / (1 - p0)) ** 0.5)
816
880
  return p1 + self.rho * sd * zs
881
+
817
882
  self.cond_backward = cond_backward
818
883
 
819
884
  def add_conn_prop(self, src, trg, prop, stage=0):
@@ -833,9 +898,9 @@ class ReciprocalConnector(AbstractConnector):
833
898
  # *** A sequence of major methods executed during build ***
834
899
  def setup_variables(self):
835
900
  # If pr_arg is string, use the same value as p0_arg or p1_arg
836
- if isinstance(self.vars['pr_arg'], str):
837
- pr_arg_func = 'p1_arg' if '1' in self.vars['pr_arg'] else 'p0_arg'
838
- self.vars['pr_arg'] = self.vars[pr_arg_func]
901
+ if isinstance(self.vars["pr_arg"], str):
902
+ pr_arg_func = "p1_arg" if "1" in self.vars["pr_arg"] else "p0_arg"
903
+ self.vars["pr_arg"] = self.vars[pr_arg_func]
839
904
  else:
840
905
  pr_arg_func = None
841
906
 
@@ -850,60 +915,70 @@ class ReciprocalConnector(AbstractConnector):
850
915
  self.callable_set = callable_set
851
916
 
852
917
  # Make callable variables except a few, accept index input instead
853
- for name in callable_set - {'p0', 'p1', 'pr'}:
918
+ for name in callable_set - {"p0", "p1", "pr"}:
854
919
  var = self.vars[name]
855
- setattr(self, name, self.node_2_idx_input(var, '1' in name))
920
+ setattr(self, name, self.node_2_idx_input(var, "1" in name))
856
921
 
857
922
  # Set up function for pr_arg if use value from p0_arg or p1_arg
858
923
  if pr_arg_func is None:
859
924
  self._pr_arg = self.pr_arg # use specified pr_arg
860
925
  else:
861
- self._pr_arg_val = 0. # storing current value from p_arg
926
+ self._pr_arg_val = 0.0 # storing current value from p_arg
862
927
  p_arg = getattr(self, pr_arg_func)
928
+
863
929
  def p_arg_4_pr(*args, **kwargs):
864
930
  val = p_arg(*args, **kwargs)
865
931
  self._pr_arg_val = val
866
932
  return val
933
+
867
934
  setattr(self, pr_arg_func, p_arg_4_pr)
935
+
868
936
  def pr_arg(self, *arg):
869
937
  return self._pr_arg_val
938
+
870
939
  self._pr_arg = types.MethodType(pr_arg, self)
871
940
 
872
941
  def cache_variables(self):
873
942
  # Select cacheable attrilbutes
874
- cache_set = {'p0', 'p0_arg', 'p1', 'p1_arg'}
943
+ cache_set = {"p0", "p0_arg", "p1", "p1_arg"}
875
944
  if self.symmetric_p1:
876
- cache_set.remove('p1')
945
+ cache_set.remove("p1")
877
946
  if self.symmetric_p1_arg:
878
- cache_set.remove('p1_arg')
947
+ cache_set.remove("p1_arg")
879
948
  # Output of callable variables will be cached
880
949
  # Constant functions will be called from cache but output not cached
881
950
  for name in cache_set:
882
951
  var = getattr(self, name)
883
952
  self.cache.cache_output(var, name, name in self.callable_set)
884
953
  if self.verbose and len(self.cache.cache_dict):
885
- print('Output of %s will be cached.'
886
- % ', '.join(self.cache.cache_dict),flush=True)
954
+ print("Output of %s will be cached." % ", ".join(self.cache.cache_dict), flush=True)
887
955
 
888
956
  def setup_dist_range_checker(self):
889
957
  # Checker that determines whether to consider a pair for rho estimation
890
958
  if self.dist_range_forward is None and self.dist_range_backward is None:
959
+
891
960
  def checker(var):
892
961
  p0, p1 = var[2:]
893
962
  return p0 > 0 and p1 > 0
894
963
  else:
964
+
895
965
  def in_range(p_arg, dist_range):
896
966
  return p_arg >= dist_range[0] and p_arg <= dist_range[1]
967
+
897
968
  r0, r1 = self.dist_range_forward, self.dist_range_backward
898
969
  if r1 is None:
970
+
899
971
  def checker(var):
900
972
  return in_range(var[0], r0)
901
973
  elif r0 is None:
974
+
902
975
  def checker(var):
903
976
  return in_range(var[1], r1)
904
977
  else:
978
+
905
979
  def checker(var):
906
980
  return in_range(var[0], r0) and in_range(var[1], r1)
981
+
907
982
  return checker
908
983
 
909
984
  def initialize(self):
@@ -918,8 +993,9 @@ class ReciprocalConnector(AbstractConnector):
918
993
  """The major part of the algorithm run at beginning of BMTK iterator"""
919
994
  if self.verbose:
920
995
  src_str, trg_str = self.get_nodes_info()
921
- print("\nStart building connection between: \n "
922
- + src_str + "\n " + trg_str,flush=True)
996
+ print(
997
+ "\nStart building connection between: \n " + src_str + "\n " + trg_str, flush=True
998
+ )
923
999
  self.initialize()
924
1000
  cache = self.cache # write mode
925
1001
 
@@ -928,8 +1004,8 @@ class ReciprocalConnector(AbstractConnector):
928
1004
  self.timer = Timer()
929
1005
  if self.estimate_rho:
930
1006
  dist_range_checker = self.setup_dist_range_checker()
931
- p0p1_sum = 0.
932
- norm_fac_sum = 0.
1007
+ p0p1_sum = 0.0
1008
+ norm_fac_sum = 0.0
933
1009
  n = 0
934
1010
  # Make sure each cacheable function runs excatly once per iteration
935
1011
  for i, j in self.iterate_pairs():
@@ -939,22 +1015,25 @@ class ReciprocalConnector(AbstractConnector):
939
1015
  n += 1
940
1016
  p0, p1 = var[2:]
941
1017
  p0p1_sum += p0 * p1
942
- norm_fac_sum += (p0 * (1 - p0) * p1 * (1 - p1)) ** .5
1018
+ norm_fac_sum += (p0 * (1 - p0) * p1 * (1 - p1)) ** 0.5
943
1019
  if norm_fac_sum > 0:
944
1020
  rho = (self.pr() * n - p0p1_sum) / norm_fac_sum
945
1021
  if abs(rho) > 1:
946
- print("\nWarning: Estimated value of rho=%.3f "
947
- "outside the range [-1, 1]." % rho,flush=True)
1022
+ print(
1023
+ "\nWarning: Estimated value of rho=%.3f "
1024
+ "outside the range [-1, 1]." % rho,
1025
+ flush=True,
1026
+ )
948
1027
  rho = np.clip(rho, -1, 1).item()
949
- print("Force rho to be %.0f.\n" % rho,flush=True)
1028
+ print("Force rho to be %.0f.\n" % rho, flush=True)
950
1029
  elif self.verbose:
951
- print("Estimated value of rho=%.3f" % rho,flush=True)
1030
+ print("Estimated value of rho=%.3f" % rho, flush=True)
952
1031
  self.rho = rho
953
1032
  else:
954
1033
  self.rho = 0
955
1034
 
956
1035
  if self.verbose:
957
- self.timer.report('Time for estimating rho')
1036
+ self.timer.report("Time for estimating rho")
958
1037
 
959
1038
  # Setup function for calculating conditional backward probability
960
1039
  self.setup_conditional_backward_probability()
@@ -998,15 +1077,15 @@ class ReciprocalConnector(AbstractConnector):
998
1077
  self.possible_count = possible_count
999
1078
 
1000
1079
  if self.verbose:
1001
- self.timer.report('Total time for creating connection matrix')
1080
+ self.timer.report("Total time for creating connection matrix")
1002
1081
  if self.wrong_pr:
1003
- print("Warning: Value of 'pr' outside the bounds occurred.\n",flush=True)
1082
+ print("Warning: Value of 'pr' outside the bounds occurred.\n", flush=True)
1004
1083
  self.connection_number_info()
1005
1084
  if self.save_report:
1006
1085
  self.save_connection_report()
1007
1086
 
1008
1087
  def make_connection(self):
1009
- """ Assign number of synapses per iteration.
1088
+ """Assign number of synapses per iteration.
1010
1089
  Use iterator one_to_all for forward and all_to_one for backward.
1011
1090
  """
1012
1091
  nsyns = self.conn_mat[self.stage, self.iter_count, :]
@@ -1017,7 +1096,7 @@ class ReciprocalConnector(AbstractConnector):
1017
1096
  self.iter_count = 0
1018
1097
  if self.stage == self.end_stage:
1019
1098
  if self.verbose:
1020
- self.timer.report('Done! \nTime for building connections')
1099
+ self.timer.report("Done! \nTime for building connections")
1021
1100
  self.free_memory()
1022
1101
  return nsyns
1023
1102
 
@@ -1028,7 +1107,7 @@ class ReciprocalConnector(AbstractConnector):
1028
1107
  self.stage = 0
1029
1108
  self.initial_all_to_all()
1030
1109
  if self.verbose:
1031
- print("Assigning forward connections.",flush=True)
1110
+ print("Assigning forward connections.", flush=True)
1032
1111
  self.timer.start()
1033
1112
  return self.make_connection()
1034
1113
 
@@ -1037,22 +1116,21 @@ class ReciprocalConnector(AbstractConnector):
1037
1116
  if self.iter_count == 0:
1038
1117
  self.stage = 1
1039
1118
  if self.verbose:
1040
- print("Assigning backward connections.",flush=True)
1119
+ print("Assigning backward connections.", flush=True)
1041
1120
  return self.make_connection()
1042
1121
 
1043
1122
  def free_memory(self):
1044
1123
  """Free up memory after connections are built"""
1045
1124
  # Do not clear self.conn_prop if it will be used by conn.add_properties
1046
- variables = ('conn_mat', 'source_list', 'target_list',
1047
- 'source_ids', 'target_ids')
1125
+ variables = ("conn_mat", "source_list", "target_list", "source_ids", "target_ids")
1048
1126
  for var in variables:
1049
1127
  setattr(self, var, None)
1050
1128
 
1051
1129
  # *** Helper functions for verbose ***
1052
1130
  def get_nodes_info(self):
1053
1131
  """Get strings with source and target population information"""
1054
- source_str = self.source.network_name + ': ' + self.source.filter_str
1055
- target_str = self.target.network_name + ': ' + self.target.filter_str
1132
+ source_str = self.source.network_name + ": " + self.source.filter_str
1133
+ target_str = self.target.network_name + ": " + self.target.filter_str
1056
1134
  return source_str, target_str
1057
1135
 
1058
1136
  def connection_number(self):
@@ -1080,24 +1158,31 @@ class ReciprocalConnector(AbstractConnector):
1080
1158
  n_conn = np.append(n_conn, n_recp)
1081
1159
  n_pair = int(n_pair)
1082
1160
  fraction = np.array([n_conn / n_poss, n_conn / n_pair])
1083
- fraction[np.isnan(fraction)] = 0.
1161
+ fraction[np.isnan(fraction)] = 0.0
1084
1162
  return n_conn, n_poss, n_pair, fraction
1085
1163
 
1086
1164
  def connection_number_info(self):
1087
1165
  """Print connection numbers after connections built"""
1166
+
1088
1167
  def arr2str(a, f):
1089
- return ', '.join([f] * a.size) % tuple(a.tolist())
1168
+ return ", ".join([f] * a.size) % tuple(a.tolist())
1169
+
1090
1170
  n_conn, n_poss, n_pair, fraction = self.connection_number()
1091
- conn_type = "(all, reciprocal)" if self.recurrent \
1092
- else "(forward, backward, reciprocal)"
1093
- print("Numbers of " + conn_type + " connections:",flush=True)
1094
- print("Number of connected pairs: (%s)" % arr2str(n_conn, '%d'),flush=True)
1095
- print("Number of possible connections: (%s)" % arr2str(n_poss, '%d'),flush=True)
1096
- print("Fraction of connected pairs in possible ones: (%s)"
1097
- % arr2str(100 * fraction[0], '%.2f%%'),flush=True)
1098
- print("Number of total pairs: %d" % n_pair,flush=True)
1099
- print("Fraction of connected pairs in all pairs: (%s)\n"
1100
- % arr2str(100 * fraction[1], '%.2f%%'),flush=True)
1171
+ conn_type = "(all, reciprocal)" if self.recurrent else "(forward, backward, reciprocal)"
1172
+ print("Numbers of " + conn_type + " connections:", flush=True)
1173
+ print("Number of connected pairs: (%s)" % arr2str(n_conn, "%d"), flush=True)
1174
+ print("Number of possible connections: (%s)" % arr2str(n_poss, "%d"), flush=True)
1175
+ print(
1176
+ "Fraction of connected pairs in possible ones: (%s)"
1177
+ % arr2str(100 * fraction[0], "%.2f%%"),
1178
+ flush=True,
1179
+ )
1180
+ print("Number of total pairs: %d" % n_pair, flush=True)
1181
+ print(
1182
+ "Fraction of connected pairs in all pairs: (%s)\n"
1183
+ % arr2str(100 * fraction[1], "%.2f%%"),
1184
+ flush=True,
1185
+ )
1101
1186
 
1102
1187
  def save_connection_report(self):
1103
1188
  """Save connections into a CSV file to be read from later"""
@@ -1108,21 +1193,21 @@ class ReciprocalConnector(AbstractConnector):
1108
1193
  data = {
1109
1194
  "Source": [src_str],
1110
1195
  "Target": [trg_str],
1111
- "Percent connectionivity within possible connections": [fraction[0]*100],
1112
- "Percent connectionivity within all connections": [fraction[1]*100]
1196
+ "Percent connectionivity within possible connections": [fraction[0] * 100],
1197
+ "Percent connectionivity within all connections": [fraction[1] * 100],
1113
1198
  }
1114
1199
  df = pd.DataFrame(data)
1115
-
1200
+
1116
1201
  # Append the data to the CSV file
1117
1202
  try:
1118
1203
  # Check if the file exists by trying to read it
1119
1204
  existing_df = pd.read_csv(self.report_name)
1120
1205
  # If no exception is raised, append without header
1121
- df.to_csv(self.report_name, mode='a', header=False, index=False)
1206
+ df.to_csv(self.report_name, mode="a", header=False, index=False)
1122
1207
  except FileNotFoundError:
1123
1208
  # If the file does not exist, write with header
1124
- df.to_csv(self.report_name, mode='w', header=True, index=False)
1125
-
1209
+ df.to_csv(self.report_name, mode="w", header=True, index=False)
1210
+
1126
1211
 
1127
1212
  class UnidirectionConnector(AbstractConnector):
1128
1213
  """
@@ -1155,15 +1240,17 @@ class UnidirectionConnector(AbstractConnector):
1155
1240
  This is useful in similar manner as in ReciprocalConnector.
1156
1241
  """
1157
1242
 
1158
- def __init__(self, p=1., p_arg=None, n_syn=1, verbose=True,save_report=True,report_name=None):
1243
+ def __init__(
1244
+ self, p=1.0, p_arg=None, n_syn=1, verbose=True, save_report=True, report_name=None
1245
+ ):
1159
1246
  args = locals()
1160
- var_set = ('p', 'p_arg', 'n_syn')
1247
+ var_set = ("p", "p_arg", "n_syn")
1161
1248
  self.vars = {key: args[key] for key in var_set}
1162
1249
 
1163
1250
  self.verbose = verbose
1164
1251
  self.save_report = save_report
1165
1252
  if report_name is None:
1166
- report_name = globals().get('report_name', 'default_report.csv')
1253
+ report_name = globals().get("report_name", "default_report.csv")
1167
1254
  self.report_name = report_name
1168
1255
 
1169
1256
  self.conn_prop = {}
@@ -1185,9 +1272,12 @@ class UnidirectionConnector(AbstractConnector):
1185
1272
 
1186
1273
  def edge_params(self):
1187
1274
  """Create the arguments for BMTK add_edges() method"""
1188
- params = {'source': self.source, 'target': self.target,
1189
- 'iterator': 'one_to_one',
1190
- 'connection_rule': self.make_connection}
1275
+ params = {
1276
+ "source": self.source,
1277
+ "target": self.target,
1278
+ "iterator": "one_to_one",
1279
+ "connection_rule": self.make_connection,
1280
+ }
1191
1281
  return params
1192
1282
 
1193
1283
  # *** Methods executed during bmtk network.build() ***
@@ -1223,8 +1313,10 @@ class UnidirectionConnector(AbstractConnector):
1223
1313
  self.initialize()
1224
1314
  if self.verbose:
1225
1315
  src_str, trg_str = self.get_nodes_info()
1226
- print("\nStart building connection \n from "
1227
- + src_str + "\n to " + trg_str,flush=True)
1316
+ print(
1317
+ "\nStart building connection \n from " + src_str + "\n to " + trg_str,
1318
+ flush=True,
1319
+ )
1228
1320
 
1229
1321
  # Make random connections
1230
1322
 
@@ -1245,7 +1337,7 @@ class UnidirectionConnector(AbstractConnector):
1245
1337
  if self.iter_count == self.n_pair:
1246
1338
  if self.verbose:
1247
1339
  self.connection_number_info()
1248
- self.timer.report('Done! \nTime for building connections')
1340
+ self.timer.report("Done! \nTime for building connections")
1249
1341
  if self.save_report:
1250
1342
  self.save_connection_report()
1251
1343
 
@@ -1254,45 +1346,52 @@ class UnidirectionConnector(AbstractConnector):
1254
1346
  # *** Helper functions for verbose ***
1255
1347
  def get_nodes_info(self):
1256
1348
  """Get strings with source and target population information"""
1257
- source_str = self.source.network_name + ': ' + self.source.filter_str
1258
- target_str = self.target.network_name + ': ' + self.target.filter_str
1349
+ source_str = self.source.network_name + ": " + self.source.filter_str
1350
+ target_str = self.target.network_name + ": " + self.target.filter_str
1259
1351
  return source_str, target_str
1260
1352
 
1261
1353
  def connection_number_info(self):
1262
1354
  """Print connection numbers after connections built"""
1263
- print("Number of connected pairs: %d" % self.n_conn,flush=True)
1264
- print("Number of possible connections: %d" % self.n_poss,flush=True)
1265
- print("Fraction of connected pairs in possible ones: %.2f%%"
1266
- % (100. * self.n_conn / self.n_poss) if self.n_poss else 0.)
1267
- print("Number of total pairs: %d" % self.n_pair,flush=True)
1268
- print("Fraction of connected pairs in all pairs: %.2f%%\n"
1269
- % (100. * self.n_conn / self.n_pair),flush=True)
1270
-
1355
+ print("Number of connected pairs: %d" % self.n_conn, flush=True)
1356
+ print("Number of possible connections: %d" % self.n_poss, flush=True)
1357
+ print(
1358
+ "Fraction of connected pairs in possible ones: %.2f%%"
1359
+ % (100.0 * self.n_conn / self.n_poss)
1360
+ if self.n_poss
1361
+ else 0.0
1362
+ )
1363
+ print("Number of total pairs: %d" % self.n_pair, flush=True)
1364
+ print(
1365
+ "Fraction of connected pairs in all pairs: %.2f%%\n"
1366
+ % (100.0 * self.n_conn / self.n_pair),
1367
+ flush=True,
1368
+ )
1369
+
1271
1370
  def save_connection_report(self):
1272
1371
  """Save connections into a CSV file to be read from later"""
1273
1372
  src_str, trg_str = self.get_nodes_info()
1274
-
1275
- possible_fraction = (100. * self.n_conn / self.n_poss)
1276
- all_fraction = (100. * self.n_conn / self.n_pair)
1373
+
1374
+ possible_fraction = 100.0 * self.n_conn / self.n_poss
1375
+ all_fraction = 100.0 * self.n_conn / self.n_pair
1277
1376
 
1278
1377
  # Extract the population name from source_str and target_str
1279
1378
  data = {
1280
1379
  "Source": [src_str],
1281
1380
  "Target": [trg_str],
1282
1381
  "Percent connectionivity within possible connections": [possible_fraction],
1283
- "Percent connectionivity within all connections": [all_fraction]
1382
+ "Percent connectionivity within all connections": [all_fraction],
1284
1383
  }
1285
1384
  df = pd.DataFrame(data)
1286
-
1385
+
1287
1386
  # Append the data to the CSV file
1288
1387
  try:
1289
1388
  # Check if the file exists by trying to read it
1290
1389
  existing_df = pd.read_csv(self.report_name)
1291
1390
  # If no exception is raised, append without header
1292
- df.to_csv(self.report_name, mode='a', header=False, index=False)
1391
+ df.to_csv(self.report_name, mode="a", header=False, index=False)
1293
1392
  except FileNotFoundError:
1294
1393
  # If the file does not exist, write with header
1295
- df.to_csv(self.report_name, mode='w', header=True, index=False)
1394
+ df.to_csv(self.report_name, mode="w", header=True, index=False)
1296
1395
 
1297
1396
 
1298
1397
  class GapJunction(UnidirectionConnector):
@@ -1316,16 +1415,19 @@ class GapJunction(UnidirectionConnector):
1316
1415
  Similar to `UnidirectionConnector`.
1317
1416
  """
1318
1417
 
1319
- def __init__(self, p=1., p_arg=None, verbose=True,save_report=True,report_name=None):
1320
- super().__init__(p=p, p_arg=p_arg, verbose=verbose,save_report=save_report,report_name=None)
1321
-
1418
+ def __init__(self, p=1.0, p_arg=None, verbose=True, save_report=True, report_name=None):
1419
+ super().__init__(
1420
+ p=p, p_arg=p_arg, verbose=verbose, save_report=save_report, report_name=None
1421
+ )
1322
1422
 
1323
1423
  def setup_nodes(self, source=None, target=None):
1324
1424
  super().setup_nodes(source=source, target=target)
1325
1425
  if len(self.source) != len(self.target):
1326
1426
  src_str, trg_str = self.get_nodes_info()
1327
- raise ValueError(f"Source and target must be the same for "
1328
- f"gap junction. Nodes are {src_str} and {trg_str}")
1427
+ raise ValueError(
1428
+ f"Source and target must be the same for "
1429
+ f"gap junction. Nodes are {src_str} and {trg_str}"
1430
+ )
1329
1431
  self.n_source = len(self.source)
1330
1432
 
1331
1433
  def make_connection(self, source, target, *args, **kwargs):
@@ -1335,7 +1437,7 @@ class GapJunction(UnidirectionConnector):
1335
1437
  self.initialize()
1336
1438
  if self.verbose:
1337
1439
  src_str, _ = self.get_nodes_info()
1338
- print("\nStart building gap junction \n in " + src_str,flush=True)
1440
+ print("\nStart building gap junction \n in " + src_str, flush=True)
1339
1441
 
1340
1442
  # Consider each pair only once
1341
1443
  nsyns = 0
@@ -1358,7 +1460,7 @@ class GapJunction(UnidirectionConnector):
1358
1460
  if self.iter_count == self.n_pair:
1359
1461
  if self.verbose:
1360
1462
  self.connection_number_info()
1361
- self.timer.report('Done! \nTime for building connections')
1463
+ self.timer.report("Done! \nTime for building connections")
1362
1464
  if self.save_report:
1363
1465
  self.save_connection_report()
1364
1466
  return nsyns
@@ -1373,27 +1475,27 @@ class GapJunction(UnidirectionConnector):
1373
1475
  """Save connections into a CSV file to be read from later"""
1374
1476
  src_str, trg_str = self.get_nodes_info()
1375
1477
  n_pair = self.n_pair
1376
- fraction_0 = self.n_conn / self.n_poss if self.n_poss else 0.
1478
+ fraction_0 = self.n_conn / self.n_poss if self.n_poss else 0.0
1377
1479
  fraction_1 = self.n_conn / self.n_pair
1378
1480
 
1379
1481
  # Convert fraction to percentage and prepare data for the DataFrame
1380
1482
  data = {
1381
- "Source": [src_str+"Gap"],
1382
- "Target": [trg_str+"Gap"],
1383
- "Percent connectionivity within possible connections": [fraction_0*100],
1384
- "Percent connectionivity within all connections": [fraction_1*100]
1483
+ "Source": [src_str + "Gap"],
1484
+ "Target": [trg_str + "Gap"],
1485
+ "Percent connectionivity within possible connections": [fraction_0 * 100],
1486
+ "Percent connectionivity within all connections": [fraction_1 * 100],
1385
1487
  }
1386
1488
  df = pd.DataFrame(data)
1387
-
1489
+
1388
1490
  # Append the data to the CSV file
1389
1491
  try:
1390
1492
  # Check if the file exists by trying to read it
1391
1493
  existing_df = pd.read_csv(self.report_name)
1392
1494
  # If no exception is raised, append without header
1393
- df.to_csv(self.report_name, mode='a', header=False, index=False)
1495
+ df.to_csv(self.report_name, mode="a", header=False, index=False)
1394
1496
  except FileNotFoundError:
1395
1497
  # If the file does not exist, write with header
1396
- df.to_csv(self.report_name, mode='w', header=True, index=False)
1498
+ df.to_csv(self.report_name, mode="w", header=True, index=False)
1397
1499
 
1398
1500
 
1399
1501
  class CorrelatedGapJunction(GapJunction):
@@ -1423,12 +1525,23 @@ class CorrelatedGapJunction(GapJunction):
1423
1525
  Similar to `UnidirectionConnector`.
1424
1526
  """
1425
1527
 
1426
- def __init__(self, p_non=1., p_uni=1., p_rec=1., p_arg=None,
1427
- connector=None, verbose=True,save_report=True,report_name=None):
1428
- super().__init__(p=p_non, p_arg=p_arg, verbose=verbose,save_report=save_report,report_name=None)
1429
- self.vars['p_non'] = self.vars.pop('p')
1430
- self.vars['p_uni'] = p_uni
1431
- self.vars['p_rec'] = p_rec
1528
+ def __init__(
1529
+ self,
1530
+ p_non=1.0,
1531
+ p_uni=1.0,
1532
+ p_rec=1.0,
1533
+ p_arg=None,
1534
+ connector=None,
1535
+ verbose=True,
1536
+ save_report=True,
1537
+ report_name=None,
1538
+ ):
1539
+ super().__init__(
1540
+ p=p_non, p_arg=p_arg, verbose=verbose, save_report=save_report, report_name=None
1541
+ )
1542
+ self.vars["p_non"] = self.vars.pop("p")
1543
+ self.vars["p_uni"] = p_uni
1544
+ self.vars["p_rec"] = p_rec
1432
1545
  self.connector = connector
1433
1546
  conn_prop = connector.conn_prop
1434
1547
  if isinstance(conn_prop, list):
@@ -1448,10 +1561,10 @@ class CorrelatedGapJunction(GapJunction):
1448
1561
  return conn0 + conn1, prop0 if conn0 else prop1
1449
1562
 
1450
1563
  def initialize(self):
1451
- self.has_p_arg = self.vars['p_arg'] is not None
1564
+ self.has_p_arg = self.vars["p_arg"] is not None
1452
1565
  if not self.has_p_arg:
1453
1566
  var = self.connector.vars
1454
- self.vars['p_arg'] = var.get('p_arg', var.get('p0_arg', None))
1567
+ self.vars["p_arg"] = var.get("p_arg", var.get("p0_arg", None))
1455
1568
  super().initialize()
1456
1569
  self.ps = [self.p_non, self.p_uni, self.p_rec]
1457
1570
 
@@ -1462,7 +1575,7 @@ class CorrelatedGapJunction(GapJunction):
1462
1575
  self.initialize()
1463
1576
  if self.verbose:
1464
1577
  src_str, _ = self.get_nodes_info()
1465
- print("\nStart building gap junction \n in " + src_str,flush=True)
1578
+ print("\nStart building gap junction \n in " + src_str, flush=True)
1466
1579
 
1467
1580
  # Consider each pair only once
1468
1581
  nsyns = 0
@@ -1487,7 +1600,7 @@ class CorrelatedGapJunction(GapJunction):
1487
1600
  if self.iter_count == self.n_pair:
1488
1601
  if self.verbose:
1489
1602
  self.connection_number_info()
1490
- self.timer.report('Done! \nTime for building connections')
1603
+ self.timer.report("Done! \nTime for building connections")
1491
1604
  if self.save_report:
1492
1605
  self.save_connection_report()
1493
1606
  return nsyns
@@ -1551,13 +1664,16 @@ class OneToOneSequentialConnector(AbstractConnector):
1551
1664
  if self.target_count == 0:
1552
1665
  if source is None or len(source) == 0:
1553
1666
  src_str, trg_str = self.get_nodes_info()
1554
- raise ValueError((f"{trg_str}" if self.partition_source else
1555
- f"{src_str}") + " nodes do not exists")
1667
+ raise ValueError(
1668
+ (f"{trg_str}" if self.partition_source else f"{src_str}")
1669
+ + " nodes do not exists"
1670
+ )
1556
1671
  self.source = source
1557
1672
  self.n_source = len(source)
1558
1673
  if target is None or len(target) == 0:
1559
- raise ValueError(("Source" if self.partition_source else
1560
- "Target") + " nodes do not exists")
1674
+ raise ValueError(
1675
+ ("Source" if self.partition_source else "Target") + " nodes do not exists"
1676
+ )
1561
1677
 
1562
1678
  self.targets.append(target)
1563
1679
  self.idx_range.append(self.idx_range[-1] + len(target))
@@ -1567,25 +1683,33 @@ class OneToOneSequentialConnector(AbstractConnector):
1567
1683
  if self.partition_source:
1568
1684
  raise ValueError(
1569
1685
  "Total target populations exceed the source population."
1570
- if self.partition_source else
1571
- "Total source populations exceed the target population."
1572
- )
1686
+ if self.partition_source
1687
+ else "Total source populations exceed the target population."
1688
+ )
1573
1689
 
1574
1690
  if self.verbose and self.idx_range[-1] == self.n_source:
1575
- print("All " + ("source" if self.partition_source else "target")
1576
- + " population partitions are filled.",flush=True)
1691
+ print(
1692
+ "All "
1693
+ + ("source" if self.partition_source else "target")
1694
+ + " population partitions are filled.",
1695
+ flush=True,
1696
+ )
1577
1697
 
1578
1698
  def edge_params(self, target_pop_idx=-1):
1579
1699
  """Create the arguments for BMTK add_edges() method"""
1580
1700
  if self.partition_source:
1581
- params = {'source': self.targets[target_pop_idx],
1582
- 'target': self.source,
1583
- 'iterator': 'one_to_all'}
1701
+ params = {
1702
+ "source": self.targets[target_pop_idx],
1703
+ "target": self.source,
1704
+ "iterator": "one_to_all",
1705
+ }
1584
1706
  else:
1585
- params = {'source': self.source,
1586
- 'target': self.targets[target_pop_idx],
1587
- 'iterator': 'all_to_one'}
1588
- params['connection_rule'] = self.make_connection
1707
+ params = {
1708
+ "source": self.source,
1709
+ "target": self.targets[target_pop_idx],
1710
+ "iterator": "all_to_one",
1711
+ }
1712
+ params["connection_rule"] = self.make_connection
1589
1713
  return params
1590
1714
 
1591
1715
  # *** Methods executed during bmtk network.build() ***
@@ -1597,15 +1721,23 @@ class OneToOneSequentialConnector(AbstractConnector):
1597
1721
  # Very beginning
1598
1722
  self.target_count = 0
1599
1723
  src_str, trg_str = self.get_nodes_info()
1600
- print("\nStart building connection " +
1601
- ("to " if self.partition_source else "from ") + src_str,flush=True)
1724
+ print(
1725
+ "\nStart building connection "
1726
+ + ("to " if self.partition_source else "from ")
1727
+ + src_str,
1728
+ flush=True,
1729
+ )
1602
1730
  self.timer = Timer()
1603
1731
 
1604
1732
  if self.iter_count == self.idx_range[self.target_count]:
1605
1733
  # Beginning of each target population
1606
1734
  src_str, trg_str = self.get_nodes_info(self.target_count)
1607
- print((" %d. " % self.target_count) +
1608
- ("from " if self.partition_source else "to ") + trg_str,flush=True)
1735
+ print(
1736
+ (" %d. " % self.target_count)
1737
+ + ("from " if self.partition_source else "to ")
1738
+ + trg_str,
1739
+ flush=True,
1740
+ )
1609
1741
  self.target_count += 1
1610
1742
  self.timer_part = Timer()
1611
1743
 
@@ -1618,18 +1750,18 @@ class OneToOneSequentialConnector(AbstractConnector):
1618
1750
  if self.verbose:
1619
1751
  if self.iter_count == self.idx_range[self.target_count]:
1620
1752
  # End of each target population
1621
- self.timer_part.report(' Time for this partition')
1753
+ self.timer_part.report(" Time for this partition")
1622
1754
  if self.iter_count == self.n_source:
1623
1755
  # Very end
1624
- self.timer.report('Done! \nTime for building connections')
1756
+ self.timer.report("Done! \nTime for building connections")
1625
1757
  return nsyns
1626
1758
 
1627
1759
  # *** Helper functions for verbose ***
1628
1760
  def get_nodes_info(self, target_pop_idx=-1):
1629
1761
  """Get strings with source and target population information"""
1630
1762
  target = self.targets[target_pop_idx]
1631
- source_str = self.source.network_name + ': ' + self.source.filter_str
1632
- target_str = target.network_name + ': ' + target.filter_str
1763
+ source_str = self.source.network_name + ": " + self.source.filter_str
1764
+ target_str = target.network_name + ": " + target.filter_str
1633
1765
  return source_str, target_str
1634
1766
 
1635
1767
 
@@ -1637,27 +1769,38 @@ class OneToOneSequentialConnector(AbstractConnector):
1637
1769
  ######################### ADDTIONAL EDGE PROPERTIES ##########################
1638
1770
 
1639
1771
  SYN_MIN_DELAY = 0.8 # ms
1640
- SYN_VELOCITY = 1000. # um/ms
1772
+ SYN_VELOCITY = 1000.0 # um/ms
1641
1773
  FLUC_STDEV = 0.2 # ms
1642
1774
  DELAY_LOWBOUND = 0.2 # ms must be greater than h.dt
1643
1775
  DELAY_UPBOUND = 2.0 # ms
1644
1776
 
1645
- def syn_const_delay(source=None, target = None, dist=100,
1646
- min_delay=SYN_MIN_DELAY, velocity=SYN_VELOCITY,
1647
- fluc_stdev=FLUC_STDEV, delay_bound=(DELAY_LOWBOUND, DELAY_UPBOUND),
1648
- connector=None):
1649
- """Synapse delay constant with some random fluctuation.
1650
- """
1777
+
1778
+ def syn_const_delay(
1779
+ source=None,
1780
+ target=None,
1781
+ dist=100,
1782
+ min_delay=SYN_MIN_DELAY,
1783
+ velocity=SYN_VELOCITY,
1784
+ fluc_stdev=FLUC_STDEV,
1785
+ delay_bound=(DELAY_LOWBOUND, DELAY_UPBOUND),
1786
+ connector=None,
1787
+ ):
1788
+ """Synapse delay constant with some random fluctuation."""
1651
1789
  del_fluc = fluc_stdev * rng.normal()
1652
1790
  delay = dist / SYN_VELOCITY + SYN_MIN_DELAY + del_fluc
1653
1791
  delay = min(max(delay, DELAY_LOWBOUND), DELAY_UPBOUND)
1654
1792
  return delay
1655
1793
 
1656
1794
 
1657
- def syn_dist_delay_feng(source, target, min_delay=SYN_MIN_DELAY,
1658
- velocity=SYN_VELOCITY, fluc_stdev=FLUC_STDEV,
1659
- delay_bound=(DELAY_LOWBOUND, DELAY_UPBOUND),
1660
- connector=None):
1795
+ def syn_dist_delay_feng(
1796
+ source,
1797
+ target,
1798
+ min_delay=SYN_MIN_DELAY,
1799
+ velocity=SYN_VELOCITY,
1800
+ fluc_stdev=FLUC_STDEV,
1801
+ delay_bound=(DELAY_LOWBOUND, DELAY_UPBOUND),
1802
+ connector=None,
1803
+ ):
1661
1804
  """Synpase delay linearly dependent on distance.
1662
1805
  min_delay: minimum delay (ms)
1663
1806
  velocity: synapse conduction velocity (micron/ms)
@@ -1666,7 +1809,7 @@ def syn_dist_delay_feng(source, target, min_delay=SYN_MIN_DELAY,
1666
1809
  connector: connector object from which to read distance
1667
1810
  """
1668
1811
  if connector is None:
1669
- dist = euclid_dist(target['positions'], source['positions'])
1812
+ dist = euclid_dist(target["positions"], source["positions"])
1670
1813
  else:
1671
1814
  dist = connector.get_conn_prop(source.node_id, target.node_id)
1672
1815
  del_fluc = fluc_stdev * rng.normal()
@@ -1675,30 +1818,30 @@ def syn_dist_delay_feng(source, target, min_delay=SYN_MIN_DELAY,
1675
1818
  return delay
1676
1819
 
1677
1820
 
1678
- def syn_section_PN(source, target, p=0.9,
1679
- sec_id=(1, 2), sec_x=(0.4, 0.6), **kwargs):
1821
+ def syn_section_PN(source, target, p=0.9, sec_id=(1, 2), sec_x=(0.4, 0.6), **kwargs):
1680
1822
  """Synapse location follows a Bernoulli distribution, with probability p
1681
1823
  to obtain the former in sec_id and sec_x"""
1682
1824
  syn_loc = int(not decision(p))
1683
1825
  return sec_id[syn_loc], sec_x[syn_loc]
1684
1826
 
1685
1827
 
1686
- def syn_const_delay_feng_section_PN(source, target, p=0.9,
1687
- sec_id=(1, 2), sec_x=(0.4, 0.6), **kwargs):
1828
+ def syn_const_delay_feng_section_PN(
1829
+ source, target, p=0.9, sec_id=(1, 2), sec_x=(0.4, 0.6), **kwargs
1830
+ ):
1688
1831
  """Assign both synapse delay and location with constant distance assumed"""
1689
- delay = syn_const_delay(source, target,**kwargs)
1832
+ delay = syn_const_delay(source, target, **kwargs)
1690
1833
  s_id, s_x = syn_section_PN(source, target, p=p, sec_id=sec_id, sec_x=sec_x)
1691
1834
  return delay, s_id, s_x
1692
1835
 
1693
1836
 
1694
- def syn_dist_delay_feng_section_PN(source, target, p=0.9,
1695
- sec_id=(1, 2), sec_x=(0.4, 0.6), **kwargs):
1837
+ def syn_dist_delay_feng_section_PN(
1838
+ source, target, p=0.9, sec_id=(1, 2), sec_x=(0.4, 0.6), **kwargs
1839
+ ):
1696
1840
  """Assign both synapse delay and location"""
1697
1841
  delay = syn_dist_delay_feng(source, target, **kwargs)
1698
1842
  s_id, s_x = syn_section_PN(source, target, p=p, sec_id=sec_id, sec_x=sec_x)
1699
1843
  return delay, s_id, s_x
1700
1844
 
1701
1845
 
1702
- def syn_uniform_delay_section(source, target, low=DELAY_LOWBOUND,
1703
- high=DELAY_UPBOUND, **kwargs):
1846
+ def syn_uniform_delay_section(source, target, low=DELAY_LOWBOUND, high=DELAY_UPBOUND, **kwargs):
1704
1847
  return rng.uniform(low, high)