rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,65 @@
1
+ import math
2
+
3
+ import torch
4
+
5
+ from foundry.utils.rigid import rot_vec_mul
6
+
7
+
8
+ def centre(X_L, X_exists_L):
9
+ X_L = X_L.clone()
10
+ X_L[X_exists_L] = X_L[X_exists_L] - torch.mean(
11
+ X_L[X_exists_L], dim=-2, keepdim=True
12
+ )
13
+ X_L[~X_exists_L] = 0.0
14
+ return X_L
15
+
16
+
17
+ def get_random_augmentation(X_L, s_trans):
18
+ """
19
+ Inputs:
20
+ X_L [D, L, 3]: Batched atom coordinates
21
+ s_trans (float): standard deviation of a global translation to be applied for each
22
+ element in the batch
23
+ """
24
+ D, L, _ = X_L.shape
25
+ R = uniform_random_rotation((D,)).to(X_L.device)
26
+ noise = s_trans * torch.normal(mean=0, std=1, size=(D, 1, 3)).to(X_L.device)
27
+ return rot_vec_mul(R[:, None], X_L) + noise
28
+
29
+
30
+ def centre_random_augmentation(X_L, X_exists_L, s_trans):
31
+ X_L = centre(X_L, X_exists_L)
32
+ return get_random_augmentation(X_L, s_trans)
33
+
34
+
35
+ def uniform_random_rotation(size):
36
+ # Sample random angles for rotations around X, Y, and Z axes
37
+ theta_x = torch.rand(size) * 2 * math.pi
38
+ theta_y = torch.rand(size) * 2 * math.pi
39
+ theta_z = torch.rand(size) * 2 * math.pi
40
+
41
+ # Calculate the cosines and sines of the angles
42
+ cos_x = torch.cos(theta_x)
43
+ sin_x = torch.sin(theta_x)
44
+ cos_y = torch.cos(theta_y)
45
+ sin_y = torch.sin(theta_y)
46
+ cos_z = torch.cos(theta_z)
47
+ sin_z = torch.sin(theta_z)
48
+
49
+ # Create the rotation matrices around X, Y, and Z axes
50
+ rotation_x = torch.stack(
51
+ [torch.tensor([[1, 0, 0], [0, c, -s], [0, s, c]]) for c, s in zip(cos_x, sin_x)]
52
+ )
53
+
54
+ rotation_y = torch.stack(
55
+ [torch.tensor([[c, 0, s], [0, 1, 0], [-s, 0, c]]) for c, s in zip(cos_y, sin_y)]
56
+ )
57
+
58
+ rotation_z = torch.stack(
59
+ [torch.tensor([[c, -s, 0], [s, c, 0], [0, 0, 1]]) for c, s in zip(cos_z, sin_z)]
60
+ )
61
+
62
+ # Combine the rotation matrices
63
+ rotation_matrix = torch.matmul(rotation_z, torch.matmul(rotation_y, rotation_x))
64
+
65
+ return rotation_matrix
@@ -0,0 +1,172 @@
1
+ """Squashfs filesystem mount management.
2
+
3
+ Provides a singleton manager for mounting and caching squashfs filesystems.
4
+ Mounts are lazy (happen on first access) and persist for the lifetime of the program.
5
+ """
6
+
7
+ import atexit
8
+ import logging
9
+ import os
10
+ import subprocess
11
+ import tempfile
12
+ import threading
13
+ from pathlib import Path
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class SquashfsManager:
19
+ """Singleton manager for squashfs filesystem mounts.
20
+
21
+ Provides lazy mounting of squashfs files with automatic caching and cleanup.
22
+ Thread-safe for use with multiprocessing/DataLoader workers.
23
+
24
+ Examples:
25
+ Get mount directory (mounts automatically if needed)::
26
+
27
+ from foundry.utils.squashfs import SquashfsManager
28
+
29
+ mount_dir = SquashfsManager.get_mount("/path/to/file.sqfs")
30
+ file_path = os.path.join(mount_dir, "internal/path/to/file")
31
+
32
+ Inspect active mounts::
33
+
34
+ mounts = SquashfsManager.list_mounts()
35
+ print(f"Active mounts: {len(mounts)}")
36
+
37
+ Cleanup (typically for tests)::
38
+
39
+ SquashfsManager.unmount_all()
40
+ """
41
+
42
+ _mounts: dict[str, str] = {}
43
+ _lock = threading.Lock()
44
+ _cleanup_registered = False
45
+
46
+ @classmethod
47
+ def get_mount(cls, sqfs_file: str) -> str:
48
+ """Get mount directory for squashfs file, mounting if needed.
49
+
50
+ Args:
51
+ sqfs_file: Path to the .sqfs file.
52
+
53
+ Returns:
54
+ Path to the mount directory containing the squashfs contents.
55
+
56
+ Raises:
57
+ FileNotFoundError: If the sqfs file doesn't exist.
58
+ RuntimeError: If mounting fails.
59
+ """
60
+ sqfs_file = str(Path(sqfs_file).resolve())
61
+
62
+ # Check file exists
63
+ if not os.path.exists(sqfs_file):
64
+ raise FileNotFoundError(f"Squashfs file not found: {sqfs_file}")
65
+
66
+ with cls._lock:
67
+ # Register cleanup on first use
68
+ if not cls._cleanup_registered:
69
+ atexit.register(cls._cleanup_on_exit)
70
+ cls._cleanup_registered = True
71
+
72
+ # Return cached mount if exists
73
+ if sqfs_file in cls._mounts:
74
+ return cls._mounts[sqfs_file]
75
+
76
+ # Mount and cache
77
+ mount_dir = cls._mount(sqfs_file)
78
+ cls._mounts[sqfs_file] = mount_dir
79
+ logger.info(f"Mounted {sqfs_file} at {mount_dir}")
80
+ return mount_dir
81
+
82
+ @classmethod
83
+ def _mount(cls, sqfs_file: str) -> str:
84
+ """Internal: Actually perform the squashfs mount.
85
+
86
+ Args:
87
+ sqfs_file: Path to the .sqfs file.
88
+
89
+ Returns:
90
+ Path to the mount directory.
91
+
92
+ Raises:
93
+ RuntimeError: If mounting fails.
94
+ """
95
+ mount_dir = tempfile.mkdtemp(prefix="sqfs_")
96
+ try:
97
+ subprocess.run(
98
+ ["squashfuse", sqfs_file, mount_dir],
99
+ check=True,
100
+ capture_output=True,
101
+ )
102
+ return mount_dir
103
+ except subprocess.CalledProcessError as e:
104
+ # Cleanup failed mount directory
105
+ try:
106
+ os.rmdir(mount_dir)
107
+ except OSError:
108
+ pass
109
+ raise RuntimeError(
110
+ f"Failed to mount {sqfs_file}: {e.stderr.decode()}"
111
+ ) from e
112
+
113
+ @classmethod
114
+ def _unmount(cls, sqfs_file: str) -> None:
115
+ """Internal: Unmount a squashfs filesystem.
116
+
117
+ Args:
118
+ sqfs_file: Path to the .sqfs file to unmount.
119
+ """
120
+ if sqfs_file not in cls._mounts:
121
+ return
122
+
123
+ mount_dir = cls._mounts[sqfs_file]
124
+ try:
125
+ subprocess.run(
126
+ ["fusermount", "-u", mount_dir],
127
+ check=True,
128
+ capture_output=True,
129
+ )
130
+ os.rmdir(mount_dir)
131
+ logger.info(f"Unmounted {sqfs_file}")
132
+ except (subprocess.CalledProcessError, OSError) as e:
133
+ logger.warning(f"Failed to unmount {sqfs_file}: {e}")
134
+ finally:
135
+ del cls._mounts[sqfs_file]
136
+
137
+ @classmethod
138
+ def list_mounts(cls) -> dict[str, str]:
139
+ """List all active squashfs mounts.
140
+
141
+ Returns:
142
+ Dictionary mapping sqfs file paths to mount directories.
143
+
144
+ Examples:
145
+ >>> mounts = SquashfsManager.list_mounts()
146
+ >>> print(f"Active mounts: {len(mounts)}")
147
+ >>> for sqfs, mount_dir in mounts.items():
148
+ ... print(f"{sqfs} -> {mount_dir}")
149
+ """
150
+ with cls._lock:
151
+ return cls._mounts.copy()
152
+
153
+ @classmethod
154
+ def unmount_all(cls) -> None:
155
+ """Unmount all squashfs filesystems.
156
+
157
+ Useful for cleanup in tests or explicit resource management.
158
+ Normally not needed as cleanup happens automatically on program exit.
159
+
160
+ Examples:
161
+ >>> SquashfsManager.unmount_all() # Clean up in test teardown
162
+ """
163
+ with cls._lock:
164
+ sqfs_files = list(cls._mounts.keys())
165
+ for sqfs_file in sqfs_files:
166
+ cls._unmount(sqfs_file)
167
+
168
+ @classmethod
169
+ def _cleanup_on_exit(cls) -> None:
170
+ """Cleanup hook called on program exit via atexit."""
171
+ logger.debug("Cleaning up squashfs mounts on exit")
172
+ cls.unmount_all()
foundry/utils/torch.py ADDED
@@ -0,0 +1,317 @@
1
+ """General convenience utilities for PyTorch."""
2
+
3
+ __all__ = ["map_to", "assert_no_nans", "assert_shape", "assert_same_shape"]
4
+
5
+ import time
6
+ import warnings
7
+ from contextlib import contextmanager
8
+
9
+ import numpy as np
10
+ import torch
11
+ from beartype.typing import Any, Sequence
12
+ from toolz import valmap
13
+ from torch import Tensor
14
+ from torch._prims_common import DeviceLikeType
15
+ from torch.types import _dtype
16
+
17
+ from foundry import should_check_nans
18
+ from foundry.common import at_least_one_exists, do_nothing
19
+
20
+
21
+ def map_to(
22
+ x: Any,
23
+ *,
24
+ device: DeviceLikeType | None = None,
25
+ dtype: _dtype | None = None,
26
+ non_blocking: bool = False,
27
+ **to_kwargs,
28
+ ) -> Any:
29
+ """
30
+ Recursively applies the `.to()` method to all tensors in a nested structure.
31
+
32
+ This function handles nested structures such as dictionaries and lists, applying the `.to()` method
33
+ to any PyTorch tensors while leaving other types unchanged.
34
+
35
+ NOTE: If you are instantiating a new tensor, you should use the `device` and `dtype` arguments
36
+ instead of calling `map_to()` on the tensor.
37
+ (https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#create-tensors-directly-on-the-target-device)
38
+
39
+
40
+ Args:
41
+ - x (Any): The input structure, which can be a tensor, dictionary, list, or any other type.
42
+ - device (DeviceLikeType | None): The target device to move tensors to (e.g., 'cpu', 'cuda').
43
+ - dtype (_dtype | None): The target dtype to cast tensors to.
44
+ - non_blocking (bool): Whether to use non-blocking transfers when possible.
45
+ - **to_kwargs: Additional keyword arguments to pass to the `.to()` method.
46
+
47
+ Returns:
48
+ - Any: The input structure with all contained tensors processed by the `.to()` method.
49
+
50
+ Example:
51
+ >>> data = {"tensor": torch.tensor([1, 2, 3]), "list": [torch.tensor([4, 5]), "string"]}
52
+ >>> map_to(data, device="cuda", dtype=torch.float32)
53
+ {'tensor': tensor([1., 2., 3.], device='cuda:0', dtype=torch.float32),
54
+ 'list': [tensor([4., 5.], device='cuda:0', dtype=torch.float32), 'string']}
55
+ """
56
+ torch._assert(
57
+ at_least_one_exists(device, dtype),
58
+ "Must provide at least one of `device` or `dtype`",
59
+ )
60
+
61
+ if isinstance(x, dict):
62
+ return valmap(
63
+ lambda v: map_to(
64
+ v, device=device, dtype=dtype, non_blocking=non_blocking, **to_kwargs
65
+ ),
66
+ x,
67
+ )
68
+ elif isinstance(x, (list, tuple)):
69
+ return type(x)(
70
+ map(
71
+ lambda v: map_to(
72
+ v,
73
+ device=device,
74
+ dtype=dtype,
75
+ non_blocking=non_blocking,
76
+ **to_kwargs,
77
+ ),
78
+ x,
79
+ )
80
+ )
81
+ elif isinstance(x, Tensor):
82
+ return x.to(device=device, dtype=dtype, non_blocking=non_blocking, **to_kwargs)
83
+ else:
84
+ return x
85
+
86
+
87
+ def _assert_no_nans(x: Any, *, msg: str = "", fail_if_not_tensor: bool = False) -> None:
88
+ """Recursively checks for NaN values in tensor-like objects.
89
+
90
+ Args:
91
+ - x (Any): Input to check for NaNs. Can be a tensor, dict, list, tuple, or other type.
92
+ - msg (str): Prefix for error messages.
93
+ - fail_if_not_tensor (bool): If True, raises error for non-tensor types.
94
+ """
95
+ if isinstance(x, Tensor):
96
+ torch._assert(
97
+ not torch.isnan(x).any(),
98
+ ": ".join(filter(bool, [msg, "Tensor contains NaNs!"])),
99
+ )
100
+ elif isinstance(x, np.ndarray):
101
+ torch._assert(
102
+ not np.isnan(x).any(),
103
+ ": ".join(filter(bool, [msg, "Numpy array contains NaNs!"])),
104
+ )
105
+ elif isinstance(x, float):
106
+ torch._assert(
107
+ not np.isnan(x),
108
+ ": ".join(filter(bool, [msg, "float is NaN!"])),
109
+ )
110
+ elif isinstance(x, dict):
111
+ for k, v in x.items():
112
+ _assert_no_nans(
113
+ v,
114
+ msg=".".join(filter(bool, [msg, k])),
115
+ fail_if_not_tensor=fail_if_not_tensor,
116
+ )
117
+ elif isinstance(x, (list, tuple)):
118
+ for idx, v in enumerate(x):
119
+ _assert_no_nans(
120
+ v,
121
+ msg=".".join(filter(bool, [msg, str(idx)])),
122
+ fail_if_not_tensor=fail_if_not_tensor,
123
+ )
124
+ elif fail_if_not_tensor:
125
+ raise ValueError(f"Unsupported type: {type(x)}")
126
+
127
+
128
+ assert_no_nans = _assert_no_nans if should_check_nans else do_nothing
129
+
130
+
131
+ @contextmanager
132
+ def _suppress_tracer_warnings():
133
+ """
134
+ Context manager to temporarily suppress known warnings in torch.jit.trace().
135
+ Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
136
+
137
+ References:
138
+ - https://github.com/NVlabs/edm2/blob/main/torch_utils/misc.py
139
+ """
140
+ tracer_warning_filter = ("ignore", None, torch.jit.TracerWarning, None, 0)
141
+ warnings.filters.insert(0, tracer_warning_filter)
142
+ yield
143
+ warnings.filters.remove(tracer_warning_filter)
144
+
145
+
146
+ def assert_shape(tensor: Tensor, ref_shape: Sequence[int | None]):
147
+ """
148
+ Assert that the shape of a tensor matches the given list of integers.
149
+ None indicates that the size of a dimension is allowed to vary.
150
+ Performs symbolic assertion when used in torch.jit.trace().
151
+
152
+ Args:
153
+ - tensor (Tensor): The tensor to check the shape of.
154
+ - ref_shape (Sequence[int | None]): The expected shape of the tensor.
155
+
156
+ References:
157
+ - https://github.com/NVlabs/edm2/blob/main/torch_utils/misc.py
158
+ """
159
+
160
+ if tensor.ndim != len(ref_shape):
161
+ raise AssertionError(
162
+ f"Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}"
163
+ )
164
+
165
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
166
+ if tensor.ndim != len(ref_shape):
167
+ raise AssertionError(
168
+ f"Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}"
169
+ )
170
+
171
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
172
+ if ref_size is None:
173
+ pass
174
+ elif isinstance(ref_size, torch.Tensor):
175
+ with (
176
+ _suppress_tracer_warnings()
177
+ ): # as_tensor results are registered as constants
178
+ torch._assert(
179
+ torch.equal(torch.as_tensor(size), ref_size),
180
+ f"Wrong size for dimension {idx}",
181
+ )
182
+ elif isinstance(size, torch.Tensor):
183
+ with (
184
+ _suppress_tracer_warnings()
185
+ ): # as_tensor results are registered as constants
186
+ torch._assert(
187
+ torch.equal(size, torch.as_tensor(ref_size)),
188
+ f"Wrong size for dimension {idx}: expected {ref_size}",
189
+ )
190
+ elif size != ref_size:
191
+ raise AssertionError(
192
+ f"Wrong size for dimension {idx}: got {size}, expected {ref_size}"
193
+ )
194
+
195
+
196
+ def assert_same_shape(tensor: Tensor, ref_tensor: Tensor) -> None:
197
+ """Assert that two tensors have the same shape."""
198
+ assert_shape(tensor, ref_tensor.shape)
199
+
200
+
201
+ def device_of(obj: Any) -> torch.device:
202
+ """Get the device of a PyTorch object, e.g. a `nn.Module` or a `Tensor`."""
203
+ if hasattr(obj, "device"):
204
+ return obj.device
205
+ elif hasattr(obj, "parameters"):
206
+ return next(obj.parameters()).device
207
+ else:
208
+ raise ValueError(f"Unsupported type: {type(obj)}")
209
+
210
+
211
+ class Timer:
212
+ """
213
+ A simple timer class for measuring elapsed time.
214
+
215
+ This class provides functionality to start, stop, reset, and measure elapsed time.
216
+ It can optionally use CUDA or MPS synchronization barriers for more accurate timing
217
+ when working with GPU operations.
218
+
219
+ Attributes:
220
+ name_ (str): The name of the timer.
221
+ elapsed_ (float): The total elapsed time.
222
+ started_ (bool): Flag indicating if the timer is currently running.
223
+ start_time (float): The start time of the current timing session.
224
+ use_barrier (bool): Whether to use CUDA or MPS synchronization barriers.
225
+
226
+ Args:
227
+ name (str): The name of the timer.
228
+ use_barrier (bool, optional): Whether to use synchronization barriers. Defaults to True.
229
+ """
230
+
231
+ def __init__(self, name, use_barrier: bool = True):
232
+ self.name_ = name
233
+ self.elapsed_ = 0.0
234
+ self.started_ = False
235
+ self.start_time = time.time()
236
+ self.use_barrier = use_barrier
237
+
238
+ def start(self) -> None:
239
+ """Start the timer."""
240
+ assert not self.started_, f"timer {self.name_} has already been started"
241
+ if self.use_barrier and torch.cuda.is_available():
242
+ torch.cuda.synchronize()
243
+ elif self.use_barrier and torch.backends.mps.is_available():
244
+ torch.mps.synchronize()
245
+ self.start_time = time.time()
246
+ self.started_ = True
247
+
248
+ def stop(self) -> None:
249
+ """Stop the timer."""
250
+ assert self.started_, f"timer {self.name_} is not started"
251
+ if self.use_barrier and torch.cuda.is_available():
252
+ torch.cuda.synchronize()
253
+ elif self.use_barrier and torch.backends.mps.is_available():
254
+ torch.mps.synchronize()
255
+ self.elapsed_ += time.time() - self.start_time
256
+ self.started_ = False
257
+
258
+ def reset(self) -> None:
259
+ """Reset timer."""
260
+ self.elapsed_ = 0.0
261
+ self.started_ = False
262
+
263
+ def elapsed(self, reset: bool = True) -> float:
264
+ """Calculate the elapsed time."""
265
+ started_ = self.started_
266
+ # If the timing in progress, end it first.
267
+ if self.started_:
268
+ self.stop()
269
+ # Get the elapsed time.
270
+ elapsed_ = self.elapsed_
271
+ # Reset the elapsed time
272
+ if reset:
273
+ self.reset()
274
+ # If timing was in progress, set it back.
275
+ if started_:
276
+ self.start()
277
+ return elapsed_
278
+
279
+
280
+ class Timers:
281
+ """
282
+ A collection of named Timer objects.
283
+
284
+ This class manages multiple Timer instances, allowing for easy creation,
285
+ starting, stopping, resetting, and querying of elapsed times for multiple timers.
286
+
287
+ Attributes:
288
+ timers (dict): A dictionary of Timer objects, keyed by their names.
289
+ """
290
+
291
+ def __init__(self):
292
+ self.timers = {}
293
+
294
+ def __call__(self, name, use_barrier: bool = True) -> Timer:
295
+ """Get or create a Timer object."""
296
+ if name not in self.timers:
297
+ self.timers[name] = Timer(name, use_barrier=use_barrier)
298
+ return self.timers[name]
299
+
300
+ def start(self, *names) -> None:
301
+ """Start the specified timers."""
302
+ for name in names:
303
+ self(name).start()
304
+
305
+ def stop(self, *names) -> None:
306
+ """Stop the specified timers."""
307
+ for name in names:
308
+ self.timers[name].stop()
309
+
310
+ def reset(self, *names) -> None:
311
+ """Reset the specified timers."""
312
+ for name in names:
313
+ self.timers[name].reset()
314
+
315
+ def elapsed(self, *names, reset: bool = True) -> dict[str, float]:
316
+ """Get the elapsed time for the specified timers."""
317
+ return {name: self.timers[name].elapsed(reset=reset) for name in names}