stouputils 1.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stouputils/__init__.py +40 -0
- stouputils/__main__.py +86 -0
- stouputils/_deprecated.py +37 -0
- stouputils/all_doctests.py +160 -0
- stouputils/applications/__init__.py +22 -0
- stouputils/applications/automatic_docs.py +634 -0
- stouputils/applications/upscaler/__init__.py +39 -0
- stouputils/applications/upscaler/config.py +128 -0
- stouputils/applications/upscaler/image.py +247 -0
- stouputils/applications/upscaler/video.py +287 -0
- stouputils/archive.py +344 -0
- stouputils/backup.py +488 -0
- stouputils/collections.py +244 -0
- stouputils/continuous_delivery/__init__.py +27 -0
- stouputils/continuous_delivery/cd_utils.py +243 -0
- stouputils/continuous_delivery/github.py +522 -0
- stouputils/continuous_delivery/pypi.py +130 -0
- stouputils/continuous_delivery/pyproject.py +147 -0
- stouputils/continuous_delivery/stubs.py +86 -0
- stouputils/ctx.py +408 -0
- stouputils/data_science/config/get.py +51 -0
- stouputils/data_science/config/set.py +125 -0
- stouputils/data_science/data_processing/image/__init__.py +66 -0
- stouputils/data_science/data_processing/image/auto_contrast.py +79 -0
- stouputils/data_science/data_processing/image/axis_flip.py +58 -0
- stouputils/data_science/data_processing/image/bias_field_correction.py +74 -0
- stouputils/data_science/data_processing/image/binary_threshold.py +73 -0
- stouputils/data_science/data_processing/image/blur.py +59 -0
- stouputils/data_science/data_processing/image/brightness.py +54 -0
- stouputils/data_science/data_processing/image/canny.py +110 -0
- stouputils/data_science/data_processing/image/clahe.py +92 -0
- stouputils/data_science/data_processing/image/common.py +30 -0
- stouputils/data_science/data_processing/image/contrast.py +53 -0
- stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -0
- stouputils/data_science/data_processing/image/denoise.py +378 -0
- stouputils/data_science/data_processing/image/histogram_equalization.py +123 -0
- stouputils/data_science/data_processing/image/invert.py +64 -0
- stouputils/data_science/data_processing/image/laplacian.py +60 -0
- stouputils/data_science/data_processing/image/median_blur.py +52 -0
- stouputils/data_science/data_processing/image/noise.py +59 -0
- stouputils/data_science/data_processing/image/normalize.py +65 -0
- stouputils/data_science/data_processing/image/random_erase.py +66 -0
- stouputils/data_science/data_processing/image/resize.py +69 -0
- stouputils/data_science/data_processing/image/rotation.py +80 -0
- stouputils/data_science/data_processing/image/salt_pepper.py +68 -0
- stouputils/data_science/data_processing/image/sharpening.py +55 -0
- stouputils/data_science/data_processing/image/shearing.py +64 -0
- stouputils/data_science/data_processing/image/threshold.py +64 -0
- stouputils/data_science/data_processing/image/translation.py +71 -0
- stouputils/data_science/data_processing/image/zoom.py +83 -0
- stouputils/data_science/data_processing/image_augmentation.py +118 -0
- stouputils/data_science/data_processing/image_preprocess.py +183 -0
- stouputils/data_science/data_processing/prosthesis_detection.py +359 -0
- stouputils/data_science/data_processing/technique.py +481 -0
- stouputils/data_science/dataset/__init__.py +45 -0
- stouputils/data_science/dataset/dataset.py +292 -0
- stouputils/data_science/dataset/dataset_loader.py +135 -0
- stouputils/data_science/dataset/grouping_strategy.py +296 -0
- stouputils/data_science/dataset/image_loader.py +100 -0
- stouputils/data_science/dataset/xy_tuple.py +696 -0
- stouputils/data_science/metric_dictionnary.py +106 -0
- stouputils/data_science/metric_utils.py +847 -0
- stouputils/data_science/mlflow_utils.py +206 -0
- stouputils/data_science/models/abstract_model.py +149 -0
- stouputils/data_science/models/all.py +85 -0
- stouputils/data_science/models/base_keras.py +765 -0
- stouputils/data_science/models/keras/all.py +38 -0
- stouputils/data_science/models/keras/convnext.py +62 -0
- stouputils/data_science/models/keras/densenet.py +50 -0
- stouputils/data_science/models/keras/efficientnet.py +60 -0
- stouputils/data_science/models/keras/mobilenet.py +56 -0
- stouputils/data_science/models/keras/resnet.py +52 -0
- stouputils/data_science/models/keras/squeezenet.py +233 -0
- stouputils/data_science/models/keras/vgg.py +42 -0
- stouputils/data_science/models/keras/xception.py +38 -0
- stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -0
- stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -0
- stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -0
- stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -0
- stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -0
- stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -0
- stouputils/data_science/models/keras_utils/losses/__init__.py +12 -0
- stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -0
- stouputils/data_science/models/keras_utils/visualizations.py +416 -0
- stouputils/data_science/models/model_interface.py +939 -0
- stouputils/data_science/models/sandbox.py +116 -0
- stouputils/data_science/range_tuple.py +234 -0
- stouputils/data_science/scripts/augment_dataset.py +77 -0
- stouputils/data_science/scripts/exhaustive_process.py +133 -0
- stouputils/data_science/scripts/preprocess_dataset.py +70 -0
- stouputils/data_science/scripts/routine.py +168 -0
- stouputils/data_science/utils.py +285 -0
- stouputils/decorators.py +605 -0
- stouputils/image.py +441 -0
- stouputils/installer/__init__.py +18 -0
- stouputils/installer/common.py +67 -0
- stouputils/installer/downloader.py +101 -0
- stouputils/installer/linux.py +144 -0
- stouputils/installer/main.py +223 -0
- stouputils/installer/windows.py +136 -0
- stouputils/io.py +486 -0
- stouputils/parallel.py +483 -0
- stouputils/print.py +482 -0
- stouputils/py.typed +1 -0
- stouputils/stouputils/__init__.pyi +15 -0
- stouputils/stouputils/_deprecated.pyi +12 -0
- stouputils/stouputils/all_doctests.pyi +46 -0
- stouputils/stouputils/applications/__init__.pyi +2 -0
- stouputils/stouputils/applications/automatic_docs.pyi +106 -0
- stouputils/stouputils/applications/upscaler/__init__.pyi +3 -0
- stouputils/stouputils/applications/upscaler/config.pyi +18 -0
- stouputils/stouputils/applications/upscaler/image.pyi +109 -0
- stouputils/stouputils/applications/upscaler/video.pyi +60 -0
- stouputils/stouputils/archive.pyi +67 -0
- stouputils/stouputils/backup.pyi +109 -0
- stouputils/stouputils/collections.pyi +86 -0
- stouputils/stouputils/continuous_delivery/__init__.pyi +5 -0
- stouputils/stouputils/continuous_delivery/cd_utils.pyi +129 -0
- stouputils/stouputils/continuous_delivery/github.pyi +162 -0
- stouputils/stouputils/continuous_delivery/pypi.pyi +53 -0
- stouputils/stouputils/continuous_delivery/pyproject.pyi +67 -0
- stouputils/stouputils/continuous_delivery/stubs.pyi +39 -0
- stouputils/stouputils/ctx.pyi +211 -0
- stouputils/stouputils/decorators.pyi +252 -0
- stouputils/stouputils/image.pyi +172 -0
- stouputils/stouputils/installer/__init__.pyi +5 -0
- stouputils/stouputils/installer/common.pyi +39 -0
- stouputils/stouputils/installer/downloader.pyi +24 -0
- stouputils/stouputils/installer/linux.pyi +39 -0
- stouputils/stouputils/installer/main.pyi +57 -0
- stouputils/stouputils/installer/windows.pyi +31 -0
- stouputils/stouputils/io.pyi +213 -0
- stouputils/stouputils/parallel.pyi +216 -0
- stouputils/stouputils/print.pyi +136 -0
- stouputils/stouputils/version_pkg.pyi +15 -0
- stouputils/version_pkg.py +189 -0
- stouputils-1.14.0.dist-info/METADATA +178 -0
- stouputils-1.14.0.dist-info/RECORD +140 -0
- stouputils-1.14.0.dist-info/WHEEL +4 -0
- stouputils-1.14.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains utility functions for working with MLflow.
|
|
3
|
+
|
|
4
|
+
This module contains functions for:
|
|
5
|
+
|
|
6
|
+
- Getting the artifact path from the current mlflow run
|
|
7
|
+
- Getting the weights path
|
|
8
|
+
- Getting the runs by experiment name
|
|
9
|
+
- Logging the history of the model to the current mlflow run
|
|
10
|
+
- Starting a new mlflow run
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
# Imports
|
|
14
|
+
import os
|
|
15
|
+
from typing import Any, Literal
|
|
16
|
+
|
|
17
|
+
import mlflow
|
|
18
|
+
from mlflow.entities import Experiment, Run
|
|
19
|
+
|
|
20
|
+
from ..decorators import handle_error, LogLevels
|
|
21
|
+
from ..io import clean_path
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# Get artifact path
|
|
25
|
+
def get_artifact_path(from_string: str = "", os_name: str = os.name) -> str:
|
|
26
|
+
""" Get the artifact path from the current mlflow run (without the file:// prefix).
|
|
27
|
+
|
|
28
|
+
Handles the different path formats for Windows and Unix-based systems.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
from_string (str): Path to the artifact (optional, defaults to the current mlflow run)
|
|
32
|
+
os_name (str): OS name (optional, defaults to os.name)
|
|
33
|
+
Returns:
|
|
34
|
+
str: The artifact path
|
|
35
|
+
"""
|
|
36
|
+
# Get the artifact path from the current mlflow run or from a string
|
|
37
|
+
if not from_string:
|
|
38
|
+
artifact_path: str = mlflow.get_artifact_uri()
|
|
39
|
+
else:
|
|
40
|
+
artifact_path: str = from_string
|
|
41
|
+
|
|
42
|
+
# Handle the different path formats for Windows and Unix-based systems
|
|
43
|
+
if os_name == "nt":
|
|
44
|
+
return artifact_path.replace("file:///", "")
|
|
45
|
+
else:
|
|
46
|
+
return artifact_path.replace("file://", "")
|
|
47
|
+
|
|
48
|
+
# Get weights path
|
|
49
|
+
def get_weights_path(from_string: str = "", weights_name: str = "best_model.keras", os_name: str = os.name) -> str:
|
|
50
|
+
""" Get the weights path from the current mlflow run.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
from_string (str): Path to the artifact (optional, defaults to the current mlflow run)
|
|
54
|
+
weights_name (str): Name of the weights file (optional, defaults to "best_model.keras")
|
|
55
|
+
os_name (str): OS name (optional, defaults to os.name)
|
|
56
|
+
Returns:
|
|
57
|
+
str: The weights path
|
|
58
|
+
|
|
59
|
+
Examples:
|
|
60
|
+
>>> get_weights_path(from_string="file:///path/to/artifact", weights_name="best_model.keras", os_name="posix")
|
|
61
|
+
'/path/to/artifact/best_model.keras'
|
|
62
|
+
|
|
63
|
+
>>> get_weights_path(from_string="file:///C:/path/to/artifact", weights_name="best_model.keras", os_name="nt")
|
|
64
|
+
'C:/path/to/artifact/best_model.keras'
|
|
65
|
+
"""
|
|
66
|
+
return clean_path(f"{get_artifact_path(from_string=from_string, os_name=os_name)}/{weights_name}")
|
|
67
|
+
|
|
68
|
+
# Get runs by experiment name
|
|
69
|
+
def get_runs_by_experiment_name(experiment_name: str, filter_string: str = "", set_experiment: bool = False) -> list[Run]:
|
|
70
|
+
""" Get the runs by experiment name.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
experiment_name (str): Name of the experiment
|
|
74
|
+
filter_string (str): Filter string to apply to the runs
|
|
75
|
+
set_experiment (bool): Whether to set the experiment
|
|
76
|
+
Returns:
|
|
77
|
+
list[Run]: List of runs
|
|
78
|
+
"""
|
|
79
|
+
if set_experiment:
|
|
80
|
+
mlflow.set_experiment(experiment_name)
|
|
81
|
+
experiment: Experiment | None = mlflow.get_experiment_by_name(experiment_name)
|
|
82
|
+
if experiment:
|
|
83
|
+
return mlflow.search_runs(
|
|
84
|
+
experiment_ids=[experiment.experiment_id],
|
|
85
|
+
output_format="list",
|
|
86
|
+
filter_string=filter_string
|
|
87
|
+
) # pyright: ignore [reportReturnType]
|
|
88
|
+
return []
|
|
89
|
+
|
|
90
|
+
def get_runs_by_model_name(experiment_name: str, model_name: str, set_experiment: bool = False) -> list[Run]:
|
|
91
|
+
""" Get the runs by model name.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
experiment_name (str): Name of the experiment
|
|
95
|
+
model_name (str): Name of the model
|
|
96
|
+
set_experiment (bool): Whether to set the experiment
|
|
97
|
+
Returns:
|
|
98
|
+
list[Run]: List of runs
|
|
99
|
+
"""
|
|
100
|
+
return get_runs_by_experiment_name(
|
|
101
|
+
experiment_name,
|
|
102
|
+
filter_string=f"tags.model_name = '{model_name}'",
|
|
103
|
+
set_experiment=set_experiment
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Log history
|
|
107
|
+
def log_history(history: dict[str, list[Any]], prefix: str = "history", **kwargs: Any) -> None:
|
|
108
|
+
""" Log the history of the model to the current mlflow run.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
history (dict[str, list[Any]]): History of the model
|
|
112
|
+
(usually from a History object like from a Keras model: history.history)
|
|
113
|
+
**kwargs (Any): Additional arguments to pass to mlflow.log_metric
|
|
114
|
+
"""
|
|
115
|
+
for (metric, values) in history.items():
|
|
116
|
+
for epoch, value in enumerate(values):
|
|
117
|
+
handle_error(mlflow.log_metric,
|
|
118
|
+
message=f"Error logging metric {metric}",
|
|
119
|
+
error_log=LogLevels.ERROR_TRACEBACK
|
|
120
|
+
)(f"{prefix}_{metric}", value, step=epoch, **kwargs)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def start_run(mlflow_uri: str, experiment_name: str, model_name: str, override_run_name: str = "", **kwargs: Any) -> str:
|
|
124
|
+
""" Start a new mlflow run.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
mlflow_uri (str): MLflow URI
|
|
128
|
+
experiment_name (str): Name of the experiment
|
|
129
|
+
model_name (str): Name of the model
|
|
130
|
+
override_run_name (str): Override the run name (if empty, it will be set automatically)
|
|
131
|
+
**kwargs (Any): Additional arguments to pass to mlflow.start_run
|
|
132
|
+
Returns:
|
|
133
|
+
str: Name of the run (suffixed with the version number)
|
|
134
|
+
"""
|
|
135
|
+
# Set the mlflow URI
|
|
136
|
+
mlflow.set_tracking_uri(mlflow_uri)
|
|
137
|
+
|
|
138
|
+
# Get the runs and increment the version number
|
|
139
|
+
runs: list[Run] = get_runs_by_model_name(experiment_name, model_name, set_experiment=True)
|
|
140
|
+
run_number: int = len(runs) + 1
|
|
141
|
+
run_name: str = f"{model_name}_v{run_number:02d}" if not override_run_name else override_run_name
|
|
142
|
+
|
|
143
|
+
# Start the run
|
|
144
|
+
mlflow.start_run(run_name=run_name, tags={"model_name": model_name}, log_system_metrics=True, **kwargs)
|
|
145
|
+
return run_name
|
|
146
|
+
|
|
147
|
+
# Get best run by metric
|
|
148
|
+
def get_best_run_by_metric(
|
|
149
|
+
experiment_name: str,
|
|
150
|
+
metric_name: str,
|
|
151
|
+
model_name: str = "",
|
|
152
|
+
ascending: bool = False,
|
|
153
|
+
has_saved_model: bool = True
|
|
154
|
+
) -> Run | None:
|
|
155
|
+
""" Get the best run by a specific metric.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
experiment_name (str): Name of the experiment
|
|
159
|
+
metric_name (str): Name of the metric to sort by
|
|
160
|
+
model_name (str): Name of the model (optional, if empty, all models are considered)
|
|
161
|
+
ascending (bool): Whether to sort in ascending order (default: False, i.e. maximum metric value is best)
|
|
162
|
+
has_saved_model (bool): Whether the model has been saved (default: True)
|
|
163
|
+
Returns:
|
|
164
|
+
Run | None: The best run or None if no runs are found
|
|
165
|
+
"""
|
|
166
|
+
# Get the runs
|
|
167
|
+
filter_string: str = f"metrics.`{metric_name}` > 0"
|
|
168
|
+
if model_name:
|
|
169
|
+
filter_string += f" AND tags.model_name = '{model_name}'"
|
|
170
|
+
if has_saved_model:
|
|
171
|
+
filter_string += " AND tags.has_saved_model = 'True'"
|
|
172
|
+
|
|
173
|
+
runs: list[Run] = get_runs_by_experiment_name(
|
|
174
|
+
experiment_name,
|
|
175
|
+
filter_string=filter_string,
|
|
176
|
+
set_experiment=True
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
if not runs:
|
|
180
|
+
return None
|
|
181
|
+
|
|
182
|
+
# Sort the runs by the metric
|
|
183
|
+
sorted_runs: list[Run] = sorted(
|
|
184
|
+
runs,
|
|
185
|
+
key=lambda run: float(run.data.metrics.get(metric_name, 0)), # type: ignore
|
|
186
|
+
reverse=not ascending
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
return sorted_runs[0] if sorted_runs else None
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def load_model(run_id: str, model_type: Literal["keras", "pytorch"] = "keras") -> Any:
|
|
193
|
+
""" Load a model from MLflow.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
run_id (str): ID of the run to load the model from
|
|
197
|
+
model_type (Literal["keras", "pytorch"]): Type of model to load (default: "keras")
|
|
198
|
+
Returns:
|
|
199
|
+
Any: The loaded model
|
|
200
|
+
"""
|
|
201
|
+
if model_type == "keras":
|
|
202
|
+
return mlflow.keras.load_model(f"runs:/{run_id}/best_model") # type: ignore
|
|
203
|
+
elif model_type == "pytorch":
|
|
204
|
+
return mlflow.pytorch.load_model(f"runs:/{run_id}/best_model") # type: ignore
|
|
205
|
+
raise ValueError(f"Model type {model_type} not supported")
|
|
206
|
+
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
""" Abstract base class for all model implementations.
|
|
2
|
+
Defines the interface that all concrete model classes must implement.
|
|
3
|
+
|
|
4
|
+
Provides abstract methods for core model operations including:
|
|
5
|
+
|
|
6
|
+
- Class routine management
|
|
7
|
+
- Model loading
|
|
8
|
+
- Training procedures
|
|
9
|
+
- Prediction functionality
|
|
10
|
+
- Evaluation metrics
|
|
11
|
+
|
|
12
|
+
Classes inheriting from AbstractModel must implement all methods.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
# Imports
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import multiprocessing.queues
|
|
19
|
+
from collections.abc import Iterable
|
|
20
|
+
from tempfile import TemporaryDirectory
|
|
21
|
+
from typing import Any
|
|
22
|
+
|
|
23
|
+
from ...decorators import abstract, LogLevels
|
|
24
|
+
|
|
25
|
+
from ..dataset import Dataset
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# Base class
|
|
29
|
+
class AbstractModel:
|
|
30
|
+
""" Abstract class for all models to copy and implement the methods. """
|
|
31
|
+
# Class methods
|
|
32
|
+
@abstract(error_log=LogLevels.ERROR_TRACEBACK)
|
|
33
|
+
def __init__(
|
|
34
|
+
self, num_classes: int, kfold: int = 0, transfer_learning: str = "imagenet", **override_params: Any
|
|
35
|
+
) -> None:
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
## Public abstract methods
|
|
40
|
+
@abstract(error_log=LogLevels.ERROR_TRACEBACK)
|
|
41
|
+
def routine_full(self, dataset: Dataset, verbose: int = 0) -> AbstractModel:
|
|
42
|
+
return self
|
|
43
|
+
|
|
44
|
+
@abstract(error_log=LogLevels.ERROR_TRACEBACK)
|
|
45
|
+
def class_load(self) -> None:
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@abstract(error_log=LogLevels.ERROR_TRACEBACK)
|
|
50
|
+
def class_train(self, dataset: Dataset, verbose: int = 0) -> bool:
|
|
51
|
+
return False
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@abstract(error_log=LogLevels.ERROR_TRACEBACK)
|
|
55
|
+
def class_predict(self, X_test: Iterable[Any]) -> Iterable[Any]:
|
|
56
|
+
return []
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@abstract(error_log=LogLevels.ERROR_TRACEBACK)
|
|
60
|
+
def class_evaluate(
|
|
61
|
+
self,
|
|
62
|
+
dataset: Dataset,
|
|
63
|
+
metrics_names: tuple[str, ...] = (),
|
|
64
|
+
save_model: bool = False,
|
|
65
|
+
verbose: int = 0
|
|
66
|
+
) -> bool:
|
|
67
|
+
return False
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
## Protected abstract methods
|
|
71
|
+
@abstract(error_log=LogLevels.ERROR_TRACEBACK)
|
|
72
|
+
def _fit(
|
|
73
|
+
self,
|
|
74
|
+
model: Any,
|
|
75
|
+
x: Any,
|
|
76
|
+
y: Any | None = None,
|
|
77
|
+
validation_data: tuple[Any, Any] | None = None,
|
|
78
|
+
shuffle: bool = True,
|
|
79
|
+
batch_size: int | None = None,
|
|
80
|
+
epochs: int = 1,
|
|
81
|
+
callbacks: list[Any] | None = None,
|
|
82
|
+
class_weight: dict[int, float] | None = None,
|
|
83
|
+
verbose: int = 0,
|
|
84
|
+
*args: Any,
|
|
85
|
+
**kwargs: Any
|
|
86
|
+
) -> Any:
|
|
87
|
+
pass
|
|
88
|
+
|
|
89
|
+
@abstract(error_log=LogLevels.ERROR_TRACEBACK)
|
|
90
|
+
def _get_callbacks(self) -> list[Any]:
|
|
91
|
+
return []
|
|
92
|
+
|
|
93
|
+
@abstract(error_log=LogLevels.ERROR_TRACEBACK)
|
|
94
|
+
def _get_metrics(self) -> list[Any]:
|
|
95
|
+
return []
|
|
96
|
+
|
|
97
|
+
@abstract(error_log=LogLevels.ERROR_TRACEBACK)
|
|
98
|
+
def _get_optimizer(self, learning_rate: float = 0.0) -> Any:
|
|
99
|
+
pass
|
|
100
|
+
|
|
101
|
+
@abstract(error_log=LogLevels.ERROR_TRACEBACK)
|
|
102
|
+
def _get_loss(self) -> Any:
|
|
103
|
+
pass
|
|
104
|
+
|
|
105
|
+
@abstract(error_log=LogLevels.ERROR_TRACEBACK)
|
|
106
|
+
def _get_base_model(self) -> Any:
|
|
107
|
+
pass
|
|
108
|
+
|
|
109
|
+
@abstract(error_log=LogLevels.ERROR_TRACEBACK)
|
|
110
|
+
def _get_architectures(
|
|
111
|
+
self, optimizer: Any = None, loss: Any = None, metrics: list[Any] | None = None
|
|
112
|
+
) -> tuple[Any, Any]:
|
|
113
|
+
return (None, None)
|
|
114
|
+
|
|
115
|
+
@abstract(error_log=LogLevels.ERROR_TRACEBACK)
|
|
116
|
+
def _find_best_learning_rate(self, dataset: Dataset, verbose: int = 0) -> float:
|
|
117
|
+
return 0.0
|
|
118
|
+
|
|
119
|
+
@abstract(error_log=LogLevels.ERROR_TRACEBACK)
|
|
120
|
+
def _train_fold(self, dataset: Dataset, fold_number: int = 0, mlflow_prefix: str = "history", verbose: int = 0) -> Any:
|
|
121
|
+
pass
|
|
122
|
+
|
|
123
|
+
@abstract(error_log=LogLevels.ERROR_TRACEBACK)
|
|
124
|
+
def _log_final_model(self) -> None:
|
|
125
|
+
pass
|
|
126
|
+
|
|
127
|
+
@abstract(error_log=LogLevels.ERROR_TRACEBACK)
|
|
128
|
+
def _find_best_learning_rate_subprocess(
|
|
129
|
+
self, dataset: Dataset, queue: multiprocessing.queues.Queue[Any] | None = None, verbose: int = 0
|
|
130
|
+
) -> dict[str, Any] | None:
|
|
131
|
+
pass
|
|
132
|
+
|
|
133
|
+
@abstract(error_log=LogLevels.ERROR_TRACEBACK)
|
|
134
|
+
def _find_best_unfreeze_percentage_subprocess(
|
|
135
|
+
self, dataset: Dataset, queue: multiprocessing.queues.Queue[Any] | None = None, verbose: int = 0
|
|
136
|
+
) -> dict[str, Any] | None:
|
|
137
|
+
pass
|
|
138
|
+
|
|
139
|
+
@abstract(error_log=LogLevels.ERROR_TRACEBACK)
|
|
140
|
+
def _train_subprocess(
|
|
141
|
+
self,
|
|
142
|
+
dataset: Dataset,
|
|
143
|
+
checkpoint_path: str,
|
|
144
|
+
temp_dir: TemporaryDirectory[str] | None = None,
|
|
145
|
+
queue: multiprocessing.queues.Queue[Any] | None = None,
|
|
146
|
+
verbose: int = 0
|
|
147
|
+
) -> dict[str, Any] | None:
|
|
148
|
+
pass
|
|
149
|
+
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
|
|
2
|
+
# Imports
|
|
3
|
+
import itertools
|
|
4
|
+
|
|
5
|
+
from .keras.all import (
|
|
6
|
+
VGG16,
|
|
7
|
+
VGG19,
|
|
8
|
+
ConvNeXtBase,
|
|
9
|
+
ConvNeXtLarge,
|
|
10
|
+
ConvNeXtSmall,
|
|
11
|
+
ConvNeXtTiny,
|
|
12
|
+
ConvNeXtXLarge,
|
|
13
|
+
DenseNet121,
|
|
14
|
+
DenseNet169,
|
|
15
|
+
DenseNet201,
|
|
16
|
+
EfficientNetB0,
|
|
17
|
+
EfficientNetV2B0,
|
|
18
|
+
EfficientNetV2L,
|
|
19
|
+
EfficientNetV2M,
|
|
20
|
+
EfficientNetV2S,
|
|
21
|
+
MobileNet,
|
|
22
|
+
MobileNetV2,
|
|
23
|
+
MobileNetV3Large,
|
|
24
|
+
MobileNetV3Small,
|
|
25
|
+
ResNet50V2,
|
|
26
|
+
ResNet101V2,
|
|
27
|
+
ResNet152V2,
|
|
28
|
+
SqueezeNet,
|
|
29
|
+
Xception,
|
|
30
|
+
)
|
|
31
|
+
from .model_interface import ModelInterface
|
|
32
|
+
|
|
33
|
+
# Other models
|
|
34
|
+
from .sandbox import Sandbox
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# Create a custom dictionary class to allow for documentation
|
|
38
|
+
class ModelClassMap(dict[type[ModelInterface], tuple[str, ...]]):
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
# Routine map
|
|
42
|
+
CLASS_MAP: ModelClassMap = ModelClassMap({
|
|
43
|
+
SqueezeNet: ("squeezenet", "squeezenets", "all", "often"),
|
|
44
|
+
|
|
45
|
+
DenseNet121: ("densenet121", "densenets", "all", "often", "good"),
|
|
46
|
+
DenseNet169: ("densenet169", "densenets", "all", "often", "good"),
|
|
47
|
+
DenseNet201: ("densenet201", "densenets", "all", "often", "good"),
|
|
48
|
+
|
|
49
|
+
EfficientNetB0: ("efficientnetb0", "efficientnets", "all"),
|
|
50
|
+
EfficientNetV2B0: ("efficientnetv2b0", "efficientnets", "all"),
|
|
51
|
+
EfficientNetV2S: ("efficientnetv2s", "efficientnets", "all", "often"),
|
|
52
|
+
EfficientNetV2M: ("efficientnetv2m", "efficientnets", "all", "often"),
|
|
53
|
+
EfficientNetV2L: ("efficientnetv2l", "efficientnets", "all", "often"),
|
|
54
|
+
|
|
55
|
+
ConvNeXtTiny: ("convnexttiny", "convnexts", "all", "often", "good"),
|
|
56
|
+
ConvNeXtSmall: ("convnextsmall", "convnexts", "all", "often"),
|
|
57
|
+
ConvNeXtBase: ("convnextbase", "convnexts", "all", "often", "good"),
|
|
58
|
+
ConvNeXtLarge: ("convnextlarge", "convnexts", "all", "often"),
|
|
59
|
+
ConvNeXtXLarge: ("convnextxlarge", "convnexts", "all", "often", "good"),
|
|
60
|
+
|
|
61
|
+
VGG16: ("vgg16", "vggs", "all"),
|
|
62
|
+
VGG19: ("vgg19", "vggs", "all"),
|
|
63
|
+
|
|
64
|
+
MobileNet: ("mobilenet", "mobilenets", "all"),
|
|
65
|
+
MobileNetV2: ("mobilenetv2", "mobilenets", "all", "often"),
|
|
66
|
+
MobileNetV3Small: ("mobilenetv3small", "mobilenets", "all", "often"),
|
|
67
|
+
MobileNetV3Large: ("mobilenetv3large", "mobilenets", "all", "often", "good"),
|
|
68
|
+
|
|
69
|
+
ResNet50V2: ("resnet50v2", "resnetsv2", "resnets", "all", "often"),
|
|
70
|
+
ResNet101V2: ("resnet101v2", "resnetsv2", "resnets", "all", "often"),
|
|
71
|
+
ResNet152V2: ("resnet152v2", "resnetsv2", "resnets", "all", "often"),
|
|
72
|
+
|
|
73
|
+
Xception: ("xception", "xceptions", "all", "often"),
|
|
74
|
+
Sandbox: ("sandbox",),
|
|
75
|
+
})
|
|
76
|
+
|
|
77
|
+
# All models names and aliases
|
|
78
|
+
ALL_MODELS: list[str] = sorted(set(itertools.chain.from_iterable(v for v in CLASS_MAP.values())))
|
|
79
|
+
""" All models names and aliases found in the `CLASS_MAP` dictionary. """
|
|
80
|
+
|
|
81
|
+
# Additional docstring
|
|
82
|
+
new_docstring: str = "\n\n" + "\n".join(f"- {k.__name__}: {v}" for k, v in CLASS_MAP.items())
|
|
83
|
+
ModelClassMap.__doc__ = "Dictionary mapping class to their names and aliases. " + new_docstring
|
|
84
|
+
CLASS_MAP.__doc__ = ModelClassMap.__doc__
|
|
85
|
+
|