hillclimber 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hillclimber might be problematic. Click here for more details.

hillclimber/cvs.py ADDED
@@ -0,0 +1,646 @@
1
+ # --- IMPORTS ---
2
+ # Standard library
3
+ from dataclasses import dataclass
4
+ from typing import Dict, List, Literal, Optional, Tuple, Union
5
+
6
+ # Third-party
7
+ import rdkit2ase
8
+ from ase import Atoms
9
+ from PIL import Image
10
+ from rdkit import Chem
11
+ from rdkit.Chem import Draw
12
+
13
+ # Local
14
+ from hillclimber.interfaces import AtomSelector, CollectiveVariable
15
+
16
+
17
+ # --- TYPE HINTS ---
18
+ GroupReductionStrategyType = Literal[
19
+ "com", "cog", "first", "all", "com_per_group", "cog_per_group"
20
+ ]
21
+ MultiGroupStrategyType = Literal["first", "all_pairs", "corresponding", "first_to_all"]
22
+ SiteIdentifier = Union[str, List[int]]
23
+ ColorTuple = Tuple[float, float, float]
24
+ AtomHighlightMap = Dict[int, ColorTuple]
25
+
26
+
27
+ # --- BASE CLASS FOR SHARED LOGIC ---
28
+ class _BasePlumedCV(CollectiveVariable):
29
+ """An abstract base class for PLUMED CVs providing shared utilities."""
30
+
31
+ prefix: str
32
+
33
+ def _get_atom_highlights(self, atoms: Atoms, **kwargs) -> Optional[AtomHighlightMap]:
34
+ """
35
+ Get atom indices and colors for visualization.
36
+
37
+ This abstract method must be implemented by subclasses to define which atoms
38
+ to highlight and with which colors.
39
+
40
+ Args:
41
+ atoms: The ASE Atoms object.
42
+ **kwargs: Additional keyword arguments for specific implementations.
43
+
44
+ Returns:
45
+ A dictionary mapping global atom indices to their RGB highlight color,
46
+ or None if selection fails.
47
+ """
48
+ raise NotImplementedError
49
+
50
+ def get_img(self, atoms: Atoms, **kwargs) -> Image.Image:
51
+ """
52
+ Generates an image of the molecule(s) with selected atoms highlighted.
53
+
54
+ This method uses RDKit to render the image. It automatically identifies
55
+ molecular fragments containing highlighted atoms and draws them in a row.
56
+
57
+ Args:
58
+ atoms: The ASE Atoms object to visualize.
59
+ **kwargs: Additional arguments passed to _get_atom_highlights.
60
+
61
+ Returns:
62
+ A PIL Image object of the visualization.
63
+ """
64
+ highlight_map = self._get_atom_highlights(atoms, **kwargs)
65
+ mol = rdkit2ase.ase2rdkit(atoms)
66
+
67
+ if not highlight_map:
68
+ return Draw.MolsToGridImage(
69
+ [mol],
70
+ molsPerRow=1,
71
+ subImgSize=(400, 400),
72
+ useSVG=False,
73
+ )
74
+
75
+ mol_frags = Chem.GetMolFrags(mol, asMols=True)
76
+ frag_indices_list = Chem.GetMolFrags(mol, asMols=False)
77
+
78
+ mols_to_draw, highlights_to_draw, colors_to_draw = [], [], []
79
+ seen_molecules = set()
80
+
81
+ for frag_mol, frag_indices in zip(mol_frags, frag_indices_list):
82
+ local_idx_map = {
83
+ global_idx: local_idx
84
+ for local_idx, global_idx in enumerate(frag_indices)
85
+ }
86
+ current_highlights = {
87
+ local_idx_map[g_idx]: color
88
+ for g_idx, color in highlight_map.items()
89
+ if g_idx in local_idx_map
90
+ }
91
+
92
+ if current_highlights:
93
+ # Create unique identifier: canonical SMILES + highlighted local indices
94
+ canonical_smiles = Chem.MolToSmiles(frag_mol)
95
+ highlighted_local_indices = tuple(sorted(current_highlights.keys()))
96
+ molecule_signature = (canonical_smiles, highlighted_local_indices)
97
+
98
+ if molecule_signature not in seen_molecules:
99
+ seen_molecules.add(molecule_signature)
100
+ mols_to_draw.append(frag_mol)
101
+ highlights_to_draw.append(list(current_highlights.keys()))
102
+ colors_to_draw.append(current_highlights)
103
+
104
+ if not mols_to_draw:
105
+ return Draw.MolsToGridImage(
106
+ [mol],
107
+ molsPerRow=1,
108
+ subImgSize=(400, 400),
109
+ useSVG=False,
110
+ )
111
+
112
+ return Draw.MolsToGridImage(
113
+ mols_to_draw,
114
+ molsPerRow=len(mols_to_draw),
115
+ subImgSize=(400, 400),
116
+ highlightAtomLists=highlights_to_draw,
117
+ highlightAtomColors=colors_to_draw,
118
+ useSVG=False,
119
+ )
120
+
121
+ @staticmethod
122
+ def _extract_labels(
123
+ commands: List[str], prefix: str, cv_keyword: str
124
+ ) -> List[str]:
125
+ """Extracts generated CV labels from a list of PLUMED commands."""
126
+ return [
127
+ cmd.split(":", 1)[0].strip()
128
+ for cmd in commands
129
+ if cv_keyword in cmd and cmd.strip().startswith((prefix, f"{prefix}_"))
130
+ ]
131
+
132
+ @staticmethod
133
+ def _get_index_pairs(
134
+ len1: int, len2: int, strategy: MultiGroupStrategyType
135
+ ) -> List[Tuple[int, int]]:
136
+ """Determines pairs of group indices based on the multi-group strategy."""
137
+ if strategy == "first":
138
+ return [(0, 0)] if len1 > 0 and len2 > 0 else []
139
+ if strategy == "all_pairs":
140
+ return [(i, j) for i in range(len1) for j in range(len2)]
141
+ if strategy == "corresponding":
142
+ n = min(len1, len2)
143
+ return [(i, i) for i in range(n)]
144
+ if strategy == "first_to_all":
145
+ return [(0, j) for j in range(len2)] if len1 > 0 else []
146
+ raise ValueError(f"Unknown multi-group strategy: {strategy}")
147
+
148
+ @staticmethod
149
+ def _create_virtual_site_command(
150
+ group: List[int], strategy: Literal["com", "cog"], label: str
151
+ ) -> str:
152
+ """Creates a PLUMED command for a COM or CENTER virtual site."""
153
+ if not group:
154
+ raise ValueError("Cannot create a virtual site for an empty group.")
155
+ atom_list = ",".join(str(idx + 1) for idx in group)
156
+ cmd_keyword = "COM" if strategy == "com" else "CENTER"
157
+ return f"{label}: {cmd_keyword} ATOMS={atom_list}"
158
+
159
+
160
+ # --- REFACTORED CV CLASSES ---
161
+ @dataclass
162
+ class DistanceCV(_BasePlumedCV):
163
+ """
164
+ PLUMED DISTANCE collective variable.
165
+
166
+ Calculates the distance between two atoms or groups of atoms. This CV supports
167
+ various strategies for reducing groups to single points (e.g., center of mass)
168
+ and for pairing multiple groups.
169
+
170
+ Attributes:
171
+ x1: Selector for the first atom/group.
172
+ x2: Selector for the second atom/group.
173
+ prefix: Label prefix for the generated PLUMED commands.
174
+ group_reduction: Strategy to reduce an atom group to a single point.
175
+ multi_group: Strategy for handling multiple groups from selectors.
176
+ create_virtual_sites: If True, create explicit virtual sites for COM/COG.
177
+
178
+ Resources:
179
+ - https://www.plumed.org/doc-master/user-doc/html/DISTANCE.html
180
+ """
181
+
182
+ x1: AtomSelector
183
+ x2: AtomSelector
184
+ prefix: str
185
+ group_reduction: GroupReductionStrategyType = "com"
186
+ multi_group: MultiGroupStrategyType = "first"
187
+ create_virtual_sites: bool = True
188
+
189
+ def _get_atom_highlights(
190
+ self, atoms: Atoms, **kwargs
191
+ ) -> Optional[AtomHighlightMap]:
192
+ groups1 = self.x1.select(atoms)
193
+ groups2 = self.x2.select(atoms)
194
+
195
+ if not groups1 or not groups2:
196
+ return None
197
+
198
+ index_pairs = self._get_index_pairs(len(groups1), len(groups2), self.multi_group)
199
+ if not index_pairs:
200
+ return None
201
+
202
+ # Correctly select atoms based on the group_reduction strategy
203
+ indices1, indices2 = set(), set()
204
+ for i, j in index_pairs:
205
+ # Handle the 'first' atom case specifically for highlighting
206
+ if self.group_reduction == "first":
207
+ # Ensure the group is not empty before accessing the first element
208
+ if groups1[i]:
209
+ indices1.add(groups1[i][0])
210
+ if groups2[j]:
211
+ indices2.add(groups2[j][0])
212
+ # For other strategies (com, cog, all), highlight the whole group
213
+ else:
214
+ indices1.update(groups1[i])
215
+ indices2.update(groups2[j])
216
+
217
+ if not indices1 and not indices2:
218
+ return None
219
+
220
+ # Color atoms based on group membership, with purple for overlaps.
221
+ highlights: AtomHighlightMap = {}
222
+ red, blue, purple = (1.0, 0.2, 0.2), (0.2, 0.2, 1.0), (1.0, 0.2, 1.0)
223
+ for idx in indices1.union(indices2):
224
+ in1, in2 = idx in indices1, idx in indices2
225
+ if in1 and in2:
226
+ highlights[idx] = purple
227
+ elif in1:
228
+ highlights[idx] = red
229
+ elif in2:
230
+ highlights[idx] = blue
231
+ return highlights
232
+
233
+ def to_plumed(self, atoms: Atoms) -> Tuple[List[str], List[str]]:
234
+ """
235
+ Generates PLUMED input strings for the DISTANCE CV.
236
+
237
+ Returns:
238
+ A tuple containing a list of CV labels and a list of PLUMED commands.
239
+ """
240
+ groups1 = self.x1.select(atoms)
241
+ groups2 = self.x2.select(atoms)
242
+
243
+ if not groups1 or not groups2:
244
+ raise ValueError(f"Empty selection for distance CV '{self.prefix}'")
245
+
246
+ flat1 = {idx for group in groups1 for idx in group}
247
+ flat2 = {idx for group in groups2 for idx in group}
248
+ if flat1.intersection(flat2) and self.group_reduction not in ["com", "cog"]:
249
+ raise ValueError(
250
+ "Overlapping atoms found. This is only valid with 'com' or 'cog' reduction."
251
+ )
252
+
253
+ commands = self._generate_commands(groups1, groups2)
254
+ labels = self._extract_labels(commands, self.prefix, "DISTANCE")
255
+ return labels, commands
256
+
257
+ def _generate_commands(
258
+ self, groups1: List[List[int]], groups2: List[List[int]]
259
+ ) -> List[str]:
260
+ """Generates all necessary PLUMED commands."""
261
+ commands = []
262
+ index_pairs = self._get_index_pairs(
263
+ len(groups1), len(groups2), self.multi_group
264
+ )
265
+
266
+ # Efficiently create virtual sites only for groups that will be used.
267
+ sites1, sites2 = {}, {}
268
+ unique_indices1 = sorted({i for i, j in index_pairs})
269
+ unique_indices2 = sorted({j for i, j in index_pairs})
270
+
271
+ for i in unique_indices1:
272
+ site, site_cmds = self._reduce_group(groups1[i], f"{self.prefix}_g1_{i}")
273
+ sites1[i] = site
274
+ commands.extend(site_cmds)
275
+ for j in unique_indices2:
276
+ site, site_cmds = self._reduce_group(groups2[j], f"{self.prefix}_g2_{j}")
277
+ sites2[j] = site
278
+ commands.extend(site_cmds)
279
+
280
+ # Create the final DISTANCE commands.
281
+ for i, j in index_pairs:
282
+ label = self.prefix if len(index_pairs) == 1 else f"{self.prefix}_{i}_{j}"
283
+ cmd = self._make_distance_command(sites1[i], sites2[j], label)
284
+ commands.append(cmd)
285
+
286
+ return commands
287
+
288
+ def _reduce_group(
289
+ self, group: List[int], site_prefix: str
290
+ ) -> Tuple[SiteIdentifier, List[str]]:
291
+ """Reduces a single atom group to a site identifier based on strategy."""
292
+ if len(group) == 1 or self.group_reduction == "first":
293
+ return str(group[0] + 1), []
294
+ if self.group_reduction == "all":
295
+ return group, []
296
+
297
+ if self.group_reduction in ["com", "cog"]:
298
+ if self.create_virtual_sites:
299
+ label = f"{site_prefix}_{self.group_reduction}"
300
+ cmd = self._create_virtual_site_command(
301
+ group, self.group_reduction, label
302
+ )
303
+ return label, [cmd]
304
+ return group, [] # Use group directly if not creating virtual sites
305
+
306
+ raise ValueError(f"Unknown group reduction strategy: {self.group_reduction}")
307
+
308
+ def _make_distance_command(
309
+ self, site1: SiteIdentifier, site2: SiteIdentifier, label: str
310
+ ) -> str:
311
+ """Creates a single PLUMED DISTANCE command string."""
312
+
313
+ def _format(site):
314
+ return ",".join(map(str, (s + 1 for s in site))) if isinstance(site, list) else site
315
+
316
+ s1_str, s2_str = _format(site1), _format(site2)
317
+ # Use ATOMS for point-like sites, ATOMS1/ATOMS2 for group-based distances
318
+ if isinstance(site1, str) and isinstance(site2, str):
319
+ return f"{label}: DISTANCE ATOMS={s1_str},{s2_str}"
320
+ return f"{label}: DISTANCE ATOMS1={s1_str} ATOMS2={s2_str}"
321
+
322
+
323
+ @dataclass
324
+ class CoordinationNumberCV(_BasePlumedCV):
325
+ """
326
+ PLUMED COORDINATION collective variable.
327
+
328
+ Calculates a coordination number based on a switching function. It supports
329
+ complex group definitions, including groups of virtual sites.
330
+
331
+ Attributes:
332
+ x1, x2: Selectors for the two groups of atoms.
333
+ prefix: Label prefix for the generated PLUMED commands.
334
+ r_0: The reference distance for the switching function (in Angstroms).
335
+ nn, mm, d_0: Parameters for the switching function.
336
+ group_reduction_1, group_reduction_2: Reduction strategies for each group.
337
+ multi_group: Strategy for handling multiple groups from selectors.
338
+ create_virtual_sites: If True, create explicit virtual sites for COM/COG.
339
+
340
+ Resources:
341
+ - https://www.plumed.org/doc-master/user-doc/html/COORDINATION.html
342
+ - https://www.plumed.org/doc-master/user-doc/html/GROUP.html
343
+ """
344
+
345
+ x1: AtomSelector
346
+ x2: AtomSelector
347
+ prefix: str
348
+ r_0: float
349
+ nn: int = 6
350
+ mm: int = 0
351
+ d_0: float = 0.0
352
+ group_reduction_1: GroupReductionStrategyType = "all"
353
+ group_reduction_2: GroupReductionStrategyType = "all"
354
+ multi_group: MultiGroupStrategyType = "first"
355
+ create_virtual_sites: bool = True
356
+
357
+ def _get_atom_highlights(
358
+ self, atoms: Atoms, **kwargs
359
+ ) -> Optional[AtomHighlightMap]:
360
+ highlight_hydrogens = kwargs.get("highlight_hydrogens", False)
361
+ groups1 = self.x1.select(atoms)
362
+ groups2 = self.x2.select(atoms)
363
+
364
+ if not groups1 or not groups2:
365
+ return None
366
+
367
+ # Flatten groups and optionally filter out hydrogens.
368
+ indices1 = {idx for g in groups1 for idx in g}
369
+ indices2 = {idx for g in groups2 for idx in g}
370
+ if not highlight_hydrogens:
371
+ indices1 = {i for i in indices1 if atoms[i].symbol != "H"}
372
+ indices2 = {i for i in indices2 if atoms[i].symbol != "H"}
373
+
374
+ if not indices1 and not indices2:
375
+ return None
376
+
377
+ # Color atoms based on group membership, with purple for overlaps.
378
+ highlights: AtomHighlightMap = {}
379
+ red, blue, purple = (1.0, 0.5, 0.5), (0.5, 0.5, 1.0), (1.0, 0.5, 1.0)
380
+ for idx in indices1.union(indices2):
381
+ in1, in2 = idx in indices1, idx in indices2
382
+ if in1 and in2:
383
+ highlights[idx] = purple
384
+ elif in1:
385
+ highlights[idx] = red
386
+ elif in2:
387
+ highlights[idx] = blue
388
+ return highlights
389
+
390
+ def to_plumed(self, atoms: Atoms) -> Tuple[List[str], List[str]]:
391
+ """
392
+ Generates PLUMED input strings for the COORDINATION CV.
393
+
394
+ Returns:
395
+ A tuple containing a list of CV labels and a list of PLUMED commands.
396
+ """
397
+ groups1 = self.x1.select(atoms)
398
+ groups2 = self.x2.select(atoms)
399
+
400
+ if not groups1 or not groups2:
401
+ raise ValueError(f"Empty selection for coordination CV '{self.prefix}'")
402
+
403
+ commands = self._generate_commands(groups1, groups2)
404
+ labels = self._extract_labels(commands, self.prefix, "COORDINATION")
405
+ return labels, commands
406
+
407
+ def _generate_commands(
408
+ self, groups1: List[List[int]], groups2: List[List[int]]
409
+ ) -> List[str]:
410
+ """Generates all necessary PLUMED commands."""
411
+ commands: List[str] = []
412
+
413
+ sites1 = self._reduce_groups(
414
+ groups1, self.group_reduction_1, f"{self.prefix}_g1", commands
415
+ )
416
+ sites2 = self._reduce_groups(
417
+ groups2, self.group_reduction_2, f"{self.prefix}_g2", commands
418
+ )
419
+
420
+ # Get site pairs using a simplified helper
421
+ site_pairs = []
422
+ if self.multi_group == "first":
423
+ site_pairs = [(sites1[0], sites2[0])] if sites1 and sites2 else []
424
+ elif self.multi_group == "all_pairs":
425
+ site_pairs = [(s1, s2) for s1 in sites1 for s2 in sites2]
426
+ elif self.multi_group == "corresponding":
427
+ n = min(len(sites1), len(sites2))
428
+ site_pairs = [(sites1[i], sites2[i]) for i in range(n)]
429
+ elif self.multi_group == "first_to_all":
430
+ site_pairs = [(sites1[0], s2) for s2 in sites2] if sites1 else []
431
+
432
+ for i, (s1, s2) in enumerate(site_pairs):
433
+ label = self.prefix if len(site_pairs) == 1 else f"{self.prefix}_{i}"
434
+ commands.append(self._make_coordination_command(s1, s2, label))
435
+
436
+ return commands
437
+
438
+ def _reduce_groups(
439
+ self,
440
+ groups: List[List[int]],
441
+ strategy: GroupReductionStrategyType,
442
+ site_prefix: str,
443
+ commands: List[str],
444
+ ) -> List[SiteIdentifier]:
445
+ """Reduces a list of atom groups into a list of site identifiers."""
446
+ if strategy in ["com_per_group", "cog_per_group"]:
447
+ if not self.create_virtual_sites:
448
+ raise ValueError(f"'{strategy}' requires create_virtual_sites=True")
449
+
450
+ reduction_type = "COM" if strategy == "com_per_group" else "CENTER"
451
+ vsite_labels = []
452
+ for i, group in enumerate(groups):
453
+ if not group:
454
+ continue
455
+ vsite_label = f"{site_prefix}_{i}"
456
+ atom_list = ",".join(str(idx + 1) for idx in group)
457
+ commands.append(f"{vsite_label}: {reduction_type} ATOMS={atom_list}")
458
+ vsite_labels.append(vsite_label)
459
+
460
+ group_label = f"{site_prefix}_group"
461
+ commands.append(f"{group_label}: GROUP ATOMS={','.join(vsite_labels)}")
462
+ return [group_label]
463
+
464
+ if strategy == "all":
465
+ return [sorted({idx for group in groups for idx in group})]
466
+
467
+ # Handle other strategies by reducing each group individually.
468
+ sites: List[SiteIdentifier] = []
469
+ for i, group in enumerate(groups):
470
+ if len(group) == 1 or strategy == "first":
471
+ sites.append(str(group[0] + 1))
472
+ elif strategy in ["com", "cog"]:
473
+ if self.create_virtual_sites:
474
+ label = f"{site_prefix}_{i}_{strategy}"
475
+ cmd = self._create_virtual_site_command(group, strategy, label)
476
+ commands.append(cmd)
477
+ sites.append(label)
478
+ else:
479
+ sites.append(group)
480
+ else:
481
+ raise ValueError(f"Unsupported reduction strategy: {strategy}")
482
+ return sites
483
+
484
+ def _make_coordination_command(
485
+ self, site1: SiteIdentifier, site2: SiteIdentifier, label: str
486
+ ) -> str:
487
+ """Creates a single PLUMED COORDINATION command string."""
488
+
489
+ def _format(site):
490
+ return ",".join(map(str, (s + 1 for s in site))) if isinstance(site, list) else site
491
+
492
+ g_a, g_b = _format(site1), _format(site2)
493
+ base_cmd = f"{label}: COORDINATION GROUPA={g_a}"
494
+ if g_a != g_b: # Omit GROUPB for self-coordination
495
+ base_cmd += f" GROUPB={g_b}"
496
+
497
+ params = f" R_0={self.r_0} NN={self.nn} D_0={self.d_0}"
498
+ if self.mm != 0:
499
+ params += f" MM={self.mm}"
500
+
501
+ return base_cmd + params
502
+
503
+
504
+ @dataclass
505
+ class TorsionCV(_BasePlumedCV):
506
+ """
507
+ PLUMED TORSION collective variable.
508
+
509
+ Calculates the torsional (dihedral) angle defined by four atoms. Each group
510
+ provided by the selector must contain exactly four atoms.
511
+
512
+ Attributes:
513
+ atoms: Selector for one or more groups of 4 atoms.
514
+ prefix: Label prefix for the generated PLUMED commands.
515
+ multi_group: Strategy for handling multiple groups from the selector.
516
+
517
+ Resources:
518
+ - https://www.plumed.org/doc-master/user-doc/html/TORSION.html
519
+ """
520
+
521
+ atoms: AtomSelector
522
+ prefix: str
523
+ multi_group: MultiGroupStrategyType = "first"
524
+
525
+ def _get_atom_highlights(
526
+ self, atoms: Atoms, **kwargs
527
+ ) -> Optional[AtomHighlightMap]:
528
+ groups = self.atoms.select(atoms)
529
+ if not groups or len(groups[0]) != 4:
530
+ print("Warning: Torsion CV requires a group of 4 atoms for visualization.")
531
+ return None
532
+
533
+ # Highlight the first 4-atom group with a color sequence.
534
+ torsion_atoms = groups[0]
535
+ colors = [
536
+ (1.0, 0.2, 0.2), # Red
537
+ (1.0, 0.6, 0.2), # Orange
538
+ (1.0, 1.0, 0.2), # Yellow
539
+ (0.2, 1.0, 0.2), # Green
540
+ ]
541
+ return {atom_idx: color for atom_idx, color in zip(torsion_atoms, colors)}
542
+
543
+ def to_plumed(self, atoms: Atoms) -> Tuple[List[str], List[str]]:
544
+ """
545
+ Generates PLUMED input strings for the TORSION CV.
546
+
547
+ Returns:
548
+ A tuple containing a list of CV labels and a list of PLUMED commands.
549
+ """
550
+ groups = self.atoms.select(atoms)
551
+ if not groups:
552
+ raise ValueError(f"Empty selection for torsion CV '{self.prefix}'")
553
+
554
+ for i, group in enumerate(groups):
555
+ if len(group) != 4:
556
+ raise ValueError(
557
+ f"Torsion CV requires 4 atoms per group, but group {i} has {len(group)}."
558
+ )
559
+
560
+ commands = self._generate_commands(groups)
561
+ labels = self._extract_labels(commands, self.prefix, "TORSION")
562
+ return labels, commands
563
+
564
+ def _generate_commands(self, groups: List[List[int]]) -> List[str]:
565
+ """Generates all necessary PLUMED commands."""
566
+ # For torsions, 'multi_group' determines how many groups to process.
567
+ if self.multi_group in ["first", "first_to_all"] and groups:
568
+ indices_to_process = [0]
569
+ else: # "all_pairs" and "corresponding" imply processing all independent groups.
570
+ indices_to_process = list(range(len(groups)))
571
+
572
+ commands = []
573
+ for i in indices_to_process:
574
+ label = self.prefix if len(indices_to_process) == 1 else f"{self.prefix}_{i}"
575
+ atom_list = ",".join(str(idx + 1) for idx in groups[i])
576
+ commands.append(f"{label}: TORSION ATOMS={atom_list}")
577
+ return commands
578
+
579
+
580
+ # TODO: we might need to set weights because plumed does not know about the atomistic weights?
581
+ @dataclass
582
+ class RadiusOfGyrationCV(_BasePlumedCV):
583
+ """
584
+ PLUMED GYRATION collective variable.
585
+
586
+ Calculates the radius of gyration of a group of atoms. The radius of gyration
587
+ is a measure of the size of a molecular system.
588
+
589
+ Attributes:
590
+ atoms: Selector for the atoms to include in the gyration calculation.
591
+ prefix: Label prefix for the generated PLUMED commands.
592
+ multi_group: Strategy for handling multiple groups from the selector.
593
+ type: The type of gyration tensor to use ("RADIUS" for scalar Rg, "GTPC_1", etc.)
594
+
595
+ Resources:
596
+ - https://www.plumed.org/doc-master/user-doc/html/GYRATION/
597
+ """
598
+
599
+ atoms: AtomSelector
600
+ prefix: str
601
+ multi_group: MultiGroupStrategyType = "first"
602
+ type: str = "RADIUS" # Options: RADIUS, GTPC_1, GTPC_2, GTPC_3, ASPHERICITY, ACYLINDRICITY, KAPPA2, etc.
603
+
604
+ def _get_atom_highlights(
605
+ self, atoms: Atoms, **kwargs
606
+ ) -> Optional[AtomHighlightMap]:
607
+ groups = self.atoms.select(atoms)
608
+ if not groups or not groups[0]:
609
+ return None
610
+
611
+ # Highlight all atoms in the first group with a single color
612
+ group = groups[0]
613
+ return {atom_idx: (0.2, 0.8, 0.2) for atom_idx in group} # Green
614
+
615
+ def to_plumed(self, atoms: Atoms) -> Tuple[List[str], List[str]]:
616
+ """
617
+ Generates PLUMED input strings for the GYRATION CV.
618
+
619
+ Returns:
620
+ A tuple containing a list of CV labels and a list of PLUMED commands.
621
+ """
622
+ groups = self.atoms.select(atoms)
623
+ if not groups:
624
+ raise ValueError(f"Empty selection for gyration CV '{self.prefix}'")
625
+
626
+ commands = self._generate_commands(groups)
627
+ labels = self._extract_labels(commands, self.prefix, "GYRATION")
628
+ return labels, commands
629
+
630
+ def _generate_commands(self, groups: List[List[int]]) -> List[str]:
631
+ """Generates all necessary PLUMED commands."""
632
+ # For gyration, 'multi_group' determines how many groups to process.
633
+ if self.multi_group in ["first", "first_to_all"] and groups:
634
+ indices_to_process = [0]
635
+ else: # "all_pairs" and "corresponding" imply processing all independent groups.
636
+ indices_to_process = list(range(len(groups)))
637
+
638
+ commands = []
639
+ for i in indices_to_process:
640
+ label = self.prefix if len(indices_to_process) == 1 else f"{self.prefix}_{i}"
641
+ atom_list = ",".join(str(idx + 1) for idx in groups[i])
642
+ command = f"{label}: GYRATION ATOMS={atom_list}"
643
+ if self.type != "RADIUS":
644
+ command += f" TYPE={self.type}"
645
+ commands.append(command)
646
+ return commands