stouputils 1.14.0__py3-none-any.whl → 1.14.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stouputils/__init__.pyi +15 -0
- stouputils/_deprecated.pyi +12 -0
- stouputils/all_doctests.pyi +46 -0
- stouputils/applications/__init__.pyi +2 -0
- stouputils/applications/automatic_docs.py +3 -0
- stouputils/applications/automatic_docs.pyi +106 -0
- stouputils/applications/upscaler/__init__.pyi +3 -0
- stouputils/applications/upscaler/config.pyi +18 -0
- stouputils/applications/upscaler/image.pyi +109 -0
- stouputils/applications/upscaler/video.pyi +60 -0
- stouputils/archive.pyi +67 -0
- stouputils/backup.pyi +109 -0
- stouputils/collections.pyi +86 -0
- stouputils/continuous_delivery/__init__.pyi +5 -0
- stouputils/continuous_delivery/cd_utils.pyi +129 -0
- stouputils/continuous_delivery/github.pyi +162 -0
- stouputils/continuous_delivery/pypi.pyi +52 -0
- stouputils/continuous_delivery/pyproject.pyi +67 -0
- stouputils/continuous_delivery/stubs.pyi +39 -0
- stouputils/ctx.pyi +211 -0
- stouputils/data_science/config/get.py +51 -51
- stouputils/data_science/data_processing/image/__init__.py +66 -66
- stouputils/data_science/data_processing/image/auto_contrast.py +79 -79
- stouputils/data_science/data_processing/image/axis_flip.py +58 -58
- stouputils/data_science/data_processing/image/bias_field_correction.py +74 -74
- stouputils/data_science/data_processing/image/binary_threshold.py +73 -73
- stouputils/data_science/data_processing/image/blur.py +59 -59
- stouputils/data_science/data_processing/image/brightness.py +54 -54
- stouputils/data_science/data_processing/image/canny.py +110 -110
- stouputils/data_science/data_processing/image/clahe.py +92 -92
- stouputils/data_science/data_processing/image/common.py +30 -30
- stouputils/data_science/data_processing/image/contrast.py +53 -53
- stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -74
- stouputils/data_science/data_processing/image/denoise.py +378 -378
- stouputils/data_science/data_processing/image/histogram_equalization.py +123 -123
- stouputils/data_science/data_processing/image/invert.py +64 -64
- stouputils/data_science/data_processing/image/laplacian.py +60 -60
- stouputils/data_science/data_processing/image/median_blur.py +52 -52
- stouputils/data_science/data_processing/image/noise.py +59 -59
- stouputils/data_science/data_processing/image/normalize.py +65 -65
- stouputils/data_science/data_processing/image/random_erase.py +66 -66
- stouputils/data_science/data_processing/image/resize.py +69 -69
- stouputils/data_science/data_processing/image/rotation.py +80 -80
- stouputils/data_science/data_processing/image/salt_pepper.py +68 -68
- stouputils/data_science/data_processing/image/sharpening.py +55 -55
- stouputils/data_science/data_processing/image/shearing.py +64 -64
- stouputils/data_science/data_processing/image/threshold.py +64 -64
- stouputils/data_science/data_processing/image/translation.py +71 -71
- stouputils/data_science/data_processing/image/zoom.py +83 -83
- stouputils/data_science/data_processing/image_augmentation.py +118 -118
- stouputils/data_science/data_processing/image_preprocess.py +183 -183
- stouputils/data_science/data_processing/prosthesis_detection.py +359 -359
- stouputils/data_science/data_processing/technique.py +481 -481
- stouputils/data_science/dataset/__init__.py +45 -45
- stouputils/data_science/dataset/dataset.py +292 -292
- stouputils/data_science/dataset/dataset_loader.py +135 -135
- stouputils/data_science/dataset/grouping_strategy.py +296 -296
- stouputils/data_science/dataset/image_loader.py +100 -100
- stouputils/data_science/dataset/xy_tuple.py +696 -696
- stouputils/data_science/metric_dictionnary.py +106 -106
- stouputils/data_science/mlflow_utils.py +206 -206
- stouputils/data_science/models/abstract_model.py +149 -149
- stouputils/data_science/models/all.py +85 -85
- stouputils/data_science/models/keras/all.py +38 -38
- stouputils/data_science/models/keras/convnext.py +62 -62
- stouputils/data_science/models/keras/densenet.py +50 -50
- stouputils/data_science/models/keras/efficientnet.py +60 -60
- stouputils/data_science/models/keras/mobilenet.py +56 -56
- stouputils/data_science/models/keras/resnet.py +52 -52
- stouputils/data_science/models/keras/squeezenet.py +233 -233
- stouputils/data_science/models/keras/vgg.py +42 -42
- stouputils/data_science/models/keras/xception.py +38 -38
- stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -20
- stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -219
- stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -148
- stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -31
- stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -249
- stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -66
- stouputils/data_science/models/keras_utils/losses/__init__.py +12 -12
- stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -56
- stouputils/data_science/models/keras_utils/visualizations.py +416 -416
- stouputils/data_science/models/sandbox.py +116 -116
- stouputils/data_science/range_tuple.py +234 -234
- stouputils/data_science/utils.py +285 -285
- stouputils/decorators.pyi +242 -0
- stouputils/image.pyi +172 -0
- stouputils/installer/__init__.py +18 -18
- stouputils/installer/__init__.pyi +5 -0
- stouputils/installer/common.pyi +39 -0
- stouputils/installer/downloader.pyi +24 -0
- stouputils/installer/linux.py +144 -144
- stouputils/installer/linux.pyi +39 -0
- stouputils/installer/main.py +223 -223
- stouputils/installer/main.pyi +57 -0
- stouputils/installer/windows.py +136 -136
- stouputils/installer/windows.pyi +31 -0
- stouputils/io.pyi +213 -0
- stouputils/parallel.py +12 -10
- stouputils/parallel.pyi +211 -0
- stouputils/print.pyi +136 -0
- stouputils/py.typed +1 -1
- stouputils/stouputils/parallel.pyi +4 -4
- stouputils/version_pkg.pyi +15 -0
- {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/METADATA +1 -1
- stouputils-1.14.2.dist-info/RECORD +171 -0
- stouputils-1.14.0.dist-info/RECORD +0 -140
- {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/WHEEL +0 -0
- {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/entry_points.txt +0 -0
|
@@ -1,106 +1,106 @@
|
|
|
1
|
-
"""
|
|
2
|
-
This module contains the MetricDictionnary class, which provides a dictionary of metric names.
|
|
3
|
-
|
|
4
|
-
This is often used to log metrics to MLflow and to display them in the console easily.
|
|
5
|
-
|
|
6
|
-
This class contains the following metrics:
|
|
7
|
-
|
|
8
|
-
1. Main metrics:
|
|
9
|
-
|
|
10
|
-
- Area Under the Curve (AUC)
|
|
11
|
-
- Area Under the Precision-Recall Curve (AUPRC)
|
|
12
|
-
- Area Under the NPV-Specificity Curve (NEGATIVE_AUPRC)
|
|
13
|
-
- Specificity (True Negative Rate)
|
|
14
|
-
- Recall/Sensitivity (True Positive Rate)
|
|
15
|
-
- Precision (Positive Predictive Value)
|
|
16
|
-
- Negative Predictive Value (NPV)
|
|
17
|
-
- Accuracy
|
|
18
|
-
- F1 Score
|
|
19
|
-
- Precision-Recall Average
|
|
20
|
-
- Precision-Recall Average for Negative Class
|
|
21
|
-
|
|
22
|
-
2. Confusion matrix metrics:
|
|
23
|
-
|
|
24
|
-
- True Negatives (TN)
|
|
25
|
-
- False Positives (FP)
|
|
26
|
-
- False Negatives (FN)
|
|
27
|
-
- True Positives (TP)
|
|
28
|
-
- False Positive Rate
|
|
29
|
-
- False Negative Rate
|
|
30
|
-
- False Discovery Rate
|
|
31
|
-
- False Omission Rate
|
|
32
|
-
- Critical Success Index (Threat Score)
|
|
33
|
-
|
|
34
|
-
3. F-scores:
|
|
35
|
-
|
|
36
|
-
- F-beta Score (where beta is configurable)
|
|
37
|
-
|
|
38
|
-
4. Matthews correlation coefficient:
|
|
39
|
-
|
|
40
|
-
- Matthews Correlation Coefficient (MCC)
|
|
41
|
-
|
|
42
|
-
5. Optimal thresholds for binary classification:
|
|
43
|
-
|
|
44
|
-
- Youden's J statistic
|
|
45
|
-
- Cost-based threshold
|
|
46
|
-
- F1 Score threshold
|
|
47
|
-
- F1 Score threshold for the negative class
|
|
48
|
-
|
|
49
|
-
6. Average metrics across folds:
|
|
50
|
-
|
|
51
|
-
- Mean value of any metric across k-fold cross validation
|
|
52
|
-
|
|
53
|
-
7. Standard deviation metrics across folds:
|
|
54
|
-
|
|
55
|
-
- Standard deviation of any metric across k-fold cross validation
|
|
56
|
-
"""
|
|
57
|
-
|
|
58
|
-
class MetricDictionnary:
|
|
59
|
-
|
|
60
|
-
# Main metrics (starting with '1:')
|
|
61
|
-
AUC: str = "1: Area Under the ROC Curve: AUC / AUROC"
|
|
62
|
-
AUPRC: str = "1: Area Under the Precision-Recall Curve: AUPRC / PR AUC"
|
|
63
|
-
NEGATIVE_AUPRC: str = "1: Area Under the NPV-Specificity Curve: AUNPRC / NPR AUC"
|
|
64
|
-
SPECIFICITY: str = "1: Specificity: True Negative Rate"
|
|
65
|
-
RECALL: str = "1: Recall/Sensitivity: True Positive Rate"
|
|
66
|
-
PRECISION: str = "1: Precision: Positive Predictive Value"
|
|
67
|
-
NPV: str = "1: NPV: Negative Predictive Value"
|
|
68
|
-
ACCURACY: str = "1: Accuracy"
|
|
69
|
-
BALANCED_ACCURACY: str = "1: Balanced Accuracy"
|
|
70
|
-
F1_SCORE: str = "1: F1 Score"
|
|
71
|
-
F1_SCORE_NEGATIVE: str = "1: F1 Score for Negative Class"
|
|
72
|
-
PR_AVERAGE: str = "1: Precision-Recall Average"
|
|
73
|
-
PR_AVERAGE_NEGATIVE: str = "1: Precision-Recall Average for Negative Class"
|
|
74
|
-
|
|
75
|
-
# Confusion matrix metrics (starting with '2:')
|
|
76
|
-
CONFUSION_MATRIX_TN: str = "2: Confusion Matrix: TN"
|
|
77
|
-
CONFUSION_MATRIX_FP: str = "2: Confusion Matrix: FP"
|
|
78
|
-
CONFUSION_MATRIX_FN: str = "2: Confusion Matrix: FN"
|
|
79
|
-
CONFUSION_MATRIX_TP: str = "2: Confusion Matrix: TP"
|
|
80
|
-
FALSE_POSITIVE_RATE: str = "2: False Positive Rate"
|
|
81
|
-
FALSE_NEGATIVE_RATE: str = "2: False Negative Rate"
|
|
82
|
-
FALSE_DISCOVERY_RATE: str = "2: False Discovery Rate"
|
|
83
|
-
FALSE_OMISSION_RATE: str = "2: False Omission Rate"
|
|
84
|
-
CRITICAL_SUCCESS_INDEX: str = "2: Critical Success Index: Threat Score"
|
|
85
|
-
|
|
86
|
-
# F-scores (starting with '3:')
|
|
87
|
-
F_SCORE_X: str = "3: F-X Score" # X is the beta value
|
|
88
|
-
|
|
89
|
-
# Matthews correlation coefficient (starting with '4:')
|
|
90
|
-
MATTHEWS_CORRELATION_COEFFICIENT: str = "4: Matthews Correlation Coefficient: MCC"
|
|
91
|
-
|
|
92
|
-
# Optimal thresholds (starting with '5:')
|
|
93
|
-
OPTIMAL_THRESHOLD_YOUDEN: str = "5: Optimal Threshold: Youden"
|
|
94
|
-
OPTIMAL_THRESHOLD_COST: str = "5: Optimal Threshold: Cost"
|
|
95
|
-
OPTIMAL_THRESHOLD_F1: str = "5: Optimal Threshold: F1"
|
|
96
|
-
OPTIMAL_THRESHOLD_F1_NEGATIVE: str = "5: Optimal Threshold: F1 for Negative Class"
|
|
97
|
-
|
|
98
|
-
# Average metrics across folds (starting with '6:')
|
|
99
|
-
AVERAGE_METRIC: str = "6: Average METRIC_NAME across folds"
|
|
100
|
-
|
|
101
|
-
# Standard deviation metrics across folds (starting with '7:')
|
|
102
|
-
STANDARD_DEVIATION_METRIC: str = "7: Standard deviation METRIC_NAME across folds"
|
|
103
|
-
|
|
104
|
-
# Parameter finder (starting with '8:')
|
|
105
|
-
PARAMETER_FINDER: str = "8: TITLE: PARAMETER_NAME"
|
|
106
|
-
|
|
1
|
+
"""
|
|
2
|
+
This module contains the MetricDictionnary class, which provides a dictionary of metric names.
|
|
3
|
+
|
|
4
|
+
This is often used to log metrics to MLflow and to display them in the console easily.
|
|
5
|
+
|
|
6
|
+
This class contains the following metrics:
|
|
7
|
+
|
|
8
|
+
1. Main metrics:
|
|
9
|
+
|
|
10
|
+
- Area Under the Curve (AUC)
|
|
11
|
+
- Area Under the Precision-Recall Curve (AUPRC)
|
|
12
|
+
- Area Under the NPV-Specificity Curve (NEGATIVE_AUPRC)
|
|
13
|
+
- Specificity (True Negative Rate)
|
|
14
|
+
- Recall/Sensitivity (True Positive Rate)
|
|
15
|
+
- Precision (Positive Predictive Value)
|
|
16
|
+
- Negative Predictive Value (NPV)
|
|
17
|
+
- Accuracy
|
|
18
|
+
- F1 Score
|
|
19
|
+
- Precision-Recall Average
|
|
20
|
+
- Precision-Recall Average for Negative Class
|
|
21
|
+
|
|
22
|
+
2. Confusion matrix metrics:
|
|
23
|
+
|
|
24
|
+
- True Negatives (TN)
|
|
25
|
+
- False Positives (FP)
|
|
26
|
+
- False Negatives (FN)
|
|
27
|
+
- True Positives (TP)
|
|
28
|
+
- False Positive Rate
|
|
29
|
+
- False Negative Rate
|
|
30
|
+
- False Discovery Rate
|
|
31
|
+
- False Omission Rate
|
|
32
|
+
- Critical Success Index (Threat Score)
|
|
33
|
+
|
|
34
|
+
3. F-scores:
|
|
35
|
+
|
|
36
|
+
- F-beta Score (where beta is configurable)
|
|
37
|
+
|
|
38
|
+
4. Matthews correlation coefficient:
|
|
39
|
+
|
|
40
|
+
- Matthews Correlation Coefficient (MCC)
|
|
41
|
+
|
|
42
|
+
5. Optimal thresholds for binary classification:
|
|
43
|
+
|
|
44
|
+
- Youden's J statistic
|
|
45
|
+
- Cost-based threshold
|
|
46
|
+
- F1 Score threshold
|
|
47
|
+
- F1 Score threshold for the negative class
|
|
48
|
+
|
|
49
|
+
6. Average metrics across folds:
|
|
50
|
+
|
|
51
|
+
- Mean value of any metric across k-fold cross validation
|
|
52
|
+
|
|
53
|
+
7. Standard deviation metrics across folds:
|
|
54
|
+
|
|
55
|
+
- Standard deviation of any metric across k-fold cross validation
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
class MetricDictionnary:
|
|
59
|
+
|
|
60
|
+
# Main metrics (starting with '1:')
|
|
61
|
+
AUC: str = "1: Area Under the ROC Curve: AUC / AUROC"
|
|
62
|
+
AUPRC: str = "1: Area Under the Precision-Recall Curve: AUPRC / PR AUC"
|
|
63
|
+
NEGATIVE_AUPRC: str = "1: Area Under the NPV-Specificity Curve: AUNPRC / NPR AUC"
|
|
64
|
+
SPECIFICITY: str = "1: Specificity: True Negative Rate"
|
|
65
|
+
RECALL: str = "1: Recall/Sensitivity: True Positive Rate"
|
|
66
|
+
PRECISION: str = "1: Precision: Positive Predictive Value"
|
|
67
|
+
NPV: str = "1: NPV: Negative Predictive Value"
|
|
68
|
+
ACCURACY: str = "1: Accuracy"
|
|
69
|
+
BALANCED_ACCURACY: str = "1: Balanced Accuracy"
|
|
70
|
+
F1_SCORE: str = "1: F1 Score"
|
|
71
|
+
F1_SCORE_NEGATIVE: str = "1: F1 Score for Negative Class"
|
|
72
|
+
PR_AVERAGE: str = "1: Precision-Recall Average"
|
|
73
|
+
PR_AVERAGE_NEGATIVE: str = "1: Precision-Recall Average for Negative Class"
|
|
74
|
+
|
|
75
|
+
# Confusion matrix metrics (starting with '2:')
|
|
76
|
+
CONFUSION_MATRIX_TN: str = "2: Confusion Matrix: TN"
|
|
77
|
+
CONFUSION_MATRIX_FP: str = "2: Confusion Matrix: FP"
|
|
78
|
+
CONFUSION_MATRIX_FN: str = "2: Confusion Matrix: FN"
|
|
79
|
+
CONFUSION_MATRIX_TP: str = "2: Confusion Matrix: TP"
|
|
80
|
+
FALSE_POSITIVE_RATE: str = "2: False Positive Rate"
|
|
81
|
+
FALSE_NEGATIVE_RATE: str = "2: False Negative Rate"
|
|
82
|
+
FALSE_DISCOVERY_RATE: str = "2: False Discovery Rate"
|
|
83
|
+
FALSE_OMISSION_RATE: str = "2: False Omission Rate"
|
|
84
|
+
CRITICAL_SUCCESS_INDEX: str = "2: Critical Success Index: Threat Score"
|
|
85
|
+
|
|
86
|
+
# F-scores (starting with '3:')
|
|
87
|
+
F_SCORE_X: str = "3: F-X Score" # X is the beta value
|
|
88
|
+
|
|
89
|
+
# Matthews correlation coefficient (starting with '4:')
|
|
90
|
+
MATTHEWS_CORRELATION_COEFFICIENT: str = "4: Matthews Correlation Coefficient: MCC"
|
|
91
|
+
|
|
92
|
+
# Optimal thresholds (starting with '5:')
|
|
93
|
+
OPTIMAL_THRESHOLD_YOUDEN: str = "5: Optimal Threshold: Youden"
|
|
94
|
+
OPTIMAL_THRESHOLD_COST: str = "5: Optimal Threshold: Cost"
|
|
95
|
+
OPTIMAL_THRESHOLD_F1: str = "5: Optimal Threshold: F1"
|
|
96
|
+
OPTIMAL_THRESHOLD_F1_NEGATIVE: str = "5: Optimal Threshold: F1 for Negative Class"
|
|
97
|
+
|
|
98
|
+
# Average metrics across folds (starting with '6:')
|
|
99
|
+
AVERAGE_METRIC: str = "6: Average METRIC_NAME across folds"
|
|
100
|
+
|
|
101
|
+
# Standard deviation metrics across folds (starting with '7:')
|
|
102
|
+
STANDARD_DEVIATION_METRIC: str = "7: Standard deviation METRIC_NAME across folds"
|
|
103
|
+
|
|
104
|
+
# Parameter finder (starting with '8:')
|
|
105
|
+
PARAMETER_FINDER: str = "8: TITLE: PARAMETER_NAME"
|
|
106
|
+
|
|
@@ -1,206 +1,206 @@
|
|
|
1
|
-
"""
|
|
2
|
-
This module contains utility functions for working with MLflow.
|
|
3
|
-
|
|
4
|
-
This module contains functions for:
|
|
5
|
-
|
|
6
|
-
- Getting the artifact path from the current mlflow run
|
|
7
|
-
- Getting the weights path
|
|
8
|
-
- Getting the runs by experiment name
|
|
9
|
-
- Logging the history of the model to the current mlflow run
|
|
10
|
-
- Starting a new mlflow run
|
|
11
|
-
"""
|
|
12
|
-
|
|
13
|
-
# Imports
|
|
14
|
-
import os
|
|
15
|
-
from typing import Any, Literal
|
|
16
|
-
|
|
17
|
-
import mlflow
|
|
18
|
-
from mlflow.entities import Experiment, Run
|
|
19
|
-
|
|
20
|
-
from ..decorators import handle_error, LogLevels
|
|
21
|
-
from ..io import clean_path
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
# Get artifact path
|
|
25
|
-
def get_artifact_path(from_string: str = "", os_name: str = os.name) -> str:
|
|
26
|
-
""" Get the artifact path from the current mlflow run (without the file:// prefix).
|
|
27
|
-
|
|
28
|
-
Handles the different path formats for Windows and Unix-based systems.
|
|
29
|
-
|
|
30
|
-
Args:
|
|
31
|
-
from_string (str): Path to the artifact (optional, defaults to the current mlflow run)
|
|
32
|
-
os_name (str): OS name (optional, defaults to os.name)
|
|
33
|
-
Returns:
|
|
34
|
-
str: The artifact path
|
|
35
|
-
"""
|
|
36
|
-
# Get the artifact path from the current mlflow run or from a string
|
|
37
|
-
if not from_string:
|
|
38
|
-
artifact_path: str = mlflow.get_artifact_uri()
|
|
39
|
-
else:
|
|
40
|
-
artifact_path: str = from_string
|
|
41
|
-
|
|
42
|
-
# Handle the different path formats for Windows and Unix-based systems
|
|
43
|
-
if os_name == "nt":
|
|
44
|
-
return artifact_path.replace("file:///", "")
|
|
45
|
-
else:
|
|
46
|
-
return artifact_path.replace("file://", "")
|
|
47
|
-
|
|
48
|
-
# Get weights path
|
|
49
|
-
def get_weights_path(from_string: str = "", weights_name: str = "best_model.keras", os_name: str = os.name) -> str:
|
|
50
|
-
""" Get the weights path from the current mlflow run.
|
|
51
|
-
|
|
52
|
-
Args:
|
|
53
|
-
from_string (str): Path to the artifact (optional, defaults to the current mlflow run)
|
|
54
|
-
weights_name (str): Name of the weights file (optional, defaults to "best_model.keras")
|
|
55
|
-
os_name (str): OS name (optional, defaults to os.name)
|
|
56
|
-
Returns:
|
|
57
|
-
str: The weights path
|
|
58
|
-
|
|
59
|
-
Examples:
|
|
60
|
-
>>> get_weights_path(from_string="file:///path/to/artifact", weights_name="best_model.keras", os_name="posix")
|
|
61
|
-
'/path/to/artifact/best_model.keras'
|
|
62
|
-
|
|
63
|
-
>>> get_weights_path(from_string="file:///C:/path/to/artifact", weights_name="best_model.keras", os_name="nt")
|
|
64
|
-
'C:/path/to/artifact/best_model.keras'
|
|
65
|
-
"""
|
|
66
|
-
return clean_path(f"{get_artifact_path(from_string=from_string, os_name=os_name)}/{weights_name}")
|
|
67
|
-
|
|
68
|
-
# Get runs by experiment name
|
|
69
|
-
def get_runs_by_experiment_name(experiment_name: str, filter_string: str = "", set_experiment: bool = False) -> list[Run]:
|
|
70
|
-
""" Get the runs by experiment name.
|
|
71
|
-
|
|
72
|
-
Args:
|
|
73
|
-
experiment_name (str): Name of the experiment
|
|
74
|
-
filter_string (str): Filter string to apply to the runs
|
|
75
|
-
set_experiment (bool): Whether to set the experiment
|
|
76
|
-
Returns:
|
|
77
|
-
list[Run]: List of runs
|
|
78
|
-
"""
|
|
79
|
-
if set_experiment:
|
|
80
|
-
mlflow.set_experiment(experiment_name)
|
|
81
|
-
experiment: Experiment | None = mlflow.get_experiment_by_name(experiment_name)
|
|
82
|
-
if experiment:
|
|
83
|
-
return mlflow.search_runs(
|
|
84
|
-
experiment_ids=[experiment.experiment_id],
|
|
85
|
-
output_format="list",
|
|
86
|
-
filter_string=filter_string
|
|
87
|
-
) # pyright: ignore [reportReturnType]
|
|
88
|
-
return []
|
|
89
|
-
|
|
90
|
-
def get_runs_by_model_name(experiment_name: str, model_name: str, set_experiment: bool = False) -> list[Run]:
|
|
91
|
-
""" Get the runs by model name.
|
|
92
|
-
|
|
93
|
-
Args:
|
|
94
|
-
experiment_name (str): Name of the experiment
|
|
95
|
-
model_name (str): Name of the model
|
|
96
|
-
set_experiment (bool): Whether to set the experiment
|
|
97
|
-
Returns:
|
|
98
|
-
list[Run]: List of runs
|
|
99
|
-
"""
|
|
100
|
-
return get_runs_by_experiment_name(
|
|
101
|
-
experiment_name,
|
|
102
|
-
filter_string=f"tags.model_name = '{model_name}'",
|
|
103
|
-
set_experiment=set_experiment
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
# Log history
|
|
107
|
-
def log_history(history: dict[str, list[Any]], prefix: str = "history", **kwargs: Any) -> None:
|
|
108
|
-
""" Log the history of the model to the current mlflow run.
|
|
109
|
-
|
|
110
|
-
Args:
|
|
111
|
-
history (dict[str, list[Any]]): History of the model
|
|
112
|
-
(usually from a History object like from a Keras model: history.history)
|
|
113
|
-
**kwargs (Any): Additional arguments to pass to mlflow.log_metric
|
|
114
|
-
"""
|
|
115
|
-
for (metric, values) in history.items():
|
|
116
|
-
for epoch, value in enumerate(values):
|
|
117
|
-
handle_error(mlflow.log_metric,
|
|
118
|
-
message=f"Error logging metric {metric}",
|
|
119
|
-
error_log=LogLevels.ERROR_TRACEBACK
|
|
120
|
-
)(f"{prefix}_{metric}", value, step=epoch, **kwargs)
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
def start_run(mlflow_uri: str, experiment_name: str, model_name: str, override_run_name: str = "", **kwargs: Any) -> str:
|
|
124
|
-
""" Start a new mlflow run.
|
|
125
|
-
|
|
126
|
-
Args:
|
|
127
|
-
mlflow_uri (str): MLflow URI
|
|
128
|
-
experiment_name (str): Name of the experiment
|
|
129
|
-
model_name (str): Name of the model
|
|
130
|
-
override_run_name (str): Override the run name (if empty, it will be set automatically)
|
|
131
|
-
**kwargs (Any): Additional arguments to pass to mlflow.start_run
|
|
132
|
-
Returns:
|
|
133
|
-
str: Name of the run (suffixed with the version number)
|
|
134
|
-
"""
|
|
135
|
-
# Set the mlflow URI
|
|
136
|
-
mlflow.set_tracking_uri(mlflow_uri)
|
|
137
|
-
|
|
138
|
-
# Get the runs and increment the version number
|
|
139
|
-
runs: list[Run] = get_runs_by_model_name(experiment_name, model_name, set_experiment=True)
|
|
140
|
-
run_number: int = len(runs) + 1
|
|
141
|
-
run_name: str = f"{model_name}_v{run_number:02d}" if not override_run_name else override_run_name
|
|
142
|
-
|
|
143
|
-
# Start the run
|
|
144
|
-
mlflow.start_run(run_name=run_name, tags={"model_name": model_name}, log_system_metrics=True, **kwargs)
|
|
145
|
-
return run_name
|
|
146
|
-
|
|
147
|
-
# Get best run by metric
|
|
148
|
-
def get_best_run_by_metric(
|
|
149
|
-
experiment_name: str,
|
|
150
|
-
metric_name: str,
|
|
151
|
-
model_name: str = "",
|
|
152
|
-
ascending: bool = False,
|
|
153
|
-
has_saved_model: bool = True
|
|
154
|
-
) -> Run | None:
|
|
155
|
-
""" Get the best run by a specific metric.
|
|
156
|
-
|
|
157
|
-
Args:
|
|
158
|
-
experiment_name (str): Name of the experiment
|
|
159
|
-
metric_name (str): Name of the metric to sort by
|
|
160
|
-
model_name (str): Name of the model (optional, if empty, all models are considered)
|
|
161
|
-
ascending (bool): Whether to sort in ascending order (default: False, i.e. maximum metric value is best)
|
|
162
|
-
has_saved_model (bool): Whether the model has been saved (default: True)
|
|
163
|
-
Returns:
|
|
164
|
-
Run | None: The best run or None if no runs are found
|
|
165
|
-
"""
|
|
166
|
-
# Get the runs
|
|
167
|
-
filter_string: str = f"metrics.`{metric_name}` > 0"
|
|
168
|
-
if model_name:
|
|
169
|
-
filter_string += f" AND tags.model_name = '{model_name}'"
|
|
170
|
-
if has_saved_model:
|
|
171
|
-
filter_string += " AND tags.has_saved_model = 'True'"
|
|
172
|
-
|
|
173
|
-
runs: list[Run] = get_runs_by_experiment_name(
|
|
174
|
-
experiment_name,
|
|
175
|
-
filter_string=filter_string,
|
|
176
|
-
set_experiment=True
|
|
177
|
-
)
|
|
178
|
-
|
|
179
|
-
if not runs:
|
|
180
|
-
return None
|
|
181
|
-
|
|
182
|
-
# Sort the runs by the metric
|
|
183
|
-
sorted_runs: list[Run] = sorted(
|
|
184
|
-
runs,
|
|
185
|
-
key=lambda run: float(run.data.metrics.get(metric_name, 0)), # type: ignore
|
|
186
|
-
reverse=not ascending
|
|
187
|
-
)
|
|
188
|
-
|
|
189
|
-
return sorted_runs[0] if sorted_runs else None
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
def load_model(run_id: str, model_type: Literal["keras", "pytorch"] = "keras") -> Any:
|
|
193
|
-
""" Load a model from MLflow.
|
|
194
|
-
|
|
195
|
-
Args:
|
|
196
|
-
run_id (str): ID of the run to load the model from
|
|
197
|
-
model_type (Literal["keras", "pytorch"]): Type of model to load (default: "keras")
|
|
198
|
-
Returns:
|
|
199
|
-
Any: The loaded model
|
|
200
|
-
"""
|
|
201
|
-
if model_type == "keras":
|
|
202
|
-
return mlflow.keras.load_model(f"runs:/{run_id}/best_model") # type: ignore
|
|
203
|
-
elif model_type == "pytorch":
|
|
204
|
-
return mlflow.pytorch.load_model(f"runs:/{run_id}/best_model") # type: ignore
|
|
205
|
-
raise ValueError(f"Model type {model_type} not supported")
|
|
206
|
-
|
|
1
|
+
"""
|
|
2
|
+
This module contains utility functions for working with MLflow.
|
|
3
|
+
|
|
4
|
+
This module contains functions for:
|
|
5
|
+
|
|
6
|
+
- Getting the artifact path from the current mlflow run
|
|
7
|
+
- Getting the weights path
|
|
8
|
+
- Getting the runs by experiment name
|
|
9
|
+
- Logging the history of the model to the current mlflow run
|
|
10
|
+
- Starting a new mlflow run
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
# Imports
|
|
14
|
+
import os
|
|
15
|
+
from typing import Any, Literal
|
|
16
|
+
|
|
17
|
+
import mlflow
|
|
18
|
+
from mlflow.entities import Experiment, Run
|
|
19
|
+
|
|
20
|
+
from ..decorators import handle_error, LogLevels
|
|
21
|
+
from ..io import clean_path
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# Get artifact path
|
|
25
|
+
def get_artifact_path(from_string: str = "", os_name: str = os.name) -> str:
|
|
26
|
+
""" Get the artifact path from the current mlflow run (without the file:// prefix).
|
|
27
|
+
|
|
28
|
+
Handles the different path formats for Windows and Unix-based systems.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
from_string (str): Path to the artifact (optional, defaults to the current mlflow run)
|
|
32
|
+
os_name (str): OS name (optional, defaults to os.name)
|
|
33
|
+
Returns:
|
|
34
|
+
str: The artifact path
|
|
35
|
+
"""
|
|
36
|
+
# Get the artifact path from the current mlflow run or from a string
|
|
37
|
+
if not from_string:
|
|
38
|
+
artifact_path: str = mlflow.get_artifact_uri()
|
|
39
|
+
else:
|
|
40
|
+
artifact_path: str = from_string
|
|
41
|
+
|
|
42
|
+
# Handle the different path formats for Windows and Unix-based systems
|
|
43
|
+
if os_name == "nt":
|
|
44
|
+
return artifact_path.replace("file:///", "")
|
|
45
|
+
else:
|
|
46
|
+
return artifact_path.replace("file://", "")
|
|
47
|
+
|
|
48
|
+
# Get weights path
|
|
49
|
+
def get_weights_path(from_string: str = "", weights_name: str = "best_model.keras", os_name: str = os.name) -> str:
|
|
50
|
+
""" Get the weights path from the current mlflow run.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
from_string (str): Path to the artifact (optional, defaults to the current mlflow run)
|
|
54
|
+
weights_name (str): Name of the weights file (optional, defaults to "best_model.keras")
|
|
55
|
+
os_name (str): OS name (optional, defaults to os.name)
|
|
56
|
+
Returns:
|
|
57
|
+
str: The weights path
|
|
58
|
+
|
|
59
|
+
Examples:
|
|
60
|
+
>>> get_weights_path(from_string="file:///path/to/artifact", weights_name="best_model.keras", os_name="posix")
|
|
61
|
+
'/path/to/artifact/best_model.keras'
|
|
62
|
+
|
|
63
|
+
>>> get_weights_path(from_string="file:///C:/path/to/artifact", weights_name="best_model.keras", os_name="nt")
|
|
64
|
+
'C:/path/to/artifact/best_model.keras'
|
|
65
|
+
"""
|
|
66
|
+
return clean_path(f"{get_artifact_path(from_string=from_string, os_name=os_name)}/{weights_name}")
|
|
67
|
+
|
|
68
|
+
# Get runs by experiment name
|
|
69
|
+
def get_runs_by_experiment_name(experiment_name: str, filter_string: str = "", set_experiment: bool = False) -> list[Run]:
|
|
70
|
+
""" Get the runs by experiment name.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
experiment_name (str): Name of the experiment
|
|
74
|
+
filter_string (str): Filter string to apply to the runs
|
|
75
|
+
set_experiment (bool): Whether to set the experiment
|
|
76
|
+
Returns:
|
|
77
|
+
list[Run]: List of runs
|
|
78
|
+
"""
|
|
79
|
+
if set_experiment:
|
|
80
|
+
mlflow.set_experiment(experiment_name)
|
|
81
|
+
experiment: Experiment | None = mlflow.get_experiment_by_name(experiment_name)
|
|
82
|
+
if experiment:
|
|
83
|
+
return mlflow.search_runs(
|
|
84
|
+
experiment_ids=[experiment.experiment_id],
|
|
85
|
+
output_format="list",
|
|
86
|
+
filter_string=filter_string
|
|
87
|
+
) # pyright: ignore [reportReturnType]
|
|
88
|
+
return []
|
|
89
|
+
|
|
90
|
+
def get_runs_by_model_name(experiment_name: str, model_name: str, set_experiment: bool = False) -> list[Run]:
|
|
91
|
+
""" Get the runs by model name.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
experiment_name (str): Name of the experiment
|
|
95
|
+
model_name (str): Name of the model
|
|
96
|
+
set_experiment (bool): Whether to set the experiment
|
|
97
|
+
Returns:
|
|
98
|
+
list[Run]: List of runs
|
|
99
|
+
"""
|
|
100
|
+
return get_runs_by_experiment_name(
|
|
101
|
+
experiment_name,
|
|
102
|
+
filter_string=f"tags.model_name = '{model_name}'",
|
|
103
|
+
set_experiment=set_experiment
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Log history
|
|
107
|
+
def log_history(history: dict[str, list[Any]], prefix: str = "history", **kwargs: Any) -> None:
|
|
108
|
+
""" Log the history of the model to the current mlflow run.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
history (dict[str, list[Any]]): History of the model
|
|
112
|
+
(usually from a History object like from a Keras model: history.history)
|
|
113
|
+
**kwargs (Any): Additional arguments to pass to mlflow.log_metric
|
|
114
|
+
"""
|
|
115
|
+
for (metric, values) in history.items():
|
|
116
|
+
for epoch, value in enumerate(values):
|
|
117
|
+
handle_error(mlflow.log_metric,
|
|
118
|
+
message=f"Error logging metric {metric}",
|
|
119
|
+
error_log=LogLevels.ERROR_TRACEBACK
|
|
120
|
+
)(f"{prefix}_{metric}", value, step=epoch, **kwargs)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def start_run(mlflow_uri: str, experiment_name: str, model_name: str, override_run_name: str = "", **kwargs: Any) -> str:
|
|
124
|
+
""" Start a new mlflow run.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
mlflow_uri (str): MLflow URI
|
|
128
|
+
experiment_name (str): Name of the experiment
|
|
129
|
+
model_name (str): Name of the model
|
|
130
|
+
override_run_name (str): Override the run name (if empty, it will be set automatically)
|
|
131
|
+
**kwargs (Any): Additional arguments to pass to mlflow.start_run
|
|
132
|
+
Returns:
|
|
133
|
+
str: Name of the run (suffixed with the version number)
|
|
134
|
+
"""
|
|
135
|
+
# Set the mlflow URI
|
|
136
|
+
mlflow.set_tracking_uri(mlflow_uri)
|
|
137
|
+
|
|
138
|
+
# Get the runs and increment the version number
|
|
139
|
+
runs: list[Run] = get_runs_by_model_name(experiment_name, model_name, set_experiment=True)
|
|
140
|
+
run_number: int = len(runs) + 1
|
|
141
|
+
run_name: str = f"{model_name}_v{run_number:02d}" if not override_run_name else override_run_name
|
|
142
|
+
|
|
143
|
+
# Start the run
|
|
144
|
+
mlflow.start_run(run_name=run_name, tags={"model_name": model_name}, log_system_metrics=True, **kwargs)
|
|
145
|
+
return run_name
|
|
146
|
+
|
|
147
|
+
# Get best run by metric
|
|
148
|
+
def get_best_run_by_metric(
|
|
149
|
+
experiment_name: str,
|
|
150
|
+
metric_name: str,
|
|
151
|
+
model_name: str = "",
|
|
152
|
+
ascending: bool = False,
|
|
153
|
+
has_saved_model: bool = True
|
|
154
|
+
) -> Run | None:
|
|
155
|
+
""" Get the best run by a specific metric.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
experiment_name (str): Name of the experiment
|
|
159
|
+
metric_name (str): Name of the metric to sort by
|
|
160
|
+
model_name (str): Name of the model (optional, if empty, all models are considered)
|
|
161
|
+
ascending (bool): Whether to sort in ascending order (default: False, i.e. maximum metric value is best)
|
|
162
|
+
has_saved_model (bool): Whether the model has been saved (default: True)
|
|
163
|
+
Returns:
|
|
164
|
+
Run | None: The best run or None if no runs are found
|
|
165
|
+
"""
|
|
166
|
+
# Get the runs
|
|
167
|
+
filter_string: str = f"metrics.`{metric_name}` > 0"
|
|
168
|
+
if model_name:
|
|
169
|
+
filter_string += f" AND tags.model_name = '{model_name}'"
|
|
170
|
+
if has_saved_model:
|
|
171
|
+
filter_string += " AND tags.has_saved_model = 'True'"
|
|
172
|
+
|
|
173
|
+
runs: list[Run] = get_runs_by_experiment_name(
|
|
174
|
+
experiment_name,
|
|
175
|
+
filter_string=filter_string,
|
|
176
|
+
set_experiment=True
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
if not runs:
|
|
180
|
+
return None
|
|
181
|
+
|
|
182
|
+
# Sort the runs by the metric
|
|
183
|
+
sorted_runs: list[Run] = sorted(
|
|
184
|
+
runs,
|
|
185
|
+
key=lambda run: float(run.data.metrics.get(metric_name, 0)), # type: ignore
|
|
186
|
+
reverse=not ascending
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
return sorted_runs[0] if sorted_runs else None
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def load_model(run_id: str, model_type: Literal["keras", "pytorch"] = "keras") -> Any:
|
|
193
|
+
""" Load a model from MLflow.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
run_id (str): ID of the run to load the model from
|
|
197
|
+
model_type (Literal["keras", "pytorch"]): Type of model to load (default: "keras")
|
|
198
|
+
Returns:
|
|
199
|
+
Any: The loaded model
|
|
200
|
+
"""
|
|
201
|
+
if model_type == "keras":
|
|
202
|
+
return mlflow.keras.load_model(f"runs:/{run_id}/best_model") # type: ignore
|
|
203
|
+
elif model_type == "pytorch":
|
|
204
|
+
return mlflow.pytorch.load_model(f"runs:/{run_id}/best_model") # type: ignore
|
|
205
|
+
raise ValueError(f"Model type {model_type} not supported")
|
|
206
|
+
|