hillclimber 0.1.0a1__py3-none-any.whl → 0.1.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hillclimber might be problematic. Click here for more details.

hillclimber/cvs.py CHANGED
@@ -1,5 +1,7 @@
1
1
  # --- IMPORTS ---
2
2
  # Standard library
3
+ from __future__ import annotations
4
+ import dataclasses
3
5
  from dataclasses import dataclass
4
6
  from typing import Dict, List, Literal, Optional, Tuple, Union
5
7
 
@@ -12,13 +14,13 @@ from rdkit.Chem import Draw
12
14
 
13
15
  # Local
14
16
  from hillclimber.interfaces import AtomSelector, CollectiveVariable
17
+ from hillclimber.virtual_atoms import VirtualAtom
15
18
 
16
19
 
17
20
  # --- TYPE HINTS ---
18
21
  GroupReductionStrategyType = Literal[
19
22
  "com", "cog", "first", "all", "com_per_group", "cog_per_group"
20
23
  ]
21
- MultiGroupStrategyType = Literal["first", "all_pairs", "corresponding", "first_to_all"]
22
24
  SiteIdentifier = Union[str, List[int]]
23
25
  ColorTuple = Tuple[float, float, float]
24
26
  AtomHighlightMap = Dict[int, ColorTuple]
@@ -129,22 +131,6 @@ class _BasePlumedCV(CollectiveVariable):
129
131
  if cv_keyword in cmd and cmd.strip().startswith((prefix, f"{prefix}_"))
130
132
  ]
131
133
 
132
- @staticmethod
133
- def _get_index_pairs(
134
- len1: int, len2: int, strategy: MultiGroupStrategyType
135
- ) -> List[Tuple[int, int]]:
136
- """Determines pairs of group indices based on the multi-group strategy."""
137
- if strategy == "first":
138
- return [(0, 0)] if len1 > 0 and len2 > 0 else []
139
- if strategy == "all_pairs":
140
- return [(i, j) for i in range(len1) for j in range(len2)]
141
- if strategy == "corresponding":
142
- n = min(len1, len2)
143
- return [(i, i) for i in range(n)]
144
- if strategy == "first_to_all":
145
- return [(0, j) for j in range(len2)] if len1 > 0 else []
146
- raise ValueError(f"Unknown multi-group strategy: {strategy}")
147
-
148
134
  @staticmethod
149
135
  def _create_virtual_site_command(
150
136
  group: List[int], strategy: Literal["com", "cog"], label: str
@@ -163,61 +149,97 @@ class DistanceCV(_BasePlumedCV):
163
149
  """
164
150
  PLUMED DISTANCE collective variable.
165
151
 
166
- Calculates the distance between two atoms or groups of atoms. This CV supports
167
- various strategies for reducing groups to single points (e.g., center of mass)
168
- and for pairing multiple groups.
169
-
170
- Attributes:
171
- x1: Selector for the first atom/group.
172
- x2: Selector for the second atom/group.
173
- prefix: Label prefix for the generated PLUMED commands.
174
- group_reduction: Strategy to reduce an atom group to a single point.
175
- multi_group: Strategy for handling multiple groups from selectors.
176
- create_virtual_sites: If True, create explicit virtual sites for COM/COG.
177
-
178
- Resources:
179
- - https://www.plumed.org/doc-master/user-doc/html/DISTANCE.html
152
+ Calculates the distance between two atoms, groups of atoms, or virtual sites.
153
+ Supports flexible flattening and pairing strategies for multiple groups.
154
+
155
+ Parameters
156
+ ----------
157
+ x1 : AtomSelector | VirtualAtom
158
+ First atom/group or virtual site.
159
+ x2 : AtomSelector | VirtualAtom
160
+ Second atom/group or virtual site.
161
+ prefix : str
162
+ Label prefix for generated PLUMED commands.
163
+ flatten : bool, default=True
164
+ For AtomSelectors only: If True, flatten all groups into single atom list.
165
+ If False, create PLUMED GROUP for each group. VirtualAtoms are never flattened.
166
+ pairwise : {"all", "diagonal", "none"}, default="all"
167
+ Strategy for pairing multiple groups:
168
+ - "all": Create all N×M pair combinations (can create many CVs!)
169
+ - "diagonal": Pair corresponding indices only (creates min(N,M) CVs)
170
+ - "none": Error if both sides have multiple groups (safety check)
171
+
172
+ Examples
173
+ --------
174
+ >>> # Distance between two specific atoms
175
+ >>> dist = hc.DistanceCV(
176
+ ... x1=ethanol_sel[0][0], # First atom of first ethanol
177
+ ... x2=water_sel[0][0], # First atom of first water
178
+ ... prefix="d_atoms"
179
+ ... )
180
+
181
+ >>> # Distance between molecule COMs
182
+ >>> dist = hc.DistanceCV(
183
+ ... x1=hc.VirtualAtom(ethanol_sel[0], "com"),
184
+ ... x2=hc.VirtualAtom(water_sel[0], "com"),
185
+ ... prefix="d_com"
186
+ ... )
187
+
188
+ >>> # One-to-many: First ethanol COM to all water COMs
189
+ >>> dist = hc.DistanceCV(
190
+ ... x1=hc.VirtualAtom(ethanol_sel[0], "com"),
191
+ ... x2=hc.VirtualAtom(water_sel, "com"),
192
+ ... prefix="d",
193
+ ... pairwise="all" # Creates 3 CVs
194
+ ... )
195
+
196
+ >>> # Diagonal pairing (avoid explosion)
197
+ >>> dist = hc.DistanceCV(
198
+ ... x1=hc.VirtualAtom(water_sel, "com"), # 3 waters
199
+ ... x2=hc.VirtualAtom(ethanol_sel, "com"), # 2 ethanols
200
+ ... prefix="d",
201
+ ... pairwise="diagonal" # Creates only 2 CVs: d_0, d_1
202
+ ... )
203
+
204
+ Resources
205
+ ---------
206
+ - https://www.plumed.org/doc-master/user-doc/html/DISTANCE.html
207
+
208
+ Notes
209
+ -----
210
+ For backwards compatibility, old parameters are still supported but deprecated:
211
+ - `group_reduction` → Use VirtualAtom instead
212
+ - `multi_group` → Use `pairwise` parameter
180
213
  """
181
214
 
182
- x1: AtomSelector
183
- x2: AtomSelector
215
+ x1: AtomSelector | VirtualAtom
216
+ x2: AtomSelector | VirtualAtom
184
217
  prefix: str
185
- group_reduction: GroupReductionStrategyType = "com"
186
- multi_group: MultiGroupStrategyType = "first"
187
- create_virtual_sites: bool = True
218
+ flatten: bool = True
219
+ pairwise: Literal["all", "diagonal", "none"] = "all"
188
220
 
189
221
  def _get_atom_highlights(
190
222
  self, atoms: Atoms, **kwargs
191
223
  ) -> Optional[AtomHighlightMap]:
224
+ """Get atom highlights for visualization."""
225
+ # Skip for VirtualAtom inputs
226
+ if isinstance(self.x1, VirtualAtom) or isinstance(self.x2, VirtualAtom):
227
+ return None
228
+
192
229
  groups1 = self.x1.select(atoms)
193
230
  groups2 = self.x2.select(atoms)
194
231
 
195
232
  if not groups1 or not groups2:
196
233
  return None
197
234
 
198
- index_pairs = self._get_index_pairs(len(groups1), len(groups2), self.multi_group)
199
- if not index_pairs:
200
- return None
201
-
202
- # Correctly select atoms based on the group_reduction strategy
203
- indices1, indices2 = set(), set()
204
- for i, j in index_pairs:
205
- # Handle the 'first' atom case specifically for highlighting
206
- if self.group_reduction == "first":
207
- # Ensure the group is not empty before accessing the first element
208
- if groups1[i]:
209
- indices1.add(groups1[i][0])
210
- if groups2[j]:
211
- indices2.add(groups2[j][0])
212
- # For other strategies (com, cog, all), highlight the whole group
213
- else:
214
- indices1.update(groups1[i])
215
- indices2.update(groups2[j])
235
+ # Highlight all atoms from both selections
236
+ indices1 = {idx for group in groups1 for idx in group}
237
+ indices2 = {idx for group in groups2 for idx in group}
216
238
 
217
239
  if not indices1 and not indices2:
218
240
  return None
219
241
 
220
- # Color atoms based on group membership, with purple for overlaps.
242
+ # Color atoms based on group membership
221
243
  highlights: AtomHighlightMap = {}
222
244
  red, blue, purple = (1.0, 0.2, 0.2), (0.2, 0.2, 1.0), (1.0, 0.2, 1.0)
223
245
  for idx in indices1.union(indices2):
@@ -231,93 +253,366 @@ class DistanceCV(_BasePlumedCV):
231
253
  return highlights
232
254
 
233
255
  def to_plumed(self, atoms: Atoms) -> Tuple[List[str], List[str]]:
256
+ """Generate PLUMED input strings for the DISTANCE CV.
257
+
258
+ Returns
259
+ -------
260
+ labels : list[str]
261
+ List of CV labels generated.
262
+ commands : list[str]
263
+ List of PLUMED command strings.
234
264
  """
235
- Generates PLUMED input strings for the DISTANCE CV.
265
+ commands = []
236
266
 
237
- Returns:
238
- A tuple containing a list of CV labels and a list of PLUMED commands.
267
+ # Process x1
268
+ labels1, cmds1 = self._process_input(self.x1, atoms, "x1")
269
+ commands.extend(cmds1)
270
+
271
+ # Process x2
272
+ labels2, cmds2 = self._process_input(self.x2, atoms, "x2")
273
+ commands.extend(cmds2)
274
+
275
+ # Check for empty selections
276
+ if not labels1 or not labels2:
277
+ raise ValueError(f"Empty selection for distance CV '{self.prefix}'")
278
+
279
+ # Generate distance CVs based on pairwise strategy
280
+ cv_labels, cv_commands = self._generate_distance_cvs(labels1, labels2)
281
+ commands.extend(cv_commands)
282
+
283
+ return cv_labels, commands
284
+
285
+ def _process_input(
286
+ self, input_obj: AtomSelector | VirtualAtom, atoms: Atoms, label_prefix: str
287
+ ) -> Tuple[List[str], List[str]]:
288
+ """Process an input (AtomSelector or VirtualAtom) and return labels and commands.
289
+
290
+ Returns
291
+ -------
292
+ labels : list[str]
293
+ List of labels for this input (either virtual site labels or GROUP labels).
294
+ commands : list[str]
295
+ PLUMED commands to create the labels.
239
296
  """
297
+ if isinstance(input_obj, VirtualAtom):
298
+ # VirtualAtom: set deterministic label if not already set
299
+ if input_obj.label is None:
300
+ # Set label based on prefix and label_prefix (x1 or x2)
301
+ labeled_va = dataclasses.replace(
302
+ input_obj, label=f"{self.prefix}_{label_prefix}"
303
+ )
304
+ return labeled_va.to_plumed(atoms)
305
+ else:
306
+ return input_obj.to_plumed(atoms)
307
+ else:
308
+ # AtomSelector: handle based on flatten parameter
309
+ groups = input_obj.select(atoms)
310
+ if not groups:
311
+ return [], []
312
+
313
+ if self.flatten:
314
+ # Flatten all groups into single list
315
+ flat_atoms = [idx for group in groups for idx in group]
316
+ atom_list = ",".join(str(idx + 1) for idx in flat_atoms)
317
+ # Return as pseudo-label (will be used directly in DISTANCE command)
318
+ return [atom_list], []
319
+ else:
320
+ # Smart GROUP creation: only create GROUP for multi-atom groups
321
+ labels = []
322
+ commands = []
323
+ for i, group in enumerate(groups):
324
+ if len(group) == 1:
325
+ # Single atom: use directly (no GROUP needed)
326
+ labels.append(str(group[0] + 1))
327
+ else:
328
+ # Multi-atom group: create GROUP
329
+ group_label = f"{self.prefix}_{label_prefix}_g{i}"
330
+ atom_list = ",".join(str(idx + 1) for idx in group)
331
+ commands.append(f"{group_label}: GROUP ATOMS={atom_list}")
332
+ labels.append(group_label)
333
+ return labels, commands
334
+
335
+ def _generate_distance_cvs(
336
+ self, labels1: List[str], labels2: List[str]
337
+ ) -> Tuple[List[str], List[str]]:
338
+ """Generate DISTANCE CV commands based on pairwise strategy."""
339
+ n1, n2 = len(labels1), len(labels2)
340
+
341
+ # Determine which pairs to create based on pairwise strategy
342
+ if n1 == 1 and n2 == 1:
343
+ # One-to-one: always create single CV
344
+ pairs = [(0, 0)]
345
+ elif n1 == 1:
346
+ # One-to-many: pair first of x1 with all of x2
347
+ pairs = [(0, j) for j in range(n2)]
348
+ elif n2 == 1:
349
+ # Many-to-one: pair all of x1 with first of x2
350
+ pairs = [(i, 0) for i in range(n1)]
351
+ else:
352
+ # Many-to-many: apply pairwise strategy
353
+ if self.pairwise == "all":
354
+ pairs = [(i, j) for i in range(n1) for j in range(n2)]
355
+ elif self.pairwise == "diagonal":
356
+ n_pairs = min(n1, n2)
357
+ pairs = [(i, i) for i in range(n_pairs)]
358
+ elif self.pairwise == "none":
359
+ raise ValueError(
360
+ f"Both x1 and x2 have multiple groups ({n1} and {n2}). "
361
+ f"Use pairwise='all' or 'diagonal', or select specific groups with indexing."
362
+ )
363
+ else:
364
+ raise ValueError(f"Unknown pairwise strategy: {self.pairwise}")
365
+
366
+ # Generate DISTANCE commands
367
+ cv_labels = []
368
+ commands = []
369
+ for idx, (i, j) in enumerate(pairs):
370
+ if len(pairs) == 1:
371
+ label = self.prefix
372
+ else:
373
+ label = f"{self.prefix}_{idx}"
374
+
375
+ # Create DISTANCE command
376
+ cmd = f"{label}: DISTANCE ATOMS={labels1[i]},{labels2[j]}"
377
+ commands.append(cmd)
378
+ cv_labels.append(label)
379
+
380
+ return cv_labels, commands
381
+
382
+
383
+ @dataclass
384
+ class AngleCV(_BasePlumedCV):
385
+ """
386
+ PLUMED ANGLE collective variable.
387
+
388
+ Calculates the angle formed by three atoms or groups of atoms using the new
389
+ VirtualAtom API. The angle is computed as the angle between the vectors
390
+ (x1-x2) and (x3-x2), where x2 is the vertex of the angle.
391
+
392
+ Parameters
393
+ ----------
394
+ x1 : AtomSelector | VirtualAtom
395
+ First position. Can be an AtomSelector or VirtualAtom.
396
+ x2 : AtomSelector | VirtualAtom
397
+ Vertex position (center of the angle). Can be an AtomSelector or VirtualAtom.
398
+ x3 : AtomSelector | VirtualAtom
399
+ Third position. Can be an AtomSelector or VirtualAtom.
400
+ prefix : str
401
+ Label prefix for the generated PLUMED commands.
402
+ flatten : bool, default=True
403
+ How to handle AtomSelector inputs:
404
+ - True: Flatten all groups into a single list
405
+ - False: Create GROUP for each selector group (not typically used for ANGLE)
406
+ strategy : {"first", "all", "diagonal", "none"}, default="first"
407
+ Strategy for creating multiple angles from multiple groups:
408
+ - "first": Use first group from each selector (1 angle)
409
+ - "all": All combinations (N×M×P angles)
410
+ - "diagonal": Pair by index (min(N,M,P) angles)
411
+ - "none": Raise error if any selector has multiple groups
412
+
413
+ Resources
414
+ ---------
415
+ - https://www.plumed.org/doc-master/user-doc/html/ANGLE/
416
+ """
417
+
418
+ x1: AtomSelector | VirtualAtom
419
+ x2: AtomSelector | VirtualAtom
420
+ x3: AtomSelector | VirtualAtom
421
+ prefix: str
422
+ flatten: bool = True
423
+ strategy: Literal["first", "all", "diagonal", "none"] = "first"
424
+
425
+ def _get_atom_highlights(
426
+ self, atoms: Atoms, **kwargs
427
+ ) -> Optional[AtomHighlightMap]:
428
+ """Get atom highlights for visualization."""
429
+ # Skip for VirtualAtom inputs
430
+ if isinstance(self.x1, VirtualAtom) or isinstance(self.x2, VirtualAtom) or isinstance(self.x3, VirtualAtom):
431
+ return None
432
+
240
433
  groups1 = self.x1.select(atoms)
241
434
  groups2 = self.x2.select(atoms)
435
+ groups3 = self.x3.select(atoms)
242
436
 
243
- if not groups1 or not groups2:
244
- raise ValueError(f"Empty selection for distance CV '{self.prefix}'")
437
+ if not groups1 or not groups2 or not groups3:
438
+ return None
245
439
 
246
- flat1 = {idx for group in groups1 for idx in group}
247
- flat2 = {idx for group in groups2 for idx in group}
248
- if flat1.intersection(flat2) and self.group_reduction not in ["com", "cog"]:
249
- raise ValueError(
250
- "Overlapping atoms found. This is only valid with 'com' or 'cog' reduction."
251
- )
440
+ # Highlight all atoms from all three selections
441
+ indices1 = {idx for group in groups1 for idx in group}
442
+ indices2 = {idx for group in groups2 for idx in group}
443
+ indices3 = {idx for group in groups3 for idx in group}
252
444
 
253
- commands = self._generate_commands(groups1, groups2)
254
- labels = self._extract_labels(commands, self.prefix, "DISTANCE")
255
- return labels, commands
445
+ if not indices1 and not indices2 and not indices3:
446
+ return None
447
+
448
+ # Color atoms: red for x1, green for x2 (vertex), blue for x3
449
+ highlights: AtomHighlightMap = {}
450
+ red, green, blue = (1.0, 0.2, 0.2), (0.2, 1.0, 0.2), (0.2, 0.2, 1.0)
451
+
452
+ # Handle overlaps by prioritizing vertex (x2) coloring
453
+ all_indices = indices1.union(indices2).union(indices3)
454
+ for idx in all_indices:
455
+ in1, in2, in3 = idx in indices1, idx in indices2, idx in indices3
456
+ if in2: # Vertex gets priority
457
+ highlights[idx] = green
458
+ elif in1 and in3: # Overlap between x1 and x3
459
+ highlights[idx] = (0.5, 0.2, 0.6) # Purple
460
+ elif in1:
461
+ highlights[idx] = red
462
+ elif in3:
463
+ highlights[idx] = blue
464
+ return highlights
465
+
466
+ def to_plumed(self, atoms: Atoms) -> Tuple[List[str], List[str]]:
467
+ """Generate PLUMED ANGLE command(s).
468
+
469
+ Returns
470
+ -------
471
+ labels : list[str]
472
+ List of CV labels created.
473
+ commands : list[str]
474
+ List of PLUMED commands.
475
+
476
+ Raises
477
+ ------
478
+ ValueError
479
+ If any selector returns empty selection.
480
+ """
481
+ # Process all three inputs
482
+ labels1, cmds1 = self._process_input(self.x1, atoms, "x1")
483
+ labels2, cmds2 = self._process_input(self.x2, atoms, "x2")
484
+ labels3, cmds3 = self._process_input(self.x3, atoms, "x3")
485
+
486
+ # Check for empty selections
487
+ if not labels1 or not labels2 or not labels3:
488
+ raise ValueError(f"Empty selection for angle CV '{self.prefix}'")
256
489
 
257
- def _generate_commands(
258
- self, groups1: List[List[int]], groups2: List[List[int]]
259
- ) -> List[str]:
260
- """Generates all necessary PLUMED commands."""
261
490
  commands = []
262
- index_pairs = self._get_index_pairs(
263
- len(groups1), len(groups2), self.multi_group
264
- )
491
+ commands.extend(cmds1)
492
+ commands.extend(cmds2)
493
+ commands.extend(cmds3)
265
494
 
266
- # Efficiently create virtual sites only for groups that will be used.
267
- sites1, sites2 = {}, {}
268
- unique_indices1 = sorted({i for i, j in index_pairs})
269
- unique_indices2 = sorted({j for i, j in index_pairs})
270
-
271
- for i in unique_indices1:
272
- site, site_cmds = self._reduce_group(groups1[i], f"{self.prefix}_g1_{i}")
273
- sites1[i] = site
274
- commands.extend(site_cmds)
275
- for j in unique_indices2:
276
- site, site_cmds = self._reduce_group(groups2[j], f"{self.prefix}_g2_{j}")
277
- sites2[j] = site
278
- commands.extend(site_cmds)
279
-
280
- # Create the final DISTANCE commands.
281
- for i, j in index_pairs:
282
- label = self.prefix if len(index_pairs) == 1 else f"{self.prefix}_{i}_{j}"
283
- cmd = self._make_distance_command(sites1[i], sites2[j], label)
284
- commands.append(cmd)
495
+ # Generate ANGLE commands
496
+ cv_labels, cv_commands = self._generate_angle_cvs(labels1, labels2, labels3)
497
+ commands.extend(cv_commands)
285
498
 
286
- return commands
499
+ return cv_labels, commands
287
500
 
288
- def _reduce_group(
289
- self, group: List[int], site_prefix: str
290
- ) -> Tuple[SiteIdentifier, List[str]]:
291
- """Reduces a single atom group to a site identifier based on strategy."""
292
- if len(group) == 1 or self.group_reduction == "first":
293
- return str(group[0] + 1), []
294
- if self.group_reduction == "all":
295
- return group, []
296
-
297
- if self.group_reduction in ["com", "cog"]:
298
- if self.create_virtual_sites:
299
- label = f"{site_prefix}_{self.group_reduction}"
300
- cmd = self._create_virtual_site_command(
301
- group, self.group_reduction, label
302
- )
303
- return label, [cmd]
304
- return group, [] # Use group directly if not creating virtual sites
501
+ def _process_input(
502
+ self, input_obj: AtomSelector | VirtualAtom, atoms: Atoms, label_prefix: str
503
+ ) -> Tuple[List[str], List[str]]:
504
+ """Process input (AtomSelector or VirtualAtom) and return labels and commands.
305
505
 
306
- raise ValueError(f"Unknown group reduction strategy: {self.group_reduction}")
506
+ Same as DistanceCV._process_input() method.
307
507
 
308
- def _make_distance_command(
309
- self, site1: SiteIdentifier, site2: SiteIdentifier, label: str
310
- ) -> str:
311
- """Creates a single PLUMED DISTANCE command string."""
508
+ Returns
509
+ -------
510
+ labels : list[str]
511
+ List of labels for this input (either virtual site labels or atom lists).
512
+ commands : list[str]
513
+ PLUMED commands to create the labels.
514
+ """
515
+ if isinstance(input_obj, VirtualAtom):
516
+ # VirtualAtom: set deterministic label if not already set
517
+ if input_obj.label is None:
518
+ labeled_va = dataclasses.replace(
519
+ input_obj, label=f"{self.prefix}_{label_prefix}"
520
+ )
521
+ return labeled_va.to_plumed(atoms)
522
+ else:
523
+ return input_obj.to_plumed(atoms)
524
+ else:
525
+ # AtomSelector: handle based on flatten parameter
526
+ groups = input_obj.select(atoms)
527
+ if not groups:
528
+ return [], []
529
+
530
+ if self.flatten:
531
+ # Flatten all groups into single list
532
+ flat_atoms = [idx for group in groups for idx in group]
533
+ atom_list = ",".join(str(idx + 1) for idx in flat_atoms)
534
+ # Return as pseudo-label (will be used directly in ANGLE command)
535
+ return [atom_list], []
536
+ else:
537
+ # Smart GROUP creation: only create GROUP for multi-atom groups
538
+ labels = []
539
+ commands = []
540
+ for i, group in enumerate(groups):
541
+ if len(group) == 1:
542
+ # Single atom: use directly (no GROUP needed)
543
+ labels.append(str(group[0] + 1))
544
+ else:
545
+ # Multi-atom group: create GROUP
546
+ group_label = f"{self.prefix}_{label_prefix}_g{i}"
547
+ atom_list = ",".join(str(idx + 1) for idx in group)
548
+ commands.append(f"{group_label}: GROUP ATOMS={atom_list}")
549
+ labels.append(group_label)
550
+ return labels, commands
551
+
552
+ def _generate_angle_cvs(
553
+ self, labels1: List[str], labels2: List[str], labels3: List[str]
554
+ ) -> Tuple[List[str], List[str]]:
555
+ """Generate ANGLE CV commands based on strategy.
556
+
557
+ Parameters
558
+ ----------
559
+ labels1, labels2, labels3 : list[str]
560
+ Labels or atom lists for the three angle positions.
561
+
562
+ Returns
563
+ -------
564
+ cv_labels : list[str]
565
+ Labels for the ANGLE CVs created.
566
+ commands : list[str]
567
+ ANGLE command strings.
568
+ """
569
+ n1, n2, n3 = len(labels1), len(labels2), len(labels3)
570
+
571
+ # Determine which triplets to create based on strategy
572
+ if n1 == 1 and n2 == 1 and n3 == 1:
573
+ # One-to-one-to-one: always create single CV
574
+ triplets = [(0, 0, 0)]
575
+ elif n1 == 1 and n2 == 1:
576
+ # One-one-to-many: pair first of x1/x2 with all of x3
577
+ triplets = [(0, 0, k) for k in range(n3)]
578
+ elif n1 == 1 and n3 == 1:
579
+ # One-many-to-one: pair first of x1/x3 with all of x2
580
+ triplets = [(0, j, 0) for j in range(n2)]
581
+ elif n2 == 1 and n3 == 1:
582
+ # Many-to-one-one: pair all of x1 with first of x2/x3
583
+ triplets = [(i, 0, 0) for i in range(n1)]
584
+ else:
585
+ # Multi-way: apply strategy
586
+ if self.strategy == "first":
587
+ triplets = [(0, 0, 0)] if n1 > 0 and n2 > 0 and n3 > 0 else []
588
+ elif self.strategy == "all":
589
+ triplets = [(i, j, k) for i in range(n1) for j in range(n2) for k in range(n3)]
590
+ elif self.strategy == "diagonal":
591
+ n_triplets = min(n1, n2, n3)
592
+ triplets = [(i, i, i) for i in range(n_triplets)]
593
+ elif self.strategy == "none":
594
+ raise ValueError(
595
+ f"Multiple groups in x1/x2/x3 ({n1}, {n2}, {n3}). "
596
+ f"Use strategy='all' or 'diagonal', or select specific groups with indexing."
597
+ )
598
+ else:
599
+ raise ValueError(f"Unknown strategy: {self.strategy}")
312
600
 
313
- def _format(site):
314
- return ",".join(map(str, (s + 1 for s in site))) if isinstance(site, list) else site
601
+ # Generate ANGLE commands
602
+ cv_labels = []
603
+ commands = []
604
+ for idx, (i, j, k) in enumerate(triplets):
605
+ if len(triplets) == 1:
606
+ label = self.prefix
607
+ else:
608
+ label = f"{self.prefix}_{i}_{j}_{k}"
315
609
 
316
- s1_str, s2_str = _format(site1), _format(site2)
317
- # Use ATOMS for point-like sites, ATOMS1/ATOMS2 for group-based distances
318
- if isinstance(site1, str) and isinstance(site2, str):
319
- return f"{label}: DISTANCE ATOMS={s1_str},{s2_str}"
320
- return f"{label}: DISTANCE ATOMS1={s1_str} ATOMS2={s2_str}"
610
+ # Create ANGLE command (ATOMS=x1,x2,x3 where x2 is vertex)
611
+ cmd = f"{label}: ANGLE ATOMS={labels1[i]},{labels2[j]},{labels3[k]}"
612
+ commands.append(cmd)
613
+ cv_labels.append(label)
614
+
615
+ return cv_labels, commands
321
616
 
322
617
 
323
618
  @dataclass
@@ -325,58 +620,76 @@ class CoordinationNumberCV(_BasePlumedCV):
325
620
  """
326
621
  PLUMED COORDINATION collective variable.
327
622
 
328
- Calculates a coordination number based on a switching function. It supports
329
- complex group definitions, including groups of virtual sites.
330
-
331
- Attributes:
332
- x1, x2: Selectors for the two groups of atoms.
333
- prefix: Label prefix for the generated PLUMED commands.
334
- r_0: The reference distance for the switching function (in Angstroms).
335
- nn, mm, d_0: Parameters for the switching function.
336
- group_reduction_1, group_reduction_2: Reduction strategies for each group.
337
- multi_group: Strategy for handling multiple groups from selectors.
338
- create_virtual_sites: If True, create explicit virtual sites for COM/COG.
339
-
340
- Resources:
341
- - https://www.plumed.org/doc-master/user-doc/html/COORDINATION.html
342
- - https://www.plumed.org/doc-master/user-doc/html/GROUP.html
623
+ Calculates a coordination number based on a switching function using the new
624
+ VirtualAtom API. The coordination number is computed between two groups of atoms
625
+ using a switching function.
626
+
627
+ Parameters
628
+ ----------
629
+ x1 : AtomSelector | VirtualAtom
630
+ First group of atoms. Can be an AtomSelector or VirtualAtom.
631
+ x2 : AtomSelector | VirtualAtom
632
+ Second group of atoms. Can be an AtomSelector or VirtualAtom.
633
+ prefix : str
634
+ Label prefix for the generated PLUMED commands.
635
+ r_0 : float
636
+ Reference distance for the switching function (in Angstroms).
637
+ nn : int, default=6
638
+ Exponent for the switching function numerator.
639
+ mm : int, default=0
640
+ Exponent for the switching function denominator.
641
+ d_0 : float, default=0.0
642
+ Offset for the switching function (in Angstroms).
643
+ flatten : bool, default=True
644
+ How to handle AtomSelector inputs:
645
+ - True: Flatten all groups into a single GROUP
646
+ - False: Create a GROUP for each selector group
647
+ pairwise : {"all", "diagonal", "none"}, default="all"
648
+ Strategy for pairing multiple groups:
649
+ - "all": All pairwise combinations (N×M CVs)
650
+ - "diagonal": Pair by index (min(N,M) CVs)
651
+ - "none": Raise error if both have multiple groups
652
+
653
+ Resources
654
+ ---------
655
+ - https://www.plumed.org/doc-master/user-doc/html/COORDINATION
656
+ - https://www.plumed.org/doc-master/user-doc/html/GROUP
343
657
  """
344
658
 
345
- x1: AtomSelector
346
- x2: AtomSelector
659
+ x1: AtomSelector | VirtualAtom
660
+ x2: AtomSelector | VirtualAtom
347
661
  prefix: str
348
662
  r_0: float
349
663
  nn: int = 6
350
664
  mm: int = 0
351
665
  d_0: float = 0.0
352
- group_reduction_1: GroupReductionStrategyType = "all"
353
- group_reduction_2: GroupReductionStrategyType = "all"
354
- multi_group: MultiGroupStrategyType = "first"
355
- create_virtual_sites: bool = True
666
+ flatten: bool = True
667
+ pairwise: Literal["all", "diagonal", "none"] = "all"
356
668
 
357
669
  def _get_atom_highlights(
358
670
  self, atoms: Atoms, **kwargs
359
671
  ) -> Optional[AtomHighlightMap]:
360
- highlight_hydrogens = kwargs.get("highlight_hydrogens", False)
672
+ """Get atom highlights for visualization."""
673
+ # Skip for VirtualAtom inputs
674
+ if isinstance(self.x1, VirtualAtom) or isinstance(self.x2, VirtualAtom):
675
+ return None
676
+
361
677
  groups1 = self.x1.select(atoms)
362
678
  groups2 = self.x2.select(atoms)
363
679
 
364
680
  if not groups1 or not groups2:
365
681
  return None
366
682
 
367
- # Flatten groups and optionally filter out hydrogens.
368
- indices1 = {idx for g in groups1 for idx in g}
369
- indices2 = {idx for g in groups2 for idx in g}
370
- if not highlight_hydrogens:
371
- indices1 = {i for i in indices1 if atoms[i].symbol != "H"}
372
- indices2 = {i for i in indices2 if atoms[i].symbol != "H"}
683
+ # Highlight all atoms from both selections
684
+ indices1 = {idx for group in groups1 for idx in group}
685
+ indices2 = {idx for group in groups2 for idx in group}
373
686
 
374
687
  if not indices1 and not indices2:
375
688
  return None
376
689
 
377
- # Color atoms based on group membership, with purple for overlaps.
690
+ # Color atoms based on group membership
378
691
  highlights: AtomHighlightMap = {}
379
- red, blue, purple = (1.0, 0.5, 0.5), (0.5, 0.5, 1.0), (1.0, 0.5, 1.0)
692
+ red, blue, purple = (1.0, 0.2, 0.2), (0.2, 0.2, 1.0), (1.0, 0.2, 1.0)
380
693
  for idx in indices1.union(indices2):
381
694
  in1, in2 = idx in indices1, idx in indices2
382
695
  if in1 and in2:
@@ -388,117 +701,179 @@ class CoordinationNumberCV(_BasePlumedCV):
388
701
  return highlights
389
702
 
390
703
  def to_plumed(self, atoms: Atoms) -> Tuple[List[str], List[str]]:
704
+ """Generate PLUMED COORDINATION command(s).
705
+
706
+ Returns
707
+ -------
708
+ labels : list[str]
709
+ List of CV labels created.
710
+ commands : list[str]
711
+ List of PLUMED commands.
391
712
  """
392
- Generates PLUMED input strings for the COORDINATION CV.
713
+ # Process both inputs to get group labels
714
+ labels1, cmds1 = self._process_coordination_input(self.x1, atoms, "x1")
715
+ labels2, cmds2 = self._process_coordination_input(self.x2, atoms, "x2")
393
716
 
394
- Returns:
395
- A tuple containing a list of CV labels and a list of PLUMED commands.
717
+ commands = []
718
+ commands.extend(cmds1)
719
+ commands.extend(cmds2)
720
+
721
+ # Generate COORDINATION commands
722
+ cv_labels, cv_commands = self._generate_coordination_cvs(labels1, labels2)
723
+ commands.extend(cv_commands)
724
+
725
+ return cv_labels, commands
726
+
727
+ def _process_coordination_input(
728
+ self, input_obj: AtomSelector | VirtualAtom, atoms: Atoms, label_prefix: str
729
+ ) -> Tuple[List[str], List[str]]:
730
+ """Process input for COORDINATION and return group labels/commands.
731
+
732
+ For COORDINATION, we need groups (not individual points), so the processing
733
+ is different from DistanceCV:
734
+ - VirtualAtom with multiple sites → create GROUP of those sites
735
+ - VirtualAtom with single site → use site directly
736
+ - AtomSelector with flatten=True → create single group with all atoms
737
+ - AtomSelector with flatten=False → create GROUP for each selector group
738
+
739
+ Returns
740
+ -------
741
+ labels : list[str]
742
+ Group labels that can be used in COORDINATION GROUPA/GROUPB.
743
+ commands : list[str]
744
+ PLUMED commands to create those groups.
396
745
  """
397
- groups1 = self.x1.select(atoms)
398
- groups2 = self.x2.select(atoms)
399
-
400
- if not groups1 or not groups2:
401
- raise ValueError(f"Empty selection for coordination CV '{self.prefix}'")
402
-
403
- commands = self._generate_commands(groups1, groups2)
404
- labels = self._extract_labels(commands, self.prefix, "COORDINATION")
405
- return labels, commands
406
-
407
- def _generate_commands(
408
- self, groups1: List[List[int]], groups2: List[List[int]]
409
- ) -> List[str]:
410
- """Generates all necessary PLUMED commands."""
411
- commands: List[str] = []
412
-
413
- sites1 = self._reduce_groups(
414
- groups1, self.group_reduction_1, f"{self.prefix}_g1", commands
415
- )
416
- sites2 = self._reduce_groups(
417
- groups2, self.group_reduction_2, f"{self.prefix}_g2", commands
418
- )
419
-
420
- # Get site pairs using a simplified helper
421
- site_pairs = []
422
- if self.multi_group == "first":
423
- site_pairs = [(sites1[0], sites2[0])] if sites1 and sites2 else []
424
- elif self.multi_group == "all_pairs":
425
- site_pairs = [(s1, s2) for s1 in sites1 for s2 in sites2]
426
- elif self.multi_group == "corresponding":
427
- n = min(len(sites1), len(sites2))
428
- site_pairs = [(sites1[i], sites2[i]) for i in range(n)]
429
- elif self.multi_group == "first_to_all":
430
- site_pairs = [(sites1[0], s2) for s2 in sites2] if sites1 else []
431
-
432
- for i, (s1, s2) in enumerate(site_pairs):
433
- label = self.prefix if len(site_pairs) == 1 else f"{self.prefix}_{i}"
434
- commands.append(self._make_coordination_command(s1, s2, label))
746
+ if isinstance(input_obj, VirtualAtom):
747
+ # Set deterministic label if not already set
748
+ if input_obj.label is None:
749
+ labeled_va = dataclasses.replace(
750
+ input_obj, label=f"{self.prefix}_{label_prefix}"
751
+ )
752
+ else:
753
+ labeled_va = input_obj
435
754
 
436
- return commands
755
+ # Get virtual site labels
756
+ vsite_labels, vsite_commands = labeled_va.to_plumed(atoms)
437
757
 
438
- def _reduce_groups(
439
- self,
440
- groups: List[List[int]],
441
- strategy: GroupReductionStrategyType,
442
- site_prefix: str,
443
- commands: List[str],
444
- ) -> List[SiteIdentifier]:
445
- """Reduces a list of atom groups into a list of site identifiers."""
446
- if strategy in ["com_per_group", "cog_per_group"]:
447
- if not self.create_virtual_sites:
448
- raise ValueError(f"'{strategy}' requires create_virtual_sites=True")
449
-
450
- reduction_type = "COM" if strategy == "com_per_group" else "CENTER"
451
- vsite_labels = []
452
- for i, group in enumerate(groups):
453
- if not group:
454
- continue
455
- vsite_label = f"{site_prefix}_{i}"
456
- atom_list = ",".join(str(idx + 1) for idx in group)
457
- commands.append(f"{vsite_label}: {reduction_type} ATOMS={atom_list}")
458
- vsite_labels.append(vsite_label)
459
-
460
- group_label = f"{site_prefix}_group"
461
- commands.append(f"{group_label}: GROUP ATOMS={','.join(vsite_labels)}")
462
- return [group_label]
463
-
464
- if strategy == "all":
465
- return [sorted({idx for group in groups for idx in group})]
466
-
467
- # Handle other strategies by reducing each group individually.
468
- sites: List[SiteIdentifier] = []
469
- for i, group in enumerate(groups):
470
- if len(group) == 1 or strategy == "first":
471
- sites.append(str(group[0] + 1))
472
- elif strategy in ["com", "cog"]:
473
- if self.create_virtual_sites:
474
- label = f"{site_prefix}_{i}_{strategy}"
475
- cmd = self._create_virtual_site_command(group, strategy, label)
476
- commands.append(cmd)
477
- sites.append(label)
758
+ # If multiple virtual sites, create a GROUP of them
759
+ if len(vsite_labels) > 1:
760
+ group_label = f"{self.prefix}_{label_prefix}_group"
761
+ group_cmd = f"{group_label}: GROUP ATOMS={','.join(vsite_labels)}"
762
+ return [group_label], vsite_commands + [group_cmd]
763
+ else:
764
+ # Single virtual site, use directly
765
+ return vsite_labels, vsite_commands
766
+ else:
767
+ # AtomSelector: create group(s) based on flatten parameter
768
+ groups = input_obj.select(atoms)
769
+ if not groups:
770
+ return [], []
771
+
772
+ if self.flatten:
773
+ # Flatten all groups into single group
774
+ flat_atoms = [idx for group in groups for idx in group]
775
+ # Return as list of atom indices (will be formatted in COORDINATION command)
776
+ return [flat_atoms], []
777
+ else:
778
+ # Smart GROUP creation: only create GROUP for multi-atom groups
779
+ labels = []
780
+ commands = []
781
+ for i, group in enumerate(groups):
782
+ if len(group) == 1:
783
+ # Single atom: use directly (no GROUP needed)
784
+ labels.append(str(group[0] + 1))
785
+ else:
786
+ # Multi-atom group: create GROUP
787
+ group_label = f"{self.prefix}_{label_prefix}_g{i}"
788
+ atom_list = ",".join(str(idx + 1) for idx in group)
789
+ commands.append(f"{group_label}: GROUP ATOMS={atom_list}")
790
+ labels.append(group_label)
791
+
792
+ # If multiple groups, create a parent GROUP
793
+ if len(labels) > 1:
794
+ parent_label = f"{self.prefix}_{label_prefix}_group"
795
+ parent_cmd = f"{parent_label}: GROUP ATOMS={','.join(labels)}"
796
+ return [parent_label], commands + [parent_cmd]
478
797
  else:
479
- sites.append(group)
798
+ return labels, commands
799
+
800
+ def _generate_coordination_cvs(
801
+ self, labels1: List[str | List[int]], labels2: List[str | List[int]]
802
+ ) -> Tuple[List[str], List[str]]:
803
+ """Generate COORDINATION CV commands.
804
+
805
+ Parameters
806
+ ----------
807
+ labels1, labels2 : list[str | list[int]]
808
+ Group labels or atom index lists for GROUPA and GROUPB.
809
+
810
+ Returns
811
+ -------
812
+ cv_labels : list[str]
813
+ Labels for the COORDINATION CVs created.
814
+ commands : list[str]
815
+ COORDINATION command strings.
816
+ """
817
+ n1, n2 = len(labels1), len(labels2)
818
+
819
+ # Determine which pairs to create based on pairwise strategy
820
+ if n1 == 1 and n2 == 1:
821
+ # One-to-one: always create single CV
822
+ pairs = [(0, 0)]
823
+ elif n1 == 1:
824
+ # One-to-many: pair first of x1 with all of x2
825
+ pairs = [(0, j) for j in range(n2)]
826
+ elif n2 == 1:
827
+ # Many-to-one: pair all of x1 with first of x2
828
+ pairs = [(i, 0) for i in range(n1)]
829
+ else:
830
+ # Many-to-many: apply pairwise strategy
831
+ if self.pairwise == "all":
832
+ pairs = [(i, j) for i in range(n1) for j in range(n2)]
833
+ elif self.pairwise == "diagonal":
834
+ n_pairs = min(n1, n2)
835
+ pairs = [(i, i) for i in range(n_pairs)]
836
+ elif self.pairwise == "none":
837
+ raise ValueError(
838
+ f"Both x1 and x2 have multiple groups ({n1} and {n2}). "
839
+ f"Use pairwise='all' or 'diagonal', or select specific groups with indexing."
840
+ )
480
841
  else:
481
- raise ValueError(f"Unsupported reduction strategy: {strategy}")
482
- return sites
842
+ raise ValueError(f"Unknown pairwise strategy: {self.pairwise}")
483
843
 
484
- def _make_coordination_command(
485
- self, site1: SiteIdentifier, site2: SiteIdentifier, label: str
486
- ) -> str:
487
- """Creates a single PLUMED COORDINATION command string."""
844
+ # Generate COORDINATION commands
845
+ cv_labels = []
846
+ commands = []
847
+ for idx, (i, j) in enumerate(pairs):
848
+ if len(pairs) == 1:
849
+ label = self.prefix
850
+ else:
851
+ label = f"{self.prefix}_{idx}"
852
+
853
+ # Format group labels for COORDINATION
854
+ def format_group(g):
855
+ if isinstance(g, list): # List of atom indices
856
+ return ",".join(str(idx + 1) for idx in g)
857
+ else: # String label
858
+ return g
859
+
860
+ g_a = format_group(labels1[i])
861
+ g_b = format_group(labels2[j])
488
862
 
489
- def _format(site):
490
- return ",".join(map(str, (s + 1 for s in site))) if isinstance(site, list) else site
863
+ # Create COORDINATION command
864
+ cmd = f"{label}: COORDINATION GROUPA={g_a}"
865
+ if g_a != g_b: # Omit GROUPB for self-coordination
866
+ cmd += f" GROUPB={g_b}"
491
867
 
492
- g_a, g_b = _format(site1), _format(site2)
493
- base_cmd = f"{label}: COORDINATION GROUPA={g_a}"
494
- if g_a != g_b: # Omit GROUPB for self-coordination
495
- base_cmd += f" GROUPB={g_b}"
868
+ # Add parameters
869
+ cmd += f" R_0={self.r_0} NN={self.nn} D_0={self.d_0}"
870
+ if self.mm != 0:
871
+ cmd += f" MM={self.mm}"
496
872
 
497
- params = f" R_0={self.r_0} NN={self.nn} D_0={self.d_0}"
498
- if self.mm != 0:
499
- params += f" MM={self.mm}"
873
+ commands.append(cmd)
874
+ cv_labels.append(label)
500
875
 
501
- return base_cmd + params
876
+ return cv_labels, commands
502
877
 
503
878
 
504
879
  @dataclass
@@ -509,18 +884,25 @@ class TorsionCV(_BasePlumedCV):
509
884
  Calculates the torsional (dihedral) angle defined by four atoms. Each group
510
885
  provided by the selector must contain exactly four atoms.
511
886
 
512
- Attributes:
513
- atoms: Selector for one or more groups of 4 atoms.
514
- prefix: Label prefix for the generated PLUMED commands.
515
- multi_group: Strategy for handling multiple groups from the selector.
516
-
517
- Resources:
518
- - https://www.plumed.org/doc-master/user-doc/html/TORSION.html
887
+ Parameters
888
+ ----------
889
+ atoms : AtomSelector
890
+ Selector for one or more groups of 4 atoms. Each group must contain exactly 4 atoms.
891
+ prefix : str
892
+ Label prefix for the generated PLUMED commands.
893
+ strategy : {"first", "all"}, default="first"
894
+ Strategy for handling multiple groups from the selector:
895
+ - "first": Process only the first group (creates 1 CV)
896
+ - "all": Process all groups independently (creates N CVs)
897
+
898
+ Resources
899
+ ---------
900
+ - https://www.plumed.org/doc-master/user-doc/html/TORSION
519
901
  """
520
902
 
521
903
  atoms: AtomSelector
522
904
  prefix: str
523
- multi_group: MultiGroupStrategyType = "first"
905
+ strategy: Literal["first", "all"] = "first"
524
906
 
525
907
  def _get_atom_highlights(
526
908
  self, atoms: Atoms, **kwargs
@@ -563,10 +945,10 @@ class TorsionCV(_BasePlumedCV):
563
945
 
564
946
  def _generate_commands(self, groups: List[List[int]]) -> List[str]:
565
947
  """Generates all necessary PLUMED commands."""
566
- # For torsions, 'multi_group' determines how many groups to process.
567
- if self.multi_group in ["first", "first_to_all"] and groups:
948
+ # Determine which groups to process based on strategy
949
+ if self.strategy == "first" and groups:
568
950
  indices_to_process = [0]
569
- else: # "all_pairs" and "corresponding" imply processing all independent groups.
951
+ else: # "all" - process all groups independently
570
952
  indices_to_process = list(range(len(groups)))
571
953
 
572
954
  commands = []
@@ -586,19 +968,33 @@ class RadiusOfGyrationCV(_BasePlumedCV):
586
968
  Calculates the radius of gyration of a group of atoms. The radius of gyration
587
969
  is a measure of the size of a molecular system.
588
970
 
589
- Attributes:
590
- atoms: Selector for the atoms to include in the gyration calculation.
591
- prefix: Label prefix for the generated PLUMED commands.
592
- multi_group: Strategy for handling multiple groups from the selector.
593
- type: The type of gyration tensor to use ("RADIUS" for scalar Rg, "GTPC_1", etc.)
594
-
595
- Resources:
596
- - https://www.plumed.org/doc-master/user-doc/html/GYRATION/
971
+ Parameters
972
+ ----------
973
+ atoms : AtomSelector
974
+ Selector for the atoms to include in the gyration calculation.
975
+ prefix : str
976
+ Label prefix for the generated PLUMED commands.
977
+ flatten : bool, default=False
978
+ How to handle multiple groups from the selector:
979
+ - True: Combine all groups into one and calculate single Rg (creates 1 CV)
980
+ - False: Keep groups separate, use strategy to determine which to process
981
+ strategy : {"first", "all"}, default="first"
982
+ Strategy for handling multiple groups when flatten=False:
983
+ - "first": Process only the first group (creates 1 CV)
984
+ - "all": Process all groups independently (creates N CVs)
985
+ type : str, default="RADIUS"
986
+ The type of gyration tensor to use.
987
+ Options: "RADIUS", "GTPC_1", "GTPC_2", "GTPC_3", "ASPHERICITY", "ACYLINDRICITY", "KAPPA2", etc.
988
+
989
+ Resources
990
+ ---------
991
+ - https://www.plumed.org/doc-master/user-doc/html/GYRATION/
597
992
  """
598
993
 
599
994
  atoms: AtomSelector
600
995
  prefix: str
601
- multi_group: MultiGroupStrategyType = "first"
996
+ flatten: bool = False
997
+ strategy: Literal["first", "all"] = "first"
602
998
  type: str = "RADIUS" # Options: RADIUS, GTPC_1, GTPC_2, GTPC_3, ASPHERICITY, ACYLINDRICITY, KAPPA2, etc.
603
999
 
604
1000
  def _get_atom_highlights(
@@ -629,18 +1025,29 @@ class RadiusOfGyrationCV(_BasePlumedCV):
629
1025
 
630
1026
  def _generate_commands(self, groups: List[List[int]]) -> List[str]:
631
1027
  """Generates all necessary PLUMED commands."""
632
- # For gyration, 'multi_group' determines how many groups to process.
633
- if self.multi_group in ["first", "first_to_all"] and groups:
634
- indices_to_process = [0]
635
- else: # "all_pairs" and "corresponding" imply processing all independent groups.
636
- indices_to_process = list(range(len(groups)))
637
-
638
1028
  commands = []
639
- for i in indices_to_process:
640
- label = self.prefix if len(indices_to_process) == 1 else f"{self.prefix}_{i}"
641
- atom_list = ",".join(str(idx + 1) for idx in groups[i])
642
- command = f"{label}: GYRATION ATOMS={atom_list}"
1029
+
1030
+ if self.flatten:
1031
+ # Combine all groups into single atom list
1032
+ flat_atoms = [idx for group in groups for idx in group]
1033
+ atom_list = ",".join(str(idx + 1) for idx in flat_atoms)
1034
+ command = f"{self.prefix}: GYRATION ATOMS={atom_list}"
643
1035
  if self.type != "RADIUS":
644
1036
  command += f" TYPE={self.type}"
645
1037
  commands.append(command)
1038
+ else:
1039
+ # Keep groups separate and use strategy to determine which to process
1040
+ if self.strategy == "first" and groups:
1041
+ indices_to_process = [0]
1042
+ else: # "all" - process all groups independently
1043
+ indices_to_process = list(range(len(groups)))
1044
+
1045
+ for i in indices_to_process:
1046
+ label = self.prefix if len(indices_to_process) == 1 else f"{self.prefix}_{i}"
1047
+ atom_list = ",".join(str(idx + 1) for idx in groups[i])
1048
+ command = f"{label}: GYRATION ATOMS={atom_list}"
1049
+ if self.type != "RADIUS":
1050
+ command += f" TYPE={self.type}"
1051
+ commands.append(command)
1052
+
646
1053
  return commands