microlux 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
microlux/__init__.py ADDED
@@ -0,0 +1,26 @@
1
+ # # -*- coding: utf-8 -*-
2
+ all = [
3
+ "point_light_curve",
4
+ "extended_light_curve",
5
+ "contour_integral",
6
+ "binary_mag",
7
+ "Iterative_State",
8
+ "Error_State",
9
+ "to_lowmass",
10
+ "to_centroid",
11
+ ]
12
+
13
+ from .basic_function import (
14
+ to_centroid as to_centroid,
15
+ to_lowmass as to_lowmass,
16
+ )
17
+ from .countour import contour_integral as contour_integral
18
+ from .model import (
19
+ binary_mag as binary_mag,
20
+ extended_light_curve as extended_light_curve,
21
+ point_light_curve as point_light_curve,
22
+ )
23
+ from .utils import (
24
+ Error_State as Error_State,
25
+ Iterative_State as Iterative_State,
26
+ )
@@ -0,0 +1,236 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+
4
+
5
+ jax.config.update("jax_enable_x64", True)
6
+
7
+
8
+ def to_centroid(s, q, x):
9
+ """
10
+ Transforms the coordinate system to the centroid.
11
+
12
+ Parameters:
13
+ s (float): The projected separation between the two objects.
14
+ q (float): The planet to host mass ratio.
15
+ x (complex): The original coordinate.
16
+
17
+ Returns:
18
+ complex: The transformed coordinate in the centroid system.
19
+ """
20
+ delta_x = s / (1 + q)
21
+ return -(jnp.conj(x) - delta_x)
22
+
23
+
24
+ def to_lowmass(s, q, x):
25
+ """
26
+ Transforms the coordinate system to the system where the lower mass object is at the origin.
27
+
28
+ Parameters:
29
+ s (float): The separation between the two components.
30
+ q (float): The mass ratio of the two components.
31
+ x (complex): The original centroid coordinate.
32
+
33
+ Returns:
34
+ complex: The transformed coordinate in the low mass component coordinate system.
35
+ """
36
+ delta_x = s / (1 + q)
37
+ return -jnp.conj(x) + delta_x
38
+
39
+
40
+ def Quadrupole_test(rho, s, q, zeta, z, cond, tol=1e-2):
41
+ """
42
+ The quadrupole test, ghost image test, and planetary caustic test proposed by Bozza 2010 to check the validity of the point source approximation.
43
+ The coefficients are fine-tuned in our implementation.
44
+
45
+ """
46
+ m1 = 1 / (1 + q)
47
+ m2 = q / (1 + q)
48
+ cQ = 2
49
+ cG = 3
50
+ cP = 4 # tunable parameters vbbl 2018 + version=3.6.2 choose cQ=3,cG=miu_G (vbbl typo) ,cP=4
51
+
52
+ # basic derivatives
53
+ fz0 = lambda z: -m1 / (z - s) - m2 / z
54
+ fz1 = lambda z: m1 / (z - s) ** 2 + m2 / z**2
55
+ fz2 = lambda z: -2 * m1 / (z - s) ** 3 - 2 * m2 / z**3
56
+ fz3 = lambda z: 6 * m1 / (z - s) ** 4 + 6 * m2 / z**4
57
+ J = lambda z: 1 - fz1(z) * jnp.conj(fz1(z))
58
+
59
+ ####Quadrupole test
60
+ miu_Q = jnp.abs(
61
+ -2
62
+ * jnp.real(
63
+ 3 * jnp.conj(fz1(z)) ** 3 * fz2(z) ** 2
64
+ - (3 - 3 * J(z) + J(z) ** 2 / 2) * jnp.abs(fz2(z)) ** 2
65
+ + J(z) * jnp.conj(fz1(z)) ** 2 * fz3(z)
66
+ )
67
+ / (J(z) ** 5)
68
+ )
69
+
70
+ # cusp test
71
+ miu_C = jnp.abs(jnp.imag(3 * jnp.conj(fz1(z)) ** 3 * fz2(z) ** 2) / (J(z) ** 5))
72
+ mag = jnp.sum(jnp.where(cond, jnp.abs(1 / J(z)), 0), axis=1)
73
+ cond1 = (
74
+ jnp.sum(jnp.where(cond, (miu_Q + miu_C), 0), axis=1)
75
+ * cQ
76
+ * (rho**2 + 1e-4 * tol)
77
+ < tol
78
+ )
79
+
80
+ ####ghost image test
81
+ zwave = jnp.conj(zeta) - fz0(z)
82
+ J_wave = 1 - fz1(z) * fz1(zwave)
83
+ J3 = J_wave * fz2(jnp.conj(z)) * fz1(z)
84
+ miu_G = jnp.abs((J3 - jnp.conj(J3) * fz1(zwave)) / (J(z) * J_wave**2))
85
+ miu_G = jnp.where(cond, 0, miu_G)
86
+ cond2 = ((rho + 1e-3) * miu_G * cG < 1).all(axis=1) # all() is same with VBBL code
87
+
88
+ #####planet test # in our frame primary is at s, the planet is at 0, so the position of the planetary caustic is 1/s
89
+ cond3 = (
90
+ (q > 1e-2)
91
+ | (
92
+ jnp.abs(zeta - 1 / s) ** 2
93
+ > cP * (rho**2 + 9 * q / s**2) # rho**2*s**2<q comment out in vbbl 3.6.2
94
+ )
95
+ )[:, 0]
96
+
97
+ return cond1 & cond2 & cond3, mag
98
+
99
+
100
+ def get_poly_coff(zeta_l, s, m2):
101
+ """
102
+ get the polynomial cofficients of the polynomial equation of the lens equation. The low mass object is at the origin and the primary is at s.
103
+ The input zeta_l should have the shape of (n,1) for broadcasting.
104
+ """
105
+ zeta_conj = jnp.conj(zeta_l)
106
+ c0 = s**2 * zeta_l * m2**2
107
+ c1 = -s * m2 * (2 * zeta_l + s * (-1 + s * zeta_l - 2 * zeta_l * zeta_conj + m2))
108
+ c2 = (
109
+ zeta_l
110
+ - s**3 * zeta_l * zeta_conj
111
+ + s * (-1 + m2 - 2 * zeta_conj * zeta_l * (1 + m2))
112
+ + s**2 * (zeta_conj - 2 * zeta_conj * m2 + zeta_l * (1 + zeta_conj**2 + m2))
113
+ )
114
+ c3 = (
115
+ s**3 * zeta_conj
116
+ + 2 * zeta_l * zeta_conj
117
+ + s**2 * (-1 + 2 * zeta_conj * zeta_l - zeta_conj**2 + m2)
118
+ - s * (zeta_l + 2 * zeta_l * zeta_conj**2 - 2 * zeta_conj * m2)
119
+ )
120
+ c4 = zeta_conj * (-1 + 2 * s * zeta_conj + zeta_conj * zeta_l) - s * (
121
+ -1 + 2 * s * zeta_conj + zeta_conj * zeta_l + m2
122
+ )
123
+ c5 = (s - zeta_conj) * zeta_conj
124
+ coff = jnp.concatenate((c5, c4, c3, c2, c1, c0), axis=1)
125
+ return coff
126
+
127
+
128
+ def get_zeta_l(rho, trajectory_centroid_l, theta): # 获得等高线采样的zeta
129
+ zeta_l = trajectory_centroid_l + rho * jnp.exp(1j * theta)
130
+ return zeta_l
131
+
132
+
133
+ def verify(zeta_l, z_l, s, m1, m2): # verify whether the root is right
134
+ return jnp.abs(z_l - m1 / (jnp.conj(z_l) - s) - m2 / jnp.conj(z_l) - zeta_l)
135
+
136
+
137
+ def get_parity(z, s, m1, m2): # get the parity of roots
138
+ de_conjzeta_z1 = m1 / (jnp.conj(z) - s) ** 2 + m2 / jnp.conj(z) ** 2
139
+ return jnp.sign((1 - jnp.abs(de_conjzeta_z1) ** 2))
140
+
141
+
142
+ def get_parity_error(z, s, m1, m2):
143
+ de_conjzeta_z1 = m1 / (jnp.conj(z) - s) ** 2 + m2 / jnp.conj(z) ** 2
144
+ return jnp.abs((1 - jnp.abs(de_conjzeta_z1) ** 2))
145
+
146
+
147
+ def dot_product(a, b):
148
+ return jnp.real(a) * jnp.real(b) + jnp.imag(a) * jnp.imag(b)
149
+
150
+
151
+ def basic_partial(z, theta, rho, q, s, caustic_crossing):
152
+ """
153
+
154
+ basic partial derivatives of the lens equation with respect to zeta, z, and theta used in the error estimation.
155
+
156
+ """
157
+ z_c = jnp.conj(z)
158
+ parZetaConZ = 1 / (1 + q) * (1 / (z_c - s) ** 2 + q / z_c**2)
159
+ par2ConZetaZ = -2 / (1 + q) * (1 / (z - s) ** 3 + q / (z) ** 3)
160
+ de_zeta = 1j * rho * jnp.exp(1j * theta)
161
+ detJ = 1 - jnp.abs(parZetaConZ) ** 2
162
+ de_z = (de_zeta - parZetaConZ * jnp.conj(de_zeta)) / detJ
163
+ deXProde2X = (rho**2 + jnp.imag(de_z**2 * de_zeta * par2ConZetaZ)) / detJ
164
+
165
+ def get_de_deXPro_de2X(carry):
166
+ # now only calculate the derivative of x'^x'' with respect to \theta if caustic_crossing is True which is used in e4 calculation
167
+ # still need to test weather this is robust enough for the case that source is very close to the caustic but not crossing it
168
+
169
+ de2_zeta = -rho * jnp.exp(1j * theta)
170
+ de2_zetaConj = -rho * jnp.exp(-1j * theta)
171
+ par3ConZetaZ = 6 / (1 + q) * (1 / (z - s) ** 4 + q / (z) ** 4)
172
+ de2_z = (
173
+ de2_zeta
174
+ - jnp.conj(par2ConZetaZ) * jnp.conj(de_z) ** 2
175
+ - parZetaConZ * (de2_zetaConj - par2ConZetaZ * de_z**2)
176
+ ) / detJ
177
+ # deXProde2X_test = 1/(2*1j)*(de2_z*jnp.conj(de_z)-de_z*jnp.conj(de2_z))
178
+ # jax.debug.print('deXProde2X_test error is {}',jnp.nansum(jnp.abs(deXProde2X_test-deXProde2X)))
179
+ de_deXPro_de2X = (
180
+ 1
181
+ / detJ**2
182
+ * jnp.imag(
183
+ detJ
184
+ * (
185
+ de2_zeta * par2ConZetaZ * de_z**2
186
+ + de_zeta * par3ConZetaZ * de_z**3
187
+ + de_zeta * par2ConZetaZ * 2 * de_z * de2_z
188
+ )
189
+ + (
190
+ jnp.conj(par2ConZetaZ) * jnp.conj(de_z) * jnp.conj(parZetaConZ)
191
+ + parZetaConZ * par2ConZetaZ * de_z
192
+ )
193
+ * de_zeta
194
+ * par2ConZetaZ
195
+ * de_z**2
196
+ )
197
+ )
198
+ return de_deXPro_de2X
199
+
200
+ de_deXPro_de2X = jax.lax.cond(
201
+ caustic_crossing, get_de_deXPro_de2X, lambda x: jnp.zeros_like(deXProde2X), None
202
+ )
203
+ # deXProde2X = jax.lax.stop_gradient(deXProde2X)
204
+ return deXProde2X, de_z, de_deXPro_de2X
205
+
206
+
207
+ @jax.custom_jvp
208
+ def refine_gradient(zeta_l, q, s, z):
209
+ return z
210
+
211
+
212
+ @refine_gradient.defjvp
213
+ def refine_gradient_jvp(primals, tangents):
214
+ """
215
+ use the custom jvp to refine the gradient of roots respect to zeta_l, based on the equation on V.Bozza 2010 eq 20 and also see our paper for the details
216
+ This will simplify the computational graph and accelerate the gradient calculation
217
+ """
218
+ zeta, q, s, z = primals
219
+ tangent_zeta, tangent_q, tangent_s, tangent_z = tangents
220
+
221
+ z_c = jnp.conj(z)
222
+ parZetaConZ = 1 / (1 + q) * (1 / (z_c - s) ** 2 + q / z_c**2)
223
+ detJ = 1 - jnp.abs(parZetaConZ) ** 2
224
+
225
+ parZetaq = 1 / (1 + q) ** 2 * (1 / (z_c - s) - 1 / z_c)
226
+ add_item_q = tangent_q * (parZetaq - jnp.conj(parZetaq) * parZetaConZ)
227
+
228
+ parZetas = -1 / (1 + q) / (z_c - s) ** 2
229
+ add_item_s = tangent_s * (parZetas - jnp.conj(parZetas) * parZetaConZ)
230
+
231
+ tangent_z2 = (
232
+ tangent_zeta - parZetaConZ * jnp.conj(tangent_zeta) - add_item_q - add_item_s
233
+ ) / detJ
234
+ # tangent_z2 = jnp.where(jnp.isnan(tangent_z2),0.,tangent_z2)
235
+ # jax.debug.print('{}',(tangent_z2-tangent_z).sum())
236
+ return z, tangent_z2