rc-foundry 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foundry/utils/ddp.py CHANGED
@@ -2,7 +2,7 @@ import logging
2
2
 
3
3
  import torch
4
4
  from beartype.typing import Any
5
- from lightning_fabric.utilities import rank_zero_only
5
+ from lightning.fabric.utilities import rank_zero_only
6
6
  from lightning_utilities.core.rank_zero import rank_prefixed_message
7
7
  from omegaconf import DictConfig
8
8
 
foundry/utils/logging.py CHANGED
@@ -4,7 +4,7 @@ from contextlib import contextmanager
4
4
 
5
5
  import pandas as pd
6
6
  from beartype.typing import Any
7
- from lightning_fabric.utilities import rank_zero_only
7
+ from lightning.fabric.utilities import rank_zero_only
8
8
  from omegaconf import DictConfig, OmegaConf
9
9
  from rich.console import Console
10
10
  from rich.syntax import Syntax
foundry/version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.1.7'
32
- __version_tuple__ = version_tuple = (0, 1, 7)
31
+ __version__ = version = '0.1.9'
32
+ __version_tuple__ = version_tuple = (0, 1, 9)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rc-foundry
3
- Version: 0.1.7
3
+ Version: 0.1.9
4
4
  Summary: Shared utilities and training infrastructure for biomolecular structure prediction models.
5
5
  Author-email: Institute for Protein Design <contact@ipd.uw.edu>
6
6
  License: BSD 3-Clause License
@@ -97,11 +97,15 @@ Foundry provides tooling and infrastructure for using and training all classes o
97
97
 
98
98
  All models within Foundry rely on [AtomWorks](https://github.com/RosettaCommons/atomworks) - a unified framework for manipulating and processing biomolecular structures - for both training and inference.
99
99
 
100
+
101
+ > [!NOTE]
102
+ > We have a slack now! Join for updates and to get your questions answered [here](https://join.slack.com/t/proteinmodelfoundry/shared_invite/zt-3kpwru8c6-nrmTW6LNHnSE7h16GNnfLA).
103
+
100
104
  ## Getting Started
101
105
  ### Quickstart guide
102
106
  **Installation**
103
107
  ```bash
104
- pip install rc-foundry[all]
108
+ pip install "rc-foundry[all]"
105
109
  ```
106
110
 
107
111
  **Downloading weights** Models can be downloaded to a target folder with:
@@ -1,7 +1,7 @@
1
1
  foundry/__init__.py,sha256=H8S1nl5v6YeW8ggn1jKy4GdtH7c-FGS-j7CqUCAEnAU,1926
2
2
  foundry/common.py,sha256=Aur8mH-CNmcUqSsw7VgaCQSW5sH1Bqf8Da91jzxPV1Y,3035
3
3
  foundry/constants.py,sha256=0n1wBKCvNuw3QaQehSbmsHYkIdaGn3tLeRFItBrdeHY,913
4
- foundry/version.py,sha256=szvPIs2C82UunpzuvVg3MbF4QhzbBYTsVJ8DmPfq6_E,704
4
+ foundry/version.py,sha256=ib8ckvf-NNDfacXd8unW0p5cf-gl57XyQvjoEMc_pvc,704
5
5
  foundry/callbacks/__init__.py,sha256=VsRT1e4sqlJHPcTCsfupMEx82Iz-LoOAGPpwvf_OJeE,126
6
6
  foundry/callbacks/callback.py,sha256=xZBo_suP4bLrP6gl5uJPbaXm00DXigePa6dMeDxucgg,3890
7
7
  foundry/callbacks/health_logging.py,sha256=tEtkByOlaAA7nnelxb7PbM9_dcIgOsdbxCdQY3K5pMc,16664
@@ -25,9 +25,9 @@ foundry/training/schedulers.py,sha256=StmXegPfIdLAv31FreCTrDh9dsOvNUfzG4YGa61Y4o
25
25
  foundry/utils/alignment.py,sha256=2anqy0mn9zeFEiVWS_EG7zHiyPk1C_gbUu-SRvQ5mAM,2502
26
26
  foundry/utils/components.py,sha256=Piw2TfQF26uuxC3hXG3iv_4rgud1lVO-cv6N-p05EDY,15200
27
27
  foundry/utils/datasets.py,sha256=pLBxVezm-TSrYuC5gFnJZdGnNWV7aPH2QiWIVE2hkdQ,16629
28
- foundry/utils/ddp.py,sha256=ydHrO6peGbRnWAwgH5rmpHuQd55g2gFzzoZJYypn7GU,3970
28
+ foundry/utils/ddp.py,sha256=202_7qqm4ihPjpB5Q9NhUjDl4u22pu5JvY0ui0UkRUQ,3970
29
29
  foundry/utils/instantiators.py,sha256=oGCp6hrmY-QPPPEjxKxe5uVFL125fH1RaLxjMKWCD_8,2169
30
- foundry/utils/logging.py,sha256=jrDgiB_56q_hWDc0jkBFekvqnNWcowJBt4B-S-ipJmM,9312
30
+ foundry/utils/logging.py,sha256=ywV75MBlQsothV0IBvqoAQTNg6pjo2-Cib7Uo080nzQ,9312
31
31
  foundry/utils/rigid.py,sha256=_Z1pmitb6xgxyguLj_TukKscUBJjQsU4bsBD24GVS84,44444
32
32
  foundry/utils/rotation_augmentation.py,sha256=7q1WEX2iJ0i7-2aV-M97nEaEdpqexDTaZn5JquYpkUk,1927
33
33
  foundry/utils/squashfs.py,sha256=QlcwuJyVe-QVfIOS7o1QfLhaCQPNzzox7ln4n8dcYEg,5234
@@ -63,8 +63,8 @@ rf3/__init__.py,sha256=XBb5hF2RqBPHODGRmjiRbfTXgOGfOzdY91GbS4Vex00,70
63
63
  rf3/_version.py,sha256=fCfpbI5aeA6yHqjo3tK78-l2dPGxhp-AyKSoCXp34Nc,739
64
64
  rf3/alignment.py,sha256=BvvwMqQGCVxV20xIsTighD1kXMadXXL2SkckLjTerx0,2102
65
65
  rf3/chemical.py,sha256=VECnRPgVm-icXbZeUG4svcENzdUiIupP6dhka_8zCrg,26572
66
- rf3/cli.py,sha256=dPKJFRHYoV2XS6xc_ZmdLTz6frqa6HZg4qgZU5oJcXU,2356
67
- rf3/inference.py,sha256=_AAJ07AfSeU3xTM2_KH9n_H12EK4qZ23IJuyauOrMaQ,2466
66
+ rf3/cli.py,sha256=jxjq8u77J8itIK4aNTfIpnMsNgg8brW1A3NfVjlgE0s,2743
67
+ rf3/inference.py,sha256=yUjSqCjIqsGErUezq42-8K9570Opm_2eYa-bNnwltwA,2494
68
68
  rf3/kinematics.py,sha256=V3yjalPupu1X2FEp7l3XZR-qzLKrhWLZyECk6RgIkcs,10901
69
69
  rf3/scoring.py,sha256=dTllswE-6Fgli2eLiNzLFc2Rhz4ouDT4WL-sVbvLTGU,41541
70
70
  rf3/train.py,sha256=V4nqCC_1JKLI3WQ-nErNa8sqFpvb1mFhXSe6ZPpEheM,7945
@@ -119,18 +119,18 @@ rfd3/__init__.py,sha256=2Wto2IsUIj2lGag9m_gqgdCwBNl5p21-Xnr7W_RpU3c,348
119
119
  rfd3/callbacks.py,sha256=Zjt8RiaYWquoKOwRmC_wCUbRbov-V4zd2_73zjhgDHE,2783
120
120
  rfd3/cli.py,sha256=ka3K5H117fzDYIDXFpOpJV21w_XBrHYJZdFE0thsGBI,1644
121
121
  rfd3/constants.py,sha256=wLvDzrThpOrK8T3wGFNQeGrhAXOJQze8l3v_7pjIdMM,13141
122
- rfd3/engine.py,sha256=NwATrhYFyqT7C9Bie8mWtUiqqzXgs9x6nOCkmZYPiT4,21224
122
+ rfd3/engine.py,sha256=viXzVGYkHPyzv5d0Ifg8zFZ6Wqg-U4I8Y4ldQLPe9x4,21536
123
123
  rfd3/run_inference.py,sha256=HfRMQ30_SAHfc-VFzBV52F-aLaNdG6PW8VkdMyB__wE,1264
124
124
  rfd3/train.py,sha256=rHswffIUhOae3_iYyvAiQ3jALoFuzrcRUgMlbJLinlI,7947
125
- rfd3/inference/datasets.py,sha256=u-2U7deHXu-iOs7doiKKynewP-NEyJfdORSTDzUSaQI,6538
126
- rfd3/inference/input_parsing.py,sha256=TyEzCzeCaNhuNi0RjMcq9fF2j3Sp36KbuZ1FUjlBTZ8,45442
125
+ rfd3/inference/datasets.py,sha256=9VLbzl7dpG8mk_pjs0R5C2wFYUoRIgXXoZcS9IohSy0,6510
126
+ rfd3/inference/input_parsing.py,sha256=pocqnnhE3-szeBbCL9gy9E3kZJSP_CGHXd6FFRxfv0c,46563
127
127
  rfd3/inference/legacy_input_parsing.py,sha256=G2XxkrjdIpL6i1YY7xEmkFitVv__Pc45ow6IKKPHw64,28855
128
128
  rfd3/inference/parsing.py,sha256=ktAMUuZE3Pe4bKAjjV3zjqcEDmGlMZ-cotIUhJsEQQA,5402
129
- rfd3/inference/symmetry/atom_array.py,sha256=HfFagFUB5yB-Y4IfUM5nuVGWHC5AEkyHqt0JcIqTQ_E,10922
130
- rfd3/inference/symmetry/checks.py,sha256=y-Kq0l5OhEmmxsPBBsMMB0qaAt18FeEicD3-jSMQFa0,9900
129
+ rfd3/inference/symmetry/atom_array.py,sha256=j8yhtZAjRNm4d06KyS4gk2XCquHPCwR0k9WiNmxz7WA,12941
130
+ rfd3/inference/symmetry/checks.py,sha256=ZWpC1JjrAjXY__xDt8EFYb5WUdSF0kobZyxxmacFU7U,10076
131
131
  rfd3/inference/symmetry/contigs.py,sha256=6OvbZ2dJg-a0mvvKAC0VkzUH5HpUDxOJvkByIst_roU,2127
132
- rfd3/inference/symmetry/frames.py,sha256=aEwkmlUsYexERX9hu09JMhisC8QTpHPVhfITbL80-EE,10819
133
- rfd3/inference/symmetry/symmetry_utils.py,sha256=p_PkxU3sw6gYGO2EmZTrbNQdLjz1mdTWEIl5MjQdIuY,14664
132
+ rfd3/inference/symmetry/frames.py,sha256=kwog5jU_wgv6ACcUER2iU5qz-mdfdOe0cpi2OTEDKMU,18894
133
+ rfd3/inference/symmetry/symmetry_utils.py,sha256=CGUzMI5CKVIcNi5_l2-YRu-ExroZX54GndxLb5P7RtY,14680
134
134
  rfd3/metrics/design_metrics.py,sha256=O1RqZdjQPNlAWYRg6UJTERYg_gUI1_hVleKsm9xbWBY,16836
135
135
  rfd3/metrics/hbonds_hbplus_metrics.py,sha256=Sewy9KzmrA1OnfkasN-fmWrQ9IRx9G7Yyhe2ua0mk28,11518
136
136
  rfd3/metrics/hbonds_metrics.py,sha256=SIR4BnDhYdpVSqwXXRYpQ_tB-M0_fVyugGl08WivCmE,15257
@@ -140,11 +140,11 @@ rfd3/metrics/sidechain_metrics.py,sha256=EGZuFuWQ0cCe83EVPAf4eysN8vP9ifNjfnmE0o5
140
140
  rfd3/model/RFD3.py,sha256=95aKzye-XzuDyLGgost-Wsfu8eT635zHIRky-pNoHSA,3569
141
141
  rfd3/model/RFD3_diffusion_module.py,sha256=BPjKGyQpbnqdzii3gXMKLhhijNqV8Xh4bSosmfDBt8w,12094
142
142
  rfd3/model/cfg_utils.py,sha256=XPBLyoB_bQRLmdrJ1Z0hCjcVvgUMGIPuw4rxTlHjB_s,2575
143
- rfd3/model/inference_sampler.py,sha256=qge7BNJttW0NXgerg3msPY3izxQ-6FsvWSTAMhZ4GJs,24696
143
+ rfd3/model/inference_sampler.py,sha256=5k_UIkJCL6QO4BEqAxkR9GKxNAOz-NRq5Jh51Wm5MU8,25152
144
144
  rfd3/model/layers/attention.py,sha256=XuNA7WyFlRfLnAgky1PtGvXFCnDGv7GeEcXz8hodTBo,19472
145
- rfd3/model/layers/block_utils.py,sha256=EZq2qYUeO6_VCLKDVC60cxfBE_EPwvp84FPmqLr28ZQ,21197
145
+ rfd3/model/layers/block_utils.py,sha256=oN0aD-vZiH4JbIFs2CzDmb2B74GNPKzdFurmGd-dirE,21244
146
146
  rfd3/model/layers/blocks.py,sha256=MOjJ53THxM2MMM27Ap7xiIXRCdI_SHzqKzLLQVX6FEc,24888
147
- rfd3/model/layers/chunked_pairwise.py,sha256=de5Qc3P7GEfZlX-QLaKfJxr6Ky5vgLcWWogatCw2UnY,14582
147
+ rfd3/model/layers/chunked_pairwise.py,sha256=B8oUxXbdm9akwdbrmjA-7HtcNQBTppMB_gAMAz9VbvY,14712
148
148
  rfd3/model/layers/encoders.py,sha256=CqByjHNSbtMIaNP_h2iEJZdTbm-N8SGo1bZgvRNpMJ8,15207
149
149
  rfd3/model/layers/layer_utils.py,sha256=UPYo-DYa__93KONSEj2YZWLtBqvYNSA9_wHDDPhVrIc,5710
150
150
  rfd3/model/layers/pairformer_layers.py,sha256=uimskhN-Ec0apEXAU6JqomyKX5-6ormrEsCFJotkBtM,3991
@@ -166,11 +166,11 @@ rfd3/transforms/ncaa_transforms.py,sha256=Lz4L8OGuOOG53sKJHcLSdV7WPQ3YzOzwd5tJG4
166
166
  rfd3/transforms/pipelines.py,sha256=FGH-XH3taTWQ6k1zpDO_d-097EQdXmL6uqXZXw4HIMs,22086
167
167
  rfd3/transforms/ppi_transforms.py,sha256=7rXyf-tn2TLz6ybYR_YVDtSDG7hOgqhYY4shNviA_Sw,23493
168
168
  rfd3/transforms/rasa.py,sha256=a4IPFvVMMxldoGLyJQiSlGg7IyUkcBASbRZLWmguAKk,4156
169
- rfd3/transforms/symmetry.py,sha256=GSnMF7oAnUxPozfafsRuHEv0yKXW0BpLTI6wsKGZrbc,2658
169
+ rfd3/transforms/symmetry.py,sha256=9I9gzAZkk5vMUJm7x8XCDSHtNPYYLAHt4meXxOczGT0,2970
170
170
  rfd3/transforms/training_conditions.py,sha256=UXiUPjDwrNKM95tRe0eXrMeRN8XlTPc_MXUvo6UpePo,19510
171
171
  rfd3/transforms/util_transforms.py,sha256=2AcLkzx-73ZFgcWD1cIHv7NyniRPI4_zThHK8azyQaY,18119
172
172
  rfd3/transforms/virtual_atoms.py,sha256=UpmxzPPd5FaJigcRoxgLSHHrLLOqsCvZ5PPZfQSGqII,12547
173
- rfd3/utils/inference.py,sha256=-8IKzkB9ulhLEJgapvnZSdIaIPQDPMpyPpHTQlFS7r0,27317
173
+ rfd3/utils/inference.py,sha256=Yf3aAUk_YZi58uIJr5Y2wfVnQ-2bh3S5GHLBPzCRjUs,26448
174
174
  rfd3/utils/io.py,sha256=wbdjUTQkDc3RCSM7gdogA-XOKR68HeQ-cfvyN4pP90w,9849
175
175
  rfd3/utils/vizualize.py,sha256=HPlczrA3zkOuxV5X05eOvy_Oga9e3cPnFUXOEP4RR_g,11046
176
176
  rf3/configs/inference.yaml,sha256=JmEZdkAnbnOrX79lGS5xrYYho9aBFfVxfUp-8KjJV5I,309
@@ -304,8 +304,8 @@ rfd3/configs/trainer/rfd3_base.yaml,sha256=R3lZxdyjUirjlLU31qWlnZgHaz4GcWTGGIz4f
304
304
  rfd3/configs/trainer/loss/losses/diffusion_loss.yaml,sha256=FE4FCEfurE0ekwZ4YfS6wCvPSNqxClwg_kc73cPql5Y,323
305
305
  rfd3/configs/trainer/loss/losses/sequence_loss.yaml,sha256=kezbQcqwAZ0VKQPUBr2MsNr9DcDL3ENIP1i-j7h-6Co,64
306
306
  rfd3/configs/trainer/metrics/design_metrics.yaml,sha256=xVDpClhHqSHvsf-8StL26z51Vn-iuWMDG9KMB-kqOI0,719
307
- rc_foundry-0.1.7.dist-info/METADATA,sha256=zlvCxfZ5-Ow7WuGKskfW6P1DGhZB9OfLIIBUBGncFeQ,11309
308
- rc_foundry-0.1.7.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
309
- rc_foundry-0.1.7.dist-info/entry_points.txt,sha256=BmiWCbWGtrd_lSOFMuCLBXyo84B7Nco-alj7hB0Yw9A,130
310
- rc_foundry-0.1.7.dist-info/licenses/LICENSE.md,sha256=NKtPCJ7QMysFmzeDg56ZfUStvgzbq5sOvRQv7_ddZOs,1533
311
- rc_foundry-0.1.7.dist-info/RECORD,,
307
+ rc_foundry-0.1.9.dist-info/METADATA,sha256=SBHGrkr8RsLTCTOhJnbusBs8G7C8HxJUPajCbmU6OyE,11502
308
+ rc_foundry-0.1.9.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
309
+ rc_foundry-0.1.9.dist-info/entry_points.txt,sha256=BmiWCbWGtrd_lSOFMuCLBXyo84B7Nco-alj7hB0Yw9A,130
310
+ rc_foundry-0.1.9.dist-info/licenses/LICENSE.md,sha256=NKtPCJ7QMysFmzeDg56ZfUStvgzbq5sOvRQv7_ddZOs,1533
311
+ rc_foundry-0.1.9.dist-info/RECORD,,
rf3/cli.py CHANGED
@@ -23,10 +23,19 @@ def fold(
23
23
  configure_minimal_inference_logging()
24
24
 
25
25
  # Find the RF3 configs directory relative to this file
26
- # This file is at: models/rf3/src/rf3/cli.py
27
- # Configs are at: models/rf3/configs/
28
- rf3_package_dir = Path(__file__).parent.parent.parent # Go up to models/rf3/
29
- config_path = str(rf3_package_dir / "configs")
26
+ # In development: models/rf3/src/rf3/cli.py -> models/rf3/configs/
27
+ # When installed: site-packages/rf3/cli.py -> site-packages/rf3/configs/
28
+ rf3_file_dir = Path(__file__).parent
29
+
30
+ # Check if we're in installed mode (configs are sibling to this file)
31
+ # or development mode (configs are ../../../configs)
32
+ if (rf3_file_dir / "configs").exists():
33
+ # Installed mode
34
+ config_path = str(rf3_file_dir / "configs")
35
+ else:
36
+ # Development mode
37
+ rf3_package_dir = rf3_file_dir.parent.parent # Go up to models/rf3/
38
+ config_path = str(rf3_package_dir / "configs")
30
39
 
31
40
  # Get all arguments
32
41
  args = ctx.params.get("args", []) + ctx.args
rf3/inference.py CHANGED
@@ -16,7 +16,9 @@ rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
16
16
 
17
17
  load_dotenv(override=True)
18
18
 
19
- _config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rf3/configs")
19
+ _config_path = os.path.join(
20
+ os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "configs"
21
+ )
20
22
 
21
23
 
22
24
  @hydra.main(
rfd3/engine.py CHANGED
@@ -21,11 +21,13 @@ from rfd3.constants import SAVED_CONDITIONING_ANNOTATIONS
21
21
  from rfd3.inference.datasets import (
22
22
  assemble_distributed_inference_loader_from_json,
23
23
  )
24
- from rfd3.inference.input_parsing import DesignInputSpecification
24
+ from rfd3.inference.input_parsing import (
25
+ DesignInputSpecification,
26
+ ensure_input_is_abspath,
27
+ )
25
28
  from rfd3.model.inference_sampler import SampleDiffusionConfig
26
29
  from rfd3.utils.inference import (
27
30
  ensure_inference_sampler_matches_design_spec,
28
- ensure_input_is_abspath,
29
31
  )
30
32
  from rfd3.utils.io import (
31
33
  CIF_LIKE_EXTENSIONS,
@@ -391,7 +393,13 @@ class RFD3InferenceEngine(BaseInferenceEngine):
391
393
  design_specifications = {}
392
394
  for prefix, example_spec in inputs.items():
393
395
  # Record task name in the specification
394
- example_spec["extra"]["task_name"] = prefix
396
+ if isinstance(example_spec, DesignInputSpecification):
397
+ example_spec.extra = example_spec.extra or {}
398
+ example_spec.extra["task_name"] = prefix
399
+ else:
400
+ if "extra" not in example_spec:
401
+ example_spec["extra"] = {}
402
+ example_spec["extra"]["task_name"] = prefix
395
403
 
396
404
  # ... Create n_batches for example
397
405
  for batch_id in range((n_batches) if exists(n_batches) else 1):
@@ -14,8 +14,8 @@ from atomworks.ml.transforms.base import Compose, Transform
14
14
  from omegaconf import DictConfig, OmegaConf
15
15
  from rfd3.inference.input_parsing import (
16
16
  DesignInputSpecification,
17
+ ensure_input_is_abspath,
17
18
  )
18
- from rfd3.utils.inference import ensure_input_is_abspath
19
19
  from torch.utils.data import (
20
20
  DataLoader,
21
21
  SequentialSampler,
@@ -5,6 +5,7 @@ import os
5
5
  import time
6
6
  import warnings
7
7
  from contextlib import contextmanager
8
+ from os import PathLike
8
9
  from typing import Any, Dict, List, Optional, Union
9
10
 
10
11
  import numpy as np
@@ -1121,3 +1122,33 @@ def accumulate_components(
1121
1122
  if atom_array_accum.bonds is None:
1122
1123
  atom_array_accum.bonds = BondList(atom_array_accum.array_length())
1123
1124
  return atom_array_accum
1125
+
1126
+
1127
+ def ensure_input_is_abspath(args: Dict[str, Any], path: PathLike | None):
1128
+ """
1129
+ Ensures the input source is an absolute path if exists, if not it will convert
1130
+
1131
+ args:
1132
+ args: Inference specification for atom array
1133
+ path: None or file to which the input is relative to.
1134
+ """
1135
+ if isinstance(args, str):
1136
+ raise ValueError(
1137
+ "Expected args to be a dictionary, got a string: {}. If you are using an input JSON ensure it contains dictionaries of arguments".format(
1138
+ args
1139
+ )
1140
+ )
1141
+ if "input" not in args or not exists(args["input"]):
1142
+ return args
1143
+ input = str(args["input"])
1144
+ if not os.path.isabs(input):
1145
+ if path is None:
1146
+ raise ValueError(
1147
+ "Input path is relative, but no base path was provided to resolve it against."
1148
+ )
1149
+ input = os.path.abspath(os.path.join(os.path.dirname(str(path)), input))
1150
+ logger.info(
1151
+ f"Input source path is relative, converted to absolute path: {input}"
1152
+ )
1153
+ args["input"] = input
1154
+ return args
@@ -1,3 +1,5 @@
1
+ import string
2
+
1
3
  import numpy as np
2
4
  from rfd3.inference.symmetry.frames import (
3
5
  decompose_symmetry_frame,
@@ -7,6 +9,68 @@ from rfd3.inference.symmetry.frames import (
7
9
  FIXED_TRANSFORM_ID = -1
8
10
  FIXED_ENTITY_ID = -1
9
11
 
12
+ # Alphabet for chain ID generation (uppercase letters only, per wwPDB convention)
13
+ _CHAIN_ALPHABET = string.ascii_uppercase
14
+
15
+
16
+ def index_to_chain_id(index: int) -> str:
17
+ """
18
+ Convert a zero-based index to a chain ID following wwPDB convention.
19
+
20
+ The naming follows the wwPDB-assigned chain ID system:
21
+ - 0-25: A-Z (single letter)
22
+ - 26-701: AA-ZZ (double letter)
23
+ - 702-18277: AAA-ZZZ (triple letter)
24
+ - And so on...
25
+
26
+ This is similar to Excel column naming (A, B, ..., Z, AA, AB, ...).
27
+
28
+ Arguments:
29
+ index: zero-based index (0 -> 'A', 25 -> 'Z', 26 -> 'AA', etc.)
30
+ Returns:
31
+ chain_id: string chain identifier
32
+ """
33
+ if index < 0:
34
+ raise ValueError(f"Chain index must be non-negative, got {index}")
35
+
36
+ result = ""
37
+ remaining = index
38
+
39
+ # Convert to bijective base-26 (like Excel columns)
40
+ while True:
41
+ result = _CHAIN_ALPHABET[remaining % 26] + result
42
+ remaining = remaining // 26 - 1
43
+ if remaining < 0:
44
+ break
45
+
46
+ return result
47
+
48
+
49
+ def chain_id_to_index(chain_id: str) -> int:
50
+ """
51
+ Convert a chain ID back to a zero-based index.
52
+
53
+ Inverse of index_to_chain_id.
54
+
55
+ Arguments:
56
+ chain_id: string chain identifier (e.g., 'A', 'Z', 'AA', 'AB')
57
+ Returns:
58
+ index: zero-based index
59
+ """
60
+ if not chain_id or not all(c in _CHAIN_ALPHABET for c in chain_id):
61
+ raise ValueError(f"Invalid chain ID: {chain_id}")
62
+
63
+ # Offset for all shorter chain IDs (26 + 26^2 + ... + 26^(len-1))
64
+ offset = sum(26**k for k in range(1, len(chain_id)))
65
+
66
+ # Value within the current length group (standard base-26)
67
+ value = 0
68
+ for char in chain_id:
69
+ value = value * 26 + _CHAIN_ALPHABET.index(char)
70
+
71
+ return offset + value
72
+
73
+
10
74
  ########################################################
11
75
  # Symmetry annotations
12
76
  ########################################################
@@ -247,11 +311,13 @@ def reset_chain_ids(atom_array, start_id):
247
311
  Reset the chain ids and pn_unit_iids of an atom array to start from the given id.
248
312
  Arguments:
249
313
  atom_array: atom array with chain_ids and pn_unit_iids annotated
314
+ start_id: starting chain ID (e.g., 'A')
250
315
  """
251
316
  chain_ids = np.unique(atom_array.chain_id)
252
- new_chain_range = range(ord(start_id), ord(start_id) + len(chain_ids))
253
- for new_id, old_id in zip(new_chain_range, chain_ids):
254
- atom_array.chain_id[atom_array.chain_id == old_id] = chr(new_id)
317
+ start_index = chain_id_to_index(start_id)
318
+ for i, old_id in enumerate(chain_ids):
319
+ new_id = index_to_chain_id(start_index + i)
320
+ atom_array.chain_id[atom_array.chain_id == old_id] = new_id
255
321
  atom_array.pn_unit_iid = atom_array.chain_id
256
322
  return atom_array
257
323
 
@@ -259,15 +325,18 @@ def reset_chain_ids(atom_array, start_id):
259
325
  def reannotate_chain_ids(atom_array, offset, multiplier=0):
260
326
  """
261
327
  Reannotate the chain ids and pn_unit_iids of an atom array.
328
+
329
+ Uses wwPDB-style chain IDs (A-Z, AA-ZZ, AAA-ZZZ, ...) to support
330
+ any number of chains.
331
+
262
332
  Arguments:
263
333
  atom_array: protein atom array with chain_ids and pn_unit_iids annotated
264
- offset: offset to add to the chain ids
265
- multiplier: multiplier to add to the chain ids
334
+ offset: offset to add to the chain ids (typically num_chains in ASU)
335
+ multiplier: multiplier for the offset (typically transform index)
266
336
  """
267
- chain_ids_int = (
268
- np.array([ord(c) for c in atom_array.chain_id]) + offset * multiplier
269
- )
270
- chain_ids = np.array([chr(id) for id in chain_ids_int], dtype=str)
337
+ chain_ids_indices = np.array([chain_id_to_index(c) for c in atom_array.chain_id])
338
+ new_indices = chain_ids_indices + offset * multiplier
339
+ chain_ids = np.array([index_to_chain_id(idx) for idx in new_indices], dtype="U4")
271
340
  atom_array.chain_id = chain_ids
272
341
  atom_array.pn_unit_iid = chain_ids
273
342
  return atom_array
@@ -24,7 +24,16 @@ def check_symmetry_config(
24
24
  assert sym_conf.id, "symmetry_id is required. e.g. {'id': 'C2'}"
25
25
  # if unsym motif is provided, check that each motif name is in the atom array
26
26
 
27
+ is_motif_atom = get_motif_features(atom_array)["is_motif_atom"]
27
28
  is_unsym_motif = np.zeros(atom_array.shape[0], dtype=bool)
29
+
30
+ if not is_motif_atom.any():
31
+ sym_conf.is_symmetric_motif = None
32
+ ranked_logger.warning(
33
+ "No motifs found in atom array. Setting is_symmetric_motif to None."
34
+ )
35
+ return sym_conf
36
+
28
37
  if sym_conf.is_unsym_motif:
29
38
  assert (
30
39
  src_atom_array is not None
@@ -36,21 +45,20 @@ def check_symmetry_config(
36
45
  if (sm and n not in sm.split(",")) and (n not in atom_array.src_component):
37
46
  raise ValueError(f"Unsym motif {n} not found in atom_array")
38
47
 
39
- is_motif_token = get_motif_features(atom_array)["is_motif_token"]
40
48
  if (
41
- is_motif_token[~is_unsym_motif].any()
49
+ is_motif_atom[~is_unsym_motif].any()
42
50
  and not sym_conf.is_symmetric_motif
43
51
  and not has_dist_cond
44
52
  ):
45
53
  raise ValueError(
46
- "Asymmetric motif inputs should be distance constrained."
47
- "Use atomwise_fixed_dist to constrain the distance between the motif atoms."
54
+ "Asymmetric motif inputs are not supported yet. Please provide a symmetric motif."
48
55
  )
49
56
 
50
57
  if partial and not sym_conf.is_symmetric_motif:
51
58
  raise ValueError(
52
59
  "Partial diffusion with symmetry is only supported for symmetric inputs."
53
60
  )
61
+ return sym_conf
54
62
 
55
63
 
56
64
  def check_atom_array_is_symmetric(atom_array):