sting 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sting/errors.py ADDED
@@ -0,0 +1,677 @@
1
+ # For computing the uncertainties, we must use some slightly different versions of the functions,
2
+ # because the original versions use jax.jit, which is great for optimization,
3
+ # but not compatble with taking second derivatives for the Hessian-based uncertainty estimation
4
+ # because they have a side effect of requiring computations such as argsort.
5
+
6
+ # So we put the slower but hessian-compatible versions of the relevant functions here, and use them in the uncertainty estimation.
7
+
8
+ import jax.numpy as jnp
9
+ import jax
10
+ from jax.experimental import checkify
11
+ import astropy.units as u
12
+ import math
13
+ jax.config.update("jax_enable_x64", True)
14
+ from typing import NamedTuple
15
+ from collections import namedtuple
16
+ from . import stream_lines_grad, extract_streamline, gradient_descent
17
+
18
+
19
+ ## constants
20
+ eps = 1e-8 # small value to avoid division by zero
21
+ G = 6.67430e-11 * (1e-3)**2 * (1.988416e30) / (1.4959787e11) # in au (km/s)^2 * Msol^-1
22
+ au_to_km = 1.4959787e8 #km
23
+ FLOAT_DTYPE = jnp.float64
24
+
25
+ # settings and constants
26
+ BIG = 1e30
27
+ LOSS_METHOD_CHOICES = [0, 1]
28
+ CANONICAL_UNITS = {
29
+ "r0": u.au,
30
+ "theta0": u.rad,
31
+ "phi0": u.rad,
32
+ "inc": u.rad,
33
+ "pa": u.rad,
34
+ "v_r0": u.km / u.s,
35
+ "mass": u.Msun,
36
+ "rmin": u.au,
37
+ "deltar": u.au,
38
+ "v_lsr": u.km / u.s,
39
+ "rc": u.au,
40
+ "omega": 1/u.s,
41
+ # mu = rc/r0 is dimensionless, so no units
42
+ }
43
+
44
+ STREAMLINE_MODEL_PARAM_KEYS = (
45
+ 'r0',
46
+ 'theta0',
47
+ 'phi0',
48
+ 'rc',
49
+ 'omega',
50
+ 'mu',
51
+ 'v_r0',
52
+ 'mass',
53
+ 'inc',
54
+ 'pa',
55
+ 'rmin',
56
+ 'deltar',
57
+ 'v_lsr',
58
+ )
59
+
60
+
61
+ def match_model_to_data_curve_hsafe(
62
+ ra_model,
63
+ dec_model,
64
+ v_model,
65
+ valid_mask_model,
66
+ ra_data,
67
+ dec_data,
68
+ model_sort_idx,
69
+ dmetric_model_frozen, #precomputed at best fit params and treated as constant
70
+ data_valid,
71
+ ):
72
+ """Second-derivative-safe version of match_model_to_data_curve, with the argsort replaced by a precomputed static integer index array"""
73
+ ra_model = jnp.asarray(ra_model, dtype=jnp.float64)
74
+ dec_model = jnp.asarray(dec_model, dtype=jnp.float64)
75
+ v_model = jnp.asarray(v_model, dtype=jnp.float64)
76
+ ra_data = jnp.asarray(ra_data, dtype=jnp.float64)
77
+ dec_data = jnp.asarray(dec_data, dtype=jnp.float64)
78
+
79
+ dmetric_model = dmetric_model_frozen
80
+ dmetric_data, _ = extract_streamline.get_distance_metric(ra_data, dec_data) # this is not precomputed because it only depends on the data
81
+ model_valid = valid_mask_model.astype(bool)
82
+
83
+ d_data_valid = jnp.where(data_valid, dmetric_data, jnp.inf)
84
+ data_min = jnp.min(d_data_valid)
85
+ model_keep = model_valid & (dmetric_model >= data_min) & (dmetric_model < BIG)
86
+ w_model = model_keep.astype(jnp.float64)
87
+
88
+ d_model = jnp.where(model_keep, dmetric_model, 0.0)
89
+ d_data = jnp.where(data_valid, dmetric_data, 0.0)
90
+
91
+ # apply precomputed sort index to model
92
+ d_model_s = d_model[model_sort_idx]
93
+ ra_s = ra_model[model_sort_idx]
94
+ dec_s = dec_model[model_sort_idx]
95
+ v_s = v_model[model_sort_idx]
96
+ w_model_s = w_model[model_sort_idx]
97
+
98
+ data_min_eff = jnp.min(jnp.where(data_valid, dmetric_data, jnp.inf))
99
+ data_max_eff = jnp.max(jnp.where(data_valid, dmetric_data, -jnp.inf))
100
+ model_min = jnp.min(jnp.where(model_keep, dmetric_model, jnp.inf))
101
+ model_max = jnp.max(jnp.where(model_keep, dmetric_model, -jnp.inf))
102
+
103
+ model_span = model_max - model_min
104
+ data_span = data_max_eff - data_min_eff
105
+ model_span_safe = jnp.where(model_span > 0.0, model_span, 1.0)
106
+ data_span_safe = jnp.where(data_span > 0.0, data_span, 1.0)
107
+
108
+ d_data_norm = (d_data - data_min_eff) / data_span_safe
109
+ d_goal = model_min + d_data_norm * model_span_safe
110
+
111
+ xp = jnp.where(w_model_s > 0, d_model_s, BIG)
112
+ ra_interp = jnp.interp(d_goal, xp, ra_s)
113
+ dec_interp = jnp.interp(d_goal, xp, dec_s)
114
+ v_interp = jnp.interp(d_goal, xp, v_s)
115
+
116
+ valid = data_valid
117
+ return ra_interp, dec_interp, v_interp, valid
118
+
119
+ @jax.jit(static_argnames=("loss_method", "npoints", "priors_keys", "priors_means", "priors_sigmas"))
120
+ def chi2_loss_hsafe(model_params, distance_pc, prepared_data, loss_method, model_sort_idx, dmetric_model_frozen,npoints=10000,
121
+ priors_keys=(), priors_means=(), priors_sigmas=()):
122
+ """Second-derivative-safe version of chi2_loss, using match_model_to_data_curve_hsafe"""
123
+ distance_pc = jnp.asarray(distance_pc, dtype=jnp.float64)
124
+
125
+ rmin = model_params['rmin']
126
+ if rmin is None:
127
+ rmin = jnp.asarray(0.0, dtype=jnp.float64)
128
+
129
+ if 'mu' in model_params:
130
+ mu = model_params['mu']
131
+ elif 'rc' in model_params:
132
+ mu = model_params['rc'] / model_params['r0']
133
+ elif 'omega' in model_params:
134
+ mu = stream_lines_grad.mu_from_omega(omega=model_params['omega'], mass=model_params['mass'], r0=model_params['r0'])
135
+ else:
136
+ raise ValueError("model_params must contain either 'rc', 'omega', or 'mu'")
137
+
138
+ model_params = dict(model_params)
139
+ model_params['mu'] = mu
140
+
141
+ xyz_stream_checked = checkify.checkify(stream_lines_grad.xyz_stream)
142
+ errs, ((x, y, z), (vx, vy, vz), valid_mask) = xyz_stream_checked(
143
+ mass=model_params['mass'],
144
+ r0=model_params['r0'],
145
+ theta0=model_params['theta0'],
146
+ phi0=model_params['phi0'],
147
+ mu=model_params['mu'],
148
+ v_r0=model_params['v_r0'],
149
+ inc=model_params['inc'],
150
+ pa=model_params['pa'],
151
+ rmin=rmin,
152
+ deltar=model_params['deltar'],
153
+ npoints=npoints,
154
+ )
155
+
156
+ ra_model = -x / distance_pc
157
+ dec_model = z / distance_pc
158
+ v_model = vy + model_params['v_lsr']
159
+ ra_model = jnp.where(valid_mask, ra_model, jnp.nan)
160
+ dec_model = jnp.where(valid_mask, dec_model, jnp.nan)
161
+ v_model = jnp.where(valid_mask, v_model, jnp.nan)
162
+
163
+ ra_data = prepared_data.ra_data
164
+ dec_data = prepared_data.dec_data
165
+ v_data = prepared_data.v_data
166
+ ra_sigma = prepared_data.ra_sigma_safe
167
+ dec_sigma = prepared_data.dec_sigma_safe
168
+ v_sigma = prepared_data.v_sigma_safe
169
+ data_valid = prepared_data.data_finite_mask
170
+
171
+
172
+ ra_interp, dec_interp, v_interp, valid = match_model_to_data_curve_hsafe(
173
+ ra_model, dec_model, v_model, valid_mask,
174
+ ra_data, dec_data,
175
+ model_sort_idx,
176
+ dmetric_model_frozen,
177
+ data_valid
178
+ )
179
+
180
+ valid_weights = valid.astype(jnp.float64)
181
+
182
+ chi2_v = jnp.sum(valid_weights * ((v_data - v_interp) / v_sigma)**2)
183
+
184
+ if loss_method == 0:
185
+ chi2_ra = jnp.sum(valid_weights * ((ra_data - ra_interp) / ra_sigma)**2)
186
+ chi2_dec = jnp.sum(valid_weights * ((dec_data - dec_interp) / dec_sigma)**2)
187
+ chi2_total = chi2_ra + chi2_dec + chi2_v
188
+ else:
189
+ r_proj_data = prepared_data.r_proj_data
190
+ theta_proj_data = prepared_data.theta_proj_data
191
+ r_proj_model, theta_proj_model = extract_streamline.cartesian_to_polar(ra_interp, dec_interp)
192
+ dtheta = extract_streamline.wrap_to_pi(theta_proj_data - theta_proj_model)
193
+ sigma_r = jnp.sqrt(ra_sigma**2 + dec_sigma**2)
194
+ r_eps = jnp.asarray(1e-8, dtype=jnp.float64)
195
+ r_safe = jnp.maximum(jnp.abs(r_proj_data), r_eps)
196
+ sigma_theta = jnp.sqrt((dec_data * ra_sigma)**2 + (ra_data * dec_sigma)**2) / r_safe**2
197
+ sigma_theta = jnp.maximum(sigma_theta, r_eps)
198
+ chi2_r = jnp.sum(valid_weights * ((r_proj_data - r_proj_model) / sigma_r)**2)
199
+ chi2_theta = jnp.sum(valid_weights * (dtheta / sigma_theta)**2)
200
+ chi2_total = chi2_r + chi2_theta + chi2_v
201
+
202
+ # add prior penalty if any
203
+ chi2_prior = gradient_descent.compute_prior_penalty(model_params, priors_means, priors_sigmas, priors_keys)
204
+ chi2_total = chi2_total + chi2_prior
205
+
206
+ return chi2_total
207
+
208
+ def compute_model_sort_idx(best_opt_params, fixed_params, distance_pc, prepared_data, npoints=10000):
209
+ """Evaluate forward model once at best-fit parameters, and compute sort index that can be reused for Hessian-based uncertainty estimation"""
210
+ model_params = {**best_opt_params, **fixed_params}
211
+ rmin = model_params.get('rmin', None)
212
+ if rmin is None:
213
+ rmin = 0.0
214
+
215
+ if 'mu' in model_params:
216
+ mu = float(model_params['mu'])
217
+ elif 'rc' in model_params:
218
+ mu = float(model_params['rc']) / float(model_params['r0'])
219
+ elif 'omega' in model_params:
220
+ mu = stream_lines_grad.mu_from_omega(omega=model_params['omega'], mass=model_params['mass'], r0=model_params['r0'])
221
+ else:
222
+ raise ValueError("model_params must contain either 'rc', 'omega', or 'mu'")
223
+
224
+ model_params = dict(model_params)
225
+ model_params['mu'] = mu
226
+
227
+ (x, y, z), (vx, vy, vz), valid_mask = stream_lines_grad.xyz_stream(
228
+ mass=float(model_params['mass']),
229
+ r0=float(model_params['r0']),
230
+ theta0=float(model_params['theta0']),
231
+ phi0=float(model_params['phi0']),
232
+ mu=float(model_params['mu']),
233
+ v_r0=float(model_params['v_r0']),
234
+ inc=float(model_params['inc']),
235
+ pa=float(model_params['pa']),
236
+ rmin=float(rmin),
237
+ deltar=float(model_params['deltar']),
238
+ npoints=npoints,
239
+ )
240
+
241
+ ra_model = jnp.asarray(-x, dtype=jnp.float64) / float(distance_pc)
242
+ dec_model = jnp.asarray(z, dtype=jnp.float64) / float(distance_pc)
243
+ ra_model = jnp.where(valid_mask, ra_model, jnp.nan)
244
+ dec_model = jnp.where(valid_mask, dec_model, jnp.nan)
245
+
246
+ dmetric_model, _ = extract_streamline.get_distance_metric(ra_model, dec_model)
247
+
248
+ data_valid = prepared_data.data_finite_mask
249
+ dmetric_data = prepared_data.dmetric_data
250
+ data_min = jnp.min(jnp.where(data_valid, dmetric_data, jnp.inf))
251
+
252
+ model_keep = valid_mask.astype(bool) & (dmetric_model >= data_min) & (dmetric_model < BIG)
253
+ w_model = model_keep.astype(jnp.float64)
254
+ d_model = jnp.where(model_keep, dmetric_model, 0.0)
255
+
256
+ # return integer array of indices that would sort the model points by distance metric, with invalid points at the end
257
+ model_sort_key = d_model + (1.0 - w_model) * BIG
258
+
259
+ return jnp.argsort(model_sort_key), dmetric_model
260
+
261
+ def estimate_parameter_errors(
262
+ best_opt_params,
263
+ fixed_params,
264
+ distance_pc,
265
+ prepared_data,
266
+ loss_method=0,
267
+ gradient_tol=1e-1,
268
+ normalisation_spec=None,
269
+ best_norm_opt_params=None,
270
+ rotation_key=None,
271
+ npoints=10000,
272
+ priors_keys=(),
273
+ priors_means=(),
274
+ priors_sigmas=()
275
+ ):
276
+ """
277
+ Estimate parameter uncertainties using Hessian of chi2 loss.
278
+ Hessian is computed in normalised parameter space to keep eery parameter on O(1) scale.
279
+ Then resulting covariance is transformed back to physical parameter space.
280
+
281
+ Parameters
282
+ ----------
283
+ prepared_data : PreparedData
284
+ Precomputed data-only quantities (created via extract_streamline.prepare_data).
285
+ gradient_tol : float or None
286
+ Tolerance on gradient norm in normalised space. If provided and
287
+ normalised-space gradient norm > gradient_tol at best params, a
288
+ warning is issued because the quadratic approximation may not be valid.
289
+ normalisation_spec : dict
290
+ Bounds-derived normalisation metadata for optimised parameters.
291
+ best_norm_opt_params : dict or None
292
+ Normalised parameters at best-fit state
293
+ If not provided, will be computed from best_opt_params and normalisation_spec, but providing it can save a redundant computation
294
+ rotation_key : str or None
295
+ If provided, must be 'rc' or 'omega'. Used to transform covariance matrix from optimised 'mu' to rotation_key
296
+ priors_keys, priors_means, priors_sigmas : tuples
297
+ Parallel tuples of prior information for optimised parameters. If provided, a prior penalty is added to the chi2 loss before computing the Hessian. If not provided, no prior penalty is added.
298
+
299
+ Returns
300
+ -------
301
+ dict
302
+ 1-sigma uncertainties for each optimisable parameter
303
+ array
304
+ covariance matrix
305
+ dict or None
306
+ If rotation_key was given and 'mu' was optimised: {'keys': new_keys, 'cov': new_cov, 'errors': error_dict}
307
+ where new_keys is same as original keys but with 'mu' replaced by rotation_key,
308
+ new_cov is the covariance matrix transformed into the original parameter space,
309
+ and error_dict is the dict of 1-sigma errors for each parameter in new_keys.
310
+ """
311
+ if gradient_tol is not None:
312
+ gradient_tol = float(gradient_tol)
313
+ if not math.isfinite(gradient_tol):
314
+ raise ValueError('gradient_tol must be finite when provided.')
315
+ if gradient_tol <= 0:
316
+ raise ValueError('gradient_tol must be positive when provided.')
317
+
318
+ if normalisation_spec is None:
319
+ raise ValueError("normalisation_spec not provided")
320
+
321
+
322
+ # evaluate hessian in raw_v_r0 space (v_r0 = softplus(raw_v_r0))
323
+ has_v_r0 = 'v_r0' in best_opt_params
324
+ params_for_hessian = dict(best_opt_params)
325
+ if has_v_r0:
326
+ v_r0_best = gradient_descent.to_float64(params_for_hessian['v_r0'])
327
+ if not bool(v_r0_best > 0): #should never be triggered
328
+ raise ValueError(f"best fit v_r0 must be positive, got v_r0={v_r0_best}")
329
+ params_for_hessian['v_r0'] = gradient_descent.inv_softplus(v_r0_best)
330
+
331
+
332
+ # convert dict -> vector
333
+ params_vec, keys = params_dict_to_vector(params_for_hessian)
334
+ loss_method = gradient_descent.check_loss_method(loss_method)
335
+
336
+ # precompute model sort index at best fit params
337
+ model_sort_idx, dmetric_model = compute_model_sort_idx(best_opt_params, fixed_params, distance_pc, prepared_data, npoints=npoints)
338
+
339
+ if best_norm_opt_params is not None:
340
+ norm_opt_params = best_norm_opt_params
341
+ else:
342
+ norm_opt_params = gradient_descent.normalise_opt_params(best_opt_params, normalisation_spec)
343
+
344
+ norm_params_vec, norm_keys = params_dict_to_vector(norm_opt_params)
345
+ missing_norm_keys = [key for key in keys if key not in normalisation_spec and key != 'v_r0']
346
+ if missing_norm_keys:
347
+ raise ValueError(
348
+ f"normalisation_spec is missing optimised parameter keys required: {missing_norm_keys} "
349
+ )
350
+
351
+ def loss_vec_norm(theta_norm_vec):
352
+ norm_params = vector_to_params_dict(theta_norm_vec, norm_keys)
353
+ physical_params = gradient_descent.denormalise_opt_params(norm_params, normalisation_spec)
354
+ model_params = {**physical_params, **fixed_params}
355
+ chi2_total = chi2_loss_hsafe(
356
+ model_params,
357
+ distance_pc,
358
+ prepared_data,
359
+ loss_method=loss_method,
360
+ model_sort_idx=model_sort_idx,
361
+ dmetric_model_frozen=dmetric_model,
362
+ npoints=npoints,
363
+ priors_keys=priors_keys,
364
+ priors_means=priors_means,
365
+ priors_sigmas=priors_sigmas
366
+ )
367
+ return chi2_total
368
+
369
+ def loss_vec(theta_vec):
370
+ params = vector_to_params_dict(theta_vec, keys)
371
+ if has_v_r0:
372
+ params = dict(params)
373
+ params['v_r0'] = gradient_descent.softplus(params['v_r0'])
374
+ model_params = {**params, **fixed_params}
375
+ chi2_total = chi2_loss_hsafe(
376
+ model_params,
377
+ distance_pc,
378
+ prepared_data,
379
+ loss_method=loss_method,
380
+ model_sort_idx=model_sort_idx,
381
+ dmetric_model_frozen=dmetric_model,
382
+ npoints=npoints,
383
+ priors_keys=priors_keys,
384
+ priors_means=priors_means,
385
+ priors_sigmas=priors_sigmas
386
+ )
387
+ return chi2_total
388
+
389
+
390
+
391
+ # Check gradient magnitude at best-fit parameters in normalised space.
392
+ if gradient_tol is not None:
393
+ def norm_loss_vec(theta_norm_vec):
394
+ norm_params = vector_to_params_dict(theta_norm_vec, norm_keys)
395
+ physical_params = gradient_descent.denormalise_opt_params(norm_params, normalisation_spec)
396
+ model_params = {**physical_params, **fixed_params}
397
+ chi2_total, _, _ = gradient_descent.chi2_loss(
398
+ model_params,
399
+ distance_pc,
400
+ prepared_data,
401
+ loss_method=loss_method,
402
+ priors_keys=priors_keys,
403
+ priors_means=priors_means,
404
+ priors_sigmas=priors_sigmas
405
+ )
406
+ return chi2_total
407
+
408
+ norm_grad_vec = jax.grad(norm_loss_vec)(norm_params_vec)
409
+
410
+ norm_grad_norm = float(gradient_descent.gradient_l2_norm(norm_grad_vec))
411
+
412
+ if norm_grad_norm > gradient_tol:
413
+ print(
414
+ "WARNING: normalised-space gradient norm at best fit = "
415
+ f"{norm_grad_norm:.3e} exceeds tolerance {gradient_tol:.3e}"
416
+ )
417
+ print("Optimisation may not have reached a minimum yet.")
418
+ print("Parameter uncertainties may be less reliable. Consider:")
419
+ print(" - Increasing n_epochs")
420
+ print(" - Reducing learning rate for finer convergence")
421
+ print(" - Reducing loss_threshold, if used, to allow more optimisation steps")
422
+
423
+ # compute Hessian in normalised space
424
+ H_norm = jax.hessian(loss_vec_norm)(norm_params_vec)
425
+
426
+ # invert to get covariance in normalised space
427
+ cov_norm = jnp.linalg.inv(H_norm)
428
+
429
+ # transform from normalised space to physical space
430
+ def denormalise_vec(theta_norm_vec):
431
+ norm_params = vector_to_params_dict(theta_norm_vec, norm_keys)
432
+ physical_params = gradient_descent.denormalise_opt_params(norm_params, normalisation_spec)
433
+ # denormalise_opt_params returns v_r0 already passed through softplus - convert back to raw space here so that J is correct for the transformation from raw_v_r0 to v_r0
434
+ if has_v_r0:
435
+ physical_params = dict(physical_params)
436
+ physical_params['v_r0'] = gradient_descent.inv_softplus(physical_params['v_r0'])
437
+ output = [physical_params[k] for k in keys]
438
+ return jnp.stack(output)
439
+
440
+ J = jax.jacobian(denormalise_vec)(norm_params_vec)
441
+ cov = J @ cov_norm @ J.T
442
+
443
+ if has_v_r0:
444
+ v_r0_transformed = transform_cov_matrix(cov, keys, best_opt_params, fixed_params, rotation_key=None, v_r0_is_raw=True)
445
+ cov = v_r0_transformed['cov']
446
+
447
+ # parameter errors
448
+ errors = jnp.sqrt(jnp.diag(cov))
449
+
450
+ error_dict = {k: float(errors[i]) for i, k in enumerate(keys)}
451
+
452
+ cov_transformed_dict = None
453
+ if rotation_key is not None and 'mu' in keys:
454
+ cov_transformed_dict = transform_cov_matrix(cov, keys, best_opt_params, fixed_params, rotation_key)
455
+
456
+ return error_dict, cov, cov_transformed_dict
457
+
458
+ def transform_cov_matrix(cov, keys, best_opt_params, fixed_params, rotation_key=None, v_r0_is_raw=False):
459
+ """Transform a covariance matrix from optimisation space to physical/output space.
460
+
461
+ Maths:
462
+ if A = original optimisation-space parameter vector
463
+ and B = transformed parameter vector,
464
+ then covariance in B space is
465
+
466
+ cov_B = J @ cov_A @ J^T
467
+
468
+ where J is the Jacobian of the transformation from A to B, evaluated at the best-fit parameters
469
+ J = d(B) / d(A) (identity except for rows being tranformed)
470
+
471
+ Parameters:
472
+ cov: covariance matrix from estimate_parameter_errors, in same parameter order as 'keys'
473
+ keys: list of str. optimised parameter names
474
+ best_opt_params: dict of best-fit optimised parameters in physical space (the point at which we evaluate J)
475
+ fixed_params: dict of fixed parameters (the point at which we evaluate J)
476
+ rotation_key: str of None, the original rotation parameter which we want to transform into ('rc' or 'omega')
477
+ v_r0_is_raw: bool, whether v_r0 is in raw space (before softplus transformation)
478
+
479
+ Returns:
480
+ new_cov: covariance matrix transformed into original parameter space, with same order as keys but with 'mu' replaced by rotation_key if rotation_key is not None
481
+ new_keys: list of str, same as keys but with 'mu' replaced by rotation_key if rotation_key is not None
482
+ errors: dict of 1-sigma errors for each parameter in new_keys
483
+ """
484
+ if rotation_key is not None:
485
+ if rotation_key not in ['rc', 'omega']:
486
+ raise ValueError(f"rotation_key must be 'rc' or 'omega', got {rotation_key}")
487
+ if 'mu' not in keys:
488
+ raise ValueError(f"keys must include 'mu' for covariance transformation, got {keys}")
489
+ # check that we have mass and r0 available to convert mu to rc or omega
490
+ for required_key in ('mass', 'r0'):
491
+ if required_key not in best_opt_params and required_key not in fixed_params:
492
+ raise ValueError(f"'{required_key}' must be present in either best_opt_params or fixed_params")
493
+
494
+ # build the vector at which to evaluate the jacobian
495
+ params_list = []
496
+ for k in keys:
497
+ if k == 'v_r0' and v_r0_is_raw:
498
+ v_r0_best = gradient_descent.to_float64(best_opt_params[k])
499
+ if not bool(v_r0_best > 0): #should never be triggered
500
+ raise ValueError(f"best fit v_r0 must be positive, got v_r0={v_r0_best}")
501
+ params_list.append(gradient_descent.inv_softplus(v_r0_best))
502
+ else:
503
+ params_list.append(float(best_opt_params[k]))
504
+ params_vec = jnp.array(params_list, dtype=jnp.float64)
505
+
506
+ def transform(vec_A):
507
+ opt_params = vector_to_params_dict(vec_A, keys)
508
+ combined_params = {**fixed_params, **opt_params}
509
+ if rotation_key is not None:
510
+ mu = combined_params['mu']
511
+ mass = combined_params['mass']
512
+ r0 = combined_params['r0']
513
+ if rotation_key == 'rc':
514
+ rotation_val = mu * r0
515
+ else: # 'omega'
516
+ rotation_val = stream_lines_grad.omega_from_mu(mu=mu, mass=mass, r0=r0)
517
+ output = []
518
+ for k in keys:
519
+ if k == 'mu' and rotation_key is not None:
520
+ output.append(rotation_val)
521
+ elif k == 'v_r0' and v_r0_is_raw:
522
+ output.append(gradient_descent.softplus(opt_params[k]))
523
+ else:
524
+ output.append(opt_params[k])
525
+ return jnp.stack(output)
526
+
527
+ J = jax.jacobian(transform)(params_vec)
528
+ new_cov = J @ cov @ J.T
529
+
530
+ new_keys = [rotation_key if (k == 'mu' and rotation_key is not None) else k for k in keys]
531
+ new_sigmas = jnp.sqrt(jnp.diag(new_cov))
532
+ error_dict = {k: float(new_sigmas[i]) for i, k in enumerate(new_keys)}
533
+
534
+ return {'keys': new_keys, 'cov': new_cov, 'errors': error_dict}
535
+
536
+ #-------------------- old gradient_descent.py ---------------------
537
+
538
+ def params_dict_to_vector(opt_params):
539
+ """Convert parameter dict to ordered vector"""
540
+ keys = list(opt_params.keys())
541
+ vec = jnp.array([opt_params[k] for k in keys], dtype=jnp.float64)
542
+ return vec, keys
543
+
544
+ def vector_to_params_dict(vec, keys):
545
+ """Convert parameter vector back to dict"""
546
+ return {k: vec[i] for i, k in enumerate(keys)}
547
+
548
+ @jax.jit
549
+ def match_model_to_data_curve(ra_model, dec_model, v_model, ra_data, dec_data):
550
+ """
551
+ Extract model values corresponding to data positions using the distance metric from
552
+ extract_streamline.get_distance_metric
553
+
554
+ Method:
555
+ 1. Compute the distance metric for model and data points
556
+ 2. Apply finite masks
557
+ 3. Normalise both metrics to [0, 1] based on their finite ranges
558
+ 4. Map data normalised positions to model normalised positions
559
+ 5. Interpolate model RA, Dec, and velocity at the mapped positions
560
+
561
+ Returns
562
+ -------
563
+ ra_model_interp, dec_model_interp, v_model_interp, valid, dmetric_model, matching_trace
564
+ where valid is a boolean mask with shape len(original data), marking
565
+ retained data points
566
+ """
567
+ ra_model = extract_streamline.to_float64(ra_model)
568
+ dec_model = extract_streamline.to_float64(dec_model)
569
+ v_model = extract_streamline.to_float64(v_model)
570
+ ra_data = extract_streamline.to_float64(ra_data)
571
+ dec_data = extract_streamline.to_float64(dec_data)
572
+
573
+ # get distance metrics
574
+ dmetric_model, _ = extract_streamline.get_distance_metric(ra_model, dec_model)
575
+ dmetric_data, _ = extract_streamline.get_distance_metric(ra_data, dec_data)
576
+
577
+ # only finite values are valid
578
+ model_valid = (
579
+ jnp.isfinite(ra_model)
580
+ & jnp.isfinite(dec_model)
581
+ & jnp.isfinite(v_model)
582
+ & jnp.isfinite(dmetric_model)
583
+ )
584
+
585
+ data_valid = (
586
+ jnp.isfinite(ra_data)
587
+ & jnp.isfinite(dec_data)
588
+ & jnp.isfinite(dmetric_data)
589
+ )
590
+
591
+ # we also filter model to keep only model points with dmetric >= minimum of data dmetric
592
+ # this is becuase the model shouldn't go further in than the innermost data point
593
+ # as this is where we no longer observe the streamer
594
+ d_data_valid = jnp.where(data_valid, dmetric_data, jnp.inf)
595
+ data_min = jnp.min(d_data_valid)
596
+
597
+ # enforce both constraints on model
598
+ model_keep = model_valid & (dmetric_model >= data_min)
599
+
600
+ # weights: 0 = ignore, 1 = use. This is for jax/jit compatibility
601
+ w_model = model_keep.astype(jnp.float64)
602
+
603
+ d_model = jnp.where(model_keep, dmetric_model, 0.0)
604
+ d_data = jnp.where(data_valid, dmetric_data, 0.0)
605
+
606
+ ra = ra_model
607
+ dec = dec_model
608
+ v = v_model
609
+
610
+ # ---- sort ONLY MODEL using metric + weight penalty ----
611
+ model_sort_key = d_model + (1.0 - w_model) * BIG
612
+ model_idx = jnp.argsort(model_sort_key)
613
+
614
+ d_model_s = d_model[model_idx]
615
+ ra_s = ra[model_idx]
616
+ dec_s = dec[model_idx]
617
+ v_s = v[model_idx]
618
+ w_model_s = w_model[model_idx]
619
+
620
+ # stats for trace and interpolation domain
621
+ data_min_eff = jnp.min(jnp.where(data_valid, d_data, jnp.inf))
622
+ data_max_eff = jnp.max(jnp.where(data_valid, d_data, -jnp.inf))
623
+
624
+ model_min = jnp.min(jnp.where(model_keep, d_model, jnp.inf))
625
+ model_max = jnp.max(jnp.where(model_keep, d_model, -jnp.inf))
626
+
627
+ model_span = model_max - model_min
628
+ data_span = data_max_eff - data_min_eff
629
+ model_span_safe = jnp.where(model_span == 0.0, 1.0, model_span)
630
+ data_span_safe = jnp.where(data_span == 0.0, 1.0, data_span)
631
+
632
+ # normalise data metric
633
+ d_data_norm = (d_data - data_min_eff) / data_span_safe
634
+ d_goal = model_min + d_data_norm * model_span_safe
635
+
636
+ # interpolate model at data points, using weights to ignore invalid model points
637
+ # by giving them huge distance values so they don't affect the interpolation
638
+ xp = jnp.where(w_model_s > 0, d_model_s, BIG)
639
+
640
+ ra_interp = jnp.interp(d_goal, xp, ra_s)
641
+ dec_interp = jnp.interp(d_goal, xp, dec_s)
642
+ v_interp = jnp.interp(d_goal, xp, v_s)
643
+
644
+ # things for trace
645
+ valid = data_valid
646
+
647
+ matching_trace = {
648
+ "model_points_total": model_idx.size,
649
+ "model_nan_count": jnp.sum(jnp.isnan(d_model)),
650
+ "model_valid_points": model_valid.sum(),
651
+ "data_points_total": ra_data.size,
652
+ "data_nan_count": jnp.sum(jnp.isnan(d_data)),
653
+ "data_valid_points": data_valid.sum(),
654
+ "model_metric_min": model_min,
655
+ "model_metric_max": model_max,
656
+ "data_metric_min": data_min_eff,
657
+ "data_metric_max": data_max_eff,
658
+ "model_metric_span": model_span_safe,
659
+ "data_metric_span": data_span_safe}
660
+
661
+ return ra_interp, dec_interp, v_interp, valid, dmetric_model, matching_trace
662
+
663
+ checked_matching = checkify.checkify(match_model_to_data_curve)
664
+
665
+ def checked_match_model_to_data_curve(*args, **kwargs):
666
+ """Wrapper around match_model_to_data_curve with checkify checks for errors (to remain jax compatible)"""
667
+ errors, result = checked_matching(*args, **kwargs)
668
+ errors.throw()
669
+ return result
670
+
671
+
672
+
673
+
674
+
675
+
676
+
677
+