dendrotweaks 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. dendrotweaks/__init__.py +10 -0
  2. dendrotweaks/analysis/__init__.py +11 -0
  3. dendrotweaks/analysis/ephys_analysis.py +482 -0
  4. dendrotweaks/analysis/morphometric_analysis.py +106 -0
  5. dendrotweaks/membrane/__init__.py +6 -0
  6. dendrotweaks/membrane/default_mod/AMPA.mod +65 -0
  7. dendrotweaks/membrane/default_mod/AMPA_NMDA.mod +100 -0
  8. dendrotweaks/membrane/default_mod/CaDyn.mod +54 -0
  9. dendrotweaks/membrane/default_mod/GABAa.mod +65 -0
  10. dendrotweaks/membrane/default_mod/Leak.mod +27 -0
  11. dendrotweaks/membrane/default_mod/NMDA.mod +72 -0
  12. dendrotweaks/membrane/default_mod/vecstim.mod +76 -0
  13. dendrotweaks/membrane/default_templates/NEURON_template.py +354 -0
  14. dendrotweaks/membrane/default_templates/default.py +73 -0
  15. dendrotweaks/membrane/default_templates/standard_channel.mod +87 -0
  16. dendrotweaks/membrane/default_templates/template_jaxley.py +108 -0
  17. dendrotweaks/membrane/default_templates/template_jaxley_new.py +108 -0
  18. dendrotweaks/membrane/distributions.py +324 -0
  19. dendrotweaks/membrane/groups.py +103 -0
  20. dendrotweaks/membrane/io/__init__.py +11 -0
  21. dendrotweaks/membrane/io/ast.py +201 -0
  22. dendrotweaks/membrane/io/code_generators.py +312 -0
  23. dendrotweaks/membrane/io/converter.py +108 -0
  24. dendrotweaks/membrane/io/factories.py +144 -0
  25. dendrotweaks/membrane/io/grammar.py +417 -0
  26. dendrotweaks/membrane/io/loader.py +90 -0
  27. dendrotweaks/membrane/io/parser.py +499 -0
  28. dendrotweaks/membrane/io/reader.py +212 -0
  29. dendrotweaks/membrane/mechanisms.py +574 -0
  30. dendrotweaks/model.py +1916 -0
  31. dendrotweaks/model_io.py +75 -0
  32. dendrotweaks/morphology/__init__.py +5 -0
  33. dendrotweaks/morphology/domains.py +100 -0
  34. dendrotweaks/morphology/io/__init__.py +5 -0
  35. dendrotweaks/morphology/io/factories.py +212 -0
  36. dendrotweaks/morphology/io/reader.py +66 -0
  37. dendrotweaks/morphology/io/validation.py +212 -0
  38. dendrotweaks/morphology/point_trees.py +681 -0
  39. dendrotweaks/morphology/reduce/__init__.py +16 -0
  40. dendrotweaks/morphology/reduce/reduce.py +155 -0
  41. dendrotweaks/morphology/reduce/reduced_cylinder.py +129 -0
  42. dendrotweaks/morphology/sec_trees.py +1112 -0
  43. dendrotweaks/morphology/seg_trees.py +157 -0
  44. dendrotweaks/morphology/trees.py +567 -0
  45. dendrotweaks/path_manager.py +261 -0
  46. dendrotweaks/simulators.py +235 -0
  47. dendrotweaks/stimuli/__init__.py +3 -0
  48. dendrotweaks/stimuli/iclamps.py +73 -0
  49. dendrotweaks/stimuli/populations.py +265 -0
  50. dendrotweaks/stimuli/synapses.py +203 -0
  51. dendrotweaks/utils.py +239 -0
  52. dendrotweaks-0.3.1.dist-info/METADATA +70 -0
  53. dendrotweaks-0.3.1.dist-info/RECORD +56 -0
  54. dendrotweaks-0.3.1.dist-info/WHEEL +5 -0
  55. dendrotweaks-0.3.1.dist-info/licenses/LICENSE +674 -0
  56. dendrotweaks-0.3.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,75 @@
1
+ import jinja2
2
+ import os
3
+ from collections import defaultdict
4
+
5
+ DOMAINS_TO_NEURON = {
6
+ 'soma': 'soma',
7
+ 'perisomatic': 'dend_11',
8
+ 'axon': 'axon',
9
+ 'apic': 'apic',
10
+ 'dend': 'dend',
11
+ 'basal': 'dend_31',
12
+ 'trunk': 'dend_41',
13
+ 'tuft': 'dend_42',
14
+ 'oblique': 'dend_43',
15
+ 'custom': 'dend_5',
16
+ 'reduced': 'dend_8',
17
+ 'undefined': 'dend_0',
18
+ }
19
+
20
+ def get_neuron_domain(domain_name):
21
+ base_domain, _, idx = domain_name.partition('_')
22
+ if base_domain in ['reduced', 'custom'] and idx.isdigit():
23
+ return f'{DOMAINS_TO_NEURON[base_domain]}{idx}'
24
+ return DOMAINS_TO_NEURON.get(base_domain, 'dend_0')
25
+
26
+ def render_template(path_to_template, context):
27
+ """
28
+ Render a Jinja2 template.
29
+
30
+ Parameters
31
+ ----------
32
+ path_to_template : str
33
+ The path to the Jinja2 template.
34
+ context : dict
35
+ The context to render the template with.
36
+ """
37
+ with open(path_to_template, 'r') as f:
38
+ template = jinja2.Template(f.read())
39
+ return template.render(context)
40
+
41
+
42
+ def get_params_to_valid_domains(model):
43
+
44
+ params_to_valid_domains = defaultdict(dict)
45
+
46
+ for param, mech in model.params_to_mechs.items():
47
+ for group_name, distribution in model.params[param].items():
48
+ group = model.groups[group_name]
49
+ valid_domains = [get_neuron_domain(domain) for domain in group.domains if mech == 'Independent' or mech in model.domains_to_mechs[domain]]
50
+ params_to_valid_domains[param][group_name] = valid_domains
51
+
52
+ return dict(params_to_valid_domains)
53
+
54
+
55
+ def filter_params(model):
56
+ """
57
+ Filter out kinetic parameters from the model.
58
+
59
+ Parameters
60
+ ----------
61
+ model : Model
62
+ The model to filter.
63
+
64
+ Returns
65
+ -------
66
+ Model
67
+ The model with kinetic parameters filtered out.
68
+ """
69
+ filtered_params = {
70
+ param: {
71
+ group_name: distribution
72
+ for group_name, distribution in distributions.items()
73
+ if param in list(model.conductances.keys()) + ['cm', 'Ra', 'ena', 'ek', 'eca']}
74
+ for param, distributions in model.params.items()}
75
+ return filtered_params
@@ -0,0 +1,5 @@
1
+ from dendrotweaks.morphology.trees import Node, Tree
2
+ from dendrotweaks.morphology.point_trees import Point, PointTree
3
+ from dendrotweaks.morphology.sec_trees import Section, SectionTree, Domain
4
+ from dendrotweaks.morphology.seg_trees import Segment, SegmentTree
5
+ from dendrotweaks.morphology.io.validation import validate_tree
@@ -0,0 +1,100 @@
1
+
2
+ class Domain:
3
+ """
4
+ A class representing
5
+ a morphological or functional domain in a neuron.
6
+
7
+ Parameters
8
+ ----------
9
+ name : str
10
+ The name of the domain.
11
+ sections : List[Section], optional
12
+ A list of sections in the domain.
13
+
14
+ Attributes
15
+ ----------
16
+ name : str
17
+ The name of the domain.
18
+ """
19
+
20
+ def __init__(self, name: str, sections = None) -> None:
21
+ self.name = name
22
+ self._sections = sections if sections else []
23
+
24
+
25
+ def __repr__(self):
26
+ return f'<Domain({self.name}, {len(self.sections)} sections)>'
27
+
28
+
29
+ def __contains__(self, section):
30
+ return section in self.sections
31
+
32
+
33
+ @property
34
+ def sections(self):
35
+ """
36
+ A list of sections in the domain.
37
+ """
38
+ return self._sections
39
+
40
+
41
+ # def merge(self, other):
42
+ # """
43
+ # Merge the sections of the other domain into this domain.
44
+ # """
45
+ # self.inserted_mechanisms.update(other.inserted_mechanisms)
46
+ # sections = self.sections + other.sections
47
+ # self._sections = []
48
+ # for sec in sections:
49
+ # self.add_section(sec)
50
+
51
+
52
+ def add_section(self, sec: "Section"):
53
+ """
54
+ Add a section to the domain.
55
+
56
+ Changes the domain attribute of the section.
57
+
58
+ Parameters
59
+ ----------
60
+ sec : Section
61
+ The section to be added to the domain.
62
+ """
63
+ if sec in self._sections:
64
+ warnings.warn(f'Section {sec} already in domain {self.name}.')
65
+ return
66
+ sec.domain = self.name
67
+ sec.domain_idx = len(self._sections)
68
+ self._sections.append(sec)
69
+
70
+
71
+ def remove_section(self, sec):
72
+ """
73
+ Remove a section from the domain.
74
+
75
+ Sets the domain attribute of the section
76
+ to None.
77
+
78
+ Parameters
79
+ ----------
80
+ sec : Section
81
+ The section to be removed from the domain.
82
+ """
83
+ if sec not in self.sections:
84
+ warnings.warn(f'Section {sec} not in domain {self.name}.')
85
+ return
86
+ sec.domain = None
87
+ sec.domain_idx = None
88
+ self._sections.remove(sec)
89
+
90
+
91
+ def is_empty(self):
92
+ """
93
+ Check if the domain is empty.
94
+
95
+ Returns
96
+ -------
97
+ bool
98
+ True if the domain is empty, False otherwise.
99
+ """
100
+ return not bool(self._sections)
@@ -0,0 +1,5 @@
1
+ from dendrotweaks.morphology.io.reader import SWCReader
2
+ from dendrotweaks.morphology.io.validation import validate_tree
3
+ from dendrotweaks.morphology.io.factories import create_point_tree
4
+ from dendrotweaks.morphology.io.factories import create_section_tree
5
+ from dendrotweaks.morphology.io.factories import create_segment_tree
@@ -0,0 +1,212 @@
1
+ from dendrotweaks.morphology.trees import Node, Tree
2
+ from dendrotweaks.morphology.point_trees import Point, PointTree
3
+ from dendrotweaks.morphology.sec_trees import Section, SectionTree
4
+ from dendrotweaks.morphology.seg_trees import Segment, SegmentTree
5
+
6
+ from dendrotweaks.morphology.io.reader import SWCReader
7
+
8
+ from typing import List, Union
9
+ import numpy as np
10
+ from pandas import DataFrame
11
+
12
+ from dendrotweaks.morphology.io.validation import validate_tree
13
+
14
+
15
+
16
+ def create_point_tree(source: Union[str, DataFrame]) -> PointTree:
17
+ """
18
+ Create a point tree from either a file path or a DataFrame.
19
+
20
+ Parameters
21
+ ----------
22
+ source : Union[str, DataFrame]
23
+ The source of the SWC data. Can be a file path or a DataFrame.
24
+
25
+ Returns
26
+ -------
27
+ PointTree
28
+ The point tree representing the reconstruction of the neuron morphology.
29
+ """
30
+ if isinstance(source, str):
31
+ reader = SWCReader()
32
+ df = reader.read_file(source)
33
+ elif isinstance(source, DataFrame):
34
+ df = source
35
+ else:
36
+ raise ValueError("Source must be a file path (str) or a DataFrame.")
37
+
38
+ nodes = [
39
+ Point(row['Index'], row['Type'], row['X'], row['Y'], row['Z'], row['R'], row['Parent'])
40
+ for _, row in df.iterrows()
41
+ ]
42
+ point_tree = PointTree(nodes)
43
+ point_tree.remove_overlaps()
44
+
45
+ return point_tree
46
+
47
+
48
+ def create_section_tree(point_tree: PointTree):
49
+ """
50
+ Create a section tree from a point tree.
51
+
52
+ Parameters
53
+ ----------
54
+ point_tree : PointTree
55
+ The point tree to create the section tree from by splitting it into sections.
56
+
57
+ Returns
58
+ -------
59
+ SectionTree
60
+ The section tree created representing the neuron morphology on a more abstract level.
61
+ """
62
+
63
+ point_tree.extend_sections()
64
+ point_tree.sort()
65
+
66
+ sections = _split_to_sections(point_tree)
67
+
68
+ sec_tree = SectionTree(sections)
69
+ sec_tree._point_tree = point_tree
70
+
71
+ return sec_tree
72
+
73
+
74
+ def _split_to_sections(point_tree: PointTree) -> List[Section]:
75
+ """
76
+ Split the point tree into sections.
77
+ """
78
+ sections = []
79
+
80
+ bifurcation_children = [
81
+ child for b in point_tree.bifurcations for child in b.children]
82
+ bifurcation_children = [point_tree.root] + bifurcation_children
83
+ # Filter out the bifurcation children to enforce the original order
84
+ bifurcation_children = [node for node in point_tree._nodes
85
+ if node in bifurcation_children]
86
+
87
+ # Assign a section to each bifurcation child
88
+ for i, child in enumerate(bifurcation_children):
89
+ section = Section(idx=i, parent_idx=-1, points=[child])
90
+ sections.append(section)
91
+ child._section = section
92
+ # Propagate the section to the children until the next
93
+ # bifurcation or termination point is reached
94
+ while child.children:
95
+ next_child = child.children[0]
96
+ if next_child in bifurcation_children:
97
+ break
98
+ next_child._section = section
99
+ section.points.append(next_child)
100
+ child = next_child
101
+
102
+ section.parent = section.points[0].parent._section if section.points[0].parent else None
103
+ section.parent_idx = section.parent.idx if section.parent else -1
104
+
105
+
106
+ if point_tree.soma_notation == '3PS':
107
+ sections = _merge_soma(sections, point_tree)
108
+
109
+ return sections
110
+
111
+
112
+ def _merge_soma(sections: List[Section], point_tree: PointTree):
113
+ """
114
+ If soma has 3PS notation, merge it into one section.
115
+ """
116
+
117
+ true_soma = point_tree.root._section
118
+ true_soma.idx = 0
119
+ true_soma.parent_idx = -1
120
+
121
+ false_somas = [sec for sec in sections
122
+ if sec.domain == 'soma' and sec is not true_soma]
123
+ if len(false_somas) != 2:
124
+ print(false_somas)
125
+ raise ValueError('Soma must have exactly 2 children of domain soma.')
126
+
127
+ for i, sec in enumerate(false_somas):
128
+ sections.remove(sec)
129
+ if len(sec.points) != 1:
130
+ raise ValueError('Soma children must have exactly 1 point.')
131
+ for pt in sec.points:
132
+ pt._section = true_soma
133
+
134
+ true_soma.points = [
135
+ false_somas[0].points[0],
136
+ true_soma.points[0],
137
+ false_somas[1].points[0]
138
+ ]
139
+
140
+ for sec in sections:
141
+ if sec is true_soma:
142
+ continue
143
+ sec.idx -= 2
144
+ sec.parent_idx = sec.points[0].parent._section.idx
145
+
146
+
147
+ return sections
148
+
149
+
150
+ def create_segment_tree(sec_tree):
151
+ """
152
+ Create a segment tree from a section tree.
153
+
154
+ Parameters
155
+ ----------
156
+ sec_tree : SectionTree
157
+ The section tree to create the segment tree from by splitting it into segments.
158
+
159
+ Returns
160
+ -------
161
+ SegmentTree
162
+ The segment tree representing spatial discretization of the neuron morphology for numerical simulations.
163
+ """
164
+
165
+ segments = _create_segments(sec_tree)
166
+
167
+ seg_tree = SegmentTree(segments)
168
+ sec_tree._seg_tree = seg_tree
169
+
170
+ return seg_tree
171
+
172
+
173
+ def _create_segments(sec_tree) -> List[Segment]:
174
+ """
175
+ Create a list of Segment objects from a SectionTree object.
176
+ """
177
+
178
+ segments = []
179
+ # TODO: Refactor this to use a stack instead of recursion
180
+ def add_segments(sec, parent_idx, idx_counter):
181
+ segs = {seg: idx + idx_counter for idx, seg in enumerate(sec._ref)}
182
+ sec.segments = []
183
+ for seg, idx in segs.items():
184
+ segment = Segment(
185
+ idx=idx, parent_idx=parent_idx, neuron_seg=seg, section=sec)
186
+ segments.append(segment)
187
+ sec.segments.append(segment)
188
+
189
+ parent_idx = idx
190
+
191
+ idx_counter += len(segs)
192
+
193
+
194
+ for child in sec.children:
195
+ # IMPORTANT: This is needed since 0 and 1 segments are not explicitly
196
+ # defined in the section segments list
197
+ if child._ref.parentseg().x == 1:
198
+ new_parent_idx = list(segs.values())[-1]
199
+ elif child._ref.parentseg().x == 0:
200
+ new_parent_idx = list(segs.values())[0]
201
+ else:
202
+ new_parent_idx = segs[child._ref.parentseg()]
203
+ # Recurse for the child section
204
+ idx_counter = add_segments(child, new_parent_idx, idx_counter)
205
+
206
+ return idx_counter
207
+
208
+ # Start with the root section of the sec_tree
209
+ add_segments(sec_tree.root, parent_idx=-1, idx_counter=0)
210
+
211
+ return segments
212
+
@@ -0,0 +1,66 @@
1
+ import pandas as pd
2
+
3
+ class SWCReader():
4
+ """
5
+ Reads an SWC file and returns a DataFrame.
6
+ """
7
+
8
+ def __init__(self):
9
+ pass
10
+
11
+ @staticmethod
12
+ def read_file(path_to_swc_file: str) -> pd.DataFrame:
13
+ """
14
+ Read the SWC file and return a DataFrame.
15
+
16
+ Parameters
17
+ ----------
18
+ path_to_swc_file : str
19
+ The full path to the SWC file.
20
+
21
+ Returns
22
+ -------
23
+ pd.DataFrame
24
+ The DataFrame containing the SWC data
25
+ """
26
+ with open(path_to_swc_file, 'r') as f:
27
+ lines = f.readlines()
28
+ lines = [' '.join(line.split()) for line in lines if line.strip()]
29
+ with open(path_to_swc_file, 'w') as f:
30
+ f.write('\n'.join(lines))
31
+
32
+ df = pd.read_csv(
33
+ path_to_swc_file,
34
+ sep=' ',
35
+ header=None,
36
+ comment='#',
37
+ names=['Index', 'Type', 'X', 'Y', 'Z', 'R', 'Parent'],
38
+ index_col=False
39
+ )
40
+
41
+ if (df['R'] == 0).all():
42
+ df['R'] = 1.0
43
+
44
+ if df['Index'].duplicated().any():
45
+ raise ValueError("The SWC file contains duplicate node ids.")
46
+ return df
47
+
48
+ @staticmethod
49
+ def plot_raw_data(df, ax):
50
+ """
51
+ Plot the raw data from the SWC file using the DataFrame.
52
+
53
+ Parameters
54
+ ----------
55
+ df : pd.DataFrame
56
+ The DataFrame containing the SWC data (generated by read_file).
57
+ ax : matplotlib.pyplot.Axes
58
+ The axes to plot on.
59
+ """
60
+ types_to_colors = {1: 'C1', 2: 'C3', 3: 'C2', 4: 'C0', 31: 'green', 41: 'blue', 42: 'magenta', 43: 'brown'}
61
+ for t in df['Type'].unique():
62
+ color = types_to_colors.get(t, 'k')
63
+ mask = df['Type'] == t
64
+ ax.scatter(df[mask]['X'], df[mask]['Y'], df[mask]['Z'],
65
+ c=color, s=1, label=f'Type {t}')
66
+ ax.legend()
@@ -0,0 +1,212 @@
1
+ from dendrotweaks.morphology.trees import Tree
2
+ from dendrotweaks.morphology.point_trees import PointTree
3
+ from dendrotweaks.morphology.sec_trees import SectionTree
4
+
5
+ import numpy as np
6
+ import warnings
7
+
8
+ # def custom_warning_formatter(message, category, filename, lineno, file=None, line=None):
9
+ # return f"WARNING: {message}\n({os.path.basename(filename)}, line {lineno})\n"
10
+
11
+ # warnings.formatwarning = custom_warning_formatter
12
+
13
+
14
+ def validate_tree(tree):
15
+ """
16
+ Validate the toplogical structure of a tree graph.
17
+
18
+ Parameters
19
+ ----------
20
+ tree : Tree
21
+ The tree to validate.
22
+ """
23
+ # Check for unique node ids
24
+ check_unique_ids(tree)
25
+ check_unique_root(tree)
26
+ check_unique_children(tree)
27
+
28
+
29
+ # Check for connectivity
30
+ check_connections(tree)
31
+ check_loops(tree)
32
+ check_bifurcations(tree)
33
+ # validate_order(self.tree)
34
+
35
+ if isinstance(tree, PointTree):
36
+ validate_point_tree(tree)
37
+
38
+ if isinstance(tree, SectionTree):
39
+ validate_section_tree(tree)
40
+
41
+ # Check if the tree is sorted
42
+ print("Checking if the tree is sorted...")
43
+ if not tree.is_sorted:
44
+ warnings.warn("Tree is not sorted")
45
+
46
+ print("***Validation complete.***")
47
+
48
+
49
+ # -----------------------------------------------------------------------------
50
+ # Indicies
51
+ # -----------------------------------------------------------------------------
52
+
53
+ def check_unique_ids(tree):
54
+ print("Checking for unique node ids...")
55
+ node_ids = {node.idx for node in tree._nodes}
56
+ if len(node_ids) != len(tree._nodes):
57
+ warnings.warn(f"Tree contains {len(tree._nodes) - len(node_ids)} duplicate node ids.")
58
+
59
+
60
+ def check_unique_children(tree):
61
+ print("Checking for duplicate children...")
62
+ for node in tree._nodes:
63
+ children = node.children
64
+ if len(children) != len(set(children)):
65
+ warnings.warn(f"Node {node} contains duplicate children.")
66
+
67
+
68
+ def check_unique_root(tree):
69
+ print("Checking for unique root node...")
70
+ root_nodes = {node for node in tree._nodes
71
+ if node.parent is None or node.parent_idx in {None, -1, '-1'}
72
+ }
73
+ if len(root_nodes) > 1:
74
+ warnings.warn(f"Found {len(root_nodes)} root nodes.")
75
+ elif len(root_nodes) == 0:
76
+ warnings.warn("Tree does not contain a root node.")
77
+
78
+
79
+ # -----------------------------------------------------------------------------
80
+ # Connectivity
81
+ # -----------------------------------------------------------------------------
82
+
83
+ def check_connections(tree):
84
+ """
85
+ Validate the parent-child relationships in the tree.
86
+
87
+ 1. Ensure that every node is listed as a child of its parent.
88
+ 2. Ensure that the parent of each child matches the node.
89
+ """
90
+ print("Checking tree connectivity...")
91
+ if not tree.is_connected:
92
+ not_connected = set(tree._nodes) - set(tree.root.subtree)
93
+ warnings.warn(f"The following nodes are not connected to the root node: {not_connected}")
94
+
95
+ for node in tree._nodes:
96
+ parent = node.parent
97
+
98
+ # Validate that the node is in its parent's children list.
99
+ if parent is not None:
100
+ if node not in parent.children:
101
+ warnings.warn(
102
+ f"Validation Warning: Node {node} is not listed in the children of its parent {parent}. "
103
+ f"Expected parent.children to include {node}, but it does not."
104
+ )
105
+
106
+ # Validate that the parent of each child is the current node.
107
+ for child in node.children:
108
+ if child.parent is not node:
109
+ warnings.warn(
110
+ f"Validation Warning: Node {child} has an incorrect parent. "
111
+ f"Expected parent {node}, but found {child.parent}."
112
+ )
113
+
114
+
115
+ def check_loops(tree):
116
+ print("Checking for loops...")
117
+ for node in tree._nodes:
118
+ for descendant in node.subtree:
119
+ if node in descendant.children:
120
+ warnings.warn(f"Node {node} is a descendant of itself. Loop detected at node {descendant}.")
121
+
122
+ def check_bifurcations(tree):
123
+ print("Checking for bifurcations with more than 2 children...")
124
+ bifurcation_issues = {node: len(node.children) for node in tree.bifurcations if len(node.children) > 2 and node is not tree.root}
125
+ if bifurcation_issues:
126
+ issues_str = "\n".join([f"Node {node.idx:<6} has {count} children" for node, count in bifurcation_issues.items()])
127
+ warnings.warn(f"Tree contains bifurcations with more than 2 children:\n{issues_str}")
128
+
129
+
130
+ # =============================================================================
131
+ # Point-specific validation
132
+ # =============================================================================
133
+
134
+ def validate_point_tree(point_tree):
135
+ """
136
+ Validate the geometry of a point tree.
137
+
138
+ Parameters
139
+ ----------
140
+ point_tree : PointTree
141
+ The point tree to validate.
142
+ """
143
+
144
+ # Check for NaN values in the DataFrame
145
+ print("Checking for NaN values...")
146
+ nan_counts = point_tree.df.isnull().sum()
147
+ if nan_counts.sum() > 0:
148
+ warnings.warn(f"Found {nan_counts} NaN values in the DataFrame")
149
+
150
+ # Check for bifurcations in the soma
151
+ print("Checking for bifurcations in the soma...")
152
+ bifurcations_without_root = [pt for pt in point_tree.bifurcations
153
+ if pt is not point_tree.root]
154
+ bifurcations_within_soma = [pt for pt in bifurcations_without_root
155
+ if pt.type_idx == 1]
156
+ if bifurcations_within_soma:
157
+ warnings.warn(f"Soma must be non-branching. Found bifurcations: {bifurcations_within_soma}")
158
+
159
+
160
+ if point_tree._is_extended:
161
+ print("Checking the extended tree for geometric continuity...")
162
+ non_overlapping_children = [
163
+ (pt, child) for pt in bifurcations_without_root for child in pt.children
164
+ if not child.overlaps_with(pt)
165
+ ]
166
+ if non_overlapping_children:
167
+ issues_str = "\n".join([f"Child {child} does not overlap with parent {pt}" for pt, child in non_overlapping_children])
168
+ warnings.warn(f"Found non-overlapping children:\n{issues_str} for bifurcations")
169
+
170
+
171
+ # =============================================================================
172
+ # Section-specific validation
173
+ # =============================================================================
174
+
175
+ def validate_section_tree(section_tree):
176
+ """
177
+ Validate a section tree.
178
+
179
+ Parameters
180
+ ----------
181
+ section_tree : SectionTree
182
+ The section tree to validate.
183
+ """
184
+
185
+ print("Checking that all points in a section belong to the same domain...")
186
+ for sec in section_tree:
187
+ if not all(pt.domain == sec.domain for pt in sec.points):
188
+ warnings.warn('All points in a section must belong to the same domain.')
189
+
190
+ print("Checking that all sections have a non-zero length...")
191
+ if any(sec.length == 0 for sec in section_tree):
192
+ warnings.warn('Found sections with zero length.')
193
+
194
+ print("Checking that all sections (except soma) have 0 or 2 children...")
195
+ if any(len(sec.children) not in {0, 2} and sec is not section_tree.root for sec in section_tree):
196
+ warnings.warn('Found sections with an incorrect number of children.')
197
+
198
+ print("Checking that the root section has domain soma...")
199
+ if not section_tree.root.domain == 'soma':
200
+ warnings.warn('Root section must have domain soma.')
201
+
202
+
203
+ # =============================================================================
204
+ # Validation utilities
205
+ # =============================================================================
206
+
207
+ def shuffle_indices_for_testing(df):
208
+ idx_range = int(df['Index'].max() - df['Index'].min()) + 1
209
+ random_mapping = {k:v for k, v in zip(df['Index'], np.random.permutation(idx_range))}
210
+ df['Index'] = df['Index'].map(random_mapping)
211
+ df.loc[df['Parent'] != -1, 'Parent'] = df['Parent'].map(random_mapping)
212
+ return df