plotastrodata 1.9.2__tar.gz → 1.9.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. {plotastrodata-1.9.2/plotastrodata.egg-info → plotastrodata-1.9.4}/PKG-INFO +1 -1
  2. {plotastrodata-1.9.2 → plotastrodata-1.9.4}/plotastrodata/__init__.py +1 -1
  3. {plotastrodata-1.9.2 → plotastrodata-1.9.4}/plotastrodata/analysis_utils.py +11 -4
  4. {plotastrodata-1.9.2 → plotastrodata-1.9.4}/plotastrodata/fitting_utils.py +101 -67
  5. {plotastrodata-1.9.2 → plotastrodata-1.9.4/plotastrodata.egg-info}/PKG-INFO +1 -1
  6. {plotastrodata-1.9.2 → plotastrodata-1.9.4}/LICENSE +0 -0
  7. {plotastrodata-1.9.2 → plotastrodata-1.9.4}/MANIFEST.in +0 -0
  8. {plotastrodata-1.9.2 → plotastrodata-1.9.4}/README.md +0 -0
  9. {plotastrodata-1.9.2 → plotastrodata-1.9.4}/plotastrodata/const_utils.py +0 -0
  10. {plotastrodata-1.9.2 → plotastrodata-1.9.4}/plotastrodata/coord_utils.py +0 -0
  11. {plotastrodata-1.9.2 → plotastrodata-1.9.4}/plotastrodata/ext_utils.py +0 -0
  12. {plotastrodata-1.9.2 → plotastrodata-1.9.4}/plotastrodata/fft_utils.py +0 -0
  13. {plotastrodata-1.9.2 → plotastrodata-1.9.4}/plotastrodata/fits_utils.py +0 -0
  14. {plotastrodata-1.9.2 → plotastrodata-1.9.4}/plotastrodata/los_utils.py +0 -0
  15. {plotastrodata-1.9.2 → plotastrodata-1.9.4}/plotastrodata/matrix_utils.py +0 -0
  16. {plotastrodata-1.9.2 → plotastrodata-1.9.4}/plotastrodata/noise_utils.py +0 -0
  17. {plotastrodata-1.9.2 → plotastrodata-1.9.4}/plotastrodata/other_utils.py +0 -0
  18. {plotastrodata-1.9.2 → plotastrodata-1.9.4}/plotastrodata/plot_utils.py +0 -0
  19. {plotastrodata-1.9.2 → plotastrodata-1.9.4}/plotastrodata.egg-info/SOURCES.txt +0 -0
  20. {plotastrodata-1.9.2 → plotastrodata-1.9.4}/plotastrodata.egg-info/dependency_links.txt +0 -0
  21. {plotastrodata-1.9.2 → plotastrodata-1.9.4}/plotastrodata.egg-info/not-zip-safe +0 -0
  22. {plotastrodata-1.9.2 → plotastrodata-1.9.4}/plotastrodata.egg-info/requires.txt +0 -0
  23. {plotastrodata-1.9.2 → plotastrodata-1.9.4}/plotastrodata.egg-info/top_level.txt +0 -0
  24. {plotastrodata-1.9.2 → plotastrodata-1.9.4}/setup.cfg +0 -0
  25. {plotastrodata-1.9.2 → plotastrodata-1.9.4}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: plotastrodata
3
- Version: 1.9.2
3
+ Version: 1.9.4
4
4
  Summary: plotastrodata is a tool for astronomers to create figures from FITS files and perform fundamental data analyses with ease.
5
5
  Home-page: https://github.com/yusukeaso-astron/plotastrodata
6
6
  Download-URL: https://github.com/yusukeaso-astron/plotastrodata
@@ -1,4 +1,4 @@
1
1
  import warnings
2
2
 
3
3
  warnings.simplefilter('ignore', FutureWarning)
4
- __version__ = '1.9.2'
4
+ __version__ = '1.9.4'
@@ -147,6 +147,9 @@ class AstroData():
147
147
  width (list, optional): Number of channels, y-pixels, and x-pixels for binning. Defaults to [1, 1, 1].
148
148
  """
149
149
  w = [1] * (4 - len(width)) + list(width)
150
+ if self.pv:
151
+ w[2] = max(w[1], w[2])
152
+ w[1] = 1
150
153
  d = to4dim(self.data)
151
154
  size = np.array(np.shape(d))
152
155
  w = np.array(w, dtype=int)
@@ -155,11 +158,12 @@ class AstroData():
155
158
  ws = ', '.join([f'{s:d}' for s in w[1:]])
156
159
  print(f'width was changed to [{ws}].')
157
160
  newsize = size // w
158
- if w[1] > 1:
159
- print(f'sigma has been divided by sqrt({w[1]:d})'
161
+ if (not self.pv and w[1] > 1) or (self.pv and w[2] > 1):
162
+ width_v = w[2] if self.pv else w[1]
163
+ print(f'sigma has been divided by sqrt({width_v:d})'
160
164
  + ' because of binning in the v-axis.')
161
- self.sigma = self.sigma / np.sqrt(w[1])
162
- if w[2] > 1 or w[3] > 1:
165
+ self.sigma = self.sigma / np.sqrt(width_v)
166
+ if (not self.pv and w[2] > 1) or w[3] > 1:
163
167
  print('Binning in the x- or y-axis does not update sigma.')
164
168
  grid = [None, self.v, self.y, self.x]
165
169
  dgrid = [None, self.dv, self.dy, self.dx]
@@ -187,6 +191,9 @@ class AstroData():
187
191
  self.data = np.squeeze(d)
188
192
  _, self.v, self.y, self.x = grid
189
193
  _, self.dv, self.dy, self.dx = dgrid
194
+ if self.pv:
195
+ self.v = self.y
196
+ self.dv = self.dy
190
197
 
191
198
  def centering(self, includexy: bool = True,
192
199
  includev: bool = False,
@@ -36,7 +36,7 @@ def logp(x: np.ndarray) -> float:
36
36
 
37
37
  def _get_GR(samples: np.ndarray, nwalkers: int, ndata: int, dim: int
38
38
  ) -> np.ndarray:
39
- # Gelman-Rubin statistics #
39
+ """Calculate the Gelman-Rubin statistics."""
40
40
  B = np.std(np.mean(samples, axis=1), axis=0)
41
41
  W = np.mean(np.std(samples, axis=1), axis=0)
42
42
  V = (len(samples[0]) - 1) / len(samples[0]) * W \
@@ -46,6 +46,20 @@ def _get_GR(samples: np.ndarray, nwalkers: int, ndata: int, dim: int
46
46
  return GR
47
47
 
48
48
 
49
+ def _check_GR(samples: np.ndarray, nwalkers: int, ndata: int, dim: int,
50
+ i: int, ntry: int = 1, grcheck: bool = False) -> int:
51
+ if not grcheck:
52
+ return ntry
53
+
54
+ GR = _get_GR(samples=samples, nwalkers=nwalkers, ndata=ndata, dim=dim)
55
+ if np.max(GR) <= 1.25:
56
+ return ntry
57
+
58
+ if i == ntry:
59
+ print(f'!!! Max GR >1.25 during {ntry:d} trials.!!!')
60
+ return i
61
+
62
+
49
63
  class EmceeCorner():
50
64
  warnings.simplefilter('ignore', RuntimeWarning)
51
65
 
@@ -87,6 +101,72 @@ class EmceeCorner():
87
101
  self.percent = percent
88
102
  self.ndata = 10000 if xdata is None else len(xdata)
89
103
 
104
+ def _get_pos0(self, ntemps: int, nwalkers: int, pt: bool) -> np.ndarray:
105
+ """Create initial walker positions within parameter bounds."""
106
+ lower = self.bounds[:, 0]
107
+ upper = self.bounds[:, 1]
108
+ width = upper - lower
109
+ pos0 = np.random.rand(ntemps, nwalkers, self.dim) * width + lower
110
+ return pos0 if pt else pos0[0]
111
+
112
+ def _run_sampler(self, pos0: np.ndarray, pt: bool,
113
+ ncores: int, ntemps: int,
114
+ nsteps: int, nwalkers: int) -> object:
115
+ """Create and run the sampler, then return it."""
116
+ if pt:
117
+ sampler_cls = ptemcee.Sampler
118
+ sampler_kwargs = {'ntemps': ntemps,
119
+ 'nwalkers': nwalkers, 'dim': self.dim,
120
+ 'logl': self.logl, 'logp': self.logp}
121
+ else:
122
+ if ncores > 1:
123
+ print('Use logl as log_prob_fn to avoid function-in-function.')
124
+ log_prob_fn = self.logl
125
+ else:
126
+ def log_prob_fn(x):
127
+ return self.logp(x) + self.logl(x)
128
+
129
+ sampler_cls = emcee.EnsembleSampler
130
+ sampler_kwargs = {'nwalkers': nwalkers, 'ndim': self.dim,
131
+ 'log_prob_fn': log_prob_fn}
132
+ if ncores > 1:
133
+ with Pool(ncores) as pool:
134
+ sampler = sampler_cls(**sampler_kwargs, pool=pool)
135
+ else:
136
+ sampler = sampler_cls(**sampler_kwargs, pool=None)
137
+ sampler.run_mcmc(pos0, nsteps)
138
+ return sampler
139
+
140
+ def _get_samples(self, sampler, nburnin: int, pt: bool) -> np.ndarray:
141
+ """Extract post-burn-in samples from sampler chain."""
142
+ if pt:
143
+ return sampler.chain[0, :, nburnin:, :] # temperatures, walkers, steps, dim
144
+ else:
145
+ return sampler.chain[:, nburnin:, :] # walkers, steps, dim
146
+
147
+ def _get_lnp_popt(self, sampler, pt: bool, nburnin: int,
148
+ ) -> tuple[np.ndarray, np.ndarray]:
149
+ """Get log probabilities and best-fit parameters from sampler."""
150
+ if pt:
151
+ lnp = sampler.logprobability[0] # 0th temperature chain
152
+ chain = sampler.chain[0]
153
+ else:
154
+ lnp = sampler.lnprobability
155
+ chain = sampler.chain
156
+ idx_best = np.unravel_index(np.argmax(lnp), lnp.shape)
157
+ popt = chain[idx_best]
158
+ lnp = lnp[:, nburnin:]
159
+ return lnp, popt
160
+
161
+ def _get_percentiles(self, samples: np.ndarray
162
+ ) -> tuple[float, float, float]:
163
+ """Compute summary statistics (percentiles) from MCMC samples."""
164
+ s = samples.reshape(-1, self.dim)
165
+ plow = np.percentile(s, self.percent[0], axis=0)
166
+ pmid = np.percentile(s, 50, axis=0)
167
+ phigh = np.percentile(s, self.percent[1], axis=0)
168
+ return plow, pmid, phigh
169
+
90
170
  def fit(self, nwalkersperdim: int = 2,
91
171
  ntemps: int = 1, nsteps: int = 1000,
92
172
  nburnin: int = 500, ntry: int = 1,
@@ -112,79 +192,33 @@ class EmceeCorner():
112
192
  print('nwalkersperdim < 2 is not allowed.'
113
193
  + f' Use 2 instead of {nwalkersperdim:d}.')
114
194
  nwalkers = max(nwalkersperdim, 2) * self.dim # must be even and >= 2 * dim
115
- if ntemps > 1 and not pt:
116
- print('ntemps>1 is supported only with pt=True. Set pt=True.')
195
+ if ntemps > 1:
117
196
  pt = True
118
197
  if global_progressbar:
119
- bar = tqdm(total=ntry * ntemps * nwalkers * (nsteps + 1) // ncores)
198
+ total = ntry * ntemps * nwalkers * (nsteps + 1) // ncores
199
+ bar = tqdm(total=total)
120
200
  bar.set_description('Within the ranges')
121
-
122
- GR = [2] * self.dim
123
- i = 0
124
- while np.max(GR) > 1.25 and i < ntry:
125
- i += 1
201
+ samples = None
202
+ sampler = None
203
+ for i in range(1, ntry + 1):
126
204
  if pos0 is None:
127
- pos0 = np.random.rand(ntemps, nwalkers, self.dim)
128
- pos0 = pos0 * (self.bounds[:, 1] - self.bounds[:, 0])
129
- pos0 = pos0 + self.bounds[:, 0]
130
- if not pt:
131
- pos0 = pos0[0]
132
- if pt:
133
- pars = {'ntemps': ntemps, 'nwalkers': nwalkers, 'dim': self.dim,
134
- 'logl': self.logl, 'logp': self.logp}
135
- if ncores > 1:
136
- with Pool(ncores) as pool:
137
- sampler = ptemcee.Sampler(**pars, pool=pool)
138
- sampler.run_mcmc(pos0, nsteps)
139
- else:
140
- sampler = ptemcee.Sampler(**pars)
141
- sampler.run_mcmc(pos0, nsteps)
142
- samples = sampler.chain[0, :, nburnin:, :] # temperatures, walkers, steps, dim
143
- else:
144
- if ncores > 1:
145
- print('Use logl as log_prob_fn to avoid'
146
- + ' function-in-function.')
147
- log_prob_fn = self.logl
148
- else:
149
- def log_prob_fn(x):
150
- return self.logp(x) + self.logl(x)
151
-
152
- pars = {'nwalkers': nwalkers, 'ndim': self.dim,
153
- 'log_prob_fn': log_prob_fn}
154
- if ncores > 1:
155
- with Pool(ncores) as pool:
156
- sampler = emcee.EnsembleSampler(**pars, pool=pool)
157
- sampler.run_mcmc(pos0, nsteps)
158
- else:
159
- sampler = emcee.EnsembleSampler(**pars)
160
- sampler.run_mcmc(pos0, nsteps)
161
- samples = sampler.chain[:, nburnin:, :] # walkers, steps, dim
162
- if grcheck:
163
- GR = _get_GR(samples=samples, nwalkers=nwalkers,
164
- ndata=self.ndata, dim=self.dim)
165
- else:
166
- GR = np.zeros(self.dim)
167
- if i == ntry - 1 and np.max(GR) > 1.25:
168
- print(f'!!! Max GR >1.25 during {ntry:d} trials.!!!')
169
-
170
- self.samples = samples
205
+ pos0 = self._get_pos0(ntemps=ntemps, nwalkers=nwalkers, pt=pt)
206
+ sampler = self._run_sampler(pos0=pos0, pt=pt, ncores=ncores,
207
+ ntemps=ntemps, nsteps=nsteps,
208
+ nwalkers=nwalkers)
209
+ samples = self._get_samples(sampler=sampler,
210
+ nburnin=nburnin, pt=pt)
211
+ i = _check_GR(samples=samples, nwalkers=nwalkers,
212
+ ndata=self.ndata, dim=self.dim,
213
+ i=i, ntry=ntry, grcheck=grcheck)
171
214
  if savechain is not None:
172
215
  np.save(savechain.removesuffix('.npy') + '.npy', samples)
173
- if pt:
174
- lnps = sampler.logprobability[0] # [0] is in the temperature axis.
175
- idx_best = np.unravel_index(np.argmax(lnps), lnps.shape)
176
- self.popt = sampler.chain[0][idx_best] # [0] is in the temperature axis.
177
- else:
178
- lnps = sampler.lnprobability
179
- idx_best = np.unravel_index(np.argmax(lnps), lnps.shape)
180
- self.popt = sampler.chain[idx_best]
181
- self.lnps = lnps[:, nburnin:]
182
- s = samples.reshape((-1, self.dim))
183
- self.plow = np.percentile(s, self.percent[0], axis=0)
184
- self.pmid = np.percentile(s, 50, axis=0)
185
- self.phigh = np.percentile(s, self.percent[1], axis=0)
216
+ self.lnp, self.popt = self._get_lnp_popt(sampler=sampler, pt=pt,
217
+ nburnin=nburnin)
218
+ self.plow, self.pmid, self.phigh = self._get_percentiles(samples)
219
+ self.samples = samples
186
220
  if global_progressbar:
187
- print('')
221
+ print()
188
222
 
189
223
  def plotcorner(self, labels: list[str] | None = None,
190
224
  cornerrange: list[float] | None = None,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: plotastrodata
3
- Version: 1.9.2
3
+ Version: 1.9.4
4
4
  Summary: plotastrodata is a tool for astronomers to create figures from FITS files and perform fundamental data analyses with ease.
5
5
  Home-page: https://github.com/yusukeaso-astron/plotastrodata
6
6
  Download-URL: https://github.com/yusukeaso-astron/plotastrodata
File without changes
File without changes
File without changes
File without changes
File without changes