plotastrodata 1.8.16__tar.gz → 1.8.18__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. {plotastrodata-1.8.16/plotastrodata.egg-info → plotastrodata-1.8.18}/PKG-INFO +1 -1
  2. {plotastrodata-1.8.16 → plotastrodata-1.8.18}/plotastrodata/__init__.py +1 -1
  3. {plotastrodata-1.8.16 → plotastrodata-1.8.18}/plotastrodata/analysis_utils.py +10 -9
  4. {plotastrodata-1.8.16 → plotastrodata-1.8.18}/plotastrodata/coord_utils.py +43 -0
  5. {plotastrodata-1.8.16 → plotastrodata-1.8.18}/plotastrodata/other_utils.py +112 -2
  6. {plotastrodata-1.8.16 → plotastrodata-1.8.18}/plotastrodata/plot_utils.py +168 -179
  7. {plotastrodata-1.8.16 → plotastrodata-1.8.18/plotastrodata.egg-info}/PKG-INFO +1 -1
  8. {plotastrodata-1.8.16 → plotastrodata-1.8.18}/LICENSE +0 -0
  9. {plotastrodata-1.8.16 → plotastrodata-1.8.18}/MANIFEST.in +0 -0
  10. {plotastrodata-1.8.16 → plotastrodata-1.8.18}/README.md +0 -0
  11. {plotastrodata-1.8.16 → plotastrodata-1.8.18}/plotastrodata/const_utils.py +0 -0
  12. {plotastrodata-1.8.16 → plotastrodata-1.8.18}/plotastrodata/ext_utils.py +0 -0
  13. {plotastrodata-1.8.16 → plotastrodata-1.8.18}/plotastrodata/fft_utils.py +0 -0
  14. {plotastrodata-1.8.16 → plotastrodata-1.8.18}/plotastrodata/fits_utils.py +0 -0
  15. {plotastrodata-1.8.16 → plotastrodata-1.8.18}/plotastrodata/fitting_utils.py +0 -0
  16. {plotastrodata-1.8.16 → plotastrodata-1.8.18}/plotastrodata/los_utils.py +0 -0
  17. {plotastrodata-1.8.16 → plotastrodata-1.8.18}/plotastrodata/matrix_utils.py +0 -0
  18. {plotastrodata-1.8.16 → plotastrodata-1.8.18}/plotastrodata/noise_utils.py +0 -0
  19. {plotastrodata-1.8.16 → plotastrodata-1.8.18}/plotastrodata.egg-info/SOURCES.txt +0 -0
  20. {plotastrodata-1.8.16 → plotastrodata-1.8.18}/plotastrodata.egg-info/dependency_links.txt +0 -0
  21. {plotastrodata-1.8.16 → plotastrodata-1.8.18}/plotastrodata.egg-info/not-zip-safe +0 -0
  22. {plotastrodata-1.8.16 → plotastrodata-1.8.18}/plotastrodata.egg-info/requires.txt +0 -0
  23. {plotastrodata-1.8.16 → plotastrodata-1.8.18}/plotastrodata.egg-info/top_level.txt +0 -0
  24. {plotastrodata-1.8.16 → plotastrodata-1.8.18}/setup.cfg +0 -0
  25. {plotastrodata-1.8.16 → plotastrodata-1.8.18}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: plotastrodata
3
- Version: 1.8.16
3
+ Version: 1.8.18
4
4
  Summary: plotastrodata is a tool for astronomers to create figures from FITS files and perform fundamental data analyses with ease.
5
5
  Home-page: https://github.com/yusukeaso-astron/plotastrodata
6
6
  Download-URL: https://github.com/yusukeaso-astron/plotastrodata
@@ -1,4 +1,4 @@
1
1
  import warnings
2
2
 
3
3
  warnings.simplefilter('ignore', FutureWarning)
4
- __version__ = '1.8.16'
4
+ __version__ = '1.8.18'
@@ -7,11 +7,12 @@ from scipy.signal import convolve
7
7
  from plotastrodata import const_utils as cu
8
8
  from plotastrodata.coord_utils import coord2xy, rel2abs, xy2coord
9
9
  from plotastrodata.fits_utils import data2fits, FitsData, Jy2K
10
- from plotastrodata.fitting_utils import (EmceeCorner, gaussian2d,
11
- gaussfit1d, gaussfit2d)
10
+ from plotastrodata.fitting_utils import (EmceeCorner, gaussfit1d,
11
+ gaussfit2d, gaussian2d)
12
12
  from plotastrodata.matrix_utils import dot2d, Mfac, Mrot
13
13
  from plotastrodata.noise_utils import estimate_rms
14
- from plotastrodata.other_utils import isdeg, RGIxy, RGIxyv, to4dim, trim
14
+ from plotastrodata.other_utils import (isdeg, nearest_index,
15
+ RGIxy, RGIxyv, to4dim, trim)
15
16
 
16
17
 
17
18
  def quadrantmean(data: np.ndarray, x: np.ndarray, y: np.ndarray,
@@ -197,10 +198,10 @@ class AstroData():
197
198
  includev (bool, optional): Centering in the v direction at each position. Defaults to False.
198
199
  """
199
200
  if includexy:
200
- xnew = self.x - self.x[np.argmin(np.abs(self.x))]
201
- ynew = self.y - self.y[np.argmin(np.abs(self.y))]
201
+ xnew = self.x - self.x[nearest_index(self.x)]
202
+ ynew = self.y - self.y[nearest_index(self.y)]
202
203
  if includev:
203
- vnew = self.v - self.v[np.argmin(np.abs(self.v))]
204
+ vnew = self.v - self.v[nearest_index(self.v)]
204
205
  if includexy and includev:
205
206
  self.data = RGIxyv(self.v, self.y, self.x, self.data,
206
207
  np.meshgrid(vnew, ynew, xnew, indexing='ij'),
@@ -518,10 +519,10 @@ class AstroData():
518
519
  """
519
520
  fhd = self.fitsheader
520
521
  h = {}
521
- ci = np.argmin(np.abs(self.x))
522
+ ci = nearest_index(self.x)
522
523
  cx = 0
523
524
  if not self.pv:
524
- cj = np.argmin(np.abs(self.y))
525
+ cj = nearest_index(self.y)
525
526
  if self.center is None:
526
527
  cx, cy = self.x[ci], self.x[cj]
527
528
  else:
@@ -538,7 +539,7 @@ class AstroData():
538
539
  h['CDELT1'] = float(self.dx / (3600 if indeg('1') else 1))
539
540
  if self.dv is not None:
540
541
  vaxis = '2' if self.pv else '3'
541
- ck = np.argmin(np.abs(self.v))
542
+ ck = nearest_index(self.v)
542
543
  cv = self.v[ck]
543
544
  dv = self.dv
544
545
  if self.restfreq is None or self.restfreq == 0:
@@ -149,3 +149,46 @@ def abs2rel(xabs: float, yabs: float,
149
149
  xrel = (xabs - x[0]) / (x[-1] - x[0])
150
150
  yrel = (yabs - y[0]) / (y[-1] - y[0])
151
151
  return np.array([xrel, yrel])
152
+
153
+
154
+ def get_sec(coord: str, mode: str) -> str:
155
+ """Pick up the second number from a hmsdms string.
156
+
157
+ Args:
158
+ coord (str): hmsdms string.
159
+ mode (str): 'ra' or 'dec'
160
+
161
+ Returns:
162
+ str: The second number as a string without the unit.
163
+ """
164
+ i_axis = 0 if mode == 'ra' else 1
165
+ return coord.split(' ')[i_axis].split('m')[1].strip('s')
166
+
167
+
168
+ def get_min(coord: str, mode: str) -> str:
169
+ """Pick up the minute number from a hmsdms string.
170
+
171
+ Args:
172
+ coord (str): hmsdms string.
173
+ mode (str): 'ra' or 'dec'
174
+
175
+ Returns:
176
+ str: The minute number as a string without the unit.
177
+ """
178
+ i_axis = 0 if mode == 'ra' else 1
179
+ s = 'h' if mode == 'ra' else 'd'
180
+ return coord.split(' ')[i_axis].split(s)[1].split('m')[0]
181
+
182
+
183
+ def get_hmdm(coord: str, mode: str) -> str:
184
+ """Pick up the coordinate string before the second part from a hsmdms string.
185
+
186
+ Args:
187
+ coord (str): hmsdms string.
188
+ mode (str): 'ra' or 'dec'
189
+
190
+ Returns:
191
+ str: The hm or dm string with the units.
192
+ """
193
+ i_axis = 0 if mode == 'ra' else 1
194
+ return coord.split(' ')[i_axis].split('m')[0] + 'm'
@@ -1,5 +1,6 @@
1
1
  import matplotlib.pyplot as plt
2
2
  import numpy as np
3
+ import warnings
3
4
  from scipy.interpolate import RegularGridInterpolator as RGI
4
5
 
5
6
 
@@ -33,6 +34,20 @@ def isdeg(s: str) -> bool:
33
34
  return False
34
35
 
35
36
 
37
+ def nearest_index(arr: np.ndarray, x: float = 0) -> int:
38
+ """Get the index of the (sorted) arrary that gives a value nearest to x. This is equivalent to np.argmin(np.abs(arr - x)) but optimized for the sorted array.
39
+
40
+ Args:
41
+ arr (np.ndarray): Sorted array.
42
+ x (float, optional): Value to approach. Defaults to 0.
43
+
44
+ Returns:
45
+ int: The index that gives a value nearest to x.
46
+ """
47
+ idx = np.searchsorted(arr, x).clip(1, len(arr) - 1)
48
+ return idx - 1 if x - arr[idx - 1] <= arr[idx] - x else idx
49
+
50
+
36
51
  def trim(data: np.ndarray | None = None, x: np.ndarray | None = None,
37
52
  y: np.ndarray | None = None, v: np.ndarray | None = None,
38
53
  xlim: list[float] | None = None,
@@ -57,8 +72,8 @@ def trim(data: np.ndarray | None = None, x: np.ndarray | None = None,
57
72
  def get_bounds(arr, lim):
58
73
  if arr is None or lim is None or None in lim:
59
74
  return arr, 0, None
60
- lo = np.argmin(np.abs(arr - max(np.min(arr), lim[0])))
61
- hi = np.argmin(np.abs(arr - min(np.max(arr), lim[1])))
75
+ lo = nearest_index(arr, max(np.min(arr), lim[0]))
76
+ hi = nearest_index(arr, min(np.max(arr), lim[1]))
62
77
  lo, hi = sorted((lo, hi))
63
78
  return arr[lo:hi + 1], lo, hi + 1
64
79
 
@@ -106,6 +121,101 @@ def to4dim(data: np.ndarray) -> np.ndarray:
106
121
  return d
107
122
 
108
123
 
124
+ def reform_grid(v: np.ndarray | None = None,
125
+ k0: int | None = None, k1: int | None = None,
126
+ vmin: float | None = None, vmax: float | None = None
127
+ ) -> np.ndarray:
128
+ """Extend or cut the given 1D array based on the given range.
129
+
130
+ Args:
131
+ v (np.ndarray | None, optional): Input 1D array. Defaults to None.
132
+ k0 (int | None, optional): How many channels are added before v[0]; the minus sign means extension. k0 has the priority over vmin. Defaults to None.
133
+ k1 (int | None, optional): How many channels are added after v[-1]; the plus sign means extension. k1 has the priority over vmax. Defaults to None.
134
+ vmin (float | None, optional): New minimum velocity. Defaults to None.
135
+ vmax (float | None, optional): New maximum velocity. Defaults to None.
136
+
137
+ Returns:
138
+ np.ndarray: Extended or cut 1D array.
139
+ """
140
+ if v is None or len(v) <= 1:
141
+ return v
142
+
143
+ dv = v[1] - v[0]
144
+ if k0 is None and vmin is not None:
145
+ k0 = int(round((vmin - v[0]) / dv))
146
+ if k0 is not None and k0 != 0:
147
+ if k0 < 0:
148
+ vpre = v[0] + dv * np.arange(k0, 0)
149
+ v = np.concatenate((vpre, v))
150
+ else:
151
+ v = v[k0:]
152
+ if k1 is None and vmax is not None:
153
+ k1 = int(round((vmax - v[-1]) / dv))
154
+ if k1 is not None and k1 != 0:
155
+ if k1 > 0:
156
+ vpost = v[-1] + dv * np.arange(1, k1 + 1)
157
+ v = np.concatenate((v, vpost))
158
+ else:
159
+ v = v[:len(v) + k1]
160
+ return v
161
+
162
+
163
+ def reform_data(c: np.ndarray, v_in: np.ndarray | None,
164
+ nv: int, v_org: np.ndarray | None = None,
165
+ vskip: int = 1) -> np.ndarray:
166
+ """Skip and fill channels with nan.
167
+
168
+ Args:
169
+ c (np.ndarray): The input 2D or 3D arrays.
170
+ v_in (np.ndarray): The input velocity 1D array.
171
+ nv (int): The number of channels with a label.
172
+ v (np.ndarray, optional): The velocity 1D array, including the channels with and without a label. Defaults to None.
173
+ vskip (int, optional): How many channels are skipped. Defaults to 1.
174
+
175
+ Returns:
176
+ np.ndarray: 3D arrays skipped and filled with nan.
177
+ """
178
+ if v_org is None:
179
+ return c
180
+
181
+ ndim = np.ndim(c)
182
+ if ndim not in [2, 3]:
183
+ print('c must be 2D or 3D.')
184
+ return
185
+
186
+ if ndim == 2:
187
+ d = np.full((nv, *np.shape(c)), c)
188
+ elif v_in is not None:
189
+ dv_org = v_org[1] - v_org[0]
190
+ dv_in = (v_in[1] - v_in[0]) * vskip
191
+ k0 = nearest_index(v_org, v_in[0])
192
+ k1 = nearest_index(v_org, v_in[-1])
193
+ if np.abs(dv_in - dv_org) / dv_org < 0.01:
194
+ d = c
195
+ else:
196
+ s = 'Velocity resolution mismatch (>1%).' \
197
+ + ' The cube needs to be regridded' \
198
+ + ' outside plotastrodata.'
199
+ warnings.warn(s, UserWarning)
200
+ n_valid = k1 - k0
201
+ d = [None] * n_valid
202
+ for k in range(n_valid):
203
+ k_tmp = nearest_index(v_in, v_org[k])
204
+ diffvel = np.abs(v_in[k_tmp] - v_org[k])
205
+ nearby = diffvel < dv_org * 0.5
206
+ d[k] = c[k_tmp] if nearby else c[0] * np.nan
207
+ d = np.array(d)
208
+ if k0 > 0:
209
+ prenan = np.full((k0, *np.shape(d)[1:]), np.nan)
210
+ d = np.concatenate((prenan, d))
211
+ d = d[::vskip]
212
+ shape = np.shape(d)
213
+ shape = (len(v_org) - shape[0], shape[1], shape[2])
214
+ postnan = np.full(shape, np.nan)
215
+ d = np.concatenate((d, postnan))
216
+ return d
217
+
218
+
109
219
  def RGIxy(y: np.ndarray, x: np.ndarray, data: np.ndarray,
110
220
  yxnew: tuple[np.ndarray, np.ndarray] | None = None,
111
221
  **kwargs) -> object | np.ndarray:
@@ -1,15 +1,16 @@
1
1
  import matplotlib as mpl
2
2
  import matplotlib.pyplot as plt
3
3
  import numpy as np
4
- import warnings
5
4
  from dataclasses import dataclass
6
5
  from matplotlib.patches import Ellipse, Rectangle
7
6
  from typing import TypeVar
8
7
 
9
8
  from plotastrodata.analysis_utils import AstroData, AstroFrame
10
- from plotastrodata.coord_utils import coord2xy, xy2coord
9
+ from plotastrodata.coord_utils import (coord2xy, xy2coord,
10
+ get_hmdm, get_min, get_sec)
11
11
  from plotastrodata.noise_utils import estimate_rms
12
- from plotastrodata.other_utils import close_figure, listing
12
+ from plotastrodata.other_utils import (close_figure, listing,
13
+ reform_grid, reform_data)
13
14
 
14
15
 
15
16
  plt.ioff() # force to turn off interactive mode
@@ -99,6 +100,89 @@ def logcbticks(vmin: float = 1e-3, vmax: float = 1e3
99
100
  return ticks[cond], ticklabels[cond]
100
101
 
101
102
 
103
+ def get_figsize(xmin: float, xmax: float, ymin: float, ymax: float,
104
+ figsize: tuple | None = None,
105
+ ncols: int = 1, nrows: int = 1, nchan: int = 1
106
+ ) -> tuple[float, float]:
107
+ """Get a nice figsize (tuple) with the given x and y ranges.
108
+
109
+ Args:
110
+ xmin (float): The figsize is based on the aspect ratio of (ymax - ymin) / (xmax - xmin).
111
+ xmax (float): The figsize is based on the aspect ratio of (ymax - ymin) / (xmax - xmin).
112
+ ymin (float): The figsize is based on the aspect ratio of (ymax - ymin) / (xmax - xmin).
113
+ ymax (float): The figsize is based on the aspect ratio of (ymax - ymin) / (xmax - xmin).
114
+ figsize (tuple | None, optional): If this is not None, this will be the output as is. Defaults to None.
115
+ ncols (int, optional): The number of columns for the channel map. Defaults to 1.
116
+ nrows (int, optional): The number of rows for the channel map. Defaults to 1.
117
+ nchan (int, optional): The number of total channels for the channel map. Defaults to 1.
118
+
119
+ Returns:
120
+ tuple[float, float]: figsize for matplotlib.pyplot.Figure.
121
+ """
122
+ if figsize is not None:
123
+ return figsize
124
+
125
+ sqrt_a = (ymax - ymin) / (xmax - xmin)
126
+ sqrt_a = np.sqrt(np.abs(sqrt_a))
127
+ if nchan == 1:
128
+ figsize = (7 / sqrt_a, 5 * sqrt_a)
129
+ else:
130
+ figsize = (ncols * 2 / sqrt_a, max(nrows*2, 3) * sqrt_a)
131
+ return figsize
132
+
133
+
134
+ def _get_gridwidth(mode: str, rmax: float) -> tuple[float, int]:
135
+ # 10^1.5 / 15 ~ 2 grids for R.A.
136
+ # 10^0.5 ~ 3 grids for Dec.
137
+ scale = 1.5 if mode == 'ra' else 0.5
138
+ x = np.log10(2. * rmax) - scale
139
+ order = np.floor(x)
140
+ frac = x - order
141
+ if frac <= 0.33:
142
+ base = 1
143
+ elif frac <= 0.68:
144
+ base = 2
145
+ else:
146
+ base = 5
147
+ return base * 10**order, int(order)
148
+
149
+
150
+ def _get_v(p, v: np.ndarray | None = None,
151
+ restfreq: float | None = None,
152
+ vskip: int = 1) -> np.ndarray:
153
+ if p.fitsimage is not None:
154
+ p.read(d := AstroData(fitsimage=p.fitsimage,
155
+ restfreq=restfreq, sigma=None))
156
+ v = d.v
157
+ if v is None:
158
+ v = np.array([0])
159
+ if len(v) > 1:
160
+ v = reform_grid(v=v, vmin=p.vmin, vmax=p.vmax)
161
+ v = v[::vskip]
162
+ return v
163
+
164
+
165
+ def _get_nij2ch(nrows: int = 1, ncols: int = 1) -> object:
166
+ def nij2ch(n: int, i: int, j: int) -> int:
167
+ return n*nrows*ncols + i*ncols + j
168
+ return nij2ch
169
+
170
+
171
+ def _get_ch2nij(nrows: int = 1, ncols: int = 1) -> object:
172
+ def ch2nij(ch: int) -> tuple[int, int, int]:
173
+ n = ch // (nrows*ncols)
174
+ i = (ch - n*nrows*ncols) // ncols
175
+ j = ch % ncols
176
+ return n, i, j
177
+ return ch2nij
178
+
179
+
180
+ def _get_vskipfill(nv: float, v_org: np.ndarray, vskip: int) -> object:
181
+ def vskipfill(c: np.ndarray, v_in: np.ndarray) -> np.ndarray:
182
+ return reform_data(c=c, v_in=v_in, nv=nv, v_org=v_org, vskip=vskip)
183
+ return vskipfill
184
+
185
+
102
186
  @dataclass
103
187
  class Stretcher():
104
188
  """Arguments and methods related to the stretch in PlotAstroData.add_color() and add_rgb().
@@ -383,7 +467,7 @@ class PlotAstroData(AstroFrame):
383
467
  kwargs is the arguments of AstroFrame to define plotting ranges.
384
468
 
385
469
  Args:
386
- v (np.ndarray, optional): Used to set up channels if fitsimage not given. Defaults to [0].
470
+ v (np.ndarray, optional): Used to set up channels if fitsimage not given. Defaults to None.
387
471
  vskip (int, optional): How many channels are skipped. Defaults to 1.
388
472
  veldigit (int, optional): How many digits after the decimal point. Defaults to 2.
389
473
  restfreq (float, optional): Used for velocity and brightness T. Defaults to None.
@@ -398,7 +482,7 @@ class PlotAstroData(AstroFrame):
398
482
  ax (optional): External fig.add_subplot(). Defaults to None.
399
483
  """
400
484
  def __init__(self,
401
- v: np.ndarray = np.array([0]), vskip: int = 1,
485
+ v: np.ndarray | None = None, vskip: int = 1,
402
486
  veldigit: int = 2, restfreq: float | None = None,
403
487
  channelnumber: int | None = None,
404
488
  nrows: int = 4, ncols: int = 6,
@@ -410,74 +494,42 @@ class PlotAstroData(AstroFrame):
410
494
  super().__init__(**kwargs)
411
495
  internalfig = fig is None
412
496
  internalax = ax is None
413
- if type(channelnumber) is int:
414
- nrows = ncols = 1
415
- if self.fitsimage is not None:
416
- self.read(d := AstroData(fitsimage=self.fitsimage,
417
- restfreq=restfreq, sigma=None))
418
- v = d.v
419
- if len(v) > 1:
420
- dv = v[1] - v[0]
421
- k0 = int(round((self.vmin - v[0]) / dv))
422
- if k0 < 0:
423
- vpre = v[0] - (1 + np.arange(-k0)[::-1]) * dv
424
- v = np.append(vpre, v)
425
- else:
426
- v = v[k0:]
427
- k1 = len(v) + int(round((self.vmax - v[-1]) / dv))
428
- if k1 > len(v):
429
- vpost = v[-1] + (1 + np.arange(k1 - len(v))) * dv
430
- v = np.append(v, vpost)
431
- else:
432
- v = v[:k1]
433
- if self.pv or v is None or len(v) == 1:
434
- nv = nrows = ncols = npages = nchan = 1
497
+ v = _get_v(p=self, v=v, restfreq=restfreq, vskip=vskip)
498
+ nv = len(v) # number of channels with a label
499
+ if self.pv or len(v) == 1 or channelnumber is not None:
500
+ nrows = ncols = npages = nchan = 1
435
501
  else:
436
- nv = len(v := v[::vskip])
437
502
  npages = int(np.ceil(nv / nrows / ncols))
438
503
  nchan = npages * nrows * ncols
439
- v = np.r_[v, v[-1] + (np.arange(nchan - nv) + 1) * dv]
440
- if type(channelnumber) is int:
441
- nchan = npages = 1
442
-
443
- def nij2ch(n: int, i: int, j: int):
444
- return n*nrows*ncols + i*ncols + j
445
-
446
- def ch2nij(ch: int) -> tuple:
447
- n = ch // (nrows*ncols)
448
- i = (ch - n*nrows*ncols) // ncols
449
- j = ch % ncols
450
- return n, i, j
451
-
504
+ v = reform_grid(v, k1=nchan - nv)
505
+ nij2ch = _get_nij2ch(nrows=nrows, ncols=ncols)
506
+ ch2nij = _get_ch2nij(nrows=nrows, ncols=ncols)
452
507
  if fontsize is None:
453
508
  fontsize = 18 if nchan == 1 else 12
454
509
  set_rcparams(fontsize=fontsize, nancolor=nancolor, dpi=dpi)
455
- ax = np.empty(nchan, dtype='object') if internalax else [ax]
456
- if figsize is None:
457
- sqrt_a = (self.ymax - self.ymin) / (self.xmax - self.xmin)
458
- sqrt_a = np.sqrt(np.abs(sqrt_a))
459
- if nchan == 1:
460
- figsize = (7 / sqrt_a, 5 * sqrt_a)
461
- else:
462
- figsize = (ncols * 2 / sqrt_a, max(nrows*2, 3) * sqrt_a)
510
+ ax = np.empty(nchan, dtype=object) if internalax else [ax]
511
+ figsize = get_figsize(xmin=self.xmin, xmax=self.xmax,
512
+ ymin=self.ymin, ymax=self.ymax,
513
+ figsize=figsize,
514
+ ncols=ncols, nrows=nrows, nchan=nchan)
515
+ need_vlabel = nchan > 1 or type(channelnumber) is int
463
516
  for ch in range(nchan):
464
517
  n, i, j = ch2nij(ch)
465
518
  if internalfig and n not in plt.get_fignums():
466
519
  fig = plt.figure(n, figsize=figsize)
467
- sharex = ax[nij2ch(n, i - 1, j)] if i > 0 else None
468
- sharey = ax[nij2ch(n, i, j - 1)] if j > 0 else None
520
+ if need_vlabel:
521
+ fig.subplots_adjust(hspace=0, wspace=0,
522
+ right=0.87, top=0.87)
469
523
  if internalax:
524
+ sharex = ax[nij2ch(n, i - 1, j)] if i > 0 else None
525
+ sharey = ax[nij2ch(n, i, j - 1)] if j > 0 else None
470
526
  ax[ch] = fig.add_subplot(nrows, ncols, i*ncols + j + 1,
471
527
  sharex=sharex, sharey=sharey)
472
- if nchan > 1 or type(channelnumber) is int:
473
- fig.subplots_adjust(hspace=0, wspace=0, right=0.87, top=0.87)
474
- if ch < nv:
475
- chnum = channelnumber
476
- vellabel = v[ch if chnum is None else chnum]
477
- vd = f'{veldigit:d}'
478
- ax[ch].text(0.9 * self.rmax, 0.7 * self.rmax,
479
- rf'${vellabel:.{vd}f}$', color='black',
480
- backgroundcolor='white', zorder=20)
528
+ if need_vlabel and ch < nv:
529
+ vlabel = v[channelnumber or ch]
530
+ ax[ch].text(0.9 * self.rmax, 0.7 * self.rmax,
531
+ rf'${vlabel:.{veldigit}f}$', color='black',
532
+ backgroundcolor='white', zorder=20)
481
533
  self.fig = None if internalfig else fig
482
534
  self.ax = ax
483
535
  self.rowcol = nrows * ncols
@@ -486,55 +538,7 @@ class PlotAstroData(AstroFrame):
486
538
  self.bottomleft = nij2ch(np.arange(npages), nrows - 1, 0)
487
539
  self.channelnumber = channelnumber
488
540
  self.v = v
489
-
490
- def vskipfill(c: np.ndarray,
491
- v_in: np.ndarray | None = None
492
- ) -> np.ndarray:
493
- """Skip and fill channels with nan.
494
-
495
- Args:
496
- c (np.ndarray): 2D or 3D arrays.
497
- v_in (np.ndarray): 1D array.
498
-
499
- Returns:
500
- np.ndarray: 3D arrays skipped and filled with nan.
501
- """
502
- if np.ndim(c) == 2:
503
- d = np.full((nv, *np.shape(c)), c)
504
- elif np.ndim(c) == 3:
505
- if v_in is not None:
506
- dv_org = self.v[1] - self.v[0]
507
- dv_in = (v_in[1] - v_in[0]) * vskip
508
- k0 = np.argmin(np.abs(self.v - v_in[0]))
509
- k1 = np.argmin(np.abs(self.v - v_in[-1]))
510
- if np.abs(dv_in - dv_org) / dv_org < 0.01:
511
- d = c
512
- else:
513
- s = 'Velocity resolution mismatch (>1%).' \
514
- + ' The cube needs to be regridded' \
515
- + ' outside plotastrodata.'
516
- warnings.warn(s, UserWarning)
517
- n_valid = k1 - k0
518
- d = [None] * n_valid
519
- for k in range(n_valid):
520
- k_tmp = np.argmin(np.abs(v_in - self.v[k]))
521
- diffvel = np.abs(v_in[k_tmp] - self.v[k])
522
- nearby = diffvel < dv_org * 0.5
523
- d[k] = c[k_tmp] if nearby else c[0] * np.nan
524
- d = np.array(d)
525
- if k0 > 0:
526
- prenan = np.full((k0, *np.shape(d)[1:]), np.nan)
527
- d = np.append(prenan, d, axis=0)
528
- d = d[::vskip]
529
- else:
530
- print('c must be 2D or 3D.')
531
- return
532
- n = nchan if channelnumber is None else nv
533
- shape = (n - len(d), len(d[0]), len(d[0, 0]))
534
- postnan = np.full(shape, d[0] * np.nan)
535
- d = np.append(d, postnan, axis=0)
536
- return d
537
- self.vskipfill = vskipfill
541
+ self.vskipfill = _get_vskipfill(nv=nv, v_org=v, vskip=vskip)
538
542
 
539
543
  def _map_init(self, kw: dict) -> tuple:
540
544
  """
@@ -1057,86 +1061,71 @@ class PlotAstroData(AstroFrame):
1057
1061
  center = '00h00m00s 00d00m00s'
1058
1062
  if len(csplit := center.split()) == 3:
1059
1063
  center = f'{csplit[1]} {csplit[2]}'
1060
-
1061
- def get_sec(x, i):
1062
- return x.split(' ')[i].split('m')[1].strip('s')
1063
-
1064
- def get_min(x, i):
1065
- s = 'h' if i == 0 else 'd'
1066
- return x.split(' ')[i].split(s)[1].split('m')[0]
1067
-
1068
- def get_hmdm(x, i):
1069
- return x.split(' ')[i].split('m')[0] + 'm'
1070
-
1071
- on_min_scale = self.rmax >= 60.0
1072
- if on_min_scale:
1064
+ if on_min_scale := (self.rmax >= 60.0):
1065
+ # On a 5-second grid.
1073
1066
  ra_s = np.floor(float(get_sec(center, 0)) / 5) * 5
1074
1067
  dec_s = 0.0
1075
- ra = get_hmdm(center, 0) + f'{ra_s:.1f}s'
1076
- dec = get_hmdm(center, 1) + f'{dec_s:.1f}s'
1068
+ ra = get_hmdm(center, 'ra') + f'{ra_s:.1f}s'
1069
+ dec = get_hmdm(center, 'dec') + f'{dec_s:.1f}s'
1077
1070
  center = f'{ra} {dec}'
1078
1071
 
1079
- dec = np.radians(coord2xy(center)[1])
1080
- log2r = np.log10(2. * self.rmax)
1081
- n = np.array([-3, -2, -1, 0, 1, 2, 3])
1082
-
1083
- def makegrid(second, mode):
1084
- second = float(second)
1085
- is_dec = mode == 'dec'
1086
- scale = 0.5 if is_dec else 1.5
1087
- factor = 1 if is_dec else 15 * np.cos(dec)
1088
- no_sec = on_min_scale and is_dec
1089
- if no_sec:
1090
- unit = r'$^{\prime}$' if is_dec else r'$^\mathrm{m}$'
1091
- else:
1092
- unit = r'$^{\prime\prime}$' if is_dec else r'$^\mathrm{s}$'
1093
- unit = r'.$\hspace{-0.4}$' + unit
1094
- dorder = log2r - scale - (order := np.floor(log2r - scale))
1095
- if 0.00 < dorder <= 0.33:
1096
- g = 1
1097
- elif 0.33 < dorder <= 0.68:
1098
- g = 2
1099
- elif 0.68 < dorder <= 1.00:
1100
- g = 5
1101
- g *= 10**order
1102
- decimals = max(-int(order), -1)
1103
- rounded = round(second, decimals)
1104
- lastdigit = round(rounded // 10**(-decimals-1) % 100 / 10) % 10
1105
- rounded -= lastdigit * 10**(-decimals) % g
1106
- ticks = (n*g - second + rounded) * factor
1107
- ticksminor = np.linspace(ticks[0], ticks[-1], 6*nticksminor + 1)
1108
- decimals = max(decimals, 0)
1109
- decimals = f'{decimals:d}'
1072
+ def get_tickvalues(ticks: np.ndarray, mode: str, no_sec: bool
1073
+ ) -> np.ndarray:
1074
+ xy = [np.zeros_like(ticks), ticks / 3600.]
1110
1075
  if mode == 'ra':
1111
- xy, i = [ticks / 3600., ticks * 0], 0
1112
- else:
1113
- xy, i = [ticks * 0, ticks / 3600.], 1
1076
+ xy.reverse()
1114
1077
  tickvalues = xy2coord(xy, center)
1115
- _get = get_min if no_sec else get_sec
1116
- tickvalues = np.array([float(_get(t, i)) for t in tickvalues])
1117
- tickvalues = np.divmod(tickvalues + 1e-7, 1)
1118
- tickvalues = (tickvalues[0] % 60, tickvalues[1])
1119
- ticklabels = [f'{int(i):02d}{unit}' + f'{j:.{decimals}f}'[2:]
1120
- for i, j in zip(*tickvalues)]
1078
+ getter = get_min if no_sec else get_sec
1079
+ tickvalues = [getter(t, mode) for t in tickvalues] # str
1080
+ tickvalues = np.array(tickvalues, dtype=float)
1081
+ # 7-digit precision for practical use.
1082
+ tickvalues = np.round(tickvalues, 7)
1083
+ return tickvalues
1084
+
1085
+ units = {'ra': {'h': r'$^\mathrm{h}$',
1086
+ 'm': r'$^\mathrm{m}$',
1087
+ 's': r'.$\hspace{-0.4}^\mathrm{s}$'},
1088
+ 'dec': {'d': r'$^{\circ}$',
1089
+ 'm': r'$^{\prime}$',
1090
+ 's': r'.$\hspace{-0.4}^{\prime\prime}$'}}
1091
+ cos_dec = np.cos(np.radians(coord2xy(center)[1]))
1092
+ intgrid = np.array([-3, -2, -1, 0, 1, 2, 3])
1093
+ i_mid = (len(intgrid) - 1) // 2
1094
+
1095
+ def makegrid(mode: str):
1096
+ second = float(get_sec(center, mode))
1097
+ no_sec = on_min_scale and (mode == 'dec')
1098
+ # gridwidth is a float like 2 x 10^order (arcsec).
1099
+ gridwidth, order = _get_gridwidth(mode, self.rmax)
1100
+ # ndigits = -1 is the largest case for 10", 20", ...
1101
+ decimals = str(max(-order, 0))
1102
+ rounded = round(second, ndigits=max(-order, -1))
1103
+ # Get a grid point closest to the input second.
1104
+ rounded = round(rounded / gridwidth) * gridwidth
1105
+ factor = 15 * cos_dec if mode == 'ra' else 1
1106
+ ticks = (intgrid * gridwidth - second + rounded) * factor
1107
+ ticksminor = np.linspace(ticks[0], ticks[-1], 6*nticksminor + 1)
1108
+ tickvalues = get_tickvalues(ticks, mode, no_sec)
1109
+ whole, frac = np.divmod(tickvalues, 1)
1110
+ u = units[mode]['m' if no_sec else 's']
1111
+ ticklabels = [f'{int(i):02d}{u}' + f'{j:.{decimals}f}'[2:]
1112
+ for i, j in zip(whole % 60, frac)]
1121
1113
  return ticks, ticksminor, ticklabels
1122
1114
 
1123
- ra_s = get_sec(center, 0)
1124
- dec_s = get_sec(center, 1)
1125
- xticks, xticksminor, xticklabels = makegrid(ra_s, 'ra')
1126
- yticks, yticksminor, yticklabels = makegrid(dec_s, 'dec')
1127
- ra_hm = get_hmdm(xy2coord([xticks[3] / 3600., 0], center), 0)
1128
- dec_dm = get_hmdm(xy2coord([0, yticks[3] / 3600.], center), 1)
1115
+ xticks, xticksminor, xticklabels = makegrid('ra')
1116
+ yticks, yticksminor, yticklabels = makegrid('dec')
1117
+ ra_hm = get_hmdm(xy2coord([xticks[i_mid] / 3600., 0], center), 'ra')
1118
+ dec_dm = get_hmdm(xy2coord([0, yticks[i_mid] / 3600.], center), 'dec')
1129
1119
  if on_min_scale:
1130
1120
  dec_dm = dec_dm.split('d')[0] + 'd'
1131
- trans = {'h': r'$^\mathrm{h}$', 'm': r'$^\mathrm{m}$'}
1132
- ra_hm = ra_hm.translate(str.maketrans(trans))
1133
- trans = {'d': r'$^{\circ}$', 'm': r'$^{\prime}$'}
1134
- dec_dm = dec_dm.translate(str.maketrans(trans))
1135
- xticklabels[3] = ra_hm + xticklabels[3]
1136
- yticklabels[3] = dec_dm + '\n' + yticklabels[3]
1137
- pa2 = PlotAxes2D(True, None, 'linear', 'linear', self.Xlim, self.Ylim,
1138
- xlabel, ylabel, xticks, yticks, xticklabels,
1139
- yticklabels, xticksminor, yticksminor, grid)
1121
+ ra_hm = ra_hm.translate(str.maketrans(units['ra']))
1122
+ dec_dm = dec_dm.translate(str.maketrans(units['dec']))
1123
+ xticklabels[i_mid] = ra_hm + xticklabels[i_mid]
1124
+ yticklabels[i_mid] = dec_dm + '\n' + yticklabels[i_mid]
1125
+ pa2 = PlotAxes2D(True, None, 'linear', 'linear',
1126
+ self.Xlim, self.Ylim, xlabel, ylabel,
1127
+ xticks, yticks, xticklabels, yticklabels,
1128
+ xticksminor, yticksminor, grid)
1140
1129
  self._set_axis_shared(pa2=pa2, title=title)
1141
1130
 
1142
1131
  def savefig(self, filename: str | None = None,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: plotastrodata
3
- Version: 1.8.16
3
+ Version: 1.8.18
4
4
  Summary: plotastrodata is a tool for astronomers to create figures from FITS files and perform fundamental data analyses with ease.
5
5
  Home-page: https://github.com/yusukeaso-astron/plotastrodata
6
6
  Download-URL: https://github.com/yusukeaso-astron/plotastrodata
File without changes
File without changes
File without changes
File without changes