bayesianflow-for-chem 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bayesianflow-for-chem might be problematic. Click here for more details.
- bayesianflow_for_chem/__init__.py +11 -0
- bayesianflow_for_chem/data.py +250 -0
- bayesianflow_for_chem/model.py +927 -0
- bayesianflow_for_chem/scorer.py +134 -0
- bayesianflow_for_chem/tool.py +470 -0
- bayesianflow_for_chem/train.py +243 -0
- bayesianflow_for_chem/vocab.txt +246 -0
- bayesianflow_for_chem-1.2.0.dist-info/METADATA +162 -0
- bayesianflow_for_chem-1.2.0.dist-info/RECORD +11 -0
- bayesianflow_for_chem-1.2.0.dist-info/WHEEL +5 -0
- bayesianflow_for_chem-1.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,927 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Author: Nianze A. Tao (Omozawa Sueno)
|
|
3
|
+
"""
|
|
4
|
+
Define Bayesian Flow Network for Chemistry (ChemBFN) model.
|
|
5
|
+
"""
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import List, Tuple, Optional, Union
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
from torch import Tensor
|
|
11
|
+
from torch.nn.functional import softmax, linear, dropout
|
|
12
|
+
from typing_extensions import Self
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Linear(nn.Linear):
|
|
16
|
+
# Modified from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
|
17
|
+
# We made it simpler and compatible with both `loralib` and `TorchScript`.
|
|
18
|
+
def __init__(self, in_features: int, out_features: int, bias: bool = True, **kargs):
|
|
19
|
+
"""
|
|
20
|
+
LoRA implemented in a dense layer.
|
|
21
|
+
|
|
22
|
+
:param in_features: number of input features
|
|
23
|
+
:param out_features: number of output features
|
|
24
|
+
:param bias: whether to use additional bias
|
|
25
|
+
:param device: device
|
|
26
|
+
:param dtype: PyTorch data type
|
|
27
|
+
:type in_features: int
|
|
28
|
+
:type out_features: int
|
|
29
|
+
:type bias: bool
|
|
30
|
+
:type device: torch.device | str | None
|
|
31
|
+
:type dtype: torch.dtype
|
|
32
|
+
"""
|
|
33
|
+
nn.Linear.__init__(self, in_features, out_features, bias, **kargs)
|
|
34
|
+
self.lora_enabled: bool = False
|
|
35
|
+
self.lora_A: Optional[nn.Parameter] = None
|
|
36
|
+
self.lora_B: Optional[nn.Parameter] = None
|
|
37
|
+
self.scaling: Optional[float] = None
|
|
38
|
+
self.lora_dropout: Optional[float] = None
|
|
39
|
+
nn.Linear.reset_parameters(self)
|
|
40
|
+
|
|
41
|
+
def enable_lora(
|
|
42
|
+
self, r: int = 8, lora_alpha: int = 1, lora_dropout: float = 0.0
|
|
43
|
+
) -> None:
|
|
44
|
+
"""
|
|
45
|
+
Enable LoRA parameters.
|
|
46
|
+
|
|
47
|
+
:param r: rank
|
|
48
|
+
:param lora_alpha: LoRA alpha value
|
|
49
|
+
:param lora_dropout: dropout frequency in LoRA layer
|
|
50
|
+
:type r: int
|
|
51
|
+
:type lora_alpha: float
|
|
52
|
+
:type lora_dropout: float
|
|
53
|
+
:return:
|
|
54
|
+
:rtype: None
|
|
55
|
+
"""
|
|
56
|
+
assert r > 0, "Rank should be larger than 0."
|
|
57
|
+
self.lora_A = nn.Parameter(self.weight.new_zeros((r, self.in_features)))
|
|
58
|
+
self.lora_B = nn.Parameter(self.weight.new_zeros((self.out_features, r)))
|
|
59
|
+
self.scaling = lora_alpha / r
|
|
60
|
+
self.lora_dropout = lora_dropout
|
|
61
|
+
self.lora_enabled = True
|
|
62
|
+
nn.init.kaiming_uniform_(self.lora_A, a=5**0.5)
|
|
63
|
+
nn.init.zeros_(self.lora_B)
|
|
64
|
+
self.weight.requires_grad_(False)
|
|
65
|
+
|
|
66
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
67
|
+
result = linear(x, self.weight, self.bias)
|
|
68
|
+
if self.lora_enabled and isinstance(self.lora_dropout, float):
|
|
69
|
+
result += (
|
|
70
|
+
dropout(x, self.lora_dropout, self.training)
|
|
71
|
+
@ self.lora_A.transpose(0, 1)
|
|
72
|
+
@ self.lora_B.transpose(0, 1)
|
|
73
|
+
) * self.scaling
|
|
74
|
+
return result
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor:
|
|
78
|
+
return x * (1 + scale) + shift
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class RoPE(nn.Module):
|
|
82
|
+
def __init__(self, channel: int = 512, num_head: int = 8) -> None:
|
|
83
|
+
"""
|
|
84
|
+
Rotary position embedding block with XPOS method.
|
|
85
|
+
|
|
86
|
+
:param channel: hidden layer features
|
|
87
|
+
:param num_head: number of heads
|
|
88
|
+
:type channel: int
|
|
89
|
+
:type num_head: int
|
|
90
|
+
"""
|
|
91
|
+
super().__init__()
|
|
92
|
+
d = channel // num_head
|
|
93
|
+
assert d % 2 == 0
|
|
94
|
+
self.channel = channel
|
|
95
|
+
i = torch.arange(0, d, 2)[None, :] / d
|
|
96
|
+
theta_half = torch.pow(10000, -i)
|
|
97
|
+
zeta_half = (i + 0.4) / 1.4
|
|
98
|
+
theta, zeta = torch.zeros((1, d)), torch.zeros((1, d))
|
|
99
|
+
theta[:, 0::2] = theta_half
|
|
100
|
+
theta[:, 1::2] = theta_half
|
|
101
|
+
zeta[:, 0::2] = zeta_half
|
|
102
|
+
zeta[:, 1::2] = zeta_half
|
|
103
|
+
self.register_buffer("theta", theta)
|
|
104
|
+
self.register_buffer("zeta", zeta)
|
|
105
|
+
|
|
106
|
+
def forward(self, size: int) -> Tuple[Tensor, Tensor, Tensor]:
|
|
107
|
+
"""
|
|
108
|
+
:param size: maximum length of sequence in the batch
|
|
109
|
+
:type size: int
|
|
110
|
+
:return: cos part of position encoding; shape: (1, 1, n_t, n_h) \n
|
|
111
|
+
sin part of position encoding; shape: (1, 1, n_t, n_h) \n
|
|
112
|
+
scaling coefficients; shape: (1, 1, n_t, n_h)
|
|
113
|
+
:rtype: tuple
|
|
114
|
+
"""
|
|
115
|
+
pos = torch.arange(size, device=self.theta.device)[:, None]
|
|
116
|
+
cos, sin = torch.cos(pos * self.theta), torch.sin(pos * self.theta)
|
|
117
|
+
zeta = torch.pow(self.zeta, pos / self.channel)
|
|
118
|
+
return cos[None, None, ...], sin[None, None, ...], zeta[None, None, ...]
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class Attention(nn.Module):
|
|
122
|
+
def __init__(self, channel: int = 512, num_head: int = 8) -> None:
|
|
123
|
+
"""
|
|
124
|
+
Multi-head self-attention block.
|
|
125
|
+
|
|
126
|
+
:param channel: hidden layer features
|
|
127
|
+
:param num_head: number of heads
|
|
128
|
+
:type channel: int
|
|
129
|
+
:type num_head: int
|
|
130
|
+
"""
|
|
131
|
+
super().__init__()
|
|
132
|
+
assert channel % num_head == 0
|
|
133
|
+
self.d = channel // num_head # head dimension
|
|
134
|
+
self.nh = num_head # number of heads
|
|
135
|
+
self.tp = (2 * self.d) ** 0.5 # attention temperature
|
|
136
|
+
self.qkv = Linear(channel, channel * 3)
|
|
137
|
+
|
|
138
|
+
@staticmethod
|
|
139
|
+
def _rotate(
|
|
140
|
+
q: Tensor, k: Tensor, pe: Tuple[Tensor, Tensor, Tensor]
|
|
141
|
+
) -> Tuple[Tensor, Tensor]:
|
|
142
|
+
q_rotate, k_rotate = torch.zeros_like(q), torch.zeros_like(k)
|
|
143
|
+
q_rotate[..., 0::2] = -q[..., 1::2]
|
|
144
|
+
q_rotate[..., 1::2] = q[..., 0::2]
|
|
145
|
+
q = (q * pe[0] + q_rotate * pe[1]) * pe[2]
|
|
146
|
+
k_rotate[..., 0::2] = -k[..., 1::2]
|
|
147
|
+
k_rotate[..., 1::2] = k[..., 0::2]
|
|
148
|
+
k = (k * pe[0] + k_rotate * pe[1]) / pe[2]
|
|
149
|
+
return q, k
|
|
150
|
+
|
|
151
|
+
def forward(
|
|
152
|
+
self, x: Tensor, pe: Tuple[Tensor, Tensor, Tensor], mask: Optional[Tensor]
|
|
153
|
+
) -> Tensor:
|
|
154
|
+
"""
|
|
155
|
+
:param x: output tensor; shape: (n_b, n_t, n_f)
|
|
156
|
+
:param pe: position encoding; shape: (1, 1, n_t, n_h) * 3
|
|
157
|
+
:param mask: attention mask; shape: (1, n_b, n_t, n_t)
|
|
158
|
+
:type x: torch.Tensor
|
|
159
|
+
:type pe: tuple
|
|
160
|
+
:type mask: torch.Tensor | None
|
|
161
|
+
:return: attentioned output; shape: (n_b, n_t, n_f)
|
|
162
|
+
:rtype: torch.Tensor
|
|
163
|
+
"""
|
|
164
|
+
n_b, n_a, _ = shape = x.shape
|
|
165
|
+
split = (n_b, n_a, self.nh, self.d)
|
|
166
|
+
q, k, v = self.qkv(x).chunk(3, -1)
|
|
167
|
+
q = q.view(split).permute(2, 0, 1, 3).contiguous()
|
|
168
|
+
k = k.view(split).permute(2, 0, 1, 3).contiguous()
|
|
169
|
+
v = v.view(split).permute(2, 0, 1, 3).contiguous()
|
|
170
|
+
q, k = self._rotate(q, k, pe) # position embedding
|
|
171
|
+
"""
|
|
172
|
+
# Original code. Maybe using `nn.functional.scaled_dot_product_attention(...)` is better.
|
|
173
|
+
|
|
174
|
+
k_t = k.transpose(-2, -1)
|
|
175
|
+
if mask is not None:
|
|
176
|
+
alpha = softmax((q @ k_t / self.tp).masked_fill_(mask, -torch.inf), -1)
|
|
177
|
+
else:
|
|
178
|
+
alpha = softmax(q @ k_t / self.tp, -1)
|
|
179
|
+
atten_out = (alpha @ v).permute(1, 2, 0, 3).contiguous().view(shape)
|
|
180
|
+
"""
|
|
181
|
+
atten_out = nn.functional.scaled_dot_product_attention(
|
|
182
|
+
q, k, v, mask, 0.0, False, scale=1 / self.tp
|
|
183
|
+
)
|
|
184
|
+
atten_out = atten_out.permute(1, 2, 0, 3).contiguous().view(shape)
|
|
185
|
+
return atten_out
|
|
186
|
+
|
|
187
|
+
def enable_lora(
|
|
188
|
+
self, r: int = 4, lora_alpha: int = 1, lora_dropout: float = 0.0
|
|
189
|
+
) -> None:
|
|
190
|
+
"""
|
|
191
|
+
Enable LoRA parameters.
|
|
192
|
+
|
|
193
|
+
:param r: rank
|
|
194
|
+
:param lora_alpha: LoRA alpha value
|
|
195
|
+
:param lora_dropout: dropout frequency in LoRA layer
|
|
196
|
+
:type r: int
|
|
197
|
+
:type lora_alpha: float
|
|
198
|
+
:type lora_dropout: float
|
|
199
|
+
:return:
|
|
200
|
+
:rtype: None
|
|
201
|
+
"""
|
|
202
|
+
self.qkv.enable_lora(r, lora_alpha, lora_dropout)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class TransformerLayer(nn.Module):
|
|
206
|
+
def __init__(
|
|
207
|
+
self, channel: int = 512, num_head: int = 8, dropout: float = 0.01
|
|
208
|
+
) -> None:
|
|
209
|
+
"""
|
|
210
|
+
Transfomer layer block.
|
|
211
|
+
|
|
212
|
+
:param channel: hidden layer features
|
|
213
|
+
:param num_head: number of attention heads
|
|
214
|
+
:param dropout: dropout frequency
|
|
215
|
+
:type channel: int
|
|
216
|
+
:type num_head: int
|
|
217
|
+
:type dropout: float
|
|
218
|
+
"""
|
|
219
|
+
super().__init__()
|
|
220
|
+
self.norm1 = nn.LayerNorm(channel, 1e-6, False)
|
|
221
|
+
self.attention = Attention(channel, num_head)
|
|
222
|
+
self.norm2 = nn.LayerNorm(channel, 1e-6, False)
|
|
223
|
+
self.ffn = nn.Sequential(
|
|
224
|
+
nn.Linear(channel, channel * 4),
|
|
225
|
+
nn.SELU(),
|
|
226
|
+
nn.Linear(channel * 4, channel),
|
|
227
|
+
nn.Dropout(dropout),
|
|
228
|
+
)
|
|
229
|
+
self.adaln_modulation = nn.Sequential(nn.SELU(), Linear(channel, 6 * channel))
|
|
230
|
+
# zero-out adaLN layer
|
|
231
|
+
nn.init.constant_(self.adaln_modulation[1].weight, 0)
|
|
232
|
+
nn.init.constant_(self.adaln_modulation[1].bias, 0)
|
|
233
|
+
|
|
234
|
+
def forward(
|
|
235
|
+
self,
|
|
236
|
+
x: Tensor,
|
|
237
|
+
pe: Tuple[Tensor, Tensor, Tensor],
|
|
238
|
+
c: Tensor,
|
|
239
|
+
mask: Optional[Tensor],
|
|
240
|
+
) -> Tensor:
|
|
241
|
+
"""
|
|
242
|
+
:param x: input tensor; shape: (n_b, n_t, n_f)
|
|
243
|
+
:param pe: position encoding; shape: (1, 1, n_t, n_h) * 3
|
|
244
|
+
:param c: conditioning; shape: (n_b, 1, n_f)
|
|
245
|
+
:param mask: attention mask; shape: (1, n_b, n_t, n_t)
|
|
246
|
+
:type x: torch.Tensor
|
|
247
|
+
:type pe: tuple
|
|
248
|
+
:type c: torch.Tensor
|
|
249
|
+
:type mask: torch.Tensor | None
|
|
250
|
+
:return: output tensor; shape: (n_b, n_t, n_f)
|
|
251
|
+
:rtype: torch.Tensor
|
|
252
|
+
"""
|
|
253
|
+
c = self.adaln_modulation(c)
|
|
254
|
+
shift, scale, gate, shift_ffn, scale_ffn, gate_ffn = c.chunk(6, -1)
|
|
255
|
+
x = x + gate * self.attention(modulate(self.norm1(x), shift, scale), pe, mask)
|
|
256
|
+
x = x + gate_ffn * self.ffn(modulate(self.norm2(x), shift_ffn, scale_ffn))
|
|
257
|
+
return x
|
|
258
|
+
|
|
259
|
+
def enable_lora(
|
|
260
|
+
self, r: int = 4, lora_alpha: int = 1, lora_dropout: float = 0.0
|
|
261
|
+
) -> None:
|
|
262
|
+
"""
|
|
263
|
+
Enable LoRA parameters.
|
|
264
|
+
|
|
265
|
+
:param r: rank
|
|
266
|
+
:param lora_alpha: LoRA alpha value
|
|
267
|
+
:param lora_dropout: dropout frequency in LoRA layer
|
|
268
|
+
:type r: int
|
|
269
|
+
:type lora_alpha: float
|
|
270
|
+
:type lora_dropout: float
|
|
271
|
+
:return:
|
|
272
|
+
:rtype: None
|
|
273
|
+
"""
|
|
274
|
+
self.attention.enable_lora(r, lora_alpha, lora_dropout)
|
|
275
|
+
self.adaln_modulation[1].enable_lora(r, lora_alpha, lora_dropout)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
class FinalLayer(nn.Module):
|
|
279
|
+
def __init__(self, num_vocab: int, channel: int = 512) -> None:
|
|
280
|
+
"""
|
|
281
|
+
The final layer of model.
|
|
282
|
+
|
|
283
|
+
:param num_vocab: number of vocabulary
|
|
284
|
+
:param channel: hidden layer features
|
|
285
|
+
:type num_vocab: int
|
|
286
|
+
:type channel: int
|
|
287
|
+
"""
|
|
288
|
+
super().__init__()
|
|
289
|
+
self.norm_final = nn.LayerNorm(channel, 1e-6, False)
|
|
290
|
+
self.linear = Linear(channel, num_vocab)
|
|
291
|
+
self.adaln_modulation = nn.Sequential(nn.SELU(), Linear(channel, 2 * channel))
|
|
292
|
+
# zero-out this layer
|
|
293
|
+
nn.init.constant_(self.linear.weight, 0)
|
|
294
|
+
nn.init.constant_(self.linear.bias, 0)
|
|
295
|
+
nn.init.constant_(self.adaln_modulation[-1].weight, 0)
|
|
296
|
+
nn.init.constant_(self.adaln_modulation[-1].bias, 0)
|
|
297
|
+
|
|
298
|
+
def forward(self, x: Tensor, c: Tensor, return_logits: bool = True) -> Tensor:
|
|
299
|
+
"""
|
|
300
|
+
:param x: input tensor; shape: (n_b, n_t, n_f)
|
|
301
|
+
:param c: conditioning; shape: (n_b, 1, n_f)
|
|
302
|
+
:param return_logits: whether to return unnormalised output logits
|
|
303
|
+
:type x: torch.Tensor
|
|
304
|
+
:type c: torch.Tensor
|
|
305
|
+
:type return_logits: bool
|
|
306
|
+
:return: output logits (unnormalised); shape: (n_b, n_t, n_vocab)
|
|
307
|
+
or token embeddings; shape: (n_b, n_t, n_f)
|
|
308
|
+
:rtype: torch.Tensor
|
|
309
|
+
"""
|
|
310
|
+
shift, scale = self.adaln_modulation(c).chunk(2, -1)
|
|
311
|
+
x = modulate(self.norm_final(x), shift, scale)
|
|
312
|
+
if return_logits:
|
|
313
|
+
return self.linear(x)
|
|
314
|
+
return x
|
|
315
|
+
|
|
316
|
+
def enable_lora(
|
|
317
|
+
self, r: int = 4, lora_alpha: int = 1, lora_dropout: float = 0.0
|
|
318
|
+
) -> None:
|
|
319
|
+
"""
|
|
320
|
+
Enable LoRA parameters.
|
|
321
|
+
|
|
322
|
+
:param r: rank
|
|
323
|
+
:param lora_alpha: LoRA alpha value
|
|
324
|
+
:param lora_dropout: dropout frequency in LoRA layer
|
|
325
|
+
:type r: int
|
|
326
|
+
:type lora_alpha: float
|
|
327
|
+
:type lora_dropout: float
|
|
328
|
+
:return:
|
|
329
|
+
:rtype: None
|
|
330
|
+
"""
|
|
331
|
+
self.linear.enable_lora(r, lora_alpha, lora_dropout)
|
|
332
|
+
self.adaln_modulation[1].enable_lora(r, lora_alpha, lora_dropout)
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
class ChemBFN(nn.Module):
|
|
336
|
+
def __init__(
|
|
337
|
+
self,
|
|
338
|
+
num_vocab: int,
|
|
339
|
+
channel: int = 512,
|
|
340
|
+
num_layer: int = 12,
|
|
341
|
+
num_head: int = 8,
|
|
342
|
+
dropout: float = 0.01,
|
|
343
|
+
) -> None:
|
|
344
|
+
r"""
|
|
345
|
+
Bayesian Flow Network for Chemistry model representation.
|
|
346
|
+
|
|
347
|
+
Enable semi-autoregressive sampling by setting
|
|
348
|
+
`ChemBFN(...).semi_autoregressive = True`.
|
|
349
|
+
|
|
350
|
+
:param num_vocab: number of vocabulary
|
|
351
|
+
:param channel: hidden layer features
|
|
352
|
+
:param num_layer: number of transformer layers
|
|
353
|
+
:param num_head: number of heads
|
|
354
|
+
:param dropout: dropout frequency
|
|
355
|
+
:type num_vocab: int
|
|
356
|
+
:type channel: int
|
|
357
|
+
:type num_layer: int
|
|
358
|
+
:type num_head: int
|
|
359
|
+
:type dropout: float
|
|
360
|
+
"""
|
|
361
|
+
super().__init__()
|
|
362
|
+
self.K = num_vocab
|
|
363
|
+
self.lora_enabled: bool = False
|
|
364
|
+
self.semi_autoregressive: bool = False
|
|
365
|
+
self.embedding = Linear(num_vocab, channel)
|
|
366
|
+
self.time_embed = nn.Sequential(
|
|
367
|
+
nn.Linear(1, channel // 2), nn.SELU(), nn.Linear(channel // 2, channel)
|
|
368
|
+
)
|
|
369
|
+
self.position = RoPE(channel, num_head)
|
|
370
|
+
self.encoder_layers = nn.ModuleList(
|
|
371
|
+
[TransformerLayer(channel, num_head, dropout) for _ in range(num_layer)]
|
|
372
|
+
)
|
|
373
|
+
self.final_layer = FinalLayer(num_vocab, channel)
|
|
374
|
+
self.register_buffer("beta", torch.scalar_tensor(20.4054 / self.K))
|
|
375
|
+
self.hparam = dict(
|
|
376
|
+
num_vocab=num_vocab,
|
|
377
|
+
channel=channel,
|
|
378
|
+
num_layer=num_layer,
|
|
379
|
+
num_head=num_head,
|
|
380
|
+
dropout=dropout,
|
|
381
|
+
)
|
|
382
|
+
self.lora_param = {}
|
|
383
|
+
|
|
384
|
+
def enable_lora(
|
|
385
|
+
self, r: int = 4, lora_alpha: int = 1, lora_dropout: float = 0.0
|
|
386
|
+
) -> None:
|
|
387
|
+
"""
|
|
388
|
+
Enable LoRA parameters.
|
|
389
|
+
|
|
390
|
+
:param r: rank
|
|
391
|
+
:param lora_alpha: LoRA alpha value
|
|
392
|
+
:param lora_dropout: dropout frequency in LoRA layer
|
|
393
|
+
:type r: int
|
|
394
|
+
:type lora_alpha: float
|
|
395
|
+
:type lora_dropout: float
|
|
396
|
+
:return:
|
|
397
|
+
:rtype: None
|
|
398
|
+
"""
|
|
399
|
+
self.lora_enabled = True
|
|
400
|
+
self.lora_param = dict(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
|
|
401
|
+
self.embedding.enable_lora(r, lora_alpha, lora_dropout)
|
|
402
|
+
for layer in self.encoder_layers:
|
|
403
|
+
layer.enable_lora(r, lora_alpha, lora_dropout)
|
|
404
|
+
self.final_layer.enable_lora(r, lora_alpha, lora_dropout)
|
|
405
|
+
|
|
406
|
+
def forward(
|
|
407
|
+
self,
|
|
408
|
+
x: Tensor,
|
|
409
|
+
t: Tensor,
|
|
410
|
+
mask: Optional[Tensor] = None,
|
|
411
|
+
y: Optional[Tensor] = None,
|
|
412
|
+
) -> Tensor:
|
|
413
|
+
"""
|
|
414
|
+
:param x: input probabilities; shape: (n_b, n_t, n_vocab)
|
|
415
|
+
:param t: time; shape: (n_b, 1, 1)
|
|
416
|
+
:param mask: input mask; shape: (n_b, n_t, 1)
|
|
417
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f)
|
|
418
|
+
:type x: torch.Tensor
|
|
419
|
+
:type t: torch.Tensor
|
|
420
|
+
:type mask: torch.Tensor | None
|
|
421
|
+
:type y: torch.Tensor | None
|
|
422
|
+
:return: probability distribution (before softmax); shape: (n_b, n_t, n_vocab)
|
|
423
|
+
or token embeddings; shape: (n_b, n_t, n_f)
|
|
424
|
+
:rtype: torch.Tensor
|
|
425
|
+
"""
|
|
426
|
+
n_b, n_t, _ = x.shape
|
|
427
|
+
c = self.time_embed(t)
|
|
428
|
+
if y is not None:
|
|
429
|
+
c += y
|
|
430
|
+
pe = self.position(x.shape[1])
|
|
431
|
+
x = self.embedding(x)
|
|
432
|
+
attn_mask: Optional[Tensor] = None
|
|
433
|
+
if self.semi_autoregressive:
|
|
434
|
+
attn_mask = torch.tril(
|
|
435
|
+
torch.ones((1, n_b, n_t, n_t), device=self.beta.device), diagonal=0
|
|
436
|
+
)
|
|
437
|
+
else:
|
|
438
|
+
if mask is not None:
|
|
439
|
+
"""
|
|
440
|
+
# Original Code.
|
|
441
|
+
|
|
442
|
+
attn_mask = mask.transpose(-2, -1).repeat(1, x.shape[1], 1)[None, ...] == 0
|
|
443
|
+
"""
|
|
444
|
+
attn_mask = mask.transpose(-2, -1).repeat(1, n_t, 1)[None, ...] != 0
|
|
445
|
+
for layer in self.encoder_layers:
|
|
446
|
+
x = layer(x, pe, c, attn_mask)
|
|
447
|
+
return self.final_layer(x, c, mask is None)
|
|
448
|
+
|
|
449
|
+
def calc_beta(self, t: Tensor) -> Tensor:
|
|
450
|
+
r"""
|
|
451
|
+
Calculate beta(t) value.
|
|
452
|
+
|
|
453
|
+
.. math::
|
|
454
|
+
```
|
|
455
|
+
\begin{equation}
|
|
456
|
+
\beta(t) = %
|
|
457
|
+
-\frac{4\ln{(1 - t + te^{-\frac{K}{4}\beta(1)})}}{K}
|
|
458
|
+
\end{equation}
|
|
459
|
+
```
|
|
460
|
+
|
|
461
|
+
:param t: continuous time in [0, 1]; shape: (n_b, 1, 1)
|
|
462
|
+
:type t: torch.Tensor
|
|
463
|
+
:return: beta(t); shape: (n_b, 1, 1)
|
|
464
|
+
:rtype: torch.Tensor
|
|
465
|
+
"""
|
|
466
|
+
return -4 * (1 - t + t * (-self.K * self.beta / 4).exp()).log() / self.K
|
|
467
|
+
|
|
468
|
+
def calc_discrete_alpha(self, t1: Tensor, t2: Tensor) -> Tensor:
|
|
469
|
+
r"""
|
|
470
|
+
Calculate alpha(i) value.
|
|
471
|
+
|
|
472
|
+
.. math:: $\alpha(i) = \bate(t_{i}) - \beta(t_{i - 1})$
|
|
473
|
+
|
|
474
|
+
:param t1: discrete time (i - 1) / n; shape: (n_b, 1, 1)
|
|
475
|
+
:param t2: discrete time i / n; shape: (n_b, 1, 1)
|
|
476
|
+
:type t1: torch.Tensor
|
|
477
|
+
:type t2: torch.Tensor
|
|
478
|
+
:return: alpha(i); shape: (n_b, 1, 1)
|
|
479
|
+
:rtype: torch.Tensor
|
|
480
|
+
"""
|
|
481
|
+
# assert t2 > t1
|
|
482
|
+
return self.calc_beta(t2) - self.calc_beta(t1)
|
|
483
|
+
|
|
484
|
+
def calc_cts_alpha(self, t: Tensor) -> Tensor:
|
|
485
|
+
r"""
|
|
486
|
+
Calculate alpha(t) / 2 value.
|
|
487
|
+
|
|
488
|
+
.. math::
|
|
489
|
+
```
|
|
490
|
+
\begin{equation}
|
|
491
|
+
\alpha(t) = %
|
|
492
|
+
\frac{d\beta(t)}{dt} = %
|
|
493
|
+
\frac{4}{K}%
|
|
494
|
+
\frac{1 - e^{-\frac{K}{4}\beta(1)}}%
|
|
495
|
+
{1 - t + te^{-\frac{K}{4}\beta(1)}}
|
|
496
|
+
\end{equation}
|
|
497
|
+
```
|
|
498
|
+
|
|
499
|
+
:param t: continuous time in [0, 1]; shape: (n_b, 1, 1)
|
|
500
|
+
:type t: torch.Tensor
|
|
501
|
+
:return: alpha(t); shape: (n_b, 1, 1)
|
|
502
|
+
:rtype: torch.Tensor
|
|
503
|
+
"""
|
|
504
|
+
a = 1 - (-self.K * self.beta / 4).exp()
|
|
505
|
+
b = 1 - t + t * (-self.K * self.beta / 4).exp()
|
|
506
|
+
return 2 * a / b / self.K
|
|
507
|
+
|
|
508
|
+
def discrete_output_distribution(
|
|
509
|
+
self, theta: Tensor, t: Tensor, y: Optional[Tensor], w: Optional[float]
|
|
510
|
+
) -> Tensor:
|
|
511
|
+
"""
|
|
512
|
+
:param theta: input distribution; shape: (n_b, n_t, n_vocab)
|
|
513
|
+
:param t: continuous time in [0, 1]; shape: (n_b, 1, 1)
|
|
514
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f)
|
|
515
|
+
:param w: guidance strength controlling the conditional generation
|
|
516
|
+
:type theta: torch.Tensor
|
|
517
|
+
:type t: torch.Tensor
|
|
518
|
+
:type y: torch.Tensor | None
|
|
519
|
+
:type w: float | None
|
|
520
|
+
:return: output distribution; shape: (n_b, n_t, n_vocab)
|
|
521
|
+
:rtype: torch.Tensor
|
|
522
|
+
"""
|
|
523
|
+
theta = 2 * theta - 1 # rescale to [-1, 1]
|
|
524
|
+
if w is None:
|
|
525
|
+
return softmax(self.forward(theta, t, None, y), -1)
|
|
526
|
+
elif y is None:
|
|
527
|
+
return softmax(self.forward(theta, t, None, None), -1)
|
|
528
|
+
else:
|
|
529
|
+
p_cond = self.forward(theta, t, None, y)
|
|
530
|
+
p_uncond = self.forward(theta, t, None, None)
|
|
531
|
+
return softmax((1 + w) * p_cond - w * p_uncond, -1)
|
|
532
|
+
|
|
533
|
+
def cts_loss(
|
|
534
|
+
self,
|
|
535
|
+
x: Tensor,
|
|
536
|
+
t: Tensor,
|
|
537
|
+
y: Optional[Tensor],
|
|
538
|
+
mask: Optional[Tensor] = None,
|
|
539
|
+
return_output_dist: bool = False,
|
|
540
|
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
|
541
|
+
"""
|
|
542
|
+
Compute continuous-time loss.
|
|
543
|
+
|
|
544
|
+
:param x: target data; shape: (n_b, n_t)
|
|
545
|
+
:param t: continuous time in [0, 1); shape: (n_b, 1, 1)
|
|
546
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f)
|
|
547
|
+
:param mask: in-text mask; shape: (n_b, n_t)
|
|
548
|
+
:param return_output_dist: whether to return the output distribution
|
|
549
|
+
:type x: torch.Tensor
|
|
550
|
+
:type t: torch.Tensor
|
|
551
|
+
:type y: torch.Tensor | None
|
|
552
|
+
:type mask: torch.Tensor | None
|
|
553
|
+
:type return_output_dist: bool
|
|
554
|
+
:returns: continuous-time loss; shape: () \n
|
|
555
|
+
output distribution; shape: (n_b, n_t, n_vocab) or `None`
|
|
556
|
+
:rtype: tuple
|
|
557
|
+
"""
|
|
558
|
+
beta = self.calc_beta(t) # shape: (n_b, 1, 1)
|
|
559
|
+
e_x = nn.functional.one_hot(x, self.K).float()
|
|
560
|
+
mu = beta * (self.K * e_x - 1)
|
|
561
|
+
sigma = (beta * self.K).sqrt()
|
|
562
|
+
theta = softmax(mu + sigma * torch.randn_like(mu), -1)
|
|
563
|
+
if mask is not None:
|
|
564
|
+
mask = mask[..., None]
|
|
565
|
+
theta = e_x * mask + (1 - mask) * theta
|
|
566
|
+
e_hat = self.discrete_output_distribution(theta, t, y, None)
|
|
567
|
+
cts_loss = self.K * (e_x - e_hat).pow(2) * self.calc_cts_alpha(t)
|
|
568
|
+
if return_output_dist:
|
|
569
|
+
return cts_loss.mean(), e_hat
|
|
570
|
+
return cts_loss.mean(), None
|
|
571
|
+
|
|
572
|
+
@torch.inference_mode()
|
|
573
|
+
def reconstruction_loss(self, x: Tensor, t: Tensor, y: Optional[Tensor]) -> Tensor:
|
|
574
|
+
"""
|
|
575
|
+
Compute reconstruction loss.
|
|
576
|
+
|
|
577
|
+
:param x: target data; shape: (n_b, n_t)
|
|
578
|
+
:param t: continuous time in [0, 1]; shape: (n_b, 1, 1)
|
|
579
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f)
|
|
580
|
+
:type x: torch.Tensor
|
|
581
|
+
:type t: torch.Tensor
|
|
582
|
+
:type y: torch.Tensor | None
|
|
583
|
+
:return: reconstruction loss; shape: ()
|
|
584
|
+
:rtype: torch.Tensor
|
|
585
|
+
"""
|
|
586
|
+
beta = self.calc_beta(t)
|
|
587
|
+
mu = beta * (self.K * nn.functional.one_hot(x, self.K).float() - 1)
|
|
588
|
+
sigma = (beta * self.K).sqrt()
|
|
589
|
+
theta = softmax(mu + sigma * torch.randn_like(mu), -1)
|
|
590
|
+
logits = self.forward(2 * theta - 1, t, None, y)
|
|
591
|
+
# compute negative log probability
|
|
592
|
+
x, logits = torch.broadcast_tensors(x[..., None], logits)
|
|
593
|
+
return (-logits.gather(-1, x[..., :1]).squeeze(-1)).mean()
|
|
594
|
+
|
|
595
|
+
@torch.jit.export
|
|
596
|
+
def sample(
|
|
597
|
+
self,
|
|
598
|
+
batch_size: int,
|
|
599
|
+
sequence_size: int,
|
|
600
|
+
y: Optional[Tensor],
|
|
601
|
+
sample_step: int = 100,
|
|
602
|
+
guidance_strength: float = 4.0,
|
|
603
|
+
token_mask: Optional[Tensor] = None,
|
|
604
|
+
) -> Tensor:
|
|
605
|
+
"""
|
|
606
|
+
Sample from a piror distribution.
|
|
607
|
+
|
|
608
|
+
:param batch_size: batch size
|
|
609
|
+
:param sequence_size: max sequence length
|
|
610
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f)
|
|
611
|
+
:param sample_step: number of sampling steps
|
|
612
|
+
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
613
|
+
:param token_mask: token mask; shape: (1, 1, n_vocab)
|
|
614
|
+
:type batch_size: int
|
|
615
|
+
:type sequence_size: int
|
|
616
|
+
:type y: torch.Tensor | None
|
|
617
|
+
:type sample_step: int
|
|
618
|
+
:type guidance_strength: float
|
|
619
|
+
:type token_mask: torch.Tensor | None
|
|
620
|
+
:return: probability distribution; shape: (n_b, n_t, n_vocab)
|
|
621
|
+
:rtype: torch.Tensor
|
|
622
|
+
"""
|
|
623
|
+
theta = (
|
|
624
|
+
torch.ones((batch_size, sequence_size, self.K), device=self.beta.device)
|
|
625
|
+
/ self.K
|
|
626
|
+
)
|
|
627
|
+
if y is not None:
|
|
628
|
+
assert y.dim() == 3 # this doesn't work if the model is frezen in JIT.
|
|
629
|
+
if y.shape[0] == 1:
|
|
630
|
+
y = y.repeat(batch_size, 1, 1)
|
|
631
|
+
for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
|
|
632
|
+
t = (i - 1).view(1, 1, 1).repeat(batch_size, 1, 1) / sample_step
|
|
633
|
+
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
634
|
+
if token_mask is not None:
|
|
635
|
+
p = p.masked_fill_(token_mask, 0.0)
|
|
636
|
+
alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
|
|
637
|
+
e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
|
|
638
|
+
mu = alpha * (self.K * e_k - 1)
|
|
639
|
+
sigma = (alpha * self.K).sqrt()
|
|
640
|
+
theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
|
|
641
|
+
theta = theta / theta.sum(-1, True)
|
|
642
|
+
t_final = torch.ones((batch_size, 1, 1), device=self.beta.device)
|
|
643
|
+
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
|
|
644
|
+
if token_mask is not None:
|
|
645
|
+
p = p.masked_fill_(token_mask, 0.0)
|
|
646
|
+
return torch.argmax(p, -1)
|
|
647
|
+
|
|
648
|
+
@torch.jit.export
|
|
649
|
+
def ode_sample(
|
|
650
|
+
self,
|
|
651
|
+
batch_size: int,
|
|
652
|
+
sequence_size: int,
|
|
653
|
+
y: Optional[Tensor],
|
|
654
|
+
sample_step: int = 100,
|
|
655
|
+
guidance_strength: float = 4.0,
|
|
656
|
+
token_mask: Optional[Tensor] = None,
|
|
657
|
+
temperature: float = 0.5,
|
|
658
|
+
) -> Tensor:
|
|
659
|
+
"""
|
|
660
|
+
ODE-based sampling.
|
|
661
|
+
|
|
662
|
+
:param batch_size: batch size
|
|
663
|
+
:param sequence_size: max sequence length
|
|
664
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f)
|
|
665
|
+
:param sample_step: number of sampling steps
|
|
666
|
+
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
667
|
+
:param token_mask: token mask; shape: (1, 1, n_vocab)
|
|
668
|
+
:param temperature: sampling temperature
|
|
669
|
+
:type batch_size: int
|
|
670
|
+
:type sequence_size: int
|
|
671
|
+
:type y: torch.Tensor | None
|
|
672
|
+
:type sample_step: int
|
|
673
|
+
:type guidance_strength: float
|
|
674
|
+
:type token_mask: torch.Tensor | None
|
|
675
|
+
:type temperature: float
|
|
676
|
+
:return: probability distribution; shape: (n_b, n_t, n_vocab)
|
|
677
|
+
:rtype: torch.Tensor
|
|
678
|
+
"""
|
|
679
|
+
z = torch.zeros((batch_size, sequence_size, self.K), device=self.beta.device)
|
|
680
|
+
if y is not None:
|
|
681
|
+
assert y.dim() == 3 # this doesn't work if the model is frezen in JIT.
|
|
682
|
+
if y.shape[0] == 1:
|
|
683
|
+
y = y.repeat(batch_size, 1, 1)
|
|
684
|
+
for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
|
|
685
|
+
t = (i - 1).view(1, 1, 1).repeat(batch_size, 1, 1) / sample_step
|
|
686
|
+
theta = torch.softmax(z, -1)
|
|
687
|
+
beta = self.calc_beta(t + 1 / sample_step)
|
|
688
|
+
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
689
|
+
if token_mask is not None:
|
|
690
|
+
p = p.masked_fill_(token_mask, 0.0)
|
|
691
|
+
u = torch.randn_like(z)
|
|
692
|
+
z = (self.K * p - 1) * beta + (self.K * beta * temperature).sqrt() * u
|
|
693
|
+
t_final = torch.ones((batch_size, 1, 1), device=self.beta.device)
|
|
694
|
+
theta = torch.softmax(z, -1)
|
|
695
|
+
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
|
|
696
|
+
if token_mask is not None:
|
|
697
|
+
p = p.masked_fill_(token_mask, 0.0)
|
|
698
|
+
return torch.argmax(p, -1)
|
|
699
|
+
|
|
700
|
+
@torch.jit.export
|
|
701
|
+
def inpaint(
|
|
702
|
+
self,
|
|
703
|
+
x: Tensor,
|
|
704
|
+
y: Optional[Tensor] = None,
|
|
705
|
+
sample_step: int = 100,
|
|
706
|
+
guidance_strength: float = 4.0,
|
|
707
|
+
token_mask: Optional[Tensor] = None,
|
|
708
|
+
) -> Tensor:
|
|
709
|
+
"""
|
|
710
|
+
Molecule inpaint functionality.
|
|
711
|
+
|
|
712
|
+
:param x: categorical indices of scaffold; shape: (n_b, n_t)
|
|
713
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f)
|
|
714
|
+
:param sample_step: number of sampling steps
|
|
715
|
+
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
716
|
+
:param token_mask: token mask; shape: (1, 1, n_vocab)
|
|
717
|
+
:type x: torch.Tensor
|
|
718
|
+
:type y: torch.Tensor | None
|
|
719
|
+
:type sample_step: int
|
|
720
|
+
:type guidance_strength: float
|
|
721
|
+
:type token_mask: torch.Tensor | None
|
|
722
|
+
:return: probability distribution; shape: (n_b, n_t, n_vocab)
|
|
723
|
+
:rtype: torch.Tensor
|
|
724
|
+
"""
|
|
725
|
+
n_b, n_t = x.shape
|
|
726
|
+
mask = (x != 0).float()[..., None]
|
|
727
|
+
theta = torch.ones((n_b, n_t, self.K), device=x.device) / self.K
|
|
728
|
+
x_onehot = nn.functional.one_hot(x, self.K) * mask
|
|
729
|
+
theta = x_onehot + (1 - mask) * theta
|
|
730
|
+
if y is not None:
|
|
731
|
+
assert y.dim() == 3 # this doesn't work if the model is frezen in JIT.
|
|
732
|
+
if y.shape[0] == 1:
|
|
733
|
+
y = y.repeat(n_b, 1, 1)
|
|
734
|
+
for i in torch.linspace(1, sample_step, sample_step, device=x.device):
|
|
735
|
+
t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
|
|
736
|
+
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
737
|
+
if token_mask is not None:
|
|
738
|
+
p = p.masked_fill_(token_mask, 0.0)
|
|
739
|
+
alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
|
|
740
|
+
e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
|
|
741
|
+
mu = alpha * (self.K * e_k - 1)
|
|
742
|
+
sigma = (alpha * self.K).sqrt()
|
|
743
|
+
theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
|
|
744
|
+
theta = theta / theta.sum(-1, True)
|
|
745
|
+
theta = x_onehot + (1 - mask) * theta
|
|
746
|
+
t_final = torch.ones((n_b, 1, 1), device=x.device)
|
|
747
|
+
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
|
|
748
|
+
if token_mask is not None:
|
|
749
|
+
p = p.masked_fill_(token_mask, 0.0)
|
|
750
|
+
return torch.argmax(p, -1)
|
|
751
|
+
|
|
752
|
+
@torch.jit.export
|
|
753
|
+
def ode_inpaint(
|
|
754
|
+
self,
|
|
755
|
+
x: Tensor,
|
|
756
|
+
y: Optional[Tensor] = None,
|
|
757
|
+
sample_step: int = 100,
|
|
758
|
+
guidance_strength: float = 4.0,
|
|
759
|
+
token_mask: Optional[Tensor] = None,
|
|
760
|
+
temperature: float = 0.5,
|
|
761
|
+
) -> Tensor:
|
|
762
|
+
"""
|
|
763
|
+
ODE inpainting.
|
|
764
|
+
|
|
765
|
+
:param x: categorical indices of scaffold; shape: (n_b, n_t)
|
|
766
|
+
:param y: conditioning vector; shape: (n_b, 1, n_f)
|
|
767
|
+
:param sample_step: number of sampling steps
|
|
768
|
+
:param guidance_strength: strength of conditional generation. It is not used if y is null.
|
|
769
|
+
:param token_mask: token mask; shape: (1, 1, n_vocab)
|
|
770
|
+
:param temperature: sampling temperature
|
|
771
|
+
:type x: torch.Tensor
|
|
772
|
+
:type y: torch.Tensor | None
|
|
773
|
+
:type sample_step: int
|
|
774
|
+
:type guidance_strength: float
|
|
775
|
+
:type token_mask: torch.Tensor | None
|
|
776
|
+
:type temperature: float
|
|
777
|
+
:return: probability distribution; shape: (n_b, n_t, n_vocab)
|
|
778
|
+
:rtype: torch.Tensor
|
|
779
|
+
"""
|
|
780
|
+
n_b, n_t = x.shape
|
|
781
|
+
mask = (x != 0).float()[..., None]
|
|
782
|
+
x_onehot = nn.functional.one_hot(x, self.K) * mask
|
|
783
|
+
z = torch.zeros((n_b, n_t, self.K), device=self.beta.device)
|
|
784
|
+
if y is not None:
|
|
785
|
+
assert y.dim() == 3 # this doesn't work if the model is frezen in JIT.
|
|
786
|
+
if y.shape[0] == 1:
|
|
787
|
+
y = y.repeat(n_b, 1, 1)
|
|
788
|
+
for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
|
|
789
|
+
t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
|
|
790
|
+
theta = torch.softmax(z, -1)
|
|
791
|
+
theta = x_onehot + (1 - mask) * theta
|
|
792
|
+
beta = self.calc_beta(t + 1 / sample_step)
|
|
793
|
+
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
|
|
794
|
+
if token_mask is not None:
|
|
795
|
+
p = p.masked_fill_(token_mask, 0.0)
|
|
796
|
+
u = torch.randn_like(z)
|
|
797
|
+
z = (self.K * p - 1) * beta + (self.K * beta * temperature).sqrt() * u
|
|
798
|
+
t_final = torch.ones((n_b, 1, 1), device=self.beta.device)
|
|
799
|
+
theta = torch.softmax(z, -1)
|
|
800
|
+
theta = x_onehot + (1 - mask) * theta
|
|
801
|
+
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
|
|
802
|
+
if token_mask is not None:
|
|
803
|
+
p = p.masked_fill_(token_mask, 0.0)
|
|
804
|
+
return torch.argmax(p, -1)
|
|
805
|
+
|
|
806
|
+
def inference(self, x: Tensor, mlp: nn.Module) -> Tensor:
|
|
807
|
+
"""
|
|
808
|
+
Predict from SMILES tokens.
|
|
809
|
+
|
|
810
|
+
:param x: input tokens; shape: (n_b, n_t)
|
|
811
|
+
:param mlp: MLP module
|
|
812
|
+
:type x: torch.Tensor
|
|
813
|
+
:type mlp: torch.nn.Module
|
|
814
|
+
:return: output values; shape: (n_b, n_task)
|
|
815
|
+
:rtype: torch.Tensor
|
|
816
|
+
"""
|
|
817
|
+
t = torch.ones((x.shape[0], 1, 1), device=x.device)
|
|
818
|
+
mask = (x != 0).float()[..., None]
|
|
819
|
+
theta = 2 * torch.nn.functional.one_hot(x, self.K).float() - 1
|
|
820
|
+
z = self.forward(theta, t, mask, None)
|
|
821
|
+
if self.semi_autoregressive:
|
|
822
|
+
return mlp.forward(z[x == 2].view(z.shape[0], -1))
|
|
823
|
+
return mlp.forward(z[::, 0])
|
|
824
|
+
|
|
825
|
+
@classmethod
|
|
826
|
+
def from_checkpoint(
|
|
827
|
+
cls, ckpt: Union[str, Path], ckpt_lora: Union[str, Path, None] = None
|
|
828
|
+
) -> Self:
|
|
829
|
+
"""
|
|
830
|
+
Load model weight from a checkpoint.
|
|
831
|
+
|
|
832
|
+
:param ckpt: checkpoint file
|
|
833
|
+
:param ckpt_lora: LoRA checkpoint file which is optional
|
|
834
|
+
:type ckpt: str | pathlib.Path
|
|
835
|
+
:type ckpt_lora: str | pathlib.Path | None
|
|
836
|
+
:return: Bayesian Flow Network for Chemistry model
|
|
837
|
+
:rtype: bayesianflow_for_chem.model.ChemBNF
|
|
838
|
+
"""
|
|
839
|
+
with open(ckpt, "rb") as f:
|
|
840
|
+
state = torch.load(f, "cpu", weights_only=True)
|
|
841
|
+
nn, hparam = state["nn"], state["hparam"]
|
|
842
|
+
model = ChemBFN(
|
|
843
|
+
hparam["num_vocab"],
|
|
844
|
+
hparam["channel"],
|
|
845
|
+
hparam["num_layer"],
|
|
846
|
+
hparam["num_head"],
|
|
847
|
+
hparam["dropout"],
|
|
848
|
+
)
|
|
849
|
+
model.load_state_dict(nn, False)
|
|
850
|
+
if ckpt_lora:
|
|
851
|
+
with open(ckpt_lora, "rb") as g:
|
|
852
|
+
lora_state = torch.load(g, "cpu", weights_only=True)
|
|
853
|
+
lora_nn, lora_param = lora_state["lora_nn"], lora_state["lora_param"]
|
|
854
|
+
model.enable_lora(**lora_param)
|
|
855
|
+
model.load_state_dict(lora_nn, False)
|
|
856
|
+
return model
|
|
857
|
+
|
|
858
|
+
|
|
859
|
+
class MLP(nn.Module):
|
|
860
|
+
def __init__(
|
|
861
|
+
self, size: List[int], class_input: bool = False, dropout: float = 0.0
|
|
862
|
+
) -> None:
|
|
863
|
+
"""
|
|
864
|
+
MLP module.
|
|
865
|
+
e.g.
|
|
866
|
+
|
|
867
|
+
```python
|
|
868
|
+
mlp = MLP(size=[512, 256, 1])
|
|
869
|
+
mlp = MLP(size=[10, 256, 512], True) # embedding 10 classes
|
|
870
|
+
```
|
|
871
|
+
|
|
872
|
+
:param size: hidden feature sizes
|
|
873
|
+
:param class_input: whether the input is class indices
|
|
874
|
+
:param dropout: dropout frequency
|
|
875
|
+
:type size: list
|
|
876
|
+
:type class_input: bool
|
|
877
|
+
:type dropout: float
|
|
878
|
+
"""
|
|
879
|
+
super().__init__()
|
|
880
|
+
assert len(size) >= 2
|
|
881
|
+
self.class_input = class_input
|
|
882
|
+
self.dropout = nn.Dropout(dropout if not class_input else 0.0)
|
|
883
|
+
self.layers = nn.ModuleList(
|
|
884
|
+
[nn.Linear(i, size[key + 1]) for key, i in enumerate(size[:-2])]
|
|
885
|
+
)
|
|
886
|
+
if class_input:
|
|
887
|
+
self.layers[0] = nn.Embedding(size[0], size[1])
|
|
888
|
+
self.layers.append(nn.Linear(size[-2], size[-1]))
|
|
889
|
+
self.hparam = dict(size=size, class_input=class_input, dropout=dropout)
|
|
890
|
+
|
|
891
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
892
|
+
"""
|
|
893
|
+
:param x: input tensor; shape: (n_b, n_input)
|
|
894
|
+
:return: output tensor; shape: (n_b, n_output) if not class_input;
|
|
895
|
+
(n_b, 1, n_output) if class_input
|
|
896
|
+
:type x: torch.Tensor
|
|
897
|
+
:rtype: torch.Tensor
|
|
898
|
+
"""
|
|
899
|
+
x = self.dropout(x)
|
|
900
|
+
if self.class_input:
|
|
901
|
+
x = x.to(dtype=torch.long)
|
|
902
|
+
for layer in self.layers[:-1]:
|
|
903
|
+
x = torch.selu(layer(x))
|
|
904
|
+
return self.layers[-1](x)
|
|
905
|
+
|
|
906
|
+
@classmethod
|
|
907
|
+
def from_checkpoint(cls, ckpt: Union[str, Path], strict: bool = True) -> Self:
|
|
908
|
+
"""
|
|
909
|
+
Load model weight from a checkpoint.
|
|
910
|
+
|
|
911
|
+
:param ckpt: checkpoint file
|
|
912
|
+
:param strict: whether to strictly match `state_dict`
|
|
913
|
+
:type ckpt: str | pathlib.Path
|
|
914
|
+
:type strict: bool
|
|
915
|
+
:return: MLP
|
|
916
|
+
:rtype: bayesianflow_for_chem.model.MLP
|
|
917
|
+
"""
|
|
918
|
+
with open(ckpt, "rb") as f:
|
|
919
|
+
state = torch.load(f, "cpu", weights_only=True)
|
|
920
|
+
nn, hparam = state["nn"], state["hparam"]
|
|
921
|
+
model = MLP(hparam["size"], hparam["class_input"], hparam["dropout"])
|
|
922
|
+
model.load_state_dict(nn, strict)
|
|
923
|
+
return model
|
|
924
|
+
|
|
925
|
+
|
|
926
|
+
if __name__ == "__main__":
|
|
927
|
+
...
|