hydraflow 0.5.3__tar.gz → 0.5.4__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (52) hide show
  1. {hydraflow-0.5.3 → hydraflow-0.5.4}/.devcontainer/postCreate.sh +1 -4
  2. {hydraflow-0.5.3 → hydraflow-0.5.4}/PKG-INFO +1 -1
  3. {hydraflow-0.5.3 → hydraflow-0.5.4}/pyproject.toml +9 -33
  4. {hydraflow-0.5.3 → hydraflow-0.5.4}/src/hydraflow/config.py +0 -3
  5. {hydraflow-0.5.3 → hydraflow-0.5.4}/src/hydraflow/context.py +0 -3
  6. {hydraflow-0.5.3 → hydraflow-0.5.4}/src/hydraflow/mlflow.py +0 -4
  7. {hydraflow-0.5.3 → hydraflow-0.5.4}/src/hydraflow/param.py +4 -5
  8. {hydraflow-0.5.3 → hydraflow-0.5.4}/src/hydraflow/run_collection.py +24 -31
  9. {hydraflow-0.5.3 → hydraflow-0.5.4}/src/hydraflow/run_data.py +0 -1
  10. {hydraflow-0.5.3 → hydraflow-0.5.4}/src/hydraflow/utils.py +0 -5
  11. hydraflow-0.5.3/tests/config/test_hydra.py → hydraflow-0.5.4/tests/config/test_overrides.py +3 -1
  12. {hydraflow-0.5.3 → hydraflow-0.5.4}/tests/conftest.py +1 -1
  13. hydraflow-0.5.4/tests/context/preemption.py +49 -0
  14. hydraflow-0.5.4/tests/context/rerun.py +44 -0
  15. hydraflow-0.5.3/tests/context/test_hydra.py → hydraflow-0.5.4/tests/context/test_context.py +2 -0
  16. hydraflow-0.5.4/tests/context/test_preemption.py +41 -0
  17. hydraflow-0.5.4/tests/context/test_rerun.py +31 -0
  18. {hydraflow-0.5.3 → hydraflow-0.5.4}/tests/param/test_param.py +10 -0
  19. hydraflow-0.5.3/tests/param/test_hydra.py → hydraflow-0.5.4/tests/param/test_params.py +3 -1
  20. {hydraflow-0.5.3 → hydraflow-0.5.4}/tests/run/test_collection.py +2 -0
  21. {hydraflow-0.5.3 → hydraflow-0.5.4}/tests/run/test_data.py +2 -0
  22. {hydraflow-0.5.3 → hydraflow-0.5.4}/tests/run/test_filter.py +2 -0
  23. {hydraflow-0.5.3 → hydraflow-0.5.4}/tests/run/test_info.py +2 -0
  24. hydraflow-0.5.3/tests/run/test_hydra.py → hydraflow-0.5.4/tests/run/test_run.py +2 -0
  25. {hydraflow-0.5.3 → hydraflow-0.5.4}/tests/test_mlflow.py +6 -6
  26. {hydraflow-0.5.3 → hydraflow-0.5.4}/tests/utils/test_run.py +4 -0
  27. hydraflow-0.5.3/tests/utils/test_hydra.py → hydraflow-0.5.4/tests/utils/test_utils.py +7 -0
  28. {hydraflow-0.5.3 → hydraflow-0.5.4}/.devcontainer/devcontainer.json +0 -0
  29. {hydraflow-0.5.3 → hydraflow-0.5.4}/.devcontainer/starship.toml +0 -0
  30. {hydraflow-0.5.3 → hydraflow-0.5.4}/.gitattributes +0 -0
  31. {hydraflow-0.5.3 → hydraflow-0.5.4}/.gitignore +0 -0
  32. {hydraflow-0.5.3 → hydraflow-0.5.4}/LICENSE +0 -0
  33. {hydraflow-0.5.3 → hydraflow-0.5.4}/README.md +0 -0
  34. {hydraflow-0.5.3 → hydraflow-0.5.4}/apps/quickstart.py +0 -0
  35. {hydraflow-0.5.3 → hydraflow-0.5.4}/mkdocs.yml +0 -0
  36. {hydraflow-0.5.3 → hydraflow-0.5.4}/src/hydraflow/__init__.py +0 -0
  37. {hydraflow-0.5.3 → hydraflow-0.5.4}/src/hydraflow/py.typed +0 -0
  38. {hydraflow-0.5.3 → hydraflow-0.5.4}/src/hydraflow/run_info.py +0 -0
  39. {hydraflow-0.5.3 → hydraflow-0.5.4}/tests/__init__.py +0 -0
  40. {hydraflow-0.5.3 → hydraflow-0.5.4}/tests/config/__init__.py +0 -0
  41. /hydraflow-0.5.3/tests/config/config.py → /hydraflow-0.5.4/tests/config/overrides.py +0 -0
  42. {hydraflow-0.5.3 → hydraflow-0.5.4}/tests/config/test_config.py +0 -0
  43. {hydraflow-0.5.3 → hydraflow-0.5.4}/tests/config/test_params.py +0 -0
  44. {hydraflow-0.5.3 → hydraflow-0.5.4}/tests/context/__init__.py +0 -0
  45. {hydraflow-0.5.3 → hydraflow-0.5.4}/tests/context/context.py +0 -0
  46. {hydraflow-0.5.3 → hydraflow-0.5.4}/tests/param/__init__.py +0 -0
  47. /hydraflow-0.5.3/tests/param/param.py → /hydraflow-0.5.4/tests/param/params.py +0 -0
  48. {hydraflow-0.5.3 → hydraflow-0.5.4}/tests/run/__init__.py +0 -0
  49. {hydraflow-0.5.3 → hydraflow-0.5.4}/tests/run/filter.py +0 -0
  50. {hydraflow-0.5.3 → hydraflow-0.5.4}/tests/run/run.py +0 -0
  51. {hydraflow-0.5.3 → hydraflow-0.5.4}/tests/utils/__init__.py +0 -0
  52. {hydraflow-0.5.3 → hydraflow-0.5.4}/tests/utils/utils.py +0 -0
@@ -5,7 +5,4 @@ mkdir -p ~/.config
5
5
  cp .devcontainer/starship.toml ~/.config
6
6
 
7
7
  curl -LsSf https://astral.sh/uv/install.sh | sh
8
- echo 'eval "$(uv generate-shell-completion bash)"' >> ~/.bashrc
9
- uv tool install ruff@latest
10
- uv python install
11
- uv sync -U
8
+ echo 'eval "$(uv generate-shell-completion bash)"' >> ~/.bashrc
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hydraflow
3
- Version: 0.5.3
3
+ Version: 0.5.4
4
4
  Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
5
5
  Project-URL: Documentation, https://daizutabi.github.io/hydraflow/
6
6
  Project-URL: Source, https://github.com/daizutabi/hydraflow
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "hydraflow"
7
- version = "0.5.3"
7
+ version = "0.5.4"
8
8
  description = "Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments."
9
9
  readme = "README.md"
10
10
  license = { file = "LICENSE" }
@@ -33,6 +33,9 @@ dev-dependencies = [
33
33
  "mkdocs-material",
34
34
  "mkdocs>=1.6",
35
35
  "pytest-cov",
36
+ "pytest-order",
37
+ "pytest-randomly",
38
+ "pytest-xdist",
36
39
  ]
37
40
 
38
41
  [tool.hatch.build.targets.sdist]
@@ -43,11 +46,11 @@ packages = ["src/hydraflow"]
43
46
 
44
47
  [tool.pytest.ini_options]
45
48
  addopts = [
46
- "--doctest-modules",
47
49
  "--cov=hydraflow",
48
50
  "--cov-report=lcov:lcov.info",
51
+ "-n8",
52
+ "--dist=loadgroup",
49
53
  ]
50
- doctest_optionflags = ["NORMALIZE_WHITESPACE", "IGNORE_EXCEPTION_DETAIL"]
51
54
  filterwarnings = [
52
55
  "ignore:pkg_resources is deprecated:DeprecationWarning",
53
56
  "ignore:Support for class-based `config` is deprecated",
@@ -65,35 +68,8 @@ target-version = "py310"
65
68
  [tool.ruff.lint]
66
69
  select = ["ALL"]
67
70
  unfixable = ["F401"]
68
- ignore = [
69
- "A005",
70
- "ANN003",
71
- "ANN401",
72
- "ARG002",
73
- "B904",
74
- "D105",
75
- "D107",
76
- "D203",
77
- "D213",
78
- "EM101",
79
- "PGH003",
80
- "TRY003",
81
- ]
82
- exclude = ["tests/scripts/*.py"]
71
+ ignore = ["A005", "ANN003", "ANN401", "B904", "D", "EM101", "PGH003", "TRY003"]
83
72
 
84
73
  [tool.ruff.lint.per-file-ignores]
85
- "tests/*" = [
86
- "A001",
87
- "ANN",
88
- "ARG",
89
- "D",
90
- "FBT",
91
- "PD",
92
- "PLR",
93
- "PT",
94
- "S",
95
- "SIM117",
96
- "TID",
97
- "SLF",
98
- ]
99
- "apps/*.py" = ["INP", "D", "G", "T"]
74
+ "tests/*" = ["A001", "ANN", "ARG", "FBT", "PLR", "PT", "S", "SIM108", "SLF"]
75
+ "apps/*.py" = ["D", "G", "INP"]
@@ -22,7 +22,6 @@ def collect_params(config: object) -> dict[str, Any]:
22
22
 
23
23
  Returns:
24
24
  dict[str, Any]: A dictionary of collected parameters.
25
-
26
25
  """
27
26
  return dict(iter_params(config))
28
27
 
@@ -41,7 +40,6 @@ def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
41
40
 
42
41
  Yields:
43
42
  Key-value pairs representing the parameters in the configuration object.
44
-
45
43
  """
46
44
  if config is None:
47
45
  return
@@ -115,7 +113,6 @@ def select_config(config: object, names: list[str]) -> dict[str, Any]:
115
113
 
116
114
  Returns:
117
115
  DictConfig: A new configuration object containing only the selected parameters.
118
-
119
116
  """
120
117
  if not isinstance(config, DictConfig):
121
118
  config = OmegaConf.structured(config)
@@ -48,7 +48,6 @@ def log_run(
48
48
  # Perform operations within the MLflow run context
49
49
  pass
50
50
  ```
51
-
52
51
  """
53
52
  if config:
54
53
  log_params(config, synchronous=synchronous)
@@ -118,7 +117,6 @@ def start_run( # noqa: PLR0913
118
117
  - `mlflow.start_run`: The MLflow function to start a run directly.
119
118
  - `log_run`: A context manager to log parameters and manage the MLflow
120
119
  run context.
121
-
122
120
  """
123
121
  with (
124
122
  mlflow.start_run(
@@ -169,7 +167,6 @@ def chdir_artifact(
169
167
  Args:
170
168
  run (Run): The run to get the artifact directory from.
171
169
  artifact_path (str | None): The artifact path.
172
-
173
170
  """
174
171
  curdir = Path.cwd()
175
172
  path = mlflow.artifacts.download_artifacts(
@@ -54,7 +54,6 @@ def set_experiment(
54
54
  Returns:
55
55
  Experiment: An instance of `mlflow.entities.Experiment` representing
56
56
  the new active experiment.
57
-
58
57
  """
59
58
  if uri is not None:
60
59
  mlflow.set_tracking_uri(uri)
@@ -78,7 +77,6 @@ def log_params(config: object, *, synchronous: bool | None = None) -> None:
78
77
  config (object): The configuration object to log the parameters from.
79
78
  synchronous (bool | None): Whether to log the parameters synchronously.
80
79
  Defaults to None.
81
-
82
80
  """
83
81
  for key, value in iter_params(config):
84
82
  mlflow.log_param(key, value, synchronous=synchronous)
@@ -135,7 +133,6 @@ def search_runs( # noqa: PLR0913
135
133
 
136
134
  Returns:
137
135
  A `RunCollection` object containing the search results.
138
-
139
136
  """
140
137
  runs = mlflow.search_runs(
141
138
  experiment_ids=experiment_ids,
@@ -180,7 +177,6 @@ def list_runs(
180
177
  Returns:
181
178
  RunCollection: A `RunCollection` instance containing the runs for the
182
179
  specified experiments.
183
-
184
180
  """
185
181
  rc = _list_runs(experiment_names, n_jobs)
186
182
  if status is None:
@@ -18,7 +18,7 @@ if TYPE_CHECKING:
18
18
  from mlflow.entities import Run
19
19
 
20
20
 
21
- def match(param: str, value: Any) -> bool:
21
+ def match(param: str, value: Any) -> bool: # noqa: PLR0911
22
22
  """Check if the string matches the specified value.
23
23
 
24
24
  Args:
@@ -28,8 +28,10 @@ def match(param: str, value: Any) -> bool:
28
28
  Returns:
29
29
  True if the parameter matches the specified value,
30
30
  False otherwise.
31
-
32
31
  """
32
+ if callable(value):
33
+ return value(param)
34
+
33
35
  if any(value is x for x in [None, True, False]):
34
36
  return param == str(value)
35
37
 
@@ -92,7 +94,6 @@ def to_value(param: str | None, type_: type) -> Any:
92
94
 
93
95
  Returns:
94
96
  The converted value.
95
-
96
97
  """
97
98
  if param is None or param == "None":
98
99
  return None
@@ -128,7 +129,6 @@ def get_params(run: Run, *names: str | list[str]) -> tuple[str | None, ...]:
128
129
  Returns:
129
130
  tuple[str | None, ...]: A tuple containing the values of the specified
130
131
  parameters in the order they were provided.
131
-
132
132
  """
133
133
  names_ = []
134
134
  for name in names:
@@ -155,7 +155,6 @@ def get_values(run: Run, names: list[str], types: list[type]) -> tuple[Any, ...]
155
155
  Returns:
156
156
  tuple[Any, ...]: A tuple containing the values of the specified
157
157
  parameters in the order they were provided.
158
-
159
158
  """
160
159
  params = get_params(run, names)
161
160
  it = zip(params, types, strict=True)
@@ -106,7 +106,6 @@ class RunCollection:
106
106
 
107
107
  Returns:
108
108
  A new `RunCollection` instance with the runs from both collections.
109
-
110
109
  """
111
110
  return self.__class__(self._runs + other._runs)
112
111
 
@@ -119,7 +118,6 @@ class RunCollection:
119
118
  Returns:
120
119
  A new `RunCollection` instance with the runs that are in this collection
121
120
  but not in the other.
122
-
123
121
  """
124
122
  runs = [run for run in self._runs if run not in other._runs] # noqa: SLF001
125
123
  return self.__class__(runs)
@@ -152,7 +150,6 @@ class RunCollection:
152
150
  Returns:
153
151
  A new `RunCollection` instance containing the first n runs if n is
154
152
  positive, or the last n runs if n is negative.
155
-
156
153
  """
157
154
  if n < 0:
158
155
  return self.__class__(self._runs[n:])
@@ -167,7 +164,6 @@ class RunCollection:
167
164
 
168
165
  Raises:
169
166
  ValueError: If the collection does not contain exactly one run.
170
-
171
167
  """
172
168
  if len(self._runs) != 1:
173
169
  raise ValueError("The collection does not contain exactly one run.")
@@ -180,7 +176,6 @@ class RunCollection:
180
176
  Returns:
181
177
  The only `Run` instance in the collection, or None if the collection
182
178
  does not contain exactly one run.
183
-
184
179
  """
185
180
  return self._runs[0] if len(self._runs) == 1 else None
186
181
 
@@ -192,7 +187,6 @@ class RunCollection:
192
187
 
193
188
  Raises:
194
189
  ValueError: If the collection is empty.
195
-
196
190
  """
197
191
  if not self._runs:
198
192
  raise ValueError("The collection is empty.")
@@ -205,7 +199,6 @@ class RunCollection:
205
199
  Returns:
206
200
  The first `Run` instance in the collection, or None if the collection
207
201
  is empty.
208
-
209
202
  """
210
203
  return self._runs[0] if self._runs else None
211
204
 
@@ -217,7 +210,6 @@ class RunCollection:
217
210
 
218
211
  Raises:
219
212
  ValueError: If the collection is empty.
220
-
221
213
  """
222
214
  if not self._runs:
223
215
  raise ValueError("The collection is empty.")
@@ -230,11 +222,18 @@ class RunCollection:
230
222
  Returns:
231
223
  The last `Run` instance in the collection, or None if the collection
232
224
  is empty.
233
-
234
225
  """
235
226
  return self._runs[-1] if self._runs else None
236
227
 
237
- def filter(self, config: object | None = None, **kwargs) -> RunCollection:
228
+ def filter(
229
+ self,
230
+ config: object | None = None,
231
+ *,
232
+ override: bool = False,
233
+ select: list[str] | None = None,
234
+ status: str | list[str] | int | list[int] | None = None,
235
+ **kwargs,
236
+ ) -> RunCollection:
238
237
  """Filter the `Run` instances based on the provided configuration.
239
238
 
240
239
  This method filters the runs in the collection according to the
@@ -254,13 +253,26 @@ class RunCollection:
254
253
  config (object | None): The configuration object to filter the runs.
255
254
  This can be any object that provides key-value pairs through
256
255
  the `iter_params` function.
256
+ override (bool): If True, override the configuration object with the
257
+ provided key-value pairs.
258
+ select (list[str] | None): The list of parameters to select.
259
+ status (str | list[str] | int | list[int] | None): The status of the
260
+ runs to filter.
257
261
  **kwargs: Additional key-value pairs to filter the runs.
258
262
 
259
263
  Returns:
260
264
  A new `RunCollection` object containing the filtered runs.
261
-
262
265
  """
263
- return RunCollection(filter_runs(self._runs, config, **kwargs))
266
+ return RunCollection(
267
+ filter_runs(
268
+ self._runs,
269
+ config,
270
+ override=override,
271
+ select=select,
272
+ status=status,
273
+ **kwargs,
274
+ ),
275
+ )
264
276
 
265
277
  def find(self, config: object | None = None, **kwargs) -> Run:
266
278
  """Find the first `Run` instance based on the provided configuration.
@@ -282,7 +294,6 @@ class RunCollection:
282
294
 
283
295
  See Also:
284
296
  `filter`: Perform the actual filtering logic.
285
-
286
297
  """
287
298
  try:
288
299
  return self.filter(config, **kwargs).first()
@@ -307,7 +318,6 @@ class RunCollection:
307
318
 
308
319
  See Also:
309
320
  `filter`: Perform the actual filtering logic.
310
-
311
321
  """
312
322
  return self.filter(config, **kwargs).try_first()
313
323
 
@@ -331,7 +341,6 @@ class RunCollection:
331
341
 
332
342
  See Also:
333
343
  `filter`: Perform the actual filtering logic.
334
-
335
344
  """
336
345
  try:
337
346
  return self.filter(config, **kwargs).last()
@@ -356,7 +365,6 @@ class RunCollection:
356
365
 
357
366
  See Also:
358
367
  `filter`: Perform the actual filtering logic.
359
-
360
368
  """
361
369
  return self.filter(config, **kwargs).try_last()
362
370
 
@@ -381,7 +389,6 @@ class RunCollection:
381
389
 
382
390
  See Also:
383
391
  `filter`: Perform the actual filtering logic.
384
-
385
392
  """
386
393
  try:
387
394
  return self.filter(config, **kwargs).one()
@@ -410,7 +417,6 @@ class RunCollection:
410
417
 
411
418
  See Also:
412
419
  `filter`: Perform the actual filtering logic.
413
-
414
420
  """
415
421
  return self.filter(config, **kwargs).try_one()
416
422
 
@@ -423,7 +429,6 @@ class RunCollection:
423
429
 
424
430
  Returns:
425
431
  A list of unique parameter names.
426
-
427
432
  """
428
433
  param_names = set()
429
434
 
@@ -448,7 +453,6 @@ class RunCollection:
448
453
  Returns:
449
454
  A dictionary where the keys are parameter names and the values are
450
455
  lists of parameter values.
451
-
452
456
  """
453
457
  params = {}
454
458
 
@@ -480,7 +484,6 @@ class RunCollection:
480
484
 
481
485
  Yields:
482
486
  Results obtained by applying the function to each run in the collection.
483
-
484
487
  """
485
488
  return (func(run, *args, **kwargs) for run in self)
486
489
 
@@ -501,7 +504,6 @@ class RunCollection:
501
504
  Yields:
502
505
  Results obtained by applying the function to each run id in the
503
506
  collection.
504
-
505
507
  """
506
508
  return (func(run_id, *args, **kwargs) for run_id in self.info.run_id)
507
509
 
@@ -522,7 +524,6 @@ class RunCollection:
522
524
  Yields:
523
525
  Results obtained by applying the function to each run configuration
524
526
  in the collection.
525
-
526
527
  """
527
528
  return (func(load_config(run), *args, **kwargs) for run in self)
528
529
 
@@ -547,7 +548,6 @@ class RunCollection:
547
548
  Yields:
548
549
  Results obtained by applying the function to each artifact URI in the
549
550
  collection.
550
-
551
551
  """
552
552
  return (func(uri, *args, **kwargs) for uri in self.info.artifact_uri)
553
553
 
@@ -571,7 +571,6 @@ class RunCollection:
571
571
  Yields:
572
572
  Results obtained by applying the function to each artifact directory
573
573
  in the collection.
574
-
575
574
  """
576
575
  return (func(dir, *args, **kwargs) for dir in self.info.artifact_dir) # noqa: A001
577
576
 
@@ -595,7 +594,6 @@ class RunCollection:
595
594
  dictionary where the keys are tuples of parameter values and the
596
595
  values are `RunCollection` objects containing the runs that match
597
596
  those parameter values.
598
-
599
597
  """
600
598
  grouped_runs: dict[str | None | tuple[str | None, ...], list[Run]] = {}
601
599
  is_list = isinstance(names, list)
@@ -622,7 +620,6 @@ class RunCollection:
622
620
  key (Callable[[Run], Any] | None): A function that takes a run and returns
623
621
  a value to sort by.
624
622
  reverse (bool): If True, sort in descending order.
625
-
626
623
  """
627
624
  self._runs.sort(key=key or (lambda x: x.info.start_time), reverse=reverse)
628
625
 
@@ -636,7 +633,6 @@ class RunCollection:
636
633
 
637
634
  Returns:
638
635
  A list of values for the specified parameters.
639
-
640
636
  """
641
637
  is_list = isinstance(names, list)
642
638
 
@@ -668,7 +664,6 @@ class RunCollection:
668
664
  This can be a single parameter name or multiple names provided
669
665
  as separate arguments or as a list.
670
666
  reverse (bool): If True, sort in descending order.
671
-
672
667
  """
673
668
  values = self.values(names)
674
669
  index = sorted(range(len(self)), key=lambda i: values[i], reverse=reverse)
@@ -726,7 +721,6 @@ def filter_runs(
726
721
 
727
722
  Returns:
728
723
  A list of runs that match the specified configuration and key-value pairs.
729
-
730
724
  """
731
725
  if override:
732
726
  config = select_overrides(config)
@@ -757,7 +751,6 @@ def filter_runs_by_status(
757
751
 
758
752
  Returns:
759
753
  A list of runs that match the specified status.
760
-
761
754
  """
762
755
  if isinstance(status, str):
763
756
  if status.startswith("!"):
@@ -37,7 +37,6 @@ class RunCollectionData:
37
37
 
38
38
  Returns:
39
39
  A DataFrame containing the runs' configurations.
40
-
41
40
  """
42
41
  return DataFrame(self._runs.map_config(collect_params))
43
42
 
@@ -26,7 +26,6 @@ def get_artifact_dir(run: Run | None = None) -> Path:
26
26
 
27
27
  Returns:
28
28
  The local path to the directory where the artifacts are downloaded.
29
-
30
29
  """
31
30
  uri = mlflow.get_artifact_uri() if run is None else run.info.artifact_uri
32
31
 
@@ -53,7 +52,6 @@ def get_artifact_path(run: Run | None, path: str) -> Path:
53
52
 
54
53
  Returns:
55
54
  The local path to the artifact.
56
-
57
55
  """
58
56
  return get_artifact_dir(run) / path
59
57
 
@@ -76,7 +74,6 @@ def get_hydra_output_dir(run: Run | None = None) -> Path:
76
74
  Raises:
77
75
  FileNotFoundError: If the Hydra configuration file is not found
78
76
  in the artifacts.
79
-
80
77
  """
81
78
  if run is None:
82
79
  hc = HydraConfig.get()
@@ -105,7 +102,6 @@ def load_config(run: Run) -> DictConfig:
105
102
  Returns:
106
103
  The loaded configuration as a DictConfig object. Returns an empty
107
104
  DictConfig if the configuration file is not found.
108
-
109
105
  """
110
106
  path = get_artifact_dir(run) / ".hydra/config.yaml"
111
107
  return OmegaConf.load(path) # type: ignore
@@ -130,7 +126,6 @@ def load_overrides(run: Run) -> list[str]:
130
126
  Returns:
131
127
  The loaded overrides as a list of strings. Returns an empty list
132
128
  if the overrides file is not found.
133
-
134
129
  """
135
130
  path = get_artifact_dir(run) / ".hydra/overrides.yaml"
136
131
  return [str(x) for x in OmegaConf.load(path)]
@@ -6,11 +6,13 @@ from mlflow.entities import Run
6
6
 
7
7
  from hydraflow.run_collection import RunCollection
8
8
 
9
+ pytestmark = pytest.mark.xdist_group(name="group1")
10
+
9
11
 
10
12
  @pytest.fixture(scope="module")
11
13
  def rc(collect):
12
14
  args = ["-m", "name=a,b", "height=3"]
13
- return collect("config/config.py", args)
15
+ return collect("config/overrides.py", args)
14
16
 
15
17
 
16
18
  @pytest.fixture(scope="module")
@@ -15,7 +15,7 @@ def experiment_name(tmp_path_factory: pytest.TempPathFactory):
15
15
  cwd = Path.cwd()
16
16
  name = str(uuid.uuid4())
17
17
 
18
- os.chdir(tmp_path_factory.mktemp(name))
18
+ os.chdir(tmp_path_factory.mktemp(name, numbered=False))
19
19
 
20
20
  yield name
21
21
 
@@ -0,0 +1,49 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING
5
+
6
+ import hydra
7
+ from hydra.core.config_store import ConfigStore
8
+
9
+ import hydraflow
10
+
11
+ if TYPE_CHECKING:
12
+ from pathlib import Path
13
+
14
+
15
+ @dataclass
16
+ class Config:
17
+ count: int = 0
18
+
19
+
20
+ cs = ConfigStore.instance()
21
+ cs.store(name="config", node=Config)
22
+
23
+
24
+ @hydra.main(version_base=None, config_name="config")
25
+ def app(cfg: Config):
26
+ hydraflow.set_experiment()
27
+
28
+ rc = hydraflow.list_runs()
29
+
30
+ if rc.filter(cfg, status="finished", override=True):
31
+ return
32
+
33
+ if run := rc.try_find(cfg, override=True):
34
+ run_id = run.info.run_id
35
+ else:
36
+ run_id = None
37
+
38
+ with hydraflow.start_run(cfg, run_id=run_id) as run:
39
+ log(hydraflow.get_artifact_dir(run))
40
+
41
+
42
+ def log(path: Path):
43
+ file = path / "a.txt"
44
+ text = file.read_text() if file.exists() else ""
45
+ file.write_text(text + "a")
46
+
47
+
48
+ if __name__ == "__main__":
49
+ app()
@@ -0,0 +1,44 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING
5
+
6
+ import hydra
7
+ from hydra.core.config_store import ConfigStore
8
+
9
+ import hydraflow
10
+
11
+ if TYPE_CHECKING:
12
+ from pathlib import Path
13
+
14
+
15
+ @dataclass
16
+ class Config:
17
+ count: int = 0
18
+
19
+
20
+ cs = ConfigStore.instance()
21
+ cs.store(name="config", node=Config)
22
+
23
+
24
+ @hydra.main(version_base=None, config_name="config")
25
+ def app(cfg: Config):
26
+ hydraflow.set_experiment()
27
+
28
+ if run := hydraflow.list_runs().try_find(cfg, override=True):
29
+ run_id = run.info.run_id
30
+ else:
31
+ run_id = None
32
+
33
+ with hydraflow.start_run(cfg, run_id=run_id) as run:
34
+ log(hydraflow.get_artifact_dir(run))
35
+
36
+
37
+ def log(path: Path):
38
+ file = path / "a.txt"
39
+ text = file.read_text() if file.exists() else ""
40
+ file.write_text(text + "a")
41
+
42
+
43
+ if __name__ == "__main__":
44
+ app()
@@ -4,6 +4,8 @@ from mlflow.entities import Run
4
4
  from hydraflow.run_collection import RunCollection
5
5
  from hydraflow.utils import get_artifact_path, get_hydra_output_dir
6
6
 
7
+ pytestmark = pytest.mark.xdist_group(name="group2")
8
+
7
9
 
8
10
  @pytest.fixture(scope="module")
9
11
  def rc(collect):
@@ -0,0 +1,41 @@
1
+ import pytest
2
+ from mlflow.entities import Run, RunStatus
3
+ from mlflow.tracking import MlflowClient
4
+
5
+ from hydraflow.run_collection import RunCollection
6
+
7
+ pytestmark = pytest.mark.xdist_group(name="group4")
8
+
9
+
10
+ @pytest.fixture(scope="module")
11
+ def rc(collect):
12
+ client = MlflowClient()
13
+ running = RunStatus.to_string(RunStatus.RUNNING)
14
+
15
+ filename = "context/preemption.py"
16
+ args = ["-m", "count=1,2,3"]
17
+
18
+ rc = collect(filename, args)
19
+ client.set_terminated(rc.get(count=2).info.run_id, status=running)
20
+ client.set_terminated(rc.get(count=3).info.run_id, status=running)
21
+ rc = collect(filename, args)
22
+ client.set_terminated(rc.get(count=3).info.run_id, status=running)
23
+ return collect(filename, args)
24
+
25
+
26
+ def test_rc_len(rc: RunCollection):
27
+ assert len(rc) == 3
28
+
29
+
30
+ @pytest.fixture(scope="module", params=[1, 2, 3])
31
+ def run(rc: RunCollection, request: pytest.FixtureRequest):
32
+ return rc.get(count=request.param)
33
+
34
+
35
+ def test_run_count(run: Run):
36
+ from hydraflow.utils import get_artifact_path
37
+
38
+ count = int(run.data.params["count"])
39
+ path = get_artifact_path(run, "a.txt")
40
+ text = path.read_text()
41
+ assert len(text) == count
@@ -0,0 +1,31 @@
1
+ import pytest
2
+ from mlflow.entities import Run
3
+
4
+ from hydraflow.run_collection import RunCollection
5
+
6
+ pytestmark = pytest.mark.xdist_group(name="group3")
7
+
8
+
9
+ @pytest.fixture(scope="module")
10
+ def rc(collect):
11
+ collect("context/rerun.py", ["-m", "count=1,2,3"])
12
+ collect("context/rerun.py", ["-m", "count=2,3"])
13
+ return collect("context/rerun.py", ["-m", "count=3"])
14
+
15
+
16
+ def test_rc_len(rc: RunCollection):
17
+ assert len(rc) == 3
18
+
19
+
20
+ @pytest.fixture(scope="module", params=[1, 2, 3])
21
+ def run(rc: RunCollection, request: pytest.FixtureRequest):
22
+ return rc.get(count=request.param)
23
+
24
+
25
+ def test_run_count(run: Run):
26
+ from hydraflow.utils import get_artifact_path
27
+
28
+ count = int(run.data.params["count"])
29
+ path = get_artifact_path(run, "a.txt")
30
+ text = path.read_text()
31
+ assert len(text) == count
@@ -5,6 +5,8 @@ import pytest
5
5
 
6
6
  from hydraflow.param import match
7
7
 
8
+ pytestmark = pytest.mark.xdist_group(name="group2")
9
+
8
10
 
9
11
  @pytest.fixture
10
12
  def param(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
@@ -51,6 +53,14 @@ def test_param(param, x, y):
51
53
  assert match(p, x)
52
54
 
53
55
 
56
+ @pytest.mark.parametrize(
57
+ ("param", "value"),
58
+ [("1.0", lambda x: float(x) > 0), ("-1.0", lambda x: float(x) < 0)],
59
+ )
60
+ def test_match_callable(param, value):
61
+ assert match(param, value)
62
+
63
+
54
64
  @pytest.mark.parametrize(
55
65
  ("param", "value"),
56
66
  [("1.0", 1.0), ("1.0", 1), ("0.0", 0), ("0.0", 0.0)],
@@ -3,11 +3,13 @@ from mlflow.entities import Run
3
3
 
4
4
  from hydraflow.run_collection import RunCollection
5
5
 
6
+ pytestmark = pytest.mark.xdist_group(name="group2")
7
+
6
8
 
7
9
  @pytest.fixture(scope="module")
8
10
  def rc(collect):
9
11
  args = ["host=a"]
10
- return collect("param/param.py", args)
12
+ return collect("param/params.py", args)
11
13
 
12
14
 
13
15
  @pytest.fixture(scope="module")
@@ -7,6 +7,8 @@ from mlflow.entities import Experiment, Run, RunStatus
7
7
  from hydraflow.mlflow import list_runs
8
8
  from hydraflow.run_collection import RunCollection, filter_runs
9
9
 
10
+ pytestmark = pytest.mark.xdist_group(name="group3")
11
+
10
12
 
11
13
  @pytest.fixture(scope="module")
12
14
  def experiment(experiment_name: str):
@@ -4,6 +4,8 @@ from mlflow.entities import Experiment
4
4
 
5
5
  from hydraflow.run_collection import RunCollection
6
6
 
7
+ pytestmark = pytest.mark.xdist_group(name="group3")
8
+
7
9
 
8
10
  @pytest.fixture(scope="module")
9
11
  def experiment(experiment_name: str):
@@ -2,6 +2,8 @@ import pytest
2
2
 
3
3
  from hydraflow.run_collection import RunCollection
4
4
 
5
+ pytestmark = pytest.mark.xdist_group(name="group7")
6
+
5
7
 
6
8
  @pytest.fixture(scope="module")
7
9
  def rc(collect):
@@ -4,6 +4,8 @@ from mlflow.entities import Experiment
4
4
 
5
5
  from hydraflow.run_collection import RunCollection
6
6
 
7
+ pytestmark = pytest.mark.xdist_group(name="group4")
8
+
7
9
 
8
10
  @pytest.fixture(scope="module")
9
11
  def experiment(experiment_name: str):
@@ -7,6 +7,8 @@ from hydraflow.run_collection import RunCollection
7
7
  if TYPE_CHECKING:
8
8
  from .run import Config
9
9
 
10
+ pytestmark = pytest.mark.xdist_group(name="group4")
11
+
10
12
 
11
13
  @pytest.fixture(scope="module")
12
14
  def rc(collect):
@@ -4,12 +4,14 @@ import mlflow
4
4
  import pytest
5
5
  from mlflow.entities import Experiment, Run, RunStatus
6
6
 
7
+ pytestmark = pytest.mark.xdist_group(name="group0")
8
+
7
9
 
8
10
  @pytest.fixture(scope="module")
9
11
  def experiment(experiment_name: str):
10
12
  from hydraflow.mlflow import log_params, set_experiment
11
13
 
12
- experiment = set_experiment(uri="mlruns", name="e")
14
+ experiment = set_experiment(uri="test_mlflow", name="e")
13
15
 
14
16
  with mlflow.start_run():
15
17
  log_params({"name": experiment_name})
@@ -20,13 +22,11 @@ def experiment(experiment_name: str):
20
22
  mlflow.start_run()
21
23
  mlflow.end_run(status=RunStatus.to_string(RunStatus.FAILED))
22
24
 
23
- yield experiment
24
-
25
- mlflow.set_tracking_uri("")
25
+ return experiment
26
26
 
27
27
 
28
28
  def test_set_experiment_uri(experiment: Experiment):
29
- assert mlflow.get_tracking_uri() == "mlruns"
29
+ assert mlflow.get_tracking_uri() == "test_mlflow"
30
30
 
31
31
 
32
32
  def test_set_experiment_location(experiment: Experiment):
@@ -35,7 +35,7 @@ def test_set_experiment_location(experiment: Experiment):
35
35
  if loc.startswith("file:"): # for windows
36
36
  loc = loc[loc.index("C:") :]
37
37
 
38
- path = Path.cwd() / "mlruns" / experiment.experiment_id
38
+ path = Path.cwd() / "test_mlflow" / experiment.experiment_id
39
39
  assert path == Path(loc)
40
40
 
41
41
 
@@ -4,6 +4,8 @@ from mlflow.entities import Experiment, Run
4
4
 
5
5
  from hydraflow.run_collection import RunCollection
6
6
 
7
+ pytestmark = pytest.mark.xdist_group(name="group5")
8
+
7
9
 
8
10
  @pytest.fixture(scope="module")
9
11
  def experiment(experiment_name: str):
@@ -28,6 +30,7 @@ def run(rc: RunCollection):
28
30
  return rc.first()
29
31
 
30
32
 
33
+ @pytest.mark.order(0)
31
34
  def test_hydra_output_dir(run: Run):
32
35
  from hydraflow.utils import get_hydra_output_dir
33
36
 
@@ -35,6 +38,7 @@ def test_hydra_output_dir(run: Run):
35
38
  get_hydra_output_dir(run)
36
39
 
37
40
 
41
+ @pytest.mark.order(1)
38
42
  def test_remove_run(rc: RunCollection):
39
43
  from hydraflow.utils import get_artifact_dir, remove_run
40
44
 
@@ -8,13 +8,20 @@ from hydraflow.run_collection import RunCollection
8
8
  if TYPE_CHECKING:
9
9
  from .utils import Config
10
10
 
11
+ pytestmark = pytest.mark.xdist_group(name="group6")
12
+
11
13
 
12
14
  @pytest.fixture(scope="module")
13
15
  def rc(collect):
14
16
  args = ["-m", "name=a,b", "age=10"]
17
+
15
18
  return collect("utils/utils.py", args)
16
19
 
17
20
 
21
+ def test_rc_len(rc: RunCollection):
22
+ assert len(rc) == 2
23
+
24
+
18
25
  @pytest.fixture(scope="module")
19
26
  def run(rc: RunCollection):
20
27
  return rc.first()
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes