flyte 0.2.0b35__py3-none-any.whl → 0.2.0b37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (39) hide show
  1. flyte/_image.py +1 -1
  2. flyte/_internal/controllers/_local_controller.py +3 -2
  3. flyte/_internal/controllers/_trace.py +14 -10
  4. flyte/_internal/controllers/remote/_action.py +37 -7
  5. flyte/_internal/controllers/remote/_controller.py +43 -21
  6. flyte/_internal/controllers/remote/_core.py +32 -16
  7. flyte/_internal/controllers/remote/_informer.py +18 -7
  8. flyte/_internal/runtime/task_serde.py +17 -6
  9. flyte/_protos/common/identifier_pb2.py +23 -1
  10. flyte/_protos/common/identifier_pb2.pyi +28 -0
  11. flyte/_protos/workflow/queue_service_pb2.py +33 -29
  12. flyte/_protos/workflow/queue_service_pb2.pyi +34 -16
  13. flyte/_protos/workflow/run_definition_pb2.py +64 -71
  14. flyte/_protos/workflow/run_definition_pb2.pyi +44 -31
  15. flyte/_protos/workflow/run_logs_service_pb2.py +10 -10
  16. flyte/_protos/workflow/run_logs_service_pb2.pyi +3 -3
  17. flyte/_protos/workflow/run_service_pb2.py +54 -46
  18. flyte/_protos/workflow/run_service_pb2.pyi +32 -18
  19. flyte/_protos/workflow/run_service_pb2_grpc.py +34 -0
  20. flyte/_protos/workflow/state_service_pb2.py +20 -19
  21. flyte/_protos/workflow/state_service_pb2.pyi +13 -12
  22. flyte/_run.py +11 -6
  23. flyte/_trace.py +4 -10
  24. flyte/_version.py +2 -2
  25. flyte/migrate/__init__.py +1 -0
  26. flyte/migrate/dynamic.py +13 -0
  27. flyte/migrate/task.py +99 -0
  28. flyte/migrate/workflow.py +13 -0
  29. flyte/remote/_action.py +56 -25
  30. flyte/remote/_logs.py +4 -3
  31. flyte/remote/_run.py +5 -4
  32. flyte-0.2.0b37.dist-info/METADATA +371 -0
  33. {flyte-0.2.0b35.dist-info → flyte-0.2.0b37.dist-info}/RECORD +38 -33
  34. flyte-0.2.0b37.dist-info/licenses/LICENSE +201 -0
  35. flyte-0.2.0b35.dist-info/METADATA +0 -249
  36. {flyte-0.2.0b35.data → flyte-0.2.0b37.data}/scripts/runtime.py +0 -0
  37. {flyte-0.2.0b35.dist-info → flyte-0.2.0b37.dist-info}/WHEEL +0 -0
  38. {flyte-0.2.0b35.dist-info → flyte-0.2.0b37.dist-info}/entry_points.txt +0 -0
  39. {flyte-0.2.0b35.dist-info → flyte-0.2.0b37.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ from flyte._protos.common import identifier_pb2 as _identifier_pb2
1
2
  from flyteidl.core import execution_pb2 as _execution_pb2
2
3
  from google.rpc import status_pb2 as _status_pb2
3
4
  from flyte._protos.validate.validate import validate_pb2 as _validate_pb2
@@ -13,40 +14,40 @@ class PutRequest(_message.Message):
13
14
  ACTION_ID_FIELD_NUMBER: _ClassVar[int]
14
15
  PARENT_ACTION_NAME_FIELD_NUMBER: _ClassVar[int]
15
16
  STATE_FIELD_NUMBER: _ClassVar[int]
16
- action_id: _run_definition_pb2.ActionIdentifier
17
+ action_id: _identifier_pb2.ActionIdentifier
17
18
  parent_action_name: str
18
19
  state: str
19
- def __init__(self, action_id: _Optional[_Union[_run_definition_pb2.ActionIdentifier, _Mapping]] = ..., parent_action_name: _Optional[str] = ..., state: _Optional[str] = ...) -> None: ...
20
+ def __init__(self, action_id: _Optional[_Union[_identifier_pb2.ActionIdentifier, _Mapping]] = ..., parent_action_name: _Optional[str] = ..., state: _Optional[str] = ...) -> None: ...
20
21
 
21
22
  class PutResponse(_message.Message):
22
23
  __slots__ = ["action_id", "status"]
23
24
  ACTION_ID_FIELD_NUMBER: _ClassVar[int]
24
25
  STATUS_FIELD_NUMBER: _ClassVar[int]
25
- action_id: _run_definition_pb2.ActionIdentifier
26
+ action_id: _identifier_pb2.ActionIdentifier
26
27
  status: _status_pb2.Status
27
- def __init__(self, action_id: _Optional[_Union[_run_definition_pb2.ActionIdentifier, _Mapping]] = ..., status: _Optional[_Union[_status_pb2.Status, _Mapping]] = ...) -> None: ...
28
+ def __init__(self, action_id: _Optional[_Union[_identifier_pb2.ActionIdentifier, _Mapping]] = ..., status: _Optional[_Union[_status_pb2.Status, _Mapping]] = ...) -> None: ...
28
29
 
29
30
  class GetRequest(_message.Message):
30
31
  __slots__ = ["action_id"]
31
32
  ACTION_ID_FIELD_NUMBER: _ClassVar[int]
32
- action_id: _run_definition_pb2.ActionIdentifier
33
- def __init__(self, action_id: _Optional[_Union[_run_definition_pb2.ActionIdentifier, _Mapping]] = ...) -> None: ...
33
+ action_id: _identifier_pb2.ActionIdentifier
34
+ def __init__(self, action_id: _Optional[_Union[_identifier_pb2.ActionIdentifier, _Mapping]] = ...) -> None: ...
34
35
 
35
36
  class GetResponse(_message.Message):
36
37
  __slots__ = ["action_id", "status", "state"]
37
38
  ACTION_ID_FIELD_NUMBER: _ClassVar[int]
38
39
  STATUS_FIELD_NUMBER: _ClassVar[int]
39
40
  STATE_FIELD_NUMBER: _ClassVar[int]
40
- action_id: _run_definition_pb2.ActionIdentifier
41
+ action_id: _identifier_pb2.ActionIdentifier
41
42
  status: _status_pb2.Status
42
43
  state: str
43
- def __init__(self, action_id: _Optional[_Union[_run_definition_pb2.ActionIdentifier, _Mapping]] = ..., status: _Optional[_Union[_status_pb2.Status, _Mapping]] = ..., state: _Optional[str] = ...) -> None: ...
44
+ def __init__(self, action_id: _Optional[_Union[_identifier_pb2.ActionIdentifier, _Mapping]] = ..., status: _Optional[_Union[_status_pb2.Status, _Mapping]] = ..., state: _Optional[str] = ...) -> None: ...
44
45
 
45
46
  class WatchRequest(_message.Message):
46
47
  __slots__ = ["parent_action_id"]
47
48
  PARENT_ACTION_ID_FIELD_NUMBER: _ClassVar[int]
48
- parent_action_id: _run_definition_pb2.ActionIdentifier
49
- def __init__(self, parent_action_id: _Optional[_Union[_run_definition_pb2.ActionIdentifier, _Mapping]] = ...) -> None: ...
49
+ parent_action_id: _identifier_pb2.ActionIdentifier
50
+ def __init__(self, parent_action_id: _Optional[_Union[_identifier_pb2.ActionIdentifier, _Mapping]] = ...) -> None: ...
50
51
 
51
52
  class WatchResponse(_message.Message):
52
53
  __slots__ = ["action_update", "control_message"]
@@ -68,8 +69,8 @@ class ActionUpdate(_message.Message):
68
69
  PHASE_FIELD_NUMBER: _ClassVar[int]
69
70
  ERROR_FIELD_NUMBER: _ClassVar[int]
70
71
  OUTPUT_URI_FIELD_NUMBER: _ClassVar[int]
71
- action_id: _run_definition_pb2.ActionIdentifier
72
+ action_id: _identifier_pb2.ActionIdentifier
72
73
  phase: _run_definition_pb2.Phase
73
74
  error: _execution_pb2.ExecutionError
74
75
  output_uri: str
75
- def __init__(self, action_id: _Optional[_Union[_run_definition_pb2.ActionIdentifier, _Mapping]] = ..., phase: _Optional[_Union[_run_definition_pb2.Phase, str]] = ..., error: _Optional[_Union[_execution_pb2.ExecutionError, _Mapping]] = ..., output_uri: _Optional[str] = ...) -> None: ...
76
+ def __init__(self, action_id: _Optional[_Union[_identifier_pb2.ActionIdentifier, _Mapping]] = ..., phase: _Optional[_Union[_run_definition_pb2.Phase, str]] = ..., error: _Optional[_Union[_execution_pb2.ExecutionError, _Mapping]] = ..., output_uri: _Optional[str] = ...) -> None: ...
flyte/_run.py CHANGED
@@ -17,6 +17,7 @@ from flyte._initialize import (
17
17
  requires_storage,
18
18
  )
19
19
  from flyte._logging import logger
20
+ from flyte._protos.common import identifier_pb2
20
21
  from flyte._task import P, R, TaskTemplate
21
22
  from flyte._tools import ipython_check
22
23
  from flyte.errors import InitializationError
@@ -185,7 +186,7 @@ class _Runner:
185
186
  run_id = None
186
187
  project_id = None
187
188
  if self._name:
188
- run_id = run_definition_pb2.RunIdentifier(
189
+ run_id = identifier_pb2.RunIdentifier(
189
190
  project=project,
190
191
  domain=domain,
191
192
  org=cfg.org,
@@ -260,9 +261,9 @@ class _Runner:
260
261
  super().__init__(
261
262
  pb2=run_definition_pb2.Run(
262
263
  action=run_definition_pb2.Action(
263
- id=run_definition_pb2.ActionIdentifier(
264
+ id=identifier_pb2.ActionIdentifier(
264
265
  name="a0",
265
- run=run_definition_pb2.RunIdentifier(name="dry-run"),
266
+ run=identifier_pb2.RunIdentifier(name="dry-run"),
266
267
  )
267
268
  )
268
269
  )
@@ -422,9 +423,9 @@ class _Runner:
422
423
  super().__init__(
423
424
  pb2=run_definition_pb2.Run(
424
425
  action=run_definition_pb2.Action(
425
- id=run_definition_pb2.ActionIdentifier(
426
+ id=identifier_pb2.ActionIdentifier(
426
427
  name="a0",
427
- run=run_definition_pb2.RunIdentifier(name="dry-run"),
428
+ run=identifier_pb2.RunIdentifier(name="dry-run"),
428
429
  )
429
430
  )
430
431
  )
@@ -434,7 +435,11 @@ class _Runner:
434
435
  def url(self) -> str:
435
436
  return "local-run"
436
437
 
437
- def wait(self, quiet: bool = False, wait_for: Literal["terminal", "running"] = "terminal"):
438
+ def wait(
439
+ self,
440
+ quiet: bool = False,
441
+ wait_for: Literal["terminal", "running"] = "terminal",
442
+ ):
438
443
  pass
439
444
 
440
445
  def outputs(self) -> R:
flyte/_trace.py CHANGED
@@ -1,7 +1,6 @@
1
1
  import functools
2
2
  import inspect
3
3
  import time
4
- from datetime import timedelta
5
4
  from typing import Any, AsyncGenerator, AsyncIterator, Awaitable, Callable, TypeGuard, TypeVar, Union, cast
6
5
 
7
6
  from flyte.models import NativeInterface
@@ -43,14 +42,12 @@ def trace(func: Callable[..., T]) -> Callable[..., T]:
43
42
  # Cast to Awaitable to satisfy mypy
44
43
  coroutine_result = cast(Awaitable[Any], func(*args, **kwargs))
45
44
  results = await coroutine_result
46
- duration = time.time() - start_time
47
- info.add_outputs(results, timedelta(seconds=duration))
45
+ info.add_outputs(results, start_time=start_time, end_time=time.time())
48
46
  await controller.record_trace(info)
49
47
  return results
50
48
  except Exception as e:
51
49
  # If there is an error, we need to record it
52
- duration = time.time() - start_time
53
- info.add_error(e, timedelta(seconds=duration))
50
+ info.add_error(e, start_time=start_time, end_time=time.time())
54
51
  await controller.record_trace(info)
55
52
  raise e
56
53
  else:
@@ -93,14 +90,11 @@ def trace(func: Callable[..., T]) -> Callable[..., T]:
93
90
  async for item in async_iter:
94
91
  items.append(item)
95
92
  yield item
96
- duration = time.time() - start_time
97
- info.add_outputs(items, timedelta(seconds=duration))
93
+ info.add_outputs(items, start_time=start_time, end_time=time.time())
98
94
  await controller.record_trace(info)
99
95
  return
100
96
  except Exception as e:
101
- end_time = time.time()
102
- duration = end_time - start_time
103
- info.add_error(e, timedelta(seconds=duration))
97
+ info.add_error(e, start_time=start_time, end_time=time.time())
104
98
  await controller.record_trace(info)
105
99
  raise e
106
100
  else:
flyte/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.0b35'
21
- __version_tuple__ = version_tuple = (0, 2, 0, 'b35')
20
+ __version__ = version = '0.2.0b37'
21
+ __version_tuple__ = version_tuple = (0, 2, 0, 'b37')
@@ -0,0 +1 @@
1
+ from flyte.migrate import dynamic, task, workflow # noqa: F401
@@ -0,0 +1,13 @@
1
+ from typing import Callable, Union
2
+
3
+ import flytekit
4
+
5
+ import flyte.migrate
6
+ from flyte._task import AsyncFunctionTaskTemplate, P, R
7
+
8
+
9
+ def dynamic_shim(**kwargs) -> Union[AsyncFunctionTaskTemplate, Callable[P, R]]:
10
+ return flyte.migrate.task.task_shim(**kwargs)
11
+
12
+
13
+ flytekit.dynamic = dynamic_shim
flyte/migrate/task.py ADDED
@@ -0,0 +1,99 @@
1
+ import datetime
2
+ from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
3
+
4
+ import flytekit
5
+ from flytekit.core import launch_plan, workflow
6
+ from flytekit.core.base_task import T, TaskResolverMixin
7
+ from flytekit.core.python_function_task import PythonFunctionTask
8
+ from flytekit.core.task import FuncOut
9
+ from flytekit.deck import DeckField
10
+ from flytekit.extras.accelerators import BaseAccelerator
11
+
12
+ import flyte
13
+ from flyte import Image, Resources, TaskEnvironment
14
+ from flyte._doc import Documentation
15
+ from flyte._task import AsyncFunctionTaskTemplate, P, R
16
+
17
+
18
+ def task_shim(
19
+ _task_function: Optional[Callable[P, FuncOut]] = None,
20
+ task_config: Optional[T] = None,
21
+ cache: Union[bool, flytekit.Cache] = False,
22
+ retries: int = 0,
23
+ interruptible: Optional[bool] = None,
24
+ deprecated: str = "",
25
+ timeout: Union[datetime.timedelta, int] = 0,
26
+ container_image: Optional[Union[str, flytekit.ImageSpec]] = None,
27
+ environment: Optional[Dict[str, str]] = None,
28
+ requests: Optional[flytekit.Resources] = None,
29
+ limits: Optional[flytekit.Resources] = None,
30
+ secret_requests: Optional[List[flytekit.Secret]] = None,
31
+ execution_mode: PythonFunctionTask.ExecutionBehavior = PythonFunctionTask.ExecutionBehavior.DEFAULT,
32
+ node_dependency_hints: Optional[
33
+ Iterable[
34
+ Union[
35
+ flytekit.PythonFunctionTask,
36
+ launch_plan.LaunchPlan,
37
+ workflow.WorkflowBase,
38
+ ]
39
+ ]
40
+ ] = None,
41
+ task_resolver: Optional[TaskResolverMixin] = None,
42
+ docs: Optional[flytekit.Documentation] = None,
43
+ disable_deck: Optional[bool] = None,
44
+ enable_deck: Optional[bool] = None,
45
+ deck_fields: Optional[Tuple[DeckField, ...]] = (
46
+ DeckField.SOURCE_CODE,
47
+ DeckField.DEPENDENCIES,
48
+ DeckField.TIMELINE,
49
+ DeckField.INPUT,
50
+ DeckField.OUTPUT,
51
+ ),
52
+ pod_template: Optional[flytekit.PodTemplate] = None,
53
+ pod_template_name: Optional[str] = None,
54
+ accelerator: Optional[BaseAccelerator] = None,
55
+ pickle_untyped: bool = False,
56
+ shared_memory: Optional[Union[Literal[True], str]] = None,
57
+ resources: Optional[Resources] = None,
58
+ labels: Optional[dict[str, str]] = None,
59
+ annotations: Optional[dict[str, str]] = None,
60
+ **kwargs,
61
+ ) -> Union[AsyncFunctionTaskTemplate, Callable[P, R]]:
62
+ plugin_config = task_config
63
+ pod_template = (
64
+ flyte.PodTemplate(
65
+ pod_spec=pod_template.pod_spec,
66
+ primary_container_name=pod_template.primary_container_name,
67
+ labels=pod_template.labels,
68
+ annotations=pod_template.annotations,
69
+ )
70
+ if pod_template
71
+ else None
72
+ )
73
+
74
+ if isinstance(container_image, flytekit.ImageSpec):
75
+ image = Image.from_debian_base()
76
+ if container_image.apt_packages:
77
+ image = image.with_apt_packages(*container_image.apt_packages)
78
+ pip_packages = ["flytekit"]
79
+ if container_image.packages:
80
+ pip_packages.extend(container_image.packages)
81
+ image = image.with_pip_packages(*pip_packages)
82
+ elif isinstance(container_image, str):
83
+ image = Image.from_base(container_image).with_pip_packages("flyte")
84
+ else:
85
+ image = Image.from_debian_base().with_pip_packages("flytekit")
86
+
87
+ docs = Documentation(description=docs.short_description) if docs else None
88
+
89
+ env = TaskEnvironment(
90
+ name="flytekit",
91
+ resources=Resources(cpu=0.8, memory="800Mi"),
92
+ image=image,
93
+ cache="enabled" if cache else "disable",
94
+ plugin_config=plugin_config,
95
+ )
96
+ return env.task(retries=retries, pod_template=pod_template_name or pod_template, docs=docs)
97
+
98
+
99
+ flytekit.task = task_shim
@@ -0,0 +1,13 @@
1
+ import flytekit
2
+
3
+ from flyte import Image, Resources, TaskEnvironment
4
+
5
+ env = TaskEnvironment(
6
+ name="flytekit",
7
+ resources=Resources(cpu=0.8, memory="800Mi"),
8
+ image=Image.from_debian_base().with_apt_packages("vim").with_pip_packages("flytekit", "pandas"),
9
+ )
10
+
11
+ # TODO: Build subtask's image
12
+
13
+ flytekit.workflow = env.task
flyte/remote/_action.py CHANGED
@@ -4,18 +4,29 @@ import asyncio
4
4
  from collections import UserDict
5
5
  from dataclasses import dataclass
6
6
  from datetime import datetime, timedelta, timezone
7
- from typing import Any, AsyncGenerator, AsyncIterator, Dict, Iterator, List, Literal, Tuple, Union, cast
7
+ from typing import (
8
+ Any,
9
+ AsyncGenerator,
10
+ AsyncIterator,
11
+ Dict,
12
+ Iterator,
13
+ List,
14
+ Literal,
15
+ Tuple,
16
+ Union,
17
+ cast,
18
+ )
8
19
 
9
20
  import grpc
10
21
  import rich.pretty
11
22
  import rich.repr
12
- from google.protobuf import timestamp
23
+ from google.protobuf import timestamp_pb2
13
24
  from rich.console import Console
14
25
  from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
15
26
 
16
27
  from flyte import types
17
28
  from flyte._initialize import ensure_client, get_client, get_common_config
18
- from flyte._protos.common import list_pb2
29
+ from flyte._protos.common import identifier_pb2, list_pb2
19
30
  from flyte._protos.workflow import run_definition_pb2, run_service_pb2
20
31
  from flyte._protos.workflow.run_service_pb2 import WatchActionDetailsResponse
21
32
  from flyte.remote._logs import Logs
@@ -24,11 +35,13 @@ from flyte.syncify import syncify
24
35
  WaitFor = Literal["terminal", "running", "logs-ready"]
25
36
 
26
37
 
27
- def _action_time_phase(action: run_definition_pb2.Action | run_definition_pb2.ActionDetails) -> rich.repr.Result:
38
+ def _action_time_phase(
39
+ action: run_definition_pb2.Action | run_definition_pb2.ActionDetails,
40
+ ) -> rich.repr.Result:
28
41
  """
29
42
  Rich representation of the action time and phase.
30
43
  """
31
- start_time = timestamp.to_datetime(action.status.start_time, timezone.utc)
44
+ start_time = action.status.start_time.ToDatetime().replace(tzinfo=timezone.utc)
32
45
  yield "start_time", start_time.isoformat()
33
46
  if action.status.phase in [
34
47
  run_definition_pb2.PHASE_FAILED,
@@ -36,7 +49,7 @@ def _action_time_phase(action: run_definition_pb2.Action | run_definition_pb2.Ac
36
49
  run_definition_pb2.PHASE_ABORTED,
37
50
  run_definition_pb2.PHASE_TIMED_OUT,
38
51
  ]:
39
- end_time = timestamp.to_datetime(action.status.end_time, timezone.utc)
52
+ end_time = action.status.end_time.ToDatetime().replace(tzinfo=timezone.utc)
40
53
  yield "end_time", end_time.isoformat()
41
54
  yield "run_time", f"{(end_time - start_time).seconds} secs"
42
55
  else:
@@ -46,7 +59,7 @@ def _action_time_phase(action: run_definition_pb2.Action | run_definition_pb2.Ac
46
59
  if isinstance(action, run_definition_pb2.ActionDetails):
47
60
  yield (
48
61
  "error",
49
- f"{action.error_info.kind}: {action.error_info.message}" if action.HasField("error_info") else "NA",
62
+ (f"{action.error_info.kind}: {action.error_info.message}" if action.HasField("error_info") else "NA"),
50
63
  )
51
64
 
52
65
 
@@ -57,7 +70,10 @@ def _action_rich_repr(action: run_definition_pb2.Action) -> rich.repr.Result:
57
70
  yield "run", action.id.run.name
58
71
  if action.metadata.HasField("task"):
59
72
  yield "task", action.metadata.task.id.name
60
- yield "type", "task"
73
+ yield "type", action.metadata.task.task_type
74
+ elif action.metadata.HasField("trace"):
75
+ yield "trace", action.metadata.trace.name
76
+ yield "type", "trace"
61
77
  yield "name", action.id.name
62
78
  yield from _action_time_phase(action)
63
79
  yield "group", action.metadata.group
@@ -65,14 +81,18 @@ def _action_rich_repr(action: run_definition_pb2.Action) -> rich.repr.Result:
65
81
  yield "attempts", action.status.attempts
66
82
 
67
83
 
68
- def _attempt_rich_repr(action: List[run_definition_pb2.ActionAttempt]) -> rich.repr.Result:
84
+ def _attempt_rich_repr(
85
+ action: List[run_definition_pb2.ActionAttempt],
86
+ ) -> rich.repr.Result:
69
87
  for attempt in action:
70
88
  yield "attempt", attempt.attempt
71
89
  yield "phase", run_definition_pb2.Phase.Name(attempt.phase)
72
90
  yield "logs_available", attempt.logs_available
73
91
 
74
92
 
75
- def _action_details_rich_repr(action: run_definition_pb2.ActionDetails) -> rich.repr.Result:
93
+ def _action_details_rich_repr(
94
+ action: run_definition_pb2.ActionDetails,
95
+ ) -> rich.repr.Result:
76
96
  """
77
97
  Rich representation of the action details.
78
98
  """
@@ -82,7 +102,7 @@ def _action_details_rich_repr(action: run_definition_pb2.ActionDetails) -> rich.
82
102
  yield "task_type", action.resolved_task_spec.task_template.type
83
103
  yield "task_version", action.resolved_task_spec.task_template.id.version
84
104
  yield "attempts", action.attempts
85
- yield "error", f"{action.error_info.kind}: {action.error_info.message}" if action.HasField("error_info") else "NA"
105
+ yield "error", (f"{action.error_info.kind}: {action.error_info.message}" if action.HasField("error_info") else "NA")
86
106
  yield "phase", run_definition_pb2.Phase.Name(action.status.phase)
87
107
  yield "group", action.metadata.group
88
108
  yield "parent", action.metadata.parent
@@ -129,7 +149,8 @@ class Action:
129
149
  token = None
130
150
  sort_by = sort_by or ("created_at", "asc")
131
151
  sort_pb2 = list_pb2.Sort(
132
- key=sort_by[0], direction=list_pb2.Sort.ASCENDING if sort_by[1] == "asc" else list_pb2.Sort.DESCENDING
152
+ key=sort_by[0],
153
+ direction=(list_pb2.Sort.ASCENDING if sort_by[1] == "asc" else list_pb2.Sort.DESCENDING),
133
154
  )
134
155
  cfg = get_common_config()
135
156
  while True:
@@ -141,7 +162,7 @@ class Action:
141
162
  resp = await get_client().run_service.ListActions(
142
163
  run_service_pb2.ListActionsRequest(
143
164
  request=req,
144
- run_id=run_definition_pb2.RunIdentifier(
165
+ run_id=identifier_pb2.RunIdentifier(
145
166
  org=cfg.org,
146
167
  project=cfg.project,
147
168
  domain=cfg.domain,
@@ -157,7 +178,13 @@ class Action:
157
178
 
158
179
  @syncify
159
180
  @classmethod
160
- async def get(cls, uri: str | None = None, /, run_name: str | None = None, name: str | None = None) -> Action:
181
+ async def get(
182
+ cls,
183
+ uri: str | None = None,
184
+ /,
185
+ run_name: str | None = None,
186
+ name: str | None = None,
187
+ ) -> Action:
161
188
  """
162
189
  Get a run by its ID or name. If both are provided, the ID will take precedence.
163
190
 
@@ -168,8 +195,8 @@ class Action:
168
195
  ensure_client()
169
196
  cfg = get_common_config()
170
197
  details: ActionDetails = await ActionDetails.get_details.aio(
171
- run_definition_pb2.ActionIdentifier(
172
- run=run_definition_pb2.RunIdentifier(
198
+ identifier_pb2.ActionIdentifier(
199
+ run=identifier_pb2.RunIdentifier(
173
200
  org=cfg.org,
174
201
  project=cfg.project,
175
202
  domain=cfg.domain,
@@ -225,7 +252,7 @@ class Action:
225
252
  return None
226
253
 
227
254
  @property
228
- def action_id(self) -> run_definition_pb2.ActionIdentifier:
255
+ def action_id(self) -> identifier_pb2.ActionIdentifier:
229
256
  """
230
257
  Get the action ID.
231
258
  """
@@ -396,7 +423,7 @@ class ActionDetails:
396
423
 
397
424
  @syncify
398
425
  @classmethod
399
- async def get_details(cls, action_id: run_definition_pb2.ActionIdentifier) -> ActionDetails:
426
+ async def get_details(cls, action_id: identifier_pb2.ActionIdentifier) -> ActionDetails:
400
427
  """
401
428
  Get the details of the action. This is a placeholder for getting the action details.
402
429
  """
@@ -411,7 +438,11 @@ class ActionDetails:
411
438
  @syncify
412
439
  @classmethod
413
440
  async def get(
414
- cls, uri: str | None = None, /, run_name: str | None = None, name: str | None = None
441
+ cls,
442
+ uri: str | None = None,
443
+ /,
444
+ run_name: str | None = None,
445
+ name: str | None = None,
415
446
  ) -> ActionDetails:
416
447
  """
417
448
  Get a run by its ID or name. If both are provided, the ID will take precedence.
@@ -425,8 +456,8 @@ class ActionDetails:
425
456
  assert name is not None and run_name is not None, "Either uri or name and run_name must be provided"
426
457
  cfg = get_common_config()
427
458
  return await cls.get_details.aio(
428
- run_definition_pb2.ActionIdentifier(
429
- run=run_definition_pb2.RunIdentifier(
459
+ identifier_pb2.ActionIdentifier(
460
+ run=identifier_pb2.RunIdentifier(
430
461
  org=cfg.org,
431
462
  project=cfg.project,
432
463
  domain=cfg.domain,
@@ -438,7 +469,7 @@ class ActionDetails:
438
469
 
439
470
  @syncify
440
471
  @classmethod
441
- async def watch(cls, action_id: run_definition_pb2.ActionIdentifier) -> AsyncIterator[ActionDetails]:
472
+ async def watch(cls, action_id: identifier_pb2.ActionIdentifier) -> AsyncIterator[ActionDetails]:
442
473
  """
443
474
  Watch the action for updates. This is a placeholder for watching the action.
444
475
  """
@@ -521,7 +552,7 @@ class ActionDetails:
521
552
  return None
522
553
 
523
554
  @property
524
- def action_id(self) -> run_definition_pb2.ActionIdentifier:
555
+ def action_id(self) -> identifier_pb2.ActionIdentifier:
525
556
  """
526
557
  Get the action ID.
527
558
  """
@@ -552,9 +583,9 @@ class ActionDetails:
552
583
  """
553
584
  Get the runtime of the action.
554
585
  """
555
- start_time = timestamp.to_datetime(self.pb2.status.start_time, timezone.utc)
586
+ start_time = self.pb2.status.start_time.ToDatetime().replace(tzinfo=timezone.utc)
556
587
  if self.pb2.status.HasField("end_time"):
557
- end_time = timestamp.to_datetime(self.pb2.status.end_time, timezone.utc)
588
+ end_time = self.pb2.status.end_time.ToDatetime().replace(tzinfo=timezone.utc)
558
589
  return end_time - start_time
559
590
  return datetime.now(timezone.utc) - start_time
560
591
 
flyte/remote/_logs.py CHANGED
@@ -11,8 +11,9 @@ from rich.text import Text
11
11
 
12
12
  from flyte._initialize import ensure_client, get_client
13
13
  from flyte._logging import logger
14
+ from flyte._protos.common import identifier_pb2
14
15
  from flyte._protos.logs.dataplane import payload_pb2
15
- from flyte._protos.workflow import run_definition_pb2, run_logs_service_pb2
16
+ from flyte._protos.workflow import run_logs_service_pb2
16
17
  from flyte._tools import ipython_check, ipywidgets_check
17
18
  from flyte.errors import LogsNotYetAvailableError
18
19
  from flyte.syncify import syncify
@@ -98,7 +99,7 @@ class Logs:
98
99
  @classmethod
99
100
  async def tail(
100
101
  cls,
101
- action_id: run_definition_pb2.ActionIdentifier,
102
+ action_id: identifier_pb2.ActionIdentifier,
102
103
  attempt: int = 1,
103
104
  retry: int = 3,
104
105
  ) -> AsyncGenerator[payload_pb2.LogLine, None]:
@@ -139,7 +140,7 @@ class Logs:
139
140
  @classmethod
140
141
  async def create_viewer(
141
142
  cls,
142
- action_id: run_definition_pb2.ActionIdentifier,
143
+ action_id: identifier_pb2.ActionIdentifier,
143
144
  attempt: int = 1,
144
145
  max_lines: int = 30,
145
146
  show_ts: bool = False,
flyte/remote/_run.py CHANGED
@@ -53,7 +53,8 @@ class Run:
53
53
  token = None
54
54
  sort_by = sort_by or ("created_at", "asc")
55
55
  sort_pb2 = list_pb2.Sort(
56
- key=sort_by[0], direction=list_pb2.Sort.ASCENDING if sort_by[1] == "asc" else list_pb2.Sort.DESCENDING
56
+ key=sort_by[0],
57
+ direction=(list_pb2.Sort.ASCENDING if sort_by[1] == "asc" else list_pb2.Sort.DESCENDING),
57
58
  )
58
59
  cfg = get_common_config()
59
60
  while True:
@@ -229,7 +230,7 @@ class RunDetails:
229
230
 
230
231
  @syncify
231
232
  @classmethod
232
- async def get_details(cls, run_id: run_definition_pb2.RunIdentifier) -> RunDetails:
233
+ async def get_details(cls, run_id: identifier_pb2.RunIdentifier) -> RunDetails:
233
234
  """
234
235
  Get the details of the run. This is a placeholder for getting the run details.
235
236
  """
@@ -253,7 +254,7 @@ class RunDetails:
253
254
  ensure_client()
254
255
  cfg = get_common_config()
255
256
  return await RunDetails.get_details.aio(
256
- run_id=run_definition_pb2.RunIdentifier(
257
+ run_id=identifier_pb2.RunIdentifier(
257
258
  org=cfg.org,
258
259
  project=cfg.project,
259
260
  domain=cfg.domain,
@@ -276,7 +277,7 @@ class RunDetails:
276
277
  return self.action_details.task_name
277
278
 
278
279
  @property
279
- def action_id(self) -> run_definition_pb2.ActionIdentifier:
280
+ def action_id(self) -> identifier_pb2.ActionIdentifier:
280
281
  """
281
282
  Get the action ID.
282
283
  """