rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
foundry/__init__.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from beartype.claw import beartype_this_package
|
|
6
|
+
from environs import Env
|
|
7
|
+
from jaxtyping import install_import_hook
|
|
8
|
+
|
|
9
|
+
# Load environment variables from `.env` file
|
|
10
|
+
_env = Env()
|
|
11
|
+
_env.read_env()
|
|
12
|
+
should_typecheck = _env.bool("TYPE_CHECK", default=False)
|
|
13
|
+
should_debug = _env.bool("DEBUG", default=False)
|
|
14
|
+
should_check_nans = _env.bool("NAN_CHECK", default=True)
|
|
15
|
+
|
|
16
|
+
# Set up logger
|
|
17
|
+
logger = logging.getLogger("foundry")
|
|
18
|
+
# ... set logging level based on `DEBUG` environment variable
|
|
19
|
+
logger.setLevel(logging.DEBUG if should_debug else logging.INFO)
|
|
20
|
+
# ... log the current mode
|
|
21
|
+
logger.debug("Debug mode: %s", should_debug)
|
|
22
|
+
logger.debug("Type checking mode: %s", should_typecheck)
|
|
23
|
+
logger.debug("NAN checking mode: %s", should_check_nans)
|
|
24
|
+
|
|
25
|
+
# Enable runtime type checking if `TYPE_CHECK` environment variable is set to `True`
|
|
26
|
+
if should_typecheck:
|
|
27
|
+
beartype_this_package()
|
|
28
|
+
install_import_hook("foundry", "beartype.beartype")
|
|
29
|
+
|
|
30
|
+
# Global flag for cuEquivariance availability
|
|
31
|
+
SHOULD_USE_CUEQUIVARIANCE = False
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
if torch.cuda.is_available():
|
|
35
|
+
if _env.bool("DISABLE_CUEQUIVARIANCE", default=False):
|
|
36
|
+
logger.info("cuEquivariance usage disabled via DISABLE_CUEQUIVARIANCE")
|
|
37
|
+
else:
|
|
38
|
+
import cuequivariance_torch as cuet # noqa: I001, F401
|
|
39
|
+
|
|
40
|
+
SHOULD_USE_CUEQUIVARIANCE = True
|
|
41
|
+
os.environ["CUEQ_DISABLE_AOT_TUNING"] = _env.str(
|
|
42
|
+
"CUEQ_DISABLE_AOT_TUNING", default="1"
|
|
43
|
+
)
|
|
44
|
+
os.environ["CUEQ_DEFAULT_CONFIG"] = _env.str(
|
|
45
|
+
"CUEQ_DEFAULT_CONFIG", default="1"
|
|
46
|
+
)
|
|
47
|
+
logger.info("cuEquivariance is available and will be used.")
|
|
48
|
+
|
|
49
|
+
except ImportError:
|
|
50
|
+
logger.debug("cuEquivariance unavailable: import failed")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# Whether to disable checkpointing globally
|
|
54
|
+
DISABLE_CHECKPOINTING = False
|
|
55
|
+
|
|
56
|
+
# Export for easy access
|
|
57
|
+
__all__ = ["SHOULD_USE_CUEQUIVARIANCE", "DISABLE_CHECKPOINTING"]
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
|
|
3
|
+
from beartype.typing import Any
|
|
4
|
+
from lightning.fabric.wrappers import (
|
|
5
|
+
_FabricOptimizer,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class BaseCallback(ABC):
|
|
10
|
+
"""Abstract base class used to build new callbacks.
|
|
11
|
+
|
|
12
|
+
Callbacks receive the trainer as the first argument to all hook methods, following
|
|
13
|
+
PyTorch Lightning's convention. This allows callbacks to access trainer.state,
|
|
14
|
+
trainer.fabric, etc.
|
|
15
|
+
|
|
16
|
+
NOTE: on_after_optimizer_step is called internally by Fabric and does NOT receive trainer.
|
|
17
|
+
Use on_before_optimizer_step for logic that requires trainer access.
|
|
18
|
+
|
|
19
|
+
Where possible, use names consistent with PyTorch Lightning's callback names (see references below).
|
|
20
|
+
Note that if using any callbacks directly within a Model, they must also adhere to this schema.
|
|
21
|
+
|
|
22
|
+
References:
|
|
23
|
+
- Pytorch Lightning Hooks (https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#hooks)
|
|
24
|
+
- Callbacks Flow (https://pytorch-lightning.readthedocs.io/en/0.10.0/callbacks.html#callbacks)
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
# Epoch loops
|
|
28
|
+
def on_fit_start(self, trainer: Any):
|
|
29
|
+
"""Called at the start of the training"""
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
def on_fit_end(self, trainer: Any):
|
|
33
|
+
"""Called at the end of the training"""
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
# Training loop
|
|
37
|
+
def on_train_epoch_start(self, trainer: Any):
|
|
38
|
+
"""Called at the start of each training epoch"""
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
def on_after_train_loader_iter(self, trainer: Any, **kwargs):
|
|
42
|
+
"""Called after 'iter(train_loader)' is called, but before the first batch is yielded"""
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
def on_before_train_loader_next(self, trainer: Any, **kwargs):
|
|
46
|
+
"""Called after each batch is yielded from the train_loader 'next(train_iter)' call"""
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
def on_train_batch_start(self, trainer: Any, batch: Any, batch_idx: int):
|
|
50
|
+
"""Called at the start of each training batch"""
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
def on_train_batch_end(
|
|
54
|
+
self, trainer: Any, outputs: Any, batch: Any, batch_idx: int
|
|
55
|
+
):
|
|
56
|
+
"""Called after each training batch, but before the optimizer.step"""
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
def on_before_optimizer_step(self, trainer: Any, optimizer: _FabricOptimizer):
|
|
60
|
+
"""Called before each optimizer.step"""
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
def on_after_optimizer_step(self, optimizer: _FabricOptimizer, **kwargs):
|
|
64
|
+
"""Called after each optimizer.step.
|
|
65
|
+
|
|
66
|
+
NOTE: This hook is called internally by Fabric when optimizer.step() executes.
|
|
67
|
+
Trainer is NOT available here. Use optimizer_step for logic requiring trainer.
|
|
68
|
+
"""
|
|
69
|
+
pass
|
|
70
|
+
|
|
71
|
+
def optimizer_step(self, trainer: Any, optimizer: _FabricOptimizer):
|
|
72
|
+
"""Called after optimizer.step completes. Unlike on_after_optimizer_step,
|
|
73
|
+
this hook is called explicitly by the trainer and receives trainer access.
|
|
74
|
+
"""
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
def on_train_epoch_end(self, trainer: Any):
|
|
78
|
+
"""Called at the end of each training epoch"""
|
|
79
|
+
pass
|
|
80
|
+
|
|
81
|
+
# Validation loop
|
|
82
|
+
def on_validation_epoch_start(self, trainer: Any):
|
|
83
|
+
"""Called at the start of each validation epoch"""
|
|
84
|
+
pass
|
|
85
|
+
|
|
86
|
+
def on_validation_batch_start(
|
|
87
|
+
self,
|
|
88
|
+
trainer: Any,
|
|
89
|
+
batch: Any,
|
|
90
|
+
batch_idx: int,
|
|
91
|
+
num_batches: int,
|
|
92
|
+
dataset_name: str | None = None,
|
|
93
|
+
):
|
|
94
|
+
"""Called at the start of each validation batch"""
|
|
95
|
+
pass
|
|
96
|
+
|
|
97
|
+
def on_validation_batch_end(
|
|
98
|
+
self,
|
|
99
|
+
trainer: Any,
|
|
100
|
+
outputs: Any,
|
|
101
|
+
batch: Any,
|
|
102
|
+
batch_idx: int,
|
|
103
|
+
num_batches: int,
|
|
104
|
+
dataset_name: str | None = None,
|
|
105
|
+
):
|
|
106
|
+
"""Called after each validation batch"""
|
|
107
|
+
pass
|
|
108
|
+
|
|
109
|
+
def on_validation_epoch_end(self, trainer: Any):
|
|
110
|
+
"""Called at the end of each validation epoch"""
|
|
111
|
+
pass
|
|
112
|
+
|
|
113
|
+
# Saving and Loading
|
|
114
|
+
def on_save_checkpoint(self, trainer: Any, state: dict[str, Any]):
|
|
115
|
+
"""Called when saving a checkpoint"""
|
|
116
|
+
pass
|
|
@@ -0,0 +1,419 @@
|
|
|
1
|
+
import gc
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from typing import Any, types
|
|
4
|
+
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
from jaxtyping import Float, Int
|
|
10
|
+
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
|
11
|
+
from lightning.fabric.wrappers import (
|
|
12
|
+
_FabricOptimizer,
|
|
13
|
+
)
|
|
14
|
+
from torch import Tensor
|
|
15
|
+
|
|
16
|
+
from foundry.callbacks.callback import BaseCallback
|
|
17
|
+
|
|
18
|
+
_DEFAULT_STATISTICS = types.MappingProxyType(
|
|
19
|
+
{
|
|
20
|
+
"mean": torch.mean,
|
|
21
|
+
"std": torch.std,
|
|
22
|
+
"norm": torch.norm,
|
|
23
|
+
"max": torch.amax,
|
|
24
|
+
"min": torch.amin,
|
|
25
|
+
}
|
|
26
|
+
)
|
|
27
|
+
"""Summary statistics to log for gradients, weights, and activations."""
|
|
28
|
+
|
|
29
|
+
_DEFAULT_HISTOGRAMS = types.MappingProxyType(
|
|
30
|
+
{
|
|
31
|
+
"activations": lambda x: np.histogram(
|
|
32
|
+
x.abs().to(torch.float32).cpu(), bins=40, range=(0, 10)
|
|
33
|
+
),
|
|
34
|
+
"grads": lambda x: np.histogram(
|
|
35
|
+
x.abs().to(torch.float32).cpu(), bins=40, range=(0, 1)
|
|
36
|
+
),
|
|
37
|
+
"weights": lambda x: np.histogram(
|
|
38
|
+
x.abs().to(torch.float32).cpu(), bins=40, range=(0, 1)
|
|
39
|
+
),
|
|
40
|
+
}
|
|
41
|
+
)
|
|
42
|
+
"""Default histograms to log for activations, gradients, and weights."""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class ActivationsGradientsWeightsTracker(BaseCallback):
|
|
46
|
+
"""Fabric callback to track gradients, activations, and weights during training.
|
|
47
|
+
|
|
48
|
+
This callback logs gradient, weight, and activation statistics at specified intervals.
|
|
49
|
+
Integrates with FabricTrainer through the BaseCallback interface.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
log_freq (int): Frequency of logging (every N steps). Defaults to 100.
|
|
53
|
+
log_grads (bool): Whether to log gradient statistics. Defaults to True.
|
|
54
|
+
log_weights (bool): Whether to log weight statistics. Defaults to True.
|
|
55
|
+
log_activations (bool): Whether to log activation statistics. Defaults to True.
|
|
56
|
+
keep_cache (bool): Whether to keep a local cache of all logged stats. Defaults to False.
|
|
57
|
+
filter_grads (callable): Function (name, param) -> bool to filter gradient tracking. None means all.
|
|
58
|
+
filter_weights (callable): Function (name, param) -> bool to filter weight tracking. None means all.
|
|
59
|
+
filter_activations (callable): Function (name, module) -> bool to filter activation tracking.
|
|
60
|
+
one means default types (Linear, Conv1d, Conv2d, MultiheadAttention).
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
log_freq: int = 100,
|
|
66
|
+
log_grads: dict[str, callable] = _DEFAULT_STATISTICS,
|
|
67
|
+
log_weights: dict[str, callable] = _DEFAULT_STATISTICS,
|
|
68
|
+
log_activations: dict[str, callable] = _DEFAULT_STATISTICS,
|
|
69
|
+
log_histograms: dict[str, callable] = _DEFAULT_HISTOGRAMS,
|
|
70
|
+
keep_cache: bool = False,
|
|
71
|
+
filter_grads: callable = None,
|
|
72
|
+
filter_weights: callable = None,
|
|
73
|
+
filter_activations: callable = None,
|
|
74
|
+
):
|
|
75
|
+
super().__init__()
|
|
76
|
+
self.log_freq = log_freq
|
|
77
|
+
self.log_grads = log_grads
|
|
78
|
+
self.log_weights = log_weights
|
|
79
|
+
self.log_activations = log_activations
|
|
80
|
+
self.log_histograms = log_histograms
|
|
81
|
+
self.keep_cache = keep_cache
|
|
82
|
+
self.filter_grads = filter_grads
|
|
83
|
+
self.filter_weights = filter_weights
|
|
84
|
+
self.filter_activations = filter_activations
|
|
85
|
+
|
|
86
|
+
self._hooks = [] # Store activation hooks for cleanup
|
|
87
|
+
self._temp_cache = {"scalars": {}, "histograms": {}}
|
|
88
|
+
self._cache = defaultdict(list)
|
|
89
|
+
if not self.keep_cache:
|
|
90
|
+
self.log_histograms = {}
|
|
91
|
+
|
|
92
|
+
@rank_zero_only
|
|
93
|
+
def on_fit_start(self, trainer):
|
|
94
|
+
"""Initialize the callback and register activation hooks."""
|
|
95
|
+
# Check that we either have loggers attached or keep_cache is True, otherwise the
|
|
96
|
+
# data will be computed but not logged.
|
|
97
|
+
if not self.keep_cache and not trainer.fabric.loggers:
|
|
98
|
+
raise ValueError(
|
|
99
|
+
"TrainingHealthTracker requires loggers or keep_cache=True. "
|
|
100
|
+
"Otherwise the data will be computed but not logged."
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
@rank_zero_only
|
|
104
|
+
def on_train_batch_start(self, trainer, batch: Any, batch_idx: int):
|
|
105
|
+
step = trainer.state["global_step"]
|
|
106
|
+
model = trainer.state["model"]
|
|
107
|
+
if (self.log_activations or "activations" in self.log_histograms) and (
|
|
108
|
+
step % self.log_freq == 0
|
|
109
|
+
):
|
|
110
|
+
self._register_activation_hooks(model, step)
|
|
111
|
+
|
|
112
|
+
@rank_zero_only
|
|
113
|
+
def on_before_optimizer_step(self, trainer, optimizer: _FabricOptimizer, **kwargs):
|
|
114
|
+
"""Log gradients, weights, and activations before optimizer step."""
|
|
115
|
+
step = trainer.state["global_step"]
|
|
116
|
+
|
|
117
|
+
if step % self.log_freq == 0:
|
|
118
|
+
model = trainer.state["model"]
|
|
119
|
+
|
|
120
|
+
# Collect weight & gradient stats
|
|
121
|
+
_should_log_some_grads = self.log_grads or ("grads" in self.log_histograms)
|
|
122
|
+
_should_log_some_weights = self.log_weights or (
|
|
123
|
+
"weights" in self.log_histograms
|
|
124
|
+
)
|
|
125
|
+
if _should_log_some_grads or _should_log_some_weights:
|
|
126
|
+
self._collect_parameter_stats(model, step)
|
|
127
|
+
|
|
128
|
+
# Log all collected stats at once using trainer's fabric instance
|
|
129
|
+
if len(self._temp_cache["scalars"]) > 0 and trainer.fabric.loggers:
|
|
130
|
+
trainer.fabric.log_dict(
|
|
131
|
+
self._temp_cache["scalars"],
|
|
132
|
+
step=step,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
if self.keep_cache:
|
|
136
|
+
self._cache["step"].append(torch.tensor(step))
|
|
137
|
+
for key, value in self._temp_cache["scalars"].items():
|
|
138
|
+
self._cache[key].append(value)
|
|
139
|
+
for key, value in self._temp_cache["histograms"].items():
|
|
140
|
+
if key.endswith("hist"):
|
|
141
|
+
self._cache[key].append(value)
|
|
142
|
+
|
|
143
|
+
def on_train_batch_end(self, trainer, **kwargs):
|
|
144
|
+
"""Called at the end of a training batch - clear temporary cache."""
|
|
145
|
+
self._temp_cache["scalars"].clear()
|
|
146
|
+
self._temp_cache["histograms"].clear()
|
|
147
|
+
self._remove_activation_hooks()
|
|
148
|
+
|
|
149
|
+
def on_fit_end(self, trainer, **kwargs):
|
|
150
|
+
"""Clean up activation hooks at the end of training."""
|
|
151
|
+
self._remove_activation_hooks()
|
|
152
|
+
|
|
153
|
+
def on_validation_epoch_start(self, trainer):
|
|
154
|
+
# Temporarily remove any hooks for validation
|
|
155
|
+
self._remove_activation_hooks()
|
|
156
|
+
|
|
157
|
+
@rank_zero_only
|
|
158
|
+
def on_save_checkpoint(self, trainer, state: dict[str, Any]):
|
|
159
|
+
self._remove_activation_hooks()
|
|
160
|
+
|
|
161
|
+
def _collect_parameter_stats(self, model, step: int):
|
|
162
|
+
"""Collect gradient and weight statistics in a single parameter iteration."""
|
|
163
|
+
cache = self._temp_cache # alias
|
|
164
|
+
|
|
165
|
+
for name, param in model.named_parameters():
|
|
166
|
+
# Gradient stats
|
|
167
|
+
if (
|
|
168
|
+
(self.log_grads or "grads" in self.log_histograms)
|
|
169
|
+
and param.grad is not None
|
|
170
|
+
and self._should_track_grad(name)
|
|
171
|
+
):
|
|
172
|
+
grad = param.grad.detach()
|
|
173
|
+
for stat_name, stat_fn in self.log_grads.items():
|
|
174
|
+
cache["scalars"]["grads/" + name + "/" + stat_name] = stat_fn(grad)
|
|
175
|
+
if "grads" in self.log_histograms:
|
|
176
|
+
counts, bin_edges = self.log_histograms["grads"](grad)
|
|
177
|
+
cache["histograms"]["grads/" + name + "/hist"] = counts
|
|
178
|
+
cache["histograms"]["grads/" + name + "/hist_bin_edges"] = bin_edges
|
|
179
|
+
|
|
180
|
+
# Weight stats
|
|
181
|
+
if (
|
|
182
|
+
self.log_weights or "weights" in self.log_histograms
|
|
183
|
+
) and self._should_track_weight(name):
|
|
184
|
+
for stat_name, stat_fn in self.log_weights.items():
|
|
185
|
+
cache["scalars"]["weights/" + name + "/" + stat_name] = stat_fn(
|
|
186
|
+
param.data
|
|
187
|
+
)
|
|
188
|
+
if "weights" in self.log_histograms:
|
|
189
|
+
counts, bin_edges = self.log_histograms["weights"](param.data)
|
|
190
|
+
cache["histograms"]["weights/" + name + "/hist"] = counts
|
|
191
|
+
cache["histograms"]["weights/" + name + "/hist_bin_edges"] = (
|
|
192
|
+
bin_edges
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
def _should_track_grad(self, name: str) -> bool:
|
|
196
|
+
"""Check if we should track gradients for this parameter."""
|
|
197
|
+
if self.filter_grads is None:
|
|
198
|
+
return True
|
|
199
|
+
return self.filter_grads(name)
|
|
200
|
+
|
|
201
|
+
def _should_track_weight(self, name: str) -> bool:
|
|
202
|
+
"""Check if we should track weights for this parameter."""
|
|
203
|
+
if self.filter_weights is None:
|
|
204
|
+
return True
|
|
205
|
+
return self.filter_weights(name)
|
|
206
|
+
|
|
207
|
+
def _should_track_activation(self, name: str, module_type: type[nn.Module]) -> bool:
|
|
208
|
+
"""Check if we should track activations for this module."""
|
|
209
|
+
if self.filter_activations is None:
|
|
210
|
+
return True
|
|
211
|
+
return self.filter_activations(name, module_type)
|
|
212
|
+
|
|
213
|
+
def _register_activation_hooks(self, model, step: int):
|
|
214
|
+
"""Register forward hooks to accumulate activations."""
|
|
215
|
+
cache = self._temp_cache # alias
|
|
216
|
+
|
|
217
|
+
def create_activation_hook(name):
|
|
218
|
+
def hook(module, input, output):
|
|
219
|
+
if isinstance(output, torch.Tensor) and (step % self.log_freq == 0):
|
|
220
|
+
output = output.detach()
|
|
221
|
+
for stat_name, stat_fn in self.log_activations.items():
|
|
222
|
+
cache["activations/" + name + "/" + stat_name] = stat_fn(output)
|
|
223
|
+
if "activations" in self.log_histograms:
|
|
224
|
+
counts, bin_edges = self.log_histograms["activations"](output)
|
|
225
|
+
cache["histograms"]["activations/" + name + "/hist"] = counts
|
|
226
|
+
cache["histograms"][
|
|
227
|
+
"activations/" + name + "/hist_bin_edges"
|
|
228
|
+
] = bin_edges
|
|
229
|
+
|
|
230
|
+
return hook
|
|
231
|
+
|
|
232
|
+
# Register hooks for filtered modules
|
|
233
|
+
for name, module in model.named_modules():
|
|
234
|
+
if self._should_track_activation(name, type(module)):
|
|
235
|
+
hook = module.register_forward_hook(create_activation_hook(name))
|
|
236
|
+
self._hooks.append(hook)
|
|
237
|
+
|
|
238
|
+
def _remove_activation_hooks(self):
|
|
239
|
+
"""Remove activation hooks."""
|
|
240
|
+
for hook in self._hooks:
|
|
241
|
+
hook.remove()
|
|
242
|
+
self._hooks.clear()
|
|
243
|
+
|
|
244
|
+
def __del__(self):
|
|
245
|
+
self._remove_activation_hooks()
|
|
246
|
+
del self._temp_cache
|
|
247
|
+
del self._cache
|
|
248
|
+
gc.collect()
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def plot_tensor_hist(
|
|
252
|
+
hist_values: Float[Tensor, "N M"],
|
|
253
|
+
name: str = "",
|
|
254
|
+
norms: Float[Tensor, "N"] = None,
|
|
255
|
+
steps: Int[Tensor, "N"] = None,
|
|
256
|
+
log_scale: bool = True,
|
|
257
|
+
) -> plt.Figure:
|
|
258
|
+
"""
|
|
259
|
+
Plot a histogram of tensor values over time, optionally including norm values.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
hist_values: Tensor of shape (N, M) containing histogram values for N steps and M bins.
|
|
263
|
+
name: Title for the plot, usually the name of the parameter being plotted.
|
|
264
|
+
norms: Optional tensor of shape (N,) containing norm values for each step.
|
|
265
|
+
steps: Optional tensor of shape (N,) containing step indices. If None, uses range(N).
|
|
266
|
+
log_scale: If True, applies log1p to histogram values before plotting.
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
A matplotlib Figure object containing the plotted histogram.
|
|
270
|
+
|
|
271
|
+
Example:
|
|
272
|
+
>>> hist_values = torch.randn(100, 50) # 100 steps, 50 bins
|
|
273
|
+
>>> norms = torch.norm(hist_values, dim=1)
|
|
274
|
+
>>> fig = plot_tensor_hist(hist_values, name="Weight Distribution", norms=norms)
|
|
275
|
+
>>> plt.show()
|
|
276
|
+
"""
|
|
277
|
+
font_size = 8
|
|
278
|
+
with plt.rc_context({"font.size": font_size}):
|
|
279
|
+
n_steps, n_bins = hist_values.shape # (N, M)
|
|
280
|
+
if log_scale:
|
|
281
|
+
hist_values = np.log1p(hist_values)
|
|
282
|
+
if steps is None:
|
|
283
|
+
steps = np.arange(n_steps)
|
|
284
|
+
fig, ax = plt.subplots(
|
|
285
|
+
figsize=(6, 2), constrained_layout=True
|
|
286
|
+
) # Added constrained_layout
|
|
287
|
+
mat = ax.matshow(hist_values.T, aspect="auto")
|
|
288
|
+
ax.set_xlabel("step")
|
|
289
|
+
|
|
290
|
+
# Get the automatically determined tick positions from matplotlib
|
|
291
|
+
locs = ax.get_xticks()
|
|
292
|
+
valid_locs = locs[(locs >= 0) & (locs < n_steps)].astype(int)
|
|
293
|
+
ax.set_xticks(valid_locs)
|
|
294
|
+
ax.set_xticklabels(steps[valid_locs])
|
|
295
|
+
ax.set_ylabel("bins")
|
|
296
|
+
|
|
297
|
+
# Create twin axis
|
|
298
|
+
if norms is not None:
|
|
299
|
+
ax2 = ax.twinx()
|
|
300
|
+
ax2.plot(np.arange(len(norms)), norms, color="black")
|
|
301
|
+
ax2.set_ylabel("norm")
|
|
302
|
+
ax2.set_xlim(0, n_steps - 1)
|
|
303
|
+
ax2.set_ylim(min(norms), max(norms)) # Independent scaling
|
|
304
|
+
ax2.set_xticks(valid_locs)
|
|
305
|
+
ax2.set_xticklabels(steps[valid_locs])
|
|
306
|
+
|
|
307
|
+
# Add colorbar - constrained_layout will handle spacing automatically
|
|
308
|
+
cbar = plt.colorbar(mat, ax=ax)
|
|
309
|
+
cbar.ax.set_ylabel("log(1+count)" if log_scale else "count")
|
|
310
|
+
|
|
311
|
+
ax.set_xlim(0, n_steps - 1)
|
|
312
|
+
ax.set_ylim(0, n_bins - 1)
|
|
313
|
+
if name:
|
|
314
|
+
ax.set_title(name, pad=20, fontsize=8)
|
|
315
|
+
|
|
316
|
+
return fig
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def plot_tensor_stats(
|
|
320
|
+
steps: Int[Tensor, "N"],
|
|
321
|
+
mean: Float[Tensor, "N"] = None,
|
|
322
|
+
std: Float[Tensor, "N"] = None,
|
|
323
|
+
min_val: Float[Tensor, "N"] = None,
|
|
324
|
+
max_val: Float[Tensor, "N"] = None,
|
|
325
|
+
norm: Float[Tensor, "N"] = None,
|
|
326
|
+
name: str = "",
|
|
327
|
+
height_ratios: tuple[float, float] = (5, 1),
|
|
328
|
+
):
|
|
329
|
+
"""
|
|
330
|
+
Plot comprehensive statistics with mean/std/min/max in top panel and norm in bottom panel.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
steps: Training step indices
|
|
334
|
+
mean: Mean values over time (optional)
|
|
335
|
+
std: Standard deviation values over time (optional, requires mean)
|
|
336
|
+
min_val: Minimum values over time (optional)
|
|
337
|
+
max_val: Maximum values over time (optional)
|
|
338
|
+
norm: Norm values over time (optional)
|
|
339
|
+
name: Title for the plot, usually the name of the parameter being plotted.
|
|
340
|
+
height_ratios: Relative heights of (stats_panel, norm_panel)
|
|
341
|
+
|
|
342
|
+
Returns:
|
|
343
|
+
matplotlib Figure object
|
|
344
|
+
"""
|
|
345
|
+
# Determine what to plot
|
|
346
|
+
has_stats = any([mean is not None, min_val is not None, max_val is not None])
|
|
347
|
+
has_norm = norm is not None
|
|
348
|
+
|
|
349
|
+
if not has_stats and not has_norm:
|
|
350
|
+
raise ValueError(
|
|
351
|
+
"At least one of mean, min_val, max_val, or norm must be provided"
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
# Create subplot layout based on available data
|
|
355
|
+
if has_stats and has_norm:
|
|
356
|
+
fig, (ax1, ax2) = plt.subplots(
|
|
357
|
+
2,
|
|
358
|
+
1,
|
|
359
|
+
figsize=(5, 3),
|
|
360
|
+
gridspec_kw={"height_ratios": height_ratios},
|
|
361
|
+
sharex=True,
|
|
362
|
+
constrained_layout=True,
|
|
363
|
+
)
|
|
364
|
+
norm_ax = ax2
|
|
365
|
+
stats_ax = ax1
|
|
366
|
+
elif has_stats:
|
|
367
|
+
fig, ax1 = plt.subplots(figsize=(5, 3))
|
|
368
|
+
stats_ax = ax1
|
|
369
|
+
norm_ax = None
|
|
370
|
+
else: # only norm
|
|
371
|
+
fig, ax2 = plt.subplots(figsize=(5, 3))
|
|
372
|
+
norm_ax = ax2
|
|
373
|
+
stats_ax = None
|
|
374
|
+
|
|
375
|
+
# Top panel: statistics (if available)
|
|
376
|
+
if has_stats and stats_ax is not None:
|
|
377
|
+
if mean is not None:
|
|
378
|
+
stats_ax.plot(steps, mean, label="mean", color="C0")
|
|
379
|
+
if std is not None:
|
|
380
|
+
stats_ax.fill_between(
|
|
381
|
+
steps, mean - std, mean + std, alpha=0.2, color="C0", label="±1 std"
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
if min_val is not None and max_val is not None:
|
|
385
|
+
stats_ax.plot(
|
|
386
|
+
steps, min_val, "--", color="gray", alpha=0.7, label="min/max"
|
|
387
|
+
)
|
|
388
|
+
stats_ax.plot(steps, max_val, "--", color="gray", alpha=0.7)
|
|
389
|
+
elif min_val is not None:
|
|
390
|
+
stats_ax.plot(steps, min_val, "--", color="gray", alpha=0.7, label="min")
|
|
391
|
+
elif max_val is not None:
|
|
392
|
+
stats_ax.plot(steps, max_val, "--", color="gray", alpha=0.7, label="max")
|
|
393
|
+
|
|
394
|
+
stats_ax.ticklabel_format(style="plain", useOffset=False)
|
|
395
|
+
stats_ax.set_ylabel("Stats", labelpad=0)
|
|
396
|
+
if name:
|
|
397
|
+
stats_ax.set_title(name, pad=5, fontsize=9)
|
|
398
|
+
stats_ax.grid(True, alpha=0.3)
|
|
399
|
+
stats_ax.legend(loc="upper right", bbox_to_anchor=(1, 1), ncol=2)
|
|
400
|
+
|
|
401
|
+
# Set xlabel only if this is the only panel
|
|
402
|
+
if not has_norm:
|
|
403
|
+
stats_ax.set_xlabel("Step")
|
|
404
|
+
|
|
405
|
+
# Bottom panel: norm (if available)
|
|
406
|
+
if has_norm and norm_ax is not None:
|
|
407
|
+
norm_ax.plot(steps, norm, label="norm", color="C1")
|
|
408
|
+
norm_ax.set_ylabel("Norm", labelpad=0)
|
|
409
|
+
norm_ax.set_xlabel("Step")
|
|
410
|
+
norm_ax.grid(True, alpha=0.3)
|
|
411
|
+
norm_ax.legend(loc="upper right", bbox_to_anchor=(1, 1))
|
|
412
|
+
norm_ax.ticklabel_format(style="plain", useOffset=False)
|
|
413
|
+
|
|
414
|
+
# Set title if this is the only panel and no stats panel exists
|
|
415
|
+
if not has_stats and name:
|
|
416
|
+
norm_ax.set_title(name, pad=5, fontsize=9)
|
|
417
|
+
|
|
418
|
+
plt.tight_layout(pad=0.5, h_pad=0.5, w_pad=0.5)
|
|
419
|
+
return fig
|