rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,279 @@
1
+ import logging
2
+ import warnings
3
+ from contextlib import contextmanager
4
+
5
+ import pandas as pd
6
+ from beartype.typing import Any
7
+ from lightning_fabric.utilities import rank_zero_only
8
+ from omegaconf import DictConfig, OmegaConf
9
+ from rich.console import Console
10
+ from rich.syntax import Syntax
11
+ from rich.table import Table
12
+ from rich.tree import Tree
13
+ from torch import nn
14
+
15
+ from foundry.utils.ddp import RankedLogger
16
+
17
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
18
+
19
+
20
+ class CachedDataFilter(logging.Filter):
21
+ """Filter to suppress atomworks cached data logging messages."""
22
+
23
+ def filter(self, record):
24
+ # Filter out "Cached data not found" messages
25
+ if "Cached data not found" in record.getMessage():
26
+ return False
27
+ return True
28
+
29
+
30
+ def silence_warnings():
31
+ """Silence common warnings that appear during foundry execution."""
32
+ warnings.filterwarnings(
33
+ "ignore", message="All-NaN slice encountered", category=RuntimeWarning
34
+ )
35
+
36
+ warnings.filterwarnings(
37
+ "ignore",
38
+ message="Category 'chem_comp_bond' not found. No bonds will be parsed",
39
+ category=UserWarning,
40
+ module="biotite.structure.io.pdbx.convert",
41
+ )
42
+
43
+ warnings.filterwarnings(
44
+ "ignore",
45
+ message="torch.get_autocast_gpu_dtype\\(\\) is deprecated.*",
46
+ category=DeprecationWarning,
47
+ module="cuequivariance_ops_torch.triangle_attention",
48
+ )
49
+
50
+ warnings.filterwarnings(
51
+ "ignore",
52
+ message=".*multi-threaded.*fork.*may lead to deadlocks.*",
53
+ category=DeprecationWarning,
54
+ )
55
+
56
+ warnings.filterwarnings(
57
+ "ignore",
58
+ message=".*is_pyramidine.*deprecated.*Use.*is_pyrimidine.*",
59
+ category=DeprecationWarning,
60
+ )
61
+
62
+ warnings.filterwarnings(
63
+ "ignore",
64
+ message=".*index_reduce.*is in beta.*API may change.*",
65
+ category=UserWarning,
66
+ )
67
+
68
+
69
+ @contextmanager
70
+ def suppress_warnings(is_inference: bool = False):
71
+ """Context manager to suppress specific warnings within its scope.
72
+
73
+ Args:
74
+ is_inference: If True, also suppress inference-specific logging messages
75
+ (e.g., atomworks cached data warnings).
76
+
77
+ Required to suppress warnings within multiprocessing contexts; e.g., `torch.multiprocessing.spawn`.
78
+ """
79
+ cached_data_filter = None
80
+
81
+ try:
82
+ with warnings.catch_warnings():
83
+ silence_warnings()
84
+ if is_inference:
85
+ # Add filter to suppress cached data messages
86
+ cached_data_filter = CachedDataFilter()
87
+ atomworks_ml_logger = logging.getLogger("atomworks.ml")
88
+ atomworks_ml_logger.addFilter(cached_data_filter)
89
+
90
+ yield
91
+ finally:
92
+ # Remove the filter
93
+ if cached_data_filter is not None:
94
+ atomworks_ml_logger = logging.getLogger("atomworks.ml")
95
+ atomworks_ml_logger.removeFilter(cached_data_filter)
96
+
97
+
98
+ def configure_minimal_inference_logging() -> None:
99
+ """Configure minimal logging for inference (quiet mode)."""
100
+ # Suppress most logging by default
101
+ logging.getLogger().setLevel(logging.WARNING)
102
+
103
+ # Suppress specific noisy loggers
104
+ for logger_name in [
105
+ "atomworks",
106
+ "transforms",
107
+ "foundry.metrics",
108
+ "foundry.trainers",
109
+ "foundry.inference_engines.base",
110
+ "rf3.trainers",
111
+ "rf3.utils.inference",
112
+ "lightning",
113
+ "lightning_fabric",
114
+ "root",
115
+ ]:
116
+ logging.getLogger(logger_name).setLevel(logging.ERROR)
117
+
118
+ # Keep model-specific inference engine loggers at INFO for progress messages
119
+ logging.getLogger("rf3.inference_engines").setLevel(logging.INFO)
120
+ logging.getLogger("rfd3.engine").setLevel(logging.INFO)
121
+
122
+ # Suppress warnings
123
+ warnings.filterwarnings("ignore")
124
+
125
+
126
+ @rank_zero_only
127
+ def print_config_tree(
128
+ cfg: DictConfig,
129
+ resolve: bool = False,
130
+ console_width: int = 100,
131
+ title: str = "CONFIG",
132
+ ) -> None:
133
+ """Prints the contents of a DictConfig as a tree structure using the Rich library.
134
+
135
+ Args:
136
+ cfg (DictConfig): A DictConfig composed by Hydra.
137
+ resolve (bool): Whether to resolve reference fields of DictConfig. Default is False.
138
+ console_width (int): The width of the console for printing. Default is 100.
139
+ """
140
+ console = Console(width=console_width)
141
+ style = "dim"
142
+ tree = Tree(title, style=style, guide_style=style)
143
+
144
+ # Generate config tree in natural order
145
+ for field in cfg:
146
+ branch = tree.add(field, style=style, guide_style=style)
147
+
148
+ config_group = cfg[field]
149
+ if isinstance(config_group, DictConfig):
150
+ branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
151
+ else:
152
+ branch_content = str(config_group)
153
+
154
+ branch.add(Syntax(branch_content, "yaml", word_wrap=True))
155
+
156
+ # Print config tree using Rich's Console
157
+ # (This call happens before instantiating other loggers, so we don't try to capture the output)
158
+ console.print(tree)
159
+
160
+
161
+ @rank_zero_only
162
+ def print_model_parameters(model: nn.Module, name: str = "") -> None:
163
+ """Prints the total and trainable parameters of a PyTorch model.
164
+
165
+ Args:
166
+ model (nn.Module): The PyTorch model to analyze.
167
+ """
168
+ total_params = sum(p.numel() for p in model.parameters())
169
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
170
+ data = {
171
+ "Type": ["Total Parameters", "Trainable Parameters"],
172
+ "Count": [total_params, trainable_params],
173
+ }
174
+
175
+ title = f"Model Parameters: {name}" if name else "Model Parameters"
176
+ print_df_as_table(pd.DataFrame(data), title=title)
177
+
178
+
179
+ def log_hyperparameters_with_all_loggers(
180
+ trainer: Any, cfg: dict | DictConfig, model: Any
181
+ ):
182
+ """Logs hyperparameters using all loggers in the trainer.
183
+
184
+ Args:
185
+ trainer: The training object containing loggers.
186
+ cfg: Configuration dictionary containing hyperparameters.
187
+ model: The model to be tracked by loggers like WandbLogger.
188
+ """
189
+ # If given a DictConfig, convert it to a dictionary
190
+ if isinstance(cfg, DictConfig):
191
+ cfg = OmegaConf.to_container(cfg, resolve=True)
192
+
193
+ for logger in trainer.fabric.loggers:
194
+ # ...log hyperparameters to each Fabric logger
195
+ # For Abstract Base Class of Fabric `Loggers`, see: https://lightning.ai/docs/fabric/stable/_modules/lightning/fabric/loggers/logger.html#Logger
196
+ assert hasattr(
197
+ logger, "log_hyperparams"
198
+ ), f"Logger {logger} does not have a `log_hyperparams` method. Ensure that the logger is a subclass of Fabric's ABC `Logger`."
199
+ try:
200
+ logger.log_hyperparams(cfg)
201
+ except NotImplementedError:
202
+ pass
203
+
204
+
205
+ def condense_count_columns_of_grouped_df(df: pd.DataFrame) -> pd.DataFrame:
206
+ """Returns modified DF with single Count column if valid, otherwise original DF.
207
+
208
+ Helpful to avoid repeating count columns in a DataFrame with multi-level columns.
209
+ """
210
+ if not isinstance(df.columns, pd.MultiIndex):
211
+ return df
212
+
213
+ try:
214
+ # Validate count structure
215
+ count_cols = df.xs("count", level=1, axis=1)
216
+ mean_cols = df.xs("mean", level=1, axis=1)
217
+
218
+ # Check count consistency per row and column existence
219
+ if not (count_cols.nunique(axis=1) == 1).all():
220
+ return df
221
+
222
+ # Build condensed dataframe
223
+ condensed_df = mean_cols.rename(columns=lambda c: f"{c} (mean)")
224
+ condensed_df["Count"] = count_cols.iloc[:, 0].astype(int)
225
+ return condensed_df
226
+
227
+ except (KeyError, IndexError):
228
+ return df
229
+
230
+
231
+ def table_from_df(df: pd.DataFrame, title: str) -> Table:
232
+ """Create a Rich Table from a DataFrame."""
233
+ table = Table(title=title, show_header=True, header_style="bold cyan")
234
+
235
+ # Add columns to the table
236
+ for col in df.columns:
237
+ table.add_column(col, justify="right", style="magenta", overflow="fold")
238
+
239
+ # Iterate through DataFrame rows and add them to the table
240
+ for _, row in df.iterrows():
241
+ row_cells = []
242
+
243
+ for col in df.columns:
244
+ cell_value = row[col]
245
+
246
+ # Determine formatting based on data type
247
+ if pd.api.types.is_integer_dtype(df[col]):
248
+ formatted_value = f"{int(cell_value):,}"
249
+ elif pd.api.types.is_float_dtype(df[col]):
250
+ formatted_value = f"{float(cell_value):,.4f}"
251
+ else:
252
+ formatted_value = str(cell_value)
253
+
254
+ row_cells.append(formatted_value)
255
+
256
+ table.add_row(*row_cells)
257
+
258
+ return table
259
+
260
+
261
+ def safe_print(obj: Any, console_width=100, logger: Any | None = None) -> None:
262
+ """Print a Rich object in a console- and logger-safe manner."""
263
+ console = Console(force_terminal=False, color_system=None, width=console_width)
264
+
265
+ # Capture the table as a string and log it
266
+ with console.capture() as capture:
267
+ console.print(obj)
268
+
269
+ if logger:
270
+ # Use the provided logger
271
+ logger.info(f"\n{capture.get()}")
272
+ else:
273
+ # Use the default ranked logger
274
+ ranked_logger.info(f"\n{capture.get()}")
275
+
276
+
277
+ def print_df_as_table(df: pd.DataFrame, title: str, console_width: int = 100) -> None:
278
+ """Pretty-print a DataFrame using Rich Table"""
279
+ safe_print(table_from_df(df=df, title=title), console_width=console_width)