sawnergy 1.0.3__py3-none-any.whl → 1.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sawnergy/__init__.py +3 -1
- sawnergy/embedding/SGNS_pml.py +324 -51
- sawnergy/embedding/SGNS_torch.py +282 -39
- sawnergy/embedding/__init__.py +26 -1
- sawnergy/embedding/embedder.py +426 -203
- sawnergy/embedding/visualizer.py +251 -0
- sawnergy/logging_util.py +1 -1
- sawnergy/rin/rin_builder.py +4 -4
- sawnergy/visual/visualizer.py +6 -6
- sawnergy/visual/visualizer_util.py +3 -0
- sawnergy/walks/walker.py +43 -22
- {sawnergy-1.0.3.dist-info → sawnergy-1.0.9.dist-info}/METADATA +91 -57
- sawnergy-1.0.9.dist-info/RECORD +23 -0
- sawnergy-1.0.3.dist-info/RECORD +0 -22
- {sawnergy-1.0.3.dist-info → sawnergy-1.0.9.dist-info}/WHEEL +0 -0
- {sawnergy-1.0.3.dist-info → sawnergy-1.0.9.dist-info}/licenses/LICENSE +0 -0
- {sawnergy-1.0.3.dist-info → sawnergy-1.0.9.dist-info}/licenses/NOTICE +0 -0
- {sawnergy-1.0.3.dist-info → sawnergy-1.0.9.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
# third party
|
|
4
|
+
import numpy as np
|
|
5
|
+
import matplotlib as mpl
|
|
6
|
+
|
|
7
|
+
# built-in
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Sequence
|
|
10
|
+
import logging
|
|
11
|
+
|
|
12
|
+
# local
|
|
13
|
+
from ..visual import visualizer_util
|
|
14
|
+
from .. import sawnergy_util
|
|
15
|
+
|
|
16
|
+
# *----------------------------------------------------*
|
|
17
|
+
# GLOBALS
|
|
18
|
+
# *----------------------------------------------------*
|
|
19
|
+
|
|
20
|
+
_logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
# *----------------------------------------------------*
|
|
23
|
+
# HELPERS
|
|
24
|
+
# *----------------------------------------------------*
|
|
25
|
+
|
|
26
|
+
def _safe_svd_pca(X: np.ndarray, k: int, *, row_l2: bool = False) -> tuple[np.ndarray, np.ndarray]:
|
|
27
|
+
"""Compute k principal directions via SVD and project onto them."""
|
|
28
|
+
if X.ndim != 2:
|
|
29
|
+
raise ValueError(f"PCA expects 2D array (N, D); got {X.shape}")
|
|
30
|
+
_, D = X.shape
|
|
31
|
+
if k not in (2, 3):
|
|
32
|
+
raise ValueError(f"PCA dimensionality must be 2 or 3; got {k}")
|
|
33
|
+
if D < k:
|
|
34
|
+
raise ValueError(f"Requested k={k} exceeds feature dim D={D}")
|
|
35
|
+
Xc = X - X.mean(axis=0, keepdims=True)
|
|
36
|
+
if row_l2:
|
|
37
|
+
norms = np.linalg.norm(Xc, axis=1, keepdims=True)
|
|
38
|
+
Xc = Xc / np.clip(norms, 1e-9, None)
|
|
39
|
+
_, _, Vt = np.linalg.svd(Xc, full_matrices=False)
|
|
40
|
+
comps = Vt[:k].copy()
|
|
41
|
+
proj = Xc @ comps.T
|
|
42
|
+
return proj, comps
|
|
43
|
+
|
|
44
|
+
def _set_equal_axes_3d(ax, xyz: np.ndarray, *, padding: float = 0.05) -> None:
|
|
45
|
+
if xyz.size == 0:
|
|
46
|
+
return
|
|
47
|
+
x, y, z = xyz[:, 0], xyz[:, 1], xyz[:, 2]
|
|
48
|
+
xmin, xmax = float(x.min()), float(x.max())
|
|
49
|
+
ymin, ymax = float(y.min()), float(y.max())
|
|
50
|
+
zmin, zmax = float(z.min()), float(z.max())
|
|
51
|
+
xr = xmax - xmin
|
|
52
|
+
yr = ymax - ymin
|
|
53
|
+
zr = zmax - zmin
|
|
54
|
+
r = max(xr, yr, zr)
|
|
55
|
+
pad = padding * (r if r > 0 else 1.0)
|
|
56
|
+
cx, cy, cz = (xmin + xmax) / 2.0, (ymin + ymax) / 2.0, (zmin + zmax) / 2.0
|
|
57
|
+
ax.set_xlim(cx - r / 2 - pad, cx + r / 2 + pad)
|
|
58
|
+
ax.set_ylim(cy - r / 2 - pad, cy + r / 2 + pad)
|
|
59
|
+
ax.set_zlim(cz - r / 2 - pad, cz + r / 2 + pad)
|
|
60
|
+
try:
|
|
61
|
+
ax.set_box_aspect([1, 1, 1])
|
|
62
|
+
except Exception:
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
# *----------------------------------------------------*
|
|
66
|
+
# CLASS
|
|
67
|
+
# *----------------------------------------------------*
|
|
68
|
+
|
|
69
|
+
class Visualizer:
|
|
70
|
+
"""3D PCA visualizer for per-frame embeddings"""
|
|
71
|
+
|
|
72
|
+
no_instances: bool = True
|
|
73
|
+
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
EMB_path: str | Path,
|
|
77
|
+
figsize: tuple[int, int] = (9, 7),
|
|
78
|
+
default_node_color: str = visualizer_util.GRAY,
|
|
79
|
+
depthshade: bool = False,
|
|
80
|
+
antialiased: bool = False,
|
|
81
|
+
init_elev: float = 35,
|
|
82
|
+
init_azim: float = 45,
|
|
83
|
+
*,
|
|
84
|
+
show: bool = False,
|
|
85
|
+
normalize_rows: bool = False,
|
|
86
|
+
) -> None:
|
|
87
|
+
# Backend & pyplot
|
|
88
|
+
visualizer_util.ensure_backend(show)
|
|
89
|
+
import matplotlib.pyplot as plt
|
|
90
|
+
self._plt = plt
|
|
91
|
+
|
|
92
|
+
if Visualizer.no_instances:
|
|
93
|
+
try:
|
|
94
|
+
visualizer_util.warm_start_matplotlib()
|
|
95
|
+
finally:
|
|
96
|
+
Visualizer.no_instances = False
|
|
97
|
+
|
|
98
|
+
# Load embeddings archive
|
|
99
|
+
EMB_path = Path(EMB_path)
|
|
100
|
+
with sawnergy_util.ArrayStorage(EMB_path, mode="r") as storage:
|
|
101
|
+
name = storage.get_attr("frame_embeddings_name")
|
|
102
|
+
E = storage.read(name, slice(None))
|
|
103
|
+
if E.ndim != 3:
|
|
104
|
+
raise ValueError(f"Expected embeddings of shape (T,N,D); got {E.shape}")
|
|
105
|
+
self.E = np.asarray(E)
|
|
106
|
+
self.T, self.N, self.D = map(int, self.E.shape)
|
|
107
|
+
_logger.info("Loaded embeddings: T=%d, N=%d, D=%d", self.T, self.N, self.D)
|
|
108
|
+
|
|
109
|
+
# Coloring normalizer (parity with RIN Visualizer)
|
|
110
|
+
self._residue_norm = mpl.colors.Normalize(0, max(1, self.N - 1))
|
|
111
|
+
|
|
112
|
+
# Figure / axes / artists
|
|
113
|
+
self._fig = self._plt.figure(figsize=figsize, num="SAWNERGY")
|
|
114
|
+
self._ax = None
|
|
115
|
+
self._scatter = None
|
|
116
|
+
self._marker_size = 30.0
|
|
117
|
+
self._init_elev = init_elev
|
|
118
|
+
self._init_azim = init_azim
|
|
119
|
+
self.default_node_color = default_node_color
|
|
120
|
+
self._antialiased = bool(antialiased)
|
|
121
|
+
self._depthshade = bool(depthshade)
|
|
122
|
+
self._normalize_rows = bool(normalize_rows)
|
|
123
|
+
|
|
124
|
+
# ------------------------------ PRIVATE ------------------------------ #
|
|
125
|
+
|
|
126
|
+
def _ensure_axes(self) -> None:
|
|
127
|
+
if self._ax is not None and self._scatter is not None:
|
|
128
|
+
return
|
|
129
|
+
self._fig.clf()
|
|
130
|
+
self._ax = self._fig.add_subplot(111, projection="3d")
|
|
131
|
+
self._ax.view_init(self._init_elev, self._init_azim)
|
|
132
|
+
self._scatter = self._ax.scatter(
|
|
133
|
+
[], [], [],
|
|
134
|
+
s=self._marker_size,
|
|
135
|
+
depthshade=self._depthshade,
|
|
136
|
+
edgecolors="none",
|
|
137
|
+
antialiased=self._antialiased,
|
|
138
|
+
)
|
|
139
|
+
try:
|
|
140
|
+
self._ax.set_axis_off()
|
|
141
|
+
except Exception:
|
|
142
|
+
pass
|
|
143
|
+
|
|
144
|
+
def _project3(self, X: np.ndarray) -> np.ndarray:
|
|
145
|
+
"""Return a 3D PCA projection of embeddings (always 3 coordinates).
|
|
146
|
+
|
|
147
|
+
If the embedding dimensionality D < 3, the remaining coordinate(s) are set to 0
|
|
148
|
+
so that the returned array still has shape (N, 3).
|
|
149
|
+
"""
|
|
150
|
+
k = 3 if X.shape[1] >= 3 else 2
|
|
151
|
+
P, _ = _safe_svd_pca(X, k, row_l2=self._normalize_rows)
|
|
152
|
+
if k == 2:
|
|
153
|
+
P = np.c_[P, np.zeros((P.shape[0], 1), dtype=P.dtype)]
|
|
154
|
+
return P
|
|
155
|
+
|
|
156
|
+
def _select_nodes(self, displayed_nodes: Sequence[int] | str | None) -> np.ndarray:
|
|
157
|
+
if displayed_nodes is None or displayed_nodes == "ALL":
|
|
158
|
+
return np.arange(self.N, dtype=np.int64)
|
|
159
|
+
idx = np.asarray(displayed_nodes)
|
|
160
|
+
if idx.dtype.kind not in "iu":
|
|
161
|
+
raise TypeError("displayed_nodes must be None, 'ALL', or an integer sequence.")
|
|
162
|
+
if idx.min() < 1 or idx.max() > self.N:
|
|
163
|
+
raise IndexError(f"displayed_nodes out of range [1,{self.N}]")
|
|
164
|
+
return idx.astype(np.int64) - 1
|
|
165
|
+
|
|
166
|
+
def _apply_colors(self, node_colors, idx: np.ndarray) -> np.ndarray:
|
|
167
|
+
# RIN Visualizer semantics:
|
|
168
|
+
if isinstance(node_colors, str):
|
|
169
|
+
node_cmap = self._plt.get_cmap(node_colors)
|
|
170
|
+
return node_cmap(self._residue_norm(idx))
|
|
171
|
+
if node_colors is None:
|
|
172
|
+
full = visualizer_util.map_groups_to_colors(
|
|
173
|
+
N=self.N, groups=None, default_color=self.default_node_color, one_based=True
|
|
174
|
+
)
|
|
175
|
+
return np.asarray(full)[idx]
|
|
176
|
+
arr = np.asarray(node_colors)
|
|
177
|
+
if arr.ndim == 2 and arr.shape[0] == self.N and arr.shape[1] in (3, 4):
|
|
178
|
+
return arr[idx]
|
|
179
|
+
full = visualizer_util.map_groups_to_colors(
|
|
180
|
+
N=self.N, groups=node_colors, default_color=self.default_node_color, one_based=True
|
|
181
|
+
)
|
|
182
|
+
return np.asarray(full)[idx]
|
|
183
|
+
|
|
184
|
+
# ------------------------------ PUBLIC ------------------------------- #
|
|
185
|
+
|
|
186
|
+
def build_frame(
|
|
187
|
+
self,
|
|
188
|
+
frame_id: int,
|
|
189
|
+
*,
|
|
190
|
+
node_colors: str | np.ndarray | None = "rainbow",
|
|
191
|
+
displayed_nodes: Sequence[int] | str | None = "ALL",
|
|
192
|
+
show_node_labels: bool = False,
|
|
193
|
+
show: bool = False
|
|
194
|
+
) -> None:
|
|
195
|
+
"""Render a single frame as a PCA **3D** scatter (matches RIN Visualizer API)."""
|
|
196
|
+
frame0 = int(frame_id) - 1
|
|
197
|
+
if not (0 <= frame0 < self.T):
|
|
198
|
+
raise IndexError(f"frame_id out of range [1,{self.T}]")
|
|
199
|
+
self._ensure_axes()
|
|
200
|
+
|
|
201
|
+
idx = self._select_nodes(displayed_nodes)
|
|
202
|
+
X = self.E[frame0, idx, :] # (n, D)
|
|
203
|
+
P = self._project3(X) # (n, 3)
|
|
204
|
+
colors = self._apply_colors(node_colors, idx)
|
|
205
|
+
|
|
206
|
+
x, y, z = P[:, 0], P[:, 1], P[:, 2]
|
|
207
|
+
self._scatter._offsets3d = (x, y, z)
|
|
208
|
+
self._scatter.set_facecolors(colors)
|
|
209
|
+
_set_equal_axes_3d(self._ax, P, padding=0.05)
|
|
210
|
+
self._ax.view_init(self._init_elev, self._init_azim)
|
|
211
|
+
|
|
212
|
+
if show_node_labels:
|
|
213
|
+
for txt in getattr(self, "_labels", []):
|
|
214
|
+
try:
|
|
215
|
+
txt.remove()
|
|
216
|
+
except Exception:
|
|
217
|
+
pass
|
|
218
|
+
self._labels = []
|
|
219
|
+
for p, nid in zip(P, idx + 1):
|
|
220
|
+
self._labels.append(self._ax.text(p[0], p[1], p[2], str(int(nid)), fontsize=8))
|
|
221
|
+
|
|
222
|
+
try:
|
|
223
|
+
self._fig.tight_layout()
|
|
224
|
+
except Exception:
|
|
225
|
+
try:
|
|
226
|
+
self._fig.subplots_adjust()
|
|
227
|
+
except Exception:
|
|
228
|
+
pass
|
|
229
|
+
try:
|
|
230
|
+
self._fig.canvas.draw_idle()
|
|
231
|
+
except Exception:
|
|
232
|
+
pass
|
|
233
|
+
|
|
234
|
+
if show:
|
|
235
|
+
try:
|
|
236
|
+
self._plt.show(block=True)
|
|
237
|
+
except TypeError:
|
|
238
|
+
self._plt.show()
|
|
239
|
+
|
|
240
|
+
# convenience
|
|
241
|
+
def savefig(self, path: str | Path, *, dpi: int = 150) -> None:
|
|
242
|
+
self._fig.savefig(path, dpi=dpi)
|
|
243
|
+
|
|
244
|
+
def close(self) -> None:
|
|
245
|
+
try:
|
|
246
|
+
self._plt.close(self._fig)
|
|
247
|
+
except Exception:
|
|
248
|
+
pass
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
__all__ = ["Visualizer"]
|
sawnergy/logging_util.py
CHANGED
sawnergy/rin/rin_builder.py
CHANGED
|
@@ -669,7 +669,7 @@ class RINBuilder:
|
|
|
669
669
|
molecule_of_interest: int,
|
|
670
670
|
frame_range: tuple[int, int] | None = None,
|
|
671
671
|
frame_batch_size: int = -1,
|
|
672
|
-
prune_low_energies_frac: float = 0.
|
|
672
|
+
prune_low_energies_frac: float = 0.85,
|
|
673
673
|
output_path: str | Path | None = None,
|
|
674
674
|
keep_prenormalized_energies: bool = True,
|
|
675
675
|
*,
|
|
@@ -690,9 +690,9 @@ class RINBuilder:
|
|
|
690
690
|
|
|
691
691
|
2. For each frame batch:
|
|
692
692
|
|
|
693
|
-
a) Run cpptraj
|
|
693
|
+
a) Run cpptraj 'pairwise' on atoms → EMAP + VMAP → sum (atomic matrix).
|
|
694
694
|
|
|
695
|
-
b) Project atomic → residue with
|
|
695
|
+
b) Project atomic → residue with 'R = Pᵀ @ A @ P'.
|
|
696
696
|
|
|
697
697
|
c) Post-process residue matrix:
|
|
698
698
|
split into (attractive, repulsive) channels,
|
|
@@ -700,7 +700,7 @@ class RINBuilder:
|
|
|
700
700
|
remove self-interactions,
|
|
701
701
|
symmetrize.
|
|
702
702
|
|
|
703
|
-
d. Optionally store **pre-normalized energies** (attractive or repulsive or both, depending on
|
|
703
|
+
d. Optionally store **pre-normalized energies** (attractive or repulsive or both, depending on 'include_<kind>').
|
|
704
704
|
|
|
705
705
|
e. Row-wise L1 normalize (directed transition probabilities) and store.
|
|
706
706
|
|
sawnergy/visual/visualizer.py
CHANGED
|
@@ -107,7 +107,7 @@ class Visualizer:
|
|
|
107
107
|
visualizer_util.ensure_backend(show)
|
|
108
108
|
import matplotlib.pyplot as plt
|
|
109
109
|
self._plt = plt
|
|
110
|
-
|
|
110
|
+
# ---------- WARM UP MPL ------------ #
|
|
111
111
|
_logger.debug("Visualizer.__init__ start | RIN_path=%s, figsize=%s, node_size=%s, edge_width=%s, depthshade=%s, antialiased=%s, init_view=(%s,%s)",
|
|
112
112
|
RIN_path, figsize, node_size, edge_width, depthshade, antialiased, init_elev, init_azim)
|
|
113
113
|
if Visualizer.no_instances:
|
|
@@ -116,7 +116,7 @@ class Visualizer:
|
|
|
116
116
|
else:
|
|
117
117
|
_logger.debug("Skipping warm-start (no_instances=False).")
|
|
118
118
|
|
|
119
|
-
|
|
119
|
+
# ---------- LOAD THE DATA ---------- #
|
|
120
120
|
with sawnergy_util.ArrayStorage(RIN_path, mode="r") as storage:
|
|
121
121
|
com_name = storage.get_attr("com_name")
|
|
122
122
|
attr_energies_name = storage.get_attr("attractive_energies_name")
|
|
@@ -135,7 +135,7 @@ class Visualizer:
|
|
|
135
135
|
self.N = np.size(self.COM_coords[0], axis=0)
|
|
136
136
|
_logger.debug("Computed N=%d", self.N)
|
|
137
137
|
|
|
138
|
-
|
|
138
|
+
# - SET UP THE CANVAS AND THE AXES - #
|
|
139
139
|
self._fig = plt.figure(figsize=figsize, num="SAWNERGY")
|
|
140
140
|
self._ax = self._fig.add_subplot(111, projection="3d")
|
|
141
141
|
self._fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
|
|
@@ -145,14 +145,14 @@ class Visualizer:
|
|
|
145
145
|
self._ax.set_axis_off()
|
|
146
146
|
_logger.debug("Figure and 3D axes initialized.")
|
|
147
147
|
|
|
148
|
-
|
|
148
|
+
# ------ SET UP PLOT ELEMENTS ------ #
|
|
149
149
|
self._scatter: PathCollection = self._ax.scatter([], [], [], s=node_size, depthshade=depthshade, edgecolors="none")
|
|
150
150
|
self._attr: Line3DCollection = Line3DCollection(np.empty((0,2,3)), linewidths=edge_width, antialiased=antialiased)
|
|
151
151
|
self._repuls: Line3DCollection = Line3DCollection(np.empty((0,2,3)), linewidths=edge_width, antialiased=antialiased)
|
|
152
152
|
self._ax.add_collection3d(self._attr); self._ax.add_collection3d(self._repuls) # set pointers to the attractive and repulsive collections
|
|
153
153
|
_logger.debug("Artists created | scatter(empty), attr_lines(empty), repuls_lines(empty).")
|
|
154
154
|
|
|
155
|
-
|
|
155
|
+
# ---------- HELPER FIELDS --------- #
|
|
156
156
|
# NOTE: 'under the hood' everything is 0-base indexed,
|
|
157
157
|
# BUT, from the API point of view, the indexing is 1-base,
|
|
158
158
|
# because amino acid residues are 1-base indexed.
|
|
@@ -160,7 +160,7 @@ class Visualizer:
|
|
|
160
160
|
self.default_node_color = default_node_color
|
|
161
161
|
_logger.debug("Helper fields set | residue_norm=[0,%d], default_node_color=%s", self.N-1, self.default_node_color)
|
|
162
162
|
|
|
163
|
-
|
|
163
|
+
# DISALLOW MPL WARM-UP IN THE FUTURE
|
|
164
164
|
Visualizer.no_instances = False
|
|
165
165
|
_logger.debug("Visualizer.no_instances set to False.")
|
|
166
166
|
|
|
@@ -319,6 +319,9 @@ def build_line_segments(
|
|
|
319
319
|
kept = edge_weights >= thresh
|
|
320
320
|
rows, cols = rows[kept], cols[kept]
|
|
321
321
|
|
|
322
|
+
nz = weights[rows, cols] > 0.0
|
|
323
|
+
rows, cols = rows[nz], cols[nz]
|
|
324
|
+
|
|
322
325
|
if rows.size == 0:
|
|
323
326
|
_logger.debug("build_line_segments: no edges kept after threshold; returning empties.")
|
|
324
327
|
return (np.empty((0, 2, 3), dtype=float),
|
sawnergy/walks/walker.py
CHANGED
|
@@ -63,7 +63,7 @@ class Walker:
|
|
|
63
63
|
|
|
64
64
|
# Load numpy arrays from read-only storage
|
|
65
65
|
with sawnergy_util.ArrayStorage(RIN_path, mode="r") as storage:
|
|
66
|
-
attr_name
|
|
66
|
+
attr_name = storage.get_attr("attractive_transitions_name")
|
|
67
67
|
repuls_name = storage.get_attr("repulsive_transitions_name")
|
|
68
68
|
attr_matrices : np.ndarray | None = (
|
|
69
69
|
storage.read(attr_name, slice(None)) if attr_name is not None else None
|
|
@@ -133,6 +133,9 @@ class Walker:
|
|
|
133
133
|
) if repuls_matrices is not None else None
|
|
134
134
|
)
|
|
135
135
|
|
|
136
|
+
self._attr_owner_pid = os.getpid() if self.attr_matrices is not None else None
|
|
137
|
+
self._repuls_owner_pid = os.getpid() if self.repuls_matrices is not None else None
|
|
138
|
+
|
|
136
139
|
_logger.debug(
|
|
137
140
|
"SharedNDArray created | attr name=%r; repuls name=%r",
|
|
138
141
|
getattr(self.attr_matrices, "name", None),
|
|
@@ -159,38 +162,56 @@ class Walker:
|
|
|
159
162
|
|
|
160
163
|
# explicit resource cleanup
|
|
161
164
|
def close(self) -> None:
|
|
162
|
-
"""
|
|
165
|
+
"""Release shared-memory resources used by this Walker.
|
|
166
|
+
|
|
167
|
+
This method:
|
|
168
|
+
- Closes local handles to the shared-memory backed arrays
|
|
169
|
+
(`self.attr_matrices`, `self.repuls_matrices`) in **the current process**.
|
|
170
|
+
- If the current process is the **creator** of a segment (its PID matches
|
|
171
|
+
`_attr_owner_pid` / `_repuls_owner_pid`), it also **unlinks** that segment
|
|
172
|
+
so the OS can reclaim it once all handles are closed.
|
|
173
|
+
|
|
174
|
+
Behavior & guarantees
|
|
175
|
+
---------------------
|
|
176
|
+
- **Idempotent:** safe to call multiple times; subsequent calls are no-ops.
|
|
177
|
+
- **Multi-process aware:** non-creator processes only close their handles;
|
|
178
|
+
creators close **and** unlink. This prevents `resource_tracker` “leaked
|
|
179
|
+
shared_memory” warnings when using `ProcessPoolExecutor`/spawn.
|
|
180
|
+
- **Best-effort unlink:** `FileNotFoundError` during unlink (already unlinked
|
|
181
|
+
elsewhere) is swallowed.
|
|
182
|
+
- Invoked automatically by the context manager (`__exit__`) and destructor
|
|
183
|
+
(`__del__`), but it's fine to call explicitly.
|
|
184
|
+
|
|
185
|
+
After calling `close()`, any operation that relies on the shared arrays may
|
|
186
|
+
fail; treat the instance as finalized.
|
|
163
187
|
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
process (per ``sawnergy_util.is_main_process()``), also attempts to
|
|
167
|
-
unlink the underlying shared-memory segments (best-effort; suppresses
|
|
168
|
-
``FileNotFoundError`` if already unlinked elsewhere).
|
|
188
|
+
Returns:
|
|
189
|
+
None
|
|
169
190
|
"""
|
|
170
191
|
if self._memory_cleaned_up:
|
|
171
192
|
_logger.debug("close(): already cleaned up; returning")
|
|
172
193
|
return
|
|
173
|
-
_logger.debug("Closing Walker resources (
|
|
194
|
+
_logger.debug("Closing Walker resources (pid=%s)", os.getpid())
|
|
174
195
|
try:
|
|
175
196
|
if self.attr_matrices is not None:
|
|
176
197
|
self.attr_matrices.close()
|
|
177
198
|
if self.repuls_matrices is not None:
|
|
178
199
|
self.repuls_matrices.close()
|
|
179
200
|
_logger.debug("SharedNDArray handles closed")
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
201
|
+
|
|
202
|
+
# Unlink in whichever process actually CREATED the segment(s)
|
|
203
|
+
try:
|
|
204
|
+
if self.attr_matrices is not None and getattr(self, "_attr_owner_pid", None) == os.getpid():
|
|
205
|
+
self.attr_matrices.unlink()
|
|
206
|
+
except FileNotFoundError:
|
|
207
|
+
_logger.debug("attr SharedMemory already unlinked elsewhere")
|
|
208
|
+
|
|
209
|
+
try:
|
|
210
|
+
if self.repuls_matrices is not None and getattr(self, "_repuls_owner_pid", None) == os.getpid():
|
|
211
|
+
self.repuls_matrices.unlink()
|
|
212
|
+
except FileNotFoundError:
|
|
213
|
+
_logger.debug("repuls SharedMemory already unlinked elsewhere")
|
|
214
|
+
|
|
194
215
|
finally:
|
|
195
216
|
self._memory_cleaned_up = True
|
|
196
217
|
_logger.debug("Cleanup complete")
|