rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,405 @@
|
|
|
1
|
+
import hydra
|
|
2
|
+
import torch
|
|
3
|
+
from atomworks.ml.samplers import (
|
|
4
|
+
DistributedMixedSampler,
|
|
5
|
+
FallbackSamplerWrapper,
|
|
6
|
+
LazyWeightedRandomSampler,
|
|
7
|
+
LoadBalancedDistributedSampler,
|
|
8
|
+
MixedSampler,
|
|
9
|
+
)
|
|
10
|
+
from beartype.typing import Any
|
|
11
|
+
from omegaconf import DictConfig, ListConfig
|
|
12
|
+
from torch.utils.data import (
|
|
13
|
+
DataLoader,
|
|
14
|
+
Dataset,
|
|
15
|
+
RandomSampler,
|
|
16
|
+
Sampler,
|
|
17
|
+
SequentialSampler,
|
|
18
|
+
Subset,
|
|
19
|
+
WeightedRandomSampler,
|
|
20
|
+
)
|
|
21
|
+
from torch.utils.data.distributed import DistributedSampler
|
|
22
|
+
|
|
23
|
+
from foundry.hydra.resolvers import register_resolvers
|
|
24
|
+
from foundry.utils.ddp import RankedLogger
|
|
25
|
+
|
|
26
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
27
|
+
try:
|
|
28
|
+
from atomworks.ml.datasets import (
|
|
29
|
+
ConcatDatasetWithID,
|
|
30
|
+
FallbackDatasetWrapper,
|
|
31
|
+
get_row_and_index_by_example_id,
|
|
32
|
+
)
|
|
33
|
+
except Exception as e:
|
|
34
|
+
ranked_logger.warning(
|
|
35
|
+
f"Failed to import atomworks.ml.datasets: {type(e).__name__}: {e}. "
|
|
36
|
+
"If training networks, the PDB_MIRROR environment variable must be set."
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
register_resolvers()
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def wrap_dataset_and_sampler_with_fallbacks(
|
|
44
|
+
dataset_to_be_wrapped: Dataset,
|
|
45
|
+
sampler_to_be_wrapped: Sampler,
|
|
46
|
+
dataset_to_fallback_to: Dataset,
|
|
47
|
+
sampler_to_fallback_to: Sampler,
|
|
48
|
+
n_fallback_retries: int,
|
|
49
|
+
) -> tuple[Dataset, Sampler]:
|
|
50
|
+
"""Wrap the specified dataset and sampler with fallback dataloading.
|
|
51
|
+
|
|
52
|
+
If the provided fallback sampler does not have weights (e.g., a MixedSampler), we will use uniform weights.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
dataset_to_be_wrapped (Dataset): The main dataset to be wrapped.
|
|
56
|
+
sampler_to_be_wrapped (Sampler): The main sampler to be wrapped.
|
|
57
|
+
dataset_to_fallback_to (Dataset): The fallback dataset. We will sample from this dataset if the main dataset fails.
|
|
58
|
+
sampler_to_fallback_to (Sampler): The fallback sampler. We will sample from this sampler if the main sampler fails.
|
|
59
|
+
n_fallback_retries (int): Number of retries for the fallback mechanism before raising an exception.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
tuple[Dataset, Sampler]: The wrapped dataset and sampler with fallbacks.
|
|
63
|
+
"""
|
|
64
|
+
# Instantiate a new fallback sampler to avoid scaling issues
|
|
65
|
+
fallback_sampler = LazyWeightedRandomSampler(
|
|
66
|
+
weights=sampler_to_fallback_to.weights
|
|
67
|
+
if "weights" in sampler_to_fallback_to
|
|
68
|
+
else torch.ones(len(dataset_to_fallback_to)),
|
|
69
|
+
num_samples=int(1e9),
|
|
70
|
+
replacement=True, # replacement for fallback dataloading, so we can draw a huge number of samples
|
|
71
|
+
generator=None,
|
|
72
|
+
prefetch_buffer_size=4,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Wrap the dataset and sampler with fallback mechanisms
|
|
76
|
+
wrapped_dataset = FallbackDatasetWrapper(
|
|
77
|
+
dataset_to_be_wrapped, fallback_dataset=dataset_to_fallback_to
|
|
78
|
+
)
|
|
79
|
+
wrapped_sampler = FallbackSamplerWrapper(
|
|
80
|
+
sampler_to_be_wrapped,
|
|
81
|
+
fallback_sampler=fallback_sampler,
|
|
82
|
+
n_fallback_retries=n_fallback_retries,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
return wrapped_dataset, wrapped_sampler
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def instantiate_single_dataset_and_sampler(cfg: DictConfig | dict) -> dict[str, Any]:
|
|
89
|
+
"""Instantiate a dataset and its corresponding sampler from a configuration dictionary.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
cfg (DictConfig): Configuration dictionary defining the dataset and its parameters.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
dict[str, Any]: A dictionary containing the instantiated dataset and sampler.
|
|
96
|
+
"""
|
|
97
|
+
# ... instantiate the dataset
|
|
98
|
+
dataset = hydra.utils.instantiate(cfg.dataset)
|
|
99
|
+
|
|
100
|
+
# Users may provide only weights, in which case we will use a WeightedRandomSampler,
|
|
101
|
+
# or they may provide a sampler directly
|
|
102
|
+
|
|
103
|
+
if "weights" in cfg and "sampler" not in cfg:
|
|
104
|
+
# ... instantiate the weights and create a WeightedRandomSampler
|
|
105
|
+
weights = hydra.utils.instantiate(cfg.weights, dataset_df=dataset.data)
|
|
106
|
+
sampler = WeightedRandomSampler(
|
|
107
|
+
weights=weights,
|
|
108
|
+
num_samples=len(dataset),
|
|
109
|
+
replacement=True,
|
|
110
|
+
)
|
|
111
|
+
elif "sampler" in cfg and "weights" not in cfg:
|
|
112
|
+
# ... instantiate the sampler with the number of samples
|
|
113
|
+
sampler = hydra.utils.instantiate(cfg.sampler)
|
|
114
|
+
else:
|
|
115
|
+
dataset_name = getattr(getattr(cfg.dataset, "dataset", None), "name", None)
|
|
116
|
+
ranked_logger.warning(
|
|
117
|
+
f"No weights or sampler provided for dataset: {dataset_name}, using uniform weights with replacement."
|
|
118
|
+
)
|
|
119
|
+
sampler = WeightedRandomSampler(
|
|
120
|
+
weights=torch.ones(len(dataset)),
|
|
121
|
+
num_samples=len(dataset),
|
|
122
|
+
replacement=True,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
return {"dataset": dataset, "sampler": sampler}
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def recursively_instantiate_datasets_and_samplers(
|
|
129
|
+
cfg: DictConfig | dict, name: str | None = None
|
|
130
|
+
) -> dict[str, Any]:
|
|
131
|
+
"""Recursively instantiate datasets and samplers from a configuration dictionary.
|
|
132
|
+
|
|
133
|
+
We must handle three cases:
|
|
134
|
+
(1) A single "leaf" dataset (e.g., "distillation"), specified with the "dataset" key
|
|
135
|
+
(2) Multiple sub-datasets that should be concatenated together with their weights (e.g., "interfaces" and "pn_units"),
|
|
136
|
+
specified with the "sub_datasets" key
|
|
137
|
+
(3) Multiple "leaf" datasets that should be sampled from with a certain probability (e.g., "distillation" and "pdb"),
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
cfg (DictConfig): Configuration dictionary defining datasets and their parameters.
|
|
141
|
+
name (str, optional): The name of the dataset, used for reporting. Defaults to None.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
dict[str, Any]: A dictionary containing the instantiated dataset and sampler.
|
|
145
|
+
"""
|
|
146
|
+
# ------- Base case (1): A single "leaf" dataset -------#
|
|
147
|
+
if "dataset" in cfg:
|
|
148
|
+
return {**instantiate_single_dataset_and_sampler(cfg), "name": name}
|
|
149
|
+
|
|
150
|
+
# ------- Recursive case (2): Multiple sub-datasets that must be concatenated together -------#
|
|
151
|
+
elif "sub_datasets" in cfg:
|
|
152
|
+
# ... create a list of dictionaries for each sub-dataset
|
|
153
|
+
datasets_info = []
|
|
154
|
+
for sub_dataset_name, sub_dataset_cfg in cfg.sub_datasets.items():
|
|
155
|
+
if sub_dataset_cfg is None:
|
|
156
|
+
# (Skip any None sub-datasets; e.g., those overrode by the experiment config)
|
|
157
|
+
continue
|
|
158
|
+
|
|
159
|
+
datasets_info.append(
|
|
160
|
+
recursively_instantiate_datasets_and_samplers(
|
|
161
|
+
sub_dataset_cfg, name=sub_dataset_name
|
|
162
|
+
)
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# ... concatenate sub-datasets and weights (e.g., "interfaces" and "pn_units" into one ConcatDataset)
|
|
166
|
+
# NOTE: Order of the weights must match the order of the datasets!
|
|
167
|
+
concatenated_dataset = ConcatDatasetWithID(
|
|
168
|
+
datasets=[info["dataset"] for info in datasets_info]
|
|
169
|
+
)
|
|
170
|
+
concatenated_weights = torch.cat(
|
|
171
|
+
[info["sampler"].weights for info in datasets_info]
|
|
172
|
+
)
|
|
173
|
+
sampler = WeightedRandomSampler(
|
|
174
|
+
weights=concatenated_weights,
|
|
175
|
+
num_samples=len(concatenated_dataset),
|
|
176
|
+
replacement=True,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
return {"dataset": concatenated_dataset, "sampler": sampler, "name": name}
|
|
180
|
+
|
|
181
|
+
# ------- Recursive case (3): Multiple datasets that must be sampled from with specified probabilities -------#
|
|
182
|
+
else:
|
|
183
|
+
datasets_info = []
|
|
184
|
+
for nested_dataset_name, nested_dataset_cfg in cfg.items():
|
|
185
|
+
if nested_dataset_cfg is None:
|
|
186
|
+
# (Skip any None training datasets; e.g., those overrode by the experiment config)
|
|
187
|
+
continue
|
|
188
|
+
|
|
189
|
+
# (To use a MixedSampler, we must provide a "probability" key for each dataset)
|
|
190
|
+
assert (
|
|
191
|
+
"probability" in nested_dataset_cfg
|
|
192
|
+
), "Expected 'probability' key in dataset configuration"
|
|
193
|
+
datasets_info.append(
|
|
194
|
+
{
|
|
195
|
+
**recursively_instantiate_datasets_and_samplers(
|
|
196
|
+
nested_dataset_cfg, name=nested_dataset_name
|
|
197
|
+
),
|
|
198
|
+
"probability": nested_dataset_cfg["probability"],
|
|
199
|
+
}
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# ... check that the sum of probabilities of all datasets is 1
|
|
203
|
+
assert (
|
|
204
|
+
abs(1 - sum(dataset_info["probability"] for dataset_info in datasets_info))
|
|
205
|
+
< 1e-5
|
|
206
|
+
), "Sum of probabilities must be 1.0"
|
|
207
|
+
|
|
208
|
+
# ... compose the list of datasets into a single dataset
|
|
209
|
+
composed_train_dataset = ConcatDatasetWithID(
|
|
210
|
+
datasets=[dataset["dataset"] for dataset in datasets_info]
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
composed_train_sampler = MixedSampler(datasets_info=datasets_info, shuffle=True)
|
|
214
|
+
|
|
215
|
+
return {
|
|
216
|
+
"dataset": composed_train_dataset,
|
|
217
|
+
"sampler": composed_train_sampler,
|
|
218
|
+
"name": name,
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def assemble_distributed_loader(
|
|
223
|
+
dataset: Dataset,
|
|
224
|
+
sampler: Sampler | None = None,
|
|
225
|
+
rank: int | None = None,
|
|
226
|
+
world_size: int | None = None,
|
|
227
|
+
n_examples_per_epoch: int | None = None,
|
|
228
|
+
loader_cfg: DictConfig | dict | None = None,
|
|
229
|
+
shuffle: bool = True,
|
|
230
|
+
drop_last: bool = False,
|
|
231
|
+
) -> DataLoader:
|
|
232
|
+
"""Assembles a distributed DataLoader for training or validation.
|
|
233
|
+
|
|
234
|
+
Performs the following steps:
|
|
235
|
+
(1) If not already a distributed sampler, wraps the sampler with a DistributedSampler or DistributedMixedSampler
|
|
236
|
+
(2) Wraps the dataset and sampler with a fallback mechanism, if needed
|
|
237
|
+
(3) Assembles the final DataLoader
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
dataset (Dataset): The dataset to be used for training or validation.
|
|
241
|
+
sampler (Sampler): The sampler to be used for training or validation. May already be distributed.
|
|
242
|
+
rank (int): The rank of the current process in distributed training.
|
|
243
|
+
world_size (int): The total number of processes participating in the distributed training.
|
|
244
|
+
n_examples_per_epoch (int): The number of examples to sample per epoch, across all GPUs.
|
|
245
|
+
For example, if we have 8 GPUs, with 2 gradient accumulation steps and 10 optimizer
|
|
246
|
+
steps per epoch, we would sample 160 examples per epoch (8 * 2 * 10).
|
|
247
|
+
loader_cfg (DictConfig or dict, optional): Additional configuration parameters for the
|
|
248
|
+
DataLoader, such as `batch_size` and `num_workers`. Defaults to an empty dictionary.
|
|
249
|
+
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to True.
|
|
250
|
+
drop_last (bool, optional): Whether to drop the last incomplete batch if the dataset size
|
|
251
|
+
is not divisible by the number of GPUs. Defaults to False.
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
DataLoader: A PyTorch DataLoader configured for distributed training, with datasets
|
|
255
|
+
concatenated and sampled according to their defined probabilities.
|
|
256
|
+
"""
|
|
257
|
+
if not loader_cfg:
|
|
258
|
+
loader_cfg = {}
|
|
259
|
+
|
|
260
|
+
if isinstance(sampler, MixedSampler):
|
|
261
|
+
# (If given a MixedSampler, we must convert to a DistributedMixedSampler)
|
|
262
|
+
assert (
|
|
263
|
+
rank is not None
|
|
264
|
+
and world_size is not None
|
|
265
|
+
and n_examples_per_epoch is not None
|
|
266
|
+
), "Rank, world_size, and n_examples_per_epoch must be provided for MixedSampler"
|
|
267
|
+
sampler = DistributedMixedSampler(
|
|
268
|
+
datasets_info=sampler.datasets_info,
|
|
269
|
+
num_replicas=world_size,
|
|
270
|
+
rank=rank,
|
|
271
|
+
n_examples_per_epoch=n_examples_per_epoch,
|
|
272
|
+
shuffle=shuffle,
|
|
273
|
+
drop_last=drop_last,
|
|
274
|
+
)
|
|
275
|
+
elif isinstance(sampler, (RandomSampler, SequentialSampler)):
|
|
276
|
+
# (If given a RandomSampler or SequentialSampler, we must convert to a DistributedSampler)
|
|
277
|
+
assert (
|
|
278
|
+
rank is not None and world_size is not None
|
|
279
|
+
), "Rank and world_size must be provided for RandomSampler or SequentialSampler"
|
|
280
|
+
sampler = DistributedSampler(
|
|
281
|
+
dataset=dataset,
|
|
282
|
+
num_replicas=world_size,
|
|
283
|
+
rank=rank,
|
|
284
|
+
shuffle=shuffle,
|
|
285
|
+
drop_last=drop_last,
|
|
286
|
+
)
|
|
287
|
+
elif sampler is None and isinstance(dataset, Subset):
|
|
288
|
+
# We are subsetting the dataset to a specific set of example IDs
|
|
289
|
+
ranked_logger.info(f"Subsetting dataset to {len(dataset)} examples!")
|
|
290
|
+
else:
|
|
291
|
+
# (We assume we are already given a DistributedSampler or DistributedMixedSampler)
|
|
292
|
+
assert (
|
|
293
|
+
rank is None and world_size is None
|
|
294
|
+
), "Rank and world_size will have no effect on the provided sampler and should be None"
|
|
295
|
+
assert isinstance(
|
|
296
|
+
sampler, (DistributedSampler, DistributedMixedSampler)
|
|
297
|
+
), "Invalid sampler type for distributed training."
|
|
298
|
+
|
|
299
|
+
# ... wrap the composed dataset and sampler with a fallback mechanism, if needed
|
|
300
|
+
if (
|
|
301
|
+
"n_fallback_retries" in loader_cfg
|
|
302
|
+
and loader_cfg.n_fallback_retries > 0
|
|
303
|
+
and sampler is not None
|
|
304
|
+
):
|
|
305
|
+
ranked_logger.info(
|
|
306
|
+
f"Wrapping train dataset and sampler with {loader_cfg.n_fallback_retries} fallbacks..."
|
|
307
|
+
)
|
|
308
|
+
dataset, sampler = wrap_dataset_and_sampler_with_fallbacks(
|
|
309
|
+
dataset_to_be_wrapped=dataset,
|
|
310
|
+
sampler_to_be_wrapped=sampler,
|
|
311
|
+
dataset_to_fallback_to=dataset,
|
|
312
|
+
sampler_to_fallback_to=sampler,
|
|
313
|
+
n_fallback_retries=loader_cfg.n_fallback_retries,
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
# ... assemble the final loader
|
|
317
|
+
loader = DataLoader(
|
|
318
|
+
dataset=dataset,
|
|
319
|
+
sampler=sampler,
|
|
320
|
+
collate_fn=lambda x: x, # No collation
|
|
321
|
+
**loader_cfg.dataloader_params if "dataloader_params" in loader_cfg else {},
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
return loader
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def subset_dataset_to_example_ids(
|
|
328
|
+
dataset: Dataset,
|
|
329
|
+
example_ids: list[str] | ListConfig,
|
|
330
|
+
) -> Dataset:
|
|
331
|
+
"""Subset a dataset to a specific set of example IDs."""
|
|
332
|
+
indices = []
|
|
333
|
+
for example_id in example_ids:
|
|
334
|
+
index = get_row_and_index_by_example_id(dataset, example_id)["index"]
|
|
335
|
+
indices.append(index)
|
|
336
|
+
|
|
337
|
+
return Subset(dataset, indices)
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def assemble_val_loader_dict(
|
|
341
|
+
cfg: DictConfig,
|
|
342
|
+
rank: int = 0,
|
|
343
|
+
world_size: int = 1,
|
|
344
|
+
loader_cfg: DictConfig | dict | None = None,
|
|
345
|
+
) -> dict[str, DataLoader]:
|
|
346
|
+
"""Assemble a dictionary of validation loaders for multiple datasets.
|
|
347
|
+
|
|
348
|
+
If a key is provided to balance the dataset, we will use a LoadBalancedDistributedSampler
|
|
349
|
+
rather than a DistributedSampler to maintain a balanced example load across processes
|
|
350
|
+
(i.e., avoid a situation where one GPU is allocated all small examples and another all large examples).
|
|
351
|
+
|
|
352
|
+
Args:
|
|
353
|
+
cfg (DictConfig): Configuration dictionary defining the validation datasets. Each key should correspond to a dataset name.
|
|
354
|
+
rank (int, optional): The rank of the current process in distributed training. Defaults to 0.
|
|
355
|
+
world_size (int, optional): The total number of processes participating in the distributed training. Defaults to 1.
|
|
356
|
+
loader_cfg (DictConfig, optional): Additional configuration parameters for the DataLoader, such as `batch_size` and `num_workers`. Defaults to None.
|
|
357
|
+
"""
|
|
358
|
+
# ... loop through the validation datasets and create a DataLoader for each, preserving the dataset name
|
|
359
|
+
val_loaders = {}
|
|
360
|
+
for val_dataset_name, val_dataset in cfg.items():
|
|
361
|
+
if not val_dataset:
|
|
362
|
+
# (Skip any None validation datasets; e.g., those overrode by the experiment config)
|
|
363
|
+
continue
|
|
364
|
+
|
|
365
|
+
assert (
|
|
366
|
+
"dataset" in val_dataset
|
|
367
|
+
), f"Expected 'dataset' key in validation dataset config for {val_dataset_name}"
|
|
368
|
+
dataset = hydra.utils.instantiate(
|
|
369
|
+
val_dataset.dataset
|
|
370
|
+
) # directly instantiate the dataset
|
|
371
|
+
|
|
372
|
+
if "key_to_balance" in val_dataset and val_dataset.key_to_balance:
|
|
373
|
+
# (If a key is provided to balance the dataset, we will use a LoadBalancedDistributedSampler)
|
|
374
|
+
key_to_balance = val_dataset.key_to_balance
|
|
375
|
+
ranked_logger.info(f"Balancing dataset with key: {key_to_balance}")
|
|
376
|
+
|
|
377
|
+
assert (
|
|
378
|
+
key_to_balance in dataset.data.columns
|
|
379
|
+
), f"Key {key_to_balance} not found in dataset columns!"
|
|
380
|
+
|
|
381
|
+
sampler = LoadBalancedDistributedSampler(
|
|
382
|
+
dataset=dataset,
|
|
383
|
+
num_replicas=world_size,
|
|
384
|
+
rank=rank,
|
|
385
|
+
key_to_balance=key_to_balance,
|
|
386
|
+
)
|
|
387
|
+
else:
|
|
388
|
+
# (Otherwise, we will use a DistributedSampler, without regard to sample size)
|
|
389
|
+
sampler = DistributedSampler(
|
|
390
|
+
dataset,
|
|
391
|
+
num_replicas=world_size,
|
|
392
|
+
rank=rank,
|
|
393
|
+
shuffle=False,
|
|
394
|
+
drop_last=False,
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
val_loader = assemble_distributed_loader(
|
|
398
|
+
dataset=dataset,
|
|
399
|
+
sampler=sampler,
|
|
400
|
+
loader_cfg=loader_cfg,
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
val_loaders[val_dataset_name] = val_loader
|
|
404
|
+
|
|
405
|
+
return val_loaders
|
foundry/utils/ddp.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from beartype.typing import Any
|
|
5
|
+
from lightning_fabric.utilities import rank_zero_only
|
|
6
|
+
from lightning_utilities.core.rank_zero import rank_prefixed_message
|
|
7
|
+
from omegaconf import DictConfig
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_current_rank() -> int:
|
|
13
|
+
"""Returns the rank of the current process."""
|
|
14
|
+
return getattr(rank_zero_only, "rank", None)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def is_rank_zero() -> bool:
|
|
18
|
+
"""Returns whether the current process is rank zero."""
|
|
19
|
+
return get_current_rank() == 0
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def set_accelerator_based_on_availability(cfg: dict | DictConfig):
|
|
23
|
+
"""Set training accelerator to CPU if no GPUs are available.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
cfg: Hydra object with trainer settings "accelerator", "devices_per_node", and "num_nodes".
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
None; modifies the input `cfg` object in place.
|
|
30
|
+
"""
|
|
31
|
+
if not torch.cuda.is_available():
|
|
32
|
+
logger.error(
|
|
33
|
+
"No GPUs available - Setting accelerator to 'cpu'. Are you sure you are using the correct configs?"
|
|
34
|
+
)
|
|
35
|
+
assert "trainer" in cfg, "Configuration object must have a 'trainer' key."
|
|
36
|
+
for key in ["accelerator", "devices_per_node", "num_nodes"]:
|
|
37
|
+
assert (
|
|
38
|
+
key in cfg.trainer
|
|
39
|
+
), f"Configuration object must have a 'trainer.{key}' key."
|
|
40
|
+
|
|
41
|
+
# Override accelerator settings
|
|
42
|
+
cfg.trainer.accelerator = "cpu"
|
|
43
|
+
cfg.trainer.devices_per_node = 1
|
|
44
|
+
cfg.trainer.num_nodes = 1
|
|
45
|
+
else:
|
|
46
|
+
cfg.trainer.accelerator = "gpu"
|
|
47
|
+
return cfg
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class RankedLogger(logging.LoggerAdapter):
|
|
51
|
+
"""A multi-GPU-friendly python command line logger.
|
|
52
|
+
|
|
53
|
+
Modified from https://github.com/ashleve/lightning-hydra-template/blob/main/src/utils/pylogger.py
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
name: str = __name__,
|
|
59
|
+
rank_zero_only: bool = False,
|
|
60
|
+
extra: Any | None = None,
|
|
61
|
+
) -> None:
|
|
62
|
+
"""Initializes a multi-GPU-friendly python command line logger that logs on all processes
|
|
63
|
+
with their rank prefixed in the log message.
|
|
64
|
+
|
|
65
|
+
:param name: The name of the logger. Default is ``__name__``.
|
|
66
|
+
:param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
|
|
67
|
+
:param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
|
|
68
|
+
"""
|
|
69
|
+
logger = logging.getLogger(name)
|
|
70
|
+
super().__init__(logger=logger, extra=extra)
|
|
71
|
+
self.rank_zero_only = rank_zero_only
|
|
72
|
+
|
|
73
|
+
def log(
|
|
74
|
+
self, level: int, msg: str, rank: int | None = None, *args, **kwargs
|
|
75
|
+
) -> None:
|
|
76
|
+
"""
|
|
77
|
+
Delegate a log call to the underlying logger, after prefixing its message with the rank
|
|
78
|
+
of the process it's being logged from. If `'rank'` is provided, then the log will only
|
|
79
|
+
occur on that rank/process.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
level (int): The level to log at. Look at `logging.__init__.py` for more information.
|
|
83
|
+
msg (str): The message to log.
|
|
84
|
+
rank (Optional[int]): The rank to log at.
|
|
85
|
+
args: Additional args to pass to the underlying logging function.
|
|
86
|
+
kwargs: Any additional keyword args to pass to the underlying logging function.
|
|
87
|
+
"""
|
|
88
|
+
if self.isEnabledFor(level):
|
|
89
|
+
msg, kwargs = self.process(msg, kwargs)
|
|
90
|
+
current_rank = getattr(rank_zero_only, "rank", None)
|
|
91
|
+
if current_rank is None:
|
|
92
|
+
raise RuntimeError(
|
|
93
|
+
"The `rank_zero_only.rank` needs to be set before use"
|
|
94
|
+
)
|
|
95
|
+
msg = rank_prefixed_message(msg, current_rank)
|
|
96
|
+
if self.rank_zero_only:
|
|
97
|
+
if current_rank == 0:
|
|
98
|
+
self.logger.log(level, msg, *args, **kwargs)
|
|
99
|
+
else:
|
|
100
|
+
if rank is None:
|
|
101
|
+
self.logger.log(level, msg, *args, **kwargs)
|
|
102
|
+
elif current_rank == rank:
|
|
103
|
+
self.logger.log(level, msg, *args, **kwargs)
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import hydra
|
|
2
|
+
from lightning.fabric.loggers import Logger
|
|
3
|
+
from omegaconf import DictConfig
|
|
4
|
+
|
|
5
|
+
from foundry.callbacks.callback import BaseCallback
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _can_be_instantiated(cfg: DictConfig) -> bool:
|
|
9
|
+
"""Checks if a config can be instantiated."""
|
|
10
|
+
return isinstance(cfg, DictConfig) and "_target_" in cfg
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class InstantiationError(ValueError):
|
|
14
|
+
"""Raised when a config cannot be instantiated."""
|
|
15
|
+
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def instantiate_callbacks(callbacks_cfg: DictConfig | None) -> list[BaseCallback]:
|
|
20
|
+
"""Instantiates callbacks from config.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
callbacks_cfg: A DictConfig object containing callback configurations.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
A list of instantiated callbacks.
|
|
27
|
+
|
|
28
|
+
Reference:
|
|
29
|
+
- Lightning Hydra Template (https://github.com/ashleve/lightning-hydra-template/blob/main/src/utils/instantiators.py#L36)
|
|
30
|
+
"""
|
|
31
|
+
callbacks: list[BaseCallback] = []
|
|
32
|
+
|
|
33
|
+
if not callbacks_cfg:
|
|
34
|
+
return callbacks
|
|
35
|
+
|
|
36
|
+
for _, cb_conf in callbacks_cfg.items():
|
|
37
|
+
if _can_be_instantiated(cb_conf):
|
|
38
|
+
callbacks.append(hydra.utils.instantiate(cb_conf))
|
|
39
|
+
else:
|
|
40
|
+
raise InstantiationError(
|
|
41
|
+
f"Skipping callback <{cb_conf}> - Not a DictConfig with `_target_` key! Please provide a valid `_target_` for instantiation."
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
return callbacks
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def instantiate_loggers(logger_cfg: DictConfig | None) -> list[Logger]:
|
|
48
|
+
"""Instantiates loggers from config.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
logger_cfg: A DictConfig object containing logger configurations.
|
|
52
|
+
|
|
53
|
+
Return:
|
|
54
|
+
A list of instantiated loggers.
|
|
55
|
+
|
|
56
|
+
Reference:
|
|
57
|
+
- Lightning Hydra Template (https://github.com/ashleve/lightning-hydra-template/blob/main/src/utils/instantiators.py#L36)
|
|
58
|
+
"""
|
|
59
|
+
loggers: list[Logger] = []
|
|
60
|
+
|
|
61
|
+
if not logger_cfg:
|
|
62
|
+
return loggers
|
|
63
|
+
|
|
64
|
+
for _, lg_conf in logger_cfg.items():
|
|
65
|
+
if _can_be_instantiated(lg_conf):
|
|
66
|
+
loggers.append(hydra.utils.instantiate(lg_conf))
|
|
67
|
+
else:
|
|
68
|
+
raise InstantiationError(
|
|
69
|
+
f"Skipping logger <{lg_conf}> - Not a DictConfig with `_target_` key! Please provide a valid `_target_` for instantiation."
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
return loggers
|