hydraflow 0.1.5__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydraflow/__init__.py +0 -10
- hydraflow/config.py +29 -17
- hydraflow/context.py +31 -19
- hydraflow/mlflow.py +23 -0
- hydraflow/runs.py +213 -303
- {hydraflow-0.1.5.dist-info → hydraflow-0.2.1.dist-info}/METADATA +1 -1
- hydraflow-0.2.1.dist-info/RECORD +9 -0
- hydraflow/util.py +0 -24
- hydraflow-0.1.5.dist-info/RECORD +0 -10
- {hydraflow-0.1.5.dist-info → hydraflow-0.2.1.dist-info}/WHEEL +0 -0
- {hydraflow-0.1.5.dist-info → hydraflow-0.2.1.dist-info}/licenses/LICENSE +0 -0
    
        hydraflow/__init__.py
    CHANGED
    
    | @@ -3,15 +3,10 @@ from .mlflow import set_experiment | |
| 3 3 | 
             
            from .runs import (
         | 
| 4 4 | 
             
                Run,
         | 
| 5 5 | 
             
                Runs,
         | 
| 6 | 
            -
                drop_unique_params,
         | 
| 7 6 | 
             
                filter_runs,
         | 
| 8 | 
            -
                get_artifact_dir,
         | 
| 9 | 
            -
                get_artifact_path,
         | 
| 10 | 
            -
                get_artifact_uri,
         | 
| 11 7 | 
             
                get_param_dict,
         | 
| 12 8 | 
             
                get_param_names,
         | 
| 13 9 | 
             
                get_run,
         | 
| 14 | 
            -
                get_run_id,
         | 
| 15 10 | 
             
                load_config,
         | 
| 16 11 | 
             
            )
         | 
| 17 12 |  | 
| @@ -20,15 +15,10 @@ __all__ = [ | |
| 20 15 | 
             
                "Run",
         | 
| 21 16 | 
             
                "Runs",
         | 
| 22 17 | 
             
                "chdir_artifact",
         | 
| 23 | 
            -
                "drop_unique_params",
         | 
| 24 18 | 
             
                "filter_runs",
         | 
| 25 | 
            -
                "get_artifact_dir",
         | 
| 26 | 
            -
                "get_artifact_path",
         | 
| 27 | 
            -
                "get_artifact_uri",
         | 
| 28 19 | 
             
                "get_param_dict",
         | 
| 29 20 | 
             
                "get_param_names",
         | 
| 30 21 | 
             
                "get_run",
         | 
| 31 | 
            -
                "get_run_id",
         | 
| 32 22 | 
             
                "load_config",
         | 
| 33 23 | 
             
                "log_run",
         | 
| 34 24 | 
             
                "set_experiment",
         | 
    
        hydraflow/config.py
    CHANGED
    
    | @@ -16,39 +16,51 @@ if TYPE_CHECKING: | |
| 16 16 |  | 
| 17 17 | 
             
            def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
         | 
| 18 18 | 
             
                """
         | 
| 19 | 
            -
                 | 
| 19 | 
            +
                Recursively iterate over the parameters in the given configuration object.
         | 
| 20 20 |  | 
| 21 | 
            -
                This function  | 
| 22 | 
            -
                 | 
| 21 | 
            +
                This function traverses the configuration object and yields key-value pairs
         | 
| 22 | 
            +
                representing the parameters. The keys are prefixed with the provided prefix.
         | 
| 23 23 |  | 
| 24 24 | 
             
                Args:
         | 
| 25 | 
            -
                    config | 
| 26 | 
            -
             | 
| 27 | 
            -
             | 
| 25 | 
            +
                    config: The configuration object to iterate over. This can be a dictionary,
         | 
| 26 | 
            +
                        list, DictConfig, or ListConfig.
         | 
| 27 | 
            +
                    prefix: The prefix to prepend to the parameter keys.
         | 
| 28 | 
            +
                        Defaults to an empty string.
         | 
| 28 29 |  | 
| 29 30 | 
             
                Yields:
         | 
| 30 | 
            -
                    Key-value pairs representing the parameters.
         | 
| 31 | 
            +
                    Key-value pairs representing the parameters in the configuration object.
         | 
| 31 32 | 
             
                """
         | 
| 32 33 | 
             
                if not isinstance(config, (DictConfig, ListConfig)):
         | 
| 33 34 | 
             
                    config = OmegaConf.create(config)  # type: ignore
         | 
| 34 35 |  | 
| 36 | 
            +
                yield from _iter_params(config, prefix)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            def _iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
         | 
| 35 40 | 
             
                if isinstance(config, DictConfig):
         | 
| 36 41 | 
             
                    for key, value in config.items():
         | 
| 37 | 
            -
                        if  | 
| 38 | 
            -
                            isinstance(v, (DictConfig, ListConfig)) for v in value
         | 
| 39 | 
            -
                        ):
         | 
| 42 | 
            +
                        if _is_param(value):
         | 
| 40 43 | 
             
                            yield f"{prefix}{key}", value
         | 
| 41 44 |  | 
| 42 | 
            -
                        elif isinstance(value, (DictConfig, ListConfig)):
         | 
| 43 | 
            -
                            yield from iter_params(value, f"{prefix}{key}.")
         | 
| 44 | 
            -
             | 
| 45 45 | 
             
                        else:
         | 
| 46 | 
            -
                            yield f"{prefix}{key}" | 
| 46 | 
            +
                            yield from _iter_params(value, f"{prefix}{key}.")
         | 
| 47 47 |  | 
| 48 48 | 
             
                elif isinstance(config, ListConfig):
         | 
| 49 49 | 
             
                    for index, value in enumerate(config):
         | 
| 50 | 
            -
                        if  | 
| 51 | 
            -
                            yield  | 
| 50 | 
            +
                        if _is_param(value):
         | 
| 51 | 
            +
                            yield f"{prefix}{index}", value
         | 
| 52 52 |  | 
| 53 53 | 
             
                        else:
         | 
| 54 | 
            -
                            yield f"{prefix}{index}" | 
| 54 | 
            +
                            yield from _iter_params(value, f"{prefix}{index}.")
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            def _is_param(value: object) -> bool:
         | 
| 58 | 
            +
                """Check if the given value is a parameter."""
         | 
| 59 | 
            +
                if isinstance(value, DictConfig):
         | 
| 60 | 
            +
                    return False
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                if isinstance(value, ListConfig):
         | 
| 63 | 
            +
                    if any(isinstance(v, (DictConfig, ListConfig)) for v in value):
         | 
| 64 | 
            +
                        return False
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                return True
         | 
    
        hydraflow/context.py
    CHANGED
    
    | @@ -5,6 +5,7 @@ run context. | |
| 5 5 |  | 
| 6 6 | 
             
            from __future__ import annotations
         | 
| 7 7 |  | 
| 8 | 
            +
            import logging
         | 
| 8 9 | 
             
            import os
         | 
| 9 10 | 
             
            import time
         | 
| 10 11 | 
             
            from contextlib import contextmanager
         | 
| @@ -17,15 +18,14 @@ from hydra.core.hydra_config import HydraConfig | |
| 17 18 | 
             
            from watchdog.events import FileModifiedEvent, FileSystemEventHandler
         | 
| 18 19 | 
             
            from watchdog.observers import Observer
         | 
| 19 20 |  | 
| 20 | 
            -
            from hydraflow.mlflow import log_params
         | 
| 21 | 
            -
            from hydraflow.runs import get_artifact_path
         | 
| 22 | 
            -
            from hydraflow.util import uri_to_path
         | 
| 21 | 
            +
            from hydraflow.mlflow import get_artifact_dir, log_params
         | 
| 23 22 |  | 
| 24 23 | 
             
            if TYPE_CHECKING:
         | 
| 25 24 | 
             
                from collections.abc import Callable, Iterator
         | 
| 26 25 |  | 
| 27 26 | 
             
                from mlflow.entities.run import Run
         | 
| 28 | 
            -
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            log = logging.getLogger(__name__)
         | 
| 29 29 |  | 
| 30 30 |  | 
| 31 31 | 
             
            @dataclass
         | 
| @@ -66,8 +66,7 @@ def log_run( | |
| 66 66 |  | 
| 67 67 | 
             
                hc = HydraConfig.get()
         | 
| 68 68 | 
             
                output_dir = Path(hc.runtime.output_dir)
         | 
| 69 | 
            -
                 | 
| 70 | 
            -
                info = Info(output_dir, uri_to_path(uri))
         | 
| 69 | 
            +
                info = Info(output_dir, get_artifact_dir())
         | 
| 71 70 |  | 
| 72 71 | 
             
                # Save '.hydra' config directory first.
         | 
| 73 72 | 
             
                output_subdir = output_dir / (hc.output_subdir or "")
         | 
| @@ -81,13 +80,21 @@ def log_run( | |
| 81 80 | 
             
                    with watch(log_artifact, output_dir):
         | 
| 82 81 | 
             
                        yield info
         | 
| 83 82 |  | 
| 83 | 
            +
                except Exception as e:
         | 
| 84 | 
            +
                    log.error(f"Error during log_run: {e}")
         | 
| 85 | 
            +
                    raise
         | 
| 86 | 
            +
             | 
| 84 87 | 
             
                finally:
         | 
| 85 88 | 
             
                    # Save output_dir including '.hydra' config directory.
         | 
| 86 89 | 
             
                    mlflow.log_artifacts(output_dir.as_posix())
         | 
| 87 90 |  | 
| 88 91 |  | 
| 89 92 | 
             
            @contextmanager
         | 
| 90 | 
            -
            def watch( | 
| 93 | 
            +
            def watch(
         | 
| 94 | 
            +
                func: Callable[[Path], None],
         | 
| 95 | 
            +
                dir: Path | str = "",
         | 
| 96 | 
            +
                timeout: int = 60,
         | 
| 97 | 
            +
            ) -> Iterator[None]:
         | 
| 91 98 | 
             
                """
         | 
| 92 99 | 
             
                Watch the given directory for changes and call the provided function
         | 
| 93 100 | 
             
                when a change is detected.
         | 
| @@ -98,25 +105,23 @@ def watch(func: Callable[[Path], None], dir: Path | str = "", timeout: int = 60) | |
| 98 105 | 
             
                period or until the context is exited.
         | 
| 99 106 |  | 
| 100 107 | 
             
                Args:
         | 
| 101 | 
            -
                    func | 
| 108 | 
            +
                    func: The function to call when a change is
         | 
| 102 109 | 
             
                        detected. It should accept a single argument of type `Path`,
         | 
| 103 110 | 
             
                        which is the path of the modified file.
         | 
| 104 | 
            -
                    dir | 
| 111 | 
            +
                    dir: The directory to watch. If not specified,
         | 
| 105 112 | 
             
                        the current MLflow artifact URI is used. Defaults to "".
         | 
| 106 | 
            -
                    timeout | 
| 113 | 
            +
                    timeout: The timeout period in seconds for the watcher
         | 
| 107 114 | 
             
                        to run after the context is exited. Defaults to 60.
         | 
| 108 115 |  | 
| 109 116 | 
             
                Yields:
         | 
| 110 | 
            -
                    None | 
| 117 | 
            +
                    None
         | 
| 111 118 |  | 
| 112 119 | 
             
                Example:
         | 
| 113 120 | 
             
                    with watch(log_artifact, "/path/to/dir"):
         | 
| 114 121 | 
             
                        # Perform operations while watching the directory for changes
         | 
| 115 122 | 
             
                        pass
         | 
| 116 123 | 
             
                """
         | 
| 117 | 
            -
                 | 
| 118 | 
            -
                    uri = mlflow.get_artifact_uri()
         | 
| 119 | 
            -
                    dir = uri_to_path(uri)
         | 
| 124 | 
            +
                dir = dir or get_artifact_dir()
         | 
| 120 125 |  | 
| 121 126 | 
             
                handler = Handler(func)
         | 
| 122 127 | 
             
                observer = Observer()
         | 
| @@ -126,6 +131,10 @@ def watch(func: Callable[[Path], None], dir: Path | str = "", timeout: int = 60) | |
| 126 131 | 
             
                try:
         | 
| 127 132 | 
             
                    yield
         | 
| 128 133 |  | 
| 134 | 
            +
                except Exception as e:
         | 
| 135 | 
            +
                    log.error(f"Error during watch: {e}")
         | 
| 136 | 
            +
                    raise
         | 
| 137 | 
            +
             | 
| 129 138 | 
             
                finally:
         | 
| 130 139 | 
             
                    elapsed = 0
         | 
| 131 140 | 
             
                    while not observer.event_queue.empty():
         | 
| @@ -150,7 +159,7 @@ class Handler(FileSystemEventHandler): | |
| 150 159 |  | 
| 151 160 | 
             
            @contextmanager
         | 
| 152 161 | 
             
            def chdir_artifact(
         | 
| 153 | 
            -
                run: Run | 
| 162 | 
            +
                run: Run,
         | 
| 154 163 | 
             
                artifact_path: str | None = None,
         | 
| 155 164 | 
             
            ) -> Iterator[Path]:
         | 
| 156 165 | 
             
                """
         | 
| @@ -166,11 +175,14 @@ def chdir_artifact( | |
| 166 175 | 
             
                    artifact_path: The artifact path.
         | 
| 167 176 | 
             
                """
         | 
| 168 177 | 
             
                curdir = Path.cwd()
         | 
| 178 | 
            +
                path = mlflow.artifacts.download_artifacts(
         | 
| 179 | 
            +
                    run_id=run.info.run_id,
         | 
| 180 | 
            +
                    artifact_path=artifact_path,
         | 
| 181 | 
            +
                )
         | 
| 169 182 |  | 
| 170 | 
            -
                 | 
| 171 | 
            -
             | 
| 172 | 
            -
                os.chdir(artifact_dir)
         | 
| 183 | 
            +
                os.chdir(path)
         | 
| 173 184 | 
             
                try:
         | 
| 174 | 
            -
                    yield  | 
| 185 | 
            +
                    yield Path(path)
         | 
| 186 | 
            +
             | 
| 175 187 | 
             
                finally:
         | 
| 176 188 | 
             
                    os.chdir(curdir)
         | 
    
        hydraflow/mlflow.py
    CHANGED
    
    | @@ -5,6 +5,8 @@ configuration objects and set up experiments using MLflow. | |
| 5 5 |  | 
| 6 6 | 
             
            from __future__ import annotations
         | 
| 7 7 |  | 
| 8 | 
            +
            from pathlib import Path
         | 
| 9 | 
            +
             | 
| 8 10 | 
             
            import mlflow
         | 
| 9 11 | 
             
            from hydra.core.hydra_config import HydraConfig
         | 
| 10 12 |  | 
| @@ -47,3 +49,24 @@ def log_params(config: object, *, synchronous: bool | None = None) -> None: | |
| 47 49 | 
             
                """
         | 
| 48 50 | 
             
                for key, value in iter_params(config):
         | 
| 49 51 | 
             
                    mlflow.log_param(key, value, synchronous=synchronous)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            def get_artifact_dir(artifact_path: str | None = None) -> Path:
         | 
| 55 | 
            +
                """
         | 
| 56 | 
            +
                Get the artifact directory for the given artifact path.
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                This function retrieves the artifact URI for the specified artifact path
         | 
| 59 | 
            +
                using MLflow, downloads the artifacts to a local directory, and returns
         | 
| 60 | 
            +
                the path to that directory.
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                Args:
         | 
| 63 | 
            +
                    artifact_path: The artifact path for which to get the directory.
         | 
| 64 | 
            +
                        Defaults to None.
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                Returns:
         | 
| 67 | 
            +
                    The local path to the directory where the artifacts are downloaded.
         | 
| 68 | 
            +
                """
         | 
| 69 | 
            +
                uri = mlflow.get_artifact_uri(artifact_path)
         | 
| 70 | 
            +
                dir = mlflow.artifacts.download_artifacts(artifact_uri=uri)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                return Path(dir)
         | 
    
        hydraflow/runs.py
    CHANGED
    
    | @@ -1,30 +1,85 @@ | |
| 1 1 | 
             
            """
         | 
| 2 2 | 
             
            This module provides functionality for managing and interacting with MLflow runs.
         | 
| 3 | 
            -
            It includes  | 
| 4 | 
            -
            log artifacts and configurations.
         | 
| 3 | 
            +
            It includes the `Runs` class and various methods to filter runs, retrieve run information,
         | 
| 4 | 
            +
            log artifacts, and load configurations.
         | 
| 5 5 | 
             
            """
         | 
| 6 6 |  | 
| 7 7 | 
             
            from __future__ import annotations
         | 
| 8 8 |  | 
| 9 9 | 
             
            from dataclasses import dataclass
         | 
| 10 10 | 
             
            from functools import cache
         | 
| 11 | 
            -
            from  | 
| 11 | 
            +
            from itertools import chain
         | 
| 12 12 | 
             
            from typing import TYPE_CHECKING, Any
         | 
| 13 13 |  | 
| 14 14 | 
             
            import mlflow
         | 
| 15 | 
            -
             | 
| 16 | 
            -
            from mlflow.entities.run import Run | 
| 17 | 
            -
            from mlflow.tracking import  | 
| 15 | 
            +
            from mlflow.entities import ViewType
         | 
| 16 | 
            +
            from mlflow.entities.run import Run
         | 
| 17 | 
            +
            from mlflow.tracking.fluent import SEARCH_MAX_RESULTS_PANDAS
         | 
| 18 18 | 
             
            from omegaconf import DictConfig, OmegaConf
         | 
| 19 | 
            -
            from pandas import DataFrame, Series
         | 
| 20 19 |  | 
| 21 20 | 
             
            from hydraflow.config import iter_params
         | 
| 22 | 
            -
            from hydraflow.util import uri_to_path
         | 
| 23 21 |  | 
| 24 22 | 
             
            if TYPE_CHECKING:
         | 
| 25 23 | 
             
                from typing import Any
         | 
| 26 24 |  | 
| 27 25 |  | 
| 26 | 
            +
            def search_runs(
         | 
| 27 | 
            +
                experiment_ids: list[str] | None = None,
         | 
| 28 | 
            +
                filter_string: str = "",
         | 
| 29 | 
            +
                run_view_type: int = ViewType.ACTIVE_ONLY,
         | 
| 30 | 
            +
                max_results: int = SEARCH_MAX_RESULTS_PANDAS,
         | 
| 31 | 
            +
                order_by: list[str] | None = None,
         | 
| 32 | 
            +
                search_all_experiments: bool = False,
         | 
| 33 | 
            +
                experiment_names: list[str] | None = None,
         | 
| 34 | 
            +
            ) -> Runs:
         | 
| 35 | 
            +
                """
         | 
| 36 | 
            +
                Search for Runs that fit the specified criteria.
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                This function wraps the `mlflow.search_runs` function and returns the results
         | 
| 39 | 
            +
                as a `Runs` object. It allows for flexible searching of MLflow runs based on
         | 
| 40 | 
            +
                various criteria.
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                Args:
         | 
| 43 | 
            +
                    experiment_ids: List of experiment IDs. Search can work with experiment IDs or
         | 
| 44 | 
            +
                        experiment names, but not both in the same call. Values other than
         | 
| 45 | 
            +
                        ``None`` or ``[]`` will result in error if ``experiment_names`` is
         | 
| 46 | 
            +
                        also not ``None`` or ``[]``. ``None`` will default to the active
         | 
| 47 | 
            +
                        experiment if ``experiment_names`` is ``None`` or ``[]``.
         | 
| 48 | 
            +
                    filter_string: Filter query string, defaults to searching all runs.
         | 
| 49 | 
            +
                    run_view_type: one of enum values ``ACTIVE_ONLY``, ``DELETED_ONLY``, or ``ALL`` runs
         | 
| 50 | 
            +
                        defined in :py:class:`mlflow.entities.ViewType`.
         | 
| 51 | 
            +
                    max_results: The maximum number of runs to put in the dataframe. Default is 100,000
         | 
| 52 | 
            +
                        to avoid causing out-of-memory issues on the user's machine.
         | 
| 53 | 
            +
                    order_by: List of columns to order by (e.g., "metrics.rmse"). The ``order_by`` column
         | 
| 54 | 
            +
                        can contain an optional ``DESC`` or ``ASC`` value. The default is ``ASC``.
         | 
| 55 | 
            +
                        The default ordering is to sort by ``start_time DESC``, then ``run_id``.
         | 
| 56 | 
            +
                    output_format: The output format to be returned. If ``pandas``, a ``pandas.DataFrame``
         | 
| 57 | 
            +
                        is returned and, if ``list``, a list of :py:class:`mlflow.entities.Run`
         | 
| 58 | 
            +
                        is returned.
         | 
| 59 | 
            +
                    search_all_experiments: Boolean specifying whether all experiments should be searched.
         | 
| 60 | 
            +
                        Only honored if ``experiment_ids`` is ``[]`` or ``None``.
         | 
| 61 | 
            +
                    experiment_names: List of experiment names. Search can work with experiment IDs or
         | 
| 62 | 
            +
                        experiment names, but not both in the same call. Values other
         | 
| 63 | 
            +
                        than ``None`` or ``[]`` will result in error if ``experiment_ids``
         | 
| 64 | 
            +
                        is also not ``None`` or ``[]``. ``None`` will default to the active
         | 
| 65 | 
            +
                        experiment if ``experiment_ids`` is ``None`` or ``[]``.
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                Returns:
         | 
| 68 | 
            +
                    A `Runs` object containing the search results.
         | 
| 69 | 
            +
                """
         | 
| 70 | 
            +
                runs = mlflow.search_runs(
         | 
| 71 | 
            +
                    experiment_ids=experiment_ids,
         | 
| 72 | 
            +
                    filter_string=filter_string,
         | 
| 73 | 
            +
                    run_view_type=run_view_type,
         | 
| 74 | 
            +
                    max_results=max_results,
         | 
| 75 | 
            +
                    order_by=order_by,
         | 
| 76 | 
            +
                    output_format="list",
         | 
| 77 | 
            +
                    search_all_experiments=search_all_experiments,
         | 
| 78 | 
            +
                    experiment_names=experiment_names,
         | 
| 79 | 
            +
                )
         | 
| 80 | 
            +
                return Runs(runs)  # type: ignore
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 28 83 | 
             
            @dataclass
         | 
| 29 84 | 
             
            class Runs:
         | 
| 30 85 | 
             
                """
         | 
| @@ -34,7 +89,7 @@ class Runs: | |
| 34 89 | 
             
                retrieving specific runs, and accessing run information.
         | 
| 35 90 | 
             
                """
         | 
| 36 91 |  | 
| 37 | 
            -
                runs: list[ | 
| 92 | 
            +
                runs: list[Run]
         | 
| 38 93 |  | 
| 39 94 | 
             
                def __repr__(self) -> str:
         | 
| 40 95 | 
             
                    return f"{self.__class__.__name__}({len(self)})"
         | 
| @@ -53,115 +108,95 @@ class Runs: | |
| 53 108 | 
             
                    be included in the returned `Runs` object.
         | 
| 54 109 |  | 
| 55 110 | 
             
                    Args:
         | 
| 56 | 
            -
                        config | 
| 57 | 
            -
                            This object should contain key-value pairs representing
         | 
| 58 | 
            -
                            the parameters to filter by.
         | 
| 111 | 
            +
                        config: The configuration object to filter the runs.
         | 
| 59 112 |  | 
| 60 113 | 
             
                    Returns:
         | 
| 61 | 
            -
                         | 
| 114 | 
            +
                        A new `Runs` object containing the filtered runs.
         | 
| 62 115 | 
             
                    """
         | 
| 63 116 | 
             
                    return Runs(filter_runs(self.runs, config))
         | 
| 64 117 |  | 
| 65 | 
            -
                def get(self, config: object) -> Run:
         | 
| 118 | 
            +
                def get(self, config: object) -> Run | None:
         | 
| 66 119 | 
             
                    """
         | 
| 67 120 | 
             
                    Retrieve a specific run based on the provided configuration.
         | 
| 68 121 |  | 
| 69 122 | 
             
                    This method filters the runs in the collection according to the
         | 
| 70 123 | 
             
                    specified configuration object and returns the run that matches
         | 
| 71 124 | 
             
                    the provided parameters. If more than one run matches the criteria,
         | 
| 72 | 
            -
                     | 
| 125 | 
            +
                    a `ValueError` is raised.
         | 
| 73 126 |  | 
| 74 127 | 
             
                    Args:
         | 
| 75 | 
            -
                        config | 
| 128 | 
            +
                        config: The configuration object to identify the run.
         | 
| 76 129 |  | 
| 77 130 | 
             
                    Returns:
         | 
| 78 131 | 
             
                        Run: The run object that matches the provided configuration.
         | 
| 132 | 
            +
                        None, if the runs are not in a DataFrame format.
         | 
| 79 133 |  | 
| 80 134 | 
             
                    Raises:
         | 
| 81 135 | 
             
                        ValueError: If the number of filtered runs is not exactly one.
         | 
| 82 136 | 
             
                    """
         | 
| 83 | 
            -
                    return  | 
| 137 | 
            +
                    return get_run(self.runs, config)
         | 
| 84 138 |  | 
| 85 | 
            -
                def  | 
| 139 | 
            +
                def get_earliest_run(self, config: object | None = None, **kwargs) -> Run | None:
         | 
| 86 140 | 
             
                    """
         | 
| 87 | 
            -
                     | 
| 141 | 
            +
                    Get the earliest run from the list of runs based on the start time.
         | 
| 88 142 |  | 
| 89 | 
            -
                    This method  | 
| 90 | 
            -
                     | 
| 91 | 
            -
             | 
| 143 | 
            +
                    This method filters the runs based on the configuration if provided
         | 
| 144 | 
            +
                    and returns the run with the earliest start time.
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    Args:
         | 
| 147 | 
            +
                        config: The configuration object to filter the runs.
         | 
| 148 | 
            +
                            If None, no filtering is applied.
         | 
| 149 | 
            +
                        **kwargs: Additional key-value pairs to filter the runs.
         | 
| 92 150 |  | 
| 93 151 | 
             
                    Returns:
         | 
| 94 | 
            -
                         | 
| 152 | 
            +
                        The run with the earliest start time, or None if no runs match the criteria.
         | 
| 153 | 
            +
                    """
         | 
| 154 | 
            +
                    return get_earliest_run(self.runs, config, **kwargs)
         | 
| 95 155 |  | 
| 96 | 
            -
             | 
| 97 | 
            -
                        NotImplementedError: If the runs are not in a DataFrame format.
         | 
| 156 | 
            +
                def get_latest_run(self, config: object | None = None, **kwargs) -> Run | None:
         | 
| 98 157 | 
             
                    """
         | 
| 99 | 
            -
                     | 
| 100 | 
            -
                        return Runs(drop_unique_params(self.runs))
         | 
| 158 | 
            +
                    Get the latest run from the list of runs based on the start time.
         | 
| 101 159 |  | 
| 102 | 
            -
                     | 
| 160 | 
            +
                    Args:
         | 
| 161 | 
            +
                        config: The configuration object to filter the runs.
         | 
| 162 | 
            +
                            If None, no filtering is applied.
         | 
| 163 | 
            +
                        **kwargs: Additional key-value pairs to filter the runs.
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    Returns:
         | 
| 166 | 
            +
                        The run with the latest start time, or None if no runs match the criteria.
         | 
| 167 | 
            +
                    """
         | 
| 168 | 
            +
                    return get_latest_run(self.runs, config, **kwargs)
         | 
| 103 169 |  | 
| 104 170 | 
             
                def get_param_names(self) -> list[str]:
         | 
| 105 171 | 
             
                    """
         | 
| 106 172 | 
             
                    Get the parameter names from the runs.
         | 
| 107 173 |  | 
| 108 | 
            -
                    This method extracts the parameter names from the  | 
| 109 | 
            -
                     | 
| 110 | 
            -
                     | 
| 174 | 
            +
                    This method extracts the unique parameter names from the provided list of runs.
         | 
| 175 | 
            +
                    It iterates through each run and collects the parameter names into a set to
         | 
| 176 | 
            +
                    ensure uniqueness.
         | 
| 111 177 |  | 
| 112 178 | 
             
                    Returns:
         | 
| 113 | 
            -
                         | 
| 114 | 
            -
             | 
| 115 | 
            -
                    Raises:
         | 
| 116 | 
            -
                        NotImplementedError: If the runs are not in a DataFrame format.
         | 
| 179 | 
            +
                        A list of unique parameter names.
         | 
| 117 180 | 
             
                    """
         | 
| 118 | 
            -
                     | 
| 119 | 
            -
                        return get_param_names(self.runs)
         | 
| 120 | 
            -
             | 
| 121 | 
            -
                    raise NotImplementedError
         | 
| 181 | 
            +
                    return get_param_names(self.runs)
         | 
| 122 182 |  | 
| 123 183 | 
             
                def get_param_dict(self) -> dict[str, list[str]]:
         | 
| 124 184 | 
             
                    """
         | 
| 125 | 
            -
                    Get the parameter dictionary from the runs.
         | 
| 185 | 
            +
                    Get the parameter dictionary from the list of runs.
         | 
| 126 186 |  | 
| 127 187 | 
             
                    This method extracts the parameter names and their corresponding values
         | 
| 128 | 
            -
                    from the  | 
| 129 | 
            -
                     | 
| 130 | 
            -
             | 
| 188 | 
            +
                    from the provided list of runs. It iterates through each run and collects
         | 
| 189 | 
            +
                    the parameter values into a dictionary where the keys are parameter names
         | 
| 190 | 
            +
                    and the values are lists of parameter values.
         | 
| 131 191 |  | 
| 132 192 | 
             
                    Returns:
         | 
| 133 | 
            -
                         | 
| 134 | 
            -
                         | 
| 135 | 
            -
             | 
| 136 | 
            -
                    Raises:
         | 
| 137 | 
            -
                        NotImplementedError: If the runs are not in a DataFrame format.
         | 
| 193 | 
            +
                        A dictionary where the keys are parameter names and the values are lists
         | 
| 194 | 
            +
                        of parameter values.
         | 
| 138 195 | 
             
                    """
         | 
| 139 | 
            -
                     | 
| 140 | 
            -
                        return get_param_dict(self.runs)
         | 
| 141 | 
            -
             | 
| 142 | 
            -
                    raise NotImplementedError
         | 
| 143 | 
            -
             | 
| 144 | 
            -
             | 
| 145 | 
            -
            def search_runs(*args, **kwargs) -> Runs:
         | 
| 146 | 
            -
                """
         | 
| 147 | 
            -
                Search for runs that match the specified criteria.
         | 
| 148 | 
            -
             | 
| 149 | 
            -
                This function wraps the `mlflow.search_runs` function and returns the results
         | 
| 150 | 
            -
                as a `Runs` object.  It allows for flexible searching of MLflow runs based on
         | 
| 151 | 
            -
                various criteria.
         | 
| 152 | 
            -
             | 
| 153 | 
            -
                Args:
         | 
| 154 | 
            -
                    *args: Positional arguments to pass to `mlflow.search_runs`.
         | 
| 155 | 
            -
                    **kwargs: Keyword arguments to pass to `mlflow.search_runs`.
         | 
| 156 | 
            -
             | 
| 157 | 
            -
                Returns:
         | 
| 158 | 
            -
                    Runs: A `Runs` object containing the search results.
         | 
| 159 | 
            -
                """
         | 
| 160 | 
            -
                runs = mlflow.search_runs(*args, **kwargs)
         | 
| 161 | 
            -
                return Runs(runs)
         | 
| 196 | 
            +
                    return get_param_dict(self.runs)
         | 
| 162 197 |  | 
| 163 198 |  | 
| 164 | 
            -
            def filter_runs(runs: list[ | 
| 199 | 
            +
            def filter_runs(runs: list[Run], config: object, **kwargs) -> list[Run]:
         | 
| 165 200 | 
             
                """
         | 
| 166 201 | 
             
                Filter the runs based on the provided configuration.
         | 
| 167 202 |  | 
| @@ -169,22 +204,26 @@ def filter_runs(runs: list[Run_] | DataFrame, config: object) -> list[Run_] | Da | |
| 169 204 | 
             
                specified configuration object. The configuration object should
         | 
| 170 205 | 
             
                contain key-value pairs that correspond to the parameters of the
         | 
| 171 206 | 
             
                runs. Only the runs that match all the specified parameters will
         | 
| 172 | 
            -
                be included in the returned  | 
| 207 | 
            +
                be included in the returned list of runs.
         | 
| 173 208 |  | 
| 174 209 | 
             
                Args:
         | 
| 175 210 | 
             
                    runs: The runs to filter.
         | 
| 176 211 | 
             
                    config: The configuration object to filter the runs.
         | 
| 212 | 
            +
                    **kwargs: Additional key-value pairs to filter the runs.
         | 
| 177 213 |  | 
| 178 214 | 
             
                Returns:
         | 
| 179 | 
            -
                     | 
| 215 | 
            +
                    A filtered list of runs.
         | 
| 180 216 | 
             
                """
         | 
| 181 | 
            -
                 | 
| 182 | 
            -
                     | 
| 217 | 
            +
                for key, value in chain(iter_params(config), kwargs.items()):
         | 
| 218 | 
            +
                    runs = [run for run in runs if _is_equal(run, key, value)]
         | 
| 183 219 |  | 
| 184 | 
            -
             | 
| 220 | 
            +
                    if len(runs) == 0:
         | 
| 221 | 
            +
                        return []
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                return runs
         | 
| 185 224 |  | 
| 186 225 |  | 
| 187 | 
            -
            def _is_equal(run:  | 
| 226 | 
            +
            def _is_equal(run: Run, key: str, value: Any) -> bool:
         | 
| 188 227 | 
             
                param = run.data.params.get(key, value)
         | 
| 189 228 |  | 
| 190 229 | 
             
                if param is None:
         | 
| @@ -193,275 +232,146 @@ def _is_equal(run: Run_, key: str, value: Any) -> bool: | |
| 193 232 | 
             
                return type(value)(param) == value
         | 
| 194 233 |  | 
| 195 234 |  | 
| 196 | 
            -
            def  | 
| 197 | 
            -
                for key, value in iter_params(config):
         | 
| 198 | 
            -
                    runs = [run for run in runs if _is_equal(run, key, value)]
         | 
| 199 | 
            -
             | 
| 200 | 
            -
                return runs
         | 
| 201 | 
            -
             | 
| 202 | 
            -
             | 
| 203 | 
            -
            def _filter_runs_dataframe(runs: DataFrame, config: object) -> DataFrame:
         | 
| 204 | 
            -
                index = np.ones(len(runs), dtype=bool)
         | 
| 205 | 
            -
             | 
| 206 | 
            -
                for key, value in iter_params(config):
         | 
| 207 | 
            -
                    name = f"params.{key}"
         | 
| 208 | 
            -
             | 
| 209 | 
            -
                    if name in runs:
         | 
| 210 | 
            -
                        series = runs[name]
         | 
| 211 | 
            -
                        is_value = -series.isna()
         | 
| 212 | 
            -
                        param = series.fillna(value).astype(type(value))
         | 
| 213 | 
            -
                        index &= is_value & (param == value)
         | 
| 214 | 
            -
             | 
| 215 | 
            -
                return runs[index]
         | 
| 216 | 
            -
             | 
| 217 | 
            -
             | 
| 218 | 
            -
            def get_run(runs: list[Run_] | DataFrame, config: object) -> Run_ | Series:
         | 
| 235 | 
            +
            def get_run(runs: list[Run], config: object, **kwargs) -> Run | None:
         | 
| 219 236 | 
             
                """
         | 
| 220 237 | 
             
                Retrieve a specific run based on the provided configuration.
         | 
| 221 238 |  | 
| 222 239 | 
             
                This method filters the runs in the collection according to the
         | 
| 223 240 | 
             
                specified configuration object and returns the run that matches
         | 
| 224 241 | 
             
                the provided parameters. If more than one run matches the criteria,
         | 
| 225 | 
            -
                 | 
| 242 | 
            +
                a `ValueError` is raised.
         | 
| 226 243 |  | 
| 227 244 | 
             
                Args:
         | 
| 228 245 | 
             
                    runs: The runs to filter.
         | 
| 229 246 | 
             
                    config: The configuration object to identify the run.
         | 
| 247 | 
            +
                    **kwargs: Additional key-value pairs to filter the runs.
         | 
| 230 248 |  | 
| 231 249 | 
             
                Returns:
         | 
| 232 | 
            -
                     | 
| 250 | 
            +
                    The run object that matches the provided configuration, or None
         | 
| 251 | 
            +
                    if no runs match the criteria.
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                Raises:
         | 
| 254 | 
            +
                    ValueError: If more than one run matches the criteria.
         | 
| 233 255 | 
             
                """
         | 
| 234 | 
            -
                runs = filter_runs(runs, config)
         | 
| 256 | 
            +
                runs = filter_runs(runs, config, **kwargs)
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                if len(runs) == 0:
         | 
| 259 | 
            +
                    return None
         | 
| 235 260 |  | 
| 236 261 | 
             
                if len(runs) == 1:
         | 
| 237 | 
            -
                    return runs[0] | 
| 262 | 
            +
                    return runs[0]
         | 
| 238 263 |  | 
| 239 | 
            -
                msg = f"number of  | 
| 264 | 
            +
                msg = f"Multiple runs were filtered. Expected number of runs is 1, but found {len(runs)} runs."
         | 
| 240 265 | 
             
                raise ValueError(msg)
         | 
| 241 266 |  | 
| 242 267 |  | 
| 243 | 
            -
            def  | 
| 268 | 
            +
            def get_earliest_run(runs: list[Run], config: object | None = None, **kwargs) -> Run | None:
         | 
| 244 269 | 
             
                """
         | 
| 245 | 
            -
                 | 
| 270 | 
            +
                Get the earliest run from the list of runs based on the start time.
         | 
| 246 271 |  | 
| 247 | 
            -
                This method  | 
| 248 | 
            -
                 | 
| 249 | 
            -
                that are shared among multiple runs.
         | 
| 272 | 
            +
                This method filters the runs based on the configuration if provided
         | 
| 273 | 
            +
                and returns the run with the earliest start time.
         | 
| 250 274 |  | 
| 251 275 | 
             
                Args:
         | 
| 252 | 
            -
                    runs: The  | 
| 276 | 
            +
                    runs: The list of runs.
         | 
| 277 | 
            +
                    config: The configuration object to filter the runs.
         | 
| 278 | 
            +
                        If None, no filtering is applied.
         | 
| 279 | 
            +
                    **kwargs: Additional key-value pairs to filter the runs.
         | 
| 253 280 |  | 
| 254 281 | 
             
                Returns:
         | 
| 255 | 
            -
                     | 
| 282 | 
            +
                    The run with the earliest start time, or None if no runs match the criteria.
         | 
| 256 283 | 
             
                """
         | 
| 284 | 
            +
                if config is not None or kwargs:
         | 
| 285 | 
            +
                    runs = filter_runs(runs, config or {}, **kwargs)
         | 
| 257 286 |  | 
| 258 | 
            -
                 | 
| 259 | 
            -
                    return not column.startswith("params.") or len(runs[column].unique()) > 1
         | 
| 287 | 
            +
                return min(runs, key=lambda run: run.info.start_time, default=None)
         | 
| 260 288 |  | 
| 261 | 
            -
                columns = [select(column) for column in runs.columns]
         | 
| 262 | 
            -
                return runs.iloc[:, columns]
         | 
| 263 289 |  | 
| 264 | 
            -
             | 
| 265 | 
            -
            def get_param_names(runs: DataFrame) -> list[str]:
         | 
| 290 | 
            +
            def get_latest_run(runs: list[Run], config: object | None = None, **kwargs) -> Run | None:
         | 
| 266 291 | 
             
                """
         | 
| 267 | 
            -
                Get the  | 
| 292 | 
            +
                Get the latest run from the list of runs based on the start time.
         | 
| 268 293 |  | 
| 269 | 
            -
                This method  | 
| 270 | 
            -
                 | 
| 271 | 
            -
                that correspond to the parameters.
         | 
| 294 | 
            +
                This method filters the runs based on the configuration if provided
         | 
| 295 | 
            +
                and returns the run with the latest start time.
         | 
| 272 296 |  | 
| 273 297 | 
             
                Args:
         | 
| 274 | 
            -
                    runs: The  | 
| 298 | 
            +
                    runs: The list of runs.
         | 
| 299 | 
            +
                    config: The configuration object to filter the runs.
         | 
| 300 | 
            +
                        If None, no filtering is applied.
         | 
| 301 | 
            +
                    **kwargs: Additional key-value pairs to filter the runs.
         | 
| 275 302 |  | 
| 276 303 | 
             
                Returns:
         | 
| 277 | 
            -
                     | 
| 304 | 
            +
                    The run with the latest start time, or None if no runs match the criteria.
         | 
| 278 305 | 
             
                """
         | 
| 306 | 
            +
                if config is not None or kwargs:
         | 
| 307 | 
            +
                    runs = filter_runs(runs, config or {}, **kwargs)
         | 
| 279 308 |  | 
| 280 | 
            -
                 | 
| 281 | 
            -
                    if column.startswith("params."):
         | 
| 282 | 
            -
                        return column.split(".", maxsplit=1)[-1]
         | 
| 283 | 
            -
             | 
| 284 | 
            -
                    return ""
         | 
| 285 | 
            -
             | 
| 286 | 
            -
                columns = [get_name(column) for column in runs.columns]
         | 
| 287 | 
            -
                return [column for column in columns if column]
         | 
| 309 | 
            +
                return max(runs, key=lambda run: run.info.start_time, default=None)
         | 
| 288 310 |  | 
| 289 311 |  | 
| 290 | 
            -
            def  | 
| 312 | 
            +
            def get_param_names(runs: list[Run]) -> list[str]:
         | 
| 291 313 | 
             
                """
         | 
| 292 | 
            -
                Get the parameter  | 
| 293 | 
            -
             | 
| 294 | 
            -
                This method extracts the parameter names and their corresponding values
         | 
| 295 | 
            -
                from the runs in the collection. If the runs are stored in a DataFrame,
         | 
| 296 | 
            -
                it retrieves the unique values for each parameter.
         | 
| 297 | 
            -
             | 
| 298 | 
            -
                Args:
         | 
| 299 | 
            -
                    runs: The DataFrame containing the runs.
         | 
| 300 | 
            -
             | 
| 301 | 
            -
                Returns:
         | 
| 302 | 
            -
                    dict[str, list[str]]: A dictionary of parameter names and
         | 
| 303 | 
            -
                    their corresponding values.
         | 
| 304 | 
            -
                """
         | 
| 305 | 
            -
                params = {}
         | 
| 306 | 
            -
                for name in get_param_names(runs):
         | 
| 307 | 
            -
                    params[name] = list(runs[f"params.{name}"].unique())
         | 
| 308 | 
            -
             | 
| 309 | 
            -
                return params
         | 
| 310 | 
            -
             | 
| 311 | 
            -
             | 
| 312 | 
            -
            @dataclass
         | 
| 313 | 
            -
            class Run:
         | 
| 314 | 
            -
                """
         | 
| 315 | 
            -
                A class to represent a specific MLflow run.
         | 
| 316 | 
            -
             | 
| 317 | 
            -
                This class provides methods to interact with the run, such as retrieving
         | 
| 318 | 
            -
                the run ID, artifact URI, and configuration. It also includes properties
         | 
| 319 | 
            -
                to access the artifact directory, artifact path, and Hydra output directory.
         | 
| 320 | 
            -
                """
         | 
| 321 | 
            -
             | 
| 322 | 
            -
                run: Run_ | Series | str
         | 
| 323 | 
            -
             | 
| 324 | 
            -
                def __repr__(self) -> str:
         | 
| 325 | 
            -
                    return f"{self.__class__.__name__}({self.run_id!r})"
         | 
| 326 | 
            -
             | 
| 327 | 
            -
                @property
         | 
| 328 | 
            -
                def run_id(self) -> str:
         | 
| 329 | 
            -
                    """
         | 
| 330 | 
            -
                    Get the run ID.
         | 
| 331 | 
            -
             | 
| 332 | 
            -
                    Returns:
         | 
| 333 | 
            -
                        str: The run ID.
         | 
| 334 | 
            -
                    """
         | 
| 335 | 
            -
                    return get_run_id(self.run)
         | 
| 336 | 
            -
             | 
| 337 | 
            -
                def artifact_uri(self, artifact_path: str | None = None) -> str:
         | 
| 338 | 
            -
                    """
         | 
| 339 | 
            -
                    Get the artifact URI.
         | 
| 340 | 
            -
             | 
| 341 | 
            -
                    Args:
         | 
| 342 | 
            -
                        artifact_path (str | None): The artifact path.
         | 
| 343 | 
            -
             | 
| 344 | 
            -
                    Returns:
         | 
| 345 | 
            -
                        str: The artifact URI.
         | 
| 346 | 
            -
                    """
         | 
| 347 | 
            -
                    return get_artifact_uri(self.run, artifact_path)
         | 
| 348 | 
            -
             | 
| 349 | 
            -
                @property
         | 
| 350 | 
            -
                def artifact_dir(self) -> Path:
         | 
| 351 | 
            -
                    """
         | 
| 352 | 
            -
                    Get the artifact directory.
         | 
| 353 | 
            -
             | 
| 354 | 
            -
                    Returns:
         | 
| 355 | 
            -
                        Path: The artifact directory.
         | 
| 356 | 
            -
                    """
         | 
| 357 | 
            -
                    return get_artifact_dir(self.run)
         | 
| 358 | 
            -
             | 
| 359 | 
            -
                def artifact_path(self, artifact_path: str | None = None) -> Path:
         | 
| 360 | 
            -
                    """
         | 
| 361 | 
            -
                    Get the artifact path.
         | 
| 362 | 
            -
             | 
| 363 | 
            -
                    Args:
         | 
| 364 | 
            -
                        artifact_path: The artifact path.
         | 
| 365 | 
            -
             | 
| 366 | 
            -
                    Returns:
         | 
| 367 | 
            -
                        Path: The artifact path.
         | 
| 368 | 
            -
                    """
         | 
| 369 | 
            -
                    return get_artifact_path(self.run, artifact_path)
         | 
| 370 | 
            -
             | 
| 371 | 
            -
                @property
         | 
| 372 | 
            -
                def config(self) -> DictConfig:
         | 
| 373 | 
            -
                    """
         | 
| 374 | 
            -
                    Get the configuration.
         | 
| 375 | 
            -
             | 
| 376 | 
            -
                    Returns:
         | 
| 377 | 
            -
                        DictConfig: The configuration.
         | 
| 378 | 
            -
                    """
         | 
| 379 | 
            -
                    return load_config(self.run)
         | 
| 380 | 
            -
             | 
| 381 | 
            -
                def log_hydra_output_dir(self) -> None:
         | 
| 382 | 
            -
                    """
         | 
| 383 | 
            -
                    Log the Hydra output directory.
         | 
| 384 | 
            -
             | 
| 385 | 
            -
                    Returns:
         | 
| 386 | 
            -
                        None
         | 
| 387 | 
            -
                    """
         | 
| 388 | 
            -
                    log_hydra_output_dir(self.run)
         | 
| 314 | 
            +
                Get the parameter names from the runs.
         | 
| 389 315 |  | 
| 390 | 
            -
             | 
| 391 | 
            -
             | 
| 392 | 
            -
                 | 
| 393 | 
            -
                Get the run ID.
         | 
| 316 | 
            +
                This method extracts the unique parameter names from the provided list of runs.
         | 
| 317 | 
            +
                It iterates through each run and collects the parameter names into a set to
         | 
| 318 | 
            +
                ensure uniqueness.
         | 
| 394 319 |  | 
| 395 320 | 
             
                Args:
         | 
| 396 | 
            -
                     | 
| 321 | 
            +
                    runs: The list of runs from which to extract parameter names.
         | 
| 397 322 |  | 
| 398 323 | 
             
                Returns:
         | 
| 399 | 
            -
                     | 
| 324 | 
            +
                    A list of unique parameter names.
         | 
| 400 325 | 
             
                """
         | 
| 401 | 
            -
                 | 
| 402 | 
            -
                    return run
         | 
| 403 | 
            -
             | 
| 404 | 
            -
                if isinstance(run, Run_):
         | 
| 405 | 
            -
                    return run.info.run_id
         | 
| 326 | 
            +
                param_names = set()
         | 
| 406 327 |  | 
| 407 | 
            -
                 | 
| 328 | 
            +
                for run in runs:
         | 
| 329 | 
            +
                    for param in run.data.params.keys():
         | 
| 330 | 
            +
                        param_names.add(param)
         | 
| 408 331 |  | 
| 332 | 
            +
                return list(param_names)
         | 
| 409 333 |  | 
| 410 | 
            -
            def get_artifact_uri(run: Run_ | Series | str, artifact_path: str | None = None) -> str:
         | 
| 411 | 
            -
                """
         | 
| 412 | 
            -
                Get the artifact URI.
         | 
| 413 | 
            -
             | 
| 414 | 
            -
                Args:
         | 
| 415 | 
            -
                    run: The run object.
         | 
| 416 | 
            -
                    artifact_path: The artifact path.
         | 
| 417 334 |  | 
| 418 | 
            -
             | 
| 419 | 
            -
                    str: The artifact URI.
         | 
| 335 | 
            +
            def get_param_dict(runs: list[Run]) -> dict[str, list[str]]:
         | 
| 420 336 | 
             
                """
         | 
| 421 | 
            -
                 | 
| 422 | 
            -
                return artifact_utils.get_artifact_uri(run_id, artifact_path)
         | 
| 337 | 
            +
                Get the parameter dictionary from the list of runs.
         | 
| 423 338 |  | 
| 424 | 
            -
             | 
| 425 | 
            -
             | 
| 426 | 
            -
                 | 
| 427 | 
            -
                 | 
| 339 | 
            +
                This method extracts the parameter names and their corresponding values
         | 
| 340 | 
            +
                from the provided list of runs. It iterates through each run and collects
         | 
| 341 | 
            +
                the parameter values into a dictionary where the keys are parameter names
         | 
| 342 | 
            +
                and the values are lists of parameter values.
         | 
| 428 343 |  | 
| 429 344 | 
             
                Args:
         | 
| 430 | 
            -
                     | 
| 345 | 
            +
                    runs: The list of runs from which to extract parameter names and values.
         | 
| 431 346 |  | 
| 432 347 | 
             
                Returns:
         | 
| 433 | 
            -
                     | 
| 348 | 
            +
                    A dictionary where the keys are parameter names and the values are lists
         | 
| 349 | 
            +
                    of parameter values.
         | 
| 434 350 | 
             
                """
         | 
| 435 | 
            -
                 | 
| 436 | 
            -
                return uri_to_path(uri)
         | 
| 351 | 
            +
                params = {}
         | 
| 437 352 |  | 
| 353 | 
            +
                for name in get_param_names(runs):
         | 
| 354 | 
            +
                    it = (run.data.params[name] for run in runs if name in run.data.params)
         | 
| 355 | 
            +
                    params[name] = sorted(set(it))
         | 
| 438 356 |  | 
| 439 | 
            -
             | 
| 440 | 
            -
                """
         | 
| 441 | 
            -
                Get the artifact path.
         | 
| 357 | 
            +
                return params
         | 
| 442 358 |  | 
| 443 | 
            -
                Args:
         | 
| 444 | 
            -
                    run: The run object.
         | 
| 445 | 
            -
                    artifact_path: The artifact path.
         | 
| 446 359 |  | 
| 447 | 
            -
             | 
| 448 | 
            -
                    Path: The artifact path.
         | 
| 360 | 
            +
            def load_config(run: Run) -> DictConfig:
         | 
| 449 361 | 
             
                """
         | 
| 450 | 
            -
                 | 
| 451 | 
            -
                return artifact_dir / artifact_path if artifact_path else artifact_dir
         | 
| 362 | 
            +
                Load the configuration for a given run.
         | 
| 452 363 |  | 
| 453 | 
            -
             | 
| 454 | 
            -
             | 
| 455 | 
            -
                 | 
| 456 | 
            -
                Load the configuration.
         | 
| 364 | 
            +
                This function loads the configuration for the provided Run instance
         | 
| 365 | 
            +
                by downloading the configuration file from the MLflow artifacts and
         | 
| 366 | 
            +
                loading it using OmegaConf.
         | 
| 457 367 |  | 
| 458 368 | 
             
                Args:
         | 
| 459 | 
            -
                    run: The  | 
| 369 | 
            +
                    run: The Run instance to load the configuration for.
         | 
| 460 370 |  | 
| 461 371 | 
             
                Returns:
         | 
| 462 | 
            -
                     | 
| 372 | 
            +
                    The loaded configuration.
         | 
| 463 373 | 
             
                """
         | 
| 464 | 
            -
                run_id =  | 
| 374 | 
            +
                run_id = run.info.run_id
         | 
| 465 375 | 
             
                return _load_config(run_id)
         | 
| 466 376 |  | 
| 467 377 |  | 
| @@ -478,35 +388,35 @@ def _load_config(run_id: str) -> DictConfig: | |
| 478 388 | 
             
                return OmegaConf.load(path)  # type: ignore
         | 
| 479 389 |  | 
| 480 390 |  | 
| 481 | 
            -
            def get_hydra_output_dir(run: Run_ | Series | str) -> Path:
         | 
| 482 | 
            -
             | 
| 483 | 
            -
             | 
| 391 | 
            +
            # def get_hydra_output_dir(run: Run_ | Series | str) -> Path:
         | 
| 392 | 
            +
            #     """
         | 
| 393 | 
            +
            #     Get the Hydra output directory.
         | 
| 484 394 |  | 
| 485 | 
            -
             | 
| 486 | 
            -
             | 
| 395 | 
            +
            #     Args:
         | 
| 396 | 
            +
            #         run: The run object.
         | 
| 487 397 |  | 
| 488 | 
            -
             | 
| 489 | 
            -
             | 
| 490 | 
            -
             | 
| 491 | 
            -
             | 
| 398 | 
            +
            #     Returns:
         | 
| 399 | 
            +
            #         Path: The Hydra output directory.
         | 
| 400 | 
            +
            #     """
         | 
| 401 | 
            +
            #     path = get_artifact_dir(run) / ".hydra/hydra.yaml"
         | 
| 492 402 |  | 
| 493 | 
            -
             | 
| 494 | 
            -
             | 
| 495 | 
            -
             | 
| 403 | 
            +
            #     if path.exists():
         | 
| 404 | 
            +
            #         hc = OmegaConf.load(path)
         | 
| 405 | 
            +
            #         return Path(hc.hydra.runtime.output_dir)
         | 
| 496 406 |  | 
| 497 | 
            -
             | 
| 407 | 
            +
            #     raise FileNotFoundError
         | 
| 498 408 |  | 
| 499 409 |  | 
| 500 | 
            -
            def log_hydra_output_dir(run: Run_ | Series | str) -> None:
         | 
| 501 | 
            -
             | 
| 502 | 
            -
             | 
| 410 | 
            +
            # def log_hydra_output_dir(run: Run_ | Series | str) -> None:
         | 
| 411 | 
            +
            #     """
         | 
| 412 | 
            +
            #     Log the Hydra output directory.
         | 
| 503 413 |  | 
| 504 | 
            -
             | 
| 505 | 
            -
             | 
| 414 | 
            +
            #     Args:
         | 
| 415 | 
            +
            #         run: The run object.
         | 
| 506 416 |  | 
| 507 | 
            -
             | 
| 508 | 
            -
             | 
| 509 | 
            -
             | 
| 510 | 
            -
             | 
| 511 | 
            -
             | 
| 512 | 
            -
             | 
| 417 | 
            +
            #     Returns:
         | 
| 418 | 
            +
            #         None
         | 
| 419 | 
            +
            #     """
         | 
| 420 | 
            +
            #     output_dir = get_hydra_output_dir(run)
         | 
| 421 | 
            +
            #     run_id = run if isinstance(run, str) else run.info.run_id
         | 
| 422 | 
            +
            #     mlflow.log_artifacts(output_dir.as_posix(), run_id=run_id)
         | 
| @@ -1,6 +1,6 @@ | |
| 1 1 | 
             
            Metadata-Version: 2.3
         | 
| 2 2 | 
             
            Name: hydraflow
         | 
| 3 | 
            -
            Version: 0.1 | 
| 3 | 
            +
            Version: 0.2.1
         | 
| 4 4 | 
             
            Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
         | 
| 5 5 | 
             
            Project-URL: Documentation, https://github.com/daizutabi/hydraflow
         | 
| 6 6 | 
             
            Project-URL: Source, https://github.com/daizutabi/hydraflow
         | 
| @@ -0,0 +1,9 @@ | |
| 1 | 
            +
            hydraflow/__init__.py,sha256=PzziOG9RnGAVbl9Yz4ScvsL8nfkjsuN0alMKRvZT-_Y,442
         | 
| 2 | 
            +
            hydraflow/config.py,sha256=wI8uNuD2D-hIf4BAhEYJaMC6EyO-erKopy_ia_b1pYA,2048
         | 
| 3 | 
            +
            hydraflow/context.py,sha256=MqkEhKEZL_N3eb3v5u9D4EqKkiSmiPyXXafhPkALRlg,5129
         | 
| 4 | 
            +
            hydraflow/mlflow.py,sha256=_Los9E38eG8sTiN8bGwZmvjCrS0S-wSGiA4fyhQM3Zw,2251
         | 
| 5 | 
            +
            hydraflow/runs.py,sha256=NT7IzE-Pf7T2Ey-eWEPZzQQaX4Gt_RKDKSn2pj2yzGc,14304
         | 
| 6 | 
            +
            hydraflow-0.2.1.dist-info/METADATA,sha256=4C_hnw1gMb8WUQXyqj4q8eA1IVbp0wZuLGGthIk1G7U,4224
         | 
| 7 | 
            +
            hydraflow-0.2.1.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
         | 
| 8 | 
            +
            hydraflow-0.2.1.dist-info/licenses/LICENSE,sha256=IGdDrBPqz1O0v_UwCW-NJlbX9Hy9b3uJ11t28y2srmY,1062
         | 
| 9 | 
            +
            hydraflow-0.2.1.dist-info/RECORD,,
         | 
    
        hydraflow/util.py
    DELETED
    
    | @@ -1,24 +0,0 @@ | |
| 1 | 
            -
            import platform
         | 
| 2 | 
            -
            from pathlib import Path
         | 
| 3 | 
            -
            from urllib.parse import urlparse
         | 
| 4 | 
            -
             | 
| 5 | 
            -
             | 
| 6 | 
            -
            def uri_to_path(uri: str) -> Path:
         | 
| 7 | 
            -
                """
         | 
| 8 | 
            -
                Convert a URI to a path.
         | 
| 9 | 
            -
             | 
| 10 | 
            -
                This function parses the given URI and converts it to a local file system
         | 
| 11 | 
            -
                path. On Windows, if the path starts with a forward slash, it is removed
         | 
| 12 | 
            -
                to ensure the path is correctly formatted.
         | 
| 13 | 
            -
             | 
| 14 | 
            -
                Args:
         | 
| 15 | 
            -
                    uri (str): The URI to convert.
         | 
| 16 | 
            -
             | 
| 17 | 
            -
                Returns:
         | 
| 18 | 
            -
                    Path: The path corresponding to the URI.
         | 
| 19 | 
            -
                """
         | 
| 20 | 
            -
                path = urlparse(uri).path
         | 
| 21 | 
            -
                if platform.system() == "Windows" and path.startswith("/"):
         | 
| 22 | 
            -
                    path = path[1:]
         | 
| 23 | 
            -
             | 
| 24 | 
            -
                return Path(path)
         | 
    
        hydraflow-0.1.5.dist-info/RECORD
    DELETED
    
    | @@ -1,10 +0,0 @@ | |
| 1 | 
            -
            hydraflow/__init__.py,sha256=e1Q0Sskx39jaU2zkGNXjFWNC5xugEz_hDERTN_6Mzy8,666
         | 
| 2 | 
            -
            hydraflow/config.py,sha256=WARa5u1F0n3wCOi65v8v8rUO78ME-mtzMeeeE2Yc1I8,1728
         | 
| 3 | 
            -
            hydraflow/context.py,sha256=NYjIMepLtaKyvw1obpE8gR1qu1OBpSB_uc6-5So2tg8,5139
         | 
| 4 | 
            -
            hydraflow/mlflow.py,sha256=2YWOYpv8eRB_ROD2yFh6ksKDXHvAPDYb86hrUi9zv6E,1558
         | 
| 5 | 
            -
            hydraflow/runs.py,sha256=vH-hrlcoTo8HRmgUWam9gtLXAl_wDzX26HEZGWckdMs,14038
         | 
| 6 | 
            -
            hydraflow/util.py,sha256=qdUGtBgY7qOF4Yr4PibJHImbLPf-6WYFVuIKu6zbNbY,614
         | 
| 7 | 
            -
            hydraflow-0.1.5.dist-info/METADATA,sha256=8mCKAA9KjcJAUiqP-DPdMl4Gcp3MSXxOF34VYKA2P8I,4224
         | 
| 8 | 
            -
            hydraflow-0.1.5.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
         | 
| 9 | 
            -
            hydraflow-0.1.5.dist-info/licenses/LICENSE,sha256=IGdDrBPqz1O0v_UwCW-NJlbX9Hy9b3uJ11t28y2srmY,1062
         | 
| 10 | 
            -
            hydraflow-0.1.5.dist-info/RECORD,,
         | 
| 
            File without changes
         | 
| 
            File without changes
         |