sting 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sting/__init__.py +8 -0
- sting/_version.py +24 -0
- sting/errors.py +677 -0
- sting/extract_streamline.py +425 -0
- sting/gradient_descent.py +1776 -0
- sting/outputs.py +1705 -0
- sting/stream_lines_grad.py +448 -0
- sting-0.2.0.dist-info/METADATA +251 -0
- sting-0.2.0.dist-info/RECORD +14 -0
- sting-0.2.0.dist-info/WHEEL +5 -0
- sting-0.2.0.dist-info/licenses/LICENCE +21 -0
- sting-0.2.0.dist-info/scm_file_list.json +26 -0
- sting-0.2.0.dist-info/scm_version.json +8 -0
- sting-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1776 @@
|
|
|
1
|
+
'''
|
|
2
|
+
This file contains the loss function and optimisation routines for streamfit.
|
|
3
|
+
|
|
4
|
+
The optimisation uses adam (adaptive moment estimation) optimiser to fit
|
|
5
|
+
streamline model parameters to observed data by minimizing chi-squared loss.
|
|
6
|
+
|
|
7
|
+
Last updated: 19-06-2026
|
|
8
|
+
'''
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
from collections import namedtuple
|
|
12
|
+
|
|
13
|
+
import jax.numpy as jnp
|
|
14
|
+
from jax import value_and_grad, lax
|
|
15
|
+
import jax
|
|
16
|
+
from jax.experimental import checkify
|
|
17
|
+
from jax.random import key
|
|
18
|
+
import optax
|
|
19
|
+
from . import stream_lines_grad
|
|
20
|
+
from . import extract_streamline
|
|
21
|
+
import csv
|
|
22
|
+
import astropy.units as u
|
|
23
|
+
import math
|
|
24
|
+
import traceback
|
|
25
|
+
jax.config.update("jax_enable_x64", True)
|
|
26
|
+
|
|
27
|
+
# settings and constants
|
|
28
|
+
VR0_MIN = 1e-6
|
|
29
|
+
BIG = 1e30
|
|
30
|
+
BIG_NEG = -1e30
|
|
31
|
+
|
|
32
|
+
LOSS_METHOD_CHOICES = [0, 1]
|
|
33
|
+
|
|
34
|
+
LOSS_METHOD_COMPONENT_KEYS = {
|
|
35
|
+
0: ('chi2_ra', 'chi2_dec', 'chi2_v', 'chi2_prior'), #radecvel
|
|
36
|
+
1: ('chi2_r', 'chi2_theta', 'chi2_v', 'chi2_prior'), #rthetavel
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
TRACE_COMMON_FIELDNAMES = [
|
|
40
|
+
'epoch',
|
|
41
|
+
'loss',
|
|
42
|
+
'chi2_total',
|
|
43
|
+
'grad_norm',
|
|
44
|
+
'model_points_total',
|
|
45
|
+
'model_nan_count',
|
|
46
|
+
'model_valid_points',
|
|
47
|
+
'model_metric_span',
|
|
48
|
+
'model_inner_count',
|
|
49
|
+
'data_inner_count',
|
|
50
|
+
'data_points_total',
|
|
51
|
+
'data_valid_points',
|
|
52
|
+
'data_retained_count',
|
|
53
|
+
'model_retained_count',
|
|
54
|
+
'overlap_metric_min',
|
|
55
|
+
'overlap_metric_max',
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
CANONICAL_UNITS = {
|
|
59
|
+
"r0": u.au,
|
|
60
|
+
"theta0": u.rad,
|
|
61
|
+
"phi0": u.rad,
|
|
62
|
+
"inc": u.rad,
|
|
63
|
+
"pa": u.rad,
|
|
64
|
+
"v_r0": u.km / u.s,
|
|
65
|
+
"mass": u.Msun,
|
|
66
|
+
"rmin": u.au,
|
|
67
|
+
"deltar": u.au,
|
|
68
|
+
"v_lsr": u.km / u.s,
|
|
69
|
+
"rc": u.au,
|
|
70
|
+
"omega": 1 / u.s,
|
|
71
|
+
# mu = rc/r0 is dimensionless, so no units
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
ANGLE_KEYS = {'theta0', 'phi0', 'inc', 'pa'}
|
|
75
|
+
|
|
76
|
+
# might later add inc, pa here too
|
|
77
|
+
ANGLE_BOUNDS_RAD = {
|
|
78
|
+
'theta0': (0.0, jnp.pi),
|
|
79
|
+
'phi0': (0.0, 2 * jnp.pi),
|
|
80
|
+
'inc': (-jnp.pi/2, jnp.pi/2),
|
|
81
|
+
'pa': (0.0, 2 * jnp.pi),
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
DISPLAY_UNITS = {
|
|
85
|
+
'r0': 'au',
|
|
86
|
+
'v_r0': 'km/s',
|
|
87
|
+
'mass': 'M_sun',
|
|
88
|
+
'rmin': 'au',
|
|
89
|
+
'deltar': 'au',
|
|
90
|
+
'v_lsr': 'km/s',
|
|
91
|
+
'rc': 'au',
|
|
92
|
+
'omega': '1/s',
|
|
93
|
+
# mu is dimensionless
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
STREAMLINE_MODEL_PARAM_KEYS = (
|
|
97
|
+
'r0',
|
|
98
|
+
'theta0',
|
|
99
|
+
'phi0',
|
|
100
|
+
'rc',
|
|
101
|
+
'omega',
|
|
102
|
+
'mu',
|
|
103
|
+
'v_r0',
|
|
104
|
+
'mass',
|
|
105
|
+
'inc',
|
|
106
|
+
'pa',
|
|
107
|
+
'rmin',
|
|
108
|
+
'deltar',
|
|
109
|
+
'v_lsr',
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
#### Return types
|
|
113
|
+
|
|
114
|
+
# Contains all the information about the covariance matrix and errors of the fitted model
|
|
115
|
+
CovarianceResult = namedtuple(
|
|
116
|
+
'CovarianceResult',
|
|
117
|
+
[
|
|
118
|
+
'covariance', # 2-D array: physical-space covariance matrix from the Hessian, ordered consistently with opt_keys / best_params / fixed_params.
|
|
119
|
+
'opt_keys', # list[str]: parameter names for rows/cols of covariance (mu-substituted, i.e. 'mu' in place of 'rc'/'omega').
|
|
120
|
+
'best_opt_params', # dict: best-fit values in the mu-substituted parameterisation.
|
|
121
|
+
'fixed_params', # dict: fixed parameters in the mu-substituted parameterisation.
|
|
122
|
+
'param_errors', # dict: 1-sigma errors keyed by opt_keys names, or None.
|
|
123
|
+
'transformed_cov', # dict or None: Jacobian-transformed result when 'mu' was substituted for 'rc'/'omega'; keys are 'keys', 'cov', 'errors'.
|
|
124
|
+
]
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Contains all the information about the model fit result, including best-fit parameters
|
|
128
|
+
FitResult = namedtuple(
|
|
129
|
+
'FitResult',
|
|
130
|
+
[
|
|
131
|
+
'best_opt_params', # dict: best-fit optimised parameters in the original user-supplied parameterisation (rc/omega restored).
|
|
132
|
+
'loss_history', # list[float]: loss value at every epoch.
|
|
133
|
+
'param_errors', # dict or None: 1-sigma errors in display parameterisation, or None if uncertainty estimation failed.
|
|
134
|
+
'covariance_result', # CovarianceResult or None: full covariance information needed for sampling, or None if estimation failed.
|
|
135
|
+
]
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def convert_and_strip_bound_units(bounds):
|
|
139
|
+
"""
|
|
140
|
+
Convert bounds that are astropy quantitiesinto canonical units,
|
|
141
|
+
then strip the units.
|
|
142
|
+
|
|
143
|
+
If input is already plain numeric, assume it's already in canonical units.
|
|
144
|
+
|
|
145
|
+
Required as JAX optimiser works with unitless arrays.
|
|
146
|
+
"""
|
|
147
|
+
if bounds is None:
|
|
148
|
+
return {}
|
|
149
|
+
|
|
150
|
+
output = {}
|
|
151
|
+
|
|
152
|
+
for key, val in bounds.items():
|
|
153
|
+
if isinstance(val, u.Quantity):
|
|
154
|
+
if key not in CANONICAL_UNITS:
|
|
155
|
+
raise ValueError(f"The parameter {key} doesn't have defined canonical units...")
|
|
156
|
+
bounds = val.to(CANONICAL_UNITS[key])
|
|
157
|
+
output[key] = (
|
|
158
|
+
float(bounds[0].value),
|
|
159
|
+
float(bounds[1].value),
|
|
160
|
+
)
|
|
161
|
+
else:
|
|
162
|
+
# already unitless
|
|
163
|
+
output[key] = tuple(float(v) for v in val)
|
|
164
|
+
return output
|
|
165
|
+
|
|
166
|
+
def check_loss_method(loss_method):
|
|
167
|
+
"""Check that the selected loss method is valid and return it"""
|
|
168
|
+
if loss_method not in LOSS_METHOD_CHOICES:
|
|
169
|
+
raise ValueError(
|
|
170
|
+
f"Unknown loss_method '{loss_method}'. "
|
|
171
|
+
f"Choose from: 0: radecvel 1: rthetavel"
|
|
172
|
+
)
|
|
173
|
+
return loss_method
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def trace_fieldnames_for_loss_method(loss_method):
|
|
177
|
+
"""Return the trace csv headers for the chosen loss method"""
|
|
178
|
+
loss_method = check_loss_method(loss_method)
|
|
179
|
+
return ['epoch', 'loss', *LOSS_METHOD_COMPONENT_KEYS[loss_method], *TRACE_COMMON_FIELDNAMES[2:]]
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def is_numeric_value(value):
|
|
183
|
+
"""Return True for scalar/array-like numeric values"""
|
|
184
|
+
try:
|
|
185
|
+
arr = jnp.asarray(value)
|
|
186
|
+
except Exception:
|
|
187
|
+
return False
|
|
188
|
+
if arr.dtype == jnp.bool_:
|
|
189
|
+
return False
|
|
190
|
+
return bool(jnp.issubdtype(arr.dtype, jnp.number))
|
|
191
|
+
|
|
192
|
+
@jax.jit
|
|
193
|
+
def to_float64(value):
|
|
194
|
+
"""Convert a numeric value or array-like input to float64"""
|
|
195
|
+
return jnp.asarray(value, dtype=jnp.float64)
|
|
196
|
+
|
|
197
|
+
@jax.jit
|
|
198
|
+
def softplus(x):
|
|
199
|
+
"""Softplus, used for v_r0"""
|
|
200
|
+
return jnp.logaddexp(x, 0.0)
|
|
201
|
+
|
|
202
|
+
@jax.jit
|
|
203
|
+
def inv_softplus(y):
|
|
204
|
+
"""used for v_r0, stable for y>0"""
|
|
205
|
+
y = to_float64(y)
|
|
206
|
+
return y + jnp.log1p(-jnp.exp(-y))
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def get_checkify_error_message(err):
|
|
210
|
+
"""Extract human-readable error message from a checkify.Error if possible,
|
|
211
|
+
or None if it doesn't contain anything"""
|
|
212
|
+
if hasattr(err, 'get'):
|
|
213
|
+
return err.get()
|
|
214
|
+
try:
|
|
215
|
+
err.throw()
|
|
216
|
+
except Exception as e:
|
|
217
|
+
return str(e)
|
|
218
|
+
return None
|
|
219
|
+
|
|
220
|
+
def make_data_tuple_float64(values):
|
|
221
|
+
"""Convert tuple/list of arrays to float64 arrays"""
|
|
222
|
+
return tuple(to_float64(value) for value in values)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def clean_model_param_dict(params, dict_name):
|
|
226
|
+
"""Convert parameter dictionary to float64 and standardise it"""
|
|
227
|
+
if params is None:
|
|
228
|
+
params = {}
|
|
229
|
+
if not isinstance(params, dict):
|
|
230
|
+
raise TypeError(f"{dict_name} must be a dictionary, got {type(params).__name__}.")
|
|
231
|
+
|
|
232
|
+
sanitized = {}
|
|
233
|
+
|
|
234
|
+
for key, val in params.items():
|
|
235
|
+
if val is None:
|
|
236
|
+
sanitized[key] = None
|
|
237
|
+
continue
|
|
238
|
+
if isinstance(val, u.Quantity):
|
|
239
|
+
if key not in CANONICAL_UNITS:
|
|
240
|
+
raise ValueError(f"The parameter {key} doesn't have defined canonical units...")
|
|
241
|
+
val = val.to(CANONICAL_UNITS[key]).value
|
|
242
|
+
# if it's already a raw number, assume it's already correct
|
|
243
|
+
sanitized[key] = jnp.asarray(val, dtype=jnp.float64)
|
|
244
|
+
|
|
245
|
+
tiny = to_float64(1e-8)
|
|
246
|
+
# Protect against exact polar-angle edge values which can cause
|
|
247
|
+
# downstream numerical issues (theta=0 or theta=pi).
|
|
248
|
+
# If the uservsupplied exactly 0 or pi,
|
|
249
|
+
# nudge by a tiny amount into the open interval (0, pi).
|
|
250
|
+
if 'theta0' in sanitized:
|
|
251
|
+
try:
|
|
252
|
+
theta_val = to_float64(sanitized['theta0'])
|
|
253
|
+
if bool(jnp.all(jnp.isclose(theta_val, to_float64(0.0)))):
|
|
254
|
+
sanitized['theta0'] = theta_val + tiny
|
|
255
|
+
elif bool(jnp.all(jnp.isclose(theta_val, to_float64(jnp.pi)))):
|
|
256
|
+
sanitized['theta0'] = theta_val - tiny
|
|
257
|
+
except Exception:
|
|
258
|
+
pass
|
|
259
|
+
|
|
260
|
+
unknown = sorted(key for key in sanitized if key not in STREAMLINE_MODEL_PARAM_KEYS)
|
|
261
|
+
if unknown:
|
|
262
|
+
raise KeyError(
|
|
263
|
+
f"Unknown parameter keys in {dict_name}: {unknown} "
|
|
264
|
+
f"Supported keys are: {list(STREAMLINE_MODEL_PARAM_KEYS)}"
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
return sanitized
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def check_param_types(opt_params, fixed_params):
|
|
271
|
+
"""Check that model parameters are of the correct type (numeric or None for rmin)"""
|
|
272
|
+
for key, value in opt_params.items():
|
|
273
|
+
if key == 'rmin' and value is None:
|
|
274
|
+
raise ValueError("'rmin' cannot be None")
|
|
275
|
+
if isinstance(value, bool) or not is_numeric_value(value):
|
|
276
|
+
raise TypeError(
|
|
277
|
+
f"Optimisable parameter '{key}' must be numeric, "
|
|
278
|
+
f"got value of type {type(value).__name__}."
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
for key, value in fixed_params.items():
|
|
282
|
+
if key == 'rmin' and value is None:
|
|
283
|
+
continue
|
|
284
|
+
if isinstance(value, bool) or not is_numeric_value(value):
|
|
285
|
+
raise TypeError(
|
|
286
|
+
f"Fixed parameter '{key}' must be numeric"
|
|
287
|
+
" (or None only for 'rmin'), "
|
|
288
|
+
f"got value of type {type(value).__name__}."
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def sanitize_param_partition(initial_opt_params, fixed_params, require_nonempty_opt=False):
|
|
293
|
+
"""Sanitize and validate opt/fixed parameter partition for streamline modeling.
|
|
294
|
+
Note: exactly one of 'rc' or 'omega' must be present across initial_opt_params and fixed_params, to determine mu=rc/r0."""
|
|
295
|
+
opt_params = clean_model_param_dict(initial_opt_params, 'initial_opt_params')
|
|
296
|
+
fixed_params = clean_model_param_dict(fixed_params, 'fixed_params')
|
|
297
|
+
|
|
298
|
+
overlap = sorted(set(opt_params) & set(fixed_params))
|
|
299
|
+
if overlap:
|
|
300
|
+
raise KeyError(
|
|
301
|
+
f"Parameters cannot be present in both initial_opt_params and fixed_params! Overlap: {overlap}"
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
all_params = set(opt_params) | set(fixed_params)
|
|
305
|
+
|
|
306
|
+
# check that exactly one of rc, omega, or mu (the rotation keys) is supplied
|
|
307
|
+
rotation_keys_present = [ key for key in ('rc', 'omega', 'mu') if key in all_params]
|
|
308
|
+
if len(rotation_keys_present) != 1:
|
|
309
|
+
raise KeyError(
|
|
310
|
+
f"Exactly one of 'rc', 'omega', or 'mu' must be provided. You have provided: {rotation_keys_present}"
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
# check all other required parameters are present (except mu, rc, omega which we already dealt with)
|
|
314
|
+
already_dealt_with = {'rc', 'omega', 'mu'}
|
|
315
|
+
missing = []
|
|
316
|
+
for key in STREAMLINE_MODEL_PARAM_KEYS:
|
|
317
|
+
if key not in all_params and key not in already_dealt_with:
|
|
318
|
+
missing.append(key)
|
|
319
|
+
if missing:
|
|
320
|
+
raise KeyError(
|
|
321
|
+
"Missing required streamline parameters across initial_opt_params and fixed_params: "
|
|
322
|
+
f"{missing}."
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
if require_nonempty_opt and len(opt_params) == 0:
|
|
326
|
+
raise ValueError(
|
|
327
|
+
"initial_opt_params must contain at least one optimisable parameter. "
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
check_param_types(opt_params, fixed_params)
|
|
331
|
+
|
|
332
|
+
return opt_params, fixed_params
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def prepare_model_params(opt_params, fixed_params):
|
|
336
|
+
"""Construct merged model parameters and clean opt/fixed dictionaries"""
|
|
337
|
+
opt_params, fixed_params = sanitize_param_partition(opt_params, fixed_params)
|
|
338
|
+
model_params = fixed_params.copy()
|
|
339
|
+
model_params.update(opt_params)
|
|
340
|
+
return model_params, opt_params, fixed_params
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def standardise_param_bounds(param_bounds):
|
|
344
|
+
"""Check/standardise parameter-bound keys"""
|
|
345
|
+
if param_bounds is None:
|
|
346
|
+
return None
|
|
347
|
+
|
|
348
|
+
standardised = dict(param_bounds)
|
|
349
|
+
|
|
350
|
+
unknown = sorted(key for key in standardised if key not in STREAMLINE_MODEL_PARAM_KEYS)
|
|
351
|
+
if unknown:
|
|
352
|
+
raise KeyError(
|
|
353
|
+
f"Unknown params in param_bounds: {unknown}. "
|
|
354
|
+
f"Supported params are: {list(STREAMLINE_MODEL_PARAM_KEYS)}"
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
return standardised
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def build_normalisation_spec(opt_params, param_bounds):
|
|
361
|
+
"""Build shift and scale for normalisation ofoptimised parameters, from bounds.
|
|
362
|
+
Doesn't include v_r0 since we use softplus transform instead for that."""
|
|
363
|
+
if param_bounds is None:
|
|
364
|
+
raise ValueError(
|
|
365
|
+
"param_bounds is required because optimisation is performed in normalised space. "
|
|
366
|
+
"Provide bounds for 'mass', 'r0' if you are optimising these."
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
missing = []
|
|
370
|
+
for key in opt_params:
|
|
371
|
+
if key == 'v_r0':
|
|
372
|
+
continue
|
|
373
|
+
if key not in param_bounds:
|
|
374
|
+
missing.append(key)
|
|
375
|
+
if missing:
|
|
376
|
+
raise ValueError(
|
|
377
|
+
"Missing bounds for optimised parameters: "
|
|
378
|
+
f"{missing}. Please add (min, max) entries for all parameters you want to optimise."
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
normalisation_spec = {}
|
|
382
|
+
for key, value in opt_params.items():
|
|
383
|
+
if key == 'v_r0':
|
|
384
|
+
if key in param_bounds:
|
|
385
|
+
print("Notice: Ignoring user-supplied bounds for 'v_r0' since we use a softplus transform for this parameter instead of normalisation.")
|
|
386
|
+
continue
|
|
387
|
+
bounds = param_bounds[key]
|
|
388
|
+
if not isinstance(bounds, (tuple, list)) or len(bounds) != 2:
|
|
389
|
+
raise ValueError(
|
|
390
|
+
f"Bounds for '{key}' must be a 2-element (min, max) tuple."
|
|
391
|
+
f"Got: {bounds!r}"
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
lower_bound = to_float64(bounds[0])
|
|
395
|
+
upper_bound = to_float64(bounds[1])
|
|
396
|
+
if not bool(jnp.isfinite(lower_bound)) or not bool(jnp.isfinite(upper_bound)):
|
|
397
|
+
raise ValueError(f"Bounds for '{key}' must be finite. Got ({lower_bound}, {upper_bound})")
|
|
398
|
+
if not bool(upper_bound > lower_bound):
|
|
399
|
+
raise ValueError(
|
|
400
|
+
f"Bounds for '{key}' must satisfy min < max. Got ({float(lower_bound)}, {float(upper_bound)})"
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
value = to_float64(value)
|
|
404
|
+
if not bool((value >= lower_bound) & (value <= upper_bound)):
|
|
405
|
+
raise ValueError(
|
|
406
|
+
f"Initial value for '{key}' ({float(value)}) is outside bounds"
|
|
407
|
+
f"({float(lower_bound)}, {float(upper_bound)})."
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
scale = upper_bound - lower_bound
|
|
411
|
+
normalisation_spec[key] = {
|
|
412
|
+
'offset': lower_bound,
|
|
413
|
+
'scale': scale,
|
|
414
|
+
}
|
|
415
|
+
|
|
416
|
+
return normalisation_spec
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def normalise_opt_params(opt_params, normalisation_spec):
|
|
420
|
+
"""normalise optimised parameters to [0, 1], unless v_r0 which is transformed by softplus instead"""
|
|
421
|
+
normalised = {}
|
|
422
|
+
for key, value in opt_params.items():
|
|
423
|
+
if key == 'v_r0':
|
|
424
|
+
# save the value 'raw' such that v_r0 = softplus(raw)
|
|
425
|
+
value = to_float64(value)
|
|
426
|
+
if value < 0:
|
|
427
|
+
raise ValueError(f"v_r0 must be non-negative, got {float(value)}")
|
|
428
|
+
normalised[key] = inv_softplus(value)
|
|
429
|
+
continue
|
|
430
|
+
offset = normalisation_spec[key]['offset']
|
|
431
|
+
scale = normalisation_spec[key]['scale']
|
|
432
|
+
normalised[key] = (to_float64(value) - offset) / scale
|
|
433
|
+
return normalised
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
def denormalise_opt_params(norm_opt_params, normalisation_spec):
|
|
437
|
+
"""Convert normalised optimised parameters back to physical/log parameter values"""
|
|
438
|
+
denormalised = {}
|
|
439
|
+
for key, value in norm_opt_params.items():
|
|
440
|
+
if key == 'v_r0':
|
|
441
|
+
denormalised[key] = softplus(to_float64(value))
|
|
442
|
+
continue
|
|
443
|
+
offset = normalisation_spec[key]['offset']
|
|
444
|
+
scale = normalisation_spec[key]['scale']
|
|
445
|
+
if key == 'phi0' or key == 'pa':
|
|
446
|
+
# special handling for phi0, pa because circular
|
|
447
|
+
denormalised[key] = jnp.mod(to_float64(value) * scale + offset, 2*jnp.pi)
|
|
448
|
+
else:
|
|
449
|
+
denormalised[key] = to_float64(value) * scale + offset
|
|
450
|
+
return denormalised
|
|
451
|
+
|
|
452
|
+
def log_header(param_key):
|
|
453
|
+
if param_key == 'mu':
|
|
454
|
+
return 'mu'
|
|
455
|
+
unit = CANONICAL_UNITS.get(param_key)
|
|
456
|
+
return f'{param_key} [{unit}]' if unit is not None else param_key
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
def get_rotation_param_key(opt_params, fixed_params):
|
|
461
|
+
"""Return whichever of 'rc', 'omega', 'mu' is present in the parameters"""
|
|
462
|
+
all_params = set(opt_params) | set(fixed_params)
|
|
463
|
+
for key in ('rc', 'omega', 'mu'):
|
|
464
|
+
if key in all_params:
|
|
465
|
+
return key
|
|
466
|
+
raise KeyError("None of the parameters 'rc', 'omega', or 'mu' are present in the parameters!")
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
def mu_from_rotation_param(rotation_key, value, mass, r0):
|
|
470
|
+
"""Convert whatever rotation parameter is preesnt into mu"""
|
|
471
|
+
if rotation_key == 'mu':
|
|
472
|
+
return value
|
|
473
|
+
elif rotation_key == 'rc':
|
|
474
|
+
return value / r0
|
|
475
|
+
elif rotation_key == 'omega':
|
|
476
|
+
return stream_lines_grad.mu_from_omega(omega=value, mass=mass, r0=r0)
|
|
477
|
+
|
|
478
|
+
def rotation_param_from_mu(rotation_key, mu, mass, r0):
|
|
479
|
+
"""Convert mu into the rotation parameter that is being used (the one that was input by the user)"""
|
|
480
|
+
if rotation_key == 'mu':
|
|
481
|
+
return mu
|
|
482
|
+
elif rotation_key == 'rc':
|
|
483
|
+
return mu * r0
|
|
484
|
+
elif rotation_key == 'omega':
|
|
485
|
+
return stream_lines_grad.omega_from_mu(mu=mu, mass=mass, r0=r0)
|
|
486
|
+
|
|
487
|
+
def with_mu_substituted(opt_params, fixed_params, param_bounds=None):
|
|
488
|
+
""" Replace the user's input rotation parameter (either rc or omega) with mu, which is the parameter used internally for the physics calculations and optimisation,
|
|
489
|
+
because it has obvious bounds (0,1) that will mean that optimisation won't explore regions where rc > r0.
|
|
490
|
+
|
|
491
|
+
User should not have supplied bounds for 'rc' or 'omega', but if they did, prints a notice and ignores them.
|
|
492
|
+
(Because the optimisation is performed in mu space where the bounds should be (0,1))"""
|
|
493
|
+
rotation_key = get_rotation_param_key(opt_params, fixed_params)
|
|
494
|
+
opt_params = dict(opt_params)
|
|
495
|
+
fixed_params = dict(fixed_params)
|
|
496
|
+
if param_bounds is not None:
|
|
497
|
+
param_bounds = dict(param_bounds)
|
|
498
|
+
else:
|
|
499
|
+
param_bounds = {}
|
|
500
|
+
|
|
501
|
+
all_params = {**fixed_params, **opt_params}
|
|
502
|
+
mass = all_params['mass']
|
|
503
|
+
r0 = all_params['r0']
|
|
504
|
+
|
|
505
|
+
if rotation_key in opt_params:
|
|
506
|
+
rotation_value = opt_params[rotation_key]
|
|
507
|
+
mu_value = mu_from_rotation_param(rotation_key, rotation_value, mass, r0)
|
|
508
|
+
del opt_params[rotation_key]
|
|
509
|
+
opt_params['mu'] = mu_value
|
|
510
|
+
# drop any rc/omega/mu bounds the user supplied, we use (0,1) bounds for mu
|
|
511
|
+
if rotation_key in param_bounds:
|
|
512
|
+
print(
|
|
513
|
+
f"Notice: Ignoring user-supplied bounds for '{rotation_key}', since the optimisation is performed in 'mu' space. "
|
|
514
|
+
f"Using (0, 1) bounds for 'mu' instead."
|
|
515
|
+
)
|
|
516
|
+
del param_bounds[rotation_key]
|
|
517
|
+
param_bounds['mu'] = (0.0+1e-6, 1.0-1e-6) # tiny epsilon to avoid rc=r0 or rc=0
|
|
518
|
+
else:
|
|
519
|
+
# rotation parameter is in fixed params. so just rename it to 'mu' in fixed params for consistency
|
|
520
|
+
rotation_value = fixed_params[rotation_key]
|
|
521
|
+
mu_value = mu_from_rotation_param(rotation_key, rotation_value, mass, r0)
|
|
522
|
+
del fixed_params[rotation_key]
|
|
523
|
+
fixed_params['mu'] = mu_value
|
|
524
|
+
|
|
525
|
+
return opt_params, fixed_params, param_bounds, rotation_key
|
|
526
|
+
|
|
527
|
+
def auto_fill_angle_bounds(opt_params, param_bounds):
|
|
528
|
+
"""Automatically fill in bounds for theta0, phi0, inc, pa, since their ranges are fixed by physics.
|
|
529
|
+
If the user supplied bounds for these parameters by mistake, print a notice and ignore them."""
|
|
530
|
+
if param_bounds is None:
|
|
531
|
+
param_bounds = {}
|
|
532
|
+
else:
|
|
533
|
+
param_bounds = dict(param_bounds)
|
|
534
|
+
for key, auto_bounds in ANGLE_BOUNDS_RAD.items():
|
|
535
|
+
if key not in opt_params:
|
|
536
|
+
continue
|
|
537
|
+
if key in param_bounds:
|
|
538
|
+
print(f"Notice: Ignoring supplied bounds for '{key}', since the bounds are fixed by physics. Using automatic bounds ({math.degrees(auto_bounds[0])} - {math.degrees(auto_bounds[1])} degrees).")
|
|
539
|
+
param_bounds[key] = auto_bounds
|
|
540
|
+
|
|
541
|
+
return param_bounds
|
|
542
|
+
|
|
543
|
+
def format_param(key, value):
|
|
544
|
+
"""
|
|
545
|
+
Format parameter for display in output, with units. Notably:
|
|
546
|
+
- converts angles (theta0, phi0, inc, pa) from radians to degrees
|
|
547
|
+
"""
|
|
548
|
+
val = float(value)
|
|
549
|
+
if key in ANGLE_KEYS:
|
|
550
|
+
deg = math.degrees(val)
|
|
551
|
+
return f"{deg:.6g} deg"
|
|
552
|
+
if key == 'mu':
|
|
553
|
+
return f"{val:.6g}"
|
|
554
|
+
if key == 'omega':
|
|
555
|
+
return f"{val:.6g} 1/s"
|
|
556
|
+
unit = DISPLAY_UNITS.get(key, '')
|
|
557
|
+
if unit:
|
|
558
|
+
suffix = f" {unit}"
|
|
559
|
+
else:
|
|
560
|
+
suffix = ""
|
|
561
|
+
return f"{val:.6g}{suffix}"
|
|
562
|
+
|
|
563
|
+
def validate_priors(priors, opt_params, fixed_params):
|
|
564
|
+
"""Check the validity of the priors dictionary supplied by the user.
|
|
565
|
+
|
|
566
|
+
priors must be a dict of the form::
|
|
567
|
+
|
|
568
|
+
{param_name: (mean, sigma) * u.Unit, ...}
|
|
569
|
+
|
|
570
|
+
where mean and sigma are plain floats in the same canonical units as the rest of the code,
|
|
571
|
+
or a length-2 astropy Quantity with the appropriate units.
|
|
572
|
+
|
|
573
|
+
Only parameters that are being optimised (i.e. keys present in opt_params) can have priors.
|
|
574
|
+
|
|
575
|
+
Parameters
|
|
576
|
+
----------
|
|
577
|
+
priors : dict or None
|
|
578
|
+
opt_params : dict : optimisable parameters (after sanitisation)
|
|
579
|
+
fixed_params : dict :fixed parameters (after sanitisation)
|
|
580
|
+
|
|
581
|
+
Returns
|
|
582
|
+
-------
|
|
583
|
+
dict : validated priors, or {} if None.
|
|
584
|
+
"""
|
|
585
|
+
if priors is None:
|
|
586
|
+
return {}
|
|
587
|
+
if not isinstance(priors, dict):
|
|
588
|
+
raise TypeError(f"priors must be a dict, got {type(priors).__name__}.")
|
|
589
|
+
|
|
590
|
+
validated = {}
|
|
591
|
+
for key, val in priors.items():
|
|
592
|
+
if key not in STREAMLINE_MODEL_PARAM_KEYS:
|
|
593
|
+
raise KeyError(
|
|
594
|
+
f"Unknown parameter '{key}' in priors. "
|
|
595
|
+
f"Supported keys are: {list(STREAMLINE_MODEL_PARAM_KEYS)}"
|
|
596
|
+
)
|
|
597
|
+
if key in fixed_params:
|
|
598
|
+
raise ValueError(
|
|
599
|
+
f"Parameter '{key}' is fixed and cannot have a prior. "
|
|
600
|
+
"Move it to initial_opt_params if you want to optimise it with a prior."
|
|
601
|
+
)
|
|
602
|
+
if key not in opt_params:
|
|
603
|
+
raise ValueError(
|
|
604
|
+
f"Parameter '{key}' has a prior but is not being optimised. "
|
|
605
|
+
"Add it to initial_opt_params or remove it from priors."
|
|
606
|
+
)
|
|
607
|
+
if isinstance(val, u.Quantity) and val.shape == (2,):
|
|
608
|
+
mean = val[0]
|
|
609
|
+
sigma = val[1]
|
|
610
|
+
elif isinstance(val, (tuple, list)) and len(val) == 2:
|
|
611
|
+
mean, sigma = val
|
|
612
|
+
else:
|
|
613
|
+
raise ValueError(
|
|
614
|
+
f"Prior for '{key}' must be a 2-element (mean, sigma) tuple, "
|
|
615
|
+
f"or a length-2 Quantity e.g. (4.0, 1.0) * u.Msun. Got: {val!r}"
|
|
616
|
+
)
|
|
617
|
+
mean, sigma = val
|
|
618
|
+
|
|
619
|
+
if isinstance(mean, u.Quantity):
|
|
620
|
+
if key not in CANONICAL_UNITS:
|
|
621
|
+
raise ValueError(f"No canonical units defined for '{key}'.")
|
|
622
|
+
mean = float(mean.to(CANONICAL_UNITS[key]).value)
|
|
623
|
+
else:
|
|
624
|
+
mean = float(mean) # assume it's already in the right units
|
|
625
|
+
if isinstance(sigma, u.Quantity):
|
|
626
|
+
if key not in CANONICAL_UNITS:
|
|
627
|
+
raise ValueError(f"No canonical units defined for '{key}'.")
|
|
628
|
+
sigma = float(sigma.to(CANONICAL_UNITS[key]).value)
|
|
629
|
+
else:
|
|
630
|
+
sigma = float(sigma) # assume it's already in the right units
|
|
631
|
+
if sigma <= 0:
|
|
632
|
+
raise ValueError(
|
|
633
|
+
f"Prior sigma for '{key}' must be positive, got {sigma}."
|
|
634
|
+
)
|
|
635
|
+
validated[key] = (mean, sigma)
|
|
636
|
+
return validated
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
@jax.jit(static_argnames=('priors_keys',))
|
|
640
|
+
def compute_prior_penalty(model_params, priors_means, priors_sigmas, priors_keys):
|
|
641
|
+
"""Compute the Gaussian prior penalty term:
|
|
642
|
+
sum_k ((theta_k - mu_k) / sigma_k)^2.
|
|
643
|
+
|
|
644
|
+
Parameters
|
|
645
|
+
----------
|
|
646
|
+
model_params : dict: physical model parameters
|
|
647
|
+
priors_means : tuple of float: prior means, one per prior key
|
|
648
|
+
priors_sigmas : tuple of float: prior sigmas, one per prior key
|
|
649
|
+
priors_keys : tuple of str: corresponding prior keys (parameter names)
|
|
650
|
+
|
|
651
|
+
Returns
|
|
652
|
+
-------
|
|
653
|
+
scalar JAX array : the prior chi-squared contribution
|
|
654
|
+
"""
|
|
655
|
+
penalty = jnp.asarray(0.0, dtype=jnp.float64)
|
|
656
|
+
for key, mu_p, sigma_p in zip(priors_keys, priors_means, priors_sigmas):
|
|
657
|
+
mu_p = jnp.asarray(mu_p, dtype=jnp.float64)
|
|
658
|
+
sigma_p = jnp.asarray(sigma_p, dtype=jnp.float64)
|
|
659
|
+
penalty = penalty + ((model_params[key] - mu_p) / sigma_p) ** 2
|
|
660
|
+
return penalty
|
|
661
|
+
|
|
662
|
+
def add_rc_omega_to_log(row, opt_params, fixed_params, all_param_keys):
|
|
663
|
+
"""Add rc and omega to the log row for user-friendly output, converting from mu if necessary"""
|
|
664
|
+
if 'mu' not in all_param_keys:
|
|
665
|
+
return row
|
|
666
|
+
mass_val = opt_params['mass'] if 'mass' in opt_params else fixed_params.get('mass', None)
|
|
667
|
+
r0_val = opt_params['r0'] if 'r0' in opt_params else fixed_params.get('r0', None)
|
|
668
|
+
mu_val = opt_params['mu'] if 'mu' in opt_params else fixed_params.get('mu', None)
|
|
669
|
+
if mass_val is not None and r0_val is not None and mu_val is not None:
|
|
670
|
+
rc_val = mu_val * r0_val
|
|
671
|
+
omega_val = stream_lines_grad.omega_from_mu(mu=mu_val, mass=mass_val, r0=r0_val)
|
|
672
|
+
row['rc'] = float(rc_val)
|
|
673
|
+
row['omega'] = float(omega_val)
|
|
674
|
+
return row
|
|
675
|
+
|
|
676
|
+
|
|
677
|
+
def build_trace_row(epoch, loss_value, loss_trace, grad_norm, loss_method):
|
|
678
|
+
"""Flatten trace dictionary into a csv row for output"""
|
|
679
|
+
loss_method = check_loss_method(loss_method)
|
|
680
|
+
chi2_components = loss_trace.get('chi2_components', {})
|
|
681
|
+
matching = loss_trace.get('matching', {})
|
|
682
|
+
model_metric_trace = matching.get('distance_metric_model', {})
|
|
683
|
+
data_metric_trace = matching.get('distance_metric_data', {})
|
|
684
|
+
|
|
685
|
+
row = {
|
|
686
|
+
'epoch': epoch,
|
|
687
|
+
'loss': loss_value,
|
|
688
|
+
}
|
|
689
|
+
for component_key in LOSS_METHOD_COMPONENT_KEYS[loss_method]:
|
|
690
|
+
row[component_key] = chi2_components.get(component_key, float('nan'))
|
|
691
|
+
|
|
692
|
+
row.update({
|
|
693
|
+
'chi2_total': chi2_components.get('chi2_total', float('nan')),
|
|
694
|
+
'grad_norm': grad_norm,
|
|
695
|
+
'model_points_total': matching.get('model_points_total', 0),
|
|
696
|
+
'model_nan_count': matching.get('model_nan_count', 0),
|
|
697
|
+
'model_valid_points': matching.get('model_valid_points', 0),
|
|
698
|
+
'model_metric_span': matching.get('model_metric_span', float('nan')),
|
|
699
|
+
'model_inner_count': model_metric_trace.get('inner_count', 0),
|
|
700
|
+
'data_inner_count': data_metric_trace.get('inner_count', 0),
|
|
701
|
+
'data_points_total': matching.get('data_points_total', 0),
|
|
702
|
+
'data_valid_points': matching.get('data_valid_points', 0),
|
|
703
|
+
'data_retained_count': matching.get('data_retained_count', 0),
|
|
704
|
+
'model_retained_count': matching.get('model_retained_count', 0),
|
|
705
|
+
'overlap_metric_min': matching.get('overlap_metric_min', float('nan')),
|
|
706
|
+
'overlap_metric_max': matching.get('overlap_metric_max', float('nan')),
|
|
707
|
+
})
|
|
708
|
+
|
|
709
|
+
return row
|
|
710
|
+
|
|
711
|
+
def trace_tree_to_python(value):
|
|
712
|
+
"""Go through the trace tree and convert JAX arrays to Python scalars where possible"""
|
|
713
|
+
# go through containers, converting JAX arrays to Python scalars where possible, and leaving non-numeric values as-is
|
|
714
|
+
if isinstance(value, dict):
|
|
715
|
+
return {key: trace_tree_to_python(v) for key, v in value.items()}
|
|
716
|
+
if isinstance(value, list):
|
|
717
|
+
return [trace_tree_to_python(v) for v in value]
|
|
718
|
+
if isinstance(value, tuple):
|
|
719
|
+
return tuple(trace_tree_to_python(v) for v in value)
|
|
720
|
+
# preserve Nones as-is
|
|
721
|
+
if value is None:
|
|
722
|
+
return None
|
|
723
|
+
# convert jax arrays to python scalars where posible
|
|
724
|
+
try:
|
|
725
|
+
array_value = jnp.asarray(value)
|
|
726
|
+
except Exception:
|
|
727
|
+
return value
|
|
728
|
+
# if it's a scalar array, convert to scalar
|
|
729
|
+
if array_value.ndim == 0:
|
|
730
|
+
return array_value.item()
|
|
731
|
+
return value
|
|
732
|
+
|
|
733
|
+
|
|
734
|
+
@jax.jit
|
|
735
|
+
def gradient_l2_norm(grad_tree):
|
|
736
|
+
"""Compute L2 norm of gradients across all leaves in a pytree
|
|
737
|
+
(a pytree is a nested structure of lists/dicts/tuples containing arrays,
|
|
738
|
+
used by jax for gradients)."""
|
|
739
|
+
grad_leaves = jax.tree_util.tree_leaves(grad_tree)
|
|
740
|
+
grad_sum_sq = jnp.asarray(0.0, dtype=jnp.float64)
|
|
741
|
+
for grad_leaf in grad_leaves:
|
|
742
|
+
grad_sum_sq = grad_sum_sq + jnp.sum(jnp.square(grad_leaf))
|
|
743
|
+
return jnp.sqrt(grad_sum_sq)
|
|
744
|
+
|
|
745
|
+
@jax.jit(static_argnames=("npoints",))
|
|
746
|
+
def forward_model(model_params, distance_pc, npoints=10000):
|
|
747
|
+
"""
|
|
748
|
+
Run the forward model using stream_lines_grad.checked_xyz_stream
|
|
749
|
+
|
|
750
|
+
Parameters:
|
|
751
|
+
-----------
|
|
752
|
+
model_params: dict
|
|
753
|
+
Dictionary of model parameters, including both optimised and fixed parameters
|
|
754
|
+
distance_pc : float
|
|
755
|
+
Distance to source in parsecs
|
|
756
|
+
npoints : int
|
|
757
|
+
Number of points to sample along the streamer
|
|
758
|
+
This is just for jax/jit compatibility to have fixed-length arrays, but
|
|
759
|
+
the actual number of valid points is determined by r0, rmin, rc, deltar,
|
|
760
|
+
so some of the returned points may be NaN if npoints is larger than the number of valid points.
|
|
761
|
+
|
|
762
|
+
Returns:
|
|
763
|
+
--------
|
|
764
|
+
tuple: (ra_offsets, dec_offsets, velocities)
|
|
765
|
+
- RA offsets in arcsec (negative for standard convention)
|
|
766
|
+
- Dec offsets in arcsec
|
|
767
|
+
- Line-of-sight velocities in km/s, relative to v_lsr
|
|
768
|
+
"""
|
|
769
|
+
|
|
770
|
+
|
|
771
|
+
distance_pc = to_float64(distance_pc)
|
|
772
|
+
|
|
773
|
+
# Run the forward model - returns positions in au, velocities in km/s
|
|
774
|
+
# valid_mask is a boolean array marking which points are valid in the returned arrays,
|
|
775
|
+
# which can be used for masking in the loss function
|
|
776
|
+
rmin = model_params['rmin']
|
|
777
|
+
if rmin is None:
|
|
778
|
+
rmin = to_float64(0.0) # rc*0.5 will always dominate in jnp.maximum
|
|
779
|
+
# derive mu from rc or omega (whicever is provided)
|
|
780
|
+
if 'mu' in model_params:
|
|
781
|
+
mu = model_params['mu']
|
|
782
|
+
elif 'rc' in model_params:
|
|
783
|
+
mu = model_params['rc'] / model_params['r0']
|
|
784
|
+
elif 'omega' in model_params:
|
|
785
|
+
mu = stream_lines_grad.mu_from_omega(omega=model_params['omega'], mass=model_params['mass'], r0=model_params['r0'])
|
|
786
|
+
else:
|
|
787
|
+
raise ValueError("model_params must contain either 'rc', 'omega', or 'mu'")
|
|
788
|
+
model_params = dict(model_params)
|
|
789
|
+
model_params['mu'] = mu
|
|
790
|
+
|
|
791
|
+
err, ((x, y, z), (vx, vy, vz), valid_mask) = stream_lines_grad.checked_xyz_stream(
|
|
792
|
+
mass=model_params['mass'],
|
|
793
|
+
r0=model_params['r0'],
|
|
794
|
+
theta0=model_params['theta0'],
|
|
795
|
+
phi0=model_params['phi0'],
|
|
796
|
+
mu=model_params['mu'],
|
|
797
|
+
v_r0=model_params['v_r0'],
|
|
798
|
+
inc=model_params['inc'],
|
|
799
|
+
pa=model_params['pa'],
|
|
800
|
+
rmin=rmin,
|
|
801
|
+
deltar=model_params['deltar'],
|
|
802
|
+
npoints=npoints
|
|
803
|
+
)
|
|
804
|
+
# err.throw()
|
|
805
|
+
|
|
806
|
+
|
|
807
|
+
# Convert positions from au to arcsec offsets
|
|
808
|
+
# x = RA offset (with negative for standard RA convention)
|
|
809
|
+
# z = Dec offset
|
|
810
|
+
# y = line-of-sight velocity
|
|
811
|
+
ra_model = -x / distance_pc # arcsec
|
|
812
|
+
dec_model = z / distance_pc # arcsec
|
|
813
|
+
# make velocity absolute by adding back v_lsr
|
|
814
|
+
v_model = vy + model_params['v_lsr'] # km/s
|
|
815
|
+
# make sure model points outside valid_mask are set to 0
|
|
816
|
+
ra_model = jnp.where(valid_mask, ra_model, 0.0)
|
|
817
|
+
dec_model = jnp.where(valid_mask, dec_model, 0.0)
|
|
818
|
+
v_model = jnp.where(valid_mask, v_model, 0.0)
|
|
819
|
+
|
|
820
|
+
|
|
821
|
+
return ra_model, dec_model, v_model, valid_mask, err
|
|
822
|
+
|
|
823
|
+
|
|
824
|
+
@jax.jit
|
|
825
|
+
def distance_metric_overlap(dmetric_model, model_finite_mask, dmetric_data, data_finite_mask):
|
|
826
|
+
"""Compute the overlapping range in the streamline distance metric between data and model"""
|
|
827
|
+
model_metric_for_min = jnp.where(model_finite_mask, dmetric_model, to_float64(BIG))
|
|
828
|
+
model_metric_for_max = jnp.where(model_finite_mask, dmetric_model, to_float64(BIG_NEG))
|
|
829
|
+
data_metric_for_min = jnp.where(data_finite_mask, dmetric_data, to_float64(BIG))
|
|
830
|
+
data_metric_for_max = jnp.where(data_finite_mask, dmetric_data, to_float64(BIG_NEG))
|
|
831
|
+
|
|
832
|
+
model_min = jnp.min(model_metric_for_min)
|
|
833
|
+
model_max = jnp.max(model_metric_for_max)
|
|
834
|
+
data_min = jnp.min(data_metric_for_min)
|
|
835
|
+
data_max = jnp.max(data_metric_for_max)
|
|
836
|
+
|
|
837
|
+
overlap_min = jnp.maximum(model_min, data_min)
|
|
838
|
+
overlap_max = jnp.minimum(model_max, data_max)
|
|
839
|
+
return model_min, model_max, data_min, data_max, overlap_min, overlap_max
|
|
840
|
+
|
|
841
|
+
@jax.jit
|
|
842
|
+
def match_model_to_data_curve(ra_model, dec_model, v_model, valid_mask_model, ra_data, dec_data):
|
|
843
|
+
"""
|
|
844
|
+
Extract model values corresponding to data positions using the distance metric from
|
|
845
|
+
extract_streamline.get_distance_metric
|
|
846
|
+
|
|
847
|
+
Method:
|
|
848
|
+
1. Compute the distance metric for model and data points
|
|
849
|
+
2. Apply finite masks
|
|
850
|
+
3. Normalise both metrics to [0, 1] based on their finite ranges
|
|
851
|
+
4. Map data normalised positions to model normalised positions
|
|
852
|
+
5. Interpolate model RA, Dec, and velocity at the mapped positions
|
|
853
|
+
|
|
854
|
+
Returns
|
|
855
|
+
-------
|
|
856
|
+
ra_model_interp, dec_model_interp, v_model_interp, valid, dmetric_model, matching_trace
|
|
857
|
+
where valid is a boolean mask with shape len(original data), marking
|
|
858
|
+
retained data points
|
|
859
|
+
"""
|
|
860
|
+
ra_model = to_float64(ra_model)
|
|
861
|
+
dec_model = to_float64(dec_model)
|
|
862
|
+
v_model = to_float64(v_model)
|
|
863
|
+
ra_data = to_float64(ra_data)
|
|
864
|
+
dec_data = to_float64(dec_data)
|
|
865
|
+
|
|
866
|
+
# get distance metrics
|
|
867
|
+
dmetric_model, _ = extract_streamline.get_distance_metric(ra_model, dec_model)
|
|
868
|
+
dmetric_data, _ = extract_streamline.get_distance_metric(ra_data, dec_data)
|
|
869
|
+
|
|
870
|
+
model_valid = valid_mask_model
|
|
871
|
+
|
|
872
|
+
data_valid = (
|
|
873
|
+
jnp.isfinite(ra_data)
|
|
874
|
+
& jnp.isfinite(dec_data)
|
|
875
|
+
& jnp.isfinite(dmetric_data)
|
|
876
|
+
)
|
|
877
|
+
|
|
878
|
+
# # we also filter model to keep only model points with dmetric >= minimum of data dmetric
|
|
879
|
+
# this is becuase the model shouldn't go further in than the innermost data point
|
|
880
|
+
# as this is where we no longer observe the streamer
|
|
881
|
+
d_data_valid = jnp.where(data_valid, dmetric_data, to_float64(BIG))
|
|
882
|
+
data_min = jnp.min(d_data_valid)
|
|
883
|
+
|
|
884
|
+
# enforce both constraints on model
|
|
885
|
+
model_keep = model_valid.astype(bool) & (dmetric_model >= data_min)
|
|
886
|
+
|
|
887
|
+
# weights: 0 = ignore, 1 = use. This is for jax/jit compatibility
|
|
888
|
+
w_model = model_keep.astype(jnp.float64)
|
|
889
|
+
|
|
890
|
+
d_model = jnp.where(model_keep, dmetric_model, 0.0)
|
|
891
|
+
d_data = jnp.where(data_valid, dmetric_data, 0.0)
|
|
892
|
+
|
|
893
|
+
# ---- sort model using metric + weight penalty (pushes invalid points to the end) ----
|
|
894
|
+
model_sort_key = d_model + (1.0 - w_model) * to_float64(BIG)
|
|
895
|
+
model_idx = jnp.argsort(model_sort_key)
|
|
896
|
+
|
|
897
|
+
d_model_s = d_model[model_idx]
|
|
898
|
+
ra_s = ra_model[model_idx]
|
|
899
|
+
dec_s = dec_model[model_idx]
|
|
900
|
+
v_s = v_model[model_idx]
|
|
901
|
+
w_model_s = w_model[model_idx]
|
|
902
|
+
|
|
903
|
+
# stats for trace and interpolation domain
|
|
904
|
+
data_min_eff = jnp.min(jnp.where(data_valid, dmetric_data, to_float64(BIG)))
|
|
905
|
+
data_max_eff = jnp.max(jnp.where(data_valid, dmetric_data, to_float64(BIG_NEG)))
|
|
906
|
+
model_min = jnp.min(jnp.where(model_keep, dmetric_model, to_float64(BIG)))
|
|
907
|
+
model_max = jnp.max(jnp.where(model_keep, dmetric_model, to_float64(BIG_NEG)))
|
|
908
|
+
|
|
909
|
+
|
|
910
|
+
model_span = model_max - model_min
|
|
911
|
+
data_span = data_max_eff - data_min_eff
|
|
912
|
+
model_span_safe = jnp.where(model_span > to_float64(0.0), model_span, to_float64(1.0))
|
|
913
|
+
data_span_safe = jnp.where(data_span > to_float64(0.0), data_span, to_float64(1.0))
|
|
914
|
+
|
|
915
|
+
|
|
916
|
+
data_has_valid = data_min_eff < to_float64(BIG)
|
|
917
|
+
model_has_valid = model_min < to_float64(BIG)
|
|
918
|
+
both_valid = data_has_valid & model_has_valid
|
|
919
|
+
# normalise data metric
|
|
920
|
+
d_data_norm = (d_data - data_min_eff) / data_span_safe
|
|
921
|
+
d_goal = model_min + d_data_norm * model_span_safe
|
|
922
|
+
|
|
923
|
+
# interpolate model at data points, using weights to ignore invalid model points
|
|
924
|
+
# by giving them huge distance values so they don't affect the interpolation
|
|
925
|
+
xp = jnp.where(w_model_s > 0, d_model_s, to_float64(BIG))
|
|
926
|
+
|
|
927
|
+
ra_interp = jnp.interp(d_goal, xp, ra_s)
|
|
928
|
+
dec_interp = jnp.interp(d_goal, xp, dec_s)
|
|
929
|
+
v_interp = jnp.interp(d_goal, xp, v_s)
|
|
930
|
+
|
|
931
|
+
# things for trace
|
|
932
|
+
valid = data_valid
|
|
933
|
+
|
|
934
|
+
matching_trace = {
|
|
935
|
+
"model_points_total": model_idx.size,
|
|
936
|
+
"model_nan_count": jnp.sum(jnp.isnan(dmetric_model)),
|
|
937
|
+
"model_valid_points": model_valid.sum(),
|
|
938
|
+
"data_points_total": ra_data.size,
|
|
939
|
+
"data_nan_count": jnp.sum(jnp.isnan(dmetric_data)),
|
|
940
|
+
"data_valid_points": data_valid.sum(),
|
|
941
|
+
"model_metric_min": model_min,
|
|
942
|
+
"model_metric_max": model_max,
|
|
943
|
+
"data_metric_min": data_min_eff,
|
|
944
|
+
"data_metric_max": data_max_eff,
|
|
945
|
+
"model_metric_span": model_span_safe,
|
|
946
|
+
"data_metric_span": data_span_safe}
|
|
947
|
+
|
|
948
|
+
return ra_interp, dec_interp, v_interp, valid, model_keep, dmetric_model, matching_trace
|
|
949
|
+
|
|
950
|
+
checked_matching = checkify.checkify(match_model_to_data_curve)
|
|
951
|
+
|
|
952
|
+
|
|
953
|
+
@jax.jit
|
|
954
|
+
def checked_match_model_to_data_curve(*args, **kwargs):
|
|
955
|
+
"""Wrapper around match_model_to_data_curve with checkify checks for errors (to remain jax compatible)"""
|
|
956
|
+
errors, result = checked_matching(*args, **kwargs)
|
|
957
|
+
errors.throw()
|
|
958
|
+
return result
|
|
959
|
+
|
|
960
|
+
@jax.jit(static_argnames=("loss_method", "npoints", "priors_keys", "priors_means", "priors_sigmas"))
|
|
961
|
+
def chi2_loss(
|
|
962
|
+
model_params,
|
|
963
|
+
distance_pc,
|
|
964
|
+
prepared_data,
|
|
965
|
+
loss_method=0,
|
|
966
|
+
npoints=10000,
|
|
967
|
+
priors_keys=(),
|
|
968
|
+
priors_means=(),
|
|
969
|
+
priors_sigmas=()
|
|
970
|
+
):
|
|
971
|
+
"""Generates model via forward model and calculates loss between data and model.
|
|
972
|
+
Returns (chi2_total, loss_trace, err) where err is a checkify.
|
|
973
|
+
|
|
974
|
+
chi2_loss = chi2_data + chi2_priors.
|
|
975
|
+
|
|
976
|
+
chi2_priors is the sum of Gaussian prior penalty terms for any optimised parameters where priors were given.
|
|
977
|
+
See compute_prior_penalty"""
|
|
978
|
+
|
|
979
|
+
loss_method = check_loss_method(loss_method)
|
|
980
|
+
|
|
981
|
+
distance_pc = to_float64(distance_pc)
|
|
982
|
+
|
|
983
|
+
ra_data = prepared_data.ra_data
|
|
984
|
+
dec_data = prepared_data.dec_data
|
|
985
|
+
v_data = prepared_data.v_data
|
|
986
|
+
ra_sigma = prepared_data.ra_sigma_safe
|
|
987
|
+
dec_sigma = prepared_data.dec_sigma_safe
|
|
988
|
+
v_sigma = prepared_data.v_sigma_safe
|
|
989
|
+
|
|
990
|
+
ra_model, dec_model, v_model, valid_mask_model, err = forward_model(model_params, distance_pc, npoints=npoints)
|
|
991
|
+
valid_mask_model = valid_mask_model.astype(jnp.bool_)
|
|
992
|
+
|
|
993
|
+
ra_model_interp, dec_model_interp, v_model_interp, valid, model_keep, dmetric_model, _ = (
|
|
994
|
+
checked_match_model_to_data_curve(ra_model, dec_model, v_model, valid_mask_model, ra_data, dec_data)
|
|
995
|
+
)
|
|
996
|
+
|
|
997
|
+
dmetric_data = prepared_data.dmetric_data
|
|
998
|
+
valid = jnp.asarray(valid, dtype=bool)
|
|
999
|
+
valid_weights = valid.astype(jnp.float64)
|
|
1000
|
+
|
|
1001
|
+
model_finite_mask = (
|
|
1002
|
+
jnp.isfinite(ra_model)
|
|
1003
|
+
& jnp.isfinite(dec_model)
|
|
1004
|
+
& jnp.isfinite(v_model)
|
|
1005
|
+
& jnp.isfinite(dmetric_model)
|
|
1006
|
+
)
|
|
1007
|
+
|
|
1008
|
+
|
|
1009
|
+
# Only compute chi2 on valid/retained data points
|
|
1010
|
+
|
|
1011
|
+
chi2_v = jnp.sum(valid_weights * (((v_data - v_model_interp) / v_sigma) ** 2))
|
|
1012
|
+
|
|
1013
|
+
|
|
1014
|
+
if loss_method == 0:
|
|
1015
|
+
chi2_ra = jnp.sum(valid_weights * (((ra_data - ra_model_interp) / ra_sigma) ** 2))
|
|
1016
|
+
chi2_dec = jnp.sum(valid_weights * (((dec_data - dec_model_interp) / dec_sigma) ** 2))
|
|
1017
|
+
chi2_total = chi2_ra + chi2_dec + chi2_v
|
|
1018
|
+
else:
|
|
1019
|
+
r_proj_data = prepared_data.r_proj_data
|
|
1020
|
+
theta_proj_data = prepared_data.theta_proj_data
|
|
1021
|
+
r_proj_model, theta_proj_model = extract_streamline.cartesian_to_polar(
|
|
1022
|
+
ra_model_interp,
|
|
1023
|
+
dec_model_interp,
|
|
1024
|
+
)
|
|
1025
|
+
|
|
1026
|
+
dtheta = extract_streamline.wrap_to_pi(theta_proj_data - theta_proj_model)
|
|
1027
|
+
|
|
1028
|
+
sigma_r = jnp.sqrt(ra_sigma**2 + dec_sigma**2)
|
|
1029
|
+
r_eps = to_float64(1e-8)
|
|
1030
|
+
r_safe = jnp.maximum(jnp.abs(r_proj_data), r_eps)
|
|
1031
|
+
sigma_theta = jnp.sqrt(((dec_data * ra_sigma)**2 + (ra_data * dec_sigma)**2)) / (r_safe**2)
|
|
1032
|
+
sigma_theta = jnp.maximum(sigma_theta, r_eps)
|
|
1033
|
+
|
|
1034
|
+
chi2_r = jnp.sum(valid_weights * (((r_proj_data - r_proj_model) / sigma_r) ** 2))
|
|
1035
|
+
chi2_theta = jnp.sum(valid_weights * ((dtheta / sigma_theta) ** 2))
|
|
1036
|
+
chi2_total = chi2_r + chi2_theta + chi2_v
|
|
1037
|
+
|
|
1038
|
+
|
|
1039
|
+
data_finite_mask = (
|
|
1040
|
+
jnp.isfinite(ra_data)
|
|
1041
|
+
& jnp.isfinite(dec_data)
|
|
1042
|
+
& jnp.isfinite(dmetric_data)
|
|
1043
|
+
)
|
|
1044
|
+
|
|
1045
|
+
model_min, model_max, data_min, data_max, overlap_min, overlap_max = distance_metric_overlap(
|
|
1046
|
+
dmetric_model,
|
|
1047
|
+
model_finite_mask,
|
|
1048
|
+
dmetric_data,
|
|
1049
|
+
data_finite_mask,
|
|
1050
|
+
)
|
|
1051
|
+
|
|
1052
|
+
model_nan_count = jnp.sum(~model_finite_mask)
|
|
1053
|
+
model_points_total = ra_model.size
|
|
1054
|
+
model_valid_points = model_points_total - model_nan_count
|
|
1055
|
+
|
|
1056
|
+
data_keep = data_finite_mask & (dmetric_data >= overlap_min) & (dmetric_data <= overlap_max)
|
|
1057
|
+
model_keep = model_finite_mask & (dmetric_model >= overlap_min) & (dmetric_model <= overlap_max)
|
|
1058
|
+
|
|
1059
|
+
sort_idx = jnp.argsort(dmetric_model)
|
|
1060
|
+
d_model_sorted = dmetric_model[sort_idx]
|
|
1061
|
+
d_diff = jnp.diff(d_model_sorted)
|
|
1062
|
+
if d_diff.size > 0:
|
|
1063
|
+
model_metric_min_gap = jnp.min(d_diff)
|
|
1064
|
+
model_metric_near_tie_count = jnp.sum(jnp.abs(d_diff) <= 1e-8)
|
|
1065
|
+
model_metric_duplicate_count = jnp.sum(d_diff == 0.0)
|
|
1066
|
+
model_metric_non_monotonic_count = jnp.sum(d_diff < 0.0)
|
|
1067
|
+
else:
|
|
1068
|
+
model_metric_min_gap = to_float64(float('nan'))
|
|
1069
|
+
model_metric_near_tie_count = to_float64(0.0)
|
|
1070
|
+
model_metric_duplicate_count = to_float64(0.0)
|
|
1071
|
+
model_metric_non_monotonic_count = to_float64(0.0)
|
|
1072
|
+
|
|
1073
|
+
model_metric_span = d_model_sorted[-1] - d_model_sorted[0] if d_model_sorted.size > 1 else to_float64(0.0)
|
|
1074
|
+
|
|
1075
|
+
# compute and add the prior penalty term to the total chi2 if priors are provided
|
|
1076
|
+
chi2_prior = compute_prior_penalty(model_params, priors_means, priors_sigmas, priors_keys)
|
|
1077
|
+
chi2_total = chi2_total + chi2_prior
|
|
1078
|
+
|
|
1079
|
+
if loss_method == 0:
|
|
1080
|
+
chi2_components = {
|
|
1081
|
+
'chi2_ra': chi2_ra,
|
|
1082
|
+
'chi2_dec': chi2_dec,
|
|
1083
|
+
'chi2_v': chi2_v,
|
|
1084
|
+
'chi2_prior': chi2_prior,
|
|
1085
|
+
'overlap_width': overlap_max - overlap_min,
|
|
1086
|
+
'chi2_total': chi2_total,
|
|
1087
|
+
}
|
|
1088
|
+
else:
|
|
1089
|
+
chi2_components = {
|
|
1090
|
+
'chi2_r': chi2_r,
|
|
1091
|
+
'chi2_theta': chi2_theta,
|
|
1092
|
+
'chi2_v': chi2_v,
|
|
1093
|
+
'chi2_prior': chi2_prior,
|
|
1094
|
+
'overlap_width': overlap_max - overlap_min,
|
|
1095
|
+
'chi2_total': chi2_total,
|
|
1096
|
+
}
|
|
1097
|
+
|
|
1098
|
+
matching_trace = {
|
|
1099
|
+
'model_points_total': model_points_total,
|
|
1100
|
+
'model_nan_count': model_nan_count,
|
|
1101
|
+
'model_valid_points': model_valid_points,
|
|
1102
|
+
'model_retained_count': jnp.sum(model_keep),
|
|
1103
|
+
'data_points_total': ra_data.size,
|
|
1104
|
+
'data_valid_points': jnp.sum(data_finite_mask),
|
|
1105
|
+
'data_retained_count': jnp.sum(data_keep),
|
|
1106
|
+
'overlap_metric_min': overlap_min,
|
|
1107
|
+
'overlap_metric_max': overlap_max,
|
|
1108
|
+
'model_metric_span': model_metric_span,
|
|
1109
|
+
'model_metric_min_gap': model_metric_min_gap,
|
|
1110
|
+
'model_metric_near_tie_count': model_metric_near_tie_count,
|
|
1111
|
+
'model_metric_duplicate_count': model_metric_duplicate_count,
|
|
1112
|
+
'model_metric_non_monotonic_count': model_metric_non_monotonic_count,
|
|
1113
|
+
}
|
|
1114
|
+
|
|
1115
|
+
loss_trace = {
|
|
1116
|
+
'chi2_components': chi2_components,
|
|
1117
|
+
'matching': matching_trace,
|
|
1118
|
+
'loss_method': loss_method,
|
|
1119
|
+
}
|
|
1120
|
+
return chi2_total, loss_trace, err
|
|
1121
|
+
|
|
1122
|
+
|
|
1123
|
+
|
|
1124
|
+
InitialGuessResult = namedtuple('InitialGuessResult', [
|
|
1125
|
+
'model_params',
|
|
1126
|
+
'ra_model',
|
|
1127
|
+
'dec_model',
|
|
1128
|
+
'v_model',
|
|
1129
|
+
'ra_model_interp',
|
|
1130
|
+
'dec_model_interp',
|
|
1131
|
+
'v_model_interp',
|
|
1132
|
+
'valid',
|
|
1133
|
+
'chi2_total',
|
|
1134
|
+
'chi2_components',
|
|
1135
|
+
])
|
|
1136
|
+
|
|
1137
|
+
|
|
1138
|
+
def evaluate_initial_guess(
|
|
1139
|
+
initial_opt_params,
|
|
1140
|
+
fixed_params,
|
|
1141
|
+
data,
|
|
1142
|
+
uncertainties,
|
|
1143
|
+
distance_pc,
|
|
1144
|
+
n_elements=10,
|
|
1145
|
+
loss_method=0,
|
|
1146
|
+
priors=None,
|
|
1147
|
+
):
|
|
1148
|
+
"""
|
|
1149
|
+
Run the forward model and compute chi2 loss for the initial parameter guess.
|
|
1150
|
+
|
|
1151
|
+
Parameters
|
|
1152
|
+
----------
|
|
1153
|
+
initial_opt_params : dict
|
|
1154
|
+
Initial guesses for the parameters to optimise.
|
|
1155
|
+
fixed_params : dict
|
|
1156
|
+
Fixed (non-optimised) parameters. Together with initial_opt_params
|
|
1157
|
+
this must provide a full, non-overlapping partition of
|
|
1158
|
+
STREAMLINE_MODEL_PARAM_KEYS.
|
|
1159
|
+
data : tuple of arrays (ra_data, dec_data, v_data)
|
|
1160
|
+
Observed RA offset (arcsec), Dec offset (arcsec), velocity (km/s).
|
|
1161
|
+
uncertainties : tuple of arrays (ra_sigma, dec_sigma, v_sigma)
|
|
1162
|
+
Uncertainties on the data.
|
|
1163
|
+
distance_pc : float
|
|
1164
|
+
Distance to source in parsecs.
|
|
1165
|
+
n_elements : int
|
|
1166
|
+
Number of distance-metric partitions, i.e. the number of 1D data
|
|
1167
|
+
points. Must match the value used when reducing the cube.
|
|
1168
|
+
loss_method : int
|
|
1169
|
+
Loss definition to use. Options:
|
|
1170
|
+
- 0: radecvel — RA, Dec, and velocity residuals.
|
|
1171
|
+
- 1: rthetavel — radial distance, polar angle, and velocity residuals.
|
|
1172
|
+
priors : dict or None
|
|
1173
|
+
Optional Gaussian priors on optimised parameters, in the form
|
|
1174
|
+
{param_name: (mean, sigma), ...}. Only optimised parameters can have priors.
|
|
1175
|
+
Mean and sigma must be in the same canonical units as the rest of the code
|
|
1176
|
+
|
|
1177
|
+
Returns
|
|
1178
|
+
-------
|
|
1179
|
+
InitialGuessResult
|
|
1180
|
+
Named tuple with entries:
|
|
1181
|
+
- model_params : merged dict of all model parameters (float64)
|
|
1182
|
+
- ra_model : full model RA offsets (arcsec)
|
|
1183
|
+
- dec_model : full model Dec offsets (arcsec)
|
|
1184
|
+
- v_model : full model velocities (km/s)
|
|
1185
|
+
- ra_model_interp : model RA interpolated at data positions
|
|
1186
|
+
- dec_model_interp : model Dec interpolated at data positions
|
|
1187
|
+
- v_model_interp : model velocity interpolated at data positions
|
|
1188
|
+
- valid : boolean mask of retained data points
|
|
1189
|
+
- chi2_total : total chi2 loss (float)
|
|
1190
|
+
- chi2_components : dict of per-component chi2 values and chi2_total
|
|
1191
|
+
"""
|
|
1192
|
+
loss_method = check_loss_method(loss_method)
|
|
1193
|
+
|
|
1194
|
+
model_params, opt_params_clean, fixed_params_clean = prepare_model_params(initial_opt_params, fixed_params)
|
|
1195
|
+
validated_priors = validate_priors(priors, opt_params_clean, fixed_params_clean)
|
|
1196
|
+
priors_keys = tuple(validated_priors.keys())
|
|
1197
|
+
priors_means = tuple(v[0] for v in validated_priors.values())
|
|
1198
|
+
priors_sigmas = tuple(v[1] for v in validated_priors.values())
|
|
1199
|
+
|
|
1200
|
+
ra_model, dec_model, v_model, valid_mask_model, err = forward_model(
|
|
1201
|
+
model_params, distance_pc
|
|
1202
|
+
)
|
|
1203
|
+
err.throw()
|
|
1204
|
+
|
|
1205
|
+
ra_model_interp, dec_model_interp, v_model_interp, valid, _, _, _ = (
|
|
1206
|
+
checked_match_model_to_data_curve(
|
|
1207
|
+
ra_model, dec_model, v_model, valid_mask_model,
|
|
1208
|
+
jnp.asarray(data[0], dtype=jnp.float64),
|
|
1209
|
+
jnp.asarray(data[1], dtype=jnp.float64),
|
|
1210
|
+
)
|
|
1211
|
+
)
|
|
1212
|
+
|
|
1213
|
+
prepared_data = extract_streamline.prepare_data(data, uncertainties, n_elements=n_elements)
|
|
1214
|
+
chi2_total, loss_trace, _ = chi2_loss(
|
|
1215
|
+
model_params, distance_pc, prepared_data, loss_method=loss_method,
|
|
1216
|
+
priors_keys=priors_keys, priors_means=priors_means, priors_sigmas=priors_sigmas
|
|
1217
|
+
)
|
|
1218
|
+
chi2_components = loss_trace['chi2_components']
|
|
1219
|
+
|
|
1220
|
+
return InitialGuessResult(
|
|
1221
|
+
model_params=model_params,
|
|
1222
|
+
ra_model=ra_model,
|
|
1223
|
+
dec_model=dec_model,
|
|
1224
|
+
v_model=v_model,
|
|
1225
|
+
ra_model_interp=ra_model_interp,
|
|
1226
|
+
dec_model_interp=dec_model_interp,
|
|
1227
|
+
v_model_interp=v_model_interp,
|
|
1228
|
+
valid=valid,
|
|
1229
|
+
chi2_total=float(chi2_total),
|
|
1230
|
+
chi2_components={k: float(v) for k, v in chi2_components.items()},
|
|
1231
|
+
)
|
|
1232
|
+
|
|
1233
|
+
|
|
1234
|
+
def fit_streamline(initial_opt_params, fixed_params, streamer, distance_pc,
|
|
1235
|
+
learning_rate=0.005, param_bounds=None, n_epochs=1000,
|
|
1236
|
+
beta1=0.9, beta2=0.999,
|
|
1237
|
+
info_every=100, loss_threshold=None, loss_threshold_epochs=1,
|
|
1238
|
+
gradient_tol=None, gradient_tol_epochs=1,
|
|
1239
|
+
early_stopping_patience=50,
|
|
1240
|
+
save_folder='sting_results',
|
|
1241
|
+
loss_method=1, # 0: radecvel, 1: rthetavel
|
|
1242
|
+
priors=None,
|
|
1243
|
+
v_lsr=None,
|
|
1244
|
+
show_plots=False
|
|
1245
|
+
):
|
|
1246
|
+
"""
|
|
1247
|
+
Fit streamline model parameters to data using Adam optimiser.
|
|
1248
|
+
Any supported streamline parameter can be optimised or fixed.
|
|
1249
|
+
Parameters are split by dictionary:
|
|
1250
|
+
- keys in initial_opt_params are optimised
|
|
1251
|
+
- keys in fixed_params are held fixed
|
|
1252
|
+
The union must contain each key in STREAMLINE_MODEL_PARAM_KEYS exactly once.
|
|
1253
|
+
|
|
1254
|
+
Parameters:
|
|
1255
|
+
-----------
|
|
1256
|
+
initial_opt_params : dict
|
|
1257
|
+
Initial guesses for the parameters to optimise.
|
|
1258
|
+
Allowed keys are STREAMLINE_MODEL_PARAM_KEYS.
|
|
1259
|
+
fixed_params : dict
|
|
1260
|
+
Fixed (non-optimised) parameters using the same key space.
|
|
1261
|
+
Together with initial_opt_params, this must provide a full,
|
|
1262
|
+
non-overlapping partition of STREAMLINE_MODEL_PARAM_KEYS.
|
|
1263
|
+
streamer: NamedTuple with fields:
|
|
1264
|
+
pc_coords, ra_data, dec_data, v_data, ra_sigma, dec_sigma, v_sigma, data, uncertainties
|
|
1265
|
+
data : tuple of arrays (ra_data, dec_data, v_data)
|
|
1266
|
+
Observed RA offset (arcsec), Dec offset (arcsec), velocity (km/s)
|
|
1267
|
+
uncertainties : tuple of arrays (ra_sigma, dec_sigma, v_sigma)
|
|
1268
|
+
Uncertainties on the data
|
|
1269
|
+
distance_pc : float
|
|
1270
|
+
Distance to source in parsecs
|
|
1271
|
+
learning_rate : float
|
|
1272
|
+
Adam learning rate applied uniformly to all normalised parameters.
|
|
1273
|
+
param_bounds : dict or None
|
|
1274
|
+
Parameter bounds in physical/log parameter units.
|
|
1275
|
+
Parameters that require bounds here (if optimised): r0, mass, inc, pa
|
|
1276
|
+
Provide param_bounds as a dictionary with values as (min, max) tuples for each parameter
|
|
1277
|
+
n_epochs : int
|
|
1278
|
+
Maximum number of optimisation iterations
|
|
1279
|
+
beta1 : float
|
|
1280
|
+
Adam exponential decay rate for first moment
|
|
1281
|
+
beta2 : float
|
|
1282
|
+
Adam exponential decay rate for second moment
|
|
1283
|
+
info_every : int
|
|
1284
|
+
Print loss every N epochs
|
|
1285
|
+
early_stopping_patience : int
|
|
1286
|
+
Stop if loss doesn't improve for N epochs
|
|
1287
|
+
save_folder : str
|
|
1288
|
+
Folder to save output CSV and trace files, and figures. Created if it doesn't exist.
|
|
1289
|
+
loss_method : int
|
|
1290
|
+
Loss definition to use. Options:
|
|
1291
|
+
- 0: radecvel: optimise RA, Dec, and velocity residuals.
|
|
1292
|
+
- 1: rthetavel: optimise projected radial distance, polar angle, and velocity residuals.
|
|
1293
|
+
Both options use the same model-data matching and overlap penalty.
|
|
1294
|
+
priors : dict or None
|
|
1295
|
+
Optional Gaussian priors on optimised parameters, in the form
|
|
1296
|
+
{param_name: (mean, sigma), ...}. Only optimised parameters can have priors.
|
|
1297
|
+
Mean and sigma must be in the same canonical units as the rest of the code
|
|
1298
|
+
v_lsr : float or None
|
|
1299
|
+
Systemic velocity (km/s). When provided, drawn as a reference line on the best-fit
|
|
1300
|
+
velocity-radius plot
|
|
1301
|
+
loss_threshold : float or None
|
|
1302
|
+
Optional absolute loss threshold for threshold-based stopping.
|
|
1303
|
+
If provided, optimisation stops after loss is <= loss_threshold for
|
|
1304
|
+
loss_threshold_epochs consecutive epochs.
|
|
1305
|
+
loss_threshold_epochs : int
|
|
1306
|
+
Number of consecutive epochs with loss <= loss_threshold required to
|
|
1307
|
+
trigger threshold-based early stopping. Must be >= 1.
|
|
1308
|
+
gradient_tol : float or None
|
|
1309
|
+
Optional gradient norm tolerance for stopping in normalised space.
|
|
1310
|
+
If provided, optimisation stops when the L2 norm of gradients with
|
|
1311
|
+
respect to normalised parameters
|
|
1312
|
+
is less than this threshold for gradient_tol_epochs consecutive epochs,
|
|
1313
|
+
indicating convergence.
|
|
1314
|
+
gradient_tol_epochs : int
|
|
1315
|
+
Number of consecutive epochs with ||grad|| < gradient_tol required to
|
|
1316
|
+
trigger normalised-space gradient norm-based early stopping. Must be >= 1.
|
|
1317
|
+
show_plots : bool
|
|
1318
|
+
Whether to show diagnostic plots during optimisation
|
|
1319
|
+
|
|
1320
|
+
Epoch 0: initial state before any updates, with initial_opt_params
|
|
1321
|
+
Epoch n (n>=1): state after applying parameter update n
|
|
1322
|
+
Tracking and checks are all performed at the end of each epoch. So e.g. loss n = loss after applying update n, using the updated parameters
|
|
1323
|
+
|
|
1324
|
+
Returns:
|
|
1325
|
+
--------
|
|
1326
|
+
FitResult namedtuple with fields:
|
|
1327
|
+
- best_opt_params : dict of best-fit optimised parameters (physical/log units)
|
|
1328
|
+
- loss_history : list of loss values at each epoch (float)
|
|
1329
|
+
- param_errors: dict of estimated 1-sigma uncertainties for each optimised parameter in the display parameterisation (or None if uncertainty estimation failed)
|
|
1330
|
+
- covariance_result: CovarianceResult or None: full covariance information needed for sampling, or None if estimation failed. Fields:
|
|
1331
|
+
- covariance : 2D array of covariance matrix in physical/log units
|
|
1332
|
+
- opt_keys: list of parameter keys corresponding to covariance_matrix rows/columns
|
|
1333
|
+
- best_opt_params: dict of best-fit optimised parameters (physical/log units)
|
|
1334
|
+
- fixed_params: dict of fixed parameters (physical/log units)
|
|
1335
|
+
- param_errors: dict of 1-sigma parameter uncertainties (physical/log units)
|
|
1336
|
+
- transformed_cov: dict of Jacobian-transformed covariance when 'rc'/'omega' was substutied by 'mu', keys are 'keys', 'cov', 'errors'
|
|
1337
|
+
"""
|
|
1338
|
+
# lazy imports to avoid circular imports
|
|
1339
|
+
from . import outputs
|
|
1340
|
+
from . import errors
|
|
1341
|
+
# Initialize parameters
|
|
1342
|
+
loss_method = check_loss_method(loss_method)
|
|
1343
|
+
|
|
1344
|
+
opt_params, fixed_params = sanitize_param_partition(
|
|
1345
|
+
initial_opt_params,
|
|
1346
|
+
fixed_params,
|
|
1347
|
+
require_nonempty_opt=True,
|
|
1348
|
+
)
|
|
1349
|
+
|
|
1350
|
+
|
|
1351
|
+
param_bounds = standardise_param_bounds(param_bounds)
|
|
1352
|
+
param_bounds = convert_and_strip_bound_units(param_bounds)
|
|
1353
|
+
|
|
1354
|
+
# we perform optimisation in mu-space when either rc or omega is present. conversion is here
|
|
1355
|
+
# rotation_key records which of 'rc', 'omega', or 'mu' is input as rotation parameter by user,
|
|
1356
|
+
# so we know which one to convert back to at the end
|
|
1357
|
+
opt_params, fixed_params, param_bounds, rotation_key = with_mu_substituted(opt_params, fixed_params, param_bounds)
|
|
1358
|
+
|
|
1359
|
+
# add bounds for angles if they are being optimised
|
|
1360
|
+
param_bounds = auto_fill_angle_bounds(opt_params, param_bounds)
|
|
1361
|
+
|
|
1362
|
+
# check priors are valid and match optimised parameters
|
|
1363
|
+
validated_priors = validate_priors(priors, opt_params, fixed_params)
|
|
1364
|
+
priors_keys = tuple(validated_priors.keys())
|
|
1365
|
+
priors_means = tuple(float(v[0]) for v in validated_priors.values())
|
|
1366
|
+
priors_sigmas = tuple(float(v[1]) for v in validated_priors.values())
|
|
1367
|
+
|
|
1368
|
+
|
|
1369
|
+
opt_param_keys = list(opt_params.keys())
|
|
1370
|
+
data = make_data_tuple_float64(streamer.data)
|
|
1371
|
+
uncertainties = make_data_tuple_float64(streamer.uncertainties)
|
|
1372
|
+
distance_pc = to_float64(distance_pc)
|
|
1373
|
+
learning_rate = to_float64(learning_rate)
|
|
1374
|
+
if not bool(jnp.isfinite(learning_rate)):
|
|
1375
|
+
raise ValueError(f'learning_rate must be finite. Got {learning_rate}.')
|
|
1376
|
+
if not bool(learning_rate > 0):
|
|
1377
|
+
raise ValueError(f'learning_rate must be > 0. Got {float(learning_rate)}.')
|
|
1378
|
+
normalisation_spec = build_normalisation_spec(opt_params, param_bounds)
|
|
1379
|
+
|
|
1380
|
+
# Keep optimisation variables in normalised coordinates; convert back to
|
|
1381
|
+
# physical/log units only when evaluating the forward model and diagnostics.
|
|
1382
|
+
opt_params_norm = normalise_opt_params(opt_params, normalisation_spec)
|
|
1383
|
+
|
|
1384
|
+
# Use one global learning rate on normalised parameters
|
|
1385
|
+
solver = optax.adam(learning_rate=learning_rate, b1=beta1, b2=beta2)
|
|
1386
|
+
|
|
1387
|
+
opt_state = solver.init(opt_params_norm)
|
|
1388
|
+
|
|
1389
|
+
# Precompute data-only quantities once before optimisation loop
|
|
1390
|
+
prepared_data = extract_streamline.prepare_data(data, uncertainties, n_elements=len(data[0]))
|
|
1391
|
+
# npoints for forward model evaluation: fixed large number set by max r0 bound and deltar
|
|
1392
|
+
# this is necessary to ensure forward model has constant array lengths for jax/jit compatability
|
|
1393
|
+
if 'r0' in param_bounds:
|
|
1394
|
+
max_r0 = param_bounds['r0'][1]
|
|
1395
|
+
deltar = fixed_params['deltar'] if 'deltar' in fixed_params else 1.0
|
|
1396
|
+
npoints = int(jnp.ceil(max_r0 / deltar))
|
|
1397
|
+
else:
|
|
1398
|
+
npoints = 50000
|
|
1399
|
+
|
|
1400
|
+
fixed_params_for_core = fixed_params
|
|
1401
|
+
|
|
1402
|
+
@jax.jit
|
|
1403
|
+
def loss_from_normalised(norm_opt_params):
|
|
1404
|
+
physical_opt_params = denormalise_opt_params(norm_opt_params, normalisation_spec)
|
|
1405
|
+
model_params = {**fixed_params_for_core, **physical_opt_params}
|
|
1406
|
+
chi2_total, loss_trace, err = chi2_loss(
|
|
1407
|
+
model_params,
|
|
1408
|
+
distance_pc,
|
|
1409
|
+
prepared_data,
|
|
1410
|
+
loss_method=loss_method,
|
|
1411
|
+
npoints=npoints,
|
|
1412
|
+
priors_keys=priors_keys,
|
|
1413
|
+
priors_means=priors_means,
|
|
1414
|
+
priors_sigmas=priors_sigmas
|
|
1415
|
+
)
|
|
1416
|
+
return chi2_total, (loss_trace, err)
|
|
1417
|
+
|
|
1418
|
+
|
|
1419
|
+
# Create gradient functions in normalised space.
|
|
1420
|
+
loss_and_grad_fn = value_and_grad(loss_from_normalised, has_aux=True)
|
|
1421
|
+
|
|
1422
|
+
|
|
1423
|
+
# Track loss history
|
|
1424
|
+
loss_history = []
|
|
1425
|
+
initial_loss, (_, initial_err) = loss_from_normalised(opt_params_norm)
|
|
1426
|
+
|
|
1427
|
+
# raise any initial errors
|
|
1428
|
+
initial_error_message = get_checkify_error_message(initial_err)
|
|
1429
|
+
if initial_error_message is not None:
|
|
1430
|
+
raise ValueError(
|
|
1431
|
+
f"Initial loss computation failed with error: {initial_error_message}. "
|
|
1432
|
+
)
|
|
1433
|
+
|
|
1434
|
+
initial_loss = float(initial_loss)
|
|
1435
|
+
loss_history.append(initial_loss) # 'epoch 0' loss (initial state, before any updates)
|
|
1436
|
+
best_loss = initial_loss
|
|
1437
|
+
best_opt_params = opt_params.copy()
|
|
1438
|
+
best_opt_params_norm = opt_params_norm.copy()
|
|
1439
|
+
best_epoch = 0
|
|
1440
|
+
patience_counter = 0
|
|
1441
|
+
loss_threshold_counter = 0
|
|
1442
|
+
gradient_tol_counter = 0
|
|
1443
|
+
ordered_best_opt_params = {k: best_opt_params[k] for k in opt_param_keys}
|
|
1444
|
+
|
|
1445
|
+
if loss_threshold is not None:
|
|
1446
|
+
loss_threshold = float(loss_threshold)
|
|
1447
|
+
if not math.isfinite(loss_threshold):
|
|
1448
|
+
raise ValueError('loss_threshold must be finite when provided.')
|
|
1449
|
+
if loss_threshold_epochs < 1:
|
|
1450
|
+
raise ValueError('loss_threshold_epochs must be >= 1 when loss_threshold is provided.')
|
|
1451
|
+
|
|
1452
|
+
if gradient_tol is not None:
|
|
1453
|
+
gradient_tol = float(gradient_tol)
|
|
1454
|
+
if not math.isfinite(gradient_tol):
|
|
1455
|
+
raise ValueError('gradient_tol must be finite when provided.')
|
|
1456
|
+
if gradient_tol <= 0:
|
|
1457
|
+
raise ValueError('gradient_tol must be positive when provided.')
|
|
1458
|
+
if gradient_tol_epochs < 1:
|
|
1459
|
+
raise ValueError('gradient_tol_epochs must be >= 1 when gradient_tol is provided.')
|
|
1460
|
+
|
|
1461
|
+
# initialise log and trace files if output_folder is provided
|
|
1462
|
+
log_file = None
|
|
1463
|
+
log_writer = None
|
|
1464
|
+
trace_file = None
|
|
1465
|
+
trace_writer = None
|
|
1466
|
+
if save_folder is not None:
|
|
1467
|
+
os.makedirs(save_folder, exist_ok=True)
|
|
1468
|
+
|
|
1469
|
+
log_file = os.path.join(save_folder, 'optimisation_log.csv')
|
|
1470
|
+
log_file = open(log_file, 'w', newline='')
|
|
1471
|
+
# Create header: epoch, loss, then all optimisable params
|
|
1472
|
+
fieldnames = ['epoch', 'loss'] + [log_header(k) for k in opt_param_keys]
|
|
1473
|
+
|
|
1474
|
+
all_param_keys = set(opt_param_keys) | set(fixed_params.keys())
|
|
1475
|
+
if 'mu' in all_param_keys:
|
|
1476
|
+
# also log derived rc and omega when mu is present for convenience
|
|
1477
|
+
if log_header('rc') not in fieldnames:
|
|
1478
|
+
fieldnames.append(log_header('rc'))
|
|
1479
|
+
if log_header('omega') not in fieldnames:
|
|
1480
|
+
fieldnames.append(log_header('omega'))
|
|
1481
|
+
log_writer = csv.DictWriter(log_file, fieldnames=fieldnames)
|
|
1482
|
+
log_writer.writeheader()
|
|
1483
|
+
log_file.flush()
|
|
1484
|
+
|
|
1485
|
+
trace_file = os.path.join(save_folder, 'optimisation_trace.csv')
|
|
1486
|
+
trace_file = open(trace_file, 'w', newline='')
|
|
1487
|
+
trace_writer = csv.DictWriter(
|
|
1488
|
+
trace_file,
|
|
1489
|
+
fieldnames=trace_fieldnames_for_loss_method(loss_method),
|
|
1490
|
+
)
|
|
1491
|
+
trace_writer.writeheader()
|
|
1492
|
+
trace_file.flush()
|
|
1493
|
+
|
|
1494
|
+
print(f"Starting optimisation with {n_epochs} epochs...")
|
|
1495
|
+
print(f"Loss method: {loss_method}")
|
|
1496
|
+
print(f"optimising parameters: {opt_param_keys}")
|
|
1497
|
+
print(f"Fixed parameters: {list(fixed_params.keys())}")
|
|
1498
|
+
if validated_priors:
|
|
1499
|
+
print(f"Priors:")
|
|
1500
|
+
for key, (mu_p, sigma_p) in validated_priors.items():
|
|
1501
|
+
print(f" {key}: mean={format_param(key, mu_p)}, sigma={format_param(key, sigma_p)}")
|
|
1502
|
+
if loss_threshold is not None:
|
|
1503
|
+
print(
|
|
1504
|
+
f"Threshold-based stopping enabled: loss <= {loss_threshold:.6g} "
|
|
1505
|
+
f"for {loss_threshold_epochs} consecutive epochs."
|
|
1506
|
+
)
|
|
1507
|
+
if gradient_tol is not None:
|
|
1508
|
+
print(
|
|
1509
|
+
f"Gradient norm stopping enabled (normalised space): ||grad|| < {gradient_tol:.6g} "
|
|
1510
|
+
f"for {gradient_tol_epochs} consecutive epochs."
|
|
1511
|
+
)
|
|
1512
|
+
print(f"Initial optimisable values:")
|
|
1513
|
+
for key in opt_param_keys:
|
|
1514
|
+
print(f" {key}: {format_param(key, opt_params[key])}")
|
|
1515
|
+
print(f"Initial loss: {initial_loss:.6g}")
|
|
1516
|
+
|
|
1517
|
+
# Log epoch 0: initial state (before any updates)
|
|
1518
|
+
initial_loss = float(initial_loss)
|
|
1519
|
+
if log_writer is not None:
|
|
1520
|
+
row = {'epoch': 0, 'loss': initial_loss}
|
|
1521
|
+
for key in opt_param_keys:
|
|
1522
|
+
row[log_header(key)] = float(opt_params[key])
|
|
1523
|
+
row = add_rc_omega_to_log(row, opt_params, fixed_params, all_param_keys)
|
|
1524
|
+
if 'rc' in row:
|
|
1525
|
+
row[log_header('rc')] = row.pop('rc')
|
|
1526
|
+
if 'omega' in row:
|
|
1527
|
+
row[log_header('omega')] = row.pop('omega')
|
|
1528
|
+
log_writer.writerow(row)
|
|
1529
|
+
log_file.flush()
|
|
1530
|
+
|
|
1531
|
+
# Log epoch 0 trace if trace file is requested
|
|
1532
|
+
if trace_writer is not None:
|
|
1533
|
+
# Compute initial loss and trace
|
|
1534
|
+
(loss_value_trace, (loss_trace_raw, _)), norm_grads_trace = loss_and_grad_fn(opt_params_norm)
|
|
1535
|
+
loss_trace = trace_tree_to_python(loss_trace_raw)
|
|
1536
|
+
grad_norm = float(gradient_l2_norm(norm_grads_trace))
|
|
1537
|
+
|
|
1538
|
+
# Build and write trace row for epoch 0
|
|
1539
|
+
trace_row = build_trace_row(0, float(loss_value_trace), loss_trace, grad_norm, loss_method)
|
|
1540
|
+
trace_writer.writerow(trace_row)
|
|
1541
|
+
trace_file.flush()
|
|
1542
|
+
|
|
1543
|
+
|
|
1544
|
+
try:
|
|
1545
|
+
for epoch in range(1, n_epochs + 1):
|
|
1546
|
+
if epoch % info_every == 0:
|
|
1547
|
+
print(f"\n Starting Epoch {epoch} -------------------------")
|
|
1548
|
+
# Compute loss and gradients at pre-update normalised parameters.
|
|
1549
|
+
loss_trace = None
|
|
1550
|
+
(loss_before, _), norm_grads = loss_and_grad_fn(opt_params_norm)
|
|
1551
|
+
|
|
1552
|
+
|
|
1553
|
+
# Perform Optax Adam step in normalised space (apply update).
|
|
1554
|
+
updates, opt_state = solver.update(norm_grads, opt_state, params=opt_params_norm)
|
|
1555
|
+
opt_params_norm = optax.apply_updates(opt_params_norm, updates)
|
|
1556
|
+
|
|
1557
|
+
# Enforce normalised bounds and map back to physical/log values.
|
|
1558
|
+
for key in opt_param_keys:
|
|
1559
|
+
if key == 'phi0':
|
|
1560
|
+
# phi0 is cyclic; wrap to [0, 1) in normalised space
|
|
1561
|
+
opt_params_norm[key] = jnp.mod(opt_params_norm[key], 1.0)
|
|
1562
|
+
elif key == 'rc':
|
|
1563
|
+
# must be positive >0
|
|
1564
|
+
opt_params_norm[key] = jnp.clip(opt_params_norm[key], to_float64(1e-6), 1.0)
|
|
1565
|
+
elif key == 'omega':
|
|
1566
|
+
# must be positive >0
|
|
1567
|
+
opt_params_norm[key] = jnp.clip(opt_params_norm[key], to_float64(1e-6), 1.0)
|
|
1568
|
+
elif key == 'v_r0':
|
|
1569
|
+
#already dealt with
|
|
1570
|
+
continue
|
|
1571
|
+
else:
|
|
1572
|
+
opt_params_norm[key] = jnp.clip(opt_params_norm[key], 0.0, 1.0)
|
|
1573
|
+
|
|
1574
|
+
# Now materialize physical parameters from the (possibly clamped)
|
|
1575
|
+
# normalised parameters.
|
|
1576
|
+
opt_params = denormalise_opt_params(opt_params_norm, normalisation_spec)
|
|
1577
|
+
|
|
1578
|
+
# Compute loss and gradient at the post-update state S(epoch) for logging
|
|
1579
|
+
(loss_value, (loss_trace_raw, err)), norm_grads = loss_and_grad_fn(opt_params_norm)
|
|
1580
|
+
# print the gradients by parameter for debugging
|
|
1581
|
+
grad_norm = float(gradient_l2_norm(norm_grads))
|
|
1582
|
+
|
|
1583
|
+
# raise any errors
|
|
1584
|
+
error_message = get_checkify_error_message(err)
|
|
1585
|
+
if error_message is not None:
|
|
1586
|
+
print(
|
|
1587
|
+
f"\nStopping at epoch {epoch}: {error_message} "
|
|
1588
|
+
)
|
|
1589
|
+
break
|
|
1590
|
+
|
|
1591
|
+
|
|
1592
|
+
loss_trace = trace_tree_to_python(loss_trace_raw)
|
|
1593
|
+
loss_value = float(loss_value)
|
|
1594
|
+
|
|
1595
|
+
|
|
1596
|
+
# Log post-update state for this epoch
|
|
1597
|
+
if log_writer is not None:
|
|
1598
|
+
row = {'epoch': epoch, 'loss': loss_value}
|
|
1599
|
+
for key in opt_param_keys:
|
|
1600
|
+
row[log_header(key)] = float(opt_params[key])
|
|
1601
|
+
|
|
1602
|
+
row = add_rc_omega_to_log(row, opt_params, fixed_params, all_param_keys)
|
|
1603
|
+
|
|
1604
|
+
if 'rc' in row:
|
|
1605
|
+
row[log_header('rc')] = row.pop('rc')
|
|
1606
|
+
if 'omega' in row:
|
|
1607
|
+
row[log_header('omega')] = row.pop('omega')
|
|
1608
|
+
log_writer.writerow(row)
|
|
1609
|
+
log_file.flush()
|
|
1610
|
+
|
|
1611
|
+
# Track loss (store the loss for the loss_history)
|
|
1612
|
+
loss_history.append(loss_value)
|
|
1613
|
+
|
|
1614
|
+
if trace_writer is not None and loss_trace is not None:
|
|
1615
|
+
# Use the post-update loss value for trace logging
|
|
1616
|
+
trace_row = build_trace_row(epoch, loss_value, loss_trace, grad_norm, loss_method)
|
|
1617
|
+
trace_writer.writerow(trace_row)
|
|
1618
|
+
trace_file.flush()
|
|
1619
|
+
|
|
1620
|
+
# Early stopping checks (use post-update loss)
|
|
1621
|
+
if loss_value < best_loss:
|
|
1622
|
+
best_loss = loss_value
|
|
1623
|
+
best_opt_params = opt_params.copy()
|
|
1624
|
+
best_opt_params_norm = opt_params_norm.copy()
|
|
1625
|
+
best_epoch = epoch
|
|
1626
|
+
patience_counter = 0
|
|
1627
|
+
else:
|
|
1628
|
+
patience_counter += 1
|
|
1629
|
+
|
|
1630
|
+
if loss_threshold is not None:
|
|
1631
|
+
if loss_value <= loss_threshold:
|
|
1632
|
+
loss_threshold_counter += 1
|
|
1633
|
+
else:
|
|
1634
|
+
loss_threshold_counter = 0
|
|
1635
|
+
|
|
1636
|
+
if gradient_tol is not None:
|
|
1637
|
+
if grad_norm < gradient_tol:
|
|
1638
|
+
gradient_tol_counter += 1
|
|
1639
|
+
else:
|
|
1640
|
+
gradient_tol_counter = 0
|
|
1641
|
+
|
|
1642
|
+
# Print progress
|
|
1643
|
+
if epoch % info_every == 0:
|
|
1644
|
+
if gradient_tol is not None:
|
|
1645
|
+
print(f'Epoch {epoch}/{n_epochs}, Loss: {loss_value:.6f}, Best Loss: {best_loss:.6f}, ||grad||: {grad_norm:.6e}')
|
|
1646
|
+
else:
|
|
1647
|
+
print(f'Epoch {epoch}/{n_epochs}, Loss: {loss_value:.6f}, Best Loss: {best_loss:.6f}')
|
|
1648
|
+
|
|
1649
|
+
# Early stopping conditions (any one is sufficient to stop)
|
|
1650
|
+
if loss_threshold is not None and loss_threshold_counter >= loss_threshold_epochs:
|
|
1651
|
+
print(
|
|
1652
|
+
f"\nEarly stopping at epoch {epoch}: loss <= {loss_threshold:.6g} "
|
|
1653
|
+
f"for {loss_threshold_epochs} consecutive epochs"
|
|
1654
|
+
)
|
|
1655
|
+
break
|
|
1656
|
+
|
|
1657
|
+
if gradient_tol is not None and gradient_tol_counter >= gradient_tol_epochs:
|
|
1658
|
+
print(
|
|
1659
|
+
f"\nEarly stopping at epoch {epoch}: normalised gradient norm {grad_norm:.6e} < {gradient_tol:.6e} "
|
|
1660
|
+
f"for {gradient_tol_epochs} consecutive epochs"
|
|
1661
|
+
)
|
|
1662
|
+
break
|
|
1663
|
+
|
|
1664
|
+
if patience_counter >= early_stopping_patience:
|
|
1665
|
+
print(f"\nEarly stopping at epoch {epoch}: no improvement for {early_stopping_patience} epochs")
|
|
1666
|
+
break
|
|
1667
|
+
|
|
1668
|
+
|
|
1669
|
+
# restore canonical parameter order before returning
|
|
1670
|
+
ordered_best_opt_params = {k: best_opt_params[k] for k in opt_param_keys}
|
|
1671
|
+
|
|
1672
|
+
finally:
|
|
1673
|
+
# Always close the CSV file if it was opened
|
|
1674
|
+
if log_file is not None:
|
|
1675
|
+
log_path = log_file.name
|
|
1676
|
+
log_file.close()
|
|
1677
|
+
print(f"Optimisation log saved to: {log_path}")
|
|
1678
|
+
if trace_file is not None:
|
|
1679
|
+
trace_path = trace_file.name
|
|
1680
|
+
trace_file.close()
|
|
1681
|
+
print(f"Matching trace log saved to: {trace_path}")
|
|
1682
|
+
|
|
1683
|
+
print(f"Optimisation complete!")
|
|
1684
|
+
print(f"Best-fit parameters found at epoch: {best_epoch}, with loss: {best_loss:.6f}")
|
|
1685
|
+
|
|
1686
|
+
# compute errors on best-fit parameters
|
|
1687
|
+
print("\nEstimating parameter uncertainties from Hessian...")
|
|
1688
|
+
param_errors = None
|
|
1689
|
+
cov_matrix = None
|
|
1690
|
+
cov_transformed_dict = None
|
|
1691
|
+
# only input the rotation key if if actually needs transforming back from mu
|
|
1692
|
+
key_needs_transform = rotation_key if rotation_key in ('rc', 'omega') else None
|
|
1693
|
+
try:
|
|
1694
|
+
param_errors, cov_matrix, cov_transformed_dict = errors.estimate_parameter_errors(
|
|
1695
|
+
ordered_best_opt_params,
|
|
1696
|
+
fixed_params,
|
|
1697
|
+
distance_pc,
|
|
1698
|
+
prepared_data,
|
|
1699
|
+
loss_method=loss_method,
|
|
1700
|
+
gradient_tol=gradient_tol,
|
|
1701
|
+
normalisation_spec=normalisation_spec,
|
|
1702
|
+
best_norm_opt_params=best_opt_params_norm,
|
|
1703
|
+
rotation_key=key_needs_transform,
|
|
1704
|
+
npoints=npoints,
|
|
1705
|
+
priors_keys=priors_keys,
|
|
1706
|
+
priors_means=priors_means,
|
|
1707
|
+
priors_sigmas=priors_sigmas
|
|
1708
|
+
)
|
|
1709
|
+
except Exception as e:
|
|
1710
|
+
print(f"\nWarning: parameter uncertainty estimation failed: ({e}).")
|
|
1711
|
+
traceback.print_exc()
|
|
1712
|
+
print("Continuing without error estimates")
|
|
1713
|
+
|
|
1714
|
+
display_opt_params = dict(ordered_best_opt_params)
|
|
1715
|
+
display_fixed_params = dict(fixed_params)
|
|
1716
|
+
display_param_errors = dict(param_errors) if param_errors is not None else None
|
|
1717
|
+
|
|
1718
|
+
if cov_transformed_dict is not None and key_needs_transform is not None and display_param_errors is not None:
|
|
1719
|
+
if key_needs_transform in cov_transformed_dict['keys']:
|
|
1720
|
+
all_params_for_transform = {**display_fixed_params, **display_opt_params}
|
|
1721
|
+
mu_best = float(ordered_best_opt_params['mu'])
|
|
1722
|
+
mass_val = float(all_params_for_transform['mass'])
|
|
1723
|
+
r0_val = float(all_params_for_transform['r0'])
|
|
1724
|
+
display_opt_params[key_needs_transform] = rotation_param_from_mu(key_needs_transform, mu_best, mass_val, r0_val)
|
|
1725
|
+
display_opt_params.pop('mu', None)
|
|
1726
|
+
display_param_errors[key_needs_transform] = cov_transformed_dict['errors'][key_needs_transform]
|
|
1727
|
+
display_param_errors.pop('mu', None)
|
|
1728
|
+
|
|
1729
|
+
print("\nFinal parameters at best-fit:")
|
|
1730
|
+
all_display_params = {**display_fixed_params, **display_opt_params}
|
|
1731
|
+
for key in all_display_params.keys():
|
|
1732
|
+
value = all_display_params[key]
|
|
1733
|
+
if display_param_errors is not None and key in display_param_errors:
|
|
1734
|
+
error = display_param_errors[key]
|
|
1735
|
+
print(f" {key}: {format_param(key, value)} ± {format_param(key, error)}")
|
|
1736
|
+
else:
|
|
1737
|
+
print(f" {key}: {format_param(key, value)}")
|
|
1738
|
+
|
|
1739
|
+
outputs.save_best_fit_params(display_opt_params, display_fixed_params, display_param_errors, save_folder=save_folder)
|
|
1740
|
+
|
|
1741
|
+
# now we will make some plots of the results
|
|
1742
|
+
if save_folder is not None:
|
|
1743
|
+
print("\nMaking diagnostic plots...")
|
|
1744
|
+
outputs.plot_fitting_results(
|
|
1745
|
+
ordered_best_opt_params,
|
|
1746
|
+
opt_param_keys,
|
|
1747
|
+
fixed_params,
|
|
1748
|
+
streamer,
|
|
1749
|
+
distance_pc,
|
|
1750
|
+
loss_history,
|
|
1751
|
+
param_errors=param_errors,
|
|
1752
|
+
cov_matrix=cov_matrix,
|
|
1753
|
+
v_lsr=v_lsr,
|
|
1754
|
+
save_folder=save_folder,
|
|
1755
|
+
show_plots=show_plots,
|
|
1756
|
+
transformed_cov_result=cov_transformed_dict,
|
|
1757
|
+
)
|
|
1758
|
+
|
|
1759
|
+
# save results to CovarianceResult and FitResult namedtuples
|
|
1760
|
+
cov_result = None
|
|
1761
|
+
if cov_matrix is not None:
|
|
1762
|
+
cov_result = CovarianceResult(
|
|
1763
|
+
covariance=cov_matrix,
|
|
1764
|
+
opt_keys=list(ordered_best_opt_params.keys()),
|
|
1765
|
+
best_opt_params=ordered_best_opt_params,
|
|
1766
|
+
fixed_params=fixed_params,
|
|
1767
|
+
param_errors=param_errors,
|
|
1768
|
+
transformed_cov=cov_transformed_dict
|
|
1769
|
+
)
|
|
1770
|
+
|
|
1771
|
+
return FitResult(
|
|
1772
|
+
best_opt_params=display_opt_params,
|
|
1773
|
+
loss_history=loss_history,
|
|
1774
|
+
param_errors=display_param_errors,
|
|
1775
|
+
covariance_result=cov_result,
|
|
1776
|
+
)
|