rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,211 @@
1
+ import os
2
+ from copy import deepcopy
3
+ from pathlib import Path
4
+
5
+ import pandas as pd
6
+ from atomworks.ml.utils import nested_dict
7
+ from beartype.typing import Any, Literal
8
+ from omegaconf import ListConfig
9
+
10
+ from foundry.callbacks.callback import BaseCallback
11
+ from foundry.utils.ddp import RankedLogger
12
+
13
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
14
+
15
+
16
+ class StoreValidationMetricsInDFCallback(BaseCallback):
17
+ """Saves the validation outputs in a DataFrame for each rank and concatenates them at the end of the validation epoch."""
18
+
19
+ def __init__(
20
+ self,
21
+ save_dir: os.PathLike,
22
+ metrics_to_save: list[str] | Literal["all"] = "all",
23
+ ):
24
+ self.save_dir = Path(save_dir)
25
+ self.metrics_to_save = metrics_to_save
26
+
27
+ def _save_dataframe_for_rank(self, rank: int, epoch: int):
28
+ """Saves per-GPU output dataframe of metrics to a rank-specific CSV."""
29
+ self.save_dir.mkdir(parents=True, exist_ok=True)
30
+ file_path = self.save_dir / f"validation_output_rank_{rank}_epoch_{epoch}.csv"
31
+
32
+ # Flush explicitly to ensure the file is written to disk
33
+ with open(file_path, "w") as f:
34
+ self.per_gpu_outputs_df.to_csv(f, index=False)
35
+ f.flush()
36
+ os.fsync(f.fileno())
37
+
38
+ ranked_logger.info(
39
+ f"Saved validation outputs to {file_path} for rank {rank}, epoch {epoch}"
40
+ )
41
+
42
+ def on_validation_epoch_start(self, trainer):
43
+ self.per_gpu_outputs_df = pd.DataFrame()
44
+
45
+ def on_validation_batch_end(
46
+ self,
47
+ trainer,
48
+ outputs: dict,
49
+ batch: Any,
50
+ batch_idx: int,
51
+ num_batches: int,
52
+ dataset_name: str | None = None,
53
+ ):
54
+ """Build a flattened DataFrame from the metrics output and accumulate with the prior batches"""
55
+ assert "metrics_output" in outputs, "Validation outputs must contain metrics."
56
+ metrics_output = deepcopy(outputs["metrics_output"])
57
+
58
+ # ... assemble a flat DataFrame from the metrics output
59
+ example_id = metrics_output.pop("example_id")
60
+ metrics_as_list_of_dicts = []
61
+
62
+ # ... remove metrics that are not in the save list
63
+ if self.metrics_to_save != "all" and isinstance(
64
+ self.metrics_to_save, list | ListConfig
65
+ ):
66
+ metrics_output = {
67
+ k: v
68
+ for k, v in metrics_output.items()
69
+ if any(k.startswith(prefix) for prefix in self.metrics_to_save)
70
+ }
71
+
72
+ def _build_row_from_flattened_dict(
73
+ dict_to_flatten: dict, prefix: str, example_id: str
74
+ ):
75
+ """Helper function to build a DataFrame row"""
76
+ flattened_dict = nested_dict.flatten(dict_to_flatten, fuse_keys=".")
77
+ row_data = {"example_id": example_id}
78
+ for sub_k, sub_v in flattened_dict.items():
79
+ # Convert lists to tuples so that they are hashable
80
+ if isinstance(sub_v, list):
81
+ sub_v = tuple(sub_v)
82
+ row_data[f"{prefix}.{sub_k}"] = sub_v
83
+ return row_data
84
+
85
+ scalar_metrics = {"example_id": example_id}
86
+ for key, value in metrics_output.items():
87
+ if isinstance(value, dict):
88
+ # Flatten once for this dict => 1 row.
89
+ metrics_as_list_of_dicts.append(
90
+ _build_row_from_flattened_dict(value, key, example_id)
91
+ )
92
+ elif isinstance(value, list) and all(isinstance(x, dict) for x in value):
93
+ # Flatten each dict in the list => multiple rows.
94
+ for subdict in value:
95
+ metrics_as_list_of_dicts.append(
96
+ _build_row_from_flattened_dict(subdict, key, example_id)
97
+ )
98
+ else:
99
+ # Scalar (string, float, int, or list that isn't list-of-dicts)
100
+ assert key not in scalar_metrics, f"Duplicate key: {key}"
101
+ scalar_metrics[key] = value
102
+
103
+ metrics_as_list_of_dicts.append(scalar_metrics)
104
+
105
+ # ... convert the list of dicts to a DataFrame and add epoch and dataset columns
106
+ batch_df = pd.DataFrame(metrics_as_list_of_dicts)
107
+ batch_df["epoch"] = trainer.state["current_epoch"]
108
+ batch_df["dataset"] = dataset_name
109
+
110
+ # Assert no duplicate rows
111
+ assert (
112
+ batch_df.duplicated().sum() == 0
113
+ ), "Duplicate rows found in the metrics DataFrame!"
114
+
115
+ # Accumulate into the per-rank DataFrame
116
+ self.per_gpu_outputs_df = pd.concat(
117
+ [self.per_gpu_outputs_df, batch_df], ignore_index=True
118
+ )
119
+
120
+ ranked_logger.info(
121
+ f"Validation Progress: {100 * (batch_idx + 1) / num_batches:.0f}% for {dataset_name}"
122
+ )
123
+
124
+ def on_validation_epoch_end(self, trainer):
125
+ """Aggregate and log the validation metrics at the end of the epoch.
126
+
127
+ Each rank writes out its partial CSV. Then rank 0 aggregates them, logs grouped metrics by dataset,
128
+ and appends them to a master file containing data from all epochs.
129
+ """
130
+
131
+ # ... write out partial CSV for this rank
132
+ rank = trainer.fabric.global_rank
133
+ epoch = trainer.state["current_epoch"]
134
+ self._save_dataframe_for_rank(rank, epoch)
135
+
136
+ # Synchronize all processes
137
+ ranked_logger.info(
138
+ "Synchronizing all processes before concatenating DataFrames..."
139
+ )
140
+ trainer.fabric.barrier()
141
+
142
+ # Only rank 0 loads and concatenates the DataFrames
143
+ ranked_logger.info("Loading and concatenating DataFrames...")
144
+ if trainer.fabric.is_global_zero:
145
+ # ... load all partial CSVs
146
+ merged_df = self._load_and_concatenate_csvs(epoch)
147
+
148
+ # ... append to master CSV for all epochs
149
+ master_path = self.save_dir / "validation_output_all_epochs.csv"
150
+ if master_path.exists():
151
+ old_df = pd.read_csv(master_path)
152
+ merged_df = pd.concat(
153
+ [old_df, merged_df], ignore_index=True, sort=False
154
+ )
155
+ merged_df.to_csv(master_path, index=False)
156
+ ranked_logger.info(f"Appended epoch={epoch} results to {master_path}")
157
+
158
+ # Store the path to the master CSV in the Trainer
159
+ trainer.validation_results_path = master_path
160
+
161
+ # Cleanup
162
+ self._cleanup_temp_files()
163
+
164
+ def _load_and_concatenate_csvs(self, epoch: int) -> pd.DataFrame:
165
+ """Load rank-specific CSVs for the given epoch and concatenate them without duplicating examples."""
166
+ pattern = f"validation_output_rank_*_epoch_{epoch}.csv"
167
+ files = list(self.save_dir.glob(pattern))
168
+
169
+ # Track which example_id + dataset combinations we've already seen
170
+ seen_examples = set()
171
+ final_dataframes = []
172
+
173
+ for f in files:
174
+ try:
175
+ df = pd.read_csv(f)
176
+
177
+ # Create a filter for rows with new example_id + dataset combinations
178
+ if not df.empty:
179
+ # Create a unique identifier for each example_id + dataset combination
180
+ df["_example_key"] = (
181
+ df["example_id"].astype(str) + "|" + df["dataset"].astype(str)
182
+ )
183
+
184
+ # Filter out rows with example_id + dataset combinations we've already seen
185
+ new_examples_mask = ~df["_example_key"].isin(seen_examples)
186
+
187
+ # If there are any new examples, add them to our final list
188
+ if new_examples_mask.any():
189
+ new_examples_df = df[new_examples_mask].copy()
190
+
191
+ # Update our set of seen examples
192
+ seen_examples.update(new_examples_df["_example_key"].tolist())
193
+
194
+ # Remove the temporary column before adding to final list
195
+ new_examples_df.drop("_example_key", axis=1, inplace=True)
196
+ final_dataframes.append(new_examples_df)
197
+
198
+ except pd.errors.EmptyDataError:
199
+ ranked_logger.warning(f"Skipping empty CSV: {f}")
200
+
201
+ # Concatenate dataframes, filling missing columns with NaN
202
+ return pd.concat(final_dataframes, axis=0, ignore_index=True, sort=False)
203
+
204
+ def _cleanup_temp_files(self):
205
+ """Remove temporary files used to store individual rank outputs."""
206
+ all_files = list(self.save_dir.rglob("validation_output_rank_*_epoch_*.csv"))
207
+ for file in all_files:
208
+ try:
209
+ file.unlink() # Remove the file
210
+ except Exception as e:
211
+ ranked_logger.warning(f"Failed to delete file {file}: {e}")
@@ -0,0 +1,67 @@
1
+ import pandas as pd
2
+ from lightning.fabric.utilities.rank_zero import rank_zero_only
3
+
4
+ from foundry.callbacks.callback import BaseCallback
5
+ from foundry.utils.logging import print_df_as_table
6
+ from foundry.utils.torch import Timers
7
+
8
+
9
+ class TimingCallback(BaseCallback):
10
+ """Fabric callback to print timing metrics."""
11
+
12
+ def __init__(self, log_every_n: int = 100):
13
+ super().__init__()
14
+ self.log_every_n = log_every_n
15
+ self.timers = Timers()
16
+ self.n_steps_since_last_log = 0
17
+
18
+ @rank_zero_only
19
+ def on_train_epoch_start(self, trainer, **kwargs):
20
+ self.timers.start("train_loader_iter")
21
+
22
+ @rank_zero_only
23
+ def on_after_train_loader_iter(self, trainer, **kwargs):
24
+ self.timers.stop("train_loader_iter")
25
+
26
+ @rank_zero_only
27
+ def on_before_train_loader_next(self, trainer, **kwargs):
28
+ self.timers.start("train_step", "train_loader_next")
29
+
30
+ @rank_zero_only
31
+ def on_train_batch_start(self, trainer, **kwargs):
32
+ self.timers.start("forward_loss_backward")
33
+ self.timers.stop("train_loader_next")
34
+
35
+ @rank_zero_only
36
+ def on_train_batch_end(self, trainer, **kwargs):
37
+ self.timers.stop("forward_loss_backward")
38
+ self.timers.stop("train_step")
39
+
40
+ @rank_zero_only
41
+ def on_before_optimizer_step(self, trainer, **kwargs):
42
+ self.timers.start("optimizer_step")
43
+
44
+ @rank_zero_only
45
+ def on_after_optimizer_step(self, optimizer, **kwargs):
46
+ self.timers.stop("optimizer_step")
47
+
48
+ @rank_zero_only
49
+ def optimizer_step(self, trainer, optimizer):
50
+ step = trainer.state["global_step"]
51
+ self.n_steps_since_last_log += 1
52
+ if step % self.log_every_n == 0:
53
+ timings = self.timers.elapsed(*self.timers.timers.keys(), reset=True)
54
+ timings = {
55
+ f"timings/{k}": v / self.n_steps_since_last_log
56
+ for k, v in timings.items()
57
+ }
58
+ trainer.fabric.log_dict(timings, step=step)
59
+ if trainer.fabric.is_global_zero:
60
+ self._print_timings(timings)
61
+
62
+ def _print_timings(self, timings: dict[str, float]):
63
+ df = pd.DataFrame(timings.items(), columns=["Step", "Time (s)"])
64
+ print_df_as_table(
65
+ df, title=f"Timing stats (over {self.n_steps_since_last_log} steps)"
66
+ )
67
+ self.n_steps_since_last_log = 0
@@ -0,0 +1,278 @@
1
+ import time
2
+ from collections import defaultdict
3
+
4
+ import pandas as pd
5
+ from atomworks.ml.example_id import parse_example_id
6
+ from beartype.typing import Any
7
+ from lightning.fabric.wrappers import (
8
+ _FabricOptimizer,
9
+ )
10
+ from rf3.utils.loss import convert_batched_losses_to_list_of_dicts, mean_losses
11
+ from rich.console import Group
12
+ from rich.panel import Panel
13
+ from rich.table import Table
14
+ from torchmetrics.aggregation import MeanMetric
15
+
16
+ from foundry.callbacks.callback import BaseCallback
17
+ from foundry.utils.ddp import RankedLogger
18
+ from foundry.utils.logging import (
19
+ print_df_as_table,
20
+ print_model_parameters,
21
+ safe_print,
22
+ table_from_df,
23
+ )
24
+
25
+
26
+ class LogModelParametersCallback(BaseCallback):
27
+ """Print a table of the total and trainable parameters of the model at the start of training."""
28
+
29
+ def on_fit_start(self, trainer):
30
+ print_model_parameters(trainer.state["model"])
31
+
32
+
33
+ class PrintExampleIDBeforeForwardPassCallback(BaseCallback):
34
+ """Print the example ID for each rank at the start of the forward pass for each batch.
35
+
36
+ WARNING: Spams the console. Use only for debugging purposes.
37
+ """
38
+
39
+ def __init__(self, rank_zero_only: bool = True):
40
+ self.logger = RankedLogger(__name__, rank_zero_only=rank_zero_only)
41
+
42
+ def on_train_batch_start(self, trainer, batch: Any, batch_idx: int):
43
+ example_id = batch[0]["example_id"]
44
+
45
+ # Prepare the formatted strings with colors
46
+ rank_info = f"[grey]<Rank {trainer.fabric.global_rank}>[/grey]"
47
+ epoch_batch_info = (
48
+ f"[blue]Epoch {trainer.state['current_epoch']} Batch {batch_idx}[/blue]"
49
+ )
50
+ example_id_info = f"[bold yellow]Example ID: {example_id}[/bold yellow]"
51
+
52
+ safe_print(
53
+ f"{rank_info} {epoch_batch_info} - {example_id_info}",
54
+ logger=self.logger,
55
+ )
56
+
57
+
58
+ class LogDatasetSamplingRatiosCallback(BaseCallback):
59
+ """Monitor the sampling ratios of the datasets and log after each epoch."""
60
+
61
+ def on_fit_start(self, trainer):
62
+ self.dataset_sampling_counts = defaultdict(int)
63
+
64
+ def on_train_batch_start(self, trainer, batch, batch_idx):
65
+ example_id = batch[0]["example_id"]
66
+
67
+ if trainer.fabric.is_global_zero:
68
+ dataset_string = "/".join(parse_example_id(example_id)["datasets"])
69
+ self.dataset_sampling_counts[dataset_string] += 1
70
+
71
+ def on_train_epoch_end(self, trainer):
72
+ if trainer.fabric.is_global_zero:
73
+ total_samples = sum(self.dataset_sampling_counts.values())
74
+
75
+ data = {
76
+ "Dataset": list(self.dataset_sampling_counts.keys()),
77
+ "Count": list(self.dataset_sampling_counts.values()),
78
+ "Percentage": [
79
+ f"{(count / total_samples) * 100:.2f}%"
80
+ for count in self.dataset_sampling_counts.values()
81
+ ],
82
+ }
83
+
84
+ print_df_as_table(
85
+ df=pd.DataFrame(data),
86
+ title=f"Epoch {trainer.state['current_epoch']}: Dataset Sampling Ratios",
87
+ )
88
+
89
+ # Reset the counts for the next epoch
90
+ self.dataset_sampling_counts.clear()
91
+
92
+
93
+ class LogLearningRateCallback(BaseCallback):
94
+ """Monitor the learning rate of the optimizer
95
+
96
+ Args:
97
+ log_every_n: Log the learning rate every n optimizer steps.
98
+ """
99
+
100
+ def __init__(self, log_every_n: int):
101
+ self.log_every_n = log_every_n
102
+
103
+ def optimizer_step(self, trainer, optimizer: _FabricOptimizer):
104
+ # Get the current global step
105
+ current_step = trainer.state["global_step"]
106
+
107
+ # Log the learning rate only every `log_every_n` steps
108
+ if current_step % self.log_every_n == 0:
109
+ trainer.fabric.log(
110
+ "train/learning_rate",
111
+ optimizer.param_groups[0]["lr"],
112
+ step=current_step,
113
+ )
114
+
115
+
116
+ class LogAF3TrainingLossesCallback(BaseCallback):
117
+ """Log the primary model losses for AF3.
118
+
119
+ Includes:
120
+ - The mean training losses every `log_every_n` batches
121
+ - The mean training losses at the end of each epoch
122
+ - The time taken to complete each epoch
123
+ - (Optionally) The full batch losses for each structure in the diffusion batch
124
+
125
+ Args:
126
+ log_every_n (int): Print the training loss after every n batches.
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ log_full_batch_losses: bool = False,
132
+ log_every_n: int = 10,
133
+ ):
134
+ """
135
+ Args:
136
+ log_full_batch_losses(bool): Log losses for every structure within the diffusion batch.
137
+ log_every_n (int): Print the training loss after every n batches.
138
+ console_width (int): Width of the console for printing.
139
+ """
140
+ self.log_every_n = log_every_n
141
+ self.log_full_batch_losses = log_full_batch_losses
142
+
143
+ self.start_time = None
144
+ self.logger = RankedLogger(__name__, rank_zero_only=True)
145
+
146
+ # This dict will store key -> MeanMetric() for each loss
147
+ self.loss_trackers = {}
148
+
149
+ def on_train_epoch_start(self, trainer):
150
+ # Record the start time of the epoch
151
+ self.start_time = time.time()
152
+
153
+ def on_train_batch_end(self, trainer, outputs: Any, batch: Any, batch_idx: int):
154
+ mean_loss_dict = {}
155
+ if "loss_dict" in outputs:
156
+ mean_loss_dict.update(mean_losses(outputs["loss_dict"]))
157
+
158
+ for key, val in mean_loss_dict.items():
159
+ if key not in self.loss_trackers:
160
+ self.loss_trackers[key] = trainer.fabric.to_device(MeanMetric())
161
+ self.loss_trackers[key].update(val)
162
+
163
+ if trainer.fabric.is_global_zero and batch_idx % self.log_every_n == 0:
164
+ # ... log losses for each structure in the batch
165
+ if self.log_full_batch_losses:
166
+ full_batch_loss_dicts = convert_batched_losses_to_list_of_dicts(
167
+ outputs["loss_dict"]
168
+ )
169
+ for loss_dict in full_batch_loss_dicts:
170
+ loss_dict = {
171
+ f"train/per_structure/{k}": v for k, v in loss_dict.items()
172
+ }
173
+ trainer.fabric.log_dict(
174
+ loss_dict, step=trainer.state["global_step"]
175
+ )
176
+
177
+ # ... log losses meaned across the batch
178
+ # (Prepend "train/batch_mean" to the keys in the loss dictionary)
179
+ mean_loss_dict_for_logging = {
180
+ f"train/batch_mean/{k}": v for k, v in mean_loss_dict.items()
181
+ }
182
+ trainer.fabric.log_dict(
183
+ mean_loss_dict_for_logging, step=trainer.state["global_step"]
184
+ )
185
+
186
+ # ... print the mean losses in a table
187
+ df_losses = pd.DataFrame(
188
+ {
189
+ "Train Loss Name": [
190
+ k.replace("_", " ").title() for k in mean_loss_dict.keys()
191
+ ],
192
+ "Value": [v for v in mean_loss_dict.values()],
193
+ }
194
+ )
195
+ table = table_from_df(df_losses, title="Training Losses")
196
+
197
+ # (percentage of batch count)
198
+ percentage_complete = (batch_idx / trainer.n_batches_per_epoch) * 100
199
+
200
+ # Simple progress bar using Unicode blocks
201
+ progress_bar_length = 10 # Length of the progress bar
202
+ filled_length = int(progress_bar_length * percentage_complete // 100)
203
+ progress_bar = "█" * filled_length + "░" * (
204
+ progress_bar_length - filled_length
205
+ )
206
+ percentage_str = f"[bold magenta]{percentage_complete:.2f}%[/bold magenta]"
207
+
208
+ # Create a panel for the epoch and batch info with a progress bar
209
+ epoch_batch_info = (
210
+ f"[grey]<Rank {trainer.fabric.global_rank}>[/grey] "
211
+ f"Epoch {trainer.state['current_epoch']} Batch {batch_idx} "
212
+ f"[{progress_bar}] {percentage_str}"
213
+ )
214
+
215
+ epoch_batch_panel = Panel(
216
+ epoch_batch_info,
217
+ border_style="bold blue",
218
+ )
219
+
220
+ # Create a panel for the example ID
221
+ example_id = batch[0]["example_id"]
222
+ example_id_str = f"[bold yellow]{example_id}[/bold yellow]"
223
+ example_id_panel = Panel(
224
+ example_id_str,
225
+ border_style="bold green",
226
+ )
227
+
228
+ # Combine all components vertically
229
+ combined_content = Group(epoch_batch_panel, example_id_panel, table)
230
+
231
+ safe_print(combined_content)
232
+
233
+ def on_train_epoch_end(self, trainer):
234
+ # Gather final epoch means (must be run on all ranks)
235
+ final_means = {
236
+ k: tracker.compute().item() for k, tracker in self.loss_trackers.items()
237
+ }
238
+
239
+ # Calculate elapsed time and number of batches (from the total_loss tracker, if available)
240
+ elapsed_time = time.time() - self.start_time
241
+ num_batches = (
242
+ self.loss_trackers["total_loss"].update_count
243
+ if "total_loss" in self.loss_trackers
244
+ else trainer.n_batches_per_epoch
245
+ )
246
+
247
+ if trainer.fabric.is_global_zero:
248
+ # Create a summary table
249
+ table = Table(
250
+ title=f"Epoch {trainer.state['current_epoch']} Summary",
251
+ show_header=False,
252
+ header_style="bold magenta",
253
+ )
254
+ table.add_column("Loss Name", style="bold cyan", justify="left")
255
+ table.add_column("Value", style="green", justify="right")
256
+
257
+ for k, v in final_means.items():
258
+ table.add_row(f"<Train> Mean {k}", f"{v:.4f}")
259
+
260
+ table.add_section()
261
+ table.add_row("Total Optimizer Steps", str(trainer.state["global_step"]))
262
+ table.add_row("Number of Batches", str(num_batches))
263
+ table.add_row("Elapsed Time (s)", f"{elapsed_time:.2f}")
264
+ table.add_row(
265
+ "Mean Time per Batch (s)", f"{elapsed_time / num_batches:.2f}"
266
+ )
267
+
268
+ safe_print(table)
269
+
270
+ # Log these final epoch means (prepend "train/per_epoch_" to each key)
271
+ trainer.fabric.log_dict(
272
+ {f"train/per_epoch_{k}": v for k, v in final_means.items()},
273
+ step=trainer.state["current_epoch"],
274
+ )
275
+
276
+ # Reset the trackers for the next epoch
277
+ for metric in self.loss_trackers.values():
278
+ metric.reset()
foundry/common.py ADDED
@@ -0,0 +1,108 @@
1
+ from functools import wraps
2
+
3
+ import torch
4
+ from beartype.typing import Any, Callable, Iterable
5
+ from toolz import merge_with
6
+
7
+
8
+ def run_once(fn: Callable) -> Callable:
9
+ """Decorator to ensure a function is only executed once per process.
10
+
11
+ Args:
12
+ fn (Callable): The function to decorate.
13
+
14
+ Returns:
15
+ Callable: A wrapped function that only executes once.
16
+ """
17
+
18
+ @wraps(fn)
19
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
20
+ if getattr(wrapper, "_has_run", False):
21
+ return
22
+ wrapper._has_run = True
23
+ return fn(*args, **kwargs)
24
+
25
+ return wrapper
26
+
27
+
28
+ def do_nothing(*args: Any, **kwargs: Any) -> None:
29
+ """Does nothing, just returns None"""
30
+ pass
31
+
32
+
33
+ def exists(obj: Any) -> bool:
34
+ """True iff object is not None"""
35
+ return obj is not None
36
+
37
+
38
+ def default(obj: Any, default: Any) -> Any:
39
+ """Return obj if it exists, otherwise return default"""
40
+ return obj if exists(obj) else default
41
+
42
+
43
+ def exactly_one_exists(*args: object) -> bool:
44
+ """True iff exactly one of the arguments exists"""
45
+ return sum(exists(arg) for arg in args) == 1
46
+
47
+
48
+ def at_least_one_exists(*args: object) -> bool:
49
+ """True iff at least one of the arguments exists"""
50
+ return any(exists(arg) for arg in args)
51
+
52
+
53
+ def concat_dicts(*dicts: dict) -> dict:
54
+ """
55
+ Concatenate a list of dicts with the same keys into a single dict.
56
+
57
+ Example:
58
+ >>> d1 = {"a": 1, "b": 2}
59
+ >>> d2 = {"a": 3, "b": 4}
60
+ >>> concat_dicts(d1, d2)
61
+ {'a': [1, 3], 'b': [2, 4]}
62
+ """
63
+ return merge_with(list, *dicts)
64
+
65
+
66
+ def listmap(fn: Callable, lst: Iterable[Any]) -> list:
67
+ """
68
+ Apply a function to each element of a single list.
69
+
70
+ Args:
71
+ - fn (Callable): Function to apply to each element
72
+ - lst (list): Input list
73
+
74
+ Returns:
75
+ - list: Result of applying fn to each element
76
+
77
+ Example:
78
+ >>> listmap(lambda x: x + 1, [1, 2, 3])
79
+ [2, 3, 4]
80
+ """
81
+ return [fn(x) for x in lst]
82
+
83
+
84
+ def listmap_with_idx(fn: Callable[[int, Any], Any], lst: Iterable[Any]) -> list:
85
+ """Maps a function over a list while providing both index and value to the function.
86
+
87
+ A convenience wrapper around listmap that allows the mapping function to access both the index and value
88
+ of each element in the input list.
89
+
90
+ Args:
91
+ - fn (Callable[[int, Any], Any]): Function that takes two arguments (index, value) and returns a transformed value.
92
+ - lst (list): Input list to map over.
93
+
94
+ Returns:
95
+ - list: New list containing the results of applying fn to each (index, value) pair.
96
+
97
+ Example:
98
+ >>> def add_index(i, x):
99
+ ... return f"{i}_{x}"
100
+ >>> listmap_with_idx(add_index, ["a", "b", "c"])
101
+ ['0_a', '1_b', '2_c']
102
+ """
103
+ return [fn(idx, x) for idx, x in enumerate(lst)]
104
+
105
+
106
+ def ensure_dtype(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
107
+ """Convert tensor to target dtype if it's not already that dtype."""
108
+ return tensor if tensor.dtype == dtype else tensor.to(dtype)
foundry/constants.py ADDED
@@ -0,0 +1,28 @@
1
+ # fmt: off
2
+ # ... For convenience, define BKBN, or TIP to be used as a shortcut | TIP is the largest set of fixed atom given at least 2 tip atoms
3
+ TIP_BY_RESTYPE = {
4
+ "TRP": ["CG","CD1","CD2","NE1","CE2","CE3","CZ2","CZ3","CH2"], # fix both rings
5
+ "HIS": ["CG","ND1","CD2","CE1","NE2"], # fixed ring
6
+ "TYR": ["CZ","OH"], # keeps ring dihedral flexible
7
+ "PHE": ["CG","CD1","CD2","CE1","CE2","CZ"],
8
+ "ASN": ["CB", "CG","OD1","ND2"],
9
+ "ASP": ["CB", "CG","OD1","OD2"],
10
+ "GLN": ["CG", "CD","OE1","NE2"],
11
+ "GLU": ["CG", "CD","OE1","OE2"],
12
+ "CYS": ["CB", "SG"],
13
+ "SER": ["CB", "OG"],
14
+ "THR": ["CB", "OG1"],
15
+ "LEU": ["CB", "CG", "CD1", "CD2"],
16
+ "VAL": ["CG1", "CG2"],
17
+ "ILE": ["CB", "CG2"],
18
+ "MET": ["SD", "CE"],
19
+ "LYS": ["CE","NZ"],
20
+ "ARG": ["CD","NE","CZ","NH1","NH2"],
21
+ "PRO": None,
22
+ "ALA": None,
23
+ "GLY": None,
24
+ "UNK": None,
25
+ "MSK": None
26
+ }
27
+
28
+ # fmt: on