rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,319 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from functools import cached_property
|
|
4
|
+
|
|
5
|
+
import hydra
|
|
6
|
+
from atomworks.common import exists
|
|
7
|
+
from atomworks.ml.utils import error, nested_dict
|
|
8
|
+
from beartype.typing import Any
|
|
9
|
+
from omegaconf import DictConfig
|
|
10
|
+
from toolz import keymap
|
|
11
|
+
|
|
12
|
+
from foundry.utils.ddp import RankedLogger
|
|
13
|
+
|
|
14
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def instantiate_metric_manager(
|
|
18
|
+
metrics_cfg: dict[str, Any] | DictConfig,
|
|
19
|
+
) -> "MetricManager":
|
|
20
|
+
"""Instantiate a MetricManager from a dictionary of metrics.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
metrics: A dictionary where keys are metric names and values are Hydra configurations for the metrics.
|
|
24
|
+
"""
|
|
25
|
+
metrics = {}
|
|
26
|
+
for name, cfg in metrics_cfg.items():
|
|
27
|
+
metric = hydra.utils.instantiate(cfg)
|
|
28
|
+
if not isinstance(metric, Metric):
|
|
29
|
+
raise TypeError(f"{name} must be a Metric instance")
|
|
30
|
+
ranked_logger.info(f"Adding metric {name} to the validation metrics...")
|
|
31
|
+
metrics[name] = metric
|
|
32
|
+
return MetricManager(metrics)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class MetricInputError(Exception):
|
|
36
|
+
"""Exception raised when a metric fails to compute."""
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class MetricManager:
|
|
40
|
+
"""Manages and computes a set of Metrics, where each Metric inherits from the Metric class.
|
|
41
|
+
|
|
42
|
+
For model validation, additional metrics can be added through the Hydra configuration; they
|
|
43
|
+
will be computed with the __call__ method automatically.
|
|
44
|
+
|
|
45
|
+
For example, during AF-3, Metrics will receive `network_input`, `network_output`, `extra_info`,
|
|
46
|
+
`ground_truth_atom_array_stack`, and `predicted_atom_array_stack` as input arguments.
|
|
47
|
+
|
|
48
|
+
Example:
|
|
49
|
+
>>> class ExampleMetric(Metric):
|
|
50
|
+
... @cached_property
|
|
51
|
+
... def kwargs_to_compute_args(self):
|
|
52
|
+
... return {"x": "x", "y": "y", "extra_info": "extra_info"}
|
|
53
|
+
...
|
|
54
|
+
... def compute(self, x, y, extra_info):
|
|
55
|
+
... return {"value": x + y}
|
|
56
|
+
>>> metric = ExampleMetric()
|
|
57
|
+
>>> manager = MetricManager({"my_metric": metric}, raise_errors=True)
|
|
58
|
+
>>> manager(x=1, y=2, extra_info={"example_id": "123"})
|
|
59
|
+
{'example_id': '123', 'my_metric.value': 3}
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
metrics: dict[str, "Metric"] = {},
|
|
65
|
+
*,
|
|
66
|
+
raise_errors: bool = True,
|
|
67
|
+
):
|
|
68
|
+
"""Initialize the MetricManager with a set of metrics.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
raise_errors: Whether to raise errors when a metric fails to compute.
|
|
72
|
+
metrics: A dictionary where keys are metric names and values are Metric instances.
|
|
73
|
+
"""
|
|
74
|
+
self.raise_errors = raise_errors
|
|
75
|
+
self.metrics = {}
|
|
76
|
+
for name, metric in metrics.items():
|
|
77
|
+
assert isinstance(
|
|
78
|
+
metric, Metric
|
|
79
|
+
), f"{name} must be a Metric instance, not {type(metric)}"
|
|
80
|
+
self.metrics[name] = metric
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def instantiate_from_hydra(
|
|
84
|
+
cls, metrics_cfg: dict[str, Any] | DictConfig
|
|
85
|
+
) -> "MetricManager":
|
|
86
|
+
"""Instantiate a MetricManager from a dictionary of metrics.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
metrics_cfg: A dictionary where keys are metric names and values are Hydra configurations for the metrics.
|
|
90
|
+
"""
|
|
91
|
+
return instantiate_metric_manager(metrics_cfg)
|
|
92
|
+
|
|
93
|
+
@classmethod
|
|
94
|
+
def from_metrics(
|
|
95
|
+
cls,
|
|
96
|
+
metrics: dict[str, "Metric"] | list[tuple[str, "Metric"]],
|
|
97
|
+
raise_errors: bool = True,
|
|
98
|
+
) -> "MetricManager":
|
|
99
|
+
"""Create MetricManager from metric objects.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
metrics: Either dict mapping names to Metric objects,
|
|
103
|
+
or list of (name, Metric) tuples.
|
|
104
|
+
raise_errors: Whether to raise errors on metric failures. Defaults to ``True``.
|
|
105
|
+
"""
|
|
106
|
+
if isinstance(metrics, list):
|
|
107
|
+
# Convert list of tuples to dict
|
|
108
|
+
metrics = dict(metrics)
|
|
109
|
+
|
|
110
|
+
# Validate all are Metric instances
|
|
111
|
+
for name, metric in metrics.items():
|
|
112
|
+
if not isinstance(metric, Metric):
|
|
113
|
+
raise TypeError(
|
|
114
|
+
f"Metric '{name}' must be a Metric instance, got {type(metric)}"
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
return cls(metrics, raise_errors=raise_errors)
|
|
118
|
+
|
|
119
|
+
def __repr__(self) -> str:
|
|
120
|
+
"""Return a string representation of the MetricManager."""
|
|
121
|
+
return f"MetricManager({', '.join(self.metrics.keys())})"
|
|
122
|
+
|
|
123
|
+
def __call__(self, **kwargs: Any) -> dict[str, Any]:
|
|
124
|
+
"""Compute all metrics and return their results as a dictionary mapping metric names to their computed values."""
|
|
125
|
+
|
|
126
|
+
# Extract example_id if it exists
|
|
127
|
+
example_id = nested_dict.get(
|
|
128
|
+
kwargs, key=("extra_info", "example_id"), default=None
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# Initialize results dictionary
|
|
132
|
+
results = {"example_id": example_id}
|
|
133
|
+
|
|
134
|
+
for name, metric in self.metrics.items():
|
|
135
|
+
assert name not in results, f"Duplicate metric name: {name}"
|
|
136
|
+
|
|
137
|
+
# Add some nice error handling context in case metrics fail
|
|
138
|
+
example_msg = (
|
|
139
|
+
f" for example '{example_id}'" if example_id is not None else ""
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
# Decide whether the Metric applies to this batch
|
|
143
|
+
if "extra_info" in kwargs and "metrics_tags" in kwargs["extra_info"]:
|
|
144
|
+
metrics_tags = kwargs["extra_info"]["metrics_tags"]
|
|
145
|
+
metrics_tags = set(metrics_tags)
|
|
146
|
+
if metric.required_tags_all and not metrics_tags.issuperset(
|
|
147
|
+
metric.required_tags_all
|
|
148
|
+
):
|
|
149
|
+
continue
|
|
150
|
+
if metric.required_tags_any and not metrics_tags.intersection(
|
|
151
|
+
metric.required_tags_any
|
|
152
|
+
):
|
|
153
|
+
continue
|
|
154
|
+
if metric.prohibited_tags and metrics_tags.intersection(
|
|
155
|
+
metric.prohibited_tags
|
|
156
|
+
):
|
|
157
|
+
continue
|
|
158
|
+
|
|
159
|
+
with error.context(
|
|
160
|
+
msg=f"Computing '{name}' ({type(metric).__name__}){example_msg}",
|
|
161
|
+
raise_error=self.raise_errors,
|
|
162
|
+
exc_types=(MetricInputError, ValueError, TypeError, AttributeError),
|
|
163
|
+
):
|
|
164
|
+
# ... compute the metric
|
|
165
|
+
metric_result = metric.compute_from_kwargs(**kwargs)
|
|
166
|
+
|
|
167
|
+
# ... append 'name' to the keys of the metric result to ensure uniqueness
|
|
168
|
+
if isinstance(metric_result, dict):
|
|
169
|
+
metric_result = keymap(lambda k: f"{name}.{k}", metric_result)
|
|
170
|
+
results.update(metric_result)
|
|
171
|
+
elif isinstance(metric_result, list):
|
|
172
|
+
results[name] = metric_result
|
|
173
|
+
else:
|
|
174
|
+
raise ValueError("Unexpected result type: expected dict or list.")
|
|
175
|
+
|
|
176
|
+
return results
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class Metric(ABC):
|
|
180
|
+
"""Abstract base class for Foundry metrics.
|
|
181
|
+
|
|
182
|
+
Defines a framework for computing metrics based on arbitrary keyword arguments.
|
|
183
|
+
A Metric can specify which batches a `MetricManager` applies it to, based on `tags` in the input batch.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
required_tags_all: A set of tags that must all be present in the input batch for this metric to be computed.
|
|
187
|
+
required_tags_any: A set of tags where at least one must be present in the input batch for this metric to be computed.
|
|
188
|
+
prohibited_tags: A set of tags that must not be present in the input batch for this metric to be computed.
|
|
189
|
+
|
|
190
|
+
To implement a new metric, subclass this class and implement the `compute` method, at a minimum.
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
def __init__(
|
|
194
|
+
self,
|
|
195
|
+
required_tags_all: list[str] | set[str] | None = None,
|
|
196
|
+
required_tags_any: list[str] | set[str] | None = None,
|
|
197
|
+
prohibited_tags: list[str] | set[str] | None = None,
|
|
198
|
+
*args: Any,
|
|
199
|
+
**kwargs: Any,
|
|
200
|
+
) -> None:
|
|
201
|
+
# Set required and prohibited tags
|
|
202
|
+
self.required_tags_all = (
|
|
203
|
+
set(required_tags_all) if exists(required_tags_all) else set()
|
|
204
|
+
)
|
|
205
|
+
self.required_tags_any = (
|
|
206
|
+
set(required_tags_any) if exists(required_tags_any) else set()
|
|
207
|
+
)
|
|
208
|
+
self.prohibited_tags = (
|
|
209
|
+
set(prohibited_tags) if exists(prohibited_tags) else set()
|
|
210
|
+
)
|
|
211
|
+
required_tags = self.required_tags_all.union(self.required_tags_any)
|
|
212
|
+
if required_tags.intersection(self.prohibited_tags):
|
|
213
|
+
raise ValueError(
|
|
214
|
+
"Conflicting tags found: required tags and prohibited tags must be disjoint."
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Check that the 'keys' of the kwargs_to_compute_args are a subset of the 'compute' method signature
|
|
218
|
+
if self.kwargs_to_compute_args:
|
|
219
|
+
assert self.kwargs_to_compute_args.keys() <= self.required_compute_args, (
|
|
220
|
+
f"The keys of kwargs_to_compute_args must be a subset of the 'compute' method signature. "
|
|
221
|
+
f"{self.kwargs_to_compute_args.keys()} is not a subset of {self.required_compute_args}"
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# Check that optional_kwargs are also in the kwargs_to_compute_args
|
|
225
|
+
if self.kwargs_to_compute_args and self.optional_kwargs:
|
|
226
|
+
assert self.optional_kwargs <= set(self.kwargs_to_compute_args.keys()), (
|
|
227
|
+
f"All optional_kwargs must be defined in kwargs_to_compute_args. "
|
|
228
|
+
f"{self.optional_kwargs} is not a subset of {set(self.kwargs_to_compute_args.keys())}"
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
@cached_property
|
|
232
|
+
def required_compute_args(self) -> frozenset[str]:
|
|
233
|
+
"""Required input keys for this metric"""
|
|
234
|
+
return frozenset(inspect.signature(self.compute).parameters.keys())
|
|
235
|
+
|
|
236
|
+
@cached_property
|
|
237
|
+
def required_kwargs(self) -> frozenset[str]:
|
|
238
|
+
"""Required input keys for this metric"""
|
|
239
|
+
if self.kwargs_to_compute_args is None:
|
|
240
|
+
return frozenset()
|
|
241
|
+
|
|
242
|
+
return frozenset(self.kwargs_to_compute_args.values())
|
|
243
|
+
|
|
244
|
+
def compute_from_kwargs(self, **kwargs: Any) -> dict[str, Any]:
|
|
245
|
+
"""Run compute with an arbitrary dictionary of input keys and values.
|
|
246
|
+
|
|
247
|
+
The 'kwargs_to_compute_args' property here will determine
|
|
248
|
+
where in the kwargs we will look for the values to pass to the compute method.
|
|
249
|
+
|
|
250
|
+
Parameters marked in 'optional_kwargs' will only be passed if present in kwargs.
|
|
251
|
+
"""
|
|
252
|
+
if self.kwargs_to_compute_args:
|
|
253
|
+
compute_inputs = {}
|
|
254
|
+
for compute_arg, kwargs_key in self.kwargs_to_compute_args.items():
|
|
255
|
+
if compute_arg in self.optional_kwargs:
|
|
256
|
+
# Optional parameter - only add if present
|
|
257
|
+
try:
|
|
258
|
+
compute_inputs[compute_arg] = nested_dict.getitem(
|
|
259
|
+
kwargs, key=kwargs_key
|
|
260
|
+
)
|
|
261
|
+
except KeyError:
|
|
262
|
+
pass # Don't pass this parameter to compute()
|
|
263
|
+
else:
|
|
264
|
+
# Required parameter - use getitem (will raise if missing)
|
|
265
|
+
compute_inputs[compute_arg] = nested_dict.getitem(
|
|
266
|
+
kwargs, key=kwargs_key
|
|
267
|
+
)
|
|
268
|
+
else:
|
|
269
|
+
# If kwargs_to_compute_args is not defined, use kwargs directly
|
|
270
|
+
compute_inputs = kwargs
|
|
271
|
+
return self.compute(**compute_inputs)
|
|
272
|
+
|
|
273
|
+
@property
|
|
274
|
+
def kwargs_to_compute_args(self) -> dict[str, Any]:
|
|
275
|
+
"""Map input keys to a flat dictionary.
|
|
276
|
+
|
|
277
|
+
If not implemented, we return None, and pass the kwargs directly to the compute method.
|
|
278
|
+
|
|
279
|
+
Override e.g. as:
|
|
280
|
+
```python
|
|
281
|
+
@cached_property
|
|
282
|
+
def kwargs_to_compute_args(self) -> dict[str, Any]:
|
|
283
|
+
return {
|
|
284
|
+
"y_true": ("network_input", "coords_unnoised"),
|
|
285
|
+
"y_pred": ("network_output", "coords_pred"),
|
|
286
|
+
"extra_info": ("extra_info",),
|
|
287
|
+
}
|
|
288
|
+
```
|
|
289
|
+
"""
|
|
290
|
+
return None
|
|
291
|
+
|
|
292
|
+
@property
|
|
293
|
+
def optional_kwargs(self) -> frozenset[str]:
|
|
294
|
+
"""Set of compute argument names that are optional.
|
|
295
|
+
|
|
296
|
+
Optional parameters will only be passed to compute() if present in kwargs.
|
|
297
|
+
The compute() method should have sensible defaults for these parameters.
|
|
298
|
+
|
|
299
|
+
Override e.g. as:
|
|
300
|
+
```python
|
|
301
|
+
@property
|
|
302
|
+
def optional_kwargs(self) -> frozenset[str]:
|
|
303
|
+
return frozenset(["confidence_indices", "interfaces_to_score"])
|
|
304
|
+
```
|
|
305
|
+
"""
|
|
306
|
+
return frozenset()
|
|
307
|
+
|
|
308
|
+
@abstractmethod
|
|
309
|
+
def compute(self, **kwargs: Any) -> dict[str, Any] | list[dict[str, Any]]:
|
|
310
|
+
"""Implement actual metric calculation here
|
|
311
|
+
|
|
312
|
+
Override e.g. as:
|
|
313
|
+
```python
|
|
314
|
+
def compute(self, y_true, y_pred, extra_info):
|
|
315
|
+
print(extra_info)
|
|
316
|
+
return lddt(y_true, y_pred, thres=self.custom_thresholds)
|
|
317
|
+
```
|
|
318
|
+
"""
|
|
319
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
pi = torch.acos(torch.zeros(1)).item() * 2
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class FourierEmbedding(nn.Module):
|
|
8
|
+
def __init__(self, c):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.c = c
|
|
11
|
+
self.register_buffer("w", torch.zeros(c, dtype=torch.float32))
|
|
12
|
+
self.register_buffer("b", torch.zeros(c, dtype=torch.float32))
|
|
13
|
+
self.reset_parameters()
|
|
14
|
+
|
|
15
|
+
def reset_parameters(self) -> None:
|
|
16
|
+
# super().reset_parameters()
|
|
17
|
+
nn.init.normal_(self.w)
|
|
18
|
+
nn.init.normal_(self.b)
|
|
19
|
+
|
|
20
|
+
def forward(
|
|
21
|
+
self,
|
|
22
|
+
t, # [D]
|
|
23
|
+
):
|
|
24
|
+
return torch.cos(2 * pi * (t[..., None] * self.w + self.b))
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Dropout(nn.Module):
|
|
28
|
+
# Dropout entire row or column
|
|
29
|
+
def __init__(self, broadcast_dim=None, p_drop=0.15):
|
|
30
|
+
super(Dropout, self).__init__()
|
|
31
|
+
# give ones with probability of 1-p_drop / zeros with p_drop
|
|
32
|
+
self.sampler = torch.distributions.bernoulli.Bernoulli(
|
|
33
|
+
torch.tensor([1 - p_drop])
|
|
34
|
+
)
|
|
35
|
+
self.broadcast_dim = broadcast_dim
|
|
36
|
+
self.p_drop = p_drop
|
|
37
|
+
|
|
38
|
+
def forward(self, x):
|
|
39
|
+
if not self.training: # no drophead during evaluation mode
|
|
40
|
+
return x
|
|
41
|
+
shape = list(x.shape)
|
|
42
|
+
if self.broadcast_dim is not None:
|
|
43
|
+
shape[self.broadcast_dim] = 1
|
|
44
|
+
mask = self.sampler.sample(shape).to(x.device).view(shape)
|
|
45
|
+
|
|
46
|
+
x = mask * x / (1.0 - self.p_drop)
|
|
47
|
+
return x
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Shared pytest fixtures for foundry tests."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@pytest.fixture(scope="session")
|
|
10
|
+
def gpu():
|
|
11
|
+
"""Fixture to check GPU availability for tests that require CUDA."""
|
|
12
|
+
if not torch.cuda.is_available():
|
|
13
|
+
pytest.skip("GPU not available")
|
|
14
|
+
return True
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_test_data_dir(conftest_file: str) -> Path:
|
|
18
|
+
"""Get TEST_DATA_DIR path relative to conftest file."""
|
|
19
|
+
return Path(conftest_file).resolve().parent / "data"
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Shared pytest configuration hooks for foundry tests."""
|
|
2
|
+
|
|
3
|
+
import rootutils
|
|
4
|
+
from dotenv import load_dotenv
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def configure_pytest(config, conftest_file: str) -> None:
|
|
8
|
+
"""Configure pytest for foundry tests.
|
|
9
|
+
Sets up project root and environment variables.
|
|
10
|
+
"""
|
|
11
|
+
# Setup the project root
|
|
12
|
+
rootutils.setup_root(conftest_file, indicator=".project-root", pythonpath=True)
|
|
13
|
+
|
|
14
|
+
# Setup environment variables
|
|
15
|
+
load_dotenv(override=True)
|