stouputils 1.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. stouputils/__init__.py +40 -0
  2. stouputils/__init__.pyi +14 -0
  3. stouputils/__main__.py +81 -0
  4. stouputils/_deprecated.py +37 -0
  5. stouputils/_deprecated.pyi +12 -0
  6. stouputils/all_doctests.py +160 -0
  7. stouputils/all_doctests.pyi +46 -0
  8. stouputils/applications/__init__.py +22 -0
  9. stouputils/applications/__init__.pyi +2 -0
  10. stouputils/applications/automatic_docs.py +634 -0
  11. stouputils/applications/automatic_docs.pyi +106 -0
  12. stouputils/applications/upscaler/__init__.py +39 -0
  13. stouputils/applications/upscaler/__init__.pyi +3 -0
  14. stouputils/applications/upscaler/config.py +128 -0
  15. stouputils/applications/upscaler/config.pyi +18 -0
  16. stouputils/applications/upscaler/image.py +247 -0
  17. stouputils/applications/upscaler/image.pyi +109 -0
  18. stouputils/applications/upscaler/video.py +287 -0
  19. stouputils/applications/upscaler/video.pyi +60 -0
  20. stouputils/archive.py +344 -0
  21. stouputils/archive.pyi +67 -0
  22. stouputils/backup.py +488 -0
  23. stouputils/backup.pyi +109 -0
  24. stouputils/collections.py +244 -0
  25. stouputils/collections.pyi +86 -0
  26. stouputils/continuous_delivery/__init__.py +27 -0
  27. stouputils/continuous_delivery/__init__.pyi +5 -0
  28. stouputils/continuous_delivery/cd_utils.py +243 -0
  29. stouputils/continuous_delivery/cd_utils.pyi +129 -0
  30. stouputils/continuous_delivery/github.py +522 -0
  31. stouputils/continuous_delivery/github.pyi +162 -0
  32. stouputils/continuous_delivery/pypi.py +91 -0
  33. stouputils/continuous_delivery/pypi.pyi +43 -0
  34. stouputils/continuous_delivery/pyproject.py +147 -0
  35. stouputils/continuous_delivery/pyproject.pyi +67 -0
  36. stouputils/continuous_delivery/stubs.py +86 -0
  37. stouputils/continuous_delivery/stubs.pyi +39 -0
  38. stouputils/ctx.py +408 -0
  39. stouputils/ctx.pyi +211 -0
  40. stouputils/data_science/config/get.py +51 -0
  41. stouputils/data_science/config/set.py +125 -0
  42. stouputils/data_science/data_processing/image/__init__.py +66 -0
  43. stouputils/data_science/data_processing/image/auto_contrast.py +79 -0
  44. stouputils/data_science/data_processing/image/axis_flip.py +58 -0
  45. stouputils/data_science/data_processing/image/bias_field_correction.py +74 -0
  46. stouputils/data_science/data_processing/image/binary_threshold.py +73 -0
  47. stouputils/data_science/data_processing/image/blur.py +59 -0
  48. stouputils/data_science/data_processing/image/brightness.py +54 -0
  49. stouputils/data_science/data_processing/image/canny.py +110 -0
  50. stouputils/data_science/data_processing/image/clahe.py +92 -0
  51. stouputils/data_science/data_processing/image/common.py +30 -0
  52. stouputils/data_science/data_processing/image/contrast.py +53 -0
  53. stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -0
  54. stouputils/data_science/data_processing/image/denoise.py +378 -0
  55. stouputils/data_science/data_processing/image/histogram_equalization.py +123 -0
  56. stouputils/data_science/data_processing/image/invert.py +64 -0
  57. stouputils/data_science/data_processing/image/laplacian.py +60 -0
  58. stouputils/data_science/data_processing/image/median_blur.py +52 -0
  59. stouputils/data_science/data_processing/image/noise.py +59 -0
  60. stouputils/data_science/data_processing/image/normalize.py +65 -0
  61. stouputils/data_science/data_processing/image/random_erase.py +66 -0
  62. stouputils/data_science/data_processing/image/resize.py +69 -0
  63. stouputils/data_science/data_processing/image/rotation.py +80 -0
  64. stouputils/data_science/data_processing/image/salt_pepper.py +68 -0
  65. stouputils/data_science/data_processing/image/sharpening.py +55 -0
  66. stouputils/data_science/data_processing/image/shearing.py +64 -0
  67. stouputils/data_science/data_processing/image/threshold.py +64 -0
  68. stouputils/data_science/data_processing/image/translation.py +71 -0
  69. stouputils/data_science/data_processing/image/zoom.py +83 -0
  70. stouputils/data_science/data_processing/image_augmentation.py +118 -0
  71. stouputils/data_science/data_processing/image_preprocess.py +183 -0
  72. stouputils/data_science/data_processing/prosthesis_detection.py +359 -0
  73. stouputils/data_science/data_processing/technique.py +481 -0
  74. stouputils/data_science/dataset/__init__.py +45 -0
  75. stouputils/data_science/dataset/dataset.py +292 -0
  76. stouputils/data_science/dataset/dataset_loader.py +135 -0
  77. stouputils/data_science/dataset/grouping_strategy.py +296 -0
  78. stouputils/data_science/dataset/image_loader.py +100 -0
  79. stouputils/data_science/dataset/xy_tuple.py +696 -0
  80. stouputils/data_science/metric_dictionnary.py +106 -0
  81. stouputils/data_science/metric_utils.py +847 -0
  82. stouputils/data_science/mlflow_utils.py +206 -0
  83. stouputils/data_science/models/abstract_model.py +149 -0
  84. stouputils/data_science/models/all.py +85 -0
  85. stouputils/data_science/models/base_keras.py +765 -0
  86. stouputils/data_science/models/keras/all.py +38 -0
  87. stouputils/data_science/models/keras/convnext.py +62 -0
  88. stouputils/data_science/models/keras/densenet.py +50 -0
  89. stouputils/data_science/models/keras/efficientnet.py +60 -0
  90. stouputils/data_science/models/keras/mobilenet.py +56 -0
  91. stouputils/data_science/models/keras/resnet.py +52 -0
  92. stouputils/data_science/models/keras/squeezenet.py +233 -0
  93. stouputils/data_science/models/keras/vgg.py +42 -0
  94. stouputils/data_science/models/keras/xception.py +38 -0
  95. stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -0
  96. stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -0
  97. stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -0
  98. stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -0
  99. stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -0
  100. stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -0
  101. stouputils/data_science/models/keras_utils/losses/__init__.py +12 -0
  102. stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -0
  103. stouputils/data_science/models/keras_utils/visualizations.py +416 -0
  104. stouputils/data_science/models/model_interface.py +939 -0
  105. stouputils/data_science/models/sandbox.py +116 -0
  106. stouputils/data_science/range_tuple.py +234 -0
  107. stouputils/data_science/scripts/augment_dataset.py +77 -0
  108. stouputils/data_science/scripts/exhaustive_process.py +133 -0
  109. stouputils/data_science/scripts/preprocess_dataset.py +70 -0
  110. stouputils/data_science/scripts/routine.py +168 -0
  111. stouputils/data_science/utils.py +285 -0
  112. stouputils/decorators.py +595 -0
  113. stouputils/decorators.pyi +242 -0
  114. stouputils/image.py +441 -0
  115. stouputils/image.pyi +172 -0
  116. stouputils/installer/__init__.py +18 -0
  117. stouputils/installer/__init__.pyi +5 -0
  118. stouputils/installer/common.py +67 -0
  119. stouputils/installer/common.pyi +39 -0
  120. stouputils/installer/downloader.py +101 -0
  121. stouputils/installer/downloader.pyi +24 -0
  122. stouputils/installer/linux.py +144 -0
  123. stouputils/installer/linux.pyi +39 -0
  124. stouputils/installer/main.py +223 -0
  125. stouputils/installer/main.pyi +57 -0
  126. stouputils/installer/windows.py +136 -0
  127. stouputils/installer/windows.pyi +31 -0
  128. stouputils/io.py +486 -0
  129. stouputils/io.pyi +213 -0
  130. stouputils/parallel.py +453 -0
  131. stouputils/parallel.pyi +211 -0
  132. stouputils/print.py +527 -0
  133. stouputils/print.pyi +146 -0
  134. stouputils/py.typed +1 -0
  135. stouputils-1.12.1.dist-info/METADATA +179 -0
  136. stouputils-1.12.1.dist-info/RECORD +138 -0
  137. stouputils-1.12.1.dist-info/WHEEL +4 -0
  138. stouputils-1.12.1.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,206 @@
1
+ """
2
+ This module contains utility functions for working with MLflow.
3
+
4
+ This module contains functions for:
5
+
6
+ - Getting the artifact path from the current mlflow run
7
+ - Getting the weights path
8
+ - Getting the runs by experiment name
9
+ - Logging the history of the model to the current mlflow run
10
+ - Starting a new mlflow run
11
+ """
12
+
13
+ # Imports
14
+ import os
15
+ from typing import Any, Literal
16
+
17
+ import mlflow
18
+ from mlflow.entities import Experiment, Run
19
+
20
+ from ..decorators import handle_error, LogLevels
21
+ from ..io import clean_path
22
+
23
+
24
+ # Get artifact path
25
+ def get_artifact_path(from_string: str = "", os_name: str = os.name) -> str:
26
+ """ Get the artifact path from the current mlflow run (without the file:// prefix).
27
+
28
+ Handles the different path formats for Windows and Unix-based systems.
29
+
30
+ Args:
31
+ from_string (str): Path to the artifact (optional, defaults to the current mlflow run)
32
+ os_name (str): OS name (optional, defaults to os.name)
33
+ Returns:
34
+ str: The artifact path
35
+ """
36
+ # Get the artifact path from the current mlflow run or from a string
37
+ if not from_string:
38
+ artifact_path: str = mlflow.get_artifact_uri()
39
+ else:
40
+ artifact_path: str = from_string
41
+
42
+ # Handle the different path formats for Windows and Unix-based systems
43
+ if os_name == "nt":
44
+ return artifact_path.replace("file:///", "")
45
+ else:
46
+ return artifact_path.replace("file://", "")
47
+
48
+ # Get weights path
49
+ def get_weights_path(from_string: str = "", weights_name: str = "best_model.keras", os_name: str = os.name) -> str:
50
+ """ Get the weights path from the current mlflow run.
51
+
52
+ Args:
53
+ from_string (str): Path to the artifact (optional, defaults to the current mlflow run)
54
+ weights_name (str): Name of the weights file (optional, defaults to "best_model.keras")
55
+ os_name (str): OS name (optional, defaults to os.name)
56
+ Returns:
57
+ str: The weights path
58
+
59
+ Examples:
60
+ >>> get_weights_path(from_string="file:///path/to/artifact", weights_name="best_model.keras", os_name="posix")
61
+ '/path/to/artifact/best_model.keras'
62
+
63
+ >>> get_weights_path(from_string="file:///C:/path/to/artifact", weights_name="best_model.keras", os_name="nt")
64
+ 'C:/path/to/artifact/best_model.keras'
65
+ """
66
+ return clean_path(f"{get_artifact_path(from_string=from_string, os_name=os_name)}/{weights_name}")
67
+
68
+ # Get runs by experiment name
69
+ def get_runs_by_experiment_name(experiment_name: str, filter_string: str = "", set_experiment: bool = False) -> list[Run]:
70
+ """ Get the runs by experiment name.
71
+
72
+ Args:
73
+ experiment_name (str): Name of the experiment
74
+ filter_string (str): Filter string to apply to the runs
75
+ set_experiment (bool): Whether to set the experiment
76
+ Returns:
77
+ list[Run]: List of runs
78
+ """
79
+ if set_experiment:
80
+ mlflow.set_experiment(experiment_name)
81
+ experiment: Experiment | None = mlflow.get_experiment_by_name(experiment_name)
82
+ if experiment:
83
+ return mlflow.search_runs(
84
+ experiment_ids=[experiment.experiment_id],
85
+ output_format="list",
86
+ filter_string=filter_string
87
+ ) # pyright: ignore [reportReturnType]
88
+ return []
89
+
90
+ def get_runs_by_model_name(experiment_name: str, model_name: str, set_experiment: bool = False) -> list[Run]:
91
+ """ Get the runs by model name.
92
+
93
+ Args:
94
+ experiment_name (str): Name of the experiment
95
+ model_name (str): Name of the model
96
+ set_experiment (bool): Whether to set the experiment
97
+ Returns:
98
+ list[Run]: List of runs
99
+ """
100
+ return get_runs_by_experiment_name(
101
+ experiment_name,
102
+ filter_string=f"tags.model_name = '{model_name}'",
103
+ set_experiment=set_experiment
104
+ )
105
+
106
+ # Log history
107
+ def log_history(history: dict[str, list[Any]], prefix: str = "history", **kwargs: Any) -> None:
108
+ """ Log the history of the model to the current mlflow run.
109
+
110
+ Args:
111
+ history (dict[str, list[Any]]): History of the model
112
+ (usually from a History object like from a Keras model: history.history)
113
+ **kwargs (Any): Additional arguments to pass to mlflow.log_metric
114
+ """
115
+ for (metric, values) in history.items():
116
+ for epoch, value in enumerate(values):
117
+ handle_error(mlflow.log_metric,
118
+ message=f"Error logging metric {metric}",
119
+ error_log=LogLevels.ERROR_TRACEBACK
120
+ )(f"{prefix}_{metric}", value, step=epoch, **kwargs)
121
+
122
+
123
+ def start_run(mlflow_uri: str, experiment_name: str, model_name: str, override_run_name: str = "", **kwargs: Any) -> str:
124
+ """ Start a new mlflow run.
125
+
126
+ Args:
127
+ mlflow_uri (str): MLflow URI
128
+ experiment_name (str): Name of the experiment
129
+ model_name (str): Name of the model
130
+ override_run_name (str): Override the run name (if empty, it will be set automatically)
131
+ **kwargs (Any): Additional arguments to pass to mlflow.start_run
132
+ Returns:
133
+ str: Name of the run (suffixed with the version number)
134
+ """
135
+ # Set the mlflow URI
136
+ mlflow.set_tracking_uri(mlflow_uri)
137
+
138
+ # Get the runs and increment the version number
139
+ runs: list[Run] = get_runs_by_model_name(experiment_name, model_name, set_experiment=True)
140
+ run_number: int = len(runs) + 1
141
+ run_name: str = f"{model_name}_v{run_number:02d}" if not override_run_name else override_run_name
142
+
143
+ # Start the run
144
+ mlflow.start_run(run_name=run_name, tags={"model_name": model_name}, log_system_metrics=True, **kwargs)
145
+ return run_name
146
+
147
+ # Get best run by metric
148
+ def get_best_run_by_metric(
149
+ experiment_name: str,
150
+ metric_name: str,
151
+ model_name: str = "",
152
+ ascending: bool = False,
153
+ has_saved_model: bool = True
154
+ ) -> Run | None:
155
+ """ Get the best run by a specific metric.
156
+
157
+ Args:
158
+ experiment_name (str): Name of the experiment
159
+ metric_name (str): Name of the metric to sort by
160
+ model_name (str): Name of the model (optional, if empty, all models are considered)
161
+ ascending (bool): Whether to sort in ascending order (default: False, i.e. maximum metric value is best)
162
+ has_saved_model (bool): Whether the model has been saved (default: True)
163
+ Returns:
164
+ Run | None: The best run or None if no runs are found
165
+ """
166
+ # Get the runs
167
+ filter_string: str = f"metrics.`{metric_name}` > 0"
168
+ if model_name:
169
+ filter_string += f" AND tags.model_name = '{model_name}'"
170
+ if has_saved_model:
171
+ filter_string += " AND tags.has_saved_model = 'True'"
172
+
173
+ runs: list[Run] = get_runs_by_experiment_name(
174
+ experiment_name,
175
+ filter_string=filter_string,
176
+ set_experiment=True
177
+ )
178
+
179
+ if not runs:
180
+ return None
181
+
182
+ # Sort the runs by the metric
183
+ sorted_runs: list[Run] = sorted(
184
+ runs,
185
+ key=lambda run: float(run.data.metrics.get(metric_name, 0)), # type: ignore
186
+ reverse=not ascending
187
+ )
188
+
189
+ return sorted_runs[0] if sorted_runs else None
190
+
191
+
192
+ def load_model(run_id: str, model_type: Literal["keras", "pytorch"] = "keras") -> Any:
193
+ """ Load a model from MLflow.
194
+
195
+ Args:
196
+ run_id (str): ID of the run to load the model from
197
+ model_type (Literal["keras", "pytorch"]): Type of model to load (default: "keras")
198
+ Returns:
199
+ Any: The loaded model
200
+ """
201
+ if model_type == "keras":
202
+ return mlflow.keras.load_model(f"runs:/{run_id}/best_model") # type: ignore
203
+ elif model_type == "pytorch":
204
+ return mlflow.pytorch.load_model(f"runs:/{run_id}/best_model") # type: ignore
205
+ raise ValueError(f"Model type {model_type} not supported")
206
+
@@ -0,0 +1,149 @@
1
+ """ Abstract base class for all model implementations.
2
+ Defines the interface that all concrete model classes must implement.
3
+
4
+ Provides abstract methods for core model operations including:
5
+
6
+ - Class routine management
7
+ - Model loading
8
+ - Training procedures
9
+ - Prediction functionality
10
+ - Evaluation metrics
11
+
12
+ Classes inheriting from AbstractModel must implement all methods.
13
+ """
14
+
15
+ # Imports
16
+ from __future__ import annotations
17
+
18
+ import multiprocessing.queues
19
+ from collections.abc import Iterable
20
+ from tempfile import TemporaryDirectory
21
+ from typing import Any
22
+
23
+ from ...decorators import abstract, LogLevels
24
+
25
+ from ..dataset import Dataset
26
+
27
+
28
+ # Base class
29
+ class AbstractModel:
30
+ """ Abstract class for all models to copy and implement the methods. """
31
+ # Class methods
32
+ @abstract(error_log=LogLevels.ERROR_TRACEBACK)
33
+ def __init__(
34
+ self, num_classes: int, kfold: int = 0, transfer_learning: str = "imagenet", **override_params: Any
35
+ ) -> None:
36
+ pass
37
+
38
+
39
+ ## Public abstract methods
40
+ @abstract(error_log=LogLevels.ERROR_TRACEBACK)
41
+ def routine_full(self, dataset: Dataset, verbose: int = 0) -> AbstractModel:
42
+ return self
43
+
44
+ @abstract(error_log=LogLevels.ERROR_TRACEBACK)
45
+ def class_load(self) -> None:
46
+ pass
47
+
48
+
49
+ @abstract(error_log=LogLevels.ERROR_TRACEBACK)
50
+ def class_train(self, dataset: Dataset, verbose: int = 0) -> bool:
51
+ return False
52
+
53
+
54
+ @abstract(error_log=LogLevels.ERROR_TRACEBACK)
55
+ def class_predict(self, X_test: Iterable[Any]) -> Iterable[Any]:
56
+ return []
57
+
58
+
59
+ @abstract(error_log=LogLevels.ERROR_TRACEBACK)
60
+ def class_evaluate(
61
+ self,
62
+ dataset: Dataset,
63
+ metrics_names: tuple[str, ...] = (),
64
+ save_model: bool = False,
65
+ verbose: int = 0
66
+ ) -> bool:
67
+ return False
68
+
69
+
70
+ ## Protected abstract methods
71
+ @abstract(error_log=LogLevels.ERROR_TRACEBACK)
72
+ def _fit(
73
+ self,
74
+ model: Any,
75
+ x: Any,
76
+ y: Any | None = None,
77
+ validation_data: tuple[Any, Any] | None = None,
78
+ shuffle: bool = True,
79
+ batch_size: int | None = None,
80
+ epochs: int = 1,
81
+ callbacks: list[Any] | None = None,
82
+ class_weight: dict[int, float] | None = None,
83
+ verbose: int = 0,
84
+ *args: Any,
85
+ **kwargs: Any
86
+ ) -> Any:
87
+ pass
88
+
89
+ @abstract(error_log=LogLevels.ERROR_TRACEBACK)
90
+ def _get_callbacks(self) -> list[Any]:
91
+ return []
92
+
93
+ @abstract(error_log=LogLevels.ERROR_TRACEBACK)
94
+ def _get_metrics(self) -> list[Any]:
95
+ return []
96
+
97
+ @abstract(error_log=LogLevels.ERROR_TRACEBACK)
98
+ def _get_optimizer(self, learning_rate: float = 0.0) -> Any:
99
+ pass
100
+
101
+ @abstract(error_log=LogLevels.ERROR_TRACEBACK)
102
+ def _get_loss(self) -> Any:
103
+ pass
104
+
105
+ @abstract(error_log=LogLevels.ERROR_TRACEBACK)
106
+ def _get_base_model(self) -> Any:
107
+ pass
108
+
109
+ @abstract(error_log=LogLevels.ERROR_TRACEBACK)
110
+ def _get_architectures(
111
+ self, optimizer: Any = None, loss: Any = None, metrics: list[Any] | None = None
112
+ ) -> tuple[Any, Any]:
113
+ return (None, None)
114
+
115
+ @abstract(error_log=LogLevels.ERROR_TRACEBACK)
116
+ def _find_best_learning_rate(self, dataset: Dataset, verbose: int = 0) -> float:
117
+ return 0.0
118
+
119
+ @abstract(error_log=LogLevels.ERROR_TRACEBACK)
120
+ def _train_fold(self, dataset: Dataset, fold_number: int = 0, mlflow_prefix: str = "history", verbose: int = 0) -> Any:
121
+ pass
122
+
123
+ @abstract(error_log=LogLevels.ERROR_TRACEBACK)
124
+ def _log_final_model(self) -> None:
125
+ pass
126
+
127
+ @abstract(error_log=LogLevels.ERROR_TRACEBACK)
128
+ def _find_best_learning_rate_subprocess(
129
+ self, dataset: Dataset, queue: multiprocessing.queues.Queue[Any] | None = None, verbose: int = 0
130
+ ) -> dict[str, Any] | None:
131
+ pass
132
+
133
+ @abstract(error_log=LogLevels.ERROR_TRACEBACK)
134
+ def _find_best_unfreeze_percentage_subprocess(
135
+ self, dataset: Dataset, queue: multiprocessing.queues.Queue[Any] | None = None, verbose: int = 0
136
+ ) -> dict[str, Any] | None:
137
+ pass
138
+
139
+ @abstract(error_log=LogLevels.ERROR_TRACEBACK)
140
+ def _train_subprocess(
141
+ self,
142
+ dataset: Dataset,
143
+ checkpoint_path: str,
144
+ temp_dir: TemporaryDirectory[str] | None = None,
145
+ queue: multiprocessing.queues.Queue[Any] | None = None,
146
+ verbose: int = 0
147
+ ) -> dict[str, Any] | None:
148
+ pass
149
+
@@ -0,0 +1,85 @@
1
+
2
+ # Imports
3
+ import itertools
4
+
5
+ from .keras.all import (
6
+ VGG16,
7
+ VGG19,
8
+ ConvNeXtBase,
9
+ ConvNeXtLarge,
10
+ ConvNeXtSmall,
11
+ ConvNeXtTiny,
12
+ ConvNeXtXLarge,
13
+ DenseNet121,
14
+ DenseNet169,
15
+ DenseNet201,
16
+ EfficientNetB0,
17
+ EfficientNetV2B0,
18
+ EfficientNetV2L,
19
+ EfficientNetV2M,
20
+ EfficientNetV2S,
21
+ MobileNet,
22
+ MobileNetV2,
23
+ MobileNetV3Large,
24
+ MobileNetV3Small,
25
+ ResNet50V2,
26
+ ResNet101V2,
27
+ ResNet152V2,
28
+ SqueezeNet,
29
+ Xception,
30
+ )
31
+ from .model_interface import ModelInterface
32
+
33
+ # Other models
34
+ from .sandbox import Sandbox
35
+
36
+
37
+ # Create a custom dictionary class to allow for documentation
38
+ class ModelClassMap(dict[type[ModelInterface], tuple[str, ...]]):
39
+ pass
40
+
41
+ # Routine map
42
+ CLASS_MAP: ModelClassMap = ModelClassMap({
43
+ SqueezeNet: ("squeezenet", "squeezenets", "all", "often"),
44
+
45
+ DenseNet121: ("densenet121", "densenets", "all", "often", "good"),
46
+ DenseNet169: ("densenet169", "densenets", "all", "often", "good"),
47
+ DenseNet201: ("densenet201", "densenets", "all", "often", "good"),
48
+
49
+ EfficientNetB0: ("efficientnetb0", "efficientnets", "all"),
50
+ EfficientNetV2B0: ("efficientnetv2b0", "efficientnets", "all"),
51
+ EfficientNetV2S: ("efficientnetv2s", "efficientnets", "all", "often"),
52
+ EfficientNetV2M: ("efficientnetv2m", "efficientnets", "all", "often"),
53
+ EfficientNetV2L: ("efficientnetv2l", "efficientnets", "all", "often"),
54
+
55
+ ConvNeXtTiny: ("convnexttiny", "convnexts", "all", "often", "good"),
56
+ ConvNeXtSmall: ("convnextsmall", "convnexts", "all", "often"),
57
+ ConvNeXtBase: ("convnextbase", "convnexts", "all", "often", "good"),
58
+ ConvNeXtLarge: ("convnextlarge", "convnexts", "all", "often"),
59
+ ConvNeXtXLarge: ("convnextxlarge", "convnexts", "all", "often", "good"),
60
+
61
+ VGG16: ("vgg16", "vggs", "all"),
62
+ VGG19: ("vgg19", "vggs", "all"),
63
+
64
+ MobileNet: ("mobilenet", "mobilenets", "all"),
65
+ MobileNetV2: ("mobilenetv2", "mobilenets", "all", "often"),
66
+ MobileNetV3Small: ("mobilenetv3small", "mobilenets", "all", "often"),
67
+ MobileNetV3Large: ("mobilenetv3large", "mobilenets", "all", "often", "good"),
68
+
69
+ ResNet50V2: ("resnet50v2", "resnetsv2", "resnets", "all", "often"),
70
+ ResNet101V2: ("resnet101v2", "resnetsv2", "resnets", "all", "often"),
71
+ ResNet152V2: ("resnet152v2", "resnetsv2", "resnets", "all", "often"),
72
+
73
+ Xception: ("xception", "xceptions", "all", "often"),
74
+ Sandbox: ("sandbox",),
75
+ })
76
+
77
+ # All models names and aliases
78
+ ALL_MODELS: list[str] = sorted(set(itertools.chain.from_iterable(v for v in CLASS_MAP.values())))
79
+ """ All models names and aliases found in the `CLASS_MAP` dictionary. """
80
+
81
+ # Additional docstring
82
+ new_docstring: str = "\n\n" + "\n".join(f"- {k.__name__}: {v}" for k, v in CLASS_MAP.items())
83
+ ModelClassMap.__doc__ = "Dictionary mapping class to their names and aliases. " + new_docstring
84
+ CLASS_MAP.__doc__ = ModelClassMap.__doc__
85
+