mlmm-toolkit 0.2.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (372) hide show
  1. hessian_ff/__init__.py +50 -0
  2. hessian_ff/analytical_hessian.py +609 -0
  3. hessian_ff/constants.py +46 -0
  4. hessian_ff/forcefield.py +339 -0
  5. hessian_ff/loaders.py +608 -0
  6. hessian_ff/native/Makefile +8 -0
  7. hessian_ff/native/__init__.py +28 -0
  8. hessian_ff/native/analytical_hessian.py +88 -0
  9. hessian_ff/native/analytical_hessian_ext.cpp +258 -0
  10. hessian_ff/native/bonded.py +82 -0
  11. hessian_ff/native/bonded_ext.cpp +640 -0
  12. hessian_ff/native/loader.py +349 -0
  13. hessian_ff/native/nonbonded.py +118 -0
  14. hessian_ff/native/nonbonded_ext.cpp +1150 -0
  15. hessian_ff/prmtop_parmed.py +23 -0
  16. hessian_ff/system.py +107 -0
  17. hessian_ff/terms/__init__.py +14 -0
  18. hessian_ff/terms/angle.py +73 -0
  19. hessian_ff/terms/bond.py +44 -0
  20. hessian_ff/terms/cmap.py +406 -0
  21. hessian_ff/terms/dihedral.py +141 -0
  22. hessian_ff/terms/nonbonded.py +209 -0
  23. hessian_ff/tests/__init__.py +0 -0
  24. hessian_ff/tests/conftest.py +75 -0
  25. hessian_ff/tests/data/small/complex.parm7 +1346 -0
  26. hessian_ff/tests/data/small/complex.pdb +125 -0
  27. hessian_ff/tests/data/small/complex.rst7 +63 -0
  28. hessian_ff/tests/test_coords_input.py +44 -0
  29. hessian_ff/tests/test_energy_force.py +49 -0
  30. hessian_ff/tests/test_hessian.py +137 -0
  31. hessian_ff/tests/test_smoke.py +18 -0
  32. hessian_ff/tests/test_validation.py +40 -0
  33. hessian_ff/workflows.py +889 -0
  34. mlmm/__init__.py +36 -0
  35. mlmm/__main__.py +7 -0
  36. mlmm/_version.py +34 -0
  37. mlmm/add_elem_info.py +374 -0
  38. mlmm/advanced_help.py +91 -0
  39. mlmm/align_freeze_atoms.py +601 -0
  40. mlmm/all.py +3535 -0
  41. mlmm/bond_changes.py +231 -0
  42. mlmm/bool_compat.py +223 -0
  43. mlmm/cli.py +574 -0
  44. mlmm/cli_utils.py +166 -0
  45. mlmm/default_group.py +337 -0
  46. mlmm/defaults.py +467 -0
  47. mlmm/define_layer.py +526 -0
  48. mlmm/dft.py +1041 -0
  49. mlmm/energy_diagram.py +253 -0
  50. mlmm/extract.py +2213 -0
  51. mlmm/fix_altloc.py +464 -0
  52. mlmm/freq.py +1406 -0
  53. mlmm/harmonic_constraints.py +140 -0
  54. mlmm/hessian_cache.py +44 -0
  55. mlmm/hessian_calc.py +174 -0
  56. mlmm/irc.py +638 -0
  57. mlmm/mlmm_calc.py +2262 -0
  58. mlmm/mm_parm.py +945 -0
  59. mlmm/oniom_export.py +1983 -0
  60. mlmm/oniom_import.py +457 -0
  61. mlmm/opt.py +1742 -0
  62. mlmm/path_opt.py +1353 -0
  63. mlmm/path_search.py +2299 -0
  64. mlmm/preflight.py +88 -0
  65. mlmm/py.typed +1 -0
  66. mlmm/pysis_runner.py +45 -0
  67. mlmm/scan.py +1047 -0
  68. mlmm/scan2d.py +1226 -0
  69. mlmm/scan3d.py +1265 -0
  70. mlmm/scan_common.py +184 -0
  71. mlmm/summary_log.py +736 -0
  72. mlmm/trj2fig.py +448 -0
  73. mlmm/tsopt.py +2871 -0
  74. mlmm/utils.py +2309 -0
  75. mlmm/xtb_embedcharge_correction.py +475 -0
  76. mlmm_toolkit-0.2.2.dev0.dist-info/METADATA +1159 -0
  77. mlmm_toolkit-0.2.2.dev0.dist-info/RECORD +372 -0
  78. mlmm_toolkit-0.2.2.dev0.dist-info/WHEEL +5 -0
  79. mlmm_toolkit-0.2.2.dev0.dist-info/entry_points.txt +2 -0
  80. mlmm_toolkit-0.2.2.dev0.dist-info/licenses/LICENSE +674 -0
  81. mlmm_toolkit-0.2.2.dev0.dist-info/top_level.txt +4 -0
  82. pysisyphus/Geometry.py +1667 -0
  83. pysisyphus/LICENSE +674 -0
  84. pysisyphus/TableFormatter.py +63 -0
  85. pysisyphus/TablePrinter.py +74 -0
  86. pysisyphus/__init__.py +12 -0
  87. pysisyphus/calculators/AFIR.py +452 -0
  88. pysisyphus/calculators/AnaPot.py +20 -0
  89. pysisyphus/calculators/AnaPot2.py +48 -0
  90. pysisyphus/calculators/AnaPot3.py +12 -0
  91. pysisyphus/calculators/AnaPot4.py +20 -0
  92. pysisyphus/calculators/AnaPotBase.py +337 -0
  93. pysisyphus/calculators/AnaPotCBM.py +25 -0
  94. pysisyphus/calculators/AtomAtomTransTorque.py +154 -0
  95. pysisyphus/calculators/CFOUR.py +250 -0
  96. pysisyphus/calculators/Calculator.py +844 -0
  97. pysisyphus/calculators/CerjanMiller.py +24 -0
  98. pysisyphus/calculators/Composite.py +123 -0
  99. pysisyphus/calculators/ConicalIntersection.py +171 -0
  100. pysisyphus/calculators/DFTBp.py +430 -0
  101. pysisyphus/calculators/DFTD3.py +66 -0
  102. pysisyphus/calculators/DFTD4.py +84 -0
  103. pysisyphus/calculators/Dalton.py +61 -0
  104. pysisyphus/calculators/Dimer.py +681 -0
  105. pysisyphus/calculators/Dummy.py +20 -0
  106. pysisyphus/calculators/EGO.py +76 -0
  107. pysisyphus/calculators/EnergyMin.py +224 -0
  108. pysisyphus/calculators/ExternalPotential.py +264 -0
  109. pysisyphus/calculators/FakeASE.py +35 -0
  110. pysisyphus/calculators/FourWellAnaPot.py +28 -0
  111. pysisyphus/calculators/FreeEndNEBPot.py +39 -0
  112. pysisyphus/calculators/Gaussian09.py +18 -0
  113. pysisyphus/calculators/Gaussian16.py +726 -0
  114. pysisyphus/calculators/HardSphere.py +159 -0
  115. pysisyphus/calculators/IDPPCalculator.py +49 -0
  116. pysisyphus/calculators/IPIClient.py +133 -0
  117. pysisyphus/calculators/IPIServer.py +234 -0
  118. pysisyphus/calculators/LEPSBase.py +24 -0
  119. pysisyphus/calculators/LEPSExpr.py +139 -0
  120. pysisyphus/calculators/LennardJones.py +80 -0
  121. pysisyphus/calculators/MOPAC.py +219 -0
  122. pysisyphus/calculators/MullerBrownSympyPot.py +51 -0
  123. pysisyphus/calculators/MultiCalc.py +85 -0
  124. pysisyphus/calculators/NFK.py +45 -0
  125. pysisyphus/calculators/OBabel.py +87 -0
  126. pysisyphus/calculators/ONIOMv2.py +1129 -0
  127. pysisyphus/calculators/ORCA.py +893 -0
  128. pysisyphus/calculators/ORCA5.py +6 -0
  129. pysisyphus/calculators/OpenMM.py +88 -0
  130. pysisyphus/calculators/OpenMolcas.py +281 -0
  131. pysisyphus/calculators/OverlapCalculator.py +908 -0
  132. pysisyphus/calculators/Psi4.py +218 -0
  133. pysisyphus/calculators/PyPsi4.py +37 -0
  134. pysisyphus/calculators/PySCF.py +341 -0
  135. pysisyphus/calculators/PyXTB.py +73 -0
  136. pysisyphus/calculators/QCEngine.py +106 -0
  137. pysisyphus/calculators/Rastrigin.py +22 -0
  138. pysisyphus/calculators/Remote.py +76 -0
  139. pysisyphus/calculators/Rosenbrock.py +15 -0
  140. pysisyphus/calculators/SocketCalc.py +97 -0
  141. pysisyphus/calculators/TIP3P.py +111 -0
  142. pysisyphus/calculators/TransTorque.py +161 -0
  143. pysisyphus/calculators/Turbomole.py +965 -0
  144. pysisyphus/calculators/VRIPot.py +37 -0
  145. pysisyphus/calculators/WFOWrapper.py +333 -0
  146. pysisyphus/calculators/WFOWrapper2.py +341 -0
  147. pysisyphus/calculators/XTB.py +418 -0
  148. pysisyphus/calculators/__init__.py +81 -0
  149. pysisyphus/calculators/cosmo_data.py +139 -0
  150. pysisyphus/calculators/parser.py +150 -0
  151. pysisyphus/color.py +19 -0
  152. pysisyphus/config.py +133 -0
  153. pysisyphus/constants.py +65 -0
  154. pysisyphus/cos/AdaptiveNEB.py +230 -0
  155. pysisyphus/cos/ChainOfStates.py +725 -0
  156. pysisyphus/cos/FreeEndNEB.py +25 -0
  157. pysisyphus/cos/FreezingString.py +103 -0
  158. pysisyphus/cos/GrowingChainOfStates.py +71 -0
  159. pysisyphus/cos/GrowingNT.py +309 -0
  160. pysisyphus/cos/GrowingString.py +508 -0
  161. pysisyphus/cos/NEB.py +189 -0
  162. pysisyphus/cos/SimpleZTS.py +64 -0
  163. pysisyphus/cos/__init__.py +22 -0
  164. pysisyphus/cos/stiffness.py +199 -0
  165. pysisyphus/drivers/__init__.py +17 -0
  166. pysisyphus/drivers/afir.py +855 -0
  167. pysisyphus/drivers/barriers.py +271 -0
  168. pysisyphus/drivers/birkholz.py +138 -0
  169. pysisyphus/drivers/cluster.py +318 -0
  170. pysisyphus/drivers/diabatization.py +133 -0
  171. pysisyphus/drivers/merge.py +368 -0
  172. pysisyphus/drivers/merge_mol2.py +322 -0
  173. pysisyphus/drivers/opt.py +375 -0
  174. pysisyphus/drivers/perf.py +91 -0
  175. pysisyphus/drivers/pka.py +52 -0
  176. pysisyphus/drivers/precon_pos_rot.py +669 -0
  177. pysisyphus/drivers/rates.py +480 -0
  178. pysisyphus/drivers/replace.py +219 -0
  179. pysisyphus/drivers/scan.py +212 -0
  180. pysisyphus/drivers/spectrum.py +166 -0
  181. pysisyphus/drivers/thermo.py +31 -0
  182. pysisyphus/dynamics/Gaussian.py +103 -0
  183. pysisyphus/dynamics/__init__.py +20 -0
  184. pysisyphus/dynamics/colvars.py +136 -0
  185. pysisyphus/dynamics/driver.py +297 -0
  186. pysisyphus/dynamics/helpers.py +256 -0
  187. pysisyphus/dynamics/lincs.py +105 -0
  188. pysisyphus/dynamics/mdp.py +364 -0
  189. pysisyphus/dynamics/rattle.py +121 -0
  190. pysisyphus/dynamics/thermostats.py +128 -0
  191. pysisyphus/dynamics/wigner.py +266 -0
  192. pysisyphus/elem_data.py +3473 -0
  193. pysisyphus/exceptions.py +2 -0
  194. pysisyphus/filtertrj.py +69 -0
  195. pysisyphus/helpers.py +623 -0
  196. pysisyphus/helpers_pure.py +649 -0
  197. pysisyphus/init_logging.py +50 -0
  198. pysisyphus/intcoords/Bend.py +69 -0
  199. pysisyphus/intcoords/Bend2.py +25 -0
  200. pysisyphus/intcoords/BondedFragment.py +32 -0
  201. pysisyphus/intcoords/Cartesian.py +41 -0
  202. pysisyphus/intcoords/CartesianCoords.py +140 -0
  203. pysisyphus/intcoords/Coords.py +56 -0
  204. pysisyphus/intcoords/DLC.py +197 -0
  205. pysisyphus/intcoords/DistanceFunction.py +34 -0
  206. pysisyphus/intcoords/DummyImproper.py +70 -0
  207. pysisyphus/intcoords/DummyTorsion.py +72 -0
  208. pysisyphus/intcoords/LinearBend.py +105 -0
  209. pysisyphus/intcoords/LinearDisplacement.py +80 -0
  210. pysisyphus/intcoords/OutOfPlane.py +59 -0
  211. pysisyphus/intcoords/PrimTypes.py +286 -0
  212. pysisyphus/intcoords/Primitive.py +137 -0
  213. pysisyphus/intcoords/RedundantCoords.py +659 -0
  214. pysisyphus/intcoords/RobustTorsion.py +59 -0
  215. pysisyphus/intcoords/Rotation.py +147 -0
  216. pysisyphus/intcoords/Stretch.py +31 -0
  217. pysisyphus/intcoords/Torsion.py +101 -0
  218. pysisyphus/intcoords/Torsion2.py +25 -0
  219. pysisyphus/intcoords/Translation.py +45 -0
  220. pysisyphus/intcoords/__init__.py +61 -0
  221. pysisyphus/intcoords/augment_bonds.py +126 -0
  222. pysisyphus/intcoords/derivatives.py +10512 -0
  223. pysisyphus/intcoords/eval.py +80 -0
  224. pysisyphus/intcoords/exceptions.py +37 -0
  225. pysisyphus/intcoords/findiffs.py +48 -0
  226. pysisyphus/intcoords/generate_derivatives.py +414 -0
  227. pysisyphus/intcoords/helpers.py +235 -0
  228. pysisyphus/intcoords/logging_conf.py +10 -0
  229. pysisyphus/intcoords/mp_derivatives.py +10836 -0
  230. pysisyphus/intcoords/setup.py +962 -0
  231. pysisyphus/intcoords/setup_fast.py +176 -0
  232. pysisyphus/intcoords/update.py +272 -0
  233. pysisyphus/intcoords/valid.py +89 -0
  234. pysisyphus/interpolate/Geodesic.py +93 -0
  235. pysisyphus/interpolate/IDPP.py +55 -0
  236. pysisyphus/interpolate/Interpolator.py +116 -0
  237. pysisyphus/interpolate/LST.py +70 -0
  238. pysisyphus/interpolate/Redund.py +152 -0
  239. pysisyphus/interpolate/__init__.py +9 -0
  240. pysisyphus/interpolate/helpers.py +34 -0
  241. pysisyphus/io/__init__.py +22 -0
  242. pysisyphus/io/aomix.py +178 -0
  243. pysisyphus/io/cjson.py +24 -0
  244. pysisyphus/io/crd.py +101 -0
  245. pysisyphus/io/cube.py +220 -0
  246. pysisyphus/io/fchk.py +184 -0
  247. pysisyphus/io/hdf5.py +49 -0
  248. pysisyphus/io/hessian.py +72 -0
  249. pysisyphus/io/mol2.py +146 -0
  250. pysisyphus/io/molden.py +293 -0
  251. pysisyphus/io/orca.py +189 -0
  252. pysisyphus/io/pdb.py +269 -0
  253. pysisyphus/io/psf.py +79 -0
  254. pysisyphus/io/pubchem.py +31 -0
  255. pysisyphus/io/qcschema.py +34 -0
  256. pysisyphus/io/sdf.py +29 -0
  257. pysisyphus/io/xyz.py +61 -0
  258. pysisyphus/io/zmat.py +175 -0
  259. pysisyphus/irc/DWI.py +108 -0
  260. pysisyphus/irc/DampedVelocityVerlet.py +134 -0
  261. pysisyphus/irc/Euler.py +22 -0
  262. pysisyphus/irc/EulerPC.py +345 -0
  263. pysisyphus/irc/GonzalezSchlegel.py +187 -0
  264. pysisyphus/irc/IMKMod.py +164 -0
  265. pysisyphus/irc/IRC.py +878 -0
  266. pysisyphus/irc/IRCDummy.py +10 -0
  267. pysisyphus/irc/Instanton.py +307 -0
  268. pysisyphus/irc/LQA.py +53 -0
  269. pysisyphus/irc/ModeKill.py +136 -0
  270. pysisyphus/irc/ParamPlot.py +53 -0
  271. pysisyphus/irc/RK4.py +36 -0
  272. pysisyphus/irc/__init__.py +31 -0
  273. pysisyphus/irc/initial_displ.py +219 -0
  274. pysisyphus/linalg.py +411 -0
  275. pysisyphus/line_searches/Backtracking.py +88 -0
  276. pysisyphus/line_searches/HagerZhang.py +184 -0
  277. pysisyphus/line_searches/LineSearch.py +232 -0
  278. pysisyphus/line_searches/StrongWolfe.py +108 -0
  279. pysisyphus/line_searches/__init__.py +9 -0
  280. pysisyphus/line_searches/interpol.py +15 -0
  281. pysisyphus/modefollow/NormalMode.py +40 -0
  282. pysisyphus/modefollow/__init__.py +10 -0
  283. pysisyphus/modefollow/davidson.py +199 -0
  284. pysisyphus/modefollow/lanczos.py +95 -0
  285. pysisyphus/optimizers/BFGS.py +99 -0
  286. pysisyphus/optimizers/BacktrackingOptimizer.py +113 -0
  287. pysisyphus/optimizers/ConjugateGradient.py +98 -0
  288. pysisyphus/optimizers/CubicNewton.py +75 -0
  289. pysisyphus/optimizers/FIRE.py +113 -0
  290. pysisyphus/optimizers/HessianOptimizer.py +1176 -0
  291. pysisyphus/optimizers/LBFGS.py +228 -0
  292. pysisyphus/optimizers/LayerOpt.py +411 -0
  293. pysisyphus/optimizers/MicroOptimizer.py +169 -0
  294. pysisyphus/optimizers/NCOptimizer.py +90 -0
  295. pysisyphus/optimizers/Optimizer.py +1084 -0
  296. pysisyphus/optimizers/PreconLBFGS.py +260 -0
  297. pysisyphus/optimizers/PreconSteepestDescent.py +7 -0
  298. pysisyphus/optimizers/QuickMin.py +74 -0
  299. pysisyphus/optimizers/RFOptimizer.py +181 -0
  300. pysisyphus/optimizers/RSA.py +99 -0
  301. pysisyphus/optimizers/StabilizedQNMethod.py +248 -0
  302. pysisyphus/optimizers/SteepestDescent.py +23 -0
  303. pysisyphus/optimizers/StringOptimizer.py +173 -0
  304. pysisyphus/optimizers/__init__.py +41 -0
  305. pysisyphus/optimizers/closures.py +301 -0
  306. pysisyphus/optimizers/cls_map.py +58 -0
  307. pysisyphus/optimizers/exceptions.py +6 -0
  308. pysisyphus/optimizers/gdiis.py +280 -0
  309. pysisyphus/optimizers/guess_hessians.py +311 -0
  310. pysisyphus/optimizers/hessian_updates.py +355 -0
  311. pysisyphus/optimizers/poly_fit.py +285 -0
  312. pysisyphus/optimizers/precon.py +153 -0
  313. pysisyphus/optimizers/restrict_step.py +24 -0
  314. pysisyphus/pack.py +172 -0
  315. pysisyphus/peakdetect.py +948 -0
  316. pysisyphus/plot.py +1031 -0
  317. pysisyphus/run.py +2106 -0
  318. pysisyphus/socket_helper.py +74 -0
  319. pysisyphus/stocastic/FragmentKick.py +132 -0
  320. pysisyphus/stocastic/Kick.py +81 -0
  321. pysisyphus/stocastic/Pipeline.py +303 -0
  322. pysisyphus/stocastic/__init__.py +21 -0
  323. pysisyphus/stocastic/align.py +127 -0
  324. pysisyphus/testing.py +96 -0
  325. pysisyphus/thermo.py +156 -0
  326. pysisyphus/trj.py +824 -0
  327. pysisyphus/tsoptimizers/RSIRFOptimizer.py +56 -0
  328. pysisyphus/tsoptimizers/RSPRFOptimizer.py +182 -0
  329. pysisyphus/tsoptimizers/TRIM.py +59 -0
  330. pysisyphus/tsoptimizers/TSHessianOptimizer.py +463 -0
  331. pysisyphus/tsoptimizers/__init__.py +23 -0
  332. pysisyphus/wavefunction/Basis.py +239 -0
  333. pysisyphus/wavefunction/DIIS.py +76 -0
  334. pysisyphus/wavefunction/__init__.py +25 -0
  335. pysisyphus/wavefunction/build_ext.py +42 -0
  336. pysisyphus/wavefunction/cart2sph.py +190 -0
  337. pysisyphus/wavefunction/diabatization.py +304 -0
  338. pysisyphus/wavefunction/excited_states.py +435 -0
  339. pysisyphus/wavefunction/gen_ints.py +1811 -0
  340. pysisyphus/wavefunction/helpers.py +104 -0
  341. pysisyphus/wavefunction/ints/__init__.py +0 -0
  342. pysisyphus/wavefunction/ints/boys.py +193 -0
  343. pysisyphus/wavefunction/ints/boys_table_N_64_xasym_27.1_step_0.01.npy +0 -0
  344. pysisyphus/wavefunction/ints/cart_gto3d.py +176 -0
  345. pysisyphus/wavefunction/ints/coulomb3d.py +25928 -0
  346. pysisyphus/wavefunction/ints/diag_quadrupole3d.py +10036 -0
  347. pysisyphus/wavefunction/ints/dipole3d.py +8762 -0
  348. pysisyphus/wavefunction/ints/int2c2e3d.py +7198 -0
  349. pysisyphus/wavefunction/ints/int3c2e3d_sph.py +65040 -0
  350. pysisyphus/wavefunction/ints/kinetic3d.py +8240 -0
  351. pysisyphus/wavefunction/ints/ovlp3d.py +3777 -0
  352. pysisyphus/wavefunction/ints/quadrupole3d.py +15054 -0
  353. pysisyphus/wavefunction/ints/self_ovlp3d.py +198 -0
  354. pysisyphus/wavefunction/localization.py +458 -0
  355. pysisyphus/wavefunction/multipole.py +159 -0
  356. pysisyphus/wavefunction/normalization.py +36 -0
  357. pysisyphus/wavefunction/pop_analysis.py +134 -0
  358. pysisyphus/wavefunction/shells.py +1171 -0
  359. pysisyphus/wavefunction/wavefunction.py +504 -0
  360. pysisyphus/wrapper/__init__.py +11 -0
  361. pysisyphus/wrapper/exceptions.py +2 -0
  362. pysisyphus/wrapper/jmol.py +120 -0
  363. pysisyphus/wrapper/mwfn.py +169 -0
  364. pysisyphus/wrapper/packmol.py +71 -0
  365. pysisyphus/xyzloader.py +168 -0
  366. pysisyphus/yaml_mods.py +45 -0
  367. thermoanalysis/LICENSE +674 -0
  368. thermoanalysis/QCData.py +244 -0
  369. thermoanalysis/__init__.py +0 -0
  370. thermoanalysis/config.py +3 -0
  371. thermoanalysis/constants.py +20 -0
  372. thermoanalysis/thermo.py +1011 -0
mlmm/freq.py ADDED
@@ -0,0 +1,1406 @@
1
+ """
2
+ ML/MM vibrational frequency analysis with PHVA support and thermochemistry.
3
+
4
+ Example:
5
+ mlmm freq -i pocket.pdb --parm real.parm7 --model-pdb ml_region.pdb -q 0
6
+
7
+ For detailed documentation, see: docs/freq.md
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import gc
13
+ import logging
14
+ import sys
15
+ import textwrap
16
+ import time
17
+
18
+ logger = logging.getLogger(__name__)
19
+ from copy import deepcopy
20
+ from pathlib import Path
21
+ from typing import Any, Dict, List, Optional, Tuple
22
+
23
+ import click
24
+ import numpy as np
25
+ import torch
26
+ import ase.units as units
27
+ import yaml
28
+ from ase import Atoms
29
+ from ase.data import atomic_masses
30
+ from ase.io import write
31
+
32
+ from pysisyphus.constants import AMU2AU, ANG2BOHR, AU2EV, BOHR2ANG
33
+ from pysisyphus.helpers import geom_loader
34
+
35
+ from .mlmm_calc import mlmm
36
+ from .defaults import FREQ_KW, THERMO_KW
37
+ from .opt import (
38
+ CALC_KW as OPT_CALC_KW,
39
+ GEOM_KW as OPT_GEOM_KW,
40
+ _normalize_geom_freeze as _normalize_geom_freeze_opt,
41
+ _parse_freeze_atoms as _parse_freeze_atoms_opt,
42
+ )
43
+ from .utils import (
44
+ apply_ref_pdb_override,
45
+ apply_layer_freeze_constraints,
46
+ apply_yaml_overrides,
47
+ convert_xyz_to_pdb,
48
+ set_convert_file_enabled,
49
+ is_convert_file_enabled,
50
+ convert_xyz_like_outputs,
51
+ deep_update,
52
+ filter_calc_for_echo,
53
+ format_elapsed,
54
+ format_freeze_atoms_for_echo,
55
+ load_yaml_dict,
56
+ merge_freeze_atom_indices,
57
+ prepare_input_structure,
58
+ pretty_block,
59
+ resolve_charge_spin_or_raise,
60
+ parse_indices_string,
61
+ build_model_pdb_from_bfactors,
62
+ build_model_pdb_from_indices,
63
+ strip_inherited_keys,
64
+ yaml_section_has_key,
65
+ )
66
+ from .cli_utils import resolve_yaml_sources, load_merged_yaml_cfg, make_is_param_explicit
67
+
68
+
69
+ def _safe_masses_amu(atomic_numbers) -> np.ndarray:
70
+ """Look up atomic masses with a clear error for unknown atomic numbers."""
71
+ max_z = len(atomic_masses) - 1
72
+ bad = [z for z in atomic_numbers if z < 0 or z > max_z or atomic_masses[z] == 0.0]
73
+ if bad:
74
+ raise ValueError(
75
+ f"Unknown or unsupported atomic number(s): {sorted(set(bad))}. "
76
+ "Check that all elements in the input structure are valid."
77
+ )
78
+ return np.array([atomic_masses[z] for z in atomic_numbers])
79
+
80
+
81
+ def _torch_device(auto: str = "auto") -> torch.device:
82
+ if auto == "auto":
83
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
84
+ return torch.device(auto)
85
+
86
+
87
+ # ===================================================================
88
+ # Mass-weighted TR projection & vibrational analysis
89
+ # ===================================================================
90
+
91
+ def _build_tr_basis(coords_bohr_t: torch.Tensor,
92
+ masses_au_t: torch.Tensor) -> torch.Tensor:
93
+ """
94
+ Mass-weighted translation/rotation basis (Tx, Ty, Tz, Rx, Ry, Rz), shape (3N, r<=6).
95
+ """
96
+ device, dtype = coords_bohr_t.device, coords_bohr_t.dtype
97
+ N = coords_bohr_t.shape[0]
98
+ m_au = masses_au_t.to(dtype=dtype, device=device)
99
+ m_sqrt = torch.sqrt(m_au).reshape(-1, 1)
100
+
101
+ com = (m_au.reshape(-1, 1) * coords_bohr_t).sum(0) / m_au.sum()
102
+ x = coords_bohr_t - com
103
+
104
+ eye3 = torch.eye(3, dtype=dtype, device=device)
105
+ cols = []
106
+ for i in range(3):
107
+ cols.append((eye3[i].repeat(N, 1) * m_sqrt).reshape(-1, 1))
108
+ for i in range(3):
109
+ rot = torch.cross(x, eye3[i].expand_as(x), dim=1) * m_sqrt
110
+ cols.append(rot.reshape(-1, 1))
111
+ return torch.cat(cols, dim=1)
112
+
113
+
114
+ def _tr_orthonormal_basis(coords_bohr_t: torch.Tensor,
115
+ masses_au_t: torch.Tensor,
116
+ rtol: float = 1e-12) -> Tuple[torch.Tensor, int]:
117
+ """
118
+ Orthonormalize TR basis in mass-weighted space by SVD. Returns (Q, rank).
119
+ """
120
+ B = _build_tr_basis(coords_bohr_t, masses_au_t)
121
+ U, S, Vh = torch.linalg.svd(B, full_matrices=False)
122
+ r = int((S > rtol * S.max()).sum().item())
123
+ Q = U[:, :r]
124
+ del B, S, Vh, U
125
+ return Q, r
126
+
127
+
128
+ def _mw_projected_hessian(H_t: torch.Tensor,
129
+ coords_bohr_t: torch.Tensor,
130
+ masses_au_t: torch.Tensor) -> torch.Tensor:
131
+ """
132
+ Project out translations/rotations in mass-weighted space:
133
+ Hmw = M^{-1/2} H M^{-1/2}; P = I - QQ^T; Hmw_proj = P Hmw P
134
+
135
+ To save memory, update **H_t in-place** (no clone) and return it.
136
+ The output is explicitly symmetrized after TR projection.
137
+ """
138
+ if H_t.dtype != torch.float64:
139
+ H_t = H_t.to(dtype=torch.float64)
140
+ dtype, device = H_t.dtype, H_t.device
141
+ with torch.no_grad():
142
+ masses_amu_t = (masses_au_t / AMU2AU).to(dtype=dtype, device=device)
143
+ m3 = torch.repeat_interleave(masses_amu_t, 3)
144
+ # Use a single base vector for inverse sqrt mass and create views (no extra large allocations)
145
+ inv_sqrt_m = torch.sqrt(1.0 / m3)
146
+ inv_sqrt_m_col = inv_sqrt_m.view(1, -1)
147
+ inv_sqrt_m_row = inv_sqrt_m.view(-1, 1)
148
+
149
+ # In-place mass-weighting on input Hessian
150
+ H_t.mul_(inv_sqrt_m_row)
151
+ H_t.mul_(inv_sqrt_m_col)
152
+
153
+ Q, _ = _tr_orthonormal_basis(coords_bohr_t, masses_au_t) # (3N, r)
154
+ Q = Q.to(dtype=dtype, device=device)
155
+ Qt = Q.T
156
+
157
+ QtH = Qt @ H_t # (r,3N)
158
+ H_t.addmm_(Q, QtH, beta=1.0, alpha=-1.0)
159
+
160
+ HQ = QtH.T # (3N,r)
161
+ H_t.addmm_(HQ, Qt, beta=1.0, alpha=-1.0)
162
+
163
+ QtHQ = QtH @ Q # (r,r)
164
+ tmp = Q @ QtHQ # (3N,r)
165
+ H_t.addmm_(tmp, Qt, beta=1.0, alpha=1.0)
166
+
167
+ # Explicit symmetrization: H = (H + H^T) / 2
168
+ H_sym = H_t.T.clone()
169
+ H_t.add_(H_sym).mul_(0.5)
170
+ del H_sym
171
+
172
+ del masses_amu_t, m3, inv_sqrt_m, inv_sqrt_m_col, inv_sqrt_m_row
173
+ del Q, Qt, QtH, HQ, QtHQ, tmp
174
+
175
+ if torch.cuda.is_available() and device.type == "cuda":
176
+ torch.cuda.empty_cache()
177
+ return H_t
178
+
179
+
180
+ # ---- PHVA helper: mass-weighted Hessian without TR projection (for active subspace) ----
181
+ def _mass_weighted_hessian(H_t: torch.Tensor,
182
+ masses_au_t: torch.Tensor) -> torch.Tensor:
183
+ """
184
+ Return Hmw = M^{-1/2} H M^{-1/2} (no symmetrization/TR projection; in-place).
185
+ """
186
+ dtype, device = H_t.dtype, H_t.device
187
+ with torch.no_grad():
188
+ masses_amu_t = (masses_au_t / AMU2AU).to(dtype=dtype, device=device)
189
+ m3 = torch.repeat_interleave(masses_amu_t, 3)
190
+ inv_sqrt_m = torch.sqrt(1.0 / m3)
191
+ inv_sqrt_m_col = inv_sqrt_m.view(1, -1)
192
+ inv_sqrt_m_row = inv_sqrt_m.view(-1, 1)
193
+ # In-place mass-weighting on input Hessian
194
+ H_t.mul_(inv_sqrt_m_row)
195
+ H_t.mul_(inv_sqrt_m_col)
196
+ del masses_amu_t, m3, inv_sqrt_m, inv_sqrt_m_col, inv_sqrt_m_row
197
+ return H_t
198
+
199
+
200
+ def _frequencies_cm_and_modes(H_t: torch.Tensor,
201
+ atomic_numbers: List[int],
202
+ coords_bohr: np.ndarray,
203
+ device: torch.device,
204
+ tol: float = 1e-6,
205
+ freeze_idx: Optional[List[int]] = None) -> Tuple[np.ndarray, torch.Tensor]:
206
+ """
207
+ Diagonalize a (possibly PHVA/active-subspace) TR-projected mass-weighted Hessian
208
+ to obtain frequencies (cm^-1) and mass-weighted eigenvectors (modes).
209
+
210
+ If `freeze_idx` is provided (list of 0-based atom indices), perform
211
+ Partial Hessian Vibrational Analysis (PHVA). Supports two cases:
212
+
213
+ A) Full Hessian given (3N×3N):
214
+ 1) build Hmw = M^{-1/2} H M^{-1/2}
215
+ 2) take the active subspace by removing DOF of frozen atoms
216
+ 3) perform TR projection **only in the active subspace** (always applied)
217
+ 4) diagonalize and embed eigenvectors back to 3N by zero-filling frozen DOF
218
+
219
+ B) Already-reduced (active-block) Hessian given (3N_act×3N_act), e.g.
220
+ when UMA is called with return_partial_hessian=True:
221
+ 1) mass-weight with **active** masses only
222
+ 2) TR projection in the active space
223
+ 3) diagonalize and embed back to 3N by zero-filling frozen DOF
224
+
225
+ Returns:
226
+ freqs_cm : (nmode,) numpy, negatives are imaginary
227
+ modes : (nmode, 3N) torch (mass-weighted eigenvectors)
228
+ """
229
+ with torch.no_grad():
230
+ if H_t.dtype != torch.float64:
231
+ H_t = H_t.to(dtype=torch.float64)
232
+ Z = np.array(atomic_numbers, dtype=int)
233
+ N = int(len(Z))
234
+ masses_amu = np.array([atomic_masses[z] for z in Z]) # amu
235
+ masses_au_t = torch.as_tensor(masses_amu * AMU2AU, dtype=H_t.dtype, device=device)
236
+ coords_bohr_t = torch.as_tensor(coords_bohr.reshape(-1, 3), dtype=H_t.dtype, device=device)
237
+
238
+ # --------------------------------------------
239
+ # PHVA path (active DOF subspace with TR-proj)
240
+ # --------------------------------------------
241
+ if freeze_idx is not None and len(freeze_idx) > 0:
242
+ # Active atom indices
243
+ frozen_set = set(int(i) for i in freeze_idx if 0 <= int(i) < N)
244
+ active_idx = [i for i in range(N) if i not in frozen_set]
245
+ n_active = len(active_idx)
246
+ if n_active == 0:
247
+ # All atoms are frozen → no modes
248
+ freqs_cm = np.zeros((0,), dtype=float)
249
+ modes = torch.zeros((0, 3 * N), dtype=H_t.dtype, device=H_t.device)
250
+ return freqs_cm, modes
251
+
252
+ # Determine whether the provided Hessian is already the active block (3N_act×3N_act).
253
+ expected_act_dim = 3 * n_active
254
+ is_partial = (H_t.shape[0] == expected_act_dim and H_t.shape[1] == expected_act_dim)
255
+
256
+ if is_partial:
257
+ # --- Case B: Active-subspace Hessian supplied ---
258
+ # Mass-weight using only active atoms → project TR modes in the active space
259
+ # → diagonalise → embed back into the full space.
260
+ masses_act = masses_au_t[active_idx]
261
+ coords_act = coords_bohr_t[active_idx, :]
262
+
263
+ # in-place mass-weight (active masses)
264
+ Hmw_act = _mass_weighted_hessian(H_t, masses_act)
265
+
266
+ # TR basis and projection in the active space
267
+ Q, _ = _tr_orthonormal_basis(coords_act, masses_act) # (3N_act, r)
268
+ Qt = Q.T
269
+ QtH = Qt @ Hmw_act
270
+ Hmw_act.addmm_(Q, QtH, beta=1.0, alpha=-1.0)
271
+ Hmw_act.addmm_(QtH.T, Qt, beta=1.0, alpha=-1.0)
272
+ QtHQ = QtH @ Q
273
+ Hmw_act.addmm_(Q @ QtHQ, Qt, beta=1.0, alpha=1.0)
274
+
275
+ # Explicit symmetrization before eigendecomposition
276
+ _t = Hmw_act.T.clone()
277
+ Hmw_act.add_(_t).mul_(0.5)
278
+ del _t
279
+ omega2, Vsub = torch.linalg.eigh(Hmw_act, UPLO="U")
280
+
281
+ # Free the (only) Hessian ASAP
282
+ del Hmw_act
283
+ del H_t
284
+ if torch.cuda.is_available():
285
+ torch.cuda.empty_cache()
286
+
287
+ sel = torch.abs(omega2) > tol
288
+ omega2 = omega2[sel]
289
+ Vsub = Vsub[:, sel] # (3N_act, nsel)
290
+
291
+ # Embed to full 3N (mass-weighted eigenvectors)
292
+ modes = torch.zeros((Vsub.shape[1], 3 * N), dtype=Vsub.dtype, device=Vsub.device)
293
+ mask_dof = torch.ones(3 * N, dtype=torch.bool, device=Vsub.device)
294
+ for i in frozen_set:
295
+ mask_dof[3 * i:3 * i + 3] = False
296
+ modes[:, mask_dof] = Vsub.T
297
+ del Q, Qt, QtH, QtHQ, mask_dof
298
+
299
+ else:
300
+ # --- Case A: Full Hessian (3N×3N) supplied ---
301
+ # Apply full mass-weighting → extract the active block → project TR modes in the active space.
302
+ H_t = _mass_weighted_hessian(H_t, masses_au_t)
303
+
304
+ # Build active mask (boolean) and immediately carve out the active block
305
+ mask_dof = torch.ones(3 * N, dtype=torch.bool, device=H_t.device)
306
+ for i in frozen_set:
307
+ mask_dof[3 * i:3 * i + 3] = False
308
+
309
+ # Create the reduced Hessian; free the full one immediately to keep only one in VRAM
310
+ H_act = H_t[mask_dof][:, mask_dof]
311
+ del H_t
312
+ if torch.cuda.is_available():
313
+ torch.cuda.empty_cache()
314
+ H_t = H_act
315
+ del H_act
316
+
317
+ coords_act = coords_bohr_t[active_idx, :]
318
+ masses_act = masses_au_t[active_idx]
319
+ Q, _ = _tr_orthonormal_basis(coords_act, masses_act) # (3N_act, r)
320
+ Qt = Q.T
321
+
322
+ QtH = Qt @ H_t
323
+ H_t.addmm_(Q, QtH, beta=1.0, alpha=-1.0)
324
+
325
+ H_t.addmm_(QtH.T, Qt, beta=1.0, alpha=-1.0)
326
+
327
+ QtH = QtH @ Q
328
+ H_t.addmm_(Q @ QtH, Qt, beta=1.0, alpha=1.0)
329
+
330
+ # Explicit symmetrization before eigendecomposition
331
+ _t = H_t.T.clone()
332
+ H_t.add_(_t).mul_(0.5)
333
+ del _t
334
+ omega2, Vsub = torch.linalg.eigh(H_t, UPLO="U")
335
+
336
+ # Free the (only) Hessian ASAP
337
+ del H_t
338
+ if torch.cuda.is_available():
339
+ torch.cuda.empty_cache()
340
+
341
+ sel = torch.abs(omega2) > tol
342
+ omega2 = omega2[sel]
343
+ Vsub = Vsub[:, sel] # (3N_act, nsel)
344
+
345
+ modes = torch.zeros((Vsub.shape[1], 3 * N), dtype=Vsub.dtype, device=Vsub.device)
346
+ modes[:, mask_dof] = Vsub.T # (nsel, 3N_act) → place into active DOF
347
+ del Vsub, mask_dof, Q, Qt, QtH
348
+
349
+ else:
350
+ # Legacy behavior: TR-projection in full DOF → diagonalization (both in-place)
351
+ H_t = _mw_projected_hessian(H_t, coords_bohr_t, masses_au_t)
352
+ # Explicit symmetrization before eigendecomposition
353
+ _t = H_t.T.clone()
354
+ H_t.add_(_t).mul_(0.5)
355
+ del _t
356
+ omega2, V = torch.linalg.eigh(H_t, UPLO="U")
357
+
358
+ # Free the (only) Hessian ASAP
359
+ del H_t
360
+ if torch.cuda.is_available():
361
+ torch.cuda.empty_cache()
362
+
363
+ sel = torch.abs(omega2) > tol
364
+ omega2 = omega2[sel]
365
+ modes = V[:, sel].T
366
+ del V
367
+
368
+ # Convert to frequencies (cm^-1)
369
+ s_new = (units._hbar * 1e10 / np.sqrt(units._e * units._amu) * np.sqrt(AU2EV) / BOHR2ANG)
370
+ hnu = s_new * torch.sqrt(torch.abs(omega2))
371
+ hnu = torch.where(omega2 < 0, -hnu, hnu)
372
+ freqs_cm = (hnu / units.invcm).detach().cpu().numpy()
373
+
374
+ del omega2, hnu, sel
375
+ if torch.cuda.is_available():
376
+ torch.cuda.empty_cache()
377
+ return freqs_cm, modes
378
+
379
+
380
+ def _mw_mode_to_cart(mode_mw_3N_t: torch.Tensor,
381
+ masses_au_t: torch.Tensor) -> np.ndarray:
382
+ """
383
+ Convert one mass-weighted eigenvector (3N,) to Cartesian (3N,) and L2-normalize.
384
+ """
385
+ with torch.no_grad():
386
+ masses_amu_t = (masses_au_t / AMU2AU).to(dtype=mode_mw_3N_t.dtype, device=mode_mw_3N_t.device)
387
+ m3 = torch.repeat_interleave(masses_amu_t, 3)
388
+ v_cart = torch.sqrt(1.0 / m3) * mode_mw_3N_t
389
+ v_cart.div_(torch.linalg.norm(v_cart))
390
+ arr = v_cart.detach().cpu().numpy()
391
+ del masses_amu_t, m3, v_cart
392
+ return arr
393
+
394
+
395
+ def _calc_full_hessian_torch(
396
+ geom,
397
+ calc_kwargs: Dict[str, Any],
398
+ device: torch.device,
399
+ *,
400
+ refresh_geom_meta: bool = False,
401
+ ) -> Tuple[torch.Tensor, float]:
402
+ """Return (Hessian torch tensor, energy Hartree) from the ML/MM calculator."""
403
+
404
+ kw = dict(calc_kwargs or {})
405
+ kw["out_hess_torch"] = True
406
+ calc = mlmm(**kw)
407
+ result = calc.get_hessian(geom.atoms, geom.coords)
408
+
409
+ if refresh_geom_meta:
410
+ within = result.get("within_partial_hessian")
411
+ if within is None and kw.get("return_partial_hessian"):
412
+ try:
413
+ core = getattr(calc, "core", None)
414
+ if core is not None and hasattr(core, "_build_within_partial_hessian"):
415
+ within = core._build_within_partial_hessian()
416
+ except Exception:
417
+ within = None
418
+ if within is not None:
419
+ geom.within_partial_hessian = within
420
+ elif "hessian" in result:
421
+ geom.within_partial_hessian = None
422
+
423
+ try:
424
+ core = getattr(calc, "core", None)
425
+ if core is not None and hasattr(core, "hess_active_atoms"):
426
+ active_atoms = np.asarray(core.hess_active_atoms, dtype=int)
427
+ geom._hess_active_atoms_last = active_atoms
428
+ if active_atoms.size:
429
+ active_dofs = np.empty(active_atoms.size * 3, dtype=int)
430
+ for i, a in enumerate(active_atoms):
431
+ base = 3 * int(a)
432
+ active_dofs[3 * i:3 * i + 3] = (base, base + 1, base + 2)
433
+ else:
434
+ active_dofs = np.zeros(0, dtype=int)
435
+ geom._hess_active_dofs_last = active_dofs
436
+ except Exception:
437
+ logger.debug("Failed to extract active DOF info from calculator", exc_info=True)
438
+
439
+ H = result["hessian"]
440
+ if not isinstance(H, torch.Tensor):
441
+ H = torch.as_tensor(H)
442
+ H = H.to(device=device)
443
+ energy = float(result.get("energy", 0.0))
444
+
445
+ del calc, result
446
+ if torch.cuda.is_available():
447
+ torch.cuda.empty_cache()
448
+
449
+ return H, energy
450
+
451
+
452
+ def _collect_layer_atom_sets(calc_cfg: Dict[str, Any]) -> Dict[str, set[int]]:
453
+ """Collect ML/MM layer index sets from a temporary calculator instance."""
454
+ empty = {"ml": set(), "hess_mm": set(), "movable_mm": set(), "frozen_mm": set()}
455
+ try:
456
+ temp_calc = mlmm(**dict(calc_cfg))
457
+ calc_core = temp_calc.core if hasattr(temp_calc, "core") else temp_calc
458
+ layer_sets = {
459
+ "ml": set(getattr(calc_core, "ml_indices", []) or []),
460
+ "hess_mm": set(getattr(calc_core, "hess_mm_indices", []) or []),
461
+ "movable_mm": set(getattr(calc_core, "movable_mm_indices", []) or []),
462
+ "frozen_mm": set(getattr(calc_core, "frozen_layer_indices", []) or []),
463
+ }
464
+ del temp_calc
465
+ if torch.cuda.is_available():
466
+ torch.cuda.empty_cache()
467
+ return layer_sets
468
+ except Exception:
469
+ return empty
470
+
471
+
472
+ def _align_three_layer_hessian_targets(
473
+ calc_cfg: Dict[str, Any],
474
+ *,
475
+ echo_fn=None,
476
+ ) -> bool:
477
+ """
478
+ In 3-layer detect-layer mode, align Hessian targets to MovableMM by default.
479
+
480
+ Returns True when a default policy was applied.
481
+ """
482
+ if not bool(calc_cfg.get("use_bfactor_layers", False)):
483
+ return False
484
+ if calc_cfg.get("movable_cutoff") is not None:
485
+ return False
486
+ if calc_cfg.get("hess_cutoff") is not None:
487
+ return False
488
+
489
+ explicit_layer_lists = any(
490
+ calc_cfg.get(key) is not None
491
+ for key in ("hess_mm_atoms", "movable_mm_atoms", "frozen_mm_atoms")
492
+ )
493
+ if explicit_layer_lists:
494
+ return False
495
+
496
+ calc_cfg["hess_cutoff"] = float("inf")
497
+ if echo_fn is not None:
498
+ echo_fn("[layer] 3-layer mode: using MovableMM atoms as Hessian targets.")
499
+ return True
500
+
501
+
502
+ def _resolve_active_atom_indices(
503
+ calc_cfg: Dict[str, Any],
504
+ n_atoms: int,
505
+ active_dof_mode: str,
506
+ ) -> Tuple[Optional[set[int]], Dict[str, set[int]]]:
507
+ """Resolve active atom indices for active_dof_mode from calculator layer sets."""
508
+ layer_sets = _collect_layer_atom_sets(calc_cfg)
509
+ mode = str(active_dof_mode).lower()
510
+ if mode == "all":
511
+ return None, layer_sets
512
+
513
+ ml_indices = layer_sets["ml"]
514
+ hess_mm_indices = layer_sets["hess_mm"]
515
+ movable_mm_indices = layer_sets["movable_mm"]
516
+ frozen_mm_indices = layer_sets["frozen_mm"]
517
+ partial_mm_indices = hess_mm_indices if hess_mm_indices else movable_mm_indices
518
+
519
+ if mode == "ml-only":
520
+ active_indices = set(ml_indices)
521
+ elif mode == "partial":
522
+ active_indices = set(ml_indices) | set(partial_mm_indices)
523
+ elif mode == "unfrozen":
524
+ if ml_indices or hess_mm_indices or movable_mm_indices:
525
+ active_indices = set(ml_indices) | set(hess_mm_indices) | set(movable_mm_indices)
526
+ elif frozen_mm_indices:
527
+ active_indices = set(range(int(n_atoms))) - set(frozen_mm_indices)
528
+ else:
529
+ return None, layer_sets
530
+ else:
531
+ active_indices = set(ml_indices) | set(partial_mm_indices)
532
+
533
+ if not active_indices:
534
+ return None, layer_sets
535
+ return active_indices, layer_sets
536
+
537
+
538
+ def _write_mode_trj_and_pdb(geom,
539
+ mode_vec_3N: np.ndarray,
540
+ out_trj: Path,
541
+ out_pdb: Path,
542
+ amplitude_ang: float = 0.8,
543
+ n_frames: int = 20,
544
+ comment: str = "mode",
545
+ ref_pdb: Optional[Path] = None) -> None:
546
+ """Write a single mode animation as _trj.xyz (XYZ-like) and .pdb.
547
+
548
+ If `ref_pdb` is provided and is a .pdb file, the .pdb is generated by
549
+ converting the _trj.xyz using the input PDB as the template (same as path_opt).
550
+ """
551
+ ref_ang = geom.coords.reshape(-1, 3) * BOHR2ANG
552
+ mode = mode_vec_3N.reshape(-1, 3).copy()
553
+ mode /= np.linalg.norm(mode)
554
+
555
+ # _trj.xyz (concatenated XYZ-like trajectory)
556
+ if ref_pdb is not None and ref_pdb.suffix.lower() == ".pdb":
557
+ # Emit a simple XYZ-like trajectory in Å for the converter
558
+ with out_trj.open("w", encoding="utf-8") as f:
559
+ for i in range(n_frames):
560
+ phase = np.sin(2.0 * np.pi * i / n_frames)
561
+ coords = ref_ang + phase * amplitude_ang * mode # Å
562
+ f.write(f"{len(geom.atoms)}\n{comment} frame={i+1}/{n_frames}\n")
563
+ for sym, (x, y, z) in zip(geom.atoms, coords):
564
+ f.write(f"{sym:2s} {x: .8f} {y: .8f} {z: .8f}\n")
565
+ # Generate PDB using the input PDB as template (respects convert-files toggle)
566
+ if is_convert_file_enabled():
567
+ try:
568
+ convert_xyz_to_pdb(out_trj, ref_pdb, out_pdb)
569
+ except Exception:
570
+ # Fallback: generate MODEL/ENDMDL using ASE
571
+ atoms0 = Atoms(geom.atoms, positions=ref_ang, pbc=False)
572
+ for i in range(n_frames):
573
+ phase = np.sin(2.0 * np.pi * i / n_frames)
574
+ ai = atoms0.copy()
575
+ ai.set_positions(ref_ang + phase * amplitude_ang * mode)
576
+ write(out_pdb, ai, append=(i != 0))
577
+ return
578
+
579
+ # If no ref_pdb is given, use the legacy behavior (use pysisyphus.make_trj_str if available)
580
+ try:
581
+ from pysisyphus.xyzloader import make_trj_str # type: ignore
582
+ amp_ang = amplitude_ang
583
+ steps = np.sin(2.0 * np.pi * np.arange(n_frames) / n_frames)[:, None, None] * (amp_ang * mode[None, :, :])
584
+ traj_ang = ref_ang[None, :, :] + steps # (T,N,3) in Å
585
+ traj_bohr = traj_ang.reshape(n_frames, -1, 3) * ANG2BOHR
586
+ comments = [f"{comment} frame={i+1}/{n_frames}" for i in range(n_frames)]
587
+ trj_str = make_trj_str(geom.atoms, traj_bohr, comments=comments)
588
+ out_trj.write_text(trj_str, encoding="utf-8")
589
+ except Exception:
590
+ with out_trj.open("w", encoding="utf-8") as f:
591
+ for i in range(n_frames):
592
+ phase = np.sin(2.0 * np.pi * i / n_frames)
593
+ coords = ref_ang + phase * amplitude_ang * mode
594
+ f.write(f"{len(geom.atoms)}\n{comment} frame={i+1}/{n_frames}\n")
595
+ for sym, (x, y, z) in zip(geom.atoms, coords):
596
+ f.write(f"{sym:2s} {x: .8f} {y: .8f} {z: .8f}\n")
597
+
598
+ # .pdb (MODEL/ENDMDL via ASE)
599
+ atoms0 = Atoms(geom.atoms, positions=ref_ang, pbc=False)
600
+ for i in range(n_frames):
601
+ phase = np.sin(2.0 * np.pi * i / n_frames)
602
+ ai = atoms0.copy()
603
+ ai.set_positions(ref_ang + phase * amplitude_ang * mode)
604
+ write(out_pdb, ai, append=(i != 0))
605
+
606
+
607
+ # ===================================================================
608
+ # Defaults for CLI
609
+ # ===================================================================
610
+
611
+ # Geometry defaults — shared with opt.py
612
+ GEOM_KW: Dict[str, Any] = deepcopy(OPT_GEOM_KW)
613
+
614
+ # ML/MM calculator defaults — shared with opt.py
615
+ CALC_KW: Dict[str, Any] = deepcopy(OPT_CALC_KW)
616
+
617
+ # FREQ_KW and THERMO_KW are imported from .defaults
618
+
619
+
620
+ # ===================================================================
621
+ # CLI
622
+ # ===================================================================
623
+
624
+ @click.command(
625
+ help="ML/MM vibrational frequency analysis (PHVA-compatible).",
626
+ context_settings={"help_option_names": ["-h", "--help"]},
627
+ )
628
+ @click.option(
629
+ "-i", "--input",
630
+ "input_path",
631
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
632
+ required=True,
633
+ help="Enzyme complex PDB used by both geom_loader and the ML/MM calculator.",
634
+ )
635
+ @click.option(
636
+ "--parm",
637
+ "real_parm7",
638
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
639
+ required=True,
640
+ help="Amber parm7 topology for the full enzyme complex.",
641
+ )
642
+ @click.option(
643
+ "--model-pdb",
644
+ "model_pdb",
645
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
646
+ required=False,
647
+ help="PDB defining atoms belonging to the ML region. Optional when --detect-layer is enabled.",
648
+ )
649
+ @click.option(
650
+ "--model-indices",
651
+ "model_indices_str",
652
+ type=str,
653
+ default=None,
654
+ show_default=False,
655
+ help="Comma-separated atom indices for the ML region (ranges allowed like 1-5). "
656
+ "Used when --model-pdb is omitted.",
657
+ )
658
+ @click.option(
659
+ "--model-indices-one-based/--model-indices-zero-based",
660
+ "model_indices_one_based",
661
+ default=True,
662
+ show_default=True,
663
+ help="Interpret --model-indices as 1-based (default) or 0-based.",
664
+ )
665
+ @click.option(
666
+ "--detect-layer/--no-detect-layer",
667
+ "detect_layer",
668
+ default=True,
669
+ show_default=True,
670
+ help="Detect ML/MM layers from input PDB B-factors (B=0/10/20). "
671
+ "If disabled, you must provide --model-pdb or --model-indices.",
672
+ )
673
+ @click.option("-q", "--charge", type=int, required=False,
674
+ help="ML region charge. Required unless --ligand-charge is provided.")
675
+ @click.option("-l", "--ligand-charge", type=str, default=None, show_default=False,
676
+ help="Total charge or per-resname mapping (e.g., GPP:-3,SAM:1) used to derive "
677
+ "charge when -q is omitted (requires PDB input or --ref-pdb).")
678
+ @click.option(
679
+ "-m",
680
+ "--multiplicity",
681
+ "spin",
682
+ type=int,
683
+ default=None,
684
+ show_default=False,
685
+ help="Spin multiplicity (2S+1) for the ML region. Defaults to 1 when omitted.",
686
+ )
687
+ @click.option(
688
+ "--freeze-atoms",
689
+ "freeze_atoms_text",
690
+ type=str,
691
+ default=None,
692
+ show_default=False,
693
+ help="Comma-separated 1-based atom indices to freeze (e.g., '1,3,5').",
694
+ )
695
+ @click.option(
696
+ "--hess-cutoff",
697
+ "hess_cutoff",
698
+ type=float,
699
+ default=None,
700
+ show_default=False,
701
+ help="Distance cutoff (Å) from ML region for MM atoms to include in Hessian calculation. "
702
+ "Applied to movable MM atoms and can be combined with --detect-layer.",
703
+ )
704
+ @click.option(
705
+ "--movable-cutoff",
706
+ "movable_cutoff",
707
+ type=float,
708
+ default=None,
709
+ show_default=False,
710
+ help="Distance cutoff (Å) from ML region for movable MM atoms. MM atoms beyond this are frozen. "
711
+ "Providing --movable-cutoff disables --detect-layer.",
712
+ )
713
+ @click.option(
714
+ "--hessian-calc-mode",
715
+ type=click.Choice(["Analytical", "FiniteDifference"], case_sensitive=False),
716
+ default=None,
717
+ help="How the ML backend builds the Hessian (Analytical or FiniteDifference); "
718
+ "overrides calc.hessian_calc_mode from YAML. "
719
+ "Default: 'FiniteDifference'. Use 'Analytical' when VRAM is sufficient.",
720
+ )
721
+ @click.option("--max-write", type=int, default=FREQ_KW["max_write"], show_default=True,
722
+ help="Maximum number of modes to export.")
723
+ @click.option("--amplitude-ang", type=float, default=FREQ_KW["amplitude_ang"], show_default=True,
724
+ help="Mode animation amplitude (Å).")
725
+ @click.option("--n-frames", type=int, default=FREQ_KW["n_frames"], show_default=True,
726
+ help="Frames per vibrational mode animation.")
727
+ @click.option(
728
+ "--sort",
729
+ type=click.Choice(["value", "abs"]),
730
+ default=FREQ_KW["sort"],
731
+ show_default=True,
732
+ help="Sort modes by signed value or absolute value.",
733
+ )
734
+ @click.option("--temperature", type=float, default=THERMO_KW["temperature"], show_default=True,
735
+ help="Temperature (K) for thermochemistry summary.")
736
+ @click.option("--pressure", "pressure_atm",
737
+ type=float, default=THERMO_KW["pressure_atm"], show_default=True,
738
+ help="Pressure (atm) for thermochemistry summary.")
739
+ @click.option(
740
+ "--dump/--no-dump",
741
+ default=THERMO_KW["dump"],
742
+ show_default=True,
743
+ help="Write 'thermoanalysis.yaml' alongside the console summary.",
744
+ )
745
+ @click.option("-o", "--out-dir", type=str, default=FREQ_KW["out_dir"], show_default=True, help="Output directory.")
746
+ @click.option(
747
+ "--active-dof-mode",
748
+ type=click.Choice(["all", "ml-only", "partial", "unfrozen"], case_sensitive=False),
749
+ default="partial",
750
+ show_default=True,
751
+ help="Active DOF selection for frequency analysis: "
752
+ "all (all atoms), ml-only (ML only), partial (ML + MovableMM, default), "
753
+ "unfrozen (all non-frozen atoms).",
754
+ )
755
+ @click.option(
756
+ "--config",
757
+ "config_yaml",
758
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
759
+ default=None,
760
+ help="Base YAML configuration file applied before explicit CLI options.",
761
+ )
762
+ @click.option(
763
+ "--show-config/--no-show-config",
764
+ "show_config",
765
+ default=False,
766
+ show_default=True,
767
+ help="Print resolved configuration and continue execution.",
768
+ )
769
+ @click.option(
770
+ "--dry-run/--no-dry-run",
771
+ "dry_run",
772
+ default=False,
773
+ show_default=True,
774
+ help="Validate options and print the execution plan without running frequency analysis.",
775
+ )
776
+ @click.option(
777
+ "--ref-pdb",
778
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
779
+ default=None,
780
+ help="Reference PDB topology to use when --input is XYZ (keeps XYZ coordinates).",
781
+ )
782
+ @click.option(
783
+ "--convert-files/--no-convert-files",
784
+ "convert_files",
785
+ default=True,
786
+ show_default=True,
787
+ help="Convert XYZ/TRJ outputs into PDB companions based on the input format.",
788
+ )
789
+ @click.option(
790
+ "--hess-device",
791
+ "hess_device",
792
+ type=click.Choice(["auto", "cuda", "cpu"], case_sensitive=False),
793
+ default="auto",
794
+ show_default=True,
795
+ help="Device for Hessian assembly and diagonalization (auto/cuda/cpu). "
796
+ "Use 'cpu' to avoid VRAM issues with large unfrozen systems. "
797
+ "ML model inference always uses ml_device (typically GPU).",
798
+ )
799
+ @click.option(
800
+ "-b", "--backend",
801
+ type=click.Choice(["uma", "orb", "mace", "aimnet2"], case_sensitive=False),
802
+ default=None,
803
+ show_default=False,
804
+ help="ML backend for the ONIOM high-level region (default: uma).",
805
+ )
806
+ @click.option(
807
+ "--embedcharge/--no-embedcharge",
808
+ "embedcharge",
809
+ default=False,
810
+ show_default=True,
811
+ help="Enable xTB point-charge embedding correction for MM→ML environmental effects.",
812
+ )
813
+ @click.option(
814
+ "--embedcharge-cutoff",
815
+ "embedcharge_cutoff",
816
+ type=float,
817
+ default=None,
818
+ show_default=False,
819
+ help="Distance cutoff (Å) from ML region for MM point charges in xTB embedding. "
820
+ "Default: 12.0 Å when --embedcharge is enabled.",
821
+ )
822
+ @click.pass_context
823
+ def cli(
824
+ ctx: click.Context,
825
+ input_path: Path,
826
+ real_parm7: Path,
827
+ model_pdb: Optional[Path],
828
+ model_indices_str: Optional[str],
829
+ model_indices_one_based: bool,
830
+ detect_layer: bool,
831
+ charge: Optional[int],
832
+ ligand_charge: Optional[str],
833
+ spin: Optional[int],
834
+ freeze_atoms_text: Optional[str],
835
+ hess_cutoff: Optional[float],
836
+ movable_cutoff: Optional[float],
837
+ hessian_calc_mode: Optional[str],
838
+ max_write: int,
839
+ amplitude_ang: float,
840
+ n_frames: int,
841
+ sort: str,
842
+ temperature: float,
843
+ pressure_atm: float,
844
+ dump: bool,
845
+ out_dir: str,
846
+ active_dof_mode: str,
847
+ config_yaml: Optional[Path],
848
+ show_config: bool,
849
+ dry_run: bool,
850
+ ref_pdb: Optional[Path],
851
+ convert_files: bool,
852
+ hess_device: str,
853
+ backend: Optional[str],
854
+ embedcharge: bool,
855
+ embedcharge_cutoff: Optional[float],
856
+ ) -> None:
857
+ set_convert_file_enabled(convert_files)
858
+ time_start = time.perf_counter()
859
+ _is_param_explicit = make_is_param_explicit(ctx)
860
+
861
+ config_yaml, override_yaml, used_legacy_yaml = resolve_yaml_sources(
862
+ config_yaml=config_yaml,
863
+ override_yaml=None,
864
+ args_yaml_legacy=None,
865
+ )
866
+ merged_yaml_cfg, _, _ = load_merged_yaml_cfg(
867
+ config_yaml=config_yaml,
868
+ override_yaml=None,
869
+ )
870
+
871
+ # Validate input format: PDB directly, or XYZ with --ref-pdb
872
+ suffix = input_path.suffix.lower()
873
+ if suffix not in (".pdb", ".xyz"):
874
+ click.echo("ERROR: --input must be a PDB or XYZ file.", err=True)
875
+ sys.exit(1)
876
+ if suffix == ".xyz" and ref_pdb is None:
877
+ click.echo("ERROR: --ref-pdb is required when --input is an XYZ file.", err=True)
878
+ sys.exit(1)
879
+
880
+ prepared_input = prepare_input_structure(input_path)
881
+ try:
882
+ apply_ref_pdb_override(prepared_input, ref_pdb)
883
+ except click.BadParameter as e:
884
+ click.echo(f"ERROR: {e}", err=True)
885
+ prepared_input.cleanup()
886
+ sys.exit(1)
887
+
888
+ geom_input_path = prepared_input.geom_path
889
+ source_path = prepared_input.source_path # PDB topology for output conversion
890
+ charge, spin = resolve_charge_spin_or_raise(
891
+ prepared_input, charge, spin,
892
+ ligand_charge=ligand_charge, prefix="[freq]",
893
+ )
894
+
895
+ try:
896
+ freeze_atoms_cli = _parse_freeze_atoms_opt(freeze_atoms_text)
897
+ except click.BadParameter as e:
898
+ click.echo(f"ERROR: {e}", err=True)
899
+ prepared_input.cleanup()
900
+ sys.exit(1)
901
+
902
+ model_indices: Optional[List[int]] = None
903
+ if model_indices_str:
904
+ try:
905
+ model_indices = parse_indices_string(model_indices_str, one_based=model_indices_one_based)
906
+ except click.BadParameter as e:
907
+ click.echo(f"ERROR: {e}", err=True)
908
+ prepared_input.cleanup()
909
+ sys.exit(1)
910
+
911
+ try:
912
+ config_layer_cfg = load_yaml_dict(config_yaml)
913
+ override_layer_cfg = load_yaml_dict(override_yaml)
914
+ except ValueError as e:
915
+ click.echo(f"ERROR: {e}", err=True)
916
+ prepared_input.cleanup()
917
+ sys.exit(1)
918
+
919
+ geom_cfg = deepcopy(GEOM_KW)
920
+ calc_cfg = deepcopy(CALC_KW)
921
+ freq_cfg = dict(FREQ_KW)
922
+ thermo_cfg = dict(THERMO_KW)
923
+ # Keep command-level default for detect-layer unless YAML/explicit CLI overrides it.
924
+ calc_cfg["use_bfactor_layers"] = bool(detect_layer)
925
+
926
+ apply_yaml_overrides(
927
+ config_layer_cfg,
928
+ [
929
+ (geom_cfg, (("geom",),)),
930
+ (calc_cfg, (("calc",), ("mlmm",))),
931
+ (freq_cfg, (("freq",),)),
932
+ (thermo_cfg, (("thermo",), ("freq", "thermo"))),
933
+ ],
934
+ )
935
+
936
+ # CLI explicit overrides (after config YAML, before override YAML)
937
+ if backend is not None:
938
+ calc_cfg["backend"] = str(backend).lower()
939
+ if _is_param_explicit("embedcharge"):
940
+ calc_cfg["embedcharge"] = bool(embedcharge)
941
+ if _is_param_explicit("embedcharge_cutoff"):
942
+ calc_cfg["embedcharge_cutoff"] = embedcharge_cutoff
943
+
944
+ if _is_param_explicit("hessian_calc_mode") and hessian_calc_mode is not None:
945
+ calc_cfg["hessian_calc_mode"] = str(hessian_calc_mode)
946
+
947
+ if _is_param_explicit("max_write"):
948
+ freq_cfg["max_write"] = int(max_write)
949
+ if _is_param_explicit("amplitude_ang"):
950
+ freq_cfg["amplitude_ang"] = float(amplitude_ang)
951
+ if _is_param_explicit("n_frames"):
952
+ freq_cfg["n_frames"] = int(n_frames)
953
+ if _is_param_explicit("sort"):
954
+ freq_cfg["sort"] = str(sort)
955
+ if _is_param_explicit("out_dir"):
956
+ freq_cfg["out_dir"] = out_dir
957
+ if _is_param_explicit("active_dof_mode"):
958
+ freq_cfg["active_dof_mode"] = str(active_dof_mode)
959
+
960
+ if _is_param_explicit("temperature"):
961
+ thermo_cfg["temperature"] = float(temperature)
962
+ if _is_param_explicit("pressure_atm"):
963
+ thermo_cfg["pressure_atm"] = float(pressure_atm)
964
+ if _is_param_explicit("dump"):
965
+ thermo_cfg["dump"] = bool(dump)
966
+
967
+ if _is_param_explicit("hess_cutoff") and hess_cutoff is not None:
968
+ calc_cfg["hess_cutoff"] = float(hess_cutoff)
969
+ if _is_param_explicit("movable_cutoff") and movable_cutoff is not None:
970
+ calc_cfg["movable_cutoff"] = float(movable_cutoff)
971
+ if _is_param_explicit("detect_layer"):
972
+ calc_cfg["use_bfactor_layers"] = bool(detect_layer)
973
+
974
+ model_charge_value = calc_cfg.get("model_charge", charge)
975
+ if model_charge_value is None:
976
+ model_charge_value = charge
977
+ calc_cfg["model_charge"] = int(model_charge_value)
978
+ if _is_param_explicit("charge"):
979
+ calc_cfg["model_charge"] = int(charge)
980
+
981
+ model_mult_value = calc_cfg.get("model_mult", spin)
982
+ if model_mult_value is None:
983
+ model_mult_value = spin
984
+ calc_cfg["model_mult"] = int(model_mult_value)
985
+ if _is_param_explicit("spin"):
986
+ calc_cfg["model_mult"] = int(spin)
987
+
988
+ calc_cfg["input_pdb"] = str(source_path)
989
+ calc_cfg["real_parm7"] = str(real_parm7)
990
+ if model_pdb is not None:
991
+ calc_cfg["model_pdb"] = str(model_pdb)
992
+
993
+ apply_yaml_overrides(
994
+ override_layer_cfg,
995
+ [
996
+ (geom_cfg, (("geom",),)),
997
+ (calc_cfg, (("calc",), ("mlmm",))),
998
+ (freq_cfg, (("freq",),)),
999
+ (thermo_cfg, (("thermo",), ("freq", "thermo"))),
1000
+ ],
1001
+ )
1002
+ calc_paths = (("calc",), ("mlmm",))
1003
+ partial_explicit = (
1004
+ yaml_section_has_key(config_layer_cfg, calc_paths, "return_partial_hessian")
1005
+ or yaml_section_has_key(override_layer_cfg, calc_paths, "return_partial_hessian")
1006
+ )
1007
+ if not partial_explicit:
1008
+ calc_cfg["return_partial_hessian"] = True
1009
+
1010
+ try:
1011
+ geom_freeze = _normalize_geom_freeze_opt(geom_cfg.get("freeze_atoms"))
1012
+ except click.BadParameter as e:
1013
+ click.echo(f"ERROR: {e}", err=True)
1014
+ prepared_input.cleanup()
1015
+ sys.exit(1)
1016
+ geom_cfg["freeze_atoms"] = geom_freeze
1017
+ if freeze_atoms_cli:
1018
+ merge_freeze_atom_indices(geom_cfg, freeze_atoms_cli)
1019
+ freeze_atoms_final = list(geom_cfg.get("freeze_atoms") or [])
1020
+ calc_cfg["freeze_atoms"] = freeze_atoms_final
1021
+
1022
+ out_dir_path = Path(freq_cfg.get("out_dir", FREQ_KW["out_dir"])).resolve()
1023
+ layer_source_pdb = source_path
1024
+ detect_layer_enabled = bool(calc_cfg.get("use_bfactor_layers", True))
1025
+ model_pdb_cfg = calc_cfg.get("model_pdb")
1026
+ movable_cutoff_value = calc_cfg.get("movable_cutoff")
1027
+ if movable_cutoff_value is not None:
1028
+ if detect_layer_enabled:
1029
+ click.echo("[layer] movable_cutoff is set; disabling detect-layer mode.", err=True)
1030
+ detect_layer_enabled = False
1031
+ calc_cfg["use_bfactor_layers"] = False
1032
+
1033
+ if show_config:
1034
+ click.echo(
1035
+ pretty_block(
1036
+ "yaml_layers",
1037
+ {
1038
+ "config": None if config_yaml is None else str(config_yaml),
1039
+ "override_yaml": None if override_yaml is None else str(override_yaml),
1040
+ "merged_keys": sorted(merged_yaml_cfg.keys()),
1041
+ },
1042
+ )
1043
+ )
1044
+
1045
+ if dry_run:
1046
+ model_region_source = "bfactor"
1047
+ if not detect_layer_enabled:
1048
+ if model_pdb_cfg is not None:
1049
+ model_region_source = "model_pdb"
1050
+ elif model_indices:
1051
+ model_region_source = "model_indices"
1052
+ else:
1053
+ click.echo("ERROR: Provide --model-pdb or --model-indices when --no-detect-layer.", err=True)
1054
+ prepared_input.cleanup()
1055
+ sys.exit(1)
1056
+ if detect_layer_enabled and layer_source_pdb.suffix.lower() != ".pdb":
1057
+ click.echo("ERROR: --detect-layer requires a PDB input (or --ref-pdb).", err=True)
1058
+ prepared_input.cleanup()
1059
+ sys.exit(1)
1060
+ if (
1061
+ not detect_layer_enabled
1062
+ and model_pdb_cfg is None
1063
+ and model_indices
1064
+ and layer_source_pdb.suffix.lower() != ".pdb"
1065
+ ):
1066
+ click.echo("ERROR: --model-indices requires a PDB input (or --ref-pdb).", err=True)
1067
+ prepared_input.cleanup()
1068
+ sys.exit(1)
1069
+ click.echo(
1070
+ pretty_block(
1071
+ "dry_run_plan",
1072
+ {
1073
+ "input_geometry": str(geom_input_path),
1074
+ "output_dir": str(out_dir_path),
1075
+ "detect_layer": bool(detect_layer_enabled),
1076
+ "model_region_source": model_region_source,
1077
+ "model_indices_count": 0 if not model_indices else len(model_indices),
1078
+ "active_dof_mode": str(freq_cfg.get("active_dof_mode", active_dof_mode)),
1079
+ "will_run_frequency_analysis": True,
1080
+ "will_write_modes": True,
1081
+ "will_dump_thermo_yaml": bool(thermo_cfg.get("dump", False)),
1082
+ "backend": calc_cfg.get("backend", "uma"),
1083
+ "embedcharge": bool(calc_cfg.get("embedcharge", False)),
1084
+ },
1085
+ )
1086
+ )
1087
+ click.echo("[dry-run] Validation complete. Frequency execution was skipped.")
1088
+ return
1089
+
1090
+ out_dir_path.mkdir(parents=True, exist_ok=True)
1091
+
1092
+ if detect_layer_enabled and layer_source_pdb.suffix.lower() != ".pdb":
1093
+ click.echo("ERROR: --detect-layer requires a PDB input (or --ref-pdb).", err=True)
1094
+ prepared_input.cleanup()
1095
+ sys.exit(1)
1096
+
1097
+ model_pdb_path: Optional[Path] = None
1098
+ layer_info: Optional[Dict[str, List[int]]] = None
1099
+
1100
+ if detect_layer_enabled:
1101
+ try:
1102
+ model_pdb_path, layer_info = build_model_pdb_from_bfactors(layer_source_pdb, out_dir_path)
1103
+ calc_cfg["use_bfactor_layers"] = True
1104
+ click.echo(
1105
+ f"[layer] Detected B-factor layers: ML={len(layer_info.get('ml_indices', []))}, "
1106
+ f"MovableMM={len(layer_info.get('movable_mm_indices', []))}, "
1107
+ f"FrozenMM={len(layer_info.get('frozen_indices', []))}"
1108
+ )
1109
+ except Exception as e:
1110
+ if model_pdb_cfg is None and not model_indices:
1111
+ click.echo(f"ERROR: {e}", err=True)
1112
+ prepared_input.cleanup()
1113
+ sys.exit(1)
1114
+ click.echo(f"[layer] WARNING: {e} Falling back to explicit ML region.", err=True)
1115
+ detect_layer_enabled = False
1116
+
1117
+ if not detect_layer_enabled:
1118
+ if model_pdb_cfg is None and not model_indices:
1119
+ click.echo("ERROR: Provide --model-pdb or --model-indices when --no-detect-layer.", err=True)
1120
+ prepared_input.cleanup()
1121
+ sys.exit(1)
1122
+ if model_pdb_cfg is not None:
1123
+ model_pdb_path = Path(model_pdb_cfg)
1124
+ else:
1125
+ if layer_source_pdb.suffix.lower() != ".pdb":
1126
+ click.echo("ERROR: --model-indices requires a PDB input (or --ref-pdb).", err=True)
1127
+ prepared_input.cleanup()
1128
+ sys.exit(1)
1129
+ try:
1130
+ model_pdb_path = build_model_pdb_from_indices(layer_source_pdb, out_dir_path, model_indices or [])
1131
+ except Exception as e:
1132
+ click.echo(f"ERROR: {e}", err=True)
1133
+ prepared_input.cleanup()
1134
+ sys.exit(1)
1135
+ calc_cfg["use_bfactor_layers"] = False
1136
+
1137
+ if model_pdb_path is None:
1138
+ click.echo("ERROR: Failed to resolve model PDB for the ML region.", err=True)
1139
+ prepared_input.cleanup()
1140
+ sys.exit(1)
1141
+
1142
+ calc_cfg["model_pdb"] = str(model_pdb_path)
1143
+ freeze_atoms_final = apply_layer_freeze_constraints(
1144
+ geom_cfg,
1145
+ calc_cfg,
1146
+ layer_info,
1147
+ echo_fn=click.echo,
1148
+ )
1149
+
1150
+ for key in ("input_pdb", "real_parm7", "model_pdb", "mm_fd_dir"):
1151
+ val = calc_cfg.get(key)
1152
+ if val:
1153
+ calc_cfg[key] = str(Path(val).expanduser().resolve())
1154
+
1155
+ click.echo(pretty_block("geom", format_freeze_atoms_for_echo(geom_cfg, key="freeze_atoms")))
1156
+ echo_calc = format_freeze_atoms_for_echo(filter_calc_for_echo(calc_cfg), key="freeze_atoms")
1157
+ click.echo(pretty_block("calc", echo_calc))
1158
+ echo_freq = strip_inherited_keys({**freq_cfg, "out_dir": str(out_dir_path)}, FREQ_KW, mode="same")
1159
+ click.echo(pretty_block("freq", echo_freq))
1160
+ echo_thermo = strip_inherited_keys(thermo_cfg, THERMO_KW, mode="same")
1161
+ click.echo(pretty_block("thermo", echo_thermo))
1162
+
1163
+ coord_type = geom_cfg.get("coord_type", "cart")
1164
+ coord_kwargs = dict(geom_cfg)
1165
+ coord_kwargs.pop("coord_type", None)
1166
+ geometry = geom_loader(geom_input_path, coord_type=coord_type, **coord_kwargs)
1167
+
1168
+ masses_amu = np.array([atomic_masses[z] for z in geometry.atomic_numbers])
1169
+ # Resolve Hessian assembly/diagonalization device separately from ML inference device.
1170
+ # --hess-device=cpu allows large Hessians to be assembled on CPU while ML model stays on GPU.
1171
+ if hess_device.lower() == "auto":
1172
+ device = _torch_device(calc_cfg.get("ml_device", "auto"))
1173
+ else:
1174
+ device = _torch_device(hess_device.lower())
1175
+ if device.type == "cpu":
1176
+ click.echo("[device] Hessian assembly and diagonalization will run on CPU.")
1177
+ masses_au_t = torch.as_tensor(masses_amu * AMU2AU, dtype=torch.float32, device=device)
1178
+
1179
+ n_atoms = len(geometry.atoms)
1180
+ all_indices = set(range(n_atoms))
1181
+ three_layer_policy_applied = _align_three_layer_hessian_targets(calc_cfg, echo_fn=click.echo)
1182
+
1183
+ # Determine active atoms based on mode
1184
+ active_dof_mode_lower = str(freq_cfg.get("active_dof_mode", active_dof_mode)).lower()
1185
+ active_indices, layer_sets = _resolve_active_atom_indices(calc_cfg, n_atoms, active_dof_mode_lower)
1186
+ ml_indices = layer_sets["ml"]
1187
+ hess_mm_indices = layer_sets["hess_mm"]
1188
+ movable_mm_indices = layer_sets["movable_mm"]
1189
+ partial_mm_indices = hess_mm_indices if hess_mm_indices else movable_mm_indices
1190
+
1191
+ if active_dof_mode_lower == "all" or active_indices is None:
1192
+ active_indices = all_indices
1193
+ click.echo("[active-dof] Using all atoms for frequency analysis.")
1194
+ elif active_dof_mode_lower == "ml-only":
1195
+ click.echo(f"[active-dof] Using ML atoms only for frequency analysis (n={len(active_indices)}).")
1196
+ elif active_dof_mode_lower == "partial":
1197
+ if three_layer_policy_applied:
1198
+ click.echo(f"[active-dof] Using ML + MovableMM atoms for frequency analysis (n={len(active_indices)}).")
1199
+ elif hess_mm_indices:
1200
+ click.echo(f"[active-dof] Using ML + Hessian-target MM atoms for frequency analysis (n={len(active_indices)}).")
1201
+ else:
1202
+ click.echo(f"[active-dof] Using ML + MovableMM atoms for frequency analysis (n={len(active_indices)}).")
1203
+ elif active_dof_mode_lower == "unfrozen":
1204
+ click.echo(f"[active-dof] Using all non-frozen atoms for frequency analysis (n={len(active_indices)}).")
1205
+ else:
1206
+ active_indices = set(ml_indices) | set(partial_mm_indices)
1207
+ click.echo(f"[active-dof] Defaulting to partial mode (n={len(active_indices)}).")
1208
+
1209
+ # Atoms not in active_indices become frozen for frequency analysis
1210
+ freeze_for_freq = sorted(all_indices - active_indices)
1211
+ # Also include any explicitly frozen atoms from config
1212
+ explicit_freeze = set(calc_cfg.get("freeze_atoms") or [])
1213
+ freeze_list = sorted(set(freeze_for_freq) | explicit_freeze)
1214
+
1215
+ try:
1216
+ from .hessian_cache import load as _hess_load
1217
+ _cached_ts = _hess_load("ts")
1218
+ if _cached_ts is not None:
1219
+ click.echo("[freq] Reusing cached TS Hessian.")
1220
+ H_t = _cached_ts["hessian"]
1221
+ if isinstance(H_t, torch.Tensor):
1222
+ H_t = H_t.to(device=device)
1223
+ else:
1224
+ H_t = torch.as_tensor(H_t, device=device)
1225
+ energy_ha = _cached_ts.get("meta", {}).get("energy_ha")
1226
+ if energy_ha is None:
1227
+ energy_ha = float(geometry.energy)
1228
+ else:
1229
+ H_t, energy_ha = _calc_full_hessian_torch(geometry, calc_cfg, device)
1230
+ coords_bohr = geometry.coords.reshape(-1, 3)
1231
+ freqs_cm, modes_mw = _frequencies_cm_and_modes(
1232
+ H_t,
1233
+ geometry.atomic_numbers,
1234
+ coords_bohr,
1235
+ device,
1236
+ freeze_idx=freeze_list if freeze_list else None,
1237
+ )
1238
+
1239
+ del H_t
1240
+ if torch.cuda.is_available():
1241
+ torch.cuda.empty_cache()
1242
+
1243
+ order = (
1244
+ np.argsort(np.abs(freqs_cm))
1245
+ if freq_cfg["sort"] == "abs"
1246
+ else np.argsort(freqs_cm)
1247
+ )
1248
+ n_write = int(min(freq_cfg["max_write"], len(order)))
1249
+ click.echo(
1250
+ f"[INFO] Total modes: {len(freqs_cm)} → writing {n_write} mode(s) ({freq_cfg['sort']} ordering)."
1251
+ )
1252
+
1253
+ ref_pdb_for_modes = source_path if source_path.suffix.lower() == ".pdb" else None
1254
+ for k, idx in enumerate(order[:n_write], start=1):
1255
+ freq_val = float(freqs_cm[idx])
1256
+ mode_cart_3N = _mw_mode_to_cart(modes_mw[idx], masses_au_t)
1257
+ out_trj = out_dir_path / f"mode_{k:04d}_{freq_val:+.2f}cm-1_trj.xyz"
1258
+ out_pdb = out_dir_path / f"mode_{k:04d}_{freq_val:+.2f}cm-1.pdb"
1259
+ _write_mode_trj_and_pdb(
1260
+ geometry,
1261
+ mode_cart_3N,
1262
+ out_trj,
1263
+ out_pdb,
1264
+ amplitude_ang=freq_cfg["amplitude_ang"],
1265
+ n_frames=freq_cfg["n_frames"],
1266
+ comment=f"mode {k} {freq_val:+.2f} cm-1",
1267
+ ref_pdb=ref_pdb_for_modes,
1268
+ )
1269
+
1270
+ (out_dir_path / "frequencies_cm-1.txt").write_text(
1271
+ "\n".join(f"{i+1:4d} {float(freqs_cm[j]):+12.4f}" for i, j in enumerate(order)),
1272
+ encoding="utf-8",
1273
+ )
1274
+
1275
+ del modes_mw
1276
+ if torch.cuda.is_available():
1277
+ torch.cuda.empty_cache()
1278
+
1279
+ try:
1280
+ from thermoanalysis.QCData import QCData
1281
+ from thermoanalysis.constants import J2AU, J2CAL, NA
1282
+ from thermoanalysis.thermo import thermochemistry
1283
+
1284
+ qc_data = {
1285
+ "coords3d": geometry.coords.reshape(-1, 3) * BOHR2ANG,
1286
+ "wavenumbers": freqs_cm,
1287
+ "scf_energy": float(energy_ha),
1288
+ "masses": masses_amu,
1289
+ "mult": int(calc_cfg["model_mult"]),
1290
+ }
1291
+ qc = QCData(qc_data, point_group="c1", mult=int(calc_cfg["model_mult"]))
1292
+
1293
+ T = float(thermo_cfg["temperature"])
1294
+ p_atm = float(thermo_cfg["pressure_atm"])
1295
+ p_pa = p_atm * 101325.0 # Pa
1296
+
1297
+ tr = thermochemistry(qc, T, pressure=p_pa) # default: QRRHO
1298
+
1299
+ # Converters
1300
+ au2CalMol = (1.0 / J2AU) * NA * J2CAL
1301
+ to_cal_per_mol = lambda x: float(x) * au2CalMol
1302
+ J_per_Kmol_to_cal_per_Kmol = lambda j: float(j) * J2CAL
1303
+
1304
+ # Counts
1305
+ n_imag = int(np.sum(freqs_cm < 0.0))
1306
+
1307
+ # Compose summary
1308
+ EE = float(tr.U_el)
1309
+ ZPE = float(tr.ZPE)
1310
+ dE_therm = float(tr.U_therm) # Thermal correction to Energy (includes ZPE)
1311
+ dH_therm = float(tr.H - tr.U_el) # Thermal correction to Enthalpy (= U_therm + kBT)
1312
+ dG_therm = float(tr.dG) # Thermal correction to Free Energy (= G - EE)
1313
+
1314
+ sum_EE_ZPE = EE + ZPE
1315
+ sum_EE_thermal_E = float(tr.U_tot) # = EE + U_therm
1316
+ sum_EE_thermal_H = float(tr.H) # = H
1317
+ sum_EE_thermal_G = float(tr.G) # = G
1318
+
1319
+ E_thermal_cal = to_cal_per_mol(tr.U_therm) # cal/mol
1320
+ Cv_cal_per_Kmol = J_per_Kmol_to_cal_per_Kmol(tr.c_tot) # cal/(mol*K)
1321
+ S_cal_per_Kmol = to_cal_per_mol(tr.S_tot) # cal/(mol*K)
1322
+
1323
+ # Echo summary (Gaussian-like)
1324
+ click.echo("\nThermochemistry Summary")
1325
+ click.echo("------------------------")
1326
+ click.echo(f"Temperature (K) = {T:.2f}")
1327
+ click.echo(f"Pressure (atm) = {p_atm:.4f}")
1328
+ if freeze_list:
1329
+ click.echo("[NOTE] Thermochemistry uses active DOF (PHVA) due to frozen atoms.")
1330
+ click.echo(f"Number of Imaginary Freq = {n_imag:d}\n")
1331
+
1332
+ def _ha(x): return f"{float(x): .6f} Ha"
1333
+ def _cal(x): return f"{float(x): .2f} cal/mol"
1334
+ def _calK(x): return f"{float(x): .2f} cal/(mol*K)"
1335
+
1336
+ click.echo(f"Electronic Energy (EE) = {_ha(EE)}")
1337
+ click.echo(f"Zero-point Energy Correction = {_ha(ZPE)}")
1338
+ click.echo(f"Thermal Correction to Energy = {_ha(dE_therm)}")
1339
+ click.echo(f"Thermal Correction to Enthalpy = {_ha(dH_therm)}")
1340
+ click.echo(f"Thermal Correction to Free Energy = {_ha(dG_therm)}")
1341
+ click.echo(f"EE + Zero-point Energy = {_ha(sum_EE_ZPE)}")
1342
+ click.echo(f"EE + Thermal Energy Correction = {_ha(sum_EE_thermal_E)}")
1343
+ click.echo(f"EE + Thermal Enthalpy Correction = {_ha(sum_EE_thermal_H)}")
1344
+ click.echo(f"EE + Thermal Free Energy Correction = {_ha(sum_EE_thermal_G)}")
1345
+ click.echo("")
1346
+ click.echo(f"E (Thermal) = {_cal(E_thermal_cal)}")
1347
+ click.echo(f"Heat Capacity (Cv) = {_calK(Cv_cal_per_Kmol)}")
1348
+ click.echo(f"Entropy (S) = {_calK(S_cal_per_Kmol)}")
1349
+ click.echo("")
1350
+
1351
+ # Dump YAML when requested
1352
+ if bool(thermo_cfg["dump"]):
1353
+ out_yaml = out_dir_path / "thermoanalysis.yaml"
1354
+ payload = {
1355
+ "temperature_K": T,
1356
+ "pressure_atm": p_atm,
1357
+ "num_imag_freq": n_imag,
1358
+ "electronic_energy_ha": EE,
1359
+ "zpe_correction_ha": ZPE,
1360
+ "thermal_correction_energy_ha": dE_therm,
1361
+ "thermal_correction_enthalpy_ha": dH_therm,
1362
+ "thermal_correction_free_energy_ha": dG_therm,
1363
+ "sum_EE_and_ZPE_ha": sum_EE_ZPE,
1364
+ "sum_EE_and_thermal_energy_ha": sum_EE_thermal_E,
1365
+ "sum_EE_and_thermal_enthalpy_ha": sum_EE_thermal_H,
1366
+ "sum_EE_and_thermal_free_energy_ha": sum_EE_thermal_G,
1367
+ "E_thermal_cal_per_mol": E_thermal_cal,
1368
+ "Cv_cal_per_mol_K": Cv_cal_per_Kmol,
1369
+ "S_cal_per_mol_K": S_cal_per_Kmol,
1370
+ }
1371
+ with out_yaml.open("w", encoding="utf-8") as f:
1372
+ yaml.safe_dump(payload, f, sort_keys=False, allow_unicode=True)
1373
+ click.echo(f"[dump] Wrote thermoanalysis summary → {out_yaml}")
1374
+
1375
+ except ImportError:
1376
+ click.echo("[thermo] WARNING: 'thermoanalysis' package not found; skipped thermochemistry summary.", err=True)
1377
+ except Exception as e:
1378
+ import traceback
1379
+ tb = "".join(traceback.format_exception(type(e), e, e.__traceback__))
1380
+ click.echo("Unhandled error during thermochemistry summary:\n" + textwrap.indent(tb, " "), err=True)
1381
+
1382
+ # summary.md and key_* outputs are disabled.
1383
+ click.echo(f"[DONE] Wrote modes and list → {out_dir_path}")
1384
+
1385
+ click.echo(format_elapsed("[time] Elapsed Time for Freq", time_start))
1386
+
1387
+ except KeyboardInterrupt:
1388
+ click.echo("\nInterrupted by user.", err=True)
1389
+ sys.exit(130)
1390
+ except Exception as e:
1391
+ import traceback
1392
+ tb = "".join(traceback.format_exception(type(e), e, e.__traceback__))
1393
+ click.echo("Unhandled error during frequency analysis:\n" + textwrap.indent(tb, " "), err=True)
1394
+ sys.exit(1)
1395
+ finally:
1396
+ prepared_input.cleanup()
1397
+ # Release GPU memory so subsequent pipeline stages don't OOM
1398
+ geometry = H_t = modes = None
1399
+ gc.collect() # break cyclic refs inside torch.nn.Module
1400
+ if torch.cuda.is_available():
1401
+ torch.cuda.empty_cache()
1402
+
1403
+
1404
+ # Allow `python -m mlmm.freq` direct execution
1405
+ if __name__ == "__main__":
1406
+ cli()