boltz-vsynthes 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltz/__init__.py +7 -0
- boltz/data/__init__.py +0 -0
- boltz/data/const.py +1184 -0
- boltz/data/crop/__init__.py +0 -0
- boltz/data/crop/affinity.py +164 -0
- boltz/data/crop/boltz.py +296 -0
- boltz/data/crop/cropper.py +45 -0
- boltz/data/feature/__init__.py +0 -0
- boltz/data/feature/featurizer.py +1230 -0
- boltz/data/feature/featurizerv2.py +2208 -0
- boltz/data/feature/symmetry.py +602 -0
- boltz/data/filter/__init__.py +0 -0
- boltz/data/filter/dynamic/__init__.py +0 -0
- boltz/data/filter/dynamic/date.py +76 -0
- boltz/data/filter/dynamic/filter.py +24 -0
- boltz/data/filter/dynamic/max_residues.py +37 -0
- boltz/data/filter/dynamic/resolution.py +34 -0
- boltz/data/filter/dynamic/size.py +38 -0
- boltz/data/filter/dynamic/subset.py +42 -0
- boltz/data/filter/static/__init__.py +0 -0
- boltz/data/filter/static/filter.py +26 -0
- boltz/data/filter/static/ligand.py +37 -0
- boltz/data/filter/static/polymer.py +299 -0
- boltz/data/module/__init__.py +0 -0
- boltz/data/module/inference.py +307 -0
- boltz/data/module/inferencev2.py +429 -0
- boltz/data/module/training.py +684 -0
- boltz/data/module/trainingv2.py +660 -0
- boltz/data/mol.py +900 -0
- boltz/data/msa/__init__.py +0 -0
- boltz/data/msa/mmseqs2.py +235 -0
- boltz/data/pad.py +84 -0
- boltz/data/parse/__init__.py +0 -0
- boltz/data/parse/a3m.py +134 -0
- boltz/data/parse/csv.py +100 -0
- boltz/data/parse/fasta.py +138 -0
- boltz/data/parse/mmcif.py +1239 -0
- boltz/data/parse/mmcif_with_constraints.py +1607 -0
- boltz/data/parse/schema.py +1851 -0
- boltz/data/parse/yaml.py +68 -0
- boltz/data/sample/__init__.py +0 -0
- boltz/data/sample/cluster.py +283 -0
- boltz/data/sample/distillation.py +57 -0
- boltz/data/sample/random.py +39 -0
- boltz/data/sample/sampler.py +49 -0
- boltz/data/tokenize/__init__.py +0 -0
- boltz/data/tokenize/boltz.py +195 -0
- boltz/data/tokenize/boltz2.py +396 -0
- boltz/data/tokenize/tokenizer.py +24 -0
- boltz/data/types.py +777 -0
- boltz/data/write/__init__.py +0 -0
- boltz/data/write/mmcif.py +305 -0
- boltz/data/write/pdb.py +171 -0
- boltz/data/write/utils.py +23 -0
- boltz/data/write/writer.py +330 -0
- boltz/main.py +1292 -0
- boltz/model/__init__.py +0 -0
- boltz/model/layers/__init__.py +0 -0
- boltz/model/layers/attention.py +132 -0
- boltz/model/layers/attentionv2.py +111 -0
- boltz/model/layers/confidence_utils.py +231 -0
- boltz/model/layers/dropout.py +34 -0
- boltz/model/layers/initialize.py +100 -0
- boltz/model/layers/outer_product_mean.py +98 -0
- boltz/model/layers/pair_averaging.py +135 -0
- boltz/model/layers/pairformer.py +337 -0
- boltz/model/layers/relative.py +58 -0
- boltz/model/layers/transition.py +78 -0
- boltz/model/layers/triangular_attention/__init__.py +0 -0
- boltz/model/layers/triangular_attention/attention.py +189 -0
- boltz/model/layers/triangular_attention/primitives.py +409 -0
- boltz/model/layers/triangular_attention/utils.py +380 -0
- boltz/model/layers/triangular_mult.py +212 -0
- boltz/model/loss/__init__.py +0 -0
- boltz/model/loss/bfactor.py +49 -0
- boltz/model/loss/confidence.py +590 -0
- boltz/model/loss/confidencev2.py +621 -0
- boltz/model/loss/diffusion.py +171 -0
- boltz/model/loss/diffusionv2.py +134 -0
- boltz/model/loss/distogram.py +48 -0
- boltz/model/loss/distogramv2.py +105 -0
- boltz/model/loss/validation.py +1025 -0
- boltz/model/models/__init__.py +0 -0
- boltz/model/models/boltz1.py +1286 -0
- boltz/model/models/boltz2.py +1249 -0
- boltz/model/modules/__init__.py +0 -0
- boltz/model/modules/affinity.py +223 -0
- boltz/model/modules/confidence.py +481 -0
- boltz/model/modules/confidence_utils.py +181 -0
- boltz/model/modules/confidencev2.py +495 -0
- boltz/model/modules/diffusion.py +844 -0
- boltz/model/modules/diffusion_conditioning.py +116 -0
- boltz/model/modules/diffusionv2.py +677 -0
- boltz/model/modules/encoders.py +639 -0
- boltz/model/modules/encodersv2.py +565 -0
- boltz/model/modules/transformers.py +322 -0
- boltz/model/modules/transformersv2.py +261 -0
- boltz/model/modules/trunk.py +688 -0
- boltz/model/modules/trunkv2.py +828 -0
- boltz/model/modules/utils.py +303 -0
- boltz/model/optim/__init__.py +0 -0
- boltz/model/optim/ema.py +389 -0
- boltz/model/optim/scheduler.py +99 -0
- boltz/model/potentials/__init__.py +0 -0
- boltz/model/potentials/potentials.py +497 -0
- boltz/model/potentials/schedules.py +32 -0
- boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
- boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
- boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
- boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
- boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
- boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,116 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch import nn
|
5
|
+
from torch.nn import Module
|
6
|
+
|
7
|
+
from boltz.model.modules.encodersv2 import (
|
8
|
+
AtomEncoder,
|
9
|
+
PairwiseConditioning,
|
10
|
+
)
|
11
|
+
|
12
|
+
|
13
|
+
class DiffusionConditioning(Module):
|
14
|
+
def __init__(
|
15
|
+
self,
|
16
|
+
token_s: int,
|
17
|
+
token_z: int,
|
18
|
+
atom_s: int,
|
19
|
+
atom_z: int,
|
20
|
+
atoms_per_window_queries: int = 32,
|
21
|
+
atoms_per_window_keys: int = 128,
|
22
|
+
atom_encoder_depth: int = 3,
|
23
|
+
atom_encoder_heads: int = 4,
|
24
|
+
token_transformer_depth: int = 24,
|
25
|
+
token_transformer_heads: int = 8,
|
26
|
+
atom_decoder_depth: int = 3,
|
27
|
+
atom_decoder_heads: int = 4,
|
28
|
+
atom_feature_dim: int = 128,
|
29
|
+
conditioning_transition_layers: int = 2,
|
30
|
+
use_no_atom_char: bool = False,
|
31
|
+
use_atom_backbone_feat: bool = False,
|
32
|
+
use_residue_feats_atoms: bool = False,
|
33
|
+
) -> None:
|
34
|
+
super().__init__()
|
35
|
+
|
36
|
+
self.pairwise_conditioner = PairwiseConditioning(
|
37
|
+
token_z=token_z,
|
38
|
+
dim_token_rel_pos_feats=token_z,
|
39
|
+
num_transitions=conditioning_transition_layers,
|
40
|
+
)
|
41
|
+
|
42
|
+
self.atom_encoder = AtomEncoder(
|
43
|
+
atom_s=atom_s,
|
44
|
+
atom_z=atom_z,
|
45
|
+
token_s=token_s,
|
46
|
+
token_z=token_z,
|
47
|
+
atoms_per_window_queries=atoms_per_window_queries,
|
48
|
+
atoms_per_window_keys=atoms_per_window_keys,
|
49
|
+
atom_feature_dim=atom_feature_dim,
|
50
|
+
structure_prediction=True,
|
51
|
+
use_no_atom_char=use_no_atom_char,
|
52
|
+
use_atom_backbone_feat=use_atom_backbone_feat,
|
53
|
+
use_residue_feats_atoms=use_residue_feats_atoms,
|
54
|
+
)
|
55
|
+
|
56
|
+
self.atom_enc_proj_z = nn.ModuleList()
|
57
|
+
for _ in range(atom_encoder_depth):
|
58
|
+
self.atom_enc_proj_z.append(
|
59
|
+
nn.Sequential(
|
60
|
+
nn.LayerNorm(atom_z),
|
61
|
+
nn.Linear(atom_z, atom_encoder_heads, bias=False),
|
62
|
+
)
|
63
|
+
)
|
64
|
+
|
65
|
+
self.atom_dec_proj_z = nn.ModuleList()
|
66
|
+
for _ in range(atom_decoder_depth):
|
67
|
+
self.atom_dec_proj_z.append(
|
68
|
+
nn.Sequential(
|
69
|
+
nn.LayerNorm(atom_z),
|
70
|
+
nn.Linear(atom_z, atom_decoder_heads, bias=False),
|
71
|
+
)
|
72
|
+
)
|
73
|
+
|
74
|
+
self.token_trans_proj_z = nn.ModuleList()
|
75
|
+
for _ in range(token_transformer_depth):
|
76
|
+
self.token_trans_proj_z.append(
|
77
|
+
nn.Sequential(
|
78
|
+
nn.LayerNorm(token_z),
|
79
|
+
nn.Linear(token_z, token_transformer_heads, bias=False),
|
80
|
+
)
|
81
|
+
)
|
82
|
+
|
83
|
+
def forward(
|
84
|
+
self,
|
85
|
+
s_trunk, # Float['b n ts']
|
86
|
+
z_trunk, # Float['b n n tz']
|
87
|
+
relative_position_encoding, # Float['b n n tz']
|
88
|
+
feats,
|
89
|
+
):
|
90
|
+
z = self.pairwise_conditioner(
|
91
|
+
z_trunk,
|
92
|
+
relative_position_encoding,
|
93
|
+
)
|
94
|
+
|
95
|
+
q, c, p, to_keys = self.atom_encoder(
|
96
|
+
feats=feats,
|
97
|
+
s_trunk=s_trunk, # Float['b n ts'],
|
98
|
+
z=z, # Float['b n n tz'],
|
99
|
+
)
|
100
|
+
|
101
|
+
atom_enc_bias = []
|
102
|
+
for layer in self.atom_enc_proj_z:
|
103
|
+
atom_enc_bias.append(layer(p))
|
104
|
+
atom_enc_bias = torch.cat(atom_enc_bias, dim=-1)
|
105
|
+
|
106
|
+
atom_dec_bias = []
|
107
|
+
for layer in self.atom_dec_proj_z:
|
108
|
+
atom_dec_bias.append(layer(p))
|
109
|
+
atom_dec_bias = torch.cat(atom_dec_bias, dim=-1)
|
110
|
+
|
111
|
+
token_trans_bias = []
|
112
|
+
for layer in self.token_trans_proj_z:
|
113
|
+
token_trans_bias.append(layer(z))
|
114
|
+
token_trans_bias = torch.cat(token_trans_bias, dim=-1)
|
115
|
+
|
116
|
+
return q, c, to_keys, atom_enc_bias, atom_dec_bias, token_trans_bias
|