bayesianflow-for-chem 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bayesianflow-for-chem might be problematic. Click here for more details.

@@ -0,0 +1,927 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Author: Nianze A. Tao (Omozawa Sueno)
3
+ """
4
+ Define Bayesian Flow Network for Chemistry (ChemBFN) model.
5
+ """
6
+ from pathlib import Path
7
+ from typing import List, Tuple, Optional, Union
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch import Tensor
11
+ from torch.nn.functional import softmax, linear, dropout
12
+ from typing_extensions import Self
13
+
14
+
15
+ class Linear(nn.Linear):
16
+ # Modified from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
17
+ # We made it simpler and compatible with both `loralib` and `TorchScript`.
18
+ def __init__(self, in_features: int, out_features: int, bias: bool = True, **kargs):
19
+ """
20
+ LoRA implemented in a dense layer.
21
+
22
+ :param in_features: number of input features
23
+ :param out_features: number of output features
24
+ :param bias: whether to use additional bias
25
+ :param device: device
26
+ :param dtype: PyTorch data type
27
+ :type in_features: int
28
+ :type out_features: int
29
+ :type bias: bool
30
+ :type device: torch.device | str | None
31
+ :type dtype: torch.dtype
32
+ """
33
+ nn.Linear.__init__(self, in_features, out_features, bias, **kargs)
34
+ self.lora_enabled: bool = False
35
+ self.lora_A: Optional[nn.Parameter] = None
36
+ self.lora_B: Optional[nn.Parameter] = None
37
+ self.scaling: Optional[float] = None
38
+ self.lora_dropout: Optional[float] = None
39
+ nn.Linear.reset_parameters(self)
40
+
41
+ def enable_lora(
42
+ self, r: int = 8, lora_alpha: int = 1, lora_dropout: float = 0.0
43
+ ) -> None:
44
+ """
45
+ Enable LoRA parameters.
46
+
47
+ :param r: rank
48
+ :param lora_alpha: LoRA alpha value
49
+ :param lora_dropout: dropout frequency in LoRA layer
50
+ :type r: int
51
+ :type lora_alpha: float
52
+ :type lora_dropout: float
53
+ :return:
54
+ :rtype: None
55
+ """
56
+ assert r > 0, "Rank should be larger than 0."
57
+ self.lora_A = nn.Parameter(self.weight.new_zeros((r, self.in_features)))
58
+ self.lora_B = nn.Parameter(self.weight.new_zeros((self.out_features, r)))
59
+ self.scaling = lora_alpha / r
60
+ self.lora_dropout = lora_dropout
61
+ self.lora_enabled = True
62
+ nn.init.kaiming_uniform_(self.lora_A, a=5**0.5)
63
+ nn.init.zeros_(self.lora_B)
64
+ self.weight.requires_grad_(False)
65
+
66
+ def forward(self, x: Tensor) -> Tensor:
67
+ result = linear(x, self.weight, self.bias)
68
+ if self.lora_enabled and isinstance(self.lora_dropout, float):
69
+ result += (
70
+ dropout(x, self.lora_dropout, self.training)
71
+ @ self.lora_A.transpose(0, 1)
72
+ @ self.lora_B.transpose(0, 1)
73
+ ) * self.scaling
74
+ return result
75
+
76
+
77
+ def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor:
78
+ return x * (1 + scale) + shift
79
+
80
+
81
+ class RoPE(nn.Module):
82
+ def __init__(self, channel: int = 512, num_head: int = 8) -> None:
83
+ """
84
+ Rotary position embedding block with XPOS method.
85
+
86
+ :param channel: hidden layer features
87
+ :param num_head: number of heads
88
+ :type channel: int
89
+ :type num_head: int
90
+ """
91
+ super().__init__()
92
+ d = channel // num_head
93
+ assert d % 2 == 0
94
+ self.channel = channel
95
+ i = torch.arange(0, d, 2)[None, :] / d
96
+ theta_half = torch.pow(10000, -i)
97
+ zeta_half = (i + 0.4) / 1.4
98
+ theta, zeta = torch.zeros((1, d)), torch.zeros((1, d))
99
+ theta[:, 0::2] = theta_half
100
+ theta[:, 1::2] = theta_half
101
+ zeta[:, 0::2] = zeta_half
102
+ zeta[:, 1::2] = zeta_half
103
+ self.register_buffer("theta", theta)
104
+ self.register_buffer("zeta", zeta)
105
+
106
+ def forward(self, size: int) -> Tuple[Tensor, Tensor, Tensor]:
107
+ """
108
+ :param size: maximum length of sequence in the batch
109
+ :type size: int
110
+ :return: cos part of position encoding; shape: (1, 1, n_t, n_h) \n
111
+ sin part of position encoding; shape: (1, 1, n_t, n_h) \n
112
+ scaling coefficients; shape: (1, 1, n_t, n_h)
113
+ :rtype: tuple
114
+ """
115
+ pos = torch.arange(size, device=self.theta.device)[:, None]
116
+ cos, sin = torch.cos(pos * self.theta), torch.sin(pos * self.theta)
117
+ zeta = torch.pow(self.zeta, pos / self.channel)
118
+ return cos[None, None, ...], sin[None, None, ...], zeta[None, None, ...]
119
+
120
+
121
+ class Attention(nn.Module):
122
+ def __init__(self, channel: int = 512, num_head: int = 8) -> None:
123
+ """
124
+ Multi-head self-attention block.
125
+
126
+ :param channel: hidden layer features
127
+ :param num_head: number of heads
128
+ :type channel: int
129
+ :type num_head: int
130
+ """
131
+ super().__init__()
132
+ assert channel % num_head == 0
133
+ self.d = channel // num_head # head dimension
134
+ self.nh = num_head # number of heads
135
+ self.tp = (2 * self.d) ** 0.5 # attention temperature
136
+ self.qkv = Linear(channel, channel * 3)
137
+
138
+ @staticmethod
139
+ def _rotate(
140
+ q: Tensor, k: Tensor, pe: Tuple[Tensor, Tensor, Tensor]
141
+ ) -> Tuple[Tensor, Tensor]:
142
+ q_rotate, k_rotate = torch.zeros_like(q), torch.zeros_like(k)
143
+ q_rotate[..., 0::2] = -q[..., 1::2]
144
+ q_rotate[..., 1::2] = q[..., 0::2]
145
+ q = (q * pe[0] + q_rotate * pe[1]) * pe[2]
146
+ k_rotate[..., 0::2] = -k[..., 1::2]
147
+ k_rotate[..., 1::2] = k[..., 0::2]
148
+ k = (k * pe[0] + k_rotate * pe[1]) / pe[2]
149
+ return q, k
150
+
151
+ def forward(
152
+ self, x: Tensor, pe: Tuple[Tensor, Tensor, Tensor], mask: Optional[Tensor]
153
+ ) -> Tensor:
154
+ """
155
+ :param x: output tensor; shape: (n_b, n_t, n_f)
156
+ :param pe: position encoding; shape: (1, 1, n_t, n_h) * 3
157
+ :param mask: attention mask; shape: (1, n_b, n_t, n_t)
158
+ :type x: torch.Tensor
159
+ :type pe: tuple
160
+ :type mask: torch.Tensor | None
161
+ :return: attentioned output; shape: (n_b, n_t, n_f)
162
+ :rtype: torch.Tensor
163
+ """
164
+ n_b, n_a, _ = shape = x.shape
165
+ split = (n_b, n_a, self.nh, self.d)
166
+ q, k, v = self.qkv(x).chunk(3, -1)
167
+ q = q.view(split).permute(2, 0, 1, 3).contiguous()
168
+ k = k.view(split).permute(2, 0, 1, 3).contiguous()
169
+ v = v.view(split).permute(2, 0, 1, 3).contiguous()
170
+ q, k = self._rotate(q, k, pe) # position embedding
171
+ """
172
+ # Original code. Maybe using `nn.functional.scaled_dot_product_attention(...)` is better.
173
+
174
+ k_t = k.transpose(-2, -1)
175
+ if mask is not None:
176
+ alpha = softmax((q @ k_t / self.tp).masked_fill_(mask, -torch.inf), -1)
177
+ else:
178
+ alpha = softmax(q @ k_t / self.tp, -1)
179
+ atten_out = (alpha @ v).permute(1, 2, 0, 3).contiguous().view(shape)
180
+ """
181
+ atten_out = nn.functional.scaled_dot_product_attention(
182
+ q, k, v, mask, 0.0, False, scale=1 / self.tp
183
+ )
184
+ atten_out = atten_out.permute(1, 2, 0, 3).contiguous().view(shape)
185
+ return atten_out
186
+
187
+ def enable_lora(
188
+ self, r: int = 4, lora_alpha: int = 1, lora_dropout: float = 0.0
189
+ ) -> None:
190
+ """
191
+ Enable LoRA parameters.
192
+
193
+ :param r: rank
194
+ :param lora_alpha: LoRA alpha value
195
+ :param lora_dropout: dropout frequency in LoRA layer
196
+ :type r: int
197
+ :type lora_alpha: float
198
+ :type lora_dropout: float
199
+ :return:
200
+ :rtype: None
201
+ """
202
+ self.qkv.enable_lora(r, lora_alpha, lora_dropout)
203
+
204
+
205
+ class TransformerLayer(nn.Module):
206
+ def __init__(
207
+ self, channel: int = 512, num_head: int = 8, dropout: float = 0.01
208
+ ) -> None:
209
+ """
210
+ Transfomer layer block.
211
+
212
+ :param channel: hidden layer features
213
+ :param num_head: number of attention heads
214
+ :param dropout: dropout frequency
215
+ :type channel: int
216
+ :type num_head: int
217
+ :type dropout: float
218
+ """
219
+ super().__init__()
220
+ self.norm1 = nn.LayerNorm(channel, 1e-6, False)
221
+ self.attention = Attention(channel, num_head)
222
+ self.norm2 = nn.LayerNorm(channel, 1e-6, False)
223
+ self.ffn = nn.Sequential(
224
+ nn.Linear(channel, channel * 4),
225
+ nn.SELU(),
226
+ nn.Linear(channel * 4, channel),
227
+ nn.Dropout(dropout),
228
+ )
229
+ self.adaln_modulation = nn.Sequential(nn.SELU(), Linear(channel, 6 * channel))
230
+ # zero-out adaLN layer
231
+ nn.init.constant_(self.adaln_modulation[1].weight, 0)
232
+ nn.init.constant_(self.adaln_modulation[1].bias, 0)
233
+
234
+ def forward(
235
+ self,
236
+ x: Tensor,
237
+ pe: Tuple[Tensor, Tensor, Tensor],
238
+ c: Tensor,
239
+ mask: Optional[Tensor],
240
+ ) -> Tensor:
241
+ """
242
+ :param x: input tensor; shape: (n_b, n_t, n_f)
243
+ :param pe: position encoding; shape: (1, 1, n_t, n_h) * 3
244
+ :param c: conditioning; shape: (n_b, 1, n_f)
245
+ :param mask: attention mask; shape: (1, n_b, n_t, n_t)
246
+ :type x: torch.Tensor
247
+ :type pe: tuple
248
+ :type c: torch.Tensor
249
+ :type mask: torch.Tensor | None
250
+ :return: output tensor; shape: (n_b, n_t, n_f)
251
+ :rtype: torch.Tensor
252
+ """
253
+ c = self.adaln_modulation(c)
254
+ shift, scale, gate, shift_ffn, scale_ffn, gate_ffn = c.chunk(6, -1)
255
+ x = x + gate * self.attention(modulate(self.norm1(x), shift, scale), pe, mask)
256
+ x = x + gate_ffn * self.ffn(modulate(self.norm2(x), shift_ffn, scale_ffn))
257
+ return x
258
+
259
+ def enable_lora(
260
+ self, r: int = 4, lora_alpha: int = 1, lora_dropout: float = 0.0
261
+ ) -> None:
262
+ """
263
+ Enable LoRA parameters.
264
+
265
+ :param r: rank
266
+ :param lora_alpha: LoRA alpha value
267
+ :param lora_dropout: dropout frequency in LoRA layer
268
+ :type r: int
269
+ :type lora_alpha: float
270
+ :type lora_dropout: float
271
+ :return:
272
+ :rtype: None
273
+ """
274
+ self.attention.enable_lora(r, lora_alpha, lora_dropout)
275
+ self.adaln_modulation[1].enable_lora(r, lora_alpha, lora_dropout)
276
+
277
+
278
+ class FinalLayer(nn.Module):
279
+ def __init__(self, num_vocab: int, channel: int = 512) -> None:
280
+ """
281
+ The final layer of model.
282
+
283
+ :param num_vocab: number of vocabulary
284
+ :param channel: hidden layer features
285
+ :type num_vocab: int
286
+ :type channel: int
287
+ """
288
+ super().__init__()
289
+ self.norm_final = nn.LayerNorm(channel, 1e-6, False)
290
+ self.linear = Linear(channel, num_vocab)
291
+ self.adaln_modulation = nn.Sequential(nn.SELU(), Linear(channel, 2 * channel))
292
+ # zero-out this layer
293
+ nn.init.constant_(self.linear.weight, 0)
294
+ nn.init.constant_(self.linear.bias, 0)
295
+ nn.init.constant_(self.adaln_modulation[-1].weight, 0)
296
+ nn.init.constant_(self.adaln_modulation[-1].bias, 0)
297
+
298
+ def forward(self, x: Tensor, c: Tensor, return_logits: bool = True) -> Tensor:
299
+ """
300
+ :param x: input tensor; shape: (n_b, n_t, n_f)
301
+ :param c: conditioning; shape: (n_b, 1, n_f)
302
+ :param return_logits: whether to return unnormalised output logits
303
+ :type x: torch.Tensor
304
+ :type c: torch.Tensor
305
+ :type return_logits: bool
306
+ :return: output logits (unnormalised); shape: (n_b, n_t, n_vocab)
307
+ or token embeddings; shape: (n_b, n_t, n_f)
308
+ :rtype: torch.Tensor
309
+ """
310
+ shift, scale = self.adaln_modulation(c).chunk(2, -1)
311
+ x = modulate(self.norm_final(x), shift, scale)
312
+ if return_logits:
313
+ return self.linear(x)
314
+ return x
315
+
316
+ def enable_lora(
317
+ self, r: int = 4, lora_alpha: int = 1, lora_dropout: float = 0.0
318
+ ) -> None:
319
+ """
320
+ Enable LoRA parameters.
321
+
322
+ :param r: rank
323
+ :param lora_alpha: LoRA alpha value
324
+ :param lora_dropout: dropout frequency in LoRA layer
325
+ :type r: int
326
+ :type lora_alpha: float
327
+ :type lora_dropout: float
328
+ :return:
329
+ :rtype: None
330
+ """
331
+ self.linear.enable_lora(r, lora_alpha, lora_dropout)
332
+ self.adaln_modulation[1].enable_lora(r, lora_alpha, lora_dropout)
333
+
334
+
335
+ class ChemBFN(nn.Module):
336
+ def __init__(
337
+ self,
338
+ num_vocab: int,
339
+ channel: int = 512,
340
+ num_layer: int = 12,
341
+ num_head: int = 8,
342
+ dropout: float = 0.01,
343
+ ) -> None:
344
+ r"""
345
+ Bayesian Flow Network for Chemistry model representation.
346
+
347
+ Enable semi-autoregressive sampling by setting
348
+ `ChemBFN(...).semi_autoregressive = True`.
349
+
350
+ :param num_vocab: number of vocabulary
351
+ :param channel: hidden layer features
352
+ :param num_layer: number of transformer layers
353
+ :param num_head: number of heads
354
+ :param dropout: dropout frequency
355
+ :type num_vocab: int
356
+ :type channel: int
357
+ :type num_layer: int
358
+ :type num_head: int
359
+ :type dropout: float
360
+ """
361
+ super().__init__()
362
+ self.K = num_vocab
363
+ self.lora_enabled: bool = False
364
+ self.semi_autoregressive: bool = False
365
+ self.embedding = Linear(num_vocab, channel)
366
+ self.time_embed = nn.Sequential(
367
+ nn.Linear(1, channel // 2), nn.SELU(), nn.Linear(channel // 2, channel)
368
+ )
369
+ self.position = RoPE(channel, num_head)
370
+ self.encoder_layers = nn.ModuleList(
371
+ [TransformerLayer(channel, num_head, dropout) for _ in range(num_layer)]
372
+ )
373
+ self.final_layer = FinalLayer(num_vocab, channel)
374
+ self.register_buffer("beta", torch.scalar_tensor(20.4054 / self.K))
375
+ self.hparam = dict(
376
+ num_vocab=num_vocab,
377
+ channel=channel,
378
+ num_layer=num_layer,
379
+ num_head=num_head,
380
+ dropout=dropout,
381
+ )
382
+ self.lora_param = {}
383
+
384
+ def enable_lora(
385
+ self, r: int = 4, lora_alpha: int = 1, lora_dropout: float = 0.0
386
+ ) -> None:
387
+ """
388
+ Enable LoRA parameters.
389
+
390
+ :param r: rank
391
+ :param lora_alpha: LoRA alpha value
392
+ :param lora_dropout: dropout frequency in LoRA layer
393
+ :type r: int
394
+ :type lora_alpha: float
395
+ :type lora_dropout: float
396
+ :return:
397
+ :rtype: None
398
+ """
399
+ self.lora_enabled = True
400
+ self.lora_param = dict(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
401
+ self.embedding.enable_lora(r, lora_alpha, lora_dropout)
402
+ for layer in self.encoder_layers:
403
+ layer.enable_lora(r, lora_alpha, lora_dropout)
404
+ self.final_layer.enable_lora(r, lora_alpha, lora_dropout)
405
+
406
+ def forward(
407
+ self,
408
+ x: Tensor,
409
+ t: Tensor,
410
+ mask: Optional[Tensor] = None,
411
+ y: Optional[Tensor] = None,
412
+ ) -> Tensor:
413
+ """
414
+ :param x: input probabilities; shape: (n_b, n_t, n_vocab)
415
+ :param t: time; shape: (n_b, 1, 1)
416
+ :param mask: input mask; shape: (n_b, n_t, 1)
417
+ :param y: conditioning vector; shape: (n_b, 1, n_f)
418
+ :type x: torch.Tensor
419
+ :type t: torch.Tensor
420
+ :type mask: torch.Tensor | None
421
+ :type y: torch.Tensor | None
422
+ :return: probability distribution (before softmax); shape: (n_b, n_t, n_vocab)
423
+ or token embeddings; shape: (n_b, n_t, n_f)
424
+ :rtype: torch.Tensor
425
+ """
426
+ n_b, n_t, _ = x.shape
427
+ c = self.time_embed(t)
428
+ if y is not None:
429
+ c += y
430
+ pe = self.position(x.shape[1])
431
+ x = self.embedding(x)
432
+ attn_mask: Optional[Tensor] = None
433
+ if self.semi_autoregressive:
434
+ attn_mask = torch.tril(
435
+ torch.ones((1, n_b, n_t, n_t), device=self.beta.device), diagonal=0
436
+ )
437
+ else:
438
+ if mask is not None:
439
+ """
440
+ # Original Code.
441
+
442
+ attn_mask = mask.transpose(-2, -1).repeat(1, x.shape[1], 1)[None, ...] == 0
443
+ """
444
+ attn_mask = mask.transpose(-2, -1).repeat(1, n_t, 1)[None, ...] != 0
445
+ for layer in self.encoder_layers:
446
+ x = layer(x, pe, c, attn_mask)
447
+ return self.final_layer(x, c, mask is None)
448
+
449
+ def calc_beta(self, t: Tensor) -> Tensor:
450
+ r"""
451
+ Calculate beta(t) value.
452
+
453
+ .. math::
454
+ ```
455
+ \begin{equation}
456
+ \beta(t) = %
457
+ -\frac{4\ln{(1 - t + te^{-\frac{K}{4}\beta(1)})}}{K}
458
+ \end{equation}
459
+ ```
460
+
461
+ :param t: continuous time in [0, 1]; shape: (n_b, 1, 1)
462
+ :type t: torch.Tensor
463
+ :return: beta(t); shape: (n_b, 1, 1)
464
+ :rtype: torch.Tensor
465
+ """
466
+ return -4 * (1 - t + t * (-self.K * self.beta / 4).exp()).log() / self.K
467
+
468
+ def calc_discrete_alpha(self, t1: Tensor, t2: Tensor) -> Tensor:
469
+ r"""
470
+ Calculate alpha(i) value.
471
+
472
+ .. math:: $\alpha(i) = \bate(t_{i}) - \beta(t_{i - 1})$
473
+
474
+ :param t1: discrete time (i - 1) / n; shape: (n_b, 1, 1)
475
+ :param t2: discrete time i / n; shape: (n_b, 1, 1)
476
+ :type t1: torch.Tensor
477
+ :type t2: torch.Tensor
478
+ :return: alpha(i); shape: (n_b, 1, 1)
479
+ :rtype: torch.Tensor
480
+ """
481
+ # assert t2 > t1
482
+ return self.calc_beta(t2) - self.calc_beta(t1)
483
+
484
+ def calc_cts_alpha(self, t: Tensor) -> Tensor:
485
+ r"""
486
+ Calculate alpha(t) / 2 value.
487
+
488
+ .. math::
489
+ ```
490
+ \begin{equation}
491
+ \alpha(t) = %
492
+ \frac{d\beta(t)}{dt} = %
493
+ \frac{4}{K}%
494
+ \frac{1 - e^{-\frac{K}{4}\beta(1)}}%
495
+ {1 - t + te^{-\frac{K}{4}\beta(1)}}
496
+ \end{equation}
497
+ ```
498
+
499
+ :param t: continuous time in [0, 1]; shape: (n_b, 1, 1)
500
+ :type t: torch.Tensor
501
+ :return: alpha(t); shape: (n_b, 1, 1)
502
+ :rtype: torch.Tensor
503
+ """
504
+ a = 1 - (-self.K * self.beta / 4).exp()
505
+ b = 1 - t + t * (-self.K * self.beta / 4).exp()
506
+ return 2 * a / b / self.K
507
+
508
+ def discrete_output_distribution(
509
+ self, theta: Tensor, t: Tensor, y: Optional[Tensor], w: Optional[float]
510
+ ) -> Tensor:
511
+ """
512
+ :param theta: input distribution; shape: (n_b, n_t, n_vocab)
513
+ :param t: continuous time in [0, 1]; shape: (n_b, 1, 1)
514
+ :param y: conditioning vector; shape: (n_b, 1, n_f)
515
+ :param w: guidance strength controlling the conditional generation
516
+ :type theta: torch.Tensor
517
+ :type t: torch.Tensor
518
+ :type y: torch.Tensor | None
519
+ :type w: float | None
520
+ :return: output distribution; shape: (n_b, n_t, n_vocab)
521
+ :rtype: torch.Tensor
522
+ """
523
+ theta = 2 * theta - 1 # rescale to [-1, 1]
524
+ if w is None:
525
+ return softmax(self.forward(theta, t, None, y), -1)
526
+ elif y is None:
527
+ return softmax(self.forward(theta, t, None, None), -1)
528
+ else:
529
+ p_cond = self.forward(theta, t, None, y)
530
+ p_uncond = self.forward(theta, t, None, None)
531
+ return softmax((1 + w) * p_cond - w * p_uncond, -1)
532
+
533
+ def cts_loss(
534
+ self,
535
+ x: Tensor,
536
+ t: Tensor,
537
+ y: Optional[Tensor],
538
+ mask: Optional[Tensor] = None,
539
+ return_output_dist: bool = False,
540
+ ) -> Tuple[Tensor, Optional[Tensor]]:
541
+ """
542
+ Compute continuous-time loss.
543
+
544
+ :param x: target data; shape: (n_b, n_t)
545
+ :param t: continuous time in [0, 1); shape: (n_b, 1, 1)
546
+ :param y: conditioning vector; shape: (n_b, 1, n_f)
547
+ :param mask: in-text mask; shape: (n_b, n_t)
548
+ :param return_output_dist: whether to return the output distribution
549
+ :type x: torch.Tensor
550
+ :type t: torch.Tensor
551
+ :type y: torch.Tensor | None
552
+ :type mask: torch.Tensor | None
553
+ :type return_output_dist: bool
554
+ :returns: continuous-time loss; shape: () \n
555
+ output distribution; shape: (n_b, n_t, n_vocab) or `None`
556
+ :rtype: tuple
557
+ """
558
+ beta = self.calc_beta(t) # shape: (n_b, 1, 1)
559
+ e_x = nn.functional.one_hot(x, self.K).float()
560
+ mu = beta * (self.K * e_x - 1)
561
+ sigma = (beta * self.K).sqrt()
562
+ theta = softmax(mu + sigma * torch.randn_like(mu), -1)
563
+ if mask is not None:
564
+ mask = mask[..., None]
565
+ theta = e_x * mask + (1 - mask) * theta
566
+ e_hat = self.discrete_output_distribution(theta, t, y, None)
567
+ cts_loss = self.K * (e_x - e_hat).pow(2) * self.calc_cts_alpha(t)
568
+ if return_output_dist:
569
+ return cts_loss.mean(), e_hat
570
+ return cts_loss.mean(), None
571
+
572
+ @torch.inference_mode()
573
+ def reconstruction_loss(self, x: Tensor, t: Tensor, y: Optional[Tensor]) -> Tensor:
574
+ """
575
+ Compute reconstruction loss.
576
+
577
+ :param x: target data; shape: (n_b, n_t)
578
+ :param t: continuous time in [0, 1]; shape: (n_b, 1, 1)
579
+ :param y: conditioning vector; shape: (n_b, 1, n_f)
580
+ :type x: torch.Tensor
581
+ :type t: torch.Tensor
582
+ :type y: torch.Tensor | None
583
+ :return: reconstruction loss; shape: ()
584
+ :rtype: torch.Tensor
585
+ """
586
+ beta = self.calc_beta(t)
587
+ mu = beta * (self.K * nn.functional.one_hot(x, self.K).float() - 1)
588
+ sigma = (beta * self.K).sqrt()
589
+ theta = softmax(mu + sigma * torch.randn_like(mu), -1)
590
+ logits = self.forward(2 * theta - 1, t, None, y)
591
+ # compute negative log probability
592
+ x, logits = torch.broadcast_tensors(x[..., None], logits)
593
+ return (-logits.gather(-1, x[..., :1]).squeeze(-1)).mean()
594
+
595
+ @torch.jit.export
596
+ def sample(
597
+ self,
598
+ batch_size: int,
599
+ sequence_size: int,
600
+ y: Optional[Tensor],
601
+ sample_step: int = 100,
602
+ guidance_strength: float = 4.0,
603
+ token_mask: Optional[Tensor] = None,
604
+ ) -> Tensor:
605
+ """
606
+ Sample from a piror distribution.
607
+
608
+ :param batch_size: batch size
609
+ :param sequence_size: max sequence length
610
+ :param y: conditioning vector; shape: (n_b, 1, n_f)
611
+ :param sample_step: number of sampling steps
612
+ :param guidance_strength: strength of conditional generation. It is not used if y is null.
613
+ :param token_mask: token mask; shape: (1, 1, n_vocab)
614
+ :type batch_size: int
615
+ :type sequence_size: int
616
+ :type y: torch.Tensor | None
617
+ :type sample_step: int
618
+ :type guidance_strength: float
619
+ :type token_mask: torch.Tensor | None
620
+ :return: probability distribution; shape: (n_b, n_t, n_vocab)
621
+ :rtype: torch.Tensor
622
+ """
623
+ theta = (
624
+ torch.ones((batch_size, sequence_size, self.K), device=self.beta.device)
625
+ / self.K
626
+ )
627
+ if y is not None:
628
+ assert y.dim() == 3 # this doesn't work if the model is frezen in JIT.
629
+ if y.shape[0] == 1:
630
+ y = y.repeat(batch_size, 1, 1)
631
+ for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
632
+ t = (i - 1).view(1, 1, 1).repeat(batch_size, 1, 1) / sample_step
633
+ p = self.discrete_output_distribution(theta, t, y, guidance_strength)
634
+ if token_mask is not None:
635
+ p = p.masked_fill_(token_mask, 0.0)
636
+ alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
637
+ e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
638
+ mu = alpha * (self.K * e_k - 1)
639
+ sigma = (alpha * self.K).sqrt()
640
+ theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
641
+ theta = theta / theta.sum(-1, True)
642
+ t_final = torch.ones((batch_size, 1, 1), device=self.beta.device)
643
+ p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
644
+ if token_mask is not None:
645
+ p = p.masked_fill_(token_mask, 0.0)
646
+ return torch.argmax(p, -1)
647
+
648
+ @torch.jit.export
649
+ def ode_sample(
650
+ self,
651
+ batch_size: int,
652
+ sequence_size: int,
653
+ y: Optional[Tensor],
654
+ sample_step: int = 100,
655
+ guidance_strength: float = 4.0,
656
+ token_mask: Optional[Tensor] = None,
657
+ temperature: float = 0.5,
658
+ ) -> Tensor:
659
+ """
660
+ ODE-based sampling.
661
+
662
+ :param batch_size: batch size
663
+ :param sequence_size: max sequence length
664
+ :param y: conditioning vector; shape: (n_b, 1, n_f)
665
+ :param sample_step: number of sampling steps
666
+ :param guidance_strength: strength of conditional generation. It is not used if y is null.
667
+ :param token_mask: token mask; shape: (1, 1, n_vocab)
668
+ :param temperature: sampling temperature
669
+ :type batch_size: int
670
+ :type sequence_size: int
671
+ :type y: torch.Tensor | None
672
+ :type sample_step: int
673
+ :type guidance_strength: float
674
+ :type token_mask: torch.Tensor | None
675
+ :type temperature: float
676
+ :return: probability distribution; shape: (n_b, n_t, n_vocab)
677
+ :rtype: torch.Tensor
678
+ """
679
+ z = torch.zeros((batch_size, sequence_size, self.K), device=self.beta.device)
680
+ if y is not None:
681
+ assert y.dim() == 3 # this doesn't work if the model is frezen in JIT.
682
+ if y.shape[0] == 1:
683
+ y = y.repeat(batch_size, 1, 1)
684
+ for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
685
+ t = (i - 1).view(1, 1, 1).repeat(batch_size, 1, 1) / sample_step
686
+ theta = torch.softmax(z, -1)
687
+ beta = self.calc_beta(t + 1 / sample_step)
688
+ p = self.discrete_output_distribution(theta, t, y, guidance_strength)
689
+ if token_mask is not None:
690
+ p = p.masked_fill_(token_mask, 0.0)
691
+ u = torch.randn_like(z)
692
+ z = (self.K * p - 1) * beta + (self.K * beta * temperature).sqrt() * u
693
+ t_final = torch.ones((batch_size, 1, 1), device=self.beta.device)
694
+ theta = torch.softmax(z, -1)
695
+ p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
696
+ if token_mask is not None:
697
+ p = p.masked_fill_(token_mask, 0.0)
698
+ return torch.argmax(p, -1)
699
+
700
+ @torch.jit.export
701
+ def inpaint(
702
+ self,
703
+ x: Tensor,
704
+ y: Optional[Tensor] = None,
705
+ sample_step: int = 100,
706
+ guidance_strength: float = 4.0,
707
+ token_mask: Optional[Tensor] = None,
708
+ ) -> Tensor:
709
+ """
710
+ Molecule inpaint functionality.
711
+
712
+ :param x: categorical indices of scaffold; shape: (n_b, n_t)
713
+ :param y: conditioning vector; shape: (n_b, 1, n_f)
714
+ :param sample_step: number of sampling steps
715
+ :param guidance_strength: strength of conditional generation. It is not used if y is null.
716
+ :param token_mask: token mask; shape: (1, 1, n_vocab)
717
+ :type x: torch.Tensor
718
+ :type y: torch.Tensor | None
719
+ :type sample_step: int
720
+ :type guidance_strength: float
721
+ :type token_mask: torch.Tensor | None
722
+ :return: probability distribution; shape: (n_b, n_t, n_vocab)
723
+ :rtype: torch.Tensor
724
+ """
725
+ n_b, n_t = x.shape
726
+ mask = (x != 0).float()[..., None]
727
+ theta = torch.ones((n_b, n_t, self.K), device=x.device) / self.K
728
+ x_onehot = nn.functional.one_hot(x, self.K) * mask
729
+ theta = x_onehot + (1 - mask) * theta
730
+ if y is not None:
731
+ assert y.dim() == 3 # this doesn't work if the model is frezen in JIT.
732
+ if y.shape[0] == 1:
733
+ y = y.repeat(n_b, 1, 1)
734
+ for i in torch.linspace(1, sample_step, sample_step, device=x.device):
735
+ t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
736
+ p = self.discrete_output_distribution(theta, t, y, guidance_strength)
737
+ if token_mask is not None:
738
+ p = p.masked_fill_(token_mask, 0.0)
739
+ alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)
740
+ e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
741
+ mu = alpha * (self.K * e_k - 1)
742
+ sigma = (alpha * self.K).sqrt()
743
+ theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
744
+ theta = theta / theta.sum(-1, True)
745
+ theta = x_onehot + (1 - mask) * theta
746
+ t_final = torch.ones((n_b, 1, 1), device=x.device)
747
+ p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
748
+ if token_mask is not None:
749
+ p = p.masked_fill_(token_mask, 0.0)
750
+ return torch.argmax(p, -1)
751
+
752
+ @torch.jit.export
753
+ def ode_inpaint(
754
+ self,
755
+ x: Tensor,
756
+ y: Optional[Tensor] = None,
757
+ sample_step: int = 100,
758
+ guidance_strength: float = 4.0,
759
+ token_mask: Optional[Tensor] = None,
760
+ temperature: float = 0.5,
761
+ ) -> Tensor:
762
+ """
763
+ ODE inpainting.
764
+
765
+ :param x: categorical indices of scaffold; shape: (n_b, n_t)
766
+ :param y: conditioning vector; shape: (n_b, 1, n_f)
767
+ :param sample_step: number of sampling steps
768
+ :param guidance_strength: strength of conditional generation. It is not used if y is null.
769
+ :param token_mask: token mask; shape: (1, 1, n_vocab)
770
+ :param temperature: sampling temperature
771
+ :type x: torch.Tensor
772
+ :type y: torch.Tensor | None
773
+ :type sample_step: int
774
+ :type guidance_strength: float
775
+ :type token_mask: torch.Tensor | None
776
+ :type temperature: float
777
+ :return: probability distribution; shape: (n_b, n_t, n_vocab)
778
+ :rtype: torch.Tensor
779
+ """
780
+ n_b, n_t = x.shape
781
+ mask = (x != 0).float()[..., None]
782
+ x_onehot = nn.functional.one_hot(x, self.K) * mask
783
+ z = torch.zeros((n_b, n_t, self.K), device=self.beta.device)
784
+ if y is not None:
785
+ assert y.dim() == 3 # this doesn't work if the model is frezen in JIT.
786
+ if y.shape[0] == 1:
787
+ y = y.repeat(n_b, 1, 1)
788
+ for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
789
+ t = (i - 1).view(1, 1, 1).repeat(n_b, 1, 1) / sample_step
790
+ theta = torch.softmax(z, -1)
791
+ theta = x_onehot + (1 - mask) * theta
792
+ beta = self.calc_beta(t + 1 / sample_step)
793
+ p = self.discrete_output_distribution(theta, t, y, guidance_strength)
794
+ if token_mask is not None:
795
+ p = p.masked_fill_(token_mask, 0.0)
796
+ u = torch.randn_like(z)
797
+ z = (self.K * p - 1) * beta + (self.K * beta * temperature).sqrt() * u
798
+ t_final = torch.ones((n_b, 1, 1), device=self.beta.device)
799
+ theta = torch.softmax(z, -1)
800
+ theta = x_onehot + (1 - mask) * theta
801
+ p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
802
+ if token_mask is not None:
803
+ p = p.masked_fill_(token_mask, 0.0)
804
+ return torch.argmax(p, -1)
805
+
806
+ def inference(self, x: Tensor, mlp: nn.Module) -> Tensor:
807
+ """
808
+ Predict from SMILES tokens.
809
+
810
+ :param x: input tokens; shape: (n_b, n_t)
811
+ :param mlp: MLP module
812
+ :type x: torch.Tensor
813
+ :type mlp: torch.nn.Module
814
+ :return: output values; shape: (n_b, n_task)
815
+ :rtype: torch.Tensor
816
+ """
817
+ t = torch.ones((x.shape[0], 1, 1), device=x.device)
818
+ mask = (x != 0).float()[..., None]
819
+ theta = 2 * torch.nn.functional.one_hot(x, self.K).float() - 1
820
+ z = self.forward(theta, t, mask, None)
821
+ if self.semi_autoregressive:
822
+ return mlp.forward(z[x == 2].view(z.shape[0], -1))
823
+ return mlp.forward(z[::, 0])
824
+
825
+ @classmethod
826
+ def from_checkpoint(
827
+ cls, ckpt: Union[str, Path], ckpt_lora: Union[str, Path, None] = None
828
+ ) -> Self:
829
+ """
830
+ Load model weight from a checkpoint.
831
+
832
+ :param ckpt: checkpoint file
833
+ :param ckpt_lora: LoRA checkpoint file which is optional
834
+ :type ckpt: str | pathlib.Path
835
+ :type ckpt_lora: str | pathlib.Path | None
836
+ :return: Bayesian Flow Network for Chemistry model
837
+ :rtype: bayesianflow_for_chem.model.ChemBNF
838
+ """
839
+ with open(ckpt, "rb") as f:
840
+ state = torch.load(f, "cpu", weights_only=True)
841
+ nn, hparam = state["nn"], state["hparam"]
842
+ model = ChemBFN(
843
+ hparam["num_vocab"],
844
+ hparam["channel"],
845
+ hparam["num_layer"],
846
+ hparam["num_head"],
847
+ hparam["dropout"],
848
+ )
849
+ model.load_state_dict(nn, False)
850
+ if ckpt_lora:
851
+ with open(ckpt_lora, "rb") as g:
852
+ lora_state = torch.load(g, "cpu", weights_only=True)
853
+ lora_nn, lora_param = lora_state["lora_nn"], lora_state["lora_param"]
854
+ model.enable_lora(**lora_param)
855
+ model.load_state_dict(lora_nn, False)
856
+ return model
857
+
858
+
859
+ class MLP(nn.Module):
860
+ def __init__(
861
+ self, size: List[int], class_input: bool = False, dropout: float = 0.0
862
+ ) -> None:
863
+ """
864
+ MLP module.
865
+ e.g.
866
+
867
+ ```python
868
+ mlp = MLP(size=[512, 256, 1])
869
+ mlp = MLP(size=[10, 256, 512], True) # embedding 10 classes
870
+ ```
871
+
872
+ :param size: hidden feature sizes
873
+ :param class_input: whether the input is class indices
874
+ :param dropout: dropout frequency
875
+ :type size: list
876
+ :type class_input: bool
877
+ :type dropout: float
878
+ """
879
+ super().__init__()
880
+ assert len(size) >= 2
881
+ self.class_input = class_input
882
+ self.dropout = nn.Dropout(dropout if not class_input else 0.0)
883
+ self.layers = nn.ModuleList(
884
+ [nn.Linear(i, size[key + 1]) for key, i in enumerate(size[:-2])]
885
+ )
886
+ if class_input:
887
+ self.layers[0] = nn.Embedding(size[0], size[1])
888
+ self.layers.append(nn.Linear(size[-2], size[-1]))
889
+ self.hparam = dict(size=size, class_input=class_input, dropout=dropout)
890
+
891
+ def forward(self, x: Tensor) -> Tensor:
892
+ """
893
+ :param x: input tensor; shape: (n_b, n_input)
894
+ :return: output tensor; shape: (n_b, n_output) if not class_input;
895
+ (n_b, 1, n_output) if class_input
896
+ :type x: torch.Tensor
897
+ :rtype: torch.Tensor
898
+ """
899
+ x = self.dropout(x)
900
+ if self.class_input:
901
+ x = x.to(dtype=torch.long)
902
+ for layer in self.layers[:-1]:
903
+ x = torch.selu(layer(x))
904
+ return self.layers[-1](x)
905
+
906
+ @classmethod
907
+ def from_checkpoint(cls, ckpt: Union[str, Path], strict: bool = True) -> Self:
908
+ """
909
+ Load model weight from a checkpoint.
910
+
911
+ :param ckpt: checkpoint file
912
+ :param strict: whether to strictly match `state_dict`
913
+ :type ckpt: str | pathlib.Path
914
+ :type strict: bool
915
+ :return: MLP
916
+ :rtype: bayesianflow_for_chem.model.MLP
917
+ """
918
+ with open(ckpt, "rb") as f:
919
+ state = torch.load(f, "cpu", weights_only=True)
920
+ nn, hparam = state["nn"], state["hparam"]
921
+ model = MLP(hparam["size"], hparam["class_input"], hparam["dropout"])
922
+ model.load_state_dict(nn, strict)
923
+ return model
924
+
925
+
926
+ if __name__ == "__main__":
927
+ ...