sting 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,425 @@
1
+ '''
2
+ This file contains functions to extract streamline emission from a data cube,
3
+ and extract a 1D streamline from that.
4
+ '''
5
+
6
+ import numpy as np
7
+ from astropy import units as u
8
+ from astropy.coordinates import SkyCoord, FK5
9
+ from collections import namedtuple
10
+ import jax.numpy as jnp
11
+ import jax
12
+
13
+ BIG = 1e30
14
+
15
+ def to_float64(value):
16
+ return jnp.asarray(value, dtype=jnp.float64)
17
+
18
+ PreparedData = namedtuple('PreparedData', [
19
+ 'ra_data', 'dec_data', 'v_data',
20
+ 'ra_sigma_safe', 'dec_sigma_safe', 'v_sigma_safe',
21
+ 'dmetric_data', 'data_finite_mask',
22
+ 'data_min', 'data_max',
23
+ 'r_proj_data', 'theta_proj_data',
24
+ ])
25
+
26
+ StreamerData = namedtuple('StreamerData', [
27
+ 'pc_coords',
28
+ 'ra_data', 'dec_data', 'v_data',
29
+ 'ra_sigma', 'dec_sigma', 'v_sigma',
30
+ 'data', 'uncertainties',
31
+ ])
32
+
33
+
34
+
35
+ #@jax.jit
36
+ def wrap_to_pi(angle):
37
+ '''Wrap angles to [-pi, pi)'''
38
+ return (angle + jnp.pi) % (2.0 * jnp.pi) - jnp.pi
39
+
40
+ #@jax.jit
41
+ def circular_median(theta_vals, weights):
42
+ '''Branch-cut-safe median angle. (unrwap, linear median, rewrap)
43
+ theta values with weight = 0 are ignored in the median calculation'''
44
+ weights = weights / (jnp.sum(weights) + 1e-12) # normalize weights to sum to 1, add small value to avoid division by zero
45
+ theta_anchor = jnp.arctan2(
46
+ jnp.sum(weights * jnp.sin(theta_vals)),
47
+ jnp.sum(weights * jnp.cos(theta_vals))
48
+ )
49
+ theta_delta = wrap_to_pi(theta_vals - theta_anchor)
50
+ theta_unwrapped = theta_anchor + theta_delta
51
+ sort_idx = jnp.argsort(theta_unwrapped)
52
+ sorted_vals = theta_unwrapped[sort_idx]
53
+ sorted_weights = weights[sort_idx]
54
+ cumulative_weights = jnp.cumsum(sorted_weights)
55
+ cutoff = 0.5 * jnp.sum(sorted_weights)
56
+ median_idx = jnp.argmax(cumulative_weights >= cutoff)
57
+ theta_ref = sorted_vals[median_idx]
58
+ return wrap_to_pi(theta_ref)
59
+
60
+
61
+ #@jax.jit
62
+ def wrap_to_pi_numpy(angle):
63
+ '''Wrap angles to [-pi, pi)'''
64
+ return (angle + np.pi) % (2.0 * np.pi) - np.pi
65
+
66
+
67
+ def extract_streamer_subcube(cube, vmin=None, vmax=None, xmin=None, xmax=None, ymin=None, ymax=None, rms_thresh=None):
68
+ """Extract a subcube containing the streamer emission, by applying velocity and spatial limits, and masking out low SNR emission."""
69
+ streamer_cube = cube
70
+ if (vmin is not None) and (vmax is not None):
71
+ streamer_cube = streamer_cube.spectral_slab(vmin, vmax)
72
+ if (xmin is not None) and (xmax is not None) and (ymin is not None) and (ymax is not None):
73
+ celestial_wcs = streamer_cube.wcs.celestial
74
+ ny, nx = streamer_cube.shape[1], streamer_cube.shape[2]
75
+ # reference sky coord corresponding to the reference pixel in the WCS
76
+ ref_ra, ref_dec = celestial_wcs.wcs.crval
77
+ ref_coord = SkyCoord(ref_ra*u.deg, ref_dec*u.deg, frame=FK5)
78
+ # convert the limits from offsets to sky coords
79
+ corner1 = SkyCoord(ref_coord.ra + xmin, ref_coord.dec + ymin, frame=FK5) #'bottom left'
80
+ corner2 = SkyCoord(ref_coord.ra + xmax, ref_coord.dec + ymax, frame=FK5) #'top right'
81
+ x1, y1 = celestial_wcs.world_to_pixel(corner1)
82
+ x2, y2 = celestial_wcs.world_to_pixel(corner2)
83
+ xmin_pix = max(0, int(np.floor(min(x1, x2))))
84
+ xmax_pix = min(nx, int(np.ceil(max(x1, x2))))
85
+ ymin_pix = max(0, int(np.floor(min(y1, y2))))
86
+ ymax_pix = min(ny, int(np.ceil(max(y1, y2))))
87
+ streamer_cube = streamer_cube[:, ymin_pix:ymax_pix, xmin_pix:xmax_pix]
88
+ if rms_thresh is not None:
89
+ rms_estimate = streamer_cube.mad_std()
90
+ streamer_cube = streamer_cube.with_mask(streamer_cube > rms_thresh*rms_estimate)
91
+
92
+ return streamer_cube
93
+
94
+ def reduce_to_1D(streamer_cube, yso_centre, n_elements=10):
95
+ '''
96
+ Reduce a cube of emission to a 1D 'streamline' by weighted means
97
+
98
+ Parameters
99
+ ----------
100
+ streamer_cube : SpectralCube object, should contain only streamer emission
101
+ yso_centre : SkyCoord, the coordinates of the star, used to compute RA and Dec offsets in arcsec
102
+ n_elements : int, number of elements to reduce the cube to
103
+
104
+ Returns
105
+ -------
106
+ StreamerData
107
+ Named tuple with fields:
108
+ - pc_coords : full point cloud array, shape (3, N)
109
+ - ra_data : RA offsets of bin means (arcsec), shape (n_elements,)
110
+ - dec_data : Dec offsets of bin means (arcsec), shape (n_elements,)
111
+ - v_data : velocities of bin means (km/s), shape (n_elements,)
112
+ - ra_sigma : RA standard deviations per bin (arcsec), shape (n_elements,)
113
+ - dec_sigma : Dec standard deviations per bin (arcsec), shape (n_elements,)
114
+ - v_sigma : velocity standard deviations per bin (km/s), shape (n_elements,)
115
+ - data : (ra_data, dec_data, v_data) tuple
116
+ - uncertainties: (ra_sigma, dec_sigma, v_sigma) tuple
117
+ '''
118
+ print('Starting reduction')
119
+ nz, ny, nx = streamer_cube.shape
120
+ yso_centre_icrs = yso_centre.icrs
121
+
122
+ # create RA and Dec offset arrays in arcsec relative to the yso centre
123
+ y_indices, x_indices = np.mgrid[0:ny, 0:nx]
124
+ world_coords = streamer_cube.wcs.celestial.pixel_to_world_values(x_indices.ravel(), y_indices.ravel())
125
+ ra_unit = u.Unit(streamer_cube.header.get('CUNIT1', streamer_cube.wcs.celestial.world_axis_units[0]))
126
+ dec_unit = u.Unit(streamer_cube.header.get('CUNIT2', streamer_cube.wcs.celestial.world_axis_units[1]))
127
+ world_sky = SkyCoord(
128
+ ra=world_coords[0] * ra_unit,
129
+ dec=world_coords[1] * dec_unit,
130
+ frame='icrs'
131
+ )
132
+ dra, ddec = yso_centre_icrs.spherical_offsets_to(world_sky)
133
+ ra_coords = dra.to(u.arcsec).value.reshape(ny, nx)
134
+ dec_coords = ddec.to(u.arcsec).value.reshape(ny, nx)
135
+ # create velocity array relative to the reference channel, then express it in km/s
136
+ spectral_unit = u.Unit(streamer_cube.header.get('CUNIT3', streamer_cube.spectral_axis.unit))
137
+ spectral_axis = streamer_cube.spectral_axis.to(spectral_unit)
138
+ v_coords = spectral_axis.to(u.km / u.s).value
139
+
140
+ # get data and mask
141
+ pcloud = np.array(streamer_cube)
142
+ rms_mask = ~np.isnan(pcloud)
143
+ flux = pcloud[rms_mask]
144
+ # get indices of valid points in pc
145
+ pc_indices = np.indices(pcloud.shape)
146
+ pc_z = pc_indices[0][rms_mask]
147
+ pc_y = pc_indices[1][rms_mask]
148
+ pc_x = pc_indices[2][rms_mask]
149
+
150
+ # extract coordinates of valid points using the arrays above
151
+ pc_ra = ra_coords[pc_y, pc_x]
152
+ pc_dec = dec_coords[pc_y, pc_x]
153
+ pc_v = v_coords[pc_z]
154
+ pc_coords = np.array([pc_ra, pc_dec, pc_v]) # shape (3, n_points)
155
+
156
+ # compute partitions for binning the point cloud
157
+ distance_metric, _ = get_distance_metric(pc_coords[0], pc_coords[1], n_elements=n_elements)
158
+ b_per = np.linspace(0, 100, n_elements+1) # percentiles to bin the pc into
159
+ partitions = np.array([np.percentile(distance_metric, per) for per in b_per])
160
+
161
+ # flux-weighted means and stds in each bin
162
+ pc_means = np.zeros((3, n_elements))
163
+ pc_stds = np.zeros((3, n_elements))
164
+ for i in range(n_elements):
165
+ distance_indices = (distance_metric > partitions[i]) & (distance_metric <= partitions[i+1])
166
+ pc_means[:, i] = np.average(pc_coords.T[distance_indices],
167
+ axis=0,
168
+ weights=flux[distance_indices])
169
+ pc_stds[:, i] = np.sqrt(np.average((pc_coords.T[distance_indices] - pc_means[:, i])**2,
170
+ axis=0,
171
+ weights=flux[distance_indices]))
172
+
173
+ # flip arrays so that they go from large to small distance (towards star)
174
+ pc_means = pc_means[:, ::-1]
175
+ pc_stds = pc_stds[:, ::-1]
176
+
177
+ ra_data, dec_data, v_data = pc_means
178
+ ra_sigma, dec_sigma, v_sigma = pc_stds
179
+
180
+ return StreamerData(
181
+ pc_coords=pc_coords,
182
+ ra_data=ra_data, dec_data=dec_data, v_data=v_data,
183
+ ra_sigma=ra_sigma, dec_sigma=dec_sigma, v_sigma=v_sigma,
184
+ data=(ra_data, dec_data, v_data),
185
+ uncertainties=(ra_sigma, dec_sigma, v_sigma),
186
+ )
187
+
188
+
189
+ #@jax.jit
190
+ def safe_percentile(values, percentile):
191
+ """
192
+ jax and jit-safe percentile ignoring invalid values, which does not change array shape
193
+ """
194
+ mask = jnp.isfinite(values)
195
+
196
+ # Sort valid values to the front by pushing invalid ones to BIG
197
+ cleaned = jnp.where(mask, values, to_float64(BIG))
198
+ sorted_vals = jnp.sort(cleaned) # valid values are at the front
199
+
200
+ # make a new mask which is all the not BIG values
201
+ percentile_mask = sorted_vals < to_float64(BIG)
202
+ n_valid = jnp.sum(percentile_mask)
203
+ total_n = values.size
204
+ # Compute the index into only the valid portion
205
+ idx = jnp.clip(
206
+ jnp.floor(percentile / 100.0 * n_valid).astype(jnp.int32),
207
+ 0,
208
+ jnp.maximum(n_valid - 1, 0)
209
+ )
210
+ return sorted_vals[idx]
211
+
212
+
213
+ #@jax.jit
214
+ def get_distance_metric(ra_coords, dec_coords, n_elements=10):
215
+ '''
216
+ Compute radial + angular distance metric for point cloud binning
217
+ Uses a circular angular deviation to avoid branch-cut artifacts.
218
+ '''
219
+ pc_r, pc_theta = cartesian_to_polar(ra_coords, dec_coords)
220
+ pc_r = jnp.where(jnp.abs(pc_r) < 1e-12, to_float64(BIG), pc_r)
221
+
222
+ finite_mask = jnp.isfinite(pc_r) & jnp.isfinite(pc_theta)
223
+
224
+ # deal with if there are no valid points
225
+ def empty_case(_):
226
+
227
+ distance_metric = jnp.full_like(pc_r, to_float64(BIG))
228
+ trace = {
229
+ "n_points": jnp.array(pc_r.size, dtype=jnp.int32),
230
+ "n_finite_points": jnp.array(0, dtype=jnp.int32),
231
+ "n_reference_points": jnp.array(0, dtype=jnp.int32),
232
+ "r_percentile_thresh": to_float64(0.0),
233
+ "r_thresh": to_float64(BIG),
234
+ "theta_ref": to_float64(0.0),
235
+ "theta_weight": to_float64(1.0),
236
+ "close_point_count": jnp.array(0, dtype=jnp.int32),
237
+ }
238
+ return distance_metric, trace
239
+
240
+ def notempty_case(_):
241
+
242
+ theta_weight = 1.0 # maybe make this a tunable parameter
243
+ finite_count = jnp.sum(finite_mask)
244
+
245
+ percentile = 100.0 / n_elements
246
+ r_thresh = safe_percentile(pc_r, percentile)
247
+ # close_mask gives 0s if point is not finite or outside the threshold, and
248
+ # 1s if point is finite and within the threshold
249
+ small_enough_r = pc_r <= r_thresh
250
+ close_mask = (finite_mask & (pc_r <= r_thresh)).astype(jnp.float64)
251
+ # circular median can not have nans passed in,
252
+ # so change nans to 0 (this is fine because they already have weight=0 in the median calculation)
253
+ pc_theta_no_nan = jnp.where(finite_mask, pc_theta, 0.0)
254
+ theta_ref = circular_median(pc_theta_no_nan, weights=close_mask)
255
+
256
+ # cyclic angular deviation
257
+ theta_dev = jnp.pi - jnp.abs(
258
+ jnp.pi - jnp.abs(wrap_to_pi(pc_theta - theta_ref))
259
+ )
260
+
261
+ distance_metric = pc_r * jnp.sqrt(1.0 + (theta_weight * theta_dev) ** 2)
262
+ distance_metric = jnp.where(finite_mask, distance_metric, to_float64(BIG))
263
+ trace = {
264
+ "n_points": jnp.array(pc_r.size, dtype=jnp.int32),
265
+ "n_finite_points": jnp.array(finite_count, dtype=jnp.int32),
266
+ "n_reference_points": jnp.array(jnp.sum(small_enough_r), dtype=jnp.int32),
267
+ "r_percentile_thresh": to_float64(percentile),
268
+ "r_thresh": r_thresh,
269
+ "theta_ref": theta_ref,
270
+ "theta_weight": theta_weight,
271
+ "close_point_count": jnp.array(close_mask.size, dtype=jnp.int32),
272
+ }
273
+
274
+ return distance_metric, trace
275
+
276
+ return jax.lax.cond(finite_mask.any(), notempty_case, empty_case, operand=None)
277
+
278
+ #@jax.jit
279
+ def cartesian_to_polar(x, y):
280
+ '''
281
+ Convert cartesian coordinates (x,y) to polar coordinates
282
+ e.g. inputs could be RA and Dec offsets
283
+ Note theta is returned in radians
284
+ '''
285
+ r = jnp.sqrt(x**2 + y**2 + to_float64(1e-60)) # add small value for gradient stability
286
+ theta = jnp.arctan2(y, x) # angle wrt x-axis, in radians
287
+
288
+ return (r, theta)
289
+
290
+
291
+ def get_metric_partitions(pc_coords, n_elements):
292
+ '''
293
+ Compute percentile partitions for the streamline distance metric
294
+
295
+ Parameters
296
+ ----------
297
+ pc_coords : array
298
+ Point cloud coordinates. Index 0 = RA, Index 1 = Dec, Index 2 = velocity
299
+ n_elements : int
300
+ Number of partitions required
301
+
302
+ Returns
303
+ -------
304
+ partitions : ndarray
305
+ Percentile boundaries of the distance metric
306
+ '''
307
+ if n_elements < 1:
308
+ raise ValueError('n_elements must be >= 1')
309
+
310
+ ra_coords = pc_coords[0]
311
+ dec_coords = pc_coords[1]
312
+ distance_metric, _ = get_distance_metric(ra_coords, dec_coords, n_elements=n_elements)
313
+ distance_metric = np.asarray(distance_metric)
314
+ finite_mask = np.isfinite(distance_metric)
315
+ finite_metric = distance_metric[finite_mask]
316
+
317
+ if finite_metric.size == 0:
318
+ return np.full(n_elements + 1, np.nan, dtype=np.float64)
319
+
320
+ b_per = np.linspace(0.0, 100.0, n_elements + 1)
321
+ return np.asarray([np.percentile(finite_metric, per) for per in b_per], dtype=np.float64)
322
+
323
+
324
+ def get_metric_reference_trace(pc_coords, n_elements=10):
325
+ '''get the metric reference angle and weight used for boundary sampling'''
326
+ ra_coords = pc_coords[0]
327
+ dec_coords = pc_coords[1]
328
+ _, trace = get_distance_metric(ra_coords, dec_coords, n_elements=n_elements)
329
+ theta_ref = float(trace.get('theta_ref', 0.0))
330
+ theta_weight = float(trace.get('theta_weight', 1.0))
331
+ return theta_ref, theta_weight
332
+
333
+
334
+ def sample_metric_boundary(partition_radius, theta_ref, theta_weight=1.0, n_samples=720):
335
+ '''create a constant-metric boundary as a closed RA/Dec curve (for plotting)'''
336
+ if n_samples < 4:
337
+ raise ValueError('n_samples must be >= 4')
338
+
339
+ theta = jnp.linspace(-jnp.pi, jnp.pi, n_samples, endpoint=False)
340
+ theta_dev = jnp.pi - jnp.abs(jnp.pi - jnp.abs(wrap_to_pi(theta - theta_ref)))
341
+ radius = partition_radius / jnp.sqrt(1.0 + (theta_weight * theta_dev) ** 2)
342
+ ra = radius * jnp.cos(theta)
343
+ dec = radius * jnp.sin(theta)
344
+ return ra, dec
345
+
346
+
347
+ def sample_metric_boundaries(pc_coords, partitions, n_samples=720):
348
+ '''create all metric boundary curves for a point cloud and partition set'''
349
+ theta_ref, theta_weight = get_metric_reference_trace(pc_coords, n_elements=len(partitions)-1)
350
+ curves = [
351
+ sample_metric_boundary(partition_radius, theta_ref, theta_weight=theta_weight, n_samples=n_samples)
352
+ for partition_radius in np.asarray(partitions)
353
+ ]
354
+ trace = {
355
+ 'theta_ref': theta_ref,
356
+ 'theta_weight': theta_weight,
357
+ }
358
+ return curves, trace
359
+
360
+
361
+ def plot_metric_boundaries(ax, pc_coords, curves, color='lightgrey', linewidth=1, alpha=0.5, n_samples=720, zorder=1):
362
+ '''plot metric boundary curves on a RA/Dec axis'''
363
+ for ra, dec in curves:
364
+ ax.plot(ra, dec, color=color, linewidth=linewidth, alpha=alpha, zorder=zorder)
365
+ return curves
366
+
367
+
368
+ def prepare_data(data, uncertainties, n_elements):
369
+ '''
370
+ Precompute all the constant data-only quantities used by the gradient descent,
371
+ to speed up later iterations
372
+
373
+ Parameterss
374
+ ----------
375
+ data : tuple of arrays (ra_data, dec_data, v_data)
376
+ Observed RA offset (arcsec), Dec offset (arcsec), velocity (km/s)
377
+ uncertainties : tuple of arrays (ra_sigma, dec_sigma, v_sigma)
378
+ Uncertainties on the data
379
+ n_elements : int
380
+ Number of elements to reduce the cube to, used for computing the distance metric and its partitions
381
+
382
+ Returns
383
+ -------
384
+ PreparedData
385
+ Container containing the precomputed quantities
386
+ '''
387
+ ra_data = jnp.asarray(data[0], dtype=jnp.float64)
388
+ dec_data = jnp.asarray(data[1], dtype=jnp.float64)
389
+ v_data = jnp.asarray(data[2], dtype=jnp.float64)
390
+
391
+ ra_sigma = jnp.asarray(uncertainties[0], dtype=jnp.float64)
392
+ dec_sigma = jnp.asarray(uncertainties[1], dtype=jnp.float64)
393
+ v_sigma = jnp.asarray(uncertainties[2], dtype=jnp.float64)
394
+
395
+ eps = jnp.asarray(1e-8, dtype=jnp.float64)
396
+ ra_sigma_safe = jnp.maximum(ra_sigma, eps)
397
+ dec_sigma_safe = jnp.maximum(dec_sigma, eps)
398
+ v_sigma_safe = jnp.maximum(v_sigma, eps)
399
+
400
+ dmetric_data, _ = get_distance_metric(ra_data, dec_data, n_elements=n_elements)
401
+ data_finite_mask = jnp.isfinite(ra_data) & jnp.isfinite(dec_data) & jnp.isfinite(dmetric_data)
402
+
403
+ data_metric_for_min = jnp.where(data_finite_mask, dmetric_data, jnp.inf)
404
+ data_metric_for_max = jnp.where(data_finite_mask, dmetric_data, -jnp.inf)
405
+ data_min = jnp.min(data_metric_for_min)
406
+ data_max = jnp.max(data_metric_for_max)
407
+
408
+ r_proj_data, theta_proj_data = cartesian_to_polar(ra_data, dec_data)
409
+
410
+ return PreparedData(
411
+ ra_data=ra_data,
412
+ dec_data=dec_data,
413
+ v_data=v_data,
414
+ ra_sigma_safe=ra_sigma_safe,
415
+ dec_sigma_safe=dec_sigma_safe,
416
+ v_sigma_safe=v_sigma_safe,
417
+ dmetric_data=dmetric_data,
418
+ data_finite_mask=data_finite_mask,
419
+ data_min=data_min,
420
+ data_max=data_max,
421
+ r_proj_data=r_proj_data,
422
+ theta_proj_data=theta_proj_data,
423
+ )
424
+
425
+