dendrotweaks 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dendrotweaks/__init__.py +10 -0
- dendrotweaks/analysis/__init__.py +11 -0
- dendrotweaks/analysis/ephys_analysis.py +482 -0
- dendrotweaks/analysis/morphometric_analysis.py +106 -0
- dendrotweaks/membrane/__init__.py +6 -0
- dendrotweaks/membrane/default_mod/AMPA.mod +65 -0
- dendrotweaks/membrane/default_mod/AMPA_NMDA.mod +100 -0
- dendrotweaks/membrane/default_mod/CaDyn.mod +54 -0
- dendrotweaks/membrane/default_mod/GABAa.mod +65 -0
- dendrotweaks/membrane/default_mod/Leak.mod +27 -0
- dendrotweaks/membrane/default_mod/NMDA.mod +72 -0
- dendrotweaks/membrane/default_mod/vecstim.mod +76 -0
- dendrotweaks/membrane/default_templates/NEURON_template.py +354 -0
- dendrotweaks/membrane/default_templates/default.py +73 -0
- dendrotweaks/membrane/default_templates/standard_channel.mod +87 -0
- dendrotweaks/membrane/default_templates/template_jaxley.py +108 -0
- dendrotweaks/membrane/default_templates/template_jaxley_new.py +108 -0
- dendrotweaks/membrane/distributions.py +324 -0
- dendrotweaks/membrane/groups.py +103 -0
- dendrotweaks/membrane/io/__init__.py +11 -0
- dendrotweaks/membrane/io/ast.py +201 -0
- dendrotweaks/membrane/io/code_generators.py +312 -0
- dendrotweaks/membrane/io/converter.py +108 -0
- dendrotweaks/membrane/io/factories.py +144 -0
- dendrotweaks/membrane/io/grammar.py +417 -0
- dendrotweaks/membrane/io/loader.py +90 -0
- dendrotweaks/membrane/io/parser.py +499 -0
- dendrotweaks/membrane/io/reader.py +212 -0
- dendrotweaks/membrane/mechanisms.py +574 -0
- dendrotweaks/model.py +1916 -0
- dendrotweaks/model_io.py +75 -0
- dendrotweaks/morphology/__init__.py +5 -0
- dendrotweaks/morphology/domains.py +100 -0
- dendrotweaks/morphology/io/__init__.py +5 -0
- dendrotweaks/morphology/io/factories.py +212 -0
- dendrotweaks/morphology/io/reader.py +66 -0
- dendrotweaks/morphology/io/validation.py +212 -0
- dendrotweaks/morphology/point_trees.py +681 -0
- dendrotweaks/morphology/reduce/__init__.py +16 -0
- dendrotweaks/morphology/reduce/reduce.py +155 -0
- dendrotweaks/morphology/reduce/reduced_cylinder.py +129 -0
- dendrotweaks/morphology/sec_trees.py +1112 -0
- dendrotweaks/morphology/seg_trees.py +157 -0
- dendrotweaks/morphology/trees.py +567 -0
- dendrotweaks/path_manager.py +261 -0
- dendrotweaks/simulators.py +235 -0
- dendrotweaks/stimuli/__init__.py +3 -0
- dendrotweaks/stimuli/iclamps.py +73 -0
- dendrotweaks/stimuli/populations.py +265 -0
- dendrotweaks/stimuli/synapses.py +203 -0
- dendrotweaks/utils.py +239 -0
- dendrotweaks-0.3.1.dist-info/METADATA +70 -0
- dendrotweaks-0.3.1.dist-info/RECORD +56 -0
- dendrotweaks-0.3.1.dist-info/WHEEL +5 -0
- dendrotweaks-0.3.1.dist-info/licenses/LICENSE +674 -0
- dendrotweaks-0.3.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1112 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import pandas as pd
|
3
|
+
import matplotlib.pyplot as plt
|
4
|
+
from matplotlib.gridspec import GridSpec
|
5
|
+
import matplotlib.colors as mcolors
|
6
|
+
|
7
|
+
from typing import Callable, List
|
8
|
+
from neuron import h
|
9
|
+
|
10
|
+
from dendrotweaks.morphology.trees import Node, Tree
|
11
|
+
from dendrotweaks.morphology.domains import Domain
|
12
|
+
from dataclasses import dataclass, field
|
13
|
+
from bisect import bisect_left
|
14
|
+
|
15
|
+
import warnings
|
16
|
+
|
17
|
+
from dendrotweaks.utils import get_domain_color
|
18
|
+
|
19
|
+
def custom_warning_formatter(message, category, filename, lineno, file=None, line=None):
|
20
|
+
return f"{category.__name__}: {message} ({os.path.basename(filename)}, line {lineno})\n"
|
21
|
+
|
22
|
+
warnings.formatwarning = custom_warning_formatter
|
23
|
+
|
24
|
+
|
25
|
+
|
26
|
+
class Section(Node):
|
27
|
+
"""
|
28
|
+
A class representing a section in a neuron morphology.
|
29
|
+
|
30
|
+
A section is continuous part of a neuron's morphology between two branching points.
|
31
|
+
|
32
|
+
Parameters
|
33
|
+
----------
|
34
|
+
idx : str
|
35
|
+
The index of the section.
|
36
|
+
parent_idx : str
|
37
|
+
The index of the parent section.
|
38
|
+
points : List[Point]
|
39
|
+
The points that define the section.
|
40
|
+
|
41
|
+
Attributes
|
42
|
+
----------
|
43
|
+
points : List[Point]
|
44
|
+
The points that define the section.
|
45
|
+
segments : List[Segment]
|
46
|
+
The segments to which the section is divided.
|
47
|
+
_ref : h.Section
|
48
|
+
The reference to the NEURON section.
|
49
|
+
"""
|
50
|
+
|
51
|
+
def __init__(self, idx: str, parent_idx: str, points: List[Node]) -> None:
|
52
|
+
super().__init__(idx, parent_idx)
|
53
|
+
self.domain_idx = None
|
54
|
+
self.points = points
|
55
|
+
self.segments = []
|
56
|
+
self._ref = None
|
57
|
+
self._domain = self.points[0].domain
|
58
|
+
|
59
|
+
if not all(pt.domain == self._domain for pt in points):
|
60
|
+
raise ValueError('All points in a section must belong to the same domain.')
|
61
|
+
|
62
|
+
|
63
|
+
# MAGIC METHODS
|
64
|
+
|
65
|
+
def __call__(self, x: float):
|
66
|
+
"""
|
67
|
+
Return the segment at a given position.
|
68
|
+
"""
|
69
|
+
if self._ref is None:
|
70
|
+
raise ValueError('Section is not referenced in NEURON.')
|
71
|
+
if x < 0 or x > 1:
|
72
|
+
raise ValueError('Location x must be in the range [0, 1].')
|
73
|
+
elif x == 0:
|
74
|
+
# TODO: Decide how to handle sec(0) and sec(1)
|
75
|
+
# as they are not shown in the seg_graph
|
76
|
+
return self.segments[0]
|
77
|
+
elif x == 1:
|
78
|
+
return self.segments[-1]
|
79
|
+
matching_segs = [self._ref(x) == seg._ref for seg in self.segments]
|
80
|
+
if any(matching_segs):
|
81
|
+
return self.segments[matching_segs.index(True)]
|
82
|
+
raise ValueError(f'No segment found at location {x}')
|
83
|
+
|
84
|
+
|
85
|
+
def __iter__(self):
|
86
|
+
"""
|
87
|
+
Iterate over the segments in the section.
|
88
|
+
"""
|
89
|
+
for seg in self.segments:
|
90
|
+
yield seg
|
91
|
+
|
92
|
+
|
93
|
+
# PROPERTIES
|
94
|
+
|
95
|
+
@property
|
96
|
+
def domain(self):
|
97
|
+
"""
|
98
|
+
The morphological or functional domain of the node.
|
99
|
+
"""
|
100
|
+
return self._domain
|
101
|
+
|
102
|
+
|
103
|
+
@domain.setter
|
104
|
+
def domain(self, domain):
|
105
|
+
self._domain = domain
|
106
|
+
for pt in self.points:
|
107
|
+
pt.domain = domain
|
108
|
+
|
109
|
+
|
110
|
+
@property
|
111
|
+
def df_points(self):
|
112
|
+
"""
|
113
|
+
A DataFrame of the points in the section.
|
114
|
+
"""
|
115
|
+
# concatenate the dataframes of the nodes
|
116
|
+
return pd.concat([pt.df for pt in self.points])
|
117
|
+
|
118
|
+
|
119
|
+
@property
|
120
|
+
def df(self):
|
121
|
+
"""
|
122
|
+
A DataFrame of the section.
|
123
|
+
"""
|
124
|
+
return pd.DataFrame({'idx': [self.idx],
|
125
|
+
'parent_idx': [self.parent_idx]})
|
126
|
+
|
127
|
+
|
128
|
+
# TODO: Figure out why this is different from NEURON's diam
|
129
|
+
# Update: In NEURON, sec.diam returns the diameter of the segment at the center of the section
|
130
|
+
# In this implementation, sec.diam returns the average diameter of the section
|
131
|
+
# @property
|
132
|
+
# def diam(self):
|
133
|
+
# """
|
134
|
+
# Average diameter of the section calculated from
|
135
|
+
# the radii and distances of the points.
|
136
|
+
# """
|
137
|
+
# distances = self.distances # Cumulative distances
|
138
|
+
# radii = self.radii # Corresponding radii
|
139
|
+
# total_length = distances[-1] # Total section length
|
140
|
+
|
141
|
+
# if total_length == 0:
|
142
|
+
# return 0 # Avoid division by zero for zero-length sections
|
143
|
+
|
144
|
+
# segment_lengths = np.diff(distances) # Lengths of frusta segments
|
145
|
+
# segment_diameters = 2 * (np.array(radii[:-1]) + np.array(radii[1:])) / 2 # Mean diameter per segment
|
146
|
+
|
147
|
+
# # Length-weighted average
|
148
|
+
# avg_diameter = np.sum(segment_diameters * segment_lengths) / total_length
|
149
|
+
|
150
|
+
# return avg_diameter
|
151
|
+
|
152
|
+
@property
|
153
|
+
def diam(self):
|
154
|
+
"""
|
155
|
+
Diameter of the central segment of the section (from NEURON).
|
156
|
+
"""
|
157
|
+
return self._ref.diam
|
158
|
+
|
159
|
+
|
160
|
+
@property
|
161
|
+
def L(self):
|
162
|
+
"""
|
163
|
+
Length of the section (from NEURON).
|
164
|
+
"""
|
165
|
+
return self._ref.L
|
166
|
+
|
167
|
+
|
168
|
+
@property
|
169
|
+
def cm(self):
|
170
|
+
"""
|
171
|
+
Specific membrane capacitance of the section (from NEURON).
|
172
|
+
"""
|
173
|
+
return self._ref.cm
|
174
|
+
|
175
|
+
|
176
|
+
@property
|
177
|
+
def Ra(self):
|
178
|
+
"""
|
179
|
+
Axial resistance of the section (from NEURON).
|
180
|
+
"""
|
181
|
+
return self._ref.Ra
|
182
|
+
|
183
|
+
|
184
|
+
@property
|
185
|
+
def nseg(self):
|
186
|
+
"""
|
187
|
+
Number of segments in the section (from NEURON).
|
188
|
+
"""
|
189
|
+
return self._ref.nseg
|
190
|
+
|
191
|
+
@nseg.setter
|
192
|
+
def nseg(self, value):
|
193
|
+
if value < 1:
|
194
|
+
raise ValueError('Number of segments must be at least 1.')
|
195
|
+
if value % 2 == 0:
|
196
|
+
raise ValueError('Number of segments must be odd.')
|
197
|
+
# Set the number in NEURON
|
198
|
+
self._ref.nseg = value
|
199
|
+
# Get the new NEURON segments
|
200
|
+
nrnsegs = [seg for seg in self._ref]
|
201
|
+
|
202
|
+
# Create new DendroTweaks segments
|
203
|
+
from dendrotweaks.morphology.seg_trees import Segment
|
204
|
+
old_segments = self.segments
|
205
|
+
new_segments = [Segment(idx=0, parent_idx=0, neuron_seg=seg, section=self)
|
206
|
+
for seg in nrnsegs]
|
207
|
+
|
208
|
+
seg_tree = self._tree._seg_tree
|
209
|
+
first_segment = self.segments[0]
|
210
|
+
parent = first_segment.parent
|
211
|
+
|
212
|
+
for i, seg in enumerate(new_segments[:]):
|
213
|
+
if i == 0:
|
214
|
+
seg_tree.insert_node_before(seg, first_segment)
|
215
|
+
else:
|
216
|
+
seg_tree.insert_node_before(seg, new_segments[i-1])
|
217
|
+
|
218
|
+
for seg in old_segments:
|
219
|
+
seg_tree.remove_node(seg)
|
220
|
+
|
221
|
+
# Sort the tree
|
222
|
+
self._tree._seg_tree.sort()
|
223
|
+
# Update the section's segments
|
224
|
+
self.segments = new_segments
|
225
|
+
|
226
|
+
|
227
|
+
@property
|
228
|
+
def radii(self):
|
229
|
+
"""
|
230
|
+
Radii of the points in the section.
|
231
|
+
"""
|
232
|
+
return [pt.r for pt in self.points]
|
233
|
+
|
234
|
+
|
235
|
+
@property
|
236
|
+
def diameters(self):
|
237
|
+
"""
|
238
|
+
Diameters of the points in the section.
|
239
|
+
"""
|
240
|
+
return [2 * pt.r for pt in self.points]
|
241
|
+
|
242
|
+
|
243
|
+
@property
|
244
|
+
def xs (self):
|
245
|
+
"""
|
246
|
+
X-coordinates of the points in the section.
|
247
|
+
"""
|
248
|
+
return [pt.x for pt in self.points]
|
249
|
+
|
250
|
+
|
251
|
+
@property
|
252
|
+
def ys(self):
|
253
|
+
"""
|
254
|
+
Y-coordinates of the points in the section.
|
255
|
+
"""
|
256
|
+
return [pt.y for pt in self.points]
|
257
|
+
|
258
|
+
|
259
|
+
@property
|
260
|
+
def zs(self):
|
261
|
+
"""
|
262
|
+
Z-coordinates of the points in the section.
|
263
|
+
"""
|
264
|
+
return [pt.z for pt in self.points]
|
265
|
+
|
266
|
+
|
267
|
+
@property
|
268
|
+
def seg_centers(self):
|
269
|
+
"""
|
270
|
+
The list of segment centers in the section with normalized length.
|
271
|
+
"""
|
272
|
+
if self._ref is None:
|
273
|
+
raise ValueError('Section is not referenced in NEURON.')
|
274
|
+
return (np.array([(2*i - 1) / (2 * self._ref.nseg)
|
275
|
+
for i in range(1, self._ref.nseg + 1)]) * self._ref.L).tolist()
|
276
|
+
|
277
|
+
|
278
|
+
@property
|
279
|
+
def seg_borders(self):
|
280
|
+
"""
|
281
|
+
The list of segment borders in the section with normalized length.
|
282
|
+
"""
|
283
|
+
if self._ref is None:
|
284
|
+
raise ValueError('Section is not referenced in NEURON.')
|
285
|
+
nseg = int(self._ref.nseg)
|
286
|
+
return [i / nseg for i in range(nseg + 1)]
|
287
|
+
|
288
|
+
|
289
|
+
@property
|
290
|
+
def distances(self):
|
291
|
+
"""
|
292
|
+
The list of cumulative euclidean distances of the points in the section.
|
293
|
+
"""
|
294
|
+
coords = np.array([[pt.x, pt.y, pt.z] for pt in self.points])
|
295
|
+
deltas = np.diff(coords, axis=0)
|
296
|
+
frusta_distances = np.sqrt(np.sum(deltas**2, axis=1))
|
297
|
+
cumulative_frusta_distances = np.insert(np.cumsum(frusta_distances), 0, 0)
|
298
|
+
return cumulative_frusta_distances
|
299
|
+
|
300
|
+
|
301
|
+
@property
|
302
|
+
def center(self):
|
303
|
+
"""
|
304
|
+
The coordinates of the center of the section.
|
305
|
+
"""
|
306
|
+
return np.mean(self.xs), np.mean(self.ys), np.mean(self.zs)
|
307
|
+
|
308
|
+
|
309
|
+
@property
|
310
|
+
def length(self):
|
311
|
+
"""
|
312
|
+
The length of the section calculated as the sum of the distances between the points.
|
313
|
+
"""
|
314
|
+
return self.distances[-1]
|
315
|
+
|
316
|
+
|
317
|
+
@property
|
318
|
+
def area(self):
|
319
|
+
"""
|
320
|
+
The surface area of the section calculated as the sum of the areas of the frusta segments.
|
321
|
+
"""
|
322
|
+
areas = [np.pi * (r1 + r2) * np.sqrt((r1 - r2)**2 + h**2) for r1, r2, h in zip(self.radii[:-1], self.radii[1:], np.diff(self.distances))]
|
323
|
+
return sum(areas)
|
324
|
+
|
325
|
+
def has_mechanism(mech_name):
|
326
|
+
"""
|
327
|
+
Check if the section has a mechanism inserted.
|
328
|
+
|
329
|
+
Parameters
|
330
|
+
----------
|
331
|
+
mech_name : str
|
332
|
+
The name of the mechanism to check.
|
333
|
+
"""
|
334
|
+
return self._ref.has_membrane(mech_name)
|
335
|
+
|
336
|
+
|
337
|
+
# REFERENCING METHODS
|
338
|
+
|
339
|
+
def create_and_reference(self, simulator_name='NEURON'):
|
340
|
+
"""
|
341
|
+
Add a reference to the section in the simulator.
|
342
|
+
|
343
|
+
Parameters
|
344
|
+
----------
|
345
|
+
simulator_name : str
|
346
|
+
The name of the simulator to create the section in.
|
347
|
+
"""
|
348
|
+
if simulator_name == 'NEURON':
|
349
|
+
self.create_NEURON_section()
|
350
|
+
elif simulator_name == 'JAXLEY':
|
351
|
+
self.create_JAXLEY_section()
|
352
|
+
|
353
|
+
|
354
|
+
def create_NEURON_section(self):
|
355
|
+
"""
|
356
|
+
Create a NEURON section.
|
357
|
+
"""
|
358
|
+
self._ref = h.Section() # name=f'Sec_{self.idx}'
|
359
|
+
if self.parent is not None:
|
360
|
+
# TODO: Attaching basal to soma 0
|
361
|
+
if self.parent.parent is None: # if parent is soma
|
362
|
+
self._ref.connect(self.parent._ref(0.5))
|
363
|
+
else:
|
364
|
+
self._ref.connect(self.parent._ref(1))
|
365
|
+
# Add 3D points to the section
|
366
|
+
self._ref.pt3dclear()
|
367
|
+
for pt in self.points:
|
368
|
+
diam = 2*pt.r
|
369
|
+
diam = round(diam, 16)
|
370
|
+
self._ref.pt3dadd(pt.x, pt.y, pt.z, diam)
|
371
|
+
|
372
|
+
def create_JAXLEY_section(self):
|
373
|
+
"""
|
374
|
+
Create a JAXLEY section.
|
375
|
+
"""
|
376
|
+
raise NotImplementedError
|
377
|
+
|
378
|
+
|
379
|
+
# MECHANISM METHODS
|
380
|
+
|
381
|
+
def insert_mechanism(self, name: str):
|
382
|
+
"""
|
383
|
+
Inserts a mechanism in the section if
|
384
|
+
it is not already inserted.
|
385
|
+
"""
|
386
|
+
if self._ref.has_membrane(name):
|
387
|
+
return
|
388
|
+
self._ref.insert(name)
|
389
|
+
|
390
|
+
|
391
|
+
def uninsert_mechanism(self, name: str):
|
392
|
+
"""
|
393
|
+
Uninserts a mechanism in the section if
|
394
|
+
it was inserted.
|
395
|
+
"""
|
396
|
+
if not self._ref.has_membrane(name):
|
397
|
+
return
|
398
|
+
self._ref.uninsert(name)
|
399
|
+
|
400
|
+
|
401
|
+
# PARAMETER METHODS
|
402
|
+
|
403
|
+
def get_param_value(self, param_name):
|
404
|
+
"""
|
405
|
+
Get the average parameter of the section's segments.
|
406
|
+
|
407
|
+
Parameters
|
408
|
+
----------
|
409
|
+
param_name : str
|
410
|
+
The name of the parameter to get.
|
411
|
+
|
412
|
+
Returns
|
413
|
+
-------
|
414
|
+
float
|
415
|
+
The average value of the parameter in the section's segments.
|
416
|
+
"""
|
417
|
+
# if param_name in ['Ra', 'diam', 'L', 'nseg', 'domain', 'subtree_size']:
|
418
|
+
# return getattr(self, param_name)
|
419
|
+
# if param_name in ['dist']:
|
420
|
+
# return self.distance_to_root(0.5)
|
421
|
+
seg_values = [seg.get_param_value(param_name) for seg in self.segments]
|
422
|
+
return round(np.mean(seg_values), 16)
|
423
|
+
|
424
|
+
|
425
|
+
def path_distance(self, relative_position: float = 0,
|
426
|
+
within_domain: bool = False) -> float:
|
427
|
+
"""
|
428
|
+
Calculate the distance from the section to the root at a given relative position.
|
429
|
+
|
430
|
+
Parameters
|
431
|
+
----------
|
432
|
+
relative_position : float
|
433
|
+
The position along the section's normalized length [0, 1].
|
434
|
+
within_domain : bool
|
435
|
+
Whether to stop at the domain boundary.
|
436
|
+
|
437
|
+
Returns
|
438
|
+
-------
|
439
|
+
float
|
440
|
+
The distance from the section to the root.
|
441
|
+
|
442
|
+
Important
|
443
|
+
---------
|
444
|
+
Assumes that we always attach the 0 end of the child.
|
445
|
+
"""
|
446
|
+
if not (0 <= relative_position <= 1):
|
447
|
+
raise ValueError('Relative position must be between 0 and 1.')
|
448
|
+
|
449
|
+
distance = 0
|
450
|
+
factor = relative_position
|
451
|
+
node = self
|
452
|
+
|
453
|
+
while node.parent:
|
454
|
+
|
455
|
+
distance += factor * node.length
|
456
|
+
|
457
|
+
if within_domain and node.parent.domain != node.domain:
|
458
|
+
break
|
459
|
+
|
460
|
+
node = node.parent
|
461
|
+
factor = 1
|
462
|
+
|
463
|
+
return distance
|
464
|
+
|
465
|
+
|
466
|
+
def disconnect_from_parent(self):
|
467
|
+
"""
|
468
|
+
Detach the section from its parent section.
|
469
|
+
"""
|
470
|
+
# In SectionTree
|
471
|
+
super().disconnect_from_parent()
|
472
|
+
# In NEURON
|
473
|
+
if self._ref:
|
474
|
+
h.disconnect(sec=self._ref) #from parent
|
475
|
+
# In PointTree
|
476
|
+
self.points[0].disconnect_from_parent()
|
477
|
+
# In SegmentTree
|
478
|
+
if self.segments:
|
479
|
+
self.segments[0].disconnect_from_parent()
|
480
|
+
|
481
|
+
|
482
|
+
def connect_to_parent(self, parent):
|
483
|
+
"""
|
484
|
+
Attach the section to a parent section.
|
485
|
+
|
486
|
+
Parameters
|
487
|
+
----------
|
488
|
+
parent : Section
|
489
|
+
The parent section to attach to.
|
490
|
+
"""
|
491
|
+
# In SectionTree
|
492
|
+
super().connect_to_parent(parent)
|
493
|
+
# In NEURON
|
494
|
+
if self._ref:
|
495
|
+
if self.parent is not None:
|
496
|
+
if self.parent.parent is None: # if parent is soma
|
497
|
+
self._ref.connect(self.parent._ref(0.5))
|
498
|
+
else:
|
499
|
+
self._ref.connect(self.parent._ref(1))
|
500
|
+
|
501
|
+
# In PointTree
|
502
|
+
if self.parent is not None:
|
503
|
+
if self.parent.parent is None: # if parent is soma
|
504
|
+
parent_sec = self.parent
|
505
|
+
parent_pt = parent_sec.points[1] if len(parent_sec.points) > 1 else parent_sec.points[0]
|
506
|
+
self.points[0].connect_to_parent(parent_pt) # attach to the middle of the parent
|
507
|
+
else:
|
508
|
+
self.points[0].connect_to_parent(parent.points[-1]) # attach to the end of the parent
|
509
|
+
# In SegmentTree
|
510
|
+
if self.segments:
|
511
|
+
self.segments[0].connect_to_parent(parent.segments[-1])
|
512
|
+
|
513
|
+
|
514
|
+
# PLOTTING METHODS
|
515
|
+
|
516
|
+
def plot(self, ax=None, plot_parent=True, section_color=None, parent_color='gray',
|
517
|
+
show_labels=True, aspect_equal=True):
|
518
|
+
"""
|
519
|
+
Plot section morphology in 3D projections (XZ, YZ, XY) and radii distribution.
|
520
|
+
|
521
|
+
Parameters
|
522
|
+
----------
|
523
|
+
ax : list or array of matplotlib.axes.Axes, optional
|
524
|
+
Four axes for plotting (XZ, YZ, XY, radii). If None, creates a new figure with axes.
|
525
|
+
plot_parent : bool, optional
|
526
|
+
Whether to include parent section in the visualization.
|
527
|
+
section_color : str or None, optional
|
528
|
+
Color for the current section. If None, assigns based on section domain.
|
529
|
+
parent_color : str, optional
|
530
|
+
Color for the parent section.
|
531
|
+
show_labels : bool, optional
|
532
|
+
Whether to show axis labels and titles.
|
533
|
+
aspect_equal : bool, optional
|
534
|
+
Whether to set aspect ratio to 'equal' for the projections.
|
535
|
+
|
536
|
+
Returns
|
537
|
+
-------
|
538
|
+
ax : list of matplotlib.axes.Axes
|
539
|
+
The axes containing the plots.
|
540
|
+
"""
|
541
|
+
# Create figure and axes if not provided
|
542
|
+
if ax is None:
|
543
|
+
fig = plt.figure(figsize=(10, 8))
|
544
|
+
gs = GridSpec(2, 3, width_ratios=[1, 1, 1.2], figure=fig)
|
545
|
+
|
546
|
+
# Create the three projection axes and one radius axis
|
547
|
+
ax_xz = fig.add_subplot(gs[0, 0])
|
548
|
+
ax_yz = fig.add_subplot(gs[0, 1])
|
549
|
+
ax_xy = fig.add_subplot(gs[1, 0])
|
550
|
+
ax_radii = fig.add_subplot(gs[1, 1:])
|
551
|
+
ax = [ax_xz, ax_yz, ax_xy, ax_radii]
|
552
|
+
else:
|
553
|
+
# Use provided axes
|
554
|
+
if len(ax) != 4:
|
555
|
+
# flatten
|
556
|
+
ax = [ai for a in ax for ai in a]
|
557
|
+
ax_xz, ax_yz, ax_xy, ax_radii = ax
|
558
|
+
|
559
|
+
# Determine section color based on domain if not provided
|
560
|
+
if section_color is None:
|
561
|
+
section_color = get_domain_color(self.domain)
|
562
|
+
|
563
|
+
# Extract coordinates
|
564
|
+
xs = np.array([p.x for p in self.points])
|
565
|
+
ys = np.array([p.y for p in self.points])
|
566
|
+
zs = np.array([p.z for p in self.points])
|
567
|
+
|
568
|
+
# Plot section projections
|
569
|
+
self._plot_projection(ax_xz, xs, zs, 'X', 'Z', 'XZ Projection',
|
570
|
+
section_color, show_labels, aspect_equal)
|
571
|
+
self._plot_projection(ax_yz, ys, zs, 'Y', 'Z', 'YZ Projection',
|
572
|
+
section_color, show_labels, aspect_equal)
|
573
|
+
self._plot_projection(ax_xy, xs, ys, 'X', 'Y', 'XY Projection',
|
574
|
+
section_color, show_labels, aspect_equal)
|
575
|
+
|
576
|
+
# Plot radius distribution
|
577
|
+
self._plot_radii_distribution(ax_radii, plot_parent, section_color, parent_color)
|
578
|
+
|
579
|
+
# Plot parent section if requested
|
580
|
+
if plot_parent and self.parent:
|
581
|
+
# Only plot parent projections, radii are handled in _plot_radii_distribution
|
582
|
+
parent_xs = np.array([p.x for p in self.parent.points])
|
583
|
+
parent_ys = np.array([p.y for p in self.parent.points])
|
584
|
+
parent_zs = np.array([p.z for p in self.parent.points])
|
585
|
+
|
586
|
+
self.parent._plot_projection(ax_xz, parent_xs, parent_zs, None, None, None,
|
587
|
+
parent_color, False, aspect_equal)
|
588
|
+
self.parent._plot_projection(ax_yz, parent_ys, parent_zs, None, None, None,
|
589
|
+
parent_color, False, aspect_equal)
|
590
|
+
self.parent._plot_projection(ax_xy, parent_xs, parent_ys, None, None, None,
|
591
|
+
parent_color, False, aspect_equal)
|
592
|
+
|
593
|
+
# Add overall title if we created the figure
|
594
|
+
if ax is not None and show_labels:
|
595
|
+
fig = ax_xz.get_figure()
|
596
|
+
fig.suptitle(f"Section {self.idx} ({self.domain})", fontsize=14)
|
597
|
+
fig.tight_layout()
|
598
|
+
|
599
|
+
return ax
|
600
|
+
|
601
|
+
def _plot_projection(self, ax, x_coords, y_coords, x_label, y_label, title,
|
602
|
+
color, show_labels, aspect_equal):
|
603
|
+
"""Helper method to plot a 2D projection of the section."""
|
604
|
+
ax.plot(x_coords, y_coords, 'o-', color=color, markerfacecolor=color,
|
605
|
+
markeredgecolor='black', markersize=4, linewidth=1.5)
|
606
|
+
|
607
|
+
if show_labels:
|
608
|
+
if x_label:
|
609
|
+
ax.set_xlabel(x_label)
|
610
|
+
if y_label:
|
611
|
+
ax.set_ylabel(y_label)
|
612
|
+
if title:
|
613
|
+
ax.set_title(title)
|
614
|
+
|
615
|
+
if aspect_equal:
|
616
|
+
ax.set_aspect('equal')
|
617
|
+
|
618
|
+
def _plot_radii_distribution(self, ax, plot_parent, section_color, parent_color):
|
619
|
+
"""Helper method to plot radius distribution along the section."""
|
620
|
+
# Get section length for normalization
|
621
|
+
section_length = self.distances[-1] - self.distances[0]
|
622
|
+
|
623
|
+
# Normalize distances to start at 0 and end at section_length
|
624
|
+
normalized_distances = self.distances - self.distances[0]
|
625
|
+
|
626
|
+
# Plot section radii
|
627
|
+
ax.plot(normalized_distances, self.radii, 'o-', color=section_color,
|
628
|
+
label=f"{self.domain} ({self.idx})", linewidth=2)
|
629
|
+
|
630
|
+
# Plot reference NEURON segments if available
|
631
|
+
if hasattr(self, '_ref') and self._ref:
|
632
|
+
# Calculate normalized segment centers
|
633
|
+
normalized_seg_centers = np.array(self.seg_centers) - self.distances[0]
|
634
|
+
|
635
|
+
# Extract radii from segments
|
636
|
+
seg_radii = np.array([seg.diam / 2 for seg in self._ref])
|
637
|
+
|
638
|
+
# Use the specified bar width calculation from original code
|
639
|
+
bar_width = [self._ref.L / self._ref.nseg] * self._ref.nseg
|
640
|
+
|
641
|
+
# Plot segment radii as bars
|
642
|
+
ax.bar(normalized_seg_centers, seg_radii, width=bar_width,
|
643
|
+
alpha=0.5, color=section_color, edgecolor='white',
|
644
|
+
label=f"{self.domain} segments")
|
645
|
+
|
646
|
+
# Plot parent section if requested
|
647
|
+
if plot_parent and self.parent:
|
648
|
+
parent_length = self.parent.distances[-1] - self.parent.distances[0]
|
649
|
+
|
650
|
+
# Normalize parent distances to end at 0 (connecting to child)
|
651
|
+
# Parent section goes from -parent_length to 0
|
652
|
+
normalized_parent_distances = self.parent.distances - self.parent.distances[-1]
|
653
|
+
|
654
|
+
# Plot parent radii
|
655
|
+
ax.plot(normalized_parent_distances, self.parent.radii, 'o-',
|
656
|
+
color=parent_color, linewidth=2,
|
657
|
+
label=f"Parent {self.parent.domain} ({self.parent.idx})")
|
658
|
+
|
659
|
+
# Plot parent reference segments if available
|
660
|
+
if hasattr(self.parent, '_ref') and self.parent._ref:
|
661
|
+
# Normalize parent segment centers to the same scale
|
662
|
+
normalized_parent_seg_centers = (np.array(self.parent.seg_centers) -
|
663
|
+
self.parent.distances[-1])
|
664
|
+
|
665
|
+
# Extract parent segment radii
|
666
|
+
parent_seg_radii = np.array([seg.diam / 2 for seg in self.parent._ref])
|
667
|
+
|
668
|
+
# Use the specified bar width calculation for parent
|
669
|
+
parent_bar_width = [self.parent._ref.L / self.parent._ref.nseg] * self.parent._ref.nseg
|
670
|
+
|
671
|
+
# Plot parent segment radii as bars
|
672
|
+
ax.bar(normalized_parent_seg_centers, parent_seg_radii,
|
673
|
+
width=parent_bar_width, alpha=0.5, color=parent_color,
|
674
|
+
edgecolor='white', label=f"Parent segments")
|
675
|
+
|
676
|
+
# Set plot labels and legend
|
677
|
+
ax.set_xlabel('Distance (µm)')
|
678
|
+
ax.set_ylabel('Radius (µm)')
|
679
|
+
ax.set_title('Radius Distribution')
|
680
|
+
|
681
|
+
# Ensure y-axis starts at 0
|
682
|
+
ax.set_ylim(bottom=0)
|
683
|
+
|
684
|
+
# Adjust x-axis to show the full section(s)
|
685
|
+
if plot_parent and self.parent:
|
686
|
+
parent_length = self.parent.distances[-1] - self.parent.distances[0]
|
687
|
+
ax.set_xlim(-parent_length * 1.05, section_length * 1.05)
|
688
|
+
else:
|
689
|
+
ax.set_xlim(-section_length * 0.05, section_length * 1.05)
|
690
|
+
|
691
|
+
# Add legend if we have multiple data series
|
692
|
+
if ((hasattr(self, '_ref') and self._ref) or
|
693
|
+
(plot_parent and self.parent)):
|
694
|
+
ax.legend(loc='best', frameon=True, framealpha=0.8)
|
695
|
+
|
696
|
+
def plot_radii(self, ax=None, include_parent=False, section_color=None, parent_color='gray'):
|
697
|
+
"""
|
698
|
+
Plot just the radius distribution for the section.
|
699
|
+
|
700
|
+
Parameters
|
701
|
+
----------
|
702
|
+
ax : matplotlib.axes.Axes, optional
|
703
|
+
Axes to plot on. If None, creates a new figure and axes.
|
704
|
+
include_parent : bool, optional
|
705
|
+
Whether to include parent section in the plot.
|
706
|
+
section_color : str or None, optional
|
707
|
+
Color for current section. If None, assigns based on section domain.
|
708
|
+
parent_color : str, optional
|
709
|
+
Color for parent section if included.
|
710
|
+
|
711
|
+
Returns
|
712
|
+
-------
|
713
|
+
ax : matplotlib.axes.Axes
|
714
|
+
The axes containing the plot.
|
715
|
+
"""
|
716
|
+
# Create new figure and axes if not provided
|
717
|
+
if ax is None:
|
718
|
+
fig, ax = plt.subplots(figsize=(10, 5))
|
719
|
+
|
720
|
+
# Determine section color based on domain if not provided
|
721
|
+
if section_color is None:
|
722
|
+
domain_colors = {
|
723
|
+
'soma': 'black',
|
724
|
+
'axon': 'red',
|
725
|
+
'dend': 'blue',
|
726
|
+
'apic': 'green'
|
727
|
+
}
|
728
|
+
section_color = domain_colors.get(self.domain, 'purple')
|
729
|
+
|
730
|
+
# Plot radius distribution
|
731
|
+
self._plot_radii_distribution(ax, include_parent, section_color, parent_color)
|
732
|
+
|
733
|
+
# Add title if creating a standalone plot
|
734
|
+
if ax.get_figure().get_axes()[0] == ax: # If this is the only axes in the figure
|
735
|
+
ax.set_title(f"Radius Distribution - Section {self.idx} ({self.domain})")
|
736
|
+
plt.tight_layout()
|
737
|
+
|
738
|
+
return ax
|
739
|
+
|
740
|
+
|
741
|
+
class SectionTree(Tree):
|
742
|
+
"""
|
743
|
+
A class representing a tree graph of sections in a neuron morphology.
|
744
|
+
|
745
|
+
Parameters
|
746
|
+
----------
|
747
|
+
sections : List[Section]
|
748
|
+
A list of sections in the tree.
|
749
|
+
|
750
|
+
Attributes
|
751
|
+
----------
|
752
|
+
domains : Dict[str, Domain]
|
753
|
+
A dictionary of domains in the tree.
|
754
|
+
"""
|
755
|
+
|
756
|
+
def __init__(self, sections: list[Section]) -> None:
|
757
|
+
super().__init__(sections)
|
758
|
+
self._create_domains()
|
759
|
+
self._point_tree = None
|
760
|
+
self._seg_tree = None
|
761
|
+
|
762
|
+
|
763
|
+
def __repr__(self):
|
764
|
+
return f"SectionTree(root={self.root!r}, num_nodes={len(self._nodes)})"
|
765
|
+
|
766
|
+
|
767
|
+
def _create_domains(self):
|
768
|
+
"""
|
769
|
+
Create domains using the data from the sections (from the points in the sections).
|
770
|
+
"""
|
771
|
+
|
772
|
+
unique_domain_names = set([sec.domain for sec in self.sections])
|
773
|
+
self.domains = {name: Domain(name) for name in unique_domain_names}
|
774
|
+
|
775
|
+
for sec in self.sections:
|
776
|
+
self.domains[sec.domain].add_section(sec)
|
777
|
+
|
778
|
+
|
779
|
+
# PROPERTIES
|
780
|
+
|
781
|
+
@property
|
782
|
+
def sections(self):
|
783
|
+
"""
|
784
|
+
A list of sections in the tree. Alias for self._nodes.
|
785
|
+
"""
|
786
|
+
return self._nodes
|
787
|
+
|
788
|
+
|
789
|
+
@property
|
790
|
+
def soma(self):
|
791
|
+
"""
|
792
|
+
The soma section of the tree. Alias for self.root.
|
793
|
+
"""
|
794
|
+
return self.root
|
795
|
+
|
796
|
+
|
797
|
+
@property
|
798
|
+
def sections_by_depth(self):
|
799
|
+
"""
|
800
|
+
A dictionary of sections grouped by depth in the tree
|
801
|
+
(depth is the number of edges from the root).
|
802
|
+
"""
|
803
|
+
sections_by_depth = {}
|
804
|
+
for sec in self.sections:
|
805
|
+
if sec.depth not in sections_by_depth:
|
806
|
+
sections_by_depth[sec.depth] = []
|
807
|
+
sections_by_depth[sec.depth].append(sec)
|
808
|
+
return sections_by_depth
|
809
|
+
|
810
|
+
|
811
|
+
@property
|
812
|
+
def df(self):
|
813
|
+
"""
|
814
|
+
A DataFrame of the sections in the tree.
|
815
|
+
"""
|
816
|
+
data = {
|
817
|
+
'idx': [],
|
818
|
+
'domain': [],
|
819
|
+
'x': [],
|
820
|
+
'y': [],
|
821
|
+
'z': [],
|
822
|
+
'r': [],
|
823
|
+
'parent_idx': [],
|
824
|
+
'section_idx': [],
|
825
|
+
'parent_section_idx': [],
|
826
|
+
}
|
827
|
+
|
828
|
+
for sec in self.sections:
|
829
|
+
points = sec.points if sec.parent is None or sec.parent.parent is None else sec.points[1:]
|
830
|
+
for pt in points:
|
831
|
+
data['idx'].append(pt.idx)
|
832
|
+
data['domain'].append(pt.domain)
|
833
|
+
data['x'].append(pt.x)
|
834
|
+
data['y'].append(pt.y)
|
835
|
+
data['z'].append(pt.z)
|
836
|
+
data['r'].append(pt.r)
|
837
|
+
data['parent_idx'].append(pt.parent_idx)
|
838
|
+
data['section_idx'].append(sec.idx)
|
839
|
+
data['parent_section_idx'].append(sec.parent_idx)
|
840
|
+
|
841
|
+
return pd.DataFrame(data)
|
842
|
+
|
843
|
+
|
844
|
+
def sort(self, **kwargs):
|
845
|
+
"""
|
846
|
+
Sort the sections in the tree using a depth-first traversal.
|
847
|
+
|
848
|
+
Parameters
|
849
|
+
----------
|
850
|
+
sort_children : bool, optional
|
851
|
+
Whether to sort the children of each node
|
852
|
+
based on the number of bifurcations in their subtrees. Defaults to True.
|
853
|
+
force : bool, optional
|
854
|
+
Whether to force the sorting of the tree even if it is already sorted. Defaults to False.
|
855
|
+
"""
|
856
|
+
super().sort(**kwargs)
|
857
|
+
self._point_tree.sort(**kwargs)
|
858
|
+
if self._seg_tree:
|
859
|
+
self._seg_tree.sort(**kwargs)
|
860
|
+
|
861
|
+
|
862
|
+
# STRUCTURE METHODS
|
863
|
+
|
864
|
+
def remove_subtree(self, section):
|
865
|
+
"""
|
866
|
+
Remove a section and its subtree from the tree.
|
867
|
+
|
868
|
+
Parameters
|
869
|
+
----------
|
870
|
+
section : Section
|
871
|
+
The section to remove.
|
872
|
+
"""
|
873
|
+
super().remove_subtree(section)
|
874
|
+
# Domains
|
875
|
+
for domain in self.domains.values():
|
876
|
+
for sec in section.subtree:
|
877
|
+
if sec in domain.sections:
|
878
|
+
domain.remove_section(sec)
|
879
|
+
# Points
|
880
|
+
self._point_tree.remove_subtree(section.points[0])
|
881
|
+
# Segments
|
882
|
+
if self._seg_tree:
|
883
|
+
self._seg_tree.remove_subtree(section.segments[0])
|
884
|
+
# NEURON
|
885
|
+
if section._ref:
|
886
|
+
h.disconnect(sec=section._ref)
|
887
|
+
for sec in section.subtree:
|
888
|
+
h.delete_section(sec=sec._ref)
|
889
|
+
|
890
|
+
|
891
|
+
def remove_zero_length_sections(self):
|
892
|
+
"""
|
893
|
+
Remove sections with zero length.
|
894
|
+
"""
|
895
|
+
for sec in self.sections:
|
896
|
+
if sec.length == 0:
|
897
|
+
for pt in sec.points:
|
898
|
+
self._point_tree.remove_node(pt)
|
899
|
+
for seg in sec.segments:
|
900
|
+
self._seg_tree.remove_node(seg)
|
901
|
+
self.remove_node(sec)
|
902
|
+
|
903
|
+
|
904
|
+
def downsample(self, factor: float):
|
905
|
+
"""
|
906
|
+
Downsample the SWC tree by reducing the number of points in each section
|
907
|
+
based on the given factor, while preserving the first and last points.
|
908
|
+
|
909
|
+
:param factor: The proportion of points to keep (e.g., 0.5 keeps 50% of points)
|
910
|
+
If factor is 0, keep only the first and last points.
|
911
|
+
"""
|
912
|
+
for sec in self.sections:
|
913
|
+
if sec is self.soma:
|
914
|
+
continue
|
915
|
+
|
916
|
+
if len(sec.points) < 3: # Keep sections with only start & end points
|
917
|
+
continue
|
918
|
+
|
919
|
+
num_points = len(sec.points)
|
920
|
+
if factor == 0:
|
921
|
+
num_to_keep = 2
|
922
|
+
else:
|
923
|
+
num_to_keep = max(2, int(num_points * factor)) # Ensure at least start & end remain
|
924
|
+
|
925
|
+
# Select indices to keep (first, last, and spaced indices in between)
|
926
|
+
keep_indices = np.linspace(0, num_points - 1, num_to_keep, dtype=int)
|
927
|
+
keep_set = set(keep_indices)
|
928
|
+
|
929
|
+
points_to_remove = [pt for i, pt in enumerate(sec.points) if i not in keep_set]
|
930
|
+
|
931
|
+
print(f'Removing {len(points_to_remove)} points from section {sec.idx}')
|
932
|
+
|
933
|
+
for pt in points_to_remove:
|
934
|
+
self._point_tree.remove_node(pt)
|
935
|
+
sec.points.remove(pt)
|
936
|
+
|
937
|
+
self._point_tree.sort()
|
938
|
+
|
939
|
+
|
940
|
+
# def plot_sections_as_matrix(self, ax=None):
|
941
|
+
# """
|
942
|
+
# Plot the sections as a connectivity matrix.
|
943
|
+
# """
|
944
|
+
# if ax is None:
|
945
|
+
# fig, ax = plt.subplots()
|
946
|
+
|
947
|
+
# n = len(self.sections)
|
948
|
+
# matrix = np.zeros((n, n))
|
949
|
+
# for section in self.sections:
|
950
|
+
# if section.parent:
|
951
|
+
# matrix[section.idx, section.parent.idx] = section.idx
|
952
|
+
# matrix[matrix == 0] = np.nan
|
953
|
+
|
954
|
+
# ax.imshow(matrix.T, cmap='jet_r')
|
955
|
+
# ax.set_xlabel('Section ID')
|
956
|
+
# ax.set_ylabel('Parent ID')
|
957
|
+
|
958
|
+
|
959
|
+
# PLOTTING METHODS
|
960
|
+
|
961
|
+
def plot(self, ax=None, show_points=False, show_lines=True,
|
962
|
+
show_domains=True, annotate=False,
|
963
|
+
projection='XY', highlight_sections=None, focus_sections=None):
|
964
|
+
"""
|
965
|
+
Plot the sections in the tree in a 2D projection.
|
966
|
+
|
967
|
+
Parameters
|
968
|
+
----------
|
969
|
+
ax : matplotlib.axes.Axes, optional
|
970
|
+
Axes to plot on. If None, creates a new figure and axes.
|
971
|
+
show_points : bool, optional
|
972
|
+
Whether to show the points in the sections.
|
973
|
+
show_lines : bool, optional
|
974
|
+
Whether to show the lines connecting the points.
|
975
|
+
show_domains : bool, optional
|
976
|
+
Whether to color sections based on their domain.
|
977
|
+
annotate : bool, optional
|
978
|
+
Whether to annotate the sections with their index.
|
979
|
+
projection : str or tuple, optional
|
980
|
+
The projection to use for the plot. Can be 'XY', 'XZ', 'YZ', or a tuple of two axes.
|
981
|
+
highlight_sections : list of Section, optional
|
982
|
+
Sections to highlight in the plot.
|
983
|
+
focus_sections : list of Section, optional
|
984
|
+
Sections to focus on in the plot.
|
985
|
+
"""
|
986
|
+
if ax is None:
|
987
|
+
fig, ax = plt.subplots(figsize=(10, 10))
|
988
|
+
|
989
|
+
highlight_sections = set(highlight_sections) if highlight_sections else None
|
990
|
+
focus_sections = set(focus_sections) if focus_sections else None
|
991
|
+
x_attr, y_attr = projection[0].lower(), projection[1].lower()
|
992
|
+
|
993
|
+
section_count = len(self.sections) # Avoid recalculating
|
994
|
+
|
995
|
+
for sec in self.sections:
|
996
|
+
# Skip sections that are not in the focus set (if focus is specified)
|
997
|
+
if focus_sections and sec not in focus_sections:
|
998
|
+
continue
|
999
|
+
|
1000
|
+
xs = [getattr(pt, x_attr) for pt in sec.points]
|
1001
|
+
ys = [getattr(pt, y_attr) for pt in sec.points]
|
1002
|
+
|
1003
|
+
# Assign colors based on domains or section index
|
1004
|
+
color = plt.cm.jet(1 - sec.idx / section_count)
|
1005
|
+
if show_domains:
|
1006
|
+
color = get_domain_color(sec.domain)
|
1007
|
+
if highlight_sections and sec in highlight_sections:
|
1008
|
+
color = 'red'
|
1009
|
+
|
1010
|
+
# Plot section points and lines
|
1011
|
+
if show_points:
|
1012
|
+
ax.plot(xs, ys, '.', color=color, markersize=7, markeredgecolor='black')
|
1013
|
+
if show_lines:
|
1014
|
+
ax.plot(xs, ys, color=color, zorder=0)
|
1015
|
+
|
1016
|
+
# Annotate section index if needed
|
1017
|
+
if annotate:
|
1018
|
+
mean_x, mean_y = np.mean(xs), np.mean(ys)
|
1019
|
+
ax.annotate(
|
1020
|
+
f'{sec.idx}', (mean_x, mean_y), fontsize=8,
|
1021
|
+
color='white',
|
1022
|
+
bbox=dict(facecolor='black', edgecolor='white',
|
1023
|
+
boxstyle='round,pad=0.3')
|
1024
|
+
)
|
1025
|
+
|
1026
|
+
ax.set_xlabel(projection[0])
|
1027
|
+
ax.set_ylabel(projection[1])
|
1028
|
+
ax.set_aspect('equal')
|
1029
|
+
|
1030
|
+
|
1031
|
+
|
1032
|
+
def plot_radii_distribution(self, ax=None, highlight=None,
|
1033
|
+
domains=True, show_soma=False):
|
1034
|
+
"""
|
1035
|
+
Plot the radius distribution of the sections in the tree.
|
1036
|
+
|
1037
|
+
Parameters
|
1038
|
+
----------
|
1039
|
+
ax : matplotlib.axes.Axes, optional
|
1040
|
+
Axes to plot on. If None, creates a new figure and axes.
|
1041
|
+
highlight : list of int, optional
|
1042
|
+
Indices of sections to highlight in the plot.
|
1043
|
+
domains : bool, optional
|
1044
|
+
Whether to color sections based on their domain.
|
1045
|
+
show_soma : bool, optional
|
1046
|
+
Whether to show the soma section in the plot.
|
1047
|
+
"""
|
1048
|
+
if ax is None:
|
1049
|
+
fig, ax = plt.subplots(figsize=(8, 3))
|
1050
|
+
|
1051
|
+
for sec in self.sections:
|
1052
|
+
if not show_soma and sec.parent is None:
|
1053
|
+
continue
|
1054
|
+
color = get_domain_color(sec.domain)
|
1055
|
+
if highlight and sec.idx in highlight:
|
1056
|
+
ax.plot(
|
1057
|
+
[pt.path_distance() for pt in sec.points],
|
1058
|
+
sec.radii,
|
1059
|
+
marker='.',
|
1060
|
+
color='red',
|
1061
|
+
zorder=2
|
1062
|
+
)
|
1063
|
+
else:
|
1064
|
+
ax.plot(
|
1065
|
+
[pt.path_distance() for pt in sec.points],
|
1066
|
+
sec.radii,
|
1067
|
+
marker='.',
|
1068
|
+
color=color,
|
1069
|
+
zorder=1
|
1070
|
+
)
|
1071
|
+
ax.set_xlabel('Distance from root')
|
1072
|
+
ax.set_ylabel('Radius')
|
1073
|
+
|
1074
|
+
|
1075
|
+
# EXPORT METHODS
|
1076
|
+
|
1077
|
+
def to_swc(self, path_to_file: str):
|
1078
|
+
"""
|
1079
|
+
Save the SectionTree as an SWC file.
|
1080
|
+
|
1081
|
+
Parameters
|
1082
|
+
----------
|
1083
|
+
path_to_file : str
|
1084
|
+
The path to save the SWC file.
|
1085
|
+
"""
|
1086
|
+
if not self.is_sorted or not self._point_tree.is_sorted:
|
1087
|
+
raise ValueError('The tree must be sorted before saving.')
|
1088
|
+
|
1089
|
+
data = {
|
1090
|
+
'idx': [],
|
1091
|
+
'type_idx': [],
|
1092
|
+
'x': [],
|
1093
|
+
'y': [],
|
1094
|
+
'z': [],
|
1095
|
+
'r': [],
|
1096
|
+
'parent_idx': []
|
1097
|
+
}
|
1098
|
+
|
1099
|
+
for sec in self.sections:
|
1100
|
+
points = sec.points if sec.parent is None or sec.parent.parent is None else sec.points[1:]
|
1101
|
+
for pt in points:
|
1102
|
+
data['idx'].append(pt.idx)
|
1103
|
+
data['type_idx'].append(pt.type_idx)
|
1104
|
+
data['x'].append(pt.x)
|
1105
|
+
data['y'].append(pt.y)
|
1106
|
+
data['z'].append(pt.z)
|
1107
|
+
data['r'].append(pt.r)
|
1108
|
+
data['parent_idx'].append(pt.parent_idx)
|
1109
|
+
|
1110
|
+
df = pd.DataFrame(data)
|
1111
|
+
df.to_csv(path_to_file, sep=' ', index=False, header=False)
|
1112
|
+
|