rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,180 @@
1
+ foundry/__init__.py,sha256=H8S1nl5v6YeW8ggn1jKy4GdtH7c-FGS-j7CqUCAEnAU,1926
2
+ foundry/common.py,sha256=Aur8mH-CNmcUqSsw7VgaCQSW5sH1Bqf8Da91jzxPV1Y,3035
3
+ foundry/constants.py,sha256=0n1wBKCvNuw3QaQehSbmsHYkIdaGn3tLeRFItBrdeHY,913
4
+ foundry/version.py,sha256=m8HxkqoKGw_wAJtc4ZokpJKNLXqp4zwnNhbnfDtro7w,704
5
+ foundry/callbacks/__init__.py,sha256=VsRT1e4sqlJHPcTCsfupMEx82Iz-LoOAGPpwvf_OJeE,126
6
+ foundry/callbacks/callback.py,sha256=xZBo_suP4bLrP6gl5uJPbaXm00DXigePa6dMeDxucgg,3890
7
+ foundry/callbacks/health_logging.py,sha256=tEtkByOlaAA7nnelxb7PbM9_dcIgOsdbxCdQY3K5pMc,16664
8
+ foundry/callbacks/metrics_logging.py,sha256=Vekzs831d-HE7TfLJZnQ45iPeG9ziQWLQaMBGaymfQM,8696
9
+ foundry/callbacks/timing_logging.py,sha256=u-r0hKp7fWOY3mLk7CcuIwHgZbhte13m5M09xNgatZA,2343
10
+ foundry/callbacks/train_logging.py,sha256=Xs3tmZA88qLxmdSOwt-x8YKN4NKb1kVm59uptNXl4Qo,10399
11
+ foundry/hydra/resolvers.py,sha256=xyJzo6OeWAc_LOu8RiHhX7_CRNoLZ22626AvYHXYl4U,2186
12
+ foundry/inference_engines/base.py,sha256=qv5Gnk6NIxMpIxZ3oeOJurqMMUzBCZgfzHckb7SSzmU,8227
13
+ foundry/inference_engines/checkpoint_registry.py,sha256=xRN3PjtmcnN7aEEhDR0MKhp1yaMDXHRXdLAGTDxi_Yk,2563
14
+ foundry/metrics/__init__.py,sha256=qL4wwaiQ7EtR30pmZ9MCknqx909BJcNvHVmNJUaz_WM,236
15
+ foundry/metrics/losses.py,sha256=2CLUmf7oCdFUCvgJukdNkff0FVG3BlATI-NI60TtpVY,903
16
+ foundry/metrics/metric.py,sha256=23pKh_Ra0EcHGo5cSzYQQrUGr5zWRxeufKSJ58tfXXo,12687
17
+ foundry/model/layers/blocks.py,sha256=ihbbP_1fOlrkrcrQSk9thCrNWjK8mtxD3WxcBng9Htk,1403
18
+ foundry/testing/__init__.py,sha256=BnrU7fZ4l0Dm1vrGcNPQYTAw83PW4DGYz7TGhGqgrfQ,223
19
+ foundry/testing/fixtures.py,sha256=j27a8CAonygjlWsUjZ-95M5MF4Rjp9nw7JskqiZlweI,486
20
+ foundry/testing/pytest_hooks.py,sha256=5Ebw1GXYO2XqS9Jvpzty7g3gCXIdXu16jqg53XcuUx4,450
21
+ foundry/trainers/fabric.py,sha256=cjaTHbGuJEQwaGBvIAXD_il4bHtY-crsTY14Xn77uXA,40401
22
+ foundry/training/EMA.py,sha256=3OWA9Pz7XuDr-SRxbz24tZf55DmhSa2fKy9r5v2IXqA,2651
23
+ foundry/training/checkpoint.py,sha256=mUiObg-qcF3tvMfVu77sD9m3yVRp71czv07ccliU7qQ,1791
24
+ foundry/training/schedulers.py,sha256=StmXegPfIdLAv31FreCTrDh9dsOvNUfzG4YGa61Y4oE,3647
25
+ foundry/utils/alignment.py,sha256=OAN7H2TqraGxP1uMXUpwLO7g0qS0cxUVjuV33pY16z0,2316
26
+ foundry/utils/components.py,sha256=Piw2TfQF26uuxC3hXG3iv_4rgud1lVO-cv6N-p05EDY,15200
27
+ foundry/utils/datasets.py,sha256=pLBxVezm-TSrYuC5gFnJZdGnNWV7aPH2QiWIVE2hkdQ,16629
28
+ foundry/utils/ddp.py,sha256=ydHrO6peGbRnWAwgH5rmpHuQd55g2gFzzoZJYypn7GU,3970
29
+ foundry/utils/instantiators.py,sha256=oGCp6hrmY-QPPPEjxKxe5uVFL125fH1RaLxjMKWCD_8,2169
30
+ foundry/utils/logging.py,sha256=jrDgiB_56q_hWDc0jkBFekvqnNWcowJBt4B-S-ipJmM,9312
31
+ foundry/utils/rigid.py,sha256=_Z1pmitb6xgxyguLj_TukKscUBJjQsU4bsBD24GVS84,44444
32
+ foundry/utils/rotation_augmentation.py,sha256=7q1WEX2iJ0i7-2aV-M97nEaEdpqexDTaZn5JquYpkUk,1927
33
+ foundry/utils/squashfs.py,sha256=QlcwuJyVe-QVfIOS7o1QfLhaCQPNzzox7ln4n8dcYEg,5234
34
+ foundry/utils/torch.py,sha256=OLsqoxw4CTXbGzWUHernLUT7uQjLu0tVPtD8h8747DI,11211
35
+ foundry/utils/weights.py,sha256=btz4S02xff2vgiq4xMfiXuhK1ERafqQPtmimo1DmoWY,10381
36
+ foundry_cli/__init__.py,sha256=0BxY2RUKJLaMXUGgypPCwlTskTEFdVnkhTR4C4ft2Kw,52
37
+ foundry_cli/download_checkpoints.py,sha256=2PDKw-dWCht_mD6fRYTiOtVlk4P2CQqTPSkN-7s19mk,8474
38
+ mpnn/__init__.py,sha256=hgQcXFaCbAxFrhydVAy0xj8yC7UJF-GCCFhqD0sZ7I4,57
39
+ mpnn/inference.py,sha256=wPtGR325eVRVeesXoWtBK6b_-VcU8BZae5IfQN3-mvA,1669
40
+ mpnn/train.py,sha256=9eQGBd3rdNF5Zr2w8oUgETbqxBavNBajtA6Vbc5zESE,10239
41
+ mpnn/collate/feature_collator.py,sha256=LpzAFWo1VMa06dJLmfUWZsKe4xvLZjHbx4RICg2lgbQ,10510
42
+ mpnn/inference_engines/mpnn.py,sha256=985Ce84dWKNZk_5_dk1eJScYIgBGUKUhy-xDVCu15PA,21631
43
+ mpnn/loss/nll_loss.py,sha256=KmdNe-BCzGYtijjappzBArQcT1gHVlJnKdY1PYQ4mhU,5947
44
+ mpnn/metrics/nll.py,sha256=T6oMeUOEeHZzOMTH8NHFtsH9vUwLAsHQDPszzj4YKXI,15299
45
+ mpnn/metrics/sequence_recovery.py,sha256=YDw_LmH-a3ajBYWK0mucJEQvw0_VEyxvrBN7da4vX8Q,19034
46
+ mpnn/model/mpnn.py,sha256=vhkair2tYoId_akRP2qEq5O0IMZv6wsv9Q-V9plKgV8,131144
47
+ mpnn/model/layers/graph_embeddings.py,sha256=aEtd7iorMh8DxNH0eZVrK_zOo8HDLM5FRJyIJ8Cfz6k,99795
48
+ mpnn/model/layers/message_passing.py,sha256=TUkG9pXuo4Rtz5Bcij-OB7T4gSKmLt1KgxNmjJYPcMY,13051
49
+ mpnn/model/layers/position_wise_feed_forward.py,sha256=FATM8oveWy2XW-PDaaF9XLPIiWbehOHxG715E60n_8g,1602
50
+ mpnn/model/layers/positional_encoding.py,sha256=f-YpH1xvPGFC75U2-sOHrK13XtA9ZAWxjxxH1GrDt1M,4876
51
+ mpnn/pipelines/mpnn.py,sha256=SukwxEcAzaCgUZKcA1_KusodvcCg3_buN1dAZU-Udas,6185
52
+ mpnn/samplers/samplers.py,sha256=LDpetPMVklMboj1tucgnNvSHRUaQuehBmR2jFl4VWIE,6129
53
+ mpnn/trainers/mpnn.py,sha256=waXLQ-7pFD8MJRlnK37mHWcvqD6uOjTXVP6910tB6cw,6586
54
+ mpnn/transforms/polymer_ligand_interface.py,sha256=lipKDt_NFrpM-GiOXtvnTAvMpISOO4eHwilCgxnISJU,6106
55
+ mpnn/transforms/feature_aggregation/mpnn.py,sha256=jkhyMCqJipKQ2PvjqPkvvClhoiXx_I8e03lnDeH9__M,6324
56
+ mpnn/transforms/feature_aggregation/polymer_ligand_interface.py,sha256=gDdt9RZd0PO0YJdouNr0qsHFZV1i-5ewU6XuJrwPY54,2870
57
+ mpnn/transforms/feature_aggregation/token_encodings.py,sha256=qVlUky4HcDSU5drrZpZBnUvTSGdT6C7MN8f_owa81Bw,2227
58
+ mpnn/transforms/feature_aggregation/user_settings.py,sha256=uKyIDXz-QG0-KWQO1kqPlMj6i7RoVM6yH4iGNXFStoU,15007
59
+ mpnn/utils/inference.py,sha256=QLeukqLpedMNmvjbYgvLwDS5k7Q__NWILDSEbETkoCI,96539
60
+ mpnn/utils/probability.py,sha256=EYisliXNGXjuSPbzZwcIKjlhyINikGsqQndGBEbQoPI,990
61
+ mpnn/utils/weights.py,sha256=VsaIcOWTv8G-WJ9denxLRm3FQ9l6L66AVQN08E9BMSg,16411
62
+ rf3/__init__.py,sha256=XBb5hF2RqBPHODGRmjiRbfTXgOGfOzdY91GbS4Vex00,70
63
+ rf3/_version.py,sha256=fCfpbI5aeA6yHqjo3tK78-l2dPGxhp-AyKSoCXp34Nc,739
64
+ rf3/alignment.py,sha256=BvvwMqQGCVxV20xIsTighD1kXMadXXL2SkckLjTerx0,2102
65
+ rf3/chemical.py,sha256=VECnRPgVm-icXbZeUG4svcENzdUiIupP6dhka_8zCrg,26572
66
+ rf3/cli.py,sha256=dPKJFRHYoV2XS6xc_ZmdLTz6frqa6HZg4qgZU5oJcXU,2356
67
+ rf3/inference.py,sha256=_AAJ07AfSeU3xTM2_KH9n_H12EK4qZ23IJuyauOrMaQ,2466
68
+ rf3/kinematics.py,sha256=V3yjalPupu1X2FEp7l3XZR-qzLKrhWLZyECk6RgIkcs,10901
69
+ rf3/scoring.py,sha256=dTllswE-6Fgli2eLiNzLFc2Rhz4ouDT4WL-sVbvLTGU,41541
70
+ rf3/train.py,sha256=V4nqCC_1JKLI3WQ-nErNa8sqFpvb1mFhXSe6ZPpEheM,7945
71
+ rf3/util_module.py,sha256=ltc7QXJDb5Z184wxYQuT_-Z68YWXuPmOEBMLSzS6Pes,1428
72
+ rf3/validate.py,sha256=wXZLTWiOdTsCKiKK2_Dfnj5PInDiv5KxojOZwaUJjuo,5832
73
+ rf3/callbacks/dump_validation_structures.py,sha256=j_pDfPETyI10ZtsUlvf16-zpJdaUcC5w8TEYCk--Xbo,3909
74
+ rf3/callbacks/metrics_logging.py,sha256=MYcM_ZYKsBTJKx2xi9H0QPr5Lh1o80_bTZXH8kfV5y8,13429
75
+ rf3/data/cyclic_transform.py,sha256=Cs4x_qooCUXKNiFeVdfrXAGhZ0z_yyedscWRsGmEpwM,3351
76
+ rf3/data/extra_xforms.py,sha256=Pxv4Nt4Ir_Ca92XWx_c6mdaCl7fS3_gFCWG775WWVSQ,1197
77
+ rf3/data/ground_truth_template.py,sha256=dct1bGQ7AMjiNyLIotKJZzbbUI546mc-UDdRpKMU7vU,19460
78
+ rf3/data/paired_msa.py,sha256=aso9awdKthzsm1ITHHO8r-1O5vFDrDL-ot7FhXB7Css,7765
79
+ rf3/data/pipeline_utils.py,sha256=qLNcy2iTmZxwCJJk9etXCn_vRljUYhGqSn0uNsRDGO0,5008
80
+ rf3/data/pipelines.py,sha256=3yy8f4pUSiV9mnhDmboO2RAsIlIJ0WYCzTzXT3OGQ5M,20913
81
+ rf3/diffusion_samplers/inference_sampler.py,sha256=dNhVGQujfcMhIlFdlLjZHiWXBnaLjOjRBd58eM6SJv4,9014
82
+ rf3/inference_engines/__init__.py,sha256=tFyTBEMsQ2oR-x2TE9-7B3HIlkHFmoOP2d6hfomj_hc,121
83
+ rf3/inference_engines/rf3.py,sha256=rPJaY355JUvqq_UOAM7gSHzYBE6keN0zPG57bjUb0qU,29761
84
+ rf3/loss/af3_confidence_loss.py,sha256=X8TLudvIFxD2GHlNLsBHoysO0qeWGdaKdW8It6G7nhE,19697
85
+ rf3/loss/af3_losses.py,sha256=oD-sFIJgNR11OAFK_K-6OeKdeIauEabWNMJX8M9mE6k,24492
86
+ rf3/loss/loss.py,sha256=3aHB8FiA0WkeTacfMUnG9mnoILyjX8AE9345eToubag,6184
87
+ rf3/metrics/chiral.py,sha256=xeZjBv9XFa8Tmo7aFjevwLuRk6GGpRU4CIkT41SvnFI,7139
88
+ rf3/metrics/clashing_chains.py,sha256=El3CILpTMSD-U-5pgUugXigNao9AZ17KxOW1-NX4N38,2518
89
+ rf3/metrics/distogram.py,sha256=lXLPfNtWnt4d3_Vc9A1AsAv3c0eZDRdK5tixgO9rIj8,16978
90
+ rf3/metrics/lddt.py,sha256=VjObCWvjAhElgGFckx_sRx8toUq4TVfr5ip8ThW06Qw,21412
91
+ rf3/metrics/metadata.py,sha256=hFEJ4thQiV8vs3fzj5dxK4BS3nb4PSnhCoBrfQXwD2I,1437
92
+ rf3/metrics/metric_utils.py,sha256=rdTY3Uc4at-Y7jLaDfaEhp17Z2KYqLGETuuCFSN59bY,7125
93
+ rf3/metrics/predicted_error.py,sha256=tsUFyW6Jv8m4REKui0mVQkfACB8PTVOzMstFW1d5pAA,4973
94
+ rf3/metrics/rasa.py,sha256=mQ6ZQdroC4CY3XDiUVWKvtHFCGTxyhDASqYF8SjnQGU,4525
95
+ rf3/metrics/selected_distances.py,sha256=uyTlbPHnf6PpZ6JMRkfDAAY-GkxerjpDkurTOSt3EV0,3620
96
+ rf3/model/RF3.py,sha256=fAJu8FG54tdo7wKcwkDLhorgBu4aBjqNAg4KGJIjgmc,22383
97
+ rf3/model/RF3_blocks.py,sha256=FZliymoYsYpDz_YqPsuXILDZaiE6IGcX1wJoeohueko,3218
98
+ rf3/model/RF3_structure.py,sha256=EnZuYk8dJWgLubp9Mui1f_e6hE0z6XuLn1HX7N_LfCo,9652
99
+ rf3/model/layers/af3_auxiliary_heads.py,sha256=kJqsT0_plh9_TXGN4HpPB1NzGIarCZyAIG5NiB5RN3Q,9937
100
+ rf3/model/layers/af3_diffusion_transformer.py,sha256=uU3OsubqsrGWjTHzEsOb6hzlc1gB5-o7GcK8ElRiLs8,19143
101
+ rf3/model/layers/attention.py,sha256=ofz4LR74oFBZ64RBAcbFlksV2E2NkYM69LoMDc4qedo,11189
102
+ rf3/model/layers/layer_utils.py,sha256=dzrYwvgdKS2ouDnRiseU2x7VBcRcXa1zeHA_EKLyO78,3382
103
+ rf3/model/layers/mlff.py,sha256=SFHskP18xB7zuZAzEeubw7kKwMEGZYlypGS_zZ-02yk,5017
104
+ rf3/model/layers/outer_product.py,sha256=OYam3gsxJa7JBet71_eFZ7D7mXQghluLXMRVHv2xUSs,2037
105
+ rf3/model/layers/pairformer_layers.py,sha256=nV5zwN6CLIqfCK7UTaNX_n67ASgfWp58KJMlj3TBbpg,28724
106
+ rf3/model/layers/structure_bias.py,sha256=xhu_KNRlE_FNfB0dwfHY47xhdFMmIPgV2qeYb_WyepI,1757
107
+ rf3/symmetry/resolve.py,sha256=odQF32aM4BqaOrmfEQJtwQxKgdJL5J0_Htb1I5KEDPA,10843
108
+ rf3/trainers/rf3.py,sha256=P9zLTMu7YaxllBiLgyrn6FnP1vKCOaerA7ciDgmqBrA,24141
109
+ rf3/utils/frames.py,sha256=6LVuV2XbODKNRU_ggGkd0EBXBT7F0q-HXFad4eTUOVs,3745
110
+ rf3/utils/inference.py,sha256=MjNkhoMzwPUb5rCouvEjtW-6XRa3yb471DQtpxzrhdk,24772
111
+ rf3/utils/io.py,sha256=xwOjzWviYKphuMbJgj18dylI72n0oF7mdUp8V5qlsaQ,8051
112
+ rf3/utils/loss.py,sha256=llaiL-5VaNTDMwh0TK_nIzniYPg0zDhIzzM8i8fYCqY,2757
113
+ rf3/utils/predict_and_score.py,sha256=VzRZohertYLMfnT9SwRs1gMGEhspV0LPBKmUjpU5WAY,6209
114
+ rf3/utils/predicted_error.py,sha256=5gIRjsD7bWvYYM30_wq9827porc1oj_Qq-cjyetYJ_s,25438
115
+ rf3/utils/recycling.py,sha256=nRvv0vWMsMG0Ods83XKkxdgmqKMXTw-w02n_BuZOYoo,1491
116
+ rfd3/.gitignore,sha256=935nLWJz_oi5h-UjxP4L_ulsMpkbRIVsl0dgGCwTCbc,109
117
+ rfd3/Makefile,sha256=_O87r1eIN7AmWWIqur3z0tLn1kgAPGEAGX2fcddarMs,2224
118
+ rfd3/__init__.py,sha256=2Wto2IsUIj2lGag9m_gqgdCwBNl5p21-Xnr7W_RpU3c,348
119
+ rfd3/callbacks.py,sha256=Zjt8RiaYWquoKOwRmC_wCUbRbov-V4zd2_73zjhgDHE,2783
120
+ rfd3/cli.py,sha256=TZpZouXGmwAMFaH8hp4r3q9tbUi1xlcN8n_r8hO2q8c,1424
121
+ rfd3/constants.py,sha256=wLvDzrThpOrK8T3wGFNQeGrhAXOJQze8l3v_7pjIdMM,13141
122
+ rfd3/engine.py,sha256=HCIHgB_4xFBYz7d7DGnqibnFcdC7RGXQ6qHgUJEvm2M,20889
123
+ rfd3/run_inference.py,sha256=dubMwEFkNPOq_yYf0ny37qvvEkRjNNPRFksZgmEFkUc,1520
124
+ rfd3/train.py,sha256=rHswffIUhOae3_iYyvAiQ3jALoFuzrcRUgMlbJLinlI,7947
125
+ rfd3/inference/datasets.py,sha256=u-2U7deHXu-iOs7doiKKynewP-NEyJfdORSTDzUSaQI,6538
126
+ rfd3/inference/input_parsing.py,sha256=mk3HBvo7MPTFEET7NagCo5TSjb47w-hxUDoeQxUW_h4,45449
127
+ rfd3/inference/legacy_input_parsing.py,sha256=1wf_KF7qWnGLaVM8IXDl8fIsWCmxtOi2YlAiHEVELqw,28046
128
+ rfd3/inference/parsing.py,sha256=Nq8CYmimnql4RM-5ZfPAvOFvCae4_CC2pYDzE6iCpWU,5290
129
+ rfd3/inference/symmetry/atom_array.py,sha256=HH50Z07bTUnNUgCwAGslADbvMYHgsXn9s-fqwx6BvKw,11034
130
+ rfd3/inference/symmetry/checks.py,sha256=wb7K327GnMwGG9bgOvvDAbaPsFj4nGZpEAolICUapNc,8908
131
+ rfd3/inference/symmetry/contigs.py,sha256=6OvbZ2dJg-a0mvvKAC0VkzUH5HpUDxOJvkByIst_roU,2127
132
+ rfd3/inference/symmetry/frames.py,sha256=G55p-aOXqEYG4kCyKxrgWAsS-gW9-gOTlBME6nhbKyU,10716
133
+ rfd3/inference/symmetry/symmetry_utils.py,sha256=KwgxrdfO766RCEwF3VElAE85oEKiopPGRQDhJbKZaUA,15810
134
+ rfd3/metrics/design_metrics.py,sha256=O1RqZdjQPNlAWYRg6UJTERYg_gUI1_hVleKsm9xbWBY,16836
135
+ rfd3/metrics/hbonds_hbplus_metrics.py,sha256=Sewy9KzmrA1OnfkasN-fmWrQ9IRx9G7Yyhe2ua0mk28,11518
136
+ rfd3/metrics/hbonds_metrics.py,sha256=SIR4BnDhYdpVSqwXXRYpQ_tB-M0_fVyugGl08WivCmE,15257
137
+ rfd3/metrics/losses.py,sha256=GDz0uO2XyYCX1kvKJ1DR5s7wWlELIqqI2PhoCnue8IM,12705
138
+ rfd3/metrics/metrics_utils.py,sha256=o8zmjLq4i4LfoGiJ51rZU7KnH9LX4xEVLqbH0xBIoeI,4501
139
+ rfd3/metrics/sidechain_metrics.py,sha256=EGZuFuWQ0cCe83EVPAf4eysN8vP9ifNjfnmE0o5aIeA,12223
140
+ rfd3/model/RFD3.py,sha256=95aKzye-XzuDyLGgost-Wsfu8eT635zHIRky-pNoHSA,3569
141
+ rfd3/model/RFD3_diffusion_module.py,sha256=BPjKGyQpbnqdzii3gXMKLhhijNqV8Xh4bSosmfDBt8w,12094
142
+ rfd3/model/cfg_utils.py,sha256=XPBLyoB_bQRLmdrJ1Z0hCjcVvgUMGIPuw4rxTlHjB_s,2575
143
+ rfd3/model/inference_sampler.py,sha256=qge7BNJttW0NXgerg3msPY3izxQ-6FsvWSTAMhZ4GJs,24696
144
+ rfd3/model/layers/attention.py,sha256=XuNA7WyFlRfLnAgky1PtGvXFCnDGv7GeEcXz8hodTBo,19472
145
+ rfd3/model/layers/block_utils.py,sha256=EZq2qYUeO6_VCLKDVC60cxfBE_EPwvp84FPmqLr28ZQ,21197
146
+ rfd3/model/layers/blocks.py,sha256=MOjJ53THxM2MMM27Ap7xiIXRCdI_SHzqKzLLQVX6FEc,24888
147
+ rfd3/model/layers/chunked_pairwise.py,sha256=de5Qc3P7GEfZlX-QLaKfJxr6Ky5vgLcWWogatCw2UnY,14582
148
+ rfd3/model/layers/encoders.py,sha256=CqByjHNSbtMIaNP_h2iEJZdTbm-N8SGo1bZgvRNpMJ8,15207
149
+ rfd3/model/layers/layer_utils.py,sha256=UPYo-DYa__93KONSEj2YZWLtBqvYNSA9_wHDDPhVrIc,5710
150
+ rfd3/model/layers/pairformer_layers.py,sha256=uimskhN-Ec0apEXAU6JqomyKX5-6ormrEsCFJotkBtM,3991
151
+ rfd3/testing/debug.py,sha256=JqpSKbSp1l9V_3trLNzpdt3gazqrSOSq7NrmcuGjpJQ,4059
152
+ rfd3/testing/debug_utils.py,sha256=i_GjrsRjeaREv6hlX2sEmeztpo9w9rg7Ne3VT5-YILA,2170
153
+ rfd3/testing/testing_utils.py,sha256=CtpTDxePbCluzuvd6jBfJNI2a3_8Ry2Whbgcf-5upiM,12202
154
+ rfd3/trainer/dump_validation_structures.py,sha256=qY8s2hPBflJTXPiIUnqFFE9g36y_7s39MEcMRrxZUmA,6027
155
+ rfd3/trainer/fabric_trainer.py,sha256=8dcyDSJFviyFU9fp6Ez02CmucKi9-DOEEwHIRcB6kQU,40074
156
+ rfd3/trainer/recycling.py,sha256=nRvv0vWMsMG0Ods83XKkxdgmqKMXTw-w02n_BuZOYoo,1491
157
+ rfd3/trainer/rfd3.py,sha256=9B_FgvTNvTDpZhRVXD1ufIRNrXOnERkFJosxe7Zy8-E,21181
158
+ rfd3/trainer/trainer_utils.py,sha256=1m331JI86uQvBrapLHjjEliGjU3qxafp-v47bTjsx-I,20528
159
+ rfd3/transforms/conditioning_base.py,sha256=A0Z2-v7ttvNa6xArpBdV8srH58gSaMI1J48ULXvQJTg,19517
160
+ rfd3/transforms/conditioning_utils.py,sha256=9Pn9AFbih2FCzp5OOM9y7z6KH7HPxVibxTrfuXiitMs,7498
161
+ rfd3/transforms/design_transforms.py,sha256=ePvnLsuKUOsE4LLcmF0bbkx1vf2AiD-35rzF4zUEcEE,30944
162
+ rfd3/transforms/dna_crop.py,sha256=JeOsG0tXghJvgzEimfzBvlFN_lVd9TrvjnC929Abz5A,18214
163
+ rfd3/transforms/hbonds.py,sha256=ijfJapFlhsh3JktpDoT3VFqKTTg6ynrqMlD7dU2xFsA,16415
164
+ rfd3/transforms/hbonds_hbplus.py,sha256=xyDP-CyVl2OsUY90HsrPoKw1VycBXUrq00WfrX8HJVM,8364
165
+ rfd3/transforms/ncaa_transforms.py,sha256=Lz4L8OGuOOG53sKJHcLSdV7WPQ3YzOzwd5tJG4CHqP0,4983
166
+ rfd3/transforms/pipelines.py,sha256=FGH-XH3taTWQ6k1zpDO_d-097EQdXmL6uqXZXw4HIMs,22086
167
+ rfd3/transforms/ppi_transforms.py,sha256=7rXyf-tn2TLz6ybYR_YVDtSDG7hOgqhYY4shNviA_Sw,23493
168
+ rfd3/transforms/rasa.py,sha256=a4IPFvVMMxldoGLyJQiSlGg7IyUkcBASbRZLWmguAKk,4156
169
+ rfd3/transforms/symmetry.py,sha256=GSnMF7oAnUxPozfafsRuHEv0yKXW0BpLTI6wsKGZrbc,2658
170
+ rfd3/transforms/training_conditions.py,sha256=UXiUPjDwrNKM95tRe0eXrMeRN8XlTPc_MXUvo6UpePo,19510
171
+ rfd3/transforms/util_transforms.py,sha256=2AcLkzx-73ZFgcWD1cIHv7NyniRPI4_zThHK8azyQaY,18119
172
+ rfd3/transforms/virtual_atoms.py,sha256=UpmxzPPd5FaJigcRoxgLSHHrLLOqsCvZ5PPZfQSGqII,12547
173
+ rfd3/utils/inference.py,sha256=RQp5CCy6Z6uHVZ2Mx0zmmGluYEOrASke4bABtfRjpy0,26448
174
+ rfd3/utils/io.py,sha256=wbdjUTQkDc3RCSM7gdogA-XOKR68HeQ-cfvyN4pP90w,9849
175
+ rfd3/utils/vizualize.py,sha256=cumsMtSFQ8lfVtJDDQ9mpfMf9GSyM55RYmKl9H61_TM,11147
176
+ rc_foundry-0.1.1.dist-info/METADATA,sha256=Jblzbw1ZsPICTEN0JXCPRzbP83w76Ari5SQe4S3TlRE,10578
177
+ rc_foundry-0.1.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
178
+ rc_foundry-0.1.1.dist-info/entry_points.txt,sha256=BmiWCbWGtrd_lSOFMuCLBXyo84B7Nco-alj7hB0Yw9A,130
179
+ rc_foundry-0.1.1.dist-info/licenses/LICENSE.md,sha256=NKtPCJ7QMysFmzeDg56ZfUStvgzbq5sOvRQv7_ddZOs,1533
180
+ rc_foundry-0.1.1.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.28.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,5 @@
1
+ [console_scripts]
2
+ foundry = foundry_cli.download_checkpoints:app
3
+ mpnn = mpnn.inference:main
4
+ rf3 = rf3.cli:app
5
+ rfd3 = rfd3.cli:app
@@ -0,0 +1,28 @@
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2025, Institute for Protein Design, University of Washington
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ * Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ * Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ * Neither the name of the copyright holder nor the names of its
16
+ contributors may be used to endorse or promote products derived from
17
+ this software without specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
rf3/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ """RF3 - RosettaFold3 model implementation."""
2
+
3
+ __version__ = "0.1.0"
rf3/_version.py ADDED
@@ -0,0 +1,33 @@
1
+ # file generated by setuptools-scm
2
+ # don't change, don't track in version control
3
+
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
12
+
13
+ TYPE_CHECKING = False
14
+ if TYPE_CHECKING:
15
+ from typing import Tuple, Union
16
+
17
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
18
+ COMMIT_ID = Union[str, None]
19
+ else:
20
+ VERSION_TUPLE = object
21
+ COMMIT_ID = object
22
+
23
+ version: str
24
+ __version__: str
25
+ __version_tuple__: VERSION_TUPLE
26
+ version_tuple: VERSION_TUPLE
27
+ commit_id: COMMIT_ID
28
+ __commit_id__: COMMIT_ID
29
+
30
+ __version__ = version = "0.1.dev917+gcbbe4c6a6.d20251001"
31
+ __version_tuple__ = version_tuple = (0, 1, "dev917", "gcbbe4c6a6.d20251001")
32
+
33
+ __commit_id__ = commit_id = None
rf3/alignment.py ADDED
@@ -0,0 +1,79 @@
1
+ import logging
2
+
3
+ import torch
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ def weighted_rigid_align(
9
+ X_L, # [B, L, 3]
10
+ X_gt_L, # [B, L, 3]
11
+ X_exists_L, # [L]
12
+ w_L, # [B, L]
13
+ ):
14
+ """
15
+ Weighted rigid body alignment of X_gt_L onto X_L with weights w_L
16
+ Allows for "moving target" ground truth that is se3 invariant
17
+ Following algorithm 28 in AF3 paper
18
+ Returns:
19
+ X_align_L: [B, L, 3]
20
+ """
21
+ assert X_L.shape == X_gt_L.shape
22
+ assert X_L.shape[:-1] == w_L.shape
23
+
24
+ # Assert `X_exists_L` is a boolean mask
25
+ assert (
26
+ X_exists_L.dtype == torch.bool
27
+ ), "X_exists_L should be a boolean mask! Otherwise, the alignment will be incorrect (silent failure)!"
28
+
29
+ X_resolved = X_L[:, X_exists_L]
30
+ X_gt_resolved = X_gt_L[:, X_exists_L]
31
+ w_resolved = w_L[:, X_exists_L]
32
+ u_X = torch.sum(X_resolved * w_resolved.unsqueeze(-1), dim=-2) / torch.sum(
33
+ w_resolved, dim=-1, keepdim=True
34
+ )
35
+ u_X_gt = torch.sum(X_gt_resolved * w_resolved.unsqueeze(-1), dim=-2) / torch.sum(
36
+ w_resolved, dim=-1, keepdim=True
37
+ )
38
+
39
+ X_resolved = X_resolved - u_X.unsqueeze(-2)
40
+ X_gt_resolved = X_gt_resolved - u_X_gt.unsqueeze(-2)
41
+
42
+ # Computation of the covariance matrix
43
+ C = torch.einsum("bji,bjk->bik", w_resolved[..., None] * X_gt_resolved, X_resolved)
44
+
45
+ U, S, V = torch.linalg.svd(C)
46
+
47
+ R = U @ V
48
+ B, _, _ = X_L.shape
49
+ F = torch.eye(3, 3, device=X_L.device)[None].tile(
50
+ (
51
+ B,
52
+ 1,
53
+ 1,
54
+ )
55
+ )
56
+
57
+ F[..., -1, -1] = torch.sign(torch.linalg.det(R))
58
+ R = U @ F @ V
59
+
60
+ X_gt_L = X_gt_L - u_X_gt.unsqueeze(-2)
61
+ X_align_L = X_gt_L @ R + u_X.unsqueeze(-2)
62
+
63
+ return X_align_L.detach()
64
+
65
+
66
+ def get_rmsd(xyz1, xyz2, eps=1e-4):
67
+ L = xyz1.shape[-2]
68
+ rmsd = torch.sqrt(torch.sum((xyz2 - xyz1) * (xyz2 - xyz1), axis=(-1, -2)) / L + eps)
69
+ return rmsd
70
+
71
+
72
+ def superimpose(xyz1, xyz2, mask, eps=1e-4):
73
+ """
74
+ Superimpose xyz1 onto xyz2 using mask
75
+ """
76
+ L = xyz1.shape[-2]
77
+ assert mask.shape == (L,)
78
+ assert xyz1.shape == xyz2.shape
79
+ assert mask.dtype == torch.bool
@@ -0,0 +1,101 @@
1
+ from os import PathLike
2
+ from pathlib import Path
3
+
4
+ from atomworks.ml.example_id import parse_example_id
5
+ from beartype.typing import Any
6
+ from rf3.utils.io import (
7
+ build_stack_from_atom_array_and_batched_coords,
8
+ dump_structures,
9
+ dump_trajectories,
10
+ )
11
+
12
+ from foundry.callbacks.callback import BaseCallback
13
+
14
+
15
+ class DumpValidationStructuresCallback(BaseCallback):
16
+ """Dump predicted structures and/or diffusion trajectories during validation"""
17
+
18
+ def __init__(
19
+ self,
20
+ save_dir: PathLike,
21
+ dump_predictions: bool = False,
22
+ one_model_per_file: bool = False,
23
+ dump_trajectories: bool = False,
24
+ compress_outputs: bool = False,
25
+ ):
26
+ """
27
+ Args:
28
+ dump_predictions: Whether to dump structures (CIF files) after validation batches.
29
+ one_model_per_file: If True, write each structure within a diffusion batch to its own CIF files. If False,
30
+ include each structure within a diffusion batch as a separate model within one CIF file.
31
+ dump_trajectories: Whether to dump denoising trajectories after validation batches.
32
+ compress_outputs: Whether to gzip output files. Defaults to ``False``.
33
+ """
34
+ super().__init__()
35
+ self.save_dir = Path(save_dir)
36
+ self.dump_predictions = dump_predictions
37
+ self.dump_trajectories = dump_trajectories
38
+ self.one_model_per_file = one_model_per_file
39
+ self.compress_outputs = compress_outputs
40
+
41
+ def on_validation_batch_end(
42
+ self,
43
+ trainer,
44
+ outputs: dict,
45
+ batch: Any,
46
+ dataset_name: str | None,
47
+ **_,
48
+ ):
49
+ if (not self.dump_predictions) and (not self.dump_trajectories):
50
+ return # Nothing to do
51
+
52
+ assert (
53
+ "network_output" in outputs
54
+ ), "Validation outputs must contain `network_output` to dump structures!"
55
+
56
+ network_output = outputs["network_output"]
57
+ example = batch[0] # Assume batch size = 1
58
+
59
+ try:
60
+ # ... try to extract the PDB ID and assembly ID from the example ID
61
+ parsed_id = parse_example_id(example["example_id"])
62
+ identifier = f"{parsed_id['pdb_id']}_{parsed_id['assembly_id']}"
63
+ except (KeyError, ValueError):
64
+ # ... if parsing fails, fall back to the original example ID
65
+ identifier = example["example_id"]
66
+
67
+ def _build_path_from_example_id(dir: str, extra: str = "") -> Path:
68
+ """Helper function to build a path from a training or validation example_id."""
69
+ path = self.save_dir / dir / f"epoch_{trainer.state['current_epoch']}"
70
+
71
+ path = path / dataset_name
72
+
73
+ return path / f"{identifier}{extra}"
74
+
75
+ # Determine file type based on compression setting
76
+ file_type = "cif.gz" if self.compress_outputs else "cif"
77
+
78
+ if self.dump_predictions:
79
+ atom_array_stack = build_stack_from_atom_array_and_batched_coords(
80
+ network_output["X_L"], example["atom_array"]
81
+ )
82
+ dump_structures(
83
+ atom_arrays=atom_array_stack,
84
+ base_path=_build_path_from_example_id("predictions"),
85
+ one_model_per_file=self.one_model_per_file,
86
+ file_type=file_type,
87
+ )
88
+
89
+ if self.dump_trajectories:
90
+ dump_trajectories(
91
+ trajectory_list=network_output["X_denoised_L_traj"],
92
+ atom_array=example["atom_array"],
93
+ base_path=_build_path_from_example_id("trajectories", "_denoised"),
94
+ file_type=file_type,
95
+ )
96
+ dump_trajectories(
97
+ trajectory_list=network_output["X_noisy_L_traj"],
98
+ atom_array=example["atom_array"],
99
+ base_path=_build_path_from_example_id("trajectories", "_noisy"),
100
+ file_type=file_type,
101
+ )