ExoIris 0.18.0__py3-none-any.whl → 0.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
exoiris/__init__.py CHANGED
@@ -15,7 +15,7 @@
15
15
  # along with this program. If not, see <https://www.gnu.org/licenses/>.
16
16
 
17
17
  from .exoiris import ExoIris, load_model # noqa
18
- from .tsdata import TSData, TSDataSet # noqa
18
+ from .tsdata import TSData, TSDataGroup # noqa
19
19
  from .binning import Binning # noqa
20
20
  from .ldtkld import LDTkLD # noqa
21
21
  from .tslpf import clean_knots # noqa
exoiris/exoiris.py CHANGED
@@ -24,27 +24,24 @@ from typing import Optional, Callable, Any, Literal
24
24
 
25
25
  import astropy.io.fits as pf
26
26
  import astropy.units as u
27
- import emcee
27
+ import matplotlib.axes
28
28
  import pandas as pd
29
- import pytransit.utils.de
30
- import seaborn as sb
31
29
  from astropy.table import Table
32
30
  from celerite2 import GaussianProcess, terms
33
31
  from emcee import EnsembleSampler
34
32
  from matplotlib.pyplot import subplots, setp, figure, Figure, Axes
35
- from numpy import (where, sqrt, clip, percentile, median, squeeze, floor, ndarray,
36
- array, inf, newaxis, arange, tile, sort, argsort, concatenate, full, nan, r_, nanpercentile, log10,
37
- ceil)
38
- from numpy.random import normal, permutation
33
+ from numpy import (any, where, sqrt, clip, percentile, median, squeeze, floor, ndarray, isfinite,
34
+ array, inf, arange, argsort, concatenate, full, nan, r_, nanpercentile, log10,
35
+ ceil, unique)
36
+ from numpy.random import normal
39
37
  from pytransit import UniformPrior, NormalPrior
40
- from pytransit.orbits import epoch
41
38
  from pytransit.param import ParameterSet
42
39
  from pytransit.utils.de import DiffEvol
43
40
  from scipy.stats import norm
44
41
  from uncertainties import UFloat
45
42
 
46
43
  from .ldtkld import LDTkLD
47
- from .tsdata import TSData, TSDataSet
44
+ from .tsdata import TSData, TSDataGroup
48
45
  from .tslpf import TSLPF
49
46
  from .wlpf import WhiteLPF
50
47
 
@@ -73,32 +70,69 @@ def load_model(fname: Path | str, name: str | None = None):
73
70
  If the file format is invalid or does not match the expected format.
74
71
  """
75
72
  with pf.open(fname) as hdul:
76
- data = TSDataSet.import_fits(hdul)
73
+ data = TSDataGroup.import_fits(hdul)
74
+ hdr = hdul[0].header
77
75
 
78
- if hdul[0].header['LDMODEL'] == 'ldtk':
79
- filters, teff, logg, metal, dataset = pickle.loads(codecs.decode(json.loads(hdul[0].header['LDTKLD']).encode(), "base64"))
76
+ # Read the limb darkening model.
77
+ # ==============================
78
+ if hdr['LDMODEL'] == 'ldtk':
79
+ filters, teff, logg, metal, dataset = pickle.loads(codecs.decode(json.loads(hdr['LDTKLD']).encode(), "base64"))
80
80
  ldm = LDTkLD(filters, teff, logg, metal, dataset=dataset)
81
81
  else:
82
- ldm = hdul[0].header['LDMODEL']
82
+ ldm = hdr['LDMODEL']
83
83
 
84
+ # Read the interpolation model.
85
+ # =============================
84
86
  try:
85
- ip = hdul[0].header['INTERP']
87
+ ip = hdr['INTERP']
86
88
  except KeyError:
87
89
  ip = 'bspline'
88
90
 
89
- #TODO: save and load the noise model information
90
- a = ExoIris(name or hdul[0].header['NAME'], ldmodel=ldm, data=data, interpolation=ip)
91
+ # Read the noise model.
92
+ # =====================
93
+ try:
94
+ noise_model = hdr['NOISE']
95
+ except KeyError:
96
+ noise_model = "white"
97
+
98
+ # Setup the analysis.
99
+ # ===================
100
+ a = ExoIris(name or hdr['NAME'], ldmodel=ldm, data=data, noise_model=noise_model, interpolation=ip)
91
101
  a.set_radius_ratio_knots(hdul['K_KNOTS'].data.astype('d'))
92
102
  a.set_limb_darkening_knots(hdul['LD_KNOTS'].data.astype('d'))
93
103
 
104
+ # Read the white light curve models if they exist.
105
+ # ================================================
94
106
  try:
95
- a.period = hdul[0].header['P']
96
- a.zero_epoch = hdul[0].header['T0']
97
- a.transit_duration = hdul[0].header['T14']
98
- [d.mask_transit(a.zero_epoch, a.period, a.transit_duration) for d in a.data]
107
+ tb = Table.read(hdul['WHITE_DATA'])
108
+ white_ids = tb['id'].data
109
+ uids = unique(white_ids)
110
+ a._white_times = [tb['time'].data[white_ids == i] for i in uids]
111
+ a._white_fluxes = [tb['flux_obs'].data[white_ids == i] for i in uids]
112
+ a._white_errors = [tb['flux_obs_err'].data[white_ids == i] for i in uids]
113
+ a._white_models = [tb['flux_mod'].data[white_ids == i] for i in uids]
99
114
  except KeyError:
100
115
  pass
101
116
 
117
+ # Read the ephemeris if it exists.
118
+ # ================================
119
+ try:
120
+ a.period = hdr['P']
121
+ a.zero_epoch = hdr['T0']
122
+ a.transit_duration = hdr['T14']
123
+ [d.mask_transit(a.zero_epoch, a.period, a.transit_duration) for d in a.data]
124
+ except (KeyError, ValueError):
125
+ pass
126
+
127
+ # Read the spots if they exist.
128
+ # =============================
129
+ if 'SPOTS' in hdr and hdr['SPOTS'] is True:
130
+ a.initialize_spots(hdr["SP_TSTAR"], hdr["SP_REFWL"], hdr["SP_TLSE"])
131
+ for i in range(hdr['NSPOTS']):
132
+ a.add_spot(hdr[f'SP{i+1:02d}_EG'])
133
+
134
+ # Read the priors.
135
+ # ================
102
136
  priors = pickle.loads(codecs.decode(json.loads(hdul['PRIORS'].header['PRIORS']).encode(), "base64"))
103
137
  a._tsa.ps = ParameterSet([pickle.loads(p) for p in priors])
104
138
  a._tsa.ps.freeze()
@@ -116,9 +150,9 @@ class ExoIris:
116
150
  """The core ExoIris class providing tools for exoplanet transit spectroscopy.
117
151
  """
118
152
 
119
- def __init__(self, name: str, ldmodel, data: TSDataSet | TSData, nk: int = 50, nldc: int = 10, nthreads: int = 1,
120
- tmpars: dict | None = None, noise_model: str = 'white',
121
- interpolation: Literal['bspline', 'pchip', 'makima'] = 'bspline'):
153
+ def __init__(self, name: str, ldmodel, data: TSDataGroup | TSData, nk: int = 50, nldc: int = 10, nthreads: int = 1,
154
+ tmpars: dict | None = None, noise_model: Literal["white", "fixed_gp", "free_gp"] = 'white',
155
+ interpolation: Literal['bspline', 'pchip', 'makima', 'nearest', 'linear'] = 'makima'):
122
156
  """
123
157
  Parameters
124
158
  ----------
@@ -139,10 +173,31 @@ class ExoIris:
139
173
  noise_model
140
174
  The noise model to use. Should be either "white" for white noise or "fixed_gp" for Gaussian Process.
141
175
  """
142
- data = TSDataSet([data]) if isinstance(data, TSData) else data
143
- self._tsa: TSLPF = TSLPF(name, ldmodel, data, nk=nk, nldc=nldc, nthreads=nthreads, tmpars=tmpars,
176
+ data = TSDataGroup([data]) if isinstance(data, TSData) else data
177
+
178
+ for d in data:
179
+ if any(~isfinite(d.fluxes[d.mask])):
180
+ raise ValueError(f"The {d.name} data set flux array contains unmasked noninfinite values.")
181
+
182
+ if any(~isfinite(d.errors[d.mask])):
183
+ raise ValueError(f"The {d.name} data set error array contains unmasked noninfinite values.")
184
+
185
+ ngs = array(data.noise_groups)
186
+ if not ((ngs.min() == 0) and (ngs.max() + 1 == unique(ngs).size)):
187
+ raise ValueError("The noise groups must start from 0 and be consecutive.")
188
+
189
+ ogs = array(data.offset_groups)
190
+ if not ((ogs.min() == 0) and (ogs.max() + 1 == unique(ogs).size)):
191
+ raise ValueError("The offset groups must start from 0 and be consecutive.")
192
+
193
+ egs = array(data.epoch_groups)
194
+ if not ((egs.min() == 0) and (egs.max() + 1 == unique(egs).size)):
195
+ raise ValueError("The epoch groups must start from 0 and be consecutive.")
196
+
197
+ self._tsa = TSLPF(self, name, ldmodel, data, nk=nk, nldc=nldc, nthreads=nthreads, tmpars=tmpars,
144
198
  noise_model=noise_model, interpolation=interpolation)
145
- self._wa: WhiteLPF | None = None
199
+ self._wa: None | WhiteLPF = None
200
+
146
201
  self.nthreads: int = nthreads
147
202
 
148
203
  self.period: float | None = None
@@ -150,6 +205,11 @@ class ExoIris:
150
205
  self.transit_duration: float | None= None
151
206
  self._tref = floor(self.data.tmin)
152
207
 
208
+ self._white_times: None | list[ndarray] = None
209
+ self._white_fluxes: None | list[ndarray] = None
210
+ self._white_errors: None | list[ndarray] = None
211
+ self._white_models: None | list[ndarray] = None
212
+
153
213
  def lnposterior(self, pvp: ndarray) -> ndarray:
154
214
  """Calculate the log posterior probability for a single parameter vector or an array of parameter vectors.
155
215
 
@@ -180,7 +240,7 @@ class ExoIris:
180
240
  """
181
241
  self._tsa.set_noise_model(noise_model)
182
242
 
183
- def set_data(self, data: TSData | TSDataSet) -> None:
243
+ def set_data(self, data: TSData | TSDataGroup) -> None:
184
244
  """Set the model data.
185
245
 
186
246
  Parameters
@@ -188,7 +248,7 @@ class ExoIris:
188
248
  data
189
249
  The spectroscopic transit light curve.
190
250
  """
191
- data = TSDataSet([data]) if isinstance(data, TSData) else data
251
+ data = TSDataGroup([data]) if isinstance(data, TSData) else data
192
252
  self._tsa.set_data(data)
193
253
 
194
254
  def set_prior(self, parameter: Literal['radius ratios', 'baselines', 'wn multipliers'] | str,
@@ -279,7 +339,7 @@ class ExoIris:
279
339
  metal = (metal.n, metal.s) if isinstance(metal, UFloat) else metal
280
340
  self._tsa.set_ldtk_prior(teff, logg, metal, dataset, width, uncertainty_multiplier)
281
341
 
282
- def set_gp_hyperparameters(self, sigma: float, rho: float) -> None:
342
+ def set_gp_hyperparameters(self, sigma: float, rho: float, idata: None | int = None) -> None:
283
343
  """Set Gaussian Process (GP) hyperparameters assuming a Matern-3/2 kernel.
284
344
 
285
345
  Parameters
@@ -288,8 +348,10 @@ class ExoIris:
288
348
  The kernel amplitude parameter.
289
349
  rho
290
350
  The length scale parameter.
351
+ idata
352
+ The data set for which to set the hyperparameters. If None, the hyperparameters are set for all data sets.
291
353
  """
292
- self._tsa.set_gp_hyperparameters(sigma, rho)
354
+ self._tsa.set_gp_hyperparameters(sigma, rho, idata)
293
355
 
294
356
  def set_gp_kernel(self, kernel: terms.Term) -> None:
295
357
  """Set the Gaussian Process (GP) kernel.
@@ -301,6 +363,36 @@ class ExoIris:
301
363
  """
302
364
  self._tsa.set_gp_kernel(kernel)
303
365
 
366
+ def initialize_spots(self, tstar: float, wlref: float, include_tlse: bool = True):
367
+ """Initialize star spot model using given stellar and wavelength reference values.
368
+
369
+ Parameters
370
+ ----------
371
+ tstar
372
+ Effective stellar temperature [K].
373
+ wlref
374
+ Reference wavelength where spot amplitude matches the amplitude parameter.
375
+ """
376
+ self._tsa.initialize_spots(tstar, wlref, include_tlse)
377
+
378
+ def add_spot(self, epoch_group: int) -> None:
379
+ """Add a new star spot and associate it with an epoch group.
380
+
381
+ Parameters
382
+ ----------
383
+ epoch_group
384
+ Identifier for the epoch group to which the spot will be added.
385
+ """
386
+ self._tsa.add_spot(epoch_group)
387
+
388
+ @property
389
+ def nspots(self) -> int:
390
+ """Number of star spots."""
391
+ if self._tsa.spot_model is None:
392
+ return 0
393
+ else:
394
+ return self._tsa.spot_model.nspots
395
+
304
396
  @property
305
397
  def name(self) -> str:
306
398
  """Analysis name."""
@@ -311,7 +403,7 @@ class ExoIris:
311
403
  self._tsa.name = name
312
404
 
313
405
  @property
314
- def data(self) -> TSDataSet:
406
+ def data(self) -> TSDataGroup:
315
407
  """Analysis data set."""
316
408
  return self._tsa.data
317
409
 
@@ -378,22 +470,35 @@ class ExoIris:
378
470
  @property
379
471
  def white_times(self) -> list[ndarray]:
380
472
  """White light curve time arrays."""
381
- return self._wa.times
473
+ if self._wa is None:
474
+ return self._white_times
475
+ else:
476
+ return self._wa.times
382
477
 
383
478
  @property
384
479
  def white_fluxes(self) -> list[ndarray]:
385
480
  """White light curve flux arrays."""
386
- return self._wa.fluxes
481
+ if self._wa is None:
482
+ return self._white_fluxes
483
+ else:
484
+ return self._wa.fluxes
387
485
 
388
486
  @property
389
487
  def white_models(self) -> list[ndarray]:
390
- fm = self._wa.flux_model(self._wa._local_minimization.x)
391
- return [fm[sl] for sl in self._wa.lcslices]
488
+ """Fitted white light curve flux model arrays."""
489
+ if self._wa is None:
490
+ return self._white_models
491
+ else:
492
+ fm = self._wa.flux_model(self._wa._local_minimization.x)
493
+ return [fm[sl] for sl in self._wa.lcslices]
392
494
 
393
495
  @property
394
496
  def white_errors(self) -> list[ndarray]:
395
497
  """White light curve flux error arrays."""
396
- return self._wa.std_errors
498
+ if self._wa is None:
499
+ return self._white_errors
500
+ else:
501
+ return self._wa.std_errors
397
502
 
398
503
  def add_radius_ratio_knots(self, knot_wavelengths: Sequence) -> None:
399
504
  """Add radius ratio (k) knots.
@@ -464,65 +569,36 @@ class ExoIris:
464
569
  """Print the model parameterization."""
465
570
  self._tsa.print_parameters(1)
466
571
 
467
- def plot_setup(self, figsize: tuple[float, float] | None =None, xscale: str | None = None, xticks: Sequence | None = None) -> Figure:
572
+ def plot_setup(self, figsize: tuple[float, float] | None = None,
573
+ ax: matplotlib.axes.Axes | None = None,
574
+ xscale: str | None = None, xticks: Sequence | None = None,
575
+ yshift: float = 0.1, mh:float = 0.08, side_margin: float = 0.05,
576
+ lw: float = 0.5, c='k') -> Figure:
468
577
  """Plot the model setup with limb darkening knots, radius ratio knots, and data binning.
469
-
470
- Parameters
471
- ----------
472
- figsize
473
- The size of the figure in inches.
474
- xscale
475
- The scale of the x-axis. If provided, the x-axis scale of all the subplots will be set to this value.
476
- xticks
477
- The list of x-axis tick values for all the subplots. If provided, the x-axis ticks of all the subplots will
478
- be set to these values.
479
-
480
- Returns
481
- -------
482
- Figure
483
- The matplotlib Figure object that contains the created subplots.
484
-
485
578
  """
486
- using_ldtk = isinstance(self._tsa.ldmodel, LDTkLD)
579
+ if ax is None:
580
+ fig, ax = subplots(figsize=figsize, constrained_layout=True)
581
+ else:
582
+ fig = ax.figure
487
583
 
488
- if not using_ldtk:
489
- figsize = figsize or (13, 4)
490
- fig, axs = subplots(3, 1, figsize=figsize, sharex='all', sharey='all')
491
- axl, axk, axw = axs
584
+ ndata = self.data.size
492
585
 
493
- axl.vlines(self._tsa.ld_knots, 0.1, 0.5, ec='k')
494
- axl.text(0.01, 0.90, 'Limb darkening knots', va='top', transform=axl.transAxes)
495
- else:
496
- figsize = figsize or (13, 2*4/3)
497
- fig, axs = subplots(2, 1, figsize=figsize, sharex='all', sharey='all')
498
- axk, axw = axs
499
- axl = None
500
-
501
- axk.vlines(self._tsa.k_knots, 0.1, 0.5, ec='k')
502
- axk.text(0.01, 0.90, 'Radius ratio knots', va='top', transform=axk.transAxes)
503
- for ds in self.data:
504
- axw.vlines(ds.wavelength, 0.1, 0.5, ec='k')
505
- axw.text(0.01, 0.90, 'Wavelength bins', va='top', transform=axw.transAxes)
506
-
507
- if not using_ldtk:
508
- sb.despine(ax=axl, top=False, bottom=True, right=False)
509
- sb.despine(ax=axk, top=True, bottom=True, right=False)
510
- else:
511
- sb.despine(ax=axk, top=False, bottom=True, right=False)
586
+ for i, d in enumerate(self.data):
587
+ ax.vlines(d.wavelength, ymin=i*yshift, ymax=i*yshift+mh, colors=c, lw=lw)
512
588
 
513
- sb.despine(ax=axw, top=True, bottom=False, right=False)
514
- setp(axs, xlim=(self.data.wlmin-0.02, self.data.wlmax+0.02), yticks=[], ylim=(0, 0.9))
515
- setp(axw, xlabel=r'Wavelength [$\mu$m]')
516
- setp(axs[0].get_xticklines(), visible=False)
517
- setp(axs[0].get_xticklabels(), visible=False)
518
- setp(axs[1].get_xticklines(), visible=False)
519
- setp(axs[-1].get_xticklines(), visible=True)
589
+ i = ndata + 1
590
+ ax.vlines(self._tsa.ld_knots, ymin=i*yshift, ymax=i*yshift+mh, colors=c, lw=lw)
591
+
592
+ i = ndata + 3
593
+ ax.vlines(self.k_knots, ymin=i*yshift, ymax=i*yshift+mh, colors=c, lw=lw)
520
594
 
521
595
  if xscale:
522
- setp(axs, xscale=xscale)
596
+ setp(ax, xscale=xscale)
523
597
  if xticks is not None:
524
- [ax.set_xticks(xticks, labels=xticks) for ax in axs]
525
- fig.tight_layout()
598
+ ax.set_xticks(xticks, labels=xticks)
599
+
600
+ setp(ax, yticks=[], xlim=(self.data.wlmin-side_margin, self.data.wlmax+side_margin), xlabel=r'Wavelength [$\mu$m]')
601
+ ax.set_yticks(concatenate([arange(ndata), arange(ndata+1, ndata+4, 2)])*yshift+0.5*mh, labels=[n.replace("_", " ") for n in self.data.names] + ["Limb darkening knots", "Radius ratio knots"])
526
602
  return fig
527
603
 
528
604
  def fit_white(self, niter: int = 500) -> None:
@@ -537,10 +613,10 @@ class ExoIris:
537
613
  self._wa.optimize_global(niter, plot_convergence=False, use_tqdm=False)
538
614
  self._wa.optimize()
539
615
  pv = self._wa._local_minimization.x
540
- self.period = pv[1]
616
+ self.period = pv[0]
541
617
  self.zero_epoch = self._wa.transit_center
542
618
  self.transit_duration = self._wa.transit_duration
543
- self.data.mask_transit(pv[0], pv[1], self.transit_duration)
619
+ self.data.mask_transit(self.zero_epoch, self.period, self.transit_duration)
544
620
 
545
621
  def plot_white(self, axs=None, figsize: tuple[float, float] | None = None, ncols: int | None=None) -> Figure:
546
622
  """Plot the white light curve data with the best-fit model.
@@ -556,7 +632,25 @@ class ExoIris:
556
632
  """
557
633
  return self._wa.plot(axs=axs, figsize=figsize, ncols=ncols or min(self.data.size, 2))
558
634
 
559
- def plot_white_gp_predictions(self, axs = None, ncol: int = 1, figsize: tuple[float, float] | None = None):
635
+ def plot_white_gp_predictions(self, axs = None, ncol: int = 1, figsize: tuple[float, float] | None = None) -> None:
636
+ """Plot the predictions of a Gaussian Process model for white light curves and residuals.
637
+
638
+ Parameters
639
+ ----------
640
+ axs
641
+ Axes array in which the plots are drawn. If None, new subplots are generated.
642
+ ncol
643
+ The number of columns for the created grid of subplots. Used only if axs is None.
644
+ figsize
645
+ Size of the figure in inches (width, height). Used only if axs is None. If None,
646
+ a default size is used.
647
+
648
+ Notes
649
+ -----
650
+ The number of rows for the subplots is determined dynamically based on the shape of
651
+ the data and the number of columns specified (ncol). If the provided axes array (axs)
652
+ does not accommodate all the subplots, the behavior is undefined.
653
+ """
560
654
  ndata = self.data.size
561
655
 
562
656
  if axs is None:
@@ -657,12 +751,18 @@ class ExoIris:
657
751
 
658
752
  pv0 = self._wa._local_minimization.x
659
753
  x0 = self._tsa.ps.sample_from_prior(npop)
660
- x0[:, 0] = normal(pv0[2], 0.05, size=npop)
661
- x0[:, 1] = normal(pv0[0], 1e-4, size=npop)
662
- x0[:, 2] = normal(pv0[1], 1e-5, size=npop)
663
- x0[:, 3] = clip(normal(pv0[3], 0.01, size=npop), 0.0, 1.0)
754
+ x0[:, 0] = clip(normal(pv0[1], 0.05, size=npop), 0.01, inf)
755
+ x0[:, 1] = clip(normal(pv0[0], 1e-4, size=npop), 0.01, inf)
756
+ x0[:, 2] = clip(normal(pv0[2], 1e-3, size=npop), 0.0, 1.0)
757
+
758
+ nep = max(self.data.epoch_groups) + 1
759
+ for i in range(nep):
760
+ pida = self.ps.find_pid(f'tc_{i:02d}')
761
+ pidb = self._wa.ps.find_pid(f'tc_{i:02d}')
762
+ x0[:, pida] = normal(pv0[pidb], 0.001, size=npop)
763
+
664
764
  sl = self._tsa._sl_rratios
665
- x0[:, sl] = normal(sqrt(pv0[4]), 0.001, size=(npop, self.nk))
765
+ x0[:, sl] = normal(sqrt(pv0[self._wa.ps.find_pid('k2')]), 0.001, size=(npop, self.nk))
666
766
  for i in range(sl.start, sl.stop):
667
767
  x0[:, i] = clip(x0[:, i], 1.001*self.ps[i].prior.a, 0.999*self.ps[i].prior.b)
668
768
 
@@ -1043,11 +1143,16 @@ class ExoIris:
1043
1143
  pri.header['t14'] = self.transit_duration
1044
1144
  pri.header['ndgroups'] = self.data.size
1045
1145
  pri.header['interp'] = self._tsa.interpolation
1146
+ pri.header['noise'] = self._tsa.noise_model
1046
1147
 
1148
+ # Priors
1149
+ # ======
1047
1150
  pr = pf.ImageHDU(name='priors')
1048
1151
  priors = [pickle.dumps(p) for p in self.ps]
1049
1152
  pr.header['priors'] = json.dumps(codecs.encode(pickle.dumps(priors), "base64").decode())
1050
1153
 
1154
+ # Limb darkening
1155
+ # ==============
1051
1156
  if isinstance(self._tsa.ldmodel, LDTkLD):
1052
1157
  ldm = self._tsa.ldmodel
1053
1158
  pri.header['ldmodel'] = 'ldtk'
@@ -1056,11 +1161,56 @@ class ExoIris:
1056
1161
  else:
1057
1162
  pri.header['ldmodel'] = self._tsa.ldmodel
1058
1163
 
1164
+ # Knots
1165
+ # =====
1059
1166
  k_knots = pf.ImageHDU(self._tsa.k_knots, name='k_knots')
1060
1167
  ld_knots = pf.ImageHDU(self._tsa.ld_knots, name='ld_knots')
1061
1168
  hdul = pf.HDUList([pri, k_knots, ld_knots, pr])
1062
1169
  hdul += self.data.export_fits()
1063
1170
 
1171
+ # White light curve analysis
1172
+ # ==========================
1173
+ if self._wa is not None and self._wa._local_minimization is not None:
1174
+ wa_data = pf.BinTableHDU(
1175
+ Table(
1176
+ [
1177
+ self._wa.lcids,
1178
+ self._wa.timea,
1179
+ concatenate(self.white_models),
1180
+ self._wa.ofluxa,
1181
+ concatenate(self._wa.std_errors),
1182
+ ],
1183
+ names="id time flux_mod flux_obs flux_obs_err".split(),
1184
+ ), name='white_data'
1185
+ )
1186
+ hdul.append(wa_data)
1187
+
1188
+ names = []
1189
+ counts = {}
1190
+ for p in self._wa.ps.names:
1191
+ if p not in counts.keys():
1192
+ counts[p] = 0
1193
+ names.append(p)
1194
+ else:
1195
+ counts[p] += 1
1196
+ names.append(f'{p}_{counts[p]}')
1197
+
1198
+ wa_params = pf.BinTableHDU(Table(self._wa._local_minimization.x, names=names), name='white_params')
1199
+ hdul.append(wa_params)
1200
+
1201
+ # Spots
1202
+ # =====
1203
+ if self._tsa.spot_model is not None:
1204
+ pri.header['spots'] = True
1205
+ pri.header["sp_tstar"] = self._tsa.spot_model.tphot
1206
+ pri.header["sp_refwl"] = self._tsa.spot_model.wlref
1207
+ pri.header["sp_tlse"] = self._tsa.spot_model.include_tlse
1208
+ pri.header["nspots"] = self.nspots
1209
+ for i in range(self.nspots):
1210
+ pri.header[f"sp{i+1:02d}_eg"] = self._tsa.spot_model.spot_epoch_groups[i]
1211
+
1212
+ # Global optimization results
1213
+ # ===========================
1064
1214
  if self._tsa.de is not None:
1065
1215
  de = pf.BinTableHDU(Table(self._tsa._de_population, names=self.ps.names), name='DE')
1066
1216
  de.header['npop'] = self._tsa.de.n_pop
@@ -1068,6 +1218,8 @@ class ExoIris:
1068
1218
  de.header['imin'] = self._tsa.de.minimum_index
1069
1219
  hdul.append(de)
1070
1220
 
1221
+ # MCMC results
1222
+ # ============
1071
1223
  if self._tsa.sampler is not None:
1072
1224
  mc = pf.BinTableHDU(Table(self._tsa.sampler.flatchain, names=self.ps.names), name='MCMC')
1073
1225
  mc.header['npop'] = self._tsa.sampler.nwalkers
@@ -1141,11 +1293,8 @@ class ExoIris:
1141
1293
  log10_rho_bounds: float | tuple[float, float] = (-5, 0),
1142
1294
  log10_sigma_prior=None, log10_rho_prior=None,
1143
1295
  npop: int = 10, niter: int = 100):
1144
- if self._tsa.noise_model != 'fixed_gp':
1145
- raise ValueError("The noise model must be set to 'fixed_gp' before the hyperparameter optimization.")
1146
-
1147
- if self._wa is None:
1148
- raise ValueError("The white light curves must be fit using 'fit_white()' before the hyperparameter optimization.")
1296
+ if self._tsa.noise_model not in ('fixed_gp', 'free_gp'):
1297
+ raise ValueError("The noise model must be set to 'fixed_gp' or 'free_gp' before the hyperparameter optimization.")
1149
1298
 
1150
1299
  if log10_rho_prior is not None:
1151
1300
  if isinstance(log10_rho_prior, Sequence):
@@ -1172,7 +1321,7 @@ class ExoIris:
1172
1321
 
1173
1322
  match log10_sigma_bounds:
1174
1323
  case None:
1175
- sb = [log10_sigma_guess-1, log10_sigma_guess+1]
1324
+ sb = [log10_sigma_guess - 1, log10_sigma_guess + 1]
1176
1325
  case _ if isinstance(log10_sigma_bounds, Sequence):
1177
1326
  sb = log10_sigma_bounds
1178
1327
  case _ if isinstance(log10_sigma_bounds, float):
@@ -1180,7 +1329,7 @@ class ExoIris:
1180
1329
 
1181
1330
  match log10_rho_bounds:
1182
1331
  case None:
1183
- rb = [-5, -2]
1332
+ rb = [-5, -2]
1184
1333
  case _ if isinstance(log10_rho_bounds, Sequence):
1185
1334
  rb = log10_rho_bounds
1186
1335
  case _ if isinstance(log10_rho_bounds, float):
exoiris/ldtkld.py CHANGED
@@ -23,18 +23,18 @@ from numpy import sqrt, array, concatenate
23
23
  from pytransit import LDTkLD as PTLDTkLD
24
24
  from ldtk import BoxcarFilter
25
25
 
26
- from .tsdata import TSData, TSDataSet
26
+ from .tsdata import TSData, TSDataGroup
27
27
 
28
28
 
29
29
  class LDTkLD(PTLDTkLD):
30
- def __init__(self, data: TSDataSet | TSData,
30
+ def __init__(self, data: TSDataGroup | TSData,
31
31
  teff: tuple[float, float],
32
32
  logg: tuple[float, float],
33
33
  metal: tuple[float, float],
34
34
  cache: str | Path | None = None,
35
35
  dataset: str = 'vis-lowres') -> None:
36
36
 
37
- data = TSDataSet([data]) if isinstance(data, TSData) else data
37
+ data = TSDataGroup([data]) if isinstance(data, TSData) else data
38
38
  wl_edges = concatenate([array([d._wl_l_edges, d._wl_r_edges]) for d in data], axis=1).T
39
39
  filters = [BoxcarFilter(f"{0.5*(wla+wlb):08.5f}", wla*1e3, wlb*1e3) for wla, wlb in wl_edges]
40
40
  super().__init__(filters, teff, logg, metal, cache, dataset)