openmmml 1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
openmmml/__init__.py ADDED
@@ -0,0 +1,2 @@
1
+ from .mlpotential import MLPotential
2
+ from . import models
@@ -0,0 +1,438 @@
1
+ """
2
+ mlpotential.py: Provides a common API for creating OpenMM Systems with ML potentials.
3
+
4
+ This is part of the OpenMM molecular simulation toolkit originating from
5
+ Simbios, the NIH National Center for Physics-Based Simulation of
6
+ Biological Structures at Stanford, funded under the NIH Roadmap for
7
+ Medical Research, grant U54 GM072970. See https://simtk.org.
8
+
9
+ Portions copyright (c) 2021-2024 Stanford University and the Authors.
10
+ Authors: Peter Eastman
11
+ Contributors:
12
+
13
+ Permission is hereby granted, free of charge, to any person obtaining a
14
+ copy of this software and associated documentation files (the "Software"),
15
+ to deal in the Software without restriction, including without limitation
16
+ the rights to use, copy, modify, merge, publish, distribute, sublicense,
17
+ and/or sell copies of the Software, and to permit persons to whom the
18
+ Software is furnished to do so, subject to the following conditions:
19
+
20
+ The above copyright notice and this permission notice shall be included in
21
+ all copies or substantial portions of the Software.
22
+
23
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
24
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
25
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
26
+ THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
27
+ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
28
+ OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
29
+ USE OR OTHER DEALINGS IN THE SOFTWARE.
30
+ """
31
+
32
+ import openmm
33
+ import openmm.app
34
+ import openmm.unit as unit
35
+ from copy import deepcopy
36
+ from typing import Dict, Iterable, Optional
37
+ import sys
38
+ if sys.version_info < (3, 10):
39
+ from importlib_metadata import entry_points
40
+ else:
41
+ from importlib.metadata import entry_points
42
+
43
+
44
+ class MLPotentialImplFactory(object):
45
+ """Abstract interface for classes that create MLPotentialImpl objects.
46
+
47
+ If you are defining a new potential function, you need to create subclasses
48
+ of MLPotentialImpl and MLPotentialImplFactory, and register an instance of
49
+ the factory by calling MLPotential.registerImplFactory(). Alternatively,
50
+ if a Python package creates an entry point in the group "openmmml.potentials",
51
+ the potential will be registered automatically. The entry point name is the
52
+ name of the potential function, and the value should be the name of the
53
+ MLPotentialImplFactory subclass.
54
+ """
55
+
56
+ def createImpl(self, name: str, **args) -> "MLPotentialImpl":
57
+ """Create a MLPotentialImpl that will be used to implement a MLPotential.
58
+
59
+ When a MLPotential is created, it invokes this method to create an object
60
+ implementing the requested potential. Subclasses must implement this method
61
+ to return an instance of the correct MLPotentialImpl subclass.
62
+
63
+ Parameters
64
+ ----------
65
+ name: str
66
+ the name of the potential that was specified to the MLPotential constructor
67
+ args:
68
+ any additional keyword arguments that were provided to the MLPotential
69
+ constructor are passed to this method. This allows subclasses to customize
70
+ their behavior based on extra arguments.
71
+
72
+ Returns
73
+ -------
74
+ a MLPotentialImpl that implements the potential
75
+ """
76
+ raise NotImplementedError('Subclasses must implement createImpl()')
77
+
78
+
79
+ class MLPotentialImpl(object):
80
+ """Abstract interface for classes that implement potential functions.
81
+
82
+ If you are defining a new potential function, you need to create subclasses
83
+ of MLPotentialImpl and MLPotentialImplFactory. When a user creates a
84
+ MLPotential and specifies a name for the potential to use, it looks up the
85
+ factory that has been registered for that name and uses it to create a
86
+ MLPotentialImpl of the appropriate subclass.
87
+ """
88
+
89
+ def addForces(self,
90
+ topology: openmm.app.Topology,
91
+ system: openmm.System,
92
+ atoms: Optional[Iterable[int]],
93
+ forceGroup: int,
94
+ **args):
95
+ """Add Force objects to a System to implement the potential function.
96
+
97
+ This is invoked by MLPotential.createSystem(). Subclasses must implement
98
+ it to create the requested potential function.
99
+
100
+ Parameters
101
+ ----------
102
+ topology: Topology
103
+ the Topology from which the System is being created
104
+ system: System
105
+ the System that is being created
106
+ atoms: Optional[Iterable[int]]
107
+ the indices of atoms the potential should be applied to, or None if
108
+ it should be applied to the entire System
109
+ forceGroup: int
110
+ the force group that any newly added Forces should be in
111
+ args:
112
+ any additional keyword arguments that were provided to createSystem()
113
+ are passed to this method. This allows subclasses to customize their
114
+ behavior based on extra arguments.
115
+ """
116
+ raise NotImplementedError('Subclasses must implement addForces()')
117
+
118
+
119
+ class MLPotential(object):
120
+ """A potential function that can be used in simulations.
121
+
122
+ To use this class, create a MLPotential, specifying the name of the potential
123
+ function to use. You can then call createSystem() to create a System object
124
+ for a simulation. For example,
125
+
126
+ >>> potential = MLPotential('ani2x')
127
+ >>> system = potential.createSystem(topology)
128
+
129
+ Alternatively, you can use createMixedSystem() to create a System where part is
130
+ modeled with this potential and the rest is modeled with a conventional force
131
+ field. As an example, suppose the Topology contains three chains. Chain 0 is
132
+ a protein, chain 1 is a ligand, and chain 2 is solvent. The following code
133
+ creates a System in which the internal energy of the ligand is computed with
134
+ ANI2x, while everything else (including interactions between the ligand and the
135
+ rest of the System) is computed with Amber14.
136
+
137
+ >>> forcefield = ForceField('amber14-all.xml', 'amber14/tip3pfb.xml')
138
+ >>> mm_system = forcefield.createSystem(topology)
139
+ >>> chains = list(topology.chains())
140
+ >>> ml_atoms = [atom.index for atom in chains[1].atoms()]
141
+ >>> potential = MLPotential('ani2x')
142
+ >>> ml_system = potential.createMixedSystem(topology, mm_system, ml_atoms)
143
+ """
144
+
145
+ _implFactories: Dict[str, MLPotentialImplFactory] = {}
146
+
147
+ def __init__(self, name: str, **args):
148
+ """Create a MLPotential.
149
+
150
+ Parameters
151
+ ----------
152
+ name: str
153
+ the name of the potential function to use. Built in support is currently
154
+ provided for the following: 'ani1ccx', 'ani2x'. Others may be added by
155
+ calling MLPotential.registerImplFactory().
156
+ args:
157
+ particular potential functions may define additional arguments that can
158
+ be used to customize them. See the documentation on the specific
159
+ potential functions for more information.
160
+ """
161
+ self._impl = MLPotential._implFactories[name].createImpl(name, **args)
162
+
163
+ def createSystem(self, topology: openmm.app.Topology, removeCMMotion: bool = True, **args) -> openmm.System:
164
+ """Create a System for running a simulation with this potential function.
165
+
166
+ Parameters
167
+ ----------
168
+ topology: Topology
169
+ the Topology for which to create a System
170
+ removeCMMotion: bool
171
+ if true, a CMMotionRemover will be added to the System.
172
+ args:
173
+ particular potential functions may define additional arguments that can
174
+ be used to customize them. See the documentation on the specific
175
+ potential functions for more information.
176
+
177
+ Returns
178
+ -------
179
+ a newly created System object that uses this potential function to model the Topology
180
+ """
181
+ system = openmm.System()
182
+ if topology.getPeriodicBoxVectors() is not None:
183
+ system.setDefaultPeriodicBoxVectors(*topology.getPeriodicBoxVectors())
184
+ for atom in topology.atoms():
185
+ if atom.element is None:
186
+ system.addParticle(0)
187
+ else:
188
+ system.addParticle(atom.element.mass)
189
+ self._impl.addForces(topology, system, None, 0, **args)
190
+ if removeCMMotion:
191
+ system.addForce(openmm.CMMotionRemover())
192
+ return system
193
+
194
+ def createMixedSystem(self,
195
+ topology: openmm.app.Topology,
196
+ system: openmm.System,
197
+ atoms: Iterable[int],
198
+ removeConstraints: bool = True,
199
+ forceGroup: int = 0,
200
+ interpolate: bool = False,
201
+ **args) -> openmm.System:
202
+ """Create a System that is partly modeled with this potential and partly
203
+ with a conventional force field.
204
+
205
+ To use this method, first create a System that is entirely modeled with the
206
+ conventional force field. Pass it to this method, along with the indices of the
207
+ atoms to model with this potential (the "ML subset"). It returns a new System
208
+ that is identical to the original one except for the following changes.
209
+
210
+ 1. Removing all bonds, angles, and torsions for which *all* atoms are in the
211
+ ML subset.
212
+ 2. For every NonbondedForce and CustomNonbondedForce, adding exceptions/exclusions
213
+ to prevent atoms in the ML subset from interacting with each other.
214
+ 3. (Optional) Removing constraints between atoms that are both in the ML subset.
215
+ 4. Adding Forces as necessary to compute the internal energy of the ML subset
216
+ with this potential.
217
+
218
+ Alternatively, the System can include Forces to compute the energy both with the
219
+ conventional force field and with this potential, and to smoothly interpolate
220
+ between them. In that case, it creates a CustomCVForce containing the following.
221
+
222
+ 1. The Forces to compute this potential.
223
+ 2. Forces to compute the bonds, angles, and torsions that were removed above.
224
+ 3. For every NonbondedForce, a corresponding CustomBondForce to compute the
225
+ nonbonded interactions within the ML subset.
226
+
227
+ The CustomCVForce defines a global parameter called "lambda_interpolate" that interpolates
228
+ between the two potentials. When lambda_interpolate=0, the energy is computed entirely with
229
+ the conventional force field. When lambda_interpolate=1, the energy is computed entirely with
230
+ the ML potential. You can set its value by calling setParameter() on the Context.
231
+
232
+ Parameters
233
+ ----------
234
+ topology: Topology
235
+ the Topology for which to create a System
236
+ system: System
237
+ a System that models the Topology with a conventional force field
238
+ atoms: Iterable[int]
239
+ the indices of all atoms whose interactions should be computed with
240
+ this potential
241
+ removeConstraints: bool
242
+ if True, remove constraints between pairs of atoms whose interaction
243
+ will be computed with this potential
244
+ forceGroup: int
245
+ the force group the ML potential's Forces should be placed in
246
+ interpolate: bool
247
+ if True, create a System that can smoothly interpolate between the conventional
248
+ and ML potentials
249
+ args:
250
+ particular potential functions may define additional arguments that can
251
+ be used to customize them. See the documentation on the specific
252
+ potential functions for more information.
253
+
254
+ Returns
255
+ -------
256
+ a newly created System object that uses this potential function to model the Topology
257
+ """
258
+ # Create the new System, removing bonded interactions within the ML subset.
259
+
260
+ newSystem = self._removeBonds(system, atoms, True, removeConstraints)
261
+
262
+ # Add nonbonded exceptions and exclusions.
263
+
264
+ atomList = list(atoms)
265
+ for force in newSystem.getForces():
266
+ if isinstance(force, openmm.NonbondedForce):
267
+ for i in range(len(atomList)):
268
+ for j in range(i):
269
+ force.addException(atomList[i], atomList[j], 0, 1, 0, True)
270
+ elif isinstance(force, openmm.CustomNonbondedForce):
271
+ existing = set(tuple(force.getExclusionParticles(i)) for i in range(force.getNumExclusions()))
272
+ for i in range(len(atomList)):
273
+ a1 = atomList[i]
274
+ for j in range(i):
275
+ a2 = atomList[j]
276
+ if (a1, a2) not in existing and (a2, a1) not in existing:
277
+ force.addExclusion(a1, a2)
278
+
279
+ # Add the ML potential.
280
+
281
+ if not interpolate:
282
+ self._impl.addForces(topology, newSystem, atomList, forceGroup, **args)
283
+ else:
284
+ # Create a CustomCVForce and put the ML forces inside it.
285
+
286
+ cv = openmm.CustomCVForce('')
287
+ cv.addGlobalParameter('lambda_interpolate', 1)
288
+ tempSystem = openmm.System()
289
+ self._impl.addForces(topology, tempSystem, atomList, forceGroup, **args)
290
+ mlVarNames = []
291
+ for i, force in enumerate(tempSystem.getForces()):
292
+ name = f'mlForce{i+1}'
293
+ cv.addCollectiveVariable(name, deepcopy(force))
294
+ mlVarNames.append(name)
295
+
296
+ # Create Forces for all the bonded interactions within the ML subset and add them to the CustomCVForce.
297
+
298
+ bondedSystem = self._removeBonds(system, atoms, False, removeConstraints)
299
+ bondedForces = []
300
+ for force in bondedSystem.getForces():
301
+ if hasattr(force, 'addBond') or hasattr(force, 'addAngle') or hasattr(force, 'addTorsion'):
302
+ bondedForces.append(force)
303
+ mmVarNames = []
304
+ for i, force in enumerate(bondedForces):
305
+ name = f'mmForce{i+1}'
306
+ cv.addCollectiveVariable(name, deepcopy(force))
307
+ mmVarNames.append(name)
308
+
309
+ # Create a CustomBondForce that computes all nonbonded interactions within the ML subset.
310
+
311
+ for force in system.getForces():
312
+ if isinstance(force, openmm.NonbondedForce):
313
+ internalNonbonded = openmm.CustomBondForce('138.935456*chargeProd/r + 4*epsilon*((sigma/r)^12-(sigma/r)^6)')
314
+ internalNonbonded.addPerBondParameter('chargeProd')
315
+ internalNonbonded.addPerBondParameter('sigma')
316
+ internalNonbonded.addPerBondParameter('epsilon')
317
+ numParticles = system.getNumParticles()
318
+ atomCharge = [0]*numParticles
319
+ atomSigma = [0]*numParticles
320
+ atomEpsilon = [0]*numParticles
321
+ for i in range(numParticles):
322
+ charge, sigma, epsilon = force.getParticleParameters(i)
323
+ atomCharge[i] = charge
324
+ atomSigma[i] = sigma
325
+ atomEpsilon[i] = epsilon
326
+ exceptions = {}
327
+ for i in range(force.getNumExceptions()):
328
+ p1, p2, chargeProd, sigma, epsilon = force.getExceptionParameters(i)
329
+ exceptions[(p1, p2)] = (chargeProd, sigma, epsilon)
330
+ for p1 in atomList:
331
+ for p2 in atomList:
332
+ if p1 == p2:
333
+ break
334
+ if (p1, p2) in exceptions:
335
+ chargeProd, sigma, epsilon = exceptions[(p1, p2)]
336
+ elif (p2, p1) in exceptions:
337
+ chargeProd, sigma, epsilon = exceptions[(p2, p1)]
338
+ else:
339
+ chargeProd = atomCharge[p1]*atomCharge[p2]
340
+ sigma = 0.5*(atomSigma[p1]+atomSigma[p2])
341
+ epsilon = unit.sqrt(atomEpsilon[p1]*atomEpsilon[p2])
342
+ if chargeProd._value != 0 or epsilon._value != 0:
343
+ internalNonbonded.addBond(p1, p2, [chargeProd, sigma, epsilon])
344
+ if internalNonbonded.getNumBonds() > 0:
345
+ name = f'mmForce{len(mmVarNames)+1}'
346
+ cv.addCollectiveVariable(name, internalNonbonded)
347
+ mmVarNames.append(name)
348
+
349
+ # Configure the CustomCVForce so lambda_interpolate interpolates between the conventional and ML potentials.
350
+
351
+ mlSum = '+'.join(mlVarNames) if len(mlVarNames) > 0 else '0'
352
+ mmSum = '+'.join(mmVarNames) if len(mmVarNames) > 0 else '0'
353
+ cv.setEnergyFunction(f'lambda_interpolate*({mlSum}) + (1-lambda_interpolate)*({mmSum})')
354
+ newSystem.addForce(cv)
355
+ return newSystem
356
+
357
+ def _removeBonds(self, system: openmm.System, atoms: Iterable[int], removeInSet: bool, removeConstraints: bool) -> openmm.System:
358
+ """Copy a System, removing all bonded interactions between atoms in (or not in) a particular set.
359
+
360
+ Parameters
361
+ ----------
362
+ system: System
363
+ the System to copy
364
+ atoms: Iterable[int]
365
+ a set of atom indices
366
+ removeInSet: bool
367
+ if True, any bonded term connecting atoms in the specified set is removed. If False,
368
+ any term that does *not* connect atoms in the specified set is removed
369
+ removeConstraints: bool
370
+ if True, remove constraints between pairs of atoms in the set
371
+
372
+ Returns
373
+ -------
374
+ a newly created System object in which the specified bonded interactions have been removed
375
+ """
376
+ atomSet = set(atoms)
377
+
378
+ # Create an XML representation of the System.
379
+
380
+ import xml.etree.ElementTree as ET
381
+ xml = openmm.XmlSerializer.serialize(system)
382
+ root = ET.fromstring(xml)
383
+
384
+ # This function decides whether a bonded interaction should be removed.
385
+
386
+ def shouldRemove(termAtoms):
387
+ return all(a in atomSet for a in termAtoms) == removeInSet
388
+
389
+ # Remove bonds, angles, and torsions.
390
+
391
+ for bonds in root.findall('./Forces/Force/Bonds'):
392
+ for bond in bonds.findall('Bond'):
393
+ bondAtoms = [int(bond.attrib[p]) for p in ('p1', 'p2')]
394
+ if shouldRemove(bondAtoms):
395
+ bonds.remove(bond)
396
+ for angles in root.findall('./Forces/Force/Angles'):
397
+ for angle in angles.findall('Angle'):
398
+ angleAtoms = [int(angle.attrib[p]) for p in ('p1', 'p2', 'p3')]
399
+ if shouldRemove(angleAtoms):
400
+ angles.remove(angle)
401
+ for torsions in root.findall('./Forces/Force/Torsions'):
402
+ for torsion in torsions.findall('Torsion'):
403
+ torsionLabels = ('p1', 'p2', 'p3', 'p4') if 'p1' in torsion.attrib else ('a1', 'a2', 'a3', 'a4', 'b1', 'b2', 'b3', 'b4')
404
+ torsionAtoms = [int(torsion.attrib[p]) for p in torsionLabels]
405
+ if shouldRemove(torsionAtoms):
406
+ torsions.remove(torsion)
407
+
408
+ # Optionally remove constraints.
409
+
410
+ if removeConstraints:
411
+ for constraints in root.findall('./Constraints'):
412
+ for constraint in constraints.findall('Constraint'):
413
+ constraintAtoms = [int(constraint.attrib[p]) for p in ('p1', 'p2')]
414
+ if shouldRemove(constraintAtoms):
415
+ constraints.remove(constraint)
416
+
417
+ # Create a new System from it.
418
+
419
+ return openmm.XmlSerializer.deserialize(ET.tostring(root, encoding='unicode'))
420
+
421
+ @staticmethod
422
+ def registerImplFactory(name: str, factory: MLPotentialImplFactory):
423
+ """Register a new potential function that can be used with MLPotential.
424
+
425
+ Parameters
426
+ ----------
427
+ name: str
428
+ the name of the potential function that will be passed to the MLPotential constructor
429
+ factory: MLPotentialImplFactory
430
+ a factory object that will be used to create MLPotentialImpl objects
431
+ """
432
+ MLPotential._implFactories[name] = factory
433
+
434
+
435
+ # Register any potential functions defined by entry points.
436
+
437
+ for potential in entry_points(group='openmmml.potentials'):
438
+ MLPotential.registerImplFactory(potential.name, potential.load()())
@@ -0,0 +1 @@
1
+ from . import anipotential, macepotential, nequippotential, deepmdpotential, aimnet2potential
@@ -0,0 +1,112 @@
1
+ """
2
+ aimnet2potential.py: Implements the AIMNet2 potential function.
3
+
4
+ This is part of the OpenMM molecular simulation toolkit originating from
5
+ Simbios, the NIH National Center for Physics-Based Simulation of
6
+ Biological Structures at Stanford, funded under the NIH Roadmap for
7
+ Medical Research, grant U54 GM072970. See https://simtk.org.
8
+
9
+ Portions copyright (c) 2021-2023 Stanford University and the Authors.
10
+ Authors: Peter Eastman
11
+ Contributors:
12
+
13
+ Permission is hereby granted, free of charge, to any person obtaining a
14
+ copy of this software and associated documentation files (the "Software"),
15
+ to deal in the Software without restriction, including without limitation
16
+ the rights to use, copy, modify, merge, publish, distribute, sublicense,
17
+ and/or sell copies of the Software, and to permit persons to whom the
18
+ Software is furnished to do so, subject to the following conditions:
19
+
20
+ The above copyright notice and this permission notice shall be included in
21
+ all copies or substantial portions of the Software.
22
+
23
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
24
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
25
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
26
+ THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
27
+ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
28
+ OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
29
+ USE OR OTHER DEALINGS IN THE SOFTWARE.
30
+ """
31
+
32
+ from openmmml.mlpotential import MLPotential, MLPotentialImpl, MLPotentialImplFactory
33
+ import openmm
34
+ from openmm import unit
35
+ from typing import Iterable, Optional, Union
36
+
37
+ class AIMNet2PotentialImplFactory(MLPotentialImplFactory):
38
+ """This is the factory that creates AIMNet2PotentialImpl objects."""
39
+
40
+ def createImpl(self, name: str, **args) -> MLPotentialImpl:
41
+ return AIMNet2PotentialImpl(name)
42
+
43
+
44
+ class AIMNet2PotentialImpl(MLPotentialImpl):
45
+ """This is the MLPotentialImpl implementing the AIMNet2 potential.
46
+
47
+ The AIMNet2 potential is constructed using `aimnet` to build a PyTorch model,
48
+ and then integrated into the OpenMM System using a TorchForce. To use it, specify the model by name:
49
+
50
+ >>> potential = MLPotential('aimnet2')
51
+ """
52
+
53
+ def __init__(self, name):
54
+ self.name = name
55
+
56
+ def addForces(self,
57
+ topology: openmm.app.Topology,
58
+ system: openmm.System,
59
+ atoms: Optional[Iterable[int]],
60
+ forceGroup: int,
61
+ **args):
62
+ # Load the AIMNet2 model.
63
+
64
+ try:
65
+ from aimnet.calculators import AIMNet2Calculator
66
+ except ImportError as e:
67
+ raise ImportError(f"Failed to import aimnet with error: {e}. Install from https://github.com/isayevlab/aimnetcentral.")
68
+ import torch
69
+ import openmmtorch
70
+ model = AIMNet2Calculator('aimnet2')
71
+
72
+ # Create the PyTorch model that will be invoked by OpenMM.
73
+
74
+ includedAtoms = list(topology.atoms())
75
+ if atoms is not None:
76
+ includedAtoms = [includedAtoms[i] for i in atoms]
77
+ numbers = torch.tensor([[atom.element.atomic_number for atom in includedAtoms]])
78
+ charge = torch.tensor([args.get('charge', 0)], dtype=torch.float32)
79
+ multiplicity = torch.tensor([args.get('multiplicity', 1)], dtype=torch.float32)
80
+
81
+ class AIMNet2Force(torch.nn.Module):
82
+
83
+ def __init__(self, calc, numbers, charge, atoms):
84
+ super(AIMNet2Force, self).__init__()
85
+ self.model = calc.model
86
+ self.numbers = torch.nn.Parameter(numbers, requires_grad=False)
87
+ self.charge = torch.nn.Parameter(charge, requires_grad=False)
88
+ self.multiplicity = torch.nn.Parameter(multiplicity, requires_grad=False)
89
+ self.energyScale = (unit.ev/unit.item).conversion_factor_to(unit.kilojoules_per_mole)
90
+ if atoms is None:
91
+ self.indices = None
92
+ else:
93
+ self.indices = torch.tensor(sorted(atoms), dtype=torch.int64)
94
+
95
+ def forward(self, positions: torch.Tensor, boxvectors: Optional[torch.Tensor] = None):
96
+ positions = positions.to(torch.float32).to(self.numbers.device)
97
+ if self.indices is not None:
98
+ positions = positions[self.indices]
99
+ args = {'coord': 10.0*positions.unsqueeze(0),
100
+ 'numbers': self.numbers,
101
+ 'charge': self.charge,
102
+ 'mult': self.multiplicity}
103
+ result = self.model(args)
104
+ energy = result["energy"].sum()
105
+ return self.energyScale*energy
106
+
107
+ # Create the TorchForce and add it to the System.
108
+
109
+ module = torch.jit.script(AIMNet2Force(model, numbers, charge, atoms)).to(torch.device('cpu'))
110
+ force = openmmtorch.TorchForce(module)
111
+ force.setForceGroup(forceGroup)
112
+ system.addForce(force)