torch-ttt 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_ttt/__init__.py +0 -0
- torch_ttt/engine/__init__.py +5 -0
- torch_ttt/engine/actmad_engine.py +174 -0
- torch_ttt/engine/base_engine.py +47 -0
- torch_ttt/engine/masked_ttt_engine.py +16 -0
- torch_ttt/engine/tent_engine.py +49 -0
- torch_ttt/engine/ttt_engine.py +184 -0
- torch_ttt/engine/ttt_pp_engine.py +291 -0
- torch_ttt/engine_registry.py +38 -0
- torch_ttt/loss/__init__.py +2 -0
- torch_ttt/loss/base_loss.py +8 -0
- torch_ttt/loss/contrastive_loss.py +94 -0
- torch_ttt/loss/entropy_loss.py +15 -0
- torch_ttt/loss/mean_loss.py +12 -0
- torch_ttt/loss/ttt_loss.py +13 -0
- torch_ttt/loss/weights_magnitude_loss.py +30 -0
- torch_ttt/loss/zerot_loss.py +158 -0
- torch_ttt/loss_registry.py +38 -0
- torch_ttt/utils/__init__.py +0 -0
- torch_ttt/utils/augmentations.py +25 -0
- torch_ttt/utils/math.py +29 -0
- torch_ttt-0.0.1.dist-info/METADATA +16 -0
- torch_ttt-0.0.1.dist-info/RECORD +26 -0
- torch_ttt-0.0.1.dist-info/WHEEL +5 -0
- torch_ttt-0.0.1.dist-info/licenses/LICENSE +21 -0
- torch_ttt-0.0.1.dist-info/top_level.txt +1 -0
torch_ttt/__init__.py
ADDED
File without changes
|
@@ -0,0 +1,174 @@
|
|
1
|
+
import torch
|
2
|
+
from typing import List, Dict, Any, Tuple, Union
|
3
|
+
from contextlib import contextmanager
|
4
|
+
|
5
|
+
from torch.utils.data import DataLoader
|
6
|
+
from torch_ttt.engine.base_engine import BaseEngine
|
7
|
+
from torch_ttt.engine_registry import EngineRegistry
|
8
|
+
|
9
|
+
__all__ = ["ActMADEngine"]
|
10
|
+
|
11
|
+
|
12
|
+
# TODO: add cuda support
|
13
|
+
@EngineRegistry.register("actmad")
|
14
|
+
class ActMADEngine(BaseEngine):
|
15
|
+
"""**ActMAD** approach: multi-level pixel-wise feature alignment.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
model (torch.nn.Module): Model to be trained with TTT.
|
19
|
+
features_layer_names (List[str] | str): List of layer names to be used for feature alignment.
|
20
|
+
optimization_parameters (dict): The optimization parameters for the engine.
|
21
|
+
|
22
|
+
:Example:
|
23
|
+
|
24
|
+
.. code-block:: python
|
25
|
+
|
26
|
+
from torch_ttt.engine.actmad_engine import ActMADEngine
|
27
|
+
|
28
|
+
model = MyModel()
|
29
|
+
engine = ActMADEngine(model, ["fc1", "fc2"])
|
30
|
+
optimizer = torch.optim.Adam(engine.parameters(), lr=1e-4)
|
31
|
+
|
32
|
+
# Training
|
33
|
+
engine.train()
|
34
|
+
for inputs, labels in train_loader:
|
35
|
+
optimizer.zero_grad()
|
36
|
+
outputs, loss_ttt = engine(inputs)
|
37
|
+
loss = criterion(outputs, labels) + alpha * loss_ttt
|
38
|
+
loss.backward()
|
39
|
+
optimizer.step()
|
40
|
+
|
41
|
+
# Compute statistics for features alignment
|
42
|
+
engine.compute_statistics(train_loader)
|
43
|
+
|
44
|
+
# Inference
|
45
|
+
engine.eval()
|
46
|
+
for inputs, labels in test_loader:
|
47
|
+
output, loss_ttt = engine(inputs)
|
48
|
+
|
49
|
+
Reference:
|
50
|
+
|
51
|
+
"ActMAD: Activation Matching to Align Distributions for Test-Time Training", M. Jehanzeb Mirza, Pol Jane Soneira, Wei Lin, Mateusz Kozinski, Horst Possegger, Horst Bischof
|
52
|
+
|
53
|
+
Paper link: `PDF <https://proceedings.neurips.cc/paper/2021/hash/b618c3210e934362ac261db280128c22-Abstract.html>`_
|
54
|
+
"""
|
55
|
+
|
56
|
+
def __init__(
|
57
|
+
self,
|
58
|
+
model: torch.nn.Module,
|
59
|
+
features_layer_names: Union[List[str], str],
|
60
|
+
optimization_parameters: Dict[str, Any] = {},
|
61
|
+
):
|
62
|
+
super().__init__()
|
63
|
+
self.model = model
|
64
|
+
self.features_layer_names = features_layer_names
|
65
|
+
self.optimization_parameters = optimization_parameters
|
66
|
+
|
67
|
+
if isinstance(features_layer_names, str):
|
68
|
+
self.features_layer_names = [features_layer_names]
|
69
|
+
|
70
|
+
# TODO: rewrite this
|
71
|
+
self.target_modules = []
|
72
|
+
for layer_name in self.features_layer_names:
|
73
|
+
layer_exists = False
|
74
|
+
for name, module in model.named_modules():
|
75
|
+
if name == layer_name:
|
76
|
+
layer_exists = True
|
77
|
+
self.target_modules.append(module)
|
78
|
+
break
|
79
|
+
if not layer_exists:
|
80
|
+
raise ValueError(f"Layer {layer_name} does not exist in the model.")
|
81
|
+
|
82
|
+
self.reference_mean = None
|
83
|
+
self.reference_var = None
|
84
|
+
|
85
|
+
def ttt_forward(self, inputs) -> Tuple[torch.Tensor, torch.Tensor]:
|
86
|
+
"""Forward pass of the model.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
inputs (torch.Tensor): Input tensor.
|
90
|
+
|
91
|
+
Returns:
|
92
|
+
Returns the current model prediction and rotation loss value.
|
93
|
+
"""
|
94
|
+
with self.__capture_hook() as features_hooks:
|
95
|
+
outputs = self.model(inputs)
|
96
|
+
features = [hook.output for hook in features_hooks]
|
97
|
+
|
98
|
+
# don't compute loss during training
|
99
|
+
if self.training:
|
100
|
+
return outputs, 0
|
101
|
+
|
102
|
+
if self.reference_var is None or self.reference_mean is None:
|
103
|
+
raise ValueError(
|
104
|
+
"Reference statistics are not computed. Please call `compute_statistics` method."
|
105
|
+
)
|
106
|
+
|
107
|
+
l1_loss = torch.nn.L1Loss(reduction='mean')
|
108
|
+
features_means = [torch.mean(feature, dim=0) for feature in features]
|
109
|
+
features_vars = [torch.var(feature, dim=0) for feature in features]
|
110
|
+
|
111
|
+
loss = 0
|
112
|
+
for i in range(len(self.target_modules)):
|
113
|
+
print(features_means[i].device, self.reference_mean[i].device)
|
114
|
+
loss += l1_loss(features_means[i], self.reference_mean[i])
|
115
|
+
loss += l1_loss(features_vars[i], self.reference_var[i])
|
116
|
+
|
117
|
+
return outputs, loss
|
118
|
+
|
119
|
+
def compute_statistics(self, dataloader: DataLoader) -> None:
|
120
|
+
"""Extract and compute reference statistics for features.
|
121
|
+
|
122
|
+
Args:
|
123
|
+
dataloader (DataLoader): The dataloader used for extracting features. It can return tuples of tensors, with the first element expected to be the input tensor.
|
124
|
+
|
125
|
+
Raises:
|
126
|
+
ValueError: If the dataloader is empty or features have mismatched dimensions.
|
127
|
+
"""
|
128
|
+
|
129
|
+
self.model.eval()
|
130
|
+
feat_stack = [[] for _ in self.target_modules]
|
131
|
+
|
132
|
+
# TODO: compute variance in more memory efficient way
|
133
|
+
with torch.no_grad():
|
134
|
+
device = next(self.model.parameters()).device
|
135
|
+
for sample in dataloader:
|
136
|
+
if len(sample) < 1:
|
137
|
+
raise ValueError("Dataloader returned an empty batch.")
|
138
|
+
|
139
|
+
inputs = sample[0].to(device)
|
140
|
+
with self.__capture_hook() as features_hooks:
|
141
|
+
_ = self.model(inputs)
|
142
|
+
features = [hook.output.cpu() for hook in features_hooks]
|
143
|
+
|
144
|
+
for i, feature in enumerate(features):
|
145
|
+
feat_stack[i].append(feature)
|
146
|
+
|
147
|
+
# Compute mean and variance
|
148
|
+
self.reference_mean = [torch.mean(torch.cat(feat), dim=0).to(device) for feat in feat_stack]
|
149
|
+
self.reference_var = [torch.var(torch.cat(feat), dim=0).to(device) for feat in feat_stack]
|
150
|
+
|
151
|
+
@contextmanager
|
152
|
+
def __capture_hook(self):
|
153
|
+
"""Context manager to capture features via a forward hook."""
|
154
|
+
|
155
|
+
class OutputHook:
|
156
|
+
def __init__(self):
|
157
|
+
self.output = None
|
158
|
+
|
159
|
+
def hook(self, module, input, output):
|
160
|
+
self.output = output
|
161
|
+
|
162
|
+
hook_handels = []
|
163
|
+
features_hooks = []
|
164
|
+
for module in self.target_modules:
|
165
|
+
hook = OutputHook()
|
166
|
+
features_hooks.append(hook)
|
167
|
+
hook_handle = module.register_forward_hook(hook.hook)
|
168
|
+
hook_handels.append(hook_handle)
|
169
|
+
|
170
|
+
try:
|
171
|
+
yield features_hooks
|
172
|
+
finally:
|
173
|
+
for hook in hook_handels:
|
174
|
+
hook.remove()
|
@@ -0,0 +1,47 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import Tuple
|
3
|
+
import torch.nn as nn
|
4
|
+
import torch
|
5
|
+
from copy import deepcopy
|
6
|
+
|
7
|
+
class BaseEngine(nn.Module, ABC):
|
8
|
+
|
9
|
+
def __init__(self):
|
10
|
+
nn.Module.__init__(self)
|
11
|
+
self.optimization_parameters = {}
|
12
|
+
|
13
|
+
@abstractmethod
|
14
|
+
def ttt_forward(self, inputs) -> Tuple[torch.Tensor, torch.Tensor]:
|
15
|
+
pass
|
16
|
+
|
17
|
+
def forward(self, inputs):
|
18
|
+
|
19
|
+
if self.training:
|
20
|
+
return self.ttt_forward(inputs)
|
21
|
+
|
22
|
+
# TODO: optimization pipeline should be more flexible and
|
23
|
+
# user-defined, need some special structure for that
|
24
|
+
optimization_parameters = self.optimization_parameters or {}
|
25
|
+
|
26
|
+
optimizer_name = optimization_parameters.get("optimizer_name", "adam")
|
27
|
+
num_steps = optimization_parameters.get("num_steps", 3)
|
28
|
+
lr = optimization_parameters.get("lr", 1e-2)
|
29
|
+
copy_model = optimization_parameters.get("copy_model", "True")
|
30
|
+
|
31
|
+
running_engine = deepcopy(self) if copy_model else self
|
32
|
+
|
33
|
+
parameters = filter(lambda p: p.requires_grad, running_engine.model.parameters())
|
34
|
+
if optimizer_name == "adam":
|
35
|
+
optimizer = torch.optim.Adam(parameters, lr=lr)
|
36
|
+
|
37
|
+
loss = 0 # default value
|
38
|
+
for _ in range(num_steps):
|
39
|
+
optimizer.zero_grad()
|
40
|
+
_, loss = running_engine.ttt_forward(inputs)
|
41
|
+
loss.backward()
|
42
|
+
optimizer.step()
|
43
|
+
|
44
|
+
with torch.no_grad():
|
45
|
+
final_outputs = running_engine.model(inputs)
|
46
|
+
|
47
|
+
return final_outputs, loss
|
@@ -0,0 +1,16 @@
|
|
1
|
+
from torch_ttt.engine.base_engine import BaseEngine
|
2
|
+
from torch_ttt.engine_registry import EngineRegistry
|
3
|
+
|
4
|
+
__all__ = ["MaskedTTTEngine"]
|
5
|
+
|
6
|
+
|
7
|
+
# TODO: add cuda support
|
8
|
+
@EngineRegistry.register("masked_ttt")
|
9
|
+
class MaskedTTTEngine(BaseEngine):
|
10
|
+
"""Masked autoencoders-based **test-time training** approach."""
|
11
|
+
|
12
|
+
def __init__(self):
|
13
|
+
super().__init__()
|
14
|
+
|
15
|
+
def __call__(self, inputs):
|
16
|
+
pass
|
@@ -0,0 +1,49 @@
|
|
1
|
+
import torch
|
2
|
+
from typing import Dict, Any, Tuple
|
3
|
+
from torch_ttt.engine.base_engine import BaseEngine
|
4
|
+
from torch_ttt.engine_registry import EngineRegistry
|
5
|
+
|
6
|
+
__all__ = ["TentEngine"]
|
7
|
+
|
8
|
+
@EngineRegistry.register("tent")
|
9
|
+
class TentEngine(BaseEngine):
|
10
|
+
"""**TENT**: Fully test-time adaptation by entropy minimization.
|
11
|
+
|
12
|
+
Args:
|
13
|
+
model (torch.nn.Module): The model to adapt.
|
14
|
+
optimization_parameters (dict): Hyperparameters for adaptation.
|
15
|
+
|
16
|
+
Reference:
|
17
|
+
"TENT: Fully Test-Time Adaptation by Entropy Minimization"
|
18
|
+
Dequan Wang, Evan Shelhamer, et al.
|
19
|
+
"""
|
20
|
+
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
model: torch.nn.Module,
|
24
|
+
optimization_parameters: Dict[str, Any] = {},
|
25
|
+
):
|
26
|
+
super().__init__()
|
27
|
+
self.model = model
|
28
|
+
self.optimization_parameters = optimization_parameters
|
29
|
+
|
30
|
+
# Tent adapts only affine parameters in BatchNorm
|
31
|
+
self.model.train()
|
32
|
+
self._configure_bn()
|
33
|
+
|
34
|
+
def _configure_bn(self):
|
35
|
+
for module in self.model.modules():
|
36
|
+
if isinstance(module, torch.nn.BatchNorm2d):
|
37
|
+
module.requires_grad_(True)
|
38
|
+
module.track_running_stats = False
|
39
|
+
else:
|
40
|
+
for param in module.parameters(recurse=False):
|
41
|
+
param.requires_grad = False
|
42
|
+
|
43
|
+
def ttt_forward(self, inputs) -> Tuple[torch.Tensor, torch.Tensor]:
|
44
|
+
"""Forward pass and entropy loss computation."""
|
45
|
+
outputs = self.model(inputs)
|
46
|
+
probs = torch.nn.functional.softmax(outputs, dim=1)
|
47
|
+
log_probs = torch.nn.functional.log_softmax(outputs, dim=1)
|
48
|
+
entropy = -torch.sum(probs * log_probs, dim=1).mean()
|
49
|
+
return outputs, entropy
|
@@ -0,0 +1,184 @@
|
|
1
|
+
import torch
|
2
|
+
from contextlib import contextmanager
|
3
|
+
from typing import Tuple, Dict, Any
|
4
|
+
from torchvision.transforms import functional as F
|
5
|
+
from torch_ttt.engine.base_engine import BaseEngine
|
6
|
+
from torch_ttt.engine_registry import EngineRegistry
|
7
|
+
|
8
|
+
__all__ = ["TTTEngine"]
|
9
|
+
|
10
|
+
|
11
|
+
@EngineRegistry.register("ttt")
|
12
|
+
class TTTEngine(BaseEngine):
|
13
|
+
r"""Original image rotation-based **test-time training** approach.
|
14
|
+
|
15
|
+
Args:
|
16
|
+
model (torch.nn.Module): Model to be trained with TTT.
|
17
|
+
features_layer_name (str): The name of the layer from which the features are extracted.
|
18
|
+
angle_head (torch.nn.Module, optional): The head that predicts the rotation angles.
|
19
|
+
angle_criterion (torch.nn.Module, optional): The loss function for the rotation angles.
|
20
|
+
optimization_parameters (dict): The optimization parameters for the engine.
|
21
|
+
|
22
|
+
Warning:
|
23
|
+
The module with the name :attr:`features_layer_name` should be present in the model.
|
24
|
+
|
25
|
+
Note:
|
26
|
+
:attr:`angle_head` and :attr:`angle_criterion` are optional arguments and can be user-defined. If not provided, the default shallow head and the :meth:`torch.nn.CrossEntropyLoss()` loss function are used.
|
27
|
+
|
28
|
+
Note:
|
29
|
+
The original `TTT <https://github.com/yueatsprograms/ttt_cifar_release/blob/acac817fb7615850d19a8f8e79930240c9afe8b5/utils/rotation.py#L27>`_ implementation uses a four-class classification task, corresponding to image rotations of 0°, 90°, 180°, and 270°.
|
30
|
+
|
31
|
+
:Example:
|
32
|
+
|
33
|
+
.. code-block:: python
|
34
|
+
|
35
|
+
from torch_ttt.engine.ttt_engine import TTTEngine
|
36
|
+
|
37
|
+
model = MyModel()
|
38
|
+
engine = TTTEngine(model, "fc1")
|
39
|
+
optimizer = torch.optim.Adam(engine.parameters(), lr=1e-4)
|
40
|
+
|
41
|
+
# Training
|
42
|
+
engine.train()
|
43
|
+
for inputs, labels in train_loader:
|
44
|
+
optimizer.zero_grad()
|
45
|
+
outputs, loss_ttt = engine(inputs)
|
46
|
+
loss = criterion(outputs, labels) + alpha * loss_ttt
|
47
|
+
loss.backward()
|
48
|
+
optimizer.step()
|
49
|
+
|
50
|
+
# Inference
|
51
|
+
engine.eval()
|
52
|
+
for inputs, labels in test_loader:
|
53
|
+
output, loss_ttt = engine(inputs)
|
54
|
+
|
55
|
+
Reference:
|
56
|
+
|
57
|
+
"Test-Time Training with Self-Supervision for Generalization under Distribution Shifts", Yu Sun, Xiaolong Wang, Zhuang Liu, John Miller, Alexei A. Efros, Moritz Hardt
|
58
|
+
|
59
|
+
Paper link: `PDF <http://proceedings.mlr.press/v119/sun20b/sun20b.pdf>`_
|
60
|
+
"""
|
61
|
+
|
62
|
+
def __init__(
|
63
|
+
self,
|
64
|
+
model: torch.nn.Module,
|
65
|
+
features_layer_name: str,
|
66
|
+
angle_head: torch.nn.Module = None,
|
67
|
+
angle_criterion: torch.nn.Module = None,
|
68
|
+
optimization_parameters: Dict[str, Any] = {},
|
69
|
+
) -> None:
|
70
|
+
super().__init__()
|
71
|
+
self.model = model
|
72
|
+
self.angle_head = angle_head
|
73
|
+
self.angle_criterion = angle_criterion if angle_criterion else torch.nn.CrossEntropyLoss()
|
74
|
+
self.features_layer_name = features_layer_name
|
75
|
+
self.optimization_parameters = optimization_parameters
|
76
|
+
|
77
|
+
# Locate and store the reference to the target module
|
78
|
+
self.target_module = None
|
79
|
+
for name, module in model.named_modules():
|
80
|
+
if name == features_layer_name:
|
81
|
+
self.target_module = module
|
82
|
+
break
|
83
|
+
|
84
|
+
if self.target_module is None:
|
85
|
+
raise ValueError(f"Module '{features_layer_name}' not found in the model.")
|
86
|
+
|
87
|
+
def ttt_forward(self, inputs) -> Tuple[torch.Tensor, torch.Tensor]:
|
88
|
+
"""Forward pass of the model.
|
89
|
+
|
90
|
+
Args:
|
91
|
+
inputs (torch.Tensor): Input tensor.
|
92
|
+
|
93
|
+
Returns:
|
94
|
+
Returns the current model prediction and rotation loss value.
|
95
|
+
"""
|
96
|
+
|
97
|
+
# has to dynamically register a hook to get the features and then remove it
|
98
|
+
# need this for deepcopying the engine, see https://github.com/pytorch/pytorch/pull/103001
|
99
|
+
with self.__capture_hook() as features_hook:
|
100
|
+
# Original forward pass, intact
|
101
|
+
outputs = self.model(inputs)
|
102
|
+
|
103
|
+
# See original code: https://github.com/yueatsprograms/ttt_cifar_release/blob/acac817fb7615850d19a8f8e79930240c9afe8b5/main.py#L69
|
104
|
+
rotated_inputs, rotation_labels = self.__rotate_inputs(inputs)
|
105
|
+
_ = self.model(rotated_inputs)
|
106
|
+
features = features_hook.output
|
107
|
+
|
108
|
+
# Build angle head if not already built
|
109
|
+
if self.angle_head is None:
|
110
|
+
self.angle_head = self.__build_angle_head(features)
|
111
|
+
|
112
|
+
# move angle head to the same device as the features
|
113
|
+
self.angle_head.to(features.device)
|
114
|
+
angles = self.angle_head(features)
|
115
|
+
|
116
|
+
# Compute rotation loss
|
117
|
+
rotation_loss = self.angle_criterion(angles, rotation_labels)
|
118
|
+
return outputs, rotation_loss
|
119
|
+
|
120
|
+
# Follow this code (expand case): https://github.com/yueatsprograms/ttt_cifar_release/blob/acac817fb7615850d19a8f8e79930240c9afe8b5/utils/rotation.py#L27
|
121
|
+
def __rotate_inputs(self, inputs) -> Tuple[torch.Tensor, torch.Tensor]:
|
122
|
+
"""Rotate the input images by 0, 90, 180, and 270 degrees."""
|
123
|
+
device = next(self.model.parameters()).device
|
124
|
+
rotated_image_90 = F.rotate(inputs, 90)
|
125
|
+
rotated_image_180 = F.rotate(inputs, 180)
|
126
|
+
rotated_image_270 = F.rotate(inputs, 270)
|
127
|
+
batch_size = inputs.shape[0]
|
128
|
+
inputs = torch.cat([inputs, rotated_image_90, rotated_image_180, rotated_image_270], dim=0)
|
129
|
+
labels = [0] * batch_size + [1] * batch_size + [2] * batch_size + [3] * batch_size
|
130
|
+
return inputs.to(device), torch.tensor(labels, dtype=torch.long).to(device)
|
131
|
+
|
132
|
+
def __build_angle_head(self, features) -> torch.nn.Module:
|
133
|
+
"""Build the angle head."""
|
134
|
+
device = next(self.model.parameters()).device
|
135
|
+
|
136
|
+
# See original implementation: https://github.com/yueatsprograms/ttt_cifar_release/blob/acac817fb7615850d19a8f8e79930240c9afe8b5/utils/test_helpers.py#L33C10-L33C39
|
137
|
+
if len(features.shape) == 2:
|
138
|
+
return torch.nn.Sequential(
|
139
|
+
torch.nn.Linear(features.shape[1], 16),
|
140
|
+
torch.nn.ReLU(),
|
141
|
+
torch.nn.Linear(16, 8),
|
142
|
+
torch.nn.ReLU(),
|
143
|
+
torch.nn.Linear(8, 4),
|
144
|
+
).to(device)
|
145
|
+
|
146
|
+
# See original implementation: https://github.com/yueatsprograms/ttt_cifar_release/blob/acac817fb7615850d19a8f8e79930240c9afe8b5/models/SSHead.py#L29
|
147
|
+
elif len(features.shape) == 4:
|
148
|
+
return torch.nn.Sequential(
|
149
|
+
torch.nn.Conv2d(features.shape[1], 16, 3),
|
150
|
+
torch.nn.ReLU(),
|
151
|
+
torch.nn.Conv2d(16, 4, 3),
|
152
|
+
torch.nn.AdaptiveAvgPool2d((1, 1)), # Global Average Pooling
|
153
|
+
torch.nn.Flatten(),
|
154
|
+
).to(device)
|
155
|
+
|
156
|
+
elif len(features.shape) == 5: # For 3D inputs (batch, channels, depth, height, width)
|
157
|
+
return torch.nn.Sequential(
|
158
|
+
torch.nn.Conv3d(features.shape[1], 16, kernel_size=3),
|
159
|
+
torch.nn.ReLU(),
|
160
|
+
torch.nn.Conv3d(16, 4, kernel_size=3),
|
161
|
+
torch.nn.AdaptiveAvgPool3d((1, 1, 1)), # Global Average Pooling
|
162
|
+
torch.nn.Flatten(),
|
163
|
+
).to(device)
|
164
|
+
|
165
|
+
raise ValueError("Invalid input tensor shape.")
|
166
|
+
|
167
|
+
@contextmanager
|
168
|
+
def __capture_hook(self):
|
169
|
+
"""Context manager to capture features via a forward hook."""
|
170
|
+
|
171
|
+
class OutputHook:
|
172
|
+
def __init__(self):
|
173
|
+
self.output = None
|
174
|
+
|
175
|
+
def hook(self, module, input, output):
|
176
|
+
self.output = output
|
177
|
+
|
178
|
+
features_hook = OutputHook()
|
179
|
+
hook_handle = self.target_module.register_forward_hook(features_hook.hook)
|
180
|
+
|
181
|
+
try:
|
182
|
+
yield features_hook
|
183
|
+
finally:
|
184
|
+
hook_handle.remove()
|
@@ -0,0 +1,291 @@
|
|
1
|
+
import torch
|
2
|
+
from typing import Tuple, Optional, Callable, Dict, Any
|
3
|
+
from contextlib import contextmanager
|
4
|
+
|
5
|
+
from torchvision import transforms
|
6
|
+
from torch.utils.data import DataLoader
|
7
|
+
from torch_ttt.engine.base_engine import BaseEngine
|
8
|
+
from torch_ttt.engine_registry import EngineRegistry
|
9
|
+
from torch_ttt.loss.contrastive_loss import ContrastiveLoss
|
10
|
+
from torch_ttt.utils.augmentations import RandomResizedCrop
|
11
|
+
|
12
|
+
__all__ = ["TTTPPEngine"]
|
13
|
+
|
14
|
+
|
15
|
+
# TODO: finish this class
|
16
|
+
@EngineRegistry.register("ttt_pp")
|
17
|
+
class TTTPPEngine(BaseEngine):
|
18
|
+
"""**TTT++** approach: feature alignment-based + SimCLR loss.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
model (torch.nn.Module): Model to be trained with TTT.
|
22
|
+
features_layer_name (str): The name of the layer from which the features are extracted.
|
23
|
+
contrastive_head (torch.nn.Module, optional): The head that is used for SimCLR's Loss.
|
24
|
+
contrastive_criterion (torch.nn.Module, optional): The loss function used for SimCLR.
|
25
|
+
contrastive_transform (callable): A transformation or a composition of transformations applied to the input images to generate augmented views for contrastive learning.
|
26
|
+
scale_cov (float): The scale factor for the covariance loss.
|
27
|
+
scale_mu (float): The scale factor for the mean loss.
|
28
|
+
scale_c_cov (float): The scale factor for the contrastive covariance loss.
|
29
|
+
scale_c_mu (float): The scale factor for the contrastive mean loss.
|
30
|
+
optimization_parameters (dict): The optimization parameters for the engine.
|
31
|
+
Warning:
|
32
|
+
The module with the name :attr:`features_layer_name` should be present in the model.
|
33
|
+
|
34
|
+
:Example:
|
35
|
+
|
36
|
+
.. code-block:: python
|
37
|
+
|
38
|
+
from torch_ttt.engine.ttt_pp_engine import TTTPPEngine
|
39
|
+
|
40
|
+
model = MyModel()
|
41
|
+
engine = TTTPPEngine(model, "fc1")
|
42
|
+
optimizer = torch.optim.Adam(engine.parameters(), lr=1e-4)
|
43
|
+
|
44
|
+
# Training
|
45
|
+
engine.train()
|
46
|
+
for inputs, labels in train_loader:
|
47
|
+
optimizer.zero_grad()
|
48
|
+
outputs, loss_ttt = engine(inputs)
|
49
|
+
loss = criterion(outputs, labels) + alpha * loss_ttt
|
50
|
+
loss.backward()
|
51
|
+
optimizer.step()
|
52
|
+
|
53
|
+
# Compute statistics for features alignment
|
54
|
+
engine.compute_statistics(train_loader)
|
55
|
+
|
56
|
+
# Inference
|
57
|
+
engine.eval()
|
58
|
+
for inputs, labels in test_loader:
|
59
|
+
output, loss_ttt = engine(inputs)
|
60
|
+
|
61
|
+
Reference:
|
62
|
+
|
63
|
+
"TTT++: When Does Self-Supervised Test-Time Training Fail or Thrive?", Yuejiang Liu, Parth Kothari, Bastien van Delft, Baptiste Bellot-Gurlet, Taylor Mordan, Alexandre Alahi
|
64
|
+
|
65
|
+
Paper link: `PDF <https://proceedings.neurips.cc/paper/2021/hash/b618c3210e934362ac261db280128c22-Abstract.html>`_
|
66
|
+
"""
|
67
|
+
|
68
|
+
def __init__(
|
69
|
+
self,
|
70
|
+
model: torch.nn.Module,
|
71
|
+
features_layer_name: str,
|
72
|
+
contrastive_head: torch.nn.Module = None,
|
73
|
+
contrastive_criterion: torch.nn.Module = ContrastiveLoss(),
|
74
|
+
contrastive_transform: Optional[Callable] = None,
|
75
|
+
scale_cov: float = 0.1,
|
76
|
+
scale_mu: float = 0.1,
|
77
|
+
scale_c_cov: float = 0.1,
|
78
|
+
scale_c_mu: float = 0.1,
|
79
|
+
optimization_parameters: Dict[str, Any] = {},
|
80
|
+
) -> None:
|
81
|
+
super().__init__()
|
82
|
+
self.model = model
|
83
|
+
self.features_layer_name = features_layer_name
|
84
|
+
self.contrastive_head = contrastive_head
|
85
|
+
self.contrastive_criterion = (
|
86
|
+
contrastive_criterion if contrastive_criterion else ContrastiveLoss()
|
87
|
+
)
|
88
|
+
self.scale_cov = scale_cov
|
89
|
+
self.scale_mu = scale_mu
|
90
|
+
self.scale_c_cov = scale_c_cov
|
91
|
+
self.scale_c_mu = scale_c_mu
|
92
|
+
self.contrastive_transform = contrastive_transform
|
93
|
+
|
94
|
+
self.reference_cov = None
|
95
|
+
self.reference_mean = None
|
96
|
+
self.reference_c_cov = None
|
97
|
+
self.reference_c_mean = None
|
98
|
+
|
99
|
+
self.optimization_parameters = optimization_parameters
|
100
|
+
|
101
|
+
# Locate and store the reference to the target module
|
102
|
+
self.target_module = None
|
103
|
+
for name, module in model.named_modules():
|
104
|
+
if name == features_layer_name:
|
105
|
+
self.target_module = module
|
106
|
+
break
|
107
|
+
|
108
|
+
if self.target_module is None:
|
109
|
+
raise ValueError(f"Module '{features_layer_name}' not found in the model.")
|
110
|
+
|
111
|
+
# Validate that the target module is a Linear layer
|
112
|
+
if not isinstance(self.target_module, torch.nn.Linear):
|
113
|
+
raise TypeError(
|
114
|
+
f"Module '{features_layer_name}' is expected to be of type 'torch.nn.Linear', "
|
115
|
+
f"but found type '{type(self.target_module).__name__}'."
|
116
|
+
)
|
117
|
+
|
118
|
+
if contrastive_transform is None:
|
119
|
+
# default SimCLR augmentation
|
120
|
+
self.contrastive_transform = transforms.Compose(
|
121
|
+
[
|
122
|
+
RandomResizedCrop(scale=(0.2, 1.0)),
|
123
|
+
transforms.RandomGrayscale(p=0.2),
|
124
|
+
transforms.RandomApply([transforms.GaussianBlur(5)], p=0.3),
|
125
|
+
transforms.RandomHorizontalFlip(),
|
126
|
+
]
|
127
|
+
)
|
128
|
+
|
129
|
+
def __build_contrastive_head(self, features) -> torch.nn.Module:
|
130
|
+
"""Build the angle head."""
|
131
|
+
device = next(self.model.parameters()).device
|
132
|
+
if len(features.shape) == 2:
|
133
|
+
return torch.nn.Sequential(
|
134
|
+
torch.nn.Linear(features.shape[1], 16),
|
135
|
+
torch.nn.ReLU(),
|
136
|
+
torch.nn.Linear(16, 16),
|
137
|
+
torch.nn.ReLU(),
|
138
|
+
torch.nn.Linear(16, 16),
|
139
|
+
).to(device)
|
140
|
+
|
141
|
+
raise ValueError("Features should be 2D tensor.")
|
142
|
+
|
143
|
+
def ttt_forward(self, inputs) -> Tuple[torch.Tensor, torch.Tensor]:
|
144
|
+
"""Forward pass of the model.
|
145
|
+
|
146
|
+
Args:
|
147
|
+
inputs (torch.Tensor): Input tensor.
|
148
|
+
|
149
|
+
Returns:
|
150
|
+
Returns the model prediction and TTT++ loss value.
|
151
|
+
"""
|
152
|
+
|
153
|
+
# reset reference statistics during training
|
154
|
+
if self.training:
|
155
|
+
self.reference_cov = None
|
156
|
+
self.reference_mean = None
|
157
|
+
self.reference_c_cov = None
|
158
|
+
self.reference_c_mean = None
|
159
|
+
|
160
|
+
contrastive_inputs = torch.cat(
|
161
|
+
[self.contrastive_transform(inputs), self.contrastive_transform(inputs)], dim=0
|
162
|
+
)
|
163
|
+
|
164
|
+
# extract features for contrastive loss
|
165
|
+
with self.__capture_hook() as features_hook:
|
166
|
+
_ = self.model(contrastive_inputs)
|
167
|
+
features = features_hook.output
|
168
|
+
|
169
|
+
# Build angle head if not already built
|
170
|
+
if self.contrastive_head is None:
|
171
|
+
self.contrastive_head = self.__build_contrastive_head(features)
|
172
|
+
|
173
|
+
contrasitve_features = self.contrastive_head(features)
|
174
|
+
contrasitve_features = contrasitve_features.view(2, len(inputs), -1).transpose(0, 1)
|
175
|
+
loss = self.contrastive_criterion(contrasitve_features)
|
176
|
+
|
177
|
+
# make inference for a final prediction
|
178
|
+
with self.__capture_hook() as features_hook:
|
179
|
+
outputs = self.model(inputs)
|
180
|
+
features = features_hook.output
|
181
|
+
|
182
|
+
# compute alignment loss only during test
|
183
|
+
if not self.training:
|
184
|
+
if (
|
185
|
+
self.reference_cov is None
|
186
|
+
or self.reference_mean is None
|
187
|
+
or self.reference_c_cov is None
|
188
|
+
or self.reference_c_mean is None
|
189
|
+
):
|
190
|
+
raise ValueError(
|
191
|
+
"Reference statistics are not computed. Please call `compute_statistics` method."
|
192
|
+
)
|
193
|
+
|
194
|
+
# compute features alignment loss
|
195
|
+
cov_ext = self.__covariance(features)
|
196
|
+
mu_ext = features.mean(dim=0)
|
197
|
+
|
198
|
+
d = self.reference_cov.shape[0]
|
199
|
+
|
200
|
+
loss += self.scale_cov * (self.reference_cov - cov_ext).pow(2).sum() / (4.0 * d**2)
|
201
|
+
loss += self.scale_mu * (self.reference_mean - mu_ext).pow(2).mean()
|
202
|
+
|
203
|
+
# compute contrastive features alignment loss
|
204
|
+
c_features = self.contrastive_head(features)
|
205
|
+
|
206
|
+
cov_ext = self.__covariance(c_features)
|
207
|
+
mu_ext = c_features.mean(dim=0)
|
208
|
+
|
209
|
+
d = self.reference_c_cov.shape[0]
|
210
|
+
loss += self.scale_c_cov * (self.reference_c_cov - cov_ext).pow(2).sum() / (4.0 * d**2)
|
211
|
+
loss += self.scale_c_mu * (self.reference_c_mean - mu_ext).pow(2).mean()
|
212
|
+
|
213
|
+
return outputs, loss
|
214
|
+
|
215
|
+
@staticmethod
|
216
|
+
def __covariance(features):
|
217
|
+
"""Legacy wrapper to maintain compatibility in the engine."""
|
218
|
+
from torch_ttt.utils.math import compute_covariance
|
219
|
+
|
220
|
+
return compute_covariance(features, dim=0)
|
221
|
+
|
222
|
+
def compute_statistics(self, dataloader: DataLoader) -> None:
|
223
|
+
"""Extract and compute reference statistics for features and contrastive features.
|
224
|
+
|
225
|
+
Args:
|
226
|
+
dataloader (DataLoader): The dataloader used for extracting features. It can return tuples of tensors, with the first element expected to be the input tensor.
|
227
|
+
|
228
|
+
Raises:
|
229
|
+
ValueError: If the dataloader is empty or features have mismatched dimensions.
|
230
|
+
"""
|
231
|
+
|
232
|
+
self.model.eval()
|
233
|
+
|
234
|
+
feat_stack = []
|
235
|
+
c_feat_stack = []
|
236
|
+
|
237
|
+
with torch.no_grad():
|
238
|
+
device = next(self.model.parameters()).device
|
239
|
+
for sample in dataloader:
|
240
|
+
if len(sample) < 1:
|
241
|
+
raise ValueError("Dataloader returned an empty batch.")
|
242
|
+
|
243
|
+
inputs = sample[0].to(device)
|
244
|
+
with self.__capture_hook() as features_hook:
|
245
|
+
_ = self.model(inputs)
|
246
|
+
feat = features_hook.output
|
247
|
+
|
248
|
+
# Initialize contrastive head if not already initialized
|
249
|
+
if self.contrastive_head is None:
|
250
|
+
self.contrastive_head = self.__build_contrastive_head(feat)
|
251
|
+
|
252
|
+
# Compute contrastive features
|
253
|
+
contrastive_feat = self.contrastive_head(feat)
|
254
|
+
|
255
|
+
feat_stack.append(feat.cpu())
|
256
|
+
c_feat_stack.append(contrastive_feat.cpu())
|
257
|
+
|
258
|
+
# compute features statistics
|
259
|
+
feat_all = torch.cat(feat_stack)
|
260
|
+
feat_cov = self.__covariance(feat_all)
|
261
|
+
feat_mean = feat_all.mean(dim=0)
|
262
|
+
|
263
|
+
self.reference_cov = feat_cov.to(device)
|
264
|
+
self.reference_mean = feat_mean.to(device)
|
265
|
+
|
266
|
+
# compute contrastive features statistics
|
267
|
+
feat_all = torch.cat(c_feat_stack)
|
268
|
+
feat_cov = self.__covariance(feat_all)
|
269
|
+
feat_mean = feat_all.mean(dim=0)
|
270
|
+
|
271
|
+
self.reference_c_cov = feat_cov.to(device)
|
272
|
+
self.reference_c_mean = feat_mean.to(device)
|
273
|
+
|
274
|
+
@contextmanager
|
275
|
+
def __capture_hook(self):
|
276
|
+
"""Context manager to capture features via a forward hook."""
|
277
|
+
|
278
|
+
class OutputHook:
|
279
|
+
def __init__(self):
|
280
|
+
self.output = None
|
281
|
+
|
282
|
+
def hook(self, module, input, output):
|
283
|
+
self.output = output
|
284
|
+
|
285
|
+
features_hook = OutputHook()
|
286
|
+
hook_handle = self.target_module.register_forward_hook(features_hook.hook)
|
287
|
+
|
288
|
+
try:
|
289
|
+
yield features_hook
|
290
|
+
finally:
|
291
|
+
hook_handle.remove()
|
@@ -0,0 +1,38 @@
|
|
1
|
+
import importlib
|
2
|
+
import os
|
3
|
+
from torch_ttt.engine.base_engine import BaseEngine
|
4
|
+
|
5
|
+
|
6
|
+
class EngineRegistry:
|
7
|
+
_registry = {}
|
8
|
+
|
9
|
+
@classmethod
|
10
|
+
def register(cls, name):
|
11
|
+
def decorator(engine_class):
|
12
|
+
if not issubclass(engine_class, BaseEngine):
|
13
|
+
raise TypeError(f"Engine class '{name}' must inherit from BaseEngine.")
|
14
|
+
if name in cls._registry:
|
15
|
+
raise ValueError(f"Engine '{name}' is already registered.")
|
16
|
+
|
17
|
+
cls._registry[name] = engine_class
|
18
|
+
return engine_class
|
19
|
+
|
20
|
+
return decorator
|
21
|
+
|
22
|
+
@classmethod
|
23
|
+
def get_engine(cls, name):
|
24
|
+
if name not in cls._registry:
|
25
|
+
raise ValueError(f"Engine '{name}' is not registered.")
|
26
|
+
return cls._registry[name]
|
27
|
+
|
28
|
+
|
29
|
+
# Dynamically import all losses in the losses directory
|
30
|
+
def register_all_engines():
|
31
|
+
losses_dir = os.path.dirname(__file__) + "/engine"
|
32
|
+
for file in os.listdir(losses_dir):
|
33
|
+
if file.endswith("_engine.py") and not file.startswith("__"):
|
34
|
+
module_name = f"torch_ttt.engine.{file[:-3]}"
|
35
|
+
importlib.import_module(module_name)
|
36
|
+
|
37
|
+
|
38
|
+
register_all_engines()
|
@@ -0,0 +1,94 @@
|
|
1
|
+
import torch
|
2
|
+
from torch_ttt.loss.base_loss import BaseLoss
|
3
|
+
from torch_ttt.loss_registry import LossRegistry
|
4
|
+
|
5
|
+
|
6
|
+
# TODO: make it more readable
|
7
|
+
@LossRegistry.register("contrastive")
|
8
|
+
class ContrastiveLoss(BaseLoss):
|
9
|
+
"""Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
|
10
|
+
It also supports the unsupervised contrastive loss in SimCLR"""
|
11
|
+
|
12
|
+
def __init__(self, temperature=0.07, contrast_mode="all", base_temperature=0.07) -> None:
|
13
|
+
super(ContrastiveLoss, self).__init__()
|
14
|
+
self.temperature = temperature
|
15
|
+
self.contrast_mode = contrast_mode
|
16
|
+
self.base_temperature = base_temperature
|
17
|
+
|
18
|
+
def forward(self, features, labels=None, mask=None):
|
19
|
+
"""Compute loss for model. If both `labels` and `mask` are None,
|
20
|
+
it degenerates to SimCLR unsupervised loss:
|
21
|
+
https://arxiv.org/pdf/2002.05709.pdf
|
22
|
+
|
23
|
+
Args:
|
24
|
+
features: hidden vector of shape [bsz, n_views, ...].
|
25
|
+
labels: ground truth of shape [bsz].
|
26
|
+
mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
|
27
|
+
has the same class as sample i. Can be asymmetric.
|
28
|
+
Returns:
|
29
|
+
A loss scalar.
|
30
|
+
"""
|
31
|
+
device = torch.device("cuda") if features.is_cuda else torch.device("cpu")
|
32
|
+
|
33
|
+
if len(features.shape) < 3:
|
34
|
+
raise ValueError(
|
35
|
+
"`features` needs to be [bsz, n_views, ...]," "at least 3 dimensions are required"
|
36
|
+
)
|
37
|
+
if len(features.shape) > 3:
|
38
|
+
features = features.view(features.shape[0], features.shape[1], -1)
|
39
|
+
|
40
|
+
batch_size = features.shape[0]
|
41
|
+
if labels is not None and mask is not None:
|
42
|
+
raise ValueError("Cannot define both `labels` and `mask`")
|
43
|
+
elif labels is None and mask is None:
|
44
|
+
mask = torch.eye(batch_size, dtype=torch.float32).to(device)
|
45
|
+
elif labels is not None:
|
46
|
+
labels = labels.contiguous().view(-1, 1)
|
47
|
+
if labels.shape[0] != batch_size:
|
48
|
+
raise ValueError("Num of labels does not match num of features")
|
49
|
+
mask = torch.eq(labels, labels.T).float().to(device)
|
50
|
+
else:
|
51
|
+
mask = mask.float().to(device)
|
52
|
+
|
53
|
+
contrast_count = features.shape[1]
|
54
|
+
contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
|
55
|
+
if self.contrast_mode == "one":
|
56
|
+
anchor_feature = features[:, 0]
|
57
|
+
anchor_count = 1
|
58
|
+
elif self.contrast_mode == "all":
|
59
|
+
anchor_feature = contrast_feature
|
60
|
+
anchor_count = contrast_count
|
61
|
+
else:
|
62
|
+
raise ValueError("Unknown mode: {}".format(self.contrast_mode))
|
63
|
+
|
64
|
+
# compute logits
|
65
|
+
anchor_dot_contrast = torch.div(
|
66
|
+
torch.matmul(anchor_feature, contrast_feature.T), self.temperature
|
67
|
+
)
|
68
|
+
# for numerical stability
|
69
|
+
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
|
70
|
+
logits = anchor_dot_contrast - logits_max.detach()
|
71
|
+
|
72
|
+
# tile mask
|
73
|
+
mask = mask.repeat(anchor_count, contrast_count)
|
74
|
+
# mask-out self-contrast cases
|
75
|
+
logits_mask = torch.scatter(
|
76
|
+
torch.ones_like(mask),
|
77
|
+
1,
|
78
|
+
torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
|
79
|
+
0,
|
80
|
+
)
|
81
|
+
mask = mask * logits_mask
|
82
|
+
|
83
|
+
# compute log_prob
|
84
|
+
exp_logits = torch.exp(logits) * logits_mask
|
85
|
+
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
|
86
|
+
|
87
|
+
# compute mean of log-likelihood over positive
|
88
|
+
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
|
89
|
+
|
90
|
+
# loss
|
91
|
+
loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
|
92
|
+
loss = loss.view(anchor_count, batch_size).mean()
|
93
|
+
|
94
|
+
return loss
|
@@ -0,0 +1,15 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn.functional as F
|
3
|
+
from torch_ttt.loss.base_loss import BaseLoss
|
4
|
+
from torch_ttt.loss_registry import LossRegistry
|
5
|
+
|
6
|
+
|
7
|
+
@LossRegistry.register("entropy")
|
8
|
+
class EntropyLoss(BaseLoss):
|
9
|
+
def __init__(self):
|
10
|
+
super().__init__()
|
11
|
+
|
12
|
+
def __call__(self, model, inputs):
|
13
|
+
logits = model(inputs)
|
14
|
+
probs = F.softmax(logits, dim=1)
|
15
|
+
return -torch.sum(probs * torch.log(probs), dim=1).mean()
|
@@ -0,0 +1,12 @@
|
|
1
|
+
from torch_ttt.loss.base_loss import BaseLoss
|
2
|
+
from torch_ttt.loss_registry import LossRegistry
|
3
|
+
|
4
|
+
|
5
|
+
@LossRegistry.register("mean")
|
6
|
+
class MeanLoss(BaseLoss):
|
7
|
+
def __init__(self):
|
8
|
+
super().__init__()
|
9
|
+
|
10
|
+
def __call__(self, model, inputs):
|
11
|
+
outputs = model(inputs)
|
12
|
+
return outputs.mean()
|
@@ -0,0 +1,13 @@
|
|
1
|
+
from torch_ttt.loss.base_loss import BaseLoss
|
2
|
+
from torch_ttt.loss_registry import LossRegistry
|
3
|
+
|
4
|
+
|
5
|
+
@LossRegistry.register("ttt")
|
6
|
+
class TTTLoss(BaseLoss):
|
7
|
+
def __init__(self):
|
8
|
+
super().__init__()
|
9
|
+
|
10
|
+
def __call__(self, model, inputs):
|
11
|
+
# TODO [P2]: check that model is TTTEngine
|
12
|
+
_, loss = model(inputs)
|
13
|
+
return loss
|
@@ -0,0 +1,30 @@
|
|
1
|
+
import torch
|
2
|
+
from torch_ttt.loss.base_loss import BaseLoss
|
3
|
+
from torch_ttt.loss_registry import LossRegistry
|
4
|
+
|
5
|
+
|
6
|
+
@LossRegistry.register("weights_magnitude")
|
7
|
+
class WeightsMagnitudeLoss(BaseLoss):
|
8
|
+
def __init__(self):
|
9
|
+
super().__init__()
|
10
|
+
self.quantile = 0.95
|
11
|
+
|
12
|
+
def __call__(self, model, inputs):
|
13
|
+
# Step 2: Collect all model weights
|
14
|
+
all_weights = []
|
15
|
+
for param in model.parameters():
|
16
|
+
if param.requires_grad: # Focus only on trainable parameters
|
17
|
+
all_weights.append(param.view(-1)) # Flatten weights into a 1D tensor
|
18
|
+
|
19
|
+
# Concatenate all weights into a single tensor
|
20
|
+
all_weights = torch.cat(all_weights)
|
21
|
+
|
22
|
+
# Step 3: Compute the top `quantile` values
|
23
|
+
quantile_value = torch.quantile(all_weights.abs(), self.quantile)
|
24
|
+
top_quantile_weights = all_weights[all_weights.abs() >= quantile_value]
|
25
|
+
|
26
|
+
# Step 4: Compute the average of top quantile weights
|
27
|
+
weight_loss = top_quantile_weights.mean()
|
28
|
+
|
29
|
+
# Return the combined loss
|
30
|
+
return weight_loss
|
@@ -0,0 +1,158 @@
|
|
1
|
+
import torch
|
2
|
+
from torch_ttt.loss.base_loss import BaseLoss
|
3
|
+
from torch_ttt.loss_registry import LossRegistry
|
4
|
+
|
5
|
+
|
6
|
+
@LossRegistry.register("zerot")
|
7
|
+
class ZeroTrainLoss(BaseLoss):
|
8
|
+
def __init__(self):
|
9
|
+
super().__init__()
|
10
|
+
self.quantile = 0.95
|
11
|
+
|
12
|
+
def __call__(self, model, inputs):
|
13
|
+
N = len(list(model.named_parameters()))
|
14
|
+
|
15
|
+
importance_dict = compute_weight_importance(model, inputs, N)
|
16
|
+
|
17
|
+
# Calculate the top s% mean importance as the loss
|
18
|
+
loss = top_s_percent_mean(importance_dict, self.quantile)
|
19
|
+
|
20
|
+
# Return the combined loss
|
21
|
+
return loss
|
22
|
+
|
23
|
+
|
24
|
+
def top_s_percent_mean(importance_dict, s):
|
25
|
+
"""
|
26
|
+
Compute the mean of the top s% of the importance values.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
importance_dict (dict): A dictionary mapping layer names to weight importance tensors.
|
30
|
+
s (float): The percentage (0-100) of top importance values to consider.
|
31
|
+
|
32
|
+
Returns:
|
33
|
+
torch.Tensor: The mean of the top s% importance values, differentiable with respect to original weights.
|
34
|
+
"""
|
35
|
+
# Concatenate all importance values into a single tensor
|
36
|
+
all_importances = torch.cat([importance.view(-1) for importance in importance_dict.values()])
|
37
|
+
|
38
|
+
# Determine the number of top elements to keep
|
39
|
+
k = max(
|
40
|
+
1, int(s / 100.0 * all_importances.numel())
|
41
|
+
) # Ensure at least one element is considered
|
42
|
+
|
43
|
+
# Sort the importance values in descending order
|
44
|
+
top_importances, _ = torch.topk(all_importances, k)
|
45
|
+
|
46
|
+
# Compute the mean of the top s% importance values
|
47
|
+
top_mean = top_importances.mean()
|
48
|
+
|
49
|
+
return top_mean
|
50
|
+
|
51
|
+
|
52
|
+
def optimize_last_n_layers(m, x, N, s, M, lr=1e-3, optimizer=None, verbose=False):
|
53
|
+
"""
|
54
|
+
Optimize the last N layers of the model for M steps using the top s% of weight importance as a TTT loss function.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
m (torch.nn.Module): The model.
|
58
|
+
x (torch.Tensor): The input batch.
|
59
|
+
N (int): The number of layers from the end to optimize.
|
60
|
+
s (float): The percentage (0-100) of top importance values to consider for the loss.
|
61
|
+
M (int): The number of optimization steps.
|
62
|
+
lr (float, optional): Learning rate for the optimizer. Default is 1e-3.
|
63
|
+
|
64
|
+
Returns:
|
65
|
+
torch.Tensor: The model's predictions for the input batch after optimization.
|
66
|
+
"""
|
67
|
+
# Set the model to training mode
|
68
|
+
m.train()
|
69
|
+
|
70
|
+
# Collect parameters of the last N layers to optimize
|
71
|
+
layers = list(m.named_parameters())[-2 * N :]
|
72
|
+
params_to_optimize = [param for name, param in layers if "weight" in name]
|
73
|
+
|
74
|
+
# Ensure that the parameters of the last N layers have requires_grad=True
|
75
|
+
for param in params_to_optimize:
|
76
|
+
param.requires_grad = True
|
77
|
+
|
78
|
+
# Set up the optimizer for the last N layers' parameters
|
79
|
+
if optimizer is None:
|
80
|
+
optimizer = torch.optim.Adam(params_to_optimize, lr=lr)
|
81
|
+
|
82
|
+
# Optimization loop for M steps
|
83
|
+
for step in range(M):
|
84
|
+
# Compute weight importance
|
85
|
+
importance_dict = compute_weight_importance(m, x, N)
|
86
|
+
|
87
|
+
# Calculate the top s% mean importance as the loss
|
88
|
+
loss = top_s_percent_mean(importance_dict, s)
|
89
|
+
|
90
|
+
# Zero the gradients
|
91
|
+
optimizer.zero_grad()
|
92
|
+
|
93
|
+
# Backward pass to compute gradients
|
94
|
+
loss.backward()
|
95
|
+
|
96
|
+
# Step the optimizer to update parameters
|
97
|
+
optimizer.step()
|
98
|
+
|
99
|
+
# Optionally, print loss for monitoring
|
100
|
+
if verbose:
|
101
|
+
print(f"Step {step+1}/{M}, Loss: {loss.item()}")
|
102
|
+
|
103
|
+
# After optimization, run the model on the input x to get predictions
|
104
|
+
m.eval() # Set the model to evaluation mode
|
105
|
+
with torch.no_grad(): # Disable gradient computation for inference
|
106
|
+
predictions = m(x)
|
107
|
+
|
108
|
+
return predictions, optimizer
|
109
|
+
|
110
|
+
|
111
|
+
def compute_weight_importance(m, x, N):
|
112
|
+
"""
|
113
|
+
Compute the weight importance for the last N layers of the model efficiently.
|
114
|
+
|
115
|
+
Args:
|
116
|
+
m (torch.nn.Module): The model.
|
117
|
+
x (torch.Tensor): The input batch.
|
118
|
+
N (int): The number of layers from the end to consider.
|
119
|
+
|
120
|
+
Returns:
|
121
|
+
dict: A dictionary mapping layer names to the weight importance tensor.
|
122
|
+
"""
|
123
|
+
# Ensure the model is in evaluation mode
|
124
|
+
m.eval()
|
125
|
+
|
126
|
+
# Disable gradients for all parameters initially
|
127
|
+
for param in m.parameters():
|
128
|
+
param.requires_grad = False
|
129
|
+
|
130
|
+
# Enable gradients only for the last N layers
|
131
|
+
layers = list(m.named_parameters())[-2 * N :] # Get the last N layers (both weights and biases)
|
132
|
+
for name, param in layers:
|
133
|
+
if "weight" in name:
|
134
|
+
param.requires_grad = True
|
135
|
+
|
136
|
+
# Forward pass to compute the output
|
137
|
+
output = m(x)
|
138
|
+
|
139
|
+
# Use a simple loss function (sum of outputs) to create a scalar output
|
140
|
+
loss = output.mean()
|
141
|
+
|
142
|
+
# Compute gradients with respect to the last N layers' parameters
|
143
|
+
gradients = torch.autograd.grad(
|
144
|
+
loss, [param for name, param in layers if "weight" in name], create_graph=True
|
145
|
+
)
|
146
|
+
|
147
|
+
# Calculate importance: w * \nabla{w} for the last N layers
|
148
|
+
importance_dict = {}
|
149
|
+
grad_idx = 0
|
150
|
+
for name, param in layers:
|
151
|
+
if "weight" in name:
|
152
|
+
importance = (param * gradients[grad_idx]).abs() # w * \nabla{w}
|
153
|
+
# importance = (param).abs() # w * \nabla{w}
|
154
|
+
# importance = gradients[grad_idx].abs() # # w * \nabla{w}
|
155
|
+
importance_dict[name] = importance # Don't detach the importance
|
156
|
+
grad_idx += 1
|
157
|
+
|
158
|
+
return importance_dict
|
@@ -0,0 +1,38 @@
|
|
1
|
+
import importlib
|
2
|
+
import os
|
3
|
+
from torch_ttt.loss.base_loss import BaseLoss
|
4
|
+
|
5
|
+
|
6
|
+
class LossRegistry:
|
7
|
+
_registry = {}
|
8
|
+
|
9
|
+
@classmethod
|
10
|
+
def register(cls, name):
|
11
|
+
def decorator(loss_class):
|
12
|
+
if not issubclass(loss_class, BaseLoss):
|
13
|
+
raise TypeError(f"Loss class '{name}' must inherit from BaseLoss.")
|
14
|
+
if name in cls._registry:
|
15
|
+
raise ValueError(f"Loss '{name}' is already registered.")
|
16
|
+
|
17
|
+
cls._registry[name] = loss_class
|
18
|
+
return loss_class
|
19
|
+
|
20
|
+
return decorator
|
21
|
+
|
22
|
+
@classmethod
|
23
|
+
def get_loss(cls, name):
|
24
|
+
if name not in cls._registry:
|
25
|
+
raise ValueError(f"Loss '{name}' is not registered.")
|
26
|
+
return cls._registry[name]
|
27
|
+
|
28
|
+
|
29
|
+
# Dynamically import all losses in the losses directory
|
30
|
+
def register_all_losses():
|
31
|
+
losses_dir = os.path.dirname(__file__) + "/loss"
|
32
|
+
for file in os.listdir(losses_dir):
|
33
|
+
if file.endswith("_loss.py") and not file.startswith("__"):
|
34
|
+
module_name = f"torch_ttt.loss.{file[:-3]}"
|
35
|
+
importlib.import_module(module_name)
|
36
|
+
|
37
|
+
|
38
|
+
register_all_losses()
|
File without changes
|
@@ -0,0 +1,25 @@
|
|
1
|
+
import random
|
2
|
+
from torchvision.transforms import functional as F
|
3
|
+
|
4
|
+
|
5
|
+
class RandomResizedCrop:
|
6
|
+
def __init__(self, scale=(0.2, 1.0)):
|
7
|
+
self.scale = scale
|
8
|
+
|
9
|
+
def __call__(self, img):
|
10
|
+
# Dynamically compute the crop size
|
11
|
+
original_size = img.shape[-2:] # H × W
|
12
|
+
area = original_size[0] * original_size[1]
|
13
|
+
target_area = random.uniform(self.scale[0], self.scale[1]) * area
|
14
|
+
aspect_ratio = random.uniform(3 / 4, 4 / 3)
|
15
|
+
|
16
|
+
h = int(round((target_area * aspect_ratio) ** 0.5))
|
17
|
+
w = int(round((target_area / aspect_ratio) ** 0.5))
|
18
|
+
|
19
|
+
if random.random() < 0.5: # Randomly swap h and w
|
20
|
+
h, w = w, h
|
21
|
+
|
22
|
+
h = min(h, original_size[0])
|
23
|
+
w = min(w, original_size[1])
|
24
|
+
|
25
|
+
return F.resized_crop(img, top=0, left=0, height=h, width=w, size=original_size)
|
torch_ttt/utils/math.py
ADDED
@@ -0,0 +1,29 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
|
4
|
+
def compute_covariance(features: torch.Tensor, dim: int = 0) -> torch.Tensor:
|
5
|
+
"""Compute covariance matrix for given features along a specific dimension.
|
6
|
+
|
7
|
+
Args:
|
8
|
+
features (torch.Tensor): Input tensor of shape [N, D] or higher dimensions.
|
9
|
+
dim (int): The dimension along which to compute covariance.
|
10
|
+
|
11
|
+
Returns:
|
12
|
+
torch.Tensor: Covariance matrix of shape [D, D].
|
13
|
+
|
14
|
+
Raises:
|
15
|
+
ValueError: If the input tensor has fewer than 2 dimensions.
|
16
|
+
"""
|
17
|
+
if features.ndim < 2:
|
18
|
+
raise ValueError("Input tensor must have at least 2 dimensions to compute covariance.")
|
19
|
+
|
20
|
+
if features.size(dim) <= 1:
|
21
|
+
raise ValueError(
|
22
|
+
f"Cannot compute covariance with less than 2 samples along dimension {dim}."
|
23
|
+
)
|
24
|
+
|
25
|
+
n = features.shape[0]
|
26
|
+
tmp = torch.ones((1, n), device=features.device) @ features
|
27
|
+
cov = (features.t() @ features - (tmp.t() @ tmp) / n) / (n - 1)
|
28
|
+
|
29
|
+
return cov
|
@@ -0,0 +1,16 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: torch-ttt
|
3
|
+
Version: 0.0.1
|
4
|
+
License-File: LICENSE
|
5
|
+
Provides-Extra: dev
|
6
|
+
Requires-Dist: pytest-cov; extra == "dev"
|
7
|
+
Provides-Extra: docs
|
8
|
+
Requires-Dist: sphinxawesome-theme==6.0.0b1; extra == "docs"
|
9
|
+
Requires-Dist: sphinx-gallery==0.18.0; extra == "docs"
|
10
|
+
Requires-Dist: sphinxcontrib-bibtex; extra == "docs"
|
11
|
+
Requires-Dist: sphinxcontrib-googleanalytics; extra == "docs"
|
12
|
+
Requires-Dist: matplotlib; extra == "docs"
|
13
|
+
Requires-Dist: tqdm; extra == "docs"
|
14
|
+
Provides-Extra: all
|
15
|
+
Requires-Dist: torch_ttt[dev,docs]; extra == "all"
|
16
|
+
Dynamic: license-file
|
@@ -0,0 +1,26 @@
|
|
1
|
+
torch_ttt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
torch_ttt/engine_registry.py,sha256=FzjHxBDCF5y8vkToKDyrLcvVNAtw2LUu4Wl-EViNHw4,1144
|
3
|
+
torch_ttt/loss_registry.py,sha256=PGkrOFzeTKzXFoDLHUf-rH8aBAtyxrh4xzyowhVQtv4,1108
|
4
|
+
torch_ttt/engine/__init__.py,sha256=GlxEg_jHxKmyy-uLMIrZfbsPeF9hPDclV68fTIio3zU,183
|
5
|
+
torch_ttt/engine/actmad_engine.py,sha256=tiQLBO2SzYljUzBxdHywzrN188AsfPAqV5X6CQy7dcs,6281
|
6
|
+
torch_ttt/engine/base_engine.py,sha256=sQZq_NDzNlQnBERlRKD43Wo8Z7FJnblwHn0wANe6Q0o,1527
|
7
|
+
torch_ttt/engine/masked_ttt_engine.py,sha256=X-SApdbzzcViOk7lijPca3IMV1wP-8QDGWqWXow5RHo,404
|
8
|
+
torch_ttt/engine/tent_engine.py,sha256=w5JIL5eeQMO_NFlzEV92xL0MwOWExN2-8t7GdRP63Jk,1686
|
9
|
+
torch_ttt/engine/ttt_engine.py,sha256=pzdP4vgxAz-MQwP3KJUF0xFcUpm_hf0I_YzWYx-6q60,7675
|
10
|
+
torch_ttt/engine/ttt_pp_engine.py,sha256=W_6vP7Oc1eeVUBAix2qT3-Hz0JcV11uOdQQsfSMLJC8,11014
|
11
|
+
torch_ttt/loss/__init__.py,sha256=rV6RxTWgU_BtdUX0d7UXhtlTexhIiqxVNm4eDO9HUOU,51
|
12
|
+
torch_ttt/loss/base_loss.py,sha256=O7UhmrK8NGVI8Dd8zCgh_MKh5HNAkYnyq9_aAesbbpg,163
|
13
|
+
torch_ttt/loss/contrastive_loss.py,sha256=uxlDfSqrJEmg8v8TqJgK3LhjcRsF9euGIEGG2sYS1iY,3753
|
14
|
+
torch_ttt/loss/entropy_loss.py,sha256=bM9LetJXic5FSs3kJ6prjXkAkc6xEBaJezfCKx4-bPg,434
|
15
|
+
torch_ttt/loss/mean_loss.py,sha256=iZZ8DvU3-2nT3D0kgXBaSA7oO74rXrp9jrS-QbLUOe0,307
|
16
|
+
torch_ttt/loss/ttt_loss.py,sha256=B83lwl5vOZKn_iKL7OznD3BOR9bf3X4OfCz2jzKXF3Q,346
|
17
|
+
torch_ttt/loss/weights_magnitude_loss.py,sha256=MpyW7fqvoFzwT-nQMZN_uisf1n956knzpZWd8yXmn7s,1067
|
18
|
+
torch_ttt/loss/zerot_loss.py,sha256=k46u3RrDdF44-f_FpIyAS-r-ZsAHOU_qKp733lDn_VI,5388
|
19
|
+
torch_ttt/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
20
|
+
torch_ttt/utils/augmentations.py,sha256=y-thszixcgGvGFtCzMeG5pWvNVr1D_37kqyNJfd9ApE,831
|
21
|
+
torch_ttt/utils/math.py,sha256=mmgsyCOWGXP21gqQyrHhj-axmI4GCQ1AxFCPvvrO5Uk,952
|
22
|
+
torch_ttt-0.0.1.dist-info/licenses/LICENSE,sha256=ppPMBj3EeJB4NHrRgxWoRSoVIxx3apK0BFsIB2-jcco,1072
|
23
|
+
torch_ttt-0.0.1.dist-info/METADATA,sha256=kWAuJPGHukr9eKsvgeDvCIZAZkIgIG2c1OiNkr-gnBQ,562
|
24
|
+
torch_ttt-0.0.1.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
|
25
|
+
torch_ttt-0.0.1.dist-info/top_level.txt,sha256=WP1U2F1ULE-SjEdozpFfUhCR2ZNqIi8lKvUoeQ5G5Zw,10
|
26
|
+
torch_ttt-0.0.1.dist-info/RECORD,,
|
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2024 Nikita Durasov
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
@@ -0,0 +1 @@
|
|
1
|
+
torch_ttt
|