mlmm-toolkit 0.2.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (372) hide show
  1. hessian_ff/__init__.py +50 -0
  2. hessian_ff/analytical_hessian.py +609 -0
  3. hessian_ff/constants.py +46 -0
  4. hessian_ff/forcefield.py +339 -0
  5. hessian_ff/loaders.py +608 -0
  6. hessian_ff/native/Makefile +8 -0
  7. hessian_ff/native/__init__.py +28 -0
  8. hessian_ff/native/analytical_hessian.py +88 -0
  9. hessian_ff/native/analytical_hessian_ext.cpp +258 -0
  10. hessian_ff/native/bonded.py +82 -0
  11. hessian_ff/native/bonded_ext.cpp +640 -0
  12. hessian_ff/native/loader.py +349 -0
  13. hessian_ff/native/nonbonded.py +118 -0
  14. hessian_ff/native/nonbonded_ext.cpp +1150 -0
  15. hessian_ff/prmtop_parmed.py +23 -0
  16. hessian_ff/system.py +107 -0
  17. hessian_ff/terms/__init__.py +14 -0
  18. hessian_ff/terms/angle.py +73 -0
  19. hessian_ff/terms/bond.py +44 -0
  20. hessian_ff/terms/cmap.py +406 -0
  21. hessian_ff/terms/dihedral.py +141 -0
  22. hessian_ff/terms/nonbonded.py +209 -0
  23. hessian_ff/tests/__init__.py +0 -0
  24. hessian_ff/tests/conftest.py +75 -0
  25. hessian_ff/tests/data/small/complex.parm7 +1346 -0
  26. hessian_ff/tests/data/small/complex.pdb +125 -0
  27. hessian_ff/tests/data/small/complex.rst7 +63 -0
  28. hessian_ff/tests/test_coords_input.py +44 -0
  29. hessian_ff/tests/test_energy_force.py +49 -0
  30. hessian_ff/tests/test_hessian.py +137 -0
  31. hessian_ff/tests/test_smoke.py +18 -0
  32. hessian_ff/tests/test_validation.py +40 -0
  33. hessian_ff/workflows.py +889 -0
  34. mlmm/__init__.py +36 -0
  35. mlmm/__main__.py +7 -0
  36. mlmm/_version.py +34 -0
  37. mlmm/add_elem_info.py +374 -0
  38. mlmm/advanced_help.py +91 -0
  39. mlmm/align_freeze_atoms.py +601 -0
  40. mlmm/all.py +3535 -0
  41. mlmm/bond_changes.py +231 -0
  42. mlmm/bool_compat.py +223 -0
  43. mlmm/cli.py +574 -0
  44. mlmm/cli_utils.py +166 -0
  45. mlmm/default_group.py +337 -0
  46. mlmm/defaults.py +467 -0
  47. mlmm/define_layer.py +526 -0
  48. mlmm/dft.py +1041 -0
  49. mlmm/energy_diagram.py +253 -0
  50. mlmm/extract.py +2213 -0
  51. mlmm/fix_altloc.py +464 -0
  52. mlmm/freq.py +1406 -0
  53. mlmm/harmonic_constraints.py +140 -0
  54. mlmm/hessian_cache.py +44 -0
  55. mlmm/hessian_calc.py +174 -0
  56. mlmm/irc.py +638 -0
  57. mlmm/mlmm_calc.py +2262 -0
  58. mlmm/mm_parm.py +945 -0
  59. mlmm/oniom_export.py +1983 -0
  60. mlmm/oniom_import.py +457 -0
  61. mlmm/opt.py +1742 -0
  62. mlmm/path_opt.py +1353 -0
  63. mlmm/path_search.py +2299 -0
  64. mlmm/preflight.py +88 -0
  65. mlmm/py.typed +1 -0
  66. mlmm/pysis_runner.py +45 -0
  67. mlmm/scan.py +1047 -0
  68. mlmm/scan2d.py +1226 -0
  69. mlmm/scan3d.py +1265 -0
  70. mlmm/scan_common.py +184 -0
  71. mlmm/summary_log.py +736 -0
  72. mlmm/trj2fig.py +448 -0
  73. mlmm/tsopt.py +2871 -0
  74. mlmm/utils.py +2309 -0
  75. mlmm/xtb_embedcharge_correction.py +475 -0
  76. mlmm_toolkit-0.2.2.dev0.dist-info/METADATA +1159 -0
  77. mlmm_toolkit-0.2.2.dev0.dist-info/RECORD +372 -0
  78. mlmm_toolkit-0.2.2.dev0.dist-info/WHEEL +5 -0
  79. mlmm_toolkit-0.2.2.dev0.dist-info/entry_points.txt +2 -0
  80. mlmm_toolkit-0.2.2.dev0.dist-info/licenses/LICENSE +674 -0
  81. mlmm_toolkit-0.2.2.dev0.dist-info/top_level.txt +4 -0
  82. pysisyphus/Geometry.py +1667 -0
  83. pysisyphus/LICENSE +674 -0
  84. pysisyphus/TableFormatter.py +63 -0
  85. pysisyphus/TablePrinter.py +74 -0
  86. pysisyphus/__init__.py +12 -0
  87. pysisyphus/calculators/AFIR.py +452 -0
  88. pysisyphus/calculators/AnaPot.py +20 -0
  89. pysisyphus/calculators/AnaPot2.py +48 -0
  90. pysisyphus/calculators/AnaPot3.py +12 -0
  91. pysisyphus/calculators/AnaPot4.py +20 -0
  92. pysisyphus/calculators/AnaPotBase.py +337 -0
  93. pysisyphus/calculators/AnaPotCBM.py +25 -0
  94. pysisyphus/calculators/AtomAtomTransTorque.py +154 -0
  95. pysisyphus/calculators/CFOUR.py +250 -0
  96. pysisyphus/calculators/Calculator.py +844 -0
  97. pysisyphus/calculators/CerjanMiller.py +24 -0
  98. pysisyphus/calculators/Composite.py +123 -0
  99. pysisyphus/calculators/ConicalIntersection.py +171 -0
  100. pysisyphus/calculators/DFTBp.py +430 -0
  101. pysisyphus/calculators/DFTD3.py +66 -0
  102. pysisyphus/calculators/DFTD4.py +84 -0
  103. pysisyphus/calculators/Dalton.py +61 -0
  104. pysisyphus/calculators/Dimer.py +681 -0
  105. pysisyphus/calculators/Dummy.py +20 -0
  106. pysisyphus/calculators/EGO.py +76 -0
  107. pysisyphus/calculators/EnergyMin.py +224 -0
  108. pysisyphus/calculators/ExternalPotential.py +264 -0
  109. pysisyphus/calculators/FakeASE.py +35 -0
  110. pysisyphus/calculators/FourWellAnaPot.py +28 -0
  111. pysisyphus/calculators/FreeEndNEBPot.py +39 -0
  112. pysisyphus/calculators/Gaussian09.py +18 -0
  113. pysisyphus/calculators/Gaussian16.py +726 -0
  114. pysisyphus/calculators/HardSphere.py +159 -0
  115. pysisyphus/calculators/IDPPCalculator.py +49 -0
  116. pysisyphus/calculators/IPIClient.py +133 -0
  117. pysisyphus/calculators/IPIServer.py +234 -0
  118. pysisyphus/calculators/LEPSBase.py +24 -0
  119. pysisyphus/calculators/LEPSExpr.py +139 -0
  120. pysisyphus/calculators/LennardJones.py +80 -0
  121. pysisyphus/calculators/MOPAC.py +219 -0
  122. pysisyphus/calculators/MullerBrownSympyPot.py +51 -0
  123. pysisyphus/calculators/MultiCalc.py +85 -0
  124. pysisyphus/calculators/NFK.py +45 -0
  125. pysisyphus/calculators/OBabel.py +87 -0
  126. pysisyphus/calculators/ONIOMv2.py +1129 -0
  127. pysisyphus/calculators/ORCA.py +893 -0
  128. pysisyphus/calculators/ORCA5.py +6 -0
  129. pysisyphus/calculators/OpenMM.py +88 -0
  130. pysisyphus/calculators/OpenMolcas.py +281 -0
  131. pysisyphus/calculators/OverlapCalculator.py +908 -0
  132. pysisyphus/calculators/Psi4.py +218 -0
  133. pysisyphus/calculators/PyPsi4.py +37 -0
  134. pysisyphus/calculators/PySCF.py +341 -0
  135. pysisyphus/calculators/PyXTB.py +73 -0
  136. pysisyphus/calculators/QCEngine.py +106 -0
  137. pysisyphus/calculators/Rastrigin.py +22 -0
  138. pysisyphus/calculators/Remote.py +76 -0
  139. pysisyphus/calculators/Rosenbrock.py +15 -0
  140. pysisyphus/calculators/SocketCalc.py +97 -0
  141. pysisyphus/calculators/TIP3P.py +111 -0
  142. pysisyphus/calculators/TransTorque.py +161 -0
  143. pysisyphus/calculators/Turbomole.py +965 -0
  144. pysisyphus/calculators/VRIPot.py +37 -0
  145. pysisyphus/calculators/WFOWrapper.py +333 -0
  146. pysisyphus/calculators/WFOWrapper2.py +341 -0
  147. pysisyphus/calculators/XTB.py +418 -0
  148. pysisyphus/calculators/__init__.py +81 -0
  149. pysisyphus/calculators/cosmo_data.py +139 -0
  150. pysisyphus/calculators/parser.py +150 -0
  151. pysisyphus/color.py +19 -0
  152. pysisyphus/config.py +133 -0
  153. pysisyphus/constants.py +65 -0
  154. pysisyphus/cos/AdaptiveNEB.py +230 -0
  155. pysisyphus/cos/ChainOfStates.py +725 -0
  156. pysisyphus/cos/FreeEndNEB.py +25 -0
  157. pysisyphus/cos/FreezingString.py +103 -0
  158. pysisyphus/cos/GrowingChainOfStates.py +71 -0
  159. pysisyphus/cos/GrowingNT.py +309 -0
  160. pysisyphus/cos/GrowingString.py +508 -0
  161. pysisyphus/cos/NEB.py +189 -0
  162. pysisyphus/cos/SimpleZTS.py +64 -0
  163. pysisyphus/cos/__init__.py +22 -0
  164. pysisyphus/cos/stiffness.py +199 -0
  165. pysisyphus/drivers/__init__.py +17 -0
  166. pysisyphus/drivers/afir.py +855 -0
  167. pysisyphus/drivers/barriers.py +271 -0
  168. pysisyphus/drivers/birkholz.py +138 -0
  169. pysisyphus/drivers/cluster.py +318 -0
  170. pysisyphus/drivers/diabatization.py +133 -0
  171. pysisyphus/drivers/merge.py +368 -0
  172. pysisyphus/drivers/merge_mol2.py +322 -0
  173. pysisyphus/drivers/opt.py +375 -0
  174. pysisyphus/drivers/perf.py +91 -0
  175. pysisyphus/drivers/pka.py +52 -0
  176. pysisyphus/drivers/precon_pos_rot.py +669 -0
  177. pysisyphus/drivers/rates.py +480 -0
  178. pysisyphus/drivers/replace.py +219 -0
  179. pysisyphus/drivers/scan.py +212 -0
  180. pysisyphus/drivers/spectrum.py +166 -0
  181. pysisyphus/drivers/thermo.py +31 -0
  182. pysisyphus/dynamics/Gaussian.py +103 -0
  183. pysisyphus/dynamics/__init__.py +20 -0
  184. pysisyphus/dynamics/colvars.py +136 -0
  185. pysisyphus/dynamics/driver.py +297 -0
  186. pysisyphus/dynamics/helpers.py +256 -0
  187. pysisyphus/dynamics/lincs.py +105 -0
  188. pysisyphus/dynamics/mdp.py +364 -0
  189. pysisyphus/dynamics/rattle.py +121 -0
  190. pysisyphus/dynamics/thermostats.py +128 -0
  191. pysisyphus/dynamics/wigner.py +266 -0
  192. pysisyphus/elem_data.py +3473 -0
  193. pysisyphus/exceptions.py +2 -0
  194. pysisyphus/filtertrj.py +69 -0
  195. pysisyphus/helpers.py +623 -0
  196. pysisyphus/helpers_pure.py +649 -0
  197. pysisyphus/init_logging.py +50 -0
  198. pysisyphus/intcoords/Bend.py +69 -0
  199. pysisyphus/intcoords/Bend2.py +25 -0
  200. pysisyphus/intcoords/BondedFragment.py +32 -0
  201. pysisyphus/intcoords/Cartesian.py +41 -0
  202. pysisyphus/intcoords/CartesianCoords.py +140 -0
  203. pysisyphus/intcoords/Coords.py +56 -0
  204. pysisyphus/intcoords/DLC.py +197 -0
  205. pysisyphus/intcoords/DistanceFunction.py +34 -0
  206. pysisyphus/intcoords/DummyImproper.py +70 -0
  207. pysisyphus/intcoords/DummyTorsion.py +72 -0
  208. pysisyphus/intcoords/LinearBend.py +105 -0
  209. pysisyphus/intcoords/LinearDisplacement.py +80 -0
  210. pysisyphus/intcoords/OutOfPlane.py +59 -0
  211. pysisyphus/intcoords/PrimTypes.py +286 -0
  212. pysisyphus/intcoords/Primitive.py +137 -0
  213. pysisyphus/intcoords/RedundantCoords.py +659 -0
  214. pysisyphus/intcoords/RobustTorsion.py +59 -0
  215. pysisyphus/intcoords/Rotation.py +147 -0
  216. pysisyphus/intcoords/Stretch.py +31 -0
  217. pysisyphus/intcoords/Torsion.py +101 -0
  218. pysisyphus/intcoords/Torsion2.py +25 -0
  219. pysisyphus/intcoords/Translation.py +45 -0
  220. pysisyphus/intcoords/__init__.py +61 -0
  221. pysisyphus/intcoords/augment_bonds.py +126 -0
  222. pysisyphus/intcoords/derivatives.py +10512 -0
  223. pysisyphus/intcoords/eval.py +80 -0
  224. pysisyphus/intcoords/exceptions.py +37 -0
  225. pysisyphus/intcoords/findiffs.py +48 -0
  226. pysisyphus/intcoords/generate_derivatives.py +414 -0
  227. pysisyphus/intcoords/helpers.py +235 -0
  228. pysisyphus/intcoords/logging_conf.py +10 -0
  229. pysisyphus/intcoords/mp_derivatives.py +10836 -0
  230. pysisyphus/intcoords/setup.py +962 -0
  231. pysisyphus/intcoords/setup_fast.py +176 -0
  232. pysisyphus/intcoords/update.py +272 -0
  233. pysisyphus/intcoords/valid.py +89 -0
  234. pysisyphus/interpolate/Geodesic.py +93 -0
  235. pysisyphus/interpolate/IDPP.py +55 -0
  236. pysisyphus/interpolate/Interpolator.py +116 -0
  237. pysisyphus/interpolate/LST.py +70 -0
  238. pysisyphus/interpolate/Redund.py +152 -0
  239. pysisyphus/interpolate/__init__.py +9 -0
  240. pysisyphus/interpolate/helpers.py +34 -0
  241. pysisyphus/io/__init__.py +22 -0
  242. pysisyphus/io/aomix.py +178 -0
  243. pysisyphus/io/cjson.py +24 -0
  244. pysisyphus/io/crd.py +101 -0
  245. pysisyphus/io/cube.py +220 -0
  246. pysisyphus/io/fchk.py +184 -0
  247. pysisyphus/io/hdf5.py +49 -0
  248. pysisyphus/io/hessian.py +72 -0
  249. pysisyphus/io/mol2.py +146 -0
  250. pysisyphus/io/molden.py +293 -0
  251. pysisyphus/io/orca.py +189 -0
  252. pysisyphus/io/pdb.py +269 -0
  253. pysisyphus/io/psf.py +79 -0
  254. pysisyphus/io/pubchem.py +31 -0
  255. pysisyphus/io/qcschema.py +34 -0
  256. pysisyphus/io/sdf.py +29 -0
  257. pysisyphus/io/xyz.py +61 -0
  258. pysisyphus/io/zmat.py +175 -0
  259. pysisyphus/irc/DWI.py +108 -0
  260. pysisyphus/irc/DampedVelocityVerlet.py +134 -0
  261. pysisyphus/irc/Euler.py +22 -0
  262. pysisyphus/irc/EulerPC.py +345 -0
  263. pysisyphus/irc/GonzalezSchlegel.py +187 -0
  264. pysisyphus/irc/IMKMod.py +164 -0
  265. pysisyphus/irc/IRC.py +878 -0
  266. pysisyphus/irc/IRCDummy.py +10 -0
  267. pysisyphus/irc/Instanton.py +307 -0
  268. pysisyphus/irc/LQA.py +53 -0
  269. pysisyphus/irc/ModeKill.py +136 -0
  270. pysisyphus/irc/ParamPlot.py +53 -0
  271. pysisyphus/irc/RK4.py +36 -0
  272. pysisyphus/irc/__init__.py +31 -0
  273. pysisyphus/irc/initial_displ.py +219 -0
  274. pysisyphus/linalg.py +411 -0
  275. pysisyphus/line_searches/Backtracking.py +88 -0
  276. pysisyphus/line_searches/HagerZhang.py +184 -0
  277. pysisyphus/line_searches/LineSearch.py +232 -0
  278. pysisyphus/line_searches/StrongWolfe.py +108 -0
  279. pysisyphus/line_searches/__init__.py +9 -0
  280. pysisyphus/line_searches/interpol.py +15 -0
  281. pysisyphus/modefollow/NormalMode.py +40 -0
  282. pysisyphus/modefollow/__init__.py +10 -0
  283. pysisyphus/modefollow/davidson.py +199 -0
  284. pysisyphus/modefollow/lanczos.py +95 -0
  285. pysisyphus/optimizers/BFGS.py +99 -0
  286. pysisyphus/optimizers/BacktrackingOptimizer.py +113 -0
  287. pysisyphus/optimizers/ConjugateGradient.py +98 -0
  288. pysisyphus/optimizers/CubicNewton.py +75 -0
  289. pysisyphus/optimizers/FIRE.py +113 -0
  290. pysisyphus/optimizers/HessianOptimizer.py +1176 -0
  291. pysisyphus/optimizers/LBFGS.py +228 -0
  292. pysisyphus/optimizers/LayerOpt.py +411 -0
  293. pysisyphus/optimizers/MicroOptimizer.py +169 -0
  294. pysisyphus/optimizers/NCOptimizer.py +90 -0
  295. pysisyphus/optimizers/Optimizer.py +1084 -0
  296. pysisyphus/optimizers/PreconLBFGS.py +260 -0
  297. pysisyphus/optimizers/PreconSteepestDescent.py +7 -0
  298. pysisyphus/optimizers/QuickMin.py +74 -0
  299. pysisyphus/optimizers/RFOptimizer.py +181 -0
  300. pysisyphus/optimizers/RSA.py +99 -0
  301. pysisyphus/optimizers/StabilizedQNMethod.py +248 -0
  302. pysisyphus/optimizers/SteepestDescent.py +23 -0
  303. pysisyphus/optimizers/StringOptimizer.py +173 -0
  304. pysisyphus/optimizers/__init__.py +41 -0
  305. pysisyphus/optimizers/closures.py +301 -0
  306. pysisyphus/optimizers/cls_map.py +58 -0
  307. pysisyphus/optimizers/exceptions.py +6 -0
  308. pysisyphus/optimizers/gdiis.py +280 -0
  309. pysisyphus/optimizers/guess_hessians.py +311 -0
  310. pysisyphus/optimizers/hessian_updates.py +355 -0
  311. pysisyphus/optimizers/poly_fit.py +285 -0
  312. pysisyphus/optimizers/precon.py +153 -0
  313. pysisyphus/optimizers/restrict_step.py +24 -0
  314. pysisyphus/pack.py +172 -0
  315. pysisyphus/peakdetect.py +948 -0
  316. pysisyphus/plot.py +1031 -0
  317. pysisyphus/run.py +2106 -0
  318. pysisyphus/socket_helper.py +74 -0
  319. pysisyphus/stocastic/FragmentKick.py +132 -0
  320. pysisyphus/stocastic/Kick.py +81 -0
  321. pysisyphus/stocastic/Pipeline.py +303 -0
  322. pysisyphus/stocastic/__init__.py +21 -0
  323. pysisyphus/stocastic/align.py +127 -0
  324. pysisyphus/testing.py +96 -0
  325. pysisyphus/thermo.py +156 -0
  326. pysisyphus/trj.py +824 -0
  327. pysisyphus/tsoptimizers/RSIRFOptimizer.py +56 -0
  328. pysisyphus/tsoptimizers/RSPRFOptimizer.py +182 -0
  329. pysisyphus/tsoptimizers/TRIM.py +59 -0
  330. pysisyphus/tsoptimizers/TSHessianOptimizer.py +463 -0
  331. pysisyphus/tsoptimizers/__init__.py +23 -0
  332. pysisyphus/wavefunction/Basis.py +239 -0
  333. pysisyphus/wavefunction/DIIS.py +76 -0
  334. pysisyphus/wavefunction/__init__.py +25 -0
  335. pysisyphus/wavefunction/build_ext.py +42 -0
  336. pysisyphus/wavefunction/cart2sph.py +190 -0
  337. pysisyphus/wavefunction/diabatization.py +304 -0
  338. pysisyphus/wavefunction/excited_states.py +435 -0
  339. pysisyphus/wavefunction/gen_ints.py +1811 -0
  340. pysisyphus/wavefunction/helpers.py +104 -0
  341. pysisyphus/wavefunction/ints/__init__.py +0 -0
  342. pysisyphus/wavefunction/ints/boys.py +193 -0
  343. pysisyphus/wavefunction/ints/boys_table_N_64_xasym_27.1_step_0.01.npy +0 -0
  344. pysisyphus/wavefunction/ints/cart_gto3d.py +176 -0
  345. pysisyphus/wavefunction/ints/coulomb3d.py +25928 -0
  346. pysisyphus/wavefunction/ints/diag_quadrupole3d.py +10036 -0
  347. pysisyphus/wavefunction/ints/dipole3d.py +8762 -0
  348. pysisyphus/wavefunction/ints/int2c2e3d.py +7198 -0
  349. pysisyphus/wavefunction/ints/int3c2e3d_sph.py +65040 -0
  350. pysisyphus/wavefunction/ints/kinetic3d.py +8240 -0
  351. pysisyphus/wavefunction/ints/ovlp3d.py +3777 -0
  352. pysisyphus/wavefunction/ints/quadrupole3d.py +15054 -0
  353. pysisyphus/wavefunction/ints/self_ovlp3d.py +198 -0
  354. pysisyphus/wavefunction/localization.py +458 -0
  355. pysisyphus/wavefunction/multipole.py +159 -0
  356. pysisyphus/wavefunction/normalization.py +36 -0
  357. pysisyphus/wavefunction/pop_analysis.py +134 -0
  358. pysisyphus/wavefunction/shells.py +1171 -0
  359. pysisyphus/wavefunction/wavefunction.py +504 -0
  360. pysisyphus/wrapper/__init__.py +11 -0
  361. pysisyphus/wrapper/exceptions.py +2 -0
  362. pysisyphus/wrapper/jmol.py +120 -0
  363. pysisyphus/wrapper/mwfn.py +169 -0
  364. pysisyphus/wrapper/packmol.py +71 -0
  365. pysisyphus/xyzloader.py +168 -0
  366. pysisyphus/yaml_mods.py +45 -0
  367. thermoanalysis/LICENSE +674 -0
  368. thermoanalysis/QCData.py +244 -0
  369. thermoanalysis/__init__.py +0 -0
  370. thermoanalysis/config.py +3 -0
  371. thermoanalysis/constants.py +20 -0
  372. thermoanalysis/thermo.py +1011 -0
mlmm/utils.py ADDED
@@ -0,0 +1,2309 @@
1
+ # mlmm/utils.py
2
+
3
+ """
4
+ utils — concise utilities for configuration, plotting, and coordinates
5
+ ====================================================================
6
+
7
+ Usage (API)
8
+ -----
9
+ from mlmm.utils import (
10
+ build_energy_diagram,
11
+ convert_xyz_to_pdb,
12
+ merge_freeze_atom_indices,
13
+ pretty_block,
14
+ )
15
+
16
+ Examples::
17
+ >>> from pathlib import Path
18
+ >>> block = pretty_block("Geometry", {"freeze_atoms": [0, 1, 5]})
19
+ >>> diagram = build_energy_diagram([0.0, 12.3, 5.4], ["R", "TS", "P"])
20
+
21
+ Description
22
+ -----
23
+ - **Generic helpers**
24
+ - `pretty_block(title, content)`: Return a YAML-formatted block with an underlined title. Uses `yaml.safe_dump` with `allow_unicode=True`, `sort_keys=False`. Renders `{}` when `content` is empty.
25
+ - `format_freeze_atoms_for_echo(cfg, key="freeze_atoms")`: Normalize geometry configuration for CLI echo. If the key is an iterable (but not a string), summarize to a compact single-line form like `"11073 atoms [0,1,2,3,4,...,12981,12982,12983,12984,12985]"`.
26
+ - `format_elapsed(prefix, start_time, end_time=None)`: Format a wall-clock duration (HH:MM:SS.sss) given a start time and optional end time, using `time.perf_counter()` when the end time is omitted.
27
+ - `merge_freeze_atom_indices(geom_cfg, *indices)`: Merge one or more iterables of atom indices into `geom_cfg["freeze_atoms"]`. Preserve existing entries, de-duplicate, sort numerically, and return the updated list (in place).
28
+ - `apply_layer_freeze_constraints(geom_cfg, calc_cfg, layer_info, echo_fn=None)`: Merge layer-detected frozen indices (`layer_info["frozen_indices"]`) into both `geom_cfg["freeze_atoms"]` and `calc_cfg["freeze_atoms"]`, then optionally emit a concise summary line.
29
+ - `deep_update(dst, src)`: Recursively update mapping `dst` with `src`. Nested dicts are merged, non-dicts overwrite; returns `dst`.
30
+ - `_get_mapping_section(cfg, path)`: Internal helper to resolve a nested mapping section. Returns a `dict` or `None`.
31
+ - `apply_yaml_overrides(yaml_cfg, overrides)`: For each target dictionary and its candidate key paths, find the first existing path in `yaml_cfg` and apply it via `deep_update`. Centralizes repeated `yaml_cfg.get(...)`-style merging.
32
+ - `load_yaml_dict(path)`: Load a YAML file whose root must be a mapping. Returns `{}` when `path` is `None`. Raises `ValueError` if the YAML root is not a mapping.
33
+
34
+ - **Plotly: Energy diagram builder**
35
+ - `build_energy_diagram(energies, labels, ylabel="ΔE", baseline=False, showgrid=False)`:
36
+ Render an energy diagram where each state is a thick horizontal segment and adjacent states are connected by dotted diagonals (right end of left state → left end of right state). Segment length shrinks as the number of states grows to keep gaps readable. X ticks are centered on states and labeled by `labels`. Optional dotted baseline at the first state’s energy; optional grid. Energies are plotted as provided (no unit conversion). Returns a `plotly.graph_objs.Figure`. Validates equal lengths for `energies`/`labels` and non-empty input.
37
+
38
+ - **Coordinate conversion utilities**
39
+ - `convert_xyz_to_pdb(xyz_path, ref_pdb_path, out_pdb_path)`:
40
+ Overlay coordinates from an XYZ file (single or multi-frame) onto the atom ordering/topology of a reference PDB and write to `out_pdb_path`. The first frame creates/overwrites; subsequent frames append using `MODEL`/`ENDMDL`. Implemented with ASE (`ase.io.read`/`write`). Raises `ValueError` if no frames are found in the XYZ.
41
+
42
+ Outputs (& Directory Layout)
43
+ -----
44
+ - This module does not create directories.
45
+ - Functions primarily return Python objects or mutate dictionaries in place.
46
+ - On-disk output occurs only when explicitly requested by the caller:
47
+ - `convert_xyz_to_pdb` writes a PDB file to `out_pdb_path` (first frame create/overwrite; subsequent frames append with `MODEL`/`ENDMDL` blocks).
48
+ - `build_energy_diagram` returns a Plotly `Figure`; it does not write files unless the caller saves/exports the figure.
49
+
50
+ Notes:
51
+ -----
52
+ - Energy units in `build_energy_diagram` are passed through unchanged; ensure consistent units across states.
53
+ - Axis/line styling in `build_energy_diagram` is fixed-width with automatic padding; segment length adapts to the number of states.
54
+ - `load_yaml_dict` uses `yaml.safe_load` and enforces a mapping at the YAML root; empty files yield `{}`.
55
+ - `apply_yaml_overrides` tries candidate key paths in order and applies only the first existing mapping section per target.
56
+ - Dependencies: PyYAML, ASE (`ase.io.read`/`write`), Plotly (graph objects).
57
+ """
58
+
59
+ import ast
60
+ import logging
61
+ import math
62
+ import os
63
+ import re
64
+ import time
65
+ import tempfile
66
+ from collections.abc import Iterable as _Iterable, Mapping, Sequence as _Sequence
67
+ from dataclasses import dataclass
68
+ from numbers import Real, Integral
69
+ from pathlib import Path
70
+ from typing import Any, Callable, Dict, Optional, Sequence, List, Tuple
71
+
72
+ import click
73
+ import numpy as np
74
+ import yaml
75
+ from ase.io import read, write
76
+ import plotly.graph_objs as go
77
+
78
+ from pysisyphus.helpers import geom_loader
79
+ from pysisyphus.constants import ANG2BOHR
80
+
81
+ from .add_elem_info import guess_element
82
+
83
+ logger = logging.getLogger(__name__)
84
+
85
+ # =============================================================================
86
+ # Generic helpers
87
+ # =============================================================================
88
+
89
+
90
+ def ensure_dir(path: Path) -> None:
91
+ """Create a directory (parents ok); noop if it already exists."""
92
+ path.mkdir(parents=True, exist_ok=True)
93
+
94
+
95
+ def read_xyz_as_blocks(path: Path, *, strict: bool = False) -> List[List[str]]:
96
+ """Read an XYZ-style trajectory into blocks of lines.
97
+
98
+ When *strict* is True, malformed headers or truncated frames raise a ClickException.
99
+ """
100
+ try:
101
+ lines = path.read_text(encoding="utf-8").splitlines()
102
+ except Exception as e:
103
+ import click
104
+ raise click.ClickException(f"Failed to read {path}: {e}")
105
+
106
+ blocks: List[List[str]] = []
107
+ i = 0
108
+ while i < len(lines):
109
+ if not lines[i].strip():
110
+ i += 1
111
+ continue
112
+ try:
113
+ n_atoms = int(lines[i].strip().split()[0])
114
+ except Exception:
115
+ if strict:
116
+ import click
117
+ raise click.ClickException(f"[xyz] Malformed XYZ/TRJ header at line {i+1} of {path}")
118
+ break
119
+ end = i + n_atoms + 2
120
+ if end > len(lines):
121
+ if strict:
122
+ import click
123
+ raise click.ClickException(f"[xyz] Incomplete XYZ frame at line {i+1} of {path}")
124
+ break
125
+ blocks.append(lines[i:end])
126
+ i = end
127
+ return blocks
128
+
129
+
130
+ def parse_xyz_block(
131
+ block: Sequence[str],
132
+ *,
133
+ path: Path,
134
+ frame_idx: int,
135
+ ) -> Tuple[List[str], "np.ndarray"]:
136
+ """Parse a single XYZ frame block into (elements, coords_angstrom)."""
137
+ import click
138
+
139
+ if not block:
140
+ raise click.ClickException(f"[xyz] Empty XYZ frame in {path}")
141
+ try:
142
+ nat = int(block[0].strip().split()[0])
143
+ except Exception:
144
+ raise click.ClickException(
145
+ f"[xyz] Malformed XYZ/TRJ header in frame {frame_idx} of {path}"
146
+ )
147
+ if len(block) < 2 + nat:
148
+ raise click.ClickException(
149
+ f"[xyz] Incomplete XYZ frame {frame_idx} in {path} (expected {nat} atoms)."
150
+ )
151
+ elems: List[str] = []
152
+ coords: List[List[float]] = []
153
+ for k in range(nat):
154
+ parts = block[2 + k].split()
155
+ if len(parts) < 4:
156
+ raise click.ClickException(
157
+ f"[xyz] Malformed atom line in frame {frame_idx} of {path}"
158
+ )
159
+ elems.append(parts[0])
160
+ coords.append([float(parts[1]), float(parts[2]), float(parts[3])])
161
+ return elems, np.array(coords, dtype=float)
162
+
163
+
164
+ def xyz_blocks_first_last(
165
+ blocks: Sequence[Sequence[str]],
166
+ *,
167
+ path: Path,
168
+ ) -> Tuple[List[str], "np.ndarray", "np.ndarray"]:
169
+ """Return (elements, first_coords_ang, last_coords_ang) from pre-parsed XYZ blocks."""
170
+ import click
171
+
172
+ if not blocks:
173
+ raise click.ClickException(f"[xyz] No frames found in {path}")
174
+ first_elems, first_coords = parse_xyz_block(blocks[0], path=path, frame_idx=1)
175
+ last_elems, last_coords = parse_xyz_block(blocks[-1], path=path, frame_idx=len(blocks))
176
+ if first_elems != last_elems:
177
+ raise click.ClickException(f"[xyz] Element list changed across frames in {path}")
178
+ return first_elems, first_coords, last_coords
179
+
180
+
181
+ def read_xyz_first_last(trj_path: Path) -> Tuple[List[str], "np.ndarray", "np.ndarray"]:
182
+ """Lightweight XYZ trajectory reader: return (elements, first_coords[Å], last_coords[Å])."""
183
+ blocks = read_xyz_as_blocks(trj_path, strict=True)
184
+ return xyz_blocks_first_last(blocks, path=trj_path)
185
+
186
+
187
+ def close_matplotlib_figures() -> None:
188
+ """Best-effort cleanup for matplotlib figures to avoid open-figure warnings."""
189
+ try:
190
+ import matplotlib.pyplot as plt
191
+ plt.close("all")
192
+ except Exception:
193
+ pass
194
+
195
+
196
+ def distance_A_from_coords(coords_bohr: "np.ndarray", i: int, j: int) -> float:
197
+ """Return interatomic distance in Å given coords in Bohr."""
198
+ diff = coords_bohr[i] - coords_bohr[j]
199
+ return float(np.linalg.norm(diff) / ANG2BOHR)
200
+
201
+
202
+ def distance_tag(value_A: float, *, digits: int = 2, pad: int = 3) -> str:
203
+ """Format a distance in Å as a zero-padded integer tag (default: ×10^2)."""
204
+ scale = 10 ** digits
205
+ return f"{int(round(value_A * scale)):0{pad}d}"
206
+
207
+
208
+ def values_from_bounds(low: float, high: float, h: float) -> "np.ndarray":
209
+ """Return evenly spaced values from low→high with step cap h (inclusive)."""
210
+ if h <= 0.0:
211
+ raise click.BadParameter("--max-step-size must be > 0.")
212
+ delta = abs(high - low)
213
+ if delta < 1e-12:
214
+ return np.array([low], dtype=float)
215
+ N = int(math.ceil(delta / h))
216
+ return np.linspace(low, high, N + 1, dtype=float)
217
+
218
+
219
+ def geom_from_xyz_string(
220
+ xyz_text: str,
221
+ *,
222
+ coord_type: str,
223
+ freeze_atoms: Optional[Sequence[int]] = None,
224
+ ) -> Any:
225
+ """Load a pysisyphus Geometry from an XYZ text string (tempfile-backed)."""
226
+ s = xyz_text if xyz_text.endswith("\n") else (xyz_text + "\n")
227
+ freeze_atoms = list(freeze_atoms) if freeze_atoms is not None else []
228
+ tmp = tempfile.NamedTemporaryFile("w+", suffix=".xyz", delete=False)
229
+ try:
230
+ tmp.write(s)
231
+ tmp.flush()
232
+ tmp.close()
233
+
234
+ g = geom_loader(
235
+ Path(tmp.name),
236
+ coord_type=coord_type,
237
+ freeze_atoms=freeze_atoms,
238
+ )
239
+ try:
240
+ g.freeze_atoms = np.array(sorted(set(map(int, freeze_atoms))), dtype=int)
241
+ except Exception:
242
+ click.echo(
243
+ "[geom] WARNING: Failed to attach freeze_atoms to geometry.",
244
+ err=True,
245
+ )
246
+ return g
247
+ finally:
248
+ try:
249
+ os.unlink(tmp.name)
250
+ except Exception:
251
+ logger.debug("Failed to unlink temp file %s", tmp.name, exc_info=True)
252
+
253
+
254
+ def append_xyz_trajectory(dst_path: Path, src_path: Path, *, reset: bool = False) -> bool:
255
+ """Append an XYZ trajectory segment to a concatenated trajectory file."""
256
+ if not src_path.exists():
257
+ return False
258
+ mode = "w" if reset else "a"
259
+ with src_path.open("r", encoding="utf-8") as src, dst_path.open(mode, encoding="utf-8") as dst:
260
+ while True:
261
+ chunk = src.read(1024 * 1024)
262
+ if not chunk:
263
+ break
264
+ dst.write(chunk)
265
+ return True
266
+
267
+
268
+ def snapshot_geometry(geom: Any, *, coord_type_default: str) -> Any:
269
+ """Create an independent pysisyphus Geometry snapshot from the given Geometry."""
270
+ s = geom.as_xyz()
271
+ return geom_from_xyz_string(
272
+ s,
273
+ coord_type=getattr(geom, "coord_type", coord_type_default),
274
+ freeze_atoms=getattr(geom, "freeze_atoms", []),
275
+ )
276
+
277
+
278
+ def unbiased_energy_hartree(geom, base_calc) -> float:
279
+ """Evaluate UMA energy (Hartree) without harmonic bias."""
280
+ coords_bohr = np.asarray(geom.coords)
281
+ elems = getattr(geom, "atoms", None)
282
+ if elems is None:
283
+ return float("nan")
284
+ try:
285
+ return float(base_calc.get_energy(elems, coords_bohr)["energy"])
286
+ except Exception:
287
+ return float("nan")
288
+
289
+
290
+ def pretty_block(title: str, content: Dict[str, Any]) -> str:
291
+ """
292
+ Return a YAML-formatted block with an underlined title.
293
+ """
294
+ body = yaml.safe_dump(_to_yaml_safe(content), sort_keys=False, allow_unicode=True).strip()
295
+ return f"{title}\n" + "-" * len(title) + "\n" + (body if body else "(empty)") + "\n"
296
+
297
+
298
+ def _to_yaml_safe(value: Any) -> Any:
299
+ """Recursively convert NumPy values/containers into YAML-safe builtins."""
300
+ if isinstance(value, np.generic):
301
+ return value.item()
302
+ if isinstance(value, np.ndarray):
303
+ return [_to_yaml_safe(v) for v in value.tolist()]
304
+ if isinstance(value, Mapping):
305
+ out: Dict[Any, Any] = {}
306
+ for k, v in value.items():
307
+ nk = _to_yaml_safe(k)
308
+ if isinstance(nk, (list, tuple, set, dict)):
309
+ nk = str(nk)
310
+ out[nk] = _to_yaml_safe(v)
311
+ return out
312
+ if isinstance(value, tuple):
313
+ return [_to_yaml_safe(v) for v in value]
314
+ if isinstance(value, list):
315
+ return [_to_yaml_safe(v) for v in value]
316
+ if isinstance(value, set):
317
+ return [_to_yaml_safe(v) for v in sorted(value, key=lambda x: str(x))]
318
+ return value
319
+
320
+
321
+ # Backend-specific key prefixes in MLMM_CALC_KW.
322
+ # Keys with these prefixes are only relevant when the corresponding backend is active.
323
+ _BACKEND_KEY_PREFIXES: Dict[str, tuple] = {
324
+ "uma": ("uma_model", "uma_task_name"),
325
+ "orb": ("orb_model", "orb_precision"),
326
+ "mace": ("mace_model", "mace_dtype"),
327
+ "aimnet2": ("aimnet2_model",),
328
+ }
329
+
330
+
331
+ def filter_calc_for_echo(calc_cfg: Dict[str, Any]) -> Dict[str, Any]:
332
+ """Remove backend-specific keys that are irrelevant for the active backend.
333
+
334
+ Also hides xTB/embedcharge keys when embedcharge is disabled.
335
+ """
336
+ cfg = dict(calc_cfg)
337
+ active = cfg.get("backend", "uma")
338
+
339
+ # Remove keys belonging to inactive ML backends
340
+ for backend, keys in _BACKEND_KEY_PREFIXES.items():
341
+ if backend != active:
342
+ for k in keys:
343
+ cfg.pop(k, None)
344
+
345
+ # Hide xTB-specific keys when embedcharge is disabled
346
+ if not cfg.get("embedcharge"):
347
+ for k in list(cfg):
348
+ if k.startswith("xtb_"):
349
+ cfg.pop(k)
350
+ cfg.pop("embedcharge_step", None)
351
+ cfg.pop("embedcharge_cutoff", None)
352
+
353
+ return cfg
354
+
355
+
356
+ def strip_inherited_keys(
357
+ child_cfg: Dict[str, Any],
358
+ base_cfg: Dict[str, Any],
359
+ *,
360
+ mode: str = "present",
361
+ ) -> Dict[str, Any]:
362
+ """Return child_cfg without inherited keys (for concise logs).
363
+
364
+ Parameters
365
+ ----------
366
+ child_cfg : Dict[str, Any]
367
+ The child configuration dictionary to trim.
368
+ base_cfg : Dict[str, Any]
369
+ The base configuration dictionary to compare against.
370
+ mode : str
371
+ - "present": Remove keys that exist in base_cfg regardless of value.
372
+ - "same": Remove keys only when the value matches base_cfg.
373
+
374
+ Returns
375
+ -------
376
+ Dict[str, Any]
377
+ A new dictionary with inherited keys removed.
378
+ """
379
+ if mode not in {"present", "same"}:
380
+ raise ValueError(f"Unknown strip_inherited_keys mode: {mode}")
381
+ trimmed: Dict[str, Any] = {}
382
+ for key, value in child_cfg.items():
383
+ if key in base_cfg:
384
+ if mode == "present":
385
+ continue
386
+ if base_cfg.get(key) == value:
387
+ continue
388
+ trimmed[key] = value
389
+ return trimmed
390
+
391
+
392
+ def _summarize_atom_indices(items: Sequence[Any]) -> str:
393
+ """Return a compact single-line summary for atom indices."""
394
+ if not items:
395
+ return ""
396
+
397
+ count = len(items)
398
+ if count <= 64:
399
+ return f"{count} atoms [{','.join(map(str, items))}]"
400
+
401
+ head = ",".join(map(str, items[:5]))
402
+ tail = ",".join(map(str, items[-5:]))
403
+ return f"{count} atoms [{head},...,{tail}]"
404
+
405
+
406
+ def format_freeze_atoms_for_echo(
407
+ cfg: Dict[str, Any],
408
+ *,
409
+ key: str = "freeze_atoms",
410
+ ) -> Dict[str, Any]:
411
+ """
412
+ Normalize freeze-atoms fields for concise CLI echo output.
413
+ """
414
+ g = dict(cfg)
415
+ freeze_atoms = g.get(key)
416
+ if freeze_atoms is None:
417
+ return g
418
+
419
+ if isinstance(freeze_atoms, str):
420
+ return g
421
+
422
+ try:
423
+ items = list(freeze_atoms)
424
+ except TypeError:
425
+ return g
426
+
427
+ g[key] = _summarize_atom_indices(items)
428
+ return g
429
+
430
+
431
+ def format_elapsed(prefix: str, start_time: float, end_time: Optional[float] = None) -> str:
432
+ """Return a formatted elapsed-time string with the provided ``prefix`` label."""
433
+ finish = end_time if end_time is not None else time.perf_counter()
434
+ elapsed = max(0.0, finish - start_time)
435
+ hours, rem = divmod(elapsed, 3600)
436
+ minutes, seconds = divmod(rem, 60)
437
+ return f"{prefix}: {int(hours):02d}:{int(minutes):02d}:{seconds:06.3f}"
438
+
439
+
440
+ def normalize_freeze_atoms(raw: Any) -> List[int]:
441
+ """Normalize freeze_atoms values (string/list/iterable) into a list of integers.
442
+
443
+ Parameters
444
+ ----------
445
+ raw : Any
446
+ Input value that can be a string (e.g., "1,2,3" or "1 2 3"),
447
+ a list of integers, or any iterable of numeric values.
448
+
449
+ Returns
450
+ -------
451
+ List[int]
452
+ List of integer indices.
453
+
454
+ Examples
455
+ --------
456
+ >>> normalize_freeze_atoms("1, 2, 3")
457
+ [1, 2, 3]
458
+ >>> normalize_freeze_atoms([1, 2, 3])
459
+ [1, 2, 3]
460
+ >>> normalize_freeze_atoms(None)
461
+ []
462
+ """
463
+ import re
464
+
465
+ if raw is None:
466
+ return []
467
+ if isinstance(raw, str):
468
+ tokens = re.findall(r"-?\d+", raw)
469
+ return [int(tok) for tok in tokens]
470
+ try:
471
+ return [int(i) for i in raw]
472
+ except Exception:
473
+ return []
474
+
475
+
476
+ def merge_freeze_atom_indices(
477
+ geom_cfg: Dict[str, Any],
478
+ *indices: _Iterable[int],
479
+ ) -> List[int]:
480
+ """Merge one or more iterables of indices into ``geom_cfg['freeze_atoms']``.
481
+
482
+ Existing entries are preserved, duplicates removed, and the result sorted.
483
+ The updated list is returned.
484
+ """
485
+ merged: set[int] = set()
486
+ base = geom_cfg.get("freeze_atoms", None)
487
+ merged.update(normalize_freeze_atoms(base))
488
+ for seq in indices:
489
+ merged.update(normalize_freeze_atoms(seq))
490
+ result = sorted(merged)
491
+ geom_cfg["freeze_atoms"] = result
492
+ return result
493
+
494
+
495
+ # =============================================================================
496
+ # Link-freezing helpers
497
+ # =============================================================================
498
+
499
+
500
+ def parse_pdb_coords(pdb_path):
501
+ """Parse ATOM/HETATM records from *pdb_path* and separate link hydrogen (HL) atoms.
502
+
503
+ Returns:
504
+ A tuple (others, lkhs) where:
505
+ - others: list of tuples (index, x, y, z, line) for all atoms except the
506
+ 'HL' atom of residue 'LKH'. ``index`` is the 0-based position in the
507
+ atom sequence as loaded from the *first* MODEL (or the full file if no
508
+ MODEL records are present).
509
+ - lkhs: list of tuples (x, y, z, line) for atoms where residue name is
510
+ 'LKH' and atom name is 'HL' in the same MODEL selection.
511
+
512
+ Notes
513
+ -----
514
+ - Coordinates are read from standard PDB columns:
515
+ X: columns 31-38, Y: 39-46, Z: 47-54 (1-based indexing).
516
+ - If multiple MODEL blocks are present, only the first model is considered,
517
+ matching typical geom_loader behavior.
518
+ """
519
+ with open(pdb_path, "r") as f:
520
+ lines = f.readlines()
521
+
522
+ others = []
523
+ lkhs = []
524
+ model_seen = False
525
+ in_first_model = True
526
+ atom_index = 0
527
+ for line in lines:
528
+ if line.startswith("MODEL"):
529
+ if not model_seen:
530
+ model_seen = True
531
+ in_first_model = True
532
+ else:
533
+ in_first_model = False
534
+ continue
535
+ if line.startswith("ENDMDL"):
536
+ if model_seen and in_first_model:
537
+ break
538
+ continue
539
+ if model_seen and not in_first_model:
540
+ continue
541
+ if not (line.startswith("ATOM") or line.startswith("HETATM")):
542
+ continue
543
+
544
+ current_index = atom_index
545
+ atom_index += 1
546
+
547
+ name = line[12:16].strip()
548
+ resname = line[17:20].strip()
549
+ try:
550
+ x = float(line[30:38])
551
+ y = float(line[38:46])
552
+ z = float(line[46:54])
553
+ except ValueError:
554
+ continue
555
+
556
+ if resname == "LKH" and name == "HL":
557
+ lkhs.append((x, y, z, line))
558
+ else:
559
+ others.append((current_index, x, y, z, line))
560
+ return others, lkhs
561
+
562
+
563
+ def nearest_index(point, pool):
564
+ """Find the nearest point in *pool* to *point* using Euclidean distance.
565
+
566
+ Args:
567
+ point: Tuple (x, y, z) representing the query coordinate.
568
+ pool: Iterable of tuples (index, x, y, z, line) to search.
569
+
570
+ Returns:
571
+ A tuple (index, distance) where:
572
+ - index is the 0-based index of the nearest entry in *pool* (or -1 if *pool* is empty).
573
+ - distance is the Euclidean distance to that entry (``inf`` if *pool* is empty).
574
+ """
575
+ x, y, z = point
576
+ best_i = -1
577
+ best_d2 = float("inf")
578
+ for atom_index, a, b, c, _ in pool:
579
+ d2 = (a - x) ** 2 + (b - y) ** 2 + (c - z) ** 2
580
+ if d2 < best_d2:
581
+ best_d2 = d2
582
+ best_i = atom_index
583
+ return best_i, math.sqrt(best_d2)
584
+
585
+
586
+ def detect_freeze_links(pdb_path):
587
+ """Identify link-parent atom indices for 'LKH'/'HL' link hydrogens.
588
+
589
+ For each 'HL' atom in residue 'LKH', find the nearest atom among all other
590
+ ATOM/HETATM records and return the indices of those nearest neighbors in the
591
+ same atom ordering used by geometry loading (first MODEL if present).
592
+
593
+ Args:
594
+ pdb_path: Path to the input PDB file.
595
+
596
+ Returns:
597
+ List of 0-based indices into the full atom sequence (including any link H atoms)
598
+ corresponding to the nearest neighbors (link parents). Returns an empty list if
599
+ no LKH/HL atoms are present or if link hydrogens exist without any other atoms.
600
+ """
601
+ others, lkhs = parse_pdb_coords(pdb_path)
602
+
603
+ if not lkhs or not others:
604
+ return []
605
+
606
+ indices = []
607
+ for (x, y, z, line) in lkhs:
608
+ idx, dist = nearest_index((x, y, z), others)
609
+ if idx >= 0:
610
+ indices.append(idx)
611
+ return indices
612
+
613
+
614
+ def detect_freeze_links_logged(pdb_path: Path) -> List[int]:
615
+ """Return link-parent indices and raise a user-facing error on failure."""
616
+ try:
617
+ return list(detect_freeze_links(pdb_path))
618
+ except Exception as e:
619
+ raise click.ClickException(
620
+ f"[freeze-links] Failed to detect link parents for '{pdb_path.name}': {e}"
621
+ ) from e
622
+
623
+
624
+ def merge_detected_freeze_links(
625
+ geom_cfg: Dict[str, Any],
626
+ pdb_path: Path,
627
+ *,
628
+ prefix: str = "[freeze-links]",
629
+ ) -> List[int]:
630
+ """Detect link-parent atoms and merge them into ``geom_cfg['freeze_atoms']``."""
631
+ detected = detect_freeze_links_logged(pdb_path)
632
+ merged = merge_freeze_atom_indices(geom_cfg, detected)
633
+ if merged:
634
+ click.echo(f"{prefix} Freeze atoms (0-based): {','.join(map(str, merged))}")
635
+ return merged
636
+
637
+
638
+ def apply_layer_freeze_constraints(
639
+ geom_cfg: Dict[str, Any],
640
+ calc_cfg: Dict[str, Any],
641
+ layer_info: Optional[Dict[str, Sequence[int]]],
642
+ *,
643
+ echo_fn: Optional[Callable[[str], None]] = None,
644
+ ) -> List[int]:
645
+ """Merge frozen-layer atoms into geometry/calculator freeze lists."""
646
+ base_freeze = normalize_freeze_atoms(geom_cfg.get("freeze_atoms"))
647
+ frozen_from_layer = normalize_freeze_atoms((layer_info or {}).get("frozen_indices", []))
648
+
649
+ if frozen_from_layer:
650
+ before = set(base_freeze)
651
+ merged = sorted(before | set(frozen_from_layer))
652
+ added = len(set(merged) - before)
653
+ if echo_fn is not None:
654
+ echo_fn(
655
+ f"[layer] Applied freeze constraints from frozen layer: "
656
+ f"total={len(merged)} (added_from_layer={added})"
657
+ )
658
+ else:
659
+ merged = sorted(set(base_freeze))
660
+
661
+ geom_cfg["freeze_atoms"] = merged
662
+ calc_cfg["freeze_atoms"] = merged
663
+ return merged
664
+
665
+
666
+ def deep_update(dst: Dict[str, Any], src: Optional[Dict[str, Any]]) -> Dict[str, Any]:
667
+ """
668
+ Recursively update mapping *dst* with *src*, returning *dst*.
669
+ """
670
+ for k, v in (src or {}).items():
671
+ if isinstance(v, dict) and isinstance(dst.get(k), dict):
672
+ deep_update(dst[k], v)
673
+ else:
674
+ dst[k] = v
675
+ return dst
676
+
677
+
678
+ def collect_single_option_values(
679
+ argv: _Sequence[str],
680
+ names: _Sequence[str],
681
+ label: str,
682
+ ) -> List[str]:
683
+ """Collect values following a flag that must appear at most once."""
684
+ vals: List[str] = []
685
+ seen = 0
686
+ i = 0
687
+ while i < len(argv):
688
+ tok = argv[i]
689
+ if tok in names:
690
+ seen += 1
691
+ j = i + 1
692
+ while j < len(argv) and not argv[j].startswith("-"):
693
+ vals.append(argv[j])
694
+ j += 1
695
+ i = j
696
+ else:
697
+ i += 1
698
+ if seen > 1:
699
+ raise click.BadParameter(
700
+ f"Use a single {label} followed by multiple values; repeated flags are not accepted."
701
+ )
702
+ return vals
703
+
704
+
705
+ def load_pdb_atom_metadata(pdb_path: Path) -> List[Dict[str, Any]]:
706
+ """Return per-atom metadata (serial, name, resname, resseq, element) in file order."""
707
+ atoms: List[Dict[str, Any]] = []
708
+ with open(pdb_path, "r") as f:
709
+ for line in f:
710
+ if not (line.startswith("ATOM") or line.startswith("HETATM")):
711
+ continue
712
+
713
+ serial_txt = line[6:11].strip()
714
+ resseq_txt = line[22:26].strip()
715
+ atom_name = line[12:16].strip()
716
+ res_name = line[17:20].strip()
717
+ element_txt = line[76:78].strip()
718
+ is_hetatm = line.startswith("HETATM")
719
+
720
+ try:
721
+ serial = int(serial_txt) if serial_txt else None
722
+ except ValueError:
723
+ serial = None
724
+ try:
725
+ resseq = int(resseq_txt) if resseq_txt else None
726
+ except ValueError:
727
+ resseq = None
728
+
729
+ if not element_txt:
730
+ inferred = guess_element(atom_name, res_name, is_hetatm)
731
+ element_txt = inferred or ""
732
+
733
+ atoms.append(
734
+ {
735
+ "serial": serial,
736
+ "name": atom_name,
737
+ "resname": res_name,
738
+ "resseq": resseq,
739
+ "element": element_txt,
740
+ }
741
+ )
742
+ return atoms
743
+
744
+
745
+ def resolve_atom_spec_index(spec: str, atom_meta: _Sequence[Dict[str, Any]]) -> int:
746
+ """Resolve an atom selector string into a 0-based atom index using PDB metadata."""
747
+ tokens = [t for t in re.split(r"[\s/`,\\]+", spec.strip().replace(" ", ",")) if t]
748
+ if len(tokens) != 3:
749
+ raise ValueError(
750
+ f"Atom spec '{spec}' must have exactly 3 fields (resname, resseq, atomname)."
751
+ )
752
+
753
+ tokens_upper = [t.upper() for t in tokens]
754
+ matches: List[int] = []
755
+ for idx, meta in enumerate(atom_meta):
756
+ resname = (meta.get("resname") or "").strip().upper()
757
+ resseq = meta.get("resseq")
758
+ atom = (meta.get("name") or "").strip().upper()
759
+ if resseq is None:
760
+ continue
761
+ fields = {resname, str(resseq), atom}
762
+ if all(tok in fields for tok in tokens_upper):
763
+ matches.append(idx)
764
+
765
+ if len(matches) == 1:
766
+ return matches[0]
767
+ if len(matches) > 1:
768
+ raise ValueError(
769
+ f"Atom spec '{spec}' matches {len(matches)} atoms; use an explicit atom index."
770
+ )
771
+
772
+ resname, resseq_str, atom = tokens_upper
773
+ if not resseq_str.isdigit():
774
+ raise ValueError(
775
+ f"Atom spec '{spec}' could not be resolved and residue number '{tokens[1]}' is not numeric."
776
+ )
777
+ resseq_int = int(resseq_str)
778
+ ordered_matches = [
779
+ idx
780
+ for idx, meta in enumerate(atom_meta)
781
+ if (meta.get("resname") or "").strip().upper() == resname
782
+ and meta.get("resseq") == resseq_int
783
+ and (meta.get("name") or "").strip().upper() == atom
784
+ ]
785
+ if len(ordered_matches) == 1:
786
+ return ordered_matches[0]
787
+ if len(ordered_matches) > 1:
788
+ raise ValueError(
789
+ f"Atom spec '{spec}' matches {len(ordered_matches)} atoms after ordered fallback; "
790
+ "use an explicit atom index."
791
+ )
792
+
793
+ raise ValueError(f"Atom spec '{spec}' did not match any atom.")
794
+
795
+
796
+ def atom_label_from_meta(atom_meta: _Sequence[Dict[str, Any]], index: int) -> str:
797
+ if index < 0 or index >= len(atom_meta):
798
+ return f"idx{index}"
799
+ meta = atom_meta[index]
800
+ resname = (meta.get("resname") or "?").strip() or "?"
801
+ resseq = meta.get("resseq")
802
+ resseq_txt = "?" if resseq is None else str(resseq)
803
+ atom = (meta.get("name") or "?").strip() or "?"
804
+ return f"{resname}-{resseq_txt}-{atom}"
805
+
806
+
807
+ def axis_label_csv(
808
+ axis_name: str,
809
+ i_idx: int,
810
+ j_idx: int,
811
+ one_based: bool,
812
+ atom_meta: Optional[_Sequence[Dict[str, Any]]] = None,
813
+ pair_raw: Optional[Tuple[Any, Any, float, float]] = None,
814
+ ) -> str:
815
+ if pair_raw and (isinstance(pair_raw[0], str) or isinstance(pair_raw[1], str)) and atom_meta:
816
+ i_label = atom_label_from_meta(atom_meta, i_idx)
817
+ j_label = atom_label_from_meta(atom_meta, j_idx)
818
+ return f"{axis_name}_{i_label}_{j_label}_A"
819
+ i_disp = i_idx + 1 if one_based else i_idx
820
+ j_disp = j_idx + 1 if one_based else j_idx
821
+ return f"{axis_name}_{i_disp}_{j_disp}_A"
822
+
823
+
824
+ def axis_label_html(label: str) -> str:
825
+ parts = label.split("_")
826
+ if len(parts) >= 4 and parts[-1] == "A":
827
+ axis = parts[0]
828
+ i_disp = parts[1]
829
+ j_disp = parts[2]
830
+ return f"{axis} ({i_disp},{j_disp}) (Å)"
831
+ return label
832
+
833
+
834
+ def resolve_scan_index(
835
+ value: Any,
836
+ *,
837
+ one_based: bool,
838
+ atom_meta: Optional[_Sequence[Dict[str, Any]]],
839
+ context: str,
840
+ ) -> int:
841
+ """Resolve an index or atom-spec string for scan lists with consistent errors."""
842
+ if isinstance(value, Integral):
843
+ idx_val = int(value)
844
+ if one_based:
845
+ idx_val -= 1
846
+ if idx_val < 0:
847
+ raise click.BadParameter(
848
+ f"Negative atom index after base conversion in {context}: {idx_val} (0-based expected)."
849
+ )
850
+ return idx_val
851
+ if isinstance(value, str):
852
+ if not atom_meta:
853
+ raise click.BadParameter(
854
+ f"{context} uses a string atom spec, but no PDB metadata is available."
855
+ )
856
+ try:
857
+ return resolve_atom_spec_index(value, atom_meta)
858
+ except ValueError as exc:
859
+ raise click.BadParameter(f"{context} {exc}")
860
+ raise click.BadParameter(f"{context} must be an int index or atom spec string.")
861
+
862
+
863
+ def parse_scan_list_triples(
864
+ raw: str,
865
+ *,
866
+ one_based: bool,
867
+ atom_meta: Optional[_Sequence[Dict[str, Any]]],
868
+ option_name: str,
869
+ return_one_based: bool = False,
870
+ ) -> Tuple[List[Tuple[int, int, float]], List[Tuple[Any, Any, float]]]:
871
+ """Parse --scan-lists entries into indices (0-based by default).
872
+
873
+ Accepts both 3-tuples ``(i, j, target)`` and 4-tuples
874
+ ``(i, j, start, end)`` for bidirectional scans. 4-tuples are
875
+ expanded into two 3-tuple stages (initial→start, then initial→end)
876
+ by the caller in scan.py.
877
+
878
+ The returned *parsed* list contains tuples of length 3 **or** 4:
879
+ ``(i, j, target)`` or ``(i, j, start, end)``.
880
+ """
881
+ try:
882
+ obj = ast.literal_eval(raw)
883
+ except Exception as e:
884
+ raise click.BadParameter(f"Invalid literal for {option_name}: {e}")
885
+
886
+ if not isinstance(obj, (list, tuple)):
887
+ raise click.BadParameter(f"{option_name} must be a list/tuple of (i,j,target) or (i,j,start,end).")
888
+
889
+ parsed: list = []
890
+ for entry_idx, t in enumerate(obj, start=1):
891
+ is_3 = (
892
+ isinstance(t, (list, tuple))
893
+ and len(t) == 3
894
+ and isinstance(t[2], Real)
895
+ )
896
+ is_4 = (
897
+ isinstance(t, (list, tuple))
898
+ and len(t) == 4
899
+ and isinstance(t[2], Real)
900
+ and isinstance(t[3], Real)
901
+ )
902
+ if not (is_3 or is_4):
903
+ raise click.BadParameter(
904
+ f"{option_name} entry {entry_idx} must be (i,j,target) or (i,j,start,end): got {t}"
905
+ )
906
+
907
+ i = resolve_scan_index(
908
+ t[0],
909
+ one_based=one_based,
910
+ atom_meta=atom_meta,
911
+ context=f"{option_name} entry {entry_idx} (i)",
912
+ )
913
+ j = resolve_scan_index(
914
+ t[1],
915
+ one_based=one_based,
916
+ atom_meta=atom_meta,
917
+ context=f"{option_name} entry {entry_idx} (j)",
918
+ )
919
+ if return_one_based:
920
+ i += 1
921
+ j += 1
922
+ if is_4:
923
+ parsed.append((i, j, float(t[2]), float(t[3])))
924
+ else:
925
+ parsed.append((i, j, float(t[2])))
926
+
927
+ return parsed, list(obj)
928
+
929
+
930
+ def parse_dist_freeze_list(
931
+ raw: str,
932
+ *,
933
+ one_based: bool,
934
+ atom_meta: Optional[_Sequence[Dict[str, Any]]],
935
+ option_name: str = "--dist-freeze",
936
+ ) -> List[Tuple[int, int, Optional[float]]]:
937
+ """Parse ``--dist-freeze`` entries: ``(i,j)`` or ``(i,j,target_A)``.
938
+
939
+ Uses the same :func:`resolve_scan_index` as ``--scan-lists``, so string
940
+ atom specs (e.g. ``'A:SER123:OG'``) are supported when PDB metadata is
941
+ available.
942
+ """
943
+ try:
944
+ obj = ast.literal_eval(raw)
945
+ except Exception as e:
946
+ raise click.BadParameter(f"Invalid literal for {option_name}: {e}")
947
+
948
+ if not isinstance(obj, (list, tuple)):
949
+ raise click.BadParameter(f"{option_name} must be a list/tuple of (i,j) or (i,j,target).")
950
+
951
+ # Single tuple → wrap in list
952
+ if obj and not isinstance(obj[0], (list, tuple)):
953
+ obj = [obj]
954
+
955
+ parsed: List[Tuple[int, int, Optional[float]]] = []
956
+ for entry_idx, t in enumerate(obj, start=1):
957
+ if not (isinstance(t, (list, tuple)) and len(t) in (2, 3)):
958
+ raise click.BadParameter(
959
+ f"{option_name} entry {entry_idx} must be (i,j) or (i,j,target): got {t}"
960
+ )
961
+ i = resolve_scan_index(
962
+ t[0], one_based=one_based, atom_meta=atom_meta,
963
+ context=f"{option_name} entry {entry_idx} (i)",
964
+ )
965
+ j = resolve_scan_index(
966
+ t[1], one_based=one_based, atom_meta=atom_meta,
967
+ context=f"{option_name} entry {entry_idx} (j)",
968
+ )
969
+ target: Optional[float] = None
970
+ if len(t) == 3:
971
+ if not isinstance(t[2], Real):
972
+ raise click.BadParameter(
973
+ f"Target distance must be numeric in {option_name} entry {entry_idx}: {t}"
974
+ )
975
+ target = float(t[2])
976
+ if target <= 0.0:
977
+ raise click.BadParameter(
978
+ f"Target distance must be > 0 in {option_name} entry {entry_idx}: {t}"
979
+ )
980
+ parsed.append((i, j, target))
981
+ return parsed
982
+
983
+
984
+ def parse_dist_freeze_spec(
985
+ spec_path: Path,
986
+ *,
987
+ one_based_default: bool,
988
+ atom_meta: Optional[_Sequence[Dict[str, Any]]],
989
+ option_name: str = "--dist-freeze",
990
+ ) -> List[Tuple[int, int, Optional[float]]]:
991
+ """Parse a YAML/JSON dist-freeze spec file.
992
+
993
+ Expected format::
994
+
995
+ constraints: # or "pairs" / "stages"
996
+ - [1, 5, 1.4] # (i, j, target_A) — target optional
997
+ - [2, 6] # freeze at current distance
998
+ one_based: true # optional, defaults to CLI value
999
+ """
1000
+ spec_cfg = _load_scan_spec_root(spec_path, option_name=option_name)
1001
+ key, raw_list = _first_spec_field(spec_cfg, ("constraints", "pairs", "stages"))
1002
+ if key is None:
1003
+ raise click.BadParameter(
1004
+ f"{option_name} spec must define 'constraints', 'pairs', or 'stages'."
1005
+ )
1006
+ if not isinstance(raw_list, (list, tuple)) or len(raw_list) == 0:
1007
+ raise click.BadParameter(
1008
+ f"{option_name} field '{key}' must be a non-empty list."
1009
+ )
1010
+
1011
+ one_based = _spec_one_based(
1012
+ spec_cfg.get("one_based"), default=one_based_default, option_name=option_name,
1013
+ )
1014
+ return parse_dist_freeze_list(
1015
+ repr(list(raw_list)),
1016
+ one_based=one_based,
1017
+ atom_meta=atom_meta,
1018
+ option_name=f"{option_name} {key}",
1019
+ )
1020
+
1021
+
1022
+ def parse_scan_list_quads(
1023
+ raw: str,
1024
+ *,
1025
+ expected_len: int,
1026
+ one_based: bool,
1027
+ atom_meta: Optional[_Sequence[Dict[str, Any]]],
1028
+ option_name: str,
1029
+ ) -> Tuple[List[Tuple[int, int, float, float]], List[Tuple[Any, Any, float, float]]]:
1030
+ """Parse --scan-lists quadruples into 0-based indices."""
1031
+ try:
1032
+ obj = ast.literal_eval(raw)
1033
+ except Exception as e:
1034
+ raise click.BadParameter(f"Invalid literal for {option_name}: {e}")
1035
+
1036
+ if not (isinstance(obj, (list, tuple)) and len(obj) == expected_len):
1037
+ quads = ",".join([f"(i{n},j{n},low{n},high{n})" for n in range(1, expected_len + 1)])
1038
+ raise click.BadParameter(
1039
+ f"{option_name} must contain exactly {expected_len} quadruples: [{quads}]"
1040
+ )
1041
+
1042
+ parsed: List[Tuple[int, int, float, float]] = []
1043
+ for entry_idx, q in enumerate(obj, start=1):
1044
+ if not (
1045
+ isinstance(q, (list, tuple))
1046
+ and len(q) == 4
1047
+ and isinstance(q[2], Real)
1048
+ and isinstance(q[3], Real)
1049
+ ):
1050
+ raise click.BadParameter(f"{option_name} entry must be (i,j,low,high): got {q}")
1051
+
1052
+ i = resolve_scan_index(
1053
+ q[0],
1054
+ one_based=one_based,
1055
+ atom_meta=atom_meta,
1056
+ context=f"{option_name} entry {entry_idx} (i)",
1057
+ )
1058
+ j = resolve_scan_index(
1059
+ q[1],
1060
+ one_based=one_based,
1061
+ atom_meta=atom_meta,
1062
+ context=f"{option_name} entry {entry_idx} (j)",
1063
+ )
1064
+ parsed.append((i, j, float(q[2]), float(q[3])))
1065
+
1066
+ for i, j, low, high in parsed:
1067
+ if low <= 0.0 or high <= 0.0:
1068
+ raise click.BadParameter(f"Distances must be positive: {(i, j, low, high)}")
1069
+
1070
+ return parsed, list(obj)
1071
+
1072
+
1073
+ def _load_scan_spec_root(
1074
+ spec_path: Path,
1075
+ *,
1076
+ option_name: str = "--scan-lists",
1077
+ ) -> Mapping[str, Any]:
1078
+ """Load a scan spec (YAML/JSON) and ensure mapping root."""
1079
+ try:
1080
+ with open(spec_path, "r", encoding="utf-8") as handle:
1081
+ data = yaml.safe_load(handle)
1082
+ except Exception as exc:
1083
+ raise click.BadParameter(
1084
+ f"Failed to parse {option_name} file '{spec_path}': {exc}"
1085
+ )
1086
+
1087
+ if data is None:
1088
+ raise click.BadParameter(f"{option_name} file '{spec_path}' is empty.")
1089
+ if not isinstance(data, Mapping):
1090
+ raise click.BadParameter(
1091
+ f"{option_name} file '{spec_path}' must have a mapping at the YAML/JSON root."
1092
+ )
1093
+ return data
1094
+
1095
+
1096
+ def _spec_one_based(
1097
+ value: Any,
1098
+ *,
1099
+ default: bool,
1100
+ option_name: str = "--scan-lists",
1101
+ ) -> bool:
1102
+ """Resolve one_based value from spec with CLI fallback."""
1103
+ if value is None:
1104
+ return bool(default)
1105
+ if isinstance(value, bool):
1106
+ return value
1107
+ if isinstance(value, str):
1108
+ key = value.strip().lower()
1109
+ if key in {"1", "true", "yes", "y", "on"}:
1110
+ return True
1111
+ if key in {"0", "false", "no", "n", "off"}:
1112
+ return False
1113
+ raise click.BadParameter(
1114
+ f"{option_name} field 'one_based' must be a boolean (true/false)."
1115
+ )
1116
+
1117
+
1118
+ def _first_spec_field(
1119
+ spec_cfg: Mapping[str, Any],
1120
+ candidates: _Sequence[str],
1121
+ ) -> Tuple[Optional[str], Any]:
1122
+ for key in candidates:
1123
+ if key in spec_cfg:
1124
+ return key, spec_cfg[key]
1125
+ return None, None
1126
+
1127
+
1128
+ def is_scan_spec_file(value: str) -> bool:
1129
+ """Return True if *value* looks like an existing YAML/JSON scan spec file."""
1130
+ p = Path(value)
1131
+ return p.is_file() and p.suffix.lower() in {".yaml", ".yml", ".json"}
1132
+
1133
+
1134
+ def parse_scan_spec_stages(
1135
+ spec_path: Path,
1136
+ *,
1137
+ one_based_default: bool,
1138
+ atom_meta: Optional[_Sequence[Dict[str, Any]]],
1139
+ option_name: str = "--scan-lists",
1140
+ ) -> Tuple[List[List[Tuple[int, int, float]]], bool]:
1141
+ """Parse staged 1D scan spec into 0-based stage triples."""
1142
+ spec_cfg = _load_scan_spec_root(spec_path, option_name=option_name)
1143
+ stages_key, stages_raw = _first_spec_field(spec_cfg, ("stages",))
1144
+ if stages_key is None:
1145
+ raise click.BadParameter(f"{option_name} must define 'stages'.")
1146
+ if not isinstance(stages_raw, (list, tuple)) or len(stages_raw) == 0:
1147
+ raise click.BadParameter(f"{option_name} field '{stages_key}' must be a non-empty list.")
1148
+
1149
+ one_based = _spec_one_based(
1150
+ spec_cfg.get("one_based"), default=one_based_default, option_name=option_name
1151
+ )
1152
+ stages: List[List[Tuple[int, int, float]]] = []
1153
+ for stage_idx, stage_raw in enumerate(stages_raw, start=1):
1154
+ if not isinstance(stage_raw, (list, tuple)):
1155
+ raise click.BadParameter(
1156
+ f"{option_name} {stages_key}[{stage_idx}] must be a list of (i,j,target) entries."
1157
+ )
1158
+ parsed, _ = parse_scan_list_triples(
1159
+ repr(list(stage_raw)),
1160
+ one_based=one_based,
1161
+ atom_meta=atom_meta,
1162
+ option_name=f"{option_name} {stages_key}[{stage_idx}]",
1163
+ )
1164
+ if not parsed:
1165
+ raise click.BadParameter(
1166
+ f"{option_name} {stages_key}[{stage_idx}] must contain at least one (i,j,target) triple."
1167
+ )
1168
+ for i, j, target in parsed:
1169
+ if target <= 0.0:
1170
+ raise click.BadParameter(
1171
+ f"Non-positive target distance in {option_name} {stages_key}[{stage_idx}]: {(i, j, target)}."
1172
+ )
1173
+ stages.append(parsed)
1174
+ return stages, one_based
1175
+
1176
+
1177
+ def parse_scan_spec_quads(
1178
+ spec_path: Path,
1179
+ *,
1180
+ expected_len: int,
1181
+ one_based_default: bool,
1182
+ atom_meta: Optional[_Sequence[Dict[str, Any]]],
1183
+ option_name: str = "--scan-lists",
1184
+ ) -> Tuple[List[Tuple[int, int, float, float]], List[Tuple[Any, Any, float, float]], bool]:
1185
+ """Parse 2D/3D scan spec into 0-based quad tuples."""
1186
+ spec_cfg = _load_scan_spec_root(spec_path, option_name=option_name)
1187
+ pairs_key, pairs_raw = _first_spec_field(spec_cfg, ("pairs",))
1188
+ if pairs_key is None:
1189
+ raise click.BadParameter(f"{option_name} must define 'pairs'.")
1190
+ if not isinstance(pairs_raw, (list, tuple)):
1191
+ raise click.BadParameter(f"{option_name} field '{pairs_key}' must be a list.")
1192
+
1193
+ one_based = _spec_one_based(
1194
+ spec_cfg.get("one_based"), default=one_based_default, option_name=option_name
1195
+ )
1196
+ parsed, raw_pairs = parse_scan_list_quads(
1197
+ repr(list(pairs_raw)),
1198
+ expected_len=expected_len,
1199
+ one_based=one_based,
1200
+ atom_meta=atom_meta,
1201
+ option_name=f"{option_name} {pairs_key}",
1202
+ )
1203
+ return parsed, raw_pairs, one_based
1204
+
1205
+
1206
+ PDB_ATOM_META_HEADER = f"{'id':>5} {'atom':<4} {'res':<4} {'resid':>4} {'el':<2}"
1207
+
1208
+
1209
+ def format_pdb_atom_metadata(atom_meta: _Sequence[Dict[str, Any]], index: int) -> str:
1210
+ """Format metadata for atom *index* as aligned text: serial name resname resseq element."""
1211
+ fallback_serial = index + 1
1212
+ if index < 0 or index >= len(atom_meta):
1213
+ return f"{fallback_serial:>5} {'?':<4} {'?':<4} {'?':>4} {'?':<2}"
1214
+
1215
+ meta = atom_meta[index]
1216
+ serial = meta.get("serial") or fallback_serial
1217
+ name = meta.get("name") or "?"
1218
+ resname = meta.get("resname") or "?"
1219
+ resseq = meta.get("resseq")
1220
+ resseq_str = "?" if resseq is None else str(resseq)
1221
+ element = (meta.get("element") or "?").strip() or "?"
1222
+
1223
+ return f"{serial:>5} {name:<4} {resname:<4} {resseq_str:>4} {element:<2}"
1224
+
1225
+
1226
+ def normalize_choice(
1227
+ value: str,
1228
+ *,
1229
+ param: str,
1230
+ alias_groups: Sequence[Tuple[Sequence[str], str]],
1231
+ allowed_hint: str,
1232
+ ) -> str:
1233
+ """Normalize a mode choice using alias groups and raise error on failure.
1234
+
1235
+ Parameters
1236
+ ----------
1237
+ value : str
1238
+ The value to normalize.
1239
+ param : str
1240
+ Parameter name for error messages.
1241
+ alias_groups : Sequence[Tuple[Sequence[str], str]]
1242
+ Sequence of (aliases, canonical) pairs where aliases is a sequence of strings.
1243
+ allowed_hint : str
1244
+ Description of allowed values for error messages.
1245
+
1246
+ Returns
1247
+ -------
1248
+ str
1249
+ The canonical value corresponding to the matched alias.
1250
+
1251
+ Raises
1252
+ ------
1253
+ click.BadParameter
1254
+ If the value does not match any alias.
1255
+ """
1256
+ key = (value or "").strip().lower()
1257
+ for aliases, canonical in alias_groups:
1258
+ if any(key == alias.lower() for alias in aliases):
1259
+ return canonical
1260
+
1261
+ hint = allowed_hint.strip()
1262
+ detail = f" Allowed: {hint}." if hint else ""
1263
+ raise click.BadParameter(f"Unknown value for {param} '{value}'.{detail}")
1264
+
1265
+
1266
+ def _get_mapping_section(cfg: Mapping[str, Any], path: _Sequence[str]) -> Optional[Dict[str, Any]]:
1267
+ cur: Any = cfg
1268
+ for key in path:
1269
+ if not isinstance(cur, Mapping):
1270
+ return None
1271
+ cur = cur.get(key)
1272
+ if cur is None:
1273
+ return None
1274
+ return cur if isinstance(cur, dict) else None
1275
+
1276
+
1277
+ def apply_yaml_overrides(
1278
+ yaml_cfg: Mapping[str, Any],
1279
+ overrides: _Sequence[Tuple[Dict[str, Any], _Sequence[_Sequence[str]]]],
1280
+ ) -> None:
1281
+ """Apply YAML overrides to multiple target dictionaries.
1282
+
1283
+ Parameters
1284
+ ----------
1285
+ yaml_cfg : Mapping[str, Any]
1286
+ Parsed YAML configuration (root-level mapping).
1287
+ overrides : Sequence[Tuple[Dict[str, Any], Sequence[Sequence[str]]]]
1288
+ Each entry consists of the target dictionary to update followed by one or
1289
+ more candidate key paths. The first existing path is used. For example::
1290
+
1291
+ apply_yaml_overrides(
1292
+ yaml_cfg,
1293
+ [
1294
+ (geom_cfg, (("geom",),)),
1295
+ (lbfgs_cfg, (("stopt", "lbfgs"), ("lbfgs",))),
1296
+ ],
1297
+ )
1298
+
1299
+ This mirrors the previous ``deep_update(..., yaml_cfg.get(...))`` pattern
1300
+ while centralizing the shared logic.
1301
+ """
1302
+ for target, paths in overrides:
1303
+ for path in paths:
1304
+ norm_path = tuple(path)
1305
+ section = _get_mapping_section(yaml_cfg, norm_path)
1306
+ if section is not None:
1307
+ deep_update(target, section)
1308
+ break
1309
+
1310
+
1311
+ def yaml_section_has_key(
1312
+ yaml_cfg: Mapping[str, Any],
1313
+ paths: _Sequence[_Sequence[str]],
1314
+ key: str,
1315
+ ) -> bool:
1316
+ """Return True when any candidate YAML section explicitly defines ``key``."""
1317
+ for path in paths:
1318
+ section = _get_mapping_section(yaml_cfg, tuple(path))
1319
+ if isinstance(section, Mapping) and (key in section):
1320
+ return True
1321
+ return False
1322
+
1323
+
1324
+ def load_yaml_dict(path: Optional[Path]) -> Dict[str, Any]:
1325
+ """
1326
+ Load a YAML file whose root must be a mapping. Return an empty dict if *path* is None.
1327
+ """
1328
+ if not path:
1329
+ return {}
1330
+
1331
+ with open(path, "r") as f:
1332
+ data = yaml.safe_load(f) or {}
1333
+
1334
+ if not isinstance(data, dict):
1335
+ raise ValueError(f"YAML root must be a mapping, got: {type(data)}")
1336
+
1337
+ return data
1338
+
1339
+
1340
+ # =============================================================================
1341
+ # Plotly: Energy diagram builder
1342
+ # =============================================================================
1343
+ def build_energy_diagram(
1344
+ energies: Sequence[float],
1345
+ labels: Sequence[str],
1346
+ ylabel: str = "ΔE",
1347
+ baseline: bool = False,
1348
+ showgrid: bool = False,
1349
+ ) -> go.Figure:
1350
+ """
1351
+ Plot an energy diagram using Plotly.
1352
+
1353
+ Parameters
1354
+ ----------
1355
+ energies : Sequence[float]
1356
+ Energies for each state (same unit). Values are plotted without conversion.
1357
+ labels : Sequence[str]
1358
+ Labels corresponding to each state (for example, ["R", "TS1", "IM1", "TS2", "P"]).
1359
+ Must be the same length as ``energies``.
1360
+ ylabel : str, optional
1361
+ Y-axis label (for example, "ΔE" or "ΔG"). Defaults to ``"ΔE"``.
1362
+ baseline : bool, optional
1363
+ If ``True``, draw a dotted baseline at the energy of the first state across the plot.
1364
+ showgrid : bool, optional
1365
+ If ``True``, show grid lines on both axes. Defaults to ``False``.
1366
+
1367
+ Returns
1368
+ -------
1369
+ plotly.graph_objs.Figure
1370
+ Figure containing the energy diagram.
1371
+
1372
+ Notes
1373
+ -----
1374
+ - Each state is rendered as a thick horizontal segment (width ``HLINE_WIDTH``).
1375
+ - Adjacent states are connected by dotted diagonal segments from the right end of
1376
+ the left state to the left end of the right state.
1377
+ - Segment length automatically shrinks with additional states so that gaps remain
1378
+ between neighbors.
1379
+ - X-axis ticks are centered on each state and labeled using ``labels``.
1380
+ """
1381
+ if len(energies) == 0:
1382
+ raise ValueError("`energies` must contain at least one value.")
1383
+ if len(energies) != len(labels):
1384
+ raise ValueError("`energies` and `labels` must have the same length.")
1385
+
1386
+ n = len(energies)
1387
+ energies = [float(e) for e in energies]
1388
+
1389
+ # -----------------------------
1390
+ # Layout/style constants
1391
+ # -----------------------------
1392
+ AXIS_WIDTH = 3
1393
+ FONT_SIZE = 18
1394
+ AXIS_TITLE_SIZE = 20
1395
+ HLINE_WIDTH = 6 # Width of the horizontal state segments
1396
+ CONNECTOR_WIDTH = 2 # Width of the dotted connectors
1397
+ LINE_COLOR = "#1C1C1C"
1398
+ GRID_COLOR = "lightgrey"
1399
+
1400
+ # -----------------------------
1401
+ # Geometry along the X axis (centers and segment lengths)
1402
+ # -----------------------------
1403
+ # Place segment centers at 0.5, 1.5, 2.5, ... (equally spaced)
1404
+ centers = [i + 0.5 for i in range(n)]
1405
+
1406
+ # Shorten the segment as n grows (min 0.35, max 0.85)
1407
+ # Examples: n=5 -> 0.7, n=10 -> 0.5, n>=20 -> 0.35
1408
+ seg_width = min(0.85, max(0.35, 0.90 - 0.04 * n))
1409
+ half = seg_width / 2.0
1410
+
1411
+ lefts = [c - half for c in centers]
1412
+ rights = [c + half for c in centers]
1413
+
1414
+ # -----------------------------
1415
+ # Assemble the figure
1416
+ # -----------------------------
1417
+ fig = go.Figure()
1418
+
1419
+ # Baseline (dotted line at the first energy level)
1420
+ if baseline:
1421
+ fig.add_trace(
1422
+ go.Scatter(
1423
+ x=[lefts[0], rights[-1]],
1424
+ y=[energies[0], energies[0]],
1425
+ mode="lines",
1426
+ line=dict(color=GRID_COLOR, dash="dot", width=2),
1427
+ hoverinfo="skip",
1428
+ showlegend=False,
1429
+ )
1430
+ )
1431
+
1432
+ # Horizontal segments for each state
1433
+ for i, (e, lab) in enumerate(zip(energies, labels)):
1434
+ fig.add_trace(
1435
+ go.Scatter(
1436
+ x=[lefts[i], rights[i]],
1437
+ y=[e, e],
1438
+ mode="lines",
1439
+ line=dict(color=LINE_COLOR, width=HLINE_WIDTH),
1440
+ hovertemplate=f"{lab}: %{{y:.6f}}<extra></extra>",
1441
+ showlegend=False,
1442
+ )
1443
+ )
1444
+
1445
+ # Dotted diagonals between adjacent states (right end -> left end)
1446
+ for i in range(n - 1):
1447
+ fig.add_trace(
1448
+ go.Scatter(
1449
+ x=[rights[i], lefts[i + 1]],
1450
+ y=[energies[i], energies[i + 1]],
1451
+ mode="lines",
1452
+ line=dict(color=LINE_COLOR, width=CONNECTOR_WIDTH, dash="dot"),
1453
+ hoverinfo="skip",
1454
+ showlegend=False,
1455
+ )
1456
+ )
1457
+
1458
+ # -----------------------------
1459
+ # Axis ranges and styling
1460
+ # -----------------------------
1461
+ # Add a small margin beyond the first/last segments on X
1462
+ xpad = max(0.08, 0.15 * (1.0 - seg_width))
1463
+ x_min = lefts[0] - xpad
1464
+ x_max = rights[-1] + xpad
1465
+
1466
+ # Add vertical padding above and below
1467
+ y_min = min(energies)
1468
+ y_max = max(energies)
1469
+ span = max(1e-6, y_max - y_min) # Avoid zero span even if all values match
1470
+ ypad_low = 0.10 * span
1471
+ ypad_high = 0.20 * span
1472
+ y_range = [y_min - ypad_low, y_max + ypad_high]
1473
+
1474
+ xaxis_config = dict(
1475
+ range=[x_min, x_max],
1476
+ showline=True,
1477
+ linewidth=AXIS_WIDTH,
1478
+ linecolor=LINE_COLOR,
1479
+ mirror=True,
1480
+ ticks="inside",
1481
+ tickwidth=AXIS_WIDTH,
1482
+ tickcolor=LINE_COLOR,
1483
+ tickfont=dict(size=FONT_SIZE, color=LINE_COLOR),
1484
+ showgrid=showgrid,
1485
+ gridcolor=GRID_COLOR,
1486
+ gridwidth=0.5,
1487
+ zeroline=False,
1488
+ tickmode="array",
1489
+ tickvals=centers,
1490
+ ticktext=list(labels),
1491
+ title=dict(text="", font=dict(size=AXIS_TITLE_SIZE, color=LINE_COLOR)),
1492
+ )
1493
+
1494
+ yaxis_config = dict(
1495
+ range=y_range,
1496
+ showline=True,
1497
+ linewidth=AXIS_WIDTH,
1498
+ linecolor=LINE_COLOR,
1499
+ mirror=True,
1500
+ ticks="inside",
1501
+ tickwidth=AXIS_WIDTH,
1502
+ tickcolor=LINE_COLOR,
1503
+ tickfont=dict(size=FONT_SIZE, color=LINE_COLOR),
1504
+ showgrid=showgrid,
1505
+ gridcolor=GRID_COLOR,
1506
+ gridwidth=0.5,
1507
+ zeroline=False,
1508
+ title=dict(text=ylabel, font=dict(size=AXIS_TITLE_SIZE, color=LINE_COLOR)),
1509
+ )
1510
+
1511
+ fig.update_layout(
1512
+ xaxis=xaxis_config,
1513
+ yaxis=yaxis_config,
1514
+ plot_bgcolor="white",
1515
+ paper_bgcolor="white",
1516
+ margin=dict(l=80, r=40, t=40, b=80),
1517
+ )
1518
+
1519
+ return fig
1520
+
1521
+
1522
+ # =============================================================================
1523
+ # Coordinate conversion utilities
1524
+ # =============================================================================
1525
+ def convert_xyz_to_pdb(xyz_path: Path, ref_pdb_path: Path, out_pdb_path: Path) -> None:
1526
+ """Overlay coordinates from *xyz_path* onto the topology of *ref_pdb_path* and write to *out_pdb_path*.
1527
+
1528
+ Notes:
1529
+ - *xyz_path* may contain one or many frames. For multi‑frame trajectories,
1530
+ a MODEL/ENDMDL block is appended for each subsequent frame in the output PDB.
1531
+ - On the first frame the output file is created/overwritten; subsequent frames are appended.
1532
+ - Validates that atom ordering (element symbols) matches between XYZ and PDB.
1533
+
1534
+ Args:
1535
+ xyz_path: Path to an XYZ file (single or multi-frame).
1536
+ ref_pdb_path: Path to a reference PDB providing atom ordering/topology.
1537
+ out_pdb_path: Destination PDB file to write.
1538
+
1539
+ Raises:
1540
+ ValueError: If no frames found in XYZ file or atom ordering mismatch.
1541
+ """
1542
+ ref_atoms = read(ref_pdb_path) # Reference topology/ordering (single frame)
1543
+ traj = read(xyz_path, index=":", format="xyz") # Load all frames from the XYZ
1544
+ if not traj:
1545
+ raise ValueError(f"No frames found in {xyz_path}.")
1546
+
1547
+ ref_symbols = ref_atoms.get_chemical_symbols()
1548
+
1549
+ for step, frame in enumerate(traj):
1550
+ xyz_symbols = frame.get_chemical_symbols()
1551
+ xyz_positions = frame.get_positions()
1552
+
1553
+ if xyz_symbols != ref_symbols:
1554
+ # If atom counts match, the PDB likely has missing/wrong element columns.
1555
+ # Trust XYZ symbols (from pysisyphus which uses proper element detection)
1556
+ # and patch the reference atoms so PDB topology is preserved.
1557
+ if len(xyz_symbols) == len(ref_symbols):
1558
+ ref_atoms.set_chemical_symbols(xyz_symbols)
1559
+ ref_symbols = xyz_symbols
1560
+ else:
1561
+ raise ValueError(
1562
+ "Atom count mismatch between XYZ and PDB; "
1563
+ f"XYZ has {len(xyz_symbols)} atoms, PDB has {len(ref_symbols)}."
1564
+ )
1565
+
1566
+ atoms = ref_atoms.copy()
1567
+ atoms.set_positions(xyz_positions)
1568
+ if step == 0:
1569
+ write(out_pdb_path, atoms) # Create/overwrite on the first frame
1570
+ else:
1571
+ write(out_pdb_path, atoms, append=True) # Append subsequent frames using MODEL/ENDMDL
1572
+
1573
+
1574
+ # =============================================================================
1575
+ # Global toggle for XYZ/TRJ → PDB conversion
1576
+ # =============================================================================
1577
+ _CONVERT_FILES_ENABLED: bool = True
1578
+
1579
+
1580
+ def set_convert_file_enabled(enabled: bool) -> None:
1581
+ """Globally enable or disable XYZ/TRJ conversions to PDB outputs."""
1582
+ global _CONVERT_FILES_ENABLED
1583
+ _CONVERT_FILES_ENABLED = bool(enabled)
1584
+
1585
+
1586
+ def is_convert_file_enabled() -> bool:
1587
+ """Check if convert-files is globally enabled."""
1588
+ return _CONVERT_FILES_ENABLED
1589
+
1590
+
1591
+ def convert_xyz_like_outputs(
1592
+ xyz_path: Path,
1593
+ ref_pdb_path: Optional[Path],
1594
+ out_pdb_path: Optional[Path] = None,
1595
+ *,
1596
+ context: str = "outputs",
1597
+ on_error: str = "raise",
1598
+ ) -> bool:
1599
+ """Convert an XYZ file to PDB output using ref topology.
1600
+
1601
+ Respects the global _CONVERT_FILES_ENABLED toggle.
1602
+ Returns True when conversion succeeded; False otherwise.
1603
+ """
1604
+ if not _CONVERT_FILES_ENABLED:
1605
+ return False
1606
+ if ref_pdb_path is None or out_pdb_path is None:
1607
+ return False
1608
+ try:
1609
+ convert_xyz_to_pdb(xyz_path, ref_pdb_path, out_pdb_path)
1610
+ return True
1611
+ except Exception as e:
1612
+ if on_error == "warn":
1613
+ click.echo(f"[convert] WARNING: Failed to convert {context}: {e}", err=True)
1614
+ return False
1615
+ raise click.ClickException(f"[convert] Failed to convert {context}: {e}") from e
1616
+
1617
+
1618
+ def pdb_keys_from_line(line: str) -> Tuple[Tuple, Tuple]:
1619
+ """Extract robust keys from a PDB ATOM/HETATM record.
1620
+
1621
+ Returns:
1622
+ key_full: (chain, resseq, icode, resname, atomname, altloc)
1623
+ key_simple: (chain, resseq, icode, atomname)
1624
+ """
1625
+ atomname = line[12:16].strip()
1626
+ altloc = line[16:17].strip()
1627
+ resname = line[17:20].strip()
1628
+ chain = line[21:22].strip()
1629
+ resseq_str = line[22:26].strip()
1630
+ try:
1631
+ resseq = int(resseq_str)
1632
+ except ValueError:
1633
+ resseq = -10**9 # unlikely sentinel when missing
1634
+ icode = line[26:27].strip()
1635
+ key_full = (chain, resseq, icode, resname, atomname, altloc)
1636
+ key_simple = (chain, resseq, icode, atomname)
1637
+ return key_full, key_simple
1638
+
1639
+
1640
+ def collect_ml_atom_keys(model_pdb: Path) -> Tuple[set, set]:
1641
+ """Collect ML-region atom keys from model_pdb.
1642
+
1643
+ Returns:
1644
+ keys_full: Set of (chain, resseq, icode, resname, atomname, altloc)
1645
+ keys_simple: Set of (chain, resseq, icode, atomname)
1646
+ """
1647
+ from typing import Set as SetType
1648
+ keys_full: SetType[Tuple] = set()
1649
+ keys_simple: SetType[Tuple] = set()
1650
+ try:
1651
+ with model_pdb.open("r") as fh:
1652
+ for line in fh:
1653
+ if line.startswith("ATOM") or line.startswith("HETATM"):
1654
+ kf, ks = pdb_keys_from_line(line)
1655
+ keys_full.add(kf)
1656
+ keys_simple.add(ks)
1657
+ except Exception:
1658
+ # If anything goes wrong, leave sets empty; caller will handle gracefully.
1659
+ pass
1660
+ return keys_full, keys_simple
1661
+
1662
+
1663
+ def format_pdb_with_bfactor(line: str, b: float) -> str:
1664
+ """Return PDB line with B-factor field (cols 61-66) set to b (6.2f)."""
1665
+ if len(line) < 66:
1666
+ line = line.rstrip("\n")
1667
+ line = line + " " * max(0, 66 - len(line))
1668
+ line = line + "\n"
1669
+ bf_str = f"{b:6.2f}"
1670
+ # Preserve occupancy (cols 55-60), overwrite tempFactor (61-66).
1671
+ new_line = line[:60] + bf_str + line[66:]
1672
+ return new_line
1673
+
1674
+
1675
+ def annotate_pdb_bfactors_inplace(
1676
+ pdb_path: Path,
1677
+ model_pdb: Path,
1678
+ freeze_indices_0based: Sequence[int],
1679
+ beta_ml: float = 0.0,
1680
+ beta_frz: float = 20.0,
1681
+ beta_both: float = 0.0,
1682
+ ) -> None:
1683
+ """Overwrite B-factors in-place using 3-layer encoding (ML=0, MovableMM=10, FrozenMM=20).
1684
+
1685
+ - ML-region atoms: beta_ml (default 0.00)
1686
+ - frozen atoms: beta_frz (default 20.00)
1687
+ - ML ∩ frozen: beta_both (default 0.00, ML takes precedence)
1688
+
1689
+ Indexing for 'frozen' is 0-based and resets at each MODEL.
1690
+ """
1691
+ ml_full, ml_simple = collect_ml_atom_keys(model_pdb)
1692
+ frozen_set = set(int(i) for i in (freeze_indices_0based or []))
1693
+
1694
+ try:
1695
+ lines = pdb_path.read_text().splitlines(keepends=True)
1696
+ except Exception:
1697
+ return
1698
+
1699
+ out_lines: List[str] = []
1700
+ atom_idx = 0 # resets per MODEL
1701
+
1702
+ for line in lines:
1703
+ rec = line[:6]
1704
+ if rec.startswith("MODEL"):
1705
+ # reset atom counter for each model
1706
+ atom_idx = 0
1707
+ out_lines.append(line)
1708
+ continue
1709
+ if rec.startswith("ATOM ") or rec.startswith("HETATM"):
1710
+ kf, ks = pdb_keys_from_line(line)
1711
+ is_ml = (kf in ml_full) or (ks in ml_simple)
1712
+ is_frz = (atom_idx in frozen_set)
1713
+ if is_ml and is_frz:
1714
+ out_lines.append(format_pdb_with_bfactor(line, beta_both))
1715
+ elif is_ml:
1716
+ out_lines.append(format_pdb_with_bfactor(line, beta_ml))
1717
+ elif is_frz:
1718
+ out_lines.append(format_pdb_with_bfactor(line, beta_frz))
1719
+ else:
1720
+ out_lines.append(format_pdb_with_bfactor(line, 10.0))
1721
+ atom_idx += 1
1722
+ else:
1723
+ out_lines.append(line)
1724
+
1725
+ try:
1726
+ pdb_path.write_text("".join(out_lines))
1727
+ except Exception:
1728
+ # Silently ignore if we cannot write; conversion outputs are still present.
1729
+ pass
1730
+
1731
+
1732
+ def convert_and_annotate_xyz_to_pdb(
1733
+ src_xyz_or_trj: Path,
1734
+ ref_pdb: Path,
1735
+ dst_pdb: Path,
1736
+ model_pdb: Path,
1737
+ freeze_indices_0based: Sequence[int],
1738
+ ) -> None:
1739
+ """Convert an XYZ/TRJ file to PDB and annotate B-factors to highlight ML and frozen atoms.
1740
+
1741
+ This mirrors the behaviour of the `opt` tool:
1742
+ - ML-region atoms: 100.00
1743
+ - frozen atoms: 50.00
1744
+ - ML ∩ frozen: 150.00
1745
+ """
1746
+ try:
1747
+ convert_xyz_to_pdb(src_xyz_or_trj, ref_pdb, dst_pdb)
1748
+ annotate_pdb_bfactors_inplace(
1749
+ dst_pdb,
1750
+ model_pdb=model_pdb,
1751
+ freeze_indices_0based=freeze_indices_0based,
1752
+ )
1753
+ except Exception as exc:
1754
+ click.echo(
1755
+ f"[convert] WARNING: Failed to convert '{src_xyz_or_trj}' to PDB: {exc}",
1756
+ err=True,
1757
+ )
1758
+
1759
+
1760
+ # =============================================================================
1761
+ # Input preparation helpers
1762
+ # =============================================================================
1763
+
1764
+
1765
+ @dataclass
1766
+ class PreparedInputStructure:
1767
+ source_path: Path
1768
+ geom_path: Path
1769
+
1770
+ def cleanup(self) -> None:
1771
+ """No-op: no temporary files are created."""
1772
+ return None
1773
+
1774
+ def __enter__(self) -> "PreparedInputStructure":
1775
+ return self
1776
+
1777
+ def __exit__(self, exc_type, exc, tb) -> None:
1778
+ self.cleanup()
1779
+
1780
+
1781
+ def prepare_input_structure(path: Path) -> PreparedInputStructure:
1782
+ """Return a lightweight wrapper for the provided structure path."""
1783
+ return PreparedInputStructure(source_path=path, geom_path=path)
1784
+
1785
+
1786
+ def _count_atoms_in_file(path: Path) -> int:
1787
+ """Count atoms in a structure file (PDB or XYZ)."""
1788
+ suffix = path.suffix.lower()
1789
+ if suffix == ".pdb":
1790
+ count = 0
1791
+ with open(path, "r") as f:
1792
+ for line in f:
1793
+ if line.startswith(("ATOM ", "HETATM")):
1794
+ count += 1
1795
+ return count
1796
+ elif suffix == ".xyz":
1797
+ # XYZ format: first line is atom count
1798
+ with open(path, "r") as f:
1799
+ first_line = f.readline().strip()
1800
+ try:
1801
+ return int(first_line)
1802
+ except ValueError:
1803
+ return 0
1804
+ return 0
1805
+
1806
+
1807
+ def apply_ref_pdb_override(
1808
+ prepared_input: PreparedInputStructure,
1809
+ ref_pdb: Optional[Path],
1810
+ ) -> Optional[Path]:
1811
+ """Use a reference PDB topology while keeping XYZ coordinates for geometry loading.
1812
+
1813
+ When --ref-pdb is provided:
1814
+ - geom_path remains the original input (xyz) for high-precision coordinates
1815
+ - source_path is updated to ref_pdb for topology/residue information
1816
+ """
1817
+ import click
1818
+ if ref_pdb is None:
1819
+ return None
1820
+ ref_pdb = Path(ref_pdb).resolve()
1821
+ if ref_pdb.suffix.lower() != ".pdb":
1822
+ raise click.BadParameter("--ref-pdb must be a .pdb file.")
1823
+ geom_count = _count_atoms_in_file(prepared_input.geom_path)
1824
+ ref_count = _count_atoms_in_file(ref_pdb)
1825
+ if geom_count != ref_count:
1826
+ raise click.BadParameter(
1827
+ f"Atom count mismatch: {prepared_input.geom_path.name} has {geom_count} atoms, "
1828
+ f"but --ref-pdb {ref_pdb.name} has {ref_count} atoms."
1829
+ )
1830
+ prepared_input.source_path = ref_pdb
1831
+ return ref_pdb
1832
+
1833
+
1834
+ def _round_charge_with_note(q: float, prefix: str = "") -> int:
1835
+ """Round a float charge to the nearest integer, with a note if not exact."""
1836
+ if not math.isfinite(q):
1837
+ raise click.BadParameter(f"Computed total charge is non-finite: {q!r}")
1838
+ q_int = int(round(q))
1839
+ if abs(float(q) - q_int) > 1e-6:
1840
+ click.echo(
1841
+ f"{prefix} NOTE: total charge = {q:+g} → rounded to integer {q_int:+d}."
1842
+ )
1843
+ return q_int
1844
+
1845
+
1846
+ def _derive_charge_from_ligand_charge(
1847
+ pdb_path: Path,
1848
+ ligand_charge: Optional[str],
1849
+ *,
1850
+ prefix: str = "",
1851
+ ) -> Optional[int]:
1852
+ """Derive total system charge from a PDB file using ``--ligand-charge`` metadata.
1853
+
1854
+ Returns ``None`` when *ligand_charge* is ``None`` or derivation fails.
1855
+ """
1856
+ if ligand_charge is None:
1857
+ return None
1858
+ try:
1859
+ from Bio import PDB as BioPDB
1860
+ from .extract import compute_charge_summary, log_charge_summary
1861
+
1862
+ parser = BioPDB.PDBParser(QUIET=True)
1863
+ complex_struct = parser.get_structure("complex", str(pdb_path))
1864
+
1865
+ # Use only ML-region residues (B-factor ≈ 0) when layered PDB is available.
1866
+ # A residue is included if ANY of its atoms has B-factor < 1.0 (ML layer).
1867
+ ml_residue_ids = set()
1868
+ all_residue_ids = set()
1869
+ for res in complex_struct.get_residues():
1870
+ fid = res.get_full_id()
1871
+ all_residue_ids.add(fid)
1872
+ for atom in res.get_atoms():
1873
+ if atom.get_bfactor() < 1.0:
1874
+ ml_residue_ids.add(fid)
1875
+ break
1876
+ # Fall back to all residues if no B-factor layering is present
1877
+ # (i.e. every residue has B=0 means unlayered PDB).
1878
+ selected_ids = ml_residue_ids if ml_residue_ids != all_residue_ids else all_residue_ids
1879
+ summary = compute_charge_summary(
1880
+ complex_struct, selected_ids, set(), ligand_charge
1881
+ )
1882
+ log_charge_summary(prefix, summary)
1883
+ q_total = float(summary.get("total_charge", 0.0))
1884
+ click.echo(
1885
+ f"{prefix} Charge summary (--ligand-charge):"
1886
+ )
1887
+ click.echo(
1888
+ f" Protein: {summary.get('protein_charge', 0.0):+g}, "
1889
+ f"Ligand: {summary.get('ligand_total_charge', 0.0):+g}, "
1890
+ f"Ions: {summary.get('ion_total_charge', 0.0):+g}, "
1891
+ f"Total: {q_total:+g}"
1892
+ )
1893
+ return _round_charge_with_note(q_total, prefix)
1894
+ except Exception as e:
1895
+ click.echo(
1896
+ f"{prefix} NOTE: failed to derive charge from --ligand-charge: {e}",
1897
+ err=True,
1898
+ )
1899
+ return None
1900
+
1901
+
1902
+ def resolve_charge_spin_or_raise(
1903
+ prepared: PreparedInputStructure,
1904
+ charge: Optional[int],
1905
+ spin: Optional[int],
1906
+ *,
1907
+ spin_default: int = 1,
1908
+ charge_default: Optional[int] = None,
1909
+ ligand_charge: Optional[str] = None,
1910
+ prefix: str = "",
1911
+ ) -> Tuple[int, int]:
1912
+ """Resolve charge/spin from inputs.
1913
+
1914
+ Priority: explicit ``-q/--charge`` > ``--ligand-charge`` derivation >
1915
+ ``charge_default``. Raises :class:`click.ClickException` when charge
1916
+ cannot be resolved.
1917
+ """
1918
+ if charge is None and ligand_charge is not None:
1919
+ charge = _derive_charge_from_ligand_charge(
1920
+ prepared.source_path, ligand_charge, prefix=prefix,
1921
+ )
1922
+ if charge is None:
1923
+ if charge_default is None:
1924
+ raise click.ClickException(
1925
+ "Total charge is unresolved. Provide -q/--charge or --ligand-charge."
1926
+ )
1927
+ charge = charge_default
1928
+ if spin is None:
1929
+ spin = spin_default
1930
+ return int(charge), int(spin)
1931
+
1932
+
1933
+ # -----------------------------------------------
1934
+ # B-factor based 3-layer ML/MM system utilities
1935
+ # -----------------------------------------------
1936
+
1937
+ def read_bfactors_from_pdb(pdb_path: Path) -> List[float]:
1938
+ """
1939
+ Read B-factor (temperature factor) values from a PDB file.
1940
+
1941
+ Returns a list of B-factors in atom order (0-indexed).
1942
+ Only ATOM and HETATM records are processed.
1943
+ """
1944
+ bfactors: List[float] = []
1945
+ with open(pdb_path, "r") as f:
1946
+ for line in f:
1947
+ if line.startswith(("ATOM ", "HETATM")):
1948
+ # B-factor is at columns 61-66 (1-indexed), i.e., [60:66]
1949
+ try:
1950
+ bfac = float(line[60:66].strip())
1951
+ except (ValueError, IndexError):
1952
+ bfac = 0.0
1953
+ bfactors.append(bfac)
1954
+ return bfactors
1955
+
1956
+
1957
+ def parse_layer_indices_from_bfactors(
1958
+ bfactors: List[float],
1959
+ tolerance: float = 1.0,
1960
+ ) -> Dict[str, List[int]]:
1961
+ """
1962
+ Parse B-factor values into layer indices for 3-layer ML/MM system.
1963
+
1964
+ B-factor encoding:
1965
+ 0.0 (±tolerance): ML atoms
1966
+ 10.0 (±tolerance): Movable MM atoms
1967
+ 20.0 (±tolerance): Frozen MM atoms
1968
+
1969
+ Parameters
1970
+ ----------
1971
+ bfactors : List[float]
1972
+ B-factor values for each atom (0-indexed).
1973
+ tolerance : float
1974
+ Tolerance for B-factor matching (default: 1.0).
1975
+
1976
+ Returns
1977
+ -------
1978
+ Dict[str, List[int]]
1979
+ Dictionary with keys:
1980
+ - "ml_indices": ML region atoms
1981
+ - "hess_mm_indices": Compatibility key (empty in 3-layer encoding)
1982
+ - "movable_mm_indices": Movable MM atoms
1983
+ - "frozen_indices": Frozen atoms
1984
+ - "unassigned_indices": Atoms with B-factors not matching any layer
1985
+ """
1986
+ from .defaults import BFACTOR_ML, BFACTOR_HESS_MM, BFACTOR_MOVABLE_MM, BFACTOR_FROZEN
1987
+
1988
+ ml_indices: List[int] = []
1989
+ hess_mm_indices: List[int] = []
1990
+ movable_mm_indices: List[int] = []
1991
+ frozen_indices: List[int] = []
1992
+ unassigned_indices: List[int] = []
1993
+
1994
+ for i, bfac in enumerate(bfactors):
1995
+ if abs(bfac - BFACTOR_ML) <= tolerance:
1996
+ ml_indices.append(i)
1997
+ elif abs(bfac - BFACTOR_FROZEN) <= tolerance:
1998
+ frozen_indices.append(i)
1999
+ elif abs(bfac - BFACTOR_MOVABLE_MM) <= tolerance:
2000
+ movable_mm_indices.append(i)
2001
+ elif (
2002
+ BFACTOR_HESS_MM != BFACTOR_MOVABLE_MM
2003
+ and abs(bfac - BFACTOR_HESS_MM) <= tolerance
2004
+ ):
2005
+ hess_mm_indices.append(i)
2006
+ else:
2007
+ unassigned_indices.append(i)
2008
+
2009
+ return {
2010
+ "ml_indices": ml_indices,
2011
+ "hess_mm_indices": hess_mm_indices,
2012
+ "movable_mm_indices": movable_mm_indices,
2013
+ "frozen_indices": frozen_indices,
2014
+ "unassigned_indices": unassigned_indices,
2015
+ }
2016
+
2017
+
2018
+ def has_valid_layer_bfactors(bfactors: List[float], tolerance: float = 1.0) -> bool:
2019
+ """
2020
+ Check if PDB B-factors contain valid 3-layer encoding.
2021
+
2022
+ Returns True if at least one atom has ML B-factor and the B-factors are
2023
+ predominantly in the expected range (0, 10, 20).
2024
+ """
2025
+ from .defaults import BFACTOR_ML, BFACTOR_HESS_MM, BFACTOR_MOVABLE_MM, BFACTOR_FROZEN
2026
+
2027
+ valid_bfactors = {BFACTOR_ML, BFACTOR_MOVABLE_MM, BFACTOR_FROZEN, BFACTOR_HESS_MM}
2028
+ has_ml = False
2029
+ valid_count = 0
2030
+
2031
+ for bfac in bfactors:
2032
+ for valid in valid_bfactors:
2033
+ if abs(bfac - valid) <= tolerance:
2034
+ valid_count += 1
2035
+ if abs(bfac - BFACTOR_ML) <= tolerance:
2036
+ has_ml = True
2037
+ break
2038
+
2039
+ # Consider valid if:
2040
+ # 1. Has at least one ML atom
2041
+ # 2. At least 80% of atoms have valid B-factors
2042
+ return has_ml and (valid_count / max(len(bfactors), 1) >= 0.8)
2043
+
2044
+
2045
+ def parse_indices_string(indices_str: str, one_based: bool = True) -> List[int]:
2046
+ """
2047
+ Parse a comma-separated index string into a sorted list of 0-based ints.
2048
+
2049
+ Supports ranges like "1-5" (inclusive). By default, inputs are 1-based.
2050
+ """
2051
+ import click
2052
+ if indices_str is None:
2053
+ return []
2054
+ tokens = [tok.strip() for tok in str(indices_str).replace(" ", ",").split(",") if tok.strip()]
2055
+ indices: List[int] = []
2056
+ for token in tokens:
2057
+ if "-" in token and not token.startswith("-"):
2058
+ parts = token.split("-")
2059
+ if len(parts) == 2 and parts[0] and parts[1]:
2060
+ try:
2061
+ start = int(parts[0])
2062
+ end = int(parts[1])
2063
+ except ValueError as exc:
2064
+ raise click.BadParameter(f"Invalid range token in --model-indices: '{token}'") from exc
2065
+ if one_based:
2066
+ start -= 1
2067
+ end -= 1
2068
+ if start < 0 or end < 0 or start > end:
2069
+ raise click.BadParameter(f"Invalid range in --model-indices: '{token}'")
2070
+ indices.extend(range(start, end + 1))
2071
+ continue
2072
+ try:
2073
+ value = int(token)
2074
+ except ValueError as exc:
2075
+ raise click.BadParameter(f"Invalid index in --model-indices: '{token}'") from exc
2076
+ if one_based:
2077
+ value -= 1
2078
+ if value < 0:
2079
+ raise click.BadParameter(f"--model-indices expects positive indices; got {value + (1 if one_based else 0)}")
2080
+ indices.append(value)
2081
+ return sorted(set(indices))
2082
+
2083
+
2084
+ def write_model_pdb_from_indices(
2085
+ input_pdb_path: Path,
2086
+ output_pdb_path: Path,
2087
+ indices: Sequence[int],
2088
+ ) -> None:
2089
+ """
2090
+ Write a model PDB containing only atoms at the specified 0-based indices.
2091
+ """
2092
+ import click
2093
+ if not indices:
2094
+ raise ValueError("No indices provided to build model PDB.")
2095
+ n_atoms = _count_atoms_in_file(input_pdb_path)
2096
+ if n_atoms <= 0:
2097
+ raise ValueError(f"No atoms found in input PDB: {input_pdb_path}")
2098
+ for idx in indices:
2099
+ if idx < 0 or idx >= n_atoms:
2100
+ raise click.BadParameter(
2101
+ f"model index out of range: {idx} (valid: 0 <= idx < {n_atoms})"
2102
+ )
2103
+
2104
+ keep = set(int(i) for i in indices)
2105
+ lines_out: List[str] = []
2106
+ atom_idx = 0
2107
+ with open(input_pdb_path, "r") as f:
2108
+ for line in f:
2109
+ if line.startswith(("ATOM ", "HETATM")):
2110
+ if atom_idx in keep:
2111
+ # Auto-fill element column (77-78) if missing
2112
+ raw = line.rstrip("\n")
2113
+ elem_field = raw[76:78].strip() if len(raw) >= 78 else ""
2114
+ if not elem_field:
2115
+ atom_name = raw[12:16].strip()
2116
+ res_name = raw[17:20].strip()
2117
+ is_hetatm = raw.startswith("HETATM")
2118
+ elem = guess_element(atom_name, res_name, is_hetatm)
2119
+ if elem:
2120
+ padded = raw.ljust(76) + f"{elem:>2}" + "\n"
2121
+ lines_out.append(padded)
2122
+ else:
2123
+ lines_out.append(line)
2124
+ else:
2125
+ lines_out.append(line)
2126
+ atom_idx += 1
2127
+ if not lines_out:
2128
+ raise ValueError("Model PDB would be empty; check indices and input PDB.")
2129
+ if not lines_out[-1].endswith("\n"):
2130
+ lines_out[-1] = lines_out[-1] + "\n"
2131
+ lines_out.append("END\n")
2132
+ with open(output_pdb_path, "w") as f:
2133
+ f.writelines(lines_out)
2134
+
2135
+
2136
+ def build_model_pdb_from_indices(
2137
+ input_pdb_path: Path,
2138
+ out_dir: Path,
2139
+ indices: Sequence[int],
2140
+ *,
2141
+ label: str = "model_from_indices",
2142
+ ) -> Path:
2143
+ """
2144
+ Create a temporary model PDB under out_dir using explicit indices.
2145
+ """
2146
+ out_dir.mkdir(parents=True, exist_ok=True)
2147
+ with tempfile.NamedTemporaryFile(
2148
+ mode="w",
2149
+ suffix=".pdb",
2150
+ prefix=f"{label}_",
2151
+ dir=out_dir,
2152
+ delete=False,
2153
+ ) as tmp:
2154
+ tmp_path = Path(tmp.name)
2155
+ write_model_pdb_from_indices(input_pdb_path, tmp_path, indices)
2156
+ return tmp_path
2157
+
2158
+
2159
+ def build_model_pdb_from_bfactors(
2160
+ input_pdb_path: Path,
2161
+ out_dir: Path,
2162
+ *,
2163
+ tolerance: float = None,
2164
+ label: str = "model_from_bfactor",
2165
+ ) -> Tuple[Path, Dict[str, List[int]]]:
2166
+ """
2167
+ Create a model PDB using ML indices derived from B-factors.
2168
+
2169
+ Returns (model_pdb_path, layer_info).
2170
+ """
2171
+ from .defaults import BFACTOR_TOLERANCE
2172
+ tol = BFACTOR_TOLERANCE if tolerance is None else float(tolerance)
2173
+ bfactors = read_bfactors_from_pdb(input_pdb_path)
2174
+ if not bfactors:
2175
+ raise ValueError(f"No ATOM/HETATM records found in {input_pdb_path}.")
2176
+ if not has_valid_layer_bfactors(bfactors, tolerance=tol):
2177
+ raise ValueError(
2178
+ "Invalid or missing layer B-factors (expected ~0/10/20). "
2179
+ "Provide --no-detect-layer with --model-pdb/--model-indices."
2180
+ )
2181
+ layer_info = parse_layer_indices_from_bfactors(bfactors, tolerance=tol)
2182
+ ml_indices = layer_info.get("ml_indices") or []
2183
+ if not ml_indices:
2184
+ raise ValueError("No ML atoms detected from B-factors (value ~0).")
2185
+ out_dir.mkdir(parents=True, exist_ok=True)
2186
+ tmp_path = out_dir / f"{label}.pdb"
2187
+ write_model_pdb_from_indices(input_pdb_path, tmp_path, ml_indices)
2188
+ return tmp_path, layer_info
2189
+
2190
+
2191
+ def write_layer_bfactors_to_pdb(
2192
+ input_pdb_path: Path,
2193
+ output_pdb_path: Path,
2194
+ ml_indices: List[int],
2195
+ hess_mm_indices: Optional[List[int]] = None,
2196
+ movable_mm_indices: Optional[List[int]] = None,
2197
+ frozen_indices: Optional[List[int]] = None,
2198
+ ) -> None:
2199
+ """
2200
+ Write a PDB file with B-factors set according to 3-layer assignments.
2201
+
2202
+ B-factor encoding:
2203
+ ML atoms: 0.0
2204
+ Movable MM atoms: 10.0
2205
+ Frozen MM atoms: 20.0
2206
+ Hessian MM atoms: encoded with the same B-factor as movable MM
2207
+
2208
+ Parameters
2209
+ ----------
2210
+ input_pdb_path : Path
2211
+ Source PDB file to read atom records from.
2212
+ output_pdb_path : Path
2213
+ Output PDB file path.
2214
+ ml_indices : List[int]
2215
+ 0-based indices of ML region atoms.
2216
+ hess_mm_indices : Optional[List[int]]
2217
+ 0-based indices of MM atoms with Hessian (written as movable B-factor).
2218
+ movable_mm_indices : Optional[List[int]]
2219
+ 0-based indices of movable MM atoms without Hessian.
2220
+ frozen_indices : Optional[List[int]]
2221
+ 0-based indices of frozen atoms.
2222
+
2223
+ Notes
2224
+ -----
2225
+ Supports multi-MODEL PDB files (e.g., trajectories): atom index resets
2226
+ at each MODEL record.
2227
+ """
2228
+ from .defaults import BFACTOR_ML, BFACTOR_HESS_MM, BFACTOR_MOVABLE_MM, BFACTOR_FROZEN
2229
+
2230
+ ml_set = set(ml_indices or [])
2231
+ hess_mm_set = set(hess_mm_indices or [])
2232
+ movable_mm_set = set(movable_mm_indices or [])
2233
+ frozen_set = set(frozen_indices or [])
2234
+
2235
+ lines_out: List[str] = []
2236
+ atom_idx = 0
2237
+
2238
+ with open(input_pdb_path, "r") as f:
2239
+ for line in f:
2240
+ rec = line[:6]
2241
+ # Reset atom counter at each MODEL record (for trajectory files)
2242
+ if rec.startswith("MODEL"):
2243
+ atom_idx = 0
2244
+ lines_out.append(line)
2245
+ continue
2246
+
2247
+ if line.startswith(("ATOM ", "HETATM")):
2248
+ # Determine B-factor for this atom
2249
+ if atom_idx in ml_set:
2250
+ bfac = BFACTOR_ML
2251
+ elif atom_idx in hess_mm_set:
2252
+ bfac = BFACTOR_HESS_MM
2253
+ elif atom_idx in movable_mm_set:
2254
+ bfac = BFACTOR_MOVABLE_MM
2255
+ elif atom_idx in frozen_set:
2256
+ bfac = BFACTOR_FROZEN
2257
+ else:
2258
+ # Default: treat as movable MM (layer 3)
2259
+ bfac = BFACTOR_MOVABLE_MM
2260
+
2261
+ # Replace B-factor (columns 61-66, 1-indexed)
2262
+ # PDB format: columns 61-66 are B-factor with format %6.2f
2263
+ # Ensure line is long enough before modifying
2264
+ if len(line) >= 66:
2265
+ new_line = line[:60] + f"{bfac:6.2f}" + line[66:]
2266
+ else:
2267
+ # Pad line if too short
2268
+ padded = line.rstrip("\n").ljust(66)
2269
+ new_line = padded[:60] + f"{bfac:6.2f}" + "\n"
2270
+ lines_out.append(new_line)
2271
+ atom_idx += 1
2272
+ else:
2273
+ lines_out.append(line)
2274
+
2275
+ with open(output_pdb_path, "w") as f:
2276
+ f.writelines(lines_out)
2277
+
2278
+
2279
+ def update_pdb_bfactors_from_layers(
2280
+ pdb_path: Path,
2281
+ ml_indices: List[int],
2282
+ hess_mm_indices: Optional[List[int]] = None,
2283
+ movable_mm_indices: Optional[List[int]] = None,
2284
+ frozen_indices: Optional[List[int]] = None,
2285
+ ) -> None:
2286
+ """
2287
+ Update B-factors in a PDB file in-place based on layer assignments.
2288
+
2289
+ This is a convenience wrapper that reads and writes to the same file.
2290
+ """
2291
+ import tempfile
2292
+ import shutil
2293
+
2294
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".pdb", delete=False) as tmp:
2295
+ tmp_path = Path(tmp.name)
2296
+
2297
+ try:
2298
+ write_layer_bfactors_to_pdb(
2299
+ pdb_path,
2300
+ tmp_path,
2301
+ ml_indices,
2302
+ hess_mm_indices,
2303
+ movable_mm_indices,
2304
+ frozen_indices,
2305
+ )
2306
+ shutil.move(str(tmp_path), str(pdb_path))
2307
+ finally:
2308
+ if tmp_path.exists():
2309
+ tmp_path.unlink()