MultiOptPy 1.20.7__py3-none-any.whl → 1.20.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,95 +1,294 @@
1
-
2
1
  from multioptpy.Parameters.parameter import UnitValueLib
3
2
  from multioptpy.Utils.calc_tools import torch_calc_dihedral_angle_from_vec
4
3
  import torch
4
+ import math
5
5
 
6
6
  class StructKeepDihedralAnglePotential:
7
+ """
8
+ Computes the harmonic potential energy for a dihedral angle defined by four atoms.
9
+
10
+ This class incorporates a robust singularity handling mechanism for cases where
11
+ three consecutive atoms become collinear (i.e., the bond angle approaches 0 or 180 degrees).
12
+ In such configurations, the dihedral angle is mathematically undefined, leading to
13
+ numerical instabilities and exploding gradients.
14
+
15
+ To resolve this, a smooth switching function (smoothstep) is applied to the squared norm
16
+ of the cross products of the bond vectors. This function smoothly attenuates the
17
+ potential energy and forces to zero as the geometry enters the collinear region.
18
+
19
+ Energy Function:
20
+ E_total = E_harmonic * S(|n1|^2) * S(|n2|^2)
21
+
22
+ Where:
23
+ E_harmonic = 0.5 * k * (phi - phi_0)^2
24
+ S(x) is a cubic Hermite interpolation function (smoothstep) mapping [0, 1].
25
+ """
26
+
7
27
  def __init__(self, **kwarg):
8
28
  self.config = kwarg
9
29
  UVL = UnitValueLib()
10
30
  self.hartree2kcalmol = UVL.hartree2kcalmol
11
31
  self.bohr2angstroms = UVL.bohr2angstroms
12
32
  self.hartree2kjmol = UVL.hartree2kjmol
33
+
34
+ # Thresholds for collinearity smoothing (Squared Norm of Cross Product).
35
+ # These values determine the range over which the potential is switched off.
36
+ # COLLINEAR_CUT_MIN: Below this value, the weighting factor is 0.0.
37
+ # COLLINEAR_CUT_MAX: Above this value, the weighting factor is 1.0.
38
+ # A squared norm of 1e-8 corresponds to a sine of approx 1e-4.
39
+ self.COLLINEAR_CUT_MIN = 1e-10
40
+ self.COLLINEAR_CUT_MAX = 1e-8
41
+
13
42
  return
14
- def calc_energy(self, geom_num_list, bias_pot_params=[]):
43
+
44
+ def _compute_switching(self, val):
15
45
  """
16
- # required variables: self.config["keep_dihedral_angle_spring_const"],
17
- self.config["keep_dihedral_angle_atom_pairs"]
18
- self.config["keep_dihedral_angle_angle"]
19
- bias_pot_params[0] : keep_dihedral_angle_spring_const
20
- bias_pot_params[1] : keep_dihedral_angle_angle
46
+ Computes the smoothstep switching factor.
47
+
48
+ Args:
49
+ val (Tensor): The squared norm of the cross product vector.
50
+
51
+ Returns:
52
+ Tensor: A scalar factor between 0.0 and 1.0 using cubic interpolation.
53
+ Returns 0.0 if val < min, 1.0 if val > max.
54
+ """
55
+ t = (val - self.COLLINEAR_CUT_MIN) / (self.COLLINEAR_CUT_MAX - self.COLLINEAR_CUT_MIN)
56
+ t = torch.clamp(t, 0.0, 1.0)
57
+ # Smoothstep function: 3t^2 - 2t^3
58
+ return t * t * (3.0 - 2.0 * t)
59
+
60
+ def calc_energy(self, geom_num_list, bias_pot_params=[]):
21
61
  """
22
- a1 = geom_num_list[self.config["keep_dihedral_angle_atom_pairs"][1]-1] - geom_num_list[self.config["keep_dihedral_angle_atom_pairs"][0]-1]
23
- a2 = geom_num_list[self.config["keep_dihedral_angle_atom_pairs"][2]-1] - geom_num_list[self.config["keep_dihedral_angle_atom_pairs"][1]-1]
24
- a3 = geom_num_list[self.config["keep_dihedral_angle_atom_pairs"][3]-1] - geom_num_list[self.config["keep_dihedral_angle_atom_pairs"][2]-1]
62
+ Calculates the potential energy with collinearity smoothing.
25
63
 
26
- angle = torch.abs(torch_calc_dihedral_angle_from_vec(a1, a2, a3))
64
+ Args:
65
+ geom_num_list (Tensor): Atomic coordinates tensor of shape (N_atoms, 3).
66
+ bias_pot_params (list, optional): Optional override for parameters [k, phi_0].
67
+
68
+ Returns:
69
+ Tensor: The calculated potential energy (scalar).
70
+ """
71
+
72
+ # ========================================
73
+ # 1. Parameter Retrieval
74
+ # ========================================
27
75
  if len(bias_pot_params) == 0:
28
- energy = 0.5 * self.config["keep_dihedral_angle_spring_const"] * (angle - torch.deg2rad(torch.tensor(self.config["keep_dihedral_angle_angle"]))) ** 2
76
+ k = self.config["keep_dihedral_angle_spring_const"]
77
+ phi_0_deg = torch.tensor(self.config["keep_dihedral_angle_angle"])
78
+ phi_0 = torch.deg2rad(phi_0_deg)
29
79
  else:
30
- energy = 0.5 * bias_pot_params[0] * (angle - torch.deg2rad(bias_pot_params[1])) ** 2
31
-
32
- return energy #hartree
80
+ k = bias_pot_params[0]
81
+ phi_0_deg = bias_pot_params[1]
82
+ if isinstance(phi_0_deg, torch.Tensor):
83
+ phi_0 = torch.deg2rad(phi_0_deg)
84
+ else:
85
+ phi_0 = torch.deg2rad(torch.tensor(phi_0_deg))
86
+
87
+ device = geom_num_list.device if isinstance(geom_num_list, torch.Tensor) else torch.device("cpu")
88
+ dtype = geom_num_list.dtype
89
+
90
+ PI = torch.tensor(math.pi, device=device, dtype=dtype)
91
+ phi_0 = phi_0.to(device=device, dtype=dtype)
92
+
93
+ # ========================================
94
+ # 2. Vector Calculations
95
+ # ========================================
96
+ i1 = self.config["keep_dihedral_angle_atom_pairs"][0] - 1
97
+ i2 = self.config["keep_dihedral_angle_atom_pairs"][1] - 1
98
+ i3 = self.config["keep_dihedral_angle_atom_pairs"][2] - 1
99
+ i4 = self.config["keep_dihedral_angle_atom_pairs"][3] - 1
100
+
101
+ b1 = geom_num_list[i2] - geom_num_list[i1]
102
+ b2 = geom_num_list[i3] - geom_num_list[i2]
103
+ b3 = geom_num_list[i4] - geom_num_list[i3]
104
+
105
+ # Normal vectors to the planes defined by bond pairs
106
+ n1 = torch.linalg.cross(b1, b2)
107
+ n2 = torch.linalg.cross(b2, b3)
108
+
109
+ # ========================================
110
+ # 3. Collinearity Guard (Singularity Smoothing)
111
+ # ========================================
112
+ n1_sq_norm = torch.sum(n1**2, dim=-1)
113
+ n2_sq_norm = torch.sum(n2**2, dim=-1)
114
+
115
+ # Compute switching factors to dampen energy in collinear regions
116
+ switch_1 = self._compute_switching(n1_sq_norm)
117
+ switch_2 = self._compute_switching(n2_sq_norm)
118
+
119
+ # Safe normalization for angle calculation.
120
+ # Clamping prevents NaN in the graph, while the switching factor ensures
121
+ # that the resulting gradients in the clamped region are zeroed out.
122
+ n1_norm = torch.clamp(torch.sqrt(n1_sq_norm), min=1e-12)
123
+ n2_norm = torch.clamp(torch.sqrt(n2_sq_norm), min=1e-12)
124
+
125
+ n1_hat = n1 / (n1_norm.unsqueeze(-1))
126
+ n2_hat = n2 / (n2_norm.unsqueeze(-1))
127
+
128
+ b2_norm = torch.clamp(torch.linalg.norm(b2), min=1e-12)
129
+ b2_hat = b2 / (b2_norm.unsqueeze(-1))
130
+
131
+ # ========================================
132
+ # 4. Angle Calculation
133
+ # ========================================
134
+ x = torch.sum(n1_hat * n2_hat, dim=-1)
135
+ # (n1 x n2) is parallel to b2
136
+ m1 = torch.linalg.cross(n1_hat, n2_hat)
137
+ y = torch.sum(m1 * b2_hat, dim=-1)
138
+
139
+ phi = torch.atan2(y, x)
140
+
141
+ # ========================================
142
+ # 5. Energy Calculation with Smoothing
143
+ # ========================================
144
+ diff = phi - phi_0
145
+ # Wrap difference to [-pi, pi]
146
+ diff = diff - 2.0 * PI * torch.round(diff / (2.0 * PI))
147
+
148
+ raw_energy = 0.5 * k * diff**2
149
+
150
+ # Apply smoothing factors
151
+ energy = raw_energy * switch_1 * switch_2
152
+
153
+ return energy
33
154
 
34
155
 
35
156
  class StructKeepDihedralAnglePotentialv2:
157
+ """
158
+ Computes the dihedral angle potential energy defined by the centroids of four atom fragments.
159
+
160
+ This class is designed for coarse-grained constraints where the dihedral is defined
161
+ by the geometric centers of specified groups of atoms rather than single atoms.
162
+ It includes the same singularity smoothing mechanism as the standard atom-based potential
163
+ to handle cases where the fragment centroids become collinear.
164
+ """
165
+
36
166
  def __init__(self, **kwarg):
37
167
  self.config = kwarg
38
168
  UVL = UnitValueLib()
39
169
  self.hartree2kcalmol = UVL.hartree2kcalmol
40
170
  self.bohr2angstroms = UVL.bohr2angstroms
41
171
  self.hartree2kjmol = UVL.hartree2kjmol
172
+
173
+ # Thresholds for collinearity smoothing
174
+ self.COLLINEAR_CUT_MIN = 1e-10
175
+ self.COLLINEAR_CUT_MAX = 1e-8
176
+
42
177
  return
43
- def calc_energy(self, geom_num_list, bias_pot_params=[]):
178
+
179
+ def _compute_switching(self, val):
44
180
  """
45
- # required variables: self.config["keep_dihedral_angle_v2_spring_const"],
46
- self.config["keep_dihedral_angle_v2_angle"],
47
- self.config["keep_dihedral_angle_v2_fragm1"],
48
- self.config["keep_dihedral_angle_v2_fragm2"],
49
- self.config["keep_dihedral_angle_v2_fragm3"],
50
- self.config["keep_dihedral_angle_v2_fragm4"],
51
- bias_pot_params[0] : keep_dihedral_angle_v2_spring_const
52
- bias_pot_params[1] : keep_dihedral_angle_v2_angle
181
+ Computes the smoothstep switching factor.
53
182
  """
54
- fragm_1_center = torch.mean(geom_num_list[torch.tensor(self.config["keep_dihedral_angle_v2_fragm1"]) - 1], dim=0)
55
- fragm_2_center = torch.mean(geom_num_list[torch.tensor(self.config["keep_dihedral_angle_v2_fragm2"]) - 1], dim=0)
56
- fragm_3_center = torch.mean(geom_num_list[torch.tensor(self.config["keep_dihedral_angle_v2_fragm3"]) - 1], dim=0)
57
- fragm_4_center = torch.mean(geom_num_list[torch.tensor(self.config["keep_dihedral_angle_v2_fragm4"]) - 1], dim=0)
58
-
59
- a1 = fragm_2_center - fragm_1_center
60
- a2 = fragm_3_center - fragm_2_center
61
- a3 = fragm_4_center - fragm_3_center
183
+ t = (val - self.COLLINEAR_CUT_MIN) / (self.COLLINEAR_CUT_MAX - self.COLLINEAR_CUT_MIN)
184
+ t = torch.clamp(t, 0.0, 1.0)
185
+ return t * t * (3.0 - 2.0 * t)
62
186
 
63
- angle = torch.abs(torch_calc_dihedral_angle_from_vec(a1, a2, a3))
187
+ def calc_energy(self, geom_num_list, bias_pot_params=[]):
188
+ """
189
+ Calculates the potential energy for fragment centroids with collinearity smoothing.
190
+ """
64
191
  if len(bias_pot_params) == 0:
65
- energy = 0.5 * self.config["keep_dihedral_angle_v2_spring_const"] * (angle - torch.deg2rad(torch.tensor(self.config["keep_dihedral_angle_v2_angle"]))) ** 2
192
+ k = self.config["keep_dihedral_angle_v2_spring_const"]
193
+ phi_0_deg = torch.tensor(self.config["keep_dihedral_angle_v2_angle"])
194
+ phi_0 = torch.deg2rad(phi_0_deg)
66
195
  else:
67
- energy = 0.5 * bias_pot_params[0] * (angle - torch.deg2rad(bias_pot_params[1])) ** 2
68
- return energy #hartree
196
+ k = bias_pot_params[0]
197
+ phi_0_deg = bias_pot_params[1]
198
+ if isinstance(phi_0_deg, torch.Tensor):
199
+ phi_0 = torch.deg2rad(phi_0_deg)
200
+ else:
201
+ phi_0 = torch.deg2rad(torch.tensor(phi_0_deg))
202
+
203
+ device = geom_num_list.device if isinstance(geom_num_list, torch.Tensor) else torch.device("cpu")
204
+ dtype = geom_num_list.dtype
205
+
206
+ PI = torch.tensor(math.pi, device=device, dtype=dtype)
207
+ phi_0 = phi_0.to(device=device, dtype=dtype)
208
+
209
+ # Vector Calculations for Fragment Centers
210
+ def get_indices(key):
211
+ return torch.tensor(self.config[key], device=device, dtype=torch.long) - 1
212
+
213
+ fragm_1_center = torch.mean(geom_num_list[get_indices("keep_dihedral_angle_v2_fragm1")], dim=0)
214
+ fragm_2_center = torch.mean(geom_num_list[get_indices("keep_dihedral_angle_v2_fragm2")], dim=0)
215
+ fragm_3_center = torch.mean(geom_num_list[get_indices("keep_dihedral_angle_v2_fragm3")], dim=0)
216
+ fragm_4_center = torch.mean(geom_num_list[get_indices("keep_dihedral_angle_v2_fragm4")], dim=0)
217
+
218
+ b1 = fragm_2_center - fragm_1_center
219
+ b2 = fragm_3_center - fragm_2_center
220
+ b3 = fragm_4_center - fragm_3_center
221
+
222
+ n1 = torch.linalg.cross(b1, b2)
223
+ n2 = torch.linalg.cross(b2, b3)
224
+
225
+ # Collinearity Guard (Smoothing)
226
+ n1_sq_norm = torch.sum(n1**2, dim=-1)
227
+ n2_sq_norm = torch.sum(n2**2, dim=-1)
228
+
229
+ switch_1 = self._compute_switching(n1_sq_norm)
230
+ switch_2 = self._compute_switching(n2_sq_norm)
231
+
232
+ n1_norm = torch.clamp(torch.sqrt(n1_sq_norm), min=1e-12)
233
+ n2_norm = torch.clamp(torch.sqrt(n2_sq_norm), min=1e-12)
234
+
235
+ n1_hat = n1 / (n1_norm.unsqueeze(-1))
236
+ n2_hat = n2 / (n2_norm.unsqueeze(-1))
237
+
238
+ b2_norm = torch.clamp(torch.linalg.norm(b2), min=1e-12)
239
+ b2_hat = b2 / (b2_norm.unsqueeze(-1))
240
+
241
+ # Angle Calculation
242
+ x = torch.sum(n1_hat * n2_hat, dim=-1)
243
+ m1 = torch.linalg.cross(n1_hat, n2_hat)
244
+ y = torch.sum(m1 * b2_hat, dim=-1)
245
+
246
+ phi = torch.atan2(y, x)
247
+
248
+ # Energy Calculation
249
+ diff = phi - phi_0
250
+ diff = diff - 2.0 * PI * torch.round(diff / (2.0 * PI))
251
+
252
+ raw_energy = 0.5 * k * diff**2
253
+
254
+ # Apply Smoothing
255
+ energy = raw_energy * switch_1 * switch_2
256
+
257
+ return energy
69
258
 
70
259
  class StructKeepDihedralAnglePotentialCos:
260
+ """
261
+ Computes a cosine-based dihedral potential energy for fragment centroids.
262
+
263
+ Includes singularity smoothing to prevent gradient explosion when fragment centers
264
+ become collinear.
265
+ """
71
266
  def __init__(self, **kwarg):
72
267
  self.config = kwarg
73
268
  UVL = UnitValueLib()
74
269
  self.hartree2kcalmol = UVL.hartree2kcalmol
75
270
  self.bohr2angstroms = UVL.bohr2angstroms
76
271
  self.hartree2kjmol = UVL.hartree2kjmol
272
+
273
+ self.COLLINEAR_CUT_MIN = 1e-10
274
+ self.COLLINEAR_CUT_MAX = 1e-8
77
275
  return
276
+
277
+ def _compute_switching(self, val):
278
+ """
279
+ Computes the smoothstep switching factor.
280
+ """
281
+ t = (val - self.COLLINEAR_CUT_MIN) / (self.COLLINEAR_CUT_MAX - self.COLLINEAR_CUT_MIN)
282
+ t = torch.clamp(t, 0.0, 1.0)
283
+ return t * t * (3.0 - 2.0 * t)
284
+
78
285
  def calc_energy(self, geom_num_list, bias_pot_params=[]):
79
286
  """
80
- # required variables: self.config["keep_dihedral_angle_cos_potential_const"],
81
- self.config["keep_dihedral_angle_cos_angle_const"],
82
- self.config["keep_dihedral_angle_cos_angle"],
83
- self.config["keep_dihedral_angle_cos_fragm1"],
84
- self.config["keep_dihedral_angle_cos_fragm2"],
85
- self.config["keep_dihedral_angle_cos_fragm3"],
86
- self.config["keep_dihedral_angle_cos_fragm4"],
87
-
287
+ Calculates the cosine-based potential energy.
88
288
  """
89
289
  potential_const = float(self.config["keep_dihedral_angle_cos_potential_const"])
90
290
  angle_const = float(self.config["keep_dihedral_angle_cos_angle_const"])
91
291
 
92
-
93
292
  fragm_1_center = torch.mean(geom_num_list[torch.tensor(self.config["keep_dihedral_angle_cos_fragm1"]) - 1], dim=0)
94
293
  fragm_2_center = torch.mean(geom_num_list[torch.tensor(self.config["keep_dihedral_angle_cos_fragm2"]) - 1], dim=0)
95
294
  fragm_3_center = torch.mean(geom_num_list[torch.tensor(self.config["keep_dihedral_angle_cos_fragm3"]) - 1], dim=0)
@@ -99,7 +298,19 @@ class StructKeepDihedralAnglePotentialCos:
99
298
  a2 = fragm_3_center - fragm_2_center
100
299
  a3 = fragm_4_center - fragm_3_center
101
300
 
301
+ # Explicitly compute cross products to determine switching factors
302
+ n1 = torch.linalg.cross(a1, a2)
303
+ n2 = torch.linalg.cross(a2, a3)
304
+ n1_sq_norm = torch.sum(n1**2, dim=-1)
305
+ n2_sq_norm = torch.sum(n2**2, dim=-1)
306
+
307
+ switch_1 = self._compute_switching(n1_sq_norm)
308
+ switch_2 = self._compute_switching(n2_sq_norm)
309
+
102
310
  angle = torch_calc_dihedral_angle_from_vec(a1, a2, a3)
103
- energy = 0.5 * potential_const * (1.0 -1* torch.cos(angle_const * angle - (torch.deg2rad(torch.tensor(self.config["keep_dihedral_angle_cos_angle"])))))
311
+ raw_energy = 0.5 * potential_const * (1.0 - 1 * torch.cos(angle_const * angle - (torch.deg2rad(torch.tensor(self.config["keep_dihedral_angle_cos_angle"])))))
312
+
313
+ # Apply smoothing
314
+ energy = raw_energy * switch_1 * switch_2
104
315
 
105
- return energy #hartree
316
+ return energy
@@ -1,70 +1,277 @@
1
-
2
1
  from multioptpy.Parameters.parameter import UnitValueLib
3
2
  from multioptpy.Utils.calc_tools import torch_calc_outofplain_angle_from_vec
4
3
  import torch
5
-
4
+ import math
6
5
 
7
6
  class StructKeepOutofPlainAnglePotential:
7
+ """
8
+ Class for calculating Out-of-Plane (Wilson) angle potential with robust singularity handling.
9
+
10
+ Singularity Handling:
11
+ The Out-of-Plane angle measures the elevation of vector a1 from the plane defined by a2 and a3.
12
+ A singularity occurs when vectors a2 and a3 are collinear (angle 0 or 180).
13
+ In this case, the reference plane is undefined, and the normal vector vanishes.
14
+
15
+ This implementation applies a 'Collinearity Guard' to force gradients to zero
16
+ when the reference plane is undefined.
17
+ """
18
+
8
19
  def __init__(self, **kwarg):
9
20
  self.config = kwarg
10
21
  UVL = UnitValueLib()
11
22
  self.hartree2kcalmol = UVL.hartree2kcalmol
12
23
  self.bohr2angstroms = UVL.bohr2angstroms
13
24
  self.hartree2kjmol = UVL.hartree2kjmol
25
+
26
+ # Threshold for plane definition stability (Squared Norm of cross product)
27
+ # If the cross product of base vectors is smaller than this,
28
+ # the plane is considered undefined.
29
+ self.COLLINEAR_CUT_SQ = 1e-8
30
+
14
31
  return
32
+
15
33
  def calc_energy(self, geom_num_list, bias_pot_params=[]):
16
34
  """
17
- # required variables: self.config["keep_out_of_plain_angle_spring_const"],
18
- self.config["keep_out_of_plain_angle_atom_pairs"]
19
- self.config["keep_out_of_plain_angle_angle"]
20
- bias_pot_params[0] : keep_out_of_plain_angle_spring_const
21
- bias_pot_params[1] : keep_out_of_plain_angle_angle
22
-
35
+ Calculates Out-of-Plane angle energy.
36
+
37
+ Args:
38
+ geom_num_list: Tensor of atomic coordinates (N_atoms, 3)
39
+ bias_pot_params: Optional bias parameters [k, theta_0]
40
+
41
+ Definition:
42
+ Center atom: i (index 0 in pair list)
43
+ Neighbors: j, k, l (indices 1, 2, 3)
44
+ Vectors: a1 = r_j - r_i
45
+ a2 = r_k - r_i
46
+ a3 = r_l - r_i
47
+
48
+ Angle is the deviation of a1 from the plane spanned by a2 and a3.
23
49
  """
24
- a1 = geom_num_list[self.config["keep_out_of_plain_angle_atom_pairs"][1]-1] - geom_num_list[self.config["keep_out_of_plain_angle_atom_pairs"][0]-1]
25
- a2 = geom_num_list[self.config["keep_out_of_plain_angle_atom_pairs"][2]-1] - geom_num_list[self.config["keep_out_of_plain_angle_atom_pairs"][0]-1]
26
- a3 = geom_num_list[self.config["keep_out_of_plain_angle_atom_pairs"][3]-1] - geom_num_list[self.config["keep_out_of_plain_angle_atom_pairs"][0]-1]
27
-
28
- angle = torch_calc_outofplain_angle_from_vec(a1, a2, a3)
50
+
51
+ # ========================================
52
+ # 1. Parameter Retrieval
53
+ # ========================================
29
54
  if len(bias_pot_params) == 0:
30
- energy = 0.5 * self.config["keep_out_of_plain_angle_spring_const"] * (angle - torch.deg2rad(torch.tensor(self.config["keep_out_of_plain_angle_angle"]))) ** 2
55
+ k = self.config["keep_out_of_plain_angle_spring_const"]
56
+ theta_0_deg = torch.tensor(self.config["keep_out_of_plain_angle_angle"])
57
+ theta_0 = torch.deg2rad(theta_0_deg)
31
58
  else:
32
- energy = 0.5 * bias_pot_params[0] * (angle - torch.deg2rad(bias_pot_params[1])) ** 2
33
- return energy #hartree
59
+ k = bias_pot_params[0]
60
+ theta_0_deg = bias_pot_params[1]
61
+ if isinstance(theta_0_deg, torch.Tensor):
62
+ theta_0 = torch.deg2rad(theta_0_deg)
63
+ else:
64
+ theta_0 = torch.deg2rad(torch.tensor(theta_0_deg))
65
+
66
+ device = geom_num_list.device if isinstance(geom_num_list, torch.Tensor) else torch.device("cpu")
67
+ dtype = geom_num_list.dtype
68
+ theta_0 = theta_0.to(device=device, dtype=dtype)
69
+
70
+ # ========================================
71
+ # 2. Vector Calculations
72
+ # ========================================
73
+ # Indices are 1-based in config
74
+ # Atom 0 is the Central Atom based on the vector definition in your snippet
75
+ c_idx = self.config["keep_out_of_plain_angle_atom_pairs"][0] - 1
76
+ i1_idx = self.config["keep_out_of_plain_angle_atom_pairs"][1] - 1
77
+ i2_idx = self.config["keep_out_of_plain_angle_atom_pairs"][2] - 1
78
+ i3_idx = self.config["keep_out_of_plain_angle_atom_pairs"][3] - 1
79
+
80
+ center_pos = geom_num_list[c_idx]
81
+
82
+ # Vectors from Center to Neighbors
83
+ a1 = geom_num_list[i1_idx] - center_pos # The vector to measure (Probe)
84
+ a2 = geom_num_list[i2_idx] - center_pos # Plane definition vector 1
85
+ a3 = geom_num_list[i3_idx] - center_pos # Plane definition vector 2
86
+
87
+ # ========================================
88
+ # 3. Plane Definition & Singularity Guard
89
+ # ========================================
90
+ # The plane is defined by a2 and a3.
91
+ # Normal vector n = a2 x a3
92
+ n = torch.linalg.cross(a2, a3)
93
+
94
+ # Check squared norm of normal vector
95
+ # If n is near zero, a2 and a3 are collinear -> Plane Undefined
96
+ n_sq_norm = torch.sum(n**2, dim=-1)
97
+
98
+ # Guard Mask
99
+ is_undefined_plane = (n_sq_norm < self.COLLINEAR_CUT_SQ)
100
+
101
+ # Safe normalization
102
+ n_norm = torch.sqrt(n_sq_norm)
103
+ n_hat_demon = torch.clamp(n_norm.unsqueeze(-1), min=1e-12)
104
+ n_hat = n / n_hat_demon
105
+
106
+ # ========================================
107
+ # 4. Angle Calculation (Robust atan2)
108
+ # ========================================
109
+ # We want the angle 'phi' between a1 and the plane (a2, a3).
110
+ # This is equivalent to 90 - angle(a1, n), or simply:
111
+ # sin(phi) = (a1 . n_hat) / |a1|
112
+
113
+ # Height of a1 relative to the plane (Projection onto normal)
114
+ # h = |a1| * sin(phi)
115
+ h = torch.sum(a1 * n_hat, dim=-1)
116
+
117
+ # Length of a1
118
+ a1_norm = torch.linalg.norm(a1) # add epsilon if atom overlap is a concern
119
+
120
+ # Projected length of a1 onto the plane
121
+ # r_proj = |a1| * cos(phi) = sqrt(|a1|^2 - h^2)
122
+ # We clamp inside sqrt to avoid negative values due to numerical noise
123
+ r_proj_sq = a1_norm**2 - h**2
124
+ r_proj = torch.sqrt(torch.clamp(r_proj_sq, min=0.0))
125
+
126
+ # Calculate angle using atan2(y, x) = atan2(height, projected_distance)
127
+ # This is valid for -90 to +90 degrees and stable at 90.
128
+ # Note: If a1 is perpendicular to plane, r_proj is 0, atan2(h, 0) gives +/- 90 correctly.
129
+ angle = torch.atan2(h, r_proj)
130
+
131
+ # ========================================
132
+ # 5. Energy & Clamping
133
+ # ========================================
134
+
135
+ # Difference from equilibrium
136
+ # Out-of-plane angles are usually non-periodic (limited to -90 to 90),
137
+ # but if using generalized definition, simple subtraction is usually sufficient.
138
+ diff = angle - theta_0
139
+
140
+ energy_harmonic = 0.5 * k * diff**2
141
+
142
+ # Apply Guard:
143
+ # If plane is undefined (collinear base vectors), force energy/force to 0.
144
+ energy = torch.where(is_undefined_plane, torch.tensor(0.0, device=device, dtype=dtype), energy_harmonic)
145
+
146
+ return energy
34
147
 
35
148
  class StructKeepOutofPlainAnglePotentialv2:
149
+ """
150
+ Class for calculating Out-of-Plane (Wilson) angle potential for fragment centroids
151
+ with robust singularity handling.
152
+
153
+ This class calculates the angle of vector a1 (Frag1->Frag2) out of the plane
154
+ defined by vectors a2 (Frag1->Frag3) and a3 (Frag1->Frag4).
155
+
156
+ Singularity Handling:
157
+ - Plane Undefined: If Frag1, Frag3, and Frag4 are collinear, the reference plane
158
+ cannot be defined (normal vector vanishes). A 'Collinearity Guard' forces
159
+ gradients to zero in this region.
160
+ - Vertical Instability: Uses atan2(h, r_proj) instead of asin(h/r) to maintain
161
+ numerical stability when the angle approaches +/- 90 degrees.
162
+ """
163
+
36
164
  def __init__(self, **kwarg):
37
165
  self.config = kwarg
38
166
  UVL = UnitValueLib()
39
167
  self.hartree2kcalmol = UVL.hartree2kcalmol
40
168
  self.bohr2angstroms = UVL.bohr2angstroms
41
169
  self.hartree2kjmol = UVL.hartree2kjmol
170
+
171
+ # Threshold for plane definition stability (Squared Norm of cross product)
172
+ # If the cross product of plane-defining vectors is smaller than this,
173
+ # the plane is considered undefined.
174
+ self.COLLINEAR_CUT_SQ = 1e-8
175
+
42
176
  return
43
177
 
44
178
  def calc_energy(self, geom_num_list, bias_pot_params=[]):
45
179
  """
46
- # required variables: self.config["keep_out_of_plain_angle_v2_spring_const"],
47
- self.config["keep_out_of_plain_angle_v2_angle"],
48
- self.config["keep_out_of_plain_angle_v2_fragm1"],
49
- self.config["keep_out_of_plain_angle_v2_fragm2"],
50
- self.config["keep_out_of_plain_angle_v2_fragm3"],
51
- self.config["keep_out_of_plain_angle_v2_fragm4"],
52
-
180
+ Calculates Out-of-Plane angle energy for fragments.
181
+
182
+ Args:
183
+ geom_num_list: Tensor of atomic coordinates (N_atoms, 3)
184
+ bias_pot_params: Optional bias parameters [k, theta_0]
53
185
  """
54
- fragm_1_center = torch.mean(geom_num_list[torch.tensor(self.config["keep_out_of_plain_angle_v2_fragm1"]) - 1], dim=0)
55
- fragm_2_center = torch.mean(geom_num_list[torch.tensor(self.config["keep_out_of_plain_angle_v2_fragm2"]) - 1], dim=0)
56
- fragm_3_center = torch.mean(geom_num_list[torch.tensor(self.config["keep_out_of_plain_angle_v2_fragm3"]) - 1], dim=0)
57
- fragm_4_center = torch.mean(geom_num_list[torch.tensor(self.config["keep_out_of_plain_angle_v2_fragm4"]) - 1], dim=0)
186
+
187
+ # ========================================
188
+ # 1. Parameter Retrieval
189
+ # ========================================
190
+ if len(bias_pot_params) == 0:
191
+ k = self.config["keep_out_of_plain_angle_v2_spring_const"]
192
+ theta_0_deg = torch.tensor(self.config["keep_out_of_plain_angle_v2_angle"])
193
+ theta_0 = torch.deg2rad(theta_0_deg)
194
+ else:
195
+ k = bias_pot_params[0]
196
+ theta_0_deg = bias_pot_params[1]
197
+ if isinstance(theta_0_deg, torch.Tensor):
198
+ theta_0 = torch.deg2rad(theta_0_deg)
199
+ else:
200
+ theta_0 = torch.deg2rad(torch.tensor(theta_0_deg))
201
+
202
+ # Device/Dtype handling
203
+ device = geom_num_list.device if isinstance(geom_num_list, torch.Tensor) else torch.device("cpu")
204
+ dtype = geom_num_list.dtype
205
+ theta_0 = theta_0.to(device=device, dtype=dtype)
206
+
207
+ # ========================================
208
+ # 2. Vector Calculations (Fragment Centroids)
209
+ # ========================================
210
+
211
+ # Helper to get indices
212
+ def get_indices(key):
213
+ return torch.tensor(self.config[key], device=device, dtype=torch.long) - 1
58
214
 
59
-
215
+ # Calculate centroids (Frag 1 is the Vertex/Center)
216
+ fragm_1_center = torch.mean(geom_num_list[get_indices("keep_out_of_plain_angle_v2_fragm1")], dim=0)
217
+ fragm_2_center = torch.mean(geom_num_list[get_indices("keep_out_of_plain_angle_v2_fragm2")], dim=0)
218
+ fragm_3_center = torch.mean(geom_num_list[get_indices("keep_out_of_plain_angle_v2_fragm3")], dim=0)
219
+ fragm_4_center = torch.mean(geom_num_list[get_indices("keep_out_of_plain_angle_v2_fragm4")], dim=0)
220
+
221
+ # Define vectors originating from Fragment 1
222
+ # a1: The "Probe" vector (whose angle we are measuring)
223
+ # a2, a3: The "Base" vectors (defining the reference plane)
60
224
  a1 = fragm_2_center - fragm_1_center
61
225
  a2 = fragm_3_center - fragm_1_center
62
226
  a3 = fragm_4_center - fragm_1_center
63
227
 
64
- angle = torch_calc_outofplain_angle_from_vec(a1, a2, a3)
65
- if len(bias_pot_params) == 0:
66
- energy = 0.5 * self.config["keep_out_of_plain_angle_v2_spring_const"] * (angle - torch.deg2rad(torch.tensor(self.config["keep_out_of_plain_angle_v2_angle"]))) ** 2
67
- else:
68
- energy = 0.5 * bias_pot_params[0] * (angle - torch.deg2rad(bias_pot_params[1])) ** 2
69
- return energy #hartree
70
-
228
+ # ========================================
229
+ # 3. Plane Definition & Singularity Guard
230
+ # ========================================
231
+
232
+ # Normal vector to the plane spanned by a2 and a3
233
+ n = torch.linalg.cross(a2, a3)
234
+
235
+ # Check if plane is defined (a2 and a3 are not collinear)
236
+ n_sq_norm = torch.sum(n**2, dim=-1)
237
+ is_undefined_plane = (n_sq_norm < self.COLLINEAR_CUT_SQ)
238
+
239
+ # Safe normalization
240
+ n_norm = torch.sqrt(n_sq_norm)
241
+ n_hat_demon = torch.clamp(n_norm.unsqueeze(-1), min=1e-12)
242
+ n_hat = n / n_hat_demon
243
+
244
+ # ========================================
245
+ # 4. Angle Calculation (Robust atan2)
246
+ # ========================================
247
+
248
+ # Height of a1 relative to the plane (Projection onto normal)
249
+ # h = |a1| * sin(phi)
250
+ h = torch.sum(a1 * n_hat, dim=-1)
251
+
252
+ # Length of a1
253
+ a1_norm = torch.linalg.norm(a1)
254
+
255
+ # Projected length of a1 onto the plane
256
+ # r_proj = sqrt(|a1|^2 - h^2)
257
+ # Clamp to avoid negative sqrt due to numerical noise
258
+ r_proj_sq = a1_norm**2 - h**2
259
+ r_proj = torch.sqrt(torch.clamp(r_proj_sq, min=0.0))
260
+
261
+ # Calculate angle using atan2(height, projected_distance)
262
+ # Valid range: -90 to +90 degrees (or -pi/2 to pi/2)
263
+ angle = torch.atan2(h, r_proj)
264
+
265
+ # ========================================
266
+ # 5. Energy & Clamping
267
+ # ========================================
268
+
269
+ diff = angle - theta_0
270
+
271
+ energy_harmonic = 0.5 * k * diff**2
272
+
273
+ # Apply Guard:
274
+ # If the reference plane is undefined, force energy to 0.0 to prevent gradient explosion.
275
+ energy = torch.where(is_undefined_plane, torch.tensor(0.0, device=device, dtype=dtype), energy_harmonic)
276
+
277
+ return energy