hillclimber 0.1.0a1__py3-none-any.whl → 0.1.0a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hillclimber might be problematic. Click here for more details.

hillclimber/cvs.py CHANGED
@@ -1,5 +1,7 @@
1
1
  # --- IMPORTS ---
2
2
  # Standard library
3
+ from __future__ import annotations
4
+ import dataclasses
3
5
  from dataclasses import dataclass
4
6
  from typing import Dict, List, Literal, Optional, Tuple, Union
5
7
 
@@ -12,6 +14,7 @@ from rdkit.Chem import Draw
12
14
 
13
15
  # Local
14
16
  from hillclimber.interfaces import AtomSelector, CollectiveVariable
17
+ from hillclimber.virtual_atoms import VirtualAtom
15
18
 
16
19
 
17
20
  # --- TYPE HINTS ---
@@ -163,61 +166,97 @@ class DistanceCV(_BasePlumedCV):
163
166
  """
164
167
  PLUMED DISTANCE collective variable.
165
168
 
166
- Calculates the distance between two atoms or groups of atoms. This CV supports
167
- various strategies for reducing groups to single points (e.g., center of mass)
168
- and for pairing multiple groups.
169
-
170
- Attributes:
171
- x1: Selector for the first atom/group.
172
- x2: Selector for the second atom/group.
173
- prefix: Label prefix for the generated PLUMED commands.
174
- group_reduction: Strategy to reduce an atom group to a single point.
175
- multi_group: Strategy for handling multiple groups from selectors.
176
- create_virtual_sites: If True, create explicit virtual sites for COM/COG.
177
-
178
- Resources:
179
- - https://www.plumed.org/doc-master/user-doc/html/DISTANCE.html
169
+ Calculates the distance between two atoms, groups of atoms, or virtual sites.
170
+ Supports flexible flattening and pairing strategies for multiple groups.
171
+
172
+ Parameters
173
+ ----------
174
+ x1 : AtomSelector | VirtualAtom
175
+ First atom/group or virtual site.
176
+ x2 : AtomSelector | VirtualAtom
177
+ Second atom/group or virtual site.
178
+ prefix : str
179
+ Label prefix for generated PLUMED commands.
180
+ flatten : bool, default=True
181
+ For AtomSelectors only: If True, flatten all groups into single atom list.
182
+ If False, create PLUMED GROUP for each group. VirtualAtoms are never flattened.
183
+ pairwise : {"all", "diagonal", "none"}, default="all"
184
+ Strategy for pairing multiple groups:
185
+ - "all": Create all N×M pair combinations (can create many CVs!)
186
+ - "diagonal": Pair corresponding indices only (creates min(N,M) CVs)
187
+ - "none": Error if both sides have multiple groups (safety check)
188
+
189
+ Examples
190
+ --------
191
+ >>> # Distance between two specific atoms
192
+ >>> dist = hc.DistanceCV(
193
+ ... x1=ethanol_sel[0][0], # First atom of first ethanol
194
+ ... x2=water_sel[0][0], # First atom of first water
195
+ ... prefix="d_atoms"
196
+ ... )
197
+
198
+ >>> # Distance between molecule COMs
199
+ >>> dist = hc.DistanceCV(
200
+ ... x1=hc.VirtualAtom(ethanol_sel[0], "com"),
201
+ ... x2=hc.VirtualAtom(water_sel[0], "com"),
202
+ ... prefix="d_com"
203
+ ... )
204
+
205
+ >>> # One-to-many: First ethanol COM to all water COMs
206
+ >>> dist = hc.DistanceCV(
207
+ ... x1=hc.VirtualAtom(ethanol_sel[0], "com"),
208
+ ... x2=hc.VirtualAtom(water_sel, "com"),
209
+ ... prefix="d",
210
+ ... pairwise="all" # Creates 3 CVs
211
+ ... )
212
+
213
+ >>> # Diagonal pairing (avoid explosion)
214
+ >>> dist = hc.DistanceCV(
215
+ ... x1=hc.VirtualAtom(water_sel, "com"), # 3 waters
216
+ ... x2=hc.VirtualAtom(ethanol_sel, "com"), # 2 ethanols
217
+ ... prefix="d",
218
+ ... pairwise="diagonal" # Creates only 2 CVs: d_0, d_1
219
+ ... )
220
+
221
+ Resources
222
+ ---------
223
+ - https://www.plumed.org/doc-master/user-doc/html/DISTANCE.html
224
+
225
+ Notes
226
+ -----
227
+ For backwards compatibility, old parameters are still supported but deprecated:
228
+ - `group_reduction` → Use VirtualAtom instead
229
+ - `multi_group` → Use `pairwise` parameter
180
230
  """
181
231
 
182
- x1: AtomSelector
183
- x2: AtomSelector
232
+ x1: AtomSelector | VirtualAtom
233
+ x2: AtomSelector | VirtualAtom
184
234
  prefix: str
185
- group_reduction: GroupReductionStrategyType = "com"
186
- multi_group: MultiGroupStrategyType = "first"
187
- create_virtual_sites: bool = True
235
+ flatten: bool = True
236
+ pairwise: Literal["all", "diagonal", "none"] = "all"
188
237
 
189
238
  def _get_atom_highlights(
190
239
  self, atoms: Atoms, **kwargs
191
240
  ) -> Optional[AtomHighlightMap]:
241
+ """Get atom highlights for visualization."""
242
+ # Skip for VirtualAtom inputs
243
+ if isinstance(self.x1, VirtualAtom) or isinstance(self.x2, VirtualAtom):
244
+ return None
245
+
192
246
  groups1 = self.x1.select(atoms)
193
247
  groups2 = self.x2.select(atoms)
194
248
 
195
249
  if not groups1 or not groups2:
196
250
  return None
197
251
 
198
- index_pairs = self._get_index_pairs(len(groups1), len(groups2), self.multi_group)
199
- if not index_pairs:
200
- return None
201
-
202
- # Correctly select atoms based on the group_reduction strategy
203
- indices1, indices2 = set(), set()
204
- for i, j in index_pairs:
205
- # Handle the 'first' atom case specifically for highlighting
206
- if self.group_reduction == "first":
207
- # Ensure the group is not empty before accessing the first element
208
- if groups1[i]:
209
- indices1.add(groups1[i][0])
210
- if groups2[j]:
211
- indices2.add(groups2[j][0])
212
- # For other strategies (com, cog, all), highlight the whole group
213
- else:
214
- indices1.update(groups1[i])
215
- indices2.update(groups2[j])
252
+ # Highlight all atoms from both selections
253
+ indices1 = {idx for group in groups1 for idx in group}
254
+ indices2 = {idx for group in groups2 for idx in group}
216
255
 
217
256
  if not indices1 and not indices2:
218
257
  return None
219
258
 
220
- # Color atoms based on group membership, with purple for overlaps.
259
+ # Color atoms based on group membership
221
260
  highlights: AtomHighlightMap = {}
222
261
  red, blue, purple = (1.0, 0.2, 0.2), (0.2, 0.2, 1.0), (1.0, 0.2, 1.0)
223
262
  for idx in indices1.union(indices2):
@@ -231,93 +270,366 @@ class DistanceCV(_BasePlumedCV):
231
270
  return highlights
232
271
 
233
272
  def to_plumed(self, atoms: Atoms) -> Tuple[List[str], List[str]]:
273
+ """Generate PLUMED input strings for the DISTANCE CV.
274
+
275
+ Returns
276
+ -------
277
+ labels : list[str]
278
+ List of CV labels generated.
279
+ commands : list[str]
280
+ List of PLUMED command strings.
234
281
  """
235
- Generates PLUMED input strings for the DISTANCE CV.
282
+ commands = []
236
283
 
237
- Returns:
238
- A tuple containing a list of CV labels and a list of PLUMED commands.
284
+ # Process x1
285
+ labels1, cmds1 = self._process_input(self.x1, atoms, "x1")
286
+ commands.extend(cmds1)
287
+
288
+ # Process x2
289
+ labels2, cmds2 = self._process_input(self.x2, atoms, "x2")
290
+ commands.extend(cmds2)
291
+
292
+ # Check for empty selections
293
+ if not labels1 or not labels2:
294
+ raise ValueError(f"Empty selection for distance CV '{self.prefix}'")
295
+
296
+ # Generate distance CVs based on pairwise strategy
297
+ cv_labels, cv_commands = self._generate_distance_cvs(labels1, labels2)
298
+ commands.extend(cv_commands)
299
+
300
+ return cv_labels, commands
301
+
302
+ def _process_input(
303
+ self, input_obj: AtomSelector | VirtualAtom, atoms: Atoms, label_prefix: str
304
+ ) -> Tuple[List[str], List[str]]:
305
+ """Process an input (AtomSelector or VirtualAtom) and return labels and commands.
306
+
307
+ Returns
308
+ -------
309
+ labels : list[str]
310
+ List of labels for this input (either virtual site labels or GROUP labels).
311
+ commands : list[str]
312
+ PLUMED commands to create the labels.
239
313
  """
314
+ if isinstance(input_obj, VirtualAtom):
315
+ # VirtualAtom: set deterministic label if not already set
316
+ if input_obj.label is None:
317
+ # Set label based on prefix and label_prefix (x1 or x2)
318
+ labeled_va = dataclasses.replace(
319
+ input_obj, label=f"{self.prefix}_{label_prefix}"
320
+ )
321
+ return labeled_va.to_plumed(atoms)
322
+ else:
323
+ return input_obj.to_plumed(atoms)
324
+ else:
325
+ # AtomSelector: handle based on flatten parameter
326
+ groups = input_obj.select(atoms)
327
+ if not groups:
328
+ return [], []
329
+
330
+ if self.flatten:
331
+ # Flatten all groups into single list
332
+ flat_atoms = [idx for group in groups for idx in group]
333
+ atom_list = ",".join(str(idx + 1) for idx in flat_atoms)
334
+ # Return as pseudo-label (will be used directly in DISTANCE command)
335
+ return [atom_list], []
336
+ else:
337
+ # Smart GROUP creation: only create GROUP for multi-atom groups
338
+ labels = []
339
+ commands = []
340
+ for i, group in enumerate(groups):
341
+ if len(group) == 1:
342
+ # Single atom: use directly (no GROUP needed)
343
+ labels.append(str(group[0] + 1))
344
+ else:
345
+ # Multi-atom group: create GROUP
346
+ group_label = f"{self.prefix}_{label_prefix}_g{i}"
347
+ atom_list = ",".join(str(idx + 1) for idx in group)
348
+ commands.append(f"{group_label}: GROUP ATOMS={atom_list}")
349
+ labels.append(group_label)
350
+ return labels, commands
351
+
352
+ def _generate_distance_cvs(
353
+ self, labels1: List[str], labels2: List[str]
354
+ ) -> Tuple[List[str], List[str]]:
355
+ """Generate DISTANCE CV commands based on pairwise strategy."""
356
+ n1, n2 = len(labels1), len(labels2)
357
+
358
+ # Determine which pairs to create based on pairwise strategy
359
+ if n1 == 1 and n2 == 1:
360
+ # One-to-one: always create single CV
361
+ pairs = [(0, 0)]
362
+ elif n1 == 1:
363
+ # One-to-many: pair first of x1 with all of x2
364
+ pairs = [(0, j) for j in range(n2)]
365
+ elif n2 == 1:
366
+ # Many-to-one: pair all of x1 with first of x2
367
+ pairs = [(i, 0) for i in range(n1)]
368
+ else:
369
+ # Many-to-many: apply pairwise strategy
370
+ if self.pairwise == "all":
371
+ pairs = [(i, j) for i in range(n1) for j in range(n2)]
372
+ elif self.pairwise == "diagonal":
373
+ n_pairs = min(n1, n2)
374
+ pairs = [(i, i) for i in range(n_pairs)]
375
+ elif self.pairwise == "none":
376
+ raise ValueError(
377
+ f"Both x1 and x2 have multiple groups ({n1} and {n2}). "
378
+ f"Use pairwise='all' or 'diagonal', or select specific groups with indexing."
379
+ )
380
+ else:
381
+ raise ValueError(f"Unknown pairwise strategy: {self.pairwise}")
382
+
383
+ # Generate DISTANCE commands
384
+ cv_labels = []
385
+ commands = []
386
+ for idx, (i, j) in enumerate(pairs):
387
+ if len(pairs) == 1:
388
+ label = self.prefix
389
+ else:
390
+ label = f"{self.prefix}_{idx}"
391
+
392
+ # Create DISTANCE command
393
+ cmd = f"{label}: DISTANCE ATOMS={labels1[i]},{labels2[j]}"
394
+ commands.append(cmd)
395
+ cv_labels.append(label)
396
+
397
+ return cv_labels, commands
398
+
399
+
400
+ @dataclass
401
+ class AngleCV(_BasePlumedCV):
402
+ """
403
+ PLUMED ANGLE collective variable.
404
+
405
+ Calculates the angle formed by three atoms or groups of atoms using the new
406
+ VirtualAtom API. The angle is computed as the angle between the vectors
407
+ (x1-x2) and (x3-x2), where x2 is the vertex of the angle.
408
+
409
+ Parameters
410
+ ----------
411
+ x1 : AtomSelector | VirtualAtom
412
+ First position. Can be an AtomSelector or VirtualAtom.
413
+ x2 : AtomSelector | VirtualAtom
414
+ Vertex position (center of the angle). Can be an AtomSelector or VirtualAtom.
415
+ x3 : AtomSelector | VirtualAtom
416
+ Third position. Can be an AtomSelector or VirtualAtom.
417
+ prefix : str
418
+ Label prefix for the generated PLUMED commands.
419
+ flatten : bool, default=True
420
+ How to handle AtomSelector inputs:
421
+ - True: Flatten all groups into a single list
422
+ - False: Create GROUP for each selector group (not typically used for ANGLE)
423
+ strategy : {"first", "all", "diagonal", "none"}, default="first"
424
+ Strategy for creating multiple angles from multiple groups:
425
+ - "first": Use first group from each selector (1 angle)
426
+ - "all": All combinations (N×M×P angles)
427
+ - "diagonal": Pair by index (min(N,M,P) angles)
428
+ - "none": Raise error if any selector has multiple groups
429
+
430
+ Resources
431
+ ---------
432
+ - https://www.plumed.org/doc-master/user-doc/html/ANGLE/
433
+ """
434
+
435
+ x1: AtomSelector | VirtualAtom
436
+ x2: AtomSelector | VirtualAtom
437
+ x3: AtomSelector | VirtualAtom
438
+ prefix: str
439
+ flatten: bool = True
440
+ strategy: Literal["first", "all", "diagonal", "none"] = "first"
441
+
442
+ def _get_atom_highlights(
443
+ self, atoms: Atoms, **kwargs
444
+ ) -> Optional[AtomHighlightMap]:
445
+ """Get atom highlights for visualization."""
446
+ # Skip for VirtualAtom inputs
447
+ if isinstance(self.x1, VirtualAtom) or isinstance(self.x2, VirtualAtom) or isinstance(self.x3, VirtualAtom):
448
+ return None
449
+
240
450
  groups1 = self.x1.select(atoms)
241
451
  groups2 = self.x2.select(atoms)
452
+ groups3 = self.x3.select(atoms)
242
453
 
243
- if not groups1 or not groups2:
244
- raise ValueError(f"Empty selection for distance CV '{self.prefix}'")
454
+ if not groups1 or not groups2 or not groups3:
455
+ return None
245
456
 
246
- flat1 = {idx for group in groups1 for idx in group}
247
- flat2 = {idx for group in groups2 for idx in group}
248
- if flat1.intersection(flat2) and self.group_reduction not in ["com", "cog"]:
249
- raise ValueError(
250
- "Overlapping atoms found. This is only valid with 'com' or 'cog' reduction."
251
- )
457
+ # Highlight all atoms from all three selections
458
+ indices1 = {idx for group in groups1 for idx in group}
459
+ indices2 = {idx for group in groups2 for idx in group}
460
+ indices3 = {idx for group in groups3 for idx in group}
252
461
 
253
- commands = self._generate_commands(groups1, groups2)
254
- labels = self._extract_labels(commands, self.prefix, "DISTANCE")
255
- return labels, commands
462
+ if not indices1 and not indices2 and not indices3:
463
+ return None
464
+
465
+ # Color atoms: red for x1, green for x2 (vertex), blue for x3
466
+ highlights: AtomHighlightMap = {}
467
+ red, green, blue = (1.0, 0.2, 0.2), (0.2, 1.0, 0.2), (0.2, 0.2, 1.0)
468
+
469
+ # Handle overlaps by prioritizing vertex (x2) coloring
470
+ all_indices = indices1.union(indices2).union(indices3)
471
+ for idx in all_indices:
472
+ in1, in2, in3 = idx in indices1, idx in indices2, idx in indices3
473
+ if in2: # Vertex gets priority
474
+ highlights[idx] = green
475
+ elif in1 and in3: # Overlap between x1 and x3
476
+ highlights[idx] = (0.5, 0.2, 0.6) # Purple
477
+ elif in1:
478
+ highlights[idx] = red
479
+ elif in3:
480
+ highlights[idx] = blue
481
+ return highlights
482
+
483
+ def to_plumed(self, atoms: Atoms) -> Tuple[List[str], List[str]]:
484
+ """Generate PLUMED ANGLE command(s).
485
+
486
+ Returns
487
+ -------
488
+ labels : list[str]
489
+ List of CV labels created.
490
+ commands : list[str]
491
+ List of PLUMED commands.
492
+
493
+ Raises
494
+ ------
495
+ ValueError
496
+ If any selector returns empty selection.
497
+ """
498
+ # Process all three inputs
499
+ labels1, cmds1 = self._process_input(self.x1, atoms, "x1")
500
+ labels2, cmds2 = self._process_input(self.x2, atoms, "x2")
501
+ labels3, cmds3 = self._process_input(self.x3, atoms, "x3")
502
+
503
+ # Check for empty selections
504
+ if not labels1 or not labels2 or not labels3:
505
+ raise ValueError(f"Empty selection for angle CV '{self.prefix}'")
256
506
 
257
- def _generate_commands(
258
- self, groups1: List[List[int]], groups2: List[List[int]]
259
- ) -> List[str]:
260
- """Generates all necessary PLUMED commands."""
261
507
  commands = []
262
- index_pairs = self._get_index_pairs(
263
- len(groups1), len(groups2), self.multi_group
264
- )
508
+ commands.extend(cmds1)
509
+ commands.extend(cmds2)
510
+ commands.extend(cmds3)
265
511
 
266
- # Efficiently create virtual sites only for groups that will be used.
267
- sites1, sites2 = {}, {}
268
- unique_indices1 = sorted({i for i, j in index_pairs})
269
- unique_indices2 = sorted({j for i, j in index_pairs})
270
-
271
- for i in unique_indices1:
272
- site, site_cmds = self._reduce_group(groups1[i], f"{self.prefix}_g1_{i}")
273
- sites1[i] = site
274
- commands.extend(site_cmds)
275
- for j in unique_indices2:
276
- site, site_cmds = self._reduce_group(groups2[j], f"{self.prefix}_g2_{j}")
277
- sites2[j] = site
278
- commands.extend(site_cmds)
279
-
280
- # Create the final DISTANCE commands.
281
- for i, j in index_pairs:
282
- label = self.prefix if len(index_pairs) == 1 else f"{self.prefix}_{i}_{j}"
283
- cmd = self._make_distance_command(sites1[i], sites2[j], label)
284
- commands.append(cmd)
512
+ # Generate ANGLE commands
513
+ cv_labels, cv_commands = self._generate_angle_cvs(labels1, labels2, labels3)
514
+ commands.extend(cv_commands)
285
515
 
286
- return commands
516
+ return cv_labels, commands
287
517
 
288
- def _reduce_group(
289
- self, group: List[int], site_prefix: str
290
- ) -> Tuple[SiteIdentifier, List[str]]:
291
- """Reduces a single atom group to a site identifier based on strategy."""
292
- if len(group) == 1 or self.group_reduction == "first":
293
- return str(group[0] + 1), []
294
- if self.group_reduction == "all":
295
- return group, []
296
-
297
- if self.group_reduction in ["com", "cog"]:
298
- if self.create_virtual_sites:
299
- label = f"{site_prefix}_{self.group_reduction}"
300
- cmd = self._create_virtual_site_command(
301
- group, self.group_reduction, label
302
- )
303
- return label, [cmd]
304
- return group, [] # Use group directly if not creating virtual sites
518
+ def _process_input(
519
+ self, input_obj: AtomSelector | VirtualAtom, atoms: Atoms, label_prefix: str
520
+ ) -> Tuple[List[str], List[str]]:
521
+ """Process input (AtomSelector or VirtualAtom) and return labels and commands.
305
522
 
306
- raise ValueError(f"Unknown group reduction strategy: {self.group_reduction}")
523
+ Same as DistanceCV._process_input() method.
307
524
 
308
- def _make_distance_command(
309
- self, site1: SiteIdentifier, site2: SiteIdentifier, label: str
310
- ) -> str:
311
- """Creates a single PLUMED DISTANCE command string."""
525
+ Returns
526
+ -------
527
+ labels : list[str]
528
+ List of labels for this input (either virtual site labels or atom lists).
529
+ commands : list[str]
530
+ PLUMED commands to create the labels.
531
+ """
532
+ if isinstance(input_obj, VirtualAtom):
533
+ # VirtualAtom: set deterministic label if not already set
534
+ if input_obj.label is None:
535
+ labeled_va = dataclasses.replace(
536
+ input_obj, label=f"{self.prefix}_{label_prefix}"
537
+ )
538
+ return labeled_va.to_plumed(atoms)
539
+ else:
540
+ return input_obj.to_plumed(atoms)
541
+ else:
542
+ # AtomSelector: handle based on flatten parameter
543
+ groups = input_obj.select(atoms)
544
+ if not groups:
545
+ return [], []
546
+
547
+ if self.flatten:
548
+ # Flatten all groups into single list
549
+ flat_atoms = [idx for group in groups for idx in group]
550
+ atom_list = ",".join(str(idx + 1) for idx in flat_atoms)
551
+ # Return as pseudo-label (will be used directly in ANGLE command)
552
+ return [atom_list], []
553
+ else:
554
+ # Smart GROUP creation: only create GROUP for multi-atom groups
555
+ labels = []
556
+ commands = []
557
+ for i, group in enumerate(groups):
558
+ if len(group) == 1:
559
+ # Single atom: use directly (no GROUP needed)
560
+ labels.append(str(group[0] + 1))
561
+ else:
562
+ # Multi-atom group: create GROUP
563
+ group_label = f"{self.prefix}_{label_prefix}_g{i}"
564
+ atom_list = ",".join(str(idx + 1) for idx in group)
565
+ commands.append(f"{group_label}: GROUP ATOMS={atom_list}")
566
+ labels.append(group_label)
567
+ return labels, commands
568
+
569
+ def _generate_angle_cvs(
570
+ self, labels1: List[str], labels2: List[str], labels3: List[str]
571
+ ) -> Tuple[List[str], List[str]]:
572
+ """Generate ANGLE CV commands based on strategy.
573
+
574
+ Parameters
575
+ ----------
576
+ labels1, labels2, labels3 : list[str]
577
+ Labels or atom lists for the three angle positions.
578
+
579
+ Returns
580
+ -------
581
+ cv_labels : list[str]
582
+ Labels for the ANGLE CVs created.
583
+ commands : list[str]
584
+ ANGLE command strings.
585
+ """
586
+ n1, n2, n3 = len(labels1), len(labels2), len(labels3)
587
+
588
+ # Determine which triplets to create based on strategy
589
+ if n1 == 1 and n2 == 1 and n3 == 1:
590
+ # One-to-one-to-one: always create single CV
591
+ triplets = [(0, 0, 0)]
592
+ elif n1 == 1 and n2 == 1:
593
+ # One-one-to-many: pair first of x1/x2 with all of x3
594
+ triplets = [(0, 0, k) for k in range(n3)]
595
+ elif n1 == 1 and n3 == 1:
596
+ # One-many-to-one: pair first of x1/x3 with all of x2
597
+ triplets = [(0, j, 0) for j in range(n2)]
598
+ elif n2 == 1 and n3 == 1:
599
+ # Many-to-one-one: pair all of x1 with first of x2/x3
600
+ triplets = [(i, 0, 0) for i in range(n1)]
601
+ else:
602
+ # Multi-way: apply strategy
603
+ if self.strategy == "first":
604
+ triplets = [(0, 0, 0)] if n1 > 0 and n2 > 0 and n3 > 0 else []
605
+ elif self.strategy == "all":
606
+ triplets = [(i, j, k) for i in range(n1) for j in range(n2) for k in range(n3)]
607
+ elif self.strategy == "diagonal":
608
+ n_triplets = min(n1, n2, n3)
609
+ triplets = [(i, i, i) for i in range(n_triplets)]
610
+ elif self.strategy == "none":
611
+ raise ValueError(
612
+ f"Multiple groups in x1/x2/x3 ({n1}, {n2}, {n3}). "
613
+ f"Use strategy='all' or 'diagonal', or select specific groups with indexing."
614
+ )
615
+ else:
616
+ raise ValueError(f"Unknown strategy: {self.strategy}")
617
+
618
+ # Generate ANGLE commands
619
+ cv_labels = []
620
+ commands = []
621
+ for idx, (i, j, k) in enumerate(triplets):
622
+ if len(triplets) == 1:
623
+ label = self.prefix
624
+ else:
625
+ label = f"{self.prefix}_{i}_{j}_{k}"
312
626
 
313
- def _format(site):
314
- return ",".join(map(str, (s + 1 for s in site))) if isinstance(site, list) else site
627
+ # Create ANGLE command (ATOMS=x1,x2,x3 where x2 is vertex)
628
+ cmd = f"{label}: ANGLE ATOMS={labels1[i]},{labels2[j]},{labels3[k]}"
629
+ commands.append(cmd)
630
+ cv_labels.append(label)
315
631
 
316
- s1_str, s2_str = _format(site1), _format(site2)
317
- # Use ATOMS for point-like sites, ATOMS1/ATOMS2 for group-based distances
318
- if isinstance(site1, str) and isinstance(site2, str):
319
- return f"{label}: DISTANCE ATOMS={s1_str},{s2_str}"
320
- return f"{label}: DISTANCE ATOMS1={s1_str} ATOMS2={s2_str}"
632
+ return cv_labels, commands
321
633
 
322
634
 
323
635
  @dataclass
@@ -325,58 +637,76 @@ class CoordinationNumberCV(_BasePlumedCV):
325
637
  """
326
638
  PLUMED COORDINATION collective variable.
327
639
 
328
- Calculates a coordination number based on a switching function. It supports
329
- complex group definitions, including groups of virtual sites.
330
-
331
- Attributes:
332
- x1, x2: Selectors for the two groups of atoms.
333
- prefix: Label prefix for the generated PLUMED commands.
334
- r_0: The reference distance for the switching function (in Angstroms).
335
- nn, mm, d_0: Parameters for the switching function.
336
- group_reduction_1, group_reduction_2: Reduction strategies for each group.
337
- multi_group: Strategy for handling multiple groups from selectors.
338
- create_virtual_sites: If True, create explicit virtual sites for COM/COG.
339
-
340
- Resources:
341
- - https://www.plumed.org/doc-master/user-doc/html/COORDINATION.html
342
- - https://www.plumed.org/doc-master/user-doc/html/GROUP.html
640
+ Calculates a coordination number based on a switching function using the new
641
+ VirtualAtom API. The coordination number is computed between two groups of atoms
642
+ using a switching function.
643
+
644
+ Parameters
645
+ ----------
646
+ x1 : AtomSelector | VirtualAtom
647
+ First group of atoms. Can be an AtomSelector or VirtualAtom.
648
+ x2 : AtomSelector | VirtualAtom
649
+ Second group of atoms. Can be an AtomSelector or VirtualAtom.
650
+ prefix : str
651
+ Label prefix for the generated PLUMED commands.
652
+ r_0 : float
653
+ Reference distance for the switching function (in Angstroms).
654
+ nn : int, default=6
655
+ Exponent for the switching function numerator.
656
+ mm : int, default=0
657
+ Exponent for the switching function denominator.
658
+ d_0 : float, default=0.0
659
+ Offset for the switching function (in Angstroms).
660
+ flatten : bool, default=True
661
+ How to handle AtomSelector inputs:
662
+ - True: Flatten all groups into a single GROUP
663
+ - False: Create a GROUP for each selector group
664
+ pairwise : {"all", "diagonal", "none"}, default="all"
665
+ Strategy for pairing multiple groups:
666
+ - "all": All pairwise combinations (N×M CVs)
667
+ - "diagonal": Pair by index (min(N,M) CVs)
668
+ - "none": Raise error if both have multiple groups
669
+
670
+ Resources
671
+ ---------
672
+ - https://www.plumed.org/doc-master/user-doc/html/COORDINATION
673
+ - https://www.plumed.org/doc-master/user-doc/html/GROUP
343
674
  """
344
675
 
345
- x1: AtomSelector
346
- x2: AtomSelector
676
+ x1: AtomSelector | VirtualAtom
677
+ x2: AtomSelector | VirtualAtom
347
678
  prefix: str
348
679
  r_0: float
349
680
  nn: int = 6
350
681
  mm: int = 0
351
682
  d_0: float = 0.0
352
- group_reduction_1: GroupReductionStrategyType = "all"
353
- group_reduction_2: GroupReductionStrategyType = "all"
354
- multi_group: MultiGroupStrategyType = "first"
355
- create_virtual_sites: bool = True
683
+ flatten: bool = True
684
+ pairwise: Literal["all", "diagonal", "none"] = "all"
356
685
 
357
686
  def _get_atom_highlights(
358
687
  self, atoms: Atoms, **kwargs
359
688
  ) -> Optional[AtomHighlightMap]:
360
- highlight_hydrogens = kwargs.get("highlight_hydrogens", False)
689
+ """Get atom highlights for visualization."""
690
+ # Skip for VirtualAtom inputs
691
+ if isinstance(self.x1, VirtualAtom) or isinstance(self.x2, VirtualAtom):
692
+ return None
693
+
361
694
  groups1 = self.x1.select(atoms)
362
695
  groups2 = self.x2.select(atoms)
363
696
 
364
697
  if not groups1 or not groups2:
365
698
  return None
366
699
 
367
- # Flatten groups and optionally filter out hydrogens.
368
- indices1 = {idx for g in groups1 for idx in g}
369
- indices2 = {idx for g in groups2 for idx in g}
370
- if not highlight_hydrogens:
371
- indices1 = {i for i in indices1 if atoms[i].symbol != "H"}
372
- indices2 = {i for i in indices2 if atoms[i].symbol != "H"}
700
+ # Highlight all atoms from both selections
701
+ indices1 = {idx for group in groups1 for idx in group}
702
+ indices2 = {idx for group in groups2 for idx in group}
373
703
 
374
704
  if not indices1 and not indices2:
375
705
  return None
376
706
 
377
- # Color atoms based on group membership, with purple for overlaps.
707
+ # Color atoms based on group membership
378
708
  highlights: AtomHighlightMap = {}
379
- red, blue, purple = (1.0, 0.5, 0.5), (0.5, 0.5, 1.0), (1.0, 0.5, 1.0)
709
+ red, blue, purple = (1.0, 0.2, 0.2), (0.2, 0.2, 1.0), (1.0, 0.2, 1.0)
380
710
  for idx in indices1.union(indices2):
381
711
  in1, in2 = idx in indices1, idx in indices2
382
712
  if in1 and in2:
@@ -388,117 +718,179 @@ class CoordinationNumberCV(_BasePlumedCV):
388
718
  return highlights
389
719
 
390
720
  def to_plumed(self, atoms: Atoms) -> Tuple[List[str], List[str]]:
721
+ """Generate PLUMED COORDINATION command(s).
722
+
723
+ Returns
724
+ -------
725
+ labels : list[str]
726
+ List of CV labels created.
727
+ commands : list[str]
728
+ List of PLUMED commands.
391
729
  """
392
- Generates PLUMED input strings for the COORDINATION CV.
730
+ # Process both inputs to get group labels
731
+ labels1, cmds1 = self._process_coordination_input(self.x1, atoms, "x1")
732
+ labels2, cmds2 = self._process_coordination_input(self.x2, atoms, "x2")
393
733
 
394
- Returns:
395
- A tuple containing a list of CV labels and a list of PLUMED commands.
734
+ commands = []
735
+ commands.extend(cmds1)
736
+ commands.extend(cmds2)
737
+
738
+ # Generate COORDINATION commands
739
+ cv_labels, cv_commands = self._generate_coordination_cvs(labels1, labels2)
740
+ commands.extend(cv_commands)
741
+
742
+ return cv_labels, commands
743
+
744
+ def _process_coordination_input(
745
+ self, input_obj: AtomSelector | VirtualAtom, atoms: Atoms, label_prefix: str
746
+ ) -> Tuple[List[str], List[str]]:
747
+ """Process input for COORDINATION and return group labels/commands.
748
+
749
+ For COORDINATION, we need groups (not individual points), so the processing
750
+ is different from DistanceCV:
751
+ - VirtualAtom with multiple sites → create GROUP of those sites
752
+ - VirtualAtom with single site → use site directly
753
+ - AtomSelector with flatten=True → create single group with all atoms
754
+ - AtomSelector with flatten=False → create GROUP for each selector group
755
+
756
+ Returns
757
+ -------
758
+ labels : list[str]
759
+ Group labels that can be used in COORDINATION GROUPA/GROUPB.
760
+ commands : list[str]
761
+ PLUMED commands to create those groups.
396
762
  """
397
- groups1 = self.x1.select(atoms)
398
- groups2 = self.x2.select(atoms)
399
-
400
- if not groups1 or not groups2:
401
- raise ValueError(f"Empty selection for coordination CV '{self.prefix}'")
402
-
403
- commands = self._generate_commands(groups1, groups2)
404
- labels = self._extract_labels(commands, self.prefix, "COORDINATION")
405
- return labels, commands
406
-
407
- def _generate_commands(
408
- self, groups1: List[List[int]], groups2: List[List[int]]
409
- ) -> List[str]:
410
- """Generates all necessary PLUMED commands."""
411
- commands: List[str] = []
412
-
413
- sites1 = self._reduce_groups(
414
- groups1, self.group_reduction_1, f"{self.prefix}_g1", commands
415
- )
416
- sites2 = self._reduce_groups(
417
- groups2, self.group_reduction_2, f"{self.prefix}_g2", commands
418
- )
419
-
420
- # Get site pairs using a simplified helper
421
- site_pairs = []
422
- if self.multi_group == "first":
423
- site_pairs = [(sites1[0], sites2[0])] if sites1 and sites2 else []
424
- elif self.multi_group == "all_pairs":
425
- site_pairs = [(s1, s2) for s1 in sites1 for s2 in sites2]
426
- elif self.multi_group == "corresponding":
427
- n = min(len(sites1), len(sites2))
428
- site_pairs = [(sites1[i], sites2[i]) for i in range(n)]
429
- elif self.multi_group == "first_to_all":
430
- site_pairs = [(sites1[0], s2) for s2 in sites2] if sites1 else []
431
-
432
- for i, (s1, s2) in enumerate(site_pairs):
433
- label = self.prefix if len(site_pairs) == 1 else f"{self.prefix}_{i}"
434
- commands.append(self._make_coordination_command(s1, s2, label))
763
+ if isinstance(input_obj, VirtualAtom):
764
+ # Set deterministic label if not already set
765
+ if input_obj.label is None:
766
+ labeled_va = dataclasses.replace(
767
+ input_obj, label=f"{self.prefix}_{label_prefix}"
768
+ )
769
+ else:
770
+ labeled_va = input_obj
435
771
 
436
- return commands
772
+ # Get virtual site labels
773
+ vsite_labels, vsite_commands = labeled_va.to_plumed(atoms)
437
774
 
438
- def _reduce_groups(
439
- self,
440
- groups: List[List[int]],
441
- strategy: GroupReductionStrategyType,
442
- site_prefix: str,
443
- commands: List[str],
444
- ) -> List[SiteIdentifier]:
445
- """Reduces a list of atom groups into a list of site identifiers."""
446
- if strategy in ["com_per_group", "cog_per_group"]:
447
- if not self.create_virtual_sites:
448
- raise ValueError(f"'{strategy}' requires create_virtual_sites=True")
449
-
450
- reduction_type = "COM" if strategy == "com_per_group" else "CENTER"
451
- vsite_labels = []
452
- for i, group in enumerate(groups):
453
- if not group:
454
- continue
455
- vsite_label = f"{site_prefix}_{i}"
456
- atom_list = ",".join(str(idx + 1) for idx in group)
457
- commands.append(f"{vsite_label}: {reduction_type} ATOMS={atom_list}")
458
- vsite_labels.append(vsite_label)
459
-
460
- group_label = f"{site_prefix}_group"
461
- commands.append(f"{group_label}: GROUP ATOMS={','.join(vsite_labels)}")
462
- return [group_label]
463
-
464
- if strategy == "all":
465
- return [sorted({idx for group in groups for idx in group})]
466
-
467
- # Handle other strategies by reducing each group individually.
468
- sites: List[SiteIdentifier] = []
469
- for i, group in enumerate(groups):
470
- if len(group) == 1 or strategy == "first":
471
- sites.append(str(group[0] + 1))
472
- elif strategy in ["com", "cog"]:
473
- if self.create_virtual_sites:
474
- label = f"{site_prefix}_{i}_{strategy}"
475
- cmd = self._create_virtual_site_command(group, strategy, label)
476
- commands.append(cmd)
477
- sites.append(label)
775
+ # If multiple virtual sites, create a GROUP of them
776
+ if len(vsite_labels) > 1:
777
+ group_label = f"{self.prefix}_{label_prefix}_group"
778
+ group_cmd = f"{group_label}: GROUP ATOMS={','.join(vsite_labels)}"
779
+ return [group_label], vsite_commands + [group_cmd]
780
+ else:
781
+ # Single virtual site, use directly
782
+ return vsite_labels, vsite_commands
783
+ else:
784
+ # AtomSelector: create group(s) based on flatten parameter
785
+ groups = input_obj.select(atoms)
786
+ if not groups:
787
+ return [], []
788
+
789
+ if self.flatten:
790
+ # Flatten all groups into single group
791
+ flat_atoms = [idx for group in groups for idx in group]
792
+ # Return as list of atom indices (will be formatted in COORDINATION command)
793
+ return [flat_atoms], []
794
+ else:
795
+ # Smart GROUP creation: only create GROUP for multi-atom groups
796
+ labels = []
797
+ commands = []
798
+ for i, group in enumerate(groups):
799
+ if len(group) == 1:
800
+ # Single atom: use directly (no GROUP needed)
801
+ labels.append(str(group[0] + 1))
802
+ else:
803
+ # Multi-atom group: create GROUP
804
+ group_label = f"{self.prefix}_{label_prefix}_g{i}"
805
+ atom_list = ",".join(str(idx + 1) for idx in group)
806
+ commands.append(f"{group_label}: GROUP ATOMS={atom_list}")
807
+ labels.append(group_label)
808
+
809
+ # If multiple groups, create a parent GROUP
810
+ if len(labels) > 1:
811
+ parent_label = f"{self.prefix}_{label_prefix}_group"
812
+ parent_cmd = f"{parent_label}: GROUP ATOMS={','.join(labels)}"
813
+ return [parent_label], commands + [parent_cmd]
478
814
  else:
479
- sites.append(group)
815
+ return labels, commands
816
+
817
+ def _generate_coordination_cvs(
818
+ self, labels1: List[str | List[int]], labels2: List[str | List[int]]
819
+ ) -> Tuple[List[str], List[str]]:
820
+ """Generate COORDINATION CV commands.
821
+
822
+ Parameters
823
+ ----------
824
+ labels1, labels2 : list[str | list[int]]
825
+ Group labels or atom index lists for GROUPA and GROUPB.
826
+
827
+ Returns
828
+ -------
829
+ cv_labels : list[str]
830
+ Labels for the COORDINATION CVs created.
831
+ commands : list[str]
832
+ COORDINATION command strings.
833
+ """
834
+ n1, n2 = len(labels1), len(labels2)
835
+
836
+ # Determine which pairs to create based on pairwise strategy
837
+ if n1 == 1 and n2 == 1:
838
+ # One-to-one: always create single CV
839
+ pairs = [(0, 0)]
840
+ elif n1 == 1:
841
+ # One-to-many: pair first of x1 with all of x2
842
+ pairs = [(0, j) for j in range(n2)]
843
+ elif n2 == 1:
844
+ # Many-to-one: pair all of x1 with first of x2
845
+ pairs = [(i, 0) for i in range(n1)]
846
+ else:
847
+ # Many-to-many: apply pairwise strategy
848
+ if self.pairwise == "all":
849
+ pairs = [(i, j) for i in range(n1) for j in range(n2)]
850
+ elif self.pairwise == "diagonal":
851
+ n_pairs = min(n1, n2)
852
+ pairs = [(i, i) for i in range(n_pairs)]
853
+ elif self.pairwise == "none":
854
+ raise ValueError(
855
+ f"Both x1 and x2 have multiple groups ({n1} and {n2}). "
856
+ f"Use pairwise='all' or 'diagonal', or select specific groups with indexing."
857
+ )
480
858
  else:
481
- raise ValueError(f"Unsupported reduction strategy: {strategy}")
482
- return sites
859
+ raise ValueError(f"Unknown pairwise strategy: {self.pairwise}")
483
860
 
484
- def _make_coordination_command(
485
- self, site1: SiteIdentifier, site2: SiteIdentifier, label: str
486
- ) -> str:
487
- """Creates a single PLUMED COORDINATION command string."""
861
+ # Generate COORDINATION commands
862
+ cv_labels = []
863
+ commands = []
864
+ for idx, (i, j) in enumerate(pairs):
865
+ if len(pairs) == 1:
866
+ label = self.prefix
867
+ else:
868
+ label = f"{self.prefix}_{idx}"
488
869
 
489
- def _format(site):
490
- return ",".join(map(str, (s + 1 for s in site))) if isinstance(site, list) else site
870
+ # Format group labels for COORDINATION
871
+ def format_group(g):
872
+ if isinstance(g, list): # List of atom indices
873
+ return ",".join(str(idx + 1) for idx in g)
874
+ else: # String label
875
+ return g
491
876
 
492
- g_a, g_b = _format(site1), _format(site2)
493
- base_cmd = f"{label}: COORDINATION GROUPA={g_a}"
494
- if g_a != g_b: # Omit GROUPB for self-coordination
495
- base_cmd += f" GROUPB={g_b}"
877
+ g_a = format_group(labels1[i])
878
+ g_b = format_group(labels2[j])
496
879
 
497
- params = f" R_0={self.r_0} NN={self.nn} D_0={self.d_0}"
498
- if self.mm != 0:
499
- params += f" MM={self.mm}"
880
+ # Create COORDINATION command
881
+ cmd = f"{label}: COORDINATION GROUPA={g_a}"
882
+ if g_a != g_b: # Omit GROUPB for self-coordination
883
+ cmd += f" GROUPB={g_b}"
884
+
885
+ # Add parameters
886
+ cmd += f" R_0={self.r_0} NN={self.nn} D_0={self.d_0}"
887
+ if self.mm != 0:
888
+ cmd += f" MM={self.mm}"
889
+
890
+ commands.append(cmd)
891
+ cv_labels.append(label)
500
892
 
501
- return base_cmd + params
893
+ return cv_labels, commands
502
894
 
503
895
 
504
896
  @dataclass
@@ -515,7 +907,7 @@ class TorsionCV(_BasePlumedCV):
515
907
  multi_group: Strategy for handling multiple groups from the selector.
516
908
 
517
909
  Resources:
518
- - https://www.plumed.org/doc-master/user-doc/html/TORSION.html
910
+ - https://www.plumed.org/doc-master/user-doc/html/TORSION
519
911
  """
520
912
 
521
913
  atoms: AtomSelector