oodeel 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. oodeel/__init__.py +28 -0
  2. oodeel/aggregator/__init__.py +26 -0
  3. oodeel/aggregator/base.py +70 -0
  4. oodeel/aggregator/fisher.py +259 -0
  5. oodeel/aggregator/mean.py +72 -0
  6. oodeel/aggregator/std.py +86 -0
  7. oodeel/datasets/__init__.py +24 -0
  8. oodeel/datasets/data_handler.py +334 -0
  9. oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
  10. oodeel/datasets/deprecated/DEPRECATED_ooddataset.py +330 -0
  11. oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
  12. oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
  13. oodeel/datasets/deprecated/__init__.py +31 -0
  14. oodeel/datasets/tf_data_handler.py +600 -0
  15. oodeel/datasets/torch_data_handler.py +672 -0
  16. oodeel/eval/__init__.py +22 -0
  17. oodeel/eval/metrics.py +218 -0
  18. oodeel/eval/plots/__init__.py +27 -0
  19. oodeel/eval/plots/features.py +345 -0
  20. oodeel/eval/plots/metrics.py +118 -0
  21. oodeel/eval/plots/plotly.py +162 -0
  22. oodeel/extractor/__init__.py +35 -0
  23. oodeel/extractor/feature_extractor.py +187 -0
  24. oodeel/extractor/hf_torch_feature_extractor.py +184 -0
  25. oodeel/extractor/keras_feature_extractor.py +409 -0
  26. oodeel/extractor/torch_feature_extractor.py +506 -0
  27. oodeel/methods/__init__.py +47 -0
  28. oodeel/methods/base.py +570 -0
  29. oodeel/methods/dknn.py +185 -0
  30. oodeel/methods/energy.py +119 -0
  31. oodeel/methods/entropy.py +113 -0
  32. oodeel/methods/gen.py +113 -0
  33. oodeel/methods/gram.py +274 -0
  34. oodeel/methods/mahalanobis.py +209 -0
  35. oodeel/methods/mls.py +113 -0
  36. oodeel/methods/odin.py +109 -0
  37. oodeel/methods/rmds.py +172 -0
  38. oodeel/methods/she.py +159 -0
  39. oodeel/methods/vim.py +273 -0
  40. oodeel/preprocess/__init__.py +31 -0
  41. oodeel/preprocess/tf_preprocess.py +95 -0
  42. oodeel/preprocess/torch_preprocess.py +97 -0
  43. oodeel/types/__init__.py +75 -0
  44. oodeel/utils/__init__.py +38 -0
  45. oodeel/utils/general_utils.py +97 -0
  46. oodeel/utils/operator.py +253 -0
  47. oodeel/utils/tf_operator.py +269 -0
  48. oodeel/utils/tf_training_tools.py +219 -0
  49. oodeel/utils/torch_operator.py +292 -0
  50. oodeel/utils/torch_training_tools.py +303 -0
  51. oodeel-0.4.0.dist-info/METADATA +409 -0
  52. oodeel-0.4.0.dist-info/RECORD +63 -0
  53. oodeel-0.4.0.dist-info/WHEEL +5 -0
  54. oodeel-0.4.0.dist-info/licenses/LICENSE +21 -0
  55. oodeel-0.4.0.dist-info/top_level.txt +2 -0
  56. tests/__init__.py +22 -0
  57. tests/tests_tensorflow/__init__.py +37 -0
  58. tests/tests_tensorflow/tf_methods_utils.py +140 -0
  59. tests/tests_tensorflow/tools_tf.py +86 -0
  60. tests/tests_torch/__init__.py +38 -0
  61. tests/tests_torch/tools_torch.py +151 -0
  62. tests/tests_torch/torch_methods_utils.py +148 -0
  63. tests/tools_operator.py +153 -0
@@ -0,0 +1,219 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ import numpy as np
24
+ import tensorflow as tf
25
+ from tensorflow.keras.layers import Conv2D
26
+ from tensorflow.keras.layers import Dense
27
+ from tensorflow.keras.layers import Dropout
28
+ from tensorflow.keras.layers import Flatten
29
+ from tensorflow.keras.layers import MaxPooling2D
30
+ from tensorflow.keras.models import Sequential
31
+
32
+ from ..datasets.tf_data_handler import TFDataHandler
33
+ from ..types import List
34
+ from ..types import Optional
35
+ from ..types import Union
36
+
37
+
38
+ def get_toy_mlp(input_shape: tuple, num_classes: int) -> tf.keras.Model:
39
+ """Basic keras MLP classifier for toy datasets.
40
+
41
+ Args:
42
+ input_shape (tuple): Input data shape.
43
+ num_classes (int): Number of classes for the classification task.
44
+
45
+ Returns:
46
+ tf.keras.Model: model
47
+ """
48
+ return tf.keras.models.Sequential(
49
+ [
50
+ tf.keras.layers.Input(shape=input_shape),
51
+ tf.keras.layers.Flatten(),
52
+ tf.keras.layers.Dense(300, activation="relu"),
53
+ tf.keras.layers.Dense(150, activation="relu"),
54
+ tf.keras.layers.Dense(num_classes, activation="softmax"),
55
+ ]
56
+ )
57
+
58
+
59
+ def get_toy_keras_convnet(num_classes: int) -> tf.keras.Model:
60
+ """Basic keras convolutional classifier for toy datasets.
61
+
62
+ Args:
63
+ num_classes (int): Number of classes for the classification task.
64
+
65
+ Returns:
66
+ tf.keras.Model: model
67
+ """
68
+ return Sequential(
69
+ [
70
+ Conv2D(32, kernel_size=(3, 3), activation="relu"),
71
+ MaxPooling2D(pool_size=(2, 2)),
72
+ Conv2D(64, kernel_size=(3, 3), activation="relu"),
73
+ MaxPooling2D(pool_size=(2, 2)),
74
+ Flatten(),
75
+ Dropout(0.5),
76
+ Dense(num_classes, activation="softmax"),
77
+ ]
78
+ )
79
+
80
+
81
+ def train_tf_model(
82
+ train_data: tf.data.Dataset,
83
+ model: Union[tf.keras.Model, str],
84
+ input_shape: tuple = None,
85
+ num_classes: int = None,
86
+ batch_size: int = 128,
87
+ epochs: int = 50,
88
+ loss: str = "sparse_categorical_crossentropy",
89
+ optimizer: str = "adam",
90
+ lr_scheduler: Optional[str] = None,
91
+ learning_rate: float = 1e-3,
92
+ metrics: List[str] = ["accuracy"],
93
+ imagenet_pretrained: bool = False,
94
+ validation_data: Optional[tf.data.Dataset] = None,
95
+ save_dir: Optional[str] = None,
96
+ save_best_only: bool = True,
97
+ ) -> tf.keras.Model:
98
+ """Loads a model from tensorflow.python.keras.applications.
99
+ If the dataset is different from imagenet, trains on provided dataset.
100
+
101
+ Args:
102
+ train_data (tf.data.Dataset): training dataset.
103
+ model (Union[tf.keras.Model, str]): if a string is provided, must be a model
104
+ from tf.keras.applications or "toy_convnet" or "toy_mlp"
105
+ input_shape (tuple): Shape of the input images.
106
+ num_classes (int): Number of output classes.
107
+ batch_size (int, optional): Defaults to 128.
108
+ epochs (int, optional): Defaults to 50.
109
+ loss (str, optional): Defaults to "sparse_categorical_crossentropy".
110
+ optimizer (str, optional): Defaults to "adam".
111
+ lr_scheduler (str, optional): ("cosine" | "steps" | None). Defaults to None.
112
+ learning_rate (float, optional): Defaults to 1e-3.
113
+ metrics (List[str], optional): Validation metrics. Defaults to ["accuracy"].
114
+ imagenet_pretrained (bool, optional): Load a model pretrained on imagenet or
115
+ not. Defaults to False.
116
+ validation_data (Optional[tf.data.Dataset], optional): Defaults to None.
117
+ save_dir (Optional[str], optional): Directory to save the model.
118
+ Defaults to None.
119
+ save_best_only (bool): If False, saved model will be the last one. Defaults to
120
+ True.
121
+
122
+ Returns:
123
+ tf.keras.Model: Trained model
124
+ """
125
+ # get data infos from dataset
126
+ if isinstance(train_data.element_spec, dict):
127
+ input_id = "image"
128
+ label_id = "label"
129
+ else:
130
+ input_id = 0
131
+ label_id = -1
132
+ if input_shape is None:
133
+ input_shape = TFDataHandler.get_feature_shape(train_data, input_id)
134
+ if num_classes is None:
135
+ classes = TFDataHandler.get_feature(train_data, label_id).unique()
136
+ num_classes = len(list(classes.as_numpy_iterator()))
137
+
138
+ # prepare model
139
+ if isinstance(model, tf.keras.Model):
140
+ pass
141
+ elif isinstance(model, str):
142
+ if model == "toy_convnet":
143
+ model = get_toy_keras_convnet(num_classes)
144
+ elif model == "toy_mlp":
145
+ model = get_toy_mlp(input_shape, num_classes)
146
+ else:
147
+ weights = "imagenet" if imagenet_pretrained else None
148
+ backbone = getattr(tf.keras.applications, model)(
149
+ include_top=False, weights=weights, input_shape=input_shape
150
+ )
151
+
152
+ features = tf.keras.layers.Flatten()(backbone.layers[-1].output)
153
+ output = tf.keras.layers.Dense(
154
+ num_classes,
155
+ activation="softmax",
156
+ )(features)
157
+ model = tf.keras.Model(backbone.layers[0].input, output)
158
+
159
+ n_samples = TFDataHandler.get_dataset_length(train_data)
160
+
161
+ # Prepare callbacks
162
+ model_checkpoint_callback = []
163
+
164
+ if save_dir is not None:
165
+ checkpoint_filepath = save_dir
166
+ model_checkpoint_callback.append(
167
+ tf.keras.callbacks.ModelCheckpoint(
168
+ filepath=checkpoint_filepath,
169
+ save_weights_only=True,
170
+ monitor="val_accuracy",
171
+ mode="max",
172
+ save_best_only=save_best_only,
173
+ )
174
+ )
175
+
176
+ if len(model_checkpoint_callback) == 0:
177
+ model_checkpoint_callback = None
178
+
179
+ # optimizer
180
+ decay_steps = int(epochs * n_samples / batch_size)
181
+ if lr_scheduler == "cosine":
182
+ learning_rate_fn = tf.keras.experimental.CosineDecay(
183
+ learning_rate, decay_steps=decay_steps
184
+ )
185
+ elif lr_scheduler == "steps":
186
+ values = list(learning_rate * np.array([1, 0.1, 0.01]))
187
+ boundaries = list(np.round(decay_steps * np.array([1 / 3, 2 / 3])).astype(int))
188
+ learning_rate_fn = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
189
+ boundaries, values
190
+ )
191
+ else:
192
+ learning_rate_fn = learning_rate
193
+
194
+ config = {
195
+ "class_name": optimizer,
196
+ "config": {
197
+ "learning_rate": learning_rate_fn,
198
+ },
199
+ }
200
+
201
+ if optimizer == "SGD":
202
+ config["config"]["momentum"] = 0.9
203
+ config["config"]["decay"] = 5e-4
204
+
205
+ keras_optimizer = tf.keras.optimizers.get(config)
206
+
207
+ model.compile(loss=loss, optimizer=keras_optimizer, metrics=metrics)
208
+
209
+ model.fit(
210
+ train_data,
211
+ validation_data=validation_data,
212
+ epochs=epochs,
213
+ callbacks=model_checkpoint_callback,
214
+ )
215
+
216
+ if save_dir is not None:
217
+ model.load_weights(save_dir)
218
+ model.save(save_dir)
219
+ return model
@@ -0,0 +1,292 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ from typing import List
24
+ from typing import Optional
25
+
26
+ import numpy as np
27
+ import torch
28
+
29
+ from ..types import Callable
30
+ from ..types import TensorType
31
+ from ..types import Union
32
+ from .general_utils import is_from
33
+ from .operator import Operator
34
+
35
+
36
+ def sanitize_input(tensor_arg_func: Callable):
37
+ """ensures the decorated function receives a torch.Tensor"""
38
+
39
+ def wrapper(obj, tensor, *args, **kwargs):
40
+ if isinstance(tensor, torch.Tensor):
41
+ pass
42
+ elif is_from(tensor, "tensorflow"):
43
+ tensor = torch.Tensor(tensor.numpy())
44
+ else:
45
+ tensor = torch.Tensor(tensor)
46
+
47
+ return tensor_arg_func(obj, tensor, *args, **kwargs)
48
+
49
+ return wrapper
50
+
51
+
52
+ class TorchOperator(Operator):
53
+ """Class to handle torch operations with a unified API"""
54
+
55
+ def __init__(self, model: Optional[torch.nn.Module] = None):
56
+ if model is not None:
57
+ self._device = next(model.parameters()).device
58
+ else:
59
+ self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
+
61
+ @staticmethod
62
+ def softmax(tensor: TensorType) -> torch.Tensor:
63
+ """Softmax function along the last dimension"""
64
+ return torch.nn.functional.softmax(tensor, dim=-1)
65
+
66
+ @staticmethod
67
+ def argmax(tensor: TensorType, dim: Optional[int] = None) -> torch.Tensor:
68
+ """Argmax function"""
69
+ return torch.argmax(tensor, dim=dim)
70
+
71
+ @staticmethod
72
+ def max(
73
+ tensor: TensorType, dim: Optional[int] = None, keepdim: Optional[bool] = False
74
+ ) -> torch.Tensor:
75
+ """Max function"""
76
+ if dim is None:
77
+ return torch.max(tensor)
78
+ else:
79
+ return torch.max(tensor, dim, keepdim=keepdim)[0]
80
+
81
+ @staticmethod
82
+ def min(
83
+ tensor: TensorType, dim: Optional[int] = None, keepdim: bool = False
84
+ ) -> torch.Tensor:
85
+ """Min function"""
86
+ if dim is None:
87
+ return torch.min(tensor)
88
+ else:
89
+ return torch.min(tensor, dim, keepdim=keepdim)[0]
90
+
91
+ @staticmethod
92
+ def one_hot(tensor: TensorType, num_classes: int) -> torch.Tensor:
93
+ """One hot function"""
94
+ return torch.nn.functional.one_hot(tensor, num_classes)
95
+
96
+ @staticmethod
97
+ def sign(tensor: TensorType) -> torch.Tensor:
98
+ """Sign function"""
99
+ return torch.sign(tensor)
100
+
101
+ @staticmethod
102
+ def CrossEntropyLoss(reduction: str = "mean"):
103
+ """Cross Entropy Loss from logits"""
104
+
105
+ def sanitized_ce_loss(inputs, targets):
106
+ return torch.nn.CrossEntropyLoss(reduction=reduction)(inputs, targets)
107
+
108
+ return sanitized_ce_loss
109
+
110
+ @staticmethod
111
+ def norm(tensor: TensorType, dim: Optional[int] = None) -> torch.Tensor:
112
+ """Tensor Norm"""
113
+ return torch.norm(tensor, dim=dim)
114
+
115
+ @staticmethod
116
+ def matmul(tensor_1: TensorType, tensor_2: TensorType) -> torch.Tensor:
117
+ """Matmul operation"""
118
+ return torch.matmul(tensor_1, tensor_2)
119
+
120
+ @staticmethod
121
+ def convert_from_tensorflow(tensor: TensorType) -> torch.Tensor:
122
+ """Convert a tensorflow tensor into a torch tensor
123
+
124
+ Used when using a pytorch model on a dataset loaded from tensorflow datasets
125
+ """
126
+ return torch.Tensor(tensor.numpy())
127
+
128
+ @staticmethod
129
+ def convert_to_numpy(tensor: TensorType) -> np.ndarray:
130
+ """Convert tensor into a np.ndarray"""
131
+ if not isinstance(tensor, np.ndarray):
132
+ if tensor.device != "cpu":
133
+ tensor = tensor.to("cpu")
134
+ return tensor.detach().numpy()
135
+ return tensor
136
+
137
+ @staticmethod
138
+ def gradient(func: Callable, inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor:
139
+ """Compute gradients for a batch of samples.
140
+
141
+ Args:
142
+ func (Callable): Function used for computing gradient. Must be built with
143
+ torch differentiable operations only, and return a scalar.
144
+ inputs (torch.Tensor): Input tensor wrt which the gradients are computed
145
+ *args: Additional Args for func.
146
+ **kwargs: Additional Kwargs for func.
147
+
148
+ Returns:
149
+ torch.Tensor: Gradients computed, with the same shape as the inputs.
150
+ """
151
+ inputs.requires_grad_(True)
152
+ outputs = func(inputs, *args, **kwargs)
153
+ gradients = torch.autograd.grad(outputs, inputs)
154
+ inputs.requires_grad_(False)
155
+ return gradients[0]
156
+
157
+ @staticmethod
158
+ def stack(tensors: List[TensorType], dim: int = 0) -> torch.Tensor:
159
+ "Stack tensors along a new dimension"
160
+ return torch.stack(tensors, dim)
161
+
162
+ @staticmethod
163
+ def cat(tensors: List[TensorType], dim: int = 0) -> torch.Tensor:
164
+ "Concatenate tensors in a given dimension"
165
+ return torch.cat(tensors, dim)
166
+
167
+ @staticmethod
168
+ def mean(tensor: TensorType, dim: Optional[int] = None) -> torch.Tensor:
169
+ "Mean function"
170
+ if dim is None:
171
+ return torch.mean(tensor)
172
+ else:
173
+ return torch.mean(tensor, dim)
174
+
175
+ @staticmethod
176
+ def flatten(tensor: TensorType) -> torch.Tensor:
177
+ "Flatten function"
178
+ # Flatten the features to 2D (n_batch, n_features)
179
+ return tensor.view(tensor.size(0), -1)
180
+
181
+ def from_numpy(self, arr: np.ndarray) -> torch.Tensor:
182
+ "Convert a NumPy array to a tensor"
183
+ # TODO change dtype
184
+ return torch.tensor(arr).to(self._device)
185
+
186
+ @staticmethod
187
+ def t(tensor: TensorType) -> torch.Tensor:
188
+ "Transpose function for tensor of rank 2"
189
+ return tensor.t()
190
+
191
+ @staticmethod
192
+ def permute(tensor: TensorType, dims) -> torch.Tensor:
193
+ "Transpose function for tensor of rank 2"
194
+ return torch.permute(tensor, dims)
195
+
196
+ @staticmethod
197
+ def diag(tensor: TensorType) -> torch.Tensor:
198
+ "Diagonal function: return the diagonal of a 2D tensor"
199
+ return tensor.diag()
200
+
201
+ @staticmethod
202
+ def reshape(tensor: TensorType, shape: List[int]) -> torch.Tensor:
203
+ "Reshape function"
204
+ return tensor.view(*shape)
205
+
206
+ @staticmethod
207
+ def equal(tensor: TensorType, other: Union[TensorType, int, float]) -> torch.Tensor:
208
+ "Computes element-wise equality"
209
+ return torch.eq(tensor, other)
210
+
211
+ @staticmethod
212
+ def pinv(tensor: TensorType) -> torch.Tensor:
213
+ "Computes the pseudoinverse (Moore-Penrose inverse) of a matrix."
214
+ return torch.linalg.pinv(tensor)
215
+
216
+ @staticmethod
217
+ def eigh(tensor: TensorType) -> torch.Tensor:
218
+ "Computes the eigen decomposition of a self-adjoint matrix."
219
+ eigval, eigvec = torch.linalg.eigh(tensor)
220
+ return eigval, eigvec
221
+
222
+ @staticmethod
223
+ def quantile(tensor: TensorType, q: float, dim: int = None) -> torch.Tensor:
224
+ "Computes the quantile of a tensor's components. q in (0,1)"
225
+ if dim is None:
226
+ # keep the 16 millions first elements (see torch.quantile issue:
227
+ # https://github.com/pytorch/pytorch/issues/64947)
228
+ tensor_flatten = tensor.view(-1)[:16_000_000]
229
+ return torch.quantile(tensor_flatten, q).item()
230
+ else:
231
+ return torch.quantile(tensor, q, dim)
232
+
233
+ @staticmethod
234
+ def relu(tensor: TensorType) -> torch.Tensor:
235
+ "Apply relu to a tensor"
236
+ return torch.nn.functional.relu(tensor)
237
+
238
+ @staticmethod
239
+ def einsum(equation: str, *tensors: TensorType) -> torch.Tensor:
240
+ "Computes the einsum between tensors following equation"
241
+ return torch.einsum(equation, *tensors)
242
+
243
+ @staticmethod
244
+ def tril(tensor: TensorType, diagonal: int = 0) -> torch.Tensor:
245
+ "Set the upper triangle of the matrix formed by the last two dimensions of"
246
+ "tensor to zero"
247
+ return torch.tril(tensor, diagonal)
248
+
249
+ @staticmethod
250
+ def sum(tensor: TensorType, dim: Union[tuple, list, int] = None) -> torch.Tensor:
251
+ "sum along dim"
252
+ if dim is None:
253
+ return torch.sum(tensor)
254
+ return torch.sum(tensor, dim)
255
+
256
+ @staticmethod
257
+ def unsqueeze(tensor: TensorType, dim: int) -> torch.Tensor:
258
+ "unsqueeze along dim"
259
+ return torch.unsqueeze(tensor, dim)
260
+
261
+ @staticmethod
262
+ def squeeze(tensor: TensorType, dim: int = None) -> torch.Tensor:
263
+ "squeeze along dim"
264
+
265
+ if dim is None:
266
+ return torch.squeeze(tensor)
267
+
268
+ return torch.squeeze(tensor, dim)
269
+
270
+ @staticmethod
271
+ def abs(tensor: TensorType) -> torch.Tensor:
272
+ "compute absolute value"
273
+ return torch.abs(tensor)
274
+
275
+ @staticmethod
276
+ def where(
277
+ condition: TensorType,
278
+ input: Union[TensorType, float],
279
+ other: Union[TensorType, float],
280
+ ) -> torch.Tensor:
281
+ "Applies where function , to condition"
282
+ return torch.where(condition, input, other)
283
+
284
+ @staticmethod
285
+ def avg_pool_2d(tensor: TensorType) -> torch.Tensor:
286
+ """Perform avg pool in 2d as in torch.nn.functional.adaptive_avg_pool2d"""
287
+ return torch.mean(tensor, dim=(-2, -1))
288
+
289
+ @staticmethod
290
+ def log(tensor: TensorType) -> torch.Tensor:
291
+ """Perform log"""
292
+ return torch.log(tensor)