rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from foundry.utils.rigid import rot_vec_mul
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def centre(X_L, X_exists_L):
|
|
9
|
+
X_L = X_L.clone()
|
|
10
|
+
X_L[X_exists_L] = X_L[X_exists_L] - torch.mean(
|
|
11
|
+
X_L[X_exists_L], dim=-2, keepdim=True
|
|
12
|
+
)
|
|
13
|
+
X_L[~X_exists_L] = 0.0
|
|
14
|
+
return X_L
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_random_augmentation(X_L, s_trans):
|
|
18
|
+
"""
|
|
19
|
+
Inputs:
|
|
20
|
+
X_L [D, L, 3]: Batched atom coordinates
|
|
21
|
+
s_trans (float): standard deviation of a global translation to be applied for each
|
|
22
|
+
element in the batch
|
|
23
|
+
"""
|
|
24
|
+
D, L, _ = X_L.shape
|
|
25
|
+
R = uniform_random_rotation((D,)).to(X_L.device)
|
|
26
|
+
noise = s_trans * torch.normal(mean=0, std=1, size=(D, 1, 3)).to(X_L.device)
|
|
27
|
+
return rot_vec_mul(R[:, None], X_L) + noise
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def centre_random_augmentation(X_L, X_exists_L, s_trans):
|
|
31
|
+
X_L = centre(X_L, X_exists_L)
|
|
32
|
+
return get_random_augmentation(X_L, s_trans)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def uniform_random_rotation(size):
|
|
36
|
+
# Sample random angles for rotations around X, Y, and Z axes
|
|
37
|
+
theta_x = torch.rand(size) * 2 * math.pi
|
|
38
|
+
theta_y = torch.rand(size) * 2 * math.pi
|
|
39
|
+
theta_z = torch.rand(size) * 2 * math.pi
|
|
40
|
+
|
|
41
|
+
# Calculate the cosines and sines of the angles
|
|
42
|
+
cos_x = torch.cos(theta_x)
|
|
43
|
+
sin_x = torch.sin(theta_x)
|
|
44
|
+
cos_y = torch.cos(theta_y)
|
|
45
|
+
sin_y = torch.sin(theta_y)
|
|
46
|
+
cos_z = torch.cos(theta_z)
|
|
47
|
+
sin_z = torch.sin(theta_z)
|
|
48
|
+
|
|
49
|
+
# Create the rotation matrices around X, Y, and Z axes
|
|
50
|
+
rotation_x = torch.stack(
|
|
51
|
+
[torch.tensor([[1, 0, 0], [0, c, -s], [0, s, c]]) for c, s in zip(cos_x, sin_x)]
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
rotation_y = torch.stack(
|
|
55
|
+
[torch.tensor([[c, 0, s], [0, 1, 0], [-s, 0, c]]) for c, s in zip(cos_y, sin_y)]
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
rotation_z = torch.stack(
|
|
59
|
+
[torch.tensor([[c, -s, 0], [s, c, 0], [0, 0, 1]]) for c, s in zip(cos_z, sin_z)]
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Combine the rotation matrices
|
|
63
|
+
rotation_matrix = torch.matmul(rotation_z, torch.matmul(rotation_y, rotation_x))
|
|
64
|
+
|
|
65
|
+
return rotation_matrix
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
"""Squashfs filesystem mount management.
|
|
2
|
+
|
|
3
|
+
Provides a singleton manager for mounting and caching squashfs filesystems.
|
|
4
|
+
Mounts are lazy (happen on first access) and persist for the lifetime of the program.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import atexit
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
import subprocess
|
|
11
|
+
import tempfile
|
|
12
|
+
import threading
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SquashfsManager:
|
|
19
|
+
"""Singleton manager for squashfs filesystem mounts.
|
|
20
|
+
|
|
21
|
+
Provides lazy mounting of squashfs files with automatic caching and cleanup.
|
|
22
|
+
Thread-safe for use with multiprocessing/DataLoader workers.
|
|
23
|
+
|
|
24
|
+
Examples:
|
|
25
|
+
Get mount directory (mounts automatically if needed)::
|
|
26
|
+
|
|
27
|
+
from foundry.utils.squashfs import SquashfsManager
|
|
28
|
+
|
|
29
|
+
mount_dir = SquashfsManager.get_mount("/path/to/file.sqfs")
|
|
30
|
+
file_path = os.path.join(mount_dir, "internal/path/to/file")
|
|
31
|
+
|
|
32
|
+
Inspect active mounts::
|
|
33
|
+
|
|
34
|
+
mounts = SquashfsManager.list_mounts()
|
|
35
|
+
print(f"Active mounts: {len(mounts)}")
|
|
36
|
+
|
|
37
|
+
Cleanup (typically for tests)::
|
|
38
|
+
|
|
39
|
+
SquashfsManager.unmount_all()
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
_mounts: dict[str, str] = {}
|
|
43
|
+
_lock = threading.Lock()
|
|
44
|
+
_cleanup_registered = False
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def get_mount(cls, sqfs_file: str) -> str:
|
|
48
|
+
"""Get mount directory for squashfs file, mounting if needed.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
sqfs_file: Path to the .sqfs file.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Path to the mount directory containing the squashfs contents.
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
FileNotFoundError: If the sqfs file doesn't exist.
|
|
58
|
+
RuntimeError: If mounting fails.
|
|
59
|
+
"""
|
|
60
|
+
sqfs_file = str(Path(sqfs_file).resolve())
|
|
61
|
+
|
|
62
|
+
# Check file exists
|
|
63
|
+
if not os.path.exists(sqfs_file):
|
|
64
|
+
raise FileNotFoundError(f"Squashfs file not found: {sqfs_file}")
|
|
65
|
+
|
|
66
|
+
with cls._lock:
|
|
67
|
+
# Register cleanup on first use
|
|
68
|
+
if not cls._cleanup_registered:
|
|
69
|
+
atexit.register(cls._cleanup_on_exit)
|
|
70
|
+
cls._cleanup_registered = True
|
|
71
|
+
|
|
72
|
+
# Return cached mount if exists
|
|
73
|
+
if sqfs_file in cls._mounts:
|
|
74
|
+
return cls._mounts[sqfs_file]
|
|
75
|
+
|
|
76
|
+
# Mount and cache
|
|
77
|
+
mount_dir = cls._mount(sqfs_file)
|
|
78
|
+
cls._mounts[sqfs_file] = mount_dir
|
|
79
|
+
logger.info(f"Mounted {sqfs_file} at {mount_dir}")
|
|
80
|
+
return mount_dir
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def _mount(cls, sqfs_file: str) -> str:
|
|
84
|
+
"""Internal: Actually perform the squashfs mount.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
sqfs_file: Path to the .sqfs file.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
Path to the mount directory.
|
|
91
|
+
|
|
92
|
+
Raises:
|
|
93
|
+
RuntimeError: If mounting fails.
|
|
94
|
+
"""
|
|
95
|
+
mount_dir = tempfile.mkdtemp(prefix="sqfs_")
|
|
96
|
+
try:
|
|
97
|
+
subprocess.run(
|
|
98
|
+
["squashfuse", sqfs_file, mount_dir],
|
|
99
|
+
check=True,
|
|
100
|
+
capture_output=True,
|
|
101
|
+
)
|
|
102
|
+
return mount_dir
|
|
103
|
+
except subprocess.CalledProcessError as e:
|
|
104
|
+
# Cleanup failed mount directory
|
|
105
|
+
try:
|
|
106
|
+
os.rmdir(mount_dir)
|
|
107
|
+
except OSError:
|
|
108
|
+
pass
|
|
109
|
+
raise RuntimeError(
|
|
110
|
+
f"Failed to mount {sqfs_file}: {e.stderr.decode()}"
|
|
111
|
+
) from e
|
|
112
|
+
|
|
113
|
+
@classmethod
|
|
114
|
+
def _unmount(cls, sqfs_file: str) -> None:
|
|
115
|
+
"""Internal: Unmount a squashfs filesystem.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
sqfs_file: Path to the .sqfs file to unmount.
|
|
119
|
+
"""
|
|
120
|
+
if sqfs_file not in cls._mounts:
|
|
121
|
+
return
|
|
122
|
+
|
|
123
|
+
mount_dir = cls._mounts[sqfs_file]
|
|
124
|
+
try:
|
|
125
|
+
subprocess.run(
|
|
126
|
+
["fusermount", "-u", mount_dir],
|
|
127
|
+
check=True,
|
|
128
|
+
capture_output=True,
|
|
129
|
+
)
|
|
130
|
+
os.rmdir(mount_dir)
|
|
131
|
+
logger.info(f"Unmounted {sqfs_file}")
|
|
132
|
+
except (subprocess.CalledProcessError, OSError) as e:
|
|
133
|
+
logger.warning(f"Failed to unmount {sqfs_file}: {e}")
|
|
134
|
+
finally:
|
|
135
|
+
del cls._mounts[sqfs_file]
|
|
136
|
+
|
|
137
|
+
@classmethod
|
|
138
|
+
def list_mounts(cls) -> dict[str, str]:
|
|
139
|
+
"""List all active squashfs mounts.
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
Dictionary mapping sqfs file paths to mount directories.
|
|
143
|
+
|
|
144
|
+
Examples:
|
|
145
|
+
>>> mounts = SquashfsManager.list_mounts()
|
|
146
|
+
>>> print(f"Active mounts: {len(mounts)}")
|
|
147
|
+
>>> for sqfs, mount_dir in mounts.items():
|
|
148
|
+
... print(f"{sqfs} -> {mount_dir}")
|
|
149
|
+
"""
|
|
150
|
+
with cls._lock:
|
|
151
|
+
return cls._mounts.copy()
|
|
152
|
+
|
|
153
|
+
@classmethod
|
|
154
|
+
def unmount_all(cls) -> None:
|
|
155
|
+
"""Unmount all squashfs filesystems.
|
|
156
|
+
|
|
157
|
+
Useful for cleanup in tests or explicit resource management.
|
|
158
|
+
Normally not needed as cleanup happens automatically on program exit.
|
|
159
|
+
|
|
160
|
+
Examples:
|
|
161
|
+
>>> SquashfsManager.unmount_all() # Clean up in test teardown
|
|
162
|
+
"""
|
|
163
|
+
with cls._lock:
|
|
164
|
+
sqfs_files = list(cls._mounts.keys())
|
|
165
|
+
for sqfs_file in sqfs_files:
|
|
166
|
+
cls._unmount(sqfs_file)
|
|
167
|
+
|
|
168
|
+
@classmethod
|
|
169
|
+
def _cleanup_on_exit(cls) -> None:
|
|
170
|
+
"""Cleanup hook called on program exit via atexit."""
|
|
171
|
+
logger.debug("Cleaning up squashfs mounts on exit")
|
|
172
|
+
cls.unmount_all()
|
foundry/utils/torch.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
"""General convenience utilities for PyTorch."""
|
|
2
|
+
|
|
3
|
+
__all__ = ["map_to", "assert_no_nans", "assert_shape", "assert_same_shape"]
|
|
4
|
+
|
|
5
|
+
import time
|
|
6
|
+
import warnings
|
|
7
|
+
from contextlib import contextmanager
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import torch
|
|
11
|
+
from beartype.typing import Any, Sequence
|
|
12
|
+
from toolz import valmap
|
|
13
|
+
from torch import Tensor
|
|
14
|
+
from torch._prims_common import DeviceLikeType
|
|
15
|
+
from torch.types import _dtype
|
|
16
|
+
|
|
17
|
+
from foundry import should_check_nans
|
|
18
|
+
from foundry.common import at_least_one_exists, do_nothing
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def map_to(
|
|
22
|
+
x: Any,
|
|
23
|
+
*,
|
|
24
|
+
device: DeviceLikeType | None = None,
|
|
25
|
+
dtype: _dtype | None = None,
|
|
26
|
+
non_blocking: bool = False,
|
|
27
|
+
**to_kwargs,
|
|
28
|
+
) -> Any:
|
|
29
|
+
"""
|
|
30
|
+
Recursively applies the `.to()` method to all tensors in a nested structure.
|
|
31
|
+
|
|
32
|
+
This function handles nested structures such as dictionaries and lists, applying the `.to()` method
|
|
33
|
+
to any PyTorch tensors while leaving other types unchanged.
|
|
34
|
+
|
|
35
|
+
NOTE: If you are instantiating a new tensor, you should use the `device` and `dtype` arguments
|
|
36
|
+
instead of calling `map_to()` on the tensor.
|
|
37
|
+
(https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#create-tensors-directly-on-the-target-device)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
- x (Any): The input structure, which can be a tensor, dictionary, list, or any other type.
|
|
42
|
+
- device (DeviceLikeType | None): The target device to move tensors to (e.g., 'cpu', 'cuda').
|
|
43
|
+
- dtype (_dtype | None): The target dtype to cast tensors to.
|
|
44
|
+
- non_blocking (bool): Whether to use non-blocking transfers when possible.
|
|
45
|
+
- **to_kwargs: Additional keyword arguments to pass to the `.to()` method.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
- Any: The input structure with all contained tensors processed by the `.to()` method.
|
|
49
|
+
|
|
50
|
+
Example:
|
|
51
|
+
>>> data = {"tensor": torch.tensor([1, 2, 3]), "list": [torch.tensor([4, 5]), "string"]}
|
|
52
|
+
>>> map_to(data, device="cuda", dtype=torch.float32)
|
|
53
|
+
{'tensor': tensor([1., 2., 3.], device='cuda:0', dtype=torch.float32),
|
|
54
|
+
'list': [tensor([4., 5.], device='cuda:0', dtype=torch.float32), 'string']}
|
|
55
|
+
"""
|
|
56
|
+
torch._assert(
|
|
57
|
+
at_least_one_exists(device, dtype),
|
|
58
|
+
"Must provide at least one of `device` or `dtype`",
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
if isinstance(x, dict):
|
|
62
|
+
return valmap(
|
|
63
|
+
lambda v: map_to(
|
|
64
|
+
v, device=device, dtype=dtype, non_blocking=non_blocking, **to_kwargs
|
|
65
|
+
),
|
|
66
|
+
x,
|
|
67
|
+
)
|
|
68
|
+
elif isinstance(x, (list, tuple)):
|
|
69
|
+
return type(x)(
|
|
70
|
+
map(
|
|
71
|
+
lambda v: map_to(
|
|
72
|
+
v,
|
|
73
|
+
device=device,
|
|
74
|
+
dtype=dtype,
|
|
75
|
+
non_blocking=non_blocking,
|
|
76
|
+
**to_kwargs,
|
|
77
|
+
),
|
|
78
|
+
x,
|
|
79
|
+
)
|
|
80
|
+
)
|
|
81
|
+
elif isinstance(x, Tensor):
|
|
82
|
+
return x.to(device=device, dtype=dtype, non_blocking=non_blocking, **to_kwargs)
|
|
83
|
+
else:
|
|
84
|
+
return x
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _assert_no_nans(x: Any, *, msg: str = "", fail_if_not_tensor: bool = False) -> None:
|
|
88
|
+
"""Recursively checks for NaN values in tensor-like objects.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
- x (Any): Input to check for NaNs. Can be a tensor, dict, list, tuple, or other type.
|
|
92
|
+
- msg (str): Prefix for error messages.
|
|
93
|
+
- fail_if_not_tensor (bool): If True, raises error for non-tensor types.
|
|
94
|
+
"""
|
|
95
|
+
if isinstance(x, Tensor):
|
|
96
|
+
torch._assert(
|
|
97
|
+
not torch.isnan(x).any(),
|
|
98
|
+
": ".join(filter(bool, [msg, "Tensor contains NaNs!"])),
|
|
99
|
+
)
|
|
100
|
+
elif isinstance(x, np.ndarray):
|
|
101
|
+
torch._assert(
|
|
102
|
+
not np.isnan(x).any(),
|
|
103
|
+
": ".join(filter(bool, [msg, "Numpy array contains NaNs!"])),
|
|
104
|
+
)
|
|
105
|
+
elif isinstance(x, float):
|
|
106
|
+
torch._assert(
|
|
107
|
+
not np.isnan(x),
|
|
108
|
+
": ".join(filter(bool, [msg, "float is NaN!"])),
|
|
109
|
+
)
|
|
110
|
+
elif isinstance(x, dict):
|
|
111
|
+
for k, v in x.items():
|
|
112
|
+
_assert_no_nans(
|
|
113
|
+
v,
|
|
114
|
+
msg=".".join(filter(bool, [msg, k])),
|
|
115
|
+
fail_if_not_tensor=fail_if_not_tensor,
|
|
116
|
+
)
|
|
117
|
+
elif isinstance(x, (list, tuple)):
|
|
118
|
+
for idx, v in enumerate(x):
|
|
119
|
+
_assert_no_nans(
|
|
120
|
+
v,
|
|
121
|
+
msg=".".join(filter(bool, [msg, str(idx)])),
|
|
122
|
+
fail_if_not_tensor=fail_if_not_tensor,
|
|
123
|
+
)
|
|
124
|
+
elif fail_if_not_tensor:
|
|
125
|
+
raise ValueError(f"Unsupported type: {type(x)}")
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
assert_no_nans = _assert_no_nans if should_check_nans else do_nothing
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@contextmanager
|
|
132
|
+
def _suppress_tracer_warnings():
|
|
133
|
+
"""
|
|
134
|
+
Context manager to temporarily suppress known warnings in torch.jit.trace().
|
|
135
|
+
Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
|
|
136
|
+
|
|
137
|
+
References:
|
|
138
|
+
- https://github.com/NVlabs/edm2/blob/main/torch_utils/misc.py
|
|
139
|
+
"""
|
|
140
|
+
tracer_warning_filter = ("ignore", None, torch.jit.TracerWarning, None, 0)
|
|
141
|
+
warnings.filters.insert(0, tracer_warning_filter)
|
|
142
|
+
yield
|
|
143
|
+
warnings.filters.remove(tracer_warning_filter)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def assert_shape(tensor: Tensor, ref_shape: Sequence[int | None]):
|
|
147
|
+
"""
|
|
148
|
+
Assert that the shape of a tensor matches the given list of integers.
|
|
149
|
+
None indicates that the size of a dimension is allowed to vary.
|
|
150
|
+
Performs symbolic assertion when used in torch.jit.trace().
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
- tensor (Tensor): The tensor to check the shape of.
|
|
154
|
+
- ref_shape (Sequence[int | None]): The expected shape of the tensor.
|
|
155
|
+
|
|
156
|
+
References:
|
|
157
|
+
- https://github.com/NVlabs/edm2/blob/main/torch_utils/misc.py
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
if tensor.ndim != len(ref_shape):
|
|
161
|
+
raise AssertionError(
|
|
162
|
+
f"Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}"
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
|
|
166
|
+
if tensor.ndim != len(ref_shape):
|
|
167
|
+
raise AssertionError(
|
|
168
|
+
f"Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}"
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
|
|
172
|
+
if ref_size is None:
|
|
173
|
+
pass
|
|
174
|
+
elif isinstance(ref_size, torch.Tensor):
|
|
175
|
+
with (
|
|
176
|
+
_suppress_tracer_warnings()
|
|
177
|
+
): # as_tensor results are registered as constants
|
|
178
|
+
torch._assert(
|
|
179
|
+
torch.equal(torch.as_tensor(size), ref_size),
|
|
180
|
+
f"Wrong size for dimension {idx}",
|
|
181
|
+
)
|
|
182
|
+
elif isinstance(size, torch.Tensor):
|
|
183
|
+
with (
|
|
184
|
+
_suppress_tracer_warnings()
|
|
185
|
+
): # as_tensor results are registered as constants
|
|
186
|
+
torch._assert(
|
|
187
|
+
torch.equal(size, torch.as_tensor(ref_size)),
|
|
188
|
+
f"Wrong size for dimension {idx}: expected {ref_size}",
|
|
189
|
+
)
|
|
190
|
+
elif size != ref_size:
|
|
191
|
+
raise AssertionError(
|
|
192
|
+
f"Wrong size for dimension {idx}: got {size}, expected {ref_size}"
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def assert_same_shape(tensor: Tensor, ref_tensor: Tensor) -> None:
|
|
197
|
+
"""Assert that two tensors have the same shape."""
|
|
198
|
+
assert_shape(tensor, ref_tensor.shape)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def device_of(obj: Any) -> torch.device:
|
|
202
|
+
"""Get the device of a PyTorch object, e.g. a `nn.Module` or a `Tensor`."""
|
|
203
|
+
if hasattr(obj, "device"):
|
|
204
|
+
return obj.device
|
|
205
|
+
elif hasattr(obj, "parameters"):
|
|
206
|
+
return next(obj.parameters()).device
|
|
207
|
+
else:
|
|
208
|
+
raise ValueError(f"Unsupported type: {type(obj)}")
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
class Timer:
|
|
212
|
+
"""
|
|
213
|
+
A simple timer class for measuring elapsed time.
|
|
214
|
+
|
|
215
|
+
This class provides functionality to start, stop, reset, and measure elapsed time.
|
|
216
|
+
It can optionally use CUDA or MPS synchronization barriers for more accurate timing
|
|
217
|
+
when working with GPU operations.
|
|
218
|
+
|
|
219
|
+
Attributes:
|
|
220
|
+
name_ (str): The name of the timer.
|
|
221
|
+
elapsed_ (float): The total elapsed time.
|
|
222
|
+
started_ (bool): Flag indicating if the timer is currently running.
|
|
223
|
+
start_time (float): The start time of the current timing session.
|
|
224
|
+
use_barrier (bool): Whether to use CUDA or MPS synchronization barriers.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
name (str): The name of the timer.
|
|
228
|
+
use_barrier (bool, optional): Whether to use synchronization barriers. Defaults to True.
|
|
229
|
+
"""
|
|
230
|
+
|
|
231
|
+
def __init__(self, name, use_barrier: bool = True):
|
|
232
|
+
self.name_ = name
|
|
233
|
+
self.elapsed_ = 0.0
|
|
234
|
+
self.started_ = False
|
|
235
|
+
self.start_time = time.time()
|
|
236
|
+
self.use_barrier = use_barrier
|
|
237
|
+
|
|
238
|
+
def start(self) -> None:
|
|
239
|
+
"""Start the timer."""
|
|
240
|
+
assert not self.started_, f"timer {self.name_} has already been started"
|
|
241
|
+
if self.use_barrier and torch.cuda.is_available():
|
|
242
|
+
torch.cuda.synchronize()
|
|
243
|
+
elif self.use_barrier and torch.backends.mps.is_available():
|
|
244
|
+
torch.mps.synchronize()
|
|
245
|
+
self.start_time = time.time()
|
|
246
|
+
self.started_ = True
|
|
247
|
+
|
|
248
|
+
def stop(self) -> None:
|
|
249
|
+
"""Stop the timer."""
|
|
250
|
+
assert self.started_, f"timer {self.name_} is not started"
|
|
251
|
+
if self.use_barrier and torch.cuda.is_available():
|
|
252
|
+
torch.cuda.synchronize()
|
|
253
|
+
elif self.use_barrier and torch.backends.mps.is_available():
|
|
254
|
+
torch.mps.synchronize()
|
|
255
|
+
self.elapsed_ += time.time() - self.start_time
|
|
256
|
+
self.started_ = False
|
|
257
|
+
|
|
258
|
+
def reset(self) -> None:
|
|
259
|
+
"""Reset timer."""
|
|
260
|
+
self.elapsed_ = 0.0
|
|
261
|
+
self.started_ = False
|
|
262
|
+
|
|
263
|
+
def elapsed(self, reset: bool = True) -> float:
|
|
264
|
+
"""Calculate the elapsed time."""
|
|
265
|
+
started_ = self.started_
|
|
266
|
+
# If the timing in progress, end it first.
|
|
267
|
+
if self.started_:
|
|
268
|
+
self.stop()
|
|
269
|
+
# Get the elapsed time.
|
|
270
|
+
elapsed_ = self.elapsed_
|
|
271
|
+
# Reset the elapsed time
|
|
272
|
+
if reset:
|
|
273
|
+
self.reset()
|
|
274
|
+
# If timing was in progress, set it back.
|
|
275
|
+
if started_:
|
|
276
|
+
self.start()
|
|
277
|
+
return elapsed_
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
class Timers:
|
|
281
|
+
"""
|
|
282
|
+
A collection of named Timer objects.
|
|
283
|
+
|
|
284
|
+
This class manages multiple Timer instances, allowing for easy creation,
|
|
285
|
+
starting, stopping, resetting, and querying of elapsed times for multiple timers.
|
|
286
|
+
|
|
287
|
+
Attributes:
|
|
288
|
+
timers (dict): A dictionary of Timer objects, keyed by their names.
|
|
289
|
+
"""
|
|
290
|
+
|
|
291
|
+
def __init__(self):
|
|
292
|
+
self.timers = {}
|
|
293
|
+
|
|
294
|
+
def __call__(self, name, use_barrier: bool = True) -> Timer:
|
|
295
|
+
"""Get or create a Timer object."""
|
|
296
|
+
if name not in self.timers:
|
|
297
|
+
self.timers[name] = Timer(name, use_barrier=use_barrier)
|
|
298
|
+
return self.timers[name]
|
|
299
|
+
|
|
300
|
+
def start(self, *names) -> None:
|
|
301
|
+
"""Start the specified timers."""
|
|
302
|
+
for name in names:
|
|
303
|
+
self(name).start()
|
|
304
|
+
|
|
305
|
+
def stop(self, *names) -> None:
|
|
306
|
+
"""Stop the specified timers."""
|
|
307
|
+
for name in names:
|
|
308
|
+
self.timers[name].stop()
|
|
309
|
+
|
|
310
|
+
def reset(self, *names) -> None:
|
|
311
|
+
"""Reset the specified timers."""
|
|
312
|
+
for name in names:
|
|
313
|
+
self.timers[name].reset()
|
|
314
|
+
|
|
315
|
+
def elapsed(self, *names, reset: bool = True) -> dict[str, float]:
|
|
316
|
+
"""Get the elapsed time for the specified timers."""
|
|
317
|
+
return {name: self.timers[name].elapsed(reset=reset) for name in names}
|