hydraflow 0.16.2__py3-none-any.whl → 0.17.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydraflow/core/collection.py +613 -0
- hydraflow/core/context.py +3 -4
- hydraflow/core/group_by.py +205 -0
- hydraflow/core/run.py +111 -62
- hydraflow/core/run_collection.py +66 -483
- hydraflow/core/run_info.py +0 -9
- {hydraflow-0.16.2.dist-info → hydraflow-0.17.1.dist-info}/METADATA +1 -1
- {hydraflow-0.16.2.dist-info → hydraflow-0.17.1.dist-info}/RECORD +11 -9
- {hydraflow-0.16.2.dist-info → hydraflow-0.17.1.dist-info}/WHEEL +0 -0
- {hydraflow-0.16.2.dist-info → hydraflow-0.17.1.dist-info}/entry_points.txt +0 -0
- {hydraflow-0.16.2.dist-info → hydraflow-0.17.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,613 @@
|
|
1
|
+
"""Provide a collection of items that implements the Sequence protocol."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from collections.abc import Hashable, Iterable, Sequence
|
6
|
+
from dataclasses import MISSING
|
7
|
+
from typing import TYPE_CHECKING, Concatenate, overload
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
from joblib.parallel import Parallel, delayed
|
11
|
+
from omegaconf import ListConfig, OmegaConf
|
12
|
+
from polars import DataFrame, Series
|
13
|
+
|
14
|
+
from .group_by import GroupBy
|
15
|
+
|
16
|
+
if TYPE_CHECKING:
|
17
|
+
from collections.abc import Callable, Iterator
|
18
|
+
from typing import Any, Self
|
19
|
+
|
20
|
+
from numpy.typing import NDArray
|
21
|
+
|
22
|
+
|
23
|
+
class Collection[I](Sequence[I]):
|
24
|
+
"""A collection of items that implements the Sequence protocol."""
|
25
|
+
|
26
|
+
_items: list[I]
|
27
|
+
_get: Callable[[I, str, Any | Callable[[I], Any]], Any]
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
items: Iterable[I],
|
32
|
+
get: Callable[[I, str, Any | Callable[[I], Any]], Any] | None = None,
|
33
|
+
) -> None:
|
34
|
+
self._items = list(items)
|
35
|
+
self._get = get or getattr
|
36
|
+
|
37
|
+
def __repr__(self) -> str:
|
38
|
+
class_name = self.__class__.__name__
|
39
|
+
if not self:
|
40
|
+
return f"{class_name}(empty)"
|
41
|
+
|
42
|
+
type_name = repr(self[0])
|
43
|
+
if "(" in type_name:
|
44
|
+
type_name = type_name.split("(", 1)[0]
|
45
|
+
return f"{class_name}({type_name}, n={len(self)})"
|
46
|
+
|
47
|
+
def __len__(self) -> int:
|
48
|
+
return len(self._items)
|
49
|
+
|
50
|
+
def __bool__(self) -> bool:
|
51
|
+
return bool(self._items)
|
52
|
+
|
53
|
+
@overload
|
54
|
+
def __getitem__(self, index: int) -> I: ...
|
55
|
+
|
56
|
+
@overload
|
57
|
+
def __getitem__(self, index: slice) -> Self: ...
|
58
|
+
|
59
|
+
@overload
|
60
|
+
def __getitem__(self, index: Iterable[int]) -> Self: ...
|
61
|
+
|
62
|
+
def __getitem__(self, index: int | slice | Iterable[int]) -> I | Self:
|
63
|
+
if isinstance(index, int):
|
64
|
+
return self._items[index]
|
65
|
+
|
66
|
+
if isinstance(index, slice):
|
67
|
+
return self.__class__(self._items[index], self._get)
|
68
|
+
|
69
|
+
return self.__class__([self._items[i] for i in index], self._get)
|
70
|
+
|
71
|
+
def __iter__(self) -> Iterator[I]:
|
72
|
+
return iter(self._items)
|
73
|
+
|
74
|
+
def filter(
|
75
|
+
self,
|
76
|
+
*criteria: Callable[[I], bool] | tuple[str, Any],
|
77
|
+
**kwargs: Any,
|
78
|
+
) -> Self:
|
79
|
+
"""Filter items based on criteria.
|
80
|
+
|
81
|
+
This method allows filtering items using various criteria:
|
82
|
+
|
83
|
+
- Callable criteria that take an item and return a boolean
|
84
|
+
- Key-value tuples where the key is a string and the value
|
85
|
+
is compared using the `matches` function
|
86
|
+
- Keyword arguments, where the key is a string and the value
|
87
|
+
is compared using the `matches` function
|
88
|
+
|
89
|
+
The `matches` function supports the following comparison types:
|
90
|
+
|
91
|
+
- Callable: The predicate function is called with the value
|
92
|
+
- List/Set: Checks if the value is in the list/set
|
93
|
+
- Tuple of length 2: Checks if the value is in the range [min, max]
|
94
|
+
- Other: Checks for direct equality
|
95
|
+
|
96
|
+
Args:
|
97
|
+
*criteria: Callable criteria or (key, value) tuples
|
98
|
+
for filtering.
|
99
|
+
**kwargs: Additional key-value pairs for filtering.
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
Self: A new Collection containing only the items that
|
103
|
+
match all criteria.
|
104
|
+
|
105
|
+
Examples:
|
106
|
+
```python
|
107
|
+
# Filter using a callable
|
108
|
+
filtered = collection.filter(lambda x: x > 5)
|
109
|
+
|
110
|
+
# Filter using a key-value tuple
|
111
|
+
filtered = collection.filter(("age", 25))
|
112
|
+
|
113
|
+
# Filter using keyword arguments
|
114
|
+
filtered = collection.filter(age=25, name="John")
|
115
|
+
|
116
|
+
# Filter using range
|
117
|
+
filtered = collection.filter(("age", (20, 30)))
|
118
|
+
|
119
|
+
# Filter using list membership
|
120
|
+
filtered = collection.filter(("name", ["John", "Jane"]))
|
121
|
+
```
|
122
|
+
|
123
|
+
"""
|
124
|
+
items = self._items
|
125
|
+
|
126
|
+
for c in criteria:
|
127
|
+
if callable(c):
|
128
|
+
items = [i for i in items if c(i)]
|
129
|
+
else:
|
130
|
+
items = [i for i in items if matches(self._get(i, c[0], MISSING), c[1])]
|
131
|
+
|
132
|
+
for key, value in kwargs.items():
|
133
|
+
items = [i for i in items if matches(self._get(i, key, MISSING), value)]
|
134
|
+
|
135
|
+
return self.__class__(items, self._get)
|
136
|
+
|
137
|
+
def try_get(
|
138
|
+
self,
|
139
|
+
*criteria: Callable[[I], bool] | tuple[str, Any],
|
140
|
+
**kwargs: Any,
|
141
|
+
) -> I | None:
|
142
|
+
"""Try to get a single item matching the specified criteria.
|
143
|
+
|
144
|
+
This method applies filters and returns a single matching
|
145
|
+
item if exactly one is found, None if no items are found,
|
146
|
+
or raises ValueError if multiple items match.
|
147
|
+
|
148
|
+
Args:
|
149
|
+
*criteria: Callable criteria or (key, value) tuples
|
150
|
+
for filtering.
|
151
|
+
**kwargs: Additional key-value pairs for filtering.
|
152
|
+
|
153
|
+
Returns:
|
154
|
+
I | None: A single item that matches the criteria, or None if
|
155
|
+
no matches are found.
|
156
|
+
|
157
|
+
Raises:
|
158
|
+
ValueError: If multiple items match the criteria.
|
159
|
+
|
160
|
+
"""
|
161
|
+
items = self.filter(*criteria, **kwargs)
|
162
|
+
|
163
|
+
n = len(items)
|
164
|
+
if n == 0:
|
165
|
+
return None
|
166
|
+
|
167
|
+
if n == 1:
|
168
|
+
return items[0]
|
169
|
+
|
170
|
+
msg = f"Multiple items ({n}) found matching the criteria, "
|
171
|
+
msg += "expected exactly one"
|
172
|
+
raise ValueError(msg)
|
173
|
+
|
174
|
+
def get(
|
175
|
+
self,
|
176
|
+
*criteria: Callable[[I], bool] | tuple[str, Any],
|
177
|
+
**kwargs: Any,
|
178
|
+
) -> I:
|
179
|
+
"""Get a single item matching the specified criteria.
|
180
|
+
|
181
|
+
This method applies filters and returns a single matching item,
|
182
|
+
or raises ValueError if no items or multiple items match.
|
183
|
+
|
184
|
+
Args:
|
185
|
+
*criteria: Callable criteria or (key, value) tuples
|
186
|
+
for filtering.
|
187
|
+
**kwargs: Additional key-value pairs for filtering.
|
188
|
+
|
189
|
+
Returns:
|
190
|
+
I: A single item that matches the criteria.
|
191
|
+
|
192
|
+
Raises:
|
193
|
+
ValueError: If no items match or if multiple items match
|
194
|
+
the criteria.
|
195
|
+
|
196
|
+
"""
|
197
|
+
if item := self.try_get(*criteria, **kwargs):
|
198
|
+
return item
|
199
|
+
|
200
|
+
raise _value_error()
|
201
|
+
|
202
|
+
def first(
|
203
|
+
self,
|
204
|
+
*criteria: Callable[[I], bool] | tuple[str, Any],
|
205
|
+
**kwargs: Any,
|
206
|
+
) -> I:
|
207
|
+
"""Get the first item matching the specified criteria.
|
208
|
+
|
209
|
+
This method applies filters and returns the first matching item,
|
210
|
+
or raises ValueError if no items match.
|
211
|
+
|
212
|
+
Args:
|
213
|
+
*criteria: Callable criteria or (key, value) tuples
|
214
|
+
for filtering.
|
215
|
+
**kwargs: Additional key-value pairs for filtering.
|
216
|
+
|
217
|
+
Returns:
|
218
|
+
I: The first item that matches the criteria.
|
219
|
+
|
220
|
+
Raises:
|
221
|
+
ValueError: If no items match the criteria.
|
222
|
+
|
223
|
+
"""
|
224
|
+
if items := self.filter(*criteria, **kwargs):
|
225
|
+
return items[0]
|
226
|
+
|
227
|
+
raise _value_error()
|
228
|
+
|
229
|
+
def last(
|
230
|
+
self,
|
231
|
+
*criteria: Callable[[I], bool] | tuple[str, Any],
|
232
|
+
**kwargs: Any,
|
233
|
+
) -> I:
|
234
|
+
"""Get the last item matching the specified criteria.
|
235
|
+
|
236
|
+
This method applies filters and returns the last matching item,
|
237
|
+
or raises ValueError if no items match.
|
238
|
+
|
239
|
+
Args:
|
240
|
+
*criteria: Callable criteria or (key, value) tuples
|
241
|
+
for filtering.
|
242
|
+
**kwargs: Additional key-value pairs for filtering.
|
243
|
+
|
244
|
+
Returns:
|
245
|
+
I: The last item that matches the criteria.
|
246
|
+
|
247
|
+
Raises:
|
248
|
+
ValueError: If no items match the criteria.
|
249
|
+
|
250
|
+
"""
|
251
|
+
if items := self.filter(*criteria, **kwargs):
|
252
|
+
return items[-1]
|
253
|
+
|
254
|
+
raise _value_error()
|
255
|
+
|
256
|
+
def to_list(
|
257
|
+
self,
|
258
|
+
key: str,
|
259
|
+
default: Any | Callable[[I], Any] = MISSING,
|
260
|
+
) -> list[Any]:
|
261
|
+
"""Extract a list of values for a specific key from all items.
|
262
|
+
|
263
|
+
Args:
|
264
|
+
key: The key to extract from each item.
|
265
|
+
default: The default value to return if the key is not found.
|
266
|
+
If a callable, it will be called with the item
|
267
|
+
and the value returned will be used as the default.
|
268
|
+
|
269
|
+
Returns:
|
270
|
+
list[Any]: A list containing the values for the
|
271
|
+
specified key from each item.
|
272
|
+
|
273
|
+
"""
|
274
|
+
return [self._get(i, key, default) for i in self]
|
275
|
+
|
276
|
+
def to_numpy(
|
277
|
+
self,
|
278
|
+
key: str,
|
279
|
+
default: Any | Callable[[I], Any] = MISSING,
|
280
|
+
) -> NDArray:
|
281
|
+
"""Extract values for a specific key from all items as a NumPy array.
|
282
|
+
|
283
|
+
Args:
|
284
|
+
key: The key to extract from each item.
|
285
|
+
default: The default value to return if the key is not found.
|
286
|
+
If a callable, it will be called with the item
|
287
|
+
and the value returned will be used as the default.
|
288
|
+
|
289
|
+
Returns:
|
290
|
+
NDArray: A NumPy array containing the values for the
|
291
|
+
specified key from each item.
|
292
|
+
|
293
|
+
"""
|
294
|
+
return np.array(self.to_list(key, default))
|
295
|
+
|
296
|
+
def to_series(
|
297
|
+
self,
|
298
|
+
key: str,
|
299
|
+
default: Any = MISSING,
|
300
|
+
*,
|
301
|
+
name: str | None = None,
|
302
|
+
) -> Series:
|
303
|
+
"""Extract values for a specific key from all items as a Polars series.
|
304
|
+
|
305
|
+
Args:
|
306
|
+
key: The key to extract from each item.
|
307
|
+
default: The default value to return if the key is not found.
|
308
|
+
If a callable, it will be called with the item
|
309
|
+
and the value returned will be used as the default.
|
310
|
+
name: The name of the series. If not provided, the key will be used.
|
311
|
+
|
312
|
+
Returns:
|
313
|
+
Series: A Polars series containing the values for the
|
314
|
+
specified key from each item.
|
315
|
+
|
316
|
+
"""
|
317
|
+
return Series(name or key, self.to_list(key, default))
|
318
|
+
|
319
|
+
def unique(
|
320
|
+
self,
|
321
|
+
key: str,
|
322
|
+
default: Any | Callable[[I], Any] = MISSING,
|
323
|
+
) -> NDArray:
|
324
|
+
"""Get the unique values for a specific key across all items.
|
325
|
+
|
326
|
+
Args:
|
327
|
+
key: The key to extract unique values for.
|
328
|
+
default: The default value to return if the key is not found.
|
329
|
+
If a callable, it will be called with the item
|
330
|
+
and the value returned will be used as the default.
|
331
|
+
|
332
|
+
Returns:
|
333
|
+
NDArray: A NumPy array containing the unique values for the
|
334
|
+
specified key.
|
335
|
+
|
336
|
+
"""
|
337
|
+
return np.unique(self.to_numpy(key, default), axis=0)
|
338
|
+
|
339
|
+
def n_unique(
|
340
|
+
self,
|
341
|
+
key: str,
|
342
|
+
default: Any | Callable[[I], Any] = MISSING,
|
343
|
+
) -> int:
|
344
|
+
"""Count the number of unique values for a specific key across all items.
|
345
|
+
|
346
|
+
Args:
|
347
|
+
key: The key to count unique values for.
|
348
|
+
default: The default value to return if the key is not found.
|
349
|
+
If a callable, it will be called with the item
|
350
|
+
and the value returned will be used as the default.
|
351
|
+
|
352
|
+
Returns:
|
353
|
+
int: The number of unique values for the specified key.
|
354
|
+
|
355
|
+
"""
|
356
|
+
return len(self.unique(key, default))
|
357
|
+
|
358
|
+
def sort(self, *keys: str, reverse: bool = False) -> Self:
|
359
|
+
"""Sort items based on one or more keys.
|
360
|
+
|
361
|
+
Args:
|
362
|
+
*keys: The keys to sort by, in order of priority.
|
363
|
+
reverse: Whether to sort in descending order (default is
|
364
|
+
ascending).
|
365
|
+
|
366
|
+
Returns:
|
367
|
+
Self: A new Collection with the items sorted according to
|
368
|
+
the specified keys.
|
369
|
+
|
370
|
+
"""
|
371
|
+
if not keys:
|
372
|
+
return self
|
373
|
+
|
374
|
+
arrays = [self.to_numpy(key) for key in keys]
|
375
|
+
index = np.lexsort(arrays[::-1])
|
376
|
+
|
377
|
+
if reverse:
|
378
|
+
index = index[::-1]
|
379
|
+
|
380
|
+
return self[index]
|
381
|
+
|
382
|
+
def map[**P, R](
|
383
|
+
self,
|
384
|
+
function: Callable[Concatenate[I, P], R],
|
385
|
+
*args: P.args,
|
386
|
+
**kwargs: P.kwargs,
|
387
|
+
) -> Iterator[R]:
|
388
|
+
"""Apply a function to each item and return an iterator of results.
|
389
|
+
|
390
|
+
This is a memory-efficient mapping operation that lazily evaluates results.
|
391
|
+
Ideal for large collections where memory usage is a concern.
|
392
|
+
|
393
|
+
Args:
|
394
|
+
function: Function to apply to each item. The item is passed
|
395
|
+
as the first argument.
|
396
|
+
*args: Additional positional arguments to pass to the function.
|
397
|
+
**kwargs: Additional keyword arguments to pass to the function.
|
398
|
+
|
399
|
+
Returns:
|
400
|
+
Iterator[R]: An iterator of the function's results.
|
401
|
+
|
402
|
+
Examples:
|
403
|
+
```python
|
404
|
+
# Process results one at a time
|
405
|
+
for result in collection.map(process_item, additional_arg):
|
406
|
+
handle_result(result)
|
407
|
+
|
408
|
+
# Convert to list if needed
|
409
|
+
results = list(collection.map(transform_item))
|
410
|
+
```
|
411
|
+
|
412
|
+
"""
|
413
|
+
yield from (function(i, *args, **kwargs) for i in self)
|
414
|
+
|
415
|
+
def pmap[**P, R](
|
416
|
+
self,
|
417
|
+
function: Callable[Concatenate[I, P], R],
|
418
|
+
n_jobs: int = -1,
|
419
|
+
backend: str = "multiprocessing",
|
420
|
+
*args: P.args,
|
421
|
+
**kwargs: P.kwargs,
|
422
|
+
) -> list[R]:
|
423
|
+
"""Apply a function to each item in parallel and return a list of results.
|
424
|
+
|
425
|
+
This method processes items concurrently for improved performance on
|
426
|
+
CPU-bound or I/O-bound operations, depending on the backend.
|
427
|
+
|
428
|
+
Args:
|
429
|
+
function: Function to apply to each item. The item is passed
|
430
|
+
as the first argument.
|
431
|
+
n_jobs (int): Number of jobs to run in parallel. -1 means using all
|
432
|
+
processors.
|
433
|
+
backend (str): Parallelization backend.
|
434
|
+
*args: Additional positional arguments to pass to the function.
|
435
|
+
**kwargs: Additional keyword arguments to pass to the function.
|
436
|
+
|
437
|
+
Returns:
|
438
|
+
list[R]: A list containing all results of the function applications.
|
439
|
+
|
440
|
+
Examples:
|
441
|
+
```python
|
442
|
+
# Process all items in parallel using all cores
|
443
|
+
results = collection.pmap(heavy_computation)
|
444
|
+
|
445
|
+
# Specify number of parallel jobs and backend
|
446
|
+
results = collection.pmap(process_files, n_jobs=4, backend="threading")
|
447
|
+
```
|
448
|
+
|
449
|
+
"""
|
450
|
+
parallel = Parallel(n_jobs=n_jobs, backend=backend, return_as="list")
|
451
|
+
return parallel(delayed(function)(i, *args, **kwargs) for i in self) # type: ignore
|
452
|
+
|
453
|
+
def to_frame(
|
454
|
+
self,
|
455
|
+
*keys: str,
|
456
|
+
defaults: dict[str, Any | Callable[[I], Any]] | None = None,
|
457
|
+
**kwargs: Callable[[I], Any],
|
458
|
+
) -> DataFrame:
|
459
|
+
"""Convert the collection to a Polars DataFrame.
|
460
|
+
|
461
|
+
Args:
|
462
|
+
*keys (str): The keys to include as columns in the DataFrame.
|
463
|
+
defaults (dict[str, Any | Callable[[T], Any]] | None): Default
|
464
|
+
values for the keys. If a callable, it will be called with
|
465
|
+
the item and the value returned will be used as the
|
466
|
+
default.
|
467
|
+
**kwargs (Callable[[I], Any]): Additional columns to compute
|
468
|
+
using callables that take an item and return a value.
|
469
|
+
|
470
|
+
Returns:
|
471
|
+
DataFrame: A Polars DataFrame containing the specified data
|
472
|
+
from the items.
|
473
|
+
|
474
|
+
"""
|
475
|
+
if defaults is None:
|
476
|
+
defaults = {}
|
477
|
+
|
478
|
+
data = {k: self.to_list(k, defaults.get(k, MISSING)) for k in keys}
|
479
|
+
df = DataFrame(data)
|
480
|
+
|
481
|
+
if not kwargs:
|
482
|
+
return df
|
483
|
+
|
484
|
+
columns = [Series(k, self.map(v)) for k, v in kwargs.items()]
|
485
|
+
return df.with_columns(*columns)
|
486
|
+
|
487
|
+
def group_by(self, *by: str) -> GroupBy[Self, I]:
|
488
|
+
"""Group items by one or more keys and return a GroupBy instance.
|
489
|
+
|
490
|
+
This method organizes items into groups based on the specified
|
491
|
+
keys and returns a GroupBy instance that contains the grouped
|
492
|
+
collections. The GroupBy instance behaves like a dictionary,
|
493
|
+
allowing access to collections for each group key.
|
494
|
+
|
495
|
+
Args:
|
496
|
+
*by: The keys to group by. If a single key is provided,
|
497
|
+
its value will be used as the group key.
|
498
|
+
If multiple keys are provided, a tuple of their
|
499
|
+
values will be used as the group key.
|
500
|
+
Keys can use dot notation (e.g., "model.type")
|
501
|
+
to access nested configuration values.
|
502
|
+
|
503
|
+
Returns:
|
504
|
+
GroupBy[Self, I]: A GroupBy instance containing the grouped items.
|
505
|
+
Each group is a collection of the same type as the original.
|
506
|
+
|
507
|
+
"""
|
508
|
+
groups: dict[Any, Self] = {}
|
509
|
+
|
510
|
+
for item in self:
|
511
|
+
keys = [to_hashable(self._get(item, key, MISSING)) for key in by]
|
512
|
+
key = keys[0] if len(by) == 1 else tuple(keys)
|
513
|
+
|
514
|
+
if key not in groups:
|
515
|
+
groups[key] = self.__class__([], self._get)
|
516
|
+
|
517
|
+
groups[key]._items.append(item) # noqa: SLF001
|
518
|
+
|
519
|
+
return GroupBy(by, groups)
|
520
|
+
|
521
|
+
|
522
|
+
def to_hashable(value: Any) -> Hashable:
|
523
|
+
"""Convert a value to a hashable instance.
|
524
|
+
|
525
|
+
This function handles various types of values and converts them to
|
526
|
+
hashable equivalents for use in dictionaries and sets.
|
527
|
+
|
528
|
+
Args:
|
529
|
+
value: The value to convert to a hashable instance.
|
530
|
+
|
531
|
+
Returns:
|
532
|
+
A hashable version of the input value.
|
533
|
+
|
534
|
+
"""
|
535
|
+
if OmegaConf.is_list(value): # Is ListConfig hashable?
|
536
|
+
return tuple(value)
|
537
|
+
if isinstance(value, Hashable):
|
538
|
+
return value
|
539
|
+
if isinstance(value, np.ndarray):
|
540
|
+
return tuple(value.tolist())
|
541
|
+
try:
|
542
|
+
return tuple(value)
|
543
|
+
except TypeError:
|
544
|
+
return str(value)
|
545
|
+
|
546
|
+
|
547
|
+
def _value_error() -> ValueError:
|
548
|
+
msg = "No item found matching the specified criteria"
|
549
|
+
return ValueError(msg)
|
550
|
+
|
551
|
+
|
552
|
+
def matches(value: Any, criterion: Any) -> bool:
|
553
|
+
"""Check if a value matches the given criterion.
|
554
|
+
|
555
|
+
This function compares the value with the given criteria according
|
556
|
+
to the following rules:
|
557
|
+
|
558
|
+
- If criterion is callable: Call it with the value and return
|
559
|
+
the boolean result
|
560
|
+
- If criterion is a list or set: Check if the value is in the list/set
|
561
|
+
- If criterion is a tuple of length 2: Check if the value is
|
562
|
+
in the range [criterion[0], criterion[1]]. Both sides are
|
563
|
+
inclusive
|
564
|
+
- Otherwise: Check if the value equals the criteria
|
565
|
+
|
566
|
+
Args:
|
567
|
+
value: The value to be compared with the criterion.
|
568
|
+
criterion: The criterion to match against.
|
569
|
+
Can be:
|
570
|
+
|
571
|
+
- A callable that takes the value and returns a boolean
|
572
|
+
- A list or set to check membership
|
573
|
+
- A tuple of length 2 to check range inclusion
|
574
|
+
- Any other value for direct equality comparison
|
575
|
+
|
576
|
+
Returns:
|
577
|
+
bool: True if the value matches the criterion according to the rules above,
|
578
|
+
False otherwise.
|
579
|
+
|
580
|
+
Examples:
|
581
|
+
>>> matches(5, lambda x: x > 3)
|
582
|
+
True
|
583
|
+
>>> matches(2, [1, 2, 3])
|
584
|
+
True
|
585
|
+
>>> matches(4, (1, 5))
|
586
|
+
True
|
587
|
+
>>> matches(3, 3)
|
588
|
+
True
|
589
|
+
|
590
|
+
"""
|
591
|
+
if callable(criterion):
|
592
|
+
return bool(criterion(value))
|
593
|
+
|
594
|
+
if isinstance(criterion, ListConfig):
|
595
|
+
criterion = list(criterion)
|
596
|
+
|
597
|
+
if isinstance(criterion, list | set) and not _is_iterable(value):
|
598
|
+
return value in criterion
|
599
|
+
|
600
|
+
if isinstance(criterion, tuple) and len(criterion) == 2 and not _is_iterable(value):
|
601
|
+
return criterion[0] <= value <= criterion[1]
|
602
|
+
|
603
|
+
if _is_iterable(criterion):
|
604
|
+
criterion = list(criterion)
|
605
|
+
|
606
|
+
if _is_iterable(value):
|
607
|
+
value = list(value)
|
608
|
+
|
609
|
+
return value == criterion
|
610
|
+
|
611
|
+
|
612
|
+
def _is_iterable(value: Any) -> bool:
|
613
|
+
return isinstance(value, Iterable) and not isinstance(value, str)
|
hydraflow/core/context.py
CHANGED
@@ -128,13 +128,12 @@ def chdir_artifact(run: Run) -> Iterator[Path]:
|
|
128
128
|
run (Run | None): The run to get the artifact directory from.
|
129
129
|
|
130
130
|
"""
|
131
|
-
|
131
|
+
current_dir = Path.cwd()
|
132
132
|
artifact_dir = get_artifact_dir(run)
|
133
133
|
|
134
|
-
os.chdir(artifact_dir)
|
135
|
-
|
136
134
|
try:
|
135
|
+
os.chdir(artifact_dir)
|
137
136
|
yield artifact_dir
|
138
137
|
|
139
138
|
finally:
|
140
|
-
os.chdir(
|
139
|
+
os.chdir(current_dir)
|