rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,635 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any, Literal
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from jaxtyping import Float
|
|
7
|
+
|
|
8
|
+
from foundry.common import exists
|
|
9
|
+
from foundry.utils.ddp import RankedLogger
|
|
10
|
+
from foundry.utils.rotation_augmentation import (
|
|
11
|
+
rot_vec_mul,
|
|
12
|
+
uniform_random_rotation,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass(kw_only=True)
|
|
19
|
+
class SampleDiffusionConfig:
|
|
20
|
+
kind: Literal["default", "symmetry"] = "default"
|
|
21
|
+
|
|
22
|
+
# Standard EDM args
|
|
23
|
+
num_timesteps: int = 200
|
|
24
|
+
min_t: int = 0
|
|
25
|
+
max_t: int = 1
|
|
26
|
+
sigma_data: int = 16
|
|
27
|
+
s_min: float = 4e-4
|
|
28
|
+
s_max: int = 160
|
|
29
|
+
p: int = 7
|
|
30
|
+
gamma_0: float = 0.6
|
|
31
|
+
gamma_min: float = 1.0
|
|
32
|
+
noise_scale: float = 1.003
|
|
33
|
+
step_scale: float = 1.5
|
|
34
|
+
solver: Literal["af3"] = "af3"
|
|
35
|
+
|
|
36
|
+
# RFD3 / design args
|
|
37
|
+
center_option: str = "all"
|
|
38
|
+
s_trans: float = 1.0
|
|
39
|
+
s_jitter_origin: float = 0.0
|
|
40
|
+
fraction_of_steps_to_fix_motif: float = 0.0
|
|
41
|
+
skip_few_diffusion_steps: bool = False
|
|
42
|
+
allow_realignment: bool = False
|
|
43
|
+
insert_motif_at_end: bool = True
|
|
44
|
+
use_classifier_free_guidance: bool = False
|
|
45
|
+
cfg_scale: float = 2.0
|
|
46
|
+
cfg_t_max: float | None = None
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class SampleDiffusionWithMotif(SampleDiffusionConfig):
|
|
50
|
+
"""Diffusion sampler that supports optional motif alignment."""
|
|
51
|
+
|
|
52
|
+
def _construct_inference_noise_schedule(
|
|
53
|
+
self, device: torch.device, partial_t: float = None
|
|
54
|
+
) -> torch.Tensor:
|
|
55
|
+
"""Constructs a noise schedule for use during inference.
|
|
56
|
+
|
|
57
|
+
The inference noise schedule is defined in the AF-3 supplement as:
|
|
58
|
+
|
|
59
|
+
t_hat = sigma_data * (s_max**(1/p) + t * (s_min**(1/p) - s_max**(1/p)))**p
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
torch.Tensor: A tensor representing the noise schedule `t_hat`.
|
|
63
|
+
|
|
64
|
+
Reference:
|
|
65
|
+
AlphaFold 3 Supplement, Section 3.7.1.
|
|
66
|
+
"""
|
|
67
|
+
# Create a linearly spaced tensor of timesteps between min_t and max_t
|
|
68
|
+
t = torch.linspace(self.min_t, self.max_t, self.num_timesteps, device=device)
|
|
69
|
+
|
|
70
|
+
# Construct the noise schedule, using the formula provided in the reference
|
|
71
|
+
t_hat = (
|
|
72
|
+
self.sigma_data
|
|
73
|
+
* (
|
|
74
|
+
(self.s_max) ** (1 / self.p)
|
|
75
|
+
+ t * (self.s_min ** (1 / self.p) - self.s_max ** (1 / self.p))
|
|
76
|
+
)
|
|
77
|
+
** self.p
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
if partial_t is not None:
|
|
81
|
+
# For now, partial t is a global parameter
|
|
82
|
+
partial_t = float(partial_t.mean())
|
|
83
|
+
noise_schedule = t_hat
|
|
84
|
+
ranked_logger.info("Using partial diffusion with t={}".format(partial_t))
|
|
85
|
+
|
|
86
|
+
# Debug the noise schedule filtering
|
|
87
|
+
original_schedule_len = len(noise_schedule)
|
|
88
|
+
original_max = noise_schedule.max().item()
|
|
89
|
+
original_min = noise_schedule.min().item()
|
|
90
|
+
|
|
91
|
+
noise_schedule = noise_schedule[noise_schedule <= partial_t]
|
|
92
|
+
|
|
93
|
+
new_schedule_len = len(noise_schedule)
|
|
94
|
+
if new_schedule_len > 0:
|
|
95
|
+
new_max = noise_schedule.max().item()
|
|
96
|
+
new_min = noise_schedule.min().item()
|
|
97
|
+
ranked_logger.info(
|
|
98
|
+
f"Noise schedule: {original_schedule_len} → {new_schedule_len} steps"
|
|
99
|
+
)
|
|
100
|
+
ranked_logger.info(
|
|
101
|
+
f"Original range: [{original_min:.3f}, {original_max:.3f}]"
|
|
102
|
+
)
|
|
103
|
+
ranked_logger.info(f"Filtered range: [{new_min:.3f}, {new_max:.3f}]")
|
|
104
|
+
else:
|
|
105
|
+
ranked_logger.warning(
|
|
106
|
+
f"No noise schedule steps found with t <= {partial_t}!"
|
|
107
|
+
)
|
|
108
|
+
ranked_logger.info(
|
|
109
|
+
f"Original schedule range: [{original_min:.3f}, {original_max:.3f}]"
|
|
110
|
+
)
|
|
111
|
+
# Fallback to smallest available step
|
|
112
|
+
noise_schedule_original = self._construct_inference_noise_schedule(
|
|
113
|
+
device=coord_atom_lvl_to_be_noised.device
|
|
114
|
+
)
|
|
115
|
+
noise_schedule = noise_schedule_original[-1:] # Just use the final step
|
|
116
|
+
ranked_logger.info(
|
|
117
|
+
f"Using fallback: final step with t={noise_schedule[0].item():.6f}"
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
return t_hat
|
|
121
|
+
|
|
122
|
+
def _get_initial_structure(
|
|
123
|
+
self,
|
|
124
|
+
c0: torch.Tensor,
|
|
125
|
+
D: int,
|
|
126
|
+
L: int,
|
|
127
|
+
coord_atom_lvl_to_be_noised: torch.Tensor,
|
|
128
|
+
is_motif_atom_with_fixed_coord,
|
|
129
|
+
) -> torch.Tensor:
|
|
130
|
+
noise = c0 * torch.normal(mean=0.0, std=1.0, size=(D, L, 3), device=c0.device)
|
|
131
|
+
noise[..., is_motif_atom_with_fixed_coord, :] = 0 # Zero out noise going in
|
|
132
|
+
X_L = noise + coord_atom_lvl_to_be_noised
|
|
133
|
+
return X_L
|
|
134
|
+
|
|
135
|
+
def sample_diffusion_like_af3(
|
|
136
|
+
self,
|
|
137
|
+
*,
|
|
138
|
+
f: dict[str, Any],
|
|
139
|
+
diffusion_module: torch.nn.Module,
|
|
140
|
+
diffusion_batch_size: int,
|
|
141
|
+
coord_atom_lvl_to_be_noised: Float[torch.Tensor, "D L 3"],
|
|
142
|
+
initializer_outputs,
|
|
143
|
+
ref_initializer_outputs: dict[str, Any] | None,
|
|
144
|
+
f_ref: dict[str, Any] | None,
|
|
145
|
+
) -> dict[str, Any]:
|
|
146
|
+
# Motif setup to recenter the motif at every step
|
|
147
|
+
is_motif_atom_with_fixed_coord = f["is_motif_atom_with_fixed_coord"]
|
|
148
|
+
|
|
149
|
+
# Book-keeping
|
|
150
|
+
noise_schedule = self._construct_inference_noise_schedule(
|
|
151
|
+
device=coord_atom_lvl_to_be_noised.device,
|
|
152
|
+
partial_t=f.get("partial_t", None),
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
L = f["ref_element"].shape[0]
|
|
156
|
+
D = diffusion_batch_size
|
|
157
|
+
|
|
158
|
+
X_L = self._get_initial_structure(
|
|
159
|
+
c0=noise_schedule[0],
|
|
160
|
+
D=D,
|
|
161
|
+
L=L,
|
|
162
|
+
coord_atom_lvl_to_be_noised=coord_atom_lvl_to_be_noised.clone(),
|
|
163
|
+
is_motif_atom_with_fixed_coord=is_motif_atom_with_fixed_coord,
|
|
164
|
+
) # (D, L, 3)
|
|
165
|
+
|
|
166
|
+
if self.s_jitter_origin > 0.0:
|
|
167
|
+
X_L[:, is_motif_atom_with_fixed_coord, :] += torch.normal(
|
|
168
|
+
mean=0.0,
|
|
169
|
+
std=self.s_jitter_origin,
|
|
170
|
+
size=(D, 1, 3),
|
|
171
|
+
device=X_L.device,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
X_noisy_L_traj = []
|
|
175
|
+
X_denoised_L_traj = []
|
|
176
|
+
sequence_entropy_traj = []
|
|
177
|
+
t_hats = []
|
|
178
|
+
|
|
179
|
+
threshold_step = (len(noise_schedule) - 1) * self.fraction_of_steps_to_fix_motif
|
|
180
|
+
|
|
181
|
+
for step_num, (c_t_minus_1, c_t) in enumerate(
|
|
182
|
+
zip(noise_schedule, noise_schedule[1:])
|
|
183
|
+
):
|
|
184
|
+
# Assert no grads on X_L
|
|
185
|
+
assert not torch.is_grad_enabled(), "Computation graph should not be active"
|
|
186
|
+
assert not X_L.requires_grad, "X_L should not require gradients"
|
|
187
|
+
|
|
188
|
+
# Apply a random rotation and translation to the structure
|
|
189
|
+
if self.allow_realignment:
|
|
190
|
+
X_L, _ = centre_random_augment_around_motif(
|
|
191
|
+
X_L,
|
|
192
|
+
coord_atom_lvl_to_be_noised,
|
|
193
|
+
is_motif_atom_with_fixed_coord,
|
|
194
|
+
center_option=self.center_option,
|
|
195
|
+
# If centering_affects_motif is True, the model's predictions from (step_num-1) might affect the motif
|
|
196
|
+
centering_affects_motif=(max(step_num - 1, 0)) >= threshold_step,
|
|
197
|
+
# If keeping the motif position wrt the origin fixed, we can't do translational augmentation
|
|
198
|
+
# We want to keep this position fixed in the interval where the model is not allowed to change it
|
|
199
|
+
s_trans=self.s_trans if step_num >= threshold_step else 0.0,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Update gamma & step scale
|
|
203
|
+
gamma = self.gamma_0 if c_t > self.gamma_min else 0
|
|
204
|
+
step_scale = self.step_scale
|
|
205
|
+
|
|
206
|
+
# Compute the value of t_hat
|
|
207
|
+
t_hat = c_t_minus_1 * (gamma + 1)
|
|
208
|
+
|
|
209
|
+
# Noise the coordinates with scaled Gaussian noise
|
|
210
|
+
epsilon_L = (
|
|
211
|
+
self.noise_scale
|
|
212
|
+
* torch.sqrt(torch.square(t_hat) - torch.square(c_t_minus_1))
|
|
213
|
+
* torch.normal(mean=0.0, std=1.0, size=X_L.shape, device=X_L.device)
|
|
214
|
+
)
|
|
215
|
+
epsilon_L[..., is_motif_atom_with_fixed_coord, :] = (
|
|
216
|
+
0 # No noise injection for fixed atoms
|
|
217
|
+
)
|
|
218
|
+
X_noisy_L = X_L + epsilon_L
|
|
219
|
+
|
|
220
|
+
# Denoise the coordinates
|
|
221
|
+
# Handle chunked mode vs standard mode
|
|
222
|
+
if "chunked_pairwise_embedder" in initializer_outputs:
|
|
223
|
+
# Chunked mode: explicitly provide P_LL=None
|
|
224
|
+
chunked_embedder = initializer_outputs[
|
|
225
|
+
"chunked_pairwise_embedder"
|
|
226
|
+
] # Don't pop, just get
|
|
227
|
+
other_outputs = {
|
|
228
|
+
k: v
|
|
229
|
+
for k, v in initializer_outputs.items()
|
|
230
|
+
if k != "chunked_pairwise_embedder"
|
|
231
|
+
}
|
|
232
|
+
outs = diffusion_module(
|
|
233
|
+
X_noisy_L=X_noisy_L,
|
|
234
|
+
t=t_hat.tile(D),
|
|
235
|
+
f=f,
|
|
236
|
+
P_LL=None, # Not used in chunked mode
|
|
237
|
+
chunked_pairwise_embedder=chunked_embedder,
|
|
238
|
+
initializer_outputs=other_outputs,
|
|
239
|
+
**other_outputs,
|
|
240
|
+
)
|
|
241
|
+
else:
|
|
242
|
+
# Standard mode: P_LL is included in initializer_outputs
|
|
243
|
+
outs = diffusion_module(
|
|
244
|
+
X_noisy_L=X_noisy_L,
|
|
245
|
+
t=t_hat.tile(D),
|
|
246
|
+
f=f,
|
|
247
|
+
**initializer_outputs,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
X_denoised_L = outs["X_L"] if "X_L" in outs else outs
|
|
251
|
+
|
|
252
|
+
# Compute the delta between the noisy and denoised coordinates, scaled by t_hat
|
|
253
|
+
delta_L = (
|
|
254
|
+
X_noisy_L - X_denoised_L
|
|
255
|
+
) / t_hat # gradient of x wrt. t at x_t_hat
|
|
256
|
+
d_t = c_t - t_hat
|
|
257
|
+
|
|
258
|
+
if self.use_classifier_free_guidance and (
|
|
259
|
+
self.cfg_t_max is None or c_t > self.cfg_t_max
|
|
260
|
+
):
|
|
261
|
+
X_noisy_L_stripped = strip_X(X_noisy_L, f_ref)
|
|
262
|
+
|
|
263
|
+
# unconditional forward pass
|
|
264
|
+
outs_ref = diffusion_module(
|
|
265
|
+
X_noisy_L=X_noisy_L_stripped, # modify X
|
|
266
|
+
t=t_hat.tile(D),
|
|
267
|
+
f=f_ref, # modified f
|
|
268
|
+
**ref_initializer_outputs,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
X_denoised_L_stripped = outs_ref["X_L"]
|
|
272
|
+
|
|
273
|
+
delta_L_ref = (
|
|
274
|
+
X_noisy_L_stripped - X_denoised_L_stripped
|
|
275
|
+
) / t_hat # gradient of x wrt. t at x_t_hat
|
|
276
|
+
|
|
277
|
+
# pad delta_L_ref with zeros to match delta_L (for the unindexed atoms)
|
|
278
|
+
if delta_L_ref.shape[1] < delta_L.shape[1]:
|
|
279
|
+
delta_L_ref = torch.cat(
|
|
280
|
+
[
|
|
281
|
+
delta_L_ref,
|
|
282
|
+
torch.zeros_like(delta_L[:, delta_L_ref.shape[1] :, :]),
|
|
283
|
+
],
|
|
284
|
+
dim=1,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
# apply CFG
|
|
288
|
+
delta_L = delta_L + (self.cfg_scale - 1) * (delta_L - delta_L_ref)
|
|
289
|
+
|
|
290
|
+
if exists(outs.get("sequence_logits_I")):
|
|
291
|
+
# Compute confidence
|
|
292
|
+
p = torch.softmax(
|
|
293
|
+
outs["sequence_logits_I"], dim=-1
|
|
294
|
+
).cpu() # shape (D, L, 32)
|
|
295
|
+
seq_entropy = -torch.sum(
|
|
296
|
+
p * torch.log(p + 1e-10), dim=-1
|
|
297
|
+
) # shape (D, L,)
|
|
298
|
+
sequence_entropy_traj.append(seq_entropy)
|
|
299
|
+
|
|
300
|
+
# Update the coordinates, scaled by the step size
|
|
301
|
+
X_L = X_noisy_L + step_scale * d_t * delta_L
|
|
302
|
+
|
|
303
|
+
# Append the results to the trajectory (for visualization of the diffusion process)
|
|
304
|
+
X_noisy_L_scaled = (
|
|
305
|
+
self.sigma_data * X_noisy_L / torch.sqrt(t_hat**2 + self.sigma_data**2)
|
|
306
|
+
) # Save noisy traj as scaled inputs
|
|
307
|
+
X_noisy_L_traj.append(X_noisy_L_scaled)
|
|
308
|
+
X_denoised_L_traj.append(X_denoised_L)
|
|
309
|
+
t_hats.append(t_hat)
|
|
310
|
+
|
|
311
|
+
if torch.any(is_motif_atom_with_fixed_coord) and self.allow_realignment:
|
|
312
|
+
# Insert the gt motif at the end
|
|
313
|
+
X_L, _ = centre_random_augment_around_motif(
|
|
314
|
+
X_L,
|
|
315
|
+
coord_atom_lvl_to_be_noised,
|
|
316
|
+
is_motif_atom_with_fixed_coord,
|
|
317
|
+
reinsert_motif=self.insert_motif_at_end,
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
# Align prediction to original motif
|
|
321
|
+
X_L = weighted_rigid_align(
|
|
322
|
+
coord_atom_lvl_to_be_noised,
|
|
323
|
+
X_L,
|
|
324
|
+
X_exists_L=is_motif_atom_with_fixed_coord,
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
return dict(
|
|
328
|
+
X_L=X_L, # (D, L, 3)
|
|
329
|
+
X_noisy_L_traj=X_noisy_L_traj, # list[Tensor[D, L, 3]]
|
|
330
|
+
X_denoised_L_traj=X_denoised_L_traj, # list[Tensor[D, L, 3]]
|
|
331
|
+
t_hats=t_hats, # list[Tensor[D]], where D is shared across all diffusion batches
|
|
332
|
+
sequence_logits_I=outs.get("sequence_logits_I"), # (D, I, 32)
|
|
333
|
+
sequence_indices_I=outs.get("sequence_indices_I"), # (D, I, 32)
|
|
334
|
+
sequence_entropy_traj=sequence_entropy_traj, # list[Tensor[D, I]]
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
|
|
339
|
+
"""
|
|
340
|
+
This class is a wrapper around the SampleDiffusionWithMotif class.
|
|
341
|
+
It is used to sample diffusion with symmetry.
|
|
342
|
+
"""
|
|
343
|
+
|
|
344
|
+
def __init__(self, sym_step_frac: float = 0.9, **kwargs):
|
|
345
|
+
assert (
|
|
346
|
+
kwargs.get("gamma_0") > 0.5
|
|
347
|
+
), "gamma_0 must be greater than 0.5 for symmetry sampling"
|
|
348
|
+
self.sym_step_frac = sym_step_frac
|
|
349
|
+
super().__init__(**kwargs)
|
|
350
|
+
|
|
351
|
+
def apply_symmetry_to_X_L(self, X_L, f):
|
|
352
|
+
# check that we are doing symmetric inference
|
|
353
|
+
|
|
354
|
+
assert "sym_transform" in f.keys(), "Symmetry transform not found in f"
|
|
355
|
+
|
|
356
|
+
# update symmetric frames to correct for change in global frame
|
|
357
|
+
symmetry_feats = {k: v for k, v in f.items() if "sym" in k}
|
|
358
|
+
|
|
359
|
+
# apply symmetry frame shift to X_L
|
|
360
|
+
X_L = apply_symmetry_to_xyz_atomwise(
|
|
361
|
+
X_L, symmetry_feats, partial_diffusion=("partial_t" in f)
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
return X_L
|
|
365
|
+
|
|
366
|
+
def sample_diffusion_like_af3(
|
|
367
|
+
self,
|
|
368
|
+
*,
|
|
369
|
+
f: dict[str, Any],
|
|
370
|
+
diffusion_module: torch.nn.Module,
|
|
371
|
+
diffusion_batch_size: int,
|
|
372
|
+
coord_atom_lvl_to_be_noised: Float[torch.Tensor, "D L 3"],
|
|
373
|
+
initializer_outputs,
|
|
374
|
+
ref_initializer_outputs: dict[str, Any] | None,
|
|
375
|
+
f_ref: dict[str, Any] | None,
|
|
376
|
+
**_,
|
|
377
|
+
) -> dict[str, Any]:
|
|
378
|
+
# Motif setup to recenter the motif at every step
|
|
379
|
+
is_motif_atom_with_fixed_coord = f["is_motif_atom_with_fixed_coord"]
|
|
380
|
+
# Book-keeping
|
|
381
|
+
noise_schedule = self._construct_inference_noise_schedule(
|
|
382
|
+
device=coord_atom_lvl_to_be_noised.device,
|
|
383
|
+
partial_t=f.get("partial_t", None),
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
L = f["ref_element"].shape[0]
|
|
387
|
+
D = diffusion_batch_size
|
|
388
|
+
X_L = self._get_initial_structure(
|
|
389
|
+
c0=noise_schedule[0],
|
|
390
|
+
D=D,
|
|
391
|
+
L=L,
|
|
392
|
+
coord_atom_lvl_to_be_noised=coord_atom_lvl_to_be_noised.clone(),
|
|
393
|
+
is_motif_atom_with_fixed_coord=is_motif_atom_with_fixed_coord,
|
|
394
|
+
) # (D, L, 3)
|
|
395
|
+
|
|
396
|
+
X_noisy_L_traj = []
|
|
397
|
+
X_denoised_L_traj = []
|
|
398
|
+
sequence_entropy_traj = []
|
|
399
|
+
t_hats = []
|
|
400
|
+
|
|
401
|
+
# symmetrize X_L until the step gamma = gamma_min_sym
|
|
402
|
+
gamma_min_sym_idx = min(
|
|
403
|
+
int(len(noise_schedule) * self.sym_step_frac), len(noise_schedule) - 1
|
|
404
|
+
)
|
|
405
|
+
gamma_min_sym = noise_schedule[gamma_min_sym_idx]
|
|
406
|
+
|
|
407
|
+
ranked_logger.info(f"gamma_min_sym: {gamma_min_sym}")
|
|
408
|
+
ranked_logger.info(f"gamma_min: {self.gamma_min}")
|
|
409
|
+
for step_num, (c_t_minus_1, c_t) in enumerate(
|
|
410
|
+
zip(noise_schedule, noise_schedule[1:])
|
|
411
|
+
):
|
|
412
|
+
# Assert no grads on X_L
|
|
413
|
+
assert not torch.is_grad_enabled(), "Computation graph should not be active"
|
|
414
|
+
assert not X_L.requires_grad, "X_L should not require gradients"
|
|
415
|
+
|
|
416
|
+
# Apply a random rotation and translation to the structure
|
|
417
|
+
if self.allow_realignment:
|
|
418
|
+
X_L, R = centre_random_augment_around_motif(
|
|
419
|
+
X_L,
|
|
420
|
+
coord_atom_lvl_to_be_noised,
|
|
421
|
+
is_motif_atom_with_fixed_coord,
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
# Update gamma & step scale
|
|
425
|
+
gamma = self.gamma_0 if c_t > self.gamma_min else 0
|
|
426
|
+
step_scale = self.step_scale
|
|
427
|
+
|
|
428
|
+
# Compute the value of t_hat
|
|
429
|
+
t_hat = c_t_minus_1 * (gamma + 1)
|
|
430
|
+
|
|
431
|
+
# Noise the coordinates with scaled Gaussian noise
|
|
432
|
+
epsilon_L = (
|
|
433
|
+
self.noise_scale
|
|
434
|
+
* torch.sqrt(torch.square(t_hat) - torch.square(c_t_minus_1))
|
|
435
|
+
* torch.normal(mean=0.0, std=1.0, size=X_L.shape, device=X_L.device)
|
|
436
|
+
)
|
|
437
|
+
epsilon_L[..., is_motif_atom_with_fixed_coord, :] = (
|
|
438
|
+
0 # No noise injection for fixed atoms
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
# NOTE: no symmetry applied to the noisy structure
|
|
442
|
+
X_noisy_L = X_L + epsilon_L
|
|
443
|
+
|
|
444
|
+
# Denoise the coordinates
|
|
445
|
+
# Handle chunked mode vs standard mode (same as default sampler)
|
|
446
|
+
if "chunked_pairwise_embedder" in initializer_outputs:
|
|
447
|
+
# Chunked mode: explicitly provide P_LL=None
|
|
448
|
+
chunked_embedder = initializer_outputs[
|
|
449
|
+
"chunked_pairwise_embedder"
|
|
450
|
+
] # Don't pop, just get
|
|
451
|
+
other_outputs = {
|
|
452
|
+
k: v
|
|
453
|
+
for k, v in initializer_outputs.items()
|
|
454
|
+
if k != "chunked_pairwise_embedder"
|
|
455
|
+
}
|
|
456
|
+
outs = diffusion_module(
|
|
457
|
+
X_noisy_L=X_noisy_L,
|
|
458
|
+
t=t_hat.tile(D),
|
|
459
|
+
f=f,
|
|
460
|
+
P_LL=None, # Not used in chunked mode
|
|
461
|
+
chunked_pairwise_embedder=chunked_embedder,
|
|
462
|
+
initializer_outputs=other_outputs,
|
|
463
|
+
**other_outputs,
|
|
464
|
+
)
|
|
465
|
+
else:
|
|
466
|
+
# Standard mode: P_LL is included in initializer_outputs
|
|
467
|
+
outs = diffusion_module(
|
|
468
|
+
X_noisy_L=X_noisy_L,
|
|
469
|
+
t=t_hat.tile(D),
|
|
470
|
+
f=f,
|
|
471
|
+
**initializer_outputs,
|
|
472
|
+
)
|
|
473
|
+
# apply symmetry to X_denoised_L
|
|
474
|
+
if "X_L" in outs and c_t > gamma_min_sym:
|
|
475
|
+
# outs["original_X_L"] = outs["X_L"].clone()
|
|
476
|
+
outs["X_L"] = self.apply_symmetry_to_X_L(outs["X_L"], f)
|
|
477
|
+
|
|
478
|
+
X_denoised_L = outs["X_L"] if "X_L" in outs else outs
|
|
479
|
+
|
|
480
|
+
# Compute the delta between the noisy and denoised coordinates, scaled by t_hat
|
|
481
|
+
delta_L = (
|
|
482
|
+
X_noisy_L - X_denoised_L
|
|
483
|
+
) / t_hat # gradient of x wrt. t at x_t_hat
|
|
484
|
+
d_t = c_t - t_hat
|
|
485
|
+
|
|
486
|
+
# NOTE: no classifier-free guidance for symmetry
|
|
487
|
+
|
|
488
|
+
if exists(outs.get("sequence_logits_I")):
|
|
489
|
+
# Compute confidence
|
|
490
|
+
p = torch.softmax(
|
|
491
|
+
outs["sequence_logits_I"], dim=-1
|
|
492
|
+
).cpu() # shape (D, L, 32)
|
|
493
|
+
seq_entropy = -torch.sum(
|
|
494
|
+
p * torch.log(p + 1e-10), dim=-1
|
|
495
|
+
) # shape (D, L,)
|
|
496
|
+
sequence_entropy_traj.append(seq_entropy)
|
|
497
|
+
|
|
498
|
+
# Update the coordinates, scaled by the step size
|
|
499
|
+
# delta_L should be symmetric
|
|
500
|
+
X_L = X_noisy_L + step_scale * d_t * delta_L
|
|
501
|
+
|
|
502
|
+
# Append the results to the trajectory (for visualization of the diffusion process)
|
|
503
|
+
X_noisy_L_scaled = (
|
|
504
|
+
self.sigma_data * X_noisy_L / torch.sqrt(t_hat**2 + self.sigma_data**2)
|
|
505
|
+
) # Save noisy traj as scaled inputs
|
|
506
|
+
X_noisy_L_traj.append(X_noisy_L_scaled)
|
|
507
|
+
X_denoised_L_traj.append(X_denoised_L)
|
|
508
|
+
t_hats.append(t_hat)
|
|
509
|
+
|
|
510
|
+
if torch.any(is_motif_atom_with_fixed_coord) and self.allow_realignment:
|
|
511
|
+
# Insert the gt motif at the end
|
|
512
|
+
X_L, R = centre_random_augment_around_motif(
|
|
513
|
+
X_L,
|
|
514
|
+
coord_atom_lvl_to_be_noised,
|
|
515
|
+
is_motif_atom_with_fixed_coord,
|
|
516
|
+
reinsert_motif=self.insert_motif_at_end,
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
# apply symmetry frame shift to X_L
|
|
520
|
+
X_L = self.apply_symmetry_to_X_L(X_L, f)
|
|
521
|
+
|
|
522
|
+
# Align prediction to original motif
|
|
523
|
+
X_L = weighted_rigid_align(
|
|
524
|
+
coord_atom_lvl_to_be_noised,
|
|
525
|
+
X_L,
|
|
526
|
+
X_exists_L=is_motif_atom_with_fixed_coord,
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
return dict(
|
|
530
|
+
X_L=X_L, # (D, L, 3)
|
|
531
|
+
X_noisy_L_traj=X_noisy_L_traj, # list[Tensor[D, L, 3]]
|
|
532
|
+
X_denoised_L_traj=X_denoised_L_traj, # list[Tensor[D, L, 3]]
|
|
533
|
+
t_hats=t_hats, # list[Tensor[D]], where D is shared across all diffusion batches
|
|
534
|
+
sequence_logits_I=outs.get("sequence_logits_I"), # (D, I, 32)
|
|
535
|
+
sequence_indices_I=outs.get("sequence_indices_I"), # (D, I, 32)
|
|
536
|
+
sequence_entropy_traj=sequence_entropy_traj, # list[Tensor[D, I]]
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
class ConditionalDiffusionSampler:
|
|
541
|
+
"""
|
|
542
|
+
Conditional diffusion sampler, chooses at construction time which sampler to use,
|
|
543
|
+
then forwards `sample_diffusion_like_af3` to the chosen sampler.
|
|
544
|
+
If you write a new sampler, you best add it to the registry below
|
|
545
|
+
and inference_sampler.kind in inference_engine config.
|
|
546
|
+
"""
|
|
547
|
+
|
|
548
|
+
_registry = {
|
|
549
|
+
"default": SampleDiffusionWithMotif,
|
|
550
|
+
"symmetry": SampleDiffusionWithSymmetry,
|
|
551
|
+
}
|
|
552
|
+
|
|
553
|
+
def __init__(self, kind="default", **kwargs):
|
|
554
|
+
ranked_logger.info(
|
|
555
|
+
f"Initializing ConditionalDiffusionSampler with kind: {kind}"
|
|
556
|
+
)
|
|
557
|
+
try:
|
|
558
|
+
SamplerCls = self._registry[kind]
|
|
559
|
+
# remove kwargs that the sampler cannot take
|
|
560
|
+
init_args = self.get_class_init_args(SamplerCls)
|
|
561
|
+
kwargs = {k: v for k, v in kwargs.items() if k in init_args}
|
|
562
|
+
except KeyError:
|
|
563
|
+
raise ValueError(
|
|
564
|
+
f"Invalid sampler kind: {kind}, must be one of {list(self._registry.keys())}"
|
|
565
|
+
)
|
|
566
|
+
self.sampler = SamplerCls(**kwargs)
|
|
567
|
+
|
|
568
|
+
def sample_diffusion_like_af3(self, **kwargs):
|
|
569
|
+
return self.sampler.sample_diffusion_like_af3(**kwargs)
|
|
570
|
+
|
|
571
|
+
def get_class_init_args(self, cls):
|
|
572
|
+
arg_names = []
|
|
573
|
+
if hasattr(cls, "__init__") and callable(cls.__init__):
|
|
574
|
+
for p_cls in cls.__mro__:
|
|
575
|
+
if "__init__" in p_cls.__dict__ and p_cls is not object:
|
|
576
|
+
signature = inspect.signature(p_cls.__init__)
|
|
577
|
+
arg_names.extend(
|
|
578
|
+
[param.name for param in signature.parameters.values()]
|
|
579
|
+
)
|
|
580
|
+
return arg_names
|
|
581
|
+
|
|
582
|
+
|
|
583
|
+
def centre_random_augment_around_motif(
|
|
584
|
+
X_L: torch.Tensor, # (D, L, 3) noisy diffused coordinates
|
|
585
|
+
coord_atom_lvl_to_be_noised: torch.Tensor, # (D, L, 3) original coordinates
|
|
586
|
+
is_motif_atom_with_fixed_coord: torch.Tensor, # (D, L) indices in original coordinates to be kept constant
|
|
587
|
+
s_trans: float = 1.0,
|
|
588
|
+
center_option: str = "all",
|
|
589
|
+
centering_affects_motif: bool = True,
|
|
590
|
+
reinsert_motif=True,
|
|
591
|
+
):
|
|
592
|
+
D, L, _ = X_L.shape
|
|
593
|
+
|
|
594
|
+
if reinsert_motif and torch.any(is_motif_atom_with_fixed_coord):
|
|
595
|
+
# ... Align original coordinates to the prediction
|
|
596
|
+
coords_with_gt_aligned = weighted_rigid_align(
|
|
597
|
+
X_L[..., is_motif_atom_with_fixed_coord, :],
|
|
598
|
+
coord_atom_lvl_to_be_noised[..., is_motif_atom_with_fixed_coord, :],
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
# ... Insert original coordinates into X_L
|
|
602
|
+
X_L[..., is_motif_atom_with_fixed_coord, :] = coords_with_gt_aligned
|
|
603
|
+
|
|
604
|
+
# ... Centering
|
|
605
|
+
if torch.any(is_motif_atom_with_fixed_coord):
|
|
606
|
+
if center_option == "motif":
|
|
607
|
+
center = torch.mean(
|
|
608
|
+
X_L[..., is_motif_atom_with_fixed_coord, :], dim=-2, keepdim=True
|
|
609
|
+
) # (D, 1, 3) - COM of motif atoms
|
|
610
|
+
elif center_option == "diffuse":
|
|
611
|
+
center = torch.mean(
|
|
612
|
+
X_L[..., ~is_motif_atom_with_fixed_coord, :], dim=-2, keepdim=True
|
|
613
|
+
) # (D, 1, 3) - COM of diffused atoms
|
|
614
|
+
|
|
615
|
+
else:
|
|
616
|
+
center = torch.mean(X_L, dim=-2, keepdim=True)
|
|
617
|
+
else:
|
|
618
|
+
center = torch.mean(X_L, dim=-2, keepdim=True)
|
|
619
|
+
|
|
620
|
+
# ... Center
|
|
621
|
+
if centering_affects_motif:
|
|
622
|
+
X_L = X_L - center
|
|
623
|
+
else:
|
|
624
|
+
X_L[..., ~is_motif_atom_with_fixed_coord, :] = (
|
|
625
|
+
X_L[..., ~is_motif_atom_with_fixed_coord, :] - center
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
# ... Random augmentation
|
|
629
|
+
R = uniform_random_rotation((D,)).to(X_L.device)
|
|
630
|
+
noise = (
|
|
631
|
+
torch.normal(mean=0, std=1, size=(D, 1, 3), device=X_L.device) * s_trans
|
|
632
|
+
) # (D, 1, 3)
|
|
633
|
+
X_L = rot_vec_mul(R[:, None], X_L) + noise
|
|
634
|
+
|
|
635
|
+
return X_L, R
|