flyte 0.2.0b2__py3-none-any.whl → 0.2.0b4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import asyncio
3
4
  import inspect
4
5
  from dataclasses import dataclass, field, fields
5
6
  from pathlib import Path
@@ -50,6 +51,28 @@ class RunArguments:
50
51
  )
51
52
  },
52
53
  )
54
+ name: str | None = field(
55
+ default=None,
56
+ metadata={
57
+ "click.option": click.Option(
58
+ ["--name"],
59
+ type=str,
60
+ help="Name of the run. If not provided, a random name will be generated.",
61
+ )
62
+ },
63
+ )
64
+ follow: bool = field(
65
+ default=True,
66
+ metadata={
67
+ "click.option": click.Option(
68
+ ["--follow", "-f"],
69
+ is_flag=True,
70
+ default=False,
71
+ help="Wait and watch logs for the parent action. If not provided, the cli will exit after "
72
+ "successfully launching a remote execution with a link to the UI.",
73
+ )
74
+ },
75
+ )
53
76
 
54
77
  @classmethod
55
78
  def from_dict(cls, d: Dict[str, Any]) -> RunArguments:
@@ -76,21 +99,31 @@ class RunTaskCommand(click.Command):
76
99
  assert obj.endpoint, "CLI Config should have an endpoint"
77
100
  obj.init(self.run_args.project, self.run_args.domain)
78
101
 
79
- r = flyte.with_runcontext(
80
- copy_style=self.run_args.copy_style,
81
- version=self.run_args.copy_style,
82
- mode="local" if self.run_args.local else "remote",
83
- ).run(self.obj, **ctx.params)
84
- if isinstance(r, Run) and r.action is not None:
85
- console = Console()
86
- console.print(
87
- common.get_panel(
88
- "Run",
89
- f"[green bold]Created Run: {r.name} [/green bold] "
90
- f"(Project: {r.action.action_id.run.project}, Domain: {r.action.action_id.run.domain})\n\n"
91
- f"[blue bold]{r.url}[/blue bold]",
102
+ async def _run():
103
+ r = flyte.with_runcontext(
104
+ copy_style=self.run_args.copy_style,
105
+ version=self.run_args.copy_style,
106
+ mode="local" if self.run_args.local else "remote",
107
+ name=self.run_args.name,
108
+ ).run(self.obj, **ctx.params)
109
+ if isinstance(r, Run) and r.action is not None:
110
+ console = Console()
111
+ console.print(
112
+ common.get_panel(
113
+ "Run",
114
+ f"[green bold]Created Run: {r.name} [/green bold] "
115
+ f"(Project: {r.action.action_id.run.project}, Domain: {r.action.action_id.run.domain})\n\n"
116
+ f"[blue bold]{r.url}[/blue bold]",
117
+ )
92
118
  )
93
- )
119
+ if self.run_args.follow:
120
+ console.print(
121
+ "[dim]Log streaming enabled, will wait for task to start running "
122
+ "and log stream to be available[/dim]"
123
+ )
124
+ await r.show_logs(max_lines=30, show_ts=True, raw=False)
125
+
126
+ asyncio.run(_run())
94
127
 
95
128
  def get_params(self, ctx: Context) -> List[Parameter]:
96
129
  # Note this function may be called multiple times by click.
@@ -165,11 +198,11 @@ class TaskFiles(common.FileGroup):
165
198
  filename=Path(filename),
166
199
  run_args=run_args,
167
200
  name=filename,
168
- help=f"Run, functions decorated `env.task` or instances of Tasks in {filename}",
201
+ help=f"Run, functions decorated `env.task` {filename}",
169
202
  )
170
203
 
171
204
 
172
205
  run = TaskFiles(
173
206
  name="run",
174
- help="Run a task from a python file",
207
+ help="Run a task from a python file.",
175
208
  )
@@ -3,6 +3,7 @@ import rich_click as click
3
3
  from flyte._logging import initialize_logger, logger
4
4
 
5
5
  from ..config import Config
6
+ from ._abort import abort
6
7
  from ._common import CLIConfig
7
8
  from ._create import create
8
9
  from ._deploy import deploy
@@ -79,7 +80,14 @@ def main(
79
80
  config_file: str | None,
80
81
  ):
81
82
  """
82
- v2 cli. Root command, please use one of the subcommands.
83
+
84
+ ____ __ _ _ ____ ____ _ _ ____ __
85
+ ( __)( ) ( \\/ )(_ _)( __) / )( \\(___ \\ / \
86
+ ) _) / (_/\\ ) / )( ) _) \\ \\/ / / __/ _( 0 )
87
+ (__) \\____/(__/ (__) (____) \\__/ (____)(_)\\__/
88
+
89
+ The flyte cli follows a simple verb based structure, where the top-level commands are verbs that describe the action
90
+ to be taken, and the subcommands are nouns that describe the object of the action.
83
91
  """
84
92
  log_level = _verbosity_to_loglevel(verbose)
85
93
  if log_level is not None:
@@ -102,3 +110,4 @@ main.add_command(run)
102
110
  main.add_command(deploy)
103
111
  main.add_command(get) # type: ignore
104
112
  main.add_command(create) # type: ignore
113
+ main.add_command(abort) # type: ignore
flyte/config/_config.py CHANGED
@@ -183,6 +183,8 @@ def get_config_file(c: typing.Union[str, ConfigFile, None]) -> ConfigFile | None
183
183
  if isinstance(c, str):
184
184
  logger.debug(f"Using specified config file at {c}")
185
185
  return ConfigFile(c)
186
+ elif isinstance(c, ConfigFile):
187
+ return c
186
188
  config_path = resolve_config_path()
187
189
  if config_path:
188
190
  return ConfigFile(str(config_path))
flyte/errors.py CHANGED
@@ -141,3 +141,12 @@ class ReferenceTaskError(RuntimeUserError):
141
141
 
142
142
  def __init__(self, message: str):
143
143
  super().__init__("ReferenceTaskUsageError", message, "user")
144
+
145
+
146
+ class LogsNotYetAvailableError(BaseRuntimeError):
147
+ """
148
+ This error is raised when the logs are not yet available for a task.
149
+ """
150
+
151
+ def __init__(self, message: str):
152
+ super().__init__("LogsNotYetAvailable", "system", message, None)
@@ -116,6 +116,7 @@ async def create_channel(
116
116
  kwargs["auth_type"] = "ClientSecret"
117
117
  kwargs["client_id"] = client_id
118
118
  kwargs["client_secret"] = client_secret
119
+ kwargs["client_credentials_secret"] = client_secret
119
120
 
120
121
  assert endpoint, "Endpoint must be specified by this point"
121
122
 
flyte/remote/_logs.py CHANGED
@@ -3,6 +3,7 @@ from collections import deque
3
3
  from dataclasses import dataclass
4
4
  from typing import AsyncGenerator, AsyncIterator
5
5
 
6
+ import grpc
6
7
  from rich.console import Console
7
8
  from rich.live import Live
8
9
  from rich.panel import Panel
@@ -12,16 +13,23 @@ from flyte._api_commons import syncer
12
13
  from flyte._initialize import get_client, requires_client
13
14
  from flyte._protos.logs.dataplane import payload_pb2
14
15
  from flyte._protos.workflow import run_definition_pb2, run_logs_service_pb2
16
+ from flyte.errors import LogsNotYetAvailableError
15
17
 
18
+ style_map = {
19
+ payload_pb2.LogLineOriginator.SYSTEM: "bold magenta",
20
+ payload_pb2.LogLineOriginator.USER: "cyan",
21
+ payload_pb2.LogLineOriginator.UNKNOWN: "light red",
22
+ }
16
23
 
17
- def _format_line(logline: payload_pb2.LogLine, show_ts: bool) -> Text:
18
- style_map = {
19
- payload_pb2.LogLineOriginator.SYSTEM: "bold magenta",
20
- payload_pb2.LogLineOriginator.USER: "cyan",
21
- payload_pb2.LogLineOriginator.UNKNOWN: "light red",
22
- }
24
+
25
+ def _format_line(logline: payload_pb2.LogLine, show_ts: bool, filter_system: bool) -> Text | None:
26
+ if filter_system:
27
+ if logline.originator == payload_pb2.LogLineOriginator.SYSTEM:
28
+ return None
23
29
  style = style_map.get(logline.originator, "")
24
30
  if "flyte" in logline.message and "flyte.errors" not in logline.message:
31
+ if filter_system:
32
+ return None
25
33
  style = "dim"
26
34
  ts = ""
27
35
  if show_ts:
@@ -34,7 +42,14 @@ class AsyncLogViewer:
34
42
  A class to view logs asynchronously in the console or terminal or jupyter notebook.
35
43
  """
36
44
 
37
- def __init__(self, log_source: AsyncIterator, max_lines: int = 30, name: str = "Logs", show_ts: bool = False):
45
+ def __init__(
46
+ self,
47
+ log_source: AsyncIterator,
48
+ max_lines: int = 30,
49
+ name: str = "Logs",
50
+ show_ts: bool = False,
51
+ filter_system: bool = False,
52
+ ):
38
53
  self.console = Console()
39
54
  self.log_source = log_source
40
55
  self.max_lines = max_lines
@@ -42,23 +57,30 @@ class AsyncLogViewer:
42
57
  self.name = name
43
58
  self.show_ts = show_ts
44
59
  self.total_lines = 0
60
+ self.filter_flyte = filter_system
45
61
 
46
- def _render(self):
62
+ def _render(self) -> Panel:
47
63
  log_text = Text()
48
64
  for line in self.lines:
49
65
  log_text.append(line)
50
66
  return Panel(log_text, title=self.name, border_style="yellow")
51
67
 
52
68
  async def run(self):
53
- with Live(self._render(), refresh_per_second=10, console=self.console) as live:
69
+ with Live(self._render(), refresh_per_second=20, console=self.console) as live:
54
70
  try:
55
71
  async for logline in self.log_source:
56
- formatted = _format_line(logline, show_ts=self.show_ts)
57
- self.lines.append(formatted)
72
+ formatted = _format_line(logline, show_ts=self.show_ts, filter_system=self.filter_flyte)
73
+ if formatted:
74
+ self.lines.append(formatted)
58
75
  self.total_lines += 1
59
76
  live.update(self._render())
60
77
  except asyncio.CancelledError:
61
78
  pass
79
+ except KeyboardInterrupt:
80
+ pass
81
+ except LogsNotYetAvailableError as e:
82
+ self.console.print(f"[red]Error:[/red] {e}")
83
+ live.update("")
62
84
  self.console.print(f"Scrolled {self.total_lines} lines of logs.")
63
85
 
64
86
 
@@ -75,14 +97,24 @@ class Logs:
75
97
  :param action_id: The action ID to tail logs for.
76
98
  :param attempt: The attempt number (default is 0).
77
99
  """
78
- resp = get_client().logs_service.TailLogs(
79
- run_logs_service_pb2.TailLogsRequest(action_id=action_id, attempt=attempt)
80
- )
81
- async for log_set in resp:
82
- if log_set.logs:
83
- for log in log_set.logs:
84
- for line in log.lines:
85
- yield line
100
+ try:
101
+ resp = get_client().logs_service.TailLogs(
102
+ run_logs_service_pb2.TailLogsRequest(action_id=action_id, attempt=attempt)
103
+ )
104
+ async for log_set in resp:
105
+ if log_set.logs:
106
+ for log in log_set.logs:
107
+ for line in log.lines:
108
+ yield line
109
+ except asyncio.CancelledError:
110
+ pass
111
+ except KeyboardInterrupt:
112
+ pass
113
+ except grpc.aio.AioRpcError as e:
114
+ if e.code() == grpc.StatusCode.NOT_FOUND:
115
+ raise LogsNotYetAvailableError(
116
+ f"Log stream not available for action {action_id.name} in run {action_id.run.name}."
117
+ )
86
118
 
87
119
  @classmethod
88
120
  async def create_viewer(
@@ -92,6 +124,7 @@ class Logs:
92
124
  max_lines: int = 30,
93
125
  show_ts: bool = False,
94
126
  raw: bool = False,
127
+ filter_system: bool = False,
95
128
  ):
96
129
  """
97
130
  Create a log viewer for a given action ID and attempt.
@@ -101,16 +134,20 @@ class Logs:
101
134
  and keep only max_lines in view.
102
135
  :param show_ts: Whether to show timestamps in the logs.
103
136
  :param raw: if True, return the raw log lines instead of a viewer.
137
+ :param filter_system: Whether to filter log lines based on system logs.
104
138
  """
105
139
  if raw:
106
140
  console = Console()
107
141
  async for line in cls.tail.aio(cls, action_id=action_id, attempt=attempt):
108
- console.print(_format_line(line, show_ts=show_ts), end="")
142
+ line_text = _format_line(line, show_ts=show_ts, filter_system=filter_system)
143
+ if line_text:
144
+ console.print(line_text, end="")
109
145
  return
110
146
  viewer = AsyncLogViewer(
111
147
  log_source=cls.tail.aio(cls, action_id=action_id, attempt=attempt),
112
148
  max_lines=max_lines,
113
149
  show_ts=show_ts,
114
150
  name=f"{action_id.run.name}:{action_id.name} ({attempt})",
151
+ filter_system=filter_system,
115
152
  )
116
153
  await viewer.run()
flyte/remote/_run.py CHANGED
@@ -20,6 +20,8 @@ from .._protos.workflow.run_service_pb2 import WatchActionDetailsResponse
20
20
  from ._console import get_run_url
21
21
  from ._logs import Logs
22
22
 
23
+ WaitFor = Literal["terminal", "running"]
24
+
23
25
 
24
26
  def _action_time_phase(action: run_definition_pb2.Action | run_definition_pb2.ActionDetails) -> rich.repr.Result:
25
27
  """
@@ -182,71 +184,28 @@ class Run:
182
184
  return run_definition_pb2.Phase.Name(self.action.phase)
183
185
 
184
186
  @syncer.wrap
185
- async def wait(self, quiet: bool = False) -> None:
187
+ async def wait(self, quiet: bool = False, wait_for: Literal["terminal", "running"] = "terminal") -> None:
186
188
  """
187
189
  Wait for the run to complete, displaying a rich progress panel with status transitions,
188
190
  time elapsed, and error details in case of failure.
189
191
  """
190
- console = Console()
191
- if self.done():
192
- if not quiet:
193
- console.print(f"[bold green]Run '{self.name}' is already completed.[/bold green]")
194
- return
195
-
196
- try:
197
- with Progress(
198
- SpinnerColumn(),
199
- TextColumn("[progress.description]{task.description}"),
200
- TimeElapsedColumn(),
201
- console=console,
202
- transient=True,
203
- disable=quiet,
204
- ) as progress:
205
- task_id = progress.add_task(f"Waiting for run '{self.name}'...", start=False)
206
-
207
- async for ad in self.watch(cache_data_on_done=True):
208
- if ad is None:
209
- break
210
-
211
- # Update progress description with the current phase
212
- progress.update(
213
- task_id,
214
- description=f"Run: {self.name} in {ad.phase}, Runtime: {ad.runtime} secs "
215
- f"Attempts[{ad.attempts}]",
216
- )
217
- progress.start_task(task_id)
218
-
219
- # If the action is done, handle the final state
220
- if ad.done():
221
- progress.stop_task(task_id)
222
- if ad.pb2.status.phase == run_definition_pb2.PHASE_SUCCEEDED:
223
- console.print(f"[bold green]Run '{self.name}' completed successfully.[/bold green]")
224
- else:
225
- console.print(
226
- f"[bold red]Run '{self.name}' exited unsuccessfully in state {ad.phase}"
227
- f"with error: {ad.error_info}[/bold red]"
228
- )
229
- break
230
- except asyncio.CancelledError:
231
- # Handle cancellation gracefully
232
- pass
233
- except KeyboardInterrupt:
234
- # Handle keyboard interrupt gracefully
235
- console.print(f"\n[bold yellow]Run '{self.name}' was interrupted.[/bold yellow]")
192
+ return await self.action.wait(quiet=quiet, wait_for=wait_for)
236
193
 
237
194
  async def watch(self, cache_data_on_done: bool = False) -> AsyncGenerator[ActionDetails, None]:
238
195
  """
239
196
  Get the details of the run. This is a placeholder for getting the run details.
240
197
  """
241
- async for ad in self.action.watch_details(cache_data_on_done=cache_data_on_done):
242
- if ad is None:
243
- return
244
- yield ad
198
+ return self.action.watch(cache_data_on_done=cache_data_on_done)
245
199
 
246
200
  async def show_logs(
247
- self, attempt: int | None = None, max_lines: int = 100, show_ts: bool = False, raw: bool = False
201
+ self,
202
+ attempt: int | None = None,
203
+ max_lines: int = 100,
204
+ show_ts: bool = False,
205
+ raw: bool = False,
206
+ filter_system: bool = False,
248
207
  ):
249
- await self.action.show_logs(attempt, max_lines, show_ts, raw)
208
+ await self.action.show_logs(attempt, max_lines, show_ts, raw, filter_system=filter_system)
250
209
 
251
210
  async def details(self) -> RunDetails:
252
211
  """
@@ -272,15 +231,20 @@ class Run:
272
231
  )
273
232
 
274
233
  @syncer.wrap
275
- async def cancel(self) -> None:
234
+ async def abort(self):
276
235
  """
277
- Cancel the run.
236
+ Aborts / Terminates the run.
278
237
  """
279
- await get_client().run_service.AbortRun(
280
- run_service_pb2.AbortRunRequest(
281
- run_id=self.pb2.action.id.run,
238
+ try:
239
+ await get_client().run_service.AbortRun(
240
+ run_service_pb2.AbortRunRequest(
241
+ run_id=self.pb2.action.id.run,
242
+ )
282
243
  )
283
- )
244
+ except grpc.aio.AioRpcError as e:
245
+ if e.code() == grpc.StatusCode.NOT_FOUND:
246
+ return
247
+ raise
284
248
 
285
249
  def done(self) -> bool:
286
250
  """
@@ -543,19 +507,25 @@ class Action:
543
507
  return self.pb2.id
544
508
 
545
509
  async def show_logs(
546
- self, attempt: int | None = None, max_lines: int = 30, show_ts: bool = False, raw: bool = False
510
+ self,
511
+ attempt: int | None = None,
512
+ max_lines: int = 30,
513
+ show_ts: bool = False,
514
+ raw: bool = False,
515
+ filter_system: bool = False,
547
516
  ):
548
517
  details = await self.details()
549
518
  if not attempt:
550
519
  attempt = details.attempts
551
- if details.phase in [
552
- run_definition_pb2.PHASE_QUEUED,
553
- run_definition_pb2.PHASE_INITIALIZING,
554
- run_definition_pb2.PHASE_WAITING_FOR_RESOURCES,
555
- ]:
556
- raise RuntimeError("Action has not yet started, so logs are not available.")
520
+ if not details.is_running:
521
+ await self.wait(wait_for="running")
557
522
  return await Logs.create_viewer(
558
- action_id=self.action_id, attempt=attempt, max_lines=max_lines, show_ts=show_ts, raw=raw
523
+ action_id=self.action_id,
524
+ attempt=attempt,
525
+ max_lines=max_lines,
526
+ show_ts=show_ts,
527
+ raw=raw,
528
+ filter_system=filter_system,
559
529
  )
560
530
 
561
531
  async def details(self) -> ActionDetails:
@@ -566,7 +536,9 @@ class Action:
566
536
  self._details = await ActionDetails.get_details.aio(ActionDetails, self.action_id)
567
537
  return cast(ActionDetails, self._details)
568
538
 
569
- async def watch_details(self, cache_data_on_done: bool = False) -> AsyncGenerator[ActionDetails, None]:
539
+ async def watch(
540
+ self, cache_data_on_done: bool = False, wait_for: WaitFor = "terminal"
541
+ ) -> AsyncGenerator[ActionDetails, None]:
570
542
  """
571
543
  Watch the action for updates. This is a placeholder for watching the action.
572
544
  """
@@ -576,9 +548,80 @@ class Action:
576
548
  return
577
549
  self._details = ad
578
550
  yield ad
551
+ if wait_for == "running" and ad.phase == run_definition_pb2.PHASE_RUNNING:
552
+ break
553
+ elif wait_for == "terminal" and _action_done_check(ad.phase):
554
+ break
579
555
  if cache_data_on_done and ad and ad.done():
580
556
  await cast(ActionDetails, self._details).outputs()
581
557
 
558
+ async def wait(self, quiet: bool = False, wait_for: WaitFor = "terminal") -> None:
559
+ """
560
+ Wait for the run to complete, displaying a rich progress panel with status transitions,
561
+ time elapsed, and error details in case of failure.
562
+ """
563
+ console = Console()
564
+ if self.done():
565
+ if not quiet:
566
+ if self.pb2.status.phase == run_definition_pb2.PHASE_SUCCEEDED:
567
+ console.print(
568
+ f"[bold green]Action '{self.name}' in Run '{self.run_name}'"
569
+ f" completed successfully.[/bold green]"
570
+ )
571
+ else:
572
+ details = await self.details()
573
+ console.print(
574
+ f"[bold red]Action '{self.name}' in Run '{self.run_name}'"
575
+ f" exited unsuccessfully in state {self.phase} with error: {details.error_info}[/bold red]"
576
+ )
577
+ return
578
+
579
+ try:
580
+ with Progress(
581
+ SpinnerColumn(),
582
+ TextColumn("[progress.description]{task.description}"),
583
+ TimeElapsedColumn(),
584
+ console=console,
585
+ transient=True,
586
+ disable=quiet,
587
+ ) as progress:
588
+ task_id = progress.add_task(f"Waiting for run '{self.name}'...", start=False)
589
+ progress.start_task(task_id)
590
+
591
+ async for ad in self.watch(cache_data_on_done=True, wait_for=wait_for):
592
+ if ad is None:
593
+ progress.stop_task(task_id)
594
+ break
595
+
596
+ if ad.is_running and wait_for == "running":
597
+ progress.start_task(task_id)
598
+ break
599
+
600
+ # Update progress description with the current phase
601
+ progress.update(
602
+ task_id,
603
+ description=f"Run: {self.name} in {ad.phase}, Runtime: {ad.runtime} secs "
604
+ f"Attempts[{ad.attempts}]",
605
+ )
606
+
607
+ # If the action is done, handle the final state
608
+ if ad.done():
609
+ progress.stop_task(task_id)
610
+ if ad.pb2.status.phase == run_definition_pb2.PHASE_SUCCEEDED:
611
+ console.print(f"[bold green]Run '{self.name}' completed successfully.[/bold green]")
612
+ else:
613
+ console.print(
614
+ f"[bold red]Run '{self.name}' exited unsuccessfully in state {ad.phase}"
615
+ f"with error: {ad.error_info}[/bold red]"
616
+ )
617
+ break
618
+ except asyncio.CancelledError:
619
+ # Handle cancellation gracefully
620
+ pass
621
+ except KeyboardInterrupt:
622
+ # Handle keyboard interrupt gracefully
623
+ pass
624
+
582
625
  def done(self) -> bool:
583
626
  """
584
627
  Check if the action is done.
@@ -706,6 +749,13 @@ class ActionDetails:
706
749
  """
707
750
  return run_definition_pb2.Phase.Name(self.status.phase)
708
751
 
752
+ @property
753
+ def is_running(self) -> bool:
754
+ """
755
+ Check if the action is currently running.
756
+ """
757
+ return self.status.phase == run_definition_pb2.PHASE_RUNNING
758
+
709
759
  @property
710
760
  def name(self) -> str:
711
761
  """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flyte
3
- Version: 0.2.0b2
3
+ Version: 0.2.0b4
4
4
  Summary: Add your description here
5
5
  Author-email: Ketan Umare <kumare3@users.noreply.github.com>
6
6
  Requires-Python: >=3.10
@@ -15,7 +15,7 @@ Requires-Dist: obstore>=0.6.0
15
15
  Requires-Dist: protobuf>=6.30.1
16
16
  Requires-Dist: pydantic>=2.10.6
17
17
  Requires-Dist: pyyaml>=6.0.2
18
- Requires-Dist: rich-click>=1.8.8
18
+ Requires-Dist: rich-click>=1.8.9
19
19
  Requires-Dist: httpx>=0.28.1
20
20
  Requires-Dist: keyring>=25.6.0
21
21
  Requires-Dist: synchronicity>=0.9.11