rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,324 @@
1
+ import os
2
+ from copy import deepcopy
3
+ from pathlib import Path
4
+
5
+ import pandas as pd
6
+ from atomworks.ml.utils import nested_dict
7
+ from beartype.typing import Any, Literal
8
+ from omegaconf import ListConfig
9
+
10
+ from foundry.callbacks.callback import BaseCallback
11
+ from foundry.utils.ddp import RankedLogger
12
+ from foundry.utils.logging import (
13
+ condense_count_columns_of_grouped_df,
14
+ print_df_as_table,
15
+ )
16
+
17
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
18
+
19
+
20
+ class StoreValidationMetricsInDFCallback(BaseCallback):
21
+ """Saves the validation outputs in a DataFrame for each rank and concatenates them at the end of the validation epoch."""
22
+
23
+ def __init__(
24
+ self,
25
+ save_dir: os.PathLike,
26
+ metrics_to_save: list[str] | Literal["all"] = "all",
27
+ ):
28
+ self.save_dir = Path(save_dir)
29
+ self.metrics_to_save = metrics_to_save
30
+
31
+ def _save_dataframe_for_rank(self, rank: int, epoch: int):
32
+ """Saves per-GPU output dataframe of metrics to a rank-specific CSV."""
33
+ self.save_dir.mkdir(parents=True, exist_ok=True)
34
+ file_path = self.save_dir / f"validation_output_rank_{rank}_epoch_{epoch}.csv"
35
+
36
+ # Flush explicitly to ensure the file is written to disk
37
+ with open(file_path, "w") as f:
38
+ self.per_gpu_outputs_df.to_csv(f, index=False)
39
+ f.flush()
40
+ os.fsync(f.fileno())
41
+
42
+ ranked_logger.info(
43
+ f"Saved validation outputs to {file_path} for rank {rank}, epoch {epoch}"
44
+ )
45
+
46
+ def on_validation_epoch_start(self, trainer):
47
+ self.per_gpu_outputs_df = pd.DataFrame()
48
+
49
+ def on_validation_batch_end(
50
+ self,
51
+ trainer,
52
+ outputs: dict,
53
+ batch: Any,
54
+ batch_idx: int,
55
+ num_batches: int,
56
+ dataset_name: str | None = None,
57
+ ):
58
+ """Build a flattened DataFrame from the metrics output and accumulate with the prior batches"""
59
+ assert "metrics_output" in outputs, "Validation outputs must contain metrics."
60
+ metrics_output = deepcopy(outputs["metrics_output"])
61
+
62
+ # ... assemble a flat DataFrame from the metrics output
63
+ example_id = metrics_output.pop("example_id")
64
+ metrics_as_list_of_dicts = []
65
+
66
+ # ... remove metrics that are not in the save list
67
+ if self.metrics_to_save != "all" and isinstance(
68
+ self.metrics_to_save, list | ListConfig
69
+ ):
70
+ metrics_output = {
71
+ k: v
72
+ for k, v in metrics_output.items()
73
+ if any(k.startswith(prefix) for prefix in self.metrics_to_save)
74
+ }
75
+
76
+ def _build_row_from_flattened_dict(
77
+ dict_to_flatten: dict, prefix: str, example_id: str
78
+ ):
79
+ """Helper function to build a DataFrame row"""
80
+ flattened_dict = nested_dict.flatten(dict_to_flatten, fuse_keys=".")
81
+ row_data = {"example_id": example_id}
82
+ for sub_k, sub_v in flattened_dict.items():
83
+ # Convert lists to tuples so that they are hashable
84
+ if isinstance(sub_v, list):
85
+ sub_v = tuple(sub_v)
86
+ row_data[f"{prefix}.{sub_k}"] = sub_v
87
+ return row_data
88
+
89
+ scalar_metrics = {"example_id": example_id}
90
+ for key, value in metrics_output.items():
91
+ if isinstance(value, dict):
92
+ # Flatten once for this dict => 1 row.
93
+ metrics_as_list_of_dicts.append(
94
+ _build_row_from_flattened_dict(value, key, example_id)
95
+ )
96
+ elif isinstance(value, list) and all(isinstance(x, dict) for x in value):
97
+ # Flatten each dict in the list => multiple rows.
98
+ for subdict in value:
99
+ metrics_as_list_of_dicts.append(
100
+ _build_row_from_flattened_dict(subdict, key, example_id)
101
+ )
102
+ else:
103
+ # Scalar (string, float, int, or list that isn't list-of-dicts)
104
+ assert key not in scalar_metrics, f"Duplicate key: {key}"
105
+ scalar_metrics[key] = value
106
+
107
+ metrics_as_list_of_dicts.append(scalar_metrics)
108
+
109
+ # ... convert the list of dicts to a DataFrame and add epoch and dataset columns
110
+ batch_df = pd.DataFrame(metrics_as_list_of_dicts)
111
+ batch_df["epoch"] = trainer.state["current_epoch"]
112
+ batch_df["dataset"] = dataset_name
113
+
114
+ # Assert no duplicate rows
115
+ assert (
116
+ batch_df.duplicated().sum() == 0
117
+ ), "Duplicate rows found in the metrics DataFrame!"
118
+
119
+ # Accumulate into the per-rank DataFrame
120
+ self.per_gpu_outputs_df = pd.concat(
121
+ [self.per_gpu_outputs_df, batch_df], ignore_index=True
122
+ )
123
+
124
+ ranked_logger.info(
125
+ f"Validation Progress: {100 * (batch_idx + 1) / num_batches:.0f}% for {dataset_name}"
126
+ )
127
+
128
+ def on_validation_epoch_end(self, trainer):
129
+ """Aggregate and log the validation metrics at the end of the epoch.
130
+
131
+ Each rank writes out its partial CSV. Then rank 0 aggregates them, logs grouped metrics by dataset,
132
+ and appends them to a master file containing data from all epochs.
133
+ """
134
+
135
+ # ... write out partial CSV for this rank
136
+ rank = trainer.fabric.global_rank
137
+ epoch = trainer.state["current_epoch"]
138
+ self._save_dataframe_for_rank(rank, epoch)
139
+
140
+ # Synchronize all processes
141
+ ranked_logger.info(
142
+ "Synchronizing all processes before concatenating DataFrames..."
143
+ )
144
+ trainer.fabric.barrier()
145
+
146
+ # Only rank 0 loads and concatenates the DataFrames
147
+ ranked_logger.info("Loading and concatenating DataFrames...")
148
+ if trainer.fabric.is_global_zero:
149
+ # ... load all partial CSVs
150
+ merged_df = self._load_and_concatenate_csvs(epoch)
151
+
152
+ # ... append to master CSV for all epochs
153
+ master_path = self.save_dir / "validation_output_all_epochs.csv"
154
+ if master_path.exists():
155
+ old_df = pd.read_csv(master_path)
156
+ merged_df = pd.concat(
157
+ [old_df, merged_df], ignore_index=True, sort=False
158
+ )
159
+ merged_df.to_csv(master_path, index=False)
160
+ ranked_logger.info(f"Appended epoch={epoch} results to {master_path}")
161
+
162
+ # Store the path to the master CSV in the Trainer
163
+ trainer.validation_results_path = master_path
164
+
165
+ # Cleanup
166
+ self._cleanup_temp_files()
167
+
168
+ def _load_and_concatenate_csvs(self, epoch: int) -> pd.DataFrame:
169
+ """Load rank-specific CSVs for the given epoch and concatenate them without duplicating examples."""
170
+ pattern = f"validation_output_rank_*_epoch_{epoch}.csv"
171
+ files = list(self.save_dir.glob(pattern))
172
+
173
+ # Track which example_id + dataset combinations we've already seen
174
+ seen_examples = set()
175
+ final_dataframes = []
176
+
177
+ for f in files:
178
+ try:
179
+ df = pd.read_csv(f)
180
+
181
+ # Create a filter for rows with new example_id + dataset combinations
182
+ if not df.empty:
183
+ # Create a unique identifier for each example_id + dataset combination
184
+ df["_example_key"] = (
185
+ df["example_id"].astype(str) + "|" + df["dataset"].astype(str)
186
+ )
187
+
188
+ # Filter out rows with example_id + dataset combinations we've already seen
189
+ new_examples_mask = ~df["_example_key"].isin(seen_examples)
190
+
191
+ # If there are any new examples, add them to our final list
192
+ if new_examples_mask.any():
193
+ new_examples_df = df[new_examples_mask].copy()
194
+
195
+ # Update our set of seen examples
196
+ seen_examples.update(new_examples_df["_example_key"].tolist())
197
+
198
+ # Remove the temporary column before adding to final list
199
+ new_examples_df.drop("_example_key", axis=1, inplace=True)
200
+ final_dataframes.append(new_examples_df)
201
+
202
+ except pd.errors.EmptyDataError:
203
+ ranked_logger.warning(f"Skipping empty CSV: {f}")
204
+
205
+ # Concatenate dataframes, filling missing columns with NaN
206
+ return pd.concat(final_dataframes, axis=0, ignore_index=True, sort=False)
207
+
208
+ def _cleanup_temp_files(self):
209
+ """Remove temporary files used to store individual rank outputs."""
210
+ all_files = list(self.save_dir.rglob("validation_output_rank_*_epoch_*.csv"))
211
+ for file in all_files:
212
+ try:
213
+ file.unlink() # Remove the file
214
+ except Exception as e:
215
+ ranked_logger.warning(f"Failed to delete file {file}: {e}")
216
+
217
+
218
+ class LogAF3ValidationMetricsCallback(BaseCallback):
219
+ def __init__(
220
+ self,
221
+ metrics_to_log: list[str] | Literal["all"] = "all",
222
+ ):
223
+ self.metrics_to_log = metrics_to_log
224
+
225
+ def on_validation_epoch_end(self, trainer):
226
+ # Only log metrics to disk if this is the global zero rank
227
+ if not trainer.fabric.is_global_zero:
228
+ return
229
+
230
+ assert hasattr(
231
+ trainer, "validation_results_path"
232
+ ), "Results path not found! Ensure that StoreValidationMetricsInDFCallback is called first."
233
+ df = pd.read_csv(trainer.validation_results_path)
234
+
235
+ # ... filter to most recent epoch, drop epoch column
236
+ df = df[df["epoch"] == df["epoch"].max()]
237
+ df.drop(columns=["epoch", "example_id"], inplace=True)
238
+
239
+ # ... filter to columns that start with the metrics_to_log prefixes (and "dataset")
240
+ if self.metrics_to_log != "all" and isinstance(
241
+ self.metrics_to_log, list | ListConfig
242
+ ):
243
+ df = df[
244
+ [
245
+ col
246
+ for col in df.columns
247
+ if any(col.startswith(prefix) for prefix in self.metrics_to_log)
248
+ ]
249
+ + ["dataset"]
250
+ ]
251
+
252
+ for dataset in df["dataset"].unique():
253
+ dataset_df = df[df["dataset"] == dataset].copy()
254
+ dataset_df.drop(columns=["dataset"], inplace=True)
255
+
256
+ print(f"\n+{' ' + dataset + ' ':-^150}+\n")
257
+
258
+ # +------------- LDDT by type (chain, interface) -------------+
259
+ by_type_lddt_cols = [
260
+ col for col in df.columns if col.startswith("by_type_lddt")
261
+ ]
262
+ if by_type_lddt_cols:
263
+ # ... build by-type DataFrame
264
+ by_type_df = dataset_df[by_type_lddt_cols].copy()
265
+ by_type_df = by_type_df.dropna(how="all")
266
+
267
+ # ... remove the "by_type_lddt." prefix
268
+ by_type_df.columns = by_type_df.columns.str.replace("by_type_lddt.", "")
269
+ numeric_cols = by_type_df.select_dtypes(include="number").columns
270
+
271
+ # ... group by type
272
+ grouped = by_type_df.groupby("type")[numeric_cols].agg(
273
+ ["mean", "count"]
274
+ )
275
+ print_df_as_table(
276
+ condense_count_columns_of_grouped_df(grouped).reset_index(),
277
+ f"{dataset} — Epoch {trainer.state['current_epoch']} — Validation Metrics: LDDT by Type",
278
+ )
279
+
280
+ # Log the grouped metrics (aggregated from all ranks) with Fabric
281
+ if trainer.fabric:
282
+ for _, row in grouped.reset_index().iterrows():
283
+ trainer.fabric.log_dict(
284
+ {
285
+ f"val/{dataset}/{row['type'].iloc[0]}/{col}": row[col][
286
+ "mean"
287
+ ]
288
+ for col in numeric_cols
289
+ },
290
+ step=trainer.state["current_epoch"],
291
+ )
292
+
293
+ # +----------------- Other metrics -----------------+
294
+ remaining_cols = list(set(dataset_df.columns) - set(by_type_lddt_cols))
295
+ remaining_df = dataset_df[remaining_cols].copy()
296
+ remaining_df = remaining_df.dropna(how="all", axis=0)
297
+ remaining_df = remaining_df.dropna(
298
+ how="all", axis=1
299
+ ) # If a Metric is all NaNs for this dataset, drop it
300
+ numeric_cols = remaining_df.select_dtypes(include="number").columns
301
+
302
+ # Compute means and non-NaN counts for numeric columns
303
+ final_means = remaining_df[numeric_cols].mean()
304
+ non_nan_counts = remaining_df[numeric_cols].count()
305
+
306
+ # Convert the Series to a DataFrame and add the count as a new column
307
+ final_means_df = final_means.to_frame(name="mean")
308
+ final_means_df["Count"] = non_nan_counts
309
+
310
+ # ... sort, so the rows are alphabetical
311
+ final_means_df.sort_index(inplace=True)
312
+
313
+ print_df_as_table(
314
+ final_means_df.reset_index(),
315
+ f"{dataset} — {trainer.state['current_epoch']} — General Validation Metrics",
316
+ console_width=150,
317
+ )
318
+
319
+ if trainer.fabric:
320
+ for col in numeric_cols:
321
+ trainer.fabric.log_dict(
322
+ {f"val/{dataset}/{col}": final_means[col]},
323
+ step=trainer.state["current_epoch"],
324
+ )