sting 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sting/__init__.py +8 -0
- sting/_version.py +24 -0
- sting/errors.py +677 -0
- sting/extract_streamline.py +425 -0
- sting/gradient_descent.py +1776 -0
- sting/outputs.py +1705 -0
- sting/stream_lines_grad.py +448 -0
- sting-0.2.0.dist-info/METADATA +251 -0
- sting-0.2.0.dist-info/RECORD +14 -0
- sting-0.2.0.dist-info/WHEEL +5 -0
- sting-0.2.0.dist-info/licenses/LICENCE +21 -0
- sting-0.2.0.dist-info/scm_file_list.json +26 -0
- sting-0.2.0.dist-info/scm_version.json +8 -0
- sting-0.2.0.dist-info/top_level.txt +1 -0
sting/outputs.py
ADDED
|
@@ -0,0 +1,1705 @@
|
|
|
1
|
+
'''
|
|
2
|
+
This file contains functions related to outputs from streamfit optimisation,
|
|
3
|
+
such as saving logs and plotting results.
|
|
4
|
+
|
|
5
|
+
Last updated: 03-06-26
|
|
6
|
+
'''
|
|
7
|
+
import json
|
|
8
|
+
import math
|
|
9
|
+
from matplotlib.patches import Patch
|
|
10
|
+
import numpy as np
|
|
11
|
+
import jax.numpy as jnp
|
|
12
|
+
import matplotlib.pyplot as plt
|
|
13
|
+
import matplotlib as mpl
|
|
14
|
+
mpl.rcParams["font.family"] = "serif"
|
|
15
|
+
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
|
16
|
+
import pandas as pd
|
|
17
|
+
import os
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
from . import gradient_descent
|
|
21
|
+
from . import extract_streamline
|
|
22
|
+
|
|
23
|
+
def param_for_display(key, value):
|
|
24
|
+
"""
|
|
25
|
+
Format parameter for display in output, with units. Notably:
|
|
26
|
+
- converts angles (theta0, phi0, inc, pa) from radians to degrees
|
|
27
|
+
Returns (display_key, display_value, unit_str)
|
|
28
|
+
"""
|
|
29
|
+
if key in gradient_descent.ANGLE_KEYS:
|
|
30
|
+
return key, math.degrees(float(value)), 'deg'
|
|
31
|
+
unit = gradient_descent.DISPLAY_UNITS.get(key, '')
|
|
32
|
+
return key, float(value), unit
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def evaluate_best_fit(
|
|
36
|
+
best_opt_params,
|
|
37
|
+
fixed_params,
|
|
38
|
+
data,
|
|
39
|
+
distance_pc,
|
|
40
|
+
by_eye_params=None,
|
|
41
|
+
):
|
|
42
|
+
"""
|
|
43
|
+
Run the forward model and match it to data for the best-fit parameters. Optionally run a by-eye parameter set through the forward model
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
best_opt_params : dict
|
|
48
|
+
Best-fit optimised parameters.
|
|
49
|
+
fixed_params : dict
|
|
50
|
+
Fixed model parameters.
|
|
51
|
+
data : tuple of arrays (ra_data, dec_data, v_data)
|
|
52
|
+
distance_pc : float
|
|
53
|
+
by_eye_params : dict or None
|
|
54
|
+
Optional by-eye parameters
|
|
55
|
+
Returns
|
|
56
|
+
-------
|
|
57
|
+
dict with keys:
|
|
58
|
+
ra_model, dec_model, v_model : full forward model arrays
|
|
59
|
+
ra_model_interp, dec_model_interp, v_model_interp : interpolated at data positions
|
|
60
|
+
valid : boolean mask of retained data points
|
|
61
|
+
by_eye : (ra, dec, v) tuple or None
|
|
62
|
+
"""
|
|
63
|
+
ra_data, dec_data, _ = data
|
|
64
|
+
|
|
65
|
+
best_opt_full_params, _, _ = gradient_descent.prepare_model_params(best_opt_params, fixed_params)
|
|
66
|
+
ra_best, dec_best, v_best, valid_mask_best, _err = gradient_descent.forward_model(best_opt_full_params, distance_pc)
|
|
67
|
+
valid_mask_best = valid_mask_best.astype(bool)
|
|
68
|
+
|
|
69
|
+
ra_best_interp, dec_best_interp, v_best_interp, valid_interp, _, _, _ = (
|
|
70
|
+
gradient_descent.checked_match_model_to_data_curve(
|
|
71
|
+
ra_best, dec_best, v_best, valid_mask_best,
|
|
72
|
+
jnp.asarray(ra_data, dtype=jnp.float64),
|
|
73
|
+
jnp.asarray(dec_data, dtype=jnp.float64),
|
|
74
|
+
)
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
by_eye = None
|
|
78
|
+
if by_eye_params is not None:
|
|
79
|
+
by_eye_full_params, _, _ = gradient_descent.prepare_model_params(by_eye_params, fixed_params)
|
|
80
|
+
ra_by_eye, dec_by_eye, v_by_eye, _, err_by_eye = gradient_descent.forward_model(by_eye_full_params, distance_pc)
|
|
81
|
+
err_by_eye.throw()
|
|
82
|
+
by_eye = (ra_by_eye, dec_by_eye, v_by_eye)
|
|
83
|
+
|
|
84
|
+
return dict(
|
|
85
|
+
ra_model=ra_best,
|
|
86
|
+
dec_model=dec_best,
|
|
87
|
+
v_model=v_best,
|
|
88
|
+
ra_model_interp=ra_best_interp,
|
|
89
|
+
dec_model_interp=dec_best_interp,
|
|
90
|
+
v_model_interp=v_best_interp,
|
|
91
|
+
valid=valid_interp,
|
|
92
|
+
by_eye=by_eye,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def plot_fitting_results(
|
|
97
|
+
ordered_best_opt_params,
|
|
98
|
+
opt_param_keys,
|
|
99
|
+
fixed_params,
|
|
100
|
+
streamer,
|
|
101
|
+
distance_pc,
|
|
102
|
+
loss_history,
|
|
103
|
+
param_errors,
|
|
104
|
+
cov_matrix,
|
|
105
|
+
v_lsr,
|
|
106
|
+
save_folder,
|
|
107
|
+
show_plots=False,
|
|
108
|
+
transformed_cov_result=None,
|
|
109
|
+
by_eye_params=None,
|
|
110
|
+
):
|
|
111
|
+
"""
|
|
112
|
+
Generate and save the followingbest-fit diagnostic plots to save_folderafter optimisation:
|
|
113
|
+
- loss_history.png : loss vs epoch
|
|
114
|
+
- best_fit_morphology.png : RA/Dec best fit
|
|
115
|
+
- best_fit_vel_radius.png : velocity-radius best fit
|
|
116
|
+
- parameter_uncertainties.png : sizes of error bars for each optimised param (if param_errors given)
|
|
117
|
+
- parameter_correlation_matrix.png : parameter correlation heatmap (if cov_matrix given)
|
|
118
|
+
|
|
119
|
+
Parameters
|
|
120
|
+
----------
|
|
121
|
+
ordered_best_opt_params : dict, best-fit optimised parameters
|
|
122
|
+
opt_param_keys : list of str, list of optimised parameter names
|
|
123
|
+
fixed_params : dict, fixed model parameters.
|
|
124
|
+
streamer : NamedTuple with fields:
|
|
125
|
+
pc_coords, ra_data, dec_data, v_data, ra_sigma, dec_sigma, v_sigma, data, uncertainties
|
|
126
|
+
distance_pc : float
|
|
127
|
+
loss_history : list of float, loss value at each epoch.
|
|
128
|
+
param_errors : dict or None, 1-sigma parameter uncertainties keyed by parameter name, or None if uncertainty estimation failed
|
|
129
|
+
cov_matrix : array or None, parameter covariance matrix, or None if uncertainty estimation failed.
|
|
130
|
+
v_lsr : float or None, km/s
|
|
131
|
+
save_folder : str, directory to write figures into (created if absent).
|
|
132
|
+
show_plots : bool, whether to display plots (in addition to saving). Default False
|
|
133
|
+
transformed_cov_result: dict or None. keys expected: 'keys', 'cov', 'errors'.
|
|
134
|
+
by_eye_params: dict or None, optional by-eye parameter guess (with the same params as order_best_opt_params). If provided, will be plotted alongside the best-fit model in the morphology and velocity-radius plots.
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
ra_data, dec_data, v_data = streamer.data
|
|
138
|
+
ra_sigma, dec_sigma, v_sigma = streamer.uncertainties
|
|
139
|
+
|
|
140
|
+
# Loss
|
|
141
|
+
plot_loss(loss_history, save_folder=save_folder, show=show_plots)
|
|
142
|
+
|
|
143
|
+
# evaluate best fit morphology and belocity-radius
|
|
144
|
+
best_fit = evaluate_best_fit(ordered_best_opt_params, fixed_params, streamer.data, distance_pc, by_eye_params=by_eye_params)
|
|
145
|
+
|
|
146
|
+
plot_morphology(
|
|
147
|
+
streamer=streamer,
|
|
148
|
+
ra_model=best_fit['ra_model'],
|
|
149
|
+
dec_model=best_fit['dec_model'],
|
|
150
|
+
ra_model_interp=best_fit['ra_model_interp'],
|
|
151
|
+
dec_model_interp=best_fit['dec_model_interp'],
|
|
152
|
+
valid=best_fit['valid'],
|
|
153
|
+
by_eye=best_fit['by_eye'],
|
|
154
|
+
save_folder=save_folder,
|
|
155
|
+
save_name='best_fit_morphology',
|
|
156
|
+
show=show_plots,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
plot_vel_radius(
|
|
160
|
+
streamer=streamer,
|
|
161
|
+
ra_model=best_fit['ra_model'],
|
|
162
|
+
dec_model=best_fit['dec_model'],
|
|
163
|
+
v_model=best_fit['v_model'],
|
|
164
|
+
ra_model_interp=best_fit['ra_model_interp'],
|
|
165
|
+
dec_model_interp=best_fit['dec_model_interp'],
|
|
166
|
+
v_model_interp=best_fit['v_model_interp'],
|
|
167
|
+
valid=best_fit['valid'],
|
|
168
|
+
by_eye=best_fit['by_eye'],
|
|
169
|
+
velocity_reference=v_lsr,
|
|
170
|
+
save_folder=save_folder,
|
|
171
|
+
save_name='best_fit_vel_radius',
|
|
172
|
+
show=show_plots,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# Uncertainty plots (only if error estimation succeeeded)
|
|
176
|
+
if param_errors is not None and cov_matrix is not None:
|
|
177
|
+
if transformed_cov_result is not None:
|
|
178
|
+
plot_keys = transformed_cov_result['keys'] # 'mu' replaced by 'rc'/'omega'
|
|
179
|
+
plot_cov = transformed_cov_result['cov'] # Jacobian-transformed covariance with 'mu' replaced by 'rc'/'omega'
|
|
180
|
+
plot_errors = transformed_cov_result['errors'] # 'mu' error transformed to 'rc'/'omega' error
|
|
181
|
+
else:
|
|
182
|
+
plot_keys = opt_param_keys
|
|
183
|
+
plot_cov = cov_matrix
|
|
184
|
+
plot_errors = param_errors
|
|
185
|
+
param_vals = np.array([float(ordered_best_opt_params.get(k, 0.0)) for k in opt_param_keys], dtype=float)
|
|
186
|
+
param_errs = np.array([float(plot_errors[k]) for k in plot_keys], dtype=float)
|
|
187
|
+
plot_param_uncertainties(plot_keys, param_vals, param_errs, save_folder=save_folder, show=show_plots)
|
|
188
|
+
plot_param_correlations(plot_keys, plot_cov, save_folder=save_folder, show=show_plots)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def save_best_fit_params(best_opt_params, fixed_params, param_errors, save_folder='sting_results'):
|
|
192
|
+
"""
|
|
193
|
+
saves parameters from the best-fit epoch (lowest loss) and their uncertainties
|
|
194
|
+
(when available, fixed params will not have uncertainties) to a JSON
|
|
195
|
+
"""
|
|
196
|
+
output_path = os.path.join(save_folder, 'best_fit_params.json')
|
|
197
|
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
198
|
+
|
|
199
|
+
# parameters that were optimised
|
|
200
|
+
optimised_section = {}
|
|
201
|
+
for raw_key, raw_val in best_opt_params.items():
|
|
202
|
+
display_key, display_val, unit = param_for_display(raw_key, raw_val)
|
|
203
|
+
entry = {
|
|
204
|
+
'value': display_val,
|
|
205
|
+
'unit': unit,
|
|
206
|
+
}
|
|
207
|
+
if param_errors is not None and raw_key in param_errors:
|
|
208
|
+
# same conversions for the errors as for the values
|
|
209
|
+
# for log_omega: propagate via omega * sigma_log_omega).
|
|
210
|
+
raw_err = float(param_errors[raw_key])
|
|
211
|
+
if raw_key in gradient_descent.ANGLE_KEYS:
|
|
212
|
+
display_err = math.degrees(raw_err)
|
|
213
|
+
else:
|
|
214
|
+
display_err = raw_err
|
|
215
|
+
entry['sigma'] = display_err
|
|
216
|
+
optimised_section[display_key] = entry
|
|
217
|
+
|
|
218
|
+
# parameters that were fixed (no uncertainties)
|
|
219
|
+
fixed_section = {}
|
|
220
|
+
for raw_key, raw_val in fixed_params.items():
|
|
221
|
+
if raw_val is None:
|
|
222
|
+
fixed_section[raw_key] = {
|
|
223
|
+
'value': None,
|
|
224
|
+
'unit': gradient_descent.DISPLAY_UNITS.get(raw_key, '')}
|
|
225
|
+
continue
|
|
226
|
+
display_key, display_val, unit = param_for_display(raw_key, raw_val)
|
|
227
|
+
fixed_section[display_key] = {
|
|
228
|
+
'value': display_val,
|
|
229
|
+
'unit': unit,
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
output = {
|
|
233
|
+
'optimised_parameters': optimised_section,
|
|
234
|
+
'fixed_parameters': fixed_section,
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
with open(output_path, 'w') as file:
|
|
238
|
+
json.dump(output, file, indent=4)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def _ensure_clean_dir(path):
|
|
242
|
+
"""Create directory if missing and remove any files inside it.
|
|
243
|
+
|
|
244
|
+
Keeps behaviour consistent across plotting functions that write epoch frames.
|
|
245
|
+
"""
|
|
246
|
+
os.makedirs(path, exist_ok=True)
|
|
247
|
+
for filename in os.listdir(path):
|
|
248
|
+
fp = os.path.join(path, filename)
|
|
249
|
+
if os.path.isfile(fp):
|
|
250
|
+
try:
|
|
251
|
+
os.remove(fp)
|
|
252
|
+
except OSError:
|
|
253
|
+
pass
|
|
254
|
+
|
|
255
|
+
def _opt_params_from_log(optimisation_log):
|
|
256
|
+
"""Return bare parameter names from the log, stripping any unit suffixes"""
|
|
257
|
+
skip = ['epoch', 'loss']
|
|
258
|
+
cols = [c for c in optimisation_log.columns if c not in skip]
|
|
259
|
+
# Strip unit suffixes to get bare param name
|
|
260
|
+
bare = [c.split(' [')[0] for c in cols]
|
|
261
|
+
if 'mu' in bare:
|
|
262
|
+
bare = [b for b in bare if b not in ('rc', 'omega')]
|
|
263
|
+
return bare
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def create_video_from_images(save_folder, input_pattern, output_name, fps=5):
|
|
267
|
+
"""Call ffmpeg to make a video from numbered image frames.
|
|
268
|
+
If you don't have ffmpeg, will print an error message instead of crashing
|
|
269
|
+
"""
|
|
270
|
+
import subprocess
|
|
271
|
+
import shutil
|
|
272
|
+
|
|
273
|
+
ffmpeg_exe = shutil.which("ffmpeg")
|
|
274
|
+
if ffmpeg_exe is None:
|
|
275
|
+
print(
|
|
276
|
+
"ffmpeg not found, can't create video.\n"
|
|
277
|
+
"To install: see https://ffmpeg.org/download.html and ensure ffmpeg is in your system PATH."
|
|
278
|
+
)
|
|
279
|
+
return
|
|
280
|
+
|
|
281
|
+
output_video = os.path.join(save_folder, output_name)
|
|
282
|
+
ffmpeg_cmd = [
|
|
283
|
+
"ffmpeg",
|
|
284
|
+
"-y",
|
|
285
|
+
"-loglevel", "error",
|
|
286
|
+
"-framerate", str(fps),
|
|
287
|
+
"-i", input_pattern,
|
|
288
|
+
"-vf",
|
|
289
|
+
"setpts='PTS/(1+0.01*N)',pad=ceil(iw/2)*2:ceil(ih/2)*2",
|
|
290
|
+
"-pix_fmt", "yuv420p",
|
|
291
|
+
output_video,
|
|
292
|
+
]
|
|
293
|
+
try:
|
|
294
|
+
subprocess.run(ffmpeg_cmd, check=True)
|
|
295
|
+
print(f"Video saved to {output_video}")
|
|
296
|
+
except subprocess.CalledProcessError as e:
|
|
297
|
+
print(f"Error creating video: {e}")
|
|
298
|
+
|
|
299
|
+
def plot_loss(loss_history, save_folder='sting_results', show=False):
|
|
300
|
+
'''Plot loss as a function of epochs'''
|
|
301
|
+
# plot loss vs epoch nicely
|
|
302
|
+
# matplotlib serif font
|
|
303
|
+
plt.rcParams['font.family'] = 'serif'
|
|
304
|
+
# Plot loss history
|
|
305
|
+
# Epoch indexing: epoch 0 = initial state, epoch i (i >= 1) = after update i
|
|
306
|
+
# loss_history is 0-indexed: loss_history[i] = loss at epoch i
|
|
307
|
+
plt.figure(figsize=(12, 3))
|
|
308
|
+
epochs = range(len(loss_history))
|
|
309
|
+
plt.plot(epochs, loss_history)
|
|
310
|
+
plt.xlabel('Epoch')
|
|
311
|
+
plt.ylabel('Loss')
|
|
312
|
+
plt.title('Optimisation Progress')
|
|
313
|
+
plt.yscale('log')
|
|
314
|
+
plt.grid(True, alpha=0.5)
|
|
315
|
+
if save_folder is not None:
|
|
316
|
+
os.makedirs(save_folder, exist_ok=True)
|
|
317
|
+
plt.savefig(f'{save_folder}/loss_history.png', dpi=300, bbox_inches='tight')
|
|
318
|
+
if show:
|
|
319
|
+
plt.show()
|
|
320
|
+
else:
|
|
321
|
+
plt.close()
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def make_morphology_background(pc_coords, metric_boundaries, ra_lim, dec_lim, figsize=(6.5, 7)):
|
|
325
|
+
"""
|
|
326
|
+
Pre-make the background image of the point cloud and metric boundaries for the morphology plots.
|
|
327
|
+
This function is caleld by plot_morphology_by_epoch
|
|
328
|
+
|
|
329
|
+
Returns
|
|
330
|
+
-------
|
|
331
|
+
bg_rgba : ndarray, shape (H, W, 4)
|
|
332
|
+
RGBA image of the background at the target figure resolution.
|
|
333
|
+
extent : list [left, right, bottom, top]
|
|
334
|
+
Data-space extent to pass to ax.imshow so the image aligns correctly.
|
|
335
|
+
"""
|
|
336
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
337
|
+
pc_coords_np = np.asarray(pc_coords, dtype=float)
|
|
338
|
+
ax.scatter(pc_coords_np[0], pc_coords_np[1], s=1, color='gray', alpha=0.3)
|
|
339
|
+
if metric_boundaries is not None:
|
|
340
|
+
extract_streamline.plot_metric_boundaries(ax, pc_coords_np, metric_boundaries,
|
|
341
|
+
color='gray', linewidth=1, alpha=0.3)
|
|
342
|
+
ax.set_xlim(ra_lim)
|
|
343
|
+
ax.set_ylim(dec_lim)
|
|
344
|
+
ax.axis('off')
|
|
345
|
+
fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
|
|
346
|
+
fig.canvas.draw()
|
|
347
|
+
buffer = fig.canvas.buffer_rgba()
|
|
348
|
+
bg_rgba = np.asarray(buffer).copy()
|
|
349
|
+
plt.close(fig)
|
|
350
|
+
extent = [ra_lim[0], ra_lim[1], dec_lim[0], dec_lim[1]]
|
|
351
|
+
return bg_rgba, extent
|
|
352
|
+
|
|
353
|
+
def plot_morphology_by_epoch(
|
|
354
|
+
gradient_descent,
|
|
355
|
+
fixed_params,
|
|
356
|
+
initial_opt_params,
|
|
357
|
+
distance,
|
|
358
|
+
streamer=None,
|
|
359
|
+
n_points=None,
|
|
360
|
+
save_folder="sting_results",
|
|
361
|
+
make_video=False
|
|
362
|
+
):
|
|
363
|
+
"""
|
|
364
|
+
Create and save one streamline morphology plot per optimisation epoch, in save_folder/epochs/morphology,
|
|
365
|
+
and optionally compile into a video
|
|
366
|
+
"""
|
|
367
|
+
try:
|
|
368
|
+
optimisation_log = load_optimisation_log(save_folder)
|
|
369
|
+
except FileNotFoundError:
|
|
370
|
+
print(f"Error: Could not find 'optimisation_log.csv' in {save_folder}")
|
|
371
|
+
return
|
|
372
|
+
column_map = {c.split(' [')[0]: c for c in optimisation_log.columns}
|
|
373
|
+
param_names = _opt_params_from_log(optimisation_log)
|
|
374
|
+
fixed_params_clean, initial_opt_params = gradient_descent.sanitize_param_partition(fixed_params, initial_opt_params, require_nonempty_opt=False)
|
|
375
|
+
|
|
376
|
+
epochs = optimisation_log['epoch'].values
|
|
377
|
+
|
|
378
|
+
# create the models
|
|
379
|
+
epoch_models = []
|
|
380
|
+
for idx, epoch in enumerate(epochs):
|
|
381
|
+
|
|
382
|
+
row = optimisation_log.iloc[idx]
|
|
383
|
+
opt_params_epoch = {param: float(row[column_map[param]]) for param in param_names}
|
|
384
|
+
model_params_epoch = {**fixed_params_clean, **opt_params_epoch}
|
|
385
|
+
ra_model, dec_model, v_model, valid_mask_model, err = gradient_descent.forward_model(model_params_epoch, distance)
|
|
386
|
+
valid_mask_model = valid_mask_model.astype(bool)
|
|
387
|
+
|
|
388
|
+
(ra_model_interp, dec_model_interp, _, valid, model_keep, dmetric_model, matching_trace) = gradient_descent.checked_match_model_to_data_curve(
|
|
389
|
+
ra_model,
|
|
390
|
+
dec_model,
|
|
391
|
+
v_model,
|
|
392
|
+
valid_mask_model,
|
|
393
|
+
streamer.ra_data,
|
|
394
|
+
streamer.dec_data,
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
epoch_models.append(
|
|
398
|
+
dict(
|
|
399
|
+
epoch=epoch,
|
|
400
|
+
opt_params_epoch=opt_params_epoch,
|
|
401
|
+
ra_model=np.asarray(ra_model),
|
|
402
|
+
dec_model=np.asarray(dec_model),
|
|
403
|
+
ra_model_interp=np.asarray(ra_model_interp),
|
|
404
|
+
dec_model_interp=np.asarray(dec_model_interp),
|
|
405
|
+
valid=np.asarray(valid),
|
|
406
|
+
model_keep=np.asarray(model_keep),
|
|
407
|
+
)
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
# get constant axis limits
|
|
411
|
+
all_ra = np.concatenate([
|
|
412
|
+
*[e['ra_model'] for e in epoch_models],
|
|
413
|
+
np.asarray(streamer.ra_data),
|
|
414
|
+
np.asarray(streamer.pc_coords[0]),
|
|
415
|
+
])
|
|
416
|
+
all_dec = np.concatenate([
|
|
417
|
+
*[e['dec_model'] for e in epoch_models],
|
|
418
|
+
np.asarray(streamer.dec_data),
|
|
419
|
+
np.asarray(streamer.pc_coords[1]),
|
|
420
|
+
])
|
|
421
|
+
mask = np.isfinite(all_ra) & np.isfinite(all_dec)
|
|
422
|
+
all_ra = all_ra[mask]
|
|
423
|
+
all_dec = all_dec[mask]
|
|
424
|
+
|
|
425
|
+
pad_ra = 0.05 * (all_ra.max() - all_ra.min())
|
|
426
|
+
pad_dec = 0.05 * (all_dec.max() - all_dec.min())
|
|
427
|
+
|
|
428
|
+
ra_lim = (all_ra.min() - pad_ra, all_ra.max() + pad_ra)
|
|
429
|
+
dec_lim = (all_dec.min() - pad_dec, all_dec.max() + pad_dec)
|
|
430
|
+
|
|
431
|
+
# prepare clean output folder for epoch frames
|
|
432
|
+
output_dir = os.path.join(save_folder, "epochs", "morphology")
|
|
433
|
+
_ensure_clean_dir(output_dir)
|
|
434
|
+
|
|
435
|
+
partitions = extract_streamline.get_metric_partitions(streamer.pc_coords, n_points)
|
|
436
|
+
metric_boundaries, trace = extract_streamline.sample_metric_boundaries(streamer.pc_coords, partitions)
|
|
437
|
+
|
|
438
|
+
# pre-make the background image (point cloud and metric boundaries)
|
|
439
|
+
bg_rgba, bg_extent = make_morphology_background(streamer.pc_coords, metric_boundaries, ra_lim, dec_lim)
|
|
440
|
+
|
|
441
|
+
# plot and save for each epoch
|
|
442
|
+
for model in epoch_models:
|
|
443
|
+
plot_morphology(
|
|
444
|
+
ra_model=model["ra_model"],
|
|
445
|
+
dec_model=model["dec_model"],
|
|
446
|
+
streamer=streamer,
|
|
447
|
+
ra_model_interp=model["ra_model_interp"],
|
|
448
|
+
dec_model_interp=model["dec_model_interp"],
|
|
449
|
+
valid=model["valid"],
|
|
450
|
+
bg_rgba=bg_rgba,
|
|
451
|
+
bg_extent=bg_extent,
|
|
452
|
+
title=f"Epoch: {int(model['epoch'])}",
|
|
453
|
+
xlim=ra_lim,
|
|
454
|
+
ylim=dec_lim,
|
|
455
|
+
save_folder=output_dir,
|
|
456
|
+
save_name=f"morphology_epoch_{int(model['epoch']):03d}",
|
|
457
|
+
show=False
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
if make_video:
|
|
461
|
+
input_pattern = os.path.join(output_dir, "morphology_epoch_%03d.png")
|
|
462
|
+
create_video_from_images(output_dir, input_pattern, "streamline_morphology_evolution.mp4", fps=5)
|
|
463
|
+
|
|
464
|
+
def plot_morphology(
|
|
465
|
+
ra_model=None,
|
|
466
|
+
dec_model=None,
|
|
467
|
+
streamer=None,
|
|
468
|
+
ra_model_interp=None,
|
|
469
|
+
dec_model_interp=None,
|
|
470
|
+
valid=None,
|
|
471
|
+
by_eye=None,
|
|
472
|
+
metric_boundaries=None,
|
|
473
|
+
bg_rgba=None,
|
|
474
|
+
bg_extent=None,
|
|
475
|
+
title=None,
|
|
476
|
+
xlim=None,
|
|
477
|
+
ylim=None,
|
|
478
|
+
legend_loc='lower right',
|
|
479
|
+
save_folder='sting_results',
|
|
480
|
+
save_name='streamline_morphology',
|
|
481
|
+
show=True,
|
|
482
|
+
):
|
|
483
|
+
'''Plot offsets in RA/Dec. Optionally include: model, model points, data points, best fit, background overlay, metric partitions.
|
|
484
|
+
|
|
485
|
+
For a single plot, pass pc_corords and metric_boundaries directly. For per-epoch plotting use plot_morphology_by_epoch,
|
|
486
|
+
which will call this function and pass pre-rendered background images for speed.'''
|
|
487
|
+
ra_data = None
|
|
488
|
+
dec_data = None
|
|
489
|
+
ra_sigma = None
|
|
490
|
+
dec_sigma = None
|
|
491
|
+
pc_coords = None
|
|
492
|
+
if streamer is not None:
|
|
493
|
+
ra_data = streamer.ra_data
|
|
494
|
+
dec_data = streamer.dec_data
|
|
495
|
+
ra_sigma = streamer.ra_sigma
|
|
496
|
+
dec_sigma = streamer.dec_sigma
|
|
497
|
+
pc_coords = streamer.pc_coords
|
|
498
|
+
|
|
499
|
+
fig, ax = plt.subplots(figsize=(6.5, 7))
|
|
500
|
+
if valid is not None:
|
|
501
|
+
valid = np.asarray(valid, dtype=bool)
|
|
502
|
+
|
|
503
|
+
# Static background: prefer pre-rendered image, fall back to live drawing
|
|
504
|
+
if bg_rgba is not None and bg_extent is not None:
|
|
505
|
+
ax.imshow(
|
|
506
|
+
bg_rgba,
|
|
507
|
+
extent=bg_extent,
|
|
508
|
+
aspect='auto',
|
|
509
|
+
origin='upper',
|
|
510
|
+
zorder=1,
|
|
511
|
+
)
|
|
512
|
+
elif pc_coords is not None:
|
|
513
|
+
pc_coords_np = np.asarray(pc_coords, dtype=float)
|
|
514
|
+
ax.scatter(pc_coords_np[0], pc_coords_np[1], s=1, color='gray',
|
|
515
|
+
alpha=0.3, label='Point cloud', zorder=4)
|
|
516
|
+
if metric_boundaries is not None:
|
|
517
|
+
ax_limits = ax.get_xlim(), ax.get_ylim()
|
|
518
|
+
extract_streamline.plot_metric_boundaries(
|
|
519
|
+
ax, pc_coords_np, metric_boundaries,
|
|
520
|
+
color='gray', linewidth=1, alpha=0.3,
|
|
521
|
+
)
|
|
522
|
+
ax.set_xlim(ax_limits[0])
|
|
523
|
+
ax.set_ylim(ax_limits[1])
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
# model curve if given
|
|
527
|
+
if ra_model is not None and dec_model is not None:
|
|
528
|
+
ax.plot(ra_model, dec_model, color='blue', linewidth=2, label='STING', zorder=7)
|
|
529
|
+
|
|
530
|
+
# model points if given
|
|
531
|
+
if ra_model_interp is not None and dec_model_interp is not None and valid is not None:
|
|
532
|
+
if valid is not None:
|
|
533
|
+
ax.scatter(
|
|
534
|
+
ra_model_interp[valid],
|
|
535
|
+
dec_model_interp[valid],
|
|
536
|
+
s=25,
|
|
537
|
+
color='blue',
|
|
538
|
+
zorder=7,
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
# by eye model if given
|
|
542
|
+
if by_eye is not None:
|
|
543
|
+
ra_by_eye, dec_by_eye, _ = by_eye
|
|
544
|
+
ax.plot(
|
|
545
|
+
np.asarray(ra_by_eye, dtype=float),
|
|
546
|
+
np.asarray(dec_by_eye, dtype=float),
|
|
547
|
+
color='tab:green',
|
|
548
|
+
linewidth=2,
|
|
549
|
+
label='By-eye',
|
|
550
|
+
zorder=8,
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
# data points streamline if given
|
|
554
|
+
if ra_data is not None and dec_data is not None:
|
|
555
|
+
if ra_sigma is not None and dec_sigma is not None:
|
|
556
|
+
ax.errorbar(
|
|
557
|
+
ra_data,
|
|
558
|
+
dec_data,
|
|
559
|
+
xerr=ra_sigma,
|
|
560
|
+
yerr=dec_sigma,
|
|
561
|
+
fmt='o-',
|
|
562
|
+
label='Extracted 1D Streamline',
|
|
563
|
+
color='red',
|
|
564
|
+
zorder=5,
|
|
565
|
+
)
|
|
566
|
+
else:
|
|
567
|
+
ax.plot(
|
|
568
|
+
ra_data,
|
|
569
|
+
dec_data,
|
|
570
|
+
'o-',
|
|
571
|
+
label='Extracted 1D Streamline',
|
|
572
|
+
color='red',
|
|
573
|
+
zorder=5,
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
star_ra = 0
|
|
577
|
+
star_dec = 0
|
|
578
|
+
ax.scatter(star_ra, star_dec, marker='*', s=100, color='yellow', edgecolor='black', zorder=10)
|
|
579
|
+
ax.set_xlabel('RA Offset (arcsec)')
|
|
580
|
+
ax.set_ylabel('Dec Offset (arcsec)')
|
|
581
|
+
|
|
582
|
+
if xlim is not None:
|
|
583
|
+
ax.set_xlim(xlim)
|
|
584
|
+
else:
|
|
585
|
+
if pc_coords is not None:
|
|
586
|
+
all_ra = np.concatenate([*[e for e in [ra_model, ra_data, pc_coords[0], np.array([star_ra])] if e is not None]])
|
|
587
|
+
else:
|
|
588
|
+
all_ra = np.concatenate([*[e for e in [ra_model, ra_data, np.array([star_ra])] if e is not None]])
|
|
589
|
+
pad_ra = 0.05 * (all_ra.max() - all_ra.min())
|
|
590
|
+
ra_lim = (all_ra.min() - pad_ra, all_ra.max() + pad_ra)
|
|
591
|
+
ax.set_xlim(ra_lim)
|
|
592
|
+
if ylim is not None:
|
|
593
|
+
ax.set_ylim(ylim)
|
|
594
|
+
else:
|
|
595
|
+
if pc_coords is not None:
|
|
596
|
+
all_dec = np.concatenate([*[e for e in [dec_model, dec_data, pc_coords[1], np.array([star_dec])] if e is not None]])
|
|
597
|
+
else:
|
|
598
|
+
all_dec = np.concatenate([*[e for e in [dec_model, dec_data, np.array([star_dec])] if e is not None]])
|
|
599
|
+
pad_dec = 0.05 * (all_dec.max() - all_dec.min())
|
|
600
|
+
dec_lim = (all_dec.min() - pad_dec, all_dec.max() + pad_dec)
|
|
601
|
+
ax.set_ylim(dec_lim)
|
|
602
|
+
ax.invert_xaxis()
|
|
603
|
+
ax.set_title(title)
|
|
604
|
+
ax.legend(loc=legend_loc)
|
|
605
|
+
if save_folder is not None:
|
|
606
|
+
# make dir if it doesn't exist
|
|
607
|
+
os.makedirs(save_folder, exist_ok=True)
|
|
608
|
+
plt.savefig(f'{save_folder}/{save_name}.png', dpi=300, bbox_inches='tight')
|
|
609
|
+
if show:
|
|
610
|
+
plt.show()
|
|
611
|
+
else:
|
|
612
|
+
plt.close(fig)
|
|
613
|
+
|
|
614
|
+
def plot_ra_vel_by_epoch(
|
|
615
|
+
gradient_descent,
|
|
616
|
+
fixed_params,
|
|
617
|
+
initial_opt_params,
|
|
618
|
+
distance,
|
|
619
|
+
streamer=None,
|
|
620
|
+
save_folder="sting_results",
|
|
621
|
+
make_video=False,
|
|
622
|
+
):
|
|
623
|
+
"""
|
|
624
|
+
Create RA–velocity plots for every epoch
|
|
625
|
+
"""
|
|
626
|
+
try:
|
|
627
|
+
optimisation_log = load_optimisation_log(save_folder)
|
|
628
|
+
except FileNotFoundError:
|
|
629
|
+
print(f"Error: Could not find 'optimisation_log.csv' in {save_folder}")
|
|
630
|
+
return
|
|
631
|
+
column_map = {c.split(' [')[0]: c for c in optimisation_log.columns}
|
|
632
|
+
param_names = _opt_params_from_log(optimisation_log)
|
|
633
|
+
fixed_params_clean, initial_opt_params = gradient_descent.sanitize_param_partition(fixed_params, initial_opt_params, require_nonempty_opt=False)
|
|
634
|
+
|
|
635
|
+
epochs = optimisation_log['epoch'].values
|
|
636
|
+
|
|
637
|
+
epoch_models = []
|
|
638
|
+
|
|
639
|
+
# make models
|
|
640
|
+
for idx, epoch in enumerate(epochs):
|
|
641
|
+
|
|
642
|
+
row = optimisation_log.iloc[idx]
|
|
643
|
+
opt_params_epoch = {param: float(row[column_map[param]]) for param in param_names}
|
|
644
|
+
model_params_epoch = {**fixed_params_clean, **opt_params_epoch}
|
|
645
|
+
ra_model, dec_model, v_model, valid_mask_model, err = gradient_descent.forward_model(model_params_epoch, distance)
|
|
646
|
+
valid_mask_model = valid_mask_model.astype(bool)
|
|
647
|
+
ra_model_interp, _, v_model_interp, valid, model_keep, dmetric_model, matching_trace = (
|
|
648
|
+
gradient_descent.checked_match_model_to_data_curve(ra_model, dec_model, v_model, valid_mask_model, streamer.ra_data, streamer.dec_data)
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
epoch_models.append({
|
|
652
|
+
"epoch": epoch,
|
|
653
|
+
"ra_model": ra_model,
|
|
654
|
+
"v_model": v_model,
|
|
655
|
+
"ra_model_interp": ra_model_interp,
|
|
656
|
+
"v_model_interp": v_model_interp,
|
|
657
|
+
"valid": valid,
|
|
658
|
+
"model_keep": model_keep,
|
|
659
|
+
})
|
|
660
|
+
|
|
661
|
+
# global velocity limits
|
|
662
|
+
v_list = [m["v_model"] for m in epoch_models]
|
|
663
|
+
if streamer is not None:
|
|
664
|
+
v_list.append(streamer.v_data)
|
|
665
|
+
all_v = np.concatenate(v_list)
|
|
666
|
+
vlim = (np.nanmin(all_v), np.nanmax(all_v))
|
|
667
|
+
|
|
668
|
+
# global RA limits
|
|
669
|
+
ra_list = [m["ra_model"] for m in epoch_models]
|
|
670
|
+
if streamer is not None:
|
|
671
|
+
ra_list.append(streamer.ra_data)
|
|
672
|
+
all_ra = np.concatenate(ra_list)
|
|
673
|
+
ralim = (np.nanmin(all_ra), np.nanmax(all_ra))
|
|
674
|
+
|
|
675
|
+
# make clean output folder
|
|
676
|
+
output_dir = os.path.join(save_folder, "epochs", "ra_vel")
|
|
677
|
+
_ensure_clean_dir(output_dir)
|
|
678
|
+
|
|
679
|
+
|
|
680
|
+
# make the plots
|
|
681
|
+
for model in epoch_models:
|
|
682
|
+
plot_ra_vel(
|
|
683
|
+
ra_model=model["ra_model"],
|
|
684
|
+
v_model=model["v_model"],
|
|
685
|
+
streamer=streamer,
|
|
686
|
+
ra_model_interp=model["ra_model_interp"],
|
|
687
|
+
v_model_interp=model["v_model_interp"],
|
|
688
|
+
valid=model["valid"],
|
|
689
|
+
model_keep=model["model_keep"],
|
|
690
|
+
title=f"Epoch: {int(model['epoch'])}",
|
|
691
|
+
vlim=vlim,
|
|
692
|
+
ralim=ralim,
|
|
693
|
+
save_folder=output_dir,
|
|
694
|
+
save_name=f"ra_vel_epoch_{int(model['epoch']):03d}",
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
if make_video:
|
|
698
|
+
input_pattern = os.path.join(output_dir, "ra_vel_epoch_%03d.png")
|
|
699
|
+
create_video_from_images(output_dir, input_pattern, "streamline_ra_vel_evolution.mp4", fps=5)
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
def plot_ra_vel(
|
|
703
|
+
ra_model,
|
|
704
|
+
v_model,
|
|
705
|
+
*,
|
|
706
|
+
streamer=None,
|
|
707
|
+
ra_model_interp=None,
|
|
708
|
+
v_model_interp=None,
|
|
709
|
+
valid=None,
|
|
710
|
+
model_keep=None,
|
|
711
|
+
title=None,
|
|
712
|
+
vlim=None,
|
|
713
|
+
ralim=None,
|
|
714
|
+
legend_loc='lower right',
|
|
715
|
+
save_folder='sting_results',
|
|
716
|
+
save_name='streamline_ra_vel',
|
|
717
|
+
show=False,
|
|
718
|
+
):
|
|
719
|
+
ra_data = None
|
|
720
|
+
v_data = None
|
|
721
|
+
ra_sigma = None
|
|
722
|
+
v_sigma = None
|
|
723
|
+
pc_coords = None
|
|
724
|
+
if streamer is not None:
|
|
725
|
+
ra_data = streamer.ra_data
|
|
726
|
+
v_data = streamer.v_data
|
|
727
|
+
ra_sigma = streamer.ra_sigma
|
|
728
|
+
v_sigma = streamer.v_sigma
|
|
729
|
+
pc_coords = streamer.pc_coords
|
|
730
|
+
|
|
731
|
+
ra_model = np.asarray(ra_model, dtype=float)
|
|
732
|
+
v_model = np.asarray(v_model, dtype=float)
|
|
733
|
+
if valid is not None:
|
|
734
|
+
valid = np.asarray(valid, dtype=bool)
|
|
735
|
+
if model_keep is not None:
|
|
736
|
+
model_keep = np.asarray(model_keep, dtype=bool)
|
|
737
|
+
fig, ax = plt.subplots(figsize=(6, 5))
|
|
738
|
+
|
|
739
|
+
|
|
740
|
+
# point cloud (RA vs velocity)
|
|
741
|
+
if pc_coords is not None:
|
|
742
|
+
ax.scatter(pc_coords[0], pc_coords[2], s=1, alpha=0.3, color='grey', label='Point cloud')
|
|
743
|
+
|
|
744
|
+
# data
|
|
745
|
+
if ra_data is not None and v_data is not None:
|
|
746
|
+
if ra_sigma is not None and v_sigma is not None:
|
|
747
|
+
ax.errorbar(ra_data, v_data, xerr=ra_sigma, yerr=v_sigma, fmt='o', color='red', ecolor='red', ms=4, alpha=0.9, label='Data')
|
|
748
|
+
else:
|
|
749
|
+
ax.plot(ra_data, v_data, 'o', color='red', label='Data')
|
|
750
|
+
|
|
751
|
+
# model curve
|
|
752
|
+
if ra_model is not None and v_model is not None:
|
|
753
|
+
ax.plot(ra_model, v_model, color='blue', linewidth=2, label='Model Streamline', zorder=7)
|
|
754
|
+
|
|
755
|
+
# interpolated points
|
|
756
|
+
if ra_model_interp is not None and v_model_interp is not None and valid is not None:
|
|
757
|
+
ax.scatter(np.asarray(ra_model_interp)[valid], np.asarray(v_model_interp)[valid], s=25, color='blue', zorder=5, label='Model at data positions')
|
|
758
|
+
|
|
759
|
+
ax.set_xlabel("RA Offset (arcsec)")
|
|
760
|
+
ax.set_ylabel("Velocity (km/s)")
|
|
761
|
+
ax.set_title(title or "RA vs Velocity")
|
|
762
|
+
|
|
763
|
+
if vlim is not None:
|
|
764
|
+
ax.set_ylim(vlim)
|
|
765
|
+
if ralim is not None:
|
|
766
|
+
ax.set_xlim(ralim)
|
|
767
|
+
|
|
768
|
+
# flip RA axis to match astronomical convention
|
|
769
|
+
ax.invert_xaxis()
|
|
770
|
+
|
|
771
|
+
ax.legend(loc=legend_loc)
|
|
772
|
+
if save_folder is not None:
|
|
773
|
+
os.makedirs(save_folder, exist_ok=True)
|
|
774
|
+
save_path = os.path.join(save_folder, f"{save_name}.png")
|
|
775
|
+
plt.savefig(save_path, bbox_inches='tight', dpi=300)
|
|
776
|
+
if show:
|
|
777
|
+
plt.show()
|
|
778
|
+
else:
|
|
779
|
+
plt.close(fig)
|
|
780
|
+
|
|
781
|
+
|
|
782
|
+
#########
|
|
783
|
+
def plot_dec_vel_by_epoch(
|
|
784
|
+
gradient_descent,
|
|
785
|
+
fixed_params,
|
|
786
|
+
initial_opt_params,
|
|
787
|
+
distance,
|
|
788
|
+
streamer=None,
|
|
789
|
+
save_folder="sting_results",
|
|
790
|
+
make_video=False,
|
|
791
|
+
):
|
|
792
|
+
"""
|
|
793
|
+
Create DEC–velocity plots for every epoch
|
|
794
|
+
"""
|
|
795
|
+
try:
|
|
796
|
+
optimisation_log = load_optimisation_log(save_folder)
|
|
797
|
+
except FileNotFoundError:
|
|
798
|
+
print(f"Error: Could not find 'optimisation_log.csv' in {save_folder}")
|
|
799
|
+
return
|
|
800
|
+
column_map = {c.split(' [')[0]: c for c in optimisation_log.columns}
|
|
801
|
+
param_names = _opt_params_from_log(optimisation_log)
|
|
802
|
+
fixed_params_clean, initial_opt_params = gradient_descent.sanitize_param_partition(fixed_params, initial_opt_params, require_nonempty_opt=False)
|
|
803
|
+
|
|
804
|
+
epochs = optimisation_log['epoch'].values
|
|
805
|
+
|
|
806
|
+
epoch_models = []
|
|
807
|
+
|
|
808
|
+
# make models
|
|
809
|
+
for idx, epoch in enumerate(epochs):
|
|
810
|
+
|
|
811
|
+
row = optimisation_log.iloc[idx]
|
|
812
|
+
opt_params_epoch = {param: float(row[column_map[param]]) for param in param_names}
|
|
813
|
+
model_params_epoch = {**fixed_params_clean, **opt_params_epoch}
|
|
814
|
+
ra_model, dec_model, v_model, valid_mask_model, err = gradient_descent.forward_model(model_params_epoch, distance)
|
|
815
|
+
valid_mask_model = valid_mask_model.astype(bool)
|
|
816
|
+
ra_model_interp, dec_model_interp, v_model_interp, valid, model_keep, dmetric_model, matching_trace = (
|
|
817
|
+
gradient_descent.checked_match_model_to_data_curve(ra_model, dec_model, v_model, valid_mask_model, streamer.ra_data, streamer.dec_data)
|
|
818
|
+
)
|
|
819
|
+
|
|
820
|
+
epoch_models.append({
|
|
821
|
+
"epoch": epoch,
|
|
822
|
+
"ra_model": ra_model,
|
|
823
|
+
"dec_model": dec_model,
|
|
824
|
+
"v_model": v_model,
|
|
825
|
+
"ra_model_interp": ra_model_interp,
|
|
826
|
+
"dec_model_interp": dec_model_interp,
|
|
827
|
+
"v_model_interp": v_model_interp,
|
|
828
|
+
"valid": valid,
|
|
829
|
+
"model_keep": model_keep,
|
|
830
|
+
})
|
|
831
|
+
|
|
832
|
+
# global velocity limits
|
|
833
|
+
v_list = [m["v_model"] for m in epoch_models]
|
|
834
|
+
if streamer is not None:
|
|
835
|
+
v_list.append(streamer.v_data)
|
|
836
|
+
all_v = np.concatenate(v_list)
|
|
837
|
+
vlim = (np.nanmin(all_v), np.nanmax(all_v))
|
|
838
|
+
|
|
839
|
+
# global dec limits
|
|
840
|
+
dec_list = [m["dec_model"] for m in epoch_models]
|
|
841
|
+
if streamer is not None:
|
|
842
|
+
dec_list.append(streamer.dec_data)
|
|
843
|
+
all_dec = np.concatenate(dec_list)
|
|
844
|
+
declim = (np.nanmin(all_dec), np.nanmax(all_dec))
|
|
845
|
+
|
|
846
|
+
# make clean output folder
|
|
847
|
+
output_dir = os.path.join(save_folder, "epochs", "dec_vel")
|
|
848
|
+
_ensure_clean_dir(output_dir)
|
|
849
|
+
|
|
850
|
+
|
|
851
|
+
# make the plots
|
|
852
|
+
for model in epoch_models:
|
|
853
|
+
plot_dec_vel(
|
|
854
|
+
dec_model=model["dec_model"],
|
|
855
|
+
v_model=model["v_model"],
|
|
856
|
+
streamer=streamer,
|
|
857
|
+
dec_model_interp=model["dec_model_interp"],
|
|
858
|
+
v_model_interp=model["v_model_interp"],
|
|
859
|
+
valid=model["valid"],
|
|
860
|
+
model_keep=model["model_keep"],
|
|
861
|
+
title=f"Epoch: {int(model['epoch'])}",
|
|
862
|
+
vlim=vlim,
|
|
863
|
+
declim=declim,
|
|
864
|
+
save_folder = output_dir,
|
|
865
|
+
save_name = f"dec_vel_epoch_{int(model['epoch']):03d}",
|
|
866
|
+
)
|
|
867
|
+
|
|
868
|
+
if make_video:
|
|
869
|
+
input_pattern = os.path.join(output_dir, "dec_vel_epoch_%03d.png")
|
|
870
|
+
create_video_from_images(output_dir, input_pattern, "streamline_dec_vel_evolution.mp4", fps=5)
|
|
871
|
+
|
|
872
|
+
|
|
873
|
+
def plot_dec_vel(
|
|
874
|
+
dec_model,
|
|
875
|
+
v_model,
|
|
876
|
+
*,
|
|
877
|
+
streamer=None,
|
|
878
|
+
dec_model_interp=None,
|
|
879
|
+
v_model_interp=None,
|
|
880
|
+
valid=None,
|
|
881
|
+
model_keep=None,
|
|
882
|
+
title=None,
|
|
883
|
+
vlim=None,
|
|
884
|
+
declim=None,
|
|
885
|
+
legend_loc='lower right',
|
|
886
|
+
save_folder='sting_results',
|
|
887
|
+
save_name='streamline_dec_vel',
|
|
888
|
+
show=False,
|
|
889
|
+
):
|
|
890
|
+
dec_data = None
|
|
891
|
+
v_data = None
|
|
892
|
+
dec_sigma = None
|
|
893
|
+
v_sigma = None
|
|
894
|
+
pc_coords = None
|
|
895
|
+
if streamer is not None:
|
|
896
|
+
dec_data = streamer.dec_data
|
|
897
|
+
v_data = streamer.v_data
|
|
898
|
+
dec_sigma = streamer.dec_sigma
|
|
899
|
+
v_sigma = streamer.v_sigma
|
|
900
|
+
pc_coords = streamer.pc_coords
|
|
901
|
+
|
|
902
|
+
dec_model = np.asarray(dec_model, dtype=float)
|
|
903
|
+
v_model = np.asarray(v_model, dtype=float)
|
|
904
|
+
if valid is not None:
|
|
905
|
+
valid = np.asarray(valid, dtype=bool)
|
|
906
|
+
if model_keep is not None:
|
|
907
|
+
model_keep = np.asarray(model_keep, dtype=bool)
|
|
908
|
+
|
|
909
|
+
fig, ax = plt.subplots(figsize=(6, 5))
|
|
910
|
+
|
|
911
|
+
# point cloud (DEC vs velocity)
|
|
912
|
+
if pc_coords is not None:
|
|
913
|
+
# pc_coords layout: [ra, dec, velocity]
|
|
914
|
+
ax.scatter(pc_coords[1], pc_coords[2], s=1, alpha=0.3, color='grey', label='Point cloud')
|
|
915
|
+
|
|
916
|
+
# data
|
|
917
|
+
if dec_data is not None and v_data is not None:
|
|
918
|
+
if dec_sigma is not None and v_sigma is not None:
|
|
919
|
+
ax.errorbar(dec_data, v_data, xerr=dec_sigma, yerr=v_sigma, fmt='o', color='red', ecolor='red', ms=4, alpha=0.9, label='Data')
|
|
920
|
+
else:
|
|
921
|
+
ax.plot(dec_data, v_data, 'o', color='red', label='Data')
|
|
922
|
+
|
|
923
|
+
# model curve
|
|
924
|
+
if dec_model is not None and v_model is not None:
|
|
925
|
+
ax.plot(dec_model, v_model, color='blue', linewidth=2, label='Model Streamline')
|
|
926
|
+
|
|
927
|
+
# interpolated points
|
|
928
|
+
if dec_model_interp is not None and v_model_interp is not None and valid is not None:
|
|
929
|
+
ax.scatter(np.asarray(dec_model_interp)[valid], np.asarray(v_model_interp)[valid], s=25, color='blue', zorder=5, label='Model at data positions')
|
|
930
|
+
|
|
931
|
+
ax.set_xlabel("DEC Offset (arcsec)")
|
|
932
|
+
ax.set_ylabel("Velocity (km/s)")
|
|
933
|
+
ax.set_title(title or "DEC vs Velocity")
|
|
934
|
+
|
|
935
|
+
if vlim is not None:
|
|
936
|
+
ax.set_ylim(vlim)
|
|
937
|
+
if declim is not None:
|
|
938
|
+
ax.set_xlim(declim)
|
|
939
|
+
|
|
940
|
+
ax.legend(loc=legend_loc)
|
|
941
|
+
|
|
942
|
+
if save_folder is not None:
|
|
943
|
+
os.makedirs(save_folder, exist_ok=True)
|
|
944
|
+
save_path = os.path.join(save_folder, f"{save_name}.png")
|
|
945
|
+
plt.savefig(save_path, bbox_inches='tight', dpi=300)
|
|
946
|
+
plt.close(fig)
|
|
947
|
+
elif show:
|
|
948
|
+
plt.show()
|
|
949
|
+
else:
|
|
950
|
+
plt.close(fig)
|
|
951
|
+
|
|
952
|
+
|
|
953
|
+
def build_velocity_radius_kde(
|
|
954
|
+
ra_data,
|
|
955
|
+
dec_data,
|
|
956
|
+
vlos_data,
|
|
957
|
+
xmin=None,
|
|
958
|
+
xmax=None,
|
|
959
|
+
ymin=None,
|
|
960
|
+
ymax=None,
|
|
961
|
+
grid_size=100,
|
|
962
|
+
sigma_levels=None,
|
|
963
|
+
):
|
|
964
|
+
"""Build a KDE background grid for projected radius vs velocity plots.
|
|
965
|
+
|
|
966
|
+
Parameters
|
|
967
|
+
----------
|
|
968
|
+
rproj_data : array-like
|
|
969
|
+
Projected radial distance samples (arcsec).
|
|
970
|
+
vlos_data : array-like
|
|
971
|
+
Line-of-sight velocity samples (km/s).
|
|
972
|
+
xmin, xmax, ymin, ymax : float, optional
|
|
973
|
+
Plot limits for KDE grid. If omitted, finite data limits are used.
|
|
974
|
+
grid_size : int, optional
|
|
975
|
+
Number of grid points per axis.
|
|
976
|
+
sigma_levels : array-like, optional
|
|
977
|
+
Sigma values used to build cumulative Gaussian-like contour levels.
|
|
978
|
+
|
|
979
|
+
Returns
|
|
980
|
+
-------
|
|
981
|
+
dict
|
|
982
|
+
Dictionary with keys: "xx", "yy", "zz", "levels", "xlim", "ylim".
|
|
983
|
+
"""
|
|
984
|
+
from scipy import stats
|
|
985
|
+
ra = np.asarray(ra_data)
|
|
986
|
+
dec = np.asarray(dec_data)
|
|
987
|
+
rproj = np.sqrt(ra**2 + dec**2)
|
|
988
|
+
vlos = np.asarray(vlos_data, dtype=float)
|
|
989
|
+
finite = np.isfinite(rproj) & np.isfinite(vlos)
|
|
990
|
+
if np.sum(finite) < 3:
|
|
991
|
+
raise ValueError("Need at least 3 finite samples to build KDE background.")
|
|
992
|
+
|
|
993
|
+
rproj = rproj[finite]
|
|
994
|
+
vlos = vlos[finite]
|
|
995
|
+
|
|
996
|
+
if xmin is None:
|
|
997
|
+
xmin = float(np.nanmin(rproj) - 1)
|
|
998
|
+
if xmax is None:
|
|
999
|
+
xmax = float(np.nanmax(rproj) + 1)
|
|
1000
|
+
if ymin is None:
|
|
1001
|
+
ymin = float(np.nanmin(vlos) - 1)
|
|
1002
|
+
if ymax is None:
|
|
1003
|
+
ymax = float(np.nanmax(vlos) + 1)
|
|
1004
|
+
|
|
1005
|
+
|
|
1006
|
+
xx, yy = np.mgrid[xmin:xmax:complex(grid_size), ymin:ymax:complex(grid_size)]
|
|
1007
|
+
positions = np.vstack([xx.ravel(), yy.ravel()])
|
|
1008
|
+
values = np.vstack([rproj, vlos])
|
|
1009
|
+
|
|
1010
|
+
kernel = stats.gaussian_kde(values)
|
|
1011
|
+
zz = np.reshape(kernel(positions).T, xx.shape)
|
|
1012
|
+
zmax = np.nanmax(zz)
|
|
1013
|
+
if np.isfinite(zmax) and zmax > 0:
|
|
1014
|
+
zz = zz / zmax
|
|
1015
|
+
|
|
1016
|
+
if sigma_levels is None:
|
|
1017
|
+
sigma_levels = np.arange(1.0, 2.1, 0.5)
|
|
1018
|
+
sigma_levels = np.asarray(sigma_levels, dtype=float)
|
|
1019
|
+
levels = np.append(np.exp(-0.5 * sigma_levels**2)[::-1], [1.0])
|
|
1020
|
+
kde_levels = np.append(np.exp(-0.5 * np.arange(1.0, 2.1, 0.5)**2)[::-1], [1.0])
|
|
1021
|
+
|
|
1022
|
+
return {
|
|
1023
|
+
"xx": xx,
|
|
1024
|
+
"yy": yy,
|
|
1025
|
+
"zz": zz,
|
|
1026
|
+
"levels": levels,
|
|
1027
|
+
"xlim": (xmin, xmax),
|
|
1028
|
+
"ylim": (ymin, ymax),
|
|
1029
|
+
}
|
|
1030
|
+
|
|
1031
|
+
|
|
1032
|
+
def plot_vel_radius(
|
|
1033
|
+
ra_model,
|
|
1034
|
+
dec_model,
|
|
1035
|
+
v_model,
|
|
1036
|
+
streamer,
|
|
1037
|
+
*,
|
|
1038
|
+
ra_model_interp=None,
|
|
1039
|
+
dec_model_interp=None,
|
|
1040
|
+
v_model_interp=None,
|
|
1041
|
+
valid=None,
|
|
1042
|
+
by_eye=None,
|
|
1043
|
+
model_keep=None,
|
|
1044
|
+
kde_background=None,
|
|
1045
|
+
velocity_reference=None,
|
|
1046
|
+
title=None,
|
|
1047
|
+
xlim=None,
|
|
1048
|
+
ylim=None,
|
|
1049
|
+
legend_loc='lower right',
|
|
1050
|
+
save_folder='sting_results',
|
|
1051
|
+
save_name=None,
|
|
1052
|
+
show=False,
|
|
1053
|
+
):
|
|
1054
|
+
"""Plot velocity vs projected radius for one model (optionally with KDE background)."""
|
|
1055
|
+
ra_data = streamer.ra_data
|
|
1056
|
+
dec_data = streamer.dec_data
|
|
1057
|
+
v_data = streamer.v_data
|
|
1058
|
+
ra_sigma = streamer.ra_sigma
|
|
1059
|
+
dec_sigma = streamer.dec_sigma
|
|
1060
|
+
v_sigma = streamer.v_sigma
|
|
1061
|
+
pc_coords = streamer.pc_coords
|
|
1062
|
+
|
|
1063
|
+
ra_model = np.asarray(ra_model, dtype=float)
|
|
1064
|
+
dec_model = np.asarray(dec_model, dtype=float)
|
|
1065
|
+
v_model = np.asarray(v_model, dtype=float)
|
|
1066
|
+
if valid is not None:
|
|
1067
|
+
valid = np.asarray(valid, dtype=bool)
|
|
1068
|
+
if model_keep is not None:
|
|
1069
|
+
model_keep = np.asarray(model_keep, dtype=bool)
|
|
1070
|
+
ra_model = ra_model[model_keep]
|
|
1071
|
+
dec_model = dec_model[model_keep]
|
|
1072
|
+
v_model = v_model[model_keep]
|
|
1073
|
+
|
|
1074
|
+
rproj_model = np.sqrt(ra_model**2 + dec_model**2)
|
|
1075
|
+
order_model = np.argsort(rproj_model)
|
|
1076
|
+
|
|
1077
|
+
fig, ax = plt.subplots(figsize=(6.5 * 1.3, 4 * 1.3))
|
|
1078
|
+
data_handle = None
|
|
1079
|
+
model_handle = None
|
|
1080
|
+
background_handle = None
|
|
1081
|
+
by_eye_handle = None
|
|
1082
|
+
|
|
1083
|
+
if kde_background is None and pc_coords is not None:
|
|
1084
|
+
# make the kde background
|
|
1085
|
+
kde_background = build_velocity_radius_kde(
|
|
1086
|
+
ra_data=ra_data,
|
|
1087
|
+
dec_data=dec_data,
|
|
1088
|
+
vlos_data=v_data,
|
|
1089
|
+
)
|
|
1090
|
+
|
|
1091
|
+
if kde_background is not None:
|
|
1092
|
+
ax.contourf(
|
|
1093
|
+
kde_background["xx"],
|
|
1094
|
+
kde_background["yy"],
|
|
1095
|
+
kde_background["zz"],
|
|
1096
|
+
levels=kde_background["levels"],
|
|
1097
|
+
cmap='Greys',
|
|
1098
|
+
vmin=0,
|
|
1099
|
+
vmax=1.2,
|
|
1100
|
+
zorder=1,
|
|
1101
|
+
)
|
|
1102
|
+
|
|
1103
|
+
background_handle = Patch(
|
|
1104
|
+
facecolor='lightgray',
|
|
1105
|
+
edgecolor='none',
|
|
1106
|
+
label='Data KDE',
|
|
1107
|
+
)
|
|
1108
|
+
|
|
1109
|
+
# Central source marker in this projection (r=0, v=v_lsr)
|
|
1110
|
+
if velocity_reference is not None:
|
|
1111
|
+
ax.scatter(
|
|
1112
|
+
0,
|
|
1113
|
+
float(velocity_reference),
|
|
1114
|
+
marker='*',
|
|
1115
|
+
s=100,
|
|
1116
|
+
color='yellow',
|
|
1117
|
+
edgecolor='black',
|
|
1118
|
+
zorder=10,
|
|
1119
|
+
label='Central Source',
|
|
1120
|
+
)
|
|
1121
|
+
|
|
1122
|
+
if ra_data is not None and dec_data is not None and v_data is not None:
|
|
1123
|
+
ra_data = np.asarray(ra_data, dtype=float)
|
|
1124
|
+
dec_data = np.asarray(dec_data, dtype=float)
|
|
1125
|
+
v_data = np.asarray(v_data, dtype=float)
|
|
1126
|
+
rproj_data = np.sqrt(ra_data**2 + dec_data**2)
|
|
1127
|
+
|
|
1128
|
+
if ra_sigma is not None and dec_sigma is not None and v_sigma is not None:
|
|
1129
|
+
ra_sigma = np.asarray(ra_sigma, dtype=float)
|
|
1130
|
+
dec_sigma = np.asarray(dec_sigma, dtype=float)
|
|
1131
|
+
v_sigma = np.asarray(v_sigma, dtype=float)
|
|
1132
|
+
denom = np.maximum(rproj_data, 1e-8)
|
|
1133
|
+
rproj_sigma = np.sqrt((ra_data * ra_sigma) ** 2 + (dec_data * dec_sigma) ** 2) / denom
|
|
1134
|
+
data_handle = ax.errorbar(
|
|
1135
|
+
rproj_data,
|
|
1136
|
+
v_data,
|
|
1137
|
+
xerr=rproj_sigma,
|
|
1138
|
+
yerr=v_sigma,
|
|
1139
|
+
fmt='o',
|
|
1140
|
+
color='red',
|
|
1141
|
+
ecolor='red',
|
|
1142
|
+
ms=4,
|
|
1143
|
+
alpha=0.9,
|
|
1144
|
+
label='Extracted 1D Streamline',
|
|
1145
|
+
zorder=6,
|
|
1146
|
+
)
|
|
1147
|
+
else:
|
|
1148
|
+
data_handle = ax.plot(
|
|
1149
|
+
rproj_data,
|
|
1150
|
+
v_data,
|
|
1151
|
+
'o',
|
|
1152
|
+
color='red',
|
|
1153
|
+
label='Extracted 1D Streamline',
|
|
1154
|
+
zorder=6,
|
|
1155
|
+
)[0]
|
|
1156
|
+
|
|
1157
|
+
model_handle, = ax.plot(
|
|
1158
|
+
rproj_model[order_model],
|
|
1159
|
+
v_model[order_model],
|
|
1160
|
+
color='blue',
|
|
1161
|
+
linewidth=2,
|
|
1162
|
+
label='STING',
|
|
1163
|
+
zorder=7,
|
|
1164
|
+
)
|
|
1165
|
+
|
|
1166
|
+
if (
|
|
1167
|
+
ra_model_interp is not None
|
|
1168
|
+
and dec_model_interp is not None
|
|
1169
|
+
and v_model_interp is not None
|
|
1170
|
+
and valid is not None
|
|
1171
|
+
):
|
|
1172
|
+
ra_model_interp = np.asarray(ra_model_interp, dtype=float)
|
|
1173
|
+
dec_model_interp = np.asarray(dec_model_interp, dtype=float)
|
|
1174
|
+
v_model_interp = np.asarray(v_model_interp, dtype=float)
|
|
1175
|
+
valid = np.asarray(valid, dtype=bool)
|
|
1176
|
+
rproj_interp = np.sqrt(ra_model_interp**2 + dec_model_interp**2)
|
|
1177
|
+
ax.scatter(
|
|
1178
|
+
rproj_interp[valid],
|
|
1179
|
+
v_model_interp[valid],
|
|
1180
|
+
s=25,
|
|
1181
|
+
color='blue',
|
|
1182
|
+
label='Model at retained data arc lengths',
|
|
1183
|
+
zorder=8,
|
|
1184
|
+
)
|
|
1185
|
+
|
|
1186
|
+
if velocity_reference is not None:
|
|
1187
|
+
ax.axhline(
|
|
1188
|
+
float(velocity_reference),
|
|
1189
|
+
color='black',
|
|
1190
|
+
linestyle='--',
|
|
1191
|
+
label='Systemic Velocity',
|
|
1192
|
+
zorder=4,
|
|
1193
|
+
)
|
|
1194
|
+
|
|
1195
|
+
if by_eye is not None:
|
|
1196
|
+
ra_by_eye, dec_by_eye, v_by_eye = by_eye
|
|
1197
|
+
ra_by_eye = np.asarray(ra_by_eye, dtype=float)
|
|
1198
|
+
dec_by_eye = np.asarray(dec_by_eye, dtype=float)
|
|
1199
|
+
v_by_eye = np.asarray(v_by_eye, dtype=float)
|
|
1200
|
+
rproj_by_eye = np.sqrt(ra_by_eye**2 + dec_by_eye**2)
|
|
1201
|
+
by_eye_handle, = ax.plot(
|
|
1202
|
+
rproj_by_eye,
|
|
1203
|
+
v_by_eye,
|
|
1204
|
+
color='tab:green',
|
|
1205
|
+
linewidth=2,
|
|
1206
|
+
label='By-eye',
|
|
1207
|
+
zorder=9,
|
|
1208
|
+
)
|
|
1209
|
+
|
|
1210
|
+
ax.set_xlabel('Projected Distance from Source (arcsec)')
|
|
1211
|
+
ax.set_ylabel('Velocity (km/s)')
|
|
1212
|
+
ax.set_title(title or 'Velocity vs Projected Radius')
|
|
1213
|
+
|
|
1214
|
+
if xlim is not None:
|
|
1215
|
+
ax.set_xlim(xlim)
|
|
1216
|
+
elif kde_background is not None:
|
|
1217
|
+
ax.set_xlim(kde_background["xlim"])
|
|
1218
|
+
|
|
1219
|
+
if ylim is not None:
|
|
1220
|
+
ax.set_ylim(ylim)
|
|
1221
|
+
elif kde_background is not None:
|
|
1222
|
+
ax.set_ylim(kde_background["ylim"])
|
|
1223
|
+
|
|
1224
|
+
all_handles = [data_handle, model_handle, by_eye_handle, background_handle]
|
|
1225
|
+
handles = [h for h in all_handles if h is not None]
|
|
1226
|
+
if handles:
|
|
1227
|
+
ax.legend(handles=handles, loc=legend_loc)
|
|
1228
|
+
else:
|
|
1229
|
+
ax.legend(loc=legend_loc)
|
|
1230
|
+
|
|
1231
|
+
if save_folder is not None:
|
|
1232
|
+
os.makedirs(save_folder, exist_ok=True)
|
|
1233
|
+
plt.savefig(f'{save_folder}/{save_name}.png', dpi=300, bbox_inches='tight')
|
|
1234
|
+
if show:
|
|
1235
|
+
plt.show()
|
|
1236
|
+
else:
|
|
1237
|
+
plt.close(fig)
|
|
1238
|
+
|
|
1239
|
+
|
|
1240
|
+
def plot_vel_radius_by_epoch(
|
|
1241
|
+
gradient_descent,
|
|
1242
|
+
fixed_params,
|
|
1243
|
+
initial_opt_params,
|
|
1244
|
+
distance,
|
|
1245
|
+
streamer,
|
|
1246
|
+
*,
|
|
1247
|
+
grid_size=100,
|
|
1248
|
+
levels=None,
|
|
1249
|
+
velocity_reference=None,
|
|
1250
|
+
save_folder="sting_results",
|
|
1251
|
+
make_video=False,
|
|
1252
|
+
):
|
|
1253
|
+
"""Create velocity vs projected radius plots for every epoch."""
|
|
1254
|
+
|
|
1255
|
+
try:
|
|
1256
|
+
optimisation_log = load_optimisation_log(save_folder)
|
|
1257
|
+
except FileNotFoundError:
|
|
1258
|
+
print(f"Error: Could not find 'optimisation_log.csv' in {save_folder}")
|
|
1259
|
+
return
|
|
1260
|
+
column_map = {c.split(' [')[0]: c for c in optimisation_log.columns}
|
|
1261
|
+
param_names = _opt_params_from_log(optimisation_log)
|
|
1262
|
+
fixed_params_clean, initial_opt_params = gradient_descent.sanitize_param_partition(fixed_params, initial_opt_params, require_nonempty_opt=False)
|
|
1263
|
+
|
|
1264
|
+
epochs = optimisation_log['epoch'].values
|
|
1265
|
+
epoch_models = []
|
|
1266
|
+
|
|
1267
|
+
kde_background = None
|
|
1268
|
+
if streamer is not None:
|
|
1269
|
+
kde_background = build_velocity_radius_kde(
|
|
1270
|
+
ra_data=streamer.ra_data,
|
|
1271
|
+
dec_data=streamer.dec_data,
|
|
1272
|
+
vlos_data=streamer.v_data,
|
|
1273
|
+
grid_size=grid_size,
|
|
1274
|
+
sigma_levels=levels,
|
|
1275
|
+
)
|
|
1276
|
+
|
|
1277
|
+
for idx, epoch in enumerate(epochs):
|
|
1278
|
+
row = optimisation_log.iloc[idx]
|
|
1279
|
+
opt_params_epoch = {param: float(row[column_map[param]]) for param in param_names}
|
|
1280
|
+
model_params_epoch = {**fixed_params_clean, **opt_params_epoch}
|
|
1281
|
+
ra_model, dec_model, v_model, valid_mask_model, err = gradient_descent.forward_model(model_params_epoch, distance)
|
|
1282
|
+
|
|
1283
|
+
valid_mask_model = valid_mask_model.astype(bool)
|
|
1284
|
+
|
|
1285
|
+
ra_model_interp, dec_model_interp, v_model_interp, valid, model_keep, dmetric_model, matching_trace = (
|
|
1286
|
+
gradient_descent.checked_match_model_to_data_curve(
|
|
1287
|
+
ra_model,
|
|
1288
|
+
dec_model,
|
|
1289
|
+
v_model,
|
|
1290
|
+
valid_mask_model,
|
|
1291
|
+
streamer.ra_data,
|
|
1292
|
+
streamer.dec_data,
|
|
1293
|
+
)
|
|
1294
|
+
)
|
|
1295
|
+
|
|
1296
|
+
if model_keep is not None:
|
|
1297
|
+
model_keep = model_keep.astype(bool)
|
|
1298
|
+
|
|
1299
|
+
epoch_models.append({
|
|
1300
|
+
"epoch": epoch,
|
|
1301
|
+
"ra_model": ra_model,
|
|
1302
|
+
"dec_model": dec_model,
|
|
1303
|
+
"v_model": v_model,
|
|
1304
|
+
"ra_model_interp": ra_model_interp,
|
|
1305
|
+
"dec_model_interp": dec_model_interp,
|
|
1306
|
+
"v_model_interp": v_model_interp,
|
|
1307
|
+
"valid": valid,
|
|
1308
|
+
"model_keep": model_keep,
|
|
1309
|
+
})
|
|
1310
|
+
|
|
1311
|
+
# Set consistent axis limits across epochs
|
|
1312
|
+
rproj_list = []
|
|
1313
|
+
v_list = [np.asarray(m["v_model"], dtype=float) for m in epoch_models]
|
|
1314
|
+
for model in epoch_models:
|
|
1315
|
+
ra_m = np.asarray(model["ra_model"], dtype=float)
|
|
1316
|
+
dec_m = np.asarray(model["dec_model"], dtype=float)
|
|
1317
|
+
rproj_list.append(np.sqrt(ra_m**2 + dec_m**2))
|
|
1318
|
+
|
|
1319
|
+
if streamer is not None:
|
|
1320
|
+
rproj_list.append(np.sqrt(np.asarray(streamer.ra_data, dtype=float) ** 2 + np.asarray(streamer.dec_data, dtype=float) ** 2))
|
|
1321
|
+
if streamer is not None and streamer.v_data is not None:
|
|
1322
|
+
v_list.append(np.asarray(streamer.v_data, dtype=float))
|
|
1323
|
+
|
|
1324
|
+
all_rproj = np.concatenate(rproj_list)
|
|
1325
|
+
all_v = np.concatenate(v_list)
|
|
1326
|
+
xlim = (np.nanmin(all_rproj), np.nanmax(all_rproj))
|
|
1327
|
+
ylim = (np.nanmin(all_v), np.nanmax(all_v))
|
|
1328
|
+
|
|
1329
|
+
# make or clean output folder
|
|
1330
|
+
output_dir = os.path.join(save_folder, "epochs", "vel_radius")
|
|
1331
|
+
_ensure_clean_dir(output_dir)
|
|
1332
|
+
|
|
1333
|
+
for model in epoch_models:
|
|
1334
|
+
plot_vel_radius(
|
|
1335
|
+
ra_model=model["ra_model"],
|
|
1336
|
+
dec_model=model["dec_model"],
|
|
1337
|
+
v_model=model["v_model"],
|
|
1338
|
+
streamer=streamer,
|
|
1339
|
+
ra_model_interp=model["ra_model_interp"],
|
|
1340
|
+
dec_model_interp=model["dec_model_interp"],
|
|
1341
|
+
v_model_interp=model["v_model_interp"],
|
|
1342
|
+
valid=model["valid"],
|
|
1343
|
+
model_keep=model["model_keep"],
|
|
1344
|
+
kde_background=kde_background,
|
|
1345
|
+
velocity_reference=velocity_reference,
|
|
1346
|
+
title=f"Epoch: {int(model['epoch'])}",
|
|
1347
|
+
xlim=xlim,
|
|
1348
|
+
ylim=ylim,
|
|
1349
|
+
save_folder=output_dir,
|
|
1350
|
+
save_name=f"vel_radius_epoch_{int(model['epoch']):03d}",
|
|
1351
|
+
)
|
|
1352
|
+
|
|
1353
|
+
if make_video:
|
|
1354
|
+
input_pattern = os.path.join(output_dir, "vel_radius_epoch_%03d.png")
|
|
1355
|
+
create_video_from_images(
|
|
1356
|
+
output_dir,
|
|
1357
|
+
input_pattern,
|
|
1358
|
+
"streamline_vel_radius_evolution.mp4",
|
|
1359
|
+
fps=5,
|
|
1360
|
+
)
|
|
1361
|
+
|
|
1362
|
+
|
|
1363
|
+
def plot_param_uncertainties(opt_keys, opt_params, opt_sigmas, save_folder=None, show=False):
|
|
1364
|
+
eps = 1e-12
|
|
1365
|
+
norm_errs = np.abs(opt_sigmas / (opt_params + eps))
|
|
1366
|
+
|
|
1367
|
+
fig, ax = plt.subplots(figsize=(8, 4.5))
|
|
1368
|
+
ypos = np.arange(len(opt_keys))
|
|
1369
|
+
ax.barh(
|
|
1370
|
+
ypos,
|
|
1371
|
+
norm_errs,
|
|
1372
|
+
color='tab:blue',
|
|
1373
|
+
alpha=0.8
|
|
1374
|
+
)
|
|
1375
|
+
ax.set_yticks(ypos)
|
|
1376
|
+
ax.set_yticklabels(opt_keys)
|
|
1377
|
+
ax.set_xlabel('Relative uncertainty ($\\sigma / |x|$)')
|
|
1378
|
+
ax.set_title('Normalized Parameter Uncertainties')
|
|
1379
|
+
ax.grid(True, alpha=0.25)
|
|
1380
|
+
plt.tight_layout()
|
|
1381
|
+
if save_folder is not None:
|
|
1382
|
+
os.makedirs(save_folder, exist_ok=True)
|
|
1383
|
+
plt.savefig(f'{save_folder}/parameter_uncertainties.png', dpi=300, bbox_inches='tight')
|
|
1384
|
+
if show:
|
|
1385
|
+
plt.show()
|
|
1386
|
+
else:
|
|
1387
|
+
plt.close(fig)
|
|
1388
|
+
|
|
1389
|
+
def plot_param_correlations(param_names, covariance, annotate=True, save_folder=None, show=False):
|
|
1390
|
+
'''
|
|
1391
|
+
Plot a parameter correlation matrix derived from the covariance matrix, as a heatmpa.
|
|
1392
|
+
|
|
1393
|
+
Parameters:
|
|
1394
|
+
----------
|
|
1395
|
+
param_names: list of str
|
|
1396
|
+
Names of the parameters, in the same order as the covariance matrix.
|
|
1397
|
+
covariance: 2D array
|
|
1398
|
+
Covariance matrix of the parameters.
|
|
1399
|
+
annotate: bool, optional
|
|
1400
|
+
Whether to annotate the heatmap with correlation values.
|
|
1401
|
+
'''
|
|
1402
|
+
cov_np = np.array(covariance, dtype=float)
|
|
1403
|
+
|
|
1404
|
+
diag = np.sqrt(np.clip(np.diag(cov_np), 1e-30, None))
|
|
1405
|
+
corr = cov_np / np.outer(diag, diag)
|
|
1406
|
+
corr = np.clip(corr, -1.0, 1.0)
|
|
1407
|
+
|
|
1408
|
+
fig, ax = plt.subplots(figsize=(6.5, 5.5))
|
|
1409
|
+
|
|
1410
|
+
im = ax.imshow(corr, vmin=-1, vmax=1, cmap='coolwarm_r')
|
|
1411
|
+
|
|
1412
|
+
ax.set_xticks(np.arange(len(param_names)))
|
|
1413
|
+
ax.set_yticks(np.arange(len(param_names)))
|
|
1414
|
+
ax.set_xticklabels(param_names, rotation=45, ha='right', fontsize=11)
|
|
1415
|
+
ax.set_yticklabels(param_names, fontsize=11)
|
|
1416
|
+
ax.set_title('Parameter Correlation Matrix')
|
|
1417
|
+
|
|
1418
|
+
# Create colorbar axis with matched height
|
|
1419
|
+
divider = make_axes_locatable(ax)
|
|
1420
|
+
cax = divider.append_axes("right", size="5%", pad=0.08)
|
|
1421
|
+
|
|
1422
|
+
cbar = fig.colorbar(im, cax=cax)
|
|
1423
|
+
cbar.set_label('Correlation coefficient')
|
|
1424
|
+
|
|
1425
|
+
if annotate:
|
|
1426
|
+
for i in range(len(param_names)):
|
|
1427
|
+
for j in range(len(param_names)):
|
|
1428
|
+
ax.text(
|
|
1429
|
+
j, i,
|
|
1430
|
+
f'{corr[i, j]:.2f}',
|
|
1431
|
+
ha='center',
|
|
1432
|
+
va='center',
|
|
1433
|
+
fontsize=10,
|
|
1434
|
+
color='black'
|
|
1435
|
+
)
|
|
1436
|
+
|
|
1437
|
+
plt.tight_layout()
|
|
1438
|
+
if save_folder is not None:
|
|
1439
|
+
os.makedirs(save_folder, exist_ok=True)
|
|
1440
|
+
plt.savefig(f'{save_folder}/parameter_correlation_matrix.png', dpi=300, bbox_inches='tight')
|
|
1441
|
+
if show:
|
|
1442
|
+
plt.show()
|
|
1443
|
+
else:
|
|
1444
|
+
plt.close(fig)
|
|
1445
|
+
|
|
1446
|
+
def plot_streamline_covariance_samples(best_opt_params,
|
|
1447
|
+
fixed_params,
|
|
1448
|
+
data,
|
|
1449
|
+
uncertainties,
|
|
1450
|
+
distance,
|
|
1451
|
+
covariance_result,
|
|
1452
|
+
v_lsr=None,
|
|
1453
|
+
n_samples=100,
|
|
1454
|
+
save_folder=None):
|
|
1455
|
+
"""
|
|
1456
|
+
Sample parameter sets from the covariance matrix, evaluate streamlines from those sets, and plot them all together
|
|
1457
|
+
|
|
1458
|
+
Parameters
|
|
1459
|
+
----------
|
|
1460
|
+
best_opt_params : dict
|
|
1461
|
+
Best-fit optimised parameters in the original user-supplied parameterisation
|
|
1462
|
+
(e.g. with 'rc' or 'omega'). Used only to evaluate and plot the best-fit streamline.
|
|
1463
|
+
fixed_params : dict
|
|
1464
|
+
Fixed model parameters in the original user-supplied parameterisation.
|
|
1465
|
+
Used only to evaluate and plot the best-fit streamline.
|
|
1466
|
+
data : tuple of arrays (ra_data, dec_data, v_data)
|
|
1467
|
+
uncertainties : tuple of arrays (ra_sigma, dec_sigma, v_sigma)
|
|
1468
|
+
distance : float, distance in pc
|
|
1469
|
+
covariance_result : namedtuple from
|
|
1470
|
+
result = fit_streamline(...)
|
|
1471
|
+
covariance_result = result.cov_result
|
|
1472
|
+
v_lsr : float or None, km/s
|
|
1473
|
+
n_samples : int
|
|
1474
|
+
save_folder : str or None
|
|
1475
|
+
"""
|
|
1476
|
+
cov_best_params = covariance_result.best_opt_params
|
|
1477
|
+
cov_opt_keys = covariance_result.opt_keys
|
|
1478
|
+
cov_fixed_params = covariance_result.fixed_params
|
|
1479
|
+
cov = covariance_result.covariance
|
|
1480
|
+
_, streamline_samples = generate_streamline_samples(
|
|
1481
|
+
best_opt_params=cov_best_params,
|
|
1482
|
+
covariance=cov,
|
|
1483
|
+
opt_keys=cov_opt_keys,
|
|
1484
|
+
fixed_params=cov_fixed_params,
|
|
1485
|
+
distance=distance,
|
|
1486
|
+
param_bounds=None,
|
|
1487
|
+
n_samples=n_samples
|
|
1488
|
+
)
|
|
1489
|
+
|
|
1490
|
+
fig, (ax_sky, ax_v) = plt.subplots(1, 2, figsize=(10, 5))
|
|
1491
|
+
ra_data, dec_data, v_data = data
|
|
1492
|
+
ra_sigma, dec_sigma, v_sigma = uncertainties
|
|
1493
|
+
|
|
1494
|
+
# plot samples streamlines
|
|
1495
|
+
for streamline in streamline_samples:
|
|
1496
|
+
ra = streamline['ra']
|
|
1497
|
+
dec = streamline['dec']
|
|
1498
|
+
vel = streamline['v']
|
|
1499
|
+
|
|
1500
|
+
rproj = np.sqrt(ra**2 + dec**2)
|
|
1501
|
+
order = np.argsort(rproj)
|
|
1502
|
+
ax_sky.plot(ra, dec, color='tab:blue', alpha=0.1, lw=1)
|
|
1503
|
+
ax_v.plot(rproj[order], vel[order], color='tab:blue', alpha=0.1, lw=1)
|
|
1504
|
+
|
|
1505
|
+
# plot best fit streamline
|
|
1506
|
+
best_opt_full_params, best_opt_params, fixed_params = gradient_descent.prepare_model_params(best_opt_params, fixed_params)
|
|
1507
|
+
ra_best, dec_best, v_best, valid_mask_best, err = gradient_descent.forward_model(best_opt_full_params, distance)
|
|
1508
|
+
ra_best = np.asarray(ra_best, dtype=float)
|
|
1509
|
+
dec_best = np.asarray(dec_best, dtype=float)
|
|
1510
|
+
v_best = np.asarray(v_best, dtype=float)
|
|
1511
|
+
valid_mask_best = valid_mask_best.astype(bool)
|
|
1512
|
+
ra_best = ra_best[valid_mask_best]
|
|
1513
|
+
dec_best = dec_best[valid_mask_best]
|
|
1514
|
+
v_best = v_best[valid_mask_best]
|
|
1515
|
+
rproj_best = np.sqrt(ra_best**2 + dec_best**2)
|
|
1516
|
+
order_best = np.argsort(rproj_best)
|
|
1517
|
+
ax_sky.plot(ra_best, dec_best, color='blue', lw=2, label='Best-fit')
|
|
1518
|
+
ax_v.plot(rproj_best[order_best], v_best[order_best], color='blue', lw=2, label='Best-fit')
|
|
1519
|
+
|
|
1520
|
+
# plot data
|
|
1521
|
+
ax_sky.errorbar(
|
|
1522
|
+
ra_data, dec_data, xerr=ra_sigma, yerr=dec_sigma,
|
|
1523
|
+
fmt='o', color='red', ecolor='red', ms=4, alpha=0.9, label='Data'
|
|
1524
|
+
)
|
|
1525
|
+
rproj_data = np.sqrt(ra_data**2 + dec_data**2)
|
|
1526
|
+
# get errors in rproj_data
|
|
1527
|
+
rproj_sigma = np.sqrt((ra_data * ra_sigma)**2 + (dec_data * dec_sigma)**2) / rproj_data
|
|
1528
|
+
order_data = np.argsort(rproj_data)
|
|
1529
|
+
ax_v.errorbar(
|
|
1530
|
+
rproj_data[order_data], np.asarray(v_data)[order_data], yerr=np.asarray(v_sigma)[order_data], xerr=np.asarray(rproj_sigma)[order_data],
|
|
1531
|
+
fmt='o', color='red', ecolor='red', ms=4, alpha=0.9, label='Data'
|
|
1532
|
+
)
|
|
1533
|
+
if v_lsr is not None:
|
|
1534
|
+
xmin, xmax = ax_v.get_xlim()
|
|
1535
|
+
ax_v.hlines(v_lsr, xmin=xmin, xmax=xmax, colors='k', linestyles='--', alpha=0.6,)
|
|
1536
|
+
ax_v.set_xlim(xmin, xmax)
|
|
1537
|
+
|
|
1538
|
+
# finalise plots
|
|
1539
|
+
ax_sky.invert_xaxis()
|
|
1540
|
+
ax_sky.set_xlabel('RA Offset (arcsec)')
|
|
1541
|
+
ax_sky.set_ylabel('Dec Offset (arcsec)')
|
|
1542
|
+
ax_sky.set_title('Covariance Sampling')
|
|
1543
|
+
ax_sky.legend()
|
|
1544
|
+
|
|
1545
|
+
ax_v.set_xlabel('Projected distance(arcsec)')
|
|
1546
|
+
ax_v.set_ylabel('Velocity (km/s)')
|
|
1547
|
+
ax_v.set_title('Covariance Sampling')
|
|
1548
|
+
ax_v.legend()
|
|
1549
|
+
|
|
1550
|
+
plt.tight_layout()
|
|
1551
|
+
|
|
1552
|
+
if save_folder is not None:
|
|
1553
|
+
os.makedirs(save_folder, exist_ok=True)
|
|
1554
|
+
plt.savefig(f'{save_folder}/streamline_covariance_samples.png', dpi=300, bbox_inches='tight')
|
|
1555
|
+
plt.show()
|
|
1556
|
+
else:
|
|
1557
|
+
plt.show()
|
|
1558
|
+
|
|
1559
|
+
|
|
1560
|
+
def generate_streamline_samples(best_opt_params, covariance, opt_keys, fixed_params, distance, param_bounds=None, n_samples=100):
|
|
1561
|
+
"""
|
|
1562
|
+
wrapper of sample_parameter_sets_from_covariance() and evaluate_streamlines_samples() to generate streamline samples from covariance matrix.
|
|
1563
|
+
"""
|
|
1564
|
+
samples = sample_parameter_sets_from_covariance(
|
|
1565
|
+
best_opt_params,
|
|
1566
|
+
covariance,
|
|
1567
|
+
opt_keys,
|
|
1568
|
+
param_bounds=param_bounds,
|
|
1569
|
+
n_samples=n_samples
|
|
1570
|
+
)
|
|
1571
|
+
streamlines = evaluate_streamlines_samples(
|
|
1572
|
+
samples,
|
|
1573
|
+
opt_keys,
|
|
1574
|
+
fixed_params,
|
|
1575
|
+
distance,
|
|
1576
|
+
)
|
|
1577
|
+
return samples, streamlines
|
|
1578
|
+
|
|
1579
|
+
def evaluate_streamlines_samples(param_samples, opt_keys, fixed_params, distance):
|
|
1580
|
+
"""
|
|
1581
|
+
Evaluate streamline models for sampled parameter vectors
|
|
1582
|
+
Returns
|
|
1583
|
+
-------
|
|
1584
|
+
streamlines : list of dict
|
|
1585
|
+
Each entry contains:
|
|
1586
|
+
{
|
|
1587
|
+
"ra": ...,
|
|
1588
|
+
"dec": ...,
|
|
1589
|
+
"v": ...,
|
|
1590
|
+
"dmetric": ...
|
|
1591
|
+
}
|
|
1592
|
+
"""
|
|
1593
|
+
streamlines = []
|
|
1594
|
+
for sample in param_samples:
|
|
1595
|
+
sample_params = {
|
|
1596
|
+
key: float(value)
|
|
1597
|
+
for key, value in zip(opt_keys, sample)
|
|
1598
|
+
}
|
|
1599
|
+
sample_params_full, _, _ = gradient_descent.prepare_model_params(sample_params, fixed_params)
|
|
1600
|
+
ra, dec, vel, valid_mask, err = gradient_descent.forward_model(sample_params_full, distance)
|
|
1601
|
+
ra = np.asarray(ra, dtype=float)
|
|
1602
|
+
dec = np.asarray(dec, dtype=float)
|
|
1603
|
+
vel = np.asarray(vel, dtype=float)
|
|
1604
|
+
valid_mask = valid_mask.astype(bool)
|
|
1605
|
+
ra = ra[valid_mask]
|
|
1606
|
+
dec = dec[valid_mask]
|
|
1607
|
+
vel = vel[valid_mask]
|
|
1608
|
+
|
|
1609
|
+
dmetric, trace = extract_streamline.get_distance_metric(ra, dec)
|
|
1610
|
+
dmetric = np.asarray(dmetric, dtype=float)
|
|
1611
|
+
|
|
1612
|
+
streamlines.append(
|
|
1613
|
+
{
|
|
1614
|
+
"ra": ra,
|
|
1615
|
+
"dec": dec,
|
|
1616
|
+
"v": vel,
|
|
1617
|
+
"dmetric": dmetric,
|
|
1618
|
+
}
|
|
1619
|
+
)
|
|
1620
|
+
|
|
1621
|
+
return streamlines
|
|
1622
|
+
|
|
1623
|
+
def sample_parameter_sets_from_covariance(best_params, covariance, opt_keys, param_bounds=None, n_samples=100, seed=42):
|
|
1624
|
+
"""
|
|
1625
|
+
Draw parameter samples from a covariance matrix.
|
|
1626
|
+
Returns
|
|
1627
|
+
-------
|
|
1628
|
+
samples : ndarray, shape (n_samples, n_params)
|
|
1629
|
+
Sampled parameter vectors
|
|
1630
|
+
"""
|
|
1631
|
+
rng = np.random.default_rng(seed)
|
|
1632
|
+
mu = np.array(
|
|
1633
|
+
[best_params[key] for key in opt_keys],
|
|
1634
|
+
dtype=float,
|
|
1635
|
+
)
|
|
1636
|
+
cov = np.asarray(covariance, dtype=float)
|
|
1637
|
+
samples = rng.multivariate_normal(
|
|
1638
|
+
mu,
|
|
1639
|
+
cov,
|
|
1640
|
+
size=n_samples,
|
|
1641
|
+
)
|
|
1642
|
+
param_bounds = gradient_descent.auto_fill_angle_bounds(set(opt_keys), param_bounds)
|
|
1643
|
+
param_bounds = gradient_descent.convert_and_strip_bound_units(param_bounds)
|
|
1644
|
+
for j, key in enumerate(opt_keys):
|
|
1645
|
+
if key in param_bounds:
|
|
1646
|
+
low, high = param_bounds[key]
|
|
1647
|
+
samples[:, j] = np.clip(samples[:, j], low, high)
|
|
1648
|
+
|
|
1649
|
+
return samples
|
|
1650
|
+
|
|
1651
|
+
def plot_param_optimisation_history(save_folder='sting_results'):
|
|
1652
|
+
'''Plot the history of parameter optimisation from logs saved during optimisation.
|
|
1653
|
+
save_folder should be the same as the one used during optimisation,
|
|
1654
|
+
and should contain "optimisation_log.csv" and optionally "optimisation_trace.csv"'''
|
|
1655
|
+
|
|
1656
|
+
try:
|
|
1657
|
+
optimisation_log = load_optimisation_log(save_folder)
|
|
1658
|
+
except FileNotFoundError:
|
|
1659
|
+
print(f"Error: Could not find 'optimisation_log.csv' and/or 'optimisation_trace.csv' in {save_folder}")
|
|
1660
|
+
return
|
|
1661
|
+
|
|
1662
|
+
epochs = optimisation_log["epoch"].values
|
|
1663
|
+
loss = optimisation_log["loss"].values
|
|
1664
|
+
|
|
1665
|
+
param_names = []
|
|
1666
|
+
for c in optimisation_log.columns:
|
|
1667
|
+
if c not in ("epoch", "loss"):
|
|
1668
|
+
param_names.append(c)
|
|
1669
|
+
|
|
1670
|
+
fig, axes = plt.subplots(len(param_names), 1, figsize=(8, 2 * (len(param_names) + 1)), sharex=True)
|
|
1671
|
+
|
|
1672
|
+
plot_loss_panel(axes[0], epochs, loss)
|
|
1673
|
+
|
|
1674
|
+
for ax, param in zip(axes[1:], param_names):
|
|
1675
|
+
values = optimisation_log[param].values
|
|
1676
|
+
ax.plot(epochs, values)
|
|
1677
|
+
ax.set_ylabel(param)
|
|
1678
|
+
ax.grid(True)
|
|
1679
|
+
|
|
1680
|
+
plt.tight_layout()
|
|
1681
|
+
if save_folder is not None:
|
|
1682
|
+
os.makedirs(save_folder, exist_ok=True)
|
|
1683
|
+
plt.savefig(f'{save_folder}/parameter_optimisation_history.png', dpi=300, bbox_inches='tight')
|
|
1684
|
+
plt.show()
|
|
1685
|
+
else:
|
|
1686
|
+
plt.show()
|
|
1687
|
+
|
|
1688
|
+
def plot_loss_panel(ax, epochs, loss):
|
|
1689
|
+
lowest_loss = np.min(loss)
|
|
1690
|
+
best_idx = np.argmin(loss)
|
|
1691
|
+
best_epoch = epochs[best_idx]
|
|
1692
|
+
ax.plot(epochs, loss, color="black")
|
|
1693
|
+
ax.scatter(best_epoch, lowest_loss, color="green", label=f"Best Epoch: {best_epoch}")
|
|
1694
|
+
ax.set_yscale("log")
|
|
1695
|
+
ax.set_ylabel("Loss")
|
|
1696
|
+
ax.grid(True, alpha=0.3)
|
|
1697
|
+
ax.legend()
|
|
1698
|
+
|
|
1699
|
+
|
|
1700
|
+
def load_optimisation_log(logs_dir):
|
|
1701
|
+
log_path = os.path.join(logs_dir, "optimisation_log.csv")
|
|
1702
|
+
optimisation_log = pd.read_csv(log_path)
|
|
1703
|
+
return optimisation_log
|
|
1704
|
+
|
|
1705
|
+
|