rc-foundry 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/utils/ddp.py +1 -1
- foundry/utils/logging.py +1 -1
- foundry/version.py +2 -2
- {rc_foundry-0.1.7.dist-info → rc_foundry-0.1.9.dist-info}/METADATA +6 -2
- {rc_foundry-0.1.7.dist-info → rc_foundry-0.1.9.dist-info}/RECORD +22 -22
- rf3/cli.py +13 -4
- rf3/inference.py +3 -1
- rfd3/engine.py +11 -3
- rfd3/inference/datasets.py +1 -1
- rfd3/inference/input_parsing.py +31 -0
- rfd3/inference/symmetry/atom_array.py +78 -9
- rfd3/inference/symmetry/checks.py +12 -4
- rfd3/inference/symmetry/frames.py +248 -0
- rfd3/inference/symmetry/symmetry_utils.py +5 -5
- rfd3/model/inference_sampler.py +11 -1
- rfd3/model/layers/block_utils.py +33 -33
- rfd3/model/layers/chunked_pairwise.py +84 -82
- rfd3/transforms/symmetry.py +16 -7
- rfd3/utils/inference.py +4 -28
- {rc_foundry-0.1.7.dist-info → rc_foundry-0.1.9.dist-info}/WHEEL +0 -0
- {rc_foundry-0.1.7.dist-info → rc_foundry-0.1.9.dist-info}/entry_points.txt +0 -0
- {rc_foundry-0.1.7.dist-info → rc_foundry-0.1.9.dist-info}/licenses/LICENSE.md +0 -0
foundry/utils/ddp.py
CHANGED
|
@@ -2,7 +2,7 @@ import logging
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
from beartype.typing import Any
|
|
5
|
-
from
|
|
5
|
+
from lightning.fabric.utilities import rank_zero_only
|
|
6
6
|
from lightning_utilities.core.rank_zero import rank_prefixed_message
|
|
7
7
|
from omegaconf import DictConfig
|
|
8
8
|
|
foundry/utils/logging.py
CHANGED
|
@@ -4,7 +4,7 @@ from contextlib import contextmanager
|
|
|
4
4
|
|
|
5
5
|
import pandas as pd
|
|
6
6
|
from beartype.typing import Any
|
|
7
|
-
from
|
|
7
|
+
from lightning.fabric.utilities import rank_zero_only
|
|
8
8
|
from omegaconf import DictConfig, OmegaConf
|
|
9
9
|
from rich.console import Console
|
|
10
10
|
from rich.syntax import Syntax
|
foundry/version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.1.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 1,
|
|
31
|
+
__version__ = version = '0.1.9'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 1, 9)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: rc-foundry
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.9
|
|
4
4
|
Summary: Shared utilities and training infrastructure for biomolecular structure prediction models.
|
|
5
5
|
Author-email: Institute for Protein Design <contact@ipd.uw.edu>
|
|
6
6
|
License: BSD 3-Clause License
|
|
@@ -97,11 +97,15 @@ Foundry provides tooling and infrastructure for using and training all classes o
|
|
|
97
97
|
|
|
98
98
|
All models within Foundry rely on [AtomWorks](https://github.com/RosettaCommons/atomworks) - a unified framework for manipulating and processing biomolecular structures - for both training and inference.
|
|
99
99
|
|
|
100
|
+
|
|
101
|
+
> [!NOTE]
|
|
102
|
+
> We have a slack now! Join for updates and to get your questions answered [here](https://join.slack.com/t/proteinmodelfoundry/shared_invite/zt-3kpwru8c6-nrmTW6LNHnSE7h16GNnfLA).
|
|
103
|
+
|
|
100
104
|
## Getting Started
|
|
101
105
|
### Quickstart guide
|
|
102
106
|
**Installation**
|
|
103
107
|
```bash
|
|
104
|
-
pip install rc-foundry[all]
|
|
108
|
+
pip install "rc-foundry[all]"
|
|
105
109
|
```
|
|
106
110
|
|
|
107
111
|
**Downloading weights** Models can be downloaded to a target folder with:
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
foundry/__init__.py,sha256=H8S1nl5v6YeW8ggn1jKy4GdtH7c-FGS-j7CqUCAEnAU,1926
|
|
2
2
|
foundry/common.py,sha256=Aur8mH-CNmcUqSsw7VgaCQSW5sH1Bqf8Da91jzxPV1Y,3035
|
|
3
3
|
foundry/constants.py,sha256=0n1wBKCvNuw3QaQehSbmsHYkIdaGn3tLeRFItBrdeHY,913
|
|
4
|
-
foundry/version.py,sha256=
|
|
4
|
+
foundry/version.py,sha256=ib8ckvf-NNDfacXd8unW0p5cf-gl57XyQvjoEMc_pvc,704
|
|
5
5
|
foundry/callbacks/__init__.py,sha256=VsRT1e4sqlJHPcTCsfupMEx82Iz-LoOAGPpwvf_OJeE,126
|
|
6
6
|
foundry/callbacks/callback.py,sha256=xZBo_suP4bLrP6gl5uJPbaXm00DXigePa6dMeDxucgg,3890
|
|
7
7
|
foundry/callbacks/health_logging.py,sha256=tEtkByOlaAA7nnelxb7PbM9_dcIgOsdbxCdQY3K5pMc,16664
|
|
@@ -25,9 +25,9 @@ foundry/training/schedulers.py,sha256=StmXegPfIdLAv31FreCTrDh9dsOvNUfzG4YGa61Y4o
|
|
|
25
25
|
foundry/utils/alignment.py,sha256=2anqy0mn9zeFEiVWS_EG7zHiyPk1C_gbUu-SRvQ5mAM,2502
|
|
26
26
|
foundry/utils/components.py,sha256=Piw2TfQF26uuxC3hXG3iv_4rgud1lVO-cv6N-p05EDY,15200
|
|
27
27
|
foundry/utils/datasets.py,sha256=pLBxVezm-TSrYuC5gFnJZdGnNWV7aPH2QiWIVE2hkdQ,16629
|
|
28
|
-
foundry/utils/ddp.py,sha256=
|
|
28
|
+
foundry/utils/ddp.py,sha256=202_7qqm4ihPjpB5Q9NhUjDl4u22pu5JvY0ui0UkRUQ,3970
|
|
29
29
|
foundry/utils/instantiators.py,sha256=oGCp6hrmY-QPPPEjxKxe5uVFL125fH1RaLxjMKWCD_8,2169
|
|
30
|
-
foundry/utils/logging.py,sha256=
|
|
30
|
+
foundry/utils/logging.py,sha256=ywV75MBlQsothV0IBvqoAQTNg6pjo2-Cib7Uo080nzQ,9312
|
|
31
31
|
foundry/utils/rigid.py,sha256=_Z1pmitb6xgxyguLj_TukKscUBJjQsU4bsBD24GVS84,44444
|
|
32
32
|
foundry/utils/rotation_augmentation.py,sha256=7q1WEX2iJ0i7-2aV-M97nEaEdpqexDTaZn5JquYpkUk,1927
|
|
33
33
|
foundry/utils/squashfs.py,sha256=QlcwuJyVe-QVfIOS7o1QfLhaCQPNzzox7ln4n8dcYEg,5234
|
|
@@ -63,8 +63,8 @@ rf3/__init__.py,sha256=XBb5hF2RqBPHODGRmjiRbfTXgOGfOzdY91GbS4Vex00,70
|
|
|
63
63
|
rf3/_version.py,sha256=fCfpbI5aeA6yHqjo3tK78-l2dPGxhp-AyKSoCXp34Nc,739
|
|
64
64
|
rf3/alignment.py,sha256=BvvwMqQGCVxV20xIsTighD1kXMadXXL2SkckLjTerx0,2102
|
|
65
65
|
rf3/chemical.py,sha256=VECnRPgVm-icXbZeUG4svcENzdUiIupP6dhka_8zCrg,26572
|
|
66
|
-
rf3/cli.py,sha256=
|
|
67
|
-
rf3/inference.py,sha256=
|
|
66
|
+
rf3/cli.py,sha256=jxjq8u77J8itIK4aNTfIpnMsNgg8brW1A3NfVjlgE0s,2743
|
|
67
|
+
rf3/inference.py,sha256=yUjSqCjIqsGErUezq42-8K9570Opm_2eYa-bNnwltwA,2494
|
|
68
68
|
rf3/kinematics.py,sha256=V3yjalPupu1X2FEp7l3XZR-qzLKrhWLZyECk6RgIkcs,10901
|
|
69
69
|
rf3/scoring.py,sha256=dTllswE-6Fgli2eLiNzLFc2Rhz4ouDT4WL-sVbvLTGU,41541
|
|
70
70
|
rf3/train.py,sha256=V4nqCC_1JKLI3WQ-nErNa8sqFpvb1mFhXSe6ZPpEheM,7945
|
|
@@ -119,18 +119,18 @@ rfd3/__init__.py,sha256=2Wto2IsUIj2lGag9m_gqgdCwBNl5p21-Xnr7W_RpU3c,348
|
|
|
119
119
|
rfd3/callbacks.py,sha256=Zjt8RiaYWquoKOwRmC_wCUbRbov-V4zd2_73zjhgDHE,2783
|
|
120
120
|
rfd3/cli.py,sha256=ka3K5H117fzDYIDXFpOpJV21w_XBrHYJZdFE0thsGBI,1644
|
|
121
121
|
rfd3/constants.py,sha256=wLvDzrThpOrK8T3wGFNQeGrhAXOJQze8l3v_7pjIdMM,13141
|
|
122
|
-
rfd3/engine.py,sha256=
|
|
122
|
+
rfd3/engine.py,sha256=viXzVGYkHPyzv5d0Ifg8zFZ6Wqg-U4I8Y4ldQLPe9x4,21536
|
|
123
123
|
rfd3/run_inference.py,sha256=HfRMQ30_SAHfc-VFzBV52F-aLaNdG6PW8VkdMyB__wE,1264
|
|
124
124
|
rfd3/train.py,sha256=rHswffIUhOae3_iYyvAiQ3jALoFuzrcRUgMlbJLinlI,7947
|
|
125
|
-
rfd3/inference/datasets.py,sha256=
|
|
126
|
-
rfd3/inference/input_parsing.py,sha256=
|
|
125
|
+
rfd3/inference/datasets.py,sha256=9VLbzl7dpG8mk_pjs0R5C2wFYUoRIgXXoZcS9IohSy0,6510
|
|
126
|
+
rfd3/inference/input_parsing.py,sha256=pocqnnhE3-szeBbCL9gy9E3kZJSP_CGHXd6FFRxfv0c,46563
|
|
127
127
|
rfd3/inference/legacy_input_parsing.py,sha256=G2XxkrjdIpL6i1YY7xEmkFitVv__Pc45ow6IKKPHw64,28855
|
|
128
128
|
rfd3/inference/parsing.py,sha256=ktAMUuZE3Pe4bKAjjV3zjqcEDmGlMZ-cotIUhJsEQQA,5402
|
|
129
|
-
rfd3/inference/symmetry/atom_array.py,sha256=
|
|
130
|
-
rfd3/inference/symmetry/checks.py,sha256=
|
|
129
|
+
rfd3/inference/symmetry/atom_array.py,sha256=j8yhtZAjRNm4d06KyS4gk2XCquHPCwR0k9WiNmxz7WA,12941
|
|
130
|
+
rfd3/inference/symmetry/checks.py,sha256=ZWpC1JjrAjXY__xDt8EFYb5WUdSF0kobZyxxmacFU7U,10076
|
|
131
131
|
rfd3/inference/symmetry/contigs.py,sha256=6OvbZ2dJg-a0mvvKAC0VkzUH5HpUDxOJvkByIst_roU,2127
|
|
132
|
-
rfd3/inference/symmetry/frames.py,sha256=
|
|
133
|
-
rfd3/inference/symmetry/symmetry_utils.py,sha256=
|
|
132
|
+
rfd3/inference/symmetry/frames.py,sha256=kwog5jU_wgv6ACcUER2iU5qz-mdfdOe0cpi2OTEDKMU,18894
|
|
133
|
+
rfd3/inference/symmetry/symmetry_utils.py,sha256=CGUzMI5CKVIcNi5_l2-YRu-ExroZX54GndxLb5P7RtY,14680
|
|
134
134
|
rfd3/metrics/design_metrics.py,sha256=O1RqZdjQPNlAWYRg6UJTERYg_gUI1_hVleKsm9xbWBY,16836
|
|
135
135
|
rfd3/metrics/hbonds_hbplus_metrics.py,sha256=Sewy9KzmrA1OnfkasN-fmWrQ9IRx9G7Yyhe2ua0mk28,11518
|
|
136
136
|
rfd3/metrics/hbonds_metrics.py,sha256=SIR4BnDhYdpVSqwXXRYpQ_tB-M0_fVyugGl08WivCmE,15257
|
|
@@ -140,11 +140,11 @@ rfd3/metrics/sidechain_metrics.py,sha256=EGZuFuWQ0cCe83EVPAf4eysN8vP9ifNjfnmE0o5
|
|
|
140
140
|
rfd3/model/RFD3.py,sha256=95aKzye-XzuDyLGgost-Wsfu8eT635zHIRky-pNoHSA,3569
|
|
141
141
|
rfd3/model/RFD3_diffusion_module.py,sha256=BPjKGyQpbnqdzii3gXMKLhhijNqV8Xh4bSosmfDBt8w,12094
|
|
142
142
|
rfd3/model/cfg_utils.py,sha256=XPBLyoB_bQRLmdrJ1Z0hCjcVvgUMGIPuw4rxTlHjB_s,2575
|
|
143
|
-
rfd3/model/inference_sampler.py,sha256=
|
|
143
|
+
rfd3/model/inference_sampler.py,sha256=5k_UIkJCL6QO4BEqAxkR9GKxNAOz-NRq5Jh51Wm5MU8,25152
|
|
144
144
|
rfd3/model/layers/attention.py,sha256=XuNA7WyFlRfLnAgky1PtGvXFCnDGv7GeEcXz8hodTBo,19472
|
|
145
|
-
rfd3/model/layers/block_utils.py,sha256=
|
|
145
|
+
rfd3/model/layers/block_utils.py,sha256=oN0aD-vZiH4JbIFs2CzDmb2B74GNPKzdFurmGd-dirE,21244
|
|
146
146
|
rfd3/model/layers/blocks.py,sha256=MOjJ53THxM2MMM27Ap7xiIXRCdI_SHzqKzLLQVX6FEc,24888
|
|
147
|
-
rfd3/model/layers/chunked_pairwise.py,sha256=
|
|
147
|
+
rfd3/model/layers/chunked_pairwise.py,sha256=B8oUxXbdm9akwdbrmjA-7HtcNQBTppMB_gAMAz9VbvY,14712
|
|
148
148
|
rfd3/model/layers/encoders.py,sha256=CqByjHNSbtMIaNP_h2iEJZdTbm-N8SGo1bZgvRNpMJ8,15207
|
|
149
149
|
rfd3/model/layers/layer_utils.py,sha256=UPYo-DYa__93KONSEj2YZWLtBqvYNSA9_wHDDPhVrIc,5710
|
|
150
150
|
rfd3/model/layers/pairformer_layers.py,sha256=uimskhN-Ec0apEXAU6JqomyKX5-6ormrEsCFJotkBtM,3991
|
|
@@ -166,11 +166,11 @@ rfd3/transforms/ncaa_transforms.py,sha256=Lz4L8OGuOOG53sKJHcLSdV7WPQ3YzOzwd5tJG4
|
|
|
166
166
|
rfd3/transforms/pipelines.py,sha256=FGH-XH3taTWQ6k1zpDO_d-097EQdXmL6uqXZXw4HIMs,22086
|
|
167
167
|
rfd3/transforms/ppi_transforms.py,sha256=7rXyf-tn2TLz6ybYR_YVDtSDG7hOgqhYY4shNviA_Sw,23493
|
|
168
168
|
rfd3/transforms/rasa.py,sha256=a4IPFvVMMxldoGLyJQiSlGg7IyUkcBASbRZLWmguAKk,4156
|
|
169
|
-
rfd3/transforms/symmetry.py,sha256=
|
|
169
|
+
rfd3/transforms/symmetry.py,sha256=9I9gzAZkk5vMUJm7x8XCDSHtNPYYLAHt4meXxOczGT0,2970
|
|
170
170
|
rfd3/transforms/training_conditions.py,sha256=UXiUPjDwrNKM95tRe0eXrMeRN8XlTPc_MXUvo6UpePo,19510
|
|
171
171
|
rfd3/transforms/util_transforms.py,sha256=2AcLkzx-73ZFgcWD1cIHv7NyniRPI4_zThHK8azyQaY,18119
|
|
172
172
|
rfd3/transforms/virtual_atoms.py,sha256=UpmxzPPd5FaJigcRoxgLSHHrLLOqsCvZ5PPZfQSGqII,12547
|
|
173
|
-
rfd3/utils/inference.py,sha256
|
|
173
|
+
rfd3/utils/inference.py,sha256=Yf3aAUk_YZi58uIJr5Y2wfVnQ-2bh3S5GHLBPzCRjUs,26448
|
|
174
174
|
rfd3/utils/io.py,sha256=wbdjUTQkDc3RCSM7gdogA-XOKR68HeQ-cfvyN4pP90w,9849
|
|
175
175
|
rfd3/utils/vizualize.py,sha256=HPlczrA3zkOuxV5X05eOvy_Oga9e3cPnFUXOEP4RR_g,11046
|
|
176
176
|
rf3/configs/inference.yaml,sha256=JmEZdkAnbnOrX79lGS5xrYYho9aBFfVxfUp-8KjJV5I,309
|
|
@@ -304,8 +304,8 @@ rfd3/configs/trainer/rfd3_base.yaml,sha256=R3lZxdyjUirjlLU31qWlnZgHaz4GcWTGGIz4f
|
|
|
304
304
|
rfd3/configs/trainer/loss/losses/diffusion_loss.yaml,sha256=FE4FCEfurE0ekwZ4YfS6wCvPSNqxClwg_kc73cPql5Y,323
|
|
305
305
|
rfd3/configs/trainer/loss/losses/sequence_loss.yaml,sha256=kezbQcqwAZ0VKQPUBr2MsNr9DcDL3ENIP1i-j7h-6Co,64
|
|
306
306
|
rfd3/configs/trainer/metrics/design_metrics.yaml,sha256=xVDpClhHqSHvsf-8StL26z51Vn-iuWMDG9KMB-kqOI0,719
|
|
307
|
-
rc_foundry-0.1.
|
|
308
|
-
rc_foundry-0.1.
|
|
309
|
-
rc_foundry-0.1.
|
|
310
|
-
rc_foundry-0.1.
|
|
311
|
-
rc_foundry-0.1.
|
|
307
|
+
rc_foundry-0.1.9.dist-info/METADATA,sha256=SBHGrkr8RsLTCTOhJnbusBs8G7C8HxJUPajCbmU6OyE,11502
|
|
308
|
+
rc_foundry-0.1.9.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
309
|
+
rc_foundry-0.1.9.dist-info/entry_points.txt,sha256=BmiWCbWGtrd_lSOFMuCLBXyo84B7Nco-alj7hB0Yw9A,130
|
|
310
|
+
rc_foundry-0.1.9.dist-info/licenses/LICENSE.md,sha256=NKtPCJ7QMysFmzeDg56ZfUStvgzbq5sOvRQv7_ddZOs,1533
|
|
311
|
+
rc_foundry-0.1.9.dist-info/RECORD,,
|
rf3/cli.py
CHANGED
|
@@ -23,10 +23,19 @@ def fold(
|
|
|
23
23
|
configure_minimal_inference_logging()
|
|
24
24
|
|
|
25
25
|
# Find the RF3 configs directory relative to this file
|
|
26
|
-
#
|
|
27
|
-
#
|
|
28
|
-
|
|
29
|
-
|
|
26
|
+
# In development: models/rf3/src/rf3/cli.py -> models/rf3/configs/
|
|
27
|
+
# When installed: site-packages/rf3/cli.py -> site-packages/rf3/configs/
|
|
28
|
+
rf3_file_dir = Path(__file__).parent
|
|
29
|
+
|
|
30
|
+
# Check if we're in installed mode (configs are sibling to this file)
|
|
31
|
+
# or development mode (configs are ../../../configs)
|
|
32
|
+
if (rf3_file_dir / "configs").exists():
|
|
33
|
+
# Installed mode
|
|
34
|
+
config_path = str(rf3_file_dir / "configs")
|
|
35
|
+
else:
|
|
36
|
+
# Development mode
|
|
37
|
+
rf3_package_dir = rf3_file_dir.parent.parent # Go up to models/rf3/
|
|
38
|
+
config_path = str(rf3_package_dir / "configs")
|
|
30
39
|
|
|
31
40
|
# Get all arguments
|
|
32
41
|
args = ctx.params.get("args", []) + ctx.args
|
rf3/inference.py
CHANGED
|
@@ -16,7 +16,9 @@ rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
|
16
16
|
|
|
17
17
|
load_dotenv(override=True)
|
|
18
18
|
|
|
19
|
-
_config_path = os.path.join(
|
|
19
|
+
_config_path = os.path.join(
|
|
20
|
+
os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "configs"
|
|
21
|
+
)
|
|
20
22
|
|
|
21
23
|
|
|
22
24
|
@hydra.main(
|
rfd3/engine.py
CHANGED
|
@@ -21,11 +21,13 @@ from rfd3.constants import SAVED_CONDITIONING_ANNOTATIONS
|
|
|
21
21
|
from rfd3.inference.datasets import (
|
|
22
22
|
assemble_distributed_inference_loader_from_json,
|
|
23
23
|
)
|
|
24
|
-
from rfd3.inference.input_parsing import
|
|
24
|
+
from rfd3.inference.input_parsing import (
|
|
25
|
+
DesignInputSpecification,
|
|
26
|
+
ensure_input_is_abspath,
|
|
27
|
+
)
|
|
25
28
|
from rfd3.model.inference_sampler import SampleDiffusionConfig
|
|
26
29
|
from rfd3.utils.inference import (
|
|
27
30
|
ensure_inference_sampler_matches_design_spec,
|
|
28
|
-
ensure_input_is_abspath,
|
|
29
31
|
)
|
|
30
32
|
from rfd3.utils.io import (
|
|
31
33
|
CIF_LIKE_EXTENSIONS,
|
|
@@ -391,7 +393,13 @@ class RFD3InferenceEngine(BaseInferenceEngine):
|
|
|
391
393
|
design_specifications = {}
|
|
392
394
|
for prefix, example_spec in inputs.items():
|
|
393
395
|
# Record task name in the specification
|
|
394
|
-
example_spec
|
|
396
|
+
if isinstance(example_spec, DesignInputSpecification):
|
|
397
|
+
example_spec.extra = example_spec.extra or {}
|
|
398
|
+
example_spec.extra["task_name"] = prefix
|
|
399
|
+
else:
|
|
400
|
+
if "extra" not in example_spec:
|
|
401
|
+
example_spec["extra"] = {}
|
|
402
|
+
example_spec["extra"]["task_name"] = prefix
|
|
395
403
|
|
|
396
404
|
# ... Create n_batches for example
|
|
397
405
|
for batch_id in range((n_batches) if exists(n_batches) else 1):
|
rfd3/inference/datasets.py
CHANGED
|
@@ -14,8 +14,8 @@ from atomworks.ml.transforms.base import Compose, Transform
|
|
|
14
14
|
from omegaconf import DictConfig, OmegaConf
|
|
15
15
|
from rfd3.inference.input_parsing import (
|
|
16
16
|
DesignInputSpecification,
|
|
17
|
+
ensure_input_is_abspath,
|
|
17
18
|
)
|
|
18
|
-
from rfd3.utils.inference import ensure_input_is_abspath
|
|
19
19
|
from torch.utils.data import (
|
|
20
20
|
DataLoader,
|
|
21
21
|
SequentialSampler,
|
rfd3/inference/input_parsing.py
CHANGED
|
@@ -5,6 +5,7 @@ import os
|
|
|
5
5
|
import time
|
|
6
6
|
import warnings
|
|
7
7
|
from contextlib import contextmanager
|
|
8
|
+
from os import PathLike
|
|
8
9
|
from typing import Any, Dict, List, Optional, Union
|
|
9
10
|
|
|
10
11
|
import numpy as np
|
|
@@ -1121,3 +1122,33 @@ def accumulate_components(
|
|
|
1121
1122
|
if atom_array_accum.bonds is None:
|
|
1122
1123
|
atom_array_accum.bonds = BondList(atom_array_accum.array_length())
|
|
1123
1124
|
return atom_array_accum
|
|
1125
|
+
|
|
1126
|
+
|
|
1127
|
+
def ensure_input_is_abspath(args: Dict[str, Any], path: PathLike | None):
|
|
1128
|
+
"""
|
|
1129
|
+
Ensures the input source is an absolute path if exists, if not it will convert
|
|
1130
|
+
|
|
1131
|
+
args:
|
|
1132
|
+
args: Inference specification for atom array
|
|
1133
|
+
path: None or file to which the input is relative to.
|
|
1134
|
+
"""
|
|
1135
|
+
if isinstance(args, str):
|
|
1136
|
+
raise ValueError(
|
|
1137
|
+
"Expected args to be a dictionary, got a string: {}. If you are using an input JSON ensure it contains dictionaries of arguments".format(
|
|
1138
|
+
args
|
|
1139
|
+
)
|
|
1140
|
+
)
|
|
1141
|
+
if "input" not in args or not exists(args["input"]):
|
|
1142
|
+
return args
|
|
1143
|
+
input = str(args["input"])
|
|
1144
|
+
if not os.path.isabs(input):
|
|
1145
|
+
if path is None:
|
|
1146
|
+
raise ValueError(
|
|
1147
|
+
"Input path is relative, but no base path was provided to resolve it against."
|
|
1148
|
+
)
|
|
1149
|
+
input = os.path.abspath(os.path.join(os.path.dirname(str(path)), input))
|
|
1150
|
+
logger.info(
|
|
1151
|
+
f"Input source path is relative, converted to absolute path: {input}"
|
|
1152
|
+
)
|
|
1153
|
+
args["input"] = input
|
|
1154
|
+
return args
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import string
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
4
|
from rfd3.inference.symmetry.frames import (
|
|
3
5
|
decompose_symmetry_frame,
|
|
@@ -7,6 +9,68 @@ from rfd3.inference.symmetry.frames import (
|
|
|
7
9
|
FIXED_TRANSFORM_ID = -1
|
|
8
10
|
FIXED_ENTITY_ID = -1
|
|
9
11
|
|
|
12
|
+
# Alphabet for chain ID generation (uppercase letters only, per wwPDB convention)
|
|
13
|
+
_CHAIN_ALPHABET = string.ascii_uppercase
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def index_to_chain_id(index: int) -> str:
|
|
17
|
+
"""
|
|
18
|
+
Convert a zero-based index to a chain ID following wwPDB convention.
|
|
19
|
+
|
|
20
|
+
The naming follows the wwPDB-assigned chain ID system:
|
|
21
|
+
- 0-25: A-Z (single letter)
|
|
22
|
+
- 26-701: AA-ZZ (double letter)
|
|
23
|
+
- 702-18277: AAA-ZZZ (triple letter)
|
|
24
|
+
- And so on...
|
|
25
|
+
|
|
26
|
+
This is similar to Excel column naming (A, B, ..., Z, AA, AB, ...).
|
|
27
|
+
|
|
28
|
+
Arguments:
|
|
29
|
+
index: zero-based index (0 -> 'A', 25 -> 'Z', 26 -> 'AA', etc.)
|
|
30
|
+
Returns:
|
|
31
|
+
chain_id: string chain identifier
|
|
32
|
+
"""
|
|
33
|
+
if index < 0:
|
|
34
|
+
raise ValueError(f"Chain index must be non-negative, got {index}")
|
|
35
|
+
|
|
36
|
+
result = ""
|
|
37
|
+
remaining = index
|
|
38
|
+
|
|
39
|
+
# Convert to bijective base-26 (like Excel columns)
|
|
40
|
+
while True:
|
|
41
|
+
result = _CHAIN_ALPHABET[remaining % 26] + result
|
|
42
|
+
remaining = remaining // 26 - 1
|
|
43
|
+
if remaining < 0:
|
|
44
|
+
break
|
|
45
|
+
|
|
46
|
+
return result
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def chain_id_to_index(chain_id: str) -> int:
|
|
50
|
+
"""
|
|
51
|
+
Convert a chain ID back to a zero-based index.
|
|
52
|
+
|
|
53
|
+
Inverse of index_to_chain_id.
|
|
54
|
+
|
|
55
|
+
Arguments:
|
|
56
|
+
chain_id: string chain identifier (e.g., 'A', 'Z', 'AA', 'AB')
|
|
57
|
+
Returns:
|
|
58
|
+
index: zero-based index
|
|
59
|
+
"""
|
|
60
|
+
if not chain_id or not all(c in _CHAIN_ALPHABET for c in chain_id):
|
|
61
|
+
raise ValueError(f"Invalid chain ID: {chain_id}")
|
|
62
|
+
|
|
63
|
+
# Offset for all shorter chain IDs (26 + 26^2 + ... + 26^(len-1))
|
|
64
|
+
offset = sum(26**k for k in range(1, len(chain_id)))
|
|
65
|
+
|
|
66
|
+
# Value within the current length group (standard base-26)
|
|
67
|
+
value = 0
|
|
68
|
+
for char in chain_id:
|
|
69
|
+
value = value * 26 + _CHAIN_ALPHABET.index(char)
|
|
70
|
+
|
|
71
|
+
return offset + value
|
|
72
|
+
|
|
73
|
+
|
|
10
74
|
########################################################
|
|
11
75
|
# Symmetry annotations
|
|
12
76
|
########################################################
|
|
@@ -247,11 +311,13 @@ def reset_chain_ids(atom_array, start_id):
|
|
|
247
311
|
Reset the chain ids and pn_unit_iids of an atom array to start from the given id.
|
|
248
312
|
Arguments:
|
|
249
313
|
atom_array: atom array with chain_ids and pn_unit_iids annotated
|
|
314
|
+
start_id: starting chain ID (e.g., 'A')
|
|
250
315
|
"""
|
|
251
316
|
chain_ids = np.unique(atom_array.chain_id)
|
|
252
|
-
|
|
253
|
-
for
|
|
254
|
-
|
|
317
|
+
start_index = chain_id_to_index(start_id)
|
|
318
|
+
for i, old_id in enumerate(chain_ids):
|
|
319
|
+
new_id = index_to_chain_id(start_index + i)
|
|
320
|
+
atom_array.chain_id[atom_array.chain_id == old_id] = new_id
|
|
255
321
|
atom_array.pn_unit_iid = atom_array.chain_id
|
|
256
322
|
return atom_array
|
|
257
323
|
|
|
@@ -259,15 +325,18 @@ def reset_chain_ids(atom_array, start_id):
|
|
|
259
325
|
def reannotate_chain_ids(atom_array, offset, multiplier=0):
|
|
260
326
|
"""
|
|
261
327
|
Reannotate the chain ids and pn_unit_iids of an atom array.
|
|
328
|
+
|
|
329
|
+
Uses wwPDB-style chain IDs (A-Z, AA-ZZ, AAA-ZZZ, ...) to support
|
|
330
|
+
any number of chains.
|
|
331
|
+
|
|
262
332
|
Arguments:
|
|
263
333
|
atom_array: protein atom array with chain_ids and pn_unit_iids annotated
|
|
264
|
-
offset: offset to add to the chain ids
|
|
265
|
-
multiplier: multiplier
|
|
334
|
+
offset: offset to add to the chain ids (typically num_chains in ASU)
|
|
335
|
+
multiplier: multiplier for the offset (typically transform index)
|
|
266
336
|
"""
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
)
|
|
270
|
-
chain_ids = np.array([chr(id) for id in chain_ids_int], dtype=str)
|
|
337
|
+
chain_ids_indices = np.array([chain_id_to_index(c) for c in atom_array.chain_id])
|
|
338
|
+
new_indices = chain_ids_indices + offset * multiplier
|
|
339
|
+
chain_ids = np.array([index_to_chain_id(idx) for idx in new_indices], dtype="U4")
|
|
271
340
|
atom_array.chain_id = chain_ids
|
|
272
341
|
atom_array.pn_unit_iid = chain_ids
|
|
273
342
|
return atom_array
|
|
@@ -24,7 +24,16 @@ def check_symmetry_config(
|
|
|
24
24
|
assert sym_conf.id, "symmetry_id is required. e.g. {'id': 'C2'}"
|
|
25
25
|
# if unsym motif is provided, check that each motif name is in the atom array
|
|
26
26
|
|
|
27
|
+
is_motif_atom = get_motif_features(atom_array)["is_motif_atom"]
|
|
27
28
|
is_unsym_motif = np.zeros(atom_array.shape[0], dtype=bool)
|
|
29
|
+
|
|
30
|
+
if not is_motif_atom.any():
|
|
31
|
+
sym_conf.is_symmetric_motif = None
|
|
32
|
+
ranked_logger.warning(
|
|
33
|
+
"No motifs found in atom array. Setting is_symmetric_motif to None."
|
|
34
|
+
)
|
|
35
|
+
return sym_conf
|
|
36
|
+
|
|
28
37
|
if sym_conf.is_unsym_motif:
|
|
29
38
|
assert (
|
|
30
39
|
src_atom_array is not None
|
|
@@ -36,21 +45,20 @@ def check_symmetry_config(
|
|
|
36
45
|
if (sm and n not in sm.split(",")) and (n not in atom_array.src_component):
|
|
37
46
|
raise ValueError(f"Unsym motif {n} not found in atom_array")
|
|
38
47
|
|
|
39
|
-
is_motif_token = get_motif_features(atom_array)["is_motif_token"]
|
|
40
48
|
if (
|
|
41
|
-
|
|
49
|
+
is_motif_atom[~is_unsym_motif].any()
|
|
42
50
|
and not sym_conf.is_symmetric_motif
|
|
43
51
|
and not has_dist_cond
|
|
44
52
|
):
|
|
45
53
|
raise ValueError(
|
|
46
|
-
"Asymmetric motif inputs
|
|
47
|
-
"Use atomwise_fixed_dist to constrain the distance between the motif atoms."
|
|
54
|
+
"Asymmetric motif inputs are not supported yet. Please provide a symmetric motif."
|
|
48
55
|
)
|
|
49
56
|
|
|
50
57
|
if partial and not sym_conf.is_symmetric_motif:
|
|
51
58
|
raise ValueError(
|
|
52
59
|
"Partial diffusion with symmetry is only supported for symmetric inputs."
|
|
53
60
|
)
|
|
61
|
+
return sym_conf
|
|
54
62
|
|
|
55
63
|
|
|
56
64
|
def check_atom_array_is_symmetric(atom_array):
|