flyteplugins-hydra 2.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flyteplugins/hydra/__init__.py +35 -0
- flyteplugins/hydra/_cli.py +568 -0
- flyteplugins/hydra/_launcher.py +228 -0
- flyteplugins/hydra/_run.py +587 -0
- flyteplugins_hydra-2.1.9.dist-info/METADATA +504 -0
- flyteplugins_hydra-2.1.9.dist-info/RECORD +10 -0
- flyteplugins_hydra-2.1.9.dist-info/WHEEL +5 -0
- flyteplugins_hydra-2.1.9.dist-info/entry_points.txt +2 -0
- flyteplugins_hydra-2.1.9.dist-info/top_level.txt +2 -0
- hydra_plugins/hydra_flyte_launcher/__init__.py +37 -0
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""flyteplugins-hydra — Hydra launcher plugin for Flyte.
|
|
2
|
+
|
|
3
|
+
Provides three entry points for running Flyte tasks via Hydra:
|
|
4
|
+
|
|
5
|
+
1. **``@hydra.main`` + ``--multirun``** (standard Hydra CLI pattern):
|
|
6
|
+
|
|
7
|
+
.. code-block:: bash
|
|
8
|
+
|
|
9
|
+
python train.py hydra/launcher=flyte hydra.launcher.mode=remote
|
|
10
|
+
python train.py --multirun hydra/launcher=flyte hydra.launcher.mode=remote \\
|
|
11
|
+
optimizer.lr=0.001,0.01,0.1
|
|
12
|
+
|
|
13
|
+
2. **``flyte hydra run``** (Flyte CLI extension, no ``@hydra.main`` required):
|
|
14
|
+
|
|
15
|
+
.. code-block:: bash
|
|
16
|
+
|
|
17
|
+
flyte hydra run --config-path conf --config-name training --mode remote \\
|
|
18
|
+
train.py pipeline --cfg optimizer.lr=0.01
|
|
19
|
+
|
|
20
|
+
3. **``hydra_run`` / ``hydra_sweep``** (Python SDK):
|
|
21
|
+
|
|
22
|
+
.. code-block:: python
|
|
23
|
+
|
|
24
|
+
from flyteplugins.hydra import hydra_run, hydra_sweep
|
|
25
|
+
|
|
26
|
+
hydra_run(pipeline, config_path="conf", config_name="training",
|
|
27
|
+
overrides=["optimizer.lr=0.01"], mode="remote")
|
|
28
|
+
|
|
29
|
+
runs = hydra_sweep(pipeline, config_path="conf", config_name="training",
|
|
30
|
+
overrides=["optimizer.lr=0.001,0.01,0.1"], mode="remote")
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
from flyteplugins.hydra._run import apply_task_env, hydra_run, hydra_sweep
|
|
34
|
+
|
|
35
|
+
__all__ = ["apply_task_env", "hydra_run", "hydra_sweep"]
|
|
@@ -0,0 +1,568 @@
|
|
|
1
|
+
"""``flyte hydra`` CLI command group.
|
|
2
|
+
|
|
3
|
+
Registered via the ``flyte.plugins.cli.commands`` entry point so that
|
|
4
|
+
``flyte hydra run`` is available once ``flyteplugins-hydra`` is installed.
|
|
5
|
+
|
|
6
|
+
Inherits the standard ``flyte run`` flags that apply to script execution
|
|
7
|
+
(``--project``, ``--domain``, ``--local``, ``--image``, ``--follow``, etc.).
|
|
8
|
+
Hydra-specific options are:
|
|
9
|
+
``--config-path``, ``--config-name``, ``--mode``, ``--multirun``,
|
|
10
|
+
``--wait/--no-wait``, ``--wait-max-workers``, ``--task-env-key``,
|
|
11
|
+
``--hydra-override``.
|
|
12
|
+
Application config overrides use the task's ``DictConfig`` parameter name,
|
|
13
|
+
for example ``--cfg`` for ``cfg: DictConfig`` or ``--config`` for
|
|
14
|
+
``config: DictConfig``.
|
|
15
|
+
|
|
16
|
+
Usage
|
|
17
|
+
-----
|
|
18
|
+
Single run (remote by default)::
|
|
19
|
+
|
|
20
|
+
flyte hydra run --config-path conf --config-name training \\
|
|
21
|
+
train.py pipeline \\
|
|
22
|
+
--cfg optimizer.lr=0.01
|
|
23
|
+
|
|
24
|
+
Single run forced local::
|
|
25
|
+
|
|
26
|
+
flyte hydra run --local --config-path conf --config-name training \\
|
|
27
|
+
train.py pipeline
|
|
28
|
+
|
|
29
|
+
Grid sweep (six parallel remote executions)::
|
|
30
|
+
|
|
31
|
+
flyte hydra run --multirun --config-path conf --config-name training \\
|
|
32
|
+
train.py pipeline \\
|
|
33
|
+
--cfg "optimizer.lr=0.001,0.01,0.1" --cfg "training.epochs=10,20"
|
|
34
|
+
|
|
35
|
+
TPE/Bayesian sweep via Optuna sweeper::
|
|
36
|
+
|
|
37
|
+
flyte hydra run --multirun --config-path conf --config-name training \\
|
|
38
|
+
train.py pipeline \\
|
|
39
|
+
--hydra-override hydra/sweeper=optuna \\
|
|
40
|
+
--hydra-override hydra.sweeper.n_trials=20 \\
|
|
41
|
+
--hydra-override hydra.sweeper.n_jobs=4 \\
|
|
42
|
+
--cfg "optimizer.lr=interval(1e-4,1e-1)"
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
from __future__ import annotations
|
|
46
|
+
|
|
47
|
+
import importlib.util
|
|
48
|
+
import inspect
|
|
49
|
+
import sys
|
|
50
|
+
from pathlib import Path
|
|
51
|
+
|
|
52
|
+
import rich_click as click
|
|
53
|
+
from click.shell_completion import CompletionItem
|
|
54
|
+
|
|
55
|
+
_HYDRA_OVERRIDE_OPTION = "--hydra-override"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _follow_run_logs(run) -> None:
|
|
59
|
+
"""Show logs for a returned remote Run when ``--follow`` is set.
|
|
60
|
+
|
|
61
|
+
``hydra_run`` / ``hydra_sweep`` return whatever ``flyte.run`` returned.
|
|
62
|
+
In remote mode that should be a ``flyte.remote.Run`` with ``show_logs``;
|
|
63
|
+
in local mode it is a local result wrapper, so this helper quietly skips
|
|
64
|
+
objects that do not expose remote logs.
|
|
65
|
+
"""
|
|
66
|
+
show_logs = getattr(run, "show_logs", None)
|
|
67
|
+
if show_logs is not None:
|
|
68
|
+
show_logs(max_lines=30, show_ts=True, raw=False)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _completed_result_value(run):
|
|
72
|
+
"""Return a completed result value without re-printing remote run URLs."""
|
|
73
|
+
if hasattr(run, "value"):
|
|
74
|
+
return run.value
|
|
75
|
+
if getattr(run, "url", None) is None:
|
|
76
|
+
return run
|
|
77
|
+
return None
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _load_script_task(script: str, task_name: str):
|
|
81
|
+
"""Load script as a module and return the requested Flyte task."""
|
|
82
|
+
script_path = Path(script).resolve()
|
|
83
|
+
module_name = script_path.stem
|
|
84
|
+
sys.path.append(str(script_path.parent))
|
|
85
|
+
spec = importlib.util.spec_from_file_location(module_name, script_path)
|
|
86
|
+
|
|
87
|
+
if spec is None or spec.loader is None:
|
|
88
|
+
raise click.ClickException(f"Could not load module from {script}")
|
|
89
|
+
|
|
90
|
+
mod = importlib.util.module_from_spec(spec)
|
|
91
|
+
sys.modules[module_name] = mod
|
|
92
|
+
spec.loader.exec_module(mod)
|
|
93
|
+
|
|
94
|
+
task = getattr(mod, task_name, None)
|
|
95
|
+
if task is None:
|
|
96
|
+
raise click.ClickException(f"Task '{task_name}' not found in {script}")
|
|
97
|
+
return task
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _script_task_and_tail(ctx: click.Context) -> tuple[str | None, str | None, list[str]]:
|
|
101
|
+
"""Return SCRIPT, TASK_NAME, and remaining task-tail args from a Click context."""
|
|
102
|
+
script = ctx.params.get("script")
|
|
103
|
+
task_name = ctx.params.get("task_name")
|
|
104
|
+
args = list(ctx.args)
|
|
105
|
+
if script and task_name:
|
|
106
|
+
return script, task_name, args
|
|
107
|
+
if len(args) >= 2:
|
|
108
|
+
return args[0], args[1], args[2:]
|
|
109
|
+
return None, None, args
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _extract_config_overrides(task, args: list[str]) -> tuple[list[str], list[str]]:
|
|
113
|
+
"""Split DictConfig override flags out of the task-argument tail.
|
|
114
|
+
|
|
115
|
+
``flyte hydra run`` names application config override flags after the
|
|
116
|
+
task's ``DictConfig`` input. For ``cfg: DictConfig`` users pass
|
|
117
|
+
``--cfg optimizer.lr=0.01``; for ``config: DictConfig`` they pass
|
|
118
|
+
``--config optimizer.lr=0.01``. These flags sit after ``SCRIPT TASK_NAME``
|
|
119
|
+
beside ordinary task args, so Click cannot parse them with fixed command
|
|
120
|
+
options. This helper scans that tail, returns the extracted Hydra override
|
|
121
|
+
strings, and leaves all other args for normal Flyte task-parameter parsing.
|
|
122
|
+
"""
|
|
123
|
+
from flyteplugins.hydra._run import _config_param_names
|
|
124
|
+
|
|
125
|
+
config_param_names = set(_config_param_names(task))
|
|
126
|
+
if not config_param_names:
|
|
127
|
+
return [], args
|
|
128
|
+
|
|
129
|
+
config_options = {f"--{name}" for name in config_param_names}
|
|
130
|
+
overrides: list[str] = []
|
|
131
|
+
remaining: list[str] = []
|
|
132
|
+
idx = 0
|
|
133
|
+
|
|
134
|
+
while idx < len(args):
|
|
135
|
+
arg = args[idx]
|
|
136
|
+
matched = False
|
|
137
|
+
|
|
138
|
+
for option in config_options:
|
|
139
|
+
# Support both "--config-param value" and "--config-param=value".
|
|
140
|
+
if arg == option:
|
|
141
|
+
if idx + 1 >= len(args):
|
|
142
|
+
raise click.UsageError(f"Option '{option}' requires an override value.")
|
|
143
|
+
overrides.append(args[idx + 1])
|
|
144
|
+
idx += 2
|
|
145
|
+
matched = True
|
|
146
|
+
break
|
|
147
|
+
if arg.startswith(f"{option}="):
|
|
148
|
+
overrides.append(arg.split("=", 1)[1])
|
|
149
|
+
idx += 1
|
|
150
|
+
matched = True
|
|
151
|
+
break
|
|
152
|
+
|
|
153
|
+
if matched:
|
|
154
|
+
continue
|
|
155
|
+
|
|
156
|
+
remaining.append(arg)
|
|
157
|
+
idx += 1
|
|
158
|
+
|
|
159
|
+
return overrides, remaining
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _override_completion_context(
|
|
163
|
+
args: list[str],
|
|
164
|
+
incomplete: str,
|
|
165
|
+
override_options: set[str],
|
|
166
|
+
) -> tuple[list[str], str, str] | None:
|
|
167
|
+
"""Return previous overrides, current override prefix, and replacement prefix.
|
|
168
|
+
|
|
169
|
+
``flyte hydra run`` carries Hydra overrides as values to dynamic options
|
|
170
|
+
such as ``--cfg`` or ``--config``. During shell completion Click gives us
|
|
171
|
+
only the already-complete tail args plus the current incomplete word, so we
|
|
172
|
+
scan that tail ourselves to decide whether the cursor is completing a
|
|
173
|
+
Hydra override value.
|
|
174
|
+
"""
|
|
175
|
+
for option in override_options:
|
|
176
|
+
prefix = f"{option}="
|
|
177
|
+
if incomplete.startswith(prefix):
|
|
178
|
+
return _collect_complete_overrides(args, override_options), incomplete[len(prefix) :], prefix
|
|
179
|
+
|
|
180
|
+
if not args or args[-1] not in override_options:
|
|
181
|
+
return None
|
|
182
|
+
|
|
183
|
+
return _collect_complete_overrides(args[:-1], override_options), incomplete, ""
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _collect_complete_overrides(args: list[str], override_options: set[str]) -> list[str]:
|
|
187
|
+
"""Collect complete Hydra override values from a task-argument tail."""
|
|
188
|
+
overrides: list[str] = []
|
|
189
|
+
pending_override = False
|
|
190
|
+
|
|
191
|
+
for arg in args:
|
|
192
|
+
if pending_override:
|
|
193
|
+
overrides.append(arg)
|
|
194
|
+
pending_override = False
|
|
195
|
+
continue
|
|
196
|
+
|
|
197
|
+
if arg in override_options:
|
|
198
|
+
pending_override = True
|
|
199
|
+
continue
|
|
200
|
+
|
|
201
|
+
for option in override_options:
|
|
202
|
+
prefix = f"{option}="
|
|
203
|
+
if arg.startswith(prefix):
|
|
204
|
+
overrides.append(arg[len(prefix) :])
|
|
205
|
+
break
|
|
206
|
+
|
|
207
|
+
return overrides
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _complete_hydra_override_values(
|
|
211
|
+
*,
|
|
212
|
+
config_path: str | None,
|
|
213
|
+
config_name: str,
|
|
214
|
+
multirun: bool,
|
|
215
|
+
previous_overrides: list[str],
|
|
216
|
+
incomplete: str,
|
|
217
|
+
) -> list[str]:
|
|
218
|
+
"""Ask Hydra's own completion engine for override-value suggestions."""
|
|
219
|
+
from flyteplugins.hydra._run import _hydra_init
|
|
220
|
+
from hydra.plugins.completion_plugin import DefaultCompletionPlugin
|
|
221
|
+
|
|
222
|
+
parts = ["--multirun"] if multirun else []
|
|
223
|
+
parts.extend(previous_overrides)
|
|
224
|
+
line = " ".join([*parts, incomplete]).strip()
|
|
225
|
+
if not incomplete:
|
|
226
|
+
line = f"{line} " if line else ""
|
|
227
|
+
|
|
228
|
+
with _hydra_init(config_path) as config_loader:
|
|
229
|
+
completer = DefaultCompletionPlugin(config_loader)
|
|
230
|
+
return completer._query(config_name=config_name, line=line)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _hydra_override_option_complete(ctx: click.Context, _param, incomplete: str) -> list[CompletionItem]:
|
|
234
|
+
"""Complete values for the declared ``--hydra-override`` Click option."""
|
|
235
|
+
script, task_name, task_tail = _script_task_and_tail(ctx)
|
|
236
|
+
config_options: set[str] = set()
|
|
237
|
+
if script and task_name:
|
|
238
|
+
try:
|
|
239
|
+
task = _load_script_task(script, task_name)
|
|
240
|
+
from flyteplugins.hydra._run import _config_param_names
|
|
241
|
+
|
|
242
|
+
config_options = {f"--{name}" for name in _config_param_names(task)}
|
|
243
|
+
except Exception:
|
|
244
|
+
config_options = set()
|
|
245
|
+
|
|
246
|
+
previous_overrides = list(ctx.params.get("hydra_overrides") or ())
|
|
247
|
+
previous_overrides.extend(_collect_complete_overrides(task_tail, config_options))
|
|
248
|
+
|
|
249
|
+
try:
|
|
250
|
+
suggestions = _complete_hydra_override_values(
|
|
251
|
+
config_path=ctx.params.get("config_path"),
|
|
252
|
+
config_name=ctx.params.get("config_name") or "config",
|
|
253
|
+
multirun=bool(ctx.params.get("multirun")),
|
|
254
|
+
previous_overrides=previous_overrides,
|
|
255
|
+
incomplete=incomplete,
|
|
256
|
+
)
|
|
257
|
+
except Exception:
|
|
258
|
+
return []
|
|
259
|
+
|
|
260
|
+
return [CompletionItem(suggestion) for suggestion in suggestions]
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
class HydraRunCommand(click.RichCommand):
|
|
264
|
+
"""Click command that adds Hydra override completions after SCRIPT TASK."""
|
|
265
|
+
|
|
266
|
+
def shell_complete(self, ctx: click.Context, incomplete: str) -> list[CompletionItem]:
|
|
267
|
+
results = super().shell_complete(ctx, incomplete)
|
|
268
|
+
script, task_name, task_tail = _script_task_and_tail(ctx)
|
|
269
|
+
if not script or not task_name:
|
|
270
|
+
return results
|
|
271
|
+
|
|
272
|
+
try:
|
|
273
|
+
task = _load_script_task(script, task_name)
|
|
274
|
+
except Exception:
|
|
275
|
+
return results
|
|
276
|
+
|
|
277
|
+
from flyteplugins.hydra._run import _config_param_names
|
|
278
|
+
|
|
279
|
+
config_options = {f"--{name}" for name in _config_param_names(task)}
|
|
280
|
+
override_options = {*config_options, _HYDRA_OVERRIDE_OPTION}
|
|
281
|
+
|
|
282
|
+
if incomplete.startswith("-"):
|
|
283
|
+
existing = {item.value for item in results}
|
|
284
|
+
results.extend(
|
|
285
|
+
CompletionItem(option, help="Hydra app-level override")
|
|
286
|
+
for option in sorted(config_options)
|
|
287
|
+
if option.startswith(incomplete) and option not in existing
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
completion_context = _override_completion_context(task_tail, incomplete, override_options)
|
|
291
|
+
if completion_context is None:
|
|
292
|
+
return results
|
|
293
|
+
|
|
294
|
+
previous_overrides, current_override, replacement_prefix = completion_context
|
|
295
|
+
try:
|
|
296
|
+
suggestions = _complete_hydra_override_values(
|
|
297
|
+
config_path=ctx.params.get("config_path"),
|
|
298
|
+
config_name=ctx.params.get("config_name") or "config",
|
|
299
|
+
multirun=bool(ctx.params.get("multirun")),
|
|
300
|
+
previous_overrides=previous_overrides,
|
|
301
|
+
incomplete=current_override,
|
|
302
|
+
)
|
|
303
|
+
except Exception:
|
|
304
|
+
return results
|
|
305
|
+
|
|
306
|
+
results.extend(CompletionItem(f"{replacement_prefix}{suggestion}") for suggestion in suggestions)
|
|
307
|
+
return results
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def _parse_task_kwargs(task, args: list[str], parent_ctx: click.Context) -> dict:
|
|
311
|
+
"""Convert ordinary task CLI flags into Python kwargs for ``flyte.run``.
|
|
312
|
+
|
|
313
|
+
The Hydra command has to load the user script before it can know the task
|
|
314
|
+
interface. Once the task is available, this function builds a temporary
|
|
315
|
+
Click command from Flyte's typed interface using the same ``to_click_option``
|
|
316
|
+
converters as ``flyte run``. ``DictConfig`` inputs are intentionally
|
|
317
|
+
skipped because those are composed by Hydra and injected by ``hydra_run`` /
|
|
318
|
+
``hydra_sweep``.
|
|
319
|
+
"""
|
|
320
|
+
from flyte._internal.runtime.types_serde import transform_native_to_typed_interface
|
|
321
|
+
from flyte.cli._params import to_click_option
|
|
322
|
+
|
|
323
|
+
from flyteplugins.hydra._run import _config_param_names
|
|
324
|
+
|
|
325
|
+
interface = transform_native_to_typed_interface(task.native_interface)
|
|
326
|
+
if interface is None:
|
|
327
|
+
return {}
|
|
328
|
+
|
|
329
|
+
inputs_interface = task.native_interface.inputs
|
|
330
|
+
config_param_names = set(_config_param_names(task))
|
|
331
|
+
params = []
|
|
332
|
+
|
|
333
|
+
for entry in interface.inputs.variables:
|
|
334
|
+
name, var = entry.key, entry.value
|
|
335
|
+
if name in config_param_names:
|
|
336
|
+
continue
|
|
337
|
+
|
|
338
|
+
default_val = None
|
|
339
|
+
if inputs_interface[name][1] is not inspect._empty:
|
|
340
|
+
default_val = inputs_interface[name][1]
|
|
341
|
+
|
|
342
|
+
params.append(to_click_option(name, var, inputs_interface[name][0], default_val))
|
|
343
|
+
|
|
344
|
+
def _collect(**kwargs):
|
|
345
|
+
return kwargs
|
|
346
|
+
|
|
347
|
+
# Let Click apply Flyte's normal type conversion and required/default
|
|
348
|
+
# validation for the remaining task inputs.
|
|
349
|
+
parser = click.Command(
|
|
350
|
+
name="task-args",
|
|
351
|
+
params=params,
|
|
352
|
+
callback=_collect,
|
|
353
|
+
)
|
|
354
|
+
return parser.main(
|
|
355
|
+
args=args,
|
|
356
|
+
prog_name="task arguments",
|
|
357
|
+
standalone_mode=False,
|
|
358
|
+
obj=parent_ctx.obj,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
@click.group(name="hydra")
|
|
363
|
+
def hydra_group() -> None:
|
|
364
|
+
"""Run Flyte tasks via Hydra config composition and sweeping."""
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
@hydra_group.command(
|
|
368
|
+
name="run",
|
|
369
|
+
cls=HydraRunCommand,
|
|
370
|
+
context_settings={"ignore_unknown_options": True, "allow_extra_args": True},
|
|
371
|
+
)
|
|
372
|
+
@click.argument("script", type=click.Path(exists=True, dir_okay=False))
|
|
373
|
+
@click.argument("task_name")
|
|
374
|
+
@click.option("--config-path", default=None, help="Path to Hydra config directory.")
|
|
375
|
+
@click.option(
|
|
376
|
+
"--config-name",
|
|
377
|
+
default="config",
|
|
378
|
+
show_default=True,
|
|
379
|
+
help="Top-level config file name (without .yaml).",
|
|
380
|
+
)
|
|
381
|
+
@click.option(
|
|
382
|
+
"--mode",
|
|
383
|
+
type=click.Choice(["local", "remote"]),
|
|
384
|
+
default=None,
|
|
385
|
+
help="Execution mode. Defaults to remote for Flyte CLI parity.",
|
|
386
|
+
)
|
|
387
|
+
@click.option(
|
|
388
|
+
"--multirun",
|
|
389
|
+
is_flag=True,
|
|
390
|
+
default=False,
|
|
391
|
+
help="Expand sweep overrides into a grid of executions.",
|
|
392
|
+
)
|
|
393
|
+
@click.option(
|
|
394
|
+
"--wait/--no-wait",
|
|
395
|
+
default=True,
|
|
396
|
+
show_default=True,
|
|
397
|
+
help="Wait for remote Flyte runs to reach a terminal phase.",
|
|
398
|
+
)
|
|
399
|
+
@click.option(
|
|
400
|
+
"--wait-max-workers",
|
|
401
|
+
type=click.IntRange(min=1),
|
|
402
|
+
default=32,
|
|
403
|
+
show_default=True,
|
|
404
|
+
help="Maximum worker threads used while waiting for remote Flyte runs.",
|
|
405
|
+
)
|
|
406
|
+
@click.option(
|
|
407
|
+
"--task-env-key",
|
|
408
|
+
default="task_env",
|
|
409
|
+
show_default=True,
|
|
410
|
+
help="Config key containing entry-task task.override kwargs by task name.",
|
|
411
|
+
)
|
|
412
|
+
@click.option(
|
|
413
|
+
"--hydra-override",
|
|
414
|
+
"hydra_overrides",
|
|
415
|
+
multiple=True,
|
|
416
|
+
metavar="KEY=VALUE",
|
|
417
|
+
shell_complete=_hydra_override_option_complete,
|
|
418
|
+
help=("Hydra-namespace override (repeatable), e.g. hydra/sweeper=optuna or hydra.sweeper.n_trials=20."),
|
|
419
|
+
)
|
|
420
|
+
@click.pass_context
|
|
421
|
+
def hydra_run_cmd(
|
|
422
|
+
ctx: click.Context,
|
|
423
|
+
script: str,
|
|
424
|
+
task_name: str,
|
|
425
|
+
config_path: str | None,
|
|
426
|
+
config_name: str,
|
|
427
|
+
mode: str | None,
|
|
428
|
+
multirun: bool,
|
|
429
|
+
wait: bool,
|
|
430
|
+
wait_max_workers: int,
|
|
431
|
+
task_env_key: str,
|
|
432
|
+
hydra_overrides: tuple[str, ...],
|
|
433
|
+
**run_params,
|
|
434
|
+
) -> None:
|
|
435
|
+
"""Compose a Hydra config and run TASK_NAME from SCRIPT on Flyte.
|
|
436
|
+
|
|
437
|
+
SCRIPT is the path to a Python file containing the Flyte task.
|
|
438
|
+
TASK_NAME is the name of the task function to run.
|
|
439
|
+
|
|
440
|
+
Use the task's ``DictConfig`` parameter name for app-level overrides
|
|
441
|
+
(for example ``--cfg`` or ``--config``).
|
|
442
|
+
Use ``--hydra-override`` for hydra-namespace settings (hydra/sweeper=optuna).
|
|
443
|
+
"""
|
|
444
|
+
from flyte.cli._common import initialize_config
|
|
445
|
+
from flyte.cli._run import RunArguments
|
|
446
|
+
|
|
447
|
+
run_args = RunArguments.from_dict(run_params)
|
|
448
|
+
|
|
449
|
+
# Initialise Flyte.
|
|
450
|
+
config = initialize_config(
|
|
451
|
+
ctx,
|
|
452
|
+
run_args.project,
|
|
453
|
+
run_args.domain,
|
|
454
|
+
root_dir=run_args.root_dir,
|
|
455
|
+
images=run_args.image or None,
|
|
456
|
+
sync_local_sys_paths=not run_args.no_sync_local_sys_paths,
|
|
457
|
+
)
|
|
458
|
+
ctx.obj = config.replace(run_args=run_args)
|
|
459
|
+
|
|
460
|
+
if run_args.local and mode == "remote":
|
|
461
|
+
raise click.UsageError("Use either --local or --mode remote, not both.")
|
|
462
|
+
if run_args.follow and not wait:
|
|
463
|
+
raise click.UsageError("Use either --follow or --no-wait, not both.")
|
|
464
|
+
|
|
465
|
+
execution_mode = "local" if run_args.local else (mode or "remote")
|
|
466
|
+
|
|
467
|
+
# Only forward options that are accepted by flyte.with_runcontext and are
|
|
468
|
+
# meaningful for an in-process script task. Project/domain/image/root-dir
|
|
469
|
+
# are handled by initialize_config above; follow is handled after launch.
|
|
470
|
+
run_options: dict = {
|
|
471
|
+
"log_format": config.log_format,
|
|
472
|
+
"reset_root_logger": config.reset_root_logger,
|
|
473
|
+
}
|
|
474
|
+
if run_args.service_account is not None:
|
|
475
|
+
run_options["service_account"] = run_args.service_account
|
|
476
|
+
if run_args.name is not None:
|
|
477
|
+
run_options["name"] = run_args.name
|
|
478
|
+
if run_args.raw_data_path is not None:
|
|
479
|
+
run_options["raw_data_path"] = run_args.raw_data_path
|
|
480
|
+
if run_args.copy_style != "loaded_modules":
|
|
481
|
+
run_options["copy_style"] = run_args.copy_style
|
|
482
|
+
if run_args.debug:
|
|
483
|
+
run_options["debug"] = True
|
|
484
|
+
|
|
485
|
+
# Load the script as a module so Flyte task decorators have run before we
|
|
486
|
+
# inspect the requested task's typed interface.
|
|
487
|
+
task = _load_script_task(script, task_name)
|
|
488
|
+
|
|
489
|
+
# ctx.args contains everything after SCRIPT TASK_NAME that was not consumed
|
|
490
|
+
# by the fixed Hydra/Flyte options. First pull out DictConfig override
|
|
491
|
+
# aliases, then parse the rest as ordinary task inputs.
|
|
492
|
+
param_cfg_overrides, task_args = _extract_config_overrides(task, list(ctx.args))
|
|
493
|
+
task_kwargs = _parse_task_kwargs(task, task_args, ctx)
|
|
494
|
+
|
|
495
|
+
# Combine all overrides — hydra_sweep / hydra_run handle separation
|
|
496
|
+
# of hydra-namespace vs app-level overrides internally.
|
|
497
|
+
all_overrides = param_cfg_overrides + list(hydra_overrides)
|
|
498
|
+
|
|
499
|
+
if multirun:
|
|
500
|
+
from flyteplugins.hydra._run import hydra_sweep
|
|
501
|
+
|
|
502
|
+
runs = hydra_sweep(
|
|
503
|
+
task,
|
|
504
|
+
config_path=config_path,
|
|
505
|
+
config_name=config_name,
|
|
506
|
+
overrides=all_overrides,
|
|
507
|
+
mode=execution_mode,
|
|
508
|
+
wait=wait,
|
|
509
|
+
wait_max_workers=wait_max_workers,
|
|
510
|
+
run_options=run_options or None,
|
|
511
|
+
task_env_key=task_env_key,
|
|
512
|
+
**task_kwargs,
|
|
513
|
+
)
|
|
514
|
+
for i, run in enumerate(runs):
|
|
515
|
+
value = _completed_result_value(run)
|
|
516
|
+
if value is not None:
|
|
517
|
+
click.echo(f"[{i}] result={value}")
|
|
518
|
+
if run_args.follow and execution_mode == "remote":
|
|
519
|
+
_follow_run_logs(run)
|
|
520
|
+
else:
|
|
521
|
+
from flyteplugins.hydra._run import hydra_run
|
|
522
|
+
|
|
523
|
+
run = hydra_run(
|
|
524
|
+
task,
|
|
525
|
+
config_path=config_path,
|
|
526
|
+
config_name=config_name,
|
|
527
|
+
overrides=all_overrides,
|
|
528
|
+
mode=execution_mode,
|
|
529
|
+
wait=wait,
|
|
530
|
+
wait_max_workers=wait_max_workers,
|
|
531
|
+
run_options=run_options or None,
|
|
532
|
+
task_env_key=task_env_key,
|
|
533
|
+
**task_kwargs,
|
|
534
|
+
)
|
|
535
|
+
value = _completed_result_value(run)
|
|
536
|
+
if value is not None:
|
|
537
|
+
click.echo(value)
|
|
538
|
+
if run_args.follow and execution_mode == "remote":
|
|
539
|
+
_follow_run_logs(run)
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
# Dynamically inherit all standard ``flyte run`` options.
|
|
543
|
+
# Reuses RunArguments.options() so that new options added to ``flyte run``
|
|
544
|
+
# are automatically available on ``flyte hydra run`` without duplication.
|
|
545
|
+
# If a future ``flyte run`` option collides with a hydra-specific flag, the
|
|
546
|
+
# import fails immediately with a clear error rather than silently breaking.
|
|
547
|
+
def _extend_with_run_options() -> None:
|
|
548
|
+
from flyte.cli._run import RunArguments
|
|
549
|
+
|
|
550
|
+
hydra_option_names = {p.name for p in hydra_run_cmd.params}
|
|
551
|
+
unsupported_options = {
|
|
552
|
+
"run_project",
|
|
553
|
+
"run_domain",
|
|
554
|
+
"tui",
|
|
555
|
+
}
|
|
556
|
+
for opt in RunArguments.options():
|
|
557
|
+
if opt.name in unsupported_options:
|
|
558
|
+
continue
|
|
559
|
+
if opt.name in hydra_option_names:
|
|
560
|
+
raise RuntimeError(
|
|
561
|
+
f"flyte run option '{opt.name}' conflicts with a hydra-specific "
|
|
562
|
+
f"option on 'flyte hydra run'. The flyteplugins-hydra plugin "
|
|
563
|
+
f"needs to be updated to resolve this collision."
|
|
564
|
+
)
|
|
565
|
+
hydra_run_cmd.params.append(opt)
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
_extend_with_run_options()
|