sting 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sting/__init__.py +8 -0
- sting/_version.py +24 -0
- sting/errors.py +677 -0
- sting/extract_streamline.py +425 -0
- sting/gradient_descent.py +1776 -0
- sting/outputs.py +1705 -0
- sting/stream_lines_grad.py +448 -0
- sting-0.2.0.dist-info/METADATA +251 -0
- sting-0.2.0.dist-info/RECORD +14 -0
- sting-0.2.0.dist-info/WHEEL +5 -0
- sting-0.2.0.dist-info/licenses/LICENCE +21 -0
- sting-0.2.0.dist-info/scm_file_list.json +26 -0
- sting-0.2.0.dist-info/scm_version.json +8 -0
- sting-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,448 @@
|
|
|
1
|
+
'''
|
|
2
|
+
Contains all functions needed for a forward model of a streamline.
|
|
3
|
+
All are fully differentiable using JAX
|
|
4
|
+
|
|
5
|
+
Streamline implementation is based on Mendoza et al. (2009) doi:10.1111/j.1365-2966.2008.14210.x
|
|
6
|
+
|
|
7
|
+
The assumed input units are:
|
|
8
|
+
- distance on sky: au
|
|
9
|
+
- velocity: km/s
|
|
10
|
+
- mass: solar masses
|
|
11
|
+
- angles (PA, i, theta, phi...): radians
|
|
12
|
+
- rc (centrifugal radius): au (alternative to omega - either can be used to calculate mu=rc/r0)
|
|
13
|
+
- Omega (angular velocity): 1/ (alternative to rc - either can be used to calculate mu=rc/r0)
|
|
14
|
+
- distance to source: pc
|
|
15
|
+
|
|
16
|
+
Last updated: 19-06-2026
|
|
17
|
+
'''
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
import astropy.units as u
|
|
21
|
+
import jax
|
|
22
|
+
import jax.numpy as jnp
|
|
23
|
+
from jax.experimental import checkify
|
|
24
|
+
jax.config.update("jax_enable_x64", True)
|
|
25
|
+
from typing import NamedTuple
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
## constants
|
|
29
|
+
eps = 1e-8 # small value to avoid division by zero
|
|
30
|
+
FLOAT_DTYPE = jnp.float64
|
|
31
|
+
G = 6.67430e-11 * (1e-3)**2 * (1.988416e30) / (1.4959787e11) # in au (km/s)^2 * Msol^-1
|
|
32
|
+
au_to_km = 1.4959787e8 #km
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
## important streamline quantities (for easy reuse)
|
|
36
|
+
class StreamState(NamedTuple):
|
|
37
|
+
rc: jnp.ndarray
|
|
38
|
+
mu: jnp.ndarray
|
|
39
|
+
nu: jnp.ndarray
|
|
40
|
+
epsilon: jnp.ndarray
|
|
41
|
+
ecc: jnp.ndarray
|
|
42
|
+
vk0: jnp.ndarray
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@jax.jit
|
|
46
|
+
def to_float64(value):
|
|
47
|
+
'''input must be a number or array-like'''
|
|
48
|
+
return jnp.asarray(value, dtype=FLOAT_DTYPE)
|
|
49
|
+
|
|
50
|
+
@jax.jit
|
|
51
|
+
def v_k(radius, mass=0.5):
|
|
52
|
+
'''
|
|
53
|
+
Velocity term that is repeated in all velocity components.
|
|
54
|
+
It corresponds to v_k in Mendoza+(2009)
|
|
55
|
+
:param radius: au
|
|
56
|
+
:param mass: Msun
|
|
57
|
+
:return: v_k, km/s
|
|
58
|
+
'''
|
|
59
|
+
arg = G * mass / radius
|
|
60
|
+
return jnp.sqrt(arg)
|
|
61
|
+
|
|
62
|
+
@jax.jit
|
|
63
|
+
def r_cent(mass, omega=1e-14, r0=1e4):
|
|
64
|
+
'''
|
|
65
|
+
Centrifugal radius or disk radius in the Ulrich (1976) model.
|
|
66
|
+
r_u in Mendoza's nomenclature.
|
|
67
|
+
|
|
68
|
+
:param mass: Central mass for the protostar, Msun
|
|
69
|
+
:param omega: Angular speed at the r0 radius, 1/s
|
|
70
|
+
:param r0: Initial radius of the streamline, au
|
|
71
|
+
:return: r_cent, au
|
|
72
|
+
'''
|
|
73
|
+
r_cent = (jnp.power(r0, 4) * jnp.power(omega, 2) / (G * mass)) # in au^3 km^-2
|
|
74
|
+
r_cent_au = r_cent * (jnp.power(au_to_km, 2)) # in au
|
|
75
|
+
return r_cent_au
|
|
76
|
+
|
|
77
|
+
@jax.jit
|
|
78
|
+
def omega_from_mu(mu, mass, r0):
|
|
79
|
+
omega_squared = mu * G * mass / jnp.power(r0, 3) # in km^2 au^-3
|
|
80
|
+
omega = jnp.sqrt(omega_squared) / au_to_km # in 1/s
|
|
81
|
+
return omega
|
|
82
|
+
|
|
83
|
+
@jax.jit
|
|
84
|
+
def mu_from_omega(omega, mass, r0):
|
|
85
|
+
rc = r_cent(mass=mass, omega=omega, r0=r0)
|
|
86
|
+
mu = rc / r0
|
|
87
|
+
return mu
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@jax.jit
|
|
91
|
+
def build_stream_quantities(mass, r0, theta0, mu, v_r0):
|
|
92
|
+
'''
|
|
93
|
+
precompute streamer quantities reused throughout file, and
|
|
94
|
+
store in class StreamState (near top)
|
|
95
|
+
'''
|
|
96
|
+
# Protect near-zero v_r0 from creating singularities in nu calculation
|
|
97
|
+
# Allow negative v_r0, but replace exact-zero or tiny values with signed epsilon
|
|
98
|
+
threshold = to_float64(eps)
|
|
99
|
+
v_r0 = jnp.where(
|
|
100
|
+
jnp.isclose(v_r0, to_float64(0.0)),
|
|
101
|
+
- jnp.sign(v_r0) * threshold, #let it continue in the direction it was going
|
|
102
|
+
v_r0 # normal values -> unchanged
|
|
103
|
+
)
|
|
104
|
+
threshold = to_float64(eps)
|
|
105
|
+
v_r0_protected = jnp.sign(v_r0) * jnp.maximum(jnp.abs(v_r0), threshold)
|
|
106
|
+
v_r0_protected = jnp.where(v_r0 == 0.0, threshold, v_r0_protected) # handle exact 0
|
|
107
|
+
|
|
108
|
+
mu = to_float64(mu)
|
|
109
|
+
rc = mu * r0
|
|
110
|
+
nu = v_r0_protected * jnp.sqrt(rc / (G * mass))
|
|
111
|
+
sin_theta0 = jnp.sin(theta0)
|
|
112
|
+
sin_theta0_sq = jnp.power(sin_theta0, 2)
|
|
113
|
+
epsilon = jnp.power(nu, 2) + jnp.power(mu, 2) * sin_theta0_sq - 2 * mu
|
|
114
|
+
ecc = jnp.sqrt(1.0 + epsilon * sin_theta0_sq)
|
|
115
|
+
vk0 = v_k(rc, mass=mass)
|
|
116
|
+
|
|
117
|
+
return StreamState(rc=rc, mu=mu, nu=nu, epsilon=epsilon, ecc=ecc, vk0=vk0)
|
|
118
|
+
|
|
119
|
+
@jax.jit
|
|
120
|
+
def safe_arccos(x, eps=1e-10):
|
|
121
|
+
'''
|
|
122
|
+
Safe arccos function with clipping to valid range [-1, 1],
|
|
123
|
+
with a small margin to avoid numerical issues in gradients near the boundaries
|
|
124
|
+
|
|
125
|
+
:param x: input value
|
|
126
|
+
:param eps: small offset
|
|
127
|
+
:return: arccos of clipped input
|
|
128
|
+
'''
|
|
129
|
+
x = jnp.asarray(x)
|
|
130
|
+
x = x.astype(FLOAT_DTYPE)
|
|
131
|
+
|
|
132
|
+
# Keep away from +/-1 by at least a few ULPs of the active dtype
|
|
133
|
+
eps_user = jnp.asarray(eps, dtype=x.dtype)
|
|
134
|
+
eps_floor = jnp.asarray(32.0 * jnp.finfo(x.dtype).eps, dtype=x.dtype)
|
|
135
|
+
eps_eff = jnp.maximum(eps_user, eps_floor)
|
|
136
|
+
|
|
137
|
+
x_safe = jnp.clip(x, -1.0 + eps_eff, 1.0 - eps_eff)
|
|
138
|
+
return jnp.arccos(x_safe)
|
|
139
|
+
|
|
140
|
+
@jax.jit
|
|
141
|
+
def get_theta(theta0, orb_ang, orb_ang0):
|
|
142
|
+
'''
|
|
143
|
+
Gets theta from theta0, orb_ang, and orb_ang0, in radians.
|
|
144
|
+
Eqn (8) in Mendoza+2009
|
|
145
|
+
|
|
146
|
+
:param theta0: radians
|
|
147
|
+
:param orb_ang: radians
|
|
148
|
+
:param orb_ang0: radians
|
|
149
|
+
:return theta: radians
|
|
150
|
+
'''
|
|
151
|
+
cos_theta = jnp.cos(theta0) * jnp.cos(orb_ang - orb_ang0)
|
|
152
|
+
theta = safe_arccos(cos_theta)
|
|
153
|
+
return theta
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@jax.jit
|
|
157
|
+
def get_orb_ang(r_to_rc, theta0, ecc):
|
|
158
|
+
'''
|
|
159
|
+
Gets orb_ang (varphi in Mendoza+2009), in radians.
|
|
160
|
+
To get initial orb_ang, set r_to_rc = r0/rc = 1/mu
|
|
161
|
+
|
|
162
|
+
:param r_to_rc: radius divided by centrifugal radius
|
|
163
|
+
:param theta0: radius
|
|
164
|
+
:param ecc: eccentricity
|
|
165
|
+
:return orb_ang: radians
|
|
166
|
+
'''
|
|
167
|
+
cos_orb_ang = (1/ecc) * (1 - (jnp.power(jnp.sin(theta0), 2) / r_to_rc))
|
|
168
|
+
orb_ang = safe_arccos(cos_orb_ang)
|
|
169
|
+
return orb_ang
|
|
170
|
+
|
|
171
|
+
@jax.jit
|
|
172
|
+
def get_dphi(theta, theta0=jnp.radians(30)):
|
|
173
|
+
'''
|
|
174
|
+
Gets the difference in Phi between initial and current, in radians.
|
|
175
|
+
|
|
176
|
+
:param theta: radians
|
|
177
|
+
:param theta0: radians
|
|
178
|
+
:return: difference in Phi angle, radians
|
|
179
|
+
'''
|
|
180
|
+
small_eps = to_float64(1e-12) # need a very small eps since tan(theta) can be very small for theta near 0 or pi
|
|
181
|
+
tan_theta_safe = jnp.where(
|
|
182
|
+
jnp.abs(jnp.tan(theta)) > small_eps,
|
|
183
|
+
jnp.tan(theta),
|
|
184
|
+
jnp.sign(jnp.tan(theta)) * small_eps,
|
|
185
|
+
)
|
|
186
|
+
# handle exact zero case
|
|
187
|
+
tan_theta_safe = jnp.where(tan_theta_safe == 0.0, small_eps, tan_theta_safe)
|
|
188
|
+
arg = jnp.tan(theta0) / tan_theta_safe
|
|
189
|
+
return safe_arccos(arg, eps=small_eps)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
@jax.jit
|
|
194
|
+
def stream_line(r, r_mask, stream_state, theta0=jnp.radians(30), phi0=jnp.radians(15)):
|
|
195
|
+
'''
|
|
196
|
+
It calculates the stream line following Mendoza et al. (2009),
|
|
197
|
+
only for r < r0. Point r = r0 is handled outside the function.
|
|
198
|
+
It takes the radial velocity and rotation at the streamline
|
|
199
|
+
initial radius and it describes the entire trajectory.
|
|
200
|
+
|
|
201
|
+
:param r: au
|
|
202
|
+
:param r_mask: boolean mask for valid r values (r < r0 and r > 0.5*rc)
|
|
203
|
+
:param stream_state: StreamState named tuple containing precomputed quantities for the streamline
|
|
204
|
+
:param theta0: radians
|
|
205
|
+
:param phi0: radians
|
|
206
|
+
:return: theta, radians
|
|
207
|
+
'''
|
|
208
|
+
r = jnp.asarray(r, dtype=FLOAT_DTYPE)
|
|
209
|
+
rc = stream_state.rc
|
|
210
|
+
mu = stream_state.mu
|
|
211
|
+
ecc = stream_state.ecc
|
|
212
|
+
|
|
213
|
+
# orb_ang is varphi in Mendoza+2009
|
|
214
|
+
# at initial position r_to_rc = r0/rc = 1/mu
|
|
215
|
+
orb_ang0 = get_orb_ang(r_to_rc=1/mu, theta0=theta0, ecc=ecc)
|
|
216
|
+
|
|
217
|
+
r_to_rc_raw = r / rc
|
|
218
|
+
r_to_rc = jnp.where(r_mask, r_to_rc_raw, to_float64(0.6))
|
|
219
|
+
|
|
220
|
+
orb_ang = get_orb_ang(r_to_rc=r_to_rc, theta0=theta0, ecc=ecc)
|
|
221
|
+
theta = get_theta(theta0, orb_ang, orb_ang0)
|
|
222
|
+
phi = phi0 + get_dphi(theta, theta0=theta0)
|
|
223
|
+
|
|
224
|
+
# remove values where r_to_rc < 0.5 (inside centrifugal radius)
|
|
225
|
+
# this will include all the mask points, and also any points that are inside 0.5*rc
|
|
226
|
+
valid_mask = r_mask & (r_to_rc >= 0.5)
|
|
227
|
+
|
|
228
|
+
# safe sentinel values for invalid points, to make sure gradients are finite
|
|
229
|
+
# will be masked out later in final output
|
|
230
|
+
orb_ang = jnp.where(valid_mask, orb_ang, jnp.pi/4)
|
|
231
|
+
theta_sentinel = jnp.minimum(theta0 + to_float64(0.1), to_float64(jnp.pi) - to_float64(eps))
|
|
232
|
+
theta = jnp.where(valid_mask, theta, theta_sentinel)
|
|
233
|
+
phi = jnp.where(valid_mask, phi, phi0)
|
|
234
|
+
|
|
235
|
+
return orb_ang, theta, phi, valid_mask #in radians
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@jax.jit
|
|
239
|
+
def stream_line_vel(
|
|
240
|
+
r,
|
|
241
|
+
theta,
|
|
242
|
+
orb_ang,
|
|
243
|
+
stream_state,
|
|
244
|
+
theta0=jnp.radians(30),
|
|
245
|
+
r_mask=None
|
|
246
|
+
):
|
|
247
|
+
'''
|
|
248
|
+
It calculates the velocity along the stream line following Mendoza+(2009)
|
|
249
|
+
It takes the radial velocity and rotation at the streamline
|
|
250
|
+
initial radius and it describes the entire trajectory.
|
|
251
|
+
|
|
252
|
+
:param theta: radians
|
|
253
|
+
:param r: au
|
|
254
|
+
:param stream_state: StreamState named tuple containing precomputed quantities for the streamline
|
|
255
|
+
:param theta0: radians
|
|
256
|
+
:param r_mask: boolean mask
|
|
257
|
+
:return: v_r, v_theta, v_phi in units of km/s
|
|
258
|
+
'''
|
|
259
|
+
rc = stream_state.rc
|
|
260
|
+
ecc = stream_state.ecc
|
|
261
|
+
vk0 = stream_state.vk0
|
|
262
|
+
|
|
263
|
+
r_to_rc_raw = r / rc
|
|
264
|
+
if r_mask is not None:
|
|
265
|
+
# see sentnel value used for r_to_rc in stream_line. this is the same thing.
|
|
266
|
+
r_to_rc = jnp.where(r_mask, r_to_rc_raw, to_float64(0.6))
|
|
267
|
+
else:
|
|
268
|
+
r_to_rc = r_to_rc_raw
|
|
269
|
+
#
|
|
270
|
+
v_r_all = -ecc * jnp.sin(theta0) * jnp.sin(orb_ang) / r_to_rc /(1 - ecc*jnp.cos(orb_ang))
|
|
271
|
+
sqrt_arg = jnp.power(jnp.cos(theta0), 2) - jnp.power(jnp.cos(theta), 2)
|
|
272
|
+
sqrt_arg_safe = jnp.maximum(sqrt_arg, eps) # eps = 1e-8 or similar
|
|
273
|
+
|
|
274
|
+
v_theta_all = jnp.sin(theta0) / jnp.sin(theta) / r_to_rc \
|
|
275
|
+
* jnp.sqrt(sqrt_arg_safe)
|
|
276
|
+
v_phi_all = jnp.power(jnp.sin(theta0), 2) / (jnp.sin(theta) * r_to_rc)
|
|
277
|
+
|
|
278
|
+
return v_r_all * vk0, v_theta_all * vk0, v_phi_all * vk0
|
|
279
|
+
|
|
280
|
+
@jax.jit
|
|
281
|
+
def build_rotation_matrix(inc, pa):
|
|
282
|
+
'''constructs combined inclination/position-angle rotation matrix'''
|
|
283
|
+
|
|
284
|
+
inc = jnp.asarray(inc, dtype=FLOAT_DTYPE)
|
|
285
|
+
pa = jnp.asarray(pa, dtype=FLOAT_DTYPE)
|
|
286
|
+
|
|
287
|
+
ci = jnp.cos(inc)
|
|
288
|
+
si = jnp.sin(inc)
|
|
289
|
+
cp = jnp.cos(pa)
|
|
290
|
+
sp = jnp.sin(pa)
|
|
291
|
+
|
|
292
|
+
return jnp.array([
|
|
293
|
+
[cp, sp * si, -sp * ci],
|
|
294
|
+
[0.0, ci, si],
|
|
295
|
+
[sp, -cp * si, cp * ci],
|
|
296
|
+
], dtype=FLOAT_DTYPE)
|
|
297
|
+
|
|
298
|
+
@jax.jit
|
|
299
|
+
def rotate_xyz(x, y, z, rotation_matrix):
|
|
300
|
+
'''
|
|
301
|
+
Rotate on inclination and PA
|
|
302
|
+
x-axis and y-axis are on the plane on the sky,
|
|
303
|
+
z-axis is the
|
|
304
|
+
|
|
305
|
+
Rotation around x is inclination angle
|
|
306
|
+
Rotation around y is PA angle
|
|
307
|
+
|
|
308
|
+
Using example matrices as described in:
|
|
309
|
+
https://en.wikipedia.org/wiki/3D_projection
|
|
310
|
+
|
|
311
|
+
:param x: cartesian x-coordinate, in the direction of decreasing RA
|
|
312
|
+
:param y: cartesian y-coordinate, in the direction away of the observer
|
|
313
|
+
:param z: cartesian z-coordinate, in the direction of increasing Dec.
|
|
314
|
+
:param rotation_matrix: 3x3 rotation matrix combining inclination and PA rotations.
|
|
315
|
+
:return: new x, y, and z-coordinates as observed on the sky, with the
|
|
316
|
+
same units as the input ones.
|
|
317
|
+
|
|
318
|
+
'''
|
|
319
|
+
x = jnp.asarray(x, dtype=FLOAT_DTYPE)
|
|
320
|
+
y = jnp.asarray(y, dtype=FLOAT_DTYPE)
|
|
321
|
+
z = jnp.asarray(z, dtype=FLOAT_DTYPE)
|
|
322
|
+
|
|
323
|
+
xyz = jnp.stack((x, y, z), axis=0)
|
|
324
|
+
|
|
325
|
+
xyz_rot = rotation_matrix @ xyz
|
|
326
|
+
|
|
327
|
+
return xyz_rot[0], xyz_rot[1], xyz_rot[2]
|
|
328
|
+
|
|
329
|
+
def check_rc_r0(rc, r0):
|
|
330
|
+
'''check that centrifugal radius is smaller than initial radius of streamline, otherwise the model is not valid'''
|
|
331
|
+
checkify.check(
|
|
332
|
+
rc < r0,
|
|
333
|
+
"Centrifugal radius is larger than start of streamline. Model is not valid."
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
def check_r_array(r, r_low):
|
|
337
|
+
'''check that radius array extends down to r_low, otherwise the model doesn't extend far enough for the given npoints and deltar'''
|
|
338
|
+
r_small = r <= r_low
|
|
339
|
+
checkify.check(
|
|
340
|
+
jnp.any(r_small),
|
|
341
|
+
"Radius points do not extend down to rlow. Increase npoints and/or deltar"
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
def xyz_stream(mass=0.5, r0=1e4, theta0=jnp.radians(30),
|
|
345
|
+
phi0=jnp.radians(15), mu=0.1, v_r0=0,
|
|
346
|
+
inc=0, pa=0, rmin=None, deltar=1, npoints=10000):
|
|
347
|
+
'''
|
|
348
|
+
it gets xyz coordinates and velocities for a stream line.
|
|
349
|
+
They are also rotated in PA and inclination along the line of sight.
|
|
350
|
+
This is a wrapper around stream_line() and rotate_xyz()
|
|
351
|
+
|
|
352
|
+
Spherical into cartesian transformation is done for position and velocity
|
|
353
|
+
using:
|
|
354
|
+
https://en.wikipedia.org/wiki/Vector_fields_in_cylindrical_and_spherical_coordinates
|
|
355
|
+
|
|
356
|
+
:param mass: Central mass (unitless, Msun)
|
|
357
|
+
:param r0: Initial radius of streamline (unitless, au)
|
|
358
|
+
:param theta0: Initial polar angle of streamline (unitless, radians)
|
|
359
|
+
:param phi0: Initial azimuthal angle of streamline (unitless, radians)
|
|
360
|
+
:param mu: dimensionless, rc/r0 in (0, 1)
|
|
361
|
+
:param v_r0: Initial radial velocity of the streamline, (km/s)
|
|
362
|
+
:param inc: inclination with respect of line-of-sight, inc=0 is an edge-on-disk (unitless, radians)
|
|
363
|
+
:param pa: Position angle of the rotation axis, measured due East from North. This is usually estimated from the outflow PA, or the disk PA-90deg., (unitless, radians)
|
|
364
|
+
:param rmin: smallest radius for calculation, (unitless, au)
|
|
365
|
+
:param deltar: spacing between two consecutive radii in the sampling of the streamer, in (unitless, au)
|
|
366
|
+
:param npoints: number of points to sample along the streamer
|
|
367
|
+
This is just so that arrays are fixed length for jax/jit compatibility,
|
|
368
|
+
but the actual number of valid points is determined by r0, rmin, rc, deltar,
|
|
369
|
+
so some of the returned points may be NaN if npoints is larger than the number of valid points
|
|
370
|
+
:return: x, y, z in (au), v_x, v_y, v_z in (km/s)
|
|
371
|
+
'''
|
|
372
|
+
|
|
373
|
+
mass = jnp.asarray(mass, dtype=FLOAT_DTYPE)
|
|
374
|
+
r0 = jnp.asarray(r0, dtype=FLOAT_DTYPE)
|
|
375
|
+
theta0 = jnp.asarray(theta0, dtype=FLOAT_DTYPE)
|
|
376
|
+
phi0 = jnp.asarray(phi0, dtype=FLOAT_DTYPE)
|
|
377
|
+
mu = jnp.asarray(mu, dtype=FLOAT_DTYPE)
|
|
378
|
+
v_r0 = jnp.asarray(v_r0, dtype=FLOAT_DTYPE)
|
|
379
|
+
inc = jnp.asarray(inc, dtype=FLOAT_DTYPE)
|
|
380
|
+
pa = jnp.asarray(pa, dtype=FLOAT_DTYPE)
|
|
381
|
+
deltar = jnp.asarray(deltar, dtype=FLOAT_DTYPE)
|
|
382
|
+
stream_state = build_stream_quantities(mass=mass, r0=r0, theta0=theta0, mu=mu, v_r0=v_r0)
|
|
383
|
+
rc = stream_state.rc
|
|
384
|
+
mu = stream_state.mu
|
|
385
|
+
ecc = stream_state.ecc
|
|
386
|
+
|
|
387
|
+
rotation_matrix = build_rotation_matrix(inc, pa)
|
|
388
|
+
|
|
389
|
+
check_rc_r0(rc, r0)
|
|
390
|
+
|
|
391
|
+
# find the smallest radius for calculation
|
|
392
|
+
# this is the maximum between rmin and 0.5*rc
|
|
393
|
+
r_low = jnp.maximum(rmin, rc*0.5)
|
|
394
|
+
|
|
395
|
+
# r is values internal to the initial radius r0 for computation
|
|
396
|
+
# r_mask is used to mask out points that are outside the valid range, but we still need to compute them for jax/jit compatibility
|
|
397
|
+
r = (r0 - deltar) - jnp.arange(npoints-1, dtype=FLOAT_DTYPE) * deltar
|
|
398
|
+
check_r_array(r, r_low)
|
|
399
|
+
r_mask = r > r_low
|
|
400
|
+
|
|
401
|
+
# calculate positions and velocities inside r0
|
|
402
|
+
# the valid_mask will later be used to mask out invalid points. currently these values are zero
|
|
403
|
+
orb_ang, theta, phi, valid_mask = stream_line(r, r_mask, stream_state=stream_state, theta0=theta0, phi0=phi0)
|
|
404
|
+
v_r, v_theta, v_phi = stream_line_vel(r, theta, orb_ang, stream_state=stream_state, theta0=theta0, r_mask=r_mask)
|
|
405
|
+
# prepend initial positions and velocities at r0
|
|
406
|
+
valid_mask_full = jnp.concatenate((jnp.asarray([True], dtype=bool), valid_mask))
|
|
407
|
+
r_full = jnp.concatenate((jnp.asarray([r0], dtype=FLOAT_DTYPE), r))
|
|
408
|
+
theta_full = jnp.concatenate((jnp.asarray([theta0], dtype=FLOAT_DTYPE), theta))
|
|
409
|
+
phi_full = jnp.concatenate((jnp.asarray([phi0], dtype=FLOAT_DTYPE), phi))
|
|
410
|
+
orb_ang0 = get_orb_ang(r_to_rc=1/mu, theta0=theta0, ecc=ecc)
|
|
411
|
+
orb_ang_full = jnp.concatenate((jnp.asarray([orb_ang0], dtype=FLOAT_DTYPE), orb_ang))
|
|
412
|
+
v_r_full = jnp.concatenate((jnp.asarray([v_r0], dtype=FLOAT_DTYPE), v_r))
|
|
413
|
+
v_theta_full = jnp.concatenate((jnp.asarray([0.0], dtype=FLOAT_DTYPE), v_theta))
|
|
414
|
+
# we need to calculate v_phi0
|
|
415
|
+
v_phi0 = stream_state.vk0 * jnp.sin(theta0) * stream_state.mu
|
|
416
|
+
v_phi_full = jnp.concatenate((jnp.asarray([v_phi0], dtype=FLOAT_DTYPE), v_phi))
|
|
417
|
+
|
|
418
|
+
# convert from spherical into cartesian coordinates
|
|
419
|
+
v_x = v_r_full * jnp.sin(theta_full) * jnp.cos(phi_full) \
|
|
420
|
+
+ v_theta_full * jnp.cos(theta_full) * jnp.cos(phi_full) \
|
|
421
|
+
- v_phi_full * jnp.sin(phi_full)
|
|
422
|
+
v_y = v_r_full * jnp.sin(theta_full) * jnp.sin(phi_full) \
|
|
423
|
+
+ v_theta_full * jnp.cos(theta_full) * jnp.sin(phi_full) \
|
|
424
|
+
+ v_phi_full * jnp.cos(phi_full)
|
|
425
|
+
v_z = v_r_full * jnp.cos(theta_full) \
|
|
426
|
+
- v_theta_full * jnp.sin(theta_full)
|
|
427
|
+
x = r_full * jnp.sin(theta_full) * jnp.cos(phi_full)
|
|
428
|
+
y = r_full * jnp.sin(theta_full) * jnp.sin(phi_full)
|
|
429
|
+
z = r_full * jnp.cos(theta_full)
|
|
430
|
+
rotated_x, rotated_y, rotated_z = rotate_xyz(x, y, z, rotation_matrix=rotation_matrix)
|
|
431
|
+
rotated_v_x, rotated_v_y, rotated_v_z = rotate_xyz(v_x, v_y, v_z, rotation_matrix=rotation_matrix)
|
|
432
|
+
# get mask from smallest radius for calculation
|
|
433
|
+
gd_rlow = (r_full > r_low)
|
|
434
|
+
gd_rlow = jnp.logical_or(gd_rlow, valid_mask_full)
|
|
435
|
+
gd_rlow = gd_rlow.astype(x.dtype)
|
|
436
|
+
# apply mask to set invalid points to zero
|
|
437
|
+
rotated_x = jnp.where(gd_rlow, rotated_x, 0.0)
|
|
438
|
+
rotated_y = jnp.where(gd_rlow, rotated_y, 0.0)
|
|
439
|
+
rotated_z = jnp.where(gd_rlow, rotated_z, 0.0)
|
|
440
|
+
rotated_v_x = jnp.where(gd_rlow, rotated_v_x, 0.0)
|
|
441
|
+
rotated_v_y = jnp.where(gd_rlow, rotated_v_y, 0.0)
|
|
442
|
+
rotated_v_z = jnp.where(gd_rlow, rotated_v_z, 0.0)
|
|
443
|
+
return (rotated_x, rotated_y, rotated_z), \
|
|
444
|
+
(rotated_v_x, rotated_v_y, rotated_v_z), \
|
|
445
|
+
gd_rlow
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
checked_xyz_stream = jax.jit(checkify.checkify(xyz_stream), static_argnames=['npoints'])
|