rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
foundry/__init__.py ADDED
@@ -0,0 +1,57 @@
1
+ import logging
2
+ import os
3
+
4
+ import torch
5
+ from beartype.claw import beartype_this_package
6
+ from environs import Env
7
+ from jaxtyping import install_import_hook
8
+
9
+ # Load environment variables from `.env` file
10
+ _env = Env()
11
+ _env.read_env()
12
+ should_typecheck = _env.bool("TYPE_CHECK", default=False)
13
+ should_debug = _env.bool("DEBUG", default=False)
14
+ should_check_nans = _env.bool("NAN_CHECK", default=True)
15
+
16
+ # Set up logger
17
+ logger = logging.getLogger("foundry")
18
+ # ... set logging level based on `DEBUG` environment variable
19
+ logger.setLevel(logging.DEBUG if should_debug else logging.INFO)
20
+ # ... log the current mode
21
+ logger.debug("Debug mode: %s", should_debug)
22
+ logger.debug("Type checking mode: %s", should_typecheck)
23
+ logger.debug("NAN checking mode: %s", should_check_nans)
24
+
25
+ # Enable runtime type checking if `TYPE_CHECK` environment variable is set to `True`
26
+ if should_typecheck:
27
+ beartype_this_package()
28
+ install_import_hook("foundry", "beartype.beartype")
29
+
30
+ # Global flag for cuEquivariance availability
31
+ SHOULD_USE_CUEQUIVARIANCE = False
32
+
33
+ try:
34
+ if torch.cuda.is_available():
35
+ if _env.bool("DISABLE_CUEQUIVARIANCE", default=False):
36
+ logger.info("cuEquivariance usage disabled via DISABLE_CUEQUIVARIANCE")
37
+ else:
38
+ import cuequivariance_torch as cuet # noqa: I001, F401
39
+
40
+ SHOULD_USE_CUEQUIVARIANCE = True
41
+ os.environ["CUEQ_DISABLE_AOT_TUNING"] = _env.str(
42
+ "CUEQ_DISABLE_AOT_TUNING", default="1"
43
+ )
44
+ os.environ["CUEQ_DEFAULT_CONFIG"] = _env.str(
45
+ "CUEQ_DEFAULT_CONFIG", default="1"
46
+ )
47
+ logger.info("cuEquivariance is available and will be used.")
48
+
49
+ except ImportError:
50
+ logger.debug("cuEquivariance unavailable: import failed")
51
+
52
+
53
+ # Whether to disable checkpointing globally
54
+ DISABLE_CHECKPOINTING = False
55
+
56
+ # Export for easy access
57
+ __all__ = ["SHOULD_USE_CUEQUIVARIANCE", "DISABLE_CHECKPOINTING"]
@@ -0,0 +1,5 @@
1
+ """Callbacks for training and validation."""
2
+
3
+ from foundry.callbacks.callback import BaseCallback
4
+
5
+ __all__ = ["BaseCallback"]
@@ -0,0 +1,116 @@
1
+ from abc import ABC
2
+
3
+ from beartype.typing import Any
4
+ from lightning.fabric.wrappers import (
5
+ _FabricOptimizer,
6
+ )
7
+
8
+
9
+ class BaseCallback(ABC):
10
+ """Abstract base class used to build new callbacks.
11
+
12
+ Callbacks receive the trainer as the first argument to all hook methods, following
13
+ PyTorch Lightning's convention. This allows callbacks to access trainer.state,
14
+ trainer.fabric, etc.
15
+
16
+ NOTE: on_after_optimizer_step is called internally by Fabric and does NOT receive trainer.
17
+ Use on_before_optimizer_step for logic that requires trainer access.
18
+
19
+ Where possible, use names consistent with PyTorch Lightning's callback names (see references below).
20
+ Note that if using any callbacks directly within a Model, they must also adhere to this schema.
21
+
22
+ References:
23
+ - Pytorch Lightning Hooks (https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#hooks)
24
+ - Callbacks Flow (https://pytorch-lightning.readthedocs.io/en/0.10.0/callbacks.html#callbacks)
25
+ """
26
+
27
+ # Epoch loops
28
+ def on_fit_start(self, trainer: Any):
29
+ """Called at the start of the training"""
30
+ pass
31
+
32
+ def on_fit_end(self, trainer: Any):
33
+ """Called at the end of the training"""
34
+ pass
35
+
36
+ # Training loop
37
+ def on_train_epoch_start(self, trainer: Any):
38
+ """Called at the start of each training epoch"""
39
+ pass
40
+
41
+ def on_after_train_loader_iter(self, trainer: Any, **kwargs):
42
+ """Called after 'iter(train_loader)' is called, but before the first batch is yielded"""
43
+ pass
44
+
45
+ def on_before_train_loader_next(self, trainer: Any, **kwargs):
46
+ """Called after each batch is yielded from the train_loader 'next(train_iter)' call"""
47
+ pass
48
+
49
+ def on_train_batch_start(self, trainer: Any, batch: Any, batch_idx: int):
50
+ """Called at the start of each training batch"""
51
+ pass
52
+
53
+ def on_train_batch_end(
54
+ self, trainer: Any, outputs: Any, batch: Any, batch_idx: int
55
+ ):
56
+ """Called after each training batch, but before the optimizer.step"""
57
+ pass
58
+
59
+ def on_before_optimizer_step(self, trainer: Any, optimizer: _FabricOptimizer):
60
+ """Called before each optimizer.step"""
61
+ pass
62
+
63
+ def on_after_optimizer_step(self, optimizer: _FabricOptimizer, **kwargs):
64
+ """Called after each optimizer.step.
65
+
66
+ NOTE: This hook is called internally by Fabric when optimizer.step() executes.
67
+ Trainer is NOT available here. Use optimizer_step for logic requiring trainer.
68
+ """
69
+ pass
70
+
71
+ def optimizer_step(self, trainer: Any, optimizer: _FabricOptimizer):
72
+ """Called after optimizer.step completes. Unlike on_after_optimizer_step,
73
+ this hook is called explicitly by the trainer and receives trainer access.
74
+ """
75
+ pass
76
+
77
+ def on_train_epoch_end(self, trainer: Any):
78
+ """Called at the end of each training epoch"""
79
+ pass
80
+
81
+ # Validation loop
82
+ def on_validation_epoch_start(self, trainer: Any):
83
+ """Called at the start of each validation epoch"""
84
+ pass
85
+
86
+ def on_validation_batch_start(
87
+ self,
88
+ trainer: Any,
89
+ batch: Any,
90
+ batch_idx: int,
91
+ num_batches: int,
92
+ dataset_name: str | None = None,
93
+ ):
94
+ """Called at the start of each validation batch"""
95
+ pass
96
+
97
+ def on_validation_batch_end(
98
+ self,
99
+ trainer: Any,
100
+ outputs: Any,
101
+ batch: Any,
102
+ batch_idx: int,
103
+ num_batches: int,
104
+ dataset_name: str | None = None,
105
+ ):
106
+ """Called after each validation batch"""
107
+ pass
108
+
109
+ def on_validation_epoch_end(self, trainer: Any):
110
+ """Called at the end of each validation epoch"""
111
+ pass
112
+
113
+ # Saving and Loading
114
+ def on_save_checkpoint(self, trainer: Any, state: dict[str, Any]):
115
+ """Called when saving a checkpoint"""
116
+ pass
@@ -0,0 +1,419 @@
1
+ import gc
2
+ from collections import defaultdict
3
+ from typing import Any, types
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from jaxtyping import Float, Int
10
+ from lightning.fabric.utilities.rank_zero import rank_zero_only
11
+ from lightning.fabric.wrappers import (
12
+ _FabricOptimizer,
13
+ )
14
+ from torch import Tensor
15
+
16
+ from foundry.callbacks.callback import BaseCallback
17
+
18
+ _DEFAULT_STATISTICS = types.MappingProxyType(
19
+ {
20
+ "mean": torch.mean,
21
+ "std": torch.std,
22
+ "norm": torch.norm,
23
+ "max": torch.amax,
24
+ "min": torch.amin,
25
+ }
26
+ )
27
+ """Summary statistics to log for gradients, weights, and activations."""
28
+
29
+ _DEFAULT_HISTOGRAMS = types.MappingProxyType(
30
+ {
31
+ "activations": lambda x: np.histogram(
32
+ x.abs().to(torch.float32).cpu(), bins=40, range=(0, 10)
33
+ ),
34
+ "grads": lambda x: np.histogram(
35
+ x.abs().to(torch.float32).cpu(), bins=40, range=(0, 1)
36
+ ),
37
+ "weights": lambda x: np.histogram(
38
+ x.abs().to(torch.float32).cpu(), bins=40, range=(0, 1)
39
+ ),
40
+ }
41
+ )
42
+ """Default histograms to log for activations, gradients, and weights."""
43
+
44
+
45
+ class ActivationsGradientsWeightsTracker(BaseCallback):
46
+ """Fabric callback to track gradients, activations, and weights during training.
47
+
48
+ This callback logs gradient, weight, and activation statistics at specified intervals.
49
+ Integrates with FabricTrainer through the BaseCallback interface.
50
+
51
+ Args:
52
+ log_freq (int): Frequency of logging (every N steps). Defaults to 100.
53
+ log_grads (bool): Whether to log gradient statistics. Defaults to True.
54
+ log_weights (bool): Whether to log weight statistics. Defaults to True.
55
+ log_activations (bool): Whether to log activation statistics. Defaults to True.
56
+ keep_cache (bool): Whether to keep a local cache of all logged stats. Defaults to False.
57
+ filter_grads (callable): Function (name, param) -> bool to filter gradient tracking. None means all.
58
+ filter_weights (callable): Function (name, param) -> bool to filter weight tracking. None means all.
59
+ filter_activations (callable): Function (name, module) -> bool to filter activation tracking.
60
+ one means default types (Linear, Conv1d, Conv2d, MultiheadAttention).
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ log_freq: int = 100,
66
+ log_grads: dict[str, callable] = _DEFAULT_STATISTICS,
67
+ log_weights: dict[str, callable] = _DEFAULT_STATISTICS,
68
+ log_activations: dict[str, callable] = _DEFAULT_STATISTICS,
69
+ log_histograms: dict[str, callable] = _DEFAULT_HISTOGRAMS,
70
+ keep_cache: bool = False,
71
+ filter_grads: callable = None,
72
+ filter_weights: callable = None,
73
+ filter_activations: callable = None,
74
+ ):
75
+ super().__init__()
76
+ self.log_freq = log_freq
77
+ self.log_grads = log_grads
78
+ self.log_weights = log_weights
79
+ self.log_activations = log_activations
80
+ self.log_histograms = log_histograms
81
+ self.keep_cache = keep_cache
82
+ self.filter_grads = filter_grads
83
+ self.filter_weights = filter_weights
84
+ self.filter_activations = filter_activations
85
+
86
+ self._hooks = [] # Store activation hooks for cleanup
87
+ self._temp_cache = {"scalars": {}, "histograms": {}}
88
+ self._cache = defaultdict(list)
89
+ if not self.keep_cache:
90
+ self.log_histograms = {}
91
+
92
+ @rank_zero_only
93
+ def on_fit_start(self, trainer):
94
+ """Initialize the callback and register activation hooks."""
95
+ # Check that we either have loggers attached or keep_cache is True, otherwise the
96
+ # data will be computed but not logged.
97
+ if not self.keep_cache and not trainer.fabric.loggers:
98
+ raise ValueError(
99
+ "TrainingHealthTracker requires loggers or keep_cache=True. "
100
+ "Otherwise the data will be computed but not logged."
101
+ )
102
+
103
+ @rank_zero_only
104
+ def on_train_batch_start(self, trainer, batch: Any, batch_idx: int):
105
+ step = trainer.state["global_step"]
106
+ model = trainer.state["model"]
107
+ if (self.log_activations or "activations" in self.log_histograms) and (
108
+ step % self.log_freq == 0
109
+ ):
110
+ self._register_activation_hooks(model, step)
111
+
112
+ @rank_zero_only
113
+ def on_before_optimizer_step(self, trainer, optimizer: _FabricOptimizer, **kwargs):
114
+ """Log gradients, weights, and activations before optimizer step."""
115
+ step = trainer.state["global_step"]
116
+
117
+ if step % self.log_freq == 0:
118
+ model = trainer.state["model"]
119
+
120
+ # Collect weight & gradient stats
121
+ _should_log_some_grads = self.log_grads or ("grads" in self.log_histograms)
122
+ _should_log_some_weights = self.log_weights or (
123
+ "weights" in self.log_histograms
124
+ )
125
+ if _should_log_some_grads or _should_log_some_weights:
126
+ self._collect_parameter_stats(model, step)
127
+
128
+ # Log all collected stats at once using trainer's fabric instance
129
+ if len(self._temp_cache["scalars"]) > 0 and trainer.fabric.loggers:
130
+ trainer.fabric.log_dict(
131
+ self._temp_cache["scalars"],
132
+ step=step,
133
+ )
134
+
135
+ if self.keep_cache:
136
+ self._cache["step"].append(torch.tensor(step))
137
+ for key, value in self._temp_cache["scalars"].items():
138
+ self._cache[key].append(value)
139
+ for key, value in self._temp_cache["histograms"].items():
140
+ if key.endswith("hist"):
141
+ self._cache[key].append(value)
142
+
143
+ def on_train_batch_end(self, trainer, **kwargs):
144
+ """Called at the end of a training batch - clear temporary cache."""
145
+ self._temp_cache["scalars"].clear()
146
+ self._temp_cache["histograms"].clear()
147
+ self._remove_activation_hooks()
148
+
149
+ def on_fit_end(self, trainer, **kwargs):
150
+ """Clean up activation hooks at the end of training."""
151
+ self._remove_activation_hooks()
152
+
153
+ def on_validation_epoch_start(self, trainer):
154
+ # Temporarily remove any hooks for validation
155
+ self._remove_activation_hooks()
156
+
157
+ @rank_zero_only
158
+ def on_save_checkpoint(self, trainer, state: dict[str, Any]):
159
+ self._remove_activation_hooks()
160
+
161
+ def _collect_parameter_stats(self, model, step: int):
162
+ """Collect gradient and weight statistics in a single parameter iteration."""
163
+ cache = self._temp_cache # alias
164
+
165
+ for name, param in model.named_parameters():
166
+ # Gradient stats
167
+ if (
168
+ (self.log_grads or "grads" in self.log_histograms)
169
+ and param.grad is not None
170
+ and self._should_track_grad(name)
171
+ ):
172
+ grad = param.grad.detach()
173
+ for stat_name, stat_fn in self.log_grads.items():
174
+ cache["scalars"]["grads/" + name + "/" + stat_name] = stat_fn(grad)
175
+ if "grads" in self.log_histograms:
176
+ counts, bin_edges = self.log_histograms["grads"](grad)
177
+ cache["histograms"]["grads/" + name + "/hist"] = counts
178
+ cache["histograms"]["grads/" + name + "/hist_bin_edges"] = bin_edges
179
+
180
+ # Weight stats
181
+ if (
182
+ self.log_weights or "weights" in self.log_histograms
183
+ ) and self._should_track_weight(name):
184
+ for stat_name, stat_fn in self.log_weights.items():
185
+ cache["scalars"]["weights/" + name + "/" + stat_name] = stat_fn(
186
+ param.data
187
+ )
188
+ if "weights" in self.log_histograms:
189
+ counts, bin_edges = self.log_histograms["weights"](param.data)
190
+ cache["histograms"]["weights/" + name + "/hist"] = counts
191
+ cache["histograms"]["weights/" + name + "/hist_bin_edges"] = (
192
+ bin_edges
193
+ )
194
+
195
+ def _should_track_grad(self, name: str) -> bool:
196
+ """Check if we should track gradients for this parameter."""
197
+ if self.filter_grads is None:
198
+ return True
199
+ return self.filter_grads(name)
200
+
201
+ def _should_track_weight(self, name: str) -> bool:
202
+ """Check if we should track weights for this parameter."""
203
+ if self.filter_weights is None:
204
+ return True
205
+ return self.filter_weights(name)
206
+
207
+ def _should_track_activation(self, name: str, module_type: type[nn.Module]) -> bool:
208
+ """Check if we should track activations for this module."""
209
+ if self.filter_activations is None:
210
+ return True
211
+ return self.filter_activations(name, module_type)
212
+
213
+ def _register_activation_hooks(self, model, step: int):
214
+ """Register forward hooks to accumulate activations."""
215
+ cache = self._temp_cache # alias
216
+
217
+ def create_activation_hook(name):
218
+ def hook(module, input, output):
219
+ if isinstance(output, torch.Tensor) and (step % self.log_freq == 0):
220
+ output = output.detach()
221
+ for stat_name, stat_fn in self.log_activations.items():
222
+ cache["activations/" + name + "/" + stat_name] = stat_fn(output)
223
+ if "activations" in self.log_histograms:
224
+ counts, bin_edges = self.log_histograms["activations"](output)
225
+ cache["histograms"]["activations/" + name + "/hist"] = counts
226
+ cache["histograms"][
227
+ "activations/" + name + "/hist_bin_edges"
228
+ ] = bin_edges
229
+
230
+ return hook
231
+
232
+ # Register hooks for filtered modules
233
+ for name, module in model.named_modules():
234
+ if self._should_track_activation(name, type(module)):
235
+ hook = module.register_forward_hook(create_activation_hook(name))
236
+ self._hooks.append(hook)
237
+
238
+ def _remove_activation_hooks(self):
239
+ """Remove activation hooks."""
240
+ for hook in self._hooks:
241
+ hook.remove()
242
+ self._hooks.clear()
243
+
244
+ def __del__(self):
245
+ self._remove_activation_hooks()
246
+ del self._temp_cache
247
+ del self._cache
248
+ gc.collect()
249
+
250
+
251
+ def plot_tensor_hist(
252
+ hist_values: Float[Tensor, "N M"],
253
+ name: str = "",
254
+ norms: Float[Tensor, "N"] = None,
255
+ steps: Int[Tensor, "N"] = None,
256
+ log_scale: bool = True,
257
+ ) -> plt.Figure:
258
+ """
259
+ Plot a histogram of tensor values over time, optionally including norm values.
260
+
261
+ Args:
262
+ hist_values: Tensor of shape (N, M) containing histogram values for N steps and M bins.
263
+ name: Title for the plot, usually the name of the parameter being plotted.
264
+ norms: Optional tensor of shape (N,) containing norm values for each step.
265
+ steps: Optional tensor of shape (N,) containing step indices. If None, uses range(N).
266
+ log_scale: If True, applies log1p to histogram values before plotting.
267
+
268
+ Returns:
269
+ A matplotlib Figure object containing the plotted histogram.
270
+
271
+ Example:
272
+ >>> hist_values = torch.randn(100, 50) # 100 steps, 50 bins
273
+ >>> norms = torch.norm(hist_values, dim=1)
274
+ >>> fig = plot_tensor_hist(hist_values, name="Weight Distribution", norms=norms)
275
+ >>> plt.show()
276
+ """
277
+ font_size = 8
278
+ with plt.rc_context({"font.size": font_size}):
279
+ n_steps, n_bins = hist_values.shape # (N, M)
280
+ if log_scale:
281
+ hist_values = np.log1p(hist_values)
282
+ if steps is None:
283
+ steps = np.arange(n_steps)
284
+ fig, ax = plt.subplots(
285
+ figsize=(6, 2), constrained_layout=True
286
+ ) # Added constrained_layout
287
+ mat = ax.matshow(hist_values.T, aspect="auto")
288
+ ax.set_xlabel("step")
289
+
290
+ # Get the automatically determined tick positions from matplotlib
291
+ locs = ax.get_xticks()
292
+ valid_locs = locs[(locs >= 0) & (locs < n_steps)].astype(int)
293
+ ax.set_xticks(valid_locs)
294
+ ax.set_xticklabels(steps[valid_locs])
295
+ ax.set_ylabel("bins")
296
+
297
+ # Create twin axis
298
+ if norms is not None:
299
+ ax2 = ax.twinx()
300
+ ax2.plot(np.arange(len(norms)), norms, color="black")
301
+ ax2.set_ylabel("norm")
302
+ ax2.set_xlim(0, n_steps - 1)
303
+ ax2.set_ylim(min(norms), max(norms)) # Independent scaling
304
+ ax2.set_xticks(valid_locs)
305
+ ax2.set_xticklabels(steps[valid_locs])
306
+
307
+ # Add colorbar - constrained_layout will handle spacing automatically
308
+ cbar = plt.colorbar(mat, ax=ax)
309
+ cbar.ax.set_ylabel("log(1+count)" if log_scale else "count")
310
+
311
+ ax.set_xlim(0, n_steps - 1)
312
+ ax.set_ylim(0, n_bins - 1)
313
+ if name:
314
+ ax.set_title(name, pad=20, fontsize=8)
315
+
316
+ return fig
317
+
318
+
319
+ def plot_tensor_stats(
320
+ steps: Int[Tensor, "N"],
321
+ mean: Float[Tensor, "N"] = None,
322
+ std: Float[Tensor, "N"] = None,
323
+ min_val: Float[Tensor, "N"] = None,
324
+ max_val: Float[Tensor, "N"] = None,
325
+ norm: Float[Tensor, "N"] = None,
326
+ name: str = "",
327
+ height_ratios: tuple[float, float] = (5, 1),
328
+ ):
329
+ """
330
+ Plot comprehensive statistics with mean/std/min/max in top panel and norm in bottom panel.
331
+
332
+ Args:
333
+ steps: Training step indices
334
+ mean: Mean values over time (optional)
335
+ std: Standard deviation values over time (optional, requires mean)
336
+ min_val: Minimum values over time (optional)
337
+ max_val: Maximum values over time (optional)
338
+ norm: Norm values over time (optional)
339
+ name: Title for the plot, usually the name of the parameter being plotted.
340
+ height_ratios: Relative heights of (stats_panel, norm_panel)
341
+
342
+ Returns:
343
+ matplotlib Figure object
344
+ """
345
+ # Determine what to plot
346
+ has_stats = any([mean is not None, min_val is not None, max_val is not None])
347
+ has_norm = norm is not None
348
+
349
+ if not has_stats and not has_norm:
350
+ raise ValueError(
351
+ "At least one of mean, min_val, max_val, or norm must be provided"
352
+ )
353
+
354
+ # Create subplot layout based on available data
355
+ if has_stats and has_norm:
356
+ fig, (ax1, ax2) = plt.subplots(
357
+ 2,
358
+ 1,
359
+ figsize=(5, 3),
360
+ gridspec_kw={"height_ratios": height_ratios},
361
+ sharex=True,
362
+ constrained_layout=True,
363
+ )
364
+ norm_ax = ax2
365
+ stats_ax = ax1
366
+ elif has_stats:
367
+ fig, ax1 = plt.subplots(figsize=(5, 3))
368
+ stats_ax = ax1
369
+ norm_ax = None
370
+ else: # only norm
371
+ fig, ax2 = plt.subplots(figsize=(5, 3))
372
+ norm_ax = ax2
373
+ stats_ax = None
374
+
375
+ # Top panel: statistics (if available)
376
+ if has_stats and stats_ax is not None:
377
+ if mean is not None:
378
+ stats_ax.plot(steps, mean, label="mean", color="C0")
379
+ if std is not None:
380
+ stats_ax.fill_between(
381
+ steps, mean - std, mean + std, alpha=0.2, color="C0", label="±1 std"
382
+ )
383
+
384
+ if min_val is not None and max_val is not None:
385
+ stats_ax.plot(
386
+ steps, min_val, "--", color="gray", alpha=0.7, label="min/max"
387
+ )
388
+ stats_ax.plot(steps, max_val, "--", color="gray", alpha=0.7)
389
+ elif min_val is not None:
390
+ stats_ax.plot(steps, min_val, "--", color="gray", alpha=0.7, label="min")
391
+ elif max_val is not None:
392
+ stats_ax.plot(steps, max_val, "--", color="gray", alpha=0.7, label="max")
393
+
394
+ stats_ax.ticklabel_format(style="plain", useOffset=False)
395
+ stats_ax.set_ylabel("Stats", labelpad=0)
396
+ if name:
397
+ stats_ax.set_title(name, pad=5, fontsize=9)
398
+ stats_ax.grid(True, alpha=0.3)
399
+ stats_ax.legend(loc="upper right", bbox_to_anchor=(1, 1), ncol=2)
400
+
401
+ # Set xlabel only if this is the only panel
402
+ if not has_norm:
403
+ stats_ax.set_xlabel("Step")
404
+
405
+ # Bottom panel: norm (if available)
406
+ if has_norm and norm_ax is not None:
407
+ norm_ax.plot(steps, norm, label="norm", color="C1")
408
+ norm_ax.set_ylabel("Norm", labelpad=0)
409
+ norm_ax.set_xlabel("Step")
410
+ norm_ax.grid(True, alpha=0.3)
411
+ norm_ax.legend(loc="upper right", bbox_to_anchor=(1, 1))
412
+ norm_ax.ticklabel_format(style="plain", useOffset=False)
413
+
414
+ # Set title if this is the only panel and no stats panel exists
415
+ if not has_stats and name:
416
+ norm_ax.set_title(name, pad=5, fontsize=9)
417
+
418
+ plt.tight_layout(pad=0.5, h_pad=0.5, w_pad=0.5)
419
+ return fig