rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
rf3/metrics/distogram.py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from atomworks.ml.utils.token import get_af3_token_representative_idxs
|
|
8
|
+
from beartype.typing import Any, Literal
|
|
9
|
+
from biotite.structure import AtomArrayStack
|
|
10
|
+
from einops import rearrange, repeat
|
|
11
|
+
from jaxtyping import Bool, Float
|
|
12
|
+
from rf3.loss.af3_losses import distogram_loss
|
|
13
|
+
|
|
14
|
+
from foundry.metrics.metric import Metric
|
|
15
|
+
from foundry.utils.torch import assert_no_nans
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class ComparisonConfig:
|
|
20
|
+
"""Configuration for token pair comparisons in distogram metrics."""
|
|
21
|
+
|
|
22
|
+
token_a: Literal["all", "atomized", "non_atomized"] = "all"
|
|
23
|
+
token_b: Literal["all", "atomized", "non_atomized"] = "all"
|
|
24
|
+
relationship: Literal["all", "inter", "intra"] = "all"
|
|
25
|
+
|
|
26
|
+
def __eq__(self, other):
|
|
27
|
+
"""Equality that accounts for token_a/token_b symmetry."""
|
|
28
|
+
if not isinstance(other, type(self)):
|
|
29
|
+
return False
|
|
30
|
+
|
|
31
|
+
return self.relationship == other.relationship and {
|
|
32
|
+
self.token_a,
|
|
33
|
+
self.token_b,
|
|
34
|
+
} == {other.token_a, other.token_b}
|
|
35
|
+
|
|
36
|
+
def __hash__(self):
|
|
37
|
+
"""Hash function compatible with the equality definition."""
|
|
38
|
+
return hash((frozenset([self.token_a, self.token_b]), self.relationship))
|
|
39
|
+
|
|
40
|
+
def __str__(self):
|
|
41
|
+
"""String representation of the comparison config."""
|
|
42
|
+
name = f"{self.token_a}_by_{self.token_b}"
|
|
43
|
+
if self.relationship != "all":
|
|
44
|
+
name += f"_{self.relationship}"
|
|
45
|
+
return name
|
|
46
|
+
|
|
47
|
+
def create_distogram_mask(
|
|
48
|
+
self, token_rep_atom_array: AtomArrayStack
|
|
49
|
+
) -> Bool[np.ndarray, "I I"]:
|
|
50
|
+
"""Create a token-by-token mask indiciating which 2D pairs satisfy the ComparisonConfig's conditions."""
|
|
51
|
+
type_masks = {
|
|
52
|
+
"all": np.ones(len(token_rep_atom_array), dtype=bool),
|
|
53
|
+
"atomized": token_rep_atom_array.atomize,
|
|
54
|
+
"non_atomized": ~token_rep_atom_array.atomize,
|
|
55
|
+
}
|
|
56
|
+
# Create token pair mask
|
|
57
|
+
if self.token_a == self.token_b:
|
|
58
|
+
# (Both same)
|
|
59
|
+
token_pair_mask = np.outer(
|
|
60
|
+
type_masks[self.token_a], type_masks[self.token_b]
|
|
61
|
+
)
|
|
62
|
+
else:
|
|
63
|
+
# (Different - must be symmetric)
|
|
64
|
+
token_pair_mask = np.outer(
|
|
65
|
+
type_masks[self.token_a], type_masks[self.token_b]
|
|
66
|
+
) | np.outer(type_masks[self.token_b], type_masks[self.token_a])
|
|
67
|
+
|
|
68
|
+
# Apply relationship constraint
|
|
69
|
+
if self.relationship != "all":
|
|
70
|
+
intra_mask = np.equal.outer(
|
|
71
|
+
token_rep_atom_array.pn_unit_iid, token_rep_atom_array.pn_unit_iid
|
|
72
|
+
)
|
|
73
|
+
if self.relationship == "intra":
|
|
74
|
+
# Same chain ("intra")
|
|
75
|
+
token_pair_mask = token_pair_mask & intra_mask
|
|
76
|
+
else:
|
|
77
|
+
# Different chains ("inter")
|
|
78
|
+
token_pair_mask = token_pair_mask & (~intra_mask)
|
|
79
|
+
|
|
80
|
+
return token_pair_mask
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class DistogramLoss(Metric):
|
|
84
|
+
"""Computes the distogram loss, taking into account the coordinate mask."""
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def kwargs_to_compute_args(self) -> dict[str, Any]:
|
|
88
|
+
return {
|
|
89
|
+
"pred_distogram": ("network_output", "distogram"),
|
|
90
|
+
"X_rep_atoms_I": ("extra_info", "coord_token_lvl"),
|
|
91
|
+
"crd_mask_rep_atoms_I": ("extra_info", "mask_token_lvl"),
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
def __init__(self, **kwargs):
|
|
95
|
+
super().__init__(**kwargs)
|
|
96
|
+
self.cce_loss = nn.CrossEntropyLoss(reduction="none")
|
|
97
|
+
|
|
98
|
+
def compute(
|
|
99
|
+
self,
|
|
100
|
+
pred_distogram: Float[torch.Tensor, "I I n_bins"],
|
|
101
|
+
X_rep_atoms_I: Float[torch.Tensor, "I 3"],
|
|
102
|
+
crd_mask_rep_atoms_I: Float[torch.Tensor, "I"],
|
|
103
|
+
) -> dict[str, Any]:
|
|
104
|
+
"""Computes the distogram loss.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
pred_distogram: The predicted distogram. Shape: [I, I, n_bins], where n_bins is the number of bins (64 + 1 = 65).
|
|
108
|
+
X_rep_atoms_I: The ground-truth coordinates of the representative atoms for each token. Shape: [I, 3].
|
|
109
|
+
crd_mask_rep_atoms_I: A boolean mask indicating which representative atoms are present. Shape: [I].
|
|
110
|
+
"""
|
|
111
|
+
loss = distogram_loss(
|
|
112
|
+
pred_distogram, X_rep_atoms_I, crd_mask_rep_atoms_I, self.cce_loss
|
|
113
|
+
)
|
|
114
|
+
return {"distogram_loss": loss.detach().item()}
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def bin_distances(
|
|
118
|
+
coords: Float[torch.Tensor, "... L 3"],
|
|
119
|
+
min_distance: int = 2,
|
|
120
|
+
max_distance: int = 22,
|
|
121
|
+
n_bins: int = 64,
|
|
122
|
+
) -> Float[torch.Tensor, "... L L {n_bins}+1"]:
|
|
123
|
+
# TODO: Refactor loss to use this function instead (more re-usable)
|
|
124
|
+
"""Converts coordinates into binned distances according to the given parameters.
|
|
125
|
+
|
|
126
|
+
NOTE: Our returned number of bins will be n_bins + 1, as torch.bucketize adds an additional bin for values greater than the maximum.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
coords (torch.Tensor): The input tensor of coordinates. May be batched.
|
|
130
|
+
min_distance (float): The minimum distance for binning.
|
|
131
|
+
max_distance (float): The maximum distance for binning.
|
|
132
|
+
n_bins (int): The number of bins to use.
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
torch.Tensor: The binned distances.
|
|
136
|
+
"""
|
|
137
|
+
# Compute pairwise distances
|
|
138
|
+
distance_map = torch.cdist(coords, coords)
|
|
139
|
+
|
|
140
|
+
# (Replace NaN's with a large value to avoid issues with bucketize)
|
|
141
|
+
distance_map = torch.nan_to_num(distance_map, nan=9999.0)
|
|
142
|
+
|
|
143
|
+
# ... bin the distances
|
|
144
|
+
n_bins = torch.linspace(min_distance, max_distance, n_bins).to(coords.device)
|
|
145
|
+
binned_distances = torch.bucketize(distance_map, n_bins)
|
|
146
|
+
|
|
147
|
+
return binned_distances
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def masked_distogram_cross_entropy_loss(
|
|
151
|
+
input: Float[torch.Tensor, "D I I n_bins"],
|
|
152
|
+
target: Float[torch.Tensor, "D I I"],
|
|
153
|
+
mask: Float[torch.Tensor, "I I"] = None,
|
|
154
|
+
) -> Float[torch.Tensor, "D"]:
|
|
155
|
+
# TODO: Refactor loss to use this function instead (more re-usable)
|
|
156
|
+
"""Computes the masked cross-entropy between two distograms.
|
|
157
|
+
|
|
158
|
+
Note that the cross-entropy loss is not symmetric; that is, H(x, y) != H(y, x).
|
|
159
|
+
"""
|
|
160
|
+
# From the PyTorch documentation (where C = number of classes, N = batch size):
|
|
161
|
+
# > Input: Shape: (C), (N, C) or (N, C, d1, d2, ..., dk)
|
|
162
|
+
# > Target: Shape: (N) or (N, d1, d2, ..., dk) where each value should be between [0, C)
|
|
163
|
+
input = rearrange(input, "d i j n_bins -> d n_bins i j")
|
|
164
|
+
loss = F.cross_entropy(input, target, reduction="none")
|
|
165
|
+
|
|
166
|
+
# Apply mask and normalize
|
|
167
|
+
masked_loss = loss * mask if mask is not None else loss
|
|
168
|
+
normalized_loss = masked_loss.sum(dim=(-1, -2)) / mask.sum() + 1e-4 # [D]
|
|
169
|
+
|
|
170
|
+
return normalized_loss
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class DistogramComparisons(Metric):
|
|
174
|
+
"""Compares model distogram representations.
|
|
175
|
+
|
|
176
|
+
Namely:
|
|
177
|
+
- The representation from the TRUNK vs. GROUND TRUTH
|
|
178
|
+
- The representation from the TRUNK vs. PREDICTED COORDINATES
|
|
179
|
+
|
|
180
|
+
We subset to specific token pairs based on the provided ComparisonConfig.
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
@property
|
|
184
|
+
def kwargs_to_compute_args(self) -> dict[str, Any]:
|
|
185
|
+
return {
|
|
186
|
+
"X_L": ("network_output", "X_L"), # [D, L, 3]
|
|
187
|
+
"trunk_pred_distogram": (
|
|
188
|
+
"network_output",
|
|
189
|
+
"distogram",
|
|
190
|
+
), # [I, I, 65], where 65 is the number of bins (64 + 1)
|
|
191
|
+
"ground_truth_atom_array_stack": "ground_truth_atom_array_stack",
|
|
192
|
+
"X_rep_atoms_I": ("extra_info", "coord_token_lvl"), # [D, I, 3]
|
|
193
|
+
"crd_mask_rep_atoms_I": ("extra_info", "mask_token_lvl"), # [D, I]
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
def __init__(
|
|
197
|
+
self, comparison_configs: list[ComparisonConfig] | None = None, **kwargs
|
|
198
|
+
):
|
|
199
|
+
"""
|
|
200
|
+
Args:
|
|
201
|
+
comparison_configs: List of ComparisonConfig objects defining which comparisons to compute.
|
|
202
|
+
"""
|
|
203
|
+
super().__init__(**kwargs)
|
|
204
|
+
|
|
205
|
+
if comparison_configs is None:
|
|
206
|
+
# Default comparisons
|
|
207
|
+
comparison_configs = [
|
|
208
|
+
ComparisonConfig("atomized", "atomized", "intra"),
|
|
209
|
+
ComparisonConfig("atomized", "non_atomized", "inter"),
|
|
210
|
+
ComparisonConfig("non_atomized", "non_atomized", "intra"),
|
|
211
|
+
ComparisonConfig("all", "all", "all"),
|
|
212
|
+
]
|
|
213
|
+
|
|
214
|
+
# Deduplicate (handle symmetries in token_a/token_b)
|
|
215
|
+
self.comparison_configs = list(set(comparison_configs))
|
|
216
|
+
|
|
217
|
+
def compute(
|
|
218
|
+
self,
|
|
219
|
+
X_L: Float[torch.Tensor, "D L 3"],
|
|
220
|
+
trunk_pred_distogram: Float[torch.Tensor, "I I n_bins"],
|
|
221
|
+
ground_truth_atom_array_stack: AtomArrayStack,
|
|
222
|
+
X_rep_atoms_I: Float[torch.Tensor, "D I 3"] | None = None,
|
|
223
|
+
crd_mask_rep_atoms_I: Float[torch.Tensor, "D I"] | None = None,
|
|
224
|
+
) -> dict[str, Any]:
|
|
225
|
+
"""Computes the distogram loss for the trunk vs. ground truth and trunk vs. predicted coordinates.
|
|
226
|
+
|
|
227
|
+
Optionally, we also subset to intra-ligand (atomized) distances.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
X_L: The predicted coordinates. Shape: [D, L, 3]
|
|
231
|
+
trunk_pred_distogram: The prediction from the DistogramHead, which linearly projects the trunk features. Shape: [I, I, n_bins]
|
|
232
|
+
ground_truth_atom_array_stack: The ground-truth atom array stack, one model per diffusion sample. Shape: [D, L]
|
|
233
|
+
X_rep_atoms_I: The ground-truth coordinates of the representative atoms for each token. Shape: [D, I, 3]. If None, will be inferred from the ground_truth_atom_array_stack.
|
|
234
|
+
crd_mask_rep_atoms_I: A boolean mask indicating which representative atoms are present. Shape: [D, I]. If None, will be inferred from the ground_truth_atom_array_stack.
|
|
235
|
+
"""
|
|
236
|
+
MIN_PAIRS = 15
|
|
237
|
+
results = {}
|
|
238
|
+
|
|
239
|
+
# ... choose the first model, as we only care about 2D distance (frame-invariant)
|
|
240
|
+
ground_truth_atom_array = ground_truth_atom_array_stack[0]
|
|
241
|
+
|
|
242
|
+
_token_rep_idxs = get_af3_token_representative_idxs(ground_truth_atom_array)
|
|
243
|
+
token_rep_idxs = torch.from_numpy(_token_rep_idxs).to(X_L.device)
|
|
244
|
+
token_rep_atom_array = ground_truth_atom_array[_token_rep_idxs]
|
|
245
|
+
|
|
246
|
+
# Create 2D coordinate mask for valid pairs of representative atoms
|
|
247
|
+
if crd_mask_rep_atoms_I is None:
|
|
248
|
+
# (If not provided, we will use the occupancy mask)
|
|
249
|
+
crd_mask_rep_atoms_I = torch.from_numpy(
|
|
250
|
+
token_rep_atom_array.occupancy > 0
|
|
251
|
+
).to(X_L.device)
|
|
252
|
+
|
|
253
|
+
crd_mask_rep_atom_II = crd_mask_rep_atoms_I.unsqueeze(
|
|
254
|
+
-1
|
|
255
|
+
) * crd_mask_rep_atoms_I.unsqueeze(-2)
|
|
256
|
+
|
|
257
|
+
# Prepare distograms
|
|
258
|
+
# (From the ground truth)
|
|
259
|
+
if X_rep_atoms_I is None:
|
|
260
|
+
# (If not provided, we will use the coordinates of the representative atoms)
|
|
261
|
+
X_rep_atoms_I = torch.from_numpy(token_rep_atom_array.coord).to(X_L.device)
|
|
262
|
+
binned_distogram_from_ground_truth = bin_distances(X_rep_atoms_I, n_bins=64)
|
|
263
|
+
# (Predicted coordinates are batched, so we build the distogram for each predicted structure)
|
|
264
|
+
binned_distogram_from_pred_coords = bin_distances(
|
|
265
|
+
X_L[:, token_rep_idxs], n_bins=64
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
for config in self.comparison_configs:
|
|
269
|
+
# ... create a token-by-token mask for this config, specifying which 2D pairs to compare
|
|
270
|
+
token_pair_mask = config.create_distogram_mask(token_rep_atom_array)
|
|
271
|
+
mask = (
|
|
272
|
+
torch.from_numpy(token_pair_mask).to(X_L.device) & crd_mask_rep_atom_II
|
|
273
|
+
)
|
|
274
|
+
if mask.sum() < MIN_PAIRS:
|
|
275
|
+
# (Skip if not enough pairs so we do not dilute our average)
|
|
276
|
+
continue
|
|
277
|
+
|
|
278
|
+
# ... generate a descriptive name for this config
|
|
279
|
+
name = str(config)
|
|
280
|
+
|
|
281
|
+
# Compute trunk vs. ground truth
|
|
282
|
+
results[f"trunk_vs_ground_truth_cce_{name}"] = (
|
|
283
|
+
masked_distogram_cross_entropy_loss(
|
|
284
|
+
trunk_pred_distogram.unsqueeze(0),
|
|
285
|
+
binned_distogram_from_ground_truth.unsqueeze(0),
|
|
286
|
+
mask,
|
|
287
|
+
)
|
|
288
|
+
.detach()
|
|
289
|
+
.item()
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
# Compute trunk vs. predicted coordinates
|
|
293
|
+
losses = masked_distogram_cross_entropy_loss(
|
|
294
|
+
repeat(
|
|
295
|
+
trunk_pred_distogram,
|
|
296
|
+
"i j n_bins -> d i j n_bins",
|
|
297
|
+
d=binned_distogram_from_pred_coords.shape[0],
|
|
298
|
+
),
|
|
299
|
+
binned_distogram_from_pred_coords,
|
|
300
|
+
mask,
|
|
301
|
+
)
|
|
302
|
+
results.update(
|
|
303
|
+
{
|
|
304
|
+
f"trunk_vs_pred_coords_cce_{name}_{i}": loss.detach().item()
|
|
305
|
+
for i, loss in enumerate(losses)
|
|
306
|
+
}
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
return results
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
class DistogramEntropy(Metric):
|
|
313
|
+
"""Computes the entropy of the predicted distogram, subset to specific token pairs."""
|
|
314
|
+
|
|
315
|
+
@property
|
|
316
|
+
def kwargs_to_compute_args(self) -> dict[str, Any]:
|
|
317
|
+
return {
|
|
318
|
+
"trunk_pred_distogram": (
|
|
319
|
+
"network_output",
|
|
320
|
+
"distogram",
|
|
321
|
+
), # [I, I, 65], where 65 is the number of bins (64 + 1)
|
|
322
|
+
"ground_truth_atom_array_stack": "ground_truth_atom_array_stack",
|
|
323
|
+
"crd_mask_rep_atoms_I": ("extra_info", "mask_token_lvl"), # [D, I]
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
def __init__(
|
|
327
|
+
self, comparison_configs: list[ComparisonConfig] | None = None, **kwargs
|
|
328
|
+
):
|
|
329
|
+
"""
|
|
330
|
+
Args:
|
|
331
|
+
comparison_configs: List of ComparisonConfig objects defining which comparisons to compute.
|
|
332
|
+
If None, uses predefined configurations for atomized and non-atomized pairs.
|
|
333
|
+
"""
|
|
334
|
+
super().__init__(**kwargs)
|
|
335
|
+
|
|
336
|
+
if comparison_configs is None:
|
|
337
|
+
# Default comparisons
|
|
338
|
+
self.comparison_configs = [
|
|
339
|
+
ComparisonConfig(
|
|
340
|
+
token_a="atomized", token_b="atomized", relationship="intra"
|
|
341
|
+
), # Atomized-Atomized Intra
|
|
342
|
+
ComparisonConfig(
|
|
343
|
+
token_a="non_atomized", token_b="non_atomized", relationship="intra"
|
|
344
|
+
), # Non-Atomized-Non-Atomized Intra
|
|
345
|
+
ComparisonConfig(
|
|
346
|
+
token_a="all", token_b="all", relationship="inter"
|
|
347
|
+
), # All-All Inter
|
|
348
|
+
ComparisonConfig(
|
|
349
|
+
token_a="all", token_b="all", relationship="all"
|
|
350
|
+
), # All-All (everything)
|
|
351
|
+
]
|
|
352
|
+
else:
|
|
353
|
+
# Use provided comparison configurations
|
|
354
|
+
self.comparison_configs = comparison_configs
|
|
355
|
+
|
|
356
|
+
def compute(
|
|
357
|
+
self,
|
|
358
|
+
trunk_pred_distogram: Float[torch.Tensor, "I I n_bins"],
|
|
359
|
+
ground_truth_atom_array_stack: AtomArrayStack,
|
|
360
|
+
crd_mask_rep_atoms_I: Float[torch.Tensor, "D I"] | None = None,
|
|
361
|
+
) -> dict[str, Any]:
|
|
362
|
+
"""Computes the entropy of the predicted distogram distributions for different token pair subsets."""
|
|
363
|
+
MIN_PAIRS = 15
|
|
364
|
+
results = {}
|
|
365
|
+
|
|
366
|
+
# Get the first model from the atom array stack
|
|
367
|
+
ground_truth_atom_array = ground_truth_atom_array_stack[0]
|
|
368
|
+
token_rep_atom_array = ground_truth_atom_array[
|
|
369
|
+
get_af3_token_representative_idxs(ground_truth_atom_array)
|
|
370
|
+
]
|
|
371
|
+
|
|
372
|
+
# Create 2D coordinate mask for valid pairs of representative atoms
|
|
373
|
+
if crd_mask_rep_atoms_I is None:
|
|
374
|
+
crd_mask_rep_atoms_I = torch.from_numpy(
|
|
375
|
+
token_rep_atom_array.occupancy > 0
|
|
376
|
+
).to(trunk_pred_distogram.device)
|
|
377
|
+
crd_mask_rep_atom_II = crd_mask_rep_atoms_I.unsqueeze(
|
|
378
|
+
-1
|
|
379
|
+
) * crd_mask_rep_atoms_I.unsqueeze(-2)
|
|
380
|
+
|
|
381
|
+
# Compute entropy for each comparison configuration
|
|
382
|
+
for config in self.comparison_configs:
|
|
383
|
+
# Create a token-by-token mask for this config, specifying which 2D pairs to analyze
|
|
384
|
+
token_pair_mask = config.create_distogram_mask(
|
|
385
|
+
token_rep_atom_array
|
|
386
|
+
) # [I, I]
|
|
387
|
+
mask = (
|
|
388
|
+
torch.from_numpy(token_pair_mask).to(trunk_pred_distogram.device)
|
|
389
|
+
& crd_mask_rep_atom_II
|
|
390
|
+
) # [I, I]
|
|
391
|
+
|
|
392
|
+
if mask.sum() < MIN_PAIRS:
|
|
393
|
+
# Skip if not enough pairs to avoid diluting our average
|
|
394
|
+
continue
|
|
395
|
+
|
|
396
|
+
# Generate a descriptive name for this config
|
|
397
|
+
name = str(config)
|
|
398
|
+
|
|
399
|
+
# ... convert to probabilities via softmax
|
|
400
|
+
trunk_pred_distogram_probs = torch.nn.functional.softmax(
|
|
401
|
+
trunk_pred_distogram, dim=-1
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
# Compute entropy: -sum(p * log(p)) for each distribution
|
|
405
|
+
# Add small epsilon to avoid log(0)
|
|
406
|
+
epsilon = 1e-10
|
|
407
|
+
entropy = -torch.sum(
|
|
408
|
+
trunk_pred_distogram_probs
|
|
409
|
+
* torch.log(trunk_pred_distogram_probs + epsilon),
|
|
410
|
+
dim=-1,
|
|
411
|
+
) # [I, I]
|
|
412
|
+
|
|
413
|
+
# Apply mask and compute average entropy
|
|
414
|
+
masked_entropy = entropy * mask
|
|
415
|
+
assert_no_nans(masked_entropy)
|
|
416
|
+
|
|
417
|
+
avg_entropy = masked_entropy.sum() / (mask.sum() + 1e-6)
|
|
418
|
+
|
|
419
|
+
results[f"distogram_entropy_{name}"] = avg_entropy.detach().item()
|
|
420
|
+
|
|
421
|
+
return results
|