rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
foundry/training/EMA.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from collections import OrderedDict
|
|
2
|
+
from copy import deepcopy
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class EMA(nn.Module):
|
|
9
|
+
# TODO: Rename shadow to `ema_model` to better match convention
|
|
10
|
+
def __init__(self, model: nn.Module, decay: float):
|
|
11
|
+
"""Initialize the Exponential Moving Average (EMA) module.
|
|
12
|
+
|
|
13
|
+
EMA maintains a shadow model that slowly tracks the weight of the original model.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
model: The original model.
|
|
17
|
+
decay: The decay rate of the EMA. The shadow model will be updated with the formula:
|
|
18
|
+
shadow_variable -= (1 - decay) * (shadow_variable - variable)
|
|
19
|
+
"""
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.decay = decay
|
|
22
|
+
|
|
23
|
+
self.model = model
|
|
24
|
+
self.shadow = deepcopy(self.model)
|
|
25
|
+
|
|
26
|
+
# Detach the shadow model from the computation graph
|
|
27
|
+
for param in self.shadow.parameters():
|
|
28
|
+
param.detach_()
|
|
29
|
+
|
|
30
|
+
@torch.no_grad()
|
|
31
|
+
def update(self):
|
|
32
|
+
"""Update the shadow model using the weight of the original model and the decay rate."""
|
|
33
|
+
if not self.training:
|
|
34
|
+
raise RuntimeError("EMA update should only be called during training")
|
|
35
|
+
|
|
36
|
+
# ... get the model and shadow parameters
|
|
37
|
+
model_params = OrderedDict(self.model.named_parameters())
|
|
38
|
+
shadow_params = OrderedDict(self.shadow.named_parameters())
|
|
39
|
+
|
|
40
|
+
# ... ensure that both models have the same set of keys
|
|
41
|
+
assert model_params.keys() == shadow_params.keys()
|
|
42
|
+
|
|
43
|
+
for name, param in model_params.items():
|
|
44
|
+
# Update the shadow model with the formula:
|
|
45
|
+
# shadow_variable -= (1 - decay) * (shadow_variable - variable)
|
|
46
|
+
# Reference: https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
|
47
|
+
if param.requires_grad:
|
|
48
|
+
shadow_params[name].sub_(
|
|
49
|
+
(1.0 - self.decay) * (shadow_params[name] - param)
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# ... and do the same with the buffers (e.g,. objects that are part of the module state but not trainable parameters)
|
|
53
|
+
model_buffers = OrderedDict(self.model.named_buffers())
|
|
54
|
+
shadow_buffers = OrderedDict(self.shadow.named_buffers())
|
|
55
|
+
|
|
56
|
+
assert model_buffers.keys() == shadow_buffers.keys()
|
|
57
|
+
|
|
58
|
+
for name, buffer in model_buffers.items():
|
|
59
|
+
# ... copy the buffers from the model to the shadow
|
|
60
|
+
shadow_buffers[name].copy_(buffer)
|
|
61
|
+
|
|
62
|
+
def forward(self, *args, **kwargs):
|
|
63
|
+
"""Dynamic dispatch to the correct model (model or shadow)."""
|
|
64
|
+
if self.training:
|
|
65
|
+
return self.model(*args, **kwargs)
|
|
66
|
+
else:
|
|
67
|
+
return self.shadow(*args, **kwargs)
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""Utilities for gradient checkpointing.
|
|
2
|
+
|
|
3
|
+
References:
|
|
4
|
+
* `PyTorch Checkpoint Documentation`_
|
|
5
|
+
|
|
6
|
+
.. _PyTorch Checkpoint Documentation: https://pytorch.org/docs/stable/checkpoint.html
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch.utils.checkpoint import checkpoint
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def create_custom_forward(module, **kwargs):
|
|
14
|
+
"""Create a custom forward function for gradient checkpointing with fixed kwargs.
|
|
15
|
+
|
|
16
|
+
Enables passing keyword arguments to a module when using PyTorch's checkpoint function,
|
|
17
|
+
which only accepts positional arguments for the function to be checkpointed.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
module: The callable (typically a nn.Module) to wrap.
|
|
21
|
+
**kwargs: Keyword arguments to pass to the module during forward.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
A callable that accepts only positional arguments and forwards them along
|
|
25
|
+
with the fixed kwargs to the original module.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def custom_forward(*inputs):
|
|
29
|
+
return module(*inputs, **kwargs)
|
|
30
|
+
|
|
31
|
+
return custom_forward
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def activation_checkpointing(function):
|
|
35
|
+
"""Decorator to enable gradient checkpointing for a function during training.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
function: The function to apply gradient checkpointing to.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Wrapped function that conditionally applies checkpointing based on gradient state.
|
|
42
|
+
|
|
43
|
+
Examples:
|
|
44
|
+
Apply to a forward pass method::
|
|
45
|
+
|
|
46
|
+
@activation_checkpointing
|
|
47
|
+
def forward(self, x, mask=None):
|
|
48
|
+
return self.layer(x, mask)
|
|
49
|
+
|
|
50
|
+
Notes:
|
|
51
|
+
Uses ``use_reentrant=False`` for compatibility with recent PyTorch versions.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def wrapper(*args, **kwargs):
|
|
55
|
+
if torch.is_grad_enabled():
|
|
56
|
+
return checkpoint(
|
|
57
|
+
create_custom_forward(function, **kwargs), *args, use_reentrant=False
|
|
58
|
+
)
|
|
59
|
+
return function(*args, **kwargs)
|
|
60
|
+
|
|
61
|
+
return wrapper
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from torch.optim.lr_scheduler import LRScheduler, _LRScheduler
|
|
4
|
+
from torch.optim.optimizer import Optimizer
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class AF3Scheduler(_LRScheduler):
|
|
8
|
+
"""Implements a two-phase learning rate schedule a-la AF-3:
|
|
9
|
+
1. The base learning rate is 1.8 · 10^−3, which is linearly increased from 0 over the first 1,000 steps.
|
|
10
|
+
2. The learning rate is then decreased by a factor of 0.95 every 50,000 steps.
|
|
11
|
+
|
|
12
|
+
From the AF-3 Supplement, Section 5.4:
|
|
13
|
+
> "For training we use the Adam optimizer with parameters β1 = 0.9, β2 = 0.95, ϵ = 10^−8. The base learning rate
|
|
14
|
+
is 1.8 · 10^−3, which is linearly increased from 0 over the first 1,000 steps. The learning rate is then decreased
|
|
15
|
+
by a factor of 0.95 every 5 · 10^4 steps."
|
|
16
|
+
|
|
17
|
+
References:
|
|
18
|
+
- AF-3 Supplement
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
optimizer: Optimizer,
|
|
24
|
+
base_lr: float = 1.8e-3,
|
|
25
|
+
warmup_steps: int = 1000,
|
|
26
|
+
decay_factor: float = 0.95,
|
|
27
|
+
decay_steps: int = 50000,
|
|
28
|
+
last_epoch: int = -1,
|
|
29
|
+
) -> None:
|
|
30
|
+
"""Initializes a new instance of AF3LRScheduler.
|
|
31
|
+
|
|
32
|
+
Note that the "last_epoch" value is incremented every time we call `scheduler.step()`
|
|
33
|
+
method; we name it "epoch" to follow the PyTorch convention.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
optimizer (Optimizer): Wrapped optimizer.
|
|
37
|
+
base_lr (float): The base learning rate after warmup (which will then be decayed).
|
|
38
|
+
warmup_steps (int): Number of steps for linear warmup.
|
|
39
|
+
decay_factor (float): Factor by which the learning rate is multiplied every decay_steps.
|
|
40
|
+
decay_steps (int): Number of steps between each decay.
|
|
41
|
+
last_epoch (int): The index of the last epoch. Default: -1.
|
|
42
|
+
"""
|
|
43
|
+
self.base_lr = base_lr
|
|
44
|
+
self.warmup_steps = warmup_steps
|
|
45
|
+
self.decay_factor = decay_factor
|
|
46
|
+
self.decay_steps = decay_steps
|
|
47
|
+
super(AF3Scheduler, self).__init__(optimizer, last_epoch)
|
|
48
|
+
|
|
49
|
+
def get_lr(self) -> list[float]:
|
|
50
|
+
if self.last_epoch < self.warmup_steps:
|
|
51
|
+
# Linear warmup
|
|
52
|
+
return [
|
|
53
|
+
self.base_lr * (self.last_epoch / self.warmup_steps)
|
|
54
|
+
for _ in self.optimizer.param_groups
|
|
55
|
+
]
|
|
56
|
+
else:
|
|
57
|
+
# Decay after warmup
|
|
58
|
+
num_decays = (self.last_epoch - self.warmup_steps) // self.decay_steps
|
|
59
|
+
return [
|
|
60
|
+
self.base_lr * (self.decay_factor**num_decays)
|
|
61
|
+
for _ in self.optimizer.param_groups
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@dataclass
|
|
66
|
+
class SchedulerConfig:
|
|
67
|
+
"""Flexible configuration for a learning rate scheduler.
|
|
68
|
+
|
|
69
|
+
Modeled on the PyTorch Lightning scheduler configuration.
|
|
70
|
+
|
|
71
|
+
Attributes:
|
|
72
|
+
scheduler (LRScheduler): The learning rate scheduler instance. Must inherit from `torch.optim.lr_scheduler.LRScheduler`.
|
|
73
|
+
interval (str): The interval at which to apply the scheduler, typically "epoch" or "step". Defaults to "step".
|
|
74
|
+
frequency (int): The frequency of applying the scheduler. For example, a frequency of 1 means the scheduler is applied every epoch. Defaults to 1.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
scheduler: LRScheduler = None
|
|
78
|
+
interval: str = "step"
|
|
79
|
+
frequency: int = 1
|
|
80
|
+
|
|
81
|
+
def state_dict(self) -> dict:
|
|
82
|
+
return {
|
|
83
|
+
"scheduler": self.scheduler.state_dict(),
|
|
84
|
+
"interval": self.interval,
|
|
85
|
+
"frequency": self.frequency,
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
def load_state_dict(self, state_dict: dict) -> None:
|
|
89
|
+
self.scheduler.load_state_dict(state_dict["scheduler"])
|
|
90
|
+
self.interval = state_dict["interval"]
|
|
91
|
+
self.frequency = state_dict["frequency"]
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
logger = logging.getLogger(__name__)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def weighted_rigid_align(
|
|
9
|
+
X_L, # [B, L, 3]
|
|
10
|
+
X_gt_L, # [B, L, 3]
|
|
11
|
+
X_exists_L=None, # [L]
|
|
12
|
+
w_L=None, # [B, L]
|
|
13
|
+
):
|
|
14
|
+
"""
|
|
15
|
+
Weighted rigid body alignment of X_gt_L onto X_L with weights w_L
|
|
16
|
+
Allows for "moving target" ground truth that is se3 invariant
|
|
17
|
+
Following algorithm 28 in AF3 paper
|
|
18
|
+
Returns:
|
|
19
|
+
X_align_L: [B, L, 3]
|
|
20
|
+
"""
|
|
21
|
+
assert X_L.shape == X_gt_L.shape
|
|
22
|
+
assert X_L.shape[:-1] == w_L.shape
|
|
23
|
+
|
|
24
|
+
if X_exists_L is None:
|
|
25
|
+
X_exists_L = torch.ones((X_L.shape[-2]), dtype=torch.bool)
|
|
26
|
+
if w_L is None:
|
|
27
|
+
w_L = torch.ones_like(X_L[..., 0])
|
|
28
|
+
else:
|
|
29
|
+
w_L = w_L.to(torch.float32)
|
|
30
|
+
|
|
31
|
+
# Assert `X_exists_L` is a boolean mask
|
|
32
|
+
assert (
|
|
33
|
+
X_exists_L.dtype == torch.bool
|
|
34
|
+
), "X_exists_L should be a boolean mask! Otherwise, the alignment will be incorrect (silent failure)!"
|
|
35
|
+
|
|
36
|
+
X_resolved = X_L[:, X_exists_L]
|
|
37
|
+
X_gt_resolved = X_gt_L[:, X_exists_L]
|
|
38
|
+
w_resolved = w_L[:, X_exists_L]
|
|
39
|
+
u_X = torch.sum(X_resolved * w_resolved.unsqueeze(-1), dim=-2) / torch.sum(
|
|
40
|
+
w_resolved, dim=-1, keepdim=True
|
|
41
|
+
)
|
|
42
|
+
u_X_gt = torch.sum(X_gt_resolved * w_resolved.unsqueeze(-1), dim=-2) / torch.sum(
|
|
43
|
+
w_resolved, dim=-1, keepdim=True
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
X_resolved = X_resolved - u_X.unsqueeze(-2)
|
|
47
|
+
X_gt_resolved = X_gt_resolved - u_X_gt.unsqueeze(-2)
|
|
48
|
+
|
|
49
|
+
# Computation of the covariance matrix
|
|
50
|
+
C = torch.einsum("bji,bjk->bik", w_resolved[..., None] * X_gt_resolved, X_resolved)
|
|
51
|
+
|
|
52
|
+
U, S, V = torch.linalg.svd(C)
|
|
53
|
+
|
|
54
|
+
R = U @ V
|
|
55
|
+
B, _, _ = X_L.shape
|
|
56
|
+
F = torch.eye(3, 3, device=X_L.device)[None].tile(
|
|
57
|
+
(
|
|
58
|
+
B,
|
|
59
|
+
1,
|
|
60
|
+
1,
|
|
61
|
+
)
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
F[..., -1, -1] = torch.sign(torch.linalg.det(R))
|
|
65
|
+
R = U @ F @ V
|
|
66
|
+
|
|
67
|
+
X_gt_L = X_gt_L - u_X_gt.unsqueeze(-2)
|
|
68
|
+
X_align_L = X_gt_L @ R + u_X.unsqueeze(-2)
|
|
69
|
+
|
|
70
|
+
return X_align_L.detach()
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def get_rmsd(xyz1, xyz2, eps=1e-4):
|
|
74
|
+
L = xyz1.shape[-2]
|
|
75
|
+
rmsd = torch.sqrt(torch.sum((xyz2 - xyz1) * (xyz2 - xyz1), axis=(-1, -2)) / L + eps)
|
|
76
|
+
return rmsd
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def superimpose(xyz1, xyz2, mask, eps=1e-4):
|
|
80
|
+
"""
|
|
81
|
+
Superimpose xyz1 onto xyz2 using mask
|
|
82
|
+
"""
|
|
83
|
+
L = xyz1.shape[-2]
|
|
84
|
+
assert mask.shape == (L,)
|
|
85
|
+
assert xyz1.shape == xyz2.shape
|
|
86
|
+
assert mask.dtype == torch.bool
|