flyte 2.0.0b9__py3-none-any.whl → 2.0.0b14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (60) hide show
  1. flyte/__init__.py +55 -31
  2. flyte/_bin/debug.py +38 -0
  3. flyte/_bin/runtime.py +13 -0
  4. flyte/_code_bundle/_utils.py +2 -0
  5. flyte/_code_bundle/bundle.py +4 -4
  6. flyte/_context.py +1 -1
  7. flyte/_debug/__init__.py +0 -0
  8. flyte/_debug/constants.py +39 -0
  9. flyte/_debug/utils.py +17 -0
  10. flyte/_debug/vscode.py +300 -0
  11. flyte/_environment.py +5 -5
  12. flyte/_image.py +34 -19
  13. flyte/_initialize.py +15 -29
  14. flyte/_internal/controllers/remote/_action.py +2 -2
  15. flyte/_internal/controllers/remote/_controller.py +1 -1
  16. flyte/_internal/imagebuild/docker_builder.py +11 -15
  17. flyte/_internal/imagebuild/remote_builder.py +71 -22
  18. flyte/_internal/runtime/entrypoints.py +3 -0
  19. flyte/_internal/runtime/reuse.py +7 -3
  20. flyte/_internal/runtime/task_serde.py +4 -3
  21. flyte/_internal/runtime/taskrunner.py +9 -3
  22. flyte/_logging.py +5 -2
  23. flyte/_protos/common/identifier_pb2.py +25 -19
  24. flyte/_protos/common/identifier_pb2.pyi +10 -0
  25. flyte/_protos/imagebuilder/definition_pb2.py +32 -31
  26. flyte/_protos/imagebuilder/definition_pb2.pyi +25 -12
  27. flyte/_protos/workflow/queue_service_pb2.py +24 -24
  28. flyte/_protos/workflow/queue_service_pb2.pyi +6 -6
  29. flyte/_protos/workflow/run_definition_pb2.py +48 -48
  30. flyte/_protos/workflow/run_definition_pb2.pyi +20 -10
  31. flyte/_reusable_environment.py +41 -19
  32. flyte/_run.py +9 -9
  33. flyte/_secret.py +9 -5
  34. flyte/_task.py +16 -11
  35. flyte/_task_environment.py +11 -13
  36. flyte/_tools.py +0 -13
  37. flyte/_version.py +16 -3
  38. flyte/cli/_build.py +2 -3
  39. flyte/cli/_common.py +16 -5
  40. flyte/cli/_gen.py +10 -1
  41. flyte/cli/_get.py +16 -14
  42. flyte/cli/_run.py +258 -25
  43. flyte/models.py +9 -0
  44. flyte/remote/_client/auth/_authenticators/base.py +8 -2
  45. flyte/remote/_client/auth/_authenticators/device_code.py +1 -1
  46. flyte/remote/_client/auth/_authenticators/pkce.py +1 -1
  47. flyte/remote/_client/auth/_channel.py +0 -6
  48. flyte/remote/_client/auth/_client_config.py +4 -2
  49. flyte/remote/_client/controlplane.py +14 -0
  50. flyte/remote/_task.py +18 -4
  51. flyte/storage/_storage.py +83 -7
  52. flyte/types/_type_engine.py +3 -33
  53. flyte-2.0.0b14.data/scripts/debug.py +38 -0
  54. {flyte-2.0.0b9.data → flyte-2.0.0b14.data}/scripts/runtime.py +13 -0
  55. {flyte-2.0.0b9.dist-info → flyte-2.0.0b14.dist-info}/METADATA +2 -2
  56. {flyte-2.0.0b9.dist-info → flyte-2.0.0b14.dist-info}/RECORD +60 -54
  57. {flyte-2.0.0b9.dist-info → flyte-2.0.0b14.dist-info}/WHEEL +0 -0
  58. {flyte-2.0.0b9.dist-info → flyte-2.0.0b14.dist-info}/entry_points.txt +0 -0
  59. {flyte-2.0.0b9.dist-info → flyte-2.0.0b14.dist-info}/licenses/LICENSE +0 -0
  60. {flyte-2.0.0b9.dist-info → flyte-2.0.0b14.dist-info}/top_level.txt +0 -0
flyte/cli/_gen.py CHANGED
@@ -67,8 +67,13 @@ def markdown(cfg: common.CLIConfig):
67
67
  output_verb_groups: dict[str, list[str]] = {}
68
68
  output_noun_groups: dict[str, list[str]] = {}
69
69
 
70
+ processed = []
70
71
  commands = [*[("flyte", ctx.command)], *walk_commands(ctx)]
71
72
  for cmd_path, cmd in commands:
73
+ if cmd in processed:
74
+ # We already processed this command, skip it
75
+ continue
76
+ processed.append(cmd)
72
77
  output.append("")
73
78
 
74
79
  cmd_path_parts = cmd_path.split(" ")
@@ -136,7 +141,11 @@ def markdown(cfg: common.CLIConfig):
136
141
  output_verb_index.append("| ------ | -- |")
137
142
  for verb, nouns in output_verb_groups.items():
138
143
  entries = [f"[`{noun}`](#flyte-{verb}-{noun})" for noun in nouns]
139
- output_verb_index.append(f"| `{verb}` | {', '.join(entries)} |")
144
+ if len(entries) == 0:
145
+ verb_link = f"[`{verb}`](#flyte-{verb})"
146
+ output_verb_index.append(f"| {verb_link} | - |")
147
+ else:
148
+ output_verb_index.append(f"| `{verb}` | {', '.join(entries)} |")
140
149
 
141
150
  output_noun_index = []
142
151
 
flyte/cli/_get.py CHANGED
@@ -5,8 +5,6 @@ import rich_click as click
5
5
  from rich.console import Console
6
6
  from rich.pretty import pretty_repr
7
7
 
8
- import flyte.remote._action
9
-
10
8
  from . import _common as common
11
9
 
12
10
 
@@ -76,7 +74,7 @@ def run(
76
74
  console = Console()
77
75
  if name:
78
76
  details = RunDetails.get(name=name)
79
- console.print(pretty_repr(details))
77
+ console.print(common.format(f"Run {name}", [details], "json"))
80
78
  else:
81
79
  console.print(common.format("Runs", Run.listall(limit=limit), cfg.output_format))
82
80
 
@@ -110,7 +108,7 @@ def task(
110
108
  if v is None:
111
109
  raise click.BadParameter(f"Task {name} not found.")
112
110
  t = v.fetch()
113
- console.print(pretty_repr(t))
111
+ console.print(common.format(f"Task {name}", [t], "json"))
114
112
  else:
115
113
  console.print(common.format("Tasks", Task.listall(by_task_name=name, limit=limit), cfg.output_format))
116
114
  else:
@@ -131,19 +129,22 @@ def action(
131
129
  """
132
130
  Get all actions for a run or details for a specific action.
133
131
  """
132
+ import flyte.remote as remote
134
133
 
135
134
  cfg.init(project=project, domain=domain)
136
135
 
137
136
  console = Console()
138
137
  if action_name:
139
- console.print(pretty_repr(flyte.remote._action.Action.get(run_name=run_name, name=action_name)))
140
- else:
141
- # List all actions for the run
142
138
  console.print(
143
139
  common.format(
144
- f"Actions for {run_name}", flyte.remote._action.Action.listall(for_run_name=run_name), cfg.output_format
140
+ f"Action {run_name}.{action_name}", [remote.Action.get(run_name=run_name, name=action_name)], "json"
145
141
  )
146
142
  )
143
+ else:
144
+ # List all actions for the run
145
+ console.print(
146
+ common.format(f"Actions for {run_name}", remote.Action.listall(for_run_name=run_name), cfg.output_format)
147
+ )
147
148
 
148
149
 
149
150
  @get.command(cls=common.CommandBase)
@@ -211,7 +212,7 @@ def logs(
211
212
  task.cancel()
212
213
 
213
214
  if action_name:
214
- obj = flyte.remote._action.Action.get(run_name=run_name, name=action_name)
215
+ obj = remote.Action.get(run_name=run_name, name=action_name)
215
216
  else:
216
217
  obj = remote.Run.get(run_name)
217
218
  asyncio.run(_run_log_view(obj))
@@ -235,7 +236,7 @@ def secret(
235
236
 
236
237
  console = Console()
237
238
  if name:
238
- console.print(pretty_repr(remote.Secret.get(name)))
239
+ console.print(common.format("Secret", [remote.Secret.get(name)], "json"))
239
240
  else:
240
241
  console.print(common.format("Secrets", remote.Secret.listall(), cfg.output_format))
241
242
 
@@ -275,24 +276,25 @@ def io(
275
276
  raise click.BadParameter("Cannot use both --inputs-only and --outputs-only")
276
277
 
277
278
  import flyte.remote as remote
279
+ from flyte.remote import ActionDetails, ActionInputs, ActionOutputs
278
280
 
279
281
  cfg.init(project=project, domain=domain)
280
282
  console = Console()
281
283
  if action_name:
282
- obj = flyte.remote._action.ActionDetails.get(run_name=run_name, name=action_name)
284
+ obj = ActionDetails.get(run_name=run_name, name=action_name)
283
285
  else:
284
286
  obj = remote.RunDetails.get(run_name)
285
287
 
286
288
  async def _get_io(
287
- details: Union[remote.RunDetails, flyte.remote._action.ActionDetails],
288
- ) -> Tuple[flyte.remote._action.ActionInputs | None, flyte.remote._action.ActionOutputs | None | str]:
289
+ details: Union[remote.RunDetails, ActionDetails],
290
+ ) -> Tuple[ActionInputs | None, ActionOutputs | None | str]:
289
291
  if inputs_only or outputs_only:
290
292
  if inputs_only:
291
293
  return await details.inputs(), None
292
294
  elif outputs_only:
293
295
  return None, await details.outputs()
294
296
  inputs = await details.inputs()
295
- outputs: flyte.remote._action.ActionOutputs | None | str = None
297
+ outputs: ActionOutputs | None | str = None
296
298
  try:
297
299
  outputs = await details.outputs()
298
300
  except Exception:
flyte/cli/_run.py CHANGED
@@ -3,12 +3,12 @@ from __future__ import annotations
3
3
  import asyncio
4
4
  import inspect
5
5
  from dataclasses import dataclass, field, fields
6
+ from functools import lru_cache
6
7
  from pathlib import Path
7
8
  from types import ModuleType
8
9
  from typing import Any, Dict, List, cast
9
10
 
10
- import click
11
- from click import Context, Parameter
11
+ import rich_click as click
12
12
  from rich.console import Console
13
13
  from typing_extensions import get_args
14
14
 
@@ -19,6 +19,34 @@ from . import _common as common
19
19
  from ._common import CLIConfig
20
20
  from ._params import to_click_option
21
21
 
22
+ RUN_REMOTE_CMD = "deployed-task"
23
+
24
+
25
+ @lru_cache()
26
+ def _initialize_config(ctx: click.Context, project: str, domain: str):
27
+ obj: CLIConfig | None = ctx.obj
28
+ if obj is None:
29
+ import flyte.config
30
+
31
+ obj = CLIConfig(flyte.config.auto(), ctx)
32
+
33
+ obj.init(project, domain)
34
+ return obj
35
+
36
+
37
+ @lru_cache()
38
+ def _list_tasks(
39
+ ctx: click.Context,
40
+ project: str,
41
+ domain: str,
42
+ by_task_name: str | None = None,
43
+ by_task_env: str | None = None,
44
+ ) -> list[str]:
45
+ import flyte.remote
46
+
47
+ _initialize_config(ctx, project, domain)
48
+ return [task.name for task in flyte.remote.Task.listall(by_task_name=by_task_name, by_task_env=by_task_env)]
49
+
22
50
 
23
51
  @dataclass
24
52
  class RunArguments:
@@ -60,7 +88,7 @@ class RunArguments:
60
88
  },
61
89
  )
62
90
  follow: bool = field(
63
- default=True,
91
+ default=False,
64
92
  metadata={
65
93
  "click.option": click.Option(
66
94
  ["--follow", "-f"],
@@ -92,23 +120,17 @@ class RunTaskCommand(click.Command):
92
120
  kwargs.pop("name", None)
93
121
  super().__init__(obj_name, *args, **kwargs)
94
122
 
95
- def invoke(self, ctx: Context):
96
- obj: CLIConfig = ctx.obj
97
- if obj is None:
98
- import flyte.config
99
-
100
- obj = CLIConfig(flyte.config.auto(), ctx)
101
-
102
- obj.init(self.run_args.project, self.run_args.domain)
123
+ def invoke(self, ctx: click.Context):
124
+ obj: CLIConfig = _initialize_config(ctx, self.run_args.project, self.run_args.domain)
103
125
 
104
126
  async def _run():
105
127
  import flyte
106
128
 
107
- r = flyte.with_runcontext(
129
+ r = await flyte.with_runcontext(
108
130
  copy_style=self.run_args.copy_style,
109
131
  mode="local" if self.run_args.local else "remote",
110
132
  name=self.run_args.name,
111
- ).run(self.obj, **ctx.params)
133
+ ).run.aio(self.obj, **ctx.params)
112
134
  if isinstance(r, Run) and r.action is not None:
113
135
  console = Console()
114
136
  console.print(
@@ -129,7 +151,7 @@ class RunTaskCommand(click.Command):
129
151
 
130
152
  asyncio.run(_run())
131
153
 
132
- def get_params(self, ctx: Context) -> List[Parameter]:
154
+ def get_params(self, ctx: click.Context) -> List[click.Parameter]:
133
155
  # Note this function may be called multiple times by click.
134
156
  task = self.obj
135
157
  from .._internal.runtime.types_serde import transform_native_to_typed_interface
@@ -139,7 +161,7 @@ class RunTaskCommand(click.Command):
139
161
  return super().get_params(ctx)
140
162
  inputs_interface = task.native_interface.inputs
141
163
 
142
- params: List[Parameter] = []
164
+ params: List[click.Parameter] = []
143
165
  for name, var in interface.inputs.variables.items():
144
166
  default_val = None
145
167
  if inputs_interface[name][1] is not inspect._empty:
@@ -156,8 +178,9 @@ class TaskPerFileGroup(common.ObjectsPerFileGroup):
156
178
  """
157
179
 
158
180
  def __init__(self, filename: Path, run_args: RunArguments, *args, **kwargs):
159
- args = (filename, *args)
160
- super().__init__(*args, **kwargs)
181
+ if filename.is_absolute():
182
+ filename = filename.relative_to(Path.cwd())
183
+ super().__init__(*(filename, *args), **kwargs)
161
184
  self.run_args = run_args
162
185
 
163
186
  def _filter_objects(self, module: ModuleType) -> Dict[str, Any]:
@@ -173,6 +196,172 @@ class TaskPerFileGroup(common.ObjectsPerFileGroup):
173
196
  )
174
197
 
175
198
 
199
+ class RunReferenceTaskCommand(click.Command):
200
+ def __init__(self, task_name: str, run_args: RunArguments, version: str | None, *args, **kwargs):
201
+ self.task_name = task_name
202
+ self.run_args = run_args
203
+ self.version = version
204
+
205
+ super().__init__(*args, **kwargs)
206
+
207
+ def invoke(self, ctx: click.Context):
208
+ obj: CLIConfig = _initialize_config(ctx, self.run_args.project, self.run_args.domain)
209
+
210
+ async def _run():
211
+ import flyte
212
+ import flyte.remote
213
+
214
+ task = flyte.remote.Task.get(self.task_name, version=self.version, auto_version="latest")
215
+
216
+ r = await flyte.with_runcontext(
217
+ copy_style=self.run_args.copy_style,
218
+ mode="local" if self.run_args.local else "remote",
219
+ name=self.run_args.name,
220
+ ).run.aio(task, **ctx.params)
221
+ if isinstance(r, Run) and r.action is not None:
222
+ console = Console()
223
+ console.print(
224
+ common.get_panel(
225
+ "Run",
226
+ f"[green bold]Created Run: {r.name} [/green bold] "
227
+ f"(Project: {r.action.action_id.run.project}, Domain: {r.action.action_id.run.domain})\n"
228
+ f"➡️ [blue bold][link={r.url}]{r.url}[/link][/blue bold]",
229
+ obj.output_format,
230
+ )
231
+ )
232
+ if self.run_args.follow:
233
+ console.print(
234
+ "[dim]Log streaming enabled, will wait for task to start running "
235
+ "and log stream to be available[/dim]"
236
+ )
237
+ await r.show_logs.aio(max_lines=30, show_ts=True, raw=False)
238
+
239
+ asyncio.run(_run())
240
+
241
+ def get_params(self, ctx: click.Context) -> List[click.Parameter]:
242
+ # Note this function may be called multiple times by click.
243
+ import flyte.remote
244
+ from flyte._internal.runtime.types_serde import transform_native_to_typed_interface
245
+
246
+ _initialize_config(ctx, self.run_args.project, self.run_args.domain)
247
+
248
+ task = flyte.remote.Task.get(self.task_name, auto_version="latest")
249
+ task_details = task.fetch()
250
+
251
+ interface = transform_native_to_typed_interface(task_details.interface)
252
+ if interface is None:
253
+ return super().get_params(ctx)
254
+ inputs_interface = task_details.interface.inputs
255
+
256
+ params: List[click.Parameter] = []
257
+ for name, var in interface.inputs.variables.items():
258
+ default_val = None
259
+ if inputs_interface[name][1] is not inspect._empty:
260
+ default_val = inputs_interface[name][1]
261
+ params.append(to_click_option(name, var, inputs_interface[name][0], default_val))
262
+
263
+ self.params = params
264
+ return super().get_params(ctx)
265
+
266
+
267
+ class ReferenceEnvGroup(common.GroupBase):
268
+ def __init__(self, name: str, *args, run_args, env: str, **kwargs):
269
+ super().__init__(*args, **kwargs)
270
+ self.name = name
271
+ self.env = env
272
+ self.run_args = run_args
273
+
274
+ def list_commands(self, ctx):
275
+ return _list_tasks(ctx, self.run_args.project, self.run_args.domain, by_task_env=self.env)
276
+
277
+ def get_command(self, ctx, name):
278
+ return RunReferenceTaskCommand(
279
+ task_name=name,
280
+ run_args=self.run_args,
281
+ name=name,
282
+ version=None,
283
+ help=f"Run deployed task '{name}' from the Flyte backend",
284
+ )
285
+
286
+
287
+ class ReferenceTaskGroup(common.GroupBase):
288
+ """
289
+ Group that creates a command for each reference task in the current directory that is not __init__.py.
290
+ """
291
+
292
+ def __init__(self, name: str, *args, run_args, tasks: list[str] | None = None, **kwargs):
293
+ super().__init__(*args, **kwargs)
294
+ self.name = name
295
+ self.run_args = run_args
296
+
297
+ def list_commands(self, ctx):
298
+ # list envs of all reference tasks
299
+ envs = []
300
+ for task in _list_tasks(ctx, self.run_args.project, self.run_args.domain):
301
+ env = task.split(".")[0]
302
+ if env not in envs:
303
+ envs.append(env)
304
+ return envs
305
+
306
+ @staticmethod
307
+ def _parse_task_name(task_name: str) -> tuple[str, str | None, str | None]:
308
+ import re
309
+
310
+ pattern = r"^([^.:]+)(?:\.([^:]+))?(?::(.+))?$"
311
+ match = re.match(pattern, task_name)
312
+ if not match:
313
+ raise click.BadParameter(f"Invalid task name format: {task_name}")
314
+ return match.group(1), match.group(2), match.group(3)
315
+
316
+ def _env_is_task(self, ctx: click.Context, env: str) -> bool:
317
+ # check if the env name is the full task name, since sometimes task
318
+ # names don't have an environment prefix
319
+ tasks = [*_list_tasks(ctx, self.run_args.project, self.run_args.domain, by_task_name=env)]
320
+ return len(tasks) > 0
321
+
322
+ def get_command(self, ctx, name):
323
+ env, task, version = self._parse_task_name(name)
324
+ match env, task, version:
325
+ case env, None, None:
326
+ if self._env_is_task(ctx, env):
327
+ # this handles cases where task names do not have a environment prefix
328
+ task_name = env
329
+ return RunReferenceTaskCommand(
330
+ task_name=task_name,
331
+ run_args=self.run_args,
332
+ name=task_name,
333
+ version=None,
334
+ help=f"Run reference task `{task_name}` from the Flyte backend",
335
+ )
336
+ else:
337
+ return ReferenceEnvGroup(
338
+ name=name,
339
+ run_args=self.run_args,
340
+ env=env,
341
+ help=f"Run reference tasks in the `{env}` environment from the Flyte backend",
342
+ )
343
+ case env, task, None:
344
+ task_name = f"{env}.{task}"
345
+ return RunReferenceTaskCommand(
346
+ task_name=task_name,
347
+ run_args=self.run_args,
348
+ name=task_name,
349
+ version=None,
350
+ help=f"Run reference task '{task_name}' from the Flyte backend",
351
+ )
352
+ case env, task, version:
353
+ task_name = f"{env}.{task}"
354
+ return RunReferenceTaskCommand(
355
+ task_name=task_name,
356
+ run_args=self.run_args,
357
+ version=version,
358
+ name=f"{task_name}:{version}",
359
+ help=f"Run reference task '{task_name}' from the Flyte backend",
360
+ )
361
+ case _:
362
+ raise click.BadParameter(f"Invalid task name format: {task_name}")
363
+
364
+
176
365
  class TaskFiles(common.FileGroup):
177
366
  """
178
367
  Group that creates a command for each file in the current directory that is not __init__.py.
@@ -191,25 +380,45 @@ class TaskFiles(common.FileGroup):
191
380
  kwargs["params"].extend(RunArguments.options())
192
381
  super().__init__(*args, directory=directory, **kwargs)
193
382
 
194
- def get_command(self, ctx, filename):
383
+ def list_commands(self, ctx):
384
+ v = [
385
+ RUN_REMOTE_CMD,
386
+ *super().list_commands(ctx),
387
+ ]
388
+ return v
389
+
390
+ def get_command(self, ctx, cmd_name):
195
391
  run_args = RunArguments.from_dict(ctx.params)
196
- fp = Path(filename)
392
+
393
+ if cmd_name == RUN_REMOTE_CMD:
394
+ return ReferenceTaskGroup(
395
+ name=cmd_name,
396
+ run_args=run_args,
397
+ help="Run reference task from the Flyte backend",
398
+ )
399
+
400
+ fp = Path(cmd_name)
197
401
  if not fp.exists():
198
- raise click.BadParameter(f"File {filename} does not exist")
402
+ raise click.BadParameter(f"File {cmd_name} does not exist")
199
403
  if fp.is_dir():
200
- return TaskFiles(directory=fp)
404
+ return TaskFiles(
405
+ directory=fp,
406
+ help=f"Run `*.py` file inside the {fp} directory",
407
+ )
201
408
  return TaskPerFileGroup(
202
409
  filename=fp,
203
410
  run_args=run_args,
204
- name=filename,
205
- help=f"Run, functions decorated with `env.task` in {filename}",
411
+ name=cmd_name,
412
+ help=f"Run functions decorated with `env.task` in {cmd_name}",
206
413
  )
207
414
 
208
415
 
209
416
  run = TaskFiles(
210
417
  name="run",
211
- help="""
212
- Run a task from a python file.
418
+ help=f"""
419
+ Run a task from a python file or deployed task.
420
+
421
+ To run a remote task that already exists in Flyte, use the {RUN_REMOTE_CMD} command:
213
422
 
214
423
  Example usage:
215
424
 
@@ -227,6 +436,30 @@ Flyte environment:
227
436
  flyte run --local hello.py my_task --arg1 value1 --arg2 value2
228
437
  ```
229
438
 
439
+ To run tasks that you've already deployed to Flyte, use the {RUN_REMOTE_CMD} command:
440
+
441
+ ```bash
442
+ flyte run {RUN_REMOTE_CMD} my_env.my_task --arg1 value1 --arg2 value2
443
+ ```
444
+
445
+ To run a specific version of a deployed task, use the `env.task:version` syntax:
446
+
447
+ ```bash
448
+ flyte run {RUN_REMOTE_CMD} my_env.my_task:xyz123 --arg1 value1 --arg2 value2
449
+ ```
450
+
451
+ You can specify the `--config` flag to point to a specific Flyte cluster:
452
+
453
+ ```bash
454
+ flyte run --config my-config.yaml {RUN_REMOTE_CMD} ...
455
+ ```
456
+
457
+ You can discover what deployed tasks are available by running:
458
+
459
+ ```bash
460
+ flyte run {RUN_REMOTE_CMD}
461
+ ```
462
+
230
463
  Other arguments to the run command are listed below.
231
464
 
232
465
  Arguments for the task itself are provided after the task name and can be retrieved using `--help`. For example:
flyte/models.py CHANGED
@@ -166,6 +166,7 @@ class TaskContext:
166
166
  action: ActionID
167
167
  version: str
168
168
  raw_data_path: RawDataPath
169
+ input_path: str | None = None
169
170
  output_path: str
170
171
  run_base_dir: str
171
172
  report: Report
@@ -175,6 +176,7 @@ class TaskContext:
175
176
  compiled_image_cache: ImageCache | None = None
176
177
  data: Dict[str, Any] = field(default_factory=dict)
177
178
  mode: Literal["local", "remote", "hybrid"] = "remote"
179
+ interactive_mode: bool = False
178
180
 
179
181
  def replace(self, **kwargs) -> TaskContext:
180
182
  if "data" in kwargs:
@@ -191,6 +193,13 @@ class TaskContext:
191
193
  def __getitem__(self, key: str) -> Optional[Any]:
192
194
  return self.data.get(key)
193
195
 
196
+ def is_in_cluster(self):
197
+ """
198
+ Check if the task is running in a cluster.
199
+ :return: bool
200
+ """
201
+ return self.mode == "remote"
202
+
194
203
 
195
204
  @rich.repr.auto
196
205
  @dataclass(frozen=True, kw_only=True)
@@ -34,6 +34,7 @@ class Authenticator(object):
34
34
  http_proxy_url: typing.Optional[str] = None,
35
35
  verify: bool = True,
36
36
  ca_cert_path: typing.Optional[str] = None,
37
+ default_header_key: str = "authorization",
37
38
  **kwargs,
38
39
  ):
39
40
  """
@@ -80,6 +81,7 @@ class Authenticator(object):
80
81
  self._http_session = http_session or get_async_session(**kwargs)
81
82
  # Id for tracking credential refresh state
82
83
  self._creds_id = self._creds.id if self._creds else None
84
+ self._default_header_key = default_header_key
83
85
 
84
86
  async def _resolve_config(self) -> ClientConfig:
85
87
  """
@@ -131,10 +133,14 @@ class Authenticator(object):
131
133
  """
132
134
  creds = self.get_credentials()
133
135
  if creds:
134
- cfg = await self._resolve_config()
136
+ header_key = self._default_header_key
137
+ if self._resolved_config is not None:
138
+ # We only resolve the config during authentication flow, to avoid unnecessary network calls
139
+ # and usually the header_key is consistent.
140
+ header_key = self._resolved_config.header_key
135
141
  return GrpcAuthMetadata(
136
142
  creds_id=creds.id,
137
- pairs=Metadata((cfg.header_key, f"Bearer {creds.access_token}")),
143
+ pairs=Metadata((header_key, f"Bearer {creds.access_token}")),
138
144
  )
139
145
  return None
140
146
 
@@ -81,7 +81,7 @@ class DeviceCodeAuthenticator(Authenticator):
81
81
  for_endpoint=self._endpoint,
82
82
  )
83
83
  except (AuthenticationError, AuthenticationPending):
84
- logger.warning("Failed to refresh token. Kicking off a full authorization flow.")
84
+ logger.warning("Logging in...")
85
85
 
86
86
  """Fall back to device flow"""
87
87
  resp = await token_client.get_device_code(
@@ -123,7 +123,7 @@ class PKCEAuthenticator(Authenticator):
123
123
  try:
124
124
  return await self._auth_client.refresh_access_token(self._creds)
125
125
  except AccessTokenNotFoundError:
126
- logger.warning("Failed to refresh token. Kicking off a full authorization flow.")
126
+ logger.warning("Logging in...")
127
127
 
128
128
  return await self._auth_client.get_creds_from_remote()
129
129
 
@@ -1,4 +1,3 @@
1
- import os
2
1
  import ssl
3
2
  import typing
4
3
 
@@ -16,11 +15,6 @@ from ._authenticators.factory import (
16
15
  get_async_proxy_authenticator,
17
16
  )
18
17
 
19
- # Set environment variables for gRPC, this reduces log spew and avoids unnecessary warnings
20
- if "GRPC_VERBOSITY" not in os.environ:
21
- os.environ["GRPC_VERBOSITY"] = "ERROR"
22
- os.environ["GRPC_CPP_MIN_LOG_LEVEL"] = "ERROR"
23
-
24
18
  # Initialize gRPC AIO early enough so it can be used in the main thread
25
19
  init_grpc_aio()
26
20
 
@@ -1,3 +1,4 @@
1
+ import asyncio
1
2
  import typing
2
3
  from abc import abstractmethod
3
4
 
@@ -69,8 +70,9 @@ class RemoteClientConfigStore(ClientConfigStore):
69
70
  Retrieves the ClientConfig from the given grpc.Channel assuming AuthMetadataService is available
70
71
  """
71
72
  metadata_service = AuthMetadataServiceStub(self._unauthenticated_channel)
72
- public_client_config = await metadata_service.GetPublicClientConfig(PublicClientAuthConfigRequest())
73
- oauth2_metadata = await metadata_service.GetOAuth2Metadata(OAuth2MetadataRequest())
73
+ oauth2_metadata_task = metadata_service.GetOAuth2Metadata(OAuth2MetadataRequest())
74
+ public_client_config_task = metadata_service.GetPublicClientConfig(PublicClientAuthConfigRequest())
75
+ oauth2_metadata, public_client_config = await asyncio.gather(oauth2_metadata_task, public_client_config_task)
74
76
  return ClientConfig(
75
77
  token_endpoint=oauth2_metadata.token_endpoint,
76
78
  authorization_endpoint=oauth2_metadata.authorization_endpoint,
@@ -1,5 +1,19 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import os
4
+
5
+ # Set environment variables for gRPC, this reduces log spew and avoids unnecessary warnings
6
+ # before importing grpc
7
+ if "GRPC_VERBOSITY" not in os.environ:
8
+ os.environ["GRPC_VERBOSITY"] = "ERROR"
9
+ os.environ["GRPC_CPP_MIN_LOG_LEVEL"] = "ERROR"
10
+ # Disable fork support (stops "skipping fork() handlers")
11
+ os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "0"
12
+ # Reduce absl/glog verbosity
13
+ os.environ["GLOG_minloglevel"] = "2"
14
+ os.environ["ABSL_LOG"] = "0"
15
+ #### Has to be before grpc
16
+
3
17
  import grpc
4
18
  from flyteidl.service import admin_pb2_grpc, dataproxy_pb2_grpc
5
19