MultiOptPy 1.20.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- multioptpy/Calculator/__init__.py +0 -0
- multioptpy/Calculator/ase_calculation_tools.py +424 -0
- multioptpy/Calculator/ase_tools/__init__.py +0 -0
- multioptpy/Calculator/ase_tools/fairchem.py +28 -0
- multioptpy/Calculator/ase_tools/gamess.py +19 -0
- multioptpy/Calculator/ase_tools/gaussian.py +165 -0
- multioptpy/Calculator/ase_tools/mace.py +28 -0
- multioptpy/Calculator/ase_tools/mopac.py +19 -0
- multioptpy/Calculator/ase_tools/nwchem.py +31 -0
- multioptpy/Calculator/ase_tools/orca.py +22 -0
- multioptpy/Calculator/ase_tools/pygfn0.py +37 -0
- multioptpy/Calculator/dxtb_calculation_tools.py +344 -0
- multioptpy/Calculator/emt_calculation_tools.py +458 -0
- multioptpy/Calculator/gpaw_calculation_tools.py +183 -0
- multioptpy/Calculator/lj_calculation_tools.py +314 -0
- multioptpy/Calculator/psi4_calculation_tools.py +334 -0
- multioptpy/Calculator/pwscf_calculation_tools.py +189 -0
- multioptpy/Calculator/pyscf_calculation_tools.py +327 -0
- multioptpy/Calculator/sqm1_calculation_tools.py +611 -0
- multioptpy/Calculator/sqm2_calculation_tools.py +376 -0
- multioptpy/Calculator/tblite_calculation_tools.py +352 -0
- multioptpy/Calculator/tersoff_calculation_tools.py +818 -0
- multioptpy/Constraint/__init__.py +0 -0
- multioptpy/Constraint/constraint_condition.py +834 -0
- multioptpy/Coordinate/__init__.py +0 -0
- multioptpy/Coordinate/polar_coordinate.py +199 -0
- multioptpy/Coordinate/redundant_coordinate.py +638 -0
- multioptpy/IRC/__init__.py +0 -0
- multioptpy/IRC/converge_criteria.py +28 -0
- multioptpy/IRC/dvv.py +544 -0
- multioptpy/IRC/euler.py +439 -0
- multioptpy/IRC/hpc.py +564 -0
- multioptpy/IRC/lqa.py +540 -0
- multioptpy/IRC/modekill.py +662 -0
- multioptpy/IRC/rk4.py +579 -0
- multioptpy/Interpolation/__init__.py +0 -0
- multioptpy/Interpolation/adaptive_interpolation.py +283 -0
- multioptpy/Interpolation/binomial_interpolation.py +179 -0
- multioptpy/Interpolation/geodesic_interpolation.py +785 -0
- multioptpy/Interpolation/interpolation.py +156 -0
- multioptpy/Interpolation/linear_interpolation.py +473 -0
- multioptpy/Interpolation/savitzky_golay_interpolation.py +252 -0
- multioptpy/Interpolation/spline_interpolation.py +353 -0
- multioptpy/MD/__init__.py +0 -0
- multioptpy/MD/thermostat.py +185 -0
- multioptpy/MEP/__init__.py +0 -0
- multioptpy/MEP/pathopt_bneb_force.py +443 -0
- multioptpy/MEP/pathopt_dmf_force.py +448 -0
- multioptpy/MEP/pathopt_dneb_force.py +130 -0
- multioptpy/MEP/pathopt_ewbneb_force.py +207 -0
- multioptpy/MEP/pathopt_gpneb_force.py +512 -0
- multioptpy/MEP/pathopt_lup_force.py +113 -0
- multioptpy/MEP/pathopt_neb_force.py +225 -0
- multioptpy/MEP/pathopt_nesb_force.py +205 -0
- multioptpy/MEP/pathopt_om_force.py +153 -0
- multioptpy/MEP/pathopt_qsm_force.py +174 -0
- multioptpy/MEP/pathopt_qsmv2_force.py +304 -0
- multioptpy/ModelFunction/__init__.py +7 -0
- multioptpy/ModelFunction/avoiding_model_function.py +29 -0
- multioptpy/ModelFunction/binary_image_ts_search_model_function.py +47 -0
- multioptpy/ModelFunction/conical_model_function.py +26 -0
- multioptpy/ModelFunction/opt_meci.py +50 -0
- multioptpy/ModelFunction/opt_mesx.py +47 -0
- multioptpy/ModelFunction/opt_mesx_2.py +49 -0
- multioptpy/ModelFunction/seam_model_function.py +27 -0
- multioptpy/ModelHessian/__init__.py +0 -0
- multioptpy/ModelHessian/approx_hessian.py +147 -0
- multioptpy/ModelHessian/calc_params.py +227 -0
- multioptpy/ModelHessian/fischer.py +236 -0
- multioptpy/ModelHessian/fischerd3.py +360 -0
- multioptpy/ModelHessian/fischerd4.py +398 -0
- multioptpy/ModelHessian/gfn0xtb.py +633 -0
- multioptpy/ModelHessian/gfnff.py +709 -0
- multioptpy/ModelHessian/lindh.py +165 -0
- multioptpy/ModelHessian/lindh2007d2.py +707 -0
- multioptpy/ModelHessian/lindh2007d3.py +822 -0
- multioptpy/ModelHessian/lindh2007d4.py +1030 -0
- multioptpy/ModelHessian/morse.py +106 -0
- multioptpy/ModelHessian/schlegel.py +144 -0
- multioptpy/ModelHessian/schlegeld3.py +322 -0
- multioptpy/ModelHessian/schlegeld4.py +559 -0
- multioptpy/ModelHessian/shortrange.py +346 -0
- multioptpy/ModelHessian/swartd2.py +496 -0
- multioptpy/ModelHessian/swartd3.py +706 -0
- multioptpy/ModelHessian/swartd4.py +918 -0
- multioptpy/ModelHessian/tshess.py +40 -0
- multioptpy/Optimizer/QHAdam.py +61 -0
- multioptpy/Optimizer/__init__.py +0 -0
- multioptpy/Optimizer/abc_fire.py +83 -0
- multioptpy/Optimizer/adabelief.py +58 -0
- multioptpy/Optimizer/adabound.py +68 -0
- multioptpy/Optimizer/adadelta.py +65 -0
- multioptpy/Optimizer/adaderivative.py +56 -0
- multioptpy/Optimizer/adadiff.py +68 -0
- multioptpy/Optimizer/adafactor.py +70 -0
- multioptpy/Optimizer/adam.py +65 -0
- multioptpy/Optimizer/adamax.py +62 -0
- multioptpy/Optimizer/adamod.py +83 -0
- multioptpy/Optimizer/adamw.py +65 -0
- multioptpy/Optimizer/adiis.py +523 -0
- multioptpy/Optimizer/afire_neb.py +282 -0
- multioptpy/Optimizer/block_hessian_update.py +709 -0
- multioptpy/Optimizer/c2diis.py +491 -0
- multioptpy/Optimizer/component_wise_scaling.py +405 -0
- multioptpy/Optimizer/conjugate_gradient.py +82 -0
- multioptpy/Optimizer/conjugate_gradient_neb.py +345 -0
- multioptpy/Optimizer/coordinate_locking.py +405 -0
- multioptpy/Optimizer/dic_rsirfo.py +1015 -0
- multioptpy/Optimizer/ediis.py +417 -0
- multioptpy/Optimizer/eve.py +76 -0
- multioptpy/Optimizer/fastadabelief.py +61 -0
- multioptpy/Optimizer/fire.py +77 -0
- multioptpy/Optimizer/fire2.py +249 -0
- multioptpy/Optimizer/fire_neb.py +92 -0
- multioptpy/Optimizer/gan_step.py +486 -0
- multioptpy/Optimizer/gdiis.py +609 -0
- multioptpy/Optimizer/gediis.py +203 -0
- multioptpy/Optimizer/geodesic_step.py +433 -0
- multioptpy/Optimizer/gpmin.py +633 -0
- multioptpy/Optimizer/gpr_step.py +364 -0
- multioptpy/Optimizer/gradientdescent.py +78 -0
- multioptpy/Optimizer/gradientdescent_neb.py +52 -0
- multioptpy/Optimizer/hessian_update.py +433 -0
- multioptpy/Optimizer/hybrid_rfo.py +998 -0
- multioptpy/Optimizer/kdiis.py +625 -0
- multioptpy/Optimizer/lars.py +21 -0
- multioptpy/Optimizer/lbfgs.py +253 -0
- multioptpy/Optimizer/lbfgs_neb.py +355 -0
- multioptpy/Optimizer/linesearch.py +236 -0
- multioptpy/Optimizer/lookahead.py +40 -0
- multioptpy/Optimizer/nadam.py +64 -0
- multioptpy/Optimizer/newton.py +200 -0
- multioptpy/Optimizer/prodigy.py +70 -0
- multioptpy/Optimizer/purtubation.py +16 -0
- multioptpy/Optimizer/quickmin_neb.py +245 -0
- multioptpy/Optimizer/radam.py +75 -0
- multioptpy/Optimizer/rfo_neb.py +302 -0
- multioptpy/Optimizer/ric_rfo.py +842 -0
- multioptpy/Optimizer/rl_step.py +627 -0
- multioptpy/Optimizer/rmspropgrave.py +65 -0
- multioptpy/Optimizer/rsirfo.py +1647 -0
- multioptpy/Optimizer/rsprfo.py +1056 -0
- multioptpy/Optimizer/sadam.py +60 -0
- multioptpy/Optimizer/samsgrad.py +63 -0
- multioptpy/Optimizer/tr_lbfgs.py +678 -0
- multioptpy/Optimizer/trim.py +273 -0
- multioptpy/Optimizer/trust_radius.py +207 -0
- multioptpy/Optimizer/trust_radius_neb.py +121 -0
- multioptpy/Optimizer/yogi.py +60 -0
- multioptpy/OtherMethod/__init__.py +0 -0
- multioptpy/OtherMethod/addf.py +1150 -0
- multioptpy/OtherMethod/dimer.py +895 -0
- multioptpy/OtherMethod/elastic_image_pair.py +629 -0
- multioptpy/OtherMethod/modelfunction.py +456 -0
- multioptpy/OtherMethod/newton_traj.py +454 -0
- multioptpy/OtherMethod/twopshs.py +1095 -0
- multioptpy/PESAnalyzer/__init__.py +0 -0
- multioptpy/PESAnalyzer/calc_irc_curvature.py +125 -0
- multioptpy/PESAnalyzer/cmds_analysis.py +152 -0
- multioptpy/PESAnalyzer/koopman_analysis.py +268 -0
- multioptpy/PESAnalyzer/pca_analysis.py +314 -0
- multioptpy/Parameters/__init__.py +0 -0
- multioptpy/Parameters/atomic_mass.py +20 -0
- multioptpy/Parameters/atomic_number.py +22 -0
- multioptpy/Parameters/covalent_radii.py +44 -0
- multioptpy/Parameters/d2.py +61 -0
- multioptpy/Parameters/d3.py +63 -0
- multioptpy/Parameters/d4.py +103 -0
- multioptpy/Parameters/dreiding.py +34 -0
- multioptpy/Parameters/gfn0xtb_param.py +137 -0
- multioptpy/Parameters/gfnff_param.py +315 -0
- multioptpy/Parameters/gnb.py +104 -0
- multioptpy/Parameters/parameter.py +22 -0
- multioptpy/Parameters/uff.py +72 -0
- multioptpy/Parameters/unit_values.py +20 -0
- multioptpy/Potential/AFIR_potential.py +55 -0
- multioptpy/Potential/LJ_repulsive_potential.py +345 -0
- multioptpy/Potential/__init__.py +0 -0
- multioptpy/Potential/anharmonic_keep_potential.py +28 -0
- multioptpy/Potential/asym_elllipsoidal_potential.py +718 -0
- multioptpy/Potential/electrostatic_potential.py +69 -0
- multioptpy/Potential/flux_potential.py +30 -0
- multioptpy/Potential/gaussian_potential.py +101 -0
- multioptpy/Potential/idpp.py +516 -0
- multioptpy/Potential/keep_angle_potential.py +146 -0
- multioptpy/Potential/keep_dihedral_angle_potential.py +105 -0
- multioptpy/Potential/keep_outofplain_angle_potential.py +70 -0
- multioptpy/Potential/keep_potential.py +99 -0
- multioptpy/Potential/mechano_force_potential.py +74 -0
- multioptpy/Potential/nanoreactor_potential.py +52 -0
- multioptpy/Potential/potential.py +896 -0
- multioptpy/Potential/spacer_model_potential.py +221 -0
- multioptpy/Potential/switching_potential.py +258 -0
- multioptpy/Potential/universal_potential.py +34 -0
- multioptpy/Potential/value_range_potential.py +36 -0
- multioptpy/Potential/void_point_potential.py +25 -0
- multioptpy/SQM/__init__.py +0 -0
- multioptpy/SQM/sqm1/__init__.py +0 -0
- multioptpy/SQM/sqm1/sqm1_core.py +1792 -0
- multioptpy/SQM/sqm2/__init__.py +0 -0
- multioptpy/SQM/sqm2/calc_tools.py +95 -0
- multioptpy/SQM/sqm2/sqm2_basis.py +850 -0
- multioptpy/SQM/sqm2/sqm2_bond.py +119 -0
- multioptpy/SQM/sqm2/sqm2_core.py +303 -0
- multioptpy/SQM/sqm2/sqm2_data.py +1229 -0
- multioptpy/SQM/sqm2/sqm2_disp.py +65 -0
- multioptpy/SQM/sqm2/sqm2_eeq.py +243 -0
- multioptpy/SQM/sqm2/sqm2_overlapint.py +704 -0
- multioptpy/SQM/sqm2/sqm2_qm.py +578 -0
- multioptpy/SQM/sqm2/sqm2_rep.py +66 -0
- multioptpy/SQM/sqm2/sqm2_srb.py +70 -0
- multioptpy/Thermo/__init__.py +0 -0
- multioptpy/Thermo/normal_mode_analyzer.py +865 -0
- multioptpy/Utils/__init__.py +0 -0
- multioptpy/Utils/bond_connectivity.py +264 -0
- multioptpy/Utils/calc_tools.py +884 -0
- multioptpy/Utils/oniom.py +96 -0
- multioptpy/Utils/pbc.py +48 -0
- multioptpy/Utils/riemann_curvature.py +208 -0
- multioptpy/Utils/symmetry_analyzer.py +482 -0
- multioptpy/Visualization/__init__.py +0 -0
- multioptpy/Visualization/visualization.py +156 -0
- multioptpy/WFAnalyzer/MO_analysis.py +104 -0
- multioptpy/WFAnalyzer/__init__.py +0 -0
- multioptpy/Wrapper/__init__.py +0 -0
- multioptpy/Wrapper/autots.py +1239 -0
- multioptpy/Wrapper/ieip_wrapper.py +93 -0
- multioptpy/Wrapper/md_wrapper.py +92 -0
- multioptpy/Wrapper/neb_wrapper.py +94 -0
- multioptpy/Wrapper/optimize_wrapper.py +76 -0
- multioptpy/__init__.py +5 -0
- multioptpy/entrypoints.py +916 -0
- multioptpy/fileio.py +660 -0
- multioptpy/ieip.py +340 -0
- multioptpy/interface.py +1086 -0
- multioptpy/irc.py +529 -0
- multioptpy/moleculardynamics.py +432 -0
- multioptpy/neb.py +1267 -0
- multioptpy/optimization.py +1553 -0
- multioptpy/optimizer.py +709 -0
- multioptpy-1.20.2.dist-info/METADATA +438 -0
- multioptpy-1.20.2.dist-info/RECORD +246 -0
- multioptpy-1.20.2.dist-info/WHEEL +5 -0
- multioptpy-1.20.2.dist-info/entry_points.txt +9 -0
- multioptpy-1.20.2.dist-info/licenses/LICENSE +674 -0
- multioptpy-1.20.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,486 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import torch.optim as optim
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from collections import deque
|
|
7
|
+
import random
|
|
8
|
+
import os
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Generator(nn.Module):
|
|
12
|
+
"""
|
|
13
|
+
Generator network: Generates improved step scaling factors from
|
|
14
|
+
the current molecular structure, gradient, and original step
|
|
15
|
+
"""
|
|
16
|
+
def __init__(self, input_dim=3, hidden_dims=[64, 128, 64]):
|
|
17
|
+
super(Generator, self).__init__()
|
|
18
|
+
|
|
19
|
+
# Input features: concatenation of coordinates, gradients, and original step vectors
|
|
20
|
+
layers = []
|
|
21
|
+
|
|
22
|
+
# Input layer
|
|
23
|
+
layers.append(nn.Linear(input_dim, hidden_dims[0]))
|
|
24
|
+
layers.append(nn.LeakyReLU(0.2))
|
|
25
|
+
|
|
26
|
+
# Hidden layers
|
|
27
|
+
for i in range(len(hidden_dims)-1):
|
|
28
|
+
layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1]))
|
|
29
|
+
layers.append(nn.LeakyReLU(0.2))
|
|
30
|
+
if i < len(hidden_dims)-2:
|
|
31
|
+
layers.append(nn.BatchNorm1d(hidden_dims[i+1]))
|
|
32
|
+
|
|
33
|
+
# Output layer
|
|
34
|
+
layers.append(nn.Linear(hidden_dims[-1], 1))
|
|
35
|
+
layers.append(nn.Tanh()) # Limit step scale factor to -1 to 1 range
|
|
36
|
+
|
|
37
|
+
self.model = nn.Sequential(*layers)
|
|
38
|
+
|
|
39
|
+
def forward(self, x):
|
|
40
|
+
return self.model(x)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class Discriminator(nn.Module):
|
|
44
|
+
"""
|
|
45
|
+
Discriminator network: Determines whether a molecular structure and step vector pair
|
|
46
|
+
represents a "good" optimization step
|
|
47
|
+
"""
|
|
48
|
+
def __init__(self, input_dim=4, hidden_dims=[64, 32]):
|
|
49
|
+
super(Discriminator, self).__init__()
|
|
50
|
+
|
|
51
|
+
# Input features: coordinates, gradients, step vectors, and energy change
|
|
52
|
+
layers = []
|
|
53
|
+
|
|
54
|
+
# Input layer
|
|
55
|
+
layers.append(nn.Linear(input_dim, hidden_dims[0]))
|
|
56
|
+
layers.append(nn.LeakyReLU(0.2))
|
|
57
|
+
|
|
58
|
+
# Hidden layers
|
|
59
|
+
for i in range(len(hidden_dims)-1):
|
|
60
|
+
layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1]))
|
|
61
|
+
layers.append(nn.LeakyReLU(0.2))
|
|
62
|
+
|
|
63
|
+
# Output layer
|
|
64
|
+
layers.append(nn.Linear(hidden_dims[-1], 1))
|
|
65
|
+
layers.append(nn.Sigmoid()) # Output as probability between 0 and 1
|
|
66
|
+
|
|
67
|
+
self.model = nn.Sequential(*layers)
|
|
68
|
+
|
|
69
|
+
def forward(self, x):
|
|
70
|
+
return self.model(x)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class ReplayBuffer:
|
|
74
|
+
"""Experience replay buffer: Stores data for GAN training"""
|
|
75
|
+
|
|
76
|
+
def __init__(self, max_size=1000):
|
|
77
|
+
self.buffer = deque(maxlen=max_size)
|
|
78
|
+
|
|
79
|
+
def add(self, experience):
|
|
80
|
+
self.buffer.append(experience)
|
|
81
|
+
|
|
82
|
+
def sample(self, batch_size):
|
|
83
|
+
"""Random sampling from buffer"""
|
|
84
|
+
batch_size = min(batch_size, len(self.buffer))
|
|
85
|
+
return random.sample(self.buffer, batch_size)
|
|
86
|
+
|
|
87
|
+
def __len__(self):
|
|
88
|
+
return len(self.buffer)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class GANStep:
|
|
92
|
+
"""
|
|
93
|
+
GAN-Step optimizer:
|
|
94
|
+
Uses a generative adversarial network to modify optimization steps
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
def __init__(self):
|
|
98
|
+
# Basic parameters
|
|
99
|
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
100
|
+
self.feature_dim = 3 # Feature dimensions for each coordinate, gradient, and original step
|
|
101
|
+
self.batch_size = 32
|
|
102
|
+
self.training_steps = 5 # Number of GAN training steps per iteration
|
|
103
|
+
self.dtype = torch.float32 # Explicitly set data type to float32
|
|
104
|
+
|
|
105
|
+
# GAN settings
|
|
106
|
+
self.gen_hidden_dims = [64, 128, 64]
|
|
107
|
+
self.dis_hidden_dims = [64, 32]
|
|
108
|
+
self.gen_learning_rate = 0.0002
|
|
109
|
+
self.dis_learning_rate = 0.0001
|
|
110
|
+
self.beta1 = 0.5 # Beta1 parameter for Adam optimizer
|
|
111
|
+
|
|
112
|
+
# Step modification parameters
|
|
113
|
+
self.min_scale = 0.2 # Minimum scaling coefficient
|
|
114
|
+
self.max_scale = 3.0 # Maximum scaling coefficient
|
|
115
|
+
self.step_clip = 0.5 # Maximum step size
|
|
116
|
+
self.mix_ratio = 0.7 # Mixture ratio with original step
|
|
117
|
+
|
|
118
|
+
# Training history and buffers
|
|
119
|
+
self.min_samples_for_training = 10 # Minimum samples required for GAN training
|
|
120
|
+
self.good_buffer = ReplayBuffer(1000) # Steps that decreased energy
|
|
121
|
+
self.bad_buffer = ReplayBuffer(1000) # Steps that increased energy
|
|
122
|
+
|
|
123
|
+
# Learning curves and state tracking
|
|
124
|
+
self.gen_losses = []
|
|
125
|
+
self.dis_losses = []
|
|
126
|
+
self.geom_history = []
|
|
127
|
+
self.energy_history = []
|
|
128
|
+
self.gradient_history = []
|
|
129
|
+
self.step_history = []
|
|
130
|
+
self.iter = 0
|
|
131
|
+
|
|
132
|
+
# Model initialization
|
|
133
|
+
self._init_models()
|
|
134
|
+
|
|
135
|
+
def _init_models(self):
|
|
136
|
+
"""Initialize generator and discriminator networks"""
|
|
137
|
+
self.generator = Generator(
|
|
138
|
+
input_dim=self.feature_dim,
|
|
139
|
+
hidden_dims=self.gen_hidden_dims
|
|
140
|
+
).to(self.device)
|
|
141
|
+
|
|
142
|
+
self.discriminator = Discriminator(
|
|
143
|
+
input_dim=self.feature_dim+1, # Additional feature for energy change
|
|
144
|
+
hidden_dims=self.dis_hidden_dims
|
|
145
|
+
).to(self.device)
|
|
146
|
+
|
|
147
|
+
# Configure optimizers
|
|
148
|
+
self.gen_optimizer = optim.Adam(
|
|
149
|
+
self.generator.parameters(),
|
|
150
|
+
lr=self.gen_learning_rate,
|
|
151
|
+
betas=(self.beta1, 0.999)
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
self.dis_optimizer = optim.Adam(
|
|
155
|
+
self.discriminator.parameters(),
|
|
156
|
+
lr=self.dis_learning_rate,
|
|
157
|
+
betas=(self.beta1, 0.999)
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# Loss function
|
|
161
|
+
self.criterion = nn.BCELoss()
|
|
162
|
+
|
|
163
|
+
# Ensure models use consistent data type
|
|
164
|
+
self.generator.to(dtype=self.dtype)
|
|
165
|
+
self.discriminator.to(dtype=self.dtype)
|
|
166
|
+
|
|
167
|
+
def _update_history(self, geometry, energy, gradient, step=None):
|
|
168
|
+
"""Update optimization history"""
|
|
169
|
+
self.geom_history.append(geometry.copy())
|
|
170
|
+
self.energy_history.append(energy)
|
|
171
|
+
self.gradient_history.append(gradient.copy())
|
|
172
|
+
|
|
173
|
+
if step is not None:
|
|
174
|
+
self.step_history.append(step.copy())
|
|
175
|
+
elif len(self.geom_history) > 1:
|
|
176
|
+
# Calculate step from previous geometry
|
|
177
|
+
calc_step = geometry - self.geom_history[-2]
|
|
178
|
+
self.step_history.append(calc_step)
|
|
179
|
+
|
|
180
|
+
def _update_replay_buffers(self):
|
|
181
|
+
"""Update replay buffers"""
|
|
182
|
+
if len(self.energy_history) < 2 or len(self.step_history) < 1:
|
|
183
|
+
return
|
|
184
|
+
|
|
185
|
+
# Calculate energy change for the latest step
|
|
186
|
+
energy_change = self.energy_history[-1] - self.energy_history[-2]
|
|
187
|
+
prev_geom = self.geom_history[-2]
|
|
188
|
+
prev_grad = self.gradient_history[-2]
|
|
189
|
+
step = self.step_history[-1]
|
|
190
|
+
|
|
191
|
+
# Create features for each coordinate point
|
|
192
|
+
n_coords = prev_geom.shape[0]
|
|
193
|
+
for i in range(n_coords):
|
|
194
|
+
features = np.hstack([
|
|
195
|
+
prev_geom[i],
|
|
196
|
+
prev_grad[i],
|
|
197
|
+
step[i]
|
|
198
|
+
]).astype(np.float32) # Explicitly use float32
|
|
199
|
+
|
|
200
|
+
# Add energy change as a feature
|
|
201
|
+
features_with_energy = np.append(features, energy_change).astype(np.float32) # Explicitly use float32
|
|
202
|
+
|
|
203
|
+
# Add to appropriate buffer based on energy change
|
|
204
|
+
experience = (features, features_with_energy, energy_change)
|
|
205
|
+
if energy_change <= 0: # Energy decreased = good step
|
|
206
|
+
self.good_buffer.add(experience)
|
|
207
|
+
else: # Energy increased = bad step
|
|
208
|
+
self.bad_buffer.add(experience)
|
|
209
|
+
|
|
210
|
+
def _train_gan(self):
|
|
211
|
+
"""Train the GAN model"""
|
|
212
|
+
# Skip training if insufficient samples
|
|
213
|
+
if len(self.good_buffer) < self.min_samples_for_training:
|
|
214
|
+
return False
|
|
215
|
+
|
|
216
|
+
# Sample from buffer
|
|
217
|
+
for _ in range(self.training_steps):
|
|
218
|
+
# Mini-batch sampling
|
|
219
|
+
real_batch_size = min(self.batch_size // 2, len(self.good_buffer))
|
|
220
|
+
fake_batch_size = min(self.batch_size // 2, len(self.bad_buffer))
|
|
221
|
+
|
|
222
|
+
if real_batch_size == 0 or fake_batch_size == 0:
|
|
223
|
+
continue
|
|
224
|
+
|
|
225
|
+
# Good step samples (real/positive)
|
|
226
|
+
good_samples = self.good_buffer.sample(real_batch_size)
|
|
227
|
+
good_features = torch.tensor([s[0] for s in good_samples],
|
|
228
|
+
device=self.device,
|
|
229
|
+
dtype=self.dtype) # Explicitly set dtype
|
|
230
|
+
good_features_with_energy = torch.tensor([s[1] for s in good_samples],
|
|
231
|
+
device=self.device,
|
|
232
|
+
dtype=self.dtype) # Explicitly set dtype
|
|
233
|
+
|
|
234
|
+
# Bad step samples (fake/negative)
|
|
235
|
+
bad_samples = self.bad_buffer.sample(fake_batch_size)
|
|
236
|
+
bad_features = torch.tensor([s[0] for s in bad_samples],
|
|
237
|
+
device=self.device,
|
|
238
|
+
dtype=self.dtype) # Explicitly set dtype
|
|
239
|
+
bad_features_with_energy = torch.tensor([s[1] for s in bad_samples],
|
|
240
|
+
device=self.device,
|
|
241
|
+
dtype=self.dtype) # Explicitly set dtype
|
|
242
|
+
|
|
243
|
+
# Create labels
|
|
244
|
+
real_labels = torch.ones(real_batch_size, 1, device=self.device, dtype=self.dtype)
|
|
245
|
+
fake_labels = torch.zeros(fake_batch_size, 1, device=self.device, dtype=self.dtype)
|
|
246
|
+
|
|
247
|
+
# --------------------
|
|
248
|
+
# Train Discriminator
|
|
249
|
+
# --------------------
|
|
250
|
+
self.dis_optimizer.zero_grad()
|
|
251
|
+
|
|
252
|
+
# Classify good samples
|
|
253
|
+
good_outputs = self.discriminator(good_features_with_energy)
|
|
254
|
+
d_loss_real = self.criterion(good_outputs, real_labels)
|
|
255
|
+
|
|
256
|
+
# Classify bad samples
|
|
257
|
+
bad_outputs = self.discriminator(bad_features_with_energy)
|
|
258
|
+
d_loss_fake = self.criterion(bad_outputs, fake_labels)
|
|
259
|
+
|
|
260
|
+
# Classify generator-modified steps
|
|
261
|
+
gen_scale = self.generator(bad_features)
|
|
262
|
+
gen_step_features = bad_features.clone()
|
|
263
|
+
|
|
264
|
+
# Apply scaling (maintain direction, adjust scale only)
|
|
265
|
+
for i in range(gen_step_features.shape[0]):
|
|
266
|
+
# Get original step vector (last dimension)
|
|
267
|
+
orig_step = gen_step_features[i, -1].item()
|
|
268
|
+
|
|
269
|
+
# Calculate scale factor (convert from -1 to 1 range to appropriate scale)
|
|
270
|
+
scale = ((gen_scale[i].item() + 1) / 2) * (self.max_scale - self.min_scale) + self.min_scale
|
|
271
|
+
|
|
272
|
+
# Apply scale to compute modified step
|
|
273
|
+
gen_step_features[i, -1] = orig_step * scale
|
|
274
|
+
|
|
275
|
+
# Add energy change
|
|
276
|
+
gen_features_with_energy = torch.cat([
|
|
277
|
+
gen_step_features,
|
|
278
|
+
torch.zeros(gen_step_features.shape[0], 1, device=self.device, dtype=self.dtype) # Explicit dtype
|
|
279
|
+
], dim=1)
|
|
280
|
+
|
|
281
|
+
g_outputs = self.discriminator(gen_features_with_energy)
|
|
282
|
+
d_loss_gen = self.criterion(g_outputs, fake_labels)
|
|
283
|
+
|
|
284
|
+
d_loss = d_loss_real + d_loss_fake + d_loss_gen
|
|
285
|
+
d_loss.backward()
|
|
286
|
+
self.dis_optimizer.step()
|
|
287
|
+
|
|
288
|
+
# --------------------
|
|
289
|
+
# Train Generator
|
|
290
|
+
# --------------------
|
|
291
|
+
self.gen_optimizer.zero_grad()
|
|
292
|
+
|
|
293
|
+
# Generate modified steps
|
|
294
|
+
gen_scale = self.generator(bad_features)
|
|
295
|
+
gen_step_features = bad_features.clone()
|
|
296
|
+
|
|
297
|
+
# Apply scaling
|
|
298
|
+
for i in range(gen_step_features.shape[0]):
|
|
299
|
+
orig_step = gen_step_features[i, -1].item()
|
|
300
|
+
scale = ((gen_scale[i].item() + 1) / 2) * (self.max_scale - self.min_scale) + self.min_scale
|
|
301
|
+
gen_step_features[i, -1] = orig_step * scale
|
|
302
|
+
|
|
303
|
+
# Add energy change
|
|
304
|
+
gen_features_with_energy = torch.cat([
|
|
305
|
+
gen_step_features,
|
|
306
|
+
torch.zeros(gen_step_features.shape[0], 1, device=self.device, dtype=self.dtype) # Explicit dtype
|
|
307
|
+
], dim=1)
|
|
308
|
+
|
|
309
|
+
# Get discriminator evaluation
|
|
310
|
+
g_outputs = self.discriminator(gen_features_with_energy)
|
|
311
|
+
|
|
312
|
+
# Train generator to produce "good" steps
|
|
313
|
+
g_loss = self.criterion(g_outputs, real_labels)
|
|
314
|
+
g_loss.backward()
|
|
315
|
+
self.gen_optimizer.step()
|
|
316
|
+
|
|
317
|
+
# Record losses
|
|
318
|
+
self.gen_losses.append(g_loss.item())
|
|
319
|
+
self.dis_losses.append(d_loss.item())
|
|
320
|
+
|
|
321
|
+
if len(self.gen_losses) > 0:
|
|
322
|
+
print(f"GAN training - Gen loss: {np.mean(self.gen_losses[-self.training_steps:]):.4f}, "
|
|
323
|
+
f"Dis loss: {np.mean(self.dis_losses[-self.training_steps:]):.4f}")
|
|
324
|
+
|
|
325
|
+
return True
|
|
326
|
+
|
|
327
|
+
def _generate_improved_step(self, geometry, gradient, original_step):
|
|
328
|
+
"""Generate improved step using the GAN"""
|
|
329
|
+
# Don't modify step if original norm is near zero
|
|
330
|
+
orig_norm = np.linalg.norm(original_step)
|
|
331
|
+
if orig_norm < 1e-10:
|
|
332
|
+
return original_step
|
|
333
|
+
|
|
334
|
+
# Create features
|
|
335
|
+
n_coords = geometry.shape[0]
|
|
336
|
+
features = []
|
|
337
|
+
|
|
338
|
+
for i in range(n_coords):
|
|
339
|
+
feat = np.hstack([
|
|
340
|
+
geometry[i],
|
|
341
|
+
gradient[i],
|
|
342
|
+
original_step[i]
|
|
343
|
+
]).astype(np.float32) # Explicitly use float32
|
|
344
|
+
features.append(feat)
|
|
345
|
+
|
|
346
|
+
# Prepare batch for model evaluation
|
|
347
|
+
features_tensor = torch.tensor(features, device=self.device, dtype=self.dtype) # Explicitly set dtype
|
|
348
|
+
|
|
349
|
+
# Generate scaling factors with the generator
|
|
350
|
+
self.generator.eval()
|
|
351
|
+
with torch.no_grad():
|
|
352
|
+
scale_factors = self.generator(features_tensor)
|
|
353
|
+
|
|
354
|
+
# Convert from -1 to 1 range to actual scale range
|
|
355
|
+
scale_factors = ((scale_factors + 1) / 2) * (self.max_scale - self.min_scale) + self.min_scale
|
|
356
|
+
scale_factors = scale_factors.cpu().numpy()
|
|
357
|
+
|
|
358
|
+
# Generate modified step
|
|
359
|
+
gan_step = original_step.copy()
|
|
360
|
+
|
|
361
|
+
for i in range(n_coords):
|
|
362
|
+
# Apply scale to each coordinate point
|
|
363
|
+
gan_step[i] = original_step[i] * scale_factors[i, 0]
|
|
364
|
+
|
|
365
|
+
# Mix steps
|
|
366
|
+
mixed_step = self.mix_ratio * gan_step + (1 - self.mix_ratio) * original_step
|
|
367
|
+
|
|
368
|
+
# Limit step size
|
|
369
|
+
step_norm = np.linalg.norm(mixed_step)
|
|
370
|
+
if step_norm > self.step_clip:
|
|
371
|
+
mixed_step = mixed_step * (self.step_clip / step_norm)
|
|
372
|
+
|
|
373
|
+
# Output norms before and after modification
|
|
374
|
+
print(f"Step norm - Original: {orig_norm:.6f}, GAN: {np.linalg.norm(gan_step):.6f}, "
|
|
375
|
+
f"Mixed: {np.linalg.norm(mixed_step):.6f}")
|
|
376
|
+
|
|
377
|
+
return mixed_step
|
|
378
|
+
|
|
379
|
+
def run(self, geom_num_list, energy, gradient, original_move_vector):
|
|
380
|
+
"""
|
|
381
|
+
Run GAN-Step optimization step
|
|
382
|
+
|
|
383
|
+
Parameters:
|
|
384
|
+
-----------
|
|
385
|
+
geom_num_list : numpy.ndarray
|
|
386
|
+
Current molecular geometry
|
|
387
|
+
energy : float
|
|
388
|
+
Current energy value
|
|
389
|
+
gradient : numpy.ndarray
|
|
390
|
+
Current gradient
|
|
391
|
+
original_move_vector : numpy.ndarray
|
|
392
|
+
Original optimization step
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
--------
|
|
396
|
+
numpy.ndarray
|
|
397
|
+
Modified optimization step
|
|
398
|
+
"""
|
|
399
|
+
print("GAN-Step method")
|
|
400
|
+
|
|
401
|
+
# Update history
|
|
402
|
+
self._update_history(geom_num_list, energy, gradient)
|
|
403
|
+
|
|
404
|
+
# Update replay buffers
|
|
405
|
+
self._update_replay_buffers()
|
|
406
|
+
|
|
407
|
+
# Use original step for initial iterations
|
|
408
|
+
if self.iter < 3:
|
|
409
|
+
print(f"Building history (step {self.iter+1}), using original step")
|
|
410
|
+
self.iter += 1
|
|
411
|
+
return original_move_vector
|
|
412
|
+
|
|
413
|
+
# Train GAN model
|
|
414
|
+
if len(self.good_buffer) >= self.min_samples_for_training:
|
|
415
|
+
try:
|
|
416
|
+
gan_trained = self._train_gan()
|
|
417
|
+
if not gan_trained:
|
|
418
|
+
print("Failed to train GAN, using original step")
|
|
419
|
+
self.iter += 1
|
|
420
|
+
return original_move_vector
|
|
421
|
+
except RuntimeError as e:
|
|
422
|
+
print(f"Error during GAN training: {str(e)}")
|
|
423
|
+
print("Using original step due to training error")
|
|
424
|
+
self.iter += 1
|
|
425
|
+
return original_move_vector
|
|
426
|
+
else:
|
|
427
|
+
print(f"Not enough good samples for GAN training ({len(self.good_buffer)}/{self.min_samples_for_training})")
|
|
428
|
+
self.iter += 1
|
|
429
|
+
return original_move_vector
|
|
430
|
+
|
|
431
|
+
# Modify step using GAN
|
|
432
|
+
try:
|
|
433
|
+
gan_step = self._generate_improved_step(geom_num_list, gradient, original_move_vector)
|
|
434
|
+
except Exception as e:
|
|
435
|
+
print(f"Error generating improved step: {str(e)}")
|
|
436
|
+
print("Using original step due to generation error")
|
|
437
|
+
self.iter += 1
|
|
438
|
+
return original_move_vector
|
|
439
|
+
|
|
440
|
+
# Check for numerical issues
|
|
441
|
+
if np.any(np.isnan(gan_step)) or np.any(np.isinf(gan_step)):
|
|
442
|
+
print("Warning: Numerical issues in GAN step, using original step")
|
|
443
|
+
self.iter += 1
|
|
444
|
+
return original_move_vector
|
|
445
|
+
|
|
446
|
+
self.iter += 1
|
|
447
|
+
return gan_step
|
|
448
|
+
|
|
449
|
+
def save_model(self, path='gan_model'):
|
|
450
|
+
"""Save GAN model"""
|
|
451
|
+
if not os.path.exists(path):
|
|
452
|
+
os.makedirs(path)
|
|
453
|
+
|
|
454
|
+
torch.save({
|
|
455
|
+
'generator': self.generator.state_dict(),
|
|
456
|
+
'discriminator': self.discriminator.state_dict(),
|
|
457
|
+
'gen_optimizer': self.gen_optimizer.state_dict(),
|
|
458
|
+
'dis_optimizer': self.dis_optimizer.state_dict(),
|
|
459
|
+
'gen_losses': self.gen_losses,
|
|
460
|
+
'dis_losses': self.dis_losses,
|
|
461
|
+
'iter': self.iter
|
|
462
|
+
}, os.path.join(path, 'gan_step_model.pt'))
|
|
463
|
+
|
|
464
|
+
print(f"Model saved to {os.path.join(path, 'gan_step_model.pt')}")
|
|
465
|
+
|
|
466
|
+
def load_model(self, path='gan_model/gan_step_model.pt'):
|
|
467
|
+
"""Load GAN model"""
|
|
468
|
+
if not os.path.exists(path):
|
|
469
|
+
print(f"No model file found at {path}")
|
|
470
|
+
return False
|
|
471
|
+
|
|
472
|
+
try:
|
|
473
|
+
checkpoint = torch.load(path)
|
|
474
|
+
self.generator.load_state_dict(checkpoint['generator'])
|
|
475
|
+
self.discriminator.load_state_dict(checkpoint['discriminator'])
|
|
476
|
+
self.gen_optimizer.load_state_dict(checkpoint['gen_optimizer'])
|
|
477
|
+
self.dis_optimizer.load_state_dict(checkpoint['dis_optimizer'])
|
|
478
|
+
self.gen_losses = checkpoint['gen_losses']
|
|
479
|
+
self.dis_losses = checkpoint['dis_losses']
|
|
480
|
+
self.iter = checkpoint['iter']
|
|
481
|
+
|
|
482
|
+
print(f"Model loaded from {path}")
|
|
483
|
+
return True
|
|
484
|
+
except Exception as e:
|
|
485
|
+
print(f"Failed to load model: {str(e)}")
|
|
486
|
+
return False
|