microlux 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- microlux/__init__.py +26 -0
- microlux/basic_function.py +236 -0
- microlux/countour.py +548 -0
- microlux/error_estimator.py +186 -0
- microlux/limb_darkening.py +51 -0
- microlux/linear_sum_assignment.py +333 -0
- microlux/model.py +274 -0
- microlux/polynomial_solver.py +320 -0
- microlux/solution.py +506 -0
- microlux/utils.py +136 -0
- microlux-0.1.0.dist-info/METADATA +86 -0
- microlux-0.1.0.dist-info/RECORD +15 -0
- microlux-0.1.0.dist-info/WHEEL +5 -0
- microlux-0.1.0.dist-info/licenses/LICENSE +21 -0
- microlux-0.1.0.dist-info/top_level.txt +1 -0
microlux/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# # -*- coding: utf-8 -*-
|
|
2
|
+
all = [
|
|
3
|
+
"point_light_curve",
|
|
4
|
+
"extended_light_curve",
|
|
5
|
+
"contour_integral",
|
|
6
|
+
"binary_mag",
|
|
7
|
+
"Iterative_State",
|
|
8
|
+
"Error_State",
|
|
9
|
+
"to_lowmass",
|
|
10
|
+
"to_centroid",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
from .basic_function import (
|
|
14
|
+
to_centroid as to_centroid,
|
|
15
|
+
to_lowmass as to_lowmass,
|
|
16
|
+
)
|
|
17
|
+
from .countour import contour_integral as contour_integral
|
|
18
|
+
from .model import (
|
|
19
|
+
binary_mag as binary_mag,
|
|
20
|
+
extended_light_curve as extended_light_curve,
|
|
21
|
+
point_light_curve as point_light_curve,
|
|
22
|
+
)
|
|
23
|
+
from .utils import (
|
|
24
|
+
Error_State as Error_State,
|
|
25
|
+
Iterative_State as Iterative_State,
|
|
26
|
+
)
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
jax.config.update("jax_enable_x64", True)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def to_centroid(s, q, x):
|
|
9
|
+
"""
|
|
10
|
+
Transforms the coordinate system to the centroid.
|
|
11
|
+
|
|
12
|
+
Parameters:
|
|
13
|
+
s (float): The projected separation between the two objects.
|
|
14
|
+
q (float): The planet to host mass ratio.
|
|
15
|
+
x (complex): The original coordinate.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
complex: The transformed coordinate in the centroid system.
|
|
19
|
+
"""
|
|
20
|
+
delta_x = s / (1 + q)
|
|
21
|
+
return -(jnp.conj(x) - delta_x)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def to_lowmass(s, q, x):
|
|
25
|
+
"""
|
|
26
|
+
Transforms the coordinate system to the system where the lower mass object is at the origin.
|
|
27
|
+
|
|
28
|
+
Parameters:
|
|
29
|
+
s (float): The separation between the two components.
|
|
30
|
+
q (float): The mass ratio of the two components.
|
|
31
|
+
x (complex): The original centroid coordinate.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
complex: The transformed coordinate in the low mass component coordinate system.
|
|
35
|
+
"""
|
|
36
|
+
delta_x = s / (1 + q)
|
|
37
|
+
return -jnp.conj(x) + delta_x
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def Quadrupole_test(rho, s, q, zeta, z, cond, tol=1e-2):
|
|
41
|
+
"""
|
|
42
|
+
The quadrupole test, ghost image test, and planetary caustic test proposed by Bozza 2010 to check the validity of the point source approximation.
|
|
43
|
+
The coefficients are fine-tuned in our implementation.
|
|
44
|
+
|
|
45
|
+
"""
|
|
46
|
+
m1 = 1 / (1 + q)
|
|
47
|
+
m2 = q / (1 + q)
|
|
48
|
+
cQ = 2
|
|
49
|
+
cG = 3
|
|
50
|
+
cP = 4 # tunable parameters vbbl 2018 + version=3.6.2 choose cQ=3,cG=miu_G (vbbl typo) ,cP=4
|
|
51
|
+
|
|
52
|
+
# basic derivatives
|
|
53
|
+
fz0 = lambda z: -m1 / (z - s) - m2 / z
|
|
54
|
+
fz1 = lambda z: m1 / (z - s) ** 2 + m2 / z**2
|
|
55
|
+
fz2 = lambda z: -2 * m1 / (z - s) ** 3 - 2 * m2 / z**3
|
|
56
|
+
fz3 = lambda z: 6 * m1 / (z - s) ** 4 + 6 * m2 / z**4
|
|
57
|
+
J = lambda z: 1 - fz1(z) * jnp.conj(fz1(z))
|
|
58
|
+
|
|
59
|
+
####Quadrupole test
|
|
60
|
+
miu_Q = jnp.abs(
|
|
61
|
+
-2
|
|
62
|
+
* jnp.real(
|
|
63
|
+
3 * jnp.conj(fz1(z)) ** 3 * fz2(z) ** 2
|
|
64
|
+
- (3 - 3 * J(z) + J(z) ** 2 / 2) * jnp.abs(fz2(z)) ** 2
|
|
65
|
+
+ J(z) * jnp.conj(fz1(z)) ** 2 * fz3(z)
|
|
66
|
+
)
|
|
67
|
+
/ (J(z) ** 5)
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# cusp test
|
|
71
|
+
miu_C = jnp.abs(jnp.imag(3 * jnp.conj(fz1(z)) ** 3 * fz2(z) ** 2) / (J(z) ** 5))
|
|
72
|
+
mag = jnp.sum(jnp.where(cond, jnp.abs(1 / J(z)), 0), axis=1)
|
|
73
|
+
cond1 = (
|
|
74
|
+
jnp.sum(jnp.where(cond, (miu_Q + miu_C), 0), axis=1)
|
|
75
|
+
* cQ
|
|
76
|
+
* (rho**2 + 1e-4 * tol)
|
|
77
|
+
< tol
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
####ghost image test
|
|
81
|
+
zwave = jnp.conj(zeta) - fz0(z)
|
|
82
|
+
J_wave = 1 - fz1(z) * fz1(zwave)
|
|
83
|
+
J3 = J_wave * fz2(jnp.conj(z)) * fz1(z)
|
|
84
|
+
miu_G = jnp.abs((J3 - jnp.conj(J3) * fz1(zwave)) / (J(z) * J_wave**2))
|
|
85
|
+
miu_G = jnp.where(cond, 0, miu_G)
|
|
86
|
+
cond2 = ((rho + 1e-3) * miu_G * cG < 1).all(axis=1) # all() is same with VBBL code
|
|
87
|
+
|
|
88
|
+
#####planet test # in our frame primary is at s, the planet is at 0, so the position of the planetary caustic is 1/s
|
|
89
|
+
cond3 = (
|
|
90
|
+
(q > 1e-2)
|
|
91
|
+
| (
|
|
92
|
+
jnp.abs(zeta - 1 / s) ** 2
|
|
93
|
+
> cP * (rho**2 + 9 * q / s**2) # rho**2*s**2<q comment out in vbbl 3.6.2
|
|
94
|
+
)
|
|
95
|
+
)[:, 0]
|
|
96
|
+
|
|
97
|
+
return cond1 & cond2 & cond3, mag
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def get_poly_coff(zeta_l, s, m2):
|
|
101
|
+
"""
|
|
102
|
+
get the polynomial cofficients of the polynomial equation of the lens equation. The low mass object is at the origin and the primary is at s.
|
|
103
|
+
The input zeta_l should have the shape of (n,1) for broadcasting.
|
|
104
|
+
"""
|
|
105
|
+
zeta_conj = jnp.conj(zeta_l)
|
|
106
|
+
c0 = s**2 * zeta_l * m2**2
|
|
107
|
+
c1 = -s * m2 * (2 * zeta_l + s * (-1 + s * zeta_l - 2 * zeta_l * zeta_conj + m2))
|
|
108
|
+
c2 = (
|
|
109
|
+
zeta_l
|
|
110
|
+
- s**3 * zeta_l * zeta_conj
|
|
111
|
+
+ s * (-1 + m2 - 2 * zeta_conj * zeta_l * (1 + m2))
|
|
112
|
+
+ s**2 * (zeta_conj - 2 * zeta_conj * m2 + zeta_l * (1 + zeta_conj**2 + m2))
|
|
113
|
+
)
|
|
114
|
+
c3 = (
|
|
115
|
+
s**3 * zeta_conj
|
|
116
|
+
+ 2 * zeta_l * zeta_conj
|
|
117
|
+
+ s**2 * (-1 + 2 * zeta_conj * zeta_l - zeta_conj**2 + m2)
|
|
118
|
+
- s * (zeta_l + 2 * zeta_l * zeta_conj**2 - 2 * zeta_conj * m2)
|
|
119
|
+
)
|
|
120
|
+
c4 = zeta_conj * (-1 + 2 * s * zeta_conj + zeta_conj * zeta_l) - s * (
|
|
121
|
+
-1 + 2 * s * zeta_conj + zeta_conj * zeta_l + m2
|
|
122
|
+
)
|
|
123
|
+
c5 = (s - zeta_conj) * zeta_conj
|
|
124
|
+
coff = jnp.concatenate((c5, c4, c3, c2, c1, c0), axis=1)
|
|
125
|
+
return coff
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def get_zeta_l(rho, trajectory_centroid_l, theta): # 获得等高线采样的zeta
|
|
129
|
+
zeta_l = trajectory_centroid_l + rho * jnp.exp(1j * theta)
|
|
130
|
+
return zeta_l
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def verify(zeta_l, z_l, s, m1, m2): # verify whether the root is right
|
|
134
|
+
return jnp.abs(z_l - m1 / (jnp.conj(z_l) - s) - m2 / jnp.conj(z_l) - zeta_l)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def get_parity(z, s, m1, m2): # get the parity of roots
|
|
138
|
+
de_conjzeta_z1 = m1 / (jnp.conj(z) - s) ** 2 + m2 / jnp.conj(z) ** 2
|
|
139
|
+
return jnp.sign((1 - jnp.abs(de_conjzeta_z1) ** 2))
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def get_parity_error(z, s, m1, m2):
|
|
143
|
+
de_conjzeta_z1 = m1 / (jnp.conj(z) - s) ** 2 + m2 / jnp.conj(z) ** 2
|
|
144
|
+
return jnp.abs((1 - jnp.abs(de_conjzeta_z1) ** 2))
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def dot_product(a, b):
|
|
148
|
+
return jnp.real(a) * jnp.real(b) + jnp.imag(a) * jnp.imag(b)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def basic_partial(z, theta, rho, q, s, caustic_crossing):
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
basic partial derivatives of the lens equation with respect to zeta, z, and theta used in the error estimation.
|
|
155
|
+
|
|
156
|
+
"""
|
|
157
|
+
z_c = jnp.conj(z)
|
|
158
|
+
parZetaConZ = 1 / (1 + q) * (1 / (z_c - s) ** 2 + q / z_c**2)
|
|
159
|
+
par2ConZetaZ = -2 / (1 + q) * (1 / (z - s) ** 3 + q / (z) ** 3)
|
|
160
|
+
de_zeta = 1j * rho * jnp.exp(1j * theta)
|
|
161
|
+
detJ = 1 - jnp.abs(parZetaConZ) ** 2
|
|
162
|
+
de_z = (de_zeta - parZetaConZ * jnp.conj(de_zeta)) / detJ
|
|
163
|
+
deXProde2X = (rho**2 + jnp.imag(de_z**2 * de_zeta * par2ConZetaZ)) / detJ
|
|
164
|
+
|
|
165
|
+
def get_de_deXPro_de2X(carry):
|
|
166
|
+
# now only calculate the derivative of x'^x'' with respect to \theta if caustic_crossing is True which is used in e4 calculation
|
|
167
|
+
# still need to test weather this is robust enough for the case that source is very close to the caustic but not crossing it
|
|
168
|
+
|
|
169
|
+
de2_zeta = -rho * jnp.exp(1j * theta)
|
|
170
|
+
de2_zetaConj = -rho * jnp.exp(-1j * theta)
|
|
171
|
+
par3ConZetaZ = 6 / (1 + q) * (1 / (z - s) ** 4 + q / (z) ** 4)
|
|
172
|
+
de2_z = (
|
|
173
|
+
de2_zeta
|
|
174
|
+
- jnp.conj(par2ConZetaZ) * jnp.conj(de_z) ** 2
|
|
175
|
+
- parZetaConZ * (de2_zetaConj - par2ConZetaZ * de_z**2)
|
|
176
|
+
) / detJ
|
|
177
|
+
# deXProde2X_test = 1/(2*1j)*(de2_z*jnp.conj(de_z)-de_z*jnp.conj(de2_z))
|
|
178
|
+
# jax.debug.print('deXProde2X_test error is {}',jnp.nansum(jnp.abs(deXProde2X_test-deXProde2X)))
|
|
179
|
+
de_deXPro_de2X = (
|
|
180
|
+
1
|
|
181
|
+
/ detJ**2
|
|
182
|
+
* jnp.imag(
|
|
183
|
+
detJ
|
|
184
|
+
* (
|
|
185
|
+
de2_zeta * par2ConZetaZ * de_z**2
|
|
186
|
+
+ de_zeta * par3ConZetaZ * de_z**3
|
|
187
|
+
+ de_zeta * par2ConZetaZ * 2 * de_z * de2_z
|
|
188
|
+
)
|
|
189
|
+
+ (
|
|
190
|
+
jnp.conj(par2ConZetaZ) * jnp.conj(de_z) * jnp.conj(parZetaConZ)
|
|
191
|
+
+ parZetaConZ * par2ConZetaZ * de_z
|
|
192
|
+
)
|
|
193
|
+
* de_zeta
|
|
194
|
+
* par2ConZetaZ
|
|
195
|
+
* de_z**2
|
|
196
|
+
)
|
|
197
|
+
)
|
|
198
|
+
return de_deXPro_de2X
|
|
199
|
+
|
|
200
|
+
de_deXPro_de2X = jax.lax.cond(
|
|
201
|
+
caustic_crossing, get_de_deXPro_de2X, lambda x: jnp.zeros_like(deXProde2X), None
|
|
202
|
+
)
|
|
203
|
+
# deXProde2X = jax.lax.stop_gradient(deXProde2X)
|
|
204
|
+
return deXProde2X, de_z, de_deXPro_de2X
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
@jax.custom_jvp
|
|
208
|
+
def refine_gradient(zeta_l, q, s, z):
|
|
209
|
+
return z
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
@refine_gradient.defjvp
|
|
213
|
+
def refine_gradient_jvp(primals, tangents):
|
|
214
|
+
"""
|
|
215
|
+
use the custom jvp to refine the gradient of roots respect to zeta_l, based on the equation on V.Bozza 2010 eq 20 and also see our paper for the details
|
|
216
|
+
This will simplify the computational graph and accelerate the gradient calculation
|
|
217
|
+
"""
|
|
218
|
+
zeta, q, s, z = primals
|
|
219
|
+
tangent_zeta, tangent_q, tangent_s, tangent_z = tangents
|
|
220
|
+
|
|
221
|
+
z_c = jnp.conj(z)
|
|
222
|
+
parZetaConZ = 1 / (1 + q) * (1 / (z_c - s) ** 2 + q / z_c**2)
|
|
223
|
+
detJ = 1 - jnp.abs(parZetaConZ) ** 2
|
|
224
|
+
|
|
225
|
+
parZetaq = 1 / (1 + q) ** 2 * (1 / (z_c - s) - 1 / z_c)
|
|
226
|
+
add_item_q = tangent_q * (parZetaq - jnp.conj(parZetaq) * parZetaConZ)
|
|
227
|
+
|
|
228
|
+
parZetas = -1 / (1 + q) / (z_c - s) ** 2
|
|
229
|
+
add_item_s = tangent_s * (parZetas - jnp.conj(parZetas) * parZetaConZ)
|
|
230
|
+
|
|
231
|
+
tangent_z2 = (
|
|
232
|
+
tangent_zeta - parZetaConZ * jnp.conj(tangent_zeta) - add_item_q - add_item_s
|
|
233
|
+
) / detJ
|
|
234
|
+
# tangent_z2 = jnp.where(jnp.isnan(tangent_z2),0.,tangent_z2)
|
|
235
|
+
# jax.debug.print('{}',(tangent_z2-tangent_z).sum())
|
|
236
|
+
return z, tangent_z2
|