ppdmod 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ppdmod/plot.py ADDED
@@ -0,0 +1,1241 @@
1
+ from itertools import chain, zip_longest
2
+ from pathlib import Path
3
+ from typing import Dict, List, Tuple
4
+
5
+ import astropy.constants as const
6
+ import astropy.units as u
7
+ import corner
8
+ import matplotlib.cm as cm
9
+ import matplotlib.colors as mcolors
10
+ import matplotlib.lines as mlines
11
+ import matplotlib.pyplot as plt
12
+ import matplotlib.ticker as ticker
13
+ import numpy as np
14
+ from dynesty import DynamicNestedSampler, NestedSampler
15
+ from dynesty import plotting as dyplot
16
+ from matplotlib.axes import Axes
17
+ from matplotlib.gridspec import GridSpec
18
+ from matplotlib.legend import Legend
19
+
20
+ from .base import FourierComponent
21
+ from .fitting import compute_observables, get_best_fit
22
+ from .options import OPTIONS, get_colormap
23
+ from .utils import (
24
+ compare_angles,
25
+ get_band_indices,
26
+ transform_coordinates,
27
+ )
28
+
29
+
30
+ def get_best_plot_arrangement(nplots):
31
+ """Gets the best plot arrangement for a given number of plots."""
32
+ sqrt_nplots = np.sqrt(nplots)
33
+ cols = int(np.ceil(sqrt_nplots))
34
+ rows = int(np.floor(sqrt_nplots))
35
+
36
+ while rows * cols < nplots:
37
+ if cols < rows:
38
+ cols += 1
39
+ else:
40
+ rows += 1
41
+
42
+ while (rows - 1) * cols >= nplots:
43
+ rows -= 1
44
+
45
+ return rows, cols
46
+
47
+
48
+ def set_axes_color(
49
+ ax: Axes,
50
+ background_color: str,
51
+ set_label: bool = True,
52
+ direction: str | None = None,
53
+ ) -> None:
54
+ """Sets all the axes' facecolor."""
55
+ opposite_color = "white" if background_color == "black" else "black"
56
+ ax.set_facecolor(background_color)
57
+ ax.spines["bottom"].set_color(opposite_color)
58
+ ax.spines["top"].set_color(opposite_color)
59
+ ax.spines["right"].set_color(opposite_color)
60
+ ax.spines["left"].set_color(opposite_color)
61
+
62
+ if set_label:
63
+ ax.xaxis.label.set_color(opposite_color)
64
+ ax.yaxis.label.set_color(opposite_color)
65
+
66
+ ax.tick_params(axis="both", colors=opposite_color, direction=direction)
67
+
68
+
69
+ def set_legend_color(legend: Legend, background_color: str) -> None:
70
+ """Sets the legend's facecolor."""
71
+ opposite_color = "white" if background_color == "black" else "black"
72
+ plt.setp(legend.get_texts(), color=opposite_color)
73
+ legend.get_frame().set_facecolor(background_color)
74
+
75
+
76
+ def format_labels(
77
+ labels: List[str], units: List[str] | None = None, split: bool = False
78
+ ) -> List[str]:
79
+ """Formats the labels in LaTeX.
80
+
81
+ Parameters
82
+ ----------
83
+ labels : list of str
84
+ The labels.
85
+ units : list of str, optional
86
+ The units. The default is None.
87
+ split : bool, optional
88
+ If True, splits into labels, units, and uncertainties.
89
+ The default is False.
90
+
91
+ Returns
92
+ -------
93
+ labels : list of str
94
+ The formatted labels.
95
+ units : list of str, optional
96
+ The formatted units. If split is True
97
+ """
98
+ nice_labels = {
99
+ "rin": {"letter": "R", "indices": [r"\text{in}"]},
100
+ "rout": {"letter": "R", "indices": [r"\text{out}"]},
101
+ "p": {"letter": "p"},
102
+ "q": {"letter": "q"},
103
+ "rho": {"letter": r"\rho"},
104
+ "theta": {"letter": r"\theta"},
105
+ "logsigma0": {"letter": r"\Sigma", "indices": ["0"]},
106
+ "sigma0": {"letter": r"\Sigma", "indices": ["0"]},
107
+ "weight_cont": {"letter": "w", "indices": [r"\text{cont}"]},
108
+ "pa": {"letter": r"\theta", "indices": []},
109
+ "cinc": {"letter": r"\cos\left(i\right)"},
110
+ "temp0": {"letter": "T", "indices": ["0"]},
111
+ "tempc": {"letter": "T", "indices": [r"\text{c}"]},
112
+ "f": {"letter": "f"},
113
+ "fr": {"letter": "fr"},
114
+ "fwhm": {"letter": r"\sigma"},
115
+ "r": {"letter": "r"},
116
+ "phi": {"letter": r"\phi"},
117
+ }
118
+
119
+ formatted_labels = []
120
+ for label in labels:
121
+ if "-" in label:
122
+ name, index = label.split("-")
123
+ else:
124
+ name, index = label, ""
125
+
126
+ if name in nice_labels or name[-1].isdigit():
127
+ if ".t" in name:
128
+ name, time_index = name.split(".")
129
+ else:
130
+ time_index = None
131
+
132
+ if name not in nice_labels and name[-1].isdigit():
133
+ letter = nice_labels[name[:-1]]["letter"]
134
+ indices = [name[-1]]
135
+ if index:
136
+ indices.append(index)
137
+ else:
138
+ letter = nice_labels[name]["letter"]
139
+ if name in ["temp0", "tempc"]:
140
+ indices = nice_labels[name].get("indices", [])
141
+ else:
142
+ indices = [*nice_labels[name].get("indices", [])]
143
+ if index:
144
+ indices.append(rf"\mathrm{{{index}}}")
145
+
146
+ if time_index is not None:
147
+ indices.append(rf"\mathrm{{{time_index}}}")
148
+
149
+ indices = r",\,".join(indices)
150
+ formatted_label = f"{letter}_{{{indices}}}"
151
+ if "log" in label:
152
+ formatted_label = rf"\log_{{10}}\left({formatted_label}\right)"
153
+
154
+ formatted_labels.append(f"$ {formatted_label} $")
155
+ else:
156
+ if "weight" in name:
157
+ name, letter = name.replace("weight", ""), "w"
158
+
159
+ indices = []
160
+ if "small" in name:
161
+ name = name.replace("small", "")
162
+ indices = [r"\text{small}"]
163
+ elif "large" in name:
164
+ name = name.replace("large", "")
165
+ indices = [r"\text{large}"]
166
+ name = name.replace("_", "")
167
+ indices.append(rf"\text{{{name}}}")
168
+
169
+ indices = r",\,".join(indices)
170
+ formatted_label = f"{letter}_{{{indices}}}"
171
+ if "log" in label:
172
+ formatted_label = rf"\log_{{10}}\left({formatted_label}\right)"
173
+ elif "scale" in name:
174
+ formatted_label = rf"w_{{\text{{{name.replace('scale_', '')}}}}}"
175
+ elif "lnf" in name:
176
+ formatted_label = (
177
+ rf"\ln\left(f\right)_{{\text{{{name.split('_')[0]}}}}}"
178
+ )
179
+ else:
180
+ formatted_label = label
181
+
182
+ formatted_labels.append(f"$ {formatted_label} $")
183
+
184
+ if units is not None:
185
+ reformatted_units = []
186
+ for unit in units:
187
+ if unit == u.g / u.cm**2:
188
+ unit = r"\si{\gram\per\square\centi\metre}"
189
+ elif unit == u.au:
190
+ unit = r"\si{\astronomicalunit}"
191
+ elif unit == u.deg:
192
+ unit = r"\si{\degree}"
193
+ elif unit == u.pct:
194
+ unit = r"\si{\percent}"
195
+
196
+ reformatted_units.append(unit)
197
+
198
+ reformatted_units = [
199
+ rf"$ (\text{{{str(unit).strip()}}}) $" if str(unit) else ""
200
+ for unit in reformatted_units
201
+ ]
202
+ if split:
203
+ return formatted_labels, reformatted_units
204
+
205
+ formatted_labels = [
206
+ rf"{label} {unit}"
207
+ for label, unit in zip(formatted_labels, reformatted_units)
208
+ ]
209
+ return formatted_labels
210
+
211
+
212
+ def needs_sci_notation(ax):
213
+ """Checks if scientific notation is needed"""
214
+ x_min, x_max = ax.get_xlim()
215
+ y_min, y_max = ax.get_ylim()
216
+ return (
217
+ abs(x_min) <= 1e-3
218
+ or abs(x_max) <= 1e-3
219
+ or abs(y_min) <= 1e-3
220
+ or abs(y_max) <= 1e-3
221
+ )
222
+
223
+
224
+ def get_exponent(num: float) -> int:
225
+ """Gets the exponent of a number for scientific notation"""
226
+ if num == 0:
227
+ raise ValueError("Number must be non-zero")
228
+
229
+ exponent_10 = np.floor(np.log10(abs(num)))
230
+ normalized_num = num / (10**exponent_10)
231
+ return np.floor(np.log10(normalized_num) - np.log10(10**exponent_10)).astype(int)
232
+
233
+
234
+ def plot_corner(
235
+ sampler,
236
+ labels: List[str],
237
+ units: List[str] | None = None,
238
+ fontsize: int = 12,
239
+ discard: int = 0,
240
+ savefig: Path | None = None,
241
+ **kwargs,
242
+ ) -> None:
243
+ """Plots the corner of the posterior spread.
244
+
245
+ Parameters
246
+ ----------
247
+ sampler :
248
+ labels : list of str
249
+ The parameter labels.
250
+ units : list of str, optional
251
+ discard : int, optional
252
+ fontsize : int, optional
253
+ The fontsize. The default is 12.
254
+ savefig : pathlib.Path, optional
255
+ The save path. The default is None.
256
+ """
257
+ labels = format_labels(labels, units)
258
+ quantiles = [x / 100 for x in OPTIONS.fit.quantiles]
259
+ if OPTIONS.fit.fitter == "dynesty":
260
+ results = sampler.results
261
+ _, axarr = dyplot.cornerplot(
262
+ results,
263
+ color="blue",
264
+ labels=labels,
265
+ show_titles=True,
266
+ max_n_ticks=3,
267
+ title_quantiles=quantiles,
268
+ quantiles=quantiles,
269
+ )
270
+
271
+ theta, uncertainties = get_best_fit(sampler)
272
+ for index, row in enumerate(axarr):
273
+ for ax in row:
274
+ if ax is not None:
275
+ if needs_sci_notation(ax):
276
+ if "Sigma" in ax.get_xlabel():
277
+ ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
278
+ ax.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
279
+
280
+ if "Sigma" in ax.get_ylabel():
281
+ ax.yaxis.set_major_formatter(ticker.ScalarFormatter())
282
+ ax.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
283
+ ax.yaxis.get_offset_text().set_position((-0.2, 0))
284
+
285
+ title = ax.get_title()
286
+ if title and np.abs(theta[index]) <= 1e-3:
287
+ exponent = get_exponent(theta[index])
288
+ factor = 10.0**exponent
289
+ formatted_title = (
290
+ rf"${theta[index] * factor:.2f}_{{-{uncertainties[index][0] * factor:.2f}}}"
291
+ rf"^{{+{uncertainties[index][1] * factor:.2f}}}\,1\mathrm{{e}}-{exponent}$"
292
+ )
293
+ ax.set_title(
294
+ f"{labels[index]} = {formatted_title}",
295
+ fontsize=fontsize - 2,
296
+ )
297
+ else:
298
+ samples = sampler.get_chain(discard=discard, flat=True)
299
+ corner.corner(samples, labels=labels)
300
+
301
+ if savefig is not None:
302
+ plt.savefig(savefig, format=Path(savefig).suffix[1:], dpi=OPTIONS.plot.dpi)
303
+ plt.close()
304
+
305
+
306
+ def plot_chains(
307
+ sampler: NestedSampler | DynamicNestedSampler,
308
+ labels: List[str],
309
+ units: List[str] | None = None,
310
+ savefig: Path | None = None,
311
+ **kwargs,
312
+ ) -> None:
313
+ """Plots the fitter's chains.
314
+
315
+ Parameters
316
+ ----------
317
+ sampler : dynesty.NestedSampler or dynesty.DynamicNestedSampler
318
+ The sampler.
319
+ labels : list of str
320
+ The parameter labels.
321
+ units : list of str, optional
322
+ discard : int, optional
323
+ savefig : pathlib.Path, optional
324
+ The save path. The default is None.
325
+ """
326
+ labels = format_labels(labels, units)
327
+ quantiles = [x / 100 for x in OPTIONS.fit.quantiles]
328
+ results = sampler.results
329
+ dyplot.traceplot(
330
+ results,
331
+ labels=labels,
332
+ truths=np.zeros(len(labels)),
333
+ quantiles=quantiles,
334
+ truth_color="black",
335
+ show_titles=True,
336
+ trace_cmap="viridis",
337
+ connect=True,
338
+ connect_highlight=range(5),
339
+ )
340
+
341
+ if savefig:
342
+ plt.savefig(savefig, format=Path(savefig).suffix[1:], dpi=OPTIONS.plot.dpi)
343
+ else:
344
+ plt.show()
345
+ plt.close()
346
+
347
+
348
+ class LogNorm(mcolors.Normalize):
349
+ """Gets the log norm."""
350
+
351
+ def __init__(self, vmin=None, vmax=None, clip=False):
352
+ super().__init__(vmin, vmax, clip)
353
+
354
+ def __call__(self, value, clip=None):
355
+ normalized_value = np.log1p(value - self.vmin) / np.log1p(self.vmax - self.vmin)
356
+ return np.ma.masked_array(normalized_value, np.isnan(normalized_value))
357
+
358
+ def inverse(self, value):
359
+ return np.expm1(value * np.log1p(self.vmax - self.vmin)) + self.vmin
360
+
361
+
362
+ def set_axis_information(
363
+ axarr: Dict[str, List[Axes]],
364
+ key: str,
365
+ cinc=None,
366
+ ) -> Tuple[Axes, Axes]:
367
+ """Sets the axis labels and limits for the different keys."""
368
+ if isinstance(axarr[key], (tuple, list, np.ndarray)):
369
+ upper_ax, lower_ax = axarr[key]
370
+ # set_axes_color(lower_ax, OPTIONS.plot.color.background)
371
+ else:
372
+ upper_ax, lower_ax = axarr[key], None
373
+
374
+ tick_params = {
375
+ "axis": "x",
376
+ "which": "both",
377
+ "bottom": True,
378
+ "top": False,
379
+ "labelbottom": False if lower_ax is not None else True,
380
+ }
381
+
382
+ if key == "flux":
383
+ xlabel = r"$ \lambda (\mathrm{\mu}\text{m}) $"
384
+ residual_label = "Residuals (Jy)"
385
+ ylabel = r"$ F_{\nu} $ (Jy)"
386
+
387
+ elif key in ["vis", "vis2"]:
388
+ xlabel = r"$ B (\text{M}\lambda)$"
389
+ if cinc is not None:
390
+ xlabel = r"$ B_{\text{eff}} (\text{M}\lambda) $"
391
+
392
+ if key == "vis":
393
+ ylabel = r"$ F_{\nu,\,\text{corr}} $ (Jy)"
394
+ residual_label = "Residuals (Jy)"
395
+ else:
396
+ ylabel = "$ V^{2} $ (a.u.)"
397
+ residual_label = "Residuals (a.u.)"
398
+ upper_ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.2f"))
399
+
400
+ elif key == "t3":
401
+ xlabel = r"$ B_{\text{max}} (\text{M}\lambda) $"
402
+ ylabel = r"$ \Phi_{\text{cp}} (^{\circ}) $"
403
+ residual_label = r"Residuals $ (^{\circ}) $"
404
+
405
+ upper_ax.tick_params(**tick_params)
406
+ upper_ax.set_ylabel(ylabel)
407
+ if lower_ax is not None:
408
+ lower_ax.set_xlabel(xlabel)
409
+ lower_ax.set_ylabel(residual_label)
410
+ else:
411
+ upper_ax.set_xlabel(xlabel)
412
+
413
+ return upper_ax, lower_ax
414
+
415
+
416
+ def plot_data_vs_model(
417
+ axarr,
418
+ wavelengths: np.ndarray,
419
+ val: np.ndarray,
420
+ err: np.ndarray,
421
+ key: str,
422
+ baselines: np.ndarray | None = None,
423
+ model_val: np.ndarray | None = None,
424
+ colormap: str = OPTIONS.plot.color.colormap,
425
+ bands: List[str] | str = "all",
426
+ cinc: float | None = None,
427
+ ylims: Dict = {},
428
+ norm=None,
429
+ ):
430
+ """Plots the data versus the model or just the data if not model data given."""
431
+ upper_ax, lower_ax = set_axis_information(axarr, key, cinc)
432
+ colormap, alpha = get_colormap(colormap), 1 if lower_ax is None else 0.55
433
+ hline_color = "gray" if OPTIONS.plot.color.background == "white" else "white"
434
+ errorbar_params, scatter_params = OPTIONS.plot.errorbar, OPTIONS.plot.scatter
435
+ if OPTIONS.plot.color.background == "black":
436
+ errorbar_params.markeredgecolor = "white"
437
+ scatter_params.edgecolor = "white"
438
+
439
+ if model_val is not None:
440
+ model_val = np.ma.masked_array(model_val, mask=val.mask)
441
+
442
+ if bands == "all" or bands is None:
443
+ band_indices = np.where(np.ones_like(wavelengths.value).astype(bool))[0]
444
+ else:
445
+ band_indices = get_band_indices(wavelengths.value, bands)
446
+
447
+ wavelengths = wavelengths[band_indices]
448
+ val, err = val[band_indices], err[band_indices]
449
+ if model_val is not None:
450
+ model_val = model_val[band_indices]
451
+
452
+ set_axes_color(upper_ax, OPTIONS.plot.color.background)
453
+ color = colormap(norm(wavelengths.value))
454
+ if baselines is None:
455
+ grid = [wl.repeat(val.shape[-1]) for wl in wavelengths.value]
456
+ else:
457
+ grid = baselines / wavelengths.value[:, np.newaxis]
458
+
459
+ ymin, ymax = 0, 0
460
+ ymin_res, ymax_res = 0, 0
461
+ for index, _ in enumerate(wavelengths.value):
462
+ errorbar_params.color = scatter_params.color = color[index]
463
+ upper_ax.errorbar(
464
+ grid[index],
465
+ val[index],
466
+ err[index],
467
+ fmt="o",
468
+ **vars(errorbar_params),
469
+ )
470
+
471
+ ymin = min(ymin, np.nanmin(val[index]))
472
+ ymax = max(ymax, np.nanmax(val[index]))
473
+ if model_val is not None and lower_ax is not None:
474
+ upper_ax.scatter(
475
+ grid[index],
476
+ model_val[index],
477
+ marker="X",
478
+ alpha=alpha,
479
+ **vars(scatter_params),
480
+ )
481
+
482
+ if key == "t3":
483
+ upper_ax.axhline(0, color="grey", linestyle="--")
484
+ residuals = np.rad2deg(
485
+ compare_angles(
486
+ np.deg2rad(val[index]),
487
+ np.deg2rad(model_val[index]),
488
+ )
489
+ )
490
+ else:
491
+ residuals = val[index] - model_val[index]
492
+
493
+ residual_errs = err[index]
494
+
495
+ ymin = min(ymin, np.nanmin(model_val[index]))
496
+ ymax = max(ymax, np.nanmax(model_val[index]))
497
+ ymin_res = min(ymin_res, np.nanmin(residuals))
498
+ ymax_res = max(ymax_res, np.nanmax(residuals))
499
+
500
+ lower_ax.errorbar(
501
+ grid[index],
502
+ residuals,
503
+ residual_errs,
504
+ fmt="o",
505
+ **vars(errorbar_params),
506
+ )
507
+ lower_ax.axhline(y=0, color=hline_color, linestyle="--")
508
+
509
+ ymin, ymax = ymin - np.abs(ymin) * 0.25, ymax + ymax * 0.25
510
+ if key in ["flux", "vis"]:
511
+ ylim = ylims.get(key, [0, ymax])
512
+ elif key == "vis2":
513
+ ylim = ylims.get(key, [0, 1])
514
+ else:
515
+ ylim = ylims.get("t3", [ymin, ymax])
516
+
517
+ upper_ax.set_ylim(ylim)
518
+ # TODO: Improve the residual plots
519
+ if lower_ax is not None:
520
+ upper_ax.tick_params(axis="x", which="both", direction="in")
521
+ ymin_res, ymax_res = (
522
+ ymin_res - np.abs(ymin_res) * 0.25,
523
+ ymax_res + ymax_res * 0.25,
524
+ )
525
+ tick_diff = np.diff(upper_ax.get_yticks())[0]
526
+ lower_ax.set_ylim((ymin_res, ymax_res))
527
+
528
+ if not len(axarr) > 1:
529
+ label_color = "lightgray" if OPTIONS.plot.color.background == "black" else "k"
530
+ dot_label = mlines.Line2D(
531
+ [],
532
+ [],
533
+ color=label_color,
534
+ marker="o",
535
+ linestyle="None",
536
+ label="Data",
537
+ alpha=0.6,
538
+ )
539
+ x_label = mlines.Line2D(
540
+ [], [], color=label_color, marker="X", linestyle="None", label="Model"
541
+ )
542
+ legend = upper_ax.legend(handles=[dot_label, x_label])
543
+ set_legend_color(legend, OPTIONS.plot.color.background)
544
+
545
+ errorbar_params.color = scatter_params.color = None
546
+
547
+
548
+ def plot_fit(
549
+ components: List | None = None,
550
+ data_to_plot: List[str | None] | None = None,
551
+ cmap: str = OPTIONS.plot.color.colormap,
552
+ ylims: Dict[str, List[float]] = {},
553
+ bands: List[str] | str = "all",
554
+ title: str | None = None,
555
+ ax: List[List[Axes]] | None = None,
556
+ colorbar: bool = True,
557
+ savefig: Path | None = None,
558
+ ):
559
+ """Plots the deviation of a model from real data of an object for
560
+ total flux, visibilities and closure phases.
561
+
562
+ Parameters
563
+ ----------
564
+ inclination : astropy.units.one
565
+ The axis ratio.
566
+ pos_angle : astropy.units.deg
567
+ The position angle.
568
+ data_to_plot : list of str, optional
569
+ The data to plot. The default is OPTIONS.fit.data.
570
+ ylimits : dict of list of float, optional
571
+ The ylimits for the individual keys.
572
+ bands : list of str or str, optional
573
+ The bands to be plotted. The default is "all".
574
+ cmap : str, optional
575
+ The colormap.
576
+ title : str, optional
577
+ The title. The default is None.
578
+ savefig : pathlib.Path, optional
579
+ The save path. The default is None.
580
+ """
581
+ data_to_plot = OPTIONS.fit.data if data_to_plot is None else data_to_plot
582
+ flux, t3 = OPTIONS.data.flux, OPTIONS.data.t3
583
+ vis = OPTIONS.data.vis if "vis" in data_to_plot else OPTIONS.data.vis2
584
+ nts, wls = range(OPTIONS.data.nt), OPTIONS.fit.wls
585
+ norm = LogNorm(vmin=wls[0].value, vmax=wls[-1].value)
586
+
587
+ data_types, nplots = [], 0
588
+ for key in data_to_plot:
589
+ if key in ["vis", "vis2"] and "vis" not in data_types:
590
+ data_types.append("vis")
591
+ else:
592
+ data_types.append(key)
593
+ nplots += 1
594
+
595
+ for t in nts:
596
+ model_flux, model_vis, model_t3 = compute_observables(components)
597
+
598
+ # NOTE: This won't work with differing cinc and pa
599
+ cinc, pa = components[0].cinc(), components[0].pa()
600
+ figsize = (16, 5) if nplots == 3 else ((12, 5) if nplots == 2 else None)
601
+ fig = plt.figure(figsize=figsize, facecolor=OPTIONS.plot.color.background)
602
+ if ax is None:
603
+ gs = GridSpec(2, nplots, height_ratios=[2.5, 1.5], hspace=0.00)
604
+ axarr = [
605
+ [
606
+ fig.add_subplot(gs[j, i], facecolor=OPTIONS.plot.color.background)
607
+ for j in range(2)
608
+ ]
609
+ for i in range(nplots)
610
+ ]
611
+ else:
612
+ axarr = ax
613
+
614
+ axarr = dict(zip(data_types, axarr))
615
+ plot_kwargs = {"norm": norm, "colormap": cmap}
616
+ if "flux" in data_to_plot:
617
+ plot_data_vs_model(
618
+ axarr,
619
+ wls,
620
+ flux.val[t],
621
+ flux.err[t],
622
+ "flux",
623
+ ylims=ylims,
624
+ bands=bands,
625
+ model_val=model_flux[t],
626
+ cinc=cinc,
627
+ **plot_kwargs,
628
+ )
629
+
630
+ if "vis" in data_to_plot or "vis2" in data_to_plot:
631
+ baselines = np.hypot(*transform_coordinates(vis.u[t], vis.v[t], cinc, pa))
632
+ plot_data_vs_model(
633
+ axarr,
634
+ wls,
635
+ vis.val[t],
636
+ vis.err[t],
637
+ "vis" if "vis" in data_to_plot else "vis2",
638
+ ylims=ylims,
639
+ bands=bands,
640
+ baselines=baselines[:, 1:],
641
+ model_val=model_vis[t],
642
+ cinc=cinc,
643
+ **plot_kwargs,
644
+ )
645
+
646
+ if "t3" in data_to_plot:
647
+ baselines = np.hypot(*transform_coordinates(t3.u[t], t3.v[t], cinc, pa))
648
+ baselines = baselines[t3.i123[t]].T.max(1).reshape(1, -1)
649
+ plot_data_vs_model(
650
+ axarr,
651
+ wls,
652
+ t3.val[t],
653
+ t3.err[t],
654
+ "t3",
655
+ ylims=ylims,
656
+ bands=bands,
657
+ baselines=baselines[:, 1:],
658
+ model_val=model_t3[t],
659
+ cinc=cinc,
660
+ **plot_kwargs,
661
+ )
662
+
663
+ if colorbar:
664
+ sm = cm.ScalarMappable(cmap=get_colormap(cmap), norm=norm)
665
+ sm.set_array([])
666
+ cbar = plt.colorbar(sm, ax=axarr[data_types[-1]])
667
+ cbar.set_ticks(OPTIONS.plot.ticks)
668
+ cbar.set_ticklabels(
669
+ [f"{wavelength:.1f}" for wavelength in OPTIONS.plot.ticks]
670
+ )
671
+
672
+ if OPTIONS.plot.color.background == "black":
673
+ cbar.ax.yaxis.set_tick_params(color="white")
674
+ plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="white")
675
+ for spine in cbar.ax.spines.values():
676
+ spine.set_edgecolor("white")
677
+
678
+ text_color = (
679
+ "white" if OPTIONS.plot.color.background == "black" else "black"
680
+ )
681
+ cbar.set_label(label=r"$\lambda$ ($\mathrm{\mu}$m)", color=text_color)
682
+
683
+ if title is not None:
684
+ plt.title(title)
685
+
686
+ if savefig is not None:
687
+ plt.savefig(
688
+ savefig.parent / f"{savefig.stem}_t{t}{savefig.suffix}",
689
+ format=Path(savefig).suffix[1:],
690
+ dpi=OPTIONS.plot.dpi,
691
+ bbox_inches="tight",
692
+ )
693
+
694
+ if ax is None:
695
+ plt.show()
696
+
697
+ # TODO: Implement plt.close() again here
698
+
699
+
700
+ def plot_overview(
701
+ data_to_plot: List[str | None] = None,
702
+ colormap: str = OPTIONS.plot.color.colormap,
703
+ ylims: Dict[str, List[float]] = {},
704
+ title: str | None = None,
705
+ cinc: float | None = None,
706
+ pa: float | None = None,
707
+ bands: List[str] | str = "all",
708
+ colorbar: bool = True,
709
+ axarr: Axes | None = None,
710
+ savefig: Path | None = None,
711
+ ) -> None:
712
+ """Plots an overview over the total data for baselines [Mlambda].
713
+
714
+ Parameters
715
+ ----------
716
+ data_to_plot : list of str, optional
717
+ The data to plot. The default is OPTIONS.fit.data.
718
+ savefig : pathlib.Path, optional
719
+ The save path. The default is None.
720
+ """
721
+ data_to_plot = OPTIONS.fit.data if data_to_plot is None else data_to_plot
722
+ nts, wls = range(OPTIONS.data.nt), OPTIONS.fit.wls
723
+ norm = LogNorm(vmin=wls[0].value, vmax=wls[-1].value)
724
+
725
+ data_types, nplots = [], 0
726
+ for key in data_to_plot:
727
+ if key in ["vis", "vis2"] and "vis" not in data_types:
728
+ data_types.append("vis")
729
+ else:
730
+ data_types.append(key)
731
+ nplots += 1
732
+
733
+ for t in nts:
734
+ if axarr is None:
735
+ figsize = (15, 5) if nplots == 3 else ((12, 5) if nplots == 2 else None)
736
+ _, axarr = plt.subplots(
737
+ 1,
738
+ nplots,
739
+ figsize=figsize,
740
+ tight_layout=True,
741
+ facecolor=OPTIONS.plot.color.background,
742
+ )
743
+
744
+ axarr = axarr.flatten() if isinstance(axarr, np.ndarray) else [axarr]
745
+ axarr = dict(zip(data_types, axarr))
746
+
747
+ flux, t3 = OPTIONS.data.flux, OPTIONS.data.t3
748
+ vis = OPTIONS.data.vis if "vis" in OPTIONS.fit.data else OPTIONS.data.vis2
749
+
750
+ errorbar_params = OPTIONS.plot.errorbar
751
+ if OPTIONS.plot.color.background == "black":
752
+ errorbar_params.markeredgecolor = "white"
753
+
754
+ plot_kwargs = {"norm": norm, "colormap": colormap}
755
+ if "flux" in data_to_plot:
756
+ plot_data_vs_model(
757
+ axarr,
758
+ wls,
759
+ flux.val[t],
760
+ flux.err[t],
761
+ "flux",
762
+ ylims=ylims,
763
+ bands=bands,
764
+ cinc=cinc,
765
+ **plot_kwargs,
766
+ )
767
+
768
+ if "vis" in data_to_plot or "vis2" in data_to_plot:
769
+ baselines = np.hypot(*transform_coordinates(vis.u[t], vis.v[t], cinc, pa))
770
+ plot_data_vs_model(
771
+ axarr,
772
+ wls,
773
+ vis.val[t],
774
+ vis.err[t],
775
+ "vis" if "vis" in data_to_plot else "vis2",
776
+ ylims=ylims,
777
+ bands=bands,
778
+ baselines=baselines[:, 1:],
779
+ cinc=cinc,
780
+ **plot_kwargs,
781
+ )
782
+
783
+ if "t3" in data_to_plot:
784
+ baselines = np.hypot(*transform_coordinates(t3.u[t], t3.v[t], cinc, pa))
785
+ baselines = baselines[t3.i123[t]].T.max(1).reshape(1, -1)
786
+ plot_data_vs_model(
787
+ axarr,
788
+ wls,
789
+ t3.val[t],
790
+ t3.err[t],
791
+ "t3",
792
+ ylims=ylims,
793
+ bands=bands,
794
+ baselines=baselines[:, 1:],
795
+ cinc=cinc,
796
+ **plot_kwargs,
797
+ )
798
+
799
+ if colorbar:
800
+ sm = cm.ScalarMappable(cmap=colormap, norm=norm)
801
+ sm.set_array([])
802
+ cbar = plt.colorbar(sm, ax=axarr[data_types[-1]])
803
+
804
+ # TODO: Set the ticks, but make it so that it is flexible for the band
805
+ cbar.set_ticks(OPTIONS.plot.ticks)
806
+ cbar.set_ticklabels(
807
+ [f"{wavelength:.1f}" for wavelength in OPTIONS.plot.ticks]
808
+ )
809
+
810
+ if OPTIONS.plot.color.background == "black":
811
+ cbar.ax.yaxis.set_tick_params(color="white")
812
+ plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="white")
813
+ for spine in cbar.ax.spines.values():
814
+ spine.set_edgecolor("white")
815
+ opposite_color = (
816
+ "white" if OPTIONS.plot.color.background == "black" else "black"
817
+ )
818
+ cbar.set_label(label=r"$\lambda$ ($\mathrm{\mu}$m)", color=opposite_color)
819
+
820
+ if title is not None:
821
+ plt.title(title)
822
+
823
+ if savefig is not None:
824
+ plt.savefig(
825
+ savefig.parent / f"{savefig.stem}_t{t}{savefig.suffix}",
826
+ format=Path(savefig).suffix[1:],
827
+ dpi=OPTIONS.plot.dpi,
828
+ )
829
+
830
+ if savefig is None:
831
+ if axarr is not None:
832
+ return
833
+
834
+ plt.show()
835
+ plt.close()
836
+
837
+
838
+ def plot_sed(
839
+ wavelength_range: u.um,
840
+ components: List[FourierComponent | None] = None,
841
+ scaling: str = "nu",
842
+ no_model: bool = False,
843
+ ax: plt.Axes | None = None,
844
+ savefig: Path | None = None,
845
+ ):
846
+ """Plots the observables of the model.
847
+
848
+ Parameters
849
+ ----------
850
+ wavelength_range : astropy.units.m
851
+ scaling : str, optional
852
+ The scaling of the SED. "nu" for the flux to be
853
+ in Jy times Hz. If "lambda" the flux is in Jy times m.
854
+ If "none" the flux is in Jy.
855
+ The default is "nu".
856
+ """
857
+ color = OPTIONS.plot.color
858
+ savefig = Path.cwd() if savefig is None else savefig
859
+ wavelength = np.linspace(wavelength_range[0], wavelength_range[1], OPTIONS.plot.dim)
860
+
861
+ if not no_model:
862
+ wavelength = OPTIONS.fit.wls if wavelength is None else wavelength
863
+ components = [comp for comp in components if comp.name != "Point Source"]
864
+ flux = np.sum([comp.compute_flux(0, wavelength) for comp in components], axis=0)
865
+ if flux.size > 0:
866
+ flux = np.tile(flux, (len(OPTIONS.data.readouts))).real
867
+
868
+ if ax is None:
869
+ fig = plt.figure(facecolor=color.background, tight_layout=True)
870
+ ax = plt.axes(facecolor=color.background)
871
+ set_axes_color(ax, color.background)
872
+ else:
873
+ fig = None
874
+
875
+ if len(OPTIONS.data.readouts) > 1:
876
+ names = [
877
+ re.findall(r"(\d{4}-\d{2}-\d{2})", readout.fits_file.name)[0]
878
+ for readout in OPTIONS.data.readouts
879
+ ]
880
+ else:
881
+ names = [OPTIONS.data.readouts[0].fits_file.name]
882
+
883
+ cmap = plt.get_cmap(color.colormap)
884
+ norm = mcolors.LogNorm(vmin=1, vmax=len(set(names)))
885
+ colors = [cmap(norm(i)) for i in range(1, len(set(names)) + 1)]
886
+ date_to_color = {date: color for date, color in zip(set(names), colors)}
887
+ sorted_readouts = np.array(OPTIONS.data.readouts.copy())[np.argsort(names)].tolist()
888
+
889
+ values = []
890
+ for name, readout in zip(np.sort(names), sorted_readouts):
891
+ if readout.flux.val.size == 0:
892
+ continue
893
+
894
+ readout_wl = readout.wl.value
895
+ readout_flux, readout_err = (
896
+ readout.flux.val.flatten(),
897
+ readout.flux.err.flatten(),
898
+ )
899
+ readout_err_percentage = readout_err / readout_flux
900
+
901
+ if scaling == "nu":
902
+ readout_flux = (readout_flux * u.Jy).to(u.W / u.m**2 / u.Hz)
903
+ readout_flux = (
904
+ readout_flux * (const.c / ((readout_wl * u.um).to(u.m))).to(u.Hz)
905
+ ).value
906
+
907
+ readout_err = readout_err_percentage * readout_flux
908
+ lower_err, upper_err = readout_flux - readout_err, readout_flux + readout_err
909
+ if "HAW" in readout.fits_file.name:
910
+ indices_high = np.where((readout_wl >= 4.55) & (readout_wl <= 4.9))
911
+ indices_low = np.where((readout_wl >= 3.1) & (readout_wl <= 3.9))
912
+ for indices in [indices_high, indices_low]:
913
+ line = ax.plot(
914
+ readout_wl[indices],
915
+ readout_flux[indices],
916
+ color=date_to_color[name],
917
+ )
918
+ ax.fill_between(
919
+ readout_wl[indices],
920
+ lower_err[indices],
921
+ upper_err[indices],
922
+ color=line[0].get_color(),
923
+ alpha=0.5,
924
+ )
925
+ value_indices = np.hstack([indices_high, indices_low])
926
+ lim_values = readout_flux[value_indices].flatten()
927
+ else:
928
+ line = ax.plot(readout_wl, readout_flux, color=date_to_color[name])
929
+ ax.fill_between(
930
+ readout_wl,
931
+ lower_err,
932
+ upper_err,
933
+ color=line[0].get_color(),
934
+ alpha=0.5,
935
+ )
936
+ lim_values = readout_flux
937
+ values.append(lim_values)
938
+
939
+ flux_label = r"$F_{\nu}$ (Jy)"
940
+ if not no_model:
941
+ flux = flux[:, 0]
942
+ if scaling == "nu":
943
+ flux = (flux * u.Jy).to(u.W / u.m**2 / u.Hz)
944
+ flux = (flux * (const.c / (wavelength.to(u.m))).to(u.Hz)).value
945
+ flux_label = r"$\nu F_{\nu}$ (W m$^{-2}$)"
946
+
947
+ if not no_model:
948
+ ax.plot(wavelength, flux, label="Model", color="red")
949
+ values.append(flux)
950
+
951
+ if fig is not None:
952
+ ax.set_xlabel(r"$\lambda$ ($\mathrm{\mu}$m)")
953
+ ax.set_ylabel(flux_label)
954
+ ax.legend()
955
+
956
+ max_value = np.concatenate(values).max()
957
+ ax.set_ylim([0, max_value + 0.2 * max_value])
958
+
959
+ if savefig is not None:
960
+ plt.savefig(savefig, format=Path(savefig).suffix[1:], dpi=OPTIONS.plot.dpi)
961
+ plt.close()
962
+
963
+
964
+ def plot_product(
965
+ points,
966
+ product,
967
+ xlabel,
968
+ ylabel,
969
+ save_path=None,
970
+ ax=None,
971
+ colorbar=False,
972
+ cmap: str = OPTIONS.plot.color.colormap,
973
+ scale=None,
974
+ label=None,
975
+ ):
976
+ norm = None
977
+ if label is not None:
978
+ if isinstance(label, (np.ndarray, u.Quantity)):
979
+ norm = mcolors.Normalize(vmin=label[0].value, vmax=label[-1].value)
980
+
981
+ if ax is None:
982
+ fig, ax = plt.subplots()
983
+ else:
984
+ fig = ax.figure
985
+
986
+ if product.ndim > 1:
987
+ for lb, prod in zip(label, product):
988
+ color = None
989
+ if norm is not None:
990
+ colormap = get_colormap(cmap)
991
+ color = colormap(norm(lb.value))
992
+ ax.plot(points, prod, label=lb, color=color)
993
+ if not colorbar:
994
+ ax.legend()
995
+ else:
996
+ ax.plot(points, product, label=label)
997
+
998
+ ax.set_xlabel(xlabel)
999
+ ax.set_ylabel(ylabel)
1000
+
1001
+ if scale == "log":
1002
+ ax.set_yscale("log")
1003
+ elif scale == "loglog":
1004
+ ax.set_yscale("log")
1005
+ ax.set_xscale("log")
1006
+ elif scale == "sci":
1007
+ ax.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True))
1008
+ ax.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
1009
+
1010
+ if colorbar:
1011
+ sm = cm.ScalarMappable(cmap=get_colormap(cmap), norm=norm)
1012
+ sm.set_array([])
1013
+ cbar = plt.colorbar(sm, ax=ax)
1014
+ cbar.set_ticks(OPTIONS.plot.ticks)
1015
+ cbar.set_ticklabels([f"{wavelength:.1f}" for wavelength in OPTIONS.plot.ticks])
1016
+ cbar.set_label(label=r"$\lambda$ ($\mathrm{\mu}$m)")
1017
+
1018
+ if save_path is not None:
1019
+ fig.savefig(save_path, format=Path(save_path).suffix[1:], dpi=OPTIONS.plot.dpi)
1020
+ plt.close(fig)
1021
+
1022
+
1023
+ # TODO: Clean and split this function into multiple ones
1024
+ def plot_products(
1025
+ dim: int,
1026
+ components: List[FourierComponent],
1027
+ component_labels: List[str],
1028
+ save_dir: Path | None = None,
1029
+ ) -> None:
1030
+ """Plots the intermediate products of the model (temperature, density, etc.)."""
1031
+ component_labels = [
1032
+ " ".join(map(str.title, label.split("_"))) for label in component_labels
1033
+ ]
1034
+ for t in range(OPTIONS.data.nt):
1035
+ wls = np.linspace(OPTIONS.fit.wls[0], OPTIONS.fit.wls[-1], dim)
1036
+ radii, surface_density, optical_depth = [], [], []
1037
+ fluxes, emissivity, intensity = [], [], []
1038
+ _, ax = plt.subplots(figsize=(5, 5))
1039
+ for label, component in zip(component_labels, components):
1040
+ component.dim.value = dim
1041
+ flux = component.fr(t, wls) * component.compute_flux(t, wls).squeeze()
1042
+ plot_product(
1043
+ wls,
1044
+ flux,
1045
+ r"$\lambda$ ($\mathrm{\mu}$m)",
1046
+ r"$F_{\nu}$ (Jy)",
1047
+ scale="log",
1048
+ ax=ax,
1049
+ label=label,
1050
+ )
1051
+ fluxes.append(flux)
1052
+ if component.name in ["Point", "Gauss", "BBGauss"]:
1053
+ continue
1054
+
1055
+ radius = component.compute_internal_grid(t, wls)
1056
+ radii.append(radius)
1057
+
1058
+ surface_density.append(component.compute_surface_density(radius, t, wls))
1059
+ optical_depth.append(
1060
+ component.compute_optical_depth(radius, t, wls[:, np.newaxis])
1061
+ )
1062
+ emissivity.append(
1063
+ component.compute_emissivity(radius, t, wls[:, np.newaxis])
1064
+ )
1065
+ intensity.append(component.compute_intensity(radius, t, wls[:, np.newaxis]))
1066
+
1067
+ surface_density = u.Quantity(surface_density)
1068
+ optical_depth = u.Quantity(optical_depth)
1069
+ emissivity = u.Quantity(emissivity)
1070
+ intensity = u.Quantity(intensity)
1071
+
1072
+ total_flux = np.sum(fluxes, axis=0)
1073
+ ax.plot(wls, total_flux, label="Total")
1074
+ ax.set_yscale("log")
1075
+ ax.set_ylim([1e-1, None])
1076
+ ax.legend()
1077
+ plt.savefig(save_dir / f"fluxes_t{t}.png", format="png", dpi=OPTIONS.plot.dpi)
1078
+ plt.close()
1079
+
1080
+ _, ax = plt.subplots(figsize=(5, 5))
1081
+ for label, flux_ratio in zip(component_labels, np.array(fluxes) / total_flux):
1082
+ plot_product(
1083
+ wls,
1084
+ flux_ratio * 100,
1085
+ r"$\lambda$ ($\mathrm{\mu}$m)",
1086
+ r"$F_{\nu}$ / $F_{\nu,\,\mathrm{tot}}$ (%)",
1087
+ ax=ax,
1088
+ label=label,
1089
+ )
1090
+
1091
+ ax.legend()
1092
+ ax.set_ylim([0, 100])
1093
+ plt.savefig(
1094
+ save_dir / f"flux_ratios_t{t}.png", format="png", dpi=OPTIONS.plot.dpi
1095
+ )
1096
+ plt.close()
1097
+
1098
+ radii_bounds = [
1099
+ (prev[-1], current[0]) for prev, current in zip(radii[:-1], radii[1:])
1100
+ ]
1101
+ fill_radii = [np.linspace(lower, upper, dim) for lower, upper in radii_bounds]
1102
+ merged_radii = list(chain.from_iterable(zip_longest(radii, fill_radii)))[:-1]
1103
+ merged_radii = u.Quantity(np.concatenate(merged_radii, axis=0))
1104
+ fill_zeros = np.zeros((len(fill_radii), wls.size, dim))
1105
+ disc_component = [
1106
+ comp for comp in components if comp.name not in ["Point", "Gauss"]
1107
+ ][0]
1108
+
1109
+ # TODO: Make it so that the temperatures are somehow continous in the plot? (Maybe check for self.temps in the models?)
1110
+ # or interpolate smoothly somehow (see the one youtube video?) :D
1111
+ temperature = disc_component.compute_temperature(merged_radii, t, wls)
1112
+ surface_density = u.Quantity(
1113
+ list(
1114
+ chain.from_iterable(
1115
+ zip_longest(surface_density, fill_zeros[:, 0, :] * u.g / u.cm**2)
1116
+ )
1117
+ )[:-1]
1118
+ )
1119
+ surface_density = np.concatenate(surface_density, axis=0)
1120
+ optical_depth = u.Quantity(
1121
+ list(chain.from_iterable(zip_longest(optical_depth, fill_zeros)))[:-1]
1122
+ )
1123
+ optical_depth = np.hstack(optical_depth)
1124
+ emissivity = u.Quantity(
1125
+ list(chain.from_iterable(zip_longest(emissivity, fill_zeros)))[:-1]
1126
+ )
1127
+ emissivity = np.hstack(emissivity)
1128
+ intensity = u.Quantity(
1129
+ list(
1130
+ chain.from_iterable(
1131
+ zip_longest(
1132
+ intensity, fill_zeros * u.erg / u.cm**2 / u.s / u.Hz / u.sr
1133
+ )
1134
+ )
1135
+ )[:-1]
1136
+ )
1137
+ intensity = np.hstack(intensity)
1138
+ intensity = intensity.to(u.W / u.m**2 / u.Hz / u.sr)
1139
+ merged_radii_mas = (
1140
+ (merged_radii.to(u.au) / components[1].dist().to(u.pc)).value * 1e3 * u.mas
1141
+ )
1142
+
1143
+ # TODO: Code this in a better manner
1144
+ wls = [1.7, 2.15, 3.4, 8, 11.3, 13] * u.um
1145
+ cumulative_intensity = (
1146
+ np.zeros((wls.size, merged_radii_mas.size))
1147
+ * u.erg
1148
+ / u.s
1149
+ / u.Hz
1150
+ / u.cm**2
1151
+ / u.sr
1152
+ )
1153
+ # for index, wl in enumerate(wls):
1154
+ # tmp_intensity = [
1155
+ # component.compute_intensity(radius, t, wl)
1156
+ # for radius, component in zip(radii, components[1:])
1157
+ # ]
1158
+ # tmp_intensity = u.Quantity(
1159
+ # list(
1160
+ # chain.from_iterable(
1161
+ # zip_longest(
1162
+ # tmp_intensity,
1163
+ # fill_zeros[0, 0][np.newaxis, :]
1164
+ # * u.erg
1165
+ # / u.cm**2
1166
+ # / u.s
1167
+ # / u.Hz
1168
+ # / u.sr,
1169
+ # )
1170
+ # )
1171
+ # )[:-1]
1172
+ # )
1173
+ # cumulative_intensity[index, :] = np.hstack(tmp_intensity)
1174
+ #
1175
+ # cumulative_intensity = cumulative_intensity.to(
1176
+ # u.erg / u.s / u.Hz / u.cm**2 / u.mas**2
1177
+ # )
1178
+ # cumulative_total_flux = (
1179
+ # 2
1180
+ # * np.pi
1181
+ # * disc_component.cinc(t, wls)
1182
+ # * np.trapz(merged_radii_mas * cumulative_intensity, merged_radii_mas).to(
1183
+ # u.Jy
1184
+ # )[:, np.newaxis]
1185
+ # )
1186
+ #
1187
+ # cumulative_flux = np.zeros((wls.size, merged_radii.size)) * u.Jy
1188
+ # for index, _ in enumerate(merged_radii):
1189
+ # cumulative_flux[:, index] = (
1190
+ # 2
1191
+ # * np.pi
1192
+ # * disc_component.cinc(t, wls)
1193
+ # * np.trapz(
1194
+ # merged_radii_mas[:index] * cumulative_intensity[:, :index],
1195
+ # merged_radii_mas[:index],
1196
+ # ).to(u.Jy)
1197
+ # )
1198
+ # cumulative_flux_ratio = cumulative_flux / cumulative_total_flux
1199
+ # plot_product(
1200
+ # merged_radii.value,
1201
+ # cumulative_flux_ratio.value,
1202
+ # "$R$ (AU)",
1203
+ # r"$F_{\nu}\left(r\right)/F_{\nu,\,\mathrm{{tot}}}$ (a.u.)",
1204
+ # label=wls,
1205
+ # save_path=save_dir / f"cumulative_flux_ratio_t{t}.png",
1206
+ # )
1207
+
1208
+ plot_product(
1209
+ merged_radii.value,
1210
+ temperature.value,
1211
+ "$R$ (AU)",
1212
+ "$T$ (K)",
1213
+ scale="log",
1214
+ save_path=save_dir / f"temperature_t{t}.png",
1215
+ )
1216
+ plot_product(
1217
+ merged_radii.value,
1218
+ surface_density.value,
1219
+ "$R$ (au)",
1220
+ r"$\Sigma$ (g cm$^{-2}$)",
1221
+ save_path=save_dir / f"surface_density_t{t}.png",
1222
+ scale="sci",
1223
+ )
1224
+ plot_product(
1225
+ merged_radii.value,
1226
+ optical_depth.value,
1227
+ "$R$ (AU)",
1228
+ r"$\tau_{\nu}$",
1229
+ save_path=save_dir / f"optical_depths_t{t}.png",
1230
+ scale="log",
1231
+ colorbar=True,
1232
+ label=wls,
1233
+ )
1234
+ # plot_product(merged_radii.value, emissivities.value,
1235
+ # "$R$ (AU)", r"$\epsilon_{\nu}$",
1236
+ # save_path=save_dir / "emissivities.png",
1237
+ # label=wavelength)
1238
+ # plot_product(merged_radii.value, brightnesses.value,
1239
+ # "$R$ (AU)", r"$I_{\nu}$ (W m$^{-2}$ Hz$^{-1}$ sr$^{-1}$)",
1240
+ # save_path=save_dir / "brightnesses.png",
1241
+ # scale="log", label=wavelength)