boltz-vsynthes 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. boltz/__init__.py +7 -0
  2. boltz/data/__init__.py +0 -0
  3. boltz/data/const.py +1184 -0
  4. boltz/data/crop/__init__.py +0 -0
  5. boltz/data/crop/affinity.py +164 -0
  6. boltz/data/crop/boltz.py +296 -0
  7. boltz/data/crop/cropper.py +45 -0
  8. boltz/data/feature/__init__.py +0 -0
  9. boltz/data/feature/featurizer.py +1230 -0
  10. boltz/data/feature/featurizerv2.py +2208 -0
  11. boltz/data/feature/symmetry.py +602 -0
  12. boltz/data/filter/__init__.py +0 -0
  13. boltz/data/filter/dynamic/__init__.py +0 -0
  14. boltz/data/filter/dynamic/date.py +76 -0
  15. boltz/data/filter/dynamic/filter.py +24 -0
  16. boltz/data/filter/dynamic/max_residues.py +37 -0
  17. boltz/data/filter/dynamic/resolution.py +34 -0
  18. boltz/data/filter/dynamic/size.py +38 -0
  19. boltz/data/filter/dynamic/subset.py +42 -0
  20. boltz/data/filter/static/__init__.py +0 -0
  21. boltz/data/filter/static/filter.py +26 -0
  22. boltz/data/filter/static/ligand.py +37 -0
  23. boltz/data/filter/static/polymer.py +299 -0
  24. boltz/data/module/__init__.py +0 -0
  25. boltz/data/module/inference.py +307 -0
  26. boltz/data/module/inferencev2.py +429 -0
  27. boltz/data/module/training.py +684 -0
  28. boltz/data/module/trainingv2.py +660 -0
  29. boltz/data/mol.py +900 -0
  30. boltz/data/msa/__init__.py +0 -0
  31. boltz/data/msa/mmseqs2.py +235 -0
  32. boltz/data/pad.py +84 -0
  33. boltz/data/parse/__init__.py +0 -0
  34. boltz/data/parse/a3m.py +134 -0
  35. boltz/data/parse/csv.py +100 -0
  36. boltz/data/parse/fasta.py +138 -0
  37. boltz/data/parse/mmcif.py +1239 -0
  38. boltz/data/parse/mmcif_with_constraints.py +1607 -0
  39. boltz/data/parse/schema.py +1851 -0
  40. boltz/data/parse/yaml.py +68 -0
  41. boltz/data/sample/__init__.py +0 -0
  42. boltz/data/sample/cluster.py +283 -0
  43. boltz/data/sample/distillation.py +57 -0
  44. boltz/data/sample/random.py +39 -0
  45. boltz/data/sample/sampler.py +49 -0
  46. boltz/data/tokenize/__init__.py +0 -0
  47. boltz/data/tokenize/boltz.py +195 -0
  48. boltz/data/tokenize/boltz2.py +396 -0
  49. boltz/data/tokenize/tokenizer.py +24 -0
  50. boltz/data/types.py +777 -0
  51. boltz/data/write/__init__.py +0 -0
  52. boltz/data/write/mmcif.py +305 -0
  53. boltz/data/write/pdb.py +171 -0
  54. boltz/data/write/utils.py +23 -0
  55. boltz/data/write/writer.py +330 -0
  56. boltz/main.py +1292 -0
  57. boltz/model/__init__.py +0 -0
  58. boltz/model/layers/__init__.py +0 -0
  59. boltz/model/layers/attention.py +132 -0
  60. boltz/model/layers/attentionv2.py +111 -0
  61. boltz/model/layers/confidence_utils.py +231 -0
  62. boltz/model/layers/dropout.py +34 -0
  63. boltz/model/layers/initialize.py +100 -0
  64. boltz/model/layers/outer_product_mean.py +98 -0
  65. boltz/model/layers/pair_averaging.py +135 -0
  66. boltz/model/layers/pairformer.py +337 -0
  67. boltz/model/layers/relative.py +58 -0
  68. boltz/model/layers/transition.py +78 -0
  69. boltz/model/layers/triangular_attention/__init__.py +0 -0
  70. boltz/model/layers/triangular_attention/attention.py +189 -0
  71. boltz/model/layers/triangular_attention/primitives.py +409 -0
  72. boltz/model/layers/triangular_attention/utils.py +380 -0
  73. boltz/model/layers/triangular_mult.py +212 -0
  74. boltz/model/loss/__init__.py +0 -0
  75. boltz/model/loss/bfactor.py +49 -0
  76. boltz/model/loss/confidence.py +590 -0
  77. boltz/model/loss/confidencev2.py +621 -0
  78. boltz/model/loss/diffusion.py +171 -0
  79. boltz/model/loss/diffusionv2.py +134 -0
  80. boltz/model/loss/distogram.py +48 -0
  81. boltz/model/loss/distogramv2.py +105 -0
  82. boltz/model/loss/validation.py +1025 -0
  83. boltz/model/models/__init__.py +0 -0
  84. boltz/model/models/boltz1.py +1286 -0
  85. boltz/model/models/boltz2.py +1249 -0
  86. boltz/model/modules/__init__.py +0 -0
  87. boltz/model/modules/affinity.py +223 -0
  88. boltz/model/modules/confidence.py +481 -0
  89. boltz/model/modules/confidence_utils.py +181 -0
  90. boltz/model/modules/confidencev2.py +495 -0
  91. boltz/model/modules/diffusion.py +844 -0
  92. boltz/model/modules/diffusion_conditioning.py +116 -0
  93. boltz/model/modules/diffusionv2.py +677 -0
  94. boltz/model/modules/encoders.py +639 -0
  95. boltz/model/modules/encodersv2.py +565 -0
  96. boltz/model/modules/transformers.py +322 -0
  97. boltz/model/modules/transformersv2.py +261 -0
  98. boltz/model/modules/trunk.py +688 -0
  99. boltz/model/modules/trunkv2.py +828 -0
  100. boltz/model/modules/utils.py +303 -0
  101. boltz/model/optim/__init__.py +0 -0
  102. boltz/model/optim/ema.py +389 -0
  103. boltz/model/optim/scheduler.py +99 -0
  104. boltz/model/potentials/__init__.py +0 -0
  105. boltz/model/potentials/potentials.py +497 -0
  106. boltz/model/potentials/schedules.py +32 -0
  107. boltz_vsynthes-1.0.0.dist-info/METADATA +151 -0
  108. boltz_vsynthes-1.0.0.dist-info/RECORD +112 -0
  109. boltz_vsynthes-1.0.0.dist-info/WHEEL +5 -0
  110. boltz_vsynthes-1.0.0.dist-info/entry_points.txt +2 -0
  111. boltz_vsynthes-1.0.0.dist-info/licenses/LICENSE +21 -0
  112. boltz_vsynthes-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,116 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import Module
6
+
7
+ from boltz.model.modules.encodersv2 import (
8
+ AtomEncoder,
9
+ PairwiseConditioning,
10
+ )
11
+
12
+
13
+ class DiffusionConditioning(Module):
14
+ def __init__(
15
+ self,
16
+ token_s: int,
17
+ token_z: int,
18
+ atom_s: int,
19
+ atom_z: int,
20
+ atoms_per_window_queries: int = 32,
21
+ atoms_per_window_keys: int = 128,
22
+ atom_encoder_depth: int = 3,
23
+ atom_encoder_heads: int = 4,
24
+ token_transformer_depth: int = 24,
25
+ token_transformer_heads: int = 8,
26
+ atom_decoder_depth: int = 3,
27
+ atom_decoder_heads: int = 4,
28
+ atom_feature_dim: int = 128,
29
+ conditioning_transition_layers: int = 2,
30
+ use_no_atom_char: bool = False,
31
+ use_atom_backbone_feat: bool = False,
32
+ use_residue_feats_atoms: bool = False,
33
+ ) -> None:
34
+ super().__init__()
35
+
36
+ self.pairwise_conditioner = PairwiseConditioning(
37
+ token_z=token_z,
38
+ dim_token_rel_pos_feats=token_z,
39
+ num_transitions=conditioning_transition_layers,
40
+ )
41
+
42
+ self.atom_encoder = AtomEncoder(
43
+ atom_s=atom_s,
44
+ atom_z=atom_z,
45
+ token_s=token_s,
46
+ token_z=token_z,
47
+ atoms_per_window_queries=atoms_per_window_queries,
48
+ atoms_per_window_keys=atoms_per_window_keys,
49
+ atom_feature_dim=atom_feature_dim,
50
+ structure_prediction=True,
51
+ use_no_atom_char=use_no_atom_char,
52
+ use_atom_backbone_feat=use_atom_backbone_feat,
53
+ use_residue_feats_atoms=use_residue_feats_atoms,
54
+ )
55
+
56
+ self.atom_enc_proj_z = nn.ModuleList()
57
+ for _ in range(atom_encoder_depth):
58
+ self.atom_enc_proj_z.append(
59
+ nn.Sequential(
60
+ nn.LayerNorm(atom_z),
61
+ nn.Linear(atom_z, atom_encoder_heads, bias=False),
62
+ )
63
+ )
64
+
65
+ self.atom_dec_proj_z = nn.ModuleList()
66
+ for _ in range(atom_decoder_depth):
67
+ self.atom_dec_proj_z.append(
68
+ nn.Sequential(
69
+ nn.LayerNorm(atom_z),
70
+ nn.Linear(atom_z, atom_decoder_heads, bias=False),
71
+ )
72
+ )
73
+
74
+ self.token_trans_proj_z = nn.ModuleList()
75
+ for _ in range(token_transformer_depth):
76
+ self.token_trans_proj_z.append(
77
+ nn.Sequential(
78
+ nn.LayerNorm(token_z),
79
+ nn.Linear(token_z, token_transformer_heads, bias=False),
80
+ )
81
+ )
82
+
83
+ def forward(
84
+ self,
85
+ s_trunk, # Float['b n ts']
86
+ z_trunk, # Float['b n n tz']
87
+ relative_position_encoding, # Float['b n n tz']
88
+ feats,
89
+ ):
90
+ z = self.pairwise_conditioner(
91
+ z_trunk,
92
+ relative_position_encoding,
93
+ )
94
+
95
+ q, c, p, to_keys = self.atom_encoder(
96
+ feats=feats,
97
+ s_trunk=s_trunk, # Float['b n ts'],
98
+ z=z, # Float['b n n tz'],
99
+ )
100
+
101
+ atom_enc_bias = []
102
+ for layer in self.atom_enc_proj_z:
103
+ atom_enc_bias.append(layer(p))
104
+ atom_enc_bias = torch.cat(atom_enc_bias, dim=-1)
105
+
106
+ atom_dec_bias = []
107
+ for layer in self.atom_dec_proj_z:
108
+ atom_dec_bias.append(layer(p))
109
+ atom_dec_bias = torch.cat(atom_dec_bias, dim=-1)
110
+
111
+ token_trans_bias = []
112
+ for layer in self.token_trans_proj_z:
113
+ token_trans_bias.append(layer(z))
114
+ token_trans_bias = torch.cat(token_trans_bias, dim=-1)
115
+
116
+ return q, c, to_keys, atom_enc_bias, atom_dec_bias, token_trans_bias