MultiOptPy 1.20.7__py3-none-any.whl → 1.20.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,70 +2,482 @@ from multioptpy.Parameters.parameter import UnitValueLib
2
2
  from multioptpy.Utils.calc_tools import torch_calc_angle_from_vec
3
3
 
4
4
  import torch
5
-
5
+ import math
6
6
 
7
7
  class StructKeepAnglePotential:
8
+ """
9
+ Harmonic angle potential implementation ensuring C1 continuity and numerical stability at linear boundaries.
10
+
11
+ Calculates the potential energy:
12
+ E = 0.5 * k * (theta - theta_0)^2
13
+
14
+ Standard implementations using `acos(cos_theta)` suffer from gradient instabilities and loss of precision
15
+ near theta = 0 (cos_theta = 1) and theta = pi (cos_theta = -1). This class addresses these issues by
16
+ switching to a 5th-order Taylor expansion of the arc-cosine function squared in these singular regions.
17
+
18
+ Attributes:
19
+ config (dict): Configuration parameters including spring constants and target angles.
20
+ THETA_CUT (float): The angular threshold (in radians) defining the boundary regions near 0 and 180 degrees.
21
+ Inside this threshold, the Taylor expansion is used. Default is 1e-3.
22
+ EPSILON_PARAM (float): Small epsilon for floating-point comparisons (e.g., checking if theta_0 is 0).
23
+ """
24
+
8
25
  def __init__(self, **kwarg):
26
+ """
27
+ Initialize the potential parameters.
28
+
29
+ Args:
30
+ **kwarg: Keyword arguments containing:
31
+ - keep_angle_spring_const (float): The force constant (k).
32
+ - keep_angle_angle (float): The equilibrium angle (theta_0) in degrees.
33
+ - keep_angle_atom_pairs (list of int): Indices of the three atoms forming the angle (1-based).
34
+ """
9
35
  self.config = kwarg
10
- UVL = UnitValueLib()
11
- self.hartree2kcalmol = UVL.hartree2kcalmol
12
- self.bohr2angstroms = UVL.bohr2angstroms
13
- self.hartree2kjmol = UVL.hartree2kjmol
14
- return
15
-
36
+
37
+ # Unit conversion placeholders (assuming UnitValueLib is defined elsewhere)
38
+ # UVL = UnitValueLib()
39
+ # self.hartree2kcalmol = UVL.hartree2kcalmol
40
+ # self.bohr2angstroms = UVL.bohr2angstroms
41
+ # self.hartree2kjmol = UVL.hartree2kjmol
42
+
43
+ # Numerical stability parameters
44
+ # 1e-3 rad is approximately 0.05 degrees.
45
+ # This threshold is sufficiently large to avoid 'acos' catastrophic cancellation
46
+ # while being small enough for the Taylor approximation to remain highly accurate.
47
+ self.THETA_CUT = 1e-3
48
+ self.EPSILON_PARAM = 1e-8
49
+
50
+ self.COEFFS = [
51
+ 128.0/1575.0,
52
+ 4.0/35.0,
53
+ 8.0/45.0,
54
+ 1.0/3.0,
55
+ 2.0
56
+ ]
57
+
16
58
  def calc_energy(self, geom_num_list, bias_pot_params=[]):
17
59
  """
18
- # required variables: self.config["keep_angle_atom_pairs"],
19
- self.config["keep_angle_spring_const"]
20
- self.config["keep_angle_angle"]
21
- bias_pot_params[0] : keep_angle_spring_const
22
- bias_pot_params[1] : keep_angle_angle
60
+ Compute the potential energy for the current geometry.
61
+
62
+ Args:
63
+ geom_num_list (torch.Tensor): Tensor of atomic coordinates (shape: [N_atoms, 3]).
64
+ bias_pot_params (list, optional): Overriding parameters [k, theta_0_deg].
65
+ If empty, uses values from self.config.
66
+
67
+ Returns:
68
+ torch.Tensor: The calculated potential energy (scalar).
23
69
  """
24
-
25
- vector1 = geom_num_list[self.config["keep_angle_atom_pairs"][0]-1] - geom_num_list[self.config["keep_angle_atom_pairs"][1]-1]
26
- vector2 = geom_num_list[self.config["keep_angle_atom_pairs"][2]-1] - geom_num_list[self.config["keep_angle_atom_pairs"][1]-1]
27
- theta = torch_calc_angle_from_vec(vector1, vector2)
70
+ # 1. Parse parameters
28
71
  if len(bias_pot_params) == 0:
29
- energy = 0.5 * self.config["keep_angle_spring_const"] * (theta - torch.deg2rad(torch.tensor(self.config["keep_angle_angle"]))) ** 2
72
+ k = self.config["keep_angle_spring_const"]
73
+ theta_0_deg = torch.tensor(self.config["keep_angle_angle"])
74
+ theta_0 = torch.deg2rad(theta_0_deg)
30
75
  else:
31
- energy = 0.5 * bias_pot_params[0] * (theta - torch.deg2rad(bias_pot_params[1])) ** 2
32
- return energy #hartree
76
+ k = bias_pot_params[0]
77
+ theta_0_deg = bias_pot_params[1]
78
+ if isinstance(theta_0_deg, torch.Tensor):
79
+ theta_0 = torch.deg2rad(theta_0_deg)
80
+ else:
81
+ theta_0 = torch.deg2rad(torch.tensor(theta_0_deg))
33
82
 
83
+ # 2. Setup device and constants
84
+ device = geom_num_list.device if isinstance(geom_num_list, torch.Tensor) else torch.device("cpu")
85
+ dtype = geom_num_list.dtype
86
+
87
+ PI = torch.tensor(math.pi, device=device, dtype=dtype)
88
+ theta_0 = theta_0.to(device=device, dtype=dtype)
89
+ theta_cut_val = torch.tensor(self.THETA_CUT, device=device, dtype=dtype)
90
+ epsilon_param = torch.tensor(self.EPSILON_PARAM, device=device, dtype=dtype)
91
+
92
+ # 3. Calculate cosine of the angle (u)
93
+ idx1 = self.config["keep_angle_atom_pairs"][0] - 1
94
+ idx2 = self.config["keep_angle_atom_pairs"][1] - 1
95
+ idx3 = self.config["keep_angle_atom_pairs"][2] - 1
96
+
97
+ vec1 = geom_num_list[idx1] - geom_num_list[idx2]
98
+ vec2 = geom_num_list[idx3] - geom_num_list[idx2]
99
+
100
+ norm1 = torch.linalg.norm(vec1)
101
+ norm2 = torch.linalg.norm(vec2)
102
+
103
+ # Avoid division by zero
104
+ norm1_2 = torch.clamp(norm1 * norm2, min=1e-12)
105
+
106
+ u = torch.dot(vec1, vec2) / norm1_2
107
+ u = torch.clamp(u, -1.0, 1.0) # Numerical clamp to stay within valid acos domain
108
+
109
+ # 4. Define Thresholds for Taylor Expansion
110
+ u_cut_pos = torch.cos(theta_cut_val) # Threshold near 0 degrees (u ~ 1)
111
+ u_cut_neg = torch.cos(PI - theta_cut_val) # Threshold near 180 degrees (u ~ -1)
112
+
113
+ # ==============================================
114
+ # Helper Functions: High-Order Taylor Expansions
115
+ # ==============================================
116
+
117
+ def theta_sq_taylor_at_0(u_val):
118
+ """
119
+ Compute theta^2 using a 5th-order Taylor expansion near u=1 (theta=0).
120
+ Expansion of (arccos(1-x))^2:
121
+ theta^2 approx 2x + x^2/3 + 8x^3/45 + 4x^4/35 + 128x^5/1575 + ...
122
+ where x = 1 - u.
123
+ """
124
+ delta = 1.0 - u_val
125
+ # Horner's method for efficiency and precision
126
+ term = self.COEFFS[0]
127
+ term = self.COEFFS[1] + delta * term
128
+ term = self.COEFFS[2] + delta * term
129
+ term = self.COEFFS[3] + delta * term
130
+ term = self.COEFFS[4] + delta * term
131
+ return delta * term
132
+
133
+ def theta_sq_taylor_at_pi(u_val):
134
+ """
135
+ Compute (pi - theta)^2 using a 5th-order Taylor expansion near u=-1 (theta=pi).
136
+ The coefficients are identical to the u=1 case, but delta = 1 + u.
137
+ """
138
+ delta = 1.0 + u_val
139
+ # Horner's method
140
+ term = self.COEFFS[0]
141
+ term = self.COEFFS[1] + delta * term
142
+ term = self.COEFFS[2] + delta * term
143
+ term = self.COEFFS[3] + delta * term
144
+ term = self.COEFFS[4] + delta * term
145
+ return delta * term
34
146
 
147
+ # ==============================================
148
+ # Energy Calculation Logic (3 Branches)
149
+ # ==============================================
150
+
151
+ # --- BRANCH A: Equilibrium Angle is approx 0 degrees ---
152
+ if torch.abs(theta_0) < epsilon_param:
153
+ # Region 1 (u > u_cut_pos): Near 0. Use Taylor expansion of theta^2.
154
+ theta_sq = theta_sq_taylor_at_0(u)
155
+ E_taylor_0 = 0.5 * k * theta_sq
156
+
157
+ # Region 2 (u < u_cut_neg): Near pi.
158
+ # theta approx pi - sqrt((pi-theta)^2).
159
+ diff_sq = theta_sq_taylor_at_pi(u)
160
+ sqrt_diff_sq = torch.sqrt(torch.clamp(diff_sq, min=1e-30))
161
+ theta_approx = PI - sqrt_diff_sq
162
+ E_taylor_pi = 0.5 * k * theta_approx**2
163
+
164
+ # Region 3: Normal region. Use acos.
165
+ u_safe = torch.clamp(u, u_cut_neg, u_cut_pos)
166
+ theta_exact = torch.acos(u_safe)
167
+ E_exact = 0.5 * k * theta_exact**2
168
+
169
+ return torch.where(
170
+ u > u_cut_pos,
171
+ E_taylor_0,
172
+ torch.where(u < u_cut_neg, E_taylor_pi, E_exact)
173
+ )
174
+
175
+ # --- BRANCH B: Equilibrium Angle is approx 180 degrees ---
176
+ elif torch.abs(theta_0 - PI) < epsilon_param:
177
+ # Region 1 (u < u_cut_neg): Near pi. Use Taylor expansion of (pi-theta)^2.
178
+ diff_sq = theta_sq_taylor_at_pi(u)
179
+ E_taylor_pi = 0.5 * k * diff_sq
180
+
181
+ # Region 2 (u > u_cut_pos): Near 0.
182
+ # theta approx sqrt(theta^2).
183
+ theta_sq = theta_sq_taylor_at_0(u)
184
+ sqrt_theta_sq = torch.sqrt(torch.clamp(theta_sq, min=1e-30))
185
+ theta_approx = sqrt_theta_sq
186
+ E_taylor_0 = 0.5 * k * (theta_approx - PI)**2
187
+
188
+ # Region 3: Normal region. Use acos.
189
+ u_safe = torch.clamp(u, u_cut_neg, u_cut_pos)
190
+ theta_exact = torch.acos(u_safe)
191
+ E_exact = 0.5 * k * (theta_exact - PI)**2
192
+
193
+ return torch.where(
194
+ u < u_cut_neg,
195
+ E_taylor_pi,
196
+ torch.where(u > u_cut_pos, E_taylor_0, E_exact)
197
+ )
198
+
199
+ # --- BRANCH C: General Equilibrium Angle (e.g., 109.5 deg) ---
200
+ else:
201
+ is_singular_0 = (u > u_cut_pos)
202
+ is_singular_pi = (u < u_cut_neg)
203
+
204
+ # Normal Region
205
+ theta_safe = torch.acos(u)
206
+ E_safe = 0.5 * k * (theta_safe - theta_0)**2
207
+
208
+ # Singular Region 0 (u ~ 1)
209
+ # Use Taylor to get high-precision theta approx, then compute energy.
210
+ theta_sq = theta_sq_taylor_at_0(u)
211
+ sqrt_theta_sq = torch.sqrt(torch.clamp(theta_sq, min=1e-30))
212
+ E_taylor_0 = 0.5 * k * (sqrt_theta_sq - theta_0)**2
213
+
214
+ # Singular Region Pi (u ~ -1)
215
+ diff_sq = theta_sq_taylor_at_pi(u)
216
+ sqrt_diff_sq = torch.sqrt(torch.clamp(diff_sq, min=1e-30))
217
+ theta_approx_pi = PI - sqrt_diff_sq
218
+ E_taylor_pi = 0.5 * k * (theta_approx_pi - theta_0)**2
219
+
220
+ return torch.where(
221
+ is_singular_0,
222
+ E_taylor_0,
223
+ torch.where(is_singular_pi, E_taylor_pi, E_safe)
224
+ )
35
225
 
36
226
  class StructKeepAnglePotentialv2:
227
+ r"""
228
+ Angle restraint potential operating on fragment centroids with robust singularity handling.
229
+
230
+ This class calculates the angle potential :math:`E = 0.5 \cdot k \cdot (\theta - \theta_0)^2`
231
+ defined by the geometric centers (centroids) of three atom fragments.
232
+
233
+ It implements a hybrid singularity handling strategy to ensure numerical stability
234
+ (C1 continuity) and physical accuracy across all configurations, particularly
235
+ when the angle approaches 0 or 180 degrees.
236
+
237
+ **Singularity Handling Strategies:**
238
+
239
+ 1. **High-Order Taylor Expansion (Physical Accuracy)**:
240
+ Applied when the angle approaches a singularity (0 or 180) that **COINCIDES** with
241
+ the equilibrium angle :math:`\\theta_0`.
242
+ * *Context:* Linear (:math:`\\theta_0=180`) or hypothetical collapsed (:math:`\\theta_0=0`) molecules.
243
+ * *Method:* Uses a 5th-order Taylor expansion of :math:`\\arccos(u)^2` to strictly preserve
244
+ the physical curvature (Hessian) without precision loss from `acos`.
245
+
246
+ 2. **Quadratic Extrapolation (Numerical Robustness)**:
247
+ Applied when the angle approaches a singularity (0 or 180) that is **FAR** from
248
+ the equilibrium angle.
249
+ * *Context:* A linear molecule (:math:`\\theta_0=180`) being bent towards 0 degrees.
250
+ * *Method:* Replaces the physically diverging force of the harmonic potential (where
251
+ force -> infinity as sin(theta) -> 0) with a finite quadratic barrier in cos-space.
252
+ This prevents gradient explosions in MD simulations.
253
+
254
+ Attributes:
255
+ config (dict): Configuration dictionary containing potential parameters.
256
+ THETA_CUT (float): The angular threshold (in radians) defining the boundary regions.
257
+ Set to 1e-3 (~0.05 deg) to ensure stable transition before `acos` precision loss.
258
+ EPSILON_PARAM (float): Tolerance for determining if theta_0 is exactly 0 or 180.
259
+ """
260
+
37
261
  def __init__(self, **kwarg):
262
+ """
263
+ Initialize the potential parameters.
264
+
265
+ Args:
266
+ **kwarg: Keyword arguments. Must include:
267
+ - keep_angle_v2_spring_const (float): Force constant (k).
268
+ - keep_angle_v2_angle (float): Equilibrium angle in degrees.
269
+ - keep_angle_v2_fragm1 (list[int]): Atom indices for fragment 1.
270
+ - keep_angle_v2_fragm2 (list[int]): Atom indices for fragment 2 (vertex).
271
+ - keep_angle_v2_fragm3 (list[int]): Atom indices for fragment 3.
272
+ """
38
273
  self.config = kwarg
39
- UVL = UnitValueLib()
40
- self.hartree2kcalmol = UVL.hartree2kcalmol
41
- self.bohr2angstroms = UVL.bohr2angstroms
42
- self.hartree2kjmol = UVL.hartree2kjmol
43
- return
274
+
275
+ # Unit conversion placeholders (assuming UnitValueLib context)
276
+ # UVL = UnitValueLib()
277
+ # self.hartree2kcalmol = UVL.hartree2kcalmol
278
+ # self.bohr2angstroms = UVL.bohr2angstroms
279
+ # self.hartree2kjmol = UVL.hartree2kjmol
280
+
281
+ # Thresholds
282
+ self.THETA_CUT = 1e-3
283
+ self.EPSILON_PARAM = 1e-8
284
+
285
+ self.COEFFS = [
286
+ 128.0/1575.0,
287
+ 4.0/35.0,
288
+ 8.0/45.0,
289
+ 1.0/3.0,
290
+ 2.0
291
+ ]
292
+
44
293
  def calc_energy(self, geom_num_list, bias_pot_params=[]):
45
294
  """
46
- # required variables: self.config["keep_angle_v2_spring_const"],
47
- self.config["keep_angle_v2_angle"],
48
- self.config["keep_angle_v2_fragm1"],
49
- self.config["keep_angle_v2_fragm2"],
50
- self.config["keep_angle_v2_fragm3"],
51
- bias_pot_params[0] : keep_angle_v2_spring_const
52
- bias_pot_params[1] : keep_angle_v2_angle
53
-
295
+ Compute the potential energy based on fragment centroids.
296
+
297
+ Args:
298
+ geom_num_list (torch.Tensor): Atomic coordinates (N_atoms, 3).
299
+ bias_pot_params (list, optional): [k, theta_0] override.
300
+
301
+ Returns:
302
+ torch.Tensor: Potential energy.
54
303
  """
55
- fragm_1_center = torch.mean(geom_num_list[torch.tensor(self.config["keep_angle_v2_fragm1"]) - 1], dim=0)
56
- fragm_2_center = torch.mean(geom_num_list[torch.tensor(self.config["keep_angle_v2_fragm2"]) - 1], dim=0)
57
- fragm_3_center = torch.mean(geom_num_list[torch.tensor(self.config["keep_angle_v2_fragm3"]) - 1], dim=0)
58
-
59
- vector1 = fragm_1_center - fragm_2_center
60
- vector2 = fragm_3_center - fragm_2_center
61
- theta = torch_calc_angle_from_vec(vector1, vector2)
304
+ # 1. Parameter Retrieval
62
305
  if len(bias_pot_params) == 0:
63
- energy = 0.5 * self.config["keep_angle_v2_spring_const"] * (theta - torch.deg2rad(torch.tensor(self.config["keep_angle_v2_angle"]))) ** 2
306
+ k = self.config["keep_angle_v2_spring_const"]
307
+ theta_0_deg = torch.tensor(self.config["keep_angle_v2_angle"])
308
+ theta_0 = torch.deg2rad(theta_0_deg)
64
309
  else:
65
- energy = 0.5 * bias_pot_params[0] * (theta - torch.deg2rad(bias_pot_params[1])) ** 2
66
- return energy #hartree
310
+ k = bias_pot_params[0]
311
+ theta_0_deg = bias_pot_params[1]
312
+ if isinstance(theta_0_deg, torch.Tensor):
313
+ theta_0 = torch.deg2rad(theta_0_deg)
314
+ else:
315
+ theta_0 = torch.deg2rad(torch.tensor(theta_0_deg))
316
+
317
+ device = geom_num_list.device if isinstance(geom_num_list, torch.Tensor) else torch.device("cpu")
318
+ dtype = geom_num_list.dtype
319
+
320
+ PI = torch.tensor(math.pi, device=device, dtype=dtype)
321
+ theta_0 = theta_0.to(device=device, dtype=dtype)
322
+ theta_cut_val = torch.tensor(self.THETA_CUT, device=device, dtype=dtype)
323
+ epsilon_param = torch.tensor(self.EPSILON_PARAM, device=device, dtype=dtype)
67
324
 
325
+ # 2. Centroid & Vector Calculation
326
+ def get_centroid(key):
327
+ # Convert 1-based config indices to 0-based
328
+ indices = torch.tensor(self.config[key], device=device, dtype=torch.long) - 1
329
+ return torch.mean(geom_num_list[indices], dim=0)
330
+
331
+ fragm_1_center = get_centroid("keep_angle_v2_fragm1")
332
+ fragm_2_center = get_centroid("keep_angle_v2_fragm2") # Vertex
333
+ fragm_3_center = get_centroid("keep_angle_v2_fragm3")
334
+
335
+ vec1 = fragm_1_center - fragm_2_center
336
+ vec2 = fragm_3_center - fragm_2_center
337
+
338
+ norm1 = torch.linalg.norm(vec1)
339
+ norm2 = torch.linalg.norm(vec2)
340
+
341
+ # Prevent division by zero if centroids overlap
342
+ norm1_2 = torch.clamp(norm1 * norm2, min=1e-12)
343
+
344
+ u = torch.dot(vec1, vec2) / norm1_2
345
+ u = torch.clamp(u, -1.0, 1.0) # Numerical clamp for acos domain
346
+
347
+ # ========================================
348
+ # 3. Expansion Helper Functions
349
+ # ========================================
350
+
351
+ # Thresholds in cosine space
352
+ u_cut_pos = torch.cos(theta_cut_val) # Near 0 deg
353
+ u_cut_neg = torch.cos(PI - theta_cut_val) # Near 180 deg
354
+
355
+ def theta_sq_taylor_at_0(u_val):
356
+ """5th-order Taylor expansion of theta^2 near 0 (u=1)."""
357
+ delta = 1.0 - u_val
358
+ # Horner's method
359
+ term = self.COEFFS[0]
360
+ term = self.COEFFS[1] + delta * term
361
+ term = self.COEFFS[2] + delta * term
362
+ term = self.COEFFS[3] + delta * term
363
+ term = self.COEFFS[4] + delta * term
364
+ return delta * term
365
+
366
+ def theta_sq_taylor_at_pi(u_val):
367
+ """5th-order Taylor expansion of (pi-theta)^2 near 180 (u=-1)."""
368
+ delta = 1.0 + u_val
369
+ # Horner's method
370
+ term = self.COEFFS[0]
371
+ term = self.COEFFS[1] + delta * term
372
+ term = self.COEFFS[2] + delta * term
373
+ term = self.COEFFS[3] + delta * term
374
+ term = self.COEFFS[4] + delta * term
375
+ return delta * term
376
+
377
+ def get_quad_params(th_cut):
378
+ """
379
+ Compute parameters for Quadratic Extrapolation V(u) = A + B*du + 0.5*C*du^2.
380
+ This matches the Value (V) and Slope (dV/du) at the cutoff point.
381
+ """
382
+ sin_cut = torch.sin(th_cut)
383
+ # Chain rule: d(theta)/du = -1/sin(theta)
384
+ dth_du = -1.0 / sin_cut
385
+
386
+ # Value at cutoff
387
+ val = 0.5 * k * (th_cut - theta_0)**2
388
+
389
+ # First derivative wrt theta
390
+ dE_dth = k * (th_cut - theta_0)
391
+
392
+ # First derivative wrt u (Slope B)
393
+ d1 = dE_dth * dth_du
394
+
395
+ # Second derivative approximation (Gauss-Newton style: H ~ J^T J)
396
+ # We approximate d2E/du2 to ensure positive curvature and stability.
397
+ d2 = k * (dth_du**2)
398
+
399
+ return val, d1, d2
400
+
401
+ # ========================================
402
+ # 4. Energy Calculation Branches
403
+ # ========================================
404
+
405
+ # --- BRANCH A: EXACTLY Linear Equilibrium (theta_0 ~ 0) ---
406
+ if torch.abs(theta_0) < epsilon_param:
407
+ # Region 1 (u ~ 1): Equilibrium Singularity -> Taylor Expansion
408
+ # We want V = 0.5*k*theta^2.
409
+ theta_sq = theta_sq_taylor_at_0(u)
410
+ E_taylor = 0.5 * k * theta_sq
411
+
412
+ # Region 2 (u ~ -1): Antipodal Singularity -> Quadratic Extrapolation
413
+ # Molecule is bent to 180, but wants to be 0. Force is huge.
414
+ theta_cut_pi = PI - theta_cut_val
415
+ val_pi, d1_pi, d2_pi = get_quad_params(theta_cut_pi)
416
+ diff_pi = u - u_cut_neg
417
+ E_quad_pi = val_pi + d1_pi * diff_pi + 0.5 * d2_pi * (diff_pi**2)
418
+
419
+ # Region 3: Normal
420
+ u_safe = torch.clamp(u, -1.0, u_cut_pos)
421
+ theta_exact = torch.acos(u_safe)
422
+ E_exact = 0.5 * k * (theta_exact ** 2)
423
+
424
+ return torch.where(
425
+ u > u_cut_pos,
426
+ E_taylor,
427
+ torch.where(u < u_cut_neg, E_quad_pi, E_exact)
428
+ )
429
+
430
+ # --- BRANCH B: EXACTLY Planar Equilibrium (theta_0 ~ 180) ---
431
+ elif torch.abs(theta_0 - PI) < epsilon_param:
432
+ # Region 1 (u ~ -1): Equilibrium Singularity -> Taylor Expansion
433
+ # We want V = 0.5*k*(pi - theta)^2.
434
+ diff_sq_taylor = theta_sq_taylor_at_pi(u)
435
+ E_taylor = 0.5 * k * diff_sq_taylor
436
+
437
+ # Region 2 (u ~ 1): Antipodal Singularity -> Quadratic Extrapolation
438
+ # Molecule is bent to 0, but wants to be 180.
439
+ val_0, d1_0, d2_0 = get_quad_params(theta_cut_val)
440
+ diff_0 = u - u_cut_pos
441
+ E_quad_0 = val_0 + d1_0 * diff_0 + 0.5 * d2_0 * (diff_0**2)
442
+
443
+ # Region 3: Normal
444
+ u_safe = torch.clamp(u, u_cut_neg, 1.0)
445
+ theta_exact = torch.acos(u_safe)
446
+ E_exact = 0.5 * k * (theta_exact - theta_0) ** 2
447
+
448
+ return torch.where(
449
+ u < u_cut_neg,
450
+ E_taylor,
451
+ torch.where(u > u_cut_pos, E_quad_0, E_exact)
452
+ )
453
+
454
+ # --- BRANCH C: General Angle (e.g., 109.5) ---
455
+ else:
456
+ is_singular_0 = (u > u_cut_pos)
457
+ is_singular_pi = (u < u_cut_neg)
458
+
459
+ # Normal calculation
460
+ theta_safe = torch.acos(u)
461
+ E_safe = 0.5 * k * (theta_safe - theta_0) ** 2
462
+
463
+ # Extrapolation near 0
464
+ val_0, d1_0, d2_0 = get_quad_params(theta_cut_val)
465
+ diff_0 = u - u_cut_pos
466
+ E_quad_0 = val_0 + d1_0 * diff_0 + 0.5 * d2_0 * (diff_0**2)
467
+
468
+ # Extrapolation near 180
469
+ theta_cut_pi = PI - theta_cut_val
470
+ val_pi, d1_pi, d2_pi = get_quad_params(theta_cut_pi)
471
+ diff_pi = u - u_cut_neg
472
+ E_quad_pi = val_pi + d1_pi * diff_pi + 0.5 * d2_pi * (diff_pi**2)
68
473
 
474
+ return torch.where(
475
+ is_singular_0,
476
+ E_quad_0,
477
+ torch.where(is_singular_pi, E_quad_pi, E_safe)
478
+ )
479
+
480
+
69
481
  class StructKeepAnglePotentialAtomDistDependent:
70
482
  def __init__(self, **kwarg):
71
483
  self.config = kwarg