sting 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sting/__init__.py +8 -0
- sting/_version.py +24 -0
- sting/errors.py +677 -0
- sting/extract_streamline.py +425 -0
- sting/gradient_descent.py +1776 -0
- sting/outputs.py +1705 -0
- sting/stream_lines_grad.py +448 -0
- sting-0.2.0.dist-info/METADATA +251 -0
- sting-0.2.0.dist-info/RECORD +14 -0
- sting-0.2.0.dist-info/WHEEL +5 -0
- sting-0.2.0.dist-info/licenses/LICENCE +21 -0
- sting-0.2.0.dist-info/scm_file_list.json +26 -0
- sting-0.2.0.dist-info/scm_version.json +8 -0
- sting-0.2.0.dist-info/top_level.txt +1 -0
sting/errors.py
ADDED
|
@@ -0,0 +1,677 @@
|
|
|
1
|
+
# For computing the uncertainties, we must use some slightly different versions of the functions,
|
|
2
|
+
# because the original versions use jax.jit, which is great for optimization,
|
|
3
|
+
# but not compatble with taking second derivatives for the Hessian-based uncertainty estimation
|
|
4
|
+
# because they have a side effect of requiring computations such as argsort.
|
|
5
|
+
|
|
6
|
+
# So we put the slower but hessian-compatible versions of the relevant functions here, and use them in the uncertainty estimation.
|
|
7
|
+
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
import jax
|
|
10
|
+
from jax.experimental import checkify
|
|
11
|
+
import astropy.units as u
|
|
12
|
+
import math
|
|
13
|
+
jax.config.update("jax_enable_x64", True)
|
|
14
|
+
from typing import NamedTuple
|
|
15
|
+
from collections import namedtuple
|
|
16
|
+
from . import stream_lines_grad, extract_streamline, gradient_descent
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
## constants
|
|
20
|
+
eps = 1e-8 # small value to avoid division by zero
|
|
21
|
+
G = 6.67430e-11 * (1e-3)**2 * (1.988416e30) / (1.4959787e11) # in au (km/s)^2 * Msol^-1
|
|
22
|
+
au_to_km = 1.4959787e8 #km
|
|
23
|
+
FLOAT_DTYPE = jnp.float64
|
|
24
|
+
|
|
25
|
+
# settings and constants
|
|
26
|
+
BIG = 1e30
|
|
27
|
+
LOSS_METHOD_CHOICES = [0, 1]
|
|
28
|
+
CANONICAL_UNITS = {
|
|
29
|
+
"r0": u.au,
|
|
30
|
+
"theta0": u.rad,
|
|
31
|
+
"phi0": u.rad,
|
|
32
|
+
"inc": u.rad,
|
|
33
|
+
"pa": u.rad,
|
|
34
|
+
"v_r0": u.km / u.s,
|
|
35
|
+
"mass": u.Msun,
|
|
36
|
+
"rmin": u.au,
|
|
37
|
+
"deltar": u.au,
|
|
38
|
+
"v_lsr": u.km / u.s,
|
|
39
|
+
"rc": u.au,
|
|
40
|
+
"omega": 1/u.s,
|
|
41
|
+
# mu = rc/r0 is dimensionless, so no units
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
STREAMLINE_MODEL_PARAM_KEYS = (
|
|
45
|
+
'r0',
|
|
46
|
+
'theta0',
|
|
47
|
+
'phi0',
|
|
48
|
+
'rc',
|
|
49
|
+
'omega',
|
|
50
|
+
'mu',
|
|
51
|
+
'v_r0',
|
|
52
|
+
'mass',
|
|
53
|
+
'inc',
|
|
54
|
+
'pa',
|
|
55
|
+
'rmin',
|
|
56
|
+
'deltar',
|
|
57
|
+
'v_lsr',
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def match_model_to_data_curve_hsafe(
|
|
62
|
+
ra_model,
|
|
63
|
+
dec_model,
|
|
64
|
+
v_model,
|
|
65
|
+
valid_mask_model,
|
|
66
|
+
ra_data,
|
|
67
|
+
dec_data,
|
|
68
|
+
model_sort_idx,
|
|
69
|
+
dmetric_model_frozen, #precomputed at best fit params and treated as constant
|
|
70
|
+
data_valid,
|
|
71
|
+
):
|
|
72
|
+
"""Second-derivative-safe version of match_model_to_data_curve, with the argsort replaced by a precomputed static integer index array"""
|
|
73
|
+
ra_model = jnp.asarray(ra_model, dtype=jnp.float64)
|
|
74
|
+
dec_model = jnp.asarray(dec_model, dtype=jnp.float64)
|
|
75
|
+
v_model = jnp.asarray(v_model, dtype=jnp.float64)
|
|
76
|
+
ra_data = jnp.asarray(ra_data, dtype=jnp.float64)
|
|
77
|
+
dec_data = jnp.asarray(dec_data, dtype=jnp.float64)
|
|
78
|
+
|
|
79
|
+
dmetric_model = dmetric_model_frozen
|
|
80
|
+
dmetric_data, _ = extract_streamline.get_distance_metric(ra_data, dec_data) # this is not precomputed because it only depends on the data
|
|
81
|
+
model_valid = valid_mask_model.astype(bool)
|
|
82
|
+
|
|
83
|
+
d_data_valid = jnp.where(data_valid, dmetric_data, jnp.inf)
|
|
84
|
+
data_min = jnp.min(d_data_valid)
|
|
85
|
+
model_keep = model_valid & (dmetric_model >= data_min) & (dmetric_model < BIG)
|
|
86
|
+
w_model = model_keep.astype(jnp.float64)
|
|
87
|
+
|
|
88
|
+
d_model = jnp.where(model_keep, dmetric_model, 0.0)
|
|
89
|
+
d_data = jnp.where(data_valid, dmetric_data, 0.0)
|
|
90
|
+
|
|
91
|
+
# apply precomputed sort index to model
|
|
92
|
+
d_model_s = d_model[model_sort_idx]
|
|
93
|
+
ra_s = ra_model[model_sort_idx]
|
|
94
|
+
dec_s = dec_model[model_sort_idx]
|
|
95
|
+
v_s = v_model[model_sort_idx]
|
|
96
|
+
w_model_s = w_model[model_sort_idx]
|
|
97
|
+
|
|
98
|
+
data_min_eff = jnp.min(jnp.where(data_valid, dmetric_data, jnp.inf))
|
|
99
|
+
data_max_eff = jnp.max(jnp.where(data_valid, dmetric_data, -jnp.inf))
|
|
100
|
+
model_min = jnp.min(jnp.where(model_keep, dmetric_model, jnp.inf))
|
|
101
|
+
model_max = jnp.max(jnp.where(model_keep, dmetric_model, -jnp.inf))
|
|
102
|
+
|
|
103
|
+
model_span = model_max - model_min
|
|
104
|
+
data_span = data_max_eff - data_min_eff
|
|
105
|
+
model_span_safe = jnp.where(model_span > 0.0, model_span, 1.0)
|
|
106
|
+
data_span_safe = jnp.where(data_span > 0.0, data_span, 1.0)
|
|
107
|
+
|
|
108
|
+
d_data_norm = (d_data - data_min_eff) / data_span_safe
|
|
109
|
+
d_goal = model_min + d_data_norm * model_span_safe
|
|
110
|
+
|
|
111
|
+
xp = jnp.where(w_model_s > 0, d_model_s, BIG)
|
|
112
|
+
ra_interp = jnp.interp(d_goal, xp, ra_s)
|
|
113
|
+
dec_interp = jnp.interp(d_goal, xp, dec_s)
|
|
114
|
+
v_interp = jnp.interp(d_goal, xp, v_s)
|
|
115
|
+
|
|
116
|
+
valid = data_valid
|
|
117
|
+
return ra_interp, dec_interp, v_interp, valid
|
|
118
|
+
|
|
119
|
+
@jax.jit(static_argnames=("loss_method", "npoints", "priors_keys", "priors_means", "priors_sigmas"))
|
|
120
|
+
def chi2_loss_hsafe(model_params, distance_pc, prepared_data, loss_method, model_sort_idx, dmetric_model_frozen,npoints=10000,
|
|
121
|
+
priors_keys=(), priors_means=(), priors_sigmas=()):
|
|
122
|
+
"""Second-derivative-safe version of chi2_loss, using match_model_to_data_curve_hsafe"""
|
|
123
|
+
distance_pc = jnp.asarray(distance_pc, dtype=jnp.float64)
|
|
124
|
+
|
|
125
|
+
rmin = model_params['rmin']
|
|
126
|
+
if rmin is None:
|
|
127
|
+
rmin = jnp.asarray(0.0, dtype=jnp.float64)
|
|
128
|
+
|
|
129
|
+
if 'mu' in model_params:
|
|
130
|
+
mu = model_params['mu']
|
|
131
|
+
elif 'rc' in model_params:
|
|
132
|
+
mu = model_params['rc'] / model_params['r0']
|
|
133
|
+
elif 'omega' in model_params:
|
|
134
|
+
mu = stream_lines_grad.mu_from_omega(omega=model_params['omega'], mass=model_params['mass'], r0=model_params['r0'])
|
|
135
|
+
else:
|
|
136
|
+
raise ValueError("model_params must contain either 'rc', 'omega', or 'mu'")
|
|
137
|
+
|
|
138
|
+
model_params = dict(model_params)
|
|
139
|
+
model_params['mu'] = mu
|
|
140
|
+
|
|
141
|
+
xyz_stream_checked = checkify.checkify(stream_lines_grad.xyz_stream)
|
|
142
|
+
errs, ((x, y, z), (vx, vy, vz), valid_mask) = xyz_stream_checked(
|
|
143
|
+
mass=model_params['mass'],
|
|
144
|
+
r0=model_params['r0'],
|
|
145
|
+
theta0=model_params['theta0'],
|
|
146
|
+
phi0=model_params['phi0'],
|
|
147
|
+
mu=model_params['mu'],
|
|
148
|
+
v_r0=model_params['v_r0'],
|
|
149
|
+
inc=model_params['inc'],
|
|
150
|
+
pa=model_params['pa'],
|
|
151
|
+
rmin=rmin,
|
|
152
|
+
deltar=model_params['deltar'],
|
|
153
|
+
npoints=npoints,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
ra_model = -x / distance_pc
|
|
157
|
+
dec_model = z / distance_pc
|
|
158
|
+
v_model = vy + model_params['v_lsr']
|
|
159
|
+
ra_model = jnp.where(valid_mask, ra_model, jnp.nan)
|
|
160
|
+
dec_model = jnp.where(valid_mask, dec_model, jnp.nan)
|
|
161
|
+
v_model = jnp.where(valid_mask, v_model, jnp.nan)
|
|
162
|
+
|
|
163
|
+
ra_data = prepared_data.ra_data
|
|
164
|
+
dec_data = prepared_data.dec_data
|
|
165
|
+
v_data = prepared_data.v_data
|
|
166
|
+
ra_sigma = prepared_data.ra_sigma_safe
|
|
167
|
+
dec_sigma = prepared_data.dec_sigma_safe
|
|
168
|
+
v_sigma = prepared_data.v_sigma_safe
|
|
169
|
+
data_valid = prepared_data.data_finite_mask
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
ra_interp, dec_interp, v_interp, valid = match_model_to_data_curve_hsafe(
|
|
173
|
+
ra_model, dec_model, v_model, valid_mask,
|
|
174
|
+
ra_data, dec_data,
|
|
175
|
+
model_sort_idx,
|
|
176
|
+
dmetric_model_frozen,
|
|
177
|
+
data_valid
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
valid_weights = valid.astype(jnp.float64)
|
|
181
|
+
|
|
182
|
+
chi2_v = jnp.sum(valid_weights * ((v_data - v_interp) / v_sigma)**2)
|
|
183
|
+
|
|
184
|
+
if loss_method == 0:
|
|
185
|
+
chi2_ra = jnp.sum(valid_weights * ((ra_data - ra_interp) / ra_sigma)**2)
|
|
186
|
+
chi2_dec = jnp.sum(valid_weights * ((dec_data - dec_interp) / dec_sigma)**2)
|
|
187
|
+
chi2_total = chi2_ra + chi2_dec + chi2_v
|
|
188
|
+
else:
|
|
189
|
+
r_proj_data = prepared_data.r_proj_data
|
|
190
|
+
theta_proj_data = prepared_data.theta_proj_data
|
|
191
|
+
r_proj_model, theta_proj_model = extract_streamline.cartesian_to_polar(ra_interp, dec_interp)
|
|
192
|
+
dtheta = extract_streamline.wrap_to_pi(theta_proj_data - theta_proj_model)
|
|
193
|
+
sigma_r = jnp.sqrt(ra_sigma**2 + dec_sigma**2)
|
|
194
|
+
r_eps = jnp.asarray(1e-8, dtype=jnp.float64)
|
|
195
|
+
r_safe = jnp.maximum(jnp.abs(r_proj_data), r_eps)
|
|
196
|
+
sigma_theta = jnp.sqrt((dec_data * ra_sigma)**2 + (ra_data * dec_sigma)**2) / r_safe**2
|
|
197
|
+
sigma_theta = jnp.maximum(sigma_theta, r_eps)
|
|
198
|
+
chi2_r = jnp.sum(valid_weights * ((r_proj_data - r_proj_model) / sigma_r)**2)
|
|
199
|
+
chi2_theta = jnp.sum(valid_weights * (dtheta / sigma_theta)**2)
|
|
200
|
+
chi2_total = chi2_r + chi2_theta + chi2_v
|
|
201
|
+
|
|
202
|
+
# add prior penalty if any
|
|
203
|
+
chi2_prior = gradient_descent.compute_prior_penalty(model_params, priors_means, priors_sigmas, priors_keys)
|
|
204
|
+
chi2_total = chi2_total + chi2_prior
|
|
205
|
+
|
|
206
|
+
return chi2_total
|
|
207
|
+
|
|
208
|
+
def compute_model_sort_idx(best_opt_params, fixed_params, distance_pc, prepared_data, npoints=10000):
|
|
209
|
+
"""Evaluate forward model once at best-fit parameters, and compute sort index that can be reused for Hessian-based uncertainty estimation"""
|
|
210
|
+
model_params = {**best_opt_params, **fixed_params}
|
|
211
|
+
rmin = model_params.get('rmin', None)
|
|
212
|
+
if rmin is None:
|
|
213
|
+
rmin = 0.0
|
|
214
|
+
|
|
215
|
+
if 'mu' in model_params:
|
|
216
|
+
mu = float(model_params['mu'])
|
|
217
|
+
elif 'rc' in model_params:
|
|
218
|
+
mu = float(model_params['rc']) / float(model_params['r0'])
|
|
219
|
+
elif 'omega' in model_params:
|
|
220
|
+
mu = stream_lines_grad.mu_from_omega(omega=model_params['omega'], mass=model_params['mass'], r0=model_params['r0'])
|
|
221
|
+
else:
|
|
222
|
+
raise ValueError("model_params must contain either 'rc', 'omega', or 'mu'")
|
|
223
|
+
|
|
224
|
+
model_params = dict(model_params)
|
|
225
|
+
model_params['mu'] = mu
|
|
226
|
+
|
|
227
|
+
(x, y, z), (vx, vy, vz), valid_mask = stream_lines_grad.xyz_stream(
|
|
228
|
+
mass=float(model_params['mass']),
|
|
229
|
+
r0=float(model_params['r0']),
|
|
230
|
+
theta0=float(model_params['theta0']),
|
|
231
|
+
phi0=float(model_params['phi0']),
|
|
232
|
+
mu=float(model_params['mu']),
|
|
233
|
+
v_r0=float(model_params['v_r0']),
|
|
234
|
+
inc=float(model_params['inc']),
|
|
235
|
+
pa=float(model_params['pa']),
|
|
236
|
+
rmin=float(rmin),
|
|
237
|
+
deltar=float(model_params['deltar']),
|
|
238
|
+
npoints=npoints,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
ra_model = jnp.asarray(-x, dtype=jnp.float64) / float(distance_pc)
|
|
242
|
+
dec_model = jnp.asarray(z, dtype=jnp.float64) / float(distance_pc)
|
|
243
|
+
ra_model = jnp.where(valid_mask, ra_model, jnp.nan)
|
|
244
|
+
dec_model = jnp.where(valid_mask, dec_model, jnp.nan)
|
|
245
|
+
|
|
246
|
+
dmetric_model, _ = extract_streamline.get_distance_metric(ra_model, dec_model)
|
|
247
|
+
|
|
248
|
+
data_valid = prepared_data.data_finite_mask
|
|
249
|
+
dmetric_data = prepared_data.dmetric_data
|
|
250
|
+
data_min = jnp.min(jnp.where(data_valid, dmetric_data, jnp.inf))
|
|
251
|
+
|
|
252
|
+
model_keep = valid_mask.astype(bool) & (dmetric_model >= data_min) & (dmetric_model < BIG)
|
|
253
|
+
w_model = model_keep.astype(jnp.float64)
|
|
254
|
+
d_model = jnp.where(model_keep, dmetric_model, 0.0)
|
|
255
|
+
|
|
256
|
+
# return integer array of indices that would sort the model points by distance metric, with invalid points at the end
|
|
257
|
+
model_sort_key = d_model + (1.0 - w_model) * BIG
|
|
258
|
+
|
|
259
|
+
return jnp.argsort(model_sort_key), dmetric_model
|
|
260
|
+
|
|
261
|
+
def estimate_parameter_errors(
|
|
262
|
+
best_opt_params,
|
|
263
|
+
fixed_params,
|
|
264
|
+
distance_pc,
|
|
265
|
+
prepared_data,
|
|
266
|
+
loss_method=0,
|
|
267
|
+
gradient_tol=1e-1,
|
|
268
|
+
normalisation_spec=None,
|
|
269
|
+
best_norm_opt_params=None,
|
|
270
|
+
rotation_key=None,
|
|
271
|
+
npoints=10000,
|
|
272
|
+
priors_keys=(),
|
|
273
|
+
priors_means=(),
|
|
274
|
+
priors_sigmas=()
|
|
275
|
+
):
|
|
276
|
+
"""
|
|
277
|
+
Estimate parameter uncertainties using Hessian of chi2 loss.
|
|
278
|
+
Hessian is computed in normalised parameter space to keep eery parameter on O(1) scale.
|
|
279
|
+
Then resulting covariance is transformed back to physical parameter space.
|
|
280
|
+
|
|
281
|
+
Parameters
|
|
282
|
+
----------
|
|
283
|
+
prepared_data : PreparedData
|
|
284
|
+
Precomputed data-only quantities (created via extract_streamline.prepare_data).
|
|
285
|
+
gradient_tol : float or None
|
|
286
|
+
Tolerance on gradient norm in normalised space. If provided and
|
|
287
|
+
normalised-space gradient norm > gradient_tol at best params, a
|
|
288
|
+
warning is issued because the quadratic approximation may not be valid.
|
|
289
|
+
normalisation_spec : dict
|
|
290
|
+
Bounds-derived normalisation metadata for optimised parameters.
|
|
291
|
+
best_norm_opt_params : dict or None
|
|
292
|
+
Normalised parameters at best-fit state
|
|
293
|
+
If not provided, will be computed from best_opt_params and normalisation_spec, but providing it can save a redundant computation
|
|
294
|
+
rotation_key : str or None
|
|
295
|
+
If provided, must be 'rc' or 'omega'. Used to transform covariance matrix from optimised 'mu' to rotation_key
|
|
296
|
+
priors_keys, priors_means, priors_sigmas : tuples
|
|
297
|
+
Parallel tuples of prior information for optimised parameters. If provided, a prior penalty is added to the chi2 loss before computing the Hessian. If not provided, no prior penalty is added.
|
|
298
|
+
|
|
299
|
+
Returns
|
|
300
|
+
-------
|
|
301
|
+
dict
|
|
302
|
+
1-sigma uncertainties for each optimisable parameter
|
|
303
|
+
array
|
|
304
|
+
covariance matrix
|
|
305
|
+
dict or None
|
|
306
|
+
If rotation_key was given and 'mu' was optimised: {'keys': new_keys, 'cov': new_cov, 'errors': error_dict}
|
|
307
|
+
where new_keys is same as original keys but with 'mu' replaced by rotation_key,
|
|
308
|
+
new_cov is the covariance matrix transformed into the original parameter space,
|
|
309
|
+
and error_dict is the dict of 1-sigma errors for each parameter in new_keys.
|
|
310
|
+
"""
|
|
311
|
+
if gradient_tol is not None:
|
|
312
|
+
gradient_tol = float(gradient_tol)
|
|
313
|
+
if not math.isfinite(gradient_tol):
|
|
314
|
+
raise ValueError('gradient_tol must be finite when provided.')
|
|
315
|
+
if gradient_tol <= 0:
|
|
316
|
+
raise ValueError('gradient_tol must be positive when provided.')
|
|
317
|
+
|
|
318
|
+
if normalisation_spec is None:
|
|
319
|
+
raise ValueError("normalisation_spec not provided")
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
# evaluate hessian in raw_v_r0 space (v_r0 = softplus(raw_v_r0))
|
|
323
|
+
has_v_r0 = 'v_r0' in best_opt_params
|
|
324
|
+
params_for_hessian = dict(best_opt_params)
|
|
325
|
+
if has_v_r0:
|
|
326
|
+
v_r0_best = gradient_descent.to_float64(params_for_hessian['v_r0'])
|
|
327
|
+
if not bool(v_r0_best > 0): #should never be triggered
|
|
328
|
+
raise ValueError(f"best fit v_r0 must be positive, got v_r0={v_r0_best}")
|
|
329
|
+
params_for_hessian['v_r0'] = gradient_descent.inv_softplus(v_r0_best)
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
# convert dict -> vector
|
|
333
|
+
params_vec, keys = params_dict_to_vector(params_for_hessian)
|
|
334
|
+
loss_method = gradient_descent.check_loss_method(loss_method)
|
|
335
|
+
|
|
336
|
+
# precompute model sort index at best fit params
|
|
337
|
+
model_sort_idx, dmetric_model = compute_model_sort_idx(best_opt_params, fixed_params, distance_pc, prepared_data, npoints=npoints)
|
|
338
|
+
|
|
339
|
+
if best_norm_opt_params is not None:
|
|
340
|
+
norm_opt_params = best_norm_opt_params
|
|
341
|
+
else:
|
|
342
|
+
norm_opt_params = gradient_descent.normalise_opt_params(best_opt_params, normalisation_spec)
|
|
343
|
+
|
|
344
|
+
norm_params_vec, norm_keys = params_dict_to_vector(norm_opt_params)
|
|
345
|
+
missing_norm_keys = [key for key in keys if key not in normalisation_spec and key != 'v_r0']
|
|
346
|
+
if missing_norm_keys:
|
|
347
|
+
raise ValueError(
|
|
348
|
+
f"normalisation_spec is missing optimised parameter keys required: {missing_norm_keys} "
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
def loss_vec_norm(theta_norm_vec):
|
|
352
|
+
norm_params = vector_to_params_dict(theta_norm_vec, norm_keys)
|
|
353
|
+
physical_params = gradient_descent.denormalise_opt_params(norm_params, normalisation_spec)
|
|
354
|
+
model_params = {**physical_params, **fixed_params}
|
|
355
|
+
chi2_total = chi2_loss_hsafe(
|
|
356
|
+
model_params,
|
|
357
|
+
distance_pc,
|
|
358
|
+
prepared_data,
|
|
359
|
+
loss_method=loss_method,
|
|
360
|
+
model_sort_idx=model_sort_idx,
|
|
361
|
+
dmetric_model_frozen=dmetric_model,
|
|
362
|
+
npoints=npoints,
|
|
363
|
+
priors_keys=priors_keys,
|
|
364
|
+
priors_means=priors_means,
|
|
365
|
+
priors_sigmas=priors_sigmas
|
|
366
|
+
)
|
|
367
|
+
return chi2_total
|
|
368
|
+
|
|
369
|
+
def loss_vec(theta_vec):
|
|
370
|
+
params = vector_to_params_dict(theta_vec, keys)
|
|
371
|
+
if has_v_r0:
|
|
372
|
+
params = dict(params)
|
|
373
|
+
params['v_r0'] = gradient_descent.softplus(params['v_r0'])
|
|
374
|
+
model_params = {**params, **fixed_params}
|
|
375
|
+
chi2_total = chi2_loss_hsafe(
|
|
376
|
+
model_params,
|
|
377
|
+
distance_pc,
|
|
378
|
+
prepared_data,
|
|
379
|
+
loss_method=loss_method,
|
|
380
|
+
model_sort_idx=model_sort_idx,
|
|
381
|
+
dmetric_model_frozen=dmetric_model,
|
|
382
|
+
npoints=npoints,
|
|
383
|
+
priors_keys=priors_keys,
|
|
384
|
+
priors_means=priors_means,
|
|
385
|
+
priors_sigmas=priors_sigmas
|
|
386
|
+
)
|
|
387
|
+
return chi2_total
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
# Check gradient magnitude at best-fit parameters in normalised space.
|
|
392
|
+
if gradient_tol is not None:
|
|
393
|
+
def norm_loss_vec(theta_norm_vec):
|
|
394
|
+
norm_params = vector_to_params_dict(theta_norm_vec, norm_keys)
|
|
395
|
+
physical_params = gradient_descent.denormalise_opt_params(norm_params, normalisation_spec)
|
|
396
|
+
model_params = {**physical_params, **fixed_params}
|
|
397
|
+
chi2_total, _, _ = gradient_descent.chi2_loss(
|
|
398
|
+
model_params,
|
|
399
|
+
distance_pc,
|
|
400
|
+
prepared_data,
|
|
401
|
+
loss_method=loss_method,
|
|
402
|
+
priors_keys=priors_keys,
|
|
403
|
+
priors_means=priors_means,
|
|
404
|
+
priors_sigmas=priors_sigmas
|
|
405
|
+
)
|
|
406
|
+
return chi2_total
|
|
407
|
+
|
|
408
|
+
norm_grad_vec = jax.grad(norm_loss_vec)(norm_params_vec)
|
|
409
|
+
|
|
410
|
+
norm_grad_norm = float(gradient_descent.gradient_l2_norm(norm_grad_vec))
|
|
411
|
+
|
|
412
|
+
if norm_grad_norm > gradient_tol:
|
|
413
|
+
print(
|
|
414
|
+
"WARNING: normalised-space gradient norm at best fit = "
|
|
415
|
+
f"{norm_grad_norm:.3e} exceeds tolerance {gradient_tol:.3e}"
|
|
416
|
+
)
|
|
417
|
+
print("Optimisation may not have reached a minimum yet.")
|
|
418
|
+
print("Parameter uncertainties may be less reliable. Consider:")
|
|
419
|
+
print(" - Increasing n_epochs")
|
|
420
|
+
print(" - Reducing learning rate for finer convergence")
|
|
421
|
+
print(" - Reducing loss_threshold, if used, to allow more optimisation steps")
|
|
422
|
+
|
|
423
|
+
# compute Hessian in normalised space
|
|
424
|
+
H_norm = jax.hessian(loss_vec_norm)(norm_params_vec)
|
|
425
|
+
|
|
426
|
+
# invert to get covariance in normalised space
|
|
427
|
+
cov_norm = jnp.linalg.inv(H_norm)
|
|
428
|
+
|
|
429
|
+
# transform from normalised space to physical space
|
|
430
|
+
def denormalise_vec(theta_norm_vec):
|
|
431
|
+
norm_params = vector_to_params_dict(theta_norm_vec, norm_keys)
|
|
432
|
+
physical_params = gradient_descent.denormalise_opt_params(norm_params, normalisation_spec)
|
|
433
|
+
# denormalise_opt_params returns v_r0 already passed through softplus - convert back to raw space here so that J is correct for the transformation from raw_v_r0 to v_r0
|
|
434
|
+
if has_v_r0:
|
|
435
|
+
physical_params = dict(physical_params)
|
|
436
|
+
physical_params['v_r0'] = gradient_descent.inv_softplus(physical_params['v_r0'])
|
|
437
|
+
output = [physical_params[k] for k in keys]
|
|
438
|
+
return jnp.stack(output)
|
|
439
|
+
|
|
440
|
+
J = jax.jacobian(denormalise_vec)(norm_params_vec)
|
|
441
|
+
cov = J @ cov_norm @ J.T
|
|
442
|
+
|
|
443
|
+
if has_v_r0:
|
|
444
|
+
v_r0_transformed = transform_cov_matrix(cov, keys, best_opt_params, fixed_params, rotation_key=None, v_r0_is_raw=True)
|
|
445
|
+
cov = v_r0_transformed['cov']
|
|
446
|
+
|
|
447
|
+
# parameter errors
|
|
448
|
+
errors = jnp.sqrt(jnp.diag(cov))
|
|
449
|
+
|
|
450
|
+
error_dict = {k: float(errors[i]) for i, k in enumerate(keys)}
|
|
451
|
+
|
|
452
|
+
cov_transformed_dict = None
|
|
453
|
+
if rotation_key is not None and 'mu' in keys:
|
|
454
|
+
cov_transformed_dict = transform_cov_matrix(cov, keys, best_opt_params, fixed_params, rotation_key)
|
|
455
|
+
|
|
456
|
+
return error_dict, cov, cov_transformed_dict
|
|
457
|
+
|
|
458
|
+
def transform_cov_matrix(cov, keys, best_opt_params, fixed_params, rotation_key=None, v_r0_is_raw=False):
|
|
459
|
+
"""Transform a covariance matrix from optimisation space to physical/output space.
|
|
460
|
+
|
|
461
|
+
Maths:
|
|
462
|
+
if A = original optimisation-space parameter vector
|
|
463
|
+
and B = transformed parameter vector,
|
|
464
|
+
then covariance in B space is
|
|
465
|
+
|
|
466
|
+
cov_B = J @ cov_A @ J^T
|
|
467
|
+
|
|
468
|
+
where J is the Jacobian of the transformation from A to B, evaluated at the best-fit parameters
|
|
469
|
+
J = d(B) / d(A) (identity except for rows being tranformed)
|
|
470
|
+
|
|
471
|
+
Parameters:
|
|
472
|
+
cov: covariance matrix from estimate_parameter_errors, in same parameter order as 'keys'
|
|
473
|
+
keys: list of str. optimised parameter names
|
|
474
|
+
best_opt_params: dict of best-fit optimised parameters in physical space (the point at which we evaluate J)
|
|
475
|
+
fixed_params: dict of fixed parameters (the point at which we evaluate J)
|
|
476
|
+
rotation_key: str of None, the original rotation parameter which we want to transform into ('rc' or 'omega')
|
|
477
|
+
v_r0_is_raw: bool, whether v_r0 is in raw space (before softplus transformation)
|
|
478
|
+
|
|
479
|
+
Returns:
|
|
480
|
+
new_cov: covariance matrix transformed into original parameter space, with same order as keys but with 'mu' replaced by rotation_key if rotation_key is not None
|
|
481
|
+
new_keys: list of str, same as keys but with 'mu' replaced by rotation_key if rotation_key is not None
|
|
482
|
+
errors: dict of 1-sigma errors for each parameter in new_keys
|
|
483
|
+
"""
|
|
484
|
+
if rotation_key is not None:
|
|
485
|
+
if rotation_key not in ['rc', 'omega']:
|
|
486
|
+
raise ValueError(f"rotation_key must be 'rc' or 'omega', got {rotation_key}")
|
|
487
|
+
if 'mu' not in keys:
|
|
488
|
+
raise ValueError(f"keys must include 'mu' for covariance transformation, got {keys}")
|
|
489
|
+
# check that we have mass and r0 available to convert mu to rc or omega
|
|
490
|
+
for required_key in ('mass', 'r0'):
|
|
491
|
+
if required_key not in best_opt_params and required_key not in fixed_params:
|
|
492
|
+
raise ValueError(f"'{required_key}' must be present in either best_opt_params or fixed_params")
|
|
493
|
+
|
|
494
|
+
# build the vector at which to evaluate the jacobian
|
|
495
|
+
params_list = []
|
|
496
|
+
for k in keys:
|
|
497
|
+
if k == 'v_r0' and v_r0_is_raw:
|
|
498
|
+
v_r0_best = gradient_descent.to_float64(best_opt_params[k])
|
|
499
|
+
if not bool(v_r0_best > 0): #should never be triggered
|
|
500
|
+
raise ValueError(f"best fit v_r0 must be positive, got v_r0={v_r0_best}")
|
|
501
|
+
params_list.append(gradient_descent.inv_softplus(v_r0_best))
|
|
502
|
+
else:
|
|
503
|
+
params_list.append(float(best_opt_params[k]))
|
|
504
|
+
params_vec = jnp.array(params_list, dtype=jnp.float64)
|
|
505
|
+
|
|
506
|
+
def transform(vec_A):
|
|
507
|
+
opt_params = vector_to_params_dict(vec_A, keys)
|
|
508
|
+
combined_params = {**fixed_params, **opt_params}
|
|
509
|
+
if rotation_key is not None:
|
|
510
|
+
mu = combined_params['mu']
|
|
511
|
+
mass = combined_params['mass']
|
|
512
|
+
r0 = combined_params['r0']
|
|
513
|
+
if rotation_key == 'rc':
|
|
514
|
+
rotation_val = mu * r0
|
|
515
|
+
else: # 'omega'
|
|
516
|
+
rotation_val = stream_lines_grad.omega_from_mu(mu=mu, mass=mass, r0=r0)
|
|
517
|
+
output = []
|
|
518
|
+
for k in keys:
|
|
519
|
+
if k == 'mu' and rotation_key is not None:
|
|
520
|
+
output.append(rotation_val)
|
|
521
|
+
elif k == 'v_r0' and v_r0_is_raw:
|
|
522
|
+
output.append(gradient_descent.softplus(opt_params[k]))
|
|
523
|
+
else:
|
|
524
|
+
output.append(opt_params[k])
|
|
525
|
+
return jnp.stack(output)
|
|
526
|
+
|
|
527
|
+
J = jax.jacobian(transform)(params_vec)
|
|
528
|
+
new_cov = J @ cov @ J.T
|
|
529
|
+
|
|
530
|
+
new_keys = [rotation_key if (k == 'mu' and rotation_key is not None) else k for k in keys]
|
|
531
|
+
new_sigmas = jnp.sqrt(jnp.diag(new_cov))
|
|
532
|
+
error_dict = {k: float(new_sigmas[i]) for i, k in enumerate(new_keys)}
|
|
533
|
+
|
|
534
|
+
return {'keys': new_keys, 'cov': new_cov, 'errors': error_dict}
|
|
535
|
+
|
|
536
|
+
#-------------------- old gradient_descent.py ---------------------
|
|
537
|
+
|
|
538
|
+
def params_dict_to_vector(opt_params):
|
|
539
|
+
"""Convert parameter dict to ordered vector"""
|
|
540
|
+
keys = list(opt_params.keys())
|
|
541
|
+
vec = jnp.array([opt_params[k] for k in keys], dtype=jnp.float64)
|
|
542
|
+
return vec, keys
|
|
543
|
+
|
|
544
|
+
def vector_to_params_dict(vec, keys):
|
|
545
|
+
"""Convert parameter vector back to dict"""
|
|
546
|
+
return {k: vec[i] for i, k in enumerate(keys)}
|
|
547
|
+
|
|
548
|
+
@jax.jit
|
|
549
|
+
def match_model_to_data_curve(ra_model, dec_model, v_model, ra_data, dec_data):
|
|
550
|
+
"""
|
|
551
|
+
Extract model values corresponding to data positions using the distance metric from
|
|
552
|
+
extract_streamline.get_distance_metric
|
|
553
|
+
|
|
554
|
+
Method:
|
|
555
|
+
1. Compute the distance metric for model and data points
|
|
556
|
+
2. Apply finite masks
|
|
557
|
+
3. Normalise both metrics to [0, 1] based on their finite ranges
|
|
558
|
+
4. Map data normalised positions to model normalised positions
|
|
559
|
+
5. Interpolate model RA, Dec, and velocity at the mapped positions
|
|
560
|
+
|
|
561
|
+
Returns
|
|
562
|
+
-------
|
|
563
|
+
ra_model_interp, dec_model_interp, v_model_interp, valid, dmetric_model, matching_trace
|
|
564
|
+
where valid is a boolean mask with shape len(original data), marking
|
|
565
|
+
retained data points
|
|
566
|
+
"""
|
|
567
|
+
ra_model = extract_streamline.to_float64(ra_model)
|
|
568
|
+
dec_model = extract_streamline.to_float64(dec_model)
|
|
569
|
+
v_model = extract_streamline.to_float64(v_model)
|
|
570
|
+
ra_data = extract_streamline.to_float64(ra_data)
|
|
571
|
+
dec_data = extract_streamline.to_float64(dec_data)
|
|
572
|
+
|
|
573
|
+
# get distance metrics
|
|
574
|
+
dmetric_model, _ = extract_streamline.get_distance_metric(ra_model, dec_model)
|
|
575
|
+
dmetric_data, _ = extract_streamline.get_distance_metric(ra_data, dec_data)
|
|
576
|
+
|
|
577
|
+
# only finite values are valid
|
|
578
|
+
model_valid = (
|
|
579
|
+
jnp.isfinite(ra_model)
|
|
580
|
+
& jnp.isfinite(dec_model)
|
|
581
|
+
& jnp.isfinite(v_model)
|
|
582
|
+
& jnp.isfinite(dmetric_model)
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
data_valid = (
|
|
586
|
+
jnp.isfinite(ra_data)
|
|
587
|
+
& jnp.isfinite(dec_data)
|
|
588
|
+
& jnp.isfinite(dmetric_data)
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
# we also filter model to keep only model points with dmetric >= minimum of data dmetric
|
|
592
|
+
# this is becuase the model shouldn't go further in than the innermost data point
|
|
593
|
+
# as this is where we no longer observe the streamer
|
|
594
|
+
d_data_valid = jnp.where(data_valid, dmetric_data, jnp.inf)
|
|
595
|
+
data_min = jnp.min(d_data_valid)
|
|
596
|
+
|
|
597
|
+
# enforce both constraints on model
|
|
598
|
+
model_keep = model_valid & (dmetric_model >= data_min)
|
|
599
|
+
|
|
600
|
+
# weights: 0 = ignore, 1 = use. This is for jax/jit compatibility
|
|
601
|
+
w_model = model_keep.astype(jnp.float64)
|
|
602
|
+
|
|
603
|
+
d_model = jnp.where(model_keep, dmetric_model, 0.0)
|
|
604
|
+
d_data = jnp.where(data_valid, dmetric_data, 0.0)
|
|
605
|
+
|
|
606
|
+
ra = ra_model
|
|
607
|
+
dec = dec_model
|
|
608
|
+
v = v_model
|
|
609
|
+
|
|
610
|
+
# ---- sort ONLY MODEL using metric + weight penalty ----
|
|
611
|
+
model_sort_key = d_model + (1.0 - w_model) * BIG
|
|
612
|
+
model_idx = jnp.argsort(model_sort_key)
|
|
613
|
+
|
|
614
|
+
d_model_s = d_model[model_idx]
|
|
615
|
+
ra_s = ra[model_idx]
|
|
616
|
+
dec_s = dec[model_idx]
|
|
617
|
+
v_s = v[model_idx]
|
|
618
|
+
w_model_s = w_model[model_idx]
|
|
619
|
+
|
|
620
|
+
# stats for trace and interpolation domain
|
|
621
|
+
data_min_eff = jnp.min(jnp.where(data_valid, d_data, jnp.inf))
|
|
622
|
+
data_max_eff = jnp.max(jnp.where(data_valid, d_data, -jnp.inf))
|
|
623
|
+
|
|
624
|
+
model_min = jnp.min(jnp.where(model_keep, d_model, jnp.inf))
|
|
625
|
+
model_max = jnp.max(jnp.where(model_keep, d_model, -jnp.inf))
|
|
626
|
+
|
|
627
|
+
model_span = model_max - model_min
|
|
628
|
+
data_span = data_max_eff - data_min_eff
|
|
629
|
+
model_span_safe = jnp.where(model_span == 0.0, 1.0, model_span)
|
|
630
|
+
data_span_safe = jnp.where(data_span == 0.0, 1.0, data_span)
|
|
631
|
+
|
|
632
|
+
# normalise data metric
|
|
633
|
+
d_data_norm = (d_data - data_min_eff) / data_span_safe
|
|
634
|
+
d_goal = model_min + d_data_norm * model_span_safe
|
|
635
|
+
|
|
636
|
+
# interpolate model at data points, using weights to ignore invalid model points
|
|
637
|
+
# by giving them huge distance values so they don't affect the interpolation
|
|
638
|
+
xp = jnp.where(w_model_s > 0, d_model_s, BIG)
|
|
639
|
+
|
|
640
|
+
ra_interp = jnp.interp(d_goal, xp, ra_s)
|
|
641
|
+
dec_interp = jnp.interp(d_goal, xp, dec_s)
|
|
642
|
+
v_interp = jnp.interp(d_goal, xp, v_s)
|
|
643
|
+
|
|
644
|
+
# things for trace
|
|
645
|
+
valid = data_valid
|
|
646
|
+
|
|
647
|
+
matching_trace = {
|
|
648
|
+
"model_points_total": model_idx.size,
|
|
649
|
+
"model_nan_count": jnp.sum(jnp.isnan(d_model)),
|
|
650
|
+
"model_valid_points": model_valid.sum(),
|
|
651
|
+
"data_points_total": ra_data.size,
|
|
652
|
+
"data_nan_count": jnp.sum(jnp.isnan(d_data)),
|
|
653
|
+
"data_valid_points": data_valid.sum(),
|
|
654
|
+
"model_metric_min": model_min,
|
|
655
|
+
"model_metric_max": model_max,
|
|
656
|
+
"data_metric_min": data_min_eff,
|
|
657
|
+
"data_metric_max": data_max_eff,
|
|
658
|
+
"model_metric_span": model_span_safe,
|
|
659
|
+
"data_metric_span": data_span_safe}
|
|
660
|
+
|
|
661
|
+
return ra_interp, dec_interp, v_interp, valid, dmetric_model, matching_trace
|
|
662
|
+
|
|
663
|
+
checked_matching = checkify.checkify(match_model_to_data_curve)
|
|
664
|
+
|
|
665
|
+
def checked_match_model_to_data_curve(*args, **kwargs):
|
|
666
|
+
"""Wrapper around match_model_to_data_curve with checkify checks for errors (to remain jax compatible)"""
|
|
667
|
+
errors, result = checked_matching(*args, **kwargs)
|
|
668
|
+
errors.throw()
|
|
669
|
+
return result
|
|
670
|
+
|
|
671
|
+
|
|
672
|
+
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
|
|
676
|
+
|
|
677
|
+
|