hydraflow 0.16.2__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydraflow/core/collection.py +541 -0
- hydraflow/core/group_by.py +205 -0
- hydraflow/core/run.py +42 -61
- hydraflow/core/run_collection.py +37 -494
- hydraflow/core/run_info.py +0 -9
- {hydraflow-0.16.2.dist-info → hydraflow-0.17.0.dist-info}/METADATA +1 -1
- {hydraflow-0.16.2.dist-info → hydraflow-0.17.0.dist-info}/RECORD +10 -8
- {hydraflow-0.16.2.dist-info → hydraflow-0.17.0.dist-info}/WHEEL +0 -0
- {hydraflow-0.16.2.dist-info → hydraflow-0.17.0.dist-info}/entry_points.txt +0 -0
- {hydraflow-0.16.2.dist-info → hydraflow-0.17.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,205 @@
|
|
1
|
+
"""GroupBy module for organizing and aggregating collections of items.
|
2
|
+
|
3
|
+
This module provides the GroupBy class, which represents the result of a
|
4
|
+
group_by operation on a Collection. It organizes items into groups based on
|
5
|
+
specified keys and enables aggregation operations across those groups.
|
6
|
+
|
7
|
+
The GroupBy class implements a dictionary-like interface, allowing access to
|
8
|
+
individual groups through key lookup, iteration, and standard dictionary
|
9
|
+
methods like keys(), values(), and items().
|
10
|
+
|
11
|
+
Example:
|
12
|
+
```python
|
13
|
+
# Group runs by model type
|
14
|
+
grouped = runs.group_by("model.type")
|
15
|
+
|
16
|
+
# Access a specific group
|
17
|
+
transformer_runs = grouped["transformer"]
|
18
|
+
|
19
|
+
# Iterate through groups
|
20
|
+
for model_type, group in grouped.items():
|
21
|
+
print(f"Model: {model_type}, Runs: {len(group)}")
|
22
|
+
|
23
|
+
# Perform aggregations
|
24
|
+
stats = grouped.agg(
|
25
|
+
"accuracy",
|
26
|
+
"loss",
|
27
|
+
avg_time=lambda g: sum(r.get("runtime") for r in g) / len(g)
|
28
|
+
)
|
29
|
+
```
|
30
|
+
|
31
|
+
The GroupBy class supports aggregation through the agg() method, which can
|
32
|
+
compute both predefined metrics from the grouped items and custom aggregations
|
33
|
+
specified as callables.
|
34
|
+
|
35
|
+
"""
|
36
|
+
|
37
|
+
from __future__ import annotations
|
38
|
+
|
39
|
+
from dataclasses import MISSING
|
40
|
+
from typing import TYPE_CHECKING, Any
|
41
|
+
|
42
|
+
from polars import DataFrame, Series
|
43
|
+
|
44
|
+
if TYPE_CHECKING:
|
45
|
+
from collections.abc import (
|
46
|
+
Callable,
|
47
|
+
ItemsView,
|
48
|
+
Iterator,
|
49
|
+
KeysView,
|
50
|
+
Sequence,
|
51
|
+
ValuesView,
|
52
|
+
)
|
53
|
+
|
54
|
+
from .collection import Collection
|
55
|
+
|
56
|
+
|
57
|
+
class GroupBy[C: Collection[Any], I]:
|
58
|
+
"""Represents the result of a group_by operation on a Collection.
|
59
|
+
|
60
|
+
The GroupBy class organizes items from a Collection into groups based on
|
61
|
+
specified keys. It provides a dictionary-like interface for accessing the
|
62
|
+
groups and methods for aggregating data across the groups.
|
63
|
+
|
64
|
+
Attributes:
|
65
|
+
by: The keys used for grouping.
|
66
|
+
groups: A dictionary mapping group keys to Collection instances.
|
67
|
+
|
68
|
+
"""
|
69
|
+
|
70
|
+
by: tuple[str, ...]
|
71
|
+
groups: dict[Any, C]
|
72
|
+
|
73
|
+
def __init__(self, by: tuple[str, ...], groups: dict[Any, C]) -> None:
|
74
|
+
"""Initialize a GroupBy instance.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
by: The keys used for grouping.
|
78
|
+
groups: A dictionary mapping group keys to Collection instances.
|
79
|
+
|
80
|
+
"""
|
81
|
+
self.by = by
|
82
|
+
self.groups = groups
|
83
|
+
|
84
|
+
def __getitem__(self, key: Any) -> C:
|
85
|
+
"""Get a group by its key.
|
86
|
+
|
87
|
+
Args:
|
88
|
+
key: The group key to look up.
|
89
|
+
|
90
|
+
Returns:
|
91
|
+
The Collection corresponding to the key.
|
92
|
+
|
93
|
+
Raises:
|
94
|
+
KeyError: If the key is not found in the groups.
|
95
|
+
|
96
|
+
"""
|
97
|
+
return self.groups[key]
|
98
|
+
|
99
|
+
def __iter__(self) -> Iterator[Any]:
|
100
|
+
"""Iterate over group keys.
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
An iterator over the group keys.
|
104
|
+
|
105
|
+
"""
|
106
|
+
return iter(self.groups)
|
107
|
+
|
108
|
+
def __len__(self) -> int:
|
109
|
+
"""Get the number of groups.
|
110
|
+
|
111
|
+
Returns:
|
112
|
+
The number of groups.
|
113
|
+
|
114
|
+
"""
|
115
|
+
return len(self.groups)
|
116
|
+
|
117
|
+
def __contains__(self, key: Any) -> bool:
|
118
|
+
"""Check if a key is in the groups.
|
119
|
+
|
120
|
+
Args:
|
121
|
+
key: The key to check for.
|
122
|
+
|
123
|
+
Returns:
|
124
|
+
True if the key is in the groups, False otherwise.
|
125
|
+
|
126
|
+
"""
|
127
|
+
return key in self.groups
|
128
|
+
|
129
|
+
def keys(self) -> KeysView[Any]:
|
130
|
+
"""Get the keys of the groups.
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
A view of the group keys.
|
134
|
+
|
135
|
+
"""
|
136
|
+
return self.groups.keys()
|
137
|
+
|
138
|
+
def values(self) -> ValuesView[C]:
|
139
|
+
"""Get the values (Collections) of the groups.
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
A view of the group values.
|
143
|
+
|
144
|
+
"""
|
145
|
+
return self.groups.values()
|
146
|
+
|
147
|
+
def items(self) -> ItemsView[Any, C]:
|
148
|
+
"""Get the (key, value) pairs of the groups.
|
149
|
+
|
150
|
+
Returns:
|
151
|
+
A view of the (key, value) pairs.
|
152
|
+
|
153
|
+
"""
|
154
|
+
return self.groups.items()
|
155
|
+
|
156
|
+
def agg(
|
157
|
+
self,
|
158
|
+
*aggs: str,
|
159
|
+
**named_aggs: Callable[[C | Sequence[I]], Any],
|
160
|
+
) -> DataFrame:
|
161
|
+
"""Aggregate data across groups.
|
162
|
+
|
163
|
+
This method computes aggregations for each group and returns the results
|
164
|
+
as a DataFrame. There are two ways to specify aggregations:
|
165
|
+
|
166
|
+
1. String keys: These are interpreted as attributes to extract from each
|
167
|
+
item in the group.
|
168
|
+
2. Callables: Functions that take a Collection or Sequence of items and
|
169
|
+
return an aggregated value.
|
170
|
+
|
171
|
+
Args:
|
172
|
+
*aggs: String keys to aggregate.
|
173
|
+
**named_aggs: Named aggregation functions.
|
174
|
+
|
175
|
+
Returns:
|
176
|
+
A DataFrame with group keys and aggregated values.
|
177
|
+
|
178
|
+
Example:
|
179
|
+
```python
|
180
|
+
# Aggregate by accuracy and loss, and compute average runtime
|
181
|
+
stats = grouped.agg(
|
182
|
+
"accuracy",
|
183
|
+
"loss",
|
184
|
+
avg_runtime=lambda g: sum(r.get("runtime") for r in g) / len(g)
|
185
|
+
)
|
186
|
+
```
|
187
|
+
|
188
|
+
"""
|
189
|
+
gp = self.groups
|
190
|
+
|
191
|
+
if len(self.by) == 1:
|
192
|
+
df = DataFrame({self.by[0]: list(gp)})
|
193
|
+
else:
|
194
|
+
df = DataFrame(dict(zip(self.by, k, strict=True)) for k in gp)
|
195
|
+
|
196
|
+
columns = []
|
197
|
+
|
198
|
+
for agg in aggs:
|
199
|
+
values = [[c._get(i, agg, MISSING) for i in c] for c in gp.values()] # noqa: SLF001
|
200
|
+
columns.append(Series(agg, values))
|
201
|
+
|
202
|
+
for k, v in named_aggs.items():
|
203
|
+
columns.append(Series(k, [v(r) for r in gp.values()]))
|
204
|
+
|
205
|
+
return df.with_columns(columns)
|
hydraflow/core/run.py
CHANGED
@@ -29,7 +29,7 @@ from functools import cached_property
|
|
29
29
|
from pathlib import Path
|
30
30
|
from typing import TYPE_CHECKING, cast, overload
|
31
31
|
|
32
|
-
from omegaconf import DictConfig,
|
32
|
+
from omegaconf import DictConfig, OmegaConf
|
33
33
|
|
34
34
|
from .run_info import RunInfo
|
35
35
|
|
@@ -54,6 +54,7 @@ class Run[C, I = None]:
|
|
54
54
|
"""Factory function to create the implementation instance.
|
55
55
|
|
56
56
|
This can be a callable that accepts either:
|
57
|
+
|
57
58
|
- A single Path parameter: the artifacts directory
|
58
59
|
- Both a Path and a config parameter: the artifacts directory and
|
59
60
|
the configuration instance
|
@@ -65,10 +66,10 @@ class Run[C, I = None]:
|
|
65
66
|
def __init__(
|
66
67
|
self,
|
67
68
|
run_dir: Path,
|
68
|
-
impl_factory: Callable[[Path], I] | Callable[[Path, C], I]
|
69
|
+
impl_factory: Callable[[Path], I] | Callable[[Path, C], I] | None = None,
|
69
70
|
) -> None:
|
70
71
|
self.info = RunInfo(run_dir)
|
71
|
-
self.impl_factory = impl_factory
|
72
|
+
self.impl_factory = impl_factory or (lambda _: None) # type: ignore
|
72
73
|
|
73
74
|
def __repr__(self) -> str:
|
74
75
|
"""Return a string representation of the Run."""
|
@@ -132,7 +133,7 @@ class Run[C, I = None]:
|
|
132
133
|
impl_factory: Callable[[Path], I] | Callable[[Path, C], I] = lambda _: None, # type: ignore
|
133
134
|
*,
|
134
135
|
n_jobs: int = 0,
|
135
|
-
) -> RunCollection[Self]: ...
|
136
|
+
) -> RunCollection[Self, I]: ...
|
136
137
|
|
137
138
|
@classmethod
|
138
139
|
def load(
|
@@ -141,7 +142,7 @@ class Run[C, I = None]:
|
|
141
142
|
impl_factory: Callable[[Path], I] | Callable[[Path, C], I] = lambda _: None, # type: ignore
|
142
143
|
*,
|
143
144
|
n_jobs: int = 0,
|
144
|
-
) -> Self | RunCollection[Self]:
|
145
|
+
) -> Self | RunCollection[Self, I]:
|
145
146
|
"""Load a Run from a run directory.
|
146
147
|
|
147
148
|
Args:
|
@@ -167,13 +168,14 @@ class Run[C, I = None]:
|
|
167
168
|
from .run_collection import RunCollection
|
168
169
|
|
169
170
|
if n_jobs == 0:
|
170
|
-
|
171
|
+
runs = (cls(Path(r), impl_factory) for r in run_dir)
|
172
|
+
return RunCollection(runs, cls.get) # type: ignore
|
171
173
|
|
172
174
|
from joblib import Parallel, delayed
|
173
175
|
|
174
176
|
parallel = Parallel(backend="threading", n_jobs=n_jobs)
|
175
177
|
runs = parallel(delayed(cls)(Path(r), impl_factory) for r in run_dir)
|
176
|
-
return RunCollection(runs) # type: ignore
|
178
|
+
return RunCollection(runs, cls.get) # type: ignore
|
177
179
|
|
178
180
|
@overload
|
179
181
|
def update(
|
@@ -211,7 +213,9 @@ class Run[C, I = None]:
|
|
211
213
|
(can use dot notation like "section.subsection.param"),
|
212
214
|
or a tuple of strings to set multiple related configuration
|
213
215
|
values at once.
|
214
|
-
value: The value to set.
|
216
|
+
value: The value to set.
|
217
|
+
This can be:
|
218
|
+
|
215
219
|
- For string keys: Any value, or a callable that returns
|
216
220
|
a value
|
217
221
|
- For tuple keys: An iterable with the same length as the
|
@@ -258,6 +262,12 @@ class Run[C, I = None]:
|
|
258
262
|
Args:
|
259
263
|
key: The key to look for. Can use dot notation for
|
260
264
|
nested keys in configuration.
|
265
|
+
Special keys:
|
266
|
+
|
267
|
+
- "cfg": Returns the configuration object
|
268
|
+
- "impl": Returns the implementation object
|
269
|
+
- "info": Returns the run information object
|
270
|
+
|
261
271
|
default: Value to return if the key is not found.
|
262
272
|
If a callable, it will be called with the Run instance
|
263
273
|
and the value returned will be used as the default.
|
@@ -272,6 +282,13 @@ class Run[C, I = None]:
|
|
272
282
|
AttributeError: If the key is not found and
|
273
283
|
no default is provided.
|
274
284
|
|
285
|
+
Note:
|
286
|
+
The search order for keys is:
|
287
|
+
1. Configuration (cfg)
|
288
|
+
2. Implementation (impl)
|
289
|
+
3. Run information (info)
|
290
|
+
4. Run object itself (self)
|
291
|
+
|
275
292
|
"""
|
276
293
|
key = key.replace("__", ".")
|
277
294
|
|
@@ -279,12 +296,10 @@ class Run[C, I = None]:
|
|
279
296
|
if value is not MISSING:
|
280
297
|
return value
|
281
298
|
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
if key in info:
|
287
|
-
return info[key]
|
299
|
+
for attr in [self.impl, self.info, self]:
|
300
|
+
value = getattr(attr, key, MISSING)
|
301
|
+
if value is not MISSING:
|
302
|
+
return value
|
288
303
|
|
289
304
|
if default is not MISSING:
|
290
305
|
if callable(default):
|
@@ -295,71 +310,37 @@ class Run[C, I = None]:
|
|
295
310
|
msg = f"No such key: {key}"
|
296
311
|
raise AttributeError(msg)
|
297
312
|
|
298
|
-
def
|
299
|
-
"""
|
300
|
-
|
301
|
-
This method retrieves the attribute specified by the key
|
302
|
-
using the get method, and then compares it with the given
|
303
|
-
value according to the following rules:
|
304
|
-
|
305
|
-
- If value is callable: Call it with the attribute and return
|
306
|
-
the boolean result
|
307
|
-
- If value is a list or set: Check if the attribute is in the list/set
|
308
|
-
- If value is a tuple of length 2: Check if the attribute is
|
309
|
-
in the range [value[0], value[1]]. Both sides are inclusive
|
310
|
-
- Otherwise: Check if the attribute equals the value
|
313
|
+
def to_dict(self, flatten: bool = True) -> dict[str, Any]:
|
314
|
+
"""Convert the Run to a dictionary.
|
311
315
|
|
312
316
|
Args:
|
313
|
-
|
314
|
-
|
315
|
-
the attribute and returns a boolean.
|
317
|
+
flatten (bool, optional): If True, flattens nested dictionaries.
|
318
|
+
Defaults to True.
|
316
319
|
|
317
320
|
Returns:
|
318
|
-
|
321
|
+
dict[str, Any]: A dictionary representation of the Run's configuration.
|
319
322
|
|
320
323
|
"""
|
321
|
-
attr = self.get(key)
|
322
|
-
return _predicate(attr, value)
|
323
|
-
|
324
|
-
def to_dict(self) -> dict[str, Any]:
|
325
|
-
"""Convert the Run to a dictionary."""
|
326
|
-
info = self.info.to_dict()
|
327
324
|
cfg = OmegaConf.to_container(self.cfg)
|
328
|
-
|
325
|
+
if not isinstance(cfg, dict):
|
326
|
+
raise TypeError("Configuration must be a dictionary")
|
329
327
|
|
328
|
+
standard_dict: dict[str, Any] = {str(k): v for k, v in cfg.items()}
|
330
329
|
|
331
|
-
|
332
|
-
|
333
|
-
return bool(value(attr))
|
330
|
+
if flatten:
|
331
|
+
return _flatten_dict(standard_dict)
|
334
332
|
|
335
|
-
|
336
|
-
value = list(value)
|
337
|
-
|
338
|
-
if isinstance(value, list | set) and not _is_iterable(attr):
|
339
|
-
return attr in value
|
340
|
-
|
341
|
-
if isinstance(value, tuple) and len(value) == 2 and not _is_iterable(attr):
|
342
|
-
return value[0] <= attr <= value[1]
|
343
|
-
|
344
|
-
if _is_iterable(value):
|
345
|
-
value = list(value)
|
346
|
-
|
347
|
-
if _is_iterable(attr):
|
348
|
-
attr = list(attr)
|
349
|
-
|
350
|
-
return attr == value
|
351
|
-
|
352
|
-
|
353
|
-
def _is_iterable(value: Any) -> bool:
|
354
|
-
return isinstance(value, Iterable) and not isinstance(value, str)
|
333
|
+
return standard_dict
|
355
334
|
|
356
335
|
|
357
336
|
def _flatten_dict(d: dict[str, Any], parent_key: str = "") -> dict[str, Any]:
|
358
337
|
items = []
|
338
|
+
|
359
339
|
for k, v in d.items():
|
360
340
|
key = f"{parent_key}.{k}" if parent_key else k
|
361
341
|
if isinstance(v, dict):
|
362
342
|
items.extend(_flatten_dict(v, key).items())
|
363
343
|
else:
|
364
344
|
items.append((key, v))
|
345
|
+
|
365
346
|
return dict(items)
|