sting 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sting/outputs.py ADDED
@@ -0,0 +1,1705 @@
1
+ '''
2
+ This file contains functions related to outputs from streamfit optimisation,
3
+ such as saving logs and plotting results.
4
+
5
+ Last updated: 03-06-26
6
+ '''
7
+ import json
8
+ import math
9
+ from matplotlib.patches import Patch
10
+ import numpy as np
11
+ import jax.numpy as jnp
12
+ import matplotlib.pyplot as plt
13
+ import matplotlib as mpl
14
+ mpl.rcParams["font.family"] = "serif"
15
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
16
+ import pandas as pd
17
+ import os
18
+
19
+
20
+ from . import gradient_descent
21
+ from . import extract_streamline
22
+
23
+ def param_for_display(key, value):
24
+ """
25
+ Format parameter for display in output, with units. Notably:
26
+ - converts angles (theta0, phi0, inc, pa) from radians to degrees
27
+ Returns (display_key, display_value, unit_str)
28
+ """
29
+ if key in gradient_descent.ANGLE_KEYS:
30
+ return key, math.degrees(float(value)), 'deg'
31
+ unit = gradient_descent.DISPLAY_UNITS.get(key, '')
32
+ return key, float(value), unit
33
+
34
+
35
+ def evaluate_best_fit(
36
+ best_opt_params,
37
+ fixed_params,
38
+ data,
39
+ distance_pc,
40
+ by_eye_params=None,
41
+ ):
42
+ """
43
+ Run the forward model and match it to data for the best-fit parameters. Optionally run a by-eye parameter set through the forward model
44
+
45
+ Parameters
46
+ ----------
47
+ best_opt_params : dict
48
+ Best-fit optimised parameters.
49
+ fixed_params : dict
50
+ Fixed model parameters.
51
+ data : tuple of arrays (ra_data, dec_data, v_data)
52
+ distance_pc : float
53
+ by_eye_params : dict or None
54
+ Optional by-eye parameters
55
+ Returns
56
+ -------
57
+ dict with keys:
58
+ ra_model, dec_model, v_model : full forward model arrays
59
+ ra_model_interp, dec_model_interp, v_model_interp : interpolated at data positions
60
+ valid : boolean mask of retained data points
61
+ by_eye : (ra, dec, v) tuple or None
62
+ """
63
+ ra_data, dec_data, _ = data
64
+
65
+ best_opt_full_params, _, _ = gradient_descent.prepare_model_params(best_opt_params, fixed_params)
66
+ ra_best, dec_best, v_best, valid_mask_best, _err = gradient_descent.forward_model(best_opt_full_params, distance_pc)
67
+ valid_mask_best = valid_mask_best.astype(bool)
68
+
69
+ ra_best_interp, dec_best_interp, v_best_interp, valid_interp, _, _, _ = (
70
+ gradient_descent.checked_match_model_to_data_curve(
71
+ ra_best, dec_best, v_best, valid_mask_best,
72
+ jnp.asarray(ra_data, dtype=jnp.float64),
73
+ jnp.asarray(dec_data, dtype=jnp.float64),
74
+ )
75
+ )
76
+
77
+ by_eye = None
78
+ if by_eye_params is not None:
79
+ by_eye_full_params, _, _ = gradient_descent.prepare_model_params(by_eye_params, fixed_params)
80
+ ra_by_eye, dec_by_eye, v_by_eye, _, err_by_eye = gradient_descent.forward_model(by_eye_full_params, distance_pc)
81
+ err_by_eye.throw()
82
+ by_eye = (ra_by_eye, dec_by_eye, v_by_eye)
83
+
84
+ return dict(
85
+ ra_model=ra_best,
86
+ dec_model=dec_best,
87
+ v_model=v_best,
88
+ ra_model_interp=ra_best_interp,
89
+ dec_model_interp=dec_best_interp,
90
+ v_model_interp=v_best_interp,
91
+ valid=valid_interp,
92
+ by_eye=by_eye,
93
+ )
94
+
95
+
96
+ def plot_fitting_results(
97
+ ordered_best_opt_params,
98
+ opt_param_keys,
99
+ fixed_params,
100
+ streamer,
101
+ distance_pc,
102
+ loss_history,
103
+ param_errors,
104
+ cov_matrix,
105
+ v_lsr,
106
+ save_folder,
107
+ show_plots=False,
108
+ transformed_cov_result=None,
109
+ by_eye_params=None,
110
+ ):
111
+ """
112
+ Generate and save the followingbest-fit diagnostic plots to save_folderafter optimisation:
113
+ - loss_history.png : loss vs epoch
114
+ - best_fit_morphology.png : RA/Dec best fit
115
+ - best_fit_vel_radius.png : velocity-radius best fit
116
+ - parameter_uncertainties.png : sizes of error bars for each optimised param (if param_errors given)
117
+ - parameter_correlation_matrix.png : parameter correlation heatmap (if cov_matrix given)
118
+
119
+ Parameters
120
+ ----------
121
+ ordered_best_opt_params : dict, best-fit optimised parameters
122
+ opt_param_keys : list of str, list of optimised parameter names
123
+ fixed_params : dict, fixed model parameters.
124
+ streamer : NamedTuple with fields:
125
+ pc_coords, ra_data, dec_data, v_data, ra_sigma, dec_sigma, v_sigma, data, uncertainties
126
+ distance_pc : float
127
+ loss_history : list of float, loss value at each epoch.
128
+ param_errors : dict or None, 1-sigma parameter uncertainties keyed by parameter name, or None if uncertainty estimation failed
129
+ cov_matrix : array or None, parameter covariance matrix, or None if uncertainty estimation failed.
130
+ v_lsr : float or None, km/s
131
+ save_folder : str, directory to write figures into (created if absent).
132
+ show_plots : bool, whether to display plots (in addition to saving). Default False
133
+ transformed_cov_result: dict or None. keys expected: 'keys', 'cov', 'errors'.
134
+ by_eye_params: dict or None, optional by-eye parameter guess (with the same params as order_best_opt_params). If provided, will be plotted alongside the best-fit model in the morphology and velocity-radius plots.
135
+ """
136
+
137
+ ra_data, dec_data, v_data = streamer.data
138
+ ra_sigma, dec_sigma, v_sigma = streamer.uncertainties
139
+
140
+ # Loss
141
+ plot_loss(loss_history, save_folder=save_folder, show=show_plots)
142
+
143
+ # evaluate best fit morphology and belocity-radius
144
+ best_fit = evaluate_best_fit(ordered_best_opt_params, fixed_params, streamer.data, distance_pc, by_eye_params=by_eye_params)
145
+
146
+ plot_morphology(
147
+ streamer=streamer,
148
+ ra_model=best_fit['ra_model'],
149
+ dec_model=best_fit['dec_model'],
150
+ ra_model_interp=best_fit['ra_model_interp'],
151
+ dec_model_interp=best_fit['dec_model_interp'],
152
+ valid=best_fit['valid'],
153
+ by_eye=best_fit['by_eye'],
154
+ save_folder=save_folder,
155
+ save_name='best_fit_morphology',
156
+ show=show_plots,
157
+ )
158
+
159
+ plot_vel_radius(
160
+ streamer=streamer,
161
+ ra_model=best_fit['ra_model'],
162
+ dec_model=best_fit['dec_model'],
163
+ v_model=best_fit['v_model'],
164
+ ra_model_interp=best_fit['ra_model_interp'],
165
+ dec_model_interp=best_fit['dec_model_interp'],
166
+ v_model_interp=best_fit['v_model_interp'],
167
+ valid=best_fit['valid'],
168
+ by_eye=best_fit['by_eye'],
169
+ velocity_reference=v_lsr,
170
+ save_folder=save_folder,
171
+ save_name='best_fit_vel_radius',
172
+ show=show_plots,
173
+ )
174
+
175
+ # Uncertainty plots (only if error estimation succeeeded)
176
+ if param_errors is not None and cov_matrix is not None:
177
+ if transformed_cov_result is not None:
178
+ plot_keys = transformed_cov_result['keys'] # 'mu' replaced by 'rc'/'omega'
179
+ plot_cov = transformed_cov_result['cov'] # Jacobian-transformed covariance with 'mu' replaced by 'rc'/'omega'
180
+ plot_errors = transformed_cov_result['errors'] # 'mu' error transformed to 'rc'/'omega' error
181
+ else:
182
+ plot_keys = opt_param_keys
183
+ plot_cov = cov_matrix
184
+ plot_errors = param_errors
185
+ param_vals = np.array([float(ordered_best_opt_params.get(k, 0.0)) for k in opt_param_keys], dtype=float)
186
+ param_errs = np.array([float(plot_errors[k]) for k in plot_keys], dtype=float)
187
+ plot_param_uncertainties(plot_keys, param_vals, param_errs, save_folder=save_folder, show=show_plots)
188
+ plot_param_correlations(plot_keys, plot_cov, save_folder=save_folder, show=show_plots)
189
+
190
+
191
+ def save_best_fit_params(best_opt_params, fixed_params, param_errors, save_folder='sting_results'):
192
+ """
193
+ saves parameters from the best-fit epoch (lowest loss) and their uncertainties
194
+ (when available, fixed params will not have uncertainties) to a JSON
195
+ """
196
+ output_path = os.path.join(save_folder, 'best_fit_params.json')
197
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
198
+
199
+ # parameters that were optimised
200
+ optimised_section = {}
201
+ for raw_key, raw_val in best_opt_params.items():
202
+ display_key, display_val, unit = param_for_display(raw_key, raw_val)
203
+ entry = {
204
+ 'value': display_val,
205
+ 'unit': unit,
206
+ }
207
+ if param_errors is not None and raw_key in param_errors:
208
+ # same conversions for the errors as for the values
209
+ # for log_omega: propagate via omega * sigma_log_omega).
210
+ raw_err = float(param_errors[raw_key])
211
+ if raw_key in gradient_descent.ANGLE_KEYS:
212
+ display_err = math.degrees(raw_err)
213
+ else:
214
+ display_err = raw_err
215
+ entry['sigma'] = display_err
216
+ optimised_section[display_key] = entry
217
+
218
+ # parameters that were fixed (no uncertainties)
219
+ fixed_section = {}
220
+ for raw_key, raw_val in fixed_params.items():
221
+ if raw_val is None:
222
+ fixed_section[raw_key] = {
223
+ 'value': None,
224
+ 'unit': gradient_descent.DISPLAY_UNITS.get(raw_key, '')}
225
+ continue
226
+ display_key, display_val, unit = param_for_display(raw_key, raw_val)
227
+ fixed_section[display_key] = {
228
+ 'value': display_val,
229
+ 'unit': unit,
230
+ }
231
+
232
+ output = {
233
+ 'optimised_parameters': optimised_section,
234
+ 'fixed_parameters': fixed_section,
235
+ }
236
+
237
+ with open(output_path, 'w') as file:
238
+ json.dump(output, file, indent=4)
239
+
240
+
241
+ def _ensure_clean_dir(path):
242
+ """Create directory if missing and remove any files inside it.
243
+
244
+ Keeps behaviour consistent across plotting functions that write epoch frames.
245
+ """
246
+ os.makedirs(path, exist_ok=True)
247
+ for filename in os.listdir(path):
248
+ fp = os.path.join(path, filename)
249
+ if os.path.isfile(fp):
250
+ try:
251
+ os.remove(fp)
252
+ except OSError:
253
+ pass
254
+
255
+ def _opt_params_from_log(optimisation_log):
256
+ """Return bare parameter names from the log, stripping any unit suffixes"""
257
+ skip = ['epoch', 'loss']
258
+ cols = [c for c in optimisation_log.columns if c not in skip]
259
+ # Strip unit suffixes to get bare param name
260
+ bare = [c.split(' [')[0] for c in cols]
261
+ if 'mu' in bare:
262
+ bare = [b for b in bare if b not in ('rc', 'omega')]
263
+ return bare
264
+
265
+
266
+ def create_video_from_images(save_folder, input_pattern, output_name, fps=5):
267
+ """Call ffmpeg to make a video from numbered image frames.
268
+ If you don't have ffmpeg, will print an error message instead of crashing
269
+ """
270
+ import subprocess
271
+ import shutil
272
+
273
+ ffmpeg_exe = shutil.which("ffmpeg")
274
+ if ffmpeg_exe is None:
275
+ print(
276
+ "ffmpeg not found, can't create video.\n"
277
+ "To install: see https://ffmpeg.org/download.html and ensure ffmpeg is in your system PATH."
278
+ )
279
+ return
280
+
281
+ output_video = os.path.join(save_folder, output_name)
282
+ ffmpeg_cmd = [
283
+ "ffmpeg",
284
+ "-y",
285
+ "-loglevel", "error",
286
+ "-framerate", str(fps),
287
+ "-i", input_pattern,
288
+ "-vf",
289
+ "setpts='PTS/(1+0.01*N)',pad=ceil(iw/2)*2:ceil(ih/2)*2",
290
+ "-pix_fmt", "yuv420p",
291
+ output_video,
292
+ ]
293
+ try:
294
+ subprocess.run(ffmpeg_cmd, check=True)
295
+ print(f"Video saved to {output_video}")
296
+ except subprocess.CalledProcessError as e:
297
+ print(f"Error creating video: {e}")
298
+
299
+ def plot_loss(loss_history, save_folder='sting_results', show=False):
300
+ '''Plot loss as a function of epochs'''
301
+ # plot loss vs epoch nicely
302
+ # matplotlib serif font
303
+ plt.rcParams['font.family'] = 'serif'
304
+ # Plot loss history
305
+ # Epoch indexing: epoch 0 = initial state, epoch i (i >= 1) = after update i
306
+ # loss_history is 0-indexed: loss_history[i] = loss at epoch i
307
+ plt.figure(figsize=(12, 3))
308
+ epochs = range(len(loss_history))
309
+ plt.plot(epochs, loss_history)
310
+ plt.xlabel('Epoch')
311
+ plt.ylabel('Loss')
312
+ plt.title('Optimisation Progress')
313
+ plt.yscale('log')
314
+ plt.grid(True, alpha=0.5)
315
+ if save_folder is not None:
316
+ os.makedirs(save_folder, exist_ok=True)
317
+ plt.savefig(f'{save_folder}/loss_history.png', dpi=300, bbox_inches='tight')
318
+ if show:
319
+ plt.show()
320
+ else:
321
+ plt.close()
322
+
323
+
324
+ def make_morphology_background(pc_coords, metric_boundaries, ra_lim, dec_lim, figsize=(6.5, 7)):
325
+ """
326
+ Pre-make the background image of the point cloud and metric boundaries for the morphology plots.
327
+ This function is caleld by plot_morphology_by_epoch
328
+
329
+ Returns
330
+ -------
331
+ bg_rgba : ndarray, shape (H, W, 4)
332
+ RGBA image of the background at the target figure resolution.
333
+ extent : list [left, right, bottom, top]
334
+ Data-space extent to pass to ax.imshow so the image aligns correctly.
335
+ """
336
+ fig, ax = plt.subplots(figsize=figsize)
337
+ pc_coords_np = np.asarray(pc_coords, dtype=float)
338
+ ax.scatter(pc_coords_np[0], pc_coords_np[1], s=1, color='gray', alpha=0.3)
339
+ if metric_boundaries is not None:
340
+ extract_streamline.plot_metric_boundaries(ax, pc_coords_np, metric_boundaries,
341
+ color='gray', linewidth=1, alpha=0.3)
342
+ ax.set_xlim(ra_lim)
343
+ ax.set_ylim(dec_lim)
344
+ ax.axis('off')
345
+ fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
346
+ fig.canvas.draw()
347
+ buffer = fig.canvas.buffer_rgba()
348
+ bg_rgba = np.asarray(buffer).copy()
349
+ plt.close(fig)
350
+ extent = [ra_lim[0], ra_lim[1], dec_lim[0], dec_lim[1]]
351
+ return bg_rgba, extent
352
+
353
+ def plot_morphology_by_epoch(
354
+ gradient_descent,
355
+ fixed_params,
356
+ initial_opt_params,
357
+ distance,
358
+ streamer=None,
359
+ n_points=None,
360
+ save_folder="sting_results",
361
+ make_video=False
362
+ ):
363
+ """
364
+ Create and save one streamline morphology plot per optimisation epoch, in save_folder/epochs/morphology,
365
+ and optionally compile into a video
366
+ """
367
+ try:
368
+ optimisation_log = load_optimisation_log(save_folder)
369
+ except FileNotFoundError:
370
+ print(f"Error: Could not find 'optimisation_log.csv' in {save_folder}")
371
+ return
372
+ column_map = {c.split(' [')[0]: c for c in optimisation_log.columns}
373
+ param_names = _opt_params_from_log(optimisation_log)
374
+ fixed_params_clean, initial_opt_params = gradient_descent.sanitize_param_partition(fixed_params, initial_opt_params, require_nonempty_opt=False)
375
+
376
+ epochs = optimisation_log['epoch'].values
377
+
378
+ # create the models
379
+ epoch_models = []
380
+ for idx, epoch in enumerate(epochs):
381
+
382
+ row = optimisation_log.iloc[idx]
383
+ opt_params_epoch = {param: float(row[column_map[param]]) for param in param_names}
384
+ model_params_epoch = {**fixed_params_clean, **opt_params_epoch}
385
+ ra_model, dec_model, v_model, valid_mask_model, err = gradient_descent.forward_model(model_params_epoch, distance)
386
+ valid_mask_model = valid_mask_model.astype(bool)
387
+
388
+ (ra_model_interp, dec_model_interp, _, valid, model_keep, dmetric_model, matching_trace) = gradient_descent.checked_match_model_to_data_curve(
389
+ ra_model,
390
+ dec_model,
391
+ v_model,
392
+ valid_mask_model,
393
+ streamer.ra_data,
394
+ streamer.dec_data,
395
+ )
396
+
397
+ epoch_models.append(
398
+ dict(
399
+ epoch=epoch,
400
+ opt_params_epoch=opt_params_epoch,
401
+ ra_model=np.asarray(ra_model),
402
+ dec_model=np.asarray(dec_model),
403
+ ra_model_interp=np.asarray(ra_model_interp),
404
+ dec_model_interp=np.asarray(dec_model_interp),
405
+ valid=np.asarray(valid),
406
+ model_keep=np.asarray(model_keep),
407
+ )
408
+ )
409
+
410
+ # get constant axis limits
411
+ all_ra = np.concatenate([
412
+ *[e['ra_model'] for e in epoch_models],
413
+ np.asarray(streamer.ra_data),
414
+ np.asarray(streamer.pc_coords[0]),
415
+ ])
416
+ all_dec = np.concatenate([
417
+ *[e['dec_model'] for e in epoch_models],
418
+ np.asarray(streamer.dec_data),
419
+ np.asarray(streamer.pc_coords[1]),
420
+ ])
421
+ mask = np.isfinite(all_ra) & np.isfinite(all_dec)
422
+ all_ra = all_ra[mask]
423
+ all_dec = all_dec[mask]
424
+
425
+ pad_ra = 0.05 * (all_ra.max() - all_ra.min())
426
+ pad_dec = 0.05 * (all_dec.max() - all_dec.min())
427
+
428
+ ra_lim = (all_ra.min() - pad_ra, all_ra.max() + pad_ra)
429
+ dec_lim = (all_dec.min() - pad_dec, all_dec.max() + pad_dec)
430
+
431
+ # prepare clean output folder for epoch frames
432
+ output_dir = os.path.join(save_folder, "epochs", "morphology")
433
+ _ensure_clean_dir(output_dir)
434
+
435
+ partitions = extract_streamline.get_metric_partitions(streamer.pc_coords, n_points)
436
+ metric_boundaries, trace = extract_streamline.sample_metric_boundaries(streamer.pc_coords, partitions)
437
+
438
+ # pre-make the background image (point cloud and metric boundaries)
439
+ bg_rgba, bg_extent = make_morphology_background(streamer.pc_coords, metric_boundaries, ra_lim, dec_lim)
440
+
441
+ # plot and save for each epoch
442
+ for model in epoch_models:
443
+ plot_morphology(
444
+ ra_model=model["ra_model"],
445
+ dec_model=model["dec_model"],
446
+ streamer=streamer,
447
+ ra_model_interp=model["ra_model_interp"],
448
+ dec_model_interp=model["dec_model_interp"],
449
+ valid=model["valid"],
450
+ bg_rgba=bg_rgba,
451
+ bg_extent=bg_extent,
452
+ title=f"Epoch: {int(model['epoch'])}",
453
+ xlim=ra_lim,
454
+ ylim=dec_lim,
455
+ save_folder=output_dir,
456
+ save_name=f"morphology_epoch_{int(model['epoch']):03d}",
457
+ show=False
458
+ )
459
+
460
+ if make_video:
461
+ input_pattern = os.path.join(output_dir, "morphology_epoch_%03d.png")
462
+ create_video_from_images(output_dir, input_pattern, "streamline_morphology_evolution.mp4", fps=5)
463
+
464
+ def plot_morphology(
465
+ ra_model=None,
466
+ dec_model=None,
467
+ streamer=None,
468
+ ra_model_interp=None,
469
+ dec_model_interp=None,
470
+ valid=None,
471
+ by_eye=None,
472
+ metric_boundaries=None,
473
+ bg_rgba=None,
474
+ bg_extent=None,
475
+ title=None,
476
+ xlim=None,
477
+ ylim=None,
478
+ legend_loc='lower right',
479
+ save_folder='sting_results',
480
+ save_name='streamline_morphology',
481
+ show=True,
482
+ ):
483
+ '''Plot offsets in RA/Dec. Optionally include: model, model points, data points, best fit, background overlay, metric partitions.
484
+
485
+ For a single plot, pass pc_corords and metric_boundaries directly. For per-epoch plotting use plot_morphology_by_epoch,
486
+ which will call this function and pass pre-rendered background images for speed.'''
487
+ ra_data = None
488
+ dec_data = None
489
+ ra_sigma = None
490
+ dec_sigma = None
491
+ pc_coords = None
492
+ if streamer is not None:
493
+ ra_data = streamer.ra_data
494
+ dec_data = streamer.dec_data
495
+ ra_sigma = streamer.ra_sigma
496
+ dec_sigma = streamer.dec_sigma
497
+ pc_coords = streamer.pc_coords
498
+
499
+ fig, ax = plt.subplots(figsize=(6.5, 7))
500
+ if valid is not None:
501
+ valid = np.asarray(valid, dtype=bool)
502
+
503
+ # Static background: prefer pre-rendered image, fall back to live drawing
504
+ if bg_rgba is not None and bg_extent is not None:
505
+ ax.imshow(
506
+ bg_rgba,
507
+ extent=bg_extent,
508
+ aspect='auto',
509
+ origin='upper',
510
+ zorder=1,
511
+ )
512
+ elif pc_coords is not None:
513
+ pc_coords_np = np.asarray(pc_coords, dtype=float)
514
+ ax.scatter(pc_coords_np[0], pc_coords_np[1], s=1, color='gray',
515
+ alpha=0.3, label='Point cloud', zorder=4)
516
+ if metric_boundaries is not None:
517
+ ax_limits = ax.get_xlim(), ax.get_ylim()
518
+ extract_streamline.plot_metric_boundaries(
519
+ ax, pc_coords_np, metric_boundaries,
520
+ color='gray', linewidth=1, alpha=0.3,
521
+ )
522
+ ax.set_xlim(ax_limits[0])
523
+ ax.set_ylim(ax_limits[1])
524
+
525
+
526
+ # model curve if given
527
+ if ra_model is not None and dec_model is not None:
528
+ ax.plot(ra_model, dec_model, color='blue', linewidth=2, label='STING', zorder=7)
529
+
530
+ # model points if given
531
+ if ra_model_interp is not None and dec_model_interp is not None and valid is not None:
532
+ if valid is not None:
533
+ ax.scatter(
534
+ ra_model_interp[valid],
535
+ dec_model_interp[valid],
536
+ s=25,
537
+ color='blue',
538
+ zorder=7,
539
+ )
540
+
541
+ # by eye model if given
542
+ if by_eye is not None:
543
+ ra_by_eye, dec_by_eye, _ = by_eye
544
+ ax.plot(
545
+ np.asarray(ra_by_eye, dtype=float),
546
+ np.asarray(dec_by_eye, dtype=float),
547
+ color='tab:green',
548
+ linewidth=2,
549
+ label='By-eye',
550
+ zorder=8,
551
+ )
552
+
553
+ # data points streamline if given
554
+ if ra_data is not None and dec_data is not None:
555
+ if ra_sigma is not None and dec_sigma is not None:
556
+ ax.errorbar(
557
+ ra_data,
558
+ dec_data,
559
+ xerr=ra_sigma,
560
+ yerr=dec_sigma,
561
+ fmt='o-',
562
+ label='Extracted 1D Streamline',
563
+ color='red',
564
+ zorder=5,
565
+ )
566
+ else:
567
+ ax.plot(
568
+ ra_data,
569
+ dec_data,
570
+ 'o-',
571
+ label='Extracted 1D Streamline',
572
+ color='red',
573
+ zorder=5,
574
+ )
575
+
576
+ star_ra = 0
577
+ star_dec = 0
578
+ ax.scatter(star_ra, star_dec, marker='*', s=100, color='yellow', edgecolor='black', zorder=10)
579
+ ax.set_xlabel('RA Offset (arcsec)')
580
+ ax.set_ylabel('Dec Offset (arcsec)')
581
+
582
+ if xlim is not None:
583
+ ax.set_xlim(xlim)
584
+ else:
585
+ if pc_coords is not None:
586
+ all_ra = np.concatenate([*[e for e in [ra_model, ra_data, pc_coords[0], np.array([star_ra])] if e is not None]])
587
+ else:
588
+ all_ra = np.concatenate([*[e for e in [ra_model, ra_data, np.array([star_ra])] if e is not None]])
589
+ pad_ra = 0.05 * (all_ra.max() - all_ra.min())
590
+ ra_lim = (all_ra.min() - pad_ra, all_ra.max() + pad_ra)
591
+ ax.set_xlim(ra_lim)
592
+ if ylim is not None:
593
+ ax.set_ylim(ylim)
594
+ else:
595
+ if pc_coords is not None:
596
+ all_dec = np.concatenate([*[e for e in [dec_model, dec_data, pc_coords[1], np.array([star_dec])] if e is not None]])
597
+ else:
598
+ all_dec = np.concatenate([*[e for e in [dec_model, dec_data, np.array([star_dec])] if e is not None]])
599
+ pad_dec = 0.05 * (all_dec.max() - all_dec.min())
600
+ dec_lim = (all_dec.min() - pad_dec, all_dec.max() + pad_dec)
601
+ ax.set_ylim(dec_lim)
602
+ ax.invert_xaxis()
603
+ ax.set_title(title)
604
+ ax.legend(loc=legend_loc)
605
+ if save_folder is not None:
606
+ # make dir if it doesn't exist
607
+ os.makedirs(save_folder, exist_ok=True)
608
+ plt.savefig(f'{save_folder}/{save_name}.png', dpi=300, bbox_inches='tight')
609
+ if show:
610
+ plt.show()
611
+ else:
612
+ plt.close(fig)
613
+
614
+ def plot_ra_vel_by_epoch(
615
+ gradient_descent,
616
+ fixed_params,
617
+ initial_opt_params,
618
+ distance,
619
+ streamer=None,
620
+ save_folder="sting_results",
621
+ make_video=False,
622
+ ):
623
+ """
624
+ Create RA–velocity plots for every epoch
625
+ """
626
+ try:
627
+ optimisation_log = load_optimisation_log(save_folder)
628
+ except FileNotFoundError:
629
+ print(f"Error: Could not find 'optimisation_log.csv' in {save_folder}")
630
+ return
631
+ column_map = {c.split(' [')[0]: c for c in optimisation_log.columns}
632
+ param_names = _opt_params_from_log(optimisation_log)
633
+ fixed_params_clean, initial_opt_params = gradient_descent.sanitize_param_partition(fixed_params, initial_opt_params, require_nonempty_opt=False)
634
+
635
+ epochs = optimisation_log['epoch'].values
636
+
637
+ epoch_models = []
638
+
639
+ # make models
640
+ for idx, epoch in enumerate(epochs):
641
+
642
+ row = optimisation_log.iloc[idx]
643
+ opt_params_epoch = {param: float(row[column_map[param]]) for param in param_names}
644
+ model_params_epoch = {**fixed_params_clean, **opt_params_epoch}
645
+ ra_model, dec_model, v_model, valid_mask_model, err = gradient_descent.forward_model(model_params_epoch, distance)
646
+ valid_mask_model = valid_mask_model.astype(bool)
647
+ ra_model_interp, _, v_model_interp, valid, model_keep, dmetric_model, matching_trace = (
648
+ gradient_descent.checked_match_model_to_data_curve(ra_model, dec_model, v_model, valid_mask_model, streamer.ra_data, streamer.dec_data)
649
+ )
650
+
651
+ epoch_models.append({
652
+ "epoch": epoch,
653
+ "ra_model": ra_model,
654
+ "v_model": v_model,
655
+ "ra_model_interp": ra_model_interp,
656
+ "v_model_interp": v_model_interp,
657
+ "valid": valid,
658
+ "model_keep": model_keep,
659
+ })
660
+
661
+ # global velocity limits
662
+ v_list = [m["v_model"] for m in epoch_models]
663
+ if streamer is not None:
664
+ v_list.append(streamer.v_data)
665
+ all_v = np.concatenate(v_list)
666
+ vlim = (np.nanmin(all_v), np.nanmax(all_v))
667
+
668
+ # global RA limits
669
+ ra_list = [m["ra_model"] for m in epoch_models]
670
+ if streamer is not None:
671
+ ra_list.append(streamer.ra_data)
672
+ all_ra = np.concatenate(ra_list)
673
+ ralim = (np.nanmin(all_ra), np.nanmax(all_ra))
674
+
675
+ # make clean output folder
676
+ output_dir = os.path.join(save_folder, "epochs", "ra_vel")
677
+ _ensure_clean_dir(output_dir)
678
+
679
+
680
+ # make the plots
681
+ for model in epoch_models:
682
+ plot_ra_vel(
683
+ ra_model=model["ra_model"],
684
+ v_model=model["v_model"],
685
+ streamer=streamer,
686
+ ra_model_interp=model["ra_model_interp"],
687
+ v_model_interp=model["v_model_interp"],
688
+ valid=model["valid"],
689
+ model_keep=model["model_keep"],
690
+ title=f"Epoch: {int(model['epoch'])}",
691
+ vlim=vlim,
692
+ ralim=ralim,
693
+ save_folder=output_dir,
694
+ save_name=f"ra_vel_epoch_{int(model['epoch']):03d}",
695
+ )
696
+
697
+ if make_video:
698
+ input_pattern = os.path.join(output_dir, "ra_vel_epoch_%03d.png")
699
+ create_video_from_images(output_dir, input_pattern, "streamline_ra_vel_evolution.mp4", fps=5)
700
+
701
+
702
+ def plot_ra_vel(
703
+ ra_model,
704
+ v_model,
705
+ *,
706
+ streamer=None,
707
+ ra_model_interp=None,
708
+ v_model_interp=None,
709
+ valid=None,
710
+ model_keep=None,
711
+ title=None,
712
+ vlim=None,
713
+ ralim=None,
714
+ legend_loc='lower right',
715
+ save_folder='sting_results',
716
+ save_name='streamline_ra_vel',
717
+ show=False,
718
+ ):
719
+ ra_data = None
720
+ v_data = None
721
+ ra_sigma = None
722
+ v_sigma = None
723
+ pc_coords = None
724
+ if streamer is not None:
725
+ ra_data = streamer.ra_data
726
+ v_data = streamer.v_data
727
+ ra_sigma = streamer.ra_sigma
728
+ v_sigma = streamer.v_sigma
729
+ pc_coords = streamer.pc_coords
730
+
731
+ ra_model = np.asarray(ra_model, dtype=float)
732
+ v_model = np.asarray(v_model, dtype=float)
733
+ if valid is not None:
734
+ valid = np.asarray(valid, dtype=bool)
735
+ if model_keep is not None:
736
+ model_keep = np.asarray(model_keep, dtype=bool)
737
+ fig, ax = plt.subplots(figsize=(6, 5))
738
+
739
+
740
+ # point cloud (RA vs velocity)
741
+ if pc_coords is not None:
742
+ ax.scatter(pc_coords[0], pc_coords[2], s=1, alpha=0.3, color='grey', label='Point cloud')
743
+
744
+ # data
745
+ if ra_data is not None and v_data is not None:
746
+ if ra_sigma is not None and v_sigma is not None:
747
+ ax.errorbar(ra_data, v_data, xerr=ra_sigma, yerr=v_sigma, fmt='o', color='red', ecolor='red', ms=4, alpha=0.9, label='Data')
748
+ else:
749
+ ax.plot(ra_data, v_data, 'o', color='red', label='Data')
750
+
751
+ # model curve
752
+ if ra_model is not None and v_model is not None:
753
+ ax.plot(ra_model, v_model, color='blue', linewidth=2, label='Model Streamline', zorder=7)
754
+
755
+ # interpolated points
756
+ if ra_model_interp is not None and v_model_interp is not None and valid is not None:
757
+ ax.scatter(np.asarray(ra_model_interp)[valid], np.asarray(v_model_interp)[valid], s=25, color='blue', zorder=5, label='Model at data positions')
758
+
759
+ ax.set_xlabel("RA Offset (arcsec)")
760
+ ax.set_ylabel("Velocity (km/s)")
761
+ ax.set_title(title or "RA vs Velocity")
762
+
763
+ if vlim is not None:
764
+ ax.set_ylim(vlim)
765
+ if ralim is not None:
766
+ ax.set_xlim(ralim)
767
+
768
+ # flip RA axis to match astronomical convention
769
+ ax.invert_xaxis()
770
+
771
+ ax.legend(loc=legend_loc)
772
+ if save_folder is not None:
773
+ os.makedirs(save_folder, exist_ok=True)
774
+ save_path = os.path.join(save_folder, f"{save_name}.png")
775
+ plt.savefig(save_path, bbox_inches='tight', dpi=300)
776
+ if show:
777
+ plt.show()
778
+ else:
779
+ plt.close(fig)
780
+
781
+
782
+ #########
783
+ def plot_dec_vel_by_epoch(
784
+ gradient_descent,
785
+ fixed_params,
786
+ initial_opt_params,
787
+ distance,
788
+ streamer=None,
789
+ save_folder="sting_results",
790
+ make_video=False,
791
+ ):
792
+ """
793
+ Create DEC–velocity plots for every epoch
794
+ """
795
+ try:
796
+ optimisation_log = load_optimisation_log(save_folder)
797
+ except FileNotFoundError:
798
+ print(f"Error: Could not find 'optimisation_log.csv' in {save_folder}")
799
+ return
800
+ column_map = {c.split(' [')[0]: c for c in optimisation_log.columns}
801
+ param_names = _opt_params_from_log(optimisation_log)
802
+ fixed_params_clean, initial_opt_params = gradient_descent.sanitize_param_partition(fixed_params, initial_opt_params, require_nonempty_opt=False)
803
+
804
+ epochs = optimisation_log['epoch'].values
805
+
806
+ epoch_models = []
807
+
808
+ # make models
809
+ for idx, epoch in enumerate(epochs):
810
+
811
+ row = optimisation_log.iloc[idx]
812
+ opt_params_epoch = {param: float(row[column_map[param]]) for param in param_names}
813
+ model_params_epoch = {**fixed_params_clean, **opt_params_epoch}
814
+ ra_model, dec_model, v_model, valid_mask_model, err = gradient_descent.forward_model(model_params_epoch, distance)
815
+ valid_mask_model = valid_mask_model.astype(bool)
816
+ ra_model_interp, dec_model_interp, v_model_interp, valid, model_keep, dmetric_model, matching_trace = (
817
+ gradient_descent.checked_match_model_to_data_curve(ra_model, dec_model, v_model, valid_mask_model, streamer.ra_data, streamer.dec_data)
818
+ )
819
+
820
+ epoch_models.append({
821
+ "epoch": epoch,
822
+ "ra_model": ra_model,
823
+ "dec_model": dec_model,
824
+ "v_model": v_model,
825
+ "ra_model_interp": ra_model_interp,
826
+ "dec_model_interp": dec_model_interp,
827
+ "v_model_interp": v_model_interp,
828
+ "valid": valid,
829
+ "model_keep": model_keep,
830
+ })
831
+
832
+ # global velocity limits
833
+ v_list = [m["v_model"] for m in epoch_models]
834
+ if streamer is not None:
835
+ v_list.append(streamer.v_data)
836
+ all_v = np.concatenate(v_list)
837
+ vlim = (np.nanmin(all_v), np.nanmax(all_v))
838
+
839
+ # global dec limits
840
+ dec_list = [m["dec_model"] for m in epoch_models]
841
+ if streamer is not None:
842
+ dec_list.append(streamer.dec_data)
843
+ all_dec = np.concatenate(dec_list)
844
+ declim = (np.nanmin(all_dec), np.nanmax(all_dec))
845
+
846
+ # make clean output folder
847
+ output_dir = os.path.join(save_folder, "epochs", "dec_vel")
848
+ _ensure_clean_dir(output_dir)
849
+
850
+
851
+ # make the plots
852
+ for model in epoch_models:
853
+ plot_dec_vel(
854
+ dec_model=model["dec_model"],
855
+ v_model=model["v_model"],
856
+ streamer=streamer,
857
+ dec_model_interp=model["dec_model_interp"],
858
+ v_model_interp=model["v_model_interp"],
859
+ valid=model["valid"],
860
+ model_keep=model["model_keep"],
861
+ title=f"Epoch: {int(model['epoch'])}",
862
+ vlim=vlim,
863
+ declim=declim,
864
+ save_folder = output_dir,
865
+ save_name = f"dec_vel_epoch_{int(model['epoch']):03d}",
866
+ )
867
+
868
+ if make_video:
869
+ input_pattern = os.path.join(output_dir, "dec_vel_epoch_%03d.png")
870
+ create_video_from_images(output_dir, input_pattern, "streamline_dec_vel_evolution.mp4", fps=5)
871
+
872
+
873
+ def plot_dec_vel(
874
+ dec_model,
875
+ v_model,
876
+ *,
877
+ streamer=None,
878
+ dec_model_interp=None,
879
+ v_model_interp=None,
880
+ valid=None,
881
+ model_keep=None,
882
+ title=None,
883
+ vlim=None,
884
+ declim=None,
885
+ legend_loc='lower right',
886
+ save_folder='sting_results',
887
+ save_name='streamline_dec_vel',
888
+ show=False,
889
+ ):
890
+ dec_data = None
891
+ v_data = None
892
+ dec_sigma = None
893
+ v_sigma = None
894
+ pc_coords = None
895
+ if streamer is not None:
896
+ dec_data = streamer.dec_data
897
+ v_data = streamer.v_data
898
+ dec_sigma = streamer.dec_sigma
899
+ v_sigma = streamer.v_sigma
900
+ pc_coords = streamer.pc_coords
901
+
902
+ dec_model = np.asarray(dec_model, dtype=float)
903
+ v_model = np.asarray(v_model, dtype=float)
904
+ if valid is not None:
905
+ valid = np.asarray(valid, dtype=bool)
906
+ if model_keep is not None:
907
+ model_keep = np.asarray(model_keep, dtype=bool)
908
+
909
+ fig, ax = plt.subplots(figsize=(6, 5))
910
+
911
+ # point cloud (DEC vs velocity)
912
+ if pc_coords is not None:
913
+ # pc_coords layout: [ra, dec, velocity]
914
+ ax.scatter(pc_coords[1], pc_coords[2], s=1, alpha=0.3, color='grey', label='Point cloud')
915
+
916
+ # data
917
+ if dec_data is not None and v_data is not None:
918
+ if dec_sigma is not None and v_sigma is not None:
919
+ ax.errorbar(dec_data, v_data, xerr=dec_sigma, yerr=v_sigma, fmt='o', color='red', ecolor='red', ms=4, alpha=0.9, label='Data')
920
+ else:
921
+ ax.plot(dec_data, v_data, 'o', color='red', label='Data')
922
+
923
+ # model curve
924
+ if dec_model is not None and v_model is not None:
925
+ ax.plot(dec_model, v_model, color='blue', linewidth=2, label='Model Streamline')
926
+
927
+ # interpolated points
928
+ if dec_model_interp is not None and v_model_interp is not None and valid is not None:
929
+ ax.scatter(np.asarray(dec_model_interp)[valid], np.asarray(v_model_interp)[valid], s=25, color='blue', zorder=5, label='Model at data positions')
930
+
931
+ ax.set_xlabel("DEC Offset (arcsec)")
932
+ ax.set_ylabel("Velocity (km/s)")
933
+ ax.set_title(title or "DEC vs Velocity")
934
+
935
+ if vlim is not None:
936
+ ax.set_ylim(vlim)
937
+ if declim is not None:
938
+ ax.set_xlim(declim)
939
+
940
+ ax.legend(loc=legend_loc)
941
+
942
+ if save_folder is not None:
943
+ os.makedirs(save_folder, exist_ok=True)
944
+ save_path = os.path.join(save_folder, f"{save_name}.png")
945
+ plt.savefig(save_path, bbox_inches='tight', dpi=300)
946
+ plt.close(fig)
947
+ elif show:
948
+ plt.show()
949
+ else:
950
+ plt.close(fig)
951
+
952
+
953
+ def build_velocity_radius_kde(
954
+ ra_data,
955
+ dec_data,
956
+ vlos_data,
957
+ xmin=None,
958
+ xmax=None,
959
+ ymin=None,
960
+ ymax=None,
961
+ grid_size=100,
962
+ sigma_levels=None,
963
+ ):
964
+ """Build a KDE background grid for projected radius vs velocity plots.
965
+
966
+ Parameters
967
+ ----------
968
+ rproj_data : array-like
969
+ Projected radial distance samples (arcsec).
970
+ vlos_data : array-like
971
+ Line-of-sight velocity samples (km/s).
972
+ xmin, xmax, ymin, ymax : float, optional
973
+ Plot limits for KDE grid. If omitted, finite data limits are used.
974
+ grid_size : int, optional
975
+ Number of grid points per axis.
976
+ sigma_levels : array-like, optional
977
+ Sigma values used to build cumulative Gaussian-like contour levels.
978
+
979
+ Returns
980
+ -------
981
+ dict
982
+ Dictionary with keys: "xx", "yy", "zz", "levels", "xlim", "ylim".
983
+ """
984
+ from scipy import stats
985
+ ra = np.asarray(ra_data)
986
+ dec = np.asarray(dec_data)
987
+ rproj = np.sqrt(ra**2 + dec**2)
988
+ vlos = np.asarray(vlos_data, dtype=float)
989
+ finite = np.isfinite(rproj) & np.isfinite(vlos)
990
+ if np.sum(finite) < 3:
991
+ raise ValueError("Need at least 3 finite samples to build KDE background.")
992
+
993
+ rproj = rproj[finite]
994
+ vlos = vlos[finite]
995
+
996
+ if xmin is None:
997
+ xmin = float(np.nanmin(rproj) - 1)
998
+ if xmax is None:
999
+ xmax = float(np.nanmax(rproj) + 1)
1000
+ if ymin is None:
1001
+ ymin = float(np.nanmin(vlos) - 1)
1002
+ if ymax is None:
1003
+ ymax = float(np.nanmax(vlos) + 1)
1004
+
1005
+
1006
+ xx, yy = np.mgrid[xmin:xmax:complex(grid_size), ymin:ymax:complex(grid_size)]
1007
+ positions = np.vstack([xx.ravel(), yy.ravel()])
1008
+ values = np.vstack([rproj, vlos])
1009
+
1010
+ kernel = stats.gaussian_kde(values)
1011
+ zz = np.reshape(kernel(positions).T, xx.shape)
1012
+ zmax = np.nanmax(zz)
1013
+ if np.isfinite(zmax) and zmax > 0:
1014
+ zz = zz / zmax
1015
+
1016
+ if sigma_levels is None:
1017
+ sigma_levels = np.arange(1.0, 2.1, 0.5)
1018
+ sigma_levels = np.asarray(sigma_levels, dtype=float)
1019
+ levels = np.append(np.exp(-0.5 * sigma_levels**2)[::-1], [1.0])
1020
+ kde_levels = np.append(np.exp(-0.5 * np.arange(1.0, 2.1, 0.5)**2)[::-1], [1.0])
1021
+
1022
+ return {
1023
+ "xx": xx,
1024
+ "yy": yy,
1025
+ "zz": zz,
1026
+ "levels": levels,
1027
+ "xlim": (xmin, xmax),
1028
+ "ylim": (ymin, ymax),
1029
+ }
1030
+
1031
+
1032
+ def plot_vel_radius(
1033
+ ra_model,
1034
+ dec_model,
1035
+ v_model,
1036
+ streamer,
1037
+ *,
1038
+ ra_model_interp=None,
1039
+ dec_model_interp=None,
1040
+ v_model_interp=None,
1041
+ valid=None,
1042
+ by_eye=None,
1043
+ model_keep=None,
1044
+ kde_background=None,
1045
+ velocity_reference=None,
1046
+ title=None,
1047
+ xlim=None,
1048
+ ylim=None,
1049
+ legend_loc='lower right',
1050
+ save_folder='sting_results',
1051
+ save_name=None,
1052
+ show=False,
1053
+ ):
1054
+ """Plot velocity vs projected radius for one model (optionally with KDE background)."""
1055
+ ra_data = streamer.ra_data
1056
+ dec_data = streamer.dec_data
1057
+ v_data = streamer.v_data
1058
+ ra_sigma = streamer.ra_sigma
1059
+ dec_sigma = streamer.dec_sigma
1060
+ v_sigma = streamer.v_sigma
1061
+ pc_coords = streamer.pc_coords
1062
+
1063
+ ra_model = np.asarray(ra_model, dtype=float)
1064
+ dec_model = np.asarray(dec_model, dtype=float)
1065
+ v_model = np.asarray(v_model, dtype=float)
1066
+ if valid is not None:
1067
+ valid = np.asarray(valid, dtype=bool)
1068
+ if model_keep is not None:
1069
+ model_keep = np.asarray(model_keep, dtype=bool)
1070
+ ra_model = ra_model[model_keep]
1071
+ dec_model = dec_model[model_keep]
1072
+ v_model = v_model[model_keep]
1073
+
1074
+ rproj_model = np.sqrt(ra_model**2 + dec_model**2)
1075
+ order_model = np.argsort(rproj_model)
1076
+
1077
+ fig, ax = plt.subplots(figsize=(6.5 * 1.3, 4 * 1.3))
1078
+ data_handle = None
1079
+ model_handle = None
1080
+ background_handle = None
1081
+ by_eye_handle = None
1082
+
1083
+ if kde_background is None and pc_coords is not None:
1084
+ # make the kde background
1085
+ kde_background = build_velocity_radius_kde(
1086
+ ra_data=ra_data,
1087
+ dec_data=dec_data,
1088
+ vlos_data=v_data,
1089
+ )
1090
+
1091
+ if kde_background is not None:
1092
+ ax.contourf(
1093
+ kde_background["xx"],
1094
+ kde_background["yy"],
1095
+ kde_background["zz"],
1096
+ levels=kde_background["levels"],
1097
+ cmap='Greys',
1098
+ vmin=0,
1099
+ vmax=1.2,
1100
+ zorder=1,
1101
+ )
1102
+
1103
+ background_handle = Patch(
1104
+ facecolor='lightgray',
1105
+ edgecolor='none',
1106
+ label='Data KDE',
1107
+ )
1108
+
1109
+ # Central source marker in this projection (r=0, v=v_lsr)
1110
+ if velocity_reference is not None:
1111
+ ax.scatter(
1112
+ 0,
1113
+ float(velocity_reference),
1114
+ marker='*',
1115
+ s=100,
1116
+ color='yellow',
1117
+ edgecolor='black',
1118
+ zorder=10,
1119
+ label='Central Source',
1120
+ )
1121
+
1122
+ if ra_data is not None and dec_data is not None and v_data is not None:
1123
+ ra_data = np.asarray(ra_data, dtype=float)
1124
+ dec_data = np.asarray(dec_data, dtype=float)
1125
+ v_data = np.asarray(v_data, dtype=float)
1126
+ rproj_data = np.sqrt(ra_data**2 + dec_data**2)
1127
+
1128
+ if ra_sigma is not None and dec_sigma is not None and v_sigma is not None:
1129
+ ra_sigma = np.asarray(ra_sigma, dtype=float)
1130
+ dec_sigma = np.asarray(dec_sigma, dtype=float)
1131
+ v_sigma = np.asarray(v_sigma, dtype=float)
1132
+ denom = np.maximum(rproj_data, 1e-8)
1133
+ rproj_sigma = np.sqrt((ra_data * ra_sigma) ** 2 + (dec_data * dec_sigma) ** 2) / denom
1134
+ data_handle = ax.errorbar(
1135
+ rproj_data,
1136
+ v_data,
1137
+ xerr=rproj_sigma,
1138
+ yerr=v_sigma,
1139
+ fmt='o',
1140
+ color='red',
1141
+ ecolor='red',
1142
+ ms=4,
1143
+ alpha=0.9,
1144
+ label='Extracted 1D Streamline',
1145
+ zorder=6,
1146
+ )
1147
+ else:
1148
+ data_handle = ax.plot(
1149
+ rproj_data,
1150
+ v_data,
1151
+ 'o',
1152
+ color='red',
1153
+ label='Extracted 1D Streamline',
1154
+ zorder=6,
1155
+ )[0]
1156
+
1157
+ model_handle, = ax.plot(
1158
+ rproj_model[order_model],
1159
+ v_model[order_model],
1160
+ color='blue',
1161
+ linewidth=2,
1162
+ label='STING',
1163
+ zorder=7,
1164
+ )
1165
+
1166
+ if (
1167
+ ra_model_interp is not None
1168
+ and dec_model_interp is not None
1169
+ and v_model_interp is not None
1170
+ and valid is not None
1171
+ ):
1172
+ ra_model_interp = np.asarray(ra_model_interp, dtype=float)
1173
+ dec_model_interp = np.asarray(dec_model_interp, dtype=float)
1174
+ v_model_interp = np.asarray(v_model_interp, dtype=float)
1175
+ valid = np.asarray(valid, dtype=bool)
1176
+ rproj_interp = np.sqrt(ra_model_interp**2 + dec_model_interp**2)
1177
+ ax.scatter(
1178
+ rproj_interp[valid],
1179
+ v_model_interp[valid],
1180
+ s=25,
1181
+ color='blue',
1182
+ label='Model at retained data arc lengths',
1183
+ zorder=8,
1184
+ )
1185
+
1186
+ if velocity_reference is not None:
1187
+ ax.axhline(
1188
+ float(velocity_reference),
1189
+ color='black',
1190
+ linestyle='--',
1191
+ label='Systemic Velocity',
1192
+ zorder=4,
1193
+ )
1194
+
1195
+ if by_eye is not None:
1196
+ ra_by_eye, dec_by_eye, v_by_eye = by_eye
1197
+ ra_by_eye = np.asarray(ra_by_eye, dtype=float)
1198
+ dec_by_eye = np.asarray(dec_by_eye, dtype=float)
1199
+ v_by_eye = np.asarray(v_by_eye, dtype=float)
1200
+ rproj_by_eye = np.sqrt(ra_by_eye**2 + dec_by_eye**2)
1201
+ by_eye_handle, = ax.plot(
1202
+ rproj_by_eye,
1203
+ v_by_eye,
1204
+ color='tab:green',
1205
+ linewidth=2,
1206
+ label='By-eye',
1207
+ zorder=9,
1208
+ )
1209
+
1210
+ ax.set_xlabel('Projected Distance from Source (arcsec)')
1211
+ ax.set_ylabel('Velocity (km/s)')
1212
+ ax.set_title(title or 'Velocity vs Projected Radius')
1213
+
1214
+ if xlim is not None:
1215
+ ax.set_xlim(xlim)
1216
+ elif kde_background is not None:
1217
+ ax.set_xlim(kde_background["xlim"])
1218
+
1219
+ if ylim is not None:
1220
+ ax.set_ylim(ylim)
1221
+ elif kde_background is not None:
1222
+ ax.set_ylim(kde_background["ylim"])
1223
+
1224
+ all_handles = [data_handle, model_handle, by_eye_handle, background_handle]
1225
+ handles = [h for h in all_handles if h is not None]
1226
+ if handles:
1227
+ ax.legend(handles=handles, loc=legend_loc)
1228
+ else:
1229
+ ax.legend(loc=legend_loc)
1230
+
1231
+ if save_folder is not None:
1232
+ os.makedirs(save_folder, exist_ok=True)
1233
+ plt.savefig(f'{save_folder}/{save_name}.png', dpi=300, bbox_inches='tight')
1234
+ if show:
1235
+ plt.show()
1236
+ else:
1237
+ plt.close(fig)
1238
+
1239
+
1240
+ def plot_vel_radius_by_epoch(
1241
+ gradient_descent,
1242
+ fixed_params,
1243
+ initial_opt_params,
1244
+ distance,
1245
+ streamer,
1246
+ *,
1247
+ grid_size=100,
1248
+ levels=None,
1249
+ velocity_reference=None,
1250
+ save_folder="sting_results",
1251
+ make_video=False,
1252
+ ):
1253
+ """Create velocity vs projected radius plots for every epoch."""
1254
+
1255
+ try:
1256
+ optimisation_log = load_optimisation_log(save_folder)
1257
+ except FileNotFoundError:
1258
+ print(f"Error: Could not find 'optimisation_log.csv' in {save_folder}")
1259
+ return
1260
+ column_map = {c.split(' [')[0]: c for c in optimisation_log.columns}
1261
+ param_names = _opt_params_from_log(optimisation_log)
1262
+ fixed_params_clean, initial_opt_params = gradient_descent.sanitize_param_partition(fixed_params, initial_opt_params, require_nonempty_opt=False)
1263
+
1264
+ epochs = optimisation_log['epoch'].values
1265
+ epoch_models = []
1266
+
1267
+ kde_background = None
1268
+ if streamer is not None:
1269
+ kde_background = build_velocity_radius_kde(
1270
+ ra_data=streamer.ra_data,
1271
+ dec_data=streamer.dec_data,
1272
+ vlos_data=streamer.v_data,
1273
+ grid_size=grid_size,
1274
+ sigma_levels=levels,
1275
+ )
1276
+
1277
+ for idx, epoch in enumerate(epochs):
1278
+ row = optimisation_log.iloc[idx]
1279
+ opt_params_epoch = {param: float(row[column_map[param]]) for param in param_names}
1280
+ model_params_epoch = {**fixed_params_clean, **opt_params_epoch}
1281
+ ra_model, dec_model, v_model, valid_mask_model, err = gradient_descent.forward_model(model_params_epoch, distance)
1282
+
1283
+ valid_mask_model = valid_mask_model.astype(bool)
1284
+
1285
+ ra_model_interp, dec_model_interp, v_model_interp, valid, model_keep, dmetric_model, matching_trace = (
1286
+ gradient_descent.checked_match_model_to_data_curve(
1287
+ ra_model,
1288
+ dec_model,
1289
+ v_model,
1290
+ valid_mask_model,
1291
+ streamer.ra_data,
1292
+ streamer.dec_data,
1293
+ )
1294
+ )
1295
+
1296
+ if model_keep is not None:
1297
+ model_keep = model_keep.astype(bool)
1298
+
1299
+ epoch_models.append({
1300
+ "epoch": epoch,
1301
+ "ra_model": ra_model,
1302
+ "dec_model": dec_model,
1303
+ "v_model": v_model,
1304
+ "ra_model_interp": ra_model_interp,
1305
+ "dec_model_interp": dec_model_interp,
1306
+ "v_model_interp": v_model_interp,
1307
+ "valid": valid,
1308
+ "model_keep": model_keep,
1309
+ })
1310
+
1311
+ # Set consistent axis limits across epochs
1312
+ rproj_list = []
1313
+ v_list = [np.asarray(m["v_model"], dtype=float) for m in epoch_models]
1314
+ for model in epoch_models:
1315
+ ra_m = np.asarray(model["ra_model"], dtype=float)
1316
+ dec_m = np.asarray(model["dec_model"], dtype=float)
1317
+ rproj_list.append(np.sqrt(ra_m**2 + dec_m**2))
1318
+
1319
+ if streamer is not None:
1320
+ rproj_list.append(np.sqrt(np.asarray(streamer.ra_data, dtype=float) ** 2 + np.asarray(streamer.dec_data, dtype=float) ** 2))
1321
+ if streamer is not None and streamer.v_data is not None:
1322
+ v_list.append(np.asarray(streamer.v_data, dtype=float))
1323
+
1324
+ all_rproj = np.concatenate(rproj_list)
1325
+ all_v = np.concatenate(v_list)
1326
+ xlim = (np.nanmin(all_rproj), np.nanmax(all_rproj))
1327
+ ylim = (np.nanmin(all_v), np.nanmax(all_v))
1328
+
1329
+ # make or clean output folder
1330
+ output_dir = os.path.join(save_folder, "epochs", "vel_radius")
1331
+ _ensure_clean_dir(output_dir)
1332
+
1333
+ for model in epoch_models:
1334
+ plot_vel_radius(
1335
+ ra_model=model["ra_model"],
1336
+ dec_model=model["dec_model"],
1337
+ v_model=model["v_model"],
1338
+ streamer=streamer,
1339
+ ra_model_interp=model["ra_model_interp"],
1340
+ dec_model_interp=model["dec_model_interp"],
1341
+ v_model_interp=model["v_model_interp"],
1342
+ valid=model["valid"],
1343
+ model_keep=model["model_keep"],
1344
+ kde_background=kde_background,
1345
+ velocity_reference=velocity_reference,
1346
+ title=f"Epoch: {int(model['epoch'])}",
1347
+ xlim=xlim,
1348
+ ylim=ylim,
1349
+ save_folder=output_dir,
1350
+ save_name=f"vel_radius_epoch_{int(model['epoch']):03d}",
1351
+ )
1352
+
1353
+ if make_video:
1354
+ input_pattern = os.path.join(output_dir, "vel_radius_epoch_%03d.png")
1355
+ create_video_from_images(
1356
+ output_dir,
1357
+ input_pattern,
1358
+ "streamline_vel_radius_evolution.mp4",
1359
+ fps=5,
1360
+ )
1361
+
1362
+
1363
+ def plot_param_uncertainties(opt_keys, opt_params, opt_sigmas, save_folder=None, show=False):
1364
+ eps = 1e-12
1365
+ norm_errs = np.abs(opt_sigmas / (opt_params + eps))
1366
+
1367
+ fig, ax = plt.subplots(figsize=(8, 4.5))
1368
+ ypos = np.arange(len(opt_keys))
1369
+ ax.barh(
1370
+ ypos,
1371
+ norm_errs,
1372
+ color='tab:blue',
1373
+ alpha=0.8
1374
+ )
1375
+ ax.set_yticks(ypos)
1376
+ ax.set_yticklabels(opt_keys)
1377
+ ax.set_xlabel('Relative uncertainty ($\\sigma / |x|$)')
1378
+ ax.set_title('Normalized Parameter Uncertainties')
1379
+ ax.grid(True, alpha=0.25)
1380
+ plt.tight_layout()
1381
+ if save_folder is not None:
1382
+ os.makedirs(save_folder, exist_ok=True)
1383
+ plt.savefig(f'{save_folder}/parameter_uncertainties.png', dpi=300, bbox_inches='tight')
1384
+ if show:
1385
+ plt.show()
1386
+ else:
1387
+ plt.close(fig)
1388
+
1389
+ def plot_param_correlations(param_names, covariance, annotate=True, save_folder=None, show=False):
1390
+ '''
1391
+ Plot a parameter correlation matrix derived from the covariance matrix, as a heatmpa.
1392
+
1393
+ Parameters:
1394
+ ----------
1395
+ param_names: list of str
1396
+ Names of the parameters, in the same order as the covariance matrix.
1397
+ covariance: 2D array
1398
+ Covariance matrix of the parameters.
1399
+ annotate: bool, optional
1400
+ Whether to annotate the heatmap with correlation values.
1401
+ '''
1402
+ cov_np = np.array(covariance, dtype=float)
1403
+
1404
+ diag = np.sqrt(np.clip(np.diag(cov_np), 1e-30, None))
1405
+ corr = cov_np / np.outer(diag, diag)
1406
+ corr = np.clip(corr, -1.0, 1.0)
1407
+
1408
+ fig, ax = plt.subplots(figsize=(6.5, 5.5))
1409
+
1410
+ im = ax.imshow(corr, vmin=-1, vmax=1, cmap='coolwarm_r')
1411
+
1412
+ ax.set_xticks(np.arange(len(param_names)))
1413
+ ax.set_yticks(np.arange(len(param_names)))
1414
+ ax.set_xticklabels(param_names, rotation=45, ha='right', fontsize=11)
1415
+ ax.set_yticklabels(param_names, fontsize=11)
1416
+ ax.set_title('Parameter Correlation Matrix')
1417
+
1418
+ # Create colorbar axis with matched height
1419
+ divider = make_axes_locatable(ax)
1420
+ cax = divider.append_axes("right", size="5%", pad=0.08)
1421
+
1422
+ cbar = fig.colorbar(im, cax=cax)
1423
+ cbar.set_label('Correlation coefficient')
1424
+
1425
+ if annotate:
1426
+ for i in range(len(param_names)):
1427
+ for j in range(len(param_names)):
1428
+ ax.text(
1429
+ j, i,
1430
+ f'{corr[i, j]:.2f}',
1431
+ ha='center',
1432
+ va='center',
1433
+ fontsize=10,
1434
+ color='black'
1435
+ )
1436
+
1437
+ plt.tight_layout()
1438
+ if save_folder is not None:
1439
+ os.makedirs(save_folder, exist_ok=True)
1440
+ plt.savefig(f'{save_folder}/parameter_correlation_matrix.png', dpi=300, bbox_inches='tight')
1441
+ if show:
1442
+ plt.show()
1443
+ else:
1444
+ plt.close(fig)
1445
+
1446
+ def plot_streamline_covariance_samples(best_opt_params,
1447
+ fixed_params,
1448
+ data,
1449
+ uncertainties,
1450
+ distance,
1451
+ covariance_result,
1452
+ v_lsr=None,
1453
+ n_samples=100,
1454
+ save_folder=None):
1455
+ """
1456
+ Sample parameter sets from the covariance matrix, evaluate streamlines from those sets, and plot them all together
1457
+
1458
+ Parameters
1459
+ ----------
1460
+ best_opt_params : dict
1461
+ Best-fit optimised parameters in the original user-supplied parameterisation
1462
+ (e.g. with 'rc' or 'omega'). Used only to evaluate and plot the best-fit streamline.
1463
+ fixed_params : dict
1464
+ Fixed model parameters in the original user-supplied parameterisation.
1465
+ Used only to evaluate and plot the best-fit streamline.
1466
+ data : tuple of arrays (ra_data, dec_data, v_data)
1467
+ uncertainties : tuple of arrays (ra_sigma, dec_sigma, v_sigma)
1468
+ distance : float, distance in pc
1469
+ covariance_result : namedtuple from
1470
+ result = fit_streamline(...)
1471
+ covariance_result = result.cov_result
1472
+ v_lsr : float or None, km/s
1473
+ n_samples : int
1474
+ save_folder : str or None
1475
+ """
1476
+ cov_best_params = covariance_result.best_opt_params
1477
+ cov_opt_keys = covariance_result.opt_keys
1478
+ cov_fixed_params = covariance_result.fixed_params
1479
+ cov = covariance_result.covariance
1480
+ _, streamline_samples = generate_streamline_samples(
1481
+ best_opt_params=cov_best_params,
1482
+ covariance=cov,
1483
+ opt_keys=cov_opt_keys,
1484
+ fixed_params=cov_fixed_params,
1485
+ distance=distance,
1486
+ param_bounds=None,
1487
+ n_samples=n_samples
1488
+ )
1489
+
1490
+ fig, (ax_sky, ax_v) = plt.subplots(1, 2, figsize=(10, 5))
1491
+ ra_data, dec_data, v_data = data
1492
+ ra_sigma, dec_sigma, v_sigma = uncertainties
1493
+
1494
+ # plot samples streamlines
1495
+ for streamline in streamline_samples:
1496
+ ra = streamline['ra']
1497
+ dec = streamline['dec']
1498
+ vel = streamline['v']
1499
+
1500
+ rproj = np.sqrt(ra**2 + dec**2)
1501
+ order = np.argsort(rproj)
1502
+ ax_sky.plot(ra, dec, color='tab:blue', alpha=0.1, lw=1)
1503
+ ax_v.plot(rproj[order], vel[order], color='tab:blue', alpha=0.1, lw=1)
1504
+
1505
+ # plot best fit streamline
1506
+ best_opt_full_params, best_opt_params, fixed_params = gradient_descent.prepare_model_params(best_opt_params, fixed_params)
1507
+ ra_best, dec_best, v_best, valid_mask_best, err = gradient_descent.forward_model(best_opt_full_params, distance)
1508
+ ra_best = np.asarray(ra_best, dtype=float)
1509
+ dec_best = np.asarray(dec_best, dtype=float)
1510
+ v_best = np.asarray(v_best, dtype=float)
1511
+ valid_mask_best = valid_mask_best.astype(bool)
1512
+ ra_best = ra_best[valid_mask_best]
1513
+ dec_best = dec_best[valid_mask_best]
1514
+ v_best = v_best[valid_mask_best]
1515
+ rproj_best = np.sqrt(ra_best**2 + dec_best**2)
1516
+ order_best = np.argsort(rproj_best)
1517
+ ax_sky.plot(ra_best, dec_best, color='blue', lw=2, label='Best-fit')
1518
+ ax_v.plot(rproj_best[order_best], v_best[order_best], color='blue', lw=2, label='Best-fit')
1519
+
1520
+ # plot data
1521
+ ax_sky.errorbar(
1522
+ ra_data, dec_data, xerr=ra_sigma, yerr=dec_sigma,
1523
+ fmt='o', color='red', ecolor='red', ms=4, alpha=0.9, label='Data'
1524
+ )
1525
+ rproj_data = np.sqrt(ra_data**2 + dec_data**2)
1526
+ # get errors in rproj_data
1527
+ rproj_sigma = np.sqrt((ra_data * ra_sigma)**2 + (dec_data * dec_sigma)**2) / rproj_data
1528
+ order_data = np.argsort(rproj_data)
1529
+ ax_v.errorbar(
1530
+ rproj_data[order_data], np.asarray(v_data)[order_data], yerr=np.asarray(v_sigma)[order_data], xerr=np.asarray(rproj_sigma)[order_data],
1531
+ fmt='o', color='red', ecolor='red', ms=4, alpha=0.9, label='Data'
1532
+ )
1533
+ if v_lsr is not None:
1534
+ xmin, xmax = ax_v.get_xlim()
1535
+ ax_v.hlines(v_lsr, xmin=xmin, xmax=xmax, colors='k', linestyles='--', alpha=0.6,)
1536
+ ax_v.set_xlim(xmin, xmax)
1537
+
1538
+ # finalise plots
1539
+ ax_sky.invert_xaxis()
1540
+ ax_sky.set_xlabel('RA Offset (arcsec)')
1541
+ ax_sky.set_ylabel('Dec Offset (arcsec)')
1542
+ ax_sky.set_title('Covariance Sampling')
1543
+ ax_sky.legend()
1544
+
1545
+ ax_v.set_xlabel('Projected distance(arcsec)')
1546
+ ax_v.set_ylabel('Velocity (km/s)')
1547
+ ax_v.set_title('Covariance Sampling')
1548
+ ax_v.legend()
1549
+
1550
+ plt.tight_layout()
1551
+
1552
+ if save_folder is not None:
1553
+ os.makedirs(save_folder, exist_ok=True)
1554
+ plt.savefig(f'{save_folder}/streamline_covariance_samples.png', dpi=300, bbox_inches='tight')
1555
+ plt.show()
1556
+ else:
1557
+ plt.show()
1558
+
1559
+
1560
+ def generate_streamline_samples(best_opt_params, covariance, opt_keys, fixed_params, distance, param_bounds=None, n_samples=100):
1561
+ """
1562
+ wrapper of sample_parameter_sets_from_covariance() and evaluate_streamlines_samples() to generate streamline samples from covariance matrix.
1563
+ """
1564
+ samples = sample_parameter_sets_from_covariance(
1565
+ best_opt_params,
1566
+ covariance,
1567
+ opt_keys,
1568
+ param_bounds=param_bounds,
1569
+ n_samples=n_samples
1570
+ )
1571
+ streamlines = evaluate_streamlines_samples(
1572
+ samples,
1573
+ opt_keys,
1574
+ fixed_params,
1575
+ distance,
1576
+ )
1577
+ return samples, streamlines
1578
+
1579
+ def evaluate_streamlines_samples(param_samples, opt_keys, fixed_params, distance):
1580
+ """
1581
+ Evaluate streamline models for sampled parameter vectors
1582
+ Returns
1583
+ -------
1584
+ streamlines : list of dict
1585
+ Each entry contains:
1586
+ {
1587
+ "ra": ...,
1588
+ "dec": ...,
1589
+ "v": ...,
1590
+ "dmetric": ...
1591
+ }
1592
+ """
1593
+ streamlines = []
1594
+ for sample in param_samples:
1595
+ sample_params = {
1596
+ key: float(value)
1597
+ for key, value in zip(opt_keys, sample)
1598
+ }
1599
+ sample_params_full, _, _ = gradient_descent.prepare_model_params(sample_params, fixed_params)
1600
+ ra, dec, vel, valid_mask, err = gradient_descent.forward_model(sample_params_full, distance)
1601
+ ra = np.asarray(ra, dtype=float)
1602
+ dec = np.asarray(dec, dtype=float)
1603
+ vel = np.asarray(vel, dtype=float)
1604
+ valid_mask = valid_mask.astype(bool)
1605
+ ra = ra[valid_mask]
1606
+ dec = dec[valid_mask]
1607
+ vel = vel[valid_mask]
1608
+
1609
+ dmetric, trace = extract_streamline.get_distance_metric(ra, dec)
1610
+ dmetric = np.asarray(dmetric, dtype=float)
1611
+
1612
+ streamlines.append(
1613
+ {
1614
+ "ra": ra,
1615
+ "dec": dec,
1616
+ "v": vel,
1617
+ "dmetric": dmetric,
1618
+ }
1619
+ )
1620
+
1621
+ return streamlines
1622
+
1623
+ def sample_parameter_sets_from_covariance(best_params, covariance, opt_keys, param_bounds=None, n_samples=100, seed=42):
1624
+ """
1625
+ Draw parameter samples from a covariance matrix.
1626
+ Returns
1627
+ -------
1628
+ samples : ndarray, shape (n_samples, n_params)
1629
+ Sampled parameter vectors
1630
+ """
1631
+ rng = np.random.default_rng(seed)
1632
+ mu = np.array(
1633
+ [best_params[key] for key in opt_keys],
1634
+ dtype=float,
1635
+ )
1636
+ cov = np.asarray(covariance, dtype=float)
1637
+ samples = rng.multivariate_normal(
1638
+ mu,
1639
+ cov,
1640
+ size=n_samples,
1641
+ )
1642
+ param_bounds = gradient_descent.auto_fill_angle_bounds(set(opt_keys), param_bounds)
1643
+ param_bounds = gradient_descent.convert_and_strip_bound_units(param_bounds)
1644
+ for j, key in enumerate(opt_keys):
1645
+ if key in param_bounds:
1646
+ low, high = param_bounds[key]
1647
+ samples[:, j] = np.clip(samples[:, j], low, high)
1648
+
1649
+ return samples
1650
+
1651
+ def plot_param_optimisation_history(save_folder='sting_results'):
1652
+ '''Plot the history of parameter optimisation from logs saved during optimisation.
1653
+ save_folder should be the same as the one used during optimisation,
1654
+ and should contain "optimisation_log.csv" and optionally "optimisation_trace.csv"'''
1655
+
1656
+ try:
1657
+ optimisation_log = load_optimisation_log(save_folder)
1658
+ except FileNotFoundError:
1659
+ print(f"Error: Could not find 'optimisation_log.csv' and/or 'optimisation_trace.csv' in {save_folder}")
1660
+ return
1661
+
1662
+ epochs = optimisation_log["epoch"].values
1663
+ loss = optimisation_log["loss"].values
1664
+
1665
+ param_names = []
1666
+ for c in optimisation_log.columns:
1667
+ if c not in ("epoch", "loss"):
1668
+ param_names.append(c)
1669
+
1670
+ fig, axes = plt.subplots(len(param_names), 1, figsize=(8, 2 * (len(param_names) + 1)), sharex=True)
1671
+
1672
+ plot_loss_panel(axes[0], epochs, loss)
1673
+
1674
+ for ax, param in zip(axes[1:], param_names):
1675
+ values = optimisation_log[param].values
1676
+ ax.plot(epochs, values)
1677
+ ax.set_ylabel(param)
1678
+ ax.grid(True)
1679
+
1680
+ plt.tight_layout()
1681
+ if save_folder is not None:
1682
+ os.makedirs(save_folder, exist_ok=True)
1683
+ plt.savefig(f'{save_folder}/parameter_optimisation_history.png', dpi=300, bbox_inches='tight')
1684
+ plt.show()
1685
+ else:
1686
+ plt.show()
1687
+
1688
+ def plot_loss_panel(ax, epochs, loss):
1689
+ lowest_loss = np.min(loss)
1690
+ best_idx = np.argmin(loss)
1691
+ best_epoch = epochs[best_idx]
1692
+ ax.plot(epochs, loss, color="black")
1693
+ ax.scatter(best_epoch, lowest_loss, color="green", label=f"Best Epoch: {best_epoch}")
1694
+ ax.set_yscale("log")
1695
+ ax.set_ylabel("Loss")
1696
+ ax.grid(True, alpha=0.3)
1697
+ ax.legend()
1698
+
1699
+
1700
+ def load_optimisation_log(logs_dir):
1701
+ log_path = os.path.join(logs_dir, "optimisation_log.csv")
1702
+ optimisation_log = pd.read_csv(log_path)
1703
+ return optimisation_log
1704
+
1705
+