hydraflow 0.16.2__py3-none-any.whl → 0.17.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,613 @@
1
+ """Provide a collection of items that implements the Sequence protocol."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Hashable, Iterable, Sequence
6
+ from dataclasses import MISSING
7
+ from typing import TYPE_CHECKING, Concatenate, overload
8
+
9
+ import numpy as np
10
+ from joblib.parallel import Parallel, delayed
11
+ from omegaconf import ListConfig, OmegaConf
12
+ from polars import DataFrame, Series
13
+
14
+ from .group_by import GroupBy
15
+
16
+ if TYPE_CHECKING:
17
+ from collections.abc import Callable, Iterator
18
+ from typing import Any, Self
19
+
20
+ from numpy.typing import NDArray
21
+
22
+
23
+ class Collection[I](Sequence[I]):
24
+ """A collection of items that implements the Sequence protocol."""
25
+
26
+ _items: list[I]
27
+ _get: Callable[[I, str, Any | Callable[[I], Any]], Any]
28
+
29
+ def __init__(
30
+ self,
31
+ items: Iterable[I],
32
+ get: Callable[[I, str, Any | Callable[[I], Any]], Any] | None = None,
33
+ ) -> None:
34
+ self._items = list(items)
35
+ self._get = get or getattr
36
+
37
+ def __repr__(self) -> str:
38
+ class_name = self.__class__.__name__
39
+ if not self:
40
+ return f"{class_name}(empty)"
41
+
42
+ type_name = repr(self[0])
43
+ if "(" in type_name:
44
+ type_name = type_name.split("(", 1)[0]
45
+ return f"{class_name}({type_name}, n={len(self)})"
46
+
47
+ def __len__(self) -> int:
48
+ return len(self._items)
49
+
50
+ def __bool__(self) -> bool:
51
+ return bool(self._items)
52
+
53
+ @overload
54
+ def __getitem__(self, index: int) -> I: ...
55
+
56
+ @overload
57
+ def __getitem__(self, index: slice) -> Self: ...
58
+
59
+ @overload
60
+ def __getitem__(self, index: Iterable[int]) -> Self: ...
61
+
62
+ def __getitem__(self, index: int | slice | Iterable[int]) -> I | Self:
63
+ if isinstance(index, int):
64
+ return self._items[index]
65
+
66
+ if isinstance(index, slice):
67
+ return self.__class__(self._items[index], self._get)
68
+
69
+ return self.__class__([self._items[i] for i in index], self._get)
70
+
71
+ def __iter__(self) -> Iterator[I]:
72
+ return iter(self._items)
73
+
74
+ def filter(
75
+ self,
76
+ *criteria: Callable[[I], bool] | tuple[str, Any],
77
+ **kwargs: Any,
78
+ ) -> Self:
79
+ """Filter items based on criteria.
80
+
81
+ This method allows filtering items using various criteria:
82
+
83
+ - Callable criteria that take an item and return a boolean
84
+ - Key-value tuples where the key is a string and the value
85
+ is compared using the `matches` function
86
+ - Keyword arguments, where the key is a string and the value
87
+ is compared using the `matches` function
88
+
89
+ The `matches` function supports the following comparison types:
90
+
91
+ - Callable: The predicate function is called with the value
92
+ - List/Set: Checks if the value is in the list/set
93
+ - Tuple of length 2: Checks if the value is in the range [min, max]
94
+ - Other: Checks for direct equality
95
+
96
+ Args:
97
+ *criteria: Callable criteria or (key, value) tuples
98
+ for filtering.
99
+ **kwargs: Additional key-value pairs for filtering.
100
+
101
+ Returns:
102
+ Self: A new Collection containing only the items that
103
+ match all criteria.
104
+
105
+ Examples:
106
+ ```python
107
+ # Filter using a callable
108
+ filtered = collection.filter(lambda x: x > 5)
109
+
110
+ # Filter using a key-value tuple
111
+ filtered = collection.filter(("age", 25))
112
+
113
+ # Filter using keyword arguments
114
+ filtered = collection.filter(age=25, name="John")
115
+
116
+ # Filter using range
117
+ filtered = collection.filter(("age", (20, 30)))
118
+
119
+ # Filter using list membership
120
+ filtered = collection.filter(("name", ["John", "Jane"]))
121
+ ```
122
+
123
+ """
124
+ items = self._items
125
+
126
+ for c in criteria:
127
+ if callable(c):
128
+ items = [i for i in items if c(i)]
129
+ else:
130
+ items = [i for i in items if matches(self._get(i, c[0], MISSING), c[1])]
131
+
132
+ for key, value in kwargs.items():
133
+ items = [i for i in items if matches(self._get(i, key, MISSING), value)]
134
+
135
+ return self.__class__(items, self._get)
136
+
137
+ def try_get(
138
+ self,
139
+ *criteria: Callable[[I], bool] | tuple[str, Any],
140
+ **kwargs: Any,
141
+ ) -> I | None:
142
+ """Try to get a single item matching the specified criteria.
143
+
144
+ This method applies filters and returns a single matching
145
+ item if exactly one is found, None if no items are found,
146
+ or raises ValueError if multiple items match.
147
+
148
+ Args:
149
+ *criteria: Callable criteria or (key, value) tuples
150
+ for filtering.
151
+ **kwargs: Additional key-value pairs for filtering.
152
+
153
+ Returns:
154
+ I | None: A single item that matches the criteria, or None if
155
+ no matches are found.
156
+
157
+ Raises:
158
+ ValueError: If multiple items match the criteria.
159
+
160
+ """
161
+ items = self.filter(*criteria, **kwargs)
162
+
163
+ n = len(items)
164
+ if n == 0:
165
+ return None
166
+
167
+ if n == 1:
168
+ return items[0]
169
+
170
+ msg = f"Multiple items ({n}) found matching the criteria, "
171
+ msg += "expected exactly one"
172
+ raise ValueError(msg)
173
+
174
+ def get(
175
+ self,
176
+ *criteria: Callable[[I], bool] | tuple[str, Any],
177
+ **kwargs: Any,
178
+ ) -> I:
179
+ """Get a single item matching the specified criteria.
180
+
181
+ This method applies filters and returns a single matching item,
182
+ or raises ValueError if no items or multiple items match.
183
+
184
+ Args:
185
+ *criteria: Callable criteria or (key, value) tuples
186
+ for filtering.
187
+ **kwargs: Additional key-value pairs for filtering.
188
+
189
+ Returns:
190
+ I: A single item that matches the criteria.
191
+
192
+ Raises:
193
+ ValueError: If no items match or if multiple items match
194
+ the criteria.
195
+
196
+ """
197
+ if item := self.try_get(*criteria, **kwargs):
198
+ return item
199
+
200
+ raise _value_error()
201
+
202
+ def first(
203
+ self,
204
+ *criteria: Callable[[I], bool] | tuple[str, Any],
205
+ **kwargs: Any,
206
+ ) -> I:
207
+ """Get the first item matching the specified criteria.
208
+
209
+ This method applies filters and returns the first matching item,
210
+ or raises ValueError if no items match.
211
+
212
+ Args:
213
+ *criteria: Callable criteria or (key, value) tuples
214
+ for filtering.
215
+ **kwargs: Additional key-value pairs for filtering.
216
+
217
+ Returns:
218
+ I: The first item that matches the criteria.
219
+
220
+ Raises:
221
+ ValueError: If no items match the criteria.
222
+
223
+ """
224
+ if items := self.filter(*criteria, **kwargs):
225
+ return items[0]
226
+
227
+ raise _value_error()
228
+
229
+ def last(
230
+ self,
231
+ *criteria: Callable[[I], bool] | tuple[str, Any],
232
+ **kwargs: Any,
233
+ ) -> I:
234
+ """Get the last item matching the specified criteria.
235
+
236
+ This method applies filters and returns the last matching item,
237
+ or raises ValueError if no items match.
238
+
239
+ Args:
240
+ *criteria: Callable criteria or (key, value) tuples
241
+ for filtering.
242
+ **kwargs: Additional key-value pairs for filtering.
243
+
244
+ Returns:
245
+ I: The last item that matches the criteria.
246
+
247
+ Raises:
248
+ ValueError: If no items match the criteria.
249
+
250
+ """
251
+ if items := self.filter(*criteria, **kwargs):
252
+ return items[-1]
253
+
254
+ raise _value_error()
255
+
256
+ def to_list(
257
+ self,
258
+ key: str,
259
+ default: Any | Callable[[I], Any] = MISSING,
260
+ ) -> list[Any]:
261
+ """Extract a list of values for a specific key from all items.
262
+
263
+ Args:
264
+ key: The key to extract from each item.
265
+ default: The default value to return if the key is not found.
266
+ If a callable, it will be called with the item
267
+ and the value returned will be used as the default.
268
+
269
+ Returns:
270
+ list[Any]: A list containing the values for the
271
+ specified key from each item.
272
+
273
+ """
274
+ return [self._get(i, key, default) for i in self]
275
+
276
+ def to_numpy(
277
+ self,
278
+ key: str,
279
+ default: Any | Callable[[I], Any] = MISSING,
280
+ ) -> NDArray:
281
+ """Extract values for a specific key from all items as a NumPy array.
282
+
283
+ Args:
284
+ key: The key to extract from each item.
285
+ default: The default value to return if the key is not found.
286
+ If a callable, it will be called with the item
287
+ and the value returned will be used as the default.
288
+
289
+ Returns:
290
+ NDArray: A NumPy array containing the values for the
291
+ specified key from each item.
292
+
293
+ """
294
+ return np.array(self.to_list(key, default))
295
+
296
+ def to_series(
297
+ self,
298
+ key: str,
299
+ default: Any = MISSING,
300
+ *,
301
+ name: str | None = None,
302
+ ) -> Series:
303
+ """Extract values for a specific key from all items as a Polars series.
304
+
305
+ Args:
306
+ key: The key to extract from each item.
307
+ default: The default value to return if the key is not found.
308
+ If a callable, it will be called with the item
309
+ and the value returned will be used as the default.
310
+ name: The name of the series. If not provided, the key will be used.
311
+
312
+ Returns:
313
+ Series: A Polars series containing the values for the
314
+ specified key from each item.
315
+
316
+ """
317
+ return Series(name or key, self.to_list(key, default))
318
+
319
+ def unique(
320
+ self,
321
+ key: str,
322
+ default: Any | Callable[[I], Any] = MISSING,
323
+ ) -> NDArray:
324
+ """Get the unique values for a specific key across all items.
325
+
326
+ Args:
327
+ key: The key to extract unique values for.
328
+ default: The default value to return if the key is not found.
329
+ If a callable, it will be called with the item
330
+ and the value returned will be used as the default.
331
+
332
+ Returns:
333
+ NDArray: A NumPy array containing the unique values for the
334
+ specified key.
335
+
336
+ """
337
+ return np.unique(self.to_numpy(key, default), axis=0)
338
+
339
+ def n_unique(
340
+ self,
341
+ key: str,
342
+ default: Any | Callable[[I], Any] = MISSING,
343
+ ) -> int:
344
+ """Count the number of unique values for a specific key across all items.
345
+
346
+ Args:
347
+ key: The key to count unique values for.
348
+ default: The default value to return if the key is not found.
349
+ If a callable, it will be called with the item
350
+ and the value returned will be used as the default.
351
+
352
+ Returns:
353
+ int: The number of unique values for the specified key.
354
+
355
+ """
356
+ return len(self.unique(key, default))
357
+
358
+ def sort(self, *keys: str, reverse: bool = False) -> Self:
359
+ """Sort items based on one or more keys.
360
+
361
+ Args:
362
+ *keys: The keys to sort by, in order of priority.
363
+ reverse: Whether to sort in descending order (default is
364
+ ascending).
365
+
366
+ Returns:
367
+ Self: A new Collection with the items sorted according to
368
+ the specified keys.
369
+
370
+ """
371
+ if not keys:
372
+ return self
373
+
374
+ arrays = [self.to_numpy(key) for key in keys]
375
+ index = np.lexsort(arrays[::-1])
376
+
377
+ if reverse:
378
+ index = index[::-1]
379
+
380
+ return self[index]
381
+
382
+ def map[**P, R](
383
+ self,
384
+ function: Callable[Concatenate[I, P], R],
385
+ *args: P.args,
386
+ **kwargs: P.kwargs,
387
+ ) -> Iterator[R]:
388
+ """Apply a function to each item and return an iterator of results.
389
+
390
+ This is a memory-efficient mapping operation that lazily evaluates results.
391
+ Ideal for large collections where memory usage is a concern.
392
+
393
+ Args:
394
+ function: Function to apply to each item. The item is passed
395
+ as the first argument.
396
+ *args: Additional positional arguments to pass to the function.
397
+ **kwargs: Additional keyword arguments to pass to the function.
398
+
399
+ Returns:
400
+ Iterator[R]: An iterator of the function's results.
401
+
402
+ Examples:
403
+ ```python
404
+ # Process results one at a time
405
+ for result in collection.map(process_item, additional_arg):
406
+ handle_result(result)
407
+
408
+ # Convert to list if needed
409
+ results = list(collection.map(transform_item))
410
+ ```
411
+
412
+ """
413
+ yield from (function(i, *args, **kwargs) for i in self)
414
+
415
+ def pmap[**P, R](
416
+ self,
417
+ function: Callable[Concatenate[I, P], R],
418
+ n_jobs: int = -1,
419
+ backend: str = "multiprocessing",
420
+ *args: P.args,
421
+ **kwargs: P.kwargs,
422
+ ) -> list[R]:
423
+ """Apply a function to each item in parallel and return a list of results.
424
+
425
+ This method processes items concurrently for improved performance on
426
+ CPU-bound or I/O-bound operations, depending on the backend.
427
+
428
+ Args:
429
+ function: Function to apply to each item. The item is passed
430
+ as the first argument.
431
+ n_jobs (int): Number of jobs to run in parallel. -1 means using all
432
+ processors.
433
+ backend (str): Parallelization backend.
434
+ *args: Additional positional arguments to pass to the function.
435
+ **kwargs: Additional keyword arguments to pass to the function.
436
+
437
+ Returns:
438
+ list[R]: A list containing all results of the function applications.
439
+
440
+ Examples:
441
+ ```python
442
+ # Process all items in parallel using all cores
443
+ results = collection.pmap(heavy_computation)
444
+
445
+ # Specify number of parallel jobs and backend
446
+ results = collection.pmap(process_files, n_jobs=4, backend="threading")
447
+ ```
448
+
449
+ """
450
+ parallel = Parallel(n_jobs=n_jobs, backend=backend, return_as="list")
451
+ return parallel(delayed(function)(i, *args, **kwargs) for i in self) # type: ignore
452
+
453
+ def to_frame(
454
+ self,
455
+ *keys: str,
456
+ defaults: dict[str, Any | Callable[[I], Any]] | None = None,
457
+ **kwargs: Callable[[I], Any],
458
+ ) -> DataFrame:
459
+ """Convert the collection to a Polars DataFrame.
460
+
461
+ Args:
462
+ *keys (str): The keys to include as columns in the DataFrame.
463
+ defaults (dict[str, Any | Callable[[T], Any]] | None): Default
464
+ values for the keys. If a callable, it will be called with
465
+ the item and the value returned will be used as the
466
+ default.
467
+ **kwargs (Callable[[I], Any]): Additional columns to compute
468
+ using callables that take an item and return a value.
469
+
470
+ Returns:
471
+ DataFrame: A Polars DataFrame containing the specified data
472
+ from the items.
473
+
474
+ """
475
+ if defaults is None:
476
+ defaults = {}
477
+
478
+ data = {k: self.to_list(k, defaults.get(k, MISSING)) for k in keys}
479
+ df = DataFrame(data)
480
+
481
+ if not kwargs:
482
+ return df
483
+
484
+ columns = [Series(k, self.map(v)) for k, v in kwargs.items()]
485
+ return df.with_columns(*columns)
486
+
487
+ def group_by(self, *by: str) -> GroupBy[Self, I]:
488
+ """Group items by one or more keys and return a GroupBy instance.
489
+
490
+ This method organizes items into groups based on the specified
491
+ keys and returns a GroupBy instance that contains the grouped
492
+ collections. The GroupBy instance behaves like a dictionary,
493
+ allowing access to collections for each group key.
494
+
495
+ Args:
496
+ *by: The keys to group by. If a single key is provided,
497
+ its value will be used as the group key.
498
+ If multiple keys are provided, a tuple of their
499
+ values will be used as the group key.
500
+ Keys can use dot notation (e.g., "model.type")
501
+ to access nested configuration values.
502
+
503
+ Returns:
504
+ GroupBy[Self, I]: A GroupBy instance containing the grouped items.
505
+ Each group is a collection of the same type as the original.
506
+
507
+ """
508
+ groups: dict[Any, Self] = {}
509
+
510
+ for item in self:
511
+ keys = [to_hashable(self._get(item, key, MISSING)) for key in by]
512
+ key = keys[0] if len(by) == 1 else tuple(keys)
513
+
514
+ if key not in groups:
515
+ groups[key] = self.__class__([], self._get)
516
+
517
+ groups[key]._items.append(item) # noqa: SLF001
518
+
519
+ return GroupBy(by, groups)
520
+
521
+
522
+ def to_hashable(value: Any) -> Hashable:
523
+ """Convert a value to a hashable instance.
524
+
525
+ This function handles various types of values and converts them to
526
+ hashable equivalents for use in dictionaries and sets.
527
+
528
+ Args:
529
+ value: The value to convert to a hashable instance.
530
+
531
+ Returns:
532
+ A hashable version of the input value.
533
+
534
+ """
535
+ if OmegaConf.is_list(value): # Is ListConfig hashable?
536
+ return tuple(value)
537
+ if isinstance(value, Hashable):
538
+ return value
539
+ if isinstance(value, np.ndarray):
540
+ return tuple(value.tolist())
541
+ try:
542
+ return tuple(value)
543
+ except TypeError:
544
+ return str(value)
545
+
546
+
547
+ def _value_error() -> ValueError:
548
+ msg = "No item found matching the specified criteria"
549
+ return ValueError(msg)
550
+
551
+
552
+ def matches(value: Any, criterion: Any) -> bool:
553
+ """Check if a value matches the given criterion.
554
+
555
+ This function compares the value with the given criteria according
556
+ to the following rules:
557
+
558
+ - If criterion is callable: Call it with the value and return
559
+ the boolean result
560
+ - If criterion is a list or set: Check if the value is in the list/set
561
+ - If criterion is a tuple of length 2: Check if the value is
562
+ in the range [criterion[0], criterion[1]]. Both sides are
563
+ inclusive
564
+ - Otherwise: Check if the value equals the criteria
565
+
566
+ Args:
567
+ value: The value to be compared with the criterion.
568
+ criterion: The criterion to match against.
569
+ Can be:
570
+
571
+ - A callable that takes the value and returns a boolean
572
+ - A list or set to check membership
573
+ - A tuple of length 2 to check range inclusion
574
+ - Any other value for direct equality comparison
575
+
576
+ Returns:
577
+ bool: True if the value matches the criterion according to the rules above,
578
+ False otherwise.
579
+
580
+ Examples:
581
+ >>> matches(5, lambda x: x > 3)
582
+ True
583
+ >>> matches(2, [1, 2, 3])
584
+ True
585
+ >>> matches(4, (1, 5))
586
+ True
587
+ >>> matches(3, 3)
588
+ True
589
+
590
+ """
591
+ if callable(criterion):
592
+ return bool(criterion(value))
593
+
594
+ if isinstance(criterion, ListConfig):
595
+ criterion = list(criterion)
596
+
597
+ if isinstance(criterion, list | set) and not _is_iterable(value):
598
+ return value in criterion
599
+
600
+ if isinstance(criterion, tuple) and len(criterion) == 2 and not _is_iterable(value):
601
+ return criterion[0] <= value <= criterion[1]
602
+
603
+ if _is_iterable(criterion):
604
+ criterion = list(criterion)
605
+
606
+ if _is_iterable(value):
607
+ value = list(value)
608
+
609
+ return value == criterion
610
+
611
+
612
+ def _is_iterable(value: Any) -> bool:
613
+ return isinstance(value, Iterable) and not isinstance(value, str)
hydraflow/core/context.py CHANGED
@@ -128,13 +128,12 @@ def chdir_artifact(run: Run) -> Iterator[Path]:
128
128
  run (Run | None): The run to get the artifact directory from.
129
129
 
130
130
  """
131
- curdir = Path.cwd()
131
+ current_dir = Path.cwd()
132
132
  artifact_dir = get_artifact_dir(run)
133
133
 
134
- os.chdir(artifact_dir)
135
-
136
134
  try:
135
+ os.chdir(artifact_dir)
137
136
  yield artifact_dir
138
137
 
139
138
  finally:
140
- os.chdir(curdir)
139
+ os.chdir(current_dir)