rc-foundry 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (139) hide show
  1. foundry/version.py +2 -2
  2. {rc_foundry-0.1.4.dist-info → rc_foundry-0.1.6.dist-info}/METADATA +1 -1
  3. {rc_foundry-0.1.4.dist-info → rc_foundry-0.1.6.dist-info}/RECORD +139 -8
  4. rf3/configs/callbacks/default.yaml +5 -0
  5. rf3/configs/callbacks/dump_validation_structures.yaml +6 -0
  6. rf3/configs/callbacks/metrics_logging.yaml +10 -0
  7. rf3/configs/callbacks/train_logging.yaml +16 -0
  8. rf3/configs/dataloader/default.yaml +15 -0
  9. rf3/configs/datasets/base.yaml +31 -0
  10. rf3/configs/datasets/pdb_and_distillation.yaml +58 -0
  11. rf3/configs/datasets/pdb_only.yaml +17 -0
  12. rf3/configs/datasets/train/disorder_distillation.yaml +48 -0
  13. rf3/configs/datasets/train/domain_distillation.yaml +50 -0
  14. rf3/configs/datasets/train/monomer_distillation.yaml +49 -0
  15. rf3/configs/datasets/train/na_complex_distillation.yaml +50 -0
  16. rf3/configs/datasets/train/pdb/af3_weighted_sampling.yaml +8 -0
  17. rf3/configs/datasets/train/pdb/base.yaml +32 -0
  18. rf3/configs/datasets/train/pdb/plinder.yaml +54 -0
  19. rf3/configs/datasets/train/pdb/train_interface.yaml +51 -0
  20. rf3/configs/datasets/train/pdb/train_pn_unit.yaml +46 -0
  21. rf3/configs/datasets/train/rna_monomer_distillation.yaml +56 -0
  22. rf3/configs/datasets/val/af3_ab_set.yaml +11 -0
  23. rf3/configs/datasets/val/af3_validation.yaml +11 -0
  24. rf3/configs/datasets/val/base.yaml +32 -0
  25. rf3/configs/datasets/val/runs_and_poses.yaml +12 -0
  26. rf3/configs/debug/default.yaml +66 -0
  27. rf3/configs/debug/train_specific_examples.yaml +21 -0
  28. rf3/configs/experiment/pretrained/rf3.yaml +50 -0
  29. rf3/configs/experiment/pretrained/rf3_with_confidence.yaml +13 -0
  30. rf3/configs/experiment/quick-rf3-with-confidence.yaml +15 -0
  31. rf3/configs/experiment/quick-rf3.yaml +61 -0
  32. rf3/configs/hydra/default.yaml +18 -0
  33. rf3/configs/hydra/no_logging.yaml +7 -0
  34. rf3/configs/inference.yaml +7 -0
  35. rf3/configs/inference_engine/base.yaml +23 -0
  36. rf3/configs/inference_engine/rf3.yaml +33 -0
  37. rf3/configs/logger/csv.yaml +6 -0
  38. rf3/configs/logger/default.yaml +3 -0
  39. rf3/configs/logger/wandb.yaml +15 -0
  40. rf3/configs/model/components/ema.yaml +1 -0
  41. rf3/configs/model/components/rf3_net.yaml +177 -0
  42. rf3/configs/model/components/rf3_net_with_confidence_head.yaml +45 -0
  43. rf3/configs/model/optimizers/adam.yaml +5 -0
  44. rf3/configs/model/rf3.yaml +43 -0
  45. rf3/configs/model/rf3_with_confidence.yaml +7 -0
  46. rf3/configs/model/schedulers/af3.yaml +6 -0
  47. rf3/configs/paths/data/default.yaml +43 -0
  48. rf3/configs/paths/default.yaml +21 -0
  49. rf3/configs/train.yaml +42 -0
  50. rf3/configs/trainer/cpu.yaml +6 -0
  51. rf3/configs/trainer/ddp.yaml +5 -0
  52. rf3/configs/trainer/loss/losses/confidence_loss.yaml +29 -0
  53. rf3/configs/trainer/loss/losses/diffusion_loss.yaml +9 -0
  54. rf3/configs/trainer/loss/losses/distogram_loss.yaml +2 -0
  55. rf3/configs/trainer/loss/structure_prediction.yaml +4 -0
  56. rf3/configs/trainer/loss/structure_prediction_with_confidence.yaml +2 -0
  57. rf3/configs/trainer/metrics/structure_prediction.yaml +14 -0
  58. rf3/configs/trainer/rf3.yaml +20 -0
  59. rf3/configs/trainer/rf3_with_confidence.yaml +13 -0
  60. rf3/configs/validate.yaml +45 -0
  61. rfd3/cli.py +10 -4
  62. rfd3/configs/__init__.py +0 -0
  63. rfd3/configs/callbacks/design_callbacks.yaml +10 -0
  64. rfd3/configs/callbacks/metrics_logging.yaml +20 -0
  65. rfd3/configs/callbacks/train_logging.yaml +24 -0
  66. rfd3/configs/dataloader/default.yaml +15 -0
  67. rfd3/configs/dataloader/fast.yaml +11 -0
  68. rfd3/configs/datasets/conditions/dna_condition.yaml +3 -0
  69. rfd3/configs/datasets/conditions/island.yaml +28 -0
  70. rfd3/configs/datasets/conditions/ppi.yaml +2 -0
  71. rfd3/configs/datasets/conditions/sequence_design.yaml +17 -0
  72. rfd3/configs/datasets/conditions/tipatom.yaml +28 -0
  73. rfd3/configs/datasets/conditions/unconditional.yaml +21 -0
  74. rfd3/configs/datasets/design_base.yaml +97 -0
  75. rfd3/configs/datasets/train/pdb/af3_train_interface.yaml +46 -0
  76. rfd3/configs/datasets/train/pdb/af3_train_pn_unit.yaml +42 -0
  77. rfd3/configs/datasets/train/pdb/base.yaml +14 -0
  78. rfd3/configs/datasets/train/pdb/base_no_weights.yaml +19 -0
  79. rfd3/configs/datasets/train/pdb/base_transform_args.yaml +59 -0
  80. rfd3/configs/datasets/train/pdb/na_complex_distillation.yaml +20 -0
  81. rfd3/configs/datasets/train/pdb/pdb_base.yaml +11 -0
  82. rfd3/configs/datasets/train/pdb/rfd3_train_interface.yaml +22 -0
  83. rfd3/configs/datasets/train/pdb/rfd3_train_pn_unit.yaml +23 -0
  84. rfd3/configs/datasets/train/rfd3_monomer_distillation.yaml +38 -0
  85. rfd3/configs/datasets/val/bcov_ppi_easy_medium.yaml +9 -0
  86. rfd3/configs/datasets/val/design_validation_base.yaml +40 -0
  87. rfd3/configs/datasets/val/dna_binder_design5.yaml +9 -0
  88. rfd3/configs/datasets/val/dna_binder_long.yaml +13 -0
  89. rfd3/configs/datasets/val/dna_binder_short.yaml +13 -0
  90. rfd3/configs/datasets/val/indexed.yaml +9 -0
  91. rfd3/configs/datasets/val/mcsa_41.yaml +9 -0
  92. rfd3/configs/datasets/val/mcsa_41_short_rigid.yaml +10 -0
  93. rfd3/configs/datasets/val/ppi_inference.yaml +7 -0
  94. rfd3/configs/datasets/val/sm_binder_hbonds.yaml +13 -0
  95. rfd3/configs/datasets/val/sm_binder_hbonds_short.yaml +15 -0
  96. rfd3/configs/datasets/val/unconditional.yaml +9 -0
  97. rfd3/configs/datasets/val/unconditional_deep.yaml +9 -0
  98. rfd3/configs/datasets/val/unindexed.yaml +8 -0
  99. rfd3/configs/datasets/val/val_examples/bcov_ppi_easy_medium_with_ori.yaml +151 -0
  100. rfd3/configs/datasets/val/val_examples/bcov_ppi_easy_medium_with_ori_spoof_helical_bundle.yaml +7 -0
  101. rfd3/configs/datasets/val/val_examples/bcov_ppi_easy_medium_with_ori_varying_lengths.yaml +28 -0
  102. rfd3/configs/datasets/val/val_examples/bpem_ori_hb.yaml +212 -0
  103. rfd3/configs/debug/default.yaml +64 -0
  104. rfd3/configs/debug/train_specific_examples.yaml +21 -0
  105. rfd3/configs/dev.yaml +9 -0
  106. rfd3/configs/experiment/debug.yaml +14 -0
  107. rfd3/configs/experiment/pretrain.yaml +31 -0
  108. rfd3/configs/experiment/test-uncond.yaml +10 -0
  109. rfd3/configs/experiment/test-unindexed.yaml +21 -0
  110. rfd3/configs/hydra/default.yaml +18 -0
  111. rfd3/configs/hydra/no_logging.yaml +7 -0
  112. rfd3/configs/inference.yaml +9 -0
  113. rfd3/configs/inference_engine/base.yaml +15 -0
  114. rfd3/configs/inference_engine/dev.yaml +20 -0
  115. rfd3/configs/inference_engine/rfdiffusion3.yaml +65 -0
  116. rfd3/configs/logger/csv.yaml +6 -0
  117. rfd3/configs/logger/default.yaml +2 -0
  118. rfd3/configs/logger/wandb.yaml +15 -0
  119. rfd3/configs/model/components/ema.yaml +1 -0
  120. rfd3/configs/model/components/rfd3_net.yaml +131 -0
  121. rfd3/configs/model/optimizers/adam.yaml +5 -0
  122. rfd3/configs/model/rfd3_base.yaml +8 -0
  123. rfd3/configs/model/samplers/edm.yaml +21 -0
  124. rfd3/configs/model/samplers/symmetry.yaml +10 -0
  125. rfd3/configs/model/schedulers/af3.yaml +6 -0
  126. rfd3/configs/paths/data/default.yaml +18 -0
  127. rfd3/configs/paths/default.yaml +22 -0
  128. rfd3/configs/train.yaml +28 -0
  129. rfd3/configs/trainer/cpu.yaml +6 -0
  130. rfd3/configs/trainer/ddp.yaml +5 -0
  131. rfd3/configs/trainer/loss/losses/diffusion_loss.yaml +12 -0
  132. rfd3/configs/trainer/loss/losses/sequence_loss.yaml +3 -0
  133. rfd3/configs/trainer/metrics/design_metrics.yaml +22 -0
  134. rfd3/configs/trainer/rfd3_base.yaml +35 -0
  135. rfd3/configs/validate.yaml +34 -0
  136. rfd3/run_inference.py +3 -7
  137. {rc_foundry-0.1.4.dist-info → rc_foundry-0.1.6.dist-info}/WHEEL +0 -0
  138. {rc_foundry-0.1.4.dist-info → rc_foundry-0.1.6.dist-info}/entry_points.txt +0 -0
  139. {rc_foundry-0.1.4.dist-info → rc_foundry-0.1.6.dist-info}/licenses/LICENSE.md +0 -0
@@ -0,0 +1,177 @@
1
+ # Model architecture
2
+ _target_: rf3.model.RF3.RF3
3
+
4
+ # +---------- Channel dimensions ----------+
5
+ c_s: 384
6
+ c_z: 128
7
+ c_atom: 128
8
+ c_atompair: 16
9
+ c_s_inputs: 449
10
+
11
+ # +---------- Feature embedding ----------+
12
+ feature_initializer:
13
+ # InputFeatureEmbedder
14
+ input_feature_embedder:
15
+ features:
16
+ - restype
17
+ - profile
18
+ - deletion_mean
19
+ atom_attention_encoder:
20
+ c_token: 384
21
+ c_atom_1d_features: 389
22
+ c_tokenpair: ${model.net.c_z}
23
+ use_inv_dist_squared: true
24
+ atom_1d_features:
25
+ - ref_pos
26
+ - ref_charge
27
+ - ref_mask
28
+ - ref_element
29
+ - ref_atom_name_chars
30
+ atom_transformer:
31
+ n_queries: 32
32
+ n_keys: 128
33
+ diffusion_transformer:
34
+ n_block: 3
35
+ diffusion_transformer_block:
36
+ n_head: 4
37
+ no_residual_connection_between_attention_and_transition: true
38
+ kq_norm: true
39
+
40
+ # RelativePositionEncoding
41
+ relative_position_encoding:
42
+ r_max: 32
43
+ s_max: 2
44
+
45
+ # +---------- Recycler ----------+
46
+ recycler:
47
+ # Pairformer
48
+ n_pairformer_blocks: 48
49
+ pairformer_block:
50
+ p_drop: 0.25
51
+ triangle_multiplication:
52
+ d_hidden: 128
53
+ triangle_attention:
54
+ n_head: 4
55
+ d_hidden: 32
56
+ attention_pair_bias:
57
+ n_head: 16
58
+
59
+ # TemplateEmbedder
60
+ template_embedder:
61
+ n_block: 2
62
+ raw_template_dim: 108
63
+ c: 64
64
+ p_drop: 0.25
65
+
66
+ # MSA module
67
+ msa_module:
68
+ n_block: 4
69
+ c_m: 64
70
+ p_drop_msa: 0.15
71
+ p_drop_pair: 0.25
72
+ msa_subsample_embedder:
73
+ num_sequences: 1024
74
+ dim_raw_msa: 34
75
+ c_s_inputs: ${model.net.c_s_inputs}
76
+ c_msa_embed: ${model.net.recycler.msa_module.c_m}
77
+ outer_product:
78
+ c_msa_embed: ${model.net.recycler.msa_module.c_m}
79
+ c_outer_product: 32
80
+ c_out: ${model.net.c_z}
81
+ msa_pair_weighted_averaging:
82
+ n_heads: 8
83
+ c_weighted_average: 32
84
+ c_msa_embed: ${model.net.recycler.msa_module.c_m}
85
+ c_z: ${model.net.c_z}
86
+ separate_gate_for_every_channel: true
87
+ msa_transition:
88
+ n: 4
89
+ c: ${model.net.recycler.msa_module.c_m}
90
+ triangle_multiplication_outgoing:
91
+ d_pair: ${model.net.c_z}
92
+ d_hidden: 128
93
+ bias: True
94
+ triangle_multiplication_incoming:
95
+ d_pair: ${model.net.c_z}
96
+ d_hidden: 128
97
+ bias: True
98
+ triangle_attention_starting:
99
+ d_pair: ${model.net.c_z}
100
+ n_head: 4
101
+ d_hidden: 32
102
+ p_drop: 0.0 # This does not do anything: TODO: Remove
103
+ triangle_attention_ending:
104
+ d_pair: ${model.net.c_z}
105
+ n_head: 4
106
+ d_hidden: 32
107
+ p_drop: 0.0 # This does not do anything; TODO: Remove
108
+ pair_transition:
109
+ n: 4
110
+ c: ${model.net.c_z}
111
+
112
+ # +---------- Diffusion module ----------+
113
+ diffusion_module:
114
+ sigma_data: 16
115
+ c_token: 768
116
+ f_pred: edm
117
+ diffusion_conditioning:
118
+ c_s_inputs: ${model.net.c_s_inputs}
119
+ c_t_embed: 256
120
+ relative_position_encoding:
121
+ r_max: 32
122
+ s_max: 2
123
+ atom_attention_encoder:
124
+ c_tokenpair: ${model.net.c_z}
125
+ c_atom_1d_features: 389
126
+ use_inv_dist_squared: true
127
+ atom_1d_features:
128
+ - ref_pos
129
+ - ref_charge
130
+ - ref_mask
131
+ - ref_element
132
+ - ref_atom_name_chars
133
+ atom_transformer:
134
+ n_queries: 32
135
+ n_keys: 128
136
+ diffusion_transformer:
137
+ n_block: 3
138
+ diffusion_transformer_block:
139
+ n_head: 4
140
+ no_residual_connection_between_attention_and_transition: true
141
+ kq_norm: true
142
+ broadcast_trunk_feats_on_1dim_old: false
143
+ use_chiral_features: true
144
+ no_grad_on_chiral_center: false
145
+ diffusion_transformer:
146
+ n_block: 24
147
+ diffusion_transformer_block:
148
+ n_head: 16
149
+ no_residual_connection_between_attention_and_transition: true
150
+ kq_norm: true
151
+ atom_attention_decoder:
152
+ atom_transformer:
153
+ n_queries: 32
154
+ n_keys: 128
155
+ diffusion_transformer:
156
+ n_block: 3
157
+ diffusion_transformer_block:
158
+ n_head: 4
159
+ no_residual_connection_between_attention_and_transition: true
160
+ kq_norm: true
161
+ distogram_head:
162
+ bins: 65
163
+
164
+ # +---------- Inference sampler ----------+
165
+ inference_sampler:
166
+ solver: "af3"
167
+ num_timesteps: 200
168
+ min_t: 0
169
+ max_t: 1
170
+ sigma_data: ${model.net.diffusion_module.sigma_data}
171
+ s_min: 4e-4
172
+ s_max: 160
173
+ p: 7
174
+ gamma_0: 0.8
175
+ gamma_min: 1.0
176
+ noise_scale: 1.003
177
+ step_scale: 1.5
@@ -0,0 +1,45 @@
1
+ defaults:
2
+ - rf3_net
3
+
4
+ # Model architecture
5
+ _target_: rf3.model.RF3.RF3WithConfidence
6
+
7
+ # +---------- Mini rollout sampler ----------+
8
+ # From the AF-3 main text:
9
+ # > ... To remedy this, we developed a diffusion ‘rollout’ procedure for the full-structure prediction generation during training (using a larger step size than normal)
10
+ # They do not further elaborate on how they adjusted the step size during diffusion rollout, but this may be a fruitful area of exploration moving forwards
11
+ mini_rollout_sampler:
12
+ solver: "af3"
13
+ num_timesteps: 20 # 20 timesteps for the mini-rollout (vs. 200 for the full rollout during inference)
14
+ min_t: 0
15
+ max_t: 1
16
+ sigma_data: ${model.net.diffusion_module.sigma_data}
17
+ s_min: 4e-4
18
+ s_max: 160
19
+ p: 7
20
+ gamma_0: 0.8
21
+ gamma_min: 1.0
22
+ noise_scale: 1.003
23
+ step_scale: 1.5
24
+
25
+ # +---------- Confidence head architecture ----------+
26
+ confidence_head:
27
+ c_s: ${model.net.c_s}
28
+ c_z: ${model.net.c_z}
29
+ n_pairformer_layers: 4
30
+ pairformer:
31
+ p_drop: 0.25
32
+ triangle_multiplication:
33
+ d_hidden: 128
34
+ triangle_attention:
35
+ n_head: 4
36
+ d_hidden: 32
37
+ attention_pair_bias:
38
+ n_head: 16
39
+ n_bins_pae: 64
40
+ n_bins_pde: 64
41
+ n_bins_plddt: 50
42
+ n_bins_exp_resolved: 2
43
+ use_Cb_distances: False
44
+ use_af3_style_binning_and_final_layer_norms: True
45
+ symmetrize_Cb_logits: True
@@ -0,0 +1,5 @@
1
+ # Optimizer
2
+ _target_: torch.optim.Adam
3
+ lr: 0 # Will be set by the scheduler (starts at 0, increasing to `base_lr`)
4
+ betas: [0.9, 0.95]
5
+ eps: 1.0e-8
@@ -0,0 +1,43 @@
1
+ defaults:
2
+ - optimizers/adam@optimizer
3
+ - schedulers/af3@lr_scheduler
4
+ - components/ema@ema
5
+ - components/rf3_net@net
6
+ - _self_
7
+
8
+ net:
9
+ feature_initializer:
10
+ input_feature_embedder:
11
+ atom_attention_encoder:
12
+ c_atom_1d_features: 393 # 392 + 1 has_atom_level_embedding = 393
13
+ atom_1d_features:
14
+ - ref_pos
15
+ - ref_charge
16
+ - ref_mask
17
+ - ref_element
18
+ - ref_atom_name_chars
19
+ - ref_pos_ground_truth
20
+ - has_atom_level_embedding
21
+ use_atom_level_embedding: true
22
+ atom_level_embedding_dim: 384
23
+
24
+ recycler:
25
+ msa_module:
26
+ msa_subsample_embedder:
27
+ dim_raw_msa: 35
28
+ template_embedder:
29
+ raw_template_dim: 66
30
+
31
+ diffusion_module:
32
+ atom_attention_encoder:
33
+ c_atom_1d_features: 393 # 392 + 1 has_atom_level_embedding = 393
34
+ atom_1d_features:
35
+ - ref_pos
36
+ - ref_charge
37
+ - ref_mask
38
+ - ref_element
39
+ - ref_atom_name_chars
40
+ - ref_pos_ground_truth
41
+ - has_atom_level_embedding
42
+ use_atom_level_embedding: true
43
+ atom_level_embedding_dim: 384
@@ -0,0 +1,7 @@
1
+ defaults:
2
+ - components/rf3_net_with_confidence_head@net
3
+ - rf3
4
+ - _self_
5
+
6
+ net:
7
+ _target_: rf3.model.RF3.RF3WithConfidence
@@ -0,0 +1,6 @@
1
+ # Learning rate scheduler
2
+ _target_: foundry.training.schedulers.AF3Scheduler
3
+ base_lr: 1.8e-3
4
+ warmup_steps: 1000
5
+ decay_factor: 0.95
6
+ decay_steps: 50000
@@ -0,0 +1,43 @@
1
+ ########################
2
+ # Datasets
3
+ ########################
4
+
5
+ # path to directory with training splits
6
+ pdb_data_dir: /projects/ml/datahub/dfs/af3_splits/2025_07_13
7
+
8
+ # fb monomer distillation dataset
9
+ monomer_distillation_data_dir: /squash/af2_distillation_facebook
10
+ monomer_distillation_parquet_dir: /projects/ml/datahub/dfs/distillation/af2_distillation_facebook
11
+
12
+ mgnify_distillation_data_dir: /squash/mgnify_distill_rf3/
13
+ mgnify_distillation_parquet_dir: /home/dimaio/MGnify/
14
+
15
+ # na complex distill set
16
+ na_complex_distillation_data_dir: /projects/ml/prot_dna/rf3_newDL
17
+ na_complex_distillation_parquet_dir: /projects/ml/prot_dna
18
+
19
+ # disorder distill set
20
+ disorder_distill_parquet_dir: /projects/ml/disorder_distill
21
+
22
+ ########################
23
+ # MSAs
24
+ ########################
25
+
26
+ # path(s) to search for protein MSAs (for PDB datasets)
27
+ # e.g., {"dir": "/path/to/msas", "extension": ".a3m.gz", "directory_depth": 2}
28
+ protein_msa_dirs:
29
+ - {"dir": "/projects/msa/hhblits", "extension": ".a3m.gz", "directory_depth": 2}
30
+ - {"dir": "/projects/msa/mmseqs_gpu", "extension": ".a3m.gz", "directory_depth": 2}
31
+ - {"dir": "/projects/msa/lab", "extension": ".a3m.gz", "directory_depth": 1}
32
+
33
+ # path(s) to search for RNA MSAs
34
+ # e.g., {"dir": "/path/to/msas", "extension": ".afa", "directory_depth": 0}
35
+ rna_msa_dirs:
36
+ - {"dir": "/projects/msa/rna", "extension": ".afa", "directory_depth": 0}
37
+
38
+ ########################
39
+ # Misc
40
+ ########################
41
+
42
+ # path to save examples that fail during the Transform pipeline (null = do not save)
43
+ failed_examples_dir: null
@@ -0,0 +1,21 @@
1
+ # NOTE: order of defaults determines the order in which configs override each other (higher up items are overridden by lower items)
2
+ defaults:
3
+ - _self_
4
+ - data: default
5
+
6
+ # path to root directory (requires the `PROJECT_ROOT` environment variable to be set)
7
+ #  NOTE: This variable is auto-set upon loading via `rootutils`
8
+ root_dir: ${oc.env:PROJECT_ROOT}
9
+
10
+ # where to store data (checkpoints, logs, etc.) of all experiments in general
11
+ # (this influences the output_dir in the hydra/default.yaml config)
12
+ # change this to e.g. /scratch if you are running larger experiments with lots lof logs, checkpoints, etc.
13
+ log_dir: /net/scratch/${oc.env:USER}/training/logs
14
+
15
+ # path to output directory for this specific run, created dynamically by hydra
16
+ # path generation pattern is specified in `configs/hydra/default.yaml`
17
+ # use it to store all files generated during the run, like ckpts and metrics
18
+ output_dir: ${hydra:runtime.output_dir}
19
+
20
+ # path to working directory (auto-generated by hydra)
21
+ work_dir: ${hydra:runtime.cwd}
rf3/configs/train.yaml ADDED
@@ -0,0 +1,42 @@
1
+ # @package _global_
2
+ # ^ The "package" determines where the content of the config is placed in the output config
3
+ # For more information about overriding configs, see: https://hydra.cc/docs/advanced/overriding_packages/#overriding-packages-using-the-defaults-list
4
+
5
+ # NOTE: order of defaults determines the order in which configs override each other (higher up items are overridden by lower items)
6
+ defaults:
7
+ - callbacks: default
8
+ - logger: csv
9
+ - trainer: ???
10
+ - paths: default
11
+ - datasets: ???
12
+ - dataloader: default
13
+ - hydra: default
14
+ - model: ???
15
+ # We must keep _self_ before experiment and debug to ensure that the experiment and debug configs can override
16
+ - _self_
17
+
18
+ # experiment configs allow for version control of specific hyperparameters
19
+ # e.g. best hyperparameters for given model and datamodule
20
+ - experiment: ???
21
+
22
+ # debug configs to add onto any experiment for quickly testing or debugging code
23
+ - debug: null
24
+
25
+
26
+ # DO NOT set these here. Set them in the relevant experiment config file.
27
+ # ... these are just here to ensure users always specify these fields in their experiment configs.
28
+ name: ???
29
+ tags: ???
30
+
31
+ # NOTE: These values will be overwritten by the experiment config if they are set there. They are just provided as defaults
32
+ # here.
33
+ # ... task name (determines the output directory path)
34
+ task_name: "train"
35
+
36
+ project: ??? # required for W&B logging
37
+
38
+ seed: 1
39
+
40
+ # Provide checkpoint path to resume training from a checkpoint
41
+ # NOTE: If using W&B, must also set the `id` and `resume` fields in the `logger/wandb` config
42
+ ckpt_path: null
@@ -0,0 +1,6 @@
1
+ defaults:
2
+ - af3
3
+
4
+ accelerator: cpu
5
+ devices_per_node: 1
6
+ num_nodes: 1
@@ -0,0 +1,5 @@
1
+ strategy: ddp
2
+
3
+ accelerator: gpu
4
+ devices_per_node: 1
5
+ num_nodes: 1
@@ -0,0 +1,29 @@
1
+ _target_: rf3.loss.af3_confidence_loss.ConfidenceLoss
2
+ weight: 1.0
3
+
4
+ plddt:
5
+ weight: 1.0
6
+ n_bins: 50
7
+ max_value: 1.0
8
+
9
+ pae:
10
+ weight: 1.0
11
+ n_bins: 64
12
+ max_value: 32
13
+
14
+ pde:
15
+ weight: 1.0
16
+ n_bins: 64
17
+ max_value: 32
18
+
19
+ exp_resolved:
20
+ weight: 1.0
21
+ n_bins: 2
22
+ max_value: 1
23
+
24
+ # Adds to loss_dict true and predicted average plddt, pae, and pde per batch, also info about the spread and correlation of those values within a batch
25
+ log_statistics: True
26
+
27
+ rank_loss:
28
+ use_listnet_loss: False
29
+ weight: 0.0
@@ -0,0 +1,9 @@
1
+ _target_: rf3.loss.af3_losses.DiffusionLoss
2
+ weight: 4.0
3
+ sigma_data: ${model.net.diffusion_module.sigma_data}
4
+ alpha_dna: 5
5
+ alpha_rna: 5
6
+ alpha_ligand: 10
7
+ edm_lambda: True
8
+ se3_invariant_loss: True
9
+ clamp_diffusion_loss: False
@@ -0,0 +1,2 @@
1
+ _target_: rf3.loss.af3_losses.DistogramLoss
2
+ weight: 3e-2
@@ -0,0 +1,4 @@
1
+ defaults:
2
+ # Note that the SmoothedLDDTLoss is included within the DiffusionLoss
3
+ - losses/diffusion_loss@diffusion_loss
4
+ - losses/distogram_loss@distogram_loss
@@ -0,0 +1,2 @@
1
+ defaults:
2
+ - losses/confidence_loss@confidence_loss
@@ -0,0 +1,14 @@
1
+ by_type_lddt:
2
+ _target_: rf3.metrics.lddt.ByTypeLDDT
3
+ all_atom_lddt:
4
+ _target_: rf3.metrics.lddt.AllAtomLDDT
5
+ distogram:
6
+ _target_: rf3.metrics.distogram.DistogramLoss
7
+ distogram_comparisons:
8
+ _target_: rf3.metrics.distogram.DistogramComparisons
9
+ distogram_entropy:
10
+ _target_: rf3.metrics.distogram.DistogramEntropy
11
+ chiral_loss:
12
+ _target_: rf3.metrics.chiral.ChiralLoss
13
+ unresolved_rasa:
14
+ _target_: rf3.metrics.rasa.UnresolvedRegionRASA
@@ -0,0 +1,20 @@
1
+ defaults:
2
+ - ddp
3
+ - loss: structure_prediction
4
+ - metrics: structure_prediction
5
+
6
+ _target_: rf3.trainers.rf3.RF3Trainer
7
+ validate_every_n_epochs: 1
8
+ max_epochs: 10_000
9
+ n_examples_per_epoch: 24000
10
+ prevalidate: True
11
+
12
+ # We must pre-specify the number of recycles during training so we can pre-sample recycles per batch consistently for each GPU
13
+ n_recycles_train: ${datasets.n_recycles_train}
14
+
15
+ clip_grad_max_norm: 10.0
16
+
17
+ output_dir: ${paths.output_dir}
18
+ checkpoint_every_n_epochs: 1
19
+
20
+ precision: bf16-mixed
@@ -0,0 +1,13 @@
1
+ defaults:
2
+ - rf3
3
+ - override loss: structure_prediction_with_confidence
4
+
5
+ _target_: rf3.trainers.rf3.RF3TrainerWithConfidence
6
+
7
+ metrics:
8
+ ptm:
9
+ _target_: rf3.metrics.predicted_error.ComputePTM
10
+ iptm:
11
+ _target_: rf3.metrics.predicted_error.ComputeIPTM
12
+ count_clashing_chains:
13
+ _target_: rf3.metrics.clashing_chains.CountClashingChains
@@ -0,0 +1,45 @@
1
+ # @package _global_
2
+ # ^ The "package" determines where the content of the config is placed in the output config
3
+ # For more information about overriding configs, see: https://hydra.cc/docs/advanced/overriding_packages/#overriding-packages-using-the-defaults-list
4
+
5
+ # NOTE: order of defaults determines the order in which configs override each other (higher up items are overridden by lower items)
6
+ defaults:
7
+ - callbacks: default
8
+ - logger: csv
9
+ - trainer: ???
10
+ - paths: default
11
+ - datasets: ???
12
+ - dataloader: default
13
+ - hydra: default
14
+ - model: ???
15
+ # We must keep _self_ before experiment and debug to ensure that the experiment and debug configs can override
16
+ - _self_
17
+
18
+ # experiment configs allow for version control of specific hyperparameters
19
+ # e.g. best hyperparameters for given model and datamodule
20
+ - experiment: ???
21
+
22
+ # debug configs to add onto any experiment for quickly testing or debugging code
23
+ - debug: null
24
+
25
+
26
+ # DO NOT set these here. Set them in the relevant experiment config file.
27
+ # ... these are just here to ensure users always specify these fields in their experiment configs.
28
+ name: ???
29
+ tags: ???
30
+
31
+ # NOTE: These values will be overwritten by the experiment config if they are set there. They are just provided as defaults
32
+ # here.
33
+ # ... task name (determines the output directory path)
34
+ task_name: "validate"
35
+
36
+ project: ??? # required for W&B logging
37
+
38
+ seed: 1
39
+
40
+ # Dump CIF files for validation structures
41
+ callbacks:
42
+ dump_validation_structures_callback:
43
+ dump_predictions: True
44
+ one_model_per_file: False
45
+ dump_trajectories: False
rfd3/cli.py CHANGED
@@ -12,10 +12,16 @@ app = typer.Typer()
12
12
  def design(ctx: typer.Context):
13
13
  """Run design using hydra config overrides and input files."""
14
14
  # Find the RFD3 configs directory relative to this file
15
- # This file is at: models/rfd3/src/rfd3/cli.py
16
- # Configs are at: models/rfd3/configs/
17
- rfd3_package_dir = Path(__file__).parent.parent.parent # Go up to models/rfd3/
18
- config_path = str(rfd3_package_dir / "configs")
15
+ # Development: models/rfd3/src/rfd3/cli.py -> models/rfd3/configs/
16
+ # Installed: site-packages/rfd3/cli.py -> site-packages/rfd3/configs/
17
+
18
+ # Try development location first
19
+ dev_config_path = Path(__file__).parent.parent.parent / "configs"
20
+ if dev_config_path.exists():
21
+ config_path = str(dev_config_path)
22
+ else:
23
+ # Fall back to installed package location
24
+ config_path = str(Path(__file__).parent / "configs")
19
25
 
20
26
  # Get all arguments
21
27
  args = ctx.params.get("args", []) + ctx.args
File without changes
@@ -0,0 +1,10 @@
1
+ defaults:
2
+ - train_logging
3
+ - _self_
4
+
5
+ log_learning_rate_callback:
6
+ log_every_n: 25 # default 10
7
+
8
+ log_af3_training_losses_callback:
9
+ log_full_batch_losses: False
10
+ log_every_n: 25 # default 10
@@ -0,0 +1,20 @@
1
+ store_validation_metrics_in_df_callback:
2
+ _target_: foundry.callbacks.metrics_logging.StoreValidationMetricsInDFCallback
3
+ save_dir: ${paths.output_dir}/val_metrics
4
+ metrics_to_save: "all"
5
+
6
+ dump_validation_structures_callback:
7
+ _target_: rfd3.trainer.dump_validation_structures.DumpValidationStructuresCallback
8
+ save_dir: ${paths.output_dir}/val_structures
9
+ dump_predictions: True
10
+ dump_prediction_metadata_json: True
11
+ dump_trajectories: False
12
+ dump_denoised_trajectories_only: False
13
+
14
+ one_model_per_file: True
15
+ dump_every_n: 4
16
+ align_trajectories: False
17
+ verbose: False
18
+
19
+ log_design_validation_metrics_callback:
20
+ _target_: rfd3.callbacks.LogDesignValidationMetricsCallback
@@ -0,0 +1,24 @@
1
+ log_af3_training_losses_callback:
2
+ _target_: foundry.callbacks.train_logging.LogAF3TrainingLossesCallback
3
+ log_every_n: 10
4
+ log_full_batch_losses: true
5
+
6
+ log_learning_rate_callback:
7
+ _target_: foundry.callbacks.train_logging.LogLearningRateCallback
8
+ log_every_n: 10
9
+
10
+ log_model_parameters_callback:
11
+ _target_: foundry.callbacks.train_logging.LogModelParametersCallback
12
+
13
+ log_dataset_sampling_ratios_callback:
14
+ _target_: foundry.callbacks.train_logging.LogDatasetSamplingRatiosCallback
15
+
16
+ # Optional health logging
17
+ # activations_tracking_callback:
18
+ # _target_: foundry.callbacks.health_logging.ActivationsGradientsWeightsTracker
19
+ # log_freq: 100
20
+ # keep_cache: True # --> WARNING: Do not run this in a production run, this will lead to a memory leak! Meant for debugging.
21
+ # activations_tracking_callback:
22
+ # _target_: foundry.callbacks.health_logging.ActivationsGradientsWeightsTracker
23
+ # log_freq: 100
24
+ # keep_cache: True # --> WARNING: Do not run this in a production run, this will lead to a memory leak! Meant for debugging.
@@ -0,0 +1,15 @@
1
+ train:
2
+ dataloader_params:
3
+ # These parameters will be unpacked as kwargs for the DataLoader
4
+ batch_size: 1
5
+ num_workers: 2
6
+ prefetch_factor: 3
7
+ n_fallback_retries: 4
8
+
9
+ val:
10
+ dataloader_params:
11
+ # These parameters will be unpacked as kwargs for the DataLoader
12
+ batch_size: 1
13
+ num_workers: 2
14
+ prefetch_factor: 3
15
+ n_fallback_retries: 0 # Disable fallback retries for validation
@@ -0,0 +1,11 @@
1
+ defaults:
2
+ - default
3
+
4
+ train:
5
+ dataloader_params:
6
+ num_workers: 2
7
+ prefetch_factor: 6
8
+ val:
9
+ dataloader_params:
10
+ num_workers: 2
11
+ prefetch_factor: 6