rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
foundry/__init__.py,sha256=H8S1nl5v6YeW8ggn1jKy4GdtH7c-FGS-j7CqUCAEnAU,1926
|
|
2
|
+
foundry/common.py,sha256=Aur8mH-CNmcUqSsw7VgaCQSW5sH1Bqf8Da91jzxPV1Y,3035
|
|
3
|
+
foundry/constants.py,sha256=0n1wBKCvNuw3QaQehSbmsHYkIdaGn3tLeRFItBrdeHY,913
|
|
4
|
+
foundry/version.py,sha256=m8HxkqoKGw_wAJtc4ZokpJKNLXqp4zwnNhbnfDtro7w,704
|
|
5
|
+
foundry/callbacks/__init__.py,sha256=VsRT1e4sqlJHPcTCsfupMEx82Iz-LoOAGPpwvf_OJeE,126
|
|
6
|
+
foundry/callbacks/callback.py,sha256=xZBo_suP4bLrP6gl5uJPbaXm00DXigePa6dMeDxucgg,3890
|
|
7
|
+
foundry/callbacks/health_logging.py,sha256=tEtkByOlaAA7nnelxb7PbM9_dcIgOsdbxCdQY3K5pMc,16664
|
|
8
|
+
foundry/callbacks/metrics_logging.py,sha256=Vekzs831d-HE7TfLJZnQ45iPeG9ziQWLQaMBGaymfQM,8696
|
|
9
|
+
foundry/callbacks/timing_logging.py,sha256=u-r0hKp7fWOY3mLk7CcuIwHgZbhte13m5M09xNgatZA,2343
|
|
10
|
+
foundry/callbacks/train_logging.py,sha256=Xs3tmZA88qLxmdSOwt-x8YKN4NKb1kVm59uptNXl4Qo,10399
|
|
11
|
+
foundry/hydra/resolvers.py,sha256=xyJzo6OeWAc_LOu8RiHhX7_CRNoLZ22626AvYHXYl4U,2186
|
|
12
|
+
foundry/inference_engines/base.py,sha256=qv5Gnk6NIxMpIxZ3oeOJurqMMUzBCZgfzHckb7SSzmU,8227
|
|
13
|
+
foundry/inference_engines/checkpoint_registry.py,sha256=xRN3PjtmcnN7aEEhDR0MKhp1yaMDXHRXdLAGTDxi_Yk,2563
|
|
14
|
+
foundry/metrics/__init__.py,sha256=qL4wwaiQ7EtR30pmZ9MCknqx909BJcNvHVmNJUaz_WM,236
|
|
15
|
+
foundry/metrics/losses.py,sha256=2CLUmf7oCdFUCvgJukdNkff0FVG3BlATI-NI60TtpVY,903
|
|
16
|
+
foundry/metrics/metric.py,sha256=23pKh_Ra0EcHGo5cSzYQQrUGr5zWRxeufKSJ58tfXXo,12687
|
|
17
|
+
foundry/model/layers/blocks.py,sha256=ihbbP_1fOlrkrcrQSk9thCrNWjK8mtxD3WxcBng9Htk,1403
|
|
18
|
+
foundry/testing/__init__.py,sha256=BnrU7fZ4l0Dm1vrGcNPQYTAw83PW4DGYz7TGhGqgrfQ,223
|
|
19
|
+
foundry/testing/fixtures.py,sha256=j27a8CAonygjlWsUjZ-95M5MF4Rjp9nw7JskqiZlweI,486
|
|
20
|
+
foundry/testing/pytest_hooks.py,sha256=5Ebw1GXYO2XqS9Jvpzty7g3gCXIdXu16jqg53XcuUx4,450
|
|
21
|
+
foundry/trainers/fabric.py,sha256=cjaTHbGuJEQwaGBvIAXD_il4bHtY-crsTY14Xn77uXA,40401
|
|
22
|
+
foundry/training/EMA.py,sha256=3OWA9Pz7XuDr-SRxbz24tZf55DmhSa2fKy9r5v2IXqA,2651
|
|
23
|
+
foundry/training/checkpoint.py,sha256=mUiObg-qcF3tvMfVu77sD9m3yVRp71czv07ccliU7qQ,1791
|
|
24
|
+
foundry/training/schedulers.py,sha256=StmXegPfIdLAv31FreCTrDh9dsOvNUfzG4YGa61Y4oE,3647
|
|
25
|
+
foundry/utils/alignment.py,sha256=OAN7H2TqraGxP1uMXUpwLO7g0qS0cxUVjuV33pY16z0,2316
|
|
26
|
+
foundry/utils/components.py,sha256=Piw2TfQF26uuxC3hXG3iv_4rgud1lVO-cv6N-p05EDY,15200
|
|
27
|
+
foundry/utils/datasets.py,sha256=pLBxVezm-TSrYuC5gFnJZdGnNWV7aPH2QiWIVE2hkdQ,16629
|
|
28
|
+
foundry/utils/ddp.py,sha256=ydHrO6peGbRnWAwgH5rmpHuQd55g2gFzzoZJYypn7GU,3970
|
|
29
|
+
foundry/utils/instantiators.py,sha256=oGCp6hrmY-QPPPEjxKxe5uVFL125fH1RaLxjMKWCD_8,2169
|
|
30
|
+
foundry/utils/logging.py,sha256=jrDgiB_56q_hWDc0jkBFekvqnNWcowJBt4B-S-ipJmM,9312
|
|
31
|
+
foundry/utils/rigid.py,sha256=_Z1pmitb6xgxyguLj_TukKscUBJjQsU4bsBD24GVS84,44444
|
|
32
|
+
foundry/utils/rotation_augmentation.py,sha256=7q1WEX2iJ0i7-2aV-M97nEaEdpqexDTaZn5JquYpkUk,1927
|
|
33
|
+
foundry/utils/squashfs.py,sha256=QlcwuJyVe-QVfIOS7o1QfLhaCQPNzzox7ln4n8dcYEg,5234
|
|
34
|
+
foundry/utils/torch.py,sha256=OLsqoxw4CTXbGzWUHernLUT7uQjLu0tVPtD8h8747DI,11211
|
|
35
|
+
foundry/utils/weights.py,sha256=btz4S02xff2vgiq4xMfiXuhK1ERafqQPtmimo1DmoWY,10381
|
|
36
|
+
foundry_cli/__init__.py,sha256=0BxY2RUKJLaMXUGgypPCwlTskTEFdVnkhTR4C4ft2Kw,52
|
|
37
|
+
foundry_cli/download_checkpoints.py,sha256=2PDKw-dWCht_mD6fRYTiOtVlk4P2CQqTPSkN-7s19mk,8474
|
|
38
|
+
mpnn/__init__.py,sha256=hgQcXFaCbAxFrhydVAy0xj8yC7UJF-GCCFhqD0sZ7I4,57
|
|
39
|
+
mpnn/inference.py,sha256=wPtGR325eVRVeesXoWtBK6b_-VcU8BZae5IfQN3-mvA,1669
|
|
40
|
+
mpnn/train.py,sha256=9eQGBd3rdNF5Zr2w8oUgETbqxBavNBajtA6Vbc5zESE,10239
|
|
41
|
+
mpnn/collate/feature_collator.py,sha256=LpzAFWo1VMa06dJLmfUWZsKe4xvLZjHbx4RICg2lgbQ,10510
|
|
42
|
+
mpnn/inference_engines/mpnn.py,sha256=985Ce84dWKNZk_5_dk1eJScYIgBGUKUhy-xDVCu15PA,21631
|
|
43
|
+
mpnn/loss/nll_loss.py,sha256=KmdNe-BCzGYtijjappzBArQcT1gHVlJnKdY1PYQ4mhU,5947
|
|
44
|
+
mpnn/metrics/nll.py,sha256=T6oMeUOEeHZzOMTH8NHFtsH9vUwLAsHQDPszzj4YKXI,15299
|
|
45
|
+
mpnn/metrics/sequence_recovery.py,sha256=YDw_LmH-a3ajBYWK0mucJEQvw0_VEyxvrBN7da4vX8Q,19034
|
|
46
|
+
mpnn/model/mpnn.py,sha256=vhkair2tYoId_akRP2qEq5O0IMZv6wsv9Q-V9plKgV8,131144
|
|
47
|
+
mpnn/model/layers/graph_embeddings.py,sha256=aEtd7iorMh8DxNH0eZVrK_zOo8HDLM5FRJyIJ8Cfz6k,99795
|
|
48
|
+
mpnn/model/layers/message_passing.py,sha256=TUkG9pXuo4Rtz5Bcij-OB7T4gSKmLt1KgxNmjJYPcMY,13051
|
|
49
|
+
mpnn/model/layers/position_wise_feed_forward.py,sha256=FATM8oveWy2XW-PDaaF9XLPIiWbehOHxG715E60n_8g,1602
|
|
50
|
+
mpnn/model/layers/positional_encoding.py,sha256=f-YpH1xvPGFC75U2-sOHrK13XtA9ZAWxjxxH1GrDt1M,4876
|
|
51
|
+
mpnn/pipelines/mpnn.py,sha256=SukwxEcAzaCgUZKcA1_KusodvcCg3_buN1dAZU-Udas,6185
|
|
52
|
+
mpnn/samplers/samplers.py,sha256=LDpetPMVklMboj1tucgnNvSHRUaQuehBmR2jFl4VWIE,6129
|
|
53
|
+
mpnn/trainers/mpnn.py,sha256=waXLQ-7pFD8MJRlnK37mHWcvqD6uOjTXVP6910tB6cw,6586
|
|
54
|
+
mpnn/transforms/polymer_ligand_interface.py,sha256=lipKDt_NFrpM-GiOXtvnTAvMpISOO4eHwilCgxnISJU,6106
|
|
55
|
+
mpnn/transforms/feature_aggregation/mpnn.py,sha256=jkhyMCqJipKQ2PvjqPkvvClhoiXx_I8e03lnDeH9__M,6324
|
|
56
|
+
mpnn/transforms/feature_aggregation/polymer_ligand_interface.py,sha256=gDdt9RZd0PO0YJdouNr0qsHFZV1i-5ewU6XuJrwPY54,2870
|
|
57
|
+
mpnn/transforms/feature_aggregation/token_encodings.py,sha256=qVlUky4HcDSU5drrZpZBnUvTSGdT6C7MN8f_owa81Bw,2227
|
|
58
|
+
mpnn/transforms/feature_aggregation/user_settings.py,sha256=uKyIDXz-QG0-KWQO1kqPlMj6i7RoVM6yH4iGNXFStoU,15007
|
|
59
|
+
mpnn/utils/inference.py,sha256=QLeukqLpedMNmvjbYgvLwDS5k7Q__NWILDSEbETkoCI,96539
|
|
60
|
+
mpnn/utils/probability.py,sha256=EYisliXNGXjuSPbzZwcIKjlhyINikGsqQndGBEbQoPI,990
|
|
61
|
+
mpnn/utils/weights.py,sha256=VsaIcOWTv8G-WJ9denxLRm3FQ9l6L66AVQN08E9BMSg,16411
|
|
62
|
+
rf3/__init__.py,sha256=XBb5hF2RqBPHODGRmjiRbfTXgOGfOzdY91GbS4Vex00,70
|
|
63
|
+
rf3/_version.py,sha256=fCfpbI5aeA6yHqjo3tK78-l2dPGxhp-AyKSoCXp34Nc,739
|
|
64
|
+
rf3/alignment.py,sha256=BvvwMqQGCVxV20xIsTighD1kXMadXXL2SkckLjTerx0,2102
|
|
65
|
+
rf3/chemical.py,sha256=VECnRPgVm-icXbZeUG4svcENzdUiIupP6dhka_8zCrg,26572
|
|
66
|
+
rf3/cli.py,sha256=dPKJFRHYoV2XS6xc_ZmdLTz6frqa6HZg4qgZU5oJcXU,2356
|
|
67
|
+
rf3/inference.py,sha256=_AAJ07AfSeU3xTM2_KH9n_H12EK4qZ23IJuyauOrMaQ,2466
|
|
68
|
+
rf3/kinematics.py,sha256=V3yjalPupu1X2FEp7l3XZR-qzLKrhWLZyECk6RgIkcs,10901
|
|
69
|
+
rf3/scoring.py,sha256=dTllswE-6Fgli2eLiNzLFc2Rhz4ouDT4WL-sVbvLTGU,41541
|
|
70
|
+
rf3/train.py,sha256=V4nqCC_1JKLI3WQ-nErNa8sqFpvb1mFhXSe6ZPpEheM,7945
|
|
71
|
+
rf3/util_module.py,sha256=ltc7QXJDb5Z184wxYQuT_-Z68YWXuPmOEBMLSzS6Pes,1428
|
|
72
|
+
rf3/validate.py,sha256=wXZLTWiOdTsCKiKK2_Dfnj5PInDiv5KxojOZwaUJjuo,5832
|
|
73
|
+
rf3/callbacks/dump_validation_structures.py,sha256=j_pDfPETyI10ZtsUlvf16-zpJdaUcC5w8TEYCk--Xbo,3909
|
|
74
|
+
rf3/callbacks/metrics_logging.py,sha256=MYcM_ZYKsBTJKx2xi9H0QPr5Lh1o80_bTZXH8kfV5y8,13429
|
|
75
|
+
rf3/data/cyclic_transform.py,sha256=Cs4x_qooCUXKNiFeVdfrXAGhZ0z_yyedscWRsGmEpwM,3351
|
|
76
|
+
rf3/data/extra_xforms.py,sha256=Pxv4Nt4Ir_Ca92XWx_c6mdaCl7fS3_gFCWG775WWVSQ,1197
|
|
77
|
+
rf3/data/ground_truth_template.py,sha256=dct1bGQ7AMjiNyLIotKJZzbbUI546mc-UDdRpKMU7vU,19460
|
|
78
|
+
rf3/data/paired_msa.py,sha256=aso9awdKthzsm1ITHHO8r-1O5vFDrDL-ot7FhXB7Css,7765
|
|
79
|
+
rf3/data/pipeline_utils.py,sha256=qLNcy2iTmZxwCJJk9etXCn_vRljUYhGqSn0uNsRDGO0,5008
|
|
80
|
+
rf3/data/pipelines.py,sha256=3yy8f4pUSiV9mnhDmboO2RAsIlIJ0WYCzTzXT3OGQ5M,20913
|
|
81
|
+
rf3/diffusion_samplers/inference_sampler.py,sha256=dNhVGQujfcMhIlFdlLjZHiWXBnaLjOjRBd58eM6SJv4,9014
|
|
82
|
+
rf3/inference_engines/__init__.py,sha256=tFyTBEMsQ2oR-x2TE9-7B3HIlkHFmoOP2d6hfomj_hc,121
|
|
83
|
+
rf3/inference_engines/rf3.py,sha256=rPJaY355JUvqq_UOAM7gSHzYBE6keN0zPG57bjUb0qU,29761
|
|
84
|
+
rf3/loss/af3_confidence_loss.py,sha256=X8TLudvIFxD2GHlNLsBHoysO0qeWGdaKdW8It6G7nhE,19697
|
|
85
|
+
rf3/loss/af3_losses.py,sha256=oD-sFIJgNR11OAFK_K-6OeKdeIauEabWNMJX8M9mE6k,24492
|
|
86
|
+
rf3/loss/loss.py,sha256=3aHB8FiA0WkeTacfMUnG9mnoILyjX8AE9345eToubag,6184
|
|
87
|
+
rf3/metrics/chiral.py,sha256=xeZjBv9XFa8Tmo7aFjevwLuRk6GGpRU4CIkT41SvnFI,7139
|
|
88
|
+
rf3/metrics/clashing_chains.py,sha256=El3CILpTMSD-U-5pgUugXigNao9AZ17KxOW1-NX4N38,2518
|
|
89
|
+
rf3/metrics/distogram.py,sha256=lXLPfNtWnt4d3_Vc9A1AsAv3c0eZDRdK5tixgO9rIj8,16978
|
|
90
|
+
rf3/metrics/lddt.py,sha256=VjObCWvjAhElgGFckx_sRx8toUq4TVfr5ip8ThW06Qw,21412
|
|
91
|
+
rf3/metrics/metadata.py,sha256=hFEJ4thQiV8vs3fzj5dxK4BS3nb4PSnhCoBrfQXwD2I,1437
|
|
92
|
+
rf3/metrics/metric_utils.py,sha256=rdTY3Uc4at-Y7jLaDfaEhp17Z2KYqLGETuuCFSN59bY,7125
|
|
93
|
+
rf3/metrics/predicted_error.py,sha256=tsUFyW6Jv8m4REKui0mVQkfACB8PTVOzMstFW1d5pAA,4973
|
|
94
|
+
rf3/metrics/rasa.py,sha256=mQ6ZQdroC4CY3XDiUVWKvtHFCGTxyhDASqYF8SjnQGU,4525
|
|
95
|
+
rf3/metrics/selected_distances.py,sha256=uyTlbPHnf6PpZ6JMRkfDAAY-GkxerjpDkurTOSt3EV0,3620
|
|
96
|
+
rf3/model/RF3.py,sha256=fAJu8FG54tdo7wKcwkDLhorgBu4aBjqNAg4KGJIjgmc,22383
|
|
97
|
+
rf3/model/RF3_blocks.py,sha256=FZliymoYsYpDz_YqPsuXILDZaiE6IGcX1wJoeohueko,3218
|
|
98
|
+
rf3/model/RF3_structure.py,sha256=EnZuYk8dJWgLubp9Mui1f_e6hE0z6XuLn1HX7N_LfCo,9652
|
|
99
|
+
rf3/model/layers/af3_auxiliary_heads.py,sha256=kJqsT0_plh9_TXGN4HpPB1NzGIarCZyAIG5NiB5RN3Q,9937
|
|
100
|
+
rf3/model/layers/af3_diffusion_transformer.py,sha256=uU3OsubqsrGWjTHzEsOb6hzlc1gB5-o7GcK8ElRiLs8,19143
|
|
101
|
+
rf3/model/layers/attention.py,sha256=ofz4LR74oFBZ64RBAcbFlksV2E2NkYM69LoMDc4qedo,11189
|
|
102
|
+
rf3/model/layers/layer_utils.py,sha256=dzrYwvgdKS2ouDnRiseU2x7VBcRcXa1zeHA_EKLyO78,3382
|
|
103
|
+
rf3/model/layers/mlff.py,sha256=SFHskP18xB7zuZAzEeubw7kKwMEGZYlypGS_zZ-02yk,5017
|
|
104
|
+
rf3/model/layers/outer_product.py,sha256=OYam3gsxJa7JBet71_eFZ7D7mXQghluLXMRVHv2xUSs,2037
|
|
105
|
+
rf3/model/layers/pairformer_layers.py,sha256=nV5zwN6CLIqfCK7UTaNX_n67ASgfWp58KJMlj3TBbpg,28724
|
|
106
|
+
rf3/model/layers/structure_bias.py,sha256=xhu_KNRlE_FNfB0dwfHY47xhdFMmIPgV2qeYb_WyepI,1757
|
|
107
|
+
rf3/symmetry/resolve.py,sha256=odQF32aM4BqaOrmfEQJtwQxKgdJL5J0_Htb1I5KEDPA,10843
|
|
108
|
+
rf3/trainers/rf3.py,sha256=P9zLTMu7YaxllBiLgyrn6FnP1vKCOaerA7ciDgmqBrA,24141
|
|
109
|
+
rf3/utils/frames.py,sha256=6LVuV2XbODKNRU_ggGkd0EBXBT7F0q-HXFad4eTUOVs,3745
|
|
110
|
+
rf3/utils/inference.py,sha256=MjNkhoMzwPUb5rCouvEjtW-6XRa3yb471DQtpxzrhdk,24772
|
|
111
|
+
rf3/utils/io.py,sha256=xwOjzWviYKphuMbJgj18dylI72n0oF7mdUp8V5qlsaQ,8051
|
|
112
|
+
rf3/utils/loss.py,sha256=llaiL-5VaNTDMwh0TK_nIzniYPg0zDhIzzM8i8fYCqY,2757
|
|
113
|
+
rf3/utils/predict_and_score.py,sha256=VzRZohertYLMfnT9SwRs1gMGEhspV0LPBKmUjpU5WAY,6209
|
|
114
|
+
rf3/utils/predicted_error.py,sha256=5gIRjsD7bWvYYM30_wq9827porc1oj_Qq-cjyetYJ_s,25438
|
|
115
|
+
rf3/utils/recycling.py,sha256=nRvv0vWMsMG0Ods83XKkxdgmqKMXTw-w02n_BuZOYoo,1491
|
|
116
|
+
rfd3/.gitignore,sha256=935nLWJz_oi5h-UjxP4L_ulsMpkbRIVsl0dgGCwTCbc,109
|
|
117
|
+
rfd3/Makefile,sha256=_O87r1eIN7AmWWIqur3z0tLn1kgAPGEAGX2fcddarMs,2224
|
|
118
|
+
rfd3/__init__.py,sha256=2Wto2IsUIj2lGag9m_gqgdCwBNl5p21-Xnr7W_RpU3c,348
|
|
119
|
+
rfd3/callbacks.py,sha256=Zjt8RiaYWquoKOwRmC_wCUbRbov-V4zd2_73zjhgDHE,2783
|
|
120
|
+
rfd3/cli.py,sha256=TZpZouXGmwAMFaH8hp4r3q9tbUi1xlcN8n_r8hO2q8c,1424
|
|
121
|
+
rfd3/constants.py,sha256=wLvDzrThpOrK8T3wGFNQeGrhAXOJQze8l3v_7pjIdMM,13141
|
|
122
|
+
rfd3/engine.py,sha256=HCIHgB_4xFBYz7d7DGnqibnFcdC7RGXQ6qHgUJEvm2M,20889
|
|
123
|
+
rfd3/run_inference.py,sha256=dubMwEFkNPOq_yYf0ny37qvvEkRjNNPRFksZgmEFkUc,1520
|
|
124
|
+
rfd3/train.py,sha256=rHswffIUhOae3_iYyvAiQ3jALoFuzrcRUgMlbJLinlI,7947
|
|
125
|
+
rfd3/inference/datasets.py,sha256=u-2U7deHXu-iOs7doiKKynewP-NEyJfdORSTDzUSaQI,6538
|
|
126
|
+
rfd3/inference/input_parsing.py,sha256=mk3HBvo7MPTFEET7NagCo5TSjb47w-hxUDoeQxUW_h4,45449
|
|
127
|
+
rfd3/inference/legacy_input_parsing.py,sha256=1wf_KF7qWnGLaVM8IXDl8fIsWCmxtOi2YlAiHEVELqw,28046
|
|
128
|
+
rfd3/inference/parsing.py,sha256=Nq8CYmimnql4RM-5ZfPAvOFvCae4_CC2pYDzE6iCpWU,5290
|
|
129
|
+
rfd3/inference/symmetry/atom_array.py,sha256=HH50Z07bTUnNUgCwAGslADbvMYHgsXn9s-fqwx6BvKw,11034
|
|
130
|
+
rfd3/inference/symmetry/checks.py,sha256=wb7K327GnMwGG9bgOvvDAbaPsFj4nGZpEAolICUapNc,8908
|
|
131
|
+
rfd3/inference/symmetry/contigs.py,sha256=6OvbZ2dJg-a0mvvKAC0VkzUH5HpUDxOJvkByIst_roU,2127
|
|
132
|
+
rfd3/inference/symmetry/frames.py,sha256=G55p-aOXqEYG4kCyKxrgWAsS-gW9-gOTlBME6nhbKyU,10716
|
|
133
|
+
rfd3/inference/symmetry/symmetry_utils.py,sha256=KwgxrdfO766RCEwF3VElAE85oEKiopPGRQDhJbKZaUA,15810
|
|
134
|
+
rfd3/metrics/design_metrics.py,sha256=O1RqZdjQPNlAWYRg6UJTERYg_gUI1_hVleKsm9xbWBY,16836
|
|
135
|
+
rfd3/metrics/hbonds_hbplus_metrics.py,sha256=Sewy9KzmrA1OnfkasN-fmWrQ9IRx9G7Yyhe2ua0mk28,11518
|
|
136
|
+
rfd3/metrics/hbonds_metrics.py,sha256=SIR4BnDhYdpVSqwXXRYpQ_tB-M0_fVyugGl08WivCmE,15257
|
|
137
|
+
rfd3/metrics/losses.py,sha256=GDz0uO2XyYCX1kvKJ1DR5s7wWlELIqqI2PhoCnue8IM,12705
|
|
138
|
+
rfd3/metrics/metrics_utils.py,sha256=o8zmjLq4i4LfoGiJ51rZU7KnH9LX4xEVLqbH0xBIoeI,4501
|
|
139
|
+
rfd3/metrics/sidechain_metrics.py,sha256=EGZuFuWQ0cCe83EVPAf4eysN8vP9ifNjfnmE0o5aIeA,12223
|
|
140
|
+
rfd3/model/RFD3.py,sha256=95aKzye-XzuDyLGgost-Wsfu8eT635zHIRky-pNoHSA,3569
|
|
141
|
+
rfd3/model/RFD3_diffusion_module.py,sha256=BPjKGyQpbnqdzii3gXMKLhhijNqV8Xh4bSosmfDBt8w,12094
|
|
142
|
+
rfd3/model/cfg_utils.py,sha256=XPBLyoB_bQRLmdrJ1Z0hCjcVvgUMGIPuw4rxTlHjB_s,2575
|
|
143
|
+
rfd3/model/inference_sampler.py,sha256=qge7BNJttW0NXgerg3msPY3izxQ-6FsvWSTAMhZ4GJs,24696
|
|
144
|
+
rfd3/model/layers/attention.py,sha256=XuNA7WyFlRfLnAgky1PtGvXFCnDGv7GeEcXz8hodTBo,19472
|
|
145
|
+
rfd3/model/layers/block_utils.py,sha256=EZq2qYUeO6_VCLKDVC60cxfBE_EPwvp84FPmqLr28ZQ,21197
|
|
146
|
+
rfd3/model/layers/blocks.py,sha256=MOjJ53THxM2MMM27Ap7xiIXRCdI_SHzqKzLLQVX6FEc,24888
|
|
147
|
+
rfd3/model/layers/chunked_pairwise.py,sha256=de5Qc3P7GEfZlX-QLaKfJxr6Ky5vgLcWWogatCw2UnY,14582
|
|
148
|
+
rfd3/model/layers/encoders.py,sha256=CqByjHNSbtMIaNP_h2iEJZdTbm-N8SGo1bZgvRNpMJ8,15207
|
|
149
|
+
rfd3/model/layers/layer_utils.py,sha256=UPYo-DYa__93KONSEj2YZWLtBqvYNSA9_wHDDPhVrIc,5710
|
|
150
|
+
rfd3/model/layers/pairformer_layers.py,sha256=uimskhN-Ec0apEXAU6JqomyKX5-6ormrEsCFJotkBtM,3991
|
|
151
|
+
rfd3/testing/debug.py,sha256=JqpSKbSp1l9V_3trLNzpdt3gazqrSOSq7NrmcuGjpJQ,4059
|
|
152
|
+
rfd3/testing/debug_utils.py,sha256=i_GjrsRjeaREv6hlX2sEmeztpo9w9rg7Ne3VT5-YILA,2170
|
|
153
|
+
rfd3/testing/testing_utils.py,sha256=CtpTDxePbCluzuvd6jBfJNI2a3_8Ry2Whbgcf-5upiM,12202
|
|
154
|
+
rfd3/trainer/dump_validation_structures.py,sha256=qY8s2hPBflJTXPiIUnqFFE9g36y_7s39MEcMRrxZUmA,6027
|
|
155
|
+
rfd3/trainer/fabric_trainer.py,sha256=8dcyDSJFviyFU9fp6Ez02CmucKi9-DOEEwHIRcB6kQU,40074
|
|
156
|
+
rfd3/trainer/recycling.py,sha256=nRvv0vWMsMG0Ods83XKkxdgmqKMXTw-w02n_BuZOYoo,1491
|
|
157
|
+
rfd3/trainer/rfd3.py,sha256=9B_FgvTNvTDpZhRVXD1ufIRNrXOnERkFJosxe7Zy8-E,21181
|
|
158
|
+
rfd3/trainer/trainer_utils.py,sha256=1m331JI86uQvBrapLHjjEliGjU3qxafp-v47bTjsx-I,20528
|
|
159
|
+
rfd3/transforms/conditioning_base.py,sha256=A0Z2-v7ttvNa6xArpBdV8srH58gSaMI1J48ULXvQJTg,19517
|
|
160
|
+
rfd3/transforms/conditioning_utils.py,sha256=9Pn9AFbih2FCzp5OOM9y7z6KH7HPxVibxTrfuXiitMs,7498
|
|
161
|
+
rfd3/transforms/design_transforms.py,sha256=ePvnLsuKUOsE4LLcmF0bbkx1vf2AiD-35rzF4zUEcEE,30944
|
|
162
|
+
rfd3/transforms/dna_crop.py,sha256=JeOsG0tXghJvgzEimfzBvlFN_lVd9TrvjnC929Abz5A,18214
|
|
163
|
+
rfd3/transforms/hbonds.py,sha256=ijfJapFlhsh3JktpDoT3VFqKTTg6ynrqMlD7dU2xFsA,16415
|
|
164
|
+
rfd3/transforms/hbonds_hbplus.py,sha256=xyDP-CyVl2OsUY90HsrPoKw1VycBXUrq00WfrX8HJVM,8364
|
|
165
|
+
rfd3/transforms/ncaa_transforms.py,sha256=Lz4L8OGuOOG53sKJHcLSdV7WPQ3YzOzwd5tJG4CHqP0,4983
|
|
166
|
+
rfd3/transforms/pipelines.py,sha256=FGH-XH3taTWQ6k1zpDO_d-097EQdXmL6uqXZXw4HIMs,22086
|
|
167
|
+
rfd3/transforms/ppi_transforms.py,sha256=7rXyf-tn2TLz6ybYR_YVDtSDG7hOgqhYY4shNviA_Sw,23493
|
|
168
|
+
rfd3/transforms/rasa.py,sha256=a4IPFvVMMxldoGLyJQiSlGg7IyUkcBASbRZLWmguAKk,4156
|
|
169
|
+
rfd3/transforms/symmetry.py,sha256=GSnMF7oAnUxPozfafsRuHEv0yKXW0BpLTI6wsKGZrbc,2658
|
|
170
|
+
rfd3/transforms/training_conditions.py,sha256=UXiUPjDwrNKM95tRe0eXrMeRN8XlTPc_MXUvo6UpePo,19510
|
|
171
|
+
rfd3/transforms/util_transforms.py,sha256=2AcLkzx-73ZFgcWD1cIHv7NyniRPI4_zThHK8azyQaY,18119
|
|
172
|
+
rfd3/transforms/virtual_atoms.py,sha256=UpmxzPPd5FaJigcRoxgLSHHrLLOqsCvZ5PPZfQSGqII,12547
|
|
173
|
+
rfd3/utils/inference.py,sha256=RQp5CCy6Z6uHVZ2Mx0zmmGluYEOrASke4bABtfRjpy0,26448
|
|
174
|
+
rfd3/utils/io.py,sha256=wbdjUTQkDc3RCSM7gdogA-XOKR68HeQ-cfvyN4pP90w,9849
|
|
175
|
+
rfd3/utils/vizualize.py,sha256=cumsMtSFQ8lfVtJDDQ9mpfMf9GSyM55RYmKl9H61_TM,11147
|
|
176
|
+
rc_foundry-0.1.1.dist-info/METADATA,sha256=Jblzbw1ZsPICTEN0JXCPRzbP83w76Ari5SQe4S3TlRE,10578
|
|
177
|
+
rc_foundry-0.1.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
178
|
+
rc_foundry-0.1.1.dist-info/entry_points.txt,sha256=BmiWCbWGtrd_lSOFMuCLBXyo84B7Nco-alj7hB0Yw9A,130
|
|
179
|
+
rc_foundry-0.1.1.dist-info/licenses/LICENSE.md,sha256=NKtPCJ7QMysFmzeDg56ZfUStvgzbq5sOvRQv7_ddZOs,1533
|
|
180
|
+
rc_foundry-0.1.1.dist-info/RECORD,,
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
BSD 3-Clause License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025, Institute for Protein Design, University of Washington
|
|
4
|
+
|
|
5
|
+
Redistribution and use in source and binary forms, with or without
|
|
6
|
+
modification, are permitted provided that the following conditions are met:
|
|
7
|
+
|
|
8
|
+
* Redistributions of source code must retain the above copyright notice, this
|
|
9
|
+
list of conditions and the following disclaimer.
|
|
10
|
+
|
|
11
|
+
* Redistributions in binary form must reproduce the above copyright notice,
|
|
12
|
+
this list of conditions and the following disclaimer in the documentation
|
|
13
|
+
and/or other materials provided with the distribution.
|
|
14
|
+
|
|
15
|
+
* Neither the name of the copyright holder nor the names of its
|
|
16
|
+
contributors may be used to endorse or promote products derived from
|
|
17
|
+
this software without specific prior written permission.
|
|
18
|
+
|
|
19
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
20
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
21
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
22
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
23
|
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
24
|
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
25
|
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
26
|
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
27
|
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
28
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
rf3/__init__.py
ADDED
rf3/_version.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# file generated by setuptools-scm
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"__version__",
|
|
6
|
+
"__version_tuple__",
|
|
7
|
+
"version",
|
|
8
|
+
"version_tuple",
|
|
9
|
+
"__commit_id__",
|
|
10
|
+
"commit_id",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
TYPE_CHECKING = False
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from typing import Tuple, Union
|
|
16
|
+
|
|
17
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
|
18
|
+
COMMIT_ID = Union[str, None]
|
|
19
|
+
else:
|
|
20
|
+
VERSION_TUPLE = object
|
|
21
|
+
COMMIT_ID = object
|
|
22
|
+
|
|
23
|
+
version: str
|
|
24
|
+
__version__: str
|
|
25
|
+
__version_tuple__: VERSION_TUPLE
|
|
26
|
+
version_tuple: VERSION_TUPLE
|
|
27
|
+
commit_id: COMMIT_ID
|
|
28
|
+
__commit_id__: COMMIT_ID
|
|
29
|
+
|
|
30
|
+
__version__ = version = "0.1.dev917+gcbbe4c6a6.d20251001"
|
|
31
|
+
__version_tuple__ = version_tuple = (0, 1, "dev917", "gcbbe4c6a6.d20251001")
|
|
32
|
+
|
|
33
|
+
__commit_id__ = commit_id = None
|
rf3/alignment.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
logger = logging.getLogger(__name__)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def weighted_rigid_align(
|
|
9
|
+
X_L, # [B, L, 3]
|
|
10
|
+
X_gt_L, # [B, L, 3]
|
|
11
|
+
X_exists_L, # [L]
|
|
12
|
+
w_L, # [B, L]
|
|
13
|
+
):
|
|
14
|
+
"""
|
|
15
|
+
Weighted rigid body alignment of X_gt_L onto X_L with weights w_L
|
|
16
|
+
Allows for "moving target" ground truth that is se3 invariant
|
|
17
|
+
Following algorithm 28 in AF3 paper
|
|
18
|
+
Returns:
|
|
19
|
+
X_align_L: [B, L, 3]
|
|
20
|
+
"""
|
|
21
|
+
assert X_L.shape == X_gt_L.shape
|
|
22
|
+
assert X_L.shape[:-1] == w_L.shape
|
|
23
|
+
|
|
24
|
+
# Assert `X_exists_L` is a boolean mask
|
|
25
|
+
assert (
|
|
26
|
+
X_exists_L.dtype == torch.bool
|
|
27
|
+
), "X_exists_L should be a boolean mask! Otherwise, the alignment will be incorrect (silent failure)!"
|
|
28
|
+
|
|
29
|
+
X_resolved = X_L[:, X_exists_L]
|
|
30
|
+
X_gt_resolved = X_gt_L[:, X_exists_L]
|
|
31
|
+
w_resolved = w_L[:, X_exists_L]
|
|
32
|
+
u_X = torch.sum(X_resolved * w_resolved.unsqueeze(-1), dim=-2) / torch.sum(
|
|
33
|
+
w_resolved, dim=-1, keepdim=True
|
|
34
|
+
)
|
|
35
|
+
u_X_gt = torch.sum(X_gt_resolved * w_resolved.unsqueeze(-1), dim=-2) / torch.sum(
|
|
36
|
+
w_resolved, dim=-1, keepdim=True
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
X_resolved = X_resolved - u_X.unsqueeze(-2)
|
|
40
|
+
X_gt_resolved = X_gt_resolved - u_X_gt.unsqueeze(-2)
|
|
41
|
+
|
|
42
|
+
# Computation of the covariance matrix
|
|
43
|
+
C = torch.einsum("bji,bjk->bik", w_resolved[..., None] * X_gt_resolved, X_resolved)
|
|
44
|
+
|
|
45
|
+
U, S, V = torch.linalg.svd(C)
|
|
46
|
+
|
|
47
|
+
R = U @ V
|
|
48
|
+
B, _, _ = X_L.shape
|
|
49
|
+
F = torch.eye(3, 3, device=X_L.device)[None].tile(
|
|
50
|
+
(
|
|
51
|
+
B,
|
|
52
|
+
1,
|
|
53
|
+
1,
|
|
54
|
+
)
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
F[..., -1, -1] = torch.sign(torch.linalg.det(R))
|
|
58
|
+
R = U @ F @ V
|
|
59
|
+
|
|
60
|
+
X_gt_L = X_gt_L - u_X_gt.unsqueeze(-2)
|
|
61
|
+
X_align_L = X_gt_L @ R + u_X.unsqueeze(-2)
|
|
62
|
+
|
|
63
|
+
return X_align_L.detach()
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def get_rmsd(xyz1, xyz2, eps=1e-4):
|
|
67
|
+
L = xyz1.shape[-2]
|
|
68
|
+
rmsd = torch.sqrt(torch.sum((xyz2 - xyz1) * (xyz2 - xyz1), axis=(-1, -2)) / L + eps)
|
|
69
|
+
return rmsd
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def superimpose(xyz1, xyz2, mask, eps=1e-4):
|
|
73
|
+
"""
|
|
74
|
+
Superimpose xyz1 onto xyz2 using mask
|
|
75
|
+
"""
|
|
76
|
+
L = xyz1.shape[-2]
|
|
77
|
+
assert mask.shape == (L,)
|
|
78
|
+
assert xyz1.shape == xyz2.shape
|
|
79
|
+
assert mask.dtype == torch.bool
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
from os import PathLike
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from atomworks.ml.example_id import parse_example_id
|
|
5
|
+
from beartype.typing import Any
|
|
6
|
+
from rf3.utils.io import (
|
|
7
|
+
build_stack_from_atom_array_and_batched_coords,
|
|
8
|
+
dump_structures,
|
|
9
|
+
dump_trajectories,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
from foundry.callbacks.callback import BaseCallback
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DumpValidationStructuresCallback(BaseCallback):
|
|
16
|
+
"""Dump predicted structures and/or diffusion trajectories during validation"""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
save_dir: PathLike,
|
|
21
|
+
dump_predictions: bool = False,
|
|
22
|
+
one_model_per_file: bool = False,
|
|
23
|
+
dump_trajectories: bool = False,
|
|
24
|
+
compress_outputs: bool = False,
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
Args:
|
|
28
|
+
dump_predictions: Whether to dump structures (CIF files) after validation batches.
|
|
29
|
+
one_model_per_file: If True, write each structure within a diffusion batch to its own CIF files. If False,
|
|
30
|
+
include each structure within a diffusion batch as a separate model within one CIF file.
|
|
31
|
+
dump_trajectories: Whether to dump denoising trajectories after validation batches.
|
|
32
|
+
compress_outputs: Whether to gzip output files. Defaults to ``False``.
|
|
33
|
+
"""
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.save_dir = Path(save_dir)
|
|
36
|
+
self.dump_predictions = dump_predictions
|
|
37
|
+
self.dump_trajectories = dump_trajectories
|
|
38
|
+
self.one_model_per_file = one_model_per_file
|
|
39
|
+
self.compress_outputs = compress_outputs
|
|
40
|
+
|
|
41
|
+
def on_validation_batch_end(
|
|
42
|
+
self,
|
|
43
|
+
trainer,
|
|
44
|
+
outputs: dict,
|
|
45
|
+
batch: Any,
|
|
46
|
+
dataset_name: str | None,
|
|
47
|
+
**_,
|
|
48
|
+
):
|
|
49
|
+
if (not self.dump_predictions) and (not self.dump_trajectories):
|
|
50
|
+
return # Nothing to do
|
|
51
|
+
|
|
52
|
+
assert (
|
|
53
|
+
"network_output" in outputs
|
|
54
|
+
), "Validation outputs must contain `network_output` to dump structures!"
|
|
55
|
+
|
|
56
|
+
network_output = outputs["network_output"]
|
|
57
|
+
example = batch[0] # Assume batch size = 1
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
# ... try to extract the PDB ID and assembly ID from the example ID
|
|
61
|
+
parsed_id = parse_example_id(example["example_id"])
|
|
62
|
+
identifier = f"{parsed_id['pdb_id']}_{parsed_id['assembly_id']}"
|
|
63
|
+
except (KeyError, ValueError):
|
|
64
|
+
# ... if parsing fails, fall back to the original example ID
|
|
65
|
+
identifier = example["example_id"]
|
|
66
|
+
|
|
67
|
+
def _build_path_from_example_id(dir: str, extra: str = "") -> Path:
|
|
68
|
+
"""Helper function to build a path from a training or validation example_id."""
|
|
69
|
+
path = self.save_dir / dir / f"epoch_{trainer.state['current_epoch']}"
|
|
70
|
+
|
|
71
|
+
path = path / dataset_name
|
|
72
|
+
|
|
73
|
+
return path / f"{identifier}{extra}"
|
|
74
|
+
|
|
75
|
+
# Determine file type based on compression setting
|
|
76
|
+
file_type = "cif.gz" if self.compress_outputs else "cif"
|
|
77
|
+
|
|
78
|
+
if self.dump_predictions:
|
|
79
|
+
atom_array_stack = build_stack_from_atom_array_and_batched_coords(
|
|
80
|
+
network_output["X_L"], example["atom_array"]
|
|
81
|
+
)
|
|
82
|
+
dump_structures(
|
|
83
|
+
atom_arrays=atom_array_stack,
|
|
84
|
+
base_path=_build_path_from_example_id("predictions"),
|
|
85
|
+
one_model_per_file=self.one_model_per_file,
|
|
86
|
+
file_type=file_type,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
if self.dump_trajectories:
|
|
90
|
+
dump_trajectories(
|
|
91
|
+
trajectory_list=network_output["X_denoised_L_traj"],
|
|
92
|
+
atom_array=example["atom_array"],
|
|
93
|
+
base_path=_build_path_from_example_id("trajectories", "_denoised"),
|
|
94
|
+
file_type=file_type,
|
|
95
|
+
)
|
|
96
|
+
dump_trajectories(
|
|
97
|
+
trajectory_list=network_output["X_noisy_L_traj"],
|
|
98
|
+
atom_array=example["atom_array"],
|
|
99
|
+
base_path=_build_path_from_example_id("trajectories", "_noisy"),
|
|
100
|
+
file_type=file_type,
|
|
101
|
+
)
|