sting 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,448 @@
1
+ '''
2
+ Contains all functions needed for a forward model of a streamline.
3
+ All are fully differentiable using JAX
4
+
5
+ Streamline implementation is based on Mendoza et al. (2009) doi:10.1111/j.1365-2966.2008.14210.x
6
+
7
+ The assumed input units are:
8
+ - distance on sky: au
9
+ - velocity: km/s
10
+ - mass: solar masses
11
+ - angles (PA, i, theta, phi...): radians
12
+ - rc (centrifugal radius): au (alternative to omega - either can be used to calculate mu=rc/r0)
13
+ - Omega (angular velocity): 1/ (alternative to rc - either can be used to calculate mu=rc/r0)
14
+ - distance to source: pc
15
+
16
+ Last updated: 19-06-2026
17
+ '''
18
+
19
+
20
+ import astropy.units as u
21
+ import jax
22
+ import jax.numpy as jnp
23
+ from jax.experimental import checkify
24
+ jax.config.update("jax_enable_x64", True)
25
+ from typing import NamedTuple
26
+
27
+
28
+ ## constants
29
+ eps = 1e-8 # small value to avoid division by zero
30
+ FLOAT_DTYPE = jnp.float64
31
+ G = 6.67430e-11 * (1e-3)**2 * (1.988416e30) / (1.4959787e11) # in au (km/s)^2 * Msol^-1
32
+ au_to_km = 1.4959787e8 #km
33
+
34
+
35
+ ## important streamline quantities (for easy reuse)
36
+ class StreamState(NamedTuple):
37
+ rc: jnp.ndarray
38
+ mu: jnp.ndarray
39
+ nu: jnp.ndarray
40
+ epsilon: jnp.ndarray
41
+ ecc: jnp.ndarray
42
+ vk0: jnp.ndarray
43
+
44
+
45
+ @jax.jit
46
+ def to_float64(value):
47
+ '''input must be a number or array-like'''
48
+ return jnp.asarray(value, dtype=FLOAT_DTYPE)
49
+
50
+ @jax.jit
51
+ def v_k(radius, mass=0.5):
52
+ '''
53
+ Velocity term that is repeated in all velocity components.
54
+ It corresponds to v_k in Mendoza+(2009)
55
+ :param radius: au
56
+ :param mass: Msun
57
+ :return: v_k, km/s
58
+ '''
59
+ arg = G * mass / radius
60
+ return jnp.sqrt(arg)
61
+
62
+ @jax.jit
63
+ def r_cent(mass, omega=1e-14, r0=1e4):
64
+ '''
65
+ Centrifugal radius or disk radius in the Ulrich (1976) model.
66
+ r_u in Mendoza's nomenclature.
67
+
68
+ :param mass: Central mass for the protostar, Msun
69
+ :param omega: Angular speed at the r0 radius, 1/s
70
+ :param r0: Initial radius of the streamline, au
71
+ :return: r_cent, au
72
+ '''
73
+ r_cent = (jnp.power(r0, 4) * jnp.power(omega, 2) / (G * mass)) # in au^3 km^-2
74
+ r_cent_au = r_cent * (jnp.power(au_to_km, 2)) # in au
75
+ return r_cent_au
76
+
77
+ @jax.jit
78
+ def omega_from_mu(mu, mass, r0):
79
+ omega_squared = mu * G * mass / jnp.power(r0, 3) # in km^2 au^-3
80
+ omega = jnp.sqrt(omega_squared) / au_to_km # in 1/s
81
+ return omega
82
+
83
+ @jax.jit
84
+ def mu_from_omega(omega, mass, r0):
85
+ rc = r_cent(mass=mass, omega=omega, r0=r0)
86
+ mu = rc / r0
87
+ return mu
88
+
89
+
90
+ @jax.jit
91
+ def build_stream_quantities(mass, r0, theta0, mu, v_r0):
92
+ '''
93
+ precompute streamer quantities reused throughout file, and
94
+ store in class StreamState (near top)
95
+ '''
96
+ # Protect near-zero v_r0 from creating singularities in nu calculation
97
+ # Allow negative v_r0, but replace exact-zero or tiny values with signed epsilon
98
+ threshold = to_float64(eps)
99
+ v_r0 = jnp.where(
100
+ jnp.isclose(v_r0, to_float64(0.0)),
101
+ - jnp.sign(v_r0) * threshold, #let it continue in the direction it was going
102
+ v_r0 # normal values -> unchanged
103
+ )
104
+ threshold = to_float64(eps)
105
+ v_r0_protected = jnp.sign(v_r0) * jnp.maximum(jnp.abs(v_r0), threshold)
106
+ v_r0_protected = jnp.where(v_r0 == 0.0, threshold, v_r0_protected) # handle exact 0
107
+
108
+ mu = to_float64(mu)
109
+ rc = mu * r0
110
+ nu = v_r0_protected * jnp.sqrt(rc / (G * mass))
111
+ sin_theta0 = jnp.sin(theta0)
112
+ sin_theta0_sq = jnp.power(sin_theta0, 2)
113
+ epsilon = jnp.power(nu, 2) + jnp.power(mu, 2) * sin_theta0_sq - 2 * mu
114
+ ecc = jnp.sqrt(1.0 + epsilon * sin_theta0_sq)
115
+ vk0 = v_k(rc, mass=mass)
116
+
117
+ return StreamState(rc=rc, mu=mu, nu=nu, epsilon=epsilon, ecc=ecc, vk0=vk0)
118
+
119
+ @jax.jit
120
+ def safe_arccos(x, eps=1e-10):
121
+ '''
122
+ Safe arccos function with clipping to valid range [-1, 1],
123
+ with a small margin to avoid numerical issues in gradients near the boundaries
124
+
125
+ :param x: input value
126
+ :param eps: small offset
127
+ :return: arccos of clipped input
128
+ '''
129
+ x = jnp.asarray(x)
130
+ x = x.astype(FLOAT_DTYPE)
131
+
132
+ # Keep away from +/-1 by at least a few ULPs of the active dtype
133
+ eps_user = jnp.asarray(eps, dtype=x.dtype)
134
+ eps_floor = jnp.asarray(32.0 * jnp.finfo(x.dtype).eps, dtype=x.dtype)
135
+ eps_eff = jnp.maximum(eps_user, eps_floor)
136
+
137
+ x_safe = jnp.clip(x, -1.0 + eps_eff, 1.0 - eps_eff)
138
+ return jnp.arccos(x_safe)
139
+
140
+ @jax.jit
141
+ def get_theta(theta0, orb_ang, orb_ang0):
142
+ '''
143
+ Gets theta from theta0, orb_ang, and orb_ang0, in radians.
144
+ Eqn (8) in Mendoza+2009
145
+
146
+ :param theta0: radians
147
+ :param orb_ang: radians
148
+ :param orb_ang0: radians
149
+ :return theta: radians
150
+ '''
151
+ cos_theta = jnp.cos(theta0) * jnp.cos(orb_ang - orb_ang0)
152
+ theta = safe_arccos(cos_theta)
153
+ return theta
154
+
155
+
156
+ @jax.jit
157
+ def get_orb_ang(r_to_rc, theta0, ecc):
158
+ '''
159
+ Gets orb_ang (varphi in Mendoza+2009), in radians.
160
+ To get initial orb_ang, set r_to_rc = r0/rc = 1/mu
161
+
162
+ :param r_to_rc: radius divided by centrifugal radius
163
+ :param theta0: radius
164
+ :param ecc: eccentricity
165
+ :return orb_ang: radians
166
+ '''
167
+ cos_orb_ang = (1/ecc) * (1 - (jnp.power(jnp.sin(theta0), 2) / r_to_rc))
168
+ orb_ang = safe_arccos(cos_orb_ang)
169
+ return orb_ang
170
+
171
+ @jax.jit
172
+ def get_dphi(theta, theta0=jnp.radians(30)):
173
+ '''
174
+ Gets the difference in Phi between initial and current, in radians.
175
+
176
+ :param theta: radians
177
+ :param theta0: radians
178
+ :return: difference in Phi angle, radians
179
+ '''
180
+ small_eps = to_float64(1e-12) # need a very small eps since tan(theta) can be very small for theta near 0 or pi
181
+ tan_theta_safe = jnp.where(
182
+ jnp.abs(jnp.tan(theta)) > small_eps,
183
+ jnp.tan(theta),
184
+ jnp.sign(jnp.tan(theta)) * small_eps,
185
+ )
186
+ # handle exact zero case
187
+ tan_theta_safe = jnp.where(tan_theta_safe == 0.0, small_eps, tan_theta_safe)
188
+ arg = jnp.tan(theta0) / tan_theta_safe
189
+ return safe_arccos(arg, eps=small_eps)
190
+
191
+
192
+
193
+ @jax.jit
194
+ def stream_line(r, r_mask, stream_state, theta0=jnp.radians(30), phi0=jnp.radians(15)):
195
+ '''
196
+ It calculates the stream line following Mendoza et al. (2009),
197
+ only for r < r0. Point r = r0 is handled outside the function.
198
+ It takes the radial velocity and rotation at the streamline
199
+ initial radius and it describes the entire trajectory.
200
+
201
+ :param r: au
202
+ :param r_mask: boolean mask for valid r values (r < r0 and r > 0.5*rc)
203
+ :param stream_state: StreamState named tuple containing precomputed quantities for the streamline
204
+ :param theta0: radians
205
+ :param phi0: radians
206
+ :return: theta, radians
207
+ '''
208
+ r = jnp.asarray(r, dtype=FLOAT_DTYPE)
209
+ rc = stream_state.rc
210
+ mu = stream_state.mu
211
+ ecc = stream_state.ecc
212
+
213
+ # orb_ang is varphi in Mendoza+2009
214
+ # at initial position r_to_rc = r0/rc = 1/mu
215
+ orb_ang0 = get_orb_ang(r_to_rc=1/mu, theta0=theta0, ecc=ecc)
216
+
217
+ r_to_rc_raw = r / rc
218
+ r_to_rc = jnp.where(r_mask, r_to_rc_raw, to_float64(0.6))
219
+
220
+ orb_ang = get_orb_ang(r_to_rc=r_to_rc, theta0=theta0, ecc=ecc)
221
+ theta = get_theta(theta0, orb_ang, orb_ang0)
222
+ phi = phi0 + get_dphi(theta, theta0=theta0)
223
+
224
+ # remove values where r_to_rc < 0.5 (inside centrifugal radius)
225
+ # this will include all the mask points, and also any points that are inside 0.5*rc
226
+ valid_mask = r_mask & (r_to_rc >= 0.5)
227
+
228
+ # safe sentinel values for invalid points, to make sure gradients are finite
229
+ # will be masked out later in final output
230
+ orb_ang = jnp.where(valid_mask, orb_ang, jnp.pi/4)
231
+ theta_sentinel = jnp.minimum(theta0 + to_float64(0.1), to_float64(jnp.pi) - to_float64(eps))
232
+ theta = jnp.where(valid_mask, theta, theta_sentinel)
233
+ phi = jnp.where(valid_mask, phi, phi0)
234
+
235
+ return orb_ang, theta, phi, valid_mask #in radians
236
+
237
+
238
+ @jax.jit
239
+ def stream_line_vel(
240
+ r,
241
+ theta,
242
+ orb_ang,
243
+ stream_state,
244
+ theta0=jnp.radians(30),
245
+ r_mask=None
246
+ ):
247
+ '''
248
+ It calculates the velocity along the stream line following Mendoza+(2009)
249
+ It takes the radial velocity and rotation at the streamline
250
+ initial radius and it describes the entire trajectory.
251
+
252
+ :param theta: radians
253
+ :param r: au
254
+ :param stream_state: StreamState named tuple containing precomputed quantities for the streamline
255
+ :param theta0: radians
256
+ :param r_mask: boolean mask
257
+ :return: v_r, v_theta, v_phi in units of km/s
258
+ '''
259
+ rc = stream_state.rc
260
+ ecc = stream_state.ecc
261
+ vk0 = stream_state.vk0
262
+
263
+ r_to_rc_raw = r / rc
264
+ if r_mask is not None:
265
+ # see sentnel value used for r_to_rc in stream_line. this is the same thing.
266
+ r_to_rc = jnp.where(r_mask, r_to_rc_raw, to_float64(0.6))
267
+ else:
268
+ r_to_rc = r_to_rc_raw
269
+ #
270
+ v_r_all = -ecc * jnp.sin(theta0) * jnp.sin(orb_ang) / r_to_rc /(1 - ecc*jnp.cos(orb_ang))
271
+ sqrt_arg = jnp.power(jnp.cos(theta0), 2) - jnp.power(jnp.cos(theta), 2)
272
+ sqrt_arg_safe = jnp.maximum(sqrt_arg, eps) # eps = 1e-8 or similar
273
+
274
+ v_theta_all = jnp.sin(theta0) / jnp.sin(theta) / r_to_rc \
275
+ * jnp.sqrt(sqrt_arg_safe)
276
+ v_phi_all = jnp.power(jnp.sin(theta0), 2) / (jnp.sin(theta) * r_to_rc)
277
+
278
+ return v_r_all * vk0, v_theta_all * vk0, v_phi_all * vk0
279
+
280
+ @jax.jit
281
+ def build_rotation_matrix(inc, pa):
282
+ '''constructs combined inclination/position-angle rotation matrix'''
283
+
284
+ inc = jnp.asarray(inc, dtype=FLOAT_DTYPE)
285
+ pa = jnp.asarray(pa, dtype=FLOAT_DTYPE)
286
+
287
+ ci = jnp.cos(inc)
288
+ si = jnp.sin(inc)
289
+ cp = jnp.cos(pa)
290
+ sp = jnp.sin(pa)
291
+
292
+ return jnp.array([
293
+ [cp, sp * si, -sp * ci],
294
+ [0.0, ci, si],
295
+ [sp, -cp * si, cp * ci],
296
+ ], dtype=FLOAT_DTYPE)
297
+
298
+ @jax.jit
299
+ def rotate_xyz(x, y, z, rotation_matrix):
300
+ '''
301
+ Rotate on inclination and PA
302
+ x-axis and y-axis are on the plane on the sky,
303
+ z-axis is the
304
+
305
+ Rotation around x is inclination angle
306
+ Rotation around y is PA angle
307
+
308
+ Using example matrices as described in:
309
+ https://en.wikipedia.org/wiki/3D_projection
310
+
311
+ :param x: cartesian x-coordinate, in the direction of decreasing RA
312
+ :param y: cartesian y-coordinate, in the direction away of the observer
313
+ :param z: cartesian z-coordinate, in the direction of increasing Dec.
314
+ :param rotation_matrix: 3x3 rotation matrix combining inclination and PA rotations.
315
+ :return: new x, y, and z-coordinates as observed on the sky, with the
316
+ same units as the input ones.
317
+
318
+ '''
319
+ x = jnp.asarray(x, dtype=FLOAT_DTYPE)
320
+ y = jnp.asarray(y, dtype=FLOAT_DTYPE)
321
+ z = jnp.asarray(z, dtype=FLOAT_DTYPE)
322
+
323
+ xyz = jnp.stack((x, y, z), axis=0)
324
+
325
+ xyz_rot = rotation_matrix @ xyz
326
+
327
+ return xyz_rot[0], xyz_rot[1], xyz_rot[2]
328
+
329
+ def check_rc_r0(rc, r0):
330
+ '''check that centrifugal radius is smaller than initial radius of streamline, otherwise the model is not valid'''
331
+ checkify.check(
332
+ rc < r0,
333
+ "Centrifugal radius is larger than start of streamline. Model is not valid."
334
+ )
335
+
336
+ def check_r_array(r, r_low):
337
+ '''check that radius array extends down to r_low, otherwise the model doesn't extend far enough for the given npoints and deltar'''
338
+ r_small = r <= r_low
339
+ checkify.check(
340
+ jnp.any(r_small),
341
+ "Radius points do not extend down to rlow. Increase npoints and/or deltar"
342
+ )
343
+
344
+ def xyz_stream(mass=0.5, r0=1e4, theta0=jnp.radians(30),
345
+ phi0=jnp.radians(15), mu=0.1, v_r0=0,
346
+ inc=0, pa=0, rmin=None, deltar=1, npoints=10000):
347
+ '''
348
+ it gets xyz coordinates and velocities for a stream line.
349
+ They are also rotated in PA and inclination along the line of sight.
350
+ This is a wrapper around stream_line() and rotate_xyz()
351
+
352
+ Spherical into cartesian transformation is done for position and velocity
353
+ using:
354
+ https://en.wikipedia.org/wiki/Vector_fields_in_cylindrical_and_spherical_coordinates
355
+
356
+ :param mass: Central mass (unitless, Msun)
357
+ :param r0: Initial radius of streamline (unitless, au)
358
+ :param theta0: Initial polar angle of streamline (unitless, radians)
359
+ :param phi0: Initial azimuthal angle of streamline (unitless, radians)
360
+ :param mu: dimensionless, rc/r0 in (0, 1)
361
+ :param v_r0: Initial radial velocity of the streamline, (km/s)
362
+ :param inc: inclination with respect of line-of-sight, inc=0 is an edge-on-disk (unitless, radians)
363
+ :param pa: Position angle of the rotation axis, measured due East from North. This is usually estimated from the outflow PA, or the disk PA-90deg., (unitless, radians)
364
+ :param rmin: smallest radius for calculation, (unitless, au)
365
+ :param deltar: spacing between two consecutive radii in the sampling of the streamer, in (unitless, au)
366
+ :param npoints: number of points to sample along the streamer
367
+ This is just so that arrays are fixed length for jax/jit compatibility,
368
+ but the actual number of valid points is determined by r0, rmin, rc, deltar,
369
+ so some of the returned points may be NaN if npoints is larger than the number of valid points
370
+ :return: x, y, z in (au), v_x, v_y, v_z in (km/s)
371
+ '''
372
+
373
+ mass = jnp.asarray(mass, dtype=FLOAT_DTYPE)
374
+ r0 = jnp.asarray(r0, dtype=FLOAT_DTYPE)
375
+ theta0 = jnp.asarray(theta0, dtype=FLOAT_DTYPE)
376
+ phi0 = jnp.asarray(phi0, dtype=FLOAT_DTYPE)
377
+ mu = jnp.asarray(mu, dtype=FLOAT_DTYPE)
378
+ v_r0 = jnp.asarray(v_r0, dtype=FLOAT_DTYPE)
379
+ inc = jnp.asarray(inc, dtype=FLOAT_DTYPE)
380
+ pa = jnp.asarray(pa, dtype=FLOAT_DTYPE)
381
+ deltar = jnp.asarray(deltar, dtype=FLOAT_DTYPE)
382
+ stream_state = build_stream_quantities(mass=mass, r0=r0, theta0=theta0, mu=mu, v_r0=v_r0)
383
+ rc = stream_state.rc
384
+ mu = stream_state.mu
385
+ ecc = stream_state.ecc
386
+
387
+ rotation_matrix = build_rotation_matrix(inc, pa)
388
+
389
+ check_rc_r0(rc, r0)
390
+
391
+ # find the smallest radius for calculation
392
+ # this is the maximum between rmin and 0.5*rc
393
+ r_low = jnp.maximum(rmin, rc*0.5)
394
+
395
+ # r is values internal to the initial radius r0 for computation
396
+ # r_mask is used to mask out points that are outside the valid range, but we still need to compute them for jax/jit compatibility
397
+ r = (r0 - deltar) - jnp.arange(npoints-1, dtype=FLOAT_DTYPE) * deltar
398
+ check_r_array(r, r_low)
399
+ r_mask = r > r_low
400
+
401
+ # calculate positions and velocities inside r0
402
+ # the valid_mask will later be used to mask out invalid points. currently these values are zero
403
+ orb_ang, theta, phi, valid_mask = stream_line(r, r_mask, stream_state=stream_state, theta0=theta0, phi0=phi0)
404
+ v_r, v_theta, v_phi = stream_line_vel(r, theta, orb_ang, stream_state=stream_state, theta0=theta0, r_mask=r_mask)
405
+ # prepend initial positions and velocities at r0
406
+ valid_mask_full = jnp.concatenate((jnp.asarray([True], dtype=bool), valid_mask))
407
+ r_full = jnp.concatenate((jnp.asarray([r0], dtype=FLOAT_DTYPE), r))
408
+ theta_full = jnp.concatenate((jnp.asarray([theta0], dtype=FLOAT_DTYPE), theta))
409
+ phi_full = jnp.concatenate((jnp.asarray([phi0], dtype=FLOAT_DTYPE), phi))
410
+ orb_ang0 = get_orb_ang(r_to_rc=1/mu, theta0=theta0, ecc=ecc)
411
+ orb_ang_full = jnp.concatenate((jnp.asarray([orb_ang0], dtype=FLOAT_DTYPE), orb_ang))
412
+ v_r_full = jnp.concatenate((jnp.asarray([v_r0], dtype=FLOAT_DTYPE), v_r))
413
+ v_theta_full = jnp.concatenate((jnp.asarray([0.0], dtype=FLOAT_DTYPE), v_theta))
414
+ # we need to calculate v_phi0
415
+ v_phi0 = stream_state.vk0 * jnp.sin(theta0) * stream_state.mu
416
+ v_phi_full = jnp.concatenate((jnp.asarray([v_phi0], dtype=FLOAT_DTYPE), v_phi))
417
+
418
+ # convert from spherical into cartesian coordinates
419
+ v_x = v_r_full * jnp.sin(theta_full) * jnp.cos(phi_full) \
420
+ + v_theta_full * jnp.cos(theta_full) * jnp.cos(phi_full) \
421
+ - v_phi_full * jnp.sin(phi_full)
422
+ v_y = v_r_full * jnp.sin(theta_full) * jnp.sin(phi_full) \
423
+ + v_theta_full * jnp.cos(theta_full) * jnp.sin(phi_full) \
424
+ + v_phi_full * jnp.cos(phi_full)
425
+ v_z = v_r_full * jnp.cos(theta_full) \
426
+ - v_theta_full * jnp.sin(theta_full)
427
+ x = r_full * jnp.sin(theta_full) * jnp.cos(phi_full)
428
+ y = r_full * jnp.sin(theta_full) * jnp.sin(phi_full)
429
+ z = r_full * jnp.cos(theta_full)
430
+ rotated_x, rotated_y, rotated_z = rotate_xyz(x, y, z, rotation_matrix=rotation_matrix)
431
+ rotated_v_x, rotated_v_y, rotated_v_z = rotate_xyz(v_x, v_y, v_z, rotation_matrix=rotation_matrix)
432
+ # get mask from smallest radius for calculation
433
+ gd_rlow = (r_full > r_low)
434
+ gd_rlow = jnp.logical_or(gd_rlow, valid_mask_full)
435
+ gd_rlow = gd_rlow.astype(x.dtype)
436
+ # apply mask to set invalid points to zero
437
+ rotated_x = jnp.where(gd_rlow, rotated_x, 0.0)
438
+ rotated_y = jnp.where(gd_rlow, rotated_y, 0.0)
439
+ rotated_z = jnp.where(gd_rlow, rotated_z, 0.0)
440
+ rotated_v_x = jnp.where(gd_rlow, rotated_v_x, 0.0)
441
+ rotated_v_y = jnp.where(gd_rlow, rotated_v_y, 0.0)
442
+ rotated_v_z = jnp.where(gd_rlow, rotated_v_z, 0.0)
443
+ return (rotated_x, rotated_y, rotated_z), \
444
+ (rotated_v_x, rotated_v_y, rotated_v_z), \
445
+ gd_rlow
446
+
447
+
448
+ checked_xyz_stream = jax.jit(checkify.checkify(xyz_stream), static_argnames=['npoints'])