mukip 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mukip/__init__.py ADDED
@@ -0,0 +1,6 @@
1
+ from .microkinetic_model import MicrokineticModel
2
+ from .jvm_manager import initialize
3
+ from .plot_2d import plot_2d
4
+
5
+ __version__ = "0.4.0"
6
+ __all__ = ['initialize', 'MicrokineticModel', 'plot_2d']
mukip/jvm_manager.py ADDED
@@ -0,0 +1,241 @@
1
+ """
2
+ JVM management module with automatic JRE download and initialization.
3
+ """
4
+ import os
5
+ import sys
6
+ import platform
7
+ import urllib.request
8
+ import ssl
9
+ import zipfile
10
+ import tarfile
11
+ import shutil
12
+ from pathlib import Path
13
+ from typing import Optional, Tuple
14
+ import jpype
15
+ from jpype import JClass
16
+
17
+ # Package root directory
18
+ PACKAGE_DIR = Path(__file__).parent
19
+ JRE_DIR = PACKAGE_DIR / "jre24"
20
+ JAR_FILE = PACKAGE_DIR / "resources" / "mukip-0.4.0-beta-all.jar"
21
+
22
+ MUKIP_INITIALIZED = False
23
+
24
+ # JRE download configuration
25
+ JRE_VERSION = "24.0.2_12"
26
+ JRE_BASE_URL = "https://github.com/WestlakeUniversity-WangLab/jre8-mirror/releases/download"
27
+
28
+ jvm_paths = {
29
+ 'windows': JRE_DIR / 'bin' / 'server' / 'jvm.dll',
30
+ 'linux': JRE_DIR / 'lib' / 'server' / 'libjvm.so',
31
+ 'darwin': JRE_DIR / 'Contents' / 'Home' / 'lib' / 'server' / 'libjvm.dylib'
32
+ }
33
+
34
+ def detect_system_arch() -> Tuple[str, str]:
35
+ """
36
+ Detect system OS and architecture.
37
+
38
+ :return: Tuple of (system, arch) compatible with JRE downloads
39
+ """
40
+ raw_system = platform.system().lower()
41
+ raw_machine = platform.machine().lower()
42
+
43
+ if raw_system == "darwin":
44
+ system = "mac"
45
+ elif raw_system == "windows":
46
+ system = "windows"
47
+ elif raw_system == "linux":
48
+ system = "linux"
49
+ elif raw_system == "aix":
50
+ system = "aix"
51
+ else:
52
+ system = raw_system
53
+
54
+ if raw_machine in ["x86_64", "amd64"]:
55
+ arch = "x64"
56
+ elif raw_machine in ["arm64", "aarch64"]:
57
+ arch = "aarch64"
58
+ elif raw_machine in ["ppc64le", "powerpc64le"]:
59
+ arch = "ppc64le"
60
+ elif raw_machine == "s390x":
61
+ arch = "s390x"
62
+ elif raw_machine == "riscv64":
63
+ arch = "riscv64"
64
+ elif raw_machine.startswith("arm"):
65
+ arch = "arm"
66
+ else:
67
+ arch = raw_machine
68
+
69
+ if system == "linux":
70
+ if os.path.exists("/etc/alpine-release"):
71
+ system = "alpine-linux"
72
+
73
+ if not arch:
74
+ arch = raw_machine
75
+
76
+ return system, arch
77
+
78
+
79
+ def get_jvm_library_path() -> Optional[str]:
80
+ """
81
+ Get the JVM library path for the downloaded JRE.
82
+
83
+ :return: Path to JVM library or None if not found
84
+ """
85
+ if not JRE_DIR.exists():
86
+ return None
87
+
88
+ system = platform.system().lower()
89
+ jvm_path = jvm_paths.get(system)
90
+ if jvm_path and jvm_path.exists():
91
+ return str(jvm_path)
92
+
93
+ return None
94
+
95
+
96
+ def download_jre(progress_callback=None) -> bool:
97
+ """
98
+ Download and extract JRE from GitHub releases.
99
+
100
+ :param progress_callback: Optional callback function(step, percent)
101
+ step: 'download' or 'extract'
102
+ percent: 0-100
103
+ :return: True if successful, False otherwise
104
+ """
105
+ system, arch = detect_system_arch()
106
+ extension = "zip" if system == "windows" else "tar.gz"
107
+ filename = f"OpenJDK24U-jre_{arch}_{system}_hotspot_{JRE_VERSION}.{extension}"
108
+ download_url = f"{JRE_BASE_URL}/{JRE_VERSION}/{filename}"
109
+
110
+ temp_dir = PACKAGE_DIR / "temp_jre"
111
+ temp_dir.mkdir(exist_ok=True)
112
+ local_file = temp_dir / filename
113
+ try:
114
+
115
+ # Download with progress
116
+ def download_hook(block_num, block_size, total_size):
117
+ if progress_callback and total_size > 0:
118
+ downloaded = block_num * block_size
119
+ percent = int((downloaded * 100) / total_size)
120
+ progress_callback('download', min(percent, 100))
121
+
122
+ # Disable SSL verification for simplicity
123
+ ssl_context = ssl.create_default_context()
124
+ ssl_context.check_hostname = False
125
+ ssl_context.verify_mode = ssl.CERT_NONE
126
+
127
+ urllib.request.urlretrieve(download_url, str(local_file), reporthook=download_hook)
128
+
129
+ # Extract
130
+ extract_dir = temp_dir / "extracted"
131
+ extract_dir.mkdir(exist_ok=True)
132
+
133
+ if local_file.suffix == '.zip':
134
+ with zipfile.ZipFile(local_file, 'r') as zf:
135
+ members = zf.namelist()
136
+ total = len(members)
137
+ for i, member in enumerate(members):
138
+ zf.extract(member, extract_dir)
139
+ if progress_callback:
140
+ progress_callback('extract', int((i + 1) * 100 / total))
141
+ else:
142
+ with tarfile.open(local_file, 'r:gz') as tf:
143
+ members = tf.getmembers()
144
+ total = len(members)
145
+ for i, member in enumerate(members):
146
+ tf.extract(member, extract_dir)
147
+ if progress_callback:
148
+ progress_callback('extract', int((i + 1) * 100 / total))
149
+
150
+ # Move extracted JRE to target directory
151
+ if JRE_DIR.exists():
152
+ shutil.rmtree(JRE_DIR)
153
+
154
+ # Find the actual JRE directory
155
+ extracted_items = list(extract_dir.iterdir())
156
+ if len(extracted_items) == 1 and extracted_items[0].is_dir():
157
+ # Single directory containing JRE
158
+ shutil.move(str(extracted_items[0]), str(JRE_DIR))
159
+ else:
160
+ # Multiple items, move them all
161
+ JRE_DIR.mkdir(exist_ok=True)
162
+ for item in extracted_items:
163
+ shutil.move(str(item), str(JRE_DIR / item.name))
164
+
165
+ # Cleanup
166
+ shutil.rmtree(temp_dir)
167
+
168
+ return get_jvm_library_path() is not None
169
+
170
+ except Exception as e:
171
+ print(f"Failed to download/extract JRE: {e}", file=sys.stderr)
172
+ print("Please check your network connection and retry.", file=sys.stderr)
173
+ print("If the problem persists, please download it manually:", file=sys.stderr)
174
+ print(f"1. Download: {download_url}", file=sys.stderr)
175
+ print(f"2. Extract the contents to: {JRE_DIR}", file=sys.stderr)
176
+ print(f"3. Ensure the file exists: {jvm_paths.get(platform.system().lower())}.", file=sys.stderr)
177
+ return False
178
+
179
+
180
+ def ensure_jre() -> bool:
181
+ """
182
+ Ensure JRE is available, download if not present.
183
+
184
+ :return: True if JRE is ready, False otherwise
185
+ """
186
+ if get_jvm_library_path():
187
+ return True
188
+
189
+ print("JRE not found. Downloading... (first-time setup only)")
190
+
191
+ def progress_callback(step, percent):
192
+ if step == 'download':
193
+ print(f"\rDownloading: {percent}%", end='', flush=True)
194
+ elif step == 'extract':
195
+ print(f"\rExtracting: {percent}%", end='', flush=True)
196
+
197
+ success = download_jre(progress_callback)
198
+ if success:
199
+ print("\nJRE installation complete!")
200
+ else:
201
+ print("\nJRE installation failed!")
202
+
203
+ return success
204
+
205
+
206
+ def start_jvm():
207
+ jvm_path = get_jvm_library_path()
208
+ if not jvm_path:
209
+ if not ensure_jre():
210
+ return False
211
+ jvm_path = get_jvm_library_path()
212
+ if not jvm_path:
213
+ return False
214
+
215
+ try:
216
+ jpype.startJVM(jvm_path, classpath=[JAR_FILE], convertStrings=False)
217
+ return jpype.isJVMStarted()
218
+ except Exception as e:
219
+ print(f"Failed to start JVM: {e}", file=sys.stderr)
220
+ return False
221
+
222
+
223
+ def initialize():
224
+ """
225
+ Initialize JVM and run MuKiP initialization.
226
+ This should be called before any MuKiP functions.
227
+ """
228
+ global MUKIP_INITIALIZED
229
+ if MUKIP_INITIALIZED:
230
+ return
231
+ try:
232
+ if not jpype.isJVMStarted():
233
+ start_jvm()
234
+ JClass("com.wang_lab.mukip.components.ComponentsLoader").initializeComponents("")
235
+ MUKIP_INITIALIZED = True
236
+ except Exception as e:
237
+ print(f"MuKiP initialization failed: {e}", file=sys.stderr)
238
+
239
+ def get_class(classpath: str):
240
+ initialize()
241
+ return JClass(classpath)
@@ -0,0 +1,154 @@
1
+ import multiprocessing
2
+ from typing import Literal
3
+ import jpype
4
+
5
+ from mukip.plot_2d import plot_2d
6
+ from mukip.jvm_manager import get_class
7
+
8
+
9
+ def _wrap_thermo(thermo):
10
+ return {str(k): float(v) for k, v in thermo.getValues().items()}
11
+
12
+ def _wrap_vector(vector):
13
+ return [float(b.doubleValue()) for b in vector]
14
+
15
+ def _wrap_double_array(array):
16
+ return [float(d) for d in array]
17
+
18
+ class MicrokineticModel:
19
+ """
20
+ A wrapped class for microkinetic modeling.
21
+ """
22
+ MKM = get_class("com.wang_lab.mukip.components.model.MicrokineticModel").class_.getField("Companion").get(None)
23
+ ArrayList = get_class('java.util.ArrayList')
24
+ LinkedHashMap = get_class("java.util.LinkedHashMap")
25
+ GridPoint = get_class("com.wang_lab.mukip.point.GridPoint")
26
+ MapPoint = get_class("com.wang_lab.mukip.point.MapPoint")
27
+ JUtils = get_class("com.wang_lab.mukip.MuKiPUtils")
28
+ Adsorbate = get_class("com.wang_lab.mukip.species.Adsorbate")
29
+ Gas = get_class("com.wang_lab.mukip.species.Gas")
30
+ Fluid = get_class("com.wang_lab.mukip.species.Fluid")
31
+ Aqua = get_class("com.wang_lab.mukip.species.Aqua")
32
+ n_cpu = multiprocessing.cpu_count()
33
+ def __init__(self, setup_file: str):
34
+ """
35
+ Create a new MicrokineticModel object.
36
+
37
+ :param setup_file: The path of the setup file for microkinetic model
38
+ """
39
+ self.setup_file = setup_file
40
+ self.model = self.MKM.reactionModel(jpype.JString(self.setup_file))
41
+
42
+ def load_data(self):
43
+ """
44
+ Load solutions for each point from the data file.
45
+ """
46
+ self.model.loadData()
47
+
48
+ def save_data(self):
49
+ """
50
+ Save solutions on each point to the data file.
51
+ """
52
+ self.model.saveData()
53
+
54
+ def _get_grid_point(self, *index: int):
55
+ return self.model.getPoint(self.GridPoint(*index))
56
+
57
+ def get_global_thermo(self):
58
+ """
59
+ Get global thermodynamic data.
60
+ :return: A dictionary of all global thermodynamic variables.
61
+ """
62
+ return _wrap_thermo(self.model.getThermo())
63
+
64
+ def get_grid_thermo(self, *index: int):
65
+ """
66
+ Get the thermodynamic data of a grid point.
67
+ :param index: The index of the grid point.
68
+ :return: A dictionary of all thermodynamic variables on this point.
69
+ """
70
+ return _wrap_thermo(self._get_grid_point(*index).getThermo())
71
+
72
+ def run(self, method: Literal['map_sample', 'map_in_turn', 'map_expand', 'map_check']='map_sample', n_thread: int = n_cpu):
73
+ """
74
+ Run the microkinetic model with given method.
75
+ :param method: The method to use to run the microkinetic model.
76
+ - 'map_sample': Use sampling-expansion method (SEM), which boasts high solving efficiency and stability
77
+ , and is highly recommended.
78
+ - 'map_in_turn': Solve all points in turn.
79
+ - 'map_expand': Solve all unsolved points from points with solution. It should be used when a partial set
80
+ of points are successfully solved.
81
+ - 'map_check': Check existing solutions on all points in turn. If the solution on a point is incorrect,
82
+ it will attempt to solve it. If the solution fails, the data will be deleted to ensure that
83
+ all solutions are correct after the method is completed.
84
+ :param n_thread: Number of threads to use.
85
+ """
86
+ self.model.getMapper().map(method, n_thread, self.ArrayList(), None)
87
+
88
+ def get_result(self, *index: int):
89
+ """
90
+ Get results on a point.
91
+ :param index: The index of the grid point.
92
+ :return: Results as a dictionary.
93
+ """
94
+ pt = self._get_grid_point(*index)
95
+ results = {}
96
+ if r := pt.getCoverage():
97
+ adsorbates = self.model.get(self.Adsorbate)
98
+ coverage = _wrap_vector(r)
99
+ results["coverage"] = {str(ads.getName()): c for ads, c in zip(adsorbates, coverage)}
100
+ if r := pt.getTof():
101
+ fluids = self.model.get(self.Fluid)
102
+ tof = _wrap_double_array(r)
103
+ results["TOF"] = {str(gas.getName()): p for gas, p in zip(fluids, tof)}
104
+ if r := pt.getCurrent():
105
+ results["current"] = float(r)
106
+ if r := pt.getPressure():
107
+ gases = self.model.get(self.Gas)
108
+ pressure = _wrap_vector(r)
109
+ results["pressure"] = {str(gas.getName()): p for gas, p in zip(gases, pressure)}
110
+ if r := pt.getConcentration():
111
+ aquas = self.model.get(self.Aqua)
112
+ concentration = _wrap_vector(r)
113
+ results["concentration"] = {str(aq.getName()): c for aq, c in zip(aquas, concentration)}
114
+ return results
115
+
116
+ def get_variables(self, *index: int):
117
+ """
118
+ Get all fined variables on a point.
119
+ :param index: The index of the grid point.
120
+ :return: Fined variables as a dictionary.
121
+ """
122
+ pt = self._get_grid_point(*index)
123
+ variables = self.model.getSolver().getFinePointValues(pt)
124
+ results = {str(k): float(v.doubleValue()) for k, v in variables.items()}
125
+ expressions = self.model.getSolver().getExpressionDictionary()
126
+ buffer = self.LinkedHashMap()
127
+ for k, v in expressions.items():
128
+ results[k] = float(self.JUtils.toDouble(v.arithmetic(variables, expressions, buffer)))
129
+ return results
130
+
131
+ def write(self, plot: bool = False, fig_size=None, contour_kw=None, clabel_kw=None, contourf_kw=None):
132
+ """
133
+ Write data with writers defined in the setup file.
134
+ :param plot: Whether to generate and save plots along with data.
135
+ Default: False
136
+ :param fig_size: Figure size in inches as a tuple (width, height).
137
+ Default: (9, 6)
138
+ :param contour_kw: Keyword arguments passed to matplotlib's contour() for contour lines.
139
+ Default: {'levels': 31, 'colors': 'black', 'linewidths': 0.5}
140
+ :param clabel_kw: Keyword arguments passed to matplotlib's clabel() for contour labels.
141
+ Default: {'inline': True, 'fontsize': 8}
142
+ :param contourf_kw: Keyword arguments passed to matplotlib's contourf() for filled contours.
143
+ Default: {'levels': 31, 'cmap': 'jet'}
144
+ """
145
+ for writer in self.model.getWriters():
146
+ writer.output()
147
+ if plot:
148
+ class_name = writer.__class__.__name__
149
+ csv_path = str(writer.getOutputFile().getAbsolutePath())
150
+ if class_name == "com.wang_lab.mukip.components.writer.CSV2DCustomWriter":
151
+ plot_2d(csv_path, fig_size, contour_kw, clabel_kw, contourf_kw)
152
+ else:
153
+ print(f"Plot method for {class_name} not implemented")
154
+
mukip/plot_2d.py ADDED
@@ -0,0 +1,101 @@
1
+ import os
2
+ import csv
3
+ import numpy as np
4
+
5
+ import matplotlib
6
+ matplotlib.use('Agg')
7
+
8
+ import matplotlib.pyplot as plt
9
+
10
+ def parse_axis_labels(label_str):
11
+ """Extract A (Y-axis label) and B (X-axis label) from a string like 'A\\B'"""
12
+ label_str = label_str.strip()
13
+ if '\\' in label_str:
14
+ a, b = label_str.split('\\', 1)
15
+ return a.strip(), b.strip()
16
+ else:
17
+ return "Y", "X"
18
+
19
+ def read_csv_matrix(file_path):
20
+ """Read CSV file into a 2D list of strings"""
21
+ with open(file_path, newline='', encoding='utf-8') as f:
22
+ reader = csv.reader(f)
23
+ matrix = [row for row in reader]
24
+ return matrix
25
+
26
+ def to_float_or_nan(s):
27
+ """Convert string to float; return np.nan if conversion fails"""
28
+ try:
29
+ return float(s)
30
+ except ValueError:
31
+ return np.nan
32
+
33
+ def plot_2d(csv_path, fig_size=None, contour_kw=None, clabel_kw=None, contourf_kw=None):
34
+ # Read raw data
35
+ data = read_csv_matrix(csv_path)
36
+ title = os.path.splitext(os.path.basename(csv_path))[0]
37
+
38
+ if not data or len(data) < 2 or len(data[0]) < 2:
39
+ raise ValueError("Invalid CSV format: at least a 2x2 grid (including header) is required.")
40
+
41
+ # Parse top-left cell (e.g., "A\\B")
42
+ top_left = data[0][0]
43
+ ylabel, xlabel = parse_axis_labels(top_left)
44
+
45
+ # Extract X values (first row, skip first element)
46
+ x_vals_str = data[0][1:]
47
+ x_vals = np.array([to_float_or_nan(x) for x in x_vals_str])
48
+ x_vals = x_vals[~np.isnan(x_vals)]
49
+
50
+ # Extract Y values (first column, skip first element)
51
+ y_vals_str = [row[0] for row in data[1:]]
52
+ y_vals = np.array([to_float_or_nan(y) for y in y_vals_str])
53
+ y_vals = y_vals[~np.isnan(y_vals)]
54
+
55
+ # Extract Z data
56
+ z_rows = []
57
+ for i in range(1, len(data)):
58
+ row = data[i]
59
+ z_row = []
60
+ for j in range(1, min(len(row), len(x_vals_str) + 1)):
61
+ z_row.append(to_float_or_nan(row[j]))
62
+ while len(z_row) < len(x_vals):
63
+ z_row.append(np.nan)
64
+ z_rows.append(z_row[:len(x_vals)])
65
+
66
+ Z = np.array(z_rows, dtype=float)
67
+
68
+ # Validate dimensions
69
+ if Z.shape[0] != len(y_vals) or Z.shape[1] != len(x_vals):
70
+ print(f"Warning: Data shape {Z.shape} doesn't match Y({len(y_vals)}) × X({len(x_vals)}).")
71
+
72
+ # Create meshgrid
73
+ X, Y = np.meshgrid(x_vals, y_vals)
74
+
75
+ # Check for non-positive values (log undefined for <= 0)
76
+ Z = np.log10(Z)
77
+
78
+ # Plot
79
+ if not fig_size:
80
+ fig_size = (9, 6)
81
+ plt.figure(figsize=fig_size)
82
+
83
+ # Use LogNorm for logarithmic color scale
84
+ if not contour_kw:
85
+ contour_kw = {'levels': 31, 'colors': 'black', 'linewidths': 0.5}
86
+ contour = plt.contour(X, Y, Z, **contour_kw)
87
+
88
+ if not clabel_kw:
89
+ clabel_kw = {'inline': True, 'fontsize': 8}
90
+ plt.clabel(contour, **clabel_kw)
91
+
92
+ if not contourf_kw:
93
+ contourf_kw = {'levels': 31, 'cmap': 'jet'}
94
+ filled = plt.contourf(X, Y, Z, **contourf_kw)
95
+ plt.colorbar(filled)
96
+
97
+ plt.xlabel(xlabel)
98
+ plt.ylabel(ylabel)
99
+ plt.title(f'log({title})')
100
+ plt.tight_layout()
101
+ plt.savefig(csv_path.replace('.csv', '.png'))
Binary file