dendrotweaks 0.4.5__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dendrotweaks/__init__.py +1 -1
- dendrotweaks/analysis/ephys_analysis.py +63 -39
- dendrotweaks/biophys/default_mod/vecstim.mod +1 -11
- dendrotweaks/biophys/distributions.py +3 -3
- dendrotweaks/biophys/io/converter.py +4 -0
- dendrotweaks/biophys/mechanisms.py +11 -1
- dendrotweaks/model.py +143 -1087
- dendrotweaks/model_io.py +736 -39
- dendrotweaks/model_simulation.py +326 -0
- dendrotweaks/morphology/io/reader.py +2 -2
- dendrotweaks/path_manager.py +2 -2
- dendrotweaks/prerun.py +63 -0
- dendrotweaks/utils.py +114 -29
- {dendrotweaks-0.4.5.dist-info → dendrotweaks-0.4.6.dist-info}/METADATA +1 -1
- {dendrotweaks-0.4.5.dist-info → dendrotweaks-0.4.6.dist-info}/RECORD +18 -16
- {dendrotweaks-0.4.5.dist-info → dendrotweaks-0.4.6.dist-info}/WHEEL +0 -0
- {dendrotweaks-0.4.5.dist-info → dendrotweaks-0.4.6.dist-info}/licenses/LICENSE +0 -0
- {dendrotweaks-0.4.5.dist-info → dendrotweaks-0.4.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,326 @@
|
|
1
|
+
# Imports
|
2
|
+
import os
|
3
|
+
|
4
|
+
# DendroTweaks imports
|
5
|
+
from dendrotweaks.stimuli.populations import Population
|
6
|
+
from dendrotweaks.stimuli.iclamps import IClamp
|
7
|
+
from dendrotweaks.morphology.io import create_segment_tree
|
8
|
+
from dendrotweaks.prerun import prerun
|
9
|
+
from dendrotweaks.utils import calculate_lambda_f
|
10
|
+
from dendrotweaks.utils import POPULATIONS
|
11
|
+
|
12
|
+
# Warnings configuration
|
13
|
+
import warnings
|
14
|
+
|
15
|
+
def custom_warning_formatter(message, category, filename, lineno, file=None, line=None):
|
16
|
+
return f"WARNING: {message}\n({os.path.basename(filename)}, line {lineno})\n"
|
17
|
+
|
18
|
+
warnings.formatwarning = custom_warning_formatter
|
19
|
+
|
20
|
+
|
21
|
+
class SimulationMixin:
|
22
|
+
|
23
|
+
"""Mixin class for model simulation functionalities."""
|
24
|
+
|
25
|
+
@property
|
26
|
+
def recordings(self):
|
27
|
+
"""
|
28
|
+
The recordings of the model. Reference to the recordings in the simulator.
|
29
|
+
"""
|
30
|
+
return self.simulator.recordings
|
31
|
+
|
32
|
+
|
33
|
+
@recordings.setter
|
34
|
+
def recordings(self, recordings):
|
35
|
+
self.simulator.recordings = recordings
|
36
|
+
|
37
|
+
|
38
|
+
# ========================================================================
|
39
|
+
# SEGMENTATION
|
40
|
+
# ========================================================================
|
41
|
+
|
42
|
+
# TODO Make a context manager for this
|
43
|
+
def _temp_clear_stimuli(self):
|
44
|
+
"""
|
45
|
+
Temporarily save and clear stimuli.
|
46
|
+
"""
|
47
|
+
self.export_stimuli(file_name='_temp_stimuli')
|
48
|
+
self.remove_all_stimuli()
|
49
|
+
self.remove_all_recordings()
|
50
|
+
|
51
|
+
|
52
|
+
def _temp_reload_stimuli(self):
|
53
|
+
"""
|
54
|
+
Load stimuli from a temporary file and clean up.
|
55
|
+
"""
|
56
|
+
self.load_stimuli(file_name='_temp_stimuli')
|
57
|
+
for ext in ['json', 'csv']:
|
58
|
+
temp_path = self.path_manager.get_file_path('stimuli', '_temp_stimuli', extension=ext)
|
59
|
+
if os.path.exists(temp_path):
|
60
|
+
os.remove(temp_path)
|
61
|
+
|
62
|
+
|
63
|
+
def set_segmentation(self, d_lambda=0.1, f=100):
|
64
|
+
"""
|
65
|
+
Set the number of segments in each section based on the geometry.
|
66
|
+
|
67
|
+
Parameters
|
68
|
+
----------
|
69
|
+
d_lambda : float
|
70
|
+
The lambda value to use.
|
71
|
+
f : float
|
72
|
+
The frequency value to use.
|
73
|
+
"""
|
74
|
+
self.d_lambda = d_lambda
|
75
|
+
|
76
|
+
# Temporarily save and clear stimuli
|
77
|
+
self._temp_clear_stimuli()
|
78
|
+
|
79
|
+
# Pre-distribute parameters needed for lambda_f calculation
|
80
|
+
for param_name in ['cm', 'Ra']:
|
81
|
+
self.distribute(param_name)
|
82
|
+
|
83
|
+
# Calculate lambda_f and set nseg for each section
|
84
|
+
for sec in self.sec_tree.sections:
|
85
|
+
lambda_f = calculate_lambda_f(sec.distances, sec.diameters, sec.Ra, sec.cm, f)
|
86
|
+
nseg = max(1, int((sec.L / (d_lambda * lambda_f) + 0.9) / 2) * 2 + 1)
|
87
|
+
sec._nseg = sec._ref.nseg = nseg
|
88
|
+
|
89
|
+
# Rebuild the segment tree and redistribute parameters
|
90
|
+
self.seg_tree = create_segment_tree(self.sec_tree)
|
91
|
+
self.distribute_all()
|
92
|
+
|
93
|
+
# Reload stimuli and clean up temporary files
|
94
|
+
self._temp_reload_stimuli()
|
95
|
+
|
96
|
+
|
97
|
+
# -----------------------------------------------------------------------
|
98
|
+
# ICLAMPS
|
99
|
+
# -----------------------------------------------------------------------
|
100
|
+
|
101
|
+
def add_iclamp(self, sec, loc, amp=0, delay=100, dur=100):
|
102
|
+
"""
|
103
|
+
Add an IClamp to a section.
|
104
|
+
|
105
|
+
Parameters
|
106
|
+
----------
|
107
|
+
sec : Section
|
108
|
+
The section to add the IClamp to.
|
109
|
+
loc : float
|
110
|
+
The location of the IClamp in the section.
|
111
|
+
amp : float, optional
|
112
|
+
The amplitude of the IClamp. Default is 0.
|
113
|
+
delay : float, optional
|
114
|
+
The delay of the IClamp. Default is 100.
|
115
|
+
dur : float, optional
|
116
|
+
The duration of the IClamp. Default is 100.
|
117
|
+
"""
|
118
|
+
seg = sec(loc)
|
119
|
+
if self.iclamps.get(seg):
|
120
|
+
self.remove_iclamp(sec, loc)
|
121
|
+
iclamp = IClamp(sec, loc, amp, delay, dur)
|
122
|
+
print(f'IClamp added to sec {sec} at loc {loc}.')
|
123
|
+
self.iclamps[seg] = iclamp
|
124
|
+
|
125
|
+
|
126
|
+
def remove_iclamp(self, sec, loc):
|
127
|
+
"""
|
128
|
+
Remove an IClamp from a section.
|
129
|
+
|
130
|
+
Parameters
|
131
|
+
----------
|
132
|
+
sec : Section
|
133
|
+
The section to remove the IClamp from.
|
134
|
+
loc : float
|
135
|
+
The location of the IClamp in the section.
|
136
|
+
"""
|
137
|
+
seg = sec(loc)
|
138
|
+
if self.iclamps.get(seg):
|
139
|
+
self.iclamps.pop(seg)
|
140
|
+
|
141
|
+
|
142
|
+
def remove_all_iclamps(self):
|
143
|
+
"""
|
144
|
+
Remove all IClamps from the model.
|
145
|
+
"""
|
146
|
+
|
147
|
+
for seg in list(self.iclamps.keys()):
|
148
|
+
sec, loc = seg._section, seg.x
|
149
|
+
self.remove_iclamp(sec, loc)
|
150
|
+
if self.iclamps:
|
151
|
+
warnings.warn(f'Not all iclamps were removed: {self.iclamps}')
|
152
|
+
self.iclamps = {}
|
153
|
+
|
154
|
+
|
155
|
+
# -----------------------------------------------------------------------
|
156
|
+
# SYNAPSES
|
157
|
+
# -----------------------------------------------------------------------
|
158
|
+
|
159
|
+
def _add_population(self, population):
|
160
|
+
self.populations[population.syn_type][population.name] = population
|
161
|
+
|
162
|
+
|
163
|
+
def add_population(self, segments, N, syn_type):
|
164
|
+
"""
|
165
|
+
Add a population of synapses to the model.
|
166
|
+
|
167
|
+
Parameters
|
168
|
+
----------
|
169
|
+
segments : list[Segment]
|
170
|
+
The segments to add the synapses to.
|
171
|
+
N : int
|
172
|
+
The number of synapses to add.
|
173
|
+
syn_type : str
|
174
|
+
The type of synapse to add.
|
175
|
+
"""
|
176
|
+
idx = len(self.populations[syn_type])
|
177
|
+
population = Population(idx, segments, N, syn_type)
|
178
|
+
population.allocate_synapses()
|
179
|
+
population.create_inputs()
|
180
|
+
self._add_population(population)
|
181
|
+
|
182
|
+
|
183
|
+
def update_population_kinetic_params(self, pop_name, **params):
|
184
|
+
"""
|
185
|
+
Update the kinetic parameters of a population of synapses.
|
186
|
+
|
187
|
+
Parameters
|
188
|
+
----------
|
189
|
+
pop_name : str
|
190
|
+
The name of the population.
|
191
|
+
params : dict
|
192
|
+
The parameters to update.
|
193
|
+
"""
|
194
|
+
syn_type, idx = pop_name.rsplit('_', 1)
|
195
|
+
population = self.populations[syn_type][pop_name]
|
196
|
+
population.update_kinetic_params(**params)
|
197
|
+
print(population.kinetic_params)
|
198
|
+
|
199
|
+
|
200
|
+
def update_population_input_params(self, pop_name, **params):
|
201
|
+
"""
|
202
|
+
Update the input parameters of a population of synapses.
|
203
|
+
|
204
|
+
Parameters
|
205
|
+
----------
|
206
|
+
pop_name : str
|
207
|
+
The name of the population.
|
208
|
+
params : dict
|
209
|
+
The parameters to update.
|
210
|
+
"""
|
211
|
+
syn_type, idx = pop_name.rsplit('_', 1)
|
212
|
+
population = self.populations[syn_type][pop_name]
|
213
|
+
population.update_input_params(**params)
|
214
|
+
print(population.input_params)
|
215
|
+
|
216
|
+
|
217
|
+
def remove_population(self, name):
|
218
|
+
"""
|
219
|
+
Remove a population of synapses from the model.
|
220
|
+
|
221
|
+
Parameters
|
222
|
+
----------
|
223
|
+
name : str
|
224
|
+
The name of the population
|
225
|
+
"""
|
226
|
+
syn_type, idx = name.rsplit('_', 1)
|
227
|
+
population = self.populations[syn_type].pop(name)
|
228
|
+
population.clean()
|
229
|
+
|
230
|
+
|
231
|
+
def remove_all_populations(self):
|
232
|
+
"""
|
233
|
+
Remove all populations of synapses from the model.
|
234
|
+
"""
|
235
|
+
for syn_type in self.populations:
|
236
|
+
for name in list(self.populations[syn_type].keys()):
|
237
|
+
self.remove_population(name)
|
238
|
+
if any(self.populations.values()):
|
239
|
+
warnings.warn(f'Not all populations were removed: {self.populations}')
|
240
|
+
self.populations = POPULATIONS
|
241
|
+
|
242
|
+
|
243
|
+
def remove_all_stimuli(self):
|
244
|
+
"""
|
245
|
+
Remove all stimuli from the model.
|
246
|
+
"""
|
247
|
+
self.remove_all_iclamps()
|
248
|
+
self.remove_all_populations()
|
249
|
+
|
250
|
+
|
251
|
+
# ========================================================================
|
252
|
+
# SIMULATION
|
253
|
+
# ========================================================================
|
254
|
+
|
255
|
+
def add_recording(self, sec, loc, var='v'):
|
256
|
+
"""
|
257
|
+
Add a recording to the model.
|
258
|
+
|
259
|
+
Parameters
|
260
|
+
----------
|
261
|
+
sec : Section
|
262
|
+
The section to record from.
|
263
|
+
loc : float
|
264
|
+
The location along the normalized section length to record from.
|
265
|
+
var : str, optional
|
266
|
+
The variable to record. Default is 'v'.
|
267
|
+
"""
|
268
|
+
self.simulator.add_recording(sec, loc, var)
|
269
|
+
print(f'Recording added to sec {sec} at loc {loc}.')
|
270
|
+
|
271
|
+
|
272
|
+
def remove_recording(self, sec, loc, var='v'):
|
273
|
+
"""
|
274
|
+
Remove a recording from the model.
|
275
|
+
|
276
|
+
Parameters
|
277
|
+
----------
|
278
|
+
sec : Section
|
279
|
+
The section to remove the recording from.
|
280
|
+
loc : float
|
281
|
+
The location along the normalized section length to remove the recording from.
|
282
|
+
"""
|
283
|
+
self.simulator.remove_recording(sec, loc, var)
|
284
|
+
|
285
|
+
|
286
|
+
def remove_all_recordings(self, var=None):
|
287
|
+
"""
|
288
|
+
Remove all recordings from the model.
|
289
|
+
"""
|
290
|
+
self.simulator.remove_all_recordings(var=var)
|
291
|
+
|
292
|
+
|
293
|
+
def run(self, duration=300, prerun_time=0, truncate=True):
|
294
|
+
"""
|
295
|
+
Run the simulation for a specified duration, optionally preceded by a prerun period
|
296
|
+
to stabilize the model.
|
297
|
+
|
298
|
+
Parameters
|
299
|
+
----------
|
300
|
+
duration : float
|
301
|
+
Duration of the main simulation (excluding prerun).
|
302
|
+
prerun_time : float
|
303
|
+
Optional prerun period to run before the main simulation.
|
304
|
+
truncate : bool
|
305
|
+
Whether to truncate prerun data after the simulation.
|
306
|
+
"""
|
307
|
+
if duration <= 0:
|
308
|
+
raise ValueError("Simulation duration must be positive.")
|
309
|
+
if prerun_time < 0:
|
310
|
+
raise ValueError("Prerun time must be non-negative.")
|
311
|
+
|
312
|
+
total_time = duration + prerun_time
|
313
|
+
|
314
|
+
if prerun_time > 0:
|
315
|
+
with prerun(self, duration=prerun_time, truncate=truncate):
|
316
|
+
self.simulator.run(total_time)
|
317
|
+
else:
|
318
|
+
self.simulator.run(duration)
|
319
|
+
|
320
|
+
|
321
|
+
def get_traces(self):
|
322
|
+
return self.simulator.get_traces()
|
323
|
+
|
324
|
+
|
325
|
+
def plot(self, *args, **kwargs):
|
326
|
+
self.simulator.plot(*args, **kwargs)
|
@@ -70,6 +70,6 @@ class SWCReader():
|
|
70
70
|
for t in df['Type'].unique():
|
71
71
|
color = types_to_colors.get(t, 'k')
|
72
72
|
mask = df['Type'] == t
|
73
|
-
ax.scatter(df[mask]['X'], df[mask]['Y'], df[mask]['Z'],
|
74
|
-
|
73
|
+
ax.scatter(xs=df[mask]['X'], ys=df[mask]['Y'], zs=df[mask]['Z'],
|
74
|
+
c=color, s=1, label=f'Type {t}')
|
75
75
|
ax.legend()
|
dendrotweaks/path_manager.py
CHANGED
@@ -50,9 +50,9 @@ class PathManager:
|
|
50
50
|
@property
|
51
51
|
def path_to_data(self):
|
52
52
|
"""
|
53
|
-
The path to the data directory.
|
53
|
+
The path to the data directory, which is always the parent directory of path_to_model.
|
54
54
|
"""
|
55
|
-
return os.path.
|
55
|
+
return os.path.abspath(os.path.join(self.path_to_model, os.pardir))
|
56
56
|
|
57
57
|
def __repr__(self):
|
58
58
|
return f"PathManager({self.path_to_model})"
|
dendrotweaks/prerun.py
ADDED
@@ -0,0 +1,63 @@
|
|
1
|
+
class prerun:
|
2
|
+
"""
|
3
|
+
Context manager to prerun a simulation for a specified duration.
|
4
|
+
This is useful to stabilize the model before running the main simulation.
|
5
|
+
"""
|
6
|
+
def __init__(self, model, duration=300, truncate=True):
|
7
|
+
self.model = model
|
8
|
+
if duration <= 0:
|
9
|
+
raise ValueError("Duration must be a positive number.")
|
10
|
+
self.duration = duration
|
11
|
+
self.truncate = truncate
|
12
|
+
self._original_iclamp_delays = {}
|
13
|
+
self._original_input_params = {}
|
14
|
+
|
15
|
+
def __enter__(self):
|
16
|
+
|
17
|
+
self._original_iclamp_delays = {k: v.delay for k, v in self.model.iclamps.items()}
|
18
|
+
self._original_input_params = {
|
19
|
+
(pk, sk): (pop.input_params['start'], pop.input_params['end'])
|
20
|
+
for pk, pops in self.model.populations.items()
|
21
|
+
for sk, pop in pops.items()
|
22
|
+
}
|
23
|
+
for iclamp in self.model.iclamps.values():
|
24
|
+
iclamp.delay += self.duration
|
25
|
+
|
26
|
+
for pops in self.model.populations.values():
|
27
|
+
for pop in pops.values():
|
28
|
+
start = pop.input_params['start'] + self.duration
|
29
|
+
end = pop.input_params['end'] + self.duration
|
30
|
+
pop.update_input_params(**{
|
31
|
+
'start': start,
|
32
|
+
'end': end
|
33
|
+
})
|
34
|
+
|
35
|
+
return self
|
36
|
+
|
37
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
38
|
+
|
39
|
+
# Restore iClamp delays
|
40
|
+
for seg, delay in self._original_iclamp_delays.items():
|
41
|
+
self.model.iclamps[seg].delay = delay
|
42
|
+
|
43
|
+
# Restore input timings
|
44
|
+
for (pk, sk), (start, end) in self._original_input_params.items():
|
45
|
+
self.model.populations[pk][sk].update_input_params(**{'start': start, 'end': end})
|
46
|
+
|
47
|
+
duration_timepoints = int(self.duration / self.model.simulator.dt)
|
48
|
+
if len(self.model.simulator.t) > duration_timepoints and self.truncate:
|
49
|
+
self._truncate()
|
50
|
+
|
51
|
+
def _truncate(self):
|
52
|
+
"""Truncate the simulation time and recordings to the specified duration."""
|
53
|
+
|
54
|
+
onset = int(self.duration / self.model.simulator.dt)
|
55
|
+
|
56
|
+
self.model.simulator.t = self.model.simulator.t[onset:]
|
57
|
+
self.model.simulator.t = [t - self.duration for t in self.model.simulator.t]
|
58
|
+
|
59
|
+
for var, recs in self.model.recordings.items():
|
60
|
+
for seg, rec in recs.items():
|
61
|
+
recs[seg] = rec[onset:]
|
62
|
+
|
63
|
+
self.model.simulator._duration -= self.duration
|
dendrotweaks/utils.py
CHANGED
@@ -26,24 +26,41 @@ SWC_ID_TO_DOMAIN = {
|
|
26
26
|
8: 'reduced',
|
27
27
|
}
|
28
28
|
|
29
|
+
POPULATIONS = {'AMPA': {}, 'NMDA': {}, 'AMPA_NMDA': {}, 'GABAa': {}}
|
30
|
+
|
31
|
+
INDEPENDENT_PARAMS = {
|
32
|
+
'cm': 1, # uF/cm2
|
33
|
+
'Ra': 100, # Ohm cm
|
34
|
+
'ena': 50, # mV
|
35
|
+
'ek': -77, # mV
|
36
|
+
'eca': 140 # mV
|
37
|
+
}
|
38
|
+
|
39
|
+
DOMAIN_TO_GROUP = {
|
40
|
+
'soma': 'somatic',
|
41
|
+
'axon': 'axonal',
|
42
|
+
'dend': 'dendritic',
|
43
|
+
'apic': 'apical',
|
44
|
+
}
|
45
|
+
|
29
46
|
DOMAIN_TO_SWC_ID = {
|
30
47
|
v: k for k, v in SWC_ID_TO_DOMAIN.items()
|
31
48
|
}
|
32
49
|
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
50
|
+
DOMAINS_TO_NEURON = {
|
51
|
+
'soma': 'soma',
|
52
|
+
'perisomatic': 'dend_11',
|
53
|
+
'axon': 'axon',
|
54
|
+
'apic': 'apic',
|
55
|
+
'dend': 'dend',
|
56
|
+
'basal': 'dend_31',
|
57
|
+
'trunk': 'dend_41',
|
58
|
+
'tuft': 'dend_42',
|
59
|
+
'oblique': 'dend_43',
|
60
|
+
'custom': 'dend_5',
|
61
|
+
'reduced': 'dend_8',
|
62
|
+
'undefined': 'dend_0',
|
63
|
+
}
|
47
64
|
|
48
65
|
DOMAINS_TO_COLORS = {
|
49
66
|
'soma': '#E69F00',
|
@@ -60,11 +77,24 @@ DOMAINS_TO_COLORS = {
|
|
60
77
|
'undefined': '#7F7F7F',
|
61
78
|
}
|
62
79
|
|
63
|
-
def
|
80
|
+
def get_swc_idx(domain_name):
|
64
81
|
base_domain, _, idx = domain_name.partition('_')
|
65
|
-
|
82
|
+
if base_domain == 'reduced':
|
83
|
+
return int(f'8{idx}')
|
84
|
+
elif base_domain == 'custom':
|
85
|
+
return int(f'5{idx}')
|
86
|
+
return DOMAIN_TO_SWC_ID.get(base_domain, 0)
|
66
87
|
|
88
|
+
def get_domain_name(swc_idx):
|
89
|
+
if str(swc_idx).startswith('8'):
|
90
|
+
return 'reduced_' + str(swc_idx)[1:]
|
91
|
+
elif str(swc_idx).startswith('5'):
|
92
|
+
return 'custom_' + str(swc_idx)[1:]
|
93
|
+
return SWC_ID_TO_DOMAIN.get(swc_idx, 'undefined')
|
67
94
|
|
95
|
+
def get_domain_color(domain_name):
|
96
|
+
base_domain, _, idx = domain_name.partition('_')
|
97
|
+
return DOMAINS_TO_COLORS.get(base_domain, '#7F7F7F')
|
68
98
|
|
69
99
|
def timeit(func):
|
70
100
|
def wrapper(*args, **kwargs):
|
@@ -75,7 +105,6 @@ def timeit(func):
|
|
75
105
|
return result
|
76
106
|
return wrapper
|
77
107
|
|
78
|
-
|
79
108
|
def calculate_lambda_f(distances, diameters, Ra=35.4, Cm=1, frequency=100):
|
80
109
|
"""
|
81
110
|
Calculate the frequency-dependent length constant (lambda_f) according to NEURON's implementation,
|
@@ -118,10 +147,6 @@ def calculate_lambda_f(distances, diameters, Ra=35.4, Cm=1, frequency=100):
|
|
118
147
|
# Return section_L/lam (electrotonic length of the section)
|
119
148
|
return section_L / lam
|
120
149
|
|
121
|
-
if (__name__ == '__main__'):
|
122
|
-
print('Executing as standalone script')
|
123
|
-
|
124
|
-
|
125
150
|
def dynamic_import(module_name, class_name):
|
126
151
|
"""
|
127
152
|
Dynamically import a class from a module.
|
@@ -142,20 +167,17 @@ def dynamic_import(module_name, class_name):
|
|
142
167
|
module = import_module(module_name)
|
143
168
|
return getattr(module, class_name)
|
144
169
|
|
145
|
-
|
146
170
|
def list_folders(path_to_folder):
|
147
171
|
folders = [f for f in os.listdir(path_to_folder)
|
148
172
|
if os.path.isdir(os.path.join(path_to_folder, f))]
|
149
173
|
sorted_folders = sorted(folders, key=lambda x: x.lower())
|
150
174
|
return sorted_folders
|
151
175
|
|
152
|
-
|
153
176
|
def list_files(path_to_folder, extension):
|
154
177
|
files = [f for f in os.listdir(path_to_folder)
|
155
178
|
if f.endswith(extension)]
|
156
179
|
return files
|
157
180
|
|
158
|
-
|
159
181
|
def write_file(content: str, path_to_file: str, verbose: bool = True) -> None:
|
160
182
|
"""
|
161
183
|
Write content to a file.
|
@@ -175,13 +197,11 @@ def write_file(content: str, path_to_file: str, verbose: bool = True) -> None:
|
|
175
197
|
f.write(content)
|
176
198
|
print(f"Saved content to {path_to_file}")
|
177
199
|
|
178
|
-
|
179
200
|
def read_file(path_to_file):
|
180
201
|
with open(path_to_file, 'r') as f:
|
181
202
|
content = f.read()
|
182
203
|
return content
|
183
204
|
|
184
|
-
|
185
205
|
def download_example_data(path_to_destination, include_templates=True, include_modfiles=True):
|
186
206
|
"""
|
187
207
|
Download and extract specific folders from the DendroTweaks GitHub repository:
|
@@ -240,8 +260,6 @@ def download_example_data(path_to_destination, include_templates=True, include_m
|
|
240
260
|
os.remove(zip_path)
|
241
261
|
print(f"Data downloaded and extracted successfully to {path_to_destination}/.")
|
242
262
|
|
243
|
-
|
244
|
-
|
245
263
|
def apply_dark_theme():
|
246
264
|
"""
|
247
265
|
Apply a dark theme to matplotlib plots.
|
@@ -259,4 +277,71 @@ def apply_dark_theme():
|
|
259
277
|
'ytick.color': 'white',
|
260
278
|
'text.color': 'white',
|
261
279
|
'axes.prop_cycle': plt.cycler(color=plt.cm.tab10.colors), # use standard matplotlib colors
|
262
|
-
})
|
280
|
+
})
|
281
|
+
|
282
|
+
def mse(y_true, y_pred):
|
283
|
+
return np.mean((np.array(y_true) - np.array(y_pred)) ** 2)
|
284
|
+
|
285
|
+
def poly_fit(x, y, max_degree=6, tolerance=1e-6):
|
286
|
+
"""
|
287
|
+
Fit a polynomial to the data and return the coefficients and predicted values.
|
288
|
+
"""
|
289
|
+
for degree in range(max_degree + 1):
|
290
|
+
coeffs = np.polyfit(x, y, degree)
|
291
|
+
y_pred = np.polyval(coeffs, x)
|
292
|
+
if np.all(np.abs(np.array(y) - y_pred) < tolerance):
|
293
|
+
break
|
294
|
+
return coeffs, y_pred
|
295
|
+
|
296
|
+
def step_fit(x, y):
|
297
|
+
"""
|
298
|
+
Fit a single step function with variable-width transition zone.
|
299
|
+
Returns (high_val, low_val, start, end), and predicted y-values.
|
300
|
+
"""
|
301
|
+
x = np.array(x)
|
302
|
+
y = np.array(y)
|
303
|
+
|
304
|
+
sort_idx = np.argsort(x)
|
305
|
+
x = x[sort_idx]
|
306
|
+
y = y[sort_idx]
|
307
|
+
|
308
|
+
best_mse = float('inf')
|
309
|
+
best_params = None
|
310
|
+
best_pred = None
|
311
|
+
|
312
|
+
n = len(x)
|
313
|
+
for i in range(n - 1):
|
314
|
+
for j in range(i + 1, n):
|
315
|
+
start = x[i]
|
316
|
+
end = x[j]
|
317
|
+
inside = (x > start) & (x < end)
|
318
|
+
outside = ~inside
|
319
|
+
|
320
|
+
if not np.any(inside) or not np.any(outside):
|
321
|
+
continue
|
322
|
+
|
323
|
+
high_val = np.nanmean(y[inside])
|
324
|
+
low_val = np.nanmean(y[outside])
|
325
|
+
|
326
|
+
pred = np.where(inside, high_val, low_val)
|
327
|
+
score = mse(y, pred)
|
328
|
+
|
329
|
+
if score < best_mse:
|
330
|
+
best_mse = score
|
331
|
+
best_params = (start, end, low_val, high_val)
|
332
|
+
best_pred = pred
|
333
|
+
|
334
|
+
return best_params, best_pred
|
335
|
+
|
336
|
+
DEFAULT_FIT_MODELS = {
|
337
|
+
'poly': {
|
338
|
+
'fit': poly_fit,
|
339
|
+
'score': mse,
|
340
|
+
'complexity': lambda coeffs: len(coeffs) - 1 # degree of polynomial
|
341
|
+
},
|
342
|
+
'step': {
|
343
|
+
'fit': step_fit,
|
344
|
+
'score': mse,
|
345
|
+
'complexity': lambda params: 4
|
346
|
+
}
|
347
|
+
}
|