runnable 0.9.1__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,94 +0,0 @@
1
- import functools
2
- import logging
3
- from typing import Any, Union
4
-
5
- from pydantic import ConfigDict, PrivateAttr
6
-
7
- from runnable import defaults
8
- from runnable.experiment_tracker import BaseExperimentTracker
9
-
10
- logger = logging.getLogger(defaults.NAME)
11
-
12
-
13
- class MLFlowExperimentTracker(BaseExperimentTracker):
14
- """
15
- A MLFlow experiment tracker.
16
-
17
- TODO: Need to set up credentials from secrets
18
- """
19
-
20
- service_name: str = "mlflow"
21
-
22
- server_url: str
23
- autolog: bool = False
24
-
25
- _default_experiment_name: str = PrivateAttr(default="Default")
26
- _active_run_id: str = PrivateAttr(default="")
27
- _client: Any = PrivateAttr(default=None)
28
-
29
- model_config = ConfigDict(extra="forbid")
30
-
31
- def model_post_init(self, __context: Any) -> None:
32
- try:
33
- import mlflow
34
- except ImportError:
35
- raise Exception("You need to install mlflow to use MLFlowExperimentTracker.")
36
-
37
- self._client = mlflow
38
-
39
- self._client.set_tracking_uri(self.server_url)
40
-
41
- if self.autolog:
42
- self._client.autolog(log_models=False)
43
-
44
- @functools.cached_property
45
- def experiment_id(self):
46
- experiment_name = self._default_experiment_name
47
-
48
- # If a tag is provided, we should create that as our experiment
49
- if self._context.tag:
50
- experiment_name = self._context.tag
51
-
52
- experiment = self._client.get_experiment_by_name(experiment_name)
53
- if not experiment:
54
- # Create the experiment and get it.
55
- experiment = self._client.create_experiment(experiment_name)
56
- experiment = self._client.get_experiment(experiment)
57
-
58
- return experiment.experiment_id
59
-
60
- @functools.cached_property
61
- def run_name(self):
62
- return self._context.run_id
63
-
64
- @property
65
- def client_context(self):
66
- if self._active_run_id:
67
- return self._client.start_run(
68
- run_id=self._active_run_id, experiment_id=self.experiment_id, run_name=self.run_name
69
- )
70
-
71
- active_run = self._client.start_run(run_name=self.run_name, experiment_id=self.experiment_id)
72
- self._active_run_id = active_run.info.run_id
73
- return active_run
74
-
75
- def log_metric(self, key: str, value: Union[int, float], step: int = 0):
76
- """
77
- Sets the metric in the experiment tracking.
78
-
79
- Args:
80
- key (str): The key against you want to store the value
81
- value (Any): The value of the metric
82
- """
83
- if not isinstance(value, float) or isinstance(value, int):
84
- msg = f"Only float/int values are accepted as metrics. Setting the metric {key} as parameter {key}_{step}"
85
- logger.warning(msg)
86
- self.log_parameter(key=key, value=value, step=step)
87
- return
88
-
89
- with self.client_context as _:
90
- self._client.log_metric(key, float(value), step=step or None)
91
-
92
- def log_parameter(self, key: str, value: Any, step: int = 0):
93
- with self.client_context as _:
94
- self._client.log_param(key + f"_{str(step)}", value)