ExoIris 0.18.0__py3-none-any.whl → 0.19.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
exoiris/tsdata.py CHANGED
@@ -15,13 +15,10 @@
15
15
  # along with this program. If not, see <https://www.gnu.org/licenses/>.
16
16
 
17
17
  import warnings
18
- import numba
19
-
20
- import pandas as pd
21
18
  from collections.abc import Sequence
22
-
23
19
  from typing import Union, Optional
24
20
 
21
+ import numba
25
22
  from astropy.io import fits as pf
26
23
  from astropy.stats import mad_std
27
24
  from astropy.utils import deprecated
@@ -29,15 +26,38 @@ from matplotlib.axes import Axes
29
26
  from matplotlib.figure import Figure
30
27
  from matplotlib.pyplot import subplots, setp
31
28
  from matplotlib.ticker import LinearLocator, FuncFormatter
32
- from numpy import isfinite, median, where, all, zeros_like, diff, asarray, interp, arange, floor, ndarray, \
33
- ceil, newaxis, inf, array, ones, poly1d, polyfit, nanpercentile, atleast_2d, nan, linspace, any, sqrt, nanmedian
29
+ from numpy import (
30
+ isfinite,
31
+ where,
32
+ all,
33
+ zeros_like,
34
+ diff,
35
+ asarray,
36
+ interp,
37
+ arange,
38
+ floor,
39
+ ndarray,
40
+ ceil,
41
+ newaxis,
42
+ inf,
43
+ array,
44
+ ones,
45
+ poly1d,
46
+ polyfit,
47
+ nanpercentile,
48
+ atleast_2d,
49
+ nan,
50
+ sqrt,
51
+ nanmedian,
52
+ nanmean,
53
+ unique,
54
+ )
34
55
  from pytransit.orbits import fold
35
- from scipy.ndimage import median_filter
36
- from scipy.signal import medfilt
37
56
 
57
+ from .binning import Binning, CompoundBinning
38
58
  from .ephemeris import Ephemeris
39
59
  from .util import bin2d
40
- from .binning import Binning, CompoundBinning
60
+
41
61
 
42
62
  class TSData:
43
63
  """
@@ -45,9 +65,10 @@ class TSData:
45
65
  fluxes, and errors. It provides methods for manipulating and analyzing the data.
46
66
  """
47
67
  def __init__(self, time: Sequence, wavelength: Sequence, fluxes: Sequence, errors: Sequence, name: str,
48
- noise_group: str = 'a', wl_edges : Sequence | None = None, tm_edges : Sequence | None = None,
68
+ noise_group: int = 0, wl_edges : Sequence | None = None, tm_edges : Sequence | None = None,
49
69
  transit_mask: ndarray | None = None, ephemeris: Ephemeris | None = None, n_baseline: int = 1,
50
- mask: ndarray = None, ephemeris_group: int = 0, offset_group: int = 0) -> None:
70
+ mask: ndarray = None, epoch_group: int = 0, offset_group: int = 0,
71
+ mask_nonfinite_errors: bool = True) -> None:
51
72
  """
52
73
  Parameters
53
74
  ----------
@@ -81,26 +102,42 @@ class TSData:
81
102
  if n_baseline < 1:
82
103
  raise ValueError("n_baseline must be greater than zero.")
83
104
 
84
- if ephemeris_group < 0:
85
- raise ValueError("ephemeris_group must be a non-negative integer.")
105
+ if noise_group < 0:
106
+ raise ValueError("noise_group must be a positive integer.")
107
+
108
+ if epoch_group < 0:
109
+ raise ValueError("epoch_group must be a non-negative integer.")
86
110
 
87
111
  if offset_group < 0:
88
112
  raise ValueError("offset_group must be a non-negative integer.")
89
113
 
114
+ if not all(isfinite(time)):
115
+ raise ValueError("The time array must contain only finite values.")
116
+
117
+ if not all(isfinite(wavelength)):
118
+ raise ValueError("The wavelength array must contain only finite values.")
119
+
90
120
  self.name: str = name
121
+ self.mask_nonfinite_errors: bool = mask_nonfinite_errors
91
122
  self.time: ndarray = time.copy()
92
123
  self.wavelength: ndarray = wavelength
93
- self.mask: ndarray = mask if mask is not None else isfinite(fluxes) & isfinite(errors)
124
+ self.mask: ndarray = mask if mask is not None else isfinite(fluxes)
125
+ if self.mask_nonfinite_errors:
126
+ self.mask &= isfinite(errors)
94
127
  self.fluxes: ndarray = where(self.mask, fluxes, nan)
95
128
  self.errors: ndarray = where(self.mask, errors, nan)
96
129
  self.transit_mask: ndarray = transit_mask if transit_mask is not None else ones(time.size, dtype=bool)
97
- self.ngid: int = 0
98
- self.ephemeris: Ephemeris | None = ephemeris
130
+ self._ephemeris: Ephemeris | None = ephemeris
99
131
  self.n_baseline: int = n_baseline
100
- self._noise_group: str = noise_group
101
- self.ephemeris_group: int = ephemeris_group
132
+ self.noise_group: int = noise_group
133
+ self.epoch_group: int = epoch_group
102
134
  self.offset_group: int = offset_group
103
- self._dataset: Optional['TSDataSet'] = None
135
+ self._dataset: Optional['TSDataGroup'] = None
136
+ self.minwl: float = 0.0
137
+ self.maxwl: float = inf
138
+ self.mintm: float = 0.0
139
+ self.maxtm: float = inf
140
+
104
141
  self._update()
105
142
 
106
143
  if wl_edges is None:
@@ -137,7 +174,7 @@ class TSData:
137
174
  mask = pf.ImageHDU(self.mask.astype(int), name=f'mask_{self.name}')
138
175
  data.header['ngroup'] = self.noise_group
139
176
  data.header['nbasel'] = self.n_baseline
140
- data.header['epgroup'] = self.ephemeris_group
177
+ data.header['epgroup'] = self.epoch_group
141
178
  data.header['offgroup'] = self.offset_group
142
179
  #TODO: export ephemeris
143
180
  return pf.HDUList([time, wave, data, ootm, mask])
@@ -162,9 +199,21 @@ class TSData:
162
199
  data = hdul[f'DATA_{name}'].data.astype('d')
163
200
  ootm = hdul[f'OOTM_{name}'].data.astype(bool)
164
201
  mask = hdul[f'MASK_{name}'].data.astype(bool)
165
- noise_group = hdul[f'DATA_{name}'].header['NGROUP']
166
- ephemeris_group = hdul[f'DATA_{name}'].header['EPGROUP']
167
- offset_group = hdul[f'DATA_{name}'].header['OFFGROUP']
202
+
203
+ try:
204
+ noise_group = hdul[f'DATA_{name}'].header['NGROUP']
205
+ except KeyError:
206
+ noise_group = 0
207
+
208
+ try:
209
+ ephemeris_group = hdul[f'DATA_{name}'].header['EPGROUP']
210
+ except KeyError:
211
+ ephemeris_group = 0
212
+
213
+ try:
214
+ offset_group = hdul[f'DATA_{name}'].header['OFFGROUP']
215
+ except KeyError:
216
+ offset_group = 0
168
217
 
169
218
  try:
170
219
  n_baseline = hdul[f'DATA_{name}'].header['NBASEL']
@@ -173,21 +222,29 @@ class TSData:
173
222
 
174
223
  #TODO: import ephemeris
175
224
  return TSData(time, wave, data[0], data[1], name=name, noise_group=noise_group, transit_mask=ootm,
176
- n_baseline=n_baseline, mask=mask, ephemeris_group=ephemeris_group, offset_group=offset_group)
225
+ n_baseline=n_baseline, mask=mask, epoch_group=ephemeris_group, offset_group=offset_group)
177
226
 
178
227
  def __repr__(self) -> str:
179
228
  return f"TSData Name:'{self.name}' [{self.wavelength[0]:.2f} - {self.wavelength[-1]:.2f}] nwl={self.nwl} npt={self.npt}"
180
229
 
181
230
  @property
182
- def noise_group(self) -> str:
183
- """Noise group name."""
184
- return self._noise_group
231
+ def ephemeris(self) -> Ephemeris:
232
+ """Ephemeris."""
233
+ return self._ephemeris
234
+
235
+ @ephemeris.setter
236
+ def ephemeris(self, ep: Ephemeris) -> None:
237
+ self._ephemeris = ep
238
+ self.mask_transit(ephemeris=ep)
185
239
 
186
- @noise_group.setter
187
- def noise_group(self, ng: str) -> None:
188
- self._noise_group = ng
189
- if self._dataset is not None:
190
- self._dataset._update_nids()
240
+ @property
241
+ def bbox_wl(self) -> tuple[float, float]:
242
+ """Wavelength bounds of the bounding box."""
243
+ return self.minwl, self.maxwl
244
+
245
+ @property
246
+ def bbox_tm(self) -> tuple[float, float]:
247
+ return self.mintm, self.maxtm
191
248
 
192
249
  def mask_transit(self, t0: float | None = None, p: float | None = None, t14: float | None = None,
193
250
  ephemeris : Ephemeris | None = None, elims: tuple[int, int] | None = None) -> 'TSData':
@@ -208,9 +265,9 @@ class TSData:
208
265
  """
209
266
  if (t0 and p and t14) or ephemeris is not None:
210
267
  if ephemeris is not None:
211
- self.ephemeris = ephemeris
268
+ self._ephemeris = ephemeris
212
269
  else:
213
- self.ephemeris = Ephemeris(t0, p, t14)
270
+ self._ephemeris = Ephemeris(t0, p, t14)
214
271
  phase = fold(self.time, self.ephemeris.period, self.ephemeris.zero_epoch)
215
272
  self.transit_mask = abs(phase) > 0.502 * self.ephemeris.duration
216
273
  elif elims is not None:
@@ -238,7 +295,19 @@ class TSData:
238
295
  """Update the internal attributes."""
239
296
  self.nwl = self.wavelength.size
240
297
  self.npt = self.time.size
241
- self.wllims = self.wavelength.min(), self.wavelength.max()
298
+ self.minwl = self.wavelength.min()
299
+ self.maxwl = self.wavelength.max()
300
+ self.mintm = self.time.min()
301
+ self.maxtm = self.time.max()
302
+ if self._ephemeris is not None:
303
+ self.mask_transit(ephemeris=self._ephemeris)
304
+
305
+ def _update_data_mask(self) -> None:
306
+ self.mask = isfinite(self.fluxes)
307
+ if self.mask_nonfinite_errors:
308
+ self.mask &= isfinite(self.errors)
309
+ self.fluxes = where(self.mask, self.fluxes, nan)
310
+ self.errors = where(self.mask, self.errors, nan)
242
311
 
243
312
  def normalize_to_poly(self, deg: int = 1) -> 'TSData':
244
313
  """Normalize the baseline flux for each spectroscopic light curve.
@@ -266,11 +335,15 @@ class TSData:
266
335
  "Call TSData.mask_transit(...) first.")
267
336
 
268
337
  for ipb in range(self.nwl):
269
- bl = poly1d(polyfit(self.time[self.transit_mask & self.mask[ipb]],
270
- self.fluxes[ipb, self.transit_mask & self.mask[ipb]],
271
- deg=deg))(self.time)
272
- self.fluxes[ipb, :] /= bl
273
- self.errors[ipb, :] /= bl
338
+ mask = self.transit_mask & self.mask[ipb]
339
+ if mask.sum() > 2:
340
+ bl = poly1d(polyfit(self.time[mask], self.fluxes[ipb, mask], deg=deg))(self.time)
341
+ self.fluxes[ipb, :] /= bl
342
+ self.errors[ipb, :] /= bl
343
+ else:
344
+ self.fluxes[ipb, :] = nan
345
+ self.errors[ipb, :] = nan
346
+ self._update_data_mask()
274
347
  return self
275
348
 
276
349
  def normalize_to_median(self, s: slice) -> 'TSData':
@@ -286,7 +359,7 @@ class TSData:
286
359
  self.errors[:,:] /= n
287
360
  return self
288
361
 
289
- def partition_time(self, tlims: tuple[tuple[float,float]]) -> 'TSDataSet':
362
+ def partition_time(self, tlims: tuple[tuple[float,float]]) -> 'TSDataGroup':
290
363
  """Partition the data into n segments defined by tlims.
291
364
 
292
365
  Parameters
@@ -299,20 +372,22 @@ class TSData:
299
372
  d = TSData(name=f'{self.name}_1', time=self.time[m], wavelength=self.wavelength,
300
373
  fluxes=self.fluxes[:, m], errors=self.errors[:, m], mask=self.mask[:, m],
301
374
  noise_group=self.noise_group,
302
- ephemeris_group=self.ephemeris_group,
375
+ epoch_group=self.epoch_group,
303
376
  offset_group=self.offset_group,
304
377
  transit_mask=self.transit_mask[m],
305
378
  ephemeris=self.ephemeris,
306
- n_baseline=self.n_baseline)
379
+ n_baseline=self.n_baseline,
380
+ mask_nonfinite_errors=self.mask_nonfinite_errors)
307
381
  for i, m in enumerate(masks[1:]):
308
382
  d = d + TSData(name=f'{self.name}_{i+2}', time=self.time[m], wavelength=self.wavelength,
309
383
  fluxes=self.fluxes[:, m], errors=self.errors[:, m], mask=self.mask[:, m],
310
384
  noise_group=self.noise_group,
311
- ephemeris_group=self.ephemeris_group,
385
+ epoch_group=self.epoch_group,
312
386
  offset_group=self.offset_group,
313
387
  transit_mask=self.transit_mask[m],
314
388
  ephemeris=self.ephemeris,
315
- n_baseline=self.n_baseline)
389
+ n_baseline=self.n_baseline,
390
+ mask_nonfinite_errors=self.mask_nonfinite_errors)
316
391
  return d
317
392
 
318
393
  def crop_wavelength(self, lmin: float, lmax: float, inplace: bool = True) -> 'TSData':
@@ -344,12 +419,13 @@ class TSData:
344
419
  errors=self.errors[m],
345
420
  mask=self.mask[m],
346
421
  noise_group=self.noise_group,
347
- ephemeris_group=self.ephemeris_group,
422
+ epoch_group=self.epoch_group,
348
423
  offset_group=self.offset_group,
349
424
  wl_edges=(self._wl_l_edges[m], self._wl_r_edges[m]),
350
425
  tm_edges=(self._tm_l_edges, self._tm_r_edges),
351
426
  transit_mask=self.transit_mask, ephemeris=self.ephemeris,
352
- n_baseline=self.n_baseline)
427
+ n_baseline=self.n_baseline,
428
+ mask_nonfinite_errors=self.mask_nonfinite_errors)
353
429
 
354
430
  def crop_time(self, tmin: float, tmax: float, inplace: bool = True) -> 'TSData':
355
431
  """Crop the data to include only the time range between lmin and lmax.
@@ -381,19 +457,20 @@ class TSData:
381
457
  errors=self.errors[:, m],
382
458
  mask = self.mask[:, m],
383
459
  noise_group=self.noise_group,
384
- ephemeris_group=self.ephemeris_group,
460
+ epoch_group=self.epoch_group,
385
461
  offset_group=self.offset_group,
386
462
  wl_edges=(self._wl_l_edges, self._wl_r_edges),
387
463
  tm_edges=(self._tm_l_edges[m], self._tm_r_edges[m]),
388
464
  transit_mask=self.transit_mask[m], ephemeris=self.ephemeris,
389
- n_baseline=self.n_baseline)
465
+ n_baseline=self.n_baseline,
466
+ mask_nonfinite_errors=self.mask_nonfinite_errors)
390
467
 
391
- def remove_outliers(self, sigma: float = 5.0) -> 'TSData':
392
- """Remove outliers along the wavelength axis.
468
+ # TODO: separate mask into bad data mask and outlier mask.
469
+ def mask_outliers(self, sigma: float = 5.0) -> 'TSData':
470
+ """Mask outliers along the wavelength axis.
393
471
 
394
- Replace outliers along the wavelength axis with the value of a 5-point running median filter. Outliers are
395
- defined as data points that deviate from the median by more than sigma times the median absolute deviation
396
- along the wavelength axis.
472
+ Outliers are defined as data points that deviate from the running 5-point median by more
473
+ than sigma times the median absolute deviation along the wavelength axis.
397
474
 
398
475
  Parameters
399
476
  ----------
@@ -404,13 +481,18 @@ class TSData:
404
481
  ----
405
482
  The data will be modified in place.
406
483
  """
407
- fm = median(self.fluxes, axis=0)
408
- fe = mad_std(self.fluxes, axis=0)
484
+ fm = nanmedian(self.fluxes, axis=0)
485
+ fe = mad_std(self.fluxes, axis=0, ignore_nan=True)
409
486
  self.mask &= abs(self.fluxes - fm) / fe < sigma
410
487
  self.fluxes = where(self.mask, self.fluxes, nan)
411
488
  self.errors = where(self.mask, self.errors, nan)
412
489
  return self
413
490
 
491
+ @deprecated("0.10", alternative="TSData.mask_outliers")
492
+ def remove_outliers(self, sigma: float = 5.0) -> 'TSData':
493
+ """Remove outliers along the wavelength axis."""
494
+ self.mask_outliers(sigma=sigma)
495
+
414
496
  def plot(self, ax=None, vmin: float = None, vmax: float = None, cmap=None, figsize=None, data=None,
415
497
  plims: tuple[float, float] | None = None) -> Figure:
416
498
  """Plot the spectroscopic light curves as a 2D image.
@@ -510,7 +592,7 @@ class TSData:
510
592
  fig = ax.figure
511
593
  tref = floor(self.time.min())
512
594
 
513
- ax.plot(self.time, self.fluxes.mean(0))
595
+ ax.plot(self.time, nanmean(self.fluxes, 0))
514
596
  if self.ephemeris is not None:
515
597
  [ax.axvline(tl, ls='--', c='k') for tl in self.ephemeris.transit_limits(self.time.mean())]
516
598
 
@@ -547,7 +629,7 @@ class TSData:
547
629
  """
548
630
  return self.plot(ax=ax, figsize=figsize, data=where(self.transit_mask, self.fluxes, nan))
549
631
 
550
- def __add__(self, other: Union['TSData', 'TSDataSet']) -> 'TSDataSet':
632
+ def __add__(self, other: Union['TSData', 'TSDataGroup']) -> 'TSDataGroup':
551
633
  """Combine two transmission spectra along the wavelength axis.
552
634
 
553
635
  Parameters
@@ -557,12 +639,12 @@ class TSData:
557
639
 
558
640
  Returns
559
641
  -------
560
- TSDataSet
642
+ TSDataGroup
561
643
  """
562
644
  if isinstance(other, TSData):
563
- return TSDataSet([self, other])
645
+ return TSDataGroup([self, other])
564
646
  else:
565
- return TSDataSet([self]) + other
647
+ return TSDataGroup([self]) + other
566
648
 
567
649
  def bin_wavelength(self, binning: Optional[Union[Binning, CompoundBinning]] = None,
568
650
  nb: Optional[int] = None, bw: Optional[float] = None, r: Optional[float] = None,
@@ -592,7 +674,7 @@ class TSData:
592
674
  with warnings.catch_warnings():
593
675
  warnings.simplefilter('ignore', numba.NumbaPerformanceWarning)
594
676
  if binning is None:
595
- binning = Binning(self.wllims[0], self.wllims[1], nb=nb, bw=bw, r=r)
677
+ binning = Binning(self.bbox_wl[0], self.bbox_wl[1], nb=nb, bw=bw, r=r)
596
678
  bf, be = bin2d(self.fluxes, self.errors, self._wl_l_edges, self._wl_r_edges,
597
679
  binning.bins, estimate_errors=estimate_errors)
598
680
  if not all(isfinite(be)):
@@ -602,7 +684,7 @@ class TSData:
602
684
  name=self.name,
603
685
  tm_edges=(self._tm_l_edges, self._tm_r_edges),
604
686
  noise_group=self.noise_group,
605
- ephemeris_group=self.ephemeris_group,
687
+ epoch_group=self.epoch_group,
606
688
  offset_group=self.offset_group,
607
689
  transit_mask=self.transit_mask,
608
690
  ephemeris=self.ephemeris,
@@ -644,14 +726,14 @@ class TSData:
644
726
  noise_group=self.noise_group,
645
727
  ephemeris=self.ephemeris,
646
728
  n_baseline=self.n_baseline,
647
- ephemeris_group=self.ephemeris_group,
729
+ epoch_group=self.epoch_group,
648
730
  offset_group=self.offset_group)
649
731
  if self.ephemeris is not None:
650
732
  d.mask_transit(ephemeris=self.ephemeris)
651
733
  return d
652
734
 
653
- class TSDataSet:
654
- """`TSDataSet` is a high-level data storage class that can contain multiple `TSData` objects.
735
+ class TSDataGroup:
736
+ """`TSDataGroup` is a high-level data storage class that can contain multiple `TSData` objects.
655
737
  """
656
738
  def __init__(self, data: Sequence[TSData]):
657
739
  self.data: list[TSData] = []
@@ -659,7 +741,7 @@ class TSDataSet:
659
741
  self.wlmax: float = -inf
660
742
  self.tmin: float = inf
661
743
  self.tmax: float = -inf
662
- self.ngids: ndarray = array([])
744
+ self._noise_groups: array | None = None
663
745
  for d in data:
664
746
  self._add_data(d)
665
747
 
@@ -668,19 +750,12 @@ class TSDataSet:
668
750
  raise ValueError('A TSData object with the same name already exists.')
669
751
  d._dataset = self
670
752
  self.data.append(d)
671
- self._update_nids()
753
+ self._noise_groups = array([d.noise_group for d in self.data])
672
754
  self.wlmin = min(self.wlmin, d.wavelength.min())
673
755
  self.wlmax = max(self.wlmax, d.wavelength.max())
674
756
  self.tmin = min(self.tmin, d.time.min())
675
757
  self.tmax = max(self.tmax, d.time.max())
676
758
 
677
- def _update_nids(self):
678
- ngs = pd.Categorical(self.noise_groups)
679
- self.unique_noise_groups = list(ngs.categories)
680
- self.ngids = ngs.codes.astype(int)
681
- for i,d in enumerate(self.data):
682
- d.ngid = self.ngids[i]
683
-
684
759
  @property
685
760
  def names(self) -> list[str]:
686
761
  """List of data set names."""
@@ -707,20 +782,25 @@ class TSDataSet:
707
782
  return [d.errors for d in self.data]
708
783
 
709
784
  @property
710
- def noise_groups(self) -> list[str]:
711
- """List of noise group names."""
712
- return [d.noise_group for d in self.data]
785
+ def noise_groups(self) -> ndarray[int]:
786
+ """Array of noise groups."""
787
+ return self._noise_groups
713
788
 
714
789
  @property
715
790
  def n_noise_groups(self) -> int:
716
791
  """Number of noise groups."""
717
- return len(set(self.noise_groups))
792
+ return len(unique(self.noise_groups))
718
793
 
719
794
  @property
720
795
  def offset_groups(self) -> list[int]:
721
796
  """List of offset groups."""
722
797
  return [d.offset_group for d in self.data]
723
798
 
799
+ @property
800
+ def epoch_groups(self) -> list[int]:
801
+ """List of epoch groups."""
802
+ return [d.epoch_group for d in self.data]
803
+
724
804
  @property
725
805
  def n_baselines(self) -> list[int]:
726
806
  """Number of baseline coefficients for each data set."""
@@ -749,7 +829,7 @@ class TSDataSet:
749
829
  return hdul
750
830
 
751
831
  @staticmethod
752
- def import_fits(hdul: pf.HDUList) -> 'TSDataSet':
832
+ def import_fits(hdul: pf.HDUList) -> 'TSDataGroup':
753
833
  """Import all the data from a FITS HDU list.
754
834
 
755
835
  Parameters
@@ -759,14 +839,14 @@ class TSDataSet:
759
839
 
760
840
  Returns
761
841
  -------
762
- TSDataSet
842
+ TSDataGroup
763
843
  """
764
844
  ds = hdul['DATASET']
765
845
  data = []
766
846
  for i in range(ds.header['NDATA']):
767
847
  name = ds.header[f'NAME_{i}']
768
848
  data.append(TSData.import_fits(name, hdul))
769
- return TSDataSet(data)
849
+ return TSDataGroup(data)
770
850
 
771
851
  def mask_transit(self, tc: float, p: float, t14: float):
772
852
  for d in self.data:
@@ -779,7 +859,7 @@ class TSDataSet:
779
859
  return self.size
780
860
 
781
861
  def __repr__(self):
782
- return f"TSDataSet with {self.size} groups"
862
+ return f"TSDataGroup with {self.size} groups"
783
863
 
784
864
  def plot(self, axs=None, vmin: float = None, vmax: float = None, ncols: int = 1, cmap=None, figsize=None, data: ndarray | None = None) -> Figure:
785
865
  """Plot all the data sets.
@@ -849,6 +929,10 @@ class TSDataSet:
849
929
 
850
930
  def __add__(self, other):
851
931
  if isinstance(other, TSData):
852
- return TSDataSet(self.data + [other])
853
- elif isinstance(other, TSDataSet):
854
- return TSDataSet(self.data + other.data)
932
+ return TSDataGroup(self.data + [other])
933
+ elif isinstance(other, TSDataGroup):
934
+ return TSDataGroup(self.data + other.data)
935
+
936
+
937
+ class TSDataSet(TSDataGroup):
938
+ pass