runnable 0.9.1__py3-none-any.whl → 0.11.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,94 +0,0 @@
1
- import functools
2
- import logging
3
- from typing import Any, Union
4
-
5
- from pydantic import ConfigDict, PrivateAttr
6
-
7
- from runnable import defaults
8
- from runnable.experiment_tracker import BaseExperimentTracker
9
-
10
- logger = logging.getLogger(defaults.NAME)
11
-
12
-
13
- class MLFlowExperimentTracker(BaseExperimentTracker):
14
- """
15
- A MLFlow experiment tracker.
16
-
17
- TODO: Need to set up credentials from secrets
18
- """
19
-
20
- service_name: str = "mlflow"
21
-
22
- server_url: str
23
- autolog: bool = False
24
-
25
- _default_experiment_name: str = PrivateAttr(default="Default")
26
- _active_run_id: str = PrivateAttr(default="")
27
- _client: Any = PrivateAttr(default=None)
28
-
29
- model_config = ConfigDict(extra="forbid")
30
-
31
- def model_post_init(self, __context: Any) -> None:
32
- try:
33
- import mlflow
34
- except ImportError:
35
- raise Exception("You need to install mlflow to use MLFlowExperimentTracker.")
36
-
37
- self._client = mlflow
38
-
39
- self._client.set_tracking_uri(self.server_url)
40
-
41
- if self.autolog:
42
- self._client.autolog(log_models=False)
43
-
44
- @functools.cached_property
45
- def experiment_id(self):
46
- experiment_name = self._default_experiment_name
47
-
48
- # If a tag is provided, we should create that as our experiment
49
- if self._context.tag:
50
- experiment_name = self._context.tag
51
-
52
- experiment = self._client.get_experiment_by_name(experiment_name)
53
- if not experiment:
54
- # Create the experiment and get it.
55
- experiment = self._client.create_experiment(experiment_name)
56
- experiment = self._client.get_experiment(experiment)
57
-
58
- return experiment.experiment_id
59
-
60
- @functools.cached_property
61
- def run_name(self):
62
- return self._context.run_id
63
-
64
- @property
65
- def client_context(self):
66
- if self._active_run_id:
67
- return self._client.start_run(
68
- run_id=self._active_run_id, experiment_id=self.experiment_id, run_name=self.run_name
69
- )
70
-
71
- active_run = self._client.start_run(run_name=self.run_name, experiment_id=self.experiment_id)
72
- self._active_run_id = active_run.info.run_id
73
- return active_run
74
-
75
- def log_metric(self, key: str, value: Union[int, float], step: int = 0):
76
- """
77
- Sets the metric in the experiment tracking.
78
-
79
- Args:
80
- key (str): The key against you want to store the value
81
- value (Any): The value of the metric
82
- """
83
- if not isinstance(value, float) or isinstance(value, int):
84
- msg = f"Only float/int values are accepted as metrics. Setting the metric {key} as parameter {key}_{step}"
85
- logger.warning(msg)
86
- self.log_parameter(key=key, value=value, step=step)
87
- return
88
-
89
- with self.client_context as _:
90
- self._client.log_metric(key, float(value), step=step or None)
91
-
92
- def log_parameter(self, key: str, value: Any, step: int = 0):
93
- with self.client_context as _:
94
- self._client.log_param(key + f"_{str(step)}", value)