oodeel 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. oodeel/__init__.py +28 -0
  2. oodeel/aggregator/__init__.py +26 -0
  3. oodeel/aggregator/base.py +70 -0
  4. oodeel/aggregator/fisher.py +259 -0
  5. oodeel/aggregator/mean.py +72 -0
  6. oodeel/aggregator/std.py +86 -0
  7. oodeel/datasets/__init__.py +24 -0
  8. oodeel/datasets/data_handler.py +334 -0
  9. oodeel/datasets/deprecated/DEPRECATED_data_handler.py +236 -0
  10. oodeel/datasets/deprecated/DEPRECATED_ooddataset.py +330 -0
  11. oodeel/datasets/deprecated/DEPRECATED_tf_data_handler.py +671 -0
  12. oodeel/datasets/deprecated/DEPRECATED_torch_data_handler.py +769 -0
  13. oodeel/datasets/deprecated/__init__.py +31 -0
  14. oodeel/datasets/tf_data_handler.py +600 -0
  15. oodeel/datasets/torch_data_handler.py +672 -0
  16. oodeel/eval/__init__.py +22 -0
  17. oodeel/eval/metrics.py +218 -0
  18. oodeel/eval/plots/__init__.py +27 -0
  19. oodeel/eval/plots/features.py +345 -0
  20. oodeel/eval/plots/metrics.py +118 -0
  21. oodeel/eval/plots/plotly.py +162 -0
  22. oodeel/extractor/__init__.py +35 -0
  23. oodeel/extractor/feature_extractor.py +187 -0
  24. oodeel/extractor/hf_torch_feature_extractor.py +184 -0
  25. oodeel/extractor/keras_feature_extractor.py +409 -0
  26. oodeel/extractor/torch_feature_extractor.py +506 -0
  27. oodeel/methods/__init__.py +47 -0
  28. oodeel/methods/base.py +570 -0
  29. oodeel/methods/dknn.py +185 -0
  30. oodeel/methods/energy.py +119 -0
  31. oodeel/methods/entropy.py +113 -0
  32. oodeel/methods/gen.py +113 -0
  33. oodeel/methods/gram.py +274 -0
  34. oodeel/methods/mahalanobis.py +209 -0
  35. oodeel/methods/mls.py +113 -0
  36. oodeel/methods/odin.py +109 -0
  37. oodeel/methods/rmds.py +172 -0
  38. oodeel/methods/she.py +159 -0
  39. oodeel/methods/vim.py +273 -0
  40. oodeel/preprocess/__init__.py +31 -0
  41. oodeel/preprocess/tf_preprocess.py +95 -0
  42. oodeel/preprocess/torch_preprocess.py +97 -0
  43. oodeel/types/__init__.py +75 -0
  44. oodeel/utils/__init__.py +38 -0
  45. oodeel/utils/general_utils.py +97 -0
  46. oodeel/utils/operator.py +253 -0
  47. oodeel/utils/tf_operator.py +269 -0
  48. oodeel/utils/tf_training_tools.py +219 -0
  49. oodeel/utils/torch_operator.py +292 -0
  50. oodeel/utils/torch_training_tools.py +303 -0
  51. oodeel-0.4.0.dist-info/METADATA +409 -0
  52. oodeel-0.4.0.dist-info/RECORD +63 -0
  53. oodeel-0.4.0.dist-info/WHEEL +5 -0
  54. oodeel-0.4.0.dist-info/licenses/LICENSE +21 -0
  55. oodeel-0.4.0.dist-info/top_level.txt +2 -0
  56. tests/__init__.py +22 -0
  57. tests/tests_tensorflow/__init__.py +37 -0
  58. tests/tests_tensorflow/tf_methods_utils.py +140 -0
  59. tests/tests_tensorflow/tools_tf.py +86 -0
  60. tests/tests_torch/__init__.py +38 -0
  61. tests/tests_torch/tools_torch.py +151 -0
  62. tests/tests_torch/torch_methods_utils.py +148 -0
  63. tests/tools_operator.py +153 -0
@@ -0,0 +1,97 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ from ..types import Any
24
+ from ..types import Callable
25
+
26
+
27
+ def is_from(model_or_tensor: Any, framework: str) -> bool:
28
+ """Check whether a model or tensor belongs to a specific framework
29
+
30
+ Args:
31
+ model_or_tensor (Any): Neural network or Tensor
32
+ framework (str): Model or tensor framework ("torch" | "keras" | "tensorflow")
33
+
34
+ Returns:
35
+ bool: Whether the model belongs to specified framework or not
36
+ """
37
+ keywords_list = []
38
+ class_parents = list(model_or_tensor.__class__.__mro__)
39
+ for class_id in class_parents:
40
+ class_list = str(class_id).split("'")[1].split(".")
41
+ for keyword in class_list:
42
+ keywords_list.append(keyword)
43
+
44
+ if framework == "huggingface":
45
+ framework = "transformers"
46
+
47
+ return framework in keywords_list
48
+
49
+
50
+ def import_backend_specific_stuff(model: Callable):
51
+ """Get backend specific data handler, operator and feature extractor class.
52
+
53
+ Args:
54
+ model (Callable): a model (Keras or PyTorch) used to identify the backend.
55
+
56
+ Returns:
57
+ str: backend name
58
+ DataHandler: torch or tf data handler
59
+ Operator: torch or tf operator
60
+ FeatureExtractor: torch or tf feature extractor class
61
+ """
62
+ if is_from(model, "keras"):
63
+ from ..extractor.keras_feature_extractor import KerasFeatureExtractor
64
+ from ..datasets.tf_data_handler import TFDataHandler
65
+ from ..utils import TFOperator
66
+
67
+ backend = "tensorflow"
68
+ data_handler = TFDataHandler()
69
+ op = TFOperator()
70
+ FeatureExtractorClass = KerasFeatureExtractor
71
+
72
+ # For huggingface models, is_from(model, "torch") will also return True so
73
+ # it has to be checked before torch
74
+ elif is_from(model, "huggingface"):
75
+ from ..extractor.hf_torch_feature_extractor import HFTorchFeatureExtractor
76
+ from ..datasets.torch_data_handler import TorchDataHandler
77
+ from ..utils import TorchOperator
78
+
79
+ backend = "torch"
80
+ data_handler = TorchDataHandler()
81
+ op = TorchOperator(model)
82
+ FeatureExtractorClass = HFTorchFeatureExtractor
83
+
84
+ elif is_from(model, "torch"):
85
+ from ..extractor.torch_feature_extractor import TorchFeatureExtractor
86
+ from ..datasets.torch_data_handler import TorchDataHandler
87
+ from ..utils import TorchOperator
88
+
89
+ backend = "torch"
90
+ data_handler = TorchDataHandler()
91
+ op = TorchOperator(model)
92
+ FeatureExtractorClass = TorchFeatureExtractor
93
+
94
+ else:
95
+ raise NotImplementedError()
96
+
97
+ return backend, data_handler, op, FeatureExtractorClass
@@ -0,0 +1,253 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ from abc import ABC
24
+ from abc import abstractmethod
25
+
26
+ import numpy as np
27
+
28
+ from ..types import Callable
29
+ from ..types import List
30
+ from ..types import Optional
31
+ from ..types import TensorType
32
+ from ..types import Union
33
+
34
+
35
+ class Operator(ABC):
36
+ """Class to handle tensorflow and torch operations with a unified API"""
37
+
38
+ @staticmethod
39
+ @abstractmethod
40
+ def softmax(tensor: TensorType) -> TensorType:
41
+ """Softmax function along the last dimension"""
42
+ raise NotImplementedError()
43
+
44
+ @staticmethod
45
+ @abstractmethod
46
+ def argmax(tensor: TensorType, dim: Optional[int] = None) -> TensorType:
47
+ """Argmax function"""
48
+ raise NotImplementedError()
49
+
50
+ @staticmethod
51
+ @abstractmethod
52
+ def max(
53
+ tensor: TensorType, dim: Optional[int] = None, keepdim: bool = False
54
+ ) -> TensorType:
55
+ """Max function"""
56
+ raise NotImplementedError()
57
+
58
+ @staticmethod
59
+ @abstractmethod
60
+ def min(
61
+ tensor: TensorType, dim: Optional[int] = None, keepdim: bool = False
62
+ ) -> TensorType:
63
+ """Min function"""
64
+ raise NotImplementedError()
65
+
66
+ @staticmethod
67
+ @abstractmethod
68
+ def one_hot(tensor: TensorType, num_classes: int) -> TensorType:
69
+ """One hot function"""
70
+ raise NotImplementedError()
71
+
72
+ @staticmethod
73
+ @abstractmethod
74
+ def sign(tensor: TensorType) -> TensorType:
75
+ """Sign function"""
76
+ raise NotImplementedError()
77
+
78
+ @staticmethod
79
+ @abstractmethod
80
+ def CrossEntropyLoss(reduction: str = "mean"):
81
+ """Cross Entropy Loss from logits"""
82
+ raise NotImplementedError()
83
+
84
+ @staticmethod
85
+ @abstractmethod
86
+ def norm(tensor: TensorType, dim: Optional[int] = None) -> TensorType:
87
+ """Norm function"""
88
+ raise NotImplementedError()
89
+
90
+ @staticmethod
91
+ @abstractmethod
92
+ def matmul(tensor_1: TensorType, tensor_2: TensorType) -> TensorType:
93
+ """Matmul operation"""
94
+ raise NotImplementedError()
95
+
96
+ @staticmethod
97
+ @abstractmethod
98
+ def convert_to_numpy(tensor: TensorType) -> np.ndarray:
99
+ "Convert a tensor to a NumPy array"
100
+ raise NotImplementedError()
101
+
102
+ @staticmethod
103
+ @abstractmethod
104
+ def gradient(func: Callable, inputs: TensorType) -> TensorType:
105
+ """Compute gradients for a batch of samples.
106
+
107
+ Args:
108
+ func (Callable): Function used for computing gradient. Must be built with
109
+ differentiable operations only, and return a scalar.
110
+ inputs (Any): Input tensor wrt which the gradients are computed
111
+
112
+ Returns:
113
+ Gradients computed, with the same shape as the inputs.
114
+ """
115
+ raise NotImplementedError()
116
+
117
+ @staticmethod
118
+ @abstractmethod
119
+ def stack(tensors: List[TensorType], dim: int = 0) -> TensorType:
120
+ "Stack tensors along a new dimension"
121
+ raise NotImplementedError()
122
+
123
+ @staticmethod
124
+ @abstractmethod
125
+ def cat(tensors: List[TensorType], dim: int = 0) -> TensorType:
126
+ "Concatenate tensors in a given dimension"
127
+ raise NotImplementedError()
128
+
129
+ @staticmethod
130
+ @abstractmethod
131
+ def mean(tensor: TensorType, dim: Optional[int] = None) -> TensorType:
132
+ "Mean function"
133
+ raise NotImplementedError()
134
+
135
+ @staticmethod
136
+ @abstractmethod
137
+ def flatten(tensor: TensorType) -> TensorType:
138
+ "Flatten to 2D tensor (batch_size, -1)"
139
+ # Flatten the features to 2D (n_batch, n_features)
140
+ raise NotImplementedError()
141
+
142
+ @staticmethod
143
+ @abstractmethod
144
+ def from_numpy(arr: np.ndarray) -> TensorType:
145
+ "Convert a NumPy array to a tensor"
146
+ # TODO change dtype
147
+ raise NotImplementedError()
148
+
149
+ @staticmethod
150
+ @abstractmethod
151
+ def t(tensor: TensorType) -> TensorType:
152
+ "Transpose function for tensor of rank 2"
153
+ raise NotImplementedError()
154
+
155
+ @staticmethod
156
+ @abstractmethod
157
+ def permute(tensor: TensorType) -> TensorType:
158
+ "Transpose function for tensor of rank 2"
159
+ raise NotImplementedError()
160
+
161
+ @staticmethod
162
+ @abstractmethod
163
+ def diag(tensor: TensorType) -> TensorType:
164
+ "Diagonal function: return the diagonal of a 2D tensor"
165
+ raise NotImplementedError()
166
+
167
+ @staticmethod
168
+ @abstractmethod
169
+ def reshape(tensor: TensorType, shape: List[int]) -> TensorType:
170
+ "Reshape function"
171
+ raise NotImplementedError()
172
+
173
+ @staticmethod
174
+ @abstractmethod
175
+ def equal(tensor: TensorType, other: Union[TensorType, int, float]) -> TensorType:
176
+ "Computes element-wise equality"
177
+ raise NotImplementedError()
178
+
179
+ @staticmethod
180
+ @abstractmethod
181
+ def pinv(tensor: TensorType) -> TensorType:
182
+ "Computes the pseudoinverse (Moore-Penrose inverse) of a matrix."
183
+ raise NotImplementedError()
184
+
185
+ @staticmethod
186
+ @abstractmethod
187
+ def eigh(tensor: TensorType) -> TensorType:
188
+ "Computes the eigen decomposition of a self-adjoint matrix."
189
+ raise NotImplementedError()
190
+
191
+ @staticmethod
192
+ def quantile(tensor: TensorType, q: float, dim: int = None) -> TensorType:
193
+ "Computes the quantile of a tensor's components"
194
+ raise NotImplementedError()
195
+
196
+ @staticmethod
197
+ def relu(tensor: TensorType) -> TensorType:
198
+ "Apply relu to a tensor"
199
+ raise NotImplementedError()
200
+
201
+ @staticmethod
202
+ @abstractmethod
203
+ def einsum(equation: str, *tensors: TensorType) -> TensorType:
204
+ "Computes the einsum between tensors following equation"
205
+ raise NotImplementedError()
206
+
207
+ @staticmethod
208
+ @abstractmethod
209
+ def tril(tensor: TensorType, diagonal: int = 0) -> TensorType:
210
+ "Set the upper triangle of the matrix formed by the last two dimensions of"
211
+ "tensor to zero"
212
+ raise NotImplementedError()
213
+
214
+ @staticmethod
215
+ def sum(tensor: TensorType, dim: Union[tuple, list, int] = None) -> TensorType:
216
+ "sum along dim"
217
+ raise NotImplementedError()
218
+
219
+ @staticmethod
220
+ def unsqueeze(tensor: TensorType, dim: int) -> TensorType:
221
+ "unsqueeze/expand_dim along dim"
222
+ raise NotImplementedError()
223
+
224
+ @staticmethod
225
+ def squeeze(tensor: TensorType, dim: int = None) -> TensorType:
226
+ "squeeze along dim"
227
+ raise NotImplementedError()
228
+
229
+ @staticmethod
230
+ def abs(tensor: TensorType) -> TensorType:
231
+ "compute absolute value"
232
+ raise NotImplementedError()
233
+
234
+ @staticmethod
235
+ def where(
236
+ condition: TensorType,
237
+ input: Union[TensorType, float],
238
+ other: Union[TensorType, float],
239
+ ) -> TensorType:
240
+ "Applies where function to condition"
241
+ raise NotImplementedError()
242
+
243
+ @staticmethod
244
+ @abstractmethod
245
+ def avg_pool_2d(tensor: TensorType) -> TensorType:
246
+ """Perform avg pool in 2d as in torch.nn.functional.adaptive_avg_pool2d"""
247
+ raise NotImplementedError()
248
+
249
+ @staticmethod
250
+ @abstractmethod
251
+ def log(tensor: TensorType) -> TensorType:
252
+ """Perform log"""
253
+ raise NotImplementedError()
@@ -0,0 +1,269 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
3
+ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
4
+ # CRIAQ and ANITI - https://www.deel.ai/
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ import numpy as np
24
+ import tensorflow as tf
25
+ import tensorflow_probability as tfp
26
+
27
+ from ..types import Callable
28
+ from ..types import List
29
+ from ..types import Optional
30
+ from ..types import TensorType
31
+ from ..types import Union
32
+ from .general_utils import is_from
33
+ from .operator import Operator
34
+
35
+
36
+ def sanitize_input(tensor_arg_func: Callable):
37
+ """ensures the decorated function receives a tf.Tensor"""
38
+
39
+ def wrapper(obj, tensor, *args, **kwargs):
40
+ if isinstance(tensor, tf.Tensor):
41
+ pass
42
+ elif is_from(tensor, "torch"):
43
+ tensor = tf.convert_to_tensor(tensor.numpy())
44
+ else:
45
+ tensor = tf.convert_to_tensor(tensor)
46
+
47
+ return tensor_arg_func(obj, tensor, *args, **kwargs)
48
+
49
+ return wrapper
50
+
51
+
52
+ class TFOperator(Operator):
53
+ """Class to handle tensorflow operations with a unified API"""
54
+
55
+ @staticmethod
56
+ def softmax(tensor: TensorType) -> tf.Tensor:
57
+ """Softmax function along the last dimension"""
58
+ return tf.keras.activations.softmax(tensor, axis=-1)
59
+
60
+ @staticmethod
61
+ def argmax(tensor: TensorType, dim: Optional[int] = None) -> tf.Tensor:
62
+ """Argmax function"""
63
+ if dim is None:
64
+ return tf.argmax(tf.reshape(tensor, [-1]))
65
+ return tf.argmax(tensor, axis=dim)
66
+
67
+ @staticmethod
68
+ def max(
69
+ tensor: TensorType, dim: Optional[int] = None, keepdim: bool = False
70
+ ) -> tf.Tensor:
71
+ """Max function"""
72
+ return tf.reduce_max(tensor, axis=dim, keepdims=keepdim)
73
+
74
+ @staticmethod
75
+ def min(
76
+ tensor: TensorType, dim: Optional[int] = None, keepdim: bool = False
77
+ ) -> tf.Tensor:
78
+ """Min function"""
79
+ return tf.reduce_min(tensor, axis=dim, keepdims=keepdim)
80
+
81
+ @staticmethod
82
+ def one_hot(tensor: TensorType, num_classes: int) -> tf.Tensor:
83
+ """One hot function"""
84
+ return tf.one_hot(tensor, num_classes)
85
+
86
+ @staticmethod
87
+ def sign(tensor: TensorType) -> tf.Tensor:
88
+ """Sign function"""
89
+ return tf.sign(tensor)
90
+
91
+ @staticmethod
92
+ def CrossEntropyLoss(reduction: str = "mean"):
93
+ """Cross Entropy Loss from logits"""
94
+
95
+ tf_reduction = {"mean": "sum_over_batch_size", "sum": "sum"}[reduction]
96
+
97
+ def sanitized_ce_loss(inputs, targets):
98
+ return tf.keras.losses.SparseCategoricalCrossentropy(
99
+ from_logits=True, reduction=tf_reduction
100
+ )(targets, inputs)
101
+
102
+ return sanitized_ce_loss
103
+
104
+ @staticmethod
105
+ def norm(tensor: TensorType, dim: Optional[int] = None) -> tf.Tensor:
106
+ """Tensor Norm"""
107
+ return tf.norm(tensor, axis=dim)
108
+
109
+ @staticmethod
110
+ @tf.function
111
+ def matmul(tensor_1: TensorType, tensor_2: TensorType) -> tf.Tensor:
112
+ """Matmul operation"""
113
+ return tf.matmul(tensor_1, tensor_2)
114
+
115
+ @staticmethod
116
+ def convert_to_numpy(tensor: TensorType) -> np.ndarray:
117
+ """Convert tensor into a np.ndarray"""
118
+ if not isinstance(tensor, np.ndarray):
119
+ return tensor.numpy()
120
+ return tensor
121
+
122
+ @staticmethod
123
+ def gradient(func: Callable, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
124
+ """Compute gradients for a batch of samples.
125
+
126
+ Args:
127
+ func (Callable): Function used for computing gradient. Must be built with
128
+ tensorflow differentiable operations only, and return a scalar.
129
+ inputs (tf.Tensor): Input tensor wrt which the gradients are computed
130
+ *args: Additional Args for func.
131
+ **kwargs: Additional Kwargs for func.
132
+
133
+ Returns:
134
+ tf.Tensor: Gradients computed, with the same shape as the inputs.
135
+ """
136
+ with tf.GradientTape(watch_accessed_variables=False) as tape:
137
+ tape.watch(inputs)
138
+ outputs = func(inputs, *args, **kwargs)
139
+ return tape.gradient(outputs, inputs)
140
+
141
+ @staticmethod
142
+ def stack(tensors: List[TensorType], dim: int = 0) -> tf.Tensor:
143
+ "Stack tensors along a new dimension"
144
+ return tf.stack(tensors, dim)
145
+
146
+ @staticmethod
147
+ def cat(tensors: List[TensorType], dim: int = 0) -> tf.Tensor:
148
+ "Concatenate tensors in a given dimension"
149
+ return tf.concat(tensors, dim)
150
+
151
+ @staticmethod
152
+ def mean(tensor: TensorType, dim: Optional[int] = None) -> tf.Tensor:
153
+ "Mean function"
154
+ return tf.reduce_mean(tensor, dim)
155
+
156
+ @staticmethod
157
+ def flatten(tensor: TensorType) -> tf.Tensor:
158
+ "Flatten to 2D tensor of shape (tensor.shape[0], -1)"
159
+ # Flatten the features to 2D (n_batch, n_features)
160
+ return tf.reshape(tensor, shape=[tf.shape(tensor)[0], -1])
161
+
162
+ @staticmethod
163
+ def from_numpy(arr: np.ndarray) -> tf.Tensor:
164
+ "Convert a NumPy array to a tensor"
165
+ # TODO change dtype
166
+ return tf.convert_to_tensor(arr)
167
+
168
+ @staticmethod
169
+ def t(tensor: TensorType) -> tf.Tensor:
170
+ "Transpose function for tensor of rank 2"
171
+ return tf.transpose(tensor)
172
+
173
+ @staticmethod
174
+ def permute(tensor: TensorType, dims) -> tf.Tensor:
175
+ "Transpose function for tensor of rank 2"
176
+ return tf.transpose(tensor, dims)
177
+
178
+ @staticmethod
179
+ def diag(tensor: TensorType) -> tf.Tensor:
180
+ "Diagonal function: return the diagonal of a 2D tensor"
181
+ return tf.linalg.diag_part(tensor)
182
+
183
+ @staticmethod
184
+ def reshape(tensor: TensorType, shape: List[int]) -> tf.Tensor:
185
+ "Reshape function"
186
+ return tf.reshape(tensor, shape)
187
+
188
+ @staticmethod
189
+ def equal(tensor: TensorType, other: Union[TensorType, int, float]) -> tf.Tensor:
190
+ "Computes element-wise equality"
191
+ return tf.math.equal(tensor, other)
192
+
193
+ @staticmethod
194
+ def pinv(tensor: TensorType) -> tf.Tensor:
195
+ "Computes the pseudoinverse (Moore-Penrose inverse) of a matrix."
196
+ return tf.linalg.pinv(tensor)
197
+
198
+ @staticmethod
199
+ def eigh(tensor: TensorType) -> tf.Tensor:
200
+ "Computes the eigen decomposition of a self-adjoint matrix."
201
+ eigval, eigvec = tf.linalg.eigh(tensor)
202
+ return eigval, eigvec
203
+
204
+ @staticmethod
205
+ def quantile(tensor: TensorType, q: float, dim: int = None) -> tf.Tensor:
206
+ "Computes the quantile of a tensor's components. q in (0,1)"
207
+ q = tfp.stats.percentile(tensor, q * 100, axis=dim)
208
+ return float(q) if dim is None else q
209
+
210
+ @staticmethod
211
+ def relu(tensor: TensorType) -> tf.Tensor:
212
+ "Apply relu to a tensor"
213
+ return tf.nn.relu(tensor)
214
+
215
+ @staticmethod
216
+ def einsum(equation: str, *tensors: TensorType) -> tf.Tensor:
217
+ "Computes the einsum between tensors following equation"
218
+ return tf.einsum(equation, *tensors)
219
+
220
+ @staticmethod
221
+ def tril(tensor: TensorType, diagonal: int = 0) -> tf.Tensor:
222
+ "Set the upper triangle of the matrix formed by the last two dimensions of"
223
+ "tensor to zero"
224
+ return tf.experimental.numpy.tril(tensor, k=diagonal)
225
+
226
+ @staticmethod
227
+ def sum(tensor: TensorType, dim: Union[tuple, list, int] = None) -> tf.Tensor:
228
+ "sum along dim"
229
+ if tensor.dtype == tf.bool:
230
+ tensor = tf.cast(tensor, tf.float32)
231
+ return tf.reduce_sum(tensor, axis=dim)
232
+
233
+ @staticmethod
234
+ def unsqueeze(tensor: TensorType, dim: int) -> tf.Tensor:
235
+ "expand_dim along dim"
236
+ return tf.expand_dims(tensor, dim)
237
+
238
+ @staticmethod
239
+ def squeeze(tensor: TensorType, dim: int = None) -> tf.Tensor:
240
+ "expand_dim along dim"
241
+ return tf.squeeze(tensor, dim)
242
+
243
+ @staticmethod
244
+ def abs(tensor: TensorType) -> tf.Tensor:
245
+ "compute absolute value"
246
+ return tf.abs(tensor)
247
+
248
+ @staticmethod
249
+ def where(
250
+ condition: TensorType,
251
+ input: Union[TensorType, float],
252
+ other: Union[TensorType, float],
253
+ ) -> tf.Tensor:
254
+ "Applies where function to condition"
255
+ return tf.where(condition, input, other)
256
+
257
+ @staticmethod
258
+ def percentile(x, q):
259
+ return tfp.stats.percentile(x, q)
260
+
261
+ @staticmethod
262
+ def avg_pool_2d(tensor: TensorType) -> tf.Tensor:
263
+ """Perform avg pool in 2d as in torch.nn.functional.adaptive_avg_pool2d"""
264
+ return tf.reduce_mean(tensor, axis=(-3, -2))
265
+
266
+ @staticmethod
267
+ def log(tensor: TensorType) -> tf.Tensor:
268
+ """Perform log"""
269
+ return tf.math.log(tensor)