ppdmod 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ppdmod/__init__.py +1 -0
- ppdmod/base.py +225 -0
- ppdmod/components.py +557 -0
- ppdmod/config/standard_parameters.toml +290 -0
- ppdmod/data.py +485 -0
- ppdmod/fitting.py +546 -0
- ppdmod/options.py +164 -0
- ppdmod/parameter.py +152 -0
- ppdmod/plot.py +1241 -0
- ppdmod/utils.py +575 -0
- ppdmod-2.0.0.dist-info/METADATA +68 -0
- ppdmod-2.0.0.dist-info/RECORD +15 -0
- ppdmod-2.0.0.dist-info/WHEEL +5 -0
- ppdmod-2.0.0.dist-info/licenses/LICENSE +21 -0
- ppdmod-2.0.0.dist-info/top_level.txt +1 -0
ppdmod/fitting.py
ADDED
@@ -0,0 +1,546 @@
|
|
1
|
+
import sys
|
2
|
+
from multiprocessing import Pool
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import Any, List, Tuple
|
5
|
+
|
6
|
+
import dynesty.utils as dyutils
|
7
|
+
import numpy as np
|
8
|
+
from dynesty import DynamicNestedSampler
|
9
|
+
from numpy.typing import NDArray
|
10
|
+
|
11
|
+
from .base import Component
|
12
|
+
from .data import get_weights
|
13
|
+
from .options import OPTIONS
|
14
|
+
from .parameter import Parameter
|
15
|
+
from .utils import compare_angles, compute_t3, compute_vis, get_band_indices
|
16
|
+
|
17
|
+
CURRENT_MODULE = sys.modules[__name__]
|
18
|
+
|
19
|
+
|
20
|
+
def get_fit_params(components: List[Component], key: str | None = None) -> NDArray[Any]:
|
21
|
+
"""Gets the fit params from the components.
|
22
|
+
|
23
|
+
Parameters
|
24
|
+
----------
|
25
|
+
components : list of Component
|
26
|
+
The components to be used in the model.
|
27
|
+
key : str, optional
|
28
|
+
If a key is provided, a field of the parameter will be returned.
|
29
|
+
|
30
|
+
Returns
|
31
|
+
-------
|
32
|
+
params : numpy.typing.NDArray
|
33
|
+
"""
|
34
|
+
params = []
|
35
|
+
[
|
36
|
+
params.extend(component.get_params(free=True, shared=False).values())
|
37
|
+
for component in components
|
38
|
+
]
|
39
|
+
params.extend(
|
40
|
+
[
|
41
|
+
component.get_params(free=True, shared=True).values()
|
42
|
+
for component in components
|
43
|
+
][-1]
|
44
|
+
)
|
45
|
+
|
46
|
+
if key is not None:
|
47
|
+
return np.array([getattr(param, key, None) for param in params])
|
48
|
+
|
49
|
+
return np.array(params)
|
50
|
+
|
51
|
+
|
52
|
+
def get_labels(components: List[Component]) -> NDArray[Any]:
|
53
|
+
"""Sets theta from the components.
|
54
|
+
|
55
|
+
Parameters
|
56
|
+
----------
|
57
|
+
components : list of Component
|
58
|
+
The components to be used in the model.
|
59
|
+
|
60
|
+
Returns
|
61
|
+
-------
|
62
|
+
theta : numpy.typing.NDArray
|
63
|
+
"""
|
64
|
+
labels, labels_shared = [], []
|
65
|
+
for index, component in enumerate(components):
|
66
|
+
component_labels = [key for key in component.get_params(free=True)]
|
67
|
+
labels.extend([f"{label}-{index}" for label in component_labels])
|
68
|
+
labels_shared.append(
|
69
|
+
[
|
70
|
+
rf"{key}-\mathrm{{sh}}"
|
71
|
+
for key in component.get_params(free=True, shared=True)
|
72
|
+
]
|
73
|
+
)
|
74
|
+
|
75
|
+
labels.extend(labels_shared[-1])
|
76
|
+
return labels
|
77
|
+
|
78
|
+
|
79
|
+
def get_priors(components: List[Component]) -> NDArray[Any]:
|
80
|
+
"""Gets the priors from the model parameters."""
|
81
|
+
return np.array([param.get_limits() for param in get_fit_params(components)])
|
82
|
+
|
83
|
+
|
84
|
+
def get_units(components: List[Component]) -> NDArray[Any]:
|
85
|
+
"""Sets the units from the components."""
|
86
|
+
return get_fit_params(components, "unit")
|
87
|
+
|
88
|
+
|
89
|
+
def get_theta(components: List[Component]) -> NDArray[Any]:
|
90
|
+
"""Sets the theta vector from the components."""
|
91
|
+
return get_fit_params(components, "value")
|
92
|
+
|
93
|
+
|
94
|
+
def set_components_from_theta(
|
95
|
+
theta: NDArray[Any], uniform: NDArray[Any] = np.array([])
|
96
|
+
) -> List[Component]:
|
97
|
+
"""Sets the components from theta."""
|
98
|
+
components = [component.copy() for component in OPTIONS.model.components]
|
99
|
+
nshared = len(components[-1].get_params(free=True, shared=True))
|
100
|
+
if nshared != 0:
|
101
|
+
theta_list, shared_params = theta[:-nshared], theta[-nshared:]
|
102
|
+
uniforms, shared_uniforms = uniform[:-nshared], uniform[-nshared:]
|
103
|
+
else:
|
104
|
+
theta_list, shared_params = theta, np.array([])
|
105
|
+
uniforms, shared_uniforms = uniform, np.array([])
|
106
|
+
|
107
|
+
theta_list, uniform_list = theta_list.copy().tolist(), uniforms.copy().tolist()
|
108
|
+
shared_params_labels = [
|
109
|
+
label.split("-")[0] for label in get_labels(components) if "sh" in label
|
110
|
+
]
|
111
|
+
|
112
|
+
for component in components:
|
113
|
+
for param in component.get_params(free=True).values():
|
114
|
+
param.value = theta_list.pop(0)
|
115
|
+
param.free = True
|
116
|
+
if uniforms.size != 0:
|
117
|
+
param.uniform = uniform_list.pop(0)
|
118
|
+
|
119
|
+
for index, (param_name, value) in enumerate(
|
120
|
+
zip(shared_params_labels, shared_params)
|
121
|
+
):
|
122
|
+
if hasattr(component, param_name):
|
123
|
+
param = getattr(component, param_name)
|
124
|
+
param.value = value
|
125
|
+
param.free = param.shared = True
|
126
|
+
|
127
|
+
if shared_uniforms.size != 0:
|
128
|
+
param.uniform = shared_uniforms[index]
|
129
|
+
|
130
|
+
return components
|
131
|
+
|
132
|
+
|
133
|
+
def compute_residuals(
|
134
|
+
data: NDArray[Any], model_data: NDArray[Any], kind: str = "linear"
|
135
|
+
) -> NDArray[Any]:
|
136
|
+
"""Computes the residuals from data vs. model."""
|
137
|
+
if kind == "periodic":
|
138
|
+
return np.rad2deg(compare_angles(np.deg2rad(data), np.deg2rad(model_data)))
|
139
|
+
return data - model_data
|
140
|
+
|
141
|
+
|
142
|
+
def compute_chi_sq(
|
143
|
+
data: NDArray[Any],
|
144
|
+
sigma_sq: NDArray[Any],
|
145
|
+
model: NDArray[Any],
|
146
|
+
kind: str = "linear",
|
147
|
+
) -> float:
|
148
|
+
"""Computes the chi sq from."""
|
149
|
+
return np.sum(compute_residuals(data, model, kind) ** 2 / sigma_sq)
|
150
|
+
|
151
|
+
|
152
|
+
def compute_loglike(
|
153
|
+
data: NDArray[Any],
|
154
|
+
error: NDArray[Any],
|
155
|
+
model: NDArray[Any],
|
156
|
+
kind: str = "linear",
|
157
|
+
lnf: float | None = None,
|
158
|
+
):
|
159
|
+
"""Computes the chi square minimisation.
|
160
|
+
|
161
|
+
Parameters
|
162
|
+
----------
|
163
|
+
data : numpy.typing.NDArray
|
164
|
+
The real data.
|
165
|
+
error : numpy.typing.NDArray
|
166
|
+
The real data's error.
|
167
|
+
model : numpy.typing.NDArray
|
168
|
+
The model data.
|
169
|
+
kind : str, optional
|
170
|
+
The method to determine the residuals of the dataset.
|
171
|
+
Either "linear" or "periodic". Default is "linear".
|
172
|
+
|
173
|
+
Returns
|
174
|
+
-------
|
175
|
+
chi_sq : float
|
176
|
+
"""
|
177
|
+
sn = error**2
|
178
|
+
if lnf is not None:
|
179
|
+
sn += model**2 * np.exp(2 * lnf)
|
180
|
+
|
181
|
+
chi_sq = compute_chi_sq(data, sn, model, kind)
|
182
|
+
lnorm = np.sum(data.size * np.log(2 * np.pi) + np.log(sn))
|
183
|
+
return -0.5 * (chi_sq + lnorm)
|
184
|
+
|
185
|
+
|
186
|
+
def compute_observables(
|
187
|
+
components: List[Component],
|
188
|
+
time: NDArray[Any] | None = None,
|
189
|
+
wavelength: NDArray[Any] | None = None,
|
190
|
+
) -> Tuple[NDArray[Any], NDArray[Any], NDArray[Any]]:
|
191
|
+
"""Calculates the observables from the model.
|
192
|
+
|
193
|
+
Parameters
|
194
|
+
----------
|
195
|
+
components : list of Component
|
196
|
+
The components to be used in the model.
|
197
|
+
time : numpy.typing.NDArray, optional
|
198
|
+
The time to be used in the model.
|
199
|
+
wavelength : numpy.typing.NDArray, optional
|
200
|
+
The wavelength to be used in the model.
|
201
|
+
"""
|
202
|
+
wavelength = OPTIONS.fit.wls if wavelength is None else wavelength
|
203
|
+
times = range(time if time is not None else OPTIONS.data.nt)
|
204
|
+
vis = OPTIONS.data.vis2 if "vis2" in OPTIONS.fit.data else OPTIONS.data.vis
|
205
|
+
t3 = OPTIONS.data.t3
|
206
|
+
complex_vis, complex_t3 = [], []
|
207
|
+
for t in times:
|
208
|
+
complex_vis.append(
|
209
|
+
np.sum(
|
210
|
+
[
|
211
|
+
comp.compute_complex_vis(vis.u[t], vis.v[t], t, wavelength)
|
212
|
+
for comp in components
|
213
|
+
],
|
214
|
+
axis=0,
|
215
|
+
)
|
216
|
+
)
|
217
|
+
complex_t3.append(
|
218
|
+
np.sum(
|
219
|
+
[
|
220
|
+
comp.compute_complex_vis(t3.u[t], t3.v[t], t, wavelength)
|
221
|
+
for comp in components
|
222
|
+
],
|
223
|
+
axis=0,
|
224
|
+
)
|
225
|
+
)
|
226
|
+
|
227
|
+
complex_vis, complex_t3 = np.array(complex_vis), np.array(complex_t3)
|
228
|
+
t3_model = np.array(
|
229
|
+
[compute_t3(complex_t3[t], OPTIONS.data.t3.i123[t])[:, 1:] for t in times]
|
230
|
+
)
|
231
|
+
flux_model = np.array([complex_vis[t, :, 0].reshape(-1, 1) for t in times])
|
232
|
+
vis_model = np.array([compute_vis(complex_vis[t, :, 1:]) for t in times])
|
233
|
+
if flux_model.size > 0:
|
234
|
+
flux_model = np.array(
|
235
|
+
[
|
236
|
+
np.tile(flux_model[t], OPTIONS.data.flux.val.shape[-1]).real
|
237
|
+
for t in times
|
238
|
+
]
|
239
|
+
)
|
240
|
+
|
241
|
+
return flux_model, vis_model, t3_model
|
242
|
+
|
243
|
+
|
244
|
+
def compute_nband_fit_chi_sq(
|
245
|
+
model_data: NDArray[Any],
|
246
|
+
ndim: int,
|
247
|
+
reduced: bool = False,
|
248
|
+
) -> float:
|
249
|
+
"""Calculates the sed model's chi square.
|
250
|
+
|
251
|
+
Parameters
|
252
|
+
----------
|
253
|
+
flux_model : numpy.typing.NDArray
|
254
|
+
The model's total flux.
|
255
|
+
ndim : int, optional
|
256
|
+
The number of (parameter) dimensions.
|
257
|
+
reduced : bool, optional
|
258
|
+
Whether to return the reduced chi square.
|
259
|
+
|
260
|
+
Returns
|
261
|
+
-------
|
262
|
+
chi_sq : float
|
263
|
+
The chi square.
|
264
|
+
"""
|
265
|
+
# NOTE: The -1 here indicates that one of the parameters is actually fixed
|
266
|
+
flux, ndim = OPTIONS.data.flux, ndim - 1
|
267
|
+
val, err = map(lambda x: x.squeeze(), [flux.val, flux.err])
|
268
|
+
chi_sq = compute_loglike(
|
269
|
+
val.compressed(),
|
270
|
+
err.compressed(),
|
271
|
+
model_data.squeeze()[~val.mask],
|
272
|
+
)
|
273
|
+
|
274
|
+
if reduced:
|
275
|
+
return chi_sq / (flux.val.size - ndim)
|
276
|
+
|
277
|
+
return chi_sq
|
278
|
+
|
279
|
+
|
280
|
+
def compute_interferometric_loglike(
|
281
|
+
components: List[Component],
|
282
|
+
) -> Tuple:
|
283
|
+
"""Calculates the disc model's chi square.
|
284
|
+
|
285
|
+
Parameters
|
286
|
+
----------
|
287
|
+
components : list of Component
|
288
|
+
The components to be used in the model.
|
289
|
+
method : bool
|
290
|
+
The method used to calculate the chi square.
|
291
|
+
Either "linear" or "logarithmic".
|
292
|
+
Default is "logarithmic".
|
293
|
+
|
294
|
+
Returns
|
295
|
+
-------
|
296
|
+
chi_sq : Tuple of floats
|
297
|
+
The total and the individual chi squares.
|
298
|
+
"""
|
299
|
+
observables = ["flux", "vis", "t3"]
|
300
|
+
model_data = dict(zip(observables, compute_observables(components)))
|
301
|
+
wls = OPTIONS.fit.wls.value
|
302
|
+
|
303
|
+
loglikes = []
|
304
|
+
for key in OPTIONS.fit.data:
|
305
|
+
data = getattr(OPTIONS.data, key)
|
306
|
+
key = key if key != "vis2" else "vis"
|
307
|
+
|
308
|
+
loglikes_bands = []
|
309
|
+
for band in OPTIONS.fit.bands:
|
310
|
+
band_indices = get_band_indices(wls, [band])
|
311
|
+
mask = data.val[:, band_indices].mask
|
312
|
+
loglikes_bands.append(
|
313
|
+
compute_loglike(
|
314
|
+
data.val[:, band_indices].compressed(),
|
315
|
+
data.err[:, band_indices].compressed(),
|
316
|
+
model_data[key][:, band_indices][~mask],
|
317
|
+
kind="linear" if key != "t3" else "periodic",
|
318
|
+
)
|
319
|
+
)
|
320
|
+
loglikes.append(loglikes_bands)
|
321
|
+
|
322
|
+
loglikes = np.array(loglikes).astype(float)
|
323
|
+
weights_general = get_weights(kind="general")
|
324
|
+
weights_bands = get_weights(kind="bands")
|
325
|
+
return np.sum((weights_bands * loglikes).sum(1) * weights_general), loglikes
|
326
|
+
|
327
|
+
|
328
|
+
def sample_uniform(
|
329
|
+
param: Parameter | None = None,
|
330
|
+
theta: float | None = None,
|
331
|
+
prior: List[float] | None = None,
|
332
|
+
) -> float:
|
333
|
+
"""Samples from a uniform prior."""
|
334
|
+
if param is not None:
|
335
|
+
return param.min + (param.max - param.min) * param.uniform
|
336
|
+
return prior[0] + (prior[1] - prior[0]) * theta
|
337
|
+
|
338
|
+
|
339
|
+
def transform_uniform_prior(theta: NDArray[Any], priors: NDArray[Any]) -> float:
|
340
|
+
"""Prior transform for uniform priors."""
|
341
|
+
return priors[:, 0] + (priors[:, 1] - priors[:, 0]) * theta
|
342
|
+
|
343
|
+
|
344
|
+
def nband(params: NDArray[Any], labels: List[str], theta: NDArray[Any]) -> NDArray[Any]:
|
345
|
+
"""Transform that soft constrains successive radii to be smaller than the one before."""
|
346
|
+
indices = list(map(labels.index, filter(lambda x: "weight" in x, labels)))
|
347
|
+
remainder = 100
|
348
|
+
for index in indices[:-1]:
|
349
|
+
params[index] = remainder * theta[index]
|
350
|
+
remainder -= params[index]
|
351
|
+
|
352
|
+
params[indices[-1]] = remainder
|
353
|
+
return params
|
354
|
+
|
355
|
+
|
356
|
+
# NOTE: This ignores the first component (that being the star) -> Not generalised.
|
357
|
+
# Also only works if all the other components are based on the Ring class
|
358
|
+
def radii(components: List[Component]) -> List[Component]:
|
359
|
+
"""Forces the radii to be sequential."""
|
360
|
+
for index, component in enumerate(components[1:], start=1):
|
361
|
+
if not any(name in component.name for name in ["Ring", "TempGrad", "GreyBody"]):
|
362
|
+
continue
|
363
|
+
|
364
|
+
if component.rin.free and not component.rin.shared:
|
365
|
+
if index > 1:
|
366
|
+
component.rin.min = components[index - 1].rout.value
|
367
|
+
|
368
|
+
component.rin.value = sample_uniform(component.rin)
|
369
|
+
component.rout.min = component.rin.value
|
370
|
+
component.rout.value = sample_uniform(component.rout)
|
371
|
+
|
372
|
+
return components
|
373
|
+
|
374
|
+
|
375
|
+
def shift():
|
376
|
+
"""Forces the shift to be within bounds."""
|
377
|
+
# NOTE: Removes overlap caused by photosphere shift
|
378
|
+
# TODO: This does not account for direction -> Problem? Try in a fit.
|
379
|
+
# Subtract the distance also from the centre?
|
380
|
+
for index, comp in enumerate(components):
|
381
|
+
bounds = np.array([np.nan, np.nan])
|
382
|
+
if index != 0:
|
383
|
+
prev_comp = components[index - 1]
|
384
|
+
try:
|
385
|
+
bounds[0] = min(
|
386
|
+
comp.rin.value - prev_comp.rout.value, prev_comp.rout.value
|
387
|
+
)
|
388
|
+
except AttributeError:
|
389
|
+
bounds[0] = comp.r.value
|
390
|
+
|
391
|
+
if index != len(components) - 1:
|
392
|
+
next_comp = components[index + 1]
|
393
|
+
try:
|
394
|
+
bounds[1] = next_comp.rin.value - comp.rout.value
|
395
|
+
except AttributeError:
|
396
|
+
bounds[1] = next_comp.rin.value
|
397
|
+
|
398
|
+
upper = np.min(bounds[~np.isnan(bounds)])
|
399
|
+
if f"r-{index}" in labels:
|
400
|
+
r_ind = labels.index(f"r-{index}")
|
401
|
+
lower = priors[r_ind][0]
|
402
|
+
params[r_ind] = lower + (upper - lower) * theta[r_ind]
|
403
|
+
|
404
|
+
|
405
|
+
def param_transform(theta: List[float]) -> NDArray[Any]:
|
406
|
+
"""Transform that soft constrains successive radii to be smaller than the one before."""
|
407
|
+
params = transform_uniform_prior(theta, get_priors(OPTIONS.model.components))
|
408
|
+
if OPTIONS.fit.type == "nband":
|
409
|
+
return nband(params, get_labels(OPTIONS.model.components), theta)
|
410
|
+
|
411
|
+
components = set_components_from_theta(params, theta)
|
412
|
+
for option in OPTIONS.fit.conditions:
|
413
|
+
components = getattr(CURRENT_MODULE, option)(components)
|
414
|
+
|
415
|
+
return get_theta(components)
|
416
|
+
|
417
|
+
|
418
|
+
def lnprob(theta: NDArray[Any]) -> float:
|
419
|
+
"""Takes theta vector returns a number corresponding to how good of a fit
|
420
|
+
the model is to your data for a given set of parameters.
|
421
|
+
|
422
|
+
Parameters
|
423
|
+
----------
|
424
|
+
theta: numpy.typing.NDArray
|
425
|
+
The parameters that ought to be fitted.
|
426
|
+
|
427
|
+
Returns
|
428
|
+
-------
|
429
|
+
float
|
430
|
+
The log of the probability.
|
431
|
+
"""
|
432
|
+
components = set_components_from_theta(theta)
|
433
|
+
if OPTIONS.fit.type == "nband":
|
434
|
+
return compute_nband_fit_chi_sq(
|
435
|
+
components[0].compute_flux(0, OPTIONS.fit.wls),
|
436
|
+
ndim=theta.size,
|
437
|
+
)
|
438
|
+
return compute_interferometric_loglike(components)[0]
|
439
|
+
|
440
|
+
|
441
|
+
def run_fit(
|
442
|
+
sample: str = "rwalk",
|
443
|
+
bound: str = "multi",
|
444
|
+
ncores: int = 6,
|
445
|
+
debug: bool = False,
|
446
|
+
save_dir: Path | None = None,
|
447
|
+
**kwargs,
|
448
|
+
) -> DynamicNestedSampler:
|
449
|
+
"""Runs the dynesty nested sampler.
|
450
|
+
|
451
|
+
Parameters
|
452
|
+
----------
|
453
|
+
sample : str, optional
|
454
|
+
The sampling method. Either "rwalk" or "unif".
|
455
|
+
bound : str, optional
|
456
|
+
The bounding method. Either "multi" or "single".
|
457
|
+
ncores : int, optional
|
458
|
+
The number of cores to use.
|
459
|
+
debug : bool, optional
|
460
|
+
Whether to run the sampler in debug mode.
|
461
|
+
This will not use multiprocessing.
|
462
|
+
save_dir : Path, optional
|
463
|
+
The directory to save the sampler.
|
464
|
+
|
465
|
+
Returns
|
466
|
+
-------
|
467
|
+
sampler : dynesty.DynamicNestedSampler
|
468
|
+
"""
|
469
|
+
if save_dir is not None:
|
470
|
+
checkpoint_file = save_dir / "sampler.save"
|
471
|
+
else:
|
472
|
+
checkpoint_file = None
|
473
|
+
|
474
|
+
components = OPTIONS.model.components
|
475
|
+
periodic = [
|
476
|
+
index
|
477
|
+
for index, param in enumerate(get_fit_params(components))
|
478
|
+
if param.periodic
|
479
|
+
]
|
480
|
+
periodic = None if not periodic else periodic
|
481
|
+
reflective = [
|
482
|
+
index
|
483
|
+
for index, param in enumerate(get_fit_params(components))
|
484
|
+
if param.reflective
|
485
|
+
]
|
486
|
+
reflective = None if not reflective else reflective
|
487
|
+
|
488
|
+
pool = Pool(processes=ncores) if not debug else None
|
489
|
+
queue_size = 2 * ncores if not debug else None
|
490
|
+
|
491
|
+
general_kwargs = {
|
492
|
+
"bound": bound,
|
493
|
+
"queue_size": queue_size,
|
494
|
+
"sample": sample,
|
495
|
+
"periodic": periodic,
|
496
|
+
"reflective": reflective,
|
497
|
+
"pool": pool,
|
498
|
+
}
|
499
|
+
|
500
|
+
run_kwargs = {
|
501
|
+
"nlive_batch": kwargs.pop("nlive_batch", 500),
|
502
|
+
"dlogz_init": kwargs.pop("dlogz_init", 0.01),
|
503
|
+
"nlive_init": kwargs.pop("nlive_init", 1000),
|
504
|
+
}
|
505
|
+
|
506
|
+
print(f"Executing Dynesty.\n{'':-^50}")
|
507
|
+
labels = get_labels(OPTIONS.model.components)
|
508
|
+
ptform = kwargs.pop("ptform", param_transform)
|
509
|
+
sampler = DynamicNestedSampler(
|
510
|
+
kwargs.pop("lnprob", lnprob),
|
511
|
+
ptform,
|
512
|
+
len(labels),
|
513
|
+
**general_kwargs,
|
514
|
+
)
|
515
|
+
sampler.run_nested(
|
516
|
+
**run_kwargs, print_progress=True, checkpoint_file=str(checkpoint_file)
|
517
|
+
)
|
518
|
+
|
519
|
+
if not debug:
|
520
|
+
pool.close()
|
521
|
+
pool.join()
|
522
|
+
|
523
|
+
return sampler
|
524
|
+
|
525
|
+
|
526
|
+
def get_best_fit(
|
527
|
+
sampler: DynamicNestedSampler,
|
528
|
+
method: str = "max",
|
529
|
+
) -> Tuple[NDArray[Any], NDArray[Any]]:
|
530
|
+
"""Gets the best fit from the sampler."""
|
531
|
+
results = sampler.results
|
532
|
+
samples, logl = results.samples, results.logl
|
533
|
+
weights = results.importance_weights()
|
534
|
+
quantiles = np.array(
|
535
|
+
[
|
536
|
+
dyutils.quantile(
|
537
|
+
samps, np.array(OPTIONS.fit.quantiles) / 100, weights=weights
|
538
|
+
)
|
539
|
+
for samps in samples.T
|
540
|
+
]
|
541
|
+
)
|
542
|
+
|
543
|
+
if method == "max":
|
544
|
+
quantiles[:, 1] = samples[logl.argmax()]
|
545
|
+
|
546
|
+
return quantiles[:, 1], np.diff(quantiles.T, axis=0).T
|
ppdmod/options.py
ADDED
@@ -0,0 +1,164 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
from types import SimpleNamespace
|
3
|
+
from typing import Any, Dict, List
|
4
|
+
|
5
|
+
import astropy.units as u
|
6
|
+
import matplotlib.pyplot as plt
|
7
|
+
import numpy as np
|
8
|
+
import toml
|
9
|
+
from matplotlib import colormaps as mcm
|
10
|
+
from matplotlib.colors import ListedColormap
|
11
|
+
|
12
|
+
|
13
|
+
def convert_style_to_colormap(style: str) -> ListedColormap:
|
14
|
+
"""Converts a style into a colormap."""
|
15
|
+
plt.style.use(style)
|
16
|
+
colormap = ListedColormap(plt.rcParams["axes.prop_cycle"].by_key()["color"])
|
17
|
+
plt.style.use("default")
|
18
|
+
return colormap
|
19
|
+
|
20
|
+
|
21
|
+
def get_colormap(colormap: str) -> ListedColormap:
|
22
|
+
"""Gets the colormap as the matplotlib colormaps or styles."""
|
23
|
+
try:
|
24
|
+
return mcm.get_cmap(colormap)
|
25
|
+
except ValueError:
|
26
|
+
return convert_style_to_colormap(colormap)
|
27
|
+
|
28
|
+
|
29
|
+
def get_colorlist(colormap: str, ncolors: int = 10) -> List[str]:
|
30
|
+
"""Gets the colormap as a list from the matplotlib colormaps."""
|
31
|
+
return [get_colormap(colormap)(i) for i in range(ncolors)]
|
32
|
+
|
33
|
+
|
34
|
+
def get_units(dictionary: Dict[str, Any]) -> Dict[str, Any]:
|
35
|
+
"""Converts the units in a dictionary to astropy units."""
|
36
|
+
converted_dictionary = dictionary.copy()
|
37
|
+
for val in converted_dictionary.values():
|
38
|
+
if "unit" in val:
|
39
|
+
if val["unit"] == "one":
|
40
|
+
val["unit"] = u.one
|
41
|
+
else:
|
42
|
+
val["unit"] = u.Unit(val["unit"])
|
43
|
+
|
44
|
+
return converted_dictionary
|
45
|
+
|
46
|
+
|
47
|
+
def load_toml_to_namespace(toml_file: Path):
|
48
|
+
"""Loads a toml file into a namespace."""
|
49
|
+
with open(toml_file, "r") as file:
|
50
|
+
data = toml.load(file)["STANDARD_PARAMETERS"]
|
51
|
+
|
52
|
+
return SimpleNamespace(**get_units(data))
|
53
|
+
|
54
|
+
|
55
|
+
STANDARD_PARAMS = load_toml_to_namespace(
|
56
|
+
Path(__file__).parent / "config" / "standard_parameters.toml"
|
57
|
+
)
|
58
|
+
|
59
|
+
|
60
|
+
# NOTE: Data
|
61
|
+
vis_data = SimpleNamespace(
|
62
|
+
val=np.array([]),
|
63
|
+
err=np.array([]),
|
64
|
+
u=np.array([]).reshape(1, -1),
|
65
|
+
v=np.array([]).reshape(1, -1),
|
66
|
+
)
|
67
|
+
vis2_data = SimpleNamespace(
|
68
|
+
val=np.array([]),
|
69
|
+
err=np.array([]),
|
70
|
+
u=np.array([]).reshape(1, -1),
|
71
|
+
v=np.array([]).reshape(1, -1),
|
72
|
+
)
|
73
|
+
t3_data = SimpleNamespace(
|
74
|
+
val=np.array([]),
|
75
|
+
err=np.array([]),
|
76
|
+
u123=np.array([]),
|
77
|
+
v123=np.array([]),
|
78
|
+
u=np.array([]).reshape(1, -1),
|
79
|
+
v=np.array([]).reshape(1, -1),
|
80
|
+
i123=np.array([]),
|
81
|
+
)
|
82
|
+
flux_data = SimpleNamespace(val=np.array([]), err=np.array([]))
|
83
|
+
gravity = SimpleNamespace(index=20)
|
84
|
+
dtype = SimpleNamespace(complex=np.complex128, real=np.float64)
|
85
|
+
binning = SimpleNamespace(
|
86
|
+
unknown=0.2 * u.um,
|
87
|
+
kband=0.2 * u.um,
|
88
|
+
hband=0.2 * u.um,
|
89
|
+
lband=0.1 * u.um,
|
90
|
+
mband=0.1 * u.um,
|
91
|
+
lmband=0.1 * u.um,
|
92
|
+
nband=0.1 * u.um,
|
93
|
+
)
|
94
|
+
interpolation = SimpleNamespace(dim=10, kind="linear", fill_value=None)
|
95
|
+
data = SimpleNamespace(
|
96
|
+
readouts=[],
|
97
|
+
readouts_t=[],
|
98
|
+
hduls=[],
|
99
|
+
hduls_t=[],
|
100
|
+
nt=1,
|
101
|
+
bands=[],
|
102
|
+
resolutions=[],
|
103
|
+
do_bin=True,
|
104
|
+
flux=flux_data,
|
105
|
+
vis=vis_data,
|
106
|
+
vis2=vis2_data,
|
107
|
+
t3=t3_data,
|
108
|
+
gravity=gravity,
|
109
|
+
binning=binning,
|
110
|
+
dtype=dtype,
|
111
|
+
interpolation=interpolation,
|
112
|
+
)
|
113
|
+
|
114
|
+
# NOTE: Model
|
115
|
+
model = SimpleNamespace(
|
116
|
+
components=None,
|
117
|
+
output="non-normed",
|
118
|
+
gridtype="logarithmic",
|
119
|
+
)
|
120
|
+
|
121
|
+
# NOTE: Plot
|
122
|
+
color = SimpleNamespace(
|
123
|
+
background="white", colormap="plasma", number=100, list=get_colorlist("plasma", 100)
|
124
|
+
)
|
125
|
+
errorbar = SimpleNamespace(
|
126
|
+
color=None,
|
127
|
+
markeredgecolor="black",
|
128
|
+
markeredgewidth=0.2,
|
129
|
+
capsize=5,
|
130
|
+
capthick=3,
|
131
|
+
ecolor="gray",
|
132
|
+
zorder=2,
|
133
|
+
)
|
134
|
+
scatter = SimpleNamespace(color="", edgecolor="black", linewidths=0.2, zorder=3)
|
135
|
+
plot = SimpleNamespace(
|
136
|
+
dim=256,
|
137
|
+
dpi=300,
|
138
|
+
ticks=[1.7, 2.15, 3.2, 4.7, 8.0, 9.0, 10.0, 11.0, 12.0, 12.75],
|
139
|
+
color=color,
|
140
|
+
errorbar=errorbar,
|
141
|
+
scatter=scatter,
|
142
|
+
)
|
143
|
+
|
144
|
+
# NOTE: Weights
|
145
|
+
weights = SimpleNamespace(
|
146
|
+
flux=SimpleNamespace(general=1),
|
147
|
+
t3=SimpleNamespace(general=1),
|
148
|
+
vis=SimpleNamespace(general=1),
|
149
|
+
)
|
150
|
+
|
151
|
+
# NOTE: Fitting
|
152
|
+
fit = SimpleNamespace(
|
153
|
+
weights=weights,
|
154
|
+
type="disc",
|
155
|
+
data=["flux", "vis", "t3"],
|
156
|
+
bands=["all"],
|
157
|
+
wls=None,
|
158
|
+
quantiles=[2.5, 50, 97.5],
|
159
|
+
fitter="dynesty",
|
160
|
+
conditions=None,
|
161
|
+
)
|
162
|
+
|
163
|
+
# NOTE: All options
|
164
|
+
OPTIONS = SimpleNamespace(data=data, model=model, plot=plot, fit=fit)
|