oodeel 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- oodeel/__init__.py +28 -0
- oodeel/aggregator/__init__.py +26 -0
- oodeel/aggregator/base.py +70 -0
- oodeel/aggregator/fisher.py +259 -0
- oodeel/aggregator/mean.py +72 -0
- oodeel/aggregator/std.py +86 -0
- oodeel/datasets/__init__.py +24 -0
- oodeel/datasets/data_handler.py +334 -0
- oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
- oodeel/datasets/deprecated/DEPRECATED_ooddataset.py +330 -0
- oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
- oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
- oodeel/datasets/deprecated/__init__.py +31 -0
- oodeel/datasets/tf_data_handler.py +600 -0
- oodeel/datasets/torch_data_handler.py +672 -0
- oodeel/eval/__init__.py +22 -0
- oodeel/eval/metrics.py +218 -0
- oodeel/eval/plots/__init__.py +27 -0
- oodeel/eval/plots/features.py +345 -0
- oodeel/eval/plots/metrics.py +118 -0
- oodeel/eval/plots/plotly.py +162 -0
- oodeel/extractor/__init__.py +35 -0
- oodeel/extractor/feature_extractor.py +187 -0
- oodeel/extractor/hf_torch_feature_extractor.py +184 -0
- oodeel/extractor/keras_feature_extractor.py +409 -0
- oodeel/extractor/torch_feature_extractor.py +506 -0
- oodeel/methods/__init__.py +47 -0
- oodeel/methods/base.py +570 -0
- oodeel/methods/dknn.py +185 -0
- oodeel/methods/energy.py +119 -0
- oodeel/methods/entropy.py +113 -0
- oodeel/methods/gen.py +113 -0
- oodeel/methods/gram.py +274 -0
- oodeel/methods/mahalanobis.py +209 -0
- oodeel/methods/mls.py +113 -0
- oodeel/methods/odin.py +109 -0
- oodeel/methods/rmds.py +172 -0
- oodeel/methods/she.py +159 -0
- oodeel/methods/vim.py +273 -0
- oodeel/preprocess/__init__.py +31 -0
- oodeel/preprocess/tf_preprocess.py +95 -0
- oodeel/preprocess/torch_preprocess.py +97 -0
- oodeel/types/__init__.py +75 -0
- oodeel/utils/__init__.py +38 -0
- oodeel/utils/general_utils.py +97 -0
- oodeel/utils/operator.py +253 -0
- oodeel/utils/tf_operator.py +269 -0
- oodeel/utils/tf_training_tools.py +219 -0
- oodeel/utils/torch_operator.py +292 -0
- oodeel/utils/torch_training_tools.py +303 -0
- oodeel-0.4.0.dist-info/METADATA +409 -0
- oodeel-0.4.0.dist-info/RECORD +63 -0
- oodeel-0.4.0.dist-info/WHEEL +5 -0
- oodeel-0.4.0.dist-info/licenses/LICENSE +21 -0
- oodeel-0.4.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +22 -0
- tests/tests_tensorflow/__init__.py +37 -0
- tests/tests_tensorflow/tf_methods_utils.py +140 -0
- tests/tests_tensorflow/tools_tf.py +86 -0
- tests/tests_torch/__init__.py +38 -0
- tests/tests_torch/tools_torch.py +151 -0
- tests/tests_torch/torch_methods_utils.py +148 -0
- tests/tools_operator.py +153 -0
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
|
|
3
|
+
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
|
|
4
|
+
# CRIAQ and ANITI - https://www.deel.ai/
|
|
5
|
+
#
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
#
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
#
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
import numpy as np
|
|
24
|
+
import tensorflow as tf
|
|
25
|
+
from tensorflow.keras.layers import Conv2D
|
|
26
|
+
from tensorflow.keras.layers import Dense
|
|
27
|
+
from tensorflow.keras.layers import Dropout
|
|
28
|
+
from tensorflow.keras.layers import Flatten
|
|
29
|
+
from tensorflow.keras.layers import MaxPooling2D
|
|
30
|
+
from tensorflow.keras.models import Sequential
|
|
31
|
+
|
|
32
|
+
from ..datasets.tf_data_handler import TFDataHandler
|
|
33
|
+
from ..types import List
|
|
34
|
+
from ..types import Optional
|
|
35
|
+
from ..types import Union
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_toy_mlp(input_shape: tuple, num_classes: int) -> tf.keras.Model:
|
|
39
|
+
"""Basic keras MLP classifier for toy datasets.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
input_shape (tuple): Input data shape.
|
|
43
|
+
num_classes (int): Number of classes for the classification task.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
tf.keras.Model: model
|
|
47
|
+
"""
|
|
48
|
+
return tf.keras.models.Sequential(
|
|
49
|
+
[
|
|
50
|
+
tf.keras.layers.Input(shape=input_shape),
|
|
51
|
+
tf.keras.layers.Flatten(),
|
|
52
|
+
tf.keras.layers.Dense(300, activation="relu"),
|
|
53
|
+
tf.keras.layers.Dense(150, activation="relu"),
|
|
54
|
+
tf.keras.layers.Dense(num_classes, activation="softmax"),
|
|
55
|
+
]
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def get_toy_keras_convnet(num_classes: int) -> tf.keras.Model:
|
|
60
|
+
"""Basic keras convolutional classifier for toy datasets.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
num_classes (int): Number of classes for the classification task.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
tf.keras.Model: model
|
|
67
|
+
"""
|
|
68
|
+
return Sequential(
|
|
69
|
+
[
|
|
70
|
+
Conv2D(32, kernel_size=(3, 3), activation="relu"),
|
|
71
|
+
MaxPooling2D(pool_size=(2, 2)),
|
|
72
|
+
Conv2D(64, kernel_size=(3, 3), activation="relu"),
|
|
73
|
+
MaxPooling2D(pool_size=(2, 2)),
|
|
74
|
+
Flatten(),
|
|
75
|
+
Dropout(0.5),
|
|
76
|
+
Dense(num_classes, activation="softmax"),
|
|
77
|
+
]
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def train_tf_model(
|
|
82
|
+
train_data: tf.data.Dataset,
|
|
83
|
+
model: Union[tf.keras.Model, str],
|
|
84
|
+
input_shape: tuple = None,
|
|
85
|
+
num_classes: int = None,
|
|
86
|
+
batch_size: int = 128,
|
|
87
|
+
epochs: int = 50,
|
|
88
|
+
loss: str = "sparse_categorical_crossentropy",
|
|
89
|
+
optimizer: str = "adam",
|
|
90
|
+
lr_scheduler: Optional[str] = None,
|
|
91
|
+
learning_rate: float = 1e-3,
|
|
92
|
+
metrics: List[str] = ["accuracy"],
|
|
93
|
+
imagenet_pretrained: bool = False,
|
|
94
|
+
validation_data: Optional[tf.data.Dataset] = None,
|
|
95
|
+
save_dir: Optional[str] = None,
|
|
96
|
+
save_best_only: bool = True,
|
|
97
|
+
) -> tf.keras.Model:
|
|
98
|
+
"""Loads a model from tensorflow.python.keras.applications.
|
|
99
|
+
If the dataset is different from imagenet, trains on provided dataset.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
train_data (tf.data.Dataset): training dataset.
|
|
103
|
+
model (Union[tf.keras.Model, str]): if a string is provided, must be a model
|
|
104
|
+
from tf.keras.applications or "toy_convnet" or "toy_mlp"
|
|
105
|
+
input_shape (tuple): Shape of the input images.
|
|
106
|
+
num_classes (int): Number of output classes.
|
|
107
|
+
batch_size (int, optional): Defaults to 128.
|
|
108
|
+
epochs (int, optional): Defaults to 50.
|
|
109
|
+
loss (str, optional): Defaults to "sparse_categorical_crossentropy".
|
|
110
|
+
optimizer (str, optional): Defaults to "adam".
|
|
111
|
+
lr_scheduler (str, optional): ("cosine" | "steps" | None). Defaults to None.
|
|
112
|
+
learning_rate (float, optional): Defaults to 1e-3.
|
|
113
|
+
metrics (List[str], optional): Validation metrics. Defaults to ["accuracy"].
|
|
114
|
+
imagenet_pretrained (bool, optional): Load a model pretrained on imagenet or
|
|
115
|
+
not. Defaults to False.
|
|
116
|
+
validation_data (Optional[tf.data.Dataset], optional): Defaults to None.
|
|
117
|
+
save_dir (Optional[str], optional): Directory to save the model.
|
|
118
|
+
Defaults to None.
|
|
119
|
+
save_best_only (bool): If False, saved model will be the last one. Defaults to
|
|
120
|
+
True.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
tf.keras.Model: Trained model
|
|
124
|
+
"""
|
|
125
|
+
# get data infos from dataset
|
|
126
|
+
if isinstance(train_data.element_spec, dict):
|
|
127
|
+
input_id = "image"
|
|
128
|
+
label_id = "label"
|
|
129
|
+
else:
|
|
130
|
+
input_id = 0
|
|
131
|
+
label_id = -1
|
|
132
|
+
if input_shape is None:
|
|
133
|
+
input_shape = TFDataHandler.get_feature_shape(train_data, input_id)
|
|
134
|
+
if num_classes is None:
|
|
135
|
+
classes = TFDataHandler.get_feature(train_data, label_id).unique()
|
|
136
|
+
num_classes = len(list(classes.as_numpy_iterator()))
|
|
137
|
+
|
|
138
|
+
# prepare model
|
|
139
|
+
if isinstance(model, tf.keras.Model):
|
|
140
|
+
pass
|
|
141
|
+
elif isinstance(model, str):
|
|
142
|
+
if model == "toy_convnet":
|
|
143
|
+
model = get_toy_keras_convnet(num_classes)
|
|
144
|
+
elif model == "toy_mlp":
|
|
145
|
+
model = get_toy_mlp(input_shape, num_classes)
|
|
146
|
+
else:
|
|
147
|
+
weights = "imagenet" if imagenet_pretrained else None
|
|
148
|
+
backbone = getattr(tf.keras.applications, model)(
|
|
149
|
+
include_top=False, weights=weights, input_shape=input_shape
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
features = tf.keras.layers.Flatten()(backbone.layers[-1].output)
|
|
153
|
+
output = tf.keras.layers.Dense(
|
|
154
|
+
num_classes,
|
|
155
|
+
activation="softmax",
|
|
156
|
+
)(features)
|
|
157
|
+
model = tf.keras.Model(backbone.layers[0].input, output)
|
|
158
|
+
|
|
159
|
+
n_samples = TFDataHandler.get_dataset_length(train_data)
|
|
160
|
+
|
|
161
|
+
# Prepare callbacks
|
|
162
|
+
model_checkpoint_callback = []
|
|
163
|
+
|
|
164
|
+
if save_dir is not None:
|
|
165
|
+
checkpoint_filepath = save_dir
|
|
166
|
+
model_checkpoint_callback.append(
|
|
167
|
+
tf.keras.callbacks.ModelCheckpoint(
|
|
168
|
+
filepath=checkpoint_filepath,
|
|
169
|
+
save_weights_only=True,
|
|
170
|
+
monitor="val_accuracy",
|
|
171
|
+
mode="max",
|
|
172
|
+
save_best_only=save_best_only,
|
|
173
|
+
)
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
if len(model_checkpoint_callback) == 0:
|
|
177
|
+
model_checkpoint_callback = None
|
|
178
|
+
|
|
179
|
+
# optimizer
|
|
180
|
+
decay_steps = int(epochs * n_samples / batch_size)
|
|
181
|
+
if lr_scheduler == "cosine":
|
|
182
|
+
learning_rate_fn = tf.keras.experimental.CosineDecay(
|
|
183
|
+
learning_rate, decay_steps=decay_steps
|
|
184
|
+
)
|
|
185
|
+
elif lr_scheduler == "steps":
|
|
186
|
+
values = list(learning_rate * np.array([1, 0.1, 0.01]))
|
|
187
|
+
boundaries = list(np.round(decay_steps * np.array([1 / 3, 2 / 3])).astype(int))
|
|
188
|
+
learning_rate_fn = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
|
|
189
|
+
boundaries, values
|
|
190
|
+
)
|
|
191
|
+
else:
|
|
192
|
+
learning_rate_fn = learning_rate
|
|
193
|
+
|
|
194
|
+
config = {
|
|
195
|
+
"class_name": optimizer,
|
|
196
|
+
"config": {
|
|
197
|
+
"learning_rate": learning_rate_fn,
|
|
198
|
+
},
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
if optimizer == "SGD":
|
|
202
|
+
config["config"]["momentum"] = 0.9
|
|
203
|
+
config["config"]["decay"] = 5e-4
|
|
204
|
+
|
|
205
|
+
keras_optimizer = tf.keras.optimizers.get(config)
|
|
206
|
+
|
|
207
|
+
model.compile(loss=loss, optimizer=keras_optimizer, metrics=metrics)
|
|
208
|
+
|
|
209
|
+
model.fit(
|
|
210
|
+
train_data,
|
|
211
|
+
validation_data=validation_data,
|
|
212
|
+
epochs=epochs,
|
|
213
|
+
callbacks=model_checkpoint_callback,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
if save_dir is not None:
|
|
217
|
+
model.load_weights(save_dir)
|
|
218
|
+
model.save(save_dir)
|
|
219
|
+
return model
|
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
|
|
3
|
+
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
|
|
4
|
+
# CRIAQ and ANITI - https://www.deel.ai/
|
|
5
|
+
#
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
#
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
#
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
from typing import List
|
|
24
|
+
from typing import Optional
|
|
25
|
+
|
|
26
|
+
import numpy as np
|
|
27
|
+
import torch
|
|
28
|
+
|
|
29
|
+
from ..types import Callable
|
|
30
|
+
from ..types import TensorType
|
|
31
|
+
from ..types import Union
|
|
32
|
+
from .general_utils import is_from
|
|
33
|
+
from .operator import Operator
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def sanitize_input(tensor_arg_func: Callable):
|
|
37
|
+
"""ensures the decorated function receives a torch.Tensor"""
|
|
38
|
+
|
|
39
|
+
def wrapper(obj, tensor, *args, **kwargs):
|
|
40
|
+
if isinstance(tensor, torch.Tensor):
|
|
41
|
+
pass
|
|
42
|
+
elif is_from(tensor, "tensorflow"):
|
|
43
|
+
tensor = torch.Tensor(tensor.numpy())
|
|
44
|
+
else:
|
|
45
|
+
tensor = torch.Tensor(tensor)
|
|
46
|
+
|
|
47
|
+
return tensor_arg_func(obj, tensor, *args, **kwargs)
|
|
48
|
+
|
|
49
|
+
return wrapper
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class TorchOperator(Operator):
|
|
53
|
+
"""Class to handle torch operations with a unified API"""
|
|
54
|
+
|
|
55
|
+
def __init__(self, model: Optional[torch.nn.Module] = None):
|
|
56
|
+
if model is not None:
|
|
57
|
+
self._device = next(model.parameters()).device
|
|
58
|
+
else:
|
|
59
|
+
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
60
|
+
|
|
61
|
+
@staticmethod
|
|
62
|
+
def softmax(tensor: TensorType) -> torch.Tensor:
|
|
63
|
+
"""Softmax function along the last dimension"""
|
|
64
|
+
return torch.nn.functional.softmax(tensor, dim=-1)
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def argmax(tensor: TensorType, dim: Optional[int] = None) -> torch.Tensor:
|
|
68
|
+
"""Argmax function"""
|
|
69
|
+
return torch.argmax(tensor, dim=dim)
|
|
70
|
+
|
|
71
|
+
@staticmethod
|
|
72
|
+
def max(
|
|
73
|
+
tensor: TensorType, dim: Optional[int] = None, keepdim: Optional[bool] = False
|
|
74
|
+
) -> torch.Tensor:
|
|
75
|
+
"""Max function"""
|
|
76
|
+
if dim is None:
|
|
77
|
+
return torch.max(tensor)
|
|
78
|
+
else:
|
|
79
|
+
return torch.max(tensor, dim, keepdim=keepdim)[0]
|
|
80
|
+
|
|
81
|
+
@staticmethod
|
|
82
|
+
def min(
|
|
83
|
+
tensor: TensorType, dim: Optional[int] = None, keepdim: bool = False
|
|
84
|
+
) -> torch.Tensor:
|
|
85
|
+
"""Min function"""
|
|
86
|
+
if dim is None:
|
|
87
|
+
return torch.min(tensor)
|
|
88
|
+
else:
|
|
89
|
+
return torch.min(tensor, dim, keepdim=keepdim)[0]
|
|
90
|
+
|
|
91
|
+
@staticmethod
|
|
92
|
+
def one_hot(tensor: TensorType, num_classes: int) -> torch.Tensor:
|
|
93
|
+
"""One hot function"""
|
|
94
|
+
return torch.nn.functional.one_hot(tensor, num_classes)
|
|
95
|
+
|
|
96
|
+
@staticmethod
|
|
97
|
+
def sign(tensor: TensorType) -> torch.Tensor:
|
|
98
|
+
"""Sign function"""
|
|
99
|
+
return torch.sign(tensor)
|
|
100
|
+
|
|
101
|
+
@staticmethod
|
|
102
|
+
def CrossEntropyLoss(reduction: str = "mean"):
|
|
103
|
+
"""Cross Entropy Loss from logits"""
|
|
104
|
+
|
|
105
|
+
def sanitized_ce_loss(inputs, targets):
|
|
106
|
+
return torch.nn.CrossEntropyLoss(reduction=reduction)(inputs, targets)
|
|
107
|
+
|
|
108
|
+
return sanitized_ce_loss
|
|
109
|
+
|
|
110
|
+
@staticmethod
|
|
111
|
+
def norm(tensor: TensorType, dim: Optional[int] = None) -> torch.Tensor:
|
|
112
|
+
"""Tensor Norm"""
|
|
113
|
+
return torch.norm(tensor, dim=dim)
|
|
114
|
+
|
|
115
|
+
@staticmethod
|
|
116
|
+
def matmul(tensor_1: TensorType, tensor_2: TensorType) -> torch.Tensor:
|
|
117
|
+
"""Matmul operation"""
|
|
118
|
+
return torch.matmul(tensor_1, tensor_2)
|
|
119
|
+
|
|
120
|
+
@staticmethod
|
|
121
|
+
def convert_from_tensorflow(tensor: TensorType) -> torch.Tensor:
|
|
122
|
+
"""Convert a tensorflow tensor into a torch tensor
|
|
123
|
+
|
|
124
|
+
Used when using a pytorch model on a dataset loaded from tensorflow datasets
|
|
125
|
+
"""
|
|
126
|
+
return torch.Tensor(tensor.numpy())
|
|
127
|
+
|
|
128
|
+
@staticmethod
|
|
129
|
+
def convert_to_numpy(tensor: TensorType) -> np.ndarray:
|
|
130
|
+
"""Convert tensor into a np.ndarray"""
|
|
131
|
+
if not isinstance(tensor, np.ndarray):
|
|
132
|
+
if tensor.device != "cpu":
|
|
133
|
+
tensor = tensor.to("cpu")
|
|
134
|
+
return tensor.detach().numpy()
|
|
135
|
+
return tensor
|
|
136
|
+
|
|
137
|
+
@staticmethod
|
|
138
|
+
def gradient(func: Callable, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
|
139
|
+
"""Compute gradients for a batch of samples.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
func (Callable): Function used for computing gradient. Must be built with
|
|
143
|
+
torch differentiable operations only, and return a scalar.
|
|
144
|
+
inputs (torch.Tensor): Input tensor wrt which the gradients are computed
|
|
145
|
+
*args: Additional Args for func.
|
|
146
|
+
**kwargs: Additional Kwargs for func.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
torch.Tensor: Gradients computed, with the same shape as the inputs.
|
|
150
|
+
"""
|
|
151
|
+
inputs.requires_grad_(True)
|
|
152
|
+
outputs = func(inputs, *args, **kwargs)
|
|
153
|
+
gradients = torch.autograd.grad(outputs, inputs)
|
|
154
|
+
inputs.requires_grad_(False)
|
|
155
|
+
return gradients[0]
|
|
156
|
+
|
|
157
|
+
@staticmethod
|
|
158
|
+
def stack(tensors: List[TensorType], dim: int = 0) -> torch.Tensor:
|
|
159
|
+
"Stack tensors along a new dimension"
|
|
160
|
+
return torch.stack(tensors, dim)
|
|
161
|
+
|
|
162
|
+
@staticmethod
|
|
163
|
+
def cat(tensors: List[TensorType], dim: int = 0) -> torch.Tensor:
|
|
164
|
+
"Concatenate tensors in a given dimension"
|
|
165
|
+
return torch.cat(tensors, dim)
|
|
166
|
+
|
|
167
|
+
@staticmethod
|
|
168
|
+
def mean(tensor: TensorType, dim: Optional[int] = None) -> torch.Tensor:
|
|
169
|
+
"Mean function"
|
|
170
|
+
if dim is None:
|
|
171
|
+
return torch.mean(tensor)
|
|
172
|
+
else:
|
|
173
|
+
return torch.mean(tensor, dim)
|
|
174
|
+
|
|
175
|
+
@staticmethod
|
|
176
|
+
def flatten(tensor: TensorType) -> torch.Tensor:
|
|
177
|
+
"Flatten function"
|
|
178
|
+
# Flatten the features to 2D (n_batch, n_features)
|
|
179
|
+
return tensor.view(tensor.size(0), -1)
|
|
180
|
+
|
|
181
|
+
def from_numpy(self, arr: np.ndarray) -> torch.Tensor:
|
|
182
|
+
"Convert a NumPy array to a tensor"
|
|
183
|
+
# TODO change dtype
|
|
184
|
+
return torch.tensor(arr).to(self._device)
|
|
185
|
+
|
|
186
|
+
@staticmethod
|
|
187
|
+
def t(tensor: TensorType) -> torch.Tensor:
|
|
188
|
+
"Transpose function for tensor of rank 2"
|
|
189
|
+
return tensor.t()
|
|
190
|
+
|
|
191
|
+
@staticmethod
|
|
192
|
+
def permute(tensor: TensorType, dims) -> torch.Tensor:
|
|
193
|
+
"Transpose function for tensor of rank 2"
|
|
194
|
+
return torch.permute(tensor, dims)
|
|
195
|
+
|
|
196
|
+
@staticmethod
|
|
197
|
+
def diag(tensor: TensorType) -> torch.Tensor:
|
|
198
|
+
"Diagonal function: return the diagonal of a 2D tensor"
|
|
199
|
+
return tensor.diag()
|
|
200
|
+
|
|
201
|
+
@staticmethod
|
|
202
|
+
def reshape(tensor: TensorType, shape: List[int]) -> torch.Tensor:
|
|
203
|
+
"Reshape function"
|
|
204
|
+
return tensor.view(*shape)
|
|
205
|
+
|
|
206
|
+
@staticmethod
|
|
207
|
+
def equal(tensor: TensorType, other: Union[TensorType, int, float]) -> torch.Tensor:
|
|
208
|
+
"Computes element-wise equality"
|
|
209
|
+
return torch.eq(tensor, other)
|
|
210
|
+
|
|
211
|
+
@staticmethod
|
|
212
|
+
def pinv(tensor: TensorType) -> torch.Tensor:
|
|
213
|
+
"Computes the pseudoinverse (Moore-Penrose inverse) of a matrix."
|
|
214
|
+
return torch.linalg.pinv(tensor)
|
|
215
|
+
|
|
216
|
+
@staticmethod
|
|
217
|
+
def eigh(tensor: TensorType) -> torch.Tensor:
|
|
218
|
+
"Computes the eigen decomposition of a self-adjoint matrix."
|
|
219
|
+
eigval, eigvec = torch.linalg.eigh(tensor)
|
|
220
|
+
return eigval, eigvec
|
|
221
|
+
|
|
222
|
+
@staticmethod
|
|
223
|
+
def quantile(tensor: TensorType, q: float, dim: int = None) -> torch.Tensor:
|
|
224
|
+
"Computes the quantile of a tensor's components. q in (0,1)"
|
|
225
|
+
if dim is None:
|
|
226
|
+
# keep the 16 millions first elements (see torch.quantile issue:
|
|
227
|
+
# https://github.com/pytorch/pytorch/issues/64947)
|
|
228
|
+
tensor_flatten = tensor.view(-1)[:16_000_000]
|
|
229
|
+
return torch.quantile(tensor_flatten, q).item()
|
|
230
|
+
else:
|
|
231
|
+
return torch.quantile(tensor, q, dim)
|
|
232
|
+
|
|
233
|
+
@staticmethod
|
|
234
|
+
def relu(tensor: TensorType) -> torch.Tensor:
|
|
235
|
+
"Apply relu to a tensor"
|
|
236
|
+
return torch.nn.functional.relu(tensor)
|
|
237
|
+
|
|
238
|
+
@staticmethod
|
|
239
|
+
def einsum(equation: str, *tensors: TensorType) -> torch.Tensor:
|
|
240
|
+
"Computes the einsum between tensors following equation"
|
|
241
|
+
return torch.einsum(equation, *tensors)
|
|
242
|
+
|
|
243
|
+
@staticmethod
|
|
244
|
+
def tril(tensor: TensorType, diagonal: int = 0) -> torch.Tensor:
|
|
245
|
+
"Set the upper triangle of the matrix formed by the last two dimensions of"
|
|
246
|
+
"tensor to zero"
|
|
247
|
+
return torch.tril(tensor, diagonal)
|
|
248
|
+
|
|
249
|
+
@staticmethod
|
|
250
|
+
def sum(tensor: TensorType, dim: Union[tuple, list, int] = None) -> torch.Tensor:
|
|
251
|
+
"sum along dim"
|
|
252
|
+
if dim is None:
|
|
253
|
+
return torch.sum(tensor)
|
|
254
|
+
return torch.sum(tensor, dim)
|
|
255
|
+
|
|
256
|
+
@staticmethod
|
|
257
|
+
def unsqueeze(tensor: TensorType, dim: int) -> torch.Tensor:
|
|
258
|
+
"unsqueeze along dim"
|
|
259
|
+
return torch.unsqueeze(tensor, dim)
|
|
260
|
+
|
|
261
|
+
@staticmethod
|
|
262
|
+
def squeeze(tensor: TensorType, dim: int = None) -> torch.Tensor:
|
|
263
|
+
"squeeze along dim"
|
|
264
|
+
|
|
265
|
+
if dim is None:
|
|
266
|
+
return torch.squeeze(tensor)
|
|
267
|
+
|
|
268
|
+
return torch.squeeze(tensor, dim)
|
|
269
|
+
|
|
270
|
+
@staticmethod
|
|
271
|
+
def abs(tensor: TensorType) -> torch.Tensor:
|
|
272
|
+
"compute absolute value"
|
|
273
|
+
return torch.abs(tensor)
|
|
274
|
+
|
|
275
|
+
@staticmethod
|
|
276
|
+
def where(
|
|
277
|
+
condition: TensorType,
|
|
278
|
+
input: Union[TensorType, float],
|
|
279
|
+
other: Union[TensorType, float],
|
|
280
|
+
) -> torch.Tensor:
|
|
281
|
+
"Applies where function , to condition"
|
|
282
|
+
return torch.where(condition, input, other)
|
|
283
|
+
|
|
284
|
+
@staticmethod
|
|
285
|
+
def avg_pool_2d(tensor: TensorType) -> torch.Tensor:
|
|
286
|
+
"""Perform avg pool in 2d as in torch.nn.functional.adaptive_avg_pool2d"""
|
|
287
|
+
return torch.mean(tensor, dim=(-2, -1))
|
|
288
|
+
|
|
289
|
+
@staticmethod
|
|
290
|
+
def log(tensor: TensorType) -> torch.Tensor:
|
|
291
|
+
"""Perform log"""
|
|
292
|
+
return torch.log(tensor)
|