rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
mpnn/pipelines/mpnn.py ADDED
@@ -0,0 +1,162 @@
1
+ from atomworks.constants import AF3_EXCLUDED_LIGANDS, STANDARD_AA, UNKNOWN_AA
2
+ from atomworks.enums import ChainTypeInfo
3
+ from atomworks.ml.transforms.atom_array import AddWithinChainInstanceResIdx
4
+ from atomworks.ml.transforms.atomize import (
5
+ AtomizeByCCDName,
6
+ FlagNonPolymersForAtomization,
7
+ )
8
+ from atomworks.ml.transforms.base import (
9
+ AddData,
10
+ Compose,
11
+ ConditionalRoute,
12
+ ConvertToTorch,
13
+ Identity,
14
+ SubsetToKeys,
15
+ )
16
+ from atomworks.ml.transforms.covalent_modifications import (
17
+ FlagAndReassignCovalentModifications,
18
+ )
19
+ from atomworks.ml.transforms.featurize_unresolved_residues import (
20
+ MaskResiduesWithSpecificUnresolvedAtoms,
21
+ )
22
+ from atomworks.ml.transforms.filters import (
23
+ FilterToSpecifiedPNUnits,
24
+ HandleUndesiredResTokens,
25
+ RemoveHydrogens,
26
+ RemoveUnresolvedTokens,
27
+ )
28
+ from mpnn.transforms.feature_aggregation.mpnn import (
29
+ EncodeMPNNNonAtomizedTokens,
30
+ FeaturizeAtomizedTokens,
31
+ FeaturizeNonAtomizedTokens,
32
+ )
33
+ from mpnn.transforms.feature_aggregation.user_settings import (
34
+ FeaturizeUserSettings,
35
+ )
36
+
37
+
38
+ def TrainingRoute(transform):
39
+ return ConditionalRoute(
40
+ condition_func=lambda data: data["is_inference"],
41
+ transform_map={True: Identity(), False: transform},
42
+ )
43
+
44
+
45
+ def InferenceRoute(transform):
46
+ return ConditionalRoute(
47
+ condition_func=lambda data: data["is_inference"],
48
+ transform_map={False: Identity(), True: transform},
49
+ )
50
+
51
+
52
+ def ModelTypeRoute(transform, model_type: str):
53
+ return ConditionalRoute(
54
+ condition_func=lambda data: data["model_type"] == model_type,
55
+ transform_map={True: transform, False: Identity()},
56
+ )
57
+
58
+
59
+ def build_mpnn_transform_pipeline(
60
+ *,
61
+ model_type: str = None,
62
+ occupancy_threshold_sidechain: float = 0.5,
63
+ occupancy_threshold_backbone: float = 0.8,
64
+ is_inference: bool = False,
65
+ minimal_return: bool = False,
66
+ train_structure_noise_default: float = 0.1,
67
+ undesired_res_names: list[str] = AF3_EXCLUDED_LIGANDS,
68
+ device=None,
69
+ ) -> Compose:
70
+ """Build the MPNN transform pipeline.
71
+ Args:
72
+ model_type (str): Model type identifier to include in data. Must be
73
+ provided. Defaults to None.
74
+ occupancy_threshold_sidechain (float): Minimum occupancy to consider
75
+ sidechain atoms as present in masks. Defaults to 0.5.
76
+ occupancy_threshold_backbone (float): Minimum occupancy to consider
77
+ backbone atoms as resolved. Residues with backbone atoms below this
78
+ threshold will be masked entirely. Defaults to 0.8.
79
+ train_structure_noise_default (float): Default standard deviation of
80
+ Gaussian noise to add to atomic coordinates during training for data
81
+ augmentation. Defaults to 0.1.
82
+ is_inference (bool): Whether this is inference mode. Defaults to
83
+ False (training mode).
84
+ minimal_return (bool): Whether to return minimal intermediate data.
85
+ Defaults to False.
86
+ undesired_res_names (list[str]): List of residue names to treat as
87
+ undesired and handle accordingly. Defaults to AF3_EXCLUDED_LIGANDS.
88
+ device (str | torch.device, optional): Device to move tensors to.
89
+ Defaults to None, which leads to default ConvertToTorch behavior.
90
+ """
91
+ if model_type not in ("protein_mpnn", "ligand_mpnn"):
92
+ raise ValueError(f"Unsupported model_type: {model_type}")
93
+
94
+ transforms = [
95
+ AddData({"model_type": model_type}),
96
+ AddData({"is_inference": is_inference}),
97
+ # + --------- Filters --------- +
98
+ RemoveHydrogens(),
99
+ # ... during training, filter to non-clashing chains (which are
100
+ # pre-computed and stored in the "extra_info" key)
101
+ TrainingRoute(
102
+ FilterToSpecifiedPNUnits(
103
+ extra_info_key_with_pn_unit_iids_to_keep="all_pn_unit_iids_after_processing"
104
+ ),
105
+ ),
106
+ # ... during training, remove undesired residues (e.g., non-biological
107
+ # crystallization artifacts), mapping to the closest canonical residue
108
+ # name where possible
109
+ TrainingRoute(
110
+ HandleUndesiredResTokens(undesired_res_tokens=undesired_res_names),
111
+ ),
112
+ # + --------- Atomization --------- +
113
+ # ... add within-chain instance res idx
114
+ AddWithinChainInstanceResIdx(),
115
+ FlagAndReassignCovalentModifications(),
116
+ # Atomization: keep standard AA + unknown AA as residues,
117
+ # atomize everything else
118
+ FlagNonPolymersForAtomization(),
119
+ AtomizeByCCDName(
120
+ atomize_by_default=True,
121
+ res_names_to_ignore=STANDARD_AA + (UNKNOWN_AA,),
122
+ move_atomized_part_to_end=False,
123
+ validate_atomize=False,
124
+ ),
125
+ # + --------- Occupancy filtering --------- +
126
+ MaskResiduesWithSpecificUnresolvedAtoms(
127
+ chain_type_to_atom_names={
128
+ ChainTypeInfo.PROTEINS: [
129
+ "N",
130
+ "CA",
131
+ "C",
132
+ "O",
133
+ ], # MPNN needs backbone + oxygen
134
+ },
135
+ occupancy_threshold=occupancy_threshold_backbone,
136
+ ),
137
+ RemoveUnresolvedTokens(),
138
+ # +-------- Encoding and featurization --------- +
139
+ AddData({"input_features": dict()}),
140
+ # Encode and featurize non-atomized tokens
141
+ EncodeMPNNNonAtomizedTokens(occupancy_threshold=occupancy_threshold_sidechain),
142
+ FeaturizeNonAtomizedTokens(),
143
+ # LigandMPNN specific featurization: featurize atomized tokens
144
+ ModelTypeRoute(
145
+ transform=FeaturizeAtomizedTokens(),
146
+ model_type=model_type,
147
+ ),
148
+ # Featurize user settings
149
+ FeaturizeUserSettings(
150
+ is_inference=is_inference,
151
+ minimal_return=minimal_return,
152
+ train_structure_noise_default=train_structure_noise_default,
153
+ ),
154
+ # Convert to torch and subset keys
155
+ ConvertToTorch(
156
+ keys=["input_features"],
157
+ **({"device": device} if device is not None else {}),
158
+ ),
159
+ SubsetToKeys(keys=["input_features", "atom_array"]),
160
+ ]
161
+
162
+ return Compose(transforms)
@@ -0,0 +1,167 @@
1
+ from typing import Any, Callable, Iterator, List
2
+
3
+ import numpy as np
4
+ import torch
5
+ from atomworks.ml.samplers import set_sampler_epoch
6
+ from torch.utils.data import BatchSampler, Sampler
7
+
8
+
9
+ class PaddedTokenBudgetBatchSampler(BatchSampler):
10
+ """
11
+ Token-based batch sampler that wraps existing samplers and creates batches
12
+ of similar-token length samples, respecting a maximum token count
13
+ constraint (considering that the batches will be padded to the maximum
14
+ length in the batch).
15
+
16
+ Args:
17
+ sampler: The underlying sampler to wrap around.
18
+ get_num_tokens: Function that takes an index from the previous sampler
19
+ and returns the number of tokens for that sample.
20
+ max_tokens_with_padding: Maximum number of tokens allowed per batch,
21
+ including padding. The constraint is
22
+ max(batch_lengths) * len(batch) <= max_tokens.
23
+ shuffle_batches: Whether to randomize the order of batches after
24
+ grouping by length. Defaults to True.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ sampler: Sampler,
30
+ get_num_tokens: Callable[[Any], int],
31
+ max_tokens_with_padding: int = 6000,
32
+ shuffle_batches: bool = True,
33
+ ):
34
+ # Initialize BatchSampler with a dummy batch_size (we don't use it).
35
+ super().__init__(sampler, batch_size=1, drop_last=False)
36
+
37
+ self.sampler = sampler
38
+ self.get_num_tokens = get_num_tokens
39
+ self.max_tokens_with_padding = max_tokens_with_padding
40
+ self.shuffle_batches = shuffle_batches
41
+
42
+ # Add drop_last attribute for DataLoader compatibility
43
+ self.drop_last = False
44
+
45
+ self.epoch = 0
46
+
47
+ # Will hold our "one epoch" cache.
48
+ self._batches: List[List[Any]] | None = None
49
+
50
+ # Validate inputs
51
+ if max_tokens_with_padding <= 0:
52
+ raise ValueError("max_tokens_with_padding must be greater than 0")
53
+ if get_num_tokens is None:
54
+ raise ValueError("get_num_tokens function must be provided")
55
+
56
+ def set_epoch(self, epoch: int) -> None:
57
+ self.epoch = epoch
58
+ set_sampler_epoch(self.sampler, epoch)
59
+
60
+ def _build_batches(self) -> List[List[Any]]:
61
+ """
62
+ Compute all batches and cache them for the current epoch.
63
+
64
+ Returns:
65
+ List of batches, where each batch is a list of indices.
66
+ Raises:
67
+ TypeError: If get_num_tokens returns invalid types.
68
+ ValueError: If get_num_tokens returns invalid values.
69
+ """
70
+ # Extract all indices and their token counts
71
+ sample_indices_and_lengths = []
72
+ for idx in self.sampler:
73
+ num_tokens = self.get_num_tokens(idx)
74
+
75
+ # Validate num_tokens type
76
+ if not isinstance(num_tokens, (int, float, np.integer, np.floating)):
77
+ raise TypeError(
78
+ f"get_num_tokens returned invalid type {type(num_tokens)} "
79
+ f"for index {idx}. Expected numeric type."
80
+ )
81
+
82
+ # Validate num_tokens value.
83
+ num_tokens = int(num_tokens)
84
+ if num_tokens <= 0:
85
+ raise ValueError(
86
+ f"get_num_tokens returned invalid value {num_tokens} "
87
+ f"for index {idx}. Expected positive integer."
88
+ )
89
+ if num_tokens > self.max_tokens_with_padding:
90
+ raise ValueError(
91
+ f"Index {idx} has {num_tokens} tokens, exceeding "
92
+ f"max_tokens_with_padding={self.max_tokens_with_padding}."
93
+ )
94
+
95
+ sample_indices_and_lengths.append((idx, num_tokens))
96
+
97
+ # Sort by token length (ascending order).
98
+ sample_indices_and_lengths.sort(key=lambda x: x[1])
99
+
100
+ # Batch by length
101
+ batches = []
102
+ current_batch = []
103
+ current_max_length = 0
104
+ for idx, length in sample_indices_and_lengths:
105
+ # Check if adding this sample would violate max_tokens_with_padding
106
+ # constraint.
107
+ if current_batch:
108
+ potential_max_length = max(length, current_max_length)
109
+ new_batch_size = len(current_batch) + 1
110
+
111
+ if potential_max_length * new_batch_size > self.max_tokens_with_padding:
112
+ # Current batch is full, start a new batch
113
+ batches.append(current_batch)
114
+
115
+ current_batch = [idx]
116
+ current_max_length = length
117
+ else:
118
+ # Add to current batch
119
+ current_batch.append(idx)
120
+ current_max_length = potential_max_length
121
+ else:
122
+ # First sample in batch
123
+ current_batch = [idx]
124
+ current_max_length = length
125
+
126
+ # Add the last batch if it's not empty
127
+ if current_batch:
128
+ batches.append(current_batch)
129
+
130
+ # Randomize batch order if requested
131
+ if self.shuffle_batches:
132
+ # Set the seed based on the epoch
133
+ g = torch.Generator()
134
+ g.manual_seed(self.epoch)
135
+
136
+ perm = torch.randperm(len(batches), generator=g).tolist()
137
+
138
+ batches = [batches[i] for i in perm]
139
+
140
+ return batches
141
+
142
+ def __iter__(self) -> Iterator[List[Any]]:
143
+ """
144
+ Generate batches of indices grouped by token length while respecting
145
+ the max_tokens_with_padding constraint.
146
+
147
+ Returns:
148
+ Iterator[List[Any]]: Iterator over batches of dataset indices.
149
+ """
150
+ # Build/reference cached batches if not already done.
151
+ if self._batches is None:
152
+ self._batches = self._build_batches()
153
+
154
+ for batch in self._batches:
155
+ yield batch
156
+
157
+ # End of __iter__ — clear cache so next call recomputes.
158
+ self._batches = None
159
+
160
+ def __len__(self) -> int:
161
+ """
162
+ Return the exact number of batches that will be produced.
163
+ """
164
+ # Build/reference cached batches if not already done.
165
+ if self._batches is None:
166
+ self._batches = self._build_batches()
167
+ return len(self._batches)
mpnn/train.py ADDED
@@ -0,0 +1,341 @@
1
+ #!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/mpnn_exec.sh" "$0" "$@"'
2
+
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ import pandas as pd
7
+ import torch
8
+ from atomworks.io.parser import STANDARD_PARSER_ARGS
9
+ from atomworks.ml.datasets.pandas_dataset import PandasDataset, StructuralDatasetWrapper
10
+ from atomworks.ml.datasets.parsers.default_metadata_row_parsers import GenericDFParser
11
+ from atomworks.ml.samplers import (
12
+ DistributedMixedSampler,
13
+ calculate_weights_for_pdb_dataset_df,
14
+ )
15
+ from omegaconf import DictConfig
16
+ from torch.utils.data import DataLoader, WeightedRandomSampler
17
+
18
+ from foundry.callbacks.metrics_logging import StoreValidationMetricsInDFCallback
19
+ from foundry.utils.datasets import wrap_dataset_and_sampler_with_fallbacks
20
+ from mpnn.collate.feature_collator import TokenBudgetAwareFeatureCollator
21
+ from mpnn.pipelines.mpnn import build_mpnn_transform_pipeline
22
+ from mpnn.samplers.samplers import PaddedTokenBudgetBatchSampler
23
+ from mpnn.trainers.mpnn import MPNNTrainer
24
+
25
+ model_type = sys.argv[1]
26
+
27
+ if model_type == "protein_mpnn":
28
+ batch_size = 10000
29
+ train_date_cutoff = "2021-08-02"
30
+ clip_grad_max_norm = None
31
+ train_structure_noise_default = 0.2
32
+ elif model_type == "ligand_mpnn":
33
+ batch_size = 6000
34
+ train_date_cutoff = "2022-12-16"
35
+ clip_grad_max_norm = 1.0
36
+ train_structure_noise_default = 0.1
37
+ else:
38
+ raise ValueError(f"Unknown model_type: {model_type}")
39
+
40
+
41
+ def create_noam_scheduler(optimizer, d_model, warmup_steps=4000, factor=2):
42
+ """
43
+ Create a NoamOpt-style scheduler using standard PyTorch components.
44
+
45
+ Args:
46
+ optimizer: PyTorch optimizer
47
+ d_model: Model dimension (for scaling)
48
+ warmup_steps: Number of warmup steps
49
+ factor: Scaling factor
50
+
51
+ Returns:
52
+ LambdaLR scheduler that implements NoamOpt schedule
53
+ """
54
+
55
+ def noam_lambda(step):
56
+ # NoamOpt formula: factor * (d_model ** (-0.5)) * min(step ** (-0.5), step * warmup ** (-1.5))
57
+ base_lr = factor * (d_model ** (-0.5))
58
+ if step == 0:
59
+ return 0.0 # Start with zero learning rate
60
+
61
+ # Calculate the schedule component
62
+ schedule = min(step ** (-0.5), step * warmup_steps ** (-1.5))
63
+ return base_lr * schedule
64
+
65
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=noam_lambda)
66
+
67
+
68
+ def get_num_tokens(df, idx):
69
+ """
70
+ Extract the number of non-atomized tokens for a given index.
71
+
72
+ Args:
73
+ df: DataFrame containing the dataset
74
+ idx: Index to extract token count from
75
+
76
+ Returns:
77
+ Number of non-atomized tokens for the sample at idx
78
+ """
79
+ if isinstance(idx, (list, tuple)):
80
+ # If idx is a list/tuple, return the first element's token count
81
+ idx = idx[0]
82
+ return df.iloc[idx]["n_non_atomized_tokens"]
83
+
84
+
85
+ # Common filters for MPNN datasets
86
+ MPNN_FILTERS = [
87
+ "resolution < 3.5 and ~method.str.contains('NMR')",
88
+ "n_non_atomized_tokens >= 30",
89
+ "cluster.notnull() and cluster != 'nan'",
90
+ "method in ['X-RAY_DIFFRACTION', 'ELECTRON_MICROSCOPY']",
91
+ f"n_non_atomized_tokens <= {batch_size}",
92
+ "n_prot == 1",
93
+ ]
94
+
95
+ MPNN_TRAIN_FILTERS = [
96
+ f"deposition_date < '{train_date_cutoff}'",
97
+ ] + MPNN_FILTERS
98
+
99
+ # Data loading setup
100
+ train_path = "/projects/ml/datahub/dfs/mpnn/splits/2025_07_13/pn_units_df_train.parquet"
101
+ val_path = "/projects/ml/datahub/dfs/mpnn/splits/2025_07_13/pn_units_df_val.parquet"
102
+
103
+ # Load datasets
104
+ train_df = pd.read_parquet(train_path)
105
+ val_df = pd.read_parquet(val_path)
106
+
107
+ # Create different pipelines for training and inference
108
+ train_pipeline = build_mpnn_transform_pipeline(
109
+ model_type=model_type,
110
+ is_inference=False,
111
+ minimal_return=True,
112
+ train_structure_noise_default=train_structure_noise_default,
113
+ )
114
+ inference_pipeline = build_mpnn_transform_pipeline(
115
+ model_type=model_type, is_inference=True, minimal_return=True
116
+ )
117
+
118
+ # Create train dataset with fallback
119
+ train_structural_dataset = StructuralDatasetWrapper(
120
+ dataset=PandasDataset(
121
+ data=train_df,
122
+ id_column="example_id",
123
+ name="pn_units_df_train",
124
+ filters=MPNN_TRAIN_FILTERS,
125
+ ),
126
+ dataset_parser=GenericDFParser(
127
+ example_id_colname="example_id",
128
+ path_colname="path",
129
+ assembly_id_colname="assembly_id",
130
+ ),
131
+ transform=train_pipeline,
132
+ cif_parser_args={
133
+ **STANDARD_PARSER_ARGS,
134
+ "load_from_cache": True,
135
+ "save_to_cache": True,
136
+ "cache_dir": "/net/tukwila/akubaney/cifutils/cache",
137
+ },
138
+ )
139
+
140
+ train_fallback_dataset = StructuralDatasetWrapper(
141
+ dataset=PandasDataset(
142
+ data=train_df,
143
+ id_column="example_id",
144
+ name="pn_units_df_train",
145
+ filters=MPNN_TRAIN_FILTERS,
146
+ ),
147
+ dataset_parser=GenericDFParser(
148
+ example_id_colname="example_id",
149
+ path_colname="path",
150
+ assembly_id_colname="assembly_id",
151
+ ),
152
+ transform=train_pipeline,
153
+ cif_parser_args={
154
+ **STANDARD_PARSER_ARGS,
155
+ "load_from_cache": True,
156
+ "save_to_cache": True,
157
+ "cache_dir": "/net/tukwila/akubaney/cifutils/cache",
158
+ },
159
+ )
160
+
161
+ # Calculate weights for train dataset
162
+ train_weights = calculate_weights_for_pdb_dataset_df(
163
+ dataset_df=train_structural_dataset.data,
164
+ beta=1.0, # For chains
165
+ alphas={"a_prot": 1.0, "a_nuc": 0, "a_ligand": 0, "a_loi": 0},
166
+ )
167
+
168
+ # Create train sampler with fallback
169
+ train_sampler = DistributedMixedSampler(
170
+ datasets_info=[
171
+ {
172
+ "sampler": WeightedRandomSampler(
173
+ train_weights, len(train_structural_dataset)
174
+ ),
175
+ "dataset": train_structural_dataset,
176
+ "probability": 1.0,
177
+ }
178
+ ],
179
+ num_replicas=1,
180
+ rank=0,
181
+ n_examples_per_epoch=20000,
182
+ )
183
+
184
+ train_fallback_sampler = WeightedRandomSampler(
185
+ train_weights, len(train_structural_dataset)
186
+ )
187
+
188
+ train_dataset_with_fallback, train_sampler_with_fallback = (
189
+ wrap_dataset_and_sampler_with_fallbacks(
190
+ dataset_to_be_wrapped=train_structural_dataset,
191
+ sampler_to_be_wrapped=train_sampler,
192
+ dataset_to_fallback_to=train_fallback_dataset,
193
+ sampler_to_fallback_to=train_fallback_sampler,
194
+ n_fallback_retries=5,
195
+ )
196
+ )
197
+
198
+ batched_train_sampler = PaddedTokenBudgetBatchSampler(
199
+ sampler=train_sampler_with_fallback,
200
+ get_num_tokens=lambda idx: get_num_tokens(train_structural_dataset.data, idx),
201
+ max_tokens_with_padding=batch_size,
202
+ shuffle_batches=True,
203
+ )
204
+
205
+ # Create val dataset with fallback
206
+ val_structural_dataset = StructuralDatasetWrapper(
207
+ dataset=PandasDataset(
208
+ data=val_df,
209
+ id_column="example_id",
210
+ name="pn_units_df_val",
211
+ filters=MPNN_FILTERS,
212
+ ),
213
+ dataset_parser=GenericDFParser(
214
+ example_id_colname="example_id",
215
+ path_colname="path",
216
+ assembly_id_colname="assembly_id",
217
+ ),
218
+ transform=inference_pipeline,
219
+ cif_parser_args={
220
+ **STANDARD_PARSER_ARGS,
221
+ "load_from_cache": True,
222
+ "save_to_cache": True,
223
+ "cache_dir": "/net/tukwila/akubaney/cifutils/cache",
224
+ },
225
+ )
226
+
227
+ # Create val sampler with fallback
228
+ val_weights = calculate_weights_for_pdb_dataset_df(
229
+ dataset_df=val_structural_dataset.data,
230
+ beta=1.0, # For chains
231
+ alphas={"a_prot": 1.0, "a_nuc": 0, "a_ligand": 0, "a_loi": 0},
232
+ )
233
+
234
+ val_sampler = DistributedMixedSampler(
235
+ datasets_info=[
236
+ {
237
+ "sampler": WeightedRandomSampler(val_weights, len(val_structural_dataset)),
238
+ "dataset": val_structural_dataset,
239
+ "probability": 1.0,
240
+ }
241
+ ],
242
+ num_replicas=1,
243
+ rank=0,
244
+ n_examples_per_epoch=100,
245
+ )
246
+
247
+ # Create collator
248
+ collator = TokenBudgetAwareFeatureCollator(max_tokens_with_padding=batch_size)
249
+
250
+ # Create DataLoaders
251
+ train_loader = DataLoader(
252
+ train_dataset_with_fallback,
253
+ batch_sampler=batched_train_sampler,
254
+ num_workers=12,
255
+ collate_fn=collator,
256
+ )
257
+
258
+ val_loaders = {
259
+ "test_val": DataLoader(
260
+ val_structural_dataset,
261
+ sampler=val_sampler,
262
+ num_workers=12,
263
+ collate_fn=collator,
264
+ )
265
+ }
266
+
267
+ # Create output directory for logs and checkpoints
268
+ output_dir = Path(f"./mpnn_output_{model_type}")
269
+ output_dir.mkdir(exist_ok=True)
270
+
271
+ # Create CSV logging callback
272
+ csv_callback = StoreValidationMetricsInDFCallback(
273
+ save_dir=output_dir / "val_metrics", metrics_to_save="all"
274
+ )
275
+
276
+ # Create trainer with minimal configuration for testing
277
+ trainer = MPNNTrainer(
278
+ model_type=model_type,
279
+ accelerator="gpu",
280
+ devices_per_node=1,
281
+ max_epochs=500,
282
+ output_dir=output_dir,
283
+ callbacks=[csv_callback],
284
+ precision="bf16-mixed",
285
+ clip_grad_max_norm=clip_grad_max_norm,
286
+ )
287
+
288
+ # Create minimal train_cfg for optimizer and scheduler construction
289
+ train_cfg = DictConfig(
290
+ {
291
+ "model": {
292
+ "optimizer": {
293
+ "_target_": "torch.optim.Adam",
294
+ "lr": 1.0, # This will be overridden by the NoamOpt scheduler
295
+ "betas": [0.9, 0.98], # NoamOpt uses (0.9, 0.98)
296
+ "eps": 1e-9, # NoamOpt uses 1e-9
297
+ "weight_decay": 0.0,
298
+ },
299
+ "lr_scheduler": {
300
+ "_target_": "__main__.create_noam_scheduler",
301
+ "d_model": 128, # Adjust based on your model's hidden dimension
302
+ "warmup_steps": 4000,
303
+ "factor": 2,
304
+ },
305
+ }
306
+ }
307
+ )
308
+
309
+ # Initialize trainer state with train_cfg
310
+ trainer.initialize_or_update_trainer_state({"train_cfg": train_cfg})
311
+
312
+ # Launch Fabric (this sets up the distributed environment)
313
+ trainer.fabric.launch()
314
+
315
+ # Construct model
316
+ trainer.construct_model()
317
+
318
+ # Construct optimizer and scheduler
319
+ trainer.construct_optimizer()
320
+ trainer.construct_scheduler()
321
+
322
+
323
+ class CkptConfig:
324
+ def __init__(self, path, weight_loading_config=None, reset_optimizer=False):
325
+ self.path = path
326
+ self.weight_loading_config = weight_loading_config
327
+ self.reset_optimizer = reset_optimizer
328
+
329
+
330
+ ckpt_dir = output_dir / "ckpt"
331
+ if ckpt_dir.exists():
332
+ ckpt_config = CkptConfig(
333
+ path=ckpt_dir, weight_loading_config=None, reset_optimizer=False
334
+ )
335
+ else:
336
+ ckpt_config = None
337
+
338
+ # Run the full training using fit method
339
+ print("Starting training...")
340
+ trainer.fit(train_loader=train_loader, val_loaders=val_loaders, ckpt_config=ckpt_config)
341
+ print("Training completed!")