torch-ttt 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_ttt-0.0.1/LICENSE +21 -0
- torch_ttt-0.0.1/PKG-INFO +16 -0
- torch_ttt-0.0.1/README.md +74 -0
- torch_ttt-0.0.1/pyproject.toml +49 -0
- torch_ttt-0.0.1/setup.cfg +4 -0
- torch_ttt-0.0.1/tests/test_imports.py +32 -0
- torch_ttt-0.0.1/torch_ttt/__init__.py +0 -0
- torch_ttt-0.0.1/torch_ttt/engine/__init__.py +5 -0
- torch_ttt-0.0.1/torch_ttt/engine/actmad_engine.py +174 -0
- torch_ttt-0.0.1/torch_ttt/engine/base_engine.py +47 -0
- torch_ttt-0.0.1/torch_ttt/engine/masked_ttt_engine.py +16 -0
- torch_ttt-0.0.1/torch_ttt/engine/tent_engine.py +49 -0
- torch_ttt-0.0.1/torch_ttt/engine/ttt_engine.py +184 -0
- torch_ttt-0.0.1/torch_ttt/engine/ttt_pp_engine.py +291 -0
- torch_ttt-0.0.1/torch_ttt/engine_registry.py +38 -0
- torch_ttt-0.0.1/torch_ttt/loss/__init__.py +2 -0
- torch_ttt-0.0.1/torch_ttt/loss/base_loss.py +8 -0
- torch_ttt-0.0.1/torch_ttt/loss/contrastive_loss.py +94 -0
- torch_ttt-0.0.1/torch_ttt/loss/entropy_loss.py +15 -0
- torch_ttt-0.0.1/torch_ttt/loss/mean_loss.py +12 -0
- torch_ttt-0.0.1/torch_ttt/loss/ttt_loss.py +13 -0
- torch_ttt-0.0.1/torch_ttt/loss/weights_magnitude_loss.py +30 -0
- torch_ttt-0.0.1/torch_ttt/loss/zerot_loss.py +158 -0
- torch_ttt-0.0.1/torch_ttt/loss_registry.py +38 -0
- torch_ttt-0.0.1/torch_ttt/utils/__init__.py +0 -0
- torch_ttt-0.0.1/torch_ttt/utils/augmentations.py +25 -0
- torch_ttt-0.0.1/torch_ttt/utils/math.py +29 -0
- torch_ttt-0.0.1/torch_ttt.egg-info/PKG-INFO +16 -0
- torch_ttt-0.0.1/torch_ttt.egg-info/SOURCES.txt +30 -0
- torch_ttt-0.0.1/torch_ttt.egg-info/dependency_links.txt +1 -0
- torch_ttt-0.0.1/torch_ttt.egg-info/requires.txt +14 -0
- torch_ttt-0.0.1/torch_ttt.egg-info/top_level.txt +1 -0
torch_ttt-0.0.1/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2024 Nikita Durasov
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
torch_ttt-0.0.1/PKG-INFO
ADDED
@@ -0,0 +1,16 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: torch-ttt
|
3
|
+
Version: 0.0.1
|
4
|
+
License-File: LICENSE
|
5
|
+
Provides-Extra: dev
|
6
|
+
Requires-Dist: pytest-cov; extra == "dev"
|
7
|
+
Provides-Extra: docs
|
8
|
+
Requires-Dist: sphinxawesome-theme==6.0.0b1; extra == "docs"
|
9
|
+
Requires-Dist: sphinx-gallery==0.18.0; extra == "docs"
|
10
|
+
Requires-Dist: sphinxcontrib-bibtex; extra == "docs"
|
11
|
+
Requires-Dist: sphinxcontrib-googleanalytics; extra == "docs"
|
12
|
+
Requires-Dist: matplotlib; extra == "docs"
|
13
|
+
Requires-Dist: tqdm; extra == "docs"
|
14
|
+
Provides-Extra: all
|
15
|
+
Requires-Dist: torch_ttt[dev,docs]; extra == "all"
|
16
|
+
Dynamic: license-file
|
@@ -0,0 +1,74 @@
|
|
1
|
+
<div align="center">
|
2
|
+
<img src="docs/source/_static/images/torch-ttt.png" alt="TorchTTT" width="500">
|
3
|
+
</div>
|
4
|
+
|
5
|
+
<!-- <div style="display: flex; gap: 0px; flex-wrap: wrap; align-items: center;">
|
6
|
+
<a href="https://github.com/nikitadurasov/torch-ttt/stargazers" style="margin: 2px;">
|
7
|
+
<img src="https://img.shields.io/github/stars/nikitadurasov/torch-ttt.svg?style=social" alt="GitHub stars" style="display: inline-block; margin: 0;">
|
8
|
+
</a>
|
9
|
+
<a href="https://github.com/nikitadurasov/torch-ttt/network" style="margin: 2px;">
|
10
|
+
<img src="https://img.shields.io/github/forks/nikitadurasov/torch-ttt.svg?color=blue" alt="GitHub forks" style="display: inline-block; margin: 0;">
|
11
|
+
</a>
|
12
|
+
<a href="https://github.com/nikitadurasov/torch-ttt/actions/workflows/deploy-docs.yml" style="margin: 2px;">
|
13
|
+
<img src="https://github.com/nikitadurasov/torch-ttt/actions/workflows/deploy-docs.yml/badge.svg" alt="Documentation" style="display: inline-block; margin: 0;">
|
14
|
+
</a>
|
15
|
+
</div> -->
|
16
|
+
|
17
|
+
# torch-ttt
|
18
|
+
|
19
|
+
**torch-ttt** is a package designed to work with [Test-Time Training (TTT)](https://arxiv.org/abs/1909.13231) techniques and make your networks more generalizable. It aims to be modular, easy to integrate into existing pipelines, and collaborative— including as many methods as possible. Reach out to add yours!
|
20
|
+
|
21
|
+
<p align="center">
|
22
|
+
>> You can find our webpage and documentation here:</strong>
|
23
|
+
<a href="https://torch-ttt.github.io">torch-ttt.github.io</a>
|
24
|
+
</p>
|
25
|
+
|
26
|
+
> **torch-ttt** is in its early stages, so changes are expected. Contributions are welcome—feel free to get involved! If you run into any bugs or issues, don’t hesitate to submit an issue.
|
27
|
+
|
28
|
+
---
|
29
|
+
|
30
|
+
This package provides a streamlined API for a variety of TTT methods through *Engines*, which are lightweight wrappers around your model. These Engines are:
|
31
|
+
|
32
|
+
- **Easy to use** – The internal logic of TTT methods is fully encapsulated, so you only need to wrap your model with an Engine, and you're ready to go.
|
33
|
+
- **Highly modular and standardized** – Each Engine follows the same interface, allowing methods to be used interchangeably, making it easy to find the best fit for your application.
|
34
|
+
- **Minimal changes required** – Enabling a TTT method for your model requires only a few additional lines of code.
|
35
|
+
|
36
|
+
Check out the [Quick Start](https://torch-ttt.github.io/quickstart.html) guide or the [API reference](https://torch-ttt.github.io/api.html) for a more detailed explanation of how Engines work and their core concepts.
|
37
|
+
|
38
|
+
# Installation
|
39
|
+
|
40
|
+
**torch-ttt** requires Python 3.10 or greater. Install the desired PyTorch version in your environment.
|
41
|
+
|
42
|
+
For the latest development version you can run,
|
43
|
+
|
44
|
+
```console
|
45
|
+
pip install git+https://github.com/nikitadurasov/torch-ttt.git
|
46
|
+
```
|
47
|
+
|
48
|
+
While we do not support PyPI yet, support is expected very soon!
|
49
|
+
|
50
|
+
# Quickstart
|
51
|
+
|
52
|
+
We provide a **Quick Start** guide to help you get started smoothly. Take a look here: [Quick Start](https://torch-ttt.github.io/quickstart.html) and see how to integrate TTT methods into your project.
|
53
|
+
|
54
|
+
<!-- # Implemented TTTs
|
55
|
+
|
56
|
+
## Baselines
|
57
|
+
|
58
|
+
|
59
|
+
| TTT-Method | Image | Text | Graph | Audio |
|
60
|
+
|-----------------------------------------------|:----------:|:--------------:|:------------:|:---------------------:|
|
61
|
+
| [TTT](https://arxiv.org/abs/1909.13231) | ⏳ | ⏳ | ⏳ | ⏳ |
|
62
|
+
| [MaskedTTT](https://arxiv.org/abs/2209.07522) | ⏳ | ⏳ | ⏳ | ⏳ |
|
63
|
+
| [TTT++](https://proceedings.neurips.cc/paper/2021/hash/b618c3210e934362ac261db280128c22-Abstract.html) | ⏳ | ⏳ | ⏳ | ⏳ |
|
64
|
+
| [ActMAD](https://arxiv.org/abs/2211.12870) | ⏳ | ⏳ | ⏳ | ⏳ |
|
65
|
+
| [SHOT](https://arxiv.org/abs/2002.08546) | ⏳ | ⏳ | ⏳ | ⏳ |
|
66
|
+
| [TENT](https://arxiv.org/abs/2006.10726) | ⏳ | ⏳ | ⏳ | ⏳ | -->
|
67
|
+
|
68
|
+
# Tutorials
|
69
|
+
|
70
|
+
We offer a variety of tutorials to help users gain a deeper understanding of the implemented methods and see how they can be applied to different tasks. Visit the [Tutorials](https://torch-ttt.github.io/auto_examples/index.html) page to explore them.
|
71
|
+
|
72
|
+
# Documentation
|
73
|
+
|
74
|
+
Our aim is to provide comprehensive documentation for all included TTT methods, covering their theoretical foundations, practical benefits, and efficient integration into your project. We also offer tutorials that illustrate their applications and guide you through their effective usage.
|
@@ -0,0 +1,49 @@
|
|
1
|
+
[build-system]
|
2
|
+
requires = ["setuptools", "wheel"]
|
3
|
+
build-backend = "setuptools.build_meta"
|
4
|
+
|
5
|
+
[project]
|
6
|
+
name = "torch-ttt"
|
7
|
+
version = "0.0.1"
|
8
|
+
|
9
|
+
[project.optional-dependencies]
|
10
|
+
|
11
|
+
dev = [
|
12
|
+
"pytest-cov",
|
13
|
+
]
|
14
|
+
|
15
|
+
docs = [
|
16
|
+
"sphinxawesome-theme==6.0.0b1",
|
17
|
+
"sphinx-gallery==0.18.0",
|
18
|
+
"sphinxcontrib-bibtex",
|
19
|
+
"sphinxcontrib-googleanalytics",
|
20
|
+
"matplotlib",
|
21
|
+
"tqdm"
|
22
|
+
]
|
23
|
+
|
24
|
+
all = [
|
25
|
+
"torch_ttt[dev, docs]"
|
26
|
+
]
|
27
|
+
|
28
|
+
[tool.ruff]
|
29
|
+
line-length = 100
|
30
|
+
target-version = "py310"
|
31
|
+
|
32
|
+
exclude = [
|
33
|
+
"build", # Ignore the 'build' directory
|
34
|
+
"dist", # Ignore the 'dist' directory
|
35
|
+
"__pycache__", # Ignore Python's cache directories
|
36
|
+
"notebooks"
|
37
|
+
]
|
38
|
+
|
39
|
+
[tool.setuptools.packages.find]
|
40
|
+
where = ["."]
|
41
|
+
include = ["torch_ttt*"]
|
42
|
+
|
43
|
+
[tool.ruff.lint.pydocstyle]
|
44
|
+
convention = "google"
|
45
|
+
|
46
|
+
[tool.coverage.run]
|
47
|
+
branch = true
|
48
|
+
source = ["torch_ttt"]
|
49
|
+
omit = ["*/tests/*"]
|
@@ -0,0 +1,32 @@
|
|
1
|
+
# ruff: noqa: F401
|
2
|
+
import unittest
|
3
|
+
|
4
|
+
|
5
|
+
class TestImports(unittest.TestCase):
|
6
|
+
def test_import_loss_registry(self):
|
7
|
+
try:
|
8
|
+
from torch_ttt.loss_registry import LossRegistry
|
9
|
+
except ImportError as e:
|
10
|
+
self.fail(f"Failed to import LossRegistry: {e}")
|
11
|
+
|
12
|
+
def test_import_engine_registry(self):
|
13
|
+
try:
|
14
|
+
from torch_ttt.engine_registry import EngineRegistry
|
15
|
+
except ImportError as e:
|
16
|
+
self.fail(f"Failed to import EngineRegistry: {e}")
|
17
|
+
|
18
|
+
def test_import_base_loss(self):
|
19
|
+
try:
|
20
|
+
from torch_ttt.loss.base_loss import BaseLoss
|
21
|
+
except ImportError as e:
|
22
|
+
self.fail(f"Failed to import BaseLoss: {e}")
|
23
|
+
|
24
|
+
def test_import_base_engine(self):
|
25
|
+
try:
|
26
|
+
from torch_ttt.engine.base_engine import BaseEngine
|
27
|
+
except ImportError as e:
|
28
|
+
self.fail(f"Failed to import BaseEngine: {e}")
|
29
|
+
|
30
|
+
|
31
|
+
if __name__ == "__main__":
|
32
|
+
unittest.main()
|
File without changes
|
@@ -0,0 +1,174 @@
|
|
1
|
+
import torch
|
2
|
+
from typing import List, Dict, Any, Tuple, Union
|
3
|
+
from contextlib import contextmanager
|
4
|
+
|
5
|
+
from torch.utils.data import DataLoader
|
6
|
+
from torch_ttt.engine.base_engine import BaseEngine
|
7
|
+
from torch_ttt.engine_registry import EngineRegistry
|
8
|
+
|
9
|
+
__all__ = ["ActMADEngine"]
|
10
|
+
|
11
|
+
|
12
|
+
# TODO: add cuda support
|
13
|
+
@EngineRegistry.register("actmad")
|
14
|
+
class ActMADEngine(BaseEngine):
|
15
|
+
"""**ActMAD** approach: multi-level pixel-wise feature alignment.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
model (torch.nn.Module): Model to be trained with TTT.
|
19
|
+
features_layer_names (List[str] | str): List of layer names to be used for feature alignment.
|
20
|
+
optimization_parameters (dict): The optimization parameters for the engine.
|
21
|
+
|
22
|
+
:Example:
|
23
|
+
|
24
|
+
.. code-block:: python
|
25
|
+
|
26
|
+
from torch_ttt.engine.actmad_engine import ActMADEngine
|
27
|
+
|
28
|
+
model = MyModel()
|
29
|
+
engine = ActMADEngine(model, ["fc1", "fc2"])
|
30
|
+
optimizer = torch.optim.Adam(engine.parameters(), lr=1e-4)
|
31
|
+
|
32
|
+
# Training
|
33
|
+
engine.train()
|
34
|
+
for inputs, labels in train_loader:
|
35
|
+
optimizer.zero_grad()
|
36
|
+
outputs, loss_ttt = engine(inputs)
|
37
|
+
loss = criterion(outputs, labels) + alpha * loss_ttt
|
38
|
+
loss.backward()
|
39
|
+
optimizer.step()
|
40
|
+
|
41
|
+
# Compute statistics for features alignment
|
42
|
+
engine.compute_statistics(train_loader)
|
43
|
+
|
44
|
+
# Inference
|
45
|
+
engine.eval()
|
46
|
+
for inputs, labels in test_loader:
|
47
|
+
output, loss_ttt = engine(inputs)
|
48
|
+
|
49
|
+
Reference:
|
50
|
+
|
51
|
+
"ActMAD: Activation Matching to Align Distributions for Test-Time Training", M. Jehanzeb Mirza, Pol Jane Soneira, Wei Lin, Mateusz Kozinski, Horst Possegger, Horst Bischof
|
52
|
+
|
53
|
+
Paper link: `PDF <https://proceedings.neurips.cc/paper/2021/hash/b618c3210e934362ac261db280128c22-Abstract.html>`_
|
54
|
+
"""
|
55
|
+
|
56
|
+
def __init__(
|
57
|
+
self,
|
58
|
+
model: torch.nn.Module,
|
59
|
+
features_layer_names: Union[List[str], str],
|
60
|
+
optimization_parameters: Dict[str, Any] = {},
|
61
|
+
):
|
62
|
+
super().__init__()
|
63
|
+
self.model = model
|
64
|
+
self.features_layer_names = features_layer_names
|
65
|
+
self.optimization_parameters = optimization_parameters
|
66
|
+
|
67
|
+
if isinstance(features_layer_names, str):
|
68
|
+
self.features_layer_names = [features_layer_names]
|
69
|
+
|
70
|
+
# TODO: rewrite this
|
71
|
+
self.target_modules = []
|
72
|
+
for layer_name in self.features_layer_names:
|
73
|
+
layer_exists = False
|
74
|
+
for name, module in model.named_modules():
|
75
|
+
if name == layer_name:
|
76
|
+
layer_exists = True
|
77
|
+
self.target_modules.append(module)
|
78
|
+
break
|
79
|
+
if not layer_exists:
|
80
|
+
raise ValueError(f"Layer {layer_name} does not exist in the model.")
|
81
|
+
|
82
|
+
self.reference_mean = None
|
83
|
+
self.reference_var = None
|
84
|
+
|
85
|
+
def ttt_forward(self, inputs) -> Tuple[torch.Tensor, torch.Tensor]:
|
86
|
+
"""Forward pass of the model.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
inputs (torch.Tensor): Input tensor.
|
90
|
+
|
91
|
+
Returns:
|
92
|
+
Returns the current model prediction and rotation loss value.
|
93
|
+
"""
|
94
|
+
with self.__capture_hook() as features_hooks:
|
95
|
+
outputs = self.model(inputs)
|
96
|
+
features = [hook.output for hook in features_hooks]
|
97
|
+
|
98
|
+
# don't compute loss during training
|
99
|
+
if self.training:
|
100
|
+
return outputs, 0
|
101
|
+
|
102
|
+
if self.reference_var is None or self.reference_mean is None:
|
103
|
+
raise ValueError(
|
104
|
+
"Reference statistics are not computed. Please call `compute_statistics` method."
|
105
|
+
)
|
106
|
+
|
107
|
+
l1_loss = torch.nn.L1Loss(reduction='mean')
|
108
|
+
features_means = [torch.mean(feature, dim=0) for feature in features]
|
109
|
+
features_vars = [torch.var(feature, dim=0) for feature in features]
|
110
|
+
|
111
|
+
loss = 0
|
112
|
+
for i in range(len(self.target_modules)):
|
113
|
+
print(features_means[i].device, self.reference_mean[i].device)
|
114
|
+
loss += l1_loss(features_means[i], self.reference_mean[i])
|
115
|
+
loss += l1_loss(features_vars[i], self.reference_var[i])
|
116
|
+
|
117
|
+
return outputs, loss
|
118
|
+
|
119
|
+
def compute_statistics(self, dataloader: DataLoader) -> None:
|
120
|
+
"""Extract and compute reference statistics for features.
|
121
|
+
|
122
|
+
Args:
|
123
|
+
dataloader (DataLoader): The dataloader used for extracting features. It can return tuples of tensors, with the first element expected to be the input tensor.
|
124
|
+
|
125
|
+
Raises:
|
126
|
+
ValueError: If the dataloader is empty or features have mismatched dimensions.
|
127
|
+
"""
|
128
|
+
|
129
|
+
self.model.eval()
|
130
|
+
feat_stack = [[] for _ in self.target_modules]
|
131
|
+
|
132
|
+
# TODO: compute variance in more memory efficient way
|
133
|
+
with torch.no_grad():
|
134
|
+
device = next(self.model.parameters()).device
|
135
|
+
for sample in dataloader:
|
136
|
+
if len(sample) < 1:
|
137
|
+
raise ValueError("Dataloader returned an empty batch.")
|
138
|
+
|
139
|
+
inputs = sample[0].to(device)
|
140
|
+
with self.__capture_hook() as features_hooks:
|
141
|
+
_ = self.model(inputs)
|
142
|
+
features = [hook.output.cpu() for hook in features_hooks]
|
143
|
+
|
144
|
+
for i, feature in enumerate(features):
|
145
|
+
feat_stack[i].append(feature)
|
146
|
+
|
147
|
+
# Compute mean and variance
|
148
|
+
self.reference_mean = [torch.mean(torch.cat(feat), dim=0).to(device) for feat in feat_stack]
|
149
|
+
self.reference_var = [torch.var(torch.cat(feat), dim=0).to(device) for feat in feat_stack]
|
150
|
+
|
151
|
+
@contextmanager
|
152
|
+
def __capture_hook(self):
|
153
|
+
"""Context manager to capture features via a forward hook."""
|
154
|
+
|
155
|
+
class OutputHook:
|
156
|
+
def __init__(self):
|
157
|
+
self.output = None
|
158
|
+
|
159
|
+
def hook(self, module, input, output):
|
160
|
+
self.output = output
|
161
|
+
|
162
|
+
hook_handels = []
|
163
|
+
features_hooks = []
|
164
|
+
for module in self.target_modules:
|
165
|
+
hook = OutputHook()
|
166
|
+
features_hooks.append(hook)
|
167
|
+
hook_handle = module.register_forward_hook(hook.hook)
|
168
|
+
hook_handels.append(hook_handle)
|
169
|
+
|
170
|
+
try:
|
171
|
+
yield features_hooks
|
172
|
+
finally:
|
173
|
+
for hook in hook_handels:
|
174
|
+
hook.remove()
|
@@ -0,0 +1,47 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import Tuple
|
3
|
+
import torch.nn as nn
|
4
|
+
import torch
|
5
|
+
from copy import deepcopy
|
6
|
+
|
7
|
+
class BaseEngine(nn.Module, ABC):
|
8
|
+
|
9
|
+
def __init__(self):
|
10
|
+
nn.Module.__init__(self)
|
11
|
+
self.optimization_parameters = {}
|
12
|
+
|
13
|
+
@abstractmethod
|
14
|
+
def ttt_forward(self, inputs) -> Tuple[torch.Tensor, torch.Tensor]:
|
15
|
+
pass
|
16
|
+
|
17
|
+
def forward(self, inputs):
|
18
|
+
|
19
|
+
if self.training:
|
20
|
+
return self.ttt_forward(inputs)
|
21
|
+
|
22
|
+
# TODO: optimization pipeline should be more flexible and
|
23
|
+
# user-defined, need some special structure for that
|
24
|
+
optimization_parameters = self.optimization_parameters or {}
|
25
|
+
|
26
|
+
optimizer_name = optimization_parameters.get("optimizer_name", "adam")
|
27
|
+
num_steps = optimization_parameters.get("num_steps", 3)
|
28
|
+
lr = optimization_parameters.get("lr", 1e-2)
|
29
|
+
copy_model = optimization_parameters.get("copy_model", "True")
|
30
|
+
|
31
|
+
running_engine = deepcopy(self) if copy_model else self
|
32
|
+
|
33
|
+
parameters = filter(lambda p: p.requires_grad, running_engine.model.parameters())
|
34
|
+
if optimizer_name == "adam":
|
35
|
+
optimizer = torch.optim.Adam(parameters, lr=lr)
|
36
|
+
|
37
|
+
loss = 0 # default value
|
38
|
+
for _ in range(num_steps):
|
39
|
+
optimizer.zero_grad()
|
40
|
+
_, loss = running_engine.ttt_forward(inputs)
|
41
|
+
loss.backward()
|
42
|
+
optimizer.step()
|
43
|
+
|
44
|
+
with torch.no_grad():
|
45
|
+
final_outputs = running_engine.model(inputs)
|
46
|
+
|
47
|
+
return final_outputs, loss
|
@@ -0,0 +1,16 @@
|
|
1
|
+
from torch_ttt.engine.base_engine import BaseEngine
|
2
|
+
from torch_ttt.engine_registry import EngineRegistry
|
3
|
+
|
4
|
+
__all__ = ["MaskedTTTEngine"]
|
5
|
+
|
6
|
+
|
7
|
+
# TODO: add cuda support
|
8
|
+
@EngineRegistry.register("masked_ttt")
|
9
|
+
class MaskedTTTEngine(BaseEngine):
|
10
|
+
"""Masked autoencoders-based **test-time training** approach."""
|
11
|
+
|
12
|
+
def __init__(self):
|
13
|
+
super().__init__()
|
14
|
+
|
15
|
+
def __call__(self, inputs):
|
16
|
+
pass
|
@@ -0,0 +1,49 @@
|
|
1
|
+
import torch
|
2
|
+
from typing import Dict, Any, Tuple
|
3
|
+
from torch_ttt.engine.base_engine import BaseEngine
|
4
|
+
from torch_ttt.engine_registry import EngineRegistry
|
5
|
+
|
6
|
+
__all__ = ["TentEngine"]
|
7
|
+
|
8
|
+
@EngineRegistry.register("tent")
|
9
|
+
class TentEngine(BaseEngine):
|
10
|
+
"""**TENT**: Fully test-time adaptation by entropy minimization.
|
11
|
+
|
12
|
+
Args:
|
13
|
+
model (torch.nn.Module): The model to adapt.
|
14
|
+
optimization_parameters (dict): Hyperparameters for adaptation.
|
15
|
+
|
16
|
+
Reference:
|
17
|
+
"TENT: Fully Test-Time Adaptation by Entropy Minimization"
|
18
|
+
Dequan Wang, Evan Shelhamer, et al.
|
19
|
+
"""
|
20
|
+
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
model: torch.nn.Module,
|
24
|
+
optimization_parameters: Dict[str, Any] = {},
|
25
|
+
):
|
26
|
+
super().__init__()
|
27
|
+
self.model = model
|
28
|
+
self.optimization_parameters = optimization_parameters
|
29
|
+
|
30
|
+
# Tent adapts only affine parameters in BatchNorm
|
31
|
+
self.model.train()
|
32
|
+
self._configure_bn()
|
33
|
+
|
34
|
+
def _configure_bn(self):
|
35
|
+
for module in self.model.modules():
|
36
|
+
if isinstance(module, torch.nn.BatchNorm2d):
|
37
|
+
module.requires_grad_(True)
|
38
|
+
module.track_running_stats = False
|
39
|
+
else:
|
40
|
+
for param in module.parameters(recurse=False):
|
41
|
+
param.requires_grad = False
|
42
|
+
|
43
|
+
def ttt_forward(self, inputs) -> Tuple[torch.Tensor, torch.Tensor]:
|
44
|
+
"""Forward pass and entropy loss computation."""
|
45
|
+
outputs = self.model(inputs)
|
46
|
+
probs = torch.nn.functional.softmax(outputs, dim=1)
|
47
|
+
log_probs = torch.nn.functional.log_softmax(outputs, dim=1)
|
48
|
+
entropy = -torch.sum(probs * log_probs, dim=1).mean()
|
49
|
+
return outputs, entropy
|