flyte 0.2.0b25__py3-none-any.whl → 0.2.0b27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

flyte/remote/_run.py CHANGED
@@ -1,102 +1,19 @@
1
1
  from __future__ import annotations
2
2
 
3
- import asyncio
4
3
  from dataclasses import dataclass, field
5
- from datetime import datetime, timedelta, timezone
6
- from typing import AsyncGenerator, AsyncIterator, Iterator, List, Literal, Tuple, Union, cast
4
+ from typing import AsyncGenerator, AsyncIterator, Literal, Tuple
7
5
 
8
6
  import grpc
9
7
  import rich.repr
10
- from google.protobuf import timestamp
11
- from rich.console import Console
12
- from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
13
8
 
14
9
  from flyte._initialize import ensure_client, get_client, get_common_config
15
10
  from flyte._protos.common import identifier_pb2, list_pb2
16
11
  from flyte._protos.workflow import run_definition_pb2, run_service_pb2
17
12
  from flyte.syncify import syncify
18
13
 
19
- from .._protos.workflow.run_service_pb2 import WatchActionDetailsResponse
14
+ from . import Action, ActionDetails, ActionInputs, ActionOutputs
15
+ from ._action import _action_details_rich_repr, _action_rich_repr
20
16
  from ._console import get_run_url
21
- from ._logs import Logs
22
-
23
- WaitFor = Literal["terminal", "running", "logs-ready"]
24
-
25
-
26
- def _action_time_phase(action: run_definition_pb2.Action | run_definition_pb2.ActionDetails) -> rich.repr.Result:
27
- """
28
- Rich representation of the action time and phase.
29
- """
30
- start_time = timestamp.to_datetime(action.status.start_time, timezone.utc)
31
- yield "start_time", start_time.isoformat()
32
- if action.status.phase in [
33
- run_definition_pb2.PHASE_FAILED,
34
- run_definition_pb2.PHASE_SUCCEEDED,
35
- run_definition_pb2.PHASE_ABORTED,
36
- run_definition_pb2.PHASE_TIMED_OUT,
37
- ]:
38
- end_time = timestamp.to_datetime(action.status.end_time, timezone.utc)
39
- yield "end_time", end_time.isoformat()
40
- yield "run_time", f"{(end_time - start_time).seconds} secs"
41
- else:
42
- yield "end_time", None
43
- yield "run_time", f"{(datetime.now(timezone.utc) - start_time).seconds} secs"
44
- yield "phase", run_definition_pb2.Phase.Name(action.status.phase)
45
- if isinstance(action, run_definition_pb2.ActionDetails):
46
- yield (
47
- "error",
48
- f"{action.error_info.kind}: {action.error_info.message}" if action.HasField("error_info") else "NA",
49
- )
50
-
51
-
52
- def _action_rich_repr(action: run_definition_pb2.Action) -> rich.repr.Result:
53
- """
54
- Rich representation of the action.
55
- """
56
- yield "run", action.id.run.name
57
- if action.metadata.HasField("task"):
58
- yield "task", action.metadata.task.id.name
59
- yield "type", "task"
60
- yield "name", action.id.name
61
- yield from _action_time_phase(action)
62
- yield "group", action.metadata.group
63
- yield "parent", action.metadata.parent
64
- yield "attempts", action.status.attempts
65
-
66
-
67
- def _attempt_rich_repr(action: List[run_definition_pb2.ActionAttempt]) -> rich.repr.Result:
68
- for attempt in action:
69
- yield "attempt", attempt.attempt
70
- yield "phase", run_definition_pb2.Phase.Name(attempt.phase)
71
- yield "logs_available", attempt.logs_available
72
-
73
-
74
- def _action_details_rich_repr(action: run_definition_pb2.ActionDetails) -> rich.repr.Result:
75
- """
76
- Rich representation of the action details.
77
- """
78
- yield "name", action.id.run.name
79
- yield from _action_time_phase(action)
80
- yield "task", action.resolved_task_spec.task_template.id.name
81
- yield "task_type", action.resolved_task_spec.task_template.type
82
- yield "task_version", action.resolved_task_spec.task_template.id.version
83
- yield "attempts", action.attempts
84
- yield "error", f"{action.error_info.kind}: {action.error_info.message}" if action.HasField("error_info") else "NA"
85
- yield "phase", run_definition_pb2.Phase.Name(action.status.phase)
86
- yield "group", action.metadata.group
87
- yield "parent", action.metadata.parent
88
-
89
-
90
- def _action_done_check(phase: run_definition_pb2.Phase) -> bool:
91
- """
92
- Check if the action is done.
93
- """
94
- return phase in [
95
- run_definition_pb2.PHASE_FAILED,
96
- run_definition_pb2.PHASE_SUCCEEDED,
97
- run_definition_pb2.PHASE_ABORTED,
98
- run_definition_pb2.PHASE_TIMED_OUT,
99
- ]
100
17
 
101
18
 
102
19
  @dataclass
@@ -402,575 +319,3 @@ class RunDetails:
402
319
  import rich.pretty
403
320
 
404
321
  return rich.pretty.pretty_repr(self)
405
-
406
-
407
- @dataclass
408
- class Action:
409
- """
410
- A class representing an action. It is used to manage the run of a task and its state on the remote Union API.
411
- """
412
-
413
- pb2: run_definition_pb2.Action
414
- _details: ActionDetails | None = None
415
-
416
- @syncify
417
- @classmethod
418
- async def listall(
419
- cls,
420
- for_run_name: str,
421
- filters: str | None = None,
422
- sort_by: Tuple[str, Literal["asc", "desc"]] | None = None,
423
- ) -> Union[Iterator[Action], AsyncIterator[Action]]:
424
- """
425
- Get all actions for a given run.
426
-
427
- :param for_run_name: The name of the run.
428
- :param filters: The filters to apply to the project list.
429
- :param sort_by: The sorting criteria for the project list, in the format (field, order).
430
- :return: An iterator of projects.
431
- """
432
- ensure_client()
433
- token = None
434
- sort_by = sort_by or ("created_at", "asc")
435
- sort_pb2 = list_pb2.Sort(
436
- key=sort_by[0], direction=list_pb2.Sort.ASCENDING if sort_by[1] == "asc" else list_pb2.Sort.DESCENDING
437
- )
438
- cfg = get_common_config()
439
- while True:
440
- req = list_pb2.ListRequest(
441
- limit=100,
442
- token=token,
443
- sort_by=sort_pb2,
444
- )
445
- resp = await get_client().run_service.ListActions(
446
- run_service_pb2.ListActionsRequest(
447
- request=req,
448
- run_id=run_definition_pb2.RunIdentifier(
449
- org=cfg.org,
450
- project=cfg.project,
451
- domain=cfg.domain,
452
- name=for_run_name,
453
- ),
454
- )
455
- )
456
- token = resp.token
457
- for r in resp.actions:
458
- yield cls(r)
459
- if not token:
460
- break
461
-
462
- @syncify
463
- @classmethod
464
- async def get(cls, uri: str | None = None, /, run_name: str | None = None, name: str | None = None) -> Action:
465
- """
466
- Get a run by its ID or name. If both are provided, the ID will take precedence.
467
-
468
- :param uri: The URI of the action.
469
- :param run_name: The name of the action.
470
- :param name: The name of the action.
471
- """
472
- ensure_client()
473
- cfg = get_common_config()
474
- details: ActionDetails = await ActionDetails.get_details.aio(
475
- run_definition_pb2.ActionIdentifier(
476
- run=run_definition_pb2.RunIdentifier(
477
- org=cfg.org,
478
- project=cfg.project,
479
- domain=cfg.domain,
480
- name=run_name,
481
- ),
482
- name=name,
483
- ),
484
- )
485
- return cls(
486
- pb2=run_definition_pb2.Action(
487
- id=details.action_id,
488
- metadata=details.pb2.metadata,
489
- status=details.pb2.status,
490
- ),
491
- _details=details,
492
- )
493
-
494
- @property
495
- def phase(self) -> str:
496
- """
497
- Get the phase of the action.
498
- """
499
- return run_definition_pb2.Phase.Name(self.pb2.status.phase)
500
-
501
- @property
502
- def raw_phase(self) -> run_definition_pb2.Phase:
503
- """
504
- Get the raw phase of the action.
505
- """
506
- return self.pb2.status.phase
507
-
508
- @property
509
- def name(self) -> str:
510
- """
511
- Get the name of the action.
512
- """
513
- return self.action_id.name
514
-
515
- @property
516
- def run_name(self) -> str:
517
- """
518
- Get the name of the run.
519
- """
520
- return self.action_id.run.name
521
-
522
- @property
523
- def task_name(self) -> str | None:
524
- """
525
- Get the name of the task.
526
- """
527
- if self.pb2.metadata.HasField("task") and self.pb2.metadata.task.HasField("id"):
528
- return self.pb2.metadata.task.id.name
529
- return None
530
-
531
- @property
532
- def action_id(self) -> run_definition_pb2.ActionIdentifier:
533
- """
534
- Get the action ID.
535
- """
536
- return self.pb2.id
537
-
538
- async def show_logs(
539
- self,
540
- attempt: int | None = None,
541
- max_lines: int = 30,
542
- show_ts: bool = False,
543
- raw: bool = False,
544
- filter_system: bool = False,
545
- ):
546
- details = await self.details()
547
- if not details.is_running and not details.done():
548
- # TODO we can short circuit here if the attempt is not the last one and it is done!
549
- await self.wait(wait_for="logs-ready")
550
- details = await self.details()
551
- if not attempt:
552
- attempt = details.attempts
553
- return await Logs.create_viewer(
554
- action_id=self.action_id,
555
- attempt=attempt,
556
- max_lines=max_lines,
557
- show_ts=show_ts,
558
- raw=raw,
559
- filter_system=filter_system,
560
- )
561
-
562
- async def details(self) -> ActionDetails:
563
- """
564
- Get the details of the action. This is a placeholder for getting the action details.
565
- """
566
- if not self._details:
567
- self._details = await ActionDetails.get_details.aio(self.action_id)
568
- return cast(ActionDetails, self._details)
569
-
570
- async def watch(
571
- self, cache_data_on_done: bool = False, wait_for: WaitFor = "terminal"
572
- ) -> AsyncGenerator[ActionDetails, None]:
573
- """
574
- Watch the action for updates. This is a placeholder for watching the action.
575
- """
576
- ad = None
577
- async for ad in ActionDetails.watch.aio(self.action_id):
578
- if ad is None:
579
- return
580
- self._details = ad
581
- yield ad
582
- if wait_for == "running" and ad.is_running:
583
- break
584
- elif wait_for == "logs-ready" and ad.logs_available():
585
- break
586
- if ad.done():
587
- break
588
- if cache_data_on_done and ad and ad.done():
589
- await cast(ActionDetails, self._details).outputs()
590
-
591
- async def wait(self, quiet: bool = False, wait_for: WaitFor = "terminal") -> None:
592
- """
593
- Wait for the run to complete, displaying a rich progress panel with status transitions,
594
- time elapsed, and error details in case of failure.
595
- """
596
- console = Console()
597
- if self.done():
598
- if not quiet:
599
- if self.pb2.status.phase == run_definition_pb2.PHASE_SUCCEEDED:
600
- console.print(
601
- f"[bold green]Action '{self.name}' in Run '{self.run_name}'"
602
- f" completed successfully.[/bold green]"
603
- )
604
- else:
605
- details = await self.details()
606
- console.print(
607
- f"[bold red]Action '{self.name}' in Run '{self.run_name}'"
608
- f" exited unsuccessfully in state {self.phase} with error: {details.error_info}[/bold red]"
609
- )
610
- return
611
-
612
- try:
613
- with Progress(
614
- SpinnerColumn(),
615
- TextColumn("[progress.description]{task.description}"),
616
- TimeElapsedColumn(),
617
- console=console,
618
- transient=True,
619
- disable=quiet,
620
- ) as progress:
621
- task_id = progress.add_task(f"Waiting for run '{self.name}'...", start=False)
622
- progress.start_task(task_id)
623
-
624
- async for ad in self.watch(cache_data_on_done=True, wait_for=wait_for):
625
- if ad is None:
626
- progress.stop_task(task_id)
627
- break
628
-
629
- if ad.is_running and wait_for == "running":
630
- progress.start_task(task_id)
631
- break
632
-
633
- if ad.logs_available() and wait_for == "logs-ready":
634
- progress.start_task(task_id)
635
- break
636
-
637
- # Update progress description with the current phase
638
- progress.update(
639
- task_id,
640
- description=f"Run: {self.run_name} in {ad.phase}, Runtime: {ad.runtime} secs "
641
- f"Attempts[{ad.attempts}]",
642
- )
643
-
644
- # If the action is done, handle the final state
645
- if ad.done():
646
- progress.stop_task(task_id)
647
- if ad.pb2.status.phase == run_definition_pb2.PHASE_SUCCEEDED:
648
- console.print(f"[bold green]Run '{self.run_name}' completed successfully.[/bold green]")
649
- else:
650
- console.print(
651
- f"[bold red]Run '{self.run_name}' exited unsuccessfully in state {ad.phase}"
652
- f"with error: {ad.error_info}[/bold red]"
653
- )
654
- break
655
- except asyncio.CancelledError:
656
- # Handle cancellation gracefully
657
- pass
658
- except KeyboardInterrupt:
659
- # Handle keyboard interrupt gracefully
660
- pass
661
-
662
- def done(self) -> bool:
663
- """
664
- Check if the action is done.
665
- """
666
- return _action_done_check(self.raw_phase)
667
-
668
- async def sync(self) -> Action:
669
- """
670
- Sync the action with the remote server. This is a placeholder for syncing the action.
671
- """
672
- return self
673
-
674
- def __rich_repr__(self) -> rich.repr.Result:
675
- """
676
- Rich representation of the Action object.
677
- """
678
- yield from _action_rich_repr(self.pb2)
679
- if self._details:
680
- yield from self._details.__rich_repr__()
681
-
682
- def __repr__(self) -> str:
683
- """
684
- String representation of the Action object.
685
- """
686
- import rich.pretty
687
-
688
- return rich.pretty.pretty_repr(self)
689
-
690
-
691
- @dataclass
692
- class ActionDetails:
693
- """
694
- A class representing an action. It is used to manage the run of a task and its state on the remote Union API.
695
- """
696
-
697
- pb2: run_definition_pb2.ActionDetails
698
- _inputs: ActionInputs | None = None
699
- _outputs: ActionOutputs | None = None
700
-
701
- @syncify
702
- @classmethod
703
- async def get_details(cls, action_id: run_definition_pb2.ActionIdentifier) -> ActionDetails:
704
- """
705
- Get the details of the action. This is a placeholder for getting the action details.
706
- """
707
- ensure_client()
708
- resp = await get_client().run_service.GetActionDetails(
709
- run_service_pb2.GetActionDetailsRequest(
710
- action_id=action_id,
711
- )
712
- )
713
- return ActionDetails(resp.details)
714
-
715
- @syncify
716
- @classmethod
717
- async def get(
718
- cls, uri: str | None = None, /, run_name: str | None = None, name: str | None = None
719
- ) -> ActionDetails:
720
- """
721
- Get a run by its ID or name. If both are provided, the ID will take precedence.
722
-
723
- :param uri: The URI of the action.
724
- :param name: The name of the action.
725
- :param run_name: The name of the run.
726
- """
727
- ensure_client()
728
- if not uri:
729
- assert name is not None and run_name is not None, "Either uri or name and run_name must be provided"
730
- cfg = get_common_config()
731
- return await cls.get_details.aio(
732
- run_definition_pb2.ActionIdentifier(
733
- run=run_definition_pb2.RunIdentifier(
734
- org=cfg.org,
735
- project=cfg.project,
736
- domain=cfg.domain,
737
- name=run_name,
738
- ),
739
- name=name,
740
- ),
741
- )
742
-
743
- @syncify
744
- @classmethod
745
- async def watch(cls, action_id: run_definition_pb2.ActionIdentifier) -> AsyncIterator[ActionDetails]:
746
- """
747
- Watch the action for updates. This is a placeholder for watching the action.
748
- """
749
- ensure_client()
750
- if not action_id:
751
- raise ValueError("Action ID is required")
752
-
753
- call = cast(
754
- AsyncIterator[WatchActionDetailsResponse],
755
- get_client().run_service.WatchActionDetails(
756
- request=run_service_pb2.WatchActionDetailsRequest(
757
- action_id=action_id,
758
- )
759
- ),
760
- )
761
- try:
762
- async for resp in call:
763
- v = cls(resp.details)
764
- yield v
765
- if v.done():
766
- return
767
- except grpc.aio.AioRpcError as e:
768
- if e.code() == grpc.StatusCode.CANCELLED:
769
- pass
770
- else:
771
- raise e
772
-
773
- async def watch_updates(self, cache_data_on_done: bool = False) -> AsyncGenerator[ActionDetails, None]:
774
- async for d in self.watch.aio(action_id=self.pb2.id):
775
- yield d
776
- if d.done():
777
- self.pb2 = d.pb2
778
- break
779
-
780
- if cache_data_on_done and self.done():
781
- await self._cache_data.aio()
782
-
783
- @property
784
- def phase(self) -> str:
785
- """
786
- Get the phase of the action.
787
- """
788
- return run_definition_pb2.Phase.Name(self.status.phase)
789
-
790
- @property
791
- def raw_phase(self) -> run_definition_pb2.Phase:
792
- """
793
- Get the raw phase of the action.
794
- """
795
- return self.status.phase
796
-
797
- @property
798
- def is_running(self) -> bool:
799
- """
800
- Check if the action is currently running.
801
- """
802
- return self.status.phase == run_definition_pb2.PHASE_RUNNING
803
-
804
- @property
805
- def name(self) -> str:
806
- """
807
- Get the name of the action.
808
- """
809
- return self.action_id.name
810
-
811
- @property
812
- def run_name(self) -> str:
813
- """
814
- Get the name of the run.
815
- """
816
- return self.action_id.run.name
817
-
818
- @property
819
- def task_name(self) -> str | None:
820
- """
821
- Get the name of the task.
822
- """
823
- if self.pb2.metadata.HasField("task") and self.pb2.metadata.task.HasField("id"):
824
- return self.pb2.metadata.task.id.name
825
- return None
826
-
827
- @property
828
- def action_id(self) -> run_definition_pb2.ActionIdentifier:
829
- """
830
- Get the action ID.
831
- """
832
- return self.pb2.id
833
-
834
- @property
835
- def metadata(self) -> run_definition_pb2.ActionMetadata:
836
- return self.pb2.metadata
837
-
838
- @property
839
- def status(self) -> run_definition_pb2.ActionStatus:
840
- return self.pb2.status
841
-
842
- @property
843
- def error_info(self) -> run_definition_pb2.ErrorInfo | None:
844
- if self.pb2.HasField("error_info"):
845
- return self.pb2.error_info
846
- return None
847
-
848
- @property
849
- def abort_info(self) -> run_definition_pb2.AbortInfo | None:
850
- if self.pb2.HasField("abort_info"):
851
- return self.pb2.abort_info
852
- return None
853
-
854
- @property
855
- def runtime(self) -> timedelta:
856
- """
857
- Get the runtime of the action.
858
- """
859
- start_time = timestamp.to_datetime(self.pb2.status.start_time, timezone.utc)
860
- if self.pb2.status.HasField("end_time"):
861
- end_time = timestamp.to_datetime(self.pb2.status.end_time, timezone.utc)
862
- return end_time - start_time
863
- return datetime.now(timezone.utc) - start_time
864
-
865
- @property
866
- def attempts(self) -> int:
867
- """
868
- Get the number of attempts of the action.
869
- """
870
- return self.pb2.status.attempts
871
-
872
- def logs_available(self, attempt: int | None = None) -> bool:
873
- """
874
- Check if logs are available for the action, optionally for a specific attempt.
875
- If attempt is None, it checks for the latest attempt.
876
- """
877
- if attempt is None:
878
- attempt = self.pb2.status.attempts
879
- attempts = self.pb2.attempts
880
- if attempts and len(attempts) >= attempt:
881
- return attempts[attempt - 1].logs_available
882
- return False
883
-
884
- @syncify
885
- async def _cache_data(self) -> bool:
886
- """
887
- Cache the inputs and outputs of the action.
888
- :return: Returns True if Action is terminal and all data is cached else False.
889
- """
890
- if self._inputs and self._outputs:
891
- return True
892
- if self._inputs and not self.done():
893
- return False
894
- resp = await get_client().run_service.GetActionData(
895
- request=run_service_pb2.GetActionDataRequest(
896
- action_id=self.pb2.id,
897
- )
898
- )
899
- self._inputs = ActionInputs(resp.inputs)
900
- self._outputs = ActionOutputs(resp.outputs) if resp.HasField("outputs") else None
901
- return self._outputs is not None
902
-
903
- async def inputs(self) -> ActionInputs:
904
- """
905
- Placeholder for inputs. This can be extended to handle inputs from the run context.
906
- """
907
- if not self._inputs:
908
- await self._cache_data.aio()
909
- return cast(ActionInputs, self._inputs)
910
-
911
- async def outputs(self) -> ActionOutputs:
912
- """
913
- Placeholder for outputs. This can be extended to handle outputs from the run context.
914
- """
915
- if not self._outputs:
916
- if not await self._cache_data.aio():
917
- raise RuntimeError(
918
- "Action is not in a terminal state, outputs are not available. "
919
- "Please wait for the action to complete."
920
- )
921
- return cast(ActionOutputs, self._outputs)
922
-
923
- def done(self) -> bool:
924
- """
925
- Check if the action is in a terminal state (completed or failed). This is a placeholder for checking the
926
- action state.
927
- """
928
- return _action_done_check(self.raw_phase)
929
-
930
- def __rich_repr__(self) -> rich.repr.Result:
931
- """
932
- Rich representation of the Action object.
933
- """
934
- yield from _action_details_rich_repr(self.pb2)
935
-
936
- def __repr__(self) -> str:
937
- """
938
- String representation of the Action object.
939
- """
940
- import rich.pretty
941
-
942
- return rich.pretty.pretty_repr(self)
943
-
944
-
945
- @dataclass
946
- class ActionInputs:
947
- """
948
- A class representing the inputs of an action. It is used to manage the inputs of a task and its state on the
949
- remote Union API.
950
- """
951
-
952
- pb2: run_definition_pb2.Inputs
953
-
954
- def __repr__(self):
955
- import rich.pretty
956
-
957
- import flyte.types as types
958
-
959
- return rich.pretty.pretty_repr(types.literal_string_repr(self.pb2))
960
-
961
-
962
- @dataclass
963
- class ActionOutputs:
964
- """
965
- A class representing the outputs of an action. It is used to manage the outputs of a task and its state on the
966
- remote Union API.
967
- """
968
-
969
- pb2: run_definition_pb2.Outputs
970
-
971
- def __repr__(self):
972
- import rich.pretty
973
-
974
- import flyte.types as types
975
-
976
- return rich.pretty.pretty_repr(types.literal_string_repr(self.pb2))
@@ -16,7 +16,7 @@ import typing
16
16
  from abc import ABC, abstractmethod
17
17
  from collections import OrderedDict
18
18
  from functools import lru_cache
19
- from types import GenericAlias
19
+ from types import GenericAlias, NoneType
20
20
  from typing import Any, Dict, NamedTuple, Optional, Type, cast
21
21
 
22
22
  import msgpack
@@ -307,6 +307,9 @@ class SimpleTransformer(TypeTransformer[T]):
307
307
  expected_python_type = get_underlying_type(expected_python_type)
308
308
 
309
309
  if expected_python_type is not self._type:
310
+ if expected_python_type is None and issubclass(self._type, NoneType):
311
+ # If the expected type is NoneType, we can return None
312
+ return None
310
313
  raise TypeTransformerFailedError(
311
314
  f"Cannot convert to type {expected_python_type}, only {self._type} is supported"
312
315
  )