MultiOptPy 1.20.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (246) hide show
  1. multioptpy/Calculator/__init__.py +0 -0
  2. multioptpy/Calculator/ase_calculation_tools.py +424 -0
  3. multioptpy/Calculator/ase_tools/__init__.py +0 -0
  4. multioptpy/Calculator/ase_tools/fairchem.py +28 -0
  5. multioptpy/Calculator/ase_tools/gamess.py +19 -0
  6. multioptpy/Calculator/ase_tools/gaussian.py +165 -0
  7. multioptpy/Calculator/ase_tools/mace.py +28 -0
  8. multioptpy/Calculator/ase_tools/mopac.py +19 -0
  9. multioptpy/Calculator/ase_tools/nwchem.py +31 -0
  10. multioptpy/Calculator/ase_tools/orca.py +22 -0
  11. multioptpy/Calculator/ase_tools/pygfn0.py +37 -0
  12. multioptpy/Calculator/dxtb_calculation_tools.py +344 -0
  13. multioptpy/Calculator/emt_calculation_tools.py +458 -0
  14. multioptpy/Calculator/gpaw_calculation_tools.py +183 -0
  15. multioptpy/Calculator/lj_calculation_tools.py +314 -0
  16. multioptpy/Calculator/psi4_calculation_tools.py +334 -0
  17. multioptpy/Calculator/pwscf_calculation_tools.py +189 -0
  18. multioptpy/Calculator/pyscf_calculation_tools.py +327 -0
  19. multioptpy/Calculator/sqm1_calculation_tools.py +611 -0
  20. multioptpy/Calculator/sqm2_calculation_tools.py +376 -0
  21. multioptpy/Calculator/tblite_calculation_tools.py +352 -0
  22. multioptpy/Calculator/tersoff_calculation_tools.py +818 -0
  23. multioptpy/Constraint/__init__.py +0 -0
  24. multioptpy/Constraint/constraint_condition.py +834 -0
  25. multioptpy/Coordinate/__init__.py +0 -0
  26. multioptpy/Coordinate/polar_coordinate.py +199 -0
  27. multioptpy/Coordinate/redundant_coordinate.py +638 -0
  28. multioptpy/IRC/__init__.py +0 -0
  29. multioptpy/IRC/converge_criteria.py +28 -0
  30. multioptpy/IRC/dvv.py +544 -0
  31. multioptpy/IRC/euler.py +439 -0
  32. multioptpy/IRC/hpc.py +564 -0
  33. multioptpy/IRC/lqa.py +540 -0
  34. multioptpy/IRC/modekill.py +662 -0
  35. multioptpy/IRC/rk4.py +579 -0
  36. multioptpy/Interpolation/__init__.py +0 -0
  37. multioptpy/Interpolation/adaptive_interpolation.py +283 -0
  38. multioptpy/Interpolation/binomial_interpolation.py +179 -0
  39. multioptpy/Interpolation/geodesic_interpolation.py +785 -0
  40. multioptpy/Interpolation/interpolation.py +156 -0
  41. multioptpy/Interpolation/linear_interpolation.py +473 -0
  42. multioptpy/Interpolation/savitzky_golay_interpolation.py +252 -0
  43. multioptpy/Interpolation/spline_interpolation.py +353 -0
  44. multioptpy/MD/__init__.py +0 -0
  45. multioptpy/MD/thermostat.py +185 -0
  46. multioptpy/MEP/__init__.py +0 -0
  47. multioptpy/MEP/pathopt_bneb_force.py +443 -0
  48. multioptpy/MEP/pathopt_dmf_force.py +448 -0
  49. multioptpy/MEP/pathopt_dneb_force.py +130 -0
  50. multioptpy/MEP/pathopt_ewbneb_force.py +207 -0
  51. multioptpy/MEP/pathopt_gpneb_force.py +512 -0
  52. multioptpy/MEP/pathopt_lup_force.py +113 -0
  53. multioptpy/MEP/pathopt_neb_force.py +225 -0
  54. multioptpy/MEP/pathopt_nesb_force.py +205 -0
  55. multioptpy/MEP/pathopt_om_force.py +153 -0
  56. multioptpy/MEP/pathopt_qsm_force.py +174 -0
  57. multioptpy/MEP/pathopt_qsmv2_force.py +304 -0
  58. multioptpy/ModelFunction/__init__.py +7 -0
  59. multioptpy/ModelFunction/avoiding_model_function.py +29 -0
  60. multioptpy/ModelFunction/binary_image_ts_search_model_function.py +47 -0
  61. multioptpy/ModelFunction/conical_model_function.py +26 -0
  62. multioptpy/ModelFunction/opt_meci.py +50 -0
  63. multioptpy/ModelFunction/opt_mesx.py +47 -0
  64. multioptpy/ModelFunction/opt_mesx_2.py +49 -0
  65. multioptpy/ModelFunction/seam_model_function.py +27 -0
  66. multioptpy/ModelHessian/__init__.py +0 -0
  67. multioptpy/ModelHessian/approx_hessian.py +147 -0
  68. multioptpy/ModelHessian/calc_params.py +227 -0
  69. multioptpy/ModelHessian/fischer.py +236 -0
  70. multioptpy/ModelHessian/fischerd3.py +360 -0
  71. multioptpy/ModelHessian/fischerd4.py +398 -0
  72. multioptpy/ModelHessian/gfn0xtb.py +633 -0
  73. multioptpy/ModelHessian/gfnff.py +709 -0
  74. multioptpy/ModelHessian/lindh.py +165 -0
  75. multioptpy/ModelHessian/lindh2007d2.py +707 -0
  76. multioptpy/ModelHessian/lindh2007d3.py +822 -0
  77. multioptpy/ModelHessian/lindh2007d4.py +1030 -0
  78. multioptpy/ModelHessian/morse.py +106 -0
  79. multioptpy/ModelHessian/schlegel.py +144 -0
  80. multioptpy/ModelHessian/schlegeld3.py +322 -0
  81. multioptpy/ModelHessian/schlegeld4.py +559 -0
  82. multioptpy/ModelHessian/shortrange.py +346 -0
  83. multioptpy/ModelHessian/swartd2.py +496 -0
  84. multioptpy/ModelHessian/swartd3.py +706 -0
  85. multioptpy/ModelHessian/swartd4.py +918 -0
  86. multioptpy/ModelHessian/tshess.py +40 -0
  87. multioptpy/Optimizer/QHAdam.py +61 -0
  88. multioptpy/Optimizer/__init__.py +0 -0
  89. multioptpy/Optimizer/abc_fire.py +83 -0
  90. multioptpy/Optimizer/adabelief.py +58 -0
  91. multioptpy/Optimizer/adabound.py +68 -0
  92. multioptpy/Optimizer/adadelta.py +65 -0
  93. multioptpy/Optimizer/adaderivative.py +56 -0
  94. multioptpy/Optimizer/adadiff.py +68 -0
  95. multioptpy/Optimizer/adafactor.py +70 -0
  96. multioptpy/Optimizer/adam.py +65 -0
  97. multioptpy/Optimizer/adamax.py +62 -0
  98. multioptpy/Optimizer/adamod.py +83 -0
  99. multioptpy/Optimizer/adamw.py +65 -0
  100. multioptpy/Optimizer/adiis.py +523 -0
  101. multioptpy/Optimizer/afire_neb.py +282 -0
  102. multioptpy/Optimizer/block_hessian_update.py +709 -0
  103. multioptpy/Optimizer/c2diis.py +491 -0
  104. multioptpy/Optimizer/component_wise_scaling.py +405 -0
  105. multioptpy/Optimizer/conjugate_gradient.py +82 -0
  106. multioptpy/Optimizer/conjugate_gradient_neb.py +345 -0
  107. multioptpy/Optimizer/coordinate_locking.py +405 -0
  108. multioptpy/Optimizer/dic_rsirfo.py +1015 -0
  109. multioptpy/Optimizer/ediis.py +417 -0
  110. multioptpy/Optimizer/eve.py +76 -0
  111. multioptpy/Optimizer/fastadabelief.py +61 -0
  112. multioptpy/Optimizer/fire.py +77 -0
  113. multioptpy/Optimizer/fire2.py +249 -0
  114. multioptpy/Optimizer/fire_neb.py +92 -0
  115. multioptpy/Optimizer/gan_step.py +486 -0
  116. multioptpy/Optimizer/gdiis.py +609 -0
  117. multioptpy/Optimizer/gediis.py +203 -0
  118. multioptpy/Optimizer/geodesic_step.py +433 -0
  119. multioptpy/Optimizer/gpmin.py +633 -0
  120. multioptpy/Optimizer/gpr_step.py +364 -0
  121. multioptpy/Optimizer/gradientdescent.py +78 -0
  122. multioptpy/Optimizer/gradientdescent_neb.py +52 -0
  123. multioptpy/Optimizer/hessian_update.py +433 -0
  124. multioptpy/Optimizer/hybrid_rfo.py +998 -0
  125. multioptpy/Optimizer/kdiis.py +625 -0
  126. multioptpy/Optimizer/lars.py +21 -0
  127. multioptpy/Optimizer/lbfgs.py +253 -0
  128. multioptpy/Optimizer/lbfgs_neb.py +355 -0
  129. multioptpy/Optimizer/linesearch.py +236 -0
  130. multioptpy/Optimizer/lookahead.py +40 -0
  131. multioptpy/Optimizer/nadam.py +64 -0
  132. multioptpy/Optimizer/newton.py +200 -0
  133. multioptpy/Optimizer/prodigy.py +70 -0
  134. multioptpy/Optimizer/purtubation.py +16 -0
  135. multioptpy/Optimizer/quickmin_neb.py +245 -0
  136. multioptpy/Optimizer/radam.py +75 -0
  137. multioptpy/Optimizer/rfo_neb.py +302 -0
  138. multioptpy/Optimizer/ric_rfo.py +842 -0
  139. multioptpy/Optimizer/rl_step.py +627 -0
  140. multioptpy/Optimizer/rmspropgrave.py +65 -0
  141. multioptpy/Optimizer/rsirfo.py +1647 -0
  142. multioptpy/Optimizer/rsprfo.py +1056 -0
  143. multioptpy/Optimizer/sadam.py +60 -0
  144. multioptpy/Optimizer/samsgrad.py +63 -0
  145. multioptpy/Optimizer/tr_lbfgs.py +678 -0
  146. multioptpy/Optimizer/trim.py +273 -0
  147. multioptpy/Optimizer/trust_radius.py +207 -0
  148. multioptpy/Optimizer/trust_radius_neb.py +121 -0
  149. multioptpy/Optimizer/yogi.py +60 -0
  150. multioptpy/OtherMethod/__init__.py +0 -0
  151. multioptpy/OtherMethod/addf.py +1150 -0
  152. multioptpy/OtherMethod/dimer.py +895 -0
  153. multioptpy/OtherMethod/elastic_image_pair.py +629 -0
  154. multioptpy/OtherMethod/modelfunction.py +456 -0
  155. multioptpy/OtherMethod/newton_traj.py +454 -0
  156. multioptpy/OtherMethod/twopshs.py +1095 -0
  157. multioptpy/PESAnalyzer/__init__.py +0 -0
  158. multioptpy/PESAnalyzer/calc_irc_curvature.py +125 -0
  159. multioptpy/PESAnalyzer/cmds_analysis.py +152 -0
  160. multioptpy/PESAnalyzer/koopman_analysis.py +268 -0
  161. multioptpy/PESAnalyzer/pca_analysis.py +314 -0
  162. multioptpy/Parameters/__init__.py +0 -0
  163. multioptpy/Parameters/atomic_mass.py +20 -0
  164. multioptpy/Parameters/atomic_number.py +22 -0
  165. multioptpy/Parameters/covalent_radii.py +44 -0
  166. multioptpy/Parameters/d2.py +61 -0
  167. multioptpy/Parameters/d3.py +63 -0
  168. multioptpy/Parameters/d4.py +103 -0
  169. multioptpy/Parameters/dreiding.py +34 -0
  170. multioptpy/Parameters/gfn0xtb_param.py +137 -0
  171. multioptpy/Parameters/gfnff_param.py +315 -0
  172. multioptpy/Parameters/gnb.py +104 -0
  173. multioptpy/Parameters/parameter.py +22 -0
  174. multioptpy/Parameters/uff.py +72 -0
  175. multioptpy/Parameters/unit_values.py +20 -0
  176. multioptpy/Potential/AFIR_potential.py +55 -0
  177. multioptpy/Potential/LJ_repulsive_potential.py +345 -0
  178. multioptpy/Potential/__init__.py +0 -0
  179. multioptpy/Potential/anharmonic_keep_potential.py +28 -0
  180. multioptpy/Potential/asym_elllipsoidal_potential.py +718 -0
  181. multioptpy/Potential/electrostatic_potential.py +69 -0
  182. multioptpy/Potential/flux_potential.py +30 -0
  183. multioptpy/Potential/gaussian_potential.py +101 -0
  184. multioptpy/Potential/idpp.py +516 -0
  185. multioptpy/Potential/keep_angle_potential.py +146 -0
  186. multioptpy/Potential/keep_dihedral_angle_potential.py +105 -0
  187. multioptpy/Potential/keep_outofplain_angle_potential.py +70 -0
  188. multioptpy/Potential/keep_potential.py +99 -0
  189. multioptpy/Potential/mechano_force_potential.py +74 -0
  190. multioptpy/Potential/nanoreactor_potential.py +52 -0
  191. multioptpy/Potential/potential.py +896 -0
  192. multioptpy/Potential/spacer_model_potential.py +221 -0
  193. multioptpy/Potential/switching_potential.py +258 -0
  194. multioptpy/Potential/universal_potential.py +34 -0
  195. multioptpy/Potential/value_range_potential.py +36 -0
  196. multioptpy/Potential/void_point_potential.py +25 -0
  197. multioptpy/SQM/__init__.py +0 -0
  198. multioptpy/SQM/sqm1/__init__.py +0 -0
  199. multioptpy/SQM/sqm1/sqm1_core.py +1792 -0
  200. multioptpy/SQM/sqm2/__init__.py +0 -0
  201. multioptpy/SQM/sqm2/calc_tools.py +95 -0
  202. multioptpy/SQM/sqm2/sqm2_basis.py +850 -0
  203. multioptpy/SQM/sqm2/sqm2_bond.py +119 -0
  204. multioptpy/SQM/sqm2/sqm2_core.py +303 -0
  205. multioptpy/SQM/sqm2/sqm2_data.py +1229 -0
  206. multioptpy/SQM/sqm2/sqm2_disp.py +65 -0
  207. multioptpy/SQM/sqm2/sqm2_eeq.py +243 -0
  208. multioptpy/SQM/sqm2/sqm2_overlapint.py +704 -0
  209. multioptpy/SQM/sqm2/sqm2_qm.py +578 -0
  210. multioptpy/SQM/sqm2/sqm2_rep.py +66 -0
  211. multioptpy/SQM/sqm2/sqm2_srb.py +70 -0
  212. multioptpy/Thermo/__init__.py +0 -0
  213. multioptpy/Thermo/normal_mode_analyzer.py +865 -0
  214. multioptpy/Utils/__init__.py +0 -0
  215. multioptpy/Utils/bond_connectivity.py +264 -0
  216. multioptpy/Utils/calc_tools.py +884 -0
  217. multioptpy/Utils/oniom.py +96 -0
  218. multioptpy/Utils/pbc.py +48 -0
  219. multioptpy/Utils/riemann_curvature.py +208 -0
  220. multioptpy/Utils/symmetry_analyzer.py +482 -0
  221. multioptpy/Visualization/__init__.py +0 -0
  222. multioptpy/Visualization/visualization.py +156 -0
  223. multioptpy/WFAnalyzer/MO_analysis.py +104 -0
  224. multioptpy/WFAnalyzer/__init__.py +0 -0
  225. multioptpy/Wrapper/__init__.py +0 -0
  226. multioptpy/Wrapper/autots.py +1239 -0
  227. multioptpy/Wrapper/ieip_wrapper.py +93 -0
  228. multioptpy/Wrapper/md_wrapper.py +92 -0
  229. multioptpy/Wrapper/neb_wrapper.py +94 -0
  230. multioptpy/Wrapper/optimize_wrapper.py +76 -0
  231. multioptpy/__init__.py +5 -0
  232. multioptpy/entrypoints.py +916 -0
  233. multioptpy/fileio.py +660 -0
  234. multioptpy/ieip.py +340 -0
  235. multioptpy/interface.py +1086 -0
  236. multioptpy/irc.py +529 -0
  237. multioptpy/moleculardynamics.py +432 -0
  238. multioptpy/neb.py +1267 -0
  239. multioptpy/optimization.py +1553 -0
  240. multioptpy/optimizer.py +709 -0
  241. multioptpy-1.20.2.dist-info/METADATA +438 -0
  242. multioptpy-1.20.2.dist-info/RECORD +246 -0
  243. multioptpy-1.20.2.dist-info/WHEEL +5 -0
  244. multioptpy-1.20.2.dist-info/entry_points.txt +9 -0
  245. multioptpy-1.20.2.dist-info/licenses/LICENSE +674 -0
  246. multioptpy-1.20.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,486 @@
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ import torch.nn.functional as F
6
+ from collections import deque
7
+ import random
8
+ import os
9
+
10
+
11
+ class Generator(nn.Module):
12
+ """
13
+ Generator network: Generates improved step scaling factors from
14
+ the current molecular structure, gradient, and original step
15
+ """
16
+ def __init__(self, input_dim=3, hidden_dims=[64, 128, 64]):
17
+ super(Generator, self).__init__()
18
+
19
+ # Input features: concatenation of coordinates, gradients, and original step vectors
20
+ layers = []
21
+
22
+ # Input layer
23
+ layers.append(nn.Linear(input_dim, hidden_dims[0]))
24
+ layers.append(nn.LeakyReLU(0.2))
25
+
26
+ # Hidden layers
27
+ for i in range(len(hidden_dims)-1):
28
+ layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1]))
29
+ layers.append(nn.LeakyReLU(0.2))
30
+ if i < len(hidden_dims)-2:
31
+ layers.append(nn.BatchNorm1d(hidden_dims[i+1]))
32
+
33
+ # Output layer
34
+ layers.append(nn.Linear(hidden_dims[-1], 1))
35
+ layers.append(nn.Tanh()) # Limit step scale factor to -1 to 1 range
36
+
37
+ self.model = nn.Sequential(*layers)
38
+
39
+ def forward(self, x):
40
+ return self.model(x)
41
+
42
+
43
+ class Discriminator(nn.Module):
44
+ """
45
+ Discriminator network: Determines whether a molecular structure and step vector pair
46
+ represents a "good" optimization step
47
+ """
48
+ def __init__(self, input_dim=4, hidden_dims=[64, 32]):
49
+ super(Discriminator, self).__init__()
50
+
51
+ # Input features: coordinates, gradients, step vectors, and energy change
52
+ layers = []
53
+
54
+ # Input layer
55
+ layers.append(nn.Linear(input_dim, hidden_dims[0]))
56
+ layers.append(nn.LeakyReLU(0.2))
57
+
58
+ # Hidden layers
59
+ for i in range(len(hidden_dims)-1):
60
+ layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1]))
61
+ layers.append(nn.LeakyReLU(0.2))
62
+
63
+ # Output layer
64
+ layers.append(nn.Linear(hidden_dims[-1], 1))
65
+ layers.append(nn.Sigmoid()) # Output as probability between 0 and 1
66
+
67
+ self.model = nn.Sequential(*layers)
68
+
69
+ def forward(self, x):
70
+ return self.model(x)
71
+
72
+
73
+ class ReplayBuffer:
74
+ """Experience replay buffer: Stores data for GAN training"""
75
+
76
+ def __init__(self, max_size=1000):
77
+ self.buffer = deque(maxlen=max_size)
78
+
79
+ def add(self, experience):
80
+ self.buffer.append(experience)
81
+
82
+ def sample(self, batch_size):
83
+ """Random sampling from buffer"""
84
+ batch_size = min(batch_size, len(self.buffer))
85
+ return random.sample(self.buffer, batch_size)
86
+
87
+ def __len__(self):
88
+ return len(self.buffer)
89
+
90
+
91
+ class GANStep:
92
+ """
93
+ GAN-Step optimizer:
94
+ Uses a generative adversarial network to modify optimization steps
95
+ """
96
+
97
+ def __init__(self):
98
+ # Basic parameters
99
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
100
+ self.feature_dim = 3 # Feature dimensions for each coordinate, gradient, and original step
101
+ self.batch_size = 32
102
+ self.training_steps = 5 # Number of GAN training steps per iteration
103
+ self.dtype = torch.float32 # Explicitly set data type to float32
104
+
105
+ # GAN settings
106
+ self.gen_hidden_dims = [64, 128, 64]
107
+ self.dis_hidden_dims = [64, 32]
108
+ self.gen_learning_rate = 0.0002
109
+ self.dis_learning_rate = 0.0001
110
+ self.beta1 = 0.5 # Beta1 parameter for Adam optimizer
111
+
112
+ # Step modification parameters
113
+ self.min_scale = 0.2 # Minimum scaling coefficient
114
+ self.max_scale = 3.0 # Maximum scaling coefficient
115
+ self.step_clip = 0.5 # Maximum step size
116
+ self.mix_ratio = 0.7 # Mixture ratio with original step
117
+
118
+ # Training history and buffers
119
+ self.min_samples_for_training = 10 # Minimum samples required for GAN training
120
+ self.good_buffer = ReplayBuffer(1000) # Steps that decreased energy
121
+ self.bad_buffer = ReplayBuffer(1000) # Steps that increased energy
122
+
123
+ # Learning curves and state tracking
124
+ self.gen_losses = []
125
+ self.dis_losses = []
126
+ self.geom_history = []
127
+ self.energy_history = []
128
+ self.gradient_history = []
129
+ self.step_history = []
130
+ self.iter = 0
131
+
132
+ # Model initialization
133
+ self._init_models()
134
+
135
+ def _init_models(self):
136
+ """Initialize generator and discriminator networks"""
137
+ self.generator = Generator(
138
+ input_dim=self.feature_dim,
139
+ hidden_dims=self.gen_hidden_dims
140
+ ).to(self.device)
141
+
142
+ self.discriminator = Discriminator(
143
+ input_dim=self.feature_dim+1, # Additional feature for energy change
144
+ hidden_dims=self.dis_hidden_dims
145
+ ).to(self.device)
146
+
147
+ # Configure optimizers
148
+ self.gen_optimizer = optim.Adam(
149
+ self.generator.parameters(),
150
+ lr=self.gen_learning_rate,
151
+ betas=(self.beta1, 0.999)
152
+ )
153
+
154
+ self.dis_optimizer = optim.Adam(
155
+ self.discriminator.parameters(),
156
+ lr=self.dis_learning_rate,
157
+ betas=(self.beta1, 0.999)
158
+ )
159
+
160
+ # Loss function
161
+ self.criterion = nn.BCELoss()
162
+
163
+ # Ensure models use consistent data type
164
+ self.generator.to(dtype=self.dtype)
165
+ self.discriminator.to(dtype=self.dtype)
166
+
167
+ def _update_history(self, geometry, energy, gradient, step=None):
168
+ """Update optimization history"""
169
+ self.geom_history.append(geometry.copy())
170
+ self.energy_history.append(energy)
171
+ self.gradient_history.append(gradient.copy())
172
+
173
+ if step is not None:
174
+ self.step_history.append(step.copy())
175
+ elif len(self.geom_history) > 1:
176
+ # Calculate step from previous geometry
177
+ calc_step = geometry - self.geom_history[-2]
178
+ self.step_history.append(calc_step)
179
+
180
+ def _update_replay_buffers(self):
181
+ """Update replay buffers"""
182
+ if len(self.energy_history) < 2 or len(self.step_history) < 1:
183
+ return
184
+
185
+ # Calculate energy change for the latest step
186
+ energy_change = self.energy_history[-1] - self.energy_history[-2]
187
+ prev_geom = self.geom_history[-2]
188
+ prev_grad = self.gradient_history[-2]
189
+ step = self.step_history[-1]
190
+
191
+ # Create features for each coordinate point
192
+ n_coords = prev_geom.shape[0]
193
+ for i in range(n_coords):
194
+ features = np.hstack([
195
+ prev_geom[i],
196
+ prev_grad[i],
197
+ step[i]
198
+ ]).astype(np.float32) # Explicitly use float32
199
+
200
+ # Add energy change as a feature
201
+ features_with_energy = np.append(features, energy_change).astype(np.float32) # Explicitly use float32
202
+
203
+ # Add to appropriate buffer based on energy change
204
+ experience = (features, features_with_energy, energy_change)
205
+ if energy_change <= 0: # Energy decreased = good step
206
+ self.good_buffer.add(experience)
207
+ else: # Energy increased = bad step
208
+ self.bad_buffer.add(experience)
209
+
210
+ def _train_gan(self):
211
+ """Train the GAN model"""
212
+ # Skip training if insufficient samples
213
+ if len(self.good_buffer) < self.min_samples_for_training:
214
+ return False
215
+
216
+ # Sample from buffer
217
+ for _ in range(self.training_steps):
218
+ # Mini-batch sampling
219
+ real_batch_size = min(self.batch_size // 2, len(self.good_buffer))
220
+ fake_batch_size = min(self.batch_size // 2, len(self.bad_buffer))
221
+
222
+ if real_batch_size == 0 or fake_batch_size == 0:
223
+ continue
224
+
225
+ # Good step samples (real/positive)
226
+ good_samples = self.good_buffer.sample(real_batch_size)
227
+ good_features = torch.tensor([s[0] for s in good_samples],
228
+ device=self.device,
229
+ dtype=self.dtype) # Explicitly set dtype
230
+ good_features_with_energy = torch.tensor([s[1] for s in good_samples],
231
+ device=self.device,
232
+ dtype=self.dtype) # Explicitly set dtype
233
+
234
+ # Bad step samples (fake/negative)
235
+ bad_samples = self.bad_buffer.sample(fake_batch_size)
236
+ bad_features = torch.tensor([s[0] for s in bad_samples],
237
+ device=self.device,
238
+ dtype=self.dtype) # Explicitly set dtype
239
+ bad_features_with_energy = torch.tensor([s[1] for s in bad_samples],
240
+ device=self.device,
241
+ dtype=self.dtype) # Explicitly set dtype
242
+
243
+ # Create labels
244
+ real_labels = torch.ones(real_batch_size, 1, device=self.device, dtype=self.dtype)
245
+ fake_labels = torch.zeros(fake_batch_size, 1, device=self.device, dtype=self.dtype)
246
+
247
+ # --------------------
248
+ # Train Discriminator
249
+ # --------------------
250
+ self.dis_optimizer.zero_grad()
251
+
252
+ # Classify good samples
253
+ good_outputs = self.discriminator(good_features_with_energy)
254
+ d_loss_real = self.criterion(good_outputs, real_labels)
255
+
256
+ # Classify bad samples
257
+ bad_outputs = self.discriminator(bad_features_with_energy)
258
+ d_loss_fake = self.criterion(bad_outputs, fake_labels)
259
+
260
+ # Classify generator-modified steps
261
+ gen_scale = self.generator(bad_features)
262
+ gen_step_features = bad_features.clone()
263
+
264
+ # Apply scaling (maintain direction, adjust scale only)
265
+ for i in range(gen_step_features.shape[0]):
266
+ # Get original step vector (last dimension)
267
+ orig_step = gen_step_features[i, -1].item()
268
+
269
+ # Calculate scale factor (convert from -1 to 1 range to appropriate scale)
270
+ scale = ((gen_scale[i].item() + 1) / 2) * (self.max_scale - self.min_scale) + self.min_scale
271
+
272
+ # Apply scale to compute modified step
273
+ gen_step_features[i, -1] = orig_step * scale
274
+
275
+ # Add energy change
276
+ gen_features_with_energy = torch.cat([
277
+ gen_step_features,
278
+ torch.zeros(gen_step_features.shape[0], 1, device=self.device, dtype=self.dtype) # Explicit dtype
279
+ ], dim=1)
280
+
281
+ g_outputs = self.discriminator(gen_features_with_energy)
282
+ d_loss_gen = self.criterion(g_outputs, fake_labels)
283
+
284
+ d_loss = d_loss_real + d_loss_fake + d_loss_gen
285
+ d_loss.backward()
286
+ self.dis_optimizer.step()
287
+
288
+ # --------------------
289
+ # Train Generator
290
+ # --------------------
291
+ self.gen_optimizer.zero_grad()
292
+
293
+ # Generate modified steps
294
+ gen_scale = self.generator(bad_features)
295
+ gen_step_features = bad_features.clone()
296
+
297
+ # Apply scaling
298
+ for i in range(gen_step_features.shape[0]):
299
+ orig_step = gen_step_features[i, -1].item()
300
+ scale = ((gen_scale[i].item() + 1) / 2) * (self.max_scale - self.min_scale) + self.min_scale
301
+ gen_step_features[i, -1] = orig_step * scale
302
+
303
+ # Add energy change
304
+ gen_features_with_energy = torch.cat([
305
+ gen_step_features,
306
+ torch.zeros(gen_step_features.shape[0], 1, device=self.device, dtype=self.dtype) # Explicit dtype
307
+ ], dim=1)
308
+
309
+ # Get discriminator evaluation
310
+ g_outputs = self.discriminator(gen_features_with_energy)
311
+
312
+ # Train generator to produce "good" steps
313
+ g_loss = self.criterion(g_outputs, real_labels)
314
+ g_loss.backward()
315
+ self.gen_optimizer.step()
316
+
317
+ # Record losses
318
+ self.gen_losses.append(g_loss.item())
319
+ self.dis_losses.append(d_loss.item())
320
+
321
+ if len(self.gen_losses) > 0:
322
+ print(f"GAN training - Gen loss: {np.mean(self.gen_losses[-self.training_steps:]):.4f}, "
323
+ f"Dis loss: {np.mean(self.dis_losses[-self.training_steps:]):.4f}")
324
+
325
+ return True
326
+
327
+ def _generate_improved_step(self, geometry, gradient, original_step):
328
+ """Generate improved step using the GAN"""
329
+ # Don't modify step if original norm is near zero
330
+ orig_norm = np.linalg.norm(original_step)
331
+ if orig_norm < 1e-10:
332
+ return original_step
333
+
334
+ # Create features
335
+ n_coords = geometry.shape[0]
336
+ features = []
337
+
338
+ for i in range(n_coords):
339
+ feat = np.hstack([
340
+ geometry[i],
341
+ gradient[i],
342
+ original_step[i]
343
+ ]).astype(np.float32) # Explicitly use float32
344
+ features.append(feat)
345
+
346
+ # Prepare batch for model evaluation
347
+ features_tensor = torch.tensor(features, device=self.device, dtype=self.dtype) # Explicitly set dtype
348
+
349
+ # Generate scaling factors with the generator
350
+ self.generator.eval()
351
+ with torch.no_grad():
352
+ scale_factors = self.generator(features_tensor)
353
+
354
+ # Convert from -1 to 1 range to actual scale range
355
+ scale_factors = ((scale_factors + 1) / 2) * (self.max_scale - self.min_scale) + self.min_scale
356
+ scale_factors = scale_factors.cpu().numpy()
357
+
358
+ # Generate modified step
359
+ gan_step = original_step.copy()
360
+
361
+ for i in range(n_coords):
362
+ # Apply scale to each coordinate point
363
+ gan_step[i] = original_step[i] * scale_factors[i, 0]
364
+
365
+ # Mix steps
366
+ mixed_step = self.mix_ratio * gan_step + (1 - self.mix_ratio) * original_step
367
+
368
+ # Limit step size
369
+ step_norm = np.linalg.norm(mixed_step)
370
+ if step_norm > self.step_clip:
371
+ mixed_step = mixed_step * (self.step_clip / step_norm)
372
+
373
+ # Output norms before and after modification
374
+ print(f"Step norm - Original: {orig_norm:.6f}, GAN: {np.linalg.norm(gan_step):.6f}, "
375
+ f"Mixed: {np.linalg.norm(mixed_step):.6f}")
376
+
377
+ return mixed_step
378
+
379
+ def run(self, geom_num_list, energy, gradient, original_move_vector):
380
+ """
381
+ Run GAN-Step optimization step
382
+
383
+ Parameters:
384
+ -----------
385
+ geom_num_list : numpy.ndarray
386
+ Current molecular geometry
387
+ energy : float
388
+ Current energy value
389
+ gradient : numpy.ndarray
390
+ Current gradient
391
+ original_move_vector : numpy.ndarray
392
+ Original optimization step
393
+
394
+ Returns:
395
+ --------
396
+ numpy.ndarray
397
+ Modified optimization step
398
+ """
399
+ print("GAN-Step method")
400
+
401
+ # Update history
402
+ self._update_history(geom_num_list, energy, gradient)
403
+
404
+ # Update replay buffers
405
+ self._update_replay_buffers()
406
+
407
+ # Use original step for initial iterations
408
+ if self.iter < 3:
409
+ print(f"Building history (step {self.iter+1}), using original step")
410
+ self.iter += 1
411
+ return original_move_vector
412
+
413
+ # Train GAN model
414
+ if len(self.good_buffer) >= self.min_samples_for_training:
415
+ try:
416
+ gan_trained = self._train_gan()
417
+ if not gan_trained:
418
+ print("Failed to train GAN, using original step")
419
+ self.iter += 1
420
+ return original_move_vector
421
+ except RuntimeError as e:
422
+ print(f"Error during GAN training: {str(e)}")
423
+ print("Using original step due to training error")
424
+ self.iter += 1
425
+ return original_move_vector
426
+ else:
427
+ print(f"Not enough good samples for GAN training ({len(self.good_buffer)}/{self.min_samples_for_training})")
428
+ self.iter += 1
429
+ return original_move_vector
430
+
431
+ # Modify step using GAN
432
+ try:
433
+ gan_step = self._generate_improved_step(geom_num_list, gradient, original_move_vector)
434
+ except Exception as e:
435
+ print(f"Error generating improved step: {str(e)}")
436
+ print("Using original step due to generation error")
437
+ self.iter += 1
438
+ return original_move_vector
439
+
440
+ # Check for numerical issues
441
+ if np.any(np.isnan(gan_step)) or np.any(np.isinf(gan_step)):
442
+ print("Warning: Numerical issues in GAN step, using original step")
443
+ self.iter += 1
444
+ return original_move_vector
445
+
446
+ self.iter += 1
447
+ return gan_step
448
+
449
+ def save_model(self, path='gan_model'):
450
+ """Save GAN model"""
451
+ if not os.path.exists(path):
452
+ os.makedirs(path)
453
+
454
+ torch.save({
455
+ 'generator': self.generator.state_dict(),
456
+ 'discriminator': self.discriminator.state_dict(),
457
+ 'gen_optimizer': self.gen_optimizer.state_dict(),
458
+ 'dis_optimizer': self.dis_optimizer.state_dict(),
459
+ 'gen_losses': self.gen_losses,
460
+ 'dis_losses': self.dis_losses,
461
+ 'iter': self.iter
462
+ }, os.path.join(path, 'gan_step_model.pt'))
463
+
464
+ print(f"Model saved to {os.path.join(path, 'gan_step_model.pt')}")
465
+
466
+ def load_model(self, path='gan_model/gan_step_model.pt'):
467
+ """Load GAN model"""
468
+ if not os.path.exists(path):
469
+ print(f"No model file found at {path}")
470
+ return False
471
+
472
+ try:
473
+ checkpoint = torch.load(path)
474
+ self.generator.load_state_dict(checkpoint['generator'])
475
+ self.discriminator.load_state_dict(checkpoint['discriminator'])
476
+ self.gen_optimizer.load_state_dict(checkpoint['gen_optimizer'])
477
+ self.dis_optimizer.load_state_dict(checkpoint['dis_optimizer'])
478
+ self.gen_losses = checkpoint['gen_losses']
479
+ self.dis_losses = checkpoint['dis_losses']
480
+ self.iter = checkpoint['iter']
481
+
482
+ print(f"Model loaded from {path}")
483
+ return True
484
+ except Exception as e:
485
+ print(f"Failed to load model: {str(e)}")
486
+ return False