hydraflow 0.16.2__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,205 @@
1
+ """GroupBy module for organizing and aggregating collections of items.
2
+
3
+ This module provides the GroupBy class, which represents the result of a
4
+ group_by operation on a Collection. It organizes items into groups based on
5
+ specified keys and enables aggregation operations across those groups.
6
+
7
+ The GroupBy class implements a dictionary-like interface, allowing access to
8
+ individual groups through key lookup, iteration, and standard dictionary
9
+ methods like keys(), values(), and items().
10
+
11
+ Example:
12
+ ```python
13
+ # Group runs by model type
14
+ grouped = runs.group_by("model.type")
15
+
16
+ # Access a specific group
17
+ transformer_runs = grouped["transformer"]
18
+
19
+ # Iterate through groups
20
+ for model_type, group in grouped.items():
21
+ print(f"Model: {model_type}, Runs: {len(group)}")
22
+
23
+ # Perform aggregations
24
+ stats = grouped.agg(
25
+ "accuracy",
26
+ "loss",
27
+ avg_time=lambda g: sum(r.get("runtime") for r in g) / len(g)
28
+ )
29
+ ```
30
+
31
+ The GroupBy class supports aggregation through the agg() method, which can
32
+ compute both predefined metrics from the grouped items and custom aggregations
33
+ specified as callables.
34
+
35
+ """
36
+
37
+ from __future__ import annotations
38
+
39
+ from dataclasses import MISSING
40
+ from typing import TYPE_CHECKING, Any
41
+
42
+ from polars import DataFrame, Series
43
+
44
+ if TYPE_CHECKING:
45
+ from collections.abc import (
46
+ Callable,
47
+ ItemsView,
48
+ Iterator,
49
+ KeysView,
50
+ Sequence,
51
+ ValuesView,
52
+ )
53
+
54
+ from .collection import Collection
55
+
56
+
57
+ class GroupBy[C: Collection[Any], I]:
58
+ """Represents the result of a group_by operation on a Collection.
59
+
60
+ The GroupBy class organizes items from a Collection into groups based on
61
+ specified keys. It provides a dictionary-like interface for accessing the
62
+ groups and methods for aggregating data across the groups.
63
+
64
+ Attributes:
65
+ by: The keys used for grouping.
66
+ groups: A dictionary mapping group keys to Collection instances.
67
+
68
+ """
69
+
70
+ by: tuple[str, ...]
71
+ groups: dict[Any, C]
72
+
73
+ def __init__(self, by: tuple[str, ...], groups: dict[Any, C]) -> None:
74
+ """Initialize a GroupBy instance.
75
+
76
+ Args:
77
+ by: The keys used for grouping.
78
+ groups: A dictionary mapping group keys to Collection instances.
79
+
80
+ """
81
+ self.by = by
82
+ self.groups = groups
83
+
84
+ def __getitem__(self, key: Any) -> C:
85
+ """Get a group by its key.
86
+
87
+ Args:
88
+ key: The group key to look up.
89
+
90
+ Returns:
91
+ The Collection corresponding to the key.
92
+
93
+ Raises:
94
+ KeyError: If the key is not found in the groups.
95
+
96
+ """
97
+ return self.groups[key]
98
+
99
+ def __iter__(self) -> Iterator[Any]:
100
+ """Iterate over group keys.
101
+
102
+ Returns:
103
+ An iterator over the group keys.
104
+
105
+ """
106
+ return iter(self.groups)
107
+
108
+ def __len__(self) -> int:
109
+ """Get the number of groups.
110
+
111
+ Returns:
112
+ The number of groups.
113
+
114
+ """
115
+ return len(self.groups)
116
+
117
+ def __contains__(self, key: Any) -> bool:
118
+ """Check if a key is in the groups.
119
+
120
+ Args:
121
+ key: The key to check for.
122
+
123
+ Returns:
124
+ True if the key is in the groups, False otherwise.
125
+
126
+ """
127
+ return key in self.groups
128
+
129
+ def keys(self) -> KeysView[Any]:
130
+ """Get the keys of the groups.
131
+
132
+ Returns:
133
+ A view of the group keys.
134
+
135
+ """
136
+ return self.groups.keys()
137
+
138
+ def values(self) -> ValuesView[C]:
139
+ """Get the values (Collections) of the groups.
140
+
141
+ Returns:
142
+ A view of the group values.
143
+
144
+ """
145
+ return self.groups.values()
146
+
147
+ def items(self) -> ItemsView[Any, C]:
148
+ """Get the (key, value) pairs of the groups.
149
+
150
+ Returns:
151
+ A view of the (key, value) pairs.
152
+
153
+ """
154
+ return self.groups.items()
155
+
156
+ def agg(
157
+ self,
158
+ *aggs: str,
159
+ **named_aggs: Callable[[C | Sequence[I]], Any],
160
+ ) -> DataFrame:
161
+ """Aggregate data across groups.
162
+
163
+ This method computes aggregations for each group and returns the results
164
+ as a DataFrame. There are two ways to specify aggregations:
165
+
166
+ 1. String keys: These are interpreted as attributes to extract from each
167
+ item in the group.
168
+ 2. Callables: Functions that take a Collection or Sequence of items and
169
+ return an aggregated value.
170
+
171
+ Args:
172
+ *aggs: String keys to aggregate.
173
+ **named_aggs: Named aggregation functions.
174
+
175
+ Returns:
176
+ A DataFrame with group keys and aggregated values.
177
+
178
+ Example:
179
+ ```python
180
+ # Aggregate by accuracy and loss, and compute average runtime
181
+ stats = grouped.agg(
182
+ "accuracy",
183
+ "loss",
184
+ avg_runtime=lambda g: sum(r.get("runtime") for r in g) / len(g)
185
+ )
186
+ ```
187
+
188
+ """
189
+ gp = self.groups
190
+
191
+ if len(self.by) == 1:
192
+ df = DataFrame({self.by[0]: list(gp)})
193
+ else:
194
+ df = DataFrame(dict(zip(self.by, k, strict=True)) for k in gp)
195
+
196
+ columns = []
197
+
198
+ for agg in aggs:
199
+ values = [[c._get(i, agg, MISSING) for i in c] for c in gp.values()] # noqa: SLF001
200
+ columns.append(Series(agg, values))
201
+
202
+ for k, v in named_aggs.items():
203
+ columns.append(Series(k, [v(r) for r in gp.values()]))
204
+
205
+ return df.with_columns(columns)
hydraflow/core/run.py CHANGED
@@ -29,7 +29,7 @@ from functools import cached_property
29
29
  from pathlib import Path
30
30
  from typing import TYPE_CHECKING, cast, overload
31
31
 
32
- from omegaconf import DictConfig, ListConfig, OmegaConf
32
+ from omegaconf import DictConfig, OmegaConf
33
33
 
34
34
  from .run_info import RunInfo
35
35
 
@@ -54,6 +54,7 @@ class Run[C, I = None]:
54
54
  """Factory function to create the implementation instance.
55
55
 
56
56
  This can be a callable that accepts either:
57
+
57
58
  - A single Path parameter: the artifacts directory
58
59
  - Both a Path and a config parameter: the artifacts directory and
59
60
  the configuration instance
@@ -65,10 +66,10 @@ class Run[C, I = None]:
65
66
  def __init__(
66
67
  self,
67
68
  run_dir: Path,
68
- impl_factory: Callable[[Path], I] | Callable[[Path, C], I] = lambda _: None,
69
+ impl_factory: Callable[[Path], I] | Callable[[Path, C], I] | None = None,
69
70
  ) -> None:
70
71
  self.info = RunInfo(run_dir)
71
- self.impl_factory = impl_factory
72
+ self.impl_factory = impl_factory or (lambda _: None) # type: ignore
72
73
 
73
74
  def __repr__(self) -> str:
74
75
  """Return a string representation of the Run."""
@@ -132,7 +133,7 @@ class Run[C, I = None]:
132
133
  impl_factory: Callable[[Path], I] | Callable[[Path, C], I] = lambda _: None, # type: ignore
133
134
  *,
134
135
  n_jobs: int = 0,
135
- ) -> RunCollection[Self]: ...
136
+ ) -> RunCollection[Self, I]: ...
136
137
 
137
138
  @classmethod
138
139
  def load(
@@ -141,7 +142,7 @@ class Run[C, I = None]:
141
142
  impl_factory: Callable[[Path], I] | Callable[[Path, C], I] = lambda _: None, # type: ignore
142
143
  *,
143
144
  n_jobs: int = 0,
144
- ) -> Self | RunCollection[Self]:
145
+ ) -> Self | RunCollection[Self, I]:
145
146
  """Load a Run from a run directory.
146
147
 
147
148
  Args:
@@ -167,13 +168,14 @@ class Run[C, I = None]:
167
168
  from .run_collection import RunCollection
168
169
 
169
170
  if n_jobs == 0:
170
- return RunCollection(cls(Path(r), impl_factory) for r in run_dir)
171
+ runs = (cls(Path(r), impl_factory) for r in run_dir)
172
+ return RunCollection(runs, cls.get) # type: ignore
171
173
 
172
174
  from joblib import Parallel, delayed
173
175
 
174
176
  parallel = Parallel(backend="threading", n_jobs=n_jobs)
175
177
  runs = parallel(delayed(cls)(Path(r), impl_factory) for r in run_dir)
176
- return RunCollection(runs) # type: ignore
178
+ return RunCollection(runs, cls.get) # type: ignore
177
179
 
178
180
  @overload
179
181
  def update(
@@ -211,7 +213,9 @@ class Run[C, I = None]:
211
213
  (can use dot notation like "section.subsection.param"),
212
214
  or a tuple of strings to set multiple related configuration
213
215
  values at once.
214
- value: The value to set. This can be:
216
+ value: The value to set.
217
+ This can be:
218
+
215
219
  - For string keys: Any value, or a callable that returns
216
220
  a value
217
221
  - For tuple keys: An iterable with the same length as the
@@ -258,6 +262,12 @@ class Run[C, I = None]:
258
262
  Args:
259
263
  key: The key to look for. Can use dot notation for
260
264
  nested keys in configuration.
265
+ Special keys:
266
+
267
+ - "cfg": Returns the configuration object
268
+ - "impl": Returns the implementation object
269
+ - "info": Returns the run information object
270
+
261
271
  default: Value to return if the key is not found.
262
272
  If a callable, it will be called with the Run instance
263
273
  and the value returned will be used as the default.
@@ -272,6 +282,13 @@ class Run[C, I = None]:
272
282
  AttributeError: If the key is not found and
273
283
  no default is provided.
274
284
 
285
+ Note:
286
+ The search order for keys is:
287
+ 1. Configuration (cfg)
288
+ 2. Implementation (impl)
289
+ 3. Run information (info)
290
+ 4. Run object itself (self)
291
+
275
292
  """
276
293
  key = key.replace("__", ".")
277
294
 
@@ -279,12 +296,10 @@ class Run[C, I = None]:
279
296
  if value is not MISSING:
280
297
  return value
281
298
 
282
- if self.impl and hasattr(self.impl, key):
283
- return getattr(self.impl, key)
284
-
285
- info = self.info.to_dict()
286
- if key in info:
287
- return info[key]
299
+ for attr in [self.impl, self.info, self]:
300
+ value = getattr(attr, key, MISSING)
301
+ if value is not MISSING:
302
+ return value
288
303
 
289
304
  if default is not MISSING:
290
305
  if callable(default):
@@ -295,71 +310,37 @@ class Run[C, I = None]:
295
310
  msg = f"No such key: {key}"
296
311
  raise AttributeError(msg)
297
312
 
298
- def predicate(self, key: str, value: Any) -> bool:
299
- """Check if a value satisfies a condition for filtering.
300
-
301
- This method retrieves the attribute specified by the key
302
- using the get method, and then compares it with the given
303
- value according to the following rules:
304
-
305
- - If value is callable: Call it with the attribute and return
306
- the boolean result
307
- - If value is a list or set: Check if the attribute is in the list/set
308
- - If value is a tuple of length 2: Check if the attribute is
309
- in the range [value[0], value[1]]. Both sides are inclusive
310
- - Otherwise: Check if the attribute equals the value
313
+ def to_dict(self, flatten: bool = True) -> dict[str, Any]:
314
+ """Convert the Run to a dictionary.
311
315
 
312
316
  Args:
313
- key: The key to get the attribute from.
314
- value: The value to compare with, or a callable that takes
315
- the attribute and returns a boolean.
317
+ flatten (bool, optional): If True, flattens nested dictionaries.
318
+ Defaults to True.
316
319
 
317
320
  Returns:
318
- bool: True if the attribute satisfies the condition, False otherwise.
321
+ dict[str, Any]: A dictionary representation of the Run's configuration.
319
322
 
320
323
  """
321
- attr = self.get(key)
322
- return _predicate(attr, value)
323
-
324
- def to_dict(self) -> dict[str, Any]:
325
- """Convert the Run to a dictionary."""
326
- info = self.info.to_dict()
327
324
  cfg = OmegaConf.to_container(self.cfg)
328
- return info | _flatten_dict(cfg) # type: ignore
325
+ if not isinstance(cfg, dict):
326
+ raise TypeError("Configuration must be a dictionary")
329
327
 
328
+ standard_dict: dict[str, Any] = {str(k): v for k, v in cfg.items()}
330
329
 
331
- def _predicate(attr: Any, value: Any) -> bool:
332
- if callable(value):
333
- return bool(value(attr))
330
+ if flatten:
331
+ return _flatten_dict(standard_dict)
334
332
 
335
- if isinstance(value, ListConfig):
336
- value = list(value)
337
-
338
- if isinstance(value, list | set) and not _is_iterable(attr):
339
- return attr in value
340
-
341
- if isinstance(value, tuple) and len(value) == 2 and not _is_iterable(attr):
342
- return value[0] <= attr <= value[1]
343
-
344
- if _is_iterable(value):
345
- value = list(value)
346
-
347
- if _is_iterable(attr):
348
- attr = list(attr)
349
-
350
- return attr == value
351
-
352
-
353
- def _is_iterable(value: Any) -> bool:
354
- return isinstance(value, Iterable) and not isinstance(value, str)
333
+ return standard_dict
355
334
 
356
335
 
357
336
  def _flatten_dict(d: dict[str, Any], parent_key: str = "") -> dict[str, Any]:
358
337
  items = []
338
+
359
339
  for k, v in d.items():
360
340
  key = f"{parent_key}.{k}" if parent_key else k
361
341
  if isinstance(v, dict):
362
342
  items.extend(_flatten_dict(v, key).items())
363
343
  else:
364
344
  items.append((key, v))
345
+
365
346
  return dict(items)