plotastrodata 1.8.11__tar.gz → 1.8.12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. {plotastrodata-1.8.11/plotastrodata.egg-info → plotastrodata-1.8.12}/PKG-INFO +1 -1
  2. {plotastrodata-1.8.11 → plotastrodata-1.8.12}/plotastrodata/__init__.py +1 -1
  3. {plotastrodata-1.8.11 → plotastrodata-1.8.12}/plotastrodata/analysis_utils.py +17 -38
  4. {plotastrodata-1.8.11 → plotastrodata-1.8.12}/plotastrodata/fits_utils.py +21 -31
  5. {plotastrodata-1.8.11 → plotastrodata-1.8.12}/plotastrodata/fitting_utils.py +49 -0
  6. {plotastrodata-1.8.11 → plotastrodata-1.8.12}/plotastrodata/other_utils.py +35 -45
  7. {plotastrodata-1.8.11 → plotastrodata-1.8.12}/plotastrodata/plot_utils.py +130 -138
  8. {plotastrodata-1.8.11 → plotastrodata-1.8.12/plotastrodata.egg-info}/PKG-INFO +1 -1
  9. {plotastrodata-1.8.11 → plotastrodata-1.8.12}/LICENSE +0 -0
  10. {plotastrodata-1.8.11 → plotastrodata-1.8.12}/MANIFEST.in +0 -0
  11. {plotastrodata-1.8.11 → plotastrodata-1.8.12}/README.md +0 -0
  12. {plotastrodata-1.8.11 → plotastrodata-1.8.12}/plotastrodata/const_utils.py +0 -0
  13. {plotastrodata-1.8.11 → plotastrodata-1.8.12}/plotastrodata/coord_utils.py +0 -0
  14. {plotastrodata-1.8.11 → plotastrodata-1.8.12}/plotastrodata/ext_utils.py +0 -0
  15. {plotastrodata-1.8.11 → plotastrodata-1.8.12}/plotastrodata/fft_utils.py +0 -0
  16. {plotastrodata-1.8.11 → plotastrodata-1.8.12}/plotastrodata/los_utils.py +0 -0
  17. {plotastrodata-1.8.11 → plotastrodata-1.8.12}/plotastrodata/matrix_utils.py +0 -0
  18. {plotastrodata-1.8.11 → plotastrodata-1.8.12}/plotastrodata/noise_utils.py +0 -0
  19. {plotastrodata-1.8.11 → plotastrodata-1.8.12}/plotastrodata.egg-info/SOURCES.txt +0 -0
  20. {plotastrodata-1.8.11 → plotastrodata-1.8.12}/plotastrodata.egg-info/dependency_links.txt +0 -0
  21. {plotastrodata-1.8.11 → plotastrodata-1.8.12}/plotastrodata.egg-info/not-zip-safe +0 -0
  22. {plotastrodata-1.8.11 → plotastrodata-1.8.12}/plotastrodata.egg-info/requires.txt +0 -0
  23. {plotastrodata-1.8.11 → plotastrodata-1.8.12}/plotastrodata.egg-info/top_level.txt +0 -0
  24. {plotastrodata-1.8.11 → plotastrodata-1.8.12}/setup.cfg +0 -0
  25. {plotastrodata-1.8.11 → plotastrodata-1.8.12}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: plotastrodata
3
- Version: 1.8.11
3
+ Version: 1.8.12
4
4
  Summary: plotastrodata is a tool for astronomers to create figures from FITS files and perform fundamental data analyses with ease.
5
5
  Home-page: https://github.com/yusukeaso-astron/plotastrodata
6
6
  Download-URL: https://github.com/yusukeaso-astron/plotastrodata
@@ -1,4 +1,4 @@
1
1
  import warnings
2
2
 
3
3
  warnings.simplefilter('ignore', FutureWarning)
4
- __version__ = '1.8.11'
4
+ __version__ = '1.8.12'
@@ -8,7 +8,7 @@ from scipy.signal import convolve
8
8
  from plotastrodata import const_utils as cu
9
9
  from plotastrodata.coord_utils import coord2xy, rel2abs, xy2coord
10
10
  from plotastrodata.fits_utils import data2fits, FitsData, Jy2K
11
- from plotastrodata.fitting_utils import EmceeCorner
11
+ from plotastrodata.fitting_utils import EmceeCorner, gaussfit1d
12
12
  from plotastrodata.matrix_utils import dot2d, Mfac, Mrot
13
13
  from plotastrodata.noise_utils import estimate_rms
14
14
  from plotastrodata.other_utils import (gaussian2d, isdeg,
@@ -442,7 +442,6 @@ class AstroData():
442
442
  xlist, ylist = coord2xy(coords, self.center) * 3600.
443
443
  nprof = len(xlist)
444
444
  v = self.v
445
- dv = self.dv
446
445
  data, xf, yf = filled2d(self.data, self.x, self.y, ninterp)
447
446
  x, y = np.meshgrid(xf, yf)
448
447
  prof = np.empty((nprof, len(v)))
@@ -469,22 +468,12 @@ class AstroData():
469
468
  prof *= np.abs(self.dx * self.dy) / Omega
470
469
  gfitres = {}
471
470
  if gaussfit:
472
- xmin, xmax = np.min(v), np.max(v)
473
- ymin, ymax = np.min(prof), np.max(prof)
474
- bounds = [[ymin, xmin, np.abs(dv)], [ymax, xmax, xmax - xmin]]
475
-
476
- def gauss(x, p, c, w):
477
- return p * np.exp(-4. * np.log(2.) * ((x - c) / w)**2)
478
-
479
- nprof = len(prof)
480
- best, error = [None] * nprof, [None] * nprof
471
+ res = [None] * nprof
481
472
  for i in range(nprof):
482
- popt, pcov = curve_fit(gauss, v, prof[i], bounds=bounds)
483
- perr = np.sqrt(np.diag(pcov))
484
- print('Gauss (peak, center, FWHM):', popt)
485
- print('Gauss uncertainties:', perr)
486
- best[i], error[i] = popt, perr
487
- gfitres = {'best': best, 'error': error}
473
+ res[i] = gaussfit1d(xdata=v, ydata=prof[i],
474
+ sigma=None, show=True)
475
+ gfitres['best'] = [a['popt'][:3] for a in res]
476
+ gfitres['error'] = [a['perr'][:3] for a in res]
488
477
  return v, prof, gfitres
489
478
 
490
479
  def rotate(self, pa: float = 0, **kwargs):
@@ -642,34 +631,24 @@ class AstroFrame():
642
631
  def __post_init__(self):
643
632
  self.xdir = -1 if self.xflip else 1
644
633
  self.ydir = -1 if self.yflip else 1
645
- if self.xmax is None:
646
- self.xmax = self.rmax
647
- if self.xmin is None:
648
- self.xmin = -self.rmax
649
- if self.ymax is None:
650
- self.ymax = self.rmax
651
- if self.ymin is None:
652
- self.ymin = -self.rmax
653
- if self.xdir == -1:
634
+ self.xmin = -self.rmax if self.xmin is None else self.xmin
635
+ self.xmax = self.rmax if self.xmax is None else self.xmax
636
+ self.ymin = -self.rmax if self.ymin is None else self.ymin
637
+ self.ymax = self.rmax if self.ymax is None else self.ymax
638
+ if self.xflip:
654
639
  self.xmin, self.xmax = self.xmax, self.xmin
655
- if self.ydir == -1:
640
+ if self.yflip:
656
641
  self.ymin, self.ymax = self.ymax, self.ymin
657
642
  xlim = [self.xoff + self.xmin, self.xoff + self.xmax]
658
643
  ylim = [self.yoff + self.ymin, self.yoff + self.ymax]
659
644
  vlim = [self.vmin, self.vmax]
660
645
  if self.pv:
661
- xlim = np.sort(xlim)
646
+ xlim = sorted(xlim)
662
647
  if not self.xflip:
663
- xlim = xlim[::-1]
664
- self.xlim = xlim
665
- self.ylim = ylim
666
- self.vlim = vlim
667
- if self.pv:
668
- self.Xlim = vlim if self.swapxy else xlim
669
- self.Ylim = xlim if self.swapxy else vlim
670
- else:
671
- self.Xlim = ylim if self.swapxy else xlim
672
- self.Ylim = xlim if self.swapxy else ylim
648
+ xlim.reverse()
649
+ self.xlim, self.ylim, self.vlim = xlim, ylim, vlim
650
+ _x, _y = (xlim, vlim) if self.pv else (xlim, ylim)
651
+ self.Xlim, self.Ylim = (_y, _x) if self.swapxy else (_x, _y)
673
652
  if self.quadrants is not None:
674
653
  self.Xlim = [0, self.rmax]
675
654
  self.Ylim = [0, min(self.vmax, -self.vmin)]
@@ -398,49 +398,39 @@ def fits2data(fitsimage: str, Tb: bool = False, log: bool = False,
398
398
  return fd.data, (fd.x, fd.y, fd.v), beam, bunit, rms
399
399
 
400
400
 
401
- def data2fits(d: np.ndarray | None = None, h: dict = {},
401
+ def data2fits(d: np.ndarray, h: dict = {},
402
402
  templatefits: str | None = None,
403
403
  fitsimage: str = 'test') -> None:
404
404
  """Make a fits file from a N-D array.
405
405
 
406
406
  Args:
407
- d (np.ndarray, optional): N-D array. Defaults to None.
408
- h (dict, optional): Fits header. Defaults to {}.
409
- templatefits (str, optional): Fits file to copy header. Defaults to None.
410
- fitsimage (str, optional): Output name. Defaults to 'test'.
407
+ d (np.ndarray): N-D array.
408
+ h (dict, optional): Additional FITS header. Defaults to {}.
409
+ templatefits (str, optional): FITS file whose header is used as a temperate. Defaults to None.
410
+ fitsimage (str, optional): Output filename, with or without '.fits'. Defaults to 'test'.
411
411
  """
412
- naxis = np.ndim(d)
413
- w = wcs.WCS(naxis=naxis)
414
412
  _h = {} if templatefits is None else FitsData(templatefits).get_header()
415
413
  _h.update(h)
414
+ naxis = np.ndim(d)
415
+ w = wcs.WCS(naxis=naxis)
416
416
  if _h == {}:
417
417
  w.wcs.crpix = [0] * naxis
418
418
  w.wcs.crval = [0] * naxis
419
419
  w.wcs.cdelt = [1] * naxis
420
- ctype = ['RA---SIN', 'DEC--SIN', 'FREQ']
421
- if 'CTYPE1' in _h:
422
- ctype[0] = _h['CTYPE1']
423
- if 'CTYPE2' in _h:
424
- ctype[1] = _h['CTYPE2']
425
- if 'CTYPE3' in _h:
426
- ctype[2] = _h['CTYPE3']
427
- w.wcs.ctype = ctype[:naxis]
428
- cunit = ['deg', 'deg', 'Hz']
429
- if 'CUNIT1' in _h:
430
- cunit[0] = _h['CUNIT1']
431
- if 'CUNIT2' in _h:
432
- cunit[1] = _h['CUNIT2']
433
- if 'CUNIT3' in _h:
434
- cunit[2] = _h['CUNIT3']
435
- w.wcs.cunit = cunit[:naxis]
420
+ defaults = {'CTYPE': ['RA---SIN', 'DEC--SIN', 'FREQ'],
421
+ 'CUNIT': ['deg', 'deg', 'Hz']}
422
+ for k, v in defaults.items():
423
+ for i in range(naxis):
424
+ _h.setdefault(f'{k}{i+1:d}', v[i])
425
+ w.wcs.ctype = [_h[f'CTYPE{i+1}'] for i in range(naxis)]
426
+ w.wcs.cunit = [_h[f'CUNIT{i+1}'] for i in range(naxis)]
427
+ _h.setdefault('BUNIT', 'Jy/beam')
428
+ if naxis >= 3:
429
+ _h.setdefault('SPECSYS', 'LSRK')
436
430
  header = w.to_header()
437
431
  hdu = fits.PrimaryHDU(d, header=header)
438
- if 'BUNIT' not in _h:
439
- _h['BUNIT'] = 'Jy/beam'
440
- if naxis >= 3 and 'SPECSYS' not in _h:
441
- _h['SPECSYS'] = 'LSRK'
442
- for k in _h:
443
- if not ('COMMENT' in k or 'HISTORY' in k) and _h[k] is not None:
444
- hdu.header[k] = _h[k]
432
+ for k, v in _h.items():
433
+ if v is not None and 'COMMENT' not in k and 'HISTORY' not in k:
434
+ hdu.header[k] = v
445
435
  hdu = fits.HDUList([hdu])
446
- hdu.writeto(fitsimage.replace('.fits', '') + '.fits', overwrite=True)
436
+ hdu.writeto(fitsimage.removesuffix('.fits') + '.fits', overwrite=True)
@@ -393,3 +393,52 @@ class EmceeCorner():
393
393
  error = evidence * results.logzerr[-1]
394
394
  self.evidence = evidence
395
395
  return {'evidence': evidence, 'error': error}
396
+
397
+
398
+ def gaussfit1d(xdata: np.ndarray, ydata: np.ndarray,
399
+ sigma: float | np.ndarray | None,
400
+ show: bool = False, **kwargs) -> dict:
401
+ """Gaussian fitting to a pair of 1D arrays.
402
+
403
+ Args:
404
+ xdata (np.ndarray): ydata is compared with Gauss(xdata).
405
+ ydata (np.ndarray): ydata is compared with Gauss(xdata).
406
+ sigma (float | np.ndarray | None): Noise level of ydata. If None is given, sigma is also a free parameter. Defaults to None.
407
+ show (bool, optional): True means to show the best-fit parameters and uncertainties. Defaults to False.
408
+
409
+ Returns:
410
+ dict: _description_
411
+ """
412
+ if sigma is not None and np.shape(sigma) == ():
413
+ sigma = [sigma] * len(xdata)
414
+ xmin, xmax = np.min(xdata), np.max(xdata)
415
+ ymin, ymax = np.min(ydata), np.max(ydata)
416
+ dx = np.abs(xdata[1] - xdata[0])
417
+ bounds = [[ymin, ymax], [xmin, xmax], [dx, xmax - xmin]]
418
+ if sigma is None:
419
+ bounds.append([np.log(ymax * 1e-6), np.log(ymax)])
420
+
421
+ def g(x, p):
422
+ a, c, w = p
423
+ return a * np.exp(-4. * np.log(2.) * ((x - c) / w)**2)
424
+
425
+ if sigma is None:
426
+ def logl(p):
427
+ sigmdl = np.exp(p[3])
428
+ chi2 = np.sum(((ydata - g(xdata, p[:3])) / sigmdl)**2)
429
+ return -0.5 * chi2 - p[3]
430
+ else:
431
+ def logl(p):
432
+ chi2 = np.sum(((ydata - g(xdata, p)) / sigma)**2)
433
+ return -0.5 * chi2
434
+
435
+ fitter = EmceeCorner(bounds=bounds, logl=logl)
436
+ fitter.fit(**kwargs)
437
+ popt = fitter.popt
438
+ plow = fitter.plow
439
+ phigh = fitter.phigh
440
+ perr = (phigh - plow) / 2
441
+ if show:
442
+ print('Gauss (peak, center, FWHM):', popt)
443
+ print('Gauss uncertainties:', perr)
444
+ return {'popt': popt, 'perr': perr}
@@ -37,11 +37,11 @@ def isdeg(s: str) -> bool:
37
37
 
38
38
  def trim(data: np.ndarray | None = None, x: np.ndarray | None = None,
39
39
  y: np.ndarray | None = None, v: np.ndarray | None = None,
40
- xlim: list[float, float] | None = None,
41
- ylim: list[float, float] | None = None,
42
- vlim: list[float, float] | None = None,
40
+ xlim: list[float] | None = None,
41
+ ylim: list[float] | None = None,
42
+ vlim: list[float] | None = None,
43
43
  pv: bool = False
44
- ) -> tuple[np.ndarray, list[np.ndarray, np.ndarray, np.ndarray]]:
44
+ ) -> tuple[np.ndarray | None, list[np.ndarray | None]]:
45
45
  """Trim 2D or 3D data by given coordinates and their limits.
46
46
 
47
47
  Args:
@@ -56,47 +56,37 @@ def trim(data: np.ndarray | None = None, x: np.ndarray | None = None,
56
56
  Returns:
57
57
  tuple: Trimmed (data, [x,y,v]).
58
58
  """
59
- xout, yout, vout, dataout = x, y, v, data
60
- i0 = j0 = k0 = 0
61
- i1 = j1 = k1 = 100000
62
- if x is not None and xlim is not None:
63
- if None not in xlim:
64
- x0 = np.max([np.min(x), xlim[0]])
65
- x1 = np.min([np.max(x), xlim[1]])
66
- i0 = np.argmin(np.abs(x - x0))
67
- i1 = np.argmin(np.abs(x - x1))
68
- i0, i1 = sorted([i0, i1])
69
- xout = x[i0:i1+1]
70
- if y is not None and ylim is not None:
71
- if None not in ylim:
72
- y0 = np.max([np.min(y), ylim[0]])
73
- y1 = np.min([np.max(y), ylim[1]])
74
- j0 = np.argmin(np.abs(y - y0))
75
- j1 = np.argmin(np.abs(y - y1))
76
- j0, j1 = sorted([j0, j1])
77
- yout = y[j0:j1+1]
78
- if v is not None and vlim is not None:
79
- if None not in vlim:
80
- v0 = np.max([np.min(v), vlim[0]])
81
- v1 = np.min([np.max(v), vlim[1]])
82
- k0 = np.argmin(np.abs(v - v0))
83
- k1 = np.argmin(np.abs(v - v1))
84
- k0, k1 = sorted([k0, k1])
85
- vout = v[k0:k1+1]
86
- if data is not None:
87
- d = np.squeeze(data)
88
- if np.ndim(d) == 0:
89
- print('data has only one pixel.')
90
- d = data
91
- if np.ndim(d) == 2:
92
- if pv:
93
- j0, j1 = k0, k1
94
- dataout = d[j0:j1+1, i0:i1+1]
95
- else:
96
- d = np.moveaxis(d, [-3, -2, -1], [0, 1, 2])
97
- d = d[k0:k1+1, j0:j1+1, i0:i1+1]
98
- d = np.moveaxis(d, [0, 1, 2], [-3, -2, -1])
99
- dataout = d
59
+ def get_bounds(arr, lim):
60
+ if arr is None or lim is None or None in lim:
61
+ return arr, 0, None
62
+ lo = np.argmin(np.abs(arr - max(np.min(arr), lim[0])))
63
+ hi = np.argmin(np.abs(arr - min(np.max(arr), lim[1])))
64
+ lo, hi = sorted((lo, hi))
65
+ return arr[lo:hi + 1], lo, hi + 1
66
+
67
+ xout, i0, i1 = get_bounds(x, xlim)
68
+ yout, j0, j1 = get_bounds(y, ylim)
69
+ vout, k0, k1 = get_bounds(v, vlim)
70
+
71
+ if data is None:
72
+ return None, [xout, yout, vout]
73
+
74
+ d = np.squeeze(data)
75
+
76
+ if d.ndim == 0:
77
+ print("data has only one pixel.")
78
+ return data, [xout, yout, vout]
79
+
80
+ if d.ndim == 2:
81
+ if pv:
82
+ j0, j1 = k0, k1
83
+ dataout = d[j0:j1, i0:i1]
84
+ else:
85
+ d = np.moveaxis(d, [-3, -2, -1], [0, 1, 2])
86
+ d = d[k0:k1, j0:j1, i0:i1]
87
+ d = np.moveaxis(d, [0, 1, 2], [-3, -2, -1])
88
+ dataout = d
89
+
100
90
  return dataout, [xout, yout, vout]
101
91
 
102
92
 
@@ -96,56 +96,110 @@ def logcbticks(vmin: float = 1e-3, vmax: float = 1e3
96
96
  return ticks[cond], ticklabels[cond]
97
97
 
98
98
 
99
- def do_stretch(x: list | np.ndarray,
100
- stretch: str, stretchscale: float,
101
- stretchpower: float) -> np.ndarray:
99
+ @dataclass
100
+ class Stretcher():
102
101
  """Get the stretched values.
103
102
 
104
- Args:
105
- x (list | np.ndarray): Input array in the linear scale.
106
- stretch (str): 'log', 'asinh', 'power', or 'linear'. Any other means 'linear'. 'log' means the mapped data are logarithmic. 'asinh' means the mapped data are arc sin hyperbolic. 'power' means the mapped data are power-law (see also stretchpower). Defaults to 'linear'.
107
- stretchscale (float): The output is asinh(data / stretchscale).
108
- stretchpower (float): The output is data**stretchpower / stretchpower. 1 means the linear scale, while 0 means the logarithmic scale.
109
-
110
- Returns:
111
- np.ndarray: Output stretched array.
103
+ Args:
104
+ stretch (str, optional): 'log', 'asinh', 'power', or 'linear'. Any other means 'linear'. 'log' means the mapped data are logarithmic. 'asinh' means the mapped data are arc sin hyperbolic. 'power' means the mapped data are power-law (see also stretchpower). Defaults to 'linear'.
105
+ stretchscale (float, optional): The output is asinh(data / stretchscale). Defaults to None.
106
+ stretchpower (float, optional): The output is data**stretchpower / stretchpower. 1 means the linear scale, while 0 means the logarithmic scale. Defaults to 0.5.
107
+ vmin (float, optional): The minimum value for Axes.pcolormesh() of matplotlib. Defaults to None.
108
+ vmax (float, optional): The maximum value for Axes.pcolormesh() of matplotlib. Defaults to None.
109
+ sigma (float, optional): Noise level. Defaults to 0.
112
110
  """
113
- t = np.array(x)
114
- match stretch:
115
- case 'log':
116
- t = np.log10(t) # To be consistent with logcbticks().
117
- case 'asinh':
118
- t = np.arcsinh(t / stretchscale)
119
- case 'power':
120
- p = 1e-6 if stretchpower == 0 else stretchpower
121
- t = t**p / p
122
- return t
123
-
124
-
125
- def undo_stretch(x: list | np.ndarray,
126
- stretch: str, stretchscale: float,
127
- stretchpower: float) -> np.ndarray:
128
- """Get the linear values from the stretched values.
111
+ stretch: str = 'linear'
112
+ stretchscale: float | None = None
113
+ stretchpower: float = 0.5
114
+ vmin: float | None = None
115
+ vmax: float | None = None
116
+ sigma: float = 0
117
+
118
+ def __post_init__(self):
119
+ self.n = 1 if type(self.stretch) is str else len(self.stretch)
120
+ stretch = self.stretch
121
+ stsc = self.stretchscale
122
+ vmin = self.vmin
123
+ sigma = self.sigma
124
+ if self.n == 1:
125
+ if stsc is None:
126
+ self.stretchscale = sigma
127
+ if (stretch == 'log' or stretch == 'power') and vmin is None:
128
+ self.vmin = sigma
129
+ else:
130
+ getsigma = np.equal(stsc, None)
131
+ self.stretchscale = np.where(getsigma, sigma, stsc)
132
+ islog = np.equal(stretch, 'log')
133
+ ispower = np.equal(stretch, 'power')
134
+ novmin = np.equal(vmin, None)
135
+ getsigma = (islog + ispower) * novmin
136
+ self.vmin = np.where(getsigma, sigma, vmin)
129
137
 
130
- Args:
131
- x (list | np.ndarray): Input stretched array.
132
- stretch (str): 'log', 'asinh', 'power', or 'linear'. Any other means 'linear'. 'log' means the mapped data are logarithmic. 'asinh' means the mapped data are arc sin hyperbolic. 'power' means the mapped data are power-law (see also stretchpower). Defaults to 'linear'.
133
- stretchscale (float): The input is asinh(data / stretchscale).
134
- stretchpower (float): The input is data**stretchpower / stretchpower. 1 means the linear scale, while 0 means the logarithmic scale.
138
+ def do(self, x: list | np.ndarray) -> np.ndarray:
139
+ """Get the stretched values.
135
140
 
136
- Returns:
137
- np.ndarray: Output array in the linear scale.
138
- """
139
- t = np.array(x)
140
- match stretch:
141
- case 'log':
142
- t = 10**t # To be consistent with logcbticks().
143
- case 'asinh':
144
- t = np.sinh(t) * stretchscale
145
- case 'power':
146
- p = 1e-6 if stretchpower == 0 else stretchpower
147
- t = (t * p)**(1 / p)
148
- return t
141
+ Args:
142
+ x (list | np.ndarray): Input array in the linear scale.
143
+
144
+ Returns:
145
+ np.ndarray: Output stretched array.
146
+ """
147
+ t = np.array(x)
148
+ match self.stretch:
149
+ case 'log':
150
+ t = np.log10(t) # To be consistent with logcbticks().
151
+ case 'asinh':
152
+ t = np.arcsinh(t / self.stretchscale)
153
+ case 'power':
154
+ p = 1e-6 if self.stretchpower == 0 else self.stretchpower
155
+ t = t**p / p
156
+ return t
157
+
158
+ def undo(self, x: list | np.ndarray) -> np.ndarray:
159
+ """Get the linear values from the stretched values.
160
+
161
+ Args:
162
+ x (list | np.ndarray): Input stretched array.
163
+
164
+ Returns:
165
+ np.ndarray: Output array in the linear scale.
166
+ """
167
+ t = np.array(x)
168
+ match self.stretch:
169
+ case 'log':
170
+ t = 10**t # To be consistent with logcbticks().
171
+ case 'asinh':
172
+ t = np.sinh(t) * self.stretchscale
173
+ case 'power':
174
+ p = 1e-6 if self.stretchpower == 0 else self.stretchpower
175
+ t = (t * p)**(1 / p)
176
+ return t
177
+
178
+ def set_minmax(self, data: np.ndarray
179
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
180
+ """Set vmin and vmax for color pcolormesh and RGB maps.
181
+
182
+ Args:
183
+ data (np.ndarray): 2D/3D data to plot.
184
+
185
+ Returns:
186
+ tuple[np.ndarray, np.ndarray, np.ndarray]: (Clipped stretched data, new vmin, new vmax).
187
+ """
188
+ single = self.n == 1
189
+ vminout = [self.vmin] if single else self.vmin
190
+ vmaxout = [self.vmax] if single else self.vmax
191
+ dataout = [data] if single else data
192
+ for i, (c, v0, v1) in enumerate(zip(dataout, vminout, vmaxout)):
193
+ dataout[i] = cout = self.do(c.clip(v0, v1))
194
+ vminout[i] = np.nanmin(cout)
195
+ vmaxout[i] = np.nanmax(cout)
196
+ if single:
197
+ dataout = dataout[0]
198
+ vminout = vminout[0]
199
+ vmaxout = vmaxout[0]
200
+ self.vmin = vminout
201
+ self.vmax = vmaxout
202
+ return dataout, vminout, vmaxout
149
203
 
150
204
 
151
205
  @dataclass
@@ -251,61 +305,6 @@ class PlotAxes2D():
251
305
  ax.set_aspect(self.aspect)
252
306
 
253
307
 
254
- def set_minmax(data: np.ndarray, stretch: str, stretchscale: float,
255
- stretchpower: float, sigma: float, kw: dict
256
- ) -> np.ndarray:
257
- """Set vmin and vmax for color pcolormesh and RGB maps.
258
-
259
- Args:
260
- data (np.ndarray): Plotted data.
261
- stretch (str): 'log', 'asinh', 'power', or 'linear'. Any other means 'linear'. 'log' means the mapped data are logarithmic. 'asinh' means the mapped data are arc sin hyperbolic. 'power' means the mapped data are power-law (see also stretchpower). Defaults to 'linear'.
262
- stretchscale (float): The input is asinh(data / stretchscale).
263
- stretchpower (float): The input is data**stretchpower / stretchpower. 1 means the linear scale, while 0 means the logarithmic scale.
264
- sigma (float): Noise level.
265
- kw (dict): Probably like {'vmin':0, 'vmax':1}.
266
-
267
- Returns:
268
- np.ndarray: Data clipped with the vmin and vmax.
269
- """
270
- if type(stretch) is str:
271
- data = [data]
272
- sigma = [sigma]
273
- stretch = [stretch]
274
- stretchscale = [stretchscale]
275
- stretchpower = [stretchpower]
276
- for k in ['vmin', 'vmax']:
277
- if k in kw:
278
- kw[k] = [kw[k]]
279
-
280
- n = len(data)
281
- for k in ['vmin', 'vmax']:
282
- if k not in kw:
283
- kw[k] = [None] * n
284
- isnan = np.equal(stretchscale, None)
285
- stretchscale = np.where(isnan, sigma, stretchscale)
286
- isnan = np.equal(kw['vmin'], None)
287
- cmin = np.where(isnan, sigma, kw['vmin'])
288
-
289
- argslist = (stretch, stretchscale, stretchpower)
290
- for i, stretch_args in enumerate(zip(*argslist)):
291
- c = data[i]
292
- if stretch_args[0] in ['log', 'power']:
293
- c = c.clip(cmin[i], None)
294
- c = do_stretch(c, *stretch_args)
295
- data[i] = c
296
- for k in ['vmin', 'vmax']:
297
- if kw[k][i] is None:
298
- kw[k][i] = np.nanmin(c) if k == 'vmin' else np.nanmax(c)
299
- else:
300
- kw[k][i] = do_stretch(kw[k][i], *stretch_args)
301
- data = [c.clip(a, b) for c, a, b in zip(data, kw['vmin'], kw['vmax'])]
302
- if n == 1:
303
- data = data[0]
304
- for k in ['vmin', 'vmax']:
305
- kw[k] = kw[k][0]
306
- return data
307
-
308
-
309
308
  def kwargs2AstroData(kw: dict) -> AstroData:
310
309
  """Get AstroData and remove its arguments from kwargs.
311
310
 
@@ -802,9 +801,7 @@ class PlotAstroData(AstroFrame):
802
801
  def _set_colorbar(self, mappable, ch: int, show_cbar: bool,
803
802
  cblabel: str, cbformat: str,
804
803
  cbticks: list | None, cbticklabels: list | None,
805
- cblocation: str, stretch: str,
806
- stretchscale: float, stretchpower: float,
807
- vmin: float, vmax: float):
804
+ cblocation: str, st: Stretcher):
808
805
  if not show_cbar:
809
806
  return
810
807
 
@@ -823,27 +820,19 @@ class PlotAstroData(AstroFrame):
823
820
  cb.ax.tick_params(labelsize=14)
824
821
  font = mpl.font_manager.FontProperties(size=16)
825
822
  cb.ax.yaxis.label.set_font_properties(font)
826
- stretch_args = (stretch, stretchscale, stretchpower)
827
- if cbticks is None and stretch == 'log':
828
- cbticks, cbticklabels = logcbticks(10**vmin, 10**vmax)
829
- if cbticks is not None:
830
- cbticks = do_stretch(cbticks, *stretch_args)
831
- else:
832
- cbticks = cb.get_ticks()
833
- cond = (vmin <= cbticks) * (cbticks <= vmax)
823
+ if cbticks is None and st.stretch == 'log':
824
+ cbticks, cbticklabels = logcbticks(10**st.vmin, 10**st.vmax)
825
+ cbticks = cb.get_ticks() if cbticks is None else st.do(cbticks)
826
+ cond = (st.vmin <= cbticks) * (cbticks <= st.vmax)
834
827
  cbticks = cbticks[cond]
835
828
  cb.set_ticks(cbticks)
836
- if cbticklabels is not None:
837
- cbticklabels = np.array(cbticklabels)[cond]
829
+ if cbticklabels is None:
830
+ cbticklabels = [f'{t:{cbformat[1:]}}' for t in st.undo(cbticks)]
838
831
  else:
839
- t = undo_stretch(cbticks, *stretch_args)
840
- cbticklabels = [f'{d:{cbformat[1:]}}' for d in t]
832
+ cbticklabels = np.array(cbticklabels)[cond]
841
833
  cb.set_ticklabels(cbticklabels)
842
834
 
843
835
  def add_color(self,
844
- stretch: str = 'linear',
845
- stretchscale: float | None = None,
846
- stretchpower: float = 0.5,
847
836
  show_cbar: bool = True,
848
837
  cblabel: str | None = None,
849
838
  cbformat: float = '%.1e',
@@ -851,12 +840,9 @@ class PlotAstroData(AstroFrame):
851
840
  cbticklabels: list[str] | None = None,
852
841
  cblocation: str = 'right',
853
842
  **kwargs) -> None:
854
- """Use Axes.pcolormesh of matplotlib. kwargs must include the arguments of AstroData to specify the data to be plotted. kwargs may include arguments for add_beam() and a dict of beam_kwargs to specify the beam patch in more detail. kwargs may include xskiip and yskip.
843
+ """Use Axes.pcolormesh of matplotlib. kwargs must include the arguments of AstroData to specify the data to be plotted. kwargs may include the arguments for Stretcher (stretch, stretchscale, and stretchpower) to specify the stretch parameters. kwargs may include arguments for add_beam() and a dict of beam_kwargs to specify the beam patch in more detail. kwargs may include xskiip and yskip.
855
844
 
856
845
  Args:
857
- stretch (str, optional): 'log', 'asinh', 'power', or 'linear'. Any other means 'linear'. 'log' means the mapped data are logarithmic. 'asinh' means the mapped data are arc sin hyperbolic. 'power' means the mapped data are power-law (see also stretchpower). Defaults to 'linear'.
858
- stretchscale (float, optional): Color scale is asinh(data / stretchscale). Defaults to None.
859
- stretchpower (float, optional): Color scale is data**stretchpower / stretchpower. 1 means the linear scale, while 0 means the logarithmic scale. Defaults to 0.5.
860
846
  show_cbar (bool, optional): Show color bar. Defaults to True.
861
847
  cblabel (str, optional): Colorbar label. Defaults to None.
862
848
  cbformat (float, optional): Format for ticklabels of colorbar. Defaults to '%.1e'.
@@ -865,7 +851,8 @@ class PlotAstroData(AstroFrame):
865
851
  cblocation (str, optional): 'left', 'top', 'left', 'right'. Only for 2D images. Defaults to 'right'.
866
852
  """
867
853
  self._kw = {'cmap': 'cubehelix', 'alpha': 1,
868
- 'edgecolors': 'none', 'zorder': 1}
854
+ 'edgecolors': 'none', 'zorder': 1,
855
+ 'vmin': None, 'vmax': None}
869
856
  c, x, y, v, beam, sigma, bunit, _kw, beam_kwargs, singlepix \
870
857
  = self._map_init(kwargs)
871
858
  if singlepix:
@@ -874,11 +861,16 @@ class PlotAstroData(AstroFrame):
874
861
 
875
862
  if cblabel is None:
876
863
  cblabel = bunit
877
- if stretchscale is None:
878
- stretchscale = sigma
879
864
 
880
- stretch_args = (stretch, stretchscale, stretchpower)
881
- c = set_minmax(c, *stretch_args, sigma, _kw)
865
+ stretch_params = {}
866
+ for k in ['stretch', 'stretchscale', 'stretchpower']:
867
+ if k in _kw:
868
+ stretch_params[k] = _kw.pop(k)
869
+ st = Stretcher(vmin=_kw['vmin'], vmax=_kw['vmax'],
870
+ sigma=sigma, **stretch_params)
871
+ c, cmin, cmax = st.set_minmax(c)
872
+ _kw['vmin'] = cmin
873
+ _kw['vmax'] = cmax
882
874
  c = self.vskipfill(c, v)
883
875
  if type(self.channelnumber) is int:
884
876
  c = [c[self.channelnumber]]
@@ -889,8 +881,7 @@ class PlotAstroData(AstroFrame):
889
881
  p[ch] = pnow
890
882
  for ch in self.bottomleft:
891
883
  self._set_colorbar(p, ch, show_cbar, cblabel, cbformat,
892
- cbticks, cbticklabels, cblocation,
893
- *stretch_args, _kw['vmin'], _kw['vmax'])
884
+ cbticks, cbticklabels, cblocation, st)
894
885
  self.add_beam(beam=beam, **beam_kwargs)
895
886
 
896
887
  def add_contour(self,
@@ -982,33 +973,34 @@ class PlotAstroData(AstroFrame):
982
973
  def add_rgb(self,
983
974
  stretch: list[str, str, str] = ['linear'] * 3,
984
975
  stretchscale: list[float | None, float | None, float | None] = [None] * 3,
985
- stretchpower: list[float, float, float] = [1, 1, 1],
976
+ stretchpower: list[float, float, float] = [0.5, 0.5, 0.5],
986
977
  **kwargs) -> None:
987
- """Use PIL.Image and imshow of matplotlib. kwargs must include the arguments of AstroData to specify the data to be plotted. A three-element array ([red, green, blue]) is supposed for all arguments, except for xskip, yskip and show_beam, including vmax and vmin. kwargs may include arguments for add_beam() and a dict of beam_kwargs to specify the beam patch in more detail. kwargs may include xskiip and yskip.
978
+ """Use PIL.Image and imshow of matplotlib. kwargs must include the arguments of AstroData to specify the data to be plotted. A three-element array ([red, green, blue]) is supposed for all arguments, except for xskip, yskip and show_beam, including vmax and vmin. kwargs may include the arguments for Stretcher (stretch, stretchscale, and stretchpower; three-element array for each) to specify the stretch parameters. kwargs may include arguments for add_beam() and a dict of beam_kwargs to specify the beam patch in more detail. kwargs may include xskiip and yskip.
988
979
 
989
980
  Args:
990
981
  stretch (str, optional): 'log', 'asinh', 'power', or 'linear'. Any other means 'linear'. 'log' means the mapped data are logarithmic. 'asinh' means the mapped data are arc sin hyperbolic. 'power' means the mapped data are power-law (see also stretchpower). Defaults to 'linear'.
991
982
  stretchscale (float, optional): Color scale is asinh(data / stretchscale). Defaults to None.
992
- stretchpower (float, optional): Color scale is data**stretchpower / stretchpower. 1 means the linear scale, while 0 means the logarithmic scale. Defaults to 1.
983
+ stretchpower (float, optional): Color scale is data**stretchpower / stretchpower. 1 means the linear scale, while 0 means the logarithmic scale. Defaults to 0.5.
993
984
  """
994
985
  from PIL import Image
995
986
 
996
- self._kw = {}
987
+ self._kw = {'vmin': [None] * 3, 'vmax': [None] * 3}
997
988
  c, x, y, v, beam, sigma, _, _kw, beam_kwargs, singlepix \
998
989
  = self._map_init(kwargs)
999
990
  if singlepix:
1000
991
  print('No pixel size. Skip add_rgb.')
1001
992
  return
1002
993
 
1003
- stretch_args = (stretch, stretchscale, stretchpower)
1004
- c = set_minmax(c, *stretch_args, sigma, _kw)
1005
994
  if not (np.shape(c[0]) == np.shape(c[1]) == np.shape(c[2])):
1006
995
  print('RGB shapes mismatch. Skip add_rgb.')
1007
996
  return
1008
997
 
1009
- for i, (cmin, cmax) in enumerate(zip(_kw['vmin'], _kw['vmax'])):
1010
- if cmax > cmin:
1011
- c[i] = (c[i] - cmin) / (cmax - cmin) * 255
998
+ st = Stretcher(stretch, stretchscale, stretchpower,
999
+ _kw['vmin'], _kw['vmax'], sigma)
1000
+ c, cmin, cmax = st.set_minmax(c)
1001
+ for i in range(st.n):
1002
+ if cmax[i] > cmin[i]:
1003
+ c[i] = (c[i] - cmin[i]) / (cmax[i] - cmin[i]) * 255
1012
1004
  c[i] = self.vskipfill(c[i], v)
1013
1005
  size = np.shape(c[0][0])[::-1]
1014
1006
  c = np.moveaxis(c, 1, 0)[:, :, ::-self.ydir, ::-self.xdir]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: plotastrodata
3
- Version: 1.8.11
3
+ Version: 1.8.12
4
4
  Summary: plotastrodata is a tool for astronomers to create figures from FITS files and perform fundamental data analyses with ease.
5
5
  Home-page: https://github.com/yusukeaso-astron/plotastrodata
6
6
  Download-URL: https://github.com/yusukeaso-astron/plotastrodata
File without changes
File without changes
File without changes
File without changes