flyte 2.0.0b9__py3-none-any.whl → 2.0.0b13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flyte/_tools.py CHANGED
@@ -1,6 +1,3 @@
1
- import os
2
-
3
-
4
1
  def ipython_check() -> bool:
5
2
  """
6
3
  Check if interface is launching from iPython (not colab)
@@ -17,16 +14,6 @@ def ipython_check() -> bool:
17
14
  return is_ipython
18
15
 
19
16
 
20
- def is_in_cluster() -> bool:
21
- """
22
- Check if the task is running in a cluster
23
- :return is_in_cluster (bool): True or False
24
- """
25
- if os.getenv("_UN_CLS"):
26
- return True
27
- return False
28
-
29
-
30
17
  def ipywidgets_check() -> bool:
31
18
  """
32
19
  Check if the interface is running in IPython with ipywidgets support.
flyte/_version.py CHANGED
@@ -1,7 +1,14 @@
1
1
  # file generated by setuptools-scm
2
2
  # don't change, don't track in version control
3
3
 
4
- __all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
5
12
 
6
13
  TYPE_CHECKING = False
7
14
  if TYPE_CHECKING:
@@ -9,13 +16,19 @@ if TYPE_CHECKING:
9
16
  from typing import Union
10
17
 
11
18
  VERSION_TUPLE = Tuple[Union[int, str], ...]
19
+ COMMIT_ID = Union[str, None]
12
20
  else:
13
21
  VERSION_TUPLE = object
22
+ COMMIT_ID = object
14
23
 
15
24
  version: str
16
25
  __version__: str
17
26
  __version_tuple__: VERSION_TUPLE
18
27
  version_tuple: VERSION_TUPLE
28
+ commit_id: COMMIT_ID
29
+ __commit_id__: COMMIT_ID
19
30
 
20
- __version__ = version = '2.0.0b9'
21
- __version_tuple__ = version_tuple = (2, 0, 0, 'b9')
31
+ __version__ = version = '2.0.0b13'
32
+ __version_tuple__ = version_tuple = (2, 0, 0, 'b13')
33
+
34
+ __commit_id__ = commit_id = 'g07f30e36d'
flyte/cli/_common.py CHANGED
@@ -316,11 +316,23 @@ class FileGroup(GroupBase):
316
316
  def files(self):
317
317
  if self._files is None:
318
318
  directory = self._dir or Path(".").absolute()
319
- self._files = [os.fspath(p) for p in directory.glob("*.py") if p.name != "__init__.py"]
320
- if not self._files:
321
- self._files = [os.fspath(".")] + [
322
- os.fspath(p.name) for p in directory.iterdir() if not p.name.startswith(("_", ".")) and p.is_dir()
319
+ # add python files
320
+ _files = [os.fspath(p) for p in directory.glob("*.py") if p.name != "__init__.py"]
321
+
322
+ # add directories
323
+ _files.extend(
324
+ [
325
+ os.fspath(directory / p.name)
326
+ for p in directory.iterdir()
327
+ if not p.name.startswith(("_", ".")) and p.is_dir()
323
328
  ]
329
+ )
330
+
331
+ # files that are in the current directory or subdirectories of the
332
+ # current directory should be displayed as relative paths
333
+ self._files = [
334
+ str(Path(f).relative_to(Path.cwd())) if Path(f).is_relative_to(Path.cwd()) else f for f in _files
335
+ ]
324
336
  return self._files
325
337
 
326
338
  def list_commands(self, ctx):
@@ -351,7 +363,6 @@ def format(title: str, vals: Iterable[Any], of: OutputFormat = "table") -> Table
351
363
  """
352
364
  Get a table from a list of values.
353
365
  """
354
-
355
366
  match of:
356
367
  case "table-simple":
357
368
  return _table_format(Table(title, box=None), vals)
flyte/cli/_gen.py CHANGED
@@ -67,8 +67,13 @@ def markdown(cfg: common.CLIConfig):
67
67
  output_verb_groups: dict[str, list[str]] = {}
68
68
  output_noun_groups: dict[str, list[str]] = {}
69
69
 
70
+ processed = []
70
71
  commands = [*[("flyte", ctx.command)], *walk_commands(ctx)]
71
72
  for cmd_path, cmd in commands:
73
+ if cmd in processed:
74
+ # We already processed this command, skip it
75
+ continue
76
+ processed.append(cmd)
72
77
  output.append("")
73
78
 
74
79
  cmd_path_parts = cmd_path.split(" ")
@@ -136,7 +141,11 @@ def markdown(cfg: common.CLIConfig):
136
141
  output_verb_index.append("| ------ | -- |")
137
142
  for verb, nouns in output_verb_groups.items():
138
143
  entries = [f"[`{noun}`](#flyte-{verb}-{noun})" for noun in nouns]
139
- output_verb_index.append(f"| `{verb}` | {', '.join(entries)} |")
144
+ if len(entries) == 0:
145
+ verb_link = f"[`{verb}`](#flyte-{verb})"
146
+ output_verb_index.append(f"| {verb_link} | - |")
147
+ else:
148
+ output_verb_index.append(f"| `{verb}` | {', '.join(entries)} |")
140
149
 
141
150
  output_noun_index = []
142
151
 
flyte/cli/_get.py CHANGED
@@ -5,8 +5,6 @@ import rich_click as click
5
5
  from rich.console import Console
6
6
  from rich.pretty import pretty_repr
7
7
 
8
- import flyte.remote._action
9
-
10
8
  from . import _common as common
11
9
 
12
10
 
@@ -76,7 +74,7 @@ def run(
76
74
  console = Console()
77
75
  if name:
78
76
  details = RunDetails.get(name=name)
79
- console.print(pretty_repr(details))
77
+ console.print(common.format(f"Run {name}", [details], "json"))
80
78
  else:
81
79
  console.print(common.format("Runs", Run.listall(limit=limit), cfg.output_format))
82
80
 
@@ -110,7 +108,7 @@ def task(
110
108
  if v is None:
111
109
  raise click.BadParameter(f"Task {name} not found.")
112
110
  t = v.fetch()
113
- console.print(pretty_repr(t))
111
+ console.print(common.format(f"Task {name}", [t], "json"))
114
112
  else:
115
113
  console.print(common.format("Tasks", Task.listall(by_task_name=name, limit=limit), cfg.output_format))
116
114
  else:
@@ -131,19 +129,22 @@ def action(
131
129
  """
132
130
  Get all actions for a run or details for a specific action.
133
131
  """
132
+ import flyte.remote as remote
134
133
 
135
134
  cfg.init(project=project, domain=domain)
136
135
 
137
136
  console = Console()
138
137
  if action_name:
139
- console.print(pretty_repr(flyte.remote._action.Action.get(run_name=run_name, name=action_name)))
140
- else:
141
- # List all actions for the run
142
138
  console.print(
143
139
  common.format(
144
- f"Actions for {run_name}", flyte.remote._action.Action.listall(for_run_name=run_name), cfg.output_format
140
+ f"Action {run_name}.{action_name}", [remote.Action.get(run_name=run_name, name=action_name)], "json"
145
141
  )
146
142
  )
143
+ else:
144
+ # List all actions for the run
145
+ console.print(
146
+ common.format(f"Actions for {run_name}", remote.Action.listall(for_run_name=run_name), cfg.output_format)
147
+ )
147
148
 
148
149
 
149
150
  @get.command(cls=common.CommandBase)
@@ -211,7 +212,7 @@ def logs(
211
212
  task.cancel()
212
213
 
213
214
  if action_name:
214
- obj = flyte.remote._action.Action.get(run_name=run_name, name=action_name)
215
+ obj = remote.Action.get(run_name=run_name, name=action_name)
215
216
  else:
216
217
  obj = remote.Run.get(run_name)
217
218
  asyncio.run(_run_log_view(obj))
@@ -235,7 +236,7 @@ def secret(
235
236
 
236
237
  console = Console()
237
238
  if name:
238
- console.print(pretty_repr(remote.Secret.get(name)))
239
+ console.print(common.format("Secret", [remote.Secret.get(name)], "json"))
239
240
  else:
240
241
  console.print(common.format("Secrets", remote.Secret.listall(), cfg.output_format))
241
242
 
@@ -275,24 +276,25 @@ def io(
275
276
  raise click.BadParameter("Cannot use both --inputs-only and --outputs-only")
276
277
 
277
278
  import flyte.remote as remote
279
+ from flyte.remote import ActionDetails, ActionInputs, ActionOutputs
278
280
 
279
281
  cfg.init(project=project, domain=domain)
280
282
  console = Console()
281
283
  if action_name:
282
- obj = flyte.remote._action.ActionDetails.get(run_name=run_name, name=action_name)
284
+ obj = ActionDetails.get(run_name=run_name, name=action_name)
283
285
  else:
284
286
  obj = remote.RunDetails.get(run_name)
285
287
 
286
288
  async def _get_io(
287
- details: Union[remote.RunDetails, flyte.remote._action.ActionDetails],
288
- ) -> Tuple[flyte.remote._action.ActionInputs | None, flyte.remote._action.ActionOutputs | None | str]:
289
+ details: Union[remote.RunDetails, ActionDetails],
290
+ ) -> Tuple[ActionInputs | None, ActionOutputs | None | str]:
289
291
  if inputs_only or outputs_only:
290
292
  if inputs_only:
291
293
  return await details.inputs(), None
292
294
  elif outputs_only:
293
295
  return None, await details.outputs()
294
296
  inputs = await details.inputs()
295
- outputs: flyte.remote._action.ActionOutputs | None | str = None
297
+ outputs: ActionOutputs | None | str = None
296
298
  try:
297
299
  outputs = await details.outputs()
298
300
  except Exception:
flyte/cli/_run.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  import asyncio
4
4
  import inspect
5
5
  from dataclasses import dataclass, field, fields
6
+ from functools import lru_cache
6
7
  from pathlib import Path
7
8
  from types import ModuleType
8
9
  from typing import Any, Dict, List, cast
@@ -19,6 +20,34 @@ from . import _common as common
19
20
  from ._common import CLIConfig
20
21
  from ._params import to_click_option
21
22
 
23
+ RUN_REMOTE_CMD = "deployed-task"
24
+
25
+
26
+ @lru_cache()
27
+ def _initialize_config(ctx: Context, project: str, domain: str):
28
+ obj: CLIConfig | None = ctx.obj
29
+ if obj is None:
30
+ import flyte.config
31
+
32
+ obj = CLIConfig(flyte.config.auto(), ctx)
33
+
34
+ obj.init(project, domain)
35
+ return obj
36
+
37
+
38
+ @lru_cache()
39
+ def _list_tasks(
40
+ ctx: Context,
41
+ project: str,
42
+ domain: str,
43
+ by_task_name: str | None = None,
44
+ by_task_env: str | None = None,
45
+ ) -> list[str]:
46
+ import flyte.remote
47
+
48
+ _initialize_config(ctx, project, domain)
49
+ return [task.name for task in flyte.remote.Task.listall(by_task_name=by_task_name, by_task_env=by_task_env)]
50
+
22
51
 
23
52
  @dataclass
24
53
  class RunArguments:
@@ -60,7 +89,7 @@ class RunArguments:
60
89
  },
61
90
  )
62
91
  follow: bool = field(
63
- default=True,
92
+ default=False,
64
93
  metadata={
65
94
  "click.option": click.Option(
66
95
  ["--follow", "-f"],
@@ -93,22 +122,16 @@ class RunTaskCommand(click.Command):
93
122
  super().__init__(obj_name, *args, **kwargs)
94
123
 
95
124
  def invoke(self, ctx: Context):
96
- obj: CLIConfig = ctx.obj
97
- if obj is None:
98
- import flyte.config
99
-
100
- obj = CLIConfig(flyte.config.auto(), ctx)
101
-
102
- obj.init(self.run_args.project, self.run_args.domain)
125
+ obj: CLIConfig = _initialize_config(ctx, self.run_args.project, self.run_args.domain)
103
126
 
104
127
  async def _run():
105
128
  import flyte
106
129
 
107
- r = flyte.with_runcontext(
130
+ r = await flyte.with_runcontext(
108
131
  copy_style=self.run_args.copy_style,
109
132
  mode="local" if self.run_args.local else "remote",
110
133
  name=self.run_args.name,
111
- ).run(self.obj, **ctx.params)
134
+ ).run.aio(self.obj, **ctx.params)
112
135
  if isinstance(r, Run) and r.action is not None:
113
136
  console = Console()
114
137
  console.print(
@@ -156,8 +179,9 @@ class TaskPerFileGroup(common.ObjectsPerFileGroup):
156
179
  """
157
180
 
158
181
  def __init__(self, filename: Path, run_args: RunArguments, *args, **kwargs):
159
- args = (filename, *args)
160
- super().__init__(*args, **kwargs)
182
+ if filename.is_absolute():
183
+ filename = filename.relative_to(Path.cwd())
184
+ super().__init__(*(filename, *args), **kwargs)
161
185
  self.run_args = run_args
162
186
 
163
187
  def _filter_objects(self, module: ModuleType) -> Dict[str, Any]:
@@ -173,6 +197,173 @@ class TaskPerFileGroup(common.ObjectsPerFileGroup):
173
197
  )
174
198
 
175
199
 
200
+ class RunReferenceTaskCommand(click.Command):
201
+ def __init__(self, task_name: str, run_args: RunArguments, version: str | None, *args, **kwargs):
202
+ self.task_name = task_name
203
+ self.run_args = run_args
204
+ self.version = version
205
+
206
+ super().__init__(*args, **kwargs)
207
+
208
+ def invoke(self, ctx: click.Context):
209
+ obj: CLIConfig = _initialize_config(ctx, self.run_args.project, self.run_args.domain)
210
+
211
+ async def _run():
212
+ import flyte
213
+ import flyte.remote
214
+
215
+ task = flyte.remote.Task.get(self.task_name, version=self.version, auto_version="latest")
216
+
217
+ r = await flyte.with_runcontext(
218
+ copy_style=self.run_args.copy_style,
219
+ mode="local" if self.run_args.local else "remote",
220
+ name=self.run_args.name,
221
+ ).run.aio(task, **ctx.params)
222
+ if isinstance(r, Run) and r.action is not None:
223
+ console = Console()
224
+ console.print(
225
+ common.get_panel(
226
+ "Run",
227
+ f"[green bold]Created Run: {r.name} [/green bold] "
228
+ f"(Project: {r.action.action_id.run.project}, Domain: {r.action.action_id.run.domain})\n"
229
+ f"➡️ [blue bold][link={r.url}]{r.url}[/link][/blue bold]",
230
+ obj.output_format,
231
+ )
232
+ )
233
+ if self.run_args.follow:
234
+ console.print(
235
+ "[dim]Log streaming enabled, will wait for task to start running "
236
+ "and log stream to be available[/dim]"
237
+ )
238
+ await r.show_logs.aio(max_lines=30, show_ts=True, raw=False)
239
+
240
+ asyncio.run(_run())
241
+
242
+ def get_params(self, ctx: Context) -> List[Parameter]:
243
+ # Note this function may be called multiple times by click.
244
+ import flyte.remote
245
+ from flyte._internal.runtime.types_serde import transform_native_to_typed_interface
246
+
247
+ _initialize_config(ctx, self.run_args.project, self.run_args.domain)
248
+
249
+ task = flyte.remote.Task.get(self.task_name, auto_version="latest")
250
+ task_details = task.fetch()
251
+
252
+ interface = transform_native_to_typed_interface(task_details.interface)
253
+ if interface is None:
254
+ return super().get_params(ctx)
255
+ inputs_interface = task_details.interface.inputs
256
+
257
+ params: List[Parameter] = []
258
+ for name, var in interface.inputs.variables.items():
259
+ default_val = None
260
+ if inputs_interface[name][1] is not inspect._empty:
261
+ default_val = inputs_interface[name][1]
262
+ params.append(to_click_option(name, var, inputs_interface[name][0], default_val))
263
+
264
+ self.params = params
265
+ return super().get_params(ctx)
266
+
267
+
268
+ class ReferenceEnvGroup(common.GroupBase):
269
+ def __init__(self, name: str, *args, run_args, env: str, **kwargs):
270
+ super().__init__(*args, **kwargs)
271
+ self.name = name
272
+ self.env = env
273
+ self.run_args = run_args
274
+
275
+ def list_commands(self, ctx):
276
+ return _list_tasks(ctx, self.run_args.project, self.run_args.domain, by_task_env=self.env)
277
+
278
+ def get_command(self, ctx, name):
279
+ return RunReferenceTaskCommand(
280
+ task_name=name,
281
+ run_args=self.run_args,
282
+ name=name,
283
+ version=None,
284
+ help=f"Run deployed task '{name}' from the Flyte backend",
285
+ )
286
+
287
+
288
+ class ReferenceTaskGroup(common.GroupBase):
289
+ """
290
+ Group that creates a command for each reference task in the current directory that is not __init__.py.
291
+ """
292
+
293
+ def __init__(self, name: str, *args, run_args, tasks: list[str] | None = None, **kwargs):
294
+ super().__init__(*args, **kwargs)
295
+ self.name = name
296
+ self.run_args = run_args
297
+
298
+ def list_commands(self, ctx):
299
+ # list envs of all reference tasks
300
+ envs = []
301
+ for task in _list_tasks(ctx, self.run_args.project, self.run_args.domain):
302
+ env = task.split(".")[0]
303
+ if env not in envs:
304
+ envs.append(env)
305
+ return envs
306
+
307
+ @staticmethod
308
+ def _parse_task_name(task_name: str) -> tuple[str, str | None, str | None]:
309
+ import re
310
+
311
+ pattern = r"^([^.:]+)(?:\.([^:]+))?(?::(.+))?$"
312
+ match = re.match(pattern, task_name)
313
+ if not match:
314
+ raise click.BadParameter(f"Invalid task name format: {task_name}")
315
+ return match.group(1), match.group(2), match.group(3)
316
+
317
+ def _env_is_task(self, ctx: click.Context, env: str) -> bool:
318
+ # check if the env name is the full task name, since sometimes task
319
+ # names don't have an environment prefix
320
+ tasks = [*_list_tasks(ctx, self.run_args.project, self.run_args.domain, by_task_name=env)]
321
+ return len(tasks) > 0
322
+
323
+ def get_command(self, ctx, name):
324
+ env, task, version = self._parse_task_name(name)
325
+
326
+ match env, task, version:
327
+ case env, None, None:
328
+ if self._env_is_task(ctx, env):
329
+ # this handles cases where task names do not have a environment prefix
330
+ task_name = env
331
+ return RunReferenceTaskCommand(
332
+ task_name=task_name,
333
+ run_args=self.run_args,
334
+ name=task_name,
335
+ version=None,
336
+ help=f"Run reference task `{task_name}` from the Flyte backend",
337
+ )
338
+ else:
339
+ return ReferenceEnvGroup(
340
+ name=name,
341
+ run_args=self.run_args,
342
+ env=env,
343
+ help=f"Run reference tasks in the `{env}` environment from the Flyte backend",
344
+ )
345
+ case env, task, None:
346
+ task_name = f"{env}.{task}"
347
+ return RunReferenceTaskCommand(
348
+ task_name=task_name,
349
+ run_args=self.run_args,
350
+ name=task_name,
351
+ version=None,
352
+ help=f"Run reference task '{task_name}' from the Flyte backend",
353
+ )
354
+ case env, task, version:
355
+ task_name = f"{env}.{task}"
356
+ return RunReferenceTaskCommand(
357
+ task_name=task_name,
358
+ run_args=self.run_args,
359
+ version=version,
360
+ name=f"{task_name}:{version}",
361
+ help=f"Run reference task '{task_name}' from the Flyte backend",
362
+ )
363
+ case _:
364
+ raise click.BadParameter(f"Invalid task name format: {task_name}")
365
+
366
+
176
367
  class TaskFiles(common.FileGroup):
177
368
  """
178
369
  Group that creates a command for each file in the current directory that is not __init__.py.
@@ -191,25 +382,44 @@ class TaskFiles(common.FileGroup):
191
382
  kwargs["params"].extend(RunArguments.options())
192
383
  super().__init__(*args, directory=directory, **kwargs)
193
384
 
194
- def get_command(self, ctx, filename):
385
+ def list_commands(self, ctx):
386
+ return [
387
+ RUN_REMOTE_CMD,
388
+ *self.files,
389
+ ]
390
+
391
+ def get_command(self, ctx, cmd_name):
195
392
  run_args = RunArguments.from_dict(ctx.params)
196
- fp = Path(filename)
393
+
394
+ if cmd_name == RUN_REMOTE_CMD:
395
+ return ReferenceTaskGroup(
396
+ name=cmd_name,
397
+ run_args=run_args,
398
+ help="Run reference task from the Flyte backend",
399
+ )
400
+
401
+ fp = Path(cmd_name)
197
402
  if not fp.exists():
198
- raise click.BadParameter(f"File {filename} does not exist")
403
+ raise click.BadParameter(f"File {cmd_name} does not exist")
199
404
  if fp.is_dir():
200
- return TaskFiles(directory=fp)
405
+ return TaskFiles(
406
+ directory=fp,
407
+ help=f"Run `*.py` file inside the {fp} directory",
408
+ )
201
409
  return TaskPerFileGroup(
202
410
  filename=fp,
203
411
  run_args=run_args,
204
- name=filename,
205
- help=f"Run, functions decorated with `env.task` in {filename}",
412
+ name=cmd_name,
413
+ help=f"Run functions decorated with `env.task` in {cmd_name}",
206
414
  )
207
415
 
208
416
 
209
417
  run = TaskFiles(
210
418
  name="run",
211
- help="""
212
- Run a task from a python file.
419
+ help=f"""
420
+ Run a task from a python file or deployed task.
421
+
422
+ To run a remote task that already exists in Flyte, use the {RUN_REMOTE_CMD} command:
213
423
 
214
424
  Example usage:
215
425
 
@@ -227,6 +437,30 @@ Flyte environment:
227
437
  flyte run --local hello.py my_task --arg1 value1 --arg2 value2
228
438
  ```
229
439
 
440
+ To run tasks that you've already deployed to Flyte, use the {RUN_REMOTE_CMD} command:
441
+
442
+ ```bash
443
+ flyte run {RUN_REMOTE_CMD} my_env.my_task --arg1 value1 --arg2 value2
444
+ ```
445
+
446
+ To run a specific version of a deployed task, use the `env.task:version` syntax:
447
+
448
+ ```bash
449
+ flyte run {RUN_REMOTE_CMD} my_env.my_task:xyz123 --arg1 value1 --arg2 value2
450
+ ```
451
+
452
+ You can specify the `--config` flag to point to a specific Flyte cluster:
453
+
454
+ ```bash
455
+ flyte run --config my-config.yaml {RUN_REMOTE_CMD} ...
456
+ ```
457
+
458
+ You can discover what deployed tasks are available by running:
459
+
460
+ ```bash
461
+ flyte run {RUN_REMOTE_CMD}
462
+ ```
463
+
230
464
  Other arguments to the run command are listed below.
231
465
 
232
466
  Arguments for the task itself are provided after the task name and can be retrieved using `--help`. For example:
flyte/models.py CHANGED
@@ -191,6 +191,13 @@ class TaskContext:
191
191
  def __getitem__(self, key: str) -> Optional[Any]:
192
192
  return self.data.get(key)
193
193
 
194
+ def is_in_cluster(self):
195
+ """
196
+ Check if the task is running in a cluster.
197
+ :return: bool
198
+ """
199
+ return self.mode == "remote"
200
+
194
201
 
195
202
  @rich.repr.auto
196
203
  @dataclass(frozen=True, kw_only=True)
@@ -34,6 +34,7 @@ class Authenticator(object):
34
34
  http_proxy_url: typing.Optional[str] = None,
35
35
  verify: bool = True,
36
36
  ca_cert_path: typing.Optional[str] = None,
37
+ default_header_key: str = "authorization",
37
38
  **kwargs,
38
39
  ):
39
40
  """
@@ -80,6 +81,7 @@ class Authenticator(object):
80
81
  self._http_session = http_session or get_async_session(**kwargs)
81
82
  # Id for tracking credential refresh state
82
83
  self._creds_id = self._creds.id if self._creds else None
84
+ self._default_header_key = default_header_key
83
85
 
84
86
  async def _resolve_config(self) -> ClientConfig:
85
87
  """
@@ -131,10 +133,14 @@ class Authenticator(object):
131
133
  """
132
134
  creds = self.get_credentials()
133
135
  if creds:
134
- cfg = await self._resolve_config()
136
+ header_key = self._default_header_key
137
+ if self._resolved_config is not None:
138
+ # We only resolve the config during authentication flow, to avoid unnecessary network calls
139
+ # and usually the header_key is consistent.
140
+ header_key = self._resolved_config.header_key
135
141
  return GrpcAuthMetadata(
136
142
  creds_id=creds.id,
137
- pairs=Metadata((cfg.header_key, f"Bearer {creds.access_token}")),
143
+ pairs=Metadata((header_key, f"Bearer {creds.access_token}")),
138
144
  )
139
145
  return None
140
146
 
@@ -1,4 +1,3 @@
1
- import os
2
1
  import ssl
3
2
  import typing
4
3
 
@@ -16,11 +15,6 @@ from ._authenticators.factory import (
16
15
  get_async_proxy_authenticator,
17
16
  )
18
17
 
19
- # Set environment variables for gRPC, this reduces log spew and avoids unnecessary warnings
20
- if "GRPC_VERBOSITY" not in os.environ:
21
- os.environ["GRPC_VERBOSITY"] = "ERROR"
22
- os.environ["GRPC_CPP_MIN_LOG_LEVEL"] = "ERROR"
23
-
24
18
  # Initialize gRPC AIO early enough so it can be used in the main thread
25
19
  init_grpc_aio()
26
20
 
@@ -1,3 +1,4 @@
1
+ import asyncio
1
2
  import typing
2
3
  from abc import abstractmethod
3
4
 
@@ -69,8 +70,9 @@ class RemoteClientConfigStore(ClientConfigStore):
69
70
  Retrieves the ClientConfig from the given grpc.Channel assuming AuthMetadataService is available
70
71
  """
71
72
  metadata_service = AuthMetadataServiceStub(self._unauthenticated_channel)
72
- public_client_config = await metadata_service.GetPublicClientConfig(PublicClientAuthConfigRequest())
73
- oauth2_metadata = await metadata_service.GetOAuth2Metadata(OAuth2MetadataRequest())
73
+ oauth2_metadata_task = metadata_service.GetOAuth2Metadata(OAuth2MetadataRequest())
74
+ public_client_config_task = metadata_service.GetPublicClientConfig(PublicClientAuthConfigRequest())
75
+ oauth2_metadata, public_client_config = await asyncio.gather(oauth2_metadata_task, public_client_config_task)
74
76
  return ClientConfig(
75
77
  token_endpoint=oauth2_metadata.token_endpoint,
76
78
  authorization_endpoint=oauth2_metadata.authorization_endpoint,
@@ -1,5 +1,19 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import os
4
+
5
+ # Set environment variables for gRPC, this reduces log spew and avoids unnecessary warnings
6
+ # before importing grpc
7
+ if "GRPC_VERBOSITY" not in os.environ:
8
+ os.environ["GRPC_VERBOSITY"] = "ERROR"
9
+ os.environ["GRPC_CPP_MIN_LOG_LEVEL"] = "ERROR"
10
+ # Disable fork support (stops "skipping fork() handlers")
11
+ os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "0"
12
+ # Reduce absl/glog verbosity
13
+ os.environ["GLOG_minloglevel"] = "2"
14
+ os.environ["ABSL_LOG"] = "0"
15
+ #### Has to be before grpc
16
+
3
17
  import grpc
4
18
  from flyteidl.service import admin_pb2_grpc, dataproxy_pb2_grpc
5
19