collie-mlops 0.1.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of collie-mlops might be problematic. Click here for more details.

Files changed (45) hide show
  1. collie/__init__.py +69 -0
  2. collie/_common/__init__.py +0 -0
  3. collie/_common/decorator.py +53 -0
  4. collie/_common/exceptions.py +104 -0
  5. collie/_common/mlflow_model_io/__init__.py +0 -0
  6. collie/_common/mlflow_model_io/base_flavor_handler.py +26 -0
  7. collie/_common/mlflow_model_io/flavor_registry.py +72 -0
  8. collie/_common/mlflow_model_io/model_flavors.py +259 -0
  9. collie/_common/mlflow_model_io/model_io.py +65 -0
  10. collie/_common/utils.py +13 -0
  11. collie/contracts/__init__.py +0 -0
  12. collie/contracts/event.py +79 -0
  13. collie/contracts/mlflow.py +444 -0
  14. collie/contracts/orchestrator.py +79 -0
  15. collie/core/__init__.py +41 -0
  16. collie/core/enums/__init__.py +0 -0
  17. collie/core/enums/components.py +26 -0
  18. collie/core/enums/ml_models.py +20 -0
  19. collie/core/evaluator/__init__.py +0 -0
  20. collie/core/evaluator/evaluator.py +147 -0
  21. collie/core/models.py +125 -0
  22. collie/core/orchestrator/__init__.py +0 -0
  23. collie/core/orchestrator/orchestrator.py +47 -0
  24. collie/core/pusher/__init__.py +0 -0
  25. collie/core/pusher/pusher.py +98 -0
  26. collie/core/trainer/__init__.py +0 -0
  27. collie/core/trainer/trainer.py +78 -0
  28. collie/core/transform/__init__.py +0 -0
  29. collie/core/transform/transform.py +87 -0
  30. collie/core/tuner/__init__.py +0 -0
  31. collie/core/tuner/tuner.py +84 -0
  32. collie/helper/__init__.py +0 -0
  33. collie/helper/pytorch/__init__.py +0 -0
  34. collie/helper/pytorch/callback/__init__.py +0 -0
  35. collie/helper/pytorch/callback/callback.py +155 -0
  36. collie/helper/pytorch/callback/earlystop.py +54 -0
  37. collie/helper/pytorch/callback/model_checkpoint.py +100 -0
  38. collie/helper/pytorch/model/__init__.py +0 -0
  39. collie/helper/pytorch/model/loader.py +55 -0
  40. collie/helper/pytorch/trainer.py +304 -0
  41. collie_mlops-0.1.0b0.dist-info/METADATA +217 -0
  42. collie_mlops-0.1.0b0.dist-info/RECORD +45 -0
  43. collie_mlops-0.1.0b0.dist-info/WHEEL +5 -0
  44. collie_mlops-0.1.0b0.dist-info/licenses/LICENSE +21 -0
  45. collie_mlops-0.1.0b0.dist-info/top_level.txt +1 -0
collie/__init__.py ADDED
@@ -0,0 +1,69 @@
1
+ """
2
+ Collie - A Lightweight MLOps Framework for Machine Learning Workflows
3
+
4
+ Collie provides a modular, event-driven architecture for building ML pipelines
5
+ with deep MLflow integration.
6
+
7
+ Quick Start:
8
+ >>> from collie import Transformer, Trainer, Orchestrator
9
+ >>> # Define your components
10
+ >>> orchestrator = Orchestrator(
11
+ ... components=[MyTransformer(), MyTrainer()],
12
+ ... tracking_uri="http://localhost:5000",
13
+ ... registered_model_name="my_model"
14
+ ... )
15
+ >>> orchestrator.run()
16
+
17
+ For more examples, see: https://github.com/ChingHuanChiu/collie
18
+ """
19
+
20
+ __author__ = "ChingHuanChiu"
21
+ __email__ = "stevenchiou8@gmail.com"
22
+ __version__ = "0.1.0b0"
23
+
24
+ # Import all main components for easy access
25
+ from .contracts.event import Event, EventType, PipelineContext
26
+ from .core.transform.transform import Transformer
27
+ from .core.trainer.trainer import Trainer
28
+ from .core.tuner.tuner import Tuner
29
+ from .core.evaluator.evaluator import Evaluator
30
+ from .core.pusher.pusher import Pusher
31
+ from .core.orchestrator.orchestrator import Orchestrator
32
+
33
+ # Import data models
34
+ from .core.models import (
35
+ TransformerPayload,
36
+ TrainerPayload,
37
+ TunerPayload,
38
+ EvaluatorPayload,
39
+ PusherPayload,
40
+ )
41
+
42
+ # Import enums for configuration
43
+ from .core.enums.ml_models import ModelFlavor, MLflowModelStage
44
+
45
+ __all__ = [
46
+ # Core components - the main classes users interact with
47
+ "Transformer",
48
+ "Trainer",
49
+ "Tuner",
50
+ "Evaluator",
51
+ "Pusher",
52
+ "Orchestrator",
53
+
54
+ # Event system - for building custom workflows
55
+ "Event",
56
+ "EventType",
57
+ "PipelineContext",
58
+
59
+ # Payload models - for type-safe data passing
60
+ "TransformerPayload",
61
+ "TrainerPayload",
62
+ "TunerPayload",
63
+ "EvaluatorPayload",
64
+ "PusherPayload",
65
+
66
+ # Configuration enums
67
+ "ModelFlavor",
68
+ "MLflowModelStage",
69
+ ]
File without changes
@@ -0,0 +1,53 @@
1
+ from typing import Tuple, List
2
+ from functools import wraps
3
+
4
+
5
+ def type_checker(
6
+ typing: Tuple[type],
7
+ error_msg: str
8
+ ):
9
+ """
10
+ A decorator that checks the type of the output of a function.
11
+
12
+ Args:
13
+ typing (Tuple[type]): A tuple of types to check against.
14
+ error_msg (str): The error message to be raised if the type does not match.
15
+
16
+ Raises:
17
+ TypeError: If the type of the output of the function does not match with given types.
18
+ """
19
+
20
+ def closure(func):
21
+ @wraps(func)
22
+ def wrapper(*arg, **kwarg):
23
+ result = func(*arg, **kwarg)
24
+ if not isinstance(result, typing):
25
+ raise TypeError(error_msg)
26
+ return result
27
+ return wrapper
28
+ return closure
29
+
30
+
31
+ def dict_key_checker(keys: List[str]):
32
+ """
33
+ A decorator that checks the keys of the output of a function.
34
+
35
+ Args:
36
+ keys (List[str]): A list of keys to check against.
37
+
38
+ Raises:
39
+ TypeError: If the output of the function is not a dictionary.
40
+ KeyError: If the output of the function does not contain all the keys in the list.
41
+ """
42
+ def closure(func):
43
+ @wraps(func)
44
+ def wrapper(*arg, **kwarg):
45
+ result = func(*arg, **kwarg)
46
+ if not isinstance(result, dict):
47
+ raise TypeError("The output must be a dictionary.")
48
+ all_keys_exist = all(key in result for key in keys)
49
+ if not all_keys_exist:
50
+ raise KeyError(f"The following keys must all exist in the output: {keys}. Output: {result}")
51
+ return result
52
+ return wrapper
53
+ return closure
@@ -0,0 +1,104 @@
1
+
2
+ class CollieBaseException(Exception):
3
+ """Base exception for all Collie framework errors."""
4
+
5
+ def __init__(self, message: str, component: str = None, details: dict = None):
6
+ self.message = message
7
+ self.component = component or self.__class__.__name__.replace('Error', '')
8
+ self.details = details or {}
9
+
10
+ detailed_message = f"[{self.component}] {message}"
11
+ if self.details:
12
+ detailed_message += f" Details: {self.details}"
13
+
14
+ super().__init__(detailed_message)
15
+
16
+
17
+ class MLflowConfigurationError(CollieBaseException):
18
+ """Raised when MLflow configuration is invalid."""
19
+
20
+ def __init__(self, message: str, config_param: str = None, **kwargs):
21
+ details = kwargs.get('details', {})
22
+ if config_param:
23
+ details['config_parameter'] = config_param
24
+ super().__init__(message, component="MLflow Config", details=details)
25
+
26
+
27
+ class MLflowOperationError(CollieBaseException):
28
+ """Raised when MLflow operations fail."""
29
+
30
+ def __init__(self, message: str, operation: str = None, **kwargs):
31
+ details = kwargs.get('details', {})
32
+ if operation:
33
+ details['operation'] = operation
34
+ super().__init__(message, component="MLflow Operation", details=details)
35
+
36
+
37
+ class OrchestratorError(CollieBaseException):
38
+ """Raised for errors in the orchestrator process."""
39
+
40
+ def __init__(self, message: str, pipeline_stage: str = None, **kwargs):
41
+ details = kwargs.get('details', {})
42
+ if pipeline_stage:
43
+ details['pipeline_stage'] = pipeline_stage
44
+ super().__init__(message, component="Orchestrator", details=details)
45
+
46
+
47
+ class TransformerError(CollieBaseException):
48
+ """Raised when data transformation fails."""
49
+
50
+ def __init__(self, message: str, data_type: str = None, **kwargs):
51
+ details = kwargs.get('details', {})
52
+ if data_type:
53
+ details['data_type'] = data_type
54
+ super().__init__(message, component="Transformer", details=details)
55
+
56
+
57
+ class TrainerError(CollieBaseException):
58
+ """Raised when model training fails."""
59
+
60
+ def __init__(self, message: str, model_type: str = None, **kwargs):
61
+ details = kwargs.get('details', {})
62
+ if model_type:
63
+ details['model_type'] = model_type
64
+ super().__init__(message, component="Trainer", details=details)
65
+
66
+
67
+ class TunerError(CollieBaseException):
68
+ """Raised when hyperparameter tuning fails."""
69
+
70
+ def __init__(self, message: str, tuning_method: str = None, **kwargs):
71
+ details = kwargs.get('details', {})
72
+ if tuning_method:
73
+ details['tuning_method'] = tuning_method
74
+ super().__init__(message, component="Tuner", details=details)
75
+
76
+
77
+ class EvaluatorError(CollieBaseException):
78
+ """Raised when model evaluation fails."""
79
+
80
+ def __init__(self, message: str, metric: str = None, **kwargs):
81
+ details = kwargs.get('details', {})
82
+ if metric:
83
+ details['metric'] = metric
84
+ super().__init__(message, component="Evaluator", details=details)
85
+
86
+
87
+ class PusherError(CollieBaseException):
88
+ """Raised when model pushing/deployment fails."""
89
+
90
+ def __init__(self, message: str, deployment_target: str = None, **kwargs):
91
+ details = kwargs.get('details', {})
92
+ if deployment_target:
93
+ details['deployment_target'] = deployment_target
94
+ super().__init__(message, component="Pusher", details=details)
95
+
96
+
97
+ class ModelFlavorError(CollieBaseException):
98
+ """Raised when model flavor operations fail."""
99
+
100
+ def __init__(self, message: str, flavor: str = None, **kwargs):
101
+ details = kwargs.get('details', {})
102
+ if flavor:
103
+ details['flavor'] = flavor
104
+ super().__init__(message, component="Model Flavor", details=details)
File without changes
@@ -0,0 +1,26 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any
3
+
4
+
5
+ class FlavorHandler(ABC):
6
+
7
+ @abstractmethod
8
+ def can_handle(self, model: Any) -> bool:
9
+ raise NotImplementedError
10
+
11
+ @abstractmethod
12
+ def flavor(self):
13
+ raise NotImplementedError
14
+
15
+ @abstractmethod
16
+ def log_model(
17
+ self,
18
+ model: Any,
19
+ name: str,
20
+ **kwargs: Any
21
+ ) -> None:
22
+ raise NotImplementedError
23
+
24
+ @abstractmethod
25
+ def load_model(self, model_uri: str) -> Any:
26
+ raise NotImplementedError
@@ -0,0 +1,72 @@
1
+ from typing import List, Optional
2
+
3
+ from collie._common.mlflow_model_io.base_flavor_handler import FlavorHandler
4
+ from collie._common.mlflow_model_io.model_flavors import (
5
+ SklearnFlavorHandler,
6
+ XGBoostFlavorHandler,
7
+ PyTorchFlavorHandler,
8
+ LightGBMFlavorHandler,
9
+ TransformersFlavorHandler,
10
+ SKLEARN_AVAILABLE,
11
+ XGBOOST_AVAILABLE,
12
+ PYTORCH_AVAILABLE,
13
+ LIGHTGBM_AVAILABLE,
14
+ TRANSFORMERS_AVAILABLE
15
+ )
16
+ from collie._common.exceptions import ModelFlavorError
17
+
18
+
19
+ class FlavorRegistry:
20
+ """Registry for model flavor handlers with conditional loading."""
21
+
22
+ def __init__(self):
23
+ self._handlers: List[FlavorHandler] = []
24
+
25
+ # Only register handlers for available frameworks
26
+ if SKLEARN_AVAILABLE:
27
+ self._handlers.append(SklearnFlavorHandler())
28
+ if XGBOOST_AVAILABLE:
29
+ self._handlers.append(XGBoostFlavorHandler())
30
+ if PYTORCH_AVAILABLE:
31
+ self._handlers.append(PyTorchFlavorHandler())
32
+ if LIGHTGBM_AVAILABLE:
33
+ self._handlers.append(LightGBMFlavorHandler())
34
+ if TRANSFORMERS_AVAILABLE:
35
+ self._handlers.append(TransformersFlavorHandler())
36
+
37
+ if not self._handlers:
38
+ raise ModelFlavorError(
39
+ "No model flavor handlers available. Please install at least one supported ML framework."
40
+ )
41
+
42
+ def find_handler_by_model(self, model) -> Optional[FlavorHandler]:
43
+ """Find a handler that can handle the given model."""
44
+ for handler in self._handlers:
45
+ if handler.can_handle(model):
46
+ return handler
47
+ return None
48
+
49
+ def find_handler_by_flavor(self, flavor: str) -> Optional[FlavorHandler]:
50
+ """Find a handler by flavor name."""
51
+ for handler in self._handlers:
52
+ if handler.flavor() == flavor:
53
+ return handler
54
+ return None
55
+
56
+ def get_available_flavors(self) -> List[str]:
57
+ """Get list of available model flavors."""
58
+ return [handler.flavor().value for handler in self._handlers]
59
+
60
+ def get_handler_info(self) -> dict:
61
+ """Get information about registered handlers."""
62
+ return {
63
+ "total_handlers": len(self._handlers),
64
+ "available_flavors": self.get_available_flavors(),
65
+ "framework_status": {
66
+ "sklearn": SKLEARN_AVAILABLE,
67
+ "xgboost": XGBOOST_AVAILABLE,
68
+ "pytorch": PYTORCH_AVAILABLE,
69
+ "lightgbm": LIGHTGBM_AVAILABLE,
70
+ "transformers": TRANSFORMERS_AVAILABLE
71
+ }
72
+ }
@@ -0,0 +1,259 @@
1
+ from typing import Any
2
+ import warnings
3
+
4
+
5
+ # Import with better error handling
6
+ try:
7
+ import mlflow.sklearn
8
+ import sklearn.base
9
+ SKLEARN_AVAILABLE = True
10
+ except ImportError:
11
+ SKLEARN_AVAILABLE = False
12
+ warnings.warn("scikit-learn not available. SklearnFlavorHandler will be disabled.")
13
+
14
+ try:
15
+ import mlflow.xgboost
16
+ import xgboost as xgb
17
+ XGBOOST_AVAILABLE = True
18
+ except ImportError:
19
+ XGBOOST_AVAILABLE = False
20
+ warnings.warn("XGBoost not available. XGBoostFlavorHandler will be disabled.")
21
+
22
+ try:
23
+ import mlflow.pytorch
24
+ import torch.nn as nn
25
+ PYTORCH_AVAILABLE = True
26
+ except ImportError:
27
+ PYTORCH_AVAILABLE = False
28
+ warnings.warn("PyTorch not available. PyTorchFlavorHandler will be disabled.")
29
+
30
+ try:
31
+ import mlflow.lightgbm
32
+ import lightgbm as lgb
33
+ LIGHTGBM_AVAILABLE = True
34
+ except (ImportError, Exception):
35
+ LIGHTGBM_AVAILABLE = False
36
+ lgb = None
37
+ warnings.warn("LightGBM not available. LightGBMFlavorHandler will be disabled.")
38
+
39
+ try:
40
+ import mlflow.transformers
41
+ from transformers import PreTrainedModel
42
+ TRANSFORMERS_AVAILABLE = True
43
+ except ImportError:
44
+ TRANSFORMERS_AVAILABLE = False
45
+ warnings.warn("Transformers not available. TransformersFlavorHandler will be disabled.")
46
+
47
+ from collie._common.mlflow_model_io.base_flavor_handler import FlavorHandler
48
+ from collie.core.enums.ml_models import ModelFlavor
49
+ from collie._common.exceptions import ModelFlavorError
50
+
51
+
52
+ class SklearnFlavorHandler(FlavorHandler):
53
+ """Handler for scikit-learn models."""
54
+
55
+ def can_handle(self, model: Any) -> bool:
56
+ if not SKLEARN_AVAILABLE:
57
+ return False
58
+ return isinstance(model, sklearn.base.BaseEstimator)
59
+
60
+ def flavor(self) -> ModelFlavor:
61
+ return ModelFlavor.SKLEARN
62
+
63
+ def log_model(self, model: Any, name: str, **kwargs: Any) -> None:
64
+ if not SKLEARN_AVAILABLE:
65
+ raise ModelFlavorError(
66
+ "scikit-learn is not available. Please install it to log sklearn models.",
67
+ flavor="sklearn"
68
+ )
69
+ try:
70
+ mlflow.sklearn.log_model(sk_model=model, artifact_path=name, **kwargs)
71
+ except Exception as e:
72
+ raise ModelFlavorError(
73
+ f"Failed to log sklearn model: {str(e)}",
74
+ flavor="sklearn",
75
+ details={"model_type": type(model).__name__, "artifact_name": name}
76
+ ) from e
77
+
78
+ def load_model(self, model_uri: str) -> Any:
79
+ if not SKLEARN_AVAILABLE:
80
+ raise ModelFlavorError(
81
+ "scikit-learn is not available. Please install it to load sklearn models.",
82
+ flavor="sklearn"
83
+ )
84
+ try:
85
+ return mlflow.sklearn.load_model(model_uri)
86
+ except Exception as e:
87
+ raise ModelFlavorError(
88
+ f"Failed to load sklearn model: {str(e)}",
89
+ flavor="sklearn",
90
+ details={"model_uri": model_uri}
91
+ ) from e
92
+
93
+
94
+ class XGBoostFlavorHandler(FlavorHandler):
95
+ """Handler for XGBoost models."""
96
+
97
+ def can_handle(self, model: Any) -> bool:
98
+ if not XGBOOST_AVAILABLE:
99
+ return False
100
+ return isinstance(model, (xgb.Booster, xgb.XGBModel))
101
+
102
+ def flavor(self) -> ModelFlavor:
103
+ return ModelFlavor.XGBOOST
104
+
105
+ def log_model(self, model: Any, name: str, **kwargs: Any) -> None:
106
+ if not XGBOOST_AVAILABLE:
107
+ raise ModelFlavorError(
108
+ "XGBoost is not available. Please install it to log XGBoost models.",
109
+ flavor="xgboost"
110
+ )
111
+ try:
112
+ mlflow.xgboost.log_model(xgb_model=model, artifact_path=name, **kwargs)
113
+ except Exception as e:
114
+ raise ModelFlavorError(
115
+ f"Failed to log XGBoost model: {str(e)}",
116
+ flavor="xgboost",
117
+ details={"model_type": type(model).__name__, "artifact_name": name}
118
+ ) from e
119
+
120
+ def load_model(self, model_uri: str) -> Any:
121
+ if not XGBOOST_AVAILABLE:
122
+ raise ModelFlavorError(
123
+ "XGBoost is not available. Please install it to load XGBoost models.",
124
+ flavor="xgboost"
125
+ )
126
+ try:
127
+ return mlflow.xgboost.load_model(model_uri)
128
+ except Exception as e:
129
+ raise ModelFlavorError(
130
+ f"Failed to load XGBoost model: {str(e)}",
131
+ flavor="xgboost",
132
+ details={"model_uri": model_uri}
133
+ ) from e
134
+
135
+
136
+ class PyTorchFlavorHandler(FlavorHandler):
137
+ """Handler for PyTorch models."""
138
+
139
+ def can_handle(self, model: Any) -> bool:
140
+ if not PYTORCH_AVAILABLE:
141
+ return False
142
+ return isinstance(model, nn.Module)
143
+
144
+ def flavor(self) -> ModelFlavor:
145
+ return ModelFlavor.PYTORCH
146
+
147
+ def log_model(self, model: Any, name: str, **kwargs: Any) -> None:
148
+ if not PYTORCH_AVAILABLE:
149
+ raise ModelFlavorError(
150
+ "PyTorch is not available. Please install it to log PyTorch models.",
151
+ flavor="pytorch"
152
+ )
153
+ try:
154
+ mlflow.pytorch.log_model(pytorch_model=model, artifact_path=name, **kwargs)
155
+ except Exception as e:
156
+ raise ModelFlavorError(
157
+ f"Failed to log PyTorch model: {str(e)}",
158
+ flavor="pytorch",
159
+ details={"model_type": type(model).__name__, "artifact_name": name}
160
+ ) from e
161
+
162
+ def load_model(self, model_uri: str) -> Any:
163
+ if not PYTORCH_AVAILABLE:
164
+ raise ModelFlavorError(
165
+ "PyTorch is not available. Please install it to load PyTorch models.",
166
+ flavor="pytorch"
167
+ )
168
+ try:
169
+ return mlflow.pytorch.load_model(model_uri)
170
+ except Exception as e:
171
+ raise ModelFlavorError(
172
+ f"Failed to load PyTorch model: {str(e)}",
173
+ flavor="pytorch",
174
+ details={"model_uri": model_uri}
175
+ ) from e
176
+
177
+
178
+ class LightGBMFlavorHandler(FlavorHandler):
179
+ """Handler for LightGBM models."""
180
+
181
+ def can_handle(self, model: Any) -> bool:
182
+ if not LIGHTGBM_AVAILABLE or lgb is None:
183
+ return False
184
+ return isinstance(model, (lgb.Booster, lgb.LGBMModel))
185
+
186
+ def flavor(self) -> ModelFlavor:
187
+ return ModelFlavor.LIGHTGBM
188
+
189
+ def log_model(self, model: Any, name: str, **kwargs: Any) -> None:
190
+ if not LIGHTGBM_AVAILABLE:
191
+ raise ModelFlavorError(
192
+ "LightGBM is not available. Please install it to log LightGBM models.",
193
+ flavor="lightgbm"
194
+ )
195
+ try:
196
+ mlflow.lightgbm.log_model(lgb_model=model, artifact_path=name, **kwargs)
197
+ except Exception as e:
198
+ raise ModelFlavorError(
199
+ f"Failed to log LightGBM model: {str(e)}",
200
+ flavor="lightgbm",
201
+ details={"model_type": type(model).__name__, "artifact_name": name}
202
+ ) from e
203
+
204
+ def load_model(self, model_uri: str) -> Any:
205
+ if not LIGHTGBM_AVAILABLE:
206
+ raise ModelFlavorError(
207
+ "LightGBM is not available. Please install it to load LightGBM models.",
208
+ flavor="lightgbm"
209
+ )
210
+ try:
211
+ return mlflow.lightgbm.load_model(model_uri)
212
+ except Exception as e:
213
+ raise ModelFlavorError(
214
+ f"Failed to load LightGBM model: {str(e)}",
215
+ flavor="lightgbm",
216
+ details={"model_uri": model_uri}
217
+ ) from e
218
+
219
+
220
+ class TransformersFlavorHandler(FlavorHandler):
221
+ """Handler for Hugging Face Transformers models."""
222
+
223
+ def can_handle(self, model: Any) -> bool:
224
+ if not TRANSFORMERS_AVAILABLE:
225
+ return False
226
+ return isinstance(model, PreTrainedModel)
227
+
228
+ def flavor(self) -> ModelFlavor:
229
+ return ModelFlavor.TRANSFORMERS
230
+
231
+ def log_model(self, model: Any, name: str, **kwargs: Any) -> None:
232
+ if not TRANSFORMERS_AVAILABLE:
233
+ raise ModelFlavorError(
234
+ "Transformers is not available. Please install it to log Transformers models.",
235
+ flavor="transformers"
236
+ )
237
+ try:
238
+ mlflow.transformers.log_model(transformers_model=model, artifact_path=name, **kwargs)
239
+ except Exception as e:
240
+ raise ModelFlavorError(
241
+ f"Failed to log Transformers model: {str(e)}",
242
+ flavor="transformers",
243
+ details={"model_type": type(model).__name__, "artifact_name": name}
244
+ ) from e
245
+
246
+ def load_model(self, model_uri: str) -> Any:
247
+ if not TRANSFORMERS_AVAILABLE:
248
+ raise ModelFlavorError(
249
+ "Transformers is not available. Please install it to load Transformers models.",
250
+ flavor="transformers"
251
+ )
252
+ try:
253
+ return mlflow.transformers.load_model(model_uri)
254
+ except Exception as e:
255
+ raise ModelFlavorError(
256
+ f"Failed to load Transformers model: {str(e)}",
257
+ flavor="transformers",
258
+ details={"model_uri": model_uri}
259
+ ) from e
@@ -0,0 +1,65 @@
1
+ from typing import Any, Optional
2
+
3
+ import mlflow
4
+ from mlflow.tracking import MlflowClient
5
+ from collie._common.mlflow_model_io.flavor_registry import FlavorRegistry
6
+
7
+
8
+ class MLflowModelIO:
9
+ def __init__(
10
+ self,
11
+ mlflow_client: MlflowClient
12
+ ) -> None:
13
+ """
14
+ Initializes an MLflowModelIO instance.
15
+
16
+ Args:
17
+ mlflow_client (MlflowClient): The MLflowClient instance to use for logging models.
18
+
19
+ """
20
+ self.registry = FlavorRegistry()
21
+ self.client = mlflow_client
22
+
23
+ def log_model(
24
+ self,
25
+ model: Any,
26
+ name: str,
27
+ registered_model_name: Optional[str] = None,
28
+ **kwargs: Any,
29
+ ) -> None:
30
+ """
31
+ Logs a model with MLflow.
32
+
33
+ Args:
34
+ model (Any): The model to log.
35
+ name (str): The name to give the logged model.
36
+ registered_model_name (Optional[str], optional): The name to give the registered model. Defaults to None.
37
+ **kwargs (Any): Additional keyword arguments to pass to the flavor handler's log_model method.
38
+
39
+ Raises:
40
+ ValueError: If the model type is not supported by any flavor handler.
41
+
42
+ """
43
+ handler = self.registry.find_handler_by_model(model)
44
+ if handler is None:
45
+ raise ValueError(f"Unsupported model type: {type(model)}")
46
+
47
+ handler.log_model(
48
+ model=model,
49
+ name=name,
50
+ registered_model_name=registered_model_name,
51
+ **kwargs
52
+ )
53
+ mlflow.log_param("model_flavor", handler.flavor())
54
+
55
+ def load_model(
56
+ self,
57
+ flavor: str,
58
+ model_uri: str,
59
+ ) -> Any:
60
+
61
+ handler = self.registry.find_handler_by_flavor(flavor)
62
+ if handler is None:
63
+ raise ValueError(f"Unsupported model flavor: {flavor}")
64
+
65
+ return handler.load_model(model_uri)
@@ -0,0 +1,13 @@
1
+ import logging
2
+
3
+
4
+ def get_logger() -> logging.Logger:
5
+ """
6
+ Return a logger that logs messages with severity level info or higher.
7
+
8
+ Returns:
9
+ A logger that logs messages with severity level info or higher.
10
+ """
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+ return logger
File without changes