rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,324 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from copy import deepcopy
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from atomworks.ml.utils import nested_dict
|
|
7
|
+
from beartype.typing import Any, Literal
|
|
8
|
+
from omegaconf import ListConfig
|
|
9
|
+
|
|
10
|
+
from foundry.callbacks.callback import BaseCallback
|
|
11
|
+
from foundry.utils.ddp import RankedLogger
|
|
12
|
+
from foundry.utils.logging import (
|
|
13
|
+
condense_count_columns_of_grouped_df,
|
|
14
|
+
print_df_as_table,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class StoreValidationMetricsInDFCallback(BaseCallback):
|
|
21
|
+
"""Saves the validation outputs in a DataFrame for each rank and concatenates them at the end of the validation epoch."""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
save_dir: os.PathLike,
|
|
26
|
+
metrics_to_save: list[str] | Literal["all"] = "all",
|
|
27
|
+
):
|
|
28
|
+
self.save_dir = Path(save_dir)
|
|
29
|
+
self.metrics_to_save = metrics_to_save
|
|
30
|
+
|
|
31
|
+
def _save_dataframe_for_rank(self, rank: int, epoch: int):
|
|
32
|
+
"""Saves per-GPU output dataframe of metrics to a rank-specific CSV."""
|
|
33
|
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
|
34
|
+
file_path = self.save_dir / f"validation_output_rank_{rank}_epoch_{epoch}.csv"
|
|
35
|
+
|
|
36
|
+
# Flush explicitly to ensure the file is written to disk
|
|
37
|
+
with open(file_path, "w") as f:
|
|
38
|
+
self.per_gpu_outputs_df.to_csv(f, index=False)
|
|
39
|
+
f.flush()
|
|
40
|
+
os.fsync(f.fileno())
|
|
41
|
+
|
|
42
|
+
ranked_logger.info(
|
|
43
|
+
f"Saved validation outputs to {file_path} for rank {rank}, epoch {epoch}"
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
def on_validation_epoch_start(self, trainer):
|
|
47
|
+
self.per_gpu_outputs_df = pd.DataFrame()
|
|
48
|
+
|
|
49
|
+
def on_validation_batch_end(
|
|
50
|
+
self,
|
|
51
|
+
trainer,
|
|
52
|
+
outputs: dict,
|
|
53
|
+
batch: Any,
|
|
54
|
+
batch_idx: int,
|
|
55
|
+
num_batches: int,
|
|
56
|
+
dataset_name: str | None = None,
|
|
57
|
+
):
|
|
58
|
+
"""Build a flattened DataFrame from the metrics output and accumulate with the prior batches"""
|
|
59
|
+
assert "metrics_output" in outputs, "Validation outputs must contain metrics."
|
|
60
|
+
metrics_output = deepcopy(outputs["metrics_output"])
|
|
61
|
+
|
|
62
|
+
# ... assemble a flat DataFrame from the metrics output
|
|
63
|
+
example_id = metrics_output.pop("example_id")
|
|
64
|
+
metrics_as_list_of_dicts = []
|
|
65
|
+
|
|
66
|
+
# ... remove metrics that are not in the save list
|
|
67
|
+
if self.metrics_to_save != "all" and isinstance(
|
|
68
|
+
self.metrics_to_save, list | ListConfig
|
|
69
|
+
):
|
|
70
|
+
metrics_output = {
|
|
71
|
+
k: v
|
|
72
|
+
for k, v in metrics_output.items()
|
|
73
|
+
if any(k.startswith(prefix) for prefix in self.metrics_to_save)
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
def _build_row_from_flattened_dict(
|
|
77
|
+
dict_to_flatten: dict, prefix: str, example_id: str
|
|
78
|
+
):
|
|
79
|
+
"""Helper function to build a DataFrame row"""
|
|
80
|
+
flattened_dict = nested_dict.flatten(dict_to_flatten, fuse_keys=".")
|
|
81
|
+
row_data = {"example_id": example_id}
|
|
82
|
+
for sub_k, sub_v in flattened_dict.items():
|
|
83
|
+
# Convert lists to tuples so that they are hashable
|
|
84
|
+
if isinstance(sub_v, list):
|
|
85
|
+
sub_v = tuple(sub_v)
|
|
86
|
+
row_data[f"{prefix}.{sub_k}"] = sub_v
|
|
87
|
+
return row_data
|
|
88
|
+
|
|
89
|
+
scalar_metrics = {"example_id": example_id}
|
|
90
|
+
for key, value in metrics_output.items():
|
|
91
|
+
if isinstance(value, dict):
|
|
92
|
+
# Flatten once for this dict => 1 row.
|
|
93
|
+
metrics_as_list_of_dicts.append(
|
|
94
|
+
_build_row_from_flattened_dict(value, key, example_id)
|
|
95
|
+
)
|
|
96
|
+
elif isinstance(value, list) and all(isinstance(x, dict) for x in value):
|
|
97
|
+
# Flatten each dict in the list => multiple rows.
|
|
98
|
+
for subdict in value:
|
|
99
|
+
metrics_as_list_of_dicts.append(
|
|
100
|
+
_build_row_from_flattened_dict(subdict, key, example_id)
|
|
101
|
+
)
|
|
102
|
+
else:
|
|
103
|
+
# Scalar (string, float, int, or list that isn't list-of-dicts)
|
|
104
|
+
assert key not in scalar_metrics, f"Duplicate key: {key}"
|
|
105
|
+
scalar_metrics[key] = value
|
|
106
|
+
|
|
107
|
+
metrics_as_list_of_dicts.append(scalar_metrics)
|
|
108
|
+
|
|
109
|
+
# ... convert the list of dicts to a DataFrame and add epoch and dataset columns
|
|
110
|
+
batch_df = pd.DataFrame(metrics_as_list_of_dicts)
|
|
111
|
+
batch_df["epoch"] = trainer.state["current_epoch"]
|
|
112
|
+
batch_df["dataset"] = dataset_name
|
|
113
|
+
|
|
114
|
+
# Assert no duplicate rows
|
|
115
|
+
assert (
|
|
116
|
+
batch_df.duplicated().sum() == 0
|
|
117
|
+
), "Duplicate rows found in the metrics DataFrame!"
|
|
118
|
+
|
|
119
|
+
# Accumulate into the per-rank DataFrame
|
|
120
|
+
self.per_gpu_outputs_df = pd.concat(
|
|
121
|
+
[self.per_gpu_outputs_df, batch_df], ignore_index=True
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
ranked_logger.info(
|
|
125
|
+
f"Validation Progress: {100 * (batch_idx + 1) / num_batches:.0f}% for {dataset_name}"
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
def on_validation_epoch_end(self, trainer):
|
|
129
|
+
"""Aggregate and log the validation metrics at the end of the epoch.
|
|
130
|
+
|
|
131
|
+
Each rank writes out its partial CSV. Then rank 0 aggregates them, logs grouped metrics by dataset,
|
|
132
|
+
and appends them to a master file containing data from all epochs.
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
# ... write out partial CSV for this rank
|
|
136
|
+
rank = trainer.fabric.global_rank
|
|
137
|
+
epoch = trainer.state["current_epoch"]
|
|
138
|
+
self._save_dataframe_for_rank(rank, epoch)
|
|
139
|
+
|
|
140
|
+
# Synchronize all processes
|
|
141
|
+
ranked_logger.info(
|
|
142
|
+
"Synchronizing all processes before concatenating DataFrames..."
|
|
143
|
+
)
|
|
144
|
+
trainer.fabric.barrier()
|
|
145
|
+
|
|
146
|
+
# Only rank 0 loads and concatenates the DataFrames
|
|
147
|
+
ranked_logger.info("Loading and concatenating DataFrames...")
|
|
148
|
+
if trainer.fabric.is_global_zero:
|
|
149
|
+
# ... load all partial CSVs
|
|
150
|
+
merged_df = self._load_and_concatenate_csvs(epoch)
|
|
151
|
+
|
|
152
|
+
# ... append to master CSV for all epochs
|
|
153
|
+
master_path = self.save_dir / "validation_output_all_epochs.csv"
|
|
154
|
+
if master_path.exists():
|
|
155
|
+
old_df = pd.read_csv(master_path)
|
|
156
|
+
merged_df = pd.concat(
|
|
157
|
+
[old_df, merged_df], ignore_index=True, sort=False
|
|
158
|
+
)
|
|
159
|
+
merged_df.to_csv(master_path, index=False)
|
|
160
|
+
ranked_logger.info(f"Appended epoch={epoch} results to {master_path}")
|
|
161
|
+
|
|
162
|
+
# Store the path to the master CSV in the Trainer
|
|
163
|
+
trainer.validation_results_path = master_path
|
|
164
|
+
|
|
165
|
+
# Cleanup
|
|
166
|
+
self._cleanup_temp_files()
|
|
167
|
+
|
|
168
|
+
def _load_and_concatenate_csvs(self, epoch: int) -> pd.DataFrame:
|
|
169
|
+
"""Load rank-specific CSVs for the given epoch and concatenate them without duplicating examples."""
|
|
170
|
+
pattern = f"validation_output_rank_*_epoch_{epoch}.csv"
|
|
171
|
+
files = list(self.save_dir.glob(pattern))
|
|
172
|
+
|
|
173
|
+
# Track which example_id + dataset combinations we've already seen
|
|
174
|
+
seen_examples = set()
|
|
175
|
+
final_dataframes = []
|
|
176
|
+
|
|
177
|
+
for f in files:
|
|
178
|
+
try:
|
|
179
|
+
df = pd.read_csv(f)
|
|
180
|
+
|
|
181
|
+
# Create a filter for rows with new example_id + dataset combinations
|
|
182
|
+
if not df.empty:
|
|
183
|
+
# Create a unique identifier for each example_id + dataset combination
|
|
184
|
+
df["_example_key"] = (
|
|
185
|
+
df["example_id"].astype(str) + "|" + df["dataset"].astype(str)
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# Filter out rows with example_id + dataset combinations we've already seen
|
|
189
|
+
new_examples_mask = ~df["_example_key"].isin(seen_examples)
|
|
190
|
+
|
|
191
|
+
# If there are any new examples, add them to our final list
|
|
192
|
+
if new_examples_mask.any():
|
|
193
|
+
new_examples_df = df[new_examples_mask].copy()
|
|
194
|
+
|
|
195
|
+
# Update our set of seen examples
|
|
196
|
+
seen_examples.update(new_examples_df["_example_key"].tolist())
|
|
197
|
+
|
|
198
|
+
# Remove the temporary column before adding to final list
|
|
199
|
+
new_examples_df.drop("_example_key", axis=1, inplace=True)
|
|
200
|
+
final_dataframes.append(new_examples_df)
|
|
201
|
+
|
|
202
|
+
except pd.errors.EmptyDataError:
|
|
203
|
+
ranked_logger.warning(f"Skipping empty CSV: {f}")
|
|
204
|
+
|
|
205
|
+
# Concatenate dataframes, filling missing columns with NaN
|
|
206
|
+
return pd.concat(final_dataframes, axis=0, ignore_index=True, sort=False)
|
|
207
|
+
|
|
208
|
+
def _cleanup_temp_files(self):
|
|
209
|
+
"""Remove temporary files used to store individual rank outputs."""
|
|
210
|
+
all_files = list(self.save_dir.rglob("validation_output_rank_*_epoch_*.csv"))
|
|
211
|
+
for file in all_files:
|
|
212
|
+
try:
|
|
213
|
+
file.unlink() # Remove the file
|
|
214
|
+
except Exception as e:
|
|
215
|
+
ranked_logger.warning(f"Failed to delete file {file}: {e}")
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class LogAF3ValidationMetricsCallback(BaseCallback):
|
|
219
|
+
def __init__(
|
|
220
|
+
self,
|
|
221
|
+
metrics_to_log: list[str] | Literal["all"] = "all",
|
|
222
|
+
):
|
|
223
|
+
self.metrics_to_log = metrics_to_log
|
|
224
|
+
|
|
225
|
+
def on_validation_epoch_end(self, trainer):
|
|
226
|
+
# Only log metrics to disk if this is the global zero rank
|
|
227
|
+
if not trainer.fabric.is_global_zero:
|
|
228
|
+
return
|
|
229
|
+
|
|
230
|
+
assert hasattr(
|
|
231
|
+
trainer, "validation_results_path"
|
|
232
|
+
), "Results path not found! Ensure that StoreValidationMetricsInDFCallback is called first."
|
|
233
|
+
df = pd.read_csv(trainer.validation_results_path)
|
|
234
|
+
|
|
235
|
+
# ... filter to most recent epoch, drop epoch column
|
|
236
|
+
df = df[df["epoch"] == df["epoch"].max()]
|
|
237
|
+
df.drop(columns=["epoch", "example_id"], inplace=True)
|
|
238
|
+
|
|
239
|
+
# ... filter to columns that start with the metrics_to_log prefixes (and "dataset")
|
|
240
|
+
if self.metrics_to_log != "all" and isinstance(
|
|
241
|
+
self.metrics_to_log, list | ListConfig
|
|
242
|
+
):
|
|
243
|
+
df = df[
|
|
244
|
+
[
|
|
245
|
+
col
|
|
246
|
+
for col in df.columns
|
|
247
|
+
if any(col.startswith(prefix) for prefix in self.metrics_to_log)
|
|
248
|
+
]
|
|
249
|
+
+ ["dataset"]
|
|
250
|
+
]
|
|
251
|
+
|
|
252
|
+
for dataset in df["dataset"].unique():
|
|
253
|
+
dataset_df = df[df["dataset"] == dataset].copy()
|
|
254
|
+
dataset_df.drop(columns=["dataset"], inplace=True)
|
|
255
|
+
|
|
256
|
+
print(f"\n+{' ' + dataset + ' ':-^150}+\n")
|
|
257
|
+
|
|
258
|
+
# +------------- LDDT by type (chain, interface) -------------+
|
|
259
|
+
by_type_lddt_cols = [
|
|
260
|
+
col for col in df.columns if col.startswith("by_type_lddt")
|
|
261
|
+
]
|
|
262
|
+
if by_type_lddt_cols:
|
|
263
|
+
# ... build by-type DataFrame
|
|
264
|
+
by_type_df = dataset_df[by_type_lddt_cols].copy()
|
|
265
|
+
by_type_df = by_type_df.dropna(how="all")
|
|
266
|
+
|
|
267
|
+
# ... remove the "by_type_lddt." prefix
|
|
268
|
+
by_type_df.columns = by_type_df.columns.str.replace("by_type_lddt.", "")
|
|
269
|
+
numeric_cols = by_type_df.select_dtypes(include="number").columns
|
|
270
|
+
|
|
271
|
+
# ... group by type
|
|
272
|
+
grouped = by_type_df.groupby("type")[numeric_cols].agg(
|
|
273
|
+
["mean", "count"]
|
|
274
|
+
)
|
|
275
|
+
print_df_as_table(
|
|
276
|
+
condense_count_columns_of_grouped_df(grouped).reset_index(),
|
|
277
|
+
f"{dataset} — Epoch {trainer.state['current_epoch']} — Validation Metrics: LDDT by Type",
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
# Log the grouped metrics (aggregated from all ranks) with Fabric
|
|
281
|
+
if trainer.fabric:
|
|
282
|
+
for _, row in grouped.reset_index().iterrows():
|
|
283
|
+
trainer.fabric.log_dict(
|
|
284
|
+
{
|
|
285
|
+
f"val/{dataset}/{row['type'].iloc[0]}/{col}": row[col][
|
|
286
|
+
"mean"
|
|
287
|
+
]
|
|
288
|
+
for col in numeric_cols
|
|
289
|
+
},
|
|
290
|
+
step=trainer.state["current_epoch"],
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
# +----------------- Other metrics -----------------+
|
|
294
|
+
remaining_cols = list(set(dataset_df.columns) - set(by_type_lddt_cols))
|
|
295
|
+
remaining_df = dataset_df[remaining_cols].copy()
|
|
296
|
+
remaining_df = remaining_df.dropna(how="all", axis=0)
|
|
297
|
+
remaining_df = remaining_df.dropna(
|
|
298
|
+
how="all", axis=1
|
|
299
|
+
) # If a Metric is all NaNs for this dataset, drop it
|
|
300
|
+
numeric_cols = remaining_df.select_dtypes(include="number").columns
|
|
301
|
+
|
|
302
|
+
# Compute means and non-NaN counts for numeric columns
|
|
303
|
+
final_means = remaining_df[numeric_cols].mean()
|
|
304
|
+
non_nan_counts = remaining_df[numeric_cols].count()
|
|
305
|
+
|
|
306
|
+
# Convert the Series to a DataFrame and add the count as a new column
|
|
307
|
+
final_means_df = final_means.to_frame(name="mean")
|
|
308
|
+
final_means_df["Count"] = non_nan_counts
|
|
309
|
+
|
|
310
|
+
# ... sort, so the rows are alphabetical
|
|
311
|
+
final_means_df.sort_index(inplace=True)
|
|
312
|
+
|
|
313
|
+
print_df_as_table(
|
|
314
|
+
final_means_df.reset_index(),
|
|
315
|
+
f"{dataset} — {trainer.state['current_epoch']} — General Validation Metrics",
|
|
316
|
+
console_width=150,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
if trainer.fabric:
|
|
320
|
+
for col in numeric_cols:
|
|
321
|
+
trainer.fabric.log_dict(
|
|
322
|
+
{f"val/{dataset}/{col}": final_means[col]},
|
|
323
|
+
step=trainer.state["current_epoch"],
|
|
324
|
+
)
|