rc-foundry 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/trainers/fabric.py +18 -2
- foundry/utils/components.py +3 -3
- foundry/utils/ddp.py +15 -12
- foundry/utils/xpu/__init__.py +27 -0
- foundry/utils/xpu/single_xpu_strategy.py +47 -0
- foundry/utils/xpu/xpu_accelerator.py +91 -0
- foundry/utils/xpu/xpu_precision.py +72 -0
- foundry/version.py +2 -2
- mpnn/inference_engines/mpnn.py +6 -2
- {rc_foundry-0.1.9.dist-info → rc_foundry-0.1.11.dist-info}/METADATA +11 -1
- {rc_foundry-0.1.9.dist-info → rc_foundry-0.1.11.dist-info}/RECORD +33 -27
- rf3/configs/inference.yaml +5 -0
- rf3/configs/train.yaml +5 -0
- rf3/configs/trainer/xpu.yaml +6 -0
- rf3/configs/validate.yaml +5 -0
- rf3/inference.py +4 -1
- rfd3/configs/dev.yaml +1 -0
- rfd3/configs/inference.yaml +1 -0
- rfd3/configs/train.yaml +2 -1
- rfd3/configs/trainer/xpu.yaml +6 -0
- rfd3/configs/validate.yaml +1 -0
- rfd3/engine.py +14 -7
- rfd3/inference/input_parsing.py +4 -2
- rfd3/inference/symmetry/atom_array.py +9 -78
- rfd3/inference/symmetry/frames.py +0 -248
- rfd3/inference/symmetry/symmetry_utils.py +2 -3
- rfd3/model/inference_sampler.py +3 -1
- rfd3/transforms/hbonds_hbplus.py +52 -49
- rfd3/transforms/symmetry.py +7 -16
- rfd3/utils/inference.py +7 -6
- {rc_foundry-0.1.9.dist-info → rc_foundry-0.1.11.dist-info}/WHEEL +0 -0
- {rc_foundry-0.1.9.dist-info → rc_foundry-0.1.11.dist-info}/entry_points.txt +0 -0
- {rc_foundry-0.1.9.dist-info → rc_foundry-0.1.11.dist-info}/licenses/LICENSE.md +0 -0
foundry/trainers/fabric.py
CHANGED
|
@@ -36,6 +36,11 @@ from foundry.utils.weights import (
|
|
|
36
36
|
freeze_parameters_with_config,
|
|
37
37
|
load_weights_with_policies,
|
|
38
38
|
)
|
|
39
|
+
from foundry.utils.xpu import (
|
|
40
|
+
SingleXPUStrategy,
|
|
41
|
+
XPUAccelerator,
|
|
42
|
+
XPUMixedPrecision,
|
|
43
|
+
)
|
|
39
44
|
|
|
40
45
|
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
41
46
|
|
|
@@ -122,8 +127,19 @@ class FabricTrainer(ABC):
|
|
|
122
127
|
(3) Fabric Loggers (https://lightning.ai/docs/fabric/2.4.0/api/loggers.html)
|
|
123
128
|
(4) Efficient Gradient Accumulation (https://lightning.ai/docs/fabric/2.4.0/advanced/gradient_accumulation.html)
|
|
124
129
|
"""
|
|
125
|
-
#
|
|
126
|
-
|
|
130
|
+
# Handle XPU accelerator
|
|
131
|
+
is_xpu = hasattr(torch, "xpu") and torch.xpu.is_available()
|
|
132
|
+
if accelerator == "xpu" or (
|
|
133
|
+
accelerator == "auto" and is_xpu and not torch.cuda.is_available()
|
|
134
|
+
):
|
|
135
|
+
accelerator = XPUAccelerator()
|
|
136
|
+
precision_plugin = None
|
|
137
|
+
if precision in ("16-mixed", "bf16-mixed"):
|
|
138
|
+
precision_plugin = XPUMixedPrecision(precision=precision)
|
|
139
|
+
precision = None # Handled by plugin
|
|
140
|
+
strategy = SingleXPUStrategy(precision_plugin=precision_plugin)
|
|
141
|
+
ranked_logger.info("Using Intel XPU with SingleXPUStrategy")
|
|
142
|
+
elif (
|
|
127
143
|
strategy == "ddp"
|
|
128
144
|
and not is_interactive_environment()
|
|
129
145
|
and not (num_nodes == 1 and devices_per_node == 1)
|
foundry/utils/components.py
CHANGED
|
@@ -5,11 +5,11 @@ from typing import List
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
|
|
7
7
|
from biotite.structure import AtomArray
|
|
8
|
-
from rfd3.constants import (
|
|
9
|
-
TIP_BY_RESTYPE,
|
|
10
|
-
)
|
|
11
8
|
|
|
12
9
|
from foundry.common import exists
|
|
10
|
+
from foundry.constants import (
|
|
11
|
+
TIP_BY_RESTYPE,
|
|
12
|
+
)
|
|
13
13
|
from foundry.utils.ddp import RankedLogger
|
|
14
14
|
|
|
15
15
|
global_logger = RankedLogger(__name__, rank_zero_only=False)
|
foundry/utils/ddp.py
CHANGED
|
@@ -20,7 +20,7 @@ def is_rank_zero() -> bool:
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
def set_accelerator_based_on_availability(cfg: dict | DictConfig):
|
|
23
|
-
"""Set training accelerator
|
|
23
|
+
"""Set training accelerator based on available hardware.
|
|
24
24
|
|
|
25
25
|
Args:
|
|
26
26
|
cfg: Hydra object with trainer settings "accelerator", "devices_per_node", and "num_nodes".
|
|
@@ -28,22 +28,25 @@ def set_accelerator_based_on_availability(cfg: dict | DictConfig):
|
|
|
28
28
|
Returns:
|
|
29
29
|
None; modifies the input `cfg` object in place.
|
|
30
30
|
"""
|
|
31
|
-
|
|
31
|
+
assert "trainer" in cfg, "Configuration object must have a 'trainer' key."
|
|
32
|
+
for key in ["accelerator", "devices_per_node", "num_nodes"]:
|
|
33
|
+
assert (
|
|
34
|
+
key in cfg.trainer
|
|
35
|
+
), f"Configuration object must have a 'trainer.{key}' key."
|
|
36
|
+
|
|
37
|
+
if torch.cuda.is_available():
|
|
38
|
+
cfg.trainer.accelerator = "gpu"
|
|
39
|
+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
40
|
+
logger.info("Intel XPU detected - using XPU accelerator")
|
|
41
|
+
cfg.trainer.accelerator = "xpu"
|
|
42
|
+
else:
|
|
32
43
|
logger.error(
|
|
33
|
-
"No GPUs available - Setting accelerator to 'cpu'. Are you sure you are using the correct configs?"
|
|
44
|
+
"No GPUs/XPUs available - Setting accelerator to 'cpu'. Are you sure you are using the correct configs?"
|
|
34
45
|
)
|
|
35
|
-
assert "trainer" in cfg, "Configuration object must have a 'trainer' key."
|
|
36
|
-
for key in ["accelerator", "devices_per_node", "num_nodes"]:
|
|
37
|
-
assert (
|
|
38
|
-
key in cfg.trainer
|
|
39
|
-
), f"Configuration object must have a 'trainer.{key}' key."
|
|
40
|
-
|
|
41
|
-
# Override accelerator settings
|
|
42
46
|
cfg.trainer.accelerator = "cpu"
|
|
43
47
|
cfg.trainer.devices_per_node = 1
|
|
44
48
|
cfg.trainer.num_nodes = 1
|
|
45
|
-
|
|
46
|
-
cfg.trainer.accelerator = "gpu"
|
|
49
|
+
|
|
47
50
|
return cfg
|
|
48
51
|
|
|
49
52
|
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""XPU utilities for Intel GPU support.
|
|
2
|
+
|
|
3
|
+
XPU support in PyTorch is now native (torch.xpu.is_available()), but Lightning Fabric
|
|
4
|
+
requires custom Accelerator, Strategy, and Precision plugins for proper XPU handling.
|
|
5
|
+
|
|
6
|
+
These components are used directly (not via registry) when XPU is detected:
|
|
7
|
+
- XPUAccelerator: Custom accelerator for XPU devices
|
|
8
|
+
- SingleXPUStrategy: Strategy for single-device XPU training/inference
|
|
9
|
+
- XPUMixedPrecision: Precision plugin with proper XPU autocast support
|
|
10
|
+
|
|
11
|
+
Usage:
|
|
12
|
+
from foundry.utils.xpu import XPUAccelerator, SingleXPUStrategy, XPUMixedPrecision
|
|
13
|
+
|
|
14
|
+
# Check availability
|
|
15
|
+
if XPUAccelerator.is_available():
|
|
16
|
+
strategy = SingleXPUStrategy(precision_plugin=XPUMixedPrecision("bf16-mixed"))
|
|
17
|
+
|
|
18
|
+
Note:
|
|
19
|
+
The FabricTrainer automatically uses these components when XPU is detected.
|
|
20
|
+
You typically don't need to use them directly unless customizing behavior.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from .single_xpu_strategy import SingleXPUStrategy
|
|
24
|
+
from .xpu_accelerator import XPUAccelerator
|
|
25
|
+
from .xpu_precision import XPUMixedPrecision
|
|
26
|
+
|
|
27
|
+
__all__ = ["SingleXPUStrategy", "XPUAccelerator", "XPUMixedPrecision"]
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""Lightning Fabric strategy for single XPU device.
|
|
2
|
+
|
|
3
|
+
https://docs.pytorch.org/docs/stable/notes/get_start_xpu.html
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from lightning.fabric.plugins import CheckpointIO
|
|
8
|
+
from lightning.fabric.plugins.precision import Precision
|
|
9
|
+
from lightning.fabric.strategies import SingleDeviceStrategy
|
|
10
|
+
from lightning.fabric.utilities.types import _DEVICE
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SingleXPUStrategy(SingleDeviceStrategy):
|
|
14
|
+
"""Strategy for training/inference on a single Intel XPU device.
|
|
15
|
+
|
|
16
|
+
This strategy extends SingleDeviceStrategy to properly handle XPU devices.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
strategy_name = "xpu_single"
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
device: _DEVICE = "xpu:0",
|
|
24
|
+
checkpoint_io: CheckpointIO | None = None,
|
|
25
|
+
precision_plugin: Precision | None = None,
|
|
26
|
+
) -> None:
|
|
27
|
+
"""Initialize the single XPU strategy.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
device: The XPU device to use. Defaults to "xpu:0".
|
|
31
|
+
checkpoint_io: Plugin for checkpoint I/O.
|
|
32
|
+
precision_plugin: Plugin for precision handling (set via _precision property).
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
RuntimeError: If XPU devices are not available.
|
|
36
|
+
"""
|
|
37
|
+
if not (hasattr(torch, "xpu") and torch.xpu.is_available()):
|
|
38
|
+
msg = "`SingleXPUStrategy` requires XPU devices to run"
|
|
39
|
+
raise RuntimeError(msg)
|
|
40
|
+
|
|
41
|
+
super().__init__(
|
|
42
|
+
device=device,
|
|
43
|
+
checkpoint_io=checkpoint_io,
|
|
44
|
+
)
|
|
45
|
+
# Precision is handled via the _precision property in newer Lightning versions
|
|
46
|
+
if precision_plugin is not None:
|
|
47
|
+
self._precision = precision_plugin
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
"""XPU Accelerator for Intel XPU devices."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from lightning.fabric.accelerators import Accelerator
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class XPUAccelerator(Accelerator):
|
|
10
|
+
"""Accelerator for Intel XPU devices.
|
|
11
|
+
|
|
12
|
+
This accelerator enables training and inference on Intel GPUs using
|
|
13
|
+
PyTorch's native XPU support (torch.xpu).
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def name(self) -> str:
|
|
18
|
+
"""Return the name of this accelerator."""
|
|
19
|
+
return "xpu"
|
|
20
|
+
|
|
21
|
+
@staticmethod
|
|
22
|
+
def setup_device(device: torch.device) -> None:
|
|
23
|
+
"""Set up the specified XPU device.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
device: The torch device to set up.
|
|
27
|
+
|
|
28
|
+
Raises:
|
|
29
|
+
RuntimeError: If device is not an XPU device.
|
|
30
|
+
"""
|
|
31
|
+
if device.type != "xpu":
|
|
32
|
+
msg = f"Device should be xpu, got {device} instead"
|
|
33
|
+
raise RuntimeError(msg)
|
|
34
|
+
|
|
35
|
+
torch.xpu.set_device(device)
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def parse_devices(devices: str | list | torch.device) -> list:
|
|
39
|
+
"""Parse devices specification for XPU training.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
devices: Device specification (int, list of ints, or string).
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
List of device indices.
|
|
46
|
+
"""
|
|
47
|
+
if isinstance(devices, list):
|
|
48
|
+
return devices
|
|
49
|
+
return [devices]
|
|
50
|
+
|
|
51
|
+
@staticmethod
|
|
52
|
+
def get_parallel_devices(devices: list) -> list[torch.device]:
|
|
53
|
+
"""Generate a list of parallel XPU devices.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
devices: List of device indices.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
List of torch.device objects for XPU.
|
|
60
|
+
"""
|
|
61
|
+
return [torch.device("xpu", idx) for idx in devices]
|
|
62
|
+
|
|
63
|
+
@staticmethod
|
|
64
|
+
def auto_device_count() -> int:
|
|
65
|
+
"""Return the number of available XPU devices."""
|
|
66
|
+
return torch.xpu.device_count()
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
def is_available() -> bool:
|
|
70
|
+
"""Check if XPU is available."""
|
|
71
|
+
return hasattr(torch, "xpu") and torch.xpu.is_available()
|
|
72
|
+
|
|
73
|
+
@staticmethod
|
|
74
|
+
def get_device_stats(device: str | torch.device) -> dict[str, Any]:
|
|
75
|
+
"""Return XPU device statistics.
|
|
76
|
+
|
|
77
|
+
Currently returns an empty dict as XPU stats API may vary.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
device: The device to get stats for.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
Dictionary of device statistics.
|
|
84
|
+
"""
|
|
85
|
+
del device # Unused
|
|
86
|
+
return {}
|
|
87
|
+
|
|
88
|
+
def teardown(self) -> None:
|
|
89
|
+
"""Clean up XPU accelerator resources."""
|
|
90
|
+
# Empty implementation required by base class
|
|
91
|
+
pass
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
"""XPU Precision Plugin for Lightning Fabric."""
|
|
2
|
+
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
from typing import Any, Generator, Literal
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from lightning.fabric.plugins.precision import MixedPrecision
|
|
8
|
+
from torch import Tensor
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class XPUMixedPrecision(MixedPrecision):
|
|
12
|
+
"""Mixed precision plugin for Intel XPU devices.
|
|
13
|
+
|
|
14
|
+
This overrides the default MixedPrecision plugin to use 'xpu' as the
|
|
15
|
+
device type for torch.autocast instead of 'cuda'.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
precision: Literal["16-mixed", "bf16-mixed"] = "bf16-mixed",
|
|
21
|
+
) -> None:
|
|
22
|
+
"""Initialize XPU mixed precision.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
precision: The precision mode. "16-mixed" uses float16,
|
|
26
|
+
"bf16-mixed" uses bfloat16. Defaults to "bf16-mixed".
|
|
27
|
+
|
|
28
|
+
Raises:
|
|
29
|
+
ValueError: If precision is not "16-mixed" or "bf16-mixed".
|
|
30
|
+
"""
|
|
31
|
+
# Determine dtype from precision string
|
|
32
|
+
if precision == "16-mixed":
|
|
33
|
+
dtype = torch.float16
|
|
34
|
+
elif precision == "bf16-mixed":
|
|
35
|
+
dtype = torch.bfloat16
|
|
36
|
+
else:
|
|
37
|
+
msg = f"Invalid precision: {precision}. Must be '16-mixed' or 'bf16-mixed'"
|
|
38
|
+
raise ValueError(msg)
|
|
39
|
+
|
|
40
|
+
# Initialize with xpu device type
|
|
41
|
+
super().__init__(precision=precision, device="xpu")
|
|
42
|
+
self._desired_input_dtype = dtype
|
|
43
|
+
|
|
44
|
+
@contextmanager
|
|
45
|
+
def forward_context(self) -> Generator[None, None, None]:
|
|
46
|
+
"""Context manager for forward pass with XPU autocast."""
|
|
47
|
+
with torch.autocast(device_type="xpu", dtype=self._desired_input_dtype):
|
|
48
|
+
yield
|
|
49
|
+
|
|
50
|
+
def convert_input(self, data: Any) -> Any:
|
|
51
|
+
"""Convert input data to the appropriate precision.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
data: Input data to convert.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
Converted data.
|
|
58
|
+
"""
|
|
59
|
+
return self._convert_fp_tensor(data)
|
|
60
|
+
|
|
61
|
+
def _convert_fp_tensor(self, data: Any) -> Any:
|
|
62
|
+
"""Convert floating point tensors to the desired dtype.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
data: Data to convert.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
Converted data if it's a floating point tensor, otherwise unchanged.
|
|
69
|
+
"""
|
|
70
|
+
if isinstance(data, Tensor) and data.is_floating_point():
|
|
71
|
+
return data.to(self._desired_input_dtype)
|
|
72
|
+
return data
|
foundry/version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.1.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 1,
|
|
31
|
+
__version__ = version = '0.1.11'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 1, 11)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
mpnn/inference_engines/mpnn.py
CHANGED
|
@@ -67,11 +67,15 @@ class MPNNInferenceEngine:
|
|
|
67
67
|
else checkpoint_path
|
|
68
68
|
)
|
|
69
69
|
|
|
70
|
-
# Determine the device.
|
|
70
|
+
# Determine the device (supports XPU, CUDA, and CPU).
|
|
71
71
|
if device is not None:
|
|
72
72
|
self.device = torch.device(device)
|
|
73
|
+
elif torch.cuda.is_available():
|
|
74
|
+
self.device = torch.device("cuda")
|
|
75
|
+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
76
|
+
self.device = torch.device("xpu")
|
|
73
77
|
else:
|
|
74
|
-
self.device = torch.device("
|
|
78
|
+
self.device = torch.device("cpu")
|
|
75
79
|
|
|
76
80
|
# Set up allowed model types.
|
|
77
81
|
self.allowed_model_types = {"protein_mpnn", "ligand_mpnn"}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: rc-foundry
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.11
|
|
4
4
|
Summary: Shared utilities and training infrastructure for biomolecular structure prediction models.
|
|
5
5
|
Author-email: Institute for Protein Design <contact@ipd.uw.edu>
|
|
6
6
|
License: BSD 3-Clause License
|
|
@@ -108,6 +108,16 @@ All models within Foundry rely on [AtomWorks](https://github.com/RosettaCommons/
|
|
|
108
108
|
pip install "rc-foundry[all]"
|
|
109
109
|
```
|
|
110
110
|
|
|
111
|
+
**Intel XPU Installation**
|
|
112
|
+
|
|
113
|
+
For Intel XPU devices, install PyTorch with XPU support first, then install Foundry.
|
|
114
|
+
```bash
|
|
115
|
+
pip install torch --index-url https://download.pytorch.org/whl/xpu
|
|
116
|
+
pip install "rc-foundry[all]"
|
|
117
|
+
```
|
|
118
|
+
> [!NOTE]
|
|
119
|
+
> Use `pip` (not `uv`) for XPU installs since UV re-resolves dependencies and may replace your XPU torch with the standard PyPI version.
|
|
120
|
+
|
|
111
121
|
**Downloading weights** Models can be downloaded to a target folder with:
|
|
112
122
|
```
|
|
113
123
|
foundry install base-models --checkpoint-dir <path/to/ckpt/dir>
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
foundry/__init__.py,sha256=H8S1nl5v6YeW8ggn1jKy4GdtH7c-FGS-j7CqUCAEnAU,1926
|
|
2
2
|
foundry/common.py,sha256=Aur8mH-CNmcUqSsw7VgaCQSW5sH1Bqf8Da91jzxPV1Y,3035
|
|
3
3
|
foundry/constants.py,sha256=0n1wBKCvNuw3QaQehSbmsHYkIdaGn3tLeRFItBrdeHY,913
|
|
4
|
-
foundry/version.py,sha256=
|
|
4
|
+
foundry/version.py,sha256=0-Ruc52ECccw_8Ef0d7jMkzrb8fkobUkZLqGGvcm1ik,706
|
|
5
5
|
foundry/callbacks/__init__.py,sha256=VsRT1e4sqlJHPcTCsfupMEx82Iz-LoOAGPpwvf_OJeE,126
|
|
6
6
|
foundry/callbacks/callback.py,sha256=xZBo_suP4bLrP6gl5uJPbaXm00DXigePa6dMeDxucgg,3890
|
|
7
7
|
foundry/callbacks/health_logging.py,sha256=tEtkByOlaAA7nnelxb7PbM9_dcIgOsdbxCdQY3K5pMc,16664
|
|
@@ -18,14 +18,14 @@ foundry/model/layers/blocks.py,sha256=ihbbP_1fOlrkrcrQSk9thCrNWjK8mtxD3WxcBng9Ht
|
|
|
18
18
|
foundry/testing/__init__.py,sha256=BnrU7fZ4l0Dm1vrGcNPQYTAw83PW4DGYz7TGhGqgrfQ,223
|
|
19
19
|
foundry/testing/fixtures.py,sha256=j27a8CAonygjlWsUjZ-95M5MF4Rjp9nw7JskqiZlweI,486
|
|
20
20
|
foundry/testing/pytest_hooks.py,sha256=5Ebw1GXYO2XqS9Jvpzty7g3gCXIdXu16jqg53XcuUx4,450
|
|
21
|
-
foundry/trainers/fabric.py,sha256=
|
|
21
|
+
foundry/trainers/fabric.py,sha256=qZVO3C48-Won57TbxMO9qb8VOh8JPNG8kqXSW6nh-YU,41057
|
|
22
22
|
foundry/training/EMA.py,sha256=3OWA9Pz7XuDr-SRxbz24tZf55DmhSa2fKy9r5v2IXqA,2651
|
|
23
23
|
foundry/training/checkpoint.py,sha256=mUiObg-qcF3tvMfVu77sD9m3yVRp71czv07ccliU7qQ,1791
|
|
24
24
|
foundry/training/schedulers.py,sha256=StmXegPfIdLAv31FreCTrDh9dsOvNUfzG4YGa61Y4oE,3647
|
|
25
25
|
foundry/utils/alignment.py,sha256=2anqy0mn9zeFEiVWS_EG7zHiyPk1C_gbUu-SRvQ5mAM,2502
|
|
26
|
-
foundry/utils/components.py,sha256=
|
|
26
|
+
foundry/utils/components.py,sha256=gCRjmZXqltOqU6tV_jISX6T-BX2z9zzr-H0Xgo7oojw,15203
|
|
27
27
|
foundry/utils/datasets.py,sha256=pLBxVezm-TSrYuC5gFnJZdGnNWV7aPH2QiWIVE2hkdQ,16629
|
|
28
|
-
foundry/utils/ddp.py,sha256=
|
|
28
|
+
foundry/utils/ddp.py,sha256=gJY3w0pwelr6UoxOkHsttgcZsTN97kfHFtfGN2VKLM8,4075
|
|
29
29
|
foundry/utils/instantiators.py,sha256=oGCp6hrmY-QPPPEjxKxe5uVFL125fH1RaLxjMKWCD_8,2169
|
|
30
30
|
foundry/utils/logging.py,sha256=ywV75MBlQsothV0IBvqoAQTNg6pjo2-Cib7Uo080nzQ,9312
|
|
31
31
|
foundry/utils/rigid.py,sha256=_Z1pmitb6xgxyguLj_TukKscUBJjQsU4bsBD24GVS84,44444
|
|
@@ -33,13 +33,17 @@ foundry/utils/rotation_augmentation.py,sha256=7q1WEX2iJ0i7-2aV-M97nEaEdpqexDTaZn
|
|
|
33
33
|
foundry/utils/squashfs.py,sha256=QlcwuJyVe-QVfIOS7o1QfLhaCQPNzzox7ln4n8dcYEg,5234
|
|
34
34
|
foundry/utils/torch.py,sha256=OLsqoxw4CTXbGzWUHernLUT7uQjLu0tVPtD8h8747DI,11211
|
|
35
35
|
foundry/utils/weights.py,sha256=btz4S02xff2vgiq4xMfiXuhK1ERafqQPtmimo1DmoWY,10381
|
|
36
|
+
foundry/utils/xpu/__init__.py,sha256=i8_uiVj5b0-_Z4bX4RZmaE01e0NHbo-eTNXMYF1M0mM,1114
|
|
37
|
+
foundry/utils/xpu/single_xpu_strategy.py,sha256=28iEo-6e0gbpApmFxZdfhLO-RZb2QlOf9XKlRdTeXXc,1592
|
|
38
|
+
foundry/utils/xpu/xpu_accelerator.py,sha256=LPVXvgBKY3OE1ghyhZvotP6aWNVZXMup1Swi1v6U5GY,2467
|
|
39
|
+
foundry/utils/xpu/xpu_precision.py,sha256=zuFfLkBIDLbF-omJ6mTiGoVU0UgXg84A2j3ZjYyoAOc,2285
|
|
36
40
|
foundry_cli/__init__.py,sha256=0BxY2RUKJLaMXUGgypPCwlTskTEFdVnkhTR4C4ft2Kw,52
|
|
37
41
|
foundry_cli/download_checkpoints.py,sha256=CxU9dKBa1vAkVd450tfH5aZAlQIUTrHsDGTbmxzd_JQ,8922
|
|
38
42
|
mpnn/__init__.py,sha256=hgQcXFaCbAxFrhydVAy0xj8yC7UJF-GCCFhqD0sZ7I4,57
|
|
39
43
|
mpnn/inference.py,sha256=wPtGR325eVRVeesXoWtBK6b_-VcU8BZae5IfQN3-mvA,1669
|
|
40
44
|
mpnn/train.py,sha256=9eQGBd3rdNF5Zr2w8oUgETbqxBavNBajtA6Vbc5zESE,10239
|
|
41
45
|
mpnn/collate/feature_collator.py,sha256=LpzAFWo1VMa06dJLmfUWZsKe4xvLZjHbx4RICg2lgbQ,10510
|
|
42
|
-
mpnn/inference_engines/mpnn.py,sha256=
|
|
46
|
+
mpnn/inference_engines/mpnn.py,sha256=Wg7FMI5KKdq14qQzKs6TkX-rvUi1fB6lSnJMdwWuKKw,21912
|
|
43
47
|
mpnn/loss/nll_loss.py,sha256=KmdNe-BCzGYtijjappzBArQcT1gHVlJnKdY1PYQ4mhU,5947
|
|
44
48
|
mpnn/metrics/nll.py,sha256=T6oMeUOEeHZzOMTH8NHFtsH9vUwLAsHQDPszzj4YKXI,15299
|
|
45
49
|
mpnn/metrics/sequence_recovery.py,sha256=YDw_LmH-a3ajBYWK0mucJEQvw0_VEyxvrBN7da4vX8Q,19034
|
|
@@ -64,7 +68,7 @@ rf3/_version.py,sha256=fCfpbI5aeA6yHqjo3tK78-l2dPGxhp-AyKSoCXp34Nc,739
|
|
|
64
68
|
rf3/alignment.py,sha256=BvvwMqQGCVxV20xIsTighD1kXMadXXL2SkckLjTerx0,2102
|
|
65
69
|
rf3/chemical.py,sha256=VECnRPgVm-icXbZeUG4svcENzdUiIupP6dhka_8zCrg,26572
|
|
66
70
|
rf3/cli.py,sha256=jxjq8u77J8itIK4aNTfIpnMsNgg8brW1A3NfVjlgE0s,2743
|
|
67
|
-
rf3/inference.py,sha256=
|
|
71
|
+
rf3/inference.py,sha256=R_bzJe2K6Ek9uhrkZ97IiVaGKhBaObWc5SWh6MVDLLg,2530
|
|
68
72
|
rf3/kinematics.py,sha256=V3yjalPupu1X2FEp7l3XZR-qzLKrhWLZyECk6RgIkcs,10901
|
|
69
73
|
rf3/scoring.py,sha256=dTllswE-6Fgli2eLiNzLFc2Rhz4ouDT4WL-sVbvLTGU,41541
|
|
70
74
|
rf3/train.py,sha256=V4nqCC_1JKLI3WQ-nErNa8sqFpvb1mFhXSe6ZPpEheM,7945
|
|
@@ -119,18 +123,18 @@ rfd3/__init__.py,sha256=2Wto2IsUIj2lGag9m_gqgdCwBNl5p21-Xnr7W_RpU3c,348
|
|
|
119
123
|
rfd3/callbacks.py,sha256=Zjt8RiaYWquoKOwRmC_wCUbRbov-V4zd2_73zjhgDHE,2783
|
|
120
124
|
rfd3/cli.py,sha256=ka3K5H117fzDYIDXFpOpJV21w_XBrHYJZdFE0thsGBI,1644
|
|
121
125
|
rfd3/constants.py,sha256=wLvDzrThpOrK8T3wGFNQeGrhAXOJQze8l3v_7pjIdMM,13141
|
|
122
|
-
rfd3/engine.py,sha256=
|
|
126
|
+
rfd3/engine.py,sha256=cj7BIP35HMycWV3rBp9n9_ik5t7eNMK1wkVfJ58xFmI,21903
|
|
123
127
|
rfd3/run_inference.py,sha256=HfRMQ30_SAHfc-VFzBV52F-aLaNdG6PW8VkdMyB__wE,1264
|
|
124
128
|
rfd3/train.py,sha256=rHswffIUhOae3_iYyvAiQ3jALoFuzrcRUgMlbJLinlI,7947
|
|
125
129
|
rfd3/inference/datasets.py,sha256=9VLbzl7dpG8mk_pjs0R5C2wFYUoRIgXXoZcS9IohSy0,6510
|
|
126
|
-
rfd3/inference/input_parsing.py,sha256=
|
|
130
|
+
rfd3/inference/input_parsing.py,sha256=qUMsMyeh-kW757xX10zi07kKg4RS0wM8jSoIxu7UsBI,46794
|
|
127
131
|
rfd3/inference/legacy_input_parsing.py,sha256=G2XxkrjdIpL6i1YY7xEmkFitVv__Pc45ow6IKKPHw64,28855
|
|
128
132
|
rfd3/inference/parsing.py,sha256=ktAMUuZE3Pe4bKAjjV3zjqcEDmGlMZ-cotIUhJsEQQA,5402
|
|
129
|
-
rfd3/inference/symmetry/atom_array.py,sha256=
|
|
133
|
+
rfd3/inference/symmetry/atom_array.py,sha256=HfFagFUB5yB-Y4IfUM5nuVGWHC5AEkyHqt0JcIqTQ_E,10922
|
|
130
134
|
rfd3/inference/symmetry/checks.py,sha256=ZWpC1JjrAjXY__xDt8EFYb5WUdSF0kobZyxxmacFU7U,10076
|
|
131
135
|
rfd3/inference/symmetry/contigs.py,sha256=6OvbZ2dJg-a0mvvKAC0VkzUH5HpUDxOJvkByIst_roU,2127
|
|
132
|
-
rfd3/inference/symmetry/frames.py,sha256=
|
|
133
|
-
rfd3/inference/symmetry/symmetry_utils.py,sha256=
|
|
136
|
+
rfd3/inference/symmetry/frames.py,sha256=aEwkmlUsYexERX9hu09JMhisC8QTpHPVhfITbL80-EE,10819
|
|
137
|
+
rfd3/inference/symmetry/symmetry_utils.py,sha256=ibYeOT-4z_Pxr5i8cKbNAAqjRf4YfKOQCpSkuZQqG0I,14554
|
|
134
138
|
rfd3/metrics/design_metrics.py,sha256=O1RqZdjQPNlAWYRg6UJTERYg_gUI1_hVleKsm9xbWBY,16836
|
|
135
139
|
rfd3/metrics/hbonds_hbplus_metrics.py,sha256=Sewy9KzmrA1OnfkasN-fmWrQ9IRx9G7Yyhe2ua0mk28,11518
|
|
136
140
|
rfd3/metrics/hbonds_metrics.py,sha256=SIR4BnDhYdpVSqwXXRYpQ_tB-M0_fVyugGl08WivCmE,15257
|
|
@@ -140,7 +144,7 @@ rfd3/metrics/sidechain_metrics.py,sha256=EGZuFuWQ0cCe83EVPAf4eysN8vP9ifNjfnmE0o5
|
|
|
140
144
|
rfd3/model/RFD3.py,sha256=95aKzye-XzuDyLGgost-Wsfu8eT635zHIRky-pNoHSA,3569
|
|
141
145
|
rfd3/model/RFD3_diffusion_module.py,sha256=BPjKGyQpbnqdzii3gXMKLhhijNqV8Xh4bSosmfDBt8w,12094
|
|
142
146
|
rfd3/model/cfg_utils.py,sha256=XPBLyoB_bQRLmdrJ1Z0hCjcVvgUMGIPuw4rxTlHjB_s,2575
|
|
143
|
-
rfd3/model/inference_sampler.py,sha256=
|
|
147
|
+
rfd3/model/inference_sampler.py,sha256=zyA4pU2MIr2PPTuQm9k70UXUm3PRp644va6XgmB3aJ8,25210
|
|
144
148
|
rfd3/model/layers/attention.py,sha256=XuNA7WyFlRfLnAgky1PtGvXFCnDGv7GeEcXz8hodTBo,19472
|
|
145
149
|
rfd3/model/layers/block_utils.py,sha256=oN0aD-vZiH4JbIFs2CzDmb2B74GNPKzdFurmGd-dirE,21244
|
|
146
150
|
rfd3/model/layers/blocks.py,sha256=MOjJ53THxM2MMM27Ap7xiIXRCdI_SHzqKzLLQVX6FEc,24888
|
|
@@ -161,21 +165,21 @@ rfd3/transforms/conditioning_utils.py,sha256=9Pn9AFbih2FCzp5OOM9y7z6KH7HPxVibxTr
|
|
|
161
165
|
rfd3/transforms/design_transforms.py,sha256=ePvnLsuKUOsE4LLcmF0bbkx1vf2AiD-35rzF4zUEcEE,30944
|
|
162
166
|
rfd3/transforms/dna_crop.py,sha256=JeOsG0tXghJvgzEimfzBvlFN_lVd9TrvjnC929Abz5A,18214
|
|
163
167
|
rfd3/transforms/hbonds.py,sha256=ijfJapFlhsh3JktpDoT3VFqKTTg6ynrqMlD7dU2xFsA,16415
|
|
164
|
-
rfd3/transforms/hbonds_hbplus.py,sha256=
|
|
168
|
+
rfd3/transforms/hbonds_hbplus.py,sha256=i-57ZGZjofxWvygDidwWeKrkrALz5LcwJCCcw6rCOfQ,8654
|
|
165
169
|
rfd3/transforms/ncaa_transforms.py,sha256=Lz4L8OGuOOG53sKJHcLSdV7WPQ3YzOzwd5tJG4CHqP0,4983
|
|
166
170
|
rfd3/transforms/pipelines.py,sha256=FGH-XH3taTWQ6k1zpDO_d-097EQdXmL6uqXZXw4HIMs,22086
|
|
167
171
|
rfd3/transforms/ppi_transforms.py,sha256=7rXyf-tn2TLz6ybYR_YVDtSDG7hOgqhYY4shNviA_Sw,23493
|
|
168
172
|
rfd3/transforms/rasa.py,sha256=a4IPFvVMMxldoGLyJQiSlGg7IyUkcBASbRZLWmguAKk,4156
|
|
169
|
-
rfd3/transforms/symmetry.py,sha256=
|
|
173
|
+
rfd3/transforms/symmetry.py,sha256=GSnMF7oAnUxPozfafsRuHEv0yKXW0BpLTI6wsKGZrbc,2658
|
|
170
174
|
rfd3/transforms/training_conditions.py,sha256=UXiUPjDwrNKM95tRe0eXrMeRN8XlTPc_MXUvo6UpePo,19510
|
|
171
175
|
rfd3/transforms/util_transforms.py,sha256=2AcLkzx-73ZFgcWD1cIHv7NyniRPI4_zThHK8azyQaY,18119
|
|
172
176
|
rfd3/transforms/virtual_atoms.py,sha256=UpmxzPPd5FaJigcRoxgLSHHrLLOqsCvZ5PPZfQSGqII,12547
|
|
173
|
-
rfd3/utils/inference.py,sha256=
|
|
177
|
+
rfd3/utils/inference.py,sha256=A0yGOsO3cXiTqwLYP0uTGTB9MR6vsmnYMEiroYgvnDE,26578
|
|
174
178
|
rfd3/utils/io.py,sha256=wbdjUTQkDc3RCSM7gdogA-XOKR68HeQ-cfvyN4pP90w,9849
|
|
175
179
|
rfd3/utils/vizualize.py,sha256=HPlczrA3zkOuxV5X05eOvy_Oga9e3cPnFUXOEP4RR_g,11046
|
|
176
|
-
rf3/configs/inference.yaml,sha256=
|
|
177
|
-
rf3/configs/train.yaml,sha256=
|
|
178
|
-
rf3/configs/validate.yaml,sha256=
|
|
180
|
+
rf3/configs/inference.yaml,sha256=uvg-ppF7-ANMeecb60Vodx6GNlEx0W3HPiGgiajUZ5Q,375
|
|
181
|
+
rf3/configs/train.yaml,sha256=aQWrr2RXsFo97CYSFOyTMu7C8F6Nz7ymZU3BuX1gkec,1639
|
|
182
|
+
rf3/configs/validate.yaml,sha256=6FUaD-kLbu0jfUSgFIXfAuWhHtVH-4w_swGJSf2AQKc,1647
|
|
179
183
|
rf3/configs/callbacks/default.yaml,sha256=MkxOj7dMXh4jJRIE62gLjoOYecGuZLWiJrr780_nubA,89
|
|
180
184
|
rf3/configs/callbacks/dump_validation_structures.yaml,sha256=EYEibR25v7KZJtadvCFLFMEPTf0FvKFNW2ocx4wm57A,259
|
|
181
185
|
rf3/configs/callbacks/metrics_logging.yaml,sha256=MNm4OpvOHxvDJofVUA27NVaiDkp1NzqOYCzl6l_7ceo,432
|
|
@@ -224,6 +228,7 @@ rf3/configs/trainer/cpu.yaml,sha256=J1WbK2SQ_VMoEOOqv0XTg0FYKPaec1TNySvWklMaE4k,
|
|
|
224
228
|
rf3/configs/trainer/ddp.yaml,sha256=uClrdTzEMNxgq4IQhMgm8okC16wUS2I5i3rKnm5SktU,65
|
|
225
229
|
rf3/configs/trainer/rf3.yaml,sha256=WBjnaYofmEV7OfJisvFGN6y_UJGNEskAdmbV9oO4ICI,500
|
|
226
230
|
rf3/configs/trainer/rf3_with_confidence.yaml,sha256=meDTw0S2nTcuAj5tefGwJteo2715x6kO143FMd3db14,346
|
|
231
|
+
rf3/configs/trainer/xpu.yaml,sha256=GlUmQcnWtZOcp7wHip5DikKWfOVZTwrTX_HU05A4kc8,73
|
|
227
232
|
rf3/configs/trainer/loss/structure_prediction.yaml,sha256=XPH2RcIo6m1BDrIoBvd2xBgJ0c4KEOnbWc5oEcBKVAM,164
|
|
228
233
|
rf3/configs/trainer/loss/structure_prediction_with_confidence.yaml,sha256=OZxtrQIqYGDFlhp3-1fFVZ2Ek6QRMP74ImsrNOJcelo,52
|
|
229
234
|
rf3/configs/trainer/loss/losses/confidence_loss.yaml,sha256=c9xI7Q32Ddl5R_GpoWq4RS7AxiJf4hsnGeMZxWs9mBU,501
|
|
@@ -231,10 +236,10 @@ rf3/configs/trainer/loss/losses/diffusion_loss.yaml,sha256=ktBy3bUjslklPuRhGWuHb
|
|
|
231
236
|
rf3/configs/trainer/loss/losses/distogram_loss.yaml,sha256=-dZWPeQ1mWwi9pUtB7xJW57qKgUNMQ4vao1PXo07RC0,56
|
|
232
237
|
rf3/configs/trainer/metrics/structure_prediction.yaml,sha256=xWp2DqoqlofbgRLMNi0LKuAaDSCbW1tAEuEk83fmc0w,439
|
|
233
238
|
rfd3/configs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
234
|
-
rfd3/configs/dev.yaml,sha256=
|
|
235
|
-
rfd3/configs/inference.yaml,sha256=
|
|
236
|
-
rfd3/configs/train.yaml,sha256=
|
|
237
|
-
rfd3/configs/validate.yaml,sha256=
|
|
239
|
+
rfd3/configs/dev.yaml,sha256=ulil1Am_romDNUOGeQY3HLDRyPMiVKHOZrh59H4uAMA,176
|
|
240
|
+
rfd3/configs/inference.yaml,sha256=zzrhijUPplk--wjKk5d_Ht3Nnv-qLav5DQcfHcmLSNY,144
|
|
241
|
+
rfd3/configs/train.yaml,sha256=BJ3owsuFDVWiVgZdmA0VaQkPtMS9ws8FEyvBZeI4YQ8,476
|
|
242
|
+
rfd3/configs/validate.yaml,sha256=E78w-K6bE0Kl27V-wz9BuR9QyZVXTupxjT0UpYIA7Ao,562
|
|
238
243
|
rfd3/configs/callbacks/design_callbacks.yaml,sha256=JWgE1-v_spzUy7JH3_6dHct_-oX4DevozRH-pM5Ds2k,196
|
|
239
244
|
rfd3/configs/callbacks/metrics_logging.yaml,sha256=pZPePSYGKEV560e3WatuLvJiHlz1CIGFOaOWoRmBh8g,694
|
|
240
245
|
rfd3/configs/callbacks/train_logging.yaml,sha256=Z25GVLTHo1HvQUjnBdayaozmNww8UhAh9DczStYaZig,1050
|
|
@@ -301,11 +306,12 @@ rfd3/configs/paths/data/default.yaml,sha256=jfs1dbbcOqHja4_6lXheyRg4t0YExqVn2w0r
|
|
|
301
306
|
rfd3/configs/trainer/cpu.yaml,sha256=rJf5LHf6x5fN5EKg8mFEn2SwfGW5dV1JdYaHqWMfpXc,74
|
|
302
307
|
rfd3/configs/trainer/ddp.yaml,sha256=uClrdTzEMNxgq4IQhMgm8okC16wUS2I5i3rKnm5SktU,65
|
|
303
308
|
rfd3/configs/trainer/rfd3_base.yaml,sha256=R3lZxdyjUirjlLU31qWlnZgHaz4GcWTGGIz4fUl7AyM,1016
|
|
309
|
+
rfd3/configs/trainer/xpu.yaml,sha256=GlUmQcnWtZOcp7wHip5DikKWfOVZTwrTX_HU05A4kc8,73
|
|
304
310
|
rfd3/configs/trainer/loss/losses/diffusion_loss.yaml,sha256=FE4FCEfurE0ekwZ4YfS6wCvPSNqxClwg_kc73cPql5Y,323
|
|
305
311
|
rfd3/configs/trainer/loss/losses/sequence_loss.yaml,sha256=kezbQcqwAZ0VKQPUBr2MsNr9DcDL3ENIP1i-j7h-6Co,64
|
|
306
312
|
rfd3/configs/trainer/metrics/design_metrics.yaml,sha256=xVDpClhHqSHvsf-8StL26z51Vn-iuWMDG9KMB-kqOI0,719
|
|
307
|
-
rc_foundry-0.1.
|
|
308
|
-
rc_foundry-0.1.
|
|
309
|
-
rc_foundry-0.1.
|
|
310
|
-
rc_foundry-0.1.
|
|
311
|
-
rc_foundry-0.1.
|
|
313
|
+
rc_foundry-0.1.11.dist-info/METADATA,sha256=lK8MYr2naArWrcTOenHGERoQTliCfzPDgDave3bpWJI,11873
|
|
314
|
+
rc_foundry-0.1.11.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
315
|
+
rc_foundry-0.1.11.dist-info/entry_points.txt,sha256=BmiWCbWGtrd_lSOFMuCLBXyo84B7Nco-alj7hB0Yw9A,130
|
|
316
|
+
rc_foundry-0.1.11.dist-info/licenses/LICENSE.md,sha256=NKtPCJ7QMysFmzeDg56ZfUStvgzbq5sOvRQv7_ddZOs,1533
|
|
317
|
+
rc_foundry-0.1.11.dist-info/RECORD,,
|
rf3/configs/inference.yaml
CHANGED
|
@@ -2,6 +2,11 @@
|
|
|
2
2
|
# ^ The "package" determines where the content of the config is placed in the output config
|
|
3
3
|
# For more information about overriding configs, see: https://hydra.cc/docs/advanced/overriding_packages/#overriding-packages-using-the-defaults-list
|
|
4
4
|
|
|
5
|
+
hydra:
|
|
6
|
+
searchpath:
|
|
7
|
+
- pkg://rf3.configs
|
|
8
|
+
- pkg://configs
|
|
9
|
+
|
|
5
10
|
defaults:
|
|
6
11
|
- inference_engine: rf3
|
|
7
12
|
- _self_
|
rf3/configs/train.yaml
CHANGED
|
@@ -2,6 +2,11 @@
|
|
|
2
2
|
# ^ The "package" determines where the content of the config is placed in the output config
|
|
3
3
|
# For more information about overriding configs, see: https://hydra.cc/docs/advanced/overriding_packages/#overriding-packages-using-the-defaults-list
|
|
4
4
|
|
|
5
|
+
hydra:
|
|
6
|
+
searchpath:
|
|
7
|
+
- pkg://rf3.configs
|
|
8
|
+
- pkg://configs
|
|
9
|
+
|
|
5
10
|
# NOTE: order of defaults determines the order in which configs override each other (higher up items are overridden by lower items)
|
|
6
11
|
defaults:
|
|
7
12
|
- callbacks: default
|
rf3/configs/validate.yaml
CHANGED
|
@@ -2,6 +2,11 @@
|
|
|
2
2
|
# ^ The "package" determines where the content of the config is placed in the output config
|
|
3
3
|
# For more information about overriding configs, see: https://hydra.cc/docs/advanced/overriding_packages/#overriding-packages-using-the-defaults-list
|
|
4
4
|
|
|
5
|
+
hydra:
|
|
6
|
+
searchpath:
|
|
7
|
+
- pkg://rf3.configs
|
|
8
|
+
- pkg://configs
|
|
9
|
+
|
|
5
10
|
# NOTE: order of defaults determines the order in which configs override each other (higher up items are overridden by lower items)
|
|
6
11
|
defaults:
|
|
7
12
|
- callbacks: default
|
rf3/inference.py
CHANGED
|
@@ -12,7 +12,10 @@ from foundry.utils.logging import suppress_warnings
|
|
|
12
12
|
|
|
13
13
|
# Setup root dir and environment variables (more info: https://github.com/ashleve/rootutils)
|
|
14
14
|
# NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located)
|
|
15
|
-
|
|
15
|
+
try:
|
|
16
|
+
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
17
|
+
except Exception:
|
|
18
|
+
pass
|
|
16
19
|
|
|
17
20
|
load_dotenv(override=True)
|
|
18
21
|
|
rfd3/configs/dev.yaml
CHANGED
rfd3/configs/inference.yaml
CHANGED
rfd3/configs/train.yaml
CHANGED