rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from copy import deepcopy
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from atomworks.ml.utils import nested_dict
|
|
7
|
+
from beartype.typing import Any, Literal
|
|
8
|
+
from omegaconf import ListConfig
|
|
9
|
+
|
|
10
|
+
from foundry.callbacks.callback import BaseCallback
|
|
11
|
+
from foundry.utils.ddp import RankedLogger
|
|
12
|
+
|
|
13
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class StoreValidationMetricsInDFCallback(BaseCallback):
|
|
17
|
+
"""Saves the validation outputs in a DataFrame for each rank and concatenates them at the end of the validation epoch."""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
save_dir: os.PathLike,
|
|
22
|
+
metrics_to_save: list[str] | Literal["all"] = "all",
|
|
23
|
+
):
|
|
24
|
+
self.save_dir = Path(save_dir)
|
|
25
|
+
self.metrics_to_save = metrics_to_save
|
|
26
|
+
|
|
27
|
+
def _save_dataframe_for_rank(self, rank: int, epoch: int):
|
|
28
|
+
"""Saves per-GPU output dataframe of metrics to a rank-specific CSV."""
|
|
29
|
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
|
30
|
+
file_path = self.save_dir / f"validation_output_rank_{rank}_epoch_{epoch}.csv"
|
|
31
|
+
|
|
32
|
+
# Flush explicitly to ensure the file is written to disk
|
|
33
|
+
with open(file_path, "w") as f:
|
|
34
|
+
self.per_gpu_outputs_df.to_csv(f, index=False)
|
|
35
|
+
f.flush()
|
|
36
|
+
os.fsync(f.fileno())
|
|
37
|
+
|
|
38
|
+
ranked_logger.info(
|
|
39
|
+
f"Saved validation outputs to {file_path} for rank {rank}, epoch {epoch}"
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def on_validation_epoch_start(self, trainer):
|
|
43
|
+
self.per_gpu_outputs_df = pd.DataFrame()
|
|
44
|
+
|
|
45
|
+
def on_validation_batch_end(
|
|
46
|
+
self,
|
|
47
|
+
trainer,
|
|
48
|
+
outputs: dict,
|
|
49
|
+
batch: Any,
|
|
50
|
+
batch_idx: int,
|
|
51
|
+
num_batches: int,
|
|
52
|
+
dataset_name: str | None = None,
|
|
53
|
+
):
|
|
54
|
+
"""Build a flattened DataFrame from the metrics output and accumulate with the prior batches"""
|
|
55
|
+
assert "metrics_output" in outputs, "Validation outputs must contain metrics."
|
|
56
|
+
metrics_output = deepcopy(outputs["metrics_output"])
|
|
57
|
+
|
|
58
|
+
# ... assemble a flat DataFrame from the metrics output
|
|
59
|
+
example_id = metrics_output.pop("example_id")
|
|
60
|
+
metrics_as_list_of_dicts = []
|
|
61
|
+
|
|
62
|
+
# ... remove metrics that are not in the save list
|
|
63
|
+
if self.metrics_to_save != "all" and isinstance(
|
|
64
|
+
self.metrics_to_save, list | ListConfig
|
|
65
|
+
):
|
|
66
|
+
metrics_output = {
|
|
67
|
+
k: v
|
|
68
|
+
for k, v in metrics_output.items()
|
|
69
|
+
if any(k.startswith(prefix) for prefix in self.metrics_to_save)
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
def _build_row_from_flattened_dict(
|
|
73
|
+
dict_to_flatten: dict, prefix: str, example_id: str
|
|
74
|
+
):
|
|
75
|
+
"""Helper function to build a DataFrame row"""
|
|
76
|
+
flattened_dict = nested_dict.flatten(dict_to_flatten, fuse_keys=".")
|
|
77
|
+
row_data = {"example_id": example_id}
|
|
78
|
+
for sub_k, sub_v in flattened_dict.items():
|
|
79
|
+
# Convert lists to tuples so that they are hashable
|
|
80
|
+
if isinstance(sub_v, list):
|
|
81
|
+
sub_v = tuple(sub_v)
|
|
82
|
+
row_data[f"{prefix}.{sub_k}"] = sub_v
|
|
83
|
+
return row_data
|
|
84
|
+
|
|
85
|
+
scalar_metrics = {"example_id": example_id}
|
|
86
|
+
for key, value in metrics_output.items():
|
|
87
|
+
if isinstance(value, dict):
|
|
88
|
+
# Flatten once for this dict => 1 row.
|
|
89
|
+
metrics_as_list_of_dicts.append(
|
|
90
|
+
_build_row_from_flattened_dict(value, key, example_id)
|
|
91
|
+
)
|
|
92
|
+
elif isinstance(value, list) and all(isinstance(x, dict) for x in value):
|
|
93
|
+
# Flatten each dict in the list => multiple rows.
|
|
94
|
+
for subdict in value:
|
|
95
|
+
metrics_as_list_of_dicts.append(
|
|
96
|
+
_build_row_from_flattened_dict(subdict, key, example_id)
|
|
97
|
+
)
|
|
98
|
+
else:
|
|
99
|
+
# Scalar (string, float, int, or list that isn't list-of-dicts)
|
|
100
|
+
assert key not in scalar_metrics, f"Duplicate key: {key}"
|
|
101
|
+
scalar_metrics[key] = value
|
|
102
|
+
|
|
103
|
+
metrics_as_list_of_dicts.append(scalar_metrics)
|
|
104
|
+
|
|
105
|
+
# ... convert the list of dicts to a DataFrame and add epoch and dataset columns
|
|
106
|
+
batch_df = pd.DataFrame(metrics_as_list_of_dicts)
|
|
107
|
+
batch_df["epoch"] = trainer.state["current_epoch"]
|
|
108
|
+
batch_df["dataset"] = dataset_name
|
|
109
|
+
|
|
110
|
+
# Assert no duplicate rows
|
|
111
|
+
assert (
|
|
112
|
+
batch_df.duplicated().sum() == 0
|
|
113
|
+
), "Duplicate rows found in the metrics DataFrame!"
|
|
114
|
+
|
|
115
|
+
# Accumulate into the per-rank DataFrame
|
|
116
|
+
self.per_gpu_outputs_df = pd.concat(
|
|
117
|
+
[self.per_gpu_outputs_df, batch_df], ignore_index=True
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
ranked_logger.info(
|
|
121
|
+
f"Validation Progress: {100 * (batch_idx + 1) / num_batches:.0f}% for {dataset_name}"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
def on_validation_epoch_end(self, trainer):
|
|
125
|
+
"""Aggregate and log the validation metrics at the end of the epoch.
|
|
126
|
+
|
|
127
|
+
Each rank writes out its partial CSV. Then rank 0 aggregates them, logs grouped metrics by dataset,
|
|
128
|
+
and appends them to a master file containing data from all epochs.
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
# ... write out partial CSV for this rank
|
|
132
|
+
rank = trainer.fabric.global_rank
|
|
133
|
+
epoch = trainer.state["current_epoch"]
|
|
134
|
+
self._save_dataframe_for_rank(rank, epoch)
|
|
135
|
+
|
|
136
|
+
# Synchronize all processes
|
|
137
|
+
ranked_logger.info(
|
|
138
|
+
"Synchronizing all processes before concatenating DataFrames..."
|
|
139
|
+
)
|
|
140
|
+
trainer.fabric.barrier()
|
|
141
|
+
|
|
142
|
+
# Only rank 0 loads and concatenates the DataFrames
|
|
143
|
+
ranked_logger.info("Loading and concatenating DataFrames...")
|
|
144
|
+
if trainer.fabric.is_global_zero:
|
|
145
|
+
# ... load all partial CSVs
|
|
146
|
+
merged_df = self._load_and_concatenate_csvs(epoch)
|
|
147
|
+
|
|
148
|
+
# ... append to master CSV for all epochs
|
|
149
|
+
master_path = self.save_dir / "validation_output_all_epochs.csv"
|
|
150
|
+
if master_path.exists():
|
|
151
|
+
old_df = pd.read_csv(master_path)
|
|
152
|
+
merged_df = pd.concat(
|
|
153
|
+
[old_df, merged_df], ignore_index=True, sort=False
|
|
154
|
+
)
|
|
155
|
+
merged_df.to_csv(master_path, index=False)
|
|
156
|
+
ranked_logger.info(f"Appended epoch={epoch} results to {master_path}")
|
|
157
|
+
|
|
158
|
+
# Store the path to the master CSV in the Trainer
|
|
159
|
+
trainer.validation_results_path = master_path
|
|
160
|
+
|
|
161
|
+
# Cleanup
|
|
162
|
+
self._cleanup_temp_files()
|
|
163
|
+
|
|
164
|
+
def _load_and_concatenate_csvs(self, epoch: int) -> pd.DataFrame:
|
|
165
|
+
"""Load rank-specific CSVs for the given epoch and concatenate them without duplicating examples."""
|
|
166
|
+
pattern = f"validation_output_rank_*_epoch_{epoch}.csv"
|
|
167
|
+
files = list(self.save_dir.glob(pattern))
|
|
168
|
+
|
|
169
|
+
# Track which example_id + dataset combinations we've already seen
|
|
170
|
+
seen_examples = set()
|
|
171
|
+
final_dataframes = []
|
|
172
|
+
|
|
173
|
+
for f in files:
|
|
174
|
+
try:
|
|
175
|
+
df = pd.read_csv(f)
|
|
176
|
+
|
|
177
|
+
# Create a filter for rows with new example_id + dataset combinations
|
|
178
|
+
if not df.empty:
|
|
179
|
+
# Create a unique identifier for each example_id + dataset combination
|
|
180
|
+
df["_example_key"] = (
|
|
181
|
+
df["example_id"].astype(str) + "|" + df["dataset"].astype(str)
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# Filter out rows with example_id + dataset combinations we've already seen
|
|
185
|
+
new_examples_mask = ~df["_example_key"].isin(seen_examples)
|
|
186
|
+
|
|
187
|
+
# If there are any new examples, add them to our final list
|
|
188
|
+
if new_examples_mask.any():
|
|
189
|
+
new_examples_df = df[new_examples_mask].copy()
|
|
190
|
+
|
|
191
|
+
# Update our set of seen examples
|
|
192
|
+
seen_examples.update(new_examples_df["_example_key"].tolist())
|
|
193
|
+
|
|
194
|
+
# Remove the temporary column before adding to final list
|
|
195
|
+
new_examples_df.drop("_example_key", axis=1, inplace=True)
|
|
196
|
+
final_dataframes.append(new_examples_df)
|
|
197
|
+
|
|
198
|
+
except pd.errors.EmptyDataError:
|
|
199
|
+
ranked_logger.warning(f"Skipping empty CSV: {f}")
|
|
200
|
+
|
|
201
|
+
# Concatenate dataframes, filling missing columns with NaN
|
|
202
|
+
return pd.concat(final_dataframes, axis=0, ignore_index=True, sort=False)
|
|
203
|
+
|
|
204
|
+
def _cleanup_temp_files(self):
|
|
205
|
+
"""Remove temporary files used to store individual rank outputs."""
|
|
206
|
+
all_files = list(self.save_dir.rglob("validation_output_rank_*_epoch_*.csv"))
|
|
207
|
+
for file in all_files:
|
|
208
|
+
try:
|
|
209
|
+
file.unlink() # Remove the file
|
|
210
|
+
except Exception as e:
|
|
211
|
+
ranked_logger.warning(f"Failed to delete file {file}: {e}")
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
|
3
|
+
|
|
4
|
+
from foundry.callbacks.callback import BaseCallback
|
|
5
|
+
from foundry.utils.logging import print_df_as_table
|
|
6
|
+
from foundry.utils.torch import Timers
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TimingCallback(BaseCallback):
|
|
10
|
+
"""Fabric callback to print timing metrics."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, log_every_n: int = 100):
|
|
13
|
+
super().__init__()
|
|
14
|
+
self.log_every_n = log_every_n
|
|
15
|
+
self.timers = Timers()
|
|
16
|
+
self.n_steps_since_last_log = 0
|
|
17
|
+
|
|
18
|
+
@rank_zero_only
|
|
19
|
+
def on_train_epoch_start(self, trainer, **kwargs):
|
|
20
|
+
self.timers.start("train_loader_iter")
|
|
21
|
+
|
|
22
|
+
@rank_zero_only
|
|
23
|
+
def on_after_train_loader_iter(self, trainer, **kwargs):
|
|
24
|
+
self.timers.stop("train_loader_iter")
|
|
25
|
+
|
|
26
|
+
@rank_zero_only
|
|
27
|
+
def on_before_train_loader_next(self, trainer, **kwargs):
|
|
28
|
+
self.timers.start("train_step", "train_loader_next")
|
|
29
|
+
|
|
30
|
+
@rank_zero_only
|
|
31
|
+
def on_train_batch_start(self, trainer, **kwargs):
|
|
32
|
+
self.timers.start("forward_loss_backward")
|
|
33
|
+
self.timers.stop("train_loader_next")
|
|
34
|
+
|
|
35
|
+
@rank_zero_only
|
|
36
|
+
def on_train_batch_end(self, trainer, **kwargs):
|
|
37
|
+
self.timers.stop("forward_loss_backward")
|
|
38
|
+
self.timers.stop("train_step")
|
|
39
|
+
|
|
40
|
+
@rank_zero_only
|
|
41
|
+
def on_before_optimizer_step(self, trainer, **kwargs):
|
|
42
|
+
self.timers.start("optimizer_step")
|
|
43
|
+
|
|
44
|
+
@rank_zero_only
|
|
45
|
+
def on_after_optimizer_step(self, optimizer, **kwargs):
|
|
46
|
+
self.timers.stop("optimizer_step")
|
|
47
|
+
|
|
48
|
+
@rank_zero_only
|
|
49
|
+
def optimizer_step(self, trainer, optimizer):
|
|
50
|
+
step = trainer.state["global_step"]
|
|
51
|
+
self.n_steps_since_last_log += 1
|
|
52
|
+
if step % self.log_every_n == 0:
|
|
53
|
+
timings = self.timers.elapsed(*self.timers.timers.keys(), reset=True)
|
|
54
|
+
timings = {
|
|
55
|
+
f"timings/{k}": v / self.n_steps_since_last_log
|
|
56
|
+
for k, v in timings.items()
|
|
57
|
+
}
|
|
58
|
+
trainer.fabric.log_dict(timings, step=step)
|
|
59
|
+
if trainer.fabric.is_global_zero:
|
|
60
|
+
self._print_timings(timings)
|
|
61
|
+
|
|
62
|
+
def _print_timings(self, timings: dict[str, float]):
|
|
63
|
+
df = pd.DataFrame(timings.items(), columns=["Step", "Time (s)"])
|
|
64
|
+
print_df_as_table(
|
|
65
|
+
df, title=f"Timing stats (over {self.n_steps_since_last_log} steps)"
|
|
66
|
+
)
|
|
67
|
+
self.n_steps_since_last_log = 0
|
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from atomworks.ml.example_id import parse_example_id
|
|
6
|
+
from beartype.typing import Any
|
|
7
|
+
from lightning.fabric.wrappers import (
|
|
8
|
+
_FabricOptimizer,
|
|
9
|
+
)
|
|
10
|
+
from rf3.utils.loss import convert_batched_losses_to_list_of_dicts, mean_losses
|
|
11
|
+
from rich.console import Group
|
|
12
|
+
from rich.panel import Panel
|
|
13
|
+
from rich.table import Table
|
|
14
|
+
from torchmetrics.aggregation import MeanMetric
|
|
15
|
+
|
|
16
|
+
from foundry.callbacks.callback import BaseCallback
|
|
17
|
+
from foundry.utils.ddp import RankedLogger
|
|
18
|
+
from foundry.utils.logging import (
|
|
19
|
+
print_df_as_table,
|
|
20
|
+
print_model_parameters,
|
|
21
|
+
safe_print,
|
|
22
|
+
table_from_df,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class LogModelParametersCallback(BaseCallback):
|
|
27
|
+
"""Print a table of the total and trainable parameters of the model at the start of training."""
|
|
28
|
+
|
|
29
|
+
def on_fit_start(self, trainer):
|
|
30
|
+
print_model_parameters(trainer.state["model"])
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class PrintExampleIDBeforeForwardPassCallback(BaseCallback):
|
|
34
|
+
"""Print the example ID for each rank at the start of the forward pass for each batch.
|
|
35
|
+
|
|
36
|
+
WARNING: Spams the console. Use only for debugging purposes.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self, rank_zero_only: bool = True):
|
|
40
|
+
self.logger = RankedLogger(__name__, rank_zero_only=rank_zero_only)
|
|
41
|
+
|
|
42
|
+
def on_train_batch_start(self, trainer, batch: Any, batch_idx: int):
|
|
43
|
+
example_id = batch[0]["example_id"]
|
|
44
|
+
|
|
45
|
+
# Prepare the formatted strings with colors
|
|
46
|
+
rank_info = f"[grey]<Rank {trainer.fabric.global_rank}>[/grey]"
|
|
47
|
+
epoch_batch_info = (
|
|
48
|
+
f"[blue]Epoch {trainer.state['current_epoch']} Batch {batch_idx}[/blue]"
|
|
49
|
+
)
|
|
50
|
+
example_id_info = f"[bold yellow]Example ID: {example_id}[/bold yellow]"
|
|
51
|
+
|
|
52
|
+
safe_print(
|
|
53
|
+
f"{rank_info} {epoch_batch_info} - {example_id_info}",
|
|
54
|
+
logger=self.logger,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class LogDatasetSamplingRatiosCallback(BaseCallback):
|
|
59
|
+
"""Monitor the sampling ratios of the datasets and log after each epoch."""
|
|
60
|
+
|
|
61
|
+
def on_fit_start(self, trainer):
|
|
62
|
+
self.dataset_sampling_counts = defaultdict(int)
|
|
63
|
+
|
|
64
|
+
def on_train_batch_start(self, trainer, batch, batch_idx):
|
|
65
|
+
example_id = batch[0]["example_id"]
|
|
66
|
+
|
|
67
|
+
if trainer.fabric.is_global_zero:
|
|
68
|
+
dataset_string = "/".join(parse_example_id(example_id)["datasets"])
|
|
69
|
+
self.dataset_sampling_counts[dataset_string] += 1
|
|
70
|
+
|
|
71
|
+
def on_train_epoch_end(self, trainer):
|
|
72
|
+
if trainer.fabric.is_global_zero:
|
|
73
|
+
total_samples = sum(self.dataset_sampling_counts.values())
|
|
74
|
+
|
|
75
|
+
data = {
|
|
76
|
+
"Dataset": list(self.dataset_sampling_counts.keys()),
|
|
77
|
+
"Count": list(self.dataset_sampling_counts.values()),
|
|
78
|
+
"Percentage": [
|
|
79
|
+
f"{(count / total_samples) * 100:.2f}%"
|
|
80
|
+
for count in self.dataset_sampling_counts.values()
|
|
81
|
+
],
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
print_df_as_table(
|
|
85
|
+
df=pd.DataFrame(data),
|
|
86
|
+
title=f"Epoch {trainer.state['current_epoch']}: Dataset Sampling Ratios",
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# Reset the counts for the next epoch
|
|
90
|
+
self.dataset_sampling_counts.clear()
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class LogLearningRateCallback(BaseCallback):
|
|
94
|
+
"""Monitor the learning rate of the optimizer
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
log_every_n: Log the learning rate every n optimizer steps.
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
def __init__(self, log_every_n: int):
|
|
101
|
+
self.log_every_n = log_every_n
|
|
102
|
+
|
|
103
|
+
def optimizer_step(self, trainer, optimizer: _FabricOptimizer):
|
|
104
|
+
# Get the current global step
|
|
105
|
+
current_step = trainer.state["global_step"]
|
|
106
|
+
|
|
107
|
+
# Log the learning rate only every `log_every_n` steps
|
|
108
|
+
if current_step % self.log_every_n == 0:
|
|
109
|
+
trainer.fabric.log(
|
|
110
|
+
"train/learning_rate",
|
|
111
|
+
optimizer.param_groups[0]["lr"],
|
|
112
|
+
step=current_step,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class LogAF3TrainingLossesCallback(BaseCallback):
|
|
117
|
+
"""Log the primary model losses for AF3.
|
|
118
|
+
|
|
119
|
+
Includes:
|
|
120
|
+
- The mean training losses every `log_every_n` batches
|
|
121
|
+
- The mean training losses at the end of each epoch
|
|
122
|
+
- The time taken to complete each epoch
|
|
123
|
+
- (Optionally) The full batch losses for each structure in the diffusion batch
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
log_every_n (int): Print the training loss after every n batches.
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
def __init__(
|
|
130
|
+
self,
|
|
131
|
+
log_full_batch_losses: bool = False,
|
|
132
|
+
log_every_n: int = 10,
|
|
133
|
+
):
|
|
134
|
+
"""
|
|
135
|
+
Args:
|
|
136
|
+
log_full_batch_losses(bool): Log losses for every structure within the diffusion batch.
|
|
137
|
+
log_every_n (int): Print the training loss after every n batches.
|
|
138
|
+
console_width (int): Width of the console for printing.
|
|
139
|
+
"""
|
|
140
|
+
self.log_every_n = log_every_n
|
|
141
|
+
self.log_full_batch_losses = log_full_batch_losses
|
|
142
|
+
|
|
143
|
+
self.start_time = None
|
|
144
|
+
self.logger = RankedLogger(__name__, rank_zero_only=True)
|
|
145
|
+
|
|
146
|
+
# This dict will store key -> MeanMetric() for each loss
|
|
147
|
+
self.loss_trackers = {}
|
|
148
|
+
|
|
149
|
+
def on_train_epoch_start(self, trainer):
|
|
150
|
+
# Record the start time of the epoch
|
|
151
|
+
self.start_time = time.time()
|
|
152
|
+
|
|
153
|
+
def on_train_batch_end(self, trainer, outputs: Any, batch: Any, batch_idx: int):
|
|
154
|
+
mean_loss_dict = {}
|
|
155
|
+
if "loss_dict" in outputs:
|
|
156
|
+
mean_loss_dict.update(mean_losses(outputs["loss_dict"]))
|
|
157
|
+
|
|
158
|
+
for key, val in mean_loss_dict.items():
|
|
159
|
+
if key not in self.loss_trackers:
|
|
160
|
+
self.loss_trackers[key] = trainer.fabric.to_device(MeanMetric())
|
|
161
|
+
self.loss_trackers[key].update(val)
|
|
162
|
+
|
|
163
|
+
if trainer.fabric.is_global_zero and batch_idx % self.log_every_n == 0:
|
|
164
|
+
# ... log losses for each structure in the batch
|
|
165
|
+
if self.log_full_batch_losses:
|
|
166
|
+
full_batch_loss_dicts = convert_batched_losses_to_list_of_dicts(
|
|
167
|
+
outputs["loss_dict"]
|
|
168
|
+
)
|
|
169
|
+
for loss_dict in full_batch_loss_dicts:
|
|
170
|
+
loss_dict = {
|
|
171
|
+
f"train/per_structure/{k}": v for k, v in loss_dict.items()
|
|
172
|
+
}
|
|
173
|
+
trainer.fabric.log_dict(
|
|
174
|
+
loss_dict, step=trainer.state["global_step"]
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# ... log losses meaned across the batch
|
|
178
|
+
# (Prepend "train/batch_mean" to the keys in the loss dictionary)
|
|
179
|
+
mean_loss_dict_for_logging = {
|
|
180
|
+
f"train/batch_mean/{k}": v for k, v in mean_loss_dict.items()
|
|
181
|
+
}
|
|
182
|
+
trainer.fabric.log_dict(
|
|
183
|
+
mean_loss_dict_for_logging, step=trainer.state["global_step"]
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# ... print the mean losses in a table
|
|
187
|
+
df_losses = pd.DataFrame(
|
|
188
|
+
{
|
|
189
|
+
"Train Loss Name": [
|
|
190
|
+
k.replace("_", " ").title() for k in mean_loss_dict.keys()
|
|
191
|
+
],
|
|
192
|
+
"Value": [v for v in mean_loss_dict.values()],
|
|
193
|
+
}
|
|
194
|
+
)
|
|
195
|
+
table = table_from_df(df_losses, title="Training Losses")
|
|
196
|
+
|
|
197
|
+
# (percentage of batch count)
|
|
198
|
+
percentage_complete = (batch_idx / trainer.n_batches_per_epoch) * 100
|
|
199
|
+
|
|
200
|
+
# Simple progress bar using Unicode blocks
|
|
201
|
+
progress_bar_length = 10 # Length of the progress bar
|
|
202
|
+
filled_length = int(progress_bar_length * percentage_complete // 100)
|
|
203
|
+
progress_bar = "█" * filled_length + "░" * (
|
|
204
|
+
progress_bar_length - filled_length
|
|
205
|
+
)
|
|
206
|
+
percentage_str = f"[bold magenta]{percentage_complete:.2f}%[/bold magenta]"
|
|
207
|
+
|
|
208
|
+
# Create a panel for the epoch and batch info with a progress bar
|
|
209
|
+
epoch_batch_info = (
|
|
210
|
+
f"[grey]<Rank {trainer.fabric.global_rank}>[/grey] "
|
|
211
|
+
f"Epoch {trainer.state['current_epoch']} Batch {batch_idx} "
|
|
212
|
+
f"[{progress_bar}] {percentage_str}"
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
epoch_batch_panel = Panel(
|
|
216
|
+
epoch_batch_info,
|
|
217
|
+
border_style="bold blue",
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# Create a panel for the example ID
|
|
221
|
+
example_id = batch[0]["example_id"]
|
|
222
|
+
example_id_str = f"[bold yellow]{example_id}[/bold yellow]"
|
|
223
|
+
example_id_panel = Panel(
|
|
224
|
+
example_id_str,
|
|
225
|
+
border_style="bold green",
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
# Combine all components vertically
|
|
229
|
+
combined_content = Group(epoch_batch_panel, example_id_panel, table)
|
|
230
|
+
|
|
231
|
+
safe_print(combined_content)
|
|
232
|
+
|
|
233
|
+
def on_train_epoch_end(self, trainer):
|
|
234
|
+
# Gather final epoch means (must be run on all ranks)
|
|
235
|
+
final_means = {
|
|
236
|
+
k: tracker.compute().item() for k, tracker in self.loss_trackers.items()
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
# Calculate elapsed time and number of batches (from the total_loss tracker, if available)
|
|
240
|
+
elapsed_time = time.time() - self.start_time
|
|
241
|
+
num_batches = (
|
|
242
|
+
self.loss_trackers["total_loss"].update_count
|
|
243
|
+
if "total_loss" in self.loss_trackers
|
|
244
|
+
else trainer.n_batches_per_epoch
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
if trainer.fabric.is_global_zero:
|
|
248
|
+
# Create a summary table
|
|
249
|
+
table = Table(
|
|
250
|
+
title=f"Epoch {trainer.state['current_epoch']} Summary",
|
|
251
|
+
show_header=False,
|
|
252
|
+
header_style="bold magenta",
|
|
253
|
+
)
|
|
254
|
+
table.add_column("Loss Name", style="bold cyan", justify="left")
|
|
255
|
+
table.add_column("Value", style="green", justify="right")
|
|
256
|
+
|
|
257
|
+
for k, v in final_means.items():
|
|
258
|
+
table.add_row(f"<Train> Mean {k}", f"{v:.4f}")
|
|
259
|
+
|
|
260
|
+
table.add_section()
|
|
261
|
+
table.add_row("Total Optimizer Steps", str(trainer.state["global_step"]))
|
|
262
|
+
table.add_row("Number of Batches", str(num_batches))
|
|
263
|
+
table.add_row("Elapsed Time (s)", f"{elapsed_time:.2f}")
|
|
264
|
+
table.add_row(
|
|
265
|
+
"Mean Time per Batch (s)", f"{elapsed_time / num_batches:.2f}"
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
safe_print(table)
|
|
269
|
+
|
|
270
|
+
# Log these final epoch means (prepend "train/per_epoch_" to each key)
|
|
271
|
+
trainer.fabric.log_dict(
|
|
272
|
+
{f"train/per_epoch_{k}": v for k, v in final_means.items()},
|
|
273
|
+
step=trainer.state["current_epoch"],
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# Reset the trackers for the next epoch
|
|
277
|
+
for metric in self.loss_trackers.values():
|
|
278
|
+
metric.reset()
|
foundry/common.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
from functools import wraps
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from beartype.typing import Any, Callable, Iterable
|
|
5
|
+
from toolz import merge_with
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def run_once(fn: Callable) -> Callable:
|
|
9
|
+
"""Decorator to ensure a function is only executed once per process.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
fn (Callable): The function to decorate.
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
Callable: A wrapped function that only executes once.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
@wraps(fn)
|
|
19
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
20
|
+
if getattr(wrapper, "_has_run", False):
|
|
21
|
+
return
|
|
22
|
+
wrapper._has_run = True
|
|
23
|
+
return fn(*args, **kwargs)
|
|
24
|
+
|
|
25
|
+
return wrapper
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def do_nothing(*args: Any, **kwargs: Any) -> None:
|
|
29
|
+
"""Does nothing, just returns None"""
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def exists(obj: Any) -> bool:
|
|
34
|
+
"""True iff object is not None"""
|
|
35
|
+
return obj is not None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def default(obj: Any, default: Any) -> Any:
|
|
39
|
+
"""Return obj if it exists, otherwise return default"""
|
|
40
|
+
return obj if exists(obj) else default
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def exactly_one_exists(*args: object) -> bool:
|
|
44
|
+
"""True iff exactly one of the arguments exists"""
|
|
45
|
+
return sum(exists(arg) for arg in args) == 1
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def at_least_one_exists(*args: object) -> bool:
|
|
49
|
+
"""True iff at least one of the arguments exists"""
|
|
50
|
+
return any(exists(arg) for arg in args)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def concat_dicts(*dicts: dict) -> dict:
|
|
54
|
+
"""
|
|
55
|
+
Concatenate a list of dicts with the same keys into a single dict.
|
|
56
|
+
|
|
57
|
+
Example:
|
|
58
|
+
>>> d1 = {"a": 1, "b": 2}
|
|
59
|
+
>>> d2 = {"a": 3, "b": 4}
|
|
60
|
+
>>> concat_dicts(d1, d2)
|
|
61
|
+
{'a': [1, 3], 'b': [2, 4]}
|
|
62
|
+
"""
|
|
63
|
+
return merge_with(list, *dicts)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def listmap(fn: Callable, lst: Iterable[Any]) -> list:
|
|
67
|
+
"""
|
|
68
|
+
Apply a function to each element of a single list.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
- fn (Callable): Function to apply to each element
|
|
72
|
+
- lst (list): Input list
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
- list: Result of applying fn to each element
|
|
76
|
+
|
|
77
|
+
Example:
|
|
78
|
+
>>> listmap(lambda x: x + 1, [1, 2, 3])
|
|
79
|
+
[2, 3, 4]
|
|
80
|
+
"""
|
|
81
|
+
return [fn(x) for x in lst]
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def listmap_with_idx(fn: Callable[[int, Any], Any], lst: Iterable[Any]) -> list:
|
|
85
|
+
"""Maps a function over a list while providing both index and value to the function.
|
|
86
|
+
|
|
87
|
+
A convenience wrapper around listmap that allows the mapping function to access both the index and value
|
|
88
|
+
of each element in the input list.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
- fn (Callable[[int, Any], Any]): Function that takes two arguments (index, value) and returns a transformed value.
|
|
92
|
+
- lst (list): Input list to map over.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
- list: New list containing the results of applying fn to each (index, value) pair.
|
|
96
|
+
|
|
97
|
+
Example:
|
|
98
|
+
>>> def add_index(i, x):
|
|
99
|
+
... return f"{i}_{x}"
|
|
100
|
+
>>> listmap_with_idx(add_index, ["a", "b", "c"])
|
|
101
|
+
['0_a', '1_b', '2_c']
|
|
102
|
+
"""
|
|
103
|
+
return [fn(idx, x) for idx, x in enumerate(lst)]
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def ensure_dtype(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
|
107
|
+
"""Convert tensor to target dtype if it's not already that dtype."""
|
|
108
|
+
return tensor if tensor.dtype == dtype else tensor.to(dtype)
|
foundry/constants.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# fmt: off
|
|
2
|
+
# ... For convenience, define BKBN, or TIP to be used as a shortcut | TIP is the largest set of fixed atom given at least 2 tip atoms
|
|
3
|
+
TIP_BY_RESTYPE = {
|
|
4
|
+
"TRP": ["CG","CD1","CD2","NE1","CE2","CE3","CZ2","CZ3","CH2"], # fix both rings
|
|
5
|
+
"HIS": ["CG","ND1","CD2","CE1","NE2"], # fixed ring
|
|
6
|
+
"TYR": ["CZ","OH"], # keeps ring dihedral flexible
|
|
7
|
+
"PHE": ["CG","CD1","CD2","CE1","CE2","CZ"],
|
|
8
|
+
"ASN": ["CB", "CG","OD1","ND2"],
|
|
9
|
+
"ASP": ["CB", "CG","OD1","OD2"],
|
|
10
|
+
"GLN": ["CG", "CD","OE1","NE2"],
|
|
11
|
+
"GLU": ["CG", "CD","OE1","OE2"],
|
|
12
|
+
"CYS": ["CB", "SG"],
|
|
13
|
+
"SER": ["CB", "OG"],
|
|
14
|
+
"THR": ["CB", "OG1"],
|
|
15
|
+
"LEU": ["CB", "CG", "CD1", "CD2"],
|
|
16
|
+
"VAL": ["CG1", "CG2"],
|
|
17
|
+
"ILE": ["CB", "CG2"],
|
|
18
|
+
"MET": ["SD", "CE"],
|
|
19
|
+
"LYS": ["CE","NZ"],
|
|
20
|
+
"ARG": ["CD","NE","CZ","NH1","NH2"],
|
|
21
|
+
"PRO": None,
|
|
22
|
+
"ALA": None,
|
|
23
|
+
"GLY": None,
|
|
24
|
+
"UNK": None,
|
|
25
|
+
"MSK": None
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
# fmt: on
|