dendrotweaks 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. dendrotweaks/__init__.py +10 -0
  2. dendrotweaks/analysis/__init__.py +11 -0
  3. dendrotweaks/analysis/ephys_analysis.py +482 -0
  4. dendrotweaks/analysis/morphometric_analysis.py +106 -0
  5. dendrotweaks/membrane/__init__.py +6 -0
  6. dendrotweaks/membrane/default_mod/AMPA.mod +65 -0
  7. dendrotweaks/membrane/default_mod/AMPA_NMDA.mod +100 -0
  8. dendrotweaks/membrane/default_mod/CaDyn.mod +54 -0
  9. dendrotweaks/membrane/default_mod/GABAa.mod +65 -0
  10. dendrotweaks/membrane/default_mod/Leak.mod +27 -0
  11. dendrotweaks/membrane/default_mod/NMDA.mod +72 -0
  12. dendrotweaks/membrane/default_mod/vecstim.mod +76 -0
  13. dendrotweaks/membrane/default_templates/NEURON_template.py +354 -0
  14. dendrotweaks/membrane/default_templates/default.py +73 -0
  15. dendrotweaks/membrane/default_templates/standard_channel.mod +87 -0
  16. dendrotweaks/membrane/default_templates/template_jaxley.py +108 -0
  17. dendrotweaks/membrane/default_templates/template_jaxley_new.py +108 -0
  18. dendrotweaks/membrane/distributions.py +324 -0
  19. dendrotweaks/membrane/groups.py +103 -0
  20. dendrotweaks/membrane/io/__init__.py +11 -0
  21. dendrotweaks/membrane/io/ast.py +201 -0
  22. dendrotweaks/membrane/io/code_generators.py +312 -0
  23. dendrotweaks/membrane/io/converter.py +108 -0
  24. dendrotweaks/membrane/io/factories.py +144 -0
  25. dendrotweaks/membrane/io/grammar.py +417 -0
  26. dendrotweaks/membrane/io/loader.py +90 -0
  27. dendrotweaks/membrane/io/parser.py +499 -0
  28. dendrotweaks/membrane/io/reader.py +212 -0
  29. dendrotweaks/membrane/mechanisms.py +574 -0
  30. dendrotweaks/model.py +1916 -0
  31. dendrotweaks/model_io.py +75 -0
  32. dendrotweaks/morphology/__init__.py +5 -0
  33. dendrotweaks/morphology/domains.py +100 -0
  34. dendrotweaks/morphology/io/__init__.py +5 -0
  35. dendrotweaks/morphology/io/factories.py +212 -0
  36. dendrotweaks/morphology/io/reader.py +66 -0
  37. dendrotweaks/morphology/io/validation.py +212 -0
  38. dendrotweaks/morphology/point_trees.py +681 -0
  39. dendrotweaks/morphology/reduce/__init__.py +16 -0
  40. dendrotweaks/morphology/reduce/reduce.py +155 -0
  41. dendrotweaks/morphology/reduce/reduced_cylinder.py +129 -0
  42. dendrotweaks/morphology/sec_trees.py +1112 -0
  43. dendrotweaks/morphology/seg_trees.py +157 -0
  44. dendrotweaks/morphology/trees.py +567 -0
  45. dendrotweaks/path_manager.py +261 -0
  46. dendrotweaks/simulators.py +235 -0
  47. dendrotweaks/stimuli/__init__.py +3 -0
  48. dendrotweaks/stimuli/iclamps.py +73 -0
  49. dendrotweaks/stimuli/populations.py +265 -0
  50. dendrotweaks/stimuli/synapses.py +203 -0
  51. dendrotweaks/utils.py +239 -0
  52. dendrotweaks-0.3.1.dist-info/METADATA +70 -0
  53. dendrotweaks-0.3.1.dist-info/RECORD +56 -0
  54. dendrotweaks-0.3.1.dist-info/WHEEL +5 -0
  55. dendrotweaks-0.3.1.dist-info/licenses/LICENSE +674 -0
  56. dendrotweaks-0.3.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,265 @@
1
+ from dendrotweaks.morphology.seg_trees import Segment
2
+ from dendrotweaks.stimuli.synapses import Synapse
3
+
4
+ from collections import defaultdict
5
+
6
+ from typing import List
7
+ import numpy as np
8
+
9
+ KINETIC_PARAMS = {
10
+ 'AMPA': {
11
+ 'gmax': 0.001,
12
+ 'tau_rise': 0.1,
13
+ 'tau_decay': 2.5,
14
+ 'e': 0
15
+ },
16
+ 'NMDA': {
17
+ 'gmax': 0.7 * 0.001,
18
+ 'tau_rise': 2,
19
+ 'tau_decay': 30,
20
+ 'e': 0,
21
+ 'gamma': 0.062,
22
+ 'mu': 0.28,
23
+ },
24
+ 'AMPA_NMDA': {
25
+ 'gmax_AMPA': 0.001,
26
+ 'gmax_NMDA': 0.7 * 0.001,
27
+ 'tau_rise_AMPA': 0.1,
28
+ 'tau_decay_AMPA': 2.5,
29
+ 'tau_rise_NMDA': 2,
30
+ 'tau_decay_NMDA': 30,
31
+ 'e': 0,
32
+ 'gamma': 0.062,
33
+ 'mu': 0.28,
34
+ },
35
+ 'GABAa': {
36
+ 'gmax': 0.001,
37
+ 'tau_rise': 0.1,
38
+ 'tau_decay': 8,
39
+ 'e': -70
40
+ }
41
+ }
42
+
43
+ class Population():
44
+ """
45
+ A population of "virtual" presynaptic neurons forming synapses on the
46
+ explicitely modelled postsynaptic neuron.
47
+
48
+ The population is defined by the number of synapses N, the segments
49
+ on which the synapses are placed, and the type of synapse. All synapses
50
+ in the population share the same kinetic parameters. Global input parameters
51
+ such as rate, noise, etc. are shared by all synapses in the population,
52
+ however, each synapse receives a unique input spike train.
53
+
54
+ Parameters
55
+ ----------
56
+ idx : str
57
+ The index of the population.
58
+ segments : List[Segment]
59
+ The segments on which the synapses are placed.
60
+ N : int
61
+ The number of synapses in the population.
62
+ syn_type : str
63
+ The type of synapse to create e.g. 'AMPA', 'NMDA', 'AMPA_NMDA', 'GABA'.
64
+
65
+ Attributes
66
+ ----------
67
+ idx : str
68
+ The index of the population.
69
+ segments : List[Segment]
70
+ The segments on which the synapses are placed.
71
+ N : int
72
+ The number of synapses in the population.
73
+ syn_type : str
74
+ The type of synapse to create e.g. 'AMPA', 'NMDA', 'AMPA_NMDA', 'GABA'.
75
+ synapses : dict
76
+ A dictionary of synapses in the population, where the key is the segment index.
77
+ input_params : dict
78
+ The input parameters of the synapses in the population.
79
+ kinetic_params : dict
80
+ The kinetic parameters of the synapses in the population.
81
+ """
82
+
83
+ def __init__(self, idx: str, segments: List[Segment], N: int, syn_type: str) -> None:
84
+
85
+ self.idx = idx
86
+ self.segments = segments
87
+ self.sections = list(set([seg._section for seg in segments]))
88
+ self._excluded_segments = [seg for sec in self.sections for seg in sec if seg not in segments]
89
+ self.syn_type = syn_type
90
+
91
+ self.N = N
92
+
93
+ self.synapses = {}
94
+
95
+ self.input_params = {
96
+ 'rate': 1,
97
+ 'noise': 0,
98
+ 'start': 100,
99
+ 'end': 200,
100
+ 'weight': 1,
101
+ 'delay': 0
102
+ }
103
+
104
+ self.kinetic_params = KINETIC_PARAMS[syn_type]
105
+
106
+ def __repr__(self):
107
+ return f"<Population({self.name}, N={self.N})>"
108
+
109
+ @property
110
+ def name(self):
111
+ """A unique name for the population."""
112
+ return f"{self.syn_type}_{self.idx}"
113
+
114
+
115
+ @property
116
+ def spike_times(self):
117
+ """
118
+ Return the spike times of the synapses in the population.
119
+ """
120
+ spike_times = defaultdict(list)
121
+ for seg, syns in self.synapses.items():
122
+ for syn in syns:
123
+ spike_times[syn].extend(syn.spike_times)
124
+ return dict(spike_times)
125
+
126
+
127
+ def update_kinetic_params(self, **params):
128
+ """
129
+ Update the kinetic parameters of the synapses.
130
+
131
+ Parameters
132
+ ----------
133
+ **params : dict
134
+ The parameters to update self.kinetic_params.
135
+ Options are:
136
+ - gmax: the maximum conductance of the synapse
137
+ - tau_rise: the rise time of the synapse
138
+ - tau_decay: the decay time of the synapse
139
+ - e: the reversal potential of the synapse
140
+ - gamma: the voltage dependence of the magnesium block (NMDA only)
141
+ - mu: the sensitivity of the magnesium block to Mg2+ concentration (NMDA only)
142
+ """
143
+ self.kinetic_params.update(params)
144
+ for syns in self.synapses.values():
145
+ for syn in syns:
146
+ for key, value in params.items():
147
+ if hasattr(syn._ref_syn, key):
148
+ setattr(syn._ref_syn, key, value)
149
+
150
+ def update_input_params(self, **params):
151
+ """
152
+ Update the input parameters of the synapses.
153
+
154
+ Parameters
155
+ ----------
156
+ **params : dict
157
+ The parameters to update self.input_params.
158
+ Options are:
159
+ - rate: the rate of the input in Hz
160
+ - noise: the noise level of the input
161
+ - start: the start time of the input
162
+ - end: the end time of the input
163
+ - weight: the weight of the synapse
164
+ - delay: the delay of the synapse
165
+ """
166
+ self.input_params.update(params)
167
+ self.create_inputs()
168
+
169
+ # ALLOCATION METHODS
170
+
171
+ def _choose_synapse_locations(self):
172
+
173
+ valid_locs = [(sec, x) for sec in self.sections
174
+ for x in np.linspace(0, 1, 1001)
175
+ if sec(x) not in self._excluded_segments]
176
+
177
+ syn_locs = [valid_locs[np.random.choice(len(valid_locs))] for _ in range(self.N)]
178
+
179
+ return syn_locs
180
+
181
+
182
+ def allocate_synapses(self, syn_locs=None):
183
+
184
+ if syn_locs is None:
185
+ syn_locs = self._choose_synapse_locations()
186
+ syn_type = self.syn_type
187
+ self.synapses = {(sec, x) : [] for sec, x in syn_locs}
188
+ for sec, x in syn_locs:
189
+ self.synapses[(sec, x)].append(Synapse(syn_type, sec, x))
190
+
191
+ self.update_kinetic_params(**self.kinetic_params)
192
+
193
+
194
+
195
+ # CREATION METHODS
196
+
197
+ def create_inputs(self):
198
+ """
199
+ Create and reference the synapses in a simulator.
200
+
201
+ This method should be called after the synapses have been allocated.
202
+ """
203
+ for syns in self.synapses.values():
204
+ for syn in syns:
205
+
206
+ syn.create_stim(
207
+ rate=self.input_params['rate'],
208
+ noise=self.input_params['noise'],
209
+ duration=self.input_params['end'] - self.input_params['start'],
210
+ delay=self.input_params['start']
211
+ )
212
+
213
+ syn.create_con(
214
+ delay=self.input_params['delay'],
215
+ weight=self.input_params['weight']
216
+ )
217
+
218
+
219
+ def to_dict(self):
220
+ """
221
+ Convert the population to a dictionary.
222
+ """
223
+ return {
224
+ 'name': self.name,
225
+ 'syn_type': self.syn_type,
226
+ 'N': self.N,
227
+ 'input_params': {**self.input_params},
228
+ 'kinetic_params': {**self.kinetic_params},
229
+ }
230
+
231
+ @property
232
+ def flat_synapses(self):
233
+ """
234
+ Return a flat list of synapses.
235
+ """
236
+ return [syn for syns in self.synapses.values() for syn in syns]
237
+
238
+ def to_csv(self):
239
+ """
240
+ Prepare the data about the location of synapses for saving to a CSV file.
241
+ """
242
+ flat_synapses = self.flat_synapses
243
+ return {
244
+ 'syn_type': [self.syn_type] * len(flat_synapses),
245
+ 'name': [self.name] * len(flat_synapses),
246
+ 'sec_idx': [syn.sec.idx for syn in flat_synapses],
247
+ 'loc': [syn.loc for syn in flat_synapses],
248
+ }
249
+
250
+
251
+ def clean(self):
252
+ """
253
+ Clear the synapses and connections from the simulator.
254
+
255
+ Removes all synapses, NetCon and NetStim objects.
256
+ """
257
+ for syns in self.synapses.values():
258
+ for syn in syns:
259
+ if syn._ref_stim:
260
+ syn._clear_stim()
261
+ if syn._ref_con:
262
+ syn._clear_con()
263
+ self.synapses.clear()
264
+
265
+
@@ -0,0 +1,203 @@
1
+ from typing import List
2
+ from neuron import h
3
+ import numpy as np
4
+ from dendrotweaks.morphology.seg_trees import Segment
5
+
6
+ class Synapse():
7
+ """
8
+ A synapse object that can be placed on a section of a neuron.
9
+
10
+ Contains references to the NEURON synapse object, the stimulus object (NetStim),
11
+ and the connection object (NetCon).
12
+
13
+ Parameters
14
+ ----------
15
+ syn_type : str
16
+ The type of synapse to create e.g. 'AMPA', 'NMDA', 'AMPA_NMDA', 'GABA'.
17
+ sec : Section
18
+ The section on which the synapse is placed.
19
+ loc : float
20
+ The location on the section where the synapse is placed, between 0 and 1.
21
+
22
+ Attributes
23
+ ----------
24
+ sec : Section
25
+ The section on which the synapse is placed.
26
+ loc : float
27
+ The location on the section where the synapse is placed, between 0 and 1.
28
+ """
29
+
30
+ def __init__(self, syn_type: str, sec, loc=0.5) -> None:
31
+ """
32
+ Creates a new synapse object.
33
+ """
34
+ self._Model = getattr(h, syn_type)
35
+ self.sec = sec
36
+ self.loc = loc
37
+
38
+ self._ref_syn = self._Model(self.seg._ref)
39
+ self._ref_stim = None
40
+ self._ref_con = None
41
+
42
+ @property
43
+ def seg(self):
44
+ """
45
+ The segment on which the synapse is placed.
46
+ """
47
+ return self.sec(self.loc)
48
+
49
+ def __repr__(self):
50
+ return f"<Synapse({self.sec}({self.loc:.3f}))>"
51
+
52
+ @property
53
+ def spike_times(self):
54
+ """
55
+ The spike times of the stimulus from the NetStim object.
56
+ """
57
+ if self._ref_stim is not None:
58
+ return self._ref_stim[1].to_python()
59
+ return []
60
+
61
+ def _clear_stim(self):
62
+ """
63
+ Clears the stimulus (NetStim) object.
64
+ """
65
+ self._ref_stim[0] = None
66
+ self._ref_stim[1] = None
67
+ self._ref_stim.pop(0)
68
+ self._ref_stim.pop(0)
69
+ self._ref_stim = None
70
+
71
+ def create_stim(self, **kwargs):
72
+ """
73
+ Creates a stimulus (NetStim) for the synapse.
74
+
75
+ Parameters
76
+ ----------
77
+ **kwargs : dict
78
+ Keyword arguments for the create_spike_times function.
79
+ """
80
+
81
+ if self._ref_stim is not None:
82
+ self._clear_stim()
83
+
84
+ spike_times = create_spike_times(**kwargs)
85
+ spike_vec = h.Vector(spike_times)
86
+ stim = h.VecStim()
87
+ stim.play(spike_vec)
88
+
89
+ self._ref_stim = [stim, spike_vec]
90
+
91
+ def _clear_con(self):
92
+ """
93
+ Clears the connection (NetCon) object.
94
+ """
95
+ self._ref_con = None
96
+
97
+ def create_con(self, delay, weight):
98
+ """
99
+ Create a connection (NetCon) between the stimulus and the synapse.
100
+
101
+ Parameters
102
+ ----------
103
+ delay : int
104
+ The delay of the connection, in ms.
105
+ weight : float
106
+ The weight of the connection.
107
+ """
108
+ if self._ref_con is not None:
109
+ self._clear_con()
110
+ self._ref_con = h.NetCon(self._ref_stim[0],
111
+ self._ref_syn,
112
+ 0,
113
+ delay,
114
+ weight)
115
+
116
+
117
+ def create_spike_times(rate=1, noise=1, duration=300, delay=0):
118
+ """
119
+ Create a spike train with a given regularity.
120
+
121
+ Parameters
122
+ ----------
123
+ rate : float
124
+ The rate of the spike train, in Hz.
125
+ noise : float
126
+ A parameter between 0 and 1 that controls the regularity of the spike train.
127
+ 0 corresponds to a regular spike train. 1 corresponds to a Poisson process.
128
+ duration : int
129
+ The total time to run the simulation for, in ms.
130
+ delay : int
131
+ The delay of the spike train, in ms.
132
+
133
+ Returns
134
+ -------
135
+ np.array
136
+ The spike times as a vector, in ms.
137
+ """
138
+
139
+ if noise == 1:
140
+ return delay + generate_poisson_process(rate, duration)
141
+ else:
142
+ return delay + generate_jittered_spikes(rate, duration, noise)
143
+
144
+
145
+ def generate_poisson_process(lam, dur):
146
+ """
147
+ Generate a Poisson process.
148
+
149
+ Parameters
150
+ ----------
151
+ lam : float
152
+ The rate parameter (lambda) of the Poisson process, in Hz.
153
+ dur : int
154
+ The total time to run the simulation for, in ms.
155
+
156
+ Returns
157
+ -------
158
+ np.array
159
+ The spike times as a vector, in ms.
160
+ """
161
+ dur_s = dur / 1000
162
+ intervals = np.random.exponential(1/lam, int(lam*dur_s))
163
+ spike_times = np.cumsum(intervals)
164
+ spike_times = spike_times[spike_times <= dur_s]
165
+ spike_times_ms = spike_times * 1000
166
+
167
+ return spike_times_ms
168
+
169
+
170
+ def generate_jittered_spikes(rate, dur, noise):
171
+ """
172
+ Generate a jittered spike train.
173
+
174
+ Parameters
175
+ ----------
176
+ rate : float
177
+ The rate of the spike train, in Hz.
178
+ dur : int
179
+ The total time to run the simulation for, in ms.
180
+ noise : float
181
+ A parameter between 0 and 1 that controls the regularity of the spike train.
182
+ 0 corresponds to a regular spike train. 1 corresponds to a Poisson process.
183
+
184
+
185
+ Returns
186
+ -------
187
+ np.array
188
+ The spike times as a vector, in ms.
189
+ """
190
+ dur_s = dur / 1000
191
+ spike_times = np.arange(0, dur_s, 1/rate)
192
+
193
+ # Add noise
194
+ noise_values = np.random.normal(0, noise/rate, len(spike_times))
195
+ spike_times += noise_values
196
+
197
+ # Ensure spike times are within the duration and sort them
198
+ spike_times = spike_times[(spike_times >= 0) & (spike_times <= dur_s)]
199
+ spike_times.sort()
200
+
201
+ spike_times_ms = spike_times * 1000
202
+
203
+ return spike_times_ms
dendrotweaks/utils.py ADDED
@@ -0,0 +1,239 @@
1
+ """
2
+ Utility functions for dendrotweaks package.
3
+ """
4
+
5
+ import time
6
+ import numpy as np
7
+ import os
8
+ import zipfile
9
+ import urllib.request
10
+ import matplotlib.pyplot as plt
11
+
12
+ SWC_ID_TO_DOMAIN = {
13
+ 0: 'undefined',
14
+ 1: 'soma',
15
+ 11: 'perisomatic',
16
+ 2: 'axon',
17
+ 3: 'dend',
18
+ 31: 'basal',
19
+ 4: 'apic',
20
+ 41: 'trunk',
21
+ 42: 'tuft',
22
+ 43: 'oblique',
23
+ 5: 'custom',
24
+ 6: 'neurite',
25
+ 7: 'glia',
26
+ 8: 'reduced',
27
+ }
28
+
29
+ DOMAIN_TO_SWC_ID = {
30
+ v: k for k, v in SWC_ID_TO_DOMAIN.items()
31
+ }
32
+
33
+ def get_swc_idx(domain_name):
34
+ base_domain, _, idx = domain_name.partition('_')
35
+ if base_domain == 'reduced':
36
+ return int(f'8{idx}')
37
+ elif base_domain == 'custom':
38
+ return int(f'5{idx}')
39
+ return DOMAIN_TO_SWC_ID.get(base_domain, 0)
40
+
41
+ def get_domain_name(swc_idx):
42
+ if str(swc_idx).startswith('8'):
43
+ return 'reduced_' + str(swc_idx)[1:]
44
+ elif str(swc_idx).startswith('5'):
45
+ return 'custom_' + str(swc_idx)[1:]
46
+ return SWC_ID_TO_DOMAIN.get(swc_idx, 'undefined')
47
+
48
+ DOMAINS_TO_COLORS = {
49
+ 'soma': '#E69F00',
50
+ 'apic': '#0072B2',
51
+ 'dend': '#019E73',
52
+ 'basal': '#31A354',
53
+ 'axon': '#F0E442',
54
+ 'trunk': '#56B4E9',
55
+ 'tuft': '#A55194',
56
+ 'oblique': '#8C564B',
57
+ 'perisomatic': '#D55E00',
58
+ 'custom': '#D62728',
59
+ 'reduced': '#E377C2',
60
+ 'undefined': '#7F7F7F',
61
+ }
62
+
63
+ def get_domain_color(domain_name):
64
+ base_domain, _, idx = domain_name.partition('_')
65
+ return DOMAINS_TO_COLORS.get(base_domain, '#7F7F7F')
66
+
67
+
68
+
69
+ def timeit(func):
70
+ def wrapper(*args, **kwargs):
71
+ start = time.time()
72
+ result = func(*args, **kwargs)
73
+ end = time.time()
74
+ print(f" Elapsed time: {round(end-start, 3)} seconds")
75
+ return result
76
+ return wrapper
77
+
78
+
79
+ def calculate_lambda_f(distances, diameters, Ra=35.4, Cm=1, frequency=100):
80
+ """
81
+ Calculate the frequency-dependent length constant (lambda_f) according to NEURON's implementation,
82
+ using 3D point data for accurate representation of varying diameter frusta.
83
+
84
+ Args:
85
+ distances (list/array): Cumulative euclidean distances between 3D points along the section from 0 to section length
86
+ diameters (list/array): Corresponding diameters at each position in micrometers
87
+ Ra (float): Axial resistance in ohm*cm
88
+ Cm (float): Specific membrane capacitance in µF/cm²
89
+ frequency (float): Frequency in Hz
90
+
91
+ Returns:
92
+ float: Lambda_f in micrometers
93
+ """
94
+ if len(distances) < 2 or len(diameters) < 2:
95
+ raise ValueError("At least 2 points are required for 3D calculation")
96
+
97
+ if len(distances) != len(diameters):
98
+ raise ValueError("distances and diameters must have the same length")
99
+
100
+ # Initialize variables
101
+ lam = 0
102
+ section_L = distances[-1]
103
+
104
+ # Calculate the contribution of each frustum
105
+ for i in range(1, len(distances)):
106
+ # Frustum length
107
+ frustum_length = distances[i] - distances[i-1]
108
+ # Average of diameters at endpoints
109
+ d1 = diameters[i-1]
110
+ d2 = diameters[i]
111
+
112
+ # Add frustum contribution to lambda calculation
113
+ lam += frustum_length / np.sqrt(d1 + d2)
114
+
115
+ # Apply the frequency-dependent factor
116
+ lam *= np.sqrt(2) * 1e-5 * np.sqrt(4 * np.pi * frequency * Ra * Cm)
117
+
118
+ # Return section_L/lam (electrotonic length of the section)
119
+ return section_L / lam
120
+
121
+ if (__name__ == '__main__'):
122
+ print('Executing as standalone script')
123
+
124
+
125
+ def dynamic_import(module_name, class_name):
126
+ """
127
+ Dynamically import a class from a module.
128
+
129
+ Parameters
130
+ ----------
131
+ module_name : str
132
+ Name of the module to import.
133
+ class_name : str
134
+ Name of the class to import.
135
+ """
136
+
137
+ from importlib import import_module
138
+
139
+ import sys
140
+ sys.path.append('app/src')
141
+ print(f"Importing class {class_name} from module {module_name}.py")
142
+ module = import_module(module_name)
143
+ return getattr(module, class_name)
144
+
145
+
146
+ def list_folders(path_to_folder):
147
+ folders = [f for f in os.listdir(path_to_folder)
148
+ if os.path.isdir(os.path.join(path_to_folder, f))]
149
+ sorted_folders = sorted(folders, key=lambda x: x.lower())
150
+ return sorted_folders
151
+
152
+
153
+ def list_files(path_to_folder, extension):
154
+ files = [f for f in os.listdir(path_to_folder)
155
+ if f.endswith(extension)]
156
+ return files
157
+
158
+
159
+ def write_file(content: str, path_to_file: str, verbose: bool = True) -> None:
160
+ """
161
+ Write content to a file.
162
+
163
+ Parameters
164
+ ----------
165
+ content : str
166
+ The content to write to the file.
167
+ path_to_file : str
168
+ The path to the file.
169
+ verbose : bool, optional
170
+ Whether to print a message after writing the file. The default is True.
171
+ """
172
+ if not os.path.exists(os.path.dirname(path_to_file)):
173
+ os.makedirs(os.path.dirname(path_to_file))
174
+ with open(path_to_file, 'w') as f:
175
+ f.write(content)
176
+ print(f"Saved content to {path_to_file}")
177
+
178
+
179
+ def read_file(path_to_file):
180
+ with open(path_to_file, 'r') as f:
181
+ content = f.read()
182
+ return content
183
+
184
+
185
+ def download_example_data(path_to_destination):
186
+ """
187
+ Download the examples subfolder from the DendroTweaks GitHub repository.
188
+
189
+ Parameters
190
+ ----------
191
+ path_to_destination : str
192
+ The path to the destination folder where the examples will be downloaded.
193
+ """
194
+ if not os.path.exists(path_to_destination):
195
+ os.makedirs(path_to_destination)
196
+
197
+ repo_url = "https://github.com/Poirazi-Lab/DendroTweaks/archive/refs/heads/main.zip"
198
+ zip_path = os.path.join(path_to_destination, "examples.zip")
199
+
200
+ print(f"Downloading examples from {repo_url}...")
201
+ urllib.request.urlretrieve(repo_url, zip_path)
202
+
203
+ print(f"Extracting examples to {path_to_destination}")
204
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
205
+ for member in zip_ref.namelist():
206
+ if member.startswith("DendroTweaks-main/examples/"):
207
+ # Extract the file with the correct path
208
+ member_path = os.path.relpath(member, "DendroTweaks-main/examples")
209
+ target_path = os.path.join(path_to_destination, member_path)
210
+ if member.endswith('/'):
211
+ os.makedirs(target_path, exist_ok=True)
212
+ else:
213
+ os.makedirs(os.path.dirname(target_path), exist_ok=True)
214
+ with zip_ref.open(member) as source, open(target_path, 'wb') as target:
215
+ target.write(source.read())
216
+
217
+ os.remove(zip_path) # Clean up the zip file
218
+ print(f"Examples downloaded successfully to {path_to_destination}/.")
219
+
220
+
221
+
222
+ def apply_dark_theme():
223
+ """
224
+ Apply a dark theme to matplotlib plots.
225
+ """
226
+ # dark theme
227
+ plt.style.use('dark_background')
228
+
229
+ # customize the style
230
+ plt.rcParams.update({
231
+ 'figure.facecolor': '#131416',
232
+ 'axes.facecolor': '#131416',
233
+ 'axes.edgecolor': 'white',
234
+ 'axes.labelcolor': 'white',
235
+ 'xtick.color': 'white',
236
+ 'ytick.color': 'white',
237
+ 'text.color': 'white',
238
+ 'axes.prop_cycle': plt.cycler(color=plt.cm.tab10.colors), # use standard matplotlib colors
239
+ })