flyte 0.2.0b4__py3-none-any.whl → 0.2.0b5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

flyte/remote/_project.py CHANGED
@@ -1,14 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
- import typing
4
3
  from dataclasses import dataclass
5
- from typing import AsyncGenerator, Literal, Tuple
4
+ from typing import AsyncIterator, Iterator, Literal, Tuple, Union
6
5
 
7
6
  import rich.repr
8
7
  from flyteidl.admin import common_pb2, project_pb2
9
8
 
10
- from flyte._api_commons import syncer
11
- from flyte._initialize import get_client, get_common_config, requires_client
9
+ from flyte._initialize import ensure_client, get_client, get_common_config
10
+ from flyte.syncify import syncify
12
11
 
13
12
 
14
13
  @dataclass
@@ -19,9 +18,8 @@ class Project:
19
18
 
20
19
  _pb2: project_pb2.Project
21
20
 
21
+ @syncify
22
22
  @classmethod
23
- @requires_client
24
- @syncer.wrap
25
23
  async def get(cls, name: str, org: str | None = None) -> Project:
26
24
  """
27
25
  Get a run by its ID or name. If both are provided, the ID will take precedence.
@@ -29,6 +27,7 @@ class Project:
29
27
  :param name: The name of the project.
30
28
  :param org: The organization of the project (if applicable).
31
29
  """
30
+ ensure_client()
32
31
  service = get_client().project_domain_service # type: ignore
33
32
  resp = await service.GetProject(
34
33
  project_pb2.ProjectGetRequest(
@@ -38,14 +37,13 @@ class Project:
38
37
  )
39
38
  return cls(resp)
40
39
 
40
+ @syncify
41
41
  @classmethod
42
- @requires_client
43
- @syncer.wrap
44
42
  async def listall(
45
43
  cls,
46
44
  filters: str | None = None,
47
45
  sort_by: Tuple[str, Literal["asc", "desc"]] | None = None,
48
- ) -> typing.Union[typing.Iterator[Project], AsyncGenerator[Project, None]]:
46
+ ) -> Union[AsyncIterator[Project], Iterator[Project]]:
49
47
  """
50
48
  Get a run by its ID or name. If both are provided, the ID will take precedence.
51
49
 
@@ -53,6 +51,7 @@ class Project:
53
51
  :param sort_by: The sorting criteria for the project list, in the format (field, order).
54
52
  :return: An iterator of projects.
55
53
  """
54
+ ensure_client()
56
55
  token = None
57
56
  sort_by = sort_by or ("created_at", "asc")
58
57
  sort_pb2 = common_pb2.Sort(
flyte/remote/_run.py CHANGED
@@ -11,16 +11,16 @@ from google.protobuf import timestamp
11
11
  from rich.console import Console
12
12
  from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
13
13
 
14
- from flyte._api_commons import syncer
15
- from flyte._initialize import get_client, get_common_config, requires_client
14
+ from flyte._initialize import ensure_client, get_client, get_common_config
16
15
  from flyte._protos.common import identifier_pb2, list_pb2
17
16
  from flyte._protos.workflow import run_definition_pb2, run_service_pb2
17
+ from flyte.syncify import syncify
18
18
 
19
19
  from .._protos.workflow.run_service_pb2 import WatchActionDetailsResponse
20
20
  from ._console import get_run_url
21
21
  from ._logs import Logs
22
22
 
23
- WaitFor = Literal["terminal", "running"]
23
+ WaitFor = Literal["terminal", "running", "logs-ready"]
24
24
 
25
25
 
26
26
  def _action_time_phase(action: run_definition_pb2.Action | run_definition_pb2.ActionDetails) -> rich.repr.Result:
@@ -106,14 +106,13 @@ class Run:
106
106
  raise RuntimeError("Run does not have an action")
107
107
  self.action = Action(self.pb2.action)
108
108
 
109
+ @syncify
109
110
  @classmethod
110
- @requires_client
111
- @syncer.wrap
112
111
  async def listall(
113
112
  cls,
114
113
  filters: str | None = None,
115
114
  sort_by: Tuple[str, Literal["asc", "desc"]] | None = None,
116
- ) -> Union[Iterator[Run], AsyncGenerator[Run, None]]:
115
+ ) -> AsyncIterator[Run]:
117
116
  """
118
117
  Get all runs for the current project and domain.
119
118
 
@@ -121,6 +120,7 @@ class Run:
121
120
  :param sort_by: The sorting criteria for the project list, in the format (field, order).
122
121
  :return: An iterator of runs.
123
122
  """
123
+ ensure_client()
124
124
  token = None
125
125
  sort_by = sort_by or ("created_at", "asc")
126
126
  sort_pb2 = list_pb2.Sort(
@@ -150,16 +150,16 @@ class Run:
150
150
  if not token:
151
151
  break
152
152
 
153
+ @syncify
153
154
  @classmethod
154
- @requires_client
155
- @syncer.wrap
156
155
  async def get(cls, name: str) -> Run:
157
156
  """
158
157
  Get the current run.
159
158
 
160
159
  :return: The current run.
161
160
  """
162
- run_details: RunDetails = await RunDetails.get.aio(RunDetails, name=name)
161
+ ensure_client()
162
+ run_details: RunDetails = await RunDetails.get.aio(name=name)
163
163
  run = run_definition_pb2.Run(
164
164
  action=run_definition_pb2.Action(
165
165
  id=run_details.action_id,
@@ -181,9 +181,16 @@ class Run:
181
181
  """
182
182
  Get the phase of the run.
183
183
  """
184
- return run_definition_pb2.Phase.Name(self.action.phase)
184
+ return self.action.phase
185
185
 
186
- @syncer.wrap
186
+ @property
187
+ def raw_phase(self) -> run_definition_pb2.Phase:
188
+ """
189
+ Get the raw phase of the run.
190
+ """
191
+ return self.action.raw_phase
192
+
193
+ @syncify
187
194
  async def wait(self, quiet: bool = False, wait_for: Literal["terminal", "running"] = "terminal") -> None:
188
195
  """
189
196
  Wait for the run to complete, displaying a rich progress panel with status transitions,
@@ -216,7 +223,6 @@ class Run:
216
223
  return self._details
217
224
 
218
225
  @property
219
- @requires_client
220
226
  def url(self) -> str:
221
227
  """
222
228
  Get the URL of the run.
@@ -230,7 +236,7 @@ class Run:
230
236
  run_name=self.name,
231
237
  )
232
238
 
233
- @syncer.wrap
239
+ @syncify
234
240
  async def abort(self):
235
241
  """
236
242
  Aborts / Terminates the run.
@@ -291,13 +297,13 @@ class RunDetails:
291
297
  """
292
298
  self.action_details = ActionDetails(self.pb2.action)
293
299
 
300
+ @syncify
294
301
  @classmethod
295
- @requires_client
296
- @syncer.wrap
297
302
  async def get_details(cls, run_id: run_definition_pb2.RunIdentifier) -> RunDetails:
298
303
  """
299
304
  Get the details of the run. This is a placeholder for getting the run details.
300
305
  """
306
+ ensure_client()
301
307
  resp = await get_client().run_service.GetRunDetails(
302
308
  run_service_pb2.GetRunDetailsRequest(
303
309
  run_id=run_id,
@@ -305,9 +311,8 @@ class RunDetails:
305
311
  )
306
312
  return cls(resp.details)
307
313
 
314
+ @syncify
308
315
  @classmethod
309
- @requires_client
310
- @syncer.wrap
311
316
  async def get(cls, name: str | None = None) -> RunDetails:
312
317
  """
313
318
  Get a run by its ID or name. If both are provided, the ID will take precedence.
@@ -315,9 +320,9 @@ class RunDetails:
315
320
  :param uri: The URI of the run.
316
321
  :param name: The name of the run.
317
322
  """
323
+ ensure_client()
318
324
  cfg = get_common_config()
319
325
  return await RunDetails.get_details.aio(
320
- cls,
321
326
  run_id=run_definition_pb2.RunIdentifier(
322
327
  org=cfg.org,
323
328
  project=cfg.project,
@@ -390,15 +395,14 @@ class Action:
390
395
  pb2: run_definition_pb2.Action
391
396
  _details: ActionDetails | None = None
392
397
 
398
+ @syncify
393
399
  @classmethod
394
- @requires_client
395
- @syncer.wrap
396
400
  async def listall(
397
401
  cls,
398
402
  for_run_name: str,
399
403
  filters: str | None = None,
400
404
  sort_by: Tuple[str, Literal["asc", "desc"]] | None = None,
401
- ) -> Union[Iterator[Action], AsyncGenerator[Action, None]]:
405
+ ) -> Union[Iterator[Action], AsyncIterator[Action]]:
402
406
  """
403
407
  Get all actions for a given run.
404
408
 
@@ -407,6 +411,7 @@ class Action:
407
411
  :param sort_by: The sorting criteria for the project list, in the format (field, order).
408
412
  :return: An iterator of projects.
409
413
  """
414
+ ensure_client()
410
415
  token = None
411
416
  sort_by = sort_by or ("created_at", "asc")
412
417
  sort_pb2 = list_pb2.Sort(
@@ -436,9 +441,8 @@ class Action:
436
441
  if not token:
437
442
  break
438
443
 
444
+ @syncify
439
445
  @classmethod
440
- @requires_client
441
- @syncer.wrap
442
446
  async def get(cls, uri: str | None = None, /, run_name: str | None = None, name: str | None = None) -> Action:
443
447
  """
444
448
  Get a run by its ID or name. If both are provided, the ID will take precedence.
@@ -447,9 +451,9 @@ class Action:
447
451
  :param run_name: The name of the action.
448
452
  :param name: The name of the action.
449
453
  """
454
+ ensure_client()
450
455
  cfg = get_common_config()
451
456
  details: ActionDetails = await ActionDetails.get_details.aio(
452
- cls,
453
457
  run_definition_pb2.ActionIdentifier(
454
458
  run=run_definition_pb2.RunIdentifier(
455
459
  org=cfg.org,
@@ -476,6 +480,13 @@ class Action:
476
480
  """
477
481
  return run_definition_pb2.Phase.Name(self.pb2.status.phase)
478
482
 
483
+ @property
484
+ def raw_phase(self) -> run_definition_pb2.Phase:
485
+ """
486
+ Get the raw phase of the action.
487
+ """
488
+ return self.pb2.status.phase
489
+
479
490
  @property
480
491
  def name(self) -> str:
481
492
  """
@@ -515,10 +526,12 @@ class Action:
515
526
  filter_system: bool = False,
516
527
  ):
517
528
  details = await self.details()
529
+ if not details.is_running and not details.done():
530
+ # TODO we can short circuit here if the attempt is not the last one and it is done!
531
+ await self.wait(wait_for="logs-ready")
532
+ details = await self.details()
518
533
  if not attempt:
519
534
  attempt = details.attempts
520
- if not details.is_running:
521
- await self.wait(wait_for="running")
522
535
  return await Logs.create_viewer(
523
536
  action_id=self.action_id,
524
537
  attempt=attempt,
@@ -533,7 +546,7 @@ class Action:
533
546
  Get the details of the action. This is a placeholder for getting the action details.
534
547
  """
535
548
  if not self._details:
536
- self._details = await ActionDetails.get_details.aio(ActionDetails, self.action_id)
549
+ self._details = await ActionDetails.get_details.aio(self.action_id)
537
550
  return cast(ActionDetails, self._details)
538
551
 
539
552
  async def watch(
@@ -543,14 +556,16 @@ class Action:
543
556
  Watch the action for updates. This is a placeholder for watching the action.
544
557
  """
545
558
  ad = None
546
- async for ad in ActionDetails.watch.aio(ActionDetails, self.action_id):
559
+ async for ad in ActionDetails.watch.aio(self.action_id):
547
560
  if ad is None:
548
561
  return
549
562
  self._details = ad
550
563
  yield ad
551
- if wait_for == "running" and ad.phase == run_definition_pb2.PHASE_RUNNING:
564
+ if wait_for == "running" and ad.is_running:
552
565
  break
553
- elif wait_for == "terminal" and _action_done_check(ad.phase):
566
+ elif wait_for == "logs-ready" and ad.logs_available():
567
+ break
568
+ if ad.done():
554
569
  break
555
570
  if cache_data_on_done and ad and ad.done():
556
571
  await cast(ActionDetails, self._details).outputs()
@@ -597,6 +612,10 @@ class Action:
597
612
  progress.start_task(task_id)
598
613
  break
599
614
 
615
+ if ad.logs_available() and wait_for == "logs-ready":
616
+ progress.start_task(task_id)
617
+ break
618
+
600
619
  # Update progress description with the current phase
601
620
  progress.update(
602
621
  task_id,
@@ -626,7 +645,7 @@ class Action:
626
645
  """
627
646
  Check if the action is done.
628
647
  """
629
- return _action_done_check(self.pb2.status.phase)
648
+ return _action_done_check(self.raw_phase)
630
649
 
631
650
  async def sync(self) -> Action:
632
651
  """
@@ -659,13 +678,13 @@ class ActionDetails:
659
678
  _inputs: ActionInputs | None = None
660
679
  _outputs: ActionOutputs | None = None
661
680
 
681
+ @syncify
662
682
  @classmethod
663
- @requires_client
664
- @syncer.wrap
665
683
  async def get_details(cls, action_id: run_definition_pb2.ActionIdentifier) -> ActionDetails:
666
684
  """
667
685
  Get the details of the action. This is a placeholder for getting the action details.
668
686
  """
687
+ ensure_client()
669
688
  resp = await get_client().run_service.GetActionDetails(
670
689
  run_service_pb2.GetActionDetailsRequest(
671
690
  action_id=action_id,
@@ -673,9 +692,8 @@ class ActionDetails:
673
692
  )
674
693
  return ActionDetails(resp.details)
675
694
 
695
+ @syncify
676
696
  @classmethod
677
- @requires_client
678
- @syncer.wrap
679
697
  async def get(
680
698
  cls, uri: str | None = None, /, run_name: str | None = None, name: str | None = None
681
699
  ) -> ActionDetails:
@@ -686,11 +704,11 @@ class ActionDetails:
686
704
  :param name: The name of the action.
687
705
  :param run_name: The name of the run.
688
706
  """
707
+ ensure_client()
689
708
  if not uri:
690
709
  assert name is not None and run_name is not None, "Either uri or name and run_name must be provided"
691
710
  cfg = get_common_config()
692
711
  return await cls.get_details.aio(
693
- cls,
694
712
  run_definition_pb2.ActionIdentifier(
695
713
  run=run_definition_pb2.RunIdentifier(
696
714
  org=cfg.org,
@@ -702,13 +720,13 @@ class ActionDetails:
702
720
  ),
703
721
  )
704
722
 
723
+ @syncify
705
724
  @classmethod
706
- @requires_client
707
- @syncer.wrap
708
- async def watch(cls, action_id: run_definition_pb2.ActionIdentifier) -> AsyncGenerator[ActionDetails, None]:
725
+ async def watch(cls, action_id: run_definition_pb2.ActionIdentifier) -> AsyncIterator[ActionDetails]:
709
726
  """
710
727
  Watch the action for updates. This is a placeholder for watching the action.
711
728
  """
729
+ ensure_client()
712
730
  if not action_id:
713
731
  raise ValueError("Action ID is required")
714
732
 
@@ -740,7 +758,7 @@ class ActionDetails:
740
758
  break
741
759
 
742
760
  if cache_data_on_done and self.done():
743
- await self._cache_data.aio(self)
761
+ await self._cache_data.aio()
744
762
 
745
763
  @property
746
764
  def phase(self) -> str:
@@ -749,6 +767,13 @@ class ActionDetails:
749
767
  """
750
768
  return run_definition_pb2.Phase.Name(self.status.phase)
751
769
 
770
+ @property
771
+ def raw_phase(self) -> run_definition_pb2.Phase:
772
+ """
773
+ Get the raw phase of the action.
774
+ """
775
+ return self.status.phase
776
+
752
777
  @property
753
778
  def is_running(self) -> bool:
754
779
  """
@@ -824,7 +849,19 @@ class ActionDetails:
824
849
  """
825
850
  return self.pb2.status.attempts
826
851
 
827
- @syncer.wrap
852
+ def logs_available(self, attempt: int | None = None) -> bool:
853
+ """
854
+ Check if logs are available for the action, optionally for a specific attempt.
855
+ If attempt is None, it checks for the latest attempt.
856
+ """
857
+ if attempt is None:
858
+ attempt = self.pb2.status.attempts
859
+ attempts = self.pb2.attempts
860
+ if attempts and len(attempts) >= attempt:
861
+ return attempts[attempt - 1].logs_available
862
+ return False
863
+
864
+ @syncify
828
865
  async def _cache_data(self) -> bool:
829
866
  """
830
867
  Cache the inputs and outputs of the action.
@@ -848,7 +885,7 @@ class ActionDetails:
848
885
  Placeholder for inputs. This can be extended to handle inputs from the run context.
849
886
  """
850
887
  if not self._inputs:
851
- await self._cache_data.aio(self)
888
+ await self._cache_data.aio()
852
889
  return cast(ActionInputs, self._inputs)
853
890
 
854
891
  async def outputs(self) -> ActionOutputs:
@@ -856,7 +893,7 @@ class ActionDetails:
856
893
  Placeholder for outputs. This can be extended to handle outputs from the run context.
857
894
  """
858
895
  if not self._outputs:
859
- if not await self._cache_data.aio(self):
896
+ if not await self._cache_data.aio():
860
897
  raise RuntimeError(
861
898
  "Action is not in a terminal state, outputs are not available. "
862
899
  "Please wait for the action to complete."
@@ -868,7 +905,7 @@ class ActionDetails:
868
905
  Check if the action is in a terminal state (completed or failed). This is a placeholder for checking the
869
906
  action state.
870
907
  """
871
- return _action_done_check(self.pb2.status.phase)
908
+ return _action_done_check(self.raw_phase)
872
909
 
873
910
  def __rich_repr__(self) -> rich.repr.Result:
874
911
  """
flyte/remote/_secret.py CHANGED
@@ -1,13 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import AsyncGenerator, Iterator, Literal, Union
4
+ from typing import AsyncIterator, Literal, Union
5
5
 
6
6
  import rich.repr
7
7
 
8
- from flyte._api_commons import syncer
9
- from flyte._initialize import get_client, get_common_config, requires_client
8
+ from flyte._initialize import ensure_client, get_client, get_common_config
10
9
  from flyte._protos.secret import definition_pb2, payload_pb2
10
+ from flyte.syncify import syncify
11
11
 
12
12
  SecretTypes = Literal["regular", "image_pull"]
13
13
 
@@ -16,10 +16,10 @@ SecretTypes = Literal["regular", "image_pull"]
16
16
  class Secret:
17
17
  pb2: definition_pb2.Secret
18
18
 
19
+ @syncify
19
20
  @classmethod
20
- @requires_client
21
- @syncer.wrap
22
21
  async def create(cls, name: str, value: Union[str, bytes], type: SecretTypes = "regular"):
22
+ ensure_client()
23
23
  cfg = get_common_config()
24
24
  secret_type = (
25
25
  definition_pb2.SecretType.SECRET_TYPE_GENERIC
@@ -49,10 +49,10 @@ class Secret:
49
49
  ),
50
50
  )
51
51
 
52
+ @syncify
52
53
  @classmethod
53
- @requires_client
54
- @syncer.wrap
55
54
  async def get(cls, name: str) -> Secret:
55
+ ensure_client()
56
56
  cfg = get_common_config()
57
57
  resp = await get_client().secrets_service.GetSecret(
58
58
  request=payload_pb2.GetSecretRequest(
@@ -66,10 +66,10 @@ class Secret:
66
66
  )
67
67
  return Secret(pb2=resp.secret)
68
68
 
69
+ @syncify
69
70
  @classmethod
70
- @requires_client
71
- @syncer.wrap
72
- async def listall(cls, limit: int = 100) -> Union[Iterator[Secret], AsyncGenerator[Secret, None]]:
71
+ async def listall(cls, limit: int = 100) -> AsyncIterator[Secret]:
72
+ ensure_client()
73
73
  cfg = get_common_config()
74
74
  token = None
75
75
  while True:
@@ -88,10 +88,10 @@ class Secret:
88
88
  if not token:
89
89
  break
90
90
 
91
+ @syncify
91
92
  @classmethod
92
- @requires_client
93
- @syncer.wrap
94
93
  async def delete(cls, name):
94
+ ensure_client()
95
95
  cfg = get_common_config()
96
96
  await get_client().secrets_service.DeleteSecret( # type: ignore
97
97
  request=payload_pb2.DeleteSecretRequest(
flyte/remote/_task.py CHANGED
@@ -9,11 +9,11 @@ import rich.repr
9
9
 
10
10
  import flyte
11
11
  import flyte.errors
12
- from flyte._api_commons import syncer
13
12
  from flyte._context import internal_ctx
14
13
  from flyte._initialize import get_client, get_common_config
15
14
  from flyte._protos.workflow import task_definition_pb2, task_service_pb2
16
15
  from flyte.models import NativeInterface
16
+ from flyte.syncify import syncify
17
17
 
18
18
 
19
19
  class LazyEntity:
@@ -32,7 +32,7 @@ class LazyEntity:
32
32
  def name(self) -> str:
33
33
  return self._name
34
34
 
35
- @syncer.wrap
35
+ @syncify
36
36
  async def fetch(self) -> Task:
37
37
  """
38
38
  Forwards all other attributes to task, causing the task to be fetched!
@@ -48,7 +48,7 @@ class LazyEntity:
48
48
  """
49
49
  Forwards the call to the underlying task. The entity will be fetched if not already present
50
50
  """
51
- tk = await self.fetch.aio(self)
51
+ tk = await self.fetch.aio()
52
52
  return await tk(*args, **kwargs)
53
53
 
54
54
  def __repr__(self) -> str:
flyte/report/_report.py CHANGED
@@ -4,10 +4,10 @@ import string
4
4
  from dataclasses import dataclass, field
5
5
  from typing import TYPE_CHECKING, Dict, List, Union
6
6
 
7
- from flyte._api_commons import syncer
8
7
  from flyte._internal.runtime import io
9
8
  from flyte._logging import logger
10
9
  from flyte._tools import ipython_check
10
+ from flyte.syncify import syncify
11
11
 
12
12
  if TYPE_CHECKING:
13
13
  from IPython.core.display import HTML
@@ -112,7 +112,7 @@ def get_tab(name: str, /, create_if_missing: bool = True) -> Tab:
112
112
  return report.get_tab(name, create_if_missing=create_if_missing)
113
113
 
114
114
 
115
- @syncer.wrap
115
+ @syncify
116
116
  async def log(content: str, do_flush: bool = False):
117
117
  """
118
118
  Log content to the main tab. The content should be a valid HTML string, but not a complete HTML document,
@@ -126,7 +126,7 @@ async def log(content: str, do_flush: bool = False):
126
126
  await flush.aio()
127
127
 
128
128
 
129
- @syncer.wrap
129
+ @syncify
130
130
  async def flush():
131
131
  """
132
132
  Flush the report.
@@ -149,7 +149,7 @@ async def flush():
149
149
  logger.debug(f"Report flushed to {final_path}")
150
150
 
151
151
 
152
- @syncer.wrap
152
+ @syncify
153
153
  async def replace(content: str, do_flush: bool = False):
154
154
  """
155
155
  Get the report. Replaces the content of the main tab.
@@ -0,0 +1,5 @@
1
+ from flyte.syncify._api import Syncify
2
+
3
+ syncify = Syncify()
4
+
5
+ __all__ = ["Syncify", "syncify"]