ExoIris 0.22.0__tar.gz → 0.23.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. {exoiris-0.22.0 → exoiris-0.23.1}/CHANGELOG.md +16 -0
  2. {exoiris-0.22.0 → exoiris-0.23.1}/ExoIris.egg-info/PKG-INFO +1 -1
  3. {exoiris-0.22.0 → exoiris-0.23.1}/PKG-INFO +1 -1
  4. {exoiris-0.22.0 → exoiris-0.23.1}/exoiris/exoiris.py +11 -8
  5. {exoiris-0.22.0 → exoiris-0.23.1}/exoiris/tsdata.py +46 -12
  6. {exoiris-0.22.0 → exoiris-0.23.1}/exoiris/tslpf.py +57 -50
  7. {exoiris-0.22.0 → exoiris-0.23.1}/.github/workflows/python-package.yml +0 -0
  8. {exoiris-0.22.0 → exoiris-0.23.1}/.gitignore +0 -0
  9. {exoiris-0.22.0 → exoiris-0.23.1}/.readthedocs.yaml +0 -0
  10. {exoiris-0.22.0 → exoiris-0.23.1}/CODE_OF_CONDUCT.md +0 -0
  11. {exoiris-0.22.0 → exoiris-0.23.1}/ExoIris.egg-info/SOURCES.txt +0 -0
  12. {exoiris-0.22.0 → exoiris-0.23.1}/ExoIris.egg-info/dependency_links.txt +0 -0
  13. {exoiris-0.22.0 → exoiris-0.23.1}/ExoIris.egg-info/requires.txt +0 -0
  14. {exoiris-0.22.0 → exoiris-0.23.1}/ExoIris.egg-info/top_level.txt +0 -0
  15. {exoiris-0.22.0 → exoiris-0.23.1}/LICENSE +0 -0
  16. {exoiris-0.22.0 → exoiris-0.23.1}/README.md +0 -0
  17. {exoiris-0.22.0 → exoiris-0.23.1}/doc/Makefile +0 -0
  18. {exoiris-0.22.0 → exoiris-0.23.1}/doc/make.bat +0 -0
  19. {exoiris-0.22.0 → exoiris-0.23.1}/doc/requirements.txt +0 -0
  20. {exoiris-0.22.0 → exoiris-0.23.1}/doc/source/_static/css/custom.css +0 -0
  21. {exoiris-0.22.0 → exoiris-0.23.1}/doc/source/api/binning.rst +0 -0
  22. {exoiris-0.22.0 → exoiris-0.23.1}/doc/source/api/exoiris.rst +0 -0
  23. {exoiris-0.22.0 → exoiris-0.23.1}/doc/source/api/tsdata.rst +0 -0
  24. {exoiris-0.22.0 → exoiris-0.23.1}/doc/source/conf.py +0 -0
  25. {exoiris-0.22.0 → exoiris-0.23.1}/doc/source/examples/e01/01a_not_so_short_intro.ipynb +0 -0
  26. {exoiris-0.22.0 → exoiris-0.23.1}/doc/source/examples/e01/01b_short_intro.ipynb +0 -0
  27. {exoiris-0.22.0 → exoiris-0.23.1}/doc/source/examples/e01/02_increasing_knot_resolution.ipynb +0 -0
  28. {exoiris-0.22.0 → exoiris-0.23.1}/doc/source/examples/e01/03_increasing_data_resolution.ipynb +0 -0
  29. {exoiris-0.22.0 → exoiris-0.23.1}/doc/source/examples/e01/04_gaussian_processes.ipynb +0 -0
  30. {exoiris-0.22.0 → exoiris-0.23.1}/doc/source/examples/e01/05a_ldtkldm.ipynb +0 -0
  31. {exoiris-0.22.0 → exoiris-0.23.1}/doc/source/examples/e01/A2_full_data_resolution.ipynb +0 -0
  32. {exoiris-0.22.0 → exoiris-0.23.1}/doc/source/examples/e01/appendix_1_data_preparation.ipynb +0 -0
  33. {exoiris-0.22.0 → exoiris-0.23.1}/doc/source/examples/e01/data/README.txt +0 -0
  34. {exoiris-0.22.0 → exoiris-0.23.1}/doc/source/examples/e01/data/nirHiss_order_1.h5 +0 -0
  35. {exoiris-0.22.0 → exoiris-0.23.1}/doc/source/examples/e01/data/nirHiss_order_2.h5 +0 -0
  36. {exoiris-0.22.0 → exoiris-0.23.1}/doc/source/examples/e01/example1.png +0 -0
  37. {exoiris-0.22.0 → exoiris-0.23.1}/doc/source/examples/e01/plot_1.ipynb +0 -0
  38. {exoiris-0.22.0 → exoiris-0.23.1}/doc/source/examples/figures.ipynb +0 -0
  39. {exoiris-0.22.0 → exoiris-0.23.1}/doc/source/examples/friendly_introduction.ipynb +0 -0
  40. {exoiris-0.22.0 → exoiris-0.23.1}/doc/source/examples/index.rst +0 -0
  41. {exoiris-0.22.0 → exoiris-0.23.1}/doc/source/examples/k_knot_example.svg +0 -0
  42. {exoiris-0.22.0 → exoiris-0.23.1}/doc/source/examples/setup_multiprocessing.py +0 -0
  43. {exoiris-0.22.0 → exoiris-0.23.1}/doc/source/index.rst +0 -0
  44. {exoiris-0.22.0 → exoiris-0.23.1}/doc/source/install.rst +0 -0
  45. {exoiris-0.22.0 → exoiris-0.23.1}/exoiris/__init__.py +0 -0
  46. {exoiris-0.22.0 → exoiris-0.23.1}/exoiris/binning.py +0 -0
  47. {exoiris-0.22.0 → exoiris-0.23.1}/exoiris/ephemeris.py +0 -0
  48. {exoiris-0.22.0 → exoiris-0.23.1}/exoiris/ldtkld.py +0 -0
  49. {exoiris-0.22.0 → exoiris-0.23.1}/exoiris/loglikelihood.py +0 -0
  50. {exoiris-0.22.0 → exoiris-0.23.1}/exoiris/spotmodel.py +0 -0
  51. {exoiris-0.22.0 → exoiris-0.23.1}/exoiris/tsmodel.py +0 -0
  52. {exoiris-0.22.0 → exoiris-0.23.1}/exoiris/util.py +0 -0
  53. {exoiris-0.22.0 → exoiris-0.23.1}/exoiris/wlpf.py +0 -0
  54. {exoiris-0.22.0 → exoiris-0.23.1}/pyproject.toml +0 -0
  55. {exoiris-0.22.0 → exoiris-0.23.1}/requirements.txt +0 -0
  56. {exoiris-0.22.0 → exoiris-0.23.1}/setup.cfg +0 -0
  57. {exoiris-0.22.0 → exoiris-0.23.1}/tests/test_binning.py +0 -0
@@ -5,6 +5,22 @@ All notable changes to ExoIris will be documented in this file.
5
5
  The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
6
6
  and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
7
7
 
8
+ ## [0.23.1] - 2025-12-18
9
+
10
+ ### Fixed
11
+ - Fixed least-squares baseline fitting for transit models with NaNs.
12
+ - Added "bspline-cubic" interpolation option as an alias for "bspline". The "bspline" option will be removed in a future
13
+ release.
14
+
15
+ ## [0.23.0] - 2025-12-16
16
+
17
+ ### Changed
18
+ - Switched baseline modeling to a least-squares approach.
19
+
20
+ ### Fixed
21
+ - Corrected prior loading for parameter sets.
22
+ - Validated `samples` before posterior spectrum calculation to prevent runtime errors.
23
+
8
24
  ## [0.22.0] - 2025-12-13
9
25
 
10
26
  ### Added
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ExoIris
3
- Version: 0.22.0
3
+ Version: 0.23.1
4
4
  Summary: Easy and robust exoplanet transmission spectroscopy.
5
5
  Author-email: Hannu Parviainen <hannu@iac.es>
6
6
  License: GPLv3
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ExoIris
3
- Version: 0.22.0
3
+ Version: 0.23.1
4
4
  Summary: Easy and robust exoplanet transmission spectroscopy.
5
5
  Author-email: Hannu Parviainen <hannu@iac.es>
6
6
  License: GPLv3
@@ -142,8 +142,11 @@ def load_model(fname: Path | str, name: str | None = None):
142
142
  # Read the priors.
143
143
  # ================
144
144
  priors = pickle.loads(codecs.decode(json.loads(hdul['PRIORS'].header['PRIORS']).encode(), "base64"))
145
- a._tsa.ps = ParameterSet([pickle.loads(p) for p in priors])
146
- a._tsa.ps.freeze()
145
+ for praw in priors:
146
+ p = pickle.loads(praw)
147
+ if p.name in a._tsa.ps.names:
148
+ a._tsa.set_prior(p.name, p.prior)
149
+
147
150
  if 'DE' in hdul:
148
151
  a._tsa._de_population = Table(hdul['DE'].data).to_pandas().values
149
152
  a._tsa._de_imin = hdul['DE'].header['IMIN']
@@ -160,7 +163,7 @@ class ExoIris:
160
163
 
161
164
  def __init__(self, name: str, ldmodel, data: TSDataGroup | TSData, nk: int = 50, nldc: int = 10, nthreads: int = 1,
162
165
  tmpars: dict | None = None, noise_model: Literal["white", "fixed_gp", "free_gp"] = 'white',
163
- interpolation: Literal['nearest', 'linear', 'pchip', 'makima', 'bspline', 'bspline-quadratic'] = 'makima'):
166
+ interpolation: Literal['nearest', 'linear', 'pchip', 'makima', 'bspline', 'bspline-quadratic', 'bspline-cubic'] = 'linear'):
164
167
  """
165
168
  Parameters
166
169
  ----------
@@ -281,9 +284,6 @@ class ExoIris:
281
284
  if parameter == 'radius ratios':
282
285
  for l in self._tsa.k_knots:
283
286
  self.set_prior(f'k_{l:08.5f}', prior, *nargs)
284
- elif parameter == 'baselines':
285
- for par in self.ps[self._tsa._sl_baseline]:
286
- self.set_prior(par.name, prior, *nargs)
287
287
  elif parameter == 'wn multipliers':
288
288
  for par in self.ps[self._tsa._sl_wnm]:
289
289
  self.set_prior(par.name, prior, *nargs)
@@ -1192,7 +1192,7 @@ class ExoIris:
1192
1192
  wavelengths. The representation (radius ratio or depth) depends on the
1193
1193
  specified `kind`.
1194
1194
  """
1195
- if self.mcmc_chains is None:
1195
+ if self.mcmc_chains is None and samples is None:
1196
1196
  raise ValueError("Cannot calculate posterior transmission spectrum before running the MCMC sampler.")
1197
1197
 
1198
1198
  if kind not in ('radius_ratio', 'depth'):
@@ -1206,8 +1206,11 @@ class ExoIris:
1206
1206
  wavelengths.sort()
1207
1207
 
1208
1208
  k_posteriors = zeros((samples.shape[0], wavelengths.size))
1209
+ k_knots = self._tsa.k_knots.copy()
1209
1210
  for i, pv in enumerate(samples):
1210
- k_posteriors[i, :] = self._tsa._ip(wavelengths, self._tsa.k_knots, pv[self._tsa._sl_rratios])
1211
+ if self._tsa.free_k_knot_ids is not None:
1212
+ k_knots[self._tsa.free_k_knot_ids] = pv[self._tsa._sl_kloc]
1213
+ k_posteriors[i, :] = self._tsa._ip(wavelengths, k_knots, pv[self._tsa._sl_rratios])
1211
1214
 
1212
1215
  if kind == 'radius_ratio':
1213
1216
  return wavelengths, k_posteriors
@@ -27,6 +27,8 @@ from matplotlib.figure import Figure
27
27
  from matplotlib.pyplot import subplots, setp
28
28
  from matplotlib.ticker import LinearLocator, FuncFormatter
29
29
  from numpy import (
30
+ any,
31
+ all,
30
32
  isfinite,
31
33
  where,
32
34
  all,
@@ -51,6 +53,9 @@ from numpy import (
51
53
  nanmedian,
52
54
  nanmean,
53
55
  unique,
56
+ ascontiguousarray,
57
+ vstack,
58
+ ones_like,
54
59
  )
55
60
  from pytransit.orbits import fold
56
61
 
@@ -68,7 +73,7 @@ class TSData:
68
73
  noise_group: int = 0, wl_edges : Sequence | None = None, tm_edges : Sequence | None = None,
69
74
  transit_mask: ndarray | None = None, ephemeris: Ephemeris | None = None, n_baseline: int = 1,
70
75
  mask: ndarray = None, epoch_group: int = 0, offset_group: int = 0,
71
- mask_nonfinite_errors: bool = True) -> None:
76
+ mask_nonfinite_errors: bool = True, covs: ndarray | None = None) -> None:
72
77
  """
73
78
  Parameters
74
79
  ----------
@@ -99,7 +104,7 @@ class TSData:
99
104
  if transit_mask is not None and transit_mask.size != time.size:
100
105
  raise ValueError("The size of the out-of-transit mask array must match the size of the time array.")
101
106
 
102
- if n_baseline < 1:
107
+ if n_baseline < 0:
103
108
  raise ValueError("n_baseline must be greater than zero.")
104
109
 
105
110
  if noise_group < 0:
@@ -126,7 +131,18 @@ class TSData:
126
131
  self.mask &= isfinite(errors)
127
132
  self.fluxes: ndarray = where(self.mask, fluxes, nan)
128
133
  self.errors: ndarray = where(self.mask, errors, nan)
134
+
135
+ if covs is not None:
136
+ self.covs: ndarray = covs
137
+ else:
138
+ ctime = self.time - self.time.mean()
139
+ self.covs = ascontiguousarray(vstack([ones(self.time.size)]+[ctime**i for i in range(1, n_baseline+1)]).T)
140
+ self.covs[:, 1:] /= self.covs[:, 1:].std(axis=0)
141
+
129
142
  self.transit_mask: ndarray = transit_mask if transit_mask is not None else ones(time.size, dtype=bool)
143
+ self._wlmask: ndarray = all(self.mask, 1)
144
+ self._wls_with_nan: ndarray = where(~self._wlmask)[0]
145
+
130
146
  self._ephemeris: Ephemeris | None = ephemeris
131
147
  self.n_baseline: int = n_baseline
132
148
  self.noise_group: int = noise_group
@@ -170,6 +186,7 @@ class TSData:
170
186
  time = pf.ImageHDU(self.time, name=f'time_{self.name}')
171
187
  wave = pf.ImageHDU(self.wavelength, name=f'wave_{self.name}')
172
188
  data = pf.ImageHDU(array([self.fluxes, self.errors]), name=f'data_{self.name}')
189
+ covs = pf.ImageHDU(self.covs, name=f'covs_{self.name}')
173
190
  ootm = pf.ImageHDU(self.transit_mask.astype(int), name=f'ootm_{self.name}')
174
191
  mask = pf.ImageHDU(self.mask.astype(int), name=f'mask_{self.name}')
175
192
  data.header['ngroup'] = self.noise_group
@@ -177,7 +194,7 @@ class TSData:
177
194
  data.header['epgroup'] = self.epoch_group
178
195
  data.header['offgroup'] = self.offset_group
179
196
  #TODO: export ephemeris
180
- return pf.HDUList([time, wave, data, ootm, mask])
197
+ return pf.HDUList([time, wave, data, covs, ootm, mask])
181
198
 
182
199
  @staticmethod
183
200
  def import_fits(name: str, hdul: pf.HDUList) -> 'TSData':
@@ -200,6 +217,11 @@ class TSData:
200
217
  ootm = hdul[f'OOTM_{name}'].data.astype(bool)
201
218
  mask = hdul[f'MASK_{name}'].data.astype(bool)
202
219
 
220
+ try:
221
+ covs = hdul[f'COVS_{name}'].data.astype('d')
222
+ except KeyError:
223
+ covs = None
224
+
203
225
  try:
204
226
  noise_group = hdul[f'DATA_{name}'].header['NGROUP']
205
227
  except KeyError:
@@ -222,7 +244,8 @@ class TSData:
222
244
 
223
245
  #TODO: import ephemeris
224
246
  return TSData(time, wave, data[0], data[1], name=name, noise_group=noise_group, transit_mask=ootm,
225
- n_baseline=n_baseline, mask=mask, epoch_group=ephemeris_group, offset_group=offset_group)
247
+ n_baseline=n_baseline, mask=mask, epoch_group=ephemeris_group, offset_group=offset_group,
248
+ covs=covs)
226
249
 
227
250
  def __repr__(self) -> str:
228
251
  return f"TSData Name:'{self.name}' [{self.wavelength[0]:.2f} - {self.wavelength[-1]:.2f}] nwl={self.nwl} npt={self.npt}"
@@ -301,6 +324,8 @@ class TSData:
301
324
  self.maxtm = self.time.max()
302
325
  if self._ephemeris is not None:
303
326
  self.mask_transit(ephemeris=self._ephemeris)
327
+ self._wlmask = all(self.mask, 1)
328
+ self._wls_with_nan = where(~self._wlmask)[0]
304
329
 
305
330
  def _update_data_mask(self) -> None:
306
331
  self.mask = isfinite(self.fluxes)
@@ -377,7 +402,8 @@ class TSData:
377
402
  transit_mask=self.transit_mask[m],
378
403
  ephemeris=self.ephemeris,
379
404
  n_baseline=self.n_baseline,
380
- mask_nonfinite_errors=self.mask_nonfinite_errors)
405
+ mask_nonfinite_errors=self.mask_nonfinite_errors,
406
+ covs=self.covs[m])
381
407
  for i, m in enumerate(masks[1:]):
382
408
  d = d + TSData(name=f'{self.name}_{i+2}', time=self.time[m], wavelength=self.wavelength,
383
409
  fluxes=self.fluxes[:, m], errors=self.errors[:, m], mask=self.mask[:, m],
@@ -387,7 +413,8 @@ class TSData:
387
413
  transit_mask=self.transit_mask[m],
388
414
  ephemeris=self.ephemeris,
389
415
  n_baseline=self.n_baseline,
390
- mask_nonfinite_errors=self.mask_nonfinite_errors)
416
+ mask_nonfinite_errors=self.mask_nonfinite_errors,
417
+ covs=self.covs[m])
391
418
  return d
392
419
 
393
420
  def crop_wavelength(self, lmin: float, lmax: float, inplace: bool = True) -> 'TSData':
@@ -448,6 +475,7 @@ class TSData:
448
475
  self.transit_mask = self.transit_mask[m]
449
476
  self._tm_l_edges = self._tm_l_edges[m]
450
477
  self._tm_r_edges = self._tm_r_edges[m]
478
+ self.covs = self.covs[m]
451
479
  self._update()
452
480
  return self
453
481
  else:
@@ -463,7 +491,8 @@ class TSData:
463
491
  tm_edges=(self._tm_l_edges[m], self._tm_r_edges[m]),
464
492
  transit_mask=self.transit_mask[m], ephemeris=self.ephemeris,
465
493
  n_baseline=self.n_baseline,
466
- mask_nonfinite_errors=self.mask_nonfinite_errors)
494
+ mask_nonfinite_errors=self.mask_nonfinite_errors,
495
+ covs=self.covs[m])
467
496
 
468
497
  # TODO: separate mask into bad data mask and outlier mask.
469
498
  def mask_outliers(self, sigma: float = 5.0) -> 'TSData':
@@ -486,6 +515,8 @@ class TSData:
486
515
  self.mask &= abs(self.fluxes - fm) / fe < sigma
487
516
  self.fluxes = where(self.mask, self.fluxes, nan)
488
517
  self.errors = where(self.mask, self.errors, nan)
518
+ self._wlmask = all(self.mask, 1)
519
+ self._wls_with_nan = where(~self._wlmask)[0]
489
520
  return self
490
521
 
491
522
  @deprecated("0.10", alternative="TSData.mask_outliers")
@@ -688,7 +719,8 @@ class TSData:
688
719
  offset_group=self.offset_group,
689
720
  transit_mask=self.transit_mask,
690
721
  ephemeris=self.ephemeris,
691
- n_baseline=self.n_baseline)
722
+ n_baseline=self.n_baseline,
723
+ covs=self.covs)
692
724
 
693
725
  def bin_time(self, binning: Optional[Union[Binning, CompoundBinning]] = None,
694
726
  nb: Optional[int] = None, bw: Optional[float] = None,
@@ -719,6 +751,7 @@ class TSData:
719
751
  binning = Binning(self.time.min(), self.time.max(), nb=nb, bw=bw/(24*60*60) if bw is not None else None)
720
752
  bf, be = bin2d(self.fluxes.T, self.errors.T, self._tm_l_edges, self._tm_r_edges,
721
753
  binning.bins, estimate_errors=estimate_errors)
754
+ bc, _ = bin2d(self.covs, ones_like(self.covs), self._tm_l_edges, self._tm_r_edges, binning.bins, False)
722
755
  d = TSData(binning.bins.mean(1), self.wavelength, bf.T, be.T,
723
756
  wl_edges=(self._wl_l_edges, self._wl_r_edges),
724
757
  tm_edges=(binning.bins[:,0], binning.bins[:,1]),
@@ -727,7 +760,8 @@ class TSData:
727
760
  ephemeris=self.ephemeris,
728
761
  n_baseline=self.n_baseline,
729
762
  epoch_group=self.epoch_group,
730
- offset_group=self.offset_group)
763
+ offset_group=self.offset_group,
764
+ covs=bc)
731
765
  if self.ephemeris is not None:
732
766
  d.mask_transit(ephemeris=self.ephemeris)
733
767
  return d
@@ -741,7 +775,7 @@ class TSDataGroup:
741
775
  self.wlmax: float = -inf
742
776
  self.tmin: float = inf
743
777
  self.tmax: float = -inf
744
- self._noise_groups: array | None = None
778
+ self._noise_groups: ndarray | None = None
745
779
  for d in data:
746
780
  self._add_data(d)
747
781
 
@@ -782,7 +816,7 @@ class TSDataGroup:
782
816
  return [d.errors for d in self.data]
783
817
 
784
818
  @property
785
- def noise_groups(self) -> ndarray[int]:
819
+ def noise_groups(self) -> ndarray[int] | None:
786
820
  """Array of noise groups."""
787
821
  return self._noise_groups
788
822
 
@@ -935,4 +969,4 @@ class TSDataGroup:
935
969
 
936
970
 
937
971
  class TSDataSet(TSDataGroup):
938
- pass
972
+ pass
@@ -17,18 +17,36 @@
17
17
  from copy import deepcopy
18
18
  from typing import Optional, Literal
19
19
 
20
- from ldtk import BoxcarFilter, LDPSetCreator # noqa
20
+ from celerite2 import GaussianProcess as GP, terms
21
+ from ldtk import BoxcarFilter, LDPSetCreator # noqa
21
22
  from numba import njit, prange
22
- from numpy import zeros, log, pi, linspace, inf, atleast_2d, newaxis, clip, arctan2, ones, floor, sum, concatenate, \
23
- sort, ndarray, zeros_like, array, tile, arange, squeeze, dstack, nan, diff, all
23
+ from numpy import (
24
+ zeros,
25
+ log,
26
+ pi,
27
+ linspace,
28
+ inf,
29
+ atleast_2d,
30
+ newaxis,
31
+ clip,
32
+ arctan2,
33
+ sum,
34
+ concatenate,
35
+ sort,
36
+ ndarray,
37
+ array,
38
+ tile,
39
+ arange,
40
+ dstack,
41
+ diff,
42
+ ascontiguousarray,
43
+ nan,
44
+ )
45
+ from numpy.linalg import lstsq, LinAlgError
24
46
  from numpy.random import default_rng
25
- from celerite2 import GaussianProcess as GP, terms
26
-
27
47
  from pytransit.lpf.logposteriorfunction import LogPosteriorFunction
28
-
29
- from pytransit.orbits import as_from_rhop, i_from_ba, fold, i_from_baew, d_from_pkaiews, epoch
48
+ from pytransit.orbits import as_from_rhop, i_from_ba
30
49
  from pytransit.param import ParameterSet, UniformPrior as UP, NormalPrior as NP, GParameter
31
- from pytransit.stars import create_bt_settl_interpolator
32
50
  from scipy.interpolate import (
33
51
  pchip_interpolate,
34
52
  splrep,
@@ -37,10 +55,10 @@ from scipy.interpolate import (
37
55
  interp1d,
38
56
  )
39
57
 
40
- from .tsmodel import TransmissionSpectroscopyModel as TSModel
41
- from .tsdata import TSDataGroup
42
58
  from .ldtkld import LDTkLD
43
59
  from .spotmodel import SpotModel
60
+ from .tsdata import TSDataGroup
61
+ from .tsmodel import TransmissionSpectroscopyModel as TSModel
44
62
 
45
63
  NM_WHITE = 0
46
64
  NM_GP_FIXED = 1
@@ -49,6 +67,17 @@ NM_GP_FREE = 2
49
67
  noise_models = dict(white=NM_WHITE, fixed_gp=NM_GP_FIXED, free_gp=NM_GP_FREE)
50
68
 
51
69
 
70
+ @njit
71
+ def nlstsq(covs, res, mask, wlmask, with_nans):
72
+ nwl = res.shape[0]
73
+ nc = covs.shape[1]
74
+ x = zeros((nc, nwl))
75
+ x[:, wlmask] = lstsq(covs, ascontiguousarray(res[wlmask].T))[0]
76
+ for i in with_nans:
77
+ x[:, i] = lstsq(covs[mask[i]], res[i, mask[i]])[0]
78
+ return x
79
+
80
+
52
81
  @njit(parallel=True, cache=False)
53
82
  def lnlike_normal(o, m, e, f, mask):
54
83
  nwl, nt = o.shape
@@ -105,11 +134,11 @@ def add_knots(x_new, x_old):
105
134
  return sort(concatenate([x_new, x_old]))
106
135
 
107
136
 
108
- interpolator_choices = ("bspline", "pchip", "makima", "nearest", "linear", "bspline-quadratic")
137
+ interpolator_choices = ("bspline", "pchip", "makima", "nearest", "linear", "bspline-quadratic", "bspline-cubic")
109
138
 
110
139
 
111
- interpolators = {'bspline': ip_bspline, 'bspline-quadratic': ip_bspline_quadratic, 'pchip': ip_pchip,
112
- 'makima': ip_makima, 'nearest': ip_nearest, 'linear': ip_linear}
140
+ interpolators = {'bspline': ip_bspline, 'bspline-cubic': ip_bspline, 'bspline-quadratic': ip_bspline_quadratic,
141
+ 'pchip': ip_pchip, 'makima': ip_makima, 'nearest': ip_nearest, 'linear': ip_linear}
113
142
 
114
143
 
115
144
  def clean_knots(knots, min_distance, lmin=0, lmax=inf):
@@ -153,7 +182,7 @@ def clean_knots(knots, min_distance, lmin=0, lmax=inf):
153
182
  class TSLPF(LogPosteriorFunction):
154
183
  def __init__(self, runner, name: str, ldmodel, data: TSDataGroup, nk: int = 50, nldc: int = 10, nthreads: int = 1,
155
184
  tmpars = None, noise_model: Literal["white", "fixed_gp", "free_gp"] = 'white',
156
- interpolation: Literal['nearest', 'linear', 'pchip', 'makima', 'bspline', 'bspline-quadratic'] = 'makima'):
185
+ interpolation: Literal['nearest', 'linear', 'pchip', 'makima', 'bspline', 'bspline-quadratic', 'bspline-cubic'] = 'linear'):
157
186
  super().__init__(name)
158
187
  self._runner = runner
159
188
  self._original_data: TSDataGroup | None = None
@@ -244,7 +273,6 @@ class TSLPF(LogPosteriorFunction):
244
273
  self._init_p_noise()
245
274
  if self._nm == NM_GP_FREE:
246
275
  self._init_p_gp()
247
- self._init_p_baseline()
248
276
  self._init_p_bias()
249
277
  self.ps.freeze()
250
278
 
@@ -399,23 +427,6 @@ class TSLPF(LogPosteriorFunction):
399
427
  self._start_gp = ps.blocks[-1].start
400
428
  self._sl_gp = ps.blocks[-1].slice
401
429
 
402
- def _init_p_baseline(self):
403
- ps = self.ps
404
- self.n_baselines = self.data.n_baselines
405
- self.baseline_knots = []
406
- pp = []
407
- for i, d in enumerate(self.data):
408
- if d.n_baseline== 1:
409
- self.baseline_knots.append([])
410
- pp.append(GParameter(f'bl_{i:02d}_c', 'baseline constant', '', NP(1.0, 1e-6), (0, inf)))
411
- elif d.n_baseline > 1:
412
- knots = linspace(d.wavelength.min(), d.wavelength.max(), d.n_baseline)
413
- self.baseline_knots.append(knots)
414
- pp.extend([GParameter(f'bl_{i:02d}_{k:08.5f}', fr'baseline at {k:08.5f} $\mu$m', '', NP(1.0, 1e-6), (0, inf)) for k in knots])
415
- ps.add_global_block('baseline_coefficients', pp)
416
- self._start_baseline = ps.blocks[-1].start
417
- self._sl_baseline = ps.blocks[-1].slice
418
-
419
430
  def _init_p_bias(self):
420
431
  ps = self.ps
421
432
  offset_groups = self.data.offset_groups
@@ -639,9 +650,8 @@ class TSLPF(LogPosteriorFunction):
639
650
  self.de = None
640
651
  self._de_population = pvpn
641
652
 
642
- def _eval_k(self, pvp):
643
- """
644
- Evaluate the radius ratio model.
653
+ def _eval_k(self, pvp) -> list[ndarray]:
654
+ """Evaluate the radius ratio model.
645
655
 
646
656
  Parameters
647
657
  ----------
@@ -733,32 +743,29 @@ class TSLPF(LogPosteriorFunction):
733
743
  fluxes[i] = biases + (1.0 - biases) * fluxes[i]
734
744
  return fluxes
735
745
 
736
- def baseline_model(self, pv):
737
- pv = atleast_2d(pv)[:, self._sl_baseline]
738
- npv = pv.shape[0]
746
+ def baseline_model(self, mtransit):
747
+ npv = mtransit[0].shape[0]
739
748
  if self._baseline_models is None or self._baseline_models[0].shape[0] != npv:
740
- self._baseline_models = [zeros((npv, d.nwl)) for d in self.data]
741
- j = 0
749
+ self._baseline_models = [zeros(m.shape) for m in mtransit]
742
750
  for i, d in enumerate(self.data):
743
- nbl = d.n_baseline
744
- m = self._baseline_models[i]
745
- if nbl == 1:
746
- m[:, :] = pv[:, j][:, newaxis]
747
- else:
748
- for ipv in range(npv):
749
- m[ipv, :] = splev(d.wavelength, splrep(self.baseline_knots[i], pv[ipv, j:j+nbl], k=min(nbl-1, 3)))
750
- j += nbl
751
+ for ipv in range(npv):
752
+ res = d.fluxes / mtransit[i][ipv]
753
+ try:
754
+ coeffs = nlstsq(d.covs, res, d.mask, d._wlmask, d._wls_with_nan)
755
+ self._baseline_models[i][ipv, :, :] = (d.covs @ coeffs).T
756
+ except LinAlgError:
757
+ self._baseline_models[i][ipv, :, :] = nan
751
758
  return self._baseline_models
752
759
 
753
760
  def flux_model(self, pv):
754
761
  transit_models = self.transit_model(pv)
755
- baseline_models = self.baseline_model(pv)
762
+ baseline_models = self.baseline_model(transit_models)
756
763
  if self.spot_model is not None:
757
764
  self.spot_model.apply_spots(pv, transit_models)
758
765
  if self.spot_model.include_tlse:
759
766
  self.spot_model.apply_tlse(pv, transit_models)
760
767
  for i in range(self.data.size):
761
- transit_models[i][:, :, :] *= baseline_models[i][:, :, newaxis]
768
+ transit_models[i][:, :, :] *= baseline_models[i][:, :, :]
762
769
  return transit_models
763
770
 
764
771
  def create_pv_population(self, npop: int = 50) -> ndarray:
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes