flyte 2.0.0b13__py3-none-any.whl → 2.0.0b30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. flyte/__init__.py +18 -2
  2. flyte/_bin/debug.py +38 -0
  3. flyte/_bin/runtime.py +62 -8
  4. flyte/_cache/cache.py +4 -2
  5. flyte/_cache/local_cache.py +216 -0
  6. flyte/_code_bundle/_ignore.py +12 -4
  7. flyte/_code_bundle/_packaging.py +13 -9
  8. flyte/_code_bundle/_utils.py +18 -10
  9. flyte/_code_bundle/bundle.py +17 -9
  10. flyte/_constants.py +1 -0
  11. flyte/_context.py +4 -1
  12. flyte/_custom_context.py +73 -0
  13. flyte/_debug/constants.py +38 -0
  14. flyte/_debug/utils.py +17 -0
  15. flyte/_debug/vscode.py +307 -0
  16. flyte/_deploy.py +235 -61
  17. flyte/_environment.py +20 -6
  18. flyte/_excepthook.py +1 -1
  19. flyte/_hash.py +1 -16
  20. flyte/_image.py +178 -81
  21. flyte/_initialize.py +132 -51
  22. flyte/_interface.py +39 -2
  23. flyte/_internal/controllers/__init__.py +4 -5
  24. flyte/_internal/controllers/_local_controller.py +70 -29
  25. flyte/_internal/controllers/_trace.py +1 -1
  26. flyte/_internal/controllers/remote/__init__.py +0 -2
  27. flyte/_internal/controllers/remote/_action.py +14 -16
  28. flyte/_internal/controllers/remote/_client.py +1 -1
  29. flyte/_internal/controllers/remote/_controller.py +68 -70
  30. flyte/_internal/controllers/remote/_core.py +127 -99
  31. flyte/_internal/controllers/remote/_informer.py +19 -10
  32. flyte/_internal/controllers/remote/_service_protocol.py +7 -7
  33. flyte/_internal/imagebuild/docker_builder.py +181 -69
  34. flyte/_internal/imagebuild/image_builder.py +0 -5
  35. flyte/_internal/imagebuild/remote_builder.py +155 -64
  36. flyte/_internal/imagebuild/utils.py +51 -2
  37. flyte/_internal/resolvers/_task_module.py +5 -38
  38. flyte/_internal/resolvers/default.py +2 -2
  39. flyte/_internal/runtime/convert.py +110 -21
  40. flyte/_internal/runtime/entrypoints.py +27 -1
  41. flyte/_internal/runtime/io.py +21 -8
  42. flyte/_internal/runtime/resources_serde.py +20 -6
  43. flyte/_internal/runtime/reuse.py +1 -1
  44. flyte/_internal/runtime/rusty.py +20 -5
  45. flyte/_internal/runtime/task_serde.py +34 -19
  46. flyte/_internal/runtime/taskrunner.py +22 -4
  47. flyte/_internal/runtime/trigger_serde.py +160 -0
  48. flyte/_internal/runtime/types_serde.py +1 -1
  49. flyte/_keyring/__init__.py +0 -0
  50. flyte/_keyring/file.py +115 -0
  51. flyte/_logging.py +201 -39
  52. flyte/_map.py +111 -14
  53. flyte/_module.py +70 -0
  54. flyte/_pod.py +4 -3
  55. flyte/_resources.py +213 -31
  56. flyte/_run.py +110 -39
  57. flyte/_task.py +75 -16
  58. flyte/_task_environment.py +105 -29
  59. flyte/_task_plugins.py +4 -2
  60. flyte/_trace.py +5 -0
  61. flyte/_trigger.py +1000 -0
  62. flyte/_utils/__init__.py +2 -1
  63. flyte/_utils/asyn.py +3 -1
  64. flyte/_utils/coro_management.py +2 -1
  65. flyte/_utils/docker_credentials.py +173 -0
  66. flyte/_utils/module_loader.py +17 -2
  67. flyte/_version.py +3 -3
  68. flyte/cli/_abort.py +3 -3
  69. flyte/cli/_build.py +3 -6
  70. flyte/cli/_common.py +78 -7
  71. flyte/cli/_create.py +182 -4
  72. flyte/cli/_delete.py +23 -1
  73. flyte/cli/_deploy.py +63 -16
  74. flyte/cli/_get.py +79 -34
  75. flyte/cli/_params.py +26 -10
  76. flyte/cli/_plugins.py +209 -0
  77. flyte/cli/_run.py +151 -26
  78. flyte/cli/_serve.py +64 -0
  79. flyte/cli/_update.py +37 -0
  80. flyte/cli/_user.py +17 -0
  81. flyte/cli/main.py +30 -4
  82. flyte/config/_config.py +10 -6
  83. flyte/config/_internal.py +1 -0
  84. flyte/config/_reader.py +29 -8
  85. flyte/connectors/__init__.py +11 -0
  86. flyte/connectors/_connector.py +270 -0
  87. flyte/connectors/_server.py +197 -0
  88. flyte/connectors/utils.py +135 -0
  89. flyte/errors.py +22 -2
  90. flyte/extend.py +8 -1
  91. flyte/extras/_container.py +6 -1
  92. flyte/git/__init__.py +3 -0
  93. flyte/git/_config.py +21 -0
  94. flyte/io/__init__.py +2 -0
  95. flyte/io/_dataframe/__init__.py +2 -0
  96. flyte/io/_dataframe/basic_dfs.py +17 -8
  97. flyte/io/_dataframe/dataframe.py +98 -132
  98. flyte/io/_dir.py +575 -113
  99. flyte/io/_file.py +582 -139
  100. flyte/io/_hashing_io.py +342 -0
  101. flyte/models.py +74 -15
  102. flyte/remote/__init__.py +6 -1
  103. flyte/remote/_action.py +34 -26
  104. flyte/remote/_client/_protocols.py +39 -4
  105. flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
  106. flyte/remote/_client/auth/_authenticators/pkce.py +1 -1
  107. flyte/remote/_client/auth/_channel.py +10 -6
  108. flyte/remote/_client/controlplane.py +17 -5
  109. flyte/remote/_console.py +3 -2
  110. flyte/remote/_data.py +6 -6
  111. flyte/remote/_logs.py +3 -3
  112. flyte/remote/_run.py +64 -8
  113. flyte/remote/_secret.py +26 -17
  114. flyte/remote/_task.py +75 -33
  115. flyte/remote/_trigger.py +306 -0
  116. flyte/remote/_user.py +33 -0
  117. flyte/report/_report.py +1 -1
  118. flyte/storage/__init__.py +6 -1
  119. flyte/storage/_config.py +5 -1
  120. flyte/storage/_parallel_reader.py +274 -0
  121. flyte/storage/_storage.py +200 -103
  122. flyte/types/__init__.py +16 -0
  123. flyte/types/_interface.py +2 -2
  124. flyte/types/_pickle.py +35 -8
  125. flyte/types/_string_literals.py +8 -9
  126. flyte/types/_type_engine.py +40 -70
  127. flyte/types/_utils.py +1 -1
  128. flyte-2.0.0b30.data/scripts/debug.py +38 -0
  129. {flyte-2.0.0b13.data → flyte-2.0.0b30.data}/scripts/runtime.py +62 -8
  130. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/METADATA +11 -3
  131. flyte-2.0.0b30.dist-info/RECORD +192 -0
  132. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/entry_points.txt +3 -0
  133. flyte/_protos/common/authorization_pb2.py +0 -66
  134. flyte/_protos/common/authorization_pb2.pyi +0 -108
  135. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  136. flyte/_protos/common/identifier_pb2.py +0 -93
  137. flyte/_protos/common/identifier_pb2.pyi +0 -110
  138. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  139. flyte/_protos/common/identity_pb2.py +0 -48
  140. flyte/_protos/common/identity_pb2.pyi +0 -72
  141. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  142. flyte/_protos/common/list_pb2.py +0 -36
  143. flyte/_protos/common/list_pb2.pyi +0 -71
  144. flyte/_protos/common/list_pb2_grpc.py +0 -4
  145. flyte/_protos/common/policy_pb2.py +0 -37
  146. flyte/_protos/common/policy_pb2.pyi +0 -27
  147. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  148. flyte/_protos/common/role_pb2.py +0 -37
  149. flyte/_protos/common/role_pb2.pyi +0 -53
  150. flyte/_protos/common/role_pb2_grpc.py +0 -4
  151. flyte/_protos/common/runtime_version_pb2.py +0 -28
  152. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  153. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  154. flyte/_protos/imagebuilder/definition_pb2.py +0 -59
  155. flyte/_protos/imagebuilder/definition_pb2.pyi +0 -140
  156. flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
  157. flyte/_protos/imagebuilder/payload_pb2.py +0 -32
  158. flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
  159. flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
  160. flyte/_protos/imagebuilder/service_pb2.py +0 -29
  161. flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
  162. flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
  163. flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
  164. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
  165. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  166. flyte/_protos/secret/definition_pb2.py +0 -49
  167. flyte/_protos/secret/definition_pb2.pyi +0 -93
  168. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  169. flyte/_protos/secret/payload_pb2.py +0 -62
  170. flyte/_protos/secret/payload_pb2.pyi +0 -94
  171. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  172. flyte/_protos/secret/secret_pb2.py +0 -38
  173. flyte/_protos/secret/secret_pb2.pyi +0 -6
  174. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  175. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  176. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  177. flyte/_protos/workflow/common_pb2.py +0 -27
  178. flyte/_protos/workflow/common_pb2.pyi +0 -14
  179. flyte/_protos/workflow/common_pb2_grpc.py +0 -4
  180. flyte/_protos/workflow/environment_pb2.py +0 -29
  181. flyte/_protos/workflow/environment_pb2.pyi +0 -12
  182. flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
  183. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  184. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  185. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  186. flyte/_protos/workflow/queue_service_pb2.py +0 -109
  187. flyte/_protos/workflow/queue_service_pb2.pyi +0 -166
  188. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  189. flyte/_protos/workflow/run_definition_pb2.py +0 -121
  190. flyte/_protos/workflow/run_definition_pb2.pyi +0 -327
  191. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  192. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  193. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  194. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  195. flyte/_protos/workflow/run_service_pb2.py +0 -137
  196. flyte/_protos/workflow/run_service_pb2.pyi +0 -185
  197. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -446
  198. flyte/_protos/workflow/state_service_pb2.py +0 -67
  199. flyte/_protos/workflow/state_service_pb2.pyi +0 -76
  200. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  201. flyte/_protos/workflow/task_definition_pb2.py +0 -79
  202. flyte/_protos/workflow/task_definition_pb2.pyi +0 -81
  203. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  204. flyte/_protos/workflow/task_service_pb2.py +0 -60
  205. flyte/_protos/workflow/task_service_pb2.pyi +0 -59
  206. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
  207. flyte-2.0.0b13.dist-info/RECORD +0 -239
  208. /flyte/{_protos → _debug}/__init__.py +0 -0
  209. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/WHEEL +0 -0
  210. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/licenses/LICENSE +0 -0
  211. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/top_level.txt +0 -0
flyte/cli/_params.py CHANGED
@@ -15,9 +15,9 @@ from typing import get_args
15
15
  import rich_click as click
16
16
  import yaml
17
17
  from click import Parameter
18
- from flyteidl.core.interface_pb2 import Variable
19
- from flyteidl.core.literals_pb2 import Literal
20
- from flyteidl.core.types_pb2 import BlobType, LiteralType, SimpleType
18
+ from flyteidl2.core.interface_pb2 import Variable
19
+ from flyteidl2.core.literals_pb2 import Literal
20
+ from flyteidl2.core.types_pb2 import BlobType, LiteralType, SimpleType
21
21
  from google.protobuf.json_format import MessageToDict
22
22
  from mashumaro.codecs.json import JSONEncoder
23
23
 
@@ -283,13 +283,17 @@ class UnionParamType(click.ParamType):
283
283
  A composite type that allows for multiple types to be specified. This is used for union types.
284
284
  """
285
285
 
286
- def __init__(self, types: typing.List[click.ParamType]):
286
+ def __init__(self, types: typing.List[click.ParamType | None]):
287
287
  super().__init__()
288
288
  self._types = self._sort_precedence(types)
289
- self.name = "|".join([t.name for t in self._types])
289
+ self.name = "|".join([t.name for t in self._types if t is not None])
290
+ self.optional = False
291
+ if None in types:
292
+ self.name = f"Optional[{self.name}]"
293
+ self.optional = True
290
294
 
291
295
  @staticmethod
292
- def _sort_precedence(tp: typing.List[click.ParamType]) -> typing.List[click.ParamType]:
296
+ def _sort_precedence(tp: typing.List[click.ParamType | None]) -> typing.List[click.ParamType]:
293
297
  unprocessed = []
294
298
  str_types = []
295
299
  others = []
@@ -311,6 +315,8 @@ class UnionParamType(click.ParamType):
311
315
  """
312
316
  for p in self._types:
313
317
  try:
318
+ if p is None and value is None:
319
+ return None
314
320
  return p.convert(value, param, ctx)
315
321
  except Exception as e:
316
322
  logger.debug(f"Ignoring conversion error for type {p} trying other variants in Union. Error: {e}")
@@ -433,7 +439,10 @@ def literal_type_to_click_type(lt: LiteralType, python_type: typing.Type) -> cli
433
439
  for i in range(len(lt.union_type.variants)):
434
440
  variant = lt.union_type.variants[i]
435
441
  variant_python_type = typing.get_args(python_type)[i]
436
- cts.append(literal_type_to_click_type(variant, variant_python_type))
442
+ if variant_python_type is type(None):
443
+ cts.append(None)
444
+ else:
445
+ cts.append(literal_type_to_click_type(variant, variant_python_type))
437
446
  return UnionParamType(cts)
438
447
 
439
448
  if lt.HasField("enum_type"):
@@ -461,6 +470,9 @@ class FlyteLiteralConverter(object):
461
470
  def is_bool(self) -> bool:
462
471
  return self.click_type == click.BOOL
463
472
 
473
+ def is_optional(self) -> bool:
474
+ return isinstance(self.click_type, UnionParamType) and self.click_type.optional
475
+
464
476
  def convert(
465
477
  self, ctx: click.Context, param: typing.Optional[click.Parameter], value: typing.Any
466
478
  ) -> typing.Union[Literal, typing.Any]:
@@ -493,7 +505,7 @@ def to_click_option(
493
505
  This handles converting workflow input types to supported click parameters with callbacks to initialize
494
506
  the input values to their expected types.
495
507
  """
496
- from flyteidl.core.types_pb2 import SimpleType
508
+ from flyteidl2.core.types_pb2 import SimpleType
497
509
 
498
510
  if input_name != input_name.lower():
499
511
  # Click does not support uppercase option names: https://github.com/pallets/click/issues/837
@@ -519,15 +531,19 @@ def to_click_option(
519
531
  if literal_var.type.metadata:
520
532
  description_extra = f": {MessageToDict(literal_var.type.metadata)}"
521
533
 
522
- # If a query has been specified, the input is never strictly required at this layer
523
534
  required = False if default_val is not None else True
524
535
  is_flag: typing.Optional[bool] = None
536
+ param_decls = [f"--{input_name}"]
525
537
  if literal_converter.is_bool():
526
538
  required = False
527
539
  is_flag = True
540
+ if default_val is True:
541
+ param_decls = [f"--{input_name}/--no-{input_name}"]
542
+ if literal_converter.is_optional():
543
+ required = False
528
544
 
529
545
  return click.Option(
530
- param_decls=[f"--{input_name}"],
546
+ param_decls=param_decls,
531
547
  type=literal_converter.click_type,
532
548
  is_flag=is_flag,
533
549
  default=default_val,
flyte/cli/_plugins.py ADDED
@@ -0,0 +1,209 @@
1
+ """CLI Plugin System for Flyte.
2
+
3
+ This module provides a plugin system that allows external packages to:
4
+ 1. Register new top-level CLI commands (e.g., flyte my-command)
5
+ 2. Register new subcommands in existing groups (e.g., flyte get my-object)
6
+ 3. Modify behavior of existing commands via hooks
7
+
8
+ Plugins are discovered via Python entry points.
9
+
10
+ Entry Point Groups:
11
+ - flyte.plugins.cli.commands: Register new commands
12
+ - Entry point name "foo" -> flyte foo (top-level command)
13
+ - Entry point name "get.bar" -> flyte get bar (adds subcommand to get group)
14
+ - Note: At most one dot is supported. For nested groups, register the entire
15
+ group hierarchy as a top-level command (without dots).
16
+
17
+ - flyte.plugins.cli.hooks: Modify existing commands
18
+ - Entry point name "run" -> modifies flyte run
19
+ - Entry point name "get.project" -> modifies flyte get project
20
+ - Note: At most one dot is supported.
21
+
22
+ Example Plugin Package:
23
+ # In your-plugin/pyproject.toml
24
+ [project.entry-points."flyte.plugins.cli.commands"]
25
+ my-command = "your_plugin.cli:my_command"
26
+ get.my-object = "your_plugin.cli:get_my_object"
27
+
28
+ [project.entry-points."flyte.plugins.cli.hooks"]
29
+ run = "your_plugin.hooks:modify_run"
30
+
31
+ # In your-plugin/your_plugin/cli.py
32
+ import rich_click as click
33
+
34
+ @click.command()
35
+ def my_command():
36
+ '''My custom top-level command.'''
37
+ click.echo("Hello from plugin!")
38
+
39
+ @click.command()
40
+ def get_my_object():
41
+ '''Get my custom object.'''
42
+ click.echo("Getting my object...")
43
+
44
+ # In your-plugin/your_plugin/hooks.py
45
+ def modify_run(command):
46
+ '''Add behavior to flyte run command.'''
47
+ # Wrap invoke() instead of callback to ensure Click's full machinery runs
48
+ original_invoke = command.invoke
49
+
50
+ def wrapper(ctx):
51
+ # Do something before
52
+ click.echo("Plugin: Starting task...")
53
+
54
+ result = original_invoke(ctx)
55
+
56
+ # Do something after
57
+ click.echo("Plugin: Task completed!")
58
+ return result
59
+
60
+ command.invoke = wrapper
61
+ return command
62
+ """
63
+
64
+ from importlib.metadata import entry_points
65
+ from typing import Callable
66
+
67
+ import rich_click as click
68
+
69
+ from flyte._logging import logger
70
+
71
+ # Type alias for command hooks
72
+ CommandHook = Callable[[click.Command], click.Command]
73
+
74
+
75
+ def discover_and_register_plugins(root_group: click.Group):
76
+ """
77
+ Discover all CLI plugins from installed packages and register them.
78
+
79
+ This function:
80
+ 1. Discovers command plugins and adds them to the CLI
81
+ 2. Discovers hook plugins and applies them to existing commands
82
+
83
+ Args:
84
+ root_group: The root Click command group (main CLI group)
85
+ """
86
+ _load_command_plugins(root_group)
87
+ _load_hook_plugins(root_group)
88
+
89
+
90
+ def _load_command_plugins(root_group: click.Group):
91
+ """Load and register command plugins."""
92
+ for ep in entry_points(group="flyte.plugins.cli.commands"):
93
+ try:
94
+ command = ep.load()
95
+ if not isinstance(command, click.Command):
96
+ logger.warning(f"Plugin {ep.name} did not return a click.Command, got {type(command).__name__}")
97
+ continue
98
+
99
+ # Check if this is a subcommand (contains dot notation)
100
+ if "." in ep.name:
101
+ group_name, command_name = ep.name.split(".", 1)
102
+
103
+ # Validate: only support one level of nesting (group.command)
104
+ if "." in command_name:
105
+ logger.error(
106
+ f"Plugin {ep.name} uses multiple dots, which is not supported. "
107
+ f"Use at most one dot (e.g., 'group.command'). "
108
+ f"For nested groups, register the entire group hierarchy as a top-level command."
109
+ )
110
+ continue
111
+
112
+ _add_subcommand_to_group(root_group, group_name, command_name, command)
113
+ else:
114
+ # Top-level command
115
+ root_group.add_command(command, name=ep.name)
116
+ logger.info(f"Registered plugin command: flyte {ep.name}")
117
+
118
+ except Exception as e:
119
+ logger.error(f"Failed to load plugin command {ep.name}: {e}")
120
+
121
+
122
+ def _load_hook_plugins(root_group: click.Group):
123
+ """Load and apply hook plugins to existing commands."""
124
+ for ep in entry_points(group="flyte.plugins.cli.hooks"):
125
+ try:
126
+ hook = ep.load()
127
+ if not callable(hook):
128
+ logger.warning(f"Plugin hook {ep.name} is not callable")
129
+ continue
130
+
131
+ # Check if this is a subcommand hook (contains dot notation)
132
+ if "." in ep.name:
133
+ group_name, command_name = ep.name.split(".", 1)
134
+
135
+ # Validate: only support one level of nesting (group.command)
136
+ if "." in command_name:
137
+ logger.error(
138
+ f"Plugin hook {ep.name} uses multiple dots, which is not supported. "
139
+ f"Use at most one dot (e.g., 'group.command')."
140
+ )
141
+ continue
142
+
143
+ _apply_hook_to_subcommand(root_group, group_name, command_name, hook)
144
+ else:
145
+ # Top-level command hook
146
+ _apply_hook_to_command(root_group, ep.name, hook)
147
+
148
+ except Exception as e:
149
+ logger.error(f"Failed to apply hook {ep.name}: {e}")
150
+
151
+
152
+ def _add_subcommand_to_group(root_group: click.Group, group_name: str, command_name: str, command: click.Command):
153
+ """Add a subcommand to an existing command group."""
154
+ if group_name not in root_group.commands:
155
+ logger.warning(f"Cannot add plugin subcommand '{command_name}' - group '{group_name}' does not exist")
156
+ return
157
+
158
+ group = root_group.commands[group_name]
159
+ if not isinstance(group, click.Group):
160
+ logger.warning(f"Cannot add plugin subcommand '{command_name}' - '{group_name}' is not a command group")
161
+ return
162
+
163
+ group.add_command(command, name=command_name)
164
+ # lower to debug later
165
+ logger.info(f"Registered plugin subcommand: flyte {group_name} {command_name}")
166
+
167
+
168
+ def _apply_hook_to_command(root_group: click.Group, command_name: str, hook: CommandHook):
169
+ """Apply a hook to a top-level command."""
170
+ if command_name not in root_group.commands:
171
+ logger.warning(f"Cannot apply hook - command '{command_name}' does not exist")
172
+ return
173
+
174
+ original_command = root_group.commands[command_name]
175
+ try:
176
+ modified_command = hook(original_command)
177
+ root_group.commands[command_name] = modified_command
178
+ # lower to debug later
179
+ logger.info(f"Applied hook to command: flyte {command_name}")
180
+ except Exception as e:
181
+ logger.error(f"Hook failed for command {command_name}: {e}")
182
+ root_group.commands[command_name] = original_command
183
+
184
+
185
+ def _apply_hook_to_subcommand(root_group: click.Group, group_name: str, command_name: str, hook: CommandHook):
186
+ """Apply a hook to a subcommand within a group."""
187
+ if group_name not in root_group.commands:
188
+ logger.warning(f"Cannot apply hook - group '{group_name}' does not exist")
189
+ return
190
+
191
+ group = root_group.commands[group_name]
192
+ if not isinstance(group, click.Group):
193
+ logger.warning(f"Cannot apply hook - '{group_name}' is not a command group")
194
+ return
195
+
196
+ if command_name not in group.commands:
197
+ logger.warning(f"Cannot apply hook - subcommand '{command_name}' does not exist in group '{group_name}'")
198
+ return
199
+
200
+ original_command = group.commands[command_name]
201
+ if original_command.callback is not None:
202
+ original_command.callback()
203
+ try:
204
+ modified_command = hook(original_command)
205
+ group.commands[command_name] = modified_command
206
+ logger.info(f"Applied hook to subcommand: flyte {group_name} {command_name}")
207
+ except Exception as e:
208
+ logger.error(f"Hook failed for subcommand {group_name}.{command_name}: {e}")
209
+ group.commands[command_name] = original_command
flyte/cli/_run.py CHANGED
@@ -8,36 +8,34 @@ from pathlib import Path
8
8
  from types import ModuleType
9
9
  from typing import Any, Dict, List, cast
10
10
 
11
- import click
12
- from click import Context, Parameter
13
- from rich.console import Console
11
+ import rich_click as click
14
12
  from typing_extensions import get_args
15
13
 
16
14
  from .._code_bundle._utils import CopyFiles
17
15
  from .._task import TaskTemplate
18
16
  from ..remote import Run
19
17
  from . import _common as common
20
- from ._common import CLIConfig
18
+ from ._common import CLIConfig, initialize_config
21
19
  from ._params import to_click_option
22
20
 
23
21
  RUN_REMOTE_CMD = "deployed-task"
24
22
 
25
23
 
26
24
  @lru_cache()
27
- def _initialize_config(ctx: Context, project: str, domain: str):
25
+ def _initialize_config(ctx: click.Context, project: str, domain: str, root_dir: str | None = None):
28
26
  obj: CLIConfig | None = ctx.obj
29
27
  if obj is None:
30
28
  import flyte.config
31
29
 
32
30
  obj = CLIConfig(flyte.config.auto(), ctx)
33
31
 
34
- obj.init(project, domain)
32
+ obj.init(project, domain, root_dir)
35
33
  return obj
36
34
 
37
35
 
38
36
  @lru_cache()
39
37
  def _list_tasks(
40
- ctx: Context,
38
+ ctx: click.Context,
41
39
  project: str,
42
40
  domain: str,
43
41
  by_task_name: str | None = None,
@@ -45,7 +43,7 @@ def _list_tasks(
45
43
  ) -> list[str]:
46
44
  import flyte.remote
47
45
 
48
- _initialize_config(ctx, project, domain)
46
+ common.initialize_config(ctx, project, domain)
49
47
  return [task.name for task in flyte.remote.Task.listall(by_task_name=by_task_name, by_task_env=by_task_env)]
50
48
 
51
49
 
@@ -78,6 +76,36 @@ class RunArguments:
78
76
  )
79
77
  },
80
78
  )
79
+ root_dir: str | None = field(
80
+ default=None,
81
+ metadata={
82
+ "click.option": click.Option(
83
+ ["--root-dir"],
84
+ type=str,
85
+ help="Override the root source directory, helpful when working with monorepos.",
86
+ )
87
+ },
88
+ )
89
+ raw_data_path: str | None = field(
90
+ default=None,
91
+ metadata={
92
+ "click.option": click.Option(
93
+ ["--raw-data-path"],
94
+ type=str,
95
+ help="Override the output prefix used to store offloaded data types. e.g. s3://bucket/",
96
+ )
97
+ },
98
+ )
99
+ service_account: str | None = field(
100
+ default=None,
101
+ metadata={
102
+ "click.option": click.Option(
103
+ ["--service-account"],
104
+ type=str,
105
+ help="Kubernetes service account. If not provided, the configured default will be used",
106
+ )
107
+ },
108
+ )
81
109
  name: str | None = field(
82
110
  default=None,
83
111
  metadata={
@@ -100,10 +128,35 @@ class RunArguments:
100
128
  )
101
129
  },
102
130
  )
131
+ image: List[str] = field(
132
+ default_factory=list,
133
+ metadata={
134
+ "click.option": click.Option(
135
+ ["--image"],
136
+ type=str,
137
+ multiple=True,
138
+ help="Image to be used in the run. Format: imagename=imageuri. Can be specified multiple times.",
139
+ )
140
+ },
141
+ )
142
+ no_sync_local_sys_paths: bool = field(
143
+ default=True,
144
+ metadata={
145
+ "click.option": click.Option(
146
+ ["--no-sync-local-sys-paths"],
147
+ is_flag=True,
148
+ flag_value=True,
149
+ default=False,
150
+ help="Disable synchronization of local sys.path entries under the root directory "
151
+ "to the remote container.",
152
+ )
153
+ },
154
+ )
103
155
 
104
156
  @classmethod
105
157
  def from_dict(cls, d: Dict[str, Any]) -> RunArguments:
106
- return cls(**d)
158
+ modified = {k: v for k, v in d.items() if k in {f.name for f in fields(cls)}}
159
+ return cls(**modified)
107
160
 
108
161
  @classmethod
109
162
  def options(cls) -> List[click.Option]:
@@ -113,7 +166,7 @@ class RunArguments:
113
166
  return [common.get_option_from_metadata(f.metadata) for f in fields(cls) if f.metadata]
114
167
 
115
168
 
116
- class RunTaskCommand(click.Command):
169
+ class RunTaskCommand(click.RichCommand):
117
170
  def __init__(self, obj_name: str, obj: Any, run_args: RunArguments, *args, **kwargs):
118
171
  self.obj_name = obj_name
119
172
  self.obj = cast(TaskTemplate, obj)
@@ -121,19 +174,39 @@ class RunTaskCommand(click.Command):
121
174
  kwargs.pop("name", None)
122
175
  super().__init__(obj_name, *args, **kwargs)
123
176
 
124
- def invoke(self, ctx: Context):
125
- obj: CLIConfig = _initialize_config(ctx, self.run_args.project, self.run_args.domain)
177
+ def invoke(self, ctx: click.Context):
178
+ obj: CLIConfig = initialize_config(
179
+ ctx,
180
+ self.run_args.project,
181
+ self.run_args.domain,
182
+ self.run_args.root_dir,
183
+ tuple(self.run_args.image) or None,
184
+ not self.run_args.no_sync_local_sys_paths,
185
+ )
126
186
 
127
187
  async def _run():
128
188
  import flyte
129
189
 
190
+ console = common.get_console()
130
191
  r = await flyte.with_runcontext(
131
192
  copy_style=self.run_args.copy_style,
132
193
  mode="local" if self.run_args.local else "remote",
133
194
  name=self.run_args.name,
195
+ raw_data_path=self.run_args.raw_data_path,
196
+ service_account=self.run_args.service_account,
197
+ log_format=obj.log_format,
134
198
  ).run.aio(self.obj, **ctx.params)
199
+ if self.run_args.local:
200
+ console.print(
201
+ common.get_panel(
202
+ "Local Run",
203
+ f"[green]Completed Local Run, data stored in path: {r.url} [/green] \n"
204
+ f"➡️ Outputs: {r.outputs()}",
205
+ obj.output_format,
206
+ )
207
+ )
208
+ return
135
209
  if isinstance(r, Run) and r.action is not None:
136
- console = Console()
137
210
  console.print(
138
211
  common.get_panel(
139
212
  "Run",
@@ -152,7 +225,7 @@ class RunTaskCommand(click.Command):
152
225
 
153
226
  asyncio.run(_run())
154
227
 
155
- def get_params(self, ctx: Context) -> List[Parameter]:
228
+ def get_params(self, ctx: click.Context) -> List[click.Parameter]:
156
229
  # Note this function may be called multiple times by click.
157
230
  task = self.obj
158
231
  from .._internal.runtime.types_serde import transform_native_to_typed_interface
@@ -162,7 +235,7 @@ class RunTaskCommand(click.Command):
162
235
  return super().get_params(ctx)
163
236
  inputs_interface = task.native_interface.inputs
164
237
 
165
- params: List[Parameter] = []
238
+ params: List[click.Parameter] = []
166
239
  for name, var in interface.inputs.variables.items():
167
240
  default_val = None
168
241
  if inputs_interface[name][1] is not inspect._empty:
@@ -187,6 +260,26 @@ class TaskPerFileGroup(common.ObjectsPerFileGroup):
187
260
  def _filter_objects(self, module: ModuleType) -> Dict[str, Any]:
188
261
  return {k: v for k, v in module.__dict__.items() if isinstance(v, TaskTemplate)}
189
262
 
263
+ def list_commands(self, ctx):
264
+ common.initialize_config(
265
+ ctx,
266
+ self.run_args.project,
267
+ self.run_args.domain,
268
+ self.run_args.root_dir,
269
+ sync_local_sys_paths=not self.run_args.no_sync_local_sys_paths,
270
+ )
271
+ return super().list_commands(ctx)
272
+
273
+ def get_command(self, ctx, obj_name):
274
+ common.initialize_config(
275
+ ctx,
276
+ self.run_args.project,
277
+ self.run_args.domain,
278
+ self.run_args.root_dir,
279
+ sync_local_sys_paths=not self.run_args.no_sync_local_sys_paths,
280
+ )
281
+ return super().get_command(ctx, obj_name)
282
+
190
283
  def _get_command_for_obj(self, ctx: click.Context, obj_name: str, obj: Any) -> click.Command:
191
284
  obj = cast(TaskTemplate, obj)
192
285
  return RunTaskCommand(
@@ -197,7 +290,7 @@ class TaskPerFileGroup(common.ObjectsPerFileGroup):
197
290
  )
198
291
 
199
292
 
200
- class RunReferenceTaskCommand(click.Command):
293
+ class RunReferenceTaskCommand(click.RichCommand):
201
294
  def __init__(self, task_name: str, run_args: RunArguments, version: str | None, *args, **kwargs):
202
295
  self.task_name = task_name
203
296
  self.run_args = run_args
@@ -206,13 +299,20 @@ class RunReferenceTaskCommand(click.Command):
206
299
  super().__init__(*args, **kwargs)
207
300
 
208
301
  def invoke(self, ctx: click.Context):
209
- obj: CLIConfig = _initialize_config(ctx, self.run_args.project, self.run_args.domain)
302
+ obj: CLIConfig = common.initialize_config(
303
+ ctx,
304
+ self.run_args.project,
305
+ self.run_args.domain,
306
+ self.run_args.root_dir,
307
+ tuple(self.run_args.image) or None,
308
+ not self.run_args.no_sync_local_sys_paths,
309
+ )
210
310
 
211
311
  async def _run():
212
- import flyte
213
312
  import flyte.remote
214
313
 
215
314
  task = flyte.remote.Task.get(self.task_name, version=self.version, auto_version="latest")
315
+ console = common.get_console()
216
316
 
217
317
  r = await flyte.with_runcontext(
218
318
  copy_style=self.run_args.copy_style,
@@ -220,7 +320,6 @@ class RunReferenceTaskCommand(click.Command):
220
320
  name=self.run_args.name,
221
321
  ).run.aio(task, **ctx.params)
222
322
  if isinstance(r, Run) and r.action is not None:
223
- console = Console()
224
323
  console.print(
225
324
  common.get_panel(
226
325
  "Run",
@@ -239,12 +338,17 @@ class RunReferenceTaskCommand(click.Command):
239
338
 
240
339
  asyncio.run(_run())
241
340
 
242
- def get_params(self, ctx: Context) -> List[Parameter]:
341
+ def get_params(self, ctx: click.Context) -> List[click.Parameter]:
243
342
  # Note this function may be called multiple times by click.
244
343
  import flyte.remote
245
344
  from flyte._internal.runtime.types_serde import transform_native_to_typed_interface
246
345
 
247
- _initialize_config(ctx, self.run_args.project, self.run_args.domain)
346
+ common.initialize_config(
347
+ ctx,
348
+ self.run_args.project,
349
+ self.run_args.domain,
350
+ sync_local_sys_paths=not self.run_args.no_sync_local_sys_paths,
351
+ )
248
352
 
249
353
  task = flyte.remote.Task.get(self.task_name, auto_version="latest")
250
354
  task_details = task.fetch()
@@ -254,7 +358,7 @@ class RunReferenceTaskCommand(click.Command):
254
358
  return super().get_params(ctx)
255
359
  inputs_interface = task_details.interface.inputs
256
360
 
257
- params: List[Parameter] = []
361
+ params: List[click.Parameter] = []
258
362
  for name, var in interface.inputs.variables.items():
259
363
  default_val = None
260
364
  if inputs_interface[name][1] is not inspect._empty:
@@ -322,7 +426,6 @@ class ReferenceTaskGroup(common.GroupBase):
322
426
 
323
427
  def get_command(self, ctx, name):
324
428
  env, task, version = self._parse_task_name(name)
325
-
326
429
  match env, task, version:
327
430
  case env, None, None:
328
431
  if self._env_is_task(ctx, env):
@@ -383,14 +486,14 @@ class TaskFiles(common.FileGroup):
383
486
  super().__init__(*args, directory=directory, **kwargs)
384
487
 
385
488
  def list_commands(self, ctx):
386
- return [
489
+ v = [
387
490
  RUN_REMOTE_CMD,
388
- *self.files,
491
+ *super().list_commands(ctx),
389
492
  ]
493
+ return v
390
494
 
391
495
  def get_command(self, ctx, cmd_name):
392
496
  run_args = RunArguments.from_dict(ctx.params)
393
-
394
497
  if cmd_name == RUN_REMOTE_CMD:
395
498
  return ReferenceTaskGroup(
396
499
  name=cmd_name,
@@ -437,6 +540,28 @@ Flyte environment:
437
540
  flyte run --local hello.py my_task --arg1 value1 --arg2 value2
438
541
  ```
439
542
 
543
+ You can provide image mappings with `--image` flag. This allows you to specify
544
+ the image URI for the task environment during CLI execution without changing
545
+ the code. Any images defined with `Image.from_ref_name("name")` will resolve to the
546
+ corresponding URIs you specify here.
547
+
548
+ ```bash
549
+ flyte run hello.py my_task --image my_image=ghcr.io/myorg/my-image:v1.0
550
+ ```
551
+
552
+ If the image name is not provided, it is regarded as a default image and will
553
+ be used when no image is specified in TaskEnvironment:
554
+
555
+ ```bash
556
+ flyte run hello.py my_task --image ghcr.io/myorg/default-image:latest
557
+ ```
558
+
559
+ You can specify multiple image arguments:
560
+
561
+ ```bash
562
+ flyte run hello.py my_task --image ghcr.io/org/default:latest --image gpu=ghcr.io/org/gpu:v2.0
563
+ ```
564
+
440
565
  To run tasks that you've already deployed to Flyte, use the {RUN_REMOTE_CMD} command:
441
566
 
442
567
  ```bash