flyte 0.1.0__py3-none-any.whl → 0.2.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (219) hide show
  1. flyte/__init__.py +78 -2
  2. flyte/_bin/__init__.py +0 -0
  3. flyte/_bin/runtime.py +152 -0
  4. flyte/_build.py +26 -0
  5. flyte/_cache/__init__.py +12 -0
  6. flyte/_cache/cache.py +145 -0
  7. flyte/_cache/defaults.py +9 -0
  8. flyte/_cache/policy_function_body.py +42 -0
  9. flyte/_code_bundle/__init__.py +8 -0
  10. flyte/_code_bundle/_ignore.py +113 -0
  11. flyte/_code_bundle/_packaging.py +187 -0
  12. flyte/_code_bundle/_utils.py +323 -0
  13. flyte/_code_bundle/bundle.py +209 -0
  14. flyte/_context.py +152 -0
  15. flyte/_deploy.py +243 -0
  16. flyte/_doc.py +29 -0
  17. flyte/_docstring.py +32 -0
  18. flyte/_environment.py +84 -0
  19. flyte/_excepthook.py +37 -0
  20. flyte/_group.py +32 -0
  21. flyte/_hash.py +23 -0
  22. flyte/_image.py +762 -0
  23. flyte/_initialize.py +492 -0
  24. flyte/_interface.py +84 -0
  25. flyte/_internal/__init__.py +3 -0
  26. flyte/_internal/controllers/__init__.py +128 -0
  27. flyte/_internal/controllers/_local_controller.py +193 -0
  28. flyte/_internal/controllers/_trace.py +41 -0
  29. flyte/_internal/controllers/remote/__init__.py +60 -0
  30. flyte/_internal/controllers/remote/_action.py +146 -0
  31. flyte/_internal/controllers/remote/_client.py +47 -0
  32. flyte/_internal/controllers/remote/_controller.py +494 -0
  33. flyte/_internal/controllers/remote/_core.py +410 -0
  34. flyte/_internal/controllers/remote/_informer.py +361 -0
  35. flyte/_internal/controllers/remote/_service_protocol.py +50 -0
  36. flyte/_internal/imagebuild/__init__.py +11 -0
  37. flyte/_internal/imagebuild/docker_builder.py +427 -0
  38. flyte/_internal/imagebuild/image_builder.py +246 -0
  39. flyte/_internal/imagebuild/remote_builder.py +0 -0
  40. flyte/_internal/resolvers/__init__.py +0 -0
  41. flyte/_internal/resolvers/_task_module.py +54 -0
  42. flyte/_internal/resolvers/common.py +31 -0
  43. flyte/_internal/resolvers/default.py +28 -0
  44. flyte/_internal/runtime/__init__.py +0 -0
  45. flyte/_internal/runtime/convert.py +342 -0
  46. flyte/_internal/runtime/entrypoints.py +135 -0
  47. flyte/_internal/runtime/io.py +136 -0
  48. flyte/_internal/runtime/resources_serde.py +138 -0
  49. flyte/_internal/runtime/task_serde.py +330 -0
  50. flyte/_internal/runtime/taskrunner.py +191 -0
  51. flyte/_internal/runtime/types_serde.py +54 -0
  52. flyte/_logging.py +135 -0
  53. flyte/_map.py +215 -0
  54. flyte/_pod.py +19 -0
  55. flyte/_protos/__init__.py +0 -0
  56. flyte/_protos/common/authorization_pb2.py +66 -0
  57. flyte/_protos/common/authorization_pb2.pyi +108 -0
  58. flyte/_protos/common/authorization_pb2_grpc.py +4 -0
  59. flyte/_protos/common/identifier_pb2.py +71 -0
  60. flyte/_protos/common/identifier_pb2.pyi +82 -0
  61. flyte/_protos/common/identifier_pb2_grpc.py +4 -0
  62. flyte/_protos/common/identity_pb2.py +48 -0
  63. flyte/_protos/common/identity_pb2.pyi +72 -0
  64. flyte/_protos/common/identity_pb2_grpc.py +4 -0
  65. flyte/_protos/common/list_pb2.py +36 -0
  66. flyte/_protos/common/list_pb2.pyi +71 -0
  67. flyte/_protos/common/list_pb2_grpc.py +4 -0
  68. flyte/_protos/common/policy_pb2.py +37 -0
  69. flyte/_protos/common/policy_pb2.pyi +27 -0
  70. flyte/_protos/common/policy_pb2_grpc.py +4 -0
  71. flyte/_protos/common/role_pb2.py +37 -0
  72. flyte/_protos/common/role_pb2.pyi +53 -0
  73. flyte/_protos/common/role_pb2_grpc.py +4 -0
  74. flyte/_protos/common/runtime_version_pb2.py +28 -0
  75. flyte/_protos/common/runtime_version_pb2.pyi +24 -0
  76. flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
  77. flyte/_protos/logs/dataplane/payload_pb2.py +100 -0
  78. flyte/_protos/logs/dataplane/payload_pb2.pyi +177 -0
  79. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
  80. flyte/_protos/secret/definition_pb2.py +49 -0
  81. flyte/_protos/secret/definition_pb2.pyi +93 -0
  82. flyte/_protos/secret/definition_pb2_grpc.py +4 -0
  83. flyte/_protos/secret/payload_pb2.py +62 -0
  84. flyte/_protos/secret/payload_pb2.pyi +94 -0
  85. flyte/_protos/secret/payload_pb2_grpc.py +4 -0
  86. flyte/_protos/secret/secret_pb2.py +38 -0
  87. flyte/_protos/secret/secret_pb2.pyi +6 -0
  88. flyte/_protos/secret/secret_pb2_grpc.py +198 -0
  89. flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
  90. flyte/_protos/validate/validate/validate_pb2.py +76 -0
  91. flyte/_protos/workflow/common_pb2.py +27 -0
  92. flyte/_protos/workflow/common_pb2.pyi +14 -0
  93. flyte/_protos/workflow/common_pb2_grpc.py +4 -0
  94. flyte/_protos/workflow/environment_pb2.py +29 -0
  95. flyte/_protos/workflow/environment_pb2.pyi +12 -0
  96. flyte/_protos/workflow/environment_pb2_grpc.py +4 -0
  97. flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
  98. flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
  99. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
  100. flyte/_protos/workflow/queue_service_pb2.py +105 -0
  101. flyte/_protos/workflow/queue_service_pb2.pyi +146 -0
  102. flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
  103. flyte/_protos/workflow/run_definition_pb2.py +128 -0
  104. flyte/_protos/workflow/run_definition_pb2.pyi +314 -0
  105. flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
  106. flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
  107. flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
  108. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
  109. flyte/_protos/workflow/run_service_pb2.py +129 -0
  110. flyte/_protos/workflow/run_service_pb2.pyi +171 -0
  111. flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
  112. flyte/_protos/workflow/state_service_pb2.py +66 -0
  113. flyte/_protos/workflow/state_service_pb2.pyi +75 -0
  114. flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
  115. flyte/_protos/workflow/task_definition_pb2.py +79 -0
  116. flyte/_protos/workflow/task_definition_pb2.pyi +81 -0
  117. flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
  118. flyte/_protos/workflow/task_service_pb2.py +60 -0
  119. flyte/_protos/workflow/task_service_pb2.pyi +59 -0
  120. flyte/_protos/workflow/task_service_pb2_grpc.py +138 -0
  121. flyte/_resources.py +226 -0
  122. flyte/_retry.py +32 -0
  123. flyte/_reusable_environment.py +25 -0
  124. flyte/_run.py +482 -0
  125. flyte/_secret.py +61 -0
  126. flyte/_task.py +449 -0
  127. flyte/_task_environment.py +183 -0
  128. flyte/_timeout.py +47 -0
  129. flyte/_tools.py +27 -0
  130. flyte/_trace.py +120 -0
  131. flyte/_utils/__init__.py +26 -0
  132. flyte/_utils/asyn.py +119 -0
  133. flyte/_utils/async_cache.py +139 -0
  134. flyte/_utils/coro_management.py +23 -0
  135. flyte/_utils/file_handling.py +72 -0
  136. flyte/_utils/helpers.py +134 -0
  137. flyte/_utils/lazy_module.py +54 -0
  138. flyte/_utils/org_discovery.py +57 -0
  139. flyte/_utils/uv_script_parser.py +49 -0
  140. flyte/_version.py +21 -0
  141. flyte/cli/__init__.py +3 -0
  142. flyte/cli/_abort.py +28 -0
  143. flyte/cli/_common.py +337 -0
  144. flyte/cli/_create.py +145 -0
  145. flyte/cli/_delete.py +23 -0
  146. flyte/cli/_deploy.py +152 -0
  147. flyte/cli/_gen.py +163 -0
  148. flyte/cli/_get.py +310 -0
  149. flyte/cli/_params.py +538 -0
  150. flyte/cli/_run.py +231 -0
  151. flyte/cli/main.py +166 -0
  152. flyte/config/__init__.py +3 -0
  153. flyte/config/_config.py +216 -0
  154. flyte/config/_internal.py +64 -0
  155. flyte/config/_reader.py +207 -0
  156. flyte/connectors/__init__.py +0 -0
  157. flyte/errors.py +172 -0
  158. flyte/extras/__init__.py +5 -0
  159. flyte/extras/_container.py +263 -0
  160. flyte/io/__init__.py +27 -0
  161. flyte/io/_dir.py +448 -0
  162. flyte/io/_file.py +467 -0
  163. flyte/io/_structured_dataset/__init__.py +129 -0
  164. flyte/io/_structured_dataset/basic_dfs.py +219 -0
  165. flyte/io/_structured_dataset/structured_dataset.py +1061 -0
  166. flyte/models.py +391 -0
  167. flyte/remote/__init__.py +26 -0
  168. flyte/remote/_client/__init__.py +0 -0
  169. flyte/remote/_client/_protocols.py +133 -0
  170. flyte/remote/_client/auth/__init__.py +12 -0
  171. flyte/remote/_client/auth/_auth_utils.py +14 -0
  172. flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
  173. flyte/remote/_client/auth/_authenticators/base.py +397 -0
  174. flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  175. flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
  176. flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
  177. flyte/remote/_client/auth/_authenticators/factory.py +200 -0
  178. flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
  179. flyte/remote/_client/auth/_channel.py +215 -0
  180. flyte/remote/_client/auth/_client_config.py +83 -0
  181. flyte/remote/_client/auth/_default_html.py +32 -0
  182. flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  183. flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
  184. flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
  185. flyte/remote/_client/auth/_keyring.py +143 -0
  186. flyte/remote/_client/auth/_token_client.py +260 -0
  187. flyte/remote/_client/auth/errors.py +16 -0
  188. flyte/remote/_client/controlplane.py +95 -0
  189. flyte/remote/_console.py +18 -0
  190. flyte/remote/_data.py +159 -0
  191. flyte/remote/_logs.py +176 -0
  192. flyte/remote/_project.py +85 -0
  193. flyte/remote/_run.py +970 -0
  194. flyte/remote/_secret.py +132 -0
  195. flyte/remote/_task.py +391 -0
  196. flyte/report/__init__.py +3 -0
  197. flyte/report/_report.py +178 -0
  198. flyte/report/_template.html +124 -0
  199. flyte/storage/__init__.py +29 -0
  200. flyte/storage/_config.py +233 -0
  201. flyte/storage/_remote_fs.py +34 -0
  202. flyte/storage/_storage.py +271 -0
  203. flyte/storage/_utils.py +5 -0
  204. flyte/syncify/__init__.py +56 -0
  205. flyte/syncify/_api.py +371 -0
  206. flyte/types/__init__.py +36 -0
  207. flyte/types/_interface.py +40 -0
  208. flyte/types/_pickle.py +118 -0
  209. flyte/types/_renderer.py +162 -0
  210. flyte/types/_string_literals.py +120 -0
  211. flyte/types/_type_engine.py +2287 -0
  212. flyte/types/_utils.py +80 -0
  213. flyte-0.2.0a0.dist-info/METADATA +249 -0
  214. flyte-0.2.0a0.dist-info/RECORD +218 -0
  215. {flyte-0.1.0.dist-info → flyte-0.2.0a0.dist-info}/WHEEL +2 -1
  216. flyte-0.2.0a0.dist-info/entry_points.txt +3 -0
  217. flyte-0.2.0a0.dist-info/top_level.txt +1 -0
  218. flyte-0.1.0.dist-info/METADATA +0 -6
  219. flyte-0.1.0.dist-info/RECORD +0 -5
flyte/__init__.py CHANGED
@@ -1,2 +1,78 @@
1
- def hello() -> str:
2
- return "Hello from flyte!"
1
+ """
2
+ Flyte SDK for authoring Compound AI applications, services and workflows.
3
+
4
+ ## Environments
5
+
6
+ TaskEnvironment class to define a new environment for a set of tasks.
7
+
8
+ Example usage:
9
+
10
+ ```python
11
+ env = flyte.TaskEnvironment(name="my_env", image="my_image", resources=Resources(cpu="1", memory="1Gi"))
12
+
13
+ @env.task
14
+ async def my_task():
15
+ pass
16
+ ```
17
+ """
18
+
19
+ __all__ = [
20
+ "GPU",
21
+ "TPU",
22
+ "Cache",
23
+ "CachePolicy",
24
+ "CacheRequest",
25
+ "Device",
26
+ "Environment",
27
+ "Image",
28
+ "PodTemplate",
29
+ "Resources",
30
+ "RetryStrategy",
31
+ "ReusePolicy",
32
+ "Secret",
33
+ "SecretRequest",
34
+ "TaskEnvironment",
35
+ "Timeout",
36
+ "TimeoutType",
37
+ "__version__",
38
+ "ctx",
39
+ "deploy",
40
+ "group",
41
+ "init",
42
+ "init_from_config",
43
+ "map",
44
+ "run",
45
+ "trace",
46
+ "with_runcontext",
47
+ ]
48
+
49
+ import sys
50
+
51
+ from ._cache import Cache, CachePolicy, CacheRequest
52
+ from ._context import ctx
53
+ from ._deploy import deploy
54
+ from ._environment import Environment
55
+ from ._excepthook import custom_excepthook
56
+ from ._group import group
57
+ from ._image import Image
58
+ from ._initialize import init, init_from_config
59
+ from ._map import map
60
+ from ._pod import PodTemplate
61
+ from ._resources import GPU, TPU, Device, Resources
62
+ from ._retry import RetryStrategy
63
+ from ._reusable_environment import ReusePolicy
64
+ from ._run import run, with_runcontext
65
+ from ._secret import Secret, SecretRequest
66
+ from ._task_environment import TaskEnvironment
67
+ from ._timeout import Timeout, TimeoutType
68
+ from ._trace import trace
69
+ from ._version import __version__
70
+
71
+ sys.excepthook = custom_excepthook
72
+
73
+
74
+ def version() -> str:
75
+ """
76
+ Returns the version of the Flyte SDK.
77
+ """
78
+ return __version__
flyte/_bin/__init__.py ADDED
File without changes
flyte/_bin/runtime.py ADDED
@@ -0,0 +1,152 @@
1
+ """
2
+ Flyte runtime module, this is the entrypoint script for the Flyte runtime.
3
+
4
+ Caution: Startup time for this module is very important, as it is the entrypoint for the Flyte runtime.
5
+ Refrain from importing any modules here. If you need to import any modules, do it inside the main function.
6
+ """
7
+
8
+ import asyncio
9
+ import os
10
+ import sys
11
+ from typing import Any, List
12
+
13
+ import click
14
+
15
+ # Todo: work with pvditt to make these the names
16
+ # ACTION_NAME = "_U_ACTION_NAME"
17
+ # RUN_NAME = "_U_RUN_NAME"
18
+ # PROJECT_NAME = "_U_PROJECT_NAME"
19
+ # DOMAIN_NAME = "_U_DOMAIN_NAME"
20
+ # ORG_NAME = "_U_ORG_NAME"
21
+
22
+ ACTION_NAME = "ACTION_NAME"
23
+ RUN_NAME = "RUN_NAME"
24
+ PROJECT_NAME = "FLYTE_INTERNAL_TASK_PROJECT"
25
+ DOMAIN_NAME = "FLYTE_INTERNAL_TASK_DOMAIN"
26
+ ORG_NAME = "_U_ORG_NAME"
27
+ ENDPOINT_OVERRIDE = "_U_EP_OVERRIDE"
28
+ RUN_OUTPUT_BASE_DIR = "_U_RUN_BASE"
29
+
30
+ # TODO: Remove this after proper auth is implemented
31
+ _UNION_EAGER_API_KEY_ENV_VAR = "_UNION_EAGER_API_KEY"
32
+
33
+
34
+ @click.command("a0")
35
+ @click.option("--inputs", "-i", required=True)
36
+ @click.option("--outputs-path", "-o", required=True)
37
+ @click.option("--version", "-v", required=True)
38
+ @click.option("--run-base-dir", envvar=RUN_OUTPUT_BASE_DIR, required=True)
39
+ @click.option("--raw-data-path", "-r", required=False)
40
+ @click.option("--checkpoint-path", "-c", required=False)
41
+ @click.option("--prev-checkpoint", "-p", required=False)
42
+ @click.option("--name", envvar=ACTION_NAME, required=False)
43
+ @click.option("--run-name", envvar=RUN_NAME, required=False)
44
+ @click.option("--project", envvar=PROJECT_NAME, required=False)
45
+ @click.option("--domain", envvar=DOMAIN_NAME, required=False)
46
+ @click.option("--org", envvar=ORG_NAME, required=False)
47
+ @click.option("--image-cache", required=False)
48
+ @click.option("--tgz", required=False)
49
+ @click.option("--pkl", required=False)
50
+ @click.option("--dest", required=False)
51
+ @click.option("--resolver", required=False)
52
+ @click.argument(
53
+ "resolver-args",
54
+ type=click.UNPROCESSED,
55
+ nargs=-1,
56
+ )
57
+ def main(
58
+ run_name: str,
59
+ name: str,
60
+ project: str,
61
+ domain: str,
62
+ org: str,
63
+ image_cache: str,
64
+ version: str,
65
+ inputs: str,
66
+ run_base_dir: str,
67
+ outputs_path: str,
68
+ raw_data_path: str,
69
+ checkpoint_path: str,
70
+ prev_checkpoint: str,
71
+ tgz: str,
72
+ pkl: str,
73
+ dest: str,
74
+ resolver: str,
75
+ resolver_args: List[str],
76
+ ):
77
+ sys.path.insert(0, ".")
78
+
79
+ import flyte
80
+ import flyte._utils as utils
81
+ from flyte._initialize import init
82
+ from flyte._internal.controllers import create_controller
83
+ from flyte._internal.imagebuild.image_builder import ImageCache
84
+ from flyte._internal.runtime.entrypoints import load_and_run_task
85
+ from flyte._logging import logger
86
+ from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath
87
+
88
+ logger.info(f"Initializing flyte runtime - version {flyte.__version__}")
89
+
90
+ assert org, "Org is required for now"
91
+ assert project, "Project is required"
92
+ assert domain, "Domain is required"
93
+ assert run_name, f"Run name is required {run_name}"
94
+ assert name, f"Action name is required {name}"
95
+
96
+ if run_name.startswith("{{"):
97
+ run_name = os.getenv("RUN_NAME", "")
98
+ if name.startswith("{{"):
99
+ name = os.getenv("ACTION_NAME", "")
100
+
101
+ # Figure out how to connect
102
+ # This detection of api key is a hack for now.
103
+ controller_kwargs: dict[str, Any] = {"insecure": False}
104
+ if api_key := os.getenv(_UNION_EAGER_API_KEY_ENV_VAR):
105
+ logger.info("Using api key from environment")
106
+ controller_kwargs["api_key"] = api_key
107
+ else:
108
+ ep = os.environ.get(ENDPOINT_OVERRIDE, "host.docker.internal:8090")
109
+ controller_kwargs["endpoint"] = ep
110
+ if "localhost" in ep or "docker" in ep:
111
+ controller_kwargs["insecure"] = True
112
+ logger.debug(f"Using controller endpoint: {ep} with kwargs: {controller_kwargs}")
113
+
114
+ bundle = CodeBundle(tgz=tgz, pkl=pkl, destination=dest, computed_version=version)
115
+ # We init regular client here so that reference tasks can work
116
+ # Current reference tasks will not work with remote controller, because we create 2 different
117
+ # channels on different threads and this is not supported by grpcio or the auth system. It ends up leading
118
+ # File "src/python/grpcio/grpc/_cython/_cygrpc/aio/completion_queue.pyx.pxi", line 147,
119
+ # in grpc._cython.cygrpc.PollerCompletionQueue._handle_events
120
+ # BlockingIOError: [Errno 11] Resource temporarily unavailable
121
+ # init(org=org, project=project, domain=domain, **controller_kwargs)
122
+ # TODO solution is to use a single channel for both controller and reference tasks, but this requires a refactor
123
+ init()
124
+ # Controller is created with the same kwargs as init, so that it can be used to run tasks
125
+ controller = create_controller(ct="remote", **controller_kwargs)
126
+
127
+ ic = ImageCache.from_transport(image_cache) if image_cache else None
128
+
129
+ # Create a coroutine to load the task and run it
130
+ task_coroutine = load_and_run_task(
131
+ resolver=resolver,
132
+ resolver_args=resolver_args,
133
+ action=ActionID(name=name, run_name=run_name, project=project, domain=domain, org=org),
134
+ raw_data_path=RawDataPath(path=raw_data_path),
135
+ checkpoints=Checkpoints(checkpoint_path, prev_checkpoint),
136
+ code_bundle=bundle,
137
+ input_path=inputs,
138
+ output_path=outputs_path,
139
+ run_base_dir=run_base_dir,
140
+ version=version,
141
+ controller=controller,
142
+ image_cache=ic,
143
+ )
144
+ # Create a coroutine to watch for errors
145
+ controller_failure = controller.watch_for_errors()
146
+
147
+ # Run both coroutines concurrently and wait for first to finish and cancel the other
148
+ async def _run_and_stop():
149
+ await utils.run_coros(controller_failure, task_coroutine)
150
+ await controller.stop()
151
+
152
+ asyncio.run(_run_and_stop())
flyte/_build.py ADDED
@@ -0,0 +1,26 @@
1
+ from __future__ import annotations
2
+
3
+ from flyte.syncify import syncify
4
+
5
+ from ._image import Image
6
+
7
+
8
+ @syncify
9
+ async def build(image: Image) -> str:
10
+ """
11
+ Build an image. The existing async context will be used.
12
+
13
+ Example:
14
+ ```
15
+ import flyte
16
+ image = flyte.Image("example_image")
17
+ if __name__ == "__main__":
18
+ asyncio.run(flyte.build.aio(image))
19
+ ```
20
+
21
+ :param image: The image(s) to build.
22
+ :return: The image URI.
23
+ """
24
+ from flyte._internal.imagebuild.image_builder import ImageBuildEngine
25
+
26
+ return await ImageBuildEngine.build(image)
@@ -0,0 +1,12 @@
1
+ from .cache import Cache, CacheBehavior, CachePolicy, CacheRequest
2
+ from .defaults import get_default_policies
3
+ from .policy_function_body import FunctionBodyPolicy
4
+
5
+ __all__ = [
6
+ "Cache",
7
+ "CacheBehavior",
8
+ "CachePolicy",
9
+ "CacheRequest",
10
+ "FunctionBodyPolicy",
11
+ "get_default_policies",
12
+ ]
flyte/_cache/cache.py ADDED
@@ -0,0 +1,145 @@
1
+ import hashlib
2
+ from dataclasses import dataclass, field
3
+ from typing import (
4
+ Callable,
5
+ Generic,
6
+ List,
7
+ Optional,
8
+ Protocol,
9
+ Tuple,
10
+ Union,
11
+ runtime_checkable,
12
+ )
13
+
14
+ import rich.repr
15
+ from typing_extensions import Literal, ParamSpec, TypeVar, get_args
16
+
17
+ # if TYPE_CHECKING:
18
+ from flyte._image import Image
19
+ from flyte.models import CodeBundle
20
+
21
+ P = ParamSpec("P")
22
+ FuncOut = TypeVar("FuncOut")
23
+
24
+ CacheBehavior = Literal["auto", "override", "disable", "enabled"]
25
+
26
+
27
+ @dataclass
28
+ class VersionParameters(Generic[P, FuncOut]):
29
+ """
30
+ Parameters used for cache version hash generation.
31
+
32
+ :param func: The function to generate a version for. This is a required parameter but can be any callable
33
+ :type func: Callable[P, FuncOut]
34
+ :param image: The container image to generate a version for. This can be a string representing the
35
+ image name or an Image object.
36
+ :type image: Optional[Union[str, Image]]
37
+ """
38
+
39
+ func: Callable[P, FuncOut] | None
40
+ image: Optional[Union[str, Image]] = None
41
+ code_bundle: Optional[CodeBundle] = None
42
+
43
+
44
+ @runtime_checkable
45
+ class CachePolicy(Protocol):
46
+ def get_version(self, salt: str, params: VersionParameters) -> str: ...
47
+
48
+
49
+ @rich.repr.auto
50
+ @dataclass
51
+ class Cache:
52
+ """
53
+ Cache configuration for a task.
54
+ :param behavior: The behavior of the cache. Can be "auto", "override" or "disable".
55
+ :param version_override: The version of the cache. If not provided, the version will be
56
+ generated based on the cache policies
57
+ :type version_override: Optional[str]
58
+ :param serialize: Boolean that indicates if identical (ie. same inputs) instances of this task should be executed in
59
+ serial when caching is enabled. This means that given multiple concurrent executions over identical inputs,
60
+ only a single instance executes and the rest wait to reuse the cached results.
61
+ :type serialize: bool
62
+ :param ignored_inputs: A tuple of input names to ignore when generating the version hash.
63
+ :type ignored_inputs: Union[Tuple[str, ...], str]
64
+ :param salt: A salt used in the hash generation.
65
+ :type salt: str
66
+ :param policies: A list of cache policies to generate the version hash.
67
+ :type policies: Optional[Union[List[CachePolicy], CachePolicy]]
68
+ """
69
+
70
+ behavior: CacheBehavior
71
+ version_override: Optional[str] = None
72
+ serialize: bool = False
73
+ ignored_inputs: Union[Tuple[str, ...], str] = field(default_factory=tuple)
74
+ salt: str = ""
75
+ policies: Optional[Union[List[CachePolicy], CachePolicy]] = None
76
+
77
+ def __post_init__(self):
78
+ if self.behavior not in get_args(CacheBehavior):
79
+ raise ValueError(f"Invalid cache behavior: {self.behavior}. Must be one of ['auto', 'override', 'disable']")
80
+ if self.behavior == "disable":
81
+ return
82
+
83
+ if isinstance(self.ignored_inputs, str):
84
+ self._ignored_inputs = (self.ignored_inputs,)
85
+ else:
86
+ self._ignored_inputs = self.ignored_inputs
87
+
88
+ # Normalize policies so that self._policies is always a list
89
+ if self.policies is None:
90
+ from flyte._cache.defaults import get_default_policies
91
+
92
+ self.policies = get_default_policies()
93
+ elif isinstance(self.policies, CachePolicy):
94
+ self.policies = [self.policies]
95
+
96
+ if self.version_override is None and not self.policies:
97
+ raise ValueError("If version is not defined then at least one cache policy needs to be set")
98
+
99
+ def is_enabled(self) -> bool:
100
+ """
101
+ Check if the cache policy is enabled.
102
+ """
103
+ return self.behavior in ["auto", "override"]
104
+
105
+ def get_ignored_inputs(self) -> Tuple[str, ...]:
106
+ return self._ignored_inputs
107
+
108
+ def get_version(self, params: Optional[VersionParameters] = None) -> str:
109
+ if not self.is_enabled():
110
+ return ""
111
+
112
+ if self.version_override is not None:
113
+ return self.version_override
114
+
115
+ if params is None:
116
+ raise ValueError("Version parameters must be provided when version_override is not set.")
117
+
118
+ if params.code_bundle is not None:
119
+ if params.code_bundle.pkl is not None:
120
+ return params.code_bundle.computed_version
121
+
122
+ task_hash = ""
123
+ if self.policies is None:
124
+ raise ValueError("Cache policies are not set.")
125
+ policies = self.policies if isinstance(self.policies, list) else [self.policies]
126
+ for policy in policies:
127
+ try:
128
+ task_hash += policy.get_version(self.salt, params)
129
+ except Exception as e:
130
+ raise ValueError(f"Failed to generate version for cache policy {policy}.") from e
131
+
132
+ hash_obj = hashlib.sha256(task_hash.encode())
133
+ return hash_obj.hexdigest()
134
+
135
+
136
+ CacheRequest = CacheBehavior | Cache
137
+
138
+
139
+ def cache_from_request(cache: CacheRequest) -> Cache:
140
+ """
141
+ Coerce user input into a cache object.
142
+ """
143
+ if isinstance(cache, Cache):
144
+ return cache
145
+ return Cache(behavior=cache)
@@ -0,0 +1,9 @@
1
+ from .cache import CachePolicy
2
+ from .policy_function_body import FunctionBodyPolicy
3
+
4
+
5
+ def get_default_policies() -> list[CachePolicy]:
6
+ """
7
+ Get default cache policies.
8
+ """
9
+ return [FunctionBodyPolicy()]
@@ -0,0 +1,42 @@
1
+ import ast
2
+ import hashlib
3
+ import inspect
4
+ import textwrap
5
+
6
+ from .cache import CachePolicy, VersionParameters
7
+
8
+
9
+ class FunctionBodyPolicy(CachePolicy):
10
+ """
11
+ A class that implements a versioning mechanism for functions by generating
12
+ a SHA-256 hash of the function's source code combined with a salt.
13
+ """
14
+
15
+ def get_version(self, salt: str, params: VersionParameters) -> str:
16
+ """
17
+ This method generates a version string for a function by hashing the function's source code
18
+ combined with a salt.
19
+
20
+ :param salt: A string that is used to salt the hash.
21
+ :param params: VersionParameters object that contains the parameters (e.g. function, ImageSpec, etc.) that are
22
+ used to generate the version.
23
+
24
+ :return: A string that represents the version of the function.
25
+ """
26
+ if params.func is None:
27
+ return ""
28
+
29
+ source = inspect.getsource(params.func)
30
+ dedented_source = textwrap.dedent(source)
31
+
32
+ # Parse the source code into an Abstract Syntax Tree (AST)
33
+ parsed_ast = ast.parse(dedented_source)
34
+
35
+ # Convert the AST into a string representation
36
+ ast_bytes = ast.dump(parsed_ast, include_attributes=False).encode("utf-8")
37
+
38
+ # Combine the AST bytes with the salt (encoded into bytes)
39
+ combined_data = ast_bytes + salt.encode("utf-8")
40
+
41
+ # Return the SHA-256 hash of the combined data (AST + salt)
42
+ return hashlib.sha256(combined_data).hexdigest()
@@ -0,0 +1,8 @@
1
+ from ._ignore import GitIgnore, IgnoreGroup, StandardIgnore
2
+ from ._utils import CopyFiles
3
+ from .bundle import build_code_bundle, build_pkl_bundle, download_bundle
4
+
5
+ __all__ = ["CopyFiles", "build_code_bundle", "build_pkl_bundle", "default_ignores", "download_bundle"]
6
+
7
+
8
+ default_ignores = [GitIgnore, StandardIgnore, IgnoreGroup]
@@ -0,0 +1,113 @@
1
+ import os
2
+ import pathlib
3
+ import subprocess
4
+ import tarfile as _tarfile
5
+ from abc import ABC, abstractmethod
6
+ from fnmatch import fnmatch
7
+ from pathlib import Path
8
+ from shutil import which
9
+ from typing import List, Optional, Type
10
+
11
+ from flyte._logging import logger
12
+
13
+
14
+ class Ignore(ABC):
15
+ """Base for Ignores, implements core logic. Children have to implement _is_ignored"""
16
+
17
+ def __init__(self, root: Path):
18
+ self.root = root
19
+
20
+ def is_ignored(self, path: pathlib.Path) -> bool:
21
+ return self._is_ignored(path)
22
+
23
+ def tar_filter(self, tarinfo: _tarfile.TarInfo) -> Optional[_tarfile.TarInfo]:
24
+ if self.is_ignored(pathlib.Path(tarinfo.name)):
25
+ return None
26
+ return tarinfo
27
+
28
+ @abstractmethod
29
+ def _is_ignored(self, path: pathlib.Path) -> bool:
30
+ pass
31
+
32
+
33
+ class GitIgnore(Ignore):
34
+ """Uses git cli (if available) to list all ignored files and compare with those."""
35
+
36
+ def __init__(self, root: Path):
37
+ super().__init__(root)
38
+ self.has_git = which("git") is not None
39
+ self.ignored_files = self._list_ignored_files()
40
+ self.ignored_dirs = self._list_ignored_dirs()
41
+
42
+ def _git_wrapper(self, extra_args: List[str]) -> set[str]:
43
+ if self.has_git:
44
+ out = subprocess.run(
45
+ ["git", "ls-files", "-io", "--exclude-standard", *extra_args],
46
+ cwd=self.root,
47
+ capture_output=True,
48
+ check=False,
49
+ )
50
+ if out.returncode == 0:
51
+ return set(out.stdout.decode("utf-8").split("\n")[:-1])
52
+ logger.info(f"Could not determine ignored paths due to:\n{out.stderr!r}\nNot applying any filters")
53
+ return set()
54
+ logger.info("No git executable found, not applying any filters")
55
+ return set()
56
+
57
+ def _list_ignored_files(self) -> set[str]:
58
+ return self._git_wrapper([])
59
+
60
+ def _list_ignored_dirs(self) -> set[str]:
61
+ return self._git_wrapper(["--directory"])
62
+
63
+ def _is_ignored(self, path: pathlib.Path) -> bool:
64
+ if self.ignored_files:
65
+ # git-ls-files uses POSIX paths
66
+ if Path(path).as_posix() in self.ignored_files:
67
+ return True
68
+ # Ignore empty directories
69
+ if os.path.isdir(os.path.join(self.root, path)) and self.ignored_dirs:
70
+ return Path(path).as_posix() + "/" in self.ignored_dirs
71
+ return False
72
+
73
+
74
+ STANDARD_IGNORE_PATTERNS = ["*.pyc", ".cache", ".cache/*", "__pycache__", "**/__pycache__"]
75
+
76
+
77
+ class StandardIgnore(Ignore):
78
+ """Retains the standard ignore functionality that previously existed. Could in theory
79
+ by fed with custom ignore patterns from cli."""
80
+
81
+ def __init__(self, root: Path, patterns: Optional[List[str]] = None):
82
+ super().__init__(root)
83
+ self.patterns = patterns if patterns else STANDARD_IGNORE_PATTERNS
84
+
85
+ def _is_ignored(self, path: pathlib.Path) -> bool:
86
+ for pattern in self.patterns:
87
+ if fnmatch(str(path), pattern):
88
+ return True
89
+ return False
90
+
91
+
92
+ class IgnoreGroup(Ignore):
93
+ """Groups multiple Ignores and checks a path against them. A file is ignored if any
94
+ Ignore considers it ignored."""
95
+
96
+ def __init__(self, root: Path, *ignores: Type[Ignore]):
97
+ super().__init__(root)
98
+ self.ignores = [ignore(root) for ignore in ignores]
99
+
100
+ def _is_ignored(self, path: pathlib.Path) -> bool:
101
+ for ignore in self.ignores:
102
+ if ignore.is_ignored(path):
103
+ return True
104
+ return False
105
+
106
+ def list_ignored(self) -> List[str]:
107
+ ignored = []
108
+ for dir, _, files in self.root.walk():
109
+ for file in files:
110
+ abs_path = dir / file
111
+ if self.is_ignored(abs_path):
112
+ ignored.append(str(abs_path.relative_to(self.root)))
113
+ return ignored