py-growth-RHEED 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- PY_GROWTH/__init__.py +3 -0
- PY_GROWTH/jax/__init__.py +8 -0
- PY_GROWTH/jax/core.py +105 -0
- PY_GROWTH/jax/models.py +169 -0
- PY_GROWTH/jax/ode_solver.py +139 -0
- PY_GROWTH/numba/__init__.py +8 -0
- PY_GROWTH/numba/core.py +49 -0
- PY_GROWTH/numba/models.py +142 -0
- PY_GROWTH/numba/ode_solver.py +100 -0
- PY_GROWTH/numpy/__init__.py +18 -0
- PY_GROWTH/numpy/core.py +75 -0
- PY_GROWTH/numpy/io.py +95 -0
- PY_GROWTH/numpy/models.py +157 -0
- PY_GROWTH/numpy/ode_solver.py +155 -0
- PY_GROWTH/plotting/__init__.py +218 -0
- py_growth_rheed-1.0.2.dist-info/METADATA +168 -0
- py_growth_rheed-1.0.2.dist-info/RECORD +20 -0
- py_growth_rheed-1.0.2.dist-info/WHEEL +5 -0
- py_growth_rheed-1.0.2.dist-info/licenses/LICENSE +11 -0
- py_growth_rheed-1.0.2.dist-info/top_level.txt +1 -0
PY_GROWTH/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from .core import run_simulation
|
|
2
|
+
from .ode_solver import solve_ode_adaptive
|
|
3
|
+
from .models import derivs0, derivs1, derivs2, compute_intensity, compute_rms_roughness
|
|
4
|
+
|
|
5
|
+
# Re-export IO utilities from numpy module to maintain strict subpackage structural symmetry
|
|
6
|
+
from PY_GROWTH.numpy.io import load_input_data, save_history_to_json
|
|
7
|
+
|
|
8
|
+
|
PY_GROWTH/jax/core.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
"""
|
|
2
|
+
core.py
|
|
3
|
+
|
|
4
|
+
Defines the centralized pipeline that binds the physics models with the ODE
|
|
5
|
+
solver to produce and return calculation history dictionaries.
|
|
6
|
+
Accelerated end-to-end using JAX.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Dict, Any, Literal
|
|
10
|
+
import time
|
|
11
|
+
import jax
|
|
12
|
+
import jax.numpy as jnp
|
|
13
|
+
|
|
14
|
+
# Enable 64-bit precision to match NumPy/C++ for identical math tracing
|
|
15
|
+
jax.config.update("jax_enable_x64", True)
|
|
16
|
+
|
|
17
|
+
from functools import partial
|
|
18
|
+
|
|
19
|
+
from .ode_solver import solve_ode_adaptive
|
|
20
|
+
from .models import derivs0, derivs1, derivs2, compute_intensity, compute_rms_roughness
|
|
21
|
+
|
|
22
|
+
@partial(jax.jit, static_argnames=['model_type', 'num_layers', 'num_intervals'])
|
|
23
|
+
def _run_sim_jitted(
|
|
24
|
+
model_type: int,
|
|
25
|
+
num_layers: int,
|
|
26
|
+
t_max: float,
|
|
27
|
+
num_intervals: int,
|
|
28
|
+
an_kn: jnp.ndarray,
|
|
29
|
+
growth_rates: jnp.ndarray
|
|
30
|
+
):
|
|
31
|
+
"""
|
|
32
|
+
Core JIT-compiled simulation trace.
|
|
33
|
+
"""
|
|
34
|
+
if model_type == 0:
|
|
35
|
+
derivs_closure = lambda x, y: derivs0(x, y, an_kn, growth_rates)
|
|
36
|
+
elif model_type == 1:
|
|
37
|
+
derivs_closure = lambda x, y: derivs1(x, y, an_kn, growth_rates)
|
|
38
|
+
elif model_type == 2:
|
|
39
|
+
derivs_closure = lambda x, y: derivs2(x, y, an_kn, growth_rates)
|
|
40
|
+
else:
|
|
41
|
+
# Fallback to 0 if traced incorrectly, but static_argnames prevents this
|
|
42
|
+
derivs_closure = lambda x, y: derivs0(x, y, an_kn, growth_rates)
|
|
43
|
+
|
|
44
|
+
y0 = jnp.zeros(num_layers, dtype=jnp.float64)
|
|
45
|
+
|
|
46
|
+
results = solve_ode_adaptive(
|
|
47
|
+
derivs=derivs_closure,
|
|
48
|
+
y0=y0,
|
|
49
|
+
t_span=(0.0, t_max),
|
|
50
|
+
num_intervals=num_intervals
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
t_history = results["t"]
|
|
54
|
+
y_history = results["y"]
|
|
55
|
+
valid_count = results["valid_count"]
|
|
56
|
+
|
|
57
|
+
intensity = compute_intensity(y_history)
|
|
58
|
+
roughness = compute_rms_roughness(t_history, y_history, growth_rates)
|
|
59
|
+
|
|
60
|
+
return t_history, y_history, intensity, roughness, valid_count
|
|
61
|
+
|
|
62
|
+
def run_simulation(
|
|
63
|
+
model_type: Literal[0, 1, 2],
|
|
64
|
+
num_layers: int,
|
|
65
|
+
t_max: float,
|
|
66
|
+
num_intervals: int,
|
|
67
|
+
an_kn: jnp.ndarray,
|
|
68
|
+
growth_rates: jnp.ndarray
|
|
69
|
+
) -> Dict[str, Any]:
|
|
70
|
+
"""
|
|
71
|
+
Orchestrates the chosen ODE implementation on a zero-based coverage system.
|
|
72
|
+
Returns standard python dictionary slicing dynamic lengths appropriately.
|
|
73
|
+
"""
|
|
74
|
+
import numpy as np # For output conversion
|
|
75
|
+
|
|
76
|
+
start_time = time.perf_counter()
|
|
77
|
+
|
|
78
|
+
an_kn_jnp = jnp.array(an_kn, dtype=jnp.float64)
|
|
79
|
+
gR_jnp = jnp.array(growth_rates, dtype=jnp.float64)
|
|
80
|
+
|
|
81
|
+
t_hist_jnp, y_hist_jnp, int_jnp, rough_jnp, valid_count = _run_sim_jitted(
|
|
82
|
+
model_type=int(model_type),
|
|
83
|
+
num_layers=int(num_layers),
|
|
84
|
+
t_max=float(t_max),
|
|
85
|
+
num_intervals=int(num_intervals),
|
|
86
|
+
an_kn=an_kn_jnp,
|
|
87
|
+
growth_rates=gR_jnp
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# Block until ready to measure true execution time including GPU execution
|
|
91
|
+
t_hist_jnp.block_until_ready()
|
|
92
|
+
|
|
93
|
+
end_time = time.perf_counter()
|
|
94
|
+
calc_time = end_time - start_time
|
|
95
|
+
|
|
96
|
+
v = int(valid_count)
|
|
97
|
+
|
|
98
|
+
return {
|
|
99
|
+
"time": np.array(t_hist_jnp[:v]),
|
|
100
|
+
"coverage": np.array(y_hist_jnp[:v, :]),
|
|
101
|
+
"intensity": np.array(int_jnp[:v]),
|
|
102
|
+
"rms_roughness": np.array(rough_jnp[:v]),
|
|
103
|
+
"model_type": model_type,
|
|
104
|
+
"execution_time_seconds": calc_time
|
|
105
|
+
}
|
PY_GROWTH/jax/models.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
"""
|
|
2
|
+
models.py
|
|
3
|
+
|
|
4
|
+
Vectorized mathematical models translating C++ loop calculations representing
|
|
5
|
+
diffusive and distributed RHEED epitaxial growth dynamics.
|
|
6
|
+
Accelerated using JAX.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import jax
|
|
10
|
+
import jax.numpy as jnp
|
|
11
|
+
from functools import partial
|
|
12
|
+
|
|
13
|
+
@jax.jit
|
|
14
|
+
def _th_padded(theta: jnp.ndarray) -> jnp.ndarray:
|
|
15
|
+
"""
|
|
16
|
+
Pads the coverage array for generic layer interaction without out-of-bounds errors.
|
|
17
|
+
Upper boundary condition (n<=0) treats coverage as 1.0 (fully covered substrate).
|
|
18
|
+
Lower boundary condition (n>nMax) treats coverage as 0.0 (no coverage yet).
|
|
19
|
+
"""
|
|
20
|
+
return jnp.concatenate((jnp.array([1.0, 1.0]), theta, jnp.array([0.0, 0.0])))
|
|
21
|
+
|
|
22
|
+
@jax.jit
|
|
23
|
+
def _dn1(theta_array: jnp.ndarray) -> jnp.ndarray:
|
|
24
|
+
"""Distribution function 1 for Distributed model 1."""
|
|
25
|
+
th_c = jnp.clip(theta_array, 0.0, 1.0)
|
|
26
|
+
return th_c * jnp.sqrt(1.0 - th_c)
|
|
27
|
+
|
|
28
|
+
@jax.jit
|
|
29
|
+
def _dn2(theta_array: jnp.ndarray) -> jnp.ndarray:
|
|
30
|
+
"""Distribution function 2 for Distributed model 2."""
|
|
31
|
+
th_c = jnp.clip(theta_array, 0.0, 1.0)
|
|
32
|
+
return jnp.where(th_c < 0.5, jnp.sqrt(th_c), jnp.sqrt(1.0 - th_c))
|
|
33
|
+
|
|
34
|
+
@jax.jit
|
|
35
|
+
def derivs0(t: float, theta: jnp.ndarray, C: jnp.ndarray, gR: jnp.ndarray) -> jnp.ndarray:
|
|
36
|
+
""" Diffusive growth: Model 0 """
|
|
37
|
+
th = _th_padded(theta)
|
|
38
|
+
|
|
39
|
+
th_n = th[2:-2]
|
|
40
|
+
th_n_minus_1 = th[1:-3]
|
|
41
|
+
th_n_minus_2 = th[:-4]
|
|
42
|
+
th_n_plus_1 = th[3:-1]
|
|
43
|
+
th_n_plus_2 = th[4:]
|
|
44
|
+
|
|
45
|
+
dTheta = th_n_minus_1 - th_n
|
|
46
|
+
|
|
47
|
+
dThetaDt = (dTheta * gR +
|
|
48
|
+
C * (th_n_plus_1 - th_n_plus_2) * dTheta -
|
|
49
|
+
C * (th_n - th_n_plus_1) * (th_n_minus_2 - th_n_minus_1))
|
|
50
|
+
|
|
51
|
+
return dThetaDt
|
|
52
|
+
|
|
53
|
+
@jax.jit
|
|
54
|
+
def derivs1(t: float, theta: jnp.ndarray, C: jnp.ndarray, gR: jnp.ndarray) -> jnp.ndarray:
|
|
55
|
+
""" Distributed growth model 1: Model 1 """
|
|
56
|
+
th = _th_padded(theta)
|
|
57
|
+
|
|
58
|
+
th_n = th[2:-2]
|
|
59
|
+
th_n_minus_1 = th[1:-3]
|
|
60
|
+
th_n_plus_1 = th[3:-1]
|
|
61
|
+
|
|
62
|
+
dn1_n = _dn1(th_n)
|
|
63
|
+
dn1_n_minus_1 = _dn1(th_n_minus_1)
|
|
64
|
+
dn1_n_plus_1 = _dn1(th_n_plus_1)
|
|
65
|
+
|
|
66
|
+
C_n = C
|
|
67
|
+
C_n_minus_1 = jnp.pad(C, (1, 0), constant_values=0.0)[:-1]
|
|
68
|
+
|
|
69
|
+
dTheta = th_n_minus_1 - th_n
|
|
70
|
+
|
|
71
|
+
den_m1 = dn1_n_minus_1 + dn1_n
|
|
72
|
+
term_m1 = jnp.where(dn1_n_minus_1 != 0.0,
|
|
73
|
+
(C_n_minus_1 * dTheta * dn1_n_minus_1) / den_m1,
|
|
74
|
+
0.0)
|
|
75
|
+
|
|
76
|
+
den_p1 = dn1_n + dn1_n_plus_1
|
|
77
|
+
term_p1 = jnp.where(dn1_n > 0.0,
|
|
78
|
+
(C_n * (th_n - th_n_plus_1) * dn1_n) / den_p1,
|
|
79
|
+
0.0)
|
|
80
|
+
|
|
81
|
+
return (dTheta - term_m1 + term_p1) * gR
|
|
82
|
+
|
|
83
|
+
@jax.jit
|
|
84
|
+
def derivs2(t: float, theta: jnp.ndarray, C: jnp.ndarray, gR: jnp.ndarray) -> jnp.ndarray:
|
|
85
|
+
""" Distributed growth model 2: Model 2 """
|
|
86
|
+
th = _th_padded(theta)
|
|
87
|
+
|
|
88
|
+
th_n = th[2:-2]
|
|
89
|
+
th_n_minus_1 = th[1:-3]
|
|
90
|
+
th_n_plus_1 = th[3:-1]
|
|
91
|
+
|
|
92
|
+
dn2_n = _dn2(th_n)
|
|
93
|
+
dn2_n_minus_1 = _dn2(th_n_minus_1)
|
|
94
|
+
dn2_n_plus_1 = _dn2(th_n_plus_1)
|
|
95
|
+
|
|
96
|
+
C_n = C
|
|
97
|
+
C_n_minus_1 = jnp.pad(C, (1, 0), constant_values=0.0)[:-1]
|
|
98
|
+
|
|
99
|
+
dTheta = th_n_minus_1 - th_n
|
|
100
|
+
|
|
101
|
+
den_m1 = dn2_n_minus_1 + dn2_n
|
|
102
|
+
term_m1 = jnp.where(dn2_n_minus_1 != 0.0,
|
|
103
|
+
(C_n_minus_1 * dTheta * dn2_n_minus_1) / den_m1,
|
|
104
|
+
0.0)
|
|
105
|
+
|
|
106
|
+
den_p1 = dn2_n + dn2_n_plus_1
|
|
107
|
+
term_p1 = jnp.where(dn2_n > 0.0,
|
|
108
|
+
(C_n * (th_n - th_n_plus_1) * dn2_n) / den_p1,
|
|
109
|
+
0.0)
|
|
110
|
+
|
|
111
|
+
return (dTheta - term_m1 + term_p1) * gR
|
|
112
|
+
|
|
113
|
+
@jax.jit
|
|
114
|
+
def compute_intensity(coverage_history: jnp.ndarray) -> jnp.ndarray:
|
|
115
|
+
"""
|
|
116
|
+
Computes Kinematical Diffracted Intensity.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
coverage_history (jnp.ndarray): History matrix of shape (num_time_steps, numLayers).
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
jnp.ndarray: Evaluated sequence of intensity values.
|
|
123
|
+
"""
|
|
124
|
+
num_layers = coverage_history.shape[1]
|
|
125
|
+
cDI = 1.0 - coverage_history[:, 0]
|
|
126
|
+
|
|
127
|
+
n_indices = jnp.arange(1, num_layers)
|
|
128
|
+
cos_n_pi = jnp.cos(n_indices * jnp.pi)
|
|
129
|
+
|
|
130
|
+
cov_n = coverage_history[:, :-1]
|
|
131
|
+
cov_n_plus_1 = coverage_history[:, 1:]
|
|
132
|
+
|
|
133
|
+
sum_term = jnp.sum((cov_n - cov_n_plus_1) * cos_n_pi, axis=1)
|
|
134
|
+
# Using python if is evaluated at trace time, which works if num_layers is static.
|
|
135
|
+
# In JAX, shapes (and thus num_layers) are always statically traced.
|
|
136
|
+
if num_layers > 1:
|
|
137
|
+
cDI += sum_term
|
|
138
|
+
|
|
139
|
+
return cDI ** 2
|
|
140
|
+
|
|
141
|
+
@jax.jit
|
|
142
|
+
def compute_rms_roughness(growth_time: jnp.ndarray, coverage_history: jnp.ndarray, gR: jnp.ndarray) -> jnp.ndarray:
|
|
143
|
+
"""
|
|
144
|
+
Computes the RMS Roughness over time.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
growth_time (jnp.ndarray): Vector of time steps elapsed.
|
|
148
|
+
coverage_history (jnp.ndarray): Shape (num_time_steps, numLayers).
|
|
149
|
+
gR (jnp.ndarray): Growth rates of the layers.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
jnp.ndarray: Vector of evaluated RMS roughness configurations.
|
|
153
|
+
"""
|
|
154
|
+
num_layers = coverage_history.shape[1]
|
|
155
|
+
|
|
156
|
+
sD = (growth_time * gR[0])**2 * (1.0 - coverage_history[:, 0])
|
|
157
|
+
|
|
158
|
+
if num_layers > 1:
|
|
159
|
+
t = growth_time[:, jnp.newaxis]
|
|
160
|
+
n_indices = jnp.arange(1, num_layers)
|
|
161
|
+
cov_n = coverage_history[:, :-1]
|
|
162
|
+
cov_n_plus_1 = coverage_history[:, 1:]
|
|
163
|
+
|
|
164
|
+
gR_inner = gR[:num_layers-1]
|
|
165
|
+
|
|
166
|
+
sum_term = jnp.sum((n_indices - t * gR_inner)**2 * (cov_n - cov_n_plus_1), axis=1)
|
|
167
|
+
sD += sum_term
|
|
168
|
+
|
|
169
|
+
return jnp.sqrt(sD)
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ode_solver.py
|
|
3
|
+
|
|
4
|
+
Provides a generalized, pure JAX implementation of the Cash-Karp
|
|
5
|
+
fifth-order Runge-Kutta adaptive ODE solver. Designed using
|
|
6
|
+
jax.lax.while_loop to allow end-to-end JIT compilation.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import jax
|
|
10
|
+
import jax.numpy as jnp
|
|
11
|
+
from typing import Callable, Tuple, Dict
|
|
12
|
+
|
|
13
|
+
def rkck(y: jnp.ndarray, dydx: jnp.ndarray, x: float, h: float,
|
|
14
|
+
derivs: Callable[[float, jnp.ndarray], jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
|
15
|
+
"""Cash-Karp fifth-order Runge-Kutta step."""
|
|
16
|
+
b21 = 0.2; b31 = 3.0/40.0; b32 = 9.0/40.0; b41 = 0.3; b42 = -0.9; b43 = 1.2
|
|
17
|
+
b51 = -11.0/54.0; b52 = 2.5; b53 = -70.0/27.0; b54 = 35.0/27.0
|
|
18
|
+
b61 = 1631.0/55296.0; b62 = 175.0/512.0; b63 = 575.0/13824.0
|
|
19
|
+
b64 = 44275.0/110592.0; b65 = 253.0/4096.0
|
|
20
|
+
|
|
21
|
+
c1 = 37.0/378.0; c3 = 250.0/621.0; c4 = 125.0/594.0; c6 = 512.0/1771.0
|
|
22
|
+
dc5 = -277.00/14336.0
|
|
23
|
+
|
|
24
|
+
dc1 = c1 - 2825.0/27648.0; dc3 = c3 - 18575.0/48384.0
|
|
25
|
+
dc4 = c4 - 13525.0/55296.0; dc6 = c6 - 0.25
|
|
26
|
+
|
|
27
|
+
ak2 = derivs(x + b21*h, y + b21*h*dydx)
|
|
28
|
+
ak3 = derivs(x + (b31+b32)*h, y + h*(b31*dydx + b32*ak2))
|
|
29
|
+
ak4 = derivs(x + (b41+b42+b43)*h, y + h*(b41*dydx + b42*ak2 + b43*ak3))
|
|
30
|
+
ak5 = derivs(x + (b51+b52+b53+b54)*h, y + h*(b51*dydx + b52*ak2 + b53*ak3 + b54*ak4))
|
|
31
|
+
ak6 = derivs(x + (b61+b62+b63+b64+b65)*h, y + h*(b61*dydx + b62*ak2 + b63*ak3 + b64*ak4 + b65*ak5))
|
|
32
|
+
|
|
33
|
+
yout = y + h*(c1*dydx + c3*ak3 + c4*ak4 + c6*ak6)
|
|
34
|
+
yerr = h*(dc1*dydx + dc3*ak3 + dc4*ak4 + dc5*ak5 + dc6*ak6)
|
|
35
|
+
|
|
36
|
+
return yout, yerr
|
|
37
|
+
|
|
38
|
+
def rkqs(y: jnp.ndarray, dydx: jnp.ndarray, x: float, htry: float, eps: float,
|
|
39
|
+
yscal: jnp.ndarray, derivs: Callable[[float, jnp.ndarray], jnp.ndarray]):
|
|
40
|
+
"""Fifth-order Runge Kutta step with monitoring of local truncation error."""
|
|
41
|
+
SAFETY = 0.9; PGROW = -0.2; PSHRNK = -0.25; ERRCON = 1.89e-4
|
|
42
|
+
|
|
43
|
+
def cond_fun(state):
|
|
44
|
+
h, ytemp, yerr, errmax = state
|
|
45
|
+
return errmax > 1.0
|
|
46
|
+
|
|
47
|
+
def body_fun(state):
|
|
48
|
+
h, _, _, _ = state
|
|
49
|
+
ytemp, yerr = rkck(y, dydx, x, h, derivs)
|
|
50
|
+
errmax = jnp.max(jnp.abs(yerr / yscal)) / eps
|
|
51
|
+
|
|
52
|
+
htemp = SAFETY * h * (errmax ** PSHRNK)
|
|
53
|
+
h_new = jnp.where(h >= 0.0, jnp.maximum(htemp, 0.1 * h), jnp.minimum(htemp, 0.1 * h))
|
|
54
|
+
|
|
55
|
+
h_next_iter = jnp.where(errmax > 1.0, h_new, h)
|
|
56
|
+
return (h_next_iter, ytemp, yerr, errmax)
|
|
57
|
+
|
|
58
|
+
init_errmax = jnp.array(2.0, dtype=y.dtype)
|
|
59
|
+
init_state = (htry, y, jnp.zeros_like(y), init_errmax)
|
|
60
|
+
|
|
61
|
+
final_state = jax.lax.while_loop(cond_fun, body_fun, init_state)
|
|
62
|
+
h_used, ytemp, yerr, errmax = final_state
|
|
63
|
+
|
|
64
|
+
hnext = jnp.where(errmax > ERRCON,
|
|
65
|
+
SAFETY * h_used * (errmax ** PGROW),
|
|
66
|
+
5.0 * h_used)
|
|
67
|
+
|
|
68
|
+
x_new = x + h_used
|
|
69
|
+
return ytemp, x_new, hnext, h_used
|
|
70
|
+
|
|
71
|
+
def solve_ode_adaptive(derivs: Callable, y0: jnp.ndarray, t_span: Tuple[float, float],
|
|
72
|
+
num_intervals: int, eps: float = 1.0e-7, h1: float = 1.0e-3) -> Dict[str, jnp.ndarray]:
|
|
73
|
+
"""Main driver using jax.lax.while_loop."""
|
|
74
|
+
t0, t_max = t_span
|
|
75
|
+
TINY = 1.0e-30
|
|
76
|
+
|
|
77
|
+
d_x_sav = t_max / num_intervals
|
|
78
|
+
|
|
79
|
+
history_t = jnp.zeros(num_intervals + 5, dtype=jnp.float64)
|
|
80
|
+
history_y = jnp.zeros((num_intervals + 5, len(y0)), dtype=jnp.float64)
|
|
81
|
+
|
|
82
|
+
def cond_fun(state):
|
|
83
|
+
# State: y, x, h, x_sav, history_t, history_y, out_idx, done, steps
|
|
84
|
+
return jnp.logical_not(state[7])
|
|
85
|
+
|
|
86
|
+
def body_fun(state):
|
|
87
|
+
y, x, h, x_sav, history_t, history_y, out_idx, done, steps = state
|
|
88
|
+
|
|
89
|
+
dydx = derivs(x, y)
|
|
90
|
+
yscal = jnp.abs(y) + jnp.abs(dydx * h) + TINY
|
|
91
|
+
|
|
92
|
+
save_cond = jnp.logical_and(jnp.abs(x - x_sav) > jnp.abs(d_x_sav), out_idx < num_intervals)
|
|
93
|
+
|
|
94
|
+
idx_to_update = jnp.where(save_cond, out_idx, 0)
|
|
95
|
+
|
|
96
|
+
new_hist_t = jnp.where(save_cond, history_t.at[idx_to_update].set(x), history_t)
|
|
97
|
+
new_hist_y = jnp.where(save_cond, history_y.at[idx_to_update].set(y), history_y)
|
|
98
|
+
|
|
99
|
+
new_x_sav = jnp.where(save_cond, x, x_sav)
|
|
100
|
+
new_out_idx = out_idx + jnp.where(save_cond, 1, 0)
|
|
101
|
+
|
|
102
|
+
overshoot = (x + h - t_max) * (x + h - t0) > 0.0
|
|
103
|
+
h_adjusted = jnp.where(overshoot, t_max - x, h)
|
|
104
|
+
|
|
105
|
+
y_next, x_next, hnext, h_used = rkqs(y, dydx, x, h_adjusted, eps, yscal, derivs)
|
|
106
|
+
|
|
107
|
+
finished = jnp.logical_or((x_next - t_max) * (t_max - t0) >= 0.0, steps >= 300000)
|
|
108
|
+
|
|
109
|
+
save_end_cond = jnp.logical_and(finished, new_out_idx < len(new_hist_t))
|
|
110
|
+
|
|
111
|
+
final_idx = jnp.where(save_end_cond, new_out_idx, 0)
|
|
112
|
+
new_hist_t_final = jnp.where(save_end_cond, new_hist_t.at[final_idx].set(x_next), new_hist_t)
|
|
113
|
+
new_hist_y_final = jnp.where(save_end_cond, new_hist_y.at[final_idx].set(y_next), new_hist_y)
|
|
114
|
+
|
|
115
|
+
final_out_idx = new_out_idx + jnp.where(save_end_cond, 1, 0)
|
|
116
|
+
|
|
117
|
+
return (y_next, x_next, hnext, new_x_sav, new_hist_t_final, new_hist_y_final, final_out_idx, finished, steps + 1)
|
|
118
|
+
|
|
119
|
+
init_state = (
|
|
120
|
+
y0,
|
|
121
|
+
jnp.array(t0, dtype=jnp.float64),
|
|
122
|
+
jnp.array(h1, dtype=jnp.float64),
|
|
123
|
+
jnp.array(t0 - d_x_sav * 2.0, dtype=jnp.float64),
|
|
124
|
+
history_t,
|
|
125
|
+
history_y,
|
|
126
|
+
jnp.array(0, dtype=jnp.int32),
|
|
127
|
+
jnp.array(False),
|
|
128
|
+
jnp.array(0, dtype=jnp.int32)
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
final_state = jax.lax.while_loop(cond_fun, body_fun, init_state)
|
|
132
|
+
|
|
133
|
+
_, _, _, _, final_history_t, final_history_y, final_out_idx, _, _ = final_state
|
|
134
|
+
|
|
135
|
+
return {
|
|
136
|
+
"t": final_history_t,
|
|
137
|
+
"y": final_history_y,
|
|
138
|
+
"valid_count": final_out_idx
|
|
139
|
+
}
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from .core import run_simulation
|
|
2
|
+
from .ode_solver import solve_ode_adaptive
|
|
3
|
+
from .models import derivs0, derivs1, derivs2, compute_intensity, compute_rms_roughness
|
|
4
|
+
|
|
5
|
+
# Re-export IO utilities from numpy module to maintain strict subpackage structural symmetry
|
|
6
|
+
from PY_GROWTH.numpy.io import load_input_data, save_history_to_json
|
|
7
|
+
|
|
8
|
+
|
PY_GROWTH/numba/core.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import time
|
|
2
|
+
import numpy as np
|
|
3
|
+
from numba import njit
|
|
4
|
+
from .ode_solver import solve_ode_adaptive
|
|
5
|
+
from .models import derivs0, derivs1, derivs2, compute_intensity, compute_rms_roughness
|
|
6
|
+
|
|
7
|
+
@njit(cache=True)
|
|
8
|
+
def _jitted_pipeline(model_type, num_layers, t_max, num_intervals, an_kn, growth_rates):
|
|
9
|
+
"""
|
|
10
|
+
Executes the entire RKCK execution block alongside physical geometry extraction mathematically isolated inside LLVM.
|
|
11
|
+
Resolves Numba limitations regarding Generic Callable python function arguments.
|
|
12
|
+
"""
|
|
13
|
+
y0 = np.zeros(num_layers, dtype=np.float64)
|
|
14
|
+
|
|
15
|
+
if model_type == 0:
|
|
16
|
+
hist_t, hist_y = solve_ode_adaptive(derivs0, y0, (0.0, t_max), (an_kn, growth_rates), num_intervals)
|
|
17
|
+
elif model_type == 1:
|
|
18
|
+
hist_t, hist_y = solve_ode_adaptive(derivs1, y0, (0.0, t_max), (an_kn, growth_rates), num_intervals)
|
|
19
|
+
else:
|
|
20
|
+
hist_t, hist_y = solve_ode_adaptive(derivs2, y0, (0.0, t_max), (an_kn, growth_rates), num_intervals)
|
|
21
|
+
|
|
22
|
+
intensity = compute_intensity(hist_y)
|
|
23
|
+
rms = compute_rms_roughness(hist_t, hist_y, growth_rates)
|
|
24
|
+
|
|
25
|
+
return hist_t, hist_y, intensity, rms
|
|
26
|
+
|
|
27
|
+
def run_simulation(model_type: int, num_layers: int, t_max: float, num_intervals: int,
|
|
28
|
+
an_kn: np.ndarray, growth_rates: np.ndarray) -> dict:
|
|
29
|
+
"""
|
|
30
|
+
Executes the RHEED ODE simulations explicitly optimized under Numba caching mechanics.
|
|
31
|
+
"""
|
|
32
|
+
if model_type not in (0, 1, 2):
|
|
33
|
+
raise ValueError(f"Unknown model_type integer: {model_type}")
|
|
34
|
+
|
|
35
|
+
start_time = time.perf_counter()
|
|
36
|
+
|
|
37
|
+
t, y, intensity, rms = _jitted_pipeline(model_type, num_layers, t_max, num_intervals, an_kn, growth_rates)
|
|
38
|
+
|
|
39
|
+
end_time = time.perf_counter()
|
|
40
|
+
calc_time = end_time - start_time
|
|
41
|
+
|
|
42
|
+
return {
|
|
43
|
+
"time": t,
|
|
44
|
+
"coverage": y,
|
|
45
|
+
"intensity": intensity,
|
|
46
|
+
"rms_roughness": rms,
|
|
47
|
+
"model_type": model_type,
|
|
48
|
+
"execution_time_seconds": calc_time
|
|
49
|
+
}
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from numba import njit
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
@njit(cache=True)
|
|
6
|
+
def th(n: int, theta: np.ndarray) -> float:
|
|
7
|
+
"""
|
|
8
|
+
Retrieves the coverage value simulating generalized padding.
|
|
9
|
+
Upper boundary condition (n<=0) treats coverage as 1.0 (fully covered substrate).
|
|
10
|
+
Lower boundary condition (n>nMax) treats coverage as 0.0 (no coverage yet).
|
|
11
|
+
"""
|
|
12
|
+
nMax = len(theta) - 1
|
|
13
|
+
if n < 0:
|
|
14
|
+
return 1.0
|
|
15
|
+
elif n > nMax:
|
|
16
|
+
return 0.0
|
|
17
|
+
else:
|
|
18
|
+
return theta[n]
|
|
19
|
+
|
|
20
|
+
@njit(cache=True)
|
|
21
|
+
def dn1(n: int, theta: np.ndarray) -> float:
|
|
22
|
+
"""Distribution function 1 for Distributed model 1."""
|
|
23
|
+
tt = th(n, theta)
|
|
24
|
+
if tt < 0.0: tt = 0.0
|
|
25
|
+
if tt > 1.0: tt = 1.0
|
|
26
|
+
return tt * math.sqrt(1.0 - tt)
|
|
27
|
+
|
|
28
|
+
@njit(cache=True)
|
|
29
|
+
def dn2(n: int, theta: np.ndarray) -> float:
|
|
30
|
+
"""Distribution function 2 for Distributed model 2."""
|
|
31
|
+
tt = th(n, theta)
|
|
32
|
+
if tt < 0.0: tt = 0.0
|
|
33
|
+
if tt > 1.0: tt = 1.0
|
|
34
|
+
if tt < 0.5:
|
|
35
|
+
return math.sqrt(tt)
|
|
36
|
+
else:
|
|
37
|
+
return math.sqrt(1.0 - tt)
|
|
38
|
+
|
|
39
|
+
@njit(cache=True)
|
|
40
|
+
def derivs0(t: float, theta: np.ndarray, C: np.ndarray, gR: np.ndarray) -> np.ndarray:
|
|
41
|
+
""" Diffusive growth: Model 0 """
|
|
42
|
+
nMax = len(theta)
|
|
43
|
+
dThetaDt = np.zeros(nMax)
|
|
44
|
+
for n in range(nMax):
|
|
45
|
+
dTheta = th(n-1, theta) - th(n, theta)
|
|
46
|
+
dThetaDt[n] = (dTheta * gR[n] +
|
|
47
|
+
C[n] * (th(n+1, theta) - th(n+2, theta)) * dTheta -
|
|
48
|
+
C[n] * (th(n, theta) - th(n+1, theta)) * (th(n-2, theta) - th(n-1, theta)))
|
|
49
|
+
return dThetaDt
|
|
50
|
+
|
|
51
|
+
@njit(cache=True)
|
|
52
|
+
def derivs1(t: float, theta: np.ndarray, C: np.ndarray, gR: np.ndarray) -> np.ndarray:
|
|
53
|
+
""" Distributed growth model 1: Model 1 """
|
|
54
|
+
nMax = len(theta)
|
|
55
|
+
dThetaDt = np.zeros(nMax)
|
|
56
|
+
for n in range(nMax):
|
|
57
|
+
dTheta = th(n-1, theta) - th(n, theta)
|
|
58
|
+
|
|
59
|
+
dn1_n_m1 = dn1(n-1, theta)
|
|
60
|
+
dn1_n = dn1(n, theta)
|
|
61
|
+
dn1_n_p1 = dn1(n+1, theta)
|
|
62
|
+
|
|
63
|
+
C_prev = 0.0 if n == 0 else C[n-1]
|
|
64
|
+
|
|
65
|
+
if dn1_n_m1 != 0.0:
|
|
66
|
+
dTheta -= C_prev * dTheta * dn1_n_m1 / (dn1_n_m1 + dn1_n)
|
|
67
|
+
|
|
68
|
+
if dn1_n > 0.0:
|
|
69
|
+
dTheta += C[n] * (th(n, theta) - th(n+1, theta)) * dn1_n / (dn1_n + dn1_n_p1)
|
|
70
|
+
|
|
71
|
+
dThetaDt[n] = dTheta * gR[n]
|
|
72
|
+
return dThetaDt
|
|
73
|
+
|
|
74
|
+
@njit(cache=True)
|
|
75
|
+
def derivs2(t: float, theta: np.ndarray, C: np.ndarray, gR: np.ndarray) -> np.ndarray:
|
|
76
|
+
""" Distributed growth model 2: Model 2 """
|
|
77
|
+
nMax = len(theta)
|
|
78
|
+
dThetaDt = np.zeros(nMax)
|
|
79
|
+
for n in range(nMax):
|
|
80
|
+
dTheta = th(n-1, theta) - th(n, theta)
|
|
81
|
+
|
|
82
|
+
dn2_n_m1 = dn2(n-1, theta)
|
|
83
|
+
dn2_n = dn2(n, theta)
|
|
84
|
+
dn2_n_p1 = dn2(n+1, theta)
|
|
85
|
+
|
|
86
|
+
C_prev = 0.0 if n == 0 else C[n-1]
|
|
87
|
+
|
|
88
|
+
if dn2_n_m1 != 0.0:
|
|
89
|
+
dTheta -= C_prev * dTheta * dn2_n_m1 / (dn2_n_m1 + dn2_n)
|
|
90
|
+
|
|
91
|
+
if dn2_n > 0.0:
|
|
92
|
+
dTheta += C[n] * (th(n, theta) - th(n+1, theta)) * dn2_n / (dn2_n + dn2_n_p1)
|
|
93
|
+
|
|
94
|
+
dThetaDt[n] = dTheta * gR[n]
|
|
95
|
+
return dThetaDt
|
|
96
|
+
|
|
97
|
+
@njit(cache=True)
|
|
98
|
+
def compute_intensity(coverage_history: np.ndarray) -> np.ndarray:
|
|
99
|
+
"""
|
|
100
|
+
Computes Kinematical Diffracted Intensity.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
coverage_history (np.ndarray): History matrix of shape (num_time_steps, numLayers).
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
np.ndarray: Evaluated sequence of intensity values.
|
|
107
|
+
"""
|
|
108
|
+
num_t = coverage_history.shape[0]
|
|
109
|
+
num_layers = coverage_history.shape[1]
|
|
110
|
+
intensity = np.zeros(num_t)
|
|
111
|
+
|
|
112
|
+
PI = math.pi
|
|
113
|
+
for t in range(num_t):
|
|
114
|
+
cDI = 1.0 - coverage_history[t, 0]
|
|
115
|
+
for n in range(1, num_layers):
|
|
116
|
+
cDI += (coverage_history[t, n-1] - coverage_history[t, n]) * math.cos(n * PI)
|
|
117
|
+
intensity[t] = cDI * cDI
|
|
118
|
+
return intensity
|
|
119
|
+
|
|
120
|
+
@njit(cache=True)
|
|
121
|
+
def compute_rms_roughness(growthTime: np.ndarray, coverage_history: np.ndarray, gR: np.ndarray) -> np.ndarray:
|
|
122
|
+
"""
|
|
123
|
+
Computes the RMS Roughness over time.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
growthTime (np.ndarray): Vector of time steps elapsed.
|
|
127
|
+
coverage_history (np.ndarray): Shape (num_time_steps, numLayers).
|
|
128
|
+
gR (np.ndarray): Growth rates of the layers.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
np.ndarray: Vector of evaluated RMS roughness configurations.
|
|
132
|
+
"""
|
|
133
|
+
num_t = coverage_history.shape[0]
|
|
134
|
+
num_layers = coverage_history.shape[1]
|
|
135
|
+
rms = np.zeros(num_t)
|
|
136
|
+
|
|
137
|
+
for t in range(num_t):
|
|
138
|
+
sD = (growthTime[t] * gR[0])**2 * (1.0 - coverage_history[t, 0])
|
|
139
|
+
for n in range(1, num_layers):
|
|
140
|
+
sD += (n - growthTime[t] * gR[n-1])**2 * (coverage_history[t, n-1] - coverage_history[t, n])
|
|
141
|
+
rms[t] = math.sqrt(sD)
|
|
142
|
+
return rms
|