rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from beartype.typing import Any, Literal
|
|
3
|
+
from jaxtyping import Float
|
|
4
|
+
|
|
5
|
+
from foundry.utils.ddp import RankedLogger
|
|
6
|
+
from foundry.utils.rotation_augmentation import centre_random_augmentation
|
|
7
|
+
|
|
8
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SampleDiffusion:
|
|
12
|
+
"""Algorithm 18"""
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
*,
|
|
17
|
+
# Hyperparameters
|
|
18
|
+
num_timesteps: int, # AF-3: 200
|
|
19
|
+
min_t: int, # AF-3: 0
|
|
20
|
+
max_t: int, # AF-3: 1
|
|
21
|
+
sigma_data: int, # AF-3: 16
|
|
22
|
+
s_min: float, # AF-3: 4e-4
|
|
23
|
+
s_max: int, # AF-3: 160
|
|
24
|
+
p: int, # AF-3: 7
|
|
25
|
+
gamma_0: float, # AF-3: 0.8
|
|
26
|
+
gamma_min: float, # AF-3: 1.0,
|
|
27
|
+
noise_scale: float, # AF-3: 1.003,
|
|
28
|
+
step_scale: float, # AF-3: 1.5,
|
|
29
|
+
solver: Literal["af3"],
|
|
30
|
+
):
|
|
31
|
+
"""Initialize the diffusion sampler, to perform a complete diffusion roll-out with the given recycling outputs.
|
|
32
|
+
|
|
33
|
+
We do not use default values for the parameters to make the Hydra configuration the single source of truth and avoid silent failures.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
num_timesteps (int): The number of timesteps for which the noise schedule is constructed. Default is 200, per AF3.
|
|
37
|
+
min_t (float): The minimum value of t in the schedule. Default is 0, per AF3.
|
|
38
|
+
max_t (float): The maximum value of t in the schedule. Default is 1, per AF3.
|
|
39
|
+
sigma_data (int): A constant determined by the variance of the data. Default is 16, as defined in the AlphaFold 3 Supplement (Algorithm 20, Diffusion Module).
|
|
40
|
+
s_min (float): The minimum value of the noise schedule. Default is 4e-4, per AF3.
|
|
41
|
+
s_max (float): The maximum value of the noise schedule. Default is 160, per AF3.
|
|
42
|
+
p (int): A constant that determines the shape of the noise schedule. Default is 7, per AF3.
|
|
43
|
+
gamma_0 (float): The value of gamma when t > gamma_min. Default is 0.8, per AF3.
|
|
44
|
+
solver (str): The solver to use for the diffusion process. Default is "af3".
|
|
45
|
+
|
|
46
|
+
TODO: Continue documentation of the remaining parameters.
|
|
47
|
+
"""
|
|
48
|
+
self.num_timesteps = num_timesteps
|
|
49
|
+
self.min_t = min_t
|
|
50
|
+
self.max_t = max_t
|
|
51
|
+
self.sigma_data = sigma_data
|
|
52
|
+
self.s_min = s_min
|
|
53
|
+
self.s_max = s_max
|
|
54
|
+
self.p = p
|
|
55
|
+
self.gamma_0 = gamma_0
|
|
56
|
+
self.gamma_min = gamma_min
|
|
57
|
+
self.noise_scale = noise_scale
|
|
58
|
+
self.step_scale = step_scale
|
|
59
|
+
self.solver = solver
|
|
60
|
+
|
|
61
|
+
def _construct_inference_noise_schedule(self, device: torch.device) -> torch.Tensor:
|
|
62
|
+
"""Constructs a noise schedule for use during inference.
|
|
63
|
+
|
|
64
|
+
The inference noise schedule is defined in the AF-3 supplement as:
|
|
65
|
+
|
|
66
|
+
t_hat = sigma_data * (s_max**(1/p) + t * (s_min**(1/p) - s_max**(1/p)))**p
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
torch.Tensor: A tensor representing the noise schedule `t_hat`.
|
|
70
|
+
|
|
71
|
+
Reference:
|
|
72
|
+
AlphaFold 3 Supplement, Section 3.7.1.
|
|
73
|
+
"""
|
|
74
|
+
# Create a linearly spaced tensor of timesteps between min_t and max_t
|
|
75
|
+
t = torch.linspace(self.min_t, self.max_t, self.num_timesteps, device=device)
|
|
76
|
+
|
|
77
|
+
# Construct the noise schedule, using the formula provided in the reference
|
|
78
|
+
t_hat = (
|
|
79
|
+
self.sigma_data
|
|
80
|
+
* (
|
|
81
|
+
(self.s_max) ** (1 / self.p)
|
|
82
|
+
+ t * (self.s_min ** (1 / self.p) - self.s_max ** (1 / self.p))
|
|
83
|
+
)
|
|
84
|
+
** self.p
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
return t_hat
|
|
88
|
+
|
|
89
|
+
def _get_initial_structure(
|
|
90
|
+
self,
|
|
91
|
+
c0: torch.Tensor,
|
|
92
|
+
D: int,
|
|
93
|
+
L: int,
|
|
94
|
+
coord_atom_lvl_to_be_noised: torch.Tensor,
|
|
95
|
+
) -> torch.Tensor:
|
|
96
|
+
"""Sample initial point cloud from a normal distribution.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
c0 (torch.Tensor): A scalar tensor that will be used to scale the initial point cloud. Effectively, the same as
|
|
100
|
+
directly changing the standard deviation of the normal distribution. Derived from noise_schedule[0].
|
|
101
|
+
D (int): The number of structures to sample.
|
|
102
|
+
L (int): The number of atoms in the structure.
|
|
103
|
+
coord_atom_lvl_to_be_noised (torch.Tensor): The atom-level coordinates to be noised (either completely or partially)
|
|
104
|
+
"""
|
|
105
|
+
noise = c0 * torch.normal(mean=0.0, std=1.0, size=(D, L, 3), device=c0.device)
|
|
106
|
+
X_L = noise + coord_atom_lvl_to_be_noised
|
|
107
|
+
|
|
108
|
+
return X_L
|
|
109
|
+
|
|
110
|
+
def sample_diffusion_like_af3(
|
|
111
|
+
self,
|
|
112
|
+
*,
|
|
113
|
+
S_inputs_I: Float[torch.Tensor, "I c_s_inputs"],
|
|
114
|
+
S_trunk_I: Float[torch.Tensor, "I c_s"],
|
|
115
|
+
Z_trunk_II: Float[torch.Tensor, "I I c_z"],
|
|
116
|
+
f: dict[str, Any],
|
|
117
|
+
diffusion_module: torch.nn.Module,
|
|
118
|
+
diffusion_batch_size: int,
|
|
119
|
+
coord_atom_lvl_to_be_noised: Float[torch.Tensor, "D L 3"],
|
|
120
|
+
) -> dict[str, Any]:
|
|
121
|
+
"""Perform a complete diffusion roll-out with the given recycling outputs.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
diffusion_module (torch.nn.Module): The diffusion module to use for denoising. If using EMA and performing validation or inference,
|
|
125
|
+
this model should be the EMA model.
|
|
126
|
+
"""
|
|
127
|
+
# Construct the noise schedule t_hat for inference on the appropriate device
|
|
128
|
+
noise_schedule = self._construct_inference_noise_schedule(
|
|
129
|
+
device=S_inputs_I.device
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Infer number of atoms from any atom-level feature
|
|
133
|
+
L = f["ref_element"].shape[0]
|
|
134
|
+
D = diffusion_batch_size
|
|
135
|
+
|
|
136
|
+
# Initial X_L is drawn from a normal distribution with a mean vector of 0 and a
|
|
137
|
+
# covariance matrix equal to the 3x3 identity matrix, scaled by the noise schedule
|
|
138
|
+
X_L = self._get_initial_structure(
|
|
139
|
+
c0=noise_schedule[0],
|
|
140
|
+
D=D,
|
|
141
|
+
L=L,
|
|
142
|
+
coord_atom_lvl_to_be_noised=coord_atom_lvl_to_be_noised,
|
|
143
|
+
) # (D, L, 3)
|
|
144
|
+
|
|
145
|
+
X_noisy_L_traj = []
|
|
146
|
+
X_denoised_L_traj = []
|
|
147
|
+
t_hats = []
|
|
148
|
+
|
|
149
|
+
for c_t_minus_1, c_t in zip(noise_schedule, noise_schedule[1:]):
|
|
150
|
+
# (All predicted atoms exist)
|
|
151
|
+
X_exists_L = torch.ones((D, L)).bool() # (D, L)
|
|
152
|
+
|
|
153
|
+
# Apply a random rotation and translation to the structure
|
|
154
|
+
# TODO: Make s_trans a hyperparameter
|
|
155
|
+
s_trans = 1.0
|
|
156
|
+
X_L = centre_random_augmentation(X_L, X_exists_L, s_trans)
|
|
157
|
+
|
|
158
|
+
# Update gamma
|
|
159
|
+
gamma = self.gamma_0 if c_t > self.gamma_min else 0
|
|
160
|
+
|
|
161
|
+
# Compute the value of t_hat
|
|
162
|
+
t_hat = c_t_minus_1 * (gamma + 1)
|
|
163
|
+
|
|
164
|
+
# Noise the coordinates with scaled Gaussian noise
|
|
165
|
+
epsilon_L = (
|
|
166
|
+
self.noise_scale
|
|
167
|
+
* torch.sqrt(torch.square(t_hat) - torch.square(c_t_minus_1))
|
|
168
|
+
* torch.normal(mean=0.0, std=1.0, size=X_L.shape, device=X_L.device)
|
|
169
|
+
)
|
|
170
|
+
X_noisy_L = X_L + epsilon_L
|
|
171
|
+
|
|
172
|
+
# Denoise the coordinates
|
|
173
|
+
X_denoised_L = diffusion_module(
|
|
174
|
+
X_noisy_L=X_noisy_L,
|
|
175
|
+
t=t_hat.tile(D),
|
|
176
|
+
f=f,
|
|
177
|
+
S_inputs_I=S_inputs_I,
|
|
178
|
+
S_trunk_I=S_trunk_I,
|
|
179
|
+
Z_trunk_II=Z_trunk_II,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Compute the delta between the noisy and denoised coordinates, scaled by t_hat
|
|
183
|
+
delta_L = (X_noisy_L - X_denoised_L) / t_hat
|
|
184
|
+
d_t = c_t - t_hat
|
|
185
|
+
|
|
186
|
+
# Update the coordinates, scaled by the step size
|
|
187
|
+
X_L = X_noisy_L + self.step_scale * d_t * delta_L
|
|
188
|
+
|
|
189
|
+
X_noisy_L_scaled = (
|
|
190
|
+
X_noisy_L
|
|
191
|
+
/ (torch.sqrt(t_hat[..., None, None] ** 2 + self.sigma_data**2))
|
|
192
|
+
) * self.sigma_data
|
|
193
|
+
# Append the results to the trajectory (for visualization of the diffusion process)
|
|
194
|
+
X_noisy_L_traj.append(X_noisy_L_scaled)
|
|
195
|
+
X_denoised_L_traj.append(X_denoised_L)
|
|
196
|
+
t_hats.append(t_hat)
|
|
197
|
+
|
|
198
|
+
return dict(
|
|
199
|
+
X_L=X_L, # (D, L, 3)
|
|
200
|
+
X_noisy_L_traj=X_noisy_L_traj, # list[Tensor[D, L, 3]]
|
|
201
|
+
X_denoised_L_traj=X_denoised_L_traj, # list[Tensor[D, L, 3]]
|
|
202
|
+
t_hats=t_hats, # list[Tensor[D]], where D is shared across all diffusion batches
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class SamplePartialDiffusion(SampleDiffusion):
|
|
207
|
+
def __init__(self, partial_t: int, **kwargs):
|
|
208
|
+
super().__init__(**kwargs)
|
|
209
|
+
self.partial_t = partial_t
|
|
210
|
+
|
|
211
|
+
def _construct_inference_noise_schedule(self, device: torch.device) -> torch.Tensor:
|
|
212
|
+
"""Constructs a noise schedule for use during inference with partial t."""
|
|
213
|
+
t_hat_full = super()._construct_inference_noise_schedule(device)
|
|
214
|
+
|
|
215
|
+
assert (
|
|
216
|
+
self.partial_t < self.num_timesteps
|
|
217
|
+
), f"Partial t ({self.partial_t}) must be less than num_timesteps ({self.num_timesteps})"
|
|
218
|
+
ranked_logger.info(
|
|
219
|
+
f"Using partial t index: {self.partial_t} [e.g., {t_hat_full[self.partial_t]:.4}], or {self.partial_t / (self.num_timesteps):.2%}, by index (100% is data, 0% is noise)"
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
return t_hat_full[self.partial_t :]
|
rf3/inference.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/rf3_exec.sh" "$0" "$@"'
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import hydra
|
|
6
|
+
import rootutils
|
|
7
|
+
from dotenv import load_dotenv
|
|
8
|
+
from hydra.utils import instantiate
|
|
9
|
+
from omegaconf import DictConfig, OmegaConf
|
|
10
|
+
|
|
11
|
+
from foundry.utils.logging import suppress_warnings
|
|
12
|
+
|
|
13
|
+
# Setup root dir and environment variables (more info: https://github.com/ashleve/rootutils)
|
|
14
|
+
# NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located)
|
|
15
|
+
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
16
|
+
|
|
17
|
+
load_dotenv(override=True)
|
|
18
|
+
|
|
19
|
+
_config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rf3/configs")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@hydra.main(
|
|
23
|
+
config_path=_config_path,
|
|
24
|
+
config_name="inference",
|
|
25
|
+
version_base="1.3",
|
|
26
|
+
)
|
|
27
|
+
def run_inference(cfg: DictConfig) -> None:
|
|
28
|
+
"""Execute RF3 inference pipeline."""
|
|
29
|
+
|
|
30
|
+
# Extract run() parameters from config
|
|
31
|
+
# Preserve string inputs, convert other sequence-like inputs to a Python list (None -> [])
|
|
32
|
+
inputs_param = cfg.inputs if isinstance(cfg.inputs, str) else list(cfg.inputs or [])
|
|
33
|
+
|
|
34
|
+
run_params = {
|
|
35
|
+
"inputs": inputs_param,
|
|
36
|
+
"out_dir": str(cfg.out_dir) if cfg.get("out_dir") else None,
|
|
37
|
+
"dump_predictions": cfg.get("dump_predictions", True),
|
|
38
|
+
"dump_trajectories": cfg.get("dump_trajectories", False),
|
|
39
|
+
"one_model_per_file": cfg.get("one_model_per_file", False),
|
|
40
|
+
"annotate_b_factor_with_plddt": cfg.get("annotate_b_factor_with_plddt", False),
|
|
41
|
+
"sharding_pattern": cfg.get("sharding_pattern", None),
|
|
42
|
+
"skip_existing": cfg.get("skip_existing", False),
|
|
43
|
+
"template_selection": cfg.get("template_selection", None),
|
|
44
|
+
"ground_truth_conformer_selection": cfg.get(
|
|
45
|
+
"ground_truth_conformer_selection", None
|
|
46
|
+
),
|
|
47
|
+
"cyclic_chains": cfg.get("cyclic_chains", []),
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
# Create init config with only __init__ params
|
|
51
|
+
cfg_dict = OmegaConf.to_container(cfg, resolve=True)
|
|
52
|
+
run_param_keys = set(run_params.keys())
|
|
53
|
+
init_cfg_dict = {k: v for k, v in cfg_dict.items() if k not in run_param_keys}
|
|
54
|
+
init_cfg = OmegaConf.create(init_cfg_dict)
|
|
55
|
+
|
|
56
|
+
# Instantiate engine (only __init__ params)
|
|
57
|
+
inference_engine = instantiate(init_cfg, _convert_="partial", _recursive_=False)
|
|
58
|
+
|
|
59
|
+
# Run inference
|
|
60
|
+
with suppress_warnings(is_inference=True):
|
|
61
|
+
inference_engine.run(**run_params)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
if __name__ == "__main__":
|
|
65
|
+
run_inference()
|