rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
foundry/utils/logging.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import warnings
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from beartype.typing import Any
|
|
7
|
+
from lightning_fabric.utilities import rank_zero_only
|
|
8
|
+
from omegaconf import DictConfig, OmegaConf
|
|
9
|
+
from rich.console import Console
|
|
10
|
+
from rich.syntax import Syntax
|
|
11
|
+
from rich.table import Table
|
|
12
|
+
from rich.tree import Tree
|
|
13
|
+
from torch import nn
|
|
14
|
+
|
|
15
|
+
from foundry.utils.ddp import RankedLogger
|
|
16
|
+
|
|
17
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class CachedDataFilter(logging.Filter):
|
|
21
|
+
"""Filter to suppress atomworks cached data logging messages."""
|
|
22
|
+
|
|
23
|
+
def filter(self, record):
|
|
24
|
+
# Filter out "Cached data not found" messages
|
|
25
|
+
if "Cached data not found" in record.getMessage():
|
|
26
|
+
return False
|
|
27
|
+
return True
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def silence_warnings():
|
|
31
|
+
"""Silence common warnings that appear during foundry execution."""
|
|
32
|
+
warnings.filterwarnings(
|
|
33
|
+
"ignore", message="All-NaN slice encountered", category=RuntimeWarning
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
warnings.filterwarnings(
|
|
37
|
+
"ignore",
|
|
38
|
+
message="Category 'chem_comp_bond' not found. No bonds will be parsed",
|
|
39
|
+
category=UserWarning,
|
|
40
|
+
module="biotite.structure.io.pdbx.convert",
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
warnings.filterwarnings(
|
|
44
|
+
"ignore",
|
|
45
|
+
message="torch.get_autocast_gpu_dtype\\(\\) is deprecated.*",
|
|
46
|
+
category=DeprecationWarning,
|
|
47
|
+
module="cuequivariance_ops_torch.triangle_attention",
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
warnings.filterwarnings(
|
|
51
|
+
"ignore",
|
|
52
|
+
message=".*multi-threaded.*fork.*may lead to deadlocks.*",
|
|
53
|
+
category=DeprecationWarning,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
warnings.filterwarnings(
|
|
57
|
+
"ignore",
|
|
58
|
+
message=".*is_pyramidine.*deprecated.*Use.*is_pyrimidine.*",
|
|
59
|
+
category=DeprecationWarning,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
warnings.filterwarnings(
|
|
63
|
+
"ignore",
|
|
64
|
+
message=".*index_reduce.*is in beta.*API may change.*",
|
|
65
|
+
category=UserWarning,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@contextmanager
|
|
70
|
+
def suppress_warnings(is_inference: bool = False):
|
|
71
|
+
"""Context manager to suppress specific warnings within its scope.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
is_inference: If True, also suppress inference-specific logging messages
|
|
75
|
+
(e.g., atomworks cached data warnings).
|
|
76
|
+
|
|
77
|
+
Required to suppress warnings within multiprocessing contexts; e.g., `torch.multiprocessing.spawn`.
|
|
78
|
+
"""
|
|
79
|
+
cached_data_filter = None
|
|
80
|
+
|
|
81
|
+
try:
|
|
82
|
+
with warnings.catch_warnings():
|
|
83
|
+
silence_warnings()
|
|
84
|
+
if is_inference:
|
|
85
|
+
# Add filter to suppress cached data messages
|
|
86
|
+
cached_data_filter = CachedDataFilter()
|
|
87
|
+
atomworks_ml_logger = logging.getLogger("atomworks.ml")
|
|
88
|
+
atomworks_ml_logger.addFilter(cached_data_filter)
|
|
89
|
+
|
|
90
|
+
yield
|
|
91
|
+
finally:
|
|
92
|
+
# Remove the filter
|
|
93
|
+
if cached_data_filter is not None:
|
|
94
|
+
atomworks_ml_logger = logging.getLogger("atomworks.ml")
|
|
95
|
+
atomworks_ml_logger.removeFilter(cached_data_filter)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def configure_minimal_inference_logging() -> None:
|
|
99
|
+
"""Configure minimal logging for inference (quiet mode)."""
|
|
100
|
+
# Suppress most logging by default
|
|
101
|
+
logging.getLogger().setLevel(logging.WARNING)
|
|
102
|
+
|
|
103
|
+
# Suppress specific noisy loggers
|
|
104
|
+
for logger_name in [
|
|
105
|
+
"atomworks",
|
|
106
|
+
"transforms",
|
|
107
|
+
"foundry.metrics",
|
|
108
|
+
"foundry.trainers",
|
|
109
|
+
"foundry.inference_engines.base",
|
|
110
|
+
"rf3.trainers",
|
|
111
|
+
"rf3.utils.inference",
|
|
112
|
+
"lightning",
|
|
113
|
+
"lightning_fabric",
|
|
114
|
+
"root",
|
|
115
|
+
]:
|
|
116
|
+
logging.getLogger(logger_name).setLevel(logging.ERROR)
|
|
117
|
+
|
|
118
|
+
# Keep model-specific inference engine loggers at INFO for progress messages
|
|
119
|
+
logging.getLogger("rf3.inference_engines").setLevel(logging.INFO)
|
|
120
|
+
logging.getLogger("rfd3.engine").setLevel(logging.INFO)
|
|
121
|
+
|
|
122
|
+
# Suppress warnings
|
|
123
|
+
warnings.filterwarnings("ignore")
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@rank_zero_only
|
|
127
|
+
def print_config_tree(
|
|
128
|
+
cfg: DictConfig,
|
|
129
|
+
resolve: bool = False,
|
|
130
|
+
console_width: int = 100,
|
|
131
|
+
title: str = "CONFIG",
|
|
132
|
+
) -> None:
|
|
133
|
+
"""Prints the contents of a DictConfig as a tree structure using the Rich library.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
cfg (DictConfig): A DictConfig composed by Hydra.
|
|
137
|
+
resolve (bool): Whether to resolve reference fields of DictConfig. Default is False.
|
|
138
|
+
console_width (int): The width of the console for printing. Default is 100.
|
|
139
|
+
"""
|
|
140
|
+
console = Console(width=console_width)
|
|
141
|
+
style = "dim"
|
|
142
|
+
tree = Tree(title, style=style, guide_style=style)
|
|
143
|
+
|
|
144
|
+
# Generate config tree in natural order
|
|
145
|
+
for field in cfg:
|
|
146
|
+
branch = tree.add(field, style=style, guide_style=style)
|
|
147
|
+
|
|
148
|
+
config_group = cfg[field]
|
|
149
|
+
if isinstance(config_group, DictConfig):
|
|
150
|
+
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
|
|
151
|
+
else:
|
|
152
|
+
branch_content = str(config_group)
|
|
153
|
+
|
|
154
|
+
branch.add(Syntax(branch_content, "yaml", word_wrap=True))
|
|
155
|
+
|
|
156
|
+
# Print config tree using Rich's Console
|
|
157
|
+
# (This call happens before instantiating other loggers, so we don't try to capture the output)
|
|
158
|
+
console.print(tree)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
@rank_zero_only
|
|
162
|
+
def print_model_parameters(model: nn.Module, name: str = "") -> None:
|
|
163
|
+
"""Prints the total and trainable parameters of a PyTorch model.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
model (nn.Module): The PyTorch model to analyze.
|
|
167
|
+
"""
|
|
168
|
+
total_params = sum(p.numel() for p in model.parameters())
|
|
169
|
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
170
|
+
data = {
|
|
171
|
+
"Type": ["Total Parameters", "Trainable Parameters"],
|
|
172
|
+
"Count": [total_params, trainable_params],
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
title = f"Model Parameters: {name}" if name else "Model Parameters"
|
|
176
|
+
print_df_as_table(pd.DataFrame(data), title=title)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def log_hyperparameters_with_all_loggers(
|
|
180
|
+
trainer: Any, cfg: dict | DictConfig, model: Any
|
|
181
|
+
):
|
|
182
|
+
"""Logs hyperparameters using all loggers in the trainer.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
trainer: The training object containing loggers.
|
|
186
|
+
cfg: Configuration dictionary containing hyperparameters.
|
|
187
|
+
model: The model to be tracked by loggers like WandbLogger.
|
|
188
|
+
"""
|
|
189
|
+
# If given a DictConfig, convert it to a dictionary
|
|
190
|
+
if isinstance(cfg, DictConfig):
|
|
191
|
+
cfg = OmegaConf.to_container(cfg, resolve=True)
|
|
192
|
+
|
|
193
|
+
for logger in trainer.fabric.loggers:
|
|
194
|
+
# ...log hyperparameters to each Fabric logger
|
|
195
|
+
# For Abstract Base Class of Fabric `Loggers`, see: https://lightning.ai/docs/fabric/stable/_modules/lightning/fabric/loggers/logger.html#Logger
|
|
196
|
+
assert hasattr(
|
|
197
|
+
logger, "log_hyperparams"
|
|
198
|
+
), f"Logger {logger} does not have a `log_hyperparams` method. Ensure that the logger is a subclass of Fabric's ABC `Logger`."
|
|
199
|
+
try:
|
|
200
|
+
logger.log_hyperparams(cfg)
|
|
201
|
+
except NotImplementedError:
|
|
202
|
+
pass
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def condense_count_columns_of_grouped_df(df: pd.DataFrame) -> pd.DataFrame:
|
|
206
|
+
"""Returns modified DF with single Count column if valid, otherwise original DF.
|
|
207
|
+
|
|
208
|
+
Helpful to avoid repeating count columns in a DataFrame with multi-level columns.
|
|
209
|
+
"""
|
|
210
|
+
if not isinstance(df.columns, pd.MultiIndex):
|
|
211
|
+
return df
|
|
212
|
+
|
|
213
|
+
try:
|
|
214
|
+
# Validate count structure
|
|
215
|
+
count_cols = df.xs("count", level=1, axis=1)
|
|
216
|
+
mean_cols = df.xs("mean", level=1, axis=1)
|
|
217
|
+
|
|
218
|
+
# Check count consistency per row and column existence
|
|
219
|
+
if not (count_cols.nunique(axis=1) == 1).all():
|
|
220
|
+
return df
|
|
221
|
+
|
|
222
|
+
# Build condensed dataframe
|
|
223
|
+
condensed_df = mean_cols.rename(columns=lambda c: f"{c} (mean)")
|
|
224
|
+
condensed_df["Count"] = count_cols.iloc[:, 0].astype(int)
|
|
225
|
+
return condensed_df
|
|
226
|
+
|
|
227
|
+
except (KeyError, IndexError):
|
|
228
|
+
return df
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def table_from_df(df: pd.DataFrame, title: str) -> Table:
|
|
232
|
+
"""Create a Rich Table from a DataFrame."""
|
|
233
|
+
table = Table(title=title, show_header=True, header_style="bold cyan")
|
|
234
|
+
|
|
235
|
+
# Add columns to the table
|
|
236
|
+
for col in df.columns:
|
|
237
|
+
table.add_column(col, justify="right", style="magenta", overflow="fold")
|
|
238
|
+
|
|
239
|
+
# Iterate through DataFrame rows and add them to the table
|
|
240
|
+
for _, row in df.iterrows():
|
|
241
|
+
row_cells = []
|
|
242
|
+
|
|
243
|
+
for col in df.columns:
|
|
244
|
+
cell_value = row[col]
|
|
245
|
+
|
|
246
|
+
# Determine formatting based on data type
|
|
247
|
+
if pd.api.types.is_integer_dtype(df[col]):
|
|
248
|
+
formatted_value = f"{int(cell_value):,}"
|
|
249
|
+
elif pd.api.types.is_float_dtype(df[col]):
|
|
250
|
+
formatted_value = f"{float(cell_value):,.4f}"
|
|
251
|
+
else:
|
|
252
|
+
formatted_value = str(cell_value)
|
|
253
|
+
|
|
254
|
+
row_cells.append(formatted_value)
|
|
255
|
+
|
|
256
|
+
table.add_row(*row_cells)
|
|
257
|
+
|
|
258
|
+
return table
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def safe_print(obj: Any, console_width=100, logger: Any | None = None) -> None:
|
|
262
|
+
"""Print a Rich object in a console- and logger-safe manner."""
|
|
263
|
+
console = Console(force_terminal=False, color_system=None, width=console_width)
|
|
264
|
+
|
|
265
|
+
# Capture the table as a string and log it
|
|
266
|
+
with console.capture() as capture:
|
|
267
|
+
console.print(obj)
|
|
268
|
+
|
|
269
|
+
if logger:
|
|
270
|
+
# Use the provided logger
|
|
271
|
+
logger.info(f"\n{capture.get()}")
|
|
272
|
+
else:
|
|
273
|
+
# Use the default ranked logger
|
|
274
|
+
ranked_logger.info(f"\n{capture.get()}")
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def print_df_as_table(df: pd.DataFrame, title: str, console_width: int = 100) -> None:
|
|
278
|
+
"""Pretty-print a DataFrame using Rich Table"""
|
|
279
|
+
safe_print(table_from_df(df=df, title=title), console_width=console_width)
|