hydraflow 0.16.2__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,541 @@
1
+ """Provide a collection of items that implements the Sequence protocol."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Hashable, Iterable, Sequence
6
+ from dataclasses import MISSING
7
+ from typing import TYPE_CHECKING, overload
8
+
9
+ import numpy as np
10
+ from omegaconf import ListConfig, OmegaConf
11
+ from polars import DataFrame, Series
12
+
13
+ from .group_by import GroupBy
14
+
15
+ if TYPE_CHECKING:
16
+ from collections.abc import Callable, Iterator
17
+ from typing import Any, Self
18
+
19
+ from numpy.typing import NDArray
20
+
21
+
22
+ class Collection[I](Sequence[I]):
23
+ """A collection of items that implements the Sequence protocol."""
24
+
25
+ _items: list[I]
26
+ _get: Callable[[I, str, Any | Callable[[I], Any]], Any]
27
+
28
+ def __init__(
29
+ self,
30
+ items: Iterable[I],
31
+ get: Callable[[I, str, Any | Callable[[I], Any]], Any] | None = None,
32
+ ) -> None:
33
+ self._items = list(items)
34
+ self._get = get or getattr
35
+
36
+ def __repr__(self) -> str:
37
+ class_name = self.__class__.__name__
38
+ if not self:
39
+ return f"{class_name}(empty)"
40
+
41
+ type_name = repr(self[0])
42
+ if "(" in type_name:
43
+ type_name = type_name.split("(", 1)[0]
44
+ return f"{class_name}({type_name}, n={len(self)})"
45
+
46
+ def __len__(self) -> int:
47
+ return len(self._items)
48
+
49
+ def __bool__(self) -> bool:
50
+ return bool(self._items)
51
+
52
+ @overload
53
+ def __getitem__(self, index: int) -> I: ...
54
+
55
+ @overload
56
+ def __getitem__(self, index: slice) -> Self: ...
57
+
58
+ @overload
59
+ def __getitem__(self, index: Iterable[int]) -> Self: ...
60
+
61
+ def __getitem__(self, index: int | slice | Iterable[int]) -> I | Self:
62
+ if isinstance(index, int):
63
+ return self._items[index]
64
+
65
+ if isinstance(index, slice):
66
+ return self.__class__(self._items[index], self._get)
67
+
68
+ return self.__class__([self._items[i] for i in index], self._get)
69
+
70
+ def __iter__(self) -> Iterator[I]:
71
+ return iter(self._items)
72
+
73
+ def filter(
74
+ self,
75
+ *criteria: Callable[[I], bool] | tuple[str, Any],
76
+ **kwargs: Any,
77
+ ) -> Self:
78
+ """Filter items based on criteria.
79
+
80
+ This method allows filtering items using various criteria:
81
+
82
+ - Callable criteria that take an item and return a boolean
83
+ - Key-value tuples where the key is a string and the value
84
+ is compared using the `matches` function
85
+ - Keyword arguments, where the key is a string and the value
86
+ is compared using the `matches` function
87
+
88
+ The `matches` function supports the following comparison types:
89
+
90
+ - Callable: The predicate function is called with the value
91
+ - List/Set: Checks if the value is in the list/set
92
+ - Tuple of length 2: Checks if the value is in the range [min, max]
93
+ - Other: Checks for direct equality
94
+
95
+ Args:
96
+ *criteria: Callable criteria or (key, value) tuples
97
+ for filtering.
98
+ **kwargs: Additional key-value pairs for filtering.
99
+
100
+ Returns:
101
+ Self: A new Collection containing only the items that
102
+ match all criteria.
103
+
104
+ Examples:
105
+ ```python
106
+ # Filter using a callable
107
+ filtered = collection.filter(lambda x: x > 5)
108
+
109
+ # Filter using a key-value tuple
110
+ filtered = collection.filter(("age", 25))
111
+
112
+ # Filter using keyword arguments
113
+ filtered = collection.filter(age=25, name="John")
114
+
115
+ # Filter using range
116
+ filtered = collection.filter(("age", (20, 30)))
117
+
118
+ # Filter using list membership
119
+ filtered = collection.filter(("name", ["John", "Jane"]))
120
+ ```
121
+
122
+ """
123
+ items = self._items
124
+
125
+ for c in criteria:
126
+ if callable(c):
127
+ items = [i for i in items if c(i)]
128
+ else:
129
+ items = [i for i in items if matches(self._get(i, c[0], MISSING), c[1])]
130
+
131
+ for key, value in kwargs.items():
132
+ items = [i for i in items if matches(self._get(i, key, MISSING), value)]
133
+
134
+ return self.__class__(items, self._get)
135
+
136
+ def try_get(
137
+ self,
138
+ *criteria: Callable[[I], bool] | tuple[str, Any],
139
+ **kwargs: Any,
140
+ ) -> I | None:
141
+ """Try to get a single item matching the specified criteria.
142
+
143
+ This method applies filters and returns a single matching
144
+ item if exactly one is found, None if no items are found,
145
+ or raises ValueError if multiple items match.
146
+
147
+ Args:
148
+ *criteria: Callable criteria or (key, value) tuples
149
+ for filtering.
150
+ **kwargs: Additional key-value pairs for filtering.
151
+
152
+ Returns:
153
+ I | None: A single item that matches the criteria, or None if
154
+ no matches are found.
155
+
156
+ Raises:
157
+ ValueError: If multiple items match the criteria.
158
+
159
+ """
160
+ items = self.filter(*criteria, **kwargs)
161
+
162
+ n = len(items)
163
+ if n == 0:
164
+ return None
165
+
166
+ if n == 1:
167
+ return items[0]
168
+
169
+ msg = f"Multiple items ({n}) found matching the criteria, "
170
+ msg += "expected exactly one"
171
+ raise ValueError(msg)
172
+
173
+ def get(
174
+ self,
175
+ *criteria: Callable[[I], bool] | tuple[str, Any],
176
+ **kwargs: Any,
177
+ ) -> I:
178
+ """Get a single item matching the specified criteria.
179
+
180
+ This method applies filters and returns a single matching item,
181
+ or raises ValueError if no items or multiple items match.
182
+
183
+ Args:
184
+ *criteria: Callable criteria or (key, value) tuples
185
+ for filtering.
186
+ **kwargs: Additional key-value pairs for filtering.
187
+
188
+ Returns:
189
+ I: A single item that matches the criteria.
190
+
191
+ Raises:
192
+ ValueError: If no items match or if multiple items match
193
+ the criteria.
194
+
195
+ """
196
+ if item := self.try_get(*criteria, **kwargs):
197
+ return item
198
+
199
+ raise _value_error()
200
+
201
+ def first(
202
+ self,
203
+ *criteria: Callable[[I], bool] | tuple[str, Any],
204
+ **kwargs: Any,
205
+ ) -> I:
206
+ """Get the first item matching the specified criteria.
207
+
208
+ This method applies filters and returns the first matching item,
209
+ or raises ValueError if no items match.
210
+
211
+ Args:
212
+ *criteria: Callable criteria or (key, value) tuples
213
+ for filtering.
214
+ **kwargs: Additional key-value pairs for filtering.
215
+
216
+ Returns:
217
+ I: The first item that matches the criteria.
218
+
219
+ Raises:
220
+ ValueError: If no items match the criteria.
221
+
222
+ """
223
+ if items := self.filter(*criteria, **kwargs):
224
+ return items[0]
225
+
226
+ raise _value_error()
227
+
228
+ def last(
229
+ self,
230
+ *criteria: Callable[[I], bool] | tuple[str, Any],
231
+ **kwargs: Any,
232
+ ) -> I:
233
+ """Get the last item matching the specified criteria.
234
+
235
+ This method applies filters and returns the last matching item,
236
+ or raises ValueError if no items match.
237
+
238
+ Args:
239
+ *criteria: Callable criteria or (key, value) tuples
240
+ for filtering.
241
+ **kwargs: Additional key-value pairs for filtering.
242
+
243
+ Returns:
244
+ I: The last item that matches the criteria.
245
+
246
+ Raises:
247
+ ValueError: If no items match the criteria.
248
+
249
+ """
250
+ if items := self.filter(*criteria, **kwargs):
251
+ return items[-1]
252
+
253
+ raise _value_error()
254
+
255
+ def to_list(
256
+ self,
257
+ key: str,
258
+ default: Any | Callable[[I], Any] = MISSING,
259
+ ) -> list[Any]:
260
+ """Extract a list of values for a specific key from all items.
261
+
262
+ Args:
263
+ key: The key to extract from each item.
264
+ default: The default value to return if the key is not found.
265
+ If a callable, it will be called with the item
266
+ and the value returned will be used as the default.
267
+
268
+ Returns:
269
+ list[Any]: A list containing the values for the
270
+ specified key from each item.
271
+
272
+ """
273
+ return [self._get(i, key, default) for i in self]
274
+
275
+ def to_numpy(
276
+ self,
277
+ key: str,
278
+ default: Any | Callable[[I], Any] = MISSING,
279
+ ) -> NDArray:
280
+ """Extract values for a specific key from all items as a NumPy array.
281
+
282
+ Args:
283
+ key: The key to extract from each item.
284
+ default: The default value to return if the key is not found.
285
+ If a callable, it will be called with the item
286
+ and the value returned will be used as the default.
287
+
288
+ Returns:
289
+ NDArray: A NumPy array containing the values for the
290
+ specified key from each item.
291
+
292
+ """
293
+ return np.array(self.to_list(key, default))
294
+
295
+ def to_series(
296
+ self,
297
+ key: str,
298
+ default: Any = MISSING,
299
+ *,
300
+ name: str | None = None,
301
+ ) -> Series:
302
+ """Extract values for a specific key from all items as a Polars series.
303
+
304
+ Args:
305
+ key: The key to extract from each item.
306
+ default: The default value to return if the key is not found.
307
+ If a callable, it will be called with the item
308
+ and the value returned will be used as the default.
309
+ name: The name of the series. If not provided, the key will be used.
310
+
311
+ Returns:
312
+ Series: A Polars series containing the values for the
313
+ specified key from each item.
314
+
315
+ """
316
+ return Series(name or key, self.to_list(key, default))
317
+
318
+ def unique(
319
+ self,
320
+ key: str,
321
+ default: Any | Callable[[I], Any] = MISSING,
322
+ ) -> NDArray:
323
+ """Get the unique values for a specific key across all items.
324
+
325
+ Args:
326
+ key: The key to extract unique values for.
327
+ default: The default value to return if the key is not found.
328
+ If a callable, it will be called with the item
329
+ and the value returned will be used as the default.
330
+
331
+ Returns:
332
+ NDArray: A NumPy array containing the unique values for the
333
+ specified key.
334
+
335
+ """
336
+ return np.unique(self.to_numpy(key, default), axis=0)
337
+
338
+ def n_unique(
339
+ self,
340
+ key: str,
341
+ default: Any | Callable[[I], Any] = MISSING,
342
+ ) -> int:
343
+ """Count the number of unique values for a specific key across all items.
344
+
345
+ Args:
346
+ key: The key to count unique values for.
347
+ default: The default value to return if the key is not found.
348
+ If a callable, it will be called with the item
349
+ and the value returned will be used as the default.
350
+
351
+ Returns:
352
+ int: The number of unique values for the specified key.
353
+
354
+ """
355
+ return len(self.unique(key, default))
356
+
357
+ def sort(self, *keys: str, reverse: bool = False) -> Self:
358
+ """Sort items based on one or more keys.
359
+
360
+ Args:
361
+ *keys: The keys to sort by, in order of priority.
362
+ reverse: Whether to sort in descending order (default is
363
+ ascending).
364
+
365
+ Returns:
366
+ Self: A new Collection with the items sorted according to
367
+ the specified keys.
368
+
369
+ """
370
+ if not keys:
371
+ return self
372
+
373
+ arrays = [self.to_numpy(key) for key in keys]
374
+ index = np.lexsort(arrays[::-1])
375
+
376
+ if reverse:
377
+ index = index[::-1]
378
+
379
+ return self[index]
380
+
381
+ def to_frame(
382
+ self,
383
+ *keys: str,
384
+ defaults: dict[str, Any | Callable[[I], Any]] | None = None,
385
+ **kwargs: Callable[[I], Any],
386
+ ) -> DataFrame:
387
+ """Convert the collection to a Polars DataFrame.
388
+
389
+ Args:
390
+ *keys (str): The keys to include as columns in the DataFrame.
391
+ defaults (dict[str, Any | Callable[[T], Any]] | None): Default
392
+ values for the keys. If a callable, it will be called with
393
+ the item and the value returned will be used as the
394
+ default.
395
+ **kwargs (Callable[[I], Any]): Additional columns to compute
396
+ using callables that take an item and return a value.
397
+
398
+ Returns:
399
+ DataFrame: A Polars DataFrame containing the specified data
400
+ from the items.
401
+
402
+ """
403
+ if defaults is None:
404
+ defaults = {}
405
+
406
+ data = {k: self.to_list(k, defaults.get(k, MISSING)) for k in keys}
407
+ df = DataFrame(data)
408
+
409
+ if not kwargs:
410
+ return df
411
+
412
+ columns = [Series(k, [v(r) for r in self]) for k, v in kwargs.items()]
413
+ return df.with_columns(*columns)
414
+
415
+ def group_by(self, *by: str) -> GroupBy[Self, I]:
416
+ """Group items by one or more keys and return a GroupBy instance.
417
+
418
+ This method organizes items into groups based on the specified
419
+ keys and returns a GroupBy instance that contains the grouped
420
+ collections. The GroupBy instance behaves like a dictionary,
421
+ allowing access to collections for each group key.
422
+
423
+ Args:
424
+ *by: The keys to group by. If a single key is provided,
425
+ its value will be used as the group key.
426
+ If multiple keys are provided, a tuple of their
427
+ values will be used as the group key.
428
+ Keys can use dot notation (e.g., "model.type")
429
+ to access nested configuration values.
430
+
431
+ Returns:
432
+ GroupBy[Self, I]: A GroupBy instance containing the grouped items.
433
+ Each group is a collection of the same type as the original.
434
+
435
+ """
436
+ groups: dict[Any, Self] = {}
437
+
438
+ for item in self:
439
+ keys = [to_hashable(self._get(item, key, MISSING)) for key in by]
440
+ key = keys[0] if len(by) == 1 else tuple(keys)
441
+
442
+ if key not in groups:
443
+ groups[key] = self.__class__([], self._get)
444
+
445
+ groups[key]._items.append(item) # noqa: SLF001
446
+
447
+ return GroupBy(by, groups)
448
+
449
+
450
+ def to_hashable(value: Any) -> Hashable:
451
+ """Convert a value to a hashable instance.
452
+
453
+ This function handles various types of values and converts them to
454
+ hashable equivalents for use in dictionaries and sets.
455
+
456
+ Args:
457
+ value: The value to convert to a hashable instance.
458
+
459
+ Returns:
460
+ A hashable version of the input value.
461
+
462
+ """
463
+ if OmegaConf.is_list(value): # Is ListConfig hashable?
464
+ return tuple(value)
465
+ if isinstance(value, Hashable):
466
+ return value
467
+ if isinstance(value, np.ndarray):
468
+ return tuple(value.tolist())
469
+ try:
470
+ return tuple(value)
471
+ except TypeError:
472
+ return str(value)
473
+
474
+
475
+ def _value_error() -> ValueError:
476
+ msg = "No item found matching the specified criteria"
477
+ return ValueError(msg)
478
+
479
+
480
+ def matches(value: Any, criterion: Any) -> bool:
481
+ """Check if a value matches the given criterion.
482
+
483
+ This function compares the value with the given criteria according
484
+ to the following rules:
485
+
486
+ - If criterion is callable: Call it with the value and return
487
+ the boolean result
488
+ - If criterion is a list or set: Check if the value is in the list/set
489
+ - If criterion is a tuple of length 2: Check if the value is
490
+ in the range [criterion[0], criterion[1]]. Both sides are
491
+ inclusive
492
+ - Otherwise: Check if the value equals the criteria
493
+
494
+ Args:
495
+ value: The value to be compared with the criterion.
496
+ criterion: The criterion to match against.
497
+ Can be:
498
+
499
+ - A callable that takes the value and returns a boolean
500
+ - A list or set to check membership
501
+ - A tuple of length 2 to check range inclusion
502
+ - Any other value for direct equality comparison
503
+
504
+ Returns:
505
+ bool: True if the value matches the criterion according to the rules above,
506
+ False otherwise.
507
+
508
+ Examples:
509
+ >>> matches(5, lambda x: x > 3)
510
+ True
511
+ >>> matches(2, [1, 2, 3])
512
+ True
513
+ >>> matches(4, (1, 5))
514
+ True
515
+ >>> matches(3, 3)
516
+ True
517
+
518
+ """
519
+ if callable(criterion):
520
+ return bool(criterion(value))
521
+
522
+ if isinstance(criterion, ListConfig):
523
+ criterion = list(criterion)
524
+
525
+ if isinstance(criterion, list | set) and not _is_iterable(value):
526
+ return value in criterion
527
+
528
+ if isinstance(criterion, tuple) and len(criterion) == 2 and not _is_iterable(value):
529
+ return criterion[0] <= value <= criterion[1]
530
+
531
+ if _is_iterable(criterion):
532
+ criterion = list(criterion)
533
+
534
+ if _is_iterable(value):
535
+ value = list(value)
536
+
537
+ return value == criterion
538
+
539
+
540
+ def _is_iterable(value: Any) -> bool:
541
+ return isinstance(value, Iterable) and not isinstance(value, str)