rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,319 @@
1
+ import inspect
2
+ from abc import ABC, abstractmethod
3
+ from functools import cached_property
4
+
5
+ import hydra
6
+ from atomworks.common import exists
7
+ from atomworks.ml.utils import error, nested_dict
8
+ from beartype.typing import Any
9
+ from omegaconf import DictConfig
10
+ from toolz import keymap
11
+
12
+ from foundry.utils.ddp import RankedLogger
13
+
14
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
15
+
16
+
17
+ def instantiate_metric_manager(
18
+ metrics_cfg: dict[str, Any] | DictConfig,
19
+ ) -> "MetricManager":
20
+ """Instantiate a MetricManager from a dictionary of metrics.
21
+
22
+ Args:
23
+ metrics: A dictionary where keys are metric names and values are Hydra configurations for the metrics.
24
+ """
25
+ metrics = {}
26
+ for name, cfg in metrics_cfg.items():
27
+ metric = hydra.utils.instantiate(cfg)
28
+ if not isinstance(metric, Metric):
29
+ raise TypeError(f"{name} must be a Metric instance")
30
+ ranked_logger.info(f"Adding metric {name} to the validation metrics...")
31
+ metrics[name] = metric
32
+ return MetricManager(metrics)
33
+
34
+
35
+ class MetricInputError(Exception):
36
+ """Exception raised when a metric fails to compute."""
37
+
38
+
39
+ class MetricManager:
40
+ """Manages and computes a set of Metrics, where each Metric inherits from the Metric class.
41
+
42
+ For model validation, additional metrics can be added through the Hydra configuration; they
43
+ will be computed with the __call__ method automatically.
44
+
45
+ For example, during AF-3, Metrics will receive `network_input`, `network_output`, `extra_info`,
46
+ `ground_truth_atom_array_stack`, and `predicted_atom_array_stack` as input arguments.
47
+
48
+ Example:
49
+ >>> class ExampleMetric(Metric):
50
+ ... @cached_property
51
+ ... def kwargs_to_compute_args(self):
52
+ ... return {"x": "x", "y": "y", "extra_info": "extra_info"}
53
+ ...
54
+ ... def compute(self, x, y, extra_info):
55
+ ... return {"value": x + y}
56
+ >>> metric = ExampleMetric()
57
+ >>> manager = MetricManager({"my_metric": metric}, raise_errors=True)
58
+ >>> manager(x=1, y=2, extra_info={"example_id": "123"})
59
+ {'example_id': '123', 'my_metric.value': 3}
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ metrics: dict[str, "Metric"] = {},
65
+ *,
66
+ raise_errors: bool = True,
67
+ ):
68
+ """Initialize the MetricManager with a set of metrics.
69
+
70
+ Args:
71
+ raise_errors: Whether to raise errors when a metric fails to compute.
72
+ metrics: A dictionary where keys are metric names and values are Metric instances.
73
+ """
74
+ self.raise_errors = raise_errors
75
+ self.metrics = {}
76
+ for name, metric in metrics.items():
77
+ assert isinstance(
78
+ metric, Metric
79
+ ), f"{name} must be a Metric instance, not {type(metric)}"
80
+ self.metrics[name] = metric
81
+
82
+ @classmethod
83
+ def instantiate_from_hydra(
84
+ cls, metrics_cfg: dict[str, Any] | DictConfig
85
+ ) -> "MetricManager":
86
+ """Instantiate a MetricManager from a dictionary of metrics.
87
+
88
+ Args:
89
+ metrics_cfg: A dictionary where keys are metric names and values are Hydra configurations for the metrics.
90
+ """
91
+ return instantiate_metric_manager(metrics_cfg)
92
+
93
+ @classmethod
94
+ def from_metrics(
95
+ cls,
96
+ metrics: dict[str, "Metric"] | list[tuple[str, "Metric"]],
97
+ raise_errors: bool = True,
98
+ ) -> "MetricManager":
99
+ """Create MetricManager from metric objects.
100
+
101
+ Args:
102
+ metrics: Either dict mapping names to Metric objects,
103
+ or list of (name, Metric) tuples.
104
+ raise_errors: Whether to raise errors on metric failures. Defaults to ``True``.
105
+ """
106
+ if isinstance(metrics, list):
107
+ # Convert list of tuples to dict
108
+ metrics = dict(metrics)
109
+
110
+ # Validate all are Metric instances
111
+ for name, metric in metrics.items():
112
+ if not isinstance(metric, Metric):
113
+ raise TypeError(
114
+ f"Metric '{name}' must be a Metric instance, got {type(metric)}"
115
+ )
116
+
117
+ return cls(metrics, raise_errors=raise_errors)
118
+
119
+ def __repr__(self) -> str:
120
+ """Return a string representation of the MetricManager."""
121
+ return f"MetricManager({', '.join(self.metrics.keys())})"
122
+
123
+ def __call__(self, **kwargs: Any) -> dict[str, Any]:
124
+ """Compute all metrics and return their results as a dictionary mapping metric names to their computed values."""
125
+
126
+ # Extract example_id if it exists
127
+ example_id = nested_dict.get(
128
+ kwargs, key=("extra_info", "example_id"), default=None
129
+ )
130
+
131
+ # Initialize results dictionary
132
+ results = {"example_id": example_id}
133
+
134
+ for name, metric in self.metrics.items():
135
+ assert name not in results, f"Duplicate metric name: {name}"
136
+
137
+ # Add some nice error handling context in case metrics fail
138
+ example_msg = (
139
+ f" for example '{example_id}'" if example_id is not None else ""
140
+ )
141
+
142
+ # Decide whether the Metric applies to this batch
143
+ if "extra_info" in kwargs and "metrics_tags" in kwargs["extra_info"]:
144
+ metrics_tags = kwargs["extra_info"]["metrics_tags"]
145
+ metrics_tags = set(metrics_tags)
146
+ if metric.required_tags_all and not metrics_tags.issuperset(
147
+ metric.required_tags_all
148
+ ):
149
+ continue
150
+ if metric.required_tags_any and not metrics_tags.intersection(
151
+ metric.required_tags_any
152
+ ):
153
+ continue
154
+ if metric.prohibited_tags and metrics_tags.intersection(
155
+ metric.prohibited_tags
156
+ ):
157
+ continue
158
+
159
+ with error.context(
160
+ msg=f"Computing '{name}' ({type(metric).__name__}){example_msg}",
161
+ raise_error=self.raise_errors,
162
+ exc_types=(MetricInputError, ValueError, TypeError, AttributeError),
163
+ ):
164
+ # ... compute the metric
165
+ metric_result = metric.compute_from_kwargs(**kwargs)
166
+
167
+ # ... append 'name' to the keys of the metric result to ensure uniqueness
168
+ if isinstance(metric_result, dict):
169
+ metric_result = keymap(lambda k: f"{name}.{k}", metric_result)
170
+ results.update(metric_result)
171
+ elif isinstance(metric_result, list):
172
+ results[name] = metric_result
173
+ else:
174
+ raise ValueError("Unexpected result type: expected dict or list.")
175
+
176
+ return results
177
+
178
+
179
+ class Metric(ABC):
180
+ """Abstract base class for Foundry metrics.
181
+
182
+ Defines a framework for computing metrics based on arbitrary keyword arguments.
183
+ A Metric can specify which batches a `MetricManager` applies it to, based on `tags` in the input batch.
184
+
185
+ Args:
186
+ required_tags_all: A set of tags that must all be present in the input batch for this metric to be computed.
187
+ required_tags_any: A set of tags where at least one must be present in the input batch for this metric to be computed.
188
+ prohibited_tags: A set of tags that must not be present in the input batch for this metric to be computed.
189
+
190
+ To implement a new metric, subclass this class and implement the `compute` method, at a minimum.
191
+ """
192
+
193
+ def __init__(
194
+ self,
195
+ required_tags_all: list[str] | set[str] | None = None,
196
+ required_tags_any: list[str] | set[str] | None = None,
197
+ prohibited_tags: list[str] | set[str] | None = None,
198
+ *args: Any,
199
+ **kwargs: Any,
200
+ ) -> None:
201
+ # Set required and prohibited tags
202
+ self.required_tags_all = (
203
+ set(required_tags_all) if exists(required_tags_all) else set()
204
+ )
205
+ self.required_tags_any = (
206
+ set(required_tags_any) if exists(required_tags_any) else set()
207
+ )
208
+ self.prohibited_tags = (
209
+ set(prohibited_tags) if exists(prohibited_tags) else set()
210
+ )
211
+ required_tags = self.required_tags_all.union(self.required_tags_any)
212
+ if required_tags.intersection(self.prohibited_tags):
213
+ raise ValueError(
214
+ "Conflicting tags found: required tags and prohibited tags must be disjoint."
215
+ )
216
+
217
+ # Check that the 'keys' of the kwargs_to_compute_args are a subset of the 'compute' method signature
218
+ if self.kwargs_to_compute_args:
219
+ assert self.kwargs_to_compute_args.keys() <= self.required_compute_args, (
220
+ f"The keys of kwargs_to_compute_args must be a subset of the 'compute' method signature. "
221
+ f"{self.kwargs_to_compute_args.keys()} is not a subset of {self.required_compute_args}"
222
+ )
223
+
224
+ # Check that optional_kwargs are also in the kwargs_to_compute_args
225
+ if self.kwargs_to_compute_args and self.optional_kwargs:
226
+ assert self.optional_kwargs <= set(self.kwargs_to_compute_args.keys()), (
227
+ f"All optional_kwargs must be defined in kwargs_to_compute_args. "
228
+ f"{self.optional_kwargs} is not a subset of {set(self.kwargs_to_compute_args.keys())}"
229
+ )
230
+
231
+ @cached_property
232
+ def required_compute_args(self) -> frozenset[str]:
233
+ """Required input keys for this metric"""
234
+ return frozenset(inspect.signature(self.compute).parameters.keys())
235
+
236
+ @cached_property
237
+ def required_kwargs(self) -> frozenset[str]:
238
+ """Required input keys for this metric"""
239
+ if self.kwargs_to_compute_args is None:
240
+ return frozenset()
241
+
242
+ return frozenset(self.kwargs_to_compute_args.values())
243
+
244
+ def compute_from_kwargs(self, **kwargs: Any) -> dict[str, Any]:
245
+ """Run compute with an arbitrary dictionary of input keys and values.
246
+
247
+ The 'kwargs_to_compute_args' property here will determine
248
+ where in the kwargs we will look for the values to pass to the compute method.
249
+
250
+ Parameters marked in 'optional_kwargs' will only be passed if present in kwargs.
251
+ """
252
+ if self.kwargs_to_compute_args:
253
+ compute_inputs = {}
254
+ for compute_arg, kwargs_key in self.kwargs_to_compute_args.items():
255
+ if compute_arg in self.optional_kwargs:
256
+ # Optional parameter - only add if present
257
+ try:
258
+ compute_inputs[compute_arg] = nested_dict.getitem(
259
+ kwargs, key=kwargs_key
260
+ )
261
+ except KeyError:
262
+ pass # Don't pass this parameter to compute()
263
+ else:
264
+ # Required parameter - use getitem (will raise if missing)
265
+ compute_inputs[compute_arg] = nested_dict.getitem(
266
+ kwargs, key=kwargs_key
267
+ )
268
+ else:
269
+ # If kwargs_to_compute_args is not defined, use kwargs directly
270
+ compute_inputs = kwargs
271
+ return self.compute(**compute_inputs)
272
+
273
+ @property
274
+ def kwargs_to_compute_args(self) -> dict[str, Any]:
275
+ """Map input keys to a flat dictionary.
276
+
277
+ If not implemented, we return None, and pass the kwargs directly to the compute method.
278
+
279
+ Override e.g. as:
280
+ ```python
281
+ @cached_property
282
+ def kwargs_to_compute_args(self) -> dict[str, Any]:
283
+ return {
284
+ "y_true": ("network_input", "coords_unnoised"),
285
+ "y_pred": ("network_output", "coords_pred"),
286
+ "extra_info": ("extra_info",),
287
+ }
288
+ ```
289
+ """
290
+ return None
291
+
292
+ @property
293
+ def optional_kwargs(self) -> frozenset[str]:
294
+ """Set of compute argument names that are optional.
295
+
296
+ Optional parameters will only be passed to compute() if present in kwargs.
297
+ The compute() method should have sensible defaults for these parameters.
298
+
299
+ Override e.g. as:
300
+ ```python
301
+ @property
302
+ def optional_kwargs(self) -> frozenset[str]:
303
+ return frozenset(["confidence_indices", "interfaces_to_score"])
304
+ ```
305
+ """
306
+ return frozenset()
307
+
308
+ @abstractmethod
309
+ def compute(self, **kwargs: Any) -> dict[str, Any] | list[dict[str, Any]]:
310
+ """Implement actual metric calculation here
311
+
312
+ Override e.g. as:
313
+ ```python
314
+ def compute(self, y_true, y_pred, extra_info):
315
+ print(extra_info)
316
+ return lddt(y_true, y_pred, thres=self.custom_thresholds)
317
+ ```
318
+ """
319
+ raise NotImplementedError
@@ -0,0 +1,47 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ pi = torch.acos(torch.zeros(1)).item() * 2
5
+
6
+
7
+ class FourierEmbedding(nn.Module):
8
+ def __init__(self, c):
9
+ super().__init__()
10
+ self.c = c
11
+ self.register_buffer("w", torch.zeros(c, dtype=torch.float32))
12
+ self.register_buffer("b", torch.zeros(c, dtype=torch.float32))
13
+ self.reset_parameters()
14
+
15
+ def reset_parameters(self) -> None:
16
+ # super().reset_parameters()
17
+ nn.init.normal_(self.w)
18
+ nn.init.normal_(self.b)
19
+
20
+ def forward(
21
+ self,
22
+ t, # [D]
23
+ ):
24
+ return torch.cos(2 * pi * (t[..., None] * self.w + self.b))
25
+
26
+
27
+ class Dropout(nn.Module):
28
+ # Dropout entire row or column
29
+ def __init__(self, broadcast_dim=None, p_drop=0.15):
30
+ super(Dropout, self).__init__()
31
+ # give ones with probability of 1-p_drop / zeros with p_drop
32
+ self.sampler = torch.distributions.bernoulli.Bernoulli(
33
+ torch.tensor([1 - p_drop])
34
+ )
35
+ self.broadcast_dim = broadcast_dim
36
+ self.p_drop = p_drop
37
+
38
+ def forward(self, x):
39
+ if not self.training: # no drophead during evaluation mode
40
+ return x
41
+ shape = list(x.shape)
42
+ if self.broadcast_dim is not None:
43
+ shape[self.broadcast_dim] = 1
44
+ mask = self.sampler.sample(shape).to(x.device).view(shape)
45
+
46
+ x = mask * x / (1.0 - self.p_drop)
47
+ return x
@@ -0,0 +1,6 @@
1
+ """Shared testing utilities for foundry."""
2
+
3
+ from foundry.testing.fixtures import get_test_data_dir, gpu
4
+ from foundry.testing.pytest_hooks import configure_pytest
5
+
6
+ __all__ = ["configure_pytest", "get_test_data_dir", "gpu"]
@@ -0,0 +1,19 @@
1
+ """Shared pytest fixtures for foundry tests."""
2
+
3
+ from pathlib import Path
4
+
5
+ import pytest
6
+ import torch
7
+
8
+
9
+ @pytest.fixture(scope="session")
10
+ def gpu():
11
+ """Fixture to check GPU availability for tests that require CUDA."""
12
+ if not torch.cuda.is_available():
13
+ pytest.skip("GPU not available")
14
+ return True
15
+
16
+
17
+ def get_test_data_dir(conftest_file: str) -> Path:
18
+ """Get TEST_DATA_DIR path relative to conftest file."""
19
+ return Path(conftest_file).resolve().parent / "data"
@@ -0,0 +1,15 @@
1
+ """Shared pytest configuration hooks for foundry tests."""
2
+
3
+ import rootutils
4
+ from dotenv import load_dotenv
5
+
6
+
7
+ def configure_pytest(config, conftest_file: str) -> None:
8
+ """Configure pytest for foundry tests.
9
+ Sets up project root and environment variables.
10
+ """
11
+ # Setup the project root
12
+ rootutils.setup_root(conftest_file, indicator=".project-root", pythonpath=True)
13
+
14
+ # Setup environment variables
15
+ load_dotenv(override=True)