hydraflow 0.14.4__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hydraflow/core/run.py ADDED
@@ -0,0 +1,355 @@
1
+ """Run module for HydraFlow.
2
+
3
+ This module provides the Run class, which represents an MLflow
4
+ Run in HydraFlow. A Run contains three main components:
5
+
6
+ 1. info: Information about the run, such as run directory,
7
+ run ID, and job name.
8
+ 2. cfg: Configuration loaded from the Hydra configuration file.
9
+ 3. impl: Implementation instance created by the provided
10
+ factory function.
11
+
12
+ The Run class allows accessing these components through
13
+ a unified interface, and provides methods for setting default
14
+ configuration values and filtering runs.
15
+
16
+ The implementation instance (impl) can be created using a factory function
17
+ that accepts either just the artifacts directory path, or both the
18
+ artifacts directory path and the configuration instance. This flexibility
19
+ allows implementation classes to be configuration-aware and adjust their
20
+ behavior based on the run's configuration.
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import inspect
26
+ from collections.abc import Callable, Iterable
27
+ from dataclasses import MISSING
28
+ from functools import cached_property
29
+ from pathlib import Path
30
+ from typing import TYPE_CHECKING, cast, overload
31
+
32
+ from omegaconf import DictConfig, ListConfig, OmegaConf
33
+
34
+ from .run_info import RunInfo
35
+
36
+ if TYPE_CHECKING:
37
+ from typing import Any, Self
38
+
39
+ from .run_collection import RunCollection
40
+
41
+
42
+ class Run[C, I = None]:
43
+ """Represent an MLflow Run in HydraFlow.
44
+
45
+ A Run contains information about the run, configuration, and
46
+ implementation. The configuration type C and implementation
47
+ type I are specified as type parameters.
48
+ """
49
+
50
+ info: RunInfo
51
+ """Information about the run, such as run directory, run ID, and job name."""
52
+
53
+ impl_factory: Callable[[Path], I] | Callable[[Path, C], I]
54
+ """Factory function to create the implementation instance.
55
+
56
+ This can be a callable that accepts either:
57
+ - A single Path parameter: the artifacts directory
58
+ - Both a Path and a config parameter: the artifacts directory and
59
+ the configuration instance
60
+
61
+ The implementation dynamically detects the signature and calls the
62
+ factory with the appropriate arguments.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ run_dir: Path,
68
+ impl_factory: Callable[[Path], I] | Callable[[Path, C], I] = lambda _: None,
69
+ ) -> None:
70
+ self.info = RunInfo(run_dir)
71
+ self.impl_factory = impl_factory
72
+
73
+ def __repr__(self) -> str:
74
+ """Return a string representation of the Run."""
75
+ class_name = self.__class__.__name__
76
+ if isinstance(self.impl_factory, type):
77
+ impl_name = f"[{self.impl_factory.__name__}]"
78
+ else:
79
+ impl_name = ""
80
+
81
+ return f"{class_name}{impl_name}({self.info.run_id!r})"
82
+
83
+ @cached_property
84
+ def cfg(self) -> C:
85
+ """The configuration instance loaded from the Hydra configuration file."""
86
+ config_file = self.info.run_dir / "artifacts/.hydra/config.yaml"
87
+ if config_file.exists():
88
+ return OmegaConf.load(config_file) # type: ignore
89
+
90
+ return OmegaConf.create() # type: ignore
91
+
92
+ @cached_property
93
+ def impl(self) -> I:
94
+ """The implementation instance created by the factory function.
95
+
96
+ This property dynamically examines the signature of the impl_factory
97
+ using the inspect module and calls it with the appropriate arguments:
98
+
99
+ - If the factory accepts one parameter: called with just the artifacts
100
+ directory
101
+ - If the factory accepts two parameters: called with the artifacts
102
+ directory and the configuration instance
103
+
104
+ This allows implementation classes to be configuration-aware and
105
+ utilize both the file system and configuration information.
106
+ """
107
+ artifacts_dir = self.info.run_dir / "artifacts"
108
+
109
+ sig = inspect.signature(self.impl_factory)
110
+ params = list(sig.parameters.values())
111
+
112
+ if len(params) == 1:
113
+ impl_factory = cast("Callable[[Path], I]", self.impl_factory)
114
+ return impl_factory(artifacts_dir)
115
+
116
+ impl_factory = cast("Callable[[Path, C], I]", self.impl_factory)
117
+ return impl_factory(artifacts_dir, self.cfg)
118
+
119
+ @overload
120
+ @classmethod
121
+ def load( # type: ignore
122
+ cls,
123
+ run_dir: str | Path,
124
+ impl_factory: Callable[[Path], I] | Callable[[Path, C], I] = lambda _: None, # type: ignore
125
+ ) -> Self: ...
126
+
127
+ @overload
128
+ @classmethod
129
+ def load(
130
+ cls,
131
+ run_dir: Iterable[str | Path],
132
+ impl_factory: Callable[[Path], I] | Callable[[Path, C], I] = lambda _: None, # type: ignore
133
+ *,
134
+ n_jobs: int = 0,
135
+ ) -> RunCollection[Self]: ...
136
+
137
+ @classmethod
138
+ def load(
139
+ cls,
140
+ run_dir: str | Path | Iterable[str | Path],
141
+ impl_factory: Callable[[Path], I] | Callable[[Path, C], I] = lambda _: None, # type: ignore
142
+ *,
143
+ n_jobs: int = 0,
144
+ ) -> Self | RunCollection[Self]:
145
+ """Load a Run from a run directory.
146
+
147
+ Args:
148
+ run_dir (str | Path | Iterable[str | Path]): The directory where the
149
+ MLflow runs are stored, either as a string, a Path instance,
150
+ or an iterable of them.
151
+ impl_factory (Callable[[Path], I] | Callable[[Path, C], I]): A factory
152
+ function that creates the implementation instance. It can accept
153
+ either just the artifacts directory path, or both the path and
154
+ the configuration instance. Defaults to a function that returns
155
+ None.
156
+ n_jobs (int): The number of parallel jobs. If 0 (default), runs
157
+ sequentially. If -1, uses all available CPU cores.
158
+
159
+ Returns:
160
+ Self | RunCollection[Self]: A single Run instance or a RunCollection
161
+ of Run instances.
162
+
163
+ """
164
+ if isinstance(run_dir, str | Path):
165
+ return cls(Path(run_dir), impl_factory)
166
+
167
+ from .run_collection import RunCollection
168
+
169
+ if n_jobs == 0:
170
+ return RunCollection(cls(Path(r), impl_factory) for r in run_dir)
171
+
172
+ from joblib import Parallel, delayed
173
+
174
+ parallel = Parallel(backend="threading", n_jobs=n_jobs)
175
+ runs = parallel(delayed(cls)(Path(r), impl_factory) for r in run_dir)
176
+ return RunCollection(runs) # type: ignore
177
+
178
+ @overload
179
+ def update(
180
+ self,
181
+ key: str,
182
+ value: Any | Callable[[Self], Any],
183
+ *,
184
+ force: bool = False,
185
+ ) -> None: ...
186
+
187
+ @overload
188
+ def update(
189
+ self,
190
+ key: tuple[str, ...],
191
+ value: Iterable[Any] | Callable[[Self], Iterable[Any]],
192
+ *,
193
+ force: bool = False,
194
+ ) -> None: ...
195
+
196
+ def update(
197
+ self,
198
+ key: str | tuple[str, ...],
199
+ value: Any | Callable[[Self], Any],
200
+ *,
201
+ force: bool = False,
202
+ ) -> None:
203
+ """Set default value(s) in the configuration if they don't already exist.
204
+
205
+ This method adds a value or multiple values to the configuration,
206
+ but only if the corresponding keys don't already have values.
207
+ Existing values will not be modified.
208
+
209
+ Args:
210
+ key: Either a string representing a single configuration path
211
+ (can use dot notation like "section.subsection.param"),
212
+ or a tuple of strings to set multiple related configuration
213
+ values at once.
214
+ value: The value to set. This can be:
215
+ - For string keys: Any value, or a callable that returns
216
+ a value
217
+ - For tuple keys: An iterable with the same length as the
218
+ key tuple, or a callable that returns such an iterable
219
+ - For callable values: The callable must accept a single argument
220
+ of type Run (self) and return the appropriate value type
221
+ force: Whether to force the update even if the key already exists.
222
+
223
+ Raises:
224
+ TypeError: If a tuple key is provided but the value is
225
+ not an iterable, or if the callable doesn't return
226
+ an iterable.
227
+
228
+ """
229
+ cfg: DictConfig = self.cfg # type: ignore
230
+
231
+ if isinstance(key, str):
232
+ if force or OmegaConf.select(cfg, key, default=MISSING) is MISSING:
233
+ v = value(self) if callable(value) else value # type: ignore
234
+ OmegaConf.update(cfg, key, v, force_add=True)
235
+ return
236
+
237
+ it = (OmegaConf.select(cfg, k, default=MISSING) is not MISSING for k in key)
238
+ if not force and all(it):
239
+ return
240
+
241
+ if callable(value):
242
+ value = value(self) # type: ignore
243
+
244
+ if not isinstance(value, Iterable) or isinstance(value, str):
245
+ msg = f"{value} is not an iterable"
246
+ raise TypeError(msg)
247
+
248
+ for k, v in zip(key, value, strict=True):
249
+ if force or OmegaConf.select(cfg, k, default=MISSING) is MISSING:
250
+ OmegaConf.update(cfg, k, v, force_add=True)
251
+
252
+ def get(self, key: str, default: Any = MISSING) -> Any:
253
+ """Get a value from the information or configuration.
254
+
255
+ Args:
256
+ key: The key to look for. Can use dot notation for
257
+ nested keys in configuration.
258
+ default: Value to return if the key is not found.
259
+ If not provided, AttributeError will be raised.
260
+
261
+ Returns:
262
+ Any: The value associated with the key, or the
263
+ default value if the key is not found and a default
264
+ is provided.
265
+
266
+ Raises:
267
+ AttributeError: If the key is not found and
268
+ no default is provided.
269
+
270
+ """
271
+ value = OmegaConf.select(self.cfg, key, default=MISSING) # type: ignore
272
+ if value is not MISSING:
273
+ return value
274
+
275
+ if self.impl and hasattr(self.impl, key):
276
+ return getattr(self.impl, key)
277
+
278
+ info = self.info.to_dict()
279
+ if key in info:
280
+ return info[key]
281
+
282
+ if default is not MISSING:
283
+ return default
284
+
285
+ msg = f"No such key: {key}"
286
+ raise AttributeError(msg)
287
+
288
+ def predicate(self, key: str, value: Any) -> bool:
289
+ """Check if a value satisfies a condition for filtering.
290
+
291
+ This method retrieves the attribute specified by the key
292
+ using the get method, and then compares it with the given
293
+ value according to the following rules:
294
+
295
+ - If value is callable: Call it with the attribute and return
296
+ the boolean result
297
+ - If value is a list or set: Check if the attribute is in the list/set
298
+ - If value is a tuple of length 2: Check if the attribute is
299
+ in the range [value[0], value[1]]. Both sides are inclusive
300
+ - Otherwise: Check if the attribute equals the value
301
+
302
+ Args:
303
+ key: The key to get the attribute from.
304
+ value: The value to compare with, or a callable that takes
305
+ the attribute and returns a boolean.
306
+
307
+ Returns:
308
+ bool: True if the attribute satisfies the condition, False otherwise.
309
+
310
+ """
311
+ attr = self.get(key)
312
+ return _predicate(attr, value)
313
+
314
+ def to_dict(self) -> dict[str, Any]:
315
+ """Convert the Run to a dictionary."""
316
+ info = self.info.to_dict()
317
+ cfg = OmegaConf.to_container(self.cfg)
318
+ return info | _flatten_dict(cfg) # type: ignore
319
+
320
+
321
+ def _predicate(attr: Any, value: Any) -> bool:
322
+ if callable(value):
323
+ return bool(value(attr))
324
+
325
+ if isinstance(value, ListConfig):
326
+ value = list(value)
327
+
328
+ if isinstance(value, list | set) and not _is_iterable(attr):
329
+ return attr in value
330
+
331
+ if isinstance(value, tuple) and len(value) == 2 and not _is_iterable(attr):
332
+ return value[0] <= attr <= value[1]
333
+
334
+ if _is_iterable(value):
335
+ value = list(value)
336
+
337
+ if _is_iterable(attr):
338
+ attr = list(attr)
339
+
340
+ return attr == value
341
+
342
+
343
+ def _is_iterable(value: Any) -> bool:
344
+ return isinstance(value, Iterable) and not isinstance(value, str)
345
+
346
+
347
+ def _flatten_dict(d: dict[str, Any], parent_key: str = "") -> dict[str, Any]:
348
+ items = []
349
+ for k, v in d.items():
350
+ key = f"{parent_key}.{k}" if parent_key else k
351
+ if isinstance(v, dict):
352
+ items.extend(_flatten_dict(v, key).items())
353
+ else:
354
+ items.append((key, v))
355
+ return dict(items)