rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
mpnn/pipelines/mpnn.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
from atomworks.constants import AF3_EXCLUDED_LIGANDS, STANDARD_AA, UNKNOWN_AA
|
|
2
|
+
from atomworks.enums import ChainTypeInfo
|
|
3
|
+
from atomworks.ml.transforms.atom_array import AddWithinChainInstanceResIdx
|
|
4
|
+
from atomworks.ml.transforms.atomize import (
|
|
5
|
+
AtomizeByCCDName,
|
|
6
|
+
FlagNonPolymersForAtomization,
|
|
7
|
+
)
|
|
8
|
+
from atomworks.ml.transforms.base import (
|
|
9
|
+
AddData,
|
|
10
|
+
Compose,
|
|
11
|
+
ConditionalRoute,
|
|
12
|
+
ConvertToTorch,
|
|
13
|
+
Identity,
|
|
14
|
+
SubsetToKeys,
|
|
15
|
+
)
|
|
16
|
+
from atomworks.ml.transforms.covalent_modifications import (
|
|
17
|
+
FlagAndReassignCovalentModifications,
|
|
18
|
+
)
|
|
19
|
+
from atomworks.ml.transforms.featurize_unresolved_residues import (
|
|
20
|
+
MaskResiduesWithSpecificUnresolvedAtoms,
|
|
21
|
+
)
|
|
22
|
+
from atomworks.ml.transforms.filters import (
|
|
23
|
+
FilterToSpecifiedPNUnits,
|
|
24
|
+
HandleUndesiredResTokens,
|
|
25
|
+
RemoveHydrogens,
|
|
26
|
+
RemoveUnresolvedTokens,
|
|
27
|
+
)
|
|
28
|
+
from mpnn.transforms.feature_aggregation.mpnn import (
|
|
29
|
+
EncodeMPNNNonAtomizedTokens,
|
|
30
|
+
FeaturizeAtomizedTokens,
|
|
31
|
+
FeaturizeNonAtomizedTokens,
|
|
32
|
+
)
|
|
33
|
+
from mpnn.transforms.feature_aggregation.user_settings import (
|
|
34
|
+
FeaturizeUserSettings,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def TrainingRoute(transform):
|
|
39
|
+
return ConditionalRoute(
|
|
40
|
+
condition_func=lambda data: data["is_inference"],
|
|
41
|
+
transform_map={True: Identity(), False: transform},
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def InferenceRoute(transform):
|
|
46
|
+
return ConditionalRoute(
|
|
47
|
+
condition_func=lambda data: data["is_inference"],
|
|
48
|
+
transform_map={False: Identity(), True: transform},
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def ModelTypeRoute(transform, model_type: str):
|
|
53
|
+
return ConditionalRoute(
|
|
54
|
+
condition_func=lambda data: data["model_type"] == model_type,
|
|
55
|
+
transform_map={True: transform, False: Identity()},
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def build_mpnn_transform_pipeline(
|
|
60
|
+
*,
|
|
61
|
+
model_type: str = None,
|
|
62
|
+
occupancy_threshold_sidechain: float = 0.5,
|
|
63
|
+
occupancy_threshold_backbone: float = 0.8,
|
|
64
|
+
is_inference: bool = False,
|
|
65
|
+
minimal_return: bool = False,
|
|
66
|
+
train_structure_noise_default: float = 0.1,
|
|
67
|
+
undesired_res_names: list[str] = AF3_EXCLUDED_LIGANDS,
|
|
68
|
+
device=None,
|
|
69
|
+
) -> Compose:
|
|
70
|
+
"""Build the MPNN transform pipeline.
|
|
71
|
+
Args:
|
|
72
|
+
model_type (str): Model type identifier to include in data. Must be
|
|
73
|
+
provided. Defaults to None.
|
|
74
|
+
occupancy_threshold_sidechain (float): Minimum occupancy to consider
|
|
75
|
+
sidechain atoms as present in masks. Defaults to 0.5.
|
|
76
|
+
occupancy_threshold_backbone (float): Minimum occupancy to consider
|
|
77
|
+
backbone atoms as resolved. Residues with backbone atoms below this
|
|
78
|
+
threshold will be masked entirely. Defaults to 0.8.
|
|
79
|
+
train_structure_noise_default (float): Default standard deviation of
|
|
80
|
+
Gaussian noise to add to atomic coordinates during training for data
|
|
81
|
+
augmentation. Defaults to 0.1.
|
|
82
|
+
is_inference (bool): Whether this is inference mode. Defaults to
|
|
83
|
+
False (training mode).
|
|
84
|
+
minimal_return (bool): Whether to return minimal intermediate data.
|
|
85
|
+
Defaults to False.
|
|
86
|
+
undesired_res_names (list[str]): List of residue names to treat as
|
|
87
|
+
undesired and handle accordingly. Defaults to AF3_EXCLUDED_LIGANDS.
|
|
88
|
+
device (str | torch.device, optional): Device to move tensors to.
|
|
89
|
+
Defaults to None, which leads to default ConvertToTorch behavior.
|
|
90
|
+
"""
|
|
91
|
+
if model_type not in ("protein_mpnn", "ligand_mpnn"):
|
|
92
|
+
raise ValueError(f"Unsupported model_type: {model_type}")
|
|
93
|
+
|
|
94
|
+
transforms = [
|
|
95
|
+
AddData({"model_type": model_type}),
|
|
96
|
+
AddData({"is_inference": is_inference}),
|
|
97
|
+
# + --------- Filters --------- +
|
|
98
|
+
RemoveHydrogens(),
|
|
99
|
+
# ... during training, filter to non-clashing chains (which are
|
|
100
|
+
# pre-computed and stored in the "extra_info" key)
|
|
101
|
+
TrainingRoute(
|
|
102
|
+
FilterToSpecifiedPNUnits(
|
|
103
|
+
extra_info_key_with_pn_unit_iids_to_keep="all_pn_unit_iids_after_processing"
|
|
104
|
+
),
|
|
105
|
+
),
|
|
106
|
+
# ... during training, remove undesired residues (e.g., non-biological
|
|
107
|
+
# crystallization artifacts), mapping to the closest canonical residue
|
|
108
|
+
# name where possible
|
|
109
|
+
TrainingRoute(
|
|
110
|
+
HandleUndesiredResTokens(undesired_res_tokens=undesired_res_names),
|
|
111
|
+
),
|
|
112
|
+
# + --------- Atomization --------- +
|
|
113
|
+
# ... add within-chain instance res idx
|
|
114
|
+
AddWithinChainInstanceResIdx(),
|
|
115
|
+
FlagAndReassignCovalentModifications(),
|
|
116
|
+
# Atomization: keep standard AA + unknown AA as residues,
|
|
117
|
+
# atomize everything else
|
|
118
|
+
FlagNonPolymersForAtomization(),
|
|
119
|
+
AtomizeByCCDName(
|
|
120
|
+
atomize_by_default=True,
|
|
121
|
+
res_names_to_ignore=STANDARD_AA + (UNKNOWN_AA,),
|
|
122
|
+
move_atomized_part_to_end=False,
|
|
123
|
+
validate_atomize=False,
|
|
124
|
+
),
|
|
125
|
+
# + --------- Occupancy filtering --------- +
|
|
126
|
+
MaskResiduesWithSpecificUnresolvedAtoms(
|
|
127
|
+
chain_type_to_atom_names={
|
|
128
|
+
ChainTypeInfo.PROTEINS: [
|
|
129
|
+
"N",
|
|
130
|
+
"CA",
|
|
131
|
+
"C",
|
|
132
|
+
"O",
|
|
133
|
+
], # MPNN needs backbone + oxygen
|
|
134
|
+
},
|
|
135
|
+
occupancy_threshold=occupancy_threshold_backbone,
|
|
136
|
+
),
|
|
137
|
+
RemoveUnresolvedTokens(),
|
|
138
|
+
# +-------- Encoding and featurization --------- +
|
|
139
|
+
AddData({"input_features": dict()}),
|
|
140
|
+
# Encode and featurize non-atomized tokens
|
|
141
|
+
EncodeMPNNNonAtomizedTokens(occupancy_threshold=occupancy_threshold_sidechain),
|
|
142
|
+
FeaturizeNonAtomizedTokens(),
|
|
143
|
+
# LigandMPNN specific featurization: featurize atomized tokens
|
|
144
|
+
ModelTypeRoute(
|
|
145
|
+
transform=FeaturizeAtomizedTokens(),
|
|
146
|
+
model_type=model_type,
|
|
147
|
+
),
|
|
148
|
+
# Featurize user settings
|
|
149
|
+
FeaturizeUserSettings(
|
|
150
|
+
is_inference=is_inference,
|
|
151
|
+
minimal_return=minimal_return,
|
|
152
|
+
train_structure_noise_default=train_structure_noise_default,
|
|
153
|
+
),
|
|
154
|
+
# Convert to torch and subset keys
|
|
155
|
+
ConvertToTorch(
|
|
156
|
+
keys=["input_features"],
|
|
157
|
+
**({"device": device} if device is not None else {}),
|
|
158
|
+
),
|
|
159
|
+
SubsetToKeys(keys=["input_features", "atom_array"]),
|
|
160
|
+
]
|
|
161
|
+
|
|
162
|
+
return Compose(transforms)
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
from typing import Any, Callable, Iterator, List
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
from atomworks.ml.samplers import set_sampler_epoch
|
|
6
|
+
from torch.utils.data import BatchSampler, Sampler
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class PaddedTokenBudgetBatchSampler(BatchSampler):
|
|
10
|
+
"""
|
|
11
|
+
Token-based batch sampler that wraps existing samplers and creates batches
|
|
12
|
+
of similar-token length samples, respecting a maximum token count
|
|
13
|
+
constraint (considering that the batches will be padded to the maximum
|
|
14
|
+
length in the batch).
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
sampler: The underlying sampler to wrap around.
|
|
18
|
+
get_num_tokens: Function that takes an index from the previous sampler
|
|
19
|
+
and returns the number of tokens for that sample.
|
|
20
|
+
max_tokens_with_padding: Maximum number of tokens allowed per batch,
|
|
21
|
+
including padding. The constraint is
|
|
22
|
+
max(batch_lengths) * len(batch) <= max_tokens.
|
|
23
|
+
shuffle_batches: Whether to randomize the order of batches after
|
|
24
|
+
grouping by length. Defaults to True.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
sampler: Sampler,
|
|
30
|
+
get_num_tokens: Callable[[Any], int],
|
|
31
|
+
max_tokens_with_padding: int = 6000,
|
|
32
|
+
shuffle_batches: bool = True,
|
|
33
|
+
):
|
|
34
|
+
# Initialize BatchSampler with a dummy batch_size (we don't use it).
|
|
35
|
+
super().__init__(sampler, batch_size=1, drop_last=False)
|
|
36
|
+
|
|
37
|
+
self.sampler = sampler
|
|
38
|
+
self.get_num_tokens = get_num_tokens
|
|
39
|
+
self.max_tokens_with_padding = max_tokens_with_padding
|
|
40
|
+
self.shuffle_batches = shuffle_batches
|
|
41
|
+
|
|
42
|
+
# Add drop_last attribute for DataLoader compatibility
|
|
43
|
+
self.drop_last = False
|
|
44
|
+
|
|
45
|
+
self.epoch = 0
|
|
46
|
+
|
|
47
|
+
# Will hold our "one epoch" cache.
|
|
48
|
+
self._batches: List[List[Any]] | None = None
|
|
49
|
+
|
|
50
|
+
# Validate inputs
|
|
51
|
+
if max_tokens_with_padding <= 0:
|
|
52
|
+
raise ValueError("max_tokens_with_padding must be greater than 0")
|
|
53
|
+
if get_num_tokens is None:
|
|
54
|
+
raise ValueError("get_num_tokens function must be provided")
|
|
55
|
+
|
|
56
|
+
def set_epoch(self, epoch: int) -> None:
|
|
57
|
+
self.epoch = epoch
|
|
58
|
+
set_sampler_epoch(self.sampler, epoch)
|
|
59
|
+
|
|
60
|
+
def _build_batches(self) -> List[List[Any]]:
|
|
61
|
+
"""
|
|
62
|
+
Compute all batches and cache them for the current epoch.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
List of batches, where each batch is a list of indices.
|
|
66
|
+
Raises:
|
|
67
|
+
TypeError: If get_num_tokens returns invalid types.
|
|
68
|
+
ValueError: If get_num_tokens returns invalid values.
|
|
69
|
+
"""
|
|
70
|
+
# Extract all indices and their token counts
|
|
71
|
+
sample_indices_and_lengths = []
|
|
72
|
+
for idx in self.sampler:
|
|
73
|
+
num_tokens = self.get_num_tokens(idx)
|
|
74
|
+
|
|
75
|
+
# Validate num_tokens type
|
|
76
|
+
if not isinstance(num_tokens, (int, float, np.integer, np.floating)):
|
|
77
|
+
raise TypeError(
|
|
78
|
+
f"get_num_tokens returned invalid type {type(num_tokens)} "
|
|
79
|
+
f"for index {idx}. Expected numeric type."
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# Validate num_tokens value.
|
|
83
|
+
num_tokens = int(num_tokens)
|
|
84
|
+
if num_tokens <= 0:
|
|
85
|
+
raise ValueError(
|
|
86
|
+
f"get_num_tokens returned invalid value {num_tokens} "
|
|
87
|
+
f"for index {idx}. Expected positive integer."
|
|
88
|
+
)
|
|
89
|
+
if num_tokens > self.max_tokens_with_padding:
|
|
90
|
+
raise ValueError(
|
|
91
|
+
f"Index {idx} has {num_tokens} tokens, exceeding "
|
|
92
|
+
f"max_tokens_with_padding={self.max_tokens_with_padding}."
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
sample_indices_and_lengths.append((idx, num_tokens))
|
|
96
|
+
|
|
97
|
+
# Sort by token length (ascending order).
|
|
98
|
+
sample_indices_and_lengths.sort(key=lambda x: x[1])
|
|
99
|
+
|
|
100
|
+
# Batch by length
|
|
101
|
+
batches = []
|
|
102
|
+
current_batch = []
|
|
103
|
+
current_max_length = 0
|
|
104
|
+
for idx, length in sample_indices_and_lengths:
|
|
105
|
+
# Check if adding this sample would violate max_tokens_with_padding
|
|
106
|
+
# constraint.
|
|
107
|
+
if current_batch:
|
|
108
|
+
potential_max_length = max(length, current_max_length)
|
|
109
|
+
new_batch_size = len(current_batch) + 1
|
|
110
|
+
|
|
111
|
+
if potential_max_length * new_batch_size > self.max_tokens_with_padding:
|
|
112
|
+
# Current batch is full, start a new batch
|
|
113
|
+
batches.append(current_batch)
|
|
114
|
+
|
|
115
|
+
current_batch = [idx]
|
|
116
|
+
current_max_length = length
|
|
117
|
+
else:
|
|
118
|
+
# Add to current batch
|
|
119
|
+
current_batch.append(idx)
|
|
120
|
+
current_max_length = potential_max_length
|
|
121
|
+
else:
|
|
122
|
+
# First sample in batch
|
|
123
|
+
current_batch = [idx]
|
|
124
|
+
current_max_length = length
|
|
125
|
+
|
|
126
|
+
# Add the last batch if it's not empty
|
|
127
|
+
if current_batch:
|
|
128
|
+
batches.append(current_batch)
|
|
129
|
+
|
|
130
|
+
# Randomize batch order if requested
|
|
131
|
+
if self.shuffle_batches:
|
|
132
|
+
# Set the seed based on the epoch
|
|
133
|
+
g = torch.Generator()
|
|
134
|
+
g.manual_seed(self.epoch)
|
|
135
|
+
|
|
136
|
+
perm = torch.randperm(len(batches), generator=g).tolist()
|
|
137
|
+
|
|
138
|
+
batches = [batches[i] for i in perm]
|
|
139
|
+
|
|
140
|
+
return batches
|
|
141
|
+
|
|
142
|
+
def __iter__(self) -> Iterator[List[Any]]:
|
|
143
|
+
"""
|
|
144
|
+
Generate batches of indices grouped by token length while respecting
|
|
145
|
+
the max_tokens_with_padding constraint.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Iterator[List[Any]]: Iterator over batches of dataset indices.
|
|
149
|
+
"""
|
|
150
|
+
# Build/reference cached batches if not already done.
|
|
151
|
+
if self._batches is None:
|
|
152
|
+
self._batches = self._build_batches()
|
|
153
|
+
|
|
154
|
+
for batch in self._batches:
|
|
155
|
+
yield batch
|
|
156
|
+
|
|
157
|
+
# End of __iter__ — clear cache so next call recomputes.
|
|
158
|
+
self._batches = None
|
|
159
|
+
|
|
160
|
+
def __len__(self) -> int:
|
|
161
|
+
"""
|
|
162
|
+
Return the exact number of batches that will be produced.
|
|
163
|
+
"""
|
|
164
|
+
# Build/reference cached batches if not already done.
|
|
165
|
+
if self._batches is None:
|
|
166
|
+
self._batches = self._build_batches()
|
|
167
|
+
return len(self._batches)
|
mpnn/train.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/mpnn_exec.sh" "$0" "$@"'
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import torch
|
|
8
|
+
from atomworks.io.parser import STANDARD_PARSER_ARGS
|
|
9
|
+
from atomworks.ml.datasets.pandas_dataset import PandasDataset, StructuralDatasetWrapper
|
|
10
|
+
from atomworks.ml.datasets.parsers.default_metadata_row_parsers import GenericDFParser
|
|
11
|
+
from atomworks.ml.samplers import (
|
|
12
|
+
DistributedMixedSampler,
|
|
13
|
+
calculate_weights_for_pdb_dataset_df,
|
|
14
|
+
)
|
|
15
|
+
from omegaconf import DictConfig
|
|
16
|
+
from torch.utils.data import DataLoader, WeightedRandomSampler
|
|
17
|
+
|
|
18
|
+
from foundry.callbacks.metrics_logging import StoreValidationMetricsInDFCallback
|
|
19
|
+
from foundry.utils.datasets import wrap_dataset_and_sampler_with_fallbacks
|
|
20
|
+
from mpnn.collate.feature_collator import TokenBudgetAwareFeatureCollator
|
|
21
|
+
from mpnn.pipelines.mpnn import build_mpnn_transform_pipeline
|
|
22
|
+
from mpnn.samplers.samplers import PaddedTokenBudgetBatchSampler
|
|
23
|
+
from mpnn.trainers.mpnn import MPNNTrainer
|
|
24
|
+
|
|
25
|
+
model_type = sys.argv[1]
|
|
26
|
+
|
|
27
|
+
if model_type == "protein_mpnn":
|
|
28
|
+
batch_size = 10000
|
|
29
|
+
train_date_cutoff = "2021-08-02"
|
|
30
|
+
clip_grad_max_norm = None
|
|
31
|
+
train_structure_noise_default = 0.2
|
|
32
|
+
elif model_type == "ligand_mpnn":
|
|
33
|
+
batch_size = 6000
|
|
34
|
+
train_date_cutoff = "2022-12-16"
|
|
35
|
+
clip_grad_max_norm = 1.0
|
|
36
|
+
train_structure_noise_default = 0.1
|
|
37
|
+
else:
|
|
38
|
+
raise ValueError(f"Unknown model_type: {model_type}")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def create_noam_scheduler(optimizer, d_model, warmup_steps=4000, factor=2):
|
|
42
|
+
"""
|
|
43
|
+
Create a NoamOpt-style scheduler using standard PyTorch components.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
optimizer: PyTorch optimizer
|
|
47
|
+
d_model: Model dimension (for scaling)
|
|
48
|
+
warmup_steps: Number of warmup steps
|
|
49
|
+
factor: Scaling factor
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
LambdaLR scheduler that implements NoamOpt schedule
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def noam_lambda(step):
|
|
56
|
+
# NoamOpt formula: factor * (d_model ** (-0.5)) * min(step ** (-0.5), step * warmup ** (-1.5))
|
|
57
|
+
base_lr = factor * (d_model ** (-0.5))
|
|
58
|
+
if step == 0:
|
|
59
|
+
return 0.0 # Start with zero learning rate
|
|
60
|
+
|
|
61
|
+
# Calculate the schedule component
|
|
62
|
+
schedule = min(step ** (-0.5), step * warmup_steps ** (-1.5))
|
|
63
|
+
return base_lr * schedule
|
|
64
|
+
|
|
65
|
+
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=noam_lambda)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def get_num_tokens(df, idx):
|
|
69
|
+
"""
|
|
70
|
+
Extract the number of non-atomized tokens for a given index.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
df: DataFrame containing the dataset
|
|
74
|
+
idx: Index to extract token count from
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Number of non-atomized tokens for the sample at idx
|
|
78
|
+
"""
|
|
79
|
+
if isinstance(idx, (list, tuple)):
|
|
80
|
+
# If idx is a list/tuple, return the first element's token count
|
|
81
|
+
idx = idx[0]
|
|
82
|
+
return df.iloc[idx]["n_non_atomized_tokens"]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
# Common filters for MPNN datasets
|
|
86
|
+
MPNN_FILTERS = [
|
|
87
|
+
"resolution < 3.5 and ~method.str.contains('NMR')",
|
|
88
|
+
"n_non_atomized_tokens >= 30",
|
|
89
|
+
"cluster.notnull() and cluster != 'nan'",
|
|
90
|
+
"method in ['X-RAY_DIFFRACTION', 'ELECTRON_MICROSCOPY']",
|
|
91
|
+
f"n_non_atomized_tokens <= {batch_size}",
|
|
92
|
+
"n_prot == 1",
|
|
93
|
+
]
|
|
94
|
+
|
|
95
|
+
MPNN_TRAIN_FILTERS = [
|
|
96
|
+
f"deposition_date < '{train_date_cutoff}'",
|
|
97
|
+
] + MPNN_FILTERS
|
|
98
|
+
|
|
99
|
+
# Data loading setup
|
|
100
|
+
train_path = "/projects/ml/datahub/dfs/mpnn/splits/2025_07_13/pn_units_df_train.parquet"
|
|
101
|
+
val_path = "/projects/ml/datahub/dfs/mpnn/splits/2025_07_13/pn_units_df_val.parquet"
|
|
102
|
+
|
|
103
|
+
# Load datasets
|
|
104
|
+
train_df = pd.read_parquet(train_path)
|
|
105
|
+
val_df = pd.read_parquet(val_path)
|
|
106
|
+
|
|
107
|
+
# Create different pipelines for training and inference
|
|
108
|
+
train_pipeline = build_mpnn_transform_pipeline(
|
|
109
|
+
model_type=model_type,
|
|
110
|
+
is_inference=False,
|
|
111
|
+
minimal_return=True,
|
|
112
|
+
train_structure_noise_default=train_structure_noise_default,
|
|
113
|
+
)
|
|
114
|
+
inference_pipeline = build_mpnn_transform_pipeline(
|
|
115
|
+
model_type=model_type, is_inference=True, minimal_return=True
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# Create train dataset with fallback
|
|
119
|
+
train_structural_dataset = StructuralDatasetWrapper(
|
|
120
|
+
dataset=PandasDataset(
|
|
121
|
+
data=train_df,
|
|
122
|
+
id_column="example_id",
|
|
123
|
+
name="pn_units_df_train",
|
|
124
|
+
filters=MPNN_TRAIN_FILTERS,
|
|
125
|
+
),
|
|
126
|
+
dataset_parser=GenericDFParser(
|
|
127
|
+
example_id_colname="example_id",
|
|
128
|
+
path_colname="path",
|
|
129
|
+
assembly_id_colname="assembly_id",
|
|
130
|
+
),
|
|
131
|
+
transform=train_pipeline,
|
|
132
|
+
cif_parser_args={
|
|
133
|
+
**STANDARD_PARSER_ARGS,
|
|
134
|
+
"load_from_cache": True,
|
|
135
|
+
"save_to_cache": True,
|
|
136
|
+
"cache_dir": "/net/tukwila/akubaney/cifutils/cache",
|
|
137
|
+
},
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
train_fallback_dataset = StructuralDatasetWrapper(
|
|
141
|
+
dataset=PandasDataset(
|
|
142
|
+
data=train_df,
|
|
143
|
+
id_column="example_id",
|
|
144
|
+
name="pn_units_df_train",
|
|
145
|
+
filters=MPNN_TRAIN_FILTERS,
|
|
146
|
+
),
|
|
147
|
+
dataset_parser=GenericDFParser(
|
|
148
|
+
example_id_colname="example_id",
|
|
149
|
+
path_colname="path",
|
|
150
|
+
assembly_id_colname="assembly_id",
|
|
151
|
+
),
|
|
152
|
+
transform=train_pipeline,
|
|
153
|
+
cif_parser_args={
|
|
154
|
+
**STANDARD_PARSER_ARGS,
|
|
155
|
+
"load_from_cache": True,
|
|
156
|
+
"save_to_cache": True,
|
|
157
|
+
"cache_dir": "/net/tukwila/akubaney/cifutils/cache",
|
|
158
|
+
},
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# Calculate weights for train dataset
|
|
162
|
+
train_weights = calculate_weights_for_pdb_dataset_df(
|
|
163
|
+
dataset_df=train_structural_dataset.data,
|
|
164
|
+
beta=1.0, # For chains
|
|
165
|
+
alphas={"a_prot": 1.0, "a_nuc": 0, "a_ligand": 0, "a_loi": 0},
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
# Create train sampler with fallback
|
|
169
|
+
train_sampler = DistributedMixedSampler(
|
|
170
|
+
datasets_info=[
|
|
171
|
+
{
|
|
172
|
+
"sampler": WeightedRandomSampler(
|
|
173
|
+
train_weights, len(train_structural_dataset)
|
|
174
|
+
),
|
|
175
|
+
"dataset": train_structural_dataset,
|
|
176
|
+
"probability": 1.0,
|
|
177
|
+
}
|
|
178
|
+
],
|
|
179
|
+
num_replicas=1,
|
|
180
|
+
rank=0,
|
|
181
|
+
n_examples_per_epoch=20000,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
train_fallback_sampler = WeightedRandomSampler(
|
|
185
|
+
train_weights, len(train_structural_dataset)
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
train_dataset_with_fallback, train_sampler_with_fallback = (
|
|
189
|
+
wrap_dataset_and_sampler_with_fallbacks(
|
|
190
|
+
dataset_to_be_wrapped=train_structural_dataset,
|
|
191
|
+
sampler_to_be_wrapped=train_sampler,
|
|
192
|
+
dataset_to_fallback_to=train_fallback_dataset,
|
|
193
|
+
sampler_to_fallback_to=train_fallback_sampler,
|
|
194
|
+
n_fallback_retries=5,
|
|
195
|
+
)
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
batched_train_sampler = PaddedTokenBudgetBatchSampler(
|
|
199
|
+
sampler=train_sampler_with_fallback,
|
|
200
|
+
get_num_tokens=lambda idx: get_num_tokens(train_structural_dataset.data, idx),
|
|
201
|
+
max_tokens_with_padding=batch_size,
|
|
202
|
+
shuffle_batches=True,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
# Create val dataset with fallback
|
|
206
|
+
val_structural_dataset = StructuralDatasetWrapper(
|
|
207
|
+
dataset=PandasDataset(
|
|
208
|
+
data=val_df,
|
|
209
|
+
id_column="example_id",
|
|
210
|
+
name="pn_units_df_val",
|
|
211
|
+
filters=MPNN_FILTERS,
|
|
212
|
+
),
|
|
213
|
+
dataset_parser=GenericDFParser(
|
|
214
|
+
example_id_colname="example_id",
|
|
215
|
+
path_colname="path",
|
|
216
|
+
assembly_id_colname="assembly_id",
|
|
217
|
+
),
|
|
218
|
+
transform=inference_pipeline,
|
|
219
|
+
cif_parser_args={
|
|
220
|
+
**STANDARD_PARSER_ARGS,
|
|
221
|
+
"load_from_cache": True,
|
|
222
|
+
"save_to_cache": True,
|
|
223
|
+
"cache_dir": "/net/tukwila/akubaney/cifutils/cache",
|
|
224
|
+
},
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Create val sampler with fallback
|
|
228
|
+
val_weights = calculate_weights_for_pdb_dataset_df(
|
|
229
|
+
dataset_df=val_structural_dataset.data,
|
|
230
|
+
beta=1.0, # For chains
|
|
231
|
+
alphas={"a_prot": 1.0, "a_nuc": 0, "a_ligand": 0, "a_loi": 0},
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
val_sampler = DistributedMixedSampler(
|
|
235
|
+
datasets_info=[
|
|
236
|
+
{
|
|
237
|
+
"sampler": WeightedRandomSampler(val_weights, len(val_structural_dataset)),
|
|
238
|
+
"dataset": val_structural_dataset,
|
|
239
|
+
"probability": 1.0,
|
|
240
|
+
}
|
|
241
|
+
],
|
|
242
|
+
num_replicas=1,
|
|
243
|
+
rank=0,
|
|
244
|
+
n_examples_per_epoch=100,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# Create collator
|
|
248
|
+
collator = TokenBudgetAwareFeatureCollator(max_tokens_with_padding=batch_size)
|
|
249
|
+
|
|
250
|
+
# Create DataLoaders
|
|
251
|
+
train_loader = DataLoader(
|
|
252
|
+
train_dataset_with_fallback,
|
|
253
|
+
batch_sampler=batched_train_sampler,
|
|
254
|
+
num_workers=12,
|
|
255
|
+
collate_fn=collator,
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
val_loaders = {
|
|
259
|
+
"test_val": DataLoader(
|
|
260
|
+
val_structural_dataset,
|
|
261
|
+
sampler=val_sampler,
|
|
262
|
+
num_workers=12,
|
|
263
|
+
collate_fn=collator,
|
|
264
|
+
)
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
# Create output directory for logs and checkpoints
|
|
268
|
+
output_dir = Path(f"./mpnn_output_{model_type}")
|
|
269
|
+
output_dir.mkdir(exist_ok=True)
|
|
270
|
+
|
|
271
|
+
# Create CSV logging callback
|
|
272
|
+
csv_callback = StoreValidationMetricsInDFCallback(
|
|
273
|
+
save_dir=output_dir / "val_metrics", metrics_to_save="all"
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# Create trainer with minimal configuration for testing
|
|
277
|
+
trainer = MPNNTrainer(
|
|
278
|
+
model_type=model_type,
|
|
279
|
+
accelerator="gpu",
|
|
280
|
+
devices_per_node=1,
|
|
281
|
+
max_epochs=500,
|
|
282
|
+
output_dir=output_dir,
|
|
283
|
+
callbacks=[csv_callback],
|
|
284
|
+
precision="bf16-mixed",
|
|
285
|
+
clip_grad_max_norm=clip_grad_max_norm,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
# Create minimal train_cfg for optimizer and scheduler construction
|
|
289
|
+
train_cfg = DictConfig(
|
|
290
|
+
{
|
|
291
|
+
"model": {
|
|
292
|
+
"optimizer": {
|
|
293
|
+
"_target_": "torch.optim.Adam",
|
|
294
|
+
"lr": 1.0, # This will be overridden by the NoamOpt scheduler
|
|
295
|
+
"betas": [0.9, 0.98], # NoamOpt uses (0.9, 0.98)
|
|
296
|
+
"eps": 1e-9, # NoamOpt uses 1e-9
|
|
297
|
+
"weight_decay": 0.0,
|
|
298
|
+
},
|
|
299
|
+
"lr_scheduler": {
|
|
300
|
+
"_target_": "__main__.create_noam_scheduler",
|
|
301
|
+
"d_model": 128, # Adjust based on your model's hidden dimension
|
|
302
|
+
"warmup_steps": 4000,
|
|
303
|
+
"factor": 2,
|
|
304
|
+
},
|
|
305
|
+
}
|
|
306
|
+
}
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
# Initialize trainer state with train_cfg
|
|
310
|
+
trainer.initialize_or_update_trainer_state({"train_cfg": train_cfg})
|
|
311
|
+
|
|
312
|
+
# Launch Fabric (this sets up the distributed environment)
|
|
313
|
+
trainer.fabric.launch()
|
|
314
|
+
|
|
315
|
+
# Construct model
|
|
316
|
+
trainer.construct_model()
|
|
317
|
+
|
|
318
|
+
# Construct optimizer and scheduler
|
|
319
|
+
trainer.construct_optimizer()
|
|
320
|
+
trainer.construct_scheduler()
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
class CkptConfig:
|
|
324
|
+
def __init__(self, path, weight_loading_config=None, reset_optimizer=False):
|
|
325
|
+
self.path = path
|
|
326
|
+
self.weight_loading_config = weight_loading_config
|
|
327
|
+
self.reset_optimizer = reset_optimizer
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
ckpt_dir = output_dir / "ckpt"
|
|
331
|
+
if ckpt_dir.exists():
|
|
332
|
+
ckpt_config = CkptConfig(
|
|
333
|
+
path=ckpt_dir, weight_loading_config=None, reset_optimizer=False
|
|
334
|
+
)
|
|
335
|
+
else:
|
|
336
|
+
ckpt_config = None
|
|
337
|
+
|
|
338
|
+
# Run the full training using fit method
|
|
339
|
+
print("Starting training...")
|
|
340
|
+
trainer.fit(train_loader=train_loader, val_loaders=val_loaders, ckpt_config=ckpt_config)
|
|
341
|
+
print("Training completed!")
|