rootstock 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rootstock/protocol.py ADDED
@@ -0,0 +1,309 @@
1
+ """
2
+ Minimal i-PI protocol implementation for Rootstock.
3
+
4
+ This is adapted from ASE's ase.calculators.socketio module.
5
+ The i-PI protocol is a simple binary protocol for communicating atomic
6
+ simulation data over sockets.
7
+
8
+ Protocol Overview:
9
+ - Commands are 12-byte ASCII strings (e.g., "STATUS", "POSDATA", "GETFORCE")
10
+ - Data is transmitted as raw numpy arrays (no JSON/msgpack serialization)
11
+ - Units are atomic units (Bohr, Hartree) - we convert to ASE units (Å, eV)
12
+
13
+ Reference: https://docs.ipi-code.org/
14
+ """
15
+
16
+ import socket
17
+
18
+ import numpy as np
19
+
20
+ # ASE units conversion
21
+ BOHR_TO_ANGSTROM = 0.52917721067
22
+ HARTREE_TO_EV = 27.211386245988
23
+ ANGSTROM_TO_BOHR = 1.0 / BOHR_TO_ANGSTROM
24
+ EV_TO_HARTREE = 1.0 / HARTREE_TO_EV
25
+
26
+
27
+ class SocketClosed(Exception):
28
+ """Raised when socket connection is closed."""
29
+
30
+ pass
31
+
32
+
33
+ class IPIProtocol:
34
+ """
35
+ Communication using the i-PI protocol.
36
+
37
+ This handles the low-level socket communication, including:
38
+ - Sending/receiving fixed-length command strings
39
+ - Sending/receiving numpy arrays as raw bytes
40
+ - Unit conversions between i-PI (atomic units) and ASE (Å, eV)
41
+ """
42
+
43
+ def __init__(self, sock: socket.socket, log=None):
44
+ """
45
+ Initialize protocol handler.
46
+
47
+ Args:
48
+ sock: Connected socket object
49
+ log: Optional file object for logging (useful for debugging)
50
+ """
51
+ self.socket = sock
52
+ self.log = log
53
+
54
+ def _log(self, *args):
55
+ """Write to log if logging is enabled."""
56
+ if self.log is not None:
57
+ print(*args, file=self.log, flush=True)
58
+
59
+ # -------------------------------------------------------------------------
60
+ # Low-level send/receive
61
+ # -------------------------------------------------------------------------
62
+
63
+ def sendmsg(self, msg: str):
64
+ """Send a 12-byte command string."""
65
+ self._log(f" sendmsg: {msg!r}")
66
+ encoded = msg.encode("ascii").ljust(12)
67
+ self.socket.sendall(encoded)
68
+
69
+ def recvmsg(self) -> str:
70
+ """Receive a 12-byte command string."""
71
+ data = self._recvall(12)
72
+ msg = data.rstrip().decode("ascii")
73
+ self._log(f" recvmsg: {msg!r}")
74
+ return msg
75
+
76
+ def _recvall(self, nbytes: int) -> bytes:
77
+ """Receive exactly nbytes, handling partial reads."""
78
+ chunks = []
79
+ remaining = nbytes
80
+ while remaining > 0:
81
+ chunk = self.socket.recv(remaining)
82
+ if len(chunk) == 0:
83
+ raise SocketClosed("Socket closed while receiving data")
84
+ chunks.append(chunk)
85
+ remaining -= len(chunk)
86
+ return b"".join(chunks)
87
+
88
+ def send_array(self, arr, dtype):
89
+ """Send a numpy array as raw bytes."""
90
+ buf = np.asarray(arr, dtype=dtype).tobytes()
91
+ self._log(f" send: {len(buf)} bytes ({dtype})")
92
+ self.socket.sendall(buf)
93
+
94
+ def recv_array(self, shape, dtype) -> np.ndarray:
95
+ """Receive a numpy array from raw bytes."""
96
+ arr = np.empty(shape, dtype=dtype)
97
+ nbytes = arr.nbytes
98
+ buf = self._recvall(nbytes)
99
+ arr.flat[:] = np.frombuffer(buf, dtype=dtype)
100
+ self._log(f" recv: {nbytes} bytes ({dtype})")
101
+ return arr
102
+
103
+ # -------------------------------------------------------------------------
104
+ # High-level protocol messages
105
+ # -------------------------------------------------------------------------
106
+
107
+ def send_status(self):
108
+ """Send STATUS request."""
109
+ self._log(" send_status")
110
+ self.sendmsg("STATUS")
111
+
112
+ def recv_status(self) -> str:
113
+ """Receive status response (READY, HAVEDATA, or NEEDINIT)."""
114
+ return self.recvmsg()
115
+
116
+ def send_init(self, bead_index: int = 0, init_string: bytes = b"\x00"):
117
+ """Send INIT message (required by some codes, often ignored)."""
118
+ self._log(" send_init")
119
+ self.sendmsg("INIT")
120
+ self.send_array([bead_index], np.int32)
121
+ self.send_array([len(init_string)], np.int32)
122
+ self.send_array(np.frombuffer(init_string, dtype=np.byte), np.byte)
123
+
124
+ def recv_init(self) -> tuple[int, bytes]:
125
+ """Receive INIT message."""
126
+ bead_index = self.recv_array(1, np.int32)[0]
127
+ nbytes = self.recv_array(1, np.int32)[0]
128
+ init_bytes = self.recv_array(nbytes, np.byte).tobytes()
129
+ return bead_index, init_bytes
130
+
131
+ def send_posdata(self, cell: np.ndarray, positions: np.ndarray):
132
+ """
133
+ Send atomic positions and cell.
134
+
135
+ Args:
136
+ cell: 3x3 cell matrix in Angstrom
137
+ positions: Nx3 positions in Angstrom
138
+ """
139
+ self._log(" send_posdata")
140
+ self.sendmsg("POSDATA")
141
+
142
+ # Convert to atomic units and transpose (i-PI convention)
143
+ cell_bohr = cell.T * ANGSTROM_TO_BOHR
144
+ icell_bohr = np.linalg.pinv(cell).T / ANGSTROM_TO_BOHR
145
+ positions_bohr = positions * ANGSTROM_TO_BOHR
146
+
147
+ self.send_array(cell_bohr, np.float64)
148
+ self.send_array(icell_bohr, np.float64)
149
+ self.send_array([len(positions)], np.int32)
150
+ self.send_array(positions_bohr, np.float64)
151
+
152
+ def recv_posdata(self) -> tuple[np.ndarray, np.ndarray]:
153
+ """
154
+ Receive atomic positions and cell.
155
+
156
+ Returns:
157
+ cell: 3x3 cell matrix in Angstrom
158
+ positions: Nx3 positions in Angstrom
159
+ """
160
+ cell_bohr = self.recv_array((3, 3), np.float64).T.copy()
161
+ _icell_bohr = self.recv_array((3, 3), np.float64).T.copy() # noqa: F841
162
+ natoms = self.recv_array(1, np.int32)[0]
163
+ positions_bohr = self.recv_array((natoms, 3), np.float64)
164
+
165
+ # Convert to ASE units
166
+ cell = cell_bohr * BOHR_TO_ANGSTROM
167
+ positions = positions_bohr * BOHR_TO_ANGSTROM
168
+ return cell, positions
169
+
170
+ def send_getforce(self):
171
+ """Send GETFORCE request."""
172
+ self._log(" send_getforce")
173
+ self.sendmsg("GETFORCE")
174
+
175
+ def recv_forceready(self) -> tuple[float, np.ndarray, np.ndarray, bytes]:
176
+ """
177
+ Receive force data after GETFORCE.
178
+
179
+ Returns:
180
+ energy: Potential energy in eV
181
+ forces: Nx3 forces in eV/Angstrom
182
+ virial: 3x3 virial tensor in eV
183
+ extra: Extra bytes (often empty)
184
+ """
185
+ msg = self.recvmsg()
186
+ assert msg == "FORCEREADY", f"Expected FORCEREADY, got {msg}"
187
+
188
+ energy_hartree = self.recv_array(1, np.float64)[0]
189
+ natoms = self.recv_array(1, np.int32)[0]
190
+ forces_au = self.recv_array((natoms, 3), np.float64)
191
+ virial_au = self.recv_array((3, 3), np.float64).T.copy()
192
+ nextra = self.recv_array(1, np.int32)[0]
193
+ extra = self.recv_array(nextra, np.byte).tobytes() if nextra > 0 else b""
194
+
195
+ # Convert to ASE units
196
+ energy = energy_hartree * HARTREE_TO_EV
197
+ forces = forces_au * (HARTREE_TO_EV / BOHR_TO_ANGSTROM)
198
+ virial = virial_au * HARTREE_TO_EV
199
+
200
+ return energy, forces, virial, extra
201
+
202
+ def send_forceready(
203
+ self, energy: float, forces: np.ndarray, virial: np.ndarray, extra: bytes = b"\x00"
204
+ ):
205
+ """
206
+ Send force data in response to GETFORCE.
207
+
208
+ Args:
209
+ energy: Potential energy in eV
210
+ forces: Nx3 forces in eV/Angstrom
211
+ virial: 3x3 virial tensor in eV
212
+ extra: Extra bytes to send (minimum 1 byte)
213
+ """
214
+ self._log(" send_forceready")
215
+ self.sendmsg("FORCEREADY")
216
+
217
+ # Convert to atomic units
218
+ energy_hartree = energy * EV_TO_HARTREE
219
+ forces_au = forces / (HARTREE_TO_EV / BOHR_TO_ANGSTROM)
220
+ virial_au = virial * EV_TO_HARTREE
221
+
222
+ self.send_array([energy_hartree], np.float64)
223
+ self.send_array([len(forces)], np.int32)
224
+ self.send_array(forces_au, np.float64)
225
+ self.send_array(virial_au.T, np.float64)
226
+
227
+ # Always send at least 1 byte to avoid confusion with closed socket
228
+ if len(extra) == 0:
229
+ extra = b"\x00"
230
+ self.send_array([len(extra)], np.int32)
231
+ self.send_array(np.frombuffer(extra, dtype=np.byte), np.byte)
232
+
233
+ def send_exit(self):
234
+ """Send EXIT message to terminate connection."""
235
+ self._log(" send_exit")
236
+ self.sendmsg("EXIT")
237
+
238
+
239
+ # -----------------------------------------------------------------------------
240
+ # Socket creation helpers
241
+ # -----------------------------------------------------------------------------
242
+
243
+
244
+ def create_unix_socket_path(name: str) -> str:
245
+ """
246
+ Create path for Unix domain socket following i-PI convention.
247
+
248
+ i-PI uses /tmp/ipi_<name> as the socket path.
249
+ """
250
+ return f"/tmp/ipi_{name}"
251
+
252
+
253
+ def create_server_socket(socket_path: str, timeout: float = None) -> socket.socket:
254
+ """
255
+ Create and bind a Unix domain socket server.
256
+
257
+ Args:
258
+ socket_path: Path for the socket file
259
+ timeout: Optional timeout in seconds
260
+
261
+ Returns:
262
+ Bound socket ready for listen()
263
+ """
264
+ import os
265
+
266
+ # Remove stale socket file if it exists
267
+ if os.path.exists(socket_path):
268
+ os.unlink(socket_path)
269
+
270
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
271
+ sock.bind(socket_path)
272
+ if timeout is not None:
273
+ sock.settimeout(timeout)
274
+
275
+ return sock
276
+
277
+
278
+ def connect_unix_socket(
279
+ socket_path: str, timeout: float = None, max_retries: int = 50, retry_delay: float = 0.1
280
+ ) -> socket.socket:
281
+ """
282
+ Connect to a Unix domain socket server with retries.
283
+
284
+ Args:
285
+ socket_path: Path to the socket file
286
+ timeout: Optional timeout for the connection
287
+ max_retries: Maximum connection attempts
288
+ retry_delay: Delay between retries in seconds
289
+
290
+ Returns:
291
+ Connected socket
292
+ """
293
+ import time
294
+
295
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
296
+ if timeout is not None:
297
+ sock.settimeout(timeout)
298
+
299
+ for attempt in range(max_retries):
300
+ try:
301
+ sock.connect(socket_path)
302
+ return sock
303
+ except (FileNotFoundError, ConnectionRefusedError):
304
+ if attempt < max_retries - 1:
305
+ time.sleep(retry_delay)
306
+ else:
307
+ raise
308
+
309
+ raise RuntimeError(f"Failed to connect to {socket_path} after {max_retries} attempts")
rootstock/server.py ADDED
@@ -0,0 +1,287 @@
1
+ """
2
+ Socket server for Rootstock.
3
+
4
+ This runs in the main process and acts as an i-PI server,
5
+ sending atomic positions and receiving forces from a worker process.
6
+ """
7
+
8
+ import json
9
+ import os
10
+ import socket
11
+ import subprocess
12
+ from pathlib import Path
13
+
14
+ import numpy as np
15
+
16
+ from .protocol import (
17
+ IPIProtocol,
18
+ SocketClosed,
19
+ create_server_socket,
20
+ create_unix_socket_path,
21
+ )
22
+
23
+
24
+ class RootstockServer:
25
+ """
26
+ Server that communicates with an MLIP worker process via i-PI protocol.
27
+
28
+ The server:
29
+ 1. Creates a Unix domain socket
30
+ 2. Launches a worker subprocess using pre-built environment
31
+ 3. Accepts the worker's connection
32
+ 4. Sends positions, receives forces
33
+
34
+ Example:
35
+ with RootstockServer(
36
+ env_name="mace_env",
37
+ model="medium",
38
+ device="cuda",
39
+ root=Path("/vol/rootstock"),
40
+ ) as server:
41
+ energy, forces, virial = server.calculate(positions, cell, numbers)
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ env_name: str,
47
+ model: str,
48
+ device: str = "cuda",
49
+ socket_name: str = "rootstock",
50
+ root: Path | None = None,
51
+ log=None,
52
+ timeout: float = 60.0,
53
+ ):
54
+ """
55
+ Initialize the server.
56
+
57
+ Args:
58
+ env_name: Name of pre-built environment (e.g., "mace_env")
59
+ model: Model identifier to pass to setup()
60
+ device: Device string to pass to setup()
61
+ socket_name: Name for the Unix socket (will be /tmp/ipi_<name>)
62
+ root: Root directory for environments and cache (required)
63
+ log: Optional file object for protocol logging
64
+ timeout: Socket timeout in seconds
65
+ """
66
+ if root is None:
67
+ raise ValueError("root is required for pre-built environments")
68
+
69
+ self.socket_name = socket_name
70
+ self.socket_path = create_unix_socket_path(socket_name)
71
+ self.log = log
72
+ self.timeout = timeout
73
+
74
+ self.env_name = env_name
75
+ self.model = model
76
+ self.device = device
77
+ self.root = Path(root)
78
+
79
+ self._server_socket: socket.socket | None = None
80
+ self._client_socket: socket.socket | None = None
81
+ self._protocol: IPIProtocol | None = None
82
+ self._process: subprocess.Popen | None = None
83
+ self._connected = False
84
+
85
+ # Track INIT state
86
+ self._init_sent = False
87
+ self._init_numbers: list[int] | None = None
88
+ self._init_pbc: list[bool] | None = None
89
+
90
+ # Environment manager
91
+ self._env_manager = None
92
+ self._wrapper_path: Path | None = None
93
+
94
+ def start(self):
95
+ """Start the server and launch the worker process."""
96
+ # Create server socket
97
+ self._server_socket = create_server_socket(self.socket_path, timeout=self.timeout)
98
+ self._server_socket.listen(1)
99
+
100
+ if self.log:
101
+ print(f"Server listening on {self.socket_path}", file=self.log, flush=True)
102
+
103
+ # Launch worker process
104
+ self._start_worker()
105
+
106
+ if self.log:
107
+ print(f"Launched worker process (PID {self._process.pid})", file=self.log, flush=True)
108
+
109
+ # Wait for worker to connect
110
+ self._accept_connection()
111
+
112
+ def _start_worker(self):
113
+ """Start worker using pre-built environment."""
114
+ from .environment import EnvironmentManager
115
+
116
+ # Create environment manager
117
+ self._env_manager = EnvironmentManager(root=self.root)
118
+
119
+ # Generate wrapper script
120
+ self._wrapper_path = self._env_manager.generate_wrapper(
121
+ env_name=self.env_name,
122
+ model=self.model,
123
+ device=self.device,
124
+ socket_path=self.socket_path,
125
+ )
126
+
127
+ # Get spawn command and environment
128
+ cmd = self._env_manager.get_spawn_command(self.env_name, self._wrapper_path)
129
+ env = self._env_manager.get_environment_variables()
130
+
131
+ if self.log:
132
+ print(f"Spawning worker: {' '.join(cmd)}", file=self.log, flush=True)
133
+
134
+ self._process = subprocess.Popen(
135
+ cmd,
136
+ env=env,
137
+ stdout=subprocess.PIPE if not self.log else None,
138
+ stderr=subprocess.PIPE if not self.log else None,
139
+ )
140
+
141
+ def _accept_connection(self):
142
+ """Accept connection from worker process."""
143
+ # Use short timeout for accept so we can check if process died
144
+ self._server_socket.settimeout(1.0)
145
+
146
+ while True:
147
+ try:
148
+ self._client_socket, addr = self._server_socket.accept()
149
+ break
150
+ except TimeoutError:
151
+ # Check if process died
152
+ if self._process.poll() is not None:
153
+ stdout, stderr = self._process.communicate()
154
+ raise RuntimeError(
155
+ f"Worker process died with code {self._process.returncode}.\n"
156
+ f"stdout: {stdout}\nstderr: {stderr}"
157
+ )
158
+
159
+ # Restore original timeout
160
+ self._server_socket.settimeout(self.timeout)
161
+ self._client_socket.settimeout(self.timeout)
162
+
163
+ self._protocol = IPIProtocol(self._client_socket, log=self.log)
164
+ self._connected = True
165
+
166
+ if self.log:
167
+ print("Worker connected", file=self.log, flush=True)
168
+
169
+ def calculate(
170
+ self,
171
+ positions: np.ndarray,
172
+ cell: np.ndarray,
173
+ atomic_numbers: np.ndarray | None = None,
174
+ pbc: list[bool] | None = None,
175
+ ) -> tuple[float, np.ndarray, np.ndarray]:
176
+ """
177
+ Calculate energy and forces for given atomic configuration.
178
+
179
+ Args:
180
+ positions: Nx3 array of atomic positions in Angstrom
181
+ cell: 3x3 cell matrix in Angstrom
182
+ atomic_numbers: Atomic numbers array (sent in INIT on first call)
183
+ pbc: Periodic boundary conditions [x, y, z] (sent in INIT on first call)
184
+
185
+ Returns:
186
+ energy: Potential energy in eV
187
+ forces: Nx3 forces in eV/Angstrom
188
+ virial: 3x3 virial tensor in eV
189
+ """
190
+ if not self._connected:
191
+ raise RuntimeError("Server not connected. Call start() first.")
192
+
193
+ # Check worker status
194
+ self._protocol.send_status()
195
+ status = self._protocol.recv_status()
196
+
197
+ if status == "NEEDINIT":
198
+ # Send INIT with atomic species info
199
+ init_data = {
200
+ "numbers": atomic_numbers.tolist() if atomic_numbers is not None else None,
201
+ "pbc": [bool(p) for p in pbc] if pbc is not None else [True, True, True],
202
+ }
203
+ init_bytes = json.dumps(init_data).encode("utf-8")
204
+ self._protocol.send_init(bead_index=0, init_string=init_bytes)
205
+
206
+ # Track what we sent
207
+ self._init_sent = True
208
+ self._init_numbers = init_data["numbers"]
209
+ self._init_pbc = init_data["pbc"]
210
+
211
+ self._protocol.send_status()
212
+ status = self._protocol.recv_status()
213
+
214
+ if status != "READY":
215
+ raise RuntimeError(f"Worker not ready, status: {status}")
216
+
217
+ # Send positions
218
+ self._protocol.send_posdata(cell, positions)
219
+
220
+ # Check status - worker should now be calculating
221
+ self._protocol.send_status()
222
+ status = self._protocol.recv_status()
223
+
224
+ if status != "HAVEDATA":
225
+ raise RuntimeError(f"Worker failed to calculate, status: {status}")
226
+
227
+ # Get results
228
+ self._protocol.send_getforce()
229
+ energy, forces, virial, extra = self._protocol.recv_forceready()
230
+
231
+ return energy, forces, virial
232
+
233
+ def stop(self):
234
+ """Stop the server and terminate the worker process."""
235
+ if self._protocol is not None:
236
+ try:
237
+ self._protocol.send_exit()
238
+ except (BrokenPipeError, SocketClosed):
239
+ pass
240
+
241
+ if self._client_socket is not None:
242
+ self._client_socket.close()
243
+ self._client_socket = None
244
+
245
+ if self._server_socket is not None:
246
+ self._server_socket.close()
247
+ self._server_socket = None
248
+
249
+ if self._process is not None:
250
+ self._process.terminate()
251
+ try:
252
+ self._process.wait(timeout=5.0)
253
+ except subprocess.TimeoutExpired:
254
+ self._process.kill()
255
+ self._process.wait()
256
+ self._process = None
257
+
258
+ # Clean up socket file
259
+ if os.path.exists(self.socket_path):
260
+ os.unlink(self.socket_path)
261
+
262
+ # Clean up wrapper script
263
+ if self._wrapper_path is not None:
264
+ try:
265
+ self._wrapper_path.unlink(missing_ok=True)
266
+ except Exception:
267
+ pass
268
+ self._wrapper_path = None
269
+
270
+ # Clean up environment manager
271
+ if self._env_manager is not None:
272
+ self._env_manager.cleanup()
273
+ self._env_manager = None
274
+
275
+ self._connected = False
276
+ self._protocol = None
277
+
278
+ if self.log:
279
+ print("Server stopped", file=self.log, flush=True)
280
+
281
+ def __enter__(self):
282
+ self.start()
283
+ return self
284
+
285
+ def __exit__(self, exc_type, exc_val, exc_tb):
286
+ self.stop()
287
+ return False