dendrotweaks 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dendrotweaks/__init__.py +10 -0
- dendrotweaks/analysis/__init__.py +11 -0
- dendrotweaks/analysis/ephys_analysis.py +482 -0
- dendrotweaks/analysis/morphometric_analysis.py +106 -0
- dendrotweaks/membrane/__init__.py +6 -0
- dendrotweaks/membrane/default_mod/AMPA.mod +65 -0
- dendrotweaks/membrane/default_mod/AMPA_NMDA.mod +100 -0
- dendrotweaks/membrane/default_mod/CaDyn.mod +54 -0
- dendrotweaks/membrane/default_mod/GABAa.mod +65 -0
- dendrotweaks/membrane/default_mod/Leak.mod +27 -0
- dendrotweaks/membrane/default_mod/NMDA.mod +72 -0
- dendrotweaks/membrane/default_mod/vecstim.mod +76 -0
- dendrotweaks/membrane/default_templates/NEURON_template.py +354 -0
- dendrotweaks/membrane/default_templates/default.py +73 -0
- dendrotweaks/membrane/default_templates/standard_channel.mod +87 -0
- dendrotweaks/membrane/default_templates/template_jaxley.py +108 -0
- dendrotweaks/membrane/default_templates/template_jaxley_new.py +108 -0
- dendrotweaks/membrane/distributions.py +324 -0
- dendrotweaks/membrane/groups.py +103 -0
- dendrotweaks/membrane/io/__init__.py +11 -0
- dendrotweaks/membrane/io/ast.py +201 -0
- dendrotweaks/membrane/io/code_generators.py +312 -0
- dendrotweaks/membrane/io/converter.py +108 -0
- dendrotweaks/membrane/io/factories.py +144 -0
- dendrotweaks/membrane/io/grammar.py +417 -0
- dendrotweaks/membrane/io/loader.py +90 -0
- dendrotweaks/membrane/io/parser.py +499 -0
- dendrotweaks/membrane/io/reader.py +212 -0
- dendrotweaks/membrane/mechanisms.py +574 -0
- dendrotweaks/model.py +1916 -0
- dendrotweaks/model_io.py +75 -0
- dendrotweaks/morphology/__init__.py +5 -0
- dendrotweaks/morphology/domains.py +100 -0
- dendrotweaks/morphology/io/__init__.py +5 -0
- dendrotweaks/morphology/io/factories.py +212 -0
- dendrotweaks/morphology/io/reader.py +66 -0
- dendrotweaks/morphology/io/validation.py +212 -0
- dendrotweaks/morphology/point_trees.py +681 -0
- dendrotweaks/morphology/reduce/__init__.py +16 -0
- dendrotweaks/morphology/reduce/reduce.py +155 -0
- dendrotweaks/morphology/reduce/reduced_cylinder.py +129 -0
- dendrotweaks/morphology/sec_trees.py +1112 -0
- dendrotweaks/morphology/seg_trees.py +157 -0
- dendrotweaks/morphology/trees.py +567 -0
- dendrotweaks/path_manager.py +261 -0
- dendrotweaks/simulators.py +235 -0
- dendrotweaks/stimuli/__init__.py +3 -0
- dendrotweaks/stimuli/iclamps.py +73 -0
- dendrotweaks/stimuli/populations.py +265 -0
- dendrotweaks/stimuli/synapses.py +203 -0
- dendrotweaks/utils.py +239 -0
- dendrotweaks-0.3.1.dist-info/METADATA +70 -0
- dendrotweaks-0.3.1.dist-info/RECORD +56 -0
- dendrotweaks-0.3.1.dist-info/WHEEL +5 -0
- dendrotweaks-0.3.1.dist-info/licenses/LICENSE +674 -0
- dendrotweaks-0.3.1.dist-info/top_level.txt +1 -0
dendrotweaks/model_io.py
ADDED
@@ -0,0 +1,75 @@
|
|
1
|
+
import jinja2
|
2
|
+
import os
|
3
|
+
from collections import defaultdict
|
4
|
+
|
5
|
+
DOMAINS_TO_NEURON = {
|
6
|
+
'soma': 'soma',
|
7
|
+
'perisomatic': 'dend_11',
|
8
|
+
'axon': 'axon',
|
9
|
+
'apic': 'apic',
|
10
|
+
'dend': 'dend',
|
11
|
+
'basal': 'dend_31',
|
12
|
+
'trunk': 'dend_41',
|
13
|
+
'tuft': 'dend_42',
|
14
|
+
'oblique': 'dend_43',
|
15
|
+
'custom': 'dend_5',
|
16
|
+
'reduced': 'dend_8',
|
17
|
+
'undefined': 'dend_0',
|
18
|
+
}
|
19
|
+
|
20
|
+
def get_neuron_domain(domain_name):
|
21
|
+
base_domain, _, idx = domain_name.partition('_')
|
22
|
+
if base_domain in ['reduced', 'custom'] and idx.isdigit():
|
23
|
+
return f'{DOMAINS_TO_NEURON[base_domain]}{idx}'
|
24
|
+
return DOMAINS_TO_NEURON.get(base_domain, 'dend_0')
|
25
|
+
|
26
|
+
def render_template(path_to_template, context):
|
27
|
+
"""
|
28
|
+
Render a Jinja2 template.
|
29
|
+
|
30
|
+
Parameters
|
31
|
+
----------
|
32
|
+
path_to_template : str
|
33
|
+
The path to the Jinja2 template.
|
34
|
+
context : dict
|
35
|
+
The context to render the template with.
|
36
|
+
"""
|
37
|
+
with open(path_to_template, 'r') as f:
|
38
|
+
template = jinja2.Template(f.read())
|
39
|
+
return template.render(context)
|
40
|
+
|
41
|
+
|
42
|
+
def get_params_to_valid_domains(model):
|
43
|
+
|
44
|
+
params_to_valid_domains = defaultdict(dict)
|
45
|
+
|
46
|
+
for param, mech in model.params_to_mechs.items():
|
47
|
+
for group_name, distribution in model.params[param].items():
|
48
|
+
group = model.groups[group_name]
|
49
|
+
valid_domains = [get_neuron_domain(domain) for domain in group.domains if mech == 'Independent' or mech in model.domains_to_mechs[domain]]
|
50
|
+
params_to_valid_domains[param][group_name] = valid_domains
|
51
|
+
|
52
|
+
return dict(params_to_valid_domains)
|
53
|
+
|
54
|
+
|
55
|
+
def filter_params(model):
|
56
|
+
"""
|
57
|
+
Filter out kinetic parameters from the model.
|
58
|
+
|
59
|
+
Parameters
|
60
|
+
----------
|
61
|
+
model : Model
|
62
|
+
The model to filter.
|
63
|
+
|
64
|
+
Returns
|
65
|
+
-------
|
66
|
+
Model
|
67
|
+
The model with kinetic parameters filtered out.
|
68
|
+
"""
|
69
|
+
filtered_params = {
|
70
|
+
param: {
|
71
|
+
group_name: distribution
|
72
|
+
for group_name, distribution in distributions.items()
|
73
|
+
if param in list(model.conductances.keys()) + ['cm', 'Ra', 'ena', 'ek', 'eca']}
|
74
|
+
for param, distributions in model.params.items()}
|
75
|
+
return filtered_params
|
@@ -0,0 +1,5 @@
|
|
1
|
+
from dendrotweaks.morphology.trees import Node, Tree
|
2
|
+
from dendrotweaks.morphology.point_trees import Point, PointTree
|
3
|
+
from dendrotweaks.morphology.sec_trees import Section, SectionTree, Domain
|
4
|
+
from dendrotweaks.morphology.seg_trees import Segment, SegmentTree
|
5
|
+
from dendrotweaks.morphology.io.validation import validate_tree
|
@@ -0,0 +1,100 @@
|
|
1
|
+
|
2
|
+
class Domain:
|
3
|
+
"""
|
4
|
+
A class representing
|
5
|
+
a morphological or functional domain in a neuron.
|
6
|
+
|
7
|
+
Parameters
|
8
|
+
----------
|
9
|
+
name : str
|
10
|
+
The name of the domain.
|
11
|
+
sections : List[Section], optional
|
12
|
+
A list of sections in the domain.
|
13
|
+
|
14
|
+
Attributes
|
15
|
+
----------
|
16
|
+
name : str
|
17
|
+
The name of the domain.
|
18
|
+
"""
|
19
|
+
|
20
|
+
def __init__(self, name: str, sections = None) -> None:
|
21
|
+
self.name = name
|
22
|
+
self._sections = sections if sections else []
|
23
|
+
|
24
|
+
|
25
|
+
def __repr__(self):
|
26
|
+
return f'<Domain({self.name}, {len(self.sections)} sections)>'
|
27
|
+
|
28
|
+
|
29
|
+
def __contains__(self, section):
|
30
|
+
return section in self.sections
|
31
|
+
|
32
|
+
|
33
|
+
@property
|
34
|
+
def sections(self):
|
35
|
+
"""
|
36
|
+
A list of sections in the domain.
|
37
|
+
"""
|
38
|
+
return self._sections
|
39
|
+
|
40
|
+
|
41
|
+
# def merge(self, other):
|
42
|
+
# """
|
43
|
+
# Merge the sections of the other domain into this domain.
|
44
|
+
# """
|
45
|
+
# self.inserted_mechanisms.update(other.inserted_mechanisms)
|
46
|
+
# sections = self.sections + other.sections
|
47
|
+
# self._sections = []
|
48
|
+
# for sec in sections:
|
49
|
+
# self.add_section(sec)
|
50
|
+
|
51
|
+
|
52
|
+
def add_section(self, sec: "Section"):
|
53
|
+
"""
|
54
|
+
Add a section to the domain.
|
55
|
+
|
56
|
+
Changes the domain attribute of the section.
|
57
|
+
|
58
|
+
Parameters
|
59
|
+
----------
|
60
|
+
sec : Section
|
61
|
+
The section to be added to the domain.
|
62
|
+
"""
|
63
|
+
if sec in self._sections:
|
64
|
+
warnings.warn(f'Section {sec} already in domain {self.name}.')
|
65
|
+
return
|
66
|
+
sec.domain = self.name
|
67
|
+
sec.domain_idx = len(self._sections)
|
68
|
+
self._sections.append(sec)
|
69
|
+
|
70
|
+
|
71
|
+
def remove_section(self, sec):
|
72
|
+
"""
|
73
|
+
Remove a section from the domain.
|
74
|
+
|
75
|
+
Sets the domain attribute of the section
|
76
|
+
to None.
|
77
|
+
|
78
|
+
Parameters
|
79
|
+
----------
|
80
|
+
sec : Section
|
81
|
+
The section to be removed from the domain.
|
82
|
+
"""
|
83
|
+
if sec not in self.sections:
|
84
|
+
warnings.warn(f'Section {sec} not in domain {self.name}.')
|
85
|
+
return
|
86
|
+
sec.domain = None
|
87
|
+
sec.domain_idx = None
|
88
|
+
self._sections.remove(sec)
|
89
|
+
|
90
|
+
|
91
|
+
def is_empty(self):
|
92
|
+
"""
|
93
|
+
Check if the domain is empty.
|
94
|
+
|
95
|
+
Returns
|
96
|
+
-------
|
97
|
+
bool
|
98
|
+
True if the domain is empty, False otherwise.
|
99
|
+
"""
|
100
|
+
return not bool(self._sections)
|
@@ -0,0 +1,5 @@
|
|
1
|
+
from dendrotweaks.morphology.io.reader import SWCReader
|
2
|
+
from dendrotweaks.morphology.io.validation import validate_tree
|
3
|
+
from dendrotweaks.morphology.io.factories import create_point_tree
|
4
|
+
from dendrotweaks.morphology.io.factories import create_section_tree
|
5
|
+
from dendrotweaks.morphology.io.factories import create_segment_tree
|
@@ -0,0 +1,212 @@
|
|
1
|
+
from dendrotweaks.morphology.trees import Node, Tree
|
2
|
+
from dendrotweaks.morphology.point_trees import Point, PointTree
|
3
|
+
from dendrotweaks.morphology.sec_trees import Section, SectionTree
|
4
|
+
from dendrotweaks.morphology.seg_trees import Segment, SegmentTree
|
5
|
+
|
6
|
+
from dendrotweaks.morphology.io.reader import SWCReader
|
7
|
+
|
8
|
+
from typing import List, Union
|
9
|
+
import numpy as np
|
10
|
+
from pandas import DataFrame
|
11
|
+
|
12
|
+
from dendrotweaks.morphology.io.validation import validate_tree
|
13
|
+
|
14
|
+
|
15
|
+
|
16
|
+
def create_point_tree(source: Union[str, DataFrame]) -> PointTree:
|
17
|
+
"""
|
18
|
+
Create a point tree from either a file path or a DataFrame.
|
19
|
+
|
20
|
+
Parameters
|
21
|
+
----------
|
22
|
+
source : Union[str, DataFrame]
|
23
|
+
The source of the SWC data. Can be a file path or a DataFrame.
|
24
|
+
|
25
|
+
Returns
|
26
|
+
-------
|
27
|
+
PointTree
|
28
|
+
The point tree representing the reconstruction of the neuron morphology.
|
29
|
+
"""
|
30
|
+
if isinstance(source, str):
|
31
|
+
reader = SWCReader()
|
32
|
+
df = reader.read_file(source)
|
33
|
+
elif isinstance(source, DataFrame):
|
34
|
+
df = source
|
35
|
+
else:
|
36
|
+
raise ValueError("Source must be a file path (str) or a DataFrame.")
|
37
|
+
|
38
|
+
nodes = [
|
39
|
+
Point(row['Index'], row['Type'], row['X'], row['Y'], row['Z'], row['R'], row['Parent'])
|
40
|
+
for _, row in df.iterrows()
|
41
|
+
]
|
42
|
+
point_tree = PointTree(nodes)
|
43
|
+
point_tree.remove_overlaps()
|
44
|
+
|
45
|
+
return point_tree
|
46
|
+
|
47
|
+
|
48
|
+
def create_section_tree(point_tree: PointTree):
|
49
|
+
"""
|
50
|
+
Create a section tree from a point tree.
|
51
|
+
|
52
|
+
Parameters
|
53
|
+
----------
|
54
|
+
point_tree : PointTree
|
55
|
+
The point tree to create the section tree from by splitting it into sections.
|
56
|
+
|
57
|
+
Returns
|
58
|
+
-------
|
59
|
+
SectionTree
|
60
|
+
The section tree created representing the neuron morphology on a more abstract level.
|
61
|
+
"""
|
62
|
+
|
63
|
+
point_tree.extend_sections()
|
64
|
+
point_tree.sort()
|
65
|
+
|
66
|
+
sections = _split_to_sections(point_tree)
|
67
|
+
|
68
|
+
sec_tree = SectionTree(sections)
|
69
|
+
sec_tree._point_tree = point_tree
|
70
|
+
|
71
|
+
return sec_tree
|
72
|
+
|
73
|
+
|
74
|
+
def _split_to_sections(point_tree: PointTree) -> List[Section]:
|
75
|
+
"""
|
76
|
+
Split the point tree into sections.
|
77
|
+
"""
|
78
|
+
sections = []
|
79
|
+
|
80
|
+
bifurcation_children = [
|
81
|
+
child for b in point_tree.bifurcations for child in b.children]
|
82
|
+
bifurcation_children = [point_tree.root] + bifurcation_children
|
83
|
+
# Filter out the bifurcation children to enforce the original order
|
84
|
+
bifurcation_children = [node for node in point_tree._nodes
|
85
|
+
if node in bifurcation_children]
|
86
|
+
|
87
|
+
# Assign a section to each bifurcation child
|
88
|
+
for i, child in enumerate(bifurcation_children):
|
89
|
+
section = Section(idx=i, parent_idx=-1, points=[child])
|
90
|
+
sections.append(section)
|
91
|
+
child._section = section
|
92
|
+
# Propagate the section to the children until the next
|
93
|
+
# bifurcation or termination point is reached
|
94
|
+
while child.children:
|
95
|
+
next_child = child.children[0]
|
96
|
+
if next_child in bifurcation_children:
|
97
|
+
break
|
98
|
+
next_child._section = section
|
99
|
+
section.points.append(next_child)
|
100
|
+
child = next_child
|
101
|
+
|
102
|
+
section.parent = section.points[0].parent._section if section.points[0].parent else None
|
103
|
+
section.parent_idx = section.parent.idx if section.parent else -1
|
104
|
+
|
105
|
+
|
106
|
+
if point_tree.soma_notation == '3PS':
|
107
|
+
sections = _merge_soma(sections, point_tree)
|
108
|
+
|
109
|
+
return sections
|
110
|
+
|
111
|
+
|
112
|
+
def _merge_soma(sections: List[Section], point_tree: PointTree):
|
113
|
+
"""
|
114
|
+
If soma has 3PS notation, merge it into one section.
|
115
|
+
"""
|
116
|
+
|
117
|
+
true_soma = point_tree.root._section
|
118
|
+
true_soma.idx = 0
|
119
|
+
true_soma.parent_idx = -1
|
120
|
+
|
121
|
+
false_somas = [sec for sec in sections
|
122
|
+
if sec.domain == 'soma' and sec is not true_soma]
|
123
|
+
if len(false_somas) != 2:
|
124
|
+
print(false_somas)
|
125
|
+
raise ValueError('Soma must have exactly 2 children of domain soma.')
|
126
|
+
|
127
|
+
for i, sec in enumerate(false_somas):
|
128
|
+
sections.remove(sec)
|
129
|
+
if len(sec.points) != 1:
|
130
|
+
raise ValueError('Soma children must have exactly 1 point.')
|
131
|
+
for pt in sec.points:
|
132
|
+
pt._section = true_soma
|
133
|
+
|
134
|
+
true_soma.points = [
|
135
|
+
false_somas[0].points[0],
|
136
|
+
true_soma.points[0],
|
137
|
+
false_somas[1].points[0]
|
138
|
+
]
|
139
|
+
|
140
|
+
for sec in sections:
|
141
|
+
if sec is true_soma:
|
142
|
+
continue
|
143
|
+
sec.idx -= 2
|
144
|
+
sec.parent_idx = sec.points[0].parent._section.idx
|
145
|
+
|
146
|
+
|
147
|
+
return sections
|
148
|
+
|
149
|
+
|
150
|
+
def create_segment_tree(sec_tree):
|
151
|
+
"""
|
152
|
+
Create a segment tree from a section tree.
|
153
|
+
|
154
|
+
Parameters
|
155
|
+
----------
|
156
|
+
sec_tree : SectionTree
|
157
|
+
The section tree to create the segment tree from by splitting it into segments.
|
158
|
+
|
159
|
+
Returns
|
160
|
+
-------
|
161
|
+
SegmentTree
|
162
|
+
The segment tree representing spatial discretization of the neuron morphology for numerical simulations.
|
163
|
+
"""
|
164
|
+
|
165
|
+
segments = _create_segments(sec_tree)
|
166
|
+
|
167
|
+
seg_tree = SegmentTree(segments)
|
168
|
+
sec_tree._seg_tree = seg_tree
|
169
|
+
|
170
|
+
return seg_tree
|
171
|
+
|
172
|
+
|
173
|
+
def _create_segments(sec_tree) -> List[Segment]:
|
174
|
+
"""
|
175
|
+
Create a list of Segment objects from a SectionTree object.
|
176
|
+
"""
|
177
|
+
|
178
|
+
segments = []
|
179
|
+
# TODO: Refactor this to use a stack instead of recursion
|
180
|
+
def add_segments(sec, parent_idx, idx_counter):
|
181
|
+
segs = {seg: idx + idx_counter for idx, seg in enumerate(sec._ref)}
|
182
|
+
sec.segments = []
|
183
|
+
for seg, idx in segs.items():
|
184
|
+
segment = Segment(
|
185
|
+
idx=idx, parent_idx=parent_idx, neuron_seg=seg, section=sec)
|
186
|
+
segments.append(segment)
|
187
|
+
sec.segments.append(segment)
|
188
|
+
|
189
|
+
parent_idx = idx
|
190
|
+
|
191
|
+
idx_counter += len(segs)
|
192
|
+
|
193
|
+
|
194
|
+
for child in sec.children:
|
195
|
+
# IMPORTANT: This is needed since 0 and 1 segments are not explicitly
|
196
|
+
# defined in the section segments list
|
197
|
+
if child._ref.parentseg().x == 1:
|
198
|
+
new_parent_idx = list(segs.values())[-1]
|
199
|
+
elif child._ref.parentseg().x == 0:
|
200
|
+
new_parent_idx = list(segs.values())[0]
|
201
|
+
else:
|
202
|
+
new_parent_idx = segs[child._ref.parentseg()]
|
203
|
+
# Recurse for the child section
|
204
|
+
idx_counter = add_segments(child, new_parent_idx, idx_counter)
|
205
|
+
|
206
|
+
return idx_counter
|
207
|
+
|
208
|
+
# Start with the root section of the sec_tree
|
209
|
+
add_segments(sec_tree.root, parent_idx=-1, idx_counter=0)
|
210
|
+
|
211
|
+
return segments
|
212
|
+
|
@@ -0,0 +1,66 @@
|
|
1
|
+
import pandas as pd
|
2
|
+
|
3
|
+
class SWCReader():
|
4
|
+
"""
|
5
|
+
Reads an SWC file and returns a DataFrame.
|
6
|
+
"""
|
7
|
+
|
8
|
+
def __init__(self):
|
9
|
+
pass
|
10
|
+
|
11
|
+
@staticmethod
|
12
|
+
def read_file(path_to_swc_file: str) -> pd.DataFrame:
|
13
|
+
"""
|
14
|
+
Read the SWC file and return a DataFrame.
|
15
|
+
|
16
|
+
Parameters
|
17
|
+
----------
|
18
|
+
path_to_swc_file : str
|
19
|
+
The full path to the SWC file.
|
20
|
+
|
21
|
+
Returns
|
22
|
+
-------
|
23
|
+
pd.DataFrame
|
24
|
+
The DataFrame containing the SWC data
|
25
|
+
"""
|
26
|
+
with open(path_to_swc_file, 'r') as f:
|
27
|
+
lines = f.readlines()
|
28
|
+
lines = [' '.join(line.split()) for line in lines if line.strip()]
|
29
|
+
with open(path_to_swc_file, 'w') as f:
|
30
|
+
f.write('\n'.join(lines))
|
31
|
+
|
32
|
+
df = pd.read_csv(
|
33
|
+
path_to_swc_file,
|
34
|
+
sep=' ',
|
35
|
+
header=None,
|
36
|
+
comment='#',
|
37
|
+
names=['Index', 'Type', 'X', 'Y', 'Z', 'R', 'Parent'],
|
38
|
+
index_col=False
|
39
|
+
)
|
40
|
+
|
41
|
+
if (df['R'] == 0).all():
|
42
|
+
df['R'] = 1.0
|
43
|
+
|
44
|
+
if df['Index'].duplicated().any():
|
45
|
+
raise ValueError("The SWC file contains duplicate node ids.")
|
46
|
+
return df
|
47
|
+
|
48
|
+
@staticmethod
|
49
|
+
def plot_raw_data(df, ax):
|
50
|
+
"""
|
51
|
+
Plot the raw data from the SWC file using the DataFrame.
|
52
|
+
|
53
|
+
Parameters
|
54
|
+
----------
|
55
|
+
df : pd.DataFrame
|
56
|
+
The DataFrame containing the SWC data (generated by read_file).
|
57
|
+
ax : matplotlib.pyplot.Axes
|
58
|
+
The axes to plot on.
|
59
|
+
"""
|
60
|
+
types_to_colors = {1: 'C1', 2: 'C3', 3: 'C2', 4: 'C0', 31: 'green', 41: 'blue', 42: 'magenta', 43: 'brown'}
|
61
|
+
for t in df['Type'].unique():
|
62
|
+
color = types_to_colors.get(t, 'k')
|
63
|
+
mask = df['Type'] == t
|
64
|
+
ax.scatter(df[mask]['X'], df[mask]['Y'], df[mask]['Z'],
|
65
|
+
c=color, s=1, label=f'Type {t}')
|
66
|
+
ax.legend()
|
@@ -0,0 +1,212 @@
|
|
1
|
+
from dendrotweaks.morphology.trees import Tree
|
2
|
+
from dendrotweaks.morphology.point_trees import PointTree
|
3
|
+
from dendrotweaks.morphology.sec_trees import SectionTree
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import warnings
|
7
|
+
|
8
|
+
# def custom_warning_formatter(message, category, filename, lineno, file=None, line=None):
|
9
|
+
# return f"WARNING: {message}\n({os.path.basename(filename)}, line {lineno})\n"
|
10
|
+
|
11
|
+
# warnings.formatwarning = custom_warning_formatter
|
12
|
+
|
13
|
+
|
14
|
+
def validate_tree(tree):
|
15
|
+
"""
|
16
|
+
Validate the toplogical structure of a tree graph.
|
17
|
+
|
18
|
+
Parameters
|
19
|
+
----------
|
20
|
+
tree : Tree
|
21
|
+
The tree to validate.
|
22
|
+
"""
|
23
|
+
# Check for unique node ids
|
24
|
+
check_unique_ids(tree)
|
25
|
+
check_unique_root(tree)
|
26
|
+
check_unique_children(tree)
|
27
|
+
|
28
|
+
|
29
|
+
# Check for connectivity
|
30
|
+
check_connections(tree)
|
31
|
+
check_loops(tree)
|
32
|
+
check_bifurcations(tree)
|
33
|
+
# validate_order(self.tree)
|
34
|
+
|
35
|
+
if isinstance(tree, PointTree):
|
36
|
+
validate_point_tree(tree)
|
37
|
+
|
38
|
+
if isinstance(tree, SectionTree):
|
39
|
+
validate_section_tree(tree)
|
40
|
+
|
41
|
+
# Check if the tree is sorted
|
42
|
+
print("Checking if the tree is sorted...")
|
43
|
+
if not tree.is_sorted:
|
44
|
+
warnings.warn("Tree is not sorted")
|
45
|
+
|
46
|
+
print("***Validation complete.***")
|
47
|
+
|
48
|
+
|
49
|
+
# -----------------------------------------------------------------------------
|
50
|
+
# Indicies
|
51
|
+
# -----------------------------------------------------------------------------
|
52
|
+
|
53
|
+
def check_unique_ids(tree):
|
54
|
+
print("Checking for unique node ids...")
|
55
|
+
node_ids = {node.idx for node in tree._nodes}
|
56
|
+
if len(node_ids) != len(tree._nodes):
|
57
|
+
warnings.warn(f"Tree contains {len(tree._nodes) - len(node_ids)} duplicate node ids.")
|
58
|
+
|
59
|
+
|
60
|
+
def check_unique_children(tree):
|
61
|
+
print("Checking for duplicate children...")
|
62
|
+
for node in tree._nodes:
|
63
|
+
children = node.children
|
64
|
+
if len(children) != len(set(children)):
|
65
|
+
warnings.warn(f"Node {node} contains duplicate children.")
|
66
|
+
|
67
|
+
|
68
|
+
def check_unique_root(tree):
|
69
|
+
print("Checking for unique root node...")
|
70
|
+
root_nodes = {node for node in tree._nodes
|
71
|
+
if node.parent is None or node.parent_idx in {None, -1, '-1'}
|
72
|
+
}
|
73
|
+
if len(root_nodes) > 1:
|
74
|
+
warnings.warn(f"Found {len(root_nodes)} root nodes.")
|
75
|
+
elif len(root_nodes) == 0:
|
76
|
+
warnings.warn("Tree does not contain a root node.")
|
77
|
+
|
78
|
+
|
79
|
+
# -----------------------------------------------------------------------------
|
80
|
+
# Connectivity
|
81
|
+
# -----------------------------------------------------------------------------
|
82
|
+
|
83
|
+
def check_connections(tree):
|
84
|
+
"""
|
85
|
+
Validate the parent-child relationships in the tree.
|
86
|
+
|
87
|
+
1. Ensure that every node is listed as a child of its parent.
|
88
|
+
2. Ensure that the parent of each child matches the node.
|
89
|
+
"""
|
90
|
+
print("Checking tree connectivity...")
|
91
|
+
if not tree.is_connected:
|
92
|
+
not_connected = set(tree._nodes) - set(tree.root.subtree)
|
93
|
+
warnings.warn(f"The following nodes are not connected to the root node: {not_connected}")
|
94
|
+
|
95
|
+
for node in tree._nodes:
|
96
|
+
parent = node.parent
|
97
|
+
|
98
|
+
# Validate that the node is in its parent's children list.
|
99
|
+
if parent is not None:
|
100
|
+
if node not in parent.children:
|
101
|
+
warnings.warn(
|
102
|
+
f"Validation Warning: Node {node} is not listed in the children of its parent {parent}. "
|
103
|
+
f"Expected parent.children to include {node}, but it does not."
|
104
|
+
)
|
105
|
+
|
106
|
+
# Validate that the parent of each child is the current node.
|
107
|
+
for child in node.children:
|
108
|
+
if child.parent is not node:
|
109
|
+
warnings.warn(
|
110
|
+
f"Validation Warning: Node {child} has an incorrect parent. "
|
111
|
+
f"Expected parent {node}, but found {child.parent}."
|
112
|
+
)
|
113
|
+
|
114
|
+
|
115
|
+
def check_loops(tree):
|
116
|
+
print("Checking for loops...")
|
117
|
+
for node in tree._nodes:
|
118
|
+
for descendant in node.subtree:
|
119
|
+
if node in descendant.children:
|
120
|
+
warnings.warn(f"Node {node} is a descendant of itself. Loop detected at node {descendant}.")
|
121
|
+
|
122
|
+
def check_bifurcations(tree):
|
123
|
+
print("Checking for bifurcations with more than 2 children...")
|
124
|
+
bifurcation_issues = {node: len(node.children) for node in tree.bifurcations if len(node.children) > 2 and node is not tree.root}
|
125
|
+
if bifurcation_issues:
|
126
|
+
issues_str = "\n".join([f"Node {node.idx:<6} has {count} children" for node, count in bifurcation_issues.items()])
|
127
|
+
warnings.warn(f"Tree contains bifurcations with more than 2 children:\n{issues_str}")
|
128
|
+
|
129
|
+
|
130
|
+
# =============================================================================
|
131
|
+
# Point-specific validation
|
132
|
+
# =============================================================================
|
133
|
+
|
134
|
+
def validate_point_tree(point_tree):
|
135
|
+
"""
|
136
|
+
Validate the geometry of a point tree.
|
137
|
+
|
138
|
+
Parameters
|
139
|
+
----------
|
140
|
+
point_tree : PointTree
|
141
|
+
The point tree to validate.
|
142
|
+
"""
|
143
|
+
|
144
|
+
# Check for NaN values in the DataFrame
|
145
|
+
print("Checking for NaN values...")
|
146
|
+
nan_counts = point_tree.df.isnull().sum()
|
147
|
+
if nan_counts.sum() > 0:
|
148
|
+
warnings.warn(f"Found {nan_counts} NaN values in the DataFrame")
|
149
|
+
|
150
|
+
# Check for bifurcations in the soma
|
151
|
+
print("Checking for bifurcations in the soma...")
|
152
|
+
bifurcations_without_root = [pt for pt in point_tree.bifurcations
|
153
|
+
if pt is not point_tree.root]
|
154
|
+
bifurcations_within_soma = [pt for pt in bifurcations_without_root
|
155
|
+
if pt.type_idx == 1]
|
156
|
+
if bifurcations_within_soma:
|
157
|
+
warnings.warn(f"Soma must be non-branching. Found bifurcations: {bifurcations_within_soma}")
|
158
|
+
|
159
|
+
|
160
|
+
if point_tree._is_extended:
|
161
|
+
print("Checking the extended tree for geometric continuity...")
|
162
|
+
non_overlapping_children = [
|
163
|
+
(pt, child) for pt in bifurcations_without_root for child in pt.children
|
164
|
+
if not child.overlaps_with(pt)
|
165
|
+
]
|
166
|
+
if non_overlapping_children:
|
167
|
+
issues_str = "\n".join([f"Child {child} does not overlap with parent {pt}" for pt, child in non_overlapping_children])
|
168
|
+
warnings.warn(f"Found non-overlapping children:\n{issues_str} for bifurcations")
|
169
|
+
|
170
|
+
|
171
|
+
# =============================================================================
|
172
|
+
# Section-specific validation
|
173
|
+
# =============================================================================
|
174
|
+
|
175
|
+
def validate_section_tree(section_tree):
|
176
|
+
"""
|
177
|
+
Validate a section tree.
|
178
|
+
|
179
|
+
Parameters
|
180
|
+
----------
|
181
|
+
section_tree : SectionTree
|
182
|
+
The section tree to validate.
|
183
|
+
"""
|
184
|
+
|
185
|
+
print("Checking that all points in a section belong to the same domain...")
|
186
|
+
for sec in section_tree:
|
187
|
+
if not all(pt.domain == sec.domain for pt in sec.points):
|
188
|
+
warnings.warn('All points in a section must belong to the same domain.')
|
189
|
+
|
190
|
+
print("Checking that all sections have a non-zero length...")
|
191
|
+
if any(sec.length == 0 for sec in section_tree):
|
192
|
+
warnings.warn('Found sections with zero length.')
|
193
|
+
|
194
|
+
print("Checking that all sections (except soma) have 0 or 2 children...")
|
195
|
+
if any(len(sec.children) not in {0, 2} and sec is not section_tree.root for sec in section_tree):
|
196
|
+
warnings.warn('Found sections with an incorrect number of children.')
|
197
|
+
|
198
|
+
print("Checking that the root section has domain soma...")
|
199
|
+
if not section_tree.root.domain == 'soma':
|
200
|
+
warnings.warn('Root section must have domain soma.')
|
201
|
+
|
202
|
+
|
203
|
+
# =============================================================================
|
204
|
+
# Validation utilities
|
205
|
+
# =============================================================================
|
206
|
+
|
207
|
+
def shuffle_indices_for_testing(df):
|
208
|
+
idx_range = int(df['Index'].max() - df['Index'].min()) + 1
|
209
|
+
random_mapping = {k:v for k, v in zip(df['Index'], np.random.permutation(idx_range))}
|
210
|
+
df['Index'] = df['Index'].map(random_mapping)
|
211
|
+
df.loc[df['Parent'] != -1, 'Parent'] = df['Parent'].map(random_mapping)
|
212
|
+
return df
|