sting 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1776 @@
1
+ '''
2
+ This file contains the loss function and optimisation routines for streamfit.
3
+
4
+ The optimisation uses adam (adaptive moment estimation) optimiser to fit
5
+ streamline model parameters to observed data by minimizing chi-squared loss.
6
+
7
+ Last updated: 19-06-2026
8
+ '''
9
+
10
+ import os
11
+ from collections import namedtuple
12
+
13
+ import jax.numpy as jnp
14
+ from jax import value_and_grad, lax
15
+ import jax
16
+ from jax.experimental import checkify
17
+ from jax.random import key
18
+ import optax
19
+ from . import stream_lines_grad
20
+ from . import extract_streamline
21
+ import csv
22
+ import astropy.units as u
23
+ import math
24
+ import traceback
25
+ jax.config.update("jax_enable_x64", True)
26
+
27
+ # settings and constants
28
+ VR0_MIN = 1e-6
29
+ BIG = 1e30
30
+ BIG_NEG = -1e30
31
+
32
+ LOSS_METHOD_CHOICES = [0, 1]
33
+
34
+ LOSS_METHOD_COMPONENT_KEYS = {
35
+ 0: ('chi2_ra', 'chi2_dec', 'chi2_v', 'chi2_prior'), #radecvel
36
+ 1: ('chi2_r', 'chi2_theta', 'chi2_v', 'chi2_prior'), #rthetavel
37
+ }
38
+
39
+ TRACE_COMMON_FIELDNAMES = [
40
+ 'epoch',
41
+ 'loss',
42
+ 'chi2_total',
43
+ 'grad_norm',
44
+ 'model_points_total',
45
+ 'model_nan_count',
46
+ 'model_valid_points',
47
+ 'model_metric_span',
48
+ 'model_inner_count',
49
+ 'data_inner_count',
50
+ 'data_points_total',
51
+ 'data_valid_points',
52
+ 'data_retained_count',
53
+ 'model_retained_count',
54
+ 'overlap_metric_min',
55
+ 'overlap_metric_max',
56
+ ]
57
+
58
+ CANONICAL_UNITS = {
59
+ "r0": u.au,
60
+ "theta0": u.rad,
61
+ "phi0": u.rad,
62
+ "inc": u.rad,
63
+ "pa": u.rad,
64
+ "v_r0": u.km / u.s,
65
+ "mass": u.Msun,
66
+ "rmin": u.au,
67
+ "deltar": u.au,
68
+ "v_lsr": u.km / u.s,
69
+ "rc": u.au,
70
+ "omega": 1 / u.s,
71
+ # mu = rc/r0 is dimensionless, so no units
72
+ }
73
+
74
+ ANGLE_KEYS = {'theta0', 'phi0', 'inc', 'pa'}
75
+
76
+ # might later add inc, pa here too
77
+ ANGLE_BOUNDS_RAD = {
78
+ 'theta0': (0.0, jnp.pi),
79
+ 'phi0': (0.0, 2 * jnp.pi),
80
+ 'inc': (-jnp.pi/2, jnp.pi/2),
81
+ 'pa': (0.0, 2 * jnp.pi),
82
+ }
83
+
84
+ DISPLAY_UNITS = {
85
+ 'r0': 'au',
86
+ 'v_r0': 'km/s',
87
+ 'mass': 'M_sun',
88
+ 'rmin': 'au',
89
+ 'deltar': 'au',
90
+ 'v_lsr': 'km/s',
91
+ 'rc': 'au',
92
+ 'omega': '1/s',
93
+ # mu is dimensionless
94
+ }
95
+
96
+ STREAMLINE_MODEL_PARAM_KEYS = (
97
+ 'r0',
98
+ 'theta0',
99
+ 'phi0',
100
+ 'rc',
101
+ 'omega',
102
+ 'mu',
103
+ 'v_r0',
104
+ 'mass',
105
+ 'inc',
106
+ 'pa',
107
+ 'rmin',
108
+ 'deltar',
109
+ 'v_lsr',
110
+ )
111
+
112
+ #### Return types
113
+
114
+ # Contains all the information about the covariance matrix and errors of the fitted model
115
+ CovarianceResult = namedtuple(
116
+ 'CovarianceResult',
117
+ [
118
+ 'covariance', # 2-D array: physical-space covariance matrix from the Hessian, ordered consistently with opt_keys / best_params / fixed_params.
119
+ 'opt_keys', # list[str]: parameter names for rows/cols of covariance (mu-substituted, i.e. 'mu' in place of 'rc'/'omega').
120
+ 'best_opt_params', # dict: best-fit values in the mu-substituted parameterisation.
121
+ 'fixed_params', # dict: fixed parameters in the mu-substituted parameterisation.
122
+ 'param_errors', # dict: 1-sigma errors keyed by opt_keys names, or None.
123
+ 'transformed_cov', # dict or None: Jacobian-transformed result when 'mu' was substituted for 'rc'/'omega'; keys are 'keys', 'cov', 'errors'.
124
+ ]
125
+ )
126
+
127
+ # Contains all the information about the model fit result, including best-fit parameters
128
+ FitResult = namedtuple(
129
+ 'FitResult',
130
+ [
131
+ 'best_opt_params', # dict: best-fit optimised parameters in the original user-supplied parameterisation (rc/omega restored).
132
+ 'loss_history', # list[float]: loss value at every epoch.
133
+ 'param_errors', # dict or None: 1-sigma errors in display parameterisation, or None if uncertainty estimation failed.
134
+ 'covariance_result', # CovarianceResult or None: full covariance information needed for sampling, or None if estimation failed.
135
+ ]
136
+ )
137
+
138
+ def convert_and_strip_bound_units(bounds):
139
+ """
140
+ Convert bounds that are astropy quantitiesinto canonical units,
141
+ then strip the units.
142
+
143
+ If input is already plain numeric, assume it's already in canonical units.
144
+
145
+ Required as JAX optimiser works with unitless arrays.
146
+ """
147
+ if bounds is None:
148
+ return {}
149
+
150
+ output = {}
151
+
152
+ for key, val in bounds.items():
153
+ if isinstance(val, u.Quantity):
154
+ if key not in CANONICAL_UNITS:
155
+ raise ValueError(f"The parameter {key} doesn't have defined canonical units...")
156
+ bounds = val.to(CANONICAL_UNITS[key])
157
+ output[key] = (
158
+ float(bounds[0].value),
159
+ float(bounds[1].value),
160
+ )
161
+ else:
162
+ # already unitless
163
+ output[key] = tuple(float(v) for v in val)
164
+ return output
165
+
166
+ def check_loss_method(loss_method):
167
+ """Check that the selected loss method is valid and return it"""
168
+ if loss_method not in LOSS_METHOD_CHOICES:
169
+ raise ValueError(
170
+ f"Unknown loss_method '{loss_method}'. "
171
+ f"Choose from: 0: radecvel 1: rthetavel"
172
+ )
173
+ return loss_method
174
+
175
+
176
+ def trace_fieldnames_for_loss_method(loss_method):
177
+ """Return the trace csv headers for the chosen loss method"""
178
+ loss_method = check_loss_method(loss_method)
179
+ return ['epoch', 'loss', *LOSS_METHOD_COMPONENT_KEYS[loss_method], *TRACE_COMMON_FIELDNAMES[2:]]
180
+
181
+
182
+ def is_numeric_value(value):
183
+ """Return True for scalar/array-like numeric values"""
184
+ try:
185
+ arr = jnp.asarray(value)
186
+ except Exception:
187
+ return False
188
+ if arr.dtype == jnp.bool_:
189
+ return False
190
+ return bool(jnp.issubdtype(arr.dtype, jnp.number))
191
+
192
+ @jax.jit
193
+ def to_float64(value):
194
+ """Convert a numeric value or array-like input to float64"""
195
+ return jnp.asarray(value, dtype=jnp.float64)
196
+
197
+ @jax.jit
198
+ def softplus(x):
199
+ """Softplus, used for v_r0"""
200
+ return jnp.logaddexp(x, 0.0)
201
+
202
+ @jax.jit
203
+ def inv_softplus(y):
204
+ """used for v_r0, stable for y>0"""
205
+ y = to_float64(y)
206
+ return y + jnp.log1p(-jnp.exp(-y))
207
+
208
+
209
+ def get_checkify_error_message(err):
210
+ """Extract human-readable error message from a checkify.Error if possible,
211
+ or None if it doesn't contain anything"""
212
+ if hasattr(err, 'get'):
213
+ return err.get()
214
+ try:
215
+ err.throw()
216
+ except Exception as e:
217
+ return str(e)
218
+ return None
219
+
220
+ def make_data_tuple_float64(values):
221
+ """Convert tuple/list of arrays to float64 arrays"""
222
+ return tuple(to_float64(value) for value in values)
223
+
224
+
225
+ def clean_model_param_dict(params, dict_name):
226
+ """Convert parameter dictionary to float64 and standardise it"""
227
+ if params is None:
228
+ params = {}
229
+ if not isinstance(params, dict):
230
+ raise TypeError(f"{dict_name} must be a dictionary, got {type(params).__name__}.")
231
+
232
+ sanitized = {}
233
+
234
+ for key, val in params.items():
235
+ if val is None:
236
+ sanitized[key] = None
237
+ continue
238
+ if isinstance(val, u.Quantity):
239
+ if key not in CANONICAL_UNITS:
240
+ raise ValueError(f"The parameter {key} doesn't have defined canonical units...")
241
+ val = val.to(CANONICAL_UNITS[key]).value
242
+ # if it's already a raw number, assume it's already correct
243
+ sanitized[key] = jnp.asarray(val, dtype=jnp.float64)
244
+
245
+ tiny = to_float64(1e-8)
246
+ # Protect against exact polar-angle edge values which can cause
247
+ # downstream numerical issues (theta=0 or theta=pi).
248
+ # If the uservsupplied exactly 0 or pi,
249
+ # nudge by a tiny amount into the open interval (0, pi).
250
+ if 'theta0' in sanitized:
251
+ try:
252
+ theta_val = to_float64(sanitized['theta0'])
253
+ if bool(jnp.all(jnp.isclose(theta_val, to_float64(0.0)))):
254
+ sanitized['theta0'] = theta_val + tiny
255
+ elif bool(jnp.all(jnp.isclose(theta_val, to_float64(jnp.pi)))):
256
+ sanitized['theta0'] = theta_val - tiny
257
+ except Exception:
258
+ pass
259
+
260
+ unknown = sorted(key for key in sanitized if key not in STREAMLINE_MODEL_PARAM_KEYS)
261
+ if unknown:
262
+ raise KeyError(
263
+ f"Unknown parameter keys in {dict_name}: {unknown} "
264
+ f"Supported keys are: {list(STREAMLINE_MODEL_PARAM_KEYS)}"
265
+ )
266
+
267
+ return sanitized
268
+
269
+
270
+ def check_param_types(opt_params, fixed_params):
271
+ """Check that model parameters are of the correct type (numeric or None for rmin)"""
272
+ for key, value in opt_params.items():
273
+ if key == 'rmin' and value is None:
274
+ raise ValueError("'rmin' cannot be None")
275
+ if isinstance(value, bool) or not is_numeric_value(value):
276
+ raise TypeError(
277
+ f"Optimisable parameter '{key}' must be numeric, "
278
+ f"got value of type {type(value).__name__}."
279
+ )
280
+
281
+ for key, value in fixed_params.items():
282
+ if key == 'rmin' and value is None:
283
+ continue
284
+ if isinstance(value, bool) or not is_numeric_value(value):
285
+ raise TypeError(
286
+ f"Fixed parameter '{key}' must be numeric"
287
+ " (or None only for 'rmin'), "
288
+ f"got value of type {type(value).__name__}."
289
+ )
290
+
291
+
292
+ def sanitize_param_partition(initial_opt_params, fixed_params, require_nonempty_opt=False):
293
+ """Sanitize and validate opt/fixed parameter partition for streamline modeling.
294
+ Note: exactly one of 'rc' or 'omega' must be present across initial_opt_params and fixed_params, to determine mu=rc/r0."""
295
+ opt_params = clean_model_param_dict(initial_opt_params, 'initial_opt_params')
296
+ fixed_params = clean_model_param_dict(fixed_params, 'fixed_params')
297
+
298
+ overlap = sorted(set(opt_params) & set(fixed_params))
299
+ if overlap:
300
+ raise KeyError(
301
+ f"Parameters cannot be present in both initial_opt_params and fixed_params! Overlap: {overlap}"
302
+ )
303
+
304
+ all_params = set(opt_params) | set(fixed_params)
305
+
306
+ # check that exactly one of rc, omega, or mu (the rotation keys) is supplied
307
+ rotation_keys_present = [ key for key in ('rc', 'omega', 'mu') if key in all_params]
308
+ if len(rotation_keys_present) != 1:
309
+ raise KeyError(
310
+ f"Exactly one of 'rc', 'omega', or 'mu' must be provided. You have provided: {rotation_keys_present}"
311
+ )
312
+
313
+ # check all other required parameters are present (except mu, rc, omega which we already dealt with)
314
+ already_dealt_with = {'rc', 'omega', 'mu'}
315
+ missing = []
316
+ for key in STREAMLINE_MODEL_PARAM_KEYS:
317
+ if key not in all_params and key not in already_dealt_with:
318
+ missing.append(key)
319
+ if missing:
320
+ raise KeyError(
321
+ "Missing required streamline parameters across initial_opt_params and fixed_params: "
322
+ f"{missing}."
323
+ )
324
+
325
+ if require_nonempty_opt and len(opt_params) == 0:
326
+ raise ValueError(
327
+ "initial_opt_params must contain at least one optimisable parameter. "
328
+ )
329
+
330
+ check_param_types(opt_params, fixed_params)
331
+
332
+ return opt_params, fixed_params
333
+
334
+
335
+ def prepare_model_params(opt_params, fixed_params):
336
+ """Construct merged model parameters and clean opt/fixed dictionaries"""
337
+ opt_params, fixed_params = sanitize_param_partition(opt_params, fixed_params)
338
+ model_params = fixed_params.copy()
339
+ model_params.update(opt_params)
340
+ return model_params, opt_params, fixed_params
341
+
342
+
343
+ def standardise_param_bounds(param_bounds):
344
+ """Check/standardise parameter-bound keys"""
345
+ if param_bounds is None:
346
+ return None
347
+
348
+ standardised = dict(param_bounds)
349
+
350
+ unknown = sorted(key for key in standardised if key not in STREAMLINE_MODEL_PARAM_KEYS)
351
+ if unknown:
352
+ raise KeyError(
353
+ f"Unknown params in param_bounds: {unknown}. "
354
+ f"Supported params are: {list(STREAMLINE_MODEL_PARAM_KEYS)}"
355
+ )
356
+
357
+ return standardised
358
+
359
+
360
+ def build_normalisation_spec(opt_params, param_bounds):
361
+ """Build shift and scale for normalisation ofoptimised parameters, from bounds.
362
+ Doesn't include v_r0 since we use softplus transform instead for that."""
363
+ if param_bounds is None:
364
+ raise ValueError(
365
+ "param_bounds is required because optimisation is performed in normalised space. "
366
+ "Provide bounds for 'mass', 'r0' if you are optimising these."
367
+ )
368
+
369
+ missing = []
370
+ for key in opt_params:
371
+ if key == 'v_r0':
372
+ continue
373
+ if key not in param_bounds:
374
+ missing.append(key)
375
+ if missing:
376
+ raise ValueError(
377
+ "Missing bounds for optimised parameters: "
378
+ f"{missing}. Please add (min, max) entries for all parameters you want to optimise."
379
+ )
380
+
381
+ normalisation_spec = {}
382
+ for key, value in opt_params.items():
383
+ if key == 'v_r0':
384
+ if key in param_bounds:
385
+ print("Notice: Ignoring user-supplied bounds for 'v_r0' since we use a softplus transform for this parameter instead of normalisation.")
386
+ continue
387
+ bounds = param_bounds[key]
388
+ if not isinstance(bounds, (tuple, list)) or len(bounds) != 2:
389
+ raise ValueError(
390
+ f"Bounds for '{key}' must be a 2-element (min, max) tuple."
391
+ f"Got: {bounds!r}"
392
+ )
393
+
394
+ lower_bound = to_float64(bounds[0])
395
+ upper_bound = to_float64(bounds[1])
396
+ if not bool(jnp.isfinite(lower_bound)) or not bool(jnp.isfinite(upper_bound)):
397
+ raise ValueError(f"Bounds for '{key}' must be finite. Got ({lower_bound}, {upper_bound})")
398
+ if not bool(upper_bound > lower_bound):
399
+ raise ValueError(
400
+ f"Bounds for '{key}' must satisfy min < max. Got ({float(lower_bound)}, {float(upper_bound)})"
401
+ )
402
+
403
+ value = to_float64(value)
404
+ if not bool((value >= lower_bound) & (value <= upper_bound)):
405
+ raise ValueError(
406
+ f"Initial value for '{key}' ({float(value)}) is outside bounds"
407
+ f"({float(lower_bound)}, {float(upper_bound)})."
408
+ )
409
+
410
+ scale = upper_bound - lower_bound
411
+ normalisation_spec[key] = {
412
+ 'offset': lower_bound,
413
+ 'scale': scale,
414
+ }
415
+
416
+ return normalisation_spec
417
+
418
+
419
+ def normalise_opt_params(opt_params, normalisation_spec):
420
+ """normalise optimised parameters to [0, 1], unless v_r0 which is transformed by softplus instead"""
421
+ normalised = {}
422
+ for key, value in opt_params.items():
423
+ if key == 'v_r0':
424
+ # save the value 'raw' such that v_r0 = softplus(raw)
425
+ value = to_float64(value)
426
+ if value < 0:
427
+ raise ValueError(f"v_r0 must be non-negative, got {float(value)}")
428
+ normalised[key] = inv_softplus(value)
429
+ continue
430
+ offset = normalisation_spec[key]['offset']
431
+ scale = normalisation_spec[key]['scale']
432
+ normalised[key] = (to_float64(value) - offset) / scale
433
+ return normalised
434
+
435
+
436
+ def denormalise_opt_params(norm_opt_params, normalisation_spec):
437
+ """Convert normalised optimised parameters back to physical/log parameter values"""
438
+ denormalised = {}
439
+ for key, value in norm_opt_params.items():
440
+ if key == 'v_r0':
441
+ denormalised[key] = softplus(to_float64(value))
442
+ continue
443
+ offset = normalisation_spec[key]['offset']
444
+ scale = normalisation_spec[key]['scale']
445
+ if key == 'phi0' or key == 'pa':
446
+ # special handling for phi0, pa because circular
447
+ denormalised[key] = jnp.mod(to_float64(value) * scale + offset, 2*jnp.pi)
448
+ else:
449
+ denormalised[key] = to_float64(value) * scale + offset
450
+ return denormalised
451
+
452
+ def log_header(param_key):
453
+ if param_key == 'mu':
454
+ return 'mu'
455
+ unit = CANONICAL_UNITS.get(param_key)
456
+ return f'{param_key} [{unit}]' if unit is not None else param_key
457
+
458
+
459
+
460
+ def get_rotation_param_key(opt_params, fixed_params):
461
+ """Return whichever of 'rc', 'omega', 'mu' is present in the parameters"""
462
+ all_params = set(opt_params) | set(fixed_params)
463
+ for key in ('rc', 'omega', 'mu'):
464
+ if key in all_params:
465
+ return key
466
+ raise KeyError("None of the parameters 'rc', 'omega', or 'mu' are present in the parameters!")
467
+
468
+
469
+ def mu_from_rotation_param(rotation_key, value, mass, r0):
470
+ """Convert whatever rotation parameter is preesnt into mu"""
471
+ if rotation_key == 'mu':
472
+ return value
473
+ elif rotation_key == 'rc':
474
+ return value / r0
475
+ elif rotation_key == 'omega':
476
+ return stream_lines_grad.mu_from_omega(omega=value, mass=mass, r0=r0)
477
+
478
+ def rotation_param_from_mu(rotation_key, mu, mass, r0):
479
+ """Convert mu into the rotation parameter that is being used (the one that was input by the user)"""
480
+ if rotation_key == 'mu':
481
+ return mu
482
+ elif rotation_key == 'rc':
483
+ return mu * r0
484
+ elif rotation_key == 'omega':
485
+ return stream_lines_grad.omega_from_mu(mu=mu, mass=mass, r0=r0)
486
+
487
+ def with_mu_substituted(opt_params, fixed_params, param_bounds=None):
488
+ """ Replace the user's input rotation parameter (either rc or omega) with mu, which is the parameter used internally for the physics calculations and optimisation,
489
+ because it has obvious bounds (0,1) that will mean that optimisation won't explore regions where rc > r0.
490
+
491
+ User should not have supplied bounds for 'rc' or 'omega', but if they did, prints a notice and ignores them.
492
+ (Because the optimisation is performed in mu space where the bounds should be (0,1))"""
493
+ rotation_key = get_rotation_param_key(opt_params, fixed_params)
494
+ opt_params = dict(opt_params)
495
+ fixed_params = dict(fixed_params)
496
+ if param_bounds is not None:
497
+ param_bounds = dict(param_bounds)
498
+ else:
499
+ param_bounds = {}
500
+
501
+ all_params = {**fixed_params, **opt_params}
502
+ mass = all_params['mass']
503
+ r0 = all_params['r0']
504
+
505
+ if rotation_key in opt_params:
506
+ rotation_value = opt_params[rotation_key]
507
+ mu_value = mu_from_rotation_param(rotation_key, rotation_value, mass, r0)
508
+ del opt_params[rotation_key]
509
+ opt_params['mu'] = mu_value
510
+ # drop any rc/omega/mu bounds the user supplied, we use (0,1) bounds for mu
511
+ if rotation_key in param_bounds:
512
+ print(
513
+ f"Notice: Ignoring user-supplied bounds for '{rotation_key}', since the optimisation is performed in 'mu' space. "
514
+ f"Using (0, 1) bounds for 'mu' instead."
515
+ )
516
+ del param_bounds[rotation_key]
517
+ param_bounds['mu'] = (0.0+1e-6, 1.0-1e-6) # tiny epsilon to avoid rc=r0 or rc=0
518
+ else:
519
+ # rotation parameter is in fixed params. so just rename it to 'mu' in fixed params for consistency
520
+ rotation_value = fixed_params[rotation_key]
521
+ mu_value = mu_from_rotation_param(rotation_key, rotation_value, mass, r0)
522
+ del fixed_params[rotation_key]
523
+ fixed_params['mu'] = mu_value
524
+
525
+ return opt_params, fixed_params, param_bounds, rotation_key
526
+
527
+ def auto_fill_angle_bounds(opt_params, param_bounds):
528
+ """Automatically fill in bounds for theta0, phi0, inc, pa, since their ranges are fixed by physics.
529
+ If the user supplied bounds for these parameters by mistake, print a notice and ignore them."""
530
+ if param_bounds is None:
531
+ param_bounds = {}
532
+ else:
533
+ param_bounds = dict(param_bounds)
534
+ for key, auto_bounds in ANGLE_BOUNDS_RAD.items():
535
+ if key not in opt_params:
536
+ continue
537
+ if key in param_bounds:
538
+ print(f"Notice: Ignoring supplied bounds for '{key}', since the bounds are fixed by physics. Using automatic bounds ({math.degrees(auto_bounds[0])} - {math.degrees(auto_bounds[1])} degrees).")
539
+ param_bounds[key] = auto_bounds
540
+
541
+ return param_bounds
542
+
543
+ def format_param(key, value):
544
+ """
545
+ Format parameter for display in output, with units. Notably:
546
+ - converts angles (theta0, phi0, inc, pa) from radians to degrees
547
+ """
548
+ val = float(value)
549
+ if key in ANGLE_KEYS:
550
+ deg = math.degrees(val)
551
+ return f"{deg:.6g} deg"
552
+ if key == 'mu':
553
+ return f"{val:.6g}"
554
+ if key == 'omega':
555
+ return f"{val:.6g} 1/s"
556
+ unit = DISPLAY_UNITS.get(key, '')
557
+ if unit:
558
+ suffix = f" {unit}"
559
+ else:
560
+ suffix = ""
561
+ return f"{val:.6g}{suffix}"
562
+
563
+ def validate_priors(priors, opt_params, fixed_params):
564
+ """Check the validity of the priors dictionary supplied by the user.
565
+
566
+ priors must be a dict of the form::
567
+
568
+ {param_name: (mean, sigma) * u.Unit, ...}
569
+
570
+ where mean and sigma are plain floats in the same canonical units as the rest of the code,
571
+ or a length-2 astropy Quantity with the appropriate units.
572
+
573
+ Only parameters that are being optimised (i.e. keys present in opt_params) can have priors.
574
+
575
+ Parameters
576
+ ----------
577
+ priors : dict or None
578
+ opt_params : dict : optimisable parameters (after sanitisation)
579
+ fixed_params : dict :fixed parameters (after sanitisation)
580
+
581
+ Returns
582
+ -------
583
+ dict : validated priors, or {} if None.
584
+ """
585
+ if priors is None:
586
+ return {}
587
+ if not isinstance(priors, dict):
588
+ raise TypeError(f"priors must be a dict, got {type(priors).__name__}.")
589
+
590
+ validated = {}
591
+ for key, val in priors.items():
592
+ if key not in STREAMLINE_MODEL_PARAM_KEYS:
593
+ raise KeyError(
594
+ f"Unknown parameter '{key}' in priors. "
595
+ f"Supported keys are: {list(STREAMLINE_MODEL_PARAM_KEYS)}"
596
+ )
597
+ if key in fixed_params:
598
+ raise ValueError(
599
+ f"Parameter '{key}' is fixed and cannot have a prior. "
600
+ "Move it to initial_opt_params if you want to optimise it with a prior."
601
+ )
602
+ if key not in opt_params:
603
+ raise ValueError(
604
+ f"Parameter '{key}' has a prior but is not being optimised. "
605
+ "Add it to initial_opt_params or remove it from priors."
606
+ )
607
+ if isinstance(val, u.Quantity) and val.shape == (2,):
608
+ mean = val[0]
609
+ sigma = val[1]
610
+ elif isinstance(val, (tuple, list)) and len(val) == 2:
611
+ mean, sigma = val
612
+ else:
613
+ raise ValueError(
614
+ f"Prior for '{key}' must be a 2-element (mean, sigma) tuple, "
615
+ f"or a length-2 Quantity e.g. (4.0, 1.0) * u.Msun. Got: {val!r}"
616
+ )
617
+ mean, sigma = val
618
+
619
+ if isinstance(mean, u.Quantity):
620
+ if key not in CANONICAL_UNITS:
621
+ raise ValueError(f"No canonical units defined for '{key}'.")
622
+ mean = float(mean.to(CANONICAL_UNITS[key]).value)
623
+ else:
624
+ mean = float(mean) # assume it's already in the right units
625
+ if isinstance(sigma, u.Quantity):
626
+ if key not in CANONICAL_UNITS:
627
+ raise ValueError(f"No canonical units defined for '{key}'.")
628
+ sigma = float(sigma.to(CANONICAL_UNITS[key]).value)
629
+ else:
630
+ sigma = float(sigma) # assume it's already in the right units
631
+ if sigma <= 0:
632
+ raise ValueError(
633
+ f"Prior sigma for '{key}' must be positive, got {sigma}."
634
+ )
635
+ validated[key] = (mean, sigma)
636
+ return validated
637
+
638
+
639
+ @jax.jit(static_argnames=('priors_keys',))
640
+ def compute_prior_penalty(model_params, priors_means, priors_sigmas, priors_keys):
641
+ """Compute the Gaussian prior penalty term:
642
+ sum_k ((theta_k - mu_k) / sigma_k)^2.
643
+
644
+ Parameters
645
+ ----------
646
+ model_params : dict: physical model parameters
647
+ priors_means : tuple of float: prior means, one per prior key
648
+ priors_sigmas : tuple of float: prior sigmas, one per prior key
649
+ priors_keys : tuple of str: corresponding prior keys (parameter names)
650
+
651
+ Returns
652
+ -------
653
+ scalar JAX array : the prior chi-squared contribution
654
+ """
655
+ penalty = jnp.asarray(0.0, dtype=jnp.float64)
656
+ for key, mu_p, sigma_p in zip(priors_keys, priors_means, priors_sigmas):
657
+ mu_p = jnp.asarray(mu_p, dtype=jnp.float64)
658
+ sigma_p = jnp.asarray(sigma_p, dtype=jnp.float64)
659
+ penalty = penalty + ((model_params[key] - mu_p) / sigma_p) ** 2
660
+ return penalty
661
+
662
+ def add_rc_omega_to_log(row, opt_params, fixed_params, all_param_keys):
663
+ """Add rc and omega to the log row for user-friendly output, converting from mu if necessary"""
664
+ if 'mu' not in all_param_keys:
665
+ return row
666
+ mass_val = opt_params['mass'] if 'mass' in opt_params else fixed_params.get('mass', None)
667
+ r0_val = opt_params['r0'] if 'r0' in opt_params else fixed_params.get('r0', None)
668
+ mu_val = opt_params['mu'] if 'mu' in opt_params else fixed_params.get('mu', None)
669
+ if mass_val is not None and r0_val is not None and mu_val is not None:
670
+ rc_val = mu_val * r0_val
671
+ omega_val = stream_lines_grad.omega_from_mu(mu=mu_val, mass=mass_val, r0=r0_val)
672
+ row['rc'] = float(rc_val)
673
+ row['omega'] = float(omega_val)
674
+ return row
675
+
676
+
677
+ def build_trace_row(epoch, loss_value, loss_trace, grad_norm, loss_method):
678
+ """Flatten trace dictionary into a csv row for output"""
679
+ loss_method = check_loss_method(loss_method)
680
+ chi2_components = loss_trace.get('chi2_components', {})
681
+ matching = loss_trace.get('matching', {})
682
+ model_metric_trace = matching.get('distance_metric_model', {})
683
+ data_metric_trace = matching.get('distance_metric_data', {})
684
+
685
+ row = {
686
+ 'epoch': epoch,
687
+ 'loss': loss_value,
688
+ }
689
+ for component_key in LOSS_METHOD_COMPONENT_KEYS[loss_method]:
690
+ row[component_key] = chi2_components.get(component_key, float('nan'))
691
+
692
+ row.update({
693
+ 'chi2_total': chi2_components.get('chi2_total', float('nan')),
694
+ 'grad_norm': grad_norm,
695
+ 'model_points_total': matching.get('model_points_total', 0),
696
+ 'model_nan_count': matching.get('model_nan_count', 0),
697
+ 'model_valid_points': matching.get('model_valid_points', 0),
698
+ 'model_metric_span': matching.get('model_metric_span', float('nan')),
699
+ 'model_inner_count': model_metric_trace.get('inner_count', 0),
700
+ 'data_inner_count': data_metric_trace.get('inner_count', 0),
701
+ 'data_points_total': matching.get('data_points_total', 0),
702
+ 'data_valid_points': matching.get('data_valid_points', 0),
703
+ 'data_retained_count': matching.get('data_retained_count', 0),
704
+ 'model_retained_count': matching.get('model_retained_count', 0),
705
+ 'overlap_metric_min': matching.get('overlap_metric_min', float('nan')),
706
+ 'overlap_metric_max': matching.get('overlap_metric_max', float('nan')),
707
+ })
708
+
709
+ return row
710
+
711
+ def trace_tree_to_python(value):
712
+ """Go through the trace tree and convert JAX arrays to Python scalars where possible"""
713
+ # go through containers, converting JAX arrays to Python scalars where possible, and leaving non-numeric values as-is
714
+ if isinstance(value, dict):
715
+ return {key: trace_tree_to_python(v) for key, v in value.items()}
716
+ if isinstance(value, list):
717
+ return [trace_tree_to_python(v) for v in value]
718
+ if isinstance(value, tuple):
719
+ return tuple(trace_tree_to_python(v) for v in value)
720
+ # preserve Nones as-is
721
+ if value is None:
722
+ return None
723
+ # convert jax arrays to python scalars where posible
724
+ try:
725
+ array_value = jnp.asarray(value)
726
+ except Exception:
727
+ return value
728
+ # if it's a scalar array, convert to scalar
729
+ if array_value.ndim == 0:
730
+ return array_value.item()
731
+ return value
732
+
733
+
734
+ @jax.jit
735
+ def gradient_l2_norm(grad_tree):
736
+ """Compute L2 norm of gradients across all leaves in a pytree
737
+ (a pytree is a nested structure of lists/dicts/tuples containing arrays,
738
+ used by jax for gradients)."""
739
+ grad_leaves = jax.tree_util.tree_leaves(grad_tree)
740
+ grad_sum_sq = jnp.asarray(0.0, dtype=jnp.float64)
741
+ for grad_leaf in grad_leaves:
742
+ grad_sum_sq = grad_sum_sq + jnp.sum(jnp.square(grad_leaf))
743
+ return jnp.sqrt(grad_sum_sq)
744
+
745
+ @jax.jit(static_argnames=("npoints",))
746
+ def forward_model(model_params, distance_pc, npoints=10000):
747
+ """
748
+ Run the forward model using stream_lines_grad.checked_xyz_stream
749
+
750
+ Parameters:
751
+ -----------
752
+ model_params: dict
753
+ Dictionary of model parameters, including both optimised and fixed parameters
754
+ distance_pc : float
755
+ Distance to source in parsecs
756
+ npoints : int
757
+ Number of points to sample along the streamer
758
+ This is just for jax/jit compatibility to have fixed-length arrays, but
759
+ the actual number of valid points is determined by r0, rmin, rc, deltar,
760
+ so some of the returned points may be NaN if npoints is larger than the number of valid points.
761
+
762
+ Returns:
763
+ --------
764
+ tuple: (ra_offsets, dec_offsets, velocities)
765
+ - RA offsets in arcsec (negative for standard convention)
766
+ - Dec offsets in arcsec
767
+ - Line-of-sight velocities in km/s, relative to v_lsr
768
+ """
769
+
770
+
771
+ distance_pc = to_float64(distance_pc)
772
+
773
+ # Run the forward model - returns positions in au, velocities in km/s
774
+ # valid_mask is a boolean array marking which points are valid in the returned arrays,
775
+ # which can be used for masking in the loss function
776
+ rmin = model_params['rmin']
777
+ if rmin is None:
778
+ rmin = to_float64(0.0) # rc*0.5 will always dominate in jnp.maximum
779
+ # derive mu from rc or omega (whicever is provided)
780
+ if 'mu' in model_params:
781
+ mu = model_params['mu']
782
+ elif 'rc' in model_params:
783
+ mu = model_params['rc'] / model_params['r0']
784
+ elif 'omega' in model_params:
785
+ mu = stream_lines_grad.mu_from_omega(omega=model_params['omega'], mass=model_params['mass'], r0=model_params['r0'])
786
+ else:
787
+ raise ValueError("model_params must contain either 'rc', 'omega', or 'mu'")
788
+ model_params = dict(model_params)
789
+ model_params['mu'] = mu
790
+
791
+ err, ((x, y, z), (vx, vy, vz), valid_mask) = stream_lines_grad.checked_xyz_stream(
792
+ mass=model_params['mass'],
793
+ r0=model_params['r0'],
794
+ theta0=model_params['theta0'],
795
+ phi0=model_params['phi0'],
796
+ mu=model_params['mu'],
797
+ v_r0=model_params['v_r0'],
798
+ inc=model_params['inc'],
799
+ pa=model_params['pa'],
800
+ rmin=rmin,
801
+ deltar=model_params['deltar'],
802
+ npoints=npoints
803
+ )
804
+ # err.throw()
805
+
806
+
807
+ # Convert positions from au to arcsec offsets
808
+ # x = RA offset (with negative for standard RA convention)
809
+ # z = Dec offset
810
+ # y = line-of-sight velocity
811
+ ra_model = -x / distance_pc # arcsec
812
+ dec_model = z / distance_pc # arcsec
813
+ # make velocity absolute by adding back v_lsr
814
+ v_model = vy + model_params['v_lsr'] # km/s
815
+ # make sure model points outside valid_mask are set to 0
816
+ ra_model = jnp.where(valid_mask, ra_model, 0.0)
817
+ dec_model = jnp.where(valid_mask, dec_model, 0.0)
818
+ v_model = jnp.where(valid_mask, v_model, 0.0)
819
+
820
+
821
+ return ra_model, dec_model, v_model, valid_mask, err
822
+
823
+
824
+ @jax.jit
825
+ def distance_metric_overlap(dmetric_model, model_finite_mask, dmetric_data, data_finite_mask):
826
+ """Compute the overlapping range in the streamline distance metric between data and model"""
827
+ model_metric_for_min = jnp.where(model_finite_mask, dmetric_model, to_float64(BIG))
828
+ model_metric_for_max = jnp.where(model_finite_mask, dmetric_model, to_float64(BIG_NEG))
829
+ data_metric_for_min = jnp.where(data_finite_mask, dmetric_data, to_float64(BIG))
830
+ data_metric_for_max = jnp.where(data_finite_mask, dmetric_data, to_float64(BIG_NEG))
831
+
832
+ model_min = jnp.min(model_metric_for_min)
833
+ model_max = jnp.max(model_metric_for_max)
834
+ data_min = jnp.min(data_metric_for_min)
835
+ data_max = jnp.max(data_metric_for_max)
836
+
837
+ overlap_min = jnp.maximum(model_min, data_min)
838
+ overlap_max = jnp.minimum(model_max, data_max)
839
+ return model_min, model_max, data_min, data_max, overlap_min, overlap_max
840
+
841
+ @jax.jit
842
+ def match_model_to_data_curve(ra_model, dec_model, v_model, valid_mask_model, ra_data, dec_data):
843
+ """
844
+ Extract model values corresponding to data positions using the distance metric from
845
+ extract_streamline.get_distance_metric
846
+
847
+ Method:
848
+ 1. Compute the distance metric for model and data points
849
+ 2. Apply finite masks
850
+ 3. Normalise both metrics to [0, 1] based on their finite ranges
851
+ 4. Map data normalised positions to model normalised positions
852
+ 5. Interpolate model RA, Dec, and velocity at the mapped positions
853
+
854
+ Returns
855
+ -------
856
+ ra_model_interp, dec_model_interp, v_model_interp, valid, dmetric_model, matching_trace
857
+ where valid is a boolean mask with shape len(original data), marking
858
+ retained data points
859
+ """
860
+ ra_model = to_float64(ra_model)
861
+ dec_model = to_float64(dec_model)
862
+ v_model = to_float64(v_model)
863
+ ra_data = to_float64(ra_data)
864
+ dec_data = to_float64(dec_data)
865
+
866
+ # get distance metrics
867
+ dmetric_model, _ = extract_streamline.get_distance_metric(ra_model, dec_model)
868
+ dmetric_data, _ = extract_streamline.get_distance_metric(ra_data, dec_data)
869
+
870
+ model_valid = valid_mask_model
871
+
872
+ data_valid = (
873
+ jnp.isfinite(ra_data)
874
+ & jnp.isfinite(dec_data)
875
+ & jnp.isfinite(dmetric_data)
876
+ )
877
+
878
+ # # we also filter model to keep only model points with dmetric >= minimum of data dmetric
879
+ # this is becuase the model shouldn't go further in than the innermost data point
880
+ # as this is where we no longer observe the streamer
881
+ d_data_valid = jnp.where(data_valid, dmetric_data, to_float64(BIG))
882
+ data_min = jnp.min(d_data_valid)
883
+
884
+ # enforce both constraints on model
885
+ model_keep = model_valid.astype(bool) & (dmetric_model >= data_min)
886
+
887
+ # weights: 0 = ignore, 1 = use. This is for jax/jit compatibility
888
+ w_model = model_keep.astype(jnp.float64)
889
+
890
+ d_model = jnp.where(model_keep, dmetric_model, 0.0)
891
+ d_data = jnp.where(data_valid, dmetric_data, 0.0)
892
+
893
+ # ---- sort model using metric + weight penalty (pushes invalid points to the end) ----
894
+ model_sort_key = d_model + (1.0 - w_model) * to_float64(BIG)
895
+ model_idx = jnp.argsort(model_sort_key)
896
+
897
+ d_model_s = d_model[model_idx]
898
+ ra_s = ra_model[model_idx]
899
+ dec_s = dec_model[model_idx]
900
+ v_s = v_model[model_idx]
901
+ w_model_s = w_model[model_idx]
902
+
903
+ # stats for trace and interpolation domain
904
+ data_min_eff = jnp.min(jnp.where(data_valid, dmetric_data, to_float64(BIG)))
905
+ data_max_eff = jnp.max(jnp.where(data_valid, dmetric_data, to_float64(BIG_NEG)))
906
+ model_min = jnp.min(jnp.where(model_keep, dmetric_model, to_float64(BIG)))
907
+ model_max = jnp.max(jnp.where(model_keep, dmetric_model, to_float64(BIG_NEG)))
908
+
909
+
910
+ model_span = model_max - model_min
911
+ data_span = data_max_eff - data_min_eff
912
+ model_span_safe = jnp.where(model_span > to_float64(0.0), model_span, to_float64(1.0))
913
+ data_span_safe = jnp.where(data_span > to_float64(0.0), data_span, to_float64(1.0))
914
+
915
+
916
+ data_has_valid = data_min_eff < to_float64(BIG)
917
+ model_has_valid = model_min < to_float64(BIG)
918
+ both_valid = data_has_valid & model_has_valid
919
+ # normalise data metric
920
+ d_data_norm = (d_data - data_min_eff) / data_span_safe
921
+ d_goal = model_min + d_data_norm * model_span_safe
922
+
923
+ # interpolate model at data points, using weights to ignore invalid model points
924
+ # by giving them huge distance values so they don't affect the interpolation
925
+ xp = jnp.where(w_model_s > 0, d_model_s, to_float64(BIG))
926
+
927
+ ra_interp = jnp.interp(d_goal, xp, ra_s)
928
+ dec_interp = jnp.interp(d_goal, xp, dec_s)
929
+ v_interp = jnp.interp(d_goal, xp, v_s)
930
+
931
+ # things for trace
932
+ valid = data_valid
933
+
934
+ matching_trace = {
935
+ "model_points_total": model_idx.size,
936
+ "model_nan_count": jnp.sum(jnp.isnan(dmetric_model)),
937
+ "model_valid_points": model_valid.sum(),
938
+ "data_points_total": ra_data.size,
939
+ "data_nan_count": jnp.sum(jnp.isnan(dmetric_data)),
940
+ "data_valid_points": data_valid.sum(),
941
+ "model_metric_min": model_min,
942
+ "model_metric_max": model_max,
943
+ "data_metric_min": data_min_eff,
944
+ "data_metric_max": data_max_eff,
945
+ "model_metric_span": model_span_safe,
946
+ "data_metric_span": data_span_safe}
947
+
948
+ return ra_interp, dec_interp, v_interp, valid, model_keep, dmetric_model, matching_trace
949
+
950
+ checked_matching = checkify.checkify(match_model_to_data_curve)
951
+
952
+
953
+ @jax.jit
954
+ def checked_match_model_to_data_curve(*args, **kwargs):
955
+ """Wrapper around match_model_to_data_curve with checkify checks for errors (to remain jax compatible)"""
956
+ errors, result = checked_matching(*args, **kwargs)
957
+ errors.throw()
958
+ return result
959
+
960
+ @jax.jit(static_argnames=("loss_method", "npoints", "priors_keys", "priors_means", "priors_sigmas"))
961
+ def chi2_loss(
962
+ model_params,
963
+ distance_pc,
964
+ prepared_data,
965
+ loss_method=0,
966
+ npoints=10000,
967
+ priors_keys=(),
968
+ priors_means=(),
969
+ priors_sigmas=()
970
+ ):
971
+ """Generates model via forward model and calculates loss between data and model.
972
+ Returns (chi2_total, loss_trace, err) where err is a checkify.
973
+
974
+ chi2_loss = chi2_data + chi2_priors.
975
+
976
+ chi2_priors is the sum of Gaussian prior penalty terms for any optimised parameters where priors were given.
977
+ See compute_prior_penalty"""
978
+
979
+ loss_method = check_loss_method(loss_method)
980
+
981
+ distance_pc = to_float64(distance_pc)
982
+
983
+ ra_data = prepared_data.ra_data
984
+ dec_data = prepared_data.dec_data
985
+ v_data = prepared_data.v_data
986
+ ra_sigma = prepared_data.ra_sigma_safe
987
+ dec_sigma = prepared_data.dec_sigma_safe
988
+ v_sigma = prepared_data.v_sigma_safe
989
+
990
+ ra_model, dec_model, v_model, valid_mask_model, err = forward_model(model_params, distance_pc, npoints=npoints)
991
+ valid_mask_model = valid_mask_model.astype(jnp.bool_)
992
+
993
+ ra_model_interp, dec_model_interp, v_model_interp, valid, model_keep, dmetric_model, _ = (
994
+ checked_match_model_to_data_curve(ra_model, dec_model, v_model, valid_mask_model, ra_data, dec_data)
995
+ )
996
+
997
+ dmetric_data = prepared_data.dmetric_data
998
+ valid = jnp.asarray(valid, dtype=bool)
999
+ valid_weights = valid.astype(jnp.float64)
1000
+
1001
+ model_finite_mask = (
1002
+ jnp.isfinite(ra_model)
1003
+ & jnp.isfinite(dec_model)
1004
+ & jnp.isfinite(v_model)
1005
+ & jnp.isfinite(dmetric_model)
1006
+ )
1007
+
1008
+
1009
+ # Only compute chi2 on valid/retained data points
1010
+
1011
+ chi2_v = jnp.sum(valid_weights * (((v_data - v_model_interp) / v_sigma) ** 2))
1012
+
1013
+
1014
+ if loss_method == 0:
1015
+ chi2_ra = jnp.sum(valid_weights * (((ra_data - ra_model_interp) / ra_sigma) ** 2))
1016
+ chi2_dec = jnp.sum(valid_weights * (((dec_data - dec_model_interp) / dec_sigma) ** 2))
1017
+ chi2_total = chi2_ra + chi2_dec + chi2_v
1018
+ else:
1019
+ r_proj_data = prepared_data.r_proj_data
1020
+ theta_proj_data = prepared_data.theta_proj_data
1021
+ r_proj_model, theta_proj_model = extract_streamline.cartesian_to_polar(
1022
+ ra_model_interp,
1023
+ dec_model_interp,
1024
+ )
1025
+
1026
+ dtheta = extract_streamline.wrap_to_pi(theta_proj_data - theta_proj_model)
1027
+
1028
+ sigma_r = jnp.sqrt(ra_sigma**2 + dec_sigma**2)
1029
+ r_eps = to_float64(1e-8)
1030
+ r_safe = jnp.maximum(jnp.abs(r_proj_data), r_eps)
1031
+ sigma_theta = jnp.sqrt(((dec_data * ra_sigma)**2 + (ra_data * dec_sigma)**2)) / (r_safe**2)
1032
+ sigma_theta = jnp.maximum(sigma_theta, r_eps)
1033
+
1034
+ chi2_r = jnp.sum(valid_weights * (((r_proj_data - r_proj_model) / sigma_r) ** 2))
1035
+ chi2_theta = jnp.sum(valid_weights * ((dtheta / sigma_theta) ** 2))
1036
+ chi2_total = chi2_r + chi2_theta + chi2_v
1037
+
1038
+
1039
+ data_finite_mask = (
1040
+ jnp.isfinite(ra_data)
1041
+ & jnp.isfinite(dec_data)
1042
+ & jnp.isfinite(dmetric_data)
1043
+ )
1044
+
1045
+ model_min, model_max, data_min, data_max, overlap_min, overlap_max = distance_metric_overlap(
1046
+ dmetric_model,
1047
+ model_finite_mask,
1048
+ dmetric_data,
1049
+ data_finite_mask,
1050
+ )
1051
+
1052
+ model_nan_count = jnp.sum(~model_finite_mask)
1053
+ model_points_total = ra_model.size
1054
+ model_valid_points = model_points_total - model_nan_count
1055
+
1056
+ data_keep = data_finite_mask & (dmetric_data >= overlap_min) & (dmetric_data <= overlap_max)
1057
+ model_keep = model_finite_mask & (dmetric_model >= overlap_min) & (dmetric_model <= overlap_max)
1058
+
1059
+ sort_idx = jnp.argsort(dmetric_model)
1060
+ d_model_sorted = dmetric_model[sort_idx]
1061
+ d_diff = jnp.diff(d_model_sorted)
1062
+ if d_diff.size > 0:
1063
+ model_metric_min_gap = jnp.min(d_diff)
1064
+ model_metric_near_tie_count = jnp.sum(jnp.abs(d_diff) <= 1e-8)
1065
+ model_metric_duplicate_count = jnp.sum(d_diff == 0.0)
1066
+ model_metric_non_monotonic_count = jnp.sum(d_diff < 0.0)
1067
+ else:
1068
+ model_metric_min_gap = to_float64(float('nan'))
1069
+ model_metric_near_tie_count = to_float64(0.0)
1070
+ model_metric_duplicate_count = to_float64(0.0)
1071
+ model_metric_non_monotonic_count = to_float64(0.0)
1072
+
1073
+ model_metric_span = d_model_sorted[-1] - d_model_sorted[0] if d_model_sorted.size > 1 else to_float64(0.0)
1074
+
1075
+ # compute and add the prior penalty term to the total chi2 if priors are provided
1076
+ chi2_prior = compute_prior_penalty(model_params, priors_means, priors_sigmas, priors_keys)
1077
+ chi2_total = chi2_total + chi2_prior
1078
+
1079
+ if loss_method == 0:
1080
+ chi2_components = {
1081
+ 'chi2_ra': chi2_ra,
1082
+ 'chi2_dec': chi2_dec,
1083
+ 'chi2_v': chi2_v,
1084
+ 'chi2_prior': chi2_prior,
1085
+ 'overlap_width': overlap_max - overlap_min,
1086
+ 'chi2_total': chi2_total,
1087
+ }
1088
+ else:
1089
+ chi2_components = {
1090
+ 'chi2_r': chi2_r,
1091
+ 'chi2_theta': chi2_theta,
1092
+ 'chi2_v': chi2_v,
1093
+ 'chi2_prior': chi2_prior,
1094
+ 'overlap_width': overlap_max - overlap_min,
1095
+ 'chi2_total': chi2_total,
1096
+ }
1097
+
1098
+ matching_trace = {
1099
+ 'model_points_total': model_points_total,
1100
+ 'model_nan_count': model_nan_count,
1101
+ 'model_valid_points': model_valid_points,
1102
+ 'model_retained_count': jnp.sum(model_keep),
1103
+ 'data_points_total': ra_data.size,
1104
+ 'data_valid_points': jnp.sum(data_finite_mask),
1105
+ 'data_retained_count': jnp.sum(data_keep),
1106
+ 'overlap_metric_min': overlap_min,
1107
+ 'overlap_metric_max': overlap_max,
1108
+ 'model_metric_span': model_metric_span,
1109
+ 'model_metric_min_gap': model_metric_min_gap,
1110
+ 'model_metric_near_tie_count': model_metric_near_tie_count,
1111
+ 'model_metric_duplicate_count': model_metric_duplicate_count,
1112
+ 'model_metric_non_monotonic_count': model_metric_non_monotonic_count,
1113
+ }
1114
+
1115
+ loss_trace = {
1116
+ 'chi2_components': chi2_components,
1117
+ 'matching': matching_trace,
1118
+ 'loss_method': loss_method,
1119
+ }
1120
+ return chi2_total, loss_trace, err
1121
+
1122
+
1123
+
1124
+ InitialGuessResult = namedtuple('InitialGuessResult', [
1125
+ 'model_params',
1126
+ 'ra_model',
1127
+ 'dec_model',
1128
+ 'v_model',
1129
+ 'ra_model_interp',
1130
+ 'dec_model_interp',
1131
+ 'v_model_interp',
1132
+ 'valid',
1133
+ 'chi2_total',
1134
+ 'chi2_components',
1135
+ ])
1136
+
1137
+
1138
+ def evaluate_initial_guess(
1139
+ initial_opt_params,
1140
+ fixed_params,
1141
+ data,
1142
+ uncertainties,
1143
+ distance_pc,
1144
+ n_elements=10,
1145
+ loss_method=0,
1146
+ priors=None,
1147
+ ):
1148
+ """
1149
+ Run the forward model and compute chi2 loss for the initial parameter guess.
1150
+
1151
+ Parameters
1152
+ ----------
1153
+ initial_opt_params : dict
1154
+ Initial guesses for the parameters to optimise.
1155
+ fixed_params : dict
1156
+ Fixed (non-optimised) parameters. Together with initial_opt_params
1157
+ this must provide a full, non-overlapping partition of
1158
+ STREAMLINE_MODEL_PARAM_KEYS.
1159
+ data : tuple of arrays (ra_data, dec_data, v_data)
1160
+ Observed RA offset (arcsec), Dec offset (arcsec), velocity (km/s).
1161
+ uncertainties : tuple of arrays (ra_sigma, dec_sigma, v_sigma)
1162
+ Uncertainties on the data.
1163
+ distance_pc : float
1164
+ Distance to source in parsecs.
1165
+ n_elements : int
1166
+ Number of distance-metric partitions, i.e. the number of 1D data
1167
+ points. Must match the value used when reducing the cube.
1168
+ loss_method : int
1169
+ Loss definition to use. Options:
1170
+ - 0: radecvel — RA, Dec, and velocity residuals.
1171
+ - 1: rthetavel — radial distance, polar angle, and velocity residuals.
1172
+ priors : dict or None
1173
+ Optional Gaussian priors on optimised parameters, in the form
1174
+ {param_name: (mean, sigma), ...}. Only optimised parameters can have priors.
1175
+ Mean and sigma must be in the same canonical units as the rest of the code
1176
+
1177
+ Returns
1178
+ -------
1179
+ InitialGuessResult
1180
+ Named tuple with entries:
1181
+ - model_params : merged dict of all model parameters (float64)
1182
+ - ra_model : full model RA offsets (arcsec)
1183
+ - dec_model : full model Dec offsets (arcsec)
1184
+ - v_model : full model velocities (km/s)
1185
+ - ra_model_interp : model RA interpolated at data positions
1186
+ - dec_model_interp : model Dec interpolated at data positions
1187
+ - v_model_interp : model velocity interpolated at data positions
1188
+ - valid : boolean mask of retained data points
1189
+ - chi2_total : total chi2 loss (float)
1190
+ - chi2_components : dict of per-component chi2 values and chi2_total
1191
+ """
1192
+ loss_method = check_loss_method(loss_method)
1193
+
1194
+ model_params, opt_params_clean, fixed_params_clean = prepare_model_params(initial_opt_params, fixed_params)
1195
+ validated_priors = validate_priors(priors, opt_params_clean, fixed_params_clean)
1196
+ priors_keys = tuple(validated_priors.keys())
1197
+ priors_means = tuple(v[0] for v in validated_priors.values())
1198
+ priors_sigmas = tuple(v[1] for v in validated_priors.values())
1199
+
1200
+ ra_model, dec_model, v_model, valid_mask_model, err = forward_model(
1201
+ model_params, distance_pc
1202
+ )
1203
+ err.throw()
1204
+
1205
+ ra_model_interp, dec_model_interp, v_model_interp, valid, _, _, _ = (
1206
+ checked_match_model_to_data_curve(
1207
+ ra_model, dec_model, v_model, valid_mask_model,
1208
+ jnp.asarray(data[0], dtype=jnp.float64),
1209
+ jnp.asarray(data[1], dtype=jnp.float64),
1210
+ )
1211
+ )
1212
+
1213
+ prepared_data = extract_streamline.prepare_data(data, uncertainties, n_elements=n_elements)
1214
+ chi2_total, loss_trace, _ = chi2_loss(
1215
+ model_params, distance_pc, prepared_data, loss_method=loss_method,
1216
+ priors_keys=priors_keys, priors_means=priors_means, priors_sigmas=priors_sigmas
1217
+ )
1218
+ chi2_components = loss_trace['chi2_components']
1219
+
1220
+ return InitialGuessResult(
1221
+ model_params=model_params,
1222
+ ra_model=ra_model,
1223
+ dec_model=dec_model,
1224
+ v_model=v_model,
1225
+ ra_model_interp=ra_model_interp,
1226
+ dec_model_interp=dec_model_interp,
1227
+ v_model_interp=v_model_interp,
1228
+ valid=valid,
1229
+ chi2_total=float(chi2_total),
1230
+ chi2_components={k: float(v) for k, v in chi2_components.items()},
1231
+ )
1232
+
1233
+
1234
+ def fit_streamline(initial_opt_params, fixed_params, streamer, distance_pc,
1235
+ learning_rate=0.005, param_bounds=None, n_epochs=1000,
1236
+ beta1=0.9, beta2=0.999,
1237
+ info_every=100, loss_threshold=None, loss_threshold_epochs=1,
1238
+ gradient_tol=None, gradient_tol_epochs=1,
1239
+ early_stopping_patience=50,
1240
+ save_folder='sting_results',
1241
+ loss_method=1, # 0: radecvel, 1: rthetavel
1242
+ priors=None,
1243
+ v_lsr=None,
1244
+ show_plots=False
1245
+ ):
1246
+ """
1247
+ Fit streamline model parameters to data using Adam optimiser.
1248
+ Any supported streamline parameter can be optimised or fixed.
1249
+ Parameters are split by dictionary:
1250
+ - keys in initial_opt_params are optimised
1251
+ - keys in fixed_params are held fixed
1252
+ The union must contain each key in STREAMLINE_MODEL_PARAM_KEYS exactly once.
1253
+
1254
+ Parameters:
1255
+ -----------
1256
+ initial_opt_params : dict
1257
+ Initial guesses for the parameters to optimise.
1258
+ Allowed keys are STREAMLINE_MODEL_PARAM_KEYS.
1259
+ fixed_params : dict
1260
+ Fixed (non-optimised) parameters using the same key space.
1261
+ Together with initial_opt_params, this must provide a full,
1262
+ non-overlapping partition of STREAMLINE_MODEL_PARAM_KEYS.
1263
+ streamer: NamedTuple with fields:
1264
+ pc_coords, ra_data, dec_data, v_data, ra_sigma, dec_sigma, v_sigma, data, uncertainties
1265
+ data : tuple of arrays (ra_data, dec_data, v_data)
1266
+ Observed RA offset (arcsec), Dec offset (arcsec), velocity (km/s)
1267
+ uncertainties : tuple of arrays (ra_sigma, dec_sigma, v_sigma)
1268
+ Uncertainties on the data
1269
+ distance_pc : float
1270
+ Distance to source in parsecs
1271
+ learning_rate : float
1272
+ Adam learning rate applied uniformly to all normalised parameters.
1273
+ param_bounds : dict or None
1274
+ Parameter bounds in physical/log parameter units.
1275
+ Parameters that require bounds here (if optimised): r0, mass, inc, pa
1276
+ Provide param_bounds as a dictionary with values as (min, max) tuples for each parameter
1277
+ n_epochs : int
1278
+ Maximum number of optimisation iterations
1279
+ beta1 : float
1280
+ Adam exponential decay rate for first moment
1281
+ beta2 : float
1282
+ Adam exponential decay rate for second moment
1283
+ info_every : int
1284
+ Print loss every N epochs
1285
+ early_stopping_patience : int
1286
+ Stop if loss doesn't improve for N epochs
1287
+ save_folder : str
1288
+ Folder to save output CSV and trace files, and figures. Created if it doesn't exist.
1289
+ loss_method : int
1290
+ Loss definition to use. Options:
1291
+ - 0: radecvel: optimise RA, Dec, and velocity residuals.
1292
+ - 1: rthetavel: optimise projected radial distance, polar angle, and velocity residuals.
1293
+ Both options use the same model-data matching and overlap penalty.
1294
+ priors : dict or None
1295
+ Optional Gaussian priors on optimised parameters, in the form
1296
+ {param_name: (mean, sigma), ...}. Only optimised parameters can have priors.
1297
+ Mean and sigma must be in the same canonical units as the rest of the code
1298
+ v_lsr : float or None
1299
+ Systemic velocity (km/s). When provided, drawn as a reference line on the best-fit
1300
+ velocity-radius plot
1301
+ loss_threshold : float or None
1302
+ Optional absolute loss threshold for threshold-based stopping.
1303
+ If provided, optimisation stops after loss is <= loss_threshold for
1304
+ loss_threshold_epochs consecutive epochs.
1305
+ loss_threshold_epochs : int
1306
+ Number of consecutive epochs with loss <= loss_threshold required to
1307
+ trigger threshold-based early stopping. Must be >= 1.
1308
+ gradient_tol : float or None
1309
+ Optional gradient norm tolerance for stopping in normalised space.
1310
+ If provided, optimisation stops when the L2 norm of gradients with
1311
+ respect to normalised parameters
1312
+ is less than this threshold for gradient_tol_epochs consecutive epochs,
1313
+ indicating convergence.
1314
+ gradient_tol_epochs : int
1315
+ Number of consecutive epochs with ||grad|| < gradient_tol required to
1316
+ trigger normalised-space gradient norm-based early stopping. Must be >= 1.
1317
+ show_plots : bool
1318
+ Whether to show diagnostic plots during optimisation
1319
+
1320
+ Epoch 0: initial state before any updates, with initial_opt_params
1321
+ Epoch n (n>=1): state after applying parameter update n
1322
+ Tracking and checks are all performed at the end of each epoch. So e.g. loss n = loss after applying update n, using the updated parameters
1323
+
1324
+ Returns:
1325
+ --------
1326
+ FitResult namedtuple with fields:
1327
+ - best_opt_params : dict of best-fit optimised parameters (physical/log units)
1328
+ - loss_history : list of loss values at each epoch (float)
1329
+ - param_errors: dict of estimated 1-sigma uncertainties for each optimised parameter in the display parameterisation (or None if uncertainty estimation failed)
1330
+ - covariance_result: CovarianceResult or None: full covariance information needed for sampling, or None if estimation failed. Fields:
1331
+ - covariance : 2D array of covariance matrix in physical/log units
1332
+ - opt_keys: list of parameter keys corresponding to covariance_matrix rows/columns
1333
+ - best_opt_params: dict of best-fit optimised parameters (physical/log units)
1334
+ - fixed_params: dict of fixed parameters (physical/log units)
1335
+ - param_errors: dict of 1-sigma parameter uncertainties (physical/log units)
1336
+ - transformed_cov: dict of Jacobian-transformed covariance when 'rc'/'omega' was substutied by 'mu', keys are 'keys', 'cov', 'errors'
1337
+ """
1338
+ # lazy imports to avoid circular imports
1339
+ from . import outputs
1340
+ from . import errors
1341
+ # Initialize parameters
1342
+ loss_method = check_loss_method(loss_method)
1343
+
1344
+ opt_params, fixed_params = sanitize_param_partition(
1345
+ initial_opt_params,
1346
+ fixed_params,
1347
+ require_nonempty_opt=True,
1348
+ )
1349
+
1350
+
1351
+ param_bounds = standardise_param_bounds(param_bounds)
1352
+ param_bounds = convert_and_strip_bound_units(param_bounds)
1353
+
1354
+ # we perform optimisation in mu-space when either rc or omega is present. conversion is here
1355
+ # rotation_key records which of 'rc', 'omega', or 'mu' is input as rotation parameter by user,
1356
+ # so we know which one to convert back to at the end
1357
+ opt_params, fixed_params, param_bounds, rotation_key = with_mu_substituted(opt_params, fixed_params, param_bounds)
1358
+
1359
+ # add bounds for angles if they are being optimised
1360
+ param_bounds = auto_fill_angle_bounds(opt_params, param_bounds)
1361
+
1362
+ # check priors are valid and match optimised parameters
1363
+ validated_priors = validate_priors(priors, opt_params, fixed_params)
1364
+ priors_keys = tuple(validated_priors.keys())
1365
+ priors_means = tuple(float(v[0]) for v in validated_priors.values())
1366
+ priors_sigmas = tuple(float(v[1]) for v in validated_priors.values())
1367
+
1368
+
1369
+ opt_param_keys = list(opt_params.keys())
1370
+ data = make_data_tuple_float64(streamer.data)
1371
+ uncertainties = make_data_tuple_float64(streamer.uncertainties)
1372
+ distance_pc = to_float64(distance_pc)
1373
+ learning_rate = to_float64(learning_rate)
1374
+ if not bool(jnp.isfinite(learning_rate)):
1375
+ raise ValueError(f'learning_rate must be finite. Got {learning_rate}.')
1376
+ if not bool(learning_rate > 0):
1377
+ raise ValueError(f'learning_rate must be > 0. Got {float(learning_rate)}.')
1378
+ normalisation_spec = build_normalisation_spec(opt_params, param_bounds)
1379
+
1380
+ # Keep optimisation variables in normalised coordinates; convert back to
1381
+ # physical/log units only when evaluating the forward model and diagnostics.
1382
+ opt_params_norm = normalise_opt_params(opt_params, normalisation_spec)
1383
+
1384
+ # Use one global learning rate on normalised parameters
1385
+ solver = optax.adam(learning_rate=learning_rate, b1=beta1, b2=beta2)
1386
+
1387
+ opt_state = solver.init(opt_params_norm)
1388
+
1389
+ # Precompute data-only quantities once before optimisation loop
1390
+ prepared_data = extract_streamline.prepare_data(data, uncertainties, n_elements=len(data[0]))
1391
+ # npoints for forward model evaluation: fixed large number set by max r0 bound and deltar
1392
+ # this is necessary to ensure forward model has constant array lengths for jax/jit compatability
1393
+ if 'r0' in param_bounds:
1394
+ max_r0 = param_bounds['r0'][1]
1395
+ deltar = fixed_params['deltar'] if 'deltar' in fixed_params else 1.0
1396
+ npoints = int(jnp.ceil(max_r0 / deltar))
1397
+ else:
1398
+ npoints = 50000
1399
+
1400
+ fixed_params_for_core = fixed_params
1401
+
1402
+ @jax.jit
1403
+ def loss_from_normalised(norm_opt_params):
1404
+ physical_opt_params = denormalise_opt_params(norm_opt_params, normalisation_spec)
1405
+ model_params = {**fixed_params_for_core, **physical_opt_params}
1406
+ chi2_total, loss_trace, err = chi2_loss(
1407
+ model_params,
1408
+ distance_pc,
1409
+ prepared_data,
1410
+ loss_method=loss_method,
1411
+ npoints=npoints,
1412
+ priors_keys=priors_keys,
1413
+ priors_means=priors_means,
1414
+ priors_sigmas=priors_sigmas
1415
+ )
1416
+ return chi2_total, (loss_trace, err)
1417
+
1418
+
1419
+ # Create gradient functions in normalised space.
1420
+ loss_and_grad_fn = value_and_grad(loss_from_normalised, has_aux=True)
1421
+
1422
+
1423
+ # Track loss history
1424
+ loss_history = []
1425
+ initial_loss, (_, initial_err) = loss_from_normalised(opt_params_norm)
1426
+
1427
+ # raise any initial errors
1428
+ initial_error_message = get_checkify_error_message(initial_err)
1429
+ if initial_error_message is not None:
1430
+ raise ValueError(
1431
+ f"Initial loss computation failed with error: {initial_error_message}. "
1432
+ )
1433
+
1434
+ initial_loss = float(initial_loss)
1435
+ loss_history.append(initial_loss) # 'epoch 0' loss (initial state, before any updates)
1436
+ best_loss = initial_loss
1437
+ best_opt_params = opt_params.copy()
1438
+ best_opt_params_norm = opt_params_norm.copy()
1439
+ best_epoch = 0
1440
+ patience_counter = 0
1441
+ loss_threshold_counter = 0
1442
+ gradient_tol_counter = 0
1443
+ ordered_best_opt_params = {k: best_opt_params[k] for k in opt_param_keys}
1444
+
1445
+ if loss_threshold is not None:
1446
+ loss_threshold = float(loss_threshold)
1447
+ if not math.isfinite(loss_threshold):
1448
+ raise ValueError('loss_threshold must be finite when provided.')
1449
+ if loss_threshold_epochs < 1:
1450
+ raise ValueError('loss_threshold_epochs must be >= 1 when loss_threshold is provided.')
1451
+
1452
+ if gradient_tol is not None:
1453
+ gradient_tol = float(gradient_tol)
1454
+ if not math.isfinite(gradient_tol):
1455
+ raise ValueError('gradient_tol must be finite when provided.')
1456
+ if gradient_tol <= 0:
1457
+ raise ValueError('gradient_tol must be positive when provided.')
1458
+ if gradient_tol_epochs < 1:
1459
+ raise ValueError('gradient_tol_epochs must be >= 1 when gradient_tol is provided.')
1460
+
1461
+ # initialise log and trace files if output_folder is provided
1462
+ log_file = None
1463
+ log_writer = None
1464
+ trace_file = None
1465
+ trace_writer = None
1466
+ if save_folder is not None:
1467
+ os.makedirs(save_folder, exist_ok=True)
1468
+
1469
+ log_file = os.path.join(save_folder, 'optimisation_log.csv')
1470
+ log_file = open(log_file, 'w', newline='')
1471
+ # Create header: epoch, loss, then all optimisable params
1472
+ fieldnames = ['epoch', 'loss'] + [log_header(k) for k in opt_param_keys]
1473
+
1474
+ all_param_keys = set(opt_param_keys) | set(fixed_params.keys())
1475
+ if 'mu' in all_param_keys:
1476
+ # also log derived rc and omega when mu is present for convenience
1477
+ if log_header('rc') not in fieldnames:
1478
+ fieldnames.append(log_header('rc'))
1479
+ if log_header('omega') not in fieldnames:
1480
+ fieldnames.append(log_header('omega'))
1481
+ log_writer = csv.DictWriter(log_file, fieldnames=fieldnames)
1482
+ log_writer.writeheader()
1483
+ log_file.flush()
1484
+
1485
+ trace_file = os.path.join(save_folder, 'optimisation_trace.csv')
1486
+ trace_file = open(trace_file, 'w', newline='')
1487
+ trace_writer = csv.DictWriter(
1488
+ trace_file,
1489
+ fieldnames=trace_fieldnames_for_loss_method(loss_method),
1490
+ )
1491
+ trace_writer.writeheader()
1492
+ trace_file.flush()
1493
+
1494
+ print(f"Starting optimisation with {n_epochs} epochs...")
1495
+ print(f"Loss method: {loss_method}")
1496
+ print(f"optimising parameters: {opt_param_keys}")
1497
+ print(f"Fixed parameters: {list(fixed_params.keys())}")
1498
+ if validated_priors:
1499
+ print(f"Priors:")
1500
+ for key, (mu_p, sigma_p) in validated_priors.items():
1501
+ print(f" {key}: mean={format_param(key, mu_p)}, sigma={format_param(key, sigma_p)}")
1502
+ if loss_threshold is not None:
1503
+ print(
1504
+ f"Threshold-based stopping enabled: loss <= {loss_threshold:.6g} "
1505
+ f"for {loss_threshold_epochs} consecutive epochs."
1506
+ )
1507
+ if gradient_tol is not None:
1508
+ print(
1509
+ f"Gradient norm stopping enabled (normalised space): ||grad|| < {gradient_tol:.6g} "
1510
+ f"for {gradient_tol_epochs} consecutive epochs."
1511
+ )
1512
+ print(f"Initial optimisable values:")
1513
+ for key in opt_param_keys:
1514
+ print(f" {key}: {format_param(key, opt_params[key])}")
1515
+ print(f"Initial loss: {initial_loss:.6g}")
1516
+
1517
+ # Log epoch 0: initial state (before any updates)
1518
+ initial_loss = float(initial_loss)
1519
+ if log_writer is not None:
1520
+ row = {'epoch': 0, 'loss': initial_loss}
1521
+ for key in opt_param_keys:
1522
+ row[log_header(key)] = float(opt_params[key])
1523
+ row = add_rc_omega_to_log(row, opt_params, fixed_params, all_param_keys)
1524
+ if 'rc' in row:
1525
+ row[log_header('rc')] = row.pop('rc')
1526
+ if 'omega' in row:
1527
+ row[log_header('omega')] = row.pop('omega')
1528
+ log_writer.writerow(row)
1529
+ log_file.flush()
1530
+
1531
+ # Log epoch 0 trace if trace file is requested
1532
+ if trace_writer is not None:
1533
+ # Compute initial loss and trace
1534
+ (loss_value_trace, (loss_trace_raw, _)), norm_grads_trace = loss_and_grad_fn(opt_params_norm)
1535
+ loss_trace = trace_tree_to_python(loss_trace_raw)
1536
+ grad_norm = float(gradient_l2_norm(norm_grads_trace))
1537
+
1538
+ # Build and write trace row for epoch 0
1539
+ trace_row = build_trace_row(0, float(loss_value_trace), loss_trace, grad_norm, loss_method)
1540
+ trace_writer.writerow(trace_row)
1541
+ trace_file.flush()
1542
+
1543
+
1544
+ try:
1545
+ for epoch in range(1, n_epochs + 1):
1546
+ if epoch % info_every == 0:
1547
+ print(f"\n Starting Epoch {epoch} -------------------------")
1548
+ # Compute loss and gradients at pre-update normalised parameters.
1549
+ loss_trace = None
1550
+ (loss_before, _), norm_grads = loss_and_grad_fn(opt_params_norm)
1551
+
1552
+
1553
+ # Perform Optax Adam step in normalised space (apply update).
1554
+ updates, opt_state = solver.update(norm_grads, opt_state, params=opt_params_norm)
1555
+ opt_params_norm = optax.apply_updates(opt_params_norm, updates)
1556
+
1557
+ # Enforce normalised bounds and map back to physical/log values.
1558
+ for key in opt_param_keys:
1559
+ if key == 'phi0':
1560
+ # phi0 is cyclic; wrap to [0, 1) in normalised space
1561
+ opt_params_norm[key] = jnp.mod(opt_params_norm[key], 1.0)
1562
+ elif key == 'rc':
1563
+ # must be positive >0
1564
+ opt_params_norm[key] = jnp.clip(opt_params_norm[key], to_float64(1e-6), 1.0)
1565
+ elif key == 'omega':
1566
+ # must be positive >0
1567
+ opt_params_norm[key] = jnp.clip(opt_params_norm[key], to_float64(1e-6), 1.0)
1568
+ elif key == 'v_r0':
1569
+ #already dealt with
1570
+ continue
1571
+ else:
1572
+ opt_params_norm[key] = jnp.clip(opt_params_norm[key], 0.0, 1.0)
1573
+
1574
+ # Now materialize physical parameters from the (possibly clamped)
1575
+ # normalised parameters.
1576
+ opt_params = denormalise_opt_params(opt_params_norm, normalisation_spec)
1577
+
1578
+ # Compute loss and gradient at the post-update state S(epoch) for logging
1579
+ (loss_value, (loss_trace_raw, err)), norm_grads = loss_and_grad_fn(opt_params_norm)
1580
+ # print the gradients by parameter for debugging
1581
+ grad_norm = float(gradient_l2_norm(norm_grads))
1582
+
1583
+ # raise any errors
1584
+ error_message = get_checkify_error_message(err)
1585
+ if error_message is not None:
1586
+ print(
1587
+ f"\nStopping at epoch {epoch}: {error_message} "
1588
+ )
1589
+ break
1590
+
1591
+
1592
+ loss_trace = trace_tree_to_python(loss_trace_raw)
1593
+ loss_value = float(loss_value)
1594
+
1595
+
1596
+ # Log post-update state for this epoch
1597
+ if log_writer is not None:
1598
+ row = {'epoch': epoch, 'loss': loss_value}
1599
+ for key in opt_param_keys:
1600
+ row[log_header(key)] = float(opt_params[key])
1601
+
1602
+ row = add_rc_omega_to_log(row, opt_params, fixed_params, all_param_keys)
1603
+
1604
+ if 'rc' in row:
1605
+ row[log_header('rc')] = row.pop('rc')
1606
+ if 'omega' in row:
1607
+ row[log_header('omega')] = row.pop('omega')
1608
+ log_writer.writerow(row)
1609
+ log_file.flush()
1610
+
1611
+ # Track loss (store the loss for the loss_history)
1612
+ loss_history.append(loss_value)
1613
+
1614
+ if trace_writer is not None and loss_trace is not None:
1615
+ # Use the post-update loss value for trace logging
1616
+ trace_row = build_trace_row(epoch, loss_value, loss_trace, grad_norm, loss_method)
1617
+ trace_writer.writerow(trace_row)
1618
+ trace_file.flush()
1619
+
1620
+ # Early stopping checks (use post-update loss)
1621
+ if loss_value < best_loss:
1622
+ best_loss = loss_value
1623
+ best_opt_params = opt_params.copy()
1624
+ best_opt_params_norm = opt_params_norm.copy()
1625
+ best_epoch = epoch
1626
+ patience_counter = 0
1627
+ else:
1628
+ patience_counter += 1
1629
+
1630
+ if loss_threshold is not None:
1631
+ if loss_value <= loss_threshold:
1632
+ loss_threshold_counter += 1
1633
+ else:
1634
+ loss_threshold_counter = 0
1635
+
1636
+ if gradient_tol is not None:
1637
+ if grad_norm < gradient_tol:
1638
+ gradient_tol_counter += 1
1639
+ else:
1640
+ gradient_tol_counter = 0
1641
+
1642
+ # Print progress
1643
+ if epoch % info_every == 0:
1644
+ if gradient_tol is not None:
1645
+ print(f'Epoch {epoch}/{n_epochs}, Loss: {loss_value:.6f}, Best Loss: {best_loss:.6f}, ||grad||: {grad_norm:.6e}')
1646
+ else:
1647
+ print(f'Epoch {epoch}/{n_epochs}, Loss: {loss_value:.6f}, Best Loss: {best_loss:.6f}')
1648
+
1649
+ # Early stopping conditions (any one is sufficient to stop)
1650
+ if loss_threshold is not None and loss_threshold_counter >= loss_threshold_epochs:
1651
+ print(
1652
+ f"\nEarly stopping at epoch {epoch}: loss <= {loss_threshold:.6g} "
1653
+ f"for {loss_threshold_epochs} consecutive epochs"
1654
+ )
1655
+ break
1656
+
1657
+ if gradient_tol is not None and gradient_tol_counter >= gradient_tol_epochs:
1658
+ print(
1659
+ f"\nEarly stopping at epoch {epoch}: normalised gradient norm {grad_norm:.6e} < {gradient_tol:.6e} "
1660
+ f"for {gradient_tol_epochs} consecutive epochs"
1661
+ )
1662
+ break
1663
+
1664
+ if patience_counter >= early_stopping_patience:
1665
+ print(f"\nEarly stopping at epoch {epoch}: no improvement for {early_stopping_patience} epochs")
1666
+ break
1667
+
1668
+
1669
+ # restore canonical parameter order before returning
1670
+ ordered_best_opt_params = {k: best_opt_params[k] for k in opt_param_keys}
1671
+
1672
+ finally:
1673
+ # Always close the CSV file if it was opened
1674
+ if log_file is not None:
1675
+ log_path = log_file.name
1676
+ log_file.close()
1677
+ print(f"Optimisation log saved to: {log_path}")
1678
+ if trace_file is not None:
1679
+ trace_path = trace_file.name
1680
+ trace_file.close()
1681
+ print(f"Matching trace log saved to: {trace_path}")
1682
+
1683
+ print(f"Optimisation complete!")
1684
+ print(f"Best-fit parameters found at epoch: {best_epoch}, with loss: {best_loss:.6f}")
1685
+
1686
+ # compute errors on best-fit parameters
1687
+ print("\nEstimating parameter uncertainties from Hessian...")
1688
+ param_errors = None
1689
+ cov_matrix = None
1690
+ cov_transformed_dict = None
1691
+ # only input the rotation key if if actually needs transforming back from mu
1692
+ key_needs_transform = rotation_key if rotation_key in ('rc', 'omega') else None
1693
+ try:
1694
+ param_errors, cov_matrix, cov_transformed_dict = errors.estimate_parameter_errors(
1695
+ ordered_best_opt_params,
1696
+ fixed_params,
1697
+ distance_pc,
1698
+ prepared_data,
1699
+ loss_method=loss_method,
1700
+ gradient_tol=gradient_tol,
1701
+ normalisation_spec=normalisation_spec,
1702
+ best_norm_opt_params=best_opt_params_norm,
1703
+ rotation_key=key_needs_transform,
1704
+ npoints=npoints,
1705
+ priors_keys=priors_keys,
1706
+ priors_means=priors_means,
1707
+ priors_sigmas=priors_sigmas
1708
+ )
1709
+ except Exception as e:
1710
+ print(f"\nWarning: parameter uncertainty estimation failed: ({e}).")
1711
+ traceback.print_exc()
1712
+ print("Continuing without error estimates")
1713
+
1714
+ display_opt_params = dict(ordered_best_opt_params)
1715
+ display_fixed_params = dict(fixed_params)
1716
+ display_param_errors = dict(param_errors) if param_errors is not None else None
1717
+
1718
+ if cov_transformed_dict is not None and key_needs_transform is not None and display_param_errors is not None:
1719
+ if key_needs_transform in cov_transformed_dict['keys']:
1720
+ all_params_for_transform = {**display_fixed_params, **display_opt_params}
1721
+ mu_best = float(ordered_best_opt_params['mu'])
1722
+ mass_val = float(all_params_for_transform['mass'])
1723
+ r0_val = float(all_params_for_transform['r0'])
1724
+ display_opt_params[key_needs_transform] = rotation_param_from_mu(key_needs_transform, mu_best, mass_val, r0_val)
1725
+ display_opt_params.pop('mu', None)
1726
+ display_param_errors[key_needs_transform] = cov_transformed_dict['errors'][key_needs_transform]
1727
+ display_param_errors.pop('mu', None)
1728
+
1729
+ print("\nFinal parameters at best-fit:")
1730
+ all_display_params = {**display_fixed_params, **display_opt_params}
1731
+ for key in all_display_params.keys():
1732
+ value = all_display_params[key]
1733
+ if display_param_errors is not None and key in display_param_errors:
1734
+ error = display_param_errors[key]
1735
+ print(f" {key}: {format_param(key, value)} ± {format_param(key, error)}")
1736
+ else:
1737
+ print(f" {key}: {format_param(key, value)}")
1738
+
1739
+ outputs.save_best_fit_params(display_opt_params, display_fixed_params, display_param_errors, save_folder=save_folder)
1740
+
1741
+ # now we will make some plots of the results
1742
+ if save_folder is not None:
1743
+ print("\nMaking diagnostic plots...")
1744
+ outputs.plot_fitting_results(
1745
+ ordered_best_opt_params,
1746
+ opt_param_keys,
1747
+ fixed_params,
1748
+ streamer,
1749
+ distance_pc,
1750
+ loss_history,
1751
+ param_errors=param_errors,
1752
+ cov_matrix=cov_matrix,
1753
+ v_lsr=v_lsr,
1754
+ save_folder=save_folder,
1755
+ show_plots=show_plots,
1756
+ transformed_cov_result=cov_transformed_dict,
1757
+ )
1758
+
1759
+ # save results to CovarianceResult and FitResult namedtuples
1760
+ cov_result = None
1761
+ if cov_matrix is not None:
1762
+ cov_result = CovarianceResult(
1763
+ covariance=cov_matrix,
1764
+ opt_keys=list(ordered_best_opt_params.keys()),
1765
+ best_opt_params=ordered_best_opt_params,
1766
+ fixed_params=fixed_params,
1767
+ param_errors=param_errors,
1768
+ transformed_cov=cov_transformed_dict
1769
+ )
1770
+
1771
+ return FitResult(
1772
+ best_opt_params=display_opt_params,
1773
+ loss_history=loss_history,
1774
+ param_errors=display_param_errors,
1775
+ covariance_result=cov_result,
1776
+ )