flyte 2.0.0b6__py3-none-any.whl → 2.0.0b8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

flyte/cli/_deploy.py CHANGED
@@ -109,8 +109,8 @@ class DeployEnvCommand(click.Command):
109
109
  version=self.deploy_args.version,
110
110
  )
111
111
 
112
- console.print(common.get_table("Environments", deployment[0].env_repr(), simple=obj.simple))
113
- console.print(common.get_table("Tasks", deployment[0].task_repr(), simple=obj.simple))
112
+ console.print(common.format("Environments", deployment[0].env_repr(), obj.output_format))
113
+ console.print(common.format("Tasks", deployment[0].task_repr(), obj.output_format))
114
114
 
115
115
 
116
116
  class DeployEnvRecursiveCommand(click.Command):
@@ -139,7 +139,7 @@ class DeployEnvRecursiveCommand(click.Command):
139
139
  if failed_paths:
140
140
  console.print(f"Loaded {len(loaded_modules)} modules with, but failed to load {len(failed_paths)} paths:")
141
141
  console.print(
142
- common.get_table("Modules", [[("Path", p), ("Err", e)] for p, e in failed_paths], simple=obj.simple)
142
+ common.format("Modules", [[("Path", p), ("Err", e)] for p, e in failed_paths], obj.output_format)
143
143
  )
144
144
  else:
145
145
  console.print(f"Loaded {len(loaded_modules)} modules")
@@ -149,9 +149,7 @@ class DeployEnvRecursiveCommand(click.Command):
149
149
  if not all_envs:
150
150
  console.print("No environments found to deploy")
151
151
  return
152
- console.print(
153
- common.get_table("Loaded Environments", [[("name", e.name)] for e in all_envs], simple=obj.simple)
154
- )
152
+ console.print(common.format("Loaded Environments", [[("name", e.name)] for e in all_envs], obj.output_format))
155
153
 
156
154
  if not self.deploy_args.ignore_load_errors and len(failed_paths) > 0:
157
155
  raise click.ClickException(
@@ -168,11 +166,9 @@ class DeployEnvRecursiveCommand(click.Command):
168
166
  )
169
167
 
170
168
  console.print(
171
- common.get_table("Environments", [env for d in deployments for env in d.env_repr()], simple=obj.simple)
172
- )
173
- console.print(
174
- common.get_table("Tasks", [task for d in deployments for task in d.task_repr()], simple=obj.simple)
169
+ common.format("Environments", [env for d in deployments for env in d.env_repr()], obj.output_format)
175
170
  )
171
+ console.print(common.format("Tasks", [task for d in deployments for task in d.task_repr()], obj.output_format))
176
172
 
177
173
 
178
174
  class EnvPerFileGroup(common.ObjectsPerFileGroup):
flyte/cli/_get.py CHANGED
@@ -48,13 +48,20 @@ def project(cfg: common.CLIConfig, name: str | None = None):
48
48
  if name:
49
49
  console.print(pretty_repr(Project.get(name)))
50
50
  else:
51
- console.print(common.get_table("Projects", Project.listall(), simple=cfg.simple))
51
+ console.print(common.format("Projects", Project.listall(), cfg.output_format))
52
52
 
53
53
 
54
54
  @get.command(cls=common.CommandBase)
55
55
  @click.argument("name", type=str, required=False)
56
+ @click.option("--limit", type=int, default=100, help="Limit the number of runs to fetch when listing.")
56
57
  @click.pass_obj
57
- def run(cfg: common.CLIConfig, name: str | None = None, project: str | None = None, domain: str | None = None):
58
+ def run(
59
+ cfg: common.CLIConfig,
60
+ name: str | None = None,
61
+ project: str | None = None,
62
+ domain: str | None = None,
63
+ limit: int = 100,
64
+ ):
58
65
  """
59
66
  Get a list of all runs, or details of a specific run by name.
60
67
 
@@ -71,13 +78,13 @@ def run(cfg: common.CLIConfig, name: str | None = None, project: str | None = No
71
78
  details = RunDetails.get(name=name)
72
79
  console.print(pretty_repr(details))
73
80
  else:
74
- console.print(common.get_table("Runs", Run.listall(), simple=cfg.simple))
81
+ console.print(common.format("Runs", Run.listall(limit=limit), cfg.output_format))
75
82
 
76
83
 
77
84
  @get.command(cls=common.CommandBase)
78
85
  @click.argument("name", type=str, required=False)
79
86
  @click.argument("version", type=str, required=False)
80
- @click.option("--limit", type=int, default=100, help="Limit the number of tasks to show.")
87
+ @click.option("--limit", type=int, default=100, help="Limit the number of tasks to fetch.")
81
88
  @click.pass_obj
82
89
  def task(
83
90
  cfg: common.CLIConfig,
@@ -105,9 +112,9 @@ def task(
105
112
  t = v.fetch()
106
113
  console.print(pretty_repr(t))
107
114
  else:
108
- console.print(common.get_table("Tasks", Task.listall(by_task_name=name, limit=limit), simple=cfg.simple))
115
+ console.print(common.format("Tasks", Task.listall(by_task_name=name, limit=limit), cfg.output_format))
109
116
  else:
110
- console.print(common.get_table("Tasks", Task.listall(limit=limit), simple=cfg.simple))
117
+ console.print(common.format("Tasks", Task.listall(limit=limit), cfg.output_format))
111
118
 
112
119
 
113
120
  @get.command(cls=common.CommandBase)
@@ -133,8 +140,8 @@ def action(
133
140
  else:
134
141
  # List all actions for the run
135
142
  console.print(
136
- common.get_table(
137
- f"Actions for {run_name}", flyte.remote._action.Action.listall(for_run_name=run_name), simple=cfg.simple
143
+ common.format(
144
+ f"Actions for {run_name}", flyte.remote._action.Action.listall(for_run_name=run_name), cfg.output_format
138
145
  )
139
146
  )
140
147
 
@@ -194,7 +201,7 @@ def logs(
194
201
 
195
202
  async def _run_log_view(_obj):
196
203
  task = asyncio.create_task(
197
- _obj.show_logs(
204
+ _obj.show_logs.aio(
198
205
  max_lines=lines, show_ts=show_ts, raw=not pretty, attempt=attempt, filter_system=filter_system
199
206
  )
200
207
  )
@@ -230,7 +237,7 @@ def secret(
230
237
  if name:
231
238
  console.print(pretty_repr(remote.Secret.get(name)))
232
239
  else:
233
- console.print(common.get_table("Secrets", remote.Secret.listall(), simple=cfg.simple))
240
+ console.print(common.format("Secrets", remote.Secret.listall(), cfg.output_format))
234
241
 
235
242
 
236
243
  @get.command(cls=common.CommandBase)
@@ -299,7 +306,7 @@ def io(
299
306
  common.get_panel(
300
307
  "Inputs & Outputs",
301
308
  f"[green bold]Inputs[/green bold]\n{inputs}\n\n[blue bold]Outputs[/blue bold]\n{outputs}",
302
- simple=cfg.simple,
309
+ cfg.output_format,
303
310
  )
304
311
  )
305
312
 
flyte/cli/_option.py ADDED
@@ -0,0 +1,33 @@
1
+ from click import Option, UsageError
2
+
3
+
4
+ class MutuallyExclusiveMixin:
5
+ def __init__(self, *args, **kwargs):
6
+ self.mutually_exclusive = set(kwargs.pop("mutually_exclusive", []))
7
+ self.error_format = kwargs.pop(
8
+ "error_msg", "Illegal usage: options '{name}' and '{invalid}' are mutually exclusive"
9
+ )
10
+ super().__init__(*args, **kwargs)
11
+
12
+ def handle_parse_result(self, ctx, opts, args):
13
+ self_present = self.name in opts and opts[self.name] is not None
14
+ others_intersect = self.mutually_exclusive.intersection(opts)
15
+ others_present = others_intersect and any(opts[value] is not None for value in others_intersect)
16
+
17
+ if others_present:
18
+ if self_present:
19
+ raise UsageError(self.error_format.format(name=self.name, invalid=", ".join(self.mutually_exclusive)))
20
+ else:
21
+ self.prompt = None
22
+
23
+ return super().handle_parse_result(ctx, opts, args)
24
+
25
+
26
+ # See https://stackoverflow.com/a/37491504/499285 and https://stackoverflow.com/a/44349292/499285
27
+ class MutuallyExclusiveOption(MutuallyExclusiveMixin, Option):
28
+ def __init__(self, *args, **kwargs):
29
+ mutually_exclusive = kwargs.get("mutually_exclusive", [])
30
+ help = kwargs.get("help", "")
31
+ if mutually_exclusive:
32
+ kwargs["help"] = help + f" Mutually exclusive with {', '.join(mutually_exclusive)}."
33
+ super().__init__(*args, **kwargs)
flyte/cli/_run.py CHANGED
@@ -116,8 +116,8 @@ class RunTaskCommand(click.Command):
116
116
  "Run",
117
117
  f"[green bold]Created Run: {r.name} [/green bold] "
118
118
  f"(Project: {r.action.action_id.run.project}, Domain: {r.action.action_id.run.domain})\n"
119
- f"➡️ [blue bold]{r.url}[/blue bold]",
120
- simple=obj.simple,
119
+ f"➡️ [blue bold][link={r.url}]{r.url}[/link][/blue bold]",
120
+ obj.output_format,
121
121
  )
122
122
  )
123
123
  if self.run_args.follow:
@@ -125,7 +125,7 @@ class RunTaskCommand(click.Command):
125
125
  "[dim]Log streaming enabled, will wait for task to start running "
126
126
  "and log stream to be available[/dim]"
127
127
  )
128
- await r.show_logs(max_lines=30, show_ts=True, raw=False)
128
+ await r.show_logs.aio(max_lines=30, show_ts=True, raw=False)
129
129
 
130
130
  asyncio.run(_run())
131
131
 
@@ -212,21 +212,27 @@ run = TaskFiles(
212
212
  Run a task from a python file.
213
213
 
214
214
  Example usage:
215
+
215
216
  ```bash
216
- flyte run --name examples/basics/hello.py my_task --arg1 value1 --arg2 value2
217
+ flyte run --project my-project --domain development hello.py my_task --arg1 value1 --arg2 value2
217
218
  ```
218
- Note: all arguments for the run command are provided right after the `run` command and before the file name.
219
219
 
220
- You can also specify the project and domain using the `--project` and `--domain` options, respectively. These
221
- options can be set in the config file or passed as command line arguments.
220
+ Arguments to the run command are provided right after the `run` command and before the file name.
221
+ For example, the command above specifies the project and domain.
222
+
223
+ To run a task locally, use the `--local` flag. This will run the task in the local environment instead of the remote
224
+ Flyte environment:
222
225
 
223
- Note: The arguments for the task are provided after the task name and can be retrieved using `--help`
224
- Example:
225
226
  ```bash
226
- flyte run --name examples/basics/hello.py my_task --help
227
+ flyte run --local hello.py my_task --arg1 value1 --arg2 value2
227
228
  ```
228
229
 
229
- To run a task locally, use the `--local` flag. This will run the task in the local environment instead of the remote
230
- Flyte environment.
230
+ Other arguments to the run command are listed below.
231
+
232
+ Arguments for the task itself are provided after the task name and can be retrieved using `--help`. For example:
233
+
234
+ ```bash
235
+ flyte run hello.py my_task --help
236
+ ```
231
237
  """,
232
238
  )
flyte/cli/main.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import rich_click as click
2
+ from typing_extensions import get_args
2
3
 
3
4
  from flyte._logging import initialize_logger, logger
4
5
 
@@ -107,10 +108,13 @@ def _verbosity_to_loglevel(verbosity: int) -> int | None:
107
108
  help="Path to the configuration file to use. If not specified, the default configuration file is used.",
108
109
  )
109
110
  @click.option(
110
- "--simple",
111
- is_flag=True,
112
- default=False,
113
- help="Use a simple output format for commands that support it. This is useful for copying, pasting, and scripting.",
111
+ "--output-format",
112
+ "-of",
113
+ type=click.Choice(get_args(common.OutputFormat), case_sensitive=False),
114
+ default="table",
115
+ help="Output format for commands that support it. Defaults to 'table'.",
116
+ show_default=True,
117
+ required=False,
114
118
  )
115
119
  @click.rich_config(help_config=help_config)
116
120
  @click.pass_context
@@ -121,8 +125,8 @@ def main(
121
125
  verbose: int,
122
126
  org: str | None,
123
127
  config_file: str | None,
124
- simple: bool = False,
125
128
  auth_type: str | None = None,
129
+ output_format: common.OutputFormat = "table",
126
130
  ):
127
131
  """
128
132
  The Flyte CLI is the command line interface for working with the Flyte SDK and backend.
@@ -176,8 +180,8 @@ def main(
176
180
  org=org,
177
181
  config=cfg,
178
182
  ctx=ctx,
179
- simple=simple,
180
183
  auth_type=auth_type,
184
+ output_format=output_format,
181
185
  )
182
186
 
183
187
 
flyte/errors.py CHANGED
@@ -10,6 +10,16 @@ from typing import Literal
10
10
  ErrorKind = Literal["system", "unknown", "user"]
11
11
 
12
12
 
13
+ def silence_grpc_polling_error(loop, context):
14
+ """
15
+ Suppress specific gRPC polling errors in the event loop.
16
+ """
17
+ exc = context.get("exception")
18
+ if isinstance(exc, BlockingIOError):
19
+ return # suppress
20
+ loop.default_exception_handler(context)
21
+
22
+
13
23
  class BaseRuntimeError(RuntimeError):
14
24
  """
15
25
  Base class for all Union runtime errors. These errors are raised when the underlying task execution fails, either
@@ -86,6 +96,9 @@ class TaskTimeoutError(RuntimeUserError):
86
96
  This error is raised when the underlying task execution runs for longer than the specified timeout.
87
97
  """
88
98
 
99
+ def __init__(self, message: str):
100
+ super().__init__("TaskTimeoutError", message, "user")
101
+
89
102
 
90
103
  class RetriesExhaustedError(RuntimeUserError):
91
104
  """
@@ -199,3 +212,12 @@ class InlineIOMaxBytesBreached(RuntimeUserError):
199
212
 
200
213
  def __init__(self, message: str):
201
214
  super().__init__("InlineIOMaxBytesBreached", message, "user")
215
+
216
+
217
+ class RunAbortedError(RuntimeUserError):
218
+ """
219
+ This error is raised when the run is aborted by the user.
220
+ """
221
+
222
+ def __init__(self, message: str):
223
+ super().__init__("RunAbortedError", message, "user")
flyte/remote/_action.py CHANGED
@@ -28,6 +28,7 @@ from flyte._initialize import ensure_client, get_client, get_common_config
28
28
  from flyte._protos.common import identifier_pb2, list_pb2
29
29
  from flyte._protos.workflow import run_definition_pb2, run_service_pb2
30
30
  from flyte._protos.workflow.run_service_pb2 import WatchActionDetailsResponse
31
+ from flyte.remote._common import ToJSONMixin
31
32
  from flyte.remote._logs import Logs
32
33
  from flyte.syncify import syncify
33
34
 
@@ -120,7 +121,7 @@ def _action_done_check(phase: run_definition_pb2.Phase) -> bool:
120
121
 
121
122
 
122
123
  @dataclass
123
- class Action:
124
+ class Action(ToJSONMixin):
124
125
  """
125
126
  A class representing an action. It is used to manage the run of a task and its state on the remote Union API.
126
127
  """
@@ -257,6 +258,7 @@ class Action:
257
258
  """
258
259
  return self.pb2.id
259
260
 
261
+ @syncify
260
262
  async def show_logs(
261
263
  self,
262
264
  attempt: int | None = None,
@@ -411,7 +413,7 @@ class Action:
411
413
 
412
414
 
413
415
  @dataclass
414
- class ActionDetails:
416
+ class ActionDetails(ToJSONMixin):
415
417
  """
416
418
  A class representing an action. It is used to manage the run of a task and its state on the remote Union API.
417
419
  """
@@ -692,7 +694,7 @@ class ActionDetails:
692
694
 
693
695
 
694
696
  @dataclass
695
- class ActionInputs(UserDict):
697
+ class ActionInputs(UserDict, ToJSONMixin):
696
698
  """
697
699
  A class representing the inputs of an action. It is used to manage the inputs of a task and its state on the
698
700
  remote Union API.
@@ -709,7 +711,7 @@ class ActionInputs(UserDict):
709
711
  return rich.pretty.pretty_repr(types.literal_string_repr(self.pb2))
710
712
 
711
713
 
712
- class ActionOutputs(tuple):
714
+ class ActionOutputs(tuple, ToJSONMixin):
713
715
  """
714
716
  A class representing the outputs of an action. It is used to manage the outputs of a task and its state on the
715
717
  remote Union API.
@@ -0,0 +1,30 @@
1
+ import json
2
+
3
+ from google.protobuf.json_format import MessageToDict, MessageToJson
4
+
5
+
6
+ class ToJSONMixin:
7
+ """
8
+ A mixin class that provides a method to convert an object to a JSON-serializable dictionary.
9
+ """
10
+
11
+ def to_dict(self) -> dict:
12
+ """
13
+ Convert the object to a JSON-serializable dictionary.
14
+
15
+ Returns:
16
+ dict: A dictionary representation of the object.
17
+ """
18
+ if hasattr(self, "pb2"):
19
+ return MessageToDict(self.pb2)
20
+ else:
21
+ return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
22
+
23
+ def to_json(self) -> str:
24
+ """
25
+ Convert the object to a JSON string.
26
+
27
+ Returns:
28
+ str: A JSON string representation of the object.
29
+ """
30
+ return MessageToJson(self.pb2) if hasattr(self, "pb2") else json.dumps(self.to_dict())
flyte/remote/_logs.py CHANGED
@@ -30,7 +30,7 @@ def _format_line(logline: payload_pb2.LogLine, show_ts: bool, filter_system: boo
30
30
  if logline.originator == payload_pb2.LogLineOriginator.SYSTEM:
31
31
  return None
32
32
  style = style_map.get(logline.originator, "")
33
- if "flyte" in logline.message and "flyte.errors" not in logline.message:
33
+ if "[flyte]" in logline.message and "flyte.errors" not in logline.message:
34
34
  if filter_system:
35
35
  return None
36
36
  style = "dim"
@@ -101,7 +101,7 @@ class Logs:
101
101
  cls,
102
102
  action_id: identifier_pb2.ActionIdentifier,
103
103
  attempt: int = 1,
104
- retry: int = 3,
104
+ retry: int = 5,
105
105
  ) -> AsyncGenerator[payload_pb2.LogLine, None]:
106
106
  """
107
107
  Tail the logs for a given action ID and attempt.
@@ -135,7 +135,7 @@ class Logs:
135
135
  f"Log stream not available for action {action_id.name} in run {action_id.run.name}."
136
136
  )
137
137
  else:
138
- await asyncio.sleep(1)
138
+ await asyncio.sleep(2)
139
139
 
140
140
  @classmethod
141
141
  async def create_viewer(
flyte/remote/_project.py CHANGED
@@ -9,15 +9,17 @@ from flyteidl.admin import common_pb2, project_pb2
9
9
  from flyte._initialize import ensure_client, get_client
10
10
  from flyte.syncify import syncify
11
11
 
12
+ from ._common import ToJSONMixin
13
+
12
14
 
13
15
  # TODO Add support for orgs again
14
16
  @dataclass
15
- class Project:
17
+ class Project(ToJSONMixin):
16
18
  """
17
19
  A class representing a project in the Union API.
18
20
  """
19
21
 
20
- _pb2: project_pb2.Project
22
+ pb2: project_pb2.Project
21
23
 
22
24
  @syncify
23
25
  @classmethod
@@ -76,11 +78,11 @@ class Project:
76
78
  break
77
79
 
78
80
  def __rich_repr__(self) -> rich.repr.Result:
79
- yield "name", self._pb2.name
80
- yield "id", self._pb2.id
81
- yield "description", self._pb2.description
82
- yield "state", project_pb2.Project.ProjectState.Name(self._pb2.state)
81
+ yield "name", self.pb2.name
82
+ yield "id", self.pb2.id
83
+ yield "description", self.pb2.description
84
+ yield "state", project_pb2.Project.ProjectState.Name(self.pb2.state)
83
85
  yield (
84
86
  "labels",
85
- ", ".join([f"{k}: {v}" for k, v in self._pb2.labels.values.items()]) if self._pb2.labels else None,
87
+ ", ".join([f"{k}: {v}" for k, v in self.pb2.labels.values.items()]) if self.pb2.labels else None,
86
88
  )
flyte/remote/_run.py CHANGED
@@ -13,11 +13,12 @@ from flyte.syncify import syncify
13
13
 
14
14
  from . import Action, ActionDetails, ActionInputs, ActionOutputs
15
15
  from ._action import _action_details_rich_repr, _action_rich_repr
16
+ from ._common import ToJSONMixin
16
17
  from ._console import get_run_url
17
18
 
18
19
 
19
20
  @dataclass
20
- class Run:
21
+ class Run(ToJSONMixin):
21
22
  """
22
23
  A class representing a run of a task. It is used to manage the run of a task and its state on the remote
23
24
  Union API.
@@ -41,12 +42,14 @@ class Run:
41
42
  cls,
42
43
  filters: str | None = None,
43
44
  sort_by: Tuple[str, Literal["asc", "desc"]] | None = None,
45
+ limit: int = 100,
44
46
  ) -> AsyncIterator[Run]:
45
47
  """
46
48
  Get all runs for the current project and domain.
47
49
 
48
50
  :param filters: The filters to apply to the project list.
49
51
  :param sort_by: The sorting criteria for the project list, in the format (field, order).
52
+ :param limit: The maximum number of runs to return.
50
53
  :return: An iterator of runs.
51
54
  """
52
55
  ensure_client()
@@ -57,9 +60,10 @@ class Run:
57
60
  direction=(list_pb2.Sort.ASCENDING if sort_by[1] == "asc" else list_pb2.Sort.DESCENDING),
58
61
  )
59
62
  cfg = get_common_config()
63
+ i = 0
60
64
  while True:
61
65
  req = list_pb2.ListRequest(
62
- limit=100,
66
+ limit=min(100, limit),
63
67
  token=token,
64
68
  sort_by=sort_pb2,
65
69
  )
@@ -76,6 +80,9 @@ class Run:
76
80
  )
77
81
  token = resp.token
78
82
  for r in resp.runs:
83
+ i += 1
84
+ if i > limit:
85
+ return
79
86
  yield cls(r)
80
87
  if not token:
81
88
  break
@@ -134,6 +141,7 @@ class Run:
134
141
  """
135
142
  return self.action.watch(cache_data_on_done=cache_data_on_done)
136
143
 
144
+ @syncify
137
145
  async def show_logs(
138
146
  self,
139
147
  attempt: int | None = None,
@@ -142,7 +150,7 @@ class Run:
142
150
  raw: bool = False,
143
151
  filter_system: bool = False,
144
152
  ):
145
- await self.action.show_logs(attempt, max_lines, show_ts, raw, filter_system=filter_system)
153
+ await self.action.show_logs.aio(attempt, max_lines, show_ts, raw, filter_system=filter_system)
146
154
 
147
155
  @syncify
148
156
  async def details(self) -> RunDetails:
@@ -213,7 +221,7 @@ class Run:
213
221
 
214
222
 
215
223
  @dataclass
216
- class RunDetails:
224
+ class RunDetails(ToJSONMixin):
217
225
  """
218
226
  A class representing a run of a task. It is used to manage the run of a task and its state on the remote
219
227
  Union API.
flyte/remote/_secret.py CHANGED
@@ -7,13 +7,14 @@ import rich.repr
7
7
 
8
8
  from flyte._initialize import ensure_client, get_client, get_common_config
9
9
  from flyte._protos.secret import definition_pb2, payload_pb2
10
+ from flyte.remote._common import ToJSONMixin
10
11
  from flyte.syncify import syncify
11
12
 
12
13
  SecretTypes = Literal["regular", "image_pull"]
13
14
 
14
15
 
15
16
  @dataclass
16
- class Secret:
17
+ class Secret(ToJSONMixin):
17
18
  pb2: definition_pb2.Secret
18
19
 
19
20
  @syncify
flyte/remote/_task.py CHANGED
@@ -3,21 +3,27 @@ from __future__ import annotations
3
3
  import functools
4
4
  from dataclasses import dataclass
5
5
  from threading import Lock
6
- from typing import Any, AsyncIterator, Callable, Coroutine, Dict, Iterator, Literal, Optional, Tuple, Union
6
+ from typing import Any, AsyncIterator, Callable, Coroutine, Dict, Iterator, Literal, Optional, Tuple, Union, cast
7
7
 
8
8
  import rich.repr
9
+ from flyteidl.core import literals_pb2
9
10
  from google.protobuf import timestamp
10
11
 
11
12
  import flyte
12
13
  import flyte.errors
14
+ from flyte._cache.cache import CacheBehavior
13
15
  from flyte._context import internal_ctx
14
16
  from flyte._initialize import ensure_client, get_client, get_common_config
17
+ from flyte._internal.runtime.resources_serde import get_proto_resources
18
+ from flyte._internal.runtime.task_serde import get_proto_retry_strategy, get_proto_timeout, get_security_context
15
19
  from flyte._logging import logger
16
20
  from flyte._protos.common import identifier_pb2, list_pb2
17
21
  from flyte._protos.workflow import task_definition_pb2, task_service_pb2
18
22
  from flyte.models import NativeInterface
19
23
  from flyte.syncify import syncify
20
24
 
25
+ from ._common import ToJSONMixin
26
+
21
27
 
22
28
  def _repr_task_metadata(metadata: task_definition_pb2.TaskMetadata) -> rich.repr.Result:
23
29
  """
@@ -61,6 +67,15 @@ class LazyEntity:
61
67
  raise RuntimeError(f"Error downloading the task {self._name}, (check original exception...)")
62
68
  return self._task
63
69
 
70
+ @syncify
71
+ async def override(
72
+ self,
73
+ **kwargs: Any,
74
+ ) -> LazyEntity:
75
+ task_details = cast(TaskDetails, await self.fetch.aio())
76
+ task_details.override(**kwargs)
77
+ return self
78
+
64
79
  async def __call__(self, *args, **kwargs):
65
80
  """
66
81
  Forwards the call to the underlying task. The entity will be fetched if not already present
@@ -79,7 +94,7 @@ AutoVersioning = Literal["latest", "current"]
79
94
 
80
95
 
81
96
  @dataclass
82
- class TaskDetails:
97
+ class TaskDetails(ToJSONMixin):
83
98
  pb2: task_definition_pb2.TaskDetails
84
99
  max_inline_io_bytes: int = 10 * 1024 * 1024 # 10 MB
85
100
 
@@ -201,11 +216,20 @@ class TaskDetails:
201
216
  """
202
217
  The cache policy of the task.
203
218
  """
219
+ metadata = self.pb2.spec.task_template.metadata
220
+ behavior: CacheBehavior
221
+ if not metadata.discoverable:
222
+ behavior = "disable"
223
+ elif metadata.discovery_version:
224
+ behavior = "override"
225
+ else:
226
+ behavior = "auto"
227
+
204
228
  return flyte.Cache(
205
- behavior="enabled" if self.pb2.spec.task_template.metadata.discoverable else "disable",
206
- version_override=self.pb2.spec.task_template.metadata.discovery_version,
207
- serialize=self.pb2.spec.task_template.metadata.cache_serializable,
208
- ignored_inputs=tuple(self.pb2.spec.task_template.metadata.cache_ignore_input_vars),
229
+ behavior=behavior,
230
+ version_override=metadata.discovery_version if metadata.discovery_version else None,
231
+ serialize=metadata.cache_serializable,
232
+ ignored_inputs=tuple(metadata.cache_ignore_input_vars),
209
233
  )
210
234
 
211
235
  @property
@@ -259,19 +283,33 @@ class TaskDetails:
259
283
  def override(
260
284
  self,
261
285
  *,
262
- local: Optional[bool] = None,
263
- ref: Optional[bool] = None,
264
286
  resources: Optional[flyte.Resources] = None,
265
- cache: flyte.CacheRequest = "auto",
266
287
  retries: Union[int, flyte.RetryStrategy] = 0,
267
288
  timeout: Optional[flyte.TimeoutType] = None,
268
- reusable: Union[flyte.ReusePolicy, Literal["auto"], None] = None,
269
289
  env: Optional[Dict[str, str]] = None,
270
290
  secrets: Optional[flyte.SecretRequest] = None,
271
- max_inline_io_bytes: int | None = None,
272
291
  **kwargs: Any,
273
292
  ) -> TaskDetails:
274
- raise NotImplementedError
293
+ if len(kwargs) > 0:
294
+ raise ValueError(
295
+ f"ReferenceTasks [{self.name}] do not support overriding with kwargs: {kwargs}, "
296
+ f"Check the parameters for override method."
297
+ )
298
+ template = self.pb2.spec.task_template
299
+ if secrets:
300
+ template.security_context.CopyFrom(get_security_context(secrets))
301
+ if template.HasField("container"):
302
+ if env:
303
+ template.container.env.clear()
304
+ template.container.env.extend([literals_pb2.KeyValuePair(key=k, value=v) for k, v in env.items()])
305
+ if resources:
306
+ template.container.resources.CopyFrom(get_proto_resources(resources))
307
+ if retries:
308
+ template.metadata.retries.CopyFrom(get_proto_retry_strategy(retries))
309
+ if timeout:
310
+ template.metadata.timeout.CopyFrom(get_proto_timeout(timeout))
311
+
312
+ return self
275
313
 
276
314
  def __rich_repr__(self) -> rich.repr.Result:
277
315
  """
@@ -294,7 +332,7 @@ class TaskDetails:
294
332
 
295
333
 
296
334
  @dataclass
297
- class Task:
335
+ class Task(ToJSONMixin):
298
336
  pb2: task_definition_pb2.Task
299
337
 
300
338
  def __init__(self, pb2: task_definition_pb2.Task):