aimnet 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,222 @@
1
+ import numba
2
+ import numpy as np
3
+
4
+
5
+ @numba.njit(cache=True, parallel=False, fastmath=True)
6
+ def _nbmat_dual_cpu(
7
+ coord: np.ndarray, # float, (N, 3)
8
+ cutoff1_squared: float,
9
+ cutoff2_squared: float,
10
+ mol_idx: np.ndarray, # int, (N,)
11
+ mol_end_idx: np.ndarray, # int, (M,)
12
+ nbmat1: np.ndarray, # int, (N, maxnb1)
13
+ nbmat2: np.ndarray, # int, (N, maxnb2)
14
+ nnb1: np.ndarray, # int, zeros, (N,)
15
+ nnb2: np.ndarray, # int, zeros, (N,)
16
+ ):
17
+ maxnb1 = nbmat1.shape[1]
18
+ maxnb2 = nbmat2.shape[1]
19
+ N = coord.shape[0]
20
+ for i in range(N):
21
+ c_i = coord[i]
22
+ _mol_idx = mol_idx[i]
23
+ _j_start = i + 1
24
+ _j_end = mol_end_idx[_mol_idx]
25
+ for j in range(_j_start, _j_end):
26
+ diff = c_i - coord[j]
27
+ dx, dy, dz = diff[0], diff[1], diff[2]
28
+ dist2 = dx * dx + dy * dy + dz * dz
29
+ if dist2 < cutoff1_squared:
30
+ pos = nnb1[i]
31
+ nnb1[i] += 1
32
+ if pos < maxnb1:
33
+ nbmat1[i, pos] = j
34
+ if dist2 < cutoff2_squared:
35
+ pos = nnb2[i]
36
+ nnb2[i] += 1
37
+ if pos < maxnb2:
38
+ nbmat2[i, pos] = j
39
+ _expand_nb(nnb1, nbmat1)
40
+ _expand_nb(nnb2, nbmat2)
41
+
42
+
43
+ @numba.njit(cache=True, parallel=False, fastmath=True)
44
+ def _nbmat_cpu(
45
+ coord: np.ndarray, # float, (N, 3)
46
+ cutoff1_squared: float,
47
+ mol_idx: np.ndarray, # int, (N,)
48
+ mol_end_idx: np.ndarray, # int, (M,)
49
+ nbmat1: np.ndarray, # int, (N, maxnb1)
50
+ nnb1: np.ndarray, # int, zeros, (N,)
51
+ ):
52
+ maxnb1 = nbmat1.shape[1]
53
+ N = coord.shape[0]
54
+ for i in range(N):
55
+ c_i = coord[i]
56
+ _mol_idx = mol_idx[i]
57
+ _j_start = i + 1
58
+ _j_end = mol_end_idx[_mol_idx]
59
+ for j in range(_j_start, _j_end):
60
+ diff = c_i - coord[j]
61
+ dx, dy, dz = diff[0], diff[1], diff[2]
62
+ dist2 = dx * dx + dy * dy + dz * dz
63
+ if dist2 < cutoff1_squared:
64
+ pos = nnb1[i]
65
+ nnb1[i] += 1
66
+ if pos < maxnb1:
67
+ nbmat1[i, pos] = j
68
+ _expand_nb(nnb1, nbmat1)
69
+
70
+
71
+ @numba.njit(cache=True, inline="always")
72
+ def _expand_nb(nnb, nbmat):
73
+ nnb_copy = nnb.copy()
74
+ N = nnb.shape[0]
75
+ for i in range(N):
76
+ for m in range(nnb_copy[i]):
77
+ if m >= nbmat.shape[1]:
78
+ continue
79
+ j = nbmat[i, m]
80
+ if j < N:
81
+ pos = nnb[j]
82
+ nnb[j] += 1
83
+ if pos < nbmat.shape[1]:
84
+ nbmat[j, pos] = i
85
+
86
+
87
+ @numba.njit(cache=True, inline="always")
88
+ def _expand_nb_pbc(nnb, nbmat, shifts):
89
+ nnb_copy = nnb.copy()
90
+ N = nnb.shape[0]
91
+ for i in range(N):
92
+ for m in range(nnb_copy[i]):
93
+ if m >= nbmat.shape[1]:
94
+ continue
95
+ j = nbmat[i, m]
96
+ if j < N:
97
+ pos = nnb[j]
98
+ nnb[j] += 1
99
+ if pos < nbmat.shape[1]:
100
+ nbmat[j, pos] = i
101
+ shift = shifts[i, m]
102
+ shifts[j, pos] = -shift
103
+
104
+
105
+ @numba.njit(cache=True)
106
+ def _expand_shifts(nshift):
107
+ tot_shifts = (nshift[0] + 1) * (2 * nshift[1] + 1) * (2 * nshift[2] + 1)
108
+ shifts = np.zeros((tot_shifts, 3), dtype=np.float32)
109
+ i = 0
110
+ for k1 in range(-nshift[0], nshift[0] + 1):
111
+ for k2 in range(-nshift[1], nshift[1] + 1):
112
+ for k3 in range(-nshift[2], nshift[2] + 1):
113
+ if k1 > 0 or (k1 == 0 and k2 > 0) or (k1 == 0 and k2 == 0 and k3 >= 0):
114
+ shifts[i, 0] = k1
115
+ shifts[i, 1] = k2
116
+ shifts[i, 2] = k3
117
+ i += 1
118
+ shifts = shifts[:i]
119
+ return shifts
120
+
121
+
122
+ @numba.njit(cache=True, parallel=False, fastmath=True)
123
+ def shift_coords(coord, cell, shifts):
124
+ N = coord.shape[0]
125
+ S = shifts.shape[0]
126
+ # pre-compute shifted coords
127
+ coord_shifted = np.empty((N, S, 3), dtype=coord.dtype)
128
+ for i in range(N):
129
+ for s in range(S):
130
+ shift = shifts[s]
131
+ c_x = coord[i, 0] + shift[0] * cell[0, 0] + shift[1] * cell[1, 0] + shift[2] * cell[2, 0]
132
+ c_y = coord[i, 1] + shift[0] * cell[0, 1] + shift[1] * cell[1, 1] + shift[2] * cell[2, 1]
133
+ c_z = coord[i, 2] + shift[0] * cell[0, 2] + shift[1] * cell[1, 2] + shift[2] * cell[2, 2]
134
+ coord_shifted[i, s] = c_x, c_y, c_z
135
+ return coord_shifted
136
+
137
+
138
+ @numba.njit(cache=True, parallel=False, fastmath=True)
139
+ def _nbmat_pbc_cpu(
140
+ coord: np.ndarray, # float, (N, 3)
141
+ cell: np.ndarray, # float, (3, 3)
142
+ cutoff1_squared: float,
143
+ shifts: np.ndarray, # float, (S, 3)
144
+ nnb1: np.ndarray, # int, zeros, (N,)
145
+ nbmat1: np.ndarray, # int, (N, M)
146
+ shifts1: np.ndarray, # int, (N, M, 3)
147
+ ):
148
+ maxnb1 = nbmat1.shape[1]
149
+ N = coord.shape[0]
150
+ S = shifts.shape[0]
151
+
152
+ coord_shifted = shift_coords(coord, cell, shifts)
153
+
154
+ for i in range(N):
155
+ c_i = coord[i]
156
+ for s in range(S):
157
+ shift = shifts[s]
158
+ zero_shift = shift[0] == 0 and shift[1] == 0 and shift[2] == 0
159
+ _j_end = i if zero_shift else N
160
+ for j in range(_j_end):
161
+ c_j = coord_shifted[j, s]
162
+ dx = c_i[0] - c_j[0]
163
+ dy = c_i[1] - c_j[1]
164
+ dz = c_i[2] - c_j[2]
165
+ r2 = dx * dx + dy * dy + dz * dz
166
+ if r2 < cutoff1_squared:
167
+ pos = nnb1[i]
168
+ nnb1[i] += 1
169
+ if pos < maxnb1:
170
+ nbmat1[i, pos] = j
171
+ shifts1[i, pos] = shift
172
+ _expand_nb_pbc(nnb1, nbmat1, shifts1)
173
+
174
+
175
+ @numba.njit(cache=True, parallel=False, fastmath=True)
176
+ def _nbmat_dual_pbc_cpu(
177
+ coord: np.ndarray, # float, (N, 3)
178
+ cell: np.ndarray, # float, (3, 3)
179
+ cutoff1_squared: float,
180
+ cutoff2_squared: float,
181
+ shifts: np.ndarray, # float, (S, 3)
182
+ nnb1: np.ndarray, # int, zeros, (N,)
183
+ nnb2: np.ndarray, # int, zeros, (N,)
184
+ nbmat1: np.ndarray, # int, (N, M)
185
+ nbmat2: np.ndarray, # int, (N, M)
186
+ shifts1: np.ndarray, # int, (N, M, 3)
187
+ shifts2: np.ndarray, # int, (N, M, 3)
188
+ ):
189
+ maxnb1 = nbmat1.shape[1]
190
+ maxnb2 = nbmat2.shape[1]
191
+ N = coord.shape[0]
192
+ S = shifts.shape[0]
193
+
194
+ coord_shifted = shift_coords(coord, cell, shifts)
195
+
196
+ for i in range(N):
197
+ c_i = coord[i]
198
+ for s in range(S):
199
+ shift = shifts[s]
200
+ zero_shift = shift[0] == 0 and shift[1] == 0 and shift[2] == 0
201
+ _j_end = i if zero_shift else N
202
+ for j in range(_j_end):
203
+ c_j = coord_shifted[j, s]
204
+ dx = c_i[0] - c_j[0]
205
+ dy = c_i[1] - c_j[1]
206
+ dz = c_i[2] - c_j[2]
207
+ r2 = dx * dx + dy * dy + dz * dz
208
+ if r2 < cutoff1_squared:
209
+ pos = nnb1[i]
210
+ nnb1[i] += 1
211
+ if pos < maxnb1:
212
+ nbmat1[i, pos] = j
213
+ shifts1[i, pos] = shift
214
+ if r2 < cutoff2_squared:
215
+ pos = nnb2[i]
216
+ nnb2[i] += 1
217
+ if pos < maxnb2:
218
+ nbmat2[i, pos] = j
219
+ shifts2[i, pos] = shift
220
+
221
+ _expand_nb_pbc(nnb1, nbmat1, shifts1)
222
+ _expand_nb_pbc(nnb2, nbmat2, shifts2)
@@ -0,0 +1,217 @@
1
+ # type: ignore
2
+ import numba
3
+ import numba.cuda
4
+ from numba.core import config
5
+
6
+ config.CUDA_LOW_OCCUPANCY_WARNINGS = 0
7
+
8
+
9
+ @numba.cuda.jit(fastmath=True, cache=True)
10
+ def _nbmat_dual_cuda(coord, cutoff1_squared, cutoff2_squared, mol_idx, mol_end_idx, nbmat1, nbmat2, nnb1, nnb2):
11
+ N = coord.shape[0]
12
+ i = numba.cuda.grid(1)
13
+
14
+ if i >= N:
15
+ return
16
+
17
+ c0 = coord[i, 0]
18
+ c1 = coord[i, 1]
19
+ c2 = coord[i, 2]
20
+
21
+ maxnb1 = nbmat1.shape[1]
22
+ maxnb2 = nbmat2.shape[1]
23
+
24
+ _mol_idx = mol_idx[i]
25
+ _j_start = i + 1
26
+ _j_end = mol_end_idx[_mol_idx]
27
+
28
+ for j in range(_j_start, _j_end):
29
+ d0 = c0 - coord[j, 0]
30
+ d1 = c1 - coord[j, 1]
31
+ d2 = c2 - coord[j, 2]
32
+ dist_squared = d0 * d0 + d1 * d1 + d2 * d2
33
+ if dist_squared < cutoff1_squared:
34
+ pos = numba.cuda.atomic.add(nnb1, i, 1)
35
+ if pos < maxnb1:
36
+ nbmat1[i, pos] = j
37
+ pos = numba.cuda.atomic.add(nnb1, j, 1)
38
+ if pos < maxnb1:
39
+ nbmat1[j, pos] = i
40
+ if dist_squared < cutoff2_squared:
41
+ pos = numba.cuda.atomic.add(nnb2, i, 1)
42
+ if pos < maxnb2:
43
+ nbmat2[i, pos] = j
44
+ pos = numba.cuda.atomic.add(nnb2, j, 1)
45
+ if pos < maxnb2:
46
+ nbmat2[j, pos] = i
47
+
48
+
49
+ @numba.cuda.jit(fastmath=True, cache=True)
50
+ def _nbmat_cuda(coord, cutoff1_squared, mol_idx, mol_end_idx, nbmat1, nnb1):
51
+ N = coord.shape[0]
52
+ i = numba.cuda.grid(1)
53
+
54
+ if i >= N:
55
+ return
56
+
57
+ c0 = coord[i, 0]
58
+ c1 = coord[i, 1]
59
+ c2 = coord[i, 2]
60
+
61
+ maxnb1 = nbmat1.shape[1]
62
+
63
+ _mol_idx = mol_idx[i]
64
+ _j_start = i + 1
65
+ _j_end = mol_end_idx[_mol_idx]
66
+
67
+ for j in range(_j_start, _j_end):
68
+ d0 = c0 - coord[j, 0]
69
+ d1 = c1 - coord[j, 1]
70
+ d2 = c2 - coord[j, 2]
71
+ dist_squared = d0 * d0 + d1 * d1 + d2 * d2
72
+ if dist_squared < cutoff1_squared:
73
+ pos = numba.cuda.atomic.add(nnb1, i, 1)
74
+ if pos < maxnb1:
75
+ nbmat1[i, pos] = j
76
+ pos = numba.cuda.atomic.add(nnb1, j, 1)
77
+ if pos < maxnb1:
78
+ nbmat1[j, pos] = i
79
+
80
+
81
+ @numba.cuda.jit(cache=True, fastmath=True)
82
+ def _nbmat_pbc_dual_cuda(
83
+ coord, # N, 3
84
+ cell, # 3, 3
85
+ cutoff1_squared: float,
86
+ cutoff2_squared: float,
87
+ shifts, # S, 3
88
+ nnb1, # N
89
+ nnb2, # N
90
+ nbmat1, # N, M
91
+ nbmat2, # N, K
92
+ shifts1, # N, M, 3
93
+ shifts2, # N, K, 3
94
+ ):
95
+ idx = numba.cuda.grid(1)
96
+
97
+ _n = coord.shape[0]
98
+ _s = shifts.shape[0]
99
+
100
+ shift_idx, atom_idx = idx // _n, idx % _n
101
+ if shift_idx >= _s:
102
+ return
103
+
104
+ maxnb1 = nbmat1.shape[1]
105
+ maxnb2 = nbmat2.shape[1]
106
+
107
+ shift_x = shifts[shift_idx, 0]
108
+ shift_y = shifts[shift_idx, 1]
109
+ shift_z = shifts[shift_idx, 2]
110
+
111
+ zero_shift = shift_x == 0 and shift_y == 0 and shift_z == 0
112
+
113
+ shift_x = numba.float32(shift_x)
114
+ shift_y = numba.float32(shift_y)
115
+ shift_z = numba.float32(shift_z)
116
+
117
+ coord_shifted_x = coord[atom_idx, 0] + shift_x * cell[0, 0] + shift_y * cell[1, 0] + shift_z * cell[2, 0]
118
+ coord_shifted_y = coord[atom_idx, 1] + shift_x * cell[0, 1] + shift_y * cell[1, 1] + shift_z * cell[2, 1]
119
+ coord_shifted_z = coord[atom_idx, 2] + shift_x * cell[0, 2] + shift_y * cell[1, 2] + shift_z * cell[2, 2]
120
+
121
+ for i in range(_n):
122
+ if zero_shift and i >= atom_idx:
123
+ continue
124
+
125
+ dx = coord_shifted_x - coord[i, 0]
126
+ dy = coord_shifted_y - coord[i, 1]
127
+ dz = coord_shifted_z - coord[i, 2]
128
+
129
+ r2 = dx * dx + dy * dy + dz * dz
130
+
131
+ if r2 < cutoff1_squared:
132
+ pos = numba.cuda.atomic.add(nnb1, i, 1)
133
+ if pos < maxnb1:
134
+ nbmat1[i, pos] = atom_idx
135
+ shifts1[i, pos, 0] = shift_x
136
+ shifts1[i, pos, 1] = shift_y
137
+ shifts1[i, pos, 2] = shift_z
138
+ pos = numba.cuda.atomic.add(nnb1, atom_idx, 1)
139
+ if pos < maxnb1:
140
+ nbmat1[atom_idx, pos] = i
141
+ shifts1[atom_idx, pos, 0] = -shift_x
142
+ shifts1[atom_idx, pos, 1] = -shift_y
143
+ shifts1[atom_idx, pos, 2] = -shift_z
144
+
145
+ if r2 < cutoff2_squared:
146
+ pos = numba.cuda.atomic.add(nnb2, i, 1)
147
+ if pos < maxnb2:
148
+ nbmat2[i, pos] = atom_idx
149
+ shifts2[i, pos, 0] = shift_x
150
+ shifts2[i, pos, 1] = shift_y
151
+ shifts2[i, pos, 2] = shift_z
152
+ pos = numba.cuda.atomic.add(nnb2, atom_idx, 1)
153
+ if pos < maxnb2:
154
+ nbmat2[atom_idx, pos] = i
155
+ shifts2[atom_idx, pos, 0] = -shift_x
156
+ shifts2[atom_idx, pos, 1] = -shift_y
157
+ shifts2[atom_idx, pos, 2] = -shift_z
158
+
159
+
160
+ @numba.cuda.jit(cache=True, fastmath=True)
161
+ def _nbmat_pbc_cuda(
162
+ coord, # N, 3
163
+ cell, # 3, 3
164
+ cutoff1_squared: float,
165
+ shifts, # S, 3
166
+ nnb1, # N
167
+ nbmat1, # N, M
168
+ shifts1, # N, M, 3
169
+ ):
170
+ idx = numba.cuda.grid(1)
171
+
172
+ _n = coord.shape[0]
173
+ _s = shifts.shape[0]
174
+
175
+ shift_idx, atom_idx = idx // _n, idx % _n
176
+ if shift_idx >= _s:
177
+ return
178
+
179
+ maxnb1 = nbmat1.shape[1]
180
+
181
+ shift_x = shifts[shift_idx, 0]
182
+ shift_y = shifts[shift_idx, 1]
183
+ shift_z = shifts[shift_idx, 2]
184
+
185
+ zero_shift = shift_x == 0 and shift_y == 0 and shift_z == 0
186
+
187
+ shift_x = numba.float32(shift_x)
188
+ shift_y = numba.float32(shift_y)
189
+ shift_z = numba.float32(shift_z)
190
+
191
+ coord_shifted_x = coord[atom_idx, 0] + shift_x * cell[0, 0] + shift_y * cell[1, 0] + shift_z * cell[2, 0]
192
+ coord_shifted_y = coord[atom_idx, 1] + shift_x * cell[0, 1] + shift_y * cell[1, 1] + shift_z * cell[2, 1]
193
+ coord_shifted_z = coord[atom_idx, 2] + shift_x * cell[0, 2] + shift_y * cell[1, 2] + shift_z * cell[2, 2]
194
+
195
+ for i in range(_n):
196
+ if zero_shift and i >= atom_idx:
197
+ continue
198
+
199
+ dx = coord_shifted_x - coord[i, 0]
200
+ dy = coord_shifted_y - coord[i, 1]
201
+ dz = coord_shifted_z - coord[i, 2]
202
+
203
+ r2 = dx * dx + dy * dy + dz * dz
204
+
205
+ if r2 < cutoff1_squared:
206
+ pos = numba.cuda.atomic.add(nnb1, i, 1)
207
+ if pos < maxnb1:
208
+ nbmat1[i, pos] = atom_idx
209
+ shifts1[i, pos, 0] = shift_x
210
+ shifts1[i, pos, 1] = shift_y
211
+ shifts1[i, pos, 2] = shift_z
212
+ pos = numba.cuda.atomic.add(nnb1, atom_idx, 1)
213
+ if pos < maxnb1:
214
+ nbmat1[atom_idx, pos] = i
215
+ shifts1[atom_idx, pos, 0] = -shift_x
216
+ shifts1[atom_idx, pos, 1] = -shift_y
217
+ shifts1[atom_idx, pos, 2] = -shift_z
@@ -0,0 +1,220 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from .nb_kernel_cpu import _expand_shifts
7
+
8
+
9
+ class TooManyNeighborsError(Exception):
10
+ pass
11
+
12
+
13
+ if torch.cuda.is_available():
14
+ import numba.cuda
15
+
16
+ if not numba.cuda.is_available():
17
+ raise ImportError("PyTorch CUDA is available, but Numba CUDA is not available.")
18
+ _numba_cuda_available = True
19
+ from .nb_kernel_cuda import _nbmat_cuda, _nbmat_dual_cuda, _nbmat_pbc_cuda, _nbmat_pbc_dual_cuda
20
+
21
+ _kernel_nbmat = _nbmat_cuda
22
+ _kernel_nbmat_dual = _nbmat_dual_cuda
23
+ _kernel_nbmat_pbc = _nbmat_pbc_cuda
24
+ _kernel_nbmat_pbc_dual = _nbmat_pbc_dual_cuda
25
+ else:
26
+ _numba_cuda_available = False
27
+ from .nb_kernel_cpu import _nbmat_cpu, _nbmat_dual_cpu, _nbmat_dual_pbc_cpu, _nbmat_pbc_cpu
28
+
29
+ _kernel_nbmat = _nbmat_cpu
30
+ _kernel_nbmat_dual = _nbmat_dual_cpu
31
+ _kernel_nbmat_pbc = _nbmat_pbc_cpu
32
+ _kernel_nbmat_pbc_dual = _nbmat_dual_pbc_cpu
33
+
34
+
35
+ def calc_nbmat(
36
+ coord: Tensor,
37
+ cutoffs: Tuple[float, Optional[float]],
38
+ maxnb: Tuple[int, Optional[int]],
39
+ cell: Optional[Tensor] = None,
40
+ mol_idx: Optional[Tensor] = None,
41
+ ):
42
+ device = coord.device
43
+ N = coord.shape[0]
44
+
45
+ _pbc = cell is not None
46
+ if _pbc and mol_idx is not None and mol_idx[-1] > 0:
47
+ raise ValueError("Multiple molecules are not supported with PBC.")
48
+
49
+ if mol_idx is None:
50
+ mol_idx = torch.zeros(N, dtype=torch.long, device=device)
51
+ mol_end_idx = torch.tensor([N], dtype=torch.long, device=device)
52
+ else:
53
+ _, mol_size = torch.unique(mol_idx, return_counts=True)
54
+ mol_end_idx = mol_size.cumsum(0)
55
+
56
+ if _numba_cuda_available and device.type != "cuda":
57
+ raise ValueError("Numba CUDA is available, but the input tensors are not on CUDA.")
58
+
59
+ _cuda = device.type == "cuda" and _numba_cuda_available
60
+ _dual_cutoff = cutoffs[1] is not None
61
+ if _dual_cutoff and maxnb[1] is None:
62
+ raise ValueError("maxnb[1] must be specified for dual cutoff.")
63
+
64
+ nnb1 = torch.zeros(N, dtype=torch.long, device=device)
65
+ nbmat1 = torch.full((N + 1, maxnb[0]), N, dtype=torch.long, device=device)
66
+
67
+ if _dual_cutoff:
68
+ nnb2 = torch.zeros(N, dtype=torch.long, device=device)
69
+ nbmat2 = torch.full((N + 1, maxnb[1]), N, dtype=torch.long, device=device) # type: ignore
70
+
71
+ if _pbc:
72
+ cell_inv = torch.inverse(cell) # type: ignore[arg-type]
73
+ cutoff = max(cutoffs) if _dual_cutoff else cutoffs[0] # type: ignore
74
+ nshift = torch.ceil(cutoff * cell_inv.norm(dim=-1)).to(torch.long).cpu().numpy()
75
+ shifts = _expand_shifts(nshift)
76
+ S = shifts.shape[0]
77
+ shifts = torch.from_numpy(shifts).to(device)
78
+ shifts1 = torch.zeros(N + 1, maxnb[0], 3, dtype=torch.long, device=device)
79
+ if _dual_cutoff:
80
+ shifts2 = torch.zeros(N + 1, maxnb[1], 3, dtype=torch.long, device=device) # type: ignore
81
+ else:
82
+ S = 1
83
+
84
+ # convert tensors and launch the kernel
85
+ if _cuda:
86
+ _coord = numba.cuda.as_cuda_array(coord)
87
+ _mol_idx = numba.cuda.as_cuda_array(mol_idx)
88
+ _mol_end_idx = numba.cuda.as_cuda_array(mol_end_idx)
89
+ _nnb1 = numba.cuda.as_cuda_array(nnb1)
90
+ _nbmat1 = numba.cuda.as_cuda_array(nbmat1)
91
+ if _dual_cutoff:
92
+ _nnb2 = numba.cuda.as_cuda_array(nnb2)
93
+ _nbmat2 = numba.cuda.as_cuda_array(nbmat2)
94
+ if _pbc:
95
+ _cell = numba.cuda.as_cuda_array(cell)
96
+ _shifts = numba.cuda.as_cuda_array(shifts)
97
+ _shifts1 = numba.cuda.as_cuda_array(shifts1)
98
+ if _dual_cutoff:
99
+ _shifts2 = numba.cuda.as_cuda_array(shifts2)
100
+ threads_per_block = 32
101
+ blocks_per_grid = (N * S + (threads_per_block - 1)) // threads_per_block
102
+
103
+ if _pbc:
104
+ if _dual_cutoff:
105
+ _kernel_nbmat_pbc_dual[blocks_per_grid, threads_per_block]( # type: ignore
106
+ _coord,
107
+ _cell,
108
+ cutoffs[0] ** 2,
109
+ cutoffs[1] ** 2, # type: ignore
110
+ _shifts,
111
+ _nnb1,
112
+ _nnb2,
113
+ _nbmat1,
114
+ _nbmat2,
115
+ _shifts1,
116
+ _shifts2,
117
+ )
118
+ else:
119
+ _kernel_nbmat_pbc[blocks_per_grid, threads_per_block]( # type: ignore
120
+ _coord,
121
+ _cell,
122
+ cutoffs[0] ** 2,
123
+ _shifts,
124
+ _nnb1,
125
+ _nbmat1,
126
+ _shifts1,
127
+ )
128
+ else:
129
+ if _dual_cutoff:
130
+ _kernel_nbmat_dual[blocks_per_grid, threads_per_block]( # type: ignore
131
+ _coord,
132
+ cutoffs[0] ** 2,
133
+ cutoffs[1] ** 2, # type: ignore
134
+ _mol_idx,
135
+ _mol_end_idx,
136
+ _nbmat1,
137
+ _nbmat2,
138
+ _nnb1,
139
+ _nnb2,
140
+ )
141
+ else:
142
+ _kernel_nbmat[blocks_per_grid, threads_per_block]( # type: ignore
143
+ _coord,
144
+ cutoffs[0] ** 2,
145
+ _mol_idx,
146
+ _mol_end_idx,
147
+ _nbmat1,
148
+ _nnb1,
149
+ )
150
+
151
+ else:
152
+ _coord = coord.numpy()
153
+ _mol_idx = mol_idx.numpy()
154
+ _mol_end_idx = mol_end_idx.numpy()
155
+ _nnb1 = nnb1.numpy()
156
+ _nbmat1 = nbmat1.numpy()
157
+ if _dual_cutoff:
158
+ _nnb2 = nnb2.numpy()
159
+ _nbmat2 = nbmat2.numpy()
160
+ if _pbc:
161
+ _cell = cell.numpy() # type: ignore[union-attr]
162
+ _shifts = shifts.numpy()
163
+
164
+ if _pbc:
165
+ _shifts1 = shifts1.numpy()
166
+ if _dual_cutoff:
167
+ _shifts2 = shifts2.numpy()
168
+ _kernel_nbmat_pbc_dual(
169
+ _coord,
170
+ _cell,
171
+ cutoffs[0] ** 2,
172
+ cutoffs[1] ** 2, # type: ignore
173
+ _shifts,
174
+ _nnb1,
175
+ _nnb2,
176
+ _nbmat1,
177
+ _nbmat2,
178
+ _shifts1, # type: ignore
179
+ _shifts2, # type: ignore
180
+ ) # type: ignore
181
+ else:
182
+ _kernel_nbmat_pbc(_coord, _cell, cutoffs[0] ** 2, _shifts, _nnb1, _nbmat1, _shifts1) # type: ignore
183
+ else:
184
+ if _dual_cutoff:
185
+ _kernel_nbmat_dual(
186
+ _coord,
187
+ cutoffs[0] ** 2,
188
+ cutoffs[1] ** 2, # type: ignore
189
+ _mol_idx,
190
+ _mol_end_idx,
191
+ _nbmat1,
192
+ _nbmat2,
193
+ _nnb1,
194
+ _nnb2,
195
+ ) # type: ignore
196
+ else:
197
+ _kernel_nbmat(_coord, cutoffs[0] ** 2, _mol_idx, _mol_end_idx, _nbmat1, _nnb1)
198
+
199
+ if not _pbc:
200
+ shifts1 = None # type: ignore[assignment]
201
+ shifts2 = None # type: ignore[assignment]
202
+
203
+ nnb1_max = nnb1.max().item()
204
+ if nnb1_max > maxnb[0]:
205
+ raise TooManyNeighborsError(f"maxnb is too small: {nnb1_max=}, {maxnb=}")
206
+ nbmat1 = nbmat1[:, :nnb1_max] # type: ignore
207
+ if _pbc:
208
+ shifts1 = shifts1[:, :nnb1_max] # type: ignore
209
+ if _dual_cutoff:
210
+ nnb2_max = nnb2.max().item()
211
+ if nnb2_max > maxnb[1]: # type: ignore
212
+ raise TooManyNeighborsError(f"maxnb is too small: {nnb1_max=}, {nnb2_max=}, {maxnb=}")
213
+ nbmat2 = nbmat2[:, :nnb2_max]
214
+ if _pbc:
215
+ shifts2 = shifts2[:, :nnb2_max] # type: ignore
216
+ else:
217
+ nbmat2 = None
218
+ if _pbc:
219
+ shifts2 = None
220
+ return nbmat1, nbmat2, shifts1, shifts2
aimnet/cli.py ADDED
@@ -0,0 +1,22 @@
1
+ import click
2
+
3
+ from .train.calc_sae import calc_sae
4
+ from .train.pt2jpt import jitcompile
5
+ from .train.train import train
6
+
7
+
8
+ @click.group()
9
+ def cli():
10
+ """AIMNet2 command line tool"""
11
+
12
+
13
+ cli.add_command(train, name="train")
14
+ cli.add_command(jitcompile, name="jitcompile")
15
+ cli.add_command(calc_sae, name="calc_sae")
16
+
17
+
18
+ if __name__ == "__main__":
19
+ import logging
20
+
21
+ logging.basicConfig(level=logging.INFO)
22
+ cli()