plotastrodata 1.8.17__tar.gz → 1.9.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. {plotastrodata-1.8.17/plotastrodata.egg-info → plotastrodata-1.9.0}/PKG-INFO +1 -1
  2. {plotastrodata-1.8.17 → plotastrodata-1.9.0}/plotastrodata/__init__.py +1 -1
  3. {plotastrodata-1.8.17 → plotastrodata-1.9.0}/plotastrodata/analysis_utils.py +10 -9
  4. {plotastrodata-1.8.17 → plotastrodata-1.9.0}/plotastrodata/coord_utils.py +43 -0
  5. {plotastrodata-1.8.17 → plotastrodata-1.9.0}/plotastrodata/other_utils.py +112 -2
  6. {plotastrodata-1.8.17 → plotastrodata-1.9.0}/plotastrodata/plot_utils.py +136 -148
  7. {plotastrodata-1.8.17 → plotastrodata-1.9.0/plotastrodata.egg-info}/PKG-INFO +1 -1
  8. {plotastrodata-1.8.17 → plotastrodata-1.9.0}/LICENSE +0 -0
  9. {plotastrodata-1.8.17 → plotastrodata-1.9.0}/MANIFEST.in +0 -0
  10. {plotastrodata-1.8.17 → plotastrodata-1.9.0}/README.md +0 -0
  11. {plotastrodata-1.8.17 → plotastrodata-1.9.0}/plotastrodata/const_utils.py +0 -0
  12. {plotastrodata-1.8.17 → plotastrodata-1.9.0}/plotastrodata/ext_utils.py +0 -0
  13. {plotastrodata-1.8.17 → plotastrodata-1.9.0}/plotastrodata/fft_utils.py +0 -0
  14. {plotastrodata-1.8.17 → plotastrodata-1.9.0}/plotastrodata/fits_utils.py +0 -0
  15. {plotastrodata-1.8.17 → plotastrodata-1.9.0}/plotastrodata/fitting_utils.py +0 -0
  16. {plotastrodata-1.8.17 → plotastrodata-1.9.0}/plotastrodata/los_utils.py +0 -0
  17. {plotastrodata-1.8.17 → plotastrodata-1.9.0}/plotastrodata/matrix_utils.py +0 -0
  18. {plotastrodata-1.8.17 → plotastrodata-1.9.0}/plotastrodata/noise_utils.py +0 -0
  19. {plotastrodata-1.8.17 → plotastrodata-1.9.0}/plotastrodata.egg-info/SOURCES.txt +0 -0
  20. {plotastrodata-1.8.17 → plotastrodata-1.9.0}/plotastrodata.egg-info/dependency_links.txt +0 -0
  21. {plotastrodata-1.8.17 → plotastrodata-1.9.0}/plotastrodata.egg-info/not-zip-safe +0 -0
  22. {plotastrodata-1.8.17 → plotastrodata-1.9.0}/plotastrodata.egg-info/requires.txt +0 -0
  23. {plotastrodata-1.8.17 → plotastrodata-1.9.0}/plotastrodata.egg-info/top_level.txt +0 -0
  24. {plotastrodata-1.8.17 → plotastrodata-1.9.0}/setup.cfg +0 -0
  25. {plotastrodata-1.8.17 → plotastrodata-1.9.0}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: plotastrodata
3
- Version: 1.8.17
3
+ Version: 1.9.0
4
4
  Summary: plotastrodata is a tool for astronomers to create figures from FITS files and perform fundamental data analyses with ease.
5
5
  Home-page: https://github.com/yusukeaso-astron/plotastrodata
6
6
  Download-URL: https://github.com/yusukeaso-astron/plotastrodata
@@ -1,4 +1,4 @@
1
1
  import warnings
2
2
 
3
3
  warnings.simplefilter('ignore', FutureWarning)
4
- __version__ = '1.8.17'
4
+ __version__ = '1.9.0'
@@ -7,11 +7,12 @@ from scipy.signal import convolve
7
7
  from plotastrodata import const_utils as cu
8
8
  from plotastrodata.coord_utils import coord2xy, rel2abs, xy2coord
9
9
  from plotastrodata.fits_utils import data2fits, FitsData, Jy2K
10
- from plotastrodata.fitting_utils import (EmceeCorner, gaussian2d,
11
- gaussfit1d, gaussfit2d)
10
+ from plotastrodata.fitting_utils import (EmceeCorner, gaussfit1d,
11
+ gaussfit2d, gaussian2d)
12
12
  from plotastrodata.matrix_utils import dot2d, Mfac, Mrot
13
13
  from plotastrodata.noise_utils import estimate_rms
14
- from plotastrodata.other_utils import isdeg, RGIxy, RGIxyv, to4dim, trim
14
+ from plotastrodata.other_utils import (isdeg, nearest_index,
15
+ RGIxy, RGIxyv, to4dim, trim)
15
16
 
16
17
 
17
18
  def quadrantmean(data: np.ndarray, x: np.ndarray, y: np.ndarray,
@@ -197,10 +198,10 @@ class AstroData():
197
198
  includev (bool, optional): Centering in the v direction at each position. Defaults to False.
198
199
  """
199
200
  if includexy:
200
- xnew = self.x - self.x[np.argmin(np.abs(self.x))]
201
- ynew = self.y - self.y[np.argmin(np.abs(self.y))]
201
+ xnew = self.x - self.x[nearest_index(self.x)]
202
+ ynew = self.y - self.y[nearest_index(self.y)]
202
203
  if includev:
203
- vnew = self.v - self.v[np.argmin(np.abs(self.v))]
204
+ vnew = self.v - self.v[nearest_index(self.v)]
204
205
  if includexy and includev:
205
206
  self.data = RGIxyv(self.v, self.y, self.x, self.data,
206
207
  np.meshgrid(vnew, ynew, xnew, indexing='ij'),
@@ -518,10 +519,10 @@ class AstroData():
518
519
  """
519
520
  fhd = self.fitsheader
520
521
  h = {}
521
- ci = np.argmin(np.abs(self.x))
522
+ ci = nearest_index(self.x)
522
523
  cx = 0
523
524
  if not self.pv:
524
- cj = np.argmin(np.abs(self.y))
525
+ cj = nearest_index(self.y)
525
526
  if self.center is None:
526
527
  cx, cy = self.x[ci], self.x[cj]
527
528
  else:
@@ -538,7 +539,7 @@ class AstroData():
538
539
  h['CDELT1'] = float(self.dx / (3600 if indeg('1') else 1))
539
540
  if self.dv is not None:
540
541
  vaxis = '2' if self.pv else '3'
541
- ck = np.argmin(np.abs(self.v))
542
+ ck = nearest_index(self.v)
542
543
  cv = self.v[ck]
543
544
  dv = self.dv
544
545
  if self.restfreq is None or self.restfreq == 0:
@@ -149,3 +149,46 @@ def abs2rel(xabs: float, yabs: float,
149
149
  xrel = (xabs - x[0]) / (x[-1] - x[0])
150
150
  yrel = (yabs - y[0]) / (y[-1] - y[0])
151
151
  return np.array([xrel, yrel])
152
+
153
+
154
+ def get_sec(coord: str, mode: str) -> str:
155
+ """Pick up the second number from a hmsdms string.
156
+
157
+ Args:
158
+ coord (str): hmsdms string.
159
+ mode (str): 'ra' or 'dec'
160
+
161
+ Returns:
162
+ str: The second number as a string without the unit.
163
+ """
164
+ i_axis = 0 if mode == 'ra' else 1
165
+ return coord.split(' ')[i_axis].split('m')[1].strip('s')
166
+
167
+
168
+ def get_min(coord: str, mode: str) -> str:
169
+ """Pick up the minute number from a hmsdms string.
170
+
171
+ Args:
172
+ coord (str): hmsdms string.
173
+ mode (str): 'ra' or 'dec'
174
+
175
+ Returns:
176
+ str: The minute number as a string without the unit.
177
+ """
178
+ i_axis = 0 if mode == 'ra' else 1
179
+ s = 'h' if mode == 'ra' else 'd'
180
+ return coord.split(' ')[i_axis].split(s)[1].split('m')[0]
181
+
182
+
183
+ def get_hmdm(coord: str, mode: str) -> str:
184
+ """Pick up the coordinate string before the second part from a hsmdms string.
185
+
186
+ Args:
187
+ coord (str): hmsdms string.
188
+ mode (str): 'ra' or 'dec'
189
+
190
+ Returns:
191
+ str: The hm or dm string with the units.
192
+ """
193
+ i_axis = 0 if mode == 'ra' else 1
194
+ return coord.split(' ')[i_axis].split('m')[0] + 'm'
@@ -1,5 +1,6 @@
1
1
  import matplotlib.pyplot as plt
2
2
  import numpy as np
3
+ import warnings
3
4
  from scipy.interpolate import RegularGridInterpolator as RGI
4
5
 
5
6
 
@@ -33,6 +34,20 @@ def isdeg(s: str) -> bool:
33
34
  return False
34
35
 
35
36
 
37
+ def nearest_index(arr: np.ndarray, x: float = 0) -> int:
38
+ """Get the index of the (sorted) arrary that gives a value nearest to x. This is equivalent to np.argmin(np.abs(arr - x)) but optimized for the sorted array.
39
+
40
+ Args:
41
+ arr (np.ndarray): Sorted array.
42
+ x (float, optional): Value to approach. Defaults to 0.
43
+
44
+ Returns:
45
+ int: The index that gives a value nearest to x.
46
+ """
47
+ idx = np.searchsorted(arr, x).clip(1, len(arr) - 1)
48
+ return idx - 1 if x - arr[idx - 1] <= arr[idx] - x else idx
49
+
50
+
36
51
  def trim(data: np.ndarray | None = None, x: np.ndarray | None = None,
37
52
  y: np.ndarray | None = None, v: np.ndarray | None = None,
38
53
  xlim: list[float] | None = None,
@@ -57,8 +72,8 @@ def trim(data: np.ndarray | None = None, x: np.ndarray | None = None,
57
72
  def get_bounds(arr, lim):
58
73
  if arr is None or lim is None or None in lim:
59
74
  return arr, 0, None
60
- lo = np.argmin(np.abs(arr - max(np.min(arr), lim[0])))
61
- hi = np.argmin(np.abs(arr - min(np.max(arr), lim[1])))
75
+ lo = nearest_index(arr, max(np.min(arr), lim[0]))
76
+ hi = nearest_index(arr, min(np.max(arr), lim[1]))
62
77
  lo, hi = sorted((lo, hi))
63
78
  return arr[lo:hi + 1], lo, hi + 1
64
79
 
@@ -106,6 +121,101 @@ def to4dim(data: np.ndarray) -> np.ndarray:
106
121
  return d
107
122
 
108
123
 
124
+ def reform_grid(v: np.ndarray | None = None,
125
+ k0: int | None = None, k1: int | None = None,
126
+ vmin: float | None = None, vmax: float | None = None
127
+ ) -> np.ndarray:
128
+ """Extend or cut the given 1D array based on the given range.
129
+
130
+ Args:
131
+ v (np.ndarray | None, optional): Input 1D array. Defaults to None.
132
+ k0 (int | None, optional): How many channels are added before v[0]; the minus sign means extension. k0 has the priority over vmin. Defaults to None.
133
+ k1 (int | None, optional): How many channels are added after v[-1]; the plus sign means extension. k1 has the priority over vmax. Defaults to None.
134
+ vmin (float | None, optional): New minimum velocity. Defaults to None.
135
+ vmax (float | None, optional): New maximum velocity. Defaults to None.
136
+
137
+ Returns:
138
+ np.ndarray: Extended or cut 1D array.
139
+ """
140
+ if v is None or len(v) <= 1:
141
+ return v
142
+
143
+ dv = v[1] - v[0]
144
+ if k0 is None and vmin is not None:
145
+ k0 = int(round((vmin - v[0]) / dv))
146
+ if k0 is not None and k0 != 0:
147
+ if k0 < 0:
148
+ vpre = v[0] + dv * np.arange(k0, 0)
149
+ v = np.concatenate((vpre, v))
150
+ else:
151
+ v = v[k0:]
152
+ if k1 is None and vmax is not None:
153
+ k1 = int(round((vmax - v[-1]) / dv))
154
+ if k1 is not None and k1 != 0:
155
+ if k1 > 0:
156
+ vpost = v[-1] + dv * np.arange(1, k1 + 1)
157
+ v = np.concatenate((v, vpost))
158
+ else:
159
+ v = v[:len(v) + k1]
160
+ return v
161
+
162
+
163
+ def reform_data(c: np.ndarray, v_in: np.ndarray | None,
164
+ nv: int, v_org: np.ndarray | None = None,
165
+ vskip: int = 1) -> np.ndarray:
166
+ """Skip and fill channels with nan.
167
+
168
+ Args:
169
+ c (np.ndarray): The input 2D or 3D arrays.
170
+ v_in (np.ndarray): The input velocity 1D array.
171
+ nv (int): The number of channels with a label.
172
+ v (np.ndarray, optional): The velocity 1D array, including the channels with and without a label. Defaults to None.
173
+ vskip (int, optional): How many channels are skipped. Defaults to 1.
174
+
175
+ Returns:
176
+ np.ndarray: 3D arrays skipped and filled with nan.
177
+ """
178
+ if v_org is None:
179
+ return c
180
+
181
+ ndim = np.ndim(c)
182
+ if ndim not in [2, 3]:
183
+ print('c must be 2D or 3D.')
184
+ return
185
+
186
+ if ndim == 2:
187
+ d = np.full((nv, *np.shape(c)), c)
188
+ elif v_in is not None:
189
+ dv_org = v_org[1] - v_org[0]
190
+ dv_in = (v_in[1] - v_in[0]) * vskip
191
+ k0 = nearest_index(v_org, v_in[0])
192
+ k1 = nearest_index(v_org, v_in[-1])
193
+ if np.abs(dv_in - dv_org) / dv_org < 0.01:
194
+ d = c
195
+ else:
196
+ s = 'Velocity resolution mismatch (>1%).' \
197
+ + ' The cube needs to be regridded' \
198
+ + ' outside plotastrodata.'
199
+ warnings.warn(s, UserWarning)
200
+ n_valid = k1 - k0
201
+ d = [None] * n_valid
202
+ for k in range(n_valid):
203
+ k_tmp = nearest_index(v_in, v_org[k])
204
+ diffvel = np.abs(v_in[k_tmp] - v_org[k])
205
+ nearby = diffvel < dv_org * 0.5
206
+ d[k] = c[k_tmp] if nearby else c[0] * np.nan
207
+ d = np.array(d)
208
+ if k0 > 0:
209
+ prenan = np.full((k0, *np.shape(d)[1:]), np.nan)
210
+ d = np.concatenate((prenan, d))
211
+ d = d[::vskip]
212
+ shape = np.shape(d)
213
+ shape = (len(v_org) - shape[0], shape[1], shape[2])
214
+ postnan = np.full(shape, np.nan)
215
+ d = np.concatenate((d, postnan))
216
+ return d
217
+
218
+
109
219
  def RGIxy(y: np.ndarray, x: np.ndarray, data: np.ndarray,
110
220
  yxnew: tuple[np.ndarray, np.ndarray] | None = None,
111
221
  **kwargs) -> object | np.ndarray:
@@ -1,15 +1,16 @@
1
1
  import matplotlib as mpl
2
2
  import matplotlib.pyplot as plt
3
3
  import numpy as np
4
- import warnings
5
4
  from dataclasses import dataclass
6
5
  from matplotlib.patches import Ellipse, Rectangle
7
6
  from typing import TypeVar
8
7
 
9
8
  from plotastrodata.analysis_utils import AstroData, AstroFrame
10
- from plotastrodata.coord_utils import coord2xy, xy2coord
9
+ from plotastrodata.coord_utils import (coord2xy, xy2coord,
10
+ get_hmdm, get_min, get_sec)
11
11
  from plotastrodata.noise_utils import estimate_rms
12
- from plotastrodata.other_utils import close_figure, listing
12
+ from plotastrodata.other_utils import (close_figure, listing,
13
+ reform_grid, reform_data)
13
14
 
14
15
 
15
16
  plt.ioff() # force to turn off interactive mode
@@ -99,28 +100,42 @@ def logcbticks(vmin: float = 1e-3, vmax: float = 1e3
99
100
  return ticks[cond], ticklabels[cond]
100
101
 
101
102
 
102
- def _get_sec(coord: str, mode: str) -> str:
103
- i_axis = 0 if mode == 'ra' else 1
104
- return coord.split(' ')[i_axis].split('m')[1].strip('s')
103
+ def get_figsize(xmin: float, xmax: float, ymin: float, ymax: float,
104
+ figsize: tuple | None = None,
105
+ ncols: int = 1, nrows: int = 1, nchan: int = 1
106
+ ) -> tuple[float, float]:
107
+ """Get a nice figsize (tuple) with the given x and y ranges.
105
108
 
109
+ Args:
110
+ xmin (float): The figsize is based on the aspect ratio of (ymax - ymin) / (xmax - xmin).
111
+ xmax (float): The figsize is based on the aspect ratio of (ymax - ymin) / (xmax - xmin).
112
+ ymin (float): The figsize is based on the aspect ratio of (ymax - ymin) / (xmax - xmin).
113
+ ymax (float): The figsize is based on the aspect ratio of (ymax - ymin) / (xmax - xmin).
114
+ figsize (tuple | None, optional): If this is not None, this will be the output as is. Defaults to None.
115
+ ncols (int, optional): The number of columns for the channel map. Defaults to 1.
116
+ nrows (int, optional): The number of rows for the channel map. Defaults to 1.
117
+ nchan (int, optional): The number of total channels for the channel map. Defaults to 1.
106
118
 
107
- def _get_min(coord: str, mode: str) -> str:
108
- i_axis = 0 if mode == 'ra' else 1
109
- s = 'h' if mode == 'ra' else 'd'
110
- return coord.split(' ')[i_axis].split(s)[1].split('m')[0]
111
-
119
+ Returns:
120
+ tuple[float, float]: figsize for matplotlib.pyplot.Figure.
121
+ """
122
+ if figsize is not None:
123
+ return figsize
112
124
 
113
- def _get_hmdm(coord: str, mode: str) -> str:
114
- i_axis = 0 if mode == 'ra' else 1
115
- return coord.split(' ')[i_axis].split('m')[0] + 'm'
125
+ sqrt_a = (ymax - ymin) / (xmax - xmin)
126
+ sqrt_a = np.sqrt(np.abs(sqrt_a))
127
+ if nchan == 1:
128
+ figsize = (7 / sqrt_a, 5 * sqrt_a)
129
+ else:
130
+ figsize = (ncols * 2 / sqrt_a, max(nrows*2, 3) * sqrt_a)
131
+ return figsize
116
132
 
117
133
 
118
134
  def _get_gridwidth(mode: str, rmax: float) -> tuple[float, int]:
119
- # 10^1.5 / 15 ~ 2 grids for R.A.
120
- # 10^0.5 ~ 3 grids for Dec.
121
- scale = 1.5 if mode == 'ra' else 0.5
122
- log2r = np.log10(2. * rmax)
123
- x = log2r - scale
135
+ # 10^1.45 / 15 ~ 2 grids for R.A.
136
+ # 10^0.45 ~ 3 grids for Dec.
137
+ scale = 1.45 if mode == 'ra' else 0.45
138
+ x = np.log10(2. * rmax) - scale
124
139
  order = np.floor(x)
125
140
  frac = x - order
126
141
  if frac <= 0.33:
@@ -132,6 +147,42 @@ def _get_gridwidth(mode: str, rmax: float) -> tuple[float, int]:
132
147
  return base * 10**order, int(order)
133
148
 
134
149
 
150
+ def _get_v(p, v: np.ndarray | None = None,
151
+ restfreq: float | None = None,
152
+ vskip: int = 1) -> np.ndarray:
153
+ if p.fitsimage is not None:
154
+ p.read(d := AstroData(fitsimage=p.fitsimage,
155
+ restfreq=restfreq, sigma=None))
156
+ v = d.v
157
+ if v is None:
158
+ v = np.array([0])
159
+ if len(v) > 1:
160
+ v = reform_grid(v=v, vmin=p.vmin, vmax=p.vmax)
161
+ v = v[::vskip]
162
+ return v
163
+
164
+
165
+ def _get_nij2ch(nrows: int = 1, ncols: int = 1) -> object:
166
+ def nij2ch(n: int, i: int, j: int) -> int:
167
+ return n*nrows*ncols + i*ncols + j
168
+ return nij2ch
169
+
170
+
171
+ def _get_ch2nij(nrows: int = 1, ncols: int = 1) -> object:
172
+ def ch2nij(ch: int) -> tuple[int, int, int]:
173
+ n = ch // (nrows*ncols)
174
+ i = (ch - n*nrows*ncols) // ncols
175
+ j = ch % ncols
176
+ return n, i, j
177
+ return ch2nij
178
+
179
+
180
+ def _get_vskipfill(nv: float, v_org: np.ndarray, vskip: int) -> object:
181
+ def vskipfill(c: np.ndarray, v_in: np.ndarray) -> np.ndarray:
182
+ return reform_data(c=c, v_in=v_in, nv=nv, v_org=v_org, vskip=vskip)
183
+ return vskipfill
184
+
185
+
135
186
  @dataclass
136
187
  class Stretcher():
137
188
  """Arguments and methods related to the stretch in PlotAstroData.add_color() and add_rgb().
@@ -171,43 +222,51 @@ class Stretcher():
171
222
  getsigma = (islog + ispower) * novmin
172
223
  self.vmin = np.where(getsigma, sigma, vmin)
173
224
 
174
- def do(self, x: list | np.ndarray) -> np.ndarray:
225
+ def do(self, x: list | np.ndarray, i: int = 0) -> np.ndarray:
175
226
  """Get the stretched values.
176
227
 
177
228
  Args:
178
229
  x (list | np.ndarray): Input array in the linear scale.
230
+ i (int): Which element is used in the case where the stretch parameters are lists.
179
231
 
180
232
  Returns:
181
233
  np.ndarray: Output stretched array.
182
234
  """
235
+ st = self.stretch[i] if self.n > 1 else self.stretch
236
+ stsc = self.stretchscale[i] if self.n > 1 else self.stretchscale
237
+ stpw = self.stretchpower[i] if self.n > 1 else self.stretchpower
183
238
  t = np.array(x)
184
- match self.stretch:
239
+ match st:
185
240
  case 'log':
186
241
  t = np.log10(t) # To be consistent with logcbticks().
187
242
  case 'asinh':
188
- t = np.arcsinh(t / self.stretchscale)
243
+ t = np.arcsinh(t / stsc)
189
244
  case 'power':
190
- p = 1e-6 if self.stretchpower == 0 else self.stretchpower
245
+ p = 1e-6 if stpw == 0 else stpw
191
246
  t = t**p / p
192
247
  return t
193
248
 
194
- def undo(self, x: list | np.ndarray) -> np.ndarray:
249
+ def undo(self, x: list | np.ndarray, i: int = 0) -> np.ndarray:
195
250
  """Get the linear values from the stretched values.
196
251
 
197
252
  Args:
198
253
  x (list | np.ndarray): Input stretched array.
254
+ i (int): Which element is used in the case where the stretch parameters are lists.
199
255
 
200
256
  Returns:
201
257
  np.ndarray: Output array in the linear scale.
202
258
  """
259
+ st = self.stretch[i] if self.n > 1 else self.stretch
260
+ stsc = self.stretchscale[i] if self.n > 1 else self.stretchscale
261
+ stpw = self.stretchpower[i] if self.n > 1 else self.stretchpower
203
262
  t = np.array(x)
204
- match self.stretch:
263
+ match st:
205
264
  case 'log':
206
265
  t = 10**t # To be consistent with logcbticks().
207
266
  case 'asinh':
208
- t = np.sinh(t) * self.stretchscale
267
+ t = np.sinh(t) * stsc
209
268
  case 'power':
210
- p = 1e-6 if self.stretchpower == 0 else self.stretchpower
269
+ p = 1e-6 if stpw == 0 else stpw
211
270
  t = (t * p)**(1 / p)
212
271
  return t
213
272
 
@@ -226,7 +285,7 @@ class Stretcher():
226
285
  vmaxout = [self.vmax] if single else self.vmax
227
286
  dataout = [data] if single else data
228
287
  for i, (c, v0, v1) in enumerate(zip(dataout, vminout, vmaxout)):
229
- dataout[i] = cout = self.do(c.clip(v0, v1))
288
+ dataout[i] = cout = self.do(c.clip(v0, v1), i)
230
289
  vminout[i] = np.nanmin(cout)
231
290
  vmaxout[i] = np.nanmax(cout)
232
291
  if single:
@@ -377,7 +436,7 @@ def kwargs2instance(cls: type[T], kw: dict) -> T:
377
436
 
378
437
  Args:
379
438
  cls (class): Class to make the instance.
380
- kw (dict): Parameters to make Stretcher.
439
+ kw (dict): Parameters to make the instance.
381
440
 
382
441
  Returns:
383
442
  instance: an instance of cls made from the parameters in kwargs.
@@ -416,7 +475,7 @@ class PlotAstroData(AstroFrame):
416
475
  kwargs is the arguments of AstroFrame to define plotting ranges.
417
476
 
418
477
  Args:
419
- v (np.ndarray, optional): Used to set up channels if fitsimage not given. Defaults to [0].
478
+ v (np.ndarray, optional): Used to set up channels if fitsimage not given. Defaults to None.
420
479
  vskip (int, optional): How many channels are skipped. Defaults to 1.
421
480
  veldigit (int, optional): How many digits after the decimal point. Defaults to 2.
422
481
  restfreq (float, optional): Used for velocity and brightness T. Defaults to None.
@@ -431,7 +490,7 @@ class PlotAstroData(AstroFrame):
431
490
  ax (optional): External fig.add_subplot(). Defaults to None.
432
491
  """
433
492
  def __init__(self,
434
- v: np.ndarray = np.array([0]), vskip: int = 1,
493
+ v: np.ndarray | None = None, vskip: int = 1,
435
494
  veldigit: int = 2, restfreq: float | None = None,
436
495
  channelnumber: int | None = None,
437
496
  nrows: int = 4, ncols: int = 6,
@@ -443,74 +502,42 @@ class PlotAstroData(AstroFrame):
443
502
  super().__init__(**kwargs)
444
503
  internalfig = fig is None
445
504
  internalax = ax is None
446
- if type(channelnumber) is int:
447
- nrows = ncols = 1
448
- if self.fitsimage is not None:
449
- self.read(d := AstroData(fitsimage=self.fitsimage,
450
- restfreq=restfreq, sigma=None))
451
- v = d.v
452
- if len(v) > 1:
453
- dv = v[1] - v[0]
454
- k0 = int(round((self.vmin - v[0]) / dv))
455
- if k0 < 0:
456
- vpre = v[0] - (1 + np.arange(-k0)[::-1]) * dv
457
- v = np.append(vpre, v)
458
- else:
459
- v = v[k0:]
460
- k1 = len(v) + int(round((self.vmax - v[-1]) / dv))
461
- if k1 > len(v):
462
- vpost = v[-1] + (1 + np.arange(k1 - len(v))) * dv
463
- v = np.append(v, vpost)
464
- else:
465
- v = v[:k1]
466
- if self.pv or v is None or len(v) == 1:
467
- nv = nrows = ncols = npages = nchan = 1
505
+ v = _get_v(p=self, v=v, restfreq=restfreq, vskip=vskip)
506
+ nv = len(v) # number of channels with a label
507
+ if self.pv or len(v) == 1 or channelnumber is not None:
508
+ nrows = ncols = npages = nchan = 1
468
509
  else:
469
- nv = len(v := v[::vskip])
470
510
  npages = int(np.ceil(nv / nrows / ncols))
471
511
  nchan = npages * nrows * ncols
472
- v = np.r_[v, v[-1] + (np.arange(nchan - nv) + 1) * dv]
473
- if type(channelnumber) is int:
474
- nchan = npages = 1
475
-
476
- def nij2ch(n: int, i: int, j: int):
477
- return n*nrows*ncols + i*ncols + j
478
-
479
- def ch2nij(ch: int) -> tuple:
480
- n = ch // (nrows*ncols)
481
- i = (ch - n*nrows*ncols) // ncols
482
- j = ch % ncols
483
- return n, i, j
484
-
512
+ v = reform_grid(v, k1=nchan - nv)
513
+ nij2ch = _get_nij2ch(nrows=nrows, ncols=ncols)
514
+ ch2nij = _get_ch2nij(nrows=nrows, ncols=ncols)
485
515
  if fontsize is None:
486
516
  fontsize = 18 if nchan == 1 else 12
487
517
  set_rcparams(fontsize=fontsize, nancolor=nancolor, dpi=dpi)
488
- ax = np.empty(nchan, dtype='object') if internalax else [ax]
489
- if figsize is None:
490
- sqrt_a = (self.ymax - self.ymin) / (self.xmax - self.xmin)
491
- sqrt_a = np.sqrt(np.abs(sqrt_a))
492
- if nchan == 1:
493
- figsize = (7 / sqrt_a, 5 * sqrt_a)
494
- else:
495
- figsize = (ncols * 2 / sqrt_a, max(nrows*2, 3) * sqrt_a)
518
+ ax = np.empty(nchan, dtype=object) if internalax else [ax]
519
+ figsize = get_figsize(xmin=self.xmin, xmax=self.xmax,
520
+ ymin=self.ymin, ymax=self.ymax,
521
+ figsize=figsize,
522
+ ncols=ncols, nrows=nrows, nchan=nchan)
523
+ need_vlabel = nchan > 1 or type(channelnumber) is int
496
524
  for ch in range(nchan):
497
525
  n, i, j = ch2nij(ch)
498
526
  if internalfig and n not in plt.get_fignums():
499
527
  fig = plt.figure(n, figsize=figsize)
500
- sharex = ax[nij2ch(n, i - 1, j)] if i > 0 else None
501
- sharey = ax[nij2ch(n, i, j - 1)] if j > 0 else None
528
+ if need_vlabel:
529
+ fig.subplots_adjust(hspace=0, wspace=0,
530
+ right=0.87, top=0.87)
502
531
  if internalax:
532
+ sharex = ax[nij2ch(n, i - 1, j)] if i > 0 else None
533
+ sharey = ax[nij2ch(n, i, j - 1)] if j > 0 else None
503
534
  ax[ch] = fig.add_subplot(nrows, ncols, i*ncols + j + 1,
504
535
  sharex=sharex, sharey=sharey)
505
- if nchan > 1 or type(channelnumber) is int:
506
- fig.subplots_adjust(hspace=0, wspace=0, right=0.87, top=0.87)
507
- if ch < nv:
508
- chnum = channelnumber
509
- vellabel = v[ch if chnum is None else chnum]
510
- vd = f'{veldigit:d}'
511
- ax[ch].text(0.9 * self.rmax, 0.7 * self.rmax,
512
- rf'${vellabel:.{vd}f}$', color='black',
513
- backgroundcolor='white', zorder=20)
536
+ if need_vlabel and ch < nv:
537
+ vlabel = v[channelnumber or ch]
538
+ ax[ch].text(0.9 * self.rmax, 0.7 * self.rmax,
539
+ rf'${vlabel:.{veldigit}f}$', color='black',
540
+ backgroundcolor='white', zorder=20)
514
541
  self.fig = None if internalfig else fig
515
542
  self.ax = ax
516
543
  self.rowcol = nrows * ncols
@@ -519,55 +546,7 @@ class PlotAstroData(AstroFrame):
519
546
  self.bottomleft = nij2ch(np.arange(npages), nrows - 1, 0)
520
547
  self.channelnumber = channelnumber
521
548
  self.v = v
522
-
523
- def vskipfill(c: np.ndarray,
524
- v_in: np.ndarray | None = None
525
- ) -> np.ndarray:
526
- """Skip and fill channels with nan.
527
-
528
- Args:
529
- c (np.ndarray): 2D or 3D arrays.
530
- v_in (np.ndarray): 1D array.
531
-
532
- Returns:
533
- np.ndarray: 3D arrays skipped and filled with nan.
534
- """
535
- if np.ndim(c) == 2:
536
- d = np.full((nv, *np.shape(c)), c)
537
- elif np.ndim(c) == 3:
538
- if v_in is not None:
539
- dv_org = self.v[1] - self.v[0]
540
- dv_in = (v_in[1] - v_in[0]) * vskip
541
- k0 = np.argmin(np.abs(self.v - v_in[0]))
542
- k1 = np.argmin(np.abs(self.v - v_in[-1]))
543
- if np.abs(dv_in - dv_org) / dv_org < 0.01:
544
- d = c
545
- else:
546
- s = 'Velocity resolution mismatch (>1%).' \
547
- + ' The cube needs to be regridded' \
548
- + ' outside plotastrodata.'
549
- warnings.warn(s, UserWarning)
550
- n_valid = k1 - k0
551
- d = [None] * n_valid
552
- for k in range(n_valid):
553
- k_tmp = np.argmin(np.abs(v_in - self.v[k]))
554
- diffvel = np.abs(v_in[k_tmp] - self.v[k])
555
- nearby = diffvel < dv_org * 0.5
556
- d[k] = c[k_tmp] if nearby else c[0] * np.nan
557
- d = np.array(d)
558
- if k0 > 0:
559
- prenan = np.full((k0, *np.shape(d)[1:]), np.nan)
560
- d = np.append(prenan, d, axis=0)
561
- d = d[::vskip]
562
- else:
563
- print('c must be 2D or 3D.')
564
- return
565
- n = nchan if channelnumber is None else nv
566
- shape = (n - len(d), len(d[0]), len(d[0, 0]))
567
- postnan = np.full(shape, d[0] * np.nan)
568
- d = np.append(d, postnan, axis=0)
569
- return d
570
- self.vskipfill = vskipfill
549
+ self.vskipfill = _get_vskipfill(nv=nv, v_org=v, vskip=vskip)
571
550
 
572
551
  def _map_init(self, kw: dict) -> tuple:
573
552
  """
@@ -816,7 +795,9 @@ class PlotAstroData(AstroFrame):
816
795
  def _set_colorbar(self, mappable, ch: int, show_cbar: bool,
817
796
  cblabel: str, cbformat: str,
818
797
  cbticks: list | None, cbticklabels: list | None,
819
- cblocation: str, st: Stretcher):
798
+ cblocation: str,
799
+ cblabelfontsize: int, cbtickfontsize: int,
800
+ st: Stretcher):
820
801
  if not show_cbar:
821
802
  return
822
803
 
@@ -832,8 +813,8 @@ class PlotAstroData(AstroFrame):
832
813
  cax = plt.axes([0.88, 0.105, 0.015, 0.77])
833
814
  cb = fig.colorbar(mappable[ch], cax=cax, label=cblabel,
834
815
  format=cbformat)
835
- cb.ax.tick_params(labelsize=14)
836
- font = mpl.font_manager.FontProperties(size=16)
816
+ cb.ax.tick_params(labelsize=cbtickfontsize)
817
+ font = mpl.font_manager.FontProperties(size=cblabelfontsize)
837
818
  cb.ax.yaxis.label.set_font_properties(font)
838
819
  if cbticks is None and st.stretch == 'log':
839
820
  cbticks, cbticklabels = logcbticks(10**st.vmin, 10**st.vmax)
@@ -854,6 +835,8 @@ class PlotAstroData(AstroFrame):
854
835
  cbticks: list[float] | None = None,
855
836
  cbticklabels: list[str] | None = None,
856
837
  cblocation: str = 'right',
838
+ cblabelfontsize: int = 16,
839
+ cbtickfontsize: int = 14,
857
840
  **kwargs) -> None:
858
841
  """Use Axes.pcolormesh of matplotlib. kwargs must include the arguments of AstroData to specify the data to be plotted. kwargs may include the arguments for Stretcher (stretch, stretchscale, and stretchpower) to specify the stretch parameters. kwargs may include arguments of Beam; a dict of beam_kwargs specifies the beam patch in more detail. kwargs may include xskiip and yskip.
859
842
 
@@ -864,6 +847,8 @@ class PlotAstroData(AstroFrame):
864
847
  cbticks (list, optional): Ticks of colorbar. Defaults to None.
865
848
  cbticklabels (list, optional): Ticklabels of colorbar. Defaults to None.
866
849
  cblocation (str, optional): 'left', 'top', 'left', 'right'. Only for 2D images. Defaults to 'right'.
850
+ cblabelfontsize (int, optional): Fontsize for the colorbar label. This is independent of set_rcparams().
851
+ cbtickfontsize (int, optional): Fontsize for the colorbar ticks. This is independent of set_rcparams().
867
852
  """
868
853
  self._kw = {'cmap': 'cubehelix', 'alpha': 1,
869
854
  'edgecolors': 'none', 'zorder': 1,
@@ -890,7 +875,8 @@ class PlotAstroData(AstroFrame):
890
875
  p[ch] = pnow
891
876
  for ch in self.bottomleft:
892
877
  self._set_colorbar(p, ch, show_cbar, cblabel, cbformat,
893
- cbticks, cbticklabels, cblocation, st)
878
+ cbticks, cbticklabels, cblocation,
879
+ cblabelfontsize, cbtickfontsize, st)
894
880
 
895
881
  def add_contour(self,
896
882
  levels: list[float] = [-12, -6, -3, 3, 6, 12, 24, 48, 96, 192, 384],
@@ -1092,10 +1078,10 @@ class PlotAstroData(AstroFrame):
1092
1078
  center = f'{csplit[1]} {csplit[2]}'
1093
1079
  if on_min_scale := (self.rmax >= 60.0):
1094
1080
  # On a 5-second grid.
1095
- ra_s = np.floor(float(_get_sec(center, 0)) / 5) * 5
1081
+ ra_s = np.floor(float(get_sec(center, 0)) / 5) * 5
1096
1082
  dec_s = 0.0
1097
- ra = _get_hmdm(center, 'ra') + f'{ra_s:.1f}s'
1098
- dec = _get_hmdm(center, 'dec') + f'{dec_s:.1f}s'
1083
+ ra = get_hmdm(center, 'ra') + f'{ra_s:.1f}s'
1084
+ dec = get_hmdm(center, 'dec') + f'{dec_s:.1f}s'
1099
1085
  center = f'{ra} {dec}'
1100
1086
 
1101
1087
  def get_tickvalues(ticks: np.ndarray, mode: str, no_sec: bool
@@ -1104,7 +1090,7 @@ class PlotAstroData(AstroFrame):
1104
1090
  if mode == 'ra':
1105
1091
  xy.reverse()
1106
1092
  tickvalues = xy2coord(xy, center)
1107
- getter = _get_min if no_sec else _get_sec
1093
+ getter = get_min if no_sec else get_sec
1108
1094
  tickvalues = [getter(t, mode) for t in tickvalues] # str
1109
1095
  tickvalues = np.array(tickvalues, dtype=float)
1110
1096
  # 7-digit precision for practical use.
@@ -1117,12 +1103,14 @@ class PlotAstroData(AstroFrame):
1117
1103
  'dec': {'d': r'$^{\circ}$',
1118
1104
  'm': r'$^{\prime}$',
1119
1105
  's': r'.$\hspace{-0.4}^{\prime\prime}$'}}
1120
- cos_dec = np.cos(np.radians(coord2xy(center)[1]))
1106
+ dec_center = coord2xy(center)[1]
1107
+ sign_dec = np.sign(dec_center)
1108
+ cos_dec = np.cos(np.radians(dec_center))
1121
1109
  intgrid = np.array([-3, -2, -1, 0, 1, 2, 3])
1122
1110
  i_mid = (len(intgrid) - 1) // 2
1123
1111
 
1124
1112
  def makegrid(mode: str):
1125
- second = float(_get_sec(center, mode))
1113
+ second = float(get_sec(center, mode))
1126
1114
  no_sec = on_min_scale and (mode == 'dec')
1127
1115
  # gridwidth is a float like 2 x 10^order (arcsec).
1128
1116
  gridwidth, order = _get_gridwidth(mode, self.rmax)
@@ -1131,7 +1119,7 @@ class PlotAstroData(AstroFrame):
1131
1119
  rounded = round(second, ndigits=max(-order, -1))
1132
1120
  # Get a grid point closest to the input second.
1133
1121
  rounded = round(rounded / gridwidth) * gridwidth
1134
- factor = 15 * cos_dec if mode == 'ra' else 1
1122
+ factor = 15 * cos_dec if mode == 'ra' else sign_dec
1135
1123
  ticks = (intgrid * gridwidth - second + rounded) * factor
1136
1124
  ticksminor = np.linspace(ticks[0], ticks[-1], 6*nticksminor + 1)
1137
1125
  tickvalues = get_tickvalues(ticks, mode, no_sec)
@@ -1143,8 +1131,8 @@ class PlotAstroData(AstroFrame):
1143
1131
 
1144
1132
  xticks, xticksminor, xticklabels = makegrid('ra')
1145
1133
  yticks, yticksminor, yticklabels = makegrid('dec')
1146
- ra_hm = _get_hmdm(xy2coord([xticks[i_mid] / 3600., 0], center), 'ra')
1147
- dec_dm = _get_hmdm(xy2coord([0, yticks[i_mid] / 3600.], center), 'dec')
1134
+ ra_hm = get_hmdm(xy2coord([xticks[i_mid] / 3600., 0], center), 'ra')
1135
+ dec_dm = get_hmdm(xy2coord([0, yticks[i_mid] / 3600.], center), 'dec')
1148
1136
  if on_min_scale:
1149
1137
  dec_dm = dec_dm.split('d')[0] + 'd'
1150
1138
  ra_hm = ra_hm.translate(str.maketrans(units['ra']))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: plotastrodata
3
- Version: 1.8.17
3
+ Version: 1.9.0
4
4
  Summary: plotastrodata is a tool for astronomers to create figures from FITS files and perform fundamental data analyses with ease.
5
5
  Home-page: https://github.com/yusukeaso-astron/plotastrodata
6
6
  Download-URL: https://github.com/yusukeaso-astron/plotastrodata
File without changes
File without changes
File without changes
File without changes