aisp 0.1.34__py3-none-any.whl → 0.1.40__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aisp/__init__.py +4 -0
- aisp/base/__init__.py +4 -0
- aisp/base/_classifier.py +90 -0
- aisp/exceptions.py +42 -0
- aisp/nsa/__init__.py +11 -0
- aisp/nsa/_base.py +118 -0
- aisp/nsa/_negative_selection.py +682 -0
- aisp/nsa/_ns_core.py +153 -0
- aisp/utils/__init__.py +2 -1
- aisp/utils/_multiclass.py +16 -30
- aisp/utils/distance.py +215 -0
- aisp/utils/metrics.py +22 -43
- aisp/utils/sanitizers.py +55 -0
- {aisp-0.1.34.dist-info → aisp-0.1.40.dist-info}/METADATA +11 -111
- aisp-0.1.40.dist-info/RECORD +18 -0
- {aisp-0.1.34.dist-info → aisp-0.1.40.dist-info}/WHEEL +1 -1
- aisp/NSA/__init__.py +0 -18
- aisp/NSA/_base.py +0 -281
- aisp/NSA/_negative_selection.py +0 -1115
- aisp-0.1.34.dist-info/RECORD +0 -11
- {aisp-0.1.34.dist-info → aisp-0.1.40.dist-info}/licenses/LICENSE +0 -0
- {aisp-0.1.34.dist-info → aisp-0.1.40.dist-info}/top_level.txt +0 -0
aisp/nsa/_ns_core.py
ADDED
@@ -0,0 +1,153 @@
|
|
1
|
+
"""ns: Negative Selection
|
2
|
+
|
3
|
+
The functions perform detector checks and utilize Numba decorators for Just-In-Time compilation
|
4
|
+
"""
|
5
|
+
|
6
|
+
import numpy.typing as npt
|
7
|
+
from numba import njit, types
|
8
|
+
|
9
|
+
from ..utils.distance import compute_metric_distance, hamming
|
10
|
+
|
11
|
+
|
12
|
+
@njit(
|
13
|
+
[(
|
14
|
+
types.boolean[:, :],
|
15
|
+
types.boolean[:],
|
16
|
+
types.float64
|
17
|
+
)],
|
18
|
+
cache=True
|
19
|
+
)
|
20
|
+
def check_detector_bnsa_validity(
|
21
|
+
x_class: npt.NDArray,
|
22
|
+
vector_x: npt.NDArray,
|
23
|
+
aff_thresh: float
|
24
|
+
) -> bool:
|
25
|
+
"""
|
26
|
+
Checks the validity of a candidate detector (vector_x) against samples from a class (x_class)
|
27
|
+
using the Hamming distance. A detector is considered INVALID if its distance to any sample
|
28
|
+
in ``x_class`` is less than or equal to ``aff_thresh``.
|
29
|
+
|
30
|
+
Parameters
|
31
|
+
----------
|
32
|
+
* x_class (``npt.NDArray``): Array containing the class samples. Expected shape:
|
33
|
+
(n_samples, n_features).
|
34
|
+
* vector_x (``npt.NDArray``): Array representing the detector. Expected shape: (n_features,).
|
35
|
+
* aff_thresh (``float``): Affinity threshold.
|
36
|
+
|
37
|
+
Returns
|
38
|
+
----------
|
39
|
+
* True if the detector is valid, False otherwise.
|
40
|
+
"""
|
41
|
+
n = x_class.shape[1]
|
42
|
+
if n != vector_x.shape[0]:
|
43
|
+
return False
|
44
|
+
|
45
|
+
for i in range(x_class.shape[0]):
|
46
|
+
# Calculate the normalized Hamming Distance
|
47
|
+
if hamming(x_class[i], vector_x) <= aff_thresh:
|
48
|
+
return False
|
49
|
+
return True
|
50
|
+
|
51
|
+
|
52
|
+
@njit(
|
53
|
+
[(
|
54
|
+
types.boolean[:],
|
55
|
+
types.boolean[:, :, :],
|
56
|
+
types.float64
|
57
|
+
)],
|
58
|
+
cache=True
|
59
|
+
)
|
60
|
+
def bnsa_class_prediction(
|
61
|
+
features: npt.NDArray,
|
62
|
+
class_detectors: npt.NDArray,
|
63
|
+
aff_thresh: float
|
64
|
+
) -> int:
|
65
|
+
"""
|
66
|
+
Defines the class of a sample from the non-self detectors.
|
67
|
+
|
68
|
+
Parameters
|
69
|
+
----------
|
70
|
+
* features (``npt.NDArray``): binary sample to be classified (shape: [n_features]).
|
71
|
+
* class_detectors (``npt.NDArray``): Array containing the detectors of all classes
|
72
|
+
(shape: [n_classes, n_detectors, n_features]).
|
73
|
+
* aff_thresh (``float``): Affinity threshold that determines whether a detector recognizes the
|
74
|
+
sample as non-self.
|
75
|
+
|
76
|
+
Returns
|
77
|
+
----------
|
78
|
+
* int: Index of the predicted class. Returns -1 if it is non-self for all classes.
|
79
|
+
"""
|
80
|
+
n_classes, n_detectors, _ = class_detectors.shape
|
81
|
+
best_class_idx = -1
|
82
|
+
best_avg_distance = 0
|
83
|
+
|
84
|
+
for class_index in range(n_classes):
|
85
|
+
total_distance = 0.0
|
86
|
+
class_found = True
|
87
|
+
|
88
|
+
# Calculates the Hamming distance between the row and all detectors.
|
89
|
+
for detector_index in range(n_detectors):
|
90
|
+
# Calculates the normalized Hamming distance between the sample and the detector
|
91
|
+
distance = hamming(features, class_detectors[class_index][detector_index])
|
92
|
+
|
93
|
+
# If the distance is less than or equal to the threshold, the detector recognizes
|
94
|
+
# the sample as non-self.
|
95
|
+
if distance <= aff_thresh:
|
96
|
+
class_found = False
|
97
|
+
break
|
98
|
+
total_distance += distance
|
99
|
+
|
100
|
+
# if the sample is self for the class
|
101
|
+
if class_found:
|
102
|
+
avg_distance = total_distance / n_detectors
|
103
|
+
# Choose the class with the largest average distance.
|
104
|
+
if avg_distance > best_avg_distance:
|
105
|
+
best_avg_distance = avg_distance
|
106
|
+
best_class_idx = class_index
|
107
|
+
|
108
|
+
return best_class_idx
|
109
|
+
|
110
|
+
|
111
|
+
@njit(
|
112
|
+
[(
|
113
|
+
types.float64[:, :], types.float64[:],
|
114
|
+
types.float64, types.int32, types.float64
|
115
|
+
)],
|
116
|
+
cache=True
|
117
|
+
)
|
118
|
+
def check_detector_rnsa_validity(
|
119
|
+
x_class: npt.NDArray,
|
120
|
+
vector_x: npt.NDArray,
|
121
|
+
threshold: float,
|
122
|
+
metric: int,
|
123
|
+
p: float
|
124
|
+
) -> bool:
|
125
|
+
"""
|
126
|
+
Checks the validity of a candidate detector (vector_x) against samples from a class (x_class)
|
127
|
+
using the Hamming distance. A detector is considered INVALID if its distance to any sample
|
128
|
+
in ``x_class`` is less than or equal to ``aff_thresh``.
|
129
|
+
|
130
|
+
Parameters
|
131
|
+
----------
|
132
|
+
* x_class (``npt.NDArray``): Array containing the class samples. Expected shape:
|
133
|
+
(n_samples, n_features).
|
134
|
+
* vector_x (``npt.NDArray``): Array representing the detector. Expected shape: (n_features,).
|
135
|
+
* threshold (``float``): threshold.
|
136
|
+
* metric (``int``): Distance metric to be used. Available options:
|
137
|
+
[0 (Euclidean), 1 (Manhattan), 2 (Minkowski)].
|
138
|
+
* p (``float``): Parameter for the Minkowski distance (used only if `metric`
|
139
|
+
is "minkowski").
|
140
|
+
|
141
|
+
Returns
|
142
|
+
----------
|
143
|
+
* True if the detector is valid, False otherwise.
|
144
|
+
"""
|
145
|
+
n = x_class.shape[1]
|
146
|
+
if n != vector_x.shape[0]:
|
147
|
+
return False
|
148
|
+
|
149
|
+
for i in range(x_class.shape[0]):
|
150
|
+
distance = compute_metric_distance(vector_x, x_class[i], metric, p)
|
151
|
+
if distance <= threshold:
|
152
|
+
return False
|
153
|
+
return True
|
aisp/utils/__init__.py
CHANGED
aisp/utils/_multiclass.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
"""Utility functions for handling classes with multiple categories."""
|
2
|
+
|
1
3
|
from typing import Union
|
2
4
|
import numpy as np
|
3
5
|
import numpy.typing as npt
|
@@ -5,37 +7,21 @@ import numpy.typing as npt
|
|
5
7
|
|
6
8
|
def slice_index_list_by_class(classes: Union[npt.NDArray, list], y: npt.NDArray) -> dict:
|
7
9
|
"""
|
8
|
-
The function ``
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
Parameters
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
returns
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
---
|
23
|
-
|
24
|
-
A função ``__slice_index_list_by_class(...)``, separa os índices das linhas conforme a \
|
25
|
-
classe de saída, para percorrer o array de amostra, apenas nas posições que a saída for \
|
26
|
-
a classe que está sendo treinada.
|
27
|
-
|
28
|
-
Parameters:
|
29
|
-
---
|
30
|
-
* classes (``list or npt.NDArray``): lista com classes únicas.
|
31
|
-
* y (npt.NDArray): Recebe um array ``y``[``N amostra``] com as classes de saída do \
|
32
|
-
array de amostra ``X``.
|
33
|
-
|
34
|
-
Returns:
|
35
|
-
---
|
36
|
-
* dict: Um dicionário com a lista de posições do array(``y``), com as classes como chave.
|
10
|
+
The function ``slice_index_list_by_class(...)``, separates the indices of the lines according
|
11
|
+
to the output class, to loop through the sample array, only in positions where the output is the
|
12
|
+
class being trained.
|
13
|
+
|
14
|
+
Parameters
|
15
|
+
----------
|
16
|
+
* classes (``list or npt.NDArray``): list with unique classes.
|
17
|
+
* y (``npt.NDArray``): Receives a ``y``[``N sample``] array with the output classes of the
|
18
|
+
``X`` sample array.
|
19
|
+
|
20
|
+
returns
|
21
|
+
----------
|
22
|
+
* dict: A dictionary with the list of array positions(``y``), with the classes as key.
|
37
23
|
"""
|
38
|
-
position_samples =
|
24
|
+
position_samples = {}
|
39
25
|
for _class_ in classes:
|
40
26
|
# Gets the sample positions by class from y.
|
41
27
|
position_samples[_class_] = list(np.nonzero(y == _class_)[0])
|
aisp/utils/distance.py
ADDED
@@ -0,0 +1,215 @@
|
|
1
|
+
"""Utility functions for normalized distance between arrays with numba decorators."""
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
import numpy.typing as npt
|
5
|
+
from numba import njit, types
|
6
|
+
|
7
|
+
EUCLIDEAN = 0
|
8
|
+
MANHATTAN = 1
|
9
|
+
MINKOWSKI = 2
|
10
|
+
HAMMING = 3
|
11
|
+
|
12
|
+
|
13
|
+
@njit([(types.boolean[:], types.boolean[:])], cache=True)
|
14
|
+
def hamming(u: npt.NDArray[np.bool_], v: npt.NDArray[np.bool_]) -> np.float64:
|
15
|
+
"""
|
16
|
+
Function to calculate the normalized Hamming distance between two points.
|
17
|
+
|
18
|
+
((x₁ ≠ x₂) + (y₁ ≠ y₂) + ... + (yn ≠ yn)) / n
|
19
|
+
|
20
|
+
Parameters
|
21
|
+
----------
|
22
|
+
* u (``npt.NDArray``): Coordinates of the first point.
|
23
|
+
* v (``npt.NDArray``): Coordinates of the second point.
|
24
|
+
|
25
|
+
returns
|
26
|
+
----------
|
27
|
+
* Distance (``float``) between the two points.
|
28
|
+
"""
|
29
|
+
n = len(u)
|
30
|
+
if n == 0:
|
31
|
+
return 0.0
|
32
|
+
|
33
|
+
return np.sum(u != v) / n
|
34
|
+
|
35
|
+
|
36
|
+
@njit()
|
37
|
+
def euclidean(u: npt.NDArray[np.float64], v: npt.NDArray[np.float64]) -> np.float64:
|
38
|
+
"""
|
39
|
+
Function to calculate the normalized Euclidean distance between two points.
|
40
|
+
|
41
|
+
√( (x₁ – x₂)² + (y₁ – y₂)² + ... + (yn – yn)²)
|
42
|
+
|
43
|
+
Parameters
|
44
|
+
----------
|
45
|
+
* u (``npt.NDArray``): Coordinates of the first point.
|
46
|
+
* v (``npt.NDArray``): Coordinates of the second point.
|
47
|
+
|
48
|
+
returns
|
49
|
+
----------
|
50
|
+
* Distance (``float``) between the two points.
|
51
|
+
"""
|
52
|
+
return np.linalg.norm(u - v)
|
53
|
+
|
54
|
+
|
55
|
+
@njit()
|
56
|
+
def cityblock(u: npt.NDArray[np.float64], v: npt.NDArray[np.float64]) -> np.float64:
|
57
|
+
"""
|
58
|
+
Function to calculate the normalized Manhattan distance between two points.
|
59
|
+
|
60
|
+
(|x₁ – x₂| + |y₁ – y₂| + ... + |yn – yn|) / n
|
61
|
+
|
62
|
+
Parameters
|
63
|
+
----------
|
64
|
+
* u (``npt.NDArray``): Coordinates of the first point.
|
65
|
+
* v (``npt.NDArray``): Coordinates of the second point.
|
66
|
+
|
67
|
+
returns
|
68
|
+
----------
|
69
|
+
* Distance (``float``) between the two points.
|
70
|
+
"""
|
71
|
+
n = len(u)
|
72
|
+
if n == 0:
|
73
|
+
return -1.0
|
74
|
+
|
75
|
+
return np.sum(np.abs(u - v)) / n
|
76
|
+
|
77
|
+
|
78
|
+
@njit()
|
79
|
+
def minkowski(u: npt.NDArray[np.float64], v: npt.NDArray[np.float64], p: float = 2.0):
|
80
|
+
"""
|
81
|
+
Function to calculate the normalized Minkowski distance between two points.
|
82
|
+
|
83
|
+
(( |X₁ – Y₁|p + |X₂ – Y₂|p + ... + |Xn – Yn|p) ¹/ₚ.) / n
|
84
|
+
|
85
|
+
Parameters
|
86
|
+
----------
|
87
|
+
* u (``npt.NDArray``): Coordinates of the first point.
|
88
|
+
* v (``npt.NDArray``): Coordinates of the second point.
|
89
|
+
* p float: The p parameter defines the type of distance to be calculated:
|
90
|
+
- p = 1: **Manhattan** distance — sum of absolute differences.
|
91
|
+
- p = 2: **Euclidean** distance — sum of squared differences (square root).
|
92
|
+
- p > 2: **Minkowski** distance with an increasing penalty as p increases.
|
93
|
+
|
94
|
+
returns
|
95
|
+
----------
|
96
|
+
* Distance (``float``) between the two points.
|
97
|
+
"""
|
98
|
+
n = len(u)
|
99
|
+
if n == 0:
|
100
|
+
return -1.0
|
101
|
+
|
102
|
+
return (np.sum(np.abs(u - v) ** p) ** (1 / p)) / n
|
103
|
+
|
104
|
+
|
105
|
+
@njit(
|
106
|
+
[(
|
107
|
+
types.float64[:], types.float64[:],
|
108
|
+
types.int32, types.float64
|
109
|
+
)],
|
110
|
+
cache=True
|
111
|
+
)
|
112
|
+
def compute_metric_distance(
|
113
|
+
u: npt.NDArray[np.float64],
|
114
|
+
v: npt.NDArray[np.float64],
|
115
|
+
metric: int,
|
116
|
+
p: np.float64 = 2.0
|
117
|
+
) -> np.float64:
|
118
|
+
"""
|
119
|
+
Function to calculate the distance between two points by the chosen ``metric``.
|
120
|
+
|
121
|
+
Parameters
|
122
|
+
----------
|
123
|
+
* u (``npt.NDArray``): Coordinates of the first point.
|
124
|
+
* v (``npt.NDArray``): Coordinates of the second point.
|
125
|
+
* metric (``int``): Distance metric to be used. Available options:
|
126
|
+
[0 (Euclidean), 1 (Manhattan), 2 (Minkowski)]
|
127
|
+
* p (``float``): Parameter for the Minkowski distance (used only if `metric`
|
128
|
+
is "minkowski").
|
129
|
+
|
130
|
+
returns
|
131
|
+
----------
|
132
|
+
* Distance (``double``) between the two points with the selected metric.
|
133
|
+
"""
|
134
|
+
if metric == MANHATTAN:
|
135
|
+
return cityblock(u, v)
|
136
|
+
if metric == MINKOWSKI:
|
137
|
+
return minkowski(u, v, p)
|
138
|
+
|
139
|
+
return euclidean(u, v)
|
140
|
+
|
141
|
+
|
142
|
+
@njit(
|
143
|
+
[(
|
144
|
+
types.float64[:, :], types.float64[:],
|
145
|
+
types.int32, types.float64
|
146
|
+
)],
|
147
|
+
cache=True
|
148
|
+
)
|
149
|
+
def min_distance_to_class_vectors(
|
150
|
+
x_class: npt.NDArray[np.float64],
|
151
|
+
vector_x: npt.NDArray[np.float64],
|
152
|
+
metric: int,
|
153
|
+
p: float = 2.0
|
154
|
+
) -> float:
|
155
|
+
"""
|
156
|
+
Calculates the minimum distance between an input vector and the vectors of a class.
|
157
|
+
|
158
|
+
Parameters
|
159
|
+
----------
|
160
|
+
* x_class (``npt.NDArray``): Array containing the class vectors to be compared
|
161
|
+
with the input vector. Expected shape: (n_samples, n_features).
|
162
|
+
* vector_x (``npt.NDArray``): Vector to be compared with the class vectors.
|
163
|
+
Expected shape: (n_features,).
|
164
|
+
* metric (``str``): Distance metric to be used. Available options:
|
165
|
+
["hamming", "cityblock", "minkowski", "euclidean"]
|
166
|
+
* p (``float``): Parameter for the Minkowski distance (used only if `metric`
|
167
|
+
is "minkowski").
|
168
|
+
|
169
|
+
Returns
|
170
|
+
----------
|
171
|
+
* float: The minimum distance calculated between the input vector and the class vectors.
|
172
|
+
* Returns -1.0 if the input dimensions are incompatible.
|
173
|
+
"""
|
174
|
+
n = x_class.shape[1]
|
175
|
+
if n != vector_x.shape[0]:
|
176
|
+
return -1.0
|
177
|
+
|
178
|
+
min_distance = np.inf
|
179
|
+
for i in range(x_class.shape[0]):
|
180
|
+
distance = compute_metric_distance(vector_x, x_class[i], metric, p)
|
181
|
+
min_distance = min(min_distance, distance)
|
182
|
+
|
183
|
+
return min_distance
|
184
|
+
|
185
|
+
|
186
|
+
def get_metric_code(metric: str) -> int:
|
187
|
+
"""
|
188
|
+
Returns the numeric code associated with a distance metric.
|
189
|
+
|
190
|
+
Parameters
|
191
|
+
----------
|
192
|
+
* metric (str): Name of the metric. Can be "euclidean", "manhattan", "minkowski" or "hamming".
|
193
|
+
|
194
|
+
Raises
|
195
|
+
----------
|
196
|
+
* ValueError: If the metric provided is not supported.
|
197
|
+
|
198
|
+
Returns
|
199
|
+
----------
|
200
|
+
* int: Numeric code corresponding to the metric.
|
201
|
+
"""
|
202
|
+
metric_map = {
|
203
|
+
"euclidean": EUCLIDEAN,
|
204
|
+
"manhattan": MANHATTAN,
|
205
|
+
"minkowski": MINKOWSKI,
|
206
|
+
"hamming": HAMMING
|
207
|
+
}
|
208
|
+
|
209
|
+
normalized_metric = metric.strip().lower()
|
210
|
+
|
211
|
+
if normalized_metric not in metric_map:
|
212
|
+
supported = "', '".join(metric_map.keys())
|
213
|
+
raise ValueError(f"Unknown metric: '{metric}'. Supported: {supported}")
|
214
|
+
|
215
|
+
return metric_map[normalized_metric]
|
aisp/utils/metrics.py
CHANGED
@@ -1,61 +1,40 @@
|
|
1
|
+
"""Utility functions for measuring accuracy and performance."""
|
1
2
|
from typing import Union
|
3
|
+
|
2
4
|
import numpy as np
|
3
5
|
import numpy.typing as npt
|
4
6
|
|
5
7
|
|
6
8
|
def accuracy_score(
|
7
|
-
|
8
|
-
|
9
|
+
y_true: Union[npt.NDArray, list],
|
10
|
+
y_pred: Union[npt.NDArray, list]
|
9
11
|
) -> float:
|
10
12
|
"""
|
11
|
-
Function to calculate
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
Raises:
|
28
|
-
---
|
29
|
-
* ValueError: If `y_true` or `y_pred` are empty or if they do not have the same length.
|
30
|
-
|
31
|
-
---
|
32
|
-
|
33
|
-
Função para calcular a acurácia de precisão com base em listas de rótulos
|
34
|
-
verdadeiros e nos rótulos previstos.
|
35
|
-
|
36
|
-
Parâmetros:
|
37
|
-
---
|
38
|
-
* y_true (``Union[npt.NDArray, list]``): Rótulos verdadeiros (corretos)..
|
39
|
-
* y_pred (``Union[npt.NDArray, list]``): Rótulos previstos.
|
40
|
-
|
41
|
-
Retornos:
|
42
|
-
---
|
43
|
-
* Precisão (``float``): A proporção de previsões corretas em relação
|
44
|
-
ao número total de previsões.
|
45
|
-
|
46
|
-
Lança:
|
47
|
-
---
|
48
|
-
* ValueError: Se `y_true` ou `y_pred` estiverem vazios ou se não
|
49
|
-
tiverem o mesmo tamanho.
|
13
|
+
Function to calculate the accuracy score based on true and predicted labels.
|
14
|
+
|
15
|
+
Parameters
|
16
|
+
----------
|
17
|
+
* y_true ()``Union[npt.NDArray, list]``):
|
18
|
+
Ground truth (correct) labels. Expected to be of the same length as `y_pred`.
|
19
|
+
* y_pred (``Union[npt.NDArray, list]``):
|
20
|
+
Predicted labels. Expected to be of the same length as `y_true`.
|
21
|
+
|
22
|
+
Returns
|
23
|
+
----------
|
24
|
+
* float: The ratio of correct predictions to the total number of predictions.
|
25
|
+
|
26
|
+
Raises
|
27
|
+
----------
|
28
|
+
* ValueError: If `y_true` or `y_pred` are empty or if they do not have the same length.
|
50
29
|
"""
|
51
30
|
n = len(y_true)
|
52
31
|
if n == 0:
|
53
32
|
raise ValueError(
|
54
33
|
"Division by zero: y_true cannot be an empty list or array."
|
55
34
|
)
|
56
|
-
|
35
|
+
if n != len(y_pred):
|
57
36
|
raise ValueError(
|
58
37
|
f"Error: The arrays must have the same size. Size of y_true: "
|
59
38
|
f"{len(y_true)}, Size of y_pred: {len(y_pred)}"
|
60
39
|
)
|
61
|
-
return np.sum(np.
|
40
|
+
return np.sum(np.array(y_true) == np.array(y_pred)) / n
|
aisp/utils/sanitizers.py
ADDED
@@ -0,0 +1,55 @@
|
|
1
|
+
"""Utility functions for validation and treatment of parameters."""
|
2
|
+
from typing import TypeVar, Iterable, Callable, Any, Optional
|
3
|
+
|
4
|
+
T = TypeVar('T')
|
5
|
+
|
6
|
+
|
7
|
+
def sanitize_choice(value: T, valid_choices: Iterable[T], default: T) -> T:
|
8
|
+
"""
|
9
|
+
Returns the value if it is present in the set of valid choices; otherwise,
|
10
|
+
returns the default value.
|
11
|
+
|
12
|
+
Parameters
|
13
|
+
----------
|
14
|
+
* value (``T``): The value to be checked.
|
15
|
+
* valid_choices (``Iterable[T]``): A collection of valid choices.
|
16
|
+
* default: The default value to be returned if 'value' is not in 'valid_choices'.
|
17
|
+
|
18
|
+
Returns
|
19
|
+
----------
|
20
|
+
* The original value if valid, or the default value if not.
|
21
|
+
"""
|
22
|
+
return value if value in valid_choices else default
|
23
|
+
|
24
|
+
|
25
|
+
def sanitize_param(value: T, default: T, condition: Callable[[T], bool]) -> T:
|
26
|
+
"""
|
27
|
+
Returns the value if it satisfies the specified condition; otherwise, returns the default value.
|
28
|
+
|
29
|
+
Parameters
|
30
|
+
----------
|
31
|
+
* value: The value to be checked.
|
32
|
+
* default (``T``): The default value to be returned if the condition is not satisfied.
|
33
|
+
* condition (``Callable[[T], bool]``): A function that takes a value and returns a boolean,
|
34
|
+
determining if the value is valid.
|
35
|
+
|
36
|
+
Returns
|
37
|
+
----------
|
38
|
+
* T: The original value if the condition is satisfied, or the default value if not.
|
39
|
+
"""
|
40
|
+
return value if condition(value) else default
|
41
|
+
|
42
|
+
|
43
|
+
def sanitize_seed(seed: Any) -> Optional[int]:
|
44
|
+
"""
|
45
|
+
Returns the seed if it is a non-negative integer; otherwise, returns None.
|
46
|
+
|
47
|
+
Parameters
|
48
|
+
----------
|
49
|
+
* seed (``Any``): The seed value to be validated.
|
50
|
+
|
51
|
+
Returns
|
52
|
+
----------
|
53
|
+
* Optional[int]: The original seed if it is a non-negative integer, or None if it is invalid.
|
54
|
+
"""
|
55
|
+
return seed if isinstance(seed, int) and seed >= 0 else None
|