dendrotweaks 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dendrotweaks/__init__.py +10 -0
- dendrotweaks/analysis/__init__.py +11 -0
- dendrotweaks/analysis/ephys_analysis.py +482 -0
- dendrotweaks/analysis/morphometric_analysis.py +106 -0
- dendrotweaks/membrane/__init__.py +6 -0
- dendrotweaks/membrane/default_mod/AMPA.mod +65 -0
- dendrotweaks/membrane/default_mod/AMPA_NMDA.mod +100 -0
- dendrotweaks/membrane/default_mod/CaDyn.mod +54 -0
- dendrotweaks/membrane/default_mod/GABAa.mod +65 -0
- dendrotweaks/membrane/default_mod/Leak.mod +27 -0
- dendrotweaks/membrane/default_mod/NMDA.mod +72 -0
- dendrotweaks/membrane/default_mod/vecstim.mod +76 -0
- dendrotweaks/membrane/default_templates/NEURON_template.py +354 -0
- dendrotweaks/membrane/default_templates/default.py +73 -0
- dendrotweaks/membrane/default_templates/standard_channel.mod +87 -0
- dendrotweaks/membrane/default_templates/template_jaxley.py +108 -0
- dendrotweaks/membrane/default_templates/template_jaxley_new.py +108 -0
- dendrotweaks/membrane/distributions.py +324 -0
- dendrotweaks/membrane/groups.py +103 -0
- dendrotweaks/membrane/io/__init__.py +11 -0
- dendrotweaks/membrane/io/ast.py +201 -0
- dendrotweaks/membrane/io/code_generators.py +312 -0
- dendrotweaks/membrane/io/converter.py +108 -0
- dendrotweaks/membrane/io/factories.py +144 -0
- dendrotweaks/membrane/io/grammar.py +417 -0
- dendrotweaks/membrane/io/loader.py +90 -0
- dendrotweaks/membrane/io/parser.py +499 -0
- dendrotweaks/membrane/io/reader.py +212 -0
- dendrotweaks/membrane/mechanisms.py +574 -0
- dendrotweaks/model.py +1916 -0
- dendrotweaks/model_io.py +75 -0
- dendrotweaks/morphology/__init__.py +5 -0
- dendrotweaks/morphology/domains.py +100 -0
- dendrotweaks/morphology/io/__init__.py +5 -0
- dendrotweaks/morphology/io/factories.py +212 -0
- dendrotweaks/morphology/io/reader.py +66 -0
- dendrotweaks/morphology/io/validation.py +212 -0
- dendrotweaks/morphology/point_trees.py +681 -0
- dendrotweaks/morphology/reduce/__init__.py +16 -0
- dendrotweaks/morphology/reduce/reduce.py +155 -0
- dendrotweaks/morphology/reduce/reduced_cylinder.py +129 -0
- dendrotweaks/morphology/sec_trees.py +1112 -0
- dendrotweaks/morphology/seg_trees.py +157 -0
- dendrotweaks/morphology/trees.py +567 -0
- dendrotweaks/path_manager.py +261 -0
- dendrotweaks/simulators.py +235 -0
- dendrotweaks/stimuli/__init__.py +3 -0
- dendrotweaks/stimuli/iclamps.py +73 -0
- dendrotweaks/stimuli/populations.py +265 -0
- dendrotweaks/stimuli/synapses.py +203 -0
- dendrotweaks/utils.py +239 -0
- dendrotweaks-0.3.1.dist-info/METADATA +70 -0
- dendrotweaks-0.3.1.dist-info/RECORD +56 -0
- dendrotweaks-0.3.1.dist-info/WHEEL +5 -0
- dendrotweaks-0.3.1.dist-info/licenses/LICENSE +674 -0
- dendrotweaks-0.3.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,681 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import pandas as pd
|
3
|
+
import matplotlib.pyplot as plt
|
4
|
+
|
5
|
+
from scipy.spatial.transform import Rotation
|
6
|
+
|
7
|
+
from dendrotweaks.utils import timeit
|
8
|
+
|
9
|
+
from dendrotweaks.morphology.trees import Node, Tree
|
10
|
+
|
11
|
+
from dendrotweaks.utils import timeit
|
12
|
+
from dendrotweaks.utils import get_swc_idx, get_domain_name
|
13
|
+
from dendrotweaks.utils import get_domain_color
|
14
|
+
|
15
|
+
from contextlib import contextmanager
|
16
|
+
import random
|
17
|
+
|
18
|
+
|
19
|
+
class Point(Node):
|
20
|
+
"""
|
21
|
+
A class representing a single point in a morphological reconstruction.
|
22
|
+
|
23
|
+
Parameters
|
24
|
+
----------
|
25
|
+
idx : str
|
26
|
+
The unique identifier of the node.
|
27
|
+
type_idx : int
|
28
|
+
The type of the node according to the SWC specification (e.g. soma-1, axon-2, dendrite-3).
|
29
|
+
x : float
|
30
|
+
The x-coordinate of the node.
|
31
|
+
y : float
|
32
|
+
The y-coordinate of the node.
|
33
|
+
z : float
|
34
|
+
The z-coordinate of the node.
|
35
|
+
r : float
|
36
|
+
The radius of the node.
|
37
|
+
parent_idx : str
|
38
|
+
The identifier of the parent node.
|
39
|
+
|
40
|
+
Attributes
|
41
|
+
----------
|
42
|
+
idx : str
|
43
|
+
The unique identifier of the node.
|
44
|
+
type_idx : int
|
45
|
+
The type of the node according to the SWC specification (e.g. soma-1, axon-2, dendrite-3).
|
46
|
+
x : float
|
47
|
+
The x-coordinate of the node.
|
48
|
+
y : float
|
49
|
+
The y-coordinate of the node.
|
50
|
+
z : float
|
51
|
+
The z-coordinate of the node.
|
52
|
+
r : float
|
53
|
+
The radius of the node.
|
54
|
+
parent_idx : str
|
55
|
+
The identifier of the parent node.
|
56
|
+
"""
|
57
|
+
|
58
|
+
def __init__(self, idx: str, type_idx: int,
|
59
|
+
x: float, y: float, z: float, r: float,
|
60
|
+
parent_idx: str) -> None:
|
61
|
+
super().__init__(idx, parent_idx)
|
62
|
+
self.type_idx = type_idx
|
63
|
+
self.x = x
|
64
|
+
self.y = y
|
65
|
+
self.z = z
|
66
|
+
self.r = r
|
67
|
+
self._section = None
|
68
|
+
|
69
|
+
|
70
|
+
@property
|
71
|
+
def domain(self):
|
72
|
+
"""
|
73
|
+
The morphological or functional domain of the node.
|
74
|
+
"""
|
75
|
+
return get_domain_name(self.type_idx)
|
76
|
+
|
77
|
+
|
78
|
+
@domain.setter
|
79
|
+
def domain(self, value):
|
80
|
+
if isinstance(value, str):
|
81
|
+
self.type_idx = get_swc_idx(value)
|
82
|
+
elif value is None:
|
83
|
+
self.type_idx = None
|
84
|
+
|
85
|
+
|
86
|
+
@property
|
87
|
+
def distance_to_parent(self):
|
88
|
+
"""
|
89
|
+
The Euclidean distance from this node to its parent.
|
90
|
+
"""
|
91
|
+
if self.parent:
|
92
|
+
return np.sqrt((self.x - self.parent.x)**2 +
|
93
|
+
(self.y - self.parent.y)**2 +
|
94
|
+
(self.z - self.parent.z)**2)
|
95
|
+
return 0
|
96
|
+
|
97
|
+
|
98
|
+
def path_distance(self, within_domain=False, ancestor=None):
|
99
|
+
"""
|
100
|
+
Compute the distance from this node to an ancestor node.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
within_domain (bool): If True, stops when domain changes.
|
104
|
+
ancestor (Node, optional): If provided, stops at this specific ancestor.
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
float: The accumulated distance.
|
108
|
+
"""
|
109
|
+
distance = 0
|
110
|
+
node = self
|
111
|
+
|
112
|
+
while node.parent:
|
113
|
+
if ancestor and node.parent == ancestor:
|
114
|
+
break # Stop if we reach the specified ancestor
|
115
|
+
|
116
|
+
if within_domain and node.parent.domain != node.domain:
|
117
|
+
break # Stop if domain changes
|
118
|
+
|
119
|
+
distance += node.distance_to_parent
|
120
|
+
node = node.parent
|
121
|
+
|
122
|
+
return distance
|
123
|
+
|
124
|
+
|
125
|
+
|
126
|
+
@property
|
127
|
+
def df(self):
|
128
|
+
"""
|
129
|
+
Return a DataFrame representation of the node.
|
130
|
+
"""
|
131
|
+
return pd.DataFrame({'idx': [self.idx],
|
132
|
+
'type_idx': [self.type_idx],
|
133
|
+
'x': [self.x],
|
134
|
+
'y': [self.y],
|
135
|
+
'z': [self.z],
|
136
|
+
'r': [self.r],
|
137
|
+
'parent_idx': [self.parent_idx]})
|
138
|
+
|
139
|
+
def info(self):
|
140
|
+
"""
|
141
|
+
Print information about the node.
|
142
|
+
"""
|
143
|
+
info = (
|
144
|
+
f"Node {self.idx}:\n"
|
145
|
+
f" Type: {get_domain_name(self.type_idx)}\n"
|
146
|
+
f" Coordinates: ({self.x}, {self.y}, {self.z})\n"
|
147
|
+
f" Radius: {self.r}\n"
|
148
|
+
f" Parent: {self.parent_idx}\n"
|
149
|
+
f" Children: {[child.idx for child in self.children]}\n"
|
150
|
+
f" Siblings: {[sibling.idx for sibling in self.siblings]}\n"
|
151
|
+
f" Section: {self._section.idx if self._section else 'None'}"
|
152
|
+
)
|
153
|
+
print(info)
|
154
|
+
|
155
|
+
def copy(self):
|
156
|
+
"""
|
157
|
+
Create a copy of the node.
|
158
|
+
|
159
|
+
Returns:
|
160
|
+
Point: A copy of the node with the same attributes.
|
161
|
+
"""
|
162
|
+
new_node = Point(self.idx, self.type_idx, self.x,
|
163
|
+
self.y, self.z, self.r, self.parent_idx)
|
164
|
+
return new_node
|
165
|
+
|
166
|
+
|
167
|
+
def overlaps_with(self, other, **kwargs) -> bool:
|
168
|
+
"""
|
169
|
+
Check if the coordinates of this node overlap with another node.
|
170
|
+
|
171
|
+
Args:
|
172
|
+
other (Point): The other node to compare with.
|
173
|
+
kwargs: Additional keyword arguments passed to np.allclose.
|
174
|
+
|
175
|
+
Returns:
|
176
|
+
bool: True if the coordinates overlap, False otherwise.
|
177
|
+
"""
|
178
|
+
return np.allclose(
|
179
|
+
[self.x, self.y, self.z],
|
180
|
+
[other.x, other.y, other.z],
|
181
|
+
**kwargs
|
182
|
+
)
|
183
|
+
|
184
|
+
|
185
|
+
|
186
|
+
class PointTree(Tree):
|
187
|
+
"""
|
188
|
+
A class representing a tree graph of points in a morphological reconstruction.
|
189
|
+
|
190
|
+
Parameters
|
191
|
+
----------
|
192
|
+
nodes : list[Point]
|
193
|
+
A list of points in the tree.
|
194
|
+
"""
|
195
|
+
|
196
|
+
def __init__(self, nodes: list[Point]) -> None:
|
197
|
+
super().__init__(nodes)
|
198
|
+
self._sections = []
|
199
|
+
self._is_extended = False
|
200
|
+
|
201
|
+
|
202
|
+
def __repr__(self):
|
203
|
+
return f"PointTree(root={self.root!r}, num_nodes={len(self._nodes)})"
|
204
|
+
|
205
|
+
|
206
|
+
# PROPERTIES
|
207
|
+
|
208
|
+
@property
|
209
|
+
def points(self):
|
210
|
+
"""
|
211
|
+
The list of points in the tree. An alias for self._nodes.
|
212
|
+
"""
|
213
|
+
return self._nodes
|
214
|
+
|
215
|
+
# @property
|
216
|
+
# def is_sectioned(self):
|
217
|
+
# return len(self._sections) > 0
|
218
|
+
|
219
|
+
@property
|
220
|
+
def soma_points(self):
|
221
|
+
"""
|
222
|
+
The list of points representing the soma (type 1).
|
223
|
+
"""
|
224
|
+
return [pt for pt in self.points if pt.type_idx == 1]
|
225
|
+
|
226
|
+
@property
|
227
|
+
def soma_center(self):
|
228
|
+
"""
|
229
|
+
The center of the soma as the average of the coordinates of the soma points.
|
230
|
+
"""
|
231
|
+
return np.mean([[pt.x, pt.y, pt.z]
|
232
|
+
for pt in self.soma_points], axis=0)
|
233
|
+
|
234
|
+
@property
|
235
|
+
def apical_center(self):
|
236
|
+
"""
|
237
|
+
The center of the apical dendrite as the average of the coordinates of the apical points.
|
238
|
+
"""
|
239
|
+
apical_points = [pt for pt in self.points
|
240
|
+
if pt.type_idx == 4]
|
241
|
+
if len(apical_points) == 0:
|
242
|
+
return None
|
243
|
+
return np.mean([[pt.x, pt.y, pt.z]
|
244
|
+
for pt in apical_points], axis=0)
|
245
|
+
|
246
|
+
@property
|
247
|
+
def soma_notation(self):
|
248
|
+
"""
|
249
|
+
The type of soma notation used in the tree.
|
250
|
+
- '1PS': One-point soma
|
251
|
+
- '2PS': Two-point soma
|
252
|
+
- '3PS': Three-point soma
|
253
|
+
- 'contour': Soma represented as a contour
|
254
|
+
"""
|
255
|
+
if len(self.soma_points) == 1:
|
256
|
+
return '1PS'
|
257
|
+
elif len(self.soma_points) == 2:
|
258
|
+
return '2PS'
|
259
|
+
elif len(self.soma_points) == 3:
|
260
|
+
return '3PS'
|
261
|
+
else:
|
262
|
+
return 'contour'
|
263
|
+
|
264
|
+
@property
|
265
|
+
def df(self):
|
266
|
+
"""
|
267
|
+
A DataFrame representation of the tree.
|
268
|
+
"""
|
269
|
+
data = {
|
270
|
+
'idx': [node.idx for node in self._nodes],
|
271
|
+
'type_idx': [node.type_idx for node in self._nodes],
|
272
|
+
'x': [node.x for node in self._nodes],
|
273
|
+
'y': [node.y for node in self._nodes],
|
274
|
+
'z': [node.z for node in self._nodes],
|
275
|
+
'r': [node.r for node in self._nodes],
|
276
|
+
'parent_idx': [node.parent_idx for node in self._nodes],
|
277
|
+
}
|
278
|
+
return pd.DataFrame(data)
|
279
|
+
|
280
|
+
|
281
|
+
|
282
|
+
# SORTING METHODS
|
283
|
+
|
284
|
+
def _sort_children(self):
|
285
|
+
"""
|
286
|
+
Iterate through all nodes in the tree and sort their children based on
|
287
|
+
the number of bifurcations (nodes with more than one child) in each child's
|
288
|
+
subtree. Nodes with fewer bifurcations in their subtrees are placed earlier in the list
|
289
|
+
of the node's children, ensuring that the shortest paths are traversed first.
|
290
|
+
|
291
|
+
Returns
|
292
|
+
-------
|
293
|
+
None
|
294
|
+
"""
|
295
|
+
for node in self._nodes:
|
296
|
+
node.children = sorted(
|
297
|
+
node.children,
|
298
|
+
key=lambda x: (x.type_idx, sum(1 for n in x.subtree if len(n.children) > 1)),
|
299
|
+
reverse=False
|
300
|
+
)
|
301
|
+
|
302
|
+
# STANDARDIZATION METHODS
|
303
|
+
|
304
|
+
def change_soma_notation(self, notation):
|
305
|
+
"""
|
306
|
+
Convert the soma to 3PS notation.
|
307
|
+
"""
|
308
|
+
if self.soma_notation == notation:
|
309
|
+
return
|
310
|
+
|
311
|
+
if self.soma_notation == '1PS':
|
312
|
+
|
313
|
+
pt = self.soma_points[0]
|
314
|
+
|
315
|
+
pt_left = Point(
|
316
|
+
idx=2,
|
317
|
+
type_idx=1,
|
318
|
+
x=pt.x - pt.r,
|
319
|
+
y=pt.y,
|
320
|
+
z=pt.z,
|
321
|
+
r=pt.r,
|
322
|
+
parent_idx=pt.idx)
|
323
|
+
|
324
|
+
pt_right = Point(
|
325
|
+
idx=3,
|
326
|
+
type_idx=1,
|
327
|
+
x=pt.x + pt.r,
|
328
|
+
y=pt.y,
|
329
|
+
z=pt.z,
|
330
|
+
r=pt.r,
|
331
|
+
parent_idx=pt.idx)
|
332
|
+
|
333
|
+
self.add_subtree(pt_right, pt)
|
334
|
+
self.add_subtree(pt_left, pt)
|
335
|
+
|
336
|
+
elif self.soma_notation == '3PS':
|
337
|
+
raise NotImplementedError('Conversion from 1PS to 3PS notation is not implemented yet.')
|
338
|
+
|
339
|
+
elif self.soma_notation =='contour':
|
340
|
+
# if soma has contour notation, take the average
|
341
|
+
# distance of the nodes from the center of the soma
|
342
|
+
# and use it as radius, create 3 new nodes
|
343
|
+
raise NotImplementedError('Conversion from contour is not implemented yet.')
|
344
|
+
|
345
|
+
print('Converted soma to 3PS notation.')
|
346
|
+
|
347
|
+
# GEOMETRICAL METHODS
|
348
|
+
|
349
|
+
def round_coordinates(self, decimals=8):
|
350
|
+
"""
|
351
|
+
Round the coordinates of all nodes to the specified number of decimals.
|
352
|
+
|
353
|
+
Parameters
|
354
|
+
----------
|
355
|
+
decimals : int, optional
|
356
|
+
The number of decimals to round to, by default
|
357
|
+
"""
|
358
|
+
for pt in self.points:
|
359
|
+
pt.x = round(pt.x, decimals)
|
360
|
+
pt.y = round(pt.y, decimals)
|
361
|
+
pt.z = round(pt.z, decimals)
|
362
|
+
pt.r = round(pt.r, decimals)
|
363
|
+
|
364
|
+
def shift_coordinates_to_soma_center(self):
|
365
|
+
"""
|
366
|
+
Shift all coordinates so that the soma center is at the origin (0, 0, 0).
|
367
|
+
"""
|
368
|
+
soma_x, soma_y, soma_z = self.soma_center
|
369
|
+
for pt in self.points:
|
370
|
+
pt.x = round(pt.x - soma_x, 8)
|
371
|
+
pt.y = round(pt.y - soma_y, 8)
|
372
|
+
pt.z = round(pt.z - soma_z, 8)
|
373
|
+
|
374
|
+
@timeit
|
375
|
+
def rotate(self, angle_deg, axis='Y'):
|
376
|
+
"""Rotate the point cloud around the specified axis at the soma center using numpy.
|
377
|
+
|
378
|
+
Parameters
|
379
|
+
----------
|
380
|
+
angle_deg : float
|
381
|
+
The rotation angle in degrees.
|
382
|
+
axis : str, optional
|
383
|
+
The rotation axis ('X', 'Y', or 'Z'), by default 'Y'.
|
384
|
+
"""
|
385
|
+
|
386
|
+
# Get the rotation center point
|
387
|
+
rotation_point = self.soma_center
|
388
|
+
|
389
|
+
# Define rotation matrix based on the specified axis
|
390
|
+
angle = np.radians(angle_deg)
|
391
|
+
if axis == 'X':
|
392
|
+
rotation_matrix = np.array([
|
393
|
+
[1, 0, 0],
|
394
|
+
[0, np.cos(angle), -np.sin(angle)],
|
395
|
+
[0, np.sin(angle), np.cos(angle)]
|
396
|
+
])
|
397
|
+
elif axis == 'Y':
|
398
|
+
rotation_matrix = np.array([
|
399
|
+
[np.cos(angle), 0, np.sin(angle)],
|
400
|
+
[0, 1, 0],
|
401
|
+
[-np.sin(angle), 0, np.cos(angle)]
|
402
|
+
])
|
403
|
+
elif axis == 'Z':
|
404
|
+
rotation_matrix = np.array([
|
405
|
+
[np.cos(angle), -np.sin(angle), 0],
|
406
|
+
[np.sin(angle), np.cos(angle), 0],
|
407
|
+
[0, 0, 1]
|
408
|
+
])
|
409
|
+
else:
|
410
|
+
raise ValueError("Axis must be 'X', 'Y', or 'Z'")
|
411
|
+
|
412
|
+
# Subtract rotation point to translate the cloud to the origin
|
413
|
+
coords = np.array([[pt.x, pt.y, pt.z] for pt in self.points])
|
414
|
+
coords -= rotation_point
|
415
|
+
|
416
|
+
# Apply rotation
|
417
|
+
rotated_coords = np.dot(coords, rotation_matrix.T)
|
418
|
+
|
419
|
+
# Translate back to the original position
|
420
|
+
rotated_coords += rotation_point
|
421
|
+
|
422
|
+
# Update the coordinates of the points
|
423
|
+
for pt, (x, y, z) in zip(self._nodes, rotated_coords):
|
424
|
+
pt.x, pt.y, pt.z = x, y, z
|
425
|
+
|
426
|
+
def align_apical_dendrite(self, axis='Y', facing='up'):
|
427
|
+
"""
|
428
|
+
Align the apical dendrite with the specified axis.
|
429
|
+
|
430
|
+
Parameters
|
431
|
+
----------
|
432
|
+
axis : str, optional
|
433
|
+
The axis to align the apical dendrite with ('X', 'Y', or 'Z'), by default 'Y'.
|
434
|
+
facing : str, optional
|
435
|
+
The direction the apical dendrite should face ('up' or 'down'), by default 'up'.
|
436
|
+
"""
|
437
|
+
soma_center = self.soma_center
|
438
|
+
apical_center = self.apical_center
|
439
|
+
|
440
|
+
if apical_center is None:
|
441
|
+
return
|
442
|
+
|
443
|
+
# Define the target vector based on the axis and facing
|
444
|
+
target_vector = {
|
445
|
+
'X': np.array([1, 0, 0]),
|
446
|
+
'Y': np.array([0, 1, 0]),
|
447
|
+
'Z': np.array([0, 0, 1])
|
448
|
+
}.get(axis.upper(), None)
|
449
|
+
|
450
|
+
if target_vector is None:
|
451
|
+
raise ValueError("Axis must be 'X', 'Y', or 'Z'")
|
452
|
+
|
453
|
+
if facing == 'down':
|
454
|
+
target_vector = -target_vector
|
455
|
+
|
456
|
+
# Calculate the current vector
|
457
|
+
current_vector = apical_center - soma_center
|
458
|
+
|
459
|
+
# Check if the apical dendrite is already aligned
|
460
|
+
if np.allclose(current_vector / np.linalg.norm(current_vector), target_vector):
|
461
|
+
print('Apical dendrite is already aligned.')
|
462
|
+
return
|
463
|
+
|
464
|
+
# Calculate the rotation vector and angle
|
465
|
+
rotation_vector = np.cross(current_vector, target_vector)
|
466
|
+
rotation_angle = np.arccos(np.dot(current_vector, target_vector) / np.linalg.norm(current_vector))
|
467
|
+
|
468
|
+
# Create the rotation matrix
|
469
|
+
rotation_matrix = Rotation.from_rotvec(rotation_angle * rotation_vector / np.linalg.norm(rotation_vector)).as_matrix()
|
470
|
+
|
471
|
+
# Apply the rotation to each point
|
472
|
+
for pt in self.points:
|
473
|
+
coords = np.array([pt.x, pt.y, pt.z]) - soma_center
|
474
|
+
rotated_coords = np.dot(rotation_matrix, coords) + soma_center
|
475
|
+
pt.x, pt.y, pt.z = rotated_coords
|
476
|
+
|
477
|
+
|
478
|
+
# I/O METHODS
|
479
|
+
def remove_overlaps(self):
|
480
|
+
"""
|
481
|
+
Remove overlapping nodes from the tree.
|
482
|
+
"""
|
483
|
+
nodes_before = len(self.points)
|
484
|
+
|
485
|
+
overlapping_nodes = [
|
486
|
+
pt for pt in self.traverse()
|
487
|
+
if pt.parent is not None and pt.overlaps_with(pt.parent)
|
488
|
+
]
|
489
|
+
for pt in overlapping_nodes:
|
490
|
+
self.remove_node(pt)
|
491
|
+
|
492
|
+
self._is_extended = False
|
493
|
+
nodes_after = len(self.points)
|
494
|
+
if nodes_before != nodes_after:
|
495
|
+
print(f'Removed {nodes_before - nodes_after} overlapping nodes.')
|
496
|
+
|
497
|
+
|
498
|
+
def extend_sections(self):
|
499
|
+
"""
|
500
|
+
Extend each section by adding a node in the beginning
|
501
|
+
overlapping with the parent node for geometrical continuity.
|
502
|
+
"""
|
503
|
+
|
504
|
+
nodes_before = len(self.points)
|
505
|
+
|
506
|
+
if self._is_extended:
|
507
|
+
print('Tree is already extended.')
|
508
|
+
return
|
509
|
+
|
510
|
+
bifurcations_excluding_root = [
|
511
|
+
b for b in self.bifurcations if b != self.root
|
512
|
+
]
|
513
|
+
|
514
|
+
for pt in bifurcations_excluding_root:
|
515
|
+
children = pt.children[:]
|
516
|
+
for child in children:
|
517
|
+
if child.overlaps_with(pt):
|
518
|
+
raise ValueError(f'Child {child} already overlaps with parent {pt}.')
|
519
|
+
new_node = pt.copy()
|
520
|
+
new_node.domain = child.domain
|
521
|
+
self.insert_node_before(new_node, child)
|
522
|
+
|
523
|
+
self._is_extended = True
|
524
|
+
nodes_after = len(self.points)
|
525
|
+
print(f'Extended {nodes_after - nodes_before} nodes.')
|
526
|
+
|
527
|
+
|
528
|
+
def to_swc(self, path_to_file):
|
529
|
+
"""
|
530
|
+
Save the tree to an SWC file.
|
531
|
+
"""
|
532
|
+
with remove_overlaps(self):
|
533
|
+
df = self.df.astype({
|
534
|
+
'idx': int,
|
535
|
+
'type_idx': int,
|
536
|
+
'x': float,
|
537
|
+
'y': float,
|
538
|
+
'z': float,
|
539
|
+
'r': float,
|
540
|
+
'parent_idx': int
|
541
|
+
})
|
542
|
+
df['idx'] += 1
|
543
|
+
df.loc[df['parent_idx'] >= 0, 'parent_idx'] += 1
|
544
|
+
df.to_csv(path_to_file, sep=' ', index=False, header=False)
|
545
|
+
|
546
|
+
|
547
|
+
# PLOTTING METHODS
|
548
|
+
|
549
|
+
def plot(self, ax=None,
|
550
|
+
show_nodes=True, show_edges=True, show_domains=True,
|
551
|
+
annotate=False, projection='XY',
|
552
|
+
highlight_nodes=None, focus_nodes=None):
|
553
|
+
"""
|
554
|
+
Plot a 2D projection of the tree.
|
555
|
+
|
556
|
+
Parameters
|
557
|
+
----------
|
558
|
+
ax : matplotlib.axes.Axes, optional
|
559
|
+
The axes to plot on, by default None
|
560
|
+
show_nodes : bool, optional
|
561
|
+
Whether to plot the nodes, by default True
|
562
|
+
show_edges : bool, optional
|
563
|
+
Whether to plot the edges, by default True
|
564
|
+
show_domains : bool, optional
|
565
|
+
Whether to color the nodes based on their domains, by default True
|
566
|
+
annotate : bool, optional
|
567
|
+
Whether to annotate the nodes with their indices, by default False
|
568
|
+
projection : str, optional
|
569
|
+
The projection plane ('XY', 'XZ', or 'YZ'), by default 'XY'
|
570
|
+
highlight_nodes : list, optional
|
571
|
+
A list of nodes to highlight, by default None
|
572
|
+
focus_nodes : list, optional
|
573
|
+
A list of nodes to focus on, by default None
|
574
|
+
"""
|
575
|
+
|
576
|
+
if ax is None:
|
577
|
+
fig, ax = plt.subplots(figsize=(10, 10))
|
578
|
+
|
579
|
+
# Convert focus/highlight to sets for faster lookup
|
580
|
+
focus_nodes = set(focus_nodes) if focus_nodes else None
|
581
|
+
highlight_nodes = set(highlight_nodes) if highlight_nodes else None
|
582
|
+
|
583
|
+
# Determine which points to consider
|
584
|
+
points_to_plot = self.points if focus_nodes is None else [pt for pt in self.points if pt in focus_nodes]
|
585
|
+
|
586
|
+
# Extract coordinates for projection
|
587
|
+
coords = {axis: [getattr(pt, axis.lower()) for pt in points_to_plot] for axis in "XYZ"}
|
588
|
+
|
589
|
+
# Draw edges efficiently
|
590
|
+
if show_edges:
|
591
|
+
point_set = set(points_to_plot) # Convert list to set for fast lookup
|
592
|
+
for pt1, pt2 in self.edges:
|
593
|
+
if pt1 in point_set and pt2 in point_set:
|
594
|
+
ax.plot(
|
595
|
+
[getattr(pt1, projection[0].lower()), getattr(pt2, projection[0].lower())],
|
596
|
+
[getattr(pt1, projection[1].lower()), getattr(pt2, projection[1].lower())],
|
597
|
+
color='C1'
|
598
|
+
)
|
599
|
+
|
600
|
+
# Assign colors based on domains
|
601
|
+
if show_domains:
|
602
|
+
for pt in points_to_plot:
|
603
|
+
colors = [get_domain_color(pt.domain) for pt in points_to_plot]
|
604
|
+
else:
|
605
|
+
colors = 'C0'
|
606
|
+
|
607
|
+
# Plot nodes
|
608
|
+
if show_nodes:
|
609
|
+
ax.scatter(coords[projection[0]], coords[projection[1]], s=10, c=colors, marker='.', zorder=2)
|
610
|
+
|
611
|
+
# Annotate nodes if few enough
|
612
|
+
if annotate and len(points_to_plot) < 50:
|
613
|
+
for pt, x, y in zip(points_to_plot, coords[projection[0]], coords[projection[1]]):
|
614
|
+
ax.annotate(f'{pt.idx}', (x, y), fontsize=8)
|
615
|
+
|
616
|
+
# Highlight nodes correctly
|
617
|
+
if highlight_nodes:
|
618
|
+
for i, pt in enumerate(points_to_plot):
|
619
|
+
if pt in highlight_nodes:
|
620
|
+
ax.plot(coords[projection[0]][i], coords[projection[1]][i], 'o', color='C3', markersize=5)
|
621
|
+
|
622
|
+
# Set labels and aspect ratio
|
623
|
+
ax.set_xlabel(projection[0])
|
624
|
+
ax.set_ylabel(projection[1])
|
625
|
+
if projection in {"XY", "XZ", "YZ"}:
|
626
|
+
ax.set_aspect('equal')
|
627
|
+
|
628
|
+
|
629
|
+
|
630
|
+
def plot_radii_distribution(self, ax=None, highlight=None,
|
631
|
+
domains=True, show_soma=False):
|
632
|
+
if ax is None:
|
633
|
+
fig, ax = plt.subplots(figsize=(8, 3))
|
634
|
+
|
635
|
+
for pt in self.points:
|
636
|
+
if not show_soma and pt.domain == 'soma':
|
637
|
+
continue
|
638
|
+
color = get_domain_color(pt.domain)
|
639
|
+
if highlight and pt.idx in highlight:
|
640
|
+
ax.plot(
|
641
|
+
pt.path_distance(),
|
642
|
+
pt.r,
|
643
|
+
marker='.',
|
644
|
+
color='red',
|
645
|
+
zorder=2
|
646
|
+
)
|
647
|
+
else:
|
648
|
+
ax.plot(
|
649
|
+
pt.path_distance(),
|
650
|
+
pt.r,
|
651
|
+
marker='.',
|
652
|
+
color=color,
|
653
|
+
zorder=1
|
654
|
+
)
|
655
|
+
ax.set_xlabel('Distance from root')
|
656
|
+
ax.set_ylabel('Radius')
|
657
|
+
|
658
|
+
|
659
|
+
@contextmanager
|
660
|
+
def remove_overlaps(point_tree):
|
661
|
+
"""
|
662
|
+
Context manager for temporarily removing overlaps in the given point_tree.
|
663
|
+
Is primarily used for saving the tree to an SWC file without overlaps.
|
664
|
+
Restores the original state of the tree after the context block to ensure
|
665
|
+
the geometrical continuity of the tree.
|
666
|
+
"""
|
667
|
+
# Store whether the point_tree was already extended
|
668
|
+
was_extended = point_tree._is_extended
|
669
|
+
|
670
|
+
# Remove overlaps
|
671
|
+
point_tree.remove_overlaps()
|
672
|
+
point_tree.sort()
|
673
|
+
|
674
|
+
try:
|
675
|
+
# Yield control to the context block
|
676
|
+
yield
|
677
|
+
finally:
|
678
|
+
# Restore the overlapping state if the point_tree was extended
|
679
|
+
if was_extended:
|
680
|
+
point_tree.extend_sections()
|
681
|
+
point_tree.sort()
|
@@ -0,0 +1,16 @@
|
|
1
|
+
from dendrotweaks.morphology.reduce.reduce import map_segs_to_params
|
2
|
+
from dendrotweaks.morphology.reduce.reduce import map_segs_to_locs
|
3
|
+
from dendrotweaks.morphology.reduce.reduce import map_segs_to_reduced_segs
|
4
|
+
from dendrotweaks.morphology.reduce.reduce import map_reduced_segs_to_params
|
5
|
+
|
6
|
+
from dendrotweaks.morphology.reduce.reduce import set_avg_params_to_reduced_segs
|
7
|
+
from dendrotweaks.morphology.reduce.reduce import interpolate_missing_values
|
8
|
+
|
9
|
+
|
10
|
+
import neuron_reduce
|
11
|
+
from dendrotweaks.morphology.reduce.reduced_cylinder import _get_subtree_biophysical_properties
|
12
|
+
neuron_reduce.reducing_methods._get_subtree_biophysical_properties = _get_subtree_biophysical_properties
|
13
|
+
from neuron_reduce.reducing_methods import reduce_subtree as get_unique_cable_properties
|
14
|
+
|
15
|
+
from dendrotweaks.morphology.reduce.reduced_cylinder import calculate_nsegs
|
16
|
+
from dendrotweaks.morphology.reduce.reduced_cylinder import apply_params_to_section
|