orbix 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
orbix/__init__.py ADDED
File without changes
orbix/_version.py ADDED
@@ -0,0 +1,16 @@
1
+ # file generated by setuptools_scm
2
+ # don't change, don't track in version control
3
+ TYPE_CHECKING = False
4
+ if TYPE_CHECKING:
5
+ from typing import Tuple, Union
6
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
7
+ else:
8
+ VERSION_TUPLE = object
9
+
10
+ version: str
11
+ __version__: str
12
+ __version_tuple__: VERSION_TUPLE
13
+ version_tuple: VERSION_TUPLE
14
+
15
+ __version__ = version = '0.0.2'
16
+ __version_tuple__ = version_tuple = (0, 0, 2)
orbix/eccanom.py ADDED
@@ -0,0 +1,594 @@
1
+ """JAX-compiled functions to compute the eccentric anomaly based on the orvara method.
2
+
3
+ The primary entrance point is the `solve_E` function, which chooses the correct
4
+ method based on the eccentricity value. When repeated calls will be made with
5
+ the same eccentricity value, it is recommended to use the `get_E_solver`
6
+ function which will return a JIT-compiled function with the eccentricity value
7
+ fixed.
8
+
9
+ If the sine and cosine of the eccentric anomaly are also required, such as in
10
+ RV calculations, the `solve_E_trig` function can be used. This function returns
11
+ the eccentric anomaly along with the sine and cosine of the eccentric anomaly.
12
+ Similarly, a JIT-compiled function can be obtained using the `get_E_trig_solver`
13
+ and a specific eccentricity value to speed up repeated calculations.
14
+
15
+ If high precision is not required, the `guess_E` function returns the initial
16
+ guess for the eccentric anomaly calculated with a lookup table of polynomials.
17
+ A compiled form for a specific `e` value can be obtained using the
18
+ `get_E_guesser` function.
19
+ """
20
+
21
+ from functools import partial
22
+
23
+ import jax
24
+ import jax.numpy as jnp
25
+ from jax import lax
26
+
27
+ # Define coefficients used in the shortsin function
28
+ if3 = 1.0 / 6.0
29
+ if5 = 1.0 / (6.0 * 20.0)
30
+ if7 = 1.0 / (6.0 * 20.0 * 42.0)
31
+ if9 = 1.0 / (6.0 * 20.0 * 42.0 * 72.0)
32
+ if11 = 1.0 / (6.0 * 20.0 * 42.0 * 72.0 * 110.0)
33
+ if13 = 1.0 / (6.0 * 20.0 * 42.0 * 72.0 * 110.0 * 156.0)
34
+ if15 = 1.0 / (6.0 * 20.0 * 42.0 * 72.0 * 110.0 * 156.0 * 210.0)
35
+ pi = jnp.pi
36
+ pi_d_12 = pi / 12.0
37
+ pi_d_6 = pi / 6.0
38
+ pi_d_4 = pi / 4.0
39
+ pi_d_3 = pi / 3.0
40
+ fivepi_d_12 = 5.0 * pi / 12.0
41
+ pi_d_2 = pi / 2.0
42
+ sevenpi_d_12 = 7.0 * pi / 12.0
43
+ twopi_d_3 = 2.0 * pi / 3.0
44
+ threepi_d_4 = 3.0 * pi / 4.0
45
+ fivepi_d_6 = 5.0 * pi / 6.0
46
+ elevenpi_d_12 = 11.0 * pi / 12.0
47
+ two_pi = 2.0 * pi
48
+
49
+
50
+ def create_E_func_multi_e(e_vals, func):
51
+ """Creates E solvers for all provided `e` values using the provided function."""
52
+ # Assure that the provided e_range is linearly spaced
53
+ funcs = {e: func(e) for e in e_vals}
54
+
55
+ def E_calc(e, M):
56
+ return funcs[e](M)
57
+
58
+ return E_calc
59
+
60
+
61
+ def get_E_solver(e: float):
62
+ """Get the eccentric anomaly solver for a specific eccentricity value."""
63
+ solve_E_partial = partial(solve_E, e=e)
64
+ solve_E_jitted = jax.jit(solve_E_partial)
65
+ return solve_E_jitted
66
+
67
+
68
+ @jax.jit
69
+ def solve_E(M, e):
70
+ # Not handling e == 0 case for now, would look something like this:
71
+ # e_ind = jnp.select([e == 0, e < 0.78], [0, 1], default=2)
72
+ # return lax.switch(e_ind, [identity_solver, le_E, he_E], M, e)
73
+ # where identity_solver returns M as E
74
+ M = jnp.mod(M, two_pi)
75
+ return lax.cond(e < 0.78, le_E, he_E, M, e)
76
+
77
+
78
+ def get_E_trig_solver(e: float):
79
+ """Get the eccentric anomaly solver for a specific eccentricity value."""
80
+ solve_E_trig_partial = partial(solve_E_trig, e=e)
81
+ solve_E_trig_jitted = jax.jit(solve_E_trig_partial)
82
+ return solve_E_trig_jitted
83
+
84
+
85
+ @jax.jit
86
+ def solve_E_trig(M, e):
87
+ return lax.cond(e < 0.78, le_E_trig, he_E_trig, M, e)
88
+
89
+
90
+ def get_E_guesser(e: float):
91
+ """Creates a JIT-compiled function for a specific e value, treating e as static.
92
+
93
+ Args:
94
+ e (float): The eccentricity value.
95
+
96
+ Returns:
97
+ function: A JIT-compiled function that takes M as input.
98
+ """
99
+ # JIT-compile E_guess with e treated as static
100
+ guess_E_partial = partial(guess_E, e=e)
101
+ guess_E_jitted = jax.jit(guess_E_partial)
102
+
103
+ # Return a lambda function with e fixed
104
+ return guess_E_jitted
105
+
106
+
107
+ @jax.jit
108
+ def guess_E(M: jnp.ndarray, e: float):
109
+ """Gives the initial guess based on the constructed polynomials."""
110
+ bounds, coeffs = getbounds(e)
111
+ Esigns, _M = cut_M(M)
112
+ init_E = init_E_coeffs(_M, bounds, coeffs)
113
+ return jnp.fmod(Esigns * init_E + two_pi, two_pi)
114
+
115
+
116
+ def get_E_lookup(e, n=1800):
117
+ M_vals = jnp.linspace(0, two_pi, n, endpoint=False)
118
+ E_vals = solve_E(M_vals, e)
119
+ dM = M_vals[1] - M_vals[0]
120
+ inv_dM = 1.0 / dM
121
+ n_int = jnp.int32(n)
122
+
123
+ @jax.jit
124
+ def E_lookup(M):
125
+ return E_vals[(M * inv_dM).astype(jnp.int32) % n_int]
126
+
127
+ return E_lookup
128
+
129
+
130
+ def get_E_lookup_d(e, n=1800):
131
+ # Linearly spaced mean anomaly values
132
+ M_vals = jnp.linspace(0, two_pi, n, endpoint=False)
133
+ dM = M_vals[1] - M_vals[0]
134
+ # Use the solve_E function to get the real eccentric anomaly values
135
+ # and the derivative of E (with respect to the array indices)
136
+ E_vals = solve_E(M_vals, e)
137
+ dE_dM = 1 / (1 - e * jnp.cos(E_vals))
138
+ dE_dind = dE_dM * dM
139
+
140
+ inv_dM = 1.0 / dM
141
+ n_int = jnp.int32(n)
142
+
143
+ @jax.jit
144
+ def interp_E(M):
145
+ # Scale M to index space
146
+ ind_M = M * inv_dM
147
+ # Get integer part of the index
148
+ ind_M_int = ind_M.astype(jnp.int32)
149
+ # Get the fractional part of the index
150
+ dind = ind_M - ind_M_int
151
+ # Wrap to [0, 2pi) indices with modulo n
152
+ final_inds = ind_M_int % n_int
153
+ # Linear interpolation
154
+ return E_vals[final_inds] + dE_dind[final_inds] * dind
155
+
156
+ return interp_E
157
+
158
+
159
+ def get_E_lookup_hermite(e, n=1800):
160
+ # Generate linearly spaced mean anomaly values
161
+ M_vals = jnp.linspace(0, two_pi, n, endpoint=False)
162
+ # Compute the step size in mean anomaly per index
163
+ dM = M_vals[1] - M_vals[0]
164
+ # Solve for eccentric anomaly for each mean anomaly
165
+ E_vals = solve_E(M_vals, e)
166
+ # Compute the derivative of E with respect to the index
167
+ dE_dind = jnp.gradient(E_vals)
168
+
169
+ @jax.jit
170
+ def E_lookup(M):
171
+ # Wrap M to the range [0, two_pi]
172
+ _M = jnp.mod(M, two_pi)
173
+ # Determine the fractional index and the integer part
174
+ # dind, ind_float = jnp.modf(_M / dind_dM)
175
+ dind, ind_float = jnp.modf(jnp.true_divide(_M, dM))
176
+ ind_int = ind_float.astype(int)
177
+ # Handle wrap-around for the next index
178
+ ind_next = (ind_int + 1) % n
179
+
180
+ # Retrieve function values and derivatives at the surrounding indices
181
+ E_i = E_vals[ind_int]
182
+ E_ip1 = E_vals[ind_next]
183
+ dE_i = dE_dind[ind_int]
184
+ dE_ip1 = dE_dind[ind_next]
185
+
186
+ # Hermite basis functions
187
+ h00 = 2 * dind**3 - 3 * dind**2 + 1
188
+ h10 = dind**3 - 2 * dind**2 + dind
189
+ h01 = -2 * dind**3 + 3 * dind**2
190
+ h11 = dind**3 - dind**2
191
+
192
+ # Perform Hermite interpolation
193
+ E_interp = h00 * E_i + h10 * dE_i + h01 * E_ip1 + h11 * dE_ip1
194
+
195
+ return E_interp
196
+
197
+ return E_lookup
198
+
199
+
200
+ def shortsin(x):
201
+ """Approximates the sine function using a short polynomial.
202
+
203
+ This is only valid between [0, π].
204
+ """
205
+ x2 = x * x
206
+ return x * (
207
+ 1
208
+ - x2
209
+ * (
210
+ if3
211
+ - x2
212
+ * (if5 - x2 * (if7 - x2 * (if9 - x2 * (if11 - x2 * (if13 - x2 * if15)))))
213
+ )
214
+ )
215
+
216
+
217
+ def cut_M(M: jnp.ndarray):
218
+ """Cut M to be between 0 and pi.
219
+
220
+ Also returns the sign of the eccentric anomaly.
221
+
222
+ Args:
223
+ M (jnp.ndarray):
224
+ Mean anomalies (rad). Shape: (n,).
225
+
226
+ Returns:
227
+ Esigns (jnp.ndarray):
228
+ Sign of the eccentric anomaly. Shape: (n,).
229
+ _M (jnp.ndarray):
230
+ Modified mean anomalies. Shape: (n,).
231
+ """
232
+ mask = M > pi
233
+ Esigns = jnp.where(mask, -1, 1)
234
+ _M = jnp.where(mask, two_pi - M, M)
235
+ return Esigns, _M
236
+
237
+
238
+ def getbounds(e: float):
239
+ """Create bounds and coefficients for the eccentric anomaly polynomial.
240
+
241
+ Args:
242
+ e (float): Eccentricity
243
+
244
+ Returns:
245
+ tuple:
246
+ bounds (jnp.ndarray):
247
+ Array of bounds for the eccentric anomaly intervals. Shape: (13,)
248
+ coeffs (jnp.ndarray)
249
+ Lookup table containing coefficients for the Taylor series
250
+ expansion. Shape: (13, 6)
251
+ """
252
+ # Compute scaled constants
253
+ g2s_e = 0.2588190451025207623489 * e
254
+ g3s_e = 0.5 * e
255
+ g4s_e = 0.7071067811865475244008 * e
256
+ g5s_e = 0.8660254037844386467637 * e
257
+ g6s_e = 0.9659258262890682867497 * e
258
+
259
+ g2c_e = g6s_e
260
+ g3c_e = g5s_e
261
+ g4c_e = g4s_e
262
+ g5c_e = g3s_e
263
+ g6c_e = g2s_e
264
+
265
+ # Initialize bounds array
266
+ bounds = jnp.array(
267
+ [
268
+ 0.0,
269
+ pi_d_12 - g2s_e,
270
+ pi_d_6 - g3s_e,
271
+ pi_d_4 - g4s_e,
272
+ pi_d_3 - g5s_e,
273
+ fivepi_d_12 - g6s_e,
274
+ pi_d_2 - e,
275
+ sevenpi_d_12 - g6s_e,
276
+ twopi_d_3 - g5s_e,
277
+ threepi_d_4 - g4s_e,
278
+ fivepi_d_6 - g3s_e,
279
+ elevenpi_d_12 - g2s_e,
280
+ pi,
281
+ ]
282
+ )
283
+
284
+ # Initialize coeffs array with shape (13, 6)
285
+ ai1 = jnp.array(
286
+ [
287
+ 1.0 / (1.0 - e),
288
+ 1.0 / (1.0 - g2c_e),
289
+ 1.0 / (1.0 - g3c_e),
290
+ 1.0 / (1.0 - g4c_e),
291
+ 1.0 / (1.0 - g5c_e),
292
+ 1.0 / (1.0 - g6c_e),
293
+ 1.0,
294
+ 1.0 / (1.0 + g6c_e),
295
+ 1.0 / (1.0 + g5c_e),
296
+ 1.0 / (1.0 + g4c_e),
297
+ 1.0 / (1.0 + g3c_e),
298
+ 1.0 / (1.0 + g2c_e),
299
+ 1.0 / (1.0 + e),
300
+ ]
301
+ )
302
+ ai2 = (
303
+ jnp.array(
304
+ [
305
+ 0,
306
+ -0.5 * g2s_e,
307
+ -0.5 * g3s_e,
308
+ -0.5 * g4s_e,
309
+ -0.5 * g5s_e,
310
+ -0.5 * g6s_e,
311
+ -0.5 * e,
312
+ -0.5 * g6s_e,
313
+ -0.5 * g5s_e,
314
+ -0.5 * g4s_e,
315
+ -0.5 * g3s_e,
316
+ -0.5 * g2s_e,
317
+ 0,
318
+ ]
319
+ )
320
+ * ai1**3
321
+ )
322
+
323
+ # Index of the lower bound of the interval
324
+ i = jnp.arange(12)
325
+ # Set the 0th coefficient of the polynomials
326
+ ai0 = i * pi_d_12
327
+
328
+ # Set the 3rd, 4th, and 5th coefficients of the polynomials with array
329
+ # operations since they are solved based on the 1st and 2nd coefficients
330
+ ii = i + 1
331
+ idx = 1.0 / (bounds[ii] - bounds[i])
332
+ B0 = idx * (-ai2[i] - idx * (ai1[i] - idx * pi_d_12))
333
+ B1 = idx * (-2.0 * ai2[i] - idx * (ai1[i] - ai1[ii]))
334
+ B2 = idx * (ai2[ii] - ai2[i])
335
+ ai3 = B2 - 4.0 * B1 + 10.0 * B0
336
+ ai4 = (-2.0 * B2 + 7.0 * B1 - 15.0 * B0) * idx
337
+ ai5 = (B2 - 3.0 * B1 + 6.0 * B0) * idx**2
338
+ coeffs = jnp.stack([ai0, ai1[:-1], ai2[:-1], ai3, ai4, ai5], axis=1)
339
+
340
+ return bounds, coeffs
341
+
342
+
343
+ def init_E_poly(M, e):
344
+ """Initial guess for the eccentric anomaly.
345
+
346
+ Calculates the initial guess for the eccentric anomaly based on the mean
347
+ anomaly and eccentricity. Translated from the C implementation into JAX.
348
+
349
+ Parameters:
350
+ M (jnp.ndarray):
351
+ Mean anomaly in radians.
352
+ e (float):
353
+ Eccentricity of the orbit.
354
+
355
+ Returns:
356
+ jnp.ndarray:
357
+ Initial estimate of the eccentric anomaly in radians.
358
+ """
359
+ ome = 1.0 - e
360
+ sqrt_ome = lax.sqrt(ome)
361
+ chi = M / (sqrt_ome * ome)
362
+ Lam = lax.sqrt(8.0 + 9.0 * chi**2)
363
+ S = lax.cbrt(Lam + 3.0 * chi)
364
+ S_squared = S * S
365
+ sigma = 6.0 * chi / (2.0 + S_squared + 4.0 / S_squared)
366
+ s2 = sigma * sigma
367
+ denom = s2 + 2.0
368
+ E = sigma * (
369
+ 1.0
370
+ + s2
371
+ * ome
372
+ * (
373
+ (s2 + 20.0) / (60.0 * denom)
374
+ + s2
375
+ * ome
376
+ * (s2**3 + 25.0 * s2**2 + 340.0 * s2 + 840.0)
377
+ / (1400.0 * denom**3)
378
+ )
379
+ )
380
+ return E * sqrt_ome
381
+
382
+
383
+ def init_E_coeffs(M: jnp.ndarray, bounds: jnp.ndarray, coeffs: jnp.ndarray):
384
+ """Create the initial guess for the eccentric anomaly using the polynomials."""
385
+ # j_inds = jnp.searchsorted(bounds, M, side="right") - 1
386
+ j_inds = jnp.digitize(M, bounds) - 1
387
+ dx = M - bounds[j_inds]
388
+ return coeffs[j_inds, 0] + dx * (
389
+ coeffs[j_inds, 1]
390
+ + dx
391
+ * (
392
+ coeffs[j_inds, 2]
393
+ + +dx
394
+ * (coeffs[j_inds, 3] + dx * (coeffs[j_inds, 4] + dx * coeffs[j_inds, 5]))
395
+ )
396
+ )
397
+
398
+
399
+ def dE_num_denom(M, E, e_inv, sinE, cosE):
400
+ """Compute the numerator and denominator for dE."""
401
+ num = (M - E) * e_inv + sinE
402
+ denom = e_inv - cosE
403
+ return num, denom
404
+
405
+
406
+ def dE_2nd(M, E, e_inv, sinE, cosE):
407
+ """Compute the second order approximation of dE."""
408
+ num, denom = dE_num_denom(M, E, e_inv, sinE, cosE)
409
+ return num * denom / (denom * denom + 0.5 * sinE * num)
410
+
411
+
412
+ def dE_3rd(M, E, e_inv, sinE, cosE):
413
+ """Compute the third order approximation of dE."""
414
+ num, denom = dE_num_denom(M, E, e_inv, sinE, cosE)
415
+ dE = (
416
+ num
417
+ * (denom * denom + 0.5 * num * sinE)
418
+ / (denom * denom * denom + num * (denom * sinE + if3 * num * cosE))
419
+ )
420
+ return dE
421
+
422
+
423
+ def compute_dE_single(M, init_E_val, e_inv_val, sinE_val, cosE_val):
424
+ """Computes dE for a single element based on the condition M > 0.4.
425
+
426
+ Args:
427
+ M (float): Single element from _M.
428
+ init_E_val (float): Corresponding element from init_E.
429
+ e_inv_val (float): Inverse of eccentricity.
430
+ sinE_val (float): Sine of E.
431
+ cosE_val (float): Cosine of E.
432
+
433
+ Returns:
434
+ float: Computed dE for the element.
435
+ """
436
+ return jax.lax.cond(
437
+ M > 0.4, dE_2nd, dE_3rd, M, init_E_val, e_inv_val, sinE_val, cosE_val
438
+ )
439
+
440
+
441
+ # Vectorize the compute_dE_single function
442
+ compute_dE_vectorized = jax.vmap(
443
+ compute_dE_single, in_axes=(0, 0, None, 0, 0), out_axes=0
444
+ )
445
+
446
+
447
+ def le_E(M: jnp.ndarray, e: float):
448
+ """Inverts Kepler's time equation for elliptical orbits using Orvara's method.
449
+
450
+ Args:
451
+ M (jnp.ndarray): Mean anomalies (rad). Shape: (n,).
452
+ e (float): Eccentricity. Must satisfy 0 <= e < 1.
453
+
454
+ Returns:
455
+ - E (jnp.ndarray): Eccentric anomalies (rad). Shape: (n,).
456
+ """
457
+ # Get bounds and coeffs
458
+ bounds, coeffs = getbounds(e)
459
+ # Cut M to be between 0 and pi
460
+ Esigns, _M = cut_M(M)
461
+
462
+ # Get initial guess
463
+ init_E = init_E_coeffs(_M, bounds, coeffs)
464
+ sinE, cosE = fast_sinE_cosE(init_E)
465
+ dE = dE_2nd(_M, init_E, 1.0 / e, sinE, cosE)
466
+ E = jnp.fmod(Esigns * (init_E + dE) + two_pi, two_pi)
467
+ return E
468
+
469
+
470
+ def le_E_trig(M: jnp.ndarray, e: float):
471
+ """Inverts Kepler's time equation for elliptical orbits using Orvara's method.
472
+
473
+ Also returns the sine and cosine of the eccentric anomaly.
474
+
475
+ Args:
476
+ M (jnp.ndarray): Mean anomalies (rad). Shape: (n,).
477
+ e (float): Eccentricity. Must satisfy 0 <= e < 1.
478
+
479
+ Returns:
480
+ Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
481
+ - E (jnp.ndarray): Eccentric anomalies (rad). Shape: (n,).
482
+ - sinE (jnp.ndarray): Sine of eccentric anomalies (rad). Shape: (n,).
483
+ - cosE (jnp.ndarray): Cosine of eccentric anomalies (rad). Shape: (n,).
484
+ """
485
+ # Get bounds and coeffs
486
+ bounds, coeffs = getbounds(e)
487
+ # Cut M to be between 0 and pi
488
+ Esigns, _M = cut_M(M)
489
+
490
+ # Get initial guess
491
+ init_E = init_E_coeffs(_M, bounds, coeffs)
492
+ sinE, cosE = fast_sinE_cosE(init_E)
493
+ dE = dE_2nd(_M, init_E, 1.0 / e, sinE, cosE)
494
+ E = jnp.fmod(Esigns * (init_E + dE) + two_pi, two_pi)
495
+ sinE = Esigns * (sinE * (1.0 - 0.5 * dE * dE) + dE * cosE)
496
+ cosE = cosE * (1.0 - 0.5 * dE * dE) - dE * sinE
497
+ return E, sinE, cosE
498
+
499
+
500
+ def he_E(M: jnp.ndarray, e: float):
501
+ """Inverts Kepler's time equation for elliptical orbits with e > 0.78.
502
+
503
+ Args:
504
+ M (jnp.ndarray): Mean anomalies (rad). Shape: (n,).
505
+ e (float): Eccentricity. Must satisfy 0 <= e < 1.
506
+
507
+ Returns:
508
+ - E (jnp.ndarray): Eccentric anomalies (rad). Shape: (n,).
509
+ """
510
+ bounds, coeffs = getbounds(e)
511
+ e_inv = 1.0 / e
512
+ # Cut M to be between 0 and pi
513
+ Esigns, _M = cut_M(M)
514
+
515
+ # Get initial guess
516
+ # TODO: Come up with a way to do this without evaluating both functions
517
+ cond1 = (2 * _M + (1 - e)) > 0.2
518
+ init_E = jnp.where(cond1, init_E_coeffs(_M, bounds, coeffs), init_E_poly(_M, e))
519
+
520
+ sinE, cosE = fast_sinE_cosE(init_E)
521
+ dE = compute_dE_vectorized(_M, init_E, e_inv, sinE, cosE)
522
+ E = jnp.fmod(Esigns * (init_E + dE) + two_pi, two_pi)
523
+ return E
524
+
525
+
526
+ def he_E_trig(M: jnp.ndarray, e: float):
527
+ """Inverts Kepler's time equation for elliptical orbits with e > 0.78.
528
+
529
+ Args:
530
+ M (jnp.ndarray): Mean anomalies (rad). Shape: (n,).
531
+ e (float): Eccentricity. Must satisfy 0 <= e < 1.
532
+
533
+ Returns:
534
+ Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
535
+ - E (jnp.ndarray): Eccentric anomalies (rad). Shape: (n,).
536
+ - sinE (jnp.ndarray): Sine of eccentric anomalies (rad). Shape: (n,).
537
+ - cosE (jnp.ndarray): Cosine of eccentric anomalies (rad). Shape: (n,).
538
+ """
539
+ bounds, coeffs = getbounds(e)
540
+ e_inv = 1.0 / e
541
+ # Cut M to be between 0 and pi
542
+ Esigns, _M = cut_M(M)
543
+
544
+ # Get initial guess
545
+ # TODO: Come up with a way to do this without evaluating both functions
546
+ cond1 = (2 * _M + (1 - e)) > 0.2
547
+ init_E = jnp.where(cond1, init_E_coeffs(_M, bounds, coeffs), init_E_poly(_M, e))
548
+
549
+ sinE, cosE = fast_sinE_cosE(init_E)
550
+ dE = compute_dE_vectorized(_M, init_E, e_inv, sinE, cosE)
551
+ dEsq_d6 = dE**2 * if3
552
+ E = jnp.fmod(Esigns * (init_E + dE) + two_pi, two_pi)
553
+ sinE = Esigns * (sinE * (1 - 3 * dEsq_d6) + dE * (1 - dEsq_d6) * cosE)
554
+ cosE = cosE * (1 - 3 * dEsq_d6) - dE * (1 - dEsq_d6) * sinE
555
+ return E, sinE, cosE
556
+
557
+
558
+ def Etrig_1(E):
559
+ """When E <= pi_d_4."""
560
+ sinE = shortsin(E)
561
+ cosE = lax.sqrt(1.0 - sinE**2)
562
+ return sinE, cosE
563
+
564
+
565
+ def Etrig_2(E):
566
+ """When E > pi_d_4 and E < three_pi_d_4."""
567
+ cosE = shortsin(pi_d_2 - E)
568
+ sinE = lax.sqrt(1.0 - cosE**2)
569
+ return sinE, cosE
570
+
571
+
572
+ def Etrig_3(E):
573
+ """When E > pi_d_2 and E > three_pi_d_4."""
574
+ sinE = shortsin(pi - E)
575
+ cosE = -lax.sqrt(1.0 - sinE**2)
576
+ return sinE, cosE
577
+
578
+
579
+ def Etrig(i, E):
580
+ """Apply the correct trigonometric function based on the index."""
581
+ return lax.switch(i, [Etrig_1, Etrig_2, Etrig_3], E)
582
+
583
+
584
+ def fast_sinE_cosE(E):
585
+ """Compute the sine and cosine of the eccentric anomaly using shortsin."""
586
+ # Vectorize the computation across all elements
587
+ Ei = jnp.select([E <= pi_d_4, E < threepi_d_4], [0, 1], default=2)
588
+ sinE, cosE = jax.vmap(Etrig, in_axes=(0, 0))(Ei, E)
589
+ return sinE, cosE
590
+
591
+
592
+ def identity_solver(M, e):
593
+ """Returns M as E when e is 0."""
594
+ return M
@@ -0,0 +1,43 @@
1
+ Metadata-Version: 2.3
2
+ Name: orbix
3
+ Version: 0.0.2
4
+ Summary: A JAX library of functions useful for exoplanet simulations
5
+ Project-URL: Homepage, https://github.com/CoreySpohn/orbix
6
+ Project-URL: Issues, https://github.com/CoreySpohn/orbix/issues
7
+ License: MIT License
8
+
9
+ Copyright (c) 2024 Corey Spohn
10
+
11
+ Permission is hereby granted, free of charge, to any person obtaining a copy
12
+ of this software and associated documentation files (the "Software"), to deal
13
+ in the Software without restriction, including without limitation the rights
14
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15
+ copies of the Software, and to permit persons to whom the Software is
16
+ furnished to do so, subject to the following conditions:
17
+
18
+ The above copyright notice and this permission notice shall be included in all
19
+ copies or substantial portions of the Software.
20
+
21
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
27
+ SOFTWARE.
28
+ License-File: LICENSE
29
+ Requires-Python: >=3.10
30
+ Requires-Dist: jax>=0.4.34
31
+ Requires-Dist: numpy>=2.1.2
32
+ Provides-Extra: docs
33
+ Requires-Dist: matplotlib; extra == 'docs'
34
+ Requires-Dist: myst-nb; extra == 'docs'
35
+ Requires-Dist: sphinx; extra == 'docs'
36
+ Requires-Dist: sphinx-autoapi; extra == 'docs'
37
+ Requires-Dist: sphinx-autodoc-typehints; extra == 'docs'
38
+ Requires-Dist: sphinx-book-theme; extra == 'docs'
39
+ Description-Content-Type: text/markdown
40
+
41
+ # orbix
42
+
43
+ A set of fast exoplanet functions in JAX.
@@ -0,0 +1,7 @@
1
+ orbix/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ orbix/_version.py,sha256=NDHlyIcJZjLz8wKlmD1-pr6me5FHBAYwO_ynLG-37N8,411
3
+ orbix/eccanom.py,sha256=hlN1_dQEsKHlCjs6KHAQWmWWAXZq5b8xb3U8Dx3aCJw,17913
4
+ orbix-0.0.2.dist-info/METADATA,sha256=hpqrOh4tn16XYE5pUuf50iaE9RFa3QLleHeTM3814is,1968
5
+ orbix-0.0.2.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
6
+ orbix-0.0.2.dist-info/licenses/LICENSE,sha256=F5bToKEdj7k_wULTMxj-j2i1fIk6MVKjXHL0jrqK_y8,1068
7
+ orbix-0.0.2.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.25.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Corey Spohn
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.