hydraflow 0.17.2__py3-none-any.whl → 0.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydraflow/core/collection.py +320 -16
- hydraflow/core/main.py +18 -1
- hydraflow/core/run.py +33 -6
- hydraflow/core/run_collection.py +2 -2
- hydraflow/utils/__init__.py +0 -0
- hydraflow/utils/progress.py +90 -0
- {hydraflow-0.17.2.dist-info → hydraflow-0.18.0.dist-info}/METADATA +1 -2
- {hydraflow-0.17.2.dist-info → hydraflow-0.18.0.dist-info}/RECORD +11 -9
- {hydraflow-0.17.2.dist-info → hydraflow-0.18.0.dist-info}/WHEEL +0 -0
- {hydraflow-0.17.2.dist-info → hydraflow-0.18.0.dist-info}/entry_points.txt +0 -0
- {hydraflow-0.17.2.dist-info → hydraflow-0.18.0.dist-info}/licenses/LICENSE +0 -0
hydraflow/core/collection.py
CHANGED
@@ -2,6 +2,8 @@
|
|
2
2
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
|
+
import random
|
6
|
+
import re
|
5
7
|
from collections.abc import Hashable, Iterable, Sequence
|
6
8
|
from dataclasses import MISSING
|
7
9
|
from typing import TYPE_CHECKING, Concatenate, overload
|
@@ -417,6 +419,7 @@ class Collection[I](Sequence[I]):
|
|
417
419
|
function: Callable[Concatenate[I, P], R],
|
418
420
|
n_jobs: int = -1,
|
419
421
|
backend: str = "multiprocessing",
|
422
|
+
progress: bool = False,
|
420
423
|
*args: P.args,
|
421
424
|
**kwargs: P.kwargs,
|
422
425
|
) -> list[R]:
|
@@ -431,6 +434,7 @@ class Collection[I](Sequence[I]):
|
|
431
434
|
n_jobs (int): Number of jobs to run in parallel. -1 means using all
|
432
435
|
processors.
|
433
436
|
backend (str): Parallelization backend.
|
437
|
+
progress (bool): Whether to display a progress bar.
|
434
438
|
*args: Additional positional arguments to pass to the function.
|
435
439
|
**kwargs: Additional keyword arguments to pass to the function.
|
436
440
|
|
@@ -448,40 +452,82 @@ class Collection[I](Sequence[I]):
|
|
448
452
|
|
449
453
|
"""
|
450
454
|
parallel = Parallel(n_jobs=n_jobs, backend=backend, return_as="list")
|
451
|
-
|
455
|
+
it = (delayed(function)(i, *args, **kwargs) for i in self)
|
456
|
+
|
457
|
+
if not progress:
|
458
|
+
return parallel(it) # type: ignore
|
459
|
+
|
460
|
+
from hydraflow.utils.progress import Progress
|
461
|
+
|
462
|
+
with Progress(*Progress.get_default_columns()) as p:
|
463
|
+
p.add_task("", total=len(self))
|
464
|
+
return parallel(it) # type: ignore
|
452
465
|
|
453
466
|
def to_frame(
|
454
467
|
self,
|
455
|
-
*keys: str,
|
468
|
+
*keys: str | tuple[str, Any | Callable[[I], Any]],
|
456
469
|
defaults: dict[str, Any | Callable[[I], Any]] | None = None,
|
470
|
+
n_jobs: int = 0,
|
471
|
+
backend: str = "multiprocessing",
|
472
|
+
progress: bool = False,
|
457
473
|
**kwargs: Callable[[I], Any],
|
458
474
|
) -> DataFrame:
|
459
475
|
"""Convert the collection to a Polars DataFrame.
|
460
476
|
|
477
|
+
This method converts the items in the collection into a Polars DataFrame.
|
478
|
+
It allows specifying multiple keys, where each key can be a string or a tuple.
|
479
|
+
If a tuple is provided, the first element is treated as the key and the second
|
480
|
+
element as the default value for that key.
|
481
|
+
|
461
482
|
Args:
|
462
|
-
*keys (str): The keys to include
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
483
|
+
*keys (str | tuple[str, Any | Callable[[I], Any]]): The keys to include
|
484
|
+
as columns in the DataFrame. If a tuple is provided, the first element
|
485
|
+
is the key and the second element is the default value.
|
486
|
+
defaults (dict[str, Any | Callable[[I], Any]] | None): Default values
|
487
|
+
for the keys. If a callable, it will be called with the item and the
|
488
|
+
value returned will be used as the default.
|
489
|
+
n_jobs (int): Number of jobs to run in parallel. 0 means no parallelization.
|
490
|
+
Default to 0.
|
491
|
+
backend (str): Parallelization backend.
|
492
|
+
progress (bool): Whether to display a progress bar.
|
493
|
+
**kwargs (Callable[[I], Any]): Additional columns to compute using
|
494
|
+
callables that take an item and return a value.
|
469
495
|
|
470
496
|
Returns:
|
471
|
-
DataFrame: A Polars DataFrame containing the specified data
|
472
|
-
|
497
|
+
DataFrame: A Polars DataFrame containing the specified data from the items.
|
498
|
+
|
499
|
+
Examples:
|
500
|
+
```python
|
501
|
+
# Convert to DataFrame with single keys
|
502
|
+
df = collection.to_frame("name", "age")
|
503
|
+
|
504
|
+
# Convert to DataFrame with keys and default values
|
505
|
+
df = collection.to_frame(("name", "Unknown"), ("age", 0))
|
506
|
+
```
|
473
507
|
|
474
508
|
"""
|
475
509
|
if defaults is None:
|
476
510
|
defaults = {}
|
477
511
|
|
478
|
-
|
512
|
+
keys_ = []
|
513
|
+
for k in keys:
|
514
|
+
if isinstance(k, tuple):
|
515
|
+
keys_.append(k[0])
|
516
|
+
defaults[k[0]] = k[1]
|
517
|
+
else:
|
518
|
+
keys_.append(k)
|
519
|
+
|
520
|
+
data = {k: self.to_list(k, defaults.get(k, MISSING)) for k in keys_}
|
479
521
|
df = DataFrame(data)
|
480
522
|
|
481
523
|
if not kwargs:
|
482
524
|
return df
|
483
525
|
|
484
|
-
|
526
|
+
kv = kwargs.items()
|
527
|
+
if n_jobs == 0:
|
528
|
+
return df.with_columns(Series(k, self.map(v)) for k, v in kv)
|
529
|
+
|
530
|
+
columns = [Series(k, self.pmap(v, n_jobs, backend, progress)) for k, v in kv]
|
485
531
|
return df.with_columns(*columns)
|
486
532
|
|
487
533
|
def group_by(self, *by: str) -> GroupBy[Self, I]:
|
@@ -518,6 +564,266 @@ class Collection[I](Sequence[I]):
|
|
518
564
|
|
519
565
|
return GroupBy(by, groups)
|
520
566
|
|
567
|
+
def sample(self, k: int, seed: int | None = None) -> Self:
|
568
|
+
"""Sample a random subset of items from the collection.
|
569
|
+
|
570
|
+
This method returns a new collection containing a random sample
|
571
|
+
of items from the original collection. The sample is drawn without
|
572
|
+
replacement, meaning each item can only appear once in the sample.
|
573
|
+
|
574
|
+
Args:
|
575
|
+
k (int): The number of items to sample.
|
576
|
+
seed (int | None): The seed for the random number generator.
|
577
|
+
If provided, the sample will be reproducible.
|
578
|
+
|
579
|
+
Returns:
|
580
|
+
Self: A new collection containing a random sample of items.
|
581
|
+
|
582
|
+
Raises:
|
583
|
+
ValueError: If the sample size is greater than the collection size.
|
584
|
+
|
585
|
+
"""
|
586
|
+
n = len(self)
|
587
|
+
if k < 1 or k > n:
|
588
|
+
msg = f"Sample size ({k}) must be between 1 and {n}"
|
589
|
+
raise ValueError(msg)
|
590
|
+
|
591
|
+
if seed is not None:
|
592
|
+
random.seed(seed)
|
593
|
+
|
594
|
+
return self.__class__(random.sample(self._items, k), self._get)
|
595
|
+
|
596
|
+
def shuffle(self, seed: int | None = None) -> Self:
|
597
|
+
"""Shuffle the items in the collection.
|
598
|
+
|
599
|
+
This method returns a new collection with the items in random order.
|
600
|
+
|
601
|
+
Args:
|
602
|
+
seed (int | None): The seed for the random number generator.
|
603
|
+
If provided, the sample will be reproducible.
|
604
|
+
|
605
|
+
Returns:
|
606
|
+
Self: A new collection containing the items in random order.
|
607
|
+
|
608
|
+
"""
|
609
|
+
return self.sample(len(self), seed)
|
610
|
+
|
611
|
+
def eq(
|
612
|
+
self,
|
613
|
+
left: str,
|
614
|
+
right: str,
|
615
|
+
*,
|
616
|
+
default: Any | Callable[[I], Any] = MISSING,
|
617
|
+
) -> Callable[[I], bool]:
|
618
|
+
"""Create a predicate function that checks if two attributes are equal.
|
619
|
+
|
620
|
+
Args:
|
621
|
+
left (str): The name of the left attribute to compare.
|
622
|
+
right (str): The name of the right attribute to compare.
|
623
|
+
default (Any | Callable[[I], Any], optional): The default value
|
624
|
+
to use if either attribute is not found. If callable, it
|
625
|
+
will be called with the item.
|
626
|
+
|
627
|
+
Returns:
|
628
|
+
Callable[[I], bool]: A function that takes an item and returns
|
629
|
+
True if the values of the specified attributes are equal.
|
630
|
+
|
631
|
+
Examples:
|
632
|
+
```python
|
633
|
+
# Find items where attribute 'a' equals attribute 'b'
|
634
|
+
equal_items = collection.filter(collection.eq('a', 'b'))
|
635
|
+
```
|
636
|
+
|
637
|
+
"""
|
638
|
+
return lambda i: self._get(i, left, default) == self._get(i, right, default)
|
639
|
+
|
640
|
+
def ne(
|
641
|
+
self,
|
642
|
+
left: str,
|
643
|
+
right: str,
|
644
|
+
*,
|
645
|
+
default: Any | Callable[[I], Any] = MISSING,
|
646
|
+
) -> Callable[[I], bool]:
|
647
|
+
"""Create a predicate function that checks if two attributes are not equal.
|
648
|
+
|
649
|
+
Args:
|
650
|
+
left (str): The name of the left attribute to compare.
|
651
|
+
right (str): The name of the right attribute to compare.
|
652
|
+
default (Any | Callable[[I], Any], optional): The default value
|
653
|
+
to use if either attribute is not found. If callable, it
|
654
|
+
will be called with the item.
|
655
|
+
|
656
|
+
Returns:
|
657
|
+
Callable[[I], bool]: A function that takes an item and returns
|
658
|
+
True if the values of the specified attributes are not equal.
|
659
|
+
|
660
|
+
Examples:
|
661
|
+
```python
|
662
|
+
# Find items where attribute 'a' is not equal to attribute 'b'
|
663
|
+
unequal_items = collection.filter(collection.ne('a', 'b'))
|
664
|
+
```
|
665
|
+
|
666
|
+
"""
|
667
|
+
return lambda i: self._get(i, left, default) != self._get(i, right, default)
|
668
|
+
|
669
|
+
def gt(
|
670
|
+
self,
|
671
|
+
left: str,
|
672
|
+
right: str,
|
673
|
+
*,
|
674
|
+
default: Any | Callable[[I], Any] = MISSING,
|
675
|
+
) -> Callable[[I], bool]:
|
676
|
+
"""Create a predicate function that checks if the left > the right.
|
677
|
+
|
678
|
+
Args:
|
679
|
+
left (str): The name of the left attribute to compare.
|
680
|
+
right (str): The name of the right attribute to compare.
|
681
|
+
default (Any | Callable[[I], Any], optional): The default value
|
682
|
+
to use if either attribute is not found. If callable, it
|
683
|
+
will be called with the item.
|
684
|
+
|
685
|
+
Returns:
|
686
|
+
Callable[[I], bool]: A function that takes an item and returns
|
687
|
+
True if the left attribute value is greater than the right attribute value.
|
688
|
+
|
689
|
+
Examples:
|
690
|
+
```python
|
691
|
+
# Find items where attribute 'a' is greater than attribute 'b'
|
692
|
+
items = collection.filter(collection.gt('a', 'b'))
|
693
|
+
```
|
694
|
+
|
695
|
+
"""
|
696
|
+
return lambda i: self._get(i, left, default) > self._get(i, right, default)
|
697
|
+
|
698
|
+
def lt(
|
699
|
+
self,
|
700
|
+
left: str,
|
701
|
+
right: str,
|
702
|
+
*,
|
703
|
+
default: Any | Callable[[I], Any] = MISSING,
|
704
|
+
) -> Callable[[I], bool]:
|
705
|
+
"""Create a predicate function that checks if the left < the right.
|
706
|
+
|
707
|
+
Args:
|
708
|
+
left (str): The name of the left attribute to compare.
|
709
|
+
right (str): The name of the right attribute to compare.
|
710
|
+
default (Any | Callable[[I], Any], optional): The default value
|
711
|
+
to use if either attribute is not found. If callable, it
|
712
|
+
will be called with the item.
|
713
|
+
|
714
|
+
Returns:
|
715
|
+
Callable[[I], bool]: A function that takes an item and returns
|
716
|
+
True if the left attribute value is less than the right attribute value.
|
717
|
+
|
718
|
+
Examples:
|
719
|
+
```python
|
720
|
+
# Find items where attribute 'a' is less than attribute 'b'
|
721
|
+
items = collection.filter(collection.lt('a', 'b'))
|
722
|
+
```
|
723
|
+
|
724
|
+
"""
|
725
|
+
return lambda i: self._get(i, left, default) < self._get(i, right, default)
|
726
|
+
|
727
|
+
def ge(
|
728
|
+
self,
|
729
|
+
left: str,
|
730
|
+
right: str,
|
731
|
+
*,
|
732
|
+
default: Any | Callable[[I], Any] = MISSING,
|
733
|
+
) -> Callable[[I], bool]:
|
734
|
+
"""Create a predicate function that checks if the left >= the right.
|
735
|
+
|
736
|
+
Args:
|
737
|
+
left (str): The name of the left attribute to compare.
|
738
|
+
right (str): The name of the right attribute to compare.
|
739
|
+
default (Any | Callable[[I], Any], optional): The default value.
|
740
|
+
|
741
|
+
Returns:
|
742
|
+
Callable[[I], bool]: A predicate function for filtering.
|
743
|
+
|
744
|
+
"""
|
745
|
+
return lambda i: self._get(i, left, default) >= self._get(i, right, default)
|
746
|
+
|
747
|
+
def le(
|
748
|
+
self,
|
749
|
+
left: str,
|
750
|
+
right: str,
|
751
|
+
*,
|
752
|
+
default: Any | Callable[[I], Any] = MISSING,
|
753
|
+
) -> Callable[[I], bool]:
|
754
|
+
"""Create a predicate function that checks if the left <= the right.
|
755
|
+
|
756
|
+
Args:
|
757
|
+
left (str): The name of the left attribute to compare.
|
758
|
+
right (str): The name of the right attribute to compare.
|
759
|
+
default (Any | Callable[[I], Any], optional): The default value.
|
760
|
+
|
761
|
+
Returns:
|
762
|
+
Callable[[I], bool]: A predicate function for filtering.
|
763
|
+
|
764
|
+
"""
|
765
|
+
return lambda i: self._get(i, left, default) <= self._get(i, right, default)
|
766
|
+
|
767
|
+
def startswith(
|
768
|
+
self,
|
769
|
+
key: str,
|
770
|
+
prefix: str,
|
771
|
+
*,
|
772
|
+
default: Any | Callable[[I], Any] = MISSING,
|
773
|
+
) -> Callable[[I], bool]:
|
774
|
+
"""Create a predicate function that checks if an attribute starts with a prefix.
|
775
|
+
|
776
|
+
Args:
|
777
|
+
key (str): The name of the attribute to check.
|
778
|
+
prefix (str): The prefix to check for.
|
779
|
+
default (Any | Callable[[I], Any], optional): The default value.
|
780
|
+
|
781
|
+
Returns:
|
782
|
+
Callable[[I], bool]: A predicate function for filtering.
|
783
|
+
|
784
|
+
"""
|
785
|
+
return lambda i: str(self._get(i, key, default)).startswith(prefix)
|
786
|
+
|
787
|
+
def endswith(
|
788
|
+
self,
|
789
|
+
key: str,
|
790
|
+
suffix: str,
|
791
|
+
*,
|
792
|
+
default: Any | Callable[[I], Any] = MISSING,
|
793
|
+
) -> Callable[[I], bool]:
|
794
|
+
"""Create a predicate function that checks if an attribute ends with a suffix.
|
795
|
+
|
796
|
+
Args:
|
797
|
+
key (str): The name of the attribute to check.
|
798
|
+
suffix (str): The suffix to check for.
|
799
|
+
default (Any | Callable[[I], Any], optional): The default value.
|
800
|
+
|
801
|
+
Returns:
|
802
|
+
Callable[[I], bool]: A predicate function for filtering.
|
803
|
+
|
804
|
+
"""
|
805
|
+
return lambda i: str(self._get(i, key, default)).endswith(suffix)
|
806
|
+
|
807
|
+
def match(
|
808
|
+
self,
|
809
|
+
key: str,
|
810
|
+
pattern: str | re.Pattern,
|
811
|
+
*,
|
812
|
+
default: Any | Callable[[I], Any] = MISSING,
|
813
|
+
) -> Callable[[I], bool]:
|
814
|
+
"""Create a predicate function that checks if an attribute matches a pattern.
|
815
|
+
|
816
|
+
Args:
|
817
|
+
key (str): The name of the attribute to check.
|
818
|
+
pattern (str | re.Pattern): The pattern to check for.
|
819
|
+
default (Any | Callable[[I], Any], optional): The default value.
|
820
|
+
|
821
|
+
Returns:
|
822
|
+
Callable[[I], bool]: A predicate function for filtering.
|
823
|
+
|
824
|
+
"""
|
825
|
+
return lambda i: re.match(pattern, str(self._get(i, key, default))) is not None
|
826
|
+
|
521
827
|
|
522
828
|
def to_hashable(value: Any) -> Hashable:
|
523
829
|
"""Convert a value to a hashable instance.
|
@@ -565,9 +871,7 @@ def matches(value: Any, criterion: Any) -> bool:
|
|
565
871
|
|
566
872
|
Args:
|
567
873
|
value: The value to be compared with the criterion.
|
568
|
-
criterion: The criterion to match against.
|
569
|
-
Can be:
|
570
|
-
|
874
|
+
criterion: The criterion to match against. Can be:
|
571
875
|
- A callable that takes the value and returns a boolean
|
572
876
|
- A list or set to check membership
|
573
877
|
- A tuple of length 2 to check range inclusion
|
hydraflow/core/main.py
CHANGED
@@ -35,6 +35,8 @@ Example:
|
|
35
35
|
|
36
36
|
from __future__ import annotations
|
37
37
|
|
38
|
+
import logging
|
39
|
+
import sys
|
38
40
|
from functools import wraps
|
39
41
|
from pathlib import Path
|
40
42
|
from typing import TYPE_CHECKING
|
@@ -53,6 +55,8 @@ if TYPE_CHECKING:
|
|
53
55
|
|
54
56
|
from mlflow.entities import Run
|
55
57
|
|
58
|
+
log = logging.getLogger("hydraflow")
|
59
|
+
|
56
60
|
|
57
61
|
def main[C](
|
58
62
|
node: C | type[C],
|
@@ -62,6 +66,7 @@ def main[C](
|
|
62
66
|
force_new_run: bool = False,
|
63
67
|
match_overrides: bool = False,
|
64
68
|
rerun_finished: bool = False,
|
69
|
+
dry_run: bool = False,
|
65
70
|
update: Callable[[C], C | None] | None = None,
|
66
71
|
):
|
67
72
|
"""Decorator for configuring and running MLflow experiments with Hydra.
|
@@ -81,6 +86,9 @@ def main[C](
|
|
81
86
|
instead of full config. Defaults to False.
|
82
87
|
rerun_finished: If True, allows rerunning completed runs. Defaults to
|
83
88
|
False.
|
89
|
+
dry_run: If True, starts the hydra job but does not run the application
|
90
|
+
itself. This allows users to preview the configuration and
|
91
|
+
settings without executing the actual run. Defaults to False.
|
84
92
|
update: A function that takes a configuration and returns a new
|
85
93
|
configuration or None. The function can modify the configuration in-place
|
86
94
|
and/or return it. If the function returns None, the original (potentially
|
@@ -93,6 +101,10 @@ def main[C](
|
|
93
101
|
import mlflow
|
94
102
|
from mlflow.entities import RunStatus
|
95
103
|
|
104
|
+
if "--dry-run" in sys.argv:
|
105
|
+
dry_run = True
|
106
|
+
sys.argv.remove("--dry-run")
|
107
|
+
|
96
108
|
finished = RunStatus.to_string(RunStatus.FINISHED)
|
97
109
|
|
98
110
|
def decorator(app: Callable[[Run, C], None]) -> Callable[[], None]:
|
@@ -102,7 +114,6 @@ def main[C](
|
|
102
114
|
@wraps(app)
|
103
115
|
def inner_decorator(cfg: C) -> None:
|
104
116
|
hc = HydraConfig.get()
|
105
|
-
experiment = mlflow.set_experiment(hc.job.name)
|
106
117
|
|
107
118
|
if update:
|
108
119
|
if cfg_ := update(cfg):
|
@@ -112,6 +123,12 @@ def main[C](
|
|
112
123
|
cfg_path = hydra_dir.joinpath("config.yaml")
|
113
124
|
OmegaConf.save(cfg, cfg_path)
|
114
125
|
|
126
|
+
if dry_run:
|
127
|
+
log.info("Dry run:\n%s", OmegaConf.to_yaml(cfg).rstrip())
|
128
|
+
return
|
129
|
+
|
130
|
+
experiment = mlflow.set_experiment(hc.job.name)
|
131
|
+
|
115
132
|
if force_new_run:
|
116
133
|
run_id = None
|
117
134
|
else:
|
hydraflow/core/run.py
CHANGED
@@ -31,6 +31,7 @@ from functools import cached_property
|
|
31
31
|
from pathlib import Path
|
32
32
|
from typing import TYPE_CHECKING, cast, overload
|
33
33
|
|
34
|
+
import polars as pl
|
34
35
|
from omegaconf import DictConfig, OmegaConf
|
35
36
|
|
36
37
|
from .run_info import RunInfo
|
@@ -39,6 +40,9 @@ if TYPE_CHECKING:
|
|
39
40
|
from collections.abc import Iterator
|
40
41
|
from typing import Any, Self
|
41
42
|
|
43
|
+
from polars import Expr
|
44
|
+
from polars._typing import PolarsDataType
|
45
|
+
|
42
46
|
from .run_collection import RunCollection
|
43
47
|
|
44
48
|
|
@@ -216,9 +220,7 @@ class Run[C, I = None]:
|
|
216
220
|
(can use dot notation like "section.subsection.param"),
|
217
221
|
or a tuple of strings to set multiple related configuration
|
218
222
|
values at once.
|
219
|
-
value: The value to set.
|
220
|
-
This can be:
|
221
|
-
|
223
|
+
value: The value to set. This can be:
|
222
224
|
- For string keys: Any value, or a callable that returns
|
223
225
|
a value
|
224
226
|
- For tuple keys: An iterable with the same length as the
|
@@ -264,9 +266,7 @@ class Run[C, I = None]:
|
|
264
266
|
|
265
267
|
Args:
|
266
268
|
key: The key to look for. Can use dot notation for
|
267
|
-
nested keys in configuration.
|
268
|
-
Special keys:
|
269
|
-
|
269
|
+
nested keys in configuration. Special keys:
|
270
270
|
- "cfg": Returns the configuration object
|
271
271
|
- "impl": Returns the implementation object
|
272
272
|
- "info": Returns the run information object
|
@@ -314,6 +314,33 @@ class Run[C, I = None]:
|
|
314
314
|
msg = f"No such key: {key}"
|
315
315
|
raise AttributeError(msg)
|
316
316
|
|
317
|
+
def lit(
|
318
|
+
self,
|
319
|
+
key: str,
|
320
|
+
default: Any | Callable[[Self], Any] = MISSING,
|
321
|
+
*,
|
322
|
+
dtype: PolarsDataType | None = None,
|
323
|
+
) -> Expr:
|
324
|
+
"""Create a Polars literal expression from a run key.
|
325
|
+
|
326
|
+
Args:
|
327
|
+
key (str): The key to look up in the run's configuration or info.
|
328
|
+
default (Any | Callable[[Run], Any], optional): Default value to
|
329
|
+
use if the key is missing. If a callable is provided, it will be
|
330
|
+
called with the Run instance.
|
331
|
+
dtype (PolarsDataType | None): Explicit data type for the literal
|
332
|
+
expression.
|
333
|
+
|
334
|
+
Returns:
|
335
|
+
Expr: A Polars literal expression aliased to the provided key.
|
336
|
+
|
337
|
+
Raises:
|
338
|
+
AttributeError: If the key is not found and no default is provided.
|
339
|
+
|
340
|
+
"""
|
341
|
+
value = self.get(key, default)
|
342
|
+
return pl.lit(value, dtype).alias(key)
|
343
|
+
|
317
344
|
def to_dict(self, flatten: bool = True) -> dict[str, Any]:
|
318
345
|
"""Convert the Run to a dictionary.
|
319
346
|
|
hydraflow/core/run_collection.py
CHANGED
@@ -194,7 +194,7 @@ class RunCollection[R: Run[Any, Any], I = None](Collection[R]):
|
|
194
194
|
|
195
195
|
"""
|
196
196
|
for run in self:
|
197
|
-
yield from run.
|
197
|
+
yield from run.iterdir(relative_dir)
|
198
198
|
|
199
199
|
def glob(self, pattern: str, relative_dir: str = "") -> Iterator[Path]:
|
200
200
|
"""Glob the artifact directories for all runs in the collection.
|
@@ -212,4 +212,4 @@ class RunCollection[R: Run[Any, Any], I = None](Collection[R]):
|
|
212
212
|
|
213
213
|
"""
|
214
214
|
for run in self:
|
215
|
-
yield from run.
|
215
|
+
yield from run.glob(pattern, relative_dir)
|
File without changes
|
@@ -0,0 +1,90 @@
|
|
1
|
+
"""Provide a progress bar for parallel task execution.
|
2
|
+
|
3
|
+
This module defines the `Progress` class, which provides a visual
|
4
|
+
progress bar for tracking the execution of parallel tasks. It integrates
|
5
|
+
with the `joblib` library to display the progress of tasks being executed
|
6
|
+
in parallel, allowing users to monitor the completion status in real-time.
|
7
|
+
|
8
|
+
The `Progress` class can be customized to show different columns of
|
9
|
+
information, such as the elapsed time and the number of completed tasks.
|
10
|
+
It also provides methods to start and stop the progress display, as well
|
11
|
+
as to update the progress based on the number of completed tasks.
|
12
|
+
|
13
|
+
Example:
|
14
|
+
```python
|
15
|
+
from hydraflow.utils.progress import Progress
|
16
|
+
from joblib import Parallel, delayed
|
17
|
+
|
18
|
+
with Progress(*Progress.get_default_columns()) as progress:
|
19
|
+
Parallel(n_jobs=4)(delayed(function)(x) for x in iterable)
|
20
|
+
```
|
21
|
+
|
22
|
+
"""
|
23
|
+
|
24
|
+
from __future__ import annotations
|
25
|
+
|
26
|
+
from typing import TYPE_CHECKING
|
27
|
+
|
28
|
+
from joblib.parallel import Parallel
|
29
|
+
from rich.progress import MofNCompleteColumn, SpinnerColumn, TimeElapsedColumn
|
30
|
+
from rich.progress import Progress as Super
|
31
|
+
|
32
|
+
if TYPE_CHECKING:
|
33
|
+
from collections.abc import Callable
|
34
|
+
|
35
|
+
from rich.progress import ProgressColumn
|
36
|
+
|
37
|
+
|
38
|
+
# https://github.com/jonghwanhyeon/joblib-progress/blob/main/joblib_progress/__init__.py
|
39
|
+
class Progress(Super):
|
40
|
+
"""A progress bar for tracking parallel task execution.
|
41
|
+
|
42
|
+
This class extends the `rich.progress.Progress` class to provide
|
43
|
+
a visual progress bar specifically designed for monitoring the
|
44
|
+
execution of tasks running in parallel using the `joblib` library.
|
45
|
+
It allows users to see the completion status of tasks in real-time
|
46
|
+
and can be customized to display various columns of information.
|
47
|
+
"""
|
48
|
+
|
49
|
+
_print_progress: Callable[[Parallel], None] | None = None
|
50
|
+
|
51
|
+
def start(self) -> None:
|
52
|
+
"""Start the progress display."""
|
53
|
+
super().start()
|
54
|
+
|
55
|
+
self._print_progress = Parallel.print_progress
|
56
|
+
|
57
|
+
def _update(parallel: Parallel) -> None:
|
58
|
+
update(self, parallel)
|
59
|
+
|
60
|
+
Parallel.print_progress = _update # type: ignore
|
61
|
+
|
62
|
+
def stop(self) -> None:
|
63
|
+
"""Stop the progress display."""
|
64
|
+
if self._print_progress:
|
65
|
+
Parallel.print_progress = self._print_progress # type: ignore
|
66
|
+
|
67
|
+
super().stop()
|
68
|
+
|
69
|
+
@classmethod
|
70
|
+
def get_default_columns(cls) -> tuple[ProgressColumn, ...]:
|
71
|
+
"""Get the default columns used for a new Progress instance."""
|
72
|
+
return (
|
73
|
+
SpinnerColumn(),
|
74
|
+
TimeElapsedColumn(),
|
75
|
+
*Super.get_default_columns(),
|
76
|
+
MofNCompleteColumn(),
|
77
|
+
)
|
78
|
+
|
79
|
+
|
80
|
+
def update(progress: Progress, parallel: Parallel) -> None:
|
81
|
+
"""Update the progress bar."""
|
82
|
+
if progress.task_ids:
|
83
|
+
task_id = progress.task_ids[-1]
|
84
|
+
else:
|
85
|
+
task_id = progress.add_task("", total=None)
|
86
|
+
|
87
|
+
progress.update(task_id, completed=parallel.n_completed_tasks, refresh=True)
|
88
|
+
|
89
|
+
if progress._print_progress: # noqa: SLF001
|
90
|
+
progress._print_progress(parallel) # noqa: SLF001
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: hydraflow
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.18.0
|
4
4
|
Summary: HydraFlow seamlessly integrates Hydra and MLflow to streamline ML experiment management, combining Hydra's configuration management with MLflow's tracking capabilities.
|
5
5
|
Project-URL: Documentation, https://daizutabi.github.io/hydraflow/
|
6
6
|
Project-URL: Source, https://github.com/daizutabi/hydraflow
|
@@ -47,7 +47,6 @@ Requires-Dist: omegaconf>=2.3
|
|
47
47
|
Requires-Dist: polars>=1.26
|
48
48
|
Requires-Dist: python-ulid>=3.0.0
|
49
49
|
Requires-Dist: rich>=13.9
|
50
|
-
Requires-Dist: ruff>=0.11
|
51
50
|
Requires-Dist: typer>=0.15
|
52
51
|
Description-Content-Type: text/markdown
|
53
52
|
|
@@ -2,13 +2,13 @@ hydraflow/__init__.py,sha256=_cLLokEv0pUlwvG8RMnjOwCTtDQBs0-RgGbtDk5m_Xg,794
|
|
2
2
|
hydraflow/cli.py,sha256=3rGr___wwp8KazjLGQ7JO_IgAMqLyMlcVSs_QJK7g0Y,3135
|
3
3
|
hydraflow/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
hydraflow/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
|
-
hydraflow/core/collection.py,sha256=
|
5
|
+
hydraflow/core/collection.py,sha256=wiwbzv8p2psdoV2U4K3ROUIm8jS6fg7b_R34qBl0BaY,30249
|
6
6
|
hydraflow/core/context.py,sha256=6vpwe0Xfl6mzh2hHLE-4uB9Hjew-CK4pA0KFihQ80U8,4168
|
7
7
|
hydraflow/core/group_by.py,sha256=Pnw-oA5aXHeRG9lMLz-bKc8drqQ8LIRsWzvVn153iyQ,5488
|
8
8
|
hydraflow/core/io.py,sha256=B3-jPuJWttRgpbIpy_XA-Z2qpXzNF1ATwyYEwA7Pv3w,5172
|
9
|
-
hydraflow/core/main.py,sha256=
|
10
|
-
hydraflow/core/run.py,sha256=
|
11
|
-
hydraflow/core/run_collection.py,sha256=
|
9
|
+
hydraflow/core/main.py,sha256=6X58M-xpJPql7xqLR3gpa4XMePvZ6Q1diMSqTZf2Jrw,6542
|
10
|
+
hydraflow/core/run.py,sha256=rt97EXJ72fdMsuh4AKL-dxRfnQQ4Yzo2ZpTocTR5Wr8,15034
|
11
|
+
hydraflow/core/run_collection.py,sha256=JgCdu0_hvgAoeScdDuXQrEWKQhCzoL-v-g53cR9Sm_c,6759
|
12
12
|
hydraflow/core/run_info.py,sha256=SMOTZXEa7OBV_XjTyctk5gJGrggmYwhePvRF8CLF1kU,1616
|
13
13
|
hydraflow/executor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
14
|
hydraflow/executor/aio.py,sha256=xXsmBPIPdBlopv_1h0FdtOvoKUcuW7PQeKCV2d_lN9I,2122
|
@@ -16,8 +16,10 @@ hydraflow/executor/conf.py,sha256=8Xq4UAenRKJIl1NBgNbSfv6VUTJhdwPLayZIEAsiBR0,41
|
|
16
16
|
hydraflow/executor/io.py,sha256=18wnHpCMQRGYL-oN2841h9W2aSW_X2SmO68Lx-3FIbU,1043
|
17
17
|
hydraflow/executor/job.py,sha256=6QeJ18OMeocXeM04rCYL46GgArfX1SvZs9_4HTomTgE,5436
|
18
18
|
hydraflow/executor/parser.py,sha256=RxP8qpDaJ8VLqZ51VlPFyVitWctObhkE_3iPIsY66Cs,14610
|
19
|
-
hydraflow
|
20
|
-
hydraflow
|
21
|
-
hydraflow-0.
|
22
|
-
hydraflow-0.
|
23
|
-
hydraflow-0.
|
19
|
+
hydraflow/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
20
|
+
hydraflow/utils/progress.py,sha256=a-CHvioyGCeiUKawqPcV8i1nhzunm5-r5AlLbzd5epw,3048
|
21
|
+
hydraflow-0.18.0.dist-info/METADATA,sha256=Hd6O1Ah70Gitxtfav0w40d18RgCnIvZ7J0e1OjDBczY,7509
|
22
|
+
hydraflow-0.18.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
23
|
+
hydraflow-0.18.0.dist-info/entry_points.txt,sha256=XI0khPbpCIUo9UPqkNEpgh-kqK3Jy8T7L2VCWOdkbSM,48
|
24
|
+
hydraflow-0.18.0.dist-info/licenses/LICENSE,sha256=IGdDrBPqz1O0v_UwCW-NJlbX9Hy9b3uJ11t28y2srmY,1062
|
25
|
+
hydraflow-0.18.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|