hillclimber 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hillclimber might be problematic. Click here for more details.
- hillclimber/__init__.py +18 -0
- hillclimber/actions.py +30 -0
- hillclimber/calc.py +15 -0
- hillclimber/cvs.py +646 -0
- hillclimber/interfaces.py +92 -0
- hillclimber/metadynamics.py +249 -0
- hillclimber/nodes.py +6 -0
- hillclimber/selectors.py +96 -0
- hillclimber-0.1.0a1.dist-info/METADATA +412 -0
- hillclimber-0.1.0a1.dist-info/RECORD +13 -0
- hillclimber-0.1.0a1.dist-info/WHEEL +4 -0
- hillclimber-0.1.0a1.dist-info/entry_points.txt +5 -0
- hillclimber-0.1.0a1.dist-info/licenses/LICENSE +251 -0
hillclimber/cvs.py
ADDED
|
@@ -0,0 +1,646 @@
|
|
|
1
|
+
# --- IMPORTS ---
|
|
2
|
+
# Standard library
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Dict, List, Literal, Optional, Tuple, Union
|
|
5
|
+
|
|
6
|
+
# Third-party
|
|
7
|
+
import rdkit2ase
|
|
8
|
+
from ase import Atoms
|
|
9
|
+
from PIL import Image
|
|
10
|
+
from rdkit import Chem
|
|
11
|
+
from rdkit.Chem import Draw
|
|
12
|
+
|
|
13
|
+
# Local
|
|
14
|
+
from hillclimber.interfaces import AtomSelector, CollectiveVariable
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# --- TYPE HINTS ---
|
|
18
|
+
GroupReductionStrategyType = Literal[
|
|
19
|
+
"com", "cog", "first", "all", "com_per_group", "cog_per_group"
|
|
20
|
+
]
|
|
21
|
+
MultiGroupStrategyType = Literal["first", "all_pairs", "corresponding", "first_to_all"]
|
|
22
|
+
SiteIdentifier = Union[str, List[int]]
|
|
23
|
+
ColorTuple = Tuple[float, float, float]
|
|
24
|
+
AtomHighlightMap = Dict[int, ColorTuple]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# --- BASE CLASS FOR SHARED LOGIC ---
|
|
28
|
+
class _BasePlumedCV(CollectiveVariable):
|
|
29
|
+
"""An abstract base class for PLUMED CVs providing shared utilities."""
|
|
30
|
+
|
|
31
|
+
prefix: str
|
|
32
|
+
|
|
33
|
+
def _get_atom_highlights(self, atoms: Atoms, **kwargs) -> Optional[AtomHighlightMap]:
|
|
34
|
+
"""
|
|
35
|
+
Get atom indices and colors for visualization.
|
|
36
|
+
|
|
37
|
+
This abstract method must be implemented by subclasses to define which atoms
|
|
38
|
+
to highlight and with which colors.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
atoms: The ASE Atoms object.
|
|
42
|
+
**kwargs: Additional keyword arguments for specific implementations.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
A dictionary mapping global atom indices to their RGB highlight color,
|
|
46
|
+
or None if selection fails.
|
|
47
|
+
"""
|
|
48
|
+
raise NotImplementedError
|
|
49
|
+
|
|
50
|
+
def get_img(self, atoms: Atoms, **kwargs) -> Image.Image:
|
|
51
|
+
"""
|
|
52
|
+
Generates an image of the molecule(s) with selected atoms highlighted.
|
|
53
|
+
|
|
54
|
+
This method uses RDKit to render the image. It automatically identifies
|
|
55
|
+
molecular fragments containing highlighted atoms and draws them in a row.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
atoms: The ASE Atoms object to visualize.
|
|
59
|
+
**kwargs: Additional arguments passed to _get_atom_highlights.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
A PIL Image object of the visualization.
|
|
63
|
+
"""
|
|
64
|
+
highlight_map = self._get_atom_highlights(atoms, **kwargs)
|
|
65
|
+
mol = rdkit2ase.ase2rdkit(atoms)
|
|
66
|
+
|
|
67
|
+
if not highlight_map:
|
|
68
|
+
return Draw.MolsToGridImage(
|
|
69
|
+
[mol],
|
|
70
|
+
molsPerRow=1,
|
|
71
|
+
subImgSize=(400, 400),
|
|
72
|
+
useSVG=False,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
mol_frags = Chem.GetMolFrags(mol, asMols=True)
|
|
76
|
+
frag_indices_list = Chem.GetMolFrags(mol, asMols=False)
|
|
77
|
+
|
|
78
|
+
mols_to_draw, highlights_to_draw, colors_to_draw = [], [], []
|
|
79
|
+
seen_molecules = set()
|
|
80
|
+
|
|
81
|
+
for frag_mol, frag_indices in zip(mol_frags, frag_indices_list):
|
|
82
|
+
local_idx_map = {
|
|
83
|
+
global_idx: local_idx
|
|
84
|
+
for local_idx, global_idx in enumerate(frag_indices)
|
|
85
|
+
}
|
|
86
|
+
current_highlights = {
|
|
87
|
+
local_idx_map[g_idx]: color
|
|
88
|
+
for g_idx, color in highlight_map.items()
|
|
89
|
+
if g_idx in local_idx_map
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
if current_highlights:
|
|
93
|
+
# Create unique identifier: canonical SMILES + highlighted local indices
|
|
94
|
+
canonical_smiles = Chem.MolToSmiles(frag_mol)
|
|
95
|
+
highlighted_local_indices = tuple(sorted(current_highlights.keys()))
|
|
96
|
+
molecule_signature = (canonical_smiles, highlighted_local_indices)
|
|
97
|
+
|
|
98
|
+
if molecule_signature not in seen_molecules:
|
|
99
|
+
seen_molecules.add(molecule_signature)
|
|
100
|
+
mols_to_draw.append(frag_mol)
|
|
101
|
+
highlights_to_draw.append(list(current_highlights.keys()))
|
|
102
|
+
colors_to_draw.append(current_highlights)
|
|
103
|
+
|
|
104
|
+
if not mols_to_draw:
|
|
105
|
+
return Draw.MolsToGridImage(
|
|
106
|
+
[mol],
|
|
107
|
+
molsPerRow=1,
|
|
108
|
+
subImgSize=(400, 400),
|
|
109
|
+
useSVG=False,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
return Draw.MolsToGridImage(
|
|
113
|
+
mols_to_draw,
|
|
114
|
+
molsPerRow=len(mols_to_draw),
|
|
115
|
+
subImgSize=(400, 400),
|
|
116
|
+
highlightAtomLists=highlights_to_draw,
|
|
117
|
+
highlightAtomColors=colors_to_draw,
|
|
118
|
+
useSVG=False,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
@staticmethod
|
|
122
|
+
def _extract_labels(
|
|
123
|
+
commands: List[str], prefix: str, cv_keyword: str
|
|
124
|
+
) -> List[str]:
|
|
125
|
+
"""Extracts generated CV labels from a list of PLUMED commands."""
|
|
126
|
+
return [
|
|
127
|
+
cmd.split(":", 1)[0].strip()
|
|
128
|
+
for cmd in commands
|
|
129
|
+
if cv_keyword in cmd and cmd.strip().startswith((prefix, f"{prefix}_"))
|
|
130
|
+
]
|
|
131
|
+
|
|
132
|
+
@staticmethod
|
|
133
|
+
def _get_index_pairs(
|
|
134
|
+
len1: int, len2: int, strategy: MultiGroupStrategyType
|
|
135
|
+
) -> List[Tuple[int, int]]:
|
|
136
|
+
"""Determines pairs of group indices based on the multi-group strategy."""
|
|
137
|
+
if strategy == "first":
|
|
138
|
+
return [(0, 0)] if len1 > 0 and len2 > 0 else []
|
|
139
|
+
if strategy == "all_pairs":
|
|
140
|
+
return [(i, j) for i in range(len1) for j in range(len2)]
|
|
141
|
+
if strategy == "corresponding":
|
|
142
|
+
n = min(len1, len2)
|
|
143
|
+
return [(i, i) for i in range(n)]
|
|
144
|
+
if strategy == "first_to_all":
|
|
145
|
+
return [(0, j) for j in range(len2)] if len1 > 0 else []
|
|
146
|
+
raise ValueError(f"Unknown multi-group strategy: {strategy}")
|
|
147
|
+
|
|
148
|
+
@staticmethod
|
|
149
|
+
def _create_virtual_site_command(
|
|
150
|
+
group: List[int], strategy: Literal["com", "cog"], label: str
|
|
151
|
+
) -> str:
|
|
152
|
+
"""Creates a PLUMED command for a COM or CENTER virtual site."""
|
|
153
|
+
if not group:
|
|
154
|
+
raise ValueError("Cannot create a virtual site for an empty group.")
|
|
155
|
+
atom_list = ",".join(str(idx + 1) for idx in group)
|
|
156
|
+
cmd_keyword = "COM" if strategy == "com" else "CENTER"
|
|
157
|
+
return f"{label}: {cmd_keyword} ATOMS={atom_list}"
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
# --- REFACTORED CV CLASSES ---
|
|
161
|
+
@dataclass
|
|
162
|
+
class DistanceCV(_BasePlumedCV):
|
|
163
|
+
"""
|
|
164
|
+
PLUMED DISTANCE collective variable.
|
|
165
|
+
|
|
166
|
+
Calculates the distance between two atoms or groups of atoms. This CV supports
|
|
167
|
+
various strategies for reducing groups to single points (e.g., center of mass)
|
|
168
|
+
and for pairing multiple groups.
|
|
169
|
+
|
|
170
|
+
Attributes:
|
|
171
|
+
x1: Selector for the first atom/group.
|
|
172
|
+
x2: Selector for the second atom/group.
|
|
173
|
+
prefix: Label prefix for the generated PLUMED commands.
|
|
174
|
+
group_reduction: Strategy to reduce an atom group to a single point.
|
|
175
|
+
multi_group: Strategy for handling multiple groups from selectors.
|
|
176
|
+
create_virtual_sites: If True, create explicit virtual sites for COM/COG.
|
|
177
|
+
|
|
178
|
+
Resources:
|
|
179
|
+
- https://www.plumed.org/doc-master/user-doc/html/DISTANCE.html
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
x1: AtomSelector
|
|
183
|
+
x2: AtomSelector
|
|
184
|
+
prefix: str
|
|
185
|
+
group_reduction: GroupReductionStrategyType = "com"
|
|
186
|
+
multi_group: MultiGroupStrategyType = "first"
|
|
187
|
+
create_virtual_sites: bool = True
|
|
188
|
+
|
|
189
|
+
def _get_atom_highlights(
|
|
190
|
+
self, atoms: Atoms, **kwargs
|
|
191
|
+
) -> Optional[AtomHighlightMap]:
|
|
192
|
+
groups1 = self.x1.select(atoms)
|
|
193
|
+
groups2 = self.x2.select(atoms)
|
|
194
|
+
|
|
195
|
+
if not groups1 or not groups2:
|
|
196
|
+
return None
|
|
197
|
+
|
|
198
|
+
index_pairs = self._get_index_pairs(len(groups1), len(groups2), self.multi_group)
|
|
199
|
+
if not index_pairs:
|
|
200
|
+
return None
|
|
201
|
+
|
|
202
|
+
# Correctly select atoms based on the group_reduction strategy
|
|
203
|
+
indices1, indices2 = set(), set()
|
|
204
|
+
for i, j in index_pairs:
|
|
205
|
+
# Handle the 'first' atom case specifically for highlighting
|
|
206
|
+
if self.group_reduction == "first":
|
|
207
|
+
# Ensure the group is not empty before accessing the first element
|
|
208
|
+
if groups1[i]:
|
|
209
|
+
indices1.add(groups1[i][0])
|
|
210
|
+
if groups2[j]:
|
|
211
|
+
indices2.add(groups2[j][0])
|
|
212
|
+
# For other strategies (com, cog, all), highlight the whole group
|
|
213
|
+
else:
|
|
214
|
+
indices1.update(groups1[i])
|
|
215
|
+
indices2.update(groups2[j])
|
|
216
|
+
|
|
217
|
+
if not indices1 and not indices2:
|
|
218
|
+
return None
|
|
219
|
+
|
|
220
|
+
# Color atoms based on group membership, with purple for overlaps.
|
|
221
|
+
highlights: AtomHighlightMap = {}
|
|
222
|
+
red, blue, purple = (1.0, 0.2, 0.2), (0.2, 0.2, 1.0), (1.0, 0.2, 1.0)
|
|
223
|
+
for idx in indices1.union(indices2):
|
|
224
|
+
in1, in2 = idx in indices1, idx in indices2
|
|
225
|
+
if in1 and in2:
|
|
226
|
+
highlights[idx] = purple
|
|
227
|
+
elif in1:
|
|
228
|
+
highlights[idx] = red
|
|
229
|
+
elif in2:
|
|
230
|
+
highlights[idx] = blue
|
|
231
|
+
return highlights
|
|
232
|
+
|
|
233
|
+
def to_plumed(self, atoms: Atoms) -> Tuple[List[str], List[str]]:
|
|
234
|
+
"""
|
|
235
|
+
Generates PLUMED input strings for the DISTANCE CV.
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
A tuple containing a list of CV labels and a list of PLUMED commands.
|
|
239
|
+
"""
|
|
240
|
+
groups1 = self.x1.select(atoms)
|
|
241
|
+
groups2 = self.x2.select(atoms)
|
|
242
|
+
|
|
243
|
+
if not groups1 or not groups2:
|
|
244
|
+
raise ValueError(f"Empty selection for distance CV '{self.prefix}'")
|
|
245
|
+
|
|
246
|
+
flat1 = {idx for group in groups1 for idx in group}
|
|
247
|
+
flat2 = {idx for group in groups2 for idx in group}
|
|
248
|
+
if flat1.intersection(flat2) and self.group_reduction not in ["com", "cog"]:
|
|
249
|
+
raise ValueError(
|
|
250
|
+
"Overlapping atoms found. This is only valid with 'com' or 'cog' reduction."
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
commands = self._generate_commands(groups1, groups2)
|
|
254
|
+
labels = self._extract_labels(commands, self.prefix, "DISTANCE")
|
|
255
|
+
return labels, commands
|
|
256
|
+
|
|
257
|
+
def _generate_commands(
|
|
258
|
+
self, groups1: List[List[int]], groups2: List[List[int]]
|
|
259
|
+
) -> List[str]:
|
|
260
|
+
"""Generates all necessary PLUMED commands."""
|
|
261
|
+
commands = []
|
|
262
|
+
index_pairs = self._get_index_pairs(
|
|
263
|
+
len(groups1), len(groups2), self.multi_group
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# Efficiently create virtual sites only for groups that will be used.
|
|
267
|
+
sites1, sites2 = {}, {}
|
|
268
|
+
unique_indices1 = sorted({i for i, j in index_pairs})
|
|
269
|
+
unique_indices2 = sorted({j for i, j in index_pairs})
|
|
270
|
+
|
|
271
|
+
for i in unique_indices1:
|
|
272
|
+
site, site_cmds = self._reduce_group(groups1[i], f"{self.prefix}_g1_{i}")
|
|
273
|
+
sites1[i] = site
|
|
274
|
+
commands.extend(site_cmds)
|
|
275
|
+
for j in unique_indices2:
|
|
276
|
+
site, site_cmds = self._reduce_group(groups2[j], f"{self.prefix}_g2_{j}")
|
|
277
|
+
sites2[j] = site
|
|
278
|
+
commands.extend(site_cmds)
|
|
279
|
+
|
|
280
|
+
# Create the final DISTANCE commands.
|
|
281
|
+
for i, j in index_pairs:
|
|
282
|
+
label = self.prefix if len(index_pairs) == 1 else f"{self.prefix}_{i}_{j}"
|
|
283
|
+
cmd = self._make_distance_command(sites1[i], sites2[j], label)
|
|
284
|
+
commands.append(cmd)
|
|
285
|
+
|
|
286
|
+
return commands
|
|
287
|
+
|
|
288
|
+
def _reduce_group(
|
|
289
|
+
self, group: List[int], site_prefix: str
|
|
290
|
+
) -> Tuple[SiteIdentifier, List[str]]:
|
|
291
|
+
"""Reduces a single atom group to a site identifier based on strategy."""
|
|
292
|
+
if len(group) == 1 or self.group_reduction == "first":
|
|
293
|
+
return str(group[0] + 1), []
|
|
294
|
+
if self.group_reduction == "all":
|
|
295
|
+
return group, []
|
|
296
|
+
|
|
297
|
+
if self.group_reduction in ["com", "cog"]:
|
|
298
|
+
if self.create_virtual_sites:
|
|
299
|
+
label = f"{site_prefix}_{self.group_reduction}"
|
|
300
|
+
cmd = self._create_virtual_site_command(
|
|
301
|
+
group, self.group_reduction, label
|
|
302
|
+
)
|
|
303
|
+
return label, [cmd]
|
|
304
|
+
return group, [] # Use group directly if not creating virtual sites
|
|
305
|
+
|
|
306
|
+
raise ValueError(f"Unknown group reduction strategy: {self.group_reduction}")
|
|
307
|
+
|
|
308
|
+
def _make_distance_command(
|
|
309
|
+
self, site1: SiteIdentifier, site2: SiteIdentifier, label: str
|
|
310
|
+
) -> str:
|
|
311
|
+
"""Creates a single PLUMED DISTANCE command string."""
|
|
312
|
+
|
|
313
|
+
def _format(site):
|
|
314
|
+
return ",".join(map(str, (s + 1 for s in site))) if isinstance(site, list) else site
|
|
315
|
+
|
|
316
|
+
s1_str, s2_str = _format(site1), _format(site2)
|
|
317
|
+
# Use ATOMS for point-like sites, ATOMS1/ATOMS2 for group-based distances
|
|
318
|
+
if isinstance(site1, str) and isinstance(site2, str):
|
|
319
|
+
return f"{label}: DISTANCE ATOMS={s1_str},{s2_str}"
|
|
320
|
+
return f"{label}: DISTANCE ATOMS1={s1_str} ATOMS2={s2_str}"
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
@dataclass
|
|
324
|
+
class CoordinationNumberCV(_BasePlumedCV):
|
|
325
|
+
"""
|
|
326
|
+
PLUMED COORDINATION collective variable.
|
|
327
|
+
|
|
328
|
+
Calculates a coordination number based on a switching function. It supports
|
|
329
|
+
complex group definitions, including groups of virtual sites.
|
|
330
|
+
|
|
331
|
+
Attributes:
|
|
332
|
+
x1, x2: Selectors for the two groups of atoms.
|
|
333
|
+
prefix: Label prefix for the generated PLUMED commands.
|
|
334
|
+
r_0: The reference distance for the switching function (in Angstroms).
|
|
335
|
+
nn, mm, d_0: Parameters for the switching function.
|
|
336
|
+
group_reduction_1, group_reduction_2: Reduction strategies for each group.
|
|
337
|
+
multi_group: Strategy for handling multiple groups from selectors.
|
|
338
|
+
create_virtual_sites: If True, create explicit virtual sites for COM/COG.
|
|
339
|
+
|
|
340
|
+
Resources:
|
|
341
|
+
- https://www.plumed.org/doc-master/user-doc/html/COORDINATION.html
|
|
342
|
+
- https://www.plumed.org/doc-master/user-doc/html/GROUP.html
|
|
343
|
+
"""
|
|
344
|
+
|
|
345
|
+
x1: AtomSelector
|
|
346
|
+
x2: AtomSelector
|
|
347
|
+
prefix: str
|
|
348
|
+
r_0: float
|
|
349
|
+
nn: int = 6
|
|
350
|
+
mm: int = 0
|
|
351
|
+
d_0: float = 0.0
|
|
352
|
+
group_reduction_1: GroupReductionStrategyType = "all"
|
|
353
|
+
group_reduction_2: GroupReductionStrategyType = "all"
|
|
354
|
+
multi_group: MultiGroupStrategyType = "first"
|
|
355
|
+
create_virtual_sites: bool = True
|
|
356
|
+
|
|
357
|
+
def _get_atom_highlights(
|
|
358
|
+
self, atoms: Atoms, **kwargs
|
|
359
|
+
) -> Optional[AtomHighlightMap]:
|
|
360
|
+
highlight_hydrogens = kwargs.get("highlight_hydrogens", False)
|
|
361
|
+
groups1 = self.x1.select(atoms)
|
|
362
|
+
groups2 = self.x2.select(atoms)
|
|
363
|
+
|
|
364
|
+
if not groups1 or not groups2:
|
|
365
|
+
return None
|
|
366
|
+
|
|
367
|
+
# Flatten groups and optionally filter out hydrogens.
|
|
368
|
+
indices1 = {idx for g in groups1 for idx in g}
|
|
369
|
+
indices2 = {idx for g in groups2 for idx in g}
|
|
370
|
+
if not highlight_hydrogens:
|
|
371
|
+
indices1 = {i for i in indices1 if atoms[i].symbol != "H"}
|
|
372
|
+
indices2 = {i for i in indices2 if atoms[i].symbol != "H"}
|
|
373
|
+
|
|
374
|
+
if not indices1 and not indices2:
|
|
375
|
+
return None
|
|
376
|
+
|
|
377
|
+
# Color atoms based on group membership, with purple for overlaps.
|
|
378
|
+
highlights: AtomHighlightMap = {}
|
|
379
|
+
red, blue, purple = (1.0, 0.5, 0.5), (0.5, 0.5, 1.0), (1.0, 0.5, 1.0)
|
|
380
|
+
for idx in indices1.union(indices2):
|
|
381
|
+
in1, in2 = idx in indices1, idx in indices2
|
|
382
|
+
if in1 and in2:
|
|
383
|
+
highlights[idx] = purple
|
|
384
|
+
elif in1:
|
|
385
|
+
highlights[idx] = red
|
|
386
|
+
elif in2:
|
|
387
|
+
highlights[idx] = blue
|
|
388
|
+
return highlights
|
|
389
|
+
|
|
390
|
+
def to_plumed(self, atoms: Atoms) -> Tuple[List[str], List[str]]:
|
|
391
|
+
"""
|
|
392
|
+
Generates PLUMED input strings for the COORDINATION CV.
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
A tuple containing a list of CV labels and a list of PLUMED commands.
|
|
396
|
+
"""
|
|
397
|
+
groups1 = self.x1.select(atoms)
|
|
398
|
+
groups2 = self.x2.select(atoms)
|
|
399
|
+
|
|
400
|
+
if not groups1 or not groups2:
|
|
401
|
+
raise ValueError(f"Empty selection for coordination CV '{self.prefix}'")
|
|
402
|
+
|
|
403
|
+
commands = self._generate_commands(groups1, groups2)
|
|
404
|
+
labels = self._extract_labels(commands, self.prefix, "COORDINATION")
|
|
405
|
+
return labels, commands
|
|
406
|
+
|
|
407
|
+
def _generate_commands(
|
|
408
|
+
self, groups1: List[List[int]], groups2: List[List[int]]
|
|
409
|
+
) -> List[str]:
|
|
410
|
+
"""Generates all necessary PLUMED commands."""
|
|
411
|
+
commands: List[str] = []
|
|
412
|
+
|
|
413
|
+
sites1 = self._reduce_groups(
|
|
414
|
+
groups1, self.group_reduction_1, f"{self.prefix}_g1", commands
|
|
415
|
+
)
|
|
416
|
+
sites2 = self._reduce_groups(
|
|
417
|
+
groups2, self.group_reduction_2, f"{self.prefix}_g2", commands
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
# Get site pairs using a simplified helper
|
|
421
|
+
site_pairs = []
|
|
422
|
+
if self.multi_group == "first":
|
|
423
|
+
site_pairs = [(sites1[0], sites2[0])] if sites1 and sites2 else []
|
|
424
|
+
elif self.multi_group == "all_pairs":
|
|
425
|
+
site_pairs = [(s1, s2) for s1 in sites1 for s2 in sites2]
|
|
426
|
+
elif self.multi_group == "corresponding":
|
|
427
|
+
n = min(len(sites1), len(sites2))
|
|
428
|
+
site_pairs = [(sites1[i], sites2[i]) for i in range(n)]
|
|
429
|
+
elif self.multi_group == "first_to_all":
|
|
430
|
+
site_pairs = [(sites1[0], s2) for s2 in sites2] if sites1 else []
|
|
431
|
+
|
|
432
|
+
for i, (s1, s2) in enumerate(site_pairs):
|
|
433
|
+
label = self.prefix if len(site_pairs) == 1 else f"{self.prefix}_{i}"
|
|
434
|
+
commands.append(self._make_coordination_command(s1, s2, label))
|
|
435
|
+
|
|
436
|
+
return commands
|
|
437
|
+
|
|
438
|
+
def _reduce_groups(
|
|
439
|
+
self,
|
|
440
|
+
groups: List[List[int]],
|
|
441
|
+
strategy: GroupReductionStrategyType,
|
|
442
|
+
site_prefix: str,
|
|
443
|
+
commands: List[str],
|
|
444
|
+
) -> List[SiteIdentifier]:
|
|
445
|
+
"""Reduces a list of atom groups into a list of site identifiers."""
|
|
446
|
+
if strategy in ["com_per_group", "cog_per_group"]:
|
|
447
|
+
if not self.create_virtual_sites:
|
|
448
|
+
raise ValueError(f"'{strategy}' requires create_virtual_sites=True")
|
|
449
|
+
|
|
450
|
+
reduction_type = "COM" if strategy == "com_per_group" else "CENTER"
|
|
451
|
+
vsite_labels = []
|
|
452
|
+
for i, group in enumerate(groups):
|
|
453
|
+
if not group:
|
|
454
|
+
continue
|
|
455
|
+
vsite_label = f"{site_prefix}_{i}"
|
|
456
|
+
atom_list = ",".join(str(idx + 1) for idx in group)
|
|
457
|
+
commands.append(f"{vsite_label}: {reduction_type} ATOMS={atom_list}")
|
|
458
|
+
vsite_labels.append(vsite_label)
|
|
459
|
+
|
|
460
|
+
group_label = f"{site_prefix}_group"
|
|
461
|
+
commands.append(f"{group_label}: GROUP ATOMS={','.join(vsite_labels)}")
|
|
462
|
+
return [group_label]
|
|
463
|
+
|
|
464
|
+
if strategy == "all":
|
|
465
|
+
return [sorted({idx for group in groups for idx in group})]
|
|
466
|
+
|
|
467
|
+
# Handle other strategies by reducing each group individually.
|
|
468
|
+
sites: List[SiteIdentifier] = []
|
|
469
|
+
for i, group in enumerate(groups):
|
|
470
|
+
if len(group) == 1 or strategy == "first":
|
|
471
|
+
sites.append(str(group[0] + 1))
|
|
472
|
+
elif strategy in ["com", "cog"]:
|
|
473
|
+
if self.create_virtual_sites:
|
|
474
|
+
label = f"{site_prefix}_{i}_{strategy}"
|
|
475
|
+
cmd = self._create_virtual_site_command(group, strategy, label)
|
|
476
|
+
commands.append(cmd)
|
|
477
|
+
sites.append(label)
|
|
478
|
+
else:
|
|
479
|
+
sites.append(group)
|
|
480
|
+
else:
|
|
481
|
+
raise ValueError(f"Unsupported reduction strategy: {strategy}")
|
|
482
|
+
return sites
|
|
483
|
+
|
|
484
|
+
def _make_coordination_command(
|
|
485
|
+
self, site1: SiteIdentifier, site2: SiteIdentifier, label: str
|
|
486
|
+
) -> str:
|
|
487
|
+
"""Creates a single PLUMED COORDINATION command string."""
|
|
488
|
+
|
|
489
|
+
def _format(site):
|
|
490
|
+
return ",".join(map(str, (s + 1 for s in site))) if isinstance(site, list) else site
|
|
491
|
+
|
|
492
|
+
g_a, g_b = _format(site1), _format(site2)
|
|
493
|
+
base_cmd = f"{label}: COORDINATION GROUPA={g_a}"
|
|
494
|
+
if g_a != g_b: # Omit GROUPB for self-coordination
|
|
495
|
+
base_cmd += f" GROUPB={g_b}"
|
|
496
|
+
|
|
497
|
+
params = f" R_0={self.r_0} NN={self.nn} D_0={self.d_0}"
|
|
498
|
+
if self.mm != 0:
|
|
499
|
+
params += f" MM={self.mm}"
|
|
500
|
+
|
|
501
|
+
return base_cmd + params
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
@dataclass
|
|
505
|
+
class TorsionCV(_BasePlumedCV):
|
|
506
|
+
"""
|
|
507
|
+
PLUMED TORSION collective variable.
|
|
508
|
+
|
|
509
|
+
Calculates the torsional (dihedral) angle defined by four atoms. Each group
|
|
510
|
+
provided by the selector must contain exactly four atoms.
|
|
511
|
+
|
|
512
|
+
Attributes:
|
|
513
|
+
atoms: Selector for one or more groups of 4 atoms.
|
|
514
|
+
prefix: Label prefix for the generated PLUMED commands.
|
|
515
|
+
multi_group: Strategy for handling multiple groups from the selector.
|
|
516
|
+
|
|
517
|
+
Resources:
|
|
518
|
+
- https://www.plumed.org/doc-master/user-doc/html/TORSION.html
|
|
519
|
+
"""
|
|
520
|
+
|
|
521
|
+
atoms: AtomSelector
|
|
522
|
+
prefix: str
|
|
523
|
+
multi_group: MultiGroupStrategyType = "first"
|
|
524
|
+
|
|
525
|
+
def _get_atom_highlights(
|
|
526
|
+
self, atoms: Atoms, **kwargs
|
|
527
|
+
) -> Optional[AtomHighlightMap]:
|
|
528
|
+
groups = self.atoms.select(atoms)
|
|
529
|
+
if not groups or len(groups[0]) != 4:
|
|
530
|
+
print("Warning: Torsion CV requires a group of 4 atoms for visualization.")
|
|
531
|
+
return None
|
|
532
|
+
|
|
533
|
+
# Highlight the first 4-atom group with a color sequence.
|
|
534
|
+
torsion_atoms = groups[0]
|
|
535
|
+
colors = [
|
|
536
|
+
(1.0, 0.2, 0.2), # Red
|
|
537
|
+
(1.0, 0.6, 0.2), # Orange
|
|
538
|
+
(1.0, 1.0, 0.2), # Yellow
|
|
539
|
+
(0.2, 1.0, 0.2), # Green
|
|
540
|
+
]
|
|
541
|
+
return {atom_idx: color for atom_idx, color in zip(torsion_atoms, colors)}
|
|
542
|
+
|
|
543
|
+
def to_plumed(self, atoms: Atoms) -> Tuple[List[str], List[str]]:
|
|
544
|
+
"""
|
|
545
|
+
Generates PLUMED input strings for the TORSION CV.
|
|
546
|
+
|
|
547
|
+
Returns:
|
|
548
|
+
A tuple containing a list of CV labels and a list of PLUMED commands.
|
|
549
|
+
"""
|
|
550
|
+
groups = self.atoms.select(atoms)
|
|
551
|
+
if not groups:
|
|
552
|
+
raise ValueError(f"Empty selection for torsion CV '{self.prefix}'")
|
|
553
|
+
|
|
554
|
+
for i, group in enumerate(groups):
|
|
555
|
+
if len(group) != 4:
|
|
556
|
+
raise ValueError(
|
|
557
|
+
f"Torsion CV requires 4 atoms per group, but group {i} has {len(group)}."
|
|
558
|
+
)
|
|
559
|
+
|
|
560
|
+
commands = self._generate_commands(groups)
|
|
561
|
+
labels = self._extract_labels(commands, self.prefix, "TORSION")
|
|
562
|
+
return labels, commands
|
|
563
|
+
|
|
564
|
+
def _generate_commands(self, groups: List[List[int]]) -> List[str]:
|
|
565
|
+
"""Generates all necessary PLUMED commands."""
|
|
566
|
+
# For torsions, 'multi_group' determines how many groups to process.
|
|
567
|
+
if self.multi_group in ["first", "first_to_all"] and groups:
|
|
568
|
+
indices_to_process = [0]
|
|
569
|
+
else: # "all_pairs" and "corresponding" imply processing all independent groups.
|
|
570
|
+
indices_to_process = list(range(len(groups)))
|
|
571
|
+
|
|
572
|
+
commands = []
|
|
573
|
+
for i in indices_to_process:
|
|
574
|
+
label = self.prefix if len(indices_to_process) == 1 else f"{self.prefix}_{i}"
|
|
575
|
+
atom_list = ",".join(str(idx + 1) for idx in groups[i])
|
|
576
|
+
commands.append(f"{label}: TORSION ATOMS={atom_list}")
|
|
577
|
+
return commands
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
# TODO: we might need to set weights because plumed does not know about the atomistic weights?
|
|
581
|
+
@dataclass
|
|
582
|
+
class RadiusOfGyrationCV(_BasePlumedCV):
|
|
583
|
+
"""
|
|
584
|
+
PLUMED GYRATION collective variable.
|
|
585
|
+
|
|
586
|
+
Calculates the radius of gyration of a group of atoms. The radius of gyration
|
|
587
|
+
is a measure of the size of a molecular system.
|
|
588
|
+
|
|
589
|
+
Attributes:
|
|
590
|
+
atoms: Selector for the atoms to include in the gyration calculation.
|
|
591
|
+
prefix: Label prefix for the generated PLUMED commands.
|
|
592
|
+
multi_group: Strategy for handling multiple groups from the selector.
|
|
593
|
+
type: The type of gyration tensor to use ("RADIUS" for scalar Rg, "GTPC_1", etc.)
|
|
594
|
+
|
|
595
|
+
Resources:
|
|
596
|
+
- https://www.plumed.org/doc-master/user-doc/html/GYRATION/
|
|
597
|
+
"""
|
|
598
|
+
|
|
599
|
+
atoms: AtomSelector
|
|
600
|
+
prefix: str
|
|
601
|
+
multi_group: MultiGroupStrategyType = "first"
|
|
602
|
+
type: str = "RADIUS" # Options: RADIUS, GTPC_1, GTPC_2, GTPC_3, ASPHERICITY, ACYLINDRICITY, KAPPA2, etc.
|
|
603
|
+
|
|
604
|
+
def _get_atom_highlights(
|
|
605
|
+
self, atoms: Atoms, **kwargs
|
|
606
|
+
) -> Optional[AtomHighlightMap]:
|
|
607
|
+
groups = self.atoms.select(atoms)
|
|
608
|
+
if not groups or not groups[0]:
|
|
609
|
+
return None
|
|
610
|
+
|
|
611
|
+
# Highlight all atoms in the first group with a single color
|
|
612
|
+
group = groups[0]
|
|
613
|
+
return {atom_idx: (0.2, 0.8, 0.2) for atom_idx in group} # Green
|
|
614
|
+
|
|
615
|
+
def to_plumed(self, atoms: Atoms) -> Tuple[List[str], List[str]]:
|
|
616
|
+
"""
|
|
617
|
+
Generates PLUMED input strings for the GYRATION CV.
|
|
618
|
+
|
|
619
|
+
Returns:
|
|
620
|
+
A tuple containing a list of CV labels and a list of PLUMED commands.
|
|
621
|
+
"""
|
|
622
|
+
groups = self.atoms.select(atoms)
|
|
623
|
+
if not groups:
|
|
624
|
+
raise ValueError(f"Empty selection for gyration CV '{self.prefix}'")
|
|
625
|
+
|
|
626
|
+
commands = self._generate_commands(groups)
|
|
627
|
+
labels = self._extract_labels(commands, self.prefix, "GYRATION")
|
|
628
|
+
return labels, commands
|
|
629
|
+
|
|
630
|
+
def _generate_commands(self, groups: List[List[int]]) -> List[str]:
|
|
631
|
+
"""Generates all necessary PLUMED commands."""
|
|
632
|
+
# For gyration, 'multi_group' determines how many groups to process.
|
|
633
|
+
if self.multi_group in ["first", "first_to_all"] and groups:
|
|
634
|
+
indices_to_process = [0]
|
|
635
|
+
else: # "all_pairs" and "corresponding" imply processing all independent groups.
|
|
636
|
+
indices_to_process = list(range(len(groups)))
|
|
637
|
+
|
|
638
|
+
commands = []
|
|
639
|
+
for i in indices_to_process:
|
|
640
|
+
label = self.prefix if len(indices_to_process) == 1 else f"{self.prefix}_{i}"
|
|
641
|
+
atom_list = ",".join(str(idx + 1) for idx in groups[i])
|
|
642
|
+
command = f"{label}: GYRATION ATOMS={atom_list}"
|
|
643
|
+
if self.type != "RADIUS":
|
|
644
|
+
command += f" TYPE={self.type}"
|
|
645
|
+
commands.append(command)
|
|
646
|
+
return commands
|