flyte 2.0.0b32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (204) hide show
  1. flyte/__init__.py +108 -0
  2. flyte/_bin/__init__.py +0 -0
  3. flyte/_bin/debug.py +38 -0
  4. flyte/_bin/runtime.py +195 -0
  5. flyte/_bin/serve.py +178 -0
  6. flyte/_build.py +26 -0
  7. flyte/_cache/__init__.py +12 -0
  8. flyte/_cache/cache.py +147 -0
  9. flyte/_cache/defaults.py +9 -0
  10. flyte/_cache/local_cache.py +216 -0
  11. flyte/_cache/policy_function_body.py +42 -0
  12. flyte/_code_bundle/__init__.py +8 -0
  13. flyte/_code_bundle/_ignore.py +121 -0
  14. flyte/_code_bundle/_packaging.py +218 -0
  15. flyte/_code_bundle/_utils.py +347 -0
  16. flyte/_code_bundle/bundle.py +266 -0
  17. flyte/_constants.py +1 -0
  18. flyte/_context.py +155 -0
  19. flyte/_custom_context.py +73 -0
  20. flyte/_debug/__init__.py +0 -0
  21. flyte/_debug/constants.py +38 -0
  22. flyte/_debug/utils.py +17 -0
  23. flyte/_debug/vscode.py +307 -0
  24. flyte/_deploy.py +408 -0
  25. flyte/_deployer.py +109 -0
  26. flyte/_doc.py +29 -0
  27. flyte/_docstring.py +32 -0
  28. flyte/_environment.py +122 -0
  29. flyte/_excepthook.py +37 -0
  30. flyte/_group.py +32 -0
  31. flyte/_hash.py +8 -0
  32. flyte/_image.py +1055 -0
  33. flyte/_initialize.py +628 -0
  34. flyte/_interface.py +119 -0
  35. flyte/_internal/__init__.py +3 -0
  36. flyte/_internal/controllers/__init__.py +129 -0
  37. flyte/_internal/controllers/_local_controller.py +239 -0
  38. flyte/_internal/controllers/_trace.py +48 -0
  39. flyte/_internal/controllers/remote/__init__.py +58 -0
  40. flyte/_internal/controllers/remote/_action.py +211 -0
  41. flyte/_internal/controllers/remote/_client.py +47 -0
  42. flyte/_internal/controllers/remote/_controller.py +583 -0
  43. flyte/_internal/controllers/remote/_core.py +465 -0
  44. flyte/_internal/controllers/remote/_informer.py +381 -0
  45. flyte/_internal/controllers/remote/_service_protocol.py +50 -0
  46. flyte/_internal/imagebuild/__init__.py +3 -0
  47. flyte/_internal/imagebuild/docker_builder.py +706 -0
  48. flyte/_internal/imagebuild/image_builder.py +277 -0
  49. flyte/_internal/imagebuild/remote_builder.py +386 -0
  50. flyte/_internal/imagebuild/utils.py +78 -0
  51. flyte/_internal/resolvers/__init__.py +0 -0
  52. flyte/_internal/resolvers/_task_module.py +21 -0
  53. flyte/_internal/resolvers/common.py +31 -0
  54. flyte/_internal/resolvers/default.py +28 -0
  55. flyte/_internal/runtime/__init__.py +0 -0
  56. flyte/_internal/runtime/convert.py +486 -0
  57. flyte/_internal/runtime/entrypoints.py +204 -0
  58. flyte/_internal/runtime/io.py +188 -0
  59. flyte/_internal/runtime/resources_serde.py +152 -0
  60. flyte/_internal/runtime/reuse.py +125 -0
  61. flyte/_internal/runtime/rusty.py +193 -0
  62. flyte/_internal/runtime/task_serde.py +362 -0
  63. flyte/_internal/runtime/taskrunner.py +209 -0
  64. flyte/_internal/runtime/trigger_serde.py +160 -0
  65. flyte/_internal/runtime/types_serde.py +54 -0
  66. flyte/_keyring/__init__.py +0 -0
  67. flyte/_keyring/file.py +115 -0
  68. flyte/_logging.py +300 -0
  69. flyte/_map.py +312 -0
  70. flyte/_module.py +72 -0
  71. flyte/_pod.py +30 -0
  72. flyte/_resources.py +473 -0
  73. flyte/_retry.py +32 -0
  74. flyte/_reusable_environment.py +102 -0
  75. flyte/_run.py +724 -0
  76. flyte/_secret.py +96 -0
  77. flyte/_task.py +550 -0
  78. flyte/_task_environment.py +316 -0
  79. flyte/_task_plugins.py +47 -0
  80. flyte/_timeout.py +47 -0
  81. flyte/_tools.py +27 -0
  82. flyte/_trace.py +119 -0
  83. flyte/_trigger.py +1000 -0
  84. flyte/_utils/__init__.py +30 -0
  85. flyte/_utils/asyn.py +121 -0
  86. flyte/_utils/async_cache.py +139 -0
  87. flyte/_utils/coro_management.py +27 -0
  88. flyte/_utils/docker_credentials.py +173 -0
  89. flyte/_utils/file_handling.py +72 -0
  90. flyte/_utils/helpers.py +134 -0
  91. flyte/_utils/lazy_module.py +54 -0
  92. flyte/_utils/module_loader.py +104 -0
  93. flyte/_utils/org_discovery.py +57 -0
  94. flyte/_utils/uv_script_parser.py +49 -0
  95. flyte/_version.py +34 -0
  96. flyte/app/__init__.py +22 -0
  97. flyte/app/_app_environment.py +157 -0
  98. flyte/app/_deploy.py +125 -0
  99. flyte/app/_input.py +160 -0
  100. flyte/app/_runtime/__init__.py +3 -0
  101. flyte/app/_runtime/app_serde.py +347 -0
  102. flyte/app/_types.py +101 -0
  103. flyte/app/extras/__init__.py +3 -0
  104. flyte/app/extras/_fastapi.py +151 -0
  105. flyte/cli/__init__.py +12 -0
  106. flyte/cli/_abort.py +28 -0
  107. flyte/cli/_build.py +114 -0
  108. flyte/cli/_common.py +468 -0
  109. flyte/cli/_create.py +371 -0
  110. flyte/cli/_delete.py +45 -0
  111. flyte/cli/_deploy.py +293 -0
  112. flyte/cli/_gen.py +176 -0
  113. flyte/cli/_get.py +370 -0
  114. flyte/cli/_option.py +33 -0
  115. flyte/cli/_params.py +554 -0
  116. flyte/cli/_plugins.py +209 -0
  117. flyte/cli/_run.py +597 -0
  118. flyte/cli/_serve.py +64 -0
  119. flyte/cli/_update.py +37 -0
  120. flyte/cli/_user.py +17 -0
  121. flyte/cli/main.py +221 -0
  122. flyte/config/__init__.py +3 -0
  123. flyte/config/_config.py +248 -0
  124. flyte/config/_internal.py +73 -0
  125. flyte/config/_reader.py +225 -0
  126. flyte/connectors/__init__.py +11 -0
  127. flyte/connectors/_connector.py +270 -0
  128. flyte/connectors/_server.py +197 -0
  129. flyte/connectors/utils.py +135 -0
  130. flyte/errors.py +243 -0
  131. flyte/extend.py +19 -0
  132. flyte/extras/__init__.py +5 -0
  133. flyte/extras/_container.py +286 -0
  134. flyte/git/__init__.py +3 -0
  135. flyte/git/_config.py +21 -0
  136. flyte/io/__init__.py +29 -0
  137. flyte/io/_dataframe/__init__.py +131 -0
  138. flyte/io/_dataframe/basic_dfs.py +223 -0
  139. flyte/io/_dataframe/dataframe.py +1026 -0
  140. flyte/io/_dir.py +910 -0
  141. flyte/io/_file.py +914 -0
  142. flyte/io/_hashing_io.py +342 -0
  143. flyte/models.py +479 -0
  144. flyte/py.typed +0 -0
  145. flyte/remote/__init__.py +35 -0
  146. flyte/remote/_action.py +738 -0
  147. flyte/remote/_app.py +57 -0
  148. flyte/remote/_client/__init__.py +0 -0
  149. flyte/remote/_client/_protocols.py +189 -0
  150. flyte/remote/_client/auth/__init__.py +12 -0
  151. flyte/remote/_client/auth/_auth_utils.py +14 -0
  152. flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
  153. flyte/remote/_client/auth/_authenticators/base.py +403 -0
  154. flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  155. flyte/remote/_client/auth/_authenticators/device_code.py +117 -0
  156. flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
  157. flyte/remote/_client/auth/_authenticators/factory.py +200 -0
  158. flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
  159. flyte/remote/_client/auth/_channel.py +213 -0
  160. flyte/remote/_client/auth/_client_config.py +85 -0
  161. flyte/remote/_client/auth/_default_html.py +32 -0
  162. flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  163. flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
  164. flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
  165. flyte/remote/_client/auth/_keyring.py +152 -0
  166. flyte/remote/_client/auth/_token_client.py +260 -0
  167. flyte/remote/_client/auth/errors.py +16 -0
  168. flyte/remote/_client/controlplane.py +128 -0
  169. flyte/remote/_common.py +30 -0
  170. flyte/remote/_console.py +19 -0
  171. flyte/remote/_data.py +161 -0
  172. flyte/remote/_logs.py +185 -0
  173. flyte/remote/_project.py +88 -0
  174. flyte/remote/_run.py +386 -0
  175. flyte/remote/_secret.py +142 -0
  176. flyte/remote/_task.py +527 -0
  177. flyte/remote/_trigger.py +306 -0
  178. flyte/remote/_user.py +33 -0
  179. flyte/report/__init__.py +3 -0
  180. flyte/report/_report.py +182 -0
  181. flyte/report/_template.html +124 -0
  182. flyte/storage/__init__.py +36 -0
  183. flyte/storage/_config.py +237 -0
  184. flyte/storage/_parallel_reader.py +274 -0
  185. flyte/storage/_remote_fs.py +34 -0
  186. flyte/storage/_storage.py +456 -0
  187. flyte/storage/_utils.py +5 -0
  188. flyte/syncify/__init__.py +56 -0
  189. flyte/syncify/_api.py +375 -0
  190. flyte/types/__init__.py +52 -0
  191. flyte/types/_interface.py +40 -0
  192. flyte/types/_pickle.py +145 -0
  193. flyte/types/_renderer.py +162 -0
  194. flyte/types/_string_literals.py +119 -0
  195. flyte/types/_type_engine.py +2254 -0
  196. flyte/types/_utils.py +80 -0
  197. flyte-2.0.0b32.data/scripts/debug.py +38 -0
  198. flyte-2.0.0b32.data/scripts/runtime.py +195 -0
  199. flyte-2.0.0b32.dist-info/METADATA +351 -0
  200. flyte-2.0.0b32.dist-info/RECORD +204 -0
  201. flyte-2.0.0b32.dist-info/WHEEL +5 -0
  202. flyte-2.0.0b32.dist-info/entry_points.txt +7 -0
  203. flyte-2.0.0b32.dist-info/licenses/LICENSE +201 -0
  204. flyte-2.0.0b32.dist-info/top_level.txt +1 -0
@@ -0,0 +1,128 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+
5
+ # Set environment variables for gRPC, this reduces log spew and avoids unnecessary warnings
6
+ # before importing grpc
7
+ if "GRPC_VERBOSITY" not in os.environ:
8
+ os.environ["GRPC_VERBOSITY"] = "ERROR"
9
+ os.environ["GRPC_CPP_MIN_LOG_LEVEL"] = "ERROR"
10
+ # Disable fork support (stops "skipping fork() handlers")
11
+ os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "0"
12
+ # Reduce absl/glog verbosity
13
+ os.environ["GLOG_minloglevel"] = "2"
14
+ os.environ["ABSL_LOG"] = "0"
15
+ #### Has to be before grpc
16
+
17
+ import grpc
18
+ from flyteidl.service import admin_pb2_grpc, dataproxy_pb2_grpc, identity_pb2_grpc
19
+ from flyteidl2.app import app_service_pb2_grpc
20
+ from flyteidl2.secret import secret_pb2_grpc
21
+ from flyteidl2.task import task_service_pb2_grpc
22
+ from flyteidl2.trigger import trigger_service_pb2_grpc
23
+ from flyteidl2.workflow import run_logs_service_pb2_grpc, run_service_pb2_grpc
24
+
25
+ from ._protocols import (
26
+ AppService,
27
+ DataProxyService,
28
+ IdentityService,
29
+ MetadataServiceProtocol,
30
+ ProjectDomainService,
31
+ RunLogsService,
32
+ RunService,
33
+ SecretService,
34
+ TaskService,
35
+ TriggerService,
36
+ )
37
+ from .auth import create_channel
38
+
39
+
40
+ class ClientSet:
41
+ def __init__(
42
+ self,
43
+ channel: grpc.aio.Channel,
44
+ endpoint: str,
45
+ insecure: bool = False,
46
+ **kwargs,
47
+ ):
48
+ self.endpoint = endpoint
49
+ self.insecure = insecure
50
+ self._channel = channel
51
+ self._admin_client = admin_pb2_grpc.AdminServiceStub(channel=channel)
52
+ self._task_service = task_service_pb2_grpc.TaskServiceStub(channel=channel)
53
+ self._app_service = app_service_pb2_grpc.AppServiceStub(channel=channel)
54
+ self._run_service = run_service_pb2_grpc.RunServiceStub(channel=channel)
55
+ self._dataproxy = dataproxy_pb2_grpc.DataProxyServiceStub(channel=channel)
56
+ self._log_service = run_logs_service_pb2_grpc.RunLogsServiceStub(channel=channel)
57
+ self._secrets_service = secret_pb2_grpc.SecretServiceStub(channel=channel)
58
+ self._identity_service = identity_pb2_grpc.IdentityServiceStub(channel=channel)
59
+ self._trigger_service = trigger_service_pb2_grpc.TriggerServiceStub(channel=channel)
60
+
61
+ @classmethod
62
+ async def for_endpoint(cls, endpoint: str, *, insecure: bool = False, **kwargs) -> ClientSet:
63
+ return cls(
64
+ await create_channel(endpoint, None, insecure=insecure, **kwargs), endpoint, insecure=insecure, **kwargs
65
+ )
66
+
67
+ @classmethod
68
+ async def for_api_key(cls, api_key: str, *, insecure: bool = False, **kwargs) -> ClientSet:
69
+ from flyte.remote._client.auth._auth_utils import decode_api_key
70
+
71
+ # Parsing the API key is done in create_channel, but cleaner to redo it here rather than getting create_channel
72
+ # to return the endpoint
73
+ endpoint, _, _, _ = decode_api_key(api_key)
74
+
75
+ return cls(
76
+ await create_channel(None, api_key, insecure=insecure, **kwargs), endpoint, insecure=insecure, **kwargs
77
+ )
78
+
79
+ @classmethod
80
+ async def for_serverless(cls) -> ClientSet:
81
+ raise NotImplementedError
82
+
83
+ @classmethod
84
+ async def from_env(cls) -> ClientSet:
85
+ raise NotImplementedError
86
+
87
+ @property
88
+ def metadata_service(self) -> MetadataServiceProtocol:
89
+ return self._admin_client
90
+
91
+ @property
92
+ def project_domain_service(self) -> ProjectDomainService:
93
+ return self._admin_client
94
+
95
+ @property
96
+ def task_service(self) -> TaskService:
97
+ return self._task_service
98
+
99
+ @property
100
+ def app_service(self) -> AppService:
101
+ return self._app_service
102
+
103
+ @property
104
+ def run_service(self) -> RunService:
105
+ return self._run_service
106
+
107
+ @property
108
+ def dataproxy_service(self) -> DataProxyService:
109
+ return self._dataproxy
110
+
111
+ @property
112
+ def logs_service(self) -> RunLogsService:
113
+ return self._log_service
114
+
115
+ @property
116
+ def secrets_service(self) -> SecretService:
117
+ return self._secrets_service
118
+
119
+ @property
120
+ def identity_service(self) -> IdentityService:
121
+ return self._identity_service
122
+
123
+ @property
124
+ def trigger_service(self) -> TriggerService:
125
+ return self._trigger_service
126
+
127
+ async def close(self, grace: float | None = None):
128
+ return await self._channel.close(grace=grace)
@@ -0,0 +1,30 @@
1
+ import json
2
+
3
+ from google.protobuf.json_format import MessageToDict, MessageToJson
4
+
5
+
6
+ class ToJSONMixin:
7
+ """
8
+ A mixin class that provides a method to convert an object to a JSON-serializable dictionary.
9
+ """
10
+
11
+ def to_dict(self) -> dict:
12
+ """
13
+ Convert the object to a JSON-serializable dictionary.
14
+
15
+ Returns:
16
+ dict: A dictionary representation of the object.
17
+ """
18
+ if hasattr(self, "pb2"):
19
+ return MessageToDict(self.pb2)
20
+ else:
21
+ return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
22
+
23
+ def to_json(self) -> str:
24
+ """
25
+ Convert the object to a JSON string.
26
+
27
+ Returns:
28
+ str: A JSON string representation of the object.
29
+ """
30
+ return MessageToJson(self.pb2) if hasattr(self, "pb2") else json.dumps(self.to_dict())
@@ -0,0 +1,19 @@
1
+ from urllib.parse import urlparse
2
+
3
+
4
+ def _get_http_domain(endpoint: str, insecure: bool) -> str:
5
+ scheme = "http" if insecure else "https"
6
+ parsed = urlparse(endpoint)
7
+ if parsed.scheme == "dns":
8
+ domain = parsed.path.lstrip("/")
9
+ else:
10
+ domain = parsed.netloc or parsed.path
11
+ # TODO: make console url configurable
12
+ domain_split = domain.split(":")
13
+ if domain_split[0] == "localhost":
14
+ domain = domain if len(domain_split) > 1 else f"{domain}:8080"
15
+ return f"{scheme}://{domain}"
16
+
17
+
18
+ def get_run_url(endpoint: str, insecure: bool, project: str, domain: str, run_name: str) -> str:
19
+ return f"{_get_http_domain(endpoint, insecure)}/v2/runs/project/{project}/domain/{domain}/{run_name}"
flyte/remote/_data.py ADDED
@@ -0,0 +1,161 @@
1
+ import asyncio
2
+ import hashlib
3
+ import os
4
+ import typing
5
+ import uuid
6
+ from base64 import b64encode
7
+ from datetime import timedelta
8
+ from functools import lru_cache
9
+ from pathlib import Path
10
+ from typing import Tuple
11
+
12
+ import aiofiles
13
+ import grpc
14
+ import httpx
15
+ from flyteidl.service import dataproxy_pb2
16
+ from google.protobuf import duration_pb2
17
+
18
+ from flyte._initialize import CommonInit, ensure_client, get_client, get_init_config, require_project_and_domain
19
+ from flyte.errors import InitializationError, RuntimeSystemError
20
+ from flyte.syncify import syncify
21
+
22
+ _UPLOAD_EXPIRES_IN = timedelta(seconds=60)
23
+
24
+
25
+ def get_extra_headers_for_protocol(native_url: str) -> typing.Dict[str, str]:
26
+ """
27
+ For Azure Blob Storage, we need to set certain headers for http request.
28
+ This is used when we work with signed urls.
29
+ :param native_url:
30
+ :return:
31
+ """
32
+ if native_url.startswith("abfs://"):
33
+ return {"x-ms-blob-type": "BlockBlob"}
34
+ return {}
35
+
36
+
37
+ @lru_cache
38
+ def hash_file(file_path: typing.Union[os.PathLike, str]) -> Tuple[bytes, str, int]:
39
+ """
40
+ Hash a file and produce a digest to be used as a version
41
+ """
42
+ h = hashlib.md5()
43
+ size = 0
44
+
45
+ with open(file_path, "rb") as file:
46
+ while True:
47
+ # Reading is buffered, so we can read smaller chunks.
48
+ chunk = file.read(h.block_size)
49
+ if not chunk:
50
+ break
51
+ h.update(chunk)
52
+ size += len(chunk)
53
+
54
+ return h.digest(), h.hexdigest(), size
55
+
56
+
57
+ @require_project_and_domain
58
+ async def _upload_single_file(
59
+ cfg: CommonInit, fp: Path, verify: bool = True, basedir: str | None = None
60
+ ) -> Tuple[str, str]:
61
+ md5_bytes, str_digest, _ = hash_file(fp)
62
+ from flyte._logging import logger
63
+
64
+ try:
65
+ expires_in_pb = duration_pb2.Duration()
66
+ expires_in_pb.FromTimedelta(_UPLOAD_EXPIRES_IN)
67
+ client = get_client()
68
+ resp = await client.dataproxy_service.CreateUploadLocation( # type: ignore
69
+ dataproxy_pb2.CreateUploadLocationRequest(
70
+ project=cfg.project,
71
+ domain=cfg.domain,
72
+ content_md5=md5_bytes,
73
+ filename=fp.name,
74
+ expires_in=expires_in_pb,
75
+ filename_root=basedir,
76
+ add_content_md5_metadata=True,
77
+ )
78
+ )
79
+ except grpc.aio.AioRpcError as e:
80
+ if e.code() == grpc.StatusCode.NOT_FOUND:
81
+ raise RuntimeSystemError(
82
+ "NotFound", f"Failed to get signed url for {fp}, please check your project and domain: {e.details()}"
83
+ )
84
+ elif e.code() == grpc.StatusCode.PERMISSION_DENIED:
85
+ raise RuntimeSystemError(
86
+ "PermissionDenied", f"Failed to get signed url for {fp}, please check your permissions: {e.details()}"
87
+ )
88
+ elif e.code() == grpc.StatusCode.UNAVAILABLE:
89
+ raise InitializationError("EndpointUnavailable", "user", "Service is unavailable.")
90
+ else:
91
+ raise RuntimeSystemError(e.code().value, f"Failed to get signed url for {fp}: {e.details()}")
92
+ except Exception as e:
93
+ raise RuntimeSystemError(type(e).__name__, f"Failed to get signed url for {fp}.") from e
94
+ logger.debug(f"Uploading to [link={resp.signed_url}]signed url[/link] for [link=file://{fp}]{fp}[/link]")
95
+ extra_headers = get_extra_headers_for_protocol(resp.native_url)
96
+ extra_headers.update(resp.headers)
97
+ encoded_md5 = b64encode(md5_bytes)
98
+ content_length = fp.stat().st_size
99
+
100
+ async with aiofiles.open(str(fp), "rb") as file:
101
+ extra_headers.update({"Content-Length": str(content_length), "Content-MD5": encoded_md5.decode("utf-8")})
102
+ async with httpx.AsyncClient(verify=verify) as aclient:
103
+ put_resp = await aclient.put(resp.signed_url, headers=extra_headers, content=file)
104
+ if put_resp.status_code not in [200, 201, 204]:
105
+ raise RuntimeSystemError(
106
+ "UploadFailed",
107
+ f"Failed to upload {fp} to {resp.signed_url}, status code: {put_resp.status_code}, "
108
+ f"response: {put_resp.text}",
109
+ )
110
+ # TODO in old code we did this
111
+ # if self._config.platform.insecure_skip_verify is True
112
+ # else self._config.platform.ca_cert_file_path,
113
+ logger.debug(f"Uploaded with digest {str_digest}, blob location is {resp.native_url}")
114
+ return str_digest, resp.native_url
115
+
116
+
117
+ @syncify
118
+ async def upload_file(fp: Path, verify: bool = True) -> Tuple[str, str]:
119
+ """
120
+ Uploads a file to a remote location and returns the remote URI.
121
+
122
+ :param fp: The file path to upload.
123
+ :param verify: Whether to verify the certificate for HTTPS requests.
124
+ :return: A tuple containing the MD5 digest and the remote URI.
125
+ """
126
+ # This is a placeholder implementation. Replace with actual upload logic.
127
+ ensure_client()
128
+ cfg = get_init_config()
129
+ if not fp.is_file():
130
+ raise ValueError(f"{fp} is not a single file, upload arg must be a single file.")
131
+ return await _upload_single_file(cfg, fp, verify=verify)
132
+
133
+
134
+ async def upload_dir(dir_path: Path, verify: bool = True) -> str:
135
+ """
136
+ Uploads a directory to a remote location and returns the remote URI.
137
+
138
+ :param dir_path: The directory path to upload.
139
+ :param verify: Whether to verify the certificate for HTTPS requests.
140
+ :return: The remote URI of the uploaded directory.
141
+ """
142
+ # This is a placeholder implementation. Replace with actual upload logic.
143
+ ensure_client()
144
+ cfg = get_init_config()
145
+ if not dir_path.is_dir():
146
+ raise ValueError(f"{dir_path} is not a directory, upload arg must be a directory.")
147
+
148
+ prefix = uuid.uuid4().hex
149
+
150
+ files = dir_path.rglob("*")
151
+ uploaded_files = []
152
+ for file in files:
153
+ if file.is_file():
154
+ uploaded_files.append(_upload_single_file(cfg, file, verify=verify, basedir=prefix))
155
+
156
+ urls = await asyncio.gather(*uploaded_files)
157
+ native_url = urls[0][1] # Assuming all files are uploaded to the same prefix
158
+ # native_url is of the form s3://my-s3-bucket/flytesnacks/development/{prefix}/source/empty.md
159
+ uri = native_url.split(prefix)[0] + "/" + prefix
160
+
161
+ return uri
flyte/remote/_logs.py ADDED
@@ -0,0 +1,185 @@
1
+ import asyncio
2
+ from collections import deque
3
+ from dataclasses import dataclass
4
+ from typing import AsyncGenerator, AsyncIterator
5
+
6
+ import grpc
7
+ from flyteidl2.common import identifier_pb2
8
+ from flyteidl2.logs.dataplane import payload_pb2
9
+ from flyteidl2.workflow import run_logs_service_pb2
10
+ from rich.console import Console
11
+ from rich.live import Live
12
+ from rich.panel import Panel
13
+ from rich.text import Text
14
+
15
+ from flyte._initialize import ensure_client, get_client
16
+ from flyte._logging import logger
17
+ from flyte._tools import ipython_check, ipywidgets_check
18
+ from flyte.errors import LogsNotYetAvailableError
19
+ from flyte.syncify import syncify
20
+
21
+ style_map = {
22
+ payload_pb2.LogLineOriginator.SYSTEM: "bold magenta",
23
+ payload_pb2.LogLineOriginator.USER: "cyan",
24
+ payload_pb2.LogLineOriginator.UNKNOWN: "light red",
25
+ }
26
+
27
+
28
+ def _format_line(logline: payload_pb2.LogLine, show_ts: bool, filter_system: bool) -> Text | None:
29
+ if filter_system:
30
+ if logline.originator == payload_pb2.LogLineOriginator.SYSTEM:
31
+ return None
32
+ style = style_map.get(logline.originator, "")
33
+ if "[flyte]" in logline.message and "flyte.errors" not in logline.message:
34
+ if filter_system:
35
+ return None
36
+ style = "dim"
37
+ ts = ""
38
+ if show_ts:
39
+ ts = f"[{logline.timestamp.ToDatetime().isoformat()}]"
40
+ return Text(f"{ts} {logline.message}", style=style)
41
+
42
+
43
+ class AsyncLogViewer:
44
+ """
45
+ A class to view logs asynchronously in the console or terminal or jupyter notebook.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ log_source: AsyncIterator,
51
+ max_lines: int = 30,
52
+ name: str = "Logs",
53
+ show_ts: bool = False,
54
+ filter_system: bool = False,
55
+ panel: bool = False,
56
+ ):
57
+ self.console = Console()
58
+ self.log_source = log_source
59
+ self.max_lines = max_lines
60
+ self.lines: deque = deque(maxlen=max_lines + 1)
61
+ self.name = name
62
+ self.show_ts = show_ts
63
+ self.total_lines = 0
64
+ self.filter_flyte = filter_system
65
+ self.panel = panel
66
+
67
+ def _render(self) -> Panel | Text:
68
+ log_text = Text()
69
+ for line in self.lines:
70
+ log_text.append(line)
71
+ if self.panel:
72
+ return Panel(log_text, title=self.name, border_style="yellow")
73
+ return log_text
74
+
75
+ async def run(self):
76
+ with Live(self._render(), refresh_per_second=20, console=self.console) as live:
77
+ try:
78
+ async for logline in self.log_source:
79
+ formatted = _format_line(logline, show_ts=self.show_ts, filter_system=self.filter_flyte)
80
+ if formatted:
81
+ self.lines.append(formatted)
82
+ self.total_lines += 1
83
+ live.update(self._render())
84
+ except asyncio.CancelledError:
85
+ pass
86
+ except KeyboardInterrupt:
87
+ pass
88
+ except StopAsyncIteration:
89
+ self.console.print("[dim]Log stream ended.[/dim]")
90
+ except LogsNotYetAvailableError as e:
91
+ self.console.print(f"[red]Error:[/red] {e}")
92
+ live.update("")
93
+ self.console.print(f"Scrolled {self.total_lines} lines of logs.")
94
+
95
+
96
+ @dataclass
97
+ class Logs:
98
+ @syncify
99
+ @classmethod
100
+ async def tail(
101
+ cls,
102
+ action_id: identifier_pb2.ActionIdentifier,
103
+ attempt: int = 1,
104
+ retry: int = 5,
105
+ ) -> AsyncGenerator[payload_pb2.LogLine, None]:
106
+ """
107
+ Tail the logs for a given action ID and attempt.
108
+ :param action_id: The action ID to tail logs for.
109
+ :param attempt: The attempt number (default is 0).
110
+ """
111
+ ensure_client()
112
+ retries = 0
113
+ while True:
114
+ try:
115
+ resp = get_client().logs_service.TailLogs(
116
+ run_logs_service_pb2.TailLogsRequest(action_id=action_id, attempt=attempt)
117
+ )
118
+ async for log_set in resp:
119
+ if log_set.logs:
120
+ for log in log_set.logs:
121
+ for line in log.lines:
122
+ yield line
123
+ return
124
+ except asyncio.CancelledError:
125
+ return
126
+ except KeyboardInterrupt:
127
+ return
128
+ except StopAsyncIteration:
129
+ return
130
+ except grpc.aio.AioRpcError as e:
131
+ retries += 1
132
+ if retries >= retry:
133
+ if e.code() == grpc.StatusCode.NOT_FOUND:
134
+ raise LogsNotYetAvailableError(
135
+ f"Log stream not available for action {action_id.name} in run {action_id.run.name}."
136
+ )
137
+ else:
138
+ await asyncio.sleep(2)
139
+
140
+ @classmethod
141
+ async def create_viewer(
142
+ cls,
143
+ action_id: identifier_pb2.ActionIdentifier,
144
+ attempt: int = 1,
145
+ max_lines: int = 30,
146
+ show_ts: bool = False,
147
+ raw: bool = False,
148
+ filter_system: bool = False,
149
+ panel: bool = False,
150
+ ):
151
+ """
152
+ Create a log viewer for a given action ID and attempt.
153
+ :param action_id: Action ID to view logs for.
154
+ :param attempt: Attempt number (default is 1).
155
+ :param max_lines: Maximum number of lines to show if using the viewer. The logger will scroll
156
+ and keep only max_lines in view.
157
+ :param show_ts: Whether to show timestamps in the logs.
158
+ :param raw: if True, return the raw log lines instead of a viewer.
159
+ :param filter_system: Whether to filter log lines based on system logs.
160
+ :param panel: Whether to use a panel for the log viewer. only applicable if raw is False.
161
+ """
162
+ if attempt < 1:
163
+ raise ValueError("Attempt number must be greater than 0.")
164
+
165
+ if ipython_check():
166
+ if not ipywidgets_check():
167
+ logger.warning("IPython widgets is not available, defaulting to console output.")
168
+ raw = True
169
+
170
+ if raw:
171
+ console = Console()
172
+ async for line in cls.tail.aio(action_id=action_id, attempt=attempt):
173
+ line_text = _format_line(line, show_ts=show_ts, filter_system=filter_system)
174
+ if line_text:
175
+ console.print(line_text, end="")
176
+ return
177
+ viewer = AsyncLogViewer(
178
+ log_source=cls.tail.aio(action_id=action_id, attempt=attempt),
179
+ max_lines=max_lines,
180
+ show_ts=show_ts,
181
+ name=f"{action_id.run.name}:{action_id.name} ({attempt})",
182
+ filter_system=filter_system,
183
+ panel=panel,
184
+ )
185
+ await viewer.run()
@@ -0,0 +1,88 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import AsyncIterator, Iterator, Literal, Tuple, Union
5
+
6
+ import rich.repr
7
+ from flyteidl.admin import common_pb2, project_pb2
8
+
9
+ from flyte._initialize import ensure_client, get_client
10
+ from flyte.syncify import syncify
11
+
12
+ from ._common import ToJSONMixin
13
+
14
+
15
+ # TODO Add support for orgs again
16
+ @dataclass
17
+ class Project(ToJSONMixin):
18
+ """
19
+ A class representing a project in the Union API.
20
+ """
21
+
22
+ pb2: project_pb2.Project
23
+
24
+ @syncify
25
+ @classmethod
26
+ async def get(cls, name: str, org: str | None = None) -> Project:
27
+ """
28
+ Get a run by its ID or name. If both are provided, the ID will take precedence.
29
+
30
+ :param name: The name of the project.
31
+ :param org: The organization of the project (if applicable).
32
+ """
33
+ ensure_client()
34
+ service = get_client().project_domain_service # type: ignore
35
+ resp = await service.GetProject(
36
+ project_pb2.ProjectGetRequest(
37
+ id=name,
38
+ # org=org,
39
+ )
40
+ )
41
+ return cls(resp)
42
+
43
+ @syncify
44
+ @classmethod
45
+ async def listall(
46
+ cls,
47
+ filters: str | None = None,
48
+ sort_by: Tuple[str, Literal["asc", "desc"]] | None = None,
49
+ ) -> Union[AsyncIterator[Project], Iterator[Project]]:
50
+ """
51
+ Get a run by its ID or name. If both are provided, the ID will take precedence.
52
+
53
+ :param filters: The filters to apply to the project list.
54
+ :param sort_by: The sorting criteria for the project list, in the format (field, order).
55
+ :return: An iterator of projects.
56
+ """
57
+ ensure_client()
58
+ token = None
59
+ sort_by = sort_by or ("created_at", "asc")
60
+ sort_pb2 = common_pb2.Sort(
61
+ key=sort_by[0], direction=common_pb2.Sort.ASCENDING if sort_by[1] == "asc" else common_pb2.Sort.DESCENDING
62
+ )
63
+ # org = get_common_config().org
64
+ while True:
65
+ resp = await get_client().project_domain_service.ListProjects( # type: ignore
66
+ project_pb2.ProjectListRequest(
67
+ limit=100,
68
+ token=token,
69
+ filters=filters,
70
+ sort_by=sort_pb2,
71
+ # org=org,
72
+ )
73
+ )
74
+ token = resp.token
75
+ for p in resp.projects:
76
+ yield cls(p)
77
+ if not token:
78
+ break
79
+
80
+ def __rich_repr__(self) -> rich.repr.Result:
81
+ yield "name", self.pb2.name
82
+ yield "id", self.pb2.id
83
+ yield "description", self.pb2.description
84
+ yield "state", project_pb2.Project.ProjectState.Name(self.pb2.state)
85
+ yield (
86
+ "labels",
87
+ ", ".join([f"{k}: {v}" for k, v in self.pb2.labels.values.items()]) if self.pb2.labels else None,
88
+ )