collie-mlops 0.1.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of collie-mlops might be problematic. Click here for more details.

Files changed (45) hide show
  1. collie/__init__.py +69 -0
  2. collie/_common/__init__.py +0 -0
  3. collie/_common/decorator.py +53 -0
  4. collie/_common/exceptions.py +104 -0
  5. collie/_common/mlflow_model_io/__init__.py +0 -0
  6. collie/_common/mlflow_model_io/base_flavor_handler.py +26 -0
  7. collie/_common/mlflow_model_io/flavor_registry.py +72 -0
  8. collie/_common/mlflow_model_io/model_flavors.py +259 -0
  9. collie/_common/mlflow_model_io/model_io.py +65 -0
  10. collie/_common/utils.py +13 -0
  11. collie/contracts/__init__.py +0 -0
  12. collie/contracts/event.py +79 -0
  13. collie/contracts/mlflow.py +444 -0
  14. collie/contracts/orchestrator.py +79 -0
  15. collie/core/__init__.py +41 -0
  16. collie/core/enums/__init__.py +0 -0
  17. collie/core/enums/components.py +26 -0
  18. collie/core/enums/ml_models.py +20 -0
  19. collie/core/evaluator/__init__.py +0 -0
  20. collie/core/evaluator/evaluator.py +147 -0
  21. collie/core/models.py +125 -0
  22. collie/core/orchestrator/__init__.py +0 -0
  23. collie/core/orchestrator/orchestrator.py +47 -0
  24. collie/core/pusher/__init__.py +0 -0
  25. collie/core/pusher/pusher.py +98 -0
  26. collie/core/trainer/__init__.py +0 -0
  27. collie/core/trainer/trainer.py +78 -0
  28. collie/core/transform/__init__.py +0 -0
  29. collie/core/transform/transform.py +87 -0
  30. collie/core/tuner/__init__.py +0 -0
  31. collie/core/tuner/tuner.py +84 -0
  32. collie/helper/__init__.py +0 -0
  33. collie/helper/pytorch/__init__.py +0 -0
  34. collie/helper/pytorch/callback/__init__.py +0 -0
  35. collie/helper/pytorch/callback/callback.py +155 -0
  36. collie/helper/pytorch/callback/earlystop.py +54 -0
  37. collie/helper/pytorch/callback/model_checkpoint.py +100 -0
  38. collie/helper/pytorch/model/__init__.py +0 -0
  39. collie/helper/pytorch/model/loader.py +55 -0
  40. collie/helper/pytorch/trainer.py +304 -0
  41. collie_mlops-0.1.0b0.dist-info/METADATA +217 -0
  42. collie_mlops-0.1.0b0.dist-info/RECORD +45 -0
  43. collie_mlops-0.1.0b0.dist-info/WHEEL +5 -0
  44. collie_mlops-0.1.0b0.dist-info/licenses/LICENSE +21 -0
  45. collie_mlops-0.1.0b0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,79 @@
1
+ from typing import (
2
+ Dict,
3
+ Any,
4
+ Optional,
5
+ TypeVar
6
+ )
7
+ from abc import abstractmethod, ABC
8
+ from pydantic import (
9
+ Field,
10
+ BaseModel,
11
+ ConfigDict
12
+ )
13
+ from enum import Enum, auto
14
+
15
+ from collie._common.decorator import type_checker
16
+
17
+
18
+ class PipelineContext:
19
+ def __init__(self, data: Optional[Dict[str, Any]] = None):
20
+ self.data = data or {}
21
+
22
+ def get(self, key: str, default=None):
23
+ return self.data.get(key, default)
24
+
25
+ def set(self, key: str, value: Any):
26
+ self.data[key] = value
27
+
28
+ def to_dict(self):
29
+ return self.data
30
+
31
+ class EventType(Enum):
32
+ INITIALIZE = auto()
33
+ DATA_READY = auto()
34
+ TRAINING_DONE = auto()
35
+ TUNING_DONE = auto()
36
+ EVALUATION_DONE = auto()
37
+ PUSHER_DONE = auto()
38
+ ERROR = auto()
39
+
40
+
41
+ P = TypeVar("P")
42
+ class Event(BaseModel):
43
+ type: Optional[EventType] = None
44
+ payload: P
45
+ context: PipelineContext = Field(default_factory=PipelineContext)
46
+
47
+ model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
48
+
49
+
50
+ class EventHandler(ABC):
51
+
52
+ @abstractmethod
53
+ def handle(self, event: Event) -> Event:
54
+ """
55
+ Handle the incoming event and return a new event.
56
+
57
+ This method must be implemented by all concrete event handlers.
58
+ It should process the event payload and return a new event with
59
+ the appropriate type and payload for the next component.
60
+
61
+ Args:
62
+ event (Event): The incoming event to process
63
+
64
+ Returns:
65
+ Event: A new event with processed payload
66
+
67
+ Raises:
68
+ NotImplementedError: If not implemented by concrete class
69
+ """
70
+ raise NotImplementedError(
71
+ f"{self.__class__.__name__} must implement the 'handle' method. "
72
+ f"This method should process the incoming event and return a new event "
73
+ f"with the appropriate payload for the next pipeline component."
74
+ )
75
+
76
+ @type_checker((Event,), "The return type of *handle* method must be 'Event'.")
77
+ def _handle(self, event: Event) -> Event:
78
+ """Internal wrapper for handle method with type checking."""
79
+ return self.handle(event)
@@ -0,0 +1,444 @@
1
+ import tempfile
2
+ import threading
3
+ from abc import ABC, abstractmethod
4
+ from typing import (
5
+ Any,
6
+ Optional,
7
+ Dict,
8
+ Literal,
9
+ List,
10
+ Union,
11
+ Generator,
12
+ overload
13
+ )
14
+ from contextlib import contextmanager
15
+
16
+ import mlflow
17
+ import mlflow.data
18
+ import pandas as pd
19
+ from mlflow.tracking import MlflowClient
20
+ from mlflow import ActiveRun
21
+ from mlflow.exceptions import MlflowException
22
+
23
+ from collie._common.utils import get_logger
24
+ from collie._common.mlflow_model_io.model_io import MLflowModelIO
25
+ from collie._common.exceptions import (
26
+ MLflowConfigurationError,
27
+ MLflowOperationError
28
+ )
29
+ from collie.core.enums.ml_models import (
30
+ MLflowModelStage,
31
+ ModelFlavor
32
+ )
33
+
34
+
35
+ logger = get_logger()
36
+
37
+
38
+ class MLflowConfig:
39
+ """Singleton class to manage MLflow configuration."""
40
+
41
+ _instance = None
42
+ _lock = threading.Lock()
43
+
44
+ def __new__(cls, *args, **kwargs):
45
+ if cls._instance is None:
46
+ with cls._lock:
47
+ if cls._instance is None:
48
+ cls._instance = super().__new__(cls)
49
+ cls._instance._initialized = False
50
+ return cls._instance
51
+
52
+ def __init__(
53
+ self,
54
+ tracking_uri: str,
55
+ experiment_name: str,
56
+ ) -> None:
57
+
58
+ if not self._initialized:
59
+ self.tracking_uri = tracking_uri
60
+ self.experiment_name = experiment_name
61
+ self.mlflow_client = MlflowClient(tracking_uri=tracking_uri)
62
+ self._initialized = True
63
+
64
+ def configure(self) -> None:
65
+ """Configure the singleton with MLflow settings."""
66
+ mlflow.set_tracking_uri(self.tracking_uri)
67
+ mlflow.set_experiment(self.experiment_name)
68
+
69
+
70
+ class _MLflowModelManager:
71
+ """Handles MLflow model management operations."""
72
+
73
+ def __init__(
74
+ self,
75
+ mlflow_client: MlflowClient
76
+ ) -> None:
77
+
78
+ self._mlflow_client = mlflow_client
79
+ self._model_io = MLflowModelIO(mlflow_client)
80
+
81
+ def log_model(self, model: Any, name: Optional[str] = None) -> None:
82
+
83
+ """
84
+ Logs a model with MLflow.
85
+
86
+ Args:
87
+ model (Any): The model to log.
88
+ name (Optional[str], optional): The name to give the logged model. Defaults to None.
89
+
90
+ Raises:
91
+ MLflowOperationError: If logging the model fails.
92
+ """
93
+ try:
94
+ self._model_io.log_model(model, name)
95
+ logger.info(f"Logged model: {name or 'unnamed'}")
96
+ except Exception as e:
97
+ raise MLflowOperationError(f"Failed to log model '{name}': {e}") from e
98
+
99
+ def load_model(
100
+ self,
101
+ flavor: ModelFlavor,
102
+ model_uri: Optional[str] = None
103
+ ) -> Any:
104
+
105
+ try:
106
+ active_run = mlflow.active_run()
107
+ if active_run is None:
108
+ raise MLflowOperationError("No active run found")
109
+
110
+ model = self._model_io.load_model(flavor, model_uri)
111
+ return model
112
+ except Exception as e:
113
+ raise MLflowOperationError(f"Failed to load model '{model_uri}': {e}") from e
114
+
115
+ def load_latest_model(
116
+ self,
117
+ model_name: str,
118
+ stage: MLflowModelStage,
119
+ flavor: ModelFlavor
120
+ ) -> Any:
121
+
122
+ """
123
+ Loads the latest version of a model from a given stage.
124
+
125
+ Args:
126
+ model_name (str): The name of the model to load.
127
+ stage (MLflowModelStage): The stage from which to load the model.
128
+ flavor (ModelFlavor): The flavor of the model to load.
129
+
130
+ Returns:
131
+ Any: The loaded model if found, otherwise None.
132
+
133
+ Raises:
134
+ MLflowOperationError: If loading the model fails.
135
+ """
136
+ try:
137
+ latest_versions = self._mlflow_client.get_latest_versions(model_name, stages=[stage])
138
+ if not latest_versions:
139
+ logger.warning(f"No model found in stage '{stage}' for model '{model_name}'")
140
+ return None
141
+
142
+ latest = latest_versions[0]
143
+ model_uri = latest.source
144
+ except Exception as e:
145
+ logger.error(f"Failed to get latest model version for '{model_name}' in stage '{stage}': {e}")
146
+ return None
147
+
148
+ return self.load_model(flavor=flavor, model_uri=model_uri)
149
+
150
+ def register_model(
151
+ self,
152
+ model_name: str,
153
+ model_uri: str
154
+ ) -> int:
155
+
156
+ """
157
+ Registers a model with MLflow.
158
+
159
+ Args:
160
+ model_name (str): The name to give the registered model.
161
+ model_uri (str): The URI of the model to register.
162
+
163
+ Returns:
164
+ int: The version number of the registered model.
165
+
166
+ Raises:
167
+ MLflowOperationError: If registering the model fails.
168
+ """
169
+
170
+ try:
171
+ registered_model = mlflow.register_model(model_uri, model_name)
172
+ logger.info(f"Registered model '{model_name}' version {registered_model.version}")
173
+ return registered_model.version
174
+ except MlflowException as e:
175
+ raise MLflowOperationError(
176
+ f"Failed to register model '{model_name}' with URI '{model_uri}': {e}"
177
+ ) from e
178
+
179
+ def transition_model_version(
180
+ self,
181
+ registered_model_name: str,
182
+ version: str,
183
+ desired_stage: str,
184
+ archive_existing_versions_at_stage: bool = False,
185
+ ) -> None:
186
+
187
+ """
188
+ Transitions a model version from one stage to another.
189
+
190
+ Args:
191
+ registered_model_name (str): The name of the registered model to transition.
192
+ version (str): The version of the model to transition.
193
+ desired_stage (str): The desired stage to transition the model to.
194
+ archive_existing_versions_at_stage (bool, optional): Whether to archive existing versions at the target stage. Defaults to False.
195
+
196
+ Raises:
197
+ MLflowOperationError: If transitioning the model version fails.
198
+ """
199
+ try:
200
+ self._mlflow_client.transition_model_version_stage(
201
+ name=registered_model_name,
202
+ version=version,
203
+ stage=desired_stage,
204
+ archive_existing_versions=archive_existing_versions_at_stage,
205
+ )
206
+ logger.info(f"Transitioned model '{registered_model_name}' v{version} to {desired_stage}")
207
+ except MlflowException as e:
208
+ raise MLflowOperationError(
209
+ f"Failed to transition model '{registered_model_name}' v{version} to {desired_stage}: {e}"
210
+ ) from e
211
+
212
+ def get_latest_version(
213
+ self,
214
+ model_name: str,
215
+ stages
216
+ ) -> Optional[str]:
217
+
218
+ """
219
+ Retrieves the latest version number of a model in the specified stages.
220
+
221
+ Args:
222
+ model_name (str): The name of the model to retrieve the latest version for.
223
+ stages (List[MLflowModelStage], optional): The stages in which to search for the latest version.
224
+
225
+ Returns:
226
+ Optional[str]: The latest version number of the model in the specified stages if found, otherwise None.
227
+
228
+ Raises:
229
+ MLflowOperationError: If retrieving the latest version fails.
230
+ """
231
+ try:
232
+ latest_versions = self._mlflow_client.get_latest_versions(model_name, stages=stages)
233
+ if not latest_versions:
234
+ logger.warning(f"No versions found for model '{model_name}' in stages {stages}")
235
+ return None
236
+
237
+ latest_version = max(latest_versions, key=lambda v: int(v.version))
238
+ return latest_version.version
239
+
240
+ except MlflowException as e:
241
+ raise MLflowOperationError(
242
+ f"Failed to get latest version for model '{model_name}': {e}"
243
+ ) from e
244
+
245
+
246
+ class MLFlowComponentABC(ABC):
247
+ """
248
+ Abstract base class for MLflow components with separated concerns.
249
+ """
250
+
251
+ def __init__(self) -> None:
252
+ super().__init__()
253
+ self._mlflow_config = None
254
+ self._model_manager = None
255
+ self.mlflow = mlflow
256
+
257
+ @abstractmethod
258
+ def run(self, *args, **kwargs) -> Any:
259
+ raise NotImplementedError("Please implement the **run** method.")
260
+
261
+ @property
262
+ def mlflow_config(self) -> MLflowConfig:
263
+
264
+ if self._mlflow_config is None:
265
+ raise MLflowConfigurationError("MLflow client not set")
266
+ return self._mlflow_config
267
+
268
+ @mlflow_config.setter
269
+ def mlflow_config(self, mlflow_config: MLflowConfig):
270
+ if not isinstance(mlflow_config, MLflowConfig):
271
+ raise MLflowConfigurationError("mlflow_config must be an instance of MLflowConfig")
272
+ self._mlflow_config = mlflow_config
273
+
274
+ @property
275
+ def model_manager(self) -> Optional[_MLflowModelManager]:
276
+
277
+ if self._model_manager is None:
278
+ mlflow_client = self._mlflow_config.mlflow_client
279
+ self._model_manager = _MLflowModelManager(mlflow_client)
280
+ return self._model_manager
281
+
282
+ def log_pd_data(
283
+ self,
284
+ data: pd.DataFrame,
285
+ context: str,
286
+ source: str
287
+ ) -> None:
288
+ if not isinstance(data, pd.DataFrame):
289
+ raise ValueError("Data must be a pandas DataFrame.")
290
+
291
+ try:
292
+ ds = self.mlflow.data.from_pandas(data, source=source)
293
+ self.mlflow.log_input(ds, context=context)
294
+
295
+ with tempfile.NamedTemporaryFile(delete=True, suffix=".csv") as tmp:
296
+ data.to_csv(tmp.name, index=False)
297
+ self.mlflow.log_artifact(tmp.name)
298
+
299
+ logger.debug(f"Logged pandas data for context: {context}")
300
+ except Exception as e:
301
+ raise MLflowOperationError(f"Failed to log pandas data: {e}") from e
302
+
303
+ def log_model(
304
+ self,
305
+ model: Any,
306
+ name: Optional[str] = None
307
+ ) -> None:
308
+
309
+ self.model_manager.log_model(model, name)
310
+
311
+ def load_model(
312
+ self,
313
+ name: Optional[str] = None
314
+ ) -> Any:
315
+
316
+ return self.model_manager.load_model(name)
317
+
318
+ def load_latest_model(
319
+ self,
320
+ model_name: str,
321
+ flavor: ModelFlavor,
322
+ stage: MLflowModelStage = MLflowModelStage.PRODUCTION,
323
+ ) -> Any:
324
+
325
+ return self.model_manager.load_latest_model(
326
+ model_name,
327
+ stage,
328
+ flavor
329
+ )
330
+
331
+ def register_model(
332
+ self,
333
+ model_name: str,
334
+ model_uri: str
335
+ ) -> int:
336
+
337
+ return self.model_manager.register_model(model_name, model_uri)
338
+
339
+ def transition_model_version(
340
+ self,
341
+ registered_model_name: str,
342
+ version: str,
343
+ desired_stage: str,
344
+ archive_existing_versions_at_stage: bool = False
345
+ ) -> None:
346
+
347
+ self.model_manager.transition_model_version(
348
+ registered_model_name,
349
+ version,
350
+ desired_stage,
351
+ archive_existing_versions_at_stage
352
+ )
353
+
354
+ def get_latest_version(
355
+ self,
356
+ model_name: str,
357
+ stages: List[Literal["None", "Staging", "Production", "Archived"]]
358
+ ) -> int:
359
+
360
+ return self.model_manager.get_latest_version(model_name, stages)
361
+
362
+ @overload
363
+ def get_experiment(
364
+ self,
365
+ return_id: Literal[True]
366
+ ) -> Optional[str]: ...
367
+
368
+ @overload
369
+ def get_experiment(
370
+ self,
371
+ return_id: Literal[False] = False
372
+ ) -> Optional[mlflow.entities.Experiment]: ...
373
+
374
+ def get_experiment(
375
+ self,
376
+ return_id: bool = False
377
+ ) -> Optional[Union[mlflow.entities.Experiment, str]]:
378
+
379
+ """
380
+ Retrieves the MLflow experiment corresponding to the configured experiment name.
381
+
382
+ Args:
383
+ return_id (bool, optional): If True, returns the experiment ID instead of the experiment object.
384
+ Defaults to False.
385
+
386
+ Returns:
387
+ Optional[Union[mlflow.entities.Experiment, str]]: The experiment object or experiment ID if return_id is True,
388
+ or None if the experiment does not exist.
389
+ """
390
+ experiment_name = self.mlflow_config.experiment_name
391
+ if not experiment_name:
392
+ return None
393
+ try:
394
+ experiment = self.mlflow.get_experiment_by_name(experiment_name)
395
+ if return_id:
396
+ return experiment.experiment_id
397
+ else:
398
+ return experiment
399
+ except MlflowException as e:
400
+ logger.error(f"Failed to get experiment '{experiment_name}': {e}")
401
+ return None
402
+
403
+
404
+ @contextmanager
405
+ def start_run(
406
+ self,
407
+ tags: Optional[Dict[str, str]] = None,
408
+ run_name: Optional[str] = None,
409
+ nested: bool = False,
410
+ log_system_metrics: Optional[bool] = None,
411
+ description: Optional[str] = None,
412
+ ) -> Generator[ActiveRun, None, None]:
413
+ """
414
+ Starts an MLflow run and returns the active run object.
415
+
416
+ Args:
417
+ tags (Optional[Dict[str, str]], optional): A dictionary of string key-value pairs to store as run tags. Defaults to None.
418
+ run_name (Optional[str], optional): Name for the run. Defaults to None.
419
+ nested (bool, optional): If True, nested runs are enabled. Defaults to False.
420
+ log_system_metrics (Optional[bool], optional): If True, system metrics are logged. Defaults to None.
421
+ description (Optional[str], optional): A string description for the run. Defaults to None.
422
+
423
+ Yields:
424
+ Generator[ActiveRun, None, None]: The active run object.
425
+ Raises:
426
+ MLflowOperationError: If the MLflow run cannot be started.
427
+ """
428
+ try:
429
+ self.mlflow_config.configure()
430
+ experiment_id = self.get_experiment(return_id=True)
431
+
432
+ with self.mlflow.start_run(
433
+ experiment_id=experiment_id,
434
+ run_name=run_name,
435
+ nested=nested,
436
+ tags=tags,
437
+ log_system_metrics=log_system_metrics,
438
+ description=description,
439
+ ) as active_run:
440
+ logger.info(f"Started MLflow run: {active_run.info.run_id}")
441
+ yield active_run
442
+
443
+ except MlflowException as e:
444
+ raise MLflowOperationError(f"Failed to start MLflow run: {e}") from e
@@ -0,0 +1,79 @@
1
+ from typing import (
2
+ Optional,
3
+ Dict,
4
+ Any
5
+ )
6
+ from abc import abstractmethod
7
+
8
+ from collie.contracts.event import Event, EventType
9
+ from collie.core.enums.components import CollieComponentType
10
+ from collie.contracts.mlflow import (
11
+ MLFlowComponentABC,
12
+ MLflowConfig
13
+ )
14
+ from collie._common.exceptions import (
15
+ OrchestratorError,
16
+ TrainerError,
17
+ TunerError,
18
+ EvaluatorError,
19
+ PusherError,
20
+ TransformerError,
21
+ )
22
+
23
+
24
+ class OrchestratorBase(MLFlowComponentABC):
25
+
26
+ def __init__(
27
+ self,
28
+ components: CollieComponentType,
29
+ tracking_uri: Optional[str] = None,
30
+ mlflow_tags: Optional[Dict[str, str]] = None,
31
+ experiment_name: Optional[str] = None,
32
+ description: Optional[str] = None
33
+ ) -> None:
34
+
35
+ super().__init__()
36
+ self.components = components
37
+ self.mlflow_tags = mlflow_tags
38
+ self.tracking_uri = tracking_uri
39
+ self.description = description
40
+ self.experiment_name = experiment_name
41
+
42
+ @abstractmethod
43
+ def orchestrate_pipeline(self) -> Any:
44
+ raise NotImplementedError
45
+
46
+ def run(self) -> Any:
47
+
48
+ self.mlflow_config = MLflowConfig(
49
+ tracking_uri=self.tracking_uri,
50
+ experiment_name=self.experiment_name,
51
+ )
52
+ try:
53
+ with self.start_run(
54
+ tags=self.mlflow_tags,
55
+ run_name="Orchestrator",
56
+ description=self.description,
57
+ ):
58
+ return self.orchestrate_pipeline()
59
+ except (
60
+ TrainerError,
61
+ TunerError,
62
+ EvaluatorError,
63
+ PusherError,
64
+ TransformerError,
65
+ ) as e:
66
+ raise OrchestratorError(
67
+ f"Component error in orchestration: {str(e)}"
68
+ ) from e
69
+ except Exception as e:
70
+ raise OrchestratorError(
71
+ f"Unexpected orchestration error: {str(e)}"
72
+ ) from e
73
+
74
+ def initialize_event(self) -> Event:
75
+ """Initialize pipeline with an event."""
76
+ return Event(
77
+ type=EventType.INITIALIZE,
78
+ payload=None
79
+ )
@@ -0,0 +1,41 @@
1
+ from .transform.transform import Transformer
2
+ from .trainer.trainer import Trainer
3
+ from .tuner.tuner import Tuner
4
+ from .evaluator.evaluator import Evaluator
5
+ from .pusher.pusher import Pusher
6
+ from .orchestrator.orchestrator import Orchestrator
7
+ from .models import (
8
+ TransformerPayload,
9
+ TrainerPayload,
10
+ TunerPayload,
11
+ EvaluatorPayload,
12
+ PusherPayload,
13
+ TrainerArtifact,
14
+ TransformerArtifact,
15
+ TunerArtifact,
16
+ EvaluatorArtifact
17
+ )
18
+ from .enums.components import CollieComponentType, CollieComponents
19
+ from .enums.ml_models import ModelFlavor
20
+
21
+
22
+ __all__ = [
23
+ "Transformer",
24
+ "Trainer",
25
+ "Tuner",
26
+ "Evaluator",
27
+ "Pusher",
28
+ "Orchestrator",
29
+ "TransformerPayload",
30
+ "TrainerPayload",
31
+ "TunerPayload",
32
+ "EvaluatorPayload",
33
+ "PusherPayload",
34
+ "CollieComponentType",
35
+ "CollieComponents",
36
+ "TrainerArtifact",
37
+ "TransformerArtifact",
38
+ "TunerArtifact",
39
+ "EvaluatorArtifact",
40
+ "ModelFlavor"
41
+ ]
File without changes
@@ -0,0 +1,26 @@
1
+ from typing import Union
2
+ from enum import Enum
3
+
4
+ from collie.core.transform.transform import Transformer
5
+ from collie.core.tuner.tuner import Tuner
6
+ from collie.core.trainer.trainer import Trainer
7
+ from collie.core.evaluator.evaluator import Evaluator
8
+ from collie.core.pusher.pusher import Pusher
9
+
10
+
11
+ class CollieComponents(Enum):
12
+
13
+ TRAINER = Trainer
14
+ TRANSFORMER = Transformer
15
+ TUNER = Tuner
16
+ EVALUATOR = Evaluator
17
+ PUSHER = Pusher
18
+
19
+
20
+ CollieComponentType = Union[
21
+ Trainer,
22
+ Transformer,
23
+ Tuner,
24
+ Evaluator,
25
+ Pusher
26
+ ]
@@ -0,0 +1,20 @@
1
+ from enum import Enum
2
+
3
+
4
+ class ModelFlavor(str, Enum):
5
+ SKLEARN = "sklearn"
6
+ XGBOOST = "xgboost"
7
+ LIGHTGBM = "lightgbm"
8
+ PYTORCH = "pytorch"
9
+ TRANSFORMERS = "transformers"
10
+
11
+
12
+ class MLflowModelStage(str, Enum):
13
+
14
+ NONE = "None"
15
+ STAGING = "Staging"
16
+ PRODUCTION = "Production"
17
+ ARCHIVED = "Archived"
18
+
19
+ def __str__(self) -> str:
20
+ return self.value
File without changes