rc-foundry 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -36,6 +36,11 @@ from foundry.utils.weights import (
36
36
  freeze_parameters_with_config,
37
37
  load_weights_with_policies,
38
38
  )
39
+ from foundry.utils.xpu import (
40
+ SingleXPUStrategy,
41
+ XPUAccelerator,
42
+ XPUMixedPrecision,
43
+ )
39
44
 
40
45
  ranked_logger = RankedLogger(__name__, rank_zero_only=True)
41
46
 
@@ -122,8 +127,19 @@ class FabricTrainer(ABC):
122
127
  (3) Fabric Loggers (https://lightning.ai/docs/fabric/2.4.0/api/loggers.html)
123
128
  (4) Efficient Gradient Accumulation (https://lightning.ai/docs/fabric/2.4.0/advanced/gradient_accumulation.html)
124
129
  """
125
- # Use custom DDP strategy only for multi-device, non-interactive environments
126
- if (
130
+ # Handle XPU accelerator
131
+ is_xpu = hasattr(torch, "xpu") and torch.xpu.is_available()
132
+ if accelerator == "xpu" or (
133
+ accelerator == "auto" and is_xpu and not torch.cuda.is_available()
134
+ ):
135
+ accelerator = XPUAccelerator()
136
+ precision_plugin = None
137
+ if precision in ("16-mixed", "bf16-mixed"):
138
+ precision_plugin = XPUMixedPrecision(precision=precision)
139
+ precision = None # Handled by plugin
140
+ strategy = SingleXPUStrategy(precision_plugin=precision_plugin)
141
+ ranked_logger.info("Using Intel XPU with SingleXPUStrategy")
142
+ elif (
127
143
  strategy == "ddp"
128
144
  and not is_interactive_environment()
129
145
  and not (num_nodes == 1 and devices_per_node == 1)
@@ -5,11 +5,11 @@ from typing import List
5
5
  import numpy as np
6
6
  from atomworks.ml.encoding_definitions import AF3SequenceEncoding
7
7
  from biotite.structure import AtomArray
8
- from rfd3.constants import (
9
- TIP_BY_RESTYPE,
10
- )
11
8
 
12
9
  from foundry.common import exists
10
+ from foundry.constants import (
11
+ TIP_BY_RESTYPE,
12
+ )
13
13
  from foundry.utils.ddp import RankedLogger
14
14
 
15
15
  global_logger = RankedLogger(__name__, rank_zero_only=False)
foundry/utils/ddp.py CHANGED
@@ -20,7 +20,7 @@ def is_rank_zero() -> bool:
20
20
 
21
21
 
22
22
  def set_accelerator_based_on_availability(cfg: dict | DictConfig):
23
- """Set training accelerator to CPU if no GPUs are available.
23
+ """Set training accelerator based on available hardware.
24
24
 
25
25
  Args:
26
26
  cfg: Hydra object with trainer settings "accelerator", "devices_per_node", and "num_nodes".
@@ -28,22 +28,25 @@ def set_accelerator_based_on_availability(cfg: dict | DictConfig):
28
28
  Returns:
29
29
  None; modifies the input `cfg` object in place.
30
30
  """
31
- if not torch.cuda.is_available():
31
+ assert "trainer" in cfg, "Configuration object must have a 'trainer' key."
32
+ for key in ["accelerator", "devices_per_node", "num_nodes"]:
33
+ assert (
34
+ key in cfg.trainer
35
+ ), f"Configuration object must have a 'trainer.{key}' key."
36
+
37
+ if torch.cuda.is_available():
38
+ cfg.trainer.accelerator = "gpu"
39
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
40
+ logger.info("Intel XPU detected - using XPU accelerator")
41
+ cfg.trainer.accelerator = "xpu"
42
+ else:
32
43
  logger.error(
33
- "No GPUs available - Setting accelerator to 'cpu'. Are you sure you are using the correct configs?"
44
+ "No GPUs/XPUs available - Setting accelerator to 'cpu'. Are you sure you are using the correct configs?"
34
45
  )
35
- assert "trainer" in cfg, "Configuration object must have a 'trainer' key."
36
- for key in ["accelerator", "devices_per_node", "num_nodes"]:
37
- assert (
38
- key in cfg.trainer
39
- ), f"Configuration object must have a 'trainer.{key}' key."
40
-
41
- # Override accelerator settings
42
46
  cfg.trainer.accelerator = "cpu"
43
47
  cfg.trainer.devices_per_node = 1
44
48
  cfg.trainer.num_nodes = 1
45
- else:
46
- cfg.trainer.accelerator = "gpu"
49
+
47
50
  return cfg
48
51
 
49
52
 
@@ -0,0 +1,27 @@
1
+ """XPU utilities for Intel GPU support.
2
+
3
+ XPU support in PyTorch is now native (torch.xpu.is_available()), but Lightning Fabric
4
+ requires custom Accelerator, Strategy, and Precision plugins for proper XPU handling.
5
+
6
+ These components are used directly (not via registry) when XPU is detected:
7
+ - XPUAccelerator: Custom accelerator for XPU devices
8
+ - SingleXPUStrategy: Strategy for single-device XPU training/inference
9
+ - XPUMixedPrecision: Precision plugin with proper XPU autocast support
10
+
11
+ Usage:
12
+ from foundry.utils.xpu import XPUAccelerator, SingleXPUStrategy, XPUMixedPrecision
13
+
14
+ # Check availability
15
+ if XPUAccelerator.is_available():
16
+ strategy = SingleXPUStrategy(precision_plugin=XPUMixedPrecision("bf16-mixed"))
17
+
18
+ Note:
19
+ The FabricTrainer automatically uses these components when XPU is detected.
20
+ You typically don't need to use them directly unless customizing behavior.
21
+ """
22
+
23
+ from .single_xpu_strategy import SingleXPUStrategy
24
+ from .xpu_accelerator import XPUAccelerator
25
+ from .xpu_precision import XPUMixedPrecision
26
+
27
+ __all__ = ["SingleXPUStrategy", "XPUAccelerator", "XPUMixedPrecision"]
@@ -0,0 +1,47 @@
1
+ """Lightning Fabric strategy for single XPU device.
2
+
3
+ https://docs.pytorch.org/docs/stable/notes/get_start_xpu.html
4
+ """
5
+
6
+ import torch
7
+ from lightning.fabric.plugins import CheckpointIO
8
+ from lightning.fabric.plugins.precision import Precision
9
+ from lightning.fabric.strategies import SingleDeviceStrategy
10
+ from lightning.fabric.utilities.types import _DEVICE
11
+
12
+
13
+ class SingleXPUStrategy(SingleDeviceStrategy):
14
+ """Strategy for training/inference on a single Intel XPU device.
15
+
16
+ This strategy extends SingleDeviceStrategy to properly handle XPU devices.
17
+ """
18
+
19
+ strategy_name = "xpu_single"
20
+
21
+ def __init__(
22
+ self,
23
+ device: _DEVICE = "xpu:0",
24
+ checkpoint_io: CheckpointIO | None = None,
25
+ precision_plugin: Precision | None = None,
26
+ ) -> None:
27
+ """Initialize the single XPU strategy.
28
+
29
+ Args:
30
+ device: The XPU device to use. Defaults to "xpu:0".
31
+ checkpoint_io: Plugin for checkpoint I/O.
32
+ precision_plugin: Plugin for precision handling (set via _precision property).
33
+
34
+ Raises:
35
+ RuntimeError: If XPU devices are not available.
36
+ """
37
+ if not (hasattr(torch, "xpu") and torch.xpu.is_available()):
38
+ msg = "`SingleXPUStrategy` requires XPU devices to run"
39
+ raise RuntimeError(msg)
40
+
41
+ super().__init__(
42
+ device=device,
43
+ checkpoint_io=checkpoint_io,
44
+ )
45
+ # Precision is handled via the _precision property in newer Lightning versions
46
+ if precision_plugin is not None:
47
+ self._precision = precision_plugin
@@ -0,0 +1,91 @@
1
+ """XPU Accelerator for Intel XPU devices."""
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+ from lightning.fabric.accelerators import Accelerator
7
+
8
+
9
+ class XPUAccelerator(Accelerator):
10
+ """Accelerator for Intel XPU devices.
11
+
12
+ This accelerator enables training and inference on Intel GPUs using
13
+ PyTorch's native XPU support (torch.xpu).
14
+ """
15
+
16
+ @property
17
+ def name(self) -> str:
18
+ """Return the name of this accelerator."""
19
+ return "xpu"
20
+
21
+ @staticmethod
22
+ def setup_device(device: torch.device) -> None:
23
+ """Set up the specified XPU device.
24
+
25
+ Args:
26
+ device: The torch device to set up.
27
+
28
+ Raises:
29
+ RuntimeError: If device is not an XPU device.
30
+ """
31
+ if device.type != "xpu":
32
+ msg = f"Device should be xpu, got {device} instead"
33
+ raise RuntimeError(msg)
34
+
35
+ torch.xpu.set_device(device)
36
+
37
+ @staticmethod
38
+ def parse_devices(devices: str | list | torch.device) -> list:
39
+ """Parse devices specification for XPU training.
40
+
41
+ Args:
42
+ devices: Device specification (int, list of ints, or string).
43
+
44
+ Returns:
45
+ List of device indices.
46
+ """
47
+ if isinstance(devices, list):
48
+ return devices
49
+ return [devices]
50
+
51
+ @staticmethod
52
+ def get_parallel_devices(devices: list) -> list[torch.device]:
53
+ """Generate a list of parallel XPU devices.
54
+
55
+ Args:
56
+ devices: List of device indices.
57
+
58
+ Returns:
59
+ List of torch.device objects for XPU.
60
+ """
61
+ return [torch.device("xpu", idx) for idx in devices]
62
+
63
+ @staticmethod
64
+ def auto_device_count() -> int:
65
+ """Return the number of available XPU devices."""
66
+ return torch.xpu.device_count()
67
+
68
+ @staticmethod
69
+ def is_available() -> bool:
70
+ """Check if XPU is available."""
71
+ return hasattr(torch, "xpu") and torch.xpu.is_available()
72
+
73
+ @staticmethod
74
+ def get_device_stats(device: str | torch.device) -> dict[str, Any]:
75
+ """Return XPU device statistics.
76
+
77
+ Currently returns an empty dict as XPU stats API may vary.
78
+
79
+ Args:
80
+ device: The device to get stats for.
81
+
82
+ Returns:
83
+ Dictionary of device statistics.
84
+ """
85
+ del device # Unused
86
+ return {}
87
+
88
+ def teardown(self) -> None:
89
+ """Clean up XPU accelerator resources."""
90
+ # Empty implementation required by base class
91
+ pass
@@ -0,0 +1,72 @@
1
+ """XPU Precision Plugin for Lightning Fabric."""
2
+
3
+ from contextlib import contextmanager
4
+ from typing import Any, Generator, Literal
5
+
6
+ import torch
7
+ from lightning.fabric.plugins.precision import MixedPrecision
8
+ from torch import Tensor
9
+
10
+
11
+ class XPUMixedPrecision(MixedPrecision):
12
+ """Mixed precision plugin for Intel XPU devices.
13
+
14
+ This overrides the default MixedPrecision plugin to use 'xpu' as the
15
+ device type for torch.autocast instead of 'cuda'.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ precision: Literal["16-mixed", "bf16-mixed"] = "bf16-mixed",
21
+ ) -> None:
22
+ """Initialize XPU mixed precision.
23
+
24
+ Args:
25
+ precision: The precision mode. "16-mixed" uses float16,
26
+ "bf16-mixed" uses bfloat16. Defaults to "bf16-mixed".
27
+
28
+ Raises:
29
+ ValueError: If precision is not "16-mixed" or "bf16-mixed".
30
+ """
31
+ # Determine dtype from precision string
32
+ if precision == "16-mixed":
33
+ dtype = torch.float16
34
+ elif precision == "bf16-mixed":
35
+ dtype = torch.bfloat16
36
+ else:
37
+ msg = f"Invalid precision: {precision}. Must be '16-mixed' or 'bf16-mixed'"
38
+ raise ValueError(msg)
39
+
40
+ # Initialize with xpu device type
41
+ super().__init__(precision=precision, device="xpu")
42
+ self._desired_input_dtype = dtype
43
+
44
+ @contextmanager
45
+ def forward_context(self) -> Generator[None, None, None]:
46
+ """Context manager for forward pass with XPU autocast."""
47
+ with torch.autocast(device_type="xpu", dtype=self._desired_input_dtype):
48
+ yield
49
+
50
+ def convert_input(self, data: Any) -> Any:
51
+ """Convert input data to the appropriate precision.
52
+
53
+ Args:
54
+ data: Input data to convert.
55
+
56
+ Returns:
57
+ Converted data.
58
+ """
59
+ return self._convert_fp_tensor(data)
60
+
61
+ def _convert_fp_tensor(self, data: Any) -> Any:
62
+ """Convert floating point tensors to the desired dtype.
63
+
64
+ Args:
65
+ data: Data to convert.
66
+
67
+ Returns:
68
+ Converted data if it's a floating point tensor, otherwise unchanged.
69
+ """
70
+ if isinstance(data, Tensor) and data.is_floating_point():
71
+ return data.to(self._desired_input_dtype)
72
+ return data
foundry/version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.1.9'
32
- __version_tuple__ = version_tuple = (0, 1, 9)
31
+ __version__ = version = '0.1.11'
32
+ __version_tuple__ = version_tuple = (0, 1, 11)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -67,11 +67,15 @@ class MPNNInferenceEngine:
67
67
  else checkpoint_path
68
68
  )
69
69
 
70
- # Determine the device.
70
+ # Determine the device (supports XPU, CUDA, and CPU).
71
71
  if device is not None:
72
72
  self.device = torch.device(device)
73
+ elif torch.cuda.is_available():
74
+ self.device = torch.device("cuda")
75
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
76
+ self.device = torch.device("xpu")
73
77
  else:
74
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
78
+ self.device = torch.device("cpu")
75
79
 
76
80
  # Set up allowed model types.
77
81
  self.allowed_model_types = {"protein_mpnn", "ligand_mpnn"}
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rc-foundry
3
- Version: 0.1.9
3
+ Version: 0.1.11
4
4
  Summary: Shared utilities and training infrastructure for biomolecular structure prediction models.
5
5
  Author-email: Institute for Protein Design <contact@ipd.uw.edu>
6
6
  License: BSD 3-Clause License
@@ -108,6 +108,16 @@ All models within Foundry rely on [AtomWorks](https://github.com/RosettaCommons/
108
108
  pip install "rc-foundry[all]"
109
109
  ```
110
110
 
111
+ **Intel XPU Installation**
112
+
113
+ For Intel XPU devices, install PyTorch with XPU support first, then install Foundry.
114
+ ```bash
115
+ pip install torch --index-url https://download.pytorch.org/whl/xpu
116
+ pip install "rc-foundry[all]"
117
+ ```
118
+ > [!NOTE]
119
+ > Use `pip` (not `uv`) for XPU installs since UV re-resolves dependencies and may replace your XPU torch with the standard PyPI version.
120
+
111
121
  **Downloading weights** Models can be downloaded to a target folder with:
112
122
  ```
113
123
  foundry install base-models --checkpoint-dir <path/to/ckpt/dir>
@@ -1,7 +1,7 @@
1
1
  foundry/__init__.py,sha256=H8S1nl5v6YeW8ggn1jKy4GdtH7c-FGS-j7CqUCAEnAU,1926
2
2
  foundry/common.py,sha256=Aur8mH-CNmcUqSsw7VgaCQSW5sH1Bqf8Da91jzxPV1Y,3035
3
3
  foundry/constants.py,sha256=0n1wBKCvNuw3QaQehSbmsHYkIdaGn3tLeRFItBrdeHY,913
4
- foundry/version.py,sha256=ib8ckvf-NNDfacXd8unW0p5cf-gl57XyQvjoEMc_pvc,704
4
+ foundry/version.py,sha256=0-Ruc52ECccw_8Ef0d7jMkzrb8fkobUkZLqGGvcm1ik,706
5
5
  foundry/callbacks/__init__.py,sha256=VsRT1e4sqlJHPcTCsfupMEx82Iz-LoOAGPpwvf_OJeE,126
6
6
  foundry/callbacks/callback.py,sha256=xZBo_suP4bLrP6gl5uJPbaXm00DXigePa6dMeDxucgg,3890
7
7
  foundry/callbacks/health_logging.py,sha256=tEtkByOlaAA7nnelxb7PbM9_dcIgOsdbxCdQY3K5pMc,16664
@@ -18,14 +18,14 @@ foundry/model/layers/blocks.py,sha256=ihbbP_1fOlrkrcrQSk9thCrNWjK8mtxD3WxcBng9Ht
18
18
  foundry/testing/__init__.py,sha256=BnrU7fZ4l0Dm1vrGcNPQYTAw83PW4DGYz7TGhGqgrfQ,223
19
19
  foundry/testing/fixtures.py,sha256=j27a8CAonygjlWsUjZ-95M5MF4Rjp9nw7JskqiZlweI,486
20
20
  foundry/testing/pytest_hooks.py,sha256=5Ebw1GXYO2XqS9Jvpzty7g3gCXIdXu16jqg53XcuUx4,450
21
- foundry/trainers/fabric.py,sha256=cjaTHbGuJEQwaGBvIAXD_il4bHtY-crsTY14Xn77uXA,40401
21
+ foundry/trainers/fabric.py,sha256=qZVO3C48-Won57TbxMO9qb8VOh8JPNG8kqXSW6nh-YU,41057
22
22
  foundry/training/EMA.py,sha256=3OWA9Pz7XuDr-SRxbz24tZf55DmhSa2fKy9r5v2IXqA,2651
23
23
  foundry/training/checkpoint.py,sha256=mUiObg-qcF3tvMfVu77sD9m3yVRp71czv07ccliU7qQ,1791
24
24
  foundry/training/schedulers.py,sha256=StmXegPfIdLAv31FreCTrDh9dsOvNUfzG4YGa61Y4oE,3647
25
25
  foundry/utils/alignment.py,sha256=2anqy0mn9zeFEiVWS_EG7zHiyPk1C_gbUu-SRvQ5mAM,2502
26
- foundry/utils/components.py,sha256=Piw2TfQF26uuxC3hXG3iv_4rgud1lVO-cv6N-p05EDY,15200
26
+ foundry/utils/components.py,sha256=gCRjmZXqltOqU6tV_jISX6T-BX2z9zzr-H0Xgo7oojw,15203
27
27
  foundry/utils/datasets.py,sha256=pLBxVezm-TSrYuC5gFnJZdGnNWV7aPH2QiWIVE2hkdQ,16629
28
- foundry/utils/ddp.py,sha256=202_7qqm4ihPjpB5Q9NhUjDl4u22pu5JvY0ui0UkRUQ,3970
28
+ foundry/utils/ddp.py,sha256=gJY3w0pwelr6UoxOkHsttgcZsTN97kfHFtfGN2VKLM8,4075
29
29
  foundry/utils/instantiators.py,sha256=oGCp6hrmY-QPPPEjxKxe5uVFL125fH1RaLxjMKWCD_8,2169
30
30
  foundry/utils/logging.py,sha256=ywV75MBlQsothV0IBvqoAQTNg6pjo2-Cib7Uo080nzQ,9312
31
31
  foundry/utils/rigid.py,sha256=_Z1pmitb6xgxyguLj_TukKscUBJjQsU4bsBD24GVS84,44444
@@ -33,13 +33,17 @@ foundry/utils/rotation_augmentation.py,sha256=7q1WEX2iJ0i7-2aV-M97nEaEdpqexDTaZn
33
33
  foundry/utils/squashfs.py,sha256=QlcwuJyVe-QVfIOS7o1QfLhaCQPNzzox7ln4n8dcYEg,5234
34
34
  foundry/utils/torch.py,sha256=OLsqoxw4CTXbGzWUHernLUT7uQjLu0tVPtD8h8747DI,11211
35
35
  foundry/utils/weights.py,sha256=btz4S02xff2vgiq4xMfiXuhK1ERafqQPtmimo1DmoWY,10381
36
+ foundry/utils/xpu/__init__.py,sha256=i8_uiVj5b0-_Z4bX4RZmaE01e0NHbo-eTNXMYF1M0mM,1114
37
+ foundry/utils/xpu/single_xpu_strategy.py,sha256=28iEo-6e0gbpApmFxZdfhLO-RZb2QlOf9XKlRdTeXXc,1592
38
+ foundry/utils/xpu/xpu_accelerator.py,sha256=LPVXvgBKY3OE1ghyhZvotP6aWNVZXMup1Swi1v6U5GY,2467
39
+ foundry/utils/xpu/xpu_precision.py,sha256=zuFfLkBIDLbF-omJ6mTiGoVU0UgXg84A2j3ZjYyoAOc,2285
36
40
  foundry_cli/__init__.py,sha256=0BxY2RUKJLaMXUGgypPCwlTskTEFdVnkhTR4C4ft2Kw,52
37
41
  foundry_cli/download_checkpoints.py,sha256=CxU9dKBa1vAkVd450tfH5aZAlQIUTrHsDGTbmxzd_JQ,8922
38
42
  mpnn/__init__.py,sha256=hgQcXFaCbAxFrhydVAy0xj8yC7UJF-GCCFhqD0sZ7I4,57
39
43
  mpnn/inference.py,sha256=wPtGR325eVRVeesXoWtBK6b_-VcU8BZae5IfQN3-mvA,1669
40
44
  mpnn/train.py,sha256=9eQGBd3rdNF5Zr2w8oUgETbqxBavNBajtA6Vbc5zESE,10239
41
45
  mpnn/collate/feature_collator.py,sha256=LpzAFWo1VMa06dJLmfUWZsKe4xvLZjHbx4RICg2lgbQ,10510
42
- mpnn/inference_engines/mpnn.py,sha256=PmDEsIFipdk2fY57FA-vCp4evoU83DVVuUVmlViUtWk,21725
46
+ mpnn/inference_engines/mpnn.py,sha256=Wg7FMI5KKdq14qQzKs6TkX-rvUi1fB6lSnJMdwWuKKw,21912
43
47
  mpnn/loss/nll_loss.py,sha256=KmdNe-BCzGYtijjappzBArQcT1gHVlJnKdY1PYQ4mhU,5947
44
48
  mpnn/metrics/nll.py,sha256=T6oMeUOEeHZzOMTH8NHFtsH9vUwLAsHQDPszzj4YKXI,15299
45
49
  mpnn/metrics/sequence_recovery.py,sha256=YDw_LmH-a3ajBYWK0mucJEQvw0_VEyxvrBN7da4vX8Q,19034
@@ -64,7 +68,7 @@ rf3/_version.py,sha256=fCfpbI5aeA6yHqjo3tK78-l2dPGxhp-AyKSoCXp34Nc,739
64
68
  rf3/alignment.py,sha256=BvvwMqQGCVxV20xIsTighD1kXMadXXL2SkckLjTerx0,2102
65
69
  rf3/chemical.py,sha256=VECnRPgVm-icXbZeUG4svcENzdUiIupP6dhka_8zCrg,26572
66
70
  rf3/cli.py,sha256=jxjq8u77J8itIK4aNTfIpnMsNgg8brW1A3NfVjlgE0s,2743
67
- rf3/inference.py,sha256=yUjSqCjIqsGErUezq42-8K9570Opm_2eYa-bNnwltwA,2494
71
+ rf3/inference.py,sha256=R_bzJe2K6Ek9uhrkZ97IiVaGKhBaObWc5SWh6MVDLLg,2530
68
72
  rf3/kinematics.py,sha256=V3yjalPupu1X2FEp7l3XZR-qzLKrhWLZyECk6RgIkcs,10901
69
73
  rf3/scoring.py,sha256=dTllswE-6Fgli2eLiNzLFc2Rhz4ouDT4WL-sVbvLTGU,41541
70
74
  rf3/train.py,sha256=V4nqCC_1JKLI3WQ-nErNa8sqFpvb1mFhXSe6ZPpEheM,7945
@@ -119,18 +123,18 @@ rfd3/__init__.py,sha256=2Wto2IsUIj2lGag9m_gqgdCwBNl5p21-Xnr7W_RpU3c,348
119
123
  rfd3/callbacks.py,sha256=Zjt8RiaYWquoKOwRmC_wCUbRbov-V4zd2_73zjhgDHE,2783
120
124
  rfd3/cli.py,sha256=ka3K5H117fzDYIDXFpOpJV21w_XBrHYJZdFE0thsGBI,1644
121
125
  rfd3/constants.py,sha256=wLvDzrThpOrK8T3wGFNQeGrhAXOJQze8l3v_7pjIdMM,13141
122
- rfd3/engine.py,sha256=viXzVGYkHPyzv5d0Ifg8zFZ6Wqg-U4I8Y4ldQLPe9x4,21536
126
+ rfd3/engine.py,sha256=cj7BIP35HMycWV3rBp9n9_ik5t7eNMK1wkVfJ58xFmI,21903
123
127
  rfd3/run_inference.py,sha256=HfRMQ30_SAHfc-VFzBV52F-aLaNdG6PW8VkdMyB__wE,1264
124
128
  rfd3/train.py,sha256=rHswffIUhOae3_iYyvAiQ3jALoFuzrcRUgMlbJLinlI,7947
125
129
  rfd3/inference/datasets.py,sha256=9VLbzl7dpG8mk_pjs0R5C2wFYUoRIgXXoZcS9IohSy0,6510
126
- rfd3/inference/input_parsing.py,sha256=pocqnnhE3-szeBbCL9gy9E3kZJSP_CGHXd6FFRxfv0c,46563
130
+ rfd3/inference/input_parsing.py,sha256=qUMsMyeh-kW757xX10zi07kKg4RS0wM8jSoIxu7UsBI,46794
127
131
  rfd3/inference/legacy_input_parsing.py,sha256=G2XxkrjdIpL6i1YY7xEmkFitVv__Pc45ow6IKKPHw64,28855
128
132
  rfd3/inference/parsing.py,sha256=ktAMUuZE3Pe4bKAjjV3zjqcEDmGlMZ-cotIUhJsEQQA,5402
129
- rfd3/inference/symmetry/atom_array.py,sha256=j8yhtZAjRNm4d06KyS4gk2XCquHPCwR0k9WiNmxz7WA,12941
133
+ rfd3/inference/symmetry/atom_array.py,sha256=HfFagFUB5yB-Y4IfUM5nuVGWHC5AEkyHqt0JcIqTQ_E,10922
130
134
  rfd3/inference/symmetry/checks.py,sha256=ZWpC1JjrAjXY__xDt8EFYb5WUdSF0kobZyxxmacFU7U,10076
131
135
  rfd3/inference/symmetry/contigs.py,sha256=6OvbZ2dJg-a0mvvKAC0VkzUH5HpUDxOJvkByIst_roU,2127
132
- rfd3/inference/symmetry/frames.py,sha256=kwog5jU_wgv6ACcUER2iU5qz-mdfdOe0cpi2OTEDKMU,18894
133
- rfd3/inference/symmetry/symmetry_utils.py,sha256=CGUzMI5CKVIcNi5_l2-YRu-ExroZX54GndxLb5P7RtY,14680
136
+ rfd3/inference/symmetry/frames.py,sha256=aEwkmlUsYexERX9hu09JMhisC8QTpHPVhfITbL80-EE,10819
137
+ rfd3/inference/symmetry/symmetry_utils.py,sha256=ibYeOT-4z_Pxr5i8cKbNAAqjRf4YfKOQCpSkuZQqG0I,14554
134
138
  rfd3/metrics/design_metrics.py,sha256=O1RqZdjQPNlAWYRg6UJTERYg_gUI1_hVleKsm9xbWBY,16836
135
139
  rfd3/metrics/hbonds_hbplus_metrics.py,sha256=Sewy9KzmrA1OnfkasN-fmWrQ9IRx9G7Yyhe2ua0mk28,11518
136
140
  rfd3/metrics/hbonds_metrics.py,sha256=SIR4BnDhYdpVSqwXXRYpQ_tB-M0_fVyugGl08WivCmE,15257
@@ -140,7 +144,7 @@ rfd3/metrics/sidechain_metrics.py,sha256=EGZuFuWQ0cCe83EVPAf4eysN8vP9ifNjfnmE0o5
140
144
  rfd3/model/RFD3.py,sha256=95aKzye-XzuDyLGgost-Wsfu8eT635zHIRky-pNoHSA,3569
141
145
  rfd3/model/RFD3_diffusion_module.py,sha256=BPjKGyQpbnqdzii3gXMKLhhijNqV8Xh4bSosmfDBt8w,12094
142
146
  rfd3/model/cfg_utils.py,sha256=XPBLyoB_bQRLmdrJ1Z0hCjcVvgUMGIPuw4rxTlHjB_s,2575
143
- rfd3/model/inference_sampler.py,sha256=5k_UIkJCL6QO4BEqAxkR9GKxNAOz-NRq5Jh51Wm5MU8,25152
147
+ rfd3/model/inference_sampler.py,sha256=zyA4pU2MIr2PPTuQm9k70UXUm3PRp644va6XgmB3aJ8,25210
144
148
  rfd3/model/layers/attention.py,sha256=XuNA7WyFlRfLnAgky1PtGvXFCnDGv7GeEcXz8hodTBo,19472
145
149
  rfd3/model/layers/block_utils.py,sha256=oN0aD-vZiH4JbIFs2CzDmb2B74GNPKzdFurmGd-dirE,21244
146
150
  rfd3/model/layers/blocks.py,sha256=MOjJ53THxM2MMM27Ap7xiIXRCdI_SHzqKzLLQVX6FEc,24888
@@ -161,21 +165,21 @@ rfd3/transforms/conditioning_utils.py,sha256=9Pn9AFbih2FCzp5OOM9y7z6KH7HPxVibxTr
161
165
  rfd3/transforms/design_transforms.py,sha256=ePvnLsuKUOsE4LLcmF0bbkx1vf2AiD-35rzF4zUEcEE,30944
162
166
  rfd3/transforms/dna_crop.py,sha256=JeOsG0tXghJvgzEimfzBvlFN_lVd9TrvjnC929Abz5A,18214
163
167
  rfd3/transforms/hbonds.py,sha256=ijfJapFlhsh3JktpDoT3VFqKTTg6ynrqMlD7dU2xFsA,16415
164
- rfd3/transforms/hbonds_hbplus.py,sha256=xyDP-CyVl2OsUY90HsrPoKw1VycBXUrq00WfrX8HJVM,8364
168
+ rfd3/transforms/hbonds_hbplus.py,sha256=i-57ZGZjofxWvygDidwWeKrkrALz5LcwJCCcw6rCOfQ,8654
165
169
  rfd3/transforms/ncaa_transforms.py,sha256=Lz4L8OGuOOG53sKJHcLSdV7WPQ3YzOzwd5tJG4CHqP0,4983
166
170
  rfd3/transforms/pipelines.py,sha256=FGH-XH3taTWQ6k1zpDO_d-097EQdXmL6uqXZXw4HIMs,22086
167
171
  rfd3/transforms/ppi_transforms.py,sha256=7rXyf-tn2TLz6ybYR_YVDtSDG7hOgqhYY4shNviA_Sw,23493
168
172
  rfd3/transforms/rasa.py,sha256=a4IPFvVMMxldoGLyJQiSlGg7IyUkcBASbRZLWmguAKk,4156
169
- rfd3/transforms/symmetry.py,sha256=9I9gzAZkk5vMUJm7x8XCDSHtNPYYLAHt4meXxOczGT0,2970
173
+ rfd3/transforms/symmetry.py,sha256=GSnMF7oAnUxPozfafsRuHEv0yKXW0BpLTI6wsKGZrbc,2658
170
174
  rfd3/transforms/training_conditions.py,sha256=UXiUPjDwrNKM95tRe0eXrMeRN8XlTPc_MXUvo6UpePo,19510
171
175
  rfd3/transforms/util_transforms.py,sha256=2AcLkzx-73ZFgcWD1cIHv7NyniRPI4_zThHK8azyQaY,18119
172
176
  rfd3/transforms/virtual_atoms.py,sha256=UpmxzPPd5FaJigcRoxgLSHHrLLOqsCvZ5PPZfQSGqII,12547
173
- rfd3/utils/inference.py,sha256=Yf3aAUk_YZi58uIJr5Y2wfVnQ-2bh3S5GHLBPzCRjUs,26448
177
+ rfd3/utils/inference.py,sha256=A0yGOsO3cXiTqwLYP0uTGTB9MR6vsmnYMEiroYgvnDE,26578
174
178
  rfd3/utils/io.py,sha256=wbdjUTQkDc3RCSM7gdogA-XOKR68HeQ-cfvyN4pP90w,9849
175
179
  rfd3/utils/vizualize.py,sha256=HPlczrA3zkOuxV5X05eOvy_Oga9e3cPnFUXOEP4RR_g,11046
176
- rf3/configs/inference.yaml,sha256=JmEZdkAnbnOrX79lGS5xrYYho9aBFfVxfUp-8KjJV5I,309
177
- rf3/configs/train.yaml,sha256=4KW2fKc9a_gjg8yMoQfOpfkC-nJ5mdQEfoikOKxbnKc,1573
178
- rf3/configs/validate.yaml,sha256=3LkhXyneEuuH-ueFH9FyYY5cCDi1_0KoHNwEceuQPwI,1581
180
+ rf3/configs/inference.yaml,sha256=uvg-ppF7-ANMeecb60Vodx6GNlEx0W3HPiGgiajUZ5Q,375
181
+ rf3/configs/train.yaml,sha256=aQWrr2RXsFo97CYSFOyTMu7C8F6Nz7ymZU3BuX1gkec,1639
182
+ rf3/configs/validate.yaml,sha256=6FUaD-kLbu0jfUSgFIXfAuWhHtVH-4w_swGJSf2AQKc,1647
179
183
  rf3/configs/callbacks/default.yaml,sha256=MkxOj7dMXh4jJRIE62gLjoOYecGuZLWiJrr780_nubA,89
180
184
  rf3/configs/callbacks/dump_validation_structures.yaml,sha256=EYEibR25v7KZJtadvCFLFMEPTf0FvKFNW2ocx4wm57A,259
181
185
  rf3/configs/callbacks/metrics_logging.yaml,sha256=MNm4OpvOHxvDJofVUA27NVaiDkp1NzqOYCzl6l_7ceo,432
@@ -224,6 +228,7 @@ rf3/configs/trainer/cpu.yaml,sha256=J1WbK2SQ_VMoEOOqv0XTg0FYKPaec1TNySvWklMaE4k,
224
228
  rf3/configs/trainer/ddp.yaml,sha256=uClrdTzEMNxgq4IQhMgm8okC16wUS2I5i3rKnm5SktU,65
225
229
  rf3/configs/trainer/rf3.yaml,sha256=WBjnaYofmEV7OfJisvFGN6y_UJGNEskAdmbV9oO4ICI,500
226
230
  rf3/configs/trainer/rf3_with_confidence.yaml,sha256=meDTw0S2nTcuAj5tefGwJteo2715x6kO143FMd3db14,346
231
+ rf3/configs/trainer/xpu.yaml,sha256=GlUmQcnWtZOcp7wHip5DikKWfOVZTwrTX_HU05A4kc8,73
227
232
  rf3/configs/trainer/loss/structure_prediction.yaml,sha256=XPH2RcIo6m1BDrIoBvd2xBgJ0c4KEOnbWc5oEcBKVAM,164
228
233
  rf3/configs/trainer/loss/structure_prediction_with_confidence.yaml,sha256=OZxtrQIqYGDFlhp3-1fFVZ2Ek6QRMP74ImsrNOJcelo,52
229
234
  rf3/configs/trainer/loss/losses/confidence_loss.yaml,sha256=c9xI7Q32Ddl5R_GpoWq4RS7AxiJf4hsnGeMZxWs9mBU,501
@@ -231,10 +236,10 @@ rf3/configs/trainer/loss/losses/diffusion_loss.yaml,sha256=ktBy3bUjslklPuRhGWuHb
231
236
  rf3/configs/trainer/loss/losses/distogram_loss.yaml,sha256=-dZWPeQ1mWwi9pUtB7xJW57qKgUNMQ4vao1PXo07RC0,56
232
237
  rf3/configs/trainer/metrics/structure_prediction.yaml,sha256=xWp2DqoqlofbgRLMNi0LKuAaDSCbW1tAEuEk83fmc0w,439
233
238
  rfd3/configs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
234
- rfd3/configs/dev.yaml,sha256=Y3PiHaCRIdRXf2Ea2aUDJk6hFxyNEljJf-O2VKQRNH0,151
235
- rfd3/configs/inference.yaml,sha256=tUo5-G1-rsCU8RLclSP0ZKu5-O6JC5ofgI6bHb5kwmw,119
236
- rfd3/configs/train.yaml,sha256=SbP5136VjJH_i2fc_4b0U0ZOqz3jeby1YORXSQtiFk0,450
237
- rfd3/configs/validate.yaml,sha256=TSToLqzuP8hecj0K6TAfCywtpTZI1u1-THsX0jxCG1s,537
239
+ rfd3/configs/dev.yaml,sha256=ulil1Am_romDNUOGeQY3HLDRyPMiVKHOZrh59H4uAMA,176
240
+ rfd3/configs/inference.yaml,sha256=zzrhijUPplk--wjKk5d_Ht3Nnv-qLav5DQcfHcmLSNY,144
241
+ rfd3/configs/train.yaml,sha256=BJ3owsuFDVWiVgZdmA0VaQkPtMS9ws8FEyvBZeI4YQ8,476
242
+ rfd3/configs/validate.yaml,sha256=E78w-K6bE0Kl27V-wz9BuR9QyZVXTupxjT0UpYIA7Ao,562
238
243
  rfd3/configs/callbacks/design_callbacks.yaml,sha256=JWgE1-v_spzUy7JH3_6dHct_-oX4DevozRH-pM5Ds2k,196
239
244
  rfd3/configs/callbacks/metrics_logging.yaml,sha256=pZPePSYGKEV560e3WatuLvJiHlz1CIGFOaOWoRmBh8g,694
240
245
  rfd3/configs/callbacks/train_logging.yaml,sha256=Z25GVLTHo1HvQUjnBdayaozmNww8UhAh9DczStYaZig,1050
@@ -301,11 +306,12 @@ rfd3/configs/paths/data/default.yaml,sha256=jfs1dbbcOqHja4_6lXheyRg4t0YExqVn2w0r
301
306
  rfd3/configs/trainer/cpu.yaml,sha256=rJf5LHf6x5fN5EKg8mFEn2SwfGW5dV1JdYaHqWMfpXc,74
302
307
  rfd3/configs/trainer/ddp.yaml,sha256=uClrdTzEMNxgq4IQhMgm8okC16wUS2I5i3rKnm5SktU,65
303
308
  rfd3/configs/trainer/rfd3_base.yaml,sha256=R3lZxdyjUirjlLU31qWlnZgHaz4GcWTGGIz4fUl7AyM,1016
309
+ rfd3/configs/trainer/xpu.yaml,sha256=GlUmQcnWtZOcp7wHip5DikKWfOVZTwrTX_HU05A4kc8,73
304
310
  rfd3/configs/trainer/loss/losses/diffusion_loss.yaml,sha256=FE4FCEfurE0ekwZ4YfS6wCvPSNqxClwg_kc73cPql5Y,323
305
311
  rfd3/configs/trainer/loss/losses/sequence_loss.yaml,sha256=kezbQcqwAZ0VKQPUBr2MsNr9DcDL3ENIP1i-j7h-6Co,64
306
312
  rfd3/configs/trainer/metrics/design_metrics.yaml,sha256=xVDpClhHqSHvsf-8StL26z51Vn-iuWMDG9KMB-kqOI0,719
307
- rc_foundry-0.1.9.dist-info/METADATA,sha256=SBHGrkr8RsLTCTOhJnbusBs8G7C8HxJUPajCbmU6OyE,11502
308
- rc_foundry-0.1.9.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
309
- rc_foundry-0.1.9.dist-info/entry_points.txt,sha256=BmiWCbWGtrd_lSOFMuCLBXyo84B7Nco-alj7hB0Yw9A,130
310
- rc_foundry-0.1.9.dist-info/licenses/LICENSE.md,sha256=NKtPCJ7QMysFmzeDg56ZfUStvgzbq5sOvRQv7_ddZOs,1533
311
- rc_foundry-0.1.9.dist-info/RECORD,,
313
+ rc_foundry-0.1.11.dist-info/METADATA,sha256=lK8MYr2naArWrcTOenHGERoQTliCfzPDgDave3bpWJI,11873
314
+ rc_foundry-0.1.11.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
315
+ rc_foundry-0.1.11.dist-info/entry_points.txt,sha256=BmiWCbWGtrd_lSOFMuCLBXyo84B7Nco-alj7hB0Yw9A,130
316
+ rc_foundry-0.1.11.dist-info/licenses/LICENSE.md,sha256=NKtPCJ7QMysFmzeDg56ZfUStvgzbq5sOvRQv7_ddZOs,1533
317
+ rc_foundry-0.1.11.dist-info/RECORD,,
@@ -2,6 +2,11 @@
2
2
  # ^ The "package" determines where the content of the config is placed in the output config
3
3
  # For more information about overriding configs, see: https://hydra.cc/docs/advanced/overriding_packages/#overriding-packages-using-the-defaults-list
4
4
 
5
+ hydra:
6
+ searchpath:
7
+ - pkg://rf3.configs
8
+ - pkg://configs
9
+
5
10
  defaults:
6
11
  - inference_engine: rf3
7
12
  - _self_
rf3/configs/train.yaml CHANGED
@@ -2,6 +2,11 @@
2
2
  # ^ The "package" determines where the content of the config is placed in the output config
3
3
  # For more information about overriding configs, see: https://hydra.cc/docs/advanced/overriding_packages/#overriding-packages-using-the-defaults-list
4
4
 
5
+ hydra:
6
+ searchpath:
7
+ - pkg://rf3.configs
8
+ - pkg://configs
9
+
5
10
  # NOTE: order of defaults determines the order in which configs override each other (higher up items are overridden by lower items)
6
11
  defaults:
7
12
  - callbacks: default
@@ -0,0 +1,6 @@
1
+ strategy: xpu_single
2
+
3
+ accelerator: xpu
4
+ devices_per_node: 1
5
+ num_nodes: 1
6
+
rf3/configs/validate.yaml CHANGED
@@ -2,6 +2,11 @@
2
2
  # ^ The "package" determines where the content of the config is placed in the output config
3
3
  # For more information about overriding configs, see: https://hydra.cc/docs/advanced/overriding_packages/#overriding-packages-using-the-defaults-list
4
4
 
5
+ hydra:
6
+ searchpath:
7
+ - pkg://rf3.configs
8
+ - pkg://configs
9
+
5
10
  # NOTE: order of defaults determines the order in which configs override each other (higher up items are overridden by lower items)
6
11
  defaults:
7
12
  - callbacks: default
rf3/inference.py CHANGED
@@ -12,7 +12,10 @@ from foundry.utils.logging import suppress_warnings
12
12
 
13
13
  # Setup root dir and environment variables (more info: https://github.com/ashleve/rootutils)
14
14
  # NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located)
15
- rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
15
+ try:
16
+ rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
17
+ except Exception:
18
+ pass
16
19
 
17
20
  load_dotenv(override=True)
18
21
 
rfd3/configs/dev.yaml CHANGED
@@ -2,6 +2,7 @@
2
2
  # Inference engine config for development
3
3
  hydra:
4
4
  searchpath:
5
+ - pkg://rfd3.configs
5
6
  - pkg://configs
6
7
 
7
8
  defaults:
@@ -2,6 +2,7 @@
2
2
 
3
3
  hydra:
4
4
  searchpath:
5
+ - pkg://rfd3.configs
5
6
  - pkg://configs
6
7
 
7
8
  defaults:
rfd3/configs/train.yaml CHANGED
@@ -1,6 +1,7 @@
1
1
  # @package _global_
2
2
  hydra:
3
3
  searchpath:
4
+ - pkg://rfd3.configs
4
5
  - pkg://configs
5
6
 
6
7
  defaults:
@@ -25,4 +26,4 @@ ckpt_path: null
25
26
 
26
27
  # Placeholders
27
28
  name: aa_design
28
- tags: [aa_design]
29
+ tags: [aa_design]
@@ -0,0 +1,6 @@
1
+ strategy: xpu_single
2
+
3
+ accelerator: xpu
4
+ devices_per_node: 1
5
+ num_nodes: 1
6
+