py-growth-RHEED 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,11 @@
1
+ Copyright (c) 2026, Andrzej Daniluk & Bartek Daniluk
2
+
3
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
4
+
5
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
6
+
7
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
8
+
9
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
10
+
11
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
@@ -0,0 +1,153 @@
1
+ Metadata-Version: 2.4
2
+ Name: py-growth-RHEED
3
+ Version: 1.0.0
4
+ Summary: A Python package for simulations of RHEED intensity oscillations within the kinematical approximation
5
+ Author: Bartek Daniluk
6
+ License: BSD 3-Clause
7
+ Classifier: License :: OSI Approved :: BSD License
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Topic :: Scientific/Engineering :: Physics
10
+ Classifier: Intended Audience :: Science/Research
11
+ Requires-Python: >=3.10
12
+ Description-Content-Type: text/markdown
13
+ License-File: LICENSE
14
+ Requires-Dist: numpy<2.5.0,>=2.0.0
15
+ Requires-Dist: numba>=0.60.0
16
+ Requires-Dist: jax>=0.4.30
17
+ Requires-Dist: jaxlib>=0.4.30
18
+ Requires-Dist: plotly>=5.24.0
19
+ Requires-Dist: jupyter
20
+ Dynamic: author
21
+ Dynamic: classifier
22
+ Dynamic: description
23
+ Dynamic: description-content-type
24
+ Dynamic: license
25
+ Dynamic: license-file
26
+ Dynamic: requires-dist
27
+ Dynamic: requires-python
28
+ Dynamic: summary
29
+
30
+ # PY_GROWTH Simulation Package
31
+
32
+ <div align="center">
33
+ <img src="https://img.shields.io/badge/Python-3.10+-blue.svg" alt="Python Version"/>
34
+ <img src="https://img.shields.io/badge/JAX-Accelerated-orange.svg" alt="JAX Accelerated"/>
35
+ <img src="https://img.shields.io/badge/Numba-JIT-green.svg" alt="Numba JIT"/>
36
+ <img src="https://img.shields.io/badge/Status-Production-success.svg" alt="Production Ready"/>
37
+ </div>
38
+
39
+ **PY_GROWTH** is a modern, high-performance Python package for simulating layer coverages during the growth of thin epitaxial films and the corresponding **RHEED (Reflection High-Energy Electron Diffraction)** intensities within the kinematical approximation.
40
+
41
+ This package translates, refactors, and massively accelerates legacy C++ simulation algorithms using state-of-the-art numerical engines (`NumPy`, `Numba LLVM`, and `JAX/XLA`), enabling rapid programmatic experimentation and data plotting.
42
+
43
+ ---
44
+
45
+ ## 1. Quick Start & Tutorial
46
+
47
+ Because PY_GROWTH is built dynamically for modern Python workflows, installing the package and running a simulation requires fewer than 5 lines of code.
48
+
49
+ ### Installation
50
+
51
+ The package can be installed globally into any Virtual Environment or Jupyter Server utilizing the newly added `setup.py`:
52
+ ```bash
53
+ # Install the package globally via PyPI
54
+ pip install py-RHEED
55
+ ```
56
+ This automatically resolves all physics dependencies like `jax`, `numba`, and visualization libraries like `plotly`.
57
+
58
+ ### Interactive Jupyter Workflow
59
+
60
+ 1. Prepare your input constraints inside the `exampleData/inputData.dat` file.
61
+ 2. Open Jupyter: `python -m notebook`
62
+ 3. Enter the unified execution pipeline!
63
+
64
+ ```python
65
+ import plotly.graph_objects as go
66
+ from PY_GROWTH.numpy import run_simulation
67
+ from PY_GROWTH.io import load_input_data
68
+ import os
69
+
70
+ # 1. Load mathematical constraints safely from your file
71
+ file_path = os.path.join("exampleData", "inputData.dat")
72
+ model_type, num_layers, t_max, num_intervals, linear_rel_an_kn, linear_rel_grn, an_kn, growth_rates = load_input_data(file_path)
73
+
74
+ # 2. Execute the Adaptive Runge-Kutta solver
75
+ results = run_simulation(model_type, num_layers, t_max, num_intervals, an_kn, growth_rates)
76
+
77
+ # 3. Visualize RHEED Oscillations immediately
78
+ fig = go.Figure()
79
+ fig.add_trace(go.Scatter(x=results['time'], y=results['intensity'], mode='lines', name='Intensity'))
80
+ fig.update_layout(title="RHEED Intensity Oscillations", xaxis_title="Time (s)", yaxis_title="RHEED Intensity")
81
+ fig.show()
82
+
83
+ # 4. Optional: Save raw data identically to C++ flushing
84
+ import pandas as pd
85
+ import numpy as np
86
+ df = pd.DataFrame(np.column_stack([results['time'], results['intensity'], results['y']]))
87
+ df.to_csv("RHEED_Oscillations_Export.csv", index=False)
88
+ ```
89
+
90
+ See the `examples/` directory for full script files using `JAX` and `Numba` compilers!
91
+
92
+ ---
93
+
94
+ ## 2. Theory & Background
95
+
96
+ PY_GROWTH provides a deterministic, rigorous numerical solution for initial value problems analyzing nonlinear differential equations. The user provides constraints dictating how growth parameters ($A_n$ or $k_n$) act relative to linear intervals.
97
+
98
+ The software computes three primary physical models:
99
+ * **Model 0 (Diffusive Growth):** Simulates standard continuous step-flow and localized layer progression.
100
+ * **Model 1 (Distributed Growth Type 1):** Introduces probabilistic dispersion mechanics across surface coverages.
101
+ * **Model 2 (Distributed Growth Type 2):** Alters distribution bounds to simulate chaotic island generation.
102
+
103
+ ### The Input Data Schema
104
+
105
+ The `inputData.dat` file adheres to a strict C++-legacy schema. It expects floating-point matrices defining the system exactly in this sequence:
106
+ 1. `modelType`: 0, 1, or 2 limits.
107
+ 2. `numLayers`: The number of discrete vertical atom sequences evaluated simultaneously.
108
+ 3. `tMax`: Upper limit of the growth time interval.
109
+ 4. `num_intervals`: Granularity of output data.
110
+ 5. Boolean flags dictating continuous versus fixed constraints.
111
+ 6. The $A_n$ (or $k_n$) arrays mapped individually to each layer boundary limits.
112
+ 7. Growth rates representing $1/\tau_n \text{ MLs}^{-1}$.
113
+
114
+ ---
115
+
116
+ ## 3. Academic Citations
117
+
118
+ If incorporating **PY_GROWTH** into primary research publications or academic frameworks, please consider citing the canonical compiler topologies central to our physics translation engines:
119
+
120
+ * **The ODE Solver (Cash-Karp Runge-Kutta 5th-Order):**
121
+ Cash, J. R., & Karp, A. H. (1990). *A variable order Runge-Kutta method for initial value problems with rapidly varying right-hand sides.* ACM Transactions on Mathematical Software (TOMS), 16(3), 201-222.
122
+ * **The Numba LLVM Hardware Architecture:**
123
+ Lam, S. K., Pitrou, A., & Seibert, S. (2015). *Numba: A LLVM-based Python JIT compiler.* In Proceedings of the Second Workshop on the LLVM Compiler Infrastructure in HPC (pp. 1-6).
124
+ * **The Google JAX/XLA Accelerator Integration:**
125
+ Bradbury, J., Frostig, R., Hawkins, P., Johnson, M. J., Leary, C., Maclaurin, D., Necula, G., Paszke, A., VanderPlas, J., Wanderman-Milne, S., & Zhang, Q. (2018). *JAX: composable transformations of Python+NumPy programs.* http://github.com/google/jax
126
+ ---
127
+
128
+ ## 4. PyPI Publication (Developer Notes)
129
+
130
+ Because PY_GROWTH utilizes a standard setup.py build configuration, it is fully capable of being hosted on the global **Python Package Index (PyPI)**.
131
+
132
+ To publish updates to the global repository (allowing users worldwide to use pip install py-growth), authenticate your PyPI account and execute:
133
+ `ash
134
+ # 1. Install build tools
135
+ pip install build twine
136
+
137
+ # 2. Compile the package into a wheel distribution
138
+ python -m build
139
+
140
+ # 3. Securely upload to PyPI servers
141
+ twine upload dist/*
142
+ `
143
+
144
+ ---
145
+
146
+ ## License
147
+ This project is licensed under the **BSD 3-Clause License** - see the [LICENSE](LICENSE) file for details.
148
+
149
+
150
+ ## 4. License
151
+
152
+ This project is licensed under the **BSD 3-Clause License** - see the LICENSE file for details.
153
+ Free for academic, research, and commercial use.
@@ -0,0 +1,3 @@
1
+ """
2
+ PY_GROWTH - RHEED Intensity Oscillation Simulator.
3
+ """
@@ -0,0 +1,8 @@
1
+ from .core import run_simulation
2
+ from .ode_solver import solve_ode_adaptive
3
+ from .models import derivs0, derivs1, derivs2, compute_intensity, compute_rms_roughness
4
+
5
+ # Re-export IO utilities from numpy module to maintain strict subpackage structural symmetry
6
+ from PY_GROWTH.numpy.io import load_input_data, save_history_to_json
7
+
8
+
@@ -0,0 +1,105 @@
1
+ """
2
+ core.py
3
+
4
+ Defines the centralized pipeline that binds the physics models with the ODE
5
+ solver to produce and return calculation history dictionaries.
6
+ Accelerated end-to-end using JAX.
7
+ """
8
+
9
+ from typing import Dict, Any, Literal
10
+ import time
11
+ import jax
12
+ import jax.numpy as jnp
13
+
14
+ # Enable 64-bit precision to match NumPy/C++ for identical math tracing
15
+ jax.config.update("jax_enable_x64", True)
16
+
17
+ from functools import partial
18
+
19
+ from .ode_solver import solve_ode_adaptive
20
+ from .models import derivs0, derivs1, derivs2, compute_intensity, compute_rms_roughness
21
+
22
+ @partial(jax.jit, static_argnames=['model_type', 'num_layers', 'num_intervals'])
23
+ def _run_sim_jitted(
24
+ model_type: int,
25
+ num_layers: int,
26
+ t_max: float,
27
+ num_intervals: int,
28
+ an_kn: jnp.ndarray,
29
+ growth_rates: jnp.ndarray
30
+ ):
31
+ """
32
+ Core JIT-compiled simulation trace.
33
+ """
34
+ if model_type == 0:
35
+ derivs_closure = lambda x, y: derivs0(x, y, an_kn, growth_rates)
36
+ elif model_type == 1:
37
+ derivs_closure = lambda x, y: derivs1(x, y, an_kn, growth_rates)
38
+ elif model_type == 2:
39
+ derivs_closure = lambda x, y: derivs2(x, y, an_kn, growth_rates)
40
+ else:
41
+ # Fallback to 0 if traced incorrectly, but static_argnames prevents this
42
+ derivs_closure = lambda x, y: derivs0(x, y, an_kn, growth_rates)
43
+
44
+ y0 = jnp.zeros(num_layers, dtype=jnp.float64)
45
+
46
+ results = solve_ode_adaptive(
47
+ derivs=derivs_closure,
48
+ y0=y0,
49
+ t_span=(0.0, t_max),
50
+ num_intervals=num_intervals
51
+ )
52
+
53
+ t_history = results["t"]
54
+ y_history = results["y"]
55
+ valid_count = results["valid_count"]
56
+
57
+ intensity = compute_intensity(y_history)
58
+ roughness = compute_rms_roughness(t_history, y_history, growth_rates)
59
+
60
+ return t_history, y_history, intensity, roughness, valid_count
61
+
62
+ def run_simulation(
63
+ model_type: Literal[0, 1, 2],
64
+ num_layers: int,
65
+ t_max: float,
66
+ num_intervals: int,
67
+ an_kn: jnp.ndarray,
68
+ growth_rates: jnp.ndarray
69
+ ) -> Dict[str, Any]:
70
+ """
71
+ Orchestrates the chosen ODE implementation on a zero-based coverage system.
72
+ Returns standard python dictionary slicing dynamic lengths appropriately.
73
+ """
74
+ import numpy as np # For output conversion
75
+
76
+ start_time = time.perf_counter()
77
+
78
+ an_kn_jnp = jnp.array(an_kn, dtype=jnp.float64)
79
+ gR_jnp = jnp.array(growth_rates, dtype=jnp.float64)
80
+
81
+ t_hist_jnp, y_hist_jnp, int_jnp, rough_jnp, valid_count = _run_sim_jitted(
82
+ model_type=int(model_type),
83
+ num_layers=int(num_layers),
84
+ t_max=float(t_max),
85
+ num_intervals=int(num_intervals),
86
+ an_kn=an_kn_jnp,
87
+ growth_rates=gR_jnp
88
+ )
89
+
90
+ # Block until ready to measure true execution time including GPU execution
91
+ t_hist_jnp.block_until_ready()
92
+
93
+ end_time = time.perf_counter()
94
+ calc_time = end_time - start_time
95
+
96
+ v = int(valid_count)
97
+
98
+ return {
99
+ "time": np.array(t_hist_jnp[:v]),
100
+ "coverage": np.array(y_hist_jnp[:v, :]),
101
+ "intensity": np.array(int_jnp[:v]),
102
+ "rms_roughness": np.array(rough_jnp[:v]),
103
+ "model_type": model_type,
104
+ "execution_time_seconds": calc_time
105
+ }
@@ -0,0 +1,169 @@
1
+ """
2
+ models.py
3
+
4
+ Vectorized mathematical models translating C++ loop calculations representing
5
+ diffusive and distributed RHEED epitaxial growth dynamics.
6
+ Accelerated using JAX.
7
+ """
8
+
9
+ import jax
10
+ import jax.numpy as jnp
11
+ from functools import partial
12
+
13
+ @jax.jit
14
+ def _th_padded(theta: jnp.ndarray) -> jnp.ndarray:
15
+ """
16
+ Pads the coverage array for generic layer interaction without out-of-bounds errors.
17
+ Upper boundary condition (n<=0) treats coverage as 1.0 (fully covered substrate).
18
+ Lower boundary condition (n>nMax) treats coverage as 0.0 (no coverage yet).
19
+ """
20
+ return jnp.concatenate((jnp.array([1.0, 1.0]), theta, jnp.array([0.0, 0.0])))
21
+
22
+ @jax.jit
23
+ def _dn1(theta_array: jnp.ndarray) -> jnp.ndarray:
24
+ """Distribution function 1 for Distributed model 1."""
25
+ th_c = jnp.clip(theta_array, 0.0, 1.0)
26
+ return th_c * jnp.sqrt(1.0 - th_c)
27
+
28
+ @jax.jit
29
+ def _dn2(theta_array: jnp.ndarray) -> jnp.ndarray:
30
+ """Distribution function 2 for Distributed model 2."""
31
+ th_c = jnp.clip(theta_array, 0.0, 1.0)
32
+ return jnp.where(th_c < 0.5, jnp.sqrt(th_c), jnp.sqrt(1.0 - th_c))
33
+
34
+ @jax.jit
35
+ def derivs0(t: float, theta: jnp.ndarray, C: jnp.ndarray, gR: jnp.ndarray) -> jnp.ndarray:
36
+ """ Diffusive growth: Model 0 """
37
+ th = _th_padded(theta)
38
+
39
+ th_n = th[2:-2]
40
+ th_n_minus_1 = th[1:-3]
41
+ th_n_minus_2 = th[:-4]
42
+ th_n_plus_1 = th[3:-1]
43
+ th_n_plus_2 = th[4:]
44
+
45
+ dTheta = th_n_minus_1 - th_n
46
+
47
+ dThetaDt = (dTheta * gR +
48
+ C * (th_n_plus_1 - th_n_plus_2) * dTheta -
49
+ C * (th_n - th_n_plus_1) * (th_n_minus_2 - th_n_minus_1))
50
+
51
+ return dThetaDt
52
+
53
+ @jax.jit
54
+ def derivs1(t: float, theta: jnp.ndarray, C: jnp.ndarray, gR: jnp.ndarray) -> jnp.ndarray:
55
+ """ Distributed growth model 1: Model 1 """
56
+ th = _th_padded(theta)
57
+
58
+ th_n = th[2:-2]
59
+ th_n_minus_1 = th[1:-3]
60
+ th_n_plus_1 = th[3:-1]
61
+
62
+ dn1_n = _dn1(th_n)
63
+ dn1_n_minus_1 = _dn1(th_n_minus_1)
64
+ dn1_n_plus_1 = _dn1(th_n_plus_1)
65
+
66
+ C_n = C
67
+ C_n_minus_1 = jnp.pad(C, (1, 0), constant_values=0.0)[:-1]
68
+
69
+ dTheta = th_n_minus_1 - th_n
70
+
71
+ den_m1 = dn1_n_minus_1 + dn1_n
72
+ term_m1 = jnp.where(dn1_n_minus_1 != 0.0,
73
+ (C_n_minus_1 * dTheta * dn1_n_minus_1) / den_m1,
74
+ 0.0)
75
+
76
+ den_p1 = dn1_n + dn1_n_plus_1
77
+ term_p1 = jnp.where(dn1_n > 0.0,
78
+ (C_n * (th_n - th_n_plus_1) * dn1_n) / den_p1,
79
+ 0.0)
80
+
81
+ return (dTheta - term_m1 + term_p1) * gR
82
+
83
+ @jax.jit
84
+ def derivs2(t: float, theta: jnp.ndarray, C: jnp.ndarray, gR: jnp.ndarray) -> jnp.ndarray:
85
+ """ Distributed growth model 2: Model 2 """
86
+ th = _th_padded(theta)
87
+
88
+ th_n = th[2:-2]
89
+ th_n_minus_1 = th[1:-3]
90
+ th_n_plus_1 = th[3:-1]
91
+
92
+ dn2_n = _dn2(th_n)
93
+ dn2_n_minus_1 = _dn2(th_n_minus_1)
94
+ dn2_n_plus_1 = _dn2(th_n_plus_1)
95
+
96
+ C_n = C
97
+ C_n_minus_1 = jnp.pad(C, (1, 0), constant_values=0.0)[:-1]
98
+
99
+ dTheta = th_n_minus_1 - th_n
100
+
101
+ den_m1 = dn2_n_minus_1 + dn2_n
102
+ term_m1 = jnp.where(dn2_n_minus_1 != 0.0,
103
+ (C_n_minus_1 * dTheta * dn2_n_minus_1) / den_m1,
104
+ 0.0)
105
+
106
+ den_p1 = dn2_n + dn2_n_plus_1
107
+ term_p1 = jnp.where(dn2_n > 0.0,
108
+ (C_n * (th_n - th_n_plus_1) * dn2_n) / den_p1,
109
+ 0.0)
110
+
111
+ return (dTheta - term_m1 + term_p1) * gR
112
+
113
+ @jax.jit
114
+ def compute_intensity(coverage_history: jnp.ndarray) -> jnp.ndarray:
115
+ """
116
+ Computes Kinematical Diffracted Intensity.
117
+
118
+ Args:
119
+ coverage_history (jnp.ndarray): History matrix of shape (num_time_steps, numLayers).
120
+
121
+ Returns:
122
+ jnp.ndarray: Evaluated sequence of intensity values.
123
+ """
124
+ num_layers = coverage_history.shape[1]
125
+ cDI = 1.0 - coverage_history[:, 0]
126
+
127
+ n_indices = jnp.arange(1, num_layers)
128
+ cos_n_pi = jnp.cos(n_indices * jnp.pi)
129
+
130
+ cov_n = coverage_history[:, :-1]
131
+ cov_n_plus_1 = coverage_history[:, 1:]
132
+
133
+ sum_term = jnp.sum((cov_n - cov_n_plus_1) * cos_n_pi, axis=1)
134
+ # Using python if is evaluated at trace time, which works if num_layers is static.
135
+ # In JAX, shapes (and thus num_layers) are always statically traced.
136
+ if num_layers > 1:
137
+ cDI += sum_term
138
+
139
+ return cDI ** 2
140
+
141
+ @jax.jit
142
+ def compute_rms_roughness(growth_time: jnp.ndarray, coverage_history: jnp.ndarray, gR: jnp.ndarray) -> jnp.ndarray:
143
+ """
144
+ Computes the RMS Roughness over time.
145
+
146
+ Args:
147
+ growth_time (jnp.ndarray): Vector of time steps elapsed.
148
+ coverage_history (jnp.ndarray): Shape (num_time_steps, numLayers).
149
+ gR (jnp.ndarray): Growth rates of the layers.
150
+
151
+ Returns:
152
+ jnp.ndarray: Vector of evaluated RMS roughness configurations.
153
+ """
154
+ num_layers = coverage_history.shape[1]
155
+
156
+ sD = (growth_time * gR[0])**2 * (1.0 - coverage_history[:, 0])
157
+
158
+ if num_layers > 1:
159
+ t = growth_time[:, jnp.newaxis]
160
+ n_indices = jnp.arange(1, num_layers)
161
+ cov_n = coverage_history[:, :-1]
162
+ cov_n_plus_1 = coverage_history[:, 1:]
163
+
164
+ gR_inner = gR[:num_layers-1]
165
+
166
+ sum_term = jnp.sum((n_indices - t * gR_inner)**2 * (cov_n - cov_n_plus_1), axis=1)
167
+ sD += sum_term
168
+
169
+ return jnp.sqrt(sD)
@@ -0,0 +1,139 @@
1
+ """
2
+ ode_solver.py
3
+
4
+ Provides a generalized, pure JAX implementation of the Cash-Karp
5
+ fifth-order Runge-Kutta adaptive ODE solver. Designed using
6
+ jax.lax.while_loop to allow end-to-end JIT compilation.
7
+ """
8
+
9
+ import jax
10
+ import jax.numpy as jnp
11
+ from typing import Callable, Tuple, Dict
12
+
13
+ def rkck(y: jnp.ndarray, dydx: jnp.ndarray, x: float, h: float,
14
+ derivs: Callable[[float, jnp.ndarray], jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]:
15
+ """Cash-Karp fifth-order Runge-Kutta step."""
16
+ b21 = 0.2; b31 = 3.0/40.0; b32 = 9.0/40.0; b41 = 0.3; b42 = -0.9; b43 = 1.2
17
+ b51 = -11.0/54.0; b52 = 2.5; b53 = -70.0/27.0; b54 = 35.0/27.0
18
+ b61 = 1631.0/55296.0; b62 = 175.0/512.0; b63 = 575.0/13824.0
19
+ b64 = 44275.0/110592.0; b65 = 253.0/4096.0
20
+
21
+ c1 = 37.0/378.0; c3 = 250.0/621.0; c4 = 125.0/594.0; c6 = 512.0/1771.0
22
+ dc5 = -277.00/14336.0
23
+
24
+ dc1 = c1 - 2825.0/27648.0; dc3 = c3 - 18575.0/48384.0
25
+ dc4 = c4 - 13525.0/55296.0; dc6 = c6 - 0.25
26
+
27
+ ak2 = derivs(x + b21*h, y + b21*h*dydx)
28
+ ak3 = derivs(x + (b31+b32)*h, y + h*(b31*dydx + b32*ak2))
29
+ ak4 = derivs(x + (b41+b42+b43)*h, y + h*(b41*dydx + b42*ak2 + b43*ak3))
30
+ ak5 = derivs(x + (b51+b52+b53+b54)*h, y + h*(b51*dydx + b52*ak2 + b53*ak3 + b54*ak4))
31
+ ak6 = derivs(x + (b61+b62+b63+b64+b65)*h, y + h*(b61*dydx + b62*ak2 + b63*ak3 + b64*ak4 + b65*ak5))
32
+
33
+ yout = y + h*(c1*dydx + c3*ak3 + c4*ak4 + c6*ak6)
34
+ yerr = h*(dc1*dydx + dc3*ak3 + dc4*ak4 + dc5*ak5 + dc6*ak6)
35
+
36
+ return yout, yerr
37
+
38
+ def rkqs(y: jnp.ndarray, dydx: jnp.ndarray, x: float, htry: float, eps: float,
39
+ yscal: jnp.ndarray, derivs: Callable[[float, jnp.ndarray], jnp.ndarray]):
40
+ """Fifth-order Runge Kutta step with monitoring of local truncation error."""
41
+ SAFETY = 0.9; PGROW = -0.2; PSHRNK = -0.25; ERRCON = 1.89e-4
42
+
43
+ def cond_fun(state):
44
+ h, ytemp, yerr, errmax = state
45
+ return errmax > 1.0
46
+
47
+ def body_fun(state):
48
+ h, _, _, _ = state
49
+ ytemp, yerr = rkck(y, dydx, x, h, derivs)
50
+ errmax = jnp.max(jnp.abs(yerr / yscal)) / eps
51
+
52
+ htemp = SAFETY * h * (errmax ** PSHRNK)
53
+ h_new = jnp.where(h >= 0.0, jnp.maximum(htemp, 0.1 * h), jnp.minimum(htemp, 0.1 * h))
54
+
55
+ h_next_iter = jnp.where(errmax > 1.0, h_new, h)
56
+ return (h_next_iter, ytemp, yerr, errmax)
57
+
58
+ init_errmax = jnp.array(2.0, dtype=y.dtype)
59
+ init_state = (htry, y, jnp.zeros_like(y), init_errmax)
60
+
61
+ final_state = jax.lax.while_loop(cond_fun, body_fun, init_state)
62
+ h_used, ytemp, yerr, errmax = final_state
63
+
64
+ hnext = jnp.where(errmax > ERRCON,
65
+ SAFETY * h_used * (errmax ** PGROW),
66
+ 5.0 * h_used)
67
+
68
+ x_new = x + h_used
69
+ return ytemp, x_new, hnext, h_used
70
+
71
+ def solve_ode_adaptive(derivs: Callable, y0: jnp.ndarray, t_span: Tuple[float, float],
72
+ num_intervals: int, eps: float = 1.0e-7, h1: float = 1.0e-3) -> Dict[str, jnp.ndarray]:
73
+ """Main driver using jax.lax.while_loop."""
74
+ t0, t_max = t_span
75
+ TINY = 1.0e-30
76
+
77
+ d_x_sav = t_max / num_intervals
78
+
79
+ history_t = jnp.zeros(num_intervals + 5, dtype=jnp.float64)
80
+ history_y = jnp.zeros((num_intervals + 5, len(y0)), dtype=jnp.float64)
81
+
82
+ def cond_fun(state):
83
+ # State: y, x, h, x_sav, history_t, history_y, out_idx, done, steps
84
+ return jnp.logical_not(state[7])
85
+
86
+ def body_fun(state):
87
+ y, x, h, x_sav, history_t, history_y, out_idx, done, steps = state
88
+
89
+ dydx = derivs(x, y)
90
+ yscal = jnp.abs(y) + jnp.abs(dydx * h) + TINY
91
+
92
+ save_cond = jnp.logical_and(jnp.abs(x - x_sav) > jnp.abs(d_x_sav), out_idx < num_intervals)
93
+
94
+ idx_to_update = jnp.where(save_cond, out_idx, 0)
95
+
96
+ new_hist_t = jnp.where(save_cond, history_t.at[idx_to_update].set(x), history_t)
97
+ new_hist_y = jnp.where(save_cond, history_y.at[idx_to_update].set(y), history_y)
98
+
99
+ new_x_sav = jnp.where(save_cond, x, x_sav)
100
+ new_out_idx = out_idx + jnp.where(save_cond, 1, 0)
101
+
102
+ overshoot = (x + h - t_max) * (x + h - t0) > 0.0
103
+ h_adjusted = jnp.where(overshoot, t_max - x, h)
104
+
105
+ y_next, x_next, hnext, h_used = rkqs(y, dydx, x, h_adjusted, eps, yscal, derivs)
106
+
107
+ finished = jnp.logical_or((x_next - t_max) * (t_max - t0) >= 0.0, steps >= 300000)
108
+
109
+ save_end_cond = jnp.logical_and(finished, new_out_idx < len(new_hist_t))
110
+
111
+ final_idx = jnp.where(save_end_cond, new_out_idx, 0)
112
+ new_hist_t_final = jnp.where(save_end_cond, new_hist_t.at[final_idx].set(x_next), new_hist_t)
113
+ new_hist_y_final = jnp.where(save_end_cond, new_hist_y.at[final_idx].set(y_next), new_hist_y)
114
+
115
+ final_out_idx = new_out_idx + jnp.where(save_end_cond, 1, 0)
116
+
117
+ return (y_next, x_next, hnext, new_x_sav, new_hist_t_final, new_hist_y_final, final_out_idx, finished, steps + 1)
118
+
119
+ init_state = (
120
+ y0,
121
+ jnp.array(t0, dtype=jnp.float64),
122
+ jnp.array(h1, dtype=jnp.float64),
123
+ jnp.array(t0 - d_x_sav * 2.0, dtype=jnp.float64),
124
+ history_t,
125
+ history_y,
126
+ jnp.array(0, dtype=jnp.int32),
127
+ jnp.array(False),
128
+ jnp.array(0, dtype=jnp.int32)
129
+ )
130
+
131
+ final_state = jax.lax.while_loop(cond_fun, body_fun, init_state)
132
+
133
+ _, _, _, _, final_history_t, final_history_y, final_out_idx, _, _ = final_state
134
+
135
+ return {
136
+ "t": final_history_t,
137
+ "y": final_history_y,
138
+ "valid_count": final_out_idx
139
+ }
@@ -0,0 +1,8 @@
1
+ from .core import run_simulation
2
+ from .ode_solver import solve_ode_adaptive
3
+ from .models import derivs0, derivs1, derivs2, compute_intensity, compute_rms_roughness
4
+
5
+ # Re-export IO utilities from numpy module to maintain strict subpackage structural symmetry
6
+ from PY_GROWTH.numpy.io import load_input_data, save_history_to_json
7
+
8
+
@@ -0,0 +1,49 @@
1
+ import time
2
+ import numpy as np
3
+ from numba import njit
4
+ from .ode_solver import solve_ode_adaptive
5
+ from .models import derivs0, derivs1, derivs2, compute_intensity, compute_rms_roughness
6
+
7
+ @njit(cache=True)
8
+ def _jitted_pipeline(model_type, num_layers, t_max, num_intervals, an_kn, growth_rates):
9
+ """
10
+ Executes the entire RKCK execution block alongside physical geometry extraction mathematically isolated inside LLVM.
11
+ Resolves Numba limitations regarding Generic Callable python function arguments.
12
+ """
13
+ y0 = np.zeros(num_layers, dtype=np.float64)
14
+
15
+ if model_type == 0:
16
+ hist_t, hist_y = solve_ode_adaptive(derivs0, y0, (0.0, t_max), (an_kn, growth_rates), num_intervals)
17
+ elif model_type == 1:
18
+ hist_t, hist_y = solve_ode_adaptive(derivs1, y0, (0.0, t_max), (an_kn, growth_rates), num_intervals)
19
+ else:
20
+ hist_t, hist_y = solve_ode_adaptive(derivs2, y0, (0.0, t_max), (an_kn, growth_rates), num_intervals)
21
+
22
+ intensity = compute_intensity(hist_y)
23
+ rms = compute_rms_roughness(hist_t, hist_y, growth_rates)
24
+
25
+ return hist_t, hist_y, intensity, rms
26
+
27
+ def run_simulation(model_type: int, num_layers: int, t_max: float, num_intervals: int,
28
+ an_kn: np.ndarray, growth_rates: np.ndarray) -> dict:
29
+ """
30
+ Executes the RHEED ODE simulations explicitly optimized under Numba caching mechanics.
31
+ """
32
+ if model_type not in (0, 1, 2):
33
+ raise ValueError(f"Unknown model_type integer: {model_type}")
34
+
35
+ start_time = time.perf_counter()
36
+
37
+ t, y, intensity, rms = _jitted_pipeline(model_type, num_layers, t_max, num_intervals, an_kn, growth_rates)
38
+
39
+ end_time = time.perf_counter()
40
+ calc_time = end_time - start_time
41
+
42
+ return {
43
+ "time": t,
44
+ "coverage": y,
45
+ "intensity": intensity,
46
+ "rms_roughness": rms,
47
+ "model_type": model_type,
48
+ "execution_time_seconds": calc_time
49
+ }