ssad 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. ssad/__init__.py +39 -0
  2. ssad/confidence_estimators/__init__.py +15 -0
  3. ssad/confidence_estimators/api.py +16 -0
  4. ssad/confidence_estimators/binary_confidence.py +30 -0
  5. ssad/confidence_estimators/confidence_estimator.py +94 -0
  6. ssad/confidence_estimators/confidence_intervals_configuration.py +67 -0
  7. ssad/confidence_estimators/hybrid_confidence.py +49 -0
  8. ssad/confidence_estimators/supports_confidence_estimation.py +26 -0
  9. ssad/datamodules/__init__.py +10 -0
  10. ssad/datamodules/api.py +27 -0
  11. ssad/datamodules/dataframe_with_labels.py +95 -0
  12. ssad/datamodules/dataset_interfaces.py +47 -0
  13. ssad/datamodules/dataset_with_confidence.py +101 -0
  14. ssad/datamodules/self_supervision_datamodule.py +99 -0
  15. ssad/datamodules/transforms/dataframe_to_tensor.py +34 -0
  16. ssad/datasets/__init__.py +14 -0
  17. ssad/datasets/api.py +29 -0
  18. ssad/datasets/general_tabular_datamodule.py +322 -0
  19. ssad/datasets/pipeline.py +98 -0
  20. ssad/datasets/utils.py +427 -0
  21. ssad/distribution_analyzers/__init__.py +15 -0
  22. ssad/distribution_analyzers/api.py +20 -0
  23. ssad/distribution_analyzers/evt_thresholding.py +162 -0
  24. ssad/distribution_analyzers/supports_distribution_analysis.py +63 -0
  25. ssad/distribution_analyzers/triangular_thresholding.py +295 -0
  26. ssad/loggers/__init__.py +13 -0
  27. ssad/loggers/api.py +14 -0
  28. ssad/loggers/logging_config.py +57 -0
  29. ssad/loggers/mlflow_logger.py +154 -0
  30. ssad/models/__init__.py +10 -0
  31. ssad/models/api.py +14 -0
  32. ssad/models/autoencoder.py +86 -0
  33. ssad/models/variational_autoencoder.py +388 -0
  34. ssad/modules/__init__.py +15 -0
  35. ssad/modules/api.py +22 -0
  36. ssad/modules/cosine_reconstruction_module.py +105 -0
  37. ssad/modules/free_energy_module.py +204 -0
  38. ssad/modules/self_supervision_module.py +528 -0
  39. ssad/modules/supports_self_supervision.py +47 -0
  40. ssad/py.typed +0 -0
  41. ssad-0.1.2.dist-info/METADATA +74 -0
  42. ssad-0.1.2.dist-info/RECORD +45 -0
  43. ssad-0.1.2.dist-info/WHEEL +5 -0
  44. ssad-0.1.2.dist-info/licenses/LICENSE +7 -0
  45. ssad-0.1.2.dist-info/top_level.txt +1 -0
ssad/__init__.py ADDED
@@ -0,0 +1,39 @@
1
+ # Software Name : Self-Supervised Anomaly Detection
2
+ # SPDX-FileCopyrightText: Copyright (c) Orange SA
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ # This software is distributed under the MIT License,
6
+ # see the "LICENSE.txt" file for more details or https://spdx.org/licenses/MIT.html
7
+ #
8
+ # Authors: see CONTRIBUTORS
9
+ # Software description: A Python library for autoencoder-based anomaly detection
10
+ # based on self-supervised training with dynamic sample confidence updates.
11
+ """
12
+ Initializes the main package API by exposing public interfaces from submodules.
13
+
14
+ This module re-exports the core components of the package to simplify access and
15
+ maintain a clean and consistent public API.
16
+
17
+ Available namespaces:
18
+ - confidence_estimators
19
+ - datamodules
20
+ - datasets
21
+ - distribution_analyzers
22
+ - models
23
+ - modules
24
+ - loggers
25
+
26
+ Usage:
27
+ from mypackage import SomeModel, SomeDataModule, ConfidenceEstimator
28
+
29
+ Note:
30
+ This file uses wildcard imports (`*`) to expose only the public symbols defined
31
+ in each submodule's `__all__` list.
32
+ """
33
+ from .confidence_estimators.api import *
34
+ from .datamodules.api import *
35
+ from .datasets.api import *
36
+ from .distribution_analyzers.api import *
37
+ from .models.api import *
38
+ from .modules.api import *
39
+ from .loggers.api import *
@@ -0,0 +1,15 @@
1
+ # Software Name : Self-Supervised Anomaly Detection
2
+ # SPDX-FileCopyrightText: Copyright (c) Orange SA
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ # This software is distributed under the MIT License,
6
+ # see the "LICENSE.txt" file for more details or https://spdx.org/licenses/MIT.html
7
+ #
8
+ # Authors: see CONTRIBUTORS
9
+ # Software description: A Python library for autoencoder-based anomaly detection
10
+ # based on self-supervised training with dynamic sample confidence updates.
11
+ r"""
12
+ # What is confidence estimator
13
+ A confidence estimator is a module that, given a value such as a reconstruction error,
14
+ or a gradient value, returns a confidence score comprised between -1 and 1.
15
+ """
@@ -0,0 +1,16 @@
1
+ # Software Name : Self-Supervised Anomaly Detection
2
+ # SPDX-FileCopyrightText: Copyright (c) Orange SA
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ # This software is distributed under the MIT License,
6
+ # see the "LICENSE.txt" file for more details or https://spdx.org/licenses/MIT.html
7
+ #
8
+ # Authors: see CONTRIBUTORS
9
+ # Software description: A Python library for autoencoder-based anomaly detection
10
+ # based on self-supervised training with dynamic sample confidence updates.
11
+ """Public API for self-supervision confidence estimators."""
12
+ from .binary_confidence import BinaryConfidence
13
+ from .supports_confidence_estimation import SupportsConfidenceEstimation
14
+ from .confidence_intervals_configuration import ConfidenceIntervalsConfiguration
15
+
16
+ __all__ = ["SupportsConfidenceEstimation", "BinaryConfidence", "ConfidenceIntervalsConfiguration"]
@@ -0,0 +1,30 @@
1
+ # Software Name : Self-Supervised Anomaly Detection
2
+ # SPDX-FileCopyrightText: Copyright (c) Orange SA
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ # This software is distributed under the MIT License,
6
+ # see the "LICENSE.txt" file for more details or https://spdx.org/licenses/MIT.html
7
+ #
8
+ # Authors: see CONTRIBUTORS
9
+ # Software description: A Python library for autoencoder-based anomaly detection
10
+ # based on self-supervised training with dynamic sample confidence updates.
11
+ """
12
+ Implements a binary confidence estimator.
13
+ A sample is either normal or abnormal, otherwise it is omitted (zero confidence).
14
+ """
15
+
16
+ import torch
17
+ from .confidence_estimator import BaseConfidenceEstimator
18
+
19
+
20
+ class BinaryConfidence(BaseConfidenceEstimator):
21
+ """Binary confidence estimator"""
22
+
23
+ def _confidence_normal(self, score):
24
+ return torch.ones_like(score)
25
+
26
+ def _confidence_abnormal(self, score):
27
+ return torch.full_like(score, -1)
28
+
29
+ def _confidence_unknown(self, score):
30
+ return torch.zeros_like(score)
@@ -0,0 +1,94 @@
1
+ # Software Name : Self-Supervised Anomaly Detection
2
+ # SPDX-FileCopyrightText: Copyright (c) Orange SA
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ # This software is distributed under the MIT License,
6
+ # see the "LICENSE.txt" file for more details or https://spdx.org/licenses/MIT.html
7
+ #
8
+ # Authors: see CONTRIBUTORS
9
+ # Software description: A Python library for autoencoder-based anomaly detection
10
+ # based on self-supervised training with dynamic sample confidence updates.
11
+ """
12
+ Provides the base class for confidence estimators.
13
+ """
14
+
15
+ from abc import ABC, abstractmethod
16
+ from typing import Optional
17
+ import torch
18
+
19
+ from .confidence_intervals_configuration import (
20
+ ConfidenceIntervalsConfiguration,
21
+ )
22
+ from .supports_confidence_estimation import SupportsConfidenceEstimation
23
+
24
+
25
+ class BaseConfidenceEstimator(ABC, SupportsConfidenceEstimation):
26
+ """Base class for confidence estimators.
27
+ A confidence estimator relies on four intervals with associated scoring functions
28
+ to provide a confidence score given a criterion score for a sample.
29
+
30
+ The criterion score can be for instance a reconstruction score, or the norm of
31
+ the gradient of the reconstruction error.
32
+
33
+ The intervals define the domain for model scores for the four different confidence behaviors:
34
+ - normal: samples with model scores in this interval are considered as normal
35
+ - abnormal: samples with model scores in this interval are considered as abnormal
36
+ - unknown_positive: samples with model scores in this interval are
37
+ considered as unknown, but leaning towards a normal sample.
38
+ - unknown_negative: samples with model scores in this interval are
39
+ considered as unknown, but leaning towards an abnormal sample.
40
+
41
+ Each interval is associated with a "criterion score to confidence score" conversion function.
42
+ These functions should be implemented in the _estimate_confidence_from_model_score method.
43
+ """
44
+
45
+ def __init__(self):
46
+ super().__init__()
47
+ self.configuration: Optional[ConfidenceIntervalsConfiguration] = None
48
+ self.distribution: Optional[torch.Tensor] = None
49
+
50
+ @abstractmethod
51
+ def _confidence_normal(self, score):
52
+ raise NotImplementedError()
53
+
54
+ @abstractmethod
55
+ def _confidence_abnormal(self, score) -> torch.Tensor:
56
+ raise NotImplementedError()
57
+
58
+ @abstractmethod
59
+ def _confidence_unknown(self, score) -> torch.Tensor:
60
+ raise NotImplementedError()
61
+
62
+ @torch.no_grad()
63
+ def estimate_confidence(self, scores_batch: torch.Tensor) -> torch.Tensor:
64
+ """Estimates the confidence in a batch by retrieving the criterion score
65
+ and translating into a confidence score.
66
+
67
+ Args:
68
+ scores_batch (torch.Tensor): batch whose confidence is to be estimated.
69
+
70
+ Returns:
71
+ torch.Tensor: confidence score
72
+ """
73
+ # TODO: check this order of computation
74
+ confidence = self._confidence_unknown(scores_batch)
75
+
76
+ if self.configuration is None:
77
+ raise ValueError("Confidence estimator configuration is None")
78
+
79
+ # TODO: rework signatures of confidence normal/abnormal/unknown
80
+ normal_confidences = self._confidence_normal(scores_batch)
81
+ abnormal_confidences = self._confidence_abnormal(scores_batch)
82
+
83
+ confidence = torch.where(
84
+ self.configuration.normal.contains_tensor_mask(scores_batch),
85
+ normal_confidences,
86
+ confidence,
87
+ )
88
+ confidence = torch.where(
89
+ self.configuration.abnormal.contains_tensor_mask(scores_batch),
90
+ abnormal_confidences,
91
+ confidence,
92
+ )
93
+
94
+ return confidence
@@ -0,0 +1,67 @@
1
+ # Software Name : Self-Supervised Anomaly Detection
2
+ # SPDX-FileCopyrightText: Copyright (c) Orange SA
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ # This software is distributed under the MIT License,
6
+ # see the "LICENSE.txt" file for more details or https://spdx.org/licenses/MIT.html
7
+ #
8
+ # Authors: see CONTRIBUTORS
9
+ # Software description: A Python library for autoencoder-based anomaly detection
10
+ # based on self-supervised training with dynamic sample confidence updates.
11
+ """
12
+ A Confidence Intervals configuration consists in a set of three intervals
13
+ for the value of the analyzed distribution (e.g., reconstruction scores).
14
+ These three intervals respectively define the intervals that correspond to
15
+ normal samples, abnormal samples, and samples for which it is impossible to decide,
16
+ labelled "unknown".
17
+ """
18
+
19
+ from dataclasses import dataclass
20
+
21
+ import pandas as pd
22
+ import torch
23
+
24
+
25
+ class Interval(pd.Interval):
26
+ """Extends pandas Intervals to add a method to get tensor masks corresponding to values
27
+ within the interval.
28
+ """
29
+
30
+ def contains_tensor_mask(self, item: torch.Tensor) -> torch.Tensor:
31
+ """Computes the mask corresponding to the tensor indexes
32
+ with a value contained within the interval.
33
+
34
+ Args:
35
+ item (torch.Tensor): tensor whose mask is to be computed
36
+
37
+ Returns:
38
+ torch.Tensor: mask, boolean tensor with the same shape as item,
39
+ where values are True when the corresponding value (same index) in item
40
+ is within the Interval.
41
+ """
42
+ left_tensor = item >= self.left if self.closed_left else item > self.left
43
+ right_tensor = item <= self.right if self.closed_right else item > self.right
44
+ return torch.logical_and(left_tensor, right_tensor)
45
+
46
+
47
+ @dataclass
48
+ class ConfidenceIntervalsConfiguration:
49
+ """Represents the configuration of intervals for a confidence estimator."""
50
+
51
+ normal: Interval
52
+ abnormal: Interval
53
+ unknown: Interval
54
+
55
+ def as_dict(self) -> dict[str, str]:
56
+ """Provides a dict representation of the intervals configuration
57
+
58
+ Returns:
59
+ dict[IntervalLiteral, str]:
60
+ dictionary of string representations for intervals in the configuration.
61
+ """
62
+ dict_repr = {
63
+ "normal": str(self.normal),
64
+ "abnormal": str(self.abnormal),
65
+ "unknown": str(self.unknown),
66
+ }
67
+ return dict_repr
@@ -0,0 +1,49 @@
1
+ # Software Name : Self-Supervised Anomaly Detection
2
+ # SPDX-FileCopyrightText: Copyright (c) Orange SA
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ # This software is distributed under the MIT License,
6
+ # see the "LICENSE.txt" file for more details or https://spdx.org/licenses/MIT.html
7
+ #
8
+ # Authors: see CONTRIBUTORS
9
+ # Software description: A Python library for autoencoder-based anomaly detection
10
+ # based on self-supervised training with dynamic sample confidence updates.
11
+ from typing import List, Optional
12
+
13
+ import torch
14
+
15
+ from ssad.datamodules.dataset_with_confidence import DatasetWithConfidence
16
+ from .confidence_estimator import BaseConfidenceEstimator, SupportsConfidenceEstimation
17
+
18
+
19
+ class HybridEstimator(SupportsConfidenceEstimation):
20
+
21
+ def __init__(
22
+ self,
23
+ training_dataset: DatasetWithConfidence,
24
+ estimators: List[BaseConfidenceEstimator],
25
+ weights: Optional[List[float]] = None,
26
+ ):
27
+
28
+ if estimators is None:
29
+ raise ValueError(
30
+ "Empty estimator list during HybridEstimator construction."
31
+ )
32
+
33
+ if weights is None:
34
+ self.weights = torch.ones(len(estimators)) / len(estimators)
35
+ else:
36
+ self.weights = torch.FloatTensor(weights)
37
+
38
+ if len(self.weights) != len(estimators):
39
+ raise ValueError("Different number of weights and confidence estimators")
40
+
41
+ self.training_dataset = training_dataset
42
+ self.estimators = estimators
43
+
44
+ def estimate_confidence(self, scores_batch: torch.Tensor):
45
+ confidence = torch.Tensor(torch.zeros(len(self.training_dataset)))
46
+ for idx, estimator in enumerate(self.estimators):
47
+ estimator_confidence = estimator.estimate_confidence(scores_batch)
48
+ confidence += self.weights[idx] * estimator_confidence
49
+ return confidence
@@ -0,0 +1,26 @@
1
+ # Software Name : Self-Supervised Anomaly Detection
2
+ # SPDX-FileCopyrightText: Copyright (c) Orange SA
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ # This software is distributed under the MIT License,
6
+ # see the "LICENSE.txt" file for more details or https://spdx.org/licenses/MIT.html
7
+ #
8
+ # Authors: see CONTRIBUTORS
9
+ # Software description: A Python library for autoencoder-based anomaly detection
10
+ # based on self-supervised training with dynamic sample confidence updates.
11
+ from typing import Optional, Protocol
12
+
13
+ import torch
14
+
15
+ from ssad.confidence_estimators.confidence_intervals_configuration import (
16
+ ConfidenceIntervalsConfiguration,
17
+ )
18
+
19
+
20
+ class SupportsConfidenceEstimation(Protocol):
21
+ """Protocol for confidence estimators used in a SelfSupervisionCallback."""
22
+
23
+ configuration: Optional[ConfidenceIntervalsConfiguration]
24
+
25
+ def estimate_confidence(self, scores_batch: torch.Tensor) -> torch.Tensor:
26
+ """Computes the confidence score over a data batch"""
@@ -0,0 +1,10 @@
1
+ # Software Name : Self-Supervised Anomaly Detection
2
+ # SPDX-FileCopyrightText: Copyright (c) Orange SA
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ # This software is distributed under the MIT License,
6
+ # see the "LICENSE.txt" file for more details or https://spdx.org/licenses/MIT.html
7
+ #
8
+ # Authors: see CONTRIBUTORS
9
+ # Software description: A Python library for autoencoder-based anomaly detection
10
+ # based on self-supervised training with dynamic sample confidence updates.
@@ -0,0 +1,27 @@
1
+ # Software Name : Self-Supervised Anomaly Detection
2
+ # SPDX-FileCopyrightText: Copyright (c) Orange SA
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ # This software is distributed under the MIT License,
6
+ # see the "LICENSE.txt" file for more details or https://spdx.org/licenses/MIT.html
7
+ #
8
+ # Authors: see CONTRIBUTORS
9
+ # Software description: A Python library for autoencoder-based anomaly detection
10
+ # based on self-supervised training with dynamic sample confidence updates.
11
+ from .dataframe_with_labels import DataFrameWithLabels
12
+ from .self_supervision_datamodule import SelfSupervisionDataModule
13
+ from .dataset_with_confidence import (
14
+ DatasetWithConfidence,
15
+ init_confidence_from_csv,
16
+ save_confidence_to_csv,
17
+ )
18
+ from.transforms.dataframe_to_tensor import DataFrameToTensor
19
+
20
+ __all__ = [
21
+ "SelfSupervisionDataModule",
22
+ "DatasetWithConfidence",
23
+ "init_confidence_from_csv",
24
+ "save_confidence_to_csv",
25
+ "DataFrameWithLabels",
26
+ "DataFrameToTensor",
27
+ ]
@@ -0,0 +1,95 @@
1
+ # Software Name : Self-Supervised Anomaly Detection
2
+ # SPDX-FileCopyrightText: Copyright (c) Orange SA
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ # This software is distributed under the MIT License,
6
+ # see the "LICENSE.txt" file for more details or https://spdx.org/licenses/MIT.html
7
+ #
8
+ # Authors: see CONTRIBUTORS
9
+ # Software description: A Python library for autoencoder-based anomaly detection
10
+ # based on self-supervised training with dynamic sample confidence updates.
11
+ """
12
+ Defines the torch.Dataset specialization for
13
+ pd.Datasets with a 'label' column
14
+ """
15
+
16
+ import torch
17
+ from torch.utils.data import Dataset
18
+
19
+
20
+ class DataFrameWithLabels(Dataset):
21
+ """
22
+ Class that derives from torch.Dataset.
23
+ Defines the required methods for a dataset that
24
+ is formed of a pd.DataFrame with a column that corresponds
25
+ to the row label.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ data,
31
+ label_column_name: str,
32
+ transform=None,
33
+ target_transform=None,
34
+ ):
35
+ super().__init__()
36
+ labels = data[label_column_name]
37
+ features = data.drop(columns=[label_column_name])
38
+
39
+ self.data = torch.tensor(features.values, dtype=torch.float32)
40
+ self.labels = torch.tensor(labels.values, dtype=torch.float32)
41
+
42
+ self.transform = transform
43
+ self.target_transform = target_transform
44
+
45
+ def __len__(self) -> int:
46
+ return len(self.data)
47
+
48
+ def __getitem__(self, idx):
49
+ data = self.data[idx]
50
+ label = self.labels[idx].unsqueeze(0)
51
+
52
+ if self.transform:
53
+ data = self.transform(data)
54
+ if self.target_transform:
55
+ label = self.target_transform(label)
56
+ return data, label
57
+
58
+ def input_dim(self):
59
+ """Returns the size of the samples of a dataset
60
+
61
+ Args:
62
+ dataset (Dataset): dataset whose row size is required
63
+
64
+ Returns:
65
+ int: length of the dataset samples
66
+ """
67
+ # return len(self.data.columns)
68
+ return self.data.shape[1]
69
+
70
+ def collate(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
71
+ """Collate function, used to transform a batch into a tensor
72
+
73
+ Args:
74
+ batch (List[int]): list of indexes in the batch
75
+
76
+ Returns:
77
+ torch.Tensor:
78
+ """
79
+ data, labels = zip(*batch)
80
+ data_tensor = torch.stack(data)
81
+ labels_tensor = torch.stack(labels)
82
+ return data_tensor, labels_tensor
83
+
84
+ def get_stats(self) -> tuple:
85
+ """
86
+ Returns the total number of samples, number of normal samples, and number of anomalies.
87
+ Assumes that labels are 0 for normal and 1 for anomalies.
88
+
89
+ Returns:
90
+ tuple: (total, normal, anomaly)
91
+ """
92
+ total = len(self.labels)
93
+ normal = int((self.labels == 0).sum().item())
94
+ anomaly = int((self.labels == 1).sum().item())
95
+ return total, normal, anomaly
@@ -0,0 +1,47 @@
1
+ # Software Name : Self-Supervised Anomaly Detection
2
+ # SPDX-FileCopyrightText: Copyright (c) Orange SA
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ # This software is distributed under the MIT License,
6
+ # see the "LICENSE.txt" file for more details or https://spdx.org/licenses/MIT.html
7
+ #
8
+ # Authors: see CONTRIBUTORS
9
+ # Software description: A Python library for autoencoder-based anomaly detection
10
+ # based on self-supervised training with dynamic sample confidence updates.
11
+ from typing import Protocol, Any, Tuple, runtime_checkable
12
+
13
+
14
+ @runtime_checkable
15
+ class DatasetWithLabels(Protocol):
16
+ """Interface for a Dataset containing two attributes:
17
+ data (Any) : data used by a model for its task
18
+ labels (Any) : ground truth for the data
19
+
20
+ """
21
+
22
+ labels: Any
23
+ data: Any
24
+
25
+ def __len__(self) -> int: ...
26
+
27
+ def __getitem__(self, idx: int) -> Tuple[Any, Any]: ...
28
+
29
+
30
+ @runtime_checkable
31
+ class DatasetWithInputDim(Protocol):
32
+ """Interface for a Dataset with the input_dim method
33
+ This method is necessary to initialize the input layer of an autoencoder.
34
+ """
35
+
36
+ def input_dim(self) -> int:
37
+ """Returns the dimension of the data"""
38
+
39
+ # def collate(self, batch : List[int]) -> tuple[torch.Tensor, torch.Tensor]:
40
+ # """Collate function, used to transform a batch into a tensor
41
+
42
+ # Args:
43
+ # batch (List[int]): list of indexes in the batch
44
+
45
+ # Returns:
46
+ # torch.Tensor:
47
+ # """
@@ -0,0 +1,101 @@
1
+ # Software Name : Self-Supervised Anomaly Detection
2
+ # SPDX-FileCopyrightText: Copyright (c) Orange SA
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ # This software is distributed under the MIT License,
6
+ # see the "LICENSE.txt" file for more details or https://spdx.org/licenses/MIT.html
7
+ #
8
+ # Authors: see CONTRIBUTORS
9
+ # Software description: A Python library for autoencoder-based anomaly detection
10
+ # based on self-supervised training with dynamic sample confidence updates.
11
+ """
12
+ Defines the torch.Dataset specialization for
13
+ pd.Datasets with a 'label' column
14
+ """
15
+
16
+ from collections.abc import Sized
17
+ from pathlib import Path
18
+ from typing import Optional
19
+
20
+ import pandas as pd
21
+ import torch
22
+ from torch.utils.data import Dataset
23
+
24
+ from .dataset_interfaces import DatasetWithLabels
25
+
26
+ CONFIDENCE_COLUMN_NAME = "confidence"
27
+
28
+
29
+ class DatasetWithConfidence(Dataset):
30
+ """
31
+ Class that derives from torch.Dataset.
32
+ Defines the required methods for a dataset that
33
+ is formed of a pd.DataFrame with a column that corresponds
34
+ to the row label.
35
+ """
36
+
37
+ def __init__(self, dataset: DatasetWithLabels, confidence: torch.Tensor):
38
+
39
+ if not isinstance(dataset, Sized) or len(dataset) != confidence.size(dim=0):
40
+ raise ValueError("Size mismatch between dataset and confidence tensor.")
41
+
42
+ self.dataset = dataset
43
+ self.confidence = confidence
44
+
45
+ def __len__(self):
46
+ return self.dataset.__len__()
47
+
48
+ def __getitem__(self, idx):
49
+ # TODO check if self.dataset[idx] is a tuple or a list
50
+ # in case there is no label
51
+ data = self.dataset[idx]
52
+ return {
53
+ "data": data[0],
54
+ "label": data[1],
55
+ "confidence": self.confidence[idx],
56
+ }
57
+
58
+
59
+ def init_confidence_from_csv(dataset: Dataset, path: Optional[Path]) -> torch.Tensor:
60
+ """Loads a dataframe with confidence scores for each sample of the dataframe.
61
+
62
+ Args:
63
+ data (pd.Dataframe): dataframe
64
+ path (Path): path to csv file with a single column containing the confidence scores.
65
+
66
+ Returns:
67
+ confidence (torch.Tensor): tensor with initial confidence values.
68
+ """
69
+
70
+ if not isinstance(dataset, Sized):
71
+ raise TypeError("Provided dataset does not support the len() method.")
72
+
73
+ if path is not None:
74
+ confidence_df = pd.read_csv(path)
75
+
76
+ if CONFIDENCE_COLUMN_NAME not in confidence_df:
77
+ raise ValueError(f"Missing column: {CONFIDENCE_COLUMN_NAME}")
78
+
79
+ confidence = torch.from_numpy(confidence_df[CONFIDENCE_COLUMN_NAME].values)
80
+ else:
81
+ confidence = torch.ones(len(dataset))
82
+ return confidence
83
+
84
+
85
+ def save_confidence_to_csv(confidence_dataset: DatasetWithConfidence, path: Path):
86
+ """Saves confidence scores to csv file.
87
+ Uses a dataloader to iterate over the dataset and extract confidence scores
88
+ for each sample into a list.
89
+ This list is then saved to a csv file by converting it to a pd.DataFrame.
90
+
91
+ Args:
92
+ confidence_dataset (DatasetWithConfidence): dataset with confidence scores to save.
93
+ path (Path): path to csv file.
94
+ """
95
+ confidence = confidence_dataset.confidence.detach().cpu().numpy()
96
+
97
+ confidence_df = pd.DataFrame(
98
+ confidence,
99
+ columns=[CONFIDENCE_COLUMN_NAME],
100
+ )
101
+ confidence_df.to_csv(path)