aimnet 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,3 @@
1
1
  from .aev import AEVSV, ConvSV # noqa: F401
2
2
  from .core import MLP, AtomicShift, AtomicSum, Dipole, Embedding, Forces, Output, Quadrupole # noqa: F401
3
- from .lr import D3TS, DFTD3, LRCoulomb # noqa: F401
3
+ from .lr import D3TS, DFTD3, LRCoulomb, SRCoulomb # noqa: F401
aimnet/modules/aev.py CHANGED
@@ -1,10 +1,10 @@
1
1
  import math
2
- from typing import Dict, List, Optional, Tuple
3
2
 
4
3
  import torch
5
4
  from torch import Tensor, nn
6
5
 
7
6
  from aimnet import nbops, ops
7
+ from aimnet.kernels import conv_sv_2d_sp
8
8
 
9
9
 
10
10
  class AEVSV(nn.Module):
@@ -37,12 +37,12 @@ class AEVSV(nn.Module):
37
37
  rmin: float = 0.8,
38
38
  rc_s: float = 5.0,
39
39
  nshifts_s: int = 16,
40
- eta_s: Optional[float] = None,
41
- rc_v: Optional[float] = None,
42
- nshifts_v: Optional[int] = None,
43
- eta_v: Optional[float] = None,
44
- shifts_s: Optional[List[float]] = None,
45
- shifts_v: Optional[List[float]] = None,
40
+ eta_s: float | None = None,
41
+ rc_v: float | None = None,
42
+ nshifts_v: int | None = None,
43
+ eta_v: float | None = None,
44
+ shifts_s: list[float] | None = None,
45
+ shifts_v: list[float] | None = None,
46
46
  ):
47
47
  super().__init__()
48
48
 
@@ -78,31 +78,31 @@ class AEVSV(nn.Module):
78
78
  shifts = torch.as_tensor(shifts, dtype=torch.float)
79
79
  self.register_parameter("shifts" + mod, nn.Parameter(shifts, requires_grad=False))
80
80
 
81
- def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
82
- # shapes (..., m) and (..., m, 3)
81
+ def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
82
+ # d_ij: distances to neighbors (..., m)
83
+ # r_ij: displacement vectors to neighbors (..., m, 3)
83
84
  d_ij, r_ij = ops.calc_distances(data)
84
85
  data["d_ij"] = d_ij
85
- # shapes (..., nshifts, m) and (..., nshifts, 3, m)
86
- u_ij, gs, gv = self._calc_aev(r_ij, d_ij, data) # pylint: disable=unused-variable
87
- # for now, do not save u_ij
88
- data["gs"], data["gv"] = gs, gv
86
+ # Atomic environment vectors: (..., m, g, 4)
87
+ # 4 components = 1 scalar (radial) + 3 vector (directional)
88
+ g_sv = self._calc_aev(r_ij, d_ij, data)
89
+ data["g_sv"] = g_sv
89
90
  return data
90
91
 
91
- def _calc_aev(self, r_ij: Tensor, d_ij: Tensor, data: Dict[str, Tensor]) -> Tuple[Tensor, Tensor, Tensor]:
92
+ def _calc_aev(self, r_ij: Tensor, d_ij: Tensor, data: dict[str, Tensor]) -> Tensor:
92
93
  fc_ij = ops.cosine_cutoff(d_ij, self.rc_s) # (..., m)
93
94
  fc_ij = nbops.mask_ij_(fc_ij, data, 0.0)
94
- gs = ops.exp_expand(d_ij, self.shifts_s, self.eta_s) * fc_ij.unsqueeze(
95
- -1
96
- ) # (..., m, nshifts) * (..., m, 1) -> (..., m, shitfs)
97
- u_ij = r_ij / d_ij.unsqueeze(-1) # (..., m, 3) / (..., m, 1) -> (..., m, 3)
98
- if self._dual_basis:
99
- fc_ij = ops.cosine_cutoff(d_ij, self.rc_v)
100
- gsv = ops.exp_expand(d_ij, self.shifts_v, self.eta_v) * fc_ij.unsqueeze(-1)
101
- gv = gsv.unsqueeze(-2) * u_ij.unsqueeze(-1)
102
- else:
103
- # (..., m, 1, shifts), (..., m, 3, 1) -> (..., m, 3, shifts)
104
- gv = gs.unsqueeze(-2) * u_ij.unsqueeze(-1)
105
- return u_ij, gs, gv
95
+ # Apply cutoff envelope to Gaussian-expanded distances
96
+ # Shape: (..., m, nshifts) * (..., m, 1) -> (..., m, nshifts)
97
+ gs = ops.exp_expand(d_ij, self.shifts_s, self.eta_s) * fc_ij.unsqueeze(-1)
98
+ # Normalize displacement vectors to unit direction vectors
99
+ # Shape: (..., m, 3) / (..., m, 1) -> (..., m, 3)
100
+ u_ij = r_ij / d_ij.unsqueeze(-1)
101
+ # Combine radial basis with directional info for vector features
102
+ # Shape: (..., m, 1, nshifts) * (..., m, 3, 1) -> (..., m, nshifts, 3)
103
+ gv = gs.unsqueeze(-1) * u_ij.unsqueeze(-2)
104
+ g_sv = torch.cat([gs.unsqueeze(-1), gv], dim=-1)
105
+ return g_sv
106
106
 
107
107
 
108
108
  class ConvSV(nn.Module):
@@ -116,8 +116,6 @@ class ConvSV(nn.Module):
116
116
  Number of feature channels for atomic features.
117
117
  d2features : bool, optional
118
118
  Flag indicating whether to use 2D features. Default is False.
119
- do_vector : bool, optional
120
- Flag indicating whether to perform vector convolution. Default is True.
121
119
  nshifts_v : Optional[int], optional
122
120
  Number of shifts for vector convolution. If not provided, defaults to the value of nshifts_s.
123
121
  ncomb_v : Optional[int], optional
@@ -129,16 +127,15 @@ class ConvSV(nn.Module):
129
127
  nshifts_s: int,
130
128
  nchannel: int,
131
129
  d2features: bool = False,
132
- do_vector: bool = True,
133
- nshifts_v: Optional[int] = None,
134
- ncomb_v: Optional[int] = None,
130
+ nshifts_v: int | None = None,
131
+ ncomb_v: int | None = None,
135
132
  ):
136
133
  super().__init__()
137
134
  nshifts_v = nshifts_v or nshifts_s
138
135
  ncomb_v = ncomb_v or nshifts_v
139
136
  agh = _init_ahg(nchannel, nshifts_v, ncomb_v)
140
137
  self.register_parameter("agh", nn.Parameter(agh, requires_grad=True))
141
- self.do_vector = do_vector
138
+ self.do_vector = True
142
139
  self.nchannel = nchannel
143
140
  self.d2features = d2features
144
141
  self.nshifts_s = nshifts_s
@@ -151,33 +148,37 @@ class ConvSV(nn.Module):
151
148
  n += self.nchannel * self.ncomb_v
152
149
  return n
153
150
 
154
- def forward(self, a: Tensor, gs: Tensor, gv: Optional[Tensor] = None) -> Tensor:
155
- avf = []
151
+ def forward(self, data: dict[str, Tensor], a: Tensor) -> Tensor:
152
+ g_sv = data["g_sv"]
153
+ mode = nbops.get_nb_mode(data)
156
154
  if self.d2features:
157
- avf_s = torch.einsum("...mag,...mg->...ag", a, gs)
155
+ if mode > 0 and a.device.type == "cuda":
156
+ avf_sv = conv_sv_2d_sp(a, data["nbmat"], g_sv)
157
+ elif mode > 0:
158
+ a_j = a.index_select(0, data["nbmat"].flatten()).unflatten(0, data["nbmat"].shape)
159
+ avf_sv = torch.einsum("...mag,...mgd->...agd", a_j, g_sv)
160
+ else:
161
+ avf_sv = torch.einsum("...mag,...mgd->...agd", a.unsqueeze(1), g_sv)
158
162
  else:
159
- avf_s = torch.einsum("...mg,...ma->...ag", gs, a)
160
- avf.append(avf_s.flatten(-2, -1))
161
- if self.do_vector:
162
- assert gv is not None
163
- agh = self.agh
164
- if self.d2features:
165
- avf_v = torch.einsum("...mag,...mdg,agh->...ahd", a, gv, agh)
163
+ if mode > 0:
164
+ a_j = a.index_select(0, data["nbmat"].flatten()).unflatten(0, data["nbmat"].shape)
165
+ avf_sv = torch.einsum("...ma,...mgd->...agd", a_j, g_sv)
166
166
  else:
167
- avf_v = torch.einsum("...ma,...mdg,agh->...ahd", a, gv, agh)
168
- avf.append(avf_v.pow(2).sum(-1).flatten(-2, -1))
169
- return torch.cat(avf, dim=-1)
167
+ avf_sv = torch.einsum("...ma,...mgd->...agd", a.unsqueeze(1), g_sv)
168
+ avf_s, avf_v = avf_sv.split([1, 3], dim=-1)
169
+ avf_v = torch.einsum("agh,...agd->...ahd", self.agh, avf_v).pow(2).sum(-1)
170
+ return torch.cat([avf_s.squeeze(-1).flatten(-2, -1), avf_v.flatten(-2, -1)], dim=-1)
170
171
 
171
172
 
172
173
  def _init_ahg(b: int, m: int, n: int):
173
174
  ret = torch.zeros(b, m, n)
174
175
  for i in range(b):
175
- ret[i] = _init_ahg_one(m, n) # pylinit: disable-arguments-out-of-order
176
+ ret[i] = _init_ahg_one(m, n) # pylint: disable=arguments-out-of-order
176
177
  return ret
177
178
 
178
179
 
179
180
  def _init_ahg_one(m: int, n: int):
180
- # make x8 times more vectors to select most diverse
181
+ # Oversample by 8x to select most orthogonal vectors via maxmin algorithm
181
182
  x = torch.arange(m).unsqueeze(0)
182
183
  a1, a2, a3, a4 = torch.randn(8 * n, 4).unsqueeze(-2).unbind(-1)
183
184
  y = a1 * torch.sin(a2 * 2 * x * math.pi / m) + a3 * torch.cos(a4 * 2 * x * math.pi / m)
@@ -192,7 +193,7 @@ def _init_ahg_one(m: int, n: int):
192
193
  ret[0] = y[i]
193
194
  mask[i] = False
194
195
 
195
- # simple maxmin impementation
196
+ # simple maxmin implementation
196
197
  for j in range(1, n):
197
198
  mindist, _ = torch.cdist(ret[:j], y).min(dim=0)
198
199
  maxidx = torch.argsort(mindist)[mask][-1]
aimnet/modules/core.py CHANGED
@@ -1,4 +1,5 @@
1
- from typing import Any, Callable, Dict, List, Optional
1
+ from collections.abc import Callable
2
+ from typing import Any
2
3
 
3
4
  import torch
4
5
  from torch import Tensor, nn
@@ -10,9 +11,9 @@ from aimnet.config import get_init_module, get_module
10
11
  def MLP(
11
12
  n_in: int,
12
13
  n_out: int,
13
- hidden: Optional[List[int]] = None,
14
+ hidden: list[int] | None = None,
14
15
  activation_fn: Callable | str = "torch.nn.GELU",
15
- activation_kwargs: Optional[Dict[str, Any]] = None,
16
+ activation_kwargs: dict[str, Any] | None = None,
16
17
  weight_init_fn: Callable | str = "torch.nn.init.xavier_normal_",
17
18
  bias: bool = True,
18
19
  last_linear: bool = True,
@@ -44,7 +45,7 @@ def MLP(
44
45
 
45
46
 
46
47
  class Embedding(nn.Embedding):
47
- def __init__(self, init: Optional[Dict[int, Any]] = None, **kwargs):
48
+ def __init__(self, init: dict[int, Any] | None = None, **kwargs):
48
49
  super().__init__(**kwargs)
49
50
  with torch.no_grad():
50
51
  if init is not None:
@@ -70,7 +71,7 @@ class DSequential(nn.Module):
70
71
  super().__init__()
71
72
  self.module = nn.ModuleList(modules)
72
73
 
73
- def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
74
+ def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
74
75
  for m in self.module:
75
76
  data = m(data)
76
77
  return data
@@ -97,7 +98,7 @@ class AtomicShift(nn.Module):
97
98
  def extra_repr(self) -> str:
98
99
  return f"key_in: {self.key_in}, key_out: {self.key_out}"
99
100
 
100
- def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
101
+ def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
101
102
  shifts = self.shifts(data["numbers"]).squeeze(-1)
102
103
  if self.reduce_sum:
103
104
  shifts = nbops.mol_sum(shifts, data)
@@ -114,13 +115,13 @@ class AtomicSum(nn.Module):
114
115
  def extra_repr(self) -> str:
115
116
  return f"key_in: {self.key_in}, key_out: {self.key_out}"
116
117
 
117
- def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
118
+ def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
118
119
  data[self.key_out] = nbops.mol_sum(data[self.key_in], data)
119
120
  return data
120
121
 
121
122
 
122
123
  class Output(nn.Module):
123
- def __init__(self, mlp: Dict | nn.Module, n_in: int, n_out: int, key_in: str, key_out: str):
124
+ def __init__(self, mlp: dict | nn.Module, n_in: int, n_out: int, key_in: str, key_out: str):
124
125
  super().__init__()
125
126
  self.key_in = key_in
126
127
  self.key_out = key_out
@@ -131,7 +132,7 @@ class Output(nn.Module):
131
132
  def extra_repr(self) -> str:
132
133
  return f"key_in: {self.key_in}, key_out: {self.key_out}"
133
134
 
134
- def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
135
+ def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
135
136
  v = self.mlp(data[self.key_in]).squeeze(-1)
136
137
  if data["_input_padded"].item():
137
138
  v = nbops.mask_i_(v, data, mask_value=0.0)
@@ -147,7 +148,7 @@ class Forces(nn.Module):
147
148
  self.y = y
148
149
  self.key_out = key_out
149
150
 
150
- def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
151
+ def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
151
152
  prev = torch.is_grad_enabled()
152
153
  torch.set_grad_enabled(True)
153
154
  data[self.x].requires_grad_(True)
@@ -171,7 +172,7 @@ class Dipole(nn.Module):
171
172
  def extra_repr(self) -> str:
172
173
  return f"key_in: {self.key_in}, key_out: {self.key_out}, center_coord: {self.center_coord}"
173
174
 
174
- def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
175
+ def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
175
176
  q = data[self.key_in]
176
177
  r = data["coord"]
177
178
  if self.center_coord:
@@ -184,7 +185,7 @@ class Quadrupole(Dipole):
184
185
  def __init__(self, key_in: str = "charges", key_out: str = "quadrupole", center_coord: bool = False):
185
186
  super().__init__(key_in=key_in, key_out=key_out, center_coord=center_coord)
186
187
 
187
- def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
188
+ def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
188
189
  q = data[self.key_in]
189
190
  r = data["coord"]
190
191
  if self.center_coord:
@@ -215,7 +216,7 @@ class SRRep(nn.Module):
215
216
  self.params = nn.Embedding(87, 2, padding_idx=0, _weight=weight)
216
217
  self.params.weight.requires_grad_(False)
217
218
 
218
- def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
219
+ def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
219
220
  p = self.params(data["numbers"])
220
221
  p_i, p_j = nbops.get_ij(p, data)
221
222
  p_ij = p_i * p_j