ExoIris 0.18.0__py3-none-any.whl → 0.19.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
exoiris/__init__.py CHANGED
@@ -15,7 +15,7 @@
15
15
  # along with this program. If not, see <https://www.gnu.org/licenses/>.
16
16
 
17
17
  from .exoiris import ExoIris, load_model # noqa
18
- from .tsdata import TSData, TSDataSet # noqa
18
+ from .tsdata import TSData, TSDataGroup # noqa
19
19
  from .binning import Binning # noqa
20
20
  from .ldtkld import LDTkLD # noqa
21
21
  from .tslpf import clean_knots # noqa
exoiris/exoiris.py CHANGED
@@ -24,27 +24,24 @@ from typing import Optional, Callable, Any, Literal
24
24
 
25
25
  import astropy.io.fits as pf
26
26
  import astropy.units as u
27
- import emcee
27
+ import matplotlib.axes
28
28
  import pandas as pd
29
- import pytransit.utils.de
30
- import seaborn as sb
31
29
  from astropy.table import Table
32
30
  from celerite2 import GaussianProcess, terms
33
31
  from emcee import EnsembleSampler
34
32
  from matplotlib.pyplot import subplots, setp, figure, Figure, Axes
35
- from numpy import (where, sqrt, clip, percentile, median, squeeze, floor, ndarray,
36
- array, inf, newaxis, arange, tile, sort, argsort, concatenate, full, nan, r_, nanpercentile, log10,
37
- ceil)
38
- from numpy.random import normal, permutation
33
+ from numpy import (any, where, sqrt, clip, percentile, median, squeeze, floor, ndarray, isfinite,
34
+ array, inf, arange, argsort, concatenate, full, nan, r_, nanpercentile, log10,
35
+ ceil, unique)
36
+ from numpy.random import normal
39
37
  from pytransit import UniformPrior, NormalPrior
40
- from pytransit.orbits import epoch
41
38
  from pytransit.param import ParameterSet
42
39
  from pytransit.utils.de import DiffEvol
43
40
  from scipy.stats import norm
44
41
  from uncertainties import UFloat
45
42
 
46
43
  from .ldtkld import LDTkLD
47
- from .tsdata import TSData, TSDataSet
44
+ from .tsdata import TSData, TSDataGroup
48
45
  from .tslpf import TSLPF
49
46
  from .wlpf import WhiteLPF
50
47
 
@@ -73,7 +70,7 @@ def load_model(fname: Path | str, name: str | None = None):
73
70
  If the file format is invalid or does not match the expected format.
74
71
  """
75
72
  with pf.open(fname) as hdul:
76
- data = TSDataSet.import_fits(hdul)
73
+ data = TSDataGroup.import_fits(hdul)
77
74
 
78
75
  if hdul[0].header['LDMODEL'] == 'ldtk':
79
76
  filters, teff, logg, metal, dataset = pickle.loads(codecs.decode(json.loads(hdul[0].header['LDTKLD']).encode(), "base64"))
@@ -86,11 +83,28 @@ def load_model(fname: Path | str, name: str | None = None):
86
83
  except KeyError:
87
84
  ip = 'bspline'
88
85
 
89
- #TODO: save and load the noise model information
90
- a = ExoIris(name or hdul[0].header['NAME'], ldmodel=ldm, data=data, interpolation=ip)
86
+ try:
87
+ noise_model = hdul[0].header['NOISE']
88
+ except KeyError:
89
+ noise_model = "white"
90
+
91
+ a = ExoIris(name or hdul[0].header['NAME'], ldmodel=ldm, data=data, noise_model=noise_model, interpolation=ip)
91
92
  a.set_radius_ratio_knots(hdul['K_KNOTS'].data.astype('d'))
92
93
  a.set_limb_darkening_knots(hdul['LD_KNOTS'].data.astype('d'))
93
94
 
95
+ # Read the white light curve models if they exist.
96
+ try:
97
+ tb = Table.read(hdul['WHITE_DATA'])
98
+ white_ids = tb['id'].data
99
+ uids = unique(white_ids)
100
+ a._white_times = [tb['time'].data[white_ids == i] for i in uids]
101
+ a._white_fluxes = [tb['flux_obs'].data[white_ids == i] for i in uids]
102
+ a._white_errors = [tb['flux_obs_err'].data[white_ids == i] for i in uids]
103
+ a._white_models = [tb['flux_mod'].data[white_ids == i] for i in uids]
104
+
105
+ except KeyError:
106
+ pass
107
+
94
108
  try:
95
109
  a.period = hdul[0].header['P']
96
110
  a.zero_epoch = hdul[0].header['T0']
@@ -116,8 +130,8 @@ class ExoIris:
116
130
  """The core ExoIris class providing tools for exoplanet transit spectroscopy.
117
131
  """
118
132
 
119
- def __init__(self, name: str, ldmodel, data: TSDataSet | TSData, nk: int = 50, nldc: int = 10, nthreads: int = 1,
120
- tmpars: dict | None = None, noise_model: str = 'white',
133
+ def __init__(self, name: str, ldmodel, data: TSDataGroup | TSData, nk: int = 50, nldc: int = 10, nthreads: int = 1,
134
+ tmpars: dict | None = None, noise_model: Literal["white", "fixed_gp", "free_gp"] = 'white',
121
135
  interpolation: Literal['bspline', 'pchip', 'makima'] = 'bspline'):
122
136
  """
123
137
  Parameters
@@ -139,10 +153,31 @@ class ExoIris:
139
153
  noise_model
140
154
  The noise model to use. Should be either "white" for white noise or "fixed_gp" for Gaussian Process.
141
155
  """
142
- data = TSDataSet([data]) if isinstance(data, TSData) else data
143
- self._tsa: TSLPF = TSLPF(name, ldmodel, data, nk=nk, nldc=nldc, nthreads=nthreads, tmpars=tmpars,
156
+ data = TSDataGroup([data]) if isinstance(data, TSData) else data
157
+
158
+ for d in data:
159
+ if any(~isfinite(d.fluxes[d.mask])):
160
+ raise ValueError(f"The {d.name} data set flux array contains unmasked noninfinite values.")
161
+
162
+ if any(~isfinite(d.errors[d.mask])):
163
+ raise ValueError(f"The {d.name} data set error array contains unmasked noninfinite values.")
164
+
165
+ ngs = array(data.noise_groups)
166
+ if not ((ngs.min() == 0) and (ngs.max() + 1 == unique(ngs).size)):
167
+ raise ValueError("The noise groups must start from 0 and be consecutive.")
168
+
169
+ ogs = array(data.offset_groups)
170
+ if not ((ogs.min() == 0) and (ogs.max() + 1 == unique(ogs).size)):
171
+ raise ValueError("The offset groups must start from 0 and be consecutive.")
172
+
173
+ egs = array(data.epoch_groups)
174
+ if not ((egs.min() == 0) and (egs.max() + 1 == unique(egs).size)):
175
+ raise ValueError("The epoch groups must start from 0 and be consecutive.")
176
+
177
+ self._tsa = TSLPF(self, name, ldmodel, data, nk=nk, nldc=nldc, nthreads=nthreads, tmpars=tmpars,
144
178
  noise_model=noise_model, interpolation=interpolation)
145
- self._wa: WhiteLPF | None = None
179
+ self._wa: None | WhiteLPF = None
180
+
146
181
  self.nthreads: int = nthreads
147
182
 
148
183
  self.period: float | None = None
@@ -150,6 +185,11 @@ class ExoIris:
150
185
  self.transit_duration: float | None= None
151
186
  self._tref = floor(self.data.tmin)
152
187
 
188
+ self._white_times: None | list[ndarray] = None
189
+ self._white_fluxes: None | list[ndarray] = None
190
+ self._white_errors: None | list[ndarray] = None
191
+ self._white_models: None | list[ndarray] = None
192
+
153
193
  def lnposterior(self, pvp: ndarray) -> ndarray:
154
194
  """Calculate the log posterior probability for a single parameter vector or an array of parameter vectors.
155
195
 
@@ -180,7 +220,7 @@ class ExoIris:
180
220
  """
181
221
  self._tsa.set_noise_model(noise_model)
182
222
 
183
- def set_data(self, data: TSData | TSDataSet) -> None:
223
+ def set_data(self, data: TSData | TSDataGroup) -> None:
184
224
  """Set the model data.
185
225
 
186
226
  Parameters
@@ -188,7 +228,7 @@ class ExoIris:
188
228
  data
189
229
  The spectroscopic transit light curve.
190
230
  """
191
- data = TSDataSet([data]) if isinstance(data, TSData) else data
231
+ data = TSDataGroup([data]) if isinstance(data, TSData) else data
192
232
  self._tsa.set_data(data)
193
233
 
194
234
  def set_prior(self, parameter: Literal['radius ratios', 'baselines', 'wn multipliers'] | str,
@@ -279,7 +319,7 @@ class ExoIris:
279
319
  metal = (metal.n, metal.s) if isinstance(metal, UFloat) else metal
280
320
  self._tsa.set_ldtk_prior(teff, logg, metal, dataset, width, uncertainty_multiplier)
281
321
 
282
- def set_gp_hyperparameters(self, sigma: float, rho: float) -> None:
322
+ def set_gp_hyperparameters(self, sigma: float, rho: float, idata: None | int = None) -> None:
283
323
  """Set Gaussian Process (GP) hyperparameters assuming a Matern-3/2 kernel.
284
324
 
285
325
  Parameters
@@ -288,8 +328,10 @@ class ExoIris:
288
328
  The kernel amplitude parameter.
289
329
  rho
290
330
  The length scale parameter.
331
+ idata
332
+ The data set for which to set the hyperparameters. If None, the hyperparameters are set for all data sets.
291
333
  """
292
- self._tsa.set_gp_hyperparameters(sigma, rho)
334
+ self._tsa.set_gp_hyperparameters(sigma, rho, idata)
293
335
 
294
336
  def set_gp_kernel(self, kernel: terms.Term) -> None:
295
337
  """Set the Gaussian Process (GP) kernel.
@@ -311,7 +353,7 @@ class ExoIris:
311
353
  self._tsa.name = name
312
354
 
313
355
  @property
314
- def data(self) -> TSDataSet:
356
+ def data(self) -> TSDataGroup:
315
357
  """Analysis data set."""
316
358
  return self._tsa.data
317
359
 
@@ -378,22 +420,35 @@ class ExoIris:
378
420
  @property
379
421
  def white_times(self) -> list[ndarray]:
380
422
  """White light curve time arrays."""
381
- return self._wa.times
423
+ if self._wa is None:
424
+ return self._white_times
425
+ else:
426
+ return self._wa.times
382
427
 
383
428
  @property
384
429
  def white_fluxes(self) -> list[ndarray]:
385
430
  """White light curve flux arrays."""
386
- return self._wa.fluxes
431
+ if self._wa is None:
432
+ return self._white_fluxes
433
+ else:
434
+ return self._wa.fluxes
387
435
 
388
436
  @property
389
437
  def white_models(self) -> list[ndarray]:
390
- fm = self._wa.flux_model(self._wa._local_minimization.x)
391
- return [fm[sl] for sl in self._wa.lcslices]
438
+ """Fitted white light curve flux model arrays."""
439
+ if self._wa is None:
440
+ return self._white_models
441
+ else:
442
+ fm = self._wa.flux_model(self._wa._local_minimization.x)
443
+ return [fm[sl] for sl in self._wa.lcslices]
392
444
 
393
445
  @property
394
446
  def white_errors(self) -> list[ndarray]:
395
447
  """White light curve flux error arrays."""
396
- return self._wa.std_errors
448
+ if self._wa is None:
449
+ return self._white_errors
450
+ else:
451
+ return self._wa.std_errors
397
452
 
398
453
  def add_radius_ratio_knots(self, knot_wavelengths: Sequence) -> None:
399
454
  """Add radius ratio (k) knots.
@@ -464,65 +519,36 @@ class ExoIris:
464
519
  """Print the model parameterization."""
465
520
  self._tsa.print_parameters(1)
466
521
 
467
- def plot_setup(self, figsize: tuple[float, float] | None =None, xscale: str | None = None, xticks: Sequence | None = None) -> Figure:
522
+ def plot_setup(self, figsize: tuple[float, float] | None = None,
523
+ ax: matplotlib.axes.Axes | None = None,
524
+ xscale: str | None = None, xticks: Sequence | None = None,
525
+ yshift: float = 0.1, mh:float = 0.08, side_margin: float = 0.05,
526
+ lw: float = 0.5, c='k') -> Figure:
468
527
  """Plot the model setup with limb darkening knots, radius ratio knots, and data binning.
469
-
470
- Parameters
471
- ----------
472
- figsize
473
- The size of the figure in inches.
474
- xscale
475
- The scale of the x-axis. If provided, the x-axis scale of all the subplots will be set to this value.
476
- xticks
477
- The list of x-axis tick values for all the subplots. If provided, the x-axis ticks of all the subplots will
478
- be set to these values.
479
-
480
- Returns
481
- -------
482
- Figure
483
- The matplotlib Figure object that contains the created subplots.
484
-
485
528
  """
486
- using_ldtk = isinstance(self._tsa.ldmodel, LDTkLD)
529
+ if ax is None:
530
+ fig, ax = subplots(figsize=figsize, constrained_layout=True)
531
+ else:
532
+ fig = ax.figure
487
533
 
488
- if not using_ldtk:
489
- figsize = figsize or (13, 4)
490
- fig, axs = subplots(3, 1, figsize=figsize, sharex='all', sharey='all')
491
- axl, axk, axw = axs
534
+ ndata = self.data.size
492
535
 
493
- axl.vlines(self._tsa.ld_knots, 0.1, 0.5, ec='k')
494
- axl.text(0.01, 0.90, 'Limb darkening knots', va='top', transform=axl.transAxes)
495
- else:
496
- figsize = figsize or (13, 2*4/3)
497
- fig, axs = subplots(2, 1, figsize=figsize, sharex='all', sharey='all')
498
- axk, axw = axs
499
- axl = None
500
-
501
- axk.vlines(self._tsa.k_knots, 0.1, 0.5, ec='k')
502
- axk.text(0.01, 0.90, 'Radius ratio knots', va='top', transform=axk.transAxes)
503
- for ds in self.data:
504
- axw.vlines(ds.wavelength, 0.1, 0.5, ec='k')
505
- axw.text(0.01, 0.90, 'Wavelength bins', va='top', transform=axw.transAxes)
506
-
507
- if not using_ldtk:
508
- sb.despine(ax=axl, top=False, bottom=True, right=False)
509
- sb.despine(ax=axk, top=True, bottom=True, right=False)
510
- else:
511
- sb.despine(ax=axk, top=False, bottom=True, right=False)
536
+ for i, d in enumerate(self.data):
537
+ ax.vlines(d.wavelength, ymin=i*yshift, ymax=i*yshift+mh, colors=c, lw=lw)
512
538
 
513
- sb.despine(ax=axw, top=True, bottom=False, right=False)
514
- setp(axs, xlim=(self.data.wlmin-0.02, self.data.wlmax+0.02), yticks=[], ylim=(0, 0.9))
515
- setp(axw, xlabel=r'Wavelength [$\mu$m]')
516
- setp(axs[0].get_xticklines(), visible=False)
517
- setp(axs[0].get_xticklabels(), visible=False)
518
- setp(axs[1].get_xticklines(), visible=False)
519
- setp(axs[-1].get_xticklines(), visible=True)
539
+ i = ndata + 1
540
+ ax.vlines(self._tsa.ld_knots, ymin=i*yshift, ymax=i*yshift+mh, colors=c, lw=lw)
541
+
542
+ i = ndata + 3
543
+ ax.vlines(self.k_knots, ymin=i*yshift, ymax=i*yshift+mh, colors=c, lw=lw)
520
544
 
521
545
  if xscale:
522
- setp(axs, xscale=xscale)
546
+ setp(ax, xscale=xscale)
523
547
  if xticks is not None:
524
- [ax.set_xticks(xticks, labels=xticks) for ax in axs]
525
- fig.tight_layout()
548
+ ax.set_xticks(xticks, labels=xticks)
549
+
550
+ setp(ax, yticks=[], xlim=(self.data.wlmin-side_margin, self.data.wlmax+side_margin), xlabel=r'Wavelength [$\mu$m]')
551
+ ax.set_yticks(concatenate([arange(ndata), arange(ndata+1, ndata+4, 2)])*yshift+0.5*mh, labels=[n.replace("_", " ") for n in self.data.names] + ["Limb darkening knots", "Radius ratio knots"])
526
552
  return fig
527
553
 
528
554
  def fit_white(self, niter: int = 500) -> None:
@@ -537,10 +563,10 @@ class ExoIris:
537
563
  self._wa.optimize_global(niter, plot_convergence=False, use_tqdm=False)
538
564
  self._wa.optimize()
539
565
  pv = self._wa._local_minimization.x
540
- self.period = pv[1]
566
+ self.period = pv[0]
541
567
  self.zero_epoch = self._wa.transit_center
542
568
  self.transit_duration = self._wa.transit_duration
543
- self.data.mask_transit(pv[0], pv[1], self.transit_duration)
569
+ self.data.mask_transit(self.zero_epoch, self.period, self.transit_duration)
544
570
 
545
571
  def plot_white(self, axs=None, figsize: tuple[float, float] | None = None, ncols: int | None=None) -> Figure:
546
572
  """Plot the white light curve data with the best-fit model.
@@ -556,7 +582,25 @@ class ExoIris:
556
582
  """
557
583
  return self._wa.plot(axs=axs, figsize=figsize, ncols=ncols or min(self.data.size, 2))
558
584
 
559
- def plot_white_gp_predictions(self, axs = None, ncol: int = 1, figsize: tuple[float, float] | None = None):
585
+ def plot_white_gp_predictions(self, axs = None, ncol: int = 1, figsize: tuple[float, float] | None = None) -> None:
586
+ """Plot the predictions of a Gaussian Process model for white light curves and residuals.
587
+
588
+ Parameters
589
+ ----------
590
+ axs
591
+ Axes array in which the plots are drawn. If None, new subplots are generated.
592
+ ncol
593
+ The number of columns for the created grid of subplots. Used only if axs is None.
594
+ figsize
595
+ Size of the figure in inches (width, height). Used only if axs is None. If None,
596
+ a default size is used.
597
+
598
+ Notes
599
+ -----
600
+ The number of rows for the subplots is determined dynamically based on the shape of
601
+ the data and the number of columns specified (ncol). If the provided axes array (axs)
602
+ does not accommodate all the subplots, the behavior is undefined.
603
+ """
560
604
  ndata = self.data.size
561
605
 
562
606
  if axs is None:
@@ -657,12 +701,18 @@ class ExoIris:
657
701
 
658
702
  pv0 = self._wa._local_minimization.x
659
703
  x0 = self._tsa.ps.sample_from_prior(npop)
660
- x0[:, 0] = normal(pv0[2], 0.05, size=npop)
661
- x0[:, 1] = normal(pv0[0], 1e-4, size=npop)
662
- x0[:, 2] = normal(pv0[1], 1e-5, size=npop)
663
- x0[:, 3] = clip(normal(pv0[3], 0.01, size=npop), 0.0, 1.0)
704
+ x0[:, 0] = clip(normal(pv0[1], 0.05, size=npop), 0.01, inf)
705
+ x0[:, 1] = clip(normal(pv0[0], 1e-4, size=npop), 0.01, inf)
706
+ x0[:, 2] = clip(normal(pv0[2], 1e-3, size=npop), 0.0, 1.0)
707
+
708
+ nep = max(self.data.epoch_groups) + 1
709
+ for i in range(nep):
710
+ pida = self.ps.find_pid(f'tc_{i:02d}')
711
+ pidb = self._wa.ps.find_pid(f'tc_{i:02d}')
712
+ x0[:, pida] = normal(pv0[pidb], 0.001, size=npop)
713
+
664
714
  sl = self._tsa._sl_rratios
665
- x0[:, sl] = normal(sqrt(pv0[4]), 0.001, size=(npop, self.nk))
715
+ x0[:, sl] = normal(sqrt(pv0[self._wa.ps.find_pid('k2')]), 0.001, size=(npop, self.nk))
666
716
  for i in range(sl.start, sl.stop):
667
717
  x0[:, i] = clip(x0[:, i], 1.001*self.ps[i].prior.a, 0.999*self.ps[i].prior.b)
668
718
 
@@ -1043,6 +1093,7 @@ class ExoIris:
1043
1093
  pri.header['t14'] = self.transit_duration
1044
1094
  pri.header['ndgroups'] = self.data.size
1045
1095
  pri.header['interp'] = self._tsa.interpolation
1096
+ pri.header['noise'] = self._tsa.noise_model
1046
1097
 
1047
1098
  pr = pf.ImageHDU(name='priors')
1048
1099
  priors = [pickle.dumps(p) for p in self.ps]
@@ -1061,6 +1112,34 @@ class ExoIris:
1061
1112
  hdul = pf.HDUList([pri, k_knots, ld_knots, pr])
1062
1113
  hdul += self.data.export_fits()
1063
1114
 
1115
+ if self._wa is not None and self._wa._local_minimization is not None:
1116
+ wa_data = pf.BinTableHDU(
1117
+ Table(
1118
+ [
1119
+ self._wa.lcids,
1120
+ self._wa.timea,
1121
+ concatenate(self.white_models),
1122
+ self._wa.ofluxa,
1123
+ concatenate(self._wa.std_errors),
1124
+ ],
1125
+ names="id time flux_mod flux_obs flux_obs_err".split(),
1126
+ ), name='white_data'
1127
+ )
1128
+ hdul.append(wa_data)
1129
+
1130
+ names = []
1131
+ counts = {}
1132
+ for p in self._wa.ps.names:
1133
+ if p not in counts.keys():
1134
+ counts[p] = 0
1135
+ names.append(p)
1136
+ else:
1137
+ counts[p] += 1
1138
+ names.append(f'{p}_{counts[p]}')
1139
+
1140
+ wa_params = pf.BinTableHDU(Table(self._wa._local_minimization.x, names=names), name='white_params')
1141
+ hdul.append(wa_params)
1142
+
1064
1143
  if self._tsa.de is not None:
1065
1144
  de = pf.BinTableHDU(Table(self._tsa._de_population, names=self.ps.names), name='DE')
1066
1145
  de.header['npop'] = self._tsa.de.n_pop
@@ -1141,11 +1220,8 @@ class ExoIris:
1141
1220
  log10_rho_bounds: float | tuple[float, float] = (-5, 0),
1142
1221
  log10_sigma_prior=None, log10_rho_prior=None,
1143
1222
  npop: int = 10, niter: int = 100):
1144
- if self._tsa.noise_model != 'fixed_gp':
1145
- raise ValueError("The noise model must be set to 'fixed_gp' before the hyperparameter optimization.")
1146
-
1147
- if self._wa is None:
1148
- raise ValueError("The white light curves must be fit using 'fit_white()' before the hyperparameter optimization.")
1223
+ if self._tsa.noise_model not in ('fixed_gp', 'free_gp'):
1224
+ raise ValueError("The noise model must be set to 'fixed_gp' or 'free_gp' before the hyperparameter optimization.")
1149
1225
 
1150
1226
  if log10_rho_prior is not None:
1151
1227
  if isinstance(log10_rho_prior, Sequence):
@@ -1172,7 +1248,7 @@ class ExoIris:
1172
1248
 
1173
1249
  match log10_sigma_bounds:
1174
1250
  case None:
1175
- sb = [log10_sigma_guess-1, log10_sigma_guess+1]
1251
+ sb = [log10_sigma_guess - 1, log10_sigma_guess + 1]
1176
1252
  case _ if isinstance(log10_sigma_bounds, Sequence):
1177
1253
  sb = log10_sigma_bounds
1178
1254
  case _ if isinstance(log10_sigma_bounds, float):
@@ -1180,7 +1256,7 @@ class ExoIris:
1180
1256
 
1181
1257
  match log10_rho_bounds:
1182
1258
  case None:
1183
- rb = [-5, -2]
1259
+ rb = [-5, -2]
1184
1260
  case _ if isinstance(log10_rho_bounds, Sequence):
1185
1261
  rb = log10_rho_bounds
1186
1262
  case _ if isinstance(log10_rho_bounds, float):
exoiris/ldtkld.py CHANGED
@@ -23,18 +23,18 @@ from numpy import sqrt, array, concatenate
23
23
  from pytransit import LDTkLD as PTLDTkLD
24
24
  from ldtk import BoxcarFilter
25
25
 
26
- from .tsdata import TSData, TSDataSet
26
+ from .tsdata import TSData, TSDataGroup
27
27
 
28
28
 
29
29
  class LDTkLD(PTLDTkLD):
30
- def __init__(self, data: TSDataSet | TSData,
30
+ def __init__(self, data: TSDataGroup | TSData,
31
31
  teff: tuple[float, float],
32
32
  logg: tuple[float, float],
33
33
  metal: tuple[float, float],
34
34
  cache: str | Path | None = None,
35
35
  dataset: str = 'vis-lowres') -> None:
36
36
 
37
- data = TSDataSet([data]) if isinstance(data, TSData) else data
37
+ data = TSDataGroup([data]) if isinstance(data, TSData) else data
38
38
  wl_edges = concatenate([array([d._wl_l_edges, d._wl_r_edges]) for d in data], axis=1).T
39
39
  filters = [BoxcarFilter(f"{0.5*(wla+wlb):08.5f}", wla*1e3, wlb*1e3) for wla, wlb in wl_edges]
40
40
  super().__init__(filters, teff, logg, metal, cache, dataset)