lisaanalysistools 1.0.0__cp312-cp312-macosx_10_9_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lisaanalysistools might be problematic. Click here for more details.
- lisaanalysistools-1.0.0.dist-info/LICENSE +201 -0
- lisaanalysistools-1.0.0.dist-info/METADATA +80 -0
- lisaanalysistools-1.0.0.dist-info/RECORD +37 -0
- lisaanalysistools-1.0.0.dist-info/WHEEL +5 -0
- lisaanalysistools-1.0.0.dist-info/top_level.txt +2 -0
- lisatools/__init__.py +0 -0
- lisatools/_version.py +4 -0
- lisatools/analysiscontainer.py +438 -0
- lisatools/cutils/detector.cpython-312-darwin.so +0 -0
- lisatools/datacontainer.py +292 -0
- lisatools/detector.py +410 -0
- lisatools/diagnostic.py +976 -0
- lisatools/glitch.py +193 -0
- lisatools/sampling/__init__.py +0 -0
- lisatools/sampling/likelihood.py +882 -0
- lisatools/sampling/moves/__init__.py +0 -0
- lisatools/sampling/moves/gbgroupstretch.py +53 -0
- lisatools/sampling/moves/gbmultipletryrj.py +1287 -0
- lisatools/sampling/moves/gbspecialgroupstretch.py +671 -0
- lisatools/sampling/moves/gbspecialstretch.py +1836 -0
- lisatools/sampling/moves/mbhspecialmove.py +286 -0
- lisatools/sampling/moves/placeholder.py +16 -0
- lisatools/sampling/moves/skymodehop.py +110 -0
- lisatools/sampling/moves/specialforegroundmove.py +564 -0
- lisatools/sampling/prior.py +508 -0
- lisatools/sampling/stopping.py +320 -0
- lisatools/sampling/utility.py +324 -0
- lisatools/sensitivity.py +888 -0
- lisatools/sources/__init__.py +0 -0
- lisatools/sources/emri/__init__.py +1 -0
- lisatools/sources/emri/tdiwaveform.py +72 -0
- lisatools/stochastic.py +291 -0
- lisatools/utils/__init__.py +0 -0
- lisatools/utils/constants.py +40 -0
- lisatools/utils/multigpudataholder.py +730 -0
- lisatools/utils/pointeradjust.py +106 -0
- lisatools/utils/utility.py +240 -0
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import cupy as xp
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
|
|
5
|
+
from eryn.moves import RedBlueMove, StretchMove
|
|
6
|
+
from eryn.state import State
|
|
7
|
+
|
|
8
|
+
from lisatools.sampling.moves.skymodehop import SkyMove
|
|
9
|
+
from bbhx.likelihood import NewHeterodynedLikelihood
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MBHSpecialMove(RedBlueMove):
|
|
14
|
+
def __init__(self, waveform_gen, fd, data_residuals, psd, num_repeats, transform_fn, mbh_priors, mbh_kwargs, moves, df, temperature_controls, **kwargs):
|
|
15
|
+
|
|
16
|
+
RedBlueMove.__init__(self, **kwargs)
|
|
17
|
+
self.fd = fd
|
|
18
|
+
self.data_residuals = data_residuals
|
|
19
|
+
self.psd = psd
|
|
20
|
+
self.waveform_gen = waveform_gen
|
|
21
|
+
self.num_repeats = num_repeats
|
|
22
|
+
self.transform_fn = transform_fn
|
|
23
|
+
self.mbh_priors = mbh_priors
|
|
24
|
+
self.mbh_kwargs = mbh_kwargs
|
|
25
|
+
moves_tmp = [move[0] for move in moves]
|
|
26
|
+
move_weights = [move[1] for move in moves]
|
|
27
|
+
self.moves = moves_tmp
|
|
28
|
+
self.move_weights = move_weights
|
|
29
|
+
self.df = df
|
|
30
|
+
self.temperature_controls = temperature_controls
|
|
31
|
+
|
|
32
|
+
def get_logl(self):
|
|
33
|
+
return -1/2 * (4 * self.df * self.xp.sum((self.data_residuals[:2].conj() * self.data_residuals[:2]) / self.psd[:2], axis=-1)).get()
|
|
34
|
+
|
|
35
|
+
def propose(self, model, state):
|
|
36
|
+
|
|
37
|
+
new_state = deepcopy(state)
|
|
38
|
+
|
|
39
|
+
# TODO: in TF, can we do multiple together?
|
|
40
|
+
# TODO: switch to heterodyning
|
|
41
|
+
|
|
42
|
+
ntemps, nwalkers, nleaves, ndim = new_state.branches["mbh"].shape
|
|
43
|
+
|
|
44
|
+
assert len(self.temperature_controls) == nleaves
|
|
45
|
+
|
|
46
|
+
temp_inds_base = np.repeat(np.arange(ntemps)[:, None], nwalkers, axis=-1)
|
|
47
|
+
walker_inds_base = np.tile(np.arange(nwalkers), (ntemps, 1))
|
|
48
|
+
|
|
49
|
+
start_leaf = np.random.randint(0, nleaves)
|
|
50
|
+
for base_leaf in range(nleaves):
|
|
51
|
+
leaf = (base_leaf + start_leaf) % nleaves
|
|
52
|
+
|
|
53
|
+
temperature_control_here = self.temperature_controls[leaf]
|
|
54
|
+
temperature_control_here.betas[:] = new_state.betas_all[leaf]
|
|
55
|
+
|
|
56
|
+
# remove cold chain sources
|
|
57
|
+
xp.get_default_memory_pool().free_all_blocks()
|
|
58
|
+
removal_coords = new_state.branches["mbh"].coords[0, :, leaf]
|
|
59
|
+
removal_coords_in = self.transform_fn.both_transforms(removal_coords)
|
|
60
|
+
|
|
61
|
+
removal_waveforms = self.waveform_gen(*removal_coords_in.T, fill=True, freqs=self.fd, **self.mbh_kwargs).transpose(1, 0, 2)
|
|
62
|
+
assert removal_waveforms.shape == self.data_residuals.shape
|
|
63
|
+
|
|
64
|
+
# TODO: fix T channel
|
|
65
|
+
# d - h -> need to add removal waveforms
|
|
66
|
+
# ll_tmp1 = (-1/2 * 4 * self.df * xp.sum(self.data_residuals[:2].conj() * self.data_residuals[:2] / self.psd[:2], axis=(0, 2)) - xp.sum(xp.log(xp.asarray(self.psd[:2])), axis=(0, 2))).get()
|
|
67
|
+
|
|
68
|
+
ll_tmp2 = ((-1/2 * 4 * self.df * xp.sum(self.data_residuals[:2].conj() * self.data_residuals[:2] / self.psd[:2], axis=(0, 2))) - xp.sum(xp.log(xp.asarray(self.psd[:2])), axis=(0, 2))).get()
|
|
69
|
+
|
|
70
|
+
self.data_residuals[:2] += removal_waveforms[:2]
|
|
71
|
+
|
|
72
|
+
# ll_tmp3 = (-1/2 * 4 * self.df * xp.sum(self.data_residuals[:2].conj() * self.data_residuals[:2] / self.psd[:2], axis=(0, 2))) # - xp.sum(xp.log(xp.asarray(self.psd[:2])), axis=(0, 2))).get()
|
|
73
|
+
|
|
74
|
+
psd_term = -xp.sum(xp.log(xp.asarray(self.psd[:2])), axis=(0, 2)).get()
|
|
75
|
+
|
|
76
|
+
keep_het = ll_tmp2.argmax().item()
|
|
77
|
+
del removal_waveforms
|
|
78
|
+
xp.get_default_memory_pool().free_all_blocks()
|
|
79
|
+
data_index = xp.arange(removal_coords_in.shape[0], dtype=np.int32)
|
|
80
|
+
noise_index = xp.arange(removal_coords_in.shape[0], dtype=np.int32)
|
|
81
|
+
het_coords = np.tile(removal_coords_in[keep_het], (removal_coords_in.shape[0], 1))
|
|
82
|
+
|
|
83
|
+
like_het = NewHeterodynedLikelihood(
|
|
84
|
+
self.waveform_gen,
|
|
85
|
+
self.fd,
|
|
86
|
+
self.data_residuals.transpose(1, 0, 2).copy(),
|
|
87
|
+
self.psd.transpose(1, 0, 2).copy(),
|
|
88
|
+
het_coords,
|
|
89
|
+
256,
|
|
90
|
+
data_index=data_index,
|
|
91
|
+
noise_index=noise_index,
|
|
92
|
+
use_gpu=True, # self.use_gpu,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
ll_tmp = like_het.get_ll(
|
|
96
|
+
removal_coords_in,
|
|
97
|
+
constants_index=data_index,
|
|
98
|
+
) + psd_term
|
|
99
|
+
|
|
100
|
+
old_coords = new_state.branches["mbh"].coords[:, :, leaf].reshape(-1, ndim)
|
|
101
|
+
old_coords_in = self.transform_fn.both_transforms(old_coords)
|
|
102
|
+
data_index = walker_inds_base.astype(np.int32)
|
|
103
|
+
noise_index = walker_inds_base.astype(np.int32)
|
|
104
|
+
|
|
105
|
+
self.waveform_gen.d_d = ll_tmp # xp.asarray(-2 * np.tile(ll_tmp2, (ntemps, 1)).flatten())
|
|
106
|
+
|
|
107
|
+
del ll_tmp2
|
|
108
|
+
xp.get_default_memory_pool().free_all_blocks()
|
|
109
|
+
# d_d_store = self.waveform_gen.d_d.reshape(ntemps, nwalkers).get()
|
|
110
|
+
|
|
111
|
+
data_index_in = xp.tile(xp.arange(nwalkers), (ntemps, 1)).flatten().astype(xp.int32)
|
|
112
|
+
# TODO: fix this
|
|
113
|
+
# prev_logl = self.waveform_gen.get_direct_ll(self.fd, self.data_residuals.flatten(), self.psd.flatten(), self.df, *old_coords_in.T, noise_index=noise_index, data_index=data_index, **self.mbh_kwargs).reshape((ntemps, nwalkers)).real.get()
|
|
114
|
+
prev_logl = like_het.get_ll(
|
|
115
|
+
old_coords_in,
|
|
116
|
+
constants_index=data_index_in,
|
|
117
|
+
).reshape((ntemps, nwalkers)).real + psd_term
|
|
118
|
+
|
|
119
|
+
prev_logp = self.mbh_priors["mbh"].logpdf(old_coords).reshape((ntemps, nwalkers))
|
|
120
|
+
|
|
121
|
+
prev_logP = temperature_control_here.compute_log_posterior_tempered(prev_logl, prev_logp)
|
|
122
|
+
|
|
123
|
+
# fix this need to compute prev_logl for all walkers
|
|
124
|
+
xp.get_default_memory_pool().free_all_blocks()
|
|
125
|
+
for repeat in range(self.num_repeats):
|
|
126
|
+
|
|
127
|
+
# pick move
|
|
128
|
+
move_here = self.moves[model.random.choice(np.arange(len(self.moves)), p=self.move_weights)]
|
|
129
|
+
|
|
130
|
+
# Split the ensemble in half and iterate over these two halves.
|
|
131
|
+
accepted = np.zeros((ntemps, nwalkers), dtype=bool)
|
|
132
|
+
all_inds = np.tile(np.arange(nwalkers), (ntemps, 1))
|
|
133
|
+
inds = all_inds % self.nsplits
|
|
134
|
+
if self.randomize_split:
|
|
135
|
+
[np.random.shuffle(x) for x in inds]
|
|
136
|
+
|
|
137
|
+
# prepare accepted fraction
|
|
138
|
+
accepted_here = np.zeros((ntemps, nwalkers), dtype=bool)
|
|
139
|
+
for split in range(self.nsplits):
|
|
140
|
+
# get split information
|
|
141
|
+
S1 = inds == split
|
|
142
|
+
num_total_here = np.sum(inds == split)
|
|
143
|
+
nwalkers_here = np.sum(S1[0])
|
|
144
|
+
|
|
145
|
+
temp_inds_here = temp_inds_base[inds == split]
|
|
146
|
+
walker_inds_here = walker_inds_base[inds == split]
|
|
147
|
+
|
|
148
|
+
# prepare the sets for each model
|
|
149
|
+
# goes into the proposal as (ntemps * (nwalkers / subset size), nleaves_max, ndim)
|
|
150
|
+
sets = [
|
|
151
|
+
new_state.branches["mbh"].coords[inds == j][:, leaf].reshape(ntemps, -1, 1, ndim)
|
|
152
|
+
for j in range(self.nsplits)
|
|
153
|
+
]
|
|
154
|
+
|
|
155
|
+
old_points = sets[split].reshape((ntemps, nwalkers_here, ndim))
|
|
156
|
+
|
|
157
|
+
# setup s and c based on splits
|
|
158
|
+
s = {"mbh": sets[split]}
|
|
159
|
+
c = {"mbh": sets[:split] + sets[split + 1 :]}
|
|
160
|
+
|
|
161
|
+
# Get the move-specific proposal.
|
|
162
|
+
if isinstance(move_here, StretchMove):
|
|
163
|
+
q, factors = move_here.get_proposal(
|
|
164
|
+
s, c, model.random
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
else:
|
|
168
|
+
q, factors = move_here.get_proposal(s, model.random)
|
|
169
|
+
|
|
170
|
+
new_points = q["mbh"].reshape((ntemps, nwalkers_here, ndim))
|
|
171
|
+
|
|
172
|
+
# Compute prior of the proposed position
|
|
173
|
+
# new_inds_prior is adjusted if product-space is used
|
|
174
|
+
logp = self.mbh_priors["mbh"].logpdf(new_points.reshape(-1, ndim))
|
|
175
|
+
|
|
176
|
+
new_points_in = self.transform_fn.both_transforms(new_points.reshape(-1, ndim)[~np.isinf(logp)])
|
|
177
|
+
|
|
178
|
+
# Compute the lnprobs of the proposed position.
|
|
179
|
+
data_index = xp.asarray(walker_inds_here[~np.isinf(logp)].astype(np.int32))
|
|
180
|
+
# noise_index = walker_inds_here[~np.isinf(logp)].astype(np.int32)
|
|
181
|
+
|
|
182
|
+
# self.waveform_gen.d_d = xp.asarray(d_d_store[(temp_inds_here[~np.isinf(logp)], walker_inds_here[~np.isinf(logp)])])
|
|
183
|
+
|
|
184
|
+
logl = np.full_like(logp, -1e300)
|
|
185
|
+
|
|
186
|
+
# logl[~np.isinf(logp)] = self.waveform_gen.get_direct_ll(self.fd, self.data_residuals.flatten(), self.psd.flatten(), self.df, *new_points_in.T, noise_index=noise_index, data_index=data_index, **self.mbh_kwargs).real.get()
|
|
187
|
+
logl[~np.isinf(logp)] = like_het.get_ll(
|
|
188
|
+
new_points_in,
|
|
189
|
+
constants_index=data_index,
|
|
190
|
+
) + psd_term[data_index.get()]
|
|
191
|
+
logl = logl.reshape(ntemps, nwalkers_here)
|
|
192
|
+
|
|
193
|
+
logp = logp.reshape(ntemps, nwalkers_here)
|
|
194
|
+
prev_logp_here = prev_logp[inds == split].reshape(ntemps, nwalkers_here)
|
|
195
|
+
|
|
196
|
+
prev_logl_here = prev_logl[inds == split].reshape(ntemps, nwalkers_here)
|
|
197
|
+
|
|
198
|
+
prev_logP_here = temperature_control_here.compute_log_posterior_tempered(prev_logl_here, prev_logp_here)
|
|
199
|
+
logP = temperature_control_here.compute_log_posterior_tempered(logl, logp)
|
|
200
|
+
|
|
201
|
+
lnpdiff = factors + logP - prev_logP_here
|
|
202
|
+
|
|
203
|
+
keep = lnpdiff > np.log(model.random.rand(ntemps, nwalkers_here))
|
|
204
|
+
|
|
205
|
+
temp_inds_update = temp_inds_here[keep.flatten()]
|
|
206
|
+
walker_inds_update = walker_inds_here[keep.flatten()]
|
|
207
|
+
|
|
208
|
+
accepted[(temp_inds_update, walker_inds_update)] = True
|
|
209
|
+
|
|
210
|
+
# update state informatoin
|
|
211
|
+
new_state.branches["mbh"].coords[(temp_inds_update, walker_inds_update, np.full_like(walker_inds_update, leaf))] = new_points[keep].reshape(len(temp_inds_update), ndim)
|
|
212
|
+
|
|
213
|
+
prev_logl[(temp_inds_update, walker_inds_update)] = logl[keep].flatten()
|
|
214
|
+
prev_logp[(temp_inds_update, walker_inds_update)] = logp[keep].flatten()
|
|
215
|
+
prev_logP[(temp_inds_update, walker_inds_update)] = logP[keep].flatten()
|
|
216
|
+
|
|
217
|
+
# acceptance tracking
|
|
218
|
+
self.accepted += accepted
|
|
219
|
+
# print(self.accepted[0])
|
|
220
|
+
self.num_proposals += 1
|
|
221
|
+
|
|
222
|
+
# TODO: include PSD likelihood in swaps?
|
|
223
|
+
# temperature swaps
|
|
224
|
+
# make swaps
|
|
225
|
+
coords_for_swap = {"mbh": new_state.branches_coords["mbh"][:, :, leaf].copy()[:, :, None]}
|
|
226
|
+
|
|
227
|
+
# TODO: check permute make sure it is okay
|
|
228
|
+
coords_for_swap, prev_logP, prev_logl, prev_logp, inds, blobs, supps, branch_supps = temperature_control_here.temperature_swaps(
|
|
229
|
+
coords_for_swap,
|
|
230
|
+
prev_logP.copy(),
|
|
231
|
+
prev_logl.copy(),
|
|
232
|
+
prev_logp.copy(),
|
|
233
|
+
branch_supps={"mbh":None}
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
temperature_control_here.adapt_temps()
|
|
237
|
+
|
|
238
|
+
new_state.branches_coords["mbh"][:, :, leaf] = coords_for_swap["mbh"][:, :, 0]
|
|
239
|
+
|
|
240
|
+
# ll_tmp1 = -1/2 * 4 * self.df * xp.sum(self.data_residuals[:2].conj() * self.data_residuals[:2] / self.psd[:2], axis=(0, 2)).get()
|
|
241
|
+
|
|
242
|
+
# add back cold chain sources
|
|
243
|
+
xp.get_default_memory_pool().free_all_blocks()
|
|
244
|
+
|
|
245
|
+
add_coords = new_state.branches["mbh"].coords[0, :, leaf]
|
|
246
|
+
add_coords_in = self.transform_fn.both_transforms(add_coords)
|
|
247
|
+
|
|
248
|
+
add_waveforms = self.waveform_gen(*add_coords_in.T, fill=True, freqs=self.fd, **self.mbh_kwargs).transpose(1, 0, 2)
|
|
249
|
+
assert add_waveforms.shape == self.data_residuals.shape
|
|
250
|
+
|
|
251
|
+
# d - h -> need to subtract added waveforms
|
|
252
|
+
self.data_residuals[:2] -= add_waveforms[:2]
|
|
253
|
+
|
|
254
|
+
del like_het
|
|
255
|
+
del add_waveforms
|
|
256
|
+
xp.get_default_memory_pool().free_all_blocks()
|
|
257
|
+
# ll_tmp2 = -1/2 * 4 * self.df * xp.sum(self.data_residuals[:2].conj() * self.data_residuals[:2] / self.psd[:2], axis=(0, 2)).get()
|
|
258
|
+
|
|
259
|
+
# read out all betas from temperature controls
|
|
260
|
+
new_state.betas_all[leaf] = temperature_control_here.betas[:]
|
|
261
|
+
# print(leaf)
|
|
262
|
+
|
|
263
|
+
# ll_tmp2 = -1/2 * 4 * self.df * xp.sum(self.data_residuals[:2].conj() * self.data_residuals[:2] / self.psd[:2], axis=(0, 2)).get()
|
|
264
|
+
|
|
265
|
+
# udpate at the end
|
|
266
|
+
# new_state.log_like[(temp_inds_update, walker_inds_update)] = logl.flatten()
|
|
267
|
+
# new_state.log_prior[(temp_inds_update, walker_inds_update)] = logp.flatten()
|
|
268
|
+
|
|
269
|
+
current_ll = (-1/2 * 4 * self.df * xp.sum(self.data_residuals[:2].conj() * self.data_residuals[:2] / self.psd[:2], axis=(0, 2)) - xp.sum(xp.log(xp.asarray(self.psd[:2])), axis=(0, 2))).get()
|
|
270
|
+
xp.get_default_memory_pool().free_all_blocks()
|
|
271
|
+
# TODO: add check with last used logl
|
|
272
|
+
|
|
273
|
+
current_lp = self.mbh_priors["mbh"].logpdf(new_state.branches["mbh"].coords[0, :, :].reshape(-1, ndim)).reshape(new_state.branches["mbh"].shape[1:-1]).sum(axis=-1)
|
|
274
|
+
|
|
275
|
+
new_state.log_like[0] = current_ll
|
|
276
|
+
new_state.log_prior[0] = current_lp
|
|
277
|
+
xp.get_default_memory_pool().free_all_blocks()
|
|
278
|
+
if not hasattr(self, "best_last_ll"):
|
|
279
|
+
self.best_last_ll = current_ll.max()
|
|
280
|
+
self.low_last_ll = current_ll.min()
|
|
281
|
+
# print("mbh", self.best_last_ll, current_ll.max(), current_ll.max() - self.best_last_ll)
|
|
282
|
+
# print(current_ll.max(), self.best_last_ll, current_ll.min(), self.low_last_ll)
|
|
283
|
+
self.best_last_ll = current_ll.max()
|
|
284
|
+
self.low_last_ll = current_ll.min()
|
|
285
|
+
self.temperature_control.swaps_accepted = self.temperature_controls[0].swaps_accepted
|
|
286
|
+
return new_state, accepted
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from eryn.moves import Move
|
|
4
|
+
|
|
5
|
+
class PlaceHolder(Move):
|
|
6
|
+
def __init__(self, *args, **kwargs):
|
|
7
|
+
super(PlaceHolder, self).__init__(*args, **kwargs)
|
|
8
|
+
|
|
9
|
+
def propose(self, model, state):
|
|
10
|
+
accepted = np.zeros(state.log_prob.shape)
|
|
11
|
+
try:
|
|
12
|
+
self.temperature_control.swaps_accepted = np.zeros(len(self.temperature_control.betas) - 1, dtype=int)
|
|
13
|
+
except AttributeError:
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
return state, accepted
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from eryn.moves import MHMove
|
|
6
|
+
|
|
7
|
+
__all__ = ["GaussianMove"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SkyMove(MHMove):
|
|
11
|
+
"""A Metropolis step with a Gaussian proposal function.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
Raises:
|
|
17
|
+
ValueError: If the proposal dimensions are invalid or if any of any of
|
|
18
|
+
the other arguments are inconsistent.
|
|
19
|
+
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, ind_map=None, which="both", **kwargs):
|
|
23
|
+
|
|
24
|
+
if ind_map is None:
|
|
25
|
+
ind_map = dict(cosinc=6, lam=7, sinbeta=8, psi=9)
|
|
26
|
+
|
|
27
|
+
elif isinstance(ind_map, dict) is False:
|
|
28
|
+
raise ValueError("If providing the ind_map kwarg, it must be a dict.")
|
|
29
|
+
|
|
30
|
+
if which not in ["both", "lat", "long"]:
|
|
31
|
+
raise ValueError("which kwarg must be 'both', 'lat', or 'long'.")
|
|
32
|
+
|
|
33
|
+
self.ind_map = ind_map
|
|
34
|
+
self.which = which
|
|
35
|
+
exec(f"self.transform = self.{which}_transform")
|
|
36
|
+
super(SkyMove, self).__init__(**kwargs)
|
|
37
|
+
|
|
38
|
+
def lat_transform(self, coords, random):
|
|
39
|
+
"""
|
|
40
|
+
assumes sin beta
|
|
41
|
+
assumes 2d array with all coords
|
|
42
|
+
coords[]
|
|
43
|
+
"""
|
|
44
|
+
temp = coords.copy()
|
|
45
|
+
|
|
46
|
+
temp[:, self.ind_map["sinbeta"]] *= -1
|
|
47
|
+
temp[:, self.ind_map["cosinc"]] *= -1
|
|
48
|
+
temp[:, self.ind_map["psi"]] = np.pi - temp[:, self.ind_map["psi"]]
|
|
49
|
+
|
|
50
|
+
return temp
|
|
51
|
+
|
|
52
|
+
def long_transform(self, coords, random):
|
|
53
|
+
"""
|
|
54
|
+
assumes sin beta
|
|
55
|
+
assumes 2d array with all coords
|
|
56
|
+
coords[]
|
|
57
|
+
"""
|
|
58
|
+
temp = coords.copy()
|
|
59
|
+
|
|
60
|
+
move_amount = random.randint(0, 4, size=coords.shape[0]) * np.pi / 2.0
|
|
61
|
+
|
|
62
|
+
temp[:, self.ind_map["psi"]] += move_amount
|
|
63
|
+
temp[:, self.ind_map["lam"]] += move_amount
|
|
64
|
+
|
|
65
|
+
temp[:, self.ind_map["psi"]] %= np.pi
|
|
66
|
+
temp[:, self.ind_map["lam"]] %= 2 * np.pi
|
|
67
|
+
|
|
68
|
+
return temp
|
|
69
|
+
|
|
70
|
+
def both_transform(self, coords, random):
|
|
71
|
+
|
|
72
|
+
# if doing both does not assume it will cross plane, selects from 8 modes
|
|
73
|
+
inds_lat_change = random.randint(0, 2, size=coords.shape[0]).astype(bool)
|
|
74
|
+
coords[inds_lat_change] = self.lat_transform(coords[inds_lat_change], random)
|
|
75
|
+
coords = self.long_transform(coords, random)
|
|
76
|
+
return coords
|
|
77
|
+
|
|
78
|
+
def get_proposal(self, branches_coords, random, branches_inds=None, **kwargs):
|
|
79
|
+
"""Get proposal from Gaussian distribution
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
branches_coords (dict): Keys are ``branch_names`` and values are
|
|
83
|
+
np.ndarray[nwalkers, nleaves_max, ndim] representing
|
|
84
|
+
coordinates for walkers.
|
|
85
|
+
branches_inds (dict): Keys are ``branch_names`` and values are
|
|
86
|
+
np.ndarray[nwalkers, nleaves_max] representing which
|
|
87
|
+
leaves are currently being used.
|
|
88
|
+
random (object): Current random state object.
|
|
89
|
+
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
q = {}
|
|
93
|
+
for name, coords in zip(
|
|
94
|
+
branches_coords.keys(), branches_coords.values()
|
|
95
|
+
):
|
|
96
|
+
|
|
97
|
+
if branches_inds is None:
|
|
98
|
+
inds = np.ones(coords.shape[:-1], dtype=bool)
|
|
99
|
+
|
|
100
|
+
else:
|
|
101
|
+
inds = branches_inds[name]
|
|
102
|
+
|
|
103
|
+
ntemps, nwalkers, _, _ = coords.shape
|
|
104
|
+
inds_here = np.where(inds == True)
|
|
105
|
+
|
|
106
|
+
q[name] = coords.copy()
|
|
107
|
+
new_coords = self.transform(coords[inds_here], random)
|
|
108
|
+
q[name][inds_here] = new_coords.copy()
|
|
109
|
+
|
|
110
|
+
return q, np.zeros((ntemps, nwalkers))
|