mlmm-toolkit 0.2.2.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hessian_ff/__init__.py +50 -0
- hessian_ff/analytical_hessian.py +609 -0
- hessian_ff/constants.py +46 -0
- hessian_ff/forcefield.py +339 -0
- hessian_ff/loaders.py +608 -0
- hessian_ff/native/Makefile +8 -0
- hessian_ff/native/__init__.py +28 -0
- hessian_ff/native/analytical_hessian.py +88 -0
- hessian_ff/native/analytical_hessian_ext.cpp +258 -0
- hessian_ff/native/bonded.py +82 -0
- hessian_ff/native/bonded_ext.cpp +640 -0
- hessian_ff/native/loader.py +349 -0
- hessian_ff/native/nonbonded.py +118 -0
- hessian_ff/native/nonbonded_ext.cpp +1150 -0
- hessian_ff/prmtop_parmed.py +23 -0
- hessian_ff/system.py +107 -0
- hessian_ff/terms/__init__.py +14 -0
- hessian_ff/terms/angle.py +73 -0
- hessian_ff/terms/bond.py +44 -0
- hessian_ff/terms/cmap.py +406 -0
- hessian_ff/terms/dihedral.py +141 -0
- hessian_ff/terms/nonbonded.py +209 -0
- hessian_ff/tests/__init__.py +0 -0
- hessian_ff/tests/conftest.py +75 -0
- hessian_ff/tests/data/small/complex.parm7 +1346 -0
- hessian_ff/tests/data/small/complex.pdb +125 -0
- hessian_ff/tests/data/small/complex.rst7 +63 -0
- hessian_ff/tests/test_coords_input.py +44 -0
- hessian_ff/tests/test_energy_force.py +49 -0
- hessian_ff/tests/test_hessian.py +137 -0
- hessian_ff/tests/test_smoke.py +18 -0
- hessian_ff/tests/test_validation.py +40 -0
- hessian_ff/workflows.py +889 -0
- mlmm/__init__.py +36 -0
- mlmm/__main__.py +7 -0
- mlmm/_version.py +34 -0
- mlmm/add_elem_info.py +374 -0
- mlmm/advanced_help.py +91 -0
- mlmm/align_freeze_atoms.py +601 -0
- mlmm/all.py +3535 -0
- mlmm/bond_changes.py +231 -0
- mlmm/bool_compat.py +223 -0
- mlmm/cli.py +574 -0
- mlmm/cli_utils.py +166 -0
- mlmm/default_group.py +337 -0
- mlmm/defaults.py +467 -0
- mlmm/define_layer.py +526 -0
- mlmm/dft.py +1041 -0
- mlmm/energy_diagram.py +253 -0
- mlmm/extract.py +2213 -0
- mlmm/fix_altloc.py +464 -0
- mlmm/freq.py +1406 -0
- mlmm/harmonic_constraints.py +140 -0
- mlmm/hessian_cache.py +44 -0
- mlmm/hessian_calc.py +174 -0
- mlmm/irc.py +638 -0
- mlmm/mlmm_calc.py +2262 -0
- mlmm/mm_parm.py +945 -0
- mlmm/oniom_export.py +1983 -0
- mlmm/oniom_import.py +457 -0
- mlmm/opt.py +1742 -0
- mlmm/path_opt.py +1353 -0
- mlmm/path_search.py +2299 -0
- mlmm/preflight.py +88 -0
- mlmm/py.typed +1 -0
- mlmm/pysis_runner.py +45 -0
- mlmm/scan.py +1047 -0
- mlmm/scan2d.py +1226 -0
- mlmm/scan3d.py +1265 -0
- mlmm/scan_common.py +184 -0
- mlmm/summary_log.py +736 -0
- mlmm/trj2fig.py +448 -0
- mlmm/tsopt.py +2871 -0
- mlmm/utils.py +2309 -0
- mlmm/xtb_embedcharge_correction.py +475 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/METADATA +1159 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/RECORD +372 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/WHEEL +5 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/entry_points.txt +2 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/licenses/LICENSE +674 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/top_level.txt +4 -0
- pysisyphus/Geometry.py +1667 -0
- pysisyphus/LICENSE +674 -0
- pysisyphus/TableFormatter.py +63 -0
- pysisyphus/TablePrinter.py +74 -0
- pysisyphus/__init__.py +12 -0
- pysisyphus/calculators/AFIR.py +452 -0
- pysisyphus/calculators/AnaPot.py +20 -0
- pysisyphus/calculators/AnaPot2.py +48 -0
- pysisyphus/calculators/AnaPot3.py +12 -0
- pysisyphus/calculators/AnaPot4.py +20 -0
- pysisyphus/calculators/AnaPotBase.py +337 -0
- pysisyphus/calculators/AnaPotCBM.py +25 -0
- pysisyphus/calculators/AtomAtomTransTorque.py +154 -0
- pysisyphus/calculators/CFOUR.py +250 -0
- pysisyphus/calculators/Calculator.py +844 -0
- pysisyphus/calculators/CerjanMiller.py +24 -0
- pysisyphus/calculators/Composite.py +123 -0
- pysisyphus/calculators/ConicalIntersection.py +171 -0
- pysisyphus/calculators/DFTBp.py +430 -0
- pysisyphus/calculators/DFTD3.py +66 -0
- pysisyphus/calculators/DFTD4.py +84 -0
- pysisyphus/calculators/Dalton.py +61 -0
- pysisyphus/calculators/Dimer.py +681 -0
- pysisyphus/calculators/Dummy.py +20 -0
- pysisyphus/calculators/EGO.py +76 -0
- pysisyphus/calculators/EnergyMin.py +224 -0
- pysisyphus/calculators/ExternalPotential.py +264 -0
- pysisyphus/calculators/FakeASE.py +35 -0
- pysisyphus/calculators/FourWellAnaPot.py +28 -0
- pysisyphus/calculators/FreeEndNEBPot.py +39 -0
- pysisyphus/calculators/Gaussian09.py +18 -0
- pysisyphus/calculators/Gaussian16.py +726 -0
- pysisyphus/calculators/HardSphere.py +159 -0
- pysisyphus/calculators/IDPPCalculator.py +49 -0
- pysisyphus/calculators/IPIClient.py +133 -0
- pysisyphus/calculators/IPIServer.py +234 -0
- pysisyphus/calculators/LEPSBase.py +24 -0
- pysisyphus/calculators/LEPSExpr.py +139 -0
- pysisyphus/calculators/LennardJones.py +80 -0
- pysisyphus/calculators/MOPAC.py +219 -0
- pysisyphus/calculators/MullerBrownSympyPot.py +51 -0
- pysisyphus/calculators/MultiCalc.py +85 -0
- pysisyphus/calculators/NFK.py +45 -0
- pysisyphus/calculators/OBabel.py +87 -0
- pysisyphus/calculators/ONIOMv2.py +1129 -0
- pysisyphus/calculators/ORCA.py +893 -0
- pysisyphus/calculators/ORCA5.py +6 -0
- pysisyphus/calculators/OpenMM.py +88 -0
- pysisyphus/calculators/OpenMolcas.py +281 -0
- pysisyphus/calculators/OverlapCalculator.py +908 -0
- pysisyphus/calculators/Psi4.py +218 -0
- pysisyphus/calculators/PyPsi4.py +37 -0
- pysisyphus/calculators/PySCF.py +341 -0
- pysisyphus/calculators/PyXTB.py +73 -0
- pysisyphus/calculators/QCEngine.py +106 -0
- pysisyphus/calculators/Rastrigin.py +22 -0
- pysisyphus/calculators/Remote.py +76 -0
- pysisyphus/calculators/Rosenbrock.py +15 -0
- pysisyphus/calculators/SocketCalc.py +97 -0
- pysisyphus/calculators/TIP3P.py +111 -0
- pysisyphus/calculators/TransTorque.py +161 -0
- pysisyphus/calculators/Turbomole.py +965 -0
- pysisyphus/calculators/VRIPot.py +37 -0
- pysisyphus/calculators/WFOWrapper.py +333 -0
- pysisyphus/calculators/WFOWrapper2.py +341 -0
- pysisyphus/calculators/XTB.py +418 -0
- pysisyphus/calculators/__init__.py +81 -0
- pysisyphus/calculators/cosmo_data.py +139 -0
- pysisyphus/calculators/parser.py +150 -0
- pysisyphus/color.py +19 -0
- pysisyphus/config.py +133 -0
- pysisyphus/constants.py +65 -0
- pysisyphus/cos/AdaptiveNEB.py +230 -0
- pysisyphus/cos/ChainOfStates.py +725 -0
- pysisyphus/cos/FreeEndNEB.py +25 -0
- pysisyphus/cos/FreezingString.py +103 -0
- pysisyphus/cos/GrowingChainOfStates.py +71 -0
- pysisyphus/cos/GrowingNT.py +309 -0
- pysisyphus/cos/GrowingString.py +508 -0
- pysisyphus/cos/NEB.py +189 -0
- pysisyphus/cos/SimpleZTS.py +64 -0
- pysisyphus/cos/__init__.py +22 -0
- pysisyphus/cos/stiffness.py +199 -0
- pysisyphus/drivers/__init__.py +17 -0
- pysisyphus/drivers/afir.py +855 -0
- pysisyphus/drivers/barriers.py +271 -0
- pysisyphus/drivers/birkholz.py +138 -0
- pysisyphus/drivers/cluster.py +318 -0
- pysisyphus/drivers/diabatization.py +133 -0
- pysisyphus/drivers/merge.py +368 -0
- pysisyphus/drivers/merge_mol2.py +322 -0
- pysisyphus/drivers/opt.py +375 -0
- pysisyphus/drivers/perf.py +91 -0
- pysisyphus/drivers/pka.py +52 -0
- pysisyphus/drivers/precon_pos_rot.py +669 -0
- pysisyphus/drivers/rates.py +480 -0
- pysisyphus/drivers/replace.py +219 -0
- pysisyphus/drivers/scan.py +212 -0
- pysisyphus/drivers/spectrum.py +166 -0
- pysisyphus/drivers/thermo.py +31 -0
- pysisyphus/dynamics/Gaussian.py +103 -0
- pysisyphus/dynamics/__init__.py +20 -0
- pysisyphus/dynamics/colvars.py +136 -0
- pysisyphus/dynamics/driver.py +297 -0
- pysisyphus/dynamics/helpers.py +256 -0
- pysisyphus/dynamics/lincs.py +105 -0
- pysisyphus/dynamics/mdp.py +364 -0
- pysisyphus/dynamics/rattle.py +121 -0
- pysisyphus/dynamics/thermostats.py +128 -0
- pysisyphus/dynamics/wigner.py +266 -0
- pysisyphus/elem_data.py +3473 -0
- pysisyphus/exceptions.py +2 -0
- pysisyphus/filtertrj.py +69 -0
- pysisyphus/helpers.py +623 -0
- pysisyphus/helpers_pure.py +649 -0
- pysisyphus/init_logging.py +50 -0
- pysisyphus/intcoords/Bend.py +69 -0
- pysisyphus/intcoords/Bend2.py +25 -0
- pysisyphus/intcoords/BondedFragment.py +32 -0
- pysisyphus/intcoords/Cartesian.py +41 -0
- pysisyphus/intcoords/CartesianCoords.py +140 -0
- pysisyphus/intcoords/Coords.py +56 -0
- pysisyphus/intcoords/DLC.py +197 -0
- pysisyphus/intcoords/DistanceFunction.py +34 -0
- pysisyphus/intcoords/DummyImproper.py +70 -0
- pysisyphus/intcoords/DummyTorsion.py +72 -0
- pysisyphus/intcoords/LinearBend.py +105 -0
- pysisyphus/intcoords/LinearDisplacement.py +80 -0
- pysisyphus/intcoords/OutOfPlane.py +59 -0
- pysisyphus/intcoords/PrimTypes.py +286 -0
- pysisyphus/intcoords/Primitive.py +137 -0
- pysisyphus/intcoords/RedundantCoords.py +659 -0
- pysisyphus/intcoords/RobustTorsion.py +59 -0
- pysisyphus/intcoords/Rotation.py +147 -0
- pysisyphus/intcoords/Stretch.py +31 -0
- pysisyphus/intcoords/Torsion.py +101 -0
- pysisyphus/intcoords/Torsion2.py +25 -0
- pysisyphus/intcoords/Translation.py +45 -0
- pysisyphus/intcoords/__init__.py +61 -0
- pysisyphus/intcoords/augment_bonds.py +126 -0
- pysisyphus/intcoords/derivatives.py +10512 -0
- pysisyphus/intcoords/eval.py +80 -0
- pysisyphus/intcoords/exceptions.py +37 -0
- pysisyphus/intcoords/findiffs.py +48 -0
- pysisyphus/intcoords/generate_derivatives.py +414 -0
- pysisyphus/intcoords/helpers.py +235 -0
- pysisyphus/intcoords/logging_conf.py +10 -0
- pysisyphus/intcoords/mp_derivatives.py +10836 -0
- pysisyphus/intcoords/setup.py +962 -0
- pysisyphus/intcoords/setup_fast.py +176 -0
- pysisyphus/intcoords/update.py +272 -0
- pysisyphus/intcoords/valid.py +89 -0
- pysisyphus/interpolate/Geodesic.py +93 -0
- pysisyphus/interpolate/IDPP.py +55 -0
- pysisyphus/interpolate/Interpolator.py +116 -0
- pysisyphus/interpolate/LST.py +70 -0
- pysisyphus/interpolate/Redund.py +152 -0
- pysisyphus/interpolate/__init__.py +9 -0
- pysisyphus/interpolate/helpers.py +34 -0
- pysisyphus/io/__init__.py +22 -0
- pysisyphus/io/aomix.py +178 -0
- pysisyphus/io/cjson.py +24 -0
- pysisyphus/io/crd.py +101 -0
- pysisyphus/io/cube.py +220 -0
- pysisyphus/io/fchk.py +184 -0
- pysisyphus/io/hdf5.py +49 -0
- pysisyphus/io/hessian.py +72 -0
- pysisyphus/io/mol2.py +146 -0
- pysisyphus/io/molden.py +293 -0
- pysisyphus/io/orca.py +189 -0
- pysisyphus/io/pdb.py +269 -0
- pysisyphus/io/psf.py +79 -0
- pysisyphus/io/pubchem.py +31 -0
- pysisyphus/io/qcschema.py +34 -0
- pysisyphus/io/sdf.py +29 -0
- pysisyphus/io/xyz.py +61 -0
- pysisyphus/io/zmat.py +175 -0
- pysisyphus/irc/DWI.py +108 -0
- pysisyphus/irc/DampedVelocityVerlet.py +134 -0
- pysisyphus/irc/Euler.py +22 -0
- pysisyphus/irc/EulerPC.py +345 -0
- pysisyphus/irc/GonzalezSchlegel.py +187 -0
- pysisyphus/irc/IMKMod.py +164 -0
- pysisyphus/irc/IRC.py +878 -0
- pysisyphus/irc/IRCDummy.py +10 -0
- pysisyphus/irc/Instanton.py +307 -0
- pysisyphus/irc/LQA.py +53 -0
- pysisyphus/irc/ModeKill.py +136 -0
- pysisyphus/irc/ParamPlot.py +53 -0
- pysisyphus/irc/RK4.py +36 -0
- pysisyphus/irc/__init__.py +31 -0
- pysisyphus/irc/initial_displ.py +219 -0
- pysisyphus/linalg.py +411 -0
- pysisyphus/line_searches/Backtracking.py +88 -0
- pysisyphus/line_searches/HagerZhang.py +184 -0
- pysisyphus/line_searches/LineSearch.py +232 -0
- pysisyphus/line_searches/StrongWolfe.py +108 -0
- pysisyphus/line_searches/__init__.py +9 -0
- pysisyphus/line_searches/interpol.py +15 -0
- pysisyphus/modefollow/NormalMode.py +40 -0
- pysisyphus/modefollow/__init__.py +10 -0
- pysisyphus/modefollow/davidson.py +199 -0
- pysisyphus/modefollow/lanczos.py +95 -0
- pysisyphus/optimizers/BFGS.py +99 -0
- pysisyphus/optimizers/BacktrackingOptimizer.py +113 -0
- pysisyphus/optimizers/ConjugateGradient.py +98 -0
- pysisyphus/optimizers/CubicNewton.py +75 -0
- pysisyphus/optimizers/FIRE.py +113 -0
- pysisyphus/optimizers/HessianOptimizer.py +1176 -0
- pysisyphus/optimizers/LBFGS.py +228 -0
- pysisyphus/optimizers/LayerOpt.py +411 -0
- pysisyphus/optimizers/MicroOptimizer.py +169 -0
- pysisyphus/optimizers/NCOptimizer.py +90 -0
- pysisyphus/optimizers/Optimizer.py +1084 -0
- pysisyphus/optimizers/PreconLBFGS.py +260 -0
- pysisyphus/optimizers/PreconSteepestDescent.py +7 -0
- pysisyphus/optimizers/QuickMin.py +74 -0
- pysisyphus/optimizers/RFOptimizer.py +181 -0
- pysisyphus/optimizers/RSA.py +99 -0
- pysisyphus/optimizers/StabilizedQNMethod.py +248 -0
- pysisyphus/optimizers/SteepestDescent.py +23 -0
- pysisyphus/optimizers/StringOptimizer.py +173 -0
- pysisyphus/optimizers/__init__.py +41 -0
- pysisyphus/optimizers/closures.py +301 -0
- pysisyphus/optimizers/cls_map.py +58 -0
- pysisyphus/optimizers/exceptions.py +6 -0
- pysisyphus/optimizers/gdiis.py +280 -0
- pysisyphus/optimizers/guess_hessians.py +311 -0
- pysisyphus/optimizers/hessian_updates.py +355 -0
- pysisyphus/optimizers/poly_fit.py +285 -0
- pysisyphus/optimizers/precon.py +153 -0
- pysisyphus/optimizers/restrict_step.py +24 -0
- pysisyphus/pack.py +172 -0
- pysisyphus/peakdetect.py +948 -0
- pysisyphus/plot.py +1031 -0
- pysisyphus/run.py +2106 -0
- pysisyphus/socket_helper.py +74 -0
- pysisyphus/stocastic/FragmentKick.py +132 -0
- pysisyphus/stocastic/Kick.py +81 -0
- pysisyphus/stocastic/Pipeline.py +303 -0
- pysisyphus/stocastic/__init__.py +21 -0
- pysisyphus/stocastic/align.py +127 -0
- pysisyphus/testing.py +96 -0
- pysisyphus/thermo.py +156 -0
- pysisyphus/trj.py +824 -0
- pysisyphus/tsoptimizers/RSIRFOptimizer.py +56 -0
- pysisyphus/tsoptimizers/RSPRFOptimizer.py +182 -0
- pysisyphus/tsoptimizers/TRIM.py +59 -0
- pysisyphus/tsoptimizers/TSHessianOptimizer.py +463 -0
- pysisyphus/tsoptimizers/__init__.py +23 -0
- pysisyphus/wavefunction/Basis.py +239 -0
- pysisyphus/wavefunction/DIIS.py +76 -0
- pysisyphus/wavefunction/__init__.py +25 -0
- pysisyphus/wavefunction/build_ext.py +42 -0
- pysisyphus/wavefunction/cart2sph.py +190 -0
- pysisyphus/wavefunction/diabatization.py +304 -0
- pysisyphus/wavefunction/excited_states.py +435 -0
- pysisyphus/wavefunction/gen_ints.py +1811 -0
- pysisyphus/wavefunction/helpers.py +104 -0
- pysisyphus/wavefunction/ints/__init__.py +0 -0
- pysisyphus/wavefunction/ints/boys.py +193 -0
- pysisyphus/wavefunction/ints/boys_table_N_64_xasym_27.1_step_0.01.npy +0 -0
- pysisyphus/wavefunction/ints/cart_gto3d.py +176 -0
- pysisyphus/wavefunction/ints/coulomb3d.py +25928 -0
- pysisyphus/wavefunction/ints/diag_quadrupole3d.py +10036 -0
- pysisyphus/wavefunction/ints/dipole3d.py +8762 -0
- pysisyphus/wavefunction/ints/int2c2e3d.py +7198 -0
- pysisyphus/wavefunction/ints/int3c2e3d_sph.py +65040 -0
- pysisyphus/wavefunction/ints/kinetic3d.py +8240 -0
- pysisyphus/wavefunction/ints/ovlp3d.py +3777 -0
- pysisyphus/wavefunction/ints/quadrupole3d.py +15054 -0
- pysisyphus/wavefunction/ints/self_ovlp3d.py +198 -0
- pysisyphus/wavefunction/localization.py +458 -0
- pysisyphus/wavefunction/multipole.py +159 -0
- pysisyphus/wavefunction/normalization.py +36 -0
- pysisyphus/wavefunction/pop_analysis.py +134 -0
- pysisyphus/wavefunction/shells.py +1171 -0
- pysisyphus/wavefunction/wavefunction.py +504 -0
- pysisyphus/wrapper/__init__.py +11 -0
- pysisyphus/wrapper/exceptions.py +2 -0
- pysisyphus/wrapper/jmol.py +120 -0
- pysisyphus/wrapper/mwfn.py +169 -0
- pysisyphus/wrapper/packmol.py +71 -0
- pysisyphus/xyzloader.py +168 -0
- pysisyphus/yaml_mods.py +45 -0
- thermoanalysis/LICENSE +674 -0
- thermoanalysis/QCData.py +244 -0
- thermoanalysis/__init__.py +0 -0
- thermoanalysis/config.py +3 -0
- thermoanalysis/constants.py +20 -0
- thermoanalysis/thermo.py +1011 -0
mlmm/mlmm_calc.py
ADDED
|
@@ -0,0 +1,2262 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ONIOM-like ML/MM calculator coupling MLIP backends (UMA, ORB, MACE, AIMNet2)
|
|
3
|
+
with hessian_ff (MM).
|
|
4
|
+
|
|
5
|
+
Example:
|
|
6
|
+
calc = mlmm(input_pdb="input.pdb", real_parm7="real.parm7", model_pdb="model.pdb", charge=0)
|
|
7
|
+
|
|
8
|
+
For detailed documentation, see: docs/mlmm_calc.md
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import abc
|
|
14
|
+
import logging
|
|
15
|
+
import os
|
|
16
|
+
import warnings
|
|
17
|
+
import shutil
|
|
18
|
+
import tempfile
|
|
19
|
+
import time
|
|
20
|
+
from dataclasses import dataclass
|
|
21
|
+
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
|
22
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
import click
|
|
27
|
+
import numpy as np
|
|
28
|
+
import torch
|
|
29
|
+
import torch.nn as nn
|
|
30
|
+
|
|
31
|
+
from ase import Atoms
|
|
32
|
+
from ase.io import read
|
|
33
|
+
from ase.calculators.calculator import Calculator, all_changes
|
|
34
|
+
from ase.constraints import FixAtoms
|
|
35
|
+
|
|
36
|
+
import parmed as pmd
|
|
37
|
+
from hessian_ff import ForceFieldTorch, load_coords, load_system
|
|
38
|
+
from hessian_ff.analytical_hessian import build_analytical_hessian
|
|
39
|
+
|
|
40
|
+
# Optional OpenMM import
|
|
41
|
+
try:
|
|
42
|
+
import openmm as mm
|
|
43
|
+
from openmm import app, unit, Platform
|
|
44
|
+
from openmm.unit import ScaledUnit, joule
|
|
45
|
+
HAS_OPENMM = True
|
|
46
|
+
except ImportError:
|
|
47
|
+
HAS_OPENMM = False
|
|
48
|
+
|
|
49
|
+
# Optional fairchem import (UMA backend)
|
|
50
|
+
try:
|
|
51
|
+
from fairchem.core import pretrained_mlip
|
|
52
|
+
from fairchem.core.datasets.atomic_data import AtomicData
|
|
53
|
+
from fairchem.core.datasets import data_list_collater
|
|
54
|
+
HAS_FAIRCHEM = True
|
|
55
|
+
except ImportError:
|
|
56
|
+
HAS_FAIRCHEM = False
|
|
57
|
+
|
|
58
|
+
# Optional ORB backend
|
|
59
|
+
try:
|
|
60
|
+
import orb_models # noqa: F401
|
|
61
|
+
HAS_ORB = True
|
|
62
|
+
except ImportError:
|
|
63
|
+
HAS_ORB = False
|
|
64
|
+
|
|
65
|
+
# Optional MACE backend
|
|
66
|
+
try:
|
|
67
|
+
import mace # noqa: F401
|
|
68
|
+
HAS_MACE = True
|
|
69
|
+
except ImportError:
|
|
70
|
+
HAS_MACE = False
|
|
71
|
+
|
|
72
|
+
# Optional AIMNet2 backend
|
|
73
|
+
try:
|
|
74
|
+
import aimnet # noqa: F401
|
|
75
|
+
HAS_AIMNET2 = True
|
|
76
|
+
except ImportError:
|
|
77
|
+
HAS_AIMNET2 = False
|
|
78
|
+
|
|
79
|
+
# ---------- PySisyphus unit constants ----------
|
|
80
|
+
from pysisyphus.constants import BOHR2ANG, ANG2BOHR, AU2EV, AU2KCALPERMOL
|
|
81
|
+
EV2AU = 1.0 / AU2EV # eV → Hartree
|
|
82
|
+
KCALMOL2EV = AU2EV / AU2KCALPERMOL # kcal/mol -> eV
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
# ======================================================================
|
|
86
|
+
# ML Backend Abstraction
|
|
87
|
+
# ======================================================================
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class _MLBackend(abc.ABC):
|
|
91
|
+
"""Internal abstraction for the ML part of the ONIOM ML/MM coupling.
|
|
92
|
+
|
|
93
|
+
Each backend must provide energy/force evaluation and Hessian computation.
|
|
94
|
+
All quantities are in eV and Angstrom.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
@abc.abstractmethod
|
|
98
|
+
def eval(
|
|
99
|
+
self, atoms: Atoms, need_grad: bool = True
|
|
100
|
+
) -> Tuple[float, np.ndarray, Any]:
|
|
101
|
+
"""Evaluate energy and forces.
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
-------
|
|
105
|
+
E : float
|
|
106
|
+
Energy in eV.
|
|
107
|
+
F : ndarray (N, 3)
|
|
108
|
+
Forces in eV/Å.
|
|
109
|
+
opaque : Any
|
|
110
|
+
Backend-specific data needed for analytical Hessian (e.g., batch).
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
@abc.abstractmethod
|
|
114
|
+
def hessian_analytical(self, opaque: Any, n_atoms: int, *, dtype: torch.dtype) -> torch.Tensor:
|
|
115
|
+
"""Compute analytical Hessian from the opaque batch returned by eval().
|
|
116
|
+
|
|
117
|
+
Returns Hessian as a (N, 3, N, 3) torch Tensor in eV/Ų.
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
def hessian_fd(
|
|
121
|
+
self,
|
|
122
|
+
atoms: Atoms,
|
|
123
|
+
freeze_model: Sequence[int],
|
|
124
|
+
*,
|
|
125
|
+
eps_ang: float = 1.0e-3,
|
|
126
|
+
dtype: torch.dtype = torch.float32,
|
|
127
|
+
device: torch.device = torch.device("cpu"),
|
|
128
|
+
) -> torch.Tensor:
|
|
129
|
+
"""Compute Hessian via finite differences (central difference).
|
|
130
|
+
|
|
131
|
+
Generic implementation that works for all backends.
|
|
132
|
+
"""
|
|
133
|
+
n_atoms = len(atoms)
|
|
134
|
+
dof = n_atoms * 3
|
|
135
|
+
|
|
136
|
+
frozen_set = set(int(i) for i in freeze_model)
|
|
137
|
+
active_atoms = [i for i in range(n_atoms) if i not in frozen_set]
|
|
138
|
+
active_dof_idx = [3 * i + j for i in active_atoms for j in range(3)]
|
|
139
|
+
|
|
140
|
+
H = torch.zeros((dof, dof), device=device, dtype=dtype)
|
|
141
|
+
coord0 = atoms.get_positions().copy()
|
|
142
|
+
for k in active_dof_idx:
|
|
143
|
+
a = k // 3
|
|
144
|
+
c = k % 3
|
|
145
|
+
|
|
146
|
+
atoms.positions = coord0.copy()
|
|
147
|
+
atoms.positions[a, c] = coord0[a, c] + eps_ang
|
|
148
|
+
_, Fp, _ = self.eval(atoms, need_grad=False)
|
|
149
|
+
|
|
150
|
+
atoms.positions = coord0.copy()
|
|
151
|
+
atoms.positions[a, c] = coord0[a, c] - eps_ang
|
|
152
|
+
_, Fm, _ = self.eval(atoms, need_grad=False)
|
|
153
|
+
|
|
154
|
+
col = -(torch.from_numpy(Fp.reshape(-1)) - torch.from_numpy(Fm.reshape(-1))) / (2.0 * eps_ang)
|
|
155
|
+
H[:, k] = col.to(device, dtype=dtype)
|
|
156
|
+
|
|
157
|
+
atoms.positions = coord0
|
|
158
|
+
return H.view(n_atoms, 3, n_atoms, 3)
|
|
159
|
+
|
|
160
|
+
@property
|
|
161
|
+
@abc.abstractmethod
|
|
162
|
+
def supports_analytical_hessian(self) -> bool:
|
|
163
|
+
"""Whether this backend supports analytical Hessian."""
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
@abc.abstractmethod
|
|
167
|
+
def device(self) -> torch.device:
|
|
168
|
+
"""The torch device this backend uses."""
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class _UMABackend(_MLBackend):
|
|
172
|
+
"""UMA (FAIR-Chem) ML backend."""
|
|
173
|
+
|
|
174
|
+
def __init__(
|
|
175
|
+
self,
|
|
176
|
+
*,
|
|
177
|
+
uma_model: str = "uma-s-1p1",
|
|
178
|
+
uma_task_name: str = "omol",
|
|
179
|
+
model_charge: int = 0,
|
|
180
|
+
model_mult: int = 1,
|
|
181
|
+
ml_device: torch.device,
|
|
182
|
+
):
|
|
183
|
+
if not HAS_FAIRCHEM:
|
|
184
|
+
raise ImportError(
|
|
185
|
+
"fairchem-core is required for the UMA backend. "
|
|
186
|
+
"Install with `pip install fairchem-core` "
|
|
187
|
+
"and ensure Hugging Face authentication is configured."
|
|
188
|
+
)
|
|
189
|
+
self._device = ml_device
|
|
190
|
+
device_str = "cuda" if ml_device.type == "cuda" else "cpu"
|
|
191
|
+
self._AtomicData = AtomicData
|
|
192
|
+
self._data_list_collater = data_list_collater
|
|
193
|
+
self.predictor = pretrained_mlip.get_predict_unit(uma_model, device=device_str)
|
|
194
|
+
self.predictor.model.eval()
|
|
195
|
+
for m in self.predictor.model.modules():
|
|
196
|
+
if isinstance(m, nn.Dropout):
|
|
197
|
+
m.p = 0.0
|
|
198
|
+
self.uma_task_name = uma_task_name
|
|
199
|
+
self.model_charge = model_charge
|
|
200
|
+
self.model_mult = model_mult
|
|
201
|
+
backbone = getattr(self.predictor.model, "module", self.predictor.model).backbone
|
|
202
|
+
self._uma_max_neigh = getattr(backbone, "max_neighbors", None)
|
|
203
|
+
self._uma_radius = getattr(backbone, "cutoff", None)
|
|
204
|
+
|
|
205
|
+
@property
|
|
206
|
+
def supports_analytical_hessian(self) -> bool:
|
|
207
|
+
return True
|
|
208
|
+
|
|
209
|
+
@property
|
|
210
|
+
def device(self) -> torch.device:
|
|
211
|
+
return self._device
|
|
212
|
+
|
|
213
|
+
def eval(self, atoms: Atoms, need_grad: bool = True) -> Tuple[float, np.ndarray, Any]:
|
|
214
|
+
atoms.info.update({"charge": self.model_charge, "spin": self.model_mult - 1})
|
|
215
|
+
data = self._AtomicData.from_ase(
|
|
216
|
+
atoms,
|
|
217
|
+
max_neigh=self._uma_max_neigh,
|
|
218
|
+
radius=self._uma_radius,
|
|
219
|
+
r_edges=False,
|
|
220
|
+
).to(self._device)
|
|
221
|
+
data.dataset = self.uma_task_name
|
|
222
|
+
batch = self._data_list_collater([data], otf_graph=True).to(self._device)
|
|
223
|
+
pos = batch.pos.detach().clone().to(self._device)
|
|
224
|
+
pos.requires_grad_(need_grad)
|
|
225
|
+
batch.pos = pos
|
|
226
|
+
if need_grad:
|
|
227
|
+
res = self.predictor.predict(batch)
|
|
228
|
+
else:
|
|
229
|
+
with torch.no_grad():
|
|
230
|
+
res = self.predictor.predict(batch)
|
|
231
|
+
E = float(res["energy"].squeeze().detach().item())
|
|
232
|
+
F = res["forces"].detach().cpu().numpy()
|
|
233
|
+
return E, F, batch
|
|
234
|
+
|
|
235
|
+
def hessian_analytical(self, opaque: Any, n_atoms: int, *, dtype: torch.dtype) -> torch.Tensor:
|
|
236
|
+
batch = opaque
|
|
237
|
+
p_flags = [p.requires_grad for p in self.predictor.model.parameters()]
|
|
238
|
+
for p in self.predictor.model.parameters():
|
|
239
|
+
p.requires_grad_(False)
|
|
240
|
+
|
|
241
|
+
self.predictor.model.train()
|
|
242
|
+
try:
|
|
243
|
+
pos = batch.pos
|
|
244
|
+
|
|
245
|
+
def energy_fn(flat_pos: torch.Tensor):
|
|
246
|
+
batch.pos = flat_pos.view(-1, 3)
|
|
247
|
+
return self.predictor.predict(batch)["energy"].squeeze()
|
|
248
|
+
|
|
249
|
+
H_flat = torch.autograd.functional.hessian(energy_fn, pos.view(-1), vectorize=False)
|
|
250
|
+
H = H_flat.view(n_atoms, 3, n_atoms, 3).to(dtype).detach()
|
|
251
|
+
finally:
|
|
252
|
+
self.predictor.model.eval()
|
|
253
|
+
for p, flag in zip(self.predictor.model.parameters(), p_flags):
|
|
254
|
+
p.requires_grad_(flag)
|
|
255
|
+
if self._device.type == "cuda":
|
|
256
|
+
torch.cuda.empty_cache()
|
|
257
|
+
return H
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
class _ASEMLBackend(_MLBackend):
|
|
261
|
+
"""Base class for ASE-calculator-based ML backends (ORB, MACE, AIMNet2).
|
|
262
|
+
|
|
263
|
+
Subclasses must set ``self._ase_calc`` (an ASE Calculator) and
|
|
264
|
+
``self._device``.
|
|
265
|
+
"""
|
|
266
|
+
|
|
267
|
+
_ase_calc: Calculator
|
|
268
|
+
_device: torch.device
|
|
269
|
+
_model_charge: int = 0
|
|
270
|
+
_model_mult: int = 1
|
|
271
|
+
|
|
272
|
+
@property
|
|
273
|
+
def supports_analytical_hessian(self) -> bool:
|
|
274
|
+
return False
|
|
275
|
+
|
|
276
|
+
@property
|
|
277
|
+
def device(self) -> torch.device:
|
|
278
|
+
return self._device
|
|
279
|
+
|
|
280
|
+
def eval(self, atoms: Atoms, need_grad: bool = True) -> Tuple[float, np.ndarray, Any]:
|
|
281
|
+
atoms_copy = atoms.copy()
|
|
282
|
+
atoms_copy.calc = self._ase_calc
|
|
283
|
+
# Propagate charge/spin to ASE Atoms info for backends that use them
|
|
284
|
+
# (e.g. AIMNet2 reads atoms.info['charge'] and atoms.info['mult'])
|
|
285
|
+
atoms_copy.info["charge"] = self._model_charge
|
|
286
|
+
atoms_copy.info["mult"] = self._model_mult
|
|
287
|
+
E = float(atoms_copy.get_potential_energy())
|
|
288
|
+
F = np.array(atoms_copy.get_forces(), dtype=np.float64)
|
|
289
|
+
return E, F, None
|
|
290
|
+
|
|
291
|
+
def hessian_analytical(self, opaque: Any, n_atoms: int, *, dtype: torch.dtype) -> torch.Tensor:
|
|
292
|
+
raise NotImplementedError(
|
|
293
|
+
f"Analytical Hessian is not supported by {self.__class__.__name__}. "
|
|
294
|
+
"Use hessian_calc_mode='FiniteDifference'."
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
class _OrbBackend(_ASEMLBackend):
|
|
299
|
+
"""ORB (Orbital Materials) ML backend."""
|
|
300
|
+
|
|
301
|
+
def __init__(
|
|
302
|
+
self,
|
|
303
|
+
*,
|
|
304
|
+
orb_model: str = "orb_v3_conservative_omol",
|
|
305
|
+
model_charge: int = 0,
|
|
306
|
+
model_mult: int = 1,
|
|
307
|
+
ml_device: torch.device,
|
|
308
|
+
**_kwargs, # absorb unused keys (e.g. orb_precision)
|
|
309
|
+
):
|
|
310
|
+
if not HAS_ORB:
|
|
311
|
+
raise ImportError(
|
|
312
|
+
"orb-models is required for the ORB backend. "
|
|
313
|
+
"Install with `pip install orb-models`."
|
|
314
|
+
)
|
|
315
|
+
from orb_models.forcefield import pretrained
|
|
316
|
+
from orb_models.forcefield.calculator import ORBCalculator
|
|
317
|
+
|
|
318
|
+
device_str = "cuda" if ml_device.type == "cuda" else "cpu"
|
|
319
|
+
orbff = getattr(pretrained, orb_model)(device=device_str)
|
|
320
|
+
self._ase_calc = ORBCalculator(orbff, device=device_str)
|
|
321
|
+
self._device = ml_device
|
|
322
|
+
self._model_charge = model_charge
|
|
323
|
+
self._model_mult = model_mult
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
class _MACEBackend(_ASEMLBackend):
|
|
327
|
+
"""MACE ML backend."""
|
|
328
|
+
|
|
329
|
+
def __init__(
|
|
330
|
+
self,
|
|
331
|
+
*,
|
|
332
|
+
mace_model: str = "MACE-OMOL-0",
|
|
333
|
+
mace_dtype: str = "float64",
|
|
334
|
+
model_charge: int = 0,
|
|
335
|
+
model_mult: int = 1,
|
|
336
|
+
ml_device: torch.device,
|
|
337
|
+
):
|
|
338
|
+
if not HAS_MACE:
|
|
339
|
+
raise ImportError(
|
|
340
|
+
"mace-torch is required for the MACE backend. "
|
|
341
|
+
"Install with `pip install mace-torch`."
|
|
342
|
+
)
|
|
343
|
+
from mace.calculators import mace_off, mace_mp, mace_anicc
|
|
344
|
+
|
|
345
|
+
device_str = "cuda" if ml_device.type == "cuda" else "cpu"
|
|
346
|
+
model_lower = mace_model.lower()
|
|
347
|
+
|
|
348
|
+
# Resolve model name to the appropriate factory
|
|
349
|
+
if model_lower.startswith("mp:") or model_lower.startswith("mace-mp"):
|
|
350
|
+
model_name = mace_model.split(":", 1)[-1] if ":" in mace_model else mace_model
|
|
351
|
+
self._ase_calc = mace_mp(
|
|
352
|
+
model=model_name, device=device_str, default_dtype=mace_dtype
|
|
353
|
+
)
|
|
354
|
+
elif model_lower.startswith("off:") or model_lower.startswith("mace-off"):
|
|
355
|
+
model_name = mace_model.split(":", 1)[-1] if ":" in mace_model else mace_model
|
|
356
|
+
self._ase_calc = mace_off(
|
|
357
|
+
model=model_name, device=device_str, default_dtype=mace_dtype
|
|
358
|
+
)
|
|
359
|
+
elif model_lower.startswith("anicc") or model_lower.startswith("mace-anicc"):
|
|
360
|
+
self._ase_calc = mace_anicc(device=device_str, default_dtype=mace_dtype)
|
|
361
|
+
elif model_lower.startswith("omol") or model_lower.startswith("mace-omol"):
|
|
362
|
+
# MACE-OMOL uses mace_off with the omol model
|
|
363
|
+
self._ase_calc = mace_off(
|
|
364
|
+
model=mace_model, device=device_str, default_dtype=mace_dtype
|
|
365
|
+
)
|
|
366
|
+
else:
|
|
367
|
+
# Treat as a local model file or direct mace_off model
|
|
368
|
+
self._ase_calc = mace_off(
|
|
369
|
+
model=mace_model, device=device_str, default_dtype=mace_dtype
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
self._device = ml_device
|
|
373
|
+
self._model_charge = model_charge
|
|
374
|
+
self._model_mult = model_mult
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
class _AIMNet2Backend(_ASEMLBackend):
|
|
378
|
+
"""AIMNet2 ML backend."""
|
|
379
|
+
|
|
380
|
+
def __init__(
|
|
381
|
+
self,
|
|
382
|
+
*,
|
|
383
|
+
aimnet2_model: str = "aimnet2",
|
|
384
|
+
model_charge: int = 0,
|
|
385
|
+
model_mult: int = 1,
|
|
386
|
+
ml_device: torch.device,
|
|
387
|
+
):
|
|
388
|
+
if not HAS_AIMNET2:
|
|
389
|
+
raise ImportError(
|
|
390
|
+
"aimnet is required for the AIMNet2 backend. "
|
|
391
|
+
"Install with `pip install aimnet`."
|
|
392
|
+
)
|
|
393
|
+
from aimnet.calculators import AIMNet2Calculator
|
|
394
|
+
|
|
395
|
+
device_str = "cuda" if ml_device.type == "cuda" else "cpu"
|
|
396
|
+
self._ase_calc = AIMNet2Calculator(model=aimnet2_model, device=device_str)
|
|
397
|
+
self._device = ml_device
|
|
398
|
+
self._model_charge = model_charge
|
|
399
|
+
self._model_mult = model_mult
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def _create_ml_backend(
|
|
403
|
+
backend: str,
|
|
404
|
+
*,
|
|
405
|
+
uma_model: str = "uma-s-1p1",
|
|
406
|
+
uma_task_name: str = "omol",
|
|
407
|
+
orb_model: str = "orb_v3_conservative_omol",
|
|
408
|
+
mace_model: str = "MACE-OMOL-0",
|
|
409
|
+
mace_dtype: str = "float64",
|
|
410
|
+
aimnet2_model: str = "aimnet2",
|
|
411
|
+
model_charge: int = 0,
|
|
412
|
+
model_mult: int = 1,
|
|
413
|
+
ml_device: torch.device,
|
|
414
|
+
) -> _MLBackend:
|
|
415
|
+
"""Factory function to create the appropriate ML backend."""
|
|
416
|
+
backend = backend.strip().lower()
|
|
417
|
+
if backend == "uma":
|
|
418
|
+
return _UMABackend(
|
|
419
|
+
uma_model=uma_model,
|
|
420
|
+
uma_task_name=uma_task_name,
|
|
421
|
+
model_charge=model_charge,
|
|
422
|
+
model_mult=model_mult,
|
|
423
|
+
ml_device=ml_device,
|
|
424
|
+
)
|
|
425
|
+
elif backend == "orb":
|
|
426
|
+
return _OrbBackend(
|
|
427
|
+
orb_model=orb_model,
|
|
428
|
+
model_charge=model_charge,
|
|
429
|
+
model_mult=model_mult,
|
|
430
|
+
ml_device=ml_device,
|
|
431
|
+
)
|
|
432
|
+
elif backend == "mace":
|
|
433
|
+
return _MACEBackend(
|
|
434
|
+
mace_model=mace_model,
|
|
435
|
+
mace_dtype=mace_dtype,
|
|
436
|
+
model_charge=model_charge,
|
|
437
|
+
model_mult=model_mult,
|
|
438
|
+
ml_device=ml_device,
|
|
439
|
+
)
|
|
440
|
+
elif backend == "aimnet2":
|
|
441
|
+
return _AIMNet2Backend(
|
|
442
|
+
aimnet2_model=aimnet2_model,
|
|
443
|
+
model_charge=model_charge,
|
|
444
|
+
model_mult=model_mult,
|
|
445
|
+
ml_device=ml_device,
|
|
446
|
+
)
|
|
447
|
+
else:
|
|
448
|
+
raise ValueError(
|
|
449
|
+
f"Unknown ML backend '{backend}'. Choose from: uma, orb, mace, aimnet2."
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
# ======================================================================
|
|
454
|
+
# xTB Point-Charge Embedding Correction
|
|
455
|
+
# ======================================================================
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
class _EmbedChargeCorrection:
|
|
459
|
+
"""xTB-based point-charge embedding correction for ONIOM ML/MM.
|
|
460
|
+
|
|
461
|
+
Computes the electrostatic interaction between the ML region and
|
|
462
|
+
the MM point charges via xTB:
|
|
463
|
+
|
|
464
|
+
dE = E_xTB(ML + MM_charges) - E_xTB(ML_only)
|
|
465
|
+
dF = F_xTB(ML + MM_charges) - F_xTB(ML_only)
|
|
466
|
+
|
|
467
|
+
This accounts for the environmental electrostatic effect of MM
|
|
468
|
+
atoms on the ML region, which is not captured by the subtractive
|
|
469
|
+
ONIOM scheme alone.
|
|
470
|
+
"""
|
|
471
|
+
|
|
472
|
+
def __init__(
|
|
473
|
+
self,
|
|
474
|
+
*,
|
|
475
|
+
xtb_cmd: str = "xtb",
|
|
476
|
+
xtb_acc: float = 0.2,
|
|
477
|
+
xtb_workdir: str = "tmp",
|
|
478
|
+
xtb_keep_files: bool = False,
|
|
479
|
+
xtb_ncores: int = 4,
|
|
480
|
+
hessian_step: float = 1.0e-3,
|
|
481
|
+
):
|
|
482
|
+
self.xtb_cmd = xtb_cmd
|
|
483
|
+
self.xtb_acc = xtb_acc
|
|
484
|
+
self.xtb_workdir = xtb_workdir
|
|
485
|
+
self.xtb_keep_files = xtb_keep_files
|
|
486
|
+
self.xtb_ncores = xtb_ncores
|
|
487
|
+
self.hessian_step = hessian_step
|
|
488
|
+
|
|
489
|
+
def compute_correction(
|
|
490
|
+
self,
|
|
491
|
+
symbols: List[str],
|
|
492
|
+
coords_ml_ang: np.ndarray,
|
|
493
|
+
mm_coords_ang: np.ndarray,
|
|
494
|
+
mm_charges: np.ndarray,
|
|
495
|
+
charge: int,
|
|
496
|
+
multiplicity: int,
|
|
497
|
+
*,
|
|
498
|
+
need_forces: bool = False,
|
|
499
|
+
need_hessian: bool = False,
|
|
500
|
+
) -> Tuple[float, Optional[np.ndarray], Optional[np.ndarray]]:
|
|
501
|
+
"""Compute point-charge embedding correction.
|
|
502
|
+
|
|
503
|
+
Parameters
|
|
504
|
+
----------
|
|
505
|
+
symbols : list of str
|
|
506
|
+
Element symbols for ML atoms.
|
|
507
|
+
coords_ml_ang : ndarray (N_ML, 3)
|
|
508
|
+
Coordinates of ML atoms in Angstrom.
|
|
509
|
+
mm_coords_ang : ndarray (N_MM, 3)
|
|
510
|
+
Coordinates of MM point charges in Angstrom.
|
|
511
|
+
mm_charges : ndarray (N_MM,)
|
|
512
|
+
Charges of MM point charges in atomic units.
|
|
513
|
+
charge : int
|
|
514
|
+
Total charge of the ML region.
|
|
515
|
+
multiplicity : int
|
|
516
|
+
Spin multiplicity of the ML region.
|
|
517
|
+
need_forces : bool
|
|
518
|
+
Whether to compute force corrections.
|
|
519
|
+
need_hessian : bool
|
|
520
|
+
Whether to compute Hessian corrections.
|
|
521
|
+
|
|
522
|
+
Returns
|
|
523
|
+
-------
|
|
524
|
+
dE : float
|
|
525
|
+
Energy correction in eV.
|
|
526
|
+
dF_ml : ndarray (N_ML, 3) or None
|
|
527
|
+
Force corrections for ML atoms in eV/Å.
|
|
528
|
+
dH_ml : ndarray (3*N_ML, 3*N_ML) or None
|
|
529
|
+
Hessian correction for ML atoms in eV/Ų.
|
|
530
|
+
"""
|
|
531
|
+
from .xtb_embedcharge_correction import delta_embedcharge_minus_noembed
|
|
532
|
+
|
|
533
|
+
n_ml = len(symbols)
|
|
534
|
+
mm_coords = np.asarray(mm_coords_ang, dtype=np.float64).reshape(-1, 3)
|
|
535
|
+
mm_q = np.asarray(mm_charges, dtype=np.float64).reshape(-1)
|
|
536
|
+
n_mm = mm_q.shape[0]
|
|
537
|
+
|
|
538
|
+
if n_mm == 0:
|
|
539
|
+
dF = np.zeros((n_ml, 3), dtype=np.float64) if need_forces else None
|
|
540
|
+
dH = np.zeros((3 * n_ml, 3 * n_ml), dtype=np.float64) if need_hessian else None
|
|
541
|
+
return 0.0, dF, dH
|
|
542
|
+
|
|
543
|
+
dE_ev, dF_full_ev, dH_full_ev = delta_embedcharge_minus_noembed(
|
|
544
|
+
symbols=symbols,
|
|
545
|
+
coords_q_ang=np.asarray(coords_ml_ang, dtype=np.float64).reshape(-1, 3),
|
|
546
|
+
mm_coords_ang=mm_coords,
|
|
547
|
+
mm_charges=mm_q,
|
|
548
|
+
charge=charge,
|
|
549
|
+
multiplicity=multiplicity,
|
|
550
|
+
need_forces=need_forces or need_hessian,
|
|
551
|
+
need_hessian=need_hessian,
|
|
552
|
+
xtb_cmd=self.xtb_cmd,
|
|
553
|
+
xtb_acc=self.xtb_acc,
|
|
554
|
+
xtb_workdir=self.xtb_workdir,
|
|
555
|
+
xtb_keep_files=self.xtb_keep_files,
|
|
556
|
+
ncores=self.xtb_ncores,
|
|
557
|
+
hessian_step=self.hessian_step,
|
|
558
|
+
)
|
|
559
|
+
|
|
560
|
+
dF_ml = None
|
|
561
|
+
if dF_full_ev is not None:
|
|
562
|
+
# Extract only the ML-atom forces (first n_ml rows)
|
|
563
|
+
dF_ml = np.asarray(dF_full_ev, dtype=np.float64).reshape(-1, 3)[:n_ml]
|
|
564
|
+
|
|
565
|
+
dH_ml = None
|
|
566
|
+
if dH_full_ev is not None:
|
|
567
|
+
# Extract only the ML-atom Hessian block
|
|
568
|
+
dof_ml = 3 * n_ml
|
|
569
|
+
dH_full = np.asarray(dH_full_ev, dtype=np.float64)
|
|
570
|
+
dH_ml = dH_full[:dof_ml, :dof_ml]
|
|
571
|
+
|
|
572
|
+
return float(dE_ev), dF_ml, dH_ml
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
# ======================================================================
|
|
576
|
+
# Utilities
|
|
577
|
+
# ======================================================================
|
|
578
|
+
|
|
579
|
+
def _fixed_indices_from_constraints(atoms: Atoms) -> set[int]:
|
|
580
|
+
fixed: set[int] = set()
|
|
581
|
+
for c in atoms.constraints or []:
|
|
582
|
+
if isinstance(c, FixAtoms):
|
|
583
|
+
fixed.update(int(i) for i in c.get_indices())
|
|
584
|
+
return fixed
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
def _normalize_prmtop_lj_tables(parm7_path: str) -> None:
|
|
588
|
+
"""Normalize LJ table lengths in parm7 files generated from sliced structures.
|
|
589
|
+
|
|
590
|
+
ParmEd slicing can leave ``LENNARD_JONES_*COEF`` longer than the ``POINTERS``
|
|
591
|
+
``NTYPES`` expectation. Trim only the trailing unused tail when detected.
|
|
592
|
+
"""
|
|
593
|
+
from parmed.amber import AmberFormat, AmberParm
|
|
594
|
+
|
|
595
|
+
try:
|
|
596
|
+
AmberParm(parm7_path)
|
|
597
|
+
return
|
|
598
|
+
except Exception as exc:
|
|
599
|
+
msg = str(exc)
|
|
600
|
+
if (
|
|
601
|
+
"FLAG LENNARD_JONES_ACOEF" not in msg
|
|
602
|
+
and "FLAG LENNARD_JONES_BCOEF" not in msg
|
|
603
|
+
):
|
|
604
|
+
raise
|
|
605
|
+
|
|
606
|
+
af = AmberFormat(parm7_path)
|
|
607
|
+
pointers = list(af.parm_data.get("POINTERS", []))
|
|
608
|
+
if len(pointers) < 2:
|
|
609
|
+
raise ValueError(f"Invalid POINTERS section in parm7: {parm7_path}")
|
|
610
|
+
ntypes = int(pointers[1])
|
|
611
|
+
expected = ntypes * (ntypes + 1) // 2
|
|
612
|
+
|
|
613
|
+
changed = False
|
|
614
|
+
for key in ("LENNARD_JONES_ACOEF", "LENNARD_JONES_BCOEF"):
|
|
615
|
+
values = list(af.parm_data.get(key, []))
|
|
616
|
+
if len(values) == expected:
|
|
617
|
+
continue
|
|
618
|
+
if len(values) < expected:
|
|
619
|
+
raise ValueError(
|
|
620
|
+
f"{key} has {len(values)} entries but expected at least {expected} "
|
|
621
|
+
f"from NTYPES={ntypes} in {parm7_path}."
|
|
622
|
+
)
|
|
623
|
+
af.parm_data[key] = values[:expected]
|
|
624
|
+
changed = True
|
|
625
|
+
|
|
626
|
+
if changed:
|
|
627
|
+
af.write_parm(parm7_path)
|
|
628
|
+
|
|
629
|
+
# Validate normalized topology immediately.
|
|
630
|
+
AmberParm(parm7_path)
|
|
631
|
+
|
|
632
|
+
|
|
633
|
+
# ======================================================================
|
|
634
|
+
# hessian_ff (MM) -> ASE calculator
|
|
635
|
+
# ======================================================================
|
|
636
|
+
|
|
637
|
+
def _expand_partial_hessian(
|
|
638
|
+
h_sub: np.ndarray,
|
|
639
|
+
active_atoms: np.ndarray,
|
|
640
|
+
n_atoms: int,
|
|
641
|
+
*,
|
|
642
|
+
dtype: np.dtype,
|
|
643
|
+
) -> np.ndarray:
|
|
644
|
+
h_full = np.zeros((3 * n_atoms, 3 * n_atoms), dtype=dtype)
|
|
645
|
+
for i_local, i_atom in enumerate(active_atoms):
|
|
646
|
+
i0 = 3 * int(i_atom)
|
|
647
|
+
for j_local, j_atom in enumerate(active_atoms):
|
|
648
|
+
j0 = 3 * int(j_atom)
|
|
649
|
+
h_full[i0:i0 + 3, j0:j0 + 3] = h_sub[
|
|
650
|
+
3 * i_local:3 * i_local + 3,
|
|
651
|
+
3 * j_local:3 * j_local + 3,
|
|
652
|
+
]
|
|
653
|
+
return h_full
|
|
654
|
+
|
|
655
|
+
|
|
656
|
+
class hessianffCalculator(Calculator):
|
|
657
|
+
"""Calculator for MM. hessian_ff-backed."""
|
|
658
|
+
|
|
659
|
+
implemented_properties = ["energy", "forces"]
|
|
660
|
+
|
|
661
|
+
def __init__(
|
|
662
|
+
self,
|
|
663
|
+
parm7: str,
|
|
664
|
+
rst7: Optional[str] = None,
|
|
665
|
+
*,
|
|
666
|
+
device: str = "auto",
|
|
667
|
+
cuda_idx: int = 0,
|
|
668
|
+
threads: int = 16,
|
|
669
|
+
**kwargs,
|
|
670
|
+
):
|
|
671
|
+
super().__init__(**kwargs)
|
|
672
|
+
|
|
673
|
+
requested = str(device).lower()
|
|
674
|
+
if requested not in {"auto", "cpu"}:
|
|
675
|
+
raise ValueError(
|
|
676
|
+
"MM backend 'hessian_ff' is CPU-only. "
|
|
677
|
+
f"Got device={device!r}. Use mm_device='cpu' or 'auto'."
|
|
678
|
+
)
|
|
679
|
+
|
|
680
|
+
self.device = "cpu"
|
|
681
|
+
self.cuda_idx = int(cuda_idx)
|
|
682
|
+
self.threads = int(threads)
|
|
683
|
+
if self.threads > 0 and torch.get_num_threads() != self.threads:
|
|
684
|
+
torch.set_num_threads(self.threads)
|
|
685
|
+
|
|
686
|
+
self.system = load_system(parm7, device="cpu").to(dtype=torch.float64)
|
|
687
|
+
self.ff = ForceFieldTorch(self.system)
|
|
688
|
+
self.natom = int(self.system.natom)
|
|
689
|
+
self._coords_dtype = torch.float64
|
|
690
|
+
self._coords_device = torch.device("cpu")
|
|
691
|
+
self._coord_buf = torch.empty((self.natom, 3), dtype=self._coords_dtype, device=self._coords_device)
|
|
692
|
+
|
|
693
|
+
if rst7 is not None:
|
|
694
|
+
xyz = load_coords(rst7, natom=self.natom, device=self._coords_device, dtype=self._coords_dtype)
|
|
695
|
+
self._coord_buf.copy_(xyz)
|
|
696
|
+
|
|
697
|
+
def _positions_to_tensor(self, positions_ang: np.ndarray) -> torch.Tensor:
|
|
698
|
+
arr = np.asarray(positions_ang, dtype=np.float64)
|
|
699
|
+
if arr.shape != (self.natom, 3):
|
|
700
|
+
raise ValueError(
|
|
701
|
+
f"Coordinate shape mismatch for '{type(self).__name__}': "
|
|
702
|
+
f"got {arr.shape}, expected ({self.natom}, 3)."
|
|
703
|
+
)
|
|
704
|
+
self._coord_buf.copy_(torch.as_tensor(arr, dtype=self._coords_dtype, device=self._coords_device))
|
|
705
|
+
return self._coord_buf
|
|
706
|
+
|
|
707
|
+
def _energy_forces_from_positions(self, positions_ang: np.ndarray) -> Tuple[float, np.ndarray]:
|
|
708
|
+
xyz = self._positions_to_tensor(positions_ang)
|
|
709
|
+
out, force = self.ff.energy_force(xyz, force_calc_mode="Analytical")
|
|
710
|
+
energy_ev = float(out["E_total"].detach().cpu()) * KCALMOL2EV
|
|
711
|
+
forces_ev = force.detach().cpu().numpy().astype(np.float64, copy=False) * KCALMOL2EV
|
|
712
|
+
return energy_ev, forces_ev
|
|
713
|
+
|
|
714
|
+
def calculate(self, atoms: Atoms = None, properties=None, system_changes=all_changes):
|
|
715
|
+
super().calculate(atoms, properties, system_changes)
|
|
716
|
+
if atoms is None:
|
|
717
|
+
raise ValueError("ASE Atoms is required for MM evaluation.")
|
|
718
|
+
energy_ev, forces_ev = self._energy_forces_from_positions(atoms.get_positions())
|
|
719
|
+
self.results = {"energy": energy_ev, "forces": forces_ev}
|
|
720
|
+
|
|
721
|
+
def analytical_hessian(
|
|
722
|
+
self,
|
|
723
|
+
atoms: Atoms,
|
|
724
|
+
*,
|
|
725
|
+
info_path: Optional[str] = None,
|
|
726
|
+
dtype: np.dtype = np.float64,
|
|
727
|
+
return_partial_hessian: bool = False,
|
|
728
|
+
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
|
729
|
+
fixed = _fixed_indices_from_constraints(atoms)
|
|
730
|
+
active_atoms = np.asarray([i for i in range(len(atoms)) if i not in fixed], dtype=int)
|
|
731
|
+
|
|
732
|
+
if active_atoms.size == 0:
|
|
733
|
+
if return_partial_hessian:
|
|
734
|
+
return np.zeros((0, 0), dtype=dtype), active_atoms
|
|
735
|
+
return np.zeros((3 * len(atoms), 3 * len(atoms)), dtype=dtype), None
|
|
736
|
+
|
|
737
|
+
if info_path is not None:
|
|
738
|
+
dir_ = os.path.dirname(info_path)
|
|
739
|
+
if dir_:
|
|
740
|
+
os.makedirs(dir_, exist_ok=True)
|
|
741
|
+
with open(info_path, "w", encoding="utf-8") as log:
|
|
742
|
+
log.write("Analytical Hessian (hessian_ff)\n")
|
|
743
|
+
log.write("--------------------------------\n")
|
|
744
|
+
log.write(f"n_active_atoms = {active_atoms.size}\n")
|
|
745
|
+
log.flush()
|
|
746
|
+
|
|
747
|
+
xyz = self._positions_to_tensor(atoms.get_positions())
|
|
748
|
+
h_local, _ = build_analytical_hessian(
|
|
749
|
+
system=self.system,
|
|
750
|
+
coords=xyz,
|
|
751
|
+
active_atoms=active_atoms.tolist(),
|
|
752
|
+
)
|
|
753
|
+
h_sub = h_local.detach().cpu().numpy().astype(np.float64, copy=False) * KCALMOL2EV
|
|
754
|
+
h_sub = np.asarray(h_sub, dtype=dtype)
|
|
755
|
+
|
|
756
|
+
if return_partial_hessian:
|
|
757
|
+
return h_sub, active_atoms
|
|
758
|
+
|
|
759
|
+
h_full = _expand_partial_hessian(h_sub, active_atoms, len(atoms), dtype=dtype)
|
|
760
|
+
return h_full, None
|
|
761
|
+
|
|
762
|
+
def finite_difference_hessian(
|
|
763
|
+
self,
|
|
764
|
+
atoms: Atoms,
|
|
765
|
+
*,
|
|
766
|
+
delta: float = 1e-3,
|
|
767
|
+
info_path: Optional[str] = None,
|
|
768
|
+
dtype: np.dtype = np.float64,
|
|
769
|
+
return_partial_hessian: bool = False,
|
|
770
|
+
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
|
771
|
+
_ = float(delta) # Kept for backward-compatible signature.
|
|
772
|
+
return self.analytical_hessian(
|
|
773
|
+
atoms,
|
|
774
|
+
info_path=info_path,
|
|
775
|
+
dtype=dtype,
|
|
776
|
+
return_partial_hessian=return_partial_hessian,
|
|
777
|
+
)
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
|
|
781
|
+
# ======================================================================
|
|
782
|
+
# OpenMM Calculator
|
|
783
|
+
# ======================================================================
|
|
784
|
+
class OpenMMCalculator(Calculator):
|
|
785
|
+
"""
|
|
786
|
+
ASE Calculator wrapper for OpenMM backend (finite-difference Hessian).
|
|
787
|
+
|
|
788
|
+
This calculator uses OpenMM for MM force field evaluation and supports
|
|
789
|
+
CUDA/CPU platforms. Unlike hessianffCalculator, it computes Hessians
|
|
790
|
+
via numerical finite differences.
|
|
791
|
+
|
|
792
|
+
Parameters
|
|
793
|
+
----------
|
|
794
|
+
parm7 : str
|
|
795
|
+
Path to Amber parm7 topology file.
|
|
796
|
+
rst7 : str
|
|
797
|
+
Path to Amber rst7 coordinate file.
|
|
798
|
+
device : str, default "auto"
|
|
799
|
+
Platform selection: "auto", "cuda", or "cpu".
|
|
800
|
+
cuda_idx : int, default 0
|
|
801
|
+
CUDA device index when device="cuda".
|
|
802
|
+
threads : int, default 16
|
|
803
|
+
Number of CPU threads when device="cpu".
|
|
804
|
+
"""
|
|
805
|
+
|
|
806
|
+
implemented_properties = ["energy", "forces"]
|
|
807
|
+
|
|
808
|
+
def __init__(
|
|
809
|
+
self,
|
|
810
|
+
parm7: str,
|
|
811
|
+
rst7: str,
|
|
812
|
+
*,
|
|
813
|
+
device: str = "auto",
|
|
814
|
+
cuda_idx: int = 0,
|
|
815
|
+
threads: int = 16,
|
|
816
|
+
**kwargs,
|
|
817
|
+
):
|
|
818
|
+
super().__init__(**kwargs)
|
|
819
|
+
|
|
820
|
+
if not HAS_OPENMM:
|
|
821
|
+
raise ImportError(
|
|
822
|
+
"OpenMM is required for OpenMMCalculator. "
|
|
823
|
+
"Install with: conda install -c conda-forge openmm"
|
|
824
|
+
)
|
|
825
|
+
|
|
826
|
+
# Auto-detect device
|
|
827
|
+
if device == "auto":
|
|
828
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
829
|
+
|
|
830
|
+
# Platform selection
|
|
831
|
+
if device == "cuda":
|
|
832
|
+
platform = Platform.getPlatformByName("CUDA")
|
|
833
|
+
properties = {
|
|
834
|
+
"CudaDeviceIndex": str(cuda_idx),
|
|
835
|
+
"CudaPrecision": "double",
|
|
836
|
+
"DeterministicForces": "true",
|
|
837
|
+
"CudaUseBlockingSync": "true",
|
|
838
|
+
}
|
|
839
|
+
else:
|
|
840
|
+
platform = Platform.getPlatformByName("CPU")
|
|
841
|
+
properties = {"Threads": str(threads)}
|
|
842
|
+
|
|
843
|
+
# Load Amber topology and coordinates
|
|
844
|
+
self.prmtop = app.AmberPrmtopFile(parm7)
|
|
845
|
+
inpcrd = app.AmberInpcrdFile(rst7)
|
|
846
|
+
|
|
847
|
+
# Create OpenMM system and context
|
|
848
|
+
self.system = self.prmtop.createSystem(
|
|
849
|
+
nonbondedMethod=app.NoCutoff,
|
|
850
|
+
rigidWater=False
|
|
851
|
+
)
|
|
852
|
+
self.integrator = mm.VerletIntegrator(0 * unit.femtoseconds)
|
|
853
|
+
self.context = mm.Context(self.system, self.integrator, platform, properties)
|
|
854
|
+
self.context.setPositions(inpcrd.positions)
|
|
855
|
+
|
|
856
|
+
def calculate(self, atoms: Atoms = None, properties=None, system_changes=all_changes):
|
|
857
|
+
"""Compute energy and forces for the given atoms."""
|
|
858
|
+
super().calculate(atoms, properties, system_changes)
|
|
859
|
+
|
|
860
|
+
# Define eV unit for OpenMM
|
|
861
|
+
ev_base_unit = ScaledUnit(1.602176634e-19, joule, "electron volt", "eV")
|
|
862
|
+
eV = unit.Unit({ev_base_unit: 1.0})
|
|
863
|
+
|
|
864
|
+
# Update positions and get state
|
|
865
|
+
self.context.setPositions(atoms.get_positions() * unit.angstrom)
|
|
866
|
+
state = self.context.getState(getEnergy=True, getForces=True)
|
|
867
|
+
|
|
868
|
+
# Extract energy and forces in eV units
|
|
869
|
+
energy = state.getPotentialEnergy().value_in_unit(eV / unit.item)
|
|
870
|
+
forces = state.getForces(asNumpy=True).value_in_unit(eV / unit.angstrom / unit.item)
|
|
871
|
+
|
|
872
|
+
self.results = {"energy": energy, "forces": forces}
|
|
873
|
+
|
|
874
|
+
def finite_difference_hessian(
|
|
875
|
+
self,
|
|
876
|
+
atoms: Atoms,
|
|
877
|
+
*,
|
|
878
|
+
delta: float = 0.01,
|
|
879
|
+
info_path: Optional[str] = None,
|
|
880
|
+
dtype: np.dtype = np.float64,
|
|
881
|
+
return_partial_hessian: bool = False,
|
|
882
|
+
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
|
883
|
+
"""
|
|
884
|
+
Compute Hessian via finite differences using hessian_calc utility.
|
|
885
|
+
|
|
886
|
+
Parameters
|
|
887
|
+
----------
|
|
888
|
+
atoms : Atoms
|
|
889
|
+
Structure to differentiate.
|
|
890
|
+
delta : float, default 0.01
|
|
891
|
+
Displacement size in Angstrom.
|
|
892
|
+
info_path : str | None
|
|
893
|
+
Progress log file path.
|
|
894
|
+
dtype : numpy dtype, default float64
|
|
895
|
+
Data type for the Hessian matrix.
|
|
896
|
+
return_partial_hessian : bool, default False
|
|
897
|
+
If True, return only the active sub-Hessian and active atom indices.
|
|
898
|
+
|
|
899
|
+
Returns
|
|
900
|
+
-------
|
|
901
|
+
H_full : ndarray
|
|
902
|
+
Full (3N, 3N) Hessian matrix in eV/Ų.
|
|
903
|
+
active_atoms : ndarray | None
|
|
904
|
+
Active atom indices (only if return_partial_hessian=True).
|
|
905
|
+
"""
|
|
906
|
+
from .hessian_calc import hessian_calc
|
|
907
|
+
|
|
908
|
+
H_full = hessian_calc(atoms, self, delta=delta, info_path=info_path, dtype=dtype)
|
|
909
|
+
|
|
910
|
+
if return_partial_hessian:
|
|
911
|
+
fixed = _fixed_indices_from_constraints(atoms)
|
|
912
|
+
active_atoms = np.asarray([i for i in range(len(atoms)) if i not in fixed])
|
|
913
|
+
# Extract active sub-Hessian to match hessian_ff convention
|
|
914
|
+
idx3 = np.concatenate([3 * active_atoms + d for d in range(3)])
|
|
915
|
+
idx3.sort()
|
|
916
|
+
H_sub = H_full[np.ix_(idx3, idx3)]
|
|
917
|
+
return H_sub, active_atoms
|
|
918
|
+
|
|
919
|
+
return H_full, None
|
|
920
|
+
|
|
921
|
+
|
|
922
|
+
# ======================================================================
|
|
923
|
+
# ML/MM Core (Multi-Backend)
|
|
924
|
+
# ======================================================================
|
|
925
|
+
|
|
926
|
+
@dataclass(frozen=True)
|
|
927
|
+
class _MLHighOut:
|
|
928
|
+
E: float
|
|
929
|
+
F: np.ndarray
|
|
930
|
+
H: Optional[torch.Tensor]
|
|
931
|
+
timing: Dict[str, float | str]
|
|
932
|
+
|
|
933
|
+
|
|
934
|
+
@dataclass(frozen=True)
|
|
935
|
+
class _MMLowOut:
|
|
936
|
+
E_real: float
|
|
937
|
+
F_real: np.ndarray
|
|
938
|
+
E_model: float
|
|
939
|
+
F_model: np.ndarray
|
|
940
|
+
H_real: Optional[np.ndarray]
|
|
941
|
+
H_model: Optional[np.ndarray]
|
|
942
|
+
active_atoms_from_fd: Optional[np.ndarray]
|
|
943
|
+
timing: Dict[str, float | str]
|
|
944
|
+
|
|
945
|
+
|
|
946
|
+
class MLMMCore:
|
|
947
|
+
"""ONIOM-like ML/MM engine supporting multiple MLIP backends.
|
|
948
|
+
|
|
949
|
+
Supported ML backends: UMA (default), ORB, MACE, AIMNet2.
|
|
950
|
+
Supported MM backends: hessian_ff (analytical), OpenMM (FD).
|
|
951
|
+
Optional xTB point-charge embedding correction for environmental effects.
|
|
952
|
+
"""
|
|
953
|
+
|
|
954
|
+
def __init__(
|
|
955
|
+
self,
|
|
956
|
+
*,
|
|
957
|
+
input_pdb: str = None,
|
|
958
|
+
real_parm7: str = None,
|
|
959
|
+
model_pdb: str = None,
|
|
960
|
+
model_charge: Optional[int] = 0,
|
|
961
|
+
model_mult: int = 1,
|
|
962
|
+
link_mlmm: List[Tuple[str, str]] | None = None,
|
|
963
|
+
# ML backend selection
|
|
964
|
+
backend: str = "uma",
|
|
965
|
+
uma_model: str = "uma-s-1p1",
|
|
966
|
+
uma_task_name: str = "omol",
|
|
967
|
+
orb_model: str = "orb_v3_conservative_omol",
|
|
968
|
+
orb_precision: str = "float32",
|
|
969
|
+
mace_model: str = "MACE-OMOL-0",
|
|
970
|
+
mace_dtype: str = "float64",
|
|
971
|
+
aimnet2_model: str = "aimnet2",
|
|
972
|
+
# MM settings
|
|
973
|
+
mm_fd: bool = True,
|
|
974
|
+
mm_fd_dir: Optional[str] = None,
|
|
975
|
+
mm_fd_delta: float = 1e-3,
|
|
976
|
+
symmetrize_hessian: bool = True,
|
|
977
|
+
print_timing: bool = True,
|
|
978
|
+
print_vram: bool = True,
|
|
979
|
+
H_double: bool = False,
|
|
980
|
+
ml_device: str = "auto",
|
|
981
|
+
ml_cuda_idx: int = 0,
|
|
982
|
+
mm_backend: str = "hessian_ff",
|
|
983
|
+
mm_device: str = "cpu",
|
|
984
|
+
mm_cuda_idx: int = 0,
|
|
985
|
+
mm_threads: int = 16,
|
|
986
|
+
freeze_atoms: List[int] | None = None,
|
|
987
|
+
ml_hessian_mode: str = "FiniteDifference",
|
|
988
|
+
hessian_calc_mode: Optional[str] = None,
|
|
989
|
+
return_partial_hessian: bool = True,
|
|
990
|
+
hess_cutoff: Optional[float] = None,
|
|
991
|
+
movable_cutoff: Optional[float] = None,
|
|
992
|
+
use_bfactor_layers: bool = False,
|
|
993
|
+
hess_mm_atoms: Optional[List[int]] = None,
|
|
994
|
+
movable_mm_atoms: Optional[List[int]] = None,
|
|
995
|
+
frozen_mm_atoms: Optional[List[int]] = None,
|
|
996
|
+
# Point-charge embedding correction
|
|
997
|
+
embedcharge: bool = False,
|
|
998
|
+
embedcharge_step: float = 1.0e-3,
|
|
999
|
+
embedcharge_cutoff: Optional[float] = None,
|
|
1000
|
+
xtb_cmd: str = "xtb",
|
|
1001
|
+
xtb_acc: float = 0.2,
|
|
1002
|
+
xtb_workdir: str = "tmp",
|
|
1003
|
+
xtb_keep_files: bool = False,
|
|
1004
|
+
xtb_ncores: int = 4,
|
|
1005
|
+
**kwargs,
|
|
1006
|
+
):
|
|
1007
|
+
# --- v0.1.x backward compatibility aliases ---
|
|
1008
|
+
if "real_pdb" in kwargs:
|
|
1009
|
+
warnings.warn("'real_pdb' is deprecated; use 'input_pdb'.", DeprecationWarning, stacklevel=2)
|
|
1010
|
+
if input_pdb is None:
|
|
1011
|
+
input_pdb = kwargs.pop("real_pdb")
|
|
1012
|
+
else:
|
|
1013
|
+
kwargs.pop("real_pdb")
|
|
1014
|
+
for _old_name in ("real_rst7", "vib_run", "vib_dir"):
|
|
1015
|
+
if _old_name in kwargs:
|
|
1016
|
+
warnings.warn(f"'{_old_name}' is no longer used and will be ignored.", DeprecationWarning, stacklevel=2)
|
|
1017
|
+
kwargs.pop(_old_name)
|
|
1018
|
+
if kwargs:
|
|
1019
|
+
raise TypeError(f"MLMMCore.__init__() got unexpected keyword arguments: {', '.join(kwargs)}")
|
|
1020
|
+
if input_pdb is None:
|
|
1021
|
+
raise TypeError("MLMMCore.__init__() missing required keyword argument: 'input_pdb'")
|
|
1022
|
+
|
|
1023
|
+
self._tmpdir_obj = tempfile.TemporaryDirectory()
|
|
1024
|
+
self.tmpdir: str = self._tmpdir_obj.name
|
|
1025
|
+
for src, dst in [(input_pdb, "input.pdb"), (real_parm7, "real.parm7"), (model_pdb, "model.pdb")]:
|
|
1026
|
+
shutil.copy(src, os.path.join(self.tmpdir, dst))
|
|
1027
|
+
|
|
1028
|
+
self.input_pdb = os.path.join(self.tmpdir, "input.pdb")
|
|
1029
|
+
self.real_parm7 = os.path.join(self.tmpdir, "real.parm7")
|
|
1030
|
+
self.real_rst7 = os.path.join(self.tmpdir, "real.rst7")
|
|
1031
|
+
self.model_pdb = os.path.join(self.tmpdir, "model.pdb")
|
|
1032
|
+
self.model_parm7 = os.path.join(self.tmpdir, "model.parm7")
|
|
1033
|
+
self.model_rst7 = os.path.join(self.tmpdir, "model.rst7")
|
|
1034
|
+
|
|
1035
|
+
real_top = pmd.load_file(self.real_parm7)
|
|
1036
|
+
start_struct = pmd.load_file(self.input_pdb)
|
|
1037
|
+
real_n_atoms = int(len(real_top.atoms))
|
|
1038
|
+
start_n_atoms = int(len(start_struct.atoms))
|
|
1039
|
+
if start_n_atoms != real_n_atoms:
|
|
1040
|
+
raise ValueError(
|
|
1041
|
+
"Atom-count mismatch between input structure and real topology: "
|
|
1042
|
+
f"input_pdb='{input_pdb}' has {start_n_atoms} atoms, "
|
|
1043
|
+
f"real_parm7='{real_parm7}' expects {real_n_atoms} atoms. "
|
|
1044
|
+
"Provide a full-system input structure consistent with the parm7."
|
|
1045
|
+
)
|
|
1046
|
+
real_top.coordinates = start_struct.coordinates
|
|
1047
|
+
real_top.box = None
|
|
1048
|
+
real_top.save(self.real_parm7, overwrite=True)
|
|
1049
|
+
real_top.save(self.real_rst7, overwrite=True)
|
|
1050
|
+
|
|
1051
|
+
self.link_mlmm = link_mlmm
|
|
1052
|
+
self.ml_ID, self.mlmm_links = self._ml_prep()
|
|
1053
|
+
self.selection_indices = self._mk_model_parm7()
|
|
1054
|
+
|
|
1055
|
+
self.hess_cutoff = hess_cutoff
|
|
1056
|
+
self.movable_cutoff = movable_cutoff
|
|
1057
|
+
self.use_bfactor_layers = use_bfactor_layers
|
|
1058
|
+
self._original_input_pdb = input_pdb
|
|
1059
|
+
self._explicit_hess_mm_atoms = hess_mm_atoms
|
|
1060
|
+
self._explicit_movable_mm_atoms = movable_mm_atoms
|
|
1061
|
+
self._explicit_frozen_mm_atoms = frozen_mm_atoms
|
|
1062
|
+
self._compute_layer_indices(real_top.coordinates)
|
|
1063
|
+
|
|
1064
|
+
self.freeze_atoms = [] if freeze_atoms is None else list(freeze_atoms)
|
|
1065
|
+
if self.frozen_layer_indices:
|
|
1066
|
+
self.freeze_atoms = sorted(set(self.freeze_atoms) | set(self.frozen_layer_indices))
|
|
1067
|
+
|
|
1068
|
+
hess_set = set(self.hess_indices)
|
|
1069
|
+
all_atoms = set(range(len(real_top.atoms)))
|
|
1070
|
+
self.hess_freeze_atoms = sorted(all_atoms - hess_set)
|
|
1071
|
+
|
|
1072
|
+
self.return_partial_hessian = bool(return_partial_hessian)
|
|
1073
|
+
|
|
1074
|
+
self._n_real = len(real_top.atoms)
|
|
1075
|
+
self._idx_map_real_to_model = {idx: pos for pos, idx in enumerate(self.selection_indices)}
|
|
1076
|
+
self._update_active_dof_mappings()
|
|
1077
|
+
|
|
1078
|
+
self.H_double = bool(H_double)
|
|
1079
|
+
self.H_dtype = torch.float64 if self.H_double else torch.float32
|
|
1080
|
+
self.H_np_dtype = np.float64 if self.H_double else np.float32
|
|
1081
|
+
|
|
1082
|
+
self.mm_fd = mm_fd
|
|
1083
|
+
self.mm_fd_dir = mm_fd_dir
|
|
1084
|
+
self.mm_fd_delta = mm_fd_delta
|
|
1085
|
+
self.symmetrize_hessian = symmetrize_hessian
|
|
1086
|
+
self.print_timing = bool(print_timing)
|
|
1087
|
+
self.print_vram = bool(print_vram)
|
|
1088
|
+
if self.mm_fd_dir and not os.path.exists(self.mm_fd_dir):
|
|
1089
|
+
os.makedirs(self.mm_fd_dir, exist_ok=True)
|
|
1090
|
+
|
|
1091
|
+
if ml_device == "auto":
|
|
1092
|
+
ml_device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
1093
|
+
self.device_str = ml_device
|
|
1094
|
+
self.ml_device = torch.device(f"cuda:{ml_cuda_idx}" if ml_device == "cuda" else "cpu")
|
|
1095
|
+
|
|
1096
|
+
self.model_charge = int(0 if model_charge is None else model_charge)
|
|
1097
|
+
self.model_mult = int(model_mult)
|
|
1098
|
+
self.backend_name = str(backend).strip().lower() if backend is not None else "uma"
|
|
1099
|
+
|
|
1100
|
+
# Create ML backend via factory
|
|
1101
|
+
self._ml_backend = _create_ml_backend(
|
|
1102
|
+
self.backend_name,
|
|
1103
|
+
uma_model=uma_model,
|
|
1104
|
+
uma_task_name=uma_task_name,
|
|
1105
|
+
orb_model=orb_model,
|
|
1106
|
+
mace_model=mace_model,
|
|
1107
|
+
mace_dtype=mace_dtype,
|
|
1108
|
+
aimnet2_model=aimnet2_model,
|
|
1109
|
+
model_charge=self.model_charge,
|
|
1110
|
+
model_mult=self.model_mult,
|
|
1111
|
+
ml_device=self.ml_device,
|
|
1112
|
+
)
|
|
1113
|
+
|
|
1114
|
+
# Point-charge embedding correction
|
|
1115
|
+
self.embedcharge = bool(embedcharge)
|
|
1116
|
+
self.embedcharge_cutoff = embedcharge_cutoff
|
|
1117
|
+
self._embed_correction: Optional[_EmbedChargeCorrection] = None
|
|
1118
|
+
if self.embedcharge:
|
|
1119
|
+
self._embed_correction = _EmbedChargeCorrection(
|
|
1120
|
+
xtb_cmd=xtb_cmd,
|
|
1121
|
+
xtb_acc=xtb_acc,
|
|
1122
|
+
xtb_workdir=xtb_workdir,
|
|
1123
|
+
xtb_keep_files=xtb_keep_files,
|
|
1124
|
+
xtb_ncores=xtb_ncores,
|
|
1125
|
+
hessian_step=embedcharge_step,
|
|
1126
|
+
)
|
|
1127
|
+
|
|
1128
|
+
# MM backend selection: hessian_ff or openmm
|
|
1129
|
+
self.mm_backend = str(mm_backend).strip().lower()
|
|
1130
|
+
if self.mm_backend == "openmm":
|
|
1131
|
+
self.calc_real_low = OpenMMCalculator(
|
|
1132
|
+
parm7=self.real_parm7, rst7=self.real_rst7,
|
|
1133
|
+
device=mm_device, cuda_idx=mm_cuda_idx, threads=mm_threads
|
|
1134
|
+
)
|
|
1135
|
+
self.calc_model_low = OpenMMCalculator(
|
|
1136
|
+
parm7=self.model_parm7, rst7=self.model_rst7,
|
|
1137
|
+
device=mm_device, cuda_idx=mm_cuda_idx, threads=mm_threads
|
|
1138
|
+
)
|
|
1139
|
+
elif self.mm_backend == "hessian_ff":
|
|
1140
|
+
self.calc_real_low = hessianffCalculator(
|
|
1141
|
+
parm7=self.real_parm7, rst7=None,
|
|
1142
|
+
device=mm_device, cuda_idx=mm_cuda_idx, threads=mm_threads
|
|
1143
|
+
)
|
|
1144
|
+
self.calc_model_low = hessianffCalculator(
|
|
1145
|
+
parm7=self.model_parm7, rst7=None,
|
|
1146
|
+
device=mm_device, cuda_idx=mm_cuda_idx, threads=mm_threads
|
|
1147
|
+
)
|
|
1148
|
+
else:
|
|
1149
|
+
raise ValueError(
|
|
1150
|
+
f"Unknown mm_backend '{mm_backend}'. Choose 'hessian_ff' or 'openmm'."
|
|
1151
|
+
)
|
|
1152
|
+
|
|
1153
|
+
mode_in = hessian_calc_mode if hessian_calc_mode is not None else ml_hessian_mode
|
|
1154
|
+
mode = (mode_in or "Analytical").strip().lower()
|
|
1155
|
+
self._ml_hessian_mode = "analytical" if mode.startswith("analyt") else "fd"
|
|
1156
|
+
|
|
1157
|
+
self._atoms_real_tpl = read(self.input_pdb)
|
|
1158
|
+
self._atoms_model_tpl = read(self.model_pdb)
|
|
1159
|
+
tmp = self._atoms_model_tpl.copy()
|
|
1160
|
+
for _ in self.mlmm_links:
|
|
1161
|
+
tmp += Atoms("H", positions=[[0.0, 0.0, 0.0]])
|
|
1162
|
+
self._atoms_model_LH_tpl = tmp
|
|
1163
|
+
|
|
1164
|
+
def cleanup(self):
|
|
1165
|
+
"""Clean up temporary directory."""
|
|
1166
|
+
if hasattr(self, '_tmpdir_obj') and self._tmpdir_obj is not None:
|
|
1167
|
+
try:
|
|
1168
|
+
self._tmpdir_obj.cleanup()
|
|
1169
|
+
except Exception:
|
|
1170
|
+
logger.debug("Failed to clean up tmpdir", exc_info=True)
|
|
1171
|
+
|
|
1172
|
+
def __del__(self):
|
|
1173
|
+
self.cleanup()
|
|
1174
|
+
|
|
1175
|
+
@staticmethod
|
|
1176
|
+
def _pdb_atom_key(line: str) -> str:
|
|
1177
|
+
return f"{line[12:16].strip()} {line[17:20].strip()} {line[22:26].strip()}"
|
|
1178
|
+
|
|
1179
|
+
def _ml_prep(self) -> Tuple[List[str], List[Tuple[int, int]]]:
|
|
1180
|
+
ml_region = set()
|
|
1181
|
+
with open(self.model_pdb) as fh:
|
|
1182
|
+
for ln in fh:
|
|
1183
|
+
if ln.startswith(("ATOM", "HETATM")):
|
|
1184
|
+
ml_region.add(self._pdb_atom_key(ln))
|
|
1185
|
+
|
|
1186
|
+
leap_atoms: List[Dict] = []
|
|
1187
|
+
with open(self.input_pdb) as fh:
|
|
1188
|
+
for ln in fh:
|
|
1189
|
+
if not ln.startswith(("ATOM", "HETATM")):
|
|
1190
|
+
continue
|
|
1191
|
+
leap_atoms.append(
|
|
1192
|
+
{
|
|
1193
|
+
"idx": int(ln[6:11]),
|
|
1194
|
+
"id": self._pdb_atom_key(ln),
|
|
1195
|
+
"elem": ln[76:78].strip(),
|
|
1196
|
+
"coord": np.array([float(ln[30:38]), float(ln[38:46]), float(ln[46:54])]),
|
|
1197
|
+
}
|
|
1198
|
+
)
|
|
1199
|
+
|
|
1200
|
+
ml_ID = [str(a["idx"]) for a in leap_atoms if a["id"] in ml_region]
|
|
1201
|
+
|
|
1202
|
+
if self.link_mlmm:
|
|
1203
|
+
processed = [(" ".join(q.split()[:3]), " ".join(m.split()[:3])) for q, m in self.link_mlmm]
|
|
1204
|
+
|
|
1205
|
+
ml_indices: List[int] = []
|
|
1206
|
+
mm_indices: List[int] = []
|
|
1207
|
+
for a in leap_atoms:
|
|
1208
|
+
for qnm, mnm in processed:
|
|
1209
|
+
if a["id"] == qnm:
|
|
1210
|
+
ml_indices.append(a["idx"])
|
|
1211
|
+
elif a["id"] == mnm:
|
|
1212
|
+
mm_indices.append(a["idx"])
|
|
1213
|
+
|
|
1214
|
+
if len(set(ml_indices)) != len(ml_indices) or len(set(mm_indices)) != len(mm_indices):
|
|
1215
|
+
raise ValueError("Duplicated ML or MM indices in link specification.")
|
|
1216
|
+
mlmm_links = list(zip(ml_indices, mm_indices))
|
|
1217
|
+
else:
|
|
1218
|
+
threshold = 1.7
|
|
1219
|
+
ml_set = {a["idx"] for a in leap_atoms if a["id"] in ml_region}
|
|
1220
|
+
coords = {a["idx"]: a["coord"] for a in leap_atoms}
|
|
1221
|
+
elem = {a["idx"]: a["elem"] for a in leap_atoms}
|
|
1222
|
+
|
|
1223
|
+
ml_indices: List[int] = []
|
|
1224
|
+
mm_indices: List[int] = []
|
|
1225
|
+
for qidx in ml_set:
|
|
1226
|
+
for a in leap_atoms:
|
|
1227
|
+
midx = a["idx"]
|
|
1228
|
+
if midx in ml_set:
|
|
1229
|
+
continue
|
|
1230
|
+
if (
|
|
1231
|
+
np.linalg.norm(coords[midx] - coords[qidx]) < threshold
|
|
1232
|
+
and (
|
|
1233
|
+
(elem[midx] == "C" and elem[qidx] == "C")
|
|
1234
|
+
or (elem[midx] == "N" and elem[qidx] == "C")
|
|
1235
|
+
or (elem[midx] == "C" and elem[qidx] == "N")
|
|
1236
|
+
)
|
|
1237
|
+
):
|
|
1238
|
+
ml_indices.append(qidx)
|
|
1239
|
+
mm_indices.append(midx)
|
|
1240
|
+
|
|
1241
|
+
if len(set(ml_indices)) != len(ml_indices) or len(set(mm_indices)) != len(mm_indices):
|
|
1242
|
+
raise ValueError(
|
|
1243
|
+
"Automatic link detection produced duplicate pairs. Specify 'link_mlmm' manually."
|
|
1244
|
+
)
|
|
1245
|
+
mlmm_links = list(zip(ml_indices, mm_indices))
|
|
1246
|
+
|
|
1247
|
+
return ml_ID, mlmm_links
|
|
1248
|
+
|
|
1249
|
+
def _mk_model_parm7(self) -> List[int]:
|
|
1250
|
+
real = pmd.load_file(self.real_parm7, self.real_rst7)
|
|
1251
|
+
real.box = None
|
|
1252
|
+
ml_atoms = [real.atoms[int(i) - 1] for i in self.ml_ID]
|
|
1253
|
+
selection = [a.idx for a in ml_atoms]
|
|
1254
|
+
|
|
1255
|
+
if len(selection) == len(real.atoms):
|
|
1256
|
+
shutil.copy(self.real_parm7, self.model_parm7)
|
|
1257
|
+
shutil.copy(self.real_rst7, self.model_rst7)
|
|
1258
|
+
return selection
|
|
1259
|
+
|
|
1260
|
+
model = real[selection]
|
|
1261
|
+
model.box = None
|
|
1262
|
+
model.save(self.model_parm7, overwrite=True)
|
|
1263
|
+
_normalize_prmtop_lj_tables(self.model_parm7)
|
|
1264
|
+
model.save(self.model_rst7, overwrite=True)
|
|
1265
|
+
return selection
|
|
1266
|
+
|
|
1267
|
+
def _compute_layer_indices(self, coords: np.ndarray) -> None:
|
|
1268
|
+
self.ml_indices = sorted(self.selection_indices)
|
|
1269
|
+
|
|
1270
|
+
n_atoms = int(coords.shape[0])
|
|
1271
|
+
all_indices = set(range(n_atoms))
|
|
1272
|
+
mm_indices = all_indices - set(self.ml_indices)
|
|
1273
|
+
|
|
1274
|
+
has_explicit = (
|
|
1275
|
+
self._explicit_hess_mm_atoms is not None
|
|
1276
|
+
or self._explicit_movable_mm_atoms is not None
|
|
1277
|
+
or self._explicit_frozen_mm_atoms is not None
|
|
1278
|
+
)
|
|
1279
|
+
if has_explicit:
|
|
1280
|
+
explicit_hess = set(self._explicit_hess_mm_atoms or [])
|
|
1281
|
+
explicit_movable = set(self._explicit_movable_mm_atoms or [])
|
|
1282
|
+
explicit_frozen = set(self._explicit_frozen_mm_atoms or [])
|
|
1283
|
+
|
|
1284
|
+
for idx_set, name in [
|
|
1285
|
+
(explicit_hess, "hess_mm_atoms"),
|
|
1286
|
+
(explicit_movable, "movable_mm_atoms"),
|
|
1287
|
+
(explicit_frozen, "frozen_mm_atoms"),
|
|
1288
|
+
]:
|
|
1289
|
+
for idx in idx_set:
|
|
1290
|
+
if idx < 0 or idx >= n_atoms:
|
|
1291
|
+
raise ValueError(f"Invalid atom index {idx} in {name}: must be 0 <= idx < {n_atoms}")
|
|
1292
|
+
if idx in self.ml_indices:
|
|
1293
|
+
raise ValueError(f"Atom index {idx} in {name} is also in ML region (model_pdb)")
|
|
1294
|
+
|
|
1295
|
+
self.hess_mm_indices = sorted(explicit_hess & mm_indices)
|
|
1296
|
+
self.movable_mm_indices = sorted(explicit_movable & mm_indices)
|
|
1297
|
+
self.frozen_layer_indices = sorted(explicit_frozen & mm_indices)
|
|
1298
|
+
|
|
1299
|
+
assigned_mm = explicit_hess | explicit_movable | explicit_frozen
|
|
1300
|
+
unassigned_mm = mm_indices - assigned_mm
|
|
1301
|
+
self.movable_mm_indices = sorted(set(self.movable_mm_indices) | unassigned_mm)
|
|
1302
|
+
|
|
1303
|
+
self.hess_indices = sorted(self.ml_indices + self.hess_mm_indices)
|
|
1304
|
+
self.movable_indices = sorted(self.ml_indices + self.hess_mm_indices + self.movable_mm_indices)
|
|
1305
|
+
return
|
|
1306
|
+
|
|
1307
|
+
if self.use_bfactor_layers:
|
|
1308
|
+
from .utils import read_bfactors_from_pdb, parse_layer_indices_from_bfactors, has_valid_layer_bfactors
|
|
1309
|
+
from pathlib import Path
|
|
1310
|
+
|
|
1311
|
+
bfactors = read_bfactors_from_pdb(Path(self._original_input_pdb))
|
|
1312
|
+
if has_valid_layer_bfactors(bfactors):
|
|
1313
|
+
layer_info = parse_layer_indices_from_bfactors(bfactors)
|
|
1314
|
+
|
|
1315
|
+
movable_from_layer = set(layer_info["movable_mm_indices"]) & mm_indices
|
|
1316
|
+
frozen_from_layer = set(layer_info["frozen_indices"]) & mm_indices
|
|
1317
|
+
hess_from_layer = set(layer_info["hess_mm_indices"]) & mm_indices
|
|
1318
|
+
|
|
1319
|
+
# Unassigned MM atoms default to movable.
|
|
1320
|
+
assigned_mm = movable_from_layer | frozen_from_layer | hess_from_layer
|
|
1321
|
+
unassigned_mm = mm_indices - assigned_mm
|
|
1322
|
+
movable_pool = set(movable_from_layer) | set(unassigned_mm)
|
|
1323
|
+
|
|
1324
|
+
# Hessian-target MM selection:
|
|
1325
|
+
# 1) If hess_cutoff is set, use distance-to-ML over movable MM pool.
|
|
1326
|
+
# 2) Otherwise, keep any Layer-2 assignments (if present).
|
|
1327
|
+
hess_mm: set[int]
|
|
1328
|
+
if self.hess_cutoff is not None:
|
|
1329
|
+
ml_coords = coords[self.ml_indices]
|
|
1330
|
+
|
|
1331
|
+
def min_dist_to_ml(atom_idx: int) -> float:
|
|
1332
|
+
atom_coord = coords[atom_idx]
|
|
1333
|
+
dists = np.linalg.norm(ml_coords - atom_coord, axis=1)
|
|
1334
|
+
return float(np.min(dists))
|
|
1335
|
+
|
|
1336
|
+
hess_cut = float(self.hess_cutoff)
|
|
1337
|
+
hess_mm = {idx for idx in movable_pool if min_dist_to_ml(idx) <= hess_cut}
|
|
1338
|
+
else:
|
|
1339
|
+
hess_mm = set(hess_from_layer)
|
|
1340
|
+
|
|
1341
|
+
movable_mm = movable_pool - hess_mm
|
|
1342
|
+
|
|
1343
|
+
self.hess_mm_indices = sorted(hess_mm)
|
|
1344
|
+
self.movable_mm_indices = sorted(movable_mm)
|
|
1345
|
+
self.frozen_layer_indices = sorted(frozen_from_layer)
|
|
1346
|
+
|
|
1347
|
+
self.hess_indices = sorted(self.ml_indices + self.hess_mm_indices)
|
|
1348
|
+
self.movable_indices = sorted(self.ml_indices + self.hess_mm_indices + self.movable_mm_indices)
|
|
1349
|
+
return
|
|
1350
|
+
|
|
1351
|
+
if self.hess_cutoff is None and self.movable_cutoff is None:
|
|
1352
|
+
self.hess_mm_indices = sorted(mm_indices)
|
|
1353
|
+
self.movable_mm_indices = []
|
|
1354
|
+
self.frozen_layer_indices = []
|
|
1355
|
+
self.hess_indices = sorted(self.ml_indices + self.hess_mm_indices)
|
|
1356
|
+
self.movable_indices = sorted(self.ml_indices + self.hess_mm_indices)
|
|
1357
|
+
return
|
|
1358
|
+
|
|
1359
|
+
ml_coords = coords[self.ml_indices]
|
|
1360
|
+
|
|
1361
|
+
def min_dist_to_ml(atom_idx: int) -> float:
|
|
1362
|
+
atom_coord = coords[atom_idx]
|
|
1363
|
+
dists = np.linalg.norm(ml_coords - atom_coord, axis=1)
|
|
1364
|
+
return float(np.min(dists))
|
|
1365
|
+
|
|
1366
|
+
hess_mm: List[int] = []
|
|
1367
|
+
movable_mm: List[int] = []
|
|
1368
|
+
frozen_mm: List[int] = []
|
|
1369
|
+
|
|
1370
|
+
hess_cut = self.hess_cutoff if self.hess_cutoff is not None else float("inf")
|
|
1371
|
+
mov_cut = self.movable_cutoff if self.movable_cutoff is not None else float("inf")
|
|
1372
|
+
|
|
1373
|
+
for idx in mm_indices:
|
|
1374
|
+
d = min_dist_to_ml(idx)
|
|
1375
|
+
if d <= hess_cut:
|
|
1376
|
+
hess_mm.append(idx)
|
|
1377
|
+
elif d <= mov_cut:
|
|
1378
|
+
movable_mm.append(idx)
|
|
1379
|
+
else:
|
|
1380
|
+
frozen_mm.append(idx)
|
|
1381
|
+
|
|
1382
|
+
self.hess_mm_indices = sorted(hess_mm)
|
|
1383
|
+
self.movable_mm_indices = sorted(movable_mm)
|
|
1384
|
+
self.frozen_layer_indices = sorted(frozen_mm)
|
|
1385
|
+
self.hess_indices = sorted(self.ml_indices + self.hess_mm_indices)
|
|
1386
|
+
self.movable_indices = sorted(self.ml_indices + self.hess_mm_indices + self.movable_mm_indices)
|
|
1387
|
+
|
|
1388
|
+
def _update_active_dof_mappings(self) -> None:
|
|
1389
|
+
freeze_set = set(self.freeze_atoms)
|
|
1390
|
+
self.active_atoms_real = [i for i in range(self._n_real) if i not in freeze_set]
|
|
1391
|
+
self.n_active_real = len(self.active_atoms_real)
|
|
1392
|
+
self.full_to_active_real = {a: i for i, a in enumerate(self.active_atoms_real)}
|
|
1393
|
+
self.active_to_full_real = {i: a for i, a in enumerate(self.active_atoms_real)}
|
|
1394
|
+
|
|
1395
|
+
hess_freeze_set = set(self.hess_freeze_atoms)
|
|
1396
|
+
self.hess_active_atoms = [i for i in range(self._n_real) if i not in hess_freeze_set]
|
|
1397
|
+
self.n_hess_active = len(self.hess_active_atoms)
|
|
1398
|
+
self.full_to_hess_active = {a: i for i, a in enumerate(self.hess_active_atoms)}
|
|
1399
|
+
self.hess_active_to_full = {i: a for i, a in enumerate(self.hess_active_atoms)}
|
|
1400
|
+
|
|
1401
|
+
self.ml_hess_active_indices = [
|
|
1402
|
+
self.full_to_hess_active[i] for i in self.selection_indices if i in self.full_to_hess_active
|
|
1403
|
+
]
|
|
1404
|
+
|
|
1405
|
+
self.freeze_model = [
|
|
1406
|
+
self._idx_map_real_to_model[i] for i in self.freeze_atoms if i in self._idx_map_real_to_model
|
|
1407
|
+
]
|
|
1408
|
+
|
|
1409
|
+
def _build_within_partial_hessian(self) -> Dict[str, np.ndarray | int | str]:
|
|
1410
|
+
"""Build metadata for a partial (Hessian-target-only) Hessian."""
|
|
1411
|
+
n_real = int(self._n_real)
|
|
1412
|
+
active_atoms = np.asarray(self.hess_active_atoms, dtype=int)
|
|
1413
|
+
active_n_atoms = int(active_atoms.size)
|
|
1414
|
+
|
|
1415
|
+
active_dofs = np.empty(active_n_atoms * 3, dtype=int)
|
|
1416
|
+
for i, a in enumerate(active_atoms):
|
|
1417
|
+
base = 3 * int(a)
|
|
1418
|
+
active_dofs[3 * i:3 * i + 3] = (base, base + 1, base + 2)
|
|
1419
|
+
|
|
1420
|
+
full_to_active = -np.ones(n_real, dtype=int)
|
|
1421
|
+
if active_n_atoms:
|
|
1422
|
+
full_to_active[active_atoms] = np.arange(active_n_atoms, dtype=int)
|
|
1423
|
+
|
|
1424
|
+
return {
|
|
1425
|
+
"kind": "hess-target-only",
|
|
1426
|
+
"active_atoms": active_atoms,
|
|
1427
|
+
"active_dofs": active_dofs,
|
|
1428
|
+
"active_to_full": active_atoms.copy(),
|
|
1429
|
+
"full_to_active": full_to_active,
|
|
1430
|
+
"full_n_atoms": n_real,
|
|
1431
|
+
"full_n_dof": int(3 * n_real),
|
|
1432
|
+
"active_n_atoms": active_n_atoms,
|
|
1433
|
+
"active_n_dof": int(3 * active_n_atoms),
|
|
1434
|
+
}
|
|
1435
|
+
|
|
1436
|
+
def _prep_3_layer_atoms(self, real_coord: np.ndarray):
|
|
1437
|
+
atoms_real = self._atoms_real_tpl.copy()
|
|
1438
|
+
atoms_real.set_positions(real_coord)
|
|
1439
|
+
|
|
1440
|
+
atoms_model = self._atoms_model_tpl.copy()
|
|
1441
|
+
atoms_model_LH = self._atoms_model_LH_tpl.copy()
|
|
1442
|
+
|
|
1443
|
+
for i, ridx in enumerate(self.ml_ID):
|
|
1444
|
+
pos = atoms_real[int(ridx) - 1].position
|
|
1445
|
+
atoms_model[i].position = pos
|
|
1446
|
+
atoms_model_LH[i].position = pos
|
|
1447
|
+
|
|
1448
|
+
added_link_atoms = []
|
|
1449
|
+
base_model_len = len(self._atoms_model_tpl)
|
|
1450
|
+
for k, (ml_idx, mm_idx) in enumerate(self.mlmm_links):
|
|
1451
|
+
ml_i = ml_idx - 1
|
|
1452
|
+
mm_i = mm_idx - 1
|
|
1453
|
+
ml_elem = atoms_real[ml_i].symbol
|
|
1454
|
+
if ml_elem == "C":
|
|
1455
|
+
dist = 1.09
|
|
1456
|
+
elif ml_elem == "N":
|
|
1457
|
+
dist = 1.01
|
|
1458
|
+
else:
|
|
1459
|
+
raise ValueError(
|
|
1460
|
+
f"Unsupported link parent element: {ml_elem}. Only C and N are supported."
|
|
1461
|
+
)
|
|
1462
|
+
vec = atoms_real[mm_i].position - atoms_real[ml_i].position
|
|
1463
|
+
R = np.linalg.norm(vec)
|
|
1464
|
+
if R < 1e-6:
|
|
1465
|
+
continue
|
|
1466
|
+
u = vec / R
|
|
1467
|
+
H_pos = atoms_real[ml_i].position + u * dist
|
|
1468
|
+
link_idx_in_model_LH = base_model_len + k
|
|
1469
|
+
atoms_model_LH[link_idx_in_model_LH].position = H_pos
|
|
1470
|
+
added_link_atoms.append((link_idx_in_model_LH, ml_i, mm_i, dist))
|
|
1471
|
+
|
|
1472
|
+
freeze_model: List[int] = []
|
|
1473
|
+
if self.freeze_atoms:
|
|
1474
|
+
atoms_real.set_constraint(FixAtoms(indices=self.freeze_atoms))
|
|
1475
|
+
real_to_model = self._idx_map_real_to_model
|
|
1476
|
+
freeze_model = [real_to_model[i] for i in self.freeze_atoms if i in real_to_model]
|
|
1477
|
+
if freeze_model:
|
|
1478
|
+
atoms_model.set_constraint(FixAtoms(indices=freeze_model))
|
|
1479
|
+
atoms_model_LH.set_constraint(FixAtoms(indices=freeze_model))
|
|
1480
|
+
|
|
1481
|
+
return atoms_real, atoms_model, atoms_model_LH, added_link_atoms, freeze_model
|
|
1482
|
+
|
|
1483
|
+
@staticmethod
|
|
1484
|
+
def _jacobian_blocks_numpy(r_ml: np.ndarray, r_mm: np.ndarray, dist: float) -> Optional[np.ndarray]:
|
|
1485
|
+
"""Returns J shape (6, 3): rows=[Q_xyz, M_xyz], cols=L_xyz."""
|
|
1486
|
+
vec = r_mm - r_ml
|
|
1487
|
+
R = np.linalg.norm(vec)
|
|
1488
|
+
if R < 1e-12:
|
|
1489
|
+
return None
|
|
1490
|
+
u = vec / R
|
|
1491
|
+
I = np.eye(3)
|
|
1492
|
+
du_dQ = (I - np.outer(u, u)) / R
|
|
1493
|
+
dR_dQ = I - dist * du_dQ
|
|
1494
|
+
dR_dM = dist * du_dQ
|
|
1495
|
+
return np.hstack([dR_dQ, dR_dM]).T
|
|
1496
|
+
|
|
1497
|
+
@staticmethod
|
|
1498
|
+
def _jacobian_blocks_torch(
|
|
1499
|
+
r_ml: torch.Tensor,
|
|
1500
|
+
r_mm: torch.Tensor,
|
|
1501
|
+
dist: float,
|
|
1502
|
+
*,
|
|
1503
|
+
dtype: torch.dtype,
|
|
1504
|
+
device: torch.device,
|
|
1505
|
+
) -> Optional[torch.Tensor]:
|
|
1506
|
+
"""Returns K shape (3, 6): rows=L_xyz, cols=[Q_xyz, M_xyz]."""
|
|
1507
|
+
vec = r_mm - r_ml
|
|
1508
|
+
Rlen = torch.norm(vec)
|
|
1509
|
+
if float(Rlen) < 1e-12:
|
|
1510
|
+
return None
|
|
1511
|
+
u = vec / Rlen
|
|
1512
|
+
I3 = torch.eye(3, dtype=dtype, device=device)
|
|
1513
|
+
du_dQ = (I3 - torch.outer(u, u)) / Rlen
|
|
1514
|
+
dR_dQ = I3 - dist * du_dQ
|
|
1515
|
+
dR_dM = dist * du_dQ
|
|
1516
|
+
return torch.hstack([dR_dQ, dR_dM])
|
|
1517
|
+
|
|
1518
|
+
def _get_mm_charges(self, atom_indices: Sequence[int]) -> np.ndarray:
|
|
1519
|
+
"""Retrieve MM partial charges for the given atom indices.
|
|
1520
|
+
|
|
1521
|
+
Works with both hessian_ff (AmberSystem) and OpenMM backends.
|
|
1522
|
+
"""
|
|
1523
|
+
calc = self.calc_real_low
|
|
1524
|
+
# hessian_ff: AmberSystem with .charge tensor
|
|
1525
|
+
if isinstance(calc, hessianffCalculator) and hasattr(calc, "system"):
|
|
1526
|
+
return np.array(
|
|
1527
|
+
[calc.system.charge[i].item() for i in atom_indices],
|
|
1528
|
+
dtype=np.float64,
|
|
1529
|
+
)
|
|
1530
|
+
# OpenMM: extract charges from NonbondedForce
|
|
1531
|
+
if isinstance(calc, OpenMMCalculator) and HAS_OPENMM:
|
|
1532
|
+
sys_omm = calc.system
|
|
1533
|
+
for fi in range(sys_omm.getNumForces()):
|
|
1534
|
+
force = sys_omm.getForce(fi)
|
|
1535
|
+
if force.__class__.__name__ == "NonbondedForce":
|
|
1536
|
+
charges = np.array(
|
|
1537
|
+
[force.getParticleParameters(i)[0].value_in_unit(
|
|
1538
|
+
unit.elementary_charge)
|
|
1539
|
+
for i in atom_indices],
|
|
1540
|
+
dtype=np.float64,
|
|
1541
|
+
)
|
|
1542
|
+
return charges
|
|
1543
|
+
# Fallback: zero charges
|
|
1544
|
+
warnings.warn(
|
|
1545
|
+
"Could not extract MM charges from the calculator; returning zeros. "
|
|
1546
|
+
"Embedcharge correction will have no effect.",
|
|
1547
|
+
RuntimeWarning,
|
|
1548
|
+
)
|
|
1549
|
+
return np.zeros(len(atom_indices), dtype=np.float64)
|
|
1550
|
+
|
|
1551
|
+
def _eval_ml_high(self, atoms_model_LH: Atoms, freeze_model: Sequence[int], *, return_hessian: bool) -> _MLHighOut:
|
|
1552
|
+
local_timing: Dict[str, float | str] = {}
|
|
1553
|
+
E_model_high, F_model_high, opaque = self._ml_backend.eval(atoms_model_LH, need_grad=True)
|
|
1554
|
+
local_timing["ml_backend"] = self.backend_name
|
|
1555
|
+
|
|
1556
|
+
H_high = None
|
|
1557
|
+
if return_hessian:
|
|
1558
|
+
n_mlLH = len(atoms_model_LH)
|
|
1559
|
+
if self._ml_hessian_mode == "analytical" and self._ml_backend.supports_analytical_hessian:
|
|
1560
|
+
t0 = time.perf_counter()
|
|
1561
|
+
H_high = self._ml_backend.hessian_analytical(opaque, n_mlLH, dtype=self.H_dtype)
|
|
1562
|
+
local_timing["ml_hessian_mode"] = "Analytical"
|
|
1563
|
+
local_timing["ml_hessian_s"] = time.perf_counter() - t0
|
|
1564
|
+
else:
|
|
1565
|
+
t0 = time.perf_counter()
|
|
1566
|
+
H_high = self._ml_backend.hessian_fd(
|
|
1567
|
+
atoms_model_LH, freeze_model,
|
|
1568
|
+
eps_ang=1.0e-3, dtype=self.H_dtype, device=self.ml_device,
|
|
1569
|
+
)
|
|
1570
|
+
local_timing["ml_hessian_mode"] = "FiniteDifference"
|
|
1571
|
+
local_timing["ml_hessian_s"] = time.perf_counter() - t0
|
|
1572
|
+
|
|
1573
|
+
return _MLHighOut(E=E_model_high, F=F_model_high, H=H_high, timing=local_timing)
|
|
1574
|
+
|
|
1575
|
+
def _eval_mm_low(self, atoms_real: Atoms, atoms_model: Atoms, *, return_hessian: bool) -> _MMLowOut:
|
|
1576
|
+
local_timing: Dict[str, float | str] = {}
|
|
1577
|
+
|
|
1578
|
+
atoms_real.calc = self.calc_real_low
|
|
1579
|
+
atoms_model.calc = self.calc_model_low
|
|
1580
|
+
|
|
1581
|
+
E_real_low = atoms_real.get_potential_energy()
|
|
1582
|
+
F_real_low = np.double(atoms_real.get_forces())
|
|
1583
|
+
|
|
1584
|
+
E_model_low = atoms_model.get_potential_energy()
|
|
1585
|
+
F_model_low = np.double(atoms_model.get_forces())
|
|
1586
|
+
|
|
1587
|
+
H_real_np = None
|
|
1588
|
+
H_model_np = None
|
|
1589
|
+
active_atoms_from_fd = None
|
|
1590
|
+
|
|
1591
|
+
if return_hessian and self.mm_fd is True:
|
|
1592
|
+
info_real = os.path.join(self.mm_fd_dir, "real.log") if self.mm_fd_dir else None
|
|
1593
|
+
info_model = os.path.join(self.mm_fd_dir, "model.log") if self.mm_fd_dir else None
|
|
1594
|
+
|
|
1595
|
+
atoms_real_for_hess = atoms_real.copy()
|
|
1596
|
+
# Clear any inherited constraints before applying hess-specific ones
|
|
1597
|
+
atoms_real_for_hess.set_constraint()
|
|
1598
|
+
atoms_real_for_hess.calc = self.calc_real_low
|
|
1599
|
+
if self.hess_freeze_atoms:
|
|
1600
|
+
atoms_real_for_hess.set_constraint(FixAtoms(indices=self.hess_freeze_atoms))
|
|
1601
|
+
|
|
1602
|
+
t0 = time.perf_counter()
|
|
1603
|
+
H_real_np, active_atoms_from_fd = self.calc_real_low.finite_difference_hessian(
|
|
1604
|
+
atoms_real_for_hess,
|
|
1605
|
+
delta=self.mm_fd_delta,
|
|
1606
|
+
info_path=info_real,
|
|
1607
|
+
dtype=self.H_np_dtype,
|
|
1608
|
+
return_partial_hessian=True,
|
|
1609
|
+
)
|
|
1610
|
+
local_timing["mm_fd_real_s"] = time.perf_counter() - t0
|
|
1611
|
+
|
|
1612
|
+
t0 = time.perf_counter()
|
|
1613
|
+
H_model_np, _ = self.calc_model_low.finite_difference_hessian(
|
|
1614
|
+
atoms_model,
|
|
1615
|
+
delta=self.mm_fd_delta,
|
|
1616
|
+
info_path=info_model,
|
|
1617
|
+
dtype=self.H_np_dtype,
|
|
1618
|
+
return_partial_hessian=False,
|
|
1619
|
+
)
|
|
1620
|
+
local_timing["mm_fd_model_s"] = time.perf_counter() - t0
|
|
1621
|
+
local_timing["mm_fd_total_s"] = float(local_timing["mm_fd_real_s"]) + float(local_timing["mm_fd_model_s"])
|
|
1622
|
+
|
|
1623
|
+
return _MMLowOut(
|
|
1624
|
+
E_real=E_real_low,
|
|
1625
|
+
F_real=F_real_low,
|
|
1626
|
+
E_model=E_model_low,
|
|
1627
|
+
F_model=F_model_low,
|
|
1628
|
+
H_real=H_real_np,
|
|
1629
|
+
H_model=H_model_np,
|
|
1630
|
+
active_atoms_from_fd=active_atoms_from_fd,
|
|
1631
|
+
timing=local_timing,
|
|
1632
|
+
)
|
|
1633
|
+
|
|
1634
|
+
def compute(
|
|
1635
|
+
self,
|
|
1636
|
+
coord_ang: np.ndarray,
|
|
1637
|
+
*,
|
|
1638
|
+
return_forces: bool = False,
|
|
1639
|
+
return_hessian: bool = False,
|
|
1640
|
+
) -> Dict:
|
|
1641
|
+
timing: Dict[str, float | str] = {}
|
|
1642
|
+
hess_total_start: Optional[float] = time.perf_counter() if return_hessian else None
|
|
1643
|
+
hess_vram_base_alloc: Optional[float] = None
|
|
1644
|
+
hess_vram_base_reserved: Optional[float] = None
|
|
1645
|
+
hess_vram_total: Optional[float] = None
|
|
1646
|
+
if return_hessian and self.print_vram and self.ml_device.type == "cuda":
|
|
1647
|
+
torch.cuda.synchronize(device=self.ml_device)
|
|
1648
|
+
hess_vram_base_alloc = float(torch.cuda.memory_allocated(device=self.ml_device))
|
|
1649
|
+
hess_vram_base_reserved = float(torch.cuda.memory_reserved(device=self.ml_device))
|
|
1650
|
+
hess_vram_total = float(torch.cuda.get_device_properties(self.ml_device).total_memory)
|
|
1651
|
+
torch.cuda.reset_peak_memory_stats(device=self.ml_device)
|
|
1652
|
+
|
|
1653
|
+
atoms_real, atoms_model, atoms_model_LH, added_link_atoms, freeze_model = self._prep_3_layer_atoms(coord_ang)
|
|
1654
|
+
atoms_real.set_pbc(False)
|
|
1655
|
+
atoms_model.set_pbc(False)
|
|
1656
|
+
atoms_model_LH.set_pbc(False)
|
|
1657
|
+
|
|
1658
|
+
use_parallel = (self.ml_device.type == "cuda") and (getattr(self.calc_real_low, "device", None) == "cpu")
|
|
1659
|
+
if use_parallel:
|
|
1660
|
+
with ThreadPoolExecutor(max_workers=2) as executor:
|
|
1661
|
+
fut_ml = executor.submit(self._eval_ml_high, atoms_model_LH, freeze_model, return_hessian=return_hessian)
|
|
1662
|
+
fut_mm = executor.submit(self._eval_mm_low, atoms_real, atoms_model, return_hessian=return_hessian)
|
|
1663
|
+
ml_out = fut_ml.result()
|
|
1664
|
+
mm_out = fut_mm.result()
|
|
1665
|
+
else:
|
|
1666
|
+
ml_out = self._eval_ml_high(atoms_model_LH, freeze_model, return_hessian=return_hessian)
|
|
1667
|
+
mm_out = self._eval_mm_low(atoms_real, atoms_model, return_hessian=return_hessian)
|
|
1668
|
+
|
|
1669
|
+
timing.update(ml_out.timing)
|
|
1670
|
+
timing.update(mm_out.timing)
|
|
1671
|
+
|
|
1672
|
+
total_E = mm_out.E_real + ml_out.E - mm_out.E_model
|
|
1673
|
+
results: Dict = {"energy": total_E}
|
|
1674
|
+
|
|
1675
|
+
if return_forces or return_hessian:
|
|
1676
|
+
F_combined = np.copy(mm_out.F_real)
|
|
1677
|
+
for i, ridx in enumerate(self.selection_indices):
|
|
1678
|
+
F_combined[ridx] += ml_out.F[i] - mm_out.F_model[i]
|
|
1679
|
+
|
|
1680
|
+
real_to_model = self._idx_map_real_to_model
|
|
1681
|
+
for link_idx, ml_idx, mm_idx, dist in added_link_atoms:
|
|
1682
|
+
ml_model_idx = real_to_model[ml_idx]
|
|
1683
|
+
r_ml = atoms_model_LH[ml_model_idx].position
|
|
1684
|
+
r_mm = atoms_real[mm_idx].position
|
|
1685
|
+
grad_link = ml_out.F[link_idx]
|
|
1686
|
+
J = self._jacobian_blocks_numpy(r_ml, r_mm, dist)
|
|
1687
|
+
if J is None:
|
|
1688
|
+
continue
|
|
1689
|
+
redistributed = J @ grad_link
|
|
1690
|
+
F_combined[ml_idx] += redistributed[:3]
|
|
1691
|
+
F_combined[mm_idx] += redistributed[3:]
|
|
1692
|
+
results["forces"] = F_combined
|
|
1693
|
+
|
|
1694
|
+
# Point-charge embedding correction (optional)
|
|
1695
|
+
embed_dH = None
|
|
1696
|
+
if self.embedcharge and self._embed_correction is not None:
|
|
1697
|
+
t0_embed = time.perf_counter()
|
|
1698
|
+
# ML atom symbols and coordinates
|
|
1699
|
+
ml_symbols = [atoms_model_LH[i].symbol for i in range(len(self._atoms_model_tpl))]
|
|
1700
|
+
ml_coords = np.array([atoms_model_LH[i].position for i in range(len(self._atoms_model_tpl))])
|
|
1701
|
+
# MM atom coordinates and charges from the real topology
|
|
1702
|
+
ml_set = set(self.selection_indices)
|
|
1703
|
+
mm_atom_indices = [i for i in range(len(atoms_real)) if i not in ml_set]
|
|
1704
|
+
if mm_atom_indices and self.embedcharge_cutoff is not None:
|
|
1705
|
+
from scipy.spatial.distance import cdist
|
|
1706
|
+
ml_coords = atoms_real.get_positions()[sorted(ml_set)]
|
|
1707
|
+
mm_coords_all = atoms_real.get_positions()[mm_atom_indices]
|
|
1708
|
+
dists = cdist(mm_coords_all, ml_coords).min(axis=1)
|
|
1709
|
+
n_before = len(mm_atom_indices)
|
|
1710
|
+
mask = dists <= self.embedcharge_cutoff
|
|
1711
|
+
mm_atom_indices = [mm_atom_indices[j] for j in range(n_before) if mask[j]]
|
|
1712
|
+
if self.print_timing and not getattr(self, '_embedcharge_logged', False):
|
|
1713
|
+
print(f"[embedcharge] {len(mm_atom_indices)}/{n_before} MM atoms within {self.embedcharge_cutoff:.1f} Å cutoff.")
|
|
1714
|
+
self._embedcharge_logged = True
|
|
1715
|
+
if mm_atom_indices:
|
|
1716
|
+
mm_coords = atoms_real.get_positions()[mm_atom_indices]
|
|
1717
|
+
# Get MM partial charges from the topology
|
|
1718
|
+
mm_charges = self._get_mm_charges(mm_atom_indices)
|
|
1719
|
+
|
|
1720
|
+
dE_embed, dF_embed, dH_embed = self._embed_correction.compute_correction(
|
|
1721
|
+
symbols=ml_symbols,
|
|
1722
|
+
coords_ml_ang=ml_coords,
|
|
1723
|
+
mm_coords_ang=mm_coords,
|
|
1724
|
+
mm_charges=mm_charges,
|
|
1725
|
+
charge=self.model_charge,
|
|
1726
|
+
multiplicity=self.model_mult,
|
|
1727
|
+
need_forces=return_forces or return_hessian,
|
|
1728
|
+
need_hessian=return_hessian,
|
|
1729
|
+
)
|
|
1730
|
+
|
|
1731
|
+
# Add energy correction
|
|
1732
|
+
results["energy"] += dE_embed
|
|
1733
|
+
|
|
1734
|
+
# Add force corrections on ML atoms
|
|
1735
|
+
if dF_embed is not None and (return_forces or return_hessian):
|
|
1736
|
+
for i, ridx in enumerate(self.selection_indices):
|
|
1737
|
+
if i < len(dF_embed):
|
|
1738
|
+
results["forces"][ridx] += dF_embed[i]
|
|
1739
|
+
|
|
1740
|
+
# Store Hessian correction for later assembly
|
|
1741
|
+
if dH_embed is not None:
|
|
1742
|
+
embed_dH = dH_embed
|
|
1743
|
+
|
|
1744
|
+
timing["embedcharge_s"] = time.perf_counter() - t0_embed
|
|
1745
|
+
|
|
1746
|
+
if return_hessian:
|
|
1747
|
+
n_real = len(atoms_real)
|
|
1748
|
+
n_ml = len(self.selection_indices)
|
|
1749
|
+
n_hess_active = self.n_hess_active
|
|
1750
|
+
|
|
1751
|
+
if self.mm_fd is True:
|
|
1752
|
+
if mm_out.H_real is None or mm_out.H_model is None:
|
|
1753
|
+
raise RuntimeError("MM Hessians were not computed as expected.")
|
|
1754
|
+
|
|
1755
|
+
if mm_out.active_atoms_from_fd is not None:
|
|
1756
|
+
expected = set(self.hess_active_atoms)
|
|
1757
|
+
got = set(mm_out.active_atoms_from_fd.tolist())
|
|
1758
|
+
if expected != got:
|
|
1759
|
+
raise RuntimeError(
|
|
1760
|
+
f"Hessian active atoms mismatch: expected {len(expected)} atoms, got {len(got)}"
|
|
1761
|
+
)
|
|
1762
|
+
|
|
1763
|
+
H = torch.from_numpy(mm_out.H_real).to(self.ml_device, self.H_dtype)
|
|
1764
|
+
H = H.view(n_hess_active, 3, n_hess_active, 3)
|
|
1765
|
+
|
|
1766
|
+
H_model = torch.from_numpy(mm_out.H_model).to(self.ml_device, self.H_dtype)
|
|
1767
|
+
H_model = H_model.view(n_ml, 3, n_ml, 3)
|
|
1768
|
+
else:
|
|
1769
|
+
H = torch.zeros((n_hess_active, 3, n_hess_active, 3), dtype=self.H_dtype, device=self.ml_device)
|
|
1770
|
+
H_model = torch.zeros((n_ml, 3, n_ml, 3), dtype=self.H_dtype, device=self.ml_device)
|
|
1771
|
+
|
|
1772
|
+
H_high = ml_out.H
|
|
1773
|
+
ml_pairs = [
|
|
1774
|
+
(i, self.full_to_hess_active[gi_real])
|
|
1775
|
+
for i, gi_real in enumerate(self.selection_indices)
|
|
1776
|
+
if gi_real in self.full_to_hess_active
|
|
1777
|
+
]
|
|
1778
|
+
if ml_pairs:
|
|
1779
|
+
ml_sel_idx = torch.as_tensor([p[0] for p in ml_pairs], dtype=torch.long, device=self.ml_device)
|
|
1780
|
+
ml_active_idx = torch.as_tensor([p[1] for p in ml_pairs], dtype=torch.long, device=self.ml_device)
|
|
1781
|
+
else:
|
|
1782
|
+
ml_sel_idx = torch.empty((0,), dtype=torch.long, device=self.ml_device)
|
|
1783
|
+
ml_active_idx = torch.empty((0,), dtype=torch.long, device=self.ml_device)
|
|
1784
|
+
|
|
1785
|
+
if H_high is not None and ml_sel_idx.numel() > 0:
|
|
1786
|
+
t_asm = time.perf_counter()
|
|
1787
|
+
H_high_mm = H_high.index_select(0, ml_sel_idx).index_select(2, ml_sel_idx)
|
|
1788
|
+
H_model_mm = H_model.index_select(0, ml_sel_idx).index_select(2, ml_sel_idx)
|
|
1789
|
+
delta_mm = H_high_mm - H_model_mm
|
|
1790
|
+
H[ml_active_idx[:, None], :, ml_active_idx[None, :], :] += delta_mm.permute(0, 2, 1, 3)
|
|
1791
|
+
timing["hess_asm_mlml_s"] = time.perf_counter() - t_asm
|
|
1792
|
+
del H_model
|
|
1793
|
+
|
|
1794
|
+
real_to_model = self._idx_map_real_to_model
|
|
1795
|
+
link_data: List[Tuple[int, int, int, int, int, float, torch.Tensor]] = []
|
|
1796
|
+
for link_idx, ml_idx, mm_idx, dist in added_link_atoms:
|
|
1797
|
+
ml_model_idx = real_to_model[ml_idx]
|
|
1798
|
+
r_ml_t = torch.tensor(atoms_model_LH[ml_model_idx].position, dtype=self.H_dtype, device=self.ml_device)
|
|
1799
|
+
r_mm_t = torch.tensor(atoms_real[mm_idx].position, dtype=self.H_dtype, device=self.ml_device)
|
|
1800
|
+
K = self._jacobian_blocks_torch(r_ml_t, r_mm_t, dist, dtype=self.H_dtype, device=self.ml_device)
|
|
1801
|
+
if K is None:
|
|
1802
|
+
continue
|
|
1803
|
+
ml_active = self.full_to_hess_active.get(ml_idx)
|
|
1804
|
+
mm_active = self.full_to_hess_active.get(mm_idx)
|
|
1805
|
+
if ml_active is None or mm_active is None:
|
|
1806
|
+
continue
|
|
1807
|
+
link_data.append((link_idx, ml_idx, mm_idx, ml_active, mm_active, dist, K))
|
|
1808
|
+
|
|
1809
|
+
F_high_t = torch.as_tensor(ml_out.F, dtype=self.H_dtype, device=self.ml_device)
|
|
1810
|
+
has_link_force = bool((F_high_t.abs() > 1e-12).any().item())
|
|
1811
|
+
if link_data and (H_high is not None or has_link_force):
|
|
1812
|
+
t_asm = time.perf_counter()
|
|
1813
|
+
I3 = torch.eye(3, dtype=self.H_dtype, device=self.ml_device)
|
|
1814
|
+
for link_idx, ml_idx, mm_idx, ml_active, mm_active, dist, K in link_data:
|
|
1815
|
+
|
|
1816
|
+
if H_high is not None:
|
|
1817
|
+
H_l = H_high[link_idx, :, link_idx, :]
|
|
1818
|
+
H_self = K.T @ H_l @ K
|
|
1819
|
+
H[ml_active, :, ml_active, :].add_(H_self[0:3, 0:3])
|
|
1820
|
+
H[ml_active, :, mm_active, :].add_(H_self[0:3, 3:6])
|
|
1821
|
+
H[mm_active, :, ml_active, :].add_(H_self[3:6, 0:3])
|
|
1822
|
+
H[mm_active, :, mm_active, :].add_(H_self[3:6, 3:6])
|
|
1823
|
+
|
|
1824
|
+
f_L = -F_high_t[link_idx]
|
|
1825
|
+
|
|
1826
|
+
r_ml_t = torch.as_tensor(atoms_model_LH[real_to_model[ml_idx]].position,
|
|
1827
|
+
dtype=self.H_dtype, device=self.ml_device)
|
|
1828
|
+
r_mm_t = torch.as_tensor(atoms_real[mm_idx].position,
|
|
1829
|
+
dtype=self.H_dtype, device=self.ml_device)
|
|
1830
|
+
v = r_mm_t - r_ml_t
|
|
1831
|
+
R_sq = torch.dot(v, v)
|
|
1832
|
+
inv_R = torch.rsqrt(torch.clamp(R_sq, min=1.0e-24))
|
|
1833
|
+
inv_R2 = inv_R * inv_R
|
|
1834
|
+
u = v * inv_R
|
|
1835
|
+
|
|
1836
|
+
alpha = torch.dot(u, f_L)
|
|
1837
|
+
uuT = torch.outer(u, u)
|
|
1838
|
+
ufT = torch.outer(u, f_L)
|
|
1839
|
+
fTu = torch.outer(f_L, u)
|
|
1840
|
+
B = (alpha * (3.0 * uuT - I3) - (ufT + fTu)) * inv_R2
|
|
1841
|
+
|
|
1842
|
+
H_corr6 = torch.zeros((6, 6), dtype=self.H_dtype, device=self.ml_device)
|
|
1843
|
+
H_corr6[0:3, 0:3] = B
|
|
1844
|
+
H_corr6[3:6, 3:6] = B
|
|
1845
|
+
H_corr6[0:3, 3:6] = -B
|
|
1846
|
+
H_corr6[3:6, 0:3] = -B
|
|
1847
|
+
H_corr6.mul_(dist)
|
|
1848
|
+
|
|
1849
|
+
H[ml_active, :, ml_active, :].add_(H_corr6[0:3, 0:3])
|
|
1850
|
+
H[ml_active, :, mm_active, :].add_(H_corr6[0:3, 3:6])
|
|
1851
|
+
H[mm_active, :, ml_active, :].add_(H_corr6[3:6, 0:3])
|
|
1852
|
+
H[mm_active, :, mm_active, :].add_(H_corr6[3:6, 3:6])
|
|
1853
|
+
timing["hess_asm_link_self_s"] = time.perf_counter() - t_asm
|
|
1854
|
+
|
|
1855
|
+
if H_high is not None and link_data and ml_sel_idx.numel() > 0:
|
|
1856
|
+
t_asm = time.perf_counter()
|
|
1857
|
+
for link_idx, _ml_idx, _mm_idx, ml_active, mm_active, _dist, K in link_data:
|
|
1858
|
+
H_coup = H_high[link_idx].index_select(1, ml_sel_idx).permute(1, 0, 2).contiguous() # (K,3,3)
|
|
1859
|
+
H_row = torch.einsum("ac,bcd->bad", K.T, H_coup) # (K,6,3)
|
|
1860
|
+
H_col = torch.einsum("bca,cd->bad", H_coup, K) # (K,3,6)
|
|
1861
|
+
|
|
1862
|
+
# Mixed scalar/tensor indexing in PyTorch returns (3, K, 3) for
|
|
1863
|
+
# H[scalar, :, tensor, :], so align H_row blocks explicitly.
|
|
1864
|
+
H[ml_active, :, ml_active_idx, :].add_(H_row[:, 0:3, :].permute(1, 0, 2))
|
|
1865
|
+
H[mm_active, :, ml_active_idx, :].add_(H_row[:, 3:6, :].permute(1, 0, 2))
|
|
1866
|
+
H[ml_active_idx, :, ml_active, :].add_(H_col[:, :, 0:3])
|
|
1867
|
+
H[ml_active_idx, :, mm_active, :].add_(H_col[:, :, 3:6])
|
|
1868
|
+
timing["hess_asm_link_ml_s"] = time.perf_counter() - t_asm
|
|
1869
|
+
|
|
1870
|
+
if H_high is not None and link_data:
|
|
1871
|
+
t_asm = time.perf_counter()
|
|
1872
|
+
n_links = len(link_data)
|
|
1873
|
+
for a in range(n_links):
|
|
1874
|
+
link_idx_a, _ml_a, _mm_a, ml_a_active, mm_a_active, _dist_a, K_a = link_data[a]
|
|
1875
|
+
|
|
1876
|
+
for b in range(a + 1, n_links):
|
|
1877
|
+
link_idx_b, _ml_b, _mm_b, ml_b_active, mm_b_active, _dist_b, K_b = link_data[b]
|
|
1878
|
+
|
|
1879
|
+
H_ab = H_high[link_idx_a, :, link_idx_b, :]
|
|
1880
|
+
HAB = K_a.T @ H_ab @ K_b
|
|
1881
|
+
|
|
1882
|
+
H[ml_a_active, :, ml_b_active, :].add_(HAB[0:3, 0:3])
|
|
1883
|
+
H[ml_a_active, :, mm_b_active, :].add_(HAB[0:3, 3:6])
|
|
1884
|
+
H[mm_a_active, :, ml_b_active, :].add_(HAB[3:6, 0:3])
|
|
1885
|
+
H[mm_a_active, :, mm_b_active, :].add_(HAB[3:6, 3:6])
|
|
1886
|
+
|
|
1887
|
+
HBA = HAB.T
|
|
1888
|
+
H[ml_b_active, :, ml_a_active, :].add_(HBA[0:3, 0:3])
|
|
1889
|
+
H[ml_b_active, :, mm_a_active, :].add_(HBA[0:3, 3:6])
|
|
1890
|
+
H[mm_b_active, :, ml_a_active, :].add_(HBA[3:6, 0:3])
|
|
1891
|
+
H[mm_b_active, :, mm_a_active, :].add_(HBA[3:6, 3:6])
|
|
1892
|
+
timing["hess_asm_link_link_s"] = time.perf_counter() - t_asm
|
|
1893
|
+
|
|
1894
|
+
# Add point-charge embedding Hessian correction
|
|
1895
|
+
if embed_dH is not None:
|
|
1896
|
+
t_asm = time.perf_counter()
|
|
1897
|
+
n_model_atoms = len(self.selection_indices)
|
|
1898
|
+
dH_t = torch.from_numpy(embed_dH).to(self.ml_device, self.H_dtype)
|
|
1899
|
+
dH_t = dH_t.view(n_model_atoms, 3, n_model_atoms, 3)
|
|
1900
|
+
if ml_sel_idx.numel() > 0:
|
|
1901
|
+
dH_sub = dH_t.index_select(0, ml_sel_idx).index_select(2, ml_sel_idx)
|
|
1902
|
+
H[ml_active_idx[:, None], :, ml_active_idx[None, :], :] += dH_sub.permute(0, 2, 1, 3)
|
|
1903
|
+
timing["hess_asm_embed_s"] = time.perf_counter() - t_asm
|
|
1904
|
+
|
|
1905
|
+
if self.symmetrize_hessian:
|
|
1906
|
+
t_asm = time.perf_counter()
|
|
1907
|
+
H_flat = H.view(3 * n_hess_active, 3 * n_hess_active)
|
|
1908
|
+
H_flat = (H_flat + H_flat.t()).mul_(0.5)
|
|
1909
|
+
H = H_flat.view(n_hess_active, 3, n_hess_active, 3)
|
|
1910
|
+
timing["hess_asm_sym_s"] = time.perf_counter() - t_asm
|
|
1911
|
+
|
|
1912
|
+
if self.return_partial_hessian:
|
|
1913
|
+
results["hessian"] = H.detach()
|
|
1914
|
+
results["within_partial_hessian"] = self._build_within_partial_hessian()
|
|
1915
|
+
else:
|
|
1916
|
+
t_asm = time.perf_counter()
|
|
1917
|
+
H_full = torch.zeros((n_real, 3, n_real, 3), dtype=self.H_dtype, device=self.ml_device)
|
|
1918
|
+
active_idx = torch.as_tensor(self.hess_active_atoms, dtype=torch.long, device=self.ml_device)
|
|
1919
|
+
if active_idx.numel() > 0:
|
|
1920
|
+
H_full[active_idx[:, None], :, active_idx[None, :], :] = H.permute(0, 2, 1, 3).contiguous()
|
|
1921
|
+
results["hessian"] = H_full.detach()
|
|
1922
|
+
timing["hess_asm_full_expand_s"] = time.perf_counter() - t_asm
|
|
1923
|
+
del H_full
|
|
1924
|
+
|
|
1925
|
+
if hess_total_start is not None:
|
|
1926
|
+
timing["hessian_total_s"] = time.perf_counter() - hess_total_start
|
|
1927
|
+
results["timing"] = timing
|
|
1928
|
+
if self.print_timing:
|
|
1929
|
+
ml_mode = timing.get("ml_hessian_mode")
|
|
1930
|
+
ml_time = timing.get("ml_hessian_s")
|
|
1931
|
+
if ml_mode is not None and ml_time is not None:
|
|
1932
|
+
click.echo(f"[HessianTiming] ML Hessian ({ml_mode}): {ml_time:.2f} s")
|
|
1933
|
+
if "mm_fd_total_s" in timing:
|
|
1934
|
+
click.echo(
|
|
1935
|
+
f"[HessianTiming] MM Hessian: REAL {timing['mm_fd_real_s']:.2f} s | "
|
|
1936
|
+
f"MODEL {timing['mm_fd_model_s']:.2f} s | "
|
|
1937
|
+
f"total {timing['mm_fd_total_s']:.2f} s"
|
|
1938
|
+
)
|
|
1939
|
+
asm_parts = []
|
|
1940
|
+
for key, label in (
|
|
1941
|
+
("hess_asm_mlml_s", "ML-ML"),
|
|
1942
|
+
("hess_asm_link_self_s", "link-self"),
|
|
1943
|
+
("hess_asm_link_ml_s", "link-ML"),
|
|
1944
|
+
("hess_asm_link_link_s", "link-link"),
|
|
1945
|
+
("hess_asm_sym_s", "sym"),
|
|
1946
|
+
("hess_asm_full_expand_s", "full-expand"),
|
|
1947
|
+
):
|
|
1948
|
+
if key in timing:
|
|
1949
|
+
asm_parts.append(f"{label} {float(timing[key]):.2f} s")
|
|
1950
|
+
if asm_parts:
|
|
1951
|
+
click.echo(f"[HessianTiming] Assembly: {' | '.join(asm_parts)}")
|
|
1952
|
+
click.echo(f"[HessianTiming] Hessian total: {timing['hessian_total_s']:.2f} s")
|
|
1953
|
+
if self.print_vram and self.ml_device.type == "cuda":
|
|
1954
|
+
torch.cuda.synchronize(device=self.ml_device)
|
|
1955
|
+
base_alloc = float(hess_vram_base_alloc or 0.0)
|
|
1956
|
+
base_reserved = float(hess_vram_base_reserved or 0.0)
|
|
1957
|
+
peak_alloc = max(
|
|
1958
|
+
float(torch.cuda.max_memory_allocated(device=self.ml_device)) - base_alloc,
|
|
1959
|
+
0.0,
|
|
1960
|
+
) / 1e9
|
|
1961
|
+
peak_reserved_abs = float(torch.cuda.max_memory_reserved(device=self.ml_device))
|
|
1962
|
+
peak_reserved = max(
|
|
1963
|
+
peak_reserved_abs - base_reserved,
|
|
1964
|
+
0.0,
|
|
1965
|
+
) / 1e9
|
|
1966
|
+
total_vram = float(hess_vram_total or torch.cuda.get_device_properties(self.ml_device).total_memory) / 1e9
|
|
1967
|
+
remaining_vram = max((total_vram * 1e9) - peak_reserved_abs, 0.0) / 1e9
|
|
1968
|
+
click.echo(
|
|
1969
|
+
f"[HessianVRAM] total={total_vram:.3f} GB | "
|
|
1970
|
+
f"peak_allocated={peak_alloc:.3f} GB | "
|
|
1971
|
+
f"peak_reserved={peak_reserved:.3f} GB | "
|
|
1972
|
+
f"remaining={remaining_vram:.3f} GB"
|
|
1973
|
+
)
|
|
1974
|
+
|
|
1975
|
+
del H, H_high
|
|
1976
|
+
if self.ml_device.type == "cuda":
|
|
1977
|
+
torch.cuda.empty_cache()
|
|
1978
|
+
|
|
1979
|
+
return results
|
|
1980
|
+
|
|
1981
|
+
|
|
1982
|
+
# ======================================================================
|
|
1983
|
+
# ASE Calculator wrapper for ML/MM (ONIOM)
|
|
1984
|
+
# ======================================================================
|
|
1985
|
+
|
|
1986
|
+
class MLMMASECalculator(Calculator):
|
|
1987
|
+
"""ASE Calculator wrapping MLMMCore for use with DMF and other ASE-based methods.
|
|
1988
|
+
|
|
1989
|
+
The underlying MLMMCore takes full-system coordinates (Angstrom) and
|
|
1990
|
+
returns energy in eV and forces in eV/Angstrom, which matches ASE conventions.
|
|
1991
|
+
"""
|
|
1992
|
+
|
|
1993
|
+
implemented_properties = ["energy", "forces"]
|
|
1994
|
+
|
|
1995
|
+
def __init__(self, core: "MLMMCore", **kwargs):
|
|
1996
|
+
super().__init__(**kwargs)
|
|
1997
|
+
self.core = core
|
|
1998
|
+
|
|
1999
|
+
def calculate(self, atoms=None, properties=("energy",), system_changes=all_changes):
|
|
2000
|
+
super().calculate(atoms, properties, system_changes)
|
|
2001
|
+
coord_ang = atoms.get_positions().astype(float)
|
|
2002
|
+
want_forces = "forces" in properties
|
|
2003
|
+
res = self.core.compute(coord_ang, return_forces=want_forces, return_hessian=False)
|
|
2004
|
+
self.results = {
|
|
2005
|
+
"energy": float(res["energy"]),
|
|
2006
|
+
}
|
|
2007
|
+
if want_forces:
|
|
2008
|
+
self.results["forces"] = res["forces"].reshape(-1, 3)
|
|
2009
|
+
|
|
2010
|
+
|
|
2011
|
+
# ======================================================================
|
|
2012
|
+
# PySisyphus Calculator (ML/MM)
|
|
2013
|
+
# ======================================================================
|
|
2014
|
+
|
|
2015
|
+
from pysisyphus.calculators.Calculator import Calculator as PySiCalc
|
|
2016
|
+
|
|
2017
|
+
|
|
2018
|
+
class mlmm(PySiCalc):
|
|
2019
|
+
implemented_properties = ["energy", "forces", "hessian"]
|
|
2020
|
+
|
|
2021
|
+
def __init__(
|
|
2022
|
+
self,
|
|
2023
|
+
input_pdb: Optional[str] = None,
|
|
2024
|
+
real_parm7: Optional[str] = None,
|
|
2025
|
+
model_pdb: Optional[str] = None,
|
|
2026
|
+
*,
|
|
2027
|
+
model_charge: int = 0,
|
|
2028
|
+
model_mult: int = 1,
|
|
2029
|
+
link_mlmm: List[Tuple[str, str]] | None = None,
|
|
2030
|
+
# ML backend selection
|
|
2031
|
+
backend: str = "uma",
|
|
2032
|
+
uma_model: str = "uma-s-1p1",
|
|
2033
|
+
uma_task_name: str = "omol",
|
|
2034
|
+
orb_model: str = "orb_v3_conservative_omol",
|
|
2035
|
+
orb_precision: str = "float32",
|
|
2036
|
+
mace_model: str = "MACE-OMOL-0",
|
|
2037
|
+
mace_dtype: str = "float64",
|
|
2038
|
+
aimnet2_model: str = "aimnet2",
|
|
2039
|
+
# MM settings
|
|
2040
|
+
mm_fd: bool = True,
|
|
2041
|
+
mm_fd_dir: Optional[str] = None,
|
|
2042
|
+
mm_fd_delta: float = 1e-3,
|
|
2043
|
+
symmetrize_hessian: bool = True,
|
|
2044
|
+
out_hess_torch: bool = True,
|
|
2045
|
+
H_double: bool = False,
|
|
2046
|
+
ml_hessian_mode: str = "FiniteDifference",
|
|
2047
|
+
hessian_calc_mode: Optional[str] = None,
|
|
2048
|
+
ml_device: str = "auto",
|
|
2049
|
+
ml_cuda_idx: int = 0,
|
|
2050
|
+
mm_device: str = "cpu",
|
|
2051
|
+
mm_cuda_idx: int = 0,
|
|
2052
|
+
mm_threads: int = 16,
|
|
2053
|
+
mm_backend: str = "hessian_ff",
|
|
2054
|
+
freeze_atoms: List[int] | None = None,
|
|
2055
|
+
return_partial_hessian: bool = True,
|
|
2056
|
+
print_timing: bool = True,
|
|
2057
|
+
print_vram: bool = True,
|
|
2058
|
+
hess_cutoff: Optional[float] = None,
|
|
2059
|
+
movable_cutoff: Optional[float] = None,
|
|
2060
|
+
use_bfactor_layers: bool = False,
|
|
2061
|
+
hess_mm_atoms: Optional[List[int]] = None,
|
|
2062
|
+
movable_mm_atoms: Optional[List[int]] = None,
|
|
2063
|
+
frozen_mm_atoms: Optional[List[int]] = None,
|
|
2064
|
+
# Point-charge embedding correction
|
|
2065
|
+
embedcharge: bool = False,
|
|
2066
|
+
embedcharge_step: float = 1.0e-3,
|
|
2067
|
+
embedcharge_cutoff: Optional[float] = None,
|
|
2068
|
+
xtb_cmd: str = "xtb",
|
|
2069
|
+
xtb_acc: float = 0.2,
|
|
2070
|
+
xtb_workdir: str = "tmp",
|
|
2071
|
+
xtb_keep_files: bool = False,
|
|
2072
|
+
xtb_ncores: int = 4,
|
|
2073
|
+
**kwargs,
|
|
2074
|
+
):
|
|
2075
|
+
# --- v0.1.x backward compatibility aliases ---
|
|
2076
|
+
if "real_pdb" in kwargs:
|
|
2077
|
+
warnings.warn("'real_pdb' is deprecated; use 'input_pdb'.", DeprecationWarning, stacklevel=2)
|
|
2078
|
+
if input_pdb is None:
|
|
2079
|
+
input_pdb = kwargs.pop("real_pdb")
|
|
2080
|
+
else:
|
|
2081
|
+
kwargs.pop("real_pdb")
|
|
2082
|
+
for _old_name in ("real_rst7", "vib_run", "vib_dir"):
|
|
2083
|
+
if _old_name in kwargs:
|
|
2084
|
+
warnings.warn(f"'{_old_name}' is no longer used and will be ignored.", DeprecationWarning, stacklevel=2)
|
|
2085
|
+
kwargs.pop(_old_name)
|
|
2086
|
+
|
|
2087
|
+
self._freeze_atoms = [] if freeze_atoms is None else list(freeze_atoms)
|
|
2088
|
+
super().__init__(charge=model_charge, mult=model_mult, **kwargs)
|
|
2089
|
+
|
|
2090
|
+
self.core = MLMMCore(
|
|
2091
|
+
input_pdb=input_pdb,
|
|
2092
|
+
real_parm7=real_parm7,
|
|
2093
|
+
model_pdb=model_pdb,
|
|
2094
|
+
model_charge=model_charge,
|
|
2095
|
+
model_mult=model_mult,
|
|
2096
|
+
link_mlmm=link_mlmm,
|
|
2097
|
+
backend=backend,
|
|
2098
|
+
uma_model=uma_model,
|
|
2099
|
+
uma_task_name=uma_task_name,
|
|
2100
|
+
orb_model=orb_model,
|
|
2101
|
+
mace_model=mace_model,
|
|
2102
|
+
mace_dtype=mace_dtype,
|
|
2103
|
+
aimnet2_model=aimnet2_model,
|
|
2104
|
+
mm_fd=mm_fd,
|
|
2105
|
+
mm_fd_dir=mm_fd_dir,
|
|
2106
|
+
mm_fd_delta=mm_fd_delta,
|
|
2107
|
+
symmetrize_hessian=symmetrize_hessian,
|
|
2108
|
+
H_double=H_double,
|
|
2109
|
+
ml_device=ml_device,
|
|
2110
|
+
ml_cuda_idx=ml_cuda_idx,
|
|
2111
|
+
mm_device=mm_device,
|
|
2112
|
+
mm_cuda_idx=mm_cuda_idx,
|
|
2113
|
+
mm_threads=mm_threads,
|
|
2114
|
+
mm_backend=mm_backend,
|
|
2115
|
+
freeze_atoms=self._freeze_atoms,
|
|
2116
|
+
ml_hessian_mode=ml_hessian_mode,
|
|
2117
|
+
hessian_calc_mode=hessian_calc_mode,
|
|
2118
|
+
return_partial_hessian=return_partial_hessian,
|
|
2119
|
+
print_timing=print_timing,
|
|
2120
|
+
print_vram=print_vram,
|
|
2121
|
+
hess_cutoff=hess_cutoff,
|
|
2122
|
+
movable_cutoff=movable_cutoff,
|
|
2123
|
+
use_bfactor_layers=use_bfactor_layers,
|
|
2124
|
+
hess_mm_atoms=hess_mm_atoms,
|
|
2125
|
+
movable_mm_atoms=movable_mm_atoms,
|
|
2126
|
+
frozen_mm_atoms=frozen_mm_atoms,
|
|
2127
|
+
embedcharge=embedcharge,
|
|
2128
|
+
embedcharge_step=embedcharge_step,
|
|
2129
|
+
embedcharge_cutoff=embedcharge_cutoff,
|
|
2130
|
+
xtb_cmd=xtb_cmd,
|
|
2131
|
+
xtb_acc=xtb_acc,
|
|
2132
|
+
xtb_workdir=xtb_workdir,
|
|
2133
|
+
xtb_keep_files=xtb_keep_files,
|
|
2134
|
+
xtb_ncores=xtb_ncores,
|
|
2135
|
+
)
|
|
2136
|
+
|
|
2137
|
+
self.out_hess_torch = bool(out_hess_torch)
|
|
2138
|
+
self.hess_torch_double = bool(H_double)
|
|
2139
|
+
self._hess_scale = EV2AU / ANG2BOHR / ANG2BOHR
|
|
2140
|
+
|
|
2141
|
+
@property
|
|
2142
|
+
def freeze_atoms(self) -> List[int] | None:
|
|
2143
|
+
return self.core.freeze_atoms
|
|
2144
|
+
|
|
2145
|
+
@freeze_atoms.setter
|
|
2146
|
+
def freeze_atoms(self, indices: List[int] | None):
|
|
2147
|
+
self._freeze_atoms = [] if indices is None else list(indices)
|
|
2148
|
+
self.core.freeze_atoms = self._freeze_atoms
|
|
2149
|
+
self.core._update_active_dof_mappings()
|
|
2150
|
+
|
|
2151
|
+
def _run_core(self, coords, *, want_forces: bool, want_hessian: bool):
|
|
2152
|
+
coord_ang = np.asarray(coords).reshape(-1, 3) * BOHR2ANG
|
|
2153
|
+
res = self.core.compute(coord_ang, return_forces=want_forces or want_hessian, return_hessian=want_hessian)
|
|
2154
|
+
out = {"energy": res["energy"] * EV2AU}
|
|
2155
|
+
if want_forces or want_hessian:
|
|
2156
|
+
out["forces"] = (res["forces"] * (EV2AU / ANG2BOHR)).flatten()
|
|
2157
|
+
if want_hessian:
|
|
2158
|
+
H = res.pop("hessian")
|
|
2159
|
+
H = H.view(H.size(0) * 3, H.size(2) * 3)
|
|
2160
|
+
H.mul_(self._hess_scale)
|
|
2161
|
+
if self.out_hess_torch:
|
|
2162
|
+
target_dtype = torch.float64 if self.hess_torch_double else torch.float32
|
|
2163
|
+
out["hessian"] = H.to(target_dtype).detach().requires_grad_(False)
|
|
2164
|
+
else:
|
|
2165
|
+
out["hessian"] = H.detach().cpu().numpy()
|
|
2166
|
+
if "within_partial_hessian" in res:
|
|
2167
|
+
out["within_partial_hessian"] = res["within_partial_hessian"]
|
|
2168
|
+
return out
|
|
2169
|
+
|
|
2170
|
+
def get_energy(self, elem, coords):
|
|
2171
|
+
return self._run_core(coords, want_forces=False, want_hessian=False)
|
|
2172
|
+
|
|
2173
|
+
def get_forces(self, elem, coords):
|
|
2174
|
+
return self._run_core(coords, want_forces=True, want_hessian=False)
|
|
2175
|
+
|
|
2176
|
+
def get_hessian(self, elem, coords):
|
|
2177
|
+
return self._run_core(coords, want_forces=True, want_hessian=True)
|
|
2178
|
+
|
|
2179
|
+
|
|
2180
|
+
# ======================================================================
|
|
2181
|
+
# PySisyphus Calculator (MM-only)
|
|
2182
|
+
# ======================================================================
|
|
2183
|
+
|
|
2184
|
+
|
|
2185
|
+
class mlmm_mm_only(PySiCalc):
|
|
2186
|
+
"""PySisyphus calculator that returns MM-only energy and forces (F_real_mm).
|
|
2187
|
+
|
|
2188
|
+
Used for microiteration: relaxes the MM region without ML computation.
|
|
2189
|
+
Shares the MLMMCore from an existing ``mlmm`` calculator to avoid
|
|
2190
|
+
re-initializing topology and force field objects.
|
|
2191
|
+
"""
|
|
2192
|
+
|
|
2193
|
+
implemented_properties = ["energy", "forces"]
|
|
2194
|
+
|
|
2195
|
+
def __init__(self, core: "MLMMCore", *, freeze_atoms: list[int] | None = None, **kwargs):
|
|
2196
|
+
super().__init__(charge=core.model_charge, mult=core.model_mult, **kwargs)
|
|
2197
|
+
self.core = core
|
|
2198
|
+
self._freeze_atoms = list(freeze_atoms) if freeze_atoms else []
|
|
2199
|
+
|
|
2200
|
+
def _run_core(self, coords, *, want_forces: bool):
|
|
2201
|
+
coord_ang = np.asarray(coords).reshape(-1, 3) * BOHR2ANG
|
|
2202
|
+
atoms_real = self.core._atoms_real_tpl.copy()
|
|
2203
|
+
atoms_real.set_positions(coord_ang)
|
|
2204
|
+
atoms_real.set_pbc(False)
|
|
2205
|
+
atoms_real.calc = self.core.calc_real_low
|
|
2206
|
+
E_real = float(atoms_real.get_potential_energy())
|
|
2207
|
+
out = {"energy": E_real * EV2AU}
|
|
2208
|
+
if want_forces:
|
|
2209
|
+
F_real = np.double(atoms_real.get_forces())
|
|
2210
|
+
# Zero forces on frozen atoms
|
|
2211
|
+
for i in self._freeze_atoms:
|
|
2212
|
+
if 0 <= i < F_real.shape[0]:
|
|
2213
|
+
F_real[i, :] = 0.0
|
|
2214
|
+
out["forces"] = (F_real * (EV2AU / ANG2BOHR)).flatten()
|
|
2215
|
+
return out
|
|
2216
|
+
|
|
2217
|
+
def get_energy(self, elem, coords):
|
|
2218
|
+
return self._run_core(coords, want_forces=False)
|
|
2219
|
+
|
|
2220
|
+
def get_forces(self, elem, coords):
|
|
2221
|
+
return self._run_core(coords, want_forces=True)
|
|
2222
|
+
|
|
2223
|
+
def get_hessian(self, elem, coords):
|
|
2224
|
+
raise NotImplementedError("MM-only calculator does not support Hessian computation.")
|
|
2225
|
+
|
|
2226
|
+
|
|
2227
|
+
# ======================================================================
|
|
2228
|
+
# v0.1.x compatibility: mlmm_ase() factory
|
|
2229
|
+
# ======================================================================
|
|
2230
|
+
|
|
2231
|
+
|
|
2232
|
+
def mlmm_ase(**kwargs):
|
|
2233
|
+
"""v0.1.x compatibility wrapper.
|
|
2234
|
+
|
|
2235
|
+
Accepts all MLMMCore parameters as keyword arguments and returns
|
|
2236
|
+
an MLMMASECalculator. Equivalent to::
|
|
2237
|
+
|
|
2238
|
+
MLMMASECalculator(MLMMCore(**kwargs))
|
|
2239
|
+
"""
|
|
2240
|
+
warnings.warn(
|
|
2241
|
+
"mlmm_ase() is deprecated; use MLMMASECalculator(MLMMCore(...)) instead.",
|
|
2242
|
+
DeprecationWarning,
|
|
2243
|
+
stacklevel=2,
|
|
2244
|
+
)
|
|
2245
|
+
core = MLMMCore(**kwargs)
|
|
2246
|
+
return MLMMASECalculator(core)
|
|
2247
|
+
|
|
2248
|
+
|
|
2249
|
+
# ======================================================================
|
|
2250
|
+
# CLI registration
|
|
2251
|
+
# ======================================================================
|
|
2252
|
+
|
|
2253
|
+
from pysisyphus import run as _run
|
|
2254
|
+
|
|
2255
|
+
|
|
2256
|
+
def run_pysis_mlmm():
|
|
2257
|
+
_run.CALC_DICT["mlmm"] = mlmm
|
|
2258
|
+
_run.run()
|
|
2259
|
+
|
|
2260
|
+
|
|
2261
|
+
if __name__ == "__main__":
|
|
2262
|
+
run_pysis_mlmm()
|