stouputils 1.14.0__py3-none-any.whl → 1.14.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. stouputils/__init__.pyi +15 -0
  2. stouputils/_deprecated.pyi +12 -0
  3. stouputils/all_doctests.pyi +46 -0
  4. stouputils/applications/__init__.pyi +2 -0
  5. stouputils/applications/automatic_docs.py +3 -0
  6. stouputils/applications/automatic_docs.pyi +106 -0
  7. stouputils/applications/upscaler/__init__.pyi +3 -0
  8. stouputils/applications/upscaler/config.pyi +18 -0
  9. stouputils/applications/upscaler/image.pyi +109 -0
  10. stouputils/applications/upscaler/video.pyi +60 -0
  11. stouputils/archive.pyi +67 -0
  12. stouputils/backup.pyi +109 -0
  13. stouputils/collections.pyi +86 -0
  14. stouputils/continuous_delivery/__init__.pyi +5 -0
  15. stouputils/continuous_delivery/cd_utils.pyi +129 -0
  16. stouputils/continuous_delivery/github.pyi +162 -0
  17. stouputils/continuous_delivery/pypi.pyi +52 -0
  18. stouputils/continuous_delivery/pyproject.pyi +67 -0
  19. stouputils/continuous_delivery/stubs.pyi +39 -0
  20. stouputils/ctx.pyi +211 -0
  21. stouputils/data_science/config/get.py +51 -51
  22. stouputils/data_science/data_processing/image/__init__.py +66 -66
  23. stouputils/data_science/data_processing/image/auto_contrast.py +79 -79
  24. stouputils/data_science/data_processing/image/axis_flip.py +58 -58
  25. stouputils/data_science/data_processing/image/bias_field_correction.py +74 -74
  26. stouputils/data_science/data_processing/image/binary_threshold.py +73 -73
  27. stouputils/data_science/data_processing/image/blur.py +59 -59
  28. stouputils/data_science/data_processing/image/brightness.py +54 -54
  29. stouputils/data_science/data_processing/image/canny.py +110 -110
  30. stouputils/data_science/data_processing/image/clahe.py +92 -92
  31. stouputils/data_science/data_processing/image/common.py +30 -30
  32. stouputils/data_science/data_processing/image/contrast.py +53 -53
  33. stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -74
  34. stouputils/data_science/data_processing/image/denoise.py +378 -378
  35. stouputils/data_science/data_processing/image/histogram_equalization.py +123 -123
  36. stouputils/data_science/data_processing/image/invert.py +64 -64
  37. stouputils/data_science/data_processing/image/laplacian.py +60 -60
  38. stouputils/data_science/data_processing/image/median_blur.py +52 -52
  39. stouputils/data_science/data_processing/image/noise.py +59 -59
  40. stouputils/data_science/data_processing/image/normalize.py +65 -65
  41. stouputils/data_science/data_processing/image/random_erase.py +66 -66
  42. stouputils/data_science/data_processing/image/resize.py +69 -69
  43. stouputils/data_science/data_processing/image/rotation.py +80 -80
  44. stouputils/data_science/data_processing/image/salt_pepper.py +68 -68
  45. stouputils/data_science/data_processing/image/sharpening.py +55 -55
  46. stouputils/data_science/data_processing/image/shearing.py +64 -64
  47. stouputils/data_science/data_processing/image/threshold.py +64 -64
  48. stouputils/data_science/data_processing/image/translation.py +71 -71
  49. stouputils/data_science/data_processing/image/zoom.py +83 -83
  50. stouputils/data_science/data_processing/image_augmentation.py +118 -118
  51. stouputils/data_science/data_processing/image_preprocess.py +183 -183
  52. stouputils/data_science/data_processing/prosthesis_detection.py +359 -359
  53. stouputils/data_science/data_processing/technique.py +481 -481
  54. stouputils/data_science/dataset/__init__.py +45 -45
  55. stouputils/data_science/dataset/dataset.py +292 -292
  56. stouputils/data_science/dataset/dataset_loader.py +135 -135
  57. stouputils/data_science/dataset/grouping_strategy.py +296 -296
  58. stouputils/data_science/dataset/image_loader.py +100 -100
  59. stouputils/data_science/dataset/xy_tuple.py +696 -696
  60. stouputils/data_science/metric_dictionnary.py +106 -106
  61. stouputils/data_science/mlflow_utils.py +206 -206
  62. stouputils/data_science/models/abstract_model.py +149 -149
  63. stouputils/data_science/models/all.py +85 -85
  64. stouputils/data_science/models/keras/all.py +38 -38
  65. stouputils/data_science/models/keras/convnext.py +62 -62
  66. stouputils/data_science/models/keras/densenet.py +50 -50
  67. stouputils/data_science/models/keras/efficientnet.py +60 -60
  68. stouputils/data_science/models/keras/mobilenet.py +56 -56
  69. stouputils/data_science/models/keras/resnet.py +52 -52
  70. stouputils/data_science/models/keras/squeezenet.py +233 -233
  71. stouputils/data_science/models/keras/vgg.py +42 -42
  72. stouputils/data_science/models/keras/xception.py +38 -38
  73. stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -20
  74. stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -219
  75. stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -148
  76. stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -31
  77. stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -249
  78. stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -66
  79. stouputils/data_science/models/keras_utils/losses/__init__.py +12 -12
  80. stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -56
  81. stouputils/data_science/models/keras_utils/visualizations.py +416 -416
  82. stouputils/data_science/models/sandbox.py +116 -116
  83. stouputils/data_science/range_tuple.py +234 -234
  84. stouputils/data_science/utils.py +285 -285
  85. stouputils/decorators.pyi +242 -0
  86. stouputils/image.pyi +172 -0
  87. stouputils/installer/__init__.py +18 -18
  88. stouputils/installer/__init__.pyi +5 -0
  89. stouputils/installer/common.pyi +39 -0
  90. stouputils/installer/downloader.pyi +24 -0
  91. stouputils/installer/linux.py +144 -144
  92. stouputils/installer/linux.pyi +39 -0
  93. stouputils/installer/main.py +223 -223
  94. stouputils/installer/main.pyi +57 -0
  95. stouputils/installer/windows.py +136 -136
  96. stouputils/installer/windows.pyi +31 -0
  97. stouputils/io.pyi +213 -0
  98. stouputils/parallel.py +12 -10
  99. stouputils/parallel.pyi +211 -0
  100. stouputils/print.pyi +136 -0
  101. stouputils/py.typed +1 -1
  102. stouputils/stouputils/parallel.pyi +4 -4
  103. stouputils/version_pkg.pyi +15 -0
  104. {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/METADATA +1 -1
  105. stouputils-1.14.2.dist-info/RECORD +171 -0
  106. stouputils-1.14.0.dist-info/RECORD +0 -140
  107. {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/WHEEL +0 -0
  108. {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/entry_points.txt +0 -0
@@ -1,106 +1,106 @@
1
- """
2
- This module contains the MetricDictionnary class, which provides a dictionary of metric names.
3
-
4
- This is often used to log metrics to MLflow and to display them in the console easily.
5
-
6
- This class contains the following metrics:
7
-
8
- 1. Main metrics:
9
-
10
- - Area Under the Curve (AUC)
11
- - Area Under the Precision-Recall Curve (AUPRC)
12
- - Area Under the NPV-Specificity Curve (NEGATIVE_AUPRC)
13
- - Specificity (True Negative Rate)
14
- - Recall/Sensitivity (True Positive Rate)
15
- - Precision (Positive Predictive Value)
16
- - Negative Predictive Value (NPV)
17
- - Accuracy
18
- - F1 Score
19
- - Precision-Recall Average
20
- - Precision-Recall Average for Negative Class
21
-
22
- 2. Confusion matrix metrics:
23
-
24
- - True Negatives (TN)
25
- - False Positives (FP)
26
- - False Negatives (FN)
27
- - True Positives (TP)
28
- - False Positive Rate
29
- - False Negative Rate
30
- - False Discovery Rate
31
- - False Omission Rate
32
- - Critical Success Index (Threat Score)
33
-
34
- 3. F-scores:
35
-
36
- - F-beta Score (where beta is configurable)
37
-
38
- 4. Matthews correlation coefficient:
39
-
40
- - Matthews Correlation Coefficient (MCC)
41
-
42
- 5. Optimal thresholds for binary classification:
43
-
44
- - Youden's J statistic
45
- - Cost-based threshold
46
- - F1 Score threshold
47
- - F1 Score threshold for the negative class
48
-
49
- 6. Average metrics across folds:
50
-
51
- - Mean value of any metric across k-fold cross validation
52
-
53
- 7. Standard deviation metrics across folds:
54
-
55
- - Standard deviation of any metric across k-fold cross validation
56
- """
57
-
58
- class MetricDictionnary:
59
-
60
- # Main metrics (starting with '1:')
61
- AUC: str = "1: Area Under the ROC Curve: AUC / AUROC"
62
- AUPRC: str = "1: Area Under the Precision-Recall Curve: AUPRC / PR AUC"
63
- NEGATIVE_AUPRC: str = "1: Area Under the NPV-Specificity Curve: AUNPRC / NPR AUC"
64
- SPECIFICITY: str = "1: Specificity: True Negative Rate"
65
- RECALL: str = "1: Recall/Sensitivity: True Positive Rate"
66
- PRECISION: str = "1: Precision: Positive Predictive Value"
67
- NPV: str = "1: NPV: Negative Predictive Value"
68
- ACCURACY: str = "1: Accuracy"
69
- BALANCED_ACCURACY: str = "1: Balanced Accuracy"
70
- F1_SCORE: str = "1: F1 Score"
71
- F1_SCORE_NEGATIVE: str = "1: F1 Score for Negative Class"
72
- PR_AVERAGE: str = "1: Precision-Recall Average"
73
- PR_AVERAGE_NEGATIVE: str = "1: Precision-Recall Average for Negative Class"
74
-
75
- # Confusion matrix metrics (starting with '2:')
76
- CONFUSION_MATRIX_TN: str = "2: Confusion Matrix: TN"
77
- CONFUSION_MATRIX_FP: str = "2: Confusion Matrix: FP"
78
- CONFUSION_MATRIX_FN: str = "2: Confusion Matrix: FN"
79
- CONFUSION_MATRIX_TP: str = "2: Confusion Matrix: TP"
80
- FALSE_POSITIVE_RATE: str = "2: False Positive Rate"
81
- FALSE_NEGATIVE_RATE: str = "2: False Negative Rate"
82
- FALSE_DISCOVERY_RATE: str = "2: False Discovery Rate"
83
- FALSE_OMISSION_RATE: str = "2: False Omission Rate"
84
- CRITICAL_SUCCESS_INDEX: str = "2: Critical Success Index: Threat Score"
85
-
86
- # F-scores (starting with '3:')
87
- F_SCORE_X: str = "3: F-X Score" # X is the beta value
88
-
89
- # Matthews correlation coefficient (starting with '4:')
90
- MATTHEWS_CORRELATION_COEFFICIENT: str = "4: Matthews Correlation Coefficient: MCC"
91
-
92
- # Optimal thresholds (starting with '5:')
93
- OPTIMAL_THRESHOLD_YOUDEN: str = "5: Optimal Threshold: Youden"
94
- OPTIMAL_THRESHOLD_COST: str = "5: Optimal Threshold: Cost"
95
- OPTIMAL_THRESHOLD_F1: str = "5: Optimal Threshold: F1"
96
- OPTIMAL_THRESHOLD_F1_NEGATIVE: str = "5: Optimal Threshold: F1 for Negative Class"
97
-
98
- # Average metrics across folds (starting with '6:')
99
- AVERAGE_METRIC: str = "6: Average METRIC_NAME across folds"
100
-
101
- # Standard deviation metrics across folds (starting with '7:')
102
- STANDARD_DEVIATION_METRIC: str = "7: Standard deviation METRIC_NAME across folds"
103
-
104
- # Parameter finder (starting with '8:')
105
- PARAMETER_FINDER: str = "8: TITLE: PARAMETER_NAME"
106
-
1
+ """
2
+ This module contains the MetricDictionnary class, which provides a dictionary of metric names.
3
+
4
+ This is often used to log metrics to MLflow and to display them in the console easily.
5
+
6
+ This class contains the following metrics:
7
+
8
+ 1. Main metrics:
9
+
10
+ - Area Under the Curve (AUC)
11
+ - Area Under the Precision-Recall Curve (AUPRC)
12
+ - Area Under the NPV-Specificity Curve (NEGATIVE_AUPRC)
13
+ - Specificity (True Negative Rate)
14
+ - Recall/Sensitivity (True Positive Rate)
15
+ - Precision (Positive Predictive Value)
16
+ - Negative Predictive Value (NPV)
17
+ - Accuracy
18
+ - F1 Score
19
+ - Precision-Recall Average
20
+ - Precision-Recall Average for Negative Class
21
+
22
+ 2. Confusion matrix metrics:
23
+
24
+ - True Negatives (TN)
25
+ - False Positives (FP)
26
+ - False Negatives (FN)
27
+ - True Positives (TP)
28
+ - False Positive Rate
29
+ - False Negative Rate
30
+ - False Discovery Rate
31
+ - False Omission Rate
32
+ - Critical Success Index (Threat Score)
33
+
34
+ 3. F-scores:
35
+
36
+ - F-beta Score (where beta is configurable)
37
+
38
+ 4. Matthews correlation coefficient:
39
+
40
+ - Matthews Correlation Coefficient (MCC)
41
+
42
+ 5. Optimal thresholds for binary classification:
43
+
44
+ - Youden's J statistic
45
+ - Cost-based threshold
46
+ - F1 Score threshold
47
+ - F1 Score threshold for the negative class
48
+
49
+ 6. Average metrics across folds:
50
+
51
+ - Mean value of any metric across k-fold cross validation
52
+
53
+ 7. Standard deviation metrics across folds:
54
+
55
+ - Standard deviation of any metric across k-fold cross validation
56
+ """
57
+
58
+ class MetricDictionnary:
59
+
60
+ # Main metrics (starting with '1:')
61
+ AUC: str = "1: Area Under the ROC Curve: AUC / AUROC"
62
+ AUPRC: str = "1: Area Under the Precision-Recall Curve: AUPRC / PR AUC"
63
+ NEGATIVE_AUPRC: str = "1: Area Under the NPV-Specificity Curve: AUNPRC / NPR AUC"
64
+ SPECIFICITY: str = "1: Specificity: True Negative Rate"
65
+ RECALL: str = "1: Recall/Sensitivity: True Positive Rate"
66
+ PRECISION: str = "1: Precision: Positive Predictive Value"
67
+ NPV: str = "1: NPV: Negative Predictive Value"
68
+ ACCURACY: str = "1: Accuracy"
69
+ BALANCED_ACCURACY: str = "1: Balanced Accuracy"
70
+ F1_SCORE: str = "1: F1 Score"
71
+ F1_SCORE_NEGATIVE: str = "1: F1 Score for Negative Class"
72
+ PR_AVERAGE: str = "1: Precision-Recall Average"
73
+ PR_AVERAGE_NEGATIVE: str = "1: Precision-Recall Average for Negative Class"
74
+
75
+ # Confusion matrix metrics (starting with '2:')
76
+ CONFUSION_MATRIX_TN: str = "2: Confusion Matrix: TN"
77
+ CONFUSION_MATRIX_FP: str = "2: Confusion Matrix: FP"
78
+ CONFUSION_MATRIX_FN: str = "2: Confusion Matrix: FN"
79
+ CONFUSION_MATRIX_TP: str = "2: Confusion Matrix: TP"
80
+ FALSE_POSITIVE_RATE: str = "2: False Positive Rate"
81
+ FALSE_NEGATIVE_RATE: str = "2: False Negative Rate"
82
+ FALSE_DISCOVERY_RATE: str = "2: False Discovery Rate"
83
+ FALSE_OMISSION_RATE: str = "2: False Omission Rate"
84
+ CRITICAL_SUCCESS_INDEX: str = "2: Critical Success Index: Threat Score"
85
+
86
+ # F-scores (starting with '3:')
87
+ F_SCORE_X: str = "3: F-X Score" # X is the beta value
88
+
89
+ # Matthews correlation coefficient (starting with '4:')
90
+ MATTHEWS_CORRELATION_COEFFICIENT: str = "4: Matthews Correlation Coefficient: MCC"
91
+
92
+ # Optimal thresholds (starting with '5:')
93
+ OPTIMAL_THRESHOLD_YOUDEN: str = "5: Optimal Threshold: Youden"
94
+ OPTIMAL_THRESHOLD_COST: str = "5: Optimal Threshold: Cost"
95
+ OPTIMAL_THRESHOLD_F1: str = "5: Optimal Threshold: F1"
96
+ OPTIMAL_THRESHOLD_F1_NEGATIVE: str = "5: Optimal Threshold: F1 for Negative Class"
97
+
98
+ # Average metrics across folds (starting with '6:')
99
+ AVERAGE_METRIC: str = "6: Average METRIC_NAME across folds"
100
+
101
+ # Standard deviation metrics across folds (starting with '7:')
102
+ STANDARD_DEVIATION_METRIC: str = "7: Standard deviation METRIC_NAME across folds"
103
+
104
+ # Parameter finder (starting with '8:')
105
+ PARAMETER_FINDER: str = "8: TITLE: PARAMETER_NAME"
106
+
@@ -1,206 +1,206 @@
1
- """
2
- This module contains utility functions for working with MLflow.
3
-
4
- This module contains functions for:
5
-
6
- - Getting the artifact path from the current mlflow run
7
- - Getting the weights path
8
- - Getting the runs by experiment name
9
- - Logging the history of the model to the current mlflow run
10
- - Starting a new mlflow run
11
- """
12
-
13
- # Imports
14
- import os
15
- from typing import Any, Literal
16
-
17
- import mlflow
18
- from mlflow.entities import Experiment, Run
19
-
20
- from ..decorators import handle_error, LogLevels
21
- from ..io import clean_path
22
-
23
-
24
- # Get artifact path
25
- def get_artifact_path(from_string: str = "", os_name: str = os.name) -> str:
26
- """ Get the artifact path from the current mlflow run (without the file:// prefix).
27
-
28
- Handles the different path formats for Windows and Unix-based systems.
29
-
30
- Args:
31
- from_string (str): Path to the artifact (optional, defaults to the current mlflow run)
32
- os_name (str): OS name (optional, defaults to os.name)
33
- Returns:
34
- str: The artifact path
35
- """
36
- # Get the artifact path from the current mlflow run or from a string
37
- if not from_string:
38
- artifact_path: str = mlflow.get_artifact_uri()
39
- else:
40
- artifact_path: str = from_string
41
-
42
- # Handle the different path formats for Windows and Unix-based systems
43
- if os_name == "nt":
44
- return artifact_path.replace("file:///", "")
45
- else:
46
- return artifact_path.replace("file://", "")
47
-
48
- # Get weights path
49
- def get_weights_path(from_string: str = "", weights_name: str = "best_model.keras", os_name: str = os.name) -> str:
50
- """ Get the weights path from the current mlflow run.
51
-
52
- Args:
53
- from_string (str): Path to the artifact (optional, defaults to the current mlflow run)
54
- weights_name (str): Name of the weights file (optional, defaults to "best_model.keras")
55
- os_name (str): OS name (optional, defaults to os.name)
56
- Returns:
57
- str: The weights path
58
-
59
- Examples:
60
- >>> get_weights_path(from_string="file:///path/to/artifact", weights_name="best_model.keras", os_name="posix")
61
- '/path/to/artifact/best_model.keras'
62
-
63
- >>> get_weights_path(from_string="file:///C:/path/to/artifact", weights_name="best_model.keras", os_name="nt")
64
- 'C:/path/to/artifact/best_model.keras'
65
- """
66
- return clean_path(f"{get_artifact_path(from_string=from_string, os_name=os_name)}/{weights_name}")
67
-
68
- # Get runs by experiment name
69
- def get_runs_by_experiment_name(experiment_name: str, filter_string: str = "", set_experiment: bool = False) -> list[Run]:
70
- """ Get the runs by experiment name.
71
-
72
- Args:
73
- experiment_name (str): Name of the experiment
74
- filter_string (str): Filter string to apply to the runs
75
- set_experiment (bool): Whether to set the experiment
76
- Returns:
77
- list[Run]: List of runs
78
- """
79
- if set_experiment:
80
- mlflow.set_experiment(experiment_name)
81
- experiment: Experiment | None = mlflow.get_experiment_by_name(experiment_name)
82
- if experiment:
83
- return mlflow.search_runs(
84
- experiment_ids=[experiment.experiment_id],
85
- output_format="list",
86
- filter_string=filter_string
87
- ) # pyright: ignore [reportReturnType]
88
- return []
89
-
90
- def get_runs_by_model_name(experiment_name: str, model_name: str, set_experiment: bool = False) -> list[Run]:
91
- """ Get the runs by model name.
92
-
93
- Args:
94
- experiment_name (str): Name of the experiment
95
- model_name (str): Name of the model
96
- set_experiment (bool): Whether to set the experiment
97
- Returns:
98
- list[Run]: List of runs
99
- """
100
- return get_runs_by_experiment_name(
101
- experiment_name,
102
- filter_string=f"tags.model_name = '{model_name}'",
103
- set_experiment=set_experiment
104
- )
105
-
106
- # Log history
107
- def log_history(history: dict[str, list[Any]], prefix: str = "history", **kwargs: Any) -> None:
108
- """ Log the history of the model to the current mlflow run.
109
-
110
- Args:
111
- history (dict[str, list[Any]]): History of the model
112
- (usually from a History object like from a Keras model: history.history)
113
- **kwargs (Any): Additional arguments to pass to mlflow.log_metric
114
- """
115
- for (metric, values) in history.items():
116
- for epoch, value in enumerate(values):
117
- handle_error(mlflow.log_metric,
118
- message=f"Error logging metric {metric}",
119
- error_log=LogLevels.ERROR_TRACEBACK
120
- )(f"{prefix}_{metric}", value, step=epoch, **kwargs)
121
-
122
-
123
- def start_run(mlflow_uri: str, experiment_name: str, model_name: str, override_run_name: str = "", **kwargs: Any) -> str:
124
- """ Start a new mlflow run.
125
-
126
- Args:
127
- mlflow_uri (str): MLflow URI
128
- experiment_name (str): Name of the experiment
129
- model_name (str): Name of the model
130
- override_run_name (str): Override the run name (if empty, it will be set automatically)
131
- **kwargs (Any): Additional arguments to pass to mlflow.start_run
132
- Returns:
133
- str: Name of the run (suffixed with the version number)
134
- """
135
- # Set the mlflow URI
136
- mlflow.set_tracking_uri(mlflow_uri)
137
-
138
- # Get the runs and increment the version number
139
- runs: list[Run] = get_runs_by_model_name(experiment_name, model_name, set_experiment=True)
140
- run_number: int = len(runs) + 1
141
- run_name: str = f"{model_name}_v{run_number:02d}" if not override_run_name else override_run_name
142
-
143
- # Start the run
144
- mlflow.start_run(run_name=run_name, tags={"model_name": model_name}, log_system_metrics=True, **kwargs)
145
- return run_name
146
-
147
- # Get best run by metric
148
- def get_best_run_by_metric(
149
- experiment_name: str,
150
- metric_name: str,
151
- model_name: str = "",
152
- ascending: bool = False,
153
- has_saved_model: bool = True
154
- ) -> Run | None:
155
- """ Get the best run by a specific metric.
156
-
157
- Args:
158
- experiment_name (str): Name of the experiment
159
- metric_name (str): Name of the metric to sort by
160
- model_name (str): Name of the model (optional, if empty, all models are considered)
161
- ascending (bool): Whether to sort in ascending order (default: False, i.e. maximum metric value is best)
162
- has_saved_model (bool): Whether the model has been saved (default: True)
163
- Returns:
164
- Run | None: The best run or None if no runs are found
165
- """
166
- # Get the runs
167
- filter_string: str = f"metrics.`{metric_name}` > 0"
168
- if model_name:
169
- filter_string += f" AND tags.model_name = '{model_name}'"
170
- if has_saved_model:
171
- filter_string += " AND tags.has_saved_model = 'True'"
172
-
173
- runs: list[Run] = get_runs_by_experiment_name(
174
- experiment_name,
175
- filter_string=filter_string,
176
- set_experiment=True
177
- )
178
-
179
- if not runs:
180
- return None
181
-
182
- # Sort the runs by the metric
183
- sorted_runs: list[Run] = sorted(
184
- runs,
185
- key=lambda run: float(run.data.metrics.get(metric_name, 0)), # type: ignore
186
- reverse=not ascending
187
- )
188
-
189
- return sorted_runs[0] if sorted_runs else None
190
-
191
-
192
- def load_model(run_id: str, model_type: Literal["keras", "pytorch"] = "keras") -> Any:
193
- """ Load a model from MLflow.
194
-
195
- Args:
196
- run_id (str): ID of the run to load the model from
197
- model_type (Literal["keras", "pytorch"]): Type of model to load (default: "keras")
198
- Returns:
199
- Any: The loaded model
200
- """
201
- if model_type == "keras":
202
- return mlflow.keras.load_model(f"runs:/{run_id}/best_model") # type: ignore
203
- elif model_type == "pytorch":
204
- return mlflow.pytorch.load_model(f"runs:/{run_id}/best_model") # type: ignore
205
- raise ValueError(f"Model type {model_type} not supported")
206
-
1
+ """
2
+ This module contains utility functions for working with MLflow.
3
+
4
+ This module contains functions for:
5
+
6
+ - Getting the artifact path from the current mlflow run
7
+ - Getting the weights path
8
+ - Getting the runs by experiment name
9
+ - Logging the history of the model to the current mlflow run
10
+ - Starting a new mlflow run
11
+ """
12
+
13
+ # Imports
14
+ import os
15
+ from typing import Any, Literal
16
+
17
+ import mlflow
18
+ from mlflow.entities import Experiment, Run
19
+
20
+ from ..decorators import handle_error, LogLevels
21
+ from ..io import clean_path
22
+
23
+
24
+ # Get artifact path
25
+ def get_artifact_path(from_string: str = "", os_name: str = os.name) -> str:
26
+ """ Get the artifact path from the current mlflow run (without the file:// prefix).
27
+
28
+ Handles the different path formats for Windows and Unix-based systems.
29
+
30
+ Args:
31
+ from_string (str): Path to the artifact (optional, defaults to the current mlflow run)
32
+ os_name (str): OS name (optional, defaults to os.name)
33
+ Returns:
34
+ str: The artifact path
35
+ """
36
+ # Get the artifact path from the current mlflow run or from a string
37
+ if not from_string:
38
+ artifact_path: str = mlflow.get_artifact_uri()
39
+ else:
40
+ artifact_path: str = from_string
41
+
42
+ # Handle the different path formats for Windows and Unix-based systems
43
+ if os_name == "nt":
44
+ return artifact_path.replace("file:///", "")
45
+ else:
46
+ return artifact_path.replace("file://", "")
47
+
48
+ # Get weights path
49
+ def get_weights_path(from_string: str = "", weights_name: str = "best_model.keras", os_name: str = os.name) -> str:
50
+ """ Get the weights path from the current mlflow run.
51
+
52
+ Args:
53
+ from_string (str): Path to the artifact (optional, defaults to the current mlflow run)
54
+ weights_name (str): Name of the weights file (optional, defaults to "best_model.keras")
55
+ os_name (str): OS name (optional, defaults to os.name)
56
+ Returns:
57
+ str: The weights path
58
+
59
+ Examples:
60
+ >>> get_weights_path(from_string="file:///path/to/artifact", weights_name="best_model.keras", os_name="posix")
61
+ '/path/to/artifact/best_model.keras'
62
+
63
+ >>> get_weights_path(from_string="file:///C:/path/to/artifact", weights_name="best_model.keras", os_name="nt")
64
+ 'C:/path/to/artifact/best_model.keras'
65
+ """
66
+ return clean_path(f"{get_artifact_path(from_string=from_string, os_name=os_name)}/{weights_name}")
67
+
68
+ # Get runs by experiment name
69
+ def get_runs_by_experiment_name(experiment_name: str, filter_string: str = "", set_experiment: bool = False) -> list[Run]:
70
+ """ Get the runs by experiment name.
71
+
72
+ Args:
73
+ experiment_name (str): Name of the experiment
74
+ filter_string (str): Filter string to apply to the runs
75
+ set_experiment (bool): Whether to set the experiment
76
+ Returns:
77
+ list[Run]: List of runs
78
+ """
79
+ if set_experiment:
80
+ mlflow.set_experiment(experiment_name)
81
+ experiment: Experiment | None = mlflow.get_experiment_by_name(experiment_name)
82
+ if experiment:
83
+ return mlflow.search_runs(
84
+ experiment_ids=[experiment.experiment_id],
85
+ output_format="list",
86
+ filter_string=filter_string
87
+ ) # pyright: ignore [reportReturnType]
88
+ return []
89
+
90
+ def get_runs_by_model_name(experiment_name: str, model_name: str, set_experiment: bool = False) -> list[Run]:
91
+ """ Get the runs by model name.
92
+
93
+ Args:
94
+ experiment_name (str): Name of the experiment
95
+ model_name (str): Name of the model
96
+ set_experiment (bool): Whether to set the experiment
97
+ Returns:
98
+ list[Run]: List of runs
99
+ """
100
+ return get_runs_by_experiment_name(
101
+ experiment_name,
102
+ filter_string=f"tags.model_name = '{model_name}'",
103
+ set_experiment=set_experiment
104
+ )
105
+
106
+ # Log history
107
+ def log_history(history: dict[str, list[Any]], prefix: str = "history", **kwargs: Any) -> None:
108
+ """ Log the history of the model to the current mlflow run.
109
+
110
+ Args:
111
+ history (dict[str, list[Any]]): History of the model
112
+ (usually from a History object like from a Keras model: history.history)
113
+ **kwargs (Any): Additional arguments to pass to mlflow.log_metric
114
+ """
115
+ for (metric, values) in history.items():
116
+ for epoch, value in enumerate(values):
117
+ handle_error(mlflow.log_metric,
118
+ message=f"Error logging metric {metric}",
119
+ error_log=LogLevels.ERROR_TRACEBACK
120
+ )(f"{prefix}_{metric}", value, step=epoch, **kwargs)
121
+
122
+
123
+ def start_run(mlflow_uri: str, experiment_name: str, model_name: str, override_run_name: str = "", **kwargs: Any) -> str:
124
+ """ Start a new mlflow run.
125
+
126
+ Args:
127
+ mlflow_uri (str): MLflow URI
128
+ experiment_name (str): Name of the experiment
129
+ model_name (str): Name of the model
130
+ override_run_name (str): Override the run name (if empty, it will be set automatically)
131
+ **kwargs (Any): Additional arguments to pass to mlflow.start_run
132
+ Returns:
133
+ str: Name of the run (suffixed with the version number)
134
+ """
135
+ # Set the mlflow URI
136
+ mlflow.set_tracking_uri(mlflow_uri)
137
+
138
+ # Get the runs and increment the version number
139
+ runs: list[Run] = get_runs_by_model_name(experiment_name, model_name, set_experiment=True)
140
+ run_number: int = len(runs) + 1
141
+ run_name: str = f"{model_name}_v{run_number:02d}" if not override_run_name else override_run_name
142
+
143
+ # Start the run
144
+ mlflow.start_run(run_name=run_name, tags={"model_name": model_name}, log_system_metrics=True, **kwargs)
145
+ return run_name
146
+
147
+ # Get best run by metric
148
+ def get_best_run_by_metric(
149
+ experiment_name: str,
150
+ metric_name: str,
151
+ model_name: str = "",
152
+ ascending: bool = False,
153
+ has_saved_model: bool = True
154
+ ) -> Run | None:
155
+ """ Get the best run by a specific metric.
156
+
157
+ Args:
158
+ experiment_name (str): Name of the experiment
159
+ metric_name (str): Name of the metric to sort by
160
+ model_name (str): Name of the model (optional, if empty, all models are considered)
161
+ ascending (bool): Whether to sort in ascending order (default: False, i.e. maximum metric value is best)
162
+ has_saved_model (bool): Whether the model has been saved (default: True)
163
+ Returns:
164
+ Run | None: The best run or None if no runs are found
165
+ """
166
+ # Get the runs
167
+ filter_string: str = f"metrics.`{metric_name}` > 0"
168
+ if model_name:
169
+ filter_string += f" AND tags.model_name = '{model_name}'"
170
+ if has_saved_model:
171
+ filter_string += " AND tags.has_saved_model = 'True'"
172
+
173
+ runs: list[Run] = get_runs_by_experiment_name(
174
+ experiment_name,
175
+ filter_string=filter_string,
176
+ set_experiment=True
177
+ )
178
+
179
+ if not runs:
180
+ return None
181
+
182
+ # Sort the runs by the metric
183
+ sorted_runs: list[Run] = sorted(
184
+ runs,
185
+ key=lambda run: float(run.data.metrics.get(metric_name, 0)), # type: ignore
186
+ reverse=not ascending
187
+ )
188
+
189
+ return sorted_runs[0] if sorted_runs else None
190
+
191
+
192
+ def load_model(run_id: str, model_type: Literal["keras", "pytorch"] = "keras") -> Any:
193
+ """ Load a model from MLflow.
194
+
195
+ Args:
196
+ run_id (str): ID of the run to load the model from
197
+ model_type (Literal["keras", "pytorch"]): Type of model to load (default: "keras")
198
+ Returns:
199
+ Any: The loaded model
200
+ """
201
+ if model_type == "keras":
202
+ return mlflow.keras.load_model(f"runs:/{run_id}/best_model") # type: ignore
203
+ elif model_type == "pytorch":
204
+ return mlflow.pytorch.load_model(f"runs:/{run_id}/best_model") # type: ignore
205
+ raise ValueError(f"Model type {model_type} not supported")
206
+