rootstock 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rootstock/__init__.py +34 -0
- rootstock/calculator.py +194 -0
- rootstock/cli.py +426 -0
- rootstock/clusters.py +41 -0
- rootstock/environment.py +238 -0
- rootstock/pep723.py +172 -0
- rootstock/protocol.py +309 -0
- rootstock/server.py +287 -0
- rootstock/worker.py +273 -0
- rootstock-0.5.0.dist-info/METADATA +210 -0
- rootstock-0.5.0.dist-info/RECORD +14 -0
- rootstock-0.5.0.dist-info/WHEEL +4 -0
- rootstock-0.5.0.dist-info/entry_points.txt +2 -0
- rootstock-0.5.0.dist-info/licenses/LICENSE.md +7 -0
rootstock/protocol.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Minimal i-PI protocol implementation for Rootstock.
|
|
3
|
+
|
|
4
|
+
This is adapted from ASE's ase.calculators.socketio module.
|
|
5
|
+
The i-PI protocol is a simple binary protocol for communicating atomic
|
|
6
|
+
simulation data over sockets.
|
|
7
|
+
|
|
8
|
+
Protocol Overview:
|
|
9
|
+
- Commands are 12-byte ASCII strings (e.g., "STATUS", "POSDATA", "GETFORCE")
|
|
10
|
+
- Data is transmitted as raw numpy arrays (no JSON/msgpack serialization)
|
|
11
|
+
- Units are atomic units (Bohr, Hartree) - we convert to ASE units (Å, eV)
|
|
12
|
+
|
|
13
|
+
Reference: https://docs.ipi-code.org/
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import socket
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
# ASE units conversion
|
|
21
|
+
BOHR_TO_ANGSTROM = 0.52917721067
|
|
22
|
+
HARTREE_TO_EV = 27.211386245988
|
|
23
|
+
ANGSTROM_TO_BOHR = 1.0 / BOHR_TO_ANGSTROM
|
|
24
|
+
EV_TO_HARTREE = 1.0 / HARTREE_TO_EV
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class SocketClosed(Exception):
|
|
28
|
+
"""Raised when socket connection is closed."""
|
|
29
|
+
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class IPIProtocol:
|
|
34
|
+
"""
|
|
35
|
+
Communication using the i-PI protocol.
|
|
36
|
+
|
|
37
|
+
This handles the low-level socket communication, including:
|
|
38
|
+
- Sending/receiving fixed-length command strings
|
|
39
|
+
- Sending/receiving numpy arrays as raw bytes
|
|
40
|
+
- Unit conversions between i-PI (atomic units) and ASE (Å, eV)
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(self, sock: socket.socket, log=None):
|
|
44
|
+
"""
|
|
45
|
+
Initialize protocol handler.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
sock: Connected socket object
|
|
49
|
+
log: Optional file object for logging (useful for debugging)
|
|
50
|
+
"""
|
|
51
|
+
self.socket = sock
|
|
52
|
+
self.log = log
|
|
53
|
+
|
|
54
|
+
def _log(self, *args):
|
|
55
|
+
"""Write to log if logging is enabled."""
|
|
56
|
+
if self.log is not None:
|
|
57
|
+
print(*args, file=self.log, flush=True)
|
|
58
|
+
|
|
59
|
+
# -------------------------------------------------------------------------
|
|
60
|
+
# Low-level send/receive
|
|
61
|
+
# -------------------------------------------------------------------------
|
|
62
|
+
|
|
63
|
+
def sendmsg(self, msg: str):
|
|
64
|
+
"""Send a 12-byte command string."""
|
|
65
|
+
self._log(f" sendmsg: {msg!r}")
|
|
66
|
+
encoded = msg.encode("ascii").ljust(12)
|
|
67
|
+
self.socket.sendall(encoded)
|
|
68
|
+
|
|
69
|
+
def recvmsg(self) -> str:
|
|
70
|
+
"""Receive a 12-byte command string."""
|
|
71
|
+
data = self._recvall(12)
|
|
72
|
+
msg = data.rstrip().decode("ascii")
|
|
73
|
+
self._log(f" recvmsg: {msg!r}")
|
|
74
|
+
return msg
|
|
75
|
+
|
|
76
|
+
def _recvall(self, nbytes: int) -> bytes:
|
|
77
|
+
"""Receive exactly nbytes, handling partial reads."""
|
|
78
|
+
chunks = []
|
|
79
|
+
remaining = nbytes
|
|
80
|
+
while remaining > 0:
|
|
81
|
+
chunk = self.socket.recv(remaining)
|
|
82
|
+
if len(chunk) == 0:
|
|
83
|
+
raise SocketClosed("Socket closed while receiving data")
|
|
84
|
+
chunks.append(chunk)
|
|
85
|
+
remaining -= len(chunk)
|
|
86
|
+
return b"".join(chunks)
|
|
87
|
+
|
|
88
|
+
def send_array(self, arr, dtype):
|
|
89
|
+
"""Send a numpy array as raw bytes."""
|
|
90
|
+
buf = np.asarray(arr, dtype=dtype).tobytes()
|
|
91
|
+
self._log(f" send: {len(buf)} bytes ({dtype})")
|
|
92
|
+
self.socket.sendall(buf)
|
|
93
|
+
|
|
94
|
+
def recv_array(self, shape, dtype) -> np.ndarray:
|
|
95
|
+
"""Receive a numpy array from raw bytes."""
|
|
96
|
+
arr = np.empty(shape, dtype=dtype)
|
|
97
|
+
nbytes = arr.nbytes
|
|
98
|
+
buf = self._recvall(nbytes)
|
|
99
|
+
arr.flat[:] = np.frombuffer(buf, dtype=dtype)
|
|
100
|
+
self._log(f" recv: {nbytes} bytes ({dtype})")
|
|
101
|
+
return arr
|
|
102
|
+
|
|
103
|
+
# -------------------------------------------------------------------------
|
|
104
|
+
# High-level protocol messages
|
|
105
|
+
# -------------------------------------------------------------------------
|
|
106
|
+
|
|
107
|
+
def send_status(self):
|
|
108
|
+
"""Send STATUS request."""
|
|
109
|
+
self._log(" send_status")
|
|
110
|
+
self.sendmsg("STATUS")
|
|
111
|
+
|
|
112
|
+
def recv_status(self) -> str:
|
|
113
|
+
"""Receive status response (READY, HAVEDATA, or NEEDINIT)."""
|
|
114
|
+
return self.recvmsg()
|
|
115
|
+
|
|
116
|
+
def send_init(self, bead_index: int = 0, init_string: bytes = b"\x00"):
|
|
117
|
+
"""Send INIT message (required by some codes, often ignored)."""
|
|
118
|
+
self._log(" send_init")
|
|
119
|
+
self.sendmsg("INIT")
|
|
120
|
+
self.send_array([bead_index], np.int32)
|
|
121
|
+
self.send_array([len(init_string)], np.int32)
|
|
122
|
+
self.send_array(np.frombuffer(init_string, dtype=np.byte), np.byte)
|
|
123
|
+
|
|
124
|
+
def recv_init(self) -> tuple[int, bytes]:
|
|
125
|
+
"""Receive INIT message."""
|
|
126
|
+
bead_index = self.recv_array(1, np.int32)[0]
|
|
127
|
+
nbytes = self.recv_array(1, np.int32)[0]
|
|
128
|
+
init_bytes = self.recv_array(nbytes, np.byte).tobytes()
|
|
129
|
+
return bead_index, init_bytes
|
|
130
|
+
|
|
131
|
+
def send_posdata(self, cell: np.ndarray, positions: np.ndarray):
|
|
132
|
+
"""
|
|
133
|
+
Send atomic positions and cell.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
cell: 3x3 cell matrix in Angstrom
|
|
137
|
+
positions: Nx3 positions in Angstrom
|
|
138
|
+
"""
|
|
139
|
+
self._log(" send_posdata")
|
|
140
|
+
self.sendmsg("POSDATA")
|
|
141
|
+
|
|
142
|
+
# Convert to atomic units and transpose (i-PI convention)
|
|
143
|
+
cell_bohr = cell.T * ANGSTROM_TO_BOHR
|
|
144
|
+
icell_bohr = np.linalg.pinv(cell).T / ANGSTROM_TO_BOHR
|
|
145
|
+
positions_bohr = positions * ANGSTROM_TO_BOHR
|
|
146
|
+
|
|
147
|
+
self.send_array(cell_bohr, np.float64)
|
|
148
|
+
self.send_array(icell_bohr, np.float64)
|
|
149
|
+
self.send_array([len(positions)], np.int32)
|
|
150
|
+
self.send_array(positions_bohr, np.float64)
|
|
151
|
+
|
|
152
|
+
def recv_posdata(self) -> tuple[np.ndarray, np.ndarray]:
|
|
153
|
+
"""
|
|
154
|
+
Receive atomic positions and cell.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
cell: 3x3 cell matrix in Angstrom
|
|
158
|
+
positions: Nx3 positions in Angstrom
|
|
159
|
+
"""
|
|
160
|
+
cell_bohr = self.recv_array((3, 3), np.float64).T.copy()
|
|
161
|
+
_icell_bohr = self.recv_array((3, 3), np.float64).T.copy() # noqa: F841
|
|
162
|
+
natoms = self.recv_array(1, np.int32)[0]
|
|
163
|
+
positions_bohr = self.recv_array((natoms, 3), np.float64)
|
|
164
|
+
|
|
165
|
+
# Convert to ASE units
|
|
166
|
+
cell = cell_bohr * BOHR_TO_ANGSTROM
|
|
167
|
+
positions = positions_bohr * BOHR_TO_ANGSTROM
|
|
168
|
+
return cell, positions
|
|
169
|
+
|
|
170
|
+
def send_getforce(self):
|
|
171
|
+
"""Send GETFORCE request."""
|
|
172
|
+
self._log(" send_getforce")
|
|
173
|
+
self.sendmsg("GETFORCE")
|
|
174
|
+
|
|
175
|
+
def recv_forceready(self) -> tuple[float, np.ndarray, np.ndarray, bytes]:
|
|
176
|
+
"""
|
|
177
|
+
Receive force data after GETFORCE.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
energy: Potential energy in eV
|
|
181
|
+
forces: Nx3 forces in eV/Angstrom
|
|
182
|
+
virial: 3x3 virial tensor in eV
|
|
183
|
+
extra: Extra bytes (often empty)
|
|
184
|
+
"""
|
|
185
|
+
msg = self.recvmsg()
|
|
186
|
+
assert msg == "FORCEREADY", f"Expected FORCEREADY, got {msg}"
|
|
187
|
+
|
|
188
|
+
energy_hartree = self.recv_array(1, np.float64)[0]
|
|
189
|
+
natoms = self.recv_array(1, np.int32)[0]
|
|
190
|
+
forces_au = self.recv_array((natoms, 3), np.float64)
|
|
191
|
+
virial_au = self.recv_array((3, 3), np.float64).T.copy()
|
|
192
|
+
nextra = self.recv_array(1, np.int32)[0]
|
|
193
|
+
extra = self.recv_array(nextra, np.byte).tobytes() if nextra > 0 else b""
|
|
194
|
+
|
|
195
|
+
# Convert to ASE units
|
|
196
|
+
energy = energy_hartree * HARTREE_TO_EV
|
|
197
|
+
forces = forces_au * (HARTREE_TO_EV / BOHR_TO_ANGSTROM)
|
|
198
|
+
virial = virial_au * HARTREE_TO_EV
|
|
199
|
+
|
|
200
|
+
return energy, forces, virial, extra
|
|
201
|
+
|
|
202
|
+
def send_forceready(
|
|
203
|
+
self, energy: float, forces: np.ndarray, virial: np.ndarray, extra: bytes = b"\x00"
|
|
204
|
+
):
|
|
205
|
+
"""
|
|
206
|
+
Send force data in response to GETFORCE.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
energy: Potential energy in eV
|
|
210
|
+
forces: Nx3 forces in eV/Angstrom
|
|
211
|
+
virial: 3x3 virial tensor in eV
|
|
212
|
+
extra: Extra bytes to send (minimum 1 byte)
|
|
213
|
+
"""
|
|
214
|
+
self._log(" send_forceready")
|
|
215
|
+
self.sendmsg("FORCEREADY")
|
|
216
|
+
|
|
217
|
+
# Convert to atomic units
|
|
218
|
+
energy_hartree = energy * EV_TO_HARTREE
|
|
219
|
+
forces_au = forces / (HARTREE_TO_EV / BOHR_TO_ANGSTROM)
|
|
220
|
+
virial_au = virial * EV_TO_HARTREE
|
|
221
|
+
|
|
222
|
+
self.send_array([energy_hartree], np.float64)
|
|
223
|
+
self.send_array([len(forces)], np.int32)
|
|
224
|
+
self.send_array(forces_au, np.float64)
|
|
225
|
+
self.send_array(virial_au.T, np.float64)
|
|
226
|
+
|
|
227
|
+
# Always send at least 1 byte to avoid confusion with closed socket
|
|
228
|
+
if len(extra) == 0:
|
|
229
|
+
extra = b"\x00"
|
|
230
|
+
self.send_array([len(extra)], np.int32)
|
|
231
|
+
self.send_array(np.frombuffer(extra, dtype=np.byte), np.byte)
|
|
232
|
+
|
|
233
|
+
def send_exit(self):
|
|
234
|
+
"""Send EXIT message to terminate connection."""
|
|
235
|
+
self._log(" send_exit")
|
|
236
|
+
self.sendmsg("EXIT")
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
# -----------------------------------------------------------------------------
|
|
240
|
+
# Socket creation helpers
|
|
241
|
+
# -----------------------------------------------------------------------------
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def create_unix_socket_path(name: str) -> str:
|
|
245
|
+
"""
|
|
246
|
+
Create path for Unix domain socket following i-PI convention.
|
|
247
|
+
|
|
248
|
+
i-PI uses /tmp/ipi_<name> as the socket path.
|
|
249
|
+
"""
|
|
250
|
+
return f"/tmp/ipi_{name}"
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def create_server_socket(socket_path: str, timeout: float = None) -> socket.socket:
|
|
254
|
+
"""
|
|
255
|
+
Create and bind a Unix domain socket server.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
socket_path: Path for the socket file
|
|
259
|
+
timeout: Optional timeout in seconds
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
Bound socket ready for listen()
|
|
263
|
+
"""
|
|
264
|
+
import os
|
|
265
|
+
|
|
266
|
+
# Remove stale socket file if it exists
|
|
267
|
+
if os.path.exists(socket_path):
|
|
268
|
+
os.unlink(socket_path)
|
|
269
|
+
|
|
270
|
+
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
|
271
|
+
sock.bind(socket_path)
|
|
272
|
+
if timeout is not None:
|
|
273
|
+
sock.settimeout(timeout)
|
|
274
|
+
|
|
275
|
+
return sock
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def connect_unix_socket(
|
|
279
|
+
socket_path: str, timeout: float = None, max_retries: int = 50, retry_delay: float = 0.1
|
|
280
|
+
) -> socket.socket:
|
|
281
|
+
"""
|
|
282
|
+
Connect to a Unix domain socket server with retries.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
socket_path: Path to the socket file
|
|
286
|
+
timeout: Optional timeout for the connection
|
|
287
|
+
max_retries: Maximum connection attempts
|
|
288
|
+
retry_delay: Delay between retries in seconds
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
Connected socket
|
|
292
|
+
"""
|
|
293
|
+
import time
|
|
294
|
+
|
|
295
|
+
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
|
296
|
+
if timeout is not None:
|
|
297
|
+
sock.settimeout(timeout)
|
|
298
|
+
|
|
299
|
+
for attempt in range(max_retries):
|
|
300
|
+
try:
|
|
301
|
+
sock.connect(socket_path)
|
|
302
|
+
return sock
|
|
303
|
+
except (FileNotFoundError, ConnectionRefusedError):
|
|
304
|
+
if attempt < max_retries - 1:
|
|
305
|
+
time.sleep(retry_delay)
|
|
306
|
+
else:
|
|
307
|
+
raise
|
|
308
|
+
|
|
309
|
+
raise RuntimeError(f"Failed to connect to {socket_path} after {max_retries} attempts")
|
rootstock/server.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Socket server for Rootstock.
|
|
3
|
+
|
|
4
|
+
This runs in the main process and acts as an i-PI server,
|
|
5
|
+
sending atomic positions and receiving forces from a worker process.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
import socket
|
|
11
|
+
import subprocess
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
16
|
+
from .protocol import (
|
|
17
|
+
IPIProtocol,
|
|
18
|
+
SocketClosed,
|
|
19
|
+
create_server_socket,
|
|
20
|
+
create_unix_socket_path,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class RootstockServer:
|
|
25
|
+
"""
|
|
26
|
+
Server that communicates with an MLIP worker process via i-PI protocol.
|
|
27
|
+
|
|
28
|
+
The server:
|
|
29
|
+
1. Creates a Unix domain socket
|
|
30
|
+
2. Launches a worker subprocess using pre-built environment
|
|
31
|
+
3. Accepts the worker's connection
|
|
32
|
+
4. Sends positions, receives forces
|
|
33
|
+
|
|
34
|
+
Example:
|
|
35
|
+
with RootstockServer(
|
|
36
|
+
env_name="mace_env",
|
|
37
|
+
model="medium",
|
|
38
|
+
device="cuda",
|
|
39
|
+
root=Path("/vol/rootstock"),
|
|
40
|
+
) as server:
|
|
41
|
+
energy, forces, virial = server.calculate(positions, cell, numbers)
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
env_name: str,
|
|
47
|
+
model: str,
|
|
48
|
+
device: str = "cuda",
|
|
49
|
+
socket_name: str = "rootstock",
|
|
50
|
+
root: Path | None = None,
|
|
51
|
+
log=None,
|
|
52
|
+
timeout: float = 60.0,
|
|
53
|
+
):
|
|
54
|
+
"""
|
|
55
|
+
Initialize the server.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
env_name: Name of pre-built environment (e.g., "mace_env")
|
|
59
|
+
model: Model identifier to pass to setup()
|
|
60
|
+
device: Device string to pass to setup()
|
|
61
|
+
socket_name: Name for the Unix socket (will be /tmp/ipi_<name>)
|
|
62
|
+
root: Root directory for environments and cache (required)
|
|
63
|
+
log: Optional file object for protocol logging
|
|
64
|
+
timeout: Socket timeout in seconds
|
|
65
|
+
"""
|
|
66
|
+
if root is None:
|
|
67
|
+
raise ValueError("root is required for pre-built environments")
|
|
68
|
+
|
|
69
|
+
self.socket_name = socket_name
|
|
70
|
+
self.socket_path = create_unix_socket_path(socket_name)
|
|
71
|
+
self.log = log
|
|
72
|
+
self.timeout = timeout
|
|
73
|
+
|
|
74
|
+
self.env_name = env_name
|
|
75
|
+
self.model = model
|
|
76
|
+
self.device = device
|
|
77
|
+
self.root = Path(root)
|
|
78
|
+
|
|
79
|
+
self._server_socket: socket.socket | None = None
|
|
80
|
+
self._client_socket: socket.socket | None = None
|
|
81
|
+
self._protocol: IPIProtocol | None = None
|
|
82
|
+
self._process: subprocess.Popen | None = None
|
|
83
|
+
self._connected = False
|
|
84
|
+
|
|
85
|
+
# Track INIT state
|
|
86
|
+
self._init_sent = False
|
|
87
|
+
self._init_numbers: list[int] | None = None
|
|
88
|
+
self._init_pbc: list[bool] | None = None
|
|
89
|
+
|
|
90
|
+
# Environment manager
|
|
91
|
+
self._env_manager = None
|
|
92
|
+
self._wrapper_path: Path | None = None
|
|
93
|
+
|
|
94
|
+
def start(self):
|
|
95
|
+
"""Start the server and launch the worker process."""
|
|
96
|
+
# Create server socket
|
|
97
|
+
self._server_socket = create_server_socket(self.socket_path, timeout=self.timeout)
|
|
98
|
+
self._server_socket.listen(1)
|
|
99
|
+
|
|
100
|
+
if self.log:
|
|
101
|
+
print(f"Server listening on {self.socket_path}", file=self.log, flush=True)
|
|
102
|
+
|
|
103
|
+
# Launch worker process
|
|
104
|
+
self._start_worker()
|
|
105
|
+
|
|
106
|
+
if self.log:
|
|
107
|
+
print(f"Launched worker process (PID {self._process.pid})", file=self.log, flush=True)
|
|
108
|
+
|
|
109
|
+
# Wait for worker to connect
|
|
110
|
+
self._accept_connection()
|
|
111
|
+
|
|
112
|
+
def _start_worker(self):
|
|
113
|
+
"""Start worker using pre-built environment."""
|
|
114
|
+
from .environment import EnvironmentManager
|
|
115
|
+
|
|
116
|
+
# Create environment manager
|
|
117
|
+
self._env_manager = EnvironmentManager(root=self.root)
|
|
118
|
+
|
|
119
|
+
# Generate wrapper script
|
|
120
|
+
self._wrapper_path = self._env_manager.generate_wrapper(
|
|
121
|
+
env_name=self.env_name,
|
|
122
|
+
model=self.model,
|
|
123
|
+
device=self.device,
|
|
124
|
+
socket_path=self.socket_path,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Get spawn command and environment
|
|
128
|
+
cmd = self._env_manager.get_spawn_command(self.env_name, self._wrapper_path)
|
|
129
|
+
env = self._env_manager.get_environment_variables()
|
|
130
|
+
|
|
131
|
+
if self.log:
|
|
132
|
+
print(f"Spawning worker: {' '.join(cmd)}", file=self.log, flush=True)
|
|
133
|
+
|
|
134
|
+
self._process = subprocess.Popen(
|
|
135
|
+
cmd,
|
|
136
|
+
env=env,
|
|
137
|
+
stdout=subprocess.PIPE if not self.log else None,
|
|
138
|
+
stderr=subprocess.PIPE if not self.log else None,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
def _accept_connection(self):
|
|
142
|
+
"""Accept connection from worker process."""
|
|
143
|
+
# Use short timeout for accept so we can check if process died
|
|
144
|
+
self._server_socket.settimeout(1.0)
|
|
145
|
+
|
|
146
|
+
while True:
|
|
147
|
+
try:
|
|
148
|
+
self._client_socket, addr = self._server_socket.accept()
|
|
149
|
+
break
|
|
150
|
+
except TimeoutError:
|
|
151
|
+
# Check if process died
|
|
152
|
+
if self._process.poll() is not None:
|
|
153
|
+
stdout, stderr = self._process.communicate()
|
|
154
|
+
raise RuntimeError(
|
|
155
|
+
f"Worker process died with code {self._process.returncode}.\n"
|
|
156
|
+
f"stdout: {stdout}\nstderr: {stderr}"
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
# Restore original timeout
|
|
160
|
+
self._server_socket.settimeout(self.timeout)
|
|
161
|
+
self._client_socket.settimeout(self.timeout)
|
|
162
|
+
|
|
163
|
+
self._protocol = IPIProtocol(self._client_socket, log=self.log)
|
|
164
|
+
self._connected = True
|
|
165
|
+
|
|
166
|
+
if self.log:
|
|
167
|
+
print("Worker connected", file=self.log, flush=True)
|
|
168
|
+
|
|
169
|
+
def calculate(
|
|
170
|
+
self,
|
|
171
|
+
positions: np.ndarray,
|
|
172
|
+
cell: np.ndarray,
|
|
173
|
+
atomic_numbers: np.ndarray | None = None,
|
|
174
|
+
pbc: list[bool] | None = None,
|
|
175
|
+
) -> tuple[float, np.ndarray, np.ndarray]:
|
|
176
|
+
"""
|
|
177
|
+
Calculate energy and forces for given atomic configuration.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
positions: Nx3 array of atomic positions in Angstrom
|
|
181
|
+
cell: 3x3 cell matrix in Angstrom
|
|
182
|
+
atomic_numbers: Atomic numbers array (sent in INIT on first call)
|
|
183
|
+
pbc: Periodic boundary conditions [x, y, z] (sent in INIT on first call)
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
energy: Potential energy in eV
|
|
187
|
+
forces: Nx3 forces in eV/Angstrom
|
|
188
|
+
virial: 3x3 virial tensor in eV
|
|
189
|
+
"""
|
|
190
|
+
if not self._connected:
|
|
191
|
+
raise RuntimeError("Server not connected. Call start() first.")
|
|
192
|
+
|
|
193
|
+
# Check worker status
|
|
194
|
+
self._protocol.send_status()
|
|
195
|
+
status = self._protocol.recv_status()
|
|
196
|
+
|
|
197
|
+
if status == "NEEDINIT":
|
|
198
|
+
# Send INIT with atomic species info
|
|
199
|
+
init_data = {
|
|
200
|
+
"numbers": atomic_numbers.tolist() if atomic_numbers is not None else None,
|
|
201
|
+
"pbc": [bool(p) for p in pbc] if pbc is not None else [True, True, True],
|
|
202
|
+
}
|
|
203
|
+
init_bytes = json.dumps(init_data).encode("utf-8")
|
|
204
|
+
self._protocol.send_init(bead_index=0, init_string=init_bytes)
|
|
205
|
+
|
|
206
|
+
# Track what we sent
|
|
207
|
+
self._init_sent = True
|
|
208
|
+
self._init_numbers = init_data["numbers"]
|
|
209
|
+
self._init_pbc = init_data["pbc"]
|
|
210
|
+
|
|
211
|
+
self._protocol.send_status()
|
|
212
|
+
status = self._protocol.recv_status()
|
|
213
|
+
|
|
214
|
+
if status != "READY":
|
|
215
|
+
raise RuntimeError(f"Worker not ready, status: {status}")
|
|
216
|
+
|
|
217
|
+
# Send positions
|
|
218
|
+
self._protocol.send_posdata(cell, positions)
|
|
219
|
+
|
|
220
|
+
# Check status - worker should now be calculating
|
|
221
|
+
self._protocol.send_status()
|
|
222
|
+
status = self._protocol.recv_status()
|
|
223
|
+
|
|
224
|
+
if status != "HAVEDATA":
|
|
225
|
+
raise RuntimeError(f"Worker failed to calculate, status: {status}")
|
|
226
|
+
|
|
227
|
+
# Get results
|
|
228
|
+
self._protocol.send_getforce()
|
|
229
|
+
energy, forces, virial, extra = self._protocol.recv_forceready()
|
|
230
|
+
|
|
231
|
+
return energy, forces, virial
|
|
232
|
+
|
|
233
|
+
def stop(self):
|
|
234
|
+
"""Stop the server and terminate the worker process."""
|
|
235
|
+
if self._protocol is not None:
|
|
236
|
+
try:
|
|
237
|
+
self._protocol.send_exit()
|
|
238
|
+
except (BrokenPipeError, SocketClosed):
|
|
239
|
+
pass
|
|
240
|
+
|
|
241
|
+
if self._client_socket is not None:
|
|
242
|
+
self._client_socket.close()
|
|
243
|
+
self._client_socket = None
|
|
244
|
+
|
|
245
|
+
if self._server_socket is not None:
|
|
246
|
+
self._server_socket.close()
|
|
247
|
+
self._server_socket = None
|
|
248
|
+
|
|
249
|
+
if self._process is not None:
|
|
250
|
+
self._process.terminate()
|
|
251
|
+
try:
|
|
252
|
+
self._process.wait(timeout=5.0)
|
|
253
|
+
except subprocess.TimeoutExpired:
|
|
254
|
+
self._process.kill()
|
|
255
|
+
self._process.wait()
|
|
256
|
+
self._process = None
|
|
257
|
+
|
|
258
|
+
# Clean up socket file
|
|
259
|
+
if os.path.exists(self.socket_path):
|
|
260
|
+
os.unlink(self.socket_path)
|
|
261
|
+
|
|
262
|
+
# Clean up wrapper script
|
|
263
|
+
if self._wrapper_path is not None:
|
|
264
|
+
try:
|
|
265
|
+
self._wrapper_path.unlink(missing_ok=True)
|
|
266
|
+
except Exception:
|
|
267
|
+
pass
|
|
268
|
+
self._wrapper_path = None
|
|
269
|
+
|
|
270
|
+
# Clean up environment manager
|
|
271
|
+
if self._env_manager is not None:
|
|
272
|
+
self._env_manager.cleanup()
|
|
273
|
+
self._env_manager = None
|
|
274
|
+
|
|
275
|
+
self._connected = False
|
|
276
|
+
self._protocol = None
|
|
277
|
+
|
|
278
|
+
if self.log:
|
|
279
|
+
print("Server stopped", file=self.log, flush=True)
|
|
280
|
+
|
|
281
|
+
def __enter__(self):
|
|
282
|
+
self.start()
|
|
283
|
+
return self
|
|
284
|
+
|
|
285
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
286
|
+
self.stop()
|
|
287
|
+
return False
|