hydraflow 0.5.4__tar.gz → 0.6.0__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (56) hide show
  1. {hydraflow-0.5.4 → hydraflow-0.6.0}/PKG-INFO +1 -1
  2. {hydraflow-0.5.4 → hydraflow-0.6.0}/pyproject.toml +27 -3
  3. {hydraflow-0.5.4 → hydraflow-0.6.0}/src/hydraflow/config.py +3 -0
  4. {hydraflow-0.5.4 → hydraflow-0.6.0}/src/hydraflow/context.py +44 -4
  5. {hydraflow-0.5.4 → hydraflow-0.6.0}/src/hydraflow/mlflow.py +4 -0
  6. {hydraflow-0.5.4 → hydraflow-0.6.0}/src/hydraflow/param.py +4 -0
  7. {hydraflow-0.5.4 → hydraflow-0.6.0}/src/hydraflow/run_collection.py +29 -0
  8. {hydraflow-0.5.4 → hydraflow-0.6.0}/src/hydraflow/run_data.py +1 -0
  9. {hydraflow-0.5.4 → hydraflow-0.6.0}/src/hydraflow/utils.py +5 -0
  10. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/config/overrides.py +1 -2
  11. hydraflow-0.6.0/tests/context/chdir.py +29 -0
  12. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/context/context.py +1 -5
  13. hydraflow-0.6.0/tests/context/logging.py +39 -0
  14. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/context/preemption.py +1 -2
  15. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/context/rerun.py +3 -7
  16. hydraflow-0.6.0/tests/context/test_chdir.py +27 -0
  17. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/context/test_context.py +1 -11
  18. hydraflow-0.6.0/tests/context/test_logging.py +51 -0
  19. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/param/params.py +1 -2
  20. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/param/test_params.py +1 -1
  21. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/run/filter.py +1 -2
  22. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/run/run.py +1 -2
  23. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/run/test_info.py +1 -1
  24. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/run/test_run.py +1 -1
  25. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/utils/utils.py +1 -2
  26. {hydraflow-0.5.4 → hydraflow-0.6.0}/.devcontainer/devcontainer.json +0 -0
  27. {hydraflow-0.5.4 → hydraflow-0.6.0}/.devcontainer/postCreate.sh +0 -0
  28. {hydraflow-0.5.4 → hydraflow-0.6.0}/.devcontainer/starship.toml +0 -0
  29. {hydraflow-0.5.4 → hydraflow-0.6.0}/.gitattributes +0 -0
  30. {hydraflow-0.5.4 → hydraflow-0.6.0}/.gitignore +0 -0
  31. {hydraflow-0.5.4 → hydraflow-0.6.0}/LICENSE +0 -0
  32. {hydraflow-0.5.4 → hydraflow-0.6.0}/README.md +0 -0
  33. {hydraflow-0.5.4 → hydraflow-0.6.0}/apps/quickstart.py +0 -0
  34. {hydraflow-0.5.4 → hydraflow-0.6.0}/mkdocs.yml +0 -0
  35. {hydraflow-0.5.4 → hydraflow-0.6.0}/src/hydraflow/__init__.py +0 -0
  36. {hydraflow-0.5.4 → hydraflow-0.6.0}/src/hydraflow/py.typed +0 -0
  37. {hydraflow-0.5.4 → hydraflow-0.6.0}/src/hydraflow/run_info.py +0 -0
  38. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/__init__.py +0 -0
  39. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/config/__init__.py +0 -0
  40. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/config/test_config.py +0 -0
  41. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/config/test_overrides.py +0 -0
  42. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/config/test_params.py +0 -0
  43. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/conftest.py +0 -0
  44. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/context/__init__.py +0 -0
  45. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/context/test_preemption.py +0 -0
  46. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/context/test_rerun.py +0 -0
  47. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/param/__init__.py +0 -0
  48. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/param/test_param.py +0 -0
  49. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/run/__init__.py +0 -0
  50. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/run/test_collection.py +0 -0
  51. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/run/test_data.py +0 -0
  52. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/run/test_filter.py +0 -0
  53. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/test_mlflow.py +0 -0
  54. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/utils/__init__.py +0 -0
  55. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/utils/test_run.py +0 -0
  56. {hydraflow-0.5.4 → hydraflow-0.6.0}/tests/utils/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hydraflow
3
- Version: 0.5.4
3
+ Version: 0.6.0
4
4
  Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
5
5
  Project-URL: Documentation, https://daizutabi.github.io/hydraflow/
6
6
  Project-URL: Source, https://github.com/daizutabi/hydraflow
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "hydraflow"
7
- version = "0.5.4"
7
+ version = "0.6.0"
8
8
  description = "Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments."
9
9
  readme = "README.md"
10
10
  license = { file = "LICENSE" }
@@ -68,8 +68,32 @@ target-version = "py310"
68
68
  [tool.ruff.lint]
69
69
  select = ["ALL"]
70
70
  unfixable = ["F401"]
71
- ignore = ["A005", "ANN003", "ANN401", "B904", "D", "EM101", "PGH003", "TRY003"]
71
+ ignore = [
72
+ "A005",
73
+ "ANN003",
74
+ "ANN401",
75
+ "B904",
76
+ "D105",
77
+ "D107",
78
+ "D203",
79
+ "D213",
80
+ "EM101",
81
+ "PGH003",
82
+ "PLR1704",
83
+ "TRY003",
84
+ ]
72
85
 
73
86
  [tool.ruff.lint.per-file-ignores]
74
- "tests/*" = ["A001", "ANN", "ARG", "FBT", "PLR", "PT", "S", "SIM108", "SLF"]
87
+ "tests/*" = [
88
+ "A001",
89
+ "ANN",
90
+ "ARG",
91
+ "D",
92
+ "FBT",
93
+ "PLR",
94
+ "PT",
95
+ "S",
96
+ "SIM108",
97
+ "SLF",
98
+ ]
75
99
  "apps/*.py" = ["D", "G", "INP"]
@@ -22,6 +22,7 @@ def collect_params(config: object) -> dict[str, Any]:
22
22
 
23
23
  Returns:
24
24
  dict[str, Any]: A dictionary of collected parameters.
25
+
25
26
  """
26
27
  return dict(iter_params(config))
27
28
 
@@ -40,6 +41,7 @@ def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
40
41
 
41
42
  Yields:
42
43
  Key-value pairs representing the parameters in the configuration object.
44
+
43
45
  """
44
46
  if config is None:
45
47
  return
@@ -113,6 +115,7 @@ def select_config(config: object, names: list[str]) -> dict[str, Any]:
113
115
 
114
116
  Returns:
115
117
  DictConfig: A new configuration object containing only the selected parameters.
118
+
116
119
  """
117
120
  if not isinstance(config, DictConfig):
118
121
  config = OmegaConf.structured(config)
@@ -48,6 +48,7 @@ def log_run(
48
48
  # Perform operations within the MLflow run context
49
49
  pass
50
50
  ```
51
+
51
52
  """
52
53
  if config:
53
54
  log_params(config, synchronous=synchronous)
@@ -55,7 +56,7 @@ def log_run(
55
56
  hc = HydraConfig.get()
56
57
  output_dir = Path(hc.runtime.output_dir)
57
58
 
58
- # Save '.hydra' config directory first.
59
+ # Save '.hydra' config directory.
59
60
  output_subdir = output_dir / (hc.output_subdir or "")
60
61
  mlflow.log_artifacts(output_subdir.as_posix(), hc.output_subdir)
61
62
 
@@ -68,14 +69,41 @@ def log_run(
68
69
  raise
69
70
 
70
71
  finally:
71
- # Save output_dir including '.hydra' config directory.
72
- mlflow.log_artifacts(output_dir.as_posix())
72
+ log_hydra(output_dir)
73
+
74
+
75
+ def log_hydra(output_dir: Path) -> None:
76
+ """Log hydra logs of the current run as artifacts.
77
+
78
+ Args:
79
+ output_dir (Path): The output directory of the Hydra job.
80
+
81
+ """
82
+ uri = mlflow.get_artifact_uri()
83
+ artifact_dir = Path(mlflow.artifacts.download_artifacts(uri))
84
+
85
+ for file_hydra in output_dir.glob("*.log"):
86
+ if not file_hydra.is_file():
87
+ continue
88
+
89
+ file_artifact = artifact_dir / file_hydra.name
90
+ if file_artifact.exists():
91
+ text = file_artifact.read_text()
92
+ if not text.endswith("\n"):
93
+ text += "\n"
94
+ else:
95
+ text = ""
96
+
97
+ text += file_hydra.read_text()
98
+ mlflow.log_text(text, file_hydra.name)
73
99
 
74
100
 
75
101
  @contextmanager
76
102
  def start_run( # noqa: PLR0913
77
103
  config: object,
78
104
  *,
105
+ chdir: bool = False,
106
+ run: Run | None = None,
79
107
  run_id: str | None = None,
80
108
  experiment_id: str | None = None,
81
109
  run_name: str | None = None,
@@ -93,6 +121,9 @@ def start_run( # noqa: PLR0913
93
121
 
94
122
  Args:
95
123
  config (object): The configuration object to log parameters from.
124
+ chdir (bool): Whether to change the current working directory to the
125
+ artifact directory of the current run. Defaults to False.
126
+ run (Run | None): The existing run. Defaults to None.
96
127
  run_id (str | None): The existing run ID. Defaults to None.
97
128
  experiment_id (str | None): The experiment ID. Defaults to None.
98
129
  run_name (str | None): The name of the run. Defaults to None.
@@ -117,7 +148,11 @@ def start_run( # noqa: PLR0913
117
148
  - `mlflow.start_run`: The MLflow function to start a run directly.
118
149
  - `log_run`: A context manager to log parameters and manage the MLflow
119
150
  run context.
151
+
120
152
  """
153
+ if run:
154
+ run_id = run.info.run_id
155
+
121
156
  with (
122
157
  mlflow.start_run(
123
158
  run_id=run_id,
@@ -131,7 +166,11 @@ def start_run( # noqa: PLR0913
131
166
  ) as run,
132
167
  log_run(config if run_id is None else None, synchronous=synchronous),
133
168
  ):
134
- yield run
169
+ if chdir:
170
+ with chdir_artifact(run):
171
+ yield run
172
+ else:
173
+ yield run
135
174
 
136
175
 
137
176
  @contextmanager
@@ -167,6 +206,7 @@ def chdir_artifact(
167
206
  Args:
168
207
  run (Run): The run to get the artifact directory from.
169
208
  artifact_path (str | None): The artifact path.
209
+
170
210
  """
171
211
  curdir = Path.cwd()
172
212
  path = mlflow.artifacts.download_artifacts(
@@ -54,6 +54,7 @@ def set_experiment(
54
54
  Returns:
55
55
  Experiment: An instance of `mlflow.entities.Experiment` representing
56
56
  the new active experiment.
57
+
57
58
  """
58
59
  if uri is not None:
59
60
  mlflow.set_tracking_uri(uri)
@@ -77,6 +78,7 @@ def log_params(config: object, *, synchronous: bool | None = None) -> None:
77
78
  config (object): The configuration object to log the parameters from.
78
79
  synchronous (bool | None): Whether to log the parameters synchronously.
79
80
  Defaults to None.
81
+
80
82
  """
81
83
  for key, value in iter_params(config):
82
84
  mlflow.log_param(key, value, synchronous=synchronous)
@@ -133,6 +135,7 @@ def search_runs( # noqa: PLR0913
133
135
 
134
136
  Returns:
135
137
  A `RunCollection` object containing the search results.
138
+
136
139
  """
137
140
  runs = mlflow.search_runs(
138
141
  experiment_ids=experiment_ids,
@@ -177,6 +180,7 @@ def list_runs(
177
180
  Returns:
178
181
  RunCollection: A `RunCollection` instance containing the runs for the
179
182
  specified experiments.
183
+
180
184
  """
181
185
  rc = _list_runs(experiment_names, n_jobs)
182
186
  if status is None:
@@ -28,6 +28,7 @@ def match(param: str, value: Any) -> bool: # noqa: PLR0911
28
28
  Returns:
29
29
  True if the parameter matches the specified value,
30
30
  False otherwise.
31
+
31
32
  """
32
33
  if callable(value):
33
34
  return value(param)
@@ -94,6 +95,7 @@ def to_value(param: str | None, type_: type) -> Any:
94
95
 
95
96
  Returns:
96
97
  The converted value.
98
+
97
99
  """
98
100
  if param is None or param == "None":
99
101
  return None
@@ -129,6 +131,7 @@ def get_params(run: Run, *names: str | list[str]) -> tuple[str | None, ...]:
129
131
  Returns:
130
132
  tuple[str | None, ...]: A tuple containing the values of the specified
131
133
  parameters in the order they were provided.
134
+
132
135
  """
133
136
  names_ = []
134
137
  for name in names:
@@ -155,6 +158,7 @@ def get_values(run: Run, names: list[str], types: list[type]) -> tuple[Any, ...]
155
158
  Returns:
156
159
  tuple[Any, ...]: A tuple containing the values of the specified
157
160
  parameters in the order they were provided.
161
+
158
162
  """
159
163
  params = get_params(run, names)
160
164
  it = zip(params, types, strict=True)
@@ -106,6 +106,7 @@ class RunCollection:
106
106
 
107
107
  Returns:
108
108
  A new `RunCollection` instance with the runs from both collections.
109
+
109
110
  """
110
111
  return self.__class__(self._runs + other._runs)
111
112
 
@@ -118,6 +119,7 @@ class RunCollection:
118
119
  Returns:
119
120
  A new `RunCollection` instance with the runs that are in this collection
120
121
  but not in the other.
122
+
121
123
  """
122
124
  runs = [run for run in self._runs if run not in other._runs] # noqa: SLF001
123
125
  return self.__class__(runs)
@@ -150,6 +152,7 @@ class RunCollection:
150
152
  Returns:
151
153
  A new `RunCollection` instance containing the first n runs if n is
152
154
  positive, or the last n runs if n is negative.
155
+
153
156
  """
154
157
  if n < 0:
155
158
  return self.__class__(self._runs[n:])
@@ -164,6 +167,7 @@ class RunCollection:
164
167
 
165
168
  Raises:
166
169
  ValueError: If the collection does not contain exactly one run.
170
+
167
171
  """
168
172
  if len(self._runs) != 1:
169
173
  raise ValueError("The collection does not contain exactly one run.")
@@ -176,6 +180,7 @@ class RunCollection:
176
180
  Returns:
177
181
  The only `Run` instance in the collection, or None if the collection
178
182
  does not contain exactly one run.
183
+
179
184
  """
180
185
  return self._runs[0] if len(self._runs) == 1 else None
181
186
 
@@ -187,6 +192,7 @@ class RunCollection:
187
192
 
188
193
  Raises:
189
194
  ValueError: If the collection is empty.
195
+
190
196
  """
191
197
  if not self._runs:
192
198
  raise ValueError("The collection is empty.")
@@ -199,6 +205,7 @@ class RunCollection:
199
205
  Returns:
200
206
  The first `Run` instance in the collection, or None if the collection
201
207
  is empty.
208
+
202
209
  """
203
210
  return self._runs[0] if self._runs else None
204
211
 
@@ -210,6 +217,7 @@ class RunCollection:
210
217
 
211
218
  Raises:
212
219
  ValueError: If the collection is empty.
220
+
213
221
  """
214
222
  if not self._runs:
215
223
  raise ValueError("The collection is empty.")
@@ -222,6 +230,7 @@ class RunCollection:
222
230
  Returns:
223
231
  The last `Run` instance in the collection, or None if the collection
224
232
  is empty.
233
+
225
234
  """
226
235
  return self._runs[-1] if self._runs else None
227
236
 
@@ -262,6 +271,7 @@ class RunCollection:
262
271
 
263
272
  Returns:
264
273
  A new `RunCollection` object containing the filtered runs.
274
+
265
275
  """
266
276
  return RunCollection(
267
277
  filter_runs(
@@ -294,6 +304,7 @@ class RunCollection:
294
304
 
295
305
  See Also:
296
306
  `filter`: Perform the actual filtering logic.
307
+
297
308
  """
298
309
  try:
299
310
  return self.filter(config, **kwargs).first()
@@ -318,6 +329,7 @@ class RunCollection:
318
329
 
319
330
  See Also:
320
331
  `filter`: Perform the actual filtering logic.
332
+
321
333
  """
322
334
  return self.filter(config, **kwargs).try_first()
323
335
 
@@ -341,6 +353,7 @@ class RunCollection:
341
353
 
342
354
  See Also:
343
355
  `filter`: Perform the actual filtering logic.
356
+
344
357
  """
345
358
  try:
346
359
  return self.filter(config, **kwargs).last()
@@ -365,6 +378,7 @@ class RunCollection:
365
378
 
366
379
  See Also:
367
380
  `filter`: Perform the actual filtering logic.
381
+
368
382
  """
369
383
  return self.filter(config, **kwargs).try_last()
370
384
 
@@ -389,6 +403,7 @@ class RunCollection:
389
403
 
390
404
  See Also:
391
405
  `filter`: Perform the actual filtering logic.
406
+
392
407
  """
393
408
  try:
394
409
  return self.filter(config, **kwargs).one()
@@ -417,6 +432,7 @@ class RunCollection:
417
432
 
418
433
  See Also:
419
434
  `filter`: Perform the actual filtering logic.
435
+
420
436
  """
421
437
  return self.filter(config, **kwargs).try_one()
422
438
 
@@ -429,6 +445,7 @@ class RunCollection:
429
445
 
430
446
  Returns:
431
447
  A list of unique parameter names.
448
+
432
449
  """
433
450
  param_names = set()
434
451
 
@@ -453,6 +470,7 @@ class RunCollection:
453
470
  Returns:
454
471
  A dictionary where the keys are parameter names and the values are
455
472
  lists of parameter values.
473
+
456
474
  """
457
475
  params = {}
458
476
 
@@ -484,6 +502,7 @@ class RunCollection:
484
502
 
485
503
  Yields:
486
504
  Results obtained by applying the function to each run in the collection.
505
+
487
506
  """
488
507
  return (func(run, *args, **kwargs) for run in self)
489
508
 
@@ -504,6 +523,7 @@ class RunCollection:
504
523
  Yields:
505
524
  Results obtained by applying the function to each run id in the
506
525
  collection.
526
+
507
527
  """
508
528
  return (func(run_id, *args, **kwargs) for run_id in self.info.run_id)
509
529
 
@@ -524,6 +544,7 @@ class RunCollection:
524
544
  Yields:
525
545
  Results obtained by applying the function to each run configuration
526
546
  in the collection.
547
+
527
548
  """
528
549
  return (func(load_config(run), *args, **kwargs) for run in self)
529
550
 
@@ -548,6 +569,7 @@ class RunCollection:
548
569
  Yields:
549
570
  Results obtained by applying the function to each artifact URI in the
550
571
  collection.
572
+
551
573
  """
552
574
  return (func(uri, *args, **kwargs) for uri in self.info.artifact_uri)
553
575
 
@@ -571,6 +593,7 @@ class RunCollection:
571
593
  Yields:
572
594
  Results obtained by applying the function to each artifact directory
573
595
  in the collection.
596
+
574
597
  """
575
598
  return (func(dir, *args, **kwargs) for dir in self.info.artifact_dir) # noqa: A001
576
599
 
@@ -594,6 +617,7 @@ class RunCollection:
594
617
  dictionary where the keys are tuples of parameter values and the
595
618
  values are `RunCollection` objects containing the runs that match
596
619
  those parameter values.
620
+
597
621
  """
598
622
  grouped_runs: dict[str | None | tuple[str | None, ...], list[Run]] = {}
599
623
  is_list = isinstance(names, list)
@@ -620,6 +644,7 @@ class RunCollection:
620
644
  key (Callable[[Run], Any] | None): A function that takes a run and returns
621
645
  a value to sort by.
622
646
  reverse (bool): If True, sort in descending order.
647
+
623
648
  """
624
649
  self._runs.sort(key=key or (lambda x: x.info.start_time), reverse=reverse)
625
650
 
@@ -633,6 +658,7 @@ class RunCollection:
633
658
 
634
659
  Returns:
635
660
  A list of values for the specified parameters.
661
+
636
662
  """
637
663
  is_list = isinstance(names, list)
638
664
 
@@ -664,6 +690,7 @@ class RunCollection:
664
690
  This can be a single parameter name or multiple names provided
665
691
  as separate arguments or as a list.
666
692
  reverse (bool): If True, sort in descending order.
693
+
667
694
  """
668
695
  values = self.values(names)
669
696
  index = sorted(range(len(self)), key=lambda i: values[i], reverse=reverse)
@@ -721,6 +748,7 @@ def filter_runs(
721
748
 
722
749
  Returns:
723
750
  A list of runs that match the specified configuration and key-value pairs.
751
+
724
752
  """
725
753
  if override:
726
754
  config = select_overrides(config)
@@ -751,6 +779,7 @@ def filter_runs_by_status(
751
779
 
752
780
  Returns:
753
781
  A list of runs that match the specified status.
782
+
754
783
  """
755
784
  if isinstance(status, str):
756
785
  if status.startswith("!"):
@@ -37,6 +37,7 @@ class RunCollectionData:
37
37
 
38
38
  Returns:
39
39
  A DataFrame containing the runs' configurations.
40
+
40
41
  """
41
42
  return DataFrame(self._runs.map_config(collect_params))
42
43
 
@@ -26,6 +26,7 @@ def get_artifact_dir(run: Run | None = None) -> Path:
26
26
 
27
27
  Returns:
28
28
  The local path to the directory where the artifacts are downloaded.
29
+
29
30
  """
30
31
  uri = mlflow.get_artifact_uri() if run is None else run.info.artifact_uri
31
32
 
@@ -52,6 +53,7 @@ def get_artifact_path(run: Run | None, path: str) -> Path:
52
53
 
53
54
  Returns:
54
55
  The local path to the artifact.
56
+
55
57
  """
56
58
  return get_artifact_dir(run) / path
57
59
 
@@ -74,6 +76,7 @@ def get_hydra_output_dir(run: Run | None = None) -> Path:
74
76
  Raises:
75
77
  FileNotFoundError: If the Hydra configuration file is not found
76
78
  in the artifacts.
79
+
77
80
  """
78
81
  if run is None:
79
82
  hc = HydraConfig.get()
@@ -102,6 +105,7 @@ def load_config(run: Run) -> DictConfig:
102
105
  Returns:
103
106
  The loaded configuration as a DictConfig object. Returns an empty
104
107
  DictConfig if the configuration file is not found.
108
+
105
109
  """
106
110
  path = get_artifact_dir(run) / ".hydra/config.yaml"
107
111
  return OmegaConf.load(path) # type: ignore
@@ -126,6 +130,7 @@ def load_overrides(run: Run) -> list[str]:
126
130
  Returns:
127
131
  The loaded overrides as a list of strings. Returns an empty list
128
132
  if the overrides file is not found.
133
+
129
134
  """
130
135
  path = get_artifact_dir(run) / ".hydra/overrides.yaml"
131
136
  return [str(x) for x in OmegaConf.load(path)]
@@ -16,8 +16,7 @@ class Config:
16
16
  height: float = 1.7
17
17
 
18
18
 
19
- cs = ConfigStore.instance()
20
- cs.store(name="config", node=Config)
19
+ ConfigStore.instance().store(name="config", node=Config)
21
20
 
22
21
 
23
22
  @hydra.main(version_base=None, config_name="config")
@@ -0,0 +1,29 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+
6
+ import hydra
7
+ from hydra.core.config_store import ConfigStore
8
+
9
+ import hydraflow
10
+
11
+
12
+ @dataclass
13
+ class Config:
14
+ count: int = 0
15
+
16
+
17
+ ConfigStore.instance().store(name="config", node=Config)
18
+
19
+
20
+ @hydra.main(version_base=None, config_name="config")
21
+ def app(cfg: Config):
22
+ hydraflow.set_experiment()
23
+
24
+ with hydraflow.start_run(cfg, chdir=True):
25
+ Path("a.txt").write_text(str(cfg.count))
26
+
27
+
28
+ if __name__ == "__main__":
29
+ app()
@@ -15,8 +15,7 @@ class Config:
15
15
  name: str = "a"
16
16
 
17
17
 
18
- cs = ConfigStore.instance()
19
- cs.store(name="config", node=Config)
18
+ ConfigStore.instance().store(name="config", node=Config)
20
19
 
21
20
 
22
21
  @hydra.main(version_base=None, config_name="config")
@@ -24,9 +23,6 @@ def app(cfg: Config):
24
23
  hydraflow.set_experiment()
25
24
 
26
25
  with hydraflow.start_run(cfg) as run:
27
- with hydraflow.chdir_hydra_output():
28
- Path("a.txt").write_text("chdir_hydra_output")
29
-
30
26
  with hydraflow.chdir_artifact(run):
31
27
  Path("b.txt").write_text("chdir_artifact")
32
28
 
@@ -0,0 +1,39 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+
7
+ import hydra
8
+ from hydra.core.config_store import ConfigStore
9
+
10
+ import hydraflow
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+
15
+ @dataclass
16
+ class Config:
17
+ count: int = 0
18
+
19
+
20
+ ConfigStore.instance().store(name="config", node=Config)
21
+
22
+
23
+ @hydra.main(version_base=None, config_name="config")
24
+ def app(cfg: Config):
25
+ hydraflow.set_experiment()
26
+
27
+ run = hydraflow.list_runs().try_get(cfg, override=True)
28
+
29
+ with hydraflow.start_run(cfg, run=run):
30
+ log.info("second" if run else "first")
31
+ log.info(cfg.count)
32
+
33
+ with hydraflow.chdir_hydra_output():
34
+ Path("text.log").write_text("text\n")
35
+ Path("dir.log").mkdir()
36
+
37
+
38
+ if __name__ == "__main__":
39
+ app()
@@ -17,8 +17,7 @@ class Config:
17
17
  count: int = 0
18
18
 
19
19
 
20
- cs = ConfigStore.instance()
21
- cs.store(name="config", node=Config)
20
+ ConfigStore.instance().store(name="config", node=Config)
22
21
 
23
22
 
24
23
  @hydra.main(version_base=None, config_name="config")
@@ -17,20 +17,16 @@ class Config:
17
17
  count: int = 0
18
18
 
19
19
 
20
- cs = ConfigStore.instance()
21
- cs.store(name="config", node=Config)
20
+ ConfigStore.instance().store(name="config", node=Config)
22
21
 
23
22
 
24
23
  @hydra.main(version_base=None, config_name="config")
25
24
  def app(cfg: Config):
26
25
  hydraflow.set_experiment()
27
26
 
28
- if run := hydraflow.list_runs().try_find(cfg, override=True):
29
- run_id = run.info.run_id
30
- else:
31
- run_id = None
27
+ run = hydraflow.list_runs().try_find(cfg, override=True)
32
28
 
33
- with hydraflow.start_run(cfg, run_id=run_id) as run:
29
+ with hydraflow.start_run(cfg, run=run) as run:
34
30
  log(hydraflow.get_artifact_dir(run))
35
31
 
36
32
 
@@ -0,0 +1,27 @@
1
+ import pytest
2
+ from mlflow.entities import Run
3
+
4
+ from hydraflow.run_collection import RunCollection
5
+
6
+ pytestmark = pytest.mark.xdist_group(name="group4")
7
+
8
+
9
+ @pytest.fixture(scope="module")
10
+ def rc(collect):
11
+ return collect("context/chdir.py", ["-m", "count=1,2"])
12
+
13
+
14
+ def test_rc_len(rc: RunCollection):
15
+ assert len(rc) == 2
16
+
17
+
18
+ @pytest.fixture(scope="module", params=[1, 2])
19
+ def run(rc: RunCollection, request: pytest.FixtureRequest):
20
+ return rc.get(count=request.param)
21
+
22
+
23
+ def test_run_count(run: Run):
24
+ from hydraflow.utils import get_artifact_path
25
+
26
+ text = get_artifact_path(run, "a.txt").read_text()
27
+ assert text == run.data.params["count"]
@@ -2,7 +2,7 @@ import pytest
2
2
  from mlflow.entities import Run
3
3
 
4
4
  from hydraflow.run_collection import RunCollection
5
- from hydraflow.utils import get_artifact_path, get_hydra_output_dir
5
+ from hydraflow.utils import get_artifact_path
6
6
 
7
7
  pytestmark = pytest.mark.xdist_group(name="group2")
8
8
 
@@ -18,16 +18,6 @@ def run(rc: RunCollection, request: pytest.FixtureRequest):
18
18
  return rc[request.param]
19
19
 
20
20
 
21
- def test_chdir_hydra_output(run: Run):
22
- path = get_hydra_output_dir(run)
23
- assert (path / "a.txt").read_text() == "chdir_hydra_output"
24
-
25
-
26
21
  def test_chdir_artifact(run: Run):
27
22
  path = get_artifact_path(run, "b.txt")
28
23
  assert path.read_text() == "chdir_artifact"
29
-
30
-
31
- def test_log_run(run: Run):
32
- path = get_artifact_path(run, "a.txt")
33
- assert path.read_text() == "chdir_hydra_output"
@@ -0,0 +1,51 @@
1
+ import pytest
2
+ from mlflow.entities import Run
3
+
4
+ from hydraflow.run_collection import RunCollection
5
+ from hydraflow.utils import get_artifact_path
6
+
7
+ pytestmark = pytest.mark.xdist_group(name="group6")
8
+
9
+
10
+ @pytest.fixture(scope="module")
11
+ def rc(collect):
12
+ collect("context/logging.py", ["count=100"])
13
+ return collect("context/logging.py", ["count=100"])
14
+
15
+
16
+ def test_rc_len(rc: RunCollection):
17
+ assert len(rc) == 1
18
+
19
+
20
+ @pytest.fixture(scope="module")
21
+ def run(rc: RunCollection):
22
+ return rc[0]
23
+
24
+
25
+ @pytest.fixture(scope="module")
26
+ def hydra_log(run: Run, experiment_name: str):
27
+ path = get_artifact_path(run, f"{experiment_name}.log")
28
+ return path.read_text()
29
+
30
+
31
+ @pytest.mark.parametrize(
32
+ ("i", "suffix"),
33
+ [(0, "] - first"), (1, "] - 100"), (2, "] - second"), (3, "] - 100")],
34
+ )
35
+ def test_hydra_log(hydra_log: str, i: int, suffix: str):
36
+ assert hydra_log.splitlines()[i].endswith(suffix)
37
+
38
+
39
+ def test_text_log(run: Run):
40
+ path = get_artifact_path(run, "text.log")
41
+ assert path.read_text() == "text\ntext\n"
42
+
43
+
44
+ def test_dir_log(run: Run):
45
+ assert not get_artifact_path(run, "dir.log").exists()
46
+
47
+
48
+ def test_config(run: Run):
49
+ path = get_artifact_path(run, ".hydra/config.yaml")
50
+ cfg = path.read_text()
51
+ assert cfg == "count: 100\n"
@@ -21,8 +21,7 @@ class Config:
21
21
  data: Data = field(default_factory=Data)
22
22
 
23
23
 
24
- cs = ConfigStore.instance()
25
- cs.store(name="config", node=Config)
24
+ ConfigStore.instance().store(name="config", node=Config)
26
25
 
27
26
 
28
27
  @hydra.main(version_base=None, config_name="config")
@@ -3,7 +3,7 @@ from mlflow.entities import Run
3
3
 
4
4
  from hydraflow.run_collection import RunCollection
5
5
 
6
- pytestmark = pytest.mark.xdist_group(name="group2")
6
+ pytestmark = pytest.mark.xdist_group(name="group1")
7
7
 
8
8
 
9
9
  @pytest.fixture(scope="module")
@@ -14,8 +14,7 @@ class Config:
14
14
  port: int = 3306
15
15
 
16
16
 
17
- cs = ConfigStore.instance()
18
- cs.store(name="config", node=Config)
17
+ ConfigStore.instance().store(name="config", node=Config)
19
18
 
20
19
 
21
20
  @hydra.main(version_base=None, config_name="config")
@@ -21,8 +21,7 @@ class Config:
21
21
  data: Data = field(default_factory=Data)
22
22
 
23
23
 
24
- cs = ConfigStore.instance()
25
- cs.store(name="config", node=Config)
24
+ ConfigStore.instance().store(name="config", node=Config)
26
25
 
27
26
 
28
27
  @hydra.main(version_base=None, config_name="config")
@@ -4,7 +4,7 @@ from mlflow.entities import Experiment
4
4
 
5
5
  from hydraflow.run_collection import RunCollection
6
6
 
7
- pytestmark = pytest.mark.xdist_group(name="group4")
7
+ pytestmark = pytest.mark.xdist_group(name="group7")
8
8
 
9
9
 
10
10
  @pytest.fixture(scope="module")
@@ -7,7 +7,7 @@ from hydraflow.run_collection import RunCollection
7
7
  if TYPE_CHECKING:
8
8
  from .run import Config
9
9
 
10
- pytestmark = pytest.mark.xdist_group(name="group4")
10
+ pytestmark = pytest.mark.xdist_group(name="group5")
11
11
 
12
12
 
13
13
  @pytest.fixture(scope="module")
@@ -16,8 +16,7 @@ class Config:
16
16
  height: float = 1.7
17
17
 
18
18
 
19
- cs = ConfigStore.instance()
20
- cs.store(name="config", node=Config)
19
+ ConfigStore.instance().store(name="config", node=Config)
21
20
 
22
21
 
23
22
  @hydra.main(version_base=None, config_name="config")
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes