flyte 2.0.0b13__py3-none-any.whl → 2.0.0b15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/_bin/debug.py +38 -0
- flyte/_bin/runtime.py +13 -0
- flyte/_code_bundle/_utils.py +2 -0
- flyte/_code_bundle/bundle.py +4 -4
- flyte/_debug/__init__.py +0 -0
- flyte/_debug/constants.py +39 -0
- flyte/_debug/utils.py +17 -0
- flyte/_debug/vscode.py +300 -0
- flyte/_image.py +32 -6
- flyte/_initialize.py +14 -28
- flyte/_internal/controllers/remote/_action.py +1 -1
- flyte/_internal/controllers/remote/_controller.py +35 -35
- flyte/_internal/imagebuild/docker_builder.py +11 -15
- flyte/_internal/imagebuild/remote_builder.py +52 -23
- flyte/_internal/runtime/entrypoints.py +3 -0
- flyte/_internal/runtime/task_serde.py +1 -2
- flyte/_internal/runtime/taskrunner.py +9 -3
- flyte/_protos/common/identifier_pb2.py +25 -19
- flyte/_protos/common/identifier_pb2.pyi +10 -0
- flyte/_protos/imagebuilder/definition_pb2.py +32 -31
- flyte/_protos/imagebuilder/definition_pb2.pyi +25 -12
- flyte/_protos/workflow/queue_service_pb2.py +26 -24
- flyte/_protos/workflow/queue_service_pb2.pyi +6 -4
- flyte/_protos/workflow/run_definition_pb2.py +50 -48
- flyte/_protos/workflow/run_definition_pb2.pyi +41 -16
- flyte/_protos/workflow/task_definition_pb2.py +16 -13
- flyte/_protos/workflow/task_definition_pb2.pyi +7 -0
- flyte/_task.py +6 -6
- flyte/_task_environment.py +4 -4
- flyte/_version.py +3 -3
- flyte/cli/_build.py +2 -3
- flyte/cli/_run.py +11 -12
- flyte/models.py +2 -0
- flyte/remote/_action.py +5 -2
- flyte/remote/_client/auth/_authenticators/device_code.py +1 -1
- flyte/remote/_client/auth/_authenticators/pkce.py +1 -1
- flyte/remote/_task.py +4 -4
- flyte-2.0.0b15.data/scripts/debug.py +38 -0
- {flyte-2.0.0b13.data → flyte-2.0.0b15.data}/scripts/runtime.py +13 -0
- {flyte-2.0.0b13.dist-info → flyte-2.0.0b15.dist-info}/METADATA +2 -2
- {flyte-2.0.0b13.dist-info → flyte-2.0.0b15.dist-info}/RECORD +45 -39
- {flyte-2.0.0b13.dist-info → flyte-2.0.0b15.dist-info}/WHEEL +0 -0
- {flyte-2.0.0b13.dist-info → flyte-2.0.0b15.dist-info}/entry_points.txt +0 -0
- {flyte-2.0.0b13.dist-info → flyte-2.0.0b15.dist-info}/licenses/LICENSE +0 -0
- {flyte-2.0.0b13.dist-info → flyte-2.0.0b15.dist-info}/top_level.txt +0 -0
|
@@ -13,6 +13,7 @@ _sym_db = _symbol_database.Default()
|
|
|
13
13
|
|
|
14
14
|
from flyte._protos.common import identifier_pb2 as common_dot_identifier__pb2
|
|
15
15
|
from flyte._protos.common import identity_pb2 as common_dot_identity__pb2
|
|
16
|
+
from flyteidl.core import interface_pb2 as flyteidl_dot_core_dot_interface__pb2
|
|
16
17
|
from flyteidl.core import tasks_pb2 as flyteidl_dot_core_dot_tasks__pb2
|
|
17
18
|
from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2
|
|
18
19
|
from flyte._protos.validate.validate import validate_pb2 as validate_dot_validate__pb2
|
|
@@ -20,7 +21,7 @@ from flyte._protos.workflow import common_pb2 as workflow_dot_common__pb2
|
|
|
20
21
|
from flyte._protos.workflow import environment_pb2 as workflow_dot_environment__pb2
|
|
21
22
|
|
|
22
23
|
|
|
23
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1eworkflow/task_definition.proto\x12\x11\x63loudidl.workflow\x1a\x17\x63ommon/identifier.proto\x1a\x15\x63ommon/identity.proto\x1a\x19\x66lyteidl/core/tasks.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x17validate/validate.proto\x1a\x15workflow/common.proto\x1a\x1aworkflow/environment.proto\"\x8f\x01\n\x08TaskName\x12\x1b\n\x03org\x18\x01 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x03org\x12#\n\x07project\x18\x02 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x07project\x12!\n\x06\x64omain\x18\x03 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x06\x64omain\x12\x1e\n\x04name\x18\x04 \x01(\tB\n\xfa\x42\x07r\x05\x10\x01\x18\xff\x01R\x04name\"\xba\x01\n\x0eTaskIdentifier\x12\x1b\n\x03org\x18\x01 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x03org\x12#\n\x07project\x18\x02 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x07project\x12!\n\x06\x64omain\x18\x03 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x06\x64omain\x12\x1e\n\x04name\x18\x04 \x01(\tB\n\xfa\x42\x07r\x05\x10\x01\x18\xff\x01R\x04name\x12#\n\x07version\x18\x05 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x07version\"\xed\x01\n\x0cTaskMetadata\x12L\n\x0b\x64\x65ployed_by\x18\x01 \x01(\x0b\x32!.cloudidl.common.EnrichedIdentityB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\ndeployedBy\x12\x1d\n\nshort_name\x18\x02 \x01(\tR\tshortName\x12\x45\n\x0b\x64\x65ployed_at\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.TimestampB\x08\xfa\x42\x05\xb2\x01\x02\x08\x01R\ndeployedAt\x12)\n\x10\x65nvironment_name\x18\x04 \x01(\tR\x0f\x65nvironmentName\"\x93\x01\n\x04Task\x12\x44\n\x07task_id\x18\x01 \x01(\x0b\x32!.cloudidl.workflow.TaskIdentifierB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x06taskId\x12\x45\n\x08metadata\x18\x02 \x01(\x0b\x32\x1f.cloudidl.workflow.TaskMetadataB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x08metadata\"\x8a\x02\n\x08TaskSpec\x12J\n\rtask_template\x18\x01 \x01(\x0b\x32\x1b.flyteidl.core.TaskTemplateB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x0ctaskTemplate\x12H\n\x0e\x64\x65\x66\x61ult_inputs\x18\x02 \x03(\x0b\x32!.cloudidl.workflow.NamedParameterR\rdefaultInputs\x12&\n\nshort_name\x18\x03 \x01(\tB\x07\xfa\x42\x04r\x02\x18?R\tshortName\x12@\n\x0b\x65nvironment\x18\x04 \x01(\x0b\x32\x1e.cloudidl.workflow.EnvironmentR\x0b\x65nvironment\"\xd5\x01\n\x0bTaskDetails\x12\x44\n\x07task_id\x18\x01 \x01(\x0b\x32!.cloudidl.workflow.TaskIdentifierB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x06taskId\x12\x45\n\x08metadata\x18\x02 \x01(\x0b\x32\x1f.cloudidl.workflow.TaskMetadataB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x08metadata\x12\x39\n\x04spec\x18\x03 \x01(\x0b\x32\x1b.cloudidl.workflow.TaskSpecB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x04specB\xc0\x01\n\x15\x63om.cloudidl.workflowB\x13TaskDefinitionProtoH\x02P\x01Z+github.com/unionai/cloud/gen/pb-go/workflow\xa2\x02\x03\x43WX\xaa\x02\x11\x43loudidl.Workflow\xca\x02\x11\x43loudidl\\Workflow\xe2\x02\x1d\x43loudidl\\Workflow\\GPBMetadata\xea\x02\x12\x43loudidl::Workflowb\x06proto3')
|
|
24
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1eworkflow/task_definition.proto\x12\x11\x63loudidl.workflow\x1a\x17\x63ommon/identifier.proto\x1a\x15\x63ommon/identity.proto\x1a\x1d\x66lyteidl/core/interface.proto\x1a\x19\x66lyteidl/core/tasks.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x17validate/validate.proto\x1a\x15workflow/common.proto\x1a\x1aworkflow/environment.proto\"\x8f\x01\n\x08TaskName\x12\x1b\n\x03org\x18\x01 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x03org\x12#\n\x07project\x18\x02 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x07project\x12!\n\x06\x64omain\x18\x03 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x06\x64omain\x12\x1e\n\x04name\x18\x04 \x01(\tB\n\xfa\x42\x07r\x05\x10\x01\x18\xff\x01R\x04name\"\xba\x01\n\x0eTaskIdentifier\x12\x1b\n\x03org\x18\x01 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x03org\x12#\n\x07project\x18\x02 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x07project\x12!\n\x06\x64omain\x18\x03 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x06\x64omain\x12\x1e\n\x04name\x18\x04 \x01(\tB\n\xfa\x42\x07r\x05\x10\x01\x18\xff\x01R\x04name\x12#\n\x07version\x18\x05 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x07version\"\xed\x01\n\x0cTaskMetadata\x12L\n\x0b\x64\x65ployed_by\x18\x01 \x01(\x0b\x32!.cloudidl.common.EnrichedIdentityB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\ndeployedBy\x12\x1d\n\nshort_name\x18\x02 \x01(\tR\tshortName\x12\x45\n\x0b\x64\x65ployed_at\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.TimestampB\x08\xfa\x42\x05\xb2\x01\x02\x08\x01R\ndeployedAt\x12)\n\x10\x65nvironment_name\x18\x04 \x01(\tR\x0f\x65nvironmentName\"\x93\x01\n\x04Task\x12\x44\n\x07task_id\x18\x01 \x01(\x0b\x32!.cloudidl.workflow.TaskIdentifierB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x06taskId\x12\x45\n\x08metadata\x18\x02 \x01(\x0b\x32\x1f.cloudidl.workflow.TaskMetadataB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x08metadata\"\x8a\x02\n\x08TaskSpec\x12J\n\rtask_template\x18\x01 \x01(\x0b\x32\x1b.flyteidl.core.TaskTemplateB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x0ctaskTemplate\x12H\n\x0e\x64\x65\x66\x61ult_inputs\x18\x02 \x03(\x0b\x32!.cloudidl.workflow.NamedParameterR\rdefaultInputs\x12&\n\nshort_name\x18\x03 \x01(\tB\x07\xfa\x42\x04r\x02\x18?R\tshortName\x12@\n\x0b\x65nvironment\x18\x04 \x01(\x0b\x32\x1e.cloudidl.workflow.EnvironmentR\x0b\x65nvironment\"H\n\tTraceSpec\x12;\n\tinterface\x18\x01 \x01(\x0b\x32\x1d.flyteidl.core.TypedInterfaceR\tinterface\"\xd5\x01\n\x0bTaskDetails\x12\x44\n\x07task_id\x18\x01 \x01(\x0b\x32!.cloudidl.workflow.TaskIdentifierB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x06taskId\x12\x45\n\x08metadata\x18\x02 \x01(\x0b\x32\x1f.cloudidl.workflow.TaskMetadataB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x08metadata\x12\x39\n\x04spec\x18\x03 \x01(\x0b\x32\x1b.cloudidl.workflow.TaskSpecB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x04specB\xc0\x01\n\x15\x63om.cloudidl.workflowB\x13TaskDefinitionProtoH\x02P\x01Z+github.com/unionai/cloud/gen/pb-go/workflow\xa2\x02\x03\x43WX\xaa\x02\x11\x43loudidl.Workflow\xca\x02\x11\x43loudidl\\Workflow\xe2\x02\x1d\x43loudidl\\Workflow\\GPBMetadata\xea\x02\x12\x43loudidl::Workflowb\x06proto3')
|
|
24
25
|
|
|
25
26
|
_globals = globals()
|
|
26
27
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
@@ -64,16 +65,18 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
|
64
65
|
_TASKDETAILS.fields_by_name['metadata']._serialized_options = b'\372B\005\212\001\002\020\001'
|
|
65
66
|
_TASKDETAILS.fields_by_name['spec']._options = None
|
|
66
67
|
_TASKDETAILS.fields_by_name['spec']._serialized_options = b'\372B\005\212\001\002\020\001'
|
|
67
|
-
_globals['_TASKNAME']._serialized_start=
|
|
68
|
-
_globals['_TASKNAME']._serialized_end=
|
|
69
|
-
_globals['_TASKIDENTIFIER']._serialized_start=
|
|
70
|
-
_globals['_TASKIDENTIFIER']._serialized_end=
|
|
71
|
-
_globals['_TASKMETADATA']._serialized_start=
|
|
72
|
-
_globals['_TASKMETADATA']._serialized_end=
|
|
73
|
-
_globals['_TASK']._serialized_start=
|
|
74
|
-
_globals['_TASK']._serialized_end=
|
|
75
|
-
_globals['_TASKSPEC']._serialized_start=
|
|
76
|
-
_globals['_TASKSPEC']._serialized_end=
|
|
77
|
-
_globals['
|
|
78
|
-
_globals['
|
|
68
|
+
_globals['_TASKNAME']._serialized_start=269
|
|
69
|
+
_globals['_TASKNAME']._serialized_end=412
|
|
70
|
+
_globals['_TASKIDENTIFIER']._serialized_start=415
|
|
71
|
+
_globals['_TASKIDENTIFIER']._serialized_end=601
|
|
72
|
+
_globals['_TASKMETADATA']._serialized_start=604
|
|
73
|
+
_globals['_TASKMETADATA']._serialized_end=841
|
|
74
|
+
_globals['_TASK']._serialized_start=844
|
|
75
|
+
_globals['_TASK']._serialized_end=991
|
|
76
|
+
_globals['_TASKSPEC']._serialized_start=994
|
|
77
|
+
_globals['_TASKSPEC']._serialized_end=1260
|
|
78
|
+
_globals['_TRACESPEC']._serialized_start=1262
|
|
79
|
+
_globals['_TRACESPEC']._serialized_end=1334
|
|
80
|
+
_globals['_TASKDETAILS']._serialized_start=1337
|
|
81
|
+
_globals['_TASKDETAILS']._serialized_end=1550
|
|
79
82
|
# @@protoc_insertion_point(module_scope)
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from flyte._protos.common import identifier_pb2 as _identifier_pb2
|
|
2
2
|
from flyte._protos.common import identity_pb2 as _identity_pb2
|
|
3
|
+
from flyteidl.core import interface_pb2 as _interface_pb2
|
|
3
4
|
from flyteidl.core import tasks_pb2 as _tasks_pb2
|
|
4
5
|
from google.protobuf import timestamp_pb2 as _timestamp_pb2
|
|
5
6
|
from flyte._protos.validate.validate import validate_pb2 as _validate_pb2
|
|
@@ -70,6 +71,12 @@ class TaskSpec(_message.Message):
|
|
|
70
71
|
environment: _environment_pb2.Environment
|
|
71
72
|
def __init__(self, task_template: _Optional[_Union[_tasks_pb2.TaskTemplate, _Mapping]] = ..., default_inputs: _Optional[_Iterable[_Union[_common_pb2.NamedParameter, _Mapping]]] = ..., short_name: _Optional[str] = ..., environment: _Optional[_Union[_environment_pb2.Environment, _Mapping]] = ...) -> None: ...
|
|
72
73
|
|
|
74
|
+
class TraceSpec(_message.Message):
|
|
75
|
+
__slots__ = ["interface"]
|
|
76
|
+
INTERFACE_FIELD_NUMBER: _ClassVar[int]
|
|
77
|
+
interface: _interface_pb2.TypedInterface
|
|
78
|
+
def __init__(self, interface: _Optional[_Union[_interface_pb2.TypedInterface, _Mapping]] = ...) -> None: ...
|
|
79
|
+
|
|
73
80
|
class TaskDetails(_message.Message):
|
|
74
81
|
__slots__ = ["task_id", "metadata", "spec"]
|
|
75
82
|
TASK_ID_FIELD_NUMBER: _ClassVar[int]
|
flyte/_task.py
CHANGED
|
@@ -85,7 +85,7 @@ class TaskTemplate(Generic[P, R]):
|
|
|
85
85
|
|
|
86
86
|
name: str
|
|
87
87
|
interface: NativeInterface
|
|
88
|
-
|
|
88
|
+
short_name: str = ""
|
|
89
89
|
task_type: str = "python"
|
|
90
90
|
task_type_version: int = 0
|
|
91
91
|
image: Union[str, Image, Literal["auto"]] = "auto"
|
|
@@ -129,9 +129,9 @@ class TaskTemplate(Generic[P, R]):
|
|
|
129
129
|
if isinstance(self.retries, int):
|
|
130
130
|
self.retries = RetryStrategy(count=self.retries)
|
|
131
131
|
|
|
132
|
-
if self.
|
|
133
|
-
# If
|
|
134
|
-
self.
|
|
132
|
+
if self.short_name == "":
|
|
133
|
+
# If short_name is not set, use the name of the task
|
|
134
|
+
self.short_name = self.name
|
|
135
135
|
|
|
136
136
|
def __getstate__(self):
|
|
137
137
|
"""
|
|
@@ -314,7 +314,7 @@ class TaskTemplate(Generic[P, R]):
|
|
|
314
314
|
def override(
|
|
315
315
|
self,
|
|
316
316
|
*,
|
|
317
|
-
|
|
317
|
+
short_name: Optional[str] = None,
|
|
318
318
|
resources: Optional[Resources] = None,
|
|
319
319
|
cache: Optional[CacheRequest] = None,
|
|
320
320
|
retries: Union[int, RetryStrategy] = 0,
|
|
@@ -375,7 +375,7 @@ class TaskTemplate(Generic[P, R]):
|
|
|
375
375
|
|
|
376
376
|
return replace(
|
|
377
377
|
self,
|
|
378
|
-
|
|
378
|
+
short_name=short_name or self.short_name,
|
|
379
379
|
resources=resources,
|
|
380
380
|
cache=cache,
|
|
381
381
|
retries=retries,
|
flyte/_task_environment.py
CHANGED
|
@@ -133,7 +133,7 @@ class TaskEnvironment(Environment):
|
|
|
133
133
|
self,
|
|
134
134
|
_func=None,
|
|
135
135
|
*,
|
|
136
|
-
|
|
136
|
+
short_name: Optional[str] = None,
|
|
137
137
|
cache: CacheRequest | None = None,
|
|
138
138
|
retries: Union[int, RetryStrategy] = 0,
|
|
139
139
|
timeout: Union[timedelta, int] = 0,
|
|
@@ -147,7 +147,7 @@ class TaskEnvironment(Environment):
|
|
|
147
147
|
|
|
148
148
|
:param _func: Optional The function to decorate. If not provided, the decorator will return a callable that
|
|
149
149
|
accepts a function to be decorated.
|
|
150
|
-
:param
|
|
150
|
+
:param short_name: Optional A friendly name for the task (defaults to the function name)
|
|
151
151
|
:param cache: Optional The cache policy for the task, defaults to auto, which will cache the results of the
|
|
152
152
|
task.
|
|
153
153
|
:param retries: Optional The number of retries for the task, defaults to 0, which means no retries.
|
|
@@ -166,7 +166,7 @@ class TaskEnvironment(Environment):
|
|
|
166
166
|
raise ValueError("Cannot set pod_template when environment is reusable.")
|
|
167
167
|
|
|
168
168
|
def decorator(func: FunctionTypes) -> AsyncFunctionTaskTemplate[P, R]:
|
|
169
|
-
|
|
169
|
+
short = short_name or func.__name__
|
|
170
170
|
task_name = self.name + "." + func.__name__
|
|
171
171
|
|
|
172
172
|
if not inspect.iscoroutinefunction(func) and self.reusable is not None:
|
|
@@ -207,7 +207,7 @@ class TaskEnvironment(Environment):
|
|
|
207
207
|
parent_env=weakref.ref(self),
|
|
208
208
|
interface=NativeInterface.from_callable(func),
|
|
209
209
|
report=report,
|
|
210
|
-
|
|
210
|
+
short_name=short,
|
|
211
211
|
plugin_config=self.plugin_config,
|
|
212
212
|
max_inline_io_bytes=max_inline_io_bytes,
|
|
213
213
|
)
|
flyte/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '2.0.
|
|
32
|
-
__version_tuple__ = version_tuple = (2, 0, 0, '
|
|
31
|
+
__version__ = version = '2.0.0b15'
|
|
32
|
+
__version_tuple__ = version_tuple = (2, 0, 0, 'b15')
|
|
33
33
|
|
|
34
|
-
__commit_id__ = commit_id = '
|
|
34
|
+
__commit_id__ = commit_id = 'gaa61639b5'
|
flyte/cli/_build.py
CHANGED
|
@@ -3,8 +3,7 @@ from pathlib import Path
|
|
|
3
3
|
from types import ModuleType
|
|
4
4
|
from typing import Any, Dict, List, cast
|
|
5
5
|
|
|
6
|
-
import click
|
|
7
|
-
from click import Context
|
|
6
|
+
import rich_click as click
|
|
8
7
|
|
|
9
8
|
import flyte
|
|
10
9
|
|
|
@@ -44,7 +43,7 @@ class BuildEnvCommand(click.Command):
|
|
|
44
43
|
self.build_args = build_args
|
|
45
44
|
super().__init__(*args, **kwargs)
|
|
46
45
|
|
|
47
|
-
def invoke(self, ctx: Context):
|
|
46
|
+
def invoke(self, ctx: click.Context):
|
|
48
47
|
from rich.console import Console
|
|
49
48
|
|
|
50
49
|
console = Console()
|
flyte/cli/_run.py
CHANGED
|
@@ -8,8 +8,7 @@ from pathlib import Path
|
|
|
8
8
|
from types import ModuleType
|
|
9
9
|
from typing import Any, Dict, List, cast
|
|
10
10
|
|
|
11
|
-
import click
|
|
12
|
-
from click import Context, Parameter
|
|
11
|
+
import rich_click as click
|
|
13
12
|
from rich.console import Console
|
|
14
13
|
from typing_extensions import get_args
|
|
15
14
|
|
|
@@ -24,7 +23,7 @@ RUN_REMOTE_CMD = "deployed-task"
|
|
|
24
23
|
|
|
25
24
|
|
|
26
25
|
@lru_cache()
|
|
27
|
-
def _initialize_config(ctx: Context, project: str, domain: str):
|
|
26
|
+
def _initialize_config(ctx: click.Context, project: str, domain: str):
|
|
28
27
|
obj: CLIConfig | None = ctx.obj
|
|
29
28
|
if obj is None:
|
|
30
29
|
import flyte.config
|
|
@@ -37,7 +36,7 @@ def _initialize_config(ctx: Context, project: str, domain: str):
|
|
|
37
36
|
|
|
38
37
|
@lru_cache()
|
|
39
38
|
def _list_tasks(
|
|
40
|
-
ctx: Context,
|
|
39
|
+
ctx: click.Context,
|
|
41
40
|
project: str,
|
|
42
41
|
domain: str,
|
|
43
42
|
by_task_name: str | None = None,
|
|
@@ -121,7 +120,7 @@ class RunTaskCommand(click.Command):
|
|
|
121
120
|
kwargs.pop("name", None)
|
|
122
121
|
super().__init__(obj_name, *args, **kwargs)
|
|
123
122
|
|
|
124
|
-
def invoke(self, ctx: Context):
|
|
123
|
+
def invoke(self, ctx: click.Context):
|
|
125
124
|
obj: CLIConfig = _initialize_config(ctx, self.run_args.project, self.run_args.domain)
|
|
126
125
|
|
|
127
126
|
async def _run():
|
|
@@ -152,7 +151,7 @@ class RunTaskCommand(click.Command):
|
|
|
152
151
|
|
|
153
152
|
asyncio.run(_run())
|
|
154
153
|
|
|
155
|
-
def get_params(self, ctx: Context) -> List[Parameter]:
|
|
154
|
+
def get_params(self, ctx: click.Context) -> List[click.Parameter]:
|
|
156
155
|
# Note this function may be called multiple times by click.
|
|
157
156
|
task = self.obj
|
|
158
157
|
from .._internal.runtime.types_serde import transform_native_to_typed_interface
|
|
@@ -162,7 +161,7 @@ class RunTaskCommand(click.Command):
|
|
|
162
161
|
return super().get_params(ctx)
|
|
163
162
|
inputs_interface = task.native_interface.inputs
|
|
164
163
|
|
|
165
|
-
params: List[Parameter] = []
|
|
164
|
+
params: List[click.Parameter] = []
|
|
166
165
|
for name, var in interface.inputs.variables.items():
|
|
167
166
|
default_val = None
|
|
168
167
|
if inputs_interface[name][1] is not inspect._empty:
|
|
@@ -239,7 +238,7 @@ class RunReferenceTaskCommand(click.Command):
|
|
|
239
238
|
|
|
240
239
|
asyncio.run(_run())
|
|
241
240
|
|
|
242
|
-
def get_params(self, ctx: Context) -> List[Parameter]:
|
|
241
|
+
def get_params(self, ctx: click.Context) -> List[click.Parameter]:
|
|
243
242
|
# Note this function may be called multiple times by click.
|
|
244
243
|
import flyte.remote
|
|
245
244
|
from flyte._internal.runtime.types_serde import transform_native_to_typed_interface
|
|
@@ -254,7 +253,7 @@ class RunReferenceTaskCommand(click.Command):
|
|
|
254
253
|
return super().get_params(ctx)
|
|
255
254
|
inputs_interface = task_details.interface.inputs
|
|
256
255
|
|
|
257
|
-
params: List[Parameter] = []
|
|
256
|
+
params: List[click.Parameter] = []
|
|
258
257
|
for name, var in interface.inputs.variables.items():
|
|
259
258
|
default_val = None
|
|
260
259
|
if inputs_interface[name][1] is not inspect._empty:
|
|
@@ -322,7 +321,6 @@ class ReferenceTaskGroup(common.GroupBase):
|
|
|
322
321
|
|
|
323
322
|
def get_command(self, ctx, name):
|
|
324
323
|
env, task, version = self._parse_task_name(name)
|
|
325
|
-
|
|
326
324
|
match env, task, version:
|
|
327
325
|
case env, None, None:
|
|
328
326
|
if self._env_is_task(ctx, env):
|
|
@@ -383,10 +381,11 @@ class TaskFiles(common.FileGroup):
|
|
|
383
381
|
super().__init__(*args, directory=directory, **kwargs)
|
|
384
382
|
|
|
385
383
|
def list_commands(self, ctx):
|
|
386
|
-
|
|
384
|
+
v = [
|
|
387
385
|
RUN_REMOTE_CMD,
|
|
388
|
-
*
|
|
386
|
+
*super().list_commands(ctx),
|
|
389
387
|
]
|
|
388
|
+
return v
|
|
390
389
|
|
|
391
390
|
def get_command(self, ctx, cmd_name):
|
|
392
391
|
run_args = RunArguments.from_dict(ctx.params)
|
flyte/models.py
CHANGED
|
@@ -166,6 +166,7 @@ class TaskContext:
|
|
|
166
166
|
action: ActionID
|
|
167
167
|
version: str
|
|
168
168
|
raw_data_path: RawDataPath
|
|
169
|
+
input_path: str | None = None
|
|
169
170
|
output_path: str
|
|
170
171
|
run_base_dir: str
|
|
171
172
|
report: Report
|
|
@@ -175,6 +176,7 @@ class TaskContext:
|
|
|
175
176
|
compiled_image_cache: ImageCache | None = None
|
|
176
177
|
data: Dict[str, Any] = field(default_factory=dict)
|
|
177
178
|
mode: Literal["local", "remote", "hybrid"] = "remote"
|
|
179
|
+
interactive_mode: bool = False
|
|
178
180
|
|
|
179
181
|
def replace(self, **kwargs) -> TaskContext:
|
|
180
182
|
if "data" in kwargs:
|
flyte/remote/_action.py
CHANGED
|
@@ -627,8 +627,11 @@ class ActionDetails(ToJSONMixin):
|
|
|
627
627
|
)
|
|
628
628
|
)
|
|
629
629
|
native_iface = None
|
|
630
|
-
if self.pb2.
|
|
631
|
-
iface = self.pb2.
|
|
630
|
+
if self.pb2.HasField('task'):
|
|
631
|
+
iface = self.pb2.task.task_template.interface
|
|
632
|
+
native_iface = types.guess_interface(iface)
|
|
633
|
+
elif self.pb2.HasField('trace'):
|
|
634
|
+
iface = self.pb2.trace.interface
|
|
632
635
|
native_iface = types.guess_interface(iface)
|
|
633
636
|
|
|
634
637
|
if resp.inputs:
|
|
@@ -81,7 +81,7 @@ class DeviceCodeAuthenticator(Authenticator):
|
|
|
81
81
|
for_endpoint=self._endpoint,
|
|
82
82
|
)
|
|
83
83
|
except (AuthenticationError, AuthenticationPending):
|
|
84
|
-
logger.warning("
|
|
84
|
+
logger.warning("Logging in...")
|
|
85
85
|
|
|
86
86
|
"""Fall back to device flow"""
|
|
87
87
|
resp = await token_client.get_device_code(
|
|
@@ -123,7 +123,7 @@ class PKCEAuthenticator(Authenticator):
|
|
|
123
123
|
try:
|
|
124
124
|
return await self._auth_client.refresh_access_token(self._creds)
|
|
125
125
|
except AccessTokenNotFoundError:
|
|
126
|
-
logger.warning("
|
|
126
|
+
logger.warning("Logging in...")
|
|
127
127
|
|
|
128
128
|
return await self._auth_client.get_creds_from_remote()
|
|
129
129
|
|
flyte/remote/_task.py
CHANGED
|
@@ -283,7 +283,7 @@ class TaskDetails(ToJSONMixin):
|
|
|
283
283
|
def override(
|
|
284
284
|
self,
|
|
285
285
|
*,
|
|
286
|
-
|
|
286
|
+
short_name: Optional[str] = None,
|
|
287
287
|
resources: Optional[flyte.Resources] = None,
|
|
288
288
|
retries: Union[int, flyte.RetryStrategy] = 0,
|
|
289
289
|
timeout: Optional[flyte.TimeoutType] = None,
|
|
@@ -297,8 +297,8 @@ class TaskDetails(ToJSONMixin):
|
|
|
297
297
|
f"Check the parameters for override method."
|
|
298
298
|
)
|
|
299
299
|
template = self.pb2.spec.task_template
|
|
300
|
-
if
|
|
301
|
-
self.pb2.metadata.short_name =
|
|
300
|
+
if short_name:
|
|
301
|
+
self.pb2.metadata.short_name = short_name
|
|
302
302
|
if secrets:
|
|
303
303
|
template.security_context.CopyFrom(get_security_context(secrets))
|
|
304
304
|
if template.HasField("container"):
|
|
@@ -318,7 +318,7 @@ class TaskDetails(ToJSONMixin):
|
|
|
318
318
|
"""
|
|
319
319
|
Rich representation of the task.
|
|
320
320
|
"""
|
|
321
|
-
yield "
|
|
321
|
+
yield "short_name", self.pb2.spec.short_name
|
|
322
322
|
yield "environment", self.pb2.spec.environment
|
|
323
323
|
yield "default_inputs_keys", self.default_input_args
|
|
324
324
|
yield "required_args", self.required_args
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import click
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
@click.group()
|
|
5
|
+
def _debug():
|
|
6
|
+
"""Debug commands for Flyte."""
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@_debug.command("resume")
|
|
10
|
+
@click.option("--pid", "-m", type=int, required=True, help="PID of the vscode server.")
|
|
11
|
+
def resume(pid):
|
|
12
|
+
"""
|
|
13
|
+
Resume a Flyte task for debugging purposes.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
pid (int): PID of the vscode server.
|
|
17
|
+
"""
|
|
18
|
+
import os
|
|
19
|
+
import signal
|
|
20
|
+
|
|
21
|
+
print("Terminating server and resuming task.")
|
|
22
|
+
answer = (
|
|
23
|
+
input(
|
|
24
|
+
"This operation will kill the server. All unsaved data will be lost,"
|
|
25
|
+
" and you will no longer be able to connect to it. Do you really want to terminate? (Y/N): "
|
|
26
|
+
)
|
|
27
|
+
.strip()
|
|
28
|
+
.upper()
|
|
29
|
+
)
|
|
30
|
+
if answer == "Y":
|
|
31
|
+
os.kill(pid, signal.SIGTERM)
|
|
32
|
+
print("The server has been terminated and the task has been resumed.")
|
|
33
|
+
else:
|
|
34
|
+
print("Operation canceled.")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
if __name__ == "__main__":
|
|
38
|
+
_debug()
|
|
@@ -26,6 +26,7 @@ DOMAIN_NAME = "FLYTE_INTERNAL_TASK_DOMAIN"
|
|
|
26
26
|
ORG_NAME = "_U_ORG_NAME"
|
|
27
27
|
ENDPOINT_OVERRIDE = "_U_EP_OVERRIDE"
|
|
28
28
|
RUN_OUTPUT_BASE_DIR = "_U_RUN_BASE"
|
|
29
|
+
FLYTE_ENABLE_VSCODE_KEY = "_F_E_VS"
|
|
29
30
|
|
|
30
31
|
# TODO: Remove this after proper auth is implemented
|
|
31
32
|
_UNION_EAGER_API_KEY_ENV_VAR = "_UNION_EAGER_API_KEY"
|
|
@@ -49,6 +50,8 @@ def _pass_through():
|
|
|
49
50
|
@click.option("--project", envvar=PROJECT_NAME, required=False)
|
|
50
51
|
@click.option("--domain", envvar=DOMAIN_NAME, required=False)
|
|
51
52
|
@click.option("--org", envvar=ORG_NAME, required=False)
|
|
53
|
+
@click.option("--debug", envvar=FLYTE_ENABLE_VSCODE_KEY, type=click.BOOL, required=False)
|
|
54
|
+
@click.option("--interactive-mode", type=click.BOOL, required=False)
|
|
52
55
|
@click.option("--image-cache", required=False)
|
|
53
56
|
@click.option("--tgz", required=False)
|
|
54
57
|
@click.option("--pkl", required=False)
|
|
@@ -59,12 +62,16 @@ def _pass_through():
|
|
|
59
62
|
type=click.UNPROCESSED,
|
|
60
63
|
nargs=-1,
|
|
61
64
|
)
|
|
65
|
+
@click.pass_context
|
|
62
66
|
def main(
|
|
67
|
+
ctx: click.Context,
|
|
63
68
|
run_name: str,
|
|
64
69
|
name: str,
|
|
65
70
|
project: str,
|
|
66
71
|
domain: str,
|
|
67
72
|
org: str,
|
|
73
|
+
debug: bool,
|
|
74
|
+
interactive_mode: bool,
|
|
68
75
|
image_cache: str,
|
|
69
76
|
version: str,
|
|
70
77
|
inputs: str,
|
|
@@ -109,6 +116,11 @@ def main(
|
|
|
109
116
|
if name.startswith("{{"):
|
|
110
117
|
name = os.getenv("ACTION_NAME", "")
|
|
111
118
|
|
|
119
|
+
if debug and name == "a0":
|
|
120
|
+
from flyte._debug.vscode import _start_vscode_server
|
|
121
|
+
|
|
122
|
+
asyncio.run(_start_vscode_server(ctx))
|
|
123
|
+
|
|
112
124
|
# Figure out how to connect
|
|
113
125
|
# This detection of api key is a hack for now.
|
|
114
126
|
controller_kwargs: dict[str, Any] = {"insecure": False}
|
|
@@ -143,6 +155,7 @@ def main(
|
|
|
143
155
|
version=version,
|
|
144
156
|
controller=controller,
|
|
145
157
|
image_cache=ic,
|
|
158
|
+
interactive_mode=interactive_mode or debug,
|
|
146
159
|
)
|
|
147
160
|
# Create a coroutine to watch for errors
|
|
148
161
|
controller_failure = controller.watch_for_errors()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: flyte
|
|
3
|
-
Version: 2.0.
|
|
3
|
+
Version: 2.0.0b15
|
|
4
4
|
Summary: Add your description here
|
|
5
5
|
Author-email: Ketan Umare <kumare3@users.noreply.github.com>
|
|
6
6
|
Requires-Python: >=3.10
|
|
@@ -16,7 +16,7 @@ Requires-Dist: obstore>=0.7.3
|
|
|
16
16
|
Requires-Dist: protobuf>=6.30.1
|
|
17
17
|
Requires-Dist: pydantic>=2.10.6
|
|
18
18
|
Requires-Dist: pyyaml>=6.0.2
|
|
19
|
-
Requires-Dist: rich-click
|
|
19
|
+
Requires-Dist: rich-click==1.8.9
|
|
20
20
|
Requires-Dist: httpx<1.0.0,>=0.28.1
|
|
21
21
|
Requires-Dist: keyring>=25.6.0
|
|
22
22
|
Requires-Dist: msgpack>=1.1.0
|