mlmm-toolkit 0.2.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (372) hide show
  1. hessian_ff/__init__.py +50 -0
  2. hessian_ff/analytical_hessian.py +609 -0
  3. hessian_ff/constants.py +46 -0
  4. hessian_ff/forcefield.py +339 -0
  5. hessian_ff/loaders.py +608 -0
  6. hessian_ff/native/Makefile +8 -0
  7. hessian_ff/native/__init__.py +28 -0
  8. hessian_ff/native/analytical_hessian.py +88 -0
  9. hessian_ff/native/analytical_hessian_ext.cpp +258 -0
  10. hessian_ff/native/bonded.py +82 -0
  11. hessian_ff/native/bonded_ext.cpp +640 -0
  12. hessian_ff/native/loader.py +349 -0
  13. hessian_ff/native/nonbonded.py +118 -0
  14. hessian_ff/native/nonbonded_ext.cpp +1150 -0
  15. hessian_ff/prmtop_parmed.py +23 -0
  16. hessian_ff/system.py +107 -0
  17. hessian_ff/terms/__init__.py +14 -0
  18. hessian_ff/terms/angle.py +73 -0
  19. hessian_ff/terms/bond.py +44 -0
  20. hessian_ff/terms/cmap.py +406 -0
  21. hessian_ff/terms/dihedral.py +141 -0
  22. hessian_ff/terms/nonbonded.py +209 -0
  23. hessian_ff/tests/__init__.py +0 -0
  24. hessian_ff/tests/conftest.py +75 -0
  25. hessian_ff/tests/data/small/complex.parm7 +1346 -0
  26. hessian_ff/tests/data/small/complex.pdb +125 -0
  27. hessian_ff/tests/data/small/complex.rst7 +63 -0
  28. hessian_ff/tests/test_coords_input.py +44 -0
  29. hessian_ff/tests/test_energy_force.py +49 -0
  30. hessian_ff/tests/test_hessian.py +137 -0
  31. hessian_ff/tests/test_smoke.py +18 -0
  32. hessian_ff/tests/test_validation.py +40 -0
  33. hessian_ff/workflows.py +889 -0
  34. mlmm/__init__.py +36 -0
  35. mlmm/__main__.py +7 -0
  36. mlmm/_version.py +34 -0
  37. mlmm/add_elem_info.py +374 -0
  38. mlmm/advanced_help.py +91 -0
  39. mlmm/align_freeze_atoms.py +601 -0
  40. mlmm/all.py +3535 -0
  41. mlmm/bond_changes.py +231 -0
  42. mlmm/bool_compat.py +223 -0
  43. mlmm/cli.py +574 -0
  44. mlmm/cli_utils.py +166 -0
  45. mlmm/default_group.py +337 -0
  46. mlmm/defaults.py +467 -0
  47. mlmm/define_layer.py +526 -0
  48. mlmm/dft.py +1041 -0
  49. mlmm/energy_diagram.py +253 -0
  50. mlmm/extract.py +2213 -0
  51. mlmm/fix_altloc.py +464 -0
  52. mlmm/freq.py +1406 -0
  53. mlmm/harmonic_constraints.py +140 -0
  54. mlmm/hessian_cache.py +44 -0
  55. mlmm/hessian_calc.py +174 -0
  56. mlmm/irc.py +638 -0
  57. mlmm/mlmm_calc.py +2262 -0
  58. mlmm/mm_parm.py +945 -0
  59. mlmm/oniom_export.py +1983 -0
  60. mlmm/oniom_import.py +457 -0
  61. mlmm/opt.py +1742 -0
  62. mlmm/path_opt.py +1353 -0
  63. mlmm/path_search.py +2299 -0
  64. mlmm/preflight.py +88 -0
  65. mlmm/py.typed +1 -0
  66. mlmm/pysis_runner.py +45 -0
  67. mlmm/scan.py +1047 -0
  68. mlmm/scan2d.py +1226 -0
  69. mlmm/scan3d.py +1265 -0
  70. mlmm/scan_common.py +184 -0
  71. mlmm/summary_log.py +736 -0
  72. mlmm/trj2fig.py +448 -0
  73. mlmm/tsopt.py +2871 -0
  74. mlmm/utils.py +2309 -0
  75. mlmm/xtb_embedcharge_correction.py +475 -0
  76. mlmm_toolkit-0.2.2.dev0.dist-info/METADATA +1159 -0
  77. mlmm_toolkit-0.2.2.dev0.dist-info/RECORD +372 -0
  78. mlmm_toolkit-0.2.2.dev0.dist-info/WHEEL +5 -0
  79. mlmm_toolkit-0.2.2.dev0.dist-info/entry_points.txt +2 -0
  80. mlmm_toolkit-0.2.2.dev0.dist-info/licenses/LICENSE +674 -0
  81. mlmm_toolkit-0.2.2.dev0.dist-info/top_level.txt +4 -0
  82. pysisyphus/Geometry.py +1667 -0
  83. pysisyphus/LICENSE +674 -0
  84. pysisyphus/TableFormatter.py +63 -0
  85. pysisyphus/TablePrinter.py +74 -0
  86. pysisyphus/__init__.py +12 -0
  87. pysisyphus/calculators/AFIR.py +452 -0
  88. pysisyphus/calculators/AnaPot.py +20 -0
  89. pysisyphus/calculators/AnaPot2.py +48 -0
  90. pysisyphus/calculators/AnaPot3.py +12 -0
  91. pysisyphus/calculators/AnaPot4.py +20 -0
  92. pysisyphus/calculators/AnaPotBase.py +337 -0
  93. pysisyphus/calculators/AnaPotCBM.py +25 -0
  94. pysisyphus/calculators/AtomAtomTransTorque.py +154 -0
  95. pysisyphus/calculators/CFOUR.py +250 -0
  96. pysisyphus/calculators/Calculator.py +844 -0
  97. pysisyphus/calculators/CerjanMiller.py +24 -0
  98. pysisyphus/calculators/Composite.py +123 -0
  99. pysisyphus/calculators/ConicalIntersection.py +171 -0
  100. pysisyphus/calculators/DFTBp.py +430 -0
  101. pysisyphus/calculators/DFTD3.py +66 -0
  102. pysisyphus/calculators/DFTD4.py +84 -0
  103. pysisyphus/calculators/Dalton.py +61 -0
  104. pysisyphus/calculators/Dimer.py +681 -0
  105. pysisyphus/calculators/Dummy.py +20 -0
  106. pysisyphus/calculators/EGO.py +76 -0
  107. pysisyphus/calculators/EnergyMin.py +224 -0
  108. pysisyphus/calculators/ExternalPotential.py +264 -0
  109. pysisyphus/calculators/FakeASE.py +35 -0
  110. pysisyphus/calculators/FourWellAnaPot.py +28 -0
  111. pysisyphus/calculators/FreeEndNEBPot.py +39 -0
  112. pysisyphus/calculators/Gaussian09.py +18 -0
  113. pysisyphus/calculators/Gaussian16.py +726 -0
  114. pysisyphus/calculators/HardSphere.py +159 -0
  115. pysisyphus/calculators/IDPPCalculator.py +49 -0
  116. pysisyphus/calculators/IPIClient.py +133 -0
  117. pysisyphus/calculators/IPIServer.py +234 -0
  118. pysisyphus/calculators/LEPSBase.py +24 -0
  119. pysisyphus/calculators/LEPSExpr.py +139 -0
  120. pysisyphus/calculators/LennardJones.py +80 -0
  121. pysisyphus/calculators/MOPAC.py +219 -0
  122. pysisyphus/calculators/MullerBrownSympyPot.py +51 -0
  123. pysisyphus/calculators/MultiCalc.py +85 -0
  124. pysisyphus/calculators/NFK.py +45 -0
  125. pysisyphus/calculators/OBabel.py +87 -0
  126. pysisyphus/calculators/ONIOMv2.py +1129 -0
  127. pysisyphus/calculators/ORCA.py +893 -0
  128. pysisyphus/calculators/ORCA5.py +6 -0
  129. pysisyphus/calculators/OpenMM.py +88 -0
  130. pysisyphus/calculators/OpenMolcas.py +281 -0
  131. pysisyphus/calculators/OverlapCalculator.py +908 -0
  132. pysisyphus/calculators/Psi4.py +218 -0
  133. pysisyphus/calculators/PyPsi4.py +37 -0
  134. pysisyphus/calculators/PySCF.py +341 -0
  135. pysisyphus/calculators/PyXTB.py +73 -0
  136. pysisyphus/calculators/QCEngine.py +106 -0
  137. pysisyphus/calculators/Rastrigin.py +22 -0
  138. pysisyphus/calculators/Remote.py +76 -0
  139. pysisyphus/calculators/Rosenbrock.py +15 -0
  140. pysisyphus/calculators/SocketCalc.py +97 -0
  141. pysisyphus/calculators/TIP3P.py +111 -0
  142. pysisyphus/calculators/TransTorque.py +161 -0
  143. pysisyphus/calculators/Turbomole.py +965 -0
  144. pysisyphus/calculators/VRIPot.py +37 -0
  145. pysisyphus/calculators/WFOWrapper.py +333 -0
  146. pysisyphus/calculators/WFOWrapper2.py +341 -0
  147. pysisyphus/calculators/XTB.py +418 -0
  148. pysisyphus/calculators/__init__.py +81 -0
  149. pysisyphus/calculators/cosmo_data.py +139 -0
  150. pysisyphus/calculators/parser.py +150 -0
  151. pysisyphus/color.py +19 -0
  152. pysisyphus/config.py +133 -0
  153. pysisyphus/constants.py +65 -0
  154. pysisyphus/cos/AdaptiveNEB.py +230 -0
  155. pysisyphus/cos/ChainOfStates.py +725 -0
  156. pysisyphus/cos/FreeEndNEB.py +25 -0
  157. pysisyphus/cos/FreezingString.py +103 -0
  158. pysisyphus/cos/GrowingChainOfStates.py +71 -0
  159. pysisyphus/cos/GrowingNT.py +309 -0
  160. pysisyphus/cos/GrowingString.py +508 -0
  161. pysisyphus/cos/NEB.py +189 -0
  162. pysisyphus/cos/SimpleZTS.py +64 -0
  163. pysisyphus/cos/__init__.py +22 -0
  164. pysisyphus/cos/stiffness.py +199 -0
  165. pysisyphus/drivers/__init__.py +17 -0
  166. pysisyphus/drivers/afir.py +855 -0
  167. pysisyphus/drivers/barriers.py +271 -0
  168. pysisyphus/drivers/birkholz.py +138 -0
  169. pysisyphus/drivers/cluster.py +318 -0
  170. pysisyphus/drivers/diabatization.py +133 -0
  171. pysisyphus/drivers/merge.py +368 -0
  172. pysisyphus/drivers/merge_mol2.py +322 -0
  173. pysisyphus/drivers/opt.py +375 -0
  174. pysisyphus/drivers/perf.py +91 -0
  175. pysisyphus/drivers/pka.py +52 -0
  176. pysisyphus/drivers/precon_pos_rot.py +669 -0
  177. pysisyphus/drivers/rates.py +480 -0
  178. pysisyphus/drivers/replace.py +219 -0
  179. pysisyphus/drivers/scan.py +212 -0
  180. pysisyphus/drivers/spectrum.py +166 -0
  181. pysisyphus/drivers/thermo.py +31 -0
  182. pysisyphus/dynamics/Gaussian.py +103 -0
  183. pysisyphus/dynamics/__init__.py +20 -0
  184. pysisyphus/dynamics/colvars.py +136 -0
  185. pysisyphus/dynamics/driver.py +297 -0
  186. pysisyphus/dynamics/helpers.py +256 -0
  187. pysisyphus/dynamics/lincs.py +105 -0
  188. pysisyphus/dynamics/mdp.py +364 -0
  189. pysisyphus/dynamics/rattle.py +121 -0
  190. pysisyphus/dynamics/thermostats.py +128 -0
  191. pysisyphus/dynamics/wigner.py +266 -0
  192. pysisyphus/elem_data.py +3473 -0
  193. pysisyphus/exceptions.py +2 -0
  194. pysisyphus/filtertrj.py +69 -0
  195. pysisyphus/helpers.py +623 -0
  196. pysisyphus/helpers_pure.py +649 -0
  197. pysisyphus/init_logging.py +50 -0
  198. pysisyphus/intcoords/Bend.py +69 -0
  199. pysisyphus/intcoords/Bend2.py +25 -0
  200. pysisyphus/intcoords/BondedFragment.py +32 -0
  201. pysisyphus/intcoords/Cartesian.py +41 -0
  202. pysisyphus/intcoords/CartesianCoords.py +140 -0
  203. pysisyphus/intcoords/Coords.py +56 -0
  204. pysisyphus/intcoords/DLC.py +197 -0
  205. pysisyphus/intcoords/DistanceFunction.py +34 -0
  206. pysisyphus/intcoords/DummyImproper.py +70 -0
  207. pysisyphus/intcoords/DummyTorsion.py +72 -0
  208. pysisyphus/intcoords/LinearBend.py +105 -0
  209. pysisyphus/intcoords/LinearDisplacement.py +80 -0
  210. pysisyphus/intcoords/OutOfPlane.py +59 -0
  211. pysisyphus/intcoords/PrimTypes.py +286 -0
  212. pysisyphus/intcoords/Primitive.py +137 -0
  213. pysisyphus/intcoords/RedundantCoords.py +659 -0
  214. pysisyphus/intcoords/RobustTorsion.py +59 -0
  215. pysisyphus/intcoords/Rotation.py +147 -0
  216. pysisyphus/intcoords/Stretch.py +31 -0
  217. pysisyphus/intcoords/Torsion.py +101 -0
  218. pysisyphus/intcoords/Torsion2.py +25 -0
  219. pysisyphus/intcoords/Translation.py +45 -0
  220. pysisyphus/intcoords/__init__.py +61 -0
  221. pysisyphus/intcoords/augment_bonds.py +126 -0
  222. pysisyphus/intcoords/derivatives.py +10512 -0
  223. pysisyphus/intcoords/eval.py +80 -0
  224. pysisyphus/intcoords/exceptions.py +37 -0
  225. pysisyphus/intcoords/findiffs.py +48 -0
  226. pysisyphus/intcoords/generate_derivatives.py +414 -0
  227. pysisyphus/intcoords/helpers.py +235 -0
  228. pysisyphus/intcoords/logging_conf.py +10 -0
  229. pysisyphus/intcoords/mp_derivatives.py +10836 -0
  230. pysisyphus/intcoords/setup.py +962 -0
  231. pysisyphus/intcoords/setup_fast.py +176 -0
  232. pysisyphus/intcoords/update.py +272 -0
  233. pysisyphus/intcoords/valid.py +89 -0
  234. pysisyphus/interpolate/Geodesic.py +93 -0
  235. pysisyphus/interpolate/IDPP.py +55 -0
  236. pysisyphus/interpolate/Interpolator.py +116 -0
  237. pysisyphus/interpolate/LST.py +70 -0
  238. pysisyphus/interpolate/Redund.py +152 -0
  239. pysisyphus/interpolate/__init__.py +9 -0
  240. pysisyphus/interpolate/helpers.py +34 -0
  241. pysisyphus/io/__init__.py +22 -0
  242. pysisyphus/io/aomix.py +178 -0
  243. pysisyphus/io/cjson.py +24 -0
  244. pysisyphus/io/crd.py +101 -0
  245. pysisyphus/io/cube.py +220 -0
  246. pysisyphus/io/fchk.py +184 -0
  247. pysisyphus/io/hdf5.py +49 -0
  248. pysisyphus/io/hessian.py +72 -0
  249. pysisyphus/io/mol2.py +146 -0
  250. pysisyphus/io/molden.py +293 -0
  251. pysisyphus/io/orca.py +189 -0
  252. pysisyphus/io/pdb.py +269 -0
  253. pysisyphus/io/psf.py +79 -0
  254. pysisyphus/io/pubchem.py +31 -0
  255. pysisyphus/io/qcschema.py +34 -0
  256. pysisyphus/io/sdf.py +29 -0
  257. pysisyphus/io/xyz.py +61 -0
  258. pysisyphus/io/zmat.py +175 -0
  259. pysisyphus/irc/DWI.py +108 -0
  260. pysisyphus/irc/DampedVelocityVerlet.py +134 -0
  261. pysisyphus/irc/Euler.py +22 -0
  262. pysisyphus/irc/EulerPC.py +345 -0
  263. pysisyphus/irc/GonzalezSchlegel.py +187 -0
  264. pysisyphus/irc/IMKMod.py +164 -0
  265. pysisyphus/irc/IRC.py +878 -0
  266. pysisyphus/irc/IRCDummy.py +10 -0
  267. pysisyphus/irc/Instanton.py +307 -0
  268. pysisyphus/irc/LQA.py +53 -0
  269. pysisyphus/irc/ModeKill.py +136 -0
  270. pysisyphus/irc/ParamPlot.py +53 -0
  271. pysisyphus/irc/RK4.py +36 -0
  272. pysisyphus/irc/__init__.py +31 -0
  273. pysisyphus/irc/initial_displ.py +219 -0
  274. pysisyphus/linalg.py +411 -0
  275. pysisyphus/line_searches/Backtracking.py +88 -0
  276. pysisyphus/line_searches/HagerZhang.py +184 -0
  277. pysisyphus/line_searches/LineSearch.py +232 -0
  278. pysisyphus/line_searches/StrongWolfe.py +108 -0
  279. pysisyphus/line_searches/__init__.py +9 -0
  280. pysisyphus/line_searches/interpol.py +15 -0
  281. pysisyphus/modefollow/NormalMode.py +40 -0
  282. pysisyphus/modefollow/__init__.py +10 -0
  283. pysisyphus/modefollow/davidson.py +199 -0
  284. pysisyphus/modefollow/lanczos.py +95 -0
  285. pysisyphus/optimizers/BFGS.py +99 -0
  286. pysisyphus/optimizers/BacktrackingOptimizer.py +113 -0
  287. pysisyphus/optimizers/ConjugateGradient.py +98 -0
  288. pysisyphus/optimizers/CubicNewton.py +75 -0
  289. pysisyphus/optimizers/FIRE.py +113 -0
  290. pysisyphus/optimizers/HessianOptimizer.py +1176 -0
  291. pysisyphus/optimizers/LBFGS.py +228 -0
  292. pysisyphus/optimizers/LayerOpt.py +411 -0
  293. pysisyphus/optimizers/MicroOptimizer.py +169 -0
  294. pysisyphus/optimizers/NCOptimizer.py +90 -0
  295. pysisyphus/optimizers/Optimizer.py +1084 -0
  296. pysisyphus/optimizers/PreconLBFGS.py +260 -0
  297. pysisyphus/optimizers/PreconSteepestDescent.py +7 -0
  298. pysisyphus/optimizers/QuickMin.py +74 -0
  299. pysisyphus/optimizers/RFOptimizer.py +181 -0
  300. pysisyphus/optimizers/RSA.py +99 -0
  301. pysisyphus/optimizers/StabilizedQNMethod.py +248 -0
  302. pysisyphus/optimizers/SteepestDescent.py +23 -0
  303. pysisyphus/optimizers/StringOptimizer.py +173 -0
  304. pysisyphus/optimizers/__init__.py +41 -0
  305. pysisyphus/optimizers/closures.py +301 -0
  306. pysisyphus/optimizers/cls_map.py +58 -0
  307. pysisyphus/optimizers/exceptions.py +6 -0
  308. pysisyphus/optimizers/gdiis.py +280 -0
  309. pysisyphus/optimizers/guess_hessians.py +311 -0
  310. pysisyphus/optimizers/hessian_updates.py +355 -0
  311. pysisyphus/optimizers/poly_fit.py +285 -0
  312. pysisyphus/optimizers/precon.py +153 -0
  313. pysisyphus/optimizers/restrict_step.py +24 -0
  314. pysisyphus/pack.py +172 -0
  315. pysisyphus/peakdetect.py +948 -0
  316. pysisyphus/plot.py +1031 -0
  317. pysisyphus/run.py +2106 -0
  318. pysisyphus/socket_helper.py +74 -0
  319. pysisyphus/stocastic/FragmentKick.py +132 -0
  320. pysisyphus/stocastic/Kick.py +81 -0
  321. pysisyphus/stocastic/Pipeline.py +303 -0
  322. pysisyphus/stocastic/__init__.py +21 -0
  323. pysisyphus/stocastic/align.py +127 -0
  324. pysisyphus/testing.py +96 -0
  325. pysisyphus/thermo.py +156 -0
  326. pysisyphus/trj.py +824 -0
  327. pysisyphus/tsoptimizers/RSIRFOptimizer.py +56 -0
  328. pysisyphus/tsoptimizers/RSPRFOptimizer.py +182 -0
  329. pysisyphus/tsoptimizers/TRIM.py +59 -0
  330. pysisyphus/tsoptimizers/TSHessianOptimizer.py +463 -0
  331. pysisyphus/tsoptimizers/__init__.py +23 -0
  332. pysisyphus/wavefunction/Basis.py +239 -0
  333. pysisyphus/wavefunction/DIIS.py +76 -0
  334. pysisyphus/wavefunction/__init__.py +25 -0
  335. pysisyphus/wavefunction/build_ext.py +42 -0
  336. pysisyphus/wavefunction/cart2sph.py +190 -0
  337. pysisyphus/wavefunction/diabatization.py +304 -0
  338. pysisyphus/wavefunction/excited_states.py +435 -0
  339. pysisyphus/wavefunction/gen_ints.py +1811 -0
  340. pysisyphus/wavefunction/helpers.py +104 -0
  341. pysisyphus/wavefunction/ints/__init__.py +0 -0
  342. pysisyphus/wavefunction/ints/boys.py +193 -0
  343. pysisyphus/wavefunction/ints/boys_table_N_64_xasym_27.1_step_0.01.npy +0 -0
  344. pysisyphus/wavefunction/ints/cart_gto3d.py +176 -0
  345. pysisyphus/wavefunction/ints/coulomb3d.py +25928 -0
  346. pysisyphus/wavefunction/ints/diag_quadrupole3d.py +10036 -0
  347. pysisyphus/wavefunction/ints/dipole3d.py +8762 -0
  348. pysisyphus/wavefunction/ints/int2c2e3d.py +7198 -0
  349. pysisyphus/wavefunction/ints/int3c2e3d_sph.py +65040 -0
  350. pysisyphus/wavefunction/ints/kinetic3d.py +8240 -0
  351. pysisyphus/wavefunction/ints/ovlp3d.py +3777 -0
  352. pysisyphus/wavefunction/ints/quadrupole3d.py +15054 -0
  353. pysisyphus/wavefunction/ints/self_ovlp3d.py +198 -0
  354. pysisyphus/wavefunction/localization.py +458 -0
  355. pysisyphus/wavefunction/multipole.py +159 -0
  356. pysisyphus/wavefunction/normalization.py +36 -0
  357. pysisyphus/wavefunction/pop_analysis.py +134 -0
  358. pysisyphus/wavefunction/shells.py +1171 -0
  359. pysisyphus/wavefunction/wavefunction.py +504 -0
  360. pysisyphus/wrapper/__init__.py +11 -0
  361. pysisyphus/wrapper/exceptions.py +2 -0
  362. pysisyphus/wrapper/jmol.py +120 -0
  363. pysisyphus/wrapper/mwfn.py +169 -0
  364. pysisyphus/wrapper/packmol.py +71 -0
  365. pysisyphus/xyzloader.py +168 -0
  366. pysisyphus/yaml_mods.py +45 -0
  367. thermoanalysis/LICENSE +674 -0
  368. thermoanalysis/QCData.py +244 -0
  369. thermoanalysis/__init__.py +0 -0
  370. thermoanalysis/config.py +3 -0
  371. thermoanalysis/constants.py +20 -0
  372. thermoanalysis/thermo.py +1011 -0
hessian_ff/loaders.py ADDED
@@ -0,0 +1,608 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from pathlib import Path
5
+ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
6
+
7
+ import torch
8
+
9
+ from .prmtop_parmed import read_prmtop_with_parmed
10
+ from .system import AmberSystem
11
+ from .terms.cmap import _calc_map_coefficients
12
+
13
+
14
+ def _chunk(seq: Sequence, n: int) -> Iterable[Tuple]:
15
+ for i in range(0, len(seq), n):
16
+ yield tuple(seq[i : i + n])
17
+
18
+
19
+ def _atom_index_from_coord_index(coord_index: int) -> int:
20
+ """prmtop stores atom references as (atom_index * 3) (0-based) with sign flags."""
21
+ return abs(coord_index) // 3
22
+
23
+
24
+ def _infer_ntypes(nonbonded_parm_index: Sequence[int]) -> int:
25
+ n = len(nonbonded_parm_index)
26
+ rt = int(round(math.sqrt(n)))
27
+ if rt * rt != n:
28
+ raise ValueError(f"Cannot infer NTYPES: NONBONDED_PARM_INDEX length {n} is not a square")
29
+ return rt
30
+
31
+
32
+ def _build_excluded_pair_keys(
33
+ natom: int, num_excluded: Sequence[int], excluded_list: Sequence[int]
34
+ ) -> torch.Tensor:
35
+ """Reconstruct excluded pair keys from NUMBER_EXCLUDED_ATOMS + EXCLUDED_ATOMS_LIST.
36
+
37
+ Returns a 1D tensor of keys = i*natom + j for i<j.
38
+ """
39
+ keys: List[int] = []
40
+ cursor = 0
41
+ for i in range(natom):
42
+ n = int(num_excluded[i])
43
+ partners = excluded_list[cursor : cursor + n]
44
+ cursor += n
45
+ for j1 in partners:
46
+ j1 = int(j1)
47
+ if j1 <= 0:
48
+ continue
49
+ j = j1 - 1 # prmtop stores 1-based atom numbers
50
+ if j == i:
51
+ continue
52
+ a, b = (i, j) if i < j else (j, i)
53
+ keys.append(a * natom + b)
54
+ if not keys:
55
+ return torch.zeros((0,), dtype=torch.int64)
56
+ # unique
57
+ return torch.unique(torch.tensor(keys, dtype=torch.int64))
58
+
59
+
60
+ def _build_14_pairs(
61
+ natom: int,
62
+ dihedrals_inc_h: Sequence[int],
63
+ dihedrals_wo_h: Sequence[int],
64
+ scee: Sequence[float],
65
+ scnb: Sequence[float],
66
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
67
+ """Extract unique 1-4 pairs and their scaling from dihedral lists.
68
+
69
+ Per AMBER prmtop conventions:
70
+ - If atom3 index is negative => do not compute 1-4 for that torsion
71
+ - If atom4 index is negative => improper torsion
72
+
73
+ We include 1-4 pairs only for entries with atom3>0 and atom4>0.
74
+ """
75
+
76
+ def add_from(dihed_list: Sequence[int], out: dict):
77
+ for a1, a2, a3, a4, itype in _chunk(dihed_list, 5):
78
+ a1 = int(a1)
79
+ a2 = int(a2)
80
+ a3 = int(a3)
81
+ a4 = int(a4)
82
+ it = int(itype) - 1 # Fortran 1-based
83
+ if a3 < 0:
84
+ continue # explicitly no 1-4
85
+ if a4 < 0:
86
+ continue # improper; do not generate 1-4
87
+ i = _atom_index_from_coord_index(a1)
88
+ j = _atom_index_from_coord_index(a4)
89
+ if i == j:
90
+ continue
91
+ a, b = (i, j) if i < j else (j, i)
92
+ key = a * natom + b
93
+ inv_scee = 1.0 / float(scee[it]) if float(scee[it]) != 0 else 1.0
94
+ inv_scnb = 1.0 / float(scnb[it]) if float(scnb[it]) != 0 else 1.0
95
+ if key in out:
96
+ # If multiple terms point to the same 1-4 pair, they should agree.
97
+ prev = out[key]
98
+ if (abs(prev[0] - inv_scee) > 1e-12) or (abs(prev[1] - inv_scnb) > 1e-12):
99
+ raise ValueError(
100
+ f"Inconsistent 1-4 scaling for pair ({a},{b}): {prev} vs {(inv_scee, inv_scnb)}"
101
+ )
102
+ else:
103
+ out[key] = (inv_scee, inv_scnb)
104
+
105
+ mapping: dict[int, Tuple[float, float]] = {}
106
+ add_from(dihedrals_inc_h, mapping)
107
+ add_from(dihedrals_wo_h, mapping)
108
+
109
+ if not mapping:
110
+ z = torch.zeros((0,), dtype=torch.int64)
111
+ f = torch.zeros((0,), dtype=torch.float64)
112
+ return z, z, f, f
113
+
114
+ keys = sorted(mapping.keys())
115
+ i = torch.tensor([k // natom for k in keys], dtype=torch.int64)
116
+ j = torch.tensor([k % natom for k in keys], dtype=torch.int64)
117
+ inv_scee = torch.tensor([mapping[k][0] for k in keys], dtype=torch.float64)
118
+ inv_scnb = torch.tensor([mapping[k][1] for k in keys], dtype=torch.float64)
119
+ return i, j, inv_scee, inv_scnb
120
+
121
+
122
+ def _build_general_pairs(
123
+ natom: int, excluded_keys: torch.Tensor, keys14: torch.Tensor
124
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
125
+ """Build all i<j pairs excluding excluded_keys and 1-4 keys."""
126
+ ti, tj = torch.triu_indices(natom, natom, offset=1)
127
+ keys = ti * natom + tj
128
+ mask = torch.ones_like(keys, dtype=torch.bool)
129
+ if excluded_keys.numel() > 0:
130
+ mask &= ~torch.isin(keys, excluded_keys)
131
+ if keys14.numel() > 0:
132
+ mask &= ~torch.isin(keys, keys14)
133
+ return ti[mask], tj[mask]
134
+
135
+
136
+ def _parse_cmap(
137
+ raw: Dict[str, List[Union[int, float, str]]]
138
+ ) -> Tuple[
139
+ torch.Tensor,
140
+ torch.Tensor,
141
+ torch.Tensor,
142
+ torch.Tensor,
143
+ torch.Tensor,
144
+ torch.Tensor,
145
+ Tuple[int, ...],
146
+ Tuple[torch.Tensor, ...],
147
+ ]:
148
+ """Parse CMAP flags and return tensors in OpenMM-compatible ordering.
149
+
150
+ Returns
151
+ -------
152
+ cmap_type, cmap_i, cmap_j, cmap_k, cmap_l, cmap_m, cmap_resolution, cmap_maps
153
+ """
154
+
155
+ count_flag = "CMAP_COUNT" if "CMAP_COUNT" in raw else "CHARMM_CMAP_COUNT"
156
+ if count_flag not in raw:
157
+ z = torch.zeros((0,), dtype=torch.int64)
158
+ return z, z, z, z, z, z, tuple(), tuple()
159
+
160
+ count_vals = raw.get(count_flag, [])
161
+ # AMBER stores two ints; OpenMM reads the second.
162
+ nmap = int(count_vals[1]) if len(count_vals) > 1 else int(count_vals[0])
163
+ if nmap <= 0:
164
+ z = torch.zeros((0,), dtype=torch.int64)
165
+ return z, z, z, z, z, z, tuple(), tuple()
166
+
167
+ res_flag = "CMAP_RESOLUTION" if "CMAP_RESOLUTION" in raw else "CHARMM_CMAP_RESOLUTION"
168
+ if res_flag not in raw:
169
+ raise ValueError("CMAP_COUNT is present but CMAP_RESOLUTION is missing")
170
+ res_vals = raw.get(res_flag, [])
171
+ if len(res_vals) < nmap:
172
+ raise ValueError(
173
+ f"CMAP_RESOLUTION too short: need {nmap}, got {len(res_vals)}"
174
+ )
175
+ cmap_resolution = tuple(int(x) for x in res_vals[:nmap])
176
+
177
+ cmap_maps: List[torch.Tensor] = []
178
+ for midx, ngrid in enumerate(cmap_resolution, start=1):
179
+ key = f"CMAP_PARAMETER_{midx:02d}"
180
+ if key not in raw:
181
+ key = f"CHARMM_CMAP_PARAMETER_{midx:02d}"
182
+ if key not in raw:
183
+ raise ValueError(f"Missing {key} for CMAP map {midx}")
184
+ vals = [float(x) for x in raw[key]]
185
+ need = ngrid * ngrid
186
+ if len(vals) < need:
187
+ raise ValueError(
188
+ f"{key} too short: need {need} values for {ngrid}x{ngrid}, got {len(vals)}"
189
+ )
190
+ vals = vals[:need]
191
+
192
+ # Match OpenMM amber_file_parser readAmberSystem() remapping.
193
+ remapped: List[float] = []
194
+ for i in range(ngrid):
195
+ for j in range(ngrid):
196
+ src = ngrid * ((j + ngrid // 2) % ngrid) + ((i + ngrid // 2) % ngrid)
197
+ remapped.append(vals[src])
198
+ cmap_maps.append(torch.tensor(remapped, dtype=torch.float64))
199
+
200
+ index_flag = "CMAP_INDEX" if "CMAP_INDEX" in raw else "CHARMM_CMAP_INDEX"
201
+ if index_flag not in raw:
202
+ raise ValueError("CMAP_COUNT is present but CMAP_INDEX is missing")
203
+ cmap_idx = raw.get(index_flag, [])
204
+ if len(cmap_idx) % 6 != 0:
205
+ raise ValueError(f"{index_flag} length must be multiple of 6, got {len(cmap_idx)}")
206
+
207
+ cmap_type: List[int] = []
208
+ cmap_i: List[int] = []
209
+ cmap_j: List[int] = []
210
+ cmap_k: List[int] = []
211
+ cmap_l: List[int] = []
212
+ cmap_m: List[int] = []
213
+ for a1, a2, a3, a4, a5, itype in _chunk(cmap_idx, 6):
214
+ a1 = int(a1)
215
+ a2 = int(a2)
216
+ a3 = int(a3)
217
+ a4 = int(a4)
218
+ a5 = int(a5)
219
+ t = int(itype) - 1
220
+ if min(a1, a2, a3, a4, a5) <= 0:
221
+ raise ValueError(
222
+ f"Invalid CMAP atom pointer ({a1}, {a2}, {a3}, {a4}, {a5}); expected 1-based positive"
223
+ )
224
+ if t < 0 or t >= nmap:
225
+ raise ValueError(f"Invalid CMAP map type index {int(itype)} (1-based), nmap={nmap}")
226
+ cmap_type.append(t)
227
+ # CMAP_INDEX stores 1-based atom indices (not coord-index*3)
228
+ cmap_i.append(a1 - 1)
229
+ cmap_j.append(a2 - 1)
230
+ cmap_k.append(a3 - 1)
231
+ cmap_l.append(a4 - 1)
232
+ cmap_m.append(a5 - 1)
233
+
234
+ return (
235
+ torch.tensor(cmap_type, dtype=torch.int64),
236
+ torch.tensor(cmap_i, dtype=torch.int64),
237
+ torch.tensor(cmap_j, dtype=torch.int64),
238
+ torch.tensor(cmap_k, dtype=torch.int64),
239
+ torch.tensor(cmap_l, dtype=torch.int64),
240
+ torch.tensor(cmap_m, dtype=torch.int64),
241
+ cmap_resolution,
242
+ tuple(cmap_maps),
243
+ )
244
+
245
+
246
+ def _build_cmap_coeff_data(
247
+ cmap_resolution: Tuple[int, ...],
248
+ cmap_maps: Tuple[torch.Tensor, ...],
249
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
250
+ """Precompute CMAP bicubic coefficient tables once at load time."""
251
+ if len(cmap_resolution) == 0:
252
+ z_i64 = torch.zeros((0,), dtype=torch.int64)
253
+ z_f64 = torch.zeros((0,), dtype=torch.float64)
254
+ z_coeff = torch.zeros((0, 16), dtype=torch.float64)
255
+ return z_i64, z_f64, z_i64, z_coeff
256
+
257
+ map_size = torch.tensor([int(x) for x in cmap_resolution], dtype=torch.int64)
258
+ coeff_parts: List[torch.Tensor] = []
259
+ offsets: List[int] = []
260
+ cursor = 0
261
+ for ngrid, table in zip(cmap_resolution, cmap_maps):
262
+ n = int(ngrid)
263
+ offsets.append(cursor)
264
+ c = _calc_map_coefficients(table, size=n)
265
+ coeff_parts.append(c)
266
+ cursor += n * n
267
+
268
+ map_delta = (torch.full_like(map_size, 2.0 * math.pi, dtype=torch.float64) / map_size.to(torch.float64))
269
+ map_offset = torch.tensor(offsets, dtype=torch.int64)
270
+ map_coeff = torch.cat(coeff_parts, dim=0) if coeff_parts else torch.zeros((0, 16), dtype=torch.float64)
271
+ return map_size, map_delta, map_offset, map_coeff
272
+
273
+
274
+ def load_system(
275
+ prmtop_path: Union[str, Path],
276
+ device: Union[str, torch.device] = "cpu",
277
+ ) -> AmberSystem:
278
+ """Load an AMBER parm7/prmtop into an :class:`~hessian_ff.system.AmberSystem`.
279
+
280
+ Parameters
281
+ ----------
282
+ prmtop_path:
283
+ Path to .prmtop/.parm7
284
+ device:
285
+ Target tensor device.
286
+ """
287
+ prmtop_path = Path(prmtop_path)
288
+
289
+ raw = read_prmtop_with_parmed(prmtop_path)
290
+
291
+ # ---- required sections ----
292
+ charge = raw["CHARGE"]
293
+ natom = len(charge)
294
+ atom_type_index = raw["ATOM_TYPE_INDEX"]
295
+
296
+ nbpi = raw["NONBONDED_PARM_INDEX"]
297
+ ntypes = _infer_ntypes(nbpi)
298
+ # Keep raw signed NONBONDED_PARM_INDEX values:
299
+ # >0 : index into Lennard-Jones A/B tables (Fortran 1-based)
300
+ # <0 : index into HBOND A/B tables (Fortran -1-based)
301
+ nb_index = torch.tensor(nbpi, dtype=torch.int64).reshape(ntypes, ntypes)
302
+
303
+ lj_acoef = torch.tensor(raw["LENNARD_JONES_ACOEF"], dtype=torch.float64)
304
+ lj_bcoef = torch.tensor(raw["LENNARD_JONES_BCOEF"], dtype=torch.float64)
305
+ hb_acoef = torch.tensor(raw.get("HBOND_ACOEF", []), dtype=torch.float64)
306
+ hb_bcoef = torch.tensor(raw.get("HBOND_BCOEF", []), dtype=torch.float64)
307
+ atom_type_t = torch.tensor(atom_type_index, dtype=torch.int64) - 1
308
+
309
+ # ---- bonds ----
310
+ bonds = list(_chunk(raw.get("BONDS_INC_HYDROGEN", []), 3)) + list(
311
+ _chunk(raw.get("BONDS_WITHOUT_HYDROGEN", []), 3)
312
+ )
313
+ bond_k_list = raw.get("BOND_FORCE_CONSTANT", [])
314
+ bond_r0_list = raw.get("BOND_EQUIL_VALUE", [])
315
+ bond_i = []
316
+ bond_j = []
317
+ bond_k = []
318
+ bond_r0 = []
319
+ for a1, a2, itype in bonds:
320
+ i = _atom_index_from_coord_index(int(a1))
321
+ j = _atom_index_from_coord_index(int(a2))
322
+ t = int(itype) - 1
323
+ bond_i.append(i)
324
+ bond_j.append(j)
325
+ bond_k.append(float(bond_k_list[t]))
326
+ bond_r0.append(float(bond_r0_list[t]))
327
+
328
+ # ---- angles ----
329
+ angles = list(_chunk(raw.get("ANGLES_INC_HYDROGEN", []), 4)) + list(
330
+ _chunk(raw.get("ANGLES_WITHOUT_HYDROGEN", []), 4)
331
+ )
332
+ ang_k_list = raw.get("ANGLE_FORCE_CONSTANT", [])
333
+ ang_t0_list = raw.get("ANGLE_EQUIL_VALUE", [])
334
+ angle_i = []
335
+ angle_j = []
336
+ angle_k = []
337
+ angle_k0 = []
338
+ angle_t0 = []
339
+ for a1, a2, a3, itype in angles:
340
+ i = _atom_index_from_coord_index(int(a1))
341
+ j = _atom_index_from_coord_index(int(a2))
342
+ k = _atom_index_from_coord_index(int(a3))
343
+ t = int(itype) - 1
344
+ angle_i.append(i)
345
+ angle_j.append(j)
346
+ angle_k.append(k)
347
+ angle_k0.append(float(ang_k_list[t]))
348
+ angle_t0.append(float(ang_t0_list[t]))
349
+
350
+ # ---- dihedrals ----
351
+ dihedrals = list(_chunk(raw.get("DIHEDRALS_INC_HYDROGEN", []), 5)) + list(
352
+ _chunk(raw.get("DIHEDRALS_WITHOUT_HYDROGEN", []), 5)
353
+ )
354
+ dih_force_list = raw.get("DIHEDRAL_FORCE_CONSTANT", [])
355
+ dih_phase_list = raw.get("DIHEDRAL_PHASE", [])
356
+ dih_per_list = raw.get("DIHEDRAL_PERIODICITY", [])
357
+
358
+ dihed_i = []
359
+ dihed_j = []
360
+ dihed_k = []
361
+ dihed_l = []
362
+ dihed_force = []
363
+ dihed_period = []
364
+ dihed_phase = []
365
+ for a1, a2, a3, a4, itype in dihedrals:
366
+ t = int(itype) - 1
367
+ dihed_i.append(_atom_index_from_coord_index(int(a1)))
368
+ dihed_j.append(_atom_index_from_coord_index(int(a2)))
369
+ dihed_k.append(_atom_index_from_coord_index(int(a3)))
370
+ dihed_l.append(_atom_index_from_coord_index(int(a4)))
371
+ dihed_force.append(float(dih_force_list[t]))
372
+ dihed_period.append(float(dih_per_list[t]))
373
+ dihed_phase.append(float(dih_phase_list[t]))
374
+
375
+ # ---- exclusions & 1-4 ----
376
+ excluded_keys = _build_excluded_pair_keys(
377
+ natom,
378
+ raw.get("NUMBER_EXCLUDED_ATOMS", [0] * natom),
379
+ raw.get("EXCLUDED_ATOMS_LIST", []),
380
+ )
381
+
382
+ scee = raw.get("SCEE_SCALE_FACTOR", [])
383
+ scnb = raw.get("SCNB_SCALE_FACTOR", [])
384
+ if not scee or not scnb:
385
+ # Defaults: AMBER typically uses 1.2 and 2.0 (i.e. scale 1/1.2 and 1/2.0)
386
+ # but in prmtop these should exist. If absent, fall back to "no scaling".
387
+ scee = [1.0] * len(dih_force_list)
388
+ scnb = [1.0] * len(dih_force_list)
389
+
390
+ d_inc = raw.get("DIHEDRALS_INC_HYDROGEN", [])
391
+ d_wo = raw.get("DIHEDRALS_WITHOUT_HYDROGEN", [])
392
+ p14_i, p14_j, inv_scee, inv_scnb = _build_14_pairs(natom, d_inc, d_wo, scee, scnb)
393
+ keys14 = p14_i * natom + p14_j
394
+
395
+ pair_i, pair_j = _build_general_pairs(natom, excluded_keys, keys14)
396
+ (
397
+ cmap_type,
398
+ cmap_i,
399
+ cmap_j,
400
+ cmap_k,
401
+ cmap_l,
402
+ cmap_m,
403
+ cmap_resolution,
404
+ cmap_maps,
405
+ ) = _parse_cmap(raw)
406
+ cmap_size, cmap_delta, cmap_offset, cmap_coeff = _build_cmap_coeff_data(
407
+ cmap_resolution=cmap_resolution,
408
+ cmap_maps=cmap_maps,
409
+ )
410
+
411
+ device = torch.device(device)
412
+
413
+ return AmberSystem(
414
+ natom=natom,
415
+ charge=torch.tensor(charge, dtype=torch.float64, device=device),
416
+ atom_type=atom_type_t.to(device),
417
+ bond_i=torch.tensor(bond_i, dtype=torch.int64, device=device),
418
+ bond_j=torch.tensor(bond_j, dtype=torch.int64, device=device),
419
+ bond_k=torch.tensor(bond_k, dtype=torch.float64, device=device),
420
+ bond_r0=torch.tensor(bond_r0, dtype=torch.float64, device=device),
421
+ angle_i=torch.tensor(angle_i, dtype=torch.int64, device=device),
422
+ angle_j=torch.tensor(angle_j, dtype=torch.int64, device=device),
423
+ angle_k=torch.tensor(angle_k, dtype=torch.int64, device=device),
424
+ angle_k0=torch.tensor(angle_k0, dtype=torch.float64, device=device),
425
+ angle_t0=torch.tensor(angle_t0, dtype=torch.float64, device=device),
426
+ dihed_i=torch.tensor(dihed_i, dtype=torch.int64, device=device),
427
+ dihed_j=torch.tensor(dihed_j, dtype=torch.int64, device=device),
428
+ dihed_k=torch.tensor(dihed_k, dtype=torch.int64, device=device),
429
+ dihed_l=torch.tensor(dihed_l, dtype=torch.int64, device=device),
430
+ dihed_force=torch.tensor(dihed_force, dtype=torch.float64, device=device),
431
+ dihed_period=torch.tensor(dihed_period, dtype=torch.float64, device=device),
432
+ dihed_phase=torch.tensor(dihed_phase, dtype=torch.float64, device=device),
433
+ lj_acoef=lj_acoef.to(device),
434
+ lj_bcoef=lj_bcoef.to(device),
435
+ hb_acoef=hb_acoef.to(device),
436
+ hb_bcoef=hb_bcoef.to(device),
437
+ nb_index=nb_index.to(device),
438
+ pair_i=pair_i.to(device),
439
+ pair_j=pair_j.to(device),
440
+ pair14_i=p14_i.to(device),
441
+ pair14_j=p14_j.to(device),
442
+ pair14_inv_scee=inv_scee.to(device),
443
+ pair14_inv_scnb=inv_scnb.to(device),
444
+ cmap_type=cmap_type.to(device),
445
+ cmap_i=cmap_i.to(device),
446
+ cmap_j=cmap_j.to(device),
447
+ cmap_k=cmap_k.to(device),
448
+ cmap_l=cmap_l.to(device),
449
+ cmap_m=cmap_m.to(device),
450
+ cmap_resolution=cmap_resolution,
451
+ cmap_maps=tuple(x.to(device) for x in cmap_maps),
452
+ cmap_size=cmap_size.to(device),
453
+ cmap_delta=cmap_delta.to(device),
454
+ cmap_offset=cmap_offset.to(device),
455
+ cmap_coeff=cmap_coeff.to(device),
456
+ )
457
+
458
+
459
+ def load_coords(
460
+ coords: Union[str, Path, torch.Tensor, Any],
461
+ natom: Optional[int] = None,
462
+ device: Union[str, torch.device] = "cpu",
463
+ dtype: torch.dtype = torch.float64,
464
+ ) -> torch.Tensor:
465
+ """Load coordinates (Angstrom).
466
+
467
+ Supported formats:
468
+ - Amber restart / inpcrd (ASCII)
469
+ - PDB (ATOM/HETATM records; uses X,Y,Z columns)
470
+ - XYZ
471
+ - ASE Atoms object
472
+ - torch.Tensor with shape [N,3]
473
+ """
474
+ dev = torch.device(device)
475
+
476
+ if torch.is_tensor(coords):
477
+ xyz_t = coords.detach().to(device=dev, dtype=dtype)
478
+ if xyz_t.ndim != 2 or int(xyz_t.shape[1]) != 3:
479
+ raise ValueError(f"coords tensor must have shape [N,3], got {tuple(xyz_t.shape)}")
480
+ if natom is not None and int(xyz_t.shape[0]) != int(natom):
481
+ raise ValueError(f"NATOM mismatch: coords has {int(xyz_t.shape[0])}, expected {natom}")
482
+ return xyz_t
483
+
484
+ # ASE Atoms (duck-typed): object exposing get_positions() -> [N,3]
485
+ if hasattr(coords, "get_positions") and callable(getattr(coords, "get_positions")):
486
+ pos = coords.get_positions()
487
+ xyz_t = torch.as_tensor(pos, dtype=dtype, device=dev)
488
+ if xyz_t.ndim != 2 or int(xyz_t.shape[1]) != 3:
489
+ raise ValueError(f"ASE positions must have shape [N,3], got {tuple(xyz_t.shape)}")
490
+ if natom is not None and int(xyz_t.shape[0]) != int(natom):
491
+ raise ValueError(f"NATOM mismatch: coords has {int(xyz_t.shape[0])}, expected {natom}")
492
+ return xyz_t
493
+
494
+ coords_path = Path(coords)
495
+ suffix = coords_path.suffix.lower()
496
+ if suffix in {".rst7", ".inpcrd", ".crd", ".restrt"}:
497
+ xyz = _read_amber_inpcrd(coords_path, natom=natom)
498
+ elif suffix == ".pdb":
499
+ xyz = _read_pdb(coords_path)
500
+ elif suffix == ".xyz":
501
+ xyz = _read_xyz(coords_path, natom=natom)
502
+ else:
503
+ raise ValueError(
504
+ f"Unsupported coordinate input: {coords!r}. "
505
+ "Supported: rst7/inpcrd/crd/restrt, pdb, xyz, ASE Atoms, torch.Tensor[N,3]."
506
+ )
507
+ if natom is not None and len(xyz) != int(natom):
508
+ raise ValueError(f"NATOM mismatch: coords has {len(xyz)}, expected {natom}")
509
+ return torch.tensor(xyz, dtype=dtype, device=dev)
510
+
511
+
512
+ def _read_amber_inpcrd(path: Path, natom: Optional[int] = None) -> List[List[float]]:
513
+ """Very small reader for Amber ASCII restart/inpcrd coordinate files."""
514
+ raw_lines = path.read_text(errors="replace").splitlines()
515
+ if len(raw_lines) < 2:
516
+ raise ValueError(f"File too short: {path}")
517
+ # 2nd line begins with NATOM
518
+ nat = int(raw_lines[1].split()[0])
519
+ if natom is not None and nat != natom:
520
+ raise ValueError(f"NATOM mismatch: coords has {nat}, expected {natom}")
521
+ # Collect floats from remaining lines.
522
+ # Amber ASCII inpcrd/rst7 commonly uses fixed-width (typically 6E12.7 per line).
523
+ # Parsing as fixed-width is robust even if there are spaces.
524
+ floats: List[float] = []
525
+ for ln in raw_lines[2:]:
526
+ for i in range(0, len(ln), 12):
527
+ field = ln[i : i + 12].strip()
528
+ if not field:
529
+ continue
530
+ try:
531
+ floats.append(float(field))
532
+ except ValueError:
533
+ continue
534
+ need = 3 * nat
535
+ if len(floats) < need:
536
+ raise ValueError(f"Not enough coordinate floats in {path}: got {len(floats)}, need {need}")
537
+ floats = floats[:need]
538
+ return [[floats[3 * i], floats[3 * i + 1], floats[3 * i + 2]] for i in range(nat)]
539
+
540
+
541
+ def _read_pdb(path: Path) -> List[List[float]]:
542
+ xyz: List[List[float]] = []
543
+ for ln in path.read_text(errors="replace").splitlines():
544
+ if ln.startswith("ATOM") or ln.startswith("HETATM"):
545
+ x = float(ln[30:38])
546
+ y = float(ln[38:46])
547
+ z = float(ln[46:54])
548
+ xyz.append([x, y, z])
549
+ if not xyz:
550
+ raise ValueError(f"No ATOM/HETATM coordinates found in {path}")
551
+ return xyz
552
+
553
+
554
+ def _read_xyz(path: Path, natom: Optional[int] = None) -> List[List[float]]:
555
+ """Read first frame from XYZ.
556
+
557
+ Supports standard XYZ (atom count + comment header) and simple
558
+ whitespace-separated rows with either ``sym x y z`` or ``x y z``.
559
+ """
560
+
561
+ def _parse_xyz_row(raw_line: str) -> List[float]:
562
+ toks = raw_line.split()
563
+ if len(toks) < 3:
564
+ raise ValueError(f"Invalid XYZ row: {raw_line!r}")
565
+ for start in (1, 0):
566
+ if len(toks) < start + 3:
567
+ continue
568
+ try:
569
+ x = float(toks[start + 0])
570
+ y = float(toks[start + 1])
571
+ z = float(toks[start + 2])
572
+ return [x, y, z]
573
+ except ValueError:
574
+ continue
575
+ raise ValueError(f"Invalid XYZ row: {raw_line!r}")
576
+
577
+ lines = [ln.rstrip("\n") for ln in path.read_text(errors="replace").splitlines()]
578
+ if not lines:
579
+ raise ValueError(f"Empty XYZ file: {path}")
580
+
581
+ xyz: List[List[float]] = []
582
+ header_natom: Optional[int] = None
583
+ try:
584
+ header_natom = int(lines[0].strip().split()[0])
585
+ except Exception:
586
+ header_natom = None
587
+
588
+ if header_natom is not None:
589
+ if len(lines) < 2 + header_natom:
590
+ raise ValueError(
591
+ f"XYZ header count mismatch in {path}: need {header_natom} atoms, "
592
+ f"but file has {max(0, len(lines) - 2)} rows"
593
+ )
594
+ for ln in lines[2 : 2 + header_natom]:
595
+ if not ln.strip():
596
+ continue
597
+ xyz.append(_parse_xyz_row(ln))
598
+ else:
599
+ for ln in lines:
600
+ if not ln.strip():
601
+ continue
602
+ xyz.append(_parse_xyz_row(ln))
603
+
604
+ if not xyz:
605
+ raise ValueError(f"No coordinates found in XYZ file: {path}")
606
+ if natom is not None and len(xyz) != int(natom):
607
+ raise ValueError(f"NATOM mismatch: coords has {len(xyz)}, expected {natom}")
608
+ return xyz
@@ -0,0 +1,8 @@
1
+ .PHONY: all clean
2
+
3
+ all:
4
+ python -c "from hessian_ff.native import build_native_extensions; \
5
+ build_native_extensions(verbose=True, force_rebuild=True)"
6
+
7
+ clean:
8
+ rm -rf .build_nonbonded .build_bonded .build_analytical_hessian
@@ -0,0 +1,28 @@
1
+ """Native backend helpers for hessian_ff.
2
+
3
+ The nonbonded term requires the in-tree C++ extension in this release.
4
+ """
5
+
6
+ from .loader import (
7
+ analytical_hessian_extension_status,
8
+ bonded_extension_status,
9
+ build_native_extensions,
10
+ get_analytical_hessian_extension,
11
+ get_bonded_extension,
12
+ get_nonbonded_extension,
13
+ native_backend_status,
14
+ nonbonded_extension_status,
15
+ try_load_native_backend,
16
+ )
17
+
18
+ __all__ = [
19
+ "get_nonbonded_extension",
20
+ "get_analytical_hessian_extension",
21
+ "get_bonded_extension",
22
+ "build_native_extensions",
23
+ "native_backend_status",
24
+ "nonbonded_extension_status",
25
+ "analytical_hessian_extension_status",
26
+ "bonded_extension_status",
27
+ "try_load_native_backend",
28
+ ]