ppdmod 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ppdmod/__init__.py +1 -0
- ppdmod/base.py +225 -0
- ppdmod/components.py +557 -0
- ppdmod/config/standard_parameters.toml +290 -0
- ppdmod/data.py +485 -0
- ppdmod/fitting.py +546 -0
- ppdmod/options.py +164 -0
- ppdmod/parameter.py +152 -0
- ppdmod/plot.py +1241 -0
- ppdmod/utils.py +575 -0
- ppdmod-2.0.0.dist-info/METADATA +68 -0
- ppdmod-2.0.0.dist-info/RECORD +15 -0
- ppdmod-2.0.0.dist-info/WHEEL +5 -0
- ppdmod-2.0.0.dist-info/licenses/LICENSE +21 -0
- ppdmod-2.0.0.dist-info/top_level.txt +1 -0
ppdmod/plot.py
ADDED
@@ -0,0 +1,1241 @@
|
|
1
|
+
from itertools import chain, zip_longest
|
2
|
+
from pathlib import Path
|
3
|
+
from typing import Dict, List, Tuple
|
4
|
+
|
5
|
+
import astropy.constants as const
|
6
|
+
import astropy.units as u
|
7
|
+
import corner
|
8
|
+
import matplotlib.cm as cm
|
9
|
+
import matplotlib.colors as mcolors
|
10
|
+
import matplotlib.lines as mlines
|
11
|
+
import matplotlib.pyplot as plt
|
12
|
+
import matplotlib.ticker as ticker
|
13
|
+
import numpy as np
|
14
|
+
from dynesty import DynamicNestedSampler, NestedSampler
|
15
|
+
from dynesty import plotting as dyplot
|
16
|
+
from matplotlib.axes import Axes
|
17
|
+
from matplotlib.gridspec import GridSpec
|
18
|
+
from matplotlib.legend import Legend
|
19
|
+
|
20
|
+
from .base import FourierComponent
|
21
|
+
from .fitting import compute_observables, get_best_fit
|
22
|
+
from .options import OPTIONS, get_colormap
|
23
|
+
from .utils import (
|
24
|
+
compare_angles,
|
25
|
+
get_band_indices,
|
26
|
+
transform_coordinates,
|
27
|
+
)
|
28
|
+
|
29
|
+
|
30
|
+
def get_best_plot_arrangement(nplots):
|
31
|
+
"""Gets the best plot arrangement for a given number of plots."""
|
32
|
+
sqrt_nplots = np.sqrt(nplots)
|
33
|
+
cols = int(np.ceil(sqrt_nplots))
|
34
|
+
rows = int(np.floor(sqrt_nplots))
|
35
|
+
|
36
|
+
while rows * cols < nplots:
|
37
|
+
if cols < rows:
|
38
|
+
cols += 1
|
39
|
+
else:
|
40
|
+
rows += 1
|
41
|
+
|
42
|
+
while (rows - 1) * cols >= nplots:
|
43
|
+
rows -= 1
|
44
|
+
|
45
|
+
return rows, cols
|
46
|
+
|
47
|
+
|
48
|
+
def set_axes_color(
|
49
|
+
ax: Axes,
|
50
|
+
background_color: str,
|
51
|
+
set_label: bool = True,
|
52
|
+
direction: str | None = None,
|
53
|
+
) -> None:
|
54
|
+
"""Sets all the axes' facecolor."""
|
55
|
+
opposite_color = "white" if background_color == "black" else "black"
|
56
|
+
ax.set_facecolor(background_color)
|
57
|
+
ax.spines["bottom"].set_color(opposite_color)
|
58
|
+
ax.spines["top"].set_color(opposite_color)
|
59
|
+
ax.spines["right"].set_color(opposite_color)
|
60
|
+
ax.spines["left"].set_color(opposite_color)
|
61
|
+
|
62
|
+
if set_label:
|
63
|
+
ax.xaxis.label.set_color(opposite_color)
|
64
|
+
ax.yaxis.label.set_color(opposite_color)
|
65
|
+
|
66
|
+
ax.tick_params(axis="both", colors=opposite_color, direction=direction)
|
67
|
+
|
68
|
+
|
69
|
+
def set_legend_color(legend: Legend, background_color: str) -> None:
|
70
|
+
"""Sets the legend's facecolor."""
|
71
|
+
opposite_color = "white" if background_color == "black" else "black"
|
72
|
+
plt.setp(legend.get_texts(), color=opposite_color)
|
73
|
+
legend.get_frame().set_facecolor(background_color)
|
74
|
+
|
75
|
+
|
76
|
+
def format_labels(
|
77
|
+
labels: List[str], units: List[str] | None = None, split: bool = False
|
78
|
+
) -> List[str]:
|
79
|
+
"""Formats the labels in LaTeX.
|
80
|
+
|
81
|
+
Parameters
|
82
|
+
----------
|
83
|
+
labels : list of str
|
84
|
+
The labels.
|
85
|
+
units : list of str, optional
|
86
|
+
The units. The default is None.
|
87
|
+
split : bool, optional
|
88
|
+
If True, splits into labels, units, and uncertainties.
|
89
|
+
The default is False.
|
90
|
+
|
91
|
+
Returns
|
92
|
+
-------
|
93
|
+
labels : list of str
|
94
|
+
The formatted labels.
|
95
|
+
units : list of str, optional
|
96
|
+
The formatted units. If split is True
|
97
|
+
"""
|
98
|
+
nice_labels = {
|
99
|
+
"rin": {"letter": "R", "indices": [r"\text{in}"]},
|
100
|
+
"rout": {"letter": "R", "indices": [r"\text{out}"]},
|
101
|
+
"p": {"letter": "p"},
|
102
|
+
"q": {"letter": "q"},
|
103
|
+
"rho": {"letter": r"\rho"},
|
104
|
+
"theta": {"letter": r"\theta"},
|
105
|
+
"logsigma0": {"letter": r"\Sigma", "indices": ["0"]},
|
106
|
+
"sigma0": {"letter": r"\Sigma", "indices": ["0"]},
|
107
|
+
"weight_cont": {"letter": "w", "indices": [r"\text{cont}"]},
|
108
|
+
"pa": {"letter": r"\theta", "indices": []},
|
109
|
+
"cinc": {"letter": r"\cos\left(i\right)"},
|
110
|
+
"temp0": {"letter": "T", "indices": ["0"]},
|
111
|
+
"tempc": {"letter": "T", "indices": [r"\text{c}"]},
|
112
|
+
"f": {"letter": "f"},
|
113
|
+
"fr": {"letter": "fr"},
|
114
|
+
"fwhm": {"letter": r"\sigma"},
|
115
|
+
"r": {"letter": "r"},
|
116
|
+
"phi": {"letter": r"\phi"},
|
117
|
+
}
|
118
|
+
|
119
|
+
formatted_labels = []
|
120
|
+
for label in labels:
|
121
|
+
if "-" in label:
|
122
|
+
name, index = label.split("-")
|
123
|
+
else:
|
124
|
+
name, index = label, ""
|
125
|
+
|
126
|
+
if name in nice_labels or name[-1].isdigit():
|
127
|
+
if ".t" in name:
|
128
|
+
name, time_index = name.split(".")
|
129
|
+
else:
|
130
|
+
time_index = None
|
131
|
+
|
132
|
+
if name not in nice_labels and name[-1].isdigit():
|
133
|
+
letter = nice_labels[name[:-1]]["letter"]
|
134
|
+
indices = [name[-1]]
|
135
|
+
if index:
|
136
|
+
indices.append(index)
|
137
|
+
else:
|
138
|
+
letter = nice_labels[name]["letter"]
|
139
|
+
if name in ["temp0", "tempc"]:
|
140
|
+
indices = nice_labels[name].get("indices", [])
|
141
|
+
else:
|
142
|
+
indices = [*nice_labels[name].get("indices", [])]
|
143
|
+
if index:
|
144
|
+
indices.append(rf"\mathrm{{{index}}}")
|
145
|
+
|
146
|
+
if time_index is not None:
|
147
|
+
indices.append(rf"\mathrm{{{time_index}}}")
|
148
|
+
|
149
|
+
indices = r",\,".join(indices)
|
150
|
+
formatted_label = f"{letter}_{{{indices}}}"
|
151
|
+
if "log" in label:
|
152
|
+
formatted_label = rf"\log_{{10}}\left({formatted_label}\right)"
|
153
|
+
|
154
|
+
formatted_labels.append(f"$ {formatted_label} $")
|
155
|
+
else:
|
156
|
+
if "weight" in name:
|
157
|
+
name, letter = name.replace("weight", ""), "w"
|
158
|
+
|
159
|
+
indices = []
|
160
|
+
if "small" in name:
|
161
|
+
name = name.replace("small", "")
|
162
|
+
indices = [r"\text{small}"]
|
163
|
+
elif "large" in name:
|
164
|
+
name = name.replace("large", "")
|
165
|
+
indices = [r"\text{large}"]
|
166
|
+
name = name.replace("_", "")
|
167
|
+
indices.append(rf"\text{{{name}}}")
|
168
|
+
|
169
|
+
indices = r",\,".join(indices)
|
170
|
+
formatted_label = f"{letter}_{{{indices}}}"
|
171
|
+
if "log" in label:
|
172
|
+
formatted_label = rf"\log_{{10}}\left({formatted_label}\right)"
|
173
|
+
elif "scale" in name:
|
174
|
+
formatted_label = rf"w_{{\text{{{name.replace('scale_', '')}}}}}"
|
175
|
+
elif "lnf" in name:
|
176
|
+
formatted_label = (
|
177
|
+
rf"\ln\left(f\right)_{{\text{{{name.split('_')[0]}}}}}"
|
178
|
+
)
|
179
|
+
else:
|
180
|
+
formatted_label = label
|
181
|
+
|
182
|
+
formatted_labels.append(f"$ {formatted_label} $")
|
183
|
+
|
184
|
+
if units is not None:
|
185
|
+
reformatted_units = []
|
186
|
+
for unit in units:
|
187
|
+
if unit == u.g / u.cm**2:
|
188
|
+
unit = r"\si{\gram\per\square\centi\metre}"
|
189
|
+
elif unit == u.au:
|
190
|
+
unit = r"\si{\astronomicalunit}"
|
191
|
+
elif unit == u.deg:
|
192
|
+
unit = r"\si{\degree}"
|
193
|
+
elif unit == u.pct:
|
194
|
+
unit = r"\si{\percent}"
|
195
|
+
|
196
|
+
reformatted_units.append(unit)
|
197
|
+
|
198
|
+
reformatted_units = [
|
199
|
+
rf"$ (\text{{{str(unit).strip()}}}) $" if str(unit) else ""
|
200
|
+
for unit in reformatted_units
|
201
|
+
]
|
202
|
+
if split:
|
203
|
+
return formatted_labels, reformatted_units
|
204
|
+
|
205
|
+
formatted_labels = [
|
206
|
+
rf"{label} {unit}"
|
207
|
+
for label, unit in zip(formatted_labels, reformatted_units)
|
208
|
+
]
|
209
|
+
return formatted_labels
|
210
|
+
|
211
|
+
|
212
|
+
def needs_sci_notation(ax):
|
213
|
+
"""Checks if scientific notation is needed"""
|
214
|
+
x_min, x_max = ax.get_xlim()
|
215
|
+
y_min, y_max = ax.get_ylim()
|
216
|
+
return (
|
217
|
+
abs(x_min) <= 1e-3
|
218
|
+
or abs(x_max) <= 1e-3
|
219
|
+
or abs(y_min) <= 1e-3
|
220
|
+
or abs(y_max) <= 1e-3
|
221
|
+
)
|
222
|
+
|
223
|
+
|
224
|
+
def get_exponent(num: float) -> int:
|
225
|
+
"""Gets the exponent of a number for scientific notation"""
|
226
|
+
if num == 0:
|
227
|
+
raise ValueError("Number must be non-zero")
|
228
|
+
|
229
|
+
exponent_10 = np.floor(np.log10(abs(num)))
|
230
|
+
normalized_num = num / (10**exponent_10)
|
231
|
+
return np.floor(np.log10(normalized_num) - np.log10(10**exponent_10)).astype(int)
|
232
|
+
|
233
|
+
|
234
|
+
def plot_corner(
|
235
|
+
sampler,
|
236
|
+
labels: List[str],
|
237
|
+
units: List[str] | None = None,
|
238
|
+
fontsize: int = 12,
|
239
|
+
discard: int = 0,
|
240
|
+
savefig: Path | None = None,
|
241
|
+
**kwargs,
|
242
|
+
) -> None:
|
243
|
+
"""Plots the corner of the posterior spread.
|
244
|
+
|
245
|
+
Parameters
|
246
|
+
----------
|
247
|
+
sampler :
|
248
|
+
labels : list of str
|
249
|
+
The parameter labels.
|
250
|
+
units : list of str, optional
|
251
|
+
discard : int, optional
|
252
|
+
fontsize : int, optional
|
253
|
+
The fontsize. The default is 12.
|
254
|
+
savefig : pathlib.Path, optional
|
255
|
+
The save path. The default is None.
|
256
|
+
"""
|
257
|
+
labels = format_labels(labels, units)
|
258
|
+
quantiles = [x / 100 for x in OPTIONS.fit.quantiles]
|
259
|
+
if OPTIONS.fit.fitter == "dynesty":
|
260
|
+
results = sampler.results
|
261
|
+
_, axarr = dyplot.cornerplot(
|
262
|
+
results,
|
263
|
+
color="blue",
|
264
|
+
labels=labels,
|
265
|
+
show_titles=True,
|
266
|
+
max_n_ticks=3,
|
267
|
+
title_quantiles=quantiles,
|
268
|
+
quantiles=quantiles,
|
269
|
+
)
|
270
|
+
|
271
|
+
theta, uncertainties = get_best_fit(sampler)
|
272
|
+
for index, row in enumerate(axarr):
|
273
|
+
for ax in row:
|
274
|
+
if ax is not None:
|
275
|
+
if needs_sci_notation(ax):
|
276
|
+
if "Sigma" in ax.get_xlabel():
|
277
|
+
ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
|
278
|
+
ax.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
|
279
|
+
|
280
|
+
if "Sigma" in ax.get_ylabel():
|
281
|
+
ax.yaxis.set_major_formatter(ticker.ScalarFormatter())
|
282
|
+
ax.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
|
283
|
+
ax.yaxis.get_offset_text().set_position((-0.2, 0))
|
284
|
+
|
285
|
+
title = ax.get_title()
|
286
|
+
if title and np.abs(theta[index]) <= 1e-3:
|
287
|
+
exponent = get_exponent(theta[index])
|
288
|
+
factor = 10.0**exponent
|
289
|
+
formatted_title = (
|
290
|
+
rf"${theta[index] * factor:.2f}_{{-{uncertainties[index][0] * factor:.2f}}}"
|
291
|
+
rf"^{{+{uncertainties[index][1] * factor:.2f}}}\,1\mathrm{{e}}-{exponent}$"
|
292
|
+
)
|
293
|
+
ax.set_title(
|
294
|
+
f"{labels[index]} = {formatted_title}",
|
295
|
+
fontsize=fontsize - 2,
|
296
|
+
)
|
297
|
+
else:
|
298
|
+
samples = sampler.get_chain(discard=discard, flat=True)
|
299
|
+
corner.corner(samples, labels=labels)
|
300
|
+
|
301
|
+
if savefig is not None:
|
302
|
+
plt.savefig(savefig, format=Path(savefig).suffix[1:], dpi=OPTIONS.plot.dpi)
|
303
|
+
plt.close()
|
304
|
+
|
305
|
+
|
306
|
+
def plot_chains(
|
307
|
+
sampler: NestedSampler | DynamicNestedSampler,
|
308
|
+
labels: List[str],
|
309
|
+
units: List[str] | None = None,
|
310
|
+
savefig: Path | None = None,
|
311
|
+
**kwargs,
|
312
|
+
) -> None:
|
313
|
+
"""Plots the fitter's chains.
|
314
|
+
|
315
|
+
Parameters
|
316
|
+
----------
|
317
|
+
sampler : dynesty.NestedSampler or dynesty.DynamicNestedSampler
|
318
|
+
The sampler.
|
319
|
+
labels : list of str
|
320
|
+
The parameter labels.
|
321
|
+
units : list of str, optional
|
322
|
+
discard : int, optional
|
323
|
+
savefig : pathlib.Path, optional
|
324
|
+
The save path. The default is None.
|
325
|
+
"""
|
326
|
+
labels = format_labels(labels, units)
|
327
|
+
quantiles = [x / 100 for x in OPTIONS.fit.quantiles]
|
328
|
+
results = sampler.results
|
329
|
+
dyplot.traceplot(
|
330
|
+
results,
|
331
|
+
labels=labels,
|
332
|
+
truths=np.zeros(len(labels)),
|
333
|
+
quantiles=quantiles,
|
334
|
+
truth_color="black",
|
335
|
+
show_titles=True,
|
336
|
+
trace_cmap="viridis",
|
337
|
+
connect=True,
|
338
|
+
connect_highlight=range(5),
|
339
|
+
)
|
340
|
+
|
341
|
+
if savefig:
|
342
|
+
plt.savefig(savefig, format=Path(savefig).suffix[1:], dpi=OPTIONS.plot.dpi)
|
343
|
+
else:
|
344
|
+
plt.show()
|
345
|
+
plt.close()
|
346
|
+
|
347
|
+
|
348
|
+
class LogNorm(mcolors.Normalize):
|
349
|
+
"""Gets the log norm."""
|
350
|
+
|
351
|
+
def __init__(self, vmin=None, vmax=None, clip=False):
|
352
|
+
super().__init__(vmin, vmax, clip)
|
353
|
+
|
354
|
+
def __call__(self, value, clip=None):
|
355
|
+
normalized_value = np.log1p(value - self.vmin) / np.log1p(self.vmax - self.vmin)
|
356
|
+
return np.ma.masked_array(normalized_value, np.isnan(normalized_value))
|
357
|
+
|
358
|
+
def inverse(self, value):
|
359
|
+
return np.expm1(value * np.log1p(self.vmax - self.vmin)) + self.vmin
|
360
|
+
|
361
|
+
|
362
|
+
def set_axis_information(
|
363
|
+
axarr: Dict[str, List[Axes]],
|
364
|
+
key: str,
|
365
|
+
cinc=None,
|
366
|
+
) -> Tuple[Axes, Axes]:
|
367
|
+
"""Sets the axis labels and limits for the different keys."""
|
368
|
+
if isinstance(axarr[key], (tuple, list, np.ndarray)):
|
369
|
+
upper_ax, lower_ax = axarr[key]
|
370
|
+
# set_axes_color(lower_ax, OPTIONS.plot.color.background)
|
371
|
+
else:
|
372
|
+
upper_ax, lower_ax = axarr[key], None
|
373
|
+
|
374
|
+
tick_params = {
|
375
|
+
"axis": "x",
|
376
|
+
"which": "both",
|
377
|
+
"bottom": True,
|
378
|
+
"top": False,
|
379
|
+
"labelbottom": False if lower_ax is not None else True,
|
380
|
+
}
|
381
|
+
|
382
|
+
if key == "flux":
|
383
|
+
xlabel = r"$ \lambda (\mathrm{\mu}\text{m}) $"
|
384
|
+
residual_label = "Residuals (Jy)"
|
385
|
+
ylabel = r"$ F_{\nu} $ (Jy)"
|
386
|
+
|
387
|
+
elif key in ["vis", "vis2"]:
|
388
|
+
xlabel = r"$ B (\text{M}\lambda)$"
|
389
|
+
if cinc is not None:
|
390
|
+
xlabel = r"$ B_{\text{eff}} (\text{M}\lambda) $"
|
391
|
+
|
392
|
+
if key == "vis":
|
393
|
+
ylabel = r"$ F_{\nu,\,\text{corr}} $ (Jy)"
|
394
|
+
residual_label = "Residuals (Jy)"
|
395
|
+
else:
|
396
|
+
ylabel = "$ V^{2} $ (a.u.)"
|
397
|
+
residual_label = "Residuals (a.u.)"
|
398
|
+
upper_ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.2f"))
|
399
|
+
|
400
|
+
elif key == "t3":
|
401
|
+
xlabel = r"$ B_{\text{max}} (\text{M}\lambda) $"
|
402
|
+
ylabel = r"$ \Phi_{\text{cp}} (^{\circ}) $"
|
403
|
+
residual_label = r"Residuals $ (^{\circ}) $"
|
404
|
+
|
405
|
+
upper_ax.tick_params(**tick_params)
|
406
|
+
upper_ax.set_ylabel(ylabel)
|
407
|
+
if lower_ax is not None:
|
408
|
+
lower_ax.set_xlabel(xlabel)
|
409
|
+
lower_ax.set_ylabel(residual_label)
|
410
|
+
else:
|
411
|
+
upper_ax.set_xlabel(xlabel)
|
412
|
+
|
413
|
+
return upper_ax, lower_ax
|
414
|
+
|
415
|
+
|
416
|
+
def plot_data_vs_model(
|
417
|
+
axarr,
|
418
|
+
wavelengths: np.ndarray,
|
419
|
+
val: np.ndarray,
|
420
|
+
err: np.ndarray,
|
421
|
+
key: str,
|
422
|
+
baselines: np.ndarray | None = None,
|
423
|
+
model_val: np.ndarray | None = None,
|
424
|
+
colormap: str = OPTIONS.plot.color.colormap,
|
425
|
+
bands: List[str] | str = "all",
|
426
|
+
cinc: float | None = None,
|
427
|
+
ylims: Dict = {},
|
428
|
+
norm=None,
|
429
|
+
):
|
430
|
+
"""Plots the data versus the model or just the data if not model data given."""
|
431
|
+
upper_ax, lower_ax = set_axis_information(axarr, key, cinc)
|
432
|
+
colormap, alpha = get_colormap(colormap), 1 if lower_ax is None else 0.55
|
433
|
+
hline_color = "gray" if OPTIONS.plot.color.background == "white" else "white"
|
434
|
+
errorbar_params, scatter_params = OPTIONS.plot.errorbar, OPTIONS.plot.scatter
|
435
|
+
if OPTIONS.plot.color.background == "black":
|
436
|
+
errorbar_params.markeredgecolor = "white"
|
437
|
+
scatter_params.edgecolor = "white"
|
438
|
+
|
439
|
+
if model_val is not None:
|
440
|
+
model_val = np.ma.masked_array(model_val, mask=val.mask)
|
441
|
+
|
442
|
+
if bands == "all" or bands is None:
|
443
|
+
band_indices = np.where(np.ones_like(wavelengths.value).astype(bool))[0]
|
444
|
+
else:
|
445
|
+
band_indices = get_band_indices(wavelengths.value, bands)
|
446
|
+
|
447
|
+
wavelengths = wavelengths[band_indices]
|
448
|
+
val, err = val[band_indices], err[band_indices]
|
449
|
+
if model_val is not None:
|
450
|
+
model_val = model_val[band_indices]
|
451
|
+
|
452
|
+
set_axes_color(upper_ax, OPTIONS.plot.color.background)
|
453
|
+
color = colormap(norm(wavelengths.value))
|
454
|
+
if baselines is None:
|
455
|
+
grid = [wl.repeat(val.shape[-1]) for wl in wavelengths.value]
|
456
|
+
else:
|
457
|
+
grid = baselines / wavelengths.value[:, np.newaxis]
|
458
|
+
|
459
|
+
ymin, ymax = 0, 0
|
460
|
+
ymin_res, ymax_res = 0, 0
|
461
|
+
for index, _ in enumerate(wavelengths.value):
|
462
|
+
errorbar_params.color = scatter_params.color = color[index]
|
463
|
+
upper_ax.errorbar(
|
464
|
+
grid[index],
|
465
|
+
val[index],
|
466
|
+
err[index],
|
467
|
+
fmt="o",
|
468
|
+
**vars(errorbar_params),
|
469
|
+
)
|
470
|
+
|
471
|
+
ymin = min(ymin, np.nanmin(val[index]))
|
472
|
+
ymax = max(ymax, np.nanmax(val[index]))
|
473
|
+
if model_val is not None and lower_ax is not None:
|
474
|
+
upper_ax.scatter(
|
475
|
+
grid[index],
|
476
|
+
model_val[index],
|
477
|
+
marker="X",
|
478
|
+
alpha=alpha,
|
479
|
+
**vars(scatter_params),
|
480
|
+
)
|
481
|
+
|
482
|
+
if key == "t3":
|
483
|
+
upper_ax.axhline(0, color="grey", linestyle="--")
|
484
|
+
residuals = np.rad2deg(
|
485
|
+
compare_angles(
|
486
|
+
np.deg2rad(val[index]),
|
487
|
+
np.deg2rad(model_val[index]),
|
488
|
+
)
|
489
|
+
)
|
490
|
+
else:
|
491
|
+
residuals = val[index] - model_val[index]
|
492
|
+
|
493
|
+
residual_errs = err[index]
|
494
|
+
|
495
|
+
ymin = min(ymin, np.nanmin(model_val[index]))
|
496
|
+
ymax = max(ymax, np.nanmax(model_val[index]))
|
497
|
+
ymin_res = min(ymin_res, np.nanmin(residuals))
|
498
|
+
ymax_res = max(ymax_res, np.nanmax(residuals))
|
499
|
+
|
500
|
+
lower_ax.errorbar(
|
501
|
+
grid[index],
|
502
|
+
residuals,
|
503
|
+
residual_errs,
|
504
|
+
fmt="o",
|
505
|
+
**vars(errorbar_params),
|
506
|
+
)
|
507
|
+
lower_ax.axhline(y=0, color=hline_color, linestyle="--")
|
508
|
+
|
509
|
+
ymin, ymax = ymin - np.abs(ymin) * 0.25, ymax + ymax * 0.25
|
510
|
+
if key in ["flux", "vis"]:
|
511
|
+
ylim = ylims.get(key, [0, ymax])
|
512
|
+
elif key == "vis2":
|
513
|
+
ylim = ylims.get(key, [0, 1])
|
514
|
+
else:
|
515
|
+
ylim = ylims.get("t3", [ymin, ymax])
|
516
|
+
|
517
|
+
upper_ax.set_ylim(ylim)
|
518
|
+
# TODO: Improve the residual plots
|
519
|
+
if lower_ax is not None:
|
520
|
+
upper_ax.tick_params(axis="x", which="both", direction="in")
|
521
|
+
ymin_res, ymax_res = (
|
522
|
+
ymin_res - np.abs(ymin_res) * 0.25,
|
523
|
+
ymax_res + ymax_res * 0.25,
|
524
|
+
)
|
525
|
+
tick_diff = np.diff(upper_ax.get_yticks())[0]
|
526
|
+
lower_ax.set_ylim((ymin_res, ymax_res))
|
527
|
+
|
528
|
+
if not len(axarr) > 1:
|
529
|
+
label_color = "lightgray" if OPTIONS.plot.color.background == "black" else "k"
|
530
|
+
dot_label = mlines.Line2D(
|
531
|
+
[],
|
532
|
+
[],
|
533
|
+
color=label_color,
|
534
|
+
marker="o",
|
535
|
+
linestyle="None",
|
536
|
+
label="Data",
|
537
|
+
alpha=0.6,
|
538
|
+
)
|
539
|
+
x_label = mlines.Line2D(
|
540
|
+
[], [], color=label_color, marker="X", linestyle="None", label="Model"
|
541
|
+
)
|
542
|
+
legend = upper_ax.legend(handles=[dot_label, x_label])
|
543
|
+
set_legend_color(legend, OPTIONS.plot.color.background)
|
544
|
+
|
545
|
+
errorbar_params.color = scatter_params.color = None
|
546
|
+
|
547
|
+
|
548
|
+
def plot_fit(
|
549
|
+
components: List | None = None,
|
550
|
+
data_to_plot: List[str | None] | None = None,
|
551
|
+
cmap: str = OPTIONS.plot.color.colormap,
|
552
|
+
ylims: Dict[str, List[float]] = {},
|
553
|
+
bands: List[str] | str = "all",
|
554
|
+
title: str | None = None,
|
555
|
+
ax: List[List[Axes]] | None = None,
|
556
|
+
colorbar: bool = True,
|
557
|
+
savefig: Path | None = None,
|
558
|
+
):
|
559
|
+
"""Plots the deviation of a model from real data of an object for
|
560
|
+
total flux, visibilities and closure phases.
|
561
|
+
|
562
|
+
Parameters
|
563
|
+
----------
|
564
|
+
inclination : astropy.units.one
|
565
|
+
The axis ratio.
|
566
|
+
pos_angle : astropy.units.deg
|
567
|
+
The position angle.
|
568
|
+
data_to_plot : list of str, optional
|
569
|
+
The data to plot. The default is OPTIONS.fit.data.
|
570
|
+
ylimits : dict of list of float, optional
|
571
|
+
The ylimits for the individual keys.
|
572
|
+
bands : list of str or str, optional
|
573
|
+
The bands to be plotted. The default is "all".
|
574
|
+
cmap : str, optional
|
575
|
+
The colormap.
|
576
|
+
title : str, optional
|
577
|
+
The title. The default is None.
|
578
|
+
savefig : pathlib.Path, optional
|
579
|
+
The save path. The default is None.
|
580
|
+
"""
|
581
|
+
data_to_plot = OPTIONS.fit.data if data_to_plot is None else data_to_plot
|
582
|
+
flux, t3 = OPTIONS.data.flux, OPTIONS.data.t3
|
583
|
+
vis = OPTIONS.data.vis if "vis" in data_to_plot else OPTIONS.data.vis2
|
584
|
+
nts, wls = range(OPTIONS.data.nt), OPTIONS.fit.wls
|
585
|
+
norm = LogNorm(vmin=wls[0].value, vmax=wls[-1].value)
|
586
|
+
|
587
|
+
data_types, nplots = [], 0
|
588
|
+
for key in data_to_plot:
|
589
|
+
if key in ["vis", "vis2"] and "vis" not in data_types:
|
590
|
+
data_types.append("vis")
|
591
|
+
else:
|
592
|
+
data_types.append(key)
|
593
|
+
nplots += 1
|
594
|
+
|
595
|
+
for t in nts:
|
596
|
+
model_flux, model_vis, model_t3 = compute_observables(components)
|
597
|
+
|
598
|
+
# NOTE: This won't work with differing cinc and pa
|
599
|
+
cinc, pa = components[0].cinc(), components[0].pa()
|
600
|
+
figsize = (16, 5) if nplots == 3 else ((12, 5) if nplots == 2 else None)
|
601
|
+
fig = plt.figure(figsize=figsize, facecolor=OPTIONS.plot.color.background)
|
602
|
+
if ax is None:
|
603
|
+
gs = GridSpec(2, nplots, height_ratios=[2.5, 1.5], hspace=0.00)
|
604
|
+
axarr = [
|
605
|
+
[
|
606
|
+
fig.add_subplot(gs[j, i], facecolor=OPTIONS.plot.color.background)
|
607
|
+
for j in range(2)
|
608
|
+
]
|
609
|
+
for i in range(nplots)
|
610
|
+
]
|
611
|
+
else:
|
612
|
+
axarr = ax
|
613
|
+
|
614
|
+
axarr = dict(zip(data_types, axarr))
|
615
|
+
plot_kwargs = {"norm": norm, "colormap": cmap}
|
616
|
+
if "flux" in data_to_plot:
|
617
|
+
plot_data_vs_model(
|
618
|
+
axarr,
|
619
|
+
wls,
|
620
|
+
flux.val[t],
|
621
|
+
flux.err[t],
|
622
|
+
"flux",
|
623
|
+
ylims=ylims,
|
624
|
+
bands=bands,
|
625
|
+
model_val=model_flux[t],
|
626
|
+
cinc=cinc,
|
627
|
+
**plot_kwargs,
|
628
|
+
)
|
629
|
+
|
630
|
+
if "vis" in data_to_plot or "vis2" in data_to_plot:
|
631
|
+
baselines = np.hypot(*transform_coordinates(vis.u[t], vis.v[t], cinc, pa))
|
632
|
+
plot_data_vs_model(
|
633
|
+
axarr,
|
634
|
+
wls,
|
635
|
+
vis.val[t],
|
636
|
+
vis.err[t],
|
637
|
+
"vis" if "vis" in data_to_plot else "vis2",
|
638
|
+
ylims=ylims,
|
639
|
+
bands=bands,
|
640
|
+
baselines=baselines[:, 1:],
|
641
|
+
model_val=model_vis[t],
|
642
|
+
cinc=cinc,
|
643
|
+
**plot_kwargs,
|
644
|
+
)
|
645
|
+
|
646
|
+
if "t3" in data_to_plot:
|
647
|
+
baselines = np.hypot(*transform_coordinates(t3.u[t], t3.v[t], cinc, pa))
|
648
|
+
baselines = baselines[t3.i123[t]].T.max(1).reshape(1, -1)
|
649
|
+
plot_data_vs_model(
|
650
|
+
axarr,
|
651
|
+
wls,
|
652
|
+
t3.val[t],
|
653
|
+
t3.err[t],
|
654
|
+
"t3",
|
655
|
+
ylims=ylims,
|
656
|
+
bands=bands,
|
657
|
+
baselines=baselines[:, 1:],
|
658
|
+
model_val=model_t3[t],
|
659
|
+
cinc=cinc,
|
660
|
+
**plot_kwargs,
|
661
|
+
)
|
662
|
+
|
663
|
+
if colorbar:
|
664
|
+
sm = cm.ScalarMappable(cmap=get_colormap(cmap), norm=norm)
|
665
|
+
sm.set_array([])
|
666
|
+
cbar = plt.colorbar(sm, ax=axarr[data_types[-1]])
|
667
|
+
cbar.set_ticks(OPTIONS.plot.ticks)
|
668
|
+
cbar.set_ticklabels(
|
669
|
+
[f"{wavelength:.1f}" for wavelength in OPTIONS.plot.ticks]
|
670
|
+
)
|
671
|
+
|
672
|
+
if OPTIONS.plot.color.background == "black":
|
673
|
+
cbar.ax.yaxis.set_tick_params(color="white")
|
674
|
+
plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="white")
|
675
|
+
for spine in cbar.ax.spines.values():
|
676
|
+
spine.set_edgecolor("white")
|
677
|
+
|
678
|
+
text_color = (
|
679
|
+
"white" if OPTIONS.plot.color.background == "black" else "black"
|
680
|
+
)
|
681
|
+
cbar.set_label(label=r"$\lambda$ ($\mathrm{\mu}$m)", color=text_color)
|
682
|
+
|
683
|
+
if title is not None:
|
684
|
+
plt.title(title)
|
685
|
+
|
686
|
+
if savefig is not None:
|
687
|
+
plt.savefig(
|
688
|
+
savefig.parent / f"{savefig.stem}_t{t}{savefig.suffix}",
|
689
|
+
format=Path(savefig).suffix[1:],
|
690
|
+
dpi=OPTIONS.plot.dpi,
|
691
|
+
bbox_inches="tight",
|
692
|
+
)
|
693
|
+
|
694
|
+
if ax is None:
|
695
|
+
plt.show()
|
696
|
+
|
697
|
+
# TODO: Implement plt.close() again here
|
698
|
+
|
699
|
+
|
700
|
+
def plot_overview(
|
701
|
+
data_to_plot: List[str | None] = None,
|
702
|
+
colormap: str = OPTIONS.plot.color.colormap,
|
703
|
+
ylims: Dict[str, List[float]] = {},
|
704
|
+
title: str | None = None,
|
705
|
+
cinc: float | None = None,
|
706
|
+
pa: float | None = None,
|
707
|
+
bands: List[str] | str = "all",
|
708
|
+
colorbar: bool = True,
|
709
|
+
axarr: Axes | None = None,
|
710
|
+
savefig: Path | None = None,
|
711
|
+
) -> None:
|
712
|
+
"""Plots an overview over the total data for baselines [Mlambda].
|
713
|
+
|
714
|
+
Parameters
|
715
|
+
----------
|
716
|
+
data_to_plot : list of str, optional
|
717
|
+
The data to plot. The default is OPTIONS.fit.data.
|
718
|
+
savefig : pathlib.Path, optional
|
719
|
+
The save path. The default is None.
|
720
|
+
"""
|
721
|
+
data_to_plot = OPTIONS.fit.data if data_to_plot is None else data_to_plot
|
722
|
+
nts, wls = range(OPTIONS.data.nt), OPTIONS.fit.wls
|
723
|
+
norm = LogNorm(vmin=wls[0].value, vmax=wls[-1].value)
|
724
|
+
|
725
|
+
data_types, nplots = [], 0
|
726
|
+
for key in data_to_plot:
|
727
|
+
if key in ["vis", "vis2"] and "vis" not in data_types:
|
728
|
+
data_types.append("vis")
|
729
|
+
else:
|
730
|
+
data_types.append(key)
|
731
|
+
nplots += 1
|
732
|
+
|
733
|
+
for t in nts:
|
734
|
+
if axarr is None:
|
735
|
+
figsize = (15, 5) if nplots == 3 else ((12, 5) if nplots == 2 else None)
|
736
|
+
_, axarr = plt.subplots(
|
737
|
+
1,
|
738
|
+
nplots,
|
739
|
+
figsize=figsize,
|
740
|
+
tight_layout=True,
|
741
|
+
facecolor=OPTIONS.plot.color.background,
|
742
|
+
)
|
743
|
+
|
744
|
+
axarr = axarr.flatten() if isinstance(axarr, np.ndarray) else [axarr]
|
745
|
+
axarr = dict(zip(data_types, axarr))
|
746
|
+
|
747
|
+
flux, t3 = OPTIONS.data.flux, OPTIONS.data.t3
|
748
|
+
vis = OPTIONS.data.vis if "vis" in OPTIONS.fit.data else OPTIONS.data.vis2
|
749
|
+
|
750
|
+
errorbar_params = OPTIONS.plot.errorbar
|
751
|
+
if OPTIONS.plot.color.background == "black":
|
752
|
+
errorbar_params.markeredgecolor = "white"
|
753
|
+
|
754
|
+
plot_kwargs = {"norm": norm, "colormap": colormap}
|
755
|
+
if "flux" in data_to_plot:
|
756
|
+
plot_data_vs_model(
|
757
|
+
axarr,
|
758
|
+
wls,
|
759
|
+
flux.val[t],
|
760
|
+
flux.err[t],
|
761
|
+
"flux",
|
762
|
+
ylims=ylims,
|
763
|
+
bands=bands,
|
764
|
+
cinc=cinc,
|
765
|
+
**plot_kwargs,
|
766
|
+
)
|
767
|
+
|
768
|
+
if "vis" in data_to_plot or "vis2" in data_to_plot:
|
769
|
+
baselines = np.hypot(*transform_coordinates(vis.u[t], vis.v[t], cinc, pa))
|
770
|
+
plot_data_vs_model(
|
771
|
+
axarr,
|
772
|
+
wls,
|
773
|
+
vis.val[t],
|
774
|
+
vis.err[t],
|
775
|
+
"vis" if "vis" in data_to_plot else "vis2",
|
776
|
+
ylims=ylims,
|
777
|
+
bands=bands,
|
778
|
+
baselines=baselines[:, 1:],
|
779
|
+
cinc=cinc,
|
780
|
+
**plot_kwargs,
|
781
|
+
)
|
782
|
+
|
783
|
+
if "t3" in data_to_plot:
|
784
|
+
baselines = np.hypot(*transform_coordinates(t3.u[t], t3.v[t], cinc, pa))
|
785
|
+
baselines = baselines[t3.i123[t]].T.max(1).reshape(1, -1)
|
786
|
+
plot_data_vs_model(
|
787
|
+
axarr,
|
788
|
+
wls,
|
789
|
+
t3.val[t],
|
790
|
+
t3.err[t],
|
791
|
+
"t3",
|
792
|
+
ylims=ylims,
|
793
|
+
bands=bands,
|
794
|
+
baselines=baselines[:, 1:],
|
795
|
+
cinc=cinc,
|
796
|
+
**plot_kwargs,
|
797
|
+
)
|
798
|
+
|
799
|
+
if colorbar:
|
800
|
+
sm = cm.ScalarMappable(cmap=colormap, norm=norm)
|
801
|
+
sm.set_array([])
|
802
|
+
cbar = plt.colorbar(sm, ax=axarr[data_types[-1]])
|
803
|
+
|
804
|
+
# TODO: Set the ticks, but make it so that it is flexible for the band
|
805
|
+
cbar.set_ticks(OPTIONS.plot.ticks)
|
806
|
+
cbar.set_ticklabels(
|
807
|
+
[f"{wavelength:.1f}" for wavelength in OPTIONS.plot.ticks]
|
808
|
+
)
|
809
|
+
|
810
|
+
if OPTIONS.plot.color.background == "black":
|
811
|
+
cbar.ax.yaxis.set_tick_params(color="white")
|
812
|
+
plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="white")
|
813
|
+
for spine in cbar.ax.spines.values():
|
814
|
+
spine.set_edgecolor("white")
|
815
|
+
opposite_color = (
|
816
|
+
"white" if OPTIONS.plot.color.background == "black" else "black"
|
817
|
+
)
|
818
|
+
cbar.set_label(label=r"$\lambda$ ($\mathrm{\mu}$m)", color=opposite_color)
|
819
|
+
|
820
|
+
if title is not None:
|
821
|
+
plt.title(title)
|
822
|
+
|
823
|
+
if savefig is not None:
|
824
|
+
plt.savefig(
|
825
|
+
savefig.parent / f"{savefig.stem}_t{t}{savefig.suffix}",
|
826
|
+
format=Path(savefig).suffix[1:],
|
827
|
+
dpi=OPTIONS.plot.dpi,
|
828
|
+
)
|
829
|
+
|
830
|
+
if savefig is None:
|
831
|
+
if axarr is not None:
|
832
|
+
return
|
833
|
+
|
834
|
+
plt.show()
|
835
|
+
plt.close()
|
836
|
+
|
837
|
+
|
838
|
+
def plot_sed(
|
839
|
+
wavelength_range: u.um,
|
840
|
+
components: List[FourierComponent | None] = None,
|
841
|
+
scaling: str = "nu",
|
842
|
+
no_model: bool = False,
|
843
|
+
ax: plt.Axes | None = None,
|
844
|
+
savefig: Path | None = None,
|
845
|
+
):
|
846
|
+
"""Plots the observables of the model.
|
847
|
+
|
848
|
+
Parameters
|
849
|
+
----------
|
850
|
+
wavelength_range : astropy.units.m
|
851
|
+
scaling : str, optional
|
852
|
+
The scaling of the SED. "nu" for the flux to be
|
853
|
+
in Jy times Hz. If "lambda" the flux is in Jy times m.
|
854
|
+
If "none" the flux is in Jy.
|
855
|
+
The default is "nu".
|
856
|
+
"""
|
857
|
+
color = OPTIONS.plot.color
|
858
|
+
savefig = Path.cwd() if savefig is None else savefig
|
859
|
+
wavelength = np.linspace(wavelength_range[0], wavelength_range[1], OPTIONS.plot.dim)
|
860
|
+
|
861
|
+
if not no_model:
|
862
|
+
wavelength = OPTIONS.fit.wls if wavelength is None else wavelength
|
863
|
+
components = [comp for comp in components if comp.name != "Point Source"]
|
864
|
+
flux = np.sum([comp.compute_flux(0, wavelength) for comp in components], axis=0)
|
865
|
+
if flux.size > 0:
|
866
|
+
flux = np.tile(flux, (len(OPTIONS.data.readouts))).real
|
867
|
+
|
868
|
+
if ax is None:
|
869
|
+
fig = plt.figure(facecolor=color.background, tight_layout=True)
|
870
|
+
ax = plt.axes(facecolor=color.background)
|
871
|
+
set_axes_color(ax, color.background)
|
872
|
+
else:
|
873
|
+
fig = None
|
874
|
+
|
875
|
+
if len(OPTIONS.data.readouts) > 1:
|
876
|
+
names = [
|
877
|
+
re.findall(r"(\d{4}-\d{2}-\d{2})", readout.fits_file.name)[0]
|
878
|
+
for readout in OPTIONS.data.readouts
|
879
|
+
]
|
880
|
+
else:
|
881
|
+
names = [OPTIONS.data.readouts[0].fits_file.name]
|
882
|
+
|
883
|
+
cmap = plt.get_cmap(color.colormap)
|
884
|
+
norm = mcolors.LogNorm(vmin=1, vmax=len(set(names)))
|
885
|
+
colors = [cmap(norm(i)) for i in range(1, len(set(names)) + 1)]
|
886
|
+
date_to_color = {date: color for date, color in zip(set(names), colors)}
|
887
|
+
sorted_readouts = np.array(OPTIONS.data.readouts.copy())[np.argsort(names)].tolist()
|
888
|
+
|
889
|
+
values = []
|
890
|
+
for name, readout in zip(np.sort(names), sorted_readouts):
|
891
|
+
if readout.flux.val.size == 0:
|
892
|
+
continue
|
893
|
+
|
894
|
+
readout_wl = readout.wl.value
|
895
|
+
readout_flux, readout_err = (
|
896
|
+
readout.flux.val.flatten(),
|
897
|
+
readout.flux.err.flatten(),
|
898
|
+
)
|
899
|
+
readout_err_percentage = readout_err / readout_flux
|
900
|
+
|
901
|
+
if scaling == "nu":
|
902
|
+
readout_flux = (readout_flux * u.Jy).to(u.W / u.m**2 / u.Hz)
|
903
|
+
readout_flux = (
|
904
|
+
readout_flux * (const.c / ((readout_wl * u.um).to(u.m))).to(u.Hz)
|
905
|
+
).value
|
906
|
+
|
907
|
+
readout_err = readout_err_percentage * readout_flux
|
908
|
+
lower_err, upper_err = readout_flux - readout_err, readout_flux + readout_err
|
909
|
+
if "HAW" in readout.fits_file.name:
|
910
|
+
indices_high = np.where((readout_wl >= 4.55) & (readout_wl <= 4.9))
|
911
|
+
indices_low = np.where((readout_wl >= 3.1) & (readout_wl <= 3.9))
|
912
|
+
for indices in [indices_high, indices_low]:
|
913
|
+
line = ax.plot(
|
914
|
+
readout_wl[indices],
|
915
|
+
readout_flux[indices],
|
916
|
+
color=date_to_color[name],
|
917
|
+
)
|
918
|
+
ax.fill_between(
|
919
|
+
readout_wl[indices],
|
920
|
+
lower_err[indices],
|
921
|
+
upper_err[indices],
|
922
|
+
color=line[0].get_color(),
|
923
|
+
alpha=0.5,
|
924
|
+
)
|
925
|
+
value_indices = np.hstack([indices_high, indices_low])
|
926
|
+
lim_values = readout_flux[value_indices].flatten()
|
927
|
+
else:
|
928
|
+
line = ax.plot(readout_wl, readout_flux, color=date_to_color[name])
|
929
|
+
ax.fill_between(
|
930
|
+
readout_wl,
|
931
|
+
lower_err,
|
932
|
+
upper_err,
|
933
|
+
color=line[0].get_color(),
|
934
|
+
alpha=0.5,
|
935
|
+
)
|
936
|
+
lim_values = readout_flux
|
937
|
+
values.append(lim_values)
|
938
|
+
|
939
|
+
flux_label = r"$F_{\nu}$ (Jy)"
|
940
|
+
if not no_model:
|
941
|
+
flux = flux[:, 0]
|
942
|
+
if scaling == "nu":
|
943
|
+
flux = (flux * u.Jy).to(u.W / u.m**2 / u.Hz)
|
944
|
+
flux = (flux * (const.c / (wavelength.to(u.m))).to(u.Hz)).value
|
945
|
+
flux_label = r"$\nu F_{\nu}$ (W m$^{-2}$)"
|
946
|
+
|
947
|
+
if not no_model:
|
948
|
+
ax.plot(wavelength, flux, label="Model", color="red")
|
949
|
+
values.append(flux)
|
950
|
+
|
951
|
+
if fig is not None:
|
952
|
+
ax.set_xlabel(r"$\lambda$ ($\mathrm{\mu}$m)")
|
953
|
+
ax.set_ylabel(flux_label)
|
954
|
+
ax.legend()
|
955
|
+
|
956
|
+
max_value = np.concatenate(values).max()
|
957
|
+
ax.set_ylim([0, max_value + 0.2 * max_value])
|
958
|
+
|
959
|
+
if savefig is not None:
|
960
|
+
plt.savefig(savefig, format=Path(savefig).suffix[1:], dpi=OPTIONS.plot.dpi)
|
961
|
+
plt.close()
|
962
|
+
|
963
|
+
|
964
|
+
def plot_product(
|
965
|
+
points,
|
966
|
+
product,
|
967
|
+
xlabel,
|
968
|
+
ylabel,
|
969
|
+
save_path=None,
|
970
|
+
ax=None,
|
971
|
+
colorbar=False,
|
972
|
+
cmap: str = OPTIONS.plot.color.colormap,
|
973
|
+
scale=None,
|
974
|
+
label=None,
|
975
|
+
):
|
976
|
+
norm = None
|
977
|
+
if label is not None:
|
978
|
+
if isinstance(label, (np.ndarray, u.Quantity)):
|
979
|
+
norm = mcolors.Normalize(vmin=label[0].value, vmax=label[-1].value)
|
980
|
+
|
981
|
+
if ax is None:
|
982
|
+
fig, ax = plt.subplots()
|
983
|
+
else:
|
984
|
+
fig = ax.figure
|
985
|
+
|
986
|
+
if product.ndim > 1:
|
987
|
+
for lb, prod in zip(label, product):
|
988
|
+
color = None
|
989
|
+
if norm is not None:
|
990
|
+
colormap = get_colormap(cmap)
|
991
|
+
color = colormap(norm(lb.value))
|
992
|
+
ax.plot(points, prod, label=lb, color=color)
|
993
|
+
if not colorbar:
|
994
|
+
ax.legend()
|
995
|
+
else:
|
996
|
+
ax.plot(points, product, label=label)
|
997
|
+
|
998
|
+
ax.set_xlabel(xlabel)
|
999
|
+
ax.set_ylabel(ylabel)
|
1000
|
+
|
1001
|
+
if scale == "log":
|
1002
|
+
ax.set_yscale("log")
|
1003
|
+
elif scale == "loglog":
|
1004
|
+
ax.set_yscale("log")
|
1005
|
+
ax.set_xscale("log")
|
1006
|
+
elif scale == "sci":
|
1007
|
+
ax.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True))
|
1008
|
+
ax.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
|
1009
|
+
|
1010
|
+
if colorbar:
|
1011
|
+
sm = cm.ScalarMappable(cmap=get_colormap(cmap), norm=norm)
|
1012
|
+
sm.set_array([])
|
1013
|
+
cbar = plt.colorbar(sm, ax=ax)
|
1014
|
+
cbar.set_ticks(OPTIONS.plot.ticks)
|
1015
|
+
cbar.set_ticklabels([f"{wavelength:.1f}" for wavelength in OPTIONS.plot.ticks])
|
1016
|
+
cbar.set_label(label=r"$\lambda$ ($\mathrm{\mu}$m)")
|
1017
|
+
|
1018
|
+
if save_path is not None:
|
1019
|
+
fig.savefig(save_path, format=Path(save_path).suffix[1:], dpi=OPTIONS.plot.dpi)
|
1020
|
+
plt.close(fig)
|
1021
|
+
|
1022
|
+
|
1023
|
+
# TODO: Clean and split this function into multiple ones
|
1024
|
+
def plot_products(
|
1025
|
+
dim: int,
|
1026
|
+
components: List[FourierComponent],
|
1027
|
+
component_labels: List[str],
|
1028
|
+
save_dir: Path | None = None,
|
1029
|
+
) -> None:
|
1030
|
+
"""Plots the intermediate products of the model (temperature, density, etc.)."""
|
1031
|
+
component_labels = [
|
1032
|
+
" ".join(map(str.title, label.split("_"))) for label in component_labels
|
1033
|
+
]
|
1034
|
+
for t in range(OPTIONS.data.nt):
|
1035
|
+
wls = np.linspace(OPTIONS.fit.wls[0], OPTIONS.fit.wls[-1], dim)
|
1036
|
+
radii, surface_density, optical_depth = [], [], []
|
1037
|
+
fluxes, emissivity, intensity = [], [], []
|
1038
|
+
_, ax = plt.subplots(figsize=(5, 5))
|
1039
|
+
for label, component in zip(component_labels, components):
|
1040
|
+
component.dim.value = dim
|
1041
|
+
flux = component.fr(t, wls) * component.compute_flux(t, wls).squeeze()
|
1042
|
+
plot_product(
|
1043
|
+
wls,
|
1044
|
+
flux,
|
1045
|
+
r"$\lambda$ ($\mathrm{\mu}$m)",
|
1046
|
+
r"$F_{\nu}$ (Jy)",
|
1047
|
+
scale="log",
|
1048
|
+
ax=ax,
|
1049
|
+
label=label,
|
1050
|
+
)
|
1051
|
+
fluxes.append(flux)
|
1052
|
+
if component.name in ["Point", "Gauss", "BBGauss"]:
|
1053
|
+
continue
|
1054
|
+
|
1055
|
+
radius = component.compute_internal_grid(t, wls)
|
1056
|
+
radii.append(radius)
|
1057
|
+
|
1058
|
+
surface_density.append(component.compute_surface_density(radius, t, wls))
|
1059
|
+
optical_depth.append(
|
1060
|
+
component.compute_optical_depth(radius, t, wls[:, np.newaxis])
|
1061
|
+
)
|
1062
|
+
emissivity.append(
|
1063
|
+
component.compute_emissivity(radius, t, wls[:, np.newaxis])
|
1064
|
+
)
|
1065
|
+
intensity.append(component.compute_intensity(radius, t, wls[:, np.newaxis]))
|
1066
|
+
|
1067
|
+
surface_density = u.Quantity(surface_density)
|
1068
|
+
optical_depth = u.Quantity(optical_depth)
|
1069
|
+
emissivity = u.Quantity(emissivity)
|
1070
|
+
intensity = u.Quantity(intensity)
|
1071
|
+
|
1072
|
+
total_flux = np.sum(fluxes, axis=0)
|
1073
|
+
ax.plot(wls, total_flux, label="Total")
|
1074
|
+
ax.set_yscale("log")
|
1075
|
+
ax.set_ylim([1e-1, None])
|
1076
|
+
ax.legend()
|
1077
|
+
plt.savefig(save_dir / f"fluxes_t{t}.png", format="png", dpi=OPTIONS.plot.dpi)
|
1078
|
+
plt.close()
|
1079
|
+
|
1080
|
+
_, ax = plt.subplots(figsize=(5, 5))
|
1081
|
+
for label, flux_ratio in zip(component_labels, np.array(fluxes) / total_flux):
|
1082
|
+
plot_product(
|
1083
|
+
wls,
|
1084
|
+
flux_ratio * 100,
|
1085
|
+
r"$\lambda$ ($\mathrm{\mu}$m)",
|
1086
|
+
r"$F_{\nu}$ / $F_{\nu,\,\mathrm{tot}}$ (%)",
|
1087
|
+
ax=ax,
|
1088
|
+
label=label,
|
1089
|
+
)
|
1090
|
+
|
1091
|
+
ax.legend()
|
1092
|
+
ax.set_ylim([0, 100])
|
1093
|
+
plt.savefig(
|
1094
|
+
save_dir / f"flux_ratios_t{t}.png", format="png", dpi=OPTIONS.plot.dpi
|
1095
|
+
)
|
1096
|
+
plt.close()
|
1097
|
+
|
1098
|
+
radii_bounds = [
|
1099
|
+
(prev[-1], current[0]) for prev, current in zip(radii[:-1], radii[1:])
|
1100
|
+
]
|
1101
|
+
fill_radii = [np.linspace(lower, upper, dim) for lower, upper in radii_bounds]
|
1102
|
+
merged_radii = list(chain.from_iterable(zip_longest(radii, fill_radii)))[:-1]
|
1103
|
+
merged_radii = u.Quantity(np.concatenate(merged_radii, axis=0))
|
1104
|
+
fill_zeros = np.zeros((len(fill_radii), wls.size, dim))
|
1105
|
+
disc_component = [
|
1106
|
+
comp for comp in components if comp.name not in ["Point", "Gauss"]
|
1107
|
+
][0]
|
1108
|
+
|
1109
|
+
# TODO: Make it so that the temperatures are somehow continous in the plot? (Maybe check for self.temps in the models?)
|
1110
|
+
# or interpolate smoothly somehow (see the one youtube video?) :D
|
1111
|
+
temperature = disc_component.compute_temperature(merged_radii, t, wls)
|
1112
|
+
surface_density = u.Quantity(
|
1113
|
+
list(
|
1114
|
+
chain.from_iterable(
|
1115
|
+
zip_longest(surface_density, fill_zeros[:, 0, :] * u.g / u.cm**2)
|
1116
|
+
)
|
1117
|
+
)[:-1]
|
1118
|
+
)
|
1119
|
+
surface_density = np.concatenate(surface_density, axis=0)
|
1120
|
+
optical_depth = u.Quantity(
|
1121
|
+
list(chain.from_iterable(zip_longest(optical_depth, fill_zeros)))[:-1]
|
1122
|
+
)
|
1123
|
+
optical_depth = np.hstack(optical_depth)
|
1124
|
+
emissivity = u.Quantity(
|
1125
|
+
list(chain.from_iterable(zip_longest(emissivity, fill_zeros)))[:-1]
|
1126
|
+
)
|
1127
|
+
emissivity = np.hstack(emissivity)
|
1128
|
+
intensity = u.Quantity(
|
1129
|
+
list(
|
1130
|
+
chain.from_iterable(
|
1131
|
+
zip_longest(
|
1132
|
+
intensity, fill_zeros * u.erg / u.cm**2 / u.s / u.Hz / u.sr
|
1133
|
+
)
|
1134
|
+
)
|
1135
|
+
)[:-1]
|
1136
|
+
)
|
1137
|
+
intensity = np.hstack(intensity)
|
1138
|
+
intensity = intensity.to(u.W / u.m**2 / u.Hz / u.sr)
|
1139
|
+
merged_radii_mas = (
|
1140
|
+
(merged_radii.to(u.au) / components[1].dist().to(u.pc)).value * 1e3 * u.mas
|
1141
|
+
)
|
1142
|
+
|
1143
|
+
# TODO: Code this in a better manner
|
1144
|
+
wls = [1.7, 2.15, 3.4, 8, 11.3, 13] * u.um
|
1145
|
+
cumulative_intensity = (
|
1146
|
+
np.zeros((wls.size, merged_radii_mas.size))
|
1147
|
+
* u.erg
|
1148
|
+
/ u.s
|
1149
|
+
/ u.Hz
|
1150
|
+
/ u.cm**2
|
1151
|
+
/ u.sr
|
1152
|
+
)
|
1153
|
+
# for index, wl in enumerate(wls):
|
1154
|
+
# tmp_intensity = [
|
1155
|
+
# component.compute_intensity(radius, t, wl)
|
1156
|
+
# for radius, component in zip(radii, components[1:])
|
1157
|
+
# ]
|
1158
|
+
# tmp_intensity = u.Quantity(
|
1159
|
+
# list(
|
1160
|
+
# chain.from_iterable(
|
1161
|
+
# zip_longest(
|
1162
|
+
# tmp_intensity,
|
1163
|
+
# fill_zeros[0, 0][np.newaxis, :]
|
1164
|
+
# * u.erg
|
1165
|
+
# / u.cm**2
|
1166
|
+
# / u.s
|
1167
|
+
# / u.Hz
|
1168
|
+
# / u.sr,
|
1169
|
+
# )
|
1170
|
+
# )
|
1171
|
+
# )[:-1]
|
1172
|
+
# )
|
1173
|
+
# cumulative_intensity[index, :] = np.hstack(tmp_intensity)
|
1174
|
+
#
|
1175
|
+
# cumulative_intensity = cumulative_intensity.to(
|
1176
|
+
# u.erg / u.s / u.Hz / u.cm**2 / u.mas**2
|
1177
|
+
# )
|
1178
|
+
# cumulative_total_flux = (
|
1179
|
+
# 2
|
1180
|
+
# * np.pi
|
1181
|
+
# * disc_component.cinc(t, wls)
|
1182
|
+
# * np.trapz(merged_radii_mas * cumulative_intensity, merged_radii_mas).to(
|
1183
|
+
# u.Jy
|
1184
|
+
# )[:, np.newaxis]
|
1185
|
+
# )
|
1186
|
+
#
|
1187
|
+
# cumulative_flux = np.zeros((wls.size, merged_radii.size)) * u.Jy
|
1188
|
+
# for index, _ in enumerate(merged_radii):
|
1189
|
+
# cumulative_flux[:, index] = (
|
1190
|
+
# 2
|
1191
|
+
# * np.pi
|
1192
|
+
# * disc_component.cinc(t, wls)
|
1193
|
+
# * np.trapz(
|
1194
|
+
# merged_radii_mas[:index] * cumulative_intensity[:, :index],
|
1195
|
+
# merged_radii_mas[:index],
|
1196
|
+
# ).to(u.Jy)
|
1197
|
+
# )
|
1198
|
+
# cumulative_flux_ratio = cumulative_flux / cumulative_total_flux
|
1199
|
+
# plot_product(
|
1200
|
+
# merged_radii.value,
|
1201
|
+
# cumulative_flux_ratio.value,
|
1202
|
+
# "$R$ (AU)",
|
1203
|
+
# r"$F_{\nu}\left(r\right)/F_{\nu,\,\mathrm{{tot}}}$ (a.u.)",
|
1204
|
+
# label=wls,
|
1205
|
+
# save_path=save_dir / f"cumulative_flux_ratio_t{t}.png",
|
1206
|
+
# )
|
1207
|
+
|
1208
|
+
plot_product(
|
1209
|
+
merged_radii.value,
|
1210
|
+
temperature.value,
|
1211
|
+
"$R$ (AU)",
|
1212
|
+
"$T$ (K)",
|
1213
|
+
scale="log",
|
1214
|
+
save_path=save_dir / f"temperature_t{t}.png",
|
1215
|
+
)
|
1216
|
+
plot_product(
|
1217
|
+
merged_radii.value,
|
1218
|
+
surface_density.value,
|
1219
|
+
"$R$ (au)",
|
1220
|
+
r"$\Sigma$ (g cm$^{-2}$)",
|
1221
|
+
save_path=save_dir / f"surface_density_t{t}.png",
|
1222
|
+
scale="sci",
|
1223
|
+
)
|
1224
|
+
plot_product(
|
1225
|
+
merged_radii.value,
|
1226
|
+
optical_depth.value,
|
1227
|
+
"$R$ (AU)",
|
1228
|
+
r"$\tau_{\nu}$",
|
1229
|
+
save_path=save_dir / f"optical_depths_t{t}.png",
|
1230
|
+
scale="log",
|
1231
|
+
colorbar=True,
|
1232
|
+
label=wls,
|
1233
|
+
)
|
1234
|
+
# plot_product(merged_radii.value, emissivities.value,
|
1235
|
+
# "$R$ (AU)", r"$\epsilon_{\nu}$",
|
1236
|
+
# save_path=save_dir / "emissivities.png",
|
1237
|
+
# label=wavelength)
|
1238
|
+
# plot_product(merged_radii.value, brightnesses.value,
|
1239
|
+
# "$R$ (AU)", r"$I_{\nu}$ (W m$^{-2}$ Hz$^{-1}$ sr$^{-1}$)",
|
1240
|
+
# save_path=save_dir / "brightnesses.png",
|
1241
|
+
# scale="log", label=wavelength)
|