mlmm-toolkit 0.2.2.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hessian_ff/__init__.py +50 -0
- hessian_ff/analytical_hessian.py +609 -0
- hessian_ff/constants.py +46 -0
- hessian_ff/forcefield.py +339 -0
- hessian_ff/loaders.py +608 -0
- hessian_ff/native/Makefile +8 -0
- hessian_ff/native/__init__.py +28 -0
- hessian_ff/native/analytical_hessian.py +88 -0
- hessian_ff/native/analytical_hessian_ext.cpp +258 -0
- hessian_ff/native/bonded.py +82 -0
- hessian_ff/native/bonded_ext.cpp +640 -0
- hessian_ff/native/loader.py +349 -0
- hessian_ff/native/nonbonded.py +118 -0
- hessian_ff/native/nonbonded_ext.cpp +1150 -0
- hessian_ff/prmtop_parmed.py +23 -0
- hessian_ff/system.py +107 -0
- hessian_ff/terms/__init__.py +14 -0
- hessian_ff/terms/angle.py +73 -0
- hessian_ff/terms/bond.py +44 -0
- hessian_ff/terms/cmap.py +406 -0
- hessian_ff/terms/dihedral.py +141 -0
- hessian_ff/terms/nonbonded.py +209 -0
- hessian_ff/tests/__init__.py +0 -0
- hessian_ff/tests/conftest.py +75 -0
- hessian_ff/tests/data/small/complex.parm7 +1346 -0
- hessian_ff/tests/data/small/complex.pdb +125 -0
- hessian_ff/tests/data/small/complex.rst7 +63 -0
- hessian_ff/tests/test_coords_input.py +44 -0
- hessian_ff/tests/test_energy_force.py +49 -0
- hessian_ff/tests/test_hessian.py +137 -0
- hessian_ff/tests/test_smoke.py +18 -0
- hessian_ff/tests/test_validation.py +40 -0
- hessian_ff/workflows.py +889 -0
- mlmm/__init__.py +36 -0
- mlmm/__main__.py +7 -0
- mlmm/_version.py +34 -0
- mlmm/add_elem_info.py +374 -0
- mlmm/advanced_help.py +91 -0
- mlmm/align_freeze_atoms.py +601 -0
- mlmm/all.py +3535 -0
- mlmm/bond_changes.py +231 -0
- mlmm/bool_compat.py +223 -0
- mlmm/cli.py +574 -0
- mlmm/cli_utils.py +166 -0
- mlmm/default_group.py +337 -0
- mlmm/defaults.py +467 -0
- mlmm/define_layer.py +526 -0
- mlmm/dft.py +1041 -0
- mlmm/energy_diagram.py +253 -0
- mlmm/extract.py +2213 -0
- mlmm/fix_altloc.py +464 -0
- mlmm/freq.py +1406 -0
- mlmm/harmonic_constraints.py +140 -0
- mlmm/hessian_cache.py +44 -0
- mlmm/hessian_calc.py +174 -0
- mlmm/irc.py +638 -0
- mlmm/mlmm_calc.py +2262 -0
- mlmm/mm_parm.py +945 -0
- mlmm/oniom_export.py +1983 -0
- mlmm/oniom_import.py +457 -0
- mlmm/opt.py +1742 -0
- mlmm/path_opt.py +1353 -0
- mlmm/path_search.py +2299 -0
- mlmm/preflight.py +88 -0
- mlmm/py.typed +1 -0
- mlmm/pysis_runner.py +45 -0
- mlmm/scan.py +1047 -0
- mlmm/scan2d.py +1226 -0
- mlmm/scan3d.py +1265 -0
- mlmm/scan_common.py +184 -0
- mlmm/summary_log.py +736 -0
- mlmm/trj2fig.py +448 -0
- mlmm/tsopt.py +2871 -0
- mlmm/utils.py +2309 -0
- mlmm/xtb_embedcharge_correction.py +475 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/METADATA +1159 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/RECORD +372 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/WHEEL +5 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/entry_points.txt +2 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/licenses/LICENSE +674 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/top_level.txt +4 -0
- pysisyphus/Geometry.py +1667 -0
- pysisyphus/LICENSE +674 -0
- pysisyphus/TableFormatter.py +63 -0
- pysisyphus/TablePrinter.py +74 -0
- pysisyphus/__init__.py +12 -0
- pysisyphus/calculators/AFIR.py +452 -0
- pysisyphus/calculators/AnaPot.py +20 -0
- pysisyphus/calculators/AnaPot2.py +48 -0
- pysisyphus/calculators/AnaPot3.py +12 -0
- pysisyphus/calculators/AnaPot4.py +20 -0
- pysisyphus/calculators/AnaPotBase.py +337 -0
- pysisyphus/calculators/AnaPotCBM.py +25 -0
- pysisyphus/calculators/AtomAtomTransTorque.py +154 -0
- pysisyphus/calculators/CFOUR.py +250 -0
- pysisyphus/calculators/Calculator.py +844 -0
- pysisyphus/calculators/CerjanMiller.py +24 -0
- pysisyphus/calculators/Composite.py +123 -0
- pysisyphus/calculators/ConicalIntersection.py +171 -0
- pysisyphus/calculators/DFTBp.py +430 -0
- pysisyphus/calculators/DFTD3.py +66 -0
- pysisyphus/calculators/DFTD4.py +84 -0
- pysisyphus/calculators/Dalton.py +61 -0
- pysisyphus/calculators/Dimer.py +681 -0
- pysisyphus/calculators/Dummy.py +20 -0
- pysisyphus/calculators/EGO.py +76 -0
- pysisyphus/calculators/EnergyMin.py +224 -0
- pysisyphus/calculators/ExternalPotential.py +264 -0
- pysisyphus/calculators/FakeASE.py +35 -0
- pysisyphus/calculators/FourWellAnaPot.py +28 -0
- pysisyphus/calculators/FreeEndNEBPot.py +39 -0
- pysisyphus/calculators/Gaussian09.py +18 -0
- pysisyphus/calculators/Gaussian16.py +726 -0
- pysisyphus/calculators/HardSphere.py +159 -0
- pysisyphus/calculators/IDPPCalculator.py +49 -0
- pysisyphus/calculators/IPIClient.py +133 -0
- pysisyphus/calculators/IPIServer.py +234 -0
- pysisyphus/calculators/LEPSBase.py +24 -0
- pysisyphus/calculators/LEPSExpr.py +139 -0
- pysisyphus/calculators/LennardJones.py +80 -0
- pysisyphus/calculators/MOPAC.py +219 -0
- pysisyphus/calculators/MullerBrownSympyPot.py +51 -0
- pysisyphus/calculators/MultiCalc.py +85 -0
- pysisyphus/calculators/NFK.py +45 -0
- pysisyphus/calculators/OBabel.py +87 -0
- pysisyphus/calculators/ONIOMv2.py +1129 -0
- pysisyphus/calculators/ORCA.py +893 -0
- pysisyphus/calculators/ORCA5.py +6 -0
- pysisyphus/calculators/OpenMM.py +88 -0
- pysisyphus/calculators/OpenMolcas.py +281 -0
- pysisyphus/calculators/OverlapCalculator.py +908 -0
- pysisyphus/calculators/Psi4.py +218 -0
- pysisyphus/calculators/PyPsi4.py +37 -0
- pysisyphus/calculators/PySCF.py +341 -0
- pysisyphus/calculators/PyXTB.py +73 -0
- pysisyphus/calculators/QCEngine.py +106 -0
- pysisyphus/calculators/Rastrigin.py +22 -0
- pysisyphus/calculators/Remote.py +76 -0
- pysisyphus/calculators/Rosenbrock.py +15 -0
- pysisyphus/calculators/SocketCalc.py +97 -0
- pysisyphus/calculators/TIP3P.py +111 -0
- pysisyphus/calculators/TransTorque.py +161 -0
- pysisyphus/calculators/Turbomole.py +965 -0
- pysisyphus/calculators/VRIPot.py +37 -0
- pysisyphus/calculators/WFOWrapper.py +333 -0
- pysisyphus/calculators/WFOWrapper2.py +341 -0
- pysisyphus/calculators/XTB.py +418 -0
- pysisyphus/calculators/__init__.py +81 -0
- pysisyphus/calculators/cosmo_data.py +139 -0
- pysisyphus/calculators/parser.py +150 -0
- pysisyphus/color.py +19 -0
- pysisyphus/config.py +133 -0
- pysisyphus/constants.py +65 -0
- pysisyphus/cos/AdaptiveNEB.py +230 -0
- pysisyphus/cos/ChainOfStates.py +725 -0
- pysisyphus/cos/FreeEndNEB.py +25 -0
- pysisyphus/cos/FreezingString.py +103 -0
- pysisyphus/cos/GrowingChainOfStates.py +71 -0
- pysisyphus/cos/GrowingNT.py +309 -0
- pysisyphus/cos/GrowingString.py +508 -0
- pysisyphus/cos/NEB.py +189 -0
- pysisyphus/cos/SimpleZTS.py +64 -0
- pysisyphus/cos/__init__.py +22 -0
- pysisyphus/cos/stiffness.py +199 -0
- pysisyphus/drivers/__init__.py +17 -0
- pysisyphus/drivers/afir.py +855 -0
- pysisyphus/drivers/barriers.py +271 -0
- pysisyphus/drivers/birkholz.py +138 -0
- pysisyphus/drivers/cluster.py +318 -0
- pysisyphus/drivers/diabatization.py +133 -0
- pysisyphus/drivers/merge.py +368 -0
- pysisyphus/drivers/merge_mol2.py +322 -0
- pysisyphus/drivers/opt.py +375 -0
- pysisyphus/drivers/perf.py +91 -0
- pysisyphus/drivers/pka.py +52 -0
- pysisyphus/drivers/precon_pos_rot.py +669 -0
- pysisyphus/drivers/rates.py +480 -0
- pysisyphus/drivers/replace.py +219 -0
- pysisyphus/drivers/scan.py +212 -0
- pysisyphus/drivers/spectrum.py +166 -0
- pysisyphus/drivers/thermo.py +31 -0
- pysisyphus/dynamics/Gaussian.py +103 -0
- pysisyphus/dynamics/__init__.py +20 -0
- pysisyphus/dynamics/colvars.py +136 -0
- pysisyphus/dynamics/driver.py +297 -0
- pysisyphus/dynamics/helpers.py +256 -0
- pysisyphus/dynamics/lincs.py +105 -0
- pysisyphus/dynamics/mdp.py +364 -0
- pysisyphus/dynamics/rattle.py +121 -0
- pysisyphus/dynamics/thermostats.py +128 -0
- pysisyphus/dynamics/wigner.py +266 -0
- pysisyphus/elem_data.py +3473 -0
- pysisyphus/exceptions.py +2 -0
- pysisyphus/filtertrj.py +69 -0
- pysisyphus/helpers.py +623 -0
- pysisyphus/helpers_pure.py +649 -0
- pysisyphus/init_logging.py +50 -0
- pysisyphus/intcoords/Bend.py +69 -0
- pysisyphus/intcoords/Bend2.py +25 -0
- pysisyphus/intcoords/BondedFragment.py +32 -0
- pysisyphus/intcoords/Cartesian.py +41 -0
- pysisyphus/intcoords/CartesianCoords.py +140 -0
- pysisyphus/intcoords/Coords.py +56 -0
- pysisyphus/intcoords/DLC.py +197 -0
- pysisyphus/intcoords/DistanceFunction.py +34 -0
- pysisyphus/intcoords/DummyImproper.py +70 -0
- pysisyphus/intcoords/DummyTorsion.py +72 -0
- pysisyphus/intcoords/LinearBend.py +105 -0
- pysisyphus/intcoords/LinearDisplacement.py +80 -0
- pysisyphus/intcoords/OutOfPlane.py +59 -0
- pysisyphus/intcoords/PrimTypes.py +286 -0
- pysisyphus/intcoords/Primitive.py +137 -0
- pysisyphus/intcoords/RedundantCoords.py +659 -0
- pysisyphus/intcoords/RobustTorsion.py +59 -0
- pysisyphus/intcoords/Rotation.py +147 -0
- pysisyphus/intcoords/Stretch.py +31 -0
- pysisyphus/intcoords/Torsion.py +101 -0
- pysisyphus/intcoords/Torsion2.py +25 -0
- pysisyphus/intcoords/Translation.py +45 -0
- pysisyphus/intcoords/__init__.py +61 -0
- pysisyphus/intcoords/augment_bonds.py +126 -0
- pysisyphus/intcoords/derivatives.py +10512 -0
- pysisyphus/intcoords/eval.py +80 -0
- pysisyphus/intcoords/exceptions.py +37 -0
- pysisyphus/intcoords/findiffs.py +48 -0
- pysisyphus/intcoords/generate_derivatives.py +414 -0
- pysisyphus/intcoords/helpers.py +235 -0
- pysisyphus/intcoords/logging_conf.py +10 -0
- pysisyphus/intcoords/mp_derivatives.py +10836 -0
- pysisyphus/intcoords/setup.py +962 -0
- pysisyphus/intcoords/setup_fast.py +176 -0
- pysisyphus/intcoords/update.py +272 -0
- pysisyphus/intcoords/valid.py +89 -0
- pysisyphus/interpolate/Geodesic.py +93 -0
- pysisyphus/interpolate/IDPP.py +55 -0
- pysisyphus/interpolate/Interpolator.py +116 -0
- pysisyphus/interpolate/LST.py +70 -0
- pysisyphus/interpolate/Redund.py +152 -0
- pysisyphus/interpolate/__init__.py +9 -0
- pysisyphus/interpolate/helpers.py +34 -0
- pysisyphus/io/__init__.py +22 -0
- pysisyphus/io/aomix.py +178 -0
- pysisyphus/io/cjson.py +24 -0
- pysisyphus/io/crd.py +101 -0
- pysisyphus/io/cube.py +220 -0
- pysisyphus/io/fchk.py +184 -0
- pysisyphus/io/hdf5.py +49 -0
- pysisyphus/io/hessian.py +72 -0
- pysisyphus/io/mol2.py +146 -0
- pysisyphus/io/molden.py +293 -0
- pysisyphus/io/orca.py +189 -0
- pysisyphus/io/pdb.py +269 -0
- pysisyphus/io/psf.py +79 -0
- pysisyphus/io/pubchem.py +31 -0
- pysisyphus/io/qcschema.py +34 -0
- pysisyphus/io/sdf.py +29 -0
- pysisyphus/io/xyz.py +61 -0
- pysisyphus/io/zmat.py +175 -0
- pysisyphus/irc/DWI.py +108 -0
- pysisyphus/irc/DampedVelocityVerlet.py +134 -0
- pysisyphus/irc/Euler.py +22 -0
- pysisyphus/irc/EulerPC.py +345 -0
- pysisyphus/irc/GonzalezSchlegel.py +187 -0
- pysisyphus/irc/IMKMod.py +164 -0
- pysisyphus/irc/IRC.py +878 -0
- pysisyphus/irc/IRCDummy.py +10 -0
- pysisyphus/irc/Instanton.py +307 -0
- pysisyphus/irc/LQA.py +53 -0
- pysisyphus/irc/ModeKill.py +136 -0
- pysisyphus/irc/ParamPlot.py +53 -0
- pysisyphus/irc/RK4.py +36 -0
- pysisyphus/irc/__init__.py +31 -0
- pysisyphus/irc/initial_displ.py +219 -0
- pysisyphus/linalg.py +411 -0
- pysisyphus/line_searches/Backtracking.py +88 -0
- pysisyphus/line_searches/HagerZhang.py +184 -0
- pysisyphus/line_searches/LineSearch.py +232 -0
- pysisyphus/line_searches/StrongWolfe.py +108 -0
- pysisyphus/line_searches/__init__.py +9 -0
- pysisyphus/line_searches/interpol.py +15 -0
- pysisyphus/modefollow/NormalMode.py +40 -0
- pysisyphus/modefollow/__init__.py +10 -0
- pysisyphus/modefollow/davidson.py +199 -0
- pysisyphus/modefollow/lanczos.py +95 -0
- pysisyphus/optimizers/BFGS.py +99 -0
- pysisyphus/optimizers/BacktrackingOptimizer.py +113 -0
- pysisyphus/optimizers/ConjugateGradient.py +98 -0
- pysisyphus/optimizers/CubicNewton.py +75 -0
- pysisyphus/optimizers/FIRE.py +113 -0
- pysisyphus/optimizers/HessianOptimizer.py +1176 -0
- pysisyphus/optimizers/LBFGS.py +228 -0
- pysisyphus/optimizers/LayerOpt.py +411 -0
- pysisyphus/optimizers/MicroOptimizer.py +169 -0
- pysisyphus/optimizers/NCOptimizer.py +90 -0
- pysisyphus/optimizers/Optimizer.py +1084 -0
- pysisyphus/optimizers/PreconLBFGS.py +260 -0
- pysisyphus/optimizers/PreconSteepestDescent.py +7 -0
- pysisyphus/optimizers/QuickMin.py +74 -0
- pysisyphus/optimizers/RFOptimizer.py +181 -0
- pysisyphus/optimizers/RSA.py +99 -0
- pysisyphus/optimizers/StabilizedQNMethod.py +248 -0
- pysisyphus/optimizers/SteepestDescent.py +23 -0
- pysisyphus/optimizers/StringOptimizer.py +173 -0
- pysisyphus/optimizers/__init__.py +41 -0
- pysisyphus/optimizers/closures.py +301 -0
- pysisyphus/optimizers/cls_map.py +58 -0
- pysisyphus/optimizers/exceptions.py +6 -0
- pysisyphus/optimizers/gdiis.py +280 -0
- pysisyphus/optimizers/guess_hessians.py +311 -0
- pysisyphus/optimizers/hessian_updates.py +355 -0
- pysisyphus/optimizers/poly_fit.py +285 -0
- pysisyphus/optimizers/precon.py +153 -0
- pysisyphus/optimizers/restrict_step.py +24 -0
- pysisyphus/pack.py +172 -0
- pysisyphus/peakdetect.py +948 -0
- pysisyphus/plot.py +1031 -0
- pysisyphus/run.py +2106 -0
- pysisyphus/socket_helper.py +74 -0
- pysisyphus/stocastic/FragmentKick.py +132 -0
- pysisyphus/stocastic/Kick.py +81 -0
- pysisyphus/stocastic/Pipeline.py +303 -0
- pysisyphus/stocastic/__init__.py +21 -0
- pysisyphus/stocastic/align.py +127 -0
- pysisyphus/testing.py +96 -0
- pysisyphus/thermo.py +156 -0
- pysisyphus/trj.py +824 -0
- pysisyphus/tsoptimizers/RSIRFOptimizer.py +56 -0
- pysisyphus/tsoptimizers/RSPRFOptimizer.py +182 -0
- pysisyphus/tsoptimizers/TRIM.py +59 -0
- pysisyphus/tsoptimizers/TSHessianOptimizer.py +463 -0
- pysisyphus/tsoptimizers/__init__.py +23 -0
- pysisyphus/wavefunction/Basis.py +239 -0
- pysisyphus/wavefunction/DIIS.py +76 -0
- pysisyphus/wavefunction/__init__.py +25 -0
- pysisyphus/wavefunction/build_ext.py +42 -0
- pysisyphus/wavefunction/cart2sph.py +190 -0
- pysisyphus/wavefunction/diabatization.py +304 -0
- pysisyphus/wavefunction/excited_states.py +435 -0
- pysisyphus/wavefunction/gen_ints.py +1811 -0
- pysisyphus/wavefunction/helpers.py +104 -0
- pysisyphus/wavefunction/ints/__init__.py +0 -0
- pysisyphus/wavefunction/ints/boys.py +193 -0
- pysisyphus/wavefunction/ints/boys_table_N_64_xasym_27.1_step_0.01.npy +0 -0
- pysisyphus/wavefunction/ints/cart_gto3d.py +176 -0
- pysisyphus/wavefunction/ints/coulomb3d.py +25928 -0
- pysisyphus/wavefunction/ints/diag_quadrupole3d.py +10036 -0
- pysisyphus/wavefunction/ints/dipole3d.py +8762 -0
- pysisyphus/wavefunction/ints/int2c2e3d.py +7198 -0
- pysisyphus/wavefunction/ints/int3c2e3d_sph.py +65040 -0
- pysisyphus/wavefunction/ints/kinetic3d.py +8240 -0
- pysisyphus/wavefunction/ints/ovlp3d.py +3777 -0
- pysisyphus/wavefunction/ints/quadrupole3d.py +15054 -0
- pysisyphus/wavefunction/ints/self_ovlp3d.py +198 -0
- pysisyphus/wavefunction/localization.py +458 -0
- pysisyphus/wavefunction/multipole.py +159 -0
- pysisyphus/wavefunction/normalization.py +36 -0
- pysisyphus/wavefunction/pop_analysis.py +134 -0
- pysisyphus/wavefunction/shells.py +1171 -0
- pysisyphus/wavefunction/wavefunction.py +504 -0
- pysisyphus/wrapper/__init__.py +11 -0
- pysisyphus/wrapper/exceptions.py +2 -0
- pysisyphus/wrapper/jmol.py +120 -0
- pysisyphus/wrapper/mwfn.py +169 -0
- pysisyphus/wrapper/packmol.py +71 -0
- pysisyphus/xyzloader.py +168 -0
- pysisyphus/yaml_mods.py +45 -0
- thermoanalysis/LICENSE +674 -0
- thermoanalysis/QCData.py +244 -0
- thermoanalysis/__init__.py +0 -0
- thermoanalysis/config.py +3 -0
- thermoanalysis/constants.py +20 -0
- thermoanalysis/thermo.py +1011 -0
mlmm/utils.py
ADDED
|
@@ -0,0 +1,2309 @@
|
|
|
1
|
+
# mlmm/utils.py
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
utils — concise utilities for configuration, plotting, and coordinates
|
|
5
|
+
====================================================================
|
|
6
|
+
|
|
7
|
+
Usage (API)
|
|
8
|
+
-----
|
|
9
|
+
from mlmm.utils import (
|
|
10
|
+
build_energy_diagram,
|
|
11
|
+
convert_xyz_to_pdb,
|
|
12
|
+
merge_freeze_atom_indices,
|
|
13
|
+
pretty_block,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
Examples::
|
|
17
|
+
>>> from pathlib import Path
|
|
18
|
+
>>> block = pretty_block("Geometry", {"freeze_atoms": [0, 1, 5]})
|
|
19
|
+
>>> diagram = build_energy_diagram([0.0, 12.3, 5.4], ["R", "TS", "P"])
|
|
20
|
+
|
|
21
|
+
Description
|
|
22
|
+
-----
|
|
23
|
+
- **Generic helpers**
|
|
24
|
+
- `pretty_block(title, content)`: Return a YAML-formatted block with an underlined title. Uses `yaml.safe_dump` with `allow_unicode=True`, `sort_keys=False`. Renders `{}` when `content` is empty.
|
|
25
|
+
- `format_freeze_atoms_for_echo(cfg, key="freeze_atoms")`: Normalize geometry configuration for CLI echo. If the key is an iterable (but not a string), summarize to a compact single-line form like `"11073 atoms [0,1,2,3,4,...,12981,12982,12983,12984,12985]"`.
|
|
26
|
+
- `format_elapsed(prefix, start_time, end_time=None)`: Format a wall-clock duration (HH:MM:SS.sss) given a start time and optional end time, using `time.perf_counter()` when the end time is omitted.
|
|
27
|
+
- `merge_freeze_atom_indices(geom_cfg, *indices)`: Merge one or more iterables of atom indices into `geom_cfg["freeze_atoms"]`. Preserve existing entries, de-duplicate, sort numerically, and return the updated list (in place).
|
|
28
|
+
- `apply_layer_freeze_constraints(geom_cfg, calc_cfg, layer_info, echo_fn=None)`: Merge layer-detected frozen indices (`layer_info["frozen_indices"]`) into both `geom_cfg["freeze_atoms"]` and `calc_cfg["freeze_atoms"]`, then optionally emit a concise summary line.
|
|
29
|
+
- `deep_update(dst, src)`: Recursively update mapping `dst` with `src`. Nested dicts are merged, non-dicts overwrite; returns `dst`.
|
|
30
|
+
- `_get_mapping_section(cfg, path)`: Internal helper to resolve a nested mapping section. Returns a `dict` or `None`.
|
|
31
|
+
- `apply_yaml_overrides(yaml_cfg, overrides)`: For each target dictionary and its candidate key paths, find the first existing path in `yaml_cfg` and apply it via `deep_update`. Centralizes repeated `yaml_cfg.get(...)`-style merging.
|
|
32
|
+
- `load_yaml_dict(path)`: Load a YAML file whose root must be a mapping. Returns `{}` when `path` is `None`. Raises `ValueError` if the YAML root is not a mapping.
|
|
33
|
+
|
|
34
|
+
- **Plotly: Energy diagram builder**
|
|
35
|
+
- `build_energy_diagram(energies, labels, ylabel="ΔE", baseline=False, showgrid=False)`:
|
|
36
|
+
Render an energy diagram where each state is a thick horizontal segment and adjacent states are connected by dotted diagonals (right end of left state → left end of right state). Segment length shrinks as the number of states grows to keep gaps readable. X ticks are centered on states and labeled by `labels`. Optional dotted baseline at the first state’s energy; optional grid. Energies are plotted as provided (no unit conversion). Returns a `plotly.graph_objs.Figure`. Validates equal lengths for `energies`/`labels` and non-empty input.
|
|
37
|
+
|
|
38
|
+
- **Coordinate conversion utilities**
|
|
39
|
+
- `convert_xyz_to_pdb(xyz_path, ref_pdb_path, out_pdb_path)`:
|
|
40
|
+
Overlay coordinates from an XYZ file (single or multi-frame) onto the atom ordering/topology of a reference PDB and write to `out_pdb_path`. The first frame creates/overwrites; subsequent frames append using `MODEL`/`ENDMDL`. Implemented with ASE (`ase.io.read`/`write`). Raises `ValueError` if no frames are found in the XYZ.
|
|
41
|
+
|
|
42
|
+
Outputs (& Directory Layout)
|
|
43
|
+
-----
|
|
44
|
+
- This module does not create directories.
|
|
45
|
+
- Functions primarily return Python objects or mutate dictionaries in place.
|
|
46
|
+
- On-disk output occurs only when explicitly requested by the caller:
|
|
47
|
+
- `convert_xyz_to_pdb` writes a PDB file to `out_pdb_path` (first frame create/overwrite; subsequent frames append with `MODEL`/`ENDMDL` blocks).
|
|
48
|
+
- `build_energy_diagram` returns a Plotly `Figure`; it does not write files unless the caller saves/exports the figure.
|
|
49
|
+
|
|
50
|
+
Notes:
|
|
51
|
+
-----
|
|
52
|
+
- Energy units in `build_energy_diagram` are passed through unchanged; ensure consistent units across states.
|
|
53
|
+
- Axis/line styling in `build_energy_diagram` is fixed-width with automatic padding; segment length adapts to the number of states.
|
|
54
|
+
- `load_yaml_dict` uses `yaml.safe_load` and enforces a mapping at the YAML root; empty files yield `{}`.
|
|
55
|
+
- `apply_yaml_overrides` tries candidate key paths in order and applies only the first existing mapping section per target.
|
|
56
|
+
- Dependencies: PyYAML, ASE (`ase.io.read`/`write`), Plotly (graph objects).
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
import ast
|
|
60
|
+
import logging
|
|
61
|
+
import math
|
|
62
|
+
import os
|
|
63
|
+
import re
|
|
64
|
+
import time
|
|
65
|
+
import tempfile
|
|
66
|
+
from collections.abc import Iterable as _Iterable, Mapping, Sequence as _Sequence
|
|
67
|
+
from dataclasses import dataclass
|
|
68
|
+
from numbers import Real, Integral
|
|
69
|
+
from pathlib import Path
|
|
70
|
+
from typing import Any, Callable, Dict, Optional, Sequence, List, Tuple
|
|
71
|
+
|
|
72
|
+
import click
|
|
73
|
+
import numpy as np
|
|
74
|
+
import yaml
|
|
75
|
+
from ase.io import read, write
|
|
76
|
+
import plotly.graph_objs as go
|
|
77
|
+
|
|
78
|
+
from pysisyphus.helpers import geom_loader
|
|
79
|
+
from pysisyphus.constants import ANG2BOHR
|
|
80
|
+
|
|
81
|
+
from .add_elem_info import guess_element
|
|
82
|
+
|
|
83
|
+
logger = logging.getLogger(__name__)
|
|
84
|
+
|
|
85
|
+
# =============================================================================
|
|
86
|
+
# Generic helpers
|
|
87
|
+
# =============================================================================
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def ensure_dir(path: Path) -> None:
|
|
91
|
+
"""Create a directory (parents ok); noop if it already exists."""
|
|
92
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def read_xyz_as_blocks(path: Path, *, strict: bool = False) -> List[List[str]]:
|
|
96
|
+
"""Read an XYZ-style trajectory into blocks of lines.
|
|
97
|
+
|
|
98
|
+
When *strict* is True, malformed headers or truncated frames raise a ClickException.
|
|
99
|
+
"""
|
|
100
|
+
try:
|
|
101
|
+
lines = path.read_text(encoding="utf-8").splitlines()
|
|
102
|
+
except Exception as e:
|
|
103
|
+
import click
|
|
104
|
+
raise click.ClickException(f"Failed to read {path}: {e}")
|
|
105
|
+
|
|
106
|
+
blocks: List[List[str]] = []
|
|
107
|
+
i = 0
|
|
108
|
+
while i < len(lines):
|
|
109
|
+
if not lines[i].strip():
|
|
110
|
+
i += 1
|
|
111
|
+
continue
|
|
112
|
+
try:
|
|
113
|
+
n_atoms = int(lines[i].strip().split()[0])
|
|
114
|
+
except Exception:
|
|
115
|
+
if strict:
|
|
116
|
+
import click
|
|
117
|
+
raise click.ClickException(f"[xyz] Malformed XYZ/TRJ header at line {i+1} of {path}")
|
|
118
|
+
break
|
|
119
|
+
end = i + n_atoms + 2
|
|
120
|
+
if end > len(lines):
|
|
121
|
+
if strict:
|
|
122
|
+
import click
|
|
123
|
+
raise click.ClickException(f"[xyz] Incomplete XYZ frame at line {i+1} of {path}")
|
|
124
|
+
break
|
|
125
|
+
blocks.append(lines[i:end])
|
|
126
|
+
i = end
|
|
127
|
+
return blocks
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def parse_xyz_block(
|
|
131
|
+
block: Sequence[str],
|
|
132
|
+
*,
|
|
133
|
+
path: Path,
|
|
134
|
+
frame_idx: int,
|
|
135
|
+
) -> Tuple[List[str], "np.ndarray"]:
|
|
136
|
+
"""Parse a single XYZ frame block into (elements, coords_angstrom)."""
|
|
137
|
+
import click
|
|
138
|
+
|
|
139
|
+
if not block:
|
|
140
|
+
raise click.ClickException(f"[xyz] Empty XYZ frame in {path}")
|
|
141
|
+
try:
|
|
142
|
+
nat = int(block[0].strip().split()[0])
|
|
143
|
+
except Exception:
|
|
144
|
+
raise click.ClickException(
|
|
145
|
+
f"[xyz] Malformed XYZ/TRJ header in frame {frame_idx} of {path}"
|
|
146
|
+
)
|
|
147
|
+
if len(block) < 2 + nat:
|
|
148
|
+
raise click.ClickException(
|
|
149
|
+
f"[xyz] Incomplete XYZ frame {frame_idx} in {path} (expected {nat} atoms)."
|
|
150
|
+
)
|
|
151
|
+
elems: List[str] = []
|
|
152
|
+
coords: List[List[float]] = []
|
|
153
|
+
for k in range(nat):
|
|
154
|
+
parts = block[2 + k].split()
|
|
155
|
+
if len(parts) < 4:
|
|
156
|
+
raise click.ClickException(
|
|
157
|
+
f"[xyz] Malformed atom line in frame {frame_idx} of {path}"
|
|
158
|
+
)
|
|
159
|
+
elems.append(parts[0])
|
|
160
|
+
coords.append([float(parts[1]), float(parts[2]), float(parts[3])])
|
|
161
|
+
return elems, np.array(coords, dtype=float)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def xyz_blocks_first_last(
|
|
165
|
+
blocks: Sequence[Sequence[str]],
|
|
166
|
+
*,
|
|
167
|
+
path: Path,
|
|
168
|
+
) -> Tuple[List[str], "np.ndarray", "np.ndarray"]:
|
|
169
|
+
"""Return (elements, first_coords_ang, last_coords_ang) from pre-parsed XYZ blocks."""
|
|
170
|
+
import click
|
|
171
|
+
|
|
172
|
+
if not blocks:
|
|
173
|
+
raise click.ClickException(f"[xyz] No frames found in {path}")
|
|
174
|
+
first_elems, first_coords = parse_xyz_block(blocks[0], path=path, frame_idx=1)
|
|
175
|
+
last_elems, last_coords = parse_xyz_block(blocks[-1], path=path, frame_idx=len(blocks))
|
|
176
|
+
if first_elems != last_elems:
|
|
177
|
+
raise click.ClickException(f"[xyz] Element list changed across frames in {path}")
|
|
178
|
+
return first_elems, first_coords, last_coords
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def read_xyz_first_last(trj_path: Path) -> Tuple[List[str], "np.ndarray", "np.ndarray"]:
|
|
182
|
+
"""Lightweight XYZ trajectory reader: return (elements, first_coords[Å], last_coords[Å])."""
|
|
183
|
+
blocks = read_xyz_as_blocks(trj_path, strict=True)
|
|
184
|
+
return xyz_blocks_first_last(blocks, path=trj_path)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def close_matplotlib_figures() -> None:
|
|
188
|
+
"""Best-effort cleanup for matplotlib figures to avoid open-figure warnings."""
|
|
189
|
+
try:
|
|
190
|
+
import matplotlib.pyplot as plt
|
|
191
|
+
plt.close("all")
|
|
192
|
+
except Exception:
|
|
193
|
+
pass
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def distance_A_from_coords(coords_bohr: "np.ndarray", i: int, j: int) -> float:
|
|
197
|
+
"""Return interatomic distance in Å given coords in Bohr."""
|
|
198
|
+
diff = coords_bohr[i] - coords_bohr[j]
|
|
199
|
+
return float(np.linalg.norm(diff) / ANG2BOHR)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def distance_tag(value_A: float, *, digits: int = 2, pad: int = 3) -> str:
|
|
203
|
+
"""Format a distance in Å as a zero-padded integer tag (default: ×10^2)."""
|
|
204
|
+
scale = 10 ** digits
|
|
205
|
+
return f"{int(round(value_A * scale)):0{pad}d}"
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def values_from_bounds(low: float, high: float, h: float) -> "np.ndarray":
|
|
209
|
+
"""Return evenly spaced values from low→high with step cap h (inclusive)."""
|
|
210
|
+
if h <= 0.0:
|
|
211
|
+
raise click.BadParameter("--max-step-size must be > 0.")
|
|
212
|
+
delta = abs(high - low)
|
|
213
|
+
if delta < 1e-12:
|
|
214
|
+
return np.array([low], dtype=float)
|
|
215
|
+
N = int(math.ceil(delta / h))
|
|
216
|
+
return np.linspace(low, high, N + 1, dtype=float)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def geom_from_xyz_string(
|
|
220
|
+
xyz_text: str,
|
|
221
|
+
*,
|
|
222
|
+
coord_type: str,
|
|
223
|
+
freeze_atoms: Optional[Sequence[int]] = None,
|
|
224
|
+
) -> Any:
|
|
225
|
+
"""Load a pysisyphus Geometry from an XYZ text string (tempfile-backed)."""
|
|
226
|
+
s = xyz_text if xyz_text.endswith("\n") else (xyz_text + "\n")
|
|
227
|
+
freeze_atoms = list(freeze_atoms) if freeze_atoms is not None else []
|
|
228
|
+
tmp = tempfile.NamedTemporaryFile("w+", suffix=".xyz", delete=False)
|
|
229
|
+
try:
|
|
230
|
+
tmp.write(s)
|
|
231
|
+
tmp.flush()
|
|
232
|
+
tmp.close()
|
|
233
|
+
|
|
234
|
+
g = geom_loader(
|
|
235
|
+
Path(tmp.name),
|
|
236
|
+
coord_type=coord_type,
|
|
237
|
+
freeze_atoms=freeze_atoms,
|
|
238
|
+
)
|
|
239
|
+
try:
|
|
240
|
+
g.freeze_atoms = np.array(sorted(set(map(int, freeze_atoms))), dtype=int)
|
|
241
|
+
except Exception:
|
|
242
|
+
click.echo(
|
|
243
|
+
"[geom] WARNING: Failed to attach freeze_atoms to geometry.",
|
|
244
|
+
err=True,
|
|
245
|
+
)
|
|
246
|
+
return g
|
|
247
|
+
finally:
|
|
248
|
+
try:
|
|
249
|
+
os.unlink(tmp.name)
|
|
250
|
+
except Exception:
|
|
251
|
+
logger.debug("Failed to unlink temp file %s", tmp.name, exc_info=True)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def append_xyz_trajectory(dst_path: Path, src_path: Path, *, reset: bool = False) -> bool:
|
|
255
|
+
"""Append an XYZ trajectory segment to a concatenated trajectory file."""
|
|
256
|
+
if not src_path.exists():
|
|
257
|
+
return False
|
|
258
|
+
mode = "w" if reset else "a"
|
|
259
|
+
with src_path.open("r", encoding="utf-8") as src, dst_path.open(mode, encoding="utf-8") as dst:
|
|
260
|
+
while True:
|
|
261
|
+
chunk = src.read(1024 * 1024)
|
|
262
|
+
if not chunk:
|
|
263
|
+
break
|
|
264
|
+
dst.write(chunk)
|
|
265
|
+
return True
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def snapshot_geometry(geom: Any, *, coord_type_default: str) -> Any:
|
|
269
|
+
"""Create an independent pysisyphus Geometry snapshot from the given Geometry."""
|
|
270
|
+
s = geom.as_xyz()
|
|
271
|
+
return geom_from_xyz_string(
|
|
272
|
+
s,
|
|
273
|
+
coord_type=getattr(geom, "coord_type", coord_type_default),
|
|
274
|
+
freeze_atoms=getattr(geom, "freeze_atoms", []),
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def unbiased_energy_hartree(geom, base_calc) -> float:
|
|
279
|
+
"""Evaluate UMA energy (Hartree) without harmonic bias."""
|
|
280
|
+
coords_bohr = np.asarray(geom.coords)
|
|
281
|
+
elems = getattr(geom, "atoms", None)
|
|
282
|
+
if elems is None:
|
|
283
|
+
return float("nan")
|
|
284
|
+
try:
|
|
285
|
+
return float(base_calc.get_energy(elems, coords_bohr)["energy"])
|
|
286
|
+
except Exception:
|
|
287
|
+
return float("nan")
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def pretty_block(title: str, content: Dict[str, Any]) -> str:
|
|
291
|
+
"""
|
|
292
|
+
Return a YAML-formatted block with an underlined title.
|
|
293
|
+
"""
|
|
294
|
+
body = yaml.safe_dump(_to_yaml_safe(content), sort_keys=False, allow_unicode=True).strip()
|
|
295
|
+
return f"{title}\n" + "-" * len(title) + "\n" + (body if body else "(empty)") + "\n"
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def _to_yaml_safe(value: Any) -> Any:
|
|
299
|
+
"""Recursively convert NumPy values/containers into YAML-safe builtins."""
|
|
300
|
+
if isinstance(value, np.generic):
|
|
301
|
+
return value.item()
|
|
302
|
+
if isinstance(value, np.ndarray):
|
|
303
|
+
return [_to_yaml_safe(v) for v in value.tolist()]
|
|
304
|
+
if isinstance(value, Mapping):
|
|
305
|
+
out: Dict[Any, Any] = {}
|
|
306
|
+
for k, v in value.items():
|
|
307
|
+
nk = _to_yaml_safe(k)
|
|
308
|
+
if isinstance(nk, (list, tuple, set, dict)):
|
|
309
|
+
nk = str(nk)
|
|
310
|
+
out[nk] = _to_yaml_safe(v)
|
|
311
|
+
return out
|
|
312
|
+
if isinstance(value, tuple):
|
|
313
|
+
return [_to_yaml_safe(v) for v in value]
|
|
314
|
+
if isinstance(value, list):
|
|
315
|
+
return [_to_yaml_safe(v) for v in value]
|
|
316
|
+
if isinstance(value, set):
|
|
317
|
+
return [_to_yaml_safe(v) for v in sorted(value, key=lambda x: str(x))]
|
|
318
|
+
return value
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
# Backend-specific key prefixes in MLMM_CALC_KW.
|
|
322
|
+
# Keys with these prefixes are only relevant when the corresponding backend is active.
|
|
323
|
+
_BACKEND_KEY_PREFIXES: Dict[str, tuple] = {
|
|
324
|
+
"uma": ("uma_model", "uma_task_name"),
|
|
325
|
+
"orb": ("orb_model", "orb_precision"),
|
|
326
|
+
"mace": ("mace_model", "mace_dtype"),
|
|
327
|
+
"aimnet2": ("aimnet2_model",),
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def filter_calc_for_echo(calc_cfg: Dict[str, Any]) -> Dict[str, Any]:
|
|
332
|
+
"""Remove backend-specific keys that are irrelevant for the active backend.
|
|
333
|
+
|
|
334
|
+
Also hides xTB/embedcharge keys when embedcharge is disabled.
|
|
335
|
+
"""
|
|
336
|
+
cfg = dict(calc_cfg)
|
|
337
|
+
active = cfg.get("backend", "uma")
|
|
338
|
+
|
|
339
|
+
# Remove keys belonging to inactive ML backends
|
|
340
|
+
for backend, keys in _BACKEND_KEY_PREFIXES.items():
|
|
341
|
+
if backend != active:
|
|
342
|
+
for k in keys:
|
|
343
|
+
cfg.pop(k, None)
|
|
344
|
+
|
|
345
|
+
# Hide xTB-specific keys when embedcharge is disabled
|
|
346
|
+
if not cfg.get("embedcharge"):
|
|
347
|
+
for k in list(cfg):
|
|
348
|
+
if k.startswith("xtb_"):
|
|
349
|
+
cfg.pop(k)
|
|
350
|
+
cfg.pop("embedcharge_step", None)
|
|
351
|
+
cfg.pop("embedcharge_cutoff", None)
|
|
352
|
+
|
|
353
|
+
return cfg
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def strip_inherited_keys(
|
|
357
|
+
child_cfg: Dict[str, Any],
|
|
358
|
+
base_cfg: Dict[str, Any],
|
|
359
|
+
*,
|
|
360
|
+
mode: str = "present",
|
|
361
|
+
) -> Dict[str, Any]:
|
|
362
|
+
"""Return child_cfg without inherited keys (for concise logs).
|
|
363
|
+
|
|
364
|
+
Parameters
|
|
365
|
+
----------
|
|
366
|
+
child_cfg : Dict[str, Any]
|
|
367
|
+
The child configuration dictionary to trim.
|
|
368
|
+
base_cfg : Dict[str, Any]
|
|
369
|
+
The base configuration dictionary to compare against.
|
|
370
|
+
mode : str
|
|
371
|
+
- "present": Remove keys that exist in base_cfg regardless of value.
|
|
372
|
+
- "same": Remove keys only when the value matches base_cfg.
|
|
373
|
+
|
|
374
|
+
Returns
|
|
375
|
+
-------
|
|
376
|
+
Dict[str, Any]
|
|
377
|
+
A new dictionary with inherited keys removed.
|
|
378
|
+
"""
|
|
379
|
+
if mode not in {"present", "same"}:
|
|
380
|
+
raise ValueError(f"Unknown strip_inherited_keys mode: {mode}")
|
|
381
|
+
trimmed: Dict[str, Any] = {}
|
|
382
|
+
for key, value in child_cfg.items():
|
|
383
|
+
if key in base_cfg:
|
|
384
|
+
if mode == "present":
|
|
385
|
+
continue
|
|
386
|
+
if base_cfg.get(key) == value:
|
|
387
|
+
continue
|
|
388
|
+
trimmed[key] = value
|
|
389
|
+
return trimmed
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def _summarize_atom_indices(items: Sequence[Any]) -> str:
|
|
393
|
+
"""Return a compact single-line summary for atom indices."""
|
|
394
|
+
if not items:
|
|
395
|
+
return ""
|
|
396
|
+
|
|
397
|
+
count = len(items)
|
|
398
|
+
if count <= 64:
|
|
399
|
+
return f"{count} atoms [{','.join(map(str, items))}]"
|
|
400
|
+
|
|
401
|
+
head = ",".join(map(str, items[:5]))
|
|
402
|
+
tail = ",".join(map(str, items[-5:]))
|
|
403
|
+
return f"{count} atoms [{head},...,{tail}]"
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
def format_freeze_atoms_for_echo(
|
|
407
|
+
cfg: Dict[str, Any],
|
|
408
|
+
*,
|
|
409
|
+
key: str = "freeze_atoms",
|
|
410
|
+
) -> Dict[str, Any]:
|
|
411
|
+
"""
|
|
412
|
+
Normalize freeze-atoms fields for concise CLI echo output.
|
|
413
|
+
"""
|
|
414
|
+
g = dict(cfg)
|
|
415
|
+
freeze_atoms = g.get(key)
|
|
416
|
+
if freeze_atoms is None:
|
|
417
|
+
return g
|
|
418
|
+
|
|
419
|
+
if isinstance(freeze_atoms, str):
|
|
420
|
+
return g
|
|
421
|
+
|
|
422
|
+
try:
|
|
423
|
+
items = list(freeze_atoms)
|
|
424
|
+
except TypeError:
|
|
425
|
+
return g
|
|
426
|
+
|
|
427
|
+
g[key] = _summarize_atom_indices(items)
|
|
428
|
+
return g
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
def format_elapsed(prefix: str, start_time: float, end_time: Optional[float] = None) -> str:
|
|
432
|
+
"""Return a formatted elapsed-time string with the provided ``prefix`` label."""
|
|
433
|
+
finish = end_time if end_time is not None else time.perf_counter()
|
|
434
|
+
elapsed = max(0.0, finish - start_time)
|
|
435
|
+
hours, rem = divmod(elapsed, 3600)
|
|
436
|
+
minutes, seconds = divmod(rem, 60)
|
|
437
|
+
return f"{prefix}: {int(hours):02d}:{int(minutes):02d}:{seconds:06.3f}"
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
def normalize_freeze_atoms(raw: Any) -> List[int]:
|
|
441
|
+
"""Normalize freeze_atoms values (string/list/iterable) into a list of integers.
|
|
442
|
+
|
|
443
|
+
Parameters
|
|
444
|
+
----------
|
|
445
|
+
raw : Any
|
|
446
|
+
Input value that can be a string (e.g., "1,2,3" or "1 2 3"),
|
|
447
|
+
a list of integers, or any iterable of numeric values.
|
|
448
|
+
|
|
449
|
+
Returns
|
|
450
|
+
-------
|
|
451
|
+
List[int]
|
|
452
|
+
List of integer indices.
|
|
453
|
+
|
|
454
|
+
Examples
|
|
455
|
+
--------
|
|
456
|
+
>>> normalize_freeze_atoms("1, 2, 3")
|
|
457
|
+
[1, 2, 3]
|
|
458
|
+
>>> normalize_freeze_atoms([1, 2, 3])
|
|
459
|
+
[1, 2, 3]
|
|
460
|
+
>>> normalize_freeze_atoms(None)
|
|
461
|
+
[]
|
|
462
|
+
"""
|
|
463
|
+
import re
|
|
464
|
+
|
|
465
|
+
if raw is None:
|
|
466
|
+
return []
|
|
467
|
+
if isinstance(raw, str):
|
|
468
|
+
tokens = re.findall(r"-?\d+", raw)
|
|
469
|
+
return [int(tok) for tok in tokens]
|
|
470
|
+
try:
|
|
471
|
+
return [int(i) for i in raw]
|
|
472
|
+
except Exception:
|
|
473
|
+
return []
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
def merge_freeze_atom_indices(
|
|
477
|
+
geom_cfg: Dict[str, Any],
|
|
478
|
+
*indices: _Iterable[int],
|
|
479
|
+
) -> List[int]:
|
|
480
|
+
"""Merge one or more iterables of indices into ``geom_cfg['freeze_atoms']``.
|
|
481
|
+
|
|
482
|
+
Existing entries are preserved, duplicates removed, and the result sorted.
|
|
483
|
+
The updated list is returned.
|
|
484
|
+
"""
|
|
485
|
+
merged: set[int] = set()
|
|
486
|
+
base = geom_cfg.get("freeze_atoms", None)
|
|
487
|
+
merged.update(normalize_freeze_atoms(base))
|
|
488
|
+
for seq in indices:
|
|
489
|
+
merged.update(normalize_freeze_atoms(seq))
|
|
490
|
+
result = sorted(merged)
|
|
491
|
+
geom_cfg["freeze_atoms"] = result
|
|
492
|
+
return result
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
# =============================================================================
|
|
496
|
+
# Link-freezing helpers
|
|
497
|
+
# =============================================================================
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def parse_pdb_coords(pdb_path):
|
|
501
|
+
"""Parse ATOM/HETATM records from *pdb_path* and separate link hydrogen (HL) atoms.
|
|
502
|
+
|
|
503
|
+
Returns:
|
|
504
|
+
A tuple (others, lkhs) where:
|
|
505
|
+
- others: list of tuples (index, x, y, z, line) for all atoms except the
|
|
506
|
+
'HL' atom of residue 'LKH'. ``index`` is the 0-based position in the
|
|
507
|
+
atom sequence as loaded from the *first* MODEL (or the full file if no
|
|
508
|
+
MODEL records are present).
|
|
509
|
+
- lkhs: list of tuples (x, y, z, line) for atoms where residue name is
|
|
510
|
+
'LKH' and atom name is 'HL' in the same MODEL selection.
|
|
511
|
+
|
|
512
|
+
Notes
|
|
513
|
+
-----
|
|
514
|
+
- Coordinates are read from standard PDB columns:
|
|
515
|
+
X: columns 31-38, Y: 39-46, Z: 47-54 (1-based indexing).
|
|
516
|
+
- If multiple MODEL blocks are present, only the first model is considered,
|
|
517
|
+
matching typical geom_loader behavior.
|
|
518
|
+
"""
|
|
519
|
+
with open(pdb_path, "r") as f:
|
|
520
|
+
lines = f.readlines()
|
|
521
|
+
|
|
522
|
+
others = []
|
|
523
|
+
lkhs = []
|
|
524
|
+
model_seen = False
|
|
525
|
+
in_first_model = True
|
|
526
|
+
atom_index = 0
|
|
527
|
+
for line in lines:
|
|
528
|
+
if line.startswith("MODEL"):
|
|
529
|
+
if not model_seen:
|
|
530
|
+
model_seen = True
|
|
531
|
+
in_first_model = True
|
|
532
|
+
else:
|
|
533
|
+
in_first_model = False
|
|
534
|
+
continue
|
|
535
|
+
if line.startswith("ENDMDL"):
|
|
536
|
+
if model_seen and in_first_model:
|
|
537
|
+
break
|
|
538
|
+
continue
|
|
539
|
+
if model_seen and not in_first_model:
|
|
540
|
+
continue
|
|
541
|
+
if not (line.startswith("ATOM") or line.startswith("HETATM")):
|
|
542
|
+
continue
|
|
543
|
+
|
|
544
|
+
current_index = atom_index
|
|
545
|
+
atom_index += 1
|
|
546
|
+
|
|
547
|
+
name = line[12:16].strip()
|
|
548
|
+
resname = line[17:20].strip()
|
|
549
|
+
try:
|
|
550
|
+
x = float(line[30:38])
|
|
551
|
+
y = float(line[38:46])
|
|
552
|
+
z = float(line[46:54])
|
|
553
|
+
except ValueError:
|
|
554
|
+
continue
|
|
555
|
+
|
|
556
|
+
if resname == "LKH" and name == "HL":
|
|
557
|
+
lkhs.append((x, y, z, line))
|
|
558
|
+
else:
|
|
559
|
+
others.append((current_index, x, y, z, line))
|
|
560
|
+
return others, lkhs
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
def nearest_index(point, pool):
|
|
564
|
+
"""Find the nearest point in *pool* to *point* using Euclidean distance.
|
|
565
|
+
|
|
566
|
+
Args:
|
|
567
|
+
point: Tuple (x, y, z) representing the query coordinate.
|
|
568
|
+
pool: Iterable of tuples (index, x, y, z, line) to search.
|
|
569
|
+
|
|
570
|
+
Returns:
|
|
571
|
+
A tuple (index, distance) where:
|
|
572
|
+
- index is the 0-based index of the nearest entry in *pool* (or -1 if *pool* is empty).
|
|
573
|
+
- distance is the Euclidean distance to that entry (``inf`` if *pool* is empty).
|
|
574
|
+
"""
|
|
575
|
+
x, y, z = point
|
|
576
|
+
best_i = -1
|
|
577
|
+
best_d2 = float("inf")
|
|
578
|
+
for atom_index, a, b, c, _ in pool:
|
|
579
|
+
d2 = (a - x) ** 2 + (b - y) ** 2 + (c - z) ** 2
|
|
580
|
+
if d2 < best_d2:
|
|
581
|
+
best_d2 = d2
|
|
582
|
+
best_i = atom_index
|
|
583
|
+
return best_i, math.sqrt(best_d2)
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
def detect_freeze_links(pdb_path):
|
|
587
|
+
"""Identify link-parent atom indices for 'LKH'/'HL' link hydrogens.
|
|
588
|
+
|
|
589
|
+
For each 'HL' atom in residue 'LKH', find the nearest atom among all other
|
|
590
|
+
ATOM/HETATM records and return the indices of those nearest neighbors in the
|
|
591
|
+
same atom ordering used by geometry loading (first MODEL if present).
|
|
592
|
+
|
|
593
|
+
Args:
|
|
594
|
+
pdb_path: Path to the input PDB file.
|
|
595
|
+
|
|
596
|
+
Returns:
|
|
597
|
+
List of 0-based indices into the full atom sequence (including any link H atoms)
|
|
598
|
+
corresponding to the nearest neighbors (link parents). Returns an empty list if
|
|
599
|
+
no LKH/HL atoms are present or if link hydrogens exist without any other atoms.
|
|
600
|
+
"""
|
|
601
|
+
others, lkhs = parse_pdb_coords(pdb_path)
|
|
602
|
+
|
|
603
|
+
if not lkhs or not others:
|
|
604
|
+
return []
|
|
605
|
+
|
|
606
|
+
indices = []
|
|
607
|
+
for (x, y, z, line) in lkhs:
|
|
608
|
+
idx, dist = nearest_index((x, y, z), others)
|
|
609
|
+
if idx >= 0:
|
|
610
|
+
indices.append(idx)
|
|
611
|
+
return indices
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
def detect_freeze_links_logged(pdb_path: Path) -> List[int]:
|
|
615
|
+
"""Return link-parent indices and raise a user-facing error on failure."""
|
|
616
|
+
try:
|
|
617
|
+
return list(detect_freeze_links(pdb_path))
|
|
618
|
+
except Exception as e:
|
|
619
|
+
raise click.ClickException(
|
|
620
|
+
f"[freeze-links] Failed to detect link parents for '{pdb_path.name}': {e}"
|
|
621
|
+
) from e
|
|
622
|
+
|
|
623
|
+
|
|
624
|
+
def merge_detected_freeze_links(
|
|
625
|
+
geom_cfg: Dict[str, Any],
|
|
626
|
+
pdb_path: Path,
|
|
627
|
+
*,
|
|
628
|
+
prefix: str = "[freeze-links]",
|
|
629
|
+
) -> List[int]:
|
|
630
|
+
"""Detect link-parent atoms and merge them into ``geom_cfg['freeze_atoms']``."""
|
|
631
|
+
detected = detect_freeze_links_logged(pdb_path)
|
|
632
|
+
merged = merge_freeze_atom_indices(geom_cfg, detected)
|
|
633
|
+
if merged:
|
|
634
|
+
click.echo(f"{prefix} Freeze atoms (0-based): {','.join(map(str, merged))}")
|
|
635
|
+
return merged
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
def apply_layer_freeze_constraints(
|
|
639
|
+
geom_cfg: Dict[str, Any],
|
|
640
|
+
calc_cfg: Dict[str, Any],
|
|
641
|
+
layer_info: Optional[Dict[str, Sequence[int]]],
|
|
642
|
+
*,
|
|
643
|
+
echo_fn: Optional[Callable[[str], None]] = None,
|
|
644
|
+
) -> List[int]:
|
|
645
|
+
"""Merge frozen-layer atoms into geometry/calculator freeze lists."""
|
|
646
|
+
base_freeze = normalize_freeze_atoms(geom_cfg.get("freeze_atoms"))
|
|
647
|
+
frozen_from_layer = normalize_freeze_atoms((layer_info or {}).get("frozen_indices", []))
|
|
648
|
+
|
|
649
|
+
if frozen_from_layer:
|
|
650
|
+
before = set(base_freeze)
|
|
651
|
+
merged = sorted(before | set(frozen_from_layer))
|
|
652
|
+
added = len(set(merged) - before)
|
|
653
|
+
if echo_fn is not None:
|
|
654
|
+
echo_fn(
|
|
655
|
+
f"[layer] Applied freeze constraints from frozen layer: "
|
|
656
|
+
f"total={len(merged)} (added_from_layer={added})"
|
|
657
|
+
)
|
|
658
|
+
else:
|
|
659
|
+
merged = sorted(set(base_freeze))
|
|
660
|
+
|
|
661
|
+
geom_cfg["freeze_atoms"] = merged
|
|
662
|
+
calc_cfg["freeze_atoms"] = merged
|
|
663
|
+
return merged
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
def deep_update(dst: Dict[str, Any], src: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
|
667
|
+
"""
|
|
668
|
+
Recursively update mapping *dst* with *src*, returning *dst*.
|
|
669
|
+
"""
|
|
670
|
+
for k, v in (src or {}).items():
|
|
671
|
+
if isinstance(v, dict) and isinstance(dst.get(k), dict):
|
|
672
|
+
deep_update(dst[k], v)
|
|
673
|
+
else:
|
|
674
|
+
dst[k] = v
|
|
675
|
+
return dst
|
|
676
|
+
|
|
677
|
+
|
|
678
|
+
def collect_single_option_values(
|
|
679
|
+
argv: _Sequence[str],
|
|
680
|
+
names: _Sequence[str],
|
|
681
|
+
label: str,
|
|
682
|
+
) -> List[str]:
|
|
683
|
+
"""Collect values following a flag that must appear at most once."""
|
|
684
|
+
vals: List[str] = []
|
|
685
|
+
seen = 0
|
|
686
|
+
i = 0
|
|
687
|
+
while i < len(argv):
|
|
688
|
+
tok = argv[i]
|
|
689
|
+
if tok in names:
|
|
690
|
+
seen += 1
|
|
691
|
+
j = i + 1
|
|
692
|
+
while j < len(argv) and not argv[j].startswith("-"):
|
|
693
|
+
vals.append(argv[j])
|
|
694
|
+
j += 1
|
|
695
|
+
i = j
|
|
696
|
+
else:
|
|
697
|
+
i += 1
|
|
698
|
+
if seen > 1:
|
|
699
|
+
raise click.BadParameter(
|
|
700
|
+
f"Use a single {label} followed by multiple values; repeated flags are not accepted."
|
|
701
|
+
)
|
|
702
|
+
return vals
|
|
703
|
+
|
|
704
|
+
|
|
705
|
+
def load_pdb_atom_metadata(pdb_path: Path) -> List[Dict[str, Any]]:
|
|
706
|
+
"""Return per-atom metadata (serial, name, resname, resseq, element) in file order."""
|
|
707
|
+
atoms: List[Dict[str, Any]] = []
|
|
708
|
+
with open(pdb_path, "r") as f:
|
|
709
|
+
for line in f:
|
|
710
|
+
if not (line.startswith("ATOM") or line.startswith("HETATM")):
|
|
711
|
+
continue
|
|
712
|
+
|
|
713
|
+
serial_txt = line[6:11].strip()
|
|
714
|
+
resseq_txt = line[22:26].strip()
|
|
715
|
+
atom_name = line[12:16].strip()
|
|
716
|
+
res_name = line[17:20].strip()
|
|
717
|
+
element_txt = line[76:78].strip()
|
|
718
|
+
is_hetatm = line.startswith("HETATM")
|
|
719
|
+
|
|
720
|
+
try:
|
|
721
|
+
serial = int(serial_txt) if serial_txt else None
|
|
722
|
+
except ValueError:
|
|
723
|
+
serial = None
|
|
724
|
+
try:
|
|
725
|
+
resseq = int(resseq_txt) if resseq_txt else None
|
|
726
|
+
except ValueError:
|
|
727
|
+
resseq = None
|
|
728
|
+
|
|
729
|
+
if not element_txt:
|
|
730
|
+
inferred = guess_element(atom_name, res_name, is_hetatm)
|
|
731
|
+
element_txt = inferred or ""
|
|
732
|
+
|
|
733
|
+
atoms.append(
|
|
734
|
+
{
|
|
735
|
+
"serial": serial,
|
|
736
|
+
"name": atom_name,
|
|
737
|
+
"resname": res_name,
|
|
738
|
+
"resseq": resseq,
|
|
739
|
+
"element": element_txt,
|
|
740
|
+
}
|
|
741
|
+
)
|
|
742
|
+
return atoms
|
|
743
|
+
|
|
744
|
+
|
|
745
|
+
def resolve_atom_spec_index(spec: str, atom_meta: _Sequence[Dict[str, Any]]) -> int:
|
|
746
|
+
"""Resolve an atom selector string into a 0-based atom index using PDB metadata."""
|
|
747
|
+
tokens = [t for t in re.split(r"[\s/`,\\]+", spec.strip().replace(" ", ",")) if t]
|
|
748
|
+
if len(tokens) != 3:
|
|
749
|
+
raise ValueError(
|
|
750
|
+
f"Atom spec '{spec}' must have exactly 3 fields (resname, resseq, atomname)."
|
|
751
|
+
)
|
|
752
|
+
|
|
753
|
+
tokens_upper = [t.upper() for t in tokens]
|
|
754
|
+
matches: List[int] = []
|
|
755
|
+
for idx, meta in enumerate(atom_meta):
|
|
756
|
+
resname = (meta.get("resname") or "").strip().upper()
|
|
757
|
+
resseq = meta.get("resseq")
|
|
758
|
+
atom = (meta.get("name") or "").strip().upper()
|
|
759
|
+
if resseq is None:
|
|
760
|
+
continue
|
|
761
|
+
fields = {resname, str(resseq), atom}
|
|
762
|
+
if all(tok in fields for tok in tokens_upper):
|
|
763
|
+
matches.append(idx)
|
|
764
|
+
|
|
765
|
+
if len(matches) == 1:
|
|
766
|
+
return matches[0]
|
|
767
|
+
if len(matches) > 1:
|
|
768
|
+
raise ValueError(
|
|
769
|
+
f"Atom spec '{spec}' matches {len(matches)} atoms; use an explicit atom index."
|
|
770
|
+
)
|
|
771
|
+
|
|
772
|
+
resname, resseq_str, atom = tokens_upper
|
|
773
|
+
if not resseq_str.isdigit():
|
|
774
|
+
raise ValueError(
|
|
775
|
+
f"Atom spec '{spec}' could not be resolved and residue number '{tokens[1]}' is not numeric."
|
|
776
|
+
)
|
|
777
|
+
resseq_int = int(resseq_str)
|
|
778
|
+
ordered_matches = [
|
|
779
|
+
idx
|
|
780
|
+
for idx, meta in enumerate(atom_meta)
|
|
781
|
+
if (meta.get("resname") or "").strip().upper() == resname
|
|
782
|
+
and meta.get("resseq") == resseq_int
|
|
783
|
+
and (meta.get("name") or "").strip().upper() == atom
|
|
784
|
+
]
|
|
785
|
+
if len(ordered_matches) == 1:
|
|
786
|
+
return ordered_matches[0]
|
|
787
|
+
if len(ordered_matches) > 1:
|
|
788
|
+
raise ValueError(
|
|
789
|
+
f"Atom spec '{spec}' matches {len(ordered_matches)} atoms after ordered fallback; "
|
|
790
|
+
"use an explicit atom index."
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
raise ValueError(f"Atom spec '{spec}' did not match any atom.")
|
|
794
|
+
|
|
795
|
+
|
|
796
|
+
def atom_label_from_meta(atom_meta: _Sequence[Dict[str, Any]], index: int) -> str:
|
|
797
|
+
if index < 0 or index >= len(atom_meta):
|
|
798
|
+
return f"idx{index}"
|
|
799
|
+
meta = atom_meta[index]
|
|
800
|
+
resname = (meta.get("resname") or "?").strip() or "?"
|
|
801
|
+
resseq = meta.get("resseq")
|
|
802
|
+
resseq_txt = "?" if resseq is None else str(resseq)
|
|
803
|
+
atom = (meta.get("name") or "?").strip() or "?"
|
|
804
|
+
return f"{resname}-{resseq_txt}-{atom}"
|
|
805
|
+
|
|
806
|
+
|
|
807
|
+
def axis_label_csv(
|
|
808
|
+
axis_name: str,
|
|
809
|
+
i_idx: int,
|
|
810
|
+
j_idx: int,
|
|
811
|
+
one_based: bool,
|
|
812
|
+
atom_meta: Optional[_Sequence[Dict[str, Any]]] = None,
|
|
813
|
+
pair_raw: Optional[Tuple[Any, Any, float, float]] = None,
|
|
814
|
+
) -> str:
|
|
815
|
+
if pair_raw and (isinstance(pair_raw[0], str) or isinstance(pair_raw[1], str)) and atom_meta:
|
|
816
|
+
i_label = atom_label_from_meta(atom_meta, i_idx)
|
|
817
|
+
j_label = atom_label_from_meta(atom_meta, j_idx)
|
|
818
|
+
return f"{axis_name}_{i_label}_{j_label}_A"
|
|
819
|
+
i_disp = i_idx + 1 if one_based else i_idx
|
|
820
|
+
j_disp = j_idx + 1 if one_based else j_idx
|
|
821
|
+
return f"{axis_name}_{i_disp}_{j_disp}_A"
|
|
822
|
+
|
|
823
|
+
|
|
824
|
+
def axis_label_html(label: str) -> str:
|
|
825
|
+
parts = label.split("_")
|
|
826
|
+
if len(parts) >= 4 and parts[-1] == "A":
|
|
827
|
+
axis = parts[0]
|
|
828
|
+
i_disp = parts[1]
|
|
829
|
+
j_disp = parts[2]
|
|
830
|
+
return f"{axis} ({i_disp},{j_disp}) (Å)"
|
|
831
|
+
return label
|
|
832
|
+
|
|
833
|
+
|
|
834
|
+
def resolve_scan_index(
|
|
835
|
+
value: Any,
|
|
836
|
+
*,
|
|
837
|
+
one_based: bool,
|
|
838
|
+
atom_meta: Optional[_Sequence[Dict[str, Any]]],
|
|
839
|
+
context: str,
|
|
840
|
+
) -> int:
|
|
841
|
+
"""Resolve an index or atom-spec string for scan lists with consistent errors."""
|
|
842
|
+
if isinstance(value, Integral):
|
|
843
|
+
idx_val = int(value)
|
|
844
|
+
if one_based:
|
|
845
|
+
idx_val -= 1
|
|
846
|
+
if idx_val < 0:
|
|
847
|
+
raise click.BadParameter(
|
|
848
|
+
f"Negative atom index after base conversion in {context}: {idx_val} (0-based expected)."
|
|
849
|
+
)
|
|
850
|
+
return idx_val
|
|
851
|
+
if isinstance(value, str):
|
|
852
|
+
if not atom_meta:
|
|
853
|
+
raise click.BadParameter(
|
|
854
|
+
f"{context} uses a string atom spec, but no PDB metadata is available."
|
|
855
|
+
)
|
|
856
|
+
try:
|
|
857
|
+
return resolve_atom_spec_index(value, atom_meta)
|
|
858
|
+
except ValueError as exc:
|
|
859
|
+
raise click.BadParameter(f"{context} {exc}")
|
|
860
|
+
raise click.BadParameter(f"{context} must be an int index or atom spec string.")
|
|
861
|
+
|
|
862
|
+
|
|
863
|
+
def parse_scan_list_triples(
|
|
864
|
+
raw: str,
|
|
865
|
+
*,
|
|
866
|
+
one_based: bool,
|
|
867
|
+
atom_meta: Optional[_Sequence[Dict[str, Any]]],
|
|
868
|
+
option_name: str,
|
|
869
|
+
return_one_based: bool = False,
|
|
870
|
+
) -> Tuple[List[Tuple[int, int, float]], List[Tuple[Any, Any, float]]]:
|
|
871
|
+
"""Parse --scan-lists entries into indices (0-based by default).
|
|
872
|
+
|
|
873
|
+
Accepts both 3-tuples ``(i, j, target)`` and 4-tuples
|
|
874
|
+
``(i, j, start, end)`` for bidirectional scans. 4-tuples are
|
|
875
|
+
expanded into two 3-tuple stages (initial→start, then initial→end)
|
|
876
|
+
by the caller in scan.py.
|
|
877
|
+
|
|
878
|
+
The returned *parsed* list contains tuples of length 3 **or** 4:
|
|
879
|
+
``(i, j, target)`` or ``(i, j, start, end)``.
|
|
880
|
+
"""
|
|
881
|
+
try:
|
|
882
|
+
obj = ast.literal_eval(raw)
|
|
883
|
+
except Exception as e:
|
|
884
|
+
raise click.BadParameter(f"Invalid literal for {option_name}: {e}")
|
|
885
|
+
|
|
886
|
+
if not isinstance(obj, (list, tuple)):
|
|
887
|
+
raise click.BadParameter(f"{option_name} must be a list/tuple of (i,j,target) or (i,j,start,end).")
|
|
888
|
+
|
|
889
|
+
parsed: list = []
|
|
890
|
+
for entry_idx, t in enumerate(obj, start=1):
|
|
891
|
+
is_3 = (
|
|
892
|
+
isinstance(t, (list, tuple))
|
|
893
|
+
and len(t) == 3
|
|
894
|
+
and isinstance(t[2], Real)
|
|
895
|
+
)
|
|
896
|
+
is_4 = (
|
|
897
|
+
isinstance(t, (list, tuple))
|
|
898
|
+
and len(t) == 4
|
|
899
|
+
and isinstance(t[2], Real)
|
|
900
|
+
and isinstance(t[3], Real)
|
|
901
|
+
)
|
|
902
|
+
if not (is_3 or is_4):
|
|
903
|
+
raise click.BadParameter(
|
|
904
|
+
f"{option_name} entry {entry_idx} must be (i,j,target) or (i,j,start,end): got {t}"
|
|
905
|
+
)
|
|
906
|
+
|
|
907
|
+
i = resolve_scan_index(
|
|
908
|
+
t[0],
|
|
909
|
+
one_based=one_based,
|
|
910
|
+
atom_meta=atom_meta,
|
|
911
|
+
context=f"{option_name} entry {entry_idx} (i)",
|
|
912
|
+
)
|
|
913
|
+
j = resolve_scan_index(
|
|
914
|
+
t[1],
|
|
915
|
+
one_based=one_based,
|
|
916
|
+
atom_meta=atom_meta,
|
|
917
|
+
context=f"{option_name} entry {entry_idx} (j)",
|
|
918
|
+
)
|
|
919
|
+
if return_one_based:
|
|
920
|
+
i += 1
|
|
921
|
+
j += 1
|
|
922
|
+
if is_4:
|
|
923
|
+
parsed.append((i, j, float(t[2]), float(t[3])))
|
|
924
|
+
else:
|
|
925
|
+
parsed.append((i, j, float(t[2])))
|
|
926
|
+
|
|
927
|
+
return parsed, list(obj)
|
|
928
|
+
|
|
929
|
+
|
|
930
|
+
def parse_dist_freeze_list(
|
|
931
|
+
raw: str,
|
|
932
|
+
*,
|
|
933
|
+
one_based: bool,
|
|
934
|
+
atom_meta: Optional[_Sequence[Dict[str, Any]]],
|
|
935
|
+
option_name: str = "--dist-freeze",
|
|
936
|
+
) -> List[Tuple[int, int, Optional[float]]]:
|
|
937
|
+
"""Parse ``--dist-freeze`` entries: ``(i,j)`` or ``(i,j,target_A)``.
|
|
938
|
+
|
|
939
|
+
Uses the same :func:`resolve_scan_index` as ``--scan-lists``, so string
|
|
940
|
+
atom specs (e.g. ``'A:SER123:OG'``) are supported when PDB metadata is
|
|
941
|
+
available.
|
|
942
|
+
"""
|
|
943
|
+
try:
|
|
944
|
+
obj = ast.literal_eval(raw)
|
|
945
|
+
except Exception as e:
|
|
946
|
+
raise click.BadParameter(f"Invalid literal for {option_name}: {e}")
|
|
947
|
+
|
|
948
|
+
if not isinstance(obj, (list, tuple)):
|
|
949
|
+
raise click.BadParameter(f"{option_name} must be a list/tuple of (i,j) or (i,j,target).")
|
|
950
|
+
|
|
951
|
+
# Single tuple → wrap in list
|
|
952
|
+
if obj and not isinstance(obj[0], (list, tuple)):
|
|
953
|
+
obj = [obj]
|
|
954
|
+
|
|
955
|
+
parsed: List[Tuple[int, int, Optional[float]]] = []
|
|
956
|
+
for entry_idx, t in enumerate(obj, start=1):
|
|
957
|
+
if not (isinstance(t, (list, tuple)) and len(t) in (2, 3)):
|
|
958
|
+
raise click.BadParameter(
|
|
959
|
+
f"{option_name} entry {entry_idx} must be (i,j) or (i,j,target): got {t}"
|
|
960
|
+
)
|
|
961
|
+
i = resolve_scan_index(
|
|
962
|
+
t[0], one_based=one_based, atom_meta=atom_meta,
|
|
963
|
+
context=f"{option_name} entry {entry_idx} (i)",
|
|
964
|
+
)
|
|
965
|
+
j = resolve_scan_index(
|
|
966
|
+
t[1], one_based=one_based, atom_meta=atom_meta,
|
|
967
|
+
context=f"{option_name} entry {entry_idx} (j)",
|
|
968
|
+
)
|
|
969
|
+
target: Optional[float] = None
|
|
970
|
+
if len(t) == 3:
|
|
971
|
+
if not isinstance(t[2], Real):
|
|
972
|
+
raise click.BadParameter(
|
|
973
|
+
f"Target distance must be numeric in {option_name} entry {entry_idx}: {t}"
|
|
974
|
+
)
|
|
975
|
+
target = float(t[2])
|
|
976
|
+
if target <= 0.0:
|
|
977
|
+
raise click.BadParameter(
|
|
978
|
+
f"Target distance must be > 0 in {option_name} entry {entry_idx}: {t}"
|
|
979
|
+
)
|
|
980
|
+
parsed.append((i, j, target))
|
|
981
|
+
return parsed
|
|
982
|
+
|
|
983
|
+
|
|
984
|
+
def parse_dist_freeze_spec(
|
|
985
|
+
spec_path: Path,
|
|
986
|
+
*,
|
|
987
|
+
one_based_default: bool,
|
|
988
|
+
atom_meta: Optional[_Sequence[Dict[str, Any]]],
|
|
989
|
+
option_name: str = "--dist-freeze",
|
|
990
|
+
) -> List[Tuple[int, int, Optional[float]]]:
|
|
991
|
+
"""Parse a YAML/JSON dist-freeze spec file.
|
|
992
|
+
|
|
993
|
+
Expected format::
|
|
994
|
+
|
|
995
|
+
constraints: # or "pairs" / "stages"
|
|
996
|
+
- [1, 5, 1.4] # (i, j, target_A) — target optional
|
|
997
|
+
- [2, 6] # freeze at current distance
|
|
998
|
+
one_based: true # optional, defaults to CLI value
|
|
999
|
+
"""
|
|
1000
|
+
spec_cfg = _load_scan_spec_root(spec_path, option_name=option_name)
|
|
1001
|
+
key, raw_list = _first_spec_field(spec_cfg, ("constraints", "pairs", "stages"))
|
|
1002
|
+
if key is None:
|
|
1003
|
+
raise click.BadParameter(
|
|
1004
|
+
f"{option_name} spec must define 'constraints', 'pairs', or 'stages'."
|
|
1005
|
+
)
|
|
1006
|
+
if not isinstance(raw_list, (list, tuple)) or len(raw_list) == 0:
|
|
1007
|
+
raise click.BadParameter(
|
|
1008
|
+
f"{option_name} field '{key}' must be a non-empty list."
|
|
1009
|
+
)
|
|
1010
|
+
|
|
1011
|
+
one_based = _spec_one_based(
|
|
1012
|
+
spec_cfg.get("one_based"), default=one_based_default, option_name=option_name,
|
|
1013
|
+
)
|
|
1014
|
+
return parse_dist_freeze_list(
|
|
1015
|
+
repr(list(raw_list)),
|
|
1016
|
+
one_based=one_based,
|
|
1017
|
+
atom_meta=atom_meta,
|
|
1018
|
+
option_name=f"{option_name} {key}",
|
|
1019
|
+
)
|
|
1020
|
+
|
|
1021
|
+
|
|
1022
|
+
def parse_scan_list_quads(
|
|
1023
|
+
raw: str,
|
|
1024
|
+
*,
|
|
1025
|
+
expected_len: int,
|
|
1026
|
+
one_based: bool,
|
|
1027
|
+
atom_meta: Optional[_Sequence[Dict[str, Any]]],
|
|
1028
|
+
option_name: str,
|
|
1029
|
+
) -> Tuple[List[Tuple[int, int, float, float]], List[Tuple[Any, Any, float, float]]]:
|
|
1030
|
+
"""Parse --scan-lists quadruples into 0-based indices."""
|
|
1031
|
+
try:
|
|
1032
|
+
obj = ast.literal_eval(raw)
|
|
1033
|
+
except Exception as e:
|
|
1034
|
+
raise click.BadParameter(f"Invalid literal for {option_name}: {e}")
|
|
1035
|
+
|
|
1036
|
+
if not (isinstance(obj, (list, tuple)) and len(obj) == expected_len):
|
|
1037
|
+
quads = ",".join([f"(i{n},j{n},low{n},high{n})" for n in range(1, expected_len + 1)])
|
|
1038
|
+
raise click.BadParameter(
|
|
1039
|
+
f"{option_name} must contain exactly {expected_len} quadruples: [{quads}]"
|
|
1040
|
+
)
|
|
1041
|
+
|
|
1042
|
+
parsed: List[Tuple[int, int, float, float]] = []
|
|
1043
|
+
for entry_idx, q in enumerate(obj, start=1):
|
|
1044
|
+
if not (
|
|
1045
|
+
isinstance(q, (list, tuple))
|
|
1046
|
+
and len(q) == 4
|
|
1047
|
+
and isinstance(q[2], Real)
|
|
1048
|
+
and isinstance(q[3], Real)
|
|
1049
|
+
):
|
|
1050
|
+
raise click.BadParameter(f"{option_name} entry must be (i,j,low,high): got {q}")
|
|
1051
|
+
|
|
1052
|
+
i = resolve_scan_index(
|
|
1053
|
+
q[0],
|
|
1054
|
+
one_based=one_based,
|
|
1055
|
+
atom_meta=atom_meta,
|
|
1056
|
+
context=f"{option_name} entry {entry_idx} (i)",
|
|
1057
|
+
)
|
|
1058
|
+
j = resolve_scan_index(
|
|
1059
|
+
q[1],
|
|
1060
|
+
one_based=one_based,
|
|
1061
|
+
atom_meta=atom_meta,
|
|
1062
|
+
context=f"{option_name} entry {entry_idx} (j)",
|
|
1063
|
+
)
|
|
1064
|
+
parsed.append((i, j, float(q[2]), float(q[3])))
|
|
1065
|
+
|
|
1066
|
+
for i, j, low, high in parsed:
|
|
1067
|
+
if low <= 0.0 or high <= 0.0:
|
|
1068
|
+
raise click.BadParameter(f"Distances must be positive: {(i, j, low, high)}")
|
|
1069
|
+
|
|
1070
|
+
return parsed, list(obj)
|
|
1071
|
+
|
|
1072
|
+
|
|
1073
|
+
def _load_scan_spec_root(
|
|
1074
|
+
spec_path: Path,
|
|
1075
|
+
*,
|
|
1076
|
+
option_name: str = "--scan-lists",
|
|
1077
|
+
) -> Mapping[str, Any]:
|
|
1078
|
+
"""Load a scan spec (YAML/JSON) and ensure mapping root."""
|
|
1079
|
+
try:
|
|
1080
|
+
with open(spec_path, "r", encoding="utf-8") as handle:
|
|
1081
|
+
data = yaml.safe_load(handle)
|
|
1082
|
+
except Exception as exc:
|
|
1083
|
+
raise click.BadParameter(
|
|
1084
|
+
f"Failed to parse {option_name} file '{spec_path}': {exc}"
|
|
1085
|
+
)
|
|
1086
|
+
|
|
1087
|
+
if data is None:
|
|
1088
|
+
raise click.BadParameter(f"{option_name} file '{spec_path}' is empty.")
|
|
1089
|
+
if not isinstance(data, Mapping):
|
|
1090
|
+
raise click.BadParameter(
|
|
1091
|
+
f"{option_name} file '{spec_path}' must have a mapping at the YAML/JSON root."
|
|
1092
|
+
)
|
|
1093
|
+
return data
|
|
1094
|
+
|
|
1095
|
+
|
|
1096
|
+
def _spec_one_based(
|
|
1097
|
+
value: Any,
|
|
1098
|
+
*,
|
|
1099
|
+
default: bool,
|
|
1100
|
+
option_name: str = "--scan-lists",
|
|
1101
|
+
) -> bool:
|
|
1102
|
+
"""Resolve one_based value from spec with CLI fallback."""
|
|
1103
|
+
if value is None:
|
|
1104
|
+
return bool(default)
|
|
1105
|
+
if isinstance(value, bool):
|
|
1106
|
+
return value
|
|
1107
|
+
if isinstance(value, str):
|
|
1108
|
+
key = value.strip().lower()
|
|
1109
|
+
if key in {"1", "true", "yes", "y", "on"}:
|
|
1110
|
+
return True
|
|
1111
|
+
if key in {"0", "false", "no", "n", "off"}:
|
|
1112
|
+
return False
|
|
1113
|
+
raise click.BadParameter(
|
|
1114
|
+
f"{option_name} field 'one_based' must be a boolean (true/false)."
|
|
1115
|
+
)
|
|
1116
|
+
|
|
1117
|
+
|
|
1118
|
+
def _first_spec_field(
|
|
1119
|
+
spec_cfg: Mapping[str, Any],
|
|
1120
|
+
candidates: _Sequence[str],
|
|
1121
|
+
) -> Tuple[Optional[str], Any]:
|
|
1122
|
+
for key in candidates:
|
|
1123
|
+
if key in spec_cfg:
|
|
1124
|
+
return key, spec_cfg[key]
|
|
1125
|
+
return None, None
|
|
1126
|
+
|
|
1127
|
+
|
|
1128
|
+
def is_scan_spec_file(value: str) -> bool:
|
|
1129
|
+
"""Return True if *value* looks like an existing YAML/JSON scan spec file."""
|
|
1130
|
+
p = Path(value)
|
|
1131
|
+
return p.is_file() and p.suffix.lower() in {".yaml", ".yml", ".json"}
|
|
1132
|
+
|
|
1133
|
+
|
|
1134
|
+
def parse_scan_spec_stages(
|
|
1135
|
+
spec_path: Path,
|
|
1136
|
+
*,
|
|
1137
|
+
one_based_default: bool,
|
|
1138
|
+
atom_meta: Optional[_Sequence[Dict[str, Any]]],
|
|
1139
|
+
option_name: str = "--scan-lists",
|
|
1140
|
+
) -> Tuple[List[List[Tuple[int, int, float]]], bool]:
|
|
1141
|
+
"""Parse staged 1D scan spec into 0-based stage triples."""
|
|
1142
|
+
spec_cfg = _load_scan_spec_root(spec_path, option_name=option_name)
|
|
1143
|
+
stages_key, stages_raw = _first_spec_field(spec_cfg, ("stages",))
|
|
1144
|
+
if stages_key is None:
|
|
1145
|
+
raise click.BadParameter(f"{option_name} must define 'stages'.")
|
|
1146
|
+
if not isinstance(stages_raw, (list, tuple)) or len(stages_raw) == 0:
|
|
1147
|
+
raise click.BadParameter(f"{option_name} field '{stages_key}' must be a non-empty list.")
|
|
1148
|
+
|
|
1149
|
+
one_based = _spec_one_based(
|
|
1150
|
+
spec_cfg.get("one_based"), default=one_based_default, option_name=option_name
|
|
1151
|
+
)
|
|
1152
|
+
stages: List[List[Tuple[int, int, float]]] = []
|
|
1153
|
+
for stage_idx, stage_raw in enumerate(stages_raw, start=1):
|
|
1154
|
+
if not isinstance(stage_raw, (list, tuple)):
|
|
1155
|
+
raise click.BadParameter(
|
|
1156
|
+
f"{option_name} {stages_key}[{stage_idx}] must be a list of (i,j,target) entries."
|
|
1157
|
+
)
|
|
1158
|
+
parsed, _ = parse_scan_list_triples(
|
|
1159
|
+
repr(list(stage_raw)),
|
|
1160
|
+
one_based=one_based,
|
|
1161
|
+
atom_meta=atom_meta,
|
|
1162
|
+
option_name=f"{option_name} {stages_key}[{stage_idx}]",
|
|
1163
|
+
)
|
|
1164
|
+
if not parsed:
|
|
1165
|
+
raise click.BadParameter(
|
|
1166
|
+
f"{option_name} {stages_key}[{stage_idx}] must contain at least one (i,j,target) triple."
|
|
1167
|
+
)
|
|
1168
|
+
for i, j, target in parsed:
|
|
1169
|
+
if target <= 0.0:
|
|
1170
|
+
raise click.BadParameter(
|
|
1171
|
+
f"Non-positive target distance in {option_name} {stages_key}[{stage_idx}]: {(i, j, target)}."
|
|
1172
|
+
)
|
|
1173
|
+
stages.append(parsed)
|
|
1174
|
+
return stages, one_based
|
|
1175
|
+
|
|
1176
|
+
|
|
1177
|
+
def parse_scan_spec_quads(
|
|
1178
|
+
spec_path: Path,
|
|
1179
|
+
*,
|
|
1180
|
+
expected_len: int,
|
|
1181
|
+
one_based_default: bool,
|
|
1182
|
+
atom_meta: Optional[_Sequence[Dict[str, Any]]],
|
|
1183
|
+
option_name: str = "--scan-lists",
|
|
1184
|
+
) -> Tuple[List[Tuple[int, int, float, float]], List[Tuple[Any, Any, float, float]], bool]:
|
|
1185
|
+
"""Parse 2D/3D scan spec into 0-based quad tuples."""
|
|
1186
|
+
spec_cfg = _load_scan_spec_root(spec_path, option_name=option_name)
|
|
1187
|
+
pairs_key, pairs_raw = _first_spec_field(spec_cfg, ("pairs",))
|
|
1188
|
+
if pairs_key is None:
|
|
1189
|
+
raise click.BadParameter(f"{option_name} must define 'pairs'.")
|
|
1190
|
+
if not isinstance(pairs_raw, (list, tuple)):
|
|
1191
|
+
raise click.BadParameter(f"{option_name} field '{pairs_key}' must be a list.")
|
|
1192
|
+
|
|
1193
|
+
one_based = _spec_one_based(
|
|
1194
|
+
spec_cfg.get("one_based"), default=one_based_default, option_name=option_name
|
|
1195
|
+
)
|
|
1196
|
+
parsed, raw_pairs = parse_scan_list_quads(
|
|
1197
|
+
repr(list(pairs_raw)),
|
|
1198
|
+
expected_len=expected_len,
|
|
1199
|
+
one_based=one_based,
|
|
1200
|
+
atom_meta=atom_meta,
|
|
1201
|
+
option_name=f"{option_name} {pairs_key}",
|
|
1202
|
+
)
|
|
1203
|
+
return parsed, raw_pairs, one_based
|
|
1204
|
+
|
|
1205
|
+
|
|
1206
|
+
PDB_ATOM_META_HEADER = f"{'id':>5} {'atom':<4} {'res':<4} {'resid':>4} {'el':<2}"
|
|
1207
|
+
|
|
1208
|
+
|
|
1209
|
+
def format_pdb_atom_metadata(atom_meta: _Sequence[Dict[str, Any]], index: int) -> str:
|
|
1210
|
+
"""Format metadata for atom *index* as aligned text: serial name resname resseq element."""
|
|
1211
|
+
fallback_serial = index + 1
|
|
1212
|
+
if index < 0 or index >= len(atom_meta):
|
|
1213
|
+
return f"{fallback_serial:>5} {'?':<4} {'?':<4} {'?':>4} {'?':<2}"
|
|
1214
|
+
|
|
1215
|
+
meta = atom_meta[index]
|
|
1216
|
+
serial = meta.get("serial") or fallback_serial
|
|
1217
|
+
name = meta.get("name") or "?"
|
|
1218
|
+
resname = meta.get("resname") or "?"
|
|
1219
|
+
resseq = meta.get("resseq")
|
|
1220
|
+
resseq_str = "?" if resseq is None else str(resseq)
|
|
1221
|
+
element = (meta.get("element") or "?").strip() or "?"
|
|
1222
|
+
|
|
1223
|
+
return f"{serial:>5} {name:<4} {resname:<4} {resseq_str:>4} {element:<2}"
|
|
1224
|
+
|
|
1225
|
+
|
|
1226
|
+
def normalize_choice(
|
|
1227
|
+
value: str,
|
|
1228
|
+
*,
|
|
1229
|
+
param: str,
|
|
1230
|
+
alias_groups: Sequence[Tuple[Sequence[str], str]],
|
|
1231
|
+
allowed_hint: str,
|
|
1232
|
+
) -> str:
|
|
1233
|
+
"""Normalize a mode choice using alias groups and raise error on failure.
|
|
1234
|
+
|
|
1235
|
+
Parameters
|
|
1236
|
+
----------
|
|
1237
|
+
value : str
|
|
1238
|
+
The value to normalize.
|
|
1239
|
+
param : str
|
|
1240
|
+
Parameter name for error messages.
|
|
1241
|
+
alias_groups : Sequence[Tuple[Sequence[str], str]]
|
|
1242
|
+
Sequence of (aliases, canonical) pairs where aliases is a sequence of strings.
|
|
1243
|
+
allowed_hint : str
|
|
1244
|
+
Description of allowed values for error messages.
|
|
1245
|
+
|
|
1246
|
+
Returns
|
|
1247
|
+
-------
|
|
1248
|
+
str
|
|
1249
|
+
The canonical value corresponding to the matched alias.
|
|
1250
|
+
|
|
1251
|
+
Raises
|
|
1252
|
+
------
|
|
1253
|
+
click.BadParameter
|
|
1254
|
+
If the value does not match any alias.
|
|
1255
|
+
"""
|
|
1256
|
+
key = (value or "").strip().lower()
|
|
1257
|
+
for aliases, canonical in alias_groups:
|
|
1258
|
+
if any(key == alias.lower() for alias in aliases):
|
|
1259
|
+
return canonical
|
|
1260
|
+
|
|
1261
|
+
hint = allowed_hint.strip()
|
|
1262
|
+
detail = f" Allowed: {hint}." if hint else ""
|
|
1263
|
+
raise click.BadParameter(f"Unknown value for {param} '{value}'.{detail}")
|
|
1264
|
+
|
|
1265
|
+
|
|
1266
|
+
def _get_mapping_section(cfg: Mapping[str, Any], path: _Sequence[str]) -> Optional[Dict[str, Any]]:
|
|
1267
|
+
cur: Any = cfg
|
|
1268
|
+
for key in path:
|
|
1269
|
+
if not isinstance(cur, Mapping):
|
|
1270
|
+
return None
|
|
1271
|
+
cur = cur.get(key)
|
|
1272
|
+
if cur is None:
|
|
1273
|
+
return None
|
|
1274
|
+
return cur if isinstance(cur, dict) else None
|
|
1275
|
+
|
|
1276
|
+
|
|
1277
|
+
def apply_yaml_overrides(
|
|
1278
|
+
yaml_cfg: Mapping[str, Any],
|
|
1279
|
+
overrides: _Sequence[Tuple[Dict[str, Any], _Sequence[_Sequence[str]]]],
|
|
1280
|
+
) -> None:
|
|
1281
|
+
"""Apply YAML overrides to multiple target dictionaries.
|
|
1282
|
+
|
|
1283
|
+
Parameters
|
|
1284
|
+
----------
|
|
1285
|
+
yaml_cfg : Mapping[str, Any]
|
|
1286
|
+
Parsed YAML configuration (root-level mapping).
|
|
1287
|
+
overrides : Sequence[Tuple[Dict[str, Any], Sequence[Sequence[str]]]]
|
|
1288
|
+
Each entry consists of the target dictionary to update followed by one or
|
|
1289
|
+
more candidate key paths. The first existing path is used. For example::
|
|
1290
|
+
|
|
1291
|
+
apply_yaml_overrides(
|
|
1292
|
+
yaml_cfg,
|
|
1293
|
+
[
|
|
1294
|
+
(geom_cfg, (("geom",),)),
|
|
1295
|
+
(lbfgs_cfg, (("stopt", "lbfgs"), ("lbfgs",))),
|
|
1296
|
+
],
|
|
1297
|
+
)
|
|
1298
|
+
|
|
1299
|
+
This mirrors the previous ``deep_update(..., yaml_cfg.get(...))`` pattern
|
|
1300
|
+
while centralizing the shared logic.
|
|
1301
|
+
"""
|
|
1302
|
+
for target, paths in overrides:
|
|
1303
|
+
for path in paths:
|
|
1304
|
+
norm_path = tuple(path)
|
|
1305
|
+
section = _get_mapping_section(yaml_cfg, norm_path)
|
|
1306
|
+
if section is not None:
|
|
1307
|
+
deep_update(target, section)
|
|
1308
|
+
break
|
|
1309
|
+
|
|
1310
|
+
|
|
1311
|
+
def yaml_section_has_key(
|
|
1312
|
+
yaml_cfg: Mapping[str, Any],
|
|
1313
|
+
paths: _Sequence[_Sequence[str]],
|
|
1314
|
+
key: str,
|
|
1315
|
+
) -> bool:
|
|
1316
|
+
"""Return True when any candidate YAML section explicitly defines ``key``."""
|
|
1317
|
+
for path in paths:
|
|
1318
|
+
section = _get_mapping_section(yaml_cfg, tuple(path))
|
|
1319
|
+
if isinstance(section, Mapping) and (key in section):
|
|
1320
|
+
return True
|
|
1321
|
+
return False
|
|
1322
|
+
|
|
1323
|
+
|
|
1324
|
+
def load_yaml_dict(path: Optional[Path]) -> Dict[str, Any]:
|
|
1325
|
+
"""
|
|
1326
|
+
Load a YAML file whose root must be a mapping. Return an empty dict if *path* is None.
|
|
1327
|
+
"""
|
|
1328
|
+
if not path:
|
|
1329
|
+
return {}
|
|
1330
|
+
|
|
1331
|
+
with open(path, "r") as f:
|
|
1332
|
+
data = yaml.safe_load(f) or {}
|
|
1333
|
+
|
|
1334
|
+
if not isinstance(data, dict):
|
|
1335
|
+
raise ValueError(f"YAML root must be a mapping, got: {type(data)}")
|
|
1336
|
+
|
|
1337
|
+
return data
|
|
1338
|
+
|
|
1339
|
+
|
|
1340
|
+
# =============================================================================
|
|
1341
|
+
# Plotly: Energy diagram builder
|
|
1342
|
+
# =============================================================================
|
|
1343
|
+
def build_energy_diagram(
|
|
1344
|
+
energies: Sequence[float],
|
|
1345
|
+
labels: Sequence[str],
|
|
1346
|
+
ylabel: str = "ΔE",
|
|
1347
|
+
baseline: bool = False,
|
|
1348
|
+
showgrid: bool = False,
|
|
1349
|
+
) -> go.Figure:
|
|
1350
|
+
"""
|
|
1351
|
+
Plot an energy diagram using Plotly.
|
|
1352
|
+
|
|
1353
|
+
Parameters
|
|
1354
|
+
----------
|
|
1355
|
+
energies : Sequence[float]
|
|
1356
|
+
Energies for each state (same unit). Values are plotted without conversion.
|
|
1357
|
+
labels : Sequence[str]
|
|
1358
|
+
Labels corresponding to each state (for example, ["R", "TS1", "IM1", "TS2", "P"]).
|
|
1359
|
+
Must be the same length as ``energies``.
|
|
1360
|
+
ylabel : str, optional
|
|
1361
|
+
Y-axis label (for example, "ΔE" or "ΔG"). Defaults to ``"ΔE"``.
|
|
1362
|
+
baseline : bool, optional
|
|
1363
|
+
If ``True``, draw a dotted baseline at the energy of the first state across the plot.
|
|
1364
|
+
showgrid : bool, optional
|
|
1365
|
+
If ``True``, show grid lines on both axes. Defaults to ``False``.
|
|
1366
|
+
|
|
1367
|
+
Returns
|
|
1368
|
+
-------
|
|
1369
|
+
plotly.graph_objs.Figure
|
|
1370
|
+
Figure containing the energy diagram.
|
|
1371
|
+
|
|
1372
|
+
Notes
|
|
1373
|
+
-----
|
|
1374
|
+
- Each state is rendered as a thick horizontal segment (width ``HLINE_WIDTH``).
|
|
1375
|
+
- Adjacent states are connected by dotted diagonal segments from the right end of
|
|
1376
|
+
the left state to the left end of the right state.
|
|
1377
|
+
- Segment length automatically shrinks with additional states so that gaps remain
|
|
1378
|
+
between neighbors.
|
|
1379
|
+
- X-axis ticks are centered on each state and labeled using ``labels``.
|
|
1380
|
+
"""
|
|
1381
|
+
if len(energies) == 0:
|
|
1382
|
+
raise ValueError("`energies` must contain at least one value.")
|
|
1383
|
+
if len(energies) != len(labels):
|
|
1384
|
+
raise ValueError("`energies` and `labels` must have the same length.")
|
|
1385
|
+
|
|
1386
|
+
n = len(energies)
|
|
1387
|
+
energies = [float(e) for e in energies]
|
|
1388
|
+
|
|
1389
|
+
# -----------------------------
|
|
1390
|
+
# Layout/style constants
|
|
1391
|
+
# -----------------------------
|
|
1392
|
+
AXIS_WIDTH = 3
|
|
1393
|
+
FONT_SIZE = 18
|
|
1394
|
+
AXIS_TITLE_SIZE = 20
|
|
1395
|
+
HLINE_WIDTH = 6 # Width of the horizontal state segments
|
|
1396
|
+
CONNECTOR_WIDTH = 2 # Width of the dotted connectors
|
|
1397
|
+
LINE_COLOR = "#1C1C1C"
|
|
1398
|
+
GRID_COLOR = "lightgrey"
|
|
1399
|
+
|
|
1400
|
+
# -----------------------------
|
|
1401
|
+
# Geometry along the X axis (centers and segment lengths)
|
|
1402
|
+
# -----------------------------
|
|
1403
|
+
# Place segment centers at 0.5, 1.5, 2.5, ... (equally spaced)
|
|
1404
|
+
centers = [i + 0.5 for i in range(n)]
|
|
1405
|
+
|
|
1406
|
+
# Shorten the segment as n grows (min 0.35, max 0.85)
|
|
1407
|
+
# Examples: n=5 -> 0.7, n=10 -> 0.5, n>=20 -> 0.35
|
|
1408
|
+
seg_width = min(0.85, max(0.35, 0.90 - 0.04 * n))
|
|
1409
|
+
half = seg_width / 2.0
|
|
1410
|
+
|
|
1411
|
+
lefts = [c - half for c in centers]
|
|
1412
|
+
rights = [c + half for c in centers]
|
|
1413
|
+
|
|
1414
|
+
# -----------------------------
|
|
1415
|
+
# Assemble the figure
|
|
1416
|
+
# -----------------------------
|
|
1417
|
+
fig = go.Figure()
|
|
1418
|
+
|
|
1419
|
+
# Baseline (dotted line at the first energy level)
|
|
1420
|
+
if baseline:
|
|
1421
|
+
fig.add_trace(
|
|
1422
|
+
go.Scatter(
|
|
1423
|
+
x=[lefts[0], rights[-1]],
|
|
1424
|
+
y=[energies[0], energies[0]],
|
|
1425
|
+
mode="lines",
|
|
1426
|
+
line=dict(color=GRID_COLOR, dash="dot", width=2),
|
|
1427
|
+
hoverinfo="skip",
|
|
1428
|
+
showlegend=False,
|
|
1429
|
+
)
|
|
1430
|
+
)
|
|
1431
|
+
|
|
1432
|
+
# Horizontal segments for each state
|
|
1433
|
+
for i, (e, lab) in enumerate(zip(energies, labels)):
|
|
1434
|
+
fig.add_trace(
|
|
1435
|
+
go.Scatter(
|
|
1436
|
+
x=[lefts[i], rights[i]],
|
|
1437
|
+
y=[e, e],
|
|
1438
|
+
mode="lines",
|
|
1439
|
+
line=dict(color=LINE_COLOR, width=HLINE_WIDTH),
|
|
1440
|
+
hovertemplate=f"{lab}: %{{y:.6f}}<extra></extra>",
|
|
1441
|
+
showlegend=False,
|
|
1442
|
+
)
|
|
1443
|
+
)
|
|
1444
|
+
|
|
1445
|
+
# Dotted diagonals between adjacent states (right end -> left end)
|
|
1446
|
+
for i in range(n - 1):
|
|
1447
|
+
fig.add_trace(
|
|
1448
|
+
go.Scatter(
|
|
1449
|
+
x=[rights[i], lefts[i + 1]],
|
|
1450
|
+
y=[energies[i], energies[i + 1]],
|
|
1451
|
+
mode="lines",
|
|
1452
|
+
line=dict(color=LINE_COLOR, width=CONNECTOR_WIDTH, dash="dot"),
|
|
1453
|
+
hoverinfo="skip",
|
|
1454
|
+
showlegend=False,
|
|
1455
|
+
)
|
|
1456
|
+
)
|
|
1457
|
+
|
|
1458
|
+
# -----------------------------
|
|
1459
|
+
# Axis ranges and styling
|
|
1460
|
+
# -----------------------------
|
|
1461
|
+
# Add a small margin beyond the first/last segments on X
|
|
1462
|
+
xpad = max(0.08, 0.15 * (1.0 - seg_width))
|
|
1463
|
+
x_min = lefts[0] - xpad
|
|
1464
|
+
x_max = rights[-1] + xpad
|
|
1465
|
+
|
|
1466
|
+
# Add vertical padding above and below
|
|
1467
|
+
y_min = min(energies)
|
|
1468
|
+
y_max = max(energies)
|
|
1469
|
+
span = max(1e-6, y_max - y_min) # Avoid zero span even if all values match
|
|
1470
|
+
ypad_low = 0.10 * span
|
|
1471
|
+
ypad_high = 0.20 * span
|
|
1472
|
+
y_range = [y_min - ypad_low, y_max + ypad_high]
|
|
1473
|
+
|
|
1474
|
+
xaxis_config = dict(
|
|
1475
|
+
range=[x_min, x_max],
|
|
1476
|
+
showline=True,
|
|
1477
|
+
linewidth=AXIS_WIDTH,
|
|
1478
|
+
linecolor=LINE_COLOR,
|
|
1479
|
+
mirror=True,
|
|
1480
|
+
ticks="inside",
|
|
1481
|
+
tickwidth=AXIS_WIDTH,
|
|
1482
|
+
tickcolor=LINE_COLOR,
|
|
1483
|
+
tickfont=dict(size=FONT_SIZE, color=LINE_COLOR),
|
|
1484
|
+
showgrid=showgrid,
|
|
1485
|
+
gridcolor=GRID_COLOR,
|
|
1486
|
+
gridwidth=0.5,
|
|
1487
|
+
zeroline=False,
|
|
1488
|
+
tickmode="array",
|
|
1489
|
+
tickvals=centers,
|
|
1490
|
+
ticktext=list(labels),
|
|
1491
|
+
title=dict(text="", font=dict(size=AXIS_TITLE_SIZE, color=LINE_COLOR)),
|
|
1492
|
+
)
|
|
1493
|
+
|
|
1494
|
+
yaxis_config = dict(
|
|
1495
|
+
range=y_range,
|
|
1496
|
+
showline=True,
|
|
1497
|
+
linewidth=AXIS_WIDTH,
|
|
1498
|
+
linecolor=LINE_COLOR,
|
|
1499
|
+
mirror=True,
|
|
1500
|
+
ticks="inside",
|
|
1501
|
+
tickwidth=AXIS_WIDTH,
|
|
1502
|
+
tickcolor=LINE_COLOR,
|
|
1503
|
+
tickfont=dict(size=FONT_SIZE, color=LINE_COLOR),
|
|
1504
|
+
showgrid=showgrid,
|
|
1505
|
+
gridcolor=GRID_COLOR,
|
|
1506
|
+
gridwidth=0.5,
|
|
1507
|
+
zeroline=False,
|
|
1508
|
+
title=dict(text=ylabel, font=dict(size=AXIS_TITLE_SIZE, color=LINE_COLOR)),
|
|
1509
|
+
)
|
|
1510
|
+
|
|
1511
|
+
fig.update_layout(
|
|
1512
|
+
xaxis=xaxis_config,
|
|
1513
|
+
yaxis=yaxis_config,
|
|
1514
|
+
plot_bgcolor="white",
|
|
1515
|
+
paper_bgcolor="white",
|
|
1516
|
+
margin=dict(l=80, r=40, t=40, b=80),
|
|
1517
|
+
)
|
|
1518
|
+
|
|
1519
|
+
return fig
|
|
1520
|
+
|
|
1521
|
+
|
|
1522
|
+
# =============================================================================
|
|
1523
|
+
# Coordinate conversion utilities
|
|
1524
|
+
# =============================================================================
|
|
1525
|
+
def convert_xyz_to_pdb(xyz_path: Path, ref_pdb_path: Path, out_pdb_path: Path) -> None:
|
|
1526
|
+
"""Overlay coordinates from *xyz_path* onto the topology of *ref_pdb_path* and write to *out_pdb_path*.
|
|
1527
|
+
|
|
1528
|
+
Notes:
|
|
1529
|
+
- *xyz_path* may contain one or many frames. For multi‑frame trajectories,
|
|
1530
|
+
a MODEL/ENDMDL block is appended for each subsequent frame in the output PDB.
|
|
1531
|
+
- On the first frame the output file is created/overwritten; subsequent frames are appended.
|
|
1532
|
+
- Validates that atom ordering (element symbols) matches between XYZ and PDB.
|
|
1533
|
+
|
|
1534
|
+
Args:
|
|
1535
|
+
xyz_path: Path to an XYZ file (single or multi-frame).
|
|
1536
|
+
ref_pdb_path: Path to a reference PDB providing atom ordering/topology.
|
|
1537
|
+
out_pdb_path: Destination PDB file to write.
|
|
1538
|
+
|
|
1539
|
+
Raises:
|
|
1540
|
+
ValueError: If no frames found in XYZ file or atom ordering mismatch.
|
|
1541
|
+
"""
|
|
1542
|
+
ref_atoms = read(ref_pdb_path) # Reference topology/ordering (single frame)
|
|
1543
|
+
traj = read(xyz_path, index=":", format="xyz") # Load all frames from the XYZ
|
|
1544
|
+
if not traj:
|
|
1545
|
+
raise ValueError(f"No frames found in {xyz_path}.")
|
|
1546
|
+
|
|
1547
|
+
ref_symbols = ref_atoms.get_chemical_symbols()
|
|
1548
|
+
|
|
1549
|
+
for step, frame in enumerate(traj):
|
|
1550
|
+
xyz_symbols = frame.get_chemical_symbols()
|
|
1551
|
+
xyz_positions = frame.get_positions()
|
|
1552
|
+
|
|
1553
|
+
if xyz_symbols != ref_symbols:
|
|
1554
|
+
# If atom counts match, the PDB likely has missing/wrong element columns.
|
|
1555
|
+
# Trust XYZ symbols (from pysisyphus which uses proper element detection)
|
|
1556
|
+
# and patch the reference atoms so PDB topology is preserved.
|
|
1557
|
+
if len(xyz_symbols) == len(ref_symbols):
|
|
1558
|
+
ref_atoms.set_chemical_symbols(xyz_symbols)
|
|
1559
|
+
ref_symbols = xyz_symbols
|
|
1560
|
+
else:
|
|
1561
|
+
raise ValueError(
|
|
1562
|
+
"Atom count mismatch between XYZ and PDB; "
|
|
1563
|
+
f"XYZ has {len(xyz_symbols)} atoms, PDB has {len(ref_symbols)}."
|
|
1564
|
+
)
|
|
1565
|
+
|
|
1566
|
+
atoms = ref_atoms.copy()
|
|
1567
|
+
atoms.set_positions(xyz_positions)
|
|
1568
|
+
if step == 0:
|
|
1569
|
+
write(out_pdb_path, atoms) # Create/overwrite on the first frame
|
|
1570
|
+
else:
|
|
1571
|
+
write(out_pdb_path, atoms, append=True) # Append subsequent frames using MODEL/ENDMDL
|
|
1572
|
+
|
|
1573
|
+
|
|
1574
|
+
# =============================================================================
|
|
1575
|
+
# Global toggle for XYZ/TRJ → PDB conversion
|
|
1576
|
+
# =============================================================================
|
|
1577
|
+
_CONVERT_FILES_ENABLED: bool = True
|
|
1578
|
+
|
|
1579
|
+
|
|
1580
|
+
def set_convert_file_enabled(enabled: bool) -> None:
|
|
1581
|
+
"""Globally enable or disable XYZ/TRJ conversions to PDB outputs."""
|
|
1582
|
+
global _CONVERT_FILES_ENABLED
|
|
1583
|
+
_CONVERT_FILES_ENABLED = bool(enabled)
|
|
1584
|
+
|
|
1585
|
+
|
|
1586
|
+
def is_convert_file_enabled() -> bool:
|
|
1587
|
+
"""Check if convert-files is globally enabled."""
|
|
1588
|
+
return _CONVERT_FILES_ENABLED
|
|
1589
|
+
|
|
1590
|
+
|
|
1591
|
+
def convert_xyz_like_outputs(
|
|
1592
|
+
xyz_path: Path,
|
|
1593
|
+
ref_pdb_path: Optional[Path],
|
|
1594
|
+
out_pdb_path: Optional[Path] = None,
|
|
1595
|
+
*,
|
|
1596
|
+
context: str = "outputs",
|
|
1597
|
+
on_error: str = "raise",
|
|
1598
|
+
) -> bool:
|
|
1599
|
+
"""Convert an XYZ file to PDB output using ref topology.
|
|
1600
|
+
|
|
1601
|
+
Respects the global _CONVERT_FILES_ENABLED toggle.
|
|
1602
|
+
Returns True when conversion succeeded; False otherwise.
|
|
1603
|
+
"""
|
|
1604
|
+
if not _CONVERT_FILES_ENABLED:
|
|
1605
|
+
return False
|
|
1606
|
+
if ref_pdb_path is None or out_pdb_path is None:
|
|
1607
|
+
return False
|
|
1608
|
+
try:
|
|
1609
|
+
convert_xyz_to_pdb(xyz_path, ref_pdb_path, out_pdb_path)
|
|
1610
|
+
return True
|
|
1611
|
+
except Exception as e:
|
|
1612
|
+
if on_error == "warn":
|
|
1613
|
+
click.echo(f"[convert] WARNING: Failed to convert {context}: {e}", err=True)
|
|
1614
|
+
return False
|
|
1615
|
+
raise click.ClickException(f"[convert] Failed to convert {context}: {e}") from e
|
|
1616
|
+
|
|
1617
|
+
|
|
1618
|
+
def pdb_keys_from_line(line: str) -> Tuple[Tuple, Tuple]:
|
|
1619
|
+
"""Extract robust keys from a PDB ATOM/HETATM record.
|
|
1620
|
+
|
|
1621
|
+
Returns:
|
|
1622
|
+
key_full: (chain, resseq, icode, resname, atomname, altloc)
|
|
1623
|
+
key_simple: (chain, resseq, icode, atomname)
|
|
1624
|
+
"""
|
|
1625
|
+
atomname = line[12:16].strip()
|
|
1626
|
+
altloc = line[16:17].strip()
|
|
1627
|
+
resname = line[17:20].strip()
|
|
1628
|
+
chain = line[21:22].strip()
|
|
1629
|
+
resseq_str = line[22:26].strip()
|
|
1630
|
+
try:
|
|
1631
|
+
resseq = int(resseq_str)
|
|
1632
|
+
except ValueError:
|
|
1633
|
+
resseq = -10**9 # unlikely sentinel when missing
|
|
1634
|
+
icode = line[26:27].strip()
|
|
1635
|
+
key_full = (chain, resseq, icode, resname, atomname, altloc)
|
|
1636
|
+
key_simple = (chain, resseq, icode, atomname)
|
|
1637
|
+
return key_full, key_simple
|
|
1638
|
+
|
|
1639
|
+
|
|
1640
|
+
def collect_ml_atom_keys(model_pdb: Path) -> Tuple[set, set]:
|
|
1641
|
+
"""Collect ML-region atom keys from model_pdb.
|
|
1642
|
+
|
|
1643
|
+
Returns:
|
|
1644
|
+
keys_full: Set of (chain, resseq, icode, resname, atomname, altloc)
|
|
1645
|
+
keys_simple: Set of (chain, resseq, icode, atomname)
|
|
1646
|
+
"""
|
|
1647
|
+
from typing import Set as SetType
|
|
1648
|
+
keys_full: SetType[Tuple] = set()
|
|
1649
|
+
keys_simple: SetType[Tuple] = set()
|
|
1650
|
+
try:
|
|
1651
|
+
with model_pdb.open("r") as fh:
|
|
1652
|
+
for line in fh:
|
|
1653
|
+
if line.startswith("ATOM") or line.startswith("HETATM"):
|
|
1654
|
+
kf, ks = pdb_keys_from_line(line)
|
|
1655
|
+
keys_full.add(kf)
|
|
1656
|
+
keys_simple.add(ks)
|
|
1657
|
+
except Exception:
|
|
1658
|
+
# If anything goes wrong, leave sets empty; caller will handle gracefully.
|
|
1659
|
+
pass
|
|
1660
|
+
return keys_full, keys_simple
|
|
1661
|
+
|
|
1662
|
+
|
|
1663
|
+
def format_pdb_with_bfactor(line: str, b: float) -> str:
|
|
1664
|
+
"""Return PDB line with B-factor field (cols 61-66) set to b (6.2f)."""
|
|
1665
|
+
if len(line) < 66:
|
|
1666
|
+
line = line.rstrip("\n")
|
|
1667
|
+
line = line + " " * max(0, 66 - len(line))
|
|
1668
|
+
line = line + "\n"
|
|
1669
|
+
bf_str = f"{b:6.2f}"
|
|
1670
|
+
# Preserve occupancy (cols 55-60), overwrite tempFactor (61-66).
|
|
1671
|
+
new_line = line[:60] + bf_str + line[66:]
|
|
1672
|
+
return new_line
|
|
1673
|
+
|
|
1674
|
+
|
|
1675
|
+
def annotate_pdb_bfactors_inplace(
|
|
1676
|
+
pdb_path: Path,
|
|
1677
|
+
model_pdb: Path,
|
|
1678
|
+
freeze_indices_0based: Sequence[int],
|
|
1679
|
+
beta_ml: float = 0.0,
|
|
1680
|
+
beta_frz: float = 20.0,
|
|
1681
|
+
beta_both: float = 0.0,
|
|
1682
|
+
) -> None:
|
|
1683
|
+
"""Overwrite B-factors in-place using 3-layer encoding (ML=0, MovableMM=10, FrozenMM=20).
|
|
1684
|
+
|
|
1685
|
+
- ML-region atoms: beta_ml (default 0.00)
|
|
1686
|
+
- frozen atoms: beta_frz (default 20.00)
|
|
1687
|
+
- ML ∩ frozen: beta_both (default 0.00, ML takes precedence)
|
|
1688
|
+
|
|
1689
|
+
Indexing for 'frozen' is 0-based and resets at each MODEL.
|
|
1690
|
+
"""
|
|
1691
|
+
ml_full, ml_simple = collect_ml_atom_keys(model_pdb)
|
|
1692
|
+
frozen_set = set(int(i) for i in (freeze_indices_0based or []))
|
|
1693
|
+
|
|
1694
|
+
try:
|
|
1695
|
+
lines = pdb_path.read_text().splitlines(keepends=True)
|
|
1696
|
+
except Exception:
|
|
1697
|
+
return
|
|
1698
|
+
|
|
1699
|
+
out_lines: List[str] = []
|
|
1700
|
+
atom_idx = 0 # resets per MODEL
|
|
1701
|
+
|
|
1702
|
+
for line in lines:
|
|
1703
|
+
rec = line[:6]
|
|
1704
|
+
if rec.startswith("MODEL"):
|
|
1705
|
+
# reset atom counter for each model
|
|
1706
|
+
atom_idx = 0
|
|
1707
|
+
out_lines.append(line)
|
|
1708
|
+
continue
|
|
1709
|
+
if rec.startswith("ATOM ") or rec.startswith("HETATM"):
|
|
1710
|
+
kf, ks = pdb_keys_from_line(line)
|
|
1711
|
+
is_ml = (kf in ml_full) or (ks in ml_simple)
|
|
1712
|
+
is_frz = (atom_idx in frozen_set)
|
|
1713
|
+
if is_ml and is_frz:
|
|
1714
|
+
out_lines.append(format_pdb_with_bfactor(line, beta_both))
|
|
1715
|
+
elif is_ml:
|
|
1716
|
+
out_lines.append(format_pdb_with_bfactor(line, beta_ml))
|
|
1717
|
+
elif is_frz:
|
|
1718
|
+
out_lines.append(format_pdb_with_bfactor(line, beta_frz))
|
|
1719
|
+
else:
|
|
1720
|
+
out_lines.append(format_pdb_with_bfactor(line, 10.0))
|
|
1721
|
+
atom_idx += 1
|
|
1722
|
+
else:
|
|
1723
|
+
out_lines.append(line)
|
|
1724
|
+
|
|
1725
|
+
try:
|
|
1726
|
+
pdb_path.write_text("".join(out_lines))
|
|
1727
|
+
except Exception:
|
|
1728
|
+
# Silently ignore if we cannot write; conversion outputs are still present.
|
|
1729
|
+
pass
|
|
1730
|
+
|
|
1731
|
+
|
|
1732
|
+
def convert_and_annotate_xyz_to_pdb(
|
|
1733
|
+
src_xyz_or_trj: Path,
|
|
1734
|
+
ref_pdb: Path,
|
|
1735
|
+
dst_pdb: Path,
|
|
1736
|
+
model_pdb: Path,
|
|
1737
|
+
freeze_indices_0based: Sequence[int],
|
|
1738
|
+
) -> None:
|
|
1739
|
+
"""Convert an XYZ/TRJ file to PDB and annotate B-factors to highlight ML and frozen atoms.
|
|
1740
|
+
|
|
1741
|
+
This mirrors the behaviour of the `opt` tool:
|
|
1742
|
+
- ML-region atoms: 100.00
|
|
1743
|
+
- frozen atoms: 50.00
|
|
1744
|
+
- ML ∩ frozen: 150.00
|
|
1745
|
+
"""
|
|
1746
|
+
try:
|
|
1747
|
+
convert_xyz_to_pdb(src_xyz_or_trj, ref_pdb, dst_pdb)
|
|
1748
|
+
annotate_pdb_bfactors_inplace(
|
|
1749
|
+
dst_pdb,
|
|
1750
|
+
model_pdb=model_pdb,
|
|
1751
|
+
freeze_indices_0based=freeze_indices_0based,
|
|
1752
|
+
)
|
|
1753
|
+
except Exception as exc:
|
|
1754
|
+
click.echo(
|
|
1755
|
+
f"[convert] WARNING: Failed to convert '{src_xyz_or_trj}' to PDB: {exc}",
|
|
1756
|
+
err=True,
|
|
1757
|
+
)
|
|
1758
|
+
|
|
1759
|
+
|
|
1760
|
+
# =============================================================================
|
|
1761
|
+
# Input preparation helpers
|
|
1762
|
+
# =============================================================================
|
|
1763
|
+
|
|
1764
|
+
|
|
1765
|
+
@dataclass
|
|
1766
|
+
class PreparedInputStructure:
|
|
1767
|
+
source_path: Path
|
|
1768
|
+
geom_path: Path
|
|
1769
|
+
|
|
1770
|
+
def cleanup(self) -> None:
|
|
1771
|
+
"""No-op: no temporary files are created."""
|
|
1772
|
+
return None
|
|
1773
|
+
|
|
1774
|
+
def __enter__(self) -> "PreparedInputStructure":
|
|
1775
|
+
return self
|
|
1776
|
+
|
|
1777
|
+
def __exit__(self, exc_type, exc, tb) -> None:
|
|
1778
|
+
self.cleanup()
|
|
1779
|
+
|
|
1780
|
+
|
|
1781
|
+
def prepare_input_structure(path: Path) -> PreparedInputStructure:
|
|
1782
|
+
"""Return a lightweight wrapper for the provided structure path."""
|
|
1783
|
+
return PreparedInputStructure(source_path=path, geom_path=path)
|
|
1784
|
+
|
|
1785
|
+
|
|
1786
|
+
def _count_atoms_in_file(path: Path) -> int:
|
|
1787
|
+
"""Count atoms in a structure file (PDB or XYZ)."""
|
|
1788
|
+
suffix = path.suffix.lower()
|
|
1789
|
+
if suffix == ".pdb":
|
|
1790
|
+
count = 0
|
|
1791
|
+
with open(path, "r") as f:
|
|
1792
|
+
for line in f:
|
|
1793
|
+
if line.startswith(("ATOM ", "HETATM")):
|
|
1794
|
+
count += 1
|
|
1795
|
+
return count
|
|
1796
|
+
elif suffix == ".xyz":
|
|
1797
|
+
# XYZ format: first line is atom count
|
|
1798
|
+
with open(path, "r") as f:
|
|
1799
|
+
first_line = f.readline().strip()
|
|
1800
|
+
try:
|
|
1801
|
+
return int(first_line)
|
|
1802
|
+
except ValueError:
|
|
1803
|
+
return 0
|
|
1804
|
+
return 0
|
|
1805
|
+
|
|
1806
|
+
|
|
1807
|
+
def apply_ref_pdb_override(
|
|
1808
|
+
prepared_input: PreparedInputStructure,
|
|
1809
|
+
ref_pdb: Optional[Path],
|
|
1810
|
+
) -> Optional[Path]:
|
|
1811
|
+
"""Use a reference PDB topology while keeping XYZ coordinates for geometry loading.
|
|
1812
|
+
|
|
1813
|
+
When --ref-pdb is provided:
|
|
1814
|
+
- geom_path remains the original input (xyz) for high-precision coordinates
|
|
1815
|
+
- source_path is updated to ref_pdb for topology/residue information
|
|
1816
|
+
"""
|
|
1817
|
+
import click
|
|
1818
|
+
if ref_pdb is None:
|
|
1819
|
+
return None
|
|
1820
|
+
ref_pdb = Path(ref_pdb).resolve()
|
|
1821
|
+
if ref_pdb.suffix.lower() != ".pdb":
|
|
1822
|
+
raise click.BadParameter("--ref-pdb must be a .pdb file.")
|
|
1823
|
+
geom_count = _count_atoms_in_file(prepared_input.geom_path)
|
|
1824
|
+
ref_count = _count_atoms_in_file(ref_pdb)
|
|
1825
|
+
if geom_count != ref_count:
|
|
1826
|
+
raise click.BadParameter(
|
|
1827
|
+
f"Atom count mismatch: {prepared_input.geom_path.name} has {geom_count} atoms, "
|
|
1828
|
+
f"but --ref-pdb {ref_pdb.name} has {ref_count} atoms."
|
|
1829
|
+
)
|
|
1830
|
+
prepared_input.source_path = ref_pdb
|
|
1831
|
+
return ref_pdb
|
|
1832
|
+
|
|
1833
|
+
|
|
1834
|
+
def _round_charge_with_note(q: float, prefix: str = "") -> int:
|
|
1835
|
+
"""Round a float charge to the nearest integer, with a note if not exact."""
|
|
1836
|
+
if not math.isfinite(q):
|
|
1837
|
+
raise click.BadParameter(f"Computed total charge is non-finite: {q!r}")
|
|
1838
|
+
q_int = int(round(q))
|
|
1839
|
+
if abs(float(q) - q_int) > 1e-6:
|
|
1840
|
+
click.echo(
|
|
1841
|
+
f"{prefix} NOTE: total charge = {q:+g} → rounded to integer {q_int:+d}."
|
|
1842
|
+
)
|
|
1843
|
+
return q_int
|
|
1844
|
+
|
|
1845
|
+
|
|
1846
|
+
def _derive_charge_from_ligand_charge(
|
|
1847
|
+
pdb_path: Path,
|
|
1848
|
+
ligand_charge: Optional[str],
|
|
1849
|
+
*,
|
|
1850
|
+
prefix: str = "",
|
|
1851
|
+
) -> Optional[int]:
|
|
1852
|
+
"""Derive total system charge from a PDB file using ``--ligand-charge`` metadata.
|
|
1853
|
+
|
|
1854
|
+
Returns ``None`` when *ligand_charge* is ``None`` or derivation fails.
|
|
1855
|
+
"""
|
|
1856
|
+
if ligand_charge is None:
|
|
1857
|
+
return None
|
|
1858
|
+
try:
|
|
1859
|
+
from Bio import PDB as BioPDB
|
|
1860
|
+
from .extract import compute_charge_summary, log_charge_summary
|
|
1861
|
+
|
|
1862
|
+
parser = BioPDB.PDBParser(QUIET=True)
|
|
1863
|
+
complex_struct = parser.get_structure("complex", str(pdb_path))
|
|
1864
|
+
|
|
1865
|
+
# Use only ML-region residues (B-factor ≈ 0) when layered PDB is available.
|
|
1866
|
+
# A residue is included if ANY of its atoms has B-factor < 1.0 (ML layer).
|
|
1867
|
+
ml_residue_ids = set()
|
|
1868
|
+
all_residue_ids = set()
|
|
1869
|
+
for res in complex_struct.get_residues():
|
|
1870
|
+
fid = res.get_full_id()
|
|
1871
|
+
all_residue_ids.add(fid)
|
|
1872
|
+
for atom in res.get_atoms():
|
|
1873
|
+
if atom.get_bfactor() < 1.0:
|
|
1874
|
+
ml_residue_ids.add(fid)
|
|
1875
|
+
break
|
|
1876
|
+
# Fall back to all residues if no B-factor layering is present
|
|
1877
|
+
# (i.e. every residue has B=0 means unlayered PDB).
|
|
1878
|
+
selected_ids = ml_residue_ids if ml_residue_ids != all_residue_ids else all_residue_ids
|
|
1879
|
+
summary = compute_charge_summary(
|
|
1880
|
+
complex_struct, selected_ids, set(), ligand_charge
|
|
1881
|
+
)
|
|
1882
|
+
log_charge_summary(prefix, summary)
|
|
1883
|
+
q_total = float(summary.get("total_charge", 0.0))
|
|
1884
|
+
click.echo(
|
|
1885
|
+
f"{prefix} Charge summary (--ligand-charge):"
|
|
1886
|
+
)
|
|
1887
|
+
click.echo(
|
|
1888
|
+
f" Protein: {summary.get('protein_charge', 0.0):+g}, "
|
|
1889
|
+
f"Ligand: {summary.get('ligand_total_charge', 0.0):+g}, "
|
|
1890
|
+
f"Ions: {summary.get('ion_total_charge', 0.0):+g}, "
|
|
1891
|
+
f"Total: {q_total:+g}"
|
|
1892
|
+
)
|
|
1893
|
+
return _round_charge_with_note(q_total, prefix)
|
|
1894
|
+
except Exception as e:
|
|
1895
|
+
click.echo(
|
|
1896
|
+
f"{prefix} NOTE: failed to derive charge from --ligand-charge: {e}",
|
|
1897
|
+
err=True,
|
|
1898
|
+
)
|
|
1899
|
+
return None
|
|
1900
|
+
|
|
1901
|
+
|
|
1902
|
+
def resolve_charge_spin_or_raise(
|
|
1903
|
+
prepared: PreparedInputStructure,
|
|
1904
|
+
charge: Optional[int],
|
|
1905
|
+
spin: Optional[int],
|
|
1906
|
+
*,
|
|
1907
|
+
spin_default: int = 1,
|
|
1908
|
+
charge_default: Optional[int] = None,
|
|
1909
|
+
ligand_charge: Optional[str] = None,
|
|
1910
|
+
prefix: str = "",
|
|
1911
|
+
) -> Tuple[int, int]:
|
|
1912
|
+
"""Resolve charge/spin from inputs.
|
|
1913
|
+
|
|
1914
|
+
Priority: explicit ``-q/--charge`` > ``--ligand-charge`` derivation >
|
|
1915
|
+
``charge_default``. Raises :class:`click.ClickException` when charge
|
|
1916
|
+
cannot be resolved.
|
|
1917
|
+
"""
|
|
1918
|
+
if charge is None and ligand_charge is not None:
|
|
1919
|
+
charge = _derive_charge_from_ligand_charge(
|
|
1920
|
+
prepared.source_path, ligand_charge, prefix=prefix,
|
|
1921
|
+
)
|
|
1922
|
+
if charge is None:
|
|
1923
|
+
if charge_default is None:
|
|
1924
|
+
raise click.ClickException(
|
|
1925
|
+
"Total charge is unresolved. Provide -q/--charge or --ligand-charge."
|
|
1926
|
+
)
|
|
1927
|
+
charge = charge_default
|
|
1928
|
+
if spin is None:
|
|
1929
|
+
spin = spin_default
|
|
1930
|
+
return int(charge), int(spin)
|
|
1931
|
+
|
|
1932
|
+
|
|
1933
|
+
# -----------------------------------------------
|
|
1934
|
+
# B-factor based 3-layer ML/MM system utilities
|
|
1935
|
+
# -----------------------------------------------
|
|
1936
|
+
|
|
1937
|
+
def read_bfactors_from_pdb(pdb_path: Path) -> List[float]:
|
|
1938
|
+
"""
|
|
1939
|
+
Read B-factor (temperature factor) values from a PDB file.
|
|
1940
|
+
|
|
1941
|
+
Returns a list of B-factors in atom order (0-indexed).
|
|
1942
|
+
Only ATOM and HETATM records are processed.
|
|
1943
|
+
"""
|
|
1944
|
+
bfactors: List[float] = []
|
|
1945
|
+
with open(pdb_path, "r") as f:
|
|
1946
|
+
for line in f:
|
|
1947
|
+
if line.startswith(("ATOM ", "HETATM")):
|
|
1948
|
+
# B-factor is at columns 61-66 (1-indexed), i.e., [60:66]
|
|
1949
|
+
try:
|
|
1950
|
+
bfac = float(line[60:66].strip())
|
|
1951
|
+
except (ValueError, IndexError):
|
|
1952
|
+
bfac = 0.0
|
|
1953
|
+
bfactors.append(bfac)
|
|
1954
|
+
return bfactors
|
|
1955
|
+
|
|
1956
|
+
|
|
1957
|
+
def parse_layer_indices_from_bfactors(
|
|
1958
|
+
bfactors: List[float],
|
|
1959
|
+
tolerance: float = 1.0,
|
|
1960
|
+
) -> Dict[str, List[int]]:
|
|
1961
|
+
"""
|
|
1962
|
+
Parse B-factor values into layer indices for 3-layer ML/MM system.
|
|
1963
|
+
|
|
1964
|
+
B-factor encoding:
|
|
1965
|
+
0.0 (±tolerance): ML atoms
|
|
1966
|
+
10.0 (±tolerance): Movable MM atoms
|
|
1967
|
+
20.0 (±tolerance): Frozen MM atoms
|
|
1968
|
+
|
|
1969
|
+
Parameters
|
|
1970
|
+
----------
|
|
1971
|
+
bfactors : List[float]
|
|
1972
|
+
B-factor values for each atom (0-indexed).
|
|
1973
|
+
tolerance : float
|
|
1974
|
+
Tolerance for B-factor matching (default: 1.0).
|
|
1975
|
+
|
|
1976
|
+
Returns
|
|
1977
|
+
-------
|
|
1978
|
+
Dict[str, List[int]]
|
|
1979
|
+
Dictionary with keys:
|
|
1980
|
+
- "ml_indices": ML region atoms
|
|
1981
|
+
- "hess_mm_indices": Compatibility key (empty in 3-layer encoding)
|
|
1982
|
+
- "movable_mm_indices": Movable MM atoms
|
|
1983
|
+
- "frozen_indices": Frozen atoms
|
|
1984
|
+
- "unassigned_indices": Atoms with B-factors not matching any layer
|
|
1985
|
+
"""
|
|
1986
|
+
from .defaults import BFACTOR_ML, BFACTOR_HESS_MM, BFACTOR_MOVABLE_MM, BFACTOR_FROZEN
|
|
1987
|
+
|
|
1988
|
+
ml_indices: List[int] = []
|
|
1989
|
+
hess_mm_indices: List[int] = []
|
|
1990
|
+
movable_mm_indices: List[int] = []
|
|
1991
|
+
frozen_indices: List[int] = []
|
|
1992
|
+
unassigned_indices: List[int] = []
|
|
1993
|
+
|
|
1994
|
+
for i, bfac in enumerate(bfactors):
|
|
1995
|
+
if abs(bfac - BFACTOR_ML) <= tolerance:
|
|
1996
|
+
ml_indices.append(i)
|
|
1997
|
+
elif abs(bfac - BFACTOR_FROZEN) <= tolerance:
|
|
1998
|
+
frozen_indices.append(i)
|
|
1999
|
+
elif abs(bfac - BFACTOR_MOVABLE_MM) <= tolerance:
|
|
2000
|
+
movable_mm_indices.append(i)
|
|
2001
|
+
elif (
|
|
2002
|
+
BFACTOR_HESS_MM != BFACTOR_MOVABLE_MM
|
|
2003
|
+
and abs(bfac - BFACTOR_HESS_MM) <= tolerance
|
|
2004
|
+
):
|
|
2005
|
+
hess_mm_indices.append(i)
|
|
2006
|
+
else:
|
|
2007
|
+
unassigned_indices.append(i)
|
|
2008
|
+
|
|
2009
|
+
return {
|
|
2010
|
+
"ml_indices": ml_indices,
|
|
2011
|
+
"hess_mm_indices": hess_mm_indices,
|
|
2012
|
+
"movable_mm_indices": movable_mm_indices,
|
|
2013
|
+
"frozen_indices": frozen_indices,
|
|
2014
|
+
"unassigned_indices": unassigned_indices,
|
|
2015
|
+
}
|
|
2016
|
+
|
|
2017
|
+
|
|
2018
|
+
def has_valid_layer_bfactors(bfactors: List[float], tolerance: float = 1.0) -> bool:
|
|
2019
|
+
"""
|
|
2020
|
+
Check if PDB B-factors contain valid 3-layer encoding.
|
|
2021
|
+
|
|
2022
|
+
Returns True if at least one atom has ML B-factor and the B-factors are
|
|
2023
|
+
predominantly in the expected range (0, 10, 20).
|
|
2024
|
+
"""
|
|
2025
|
+
from .defaults import BFACTOR_ML, BFACTOR_HESS_MM, BFACTOR_MOVABLE_MM, BFACTOR_FROZEN
|
|
2026
|
+
|
|
2027
|
+
valid_bfactors = {BFACTOR_ML, BFACTOR_MOVABLE_MM, BFACTOR_FROZEN, BFACTOR_HESS_MM}
|
|
2028
|
+
has_ml = False
|
|
2029
|
+
valid_count = 0
|
|
2030
|
+
|
|
2031
|
+
for bfac in bfactors:
|
|
2032
|
+
for valid in valid_bfactors:
|
|
2033
|
+
if abs(bfac - valid) <= tolerance:
|
|
2034
|
+
valid_count += 1
|
|
2035
|
+
if abs(bfac - BFACTOR_ML) <= tolerance:
|
|
2036
|
+
has_ml = True
|
|
2037
|
+
break
|
|
2038
|
+
|
|
2039
|
+
# Consider valid if:
|
|
2040
|
+
# 1. Has at least one ML atom
|
|
2041
|
+
# 2. At least 80% of atoms have valid B-factors
|
|
2042
|
+
return has_ml and (valid_count / max(len(bfactors), 1) >= 0.8)
|
|
2043
|
+
|
|
2044
|
+
|
|
2045
|
+
def parse_indices_string(indices_str: str, one_based: bool = True) -> List[int]:
|
|
2046
|
+
"""
|
|
2047
|
+
Parse a comma-separated index string into a sorted list of 0-based ints.
|
|
2048
|
+
|
|
2049
|
+
Supports ranges like "1-5" (inclusive). By default, inputs are 1-based.
|
|
2050
|
+
"""
|
|
2051
|
+
import click
|
|
2052
|
+
if indices_str is None:
|
|
2053
|
+
return []
|
|
2054
|
+
tokens = [tok.strip() for tok in str(indices_str).replace(" ", ",").split(",") if tok.strip()]
|
|
2055
|
+
indices: List[int] = []
|
|
2056
|
+
for token in tokens:
|
|
2057
|
+
if "-" in token and not token.startswith("-"):
|
|
2058
|
+
parts = token.split("-")
|
|
2059
|
+
if len(parts) == 2 and parts[0] and parts[1]:
|
|
2060
|
+
try:
|
|
2061
|
+
start = int(parts[0])
|
|
2062
|
+
end = int(parts[1])
|
|
2063
|
+
except ValueError as exc:
|
|
2064
|
+
raise click.BadParameter(f"Invalid range token in --model-indices: '{token}'") from exc
|
|
2065
|
+
if one_based:
|
|
2066
|
+
start -= 1
|
|
2067
|
+
end -= 1
|
|
2068
|
+
if start < 0 or end < 0 or start > end:
|
|
2069
|
+
raise click.BadParameter(f"Invalid range in --model-indices: '{token}'")
|
|
2070
|
+
indices.extend(range(start, end + 1))
|
|
2071
|
+
continue
|
|
2072
|
+
try:
|
|
2073
|
+
value = int(token)
|
|
2074
|
+
except ValueError as exc:
|
|
2075
|
+
raise click.BadParameter(f"Invalid index in --model-indices: '{token}'") from exc
|
|
2076
|
+
if one_based:
|
|
2077
|
+
value -= 1
|
|
2078
|
+
if value < 0:
|
|
2079
|
+
raise click.BadParameter(f"--model-indices expects positive indices; got {value + (1 if one_based else 0)}")
|
|
2080
|
+
indices.append(value)
|
|
2081
|
+
return sorted(set(indices))
|
|
2082
|
+
|
|
2083
|
+
|
|
2084
|
+
def write_model_pdb_from_indices(
|
|
2085
|
+
input_pdb_path: Path,
|
|
2086
|
+
output_pdb_path: Path,
|
|
2087
|
+
indices: Sequence[int],
|
|
2088
|
+
) -> None:
|
|
2089
|
+
"""
|
|
2090
|
+
Write a model PDB containing only atoms at the specified 0-based indices.
|
|
2091
|
+
"""
|
|
2092
|
+
import click
|
|
2093
|
+
if not indices:
|
|
2094
|
+
raise ValueError("No indices provided to build model PDB.")
|
|
2095
|
+
n_atoms = _count_atoms_in_file(input_pdb_path)
|
|
2096
|
+
if n_atoms <= 0:
|
|
2097
|
+
raise ValueError(f"No atoms found in input PDB: {input_pdb_path}")
|
|
2098
|
+
for idx in indices:
|
|
2099
|
+
if idx < 0 or idx >= n_atoms:
|
|
2100
|
+
raise click.BadParameter(
|
|
2101
|
+
f"model index out of range: {idx} (valid: 0 <= idx < {n_atoms})"
|
|
2102
|
+
)
|
|
2103
|
+
|
|
2104
|
+
keep = set(int(i) for i in indices)
|
|
2105
|
+
lines_out: List[str] = []
|
|
2106
|
+
atom_idx = 0
|
|
2107
|
+
with open(input_pdb_path, "r") as f:
|
|
2108
|
+
for line in f:
|
|
2109
|
+
if line.startswith(("ATOM ", "HETATM")):
|
|
2110
|
+
if atom_idx in keep:
|
|
2111
|
+
# Auto-fill element column (77-78) if missing
|
|
2112
|
+
raw = line.rstrip("\n")
|
|
2113
|
+
elem_field = raw[76:78].strip() if len(raw) >= 78 else ""
|
|
2114
|
+
if not elem_field:
|
|
2115
|
+
atom_name = raw[12:16].strip()
|
|
2116
|
+
res_name = raw[17:20].strip()
|
|
2117
|
+
is_hetatm = raw.startswith("HETATM")
|
|
2118
|
+
elem = guess_element(atom_name, res_name, is_hetatm)
|
|
2119
|
+
if elem:
|
|
2120
|
+
padded = raw.ljust(76) + f"{elem:>2}" + "\n"
|
|
2121
|
+
lines_out.append(padded)
|
|
2122
|
+
else:
|
|
2123
|
+
lines_out.append(line)
|
|
2124
|
+
else:
|
|
2125
|
+
lines_out.append(line)
|
|
2126
|
+
atom_idx += 1
|
|
2127
|
+
if not lines_out:
|
|
2128
|
+
raise ValueError("Model PDB would be empty; check indices and input PDB.")
|
|
2129
|
+
if not lines_out[-1].endswith("\n"):
|
|
2130
|
+
lines_out[-1] = lines_out[-1] + "\n"
|
|
2131
|
+
lines_out.append("END\n")
|
|
2132
|
+
with open(output_pdb_path, "w") as f:
|
|
2133
|
+
f.writelines(lines_out)
|
|
2134
|
+
|
|
2135
|
+
|
|
2136
|
+
def build_model_pdb_from_indices(
|
|
2137
|
+
input_pdb_path: Path,
|
|
2138
|
+
out_dir: Path,
|
|
2139
|
+
indices: Sequence[int],
|
|
2140
|
+
*,
|
|
2141
|
+
label: str = "model_from_indices",
|
|
2142
|
+
) -> Path:
|
|
2143
|
+
"""
|
|
2144
|
+
Create a temporary model PDB under out_dir using explicit indices.
|
|
2145
|
+
"""
|
|
2146
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
2147
|
+
with tempfile.NamedTemporaryFile(
|
|
2148
|
+
mode="w",
|
|
2149
|
+
suffix=".pdb",
|
|
2150
|
+
prefix=f"{label}_",
|
|
2151
|
+
dir=out_dir,
|
|
2152
|
+
delete=False,
|
|
2153
|
+
) as tmp:
|
|
2154
|
+
tmp_path = Path(tmp.name)
|
|
2155
|
+
write_model_pdb_from_indices(input_pdb_path, tmp_path, indices)
|
|
2156
|
+
return tmp_path
|
|
2157
|
+
|
|
2158
|
+
|
|
2159
|
+
def build_model_pdb_from_bfactors(
|
|
2160
|
+
input_pdb_path: Path,
|
|
2161
|
+
out_dir: Path,
|
|
2162
|
+
*,
|
|
2163
|
+
tolerance: float = None,
|
|
2164
|
+
label: str = "model_from_bfactor",
|
|
2165
|
+
) -> Tuple[Path, Dict[str, List[int]]]:
|
|
2166
|
+
"""
|
|
2167
|
+
Create a model PDB using ML indices derived from B-factors.
|
|
2168
|
+
|
|
2169
|
+
Returns (model_pdb_path, layer_info).
|
|
2170
|
+
"""
|
|
2171
|
+
from .defaults import BFACTOR_TOLERANCE
|
|
2172
|
+
tol = BFACTOR_TOLERANCE if tolerance is None else float(tolerance)
|
|
2173
|
+
bfactors = read_bfactors_from_pdb(input_pdb_path)
|
|
2174
|
+
if not bfactors:
|
|
2175
|
+
raise ValueError(f"No ATOM/HETATM records found in {input_pdb_path}.")
|
|
2176
|
+
if not has_valid_layer_bfactors(bfactors, tolerance=tol):
|
|
2177
|
+
raise ValueError(
|
|
2178
|
+
"Invalid or missing layer B-factors (expected ~0/10/20). "
|
|
2179
|
+
"Provide --no-detect-layer with --model-pdb/--model-indices."
|
|
2180
|
+
)
|
|
2181
|
+
layer_info = parse_layer_indices_from_bfactors(bfactors, tolerance=tol)
|
|
2182
|
+
ml_indices = layer_info.get("ml_indices") or []
|
|
2183
|
+
if not ml_indices:
|
|
2184
|
+
raise ValueError("No ML atoms detected from B-factors (value ~0).")
|
|
2185
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
2186
|
+
tmp_path = out_dir / f"{label}.pdb"
|
|
2187
|
+
write_model_pdb_from_indices(input_pdb_path, tmp_path, ml_indices)
|
|
2188
|
+
return tmp_path, layer_info
|
|
2189
|
+
|
|
2190
|
+
|
|
2191
|
+
def write_layer_bfactors_to_pdb(
|
|
2192
|
+
input_pdb_path: Path,
|
|
2193
|
+
output_pdb_path: Path,
|
|
2194
|
+
ml_indices: List[int],
|
|
2195
|
+
hess_mm_indices: Optional[List[int]] = None,
|
|
2196
|
+
movable_mm_indices: Optional[List[int]] = None,
|
|
2197
|
+
frozen_indices: Optional[List[int]] = None,
|
|
2198
|
+
) -> None:
|
|
2199
|
+
"""
|
|
2200
|
+
Write a PDB file with B-factors set according to 3-layer assignments.
|
|
2201
|
+
|
|
2202
|
+
B-factor encoding:
|
|
2203
|
+
ML atoms: 0.0
|
|
2204
|
+
Movable MM atoms: 10.0
|
|
2205
|
+
Frozen MM atoms: 20.0
|
|
2206
|
+
Hessian MM atoms: encoded with the same B-factor as movable MM
|
|
2207
|
+
|
|
2208
|
+
Parameters
|
|
2209
|
+
----------
|
|
2210
|
+
input_pdb_path : Path
|
|
2211
|
+
Source PDB file to read atom records from.
|
|
2212
|
+
output_pdb_path : Path
|
|
2213
|
+
Output PDB file path.
|
|
2214
|
+
ml_indices : List[int]
|
|
2215
|
+
0-based indices of ML region atoms.
|
|
2216
|
+
hess_mm_indices : Optional[List[int]]
|
|
2217
|
+
0-based indices of MM atoms with Hessian (written as movable B-factor).
|
|
2218
|
+
movable_mm_indices : Optional[List[int]]
|
|
2219
|
+
0-based indices of movable MM atoms without Hessian.
|
|
2220
|
+
frozen_indices : Optional[List[int]]
|
|
2221
|
+
0-based indices of frozen atoms.
|
|
2222
|
+
|
|
2223
|
+
Notes
|
|
2224
|
+
-----
|
|
2225
|
+
Supports multi-MODEL PDB files (e.g., trajectories): atom index resets
|
|
2226
|
+
at each MODEL record.
|
|
2227
|
+
"""
|
|
2228
|
+
from .defaults import BFACTOR_ML, BFACTOR_HESS_MM, BFACTOR_MOVABLE_MM, BFACTOR_FROZEN
|
|
2229
|
+
|
|
2230
|
+
ml_set = set(ml_indices or [])
|
|
2231
|
+
hess_mm_set = set(hess_mm_indices or [])
|
|
2232
|
+
movable_mm_set = set(movable_mm_indices or [])
|
|
2233
|
+
frozen_set = set(frozen_indices or [])
|
|
2234
|
+
|
|
2235
|
+
lines_out: List[str] = []
|
|
2236
|
+
atom_idx = 0
|
|
2237
|
+
|
|
2238
|
+
with open(input_pdb_path, "r") as f:
|
|
2239
|
+
for line in f:
|
|
2240
|
+
rec = line[:6]
|
|
2241
|
+
# Reset atom counter at each MODEL record (for trajectory files)
|
|
2242
|
+
if rec.startswith("MODEL"):
|
|
2243
|
+
atom_idx = 0
|
|
2244
|
+
lines_out.append(line)
|
|
2245
|
+
continue
|
|
2246
|
+
|
|
2247
|
+
if line.startswith(("ATOM ", "HETATM")):
|
|
2248
|
+
# Determine B-factor for this atom
|
|
2249
|
+
if atom_idx in ml_set:
|
|
2250
|
+
bfac = BFACTOR_ML
|
|
2251
|
+
elif atom_idx in hess_mm_set:
|
|
2252
|
+
bfac = BFACTOR_HESS_MM
|
|
2253
|
+
elif atom_idx in movable_mm_set:
|
|
2254
|
+
bfac = BFACTOR_MOVABLE_MM
|
|
2255
|
+
elif atom_idx in frozen_set:
|
|
2256
|
+
bfac = BFACTOR_FROZEN
|
|
2257
|
+
else:
|
|
2258
|
+
# Default: treat as movable MM (layer 3)
|
|
2259
|
+
bfac = BFACTOR_MOVABLE_MM
|
|
2260
|
+
|
|
2261
|
+
# Replace B-factor (columns 61-66, 1-indexed)
|
|
2262
|
+
# PDB format: columns 61-66 are B-factor with format %6.2f
|
|
2263
|
+
# Ensure line is long enough before modifying
|
|
2264
|
+
if len(line) >= 66:
|
|
2265
|
+
new_line = line[:60] + f"{bfac:6.2f}" + line[66:]
|
|
2266
|
+
else:
|
|
2267
|
+
# Pad line if too short
|
|
2268
|
+
padded = line.rstrip("\n").ljust(66)
|
|
2269
|
+
new_line = padded[:60] + f"{bfac:6.2f}" + "\n"
|
|
2270
|
+
lines_out.append(new_line)
|
|
2271
|
+
atom_idx += 1
|
|
2272
|
+
else:
|
|
2273
|
+
lines_out.append(line)
|
|
2274
|
+
|
|
2275
|
+
with open(output_pdb_path, "w") as f:
|
|
2276
|
+
f.writelines(lines_out)
|
|
2277
|
+
|
|
2278
|
+
|
|
2279
|
+
def update_pdb_bfactors_from_layers(
|
|
2280
|
+
pdb_path: Path,
|
|
2281
|
+
ml_indices: List[int],
|
|
2282
|
+
hess_mm_indices: Optional[List[int]] = None,
|
|
2283
|
+
movable_mm_indices: Optional[List[int]] = None,
|
|
2284
|
+
frozen_indices: Optional[List[int]] = None,
|
|
2285
|
+
) -> None:
|
|
2286
|
+
"""
|
|
2287
|
+
Update B-factors in a PDB file in-place based on layer assignments.
|
|
2288
|
+
|
|
2289
|
+
This is a convenience wrapper that reads and writes to the same file.
|
|
2290
|
+
"""
|
|
2291
|
+
import tempfile
|
|
2292
|
+
import shutil
|
|
2293
|
+
|
|
2294
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".pdb", delete=False) as tmp:
|
|
2295
|
+
tmp_path = Path(tmp.name)
|
|
2296
|
+
|
|
2297
|
+
try:
|
|
2298
|
+
write_layer_bfactors_to_pdb(
|
|
2299
|
+
pdb_path,
|
|
2300
|
+
tmp_path,
|
|
2301
|
+
ml_indices,
|
|
2302
|
+
hess_mm_indices,
|
|
2303
|
+
movable_mm_indices,
|
|
2304
|
+
frozen_indices,
|
|
2305
|
+
)
|
|
2306
|
+
shutil.move(str(tmp_path), str(pdb_path))
|
|
2307
|
+
finally:
|
|
2308
|
+
if tmp_path.exists():
|
|
2309
|
+
tmp_path.unlink()
|