dendrotweaks 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. dendrotweaks/__init__.py +10 -0
  2. dendrotweaks/analysis/__init__.py +11 -0
  3. dendrotweaks/analysis/ephys_analysis.py +482 -0
  4. dendrotweaks/analysis/morphometric_analysis.py +106 -0
  5. dendrotweaks/membrane/__init__.py +6 -0
  6. dendrotweaks/membrane/default_mod/AMPA.mod +65 -0
  7. dendrotweaks/membrane/default_mod/AMPA_NMDA.mod +100 -0
  8. dendrotweaks/membrane/default_mod/CaDyn.mod +54 -0
  9. dendrotweaks/membrane/default_mod/GABAa.mod +65 -0
  10. dendrotweaks/membrane/default_mod/Leak.mod +27 -0
  11. dendrotweaks/membrane/default_mod/NMDA.mod +72 -0
  12. dendrotweaks/membrane/default_mod/vecstim.mod +76 -0
  13. dendrotweaks/membrane/default_templates/NEURON_template.py +354 -0
  14. dendrotweaks/membrane/default_templates/default.py +73 -0
  15. dendrotweaks/membrane/default_templates/standard_channel.mod +87 -0
  16. dendrotweaks/membrane/default_templates/template_jaxley.py +108 -0
  17. dendrotweaks/membrane/default_templates/template_jaxley_new.py +108 -0
  18. dendrotweaks/membrane/distributions.py +324 -0
  19. dendrotweaks/membrane/groups.py +103 -0
  20. dendrotweaks/membrane/io/__init__.py +11 -0
  21. dendrotweaks/membrane/io/ast.py +201 -0
  22. dendrotweaks/membrane/io/code_generators.py +312 -0
  23. dendrotweaks/membrane/io/converter.py +108 -0
  24. dendrotweaks/membrane/io/factories.py +144 -0
  25. dendrotweaks/membrane/io/grammar.py +417 -0
  26. dendrotweaks/membrane/io/loader.py +90 -0
  27. dendrotweaks/membrane/io/parser.py +499 -0
  28. dendrotweaks/membrane/io/reader.py +212 -0
  29. dendrotweaks/membrane/mechanisms.py +574 -0
  30. dendrotweaks/model.py +1916 -0
  31. dendrotweaks/model_io.py +75 -0
  32. dendrotweaks/morphology/__init__.py +5 -0
  33. dendrotweaks/morphology/domains.py +100 -0
  34. dendrotweaks/morphology/io/__init__.py +5 -0
  35. dendrotweaks/morphology/io/factories.py +212 -0
  36. dendrotweaks/morphology/io/reader.py +66 -0
  37. dendrotweaks/morphology/io/validation.py +212 -0
  38. dendrotweaks/morphology/point_trees.py +681 -0
  39. dendrotweaks/morphology/reduce/__init__.py +16 -0
  40. dendrotweaks/morphology/reduce/reduce.py +155 -0
  41. dendrotweaks/morphology/reduce/reduced_cylinder.py +129 -0
  42. dendrotweaks/morphology/sec_trees.py +1112 -0
  43. dendrotweaks/morphology/seg_trees.py +157 -0
  44. dendrotweaks/morphology/trees.py +567 -0
  45. dendrotweaks/path_manager.py +261 -0
  46. dendrotweaks/simulators.py +235 -0
  47. dendrotweaks/stimuli/__init__.py +3 -0
  48. dendrotweaks/stimuli/iclamps.py +73 -0
  49. dendrotweaks/stimuli/populations.py +265 -0
  50. dendrotweaks/stimuli/synapses.py +203 -0
  51. dendrotweaks/utils.py +239 -0
  52. dendrotweaks-0.3.1.dist-info/METADATA +70 -0
  53. dendrotweaks-0.3.1.dist-info/RECORD +56 -0
  54. dendrotweaks-0.3.1.dist-info/WHEEL +5 -0
  55. dendrotweaks-0.3.1.dist-info/licenses/LICENSE +674 -0
  56. dendrotweaks-0.3.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,681 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+
5
+ from scipy.spatial.transform import Rotation
6
+
7
+ from dendrotweaks.utils import timeit
8
+
9
+ from dendrotweaks.morphology.trees import Node, Tree
10
+
11
+ from dendrotweaks.utils import timeit
12
+ from dendrotweaks.utils import get_swc_idx, get_domain_name
13
+ from dendrotweaks.utils import get_domain_color
14
+
15
+ from contextlib import contextmanager
16
+ import random
17
+
18
+
19
+ class Point(Node):
20
+ """
21
+ A class representing a single point in a morphological reconstruction.
22
+
23
+ Parameters
24
+ ----------
25
+ idx : str
26
+ The unique identifier of the node.
27
+ type_idx : int
28
+ The type of the node according to the SWC specification (e.g. soma-1, axon-2, dendrite-3).
29
+ x : float
30
+ The x-coordinate of the node.
31
+ y : float
32
+ The y-coordinate of the node.
33
+ z : float
34
+ The z-coordinate of the node.
35
+ r : float
36
+ The radius of the node.
37
+ parent_idx : str
38
+ The identifier of the parent node.
39
+
40
+ Attributes
41
+ ----------
42
+ idx : str
43
+ The unique identifier of the node.
44
+ type_idx : int
45
+ The type of the node according to the SWC specification (e.g. soma-1, axon-2, dendrite-3).
46
+ x : float
47
+ The x-coordinate of the node.
48
+ y : float
49
+ The y-coordinate of the node.
50
+ z : float
51
+ The z-coordinate of the node.
52
+ r : float
53
+ The radius of the node.
54
+ parent_idx : str
55
+ The identifier of the parent node.
56
+ """
57
+
58
+ def __init__(self, idx: str, type_idx: int,
59
+ x: float, y: float, z: float, r: float,
60
+ parent_idx: str) -> None:
61
+ super().__init__(idx, parent_idx)
62
+ self.type_idx = type_idx
63
+ self.x = x
64
+ self.y = y
65
+ self.z = z
66
+ self.r = r
67
+ self._section = None
68
+
69
+
70
+ @property
71
+ def domain(self):
72
+ """
73
+ The morphological or functional domain of the node.
74
+ """
75
+ return get_domain_name(self.type_idx)
76
+
77
+
78
+ @domain.setter
79
+ def domain(self, value):
80
+ if isinstance(value, str):
81
+ self.type_idx = get_swc_idx(value)
82
+ elif value is None:
83
+ self.type_idx = None
84
+
85
+
86
+ @property
87
+ def distance_to_parent(self):
88
+ """
89
+ The Euclidean distance from this node to its parent.
90
+ """
91
+ if self.parent:
92
+ return np.sqrt((self.x - self.parent.x)**2 +
93
+ (self.y - self.parent.y)**2 +
94
+ (self.z - self.parent.z)**2)
95
+ return 0
96
+
97
+
98
+ def path_distance(self, within_domain=False, ancestor=None):
99
+ """
100
+ Compute the distance from this node to an ancestor node.
101
+
102
+ Args:
103
+ within_domain (bool): If True, stops when domain changes.
104
+ ancestor (Node, optional): If provided, stops at this specific ancestor.
105
+
106
+ Returns:
107
+ float: The accumulated distance.
108
+ """
109
+ distance = 0
110
+ node = self
111
+
112
+ while node.parent:
113
+ if ancestor and node.parent == ancestor:
114
+ break # Stop if we reach the specified ancestor
115
+
116
+ if within_domain and node.parent.domain != node.domain:
117
+ break # Stop if domain changes
118
+
119
+ distance += node.distance_to_parent
120
+ node = node.parent
121
+
122
+ return distance
123
+
124
+
125
+
126
+ @property
127
+ def df(self):
128
+ """
129
+ Return a DataFrame representation of the node.
130
+ """
131
+ return pd.DataFrame({'idx': [self.idx],
132
+ 'type_idx': [self.type_idx],
133
+ 'x': [self.x],
134
+ 'y': [self.y],
135
+ 'z': [self.z],
136
+ 'r': [self.r],
137
+ 'parent_idx': [self.parent_idx]})
138
+
139
+ def info(self):
140
+ """
141
+ Print information about the node.
142
+ """
143
+ info = (
144
+ f"Node {self.idx}:\n"
145
+ f" Type: {get_domain_name(self.type_idx)}\n"
146
+ f" Coordinates: ({self.x}, {self.y}, {self.z})\n"
147
+ f" Radius: {self.r}\n"
148
+ f" Parent: {self.parent_idx}\n"
149
+ f" Children: {[child.idx for child in self.children]}\n"
150
+ f" Siblings: {[sibling.idx for sibling in self.siblings]}\n"
151
+ f" Section: {self._section.idx if self._section else 'None'}"
152
+ )
153
+ print(info)
154
+
155
+ def copy(self):
156
+ """
157
+ Create a copy of the node.
158
+
159
+ Returns:
160
+ Point: A copy of the node with the same attributes.
161
+ """
162
+ new_node = Point(self.idx, self.type_idx, self.x,
163
+ self.y, self.z, self.r, self.parent_idx)
164
+ return new_node
165
+
166
+
167
+ def overlaps_with(self, other, **kwargs) -> bool:
168
+ """
169
+ Check if the coordinates of this node overlap with another node.
170
+
171
+ Args:
172
+ other (Point): The other node to compare with.
173
+ kwargs: Additional keyword arguments passed to np.allclose.
174
+
175
+ Returns:
176
+ bool: True if the coordinates overlap, False otherwise.
177
+ """
178
+ return np.allclose(
179
+ [self.x, self.y, self.z],
180
+ [other.x, other.y, other.z],
181
+ **kwargs
182
+ )
183
+
184
+
185
+
186
+ class PointTree(Tree):
187
+ """
188
+ A class representing a tree graph of points in a morphological reconstruction.
189
+
190
+ Parameters
191
+ ----------
192
+ nodes : list[Point]
193
+ A list of points in the tree.
194
+ """
195
+
196
+ def __init__(self, nodes: list[Point]) -> None:
197
+ super().__init__(nodes)
198
+ self._sections = []
199
+ self._is_extended = False
200
+
201
+
202
+ def __repr__(self):
203
+ return f"PointTree(root={self.root!r}, num_nodes={len(self._nodes)})"
204
+
205
+
206
+ # PROPERTIES
207
+
208
+ @property
209
+ def points(self):
210
+ """
211
+ The list of points in the tree. An alias for self._nodes.
212
+ """
213
+ return self._nodes
214
+
215
+ # @property
216
+ # def is_sectioned(self):
217
+ # return len(self._sections) > 0
218
+
219
+ @property
220
+ def soma_points(self):
221
+ """
222
+ The list of points representing the soma (type 1).
223
+ """
224
+ return [pt for pt in self.points if pt.type_idx == 1]
225
+
226
+ @property
227
+ def soma_center(self):
228
+ """
229
+ The center of the soma as the average of the coordinates of the soma points.
230
+ """
231
+ return np.mean([[pt.x, pt.y, pt.z]
232
+ for pt in self.soma_points], axis=0)
233
+
234
+ @property
235
+ def apical_center(self):
236
+ """
237
+ The center of the apical dendrite as the average of the coordinates of the apical points.
238
+ """
239
+ apical_points = [pt for pt in self.points
240
+ if pt.type_idx == 4]
241
+ if len(apical_points) == 0:
242
+ return None
243
+ return np.mean([[pt.x, pt.y, pt.z]
244
+ for pt in apical_points], axis=0)
245
+
246
+ @property
247
+ def soma_notation(self):
248
+ """
249
+ The type of soma notation used in the tree.
250
+ - '1PS': One-point soma
251
+ - '2PS': Two-point soma
252
+ - '3PS': Three-point soma
253
+ - 'contour': Soma represented as a contour
254
+ """
255
+ if len(self.soma_points) == 1:
256
+ return '1PS'
257
+ elif len(self.soma_points) == 2:
258
+ return '2PS'
259
+ elif len(self.soma_points) == 3:
260
+ return '3PS'
261
+ else:
262
+ return 'contour'
263
+
264
+ @property
265
+ def df(self):
266
+ """
267
+ A DataFrame representation of the tree.
268
+ """
269
+ data = {
270
+ 'idx': [node.idx for node in self._nodes],
271
+ 'type_idx': [node.type_idx for node in self._nodes],
272
+ 'x': [node.x for node in self._nodes],
273
+ 'y': [node.y for node in self._nodes],
274
+ 'z': [node.z for node in self._nodes],
275
+ 'r': [node.r for node in self._nodes],
276
+ 'parent_idx': [node.parent_idx for node in self._nodes],
277
+ }
278
+ return pd.DataFrame(data)
279
+
280
+
281
+
282
+ # SORTING METHODS
283
+
284
+ def _sort_children(self):
285
+ """
286
+ Iterate through all nodes in the tree and sort their children based on
287
+ the number of bifurcations (nodes with more than one child) in each child's
288
+ subtree. Nodes with fewer bifurcations in their subtrees are placed earlier in the list
289
+ of the node's children, ensuring that the shortest paths are traversed first.
290
+
291
+ Returns
292
+ -------
293
+ None
294
+ """
295
+ for node in self._nodes:
296
+ node.children = sorted(
297
+ node.children,
298
+ key=lambda x: (x.type_idx, sum(1 for n in x.subtree if len(n.children) > 1)),
299
+ reverse=False
300
+ )
301
+
302
+ # STANDARDIZATION METHODS
303
+
304
+ def change_soma_notation(self, notation):
305
+ """
306
+ Convert the soma to 3PS notation.
307
+ """
308
+ if self.soma_notation == notation:
309
+ return
310
+
311
+ if self.soma_notation == '1PS':
312
+
313
+ pt = self.soma_points[0]
314
+
315
+ pt_left = Point(
316
+ idx=2,
317
+ type_idx=1,
318
+ x=pt.x - pt.r,
319
+ y=pt.y,
320
+ z=pt.z,
321
+ r=pt.r,
322
+ parent_idx=pt.idx)
323
+
324
+ pt_right = Point(
325
+ idx=3,
326
+ type_idx=1,
327
+ x=pt.x + pt.r,
328
+ y=pt.y,
329
+ z=pt.z,
330
+ r=pt.r,
331
+ parent_idx=pt.idx)
332
+
333
+ self.add_subtree(pt_right, pt)
334
+ self.add_subtree(pt_left, pt)
335
+
336
+ elif self.soma_notation == '3PS':
337
+ raise NotImplementedError('Conversion from 1PS to 3PS notation is not implemented yet.')
338
+
339
+ elif self.soma_notation =='contour':
340
+ # if soma has contour notation, take the average
341
+ # distance of the nodes from the center of the soma
342
+ # and use it as radius, create 3 new nodes
343
+ raise NotImplementedError('Conversion from contour is not implemented yet.')
344
+
345
+ print('Converted soma to 3PS notation.')
346
+
347
+ # GEOMETRICAL METHODS
348
+
349
+ def round_coordinates(self, decimals=8):
350
+ """
351
+ Round the coordinates of all nodes to the specified number of decimals.
352
+
353
+ Parameters
354
+ ----------
355
+ decimals : int, optional
356
+ The number of decimals to round to, by default
357
+ """
358
+ for pt in self.points:
359
+ pt.x = round(pt.x, decimals)
360
+ pt.y = round(pt.y, decimals)
361
+ pt.z = round(pt.z, decimals)
362
+ pt.r = round(pt.r, decimals)
363
+
364
+ def shift_coordinates_to_soma_center(self):
365
+ """
366
+ Shift all coordinates so that the soma center is at the origin (0, 0, 0).
367
+ """
368
+ soma_x, soma_y, soma_z = self.soma_center
369
+ for pt in self.points:
370
+ pt.x = round(pt.x - soma_x, 8)
371
+ pt.y = round(pt.y - soma_y, 8)
372
+ pt.z = round(pt.z - soma_z, 8)
373
+
374
+ @timeit
375
+ def rotate(self, angle_deg, axis='Y'):
376
+ """Rotate the point cloud around the specified axis at the soma center using numpy.
377
+
378
+ Parameters
379
+ ----------
380
+ angle_deg : float
381
+ The rotation angle in degrees.
382
+ axis : str, optional
383
+ The rotation axis ('X', 'Y', or 'Z'), by default 'Y'.
384
+ """
385
+
386
+ # Get the rotation center point
387
+ rotation_point = self.soma_center
388
+
389
+ # Define rotation matrix based on the specified axis
390
+ angle = np.radians(angle_deg)
391
+ if axis == 'X':
392
+ rotation_matrix = np.array([
393
+ [1, 0, 0],
394
+ [0, np.cos(angle), -np.sin(angle)],
395
+ [0, np.sin(angle), np.cos(angle)]
396
+ ])
397
+ elif axis == 'Y':
398
+ rotation_matrix = np.array([
399
+ [np.cos(angle), 0, np.sin(angle)],
400
+ [0, 1, 0],
401
+ [-np.sin(angle), 0, np.cos(angle)]
402
+ ])
403
+ elif axis == 'Z':
404
+ rotation_matrix = np.array([
405
+ [np.cos(angle), -np.sin(angle), 0],
406
+ [np.sin(angle), np.cos(angle), 0],
407
+ [0, 0, 1]
408
+ ])
409
+ else:
410
+ raise ValueError("Axis must be 'X', 'Y', or 'Z'")
411
+
412
+ # Subtract rotation point to translate the cloud to the origin
413
+ coords = np.array([[pt.x, pt.y, pt.z] for pt in self.points])
414
+ coords -= rotation_point
415
+
416
+ # Apply rotation
417
+ rotated_coords = np.dot(coords, rotation_matrix.T)
418
+
419
+ # Translate back to the original position
420
+ rotated_coords += rotation_point
421
+
422
+ # Update the coordinates of the points
423
+ for pt, (x, y, z) in zip(self._nodes, rotated_coords):
424
+ pt.x, pt.y, pt.z = x, y, z
425
+
426
+ def align_apical_dendrite(self, axis='Y', facing='up'):
427
+ """
428
+ Align the apical dendrite with the specified axis.
429
+
430
+ Parameters
431
+ ----------
432
+ axis : str, optional
433
+ The axis to align the apical dendrite with ('X', 'Y', or 'Z'), by default 'Y'.
434
+ facing : str, optional
435
+ The direction the apical dendrite should face ('up' or 'down'), by default 'up'.
436
+ """
437
+ soma_center = self.soma_center
438
+ apical_center = self.apical_center
439
+
440
+ if apical_center is None:
441
+ return
442
+
443
+ # Define the target vector based on the axis and facing
444
+ target_vector = {
445
+ 'X': np.array([1, 0, 0]),
446
+ 'Y': np.array([0, 1, 0]),
447
+ 'Z': np.array([0, 0, 1])
448
+ }.get(axis.upper(), None)
449
+
450
+ if target_vector is None:
451
+ raise ValueError("Axis must be 'X', 'Y', or 'Z'")
452
+
453
+ if facing == 'down':
454
+ target_vector = -target_vector
455
+
456
+ # Calculate the current vector
457
+ current_vector = apical_center - soma_center
458
+
459
+ # Check if the apical dendrite is already aligned
460
+ if np.allclose(current_vector / np.linalg.norm(current_vector), target_vector):
461
+ print('Apical dendrite is already aligned.')
462
+ return
463
+
464
+ # Calculate the rotation vector and angle
465
+ rotation_vector = np.cross(current_vector, target_vector)
466
+ rotation_angle = np.arccos(np.dot(current_vector, target_vector) / np.linalg.norm(current_vector))
467
+
468
+ # Create the rotation matrix
469
+ rotation_matrix = Rotation.from_rotvec(rotation_angle * rotation_vector / np.linalg.norm(rotation_vector)).as_matrix()
470
+
471
+ # Apply the rotation to each point
472
+ for pt in self.points:
473
+ coords = np.array([pt.x, pt.y, pt.z]) - soma_center
474
+ rotated_coords = np.dot(rotation_matrix, coords) + soma_center
475
+ pt.x, pt.y, pt.z = rotated_coords
476
+
477
+
478
+ # I/O METHODS
479
+ def remove_overlaps(self):
480
+ """
481
+ Remove overlapping nodes from the tree.
482
+ """
483
+ nodes_before = len(self.points)
484
+
485
+ overlapping_nodes = [
486
+ pt for pt in self.traverse()
487
+ if pt.parent is not None and pt.overlaps_with(pt.parent)
488
+ ]
489
+ for pt in overlapping_nodes:
490
+ self.remove_node(pt)
491
+
492
+ self._is_extended = False
493
+ nodes_after = len(self.points)
494
+ if nodes_before != nodes_after:
495
+ print(f'Removed {nodes_before - nodes_after} overlapping nodes.')
496
+
497
+
498
+ def extend_sections(self):
499
+ """
500
+ Extend each section by adding a node in the beginning
501
+ overlapping with the parent node for geometrical continuity.
502
+ """
503
+
504
+ nodes_before = len(self.points)
505
+
506
+ if self._is_extended:
507
+ print('Tree is already extended.')
508
+ return
509
+
510
+ bifurcations_excluding_root = [
511
+ b for b in self.bifurcations if b != self.root
512
+ ]
513
+
514
+ for pt in bifurcations_excluding_root:
515
+ children = pt.children[:]
516
+ for child in children:
517
+ if child.overlaps_with(pt):
518
+ raise ValueError(f'Child {child} already overlaps with parent {pt}.')
519
+ new_node = pt.copy()
520
+ new_node.domain = child.domain
521
+ self.insert_node_before(new_node, child)
522
+
523
+ self._is_extended = True
524
+ nodes_after = len(self.points)
525
+ print(f'Extended {nodes_after - nodes_before} nodes.')
526
+
527
+
528
+ def to_swc(self, path_to_file):
529
+ """
530
+ Save the tree to an SWC file.
531
+ """
532
+ with remove_overlaps(self):
533
+ df = self.df.astype({
534
+ 'idx': int,
535
+ 'type_idx': int,
536
+ 'x': float,
537
+ 'y': float,
538
+ 'z': float,
539
+ 'r': float,
540
+ 'parent_idx': int
541
+ })
542
+ df['idx'] += 1
543
+ df.loc[df['parent_idx'] >= 0, 'parent_idx'] += 1
544
+ df.to_csv(path_to_file, sep=' ', index=False, header=False)
545
+
546
+
547
+ # PLOTTING METHODS
548
+
549
+ def plot(self, ax=None,
550
+ show_nodes=True, show_edges=True, show_domains=True,
551
+ annotate=False, projection='XY',
552
+ highlight_nodes=None, focus_nodes=None):
553
+ """
554
+ Plot a 2D projection of the tree.
555
+
556
+ Parameters
557
+ ----------
558
+ ax : matplotlib.axes.Axes, optional
559
+ The axes to plot on, by default None
560
+ show_nodes : bool, optional
561
+ Whether to plot the nodes, by default True
562
+ show_edges : bool, optional
563
+ Whether to plot the edges, by default True
564
+ show_domains : bool, optional
565
+ Whether to color the nodes based on their domains, by default True
566
+ annotate : bool, optional
567
+ Whether to annotate the nodes with their indices, by default False
568
+ projection : str, optional
569
+ The projection plane ('XY', 'XZ', or 'YZ'), by default 'XY'
570
+ highlight_nodes : list, optional
571
+ A list of nodes to highlight, by default None
572
+ focus_nodes : list, optional
573
+ A list of nodes to focus on, by default None
574
+ """
575
+
576
+ if ax is None:
577
+ fig, ax = plt.subplots(figsize=(10, 10))
578
+
579
+ # Convert focus/highlight to sets for faster lookup
580
+ focus_nodes = set(focus_nodes) if focus_nodes else None
581
+ highlight_nodes = set(highlight_nodes) if highlight_nodes else None
582
+
583
+ # Determine which points to consider
584
+ points_to_plot = self.points if focus_nodes is None else [pt for pt in self.points if pt in focus_nodes]
585
+
586
+ # Extract coordinates for projection
587
+ coords = {axis: [getattr(pt, axis.lower()) for pt in points_to_plot] for axis in "XYZ"}
588
+
589
+ # Draw edges efficiently
590
+ if show_edges:
591
+ point_set = set(points_to_plot) # Convert list to set for fast lookup
592
+ for pt1, pt2 in self.edges:
593
+ if pt1 in point_set and pt2 in point_set:
594
+ ax.plot(
595
+ [getattr(pt1, projection[0].lower()), getattr(pt2, projection[0].lower())],
596
+ [getattr(pt1, projection[1].lower()), getattr(pt2, projection[1].lower())],
597
+ color='C1'
598
+ )
599
+
600
+ # Assign colors based on domains
601
+ if show_domains:
602
+ for pt in points_to_plot:
603
+ colors = [get_domain_color(pt.domain) for pt in points_to_plot]
604
+ else:
605
+ colors = 'C0'
606
+
607
+ # Plot nodes
608
+ if show_nodes:
609
+ ax.scatter(coords[projection[0]], coords[projection[1]], s=10, c=colors, marker='.', zorder=2)
610
+
611
+ # Annotate nodes if few enough
612
+ if annotate and len(points_to_plot) < 50:
613
+ for pt, x, y in zip(points_to_plot, coords[projection[0]], coords[projection[1]]):
614
+ ax.annotate(f'{pt.idx}', (x, y), fontsize=8)
615
+
616
+ # Highlight nodes correctly
617
+ if highlight_nodes:
618
+ for i, pt in enumerate(points_to_plot):
619
+ if pt in highlight_nodes:
620
+ ax.plot(coords[projection[0]][i], coords[projection[1]][i], 'o', color='C3', markersize=5)
621
+
622
+ # Set labels and aspect ratio
623
+ ax.set_xlabel(projection[0])
624
+ ax.set_ylabel(projection[1])
625
+ if projection in {"XY", "XZ", "YZ"}:
626
+ ax.set_aspect('equal')
627
+
628
+
629
+
630
+ def plot_radii_distribution(self, ax=None, highlight=None,
631
+ domains=True, show_soma=False):
632
+ if ax is None:
633
+ fig, ax = plt.subplots(figsize=(8, 3))
634
+
635
+ for pt in self.points:
636
+ if not show_soma and pt.domain == 'soma':
637
+ continue
638
+ color = get_domain_color(pt.domain)
639
+ if highlight and pt.idx in highlight:
640
+ ax.plot(
641
+ pt.path_distance(),
642
+ pt.r,
643
+ marker='.',
644
+ color='red',
645
+ zorder=2
646
+ )
647
+ else:
648
+ ax.plot(
649
+ pt.path_distance(),
650
+ pt.r,
651
+ marker='.',
652
+ color=color,
653
+ zorder=1
654
+ )
655
+ ax.set_xlabel('Distance from root')
656
+ ax.set_ylabel('Radius')
657
+
658
+
659
+ @contextmanager
660
+ def remove_overlaps(point_tree):
661
+ """
662
+ Context manager for temporarily removing overlaps in the given point_tree.
663
+ Is primarily used for saving the tree to an SWC file without overlaps.
664
+ Restores the original state of the tree after the context block to ensure
665
+ the geometrical continuity of the tree.
666
+ """
667
+ # Store whether the point_tree was already extended
668
+ was_extended = point_tree._is_extended
669
+
670
+ # Remove overlaps
671
+ point_tree.remove_overlaps()
672
+ point_tree.sort()
673
+
674
+ try:
675
+ # Yield control to the context block
676
+ yield
677
+ finally:
678
+ # Restore the overlapping state if the point_tree was extended
679
+ if was_extended:
680
+ point_tree.extend_sections()
681
+ point_tree.sort()
@@ -0,0 +1,16 @@
1
+ from dendrotweaks.morphology.reduce.reduce import map_segs_to_params
2
+ from dendrotweaks.morphology.reduce.reduce import map_segs_to_locs
3
+ from dendrotweaks.morphology.reduce.reduce import map_segs_to_reduced_segs
4
+ from dendrotweaks.morphology.reduce.reduce import map_reduced_segs_to_params
5
+
6
+ from dendrotweaks.morphology.reduce.reduce import set_avg_params_to_reduced_segs
7
+ from dendrotweaks.morphology.reduce.reduce import interpolate_missing_values
8
+
9
+
10
+ import neuron_reduce
11
+ from dendrotweaks.morphology.reduce.reduced_cylinder import _get_subtree_biophysical_properties
12
+ neuron_reduce.reducing_methods._get_subtree_biophysical_properties = _get_subtree_biophysical_properties
13
+ from neuron_reduce.reducing_methods import reduce_subtree as get_unique_cable_properties
14
+
15
+ from dendrotweaks.morphology.reduce.reduced_cylinder import calculate_nsegs
16
+ from dendrotweaks.morphology.reduce.reduced_cylinder import apply_params_to_section