rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,421 @@
1
+ from dataclasses import dataclass
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from atomworks.ml.utils.token import get_af3_token_representative_idxs
8
+ from beartype.typing import Any, Literal
9
+ from biotite.structure import AtomArrayStack
10
+ from einops import rearrange, repeat
11
+ from jaxtyping import Bool, Float
12
+ from rf3.loss.af3_losses import distogram_loss
13
+
14
+ from foundry.metrics.metric import Metric
15
+ from foundry.utils.torch import assert_no_nans
16
+
17
+
18
+ @dataclass
19
+ class ComparisonConfig:
20
+ """Configuration for token pair comparisons in distogram metrics."""
21
+
22
+ token_a: Literal["all", "atomized", "non_atomized"] = "all"
23
+ token_b: Literal["all", "atomized", "non_atomized"] = "all"
24
+ relationship: Literal["all", "inter", "intra"] = "all"
25
+
26
+ def __eq__(self, other):
27
+ """Equality that accounts for token_a/token_b symmetry."""
28
+ if not isinstance(other, type(self)):
29
+ return False
30
+
31
+ return self.relationship == other.relationship and {
32
+ self.token_a,
33
+ self.token_b,
34
+ } == {other.token_a, other.token_b}
35
+
36
+ def __hash__(self):
37
+ """Hash function compatible with the equality definition."""
38
+ return hash((frozenset([self.token_a, self.token_b]), self.relationship))
39
+
40
+ def __str__(self):
41
+ """String representation of the comparison config."""
42
+ name = f"{self.token_a}_by_{self.token_b}"
43
+ if self.relationship != "all":
44
+ name += f"_{self.relationship}"
45
+ return name
46
+
47
+ def create_distogram_mask(
48
+ self, token_rep_atom_array: AtomArrayStack
49
+ ) -> Bool[np.ndarray, "I I"]:
50
+ """Create a token-by-token mask indiciating which 2D pairs satisfy the ComparisonConfig's conditions."""
51
+ type_masks = {
52
+ "all": np.ones(len(token_rep_atom_array), dtype=bool),
53
+ "atomized": token_rep_atom_array.atomize,
54
+ "non_atomized": ~token_rep_atom_array.atomize,
55
+ }
56
+ # Create token pair mask
57
+ if self.token_a == self.token_b:
58
+ # (Both same)
59
+ token_pair_mask = np.outer(
60
+ type_masks[self.token_a], type_masks[self.token_b]
61
+ )
62
+ else:
63
+ # (Different - must be symmetric)
64
+ token_pair_mask = np.outer(
65
+ type_masks[self.token_a], type_masks[self.token_b]
66
+ ) | np.outer(type_masks[self.token_b], type_masks[self.token_a])
67
+
68
+ # Apply relationship constraint
69
+ if self.relationship != "all":
70
+ intra_mask = np.equal.outer(
71
+ token_rep_atom_array.pn_unit_iid, token_rep_atom_array.pn_unit_iid
72
+ )
73
+ if self.relationship == "intra":
74
+ # Same chain ("intra")
75
+ token_pair_mask = token_pair_mask & intra_mask
76
+ else:
77
+ # Different chains ("inter")
78
+ token_pair_mask = token_pair_mask & (~intra_mask)
79
+
80
+ return token_pair_mask
81
+
82
+
83
+ class DistogramLoss(Metric):
84
+ """Computes the distogram loss, taking into account the coordinate mask."""
85
+
86
+ @property
87
+ def kwargs_to_compute_args(self) -> dict[str, Any]:
88
+ return {
89
+ "pred_distogram": ("network_output", "distogram"),
90
+ "X_rep_atoms_I": ("extra_info", "coord_token_lvl"),
91
+ "crd_mask_rep_atoms_I": ("extra_info", "mask_token_lvl"),
92
+ }
93
+
94
+ def __init__(self, **kwargs):
95
+ super().__init__(**kwargs)
96
+ self.cce_loss = nn.CrossEntropyLoss(reduction="none")
97
+
98
+ def compute(
99
+ self,
100
+ pred_distogram: Float[torch.Tensor, "I I n_bins"],
101
+ X_rep_atoms_I: Float[torch.Tensor, "I 3"],
102
+ crd_mask_rep_atoms_I: Float[torch.Tensor, "I"],
103
+ ) -> dict[str, Any]:
104
+ """Computes the distogram loss.
105
+
106
+ Args:
107
+ pred_distogram: The predicted distogram. Shape: [I, I, n_bins], where n_bins is the number of bins (64 + 1 = 65).
108
+ X_rep_atoms_I: The ground-truth coordinates of the representative atoms for each token. Shape: [I, 3].
109
+ crd_mask_rep_atoms_I: A boolean mask indicating which representative atoms are present. Shape: [I].
110
+ """
111
+ loss = distogram_loss(
112
+ pred_distogram, X_rep_atoms_I, crd_mask_rep_atoms_I, self.cce_loss
113
+ )
114
+ return {"distogram_loss": loss.detach().item()}
115
+
116
+
117
+ def bin_distances(
118
+ coords: Float[torch.Tensor, "... L 3"],
119
+ min_distance: int = 2,
120
+ max_distance: int = 22,
121
+ n_bins: int = 64,
122
+ ) -> Float[torch.Tensor, "... L L {n_bins}+1"]:
123
+ # TODO: Refactor loss to use this function instead (more re-usable)
124
+ """Converts coordinates into binned distances according to the given parameters.
125
+
126
+ NOTE: Our returned number of bins will be n_bins + 1, as torch.bucketize adds an additional bin for values greater than the maximum.
127
+
128
+ Args:
129
+ coords (torch.Tensor): The input tensor of coordinates. May be batched.
130
+ min_distance (float): The minimum distance for binning.
131
+ max_distance (float): The maximum distance for binning.
132
+ n_bins (int): The number of bins to use.
133
+
134
+ Returns:
135
+ torch.Tensor: The binned distances.
136
+ """
137
+ # Compute pairwise distances
138
+ distance_map = torch.cdist(coords, coords)
139
+
140
+ # (Replace NaN's with a large value to avoid issues with bucketize)
141
+ distance_map = torch.nan_to_num(distance_map, nan=9999.0)
142
+
143
+ # ... bin the distances
144
+ n_bins = torch.linspace(min_distance, max_distance, n_bins).to(coords.device)
145
+ binned_distances = torch.bucketize(distance_map, n_bins)
146
+
147
+ return binned_distances
148
+
149
+
150
+ def masked_distogram_cross_entropy_loss(
151
+ input: Float[torch.Tensor, "D I I n_bins"],
152
+ target: Float[torch.Tensor, "D I I"],
153
+ mask: Float[torch.Tensor, "I I"] = None,
154
+ ) -> Float[torch.Tensor, "D"]:
155
+ # TODO: Refactor loss to use this function instead (more re-usable)
156
+ """Computes the masked cross-entropy between two distograms.
157
+
158
+ Note that the cross-entropy loss is not symmetric; that is, H(x, y) != H(y, x).
159
+ """
160
+ # From the PyTorch documentation (where C = number of classes, N = batch size):
161
+ # > Input: Shape: (C), (N, C) or (N, C, d1, d2, ..., dk)
162
+ # > Target: Shape: (N) or (N, d1, d2, ..., dk) where each value should be between [0, C)
163
+ input = rearrange(input, "d i j n_bins -> d n_bins i j")
164
+ loss = F.cross_entropy(input, target, reduction="none")
165
+
166
+ # Apply mask and normalize
167
+ masked_loss = loss * mask if mask is not None else loss
168
+ normalized_loss = masked_loss.sum(dim=(-1, -2)) / mask.sum() + 1e-4 # [D]
169
+
170
+ return normalized_loss
171
+
172
+
173
+ class DistogramComparisons(Metric):
174
+ """Compares model distogram representations.
175
+
176
+ Namely:
177
+ - The representation from the TRUNK vs. GROUND TRUTH
178
+ - The representation from the TRUNK vs. PREDICTED COORDINATES
179
+
180
+ We subset to specific token pairs based on the provided ComparisonConfig.
181
+ """
182
+
183
+ @property
184
+ def kwargs_to_compute_args(self) -> dict[str, Any]:
185
+ return {
186
+ "X_L": ("network_output", "X_L"), # [D, L, 3]
187
+ "trunk_pred_distogram": (
188
+ "network_output",
189
+ "distogram",
190
+ ), # [I, I, 65], where 65 is the number of bins (64 + 1)
191
+ "ground_truth_atom_array_stack": "ground_truth_atom_array_stack",
192
+ "X_rep_atoms_I": ("extra_info", "coord_token_lvl"), # [D, I, 3]
193
+ "crd_mask_rep_atoms_I": ("extra_info", "mask_token_lvl"), # [D, I]
194
+ }
195
+
196
+ def __init__(
197
+ self, comparison_configs: list[ComparisonConfig] | None = None, **kwargs
198
+ ):
199
+ """
200
+ Args:
201
+ comparison_configs: List of ComparisonConfig objects defining which comparisons to compute.
202
+ """
203
+ super().__init__(**kwargs)
204
+
205
+ if comparison_configs is None:
206
+ # Default comparisons
207
+ comparison_configs = [
208
+ ComparisonConfig("atomized", "atomized", "intra"),
209
+ ComparisonConfig("atomized", "non_atomized", "inter"),
210
+ ComparisonConfig("non_atomized", "non_atomized", "intra"),
211
+ ComparisonConfig("all", "all", "all"),
212
+ ]
213
+
214
+ # Deduplicate (handle symmetries in token_a/token_b)
215
+ self.comparison_configs = list(set(comparison_configs))
216
+
217
+ def compute(
218
+ self,
219
+ X_L: Float[torch.Tensor, "D L 3"],
220
+ trunk_pred_distogram: Float[torch.Tensor, "I I n_bins"],
221
+ ground_truth_atom_array_stack: AtomArrayStack,
222
+ X_rep_atoms_I: Float[torch.Tensor, "D I 3"] | None = None,
223
+ crd_mask_rep_atoms_I: Float[torch.Tensor, "D I"] | None = None,
224
+ ) -> dict[str, Any]:
225
+ """Computes the distogram loss for the trunk vs. ground truth and trunk vs. predicted coordinates.
226
+
227
+ Optionally, we also subset to intra-ligand (atomized) distances.
228
+
229
+ Args:
230
+ X_L: The predicted coordinates. Shape: [D, L, 3]
231
+ trunk_pred_distogram: The prediction from the DistogramHead, which linearly projects the trunk features. Shape: [I, I, n_bins]
232
+ ground_truth_atom_array_stack: The ground-truth atom array stack, one model per diffusion sample. Shape: [D, L]
233
+ X_rep_atoms_I: The ground-truth coordinates of the representative atoms for each token. Shape: [D, I, 3]. If None, will be inferred from the ground_truth_atom_array_stack.
234
+ crd_mask_rep_atoms_I: A boolean mask indicating which representative atoms are present. Shape: [D, I]. If None, will be inferred from the ground_truth_atom_array_stack.
235
+ """
236
+ MIN_PAIRS = 15
237
+ results = {}
238
+
239
+ # ... choose the first model, as we only care about 2D distance (frame-invariant)
240
+ ground_truth_atom_array = ground_truth_atom_array_stack[0]
241
+
242
+ _token_rep_idxs = get_af3_token_representative_idxs(ground_truth_atom_array)
243
+ token_rep_idxs = torch.from_numpy(_token_rep_idxs).to(X_L.device)
244
+ token_rep_atom_array = ground_truth_atom_array[_token_rep_idxs]
245
+
246
+ # Create 2D coordinate mask for valid pairs of representative atoms
247
+ if crd_mask_rep_atoms_I is None:
248
+ # (If not provided, we will use the occupancy mask)
249
+ crd_mask_rep_atoms_I = torch.from_numpy(
250
+ token_rep_atom_array.occupancy > 0
251
+ ).to(X_L.device)
252
+
253
+ crd_mask_rep_atom_II = crd_mask_rep_atoms_I.unsqueeze(
254
+ -1
255
+ ) * crd_mask_rep_atoms_I.unsqueeze(-2)
256
+
257
+ # Prepare distograms
258
+ # (From the ground truth)
259
+ if X_rep_atoms_I is None:
260
+ # (If not provided, we will use the coordinates of the representative atoms)
261
+ X_rep_atoms_I = torch.from_numpy(token_rep_atom_array.coord).to(X_L.device)
262
+ binned_distogram_from_ground_truth = bin_distances(X_rep_atoms_I, n_bins=64)
263
+ # (Predicted coordinates are batched, so we build the distogram for each predicted structure)
264
+ binned_distogram_from_pred_coords = bin_distances(
265
+ X_L[:, token_rep_idxs], n_bins=64
266
+ )
267
+
268
+ for config in self.comparison_configs:
269
+ # ... create a token-by-token mask for this config, specifying which 2D pairs to compare
270
+ token_pair_mask = config.create_distogram_mask(token_rep_atom_array)
271
+ mask = (
272
+ torch.from_numpy(token_pair_mask).to(X_L.device) & crd_mask_rep_atom_II
273
+ )
274
+ if mask.sum() < MIN_PAIRS:
275
+ # (Skip if not enough pairs so we do not dilute our average)
276
+ continue
277
+
278
+ # ... generate a descriptive name for this config
279
+ name = str(config)
280
+
281
+ # Compute trunk vs. ground truth
282
+ results[f"trunk_vs_ground_truth_cce_{name}"] = (
283
+ masked_distogram_cross_entropy_loss(
284
+ trunk_pred_distogram.unsqueeze(0),
285
+ binned_distogram_from_ground_truth.unsqueeze(0),
286
+ mask,
287
+ )
288
+ .detach()
289
+ .item()
290
+ )
291
+
292
+ # Compute trunk vs. predicted coordinates
293
+ losses = masked_distogram_cross_entropy_loss(
294
+ repeat(
295
+ trunk_pred_distogram,
296
+ "i j n_bins -> d i j n_bins",
297
+ d=binned_distogram_from_pred_coords.shape[0],
298
+ ),
299
+ binned_distogram_from_pred_coords,
300
+ mask,
301
+ )
302
+ results.update(
303
+ {
304
+ f"trunk_vs_pred_coords_cce_{name}_{i}": loss.detach().item()
305
+ for i, loss in enumerate(losses)
306
+ }
307
+ )
308
+
309
+ return results
310
+
311
+
312
+ class DistogramEntropy(Metric):
313
+ """Computes the entropy of the predicted distogram, subset to specific token pairs."""
314
+
315
+ @property
316
+ def kwargs_to_compute_args(self) -> dict[str, Any]:
317
+ return {
318
+ "trunk_pred_distogram": (
319
+ "network_output",
320
+ "distogram",
321
+ ), # [I, I, 65], where 65 is the number of bins (64 + 1)
322
+ "ground_truth_atom_array_stack": "ground_truth_atom_array_stack",
323
+ "crd_mask_rep_atoms_I": ("extra_info", "mask_token_lvl"), # [D, I]
324
+ }
325
+
326
+ def __init__(
327
+ self, comparison_configs: list[ComparisonConfig] | None = None, **kwargs
328
+ ):
329
+ """
330
+ Args:
331
+ comparison_configs: List of ComparisonConfig objects defining which comparisons to compute.
332
+ If None, uses predefined configurations for atomized and non-atomized pairs.
333
+ """
334
+ super().__init__(**kwargs)
335
+
336
+ if comparison_configs is None:
337
+ # Default comparisons
338
+ self.comparison_configs = [
339
+ ComparisonConfig(
340
+ token_a="atomized", token_b="atomized", relationship="intra"
341
+ ), # Atomized-Atomized Intra
342
+ ComparisonConfig(
343
+ token_a="non_atomized", token_b="non_atomized", relationship="intra"
344
+ ), # Non-Atomized-Non-Atomized Intra
345
+ ComparisonConfig(
346
+ token_a="all", token_b="all", relationship="inter"
347
+ ), # All-All Inter
348
+ ComparisonConfig(
349
+ token_a="all", token_b="all", relationship="all"
350
+ ), # All-All (everything)
351
+ ]
352
+ else:
353
+ # Use provided comparison configurations
354
+ self.comparison_configs = comparison_configs
355
+
356
+ def compute(
357
+ self,
358
+ trunk_pred_distogram: Float[torch.Tensor, "I I n_bins"],
359
+ ground_truth_atom_array_stack: AtomArrayStack,
360
+ crd_mask_rep_atoms_I: Float[torch.Tensor, "D I"] | None = None,
361
+ ) -> dict[str, Any]:
362
+ """Computes the entropy of the predicted distogram distributions for different token pair subsets."""
363
+ MIN_PAIRS = 15
364
+ results = {}
365
+
366
+ # Get the first model from the atom array stack
367
+ ground_truth_atom_array = ground_truth_atom_array_stack[0]
368
+ token_rep_atom_array = ground_truth_atom_array[
369
+ get_af3_token_representative_idxs(ground_truth_atom_array)
370
+ ]
371
+
372
+ # Create 2D coordinate mask for valid pairs of representative atoms
373
+ if crd_mask_rep_atoms_I is None:
374
+ crd_mask_rep_atoms_I = torch.from_numpy(
375
+ token_rep_atom_array.occupancy > 0
376
+ ).to(trunk_pred_distogram.device)
377
+ crd_mask_rep_atom_II = crd_mask_rep_atoms_I.unsqueeze(
378
+ -1
379
+ ) * crd_mask_rep_atoms_I.unsqueeze(-2)
380
+
381
+ # Compute entropy for each comparison configuration
382
+ for config in self.comparison_configs:
383
+ # Create a token-by-token mask for this config, specifying which 2D pairs to analyze
384
+ token_pair_mask = config.create_distogram_mask(
385
+ token_rep_atom_array
386
+ ) # [I, I]
387
+ mask = (
388
+ torch.from_numpy(token_pair_mask).to(trunk_pred_distogram.device)
389
+ & crd_mask_rep_atom_II
390
+ ) # [I, I]
391
+
392
+ if mask.sum() < MIN_PAIRS:
393
+ # Skip if not enough pairs to avoid diluting our average
394
+ continue
395
+
396
+ # Generate a descriptive name for this config
397
+ name = str(config)
398
+
399
+ # ... convert to probabilities via softmax
400
+ trunk_pred_distogram_probs = torch.nn.functional.softmax(
401
+ trunk_pred_distogram, dim=-1
402
+ )
403
+
404
+ # Compute entropy: -sum(p * log(p)) for each distribution
405
+ # Add small epsilon to avoid log(0)
406
+ epsilon = 1e-10
407
+ entropy = -torch.sum(
408
+ trunk_pred_distogram_probs
409
+ * torch.log(trunk_pred_distogram_probs + epsilon),
410
+ dim=-1,
411
+ ) # [I, I]
412
+
413
+ # Apply mask and compute average entropy
414
+ masked_entropy = entropy * mask
415
+ assert_no_nans(masked_entropy)
416
+
417
+ avg_entropy = masked_entropy.sum() / (mask.sum() + 1e-6)
418
+
419
+ results[f"distogram_entropy_{name}"] = avg_entropy.detach().item()
420
+
421
+ return results