flyte 0.2.0b3__py3-none-any.whl → 0.2.0b5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (50) hide show
  1. flyte/_build.py +3 -2
  2. flyte/_deploy.py +4 -4
  3. flyte/_initialize.py +17 -3
  4. flyte/_internal/controllers/remote/_core.py +5 -4
  5. flyte/_internal/controllers/remote/_service_protocol.py +6 -6
  6. flyte/_protos/logs/dataplane/payload_pb2.py +28 -24
  7. flyte/_protos/logs/dataplane/payload_pb2.pyi +11 -2
  8. flyte/_protos/workflow/common_pb2.py +27 -0
  9. flyte/_protos/workflow/common_pb2.pyi +14 -0
  10. flyte/_protos/workflow/common_pb2_grpc.py +4 -0
  11. flyte/_protos/workflow/queue_service_pb2.py +39 -41
  12. flyte/_protos/workflow/queue_service_pb2.pyi +30 -28
  13. flyte/_protos/workflow/queue_service_pb2_grpc.py +15 -15
  14. flyte/_protos/workflow/run_definition_pb2.py +14 -14
  15. flyte/_protos/workflow/run_definition_pb2.pyi +4 -2
  16. flyte/_protos/workflow/task_definition_pb2.py +14 -13
  17. flyte/_protos/workflow/task_definition_pb2.pyi +7 -3
  18. flyte/_run.py +7 -5
  19. flyte/_trace.py +1 -6
  20. flyte/_version.py +2 -2
  21. flyte/cli/__init__.py +10 -0
  22. flyte/cli/_abort.py +26 -0
  23. flyte/{_cli → cli}/_common.py +2 -0
  24. flyte/{_cli → cli}/_create.py +1 -1
  25. flyte/{_cli → cli}/_delete.py +1 -1
  26. flyte/{_cli → cli}/_get.py +12 -3
  27. flyte/{_cli → cli}/_run.py +49 -16
  28. flyte/{_cli → cli}/main.py +10 -1
  29. flyte/config/_config.py +2 -0
  30. flyte/errors.py +9 -0
  31. flyte/io/_dir.py +2 -2
  32. flyte/io/_file.py +1 -4
  33. flyte/remote/_data.py +3 -3
  34. flyte/remote/_logs.py +80 -27
  35. flyte/remote/_project.py +8 -9
  36. flyte/remote/_run.py +194 -107
  37. flyte/remote/_secret.py +12 -12
  38. flyte/remote/_task.py +3 -3
  39. flyte/report/_report.py +4 -4
  40. flyte/syncify/__init__.py +5 -0
  41. flyte/syncify/_api.py +277 -0
  42. {flyte-0.2.0b3.dist-info → flyte-0.2.0b5.dist-info}/METADATA +2 -3
  43. {flyte-0.2.0b3.dist-info → flyte-0.2.0b5.dist-info}/RECORD +48 -43
  44. {flyte-0.2.0b3.dist-info → flyte-0.2.0b5.dist-info}/entry_points.txt +1 -1
  45. flyte/_api_commons.py +0 -3
  46. flyte/_cli/__init__.py +0 -0
  47. /flyte/{_cli → cli}/_deploy.py +0 -0
  48. /flyte/{_cli → cli}/_params.py +0 -0
  49. {flyte-0.2.0b3.dist-info → flyte-0.2.0b5.dist-info}/WHEEL +0 -0
  50. {flyte-0.2.0b3.dist-info → flyte-0.2.0b5.dist-info}/top_level.txt +0 -0
flyte/io/_file.py CHANGED
@@ -22,7 +22,6 @@ from fsspec.asyn import AsyncFileSystem
22
22
  from fsspec.utils import get_protocol
23
23
  from mashumaro.types import SerializableType
24
24
  from pydantic import BaseModel, model_validator
25
- from synchronicity import Synchronizer
26
25
 
27
26
  import flyte.storage as storage
28
27
  from flyte._context import internal_ctx
@@ -33,8 +32,6 @@ from flyte.types import TypeEngine, TypeTransformer, TypeTransformerFailedError
33
32
  # Type variable for the file format
34
33
  T = TypeVar("T")
35
34
 
36
- synced = Synchronizer()
37
-
38
35
 
39
36
  class File(BaseModel, Generic[T], SerializableType):
40
37
  """
@@ -314,7 +311,7 @@ class File(BaseModel, Generic[T], SerializableType):
314
311
  with fs.open(self.path, **open_kwargs) as f:
315
312
  yield f
316
313
 
317
- # @synced.wrap - enabling this did not work - synchronicity/pydantic issue
314
+ # TODO sync needs to be implemented
318
315
  async def download(self, local_path: Optional[Union[str, Path]] = None) -> str:
319
316
  """
320
317
  Asynchronously download the file to a local path.
flyte/remote/_data.py CHANGED
@@ -15,7 +15,7 @@ import httpx
15
15
  from flyteidl.service import dataproxy_pb2
16
16
  from google.protobuf import duration_pb2
17
17
 
18
- from flyte._initialize import CommonInit, get_client, get_common_config, requires_client
18
+ from flyte._initialize import CommonInit, ensure_client, get_client, get_common_config
19
19
  from flyte.errors import RuntimeSystemError
20
20
 
21
21
  _UPLOAD_EXPIRES_IN = timedelta(seconds=60)
@@ -109,7 +109,6 @@ async def _upload_single_file(
109
109
  return str_digest, resp.native_url
110
110
 
111
111
 
112
- @requires_client
113
112
  async def upload_file(fp: Path, verify: bool = True) -> Tuple[str, str]:
114
113
  """
115
114
  Uploads a file to a remote location and returns the remote URI.
@@ -119,13 +118,13 @@ async def upload_file(fp: Path, verify: bool = True) -> Tuple[str, str]:
119
118
  :return: A tuple containing the MD5 digest and the remote URI.
120
119
  """
121
120
  # This is a placeholder implementation. Replace with actual upload logic.
121
+ ensure_client()
122
122
  cfg = get_common_config()
123
123
  if not fp.is_file():
124
124
  raise ValueError(f"{fp} is not a single file, upload arg must be a single file.")
125
125
  return await _upload_single_file(cfg, fp, verify=verify)
126
126
 
127
127
 
128
- @requires_client
129
128
  async def upload_dir(dir_path: Path, verify: bool = True) -> str:
130
129
  """
131
130
  Uploads a directory to a remote location and returns the remote URI.
@@ -135,6 +134,7 @@ async def upload_dir(dir_path: Path, verify: bool = True) -> str:
135
134
  :return: The remote URI of the uploaded directory.
136
135
  """
137
136
  # This is a placeholder implementation. Replace with actual upload logic.
137
+ ensure_client()
138
138
  cfg = get_common_config()
139
139
  if not dir_path.is_dir():
140
140
  raise ValueError(f"{dir_path} is not a directory, upload arg must be a directory.")
flyte/remote/_logs.py CHANGED
@@ -3,25 +3,33 @@ from collections import deque
3
3
  from dataclasses import dataclass
4
4
  from typing import AsyncGenerator, AsyncIterator
5
5
 
6
+ import grpc
6
7
  from rich.console import Console
7
8
  from rich.live import Live
8
9
  from rich.panel import Panel
9
10
  from rich.text import Text
10
11
 
11
- from flyte._api_commons import syncer
12
- from flyte._initialize import get_client, requires_client
12
+ from flyte._initialize import ensure_client, get_client
13
13
  from flyte._protos.logs.dataplane import payload_pb2
14
14
  from flyte._protos.workflow import run_definition_pb2, run_logs_service_pb2
15
+ from flyte.errors import LogsNotYetAvailableError
16
+ from flyte.syncify import syncify
15
17
 
18
+ style_map = {
19
+ payload_pb2.LogLineOriginator.SYSTEM: "bold magenta",
20
+ payload_pb2.LogLineOriginator.USER: "cyan",
21
+ payload_pb2.LogLineOriginator.UNKNOWN: "light red",
22
+ }
16
23
 
17
- def _format_line(logline: payload_pb2.LogLine, show_ts: bool) -> Text:
18
- style_map = {
19
- payload_pb2.LogLineOriginator.SYSTEM: "bold magenta",
20
- payload_pb2.LogLineOriginator.USER: "cyan",
21
- payload_pb2.LogLineOriginator.UNKNOWN: "light red",
22
- }
24
+
25
+ def _format_line(logline: payload_pb2.LogLine, show_ts: bool, filter_system: bool) -> Text | None:
26
+ if filter_system:
27
+ if logline.originator == payload_pb2.LogLineOriginator.SYSTEM:
28
+ return None
23
29
  style = style_map.get(logline.originator, "")
24
30
  if "flyte" in logline.message and "flyte.errors" not in logline.message:
31
+ if filter_system:
32
+ return None
25
33
  style = "dim"
26
34
  ts = ""
27
35
  if show_ts:
@@ -34,7 +42,14 @@ class AsyncLogViewer:
34
42
  A class to view logs asynchronously in the console or terminal or jupyter notebook.
35
43
  """
36
44
 
37
- def __init__(self, log_source: AsyncIterator, max_lines: int = 30, name: str = "Logs", show_ts: bool = False):
45
+ def __init__(
46
+ self,
47
+ log_source: AsyncIterator,
48
+ max_lines: int = 30,
49
+ name: str = "Logs",
50
+ show_ts: bool = False,
51
+ filter_system: bool = False,
52
+ ):
38
53
  self.console = Console()
39
54
  self.log_source = log_source
40
55
  self.max_lines = max_lines
@@ -42,47 +57,78 @@ class AsyncLogViewer:
42
57
  self.name = name
43
58
  self.show_ts = show_ts
44
59
  self.total_lines = 0
60
+ self.filter_flyte = filter_system
45
61
 
46
- def _render(self):
62
+ def _render(self) -> Panel:
47
63
  log_text = Text()
48
64
  for line in self.lines:
49
65
  log_text.append(line)
50
66
  return Panel(log_text, title=self.name, border_style="yellow")
51
67
 
52
68
  async def run(self):
53
- with Live(self._render(), refresh_per_second=10, console=self.console) as live:
69
+ with Live(self._render(), refresh_per_second=20, console=self.console) as live:
54
70
  try:
55
71
  async for logline in self.log_source:
56
- formatted = _format_line(logline, show_ts=self.show_ts)
57
- self.lines.append(formatted)
72
+ formatted = _format_line(logline, show_ts=self.show_ts, filter_system=self.filter_flyte)
73
+ if formatted:
74
+ self.lines.append(formatted)
58
75
  self.total_lines += 1
59
76
  live.update(self._render())
60
77
  except asyncio.CancelledError:
61
78
  pass
79
+ except KeyboardInterrupt:
80
+ pass
81
+ except StopAsyncIteration:
82
+ self.console.print("[dim]Log stream ended.[/dim]")
83
+ except LogsNotYetAvailableError as e:
84
+ self.console.print(f"[red]Error:[/red] {e}")
85
+ live.update("")
62
86
  self.console.print(f"Scrolled {self.total_lines} lines of logs.")
63
87
 
64
88
 
65
89
  @dataclass
66
90
  class Logs:
91
+ @syncify
67
92
  @classmethod
68
- @requires_client
69
- @syncer.wrap
70
93
  async def tail(
71
- cls, action_id: run_definition_pb2.ActionIdentifier, attempt: int = 1
94
+ cls,
95
+ action_id: run_definition_pb2.ActionIdentifier,
96
+ attempt: int = 1,
97
+ retry: int = 3,
72
98
  ) -> AsyncGenerator[payload_pb2.LogLine, None]:
73
99
  """
74
100
  Tail the logs for a given action ID and attempt.
75
101
  :param action_id: The action ID to tail logs for.
76
102
  :param attempt: The attempt number (default is 0).
77
103
  """
78
- resp = get_client().logs_service.TailLogs(
79
- run_logs_service_pb2.TailLogsRequest(action_id=action_id, attempt=attempt)
80
- )
81
- async for log_set in resp:
82
- if log_set.logs:
83
- for log in log_set.logs:
84
- for line in log.lines:
85
- yield line
104
+ ensure_client()
105
+ retries = 0
106
+ while True:
107
+ try:
108
+ resp = get_client().logs_service.TailLogs(
109
+ run_logs_service_pb2.TailLogsRequest(action_id=action_id, attempt=attempt)
110
+ )
111
+ async for log_set in resp:
112
+ if log_set.logs:
113
+ for log in log_set.logs:
114
+ for line in log.lines:
115
+ yield line
116
+ return
117
+ except asyncio.CancelledError:
118
+ return
119
+ except KeyboardInterrupt:
120
+ return
121
+ except StopAsyncIteration:
122
+ return
123
+ except grpc.aio.AioRpcError as e:
124
+ retries += 1
125
+ if retries >= retry:
126
+ if e.code() == grpc.StatusCode.NOT_FOUND:
127
+ raise LogsNotYetAvailableError(
128
+ f"Log stream not available for action {action_id.name} in run {action_id.run.name}."
129
+ )
130
+ else:
131
+ await asyncio.sleep(1)
86
132
 
87
133
  @classmethod
88
134
  async def create_viewer(
@@ -92,6 +138,7 @@ class Logs:
92
138
  max_lines: int = 30,
93
139
  show_ts: bool = False,
94
140
  raw: bool = False,
141
+ filter_system: bool = False,
95
142
  ):
96
143
  """
97
144
  Create a log viewer for a given action ID and attempt.
@@ -101,16 +148,22 @@ class Logs:
101
148
  and keep only max_lines in view.
102
149
  :param show_ts: Whether to show timestamps in the logs.
103
150
  :param raw: if True, return the raw log lines instead of a viewer.
151
+ :param filter_system: Whether to filter log lines based on system logs.
104
152
  """
153
+ if attempt < 1:
154
+ raise ValueError("Attempt number must be greater than 0.")
105
155
  if raw:
106
156
  console = Console()
107
- async for line in cls.tail.aio(cls, action_id=action_id, attempt=attempt):
108
- console.print(_format_line(line, show_ts=show_ts), end="")
157
+ async for line in cls.tail.aio(action_id=action_id, attempt=attempt):
158
+ line_text = _format_line(line, show_ts=show_ts, filter_system=filter_system)
159
+ if line_text:
160
+ console.print(line_text, end="")
109
161
  return
110
162
  viewer = AsyncLogViewer(
111
- log_source=cls.tail.aio(cls, action_id=action_id, attempt=attempt),
163
+ log_source=cls.tail.aio(action_id=action_id, attempt=attempt),
112
164
  max_lines=max_lines,
113
165
  show_ts=show_ts,
114
166
  name=f"{action_id.run.name}:{action_id.name} ({attempt})",
167
+ filter_system=filter_system,
115
168
  )
116
169
  await viewer.run()
flyte/remote/_project.py CHANGED
@@ -1,14 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
- import typing
4
3
  from dataclasses import dataclass
5
- from typing import AsyncGenerator, Literal, Tuple
4
+ from typing import AsyncIterator, Iterator, Literal, Tuple, Union
6
5
 
7
6
  import rich.repr
8
7
  from flyteidl.admin import common_pb2, project_pb2
9
8
 
10
- from flyte._api_commons import syncer
11
- from flyte._initialize import get_client, get_common_config, requires_client
9
+ from flyte._initialize import ensure_client, get_client, get_common_config
10
+ from flyte.syncify import syncify
12
11
 
13
12
 
14
13
  @dataclass
@@ -19,9 +18,8 @@ class Project:
19
18
 
20
19
  _pb2: project_pb2.Project
21
20
 
21
+ @syncify
22
22
  @classmethod
23
- @requires_client
24
- @syncer.wrap
25
23
  async def get(cls, name: str, org: str | None = None) -> Project:
26
24
  """
27
25
  Get a run by its ID or name. If both are provided, the ID will take precedence.
@@ -29,6 +27,7 @@ class Project:
29
27
  :param name: The name of the project.
30
28
  :param org: The organization of the project (if applicable).
31
29
  """
30
+ ensure_client()
32
31
  service = get_client().project_domain_service # type: ignore
33
32
  resp = await service.GetProject(
34
33
  project_pb2.ProjectGetRequest(
@@ -38,14 +37,13 @@ class Project:
38
37
  )
39
38
  return cls(resp)
40
39
 
40
+ @syncify
41
41
  @classmethod
42
- @requires_client
43
- @syncer.wrap
44
42
  async def listall(
45
43
  cls,
46
44
  filters: str | None = None,
47
45
  sort_by: Tuple[str, Literal["asc", "desc"]] | None = None,
48
- ) -> typing.Union[typing.Iterator[Project], AsyncGenerator[Project, None]]:
46
+ ) -> Union[AsyncIterator[Project], Iterator[Project]]:
49
47
  """
50
48
  Get a run by its ID or name. If both are provided, the ID will take precedence.
51
49
 
@@ -53,6 +51,7 @@ class Project:
53
51
  :param sort_by: The sorting criteria for the project list, in the format (field, order).
54
52
  :return: An iterator of projects.
55
53
  """
54
+ ensure_client()
56
55
  token = None
57
56
  sort_by = sort_by or ("created_at", "asc")
58
57
  sort_pb2 = common_pb2.Sort(