flyte 2.0.0b13__py3-none-any.whl → 2.0.0b15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (45) hide show
  1. flyte/_bin/debug.py +38 -0
  2. flyte/_bin/runtime.py +13 -0
  3. flyte/_code_bundle/_utils.py +2 -0
  4. flyte/_code_bundle/bundle.py +4 -4
  5. flyte/_debug/__init__.py +0 -0
  6. flyte/_debug/constants.py +39 -0
  7. flyte/_debug/utils.py +17 -0
  8. flyte/_debug/vscode.py +300 -0
  9. flyte/_image.py +32 -6
  10. flyte/_initialize.py +14 -28
  11. flyte/_internal/controllers/remote/_action.py +1 -1
  12. flyte/_internal/controllers/remote/_controller.py +35 -35
  13. flyte/_internal/imagebuild/docker_builder.py +11 -15
  14. flyte/_internal/imagebuild/remote_builder.py +52 -23
  15. flyte/_internal/runtime/entrypoints.py +3 -0
  16. flyte/_internal/runtime/task_serde.py +1 -2
  17. flyte/_internal/runtime/taskrunner.py +9 -3
  18. flyte/_protos/common/identifier_pb2.py +25 -19
  19. flyte/_protos/common/identifier_pb2.pyi +10 -0
  20. flyte/_protos/imagebuilder/definition_pb2.py +32 -31
  21. flyte/_protos/imagebuilder/definition_pb2.pyi +25 -12
  22. flyte/_protos/workflow/queue_service_pb2.py +26 -24
  23. flyte/_protos/workflow/queue_service_pb2.pyi +6 -4
  24. flyte/_protos/workflow/run_definition_pb2.py +50 -48
  25. flyte/_protos/workflow/run_definition_pb2.pyi +41 -16
  26. flyte/_protos/workflow/task_definition_pb2.py +16 -13
  27. flyte/_protos/workflow/task_definition_pb2.pyi +7 -0
  28. flyte/_task.py +6 -6
  29. flyte/_task_environment.py +4 -4
  30. flyte/_version.py +3 -3
  31. flyte/cli/_build.py +2 -3
  32. flyte/cli/_run.py +11 -12
  33. flyte/models.py +2 -0
  34. flyte/remote/_action.py +5 -2
  35. flyte/remote/_client/auth/_authenticators/device_code.py +1 -1
  36. flyte/remote/_client/auth/_authenticators/pkce.py +1 -1
  37. flyte/remote/_task.py +4 -4
  38. flyte-2.0.0b15.data/scripts/debug.py +38 -0
  39. {flyte-2.0.0b13.data → flyte-2.0.0b15.data}/scripts/runtime.py +13 -0
  40. {flyte-2.0.0b13.dist-info → flyte-2.0.0b15.dist-info}/METADATA +2 -2
  41. {flyte-2.0.0b13.dist-info → flyte-2.0.0b15.dist-info}/RECORD +45 -39
  42. {flyte-2.0.0b13.dist-info → flyte-2.0.0b15.dist-info}/WHEEL +0 -0
  43. {flyte-2.0.0b13.dist-info → flyte-2.0.0b15.dist-info}/entry_points.txt +0 -0
  44. {flyte-2.0.0b13.dist-info → flyte-2.0.0b15.dist-info}/licenses/LICENSE +0 -0
  45. {flyte-2.0.0b13.dist-info → flyte-2.0.0b15.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@ _sym_db = _symbol_database.Default()
13
13
 
14
14
  from flyte._protos.common import identifier_pb2 as common_dot_identifier__pb2
15
15
  from flyte._protos.common import identity_pb2 as common_dot_identity__pb2
16
+ from flyteidl.core import interface_pb2 as flyteidl_dot_core_dot_interface__pb2
16
17
  from flyteidl.core import tasks_pb2 as flyteidl_dot_core_dot_tasks__pb2
17
18
  from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2
18
19
  from flyte._protos.validate.validate import validate_pb2 as validate_dot_validate__pb2
@@ -20,7 +21,7 @@ from flyte._protos.workflow import common_pb2 as workflow_dot_common__pb2
20
21
  from flyte._protos.workflow import environment_pb2 as workflow_dot_environment__pb2
21
22
 
22
23
 
23
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1eworkflow/task_definition.proto\x12\x11\x63loudidl.workflow\x1a\x17\x63ommon/identifier.proto\x1a\x15\x63ommon/identity.proto\x1a\x19\x66lyteidl/core/tasks.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x17validate/validate.proto\x1a\x15workflow/common.proto\x1a\x1aworkflow/environment.proto\"\x8f\x01\n\x08TaskName\x12\x1b\n\x03org\x18\x01 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x03org\x12#\n\x07project\x18\x02 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x07project\x12!\n\x06\x64omain\x18\x03 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x06\x64omain\x12\x1e\n\x04name\x18\x04 \x01(\tB\n\xfa\x42\x07r\x05\x10\x01\x18\xff\x01R\x04name\"\xba\x01\n\x0eTaskIdentifier\x12\x1b\n\x03org\x18\x01 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x03org\x12#\n\x07project\x18\x02 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x07project\x12!\n\x06\x64omain\x18\x03 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x06\x64omain\x12\x1e\n\x04name\x18\x04 \x01(\tB\n\xfa\x42\x07r\x05\x10\x01\x18\xff\x01R\x04name\x12#\n\x07version\x18\x05 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x07version\"\xed\x01\n\x0cTaskMetadata\x12L\n\x0b\x64\x65ployed_by\x18\x01 \x01(\x0b\x32!.cloudidl.common.EnrichedIdentityB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\ndeployedBy\x12\x1d\n\nshort_name\x18\x02 \x01(\tR\tshortName\x12\x45\n\x0b\x64\x65ployed_at\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.TimestampB\x08\xfa\x42\x05\xb2\x01\x02\x08\x01R\ndeployedAt\x12)\n\x10\x65nvironment_name\x18\x04 \x01(\tR\x0f\x65nvironmentName\"\x93\x01\n\x04Task\x12\x44\n\x07task_id\x18\x01 \x01(\x0b\x32!.cloudidl.workflow.TaskIdentifierB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x06taskId\x12\x45\n\x08metadata\x18\x02 \x01(\x0b\x32\x1f.cloudidl.workflow.TaskMetadataB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x08metadata\"\x8a\x02\n\x08TaskSpec\x12J\n\rtask_template\x18\x01 \x01(\x0b\x32\x1b.flyteidl.core.TaskTemplateB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x0ctaskTemplate\x12H\n\x0e\x64\x65\x66\x61ult_inputs\x18\x02 \x03(\x0b\x32!.cloudidl.workflow.NamedParameterR\rdefaultInputs\x12&\n\nshort_name\x18\x03 \x01(\tB\x07\xfa\x42\x04r\x02\x18?R\tshortName\x12@\n\x0b\x65nvironment\x18\x04 \x01(\x0b\x32\x1e.cloudidl.workflow.EnvironmentR\x0b\x65nvironment\"\xd5\x01\n\x0bTaskDetails\x12\x44\n\x07task_id\x18\x01 \x01(\x0b\x32!.cloudidl.workflow.TaskIdentifierB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x06taskId\x12\x45\n\x08metadata\x18\x02 \x01(\x0b\x32\x1f.cloudidl.workflow.TaskMetadataB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x08metadata\x12\x39\n\x04spec\x18\x03 \x01(\x0b\x32\x1b.cloudidl.workflow.TaskSpecB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x04specB\xc0\x01\n\x15\x63om.cloudidl.workflowB\x13TaskDefinitionProtoH\x02P\x01Z+github.com/unionai/cloud/gen/pb-go/workflow\xa2\x02\x03\x43WX\xaa\x02\x11\x43loudidl.Workflow\xca\x02\x11\x43loudidl\\Workflow\xe2\x02\x1d\x43loudidl\\Workflow\\GPBMetadata\xea\x02\x12\x43loudidl::Workflowb\x06proto3')
24
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1eworkflow/task_definition.proto\x12\x11\x63loudidl.workflow\x1a\x17\x63ommon/identifier.proto\x1a\x15\x63ommon/identity.proto\x1a\x1d\x66lyteidl/core/interface.proto\x1a\x19\x66lyteidl/core/tasks.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x17validate/validate.proto\x1a\x15workflow/common.proto\x1a\x1aworkflow/environment.proto\"\x8f\x01\n\x08TaskName\x12\x1b\n\x03org\x18\x01 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x03org\x12#\n\x07project\x18\x02 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x07project\x12!\n\x06\x64omain\x18\x03 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x06\x64omain\x12\x1e\n\x04name\x18\x04 \x01(\tB\n\xfa\x42\x07r\x05\x10\x01\x18\xff\x01R\x04name\"\xba\x01\n\x0eTaskIdentifier\x12\x1b\n\x03org\x18\x01 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x03org\x12#\n\x07project\x18\x02 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x07project\x12!\n\x06\x64omain\x18\x03 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x06\x64omain\x12\x1e\n\x04name\x18\x04 \x01(\tB\n\xfa\x42\x07r\x05\x10\x01\x18\xff\x01R\x04name\x12#\n\x07version\x18\x05 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x07version\"\xed\x01\n\x0cTaskMetadata\x12L\n\x0b\x64\x65ployed_by\x18\x01 \x01(\x0b\x32!.cloudidl.common.EnrichedIdentityB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\ndeployedBy\x12\x1d\n\nshort_name\x18\x02 \x01(\tR\tshortName\x12\x45\n\x0b\x64\x65ployed_at\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.TimestampB\x08\xfa\x42\x05\xb2\x01\x02\x08\x01R\ndeployedAt\x12)\n\x10\x65nvironment_name\x18\x04 \x01(\tR\x0f\x65nvironmentName\"\x93\x01\n\x04Task\x12\x44\n\x07task_id\x18\x01 \x01(\x0b\x32!.cloudidl.workflow.TaskIdentifierB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x06taskId\x12\x45\n\x08metadata\x18\x02 \x01(\x0b\x32\x1f.cloudidl.workflow.TaskMetadataB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x08metadata\"\x8a\x02\n\x08TaskSpec\x12J\n\rtask_template\x18\x01 \x01(\x0b\x32\x1b.flyteidl.core.TaskTemplateB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x0ctaskTemplate\x12H\n\x0e\x64\x65\x66\x61ult_inputs\x18\x02 \x03(\x0b\x32!.cloudidl.workflow.NamedParameterR\rdefaultInputs\x12&\n\nshort_name\x18\x03 \x01(\tB\x07\xfa\x42\x04r\x02\x18?R\tshortName\x12@\n\x0b\x65nvironment\x18\x04 \x01(\x0b\x32\x1e.cloudidl.workflow.EnvironmentR\x0b\x65nvironment\"H\n\tTraceSpec\x12;\n\tinterface\x18\x01 \x01(\x0b\x32\x1d.flyteidl.core.TypedInterfaceR\tinterface\"\xd5\x01\n\x0bTaskDetails\x12\x44\n\x07task_id\x18\x01 \x01(\x0b\x32!.cloudidl.workflow.TaskIdentifierB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x06taskId\x12\x45\n\x08metadata\x18\x02 \x01(\x0b\x32\x1f.cloudidl.workflow.TaskMetadataB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x08metadata\x12\x39\n\x04spec\x18\x03 \x01(\x0b\x32\x1b.cloudidl.workflow.TaskSpecB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x04specB\xc0\x01\n\x15\x63om.cloudidl.workflowB\x13TaskDefinitionProtoH\x02P\x01Z+github.com/unionai/cloud/gen/pb-go/workflow\xa2\x02\x03\x43WX\xaa\x02\x11\x43loudidl.Workflow\xca\x02\x11\x43loudidl\\Workflow\xe2\x02\x1d\x43loudidl\\Workflow\\GPBMetadata\xea\x02\x12\x43loudidl::Workflowb\x06proto3')
24
25
 
25
26
  _globals = globals()
26
27
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -64,16 +65,18 @@ if _descriptor._USE_C_DESCRIPTORS == False:
64
65
  _TASKDETAILS.fields_by_name['metadata']._serialized_options = b'\372B\005\212\001\002\020\001'
65
66
  _TASKDETAILS.fields_by_name['spec']._options = None
66
67
  _TASKDETAILS.fields_by_name['spec']._serialized_options = b'\372B\005\212\001\002\020\001'
67
- _globals['_TASKNAME']._serialized_start=238
68
- _globals['_TASKNAME']._serialized_end=381
69
- _globals['_TASKIDENTIFIER']._serialized_start=384
70
- _globals['_TASKIDENTIFIER']._serialized_end=570
71
- _globals['_TASKMETADATA']._serialized_start=573
72
- _globals['_TASKMETADATA']._serialized_end=810
73
- _globals['_TASK']._serialized_start=813
74
- _globals['_TASK']._serialized_end=960
75
- _globals['_TASKSPEC']._serialized_start=963
76
- _globals['_TASKSPEC']._serialized_end=1229
77
- _globals['_TASKDETAILS']._serialized_start=1232
78
- _globals['_TASKDETAILS']._serialized_end=1445
68
+ _globals['_TASKNAME']._serialized_start=269
69
+ _globals['_TASKNAME']._serialized_end=412
70
+ _globals['_TASKIDENTIFIER']._serialized_start=415
71
+ _globals['_TASKIDENTIFIER']._serialized_end=601
72
+ _globals['_TASKMETADATA']._serialized_start=604
73
+ _globals['_TASKMETADATA']._serialized_end=841
74
+ _globals['_TASK']._serialized_start=844
75
+ _globals['_TASK']._serialized_end=991
76
+ _globals['_TASKSPEC']._serialized_start=994
77
+ _globals['_TASKSPEC']._serialized_end=1260
78
+ _globals['_TRACESPEC']._serialized_start=1262
79
+ _globals['_TRACESPEC']._serialized_end=1334
80
+ _globals['_TASKDETAILS']._serialized_start=1337
81
+ _globals['_TASKDETAILS']._serialized_end=1550
79
82
  # @@protoc_insertion_point(module_scope)
@@ -1,5 +1,6 @@
1
1
  from flyte._protos.common import identifier_pb2 as _identifier_pb2
2
2
  from flyte._protos.common import identity_pb2 as _identity_pb2
3
+ from flyteidl.core import interface_pb2 as _interface_pb2
3
4
  from flyteidl.core import tasks_pb2 as _tasks_pb2
4
5
  from google.protobuf import timestamp_pb2 as _timestamp_pb2
5
6
  from flyte._protos.validate.validate import validate_pb2 as _validate_pb2
@@ -70,6 +71,12 @@ class TaskSpec(_message.Message):
70
71
  environment: _environment_pb2.Environment
71
72
  def __init__(self, task_template: _Optional[_Union[_tasks_pb2.TaskTemplate, _Mapping]] = ..., default_inputs: _Optional[_Iterable[_Union[_common_pb2.NamedParameter, _Mapping]]] = ..., short_name: _Optional[str] = ..., environment: _Optional[_Union[_environment_pb2.Environment, _Mapping]] = ...) -> None: ...
72
73
 
74
+ class TraceSpec(_message.Message):
75
+ __slots__ = ["interface"]
76
+ INTERFACE_FIELD_NUMBER: _ClassVar[int]
77
+ interface: _interface_pb2.TypedInterface
78
+ def __init__(self, interface: _Optional[_Union[_interface_pb2.TypedInterface, _Mapping]] = ...) -> None: ...
79
+
73
80
  class TaskDetails(_message.Message):
74
81
  __slots__ = ["task_id", "metadata", "spec"]
75
82
  TASK_ID_FIELD_NUMBER: _ClassVar[int]
flyte/_task.py CHANGED
@@ -85,7 +85,7 @@ class TaskTemplate(Generic[P, R]):
85
85
 
86
86
  name: str
87
87
  interface: NativeInterface
88
- friendly_name: str = ""
88
+ short_name: str = ""
89
89
  task_type: str = "python"
90
90
  task_type_version: int = 0
91
91
  image: Union[str, Image, Literal["auto"]] = "auto"
@@ -129,9 +129,9 @@ class TaskTemplate(Generic[P, R]):
129
129
  if isinstance(self.retries, int):
130
130
  self.retries = RetryStrategy(count=self.retries)
131
131
 
132
- if self.friendly_name == "":
133
- # If friendly_name is not set, use the name of the task
134
- self.friendly_name = self.name
132
+ if self.short_name == "":
133
+ # If short_name is not set, use the name of the task
134
+ self.short_name = self.name
135
135
 
136
136
  def __getstate__(self):
137
137
  """
@@ -314,7 +314,7 @@ class TaskTemplate(Generic[P, R]):
314
314
  def override(
315
315
  self,
316
316
  *,
317
- friendly_name: Optional[str] = None,
317
+ short_name: Optional[str] = None,
318
318
  resources: Optional[Resources] = None,
319
319
  cache: Optional[CacheRequest] = None,
320
320
  retries: Union[int, RetryStrategy] = 0,
@@ -375,7 +375,7 @@ class TaskTemplate(Generic[P, R]):
375
375
 
376
376
  return replace(
377
377
  self,
378
- friendly_name=friendly_name or self.friendly_name,
378
+ short_name=short_name or self.short_name,
379
379
  resources=resources,
380
380
  cache=cache,
381
381
  retries=retries,
@@ -133,7 +133,7 @@ class TaskEnvironment(Environment):
133
133
  self,
134
134
  _func=None,
135
135
  *,
136
- name: Optional[str] = None,
136
+ short_name: Optional[str] = None,
137
137
  cache: CacheRequest | None = None,
138
138
  retries: Union[int, RetryStrategy] = 0,
139
139
  timeout: Union[timedelta, int] = 0,
@@ -147,7 +147,7 @@ class TaskEnvironment(Environment):
147
147
 
148
148
  :param _func: Optional The function to decorate. If not provided, the decorator will return a callable that
149
149
  accepts a function to be decorated.
150
- :param name: Optional A friendly name for the task (defaults to the function name)
150
+ :param short_name: Optional A friendly name for the task (defaults to the function name)
151
151
  :param cache: Optional The cache policy for the task, defaults to auto, which will cache the results of the
152
152
  task.
153
153
  :param retries: Optional The number of retries for the task, defaults to 0, which means no retries.
@@ -166,7 +166,7 @@ class TaskEnvironment(Environment):
166
166
  raise ValueError("Cannot set pod_template when environment is reusable.")
167
167
 
168
168
  def decorator(func: FunctionTypes) -> AsyncFunctionTaskTemplate[P, R]:
169
- friendly_name = name or func.__name__
169
+ short = short_name or func.__name__
170
170
  task_name = self.name + "." + func.__name__
171
171
 
172
172
  if not inspect.iscoroutinefunction(func) and self.reusable is not None:
@@ -207,7 +207,7 @@ class TaskEnvironment(Environment):
207
207
  parent_env=weakref.ref(self),
208
208
  interface=NativeInterface.from_callable(func),
209
209
  report=report,
210
- friendly_name=friendly_name,
210
+ short_name=short,
211
211
  plugin_config=self.plugin_config,
212
212
  max_inline_io_bytes=max_inline_io_bytes,
213
213
  )
flyte/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '2.0.0b13'
32
- __version_tuple__ = version_tuple = (2, 0, 0, 'b13')
31
+ __version__ = version = '2.0.0b15'
32
+ __version_tuple__ = version_tuple = (2, 0, 0, 'b15')
33
33
 
34
- __commit_id__ = commit_id = 'g07f30e36d'
34
+ __commit_id__ = commit_id = 'gaa61639b5'
flyte/cli/_build.py CHANGED
@@ -3,8 +3,7 @@ from pathlib import Path
3
3
  from types import ModuleType
4
4
  from typing import Any, Dict, List, cast
5
5
 
6
- import click
7
- from click import Context
6
+ import rich_click as click
8
7
 
9
8
  import flyte
10
9
 
@@ -44,7 +43,7 @@ class BuildEnvCommand(click.Command):
44
43
  self.build_args = build_args
45
44
  super().__init__(*args, **kwargs)
46
45
 
47
- def invoke(self, ctx: Context):
46
+ def invoke(self, ctx: click.Context):
48
47
  from rich.console import Console
49
48
 
50
49
  console = Console()
flyte/cli/_run.py CHANGED
@@ -8,8 +8,7 @@ from pathlib import Path
8
8
  from types import ModuleType
9
9
  from typing import Any, Dict, List, cast
10
10
 
11
- import click
12
- from click import Context, Parameter
11
+ import rich_click as click
13
12
  from rich.console import Console
14
13
  from typing_extensions import get_args
15
14
 
@@ -24,7 +23,7 @@ RUN_REMOTE_CMD = "deployed-task"
24
23
 
25
24
 
26
25
  @lru_cache()
27
- def _initialize_config(ctx: Context, project: str, domain: str):
26
+ def _initialize_config(ctx: click.Context, project: str, domain: str):
28
27
  obj: CLIConfig | None = ctx.obj
29
28
  if obj is None:
30
29
  import flyte.config
@@ -37,7 +36,7 @@ def _initialize_config(ctx: Context, project: str, domain: str):
37
36
 
38
37
  @lru_cache()
39
38
  def _list_tasks(
40
- ctx: Context,
39
+ ctx: click.Context,
41
40
  project: str,
42
41
  domain: str,
43
42
  by_task_name: str | None = None,
@@ -121,7 +120,7 @@ class RunTaskCommand(click.Command):
121
120
  kwargs.pop("name", None)
122
121
  super().__init__(obj_name, *args, **kwargs)
123
122
 
124
- def invoke(self, ctx: Context):
123
+ def invoke(self, ctx: click.Context):
125
124
  obj: CLIConfig = _initialize_config(ctx, self.run_args.project, self.run_args.domain)
126
125
 
127
126
  async def _run():
@@ -152,7 +151,7 @@ class RunTaskCommand(click.Command):
152
151
 
153
152
  asyncio.run(_run())
154
153
 
155
- def get_params(self, ctx: Context) -> List[Parameter]:
154
+ def get_params(self, ctx: click.Context) -> List[click.Parameter]:
156
155
  # Note this function may be called multiple times by click.
157
156
  task = self.obj
158
157
  from .._internal.runtime.types_serde import transform_native_to_typed_interface
@@ -162,7 +161,7 @@ class RunTaskCommand(click.Command):
162
161
  return super().get_params(ctx)
163
162
  inputs_interface = task.native_interface.inputs
164
163
 
165
- params: List[Parameter] = []
164
+ params: List[click.Parameter] = []
166
165
  for name, var in interface.inputs.variables.items():
167
166
  default_val = None
168
167
  if inputs_interface[name][1] is not inspect._empty:
@@ -239,7 +238,7 @@ class RunReferenceTaskCommand(click.Command):
239
238
 
240
239
  asyncio.run(_run())
241
240
 
242
- def get_params(self, ctx: Context) -> List[Parameter]:
241
+ def get_params(self, ctx: click.Context) -> List[click.Parameter]:
243
242
  # Note this function may be called multiple times by click.
244
243
  import flyte.remote
245
244
  from flyte._internal.runtime.types_serde import transform_native_to_typed_interface
@@ -254,7 +253,7 @@ class RunReferenceTaskCommand(click.Command):
254
253
  return super().get_params(ctx)
255
254
  inputs_interface = task_details.interface.inputs
256
255
 
257
- params: List[Parameter] = []
256
+ params: List[click.Parameter] = []
258
257
  for name, var in interface.inputs.variables.items():
259
258
  default_val = None
260
259
  if inputs_interface[name][1] is not inspect._empty:
@@ -322,7 +321,6 @@ class ReferenceTaskGroup(common.GroupBase):
322
321
 
323
322
  def get_command(self, ctx, name):
324
323
  env, task, version = self._parse_task_name(name)
325
-
326
324
  match env, task, version:
327
325
  case env, None, None:
328
326
  if self._env_is_task(ctx, env):
@@ -383,10 +381,11 @@ class TaskFiles(common.FileGroup):
383
381
  super().__init__(*args, directory=directory, **kwargs)
384
382
 
385
383
  def list_commands(self, ctx):
386
- return [
384
+ v = [
387
385
  RUN_REMOTE_CMD,
388
- *self.files,
386
+ *super().list_commands(ctx),
389
387
  ]
388
+ return v
390
389
 
391
390
  def get_command(self, ctx, cmd_name):
392
391
  run_args = RunArguments.from_dict(ctx.params)
flyte/models.py CHANGED
@@ -166,6 +166,7 @@ class TaskContext:
166
166
  action: ActionID
167
167
  version: str
168
168
  raw_data_path: RawDataPath
169
+ input_path: str | None = None
169
170
  output_path: str
170
171
  run_base_dir: str
171
172
  report: Report
@@ -175,6 +176,7 @@ class TaskContext:
175
176
  compiled_image_cache: ImageCache | None = None
176
177
  data: Dict[str, Any] = field(default_factory=dict)
177
178
  mode: Literal["local", "remote", "hybrid"] = "remote"
179
+ interactive_mode: bool = False
178
180
 
179
181
  def replace(self, **kwargs) -> TaskContext:
180
182
  if "data" in kwargs:
flyte/remote/_action.py CHANGED
@@ -627,8 +627,11 @@ class ActionDetails(ToJSONMixin):
627
627
  )
628
628
  )
629
629
  native_iface = None
630
- if self.pb2.resolved_task_spec:
631
- iface = self.pb2.resolved_task_spec.task_template.interface
630
+ if self.pb2.HasField('task'):
631
+ iface = self.pb2.task.task_template.interface
632
+ native_iface = types.guess_interface(iface)
633
+ elif self.pb2.HasField('trace'):
634
+ iface = self.pb2.trace.interface
632
635
  native_iface = types.guess_interface(iface)
633
636
 
634
637
  if resp.inputs:
@@ -81,7 +81,7 @@ class DeviceCodeAuthenticator(Authenticator):
81
81
  for_endpoint=self._endpoint,
82
82
  )
83
83
  except (AuthenticationError, AuthenticationPending):
84
- logger.warning("Failed to refresh token. Kicking off a full authorization flow.")
84
+ logger.warning("Logging in...")
85
85
 
86
86
  """Fall back to device flow"""
87
87
  resp = await token_client.get_device_code(
@@ -123,7 +123,7 @@ class PKCEAuthenticator(Authenticator):
123
123
  try:
124
124
  return await self._auth_client.refresh_access_token(self._creds)
125
125
  except AccessTokenNotFoundError:
126
- logger.warning("Failed to refresh token. Kicking off a full authorization flow.")
126
+ logger.warning("Logging in...")
127
127
 
128
128
  return await self._auth_client.get_creds_from_remote()
129
129
 
flyte/remote/_task.py CHANGED
@@ -283,7 +283,7 @@ class TaskDetails(ToJSONMixin):
283
283
  def override(
284
284
  self,
285
285
  *,
286
- friendly_name: Optional[str] = None,
286
+ short_name: Optional[str] = None,
287
287
  resources: Optional[flyte.Resources] = None,
288
288
  retries: Union[int, flyte.RetryStrategy] = 0,
289
289
  timeout: Optional[flyte.TimeoutType] = None,
@@ -297,8 +297,8 @@ class TaskDetails(ToJSONMixin):
297
297
  f"Check the parameters for override method."
298
298
  )
299
299
  template = self.pb2.spec.task_template
300
- if friendly_name:
301
- self.pb2.metadata.short_name = friendly_name
300
+ if short_name:
301
+ self.pb2.metadata.short_name = short_name
302
302
  if secrets:
303
303
  template.security_context.CopyFrom(get_security_context(secrets))
304
304
  if template.HasField("container"):
@@ -318,7 +318,7 @@ class TaskDetails(ToJSONMixin):
318
318
  """
319
319
  Rich representation of the task.
320
320
  """
321
- yield "friendly_name", self.pb2.spec.short_name
321
+ yield "short_name", self.pb2.spec.short_name
322
322
  yield "environment", self.pb2.spec.environment
323
323
  yield "default_inputs_keys", self.default_input_args
324
324
  yield "required_args", self.required_args
@@ -0,0 +1,38 @@
1
+ import click
2
+
3
+
4
+ @click.group()
5
+ def _debug():
6
+ """Debug commands for Flyte."""
7
+
8
+
9
+ @_debug.command("resume")
10
+ @click.option("--pid", "-m", type=int, required=True, help="PID of the vscode server.")
11
+ def resume(pid):
12
+ """
13
+ Resume a Flyte task for debugging purposes.
14
+
15
+ Args:
16
+ pid (int): PID of the vscode server.
17
+ """
18
+ import os
19
+ import signal
20
+
21
+ print("Terminating server and resuming task.")
22
+ answer = (
23
+ input(
24
+ "This operation will kill the server. All unsaved data will be lost,"
25
+ " and you will no longer be able to connect to it. Do you really want to terminate? (Y/N): "
26
+ )
27
+ .strip()
28
+ .upper()
29
+ )
30
+ if answer == "Y":
31
+ os.kill(pid, signal.SIGTERM)
32
+ print("The server has been terminated and the task has been resumed.")
33
+ else:
34
+ print("Operation canceled.")
35
+
36
+
37
+ if __name__ == "__main__":
38
+ _debug()
@@ -26,6 +26,7 @@ DOMAIN_NAME = "FLYTE_INTERNAL_TASK_DOMAIN"
26
26
  ORG_NAME = "_U_ORG_NAME"
27
27
  ENDPOINT_OVERRIDE = "_U_EP_OVERRIDE"
28
28
  RUN_OUTPUT_BASE_DIR = "_U_RUN_BASE"
29
+ FLYTE_ENABLE_VSCODE_KEY = "_F_E_VS"
29
30
 
30
31
  # TODO: Remove this after proper auth is implemented
31
32
  _UNION_EAGER_API_KEY_ENV_VAR = "_UNION_EAGER_API_KEY"
@@ -49,6 +50,8 @@ def _pass_through():
49
50
  @click.option("--project", envvar=PROJECT_NAME, required=False)
50
51
  @click.option("--domain", envvar=DOMAIN_NAME, required=False)
51
52
  @click.option("--org", envvar=ORG_NAME, required=False)
53
+ @click.option("--debug", envvar=FLYTE_ENABLE_VSCODE_KEY, type=click.BOOL, required=False)
54
+ @click.option("--interactive-mode", type=click.BOOL, required=False)
52
55
  @click.option("--image-cache", required=False)
53
56
  @click.option("--tgz", required=False)
54
57
  @click.option("--pkl", required=False)
@@ -59,12 +62,16 @@ def _pass_through():
59
62
  type=click.UNPROCESSED,
60
63
  nargs=-1,
61
64
  )
65
+ @click.pass_context
62
66
  def main(
67
+ ctx: click.Context,
63
68
  run_name: str,
64
69
  name: str,
65
70
  project: str,
66
71
  domain: str,
67
72
  org: str,
73
+ debug: bool,
74
+ interactive_mode: bool,
68
75
  image_cache: str,
69
76
  version: str,
70
77
  inputs: str,
@@ -109,6 +116,11 @@ def main(
109
116
  if name.startswith("{{"):
110
117
  name = os.getenv("ACTION_NAME", "")
111
118
 
119
+ if debug and name == "a0":
120
+ from flyte._debug.vscode import _start_vscode_server
121
+
122
+ asyncio.run(_start_vscode_server(ctx))
123
+
112
124
  # Figure out how to connect
113
125
  # This detection of api key is a hack for now.
114
126
  controller_kwargs: dict[str, Any] = {"insecure": False}
@@ -143,6 +155,7 @@ def main(
143
155
  version=version,
144
156
  controller=controller,
145
157
  image_cache=ic,
158
+ interactive_mode=interactive_mode or debug,
146
159
  )
147
160
  # Create a coroutine to watch for errors
148
161
  controller_failure = controller.watch_for_errors()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flyte
3
- Version: 2.0.0b13
3
+ Version: 2.0.0b15
4
4
  Summary: Add your description here
5
5
  Author-email: Ketan Umare <kumare3@users.noreply.github.com>
6
6
  Requires-Python: >=3.10
@@ -16,7 +16,7 @@ Requires-Dist: obstore>=0.7.3
16
16
  Requires-Dist: protobuf>=6.30.1
17
17
  Requires-Dist: pydantic>=2.10.6
18
18
  Requires-Dist: pyyaml>=6.0.2
19
- Requires-Dist: rich-click>=1.8.9
19
+ Requires-Dist: rich-click==1.8.9
20
20
  Requires-Dist: httpx<1.0.0,>=0.28.1
21
21
  Requires-Dist: keyring>=25.6.0
22
22
  Requires-Dist: msgpack>=1.1.0