hydraflow 0.19.1__tar.gz → 0.19.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. {hydraflow-0.19.1 → hydraflow-0.19.2}/PKG-INFO +1 -1
  2. {hydraflow-0.19.1 → hydraflow-0.19.2}/pyproject.toml +16 -5
  3. {hydraflow-0.19.1 → hydraflow-0.19.2}/src/hydraflow/cli.py +1 -1
  4. {hydraflow-0.19.1 → hydraflow-0.19.2}/src/hydraflow/core/collection.py +12 -10
  5. {hydraflow-0.19.1 → hydraflow-0.19.2}/src/hydraflow/core/group_by.py +2 -2
  6. {hydraflow-0.19.1 → hydraflow-0.19.2}/src/hydraflow/core/io.py +1 -1
  7. {hydraflow-0.19.1 → hydraflow-0.19.2}/src/hydraflow/core/main.py +3 -3
  8. {hydraflow-0.19.1 → hydraflow-0.19.2}/src/hydraflow/core/run.py +12 -10
  9. {hydraflow-0.19.1 → hydraflow-0.19.2}/src/hydraflow/core/run_collection.py +2 -0
  10. {hydraflow-0.19.1 → hydraflow-0.19.2}/src/hydraflow/executor/aio.py +1 -1
  11. {hydraflow-0.19.1 → hydraflow-0.19.2}/src/hydraflow/executor/io.py +2 -2
  12. {hydraflow-0.19.1 → hydraflow-0.19.2}/src/hydraflow/executor/job.py +3 -3
  13. {hydraflow-0.19.1 → hydraflow-0.19.2}/src/hydraflow/executor/parser.py +3 -3
  14. {hydraflow-0.19.1 → hydraflow-0.19.2}/src/hydraflow/utils/progress.py +4 -4
  15. {hydraflow-0.19.1 → hydraflow-0.19.2}/LICENSE +0 -0
  16. {hydraflow-0.19.1 → hydraflow-0.19.2}/README.md +0 -0
  17. {hydraflow-0.19.1 → hydraflow-0.19.2}/src/hydraflow/__init__.py +0 -0
  18. {hydraflow-0.19.1 → hydraflow-0.19.2}/src/hydraflow/core/__init__.py +0 -0
  19. {hydraflow-0.19.1 → hydraflow-0.19.2}/src/hydraflow/core/context.py +0 -0
  20. {hydraflow-0.19.1 → hydraflow-0.19.2}/src/hydraflow/core/run_info.py +0 -0
  21. {hydraflow-0.19.1 → hydraflow-0.19.2}/src/hydraflow/executor/__init__.py +0 -0
  22. {hydraflow-0.19.1 → hydraflow-0.19.2}/src/hydraflow/executor/conf.py +0 -0
  23. {hydraflow-0.19.1 → hydraflow-0.19.2}/src/hydraflow/py.typed +0 -0
  24. {hydraflow-0.19.1 → hydraflow-0.19.2}/src/hydraflow/utils/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: hydraflow
3
- Version: 0.19.1
3
+ Version: 0.19.2
4
4
  Summary: HydraFlow seamlessly integrates Hydra and MLflow to streamline ML experiment management, combining Hydra's configuration management with MLflow's tracking capabilities.
5
5
  Keywords: machine-learning,mlflow,hydra,experiment-tracking,mlops,ai,deep-learning,research,data-science
6
6
  Author: daizutabi
@@ -1,10 +1,10 @@
1
1
  [build-system]
2
- requires = ["uv_build>=0.7.19,<0.8.0"]
2
+ requires = ["uv_build"]
3
3
  build-backend = "uv_build"
4
4
 
5
5
  [project]
6
6
  name = "hydraflow"
7
- version = "0.19.1"
7
+ version = "0.19.2"
8
8
  description = "HydraFlow seamlessly integrates Hydra and MLflow to streamline ML experiment management, combining Hydra's configuration management with MLflow's tracking capabilities."
9
9
  readme = "README.md"
10
10
  license = { file = "LICENSE" }
@@ -60,7 +60,6 @@ dev = [
60
60
  "pytest-order",
61
61
  "pytest-randomly",
62
62
  "pytest-xdist",
63
- "ruff>=0.12",
64
63
  ]
65
64
  docs = ["markdown-exec[ansi]", "mkapi>=4.4", "mkdocs-material"]
66
65
 
@@ -119,5 +118,17 @@ ignore = [
119
118
  "src/hydraflow/executor/conf.py" = ["ANN", "D"]
120
119
  "tests/*" = ["A001", "ANN", "ARG", "D", "FBT", "PD", "PLR", "PT", "S", "SLF"]
121
120
 
122
- [tool.pyright]
123
- include = ["src", "tests"]
121
+ [tool.basedpyright]
122
+ include = ["src"]
123
+ exclude = ["notebooks", "tests"]
124
+ reportAny = false
125
+ reportExplicitAny = false
126
+ reportImplicitOverride = false
127
+ reportImportCycles = false
128
+ reportIncompatibleVariableOverride = false
129
+ reportMissingTypeStubs = false
130
+ reportUnusedCallResult = false
131
+
132
+ [tool.ty.rules]
133
+ unresolved-import = "ignore"
134
+ possibly-unbound-attribute = "ignore"
@@ -15,7 +15,7 @@ app = typer.Typer(add_completion=False)
15
15
 
16
16
 
17
17
  @app.command("run", context_settings={"ignore_unknown_options": True})
18
- def _run(
18
+ def _run( # pyright: ignore[reportUnusedFunction]
19
19
  name: Annotated[str, Argument(help="Job name.", show_default=False)],
20
20
  *,
21
21
  args: Annotated[
@@ -17,11 +17,13 @@ from .group_by import GroupBy
17
17
 
18
18
  if TYPE_CHECKING:
19
19
  from collections.abc import Callable, Iterator
20
- from re import Pattern, _FlagsType
20
+ from re import Pattern, _FlagsType # pyright: ignore[reportPrivateUsage]
21
21
  from typing import Any, Self
22
22
 
23
23
  from numpy.typing import NDArray
24
24
 
25
+ # pyright: reportUnknownVariableType=false
26
+
25
27
 
26
28
  class Collection[I](Sequence[I]):
27
29
  """A collection of items that implements the Sequence protocol."""
@@ -280,7 +282,7 @@ class Collection[I](Sequence[I]):
280
282
  self,
281
283
  key: str,
282
284
  default: Any | Callable[[I], Any] = MISSING,
283
- ) -> NDArray:
285
+ ) -> NDArray[Any]:
284
286
  """Extract values for a specific key from all items as a NumPy array.
285
287
 
286
288
  Args:
@@ -323,7 +325,7 @@ class Collection[I](Sequence[I]):
323
325
  self,
324
326
  key: str,
325
327
  default: Any | Callable[[I], Any] = MISSING,
326
- ) -> NDArray:
328
+ ) -> NDArray[Any]:
327
329
  """Get the unique values for a specific key across all items.
328
330
 
329
331
  Args:
@@ -456,13 +458,13 @@ class Collection[I](Sequence[I]):
456
458
  it = (delayed(function)(i, *args, **kwargs) for i in self)
457
459
 
458
460
  if not progress:
459
- return parallel(it) # type: ignore
461
+ return parallel(it) # pyright: ignore[reportReturnType]
460
462
 
461
463
  from hydraflow.utils.progress import Progress
462
464
 
463
465
  with Progress(*Progress.get_default_columns()) as p:
464
466
  p.add_task("", total=len(self))
465
- return parallel(it) # type: ignore
467
+ return parallel(it) # pyright: ignore[reportReturnType]
466
468
 
467
469
  def to_frame(
468
470
  self,
@@ -513,12 +515,12 @@ class Collection[I](Sequence[I]):
513
515
  keys_ = []
514
516
  for k in keys:
515
517
  if isinstance(k, tuple):
516
- keys_.append(k[0])
518
+ keys_.append(k[0]) # pyright: ignore[reportUnknownMemberType]
517
519
  defaults[k[0]] = k[1]
518
520
  else:
519
- keys_.append(k)
521
+ keys_.append(k) # pyright: ignore[reportUnknownMemberType]
520
522
 
521
- data = {k: self.to_list(k, defaults.get(k, MISSING)) for k in keys_}
523
+ data = {k: self.to_list(k, defaults.get(k, MISSING)) for k in keys_} # pyright: ignore[reportUnknownArgumentType]
522
524
  df = DataFrame(data)
523
525
 
524
526
  if not kwargs:
@@ -907,11 +909,11 @@ def matches(value: Any, criterion: Any) -> bool:
907
909
  if isinstance(criterion, list | set) and not _is_iterable(value):
908
910
  return value in criterion
909
911
 
910
- if isinstance(criterion, tuple) and len(criterion) == 2 and not _is_iterable(value):
912
+ if isinstance(criterion, tuple) and len(criterion) == 2 and not _is_iterable(value): # pyright: ignore[reportUnknownArgumentType]
911
913
  return criterion[0] <= value <= criterion[1]
912
914
 
913
915
  if _is_iterable(criterion):
914
- criterion = list(criterion)
916
+ criterion = list(criterion) # pyright: ignore[reportUnknownArgumentType]
915
917
 
916
918
  if _is_iterable(value):
917
919
  value = list(value)
@@ -193,10 +193,10 @@ class GroupBy[C: Collection[Any], I]:
193
193
  else:
194
194
  df = DataFrame(dict(zip(self.by, k, strict=True)) for k in gp)
195
195
 
196
- columns = []
196
+ columns: list[Series] = []
197
197
 
198
198
  for agg in aggs:
199
- values = [[c._get(i, agg, MISSING) for i in c] for c in gp.values()] # noqa: SLF001
199
+ values = [[c._get(i, agg, MISSING) for i in c] for c in gp.values()] # noqa: SLF001 # pyright: ignore[reportPrivateUsage]
200
200
  columns.append(Series(agg, values))
201
201
 
202
202
  for k, v in named_aggs.items():
@@ -38,7 +38,7 @@ def get_artifact_dir(run: Run) -> Path:
38
38
  The local path to the directory where the artifacts are downloaded.
39
39
 
40
40
  """
41
- uri = run.info.artifact_uri
41
+ uri = run.info.artifact_uri # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
42
42
 
43
43
  if not isinstance(uri, str):
44
44
  raise NotImplementedError
@@ -105,7 +105,7 @@ def main[C](
105
105
  dry_run = True
106
106
  sys.argv.remove("--dry-run")
107
107
 
108
- finished = RunStatus.to_string(RunStatus.FINISHED)
108
+ finished = RunStatus.to_string(RunStatus.FINISHED) # pyright: ignore[reportUnknownMemberType]
109
109
 
110
110
  def decorator(app: Callable[[Run, C], None]) -> Callable[[], None]:
111
111
  ConfigStore.instance().store(config_name, node)
@@ -134,11 +134,11 @@ def main[C](
134
134
  else:
135
135
  uri = experiment.artifact_location
136
136
  overrides = hc.overrides.task if match_overrides else None
137
- run_id = get_run_id(uri, cfg, overrides)
137
+ run_id = get_run_id(uri, cfg, overrides) # pyright: ignore[reportUnknownArgumentType]
138
138
 
139
139
  if run_id and not rerun_finished:
140
140
  run = mlflow.get_run(run_id)
141
- if run.info.status == finished:
141
+ if run.info.status == finished: # pyright: ignore[reportUnknownMemberType]
142
142
  return
143
143
 
144
144
  with start_run(run_id=run_id, chdir=chdir) as run:
@@ -45,6 +45,8 @@ if TYPE_CHECKING:
45
45
 
46
46
  from .run_collection import RunCollection
47
47
 
48
+ # pyright: reportUnknownVariableType=false
49
+
48
50
 
49
51
  class Run[C, I = None]:
50
52
  """Represent an MLflow Run in HydraFlow.
@@ -76,7 +78,7 @@ class Run[C, I = None]:
76
78
  impl_factory: Callable[[Path], I] | Callable[[Path, C], I] | None = None,
77
79
  ) -> None:
78
80
  self.info = RunInfo(run_dir)
79
- self.impl_factory = impl_factory or (lambda _: None) # type: ignore
81
+ self.impl_factory = impl_factory or (lambda _: None) # pyright: ignore[reportAttributeAccessIssue]
80
82
 
81
83
  def __repr__(self) -> str:
82
84
  """Return a string representation of the Run."""
@@ -93,9 +95,9 @@ class Run[C, I = None]:
93
95
  """The configuration instance loaded from the Hydra configuration file."""
94
96
  config_file = self.info.run_dir / "artifacts/.hydra/config.yaml"
95
97
  if config_file.exists():
96
- return OmegaConf.load(config_file) # type: ignore
98
+ return OmegaConf.load(config_file) # pyright: ignore[reportReturnType]
97
99
 
98
- return OmegaConf.create() # type: ignore
100
+ return OmegaConf.create() # pyright: ignore[reportReturnType]
99
101
 
100
102
  @cached_property
101
103
  def impl(self) -> I:
@@ -176,13 +178,13 @@ class Run[C, I = None]:
176
178
 
177
179
  if n_jobs == 0:
178
180
  runs = (cls(Path(r), impl_factory) for r in run_dir)
179
- return RunCollection(runs, cls.get) # type: ignore
181
+ return RunCollection(runs, cls.get)
180
182
 
181
183
  from joblib import Parallel, delayed
182
184
 
183
185
  parallel = Parallel(backend="threading", n_jobs=n_jobs)
184
186
  runs = parallel(delayed(cls)(Path(r), impl_factory) for r in run_dir)
185
- return RunCollection(runs, cls.get) # type: ignore
187
+ return RunCollection(runs, cls.get) # pyright: ignore[reportArgumentType]
186
188
 
187
189
  @overload
188
190
  def update(
@@ -235,7 +237,7 @@ class Run[C, I = None]:
235
237
  an iterable.
236
238
 
237
239
  """
238
- cfg: DictConfig = self.cfg # type: ignore
240
+ cfg: DictConfig = self.cfg # pyright: ignore[reportAssignmentType]
239
241
 
240
242
  if isinstance(key, str):
241
243
  key = key.replace("__", ".")
@@ -296,7 +298,7 @@ class Run[C, I = None]:
296
298
  """
297
299
  key = key.replace("__", ".")
298
300
 
299
- value = OmegaConf.select(self.cfg, key, default=MISSING) # type: ignore
301
+ value = OmegaConf.select(self.cfg, key, default=MISSING) # pyright: ignore[reportArgumentType]
300
302
  if value is not MISSING:
301
303
  return value
302
304
 
@@ -377,7 +379,7 @@ class Run[C, I = None]:
377
379
  if not isinstance(cfg, dict):
378
380
  raise TypeError("Configuration must be a dictionary")
379
381
 
380
- standard_dict: dict[str, Any] = {str(k): v for k, v in cfg.items()}
382
+ standard_dict: dict[str, Any] = {str(k): v for k, v in cfg.items()} # pyright: ignore[reportUnknownArgumentType]
381
383
 
382
384
  if flatten:
383
385
  return _flatten_dict(standard_dict)
@@ -450,12 +452,12 @@ class Run[C, I = None]:
450
452
 
451
453
 
452
454
  def _flatten_dict(d: dict[str, Any], parent_key: str = "") -> dict[str, Any]:
453
- items = []
455
+ items: list[tuple[str, Any]] = []
454
456
 
455
457
  for k, v in d.items():
456
458
  key = f"{parent_key}.{k}" if parent_key else k
457
459
  if isinstance(v, dict):
458
- items.extend(_flatten_dict(v, key).items())
460
+ items.extend(_flatten_dict(v, key).items()) # pyright: ignore[reportUnknownArgumentType]
459
461
  else:
460
462
  items.append((key, v))
461
463
 
@@ -54,6 +54,8 @@ if TYPE_CHECKING:
54
54
 
55
55
  from polars import DataFrame
56
56
 
57
+ # pyright: reportUnknownVariableType=false
58
+
57
59
 
58
60
  class RunCollection[R: Run[Any, Any]](Collection[R]):
59
61
  """A collection of Run instances that implements the Sequence protocol.
@@ -68,7 +68,7 @@ async def arun(
68
68
  ) -> int | None:
69
69
  """Run a command asynchronously."""
70
70
  process = await asyncio.create_subprocess_exec(*args, stdout=PIPE, stderr=PIPE)
71
- coros = alog(process.stdout, stdout), alog(process.stderr, stderr) # type:ignore
71
+ coros = alog(process.stdout, stdout), alog(process.stderr, stderr) # pyright: ignore[reportArgumentType]
72
72
  await asyncio.gather(*coros)
73
73
  await process.communicate()
74
74
 
@@ -10,7 +10,7 @@ from omegaconf import DictConfig, OmegaConf
10
10
  from .conf import HydraflowConf
11
11
 
12
12
  if TYPE_CHECKING:
13
- from .job import Job
13
+ from .conf import Job
14
14
 
15
15
 
16
16
  def find_config_file() -> Path | None:
@@ -38,7 +38,7 @@ def load_config() -> HydraflowConf:
38
38
  if not isinstance(cfg, DictConfig):
39
39
  return schema
40
40
 
41
- return OmegaConf.merge(schema, cfg) # type: ignore[return-value]
41
+ return OmegaConf.merge(schema, cfg) # pyright: ignore[reportReturnType]
42
42
 
43
43
 
44
44
  def get_job(name: str) -> Job:
@@ -99,7 +99,7 @@ def merge_args(first: list[str], second: list[str]) -> list[str]:
99
99
  list[str]: A merged list of arguments.
100
100
 
101
101
  """
102
- merged = {}
102
+ merged: dict[str, str | None] = {}
103
103
 
104
104
  for item in [*first, *second]:
105
105
  if "=" in item:
@@ -158,7 +158,7 @@ def submit(
158
158
  iterable: Iterable[list[str]],
159
159
  *,
160
160
  dry_run: bool = False,
161
- ) -> CompletedProcess | tuple[list[str], str]:
161
+ ) -> CompletedProcess[bytes] | tuple[list[str], str]:
162
162
  """Submit entire job using a shell command."""
163
163
  executable, *args = args
164
164
  if executable == "python" and sys.platform == "win32":
@@ -181,7 +181,7 @@ def submit(
181
181
  file.unlink(missing_ok=True)
182
182
 
183
183
 
184
- def get_callable(name: str) -> Callable:
184
+ def get_callable(name: str) -> Callable[[list[str]], Any]:
185
185
  """Get a callable from a function name."""
186
186
  if "." not in name:
187
187
  msg = f"Invalid function path: {name}."
@@ -205,7 +205,7 @@ def _arange(start: float, stop: float, step: float) -> list[float]:
205
205
 
206
206
  epsilon = min(abs(start), abs(stop)) * 1e-5
207
207
 
208
- result = []
208
+ result: list[float] = []
209
209
  current = start
210
210
 
211
211
  if step > 0:
@@ -412,8 +412,8 @@ def split(arg: str) -> list[str]:
412
412
  ['(a,b)m', '(1,2:4)k']
413
413
 
414
414
  """
415
- result = []
416
- current = []
415
+ result: list[str] = []
416
+ current: list[str] = []
417
417
  bracket_count = 0
418
418
  paren_count = 0
419
419
  in_single_quote = False
@@ -57,12 +57,12 @@ class Progress(Super):
57
57
  def _update(parallel: Parallel) -> None:
58
58
  update(self, parallel)
59
59
 
60
- Parallel.print_progress = _update # type: ignore
60
+ Parallel.print_progress = _update # pyright: ignore[reportAttributeAccessIssue]
61
61
 
62
62
  def stop(self) -> None:
63
63
  """Stop the progress display."""
64
64
  if self._print_progress:
65
- Parallel.print_progress = self._print_progress # type: ignore
65
+ Parallel.print_progress = self._print_progress # pyright: ignore[reportAttributeAccessIssue]
66
66
 
67
67
  super().stop()
68
68
 
@@ -86,5 +86,5 @@ def update(progress: Progress, parallel: Parallel) -> None:
86
86
 
87
87
  progress.update(task_id, completed=parallel.n_completed_tasks, refresh=True)
88
88
 
89
- if progress._print_progress: # noqa: SLF001
90
- progress._print_progress(parallel) # noqa: SLF001
89
+ if progress._print_progress: # noqa: SLF001 # pyright: ignore[reportPrivateUsage]
90
+ progress._print_progress(parallel) # noqa: SLF001 # pyright: ignore[reportPrivateUsage]
File without changes
File without changes