sting 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sting/__init__.py +8 -0
- sting/_version.py +24 -0
- sting/errors.py +677 -0
- sting/extract_streamline.py +425 -0
- sting/gradient_descent.py +1776 -0
- sting/outputs.py +1705 -0
- sting/stream_lines_grad.py +448 -0
- sting-0.2.0.dist-info/METADATA +251 -0
- sting-0.2.0.dist-info/RECORD +14 -0
- sting-0.2.0.dist-info/WHEEL +5 -0
- sting-0.2.0.dist-info/licenses/LICENCE +21 -0
- sting-0.2.0.dist-info/scm_file_list.json +26 -0
- sting-0.2.0.dist-info/scm_version.json +8 -0
- sting-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,425 @@
|
|
|
1
|
+
'''
|
|
2
|
+
This file contains functions to extract streamline emission from a data cube,
|
|
3
|
+
and extract a 1D streamline from that.
|
|
4
|
+
'''
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from astropy import units as u
|
|
8
|
+
from astropy.coordinates import SkyCoord, FK5
|
|
9
|
+
from collections import namedtuple
|
|
10
|
+
import jax.numpy as jnp
|
|
11
|
+
import jax
|
|
12
|
+
|
|
13
|
+
BIG = 1e30
|
|
14
|
+
|
|
15
|
+
def to_float64(value):
|
|
16
|
+
return jnp.asarray(value, dtype=jnp.float64)
|
|
17
|
+
|
|
18
|
+
PreparedData = namedtuple('PreparedData', [
|
|
19
|
+
'ra_data', 'dec_data', 'v_data',
|
|
20
|
+
'ra_sigma_safe', 'dec_sigma_safe', 'v_sigma_safe',
|
|
21
|
+
'dmetric_data', 'data_finite_mask',
|
|
22
|
+
'data_min', 'data_max',
|
|
23
|
+
'r_proj_data', 'theta_proj_data',
|
|
24
|
+
])
|
|
25
|
+
|
|
26
|
+
StreamerData = namedtuple('StreamerData', [
|
|
27
|
+
'pc_coords',
|
|
28
|
+
'ra_data', 'dec_data', 'v_data',
|
|
29
|
+
'ra_sigma', 'dec_sigma', 'v_sigma',
|
|
30
|
+
'data', 'uncertainties',
|
|
31
|
+
])
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
#@jax.jit
|
|
36
|
+
def wrap_to_pi(angle):
|
|
37
|
+
'''Wrap angles to [-pi, pi)'''
|
|
38
|
+
return (angle + jnp.pi) % (2.0 * jnp.pi) - jnp.pi
|
|
39
|
+
|
|
40
|
+
#@jax.jit
|
|
41
|
+
def circular_median(theta_vals, weights):
|
|
42
|
+
'''Branch-cut-safe median angle. (unrwap, linear median, rewrap)
|
|
43
|
+
theta values with weight = 0 are ignored in the median calculation'''
|
|
44
|
+
weights = weights / (jnp.sum(weights) + 1e-12) # normalize weights to sum to 1, add small value to avoid division by zero
|
|
45
|
+
theta_anchor = jnp.arctan2(
|
|
46
|
+
jnp.sum(weights * jnp.sin(theta_vals)),
|
|
47
|
+
jnp.sum(weights * jnp.cos(theta_vals))
|
|
48
|
+
)
|
|
49
|
+
theta_delta = wrap_to_pi(theta_vals - theta_anchor)
|
|
50
|
+
theta_unwrapped = theta_anchor + theta_delta
|
|
51
|
+
sort_idx = jnp.argsort(theta_unwrapped)
|
|
52
|
+
sorted_vals = theta_unwrapped[sort_idx]
|
|
53
|
+
sorted_weights = weights[sort_idx]
|
|
54
|
+
cumulative_weights = jnp.cumsum(sorted_weights)
|
|
55
|
+
cutoff = 0.5 * jnp.sum(sorted_weights)
|
|
56
|
+
median_idx = jnp.argmax(cumulative_weights >= cutoff)
|
|
57
|
+
theta_ref = sorted_vals[median_idx]
|
|
58
|
+
return wrap_to_pi(theta_ref)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
#@jax.jit
|
|
62
|
+
def wrap_to_pi_numpy(angle):
|
|
63
|
+
'''Wrap angles to [-pi, pi)'''
|
|
64
|
+
return (angle + np.pi) % (2.0 * np.pi) - np.pi
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def extract_streamer_subcube(cube, vmin=None, vmax=None, xmin=None, xmax=None, ymin=None, ymax=None, rms_thresh=None):
|
|
68
|
+
"""Extract a subcube containing the streamer emission, by applying velocity and spatial limits, and masking out low SNR emission."""
|
|
69
|
+
streamer_cube = cube
|
|
70
|
+
if (vmin is not None) and (vmax is not None):
|
|
71
|
+
streamer_cube = streamer_cube.spectral_slab(vmin, vmax)
|
|
72
|
+
if (xmin is not None) and (xmax is not None) and (ymin is not None) and (ymax is not None):
|
|
73
|
+
celestial_wcs = streamer_cube.wcs.celestial
|
|
74
|
+
ny, nx = streamer_cube.shape[1], streamer_cube.shape[2]
|
|
75
|
+
# reference sky coord corresponding to the reference pixel in the WCS
|
|
76
|
+
ref_ra, ref_dec = celestial_wcs.wcs.crval
|
|
77
|
+
ref_coord = SkyCoord(ref_ra*u.deg, ref_dec*u.deg, frame=FK5)
|
|
78
|
+
# convert the limits from offsets to sky coords
|
|
79
|
+
corner1 = SkyCoord(ref_coord.ra + xmin, ref_coord.dec + ymin, frame=FK5) #'bottom left'
|
|
80
|
+
corner2 = SkyCoord(ref_coord.ra + xmax, ref_coord.dec + ymax, frame=FK5) #'top right'
|
|
81
|
+
x1, y1 = celestial_wcs.world_to_pixel(corner1)
|
|
82
|
+
x2, y2 = celestial_wcs.world_to_pixel(corner2)
|
|
83
|
+
xmin_pix = max(0, int(np.floor(min(x1, x2))))
|
|
84
|
+
xmax_pix = min(nx, int(np.ceil(max(x1, x2))))
|
|
85
|
+
ymin_pix = max(0, int(np.floor(min(y1, y2))))
|
|
86
|
+
ymax_pix = min(ny, int(np.ceil(max(y1, y2))))
|
|
87
|
+
streamer_cube = streamer_cube[:, ymin_pix:ymax_pix, xmin_pix:xmax_pix]
|
|
88
|
+
if rms_thresh is not None:
|
|
89
|
+
rms_estimate = streamer_cube.mad_std()
|
|
90
|
+
streamer_cube = streamer_cube.with_mask(streamer_cube > rms_thresh*rms_estimate)
|
|
91
|
+
|
|
92
|
+
return streamer_cube
|
|
93
|
+
|
|
94
|
+
def reduce_to_1D(streamer_cube, yso_centre, n_elements=10):
|
|
95
|
+
'''
|
|
96
|
+
Reduce a cube of emission to a 1D 'streamline' by weighted means
|
|
97
|
+
|
|
98
|
+
Parameters
|
|
99
|
+
----------
|
|
100
|
+
streamer_cube : SpectralCube object, should contain only streamer emission
|
|
101
|
+
yso_centre : SkyCoord, the coordinates of the star, used to compute RA and Dec offsets in arcsec
|
|
102
|
+
n_elements : int, number of elements to reduce the cube to
|
|
103
|
+
|
|
104
|
+
Returns
|
|
105
|
+
-------
|
|
106
|
+
StreamerData
|
|
107
|
+
Named tuple with fields:
|
|
108
|
+
- pc_coords : full point cloud array, shape (3, N)
|
|
109
|
+
- ra_data : RA offsets of bin means (arcsec), shape (n_elements,)
|
|
110
|
+
- dec_data : Dec offsets of bin means (arcsec), shape (n_elements,)
|
|
111
|
+
- v_data : velocities of bin means (km/s), shape (n_elements,)
|
|
112
|
+
- ra_sigma : RA standard deviations per bin (arcsec), shape (n_elements,)
|
|
113
|
+
- dec_sigma : Dec standard deviations per bin (arcsec), shape (n_elements,)
|
|
114
|
+
- v_sigma : velocity standard deviations per bin (km/s), shape (n_elements,)
|
|
115
|
+
- data : (ra_data, dec_data, v_data) tuple
|
|
116
|
+
- uncertainties: (ra_sigma, dec_sigma, v_sigma) tuple
|
|
117
|
+
'''
|
|
118
|
+
print('Starting reduction')
|
|
119
|
+
nz, ny, nx = streamer_cube.shape
|
|
120
|
+
yso_centre_icrs = yso_centre.icrs
|
|
121
|
+
|
|
122
|
+
# create RA and Dec offset arrays in arcsec relative to the yso centre
|
|
123
|
+
y_indices, x_indices = np.mgrid[0:ny, 0:nx]
|
|
124
|
+
world_coords = streamer_cube.wcs.celestial.pixel_to_world_values(x_indices.ravel(), y_indices.ravel())
|
|
125
|
+
ra_unit = u.Unit(streamer_cube.header.get('CUNIT1', streamer_cube.wcs.celestial.world_axis_units[0]))
|
|
126
|
+
dec_unit = u.Unit(streamer_cube.header.get('CUNIT2', streamer_cube.wcs.celestial.world_axis_units[1]))
|
|
127
|
+
world_sky = SkyCoord(
|
|
128
|
+
ra=world_coords[0] * ra_unit,
|
|
129
|
+
dec=world_coords[1] * dec_unit,
|
|
130
|
+
frame='icrs'
|
|
131
|
+
)
|
|
132
|
+
dra, ddec = yso_centre_icrs.spherical_offsets_to(world_sky)
|
|
133
|
+
ra_coords = dra.to(u.arcsec).value.reshape(ny, nx)
|
|
134
|
+
dec_coords = ddec.to(u.arcsec).value.reshape(ny, nx)
|
|
135
|
+
# create velocity array relative to the reference channel, then express it in km/s
|
|
136
|
+
spectral_unit = u.Unit(streamer_cube.header.get('CUNIT3', streamer_cube.spectral_axis.unit))
|
|
137
|
+
spectral_axis = streamer_cube.spectral_axis.to(spectral_unit)
|
|
138
|
+
v_coords = spectral_axis.to(u.km / u.s).value
|
|
139
|
+
|
|
140
|
+
# get data and mask
|
|
141
|
+
pcloud = np.array(streamer_cube)
|
|
142
|
+
rms_mask = ~np.isnan(pcloud)
|
|
143
|
+
flux = pcloud[rms_mask]
|
|
144
|
+
# get indices of valid points in pc
|
|
145
|
+
pc_indices = np.indices(pcloud.shape)
|
|
146
|
+
pc_z = pc_indices[0][rms_mask]
|
|
147
|
+
pc_y = pc_indices[1][rms_mask]
|
|
148
|
+
pc_x = pc_indices[2][rms_mask]
|
|
149
|
+
|
|
150
|
+
# extract coordinates of valid points using the arrays above
|
|
151
|
+
pc_ra = ra_coords[pc_y, pc_x]
|
|
152
|
+
pc_dec = dec_coords[pc_y, pc_x]
|
|
153
|
+
pc_v = v_coords[pc_z]
|
|
154
|
+
pc_coords = np.array([pc_ra, pc_dec, pc_v]) # shape (3, n_points)
|
|
155
|
+
|
|
156
|
+
# compute partitions for binning the point cloud
|
|
157
|
+
distance_metric, _ = get_distance_metric(pc_coords[0], pc_coords[1], n_elements=n_elements)
|
|
158
|
+
b_per = np.linspace(0, 100, n_elements+1) # percentiles to bin the pc into
|
|
159
|
+
partitions = np.array([np.percentile(distance_metric, per) for per in b_per])
|
|
160
|
+
|
|
161
|
+
# flux-weighted means and stds in each bin
|
|
162
|
+
pc_means = np.zeros((3, n_elements))
|
|
163
|
+
pc_stds = np.zeros((3, n_elements))
|
|
164
|
+
for i in range(n_elements):
|
|
165
|
+
distance_indices = (distance_metric > partitions[i]) & (distance_metric <= partitions[i+1])
|
|
166
|
+
pc_means[:, i] = np.average(pc_coords.T[distance_indices],
|
|
167
|
+
axis=0,
|
|
168
|
+
weights=flux[distance_indices])
|
|
169
|
+
pc_stds[:, i] = np.sqrt(np.average((pc_coords.T[distance_indices] - pc_means[:, i])**2,
|
|
170
|
+
axis=0,
|
|
171
|
+
weights=flux[distance_indices]))
|
|
172
|
+
|
|
173
|
+
# flip arrays so that they go from large to small distance (towards star)
|
|
174
|
+
pc_means = pc_means[:, ::-1]
|
|
175
|
+
pc_stds = pc_stds[:, ::-1]
|
|
176
|
+
|
|
177
|
+
ra_data, dec_data, v_data = pc_means
|
|
178
|
+
ra_sigma, dec_sigma, v_sigma = pc_stds
|
|
179
|
+
|
|
180
|
+
return StreamerData(
|
|
181
|
+
pc_coords=pc_coords,
|
|
182
|
+
ra_data=ra_data, dec_data=dec_data, v_data=v_data,
|
|
183
|
+
ra_sigma=ra_sigma, dec_sigma=dec_sigma, v_sigma=v_sigma,
|
|
184
|
+
data=(ra_data, dec_data, v_data),
|
|
185
|
+
uncertainties=(ra_sigma, dec_sigma, v_sigma),
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
#@jax.jit
|
|
190
|
+
def safe_percentile(values, percentile):
|
|
191
|
+
"""
|
|
192
|
+
jax and jit-safe percentile ignoring invalid values, which does not change array shape
|
|
193
|
+
"""
|
|
194
|
+
mask = jnp.isfinite(values)
|
|
195
|
+
|
|
196
|
+
# Sort valid values to the front by pushing invalid ones to BIG
|
|
197
|
+
cleaned = jnp.where(mask, values, to_float64(BIG))
|
|
198
|
+
sorted_vals = jnp.sort(cleaned) # valid values are at the front
|
|
199
|
+
|
|
200
|
+
# make a new mask which is all the not BIG values
|
|
201
|
+
percentile_mask = sorted_vals < to_float64(BIG)
|
|
202
|
+
n_valid = jnp.sum(percentile_mask)
|
|
203
|
+
total_n = values.size
|
|
204
|
+
# Compute the index into only the valid portion
|
|
205
|
+
idx = jnp.clip(
|
|
206
|
+
jnp.floor(percentile / 100.0 * n_valid).astype(jnp.int32),
|
|
207
|
+
0,
|
|
208
|
+
jnp.maximum(n_valid - 1, 0)
|
|
209
|
+
)
|
|
210
|
+
return sorted_vals[idx]
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
#@jax.jit
|
|
214
|
+
def get_distance_metric(ra_coords, dec_coords, n_elements=10):
|
|
215
|
+
'''
|
|
216
|
+
Compute radial + angular distance metric for point cloud binning
|
|
217
|
+
Uses a circular angular deviation to avoid branch-cut artifacts.
|
|
218
|
+
'''
|
|
219
|
+
pc_r, pc_theta = cartesian_to_polar(ra_coords, dec_coords)
|
|
220
|
+
pc_r = jnp.where(jnp.abs(pc_r) < 1e-12, to_float64(BIG), pc_r)
|
|
221
|
+
|
|
222
|
+
finite_mask = jnp.isfinite(pc_r) & jnp.isfinite(pc_theta)
|
|
223
|
+
|
|
224
|
+
# deal with if there are no valid points
|
|
225
|
+
def empty_case(_):
|
|
226
|
+
|
|
227
|
+
distance_metric = jnp.full_like(pc_r, to_float64(BIG))
|
|
228
|
+
trace = {
|
|
229
|
+
"n_points": jnp.array(pc_r.size, dtype=jnp.int32),
|
|
230
|
+
"n_finite_points": jnp.array(0, dtype=jnp.int32),
|
|
231
|
+
"n_reference_points": jnp.array(0, dtype=jnp.int32),
|
|
232
|
+
"r_percentile_thresh": to_float64(0.0),
|
|
233
|
+
"r_thresh": to_float64(BIG),
|
|
234
|
+
"theta_ref": to_float64(0.0),
|
|
235
|
+
"theta_weight": to_float64(1.0),
|
|
236
|
+
"close_point_count": jnp.array(0, dtype=jnp.int32),
|
|
237
|
+
}
|
|
238
|
+
return distance_metric, trace
|
|
239
|
+
|
|
240
|
+
def notempty_case(_):
|
|
241
|
+
|
|
242
|
+
theta_weight = 1.0 # maybe make this a tunable parameter
|
|
243
|
+
finite_count = jnp.sum(finite_mask)
|
|
244
|
+
|
|
245
|
+
percentile = 100.0 / n_elements
|
|
246
|
+
r_thresh = safe_percentile(pc_r, percentile)
|
|
247
|
+
# close_mask gives 0s if point is not finite or outside the threshold, and
|
|
248
|
+
# 1s if point is finite and within the threshold
|
|
249
|
+
small_enough_r = pc_r <= r_thresh
|
|
250
|
+
close_mask = (finite_mask & (pc_r <= r_thresh)).astype(jnp.float64)
|
|
251
|
+
# circular median can not have nans passed in,
|
|
252
|
+
# so change nans to 0 (this is fine because they already have weight=0 in the median calculation)
|
|
253
|
+
pc_theta_no_nan = jnp.where(finite_mask, pc_theta, 0.0)
|
|
254
|
+
theta_ref = circular_median(pc_theta_no_nan, weights=close_mask)
|
|
255
|
+
|
|
256
|
+
# cyclic angular deviation
|
|
257
|
+
theta_dev = jnp.pi - jnp.abs(
|
|
258
|
+
jnp.pi - jnp.abs(wrap_to_pi(pc_theta - theta_ref))
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
distance_metric = pc_r * jnp.sqrt(1.0 + (theta_weight * theta_dev) ** 2)
|
|
262
|
+
distance_metric = jnp.where(finite_mask, distance_metric, to_float64(BIG))
|
|
263
|
+
trace = {
|
|
264
|
+
"n_points": jnp.array(pc_r.size, dtype=jnp.int32),
|
|
265
|
+
"n_finite_points": jnp.array(finite_count, dtype=jnp.int32),
|
|
266
|
+
"n_reference_points": jnp.array(jnp.sum(small_enough_r), dtype=jnp.int32),
|
|
267
|
+
"r_percentile_thresh": to_float64(percentile),
|
|
268
|
+
"r_thresh": r_thresh,
|
|
269
|
+
"theta_ref": theta_ref,
|
|
270
|
+
"theta_weight": theta_weight,
|
|
271
|
+
"close_point_count": jnp.array(close_mask.size, dtype=jnp.int32),
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
return distance_metric, trace
|
|
275
|
+
|
|
276
|
+
return jax.lax.cond(finite_mask.any(), notempty_case, empty_case, operand=None)
|
|
277
|
+
|
|
278
|
+
#@jax.jit
|
|
279
|
+
def cartesian_to_polar(x, y):
|
|
280
|
+
'''
|
|
281
|
+
Convert cartesian coordinates (x,y) to polar coordinates
|
|
282
|
+
e.g. inputs could be RA and Dec offsets
|
|
283
|
+
Note theta is returned in radians
|
|
284
|
+
'''
|
|
285
|
+
r = jnp.sqrt(x**2 + y**2 + to_float64(1e-60)) # add small value for gradient stability
|
|
286
|
+
theta = jnp.arctan2(y, x) # angle wrt x-axis, in radians
|
|
287
|
+
|
|
288
|
+
return (r, theta)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def get_metric_partitions(pc_coords, n_elements):
|
|
292
|
+
'''
|
|
293
|
+
Compute percentile partitions for the streamline distance metric
|
|
294
|
+
|
|
295
|
+
Parameters
|
|
296
|
+
----------
|
|
297
|
+
pc_coords : array
|
|
298
|
+
Point cloud coordinates. Index 0 = RA, Index 1 = Dec, Index 2 = velocity
|
|
299
|
+
n_elements : int
|
|
300
|
+
Number of partitions required
|
|
301
|
+
|
|
302
|
+
Returns
|
|
303
|
+
-------
|
|
304
|
+
partitions : ndarray
|
|
305
|
+
Percentile boundaries of the distance metric
|
|
306
|
+
'''
|
|
307
|
+
if n_elements < 1:
|
|
308
|
+
raise ValueError('n_elements must be >= 1')
|
|
309
|
+
|
|
310
|
+
ra_coords = pc_coords[0]
|
|
311
|
+
dec_coords = pc_coords[1]
|
|
312
|
+
distance_metric, _ = get_distance_metric(ra_coords, dec_coords, n_elements=n_elements)
|
|
313
|
+
distance_metric = np.asarray(distance_metric)
|
|
314
|
+
finite_mask = np.isfinite(distance_metric)
|
|
315
|
+
finite_metric = distance_metric[finite_mask]
|
|
316
|
+
|
|
317
|
+
if finite_metric.size == 0:
|
|
318
|
+
return np.full(n_elements + 1, np.nan, dtype=np.float64)
|
|
319
|
+
|
|
320
|
+
b_per = np.linspace(0.0, 100.0, n_elements + 1)
|
|
321
|
+
return np.asarray([np.percentile(finite_metric, per) for per in b_per], dtype=np.float64)
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def get_metric_reference_trace(pc_coords, n_elements=10):
|
|
325
|
+
'''get the metric reference angle and weight used for boundary sampling'''
|
|
326
|
+
ra_coords = pc_coords[0]
|
|
327
|
+
dec_coords = pc_coords[1]
|
|
328
|
+
_, trace = get_distance_metric(ra_coords, dec_coords, n_elements=n_elements)
|
|
329
|
+
theta_ref = float(trace.get('theta_ref', 0.0))
|
|
330
|
+
theta_weight = float(trace.get('theta_weight', 1.0))
|
|
331
|
+
return theta_ref, theta_weight
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def sample_metric_boundary(partition_radius, theta_ref, theta_weight=1.0, n_samples=720):
|
|
335
|
+
'''create a constant-metric boundary as a closed RA/Dec curve (for plotting)'''
|
|
336
|
+
if n_samples < 4:
|
|
337
|
+
raise ValueError('n_samples must be >= 4')
|
|
338
|
+
|
|
339
|
+
theta = jnp.linspace(-jnp.pi, jnp.pi, n_samples, endpoint=False)
|
|
340
|
+
theta_dev = jnp.pi - jnp.abs(jnp.pi - jnp.abs(wrap_to_pi(theta - theta_ref)))
|
|
341
|
+
radius = partition_radius / jnp.sqrt(1.0 + (theta_weight * theta_dev) ** 2)
|
|
342
|
+
ra = radius * jnp.cos(theta)
|
|
343
|
+
dec = radius * jnp.sin(theta)
|
|
344
|
+
return ra, dec
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def sample_metric_boundaries(pc_coords, partitions, n_samples=720):
|
|
348
|
+
'''create all metric boundary curves for a point cloud and partition set'''
|
|
349
|
+
theta_ref, theta_weight = get_metric_reference_trace(pc_coords, n_elements=len(partitions)-1)
|
|
350
|
+
curves = [
|
|
351
|
+
sample_metric_boundary(partition_radius, theta_ref, theta_weight=theta_weight, n_samples=n_samples)
|
|
352
|
+
for partition_radius in np.asarray(partitions)
|
|
353
|
+
]
|
|
354
|
+
trace = {
|
|
355
|
+
'theta_ref': theta_ref,
|
|
356
|
+
'theta_weight': theta_weight,
|
|
357
|
+
}
|
|
358
|
+
return curves, trace
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def plot_metric_boundaries(ax, pc_coords, curves, color='lightgrey', linewidth=1, alpha=0.5, n_samples=720, zorder=1):
|
|
362
|
+
'''plot metric boundary curves on a RA/Dec axis'''
|
|
363
|
+
for ra, dec in curves:
|
|
364
|
+
ax.plot(ra, dec, color=color, linewidth=linewidth, alpha=alpha, zorder=zorder)
|
|
365
|
+
return curves
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def prepare_data(data, uncertainties, n_elements):
|
|
369
|
+
'''
|
|
370
|
+
Precompute all the constant data-only quantities used by the gradient descent,
|
|
371
|
+
to speed up later iterations
|
|
372
|
+
|
|
373
|
+
Parameterss
|
|
374
|
+
----------
|
|
375
|
+
data : tuple of arrays (ra_data, dec_data, v_data)
|
|
376
|
+
Observed RA offset (arcsec), Dec offset (arcsec), velocity (km/s)
|
|
377
|
+
uncertainties : tuple of arrays (ra_sigma, dec_sigma, v_sigma)
|
|
378
|
+
Uncertainties on the data
|
|
379
|
+
n_elements : int
|
|
380
|
+
Number of elements to reduce the cube to, used for computing the distance metric and its partitions
|
|
381
|
+
|
|
382
|
+
Returns
|
|
383
|
+
-------
|
|
384
|
+
PreparedData
|
|
385
|
+
Container containing the precomputed quantities
|
|
386
|
+
'''
|
|
387
|
+
ra_data = jnp.asarray(data[0], dtype=jnp.float64)
|
|
388
|
+
dec_data = jnp.asarray(data[1], dtype=jnp.float64)
|
|
389
|
+
v_data = jnp.asarray(data[2], dtype=jnp.float64)
|
|
390
|
+
|
|
391
|
+
ra_sigma = jnp.asarray(uncertainties[0], dtype=jnp.float64)
|
|
392
|
+
dec_sigma = jnp.asarray(uncertainties[1], dtype=jnp.float64)
|
|
393
|
+
v_sigma = jnp.asarray(uncertainties[2], dtype=jnp.float64)
|
|
394
|
+
|
|
395
|
+
eps = jnp.asarray(1e-8, dtype=jnp.float64)
|
|
396
|
+
ra_sigma_safe = jnp.maximum(ra_sigma, eps)
|
|
397
|
+
dec_sigma_safe = jnp.maximum(dec_sigma, eps)
|
|
398
|
+
v_sigma_safe = jnp.maximum(v_sigma, eps)
|
|
399
|
+
|
|
400
|
+
dmetric_data, _ = get_distance_metric(ra_data, dec_data, n_elements=n_elements)
|
|
401
|
+
data_finite_mask = jnp.isfinite(ra_data) & jnp.isfinite(dec_data) & jnp.isfinite(dmetric_data)
|
|
402
|
+
|
|
403
|
+
data_metric_for_min = jnp.where(data_finite_mask, dmetric_data, jnp.inf)
|
|
404
|
+
data_metric_for_max = jnp.where(data_finite_mask, dmetric_data, -jnp.inf)
|
|
405
|
+
data_min = jnp.min(data_metric_for_min)
|
|
406
|
+
data_max = jnp.max(data_metric_for_max)
|
|
407
|
+
|
|
408
|
+
r_proj_data, theta_proj_data = cartesian_to_polar(ra_data, dec_data)
|
|
409
|
+
|
|
410
|
+
return PreparedData(
|
|
411
|
+
ra_data=ra_data,
|
|
412
|
+
dec_data=dec_data,
|
|
413
|
+
v_data=v_data,
|
|
414
|
+
ra_sigma_safe=ra_sigma_safe,
|
|
415
|
+
dec_sigma_safe=dec_sigma_safe,
|
|
416
|
+
v_sigma_safe=v_sigma_safe,
|
|
417
|
+
dmetric_data=dmetric_data,
|
|
418
|
+
data_finite_mask=data_finite_mask,
|
|
419
|
+
data_min=data_min,
|
|
420
|
+
data_max=data_max,
|
|
421
|
+
r_proj_data=r_proj_data,
|
|
422
|
+
theta_proj_data=theta_proj_data,
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
|