collie-mlops 0.1.1b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- collie/__init__.py +69 -0
- collie/_common/__init__.py +0 -0
- collie/_common/decorator.py +53 -0
- collie/_common/exceptions.py +104 -0
- collie/_common/mlflow_model_io/__init__.py +0 -0
- collie/_common/mlflow_model_io/base_flavor_handler.py +26 -0
- collie/_common/mlflow_model_io/flavor_registry.py +72 -0
- collie/_common/mlflow_model_io/model_flavors.py +259 -0
- collie/_common/mlflow_model_io/model_io.py +65 -0
- collie/_common/utils.py +13 -0
- collie/contracts/__init__.py +0 -0
- collie/contracts/event.py +79 -0
- collie/contracts/mlflow.py +444 -0
- collie/contracts/orchestrator.py +79 -0
- collie/core/__init__.py +41 -0
- collie/core/enums/__init__.py +0 -0
- collie/core/enums/components.py +26 -0
- collie/core/enums/ml_models.py +20 -0
- collie/core/evaluator/__init__.py +0 -0
- collie/core/evaluator/evaluator.py +147 -0
- collie/core/models.py +125 -0
- collie/core/orchestrator/__init__.py +0 -0
- collie/core/orchestrator/orchestrator.py +47 -0
- collie/core/pusher/__init__.py +0 -0
- collie/core/pusher/pusher.py +98 -0
- collie/core/trainer/__init__.py +0 -0
- collie/core/trainer/trainer.py +78 -0
- collie/core/transform/__init__.py +0 -0
- collie/core/transform/transform.py +87 -0
- collie/core/tuner/__init__.py +0 -0
- collie/core/tuner/tuner.py +84 -0
- collie/helper/__init__.py +0 -0
- collie/helper/pytorch/__init__.py +0 -0
- collie/helper/pytorch/callback/__init__.py +0 -0
- collie/helper/pytorch/callback/callback.py +155 -0
- collie/helper/pytorch/callback/earlystop.py +54 -0
- collie/helper/pytorch/callback/model_checkpoint.py +100 -0
- collie/helper/pytorch/model/__init__.py +0 -0
- collie/helper/pytorch/model/loader.py +55 -0
- collie/helper/pytorch/trainer.py +304 -0
- collie_mlops-0.1.1b0.dist-info/LICENSE +21 -0
- collie_mlops-0.1.1b0.dist-info/METADATA +259 -0
- collie_mlops-0.1.1b0.dist-info/RECORD +45 -0
- collie_mlops-0.1.1b0.dist-info/WHEEL +5 -0
- collie_mlops-0.1.1b0.dist-info/top_level.txt +1 -0
collie/__init__.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Collie - A Lightweight MLOps Framework for Machine Learning Workflows
|
|
3
|
+
|
|
4
|
+
Collie provides a modular, event-driven architecture for building ML pipelines
|
|
5
|
+
with deep MLflow integration.
|
|
6
|
+
|
|
7
|
+
Quick Start:
|
|
8
|
+
>>> from collie import Transformer, Trainer, Orchestrator
|
|
9
|
+
>>> # Define your components
|
|
10
|
+
>>> orchestrator = Orchestrator(
|
|
11
|
+
... components=[MyTransformer(), MyTrainer()],
|
|
12
|
+
... tracking_uri="http://localhost:5000",
|
|
13
|
+
... registered_model_name="my_model"
|
|
14
|
+
... )
|
|
15
|
+
>>> orchestrator.run()
|
|
16
|
+
|
|
17
|
+
For more examples, see: https://github.com/ChingHuanChiu/collie
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
__author__ = "ChingHuanChiu"
|
|
21
|
+
__email__ = "stevenchiou8@gmail.com"
|
|
22
|
+
__version__ = "0.1.0b0"
|
|
23
|
+
|
|
24
|
+
# Import all main components for easy access
|
|
25
|
+
from .contracts.event import Event, EventType, PipelineContext
|
|
26
|
+
from .core.transform.transform import Transformer
|
|
27
|
+
from .core.trainer.trainer import Trainer
|
|
28
|
+
from .core.tuner.tuner import Tuner
|
|
29
|
+
from .core.evaluator.evaluator import Evaluator
|
|
30
|
+
from .core.pusher.pusher import Pusher
|
|
31
|
+
from .core.orchestrator.orchestrator import Orchestrator
|
|
32
|
+
|
|
33
|
+
# Import data models
|
|
34
|
+
from .core.models import (
|
|
35
|
+
TransformerPayload,
|
|
36
|
+
TrainerPayload,
|
|
37
|
+
TunerPayload,
|
|
38
|
+
EvaluatorPayload,
|
|
39
|
+
PusherPayload,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
# Import enums for configuration
|
|
43
|
+
from .core.enums.ml_models import ModelFlavor, MLflowModelStage
|
|
44
|
+
|
|
45
|
+
__all__ = [
|
|
46
|
+
# Core components - the main classes users interact with
|
|
47
|
+
"Transformer",
|
|
48
|
+
"Trainer",
|
|
49
|
+
"Tuner",
|
|
50
|
+
"Evaluator",
|
|
51
|
+
"Pusher",
|
|
52
|
+
"Orchestrator",
|
|
53
|
+
|
|
54
|
+
# Event system - for building custom workflows
|
|
55
|
+
"Event",
|
|
56
|
+
"EventType",
|
|
57
|
+
"PipelineContext",
|
|
58
|
+
|
|
59
|
+
# Payload models - for type-safe data passing
|
|
60
|
+
"TransformerPayload",
|
|
61
|
+
"TrainerPayload",
|
|
62
|
+
"TunerPayload",
|
|
63
|
+
"EvaluatorPayload",
|
|
64
|
+
"PusherPayload",
|
|
65
|
+
|
|
66
|
+
# Configuration enums
|
|
67
|
+
"ModelFlavor",
|
|
68
|
+
"MLflowModelStage",
|
|
69
|
+
]
|
|
File without changes
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from typing import Tuple, List
|
|
2
|
+
from functools import wraps
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def type_checker(
|
|
6
|
+
typing: Tuple[type],
|
|
7
|
+
error_msg: str
|
|
8
|
+
):
|
|
9
|
+
"""
|
|
10
|
+
A decorator that checks the type of the output of a function.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
typing (Tuple[type]): A tuple of types to check against.
|
|
14
|
+
error_msg (str): The error message to be raised if the type does not match.
|
|
15
|
+
|
|
16
|
+
Raises:
|
|
17
|
+
TypeError: If the type of the output of the function does not match with given types.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def closure(func):
|
|
21
|
+
@wraps(func)
|
|
22
|
+
def wrapper(*arg, **kwarg):
|
|
23
|
+
result = func(*arg, **kwarg)
|
|
24
|
+
if not isinstance(result, typing):
|
|
25
|
+
raise TypeError(error_msg)
|
|
26
|
+
return result
|
|
27
|
+
return wrapper
|
|
28
|
+
return closure
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def dict_key_checker(keys: List[str]):
|
|
32
|
+
"""
|
|
33
|
+
A decorator that checks the keys of the output of a function.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
keys (List[str]): A list of keys to check against.
|
|
37
|
+
|
|
38
|
+
Raises:
|
|
39
|
+
TypeError: If the output of the function is not a dictionary.
|
|
40
|
+
KeyError: If the output of the function does not contain all the keys in the list.
|
|
41
|
+
"""
|
|
42
|
+
def closure(func):
|
|
43
|
+
@wraps(func)
|
|
44
|
+
def wrapper(*arg, **kwarg):
|
|
45
|
+
result = func(*arg, **kwarg)
|
|
46
|
+
if not isinstance(result, dict):
|
|
47
|
+
raise TypeError("The output must be a dictionary.")
|
|
48
|
+
all_keys_exist = all(key in result for key in keys)
|
|
49
|
+
if not all_keys_exist:
|
|
50
|
+
raise KeyError(f"The following keys must all exist in the output: {keys}. Output: {result}")
|
|
51
|
+
return result
|
|
52
|
+
return wrapper
|
|
53
|
+
return closure
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
|
|
2
|
+
class CollieBaseException(Exception):
|
|
3
|
+
"""Base exception for all Collie framework errors."""
|
|
4
|
+
|
|
5
|
+
def __init__(self, message: str, component: str = None, details: dict = None):
|
|
6
|
+
self.message = message
|
|
7
|
+
self.component = component or self.__class__.__name__.replace('Error', '')
|
|
8
|
+
self.details = details or {}
|
|
9
|
+
|
|
10
|
+
detailed_message = f"[{self.component}] {message}"
|
|
11
|
+
if self.details:
|
|
12
|
+
detailed_message += f" Details: {self.details}"
|
|
13
|
+
|
|
14
|
+
super().__init__(detailed_message)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class MLflowConfigurationError(CollieBaseException):
|
|
18
|
+
"""Raised when MLflow configuration is invalid."""
|
|
19
|
+
|
|
20
|
+
def __init__(self, message: str, config_param: str = None, **kwargs):
|
|
21
|
+
details = kwargs.get('details', {})
|
|
22
|
+
if config_param:
|
|
23
|
+
details['config_parameter'] = config_param
|
|
24
|
+
super().__init__(message, component="MLflow Config", details=details)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class MLflowOperationError(CollieBaseException):
|
|
28
|
+
"""Raised when MLflow operations fail."""
|
|
29
|
+
|
|
30
|
+
def __init__(self, message: str, operation: str = None, **kwargs):
|
|
31
|
+
details = kwargs.get('details', {})
|
|
32
|
+
if operation:
|
|
33
|
+
details['operation'] = operation
|
|
34
|
+
super().__init__(message, component="MLflow Operation", details=details)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class OrchestratorError(CollieBaseException):
|
|
38
|
+
"""Raised for errors in the orchestrator process."""
|
|
39
|
+
|
|
40
|
+
def __init__(self, message: str, pipeline_stage: str = None, **kwargs):
|
|
41
|
+
details = kwargs.get('details', {})
|
|
42
|
+
if pipeline_stage:
|
|
43
|
+
details['pipeline_stage'] = pipeline_stage
|
|
44
|
+
super().__init__(message, component="Orchestrator", details=details)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class TransformerError(CollieBaseException):
|
|
48
|
+
"""Raised when data transformation fails."""
|
|
49
|
+
|
|
50
|
+
def __init__(self, message: str, data_type: str = None, **kwargs):
|
|
51
|
+
details = kwargs.get('details', {})
|
|
52
|
+
if data_type:
|
|
53
|
+
details['data_type'] = data_type
|
|
54
|
+
super().__init__(message, component="Transformer", details=details)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class TrainerError(CollieBaseException):
|
|
58
|
+
"""Raised when model training fails."""
|
|
59
|
+
|
|
60
|
+
def __init__(self, message: str, model_type: str = None, **kwargs):
|
|
61
|
+
details = kwargs.get('details', {})
|
|
62
|
+
if model_type:
|
|
63
|
+
details['model_type'] = model_type
|
|
64
|
+
super().__init__(message, component="Trainer", details=details)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class TunerError(CollieBaseException):
|
|
68
|
+
"""Raised when hyperparameter tuning fails."""
|
|
69
|
+
|
|
70
|
+
def __init__(self, message: str, tuning_method: str = None, **kwargs):
|
|
71
|
+
details = kwargs.get('details', {})
|
|
72
|
+
if tuning_method:
|
|
73
|
+
details['tuning_method'] = tuning_method
|
|
74
|
+
super().__init__(message, component="Tuner", details=details)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class EvaluatorError(CollieBaseException):
|
|
78
|
+
"""Raised when model evaluation fails."""
|
|
79
|
+
|
|
80
|
+
def __init__(self, message: str, metric: str = None, **kwargs):
|
|
81
|
+
details = kwargs.get('details', {})
|
|
82
|
+
if metric:
|
|
83
|
+
details['metric'] = metric
|
|
84
|
+
super().__init__(message, component="Evaluator", details=details)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class PusherError(CollieBaseException):
|
|
88
|
+
"""Raised when model pushing/deployment fails."""
|
|
89
|
+
|
|
90
|
+
def __init__(self, message: str, deployment_target: str = None, **kwargs):
|
|
91
|
+
details = kwargs.get('details', {})
|
|
92
|
+
if deployment_target:
|
|
93
|
+
details['deployment_target'] = deployment_target
|
|
94
|
+
super().__init__(message, component="Pusher", details=details)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class ModelFlavorError(CollieBaseException):
|
|
98
|
+
"""Raised when model flavor operations fail."""
|
|
99
|
+
|
|
100
|
+
def __init__(self, message: str, flavor: str = None, **kwargs):
|
|
101
|
+
details = kwargs.get('details', {})
|
|
102
|
+
if flavor:
|
|
103
|
+
details['flavor'] = flavor
|
|
104
|
+
super().__init__(message, component="Model Flavor", details=details)
|
|
File without changes
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class FlavorHandler(ABC):
|
|
6
|
+
|
|
7
|
+
@abstractmethod
|
|
8
|
+
def can_handle(self, model: Any) -> bool:
|
|
9
|
+
raise NotImplementedError
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def flavor(self):
|
|
13
|
+
raise NotImplementedError
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def log_model(
|
|
17
|
+
self,
|
|
18
|
+
model: Any,
|
|
19
|
+
name: str,
|
|
20
|
+
**kwargs: Any
|
|
21
|
+
) -> None:
|
|
22
|
+
raise NotImplementedError
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def load_model(self, model_uri: str) -> Any:
|
|
26
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from typing import List, Optional
|
|
2
|
+
|
|
3
|
+
from collie._common.mlflow_model_io.base_flavor_handler import FlavorHandler
|
|
4
|
+
from collie._common.mlflow_model_io.model_flavors import (
|
|
5
|
+
SklearnFlavorHandler,
|
|
6
|
+
XGBoostFlavorHandler,
|
|
7
|
+
PyTorchFlavorHandler,
|
|
8
|
+
LightGBMFlavorHandler,
|
|
9
|
+
TransformersFlavorHandler,
|
|
10
|
+
SKLEARN_AVAILABLE,
|
|
11
|
+
XGBOOST_AVAILABLE,
|
|
12
|
+
PYTORCH_AVAILABLE,
|
|
13
|
+
LIGHTGBM_AVAILABLE,
|
|
14
|
+
TRANSFORMERS_AVAILABLE
|
|
15
|
+
)
|
|
16
|
+
from collie._common.exceptions import ModelFlavorError
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class FlavorRegistry:
|
|
20
|
+
"""Registry for model flavor handlers with conditional loading."""
|
|
21
|
+
|
|
22
|
+
def __init__(self):
|
|
23
|
+
self._handlers: List[FlavorHandler] = []
|
|
24
|
+
|
|
25
|
+
# Only register handlers for available frameworks
|
|
26
|
+
if SKLEARN_AVAILABLE:
|
|
27
|
+
self._handlers.append(SklearnFlavorHandler())
|
|
28
|
+
if XGBOOST_AVAILABLE:
|
|
29
|
+
self._handlers.append(XGBoostFlavorHandler())
|
|
30
|
+
if PYTORCH_AVAILABLE:
|
|
31
|
+
self._handlers.append(PyTorchFlavorHandler())
|
|
32
|
+
if LIGHTGBM_AVAILABLE:
|
|
33
|
+
self._handlers.append(LightGBMFlavorHandler())
|
|
34
|
+
if TRANSFORMERS_AVAILABLE:
|
|
35
|
+
self._handlers.append(TransformersFlavorHandler())
|
|
36
|
+
|
|
37
|
+
if not self._handlers:
|
|
38
|
+
raise ModelFlavorError(
|
|
39
|
+
"No model flavor handlers available. Please install at least one supported ML framework."
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def find_handler_by_model(self, model) -> Optional[FlavorHandler]:
|
|
43
|
+
"""Find a handler that can handle the given model."""
|
|
44
|
+
for handler in self._handlers:
|
|
45
|
+
if handler.can_handle(model):
|
|
46
|
+
return handler
|
|
47
|
+
return None
|
|
48
|
+
|
|
49
|
+
def find_handler_by_flavor(self, flavor: str) -> Optional[FlavorHandler]:
|
|
50
|
+
"""Find a handler by flavor name."""
|
|
51
|
+
for handler in self._handlers:
|
|
52
|
+
if handler.flavor() == flavor:
|
|
53
|
+
return handler
|
|
54
|
+
return None
|
|
55
|
+
|
|
56
|
+
def get_available_flavors(self) -> List[str]:
|
|
57
|
+
"""Get list of available model flavors."""
|
|
58
|
+
return [handler.flavor().value for handler in self._handlers]
|
|
59
|
+
|
|
60
|
+
def get_handler_info(self) -> dict:
|
|
61
|
+
"""Get information about registered handlers."""
|
|
62
|
+
return {
|
|
63
|
+
"total_handlers": len(self._handlers),
|
|
64
|
+
"available_flavors": self.get_available_flavors(),
|
|
65
|
+
"framework_status": {
|
|
66
|
+
"sklearn": SKLEARN_AVAILABLE,
|
|
67
|
+
"xgboost": XGBOOST_AVAILABLE,
|
|
68
|
+
"pytorch": PYTORCH_AVAILABLE,
|
|
69
|
+
"lightgbm": LIGHTGBM_AVAILABLE,
|
|
70
|
+
"transformers": TRANSFORMERS_AVAILABLE
|
|
71
|
+
}
|
|
72
|
+
}
|
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
# Import with better error handling
|
|
6
|
+
try:
|
|
7
|
+
import mlflow.sklearn
|
|
8
|
+
import sklearn.base
|
|
9
|
+
SKLEARN_AVAILABLE = True
|
|
10
|
+
except ImportError:
|
|
11
|
+
SKLEARN_AVAILABLE = False
|
|
12
|
+
warnings.warn("scikit-learn not available. SklearnFlavorHandler will be disabled.")
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import mlflow.xgboost
|
|
16
|
+
import xgboost as xgb
|
|
17
|
+
XGBOOST_AVAILABLE = True
|
|
18
|
+
except ImportError:
|
|
19
|
+
XGBOOST_AVAILABLE = False
|
|
20
|
+
warnings.warn("XGBoost not available. XGBoostFlavorHandler will be disabled.")
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
import mlflow.pytorch
|
|
24
|
+
import torch.nn as nn
|
|
25
|
+
PYTORCH_AVAILABLE = True
|
|
26
|
+
except ImportError:
|
|
27
|
+
PYTORCH_AVAILABLE = False
|
|
28
|
+
warnings.warn("PyTorch not available. PyTorchFlavorHandler will be disabled.")
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
import mlflow.lightgbm
|
|
32
|
+
import lightgbm as lgb
|
|
33
|
+
LIGHTGBM_AVAILABLE = True
|
|
34
|
+
except (ImportError, Exception):
|
|
35
|
+
LIGHTGBM_AVAILABLE = False
|
|
36
|
+
lgb = None
|
|
37
|
+
warnings.warn("LightGBM not available. LightGBMFlavorHandler will be disabled.")
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
import mlflow.transformers
|
|
41
|
+
from transformers import PreTrainedModel
|
|
42
|
+
TRANSFORMERS_AVAILABLE = True
|
|
43
|
+
except ImportError:
|
|
44
|
+
TRANSFORMERS_AVAILABLE = False
|
|
45
|
+
warnings.warn("Transformers not available. TransformersFlavorHandler will be disabled.")
|
|
46
|
+
|
|
47
|
+
from collie._common.mlflow_model_io.base_flavor_handler import FlavorHandler
|
|
48
|
+
from collie.core.enums.ml_models import ModelFlavor
|
|
49
|
+
from collie._common.exceptions import ModelFlavorError
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class SklearnFlavorHandler(FlavorHandler):
|
|
53
|
+
"""Handler for scikit-learn models."""
|
|
54
|
+
|
|
55
|
+
def can_handle(self, model: Any) -> bool:
|
|
56
|
+
if not SKLEARN_AVAILABLE:
|
|
57
|
+
return False
|
|
58
|
+
return isinstance(model, sklearn.base.BaseEstimator)
|
|
59
|
+
|
|
60
|
+
def flavor(self) -> ModelFlavor:
|
|
61
|
+
return ModelFlavor.SKLEARN
|
|
62
|
+
|
|
63
|
+
def log_model(self, model: Any, name: str, **kwargs: Any) -> None:
|
|
64
|
+
if not SKLEARN_AVAILABLE:
|
|
65
|
+
raise ModelFlavorError(
|
|
66
|
+
"scikit-learn is not available. Please install it to log sklearn models.",
|
|
67
|
+
flavor="sklearn"
|
|
68
|
+
)
|
|
69
|
+
try:
|
|
70
|
+
mlflow.sklearn.log_model(sk_model=model, artifact_path=name, **kwargs)
|
|
71
|
+
except Exception as e:
|
|
72
|
+
raise ModelFlavorError(
|
|
73
|
+
f"Failed to log sklearn model: {str(e)}",
|
|
74
|
+
flavor="sklearn",
|
|
75
|
+
details={"model_type": type(model).__name__, "artifact_name": name}
|
|
76
|
+
) from e
|
|
77
|
+
|
|
78
|
+
def load_model(self, model_uri: str) -> Any:
|
|
79
|
+
if not SKLEARN_AVAILABLE:
|
|
80
|
+
raise ModelFlavorError(
|
|
81
|
+
"scikit-learn is not available. Please install it to load sklearn models.",
|
|
82
|
+
flavor="sklearn"
|
|
83
|
+
)
|
|
84
|
+
try:
|
|
85
|
+
return mlflow.sklearn.load_model(model_uri)
|
|
86
|
+
except Exception as e:
|
|
87
|
+
raise ModelFlavorError(
|
|
88
|
+
f"Failed to load sklearn model: {str(e)}",
|
|
89
|
+
flavor="sklearn",
|
|
90
|
+
details={"model_uri": model_uri}
|
|
91
|
+
) from e
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class XGBoostFlavorHandler(FlavorHandler):
|
|
95
|
+
"""Handler for XGBoost models."""
|
|
96
|
+
|
|
97
|
+
def can_handle(self, model: Any) -> bool:
|
|
98
|
+
if not XGBOOST_AVAILABLE:
|
|
99
|
+
return False
|
|
100
|
+
return isinstance(model, (xgb.Booster, xgb.XGBModel))
|
|
101
|
+
|
|
102
|
+
def flavor(self) -> ModelFlavor:
|
|
103
|
+
return ModelFlavor.XGBOOST
|
|
104
|
+
|
|
105
|
+
def log_model(self, model: Any, name: str, **kwargs: Any) -> None:
|
|
106
|
+
if not XGBOOST_AVAILABLE:
|
|
107
|
+
raise ModelFlavorError(
|
|
108
|
+
"XGBoost is not available. Please install it to log XGBoost models.",
|
|
109
|
+
flavor="xgboost"
|
|
110
|
+
)
|
|
111
|
+
try:
|
|
112
|
+
mlflow.xgboost.log_model(xgb_model=model, artifact_path=name, **kwargs)
|
|
113
|
+
except Exception as e:
|
|
114
|
+
raise ModelFlavorError(
|
|
115
|
+
f"Failed to log XGBoost model: {str(e)}",
|
|
116
|
+
flavor="xgboost",
|
|
117
|
+
details={"model_type": type(model).__name__, "artifact_name": name}
|
|
118
|
+
) from e
|
|
119
|
+
|
|
120
|
+
def load_model(self, model_uri: str) -> Any:
|
|
121
|
+
if not XGBOOST_AVAILABLE:
|
|
122
|
+
raise ModelFlavorError(
|
|
123
|
+
"XGBoost is not available. Please install it to load XGBoost models.",
|
|
124
|
+
flavor="xgboost"
|
|
125
|
+
)
|
|
126
|
+
try:
|
|
127
|
+
return mlflow.xgboost.load_model(model_uri)
|
|
128
|
+
except Exception as e:
|
|
129
|
+
raise ModelFlavorError(
|
|
130
|
+
f"Failed to load XGBoost model: {str(e)}",
|
|
131
|
+
flavor="xgboost",
|
|
132
|
+
details={"model_uri": model_uri}
|
|
133
|
+
) from e
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class PyTorchFlavorHandler(FlavorHandler):
|
|
137
|
+
"""Handler for PyTorch models."""
|
|
138
|
+
|
|
139
|
+
def can_handle(self, model: Any) -> bool:
|
|
140
|
+
if not PYTORCH_AVAILABLE:
|
|
141
|
+
return False
|
|
142
|
+
return isinstance(model, nn.Module)
|
|
143
|
+
|
|
144
|
+
def flavor(self) -> ModelFlavor:
|
|
145
|
+
return ModelFlavor.PYTORCH
|
|
146
|
+
|
|
147
|
+
def log_model(self, model: Any, name: str, **kwargs: Any) -> None:
|
|
148
|
+
if not PYTORCH_AVAILABLE:
|
|
149
|
+
raise ModelFlavorError(
|
|
150
|
+
"PyTorch is not available. Please install it to log PyTorch models.",
|
|
151
|
+
flavor="pytorch"
|
|
152
|
+
)
|
|
153
|
+
try:
|
|
154
|
+
mlflow.pytorch.log_model(pytorch_model=model, artifact_path=name, **kwargs)
|
|
155
|
+
except Exception as e:
|
|
156
|
+
raise ModelFlavorError(
|
|
157
|
+
f"Failed to log PyTorch model: {str(e)}",
|
|
158
|
+
flavor="pytorch",
|
|
159
|
+
details={"model_type": type(model).__name__, "artifact_name": name}
|
|
160
|
+
) from e
|
|
161
|
+
|
|
162
|
+
def load_model(self, model_uri: str) -> Any:
|
|
163
|
+
if not PYTORCH_AVAILABLE:
|
|
164
|
+
raise ModelFlavorError(
|
|
165
|
+
"PyTorch is not available. Please install it to load PyTorch models.",
|
|
166
|
+
flavor="pytorch"
|
|
167
|
+
)
|
|
168
|
+
try:
|
|
169
|
+
return mlflow.pytorch.load_model(model_uri)
|
|
170
|
+
except Exception as e:
|
|
171
|
+
raise ModelFlavorError(
|
|
172
|
+
f"Failed to load PyTorch model: {str(e)}",
|
|
173
|
+
flavor="pytorch",
|
|
174
|
+
details={"model_uri": model_uri}
|
|
175
|
+
) from e
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class LightGBMFlavorHandler(FlavorHandler):
|
|
179
|
+
"""Handler for LightGBM models."""
|
|
180
|
+
|
|
181
|
+
def can_handle(self, model: Any) -> bool:
|
|
182
|
+
if not LIGHTGBM_AVAILABLE or lgb is None:
|
|
183
|
+
return False
|
|
184
|
+
return isinstance(model, (lgb.Booster, lgb.LGBMModel))
|
|
185
|
+
|
|
186
|
+
def flavor(self) -> ModelFlavor:
|
|
187
|
+
return ModelFlavor.LIGHTGBM
|
|
188
|
+
|
|
189
|
+
def log_model(self, model: Any, name: str, **kwargs: Any) -> None:
|
|
190
|
+
if not LIGHTGBM_AVAILABLE:
|
|
191
|
+
raise ModelFlavorError(
|
|
192
|
+
"LightGBM is not available. Please install it to log LightGBM models.",
|
|
193
|
+
flavor="lightgbm"
|
|
194
|
+
)
|
|
195
|
+
try:
|
|
196
|
+
mlflow.lightgbm.log_model(lgb_model=model, artifact_path=name, **kwargs)
|
|
197
|
+
except Exception as e:
|
|
198
|
+
raise ModelFlavorError(
|
|
199
|
+
f"Failed to log LightGBM model: {str(e)}",
|
|
200
|
+
flavor="lightgbm",
|
|
201
|
+
details={"model_type": type(model).__name__, "artifact_name": name}
|
|
202
|
+
) from e
|
|
203
|
+
|
|
204
|
+
def load_model(self, model_uri: str) -> Any:
|
|
205
|
+
if not LIGHTGBM_AVAILABLE:
|
|
206
|
+
raise ModelFlavorError(
|
|
207
|
+
"LightGBM is not available. Please install it to load LightGBM models.",
|
|
208
|
+
flavor="lightgbm"
|
|
209
|
+
)
|
|
210
|
+
try:
|
|
211
|
+
return mlflow.lightgbm.load_model(model_uri)
|
|
212
|
+
except Exception as e:
|
|
213
|
+
raise ModelFlavorError(
|
|
214
|
+
f"Failed to load LightGBM model: {str(e)}",
|
|
215
|
+
flavor="lightgbm",
|
|
216
|
+
details={"model_uri": model_uri}
|
|
217
|
+
) from e
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
class TransformersFlavorHandler(FlavorHandler):
|
|
221
|
+
"""Handler for Hugging Face Transformers models."""
|
|
222
|
+
|
|
223
|
+
def can_handle(self, model: Any) -> bool:
|
|
224
|
+
if not TRANSFORMERS_AVAILABLE:
|
|
225
|
+
return False
|
|
226
|
+
return isinstance(model, PreTrainedModel)
|
|
227
|
+
|
|
228
|
+
def flavor(self) -> ModelFlavor:
|
|
229
|
+
return ModelFlavor.TRANSFORMERS
|
|
230
|
+
|
|
231
|
+
def log_model(self, model: Any, name: str, **kwargs: Any) -> None:
|
|
232
|
+
if not TRANSFORMERS_AVAILABLE:
|
|
233
|
+
raise ModelFlavorError(
|
|
234
|
+
"Transformers is not available. Please install it to log Transformers models.",
|
|
235
|
+
flavor="transformers"
|
|
236
|
+
)
|
|
237
|
+
try:
|
|
238
|
+
mlflow.transformers.log_model(transformers_model=model, artifact_path=name, **kwargs)
|
|
239
|
+
except Exception as e:
|
|
240
|
+
raise ModelFlavorError(
|
|
241
|
+
f"Failed to log Transformers model: {str(e)}",
|
|
242
|
+
flavor="transformers",
|
|
243
|
+
details={"model_type": type(model).__name__, "artifact_name": name}
|
|
244
|
+
) from e
|
|
245
|
+
|
|
246
|
+
def load_model(self, model_uri: str) -> Any:
|
|
247
|
+
if not TRANSFORMERS_AVAILABLE:
|
|
248
|
+
raise ModelFlavorError(
|
|
249
|
+
"Transformers is not available. Please install it to load Transformers models.",
|
|
250
|
+
flavor="transformers"
|
|
251
|
+
)
|
|
252
|
+
try:
|
|
253
|
+
return mlflow.transformers.load_model(model_uri)
|
|
254
|
+
except Exception as e:
|
|
255
|
+
raise ModelFlavorError(
|
|
256
|
+
f"Failed to load Transformers model: {str(e)}",
|
|
257
|
+
flavor="transformers",
|
|
258
|
+
details={"model_uri": model_uri}
|
|
259
|
+
) from e
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from typing import Any, Optional
|
|
2
|
+
|
|
3
|
+
import mlflow
|
|
4
|
+
from mlflow.tracking import MlflowClient
|
|
5
|
+
from collie._common.mlflow_model_io.flavor_registry import FlavorRegistry
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MLflowModelIO:
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
mlflow_client: MlflowClient
|
|
12
|
+
) -> None:
|
|
13
|
+
"""
|
|
14
|
+
Initializes an MLflowModelIO instance.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
mlflow_client (MlflowClient): The MLflowClient instance to use for logging models.
|
|
18
|
+
|
|
19
|
+
"""
|
|
20
|
+
self.registry = FlavorRegistry()
|
|
21
|
+
self.client = mlflow_client
|
|
22
|
+
|
|
23
|
+
def log_model(
|
|
24
|
+
self,
|
|
25
|
+
model: Any,
|
|
26
|
+
name: str,
|
|
27
|
+
registered_model_name: Optional[str] = None,
|
|
28
|
+
**kwargs: Any,
|
|
29
|
+
) -> None:
|
|
30
|
+
"""
|
|
31
|
+
Logs a model with MLflow.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
model (Any): The model to log.
|
|
35
|
+
name (str): The name to give the logged model.
|
|
36
|
+
registered_model_name (Optional[str], optional): The name to give the registered model. Defaults to None.
|
|
37
|
+
**kwargs (Any): Additional keyword arguments to pass to the flavor handler's log_model method.
|
|
38
|
+
|
|
39
|
+
Raises:
|
|
40
|
+
ValueError: If the model type is not supported by any flavor handler.
|
|
41
|
+
|
|
42
|
+
"""
|
|
43
|
+
handler = self.registry.find_handler_by_model(model)
|
|
44
|
+
if handler is None:
|
|
45
|
+
raise ValueError(f"Unsupported model type: {type(model)}")
|
|
46
|
+
|
|
47
|
+
handler.log_model(
|
|
48
|
+
model=model,
|
|
49
|
+
name=name,
|
|
50
|
+
registered_model_name=registered_model_name,
|
|
51
|
+
**kwargs
|
|
52
|
+
)
|
|
53
|
+
mlflow.log_param("model_flavor", handler.flavor())
|
|
54
|
+
|
|
55
|
+
def load_model(
|
|
56
|
+
self,
|
|
57
|
+
flavor: str,
|
|
58
|
+
model_uri: str,
|
|
59
|
+
) -> Any:
|
|
60
|
+
|
|
61
|
+
handler = self.registry.find_handler_by_flavor(flavor)
|
|
62
|
+
if handler is None:
|
|
63
|
+
raise ValueError(f"Unsupported model flavor: {flavor}")
|
|
64
|
+
|
|
65
|
+
return handler.load_model(model_uri)
|
collie/_common/utils.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def get_logger() -> logging.Logger:
|
|
5
|
+
"""
|
|
6
|
+
Return a logger that logs messages with severity level info or higher.
|
|
7
|
+
|
|
8
|
+
Returns:
|
|
9
|
+
A logger that logs messages with severity level info or higher.
|
|
10
|
+
"""
|
|
11
|
+
logging.basicConfig(level=logging.INFO)
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
return logger
|
|
File without changes
|