rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,405 @@
1
+ import hydra
2
+ import torch
3
+ from atomworks.ml.samplers import (
4
+ DistributedMixedSampler,
5
+ FallbackSamplerWrapper,
6
+ LazyWeightedRandomSampler,
7
+ LoadBalancedDistributedSampler,
8
+ MixedSampler,
9
+ )
10
+ from beartype.typing import Any
11
+ from omegaconf import DictConfig, ListConfig
12
+ from torch.utils.data import (
13
+ DataLoader,
14
+ Dataset,
15
+ RandomSampler,
16
+ Sampler,
17
+ SequentialSampler,
18
+ Subset,
19
+ WeightedRandomSampler,
20
+ )
21
+ from torch.utils.data.distributed import DistributedSampler
22
+
23
+ from foundry.hydra.resolvers import register_resolvers
24
+ from foundry.utils.ddp import RankedLogger
25
+
26
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
27
+ try:
28
+ from atomworks.ml.datasets import (
29
+ ConcatDatasetWithID,
30
+ FallbackDatasetWrapper,
31
+ get_row_and_index_by_example_id,
32
+ )
33
+ except Exception as e:
34
+ ranked_logger.warning(
35
+ f"Failed to import atomworks.ml.datasets: {type(e).__name__}: {e}. "
36
+ "If training networks, the PDB_MIRROR environment variable must be set."
37
+ )
38
+
39
+
40
+ register_resolvers()
41
+
42
+
43
+ def wrap_dataset_and_sampler_with_fallbacks(
44
+ dataset_to_be_wrapped: Dataset,
45
+ sampler_to_be_wrapped: Sampler,
46
+ dataset_to_fallback_to: Dataset,
47
+ sampler_to_fallback_to: Sampler,
48
+ n_fallback_retries: int,
49
+ ) -> tuple[Dataset, Sampler]:
50
+ """Wrap the specified dataset and sampler with fallback dataloading.
51
+
52
+ If the provided fallback sampler does not have weights (e.g., a MixedSampler), we will use uniform weights.
53
+
54
+ Args:
55
+ dataset_to_be_wrapped (Dataset): The main dataset to be wrapped.
56
+ sampler_to_be_wrapped (Sampler): The main sampler to be wrapped.
57
+ dataset_to_fallback_to (Dataset): The fallback dataset. We will sample from this dataset if the main dataset fails.
58
+ sampler_to_fallback_to (Sampler): The fallback sampler. We will sample from this sampler if the main sampler fails.
59
+ n_fallback_retries (int): Number of retries for the fallback mechanism before raising an exception.
60
+
61
+ Returns:
62
+ tuple[Dataset, Sampler]: The wrapped dataset and sampler with fallbacks.
63
+ """
64
+ # Instantiate a new fallback sampler to avoid scaling issues
65
+ fallback_sampler = LazyWeightedRandomSampler(
66
+ weights=sampler_to_fallback_to.weights
67
+ if "weights" in sampler_to_fallback_to
68
+ else torch.ones(len(dataset_to_fallback_to)),
69
+ num_samples=int(1e9),
70
+ replacement=True, # replacement for fallback dataloading, so we can draw a huge number of samples
71
+ generator=None,
72
+ prefetch_buffer_size=4,
73
+ )
74
+
75
+ # Wrap the dataset and sampler with fallback mechanisms
76
+ wrapped_dataset = FallbackDatasetWrapper(
77
+ dataset_to_be_wrapped, fallback_dataset=dataset_to_fallback_to
78
+ )
79
+ wrapped_sampler = FallbackSamplerWrapper(
80
+ sampler_to_be_wrapped,
81
+ fallback_sampler=fallback_sampler,
82
+ n_fallback_retries=n_fallback_retries,
83
+ )
84
+
85
+ return wrapped_dataset, wrapped_sampler
86
+
87
+
88
+ def instantiate_single_dataset_and_sampler(cfg: DictConfig | dict) -> dict[str, Any]:
89
+ """Instantiate a dataset and its corresponding sampler from a configuration dictionary.
90
+
91
+ Args:
92
+ cfg (DictConfig): Configuration dictionary defining the dataset and its parameters.
93
+
94
+ Returns:
95
+ dict[str, Any]: A dictionary containing the instantiated dataset and sampler.
96
+ """
97
+ # ... instantiate the dataset
98
+ dataset = hydra.utils.instantiate(cfg.dataset)
99
+
100
+ # Users may provide only weights, in which case we will use a WeightedRandomSampler,
101
+ # or they may provide a sampler directly
102
+
103
+ if "weights" in cfg and "sampler" not in cfg:
104
+ # ... instantiate the weights and create a WeightedRandomSampler
105
+ weights = hydra.utils.instantiate(cfg.weights, dataset_df=dataset.data)
106
+ sampler = WeightedRandomSampler(
107
+ weights=weights,
108
+ num_samples=len(dataset),
109
+ replacement=True,
110
+ )
111
+ elif "sampler" in cfg and "weights" not in cfg:
112
+ # ... instantiate the sampler with the number of samples
113
+ sampler = hydra.utils.instantiate(cfg.sampler)
114
+ else:
115
+ dataset_name = getattr(getattr(cfg.dataset, "dataset", None), "name", None)
116
+ ranked_logger.warning(
117
+ f"No weights or sampler provided for dataset: {dataset_name}, using uniform weights with replacement."
118
+ )
119
+ sampler = WeightedRandomSampler(
120
+ weights=torch.ones(len(dataset)),
121
+ num_samples=len(dataset),
122
+ replacement=True,
123
+ )
124
+
125
+ return {"dataset": dataset, "sampler": sampler}
126
+
127
+
128
+ def recursively_instantiate_datasets_and_samplers(
129
+ cfg: DictConfig | dict, name: str | None = None
130
+ ) -> dict[str, Any]:
131
+ """Recursively instantiate datasets and samplers from a configuration dictionary.
132
+
133
+ We must handle three cases:
134
+ (1) A single "leaf" dataset (e.g., "distillation"), specified with the "dataset" key
135
+ (2) Multiple sub-datasets that should be concatenated together with their weights (e.g., "interfaces" and "pn_units"),
136
+ specified with the "sub_datasets" key
137
+ (3) Multiple "leaf" datasets that should be sampled from with a certain probability (e.g., "distillation" and "pdb"),
138
+
139
+ Args:
140
+ cfg (DictConfig): Configuration dictionary defining datasets and their parameters.
141
+ name (str, optional): The name of the dataset, used for reporting. Defaults to None.
142
+
143
+ Returns:
144
+ dict[str, Any]: A dictionary containing the instantiated dataset and sampler.
145
+ """
146
+ # ------- Base case (1): A single "leaf" dataset -------#
147
+ if "dataset" in cfg:
148
+ return {**instantiate_single_dataset_and_sampler(cfg), "name": name}
149
+
150
+ # ------- Recursive case (2): Multiple sub-datasets that must be concatenated together -------#
151
+ elif "sub_datasets" in cfg:
152
+ # ... create a list of dictionaries for each sub-dataset
153
+ datasets_info = []
154
+ for sub_dataset_name, sub_dataset_cfg in cfg.sub_datasets.items():
155
+ if sub_dataset_cfg is None:
156
+ # (Skip any None sub-datasets; e.g., those overrode by the experiment config)
157
+ continue
158
+
159
+ datasets_info.append(
160
+ recursively_instantiate_datasets_and_samplers(
161
+ sub_dataset_cfg, name=sub_dataset_name
162
+ )
163
+ )
164
+
165
+ # ... concatenate sub-datasets and weights (e.g., "interfaces" and "pn_units" into one ConcatDataset)
166
+ # NOTE: Order of the weights must match the order of the datasets!
167
+ concatenated_dataset = ConcatDatasetWithID(
168
+ datasets=[info["dataset"] for info in datasets_info]
169
+ )
170
+ concatenated_weights = torch.cat(
171
+ [info["sampler"].weights for info in datasets_info]
172
+ )
173
+ sampler = WeightedRandomSampler(
174
+ weights=concatenated_weights,
175
+ num_samples=len(concatenated_dataset),
176
+ replacement=True,
177
+ )
178
+
179
+ return {"dataset": concatenated_dataset, "sampler": sampler, "name": name}
180
+
181
+ # ------- Recursive case (3): Multiple datasets that must be sampled from with specified probabilities -------#
182
+ else:
183
+ datasets_info = []
184
+ for nested_dataset_name, nested_dataset_cfg in cfg.items():
185
+ if nested_dataset_cfg is None:
186
+ # (Skip any None training datasets; e.g., those overrode by the experiment config)
187
+ continue
188
+
189
+ # (To use a MixedSampler, we must provide a "probability" key for each dataset)
190
+ assert (
191
+ "probability" in nested_dataset_cfg
192
+ ), "Expected 'probability' key in dataset configuration"
193
+ datasets_info.append(
194
+ {
195
+ **recursively_instantiate_datasets_and_samplers(
196
+ nested_dataset_cfg, name=nested_dataset_name
197
+ ),
198
+ "probability": nested_dataset_cfg["probability"],
199
+ }
200
+ )
201
+
202
+ # ... check that the sum of probabilities of all datasets is 1
203
+ assert (
204
+ abs(1 - sum(dataset_info["probability"] for dataset_info in datasets_info))
205
+ < 1e-5
206
+ ), "Sum of probabilities must be 1.0"
207
+
208
+ # ... compose the list of datasets into a single dataset
209
+ composed_train_dataset = ConcatDatasetWithID(
210
+ datasets=[dataset["dataset"] for dataset in datasets_info]
211
+ )
212
+
213
+ composed_train_sampler = MixedSampler(datasets_info=datasets_info, shuffle=True)
214
+
215
+ return {
216
+ "dataset": composed_train_dataset,
217
+ "sampler": composed_train_sampler,
218
+ "name": name,
219
+ }
220
+
221
+
222
+ def assemble_distributed_loader(
223
+ dataset: Dataset,
224
+ sampler: Sampler | None = None,
225
+ rank: int | None = None,
226
+ world_size: int | None = None,
227
+ n_examples_per_epoch: int | None = None,
228
+ loader_cfg: DictConfig | dict | None = None,
229
+ shuffle: bool = True,
230
+ drop_last: bool = False,
231
+ ) -> DataLoader:
232
+ """Assembles a distributed DataLoader for training or validation.
233
+
234
+ Performs the following steps:
235
+ (1) If not already a distributed sampler, wraps the sampler with a DistributedSampler or DistributedMixedSampler
236
+ (2) Wraps the dataset and sampler with a fallback mechanism, if needed
237
+ (3) Assembles the final DataLoader
238
+
239
+ Args:
240
+ dataset (Dataset): The dataset to be used for training or validation.
241
+ sampler (Sampler): The sampler to be used for training or validation. May already be distributed.
242
+ rank (int): The rank of the current process in distributed training.
243
+ world_size (int): The total number of processes participating in the distributed training.
244
+ n_examples_per_epoch (int): The number of examples to sample per epoch, across all GPUs.
245
+ For example, if we have 8 GPUs, with 2 gradient accumulation steps and 10 optimizer
246
+ steps per epoch, we would sample 160 examples per epoch (8 * 2 * 10).
247
+ loader_cfg (DictConfig or dict, optional): Additional configuration parameters for the
248
+ DataLoader, such as `batch_size` and `num_workers`. Defaults to an empty dictionary.
249
+ shuffle (bool, optional): Whether to shuffle the dataset. Defaults to True.
250
+ drop_last (bool, optional): Whether to drop the last incomplete batch if the dataset size
251
+ is not divisible by the number of GPUs. Defaults to False.
252
+
253
+ Returns:
254
+ DataLoader: A PyTorch DataLoader configured for distributed training, with datasets
255
+ concatenated and sampled according to their defined probabilities.
256
+ """
257
+ if not loader_cfg:
258
+ loader_cfg = {}
259
+
260
+ if isinstance(sampler, MixedSampler):
261
+ # (If given a MixedSampler, we must convert to a DistributedMixedSampler)
262
+ assert (
263
+ rank is not None
264
+ and world_size is not None
265
+ and n_examples_per_epoch is not None
266
+ ), "Rank, world_size, and n_examples_per_epoch must be provided for MixedSampler"
267
+ sampler = DistributedMixedSampler(
268
+ datasets_info=sampler.datasets_info,
269
+ num_replicas=world_size,
270
+ rank=rank,
271
+ n_examples_per_epoch=n_examples_per_epoch,
272
+ shuffle=shuffle,
273
+ drop_last=drop_last,
274
+ )
275
+ elif isinstance(sampler, (RandomSampler, SequentialSampler)):
276
+ # (If given a RandomSampler or SequentialSampler, we must convert to a DistributedSampler)
277
+ assert (
278
+ rank is not None and world_size is not None
279
+ ), "Rank and world_size must be provided for RandomSampler or SequentialSampler"
280
+ sampler = DistributedSampler(
281
+ dataset=dataset,
282
+ num_replicas=world_size,
283
+ rank=rank,
284
+ shuffle=shuffle,
285
+ drop_last=drop_last,
286
+ )
287
+ elif sampler is None and isinstance(dataset, Subset):
288
+ # We are subsetting the dataset to a specific set of example IDs
289
+ ranked_logger.info(f"Subsetting dataset to {len(dataset)} examples!")
290
+ else:
291
+ # (We assume we are already given a DistributedSampler or DistributedMixedSampler)
292
+ assert (
293
+ rank is None and world_size is None
294
+ ), "Rank and world_size will have no effect on the provided sampler and should be None"
295
+ assert isinstance(
296
+ sampler, (DistributedSampler, DistributedMixedSampler)
297
+ ), "Invalid sampler type for distributed training."
298
+
299
+ # ... wrap the composed dataset and sampler with a fallback mechanism, if needed
300
+ if (
301
+ "n_fallback_retries" in loader_cfg
302
+ and loader_cfg.n_fallback_retries > 0
303
+ and sampler is not None
304
+ ):
305
+ ranked_logger.info(
306
+ f"Wrapping train dataset and sampler with {loader_cfg.n_fallback_retries} fallbacks..."
307
+ )
308
+ dataset, sampler = wrap_dataset_and_sampler_with_fallbacks(
309
+ dataset_to_be_wrapped=dataset,
310
+ sampler_to_be_wrapped=sampler,
311
+ dataset_to_fallback_to=dataset,
312
+ sampler_to_fallback_to=sampler,
313
+ n_fallback_retries=loader_cfg.n_fallback_retries,
314
+ )
315
+
316
+ # ... assemble the final loader
317
+ loader = DataLoader(
318
+ dataset=dataset,
319
+ sampler=sampler,
320
+ collate_fn=lambda x: x, # No collation
321
+ **loader_cfg.dataloader_params if "dataloader_params" in loader_cfg else {},
322
+ )
323
+
324
+ return loader
325
+
326
+
327
+ def subset_dataset_to_example_ids(
328
+ dataset: Dataset,
329
+ example_ids: list[str] | ListConfig,
330
+ ) -> Dataset:
331
+ """Subset a dataset to a specific set of example IDs."""
332
+ indices = []
333
+ for example_id in example_ids:
334
+ index = get_row_and_index_by_example_id(dataset, example_id)["index"]
335
+ indices.append(index)
336
+
337
+ return Subset(dataset, indices)
338
+
339
+
340
+ def assemble_val_loader_dict(
341
+ cfg: DictConfig,
342
+ rank: int = 0,
343
+ world_size: int = 1,
344
+ loader_cfg: DictConfig | dict | None = None,
345
+ ) -> dict[str, DataLoader]:
346
+ """Assemble a dictionary of validation loaders for multiple datasets.
347
+
348
+ If a key is provided to balance the dataset, we will use a LoadBalancedDistributedSampler
349
+ rather than a DistributedSampler to maintain a balanced example load across processes
350
+ (i.e., avoid a situation where one GPU is allocated all small examples and another all large examples).
351
+
352
+ Args:
353
+ cfg (DictConfig): Configuration dictionary defining the validation datasets. Each key should correspond to a dataset name.
354
+ rank (int, optional): The rank of the current process in distributed training. Defaults to 0.
355
+ world_size (int, optional): The total number of processes participating in the distributed training. Defaults to 1.
356
+ loader_cfg (DictConfig, optional): Additional configuration parameters for the DataLoader, such as `batch_size` and `num_workers`. Defaults to None.
357
+ """
358
+ # ... loop through the validation datasets and create a DataLoader for each, preserving the dataset name
359
+ val_loaders = {}
360
+ for val_dataset_name, val_dataset in cfg.items():
361
+ if not val_dataset:
362
+ # (Skip any None validation datasets; e.g., those overrode by the experiment config)
363
+ continue
364
+
365
+ assert (
366
+ "dataset" in val_dataset
367
+ ), f"Expected 'dataset' key in validation dataset config for {val_dataset_name}"
368
+ dataset = hydra.utils.instantiate(
369
+ val_dataset.dataset
370
+ ) # directly instantiate the dataset
371
+
372
+ if "key_to_balance" in val_dataset and val_dataset.key_to_balance:
373
+ # (If a key is provided to balance the dataset, we will use a LoadBalancedDistributedSampler)
374
+ key_to_balance = val_dataset.key_to_balance
375
+ ranked_logger.info(f"Balancing dataset with key: {key_to_balance}")
376
+
377
+ assert (
378
+ key_to_balance in dataset.data.columns
379
+ ), f"Key {key_to_balance} not found in dataset columns!"
380
+
381
+ sampler = LoadBalancedDistributedSampler(
382
+ dataset=dataset,
383
+ num_replicas=world_size,
384
+ rank=rank,
385
+ key_to_balance=key_to_balance,
386
+ )
387
+ else:
388
+ # (Otherwise, we will use a DistributedSampler, without regard to sample size)
389
+ sampler = DistributedSampler(
390
+ dataset,
391
+ num_replicas=world_size,
392
+ rank=rank,
393
+ shuffle=False,
394
+ drop_last=False,
395
+ )
396
+
397
+ val_loader = assemble_distributed_loader(
398
+ dataset=dataset,
399
+ sampler=sampler,
400
+ loader_cfg=loader_cfg,
401
+ )
402
+
403
+ val_loaders[val_dataset_name] = val_loader
404
+
405
+ return val_loaders
foundry/utils/ddp.py ADDED
@@ -0,0 +1,103 @@
1
+ import logging
2
+
3
+ import torch
4
+ from beartype.typing import Any
5
+ from lightning_fabric.utilities import rank_zero_only
6
+ from lightning_utilities.core.rank_zero import rank_prefixed_message
7
+ from omegaconf import DictConfig
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def get_current_rank() -> int:
13
+ """Returns the rank of the current process."""
14
+ return getattr(rank_zero_only, "rank", None)
15
+
16
+
17
+ def is_rank_zero() -> bool:
18
+ """Returns whether the current process is rank zero."""
19
+ return get_current_rank() == 0
20
+
21
+
22
+ def set_accelerator_based_on_availability(cfg: dict | DictConfig):
23
+ """Set training accelerator to CPU if no GPUs are available.
24
+
25
+ Args:
26
+ cfg: Hydra object with trainer settings "accelerator", "devices_per_node", and "num_nodes".
27
+
28
+ Returns:
29
+ None; modifies the input `cfg` object in place.
30
+ """
31
+ if not torch.cuda.is_available():
32
+ logger.error(
33
+ "No GPUs available - Setting accelerator to 'cpu'. Are you sure you are using the correct configs?"
34
+ )
35
+ assert "trainer" in cfg, "Configuration object must have a 'trainer' key."
36
+ for key in ["accelerator", "devices_per_node", "num_nodes"]:
37
+ assert (
38
+ key in cfg.trainer
39
+ ), f"Configuration object must have a 'trainer.{key}' key."
40
+
41
+ # Override accelerator settings
42
+ cfg.trainer.accelerator = "cpu"
43
+ cfg.trainer.devices_per_node = 1
44
+ cfg.trainer.num_nodes = 1
45
+ else:
46
+ cfg.trainer.accelerator = "gpu"
47
+ return cfg
48
+
49
+
50
+ class RankedLogger(logging.LoggerAdapter):
51
+ """A multi-GPU-friendly python command line logger.
52
+
53
+ Modified from https://github.com/ashleve/lightning-hydra-template/blob/main/src/utils/pylogger.py
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ name: str = __name__,
59
+ rank_zero_only: bool = False,
60
+ extra: Any | None = None,
61
+ ) -> None:
62
+ """Initializes a multi-GPU-friendly python command line logger that logs on all processes
63
+ with their rank prefixed in the log message.
64
+
65
+ :param name: The name of the logger. Default is ``__name__``.
66
+ :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
67
+ :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
68
+ """
69
+ logger = logging.getLogger(name)
70
+ super().__init__(logger=logger, extra=extra)
71
+ self.rank_zero_only = rank_zero_only
72
+
73
+ def log(
74
+ self, level: int, msg: str, rank: int | None = None, *args, **kwargs
75
+ ) -> None:
76
+ """
77
+ Delegate a log call to the underlying logger, after prefixing its message with the rank
78
+ of the process it's being logged from. If `'rank'` is provided, then the log will only
79
+ occur on that rank/process.
80
+
81
+ Args:
82
+ level (int): The level to log at. Look at `logging.__init__.py` for more information.
83
+ msg (str): The message to log.
84
+ rank (Optional[int]): The rank to log at.
85
+ args: Additional args to pass to the underlying logging function.
86
+ kwargs: Any additional keyword args to pass to the underlying logging function.
87
+ """
88
+ if self.isEnabledFor(level):
89
+ msg, kwargs = self.process(msg, kwargs)
90
+ current_rank = getattr(rank_zero_only, "rank", None)
91
+ if current_rank is None:
92
+ raise RuntimeError(
93
+ "The `rank_zero_only.rank` needs to be set before use"
94
+ )
95
+ msg = rank_prefixed_message(msg, current_rank)
96
+ if self.rank_zero_only:
97
+ if current_rank == 0:
98
+ self.logger.log(level, msg, *args, **kwargs)
99
+ else:
100
+ if rank is None:
101
+ self.logger.log(level, msg, *args, **kwargs)
102
+ elif current_rank == rank:
103
+ self.logger.log(level, msg, *args, **kwargs)
@@ -0,0 +1,72 @@
1
+ import hydra
2
+ from lightning.fabric.loggers import Logger
3
+ from omegaconf import DictConfig
4
+
5
+ from foundry.callbacks.callback import BaseCallback
6
+
7
+
8
+ def _can_be_instantiated(cfg: DictConfig) -> bool:
9
+ """Checks if a config can be instantiated."""
10
+ return isinstance(cfg, DictConfig) and "_target_" in cfg
11
+
12
+
13
+ class InstantiationError(ValueError):
14
+ """Raised when a config cannot be instantiated."""
15
+
16
+ pass
17
+
18
+
19
+ def instantiate_callbacks(callbacks_cfg: DictConfig | None) -> list[BaseCallback]:
20
+ """Instantiates callbacks from config.
21
+
22
+ Args:
23
+ callbacks_cfg: A DictConfig object containing callback configurations.
24
+
25
+ Returns:
26
+ A list of instantiated callbacks.
27
+
28
+ Reference:
29
+ - Lightning Hydra Template (https://github.com/ashleve/lightning-hydra-template/blob/main/src/utils/instantiators.py#L36)
30
+ """
31
+ callbacks: list[BaseCallback] = []
32
+
33
+ if not callbacks_cfg:
34
+ return callbacks
35
+
36
+ for _, cb_conf in callbacks_cfg.items():
37
+ if _can_be_instantiated(cb_conf):
38
+ callbacks.append(hydra.utils.instantiate(cb_conf))
39
+ else:
40
+ raise InstantiationError(
41
+ f"Skipping callback <{cb_conf}> - Not a DictConfig with `_target_` key! Please provide a valid `_target_` for instantiation."
42
+ )
43
+
44
+ return callbacks
45
+
46
+
47
+ def instantiate_loggers(logger_cfg: DictConfig | None) -> list[Logger]:
48
+ """Instantiates loggers from config.
49
+
50
+ Args:
51
+ logger_cfg: A DictConfig object containing logger configurations.
52
+
53
+ Return:
54
+ A list of instantiated loggers.
55
+
56
+ Reference:
57
+ - Lightning Hydra Template (https://github.com/ashleve/lightning-hydra-template/blob/main/src/utils/instantiators.py#L36)
58
+ """
59
+ loggers: list[Logger] = []
60
+
61
+ if not logger_cfg:
62
+ return loggers
63
+
64
+ for _, lg_conf in logger_cfg.items():
65
+ if _can_be_instantiated(lg_conf):
66
+ loggers.append(hydra.utils.instantiate(lg_conf))
67
+ else:
68
+ raise InstantiationError(
69
+ f"Skipping logger <{lg_conf}> - Not a DictConfig with `_target_` key! Please provide a valid `_target_` for instantiation."
70
+ )
71
+
72
+ return loggers