rc-foundry 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/inference_engines/checkpoint_registry.py +58 -11
- foundry/utils/alignment.py +10 -2
- foundry/version.py +2 -2
- foundry_cli/download_checkpoints.py +66 -66
- {rc_foundry-0.1.5.dist-info → rc_foundry-0.1.7.dist-info}/METADATA +25 -20
- rc_foundry-0.1.7.dist-info/RECORD +311 -0
- rf3/configs/callbacks/default.yaml +5 -0
- rf3/configs/callbacks/dump_validation_structures.yaml +6 -0
- rf3/configs/callbacks/metrics_logging.yaml +10 -0
- rf3/configs/callbacks/train_logging.yaml +16 -0
- rf3/configs/dataloader/default.yaml +15 -0
- rf3/configs/datasets/base.yaml +31 -0
- rf3/configs/datasets/pdb_and_distillation.yaml +58 -0
- rf3/configs/datasets/pdb_only.yaml +17 -0
- rf3/configs/datasets/train/disorder_distillation.yaml +48 -0
- rf3/configs/datasets/train/domain_distillation.yaml +50 -0
- rf3/configs/datasets/train/monomer_distillation.yaml +49 -0
- rf3/configs/datasets/train/na_complex_distillation.yaml +50 -0
- rf3/configs/datasets/train/pdb/af3_weighted_sampling.yaml +8 -0
- rf3/configs/datasets/train/pdb/base.yaml +32 -0
- rf3/configs/datasets/train/pdb/plinder.yaml +54 -0
- rf3/configs/datasets/train/pdb/train_interface.yaml +51 -0
- rf3/configs/datasets/train/pdb/train_pn_unit.yaml +46 -0
- rf3/configs/datasets/train/rna_monomer_distillation.yaml +56 -0
- rf3/configs/datasets/val/af3_ab_set.yaml +11 -0
- rf3/configs/datasets/val/af3_validation.yaml +11 -0
- rf3/configs/datasets/val/base.yaml +32 -0
- rf3/configs/datasets/val/runs_and_poses.yaml +12 -0
- rf3/configs/debug/default.yaml +66 -0
- rf3/configs/debug/train_specific_examples.yaml +21 -0
- rf3/configs/experiment/pretrained/rf3.yaml +50 -0
- rf3/configs/experiment/pretrained/rf3_with_confidence.yaml +13 -0
- rf3/configs/experiment/quick-rf3-with-confidence.yaml +15 -0
- rf3/configs/experiment/quick-rf3.yaml +61 -0
- rf3/configs/hydra/default.yaml +18 -0
- rf3/configs/hydra/no_logging.yaml +7 -0
- rf3/configs/inference.yaml +7 -0
- rf3/configs/inference_engine/base.yaml +23 -0
- rf3/configs/inference_engine/rf3.yaml +33 -0
- rf3/configs/logger/csv.yaml +6 -0
- rf3/configs/logger/default.yaml +3 -0
- rf3/configs/logger/wandb.yaml +15 -0
- rf3/configs/model/components/ema.yaml +1 -0
- rf3/configs/model/components/rf3_net.yaml +177 -0
- rf3/configs/model/components/rf3_net_with_confidence_head.yaml +45 -0
- rf3/configs/model/optimizers/adam.yaml +5 -0
- rf3/configs/model/rf3.yaml +43 -0
- rf3/configs/model/rf3_with_confidence.yaml +7 -0
- rf3/configs/model/schedulers/af3.yaml +6 -0
- rf3/configs/paths/data/default.yaml +43 -0
- rf3/configs/paths/default.yaml +21 -0
- rf3/configs/train.yaml +42 -0
- rf3/configs/trainer/cpu.yaml +6 -0
- rf3/configs/trainer/ddp.yaml +5 -0
- rf3/configs/trainer/loss/losses/confidence_loss.yaml +29 -0
- rf3/configs/trainer/loss/losses/diffusion_loss.yaml +9 -0
- rf3/configs/trainer/loss/losses/distogram_loss.yaml +2 -0
- rf3/configs/trainer/loss/structure_prediction.yaml +4 -0
- rf3/configs/trainer/loss/structure_prediction_with_confidence.yaml +2 -0
- rf3/configs/trainer/metrics/structure_prediction.yaml +14 -0
- rf3/configs/trainer/rf3.yaml +20 -0
- rf3/configs/trainer/rf3_with_confidence.yaml +13 -0
- rf3/configs/validate.yaml +45 -0
- rfd3/cli.py +10 -4
- rfd3/configs/__init__.py +0 -0
- rfd3/configs/callbacks/design_callbacks.yaml +10 -0
- rfd3/configs/callbacks/metrics_logging.yaml +20 -0
- rfd3/configs/callbacks/train_logging.yaml +24 -0
- rfd3/configs/dataloader/default.yaml +15 -0
- rfd3/configs/dataloader/fast.yaml +11 -0
- rfd3/configs/datasets/conditions/dna_condition.yaml +3 -0
- rfd3/configs/datasets/conditions/island.yaml +28 -0
- rfd3/configs/datasets/conditions/ppi.yaml +2 -0
- rfd3/configs/datasets/conditions/sequence_design.yaml +17 -0
- rfd3/configs/datasets/conditions/tipatom.yaml +28 -0
- rfd3/configs/datasets/conditions/unconditional.yaml +21 -0
- rfd3/configs/datasets/design_base.yaml +97 -0
- rfd3/configs/datasets/train/pdb/af3_train_interface.yaml +46 -0
- rfd3/configs/datasets/train/pdb/af3_train_pn_unit.yaml +42 -0
- rfd3/configs/datasets/train/pdb/base.yaml +14 -0
- rfd3/configs/datasets/train/pdb/base_no_weights.yaml +19 -0
- rfd3/configs/datasets/train/pdb/base_transform_args.yaml +59 -0
- rfd3/configs/datasets/train/pdb/na_complex_distillation.yaml +20 -0
- rfd3/configs/datasets/train/pdb/pdb_base.yaml +11 -0
- rfd3/configs/datasets/train/pdb/rfd3_train_interface.yaml +22 -0
- rfd3/configs/datasets/train/pdb/rfd3_train_pn_unit.yaml +23 -0
- rfd3/configs/datasets/train/rfd3_monomer_distillation.yaml +38 -0
- rfd3/configs/datasets/val/bcov_ppi_easy_medium.yaml +9 -0
- rfd3/configs/datasets/val/design_validation_base.yaml +40 -0
- rfd3/configs/datasets/val/dna_binder_design5.yaml +9 -0
- rfd3/configs/datasets/val/dna_binder_long.yaml +13 -0
- rfd3/configs/datasets/val/dna_binder_short.yaml +13 -0
- rfd3/configs/datasets/val/indexed.yaml +9 -0
- rfd3/configs/datasets/val/mcsa_41.yaml +9 -0
- rfd3/configs/datasets/val/mcsa_41_short_rigid.yaml +10 -0
- rfd3/configs/datasets/val/ppi_inference.yaml +7 -0
- rfd3/configs/datasets/val/sm_binder_hbonds.yaml +13 -0
- rfd3/configs/datasets/val/sm_binder_hbonds_short.yaml +15 -0
- rfd3/configs/datasets/val/unconditional.yaml +9 -0
- rfd3/configs/datasets/val/unconditional_deep.yaml +9 -0
- rfd3/configs/datasets/val/unindexed.yaml +8 -0
- rfd3/configs/datasets/val/val_examples/bcov_ppi_easy_medium_with_ori.yaml +151 -0
- rfd3/configs/datasets/val/val_examples/bcov_ppi_easy_medium_with_ori_spoof_helical_bundle.yaml +7 -0
- rfd3/configs/datasets/val/val_examples/bcov_ppi_easy_medium_with_ori_varying_lengths.yaml +28 -0
- rfd3/configs/datasets/val/val_examples/bpem_ori_hb.yaml +212 -0
- rfd3/configs/debug/default.yaml +64 -0
- rfd3/configs/debug/train_specific_examples.yaml +21 -0
- rfd3/configs/dev.yaml +9 -0
- rfd3/configs/experiment/debug.yaml +14 -0
- rfd3/configs/experiment/pretrain.yaml +31 -0
- rfd3/configs/experiment/test-uncond.yaml +10 -0
- rfd3/configs/experiment/test-unindexed.yaml +21 -0
- rfd3/configs/hydra/default.yaml +18 -0
- rfd3/configs/hydra/no_logging.yaml +7 -0
- rfd3/configs/inference.yaml +9 -0
- rfd3/configs/inference_engine/base.yaml +15 -0
- rfd3/configs/inference_engine/dev.yaml +20 -0
- rfd3/configs/inference_engine/rfdiffusion3.yaml +65 -0
- rfd3/configs/logger/csv.yaml +6 -0
- rfd3/configs/logger/default.yaml +2 -0
- rfd3/configs/logger/wandb.yaml +15 -0
- rfd3/configs/model/components/ema.yaml +1 -0
- rfd3/configs/model/components/rfd3_net.yaml +131 -0
- rfd3/configs/model/optimizers/adam.yaml +5 -0
- rfd3/configs/model/rfd3_base.yaml +8 -0
- rfd3/configs/model/samplers/edm.yaml +21 -0
- rfd3/configs/model/samplers/symmetry.yaml +10 -0
- rfd3/configs/model/schedulers/af3.yaml +6 -0
- rfd3/configs/paths/data/default.yaml +18 -0
- rfd3/configs/paths/default.yaml +22 -0
- rfd3/configs/train.yaml +28 -0
- rfd3/configs/trainer/cpu.yaml +6 -0
- rfd3/configs/trainer/ddp.yaml +5 -0
- rfd3/configs/trainer/loss/losses/diffusion_loss.yaml +12 -0
- rfd3/configs/trainer/loss/losses/sequence_loss.yaml +3 -0
- rfd3/configs/trainer/metrics/design_metrics.yaml +22 -0
- rfd3/configs/trainer/rfd3_base.yaml +35 -0
- rfd3/configs/validate.yaml +34 -0
- rfd3/engine.py +19 -11
- rfd3/inference/input_parsing.py +1 -1
- rfd3/inference/legacy_input_parsing.py +17 -1
- rfd3/inference/parsing.py +1 -0
- rfd3/inference/symmetry/atom_array.py +1 -5
- rfd3/inference/symmetry/checks.py +53 -28
- rfd3/inference/symmetry/frames.py +8 -5
- rfd3/inference/symmetry/symmetry_utils.py +38 -60
- rfd3/run_inference.py +3 -1
- rfd3/utils/inference.py +23 -0
- rc_foundry-0.1.5.dist-info/RECORD +0 -180
- {rc_foundry-0.1.5.dist-info → rc_foundry-0.1.7.dist-info}/WHEEL +0 -0
- {rc_foundry-0.1.5.dist-info → rc_foundry-0.1.7.dist-info}/entry_points.txt +0 -0
- {rc_foundry-0.1.5.dist-info → rc_foundry-0.1.7.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# https://hydra.cc/docs/configure_hydra/intro/
|
|
2
|
+
|
|
3
|
+
# enable color logging (requires `colorlog` to be installed)
|
|
4
|
+
# defaults:
|
|
5
|
+
# - override hydra_logging: colorlog
|
|
6
|
+
# - override job_logging: colorlog
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# output directory, generated dynamically on each run
|
|
10
|
+
run:
|
|
11
|
+
dir: ${paths.log_dir}/${task_name}/${name}/${now:%Y-%m-%d}_${now:%H-%M}_JOB_${oc.env:SLURM_JOB_ID,default}
|
|
12
|
+
|
|
13
|
+
# ... this is where the log file is written (i.e. the programs output)
|
|
14
|
+
job_logging:
|
|
15
|
+
handlers:
|
|
16
|
+
file:
|
|
17
|
+
# Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
|
|
18
|
+
filename: ${hydra.runtime.output_dir}/experiment.log
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# @package _global_
|
|
2
|
+
|
|
3
|
+
defaults:
|
|
4
|
+
- /hydra: no_logging
|
|
5
|
+
|
|
6
|
+
# Parameters for RFD3InferenceEngine.__init__()
|
|
7
|
+
ckpt_path: ???
|
|
8
|
+
num_nodes: 1
|
|
9
|
+
devices_per_node: 1
|
|
10
|
+
verbose: false
|
|
11
|
+
seed: null
|
|
12
|
+
|
|
13
|
+
# Parameters for RFD3InferenceEngine.run()
|
|
14
|
+
inputs: ???
|
|
15
|
+
out_dir: ???
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# @package _global_
|
|
2
|
+
defaults:
|
|
3
|
+
- rfdiffusion3
|
|
4
|
+
- _self_
|
|
5
|
+
|
|
6
|
+
diffusion_batch_size: 8
|
|
7
|
+
n_batches: 1
|
|
8
|
+
seed: 42
|
|
9
|
+
|
|
10
|
+
dump_trajectories: True
|
|
11
|
+
verbose: true
|
|
12
|
+
skip_existing: False
|
|
13
|
+
cleanup_guideposts: False
|
|
14
|
+
cleanup_virtual_atoms: False
|
|
15
|
+
output_full_json: True
|
|
16
|
+
|
|
17
|
+
inference_sampler:
|
|
18
|
+
gamma_0: 0.0
|
|
19
|
+
|
|
20
|
+
out_dir: ./logs/benchmark
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# @package _global_
|
|
2
|
+
defaults:
|
|
3
|
+
- base
|
|
4
|
+
- _self_
|
|
5
|
+
|
|
6
|
+
_target_: rfd3.engine.RFD3InferenceEngine
|
|
7
|
+
|
|
8
|
+
out_dir: ???
|
|
9
|
+
inputs: ??? # null, json, pdb or
|
|
10
|
+
ckpt_path: rfd3
|
|
11
|
+
json_keys_subset: null
|
|
12
|
+
skip_existing: True
|
|
13
|
+
|
|
14
|
+
#########################################################
|
|
15
|
+
# Design spec args: overrides args from input json
|
|
16
|
+
specification: {}
|
|
17
|
+
#########################################################
|
|
18
|
+
|
|
19
|
+
# Diffusion args
|
|
20
|
+
diffusion_batch_size: 8
|
|
21
|
+
n_batches: 1
|
|
22
|
+
|
|
23
|
+
# Inference sampler args | set to None to use the default in the checkpoint's config
|
|
24
|
+
inference_sampler:
|
|
25
|
+
kind: "default" # "default" or "symmetry" to choose the sampler
|
|
26
|
+
# Classifier-free guidance args:
|
|
27
|
+
cfg_features: # set to 0 in the reference CFG step
|
|
28
|
+
- active_donor
|
|
29
|
+
- active_acceptor
|
|
30
|
+
- ref_atomwise_rasa
|
|
31
|
+
|
|
32
|
+
use_classifier_free_guidance: False
|
|
33
|
+
cfg_t_max: null # max t to apply cfg guidance
|
|
34
|
+
cfg_scale: 1.5
|
|
35
|
+
center_option: "all" # Options are ["all", "motif", "diffuse"]
|
|
36
|
+
s_trans: 1.0 # Translational noise scale for augmentation during inference
|
|
37
|
+
inference_noise_scaling_factor: 1.0
|
|
38
|
+
allow_realignment: False
|
|
39
|
+
|
|
40
|
+
# Diffusion args:
|
|
41
|
+
num_timesteps: 200
|
|
42
|
+
step_scale: 1.5 # 1.5 - 1.0 | Higher values lead to less diverse, more designable, structures
|
|
43
|
+
noise_scale: 1.003
|
|
44
|
+
p: 7
|
|
45
|
+
gamma_0: 0.6 # Previously 1.0 | 0.0 for ODE sampling
|
|
46
|
+
gamma_min: 1.0
|
|
47
|
+
s_jitter_origin: 0.0 # Sigma of gaussian noise to jitter the motif offset (equivalent to ORI token Jitter)
|
|
48
|
+
|
|
49
|
+
# Saving args
|
|
50
|
+
cleanup_guideposts: True
|
|
51
|
+
cleanup_virtual_atoms: True
|
|
52
|
+
read_sequence_from_sequence_head: True
|
|
53
|
+
output_full_json: True
|
|
54
|
+
|
|
55
|
+
# Prefix to add to all output samples
|
|
56
|
+
# Default: None -> f'{jsonfilebasename}_{jsonkey}_{batch}_{model}'
|
|
57
|
+
# Otherwise: string -> f'{string}{jsonkey}_{batch}_{model}'
|
|
58
|
+
# e.g. Empty string -> f'{jsonkey}_{batch}_{model}'
|
|
59
|
+
# e.g. Chunk string -> f'{chunkprefix_}{jsonkey}_{batch}_{model}' (pipelines usage)
|
|
60
|
+
global_prefix: null
|
|
61
|
+
dump_prediction_metadata_json: True
|
|
62
|
+
dump_trajectories: False
|
|
63
|
+
align_trajectory_structures: False
|
|
64
|
+
prevalidate_inputs: False
|
|
65
|
+
low_memory_mode: False # False for standard mode, True for memory efficient tokenization mode
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# https://wandb.ai
|
|
2
|
+
|
|
3
|
+
wandb:
|
|
4
|
+
_target_: wandb.integration.lightning.fabric.WandbLogger
|
|
5
|
+
save_dir: ${paths.output_dir}
|
|
6
|
+
name: ${name}
|
|
7
|
+
offline: False
|
|
8
|
+
id: null # pass correct id (along with checkpoint path, and resume='allow' or 'must') to resume a run
|
|
9
|
+
anonymous: null # enable anonymous logging
|
|
10
|
+
project: ${project}
|
|
11
|
+
prefix: "" # a string to put at the beginning of metric keys
|
|
12
|
+
log_model: False # do not upload model checkpoints
|
|
13
|
+
tags: ${tags}
|
|
14
|
+
# (Default resume to "never" to avoid accidentally resuming runs; we want to be explicit about resuming)
|
|
15
|
+
resume: never # never, allow, or must (see: https://docs.wandb.ai/guides/runs/resuming/)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
decay: 0.999 # From AF-3
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
_target_: rfd3.model.RFD3.RFD3
|
|
2
|
+
|
|
3
|
+
c_s: 384
|
|
4
|
+
c_z: 128
|
|
5
|
+
c_atom: 128
|
|
6
|
+
c_atompair: 16
|
|
7
|
+
|
|
8
|
+
token_initializer: # formerly known as the trunk
|
|
9
|
+
relative_position_encoding:
|
|
10
|
+
r_max: 32
|
|
11
|
+
s_max: 2
|
|
12
|
+
|
|
13
|
+
# Attention pair biases without batch dimensions
|
|
14
|
+
n_pairformer_blocks: 2
|
|
15
|
+
pairformer_block:
|
|
16
|
+
use_triangle_attn: false
|
|
17
|
+
use_triangle_mult: false
|
|
18
|
+
attention_pair_bias:
|
|
19
|
+
n_head: 16
|
|
20
|
+
kq_norm: True
|
|
21
|
+
|
|
22
|
+
token_1d_features:
|
|
23
|
+
ref_motif_token_type: 3
|
|
24
|
+
restype: 32
|
|
25
|
+
ref_plddt: 1
|
|
26
|
+
is_non_loopy: 1
|
|
27
|
+
|
|
28
|
+
downcast: ${model.net.diffusion_module.downcast}
|
|
29
|
+
atom_1d_features:
|
|
30
|
+
ref_atom_name_chars: 256
|
|
31
|
+
ref_element: 128
|
|
32
|
+
ref_charge: 1
|
|
33
|
+
ref_mask: 1
|
|
34
|
+
ref_is_motif_atom_with_fixed_coord: 1
|
|
35
|
+
ref_is_motif_atom_unindexed: 1
|
|
36
|
+
has_zero_occupancy: 1
|
|
37
|
+
ref_pos: 3
|
|
38
|
+
|
|
39
|
+
# Guided features
|
|
40
|
+
ref_atomwise_rasa: 3
|
|
41
|
+
active_donor: 1
|
|
42
|
+
active_acceptor: 1
|
|
43
|
+
is_atom_level_hotspot: 1
|
|
44
|
+
|
|
45
|
+
atom_transformer:
|
|
46
|
+
n_blocks: 0
|
|
47
|
+
atom_transformer_block:
|
|
48
|
+
n_head: 4
|
|
49
|
+
kq_norm: True
|
|
50
|
+
no_residual_connection_between_attention_and_transition: False
|
|
51
|
+
dropout: 0.0
|
|
52
|
+
n_attn_seq_neighbours: 4
|
|
53
|
+
n_attn_keys: 128
|
|
54
|
+
|
|
55
|
+
diffusion_module:
|
|
56
|
+
_target_: rfd3.model.RFD3_diffusion_module.RFD3DiffusionModule
|
|
57
|
+
c_token: 768
|
|
58
|
+
c_t_embed: 256 # Time embedding dimension
|
|
59
|
+
sigma_data: 16
|
|
60
|
+
f_pred: edm
|
|
61
|
+
n_attn_seq_neighbours: 2 # include self + n flanking neighbours
|
|
62
|
+
n_attn_keys: 128
|
|
63
|
+
n_recycle: 2
|
|
64
|
+
use_local_token_attention: false
|
|
65
|
+
|
|
66
|
+
# Upcast/downcast mechanisms
|
|
67
|
+
upcast:
|
|
68
|
+
method: cross_attention
|
|
69
|
+
n_split: 3
|
|
70
|
+
cross_attention_block:
|
|
71
|
+
n_head: 4
|
|
72
|
+
c_model: 128
|
|
73
|
+
dropout: 0.0
|
|
74
|
+
kq_norm: True
|
|
75
|
+
|
|
76
|
+
downcast:
|
|
77
|
+
method: cross_attention
|
|
78
|
+
cross_attention_block:
|
|
79
|
+
n_head: 4
|
|
80
|
+
c_model: 128
|
|
81
|
+
dropout: 0.0
|
|
82
|
+
kq_norm: True
|
|
83
|
+
|
|
84
|
+
########################################################################
|
|
85
|
+
# UNet level processing
|
|
86
|
+
########################################################################
|
|
87
|
+
atom_attention_encoder:
|
|
88
|
+
n_blocks: 3
|
|
89
|
+
atom_transformer_block:
|
|
90
|
+
n_head: 4
|
|
91
|
+
kq_norm: True
|
|
92
|
+
no_residual_connection_between_attention_and_transition: False
|
|
93
|
+
dropout: 0.0
|
|
94
|
+
|
|
95
|
+
diffusion_token_encoder: # encodes self conditioning information and distogram
|
|
96
|
+
use_distogram: True
|
|
97
|
+
use_self: True
|
|
98
|
+
use_sinusoidal_distogram_embedder: False
|
|
99
|
+
sigma_data: ${model.net.diffusion_module.sigma_data}
|
|
100
|
+
|
|
101
|
+
n_pairformer_blocks: 2
|
|
102
|
+
pairformer_block:
|
|
103
|
+
use_triangle_attn: false
|
|
104
|
+
use_triangle_mult: false
|
|
105
|
+
attention_pair_bias:
|
|
106
|
+
n_head: 16
|
|
107
|
+
kq_norm: True
|
|
108
|
+
|
|
109
|
+
diffusion_transformer:
|
|
110
|
+
n_block: 18
|
|
111
|
+
n_registers: 0 # 8 Idk if they do anything tbh
|
|
112
|
+
diffusion_transformer_block:
|
|
113
|
+
n_head: 16
|
|
114
|
+
kq_norm: True
|
|
115
|
+
no_residual_connection_between_attention_and_transition: False
|
|
116
|
+
dropout: 0.10
|
|
117
|
+
|
|
118
|
+
atom_attention_decoder:
|
|
119
|
+
n_blocks: 3
|
|
120
|
+
upcast: ${model.net.diffusion_module.upcast}
|
|
121
|
+
downcast: ${model.net.diffusion_module.downcast}
|
|
122
|
+
|
|
123
|
+
atom_transformer_block:
|
|
124
|
+
n_head: 4
|
|
125
|
+
kq_norm: True
|
|
126
|
+
no_residual_connection_between_attention_and_transition: False
|
|
127
|
+
dropout: 0.10
|
|
128
|
+
|
|
129
|
+
########################################################################
|
|
130
|
+
#
|
|
131
|
+
########################################################################
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
|
|
2
|
+
kind: "default" # "default", "symmetry", or "partial" to choose the sampler
|
|
3
|
+
solver: "af3"
|
|
4
|
+
num_timesteps: 100
|
|
5
|
+
min_t: 0
|
|
6
|
+
max_t: 1
|
|
7
|
+
sigma_data: ${model.net.diffusion_module.sigma_data}
|
|
8
|
+
s_min: 4e-4
|
|
9
|
+
s_max: 160
|
|
10
|
+
p: 7
|
|
11
|
+
gamma_0: 0.8
|
|
12
|
+
gamma_min: 1.0
|
|
13
|
+
noise_scale: 1.003
|
|
14
|
+
step_scale: 1.5
|
|
15
|
+
allow_realignment: False
|
|
16
|
+
use_classifier_free_guidance: False
|
|
17
|
+
cfg_scale: 1.5
|
|
18
|
+
cfg_features: # CFG_features will be set to 0 in the unconditional CFG step
|
|
19
|
+
- ref_atomwise_rasa
|
|
20
|
+
- active_donor
|
|
21
|
+
- active_acceptor
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# path to directory with training splits
|
|
2
|
+
pdb_data_dir: /projects/ml/frozen_pdb_copies/2025_07_13_pdb
|
|
3
|
+
pdb_parquet_dir: /projects/ml/datahub/dfs/af3_splits/2024_12_16/ # TODO: uncomment
|
|
4
|
+
|
|
5
|
+
# monomer distillation dataset
|
|
6
|
+
monomer_distillation_data_dir: /squash/af2_distillation_facebook/
|
|
7
|
+
monomer_distillation_parquet_dir: /projects/ml/datahub/dfs/distillation/af2_distillation_facebook
|
|
8
|
+
|
|
9
|
+
# path to save examples that fail during the Transform pipeline (null = do not save)
|
|
10
|
+
failed_examples_dir: null
|
|
11
|
+
|
|
12
|
+
design_benchmark_data_dir: /projects/ml/aa_design/benchmarks
|
|
13
|
+
design_model_weight_dir: /projects/ml/aa_design/models
|
|
14
|
+
|
|
15
|
+
# path to directory with cached residue data
|
|
16
|
+
residue_cache_dir: /net/tukwila/lschaaf/datahub/MACE-Egret-3-noH/mace_embeddings
|
|
17
|
+
|
|
18
|
+
cif_cache_dir: /net/tukwila/ncorley/cifutils/cache
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# NOTE: order of defaults determines the order in which configs override each other (higher up items are overridden by lower items)
|
|
2
|
+
defaults:
|
|
3
|
+
- _self_
|
|
4
|
+
- data: default
|
|
5
|
+
|
|
6
|
+
# path to root directory (requires the `PROJECT_ROOT` environment variable to be set)
|
|
7
|
+
# NOTE: This variable is auto-set upon loading via `rootutils`
|
|
8
|
+
root_dir: ${oc.env:PROJECT_ROOT}
|
|
9
|
+
|
|
10
|
+
# where to store data (checkpoints, logs, etc.) of all experiments in general
|
|
11
|
+
# (this influences the output_dir in the hydra/default.yaml config)
|
|
12
|
+
# change this to e.g. /scratch if you are running larger experiments with lots lof logs, checkpoints, etc.
|
|
13
|
+
# log_dir: ${.root_dir}/logs
|
|
14
|
+
log_dir: /net/scratch/${oc.env:USER}/training/logs
|
|
15
|
+
|
|
16
|
+
# path to output directory for this specific run, created dynamically by hydra
|
|
17
|
+
# path generation pattern is specified in `configs/hydra/default.yaml`
|
|
18
|
+
# use it to store all files generated during the run, like ckpts and metrics
|
|
19
|
+
output_dir: ${hydra:runtime.output_dir}
|
|
20
|
+
|
|
21
|
+
# path to working directory (auto-generated by hydra)
|
|
22
|
+
work_dir: ${hydra:runtime.cwd}
|
rfd3/configs/train.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# @package _global_
|
|
2
|
+
hydra:
|
|
3
|
+
searchpath:
|
|
4
|
+
- pkg://configs
|
|
5
|
+
|
|
6
|
+
defaults:
|
|
7
|
+
- model: rfd3_base
|
|
8
|
+
- trainer: rfd3_base
|
|
9
|
+
- datasets: design_base
|
|
10
|
+
- callbacks: design_callbacks
|
|
11
|
+
- dataloader: fast
|
|
12
|
+
- paths: default
|
|
13
|
+
- hydra: default
|
|
14
|
+
- logger: default
|
|
15
|
+
- _self_
|
|
16
|
+
# Required overrides:
|
|
17
|
+
- experiment: ???
|
|
18
|
+
- debug: null
|
|
19
|
+
|
|
20
|
+
# Definitions:
|
|
21
|
+
task_name: train
|
|
22
|
+
project: aa_design
|
|
23
|
+
seed: 42
|
|
24
|
+
ckpt_path: null
|
|
25
|
+
|
|
26
|
+
# Placeholders
|
|
27
|
+
name: aa_design
|
|
28
|
+
tags: [aa_design]
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
_target_: rfd3.metrics.losses.DiffusionLoss
|
|
2
|
+
sigma_data: ${model.net.diffusion_module.sigma_data}
|
|
3
|
+
weight: 4.0
|
|
4
|
+
lddt_weight: 0.25
|
|
5
|
+
alpha_virtual_atom: 1.0
|
|
6
|
+
alpha_polar_residues: 1.0
|
|
7
|
+
lp_weight: 0.0
|
|
8
|
+
unindexed_norm_p: 1.0
|
|
9
|
+
alpha_unindexed_diffused: 1.0
|
|
10
|
+
unindexed_t_alpha: 0.75
|
|
11
|
+
normalize_virtual_atom_weight: False
|
|
12
|
+
alpha_ligand: 10.0
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
general_metrics:
|
|
2
|
+
_target_: rfd3.metrics.design_metrics.AtomArrayMetrics
|
|
3
|
+
compute_for_diffused_region_only: True
|
|
4
|
+
compute_ss_adherence_if_possible: False
|
|
5
|
+
|
|
6
|
+
backbone_metrics:
|
|
7
|
+
_target_: rfd3.metrics.design_metrics.BackboneMetrics
|
|
8
|
+
|
|
9
|
+
sidechain_metrics:
|
|
10
|
+
_target_: rfd3.metrics.sidechain_metrics.SidechainMetrics
|
|
11
|
+
central_atom: CB
|
|
12
|
+
dist_threshold_min: 1.0 # min distance for identifying a bond
|
|
13
|
+
dist_threshold_max: 2.0 # max distance for identifying a bond
|
|
14
|
+
already_removed_virtual_atoms: ${trainer.cleanup_virtual_atoms}
|
|
15
|
+
|
|
16
|
+
metadata_metrics:
|
|
17
|
+
_target_: rfd3.metrics.design_metrics.MetadataMetrics
|
|
18
|
+
|
|
19
|
+
hbond_metrics:
|
|
20
|
+
_target_: rfd3.metrics.hbonds_hbplus_metrics.HbondMetrics
|
|
21
|
+
cutoff_HA_dist: 3
|
|
22
|
+
cutoff_DA_distance: 3.5
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
defaults:
|
|
2
|
+
- ddp
|
|
3
|
+
- loss/losses/diffusion_loss@loss.diffusion_loss
|
|
4
|
+
- loss/losses/sequence_loss@loss.sequence_loss
|
|
5
|
+
- metrics: design_metrics
|
|
6
|
+
- _self_
|
|
7
|
+
|
|
8
|
+
_target_: rfd3.trainer.rfd3.AADesignTrainer
|
|
9
|
+
|
|
10
|
+
# AADesign specific (atom-array related):
|
|
11
|
+
output_full_json: False # saves additional metadata in the output json
|
|
12
|
+
allow_sequence_outputs: True
|
|
13
|
+
cleanup_guideposts: False
|
|
14
|
+
cleanup_virtual_atoms: False
|
|
15
|
+
read_sequence_from_sequence_head: True
|
|
16
|
+
compute_non_clash_metrics_for_diffused_region_only: False
|
|
17
|
+
association_scheme: ${datasets.global_transform_args.association_scheme}
|
|
18
|
+
|
|
19
|
+
# Other:
|
|
20
|
+
n_examples_per_epoch: 2400 # 24000 # 10x as many epochs
|
|
21
|
+
checkpoint_every_n_epochs: 10 # Less often checkpointing for fewer epochs
|
|
22
|
+
validate_every_n_epochs: 4 # Validate often
|
|
23
|
+
|
|
24
|
+
max_epochs: 100_000
|
|
25
|
+
prevalidate: False
|
|
26
|
+
|
|
27
|
+
clip_grad_max_norm: 10.0
|
|
28
|
+
output_dir: ${paths.output_dir}
|
|
29
|
+
n_recycles_train: 2
|
|
30
|
+
grad_accum_steps: 3 # overridden by launch.sh
|
|
31
|
+
skip_optimizer_loading: True
|
|
32
|
+
|
|
33
|
+
# Precision
|
|
34
|
+
error_if_grad_nonfinite: False
|
|
35
|
+
precision: bf16-mixed
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# @package _global_
|
|
2
|
+
|
|
3
|
+
hydra:
|
|
4
|
+
searchpath:
|
|
5
|
+
- pkg://configs
|
|
6
|
+
|
|
7
|
+
defaults:
|
|
8
|
+
- model: rfd3_base
|
|
9
|
+
- trainer: rfd3_base
|
|
10
|
+
- datasets: design_base
|
|
11
|
+
- callbacks: design_callbacks
|
|
12
|
+
- dataloader: fast
|
|
13
|
+
- paths: default
|
|
14
|
+
- hydra: default
|
|
15
|
+
- logger: csv
|
|
16
|
+
- _self_
|
|
17
|
+
- experiment: ???
|
|
18
|
+
- debug: null
|
|
19
|
+
|
|
20
|
+
name: ???
|
|
21
|
+
tags: ???
|
|
22
|
+
project: aa_design
|
|
23
|
+
# Names;
|
|
24
|
+
task_name: "validate"
|
|
25
|
+
seed: 42
|
|
26
|
+
|
|
27
|
+
callbacks:
|
|
28
|
+
dump_validation_structures_callback:
|
|
29
|
+
dump_predictions: True
|
|
30
|
+
one_model_per_file: True
|
|
31
|
+
dump_trajectories: False
|
|
32
|
+
|
|
33
|
+
# Args:
|
|
34
|
+
ckpt_path: ???
|
rfd3/engine.py
CHANGED
|
@@ -23,7 +23,10 @@ from rfd3.inference.datasets import (
|
|
|
23
23
|
)
|
|
24
24
|
from rfd3.inference.input_parsing import DesignInputSpecification
|
|
25
25
|
from rfd3.model.inference_sampler import SampleDiffusionConfig
|
|
26
|
-
from rfd3.utils.inference import
|
|
26
|
+
from rfd3.utils.inference import (
|
|
27
|
+
ensure_inference_sampler_matches_design_spec,
|
|
28
|
+
ensure_input_is_abspath,
|
|
29
|
+
)
|
|
27
30
|
from rfd3.utils.io import (
|
|
28
31
|
CIF_LIKE_EXTENSIONS,
|
|
29
32
|
build_stack_from_atom_array_and_batched_coords,
|
|
@@ -171,6 +174,7 @@ class RFD3InferenceEngine(BaseInferenceEngine):
|
|
|
171
174
|
)
|
|
172
175
|
# save
|
|
173
176
|
self.specification_overrides = dict(specification or {})
|
|
177
|
+
self.inference_sampler_overrides = dict(inference_sampler or {})
|
|
174
178
|
|
|
175
179
|
# Setup output directories and args
|
|
176
180
|
self.global_prefix = global_prefix
|
|
@@ -210,6 +214,9 @@ class RFD3InferenceEngine(BaseInferenceEngine):
|
|
|
210
214
|
inputs=inputs,
|
|
211
215
|
n_batches=n_batches,
|
|
212
216
|
)
|
|
217
|
+
ensure_inference_sampler_matches_design_spec(
|
|
218
|
+
design_specifications, self.inference_sampler_overrides
|
|
219
|
+
)
|
|
213
220
|
# init before
|
|
214
221
|
self.initialize()
|
|
215
222
|
outputs = self._run_multi(design_specifications)
|
|
@@ -383,6 +390,9 @@ class RFD3InferenceEngine(BaseInferenceEngine):
|
|
|
383
390
|
# Based on inputs, construct the specifications to loop through
|
|
384
391
|
design_specifications = {}
|
|
385
392
|
for prefix, example_spec in inputs.items():
|
|
393
|
+
# Record task name in the specification
|
|
394
|
+
example_spec["extra"]["task_name"] = prefix
|
|
395
|
+
|
|
386
396
|
# ... Create n_batches for example
|
|
387
397
|
for batch_id in range((n_batches) if exists(n_batches) else 1):
|
|
388
398
|
# ... Example ID
|
|
@@ -524,21 +534,19 @@ def process_input(
|
|
|
524
534
|
|
|
525
535
|
|
|
526
536
|
def _reshape_trajectory(traj, align_structures: bool):
|
|
527
|
-
traj = [traj[i] for i in range(len(traj))]
|
|
528
|
-
n_steps = len(traj)
|
|
537
|
+
traj = [traj[i] for i in range(len(traj))] # make list of arrays
|
|
529
538
|
max_frames = 100
|
|
530
|
-
|
|
539
|
+
if len(traj) > max_frames:
|
|
540
|
+
selected_indices = torch.linspace(0, len(traj) - 1, max_frames).long().tolist()
|
|
541
|
+
traj = [traj[i] for i in selected_indices]
|
|
531
542
|
if align_structures:
|
|
532
543
|
# ... align the trajectories on the last prediction
|
|
533
|
-
for step in range(
|
|
544
|
+
for step in range(len(traj) - 1):
|
|
534
545
|
traj[step] = weighted_rigid_align(
|
|
535
|
-
X_L=traj[-1],
|
|
536
|
-
X_gt_L=traj[step],
|
|
537
|
-
)
|
|
546
|
+
X_L=traj[-1][None],
|
|
547
|
+
X_gt_L=traj[step][None],
|
|
548
|
+
).squeeze(0)
|
|
538
549
|
traj = traj[::-1] # reverse to go from noised -> denoised
|
|
539
|
-
if n_steps > max_frames:
|
|
540
|
-
selected_indices = torch.linspace(0, n_steps - 1, max_frames).long().tolist()
|
|
541
|
-
traj = [traj[i] for i in selected_indices]
|
|
542
550
|
|
|
543
551
|
traj = torch.stack(traj).cpu().numpy()
|
|
544
552
|
return traj
|
rfd3/inference/input_parsing.py
CHANGED
|
@@ -696,7 +696,7 @@ class DesignInputSpecification(BaseModel):
|
|
|
696
696
|
# Partial diffusion: use COM, keep all coordinates
|
|
697
697
|
if exists(self.symmetry) and self.symmetry.id:
|
|
698
698
|
# For symmetric structures, avoid COM centering that would collapse chains
|
|
699
|
-
|
|
699
|
+
logger.info(
|
|
700
700
|
"Partial diffusion with symmetry: skipping COM centering to preserve chain spacing"
|
|
701
701
|
)
|
|
702
702
|
else:
|