flyteplugins-hydra 2.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,35 @@
1
+ """flyteplugins-hydra — Hydra launcher plugin for Flyte.
2
+
3
+ Provides three entry points for running Flyte tasks via Hydra:
4
+
5
+ 1. **``@hydra.main`` + ``--multirun``** (standard Hydra CLI pattern):
6
+
7
+ .. code-block:: bash
8
+
9
+ python train.py hydra/launcher=flyte hydra.launcher.mode=remote
10
+ python train.py --multirun hydra/launcher=flyte hydra.launcher.mode=remote \\
11
+ optimizer.lr=0.001,0.01,0.1
12
+
13
+ 2. **``flyte hydra run``** (Flyte CLI extension, no ``@hydra.main`` required):
14
+
15
+ .. code-block:: bash
16
+
17
+ flyte hydra run --config-path conf --config-name training --mode remote \\
18
+ train.py pipeline --cfg optimizer.lr=0.01
19
+
20
+ 3. **``hydra_run`` / ``hydra_sweep``** (Python SDK):
21
+
22
+ .. code-block:: python
23
+
24
+ from flyteplugins.hydra import hydra_run, hydra_sweep
25
+
26
+ hydra_run(pipeline, config_path="conf", config_name="training",
27
+ overrides=["optimizer.lr=0.01"], mode="remote")
28
+
29
+ runs = hydra_sweep(pipeline, config_path="conf", config_name="training",
30
+ overrides=["optimizer.lr=0.001,0.01,0.1"], mode="remote")
31
+ """
32
+
33
+ from flyteplugins.hydra._run import apply_task_env, hydra_run, hydra_sweep
34
+
35
+ __all__ = ["apply_task_env", "hydra_run", "hydra_sweep"]
@@ -0,0 +1,568 @@
1
+ """``flyte hydra`` CLI command group.
2
+
3
+ Registered via the ``flyte.plugins.cli.commands`` entry point so that
4
+ ``flyte hydra run`` is available once ``flyteplugins-hydra`` is installed.
5
+
6
+ Inherits the standard ``flyte run`` flags that apply to script execution
7
+ (``--project``, ``--domain``, ``--local``, ``--image``, ``--follow``, etc.).
8
+ Hydra-specific options are:
9
+ ``--config-path``, ``--config-name``, ``--mode``, ``--multirun``,
10
+ ``--wait/--no-wait``, ``--wait-max-workers``, ``--task-env-key``,
11
+ ``--hydra-override``.
12
+ Application config overrides use the task's ``DictConfig`` parameter name,
13
+ for example ``--cfg`` for ``cfg: DictConfig`` or ``--config`` for
14
+ ``config: DictConfig``.
15
+
16
+ Usage
17
+ -----
18
+ Single run (remote by default)::
19
+
20
+ flyte hydra run --config-path conf --config-name training \\
21
+ train.py pipeline \\
22
+ --cfg optimizer.lr=0.01
23
+
24
+ Single run forced local::
25
+
26
+ flyte hydra run --local --config-path conf --config-name training \\
27
+ train.py pipeline
28
+
29
+ Grid sweep (six parallel remote executions)::
30
+
31
+ flyte hydra run --multirun --config-path conf --config-name training \\
32
+ train.py pipeline \\
33
+ --cfg "optimizer.lr=0.001,0.01,0.1" --cfg "training.epochs=10,20"
34
+
35
+ TPE/Bayesian sweep via Optuna sweeper::
36
+
37
+ flyte hydra run --multirun --config-path conf --config-name training \\
38
+ train.py pipeline \\
39
+ --hydra-override hydra/sweeper=optuna \\
40
+ --hydra-override hydra.sweeper.n_trials=20 \\
41
+ --hydra-override hydra.sweeper.n_jobs=4 \\
42
+ --cfg "optimizer.lr=interval(1e-4,1e-1)"
43
+ """
44
+
45
+ from __future__ import annotations
46
+
47
+ import importlib.util
48
+ import inspect
49
+ import sys
50
+ from pathlib import Path
51
+
52
+ import rich_click as click
53
+ from click.shell_completion import CompletionItem
54
+
55
+ _HYDRA_OVERRIDE_OPTION = "--hydra-override"
56
+
57
+
58
+ def _follow_run_logs(run) -> None:
59
+ """Show logs for a returned remote Run when ``--follow`` is set.
60
+
61
+ ``hydra_run`` / ``hydra_sweep`` return whatever ``flyte.run`` returned.
62
+ In remote mode that should be a ``flyte.remote.Run`` with ``show_logs``;
63
+ in local mode it is a local result wrapper, so this helper quietly skips
64
+ objects that do not expose remote logs.
65
+ """
66
+ show_logs = getattr(run, "show_logs", None)
67
+ if show_logs is not None:
68
+ show_logs(max_lines=30, show_ts=True, raw=False)
69
+
70
+
71
+ def _completed_result_value(run):
72
+ """Return a completed result value without re-printing remote run URLs."""
73
+ if hasattr(run, "value"):
74
+ return run.value
75
+ if getattr(run, "url", None) is None:
76
+ return run
77
+ return None
78
+
79
+
80
+ def _load_script_task(script: str, task_name: str):
81
+ """Load script as a module and return the requested Flyte task."""
82
+ script_path = Path(script).resolve()
83
+ module_name = script_path.stem
84
+ sys.path.append(str(script_path.parent))
85
+ spec = importlib.util.spec_from_file_location(module_name, script_path)
86
+
87
+ if spec is None or spec.loader is None:
88
+ raise click.ClickException(f"Could not load module from {script}")
89
+
90
+ mod = importlib.util.module_from_spec(spec)
91
+ sys.modules[module_name] = mod
92
+ spec.loader.exec_module(mod)
93
+
94
+ task = getattr(mod, task_name, None)
95
+ if task is None:
96
+ raise click.ClickException(f"Task '{task_name}' not found in {script}")
97
+ return task
98
+
99
+
100
+ def _script_task_and_tail(ctx: click.Context) -> tuple[str | None, str | None, list[str]]:
101
+ """Return SCRIPT, TASK_NAME, and remaining task-tail args from a Click context."""
102
+ script = ctx.params.get("script")
103
+ task_name = ctx.params.get("task_name")
104
+ args = list(ctx.args)
105
+ if script and task_name:
106
+ return script, task_name, args
107
+ if len(args) >= 2:
108
+ return args[0], args[1], args[2:]
109
+ return None, None, args
110
+
111
+
112
+ def _extract_config_overrides(task, args: list[str]) -> tuple[list[str], list[str]]:
113
+ """Split DictConfig override flags out of the task-argument tail.
114
+
115
+ ``flyte hydra run`` names application config override flags after the
116
+ task's ``DictConfig`` input. For ``cfg: DictConfig`` users pass
117
+ ``--cfg optimizer.lr=0.01``; for ``config: DictConfig`` they pass
118
+ ``--config optimizer.lr=0.01``. These flags sit after ``SCRIPT TASK_NAME``
119
+ beside ordinary task args, so Click cannot parse them with fixed command
120
+ options. This helper scans that tail, returns the extracted Hydra override
121
+ strings, and leaves all other args for normal Flyte task-parameter parsing.
122
+ """
123
+ from flyteplugins.hydra._run import _config_param_names
124
+
125
+ config_param_names = set(_config_param_names(task))
126
+ if not config_param_names:
127
+ return [], args
128
+
129
+ config_options = {f"--{name}" for name in config_param_names}
130
+ overrides: list[str] = []
131
+ remaining: list[str] = []
132
+ idx = 0
133
+
134
+ while idx < len(args):
135
+ arg = args[idx]
136
+ matched = False
137
+
138
+ for option in config_options:
139
+ # Support both "--config-param value" and "--config-param=value".
140
+ if arg == option:
141
+ if idx + 1 >= len(args):
142
+ raise click.UsageError(f"Option '{option}' requires an override value.")
143
+ overrides.append(args[idx + 1])
144
+ idx += 2
145
+ matched = True
146
+ break
147
+ if arg.startswith(f"{option}="):
148
+ overrides.append(arg.split("=", 1)[1])
149
+ idx += 1
150
+ matched = True
151
+ break
152
+
153
+ if matched:
154
+ continue
155
+
156
+ remaining.append(arg)
157
+ idx += 1
158
+
159
+ return overrides, remaining
160
+
161
+
162
+ def _override_completion_context(
163
+ args: list[str],
164
+ incomplete: str,
165
+ override_options: set[str],
166
+ ) -> tuple[list[str], str, str] | None:
167
+ """Return previous overrides, current override prefix, and replacement prefix.
168
+
169
+ ``flyte hydra run`` carries Hydra overrides as values to dynamic options
170
+ such as ``--cfg`` or ``--config``. During shell completion Click gives us
171
+ only the already-complete tail args plus the current incomplete word, so we
172
+ scan that tail ourselves to decide whether the cursor is completing a
173
+ Hydra override value.
174
+ """
175
+ for option in override_options:
176
+ prefix = f"{option}="
177
+ if incomplete.startswith(prefix):
178
+ return _collect_complete_overrides(args, override_options), incomplete[len(prefix) :], prefix
179
+
180
+ if not args or args[-1] not in override_options:
181
+ return None
182
+
183
+ return _collect_complete_overrides(args[:-1], override_options), incomplete, ""
184
+
185
+
186
+ def _collect_complete_overrides(args: list[str], override_options: set[str]) -> list[str]:
187
+ """Collect complete Hydra override values from a task-argument tail."""
188
+ overrides: list[str] = []
189
+ pending_override = False
190
+
191
+ for arg in args:
192
+ if pending_override:
193
+ overrides.append(arg)
194
+ pending_override = False
195
+ continue
196
+
197
+ if arg in override_options:
198
+ pending_override = True
199
+ continue
200
+
201
+ for option in override_options:
202
+ prefix = f"{option}="
203
+ if arg.startswith(prefix):
204
+ overrides.append(arg[len(prefix) :])
205
+ break
206
+
207
+ return overrides
208
+
209
+
210
+ def _complete_hydra_override_values(
211
+ *,
212
+ config_path: str | None,
213
+ config_name: str,
214
+ multirun: bool,
215
+ previous_overrides: list[str],
216
+ incomplete: str,
217
+ ) -> list[str]:
218
+ """Ask Hydra's own completion engine for override-value suggestions."""
219
+ from flyteplugins.hydra._run import _hydra_init
220
+ from hydra.plugins.completion_plugin import DefaultCompletionPlugin
221
+
222
+ parts = ["--multirun"] if multirun else []
223
+ parts.extend(previous_overrides)
224
+ line = " ".join([*parts, incomplete]).strip()
225
+ if not incomplete:
226
+ line = f"{line} " if line else ""
227
+
228
+ with _hydra_init(config_path) as config_loader:
229
+ completer = DefaultCompletionPlugin(config_loader)
230
+ return completer._query(config_name=config_name, line=line)
231
+
232
+
233
+ def _hydra_override_option_complete(ctx: click.Context, _param, incomplete: str) -> list[CompletionItem]:
234
+ """Complete values for the declared ``--hydra-override`` Click option."""
235
+ script, task_name, task_tail = _script_task_and_tail(ctx)
236
+ config_options: set[str] = set()
237
+ if script and task_name:
238
+ try:
239
+ task = _load_script_task(script, task_name)
240
+ from flyteplugins.hydra._run import _config_param_names
241
+
242
+ config_options = {f"--{name}" for name in _config_param_names(task)}
243
+ except Exception:
244
+ config_options = set()
245
+
246
+ previous_overrides = list(ctx.params.get("hydra_overrides") or ())
247
+ previous_overrides.extend(_collect_complete_overrides(task_tail, config_options))
248
+
249
+ try:
250
+ suggestions = _complete_hydra_override_values(
251
+ config_path=ctx.params.get("config_path"),
252
+ config_name=ctx.params.get("config_name") or "config",
253
+ multirun=bool(ctx.params.get("multirun")),
254
+ previous_overrides=previous_overrides,
255
+ incomplete=incomplete,
256
+ )
257
+ except Exception:
258
+ return []
259
+
260
+ return [CompletionItem(suggestion) for suggestion in suggestions]
261
+
262
+
263
+ class HydraRunCommand(click.RichCommand):
264
+ """Click command that adds Hydra override completions after SCRIPT TASK."""
265
+
266
+ def shell_complete(self, ctx: click.Context, incomplete: str) -> list[CompletionItem]:
267
+ results = super().shell_complete(ctx, incomplete)
268
+ script, task_name, task_tail = _script_task_and_tail(ctx)
269
+ if not script or not task_name:
270
+ return results
271
+
272
+ try:
273
+ task = _load_script_task(script, task_name)
274
+ except Exception:
275
+ return results
276
+
277
+ from flyteplugins.hydra._run import _config_param_names
278
+
279
+ config_options = {f"--{name}" for name in _config_param_names(task)}
280
+ override_options = {*config_options, _HYDRA_OVERRIDE_OPTION}
281
+
282
+ if incomplete.startswith("-"):
283
+ existing = {item.value for item in results}
284
+ results.extend(
285
+ CompletionItem(option, help="Hydra app-level override")
286
+ for option in sorted(config_options)
287
+ if option.startswith(incomplete) and option not in existing
288
+ )
289
+
290
+ completion_context = _override_completion_context(task_tail, incomplete, override_options)
291
+ if completion_context is None:
292
+ return results
293
+
294
+ previous_overrides, current_override, replacement_prefix = completion_context
295
+ try:
296
+ suggestions = _complete_hydra_override_values(
297
+ config_path=ctx.params.get("config_path"),
298
+ config_name=ctx.params.get("config_name") or "config",
299
+ multirun=bool(ctx.params.get("multirun")),
300
+ previous_overrides=previous_overrides,
301
+ incomplete=current_override,
302
+ )
303
+ except Exception:
304
+ return results
305
+
306
+ results.extend(CompletionItem(f"{replacement_prefix}{suggestion}") for suggestion in suggestions)
307
+ return results
308
+
309
+
310
+ def _parse_task_kwargs(task, args: list[str], parent_ctx: click.Context) -> dict:
311
+ """Convert ordinary task CLI flags into Python kwargs for ``flyte.run``.
312
+
313
+ The Hydra command has to load the user script before it can know the task
314
+ interface. Once the task is available, this function builds a temporary
315
+ Click command from Flyte's typed interface using the same ``to_click_option``
316
+ converters as ``flyte run``. ``DictConfig`` inputs are intentionally
317
+ skipped because those are composed by Hydra and injected by ``hydra_run`` /
318
+ ``hydra_sweep``.
319
+ """
320
+ from flyte._internal.runtime.types_serde import transform_native_to_typed_interface
321
+ from flyte.cli._params import to_click_option
322
+
323
+ from flyteplugins.hydra._run import _config_param_names
324
+
325
+ interface = transform_native_to_typed_interface(task.native_interface)
326
+ if interface is None:
327
+ return {}
328
+
329
+ inputs_interface = task.native_interface.inputs
330
+ config_param_names = set(_config_param_names(task))
331
+ params = []
332
+
333
+ for entry in interface.inputs.variables:
334
+ name, var = entry.key, entry.value
335
+ if name in config_param_names:
336
+ continue
337
+
338
+ default_val = None
339
+ if inputs_interface[name][1] is not inspect._empty:
340
+ default_val = inputs_interface[name][1]
341
+
342
+ params.append(to_click_option(name, var, inputs_interface[name][0], default_val))
343
+
344
+ def _collect(**kwargs):
345
+ return kwargs
346
+
347
+ # Let Click apply Flyte's normal type conversion and required/default
348
+ # validation for the remaining task inputs.
349
+ parser = click.Command(
350
+ name="task-args",
351
+ params=params,
352
+ callback=_collect,
353
+ )
354
+ return parser.main(
355
+ args=args,
356
+ prog_name="task arguments",
357
+ standalone_mode=False,
358
+ obj=parent_ctx.obj,
359
+ )
360
+
361
+
362
+ @click.group(name="hydra")
363
+ def hydra_group() -> None:
364
+ """Run Flyte tasks via Hydra config composition and sweeping."""
365
+
366
+
367
+ @hydra_group.command(
368
+ name="run",
369
+ cls=HydraRunCommand,
370
+ context_settings={"ignore_unknown_options": True, "allow_extra_args": True},
371
+ )
372
+ @click.argument("script", type=click.Path(exists=True, dir_okay=False))
373
+ @click.argument("task_name")
374
+ @click.option("--config-path", default=None, help="Path to Hydra config directory.")
375
+ @click.option(
376
+ "--config-name",
377
+ default="config",
378
+ show_default=True,
379
+ help="Top-level config file name (without .yaml).",
380
+ )
381
+ @click.option(
382
+ "--mode",
383
+ type=click.Choice(["local", "remote"]),
384
+ default=None,
385
+ help="Execution mode. Defaults to remote for Flyte CLI parity.",
386
+ )
387
+ @click.option(
388
+ "--multirun",
389
+ is_flag=True,
390
+ default=False,
391
+ help="Expand sweep overrides into a grid of executions.",
392
+ )
393
+ @click.option(
394
+ "--wait/--no-wait",
395
+ default=True,
396
+ show_default=True,
397
+ help="Wait for remote Flyte runs to reach a terminal phase.",
398
+ )
399
+ @click.option(
400
+ "--wait-max-workers",
401
+ type=click.IntRange(min=1),
402
+ default=32,
403
+ show_default=True,
404
+ help="Maximum worker threads used while waiting for remote Flyte runs.",
405
+ )
406
+ @click.option(
407
+ "--task-env-key",
408
+ default="task_env",
409
+ show_default=True,
410
+ help="Config key containing entry-task task.override kwargs by task name.",
411
+ )
412
+ @click.option(
413
+ "--hydra-override",
414
+ "hydra_overrides",
415
+ multiple=True,
416
+ metavar="KEY=VALUE",
417
+ shell_complete=_hydra_override_option_complete,
418
+ help=("Hydra-namespace override (repeatable), e.g. hydra/sweeper=optuna or hydra.sweeper.n_trials=20."),
419
+ )
420
+ @click.pass_context
421
+ def hydra_run_cmd(
422
+ ctx: click.Context,
423
+ script: str,
424
+ task_name: str,
425
+ config_path: str | None,
426
+ config_name: str,
427
+ mode: str | None,
428
+ multirun: bool,
429
+ wait: bool,
430
+ wait_max_workers: int,
431
+ task_env_key: str,
432
+ hydra_overrides: tuple[str, ...],
433
+ **run_params,
434
+ ) -> None:
435
+ """Compose a Hydra config and run TASK_NAME from SCRIPT on Flyte.
436
+
437
+ SCRIPT is the path to a Python file containing the Flyte task.
438
+ TASK_NAME is the name of the task function to run.
439
+
440
+ Use the task's ``DictConfig`` parameter name for app-level overrides
441
+ (for example ``--cfg`` or ``--config``).
442
+ Use ``--hydra-override`` for hydra-namespace settings (hydra/sweeper=optuna).
443
+ """
444
+ from flyte.cli._common import initialize_config
445
+ from flyte.cli._run import RunArguments
446
+
447
+ run_args = RunArguments.from_dict(run_params)
448
+
449
+ # Initialise Flyte.
450
+ config = initialize_config(
451
+ ctx,
452
+ run_args.project,
453
+ run_args.domain,
454
+ root_dir=run_args.root_dir,
455
+ images=run_args.image or None,
456
+ sync_local_sys_paths=not run_args.no_sync_local_sys_paths,
457
+ )
458
+ ctx.obj = config.replace(run_args=run_args)
459
+
460
+ if run_args.local and mode == "remote":
461
+ raise click.UsageError("Use either --local or --mode remote, not both.")
462
+ if run_args.follow and not wait:
463
+ raise click.UsageError("Use either --follow or --no-wait, not both.")
464
+
465
+ execution_mode = "local" if run_args.local else (mode or "remote")
466
+
467
+ # Only forward options that are accepted by flyte.with_runcontext and are
468
+ # meaningful for an in-process script task. Project/domain/image/root-dir
469
+ # are handled by initialize_config above; follow is handled after launch.
470
+ run_options: dict = {
471
+ "log_format": config.log_format,
472
+ "reset_root_logger": config.reset_root_logger,
473
+ }
474
+ if run_args.service_account is not None:
475
+ run_options["service_account"] = run_args.service_account
476
+ if run_args.name is not None:
477
+ run_options["name"] = run_args.name
478
+ if run_args.raw_data_path is not None:
479
+ run_options["raw_data_path"] = run_args.raw_data_path
480
+ if run_args.copy_style != "loaded_modules":
481
+ run_options["copy_style"] = run_args.copy_style
482
+ if run_args.debug:
483
+ run_options["debug"] = True
484
+
485
+ # Load the script as a module so Flyte task decorators have run before we
486
+ # inspect the requested task's typed interface.
487
+ task = _load_script_task(script, task_name)
488
+
489
+ # ctx.args contains everything after SCRIPT TASK_NAME that was not consumed
490
+ # by the fixed Hydra/Flyte options. First pull out DictConfig override
491
+ # aliases, then parse the rest as ordinary task inputs.
492
+ param_cfg_overrides, task_args = _extract_config_overrides(task, list(ctx.args))
493
+ task_kwargs = _parse_task_kwargs(task, task_args, ctx)
494
+
495
+ # Combine all overrides — hydra_sweep / hydra_run handle separation
496
+ # of hydra-namespace vs app-level overrides internally.
497
+ all_overrides = param_cfg_overrides + list(hydra_overrides)
498
+
499
+ if multirun:
500
+ from flyteplugins.hydra._run import hydra_sweep
501
+
502
+ runs = hydra_sweep(
503
+ task,
504
+ config_path=config_path,
505
+ config_name=config_name,
506
+ overrides=all_overrides,
507
+ mode=execution_mode,
508
+ wait=wait,
509
+ wait_max_workers=wait_max_workers,
510
+ run_options=run_options or None,
511
+ task_env_key=task_env_key,
512
+ **task_kwargs,
513
+ )
514
+ for i, run in enumerate(runs):
515
+ value = _completed_result_value(run)
516
+ if value is not None:
517
+ click.echo(f"[{i}] result={value}")
518
+ if run_args.follow and execution_mode == "remote":
519
+ _follow_run_logs(run)
520
+ else:
521
+ from flyteplugins.hydra._run import hydra_run
522
+
523
+ run = hydra_run(
524
+ task,
525
+ config_path=config_path,
526
+ config_name=config_name,
527
+ overrides=all_overrides,
528
+ mode=execution_mode,
529
+ wait=wait,
530
+ wait_max_workers=wait_max_workers,
531
+ run_options=run_options or None,
532
+ task_env_key=task_env_key,
533
+ **task_kwargs,
534
+ )
535
+ value = _completed_result_value(run)
536
+ if value is not None:
537
+ click.echo(value)
538
+ if run_args.follow and execution_mode == "remote":
539
+ _follow_run_logs(run)
540
+
541
+
542
+ # Dynamically inherit all standard ``flyte run`` options.
543
+ # Reuses RunArguments.options() so that new options added to ``flyte run``
544
+ # are automatically available on ``flyte hydra run`` without duplication.
545
+ # If a future ``flyte run`` option collides with a hydra-specific flag, the
546
+ # import fails immediately with a clear error rather than silently breaking.
547
+ def _extend_with_run_options() -> None:
548
+ from flyte.cli._run import RunArguments
549
+
550
+ hydra_option_names = {p.name for p in hydra_run_cmd.params}
551
+ unsupported_options = {
552
+ "run_project",
553
+ "run_domain",
554
+ "tui",
555
+ }
556
+ for opt in RunArguments.options():
557
+ if opt.name in unsupported_options:
558
+ continue
559
+ if opt.name in hydra_option_names:
560
+ raise RuntimeError(
561
+ f"flyte run option '{opt.name}' conflicts with a hydra-specific "
562
+ f"option on 'flyte hydra run'. The flyteplugins-hydra plugin "
563
+ f"needs to be updated to resolve this collision."
564
+ )
565
+ hydra_run_cmd.params.append(opt)
566
+
567
+
568
+ _extend_with_run_options()