torch-ttt 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
torch_ttt/__init__.py ADDED
File without changes
@@ -0,0 +1,5 @@
1
+ # ruff: noqa: F401
2
+ # from .base_engine import BaseEngine
3
+ # from .ttt_engine import TTTEngine
4
+ # from .ttt_pp_engine import TTTPPEngine
5
+ # from .masked_ttt_engine import MaskedTTTEngine
@@ -0,0 +1,174 @@
1
+ import torch
2
+ from typing import List, Dict, Any, Tuple, Union
3
+ from contextlib import contextmanager
4
+
5
+ from torch.utils.data import DataLoader
6
+ from torch_ttt.engine.base_engine import BaseEngine
7
+ from torch_ttt.engine_registry import EngineRegistry
8
+
9
+ __all__ = ["ActMADEngine"]
10
+
11
+
12
+ # TODO: add cuda support
13
+ @EngineRegistry.register("actmad")
14
+ class ActMADEngine(BaseEngine):
15
+ """**ActMAD** approach: multi-level pixel-wise feature alignment.
16
+
17
+ Args:
18
+ model (torch.nn.Module): Model to be trained with TTT.
19
+ features_layer_names (List[str] | str): List of layer names to be used for feature alignment.
20
+ optimization_parameters (dict): The optimization parameters for the engine.
21
+
22
+ :Example:
23
+
24
+ .. code-block:: python
25
+
26
+ from torch_ttt.engine.actmad_engine import ActMADEngine
27
+
28
+ model = MyModel()
29
+ engine = ActMADEngine(model, ["fc1", "fc2"])
30
+ optimizer = torch.optim.Adam(engine.parameters(), lr=1e-4)
31
+
32
+ # Training
33
+ engine.train()
34
+ for inputs, labels in train_loader:
35
+ optimizer.zero_grad()
36
+ outputs, loss_ttt = engine(inputs)
37
+ loss = criterion(outputs, labels) + alpha * loss_ttt
38
+ loss.backward()
39
+ optimizer.step()
40
+
41
+ # Compute statistics for features alignment
42
+ engine.compute_statistics(train_loader)
43
+
44
+ # Inference
45
+ engine.eval()
46
+ for inputs, labels in test_loader:
47
+ output, loss_ttt = engine(inputs)
48
+
49
+ Reference:
50
+
51
+ "ActMAD: Activation Matching to Align Distributions for Test-Time Training", M. Jehanzeb Mirza, Pol Jane Soneira, Wei Lin, Mateusz Kozinski, Horst Possegger, Horst Bischof
52
+
53
+ Paper link: `PDF <https://proceedings.neurips.cc/paper/2021/hash/b618c3210e934362ac261db280128c22-Abstract.html>`_
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ model: torch.nn.Module,
59
+ features_layer_names: Union[List[str], str],
60
+ optimization_parameters: Dict[str, Any] = {},
61
+ ):
62
+ super().__init__()
63
+ self.model = model
64
+ self.features_layer_names = features_layer_names
65
+ self.optimization_parameters = optimization_parameters
66
+
67
+ if isinstance(features_layer_names, str):
68
+ self.features_layer_names = [features_layer_names]
69
+
70
+ # TODO: rewrite this
71
+ self.target_modules = []
72
+ for layer_name in self.features_layer_names:
73
+ layer_exists = False
74
+ for name, module in model.named_modules():
75
+ if name == layer_name:
76
+ layer_exists = True
77
+ self.target_modules.append(module)
78
+ break
79
+ if not layer_exists:
80
+ raise ValueError(f"Layer {layer_name} does not exist in the model.")
81
+
82
+ self.reference_mean = None
83
+ self.reference_var = None
84
+
85
+ def ttt_forward(self, inputs) -> Tuple[torch.Tensor, torch.Tensor]:
86
+ """Forward pass of the model.
87
+
88
+ Args:
89
+ inputs (torch.Tensor): Input tensor.
90
+
91
+ Returns:
92
+ Returns the current model prediction and rotation loss value.
93
+ """
94
+ with self.__capture_hook() as features_hooks:
95
+ outputs = self.model(inputs)
96
+ features = [hook.output for hook in features_hooks]
97
+
98
+ # don't compute loss during training
99
+ if self.training:
100
+ return outputs, 0
101
+
102
+ if self.reference_var is None or self.reference_mean is None:
103
+ raise ValueError(
104
+ "Reference statistics are not computed. Please call `compute_statistics` method."
105
+ )
106
+
107
+ l1_loss = torch.nn.L1Loss(reduction='mean')
108
+ features_means = [torch.mean(feature, dim=0) for feature in features]
109
+ features_vars = [torch.var(feature, dim=0) for feature in features]
110
+
111
+ loss = 0
112
+ for i in range(len(self.target_modules)):
113
+ print(features_means[i].device, self.reference_mean[i].device)
114
+ loss += l1_loss(features_means[i], self.reference_mean[i])
115
+ loss += l1_loss(features_vars[i], self.reference_var[i])
116
+
117
+ return outputs, loss
118
+
119
+ def compute_statistics(self, dataloader: DataLoader) -> None:
120
+ """Extract and compute reference statistics for features.
121
+
122
+ Args:
123
+ dataloader (DataLoader): The dataloader used for extracting features. It can return tuples of tensors, with the first element expected to be the input tensor.
124
+
125
+ Raises:
126
+ ValueError: If the dataloader is empty or features have mismatched dimensions.
127
+ """
128
+
129
+ self.model.eval()
130
+ feat_stack = [[] for _ in self.target_modules]
131
+
132
+ # TODO: compute variance in more memory efficient way
133
+ with torch.no_grad():
134
+ device = next(self.model.parameters()).device
135
+ for sample in dataloader:
136
+ if len(sample) < 1:
137
+ raise ValueError("Dataloader returned an empty batch.")
138
+
139
+ inputs = sample[0].to(device)
140
+ with self.__capture_hook() as features_hooks:
141
+ _ = self.model(inputs)
142
+ features = [hook.output.cpu() for hook in features_hooks]
143
+
144
+ for i, feature in enumerate(features):
145
+ feat_stack[i].append(feature)
146
+
147
+ # Compute mean and variance
148
+ self.reference_mean = [torch.mean(torch.cat(feat), dim=0).to(device) for feat in feat_stack]
149
+ self.reference_var = [torch.var(torch.cat(feat), dim=0).to(device) for feat in feat_stack]
150
+
151
+ @contextmanager
152
+ def __capture_hook(self):
153
+ """Context manager to capture features via a forward hook."""
154
+
155
+ class OutputHook:
156
+ def __init__(self):
157
+ self.output = None
158
+
159
+ def hook(self, module, input, output):
160
+ self.output = output
161
+
162
+ hook_handels = []
163
+ features_hooks = []
164
+ for module in self.target_modules:
165
+ hook = OutputHook()
166
+ features_hooks.append(hook)
167
+ hook_handle = module.register_forward_hook(hook.hook)
168
+ hook_handels.append(hook_handle)
169
+
170
+ try:
171
+ yield features_hooks
172
+ finally:
173
+ for hook in hook_handels:
174
+ hook.remove()
@@ -0,0 +1,47 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Tuple
3
+ import torch.nn as nn
4
+ import torch
5
+ from copy import deepcopy
6
+
7
+ class BaseEngine(nn.Module, ABC):
8
+
9
+ def __init__(self):
10
+ nn.Module.__init__(self)
11
+ self.optimization_parameters = {}
12
+
13
+ @abstractmethod
14
+ def ttt_forward(self, inputs) -> Tuple[torch.Tensor, torch.Tensor]:
15
+ pass
16
+
17
+ def forward(self, inputs):
18
+
19
+ if self.training:
20
+ return self.ttt_forward(inputs)
21
+
22
+ # TODO: optimization pipeline should be more flexible and
23
+ # user-defined, need some special structure for that
24
+ optimization_parameters = self.optimization_parameters or {}
25
+
26
+ optimizer_name = optimization_parameters.get("optimizer_name", "adam")
27
+ num_steps = optimization_parameters.get("num_steps", 3)
28
+ lr = optimization_parameters.get("lr", 1e-2)
29
+ copy_model = optimization_parameters.get("copy_model", "True")
30
+
31
+ running_engine = deepcopy(self) if copy_model else self
32
+
33
+ parameters = filter(lambda p: p.requires_grad, running_engine.model.parameters())
34
+ if optimizer_name == "adam":
35
+ optimizer = torch.optim.Adam(parameters, lr=lr)
36
+
37
+ loss = 0 # default value
38
+ for _ in range(num_steps):
39
+ optimizer.zero_grad()
40
+ _, loss = running_engine.ttt_forward(inputs)
41
+ loss.backward()
42
+ optimizer.step()
43
+
44
+ with torch.no_grad():
45
+ final_outputs = running_engine.model(inputs)
46
+
47
+ return final_outputs, loss
@@ -0,0 +1,16 @@
1
+ from torch_ttt.engine.base_engine import BaseEngine
2
+ from torch_ttt.engine_registry import EngineRegistry
3
+
4
+ __all__ = ["MaskedTTTEngine"]
5
+
6
+
7
+ # TODO: add cuda support
8
+ @EngineRegistry.register("masked_ttt")
9
+ class MaskedTTTEngine(BaseEngine):
10
+ """Masked autoencoders-based **test-time training** approach."""
11
+
12
+ def __init__(self):
13
+ super().__init__()
14
+
15
+ def __call__(self, inputs):
16
+ pass
@@ -0,0 +1,49 @@
1
+ import torch
2
+ from typing import Dict, Any, Tuple
3
+ from torch_ttt.engine.base_engine import BaseEngine
4
+ from torch_ttt.engine_registry import EngineRegistry
5
+
6
+ __all__ = ["TentEngine"]
7
+
8
+ @EngineRegistry.register("tent")
9
+ class TentEngine(BaseEngine):
10
+ """**TENT**: Fully test-time adaptation by entropy minimization.
11
+
12
+ Args:
13
+ model (torch.nn.Module): The model to adapt.
14
+ optimization_parameters (dict): Hyperparameters for adaptation.
15
+
16
+ Reference:
17
+ "TENT: Fully Test-Time Adaptation by Entropy Minimization"
18
+ Dequan Wang, Evan Shelhamer, et al.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ model: torch.nn.Module,
24
+ optimization_parameters: Dict[str, Any] = {},
25
+ ):
26
+ super().__init__()
27
+ self.model = model
28
+ self.optimization_parameters = optimization_parameters
29
+
30
+ # Tent adapts only affine parameters in BatchNorm
31
+ self.model.train()
32
+ self._configure_bn()
33
+
34
+ def _configure_bn(self):
35
+ for module in self.model.modules():
36
+ if isinstance(module, torch.nn.BatchNorm2d):
37
+ module.requires_grad_(True)
38
+ module.track_running_stats = False
39
+ else:
40
+ for param in module.parameters(recurse=False):
41
+ param.requires_grad = False
42
+
43
+ def ttt_forward(self, inputs) -> Tuple[torch.Tensor, torch.Tensor]:
44
+ """Forward pass and entropy loss computation."""
45
+ outputs = self.model(inputs)
46
+ probs = torch.nn.functional.softmax(outputs, dim=1)
47
+ log_probs = torch.nn.functional.log_softmax(outputs, dim=1)
48
+ entropy = -torch.sum(probs * log_probs, dim=1).mean()
49
+ return outputs, entropy
@@ -0,0 +1,184 @@
1
+ import torch
2
+ from contextlib import contextmanager
3
+ from typing import Tuple, Dict, Any
4
+ from torchvision.transforms import functional as F
5
+ from torch_ttt.engine.base_engine import BaseEngine
6
+ from torch_ttt.engine_registry import EngineRegistry
7
+
8
+ __all__ = ["TTTEngine"]
9
+
10
+
11
+ @EngineRegistry.register("ttt")
12
+ class TTTEngine(BaseEngine):
13
+ r"""Original image rotation-based **test-time training** approach.
14
+
15
+ Args:
16
+ model (torch.nn.Module): Model to be trained with TTT.
17
+ features_layer_name (str): The name of the layer from which the features are extracted.
18
+ angle_head (torch.nn.Module, optional): The head that predicts the rotation angles.
19
+ angle_criterion (torch.nn.Module, optional): The loss function for the rotation angles.
20
+ optimization_parameters (dict): The optimization parameters for the engine.
21
+
22
+ Warning:
23
+ The module with the name :attr:`features_layer_name` should be present in the model.
24
+
25
+ Note:
26
+ :attr:`angle_head` and :attr:`angle_criterion` are optional arguments and can be user-defined. If not provided, the default shallow head and the :meth:`torch.nn.CrossEntropyLoss()` loss function are used.
27
+
28
+ Note:
29
+ The original `TTT <https://github.com/yueatsprograms/ttt_cifar_release/blob/acac817fb7615850d19a8f8e79930240c9afe8b5/utils/rotation.py#L27>`_ implementation uses a four-class classification task, corresponding to image rotations of 0°, 90°, 180°, and 270°.
30
+
31
+ :Example:
32
+
33
+ .. code-block:: python
34
+
35
+ from torch_ttt.engine.ttt_engine import TTTEngine
36
+
37
+ model = MyModel()
38
+ engine = TTTEngine(model, "fc1")
39
+ optimizer = torch.optim.Adam(engine.parameters(), lr=1e-4)
40
+
41
+ # Training
42
+ engine.train()
43
+ for inputs, labels in train_loader:
44
+ optimizer.zero_grad()
45
+ outputs, loss_ttt = engine(inputs)
46
+ loss = criterion(outputs, labels) + alpha * loss_ttt
47
+ loss.backward()
48
+ optimizer.step()
49
+
50
+ # Inference
51
+ engine.eval()
52
+ for inputs, labels in test_loader:
53
+ output, loss_ttt = engine(inputs)
54
+
55
+ Reference:
56
+
57
+ "Test-Time Training with Self-Supervision for Generalization under Distribution Shifts", Yu Sun, Xiaolong Wang, Zhuang Liu, John Miller, Alexei A. Efros, Moritz Hardt
58
+
59
+ Paper link: `PDF <http://proceedings.mlr.press/v119/sun20b/sun20b.pdf>`_
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ model: torch.nn.Module,
65
+ features_layer_name: str,
66
+ angle_head: torch.nn.Module = None,
67
+ angle_criterion: torch.nn.Module = None,
68
+ optimization_parameters: Dict[str, Any] = {},
69
+ ) -> None:
70
+ super().__init__()
71
+ self.model = model
72
+ self.angle_head = angle_head
73
+ self.angle_criterion = angle_criterion if angle_criterion else torch.nn.CrossEntropyLoss()
74
+ self.features_layer_name = features_layer_name
75
+ self.optimization_parameters = optimization_parameters
76
+
77
+ # Locate and store the reference to the target module
78
+ self.target_module = None
79
+ for name, module in model.named_modules():
80
+ if name == features_layer_name:
81
+ self.target_module = module
82
+ break
83
+
84
+ if self.target_module is None:
85
+ raise ValueError(f"Module '{features_layer_name}' not found in the model.")
86
+
87
+ def ttt_forward(self, inputs) -> Tuple[torch.Tensor, torch.Tensor]:
88
+ """Forward pass of the model.
89
+
90
+ Args:
91
+ inputs (torch.Tensor): Input tensor.
92
+
93
+ Returns:
94
+ Returns the current model prediction and rotation loss value.
95
+ """
96
+
97
+ # has to dynamically register a hook to get the features and then remove it
98
+ # need this for deepcopying the engine, see https://github.com/pytorch/pytorch/pull/103001
99
+ with self.__capture_hook() as features_hook:
100
+ # Original forward pass, intact
101
+ outputs = self.model(inputs)
102
+
103
+ # See original code: https://github.com/yueatsprograms/ttt_cifar_release/blob/acac817fb7615850d19a8f8e79930240c9afe8b5/main.py#L69
104
+ rotated_inputs, rotation_labels = self.__rotate_inputs(inputs)
105
+ _ = self.model(rotated_inputs)
106
+ features = features_hook.output
107
+
108
+ # Build angle head if not already built
109
+ if self.angle_head is None:
110
+ self.angle_head = self.__build_angle_head(features)
111
+
112
+ # move angle head to the same device as the features
113
+ self.angle_head.to(features.device)
114
+ angles = self.angle_head(features)
115
+
116
+ # Compute rotation loss
117
+ rotation_loss = self.angle_criterion(angles, rotation_labels)
118
+ return outputs, rotation_loss
119
+
120
+ # Follow this code (expand case): https://github.com/yueatsprograms/ttt_cifar_release/blob/acac817fb7615850d19a8f8e79930240c9afe8b5/utils/rotation.py#L27
121
+ def __rotate_inputs(self, inputs) -> Tuple[torch.Tensor, torch.Tensor]:
122
+ """Rotate the input images by 0, 90, 180, and 270 degrees."""
123
+ device = next(self.model.parameters()).device
124
+ rotated_image_90 = F.rotate(inputs, 90)
125
+ rotated_image_180 = F.rotate(inputs, 180)
126
+ rotated_image_270 = F.rotate(inputs, 270)
127
+ batch_size = inputs.shape[0]
128
+ inputs = torch.cat([inputs, rotated_image_90, rotated_image_180, rotated_image_270], dim=0)
129
+ labels = [0] * batch_size + [1] * batch_size + [2] * batch_size + [3] * batch_size
130
+ return inputs.to(device), torch.tensor(labels, dtype=torch.long).to(device)
131
+
132
+ def __build_angle_head(self, features) -> torch.nn.Module:
133
+ """Build the angle head."""
134
+ device = next(self.model.parameters()).device
135
+
136
+ # See original implementation: https://github.com/yueatsprograms/ttt_cifar_release/blob/acac817fb7615850d19a8f8e79930240c9afe8b5/utils/test_helpers.py#L33C10-L33C39
137
+ if len(features.shape) == 2:
138
+ return torch.nn.Sequential(
139
+ torch.nn.Linear(features.shape[1], 16),
140
+ torch.nn.ReLU(),
141
+ torch.nn.Linear(16, 8),
142
+ torch.nn.ReLU(),
143
+ torch.nn.Linear(8, 4),
144
+ ).to(device)
145
+
146
+ # See original implementation: https://github.com/yueatsprograms/ttt_cifar_release/blob/acac817fb7615850d19a8f8e79930240c9afe8b5/models/SSHead.py#L29
147
+ elif len(features.shape) == 4:
148
+ return torch.nn.Sequential(
149
+ torch.nn.Conv2d(features.shape[1], 16, 3),
150
+ torch.nn.ReLU(),
151
+ torch.nn.Conv2d(16, 4, 3),
152
+ torch.nn.AdaptiveAvgPool2d((1, 1)), # Global Average Pooling
153
+ torch.nn.Flatten(),
154
+ ).to(device)
155
+
156
+ elif len(features.shape) == 5: # For 3D inputs (batch, channels, depth, height, width)
157
+ return torch.nn.Sequential(
158
+ torch.nn.Conv3d(features.shape[1], 16, kernel_size=3),
159
+ torch.nn.ReLU(),
160
+ torch.nn.Conv3d(16, 4, kernel_size=3),
161
+ torch.nn.AdaptiveAvgPool3d((1, 1, 1)), # Global Average Pooling
162
+ torch.nn.Flatten(),
163
+ ).to(device)
164
+
165
+ raise ValueError("Invalid input tensor shape.")
166
+
167
+ @contextmanager
168
+ def __capture_hook(self):
169
+ """Context manager to capture features via a forward hook."""
170
+
171
+ class OutputHook:
172
+ def __init__(self):
173
+ self.output = None
174
+
175
+ def hook(self, module, input, output):
176
+ self.output = output
177
+
178
+ features_hook = OutputHook()
179
+ hook_handle = self.target_module.register_forward_hook(features_hook.hook)
180
+
181
+ try:
182
+ yield features_hook
183
+ finally:
184
+ hook_handle.remove()
@@ -0,0 +1,291 @@
1
+ import torch
2
+ from typing import Tuple, Optional, Callable, Dict, Any
3
+ from contextlib import contextmanager
4
+
5
+ from torchvision import transforms
6
+ from torch.utils.data import DataLoader
7
+ from torch_ttt.engine.base_engine import BaseEngine
8
+ from torch_ttt.engine_registry import EngineRegistry
9
+ from torch_ttt.loss.contrastive_loss import ContrastiveLoss
10
+ from torch_ttt.utils.augmentations import RandomResizedCrop
11
+
12
+ __all__ = ["TTTPPEngine"]
13
+
14
+
15
+ # TODO: finish this class
16
+ @EngineRegistry.register("ttt_pp")
17
+ class TTTPPEngine(BaseEngine):
18
+ """**TTT++** approach: feature alignment-based + SimCLR loss.
19
+
20
+ Args:
21
+ model (torch.nn.Module): Model to be trained with TTT.
22
+ features_layer_name (str): The name of the layer from which the features are extracted.
23
+ contrastive_head (torch.nn.Module, optional): The head that is used for SimCLR's Loss.
24
+ contrastive_criterion (torch.nn.Module, optional): The loss function used for SimCLR.
25
+ contrastive_transform (callable): A transformation or a composition of transformations applied to the input images to generate augmented views for contrastive learning.
26
+ scale_cov (float): The scale factor for the covariance loss.
27
+ scale_mu (float): The scale factor for the mean loss.
28
+ scale_c_cov (float): The scale factor for the contrastive covariance loss.
29
+ scale_c_mu (float): The scale factor for the contrastive mean loss.
30
+ optimization_parameters (dict): The optimization parameters for the engine.
31
+ Warning:
32
+ The module with the name :attr:`features_layer_name` should be present in the model.
33
+
34
+ :Example:
35
+
36
+ .. code-block:: python
37
+
38
+ from torch_ttt.engine.ttt_pp_engine import TTTPPEngine
39
+
40
+ model = MyModel()
41
+ engine = TTTPPEngine(model, "fc1")
42
+ optimizer = torch.optim.Adam(engine.parameters(), lr=1e-4)
43
+
44
+ # Training
45
+ engine.train()
46
+ for inputs, labels in train_loader:
47
+ optimizer.zero_grad()
48
+ outputs, loss_ttt = engine(inputs)
49
+ loss = criterion(outputs, labels) + alpha * loss_ttt
50
+ loss.backward()
51
+ optimizer.step()
52
+
53
+ # Compute statistics for features alignment
54
+ engine.compute_statistics(train_loader)
55
+
56
+ # Inference
57
+ engine.eval()
58
+ for inputs, labels in test_loader:
59
+ output, loss_ttt = engine(inputs)
60
+
61
+ Reference:
62
+
63
+ "TTT++: When Does Self-Supervised Test-Time Training Fail or Thrive?", Yuejiang Liu, Parth Kothari, Bastien van Delft, Baptiste Bellot-Gurlet, Taylor Mordan, Alexandre Alahi
64
+
65
+ Paper link: `PDF <https://proceedings.neurips.cc/paper/2021/hash/b618c3210e934362ac261db280128c22-Abstract.html>`_
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ model: torch.nn.Module,
71
+ features_layer_name: str,
72
+ contrastive_head: torch.nn.Module = None,
73
+ contrastive_criterion: torch.nn.Module = ContrastiveLoss(),
74
+ contrastive_transform: Optional[Callable] = None,
75
+ scale_cov: float = 0.1,
76
+ scale_mu: float = 0.1,
77
+ scale_c_cov: float = 0.1,
78
+ scale_c_mu: float = 0.1,
79
+ optimization_parameters: Dict[str, Any] = {},
80
+ ) -> None:
81
+ super().__init__()
82
+ self.model = model
83
+ self.features_layer_name = features_layer_name
84
+ self.contrastive_head = contrastive_head
85
+ self.contrastive_criterion = (
86
+ contrastive_criterion if contrastive_criterion else ContrastiveLoss()
87
+ )
88
+ self.scale_cov = scale_cov
89
+ self.scale_mu = scale_mu
90
+ self.scale_c_cov = scale_c_cov
91
+ self.scale_c_mu = scale_c_mu
92
+ self.contrastive_transform = contrastive_transform
93
+
94
+ self.reference_cov = None
95
+ self.reference_mean = None
96
+ self.reference_c_cov = None
97
+ self.reference_c_mean = None
98
+
99
+ self.optimization_parameters = optimization_parameters
100
+
101
+ # Locate and store the reference to the target module
102
+ self.target_module = None
103
+ for name, module in model.named_modules():
104
+ if name == features_layer_name:
105
+ self.target_module = module
106
+ break
107
+
108
+ if self.target_module is None:
109
+ raise ValueError(f"Module '{features_layer_name}' not found in the model.")
110
+
111
+ # Validate that the target module is a Linear layer
112
+ if not isinstance(self.target_module, torch.nn.Linear):
113
+ raise TypeError(
114
+ f"Module '{features_layer_name}' is expected to be of type 'torch.nn.Linear', "
115
+ f"but found type '{type(self.target_module).__name__}'."
116
+ )
117
+
118
+ if contrastive_transform is None:
119
+ # default SimCLR augmentation
120
+ self.contrastive_transform = transforms.Compose(
121
+ [
122
+ RandomResizedCrop(scale=(0.2, 1.0)),
123
+ transforms.RandomGrayscale(p=0.2),
124
+ transforms.RandomApply([transforms.GaussianBlur(5)], p=0.3),
125
+ transforms.RandomHorizontalFlip(),
126
+ ]
127
+ )
128
+
129
+ def __build_contrastive_head(self, features) -> torch.nn.Module:
130
+ """Build the angle head."""
131
+ device = next(self.model.parameters()).device
132
+ if len(features.shape) == 2:
133
+ return torch.nn.Sequential(
134
+ torch.nn.Linear(features.shape[1], 16),
135
+ torch.nn.ReLU(),
136
+ torch.nn.Linear(16, 16),
137
+ torch.nn.ReLU(),
138
+ torch.nn.Linear(16, 16),
139
+ ).to(device)
140
+
141
+ raise ValueError("Features should be 2D tensor.")
142
+
143
+ def ttt_forward(self, inputs) -> Tuple[torch.Tensor, torch.Tensor]:
144
+ """Forward pass of the model.
145
+
146
+ Args:
147
+ inputs (torch.Tensor): Input tensor.
148
+
149
+ Returns:
150
+ Returns the model prediction and TTT++ loss value.
151
+ """
152
+
153
+ # reset reference statistics during training
154
+ if self.training:
155
+ self.reference_cov = None
156
+ self.reference_mean = None
157
+ self.reference_c_cov = None
158
+ self.reference_c_mean = None
159
+
160
+ contrastive_inputs = torch.cat(
161
+ [self.contrastive_transform(inputs), self.contrastive_transform(inputs)], dim=0
162
+ )
163
+
164
+ # extract features for contrastive loss
165
+ with self.__capture_hook() as features_hook:
166
+ _ = self.model(contrastive_inputs)
167
+ features = features_hook.output
168
+
169
+ # Build angle head if not already built
170
+ if self.contrastive_head is None:
171
+ self.contrastive_head = self.__build_contrastive_head(features)
172
+
173
+ contrasitve_features = self.contrastive_head(features)
174
+ contrasitve_features = contrasitve_features.view(2, len(inputs), -1).transpose(0, 1)
175
+ loss = self.contrastive_criterion(contrasitve_features)
176
+
177
+ # make inference for a final prediction
178
+ with self.__capture_hook() as features_hook:
179
+ outputs = self.model(inputs)
180
+ features = features_hook.output
181
+
182
+ # compute alignment loss only during test
183
+ if not self.training:
184
+ if (
185
+ self.reference_cov is None
186
+ or self.reference_mean is None
187
+ or self.reference_c_cov is None
188
+ or self.reference_c_mean is None
189
+ ):
190
+ raise ValueError(
191
+ "Reference statistics are not computed. Please call `compute_statistics` method."
192
+ )
193
+
194
+ # compute features alignment loss
195
+ cov_ext = self.__covariance(features)
196
+ mu_ext = features.mean(dim=0)
197
+
198
+ d = self.reference_cov.shape[0]
199
+
200
+ loss += self.scale_cov * (self.reference_cov - cov_ext).pow(2).sum() / (4.0 * d**2)
201
+ loss += self.scale_mu * (self.reference_mean - mu_ext).pow(2).mean()
202
+
203
+ # compute contrastive features alignment loss
204
+ c_features = self.contrastive_head(features)
205
+
206
+ cov_ext = self.__covariance(c_features)
207
+ mu_ext = c_features.mean(dim=0)
208
+
209
+ d = self.reference_c_cov.shape[0]
210
+ loss += self.scale_c_cov * (self.reference_c_cov - cov_ext).pow(2).sum() / (4.0 * d**2)
211
+ loss += self.scale_c_mu * (self.reference_c_mean - mu_ext).pow(2).mean()
212
+
213
+ return outputs, loss
214
+
215
+ @staticmethod
216
+ def __covariance(features):
217
+ """Legacy wrapper to maintain compatibility in the engine."""
218
+ from torch_ttt.utils.math import compute_covariance
219
+
220
+ return compute_covariance(features, dim=0)
221
+
222
+ def compute_statistics(self, dataloader: DataLoader) -> None:
223
+ """Extract and compute reference statistics for features and contrastive features.
224
+
225
+ Args:
226
+ dataloader (DataLoader): The dataloader used for extracting features. It can return tuples of tensors, with the first element expected to be the input tensor.
227
+
228
+ Raises:
229
+ ValueError: If the dataloader is empty or features have mismatched dimensions.
230
+ """
231
+
232
+ self.model.eval()
233
+
234
+ feat_stack = []
235
+ c_feat_stack = []
236
+
237
+ with torch.no_grad():
238
+ device = next(self.model.parameters()).device
239
+ for sample in dataloader:
240
+ if len(sample) < 1:
241
+ raise ValueError("Dataloader returned an empty batch.")
242
+
243
+ inputs = sample[0].to(device)
244
+ with self.__capture_hook() as features_hook:
245
+ _ = self.model(inputs)
246
+ feat = features_hook.output
247
+
248
+ # Initialize contrastive head if not already initialized
249
+ if self.contrastive_head is None:
250
+ self.contrastive_head = self.__build_contrastive_head(feat)
251
+
252
+ # Compute contrastive features
253
+ contrastive_feat = self.contrastive_head(feat)
254
+
255
+ feat_stack.append(feat.cpu())
256
+ c_feat_stack.append(contrastive_feat.cpu())
257
+
258
+ # compute features statistics
259
+ feat_all = torch.cat(feat_stack)
260
+ feat_cov = self.__covariance(feat_all)
261
+ feat_mean = feat_all.mean(dim=0)
262
+
263
+ self.reference_cov = feat_cov.to(device)
264
+ self.reference_mean = feat_mean.to(device)
265
+
266
+ # compute contrastive features statistics
267
+ feat_all = torch.cat(c_feat_stack)
268
+ feat_cov = self.__covariance(feat_all)
269
+ feat_mean = feat_all.mean(dim=0)
270
+
271
+ self.reference_c_cov = feat_cov.to(device)
272
+ self.reference_c_mean = feat_mean.to(device)
273
+
274
+ @contextmanager
275
+ def __capture_hook(self):
276
+ """Context manager to capture features via a forward hook."""
277
+
278
+ class OutputHook:
279
+ def __init__(self):
280
+ self.output = None
281
+
282
+ def hook(self, module, input, output):
283
+ self.output = output
284
+
285
+ features_hook = OutputHook()
286
+ hook_handle = self.target_module.register_forward_hook(features_hook.hook)
287
+
288
+ try:
289
+ yield features_hook
290
+ finally:
291
+ hook_handle.remove()
@@ -0,0 +1,38 @@
1
+ import importlib
2
+ import os
3
+ from torch_ttt.engine.base_engine import BaseEngine
4
+
5
+
6
+ class EngineRegistry:
7
+ _registry = {}
8
+
9
+ @classmethod
10
+ def register(cls, name):
11
+ def decorator(engine_class):
12
+ if not issubclass(engine_class, BaseEngine):
13
+ raise TypeError(f"Engine class '{name}' must inherit from BaseEngine.")
14
+ if name in cls._registry:
15
+ raise ValueError(f"Engine '{name}' is already registered.")
16
+
17
+ cls._registry[name] = engine_class
18
+ return engine_class
19
+
20
+ return decorator
21
+
22
+ @classmethod
23
+ def get_engine(cls, name):
24
+ if name not in cls._registry:
25
+ raise ValueError(f"Engine '{name}' is not registered.")
26
+ return cls._registry[name]
27
+
28
+
29
+ # Dynamically import all losses in the losses directory
30
+ def register_all_engines():
31
+ losses_dir = os.path.dirname(__file__) + "/engine"
32
+ for file in os.listdir(losses_dir):
33
+ if file.endswith("_engine.py") and not file.startswith("__"):
34
+ module_name = f"torch_ttt.engine.{file[:-3]}"
35
+ importlib.import_module(module_name)
36
+
37
+
38
+ register_all_engines()
@@ -0,0 +1,2 @@
1
+ # ruff: noqa: F401
2
+ from .base_loss import BaseLoss
@@ -0,0 +1,8 @@
1
+ from abc import ABC, abstractmethod
2
+ import torch.nn as nn
3
+
4
+
5
+ class BaseLoss(nn.Module, ABC):
6
+ @abstractmethod
7
+ def forward(self, model, inputs):
8
+ pass
@@ -0,0 +1,94 @@
1
+ import torch
2
+ from torch_ttt.loss.base_loss import BaseLoss
3
+ from torch_ttt.loss_registry import LossRegistry
4
+
5
+
6
+ # TODO: make it more readable
7
+ @LossRegistry.register("contrastive")
8
+ class ContrastiveLoss(BaseLoss):
9
+ """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
10
+ It also supports the unsupervised contrastive loss in SimCLR"""
11
+
12
+ def __init__(self, temperature=0.07, contrast_mode="all", base_temperature=0.07) -> None:
13
+ super(ContrastiveLoss, self).__init__()
14
+ self.temperature = temperature
15
+ self.contrast_mode = contrast_mode
16
+ self.base_temperature = base_temperature
17
+
18
+ def forward(self, features, labels=None, mask=None):
19
+ """Compute loss for model. If both `labels` and `mask` are None,
20
+ it degenerates to SimCLR unsupervised loss:
21
+ https://arxiv.org/pdf/2002.05709.pdf
22
+
23
+ Args:
24
+ features: hidden vector of shape [bsz, n_views, ...].
25
+ labels: ground truth of shape [bsz].
26
+ mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
27
+ has the same class as sample i. Can be asymmetric.
28
+ Returns:
29
+ A loss scalar.
30
+ """
31
+ device = torch.device("cuda") if features.is_cuda else torch.device("cpu")
32
+
33
+ if len(features.shape) < 3:
34
+ raise ValueError(
35
+ "`features` needs to be [bsz, n_views, ...]," "at least 3 dimensions are required"
36
+ )
37
+ if len(features.shape) > 3:
38
+ features = features.view(features.shape[0], features.shape[1], -1)
39
+
40
+ batch_size = features.shape[0]
41
+ if labels is not None and mask is not None:
42
+ raise ValueError("Cannot define both `labels` and `mask`")
43
+ elif labels is None and mask is None:
44
+ mask = torch.eye(batch_size, dtype=torch.float32).to(device)
45
+ elif labels is not None:
46
+ labels = labels.contiguous().view(-1, 1)
47
+ if labels.shape[0] != batch_size:
48
+ raise ValueError("Num of labels does not match num of features")
49
+ mask = torch.eq(labels, labels.T).float().to(device)
50
+ else:
51
+ mask = mask.float().to(device)
52
+
53
+ contrast_count = features.shape[1]
54
+ contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
55
+ if self.contrast_mode == "one":
56
+ anchor_feature = features[:, 0]
57
+ anchor_count = 1
58
+ elif self.contrast_mode == "all":
59
+ anchor_feature = contrast_feature
60
+ anchor_count = contrast_count
61
+ else:
62
+ raise ValueError("Unknown mode: {}".format(self.contrast_mode))
63
+
64
+ # compute logits
65
+ anchor_dot_contrast = torch.div(
66
+ torch.matmul(anchor_feature, contrast_feature.T), self.temperature
67
+ )
68
+ # for numerical stability
69
+ logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
70
+ logits = anchor_dot_contrast - logits_max.detach()
71
+
72
+ # tile mask
73
+ mask = mask.repeat(anchor_count, contrast_count)
74
+ # mask-out self-contrast cases
75
+ logits_mask = torch.scatter(
76
+ torch.ones_like(mask),
77
+ 1,
78
+ torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
79
+ 0,
80
+ )
81
+ mask = mask * logits_mask
82
+
83
+ # compute log_prob
84
+ exp_logits = torch.exp(logits) * logits_mask
85
+ log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
86
+
87
+ # compute mean of log-likelihood over positive
88
+ mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
89
+
90
+ # loss
91
+ loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
92
+ loss = loss.view(anchor_count, batch_size).mean()
93
+
94
+ return loss
@@ -0,0 +1,15 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch_ttt.loss.base_loss import BaseLoss
4
+ from torch_ttt.loss_registry import LossRegistry
5
+
6
+
7
+ @LossRegistry.register("entropy")
8
+ class EntropyLoss(BaseLoss):
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ def __call__(self, model, inputs):
13
+ logits = model(inputs)
14
+ probs = F.softmax(logits, dim=1)
15
+ return -torch.sum(probs * torch.log(probs), dim=1).mean()
@@ -0,0 +1,12 @@
1
+ from torch_ttt.loss.base_loss import BaseLoss
2
+ from torch_ttt.loss_registry import LossRegistry
3
+
4
+
5
+ @LossRegistry.register("mean")
6
+ class MeanLoss(BaseLoss):
7
+ def __init__(self):
8
+ super().__init__()
9
+
10
+ def __call__(self, model, inputs):
11
+ outputs = model(inputs)
12
+ return outputs.mean()
@@ -0,0 +1,13 @@
1
+ from torch_ttt.loss.base_loss import BaseLoss
2
+ from torch_ttt.loss_registry import LossRegistry
3
+
4
+
5
+ @LossRegistry.register("ttt")
6
+ class TTTLoss(BaseLoss):
7
+ def __init__(self):
8
+ super().__init__()
9
+
10
+ def __call__(self, model, inputs):
11
+ # TODO [P2]: check that model is TTTEngine
12
+ _, loss = model(inputs)
13
+ return loss
@@ -0,0 +1,30 @@
1
+ import torch
2
+ from torch_ttt.loss.base_loss import BaseLoss
3
+ from torch_ttt.loss_registry import LossRegistry
4
+
5
+
6
+ @LossRegistry.register("weights_magnitude")
7
+ class WeightsMagnitudeLoss(BaseLoss):
8
+ def __init__(self):
9
+ super().__init__()
10
+ self.quantile = 0.95
11
+
12
+ def __call__(self, model, inputs):
13
+ # Step 2: Collect all model weights
14
+ all_weights = []
15
+ for param in model.parameters():
16
+ if param.requires_grad: # Focus only on trainable parameters
17
+ all_weights.append(param.view(-1)) # Flatten weights into a 1D tensor
18
+
19
+ # Concatenate all weights into a single tensor
20
+ all_weights = torch.cat(all_weights)
21
+
22
+ # Step 3: Compute the top `quantile` values
23
+ quantile_value = torch.quantile(all_weights.abs(), self.quantile)
24
+ top_quantile_weights = all_weights[all_weights.abs() >= quantile_value]
25
+
26
+ # Step 4: Compute the average of top quantile weights
27
+ weight_loss = top_quantile_weights.mean()
28
+
29
+ # Return the combined loss
30
+ return weight_loss
@@ -0,0 +1,158 @@
1
+ import torch
2
+ from torch_ttt.loss.base_loss import BaseLoss
3
+ from torch_ttt.loss_registry import LossRegistry
4
+
5
+
6
+ @LossRegistry.register("zerot")
7
+ class ZeroTrainLoss(BaseLoss):
8
+ def __init__(self):
9
+ super().__init__()
10
+ self.quantile = 0.95
11
+
12
+ def __call__(self, model, inputs):
13
+ N = len(list(model.named_parameters()))
14
+
15
+ importance_dict = compute_weight_importance(model, inputs, N)
16
+
17
+ # Calculate the top s% mean importance as the loss
18
+ loss = top_s_percent_mean(importance_dict, self.quantile)
19
+
20
+ # Return the combined loss
21
+ return loss
22
+
23
+
24
+ def top_s_percent_mean(importance_dict, s):
25
+ """
26
+ Compute the mean of the top s% of the importance values.
27
+
28
+ Args:
29
+ importance_dict (dict): A dictionary mapping layer names to weight importance tensors.
30
+ s (float): The percentage (0-100) of top importance values to consider.
31
+
32
+ Returns:
33
+ torch.Tensor: The mean of the top s% importance values, differentiable with respect to original weights.
34
+ """
35
+ # Concatenate all importance values into a single tensor
36
+ all_importances = torch.cat([importance.view(-1) for importance in importance_dict.values()])
37
+
38
+ # Determine the number of top elements to keep
39
+ k = max(
40
+ 1, int(s / 100.0 * all_importances.numel())
41
+ ) # Ensure at least one element is considered
42
+
43
+ # Sort the importance values in descending order
44
+ top_importances, _ = torch.topk(all_importances, k)
45
+
46
+ # Compute the mean of the top s% importance values
47
+ top_mean = top_importances.mean()
48
+
49
+ return top_mean
50
+
51
+
52
+ def optimize_last_n_layers(m, x, N, s, M, lr=1e-3, optimizer=None, verbose=False):
53
+ """
54
+ Optimize the last N layers of the model for M steps using the top s% of weight importance as a TTT loss function.
55
+
56
+ Args:
57
+ m (torch.nn.Module): The model.
58
+ x (torch.Tensor): The input batch.
59
+ N (int): The number of layers from the end to optimize.
60
+ s (float): The percentage (0-100) of top importance values to consider for the loss.
61
+ M (int): The number of optimization steps.
62
+ lr (float, optional): Learning rate for the optimizer. Default is 1e-3.
63
+
64
+ Returns:
65
+ torch.Tensor: The model's predictions for the input batch after optimization.
66
+ """
67
+ # Set the model to training mode
68
+ m.train()
69
+
70
+ # Collect parameters of the last N layers to optimize
71
+ layers = list(m.named_parameters())[-2 * N :]
72
+ params_to_optimize = [param for name, param in layers if "weight" in name]
73
+
74
+ # Ensure that the parameters of the last N layers have requires_grad=True
75
+ for param in params_to_optimize:
76
+ param.requires_grad = True
77
+
78
+ # Set up the optimizer for the last N layers' parameters
79
+ if optimizer is None:
80
+ optimizer = torch.optim.Adam(params_to_optimize, lr=lr)
81
+
82
+ # Optimization loop for M steps
83
+ for step in range(M):
84
+ # Compute weight importance
85
+ importance_dict = compute_weight_importance(m, x, N)
86
+
87
+ # Calculate the top s% mean importance as the loss
88
+ loss = top_s_percent_mean(importance_dict, s)
89
+
90
+ # Zero the gradients
91
+ optimizer.zero_grad()
92
+
93
+ # Backward pass to compute gradients
94
+ loss.backward()
95
+
96
+ # Step the optimizer to update parameters
97
+ optimizer.step()
98
+
99
+ # Optionally, print loss for monitoring
100
+ if verbose:
101
+ print(f"Step {step+1}/{M}, Loss: {loss.item()}")
102
+
103
+ # After optimization, run the model on the input x to get predictions
104
+ m.eval() # Set the model to evaluation mode
105
+ with torch.no_grad(): # Disable gradient computation for inference
106
+ predictions = m(x)
107
+
108
+ return predictions, optimizer
109
+
110
+
111
+ def compute_weight_importance(m, x, N):
112
+ """
113
+ Compute the weight importance for the last N layers of the model efficiently.
114
+
115
+ Args:
116
+ m (torch.nn.Module): The model.
117
+ x (torch.Tensor): The input batch.
118
+ N (int): The number of layers from the end to consider.
119
+
120
+ Returns:
121
+ dict: A dictionary mapping layer names to the weight importance tensor.
122
+ """
123
+ # Ensure the model is in evaluation mode
124
+ m.eval()
125
+
126
+ # Disable gradients for all parameters initially
127
+ for param in m.parameters():
128
+ param.requires_grad = False
129
+
130
+ # Enable gradients only for the last N layers
131
+ layers = list(m.named_parameters())[-2 * N :] # Get the last N layers (both weights and biases)
132
+ for name, param in layers:
133
+ if "weight" in name:
134
+ param.requires_grad = True
135
+
136
+ # Forward pass to compute the output
137
+ output = m(x)
138
+
139
+ # Use a simple loss function (sum of outputs) to create a scalar output
140
+ loss = output.mean()
141
+
142
+ # Compute gradients with respect to the last N layers' parameters
143
+ gradients = torch.autograd.grad(
144
+ loss, [param for name, param in layers if "weight" in name], create_graph=True
145
+ )
146
+
147
+ # Calculate importance: w * \nabla{w} for the last N layers
148
+ importance_dict = {}
149
+ grad_idx = 0
150
+ for name, param in layers:
151
+ if "weight" in name:
152
+ importance = (param * gradients[grad_idx]).abs() # w * \nabla{w}
153
+ # importance = (param).abs() # w * \nabla{w}
154
+ # importance = gradients[grad_idx].abs() # # w * \nabla{w}
155
+ importance_dict[name] = importance # Don't detach the importance
156
+ grad_idx += 1
157
+
158
+ return importance_dict
@@ -0,0 +1,38 @@
1
+ import importlib
2
+ import os
3
+ from torch_ttt.loss.base_loss import BaseLoss
4
+
5
+
6
+ class LossRegistry:
7
+ _registry = {}
8
+
9
+ @classmethod
10
+ def register(cls, name):
11
+ def decorator(loss_class):
12
+ if not issubclass(loss_class, BaseLoss):
13
+ raise TypeError(f"Loss class '{name}' must inherit from BaseLoss.")
14
+ if name in cls._registry:
15
+ raise ValueError(f"Loss '{name}' is already registered.")
16
+
17
+ cls._registry[name] = loss_class
18
+ return loss_class
19
+
20
+ return decorator
21
+
22
+ @classmethod
23
+ def get_loss(cls, name):
24
+ if name not in cls._registry:
25
+ raise ValueError(f"Loss '{name}' is not registered.")
26
+ return cls._registry[name]
27
+
28
+
29
+ # Dynamically import all losses in the losses directory
30
+ def register_all_losses():
31
+ losses_dir = os.path.dirname(__file__) + "/loss"
32
+ for file in os.listdir(losses_dir):
33
+ if file.endswith("_loss.py") and not file.startswith("__"):
34
+ module_name = f"torch_ttt.loss.{file[:-3]}"
35
+ importlib.import_module(module_name)
36
+
37
+
38
+ register_all_losses()
File without changes
@@ -0,0 +1,25 @@
1
+ import random
2
+ from torchvision.transforms import functional as F
3
+
4
+
5
+ class RandomResizedCrop:
6
+ def __init__(self, scale=(0.2, 1.0)):
7
+ self.scale = scale
8
+
9
+ def __call__(self, img):
10
+ # Dynamically compute the crop size
11
+ original_size = img.shape[-2:] # H × W
12
+ area = original_size[0] * original_size[1]
13
+ target_area = random.uniform(self.scale[0], self.scale[1]) * area
14
+ aspect_ratio = random.uniform(3 / 4, 4 / 3)
15
+
16
+ h = int(round((target_area * aspect_ratio) ** 0.5))
17
+ w = int(round((target_area / aspect_ratio) ** 0.5))
18
+
19
+ if random.random() < 0.5: # Randomly swap h and w
20
+ h, w = w, h
21
+
22
+ h = min(h, original_size[0])
23
+ w = min(w, original_size[1])
24
+
25
+ return F.resized_crop(img, top=0, left=0, height=h, width=w, size=original_size)
@@ -0,0 +1,29 @@
1
+ import torch
2
+
3
+
4
+ def compute_covariance(features: torch.Tensor, dim: int = 0) -> torch.Tensor:
5
+ """Compute covariance matrix for given features along a specific dimension.
6
+
7
+ Args:
8
+ features (torch.Tensor): Input tensor of shape [N, D] or higher dimensions.
9
+ dim (int): The dimension along which to compute covariance.
10
+
11
+ Returns:
12
+ torch.Tensor: Covariance matrix of shape [D, D].
13
+
14
+ Raises:
15
+ ValueError: If the input tensor has fewer than 2 dimensions.
16
+ """
17
+ if features.ndim < 2:
18
+ raise ValueError("Input tensor must have at least 2 dimensions to compute covariance.")
19
+
20
+ if features.size(dim) <= 1:
21
+ raise ValueError(
22
+ f"Cannot compute covariance with less than 2 samples along dimension {dim}."
23
+ )
24
+
25
+ n = features.shape[0]
26
+ tmp = torch.ones((1, n), device=features.device) @ features
27
+ cov = (features.t() @ features - (tmp.t() @ tmp) / n) / (n - 1)
28
+
29
+ return cov
@@ -0,0 +1,16 @@
1
+ Metadata-Version: 2.4
2
+ Name: torch-ttt
3
+ Version: 0.0.1
4
+ License-File: LICENSE
5
+ Provides-Extra: dev
6
+ Requires-Dist: pytest-cov; extra == "dev"
7
+ Provides-Extra: docs
8
+ Requires-Dist: sphinxawesome-theme==6.0.0b1; extra == "docs"
9
+ Requires-Dist: sphinx-gallery==0.18.0; extra == "docs"
10
+ Requires-Dist: sphinxcontrib-bibtex; extra == "docs"
11
+ Requires-Dist: sphinxcontrib-googleanalytics; extra == "docs"
12
+ Requires-Dist: matplotlib; extra == "docs"
13
+ Requires-Dist: tqdm; extra == "docs"
14
+ Provides-Extra: all
15
+ Requires-Dist: torch_ttt[dev,docs]; extra == "all"
16
+ Dynamic: license-file
@@ -0,0 +1,26 @@
1
+ torch_ttt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ torch_ttt/engine_registry.py,sha256=FzjHxBDCF5y8vkToKDyrLcvVNAtw2LUu4Wl-EViNHw4,1144
3
+ torch_ttt/loss_registry.py,sha256=PGkrOFzeTKzXFoDLHUf-rH8aBAtyxrh4xzyowhVQtv4,1108
4
+ torch_ttt/engine/__init__.py,sha256=GlxEg_jHxKmyy-uLMIrZfbsPeF9hPDclV68fTIio3zU,183
5
+ torch_ttt/engine/actmad_engine.py,sha256=tiQLBO2SzYljUzBxdHywzrN188AsfPAqV5X6CQy7dcs,6281
6
+ torch_ttt/engine/base_engine.py,sha256=sQZq_NDzNlQnBERlRKD43Wo8Z7FJnblwHn0wANe6Q0o,1527
7
+ torch_ttt/engine/masked_ttt_engine.py,sha256=X-SApdbzzcViOk7lijPca3IMV1wP-8QDGWqWXow5RHo,404
8
+ torch_ttt/engine/tent_engine.py,sha256=w5JIL5eeQMO_NFlzEV92xL0MwOWExN2-8t7GdRP63Jk,1686
9
+ torch_ttt/engine/ttt_engine.py,sha256=pzdP4vgxAz-MQwP3KJUF0xFcUpm_hf0I_YzWYx-6q60,7675
10
+ torch_ttt/engine/ttt_pp_engine.py,sha256=W_6vP7Oc1eeVUBAix2qT3-Hz0JcV11uOdQQsfSMLJC8,11014
11
+ torch_ttt/loss/__init__.py,sha256=rV6RxTWgU_BtdUX0d7UXhtlTexhIiqxVNm4eDO9HUOU,51
12
+ torch_ttt/loss/base_loss.py,sha256=O7UhmrK8NGVI8Dd8zCgh_MKh5HNAkYnyq9_aAesbbpg,163
13
+ torch_ttt/loss/contrastive_loss.py,sha256=uxlDfSqrJEmg8v8TqJgK3LhjcRsF9euGIEGG2sYS1iY,3753
14
+ torch_ttt/loss/entropy_loss.py,sha256=bM9LetJXic5FSs3kJ6prjXkAkc6xEBaJezfCKx4-bPg,434
15
+ torch_ttt/loss/mean_loss.py,sha256=iZZ8DvU3-2nT3D0kgXBaSA7oO74rXrp9jrS-QbLUOe0,307
16
+ torch_ttt/loss/ttt_loss.py,sha256=B83lwl5vOZKn_iKL7OznD3BOR9bf3X4OfCz2jzKXF3Q,346
17
+ torch_ttt/loss/weights_magnitude_loss.py,sha256=MpyW7fqvoFzwT-nQMZN_uisf1n956knzpZWd8yXmn7s,1067
18
+ torch_ttt/loss/zerot_loss.py,sha256=k46u3RrDdF44-f_FpIyAS-r-ZsAHOU_qKp733lDn_VI,5388
19
+ torch_ttt/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
+ torch_ttt/utils/augmentations.py,sha256=y-thszixcgGvGFtCzMeG5pWvNVr1D_37kqyNJfd9ApE,831
21
+ torch_ttt/utils/math.py,sha256=mmgsyCOWGXP21gqQyrHhj-axmI4GCQ1AxFCPvvrO5Uk,952
22
+ torch_ttt-0.0.1.dist-info/licenses/LICENSE,sha256=ppPMBj3EeJB4NHrRgxWoRSoVIxx3apK0BFsIB2-jcco,1072
23
+ torch_ttt-0.0.1.dist-info/METADATA,sha256=kWAuJPGHukr9eKsvgeDvCIZAZkIgIG2c1OiNkr-gnBQ,562
24
+ torch_ttt-0.0.1.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
25
+ torch_ttt-0.0.1.dist-info/top_level.txt,sha256=WP1U2F1ULE-SjEdozpFfUhCR2ZNqIi8lKvUoeQ5G5Zw,10
26
+ torch_ttt-0.0.1.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (79.0.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Nikita Durasov
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1 @@
1
+ torch_ttt