autogluon.core 1.2.1b20250205__tar.gz → 1.2.1b20250207__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/PKG-INFO +1 -1
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/_setup_utils.py +1 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/models/abstract/abstract_model.py +6 -6
- autogluon.core-1.2.1b20250207/src/autogluon/core/trainer/abstract_trainer.py +198 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/version.py +1 -1
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon.core.egg-info/PKG-INFO +1 -1
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon.core.egg-info/requires.txt +5 -5
- autogluon.core-1.2.1b20250205/src/autogluon/core/trainer/abstract_trainer.py +0 -4910
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/setup.cfg +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/setup.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/__init__.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/augmentation/__init__.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/augmentation/distill_utils.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/calibrate/__init__.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/calibrate/_decision_threshold.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/calibrate/conformity_score.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/calibrate/temperature_scaling.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/callbacks/__init__.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/callbacks/_abstract_callback.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/callbacks/_early_stopping_callback.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/callbacks/_early_stopping_ensemble_callback.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/callbacks/_example_callback.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/constants.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/data/__init__.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/data/cleaner.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/data/label_cleaner.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/hpo/__init__.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/hpo/constants.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/hpo/exceptions.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/hpo/executors.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/hpo/ray_hpo.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/hpo/ray_tune_constants.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/hpo/ray_tune_scheduler.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/hpo/ray_tune_scheduler_factory.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/hpo/ray_tune_searcher_factory.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/hpo/space_converter.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/learner/__init__.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/learner/abstract_learner.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/learning_curves/plot_curves.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/metrics/__init__.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/metrics/classification_metrics.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/metrics/quantile_metrics.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/metrics/score_func.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/metrics/softclass_metrics.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/models/__init__.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/models/_utils.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/models/abstract/__init__.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/models/abstract/_tags.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/models/abstract/abstract_nn_model.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/models/abstract/model_trial.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/models/dummy/__init__.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/models/dummy/_dummy_quantile_regressor.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/models/dummy/dummy_model.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/models/ensemble/__init__.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/models/ensemble/bagged_ensemble_model.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/models/ensemble/fold_fitting_strategy.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/models/ensemble/ray_parallel_fold_fitting_strategy.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/models/ensemble/stacker_ensemble_model.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/models/ensemble/weighted_ensemble_model.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/models/greedy_ensemble/__init__.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/models/greedy_ensemble/ensemble_selection.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/models/greedy_ensemble/greedy_weighted_ensemble_model.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/problem_type.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/pseudolabeling/__init__.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/pseudolabeling/pseudolabeling.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/ray/__init__.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/ray/distributed_jobs_managers.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/ray/resources_calculator.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/scheduler/__init__.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/scheduler/reporter.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/scheduler/scheduler_factory.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/scheduler/seq_scheduler.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/searcher/__init__.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/searcher/dummy_searcher.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/searcher/exceptions.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/searcher/local_grid_searcher.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/searcher/local_random_searcher.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/searcher/local_searcher.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/searcher/searcher_factory.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/stacked_overfitting/utils.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/trainer/__init__.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/trainer/utils.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/utils/__init__.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/utils/decorators.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/utils/early_stopping.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/utils/exceptions.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/utils/feature_selection.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/utils/files.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/utils/infer_utils.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/utils/loaders/__init__.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/utils/miscs.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/utils/plots.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/utils/savers/__init__.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/utils/time.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/utils/utils.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/utils/version_utils.py +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon.core.egg-info/SOURCES.txt +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon.core.egg-info/dependency_links.txt +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon.core.egg-info/namespace_packages.txt +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon.core.egg-info/top_level.txt +0 -0
- {autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon.core.egg-info/zip-safe +0 -0
{autogluon.core-1.2.1b20250205 → autogluon.core-1.2.1b20250207}/src/autogluon/core/_setup_utils.py
RENAMED
@@ -32,6 +32,7 @@ DEPENDENT_PACKAGES = {
|
|
32
32
|
"async_timeout": ">=4.0,<6", # Major version cap
|
33
33
|
"transformers[sentencepiece]": ">=4.38.0,<5",
|
34
34
|
"accelerate": ">=0.34.0,<1.0",
|
35
|
+
"typing-extensions": ">=4.0,<5",
|
35
36
|
}
|
36
37
|
if LITE_MODE:
|
37
38
|
DEPENDENT_PACKAGES = {package: version for package, version in DEPENDENT_PACKAGES.items() if package not in ["psutil", "Pillow", "timm"]}
|
@@ -96,11 +96,11 @@ class AbstractModel:
|
|
96
96
|
|
97
97
|
def __init__(
|
98
98
|
self,
|
99
|
-
path: str = None,
|
100
|
-
name: str = None,
|
101
|
-
problem_type: str = None,
|
102
|
-
eval_metric:
|
103
|
-
hyperparameters: dict = None,
|
99
|
+
path: str | None = None,
|
100
|
+
name: str | None = None,
|
101
|
+
problem_type: str | None = None,
|
102
|
+
eval_metric: str | metrics.Scorer | None = None,
|
103
|
+
hyperparameters: dict | None = None,
|
104
104
|
):
|
105
105
|
if name is None:
|
106
106
|
self.name = self.__class__.__name__
|
@@ -1176,7 +1176,7 @@ class AbstractModel:
|
|
1176
1176
|
quantile_levels=self.quantile_levels,
|
1177
1177
|
)
|
1178
1178
|
|
1179
|
-
def save(self, path: str = None, verbose: bool = True) -> str:
|
1179
|
+
def save(self, path: str | None = None, verbose: bool = True) -> str:
|
1180
1180
|
"""
|
1181
1181
|
Saves the model to disk.
|
1182
1182
|
|
@@ -0,0 +1,198 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import os
|
4
|
+
from typing import Any, Generic, Type, TypeVar
|
5
|
+
|
6
|
+
import networkx as nx
|
7
|
+
from typing_extensions import Self
|
8
|
+
|
9
|
+
from autogluon.core.models import AbstractModel
|
10
|
+
from autogluon.core.utils.loaders import load_pkl
|
11
|
+
from autogluon.core.utils.savers import save_json, save_pkl
|
12
|
+
|
13
|
+
ModelTypeT = TypeVar("ModelTypeT", bound=AbstractModel)
|
14
|
+
|
15
|
+
|
16
|
+
class AbstractTrainer(Generic[ModelTypeT]):
|
17
|
+
trainer_file_name = "trainer.pkl"
|
18
|
+
trainer_info_name = "info.pkl"
|
19
|
+
trainer_info_json_name = "info.json"
|
20
|
+
|
21
|
+
def __init__(self, path: str, *, low_memory: bool, save_data: bool):
|
22
|
+
self.path = path
|
23
|
+
self.reset_paths = False
|
24
|
+
|
25
|
+
self.low_memory: bool = low_memory
|
26
|
+
self.save_data: bool = save_data
|
27
|
+
|
28
|
+
#: dict of model name -> model object. A key, value pair only exists if a model is persisted in memory.
|
29
|
+
self.models: dict[str, Any] = {}
|
30
|
+
|
31
|
+
#: Directed Acyclic Graph (DAG) of model interactions. Describes how certain models depend on the predictions of certain
|
32
|
+
#: other models. Contains numerous metadata regarding each model.
|
33
|
+
self.model_graph = nx.DiGraph()
|
34
|
+
self.model_best: str | None = None
|
35
|
+
|
36
|
+
#: Names which are banned but are not used by a trained model.
|
37
|
+
self._extra_banned_names: set[str] = set()
|
38
|
+
|
39
|
+
def _get_banned_model_names(self) -> list[str]:
|
40
|
+
"""Gets all model names which would cause model files to be overwritten if a new model
|
41
|
+
was trained with the name
|
42
|
+
"""
|
43
|
+
return self.get_model_names() + list(self._extra_banned_names)
|
44
|
+
|
45
|
+
@property
|
46
|
+
def path_root(self) -> str:
|
47
|
+
"""directory containing learner.pkl"""
|
48
|
+
return os.path.dirname(self.path)
|
49
|
+
|
50
|
+
@property
|
51
|
+
def path_utils(self) -> str:
|
52
|
+
return os.path.join(self.path_root, "utils")
|
53
|
+
|
54
|
+
@property
|
55
|
+
def path_data(self) -> str:
|
56
|
+
return os.path.join(self.path_utils, "data")
|
57
|
+
|
58
|
+
def set_contexts(self, path_context: str) -> None:
|
59
|
+
self.path = self.create_contexts(path_context)
|
60
|
+
|
61
|
+
def create_contexts(self, path_context: str) -> str:
|
62
|
+
path = path_context
|
63
|
+
return path
|
64
|
+
|
65
|
+
def save_model(self, model: ModelTypeT) -> None:
|
66
|
+
model.save()
|
67
|
+
if not self.low_memory:
|
68
|
+
self.models[model.name] = model
|
69
|
+
|
70
|
+
def get_models_attribute_dict(self, attribute: str, models: list[str] | None = None) -> dict[str, Any]:
|
71
|
+
raise NotImplementedError
|
72
|
+
|
73
|
+
def get_model_attribute(self, model: str | ModelTypeT, attribute: str, **kwargs) -> Any:
|
74
|
+
"""Return model attribute value.
|
75
|
+
If `default` is specified, return default value if attribute does not exist.
|
76
|
+
If `default` is not specified, raise ValueError if attribute does not exist.
|
77
|
+
"""
|
78
|
+
if not isinstance(model, str):
|
79
|
+
model = model.name
|
80
|
+
if model not in self.model_graph.nodes:
|
81
|
+
raise ValueError(f"Model does not exist: (model={model})")
|
82
|
+
if attribute not in self.model_graph.nodes[model]:
|
83
|
+
if "default" in kwargs:
|
84
|
+
return kwargs["default"]
|
85
|
+
else:
|
86
|
+
raise ValueError(f"Model does not contain attribute: (model={model}, attribute={attribute})")
|
87
|
+
if attribute == "path":
|
88
|
+
return os.path.join(*self.model_graph.nodes[model][attribute])
|
89
|
+
return self.model_graph.nodes[model][attribute]
|
90
|
+
|
91
|
+
def set_model_attribute(self, model: str | ModelTypeT, attribute: str, val: Any) -> None:
|
92
|
+
if not isinstance(model, str):
|
93
|
+
model = model.name
|
94
|
+
self.model_graph.nodes[model][attribute] = val
|
95
|
+
|
96
|
+
def get_minimum_model_set(self, model: str | ModelTypeT, include_self: bool = True) -> list:
|
97
|
+
"""Gets the minimum set of models that the provided model depends on, including itself
|
98
|
+
Returns a list of model names
|
99
|
+
"""
|
100
|
+
if not isinstance(model, str):
|
101
|
+
model = model.name
|
102
|
+
minimum_model_set = list(nx.bfs_tree(self.model_graph, model, reverse=True))
|
103
|
+
if not include_self:
|
104
|
+
minimum_model_set = [m for m in minimum_model_set if m != model]
|
105
|
+
return minimum_model_set
|
106
|
+
|
107
|
+
def get_model_info(self, model: str | ModelTypeT) -> dict[str, Any]:
|
108
|
+
if isinstance(model, str):
|
109
|
+
if model in self.models.keys():
|
110
|
+
model = self.models[model]
|
111
|
+
if isinstance(model, str):
|
112
|
+
model_type = self.get_model_attribute(model=model, attribute="type")
|
113
|
+
model_path = self.get_model_attribute(model=model, attribute="path")
|
114
|
+
model_info = model_type.load_info(path=os.path.join(self.path, model_path))
|
115
|
+
else:
|
116
|
+
model_info = model.get_info()
|
117
|
+
return model_info
|
118
|
+
|
119
|
+
def get_model_names(self) -> list[str]:
|
120
|
+
"""Get all model names that are registered in the model graph, in no particular order."""
|
121
|
+
return list(self.model_graph.nodes)
|
122
|
+
|
123
|
+
def get_models_info(self, models: list[str | ModelTypeT] | None = None) -> dict[str, dict[str, Any]]:
|
124
|
+
models_ = self.get_model_names() if models is None else models
|
125
|
+
model_info_dict = dict()
|
126
|
+
for model in models_:
|
127
|
+
model_name = model if isinstance(model, str) else model.name
|
128
|
+
model_info_dict[model_name] = self.get_model_info(model=model)
|
129
|
+
return model_info_dict
|
130
|
+
|
131
|
+
# TODO: model_name change to model in params
|
132
|
+
def load_model(
|
133
|
+
self, model_name: str | ModelTypeT, path: str | None = None, model_type: Type[ModelTypeT] | None = None
|
134
|
+
) -> ModelTypeT:
|
135
|
+
if isinstance(model_name, AbstractModel):
|
136
|
+
return model_name
|
137
|
+
if model_name in self.models.keys():
|
138
|
+
return self.models[model_name]
|
139
|
+
else:
|
140
|
+
if path is None:
|
141
|
+
path = self.get_model_attribute(
|
142
|
+
model=model_name, attribute="path"
|
143
|
+
) # get relative location of the model to the trainer
|
144
|
+
assert path is not None
|
145
|
+
if model_type is None:
|
146
|
+
model_type = self.get_model_attribute(model=model_name, attribute="type")
|
147
|
+
assert model_type is not None
|
148
|
+
return model_type.load(path=os.path.join(self.path, path), reset_paths=self.reset_paths)
|
149
|
+
|
150
|
+
@classmethod
|
151
|
+
def load_info(cls, path: str, reset_paths: bool = False, load_model_if_required: bool = True) -> dict[str, Any]:
|
152
|
+
load_path = os.path.join(path, cls.trainer_info_name)
|
153
|
+
try:
|
154
|
+
return load_pkl.load(path=load_path)
|
155
|
+
except:
|
156
|
+
if load_model_if_required:
|
157
|
+
trainer = cls.load(path=path, reset_paths=reset_paths)
|
158
|
+
return trainer.get_info()
|
159
|
+
else:
|
160
|
+
raise
|
161
|
+
|
162
|
+
def save_info(self, include_model_info: bool = False) -> dict[str, Any]:
|
163
|
+
info = self.get_info(include_model_info=include_model_info)
|
164
|
+
|
165
|
+
save_pkl.save(path=os.path.join(self.path, self.trainer_info_name), object=info)
|
166
|
+
save_json.save(path=os.path.join(self.path, self.trainer_info_json_name), obj=info)
|
167
|
+
return info
|
168
|
+
|
169
|
+
def construct_model_templates(
|
170
|
+
self, hyperparameters: dict[str, Any]
|
171
|
+
) -> tuple[list[ModelTypeT], dict] | list[ModelTypeT]:
|
172
|
+
raise NotImplementedError
|
173
|
+
|
174
|
+
def get_model_best(self) -> str:
|
175
|
+
raise NotImplementedError
|
176
|
+
|
177
|
+
def get_info(self, include_model_info: bool = False) -> dict[str, Any]:
|
178
|
+
raise NotImplementedError
|
179
|
+
|
180
|
+
def save(self) -> None:
|
181
|
+
raise NotImplementedError
|
182
|
+
|
183
|
+
@classmethod
|
184
|
+
def load(cls, path: str, reset_paths: bool = False) -> Self:
|
185
|
+
load_path = os.path.join(path, cls.trainer_file_name)
|
186
|
+
if not reset_paths:
|
187
|
+
return load_pkl.load(path=load_path)
|
188
|
+
else:
|
189
|
+
obj = load_pkl.load(path=load_path)
|
190
|
+
obj.set_contexts(path)
|
191
|
+
obj.reset_paths = reset_paths
|
192
|
+
return obj
|
193
|
+
|
194
|
+
def fit(self, *args, **kwargs):
|
195
|
+
raise NotImplementedError
|
196
|
+
|
197
|
+
def predict(self, *args, **kwargs) -> Any:
|
198
|
+
raise NotImplementedError
|
@@ -7,12 +7,12 @@ tqdm<5,>=4.38
|
|
7
7
|
requests
|
8
8
|
matplotlib<3.11,>=3.7.0
|
9
9
|
boto3<2,>=1.10
|
10
|
-
autogluon.common==1.2.
|
10
|
+
autogluon.common==1.2.1b20250207
|
11
11
|
|
12
12
|
[all]
|
13
|
-
ray[default,tune]<2.41,>=2.10.0
|
14
13
|
pyarrow>=15.0.0
|
15
14
|
ray[default]<2.41,>=2.10.0
|
15
|
+
ray[default,tune]<2.41,>=2.10.0
|
16
16
|
hyperopt<0.2.8,>=0.2.7
|
17
17
|
|
18
18
|
[ray]
|
@@ -24,8 +24,8 @@ ray[default,tune]<2.41,>=2.10.0
|
|
24
24
|
hyperopt<0.2.8,>=0.2.7
|
25
25
|
|
26
26
|
[tests]
|
27
|
-
pytest
|
27
|
+
pytest-mypy
|
28
28
|
types-setuptools
|
29
|
-
flake8
|
30
29
|
types-requests
|
31
|
-
pytest
|
30
|
+
pytest
|
31
|
+
flake8
|