dendrotweaks 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. dendrotweaks/__init__.py +10 -0
  2. dendrotweaks/analysis/__init__.py +11 -0
  3. dendrotweaks/analysis/ephys_analysis.py +482 -0
  4. dendrotweaks/analysis/morphometric_analysis.py +106 -0
  5. dendrotweaks/membrane/__init__.py +6 -0
  6. dendrotweaks/membrane/default_mod/AMPA.mod +65 -0
  7. dendrotweaks/membrane/default_mod/AMPA_NMDA.mod +100 -0
  8. dendrotweaks/membrane/default_mod/CaDyn.mod +54 -0
  9. dendrotweaks/membrane/default_mod/GABAa.mod +65 -0
  10. dendrotweaks/membrane/default_mod/Leak.mod +27 -0
  11. dendrotweaks/membrane/default_mod/NMDA.mod +72 -0
  12. dendrotweaks/membrane/default_mod/vecstim.mod +76 -0
  13. dendrotweaks/membrane/default_templates/NEURON_template.py +354 -0
  14. dendrotweaks/membrane/default_templates/default.py +73 -0
  15. dendrotweaks/membrane/default_templates/standard_channel.mod +87 -0
  16. dendrotweaks/membrane/default_templates/template_jaxley.py +108 -0
  17. dendrotweaks/membrane/default_templates/template_jaxley_new.py +108 -0
  18. dendrotweaks/membrane/distributions.py +324 -0
  19. dendrotweaks/membrane/groups.py +103 -0
  20. dendrotweaks/membrane/io/__init__.py +11 -0
  21. dendrotweaks/membrane/io/ast.py +201 -0
  22. dendrotweaks/membrane/io/code_generators.py +312 -0
  23. dendrotweaks/membrane/io/converter.py +108 -0
  24. dendrotweaks/membrane/io/factories.py +144 -0
  25. dendrotweaks/membrane/io/grammar.py +417 -0
  26. dendrotweaks/membrane/io/loader.py +90 -0
  27. dendrotweaks/membrane/io/parser.py +499 -0
  28. dendrotweaks/membrane/io/reader.py +212 -0
  29. dendrotweaks/membrane/mechanisms.py +574 -0
  30. dendrotweaks/model.py +1916 -0
  31. dendrotweaks/model_io.py +75 -0
  32. dendrotweaks/morphology/__init__.py +5 -0
  33. dendrotweaks/morphology/domains.py +100 -0
  34. dendrotweaks/morphology/io/__init__.py +5 -0
  35. dendrotweaks/morphology/io/factories.py +212 -0
  36. dendrotweaks/morphology/io/reader.py +66 -0
  37. dendrotweaks/morphology/io/validation.py +212 -0
  38. dendrotweaks/morphology/point_trees.py +681 -0
  39. dendrotweaks/morphology/reduce/__init__.py +16 -0
  40. dendrotweaks/morphology/reduce/reduce.py +155 -0
  41. dendrotweaks/morphology/reduce/reduced_cylinder.py +129 -0
  42. dendrotweaks/morphology/sec_trees.py +1112 -0
  43. dendrotweaks/morphology/seg_trees.py +157 -0
  44. dendrotweaks/morphology/trees.py +567 -0
  45. dendrotweaks/path_manager.py +261 -0
  46. dendrotweaks/simulators.py +235 -0
  47. dendrotweaks/stimuli/__init__.py +3 -0
  48. dendrotweaks/stimuli/iclamps.py +73 -0
  49. dendrotweaks/stimuli/populations.py +265 -0
  50. dendrotweaks/stimuli/synapses.py +203 -0
  51. dendrotweaks/utils.py +239 -0
  52. dendrotweaks-0.3.1.dist-info/METADATA +70 -0
  53. dendrotweaks-0.3.1.dist-info/RECORD +56 -0
  54. dendrotweaks-0.3.1.dist-info/WHEEL +5 -0
  55. dendrotweaks-0.3.1.dist-info/licenses/LICENSE +674 -0
  56. dendrotweaks-0.3.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1112 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ from matplotlib.gridspec import GridSpec
5
+ import matplotlib.colors as mcolors
6
+
7
+ from typing import Callable, List
8
+ from neuron import h
9
+
10
+ from dendrotweaks.morphology.trees import Node, Tree
11
+ from dendrotweaks.morphology.domains import Domain
12
+ from dataclasses import dataclass, field
13
+ from bisect import bisect_left
14
+
15
+ import warnings
16
+
17
+ from dendrotweaks.utils import get_domain_color
18
+
19
+ def custom_warning_formatter(message, category, filename, lineno, file=None, line=None):
20
+ return f"{category.__name__}: {message} ({os.path.basename(filename)}, line {lineno})\n"
21
+
22
+ warnings.formatwarning = custom_warning_formatter
23
+
24
+
25
+
26
+ class Section(Node):
27
+ """
28
+ A class representing a section in a neuron morphology.
29
+
30
+ A section is continuous part of a neuron's morphology between two branching points.
31
+
32
+ Parameters
33
+ ----------
34
+ idx : str
35
+ The index of the section.
36
+ parent_idx : str
37
+ The index of the parent section.
38
+ points : List[Point]
39
+ The points that define the section.
40
+
41
+ Attributes
42
+ ----------
43
+ points : List[Point]
44
+ The points that define the section.
45
+ segments : List[Segment]
46
+ The segments to which the section is divided.
47
+ _ref : h.Section
48
+ The reference to the NEURON section.
49
+ """
50
+
51
+ def __init__(self, idx: str, parent_idx: str, points: List[Node]) -> None:
52
+ super().__init__(idx, parent_idx)
53
+ self.domain_idx = None
54
+ self.points = points
55
+ self.segments = []
56
+ self._ref = None
57
+ self._domain = self.points[0].domain
58
+
59
+ if not all(pt.domain == self._domain for pt in points):
60
+ raise ValueError('All points in a section must belong to the same domain.')
61
+
62
+
63
+ # MAGIC METHODS
64
+
65
+ def __call__(self, x: float):
66
+ """
67
+ Return the segment at a given position.
68
+ """
69
+ if self._ref is None:
70
+ raise ValueError('Section is not referenced in NEURON.')
71
+ if x < 0 or x > 1:
72
+ raise ValueError('Location x must be in the range [0, 1].')
73
+ elif x == 0:
74
+ # TODO: Decide how to handle sec(0) and sec(1)
75
+ # as they are not shown in the seg_graph
76
+ return self.segments[0]
77
+ elif x == 1:
78
+ return self.segments[-1]
79
+ matching_segs = [self._ref(x) == seg._ref for seg in self.segments]
80
+ if any(matching_segs):
81
+ return self.segments[matching_segs.index(True)]
82
+ raise ValueError(f'No segment found at location {x}')
83
+
84
+
85
+ def __iter__(self):
86
+ """
87
+ Iterate over the segments in the section.
88
+ """
89
+ for seg in self.segments:
90
+ yield seg
91
+
92
+
93
+ # PROPERTIES
94
+
95
+ @property
96
+ def domain(self):
97
+ """
98
+ The morphological or functional domain of the node.
99
+ """
100
+ return self._domain
101
+
102
+
103
+ @domain.setter
104
+ def domain(self, domain):
105
+ self._domain = domain
106
+ for pt in self.points:
107
+ pt.domain = domain
108
+
109
+
110
+ @property
111
+ def df_points(self):
112
+ """
113
+ A DataFrame of the points in the section.
114
+ """
115
+ # concatenate the dataframes of the nodes
116
+ return pd.concat([pt.df for pt in self.points])
117
+
118
+
119
+ @property
120
+ def df(self):
121
+ """
122
+ A DataFrame of the section.
123
+ """
124
+ return pd.DataFrame({'idx': [self.idx],
125
+ 'parent_idx': [self.parent_idx]})
126
+
127
+
128
+ # TODO: Figure out why this is different from NEURON's diam
129
+ # Update: In NEURON, sec.diam returns the diameter of the segment at the center of the section
130
+ # In this implementation, sec.diam returns the average diameter of the section
131
+ # @property
132
+ # def diam(self):
133
+ # """
134
+ # Average diameter of the section calculated from
135
+ # the radii and distances of the points.
136
+ # """
137
+ # distances = self.distances # Cumulative distances
138
+ # radii = self.radii # Corresponding radii
139
+ # total_length = distances[-1] # Total section length
140
+
141
+ # if total_length == 0:
142
+ # return 0 # Avoid division by zero for zero-length sections
143
+
144
+ # segment_lengths = np.diff(distances) # Lengths of frusta segments
145
+ # segment_diameters = 2 * (np.array(radii[:-1]) + np.array(radii[1:])) / 2 # Mean diameter per segment
146
+
147
+ # # Length-weighted average
148
+ # avg_diameter = np.sum(segment_diameters * segment_lengths) / total_length
149
+
150
+ # return avg_diameter
151
+
152
+ @property
153
+ def diam(self):
154
+ """
155
+ Diameter of the central segment of the section (from NEURON).
156
+ """
157
+ return self._ref.diam
158
+
159
+
160
+ @property
161
+ def L(self):
162
+ """
163
+ Length of the section (from NEURON).
164
+ """
165
+ return self._ref.L
166
+
167
+
168
+ @property
169
+ def cm(self):
170
+ """
171
+ Specific membrane capacitance of the section (from NEURON).
172
+ """
173
+ return self._ref.cm
174
+
175
+
176
+ @property
177
+ def Ra(self):
178
+ """
179
+ Axial resistance of the section (from NEURON).
180
+ """
181
+ return self._ref.Ra
182
+
183
+
184
+ @property
185
+ def nseg(self):
186
+ """
187
+ Number of segments in the section (from NEURON).
188
+ """
189
+ return self._ref.nseg
190
+
191
+ @nseg.setter
192
+ def nseg(self, value):
193
+ if value < 1:
194
+ raise ValueError('Number of segments must be at least 1.')
195
+ if value % 2 == 0:
196
+ raise ValueError('Number of segments must be odd.')
197
+ # Set the number in NEURON
198
+ self._ref.nseg = value
199
+ # Get the new NEURON segments
200
+ nrnsegs = [seg for seg in self._ref]
201
+
202
+ # Create new DendroTweaks segments
203
+ from dendrotweaks.morphology.seg_trees import Segment
204
+ old_segments = self.segments
205
+ new_segments = [Segment(idx=0, parent_idx=0, neuron_seg=seg, section=self)
206
+ for seg in nrnsegs]
207
+
208
+ seg_tree = self._tree._seg_tree
209
+ first_segment = self.segments[0]
210
+ parent = first_segment.parent
211
+
212
+ for i, seg in enumerate(new_segments[:]):
213
+ if i == 0:
214
+ seg_tree.insert_node_before(seg, first_segment)
215
+ else:
216
+ seg_tree.insert_node_before(seg, new_segments[i-1])
217
+
218
+ for seg in old_segments:
219
+ seg_tree.remove_node(seg)
220
+
221
+ # Sort the tree
222
+ self._tree._seg_tree.sort()
223
+ # Update the section's segments
224
+ self.segments = new_segments
225
+
226
+
227
+ @property
228
+ def radii(self):
229
+ """
230
+ Radii of the points in the section.
231
+ """
232
+ return [pt.r for pt in self.points]
233
+
234
+
235
+ @property
236
+ def diameters(self):
237
+ """
238
+ Diameters of the points in the section.
239
+ """
240
+ return [2 * pt.r for pt in self.points]
241
+
242
+
243
+ @property
244
+ def xs (self):
245
+ """
246
+ X-coordinates of the points in the section.
247
+ """
248
+ return [pt.x for pt in self.points]
249
+
250
+
251
+ @property
252
+ def ys(self):
253
+ """
254
+ Y-coordinates of the points in the section.
255
+ """
256
+ return [pt.y for pt in self.points]
257
+
258
+
259
+ @property
260
+ def zs(self):
261
+ """
262
+ Z-coordinates of the points in the section.
263
+ """
264
+ return [pt.z for pt in self.points]
265
+
266
+
267
+ @property
268
+ def seg_centers(self):
269
+ """
270
+ The list of segment centers in the section with normalized length.
271
+ """
272
+ if self._ref is None:
273
+ raise ValueError('Section is not referenced in NEURON.')
274
+ return (np.array([(2*i - 1) / (2 * self._ref.nseg)
275
+ for i in range(1, self._ref.nseg + 1)]) * self._ref.L).tolist()
276
+
277
+
278
+ @property
279
+ def seg_borders(self):
280
+ """
281
+ The list of segment borders in the section with normalized length.
282
+ """
283
+ if self._ref is None:
284
+ raise ValueError('Section is not referenced in NEURON.')
285
+ nseg = int(self._ref.nseg)
286
+ return [i / nseg for i in range(nseg + 1)]
287
+
288
+
289
+ @property
290
+ def distances(self):
291
+ """
292
+ The list of cumulative euclidean distances of the points in the section.
293
+ """
294
+ coords = np.array([[pt.x, pt.y, pt.z] for pt in self.points])
295
+ deltas = np.diff(coords, axis=0)
296
+ frusta_distances = np.sqrt(np.sum(deltas**2, axis=1))
297
+ cumulative_frusta_distances = np.insert(np.cumsum(frusta_distances), 0, 0)
298
+ return cumulative_frusta_distances
299
+
300
+
301
+ @property
302
+ def center(self):
303
+ """
304
+ The coordinates of the center of the section.
305
+ """
306
+ return np.mean(self.xs), np.mean(self.ys), np.mean(self.zs)
307
+
308
+
309
+ @property
310
+ def length(self):
311
+ """
312
+ The length of the section calculated as the sum of the distances between the points.
313
+ """
314
+ return self.distances[-1]
315
+
316
+
317
+ @property
318
+ def area(self):
319
+ """
320
+ The surface area of the section calculated as the sum of the areas of the frusta segments.
321
+ """
322
+ areas = [np.pi * (r1 + r2) * np.sqrt((r1 - r2)**2 + h**2) for r1, r2, h in zip(self.radii[:-1], self.radii[1:], np.diff(self.distances))]
323
+ return sum(areas)
324
+
325
+ def has_mechanism(mech_name):
326
+ """
327
+ Check if the section has a mechanism inserted.
328
+
329
+ Parameters
330
+ ----------
331
+ mech_name : str
332
+ The name of the mechanism to check.
333
+ """
334
+ return self._ref.has_membrane(mech_name)
335
+
336
+
337
+ # REFERENCING METHODS
338
+
339
+ def create_and_reference(self, simulator_name='NEURON'):
340
+ """
341
+ Add a reference to the section in the simulator.
342
+
343
+ Parameters
344
+ ----------
345
+ simulator_name : str
346
+ The name of the simulator to create the section in.
347
+ """
348
+ if simulator_name == 'NEURON':
349
+ self.create_NEURON_section()
350
+ elif simulator_name == 'JAXLEY':
351
+ self.create_JAXLEY_section()
352
+
353
+
354
+ def create_NEURON_section(self):
355
+ """
356
+ Create a NEURON section.
357
+ """
358
+ self._ref = h.Section() # name=f'Sec_{self.idx}'
359
+ if self.parent is not None:
360
+ # TODO: Attaching basal to soma 0
361
+ if self.parent.parent is None: # if parent is soma
362
+ self._ref.connect(self.parent._ref(0.5))
363
+ else:
364
+ self._ref.connect(self.parent._ref(1))
365
+ # Add 3D points to the section
366
+ self._ref.pt3dclear()
367
+ for pt in self.points:
368
+ diam = 2*pt.r
369
+ diam = round(diam, 16)
370
+ self._ref.pt3dadd(pt.x, pt.y, pt.z, diam)
371
+
372
+ def create_JAXLEY_section(self):
373
+ """
374
+ Create a JAXLEY section.
375
+ """
376
+ raise NotImplementedError
377
+
378
+
379
+ # MECHANISM METHODS
380
+
381
+ def insert_mechanism(self, name: str):
382
+ """
383
+ Inserts a mechanism in the section if
384
+ it is not already inserted.
385
+ """
386
+ if self._ref.has_membrane(name):
387
+ return
388
+ self._ref.insert(name)
389
+
390
+
391
+ def uninsert_mechanism(self, name: str):
392
+ """
393
+ Uninserts a mechanism in the section if
394
+ it was inserted.
395
+ """
396
+ if not self._ref.has_membrane(name):
397
+ return
398
+ self._ref.uninsert(name)
399
+
400
+
401
+ # PARAMETER METHODS
402
+
403
+ def get_param_value(self, param_name):
404
+ """
405
+ Get the average parameter of the section's segments.
406
+
407
+ Parameters
408
+ ----------
409
+ param_name : str
410
+ The name of the parameter to get.
411
+
412
+ Returns
413
+ -------
414
+ float
415
+ The average value of the parameter in the section's segments.
416
+ """
417
+ # if param_name in ['Ra', 'diam', 'L', 'nseg', 'domain', 'subtree_size']:
418
+ # return getattr(self, param_name)
419
+ # if param_name in ['dist']:
420
+ # return self.distance_to_root(0.5)
421
+ seg_values = [seg.get_param_value(param_name) for seg in self.segments]
422
+ return round(np.mean(seg_values), 16)
423
+
424
+
425
+ def path_distance(self, relative_position: float = 0,
426
+ within_domain: bool = False) -> float:
427
+ """
428
+ Calculate the distance from the section to the root at a given relative position.
429
+
430
+ Parameters
431
+ ----------
432
+ relative_position : float
433
+ The position along the section's normalized length [0, 1].
434
+ within_domain : bool
435
+ Whether to stop at the domain boundary.
436
+
437
+ Returns
438
+ -------
439
+ float
440
+ The distance from the section to the root.
441
+
442
+ Important
443
+ ---------
444
+ Assumes that we always attach the 0 end of the child.
445
+ """
446
+ if not (0 <= relative_position <= 1):
447
+ raise ValueError('Relative position must be between 0 and 1.')
448
+
449
+ distance = 0
450
+ factor = relative_position
451
+ node = self
452
+
453
+ while node.parent:
454
+
455
+ distance += factor * node.length
456
+
457
+ if within_domain and node.parent.domain != node.domain:
458
+ break
459
+
460
+ node = node.parent
461
+ factor = 1
462
+
463
+ return distance
464
+
465
+
466
+ def disconnect_from_parent(self):
467
+ """
468
+ Detach the section from its parent section.
469
+ """
470
+ # In SectionTree
471
+ super().disconnect_from_parent()
472
+ # In NEURON
473
+ if self._ref:
474
+ h.disconnect(sec=self._ref) #from parent
475
+ # In PointTree
476
+ self.points[0].disconnect_from_parent()
477
+ # In SegmentTree
478
+ if self.segments:
479
+ self.segments[0].disconnect_from_parent()
480
+
481
+
482
+ def connect_to_parent(self, parent):
483
+ """
484
+ Attach the section to a parent section.
485
+
486
+ Parameters
487
+ ----------
488
+ parent : Section
489
+ The parent section to attach to.
490
+ """
491
+ # In SectionTree
492
+ super().connect_to_parent(parent)
493
+ # In NEURON
494
+ if self._ref:
495
+ if self.parent is not None:
496
+ if self.parent.parent is None: # if parent is soma
497
+ self._ref.connect(self.parent._ref(0.5))
498
+ else:
499
+ self._ref.connect(self.parent._ref(1))
500
+
501
+ # In PointTree
502
+ if self.parent is not None:
503
+ if self.parent.parent is None: # if parent is soma
504
+ parent_sec = self.parent
505
+ parent_pt = parent_sec.points[1] if len(parent_sec.points) > 1 else parent_sec.points[0]
506
+ self.points[0].connect_to_parent(parent_pt) # attach to the middle of the parent
507
+ else:
508
+ self.points[0].connect_to_parent(parent.points[-1]) # attach to the end of the parent
509
+ # In SegmentTree
510
+ if self.segments:
511
+ self.segments[0].connect_to_parent(parent.segments[-1])
512
+
513
+
514
+ # PLOTTING METHODS
515
+
516
+ def plot(self, ax=None, plot_parent=True, section_color=None, parent_color='gray',
517
+ show_labels=True, aspect_equal=True):
518
+ """
519
+ Plot section morphology in 3D projections (XZ, YZ, XY) and radii distribution.
520
+
521
+ Parameters
522
+ ----------
523
+ ax : list or array of matplotlib.axes.Axes, optional
524
+ Four axes for plotting (XZ, YZ, XY, radii). If None, creates a new figure with axes.
525
+ plot_parent : bool, optional
526
+ Whether to include parent section in the visualization.
527
+ section_color : str or None, optional
528
+ Color for the current section. If None, assigns based on section domain.
529
+ parent_color : str, optional
530
+ Color for the parent section.
531
+ show_labels : bool, optional
532
+ Whether to show axis labels and titles.
533
+ aspect_equal : bool, optional
534
+ Whether to set aspect ratio to 'equal' for the projections.
535
+
536
+ Returns
537
+ -------
538
+ ax : list of matplotlib.axes.Axes
539
+ The axes containing the plots.
540
+ """
541
+ # Create figure and axes if not provided
542
+ if ax is None:
543
+ fig = plt.figure(figsize=(10, 8))
544
+ gs = GridSpec(2, 3, width_ratios=[1, 1, 1.2], figure=fig)
545
+
546
+ # Create the three projection axes and one radius axis
547
+ ax_xz = fig.add_subplot(gs[0, 0])
548
+ ax_yz = fig.add_subplot(gs[0, 1])
549
+ ax_xy = fig.add_subplot(gs[1, 0])
550
+ ax_radii = fig.add_subplot(gs[1, 1:])
551
+ ax = [ax_xz, ax_yz, ax_xy, ax_radii]
552
+ else:
553
+ # Use provided axes
554
+ if len(ax) != 4:
555
+ # flatten
556
+ ax = [ai for a in ax for ai in a]
557
+ ax_xz, ax_yz, ax_xy, ax_radii = ax
558
+
559
+ # Determine section color based on domain if not provided
560
+ if section_color is None:
561
+ section_color = get_domain_color(self.domain)
562
+
563
+ # Extract coordinates
564
+ xs = np.array([p.x for p in self.points])
565
+ ys = np.array([p.y for p in self.points])
566
+ zs = np.array([p.z for p in self.points])
567
+
568
+ # Plot section projections
569
+ self._plot_projection(ax_xz, xs, zs, 'X', 'Z', 'XZ Projection',
570
+ section_color, show_labels, aspect_equal)
571
+ self._plot_projection(ax_yz, ys, zs, 'Y', 'Z', 'YZ Projection',
572
+ section_color, show_labels, aspect_equal)
573
+ self._plot_projection(ax_xy, xs, ys, 'X', 'Y', 'XY Projection',
574
+ section_color, show_labels, aspect_equal)
575
+
576
+ # Plot radius distribution
577
+ self._plot_radii_distribution(ax_radii, plot_parent, section_color, parent_color)
578
+
579
+ # Plot parent section if requested
580
+ if plot_parent and self.parent:
581
+ # Only plot parent projections, radii are handled in _plot_radii_distribution
582
+ parent_xs = np.array([p.x for p in self.parent.points])
583
+ parent_ys = np.array([p.y for p in self.parent.points])
584
+ parent_zs = np.array([p.z for p in self.parent.points])
585
+
586
+ self.parent._plot_projection(ax_xz, parent_xs, parent_zs, None, None, None,
587
+ parent_color, False, aspect_equal)
588
+ self.parent._plot_projection(ax_yz, parent_ys, parent_zs, None, None, None,
589
+ parent_color, False, aspect_equal)
590
+ self.parent._plot_projection(ax_xy, parent_xs, parent_ys, None, None, None,
591
+ parent_color, False, aspect_equal)
592
+
593
+ # Add overall title if we created the figure
594
+ if ax is not None and show_labels:
595
+ fig = ax_xz.get_figure()
596
+ fig.suptitle(f"Section {self.idx} ({self.domain})", fontsize=14)
597
+ fig.tight_layout()
598
+
599
+ return ax
600
+
601
+ def _plot_projection(self, ax, x_coords, y_coords, x_label, y_label, title,
602
+ color, show_labels, aspect_equal):
603
+ """Helper method to plot a 2D projection of the section."""
604
+ ax.plot(x_coords, y_coords, 'o-', color=color, markerfacecolor=color,
605
+ markeredgecolor='black', markersize=4, linewidth=1.5)
606
+
607
+ if show_labels:
608
+ if x_label:
609
+ ax.set_xlabel(x_label)
610
+ if y_label:
611
+ ax.set_ylabel(y_label)
612
+ if title:
613
+ ax.set_title(title)
614
+
615
+ if aspect_equal:
616
+ ax.set_aspect('equal')
617
+
618
+ def _plot_radii_distribution(self, ax, plot_parent, section_color, parent_color):
619
+ """Helper method to plot radius distribution along the section."""
620
+ # Get section length for normalization
621
+ section_length = self.distances[-1] - self.distances[0]
622
+
623
+ # Normalize distances to start at 0 and end at section_length
624
+ normalized_distances = self.distances - self.distances[0]
625
+
626
+ # Plot section radii
627
+ ax.plot(normalized_distances, self.radii, 'o-', color=section_color,
628
+ label=f"{self.domain} ({self.idx})", linewidth=2)
629
+
630
+ # Plot reference NEURON segments if available
631
+ if hasattr(self, '_ref') and self._ref:
632
+ # Calculate normalized segment centers
633
+ normalized_seg_centers = np.array(self.seg_centers) - self.distances[0]
634
+
635
+ # Extract radii from segments
636
+ seg_radii = np.array([seg.diam / 2 for seg in self._ref])
637
+
638
+ # Use the specified bar width calculation from original code
639
+ bar_width = [self._ref.L / self._ref.nseg] * self._ref.nseg
640
+
641
+ # Plot segment radii as bars
642
+ ax.bar(normalized_seg_centers, seg_radii, width=bar_width,
643
+ alpha=0.5, color=section_color, edgecolor='white',
644
+ label=f"{self.domain} segments")
645
+
646
+ # Plot parent section if requested
647
+ if plot_parent and self.parent:
648
+ parent_length = self.parent.distances[-1] - self.parent.distances[0]
649
+
650
+ # Normalize parent distances to end at 0 (connecting to child)
651
+ # Parent section goes from -parent_length to 0
652
+ normalized_parent_distances = self.parent.distances - self.parent.distances[-1]
653
+
654
+ # Plot parent radii
655
+ ax.plot(normalized_parent_distances, self.parent.radii, 'o-',
656
+ color=parent_color, linewidth=2,
657
+ label=f"Parent {self.parent.domain} ({self.parent.idx})")
658
+
659
+ # Plot parent reference segments if available
660
+ if hasattr(self.parent, '_ref') and self.parent._ref:
661
+ # Normalize parent segment centers to the same scale
662
+ normalized_parent_seg_centers = (np.array(self.parent.seg_centers) -
663
+ self.parent.distances[-1])
664
+
665
+ # Extract parent segment radii
666
+ parent_seg_radii = np.array([seg.diam / 2 for seg in self.parent._ref])
667
+
668
+ # Use the specified bar width calculation for parent
669
+ parent_bar_width = [self.parent._ref.L / self.parent._ref.nseg] * self.parent._ref.nseg
670
+
671
+ # Plot parent segment radii as bars
672
+ ax.bar(normalized_parent_seg_centers, parent_seg_radii,
673
+ width=parent_bar_width, alpha=0.5, color=parent_color,
674
+ edgecolor='white', label=f"Parent segments")
675
+
676
+ # Set plot labels and legend
677
+ ax.set_xlabel('Distance (µm)')
678
+ ax.set_ylabel('Radius (µm)')
679
+ ax.set_title('Radius Distribution')
680
+
681
+ # Ensure y-axis starts at 0
682
+ ax.set_ylim(bottom=0)
683
+
684
+ # Adjust x-axis to show the full section(s)
685
+ if plot_parent and self.parent:
686
+ parent_length = self.parent.distances[-1] - self.parent.distances[0]
687
+ ax.set_xlim(-parent_length * 1.05, section_length * 1.05)
688
+ else:
689
+ ax.set_xlim(-section_length * 0.05, section_length * 1.05)
690
+
691
+ # Add legend if we have multiple data series
692
+ if ((hasattr(self, '_ref') and self._ref) or
693
+ (plot_parent and self.parent)):
694
+ ax.legend(loc='best', frameon=True, framealpha=0.8)
695
+
696
+ def plot_radii(self, ax=None, include_parent=False, section_color=None, parent_color='gray'):
697
+ """
698
+ Plot just the radius distribution for the section.
699
+
700
+ Parameters
701
+ ----------
702
+ ax : matplotlib.axes.Axes, optional
703
+ Axes to plot on. If None, creates a new figure and axes.
704
+ include_parent : bool, optional
705
+ Whether to include parent section in the plot.
706
+ section_color : str or None, optional
707
+ Color for current section. If None, assigns based on section domain.
708
+ parent_color : str, optional
709
+ Color for parent section if included.
710
+
711
+ Returns
712
+ -------
713
+ ax : matplotlib.axes.Axes
714
+ The axes containing the plot.
715
+ """
716
+ # Create new figure and axes if not provided
717
+ if ax is None:
718
+ fig, ax = plt.subplots(figsize=(10, 5))
719
+
720
+ # Determine section color based on domain if not provided
721
+ if section_color is None:
722
+ domain_colors = {
723
+ 'soma': 'black',
724
+ 'axon': 'red',
725
+ 'dend': 'blue',
726
+ 'apic': 'green'
727
+ }
728
+ section_color = domain_colors.get(self.domain, 'purple')
729
+
730
+ # Plot radius distribution
731
+ self._plot_radii_distribution(ax, include_parent, section_color, parent_color)
732
+
733
+ # Add title if creating a standalone plot
734
+ if ax.get_figure().get_axes()[0] == ax: # If this is the only axes in the figure
735
+ ax.set_title(f"Radius Distribution - Section {self.idx} ({self.domain})")
736
+ plt.tight_layout()
737
+
738
+ return ax
739
+
740
+
741
+ class SectionTree(Tree):
742
+ """
743
+ A class representing a tree graph of sections in a neuron morphology.
744
+
745
+ Parameters
746
+ ----------
747
+ sections : List[Section]
748
+ A list of sections in the tree.
749
+
750
+ Attributes
751
+ ----------
752
+ domains : Dict[str, Domain]
753
+ A dictionary of domains in the tree.
754
+ """
755
+
756
+ def __init__(self, sections: list[Section]) -> None:
757
+ super().__init__(sections)
758
+ self._create_domains()
759
+ self._point_tree = None
760
+ self._seg_tree = None
761
+
762
+
763
+ def __repr__(self):
764
+ return f"SectionTree(root={self.root!r}, num_nodes={len(self._nodes)})"
765
+
766
+
767
+ def _create_domains(self):
768
+ """
769
+ Create domains using the data from the sections (from the points in the sections).
770
+ """
771
+
772
+ unique_domain_names = set([sec.domain for sec in self.sections])
773
+ self.domains = {name: Domain(name) for name in unique_domain_names}
774
+
775
+ for sec in self.sections:
776
+ self.domains[sec.domain].add_section(sec)
777
+
778
+
779
+ # PROPERTIES
780
+
781
+ @property
782
+ def sections(self):
783
+ """
784
+ A list of sections in the tree. Alias for self._nodes.
785
+ """
786
+ return self._nodes
787
+
788
+
789
+ @property
790
+ def soma(self):
791
+ """
792
+ The soma section of the tree. Alias for self.root.
793
+ """
794
+ return self.root
795
+
796
+
797
+ @property
798
+ def sections_by_depth(self):
799
+ """
800
+ A dictionary of sections grouped by depth in the tree
801
+ (depth is the number of edges from the root).
802
+ """
803
+ sections_by_depth = {}
804
+ for sec in self.sections:
805
+ if sec.depth not in sections_by_depth:
806
+ sections_by_depth[sec.depth] = []
807
+ sections_by_depth[sec.depth].append(sec)
808
+ return sections_by_depth
809
+
810
+
811
+ @property
812
+ def df(self):
813
+ """
814
+ A DataFrame of the sections in the tree.
815
+ """
816
+ data = {
817
+ 'idx': [],
818
+ 'domain': [],
819
+ 'x': [],
820
+ 'y': [],
821
+ 'z': [],
822
+ 'r': [],
823
+ 'parent_idx': [],
824
+ 'section_idx': [],
825
+ 'parent_section_idx': [],
826
+ }
827
+
828
+ for sec in self.sections:
829
+ points = sec.points if sec.parent is None or sec.parent.parent is None else sec.points[1:]
830
+ for pt in points:
831
+ data['idx'].append(pt.idx)
832
+ data['domain'].append(pt.domain)
833
+ data['x'].append(pt.x)
834
+ data['y'].append(pt.y)
835
+ data['z'].append(pt.z)
836
+ data['r'].append(pt.r)
837
+ data['parent_idx'].append(pt.parent_idx)
838
+ data['section_idx'].append(sec.idx)
839
+ data['parent_section_idx'].append(sec.parent_idx)
840
+
841
+ return pd.DataFrame(data)
842
+
843
+
844
+ def sort(self, **kwargs):
845
+ """
846
+ Sort the sections in the tree using a depth-first traversal.
847
+
848
+ Parameters
849
+ ----------
850
+ sort_children : bool, optional
851
+ Whether to sort the children of each node
852
+ based on the number of bifurcations in their subtrees. Defaults to True.
853
+ force : bool, optional
854
+ Whether to force the sorting of the tree even if it is already sorted. Defaults to False.
855
+ """
856
+ super().sort(**kwargs)
857
+ self._point_tree.sort(**kwargs)
858
+ if self._seg_tree:
859
+ self._seg_tree.sort(**kwargs)
860
+
861
+
862
+ # STRUCTURE METHODS
863
+
864
+ def remove_subtree(self, section):
865
+ """
866
+ Remove a section and its subtree from the tree.
867
+
868
+ Parameters
869
+ ----------
870
+ section : Section
871
+ The section to remove.
872
+ """
873
+ super().remove_subtree(section)
874
+ # Domains
875
+ for domain in self.domains.values():
876
+ for sec in section.subtree:
877
+ if sec in domain.sections:
878
+ domain.remove_section(sec)
879
+ # Points
880
+ self._point_tree.remove_subtree(section.points[0])
881
+ # Segments
882
+ if self._seg_tree:
883
+ self._seg_tree.remove_subtree(section.segments[0])
884
+ # NEURON
885
+ if section._ref:
886
+ h.disconnect(sec=section._ref)
887
+ for sec in section.subtree:
888
+ h.delete_section(sec=sec._ref)
889
+
890
+
891
+ def remove_zero_length_sections(self):
892
+ """
893
+ Remove sections with zero length.
894
+ """
895
+ for sec in self.sections:
896
+ if sec.length == 0:
897
+ for pt in sec.points:
898
+ self._point_tree.remove_node(pt)
899
+ for seg in sec.segments:
900
+ self._seg_tree.remove_node(seg)
901
+ self.remove_node(sec)
902
+
903
+
904
+ def downsample(self, factor: float):
905
+ """
906
+ Downsample the SWC tree by reducing the number of points in each section
907
+ based on the given factor, while preserving the first and last points.
908
+
909
+ :param factor: The proportion of points to keep (e.g., 0.5 keeps 50% of points)
910
+ If factor is 0, keep only the first and last points.
911
+ """
912
+ for sec in self.sections:
913
+ if sec is self.soma:
914
+ continue
915
+
916
+ if len(sec.points) < 3: # Keep sections with only start & end points
917
+ continue
918
+
919
+ num_points = len(sec.points)
920
+ if factor == 0:
921
+ num_to_keep = 2
922
+ else:
923
+ num_to_keep = max(2, int(num_points * factor)) # Ensure at least start & end remain
924
+
925
+ # Select indices to keep (first, last, and spaced indices in between)
926
+ keep_indices = np.linspace(0, num_points - 1, num_to_keep, dtype=int)
927
+ keep_set = set(keep_indices)
928
+
929
+ points_to_remove = [pt for i, pt in enumerate(sec.points) if i not in keep_set]
930
+
931
+ print(f'Removing {len(points_to_remove)} points from section {sec.idx}')
932
+
933
+ for pt in points_to_remove:
934
+ self._point_tree.remove_node(pt)
935
+ sec.points.remove(pt)
936
+
937
+ self._point_tree.sort()
938
+
939
+
940
+ # def plot_sections_as_matrix(self, ax=None):
941
+ # """
942
+ # Plot the sections as a connectivity matrix.
943
+ # """
944
+ # if ax is None:
945
+ # fig, ax = plt.subplots()
946
+
947
+ # n = len(self.sections)
948
+ # matrix = np.zeros((n, n))
949
+ # for section in self.sections:
950
+ # if section.parent:
951
+ # matrix[section.idx, section.parent.idx] = section.idx
952
+ # matrix[matrix == 0] = np.nan
953
+
954
+ # ax.imshow(matrix.T, cmap='jet_r')
955
+ # ax.set_xlabel('Section ID')
956
+ # ax.set_ylabel('Parent ID')
957
+
958
+
959
+ # PLOTTING METHODS
960
+
961
+ def plot(self, ax=None, show_points=False, show_lines=True,
962
+ show_domains=True, annotate=False,
963
+ projection='XY', highlight_sections=None, focus_sections=None):
964
+ """
965
+ Plot the sections in the tree in a 2D projection.
966
+
967
+ Parameters
968
+ ----------
969
+ ax : matplotlib.axes.Axes, optional
970
+ Axes to plot on. If None, creates a new figure and axes.
971
+ show_points : bool, optional
972
+ Whether to show the points in the sections.
973
+ show_lines : bool, optional
974
+ Whether to show the lines connecting the points.
975
+ show_domains : bool, optional
976
+ Whether to color sections based on their domain.
977
+ annotate : bool, optional
978
+ Whether to annotate the sections with their index.
979
+ projection : str or tuple, optional
980
+ The projection to use for the plot. Can be 'XY', 'XZ', 'YZ', or a tuple of two axes.
981
+ highlight_sections : list of Section, optional
982
+ Sections to highlight in the plot.
983
+ focus_sections : list of Section, optional
984
+ Sections to focus on in the plot.
985
+ """
986
+ if ax is None:
987
+ fig, ax = plt.subplots(figsize=(10, 10))
988
+
989
+ highlight_sections = set(highlight_sections) if highlight_sections else None
990
+ focus_sections = set(focus_sections) if focus_sections else None
991
+ x_attr, y_attr = projection[0].lower(), projection[1].lower()
992
+
993
+ section_count = len(self.sections) # Avoid recalculating
994
+
995
+ for sec in self.sections:
996
+ # Skip sections that are not in the focus set (if focus is specified)
997
+ if focus_sections and sec not in focus_sections:
998
+ continue
999
+
1000
+ xs = [getattr(pt, x_attr) for pt in sec.points]
1001
+ ys = [getattr(pt, y_attr) for pt in sec.points]
1002
+
1003
+ # Assign colors based on domains or section index
1004
+ color = plt.cm.jet(1 - sec.idx / section_count)
1005
+ if show_domains:
1006
+ color = get_domain_color(sec.domain)
1007
+ if highlight_sections and sec in highlight_sections:
1008
+ color = 'red'
1009
+
1010
+ # Plot section points and lines
1011
+ if show_points:
1012
+ ax.plot(xs, ys, '.', color=color, markersize=7, markeredgecolor='black')
1013
+ if show_lines:
1014
+ ax.plot(xs, ys, color=color, zorder=0)
1015
+
1016
+ # Annotate section index if needed
1017
+ if annotate:
1018
+ mean_x, mean_y = np.mean(xs), np.mean(ys)
1019
+ ax.annotate(
1020
+ f'{sec.idx}', (mean_x, mean_y), fontsize=8,
1021
+ color='white',
1022
+ bbox=dict(facecolor='black', edgecolor='white',
1023
+ boxstyle='round,pad=0.3')
1024
+ )
1025
+
1026
+ ax.set_xlabel(projection[0])
1027
+ ax.set_ylabel(projection[1])
1028
+ ax.set_aspect('equal')
1029
+
1030
+
1031
+
1032
+ def plot_radii_distribution(self, ax=None, highlight=None,
1033
+ domains=True, show_soma=False):
1034
+ """
1035
+ Plot the radius distribution of the sections in the tree.
1036
+
1037
+ Parameters
1038
+ ----------
1039
+ ax : matplotlib.axes.Axes, optional
1040
+ Axes to plot on. If None, creates a new figure and axes.
1041
+ highlight : list of int, optional
1042
+ Indices of sections to highlight in the plot.
1043
+ domains : bool, optional
1044
+ Whether to color sections based on their domain.
1045
+ show_soma : bool, optional
1046
+ Whether to show the soma section in the plot.
1047
+ """
1048
+ if ax is None:
1049
+ fig, ax = plt.subplots(figsize=(8, 3))
1050
+
1051
+ for sec in self.sections:
1052
+ if not show_soma and sec.parent is None:
1053
+ continue
1054
+ color = get_domain_color(sec.domain)
1055
+ if highlight and sec.idx in highlight:
1056
+ ax.plot(
1057
+ [pt.path_distance() for pt in sec.points],
1058
+ sec.radii,
1059
+ marker='.',
1060
+ color='red',
1061
+ zorder=2
1062
+ )
1063
+ else:
1064
+ ax.plot(
1065
+ [pt.path_distance() for pt in sec.points],
1066
+ sec.radii,
1067
+ marker='.',
1068
+ color=color,
1069
+ zorder=1
1070
+ )
1071
+ ax.set_xlabel('Distance from root')
1072
+ ax.set_ylabel('Radius')
1073
+
1074
+
1075
+ # EXPORT METHODS
1076
+
1077
+ def to_swc(self, path_to_file: str):
1078
+ """
1079
+ Save the SectionTree as an SWC file.
1080
+
1081
+ Parameters
1082
+ ----------
1083
+ path_to_file : str
1084
+ The path to save the SWC file.
1085
+ """
1086
+ if not self.is_sorted or not self._point_tree.is_sorted:
1087
+ raise ValueError('The tree must be sorted before saving.')
1088
+
1089
+ data = {
1090
+ 'idx': [],
1091
+ 'type_idx': [],
1092
+ 'x': [],
1093
+ 'y': [],
1094
+ 'z': [],
1095
+ 'r': [],
1096
+ 'parent_idx': []
1097
+ }
1098
+
1099
+ for sec in self.sections:
1100
+ points = sec.points if sec.parent is None or sec.parent.parent is None else sec.points[1:]
1101
+ for pt in points:
1102
+ data['idx'].append(pt.idx)
1103
+ data['type_idx'].append(pt.type_idx)
1104
+ data['x'].append(pt.x)
1105
+ data['y'].append(pt.y)
1106
+ data['z'].append(pt.z)
1107
+ data['r'].append(pt.r)
1108
+ data['parent_idx'].append(pt.parent_idx)
1109
+
1110
+ df = pd.DataFrame(data)
1111
+ df.to_csv(path_to_file, sep=' ', index=False, header=False)
1112
+