flyte 0.1.0__py3-none-any.whl → 0.2.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (219) hide show
  1. flyte/__init__.py +78 -2
  2. flyte/_bin/__init__.py +0 -0
  3. flyte/_bin/runtime.py +152 -0
  4. flyte/_build.py +26 -0
  5. flyte/_cache/__init__.py +12 -0
  6. flyte/_cache/cache.py +145 -0
  7. flyte/_cache/defaults.py +9 -0
  8. flyte/_cache/policy_function_body.py +42 -0
  9. flyte/_code_bundle/__init__.py +8 -0
  10. flyte/_code_bundle/_ignore.py +113 -0
  11. flyte/_code_bundle/_packaging.py +187 -0
  12. flyte/_code_bundle/_utils.py +323 -0
  13. flyte/_code_bundle/bundle.py +209 -0
  14. flyte/_context.py +152 -0
  15. flyte/_deploy.py +243 -0
  16. flyte/_doc.py +29 -0
  17. flyte/_docstring.py +32 -0
  18. flyte/_environment.py +84 -0
  19. flyte/_excepthook.py +37 -0
  20. flyte/_group.py +32 -0
  21. flyte/_hash.py +23 -0
  22. flyte/_image.py +762 -0
  23. flyte/_initialize.py +492 -0
  24. flyte/_interface.py +84 -0
  25. flyte/_internal/__init__.py +3 -0
  26. flyte/_internal/controllers/__init__.py +128 -0
  27. flyte/_internal/controllers/_local_controller.py +193 -0
  28. flyte/_internal/controllers/_trace.py +41 -0
  29. flyte/_internal/controllers/remote/__init__.py +60 -0
  30. flyte/_internal/controllers/remote/_action.py +146 -0
  31. flyte/_internal/controllers/remote/_client.py +47 -0
  32. flyte/_internal/controllers/remote/_controller.py +494 -0
  33. flyte/_internal/controllers/remote/_core.py +410 -0
  34. flyte/_internal/controllers/remote/_informer.py +361 -0
  35. flyte/_internal/controllers/remote/_service_protocol.py +50 -0
  36. flyte/_internal/imagebuild/__init__.py +11 -0
  37. flyte/_internal/imagebuild/docker_builder.py +427 -0
  38. flyte/_internal/imagebuild/image_builder.py +246 -0
  39. flyte/_internal/imagebuild/remote_builder.py +0 -0
  40. flyte/_internal/resolvers/__init__.py +0 -0
  41. flyte/_internal/resolvers/_task_module.py +54 -0
  42. flyte/_internal/resolvers/common.py +31 -0
  43. flyte/_internal/resolvers/default.py +28 -0
  44. flyte/_internal/runtime/__init__.py +0 -0
  45. flyte/_internal/runtime/convert.py +342 -0
  46. flyte/_internal/runtime/entrypoints.py +135 -0
  47. flyte/_internal/runtime/io.py +136 -0
  48. flyte/_internal/runtime/resources_serde.py +138 -0
  49. flyte/_internal/runtime/task_serde.py +330 -0
  50. flyte/_internal/runtime/taskrunner.py +191 -0
  51. flyte/_internal/runtime/types_serde.py +54 -0
  52. flyte/_logging.py +135 -0
  53. flyte/_map.py +215 -0
  54. flyte/_pod.py +19 -0
  55. flyte/_protos/__init__.py +0 -0
  56. flyte/_protos/common/authorization_pb2.py +66 -0
  57. flyte/_protos/common/authorization_pb2.pyi +108 -0
  58. flyte/_protos/common/authorization_pb2_grpc.py +4 -0
  59. flyte/_protos/common/identifier_pb2.py +71 -0
  60. flyte/_protos/common/identifier_pb2.pyi +82 -0
  61. flyte/_protos/common/identifier_pb2_grpc.py +4 -0
  62. flyte/_protos/common/identity_pb2.py +48 -0
  63. flyte/_protos/common/identity_pb2.pyi +72 -0
  64. flyte/_protos/common/identity_pb2_grpc.py +4 -0
  65. flyte/_protos/common/list_pb2.py +36 -0
  66. flyte/_protos/common/list_pb2.pyi +71 -0
  67. flyte/_protos/common/list_pb2_grpc.py +4 -0
  68. flyte/_protos/common/policy_pb2.py +37 -0
  69. flyte/_protos/common/policy_pb2.pyi +27 -0
  70. flyte/_protos/common/policy_pb2_grpc.py +4 -0
  71. flyte/_protos/common/role_pb2.py +37 -0
  72. flyte/_protos/common/role_pb2.pyi +53 -0
  73. flyte/_protos/common/role_pb2_grpc.py +4 -0
  74. flyte/_protos/common/runtime_version_pb2.py +28 -0
  75. flyte/_protos/common/runtime_version_pb2.pyi +24 -0
  76. flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
  77. flyte/_protos/logs/dataplane/payload_pb2.py +100 -0
  78. flyte/_protos/logs/dataplane/payload_pb2.pyi +177 -0
  79. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
  80. flyte/_protos/secret/definition_pb2.py +49 -0
  81. flyte/_protos/secret/definition_pb2.pyi +93 -0
  82. flyte/_protos/secret/definition_pb2_grpc.py +4 -0
  83. flyte/_protos/secret/payload_pb2.py +62 -0
  84. flyte/_protos/secret/payload_pb2.pyi +94 -0
  85. flyte/_protos/secret/payload_pb2_grpc.py +4 -0
  86. flyte/_protos/secret/secret_pb2.py +38 -0
  87. flyte/_protos/secret/secret_pb2.pyi +6 -0
  88. flyte/_protos/secret/secret_pb2_grpc.py +198 -0
  89. flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
  90. flyte/_protos/validate/validate/validate_pb2.py +76 -0
  91. flyte/_protos/workflow/common_pb2.py +27 -0
  92. flyte/_protos/workflow/common_pb2.pyi +14 -0
  93. flyte/_protos/workflow/common_pb2_grpc.py +4 -0
  94. flyte/_protos/workflow/environment_pb2.py +29 -0
  95. flyte/_protos/workflow/environment_pb2.pyi +12 -0
  96. flyte/_protos/workflow/environment_pb2_grpc.py +4 -0
  97. flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
  98. flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
  99. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
  100. flyte/_protos/workflow/queue_service_pb2.py +105 -0
  101. flyte/_protos/workflow/queue_service_pb2.pyi +146 -0
  102. flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
  103. flyte/_protos/workflow/run_definition_pb2.py +128 -0
  104. flyte/_protos/workflow/run_definition_pb2.pyi +314 -0
  105. flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
  106. flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
  107. flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
  108. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
  109. flyte/_protos/workflow/run_service_pb2.py +129 -0
  110. flyte/_protos/workflow/run_service_pb2.pyi +171 -0
  111. flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
  112. flyte/_protos/workflow/state_service_pb2.py +66 -0
  113. flyte/_protos/workflow/state_service_pb2.pyi +75 -0
  114. flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
  115. flyte/_protos/workflow/task_definition_pb2.py +79 -0
  116. flyte/_protos/workflow/task_definition_pb2.pyi +81 -0
  117. flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
  118. flyte/_protos/workflow/task_service_pb2.py +60 -0
  119. flyte/_protos/workflow/task_service_pb2.pyi +59 -0
  120. flyte/_protos/workflow/task_service_pb2_grpc.py +138 -0
  121. flyte/_resources.py +226 -0
  122. flyte/_retry.py +32 -0
  123. flyte/_reusable_environment.py +25 -0
  124. flyte/_run.py +482 -0
  125. flyte/_secret.py +61 -0
  126. flyte/_task.py +449 -0
  127. flyte/_task_environment.py +183 -0
  128. flyte/_timeout.py +47 -0
  129. flyte/_tools.py +27 -0
  130. flyte/_trace.py +120 -0
  131. flyte/_utils/__init__.py +26 -0
  132. flyte/_utils/asyn.py +119 -0
  133. flyte/_utils/async_cache.py +139 -0
  134. flyte/_utils/coro_management.py +23 -0
  135. flyte/_utils/file_handling.py +72 -0
  136. flyte/_utils/helpers.py +134 -0
  137. flyte/_utils/lazy_module.py +54 -0
  138. flyte/_utils/org_discovery.py +57 -0
  139. flyte/_utils/uv_script_parser.py +49 -0
  140. flyte/_version.py +21 -0
  141. flyte/cli/__init__.py +3 -0
  142. flyte/cli/_abort.py +28 -0
  143. flyte/cli/_common.py +337 -0
  144. flyte/cli/_create.py +145 -0
  145. flyte/cli/_delete.py +23 -0
  146. flyte/cli/_deploy.py +152 -0
  147. flyte/cli/_gen.py +163 -0
  148. flyte/cli/_get.py +310 -0
  149. flyte/cli/_params.py +538 -0
  150. flyte/cli/_run.py +231 -0
  151. flyte/cli/main.py +166 -0
  152. flyte/config/__init__.py +3 -0
  153. flyte/config/_config.py +216 -0
  154. flyte/config/_internal.py +64 -0
  155. flyte/config/_reader.py +207 -0
  156. flyte/connectors/__init__.py +0 -0
  157. flyte/errors.py +172 -0
  158. flyte/extras/__init__.py +5 -0
  159. flyte/extras/_container.py +263 -0
  160. flyte/io/__init__.py +27 -0
  161. flyte/io/_dir.py +448 -0
  162. flyte/io/_file.py +467 -0
  163. flyte/io/_structured_dataset/__init__.py +129 -0
  164. flyte/io/_structured_dataset/basic_dfs.py +219 -0
  165. flyte/io/_structured_dataset/structured_dataset.py +1061 -0
  166. flyte/models.py +391 -0
  167. flyte/remote/__init__.py +26 -0
  168. flyte/remote/_client/__init__.py +0 -0
  169. flyte/remote/_client/_protocols.py +133 -0
  170. flyte/remote/_client/auth/__init__.py +12 -0
  171. flyte/remote/_client/auth/_auth_utils.py +14 -0
  172. flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
  173. flyte/remote/_client/auth/_authenticators/base.py +397 -0
  174. flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  175. flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
  176. flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
  177. flyte/remote/_client/auth/_authenticators/factory.py +200 -0
  178. flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
  179. flyte/remote/_client/auth/_channel.py +215 -0
  180. flyte/remote/_client/auth/_client_config.py +83 -0
  181. flyte/remote/_client/auth/_default_html.py +32 -0
  182. flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  183. flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
  184. flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
  185. flyte/remote/_client/auth/_keyring.py +143 -0
  186. flyte/remote/_client/auth/_token_client.py +260 -0
  187. flyte/remote/_client/auth/errors.py +16 -0
  188. flyte/remote/_client/controlplane.py +95 -0
  189. flyte/remote/_console.py +18 -0
  190. flyte/remote/_data.py +159 -0
  191. flyte/remote/_logs.py +176 -0
  192. flyte/remote/_project.py +85 -0
  193. flyte/remote/_run.py +970 -0
  194. flyte/remote/_secret.py +132 -0
  195. flyte/remote/_task.py +391 -0
  196. flyte/report/__init__.py +3 -0
  197. flyte/report/_report.py +178 -0
  198. flyte/report/_template.html +124 -0
  199. flyte/storage/__init__.py +29 -0
  200. flyte/storage/_config.py +233 -0
  201. flyte/storage/_remote_fs.py +34 -0
  202. flyte/storage/_storage.py +271 -0
  203. flyte/storage/_utils.py +5 -0
  204. flyte/syncify/__init__.py +56 -0
  205. flyte/syncify/_api.py +371 -0
  206. flyte/types/__init__.py +36 -0
  207. flyte/types/_interface.py +40 -0
  208. flyte/types/_pickle.py +118 -0
  209. flyte/types/_renderer.py +162 -0
  210. flyte/types/_string_literals.py +120 -0
  211. flyte/types/_type_engine.py +2287 -0
  212. flyte/types/_utils.py +80 -0
  213. flyte-0.2.0a0.dist-info/METADATA +249 -0
  214. flyte-0.2.0a0.dist-info/RECORD +218 -0
  215. {flyte-0.1.0.dist-info → flyte-0.2.0a0.dist-info}/WHEEL +2 -1
  216. flyte-0.2.0a0.dist-info/entry_points.txt +3 -0
  217. flyte-0.2.0a0.dist-info/top_level.txt +1 -0
  218. flyte-0.1.0.dist-info/METADATA +0 -6
  219. flyte-0.1.0.dist-info/RECORD +0 -5
@@ -0,0 +1,191 @@
1
+ """
2
+ This module is responsible for running tasks in the V2 runtime. All methods in this file should be
3
+ invoked within a context tree.
4
+ """
5
+
6
+ import pathlib
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
+ import flyte.report
10
+ from flyte._context import internal_ctx
11
+ from flyte._internal.imagebuild.image_builder import ImageCache
12
+ from flyte._logging import log, logger
13
+ from flyte._task import TaskTemplate
14
+ from flyte.errors import CustomError, RuntimeSystemError, RuntimeUnknownError, RuntimeUserError
15
+ from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath, TaskContext
16
+
17
+ from .. import Controller
18
+ from .convert import (
19
+ Error,
20
+ Inputs,
21
+ Outputs,
22
+ convert_from_native_to_error,
23
+ convert_from_native_to_outputs,
24
+ convert_inputs_to_native,
25
+ )
26
+ from .io import load_inputs, upload_error, upload_outputs
27
+
28
+
29
+ def replace_task_cli(args: List[str], inputs: Inputs, tmp_path: pathlib.Path, action: ActionID) -> List[str]:
30
+ """
31
+ This method can be used to run an task from the cli, if you have cli for the task. It will replace,
32
+ all the args with the task args.
33
+
34
+ The urun cli is of the format
35
+ ```python
36
+ ['urun', '--inputs', '{{.Inputs}}', '--outputs-path', '{{.Outputs}}', '--version', '',
37
+ '--raw-data-path', '{{.rawOutputDataPrefix}}',
38
+ '--checkpoint-path', '{{.checkpointOutputPrefix}}', '--prev-checkpoint', '{{.prevCheckpointPrefix}}',
39
+ '--run-name', '{{.runName}}', '--name', '{{.actionName}}',
40
+ '--tgz', 'some-path', '--dest', '.',
41
+ '--resolver', 'flyte._internal.resolvers.default.DefaultTaskResolver', '--resolver-args',
42
+ 'mod', 'test_round_trip', 'instance', 'task1']
43
+ ```
44
+ We will replace, inputs, outputs, raw_data_path, checkpoint_path, prev_checkpoint, run_name, name
45
+ with supplied values.
46
+
47
+ :param args: urun command
48
+ :param inputs: converted inputs to the task
49
+ :param tmp_path: temporary path to use for the task
50
+ :param action: run id to use for the task
51
+ :return: modified args
52
+ """
53
+ # Iterate over all the args and replace the inputs, outputs, raw_data_path, checkpoint_path, prev_checkpoint,
54
+ # root_name, run_name with the supplied values
55
+ # first we will write the inputs to a file called inputs.pb
56
+ inputs_path = tmp_path / "inputs.pb"
57
+ with open(inputs_path, "wb") as f:
58
+ f.write(inputs.proto_inputs.SerializeToString())
59
+ # now modify the args
60
+ args = list(args) # copy first because it's a proto container
61
+ for i, arg in enumerate(args):
62
+ match arg:
63
+ case "--inputs":
64
+ args[i + 1] = str(inputs_path)
65
+ case "--outputs-path":
66
+ args[i + 1] = str(tmp_path)
67
+ case "--raw-data-path":
68
+ args[i + 1] = str(tmp_path / "raw_data_path")
69
+ case "--checkpoint-path":
70
+ args[i + 1] = str(tmp_path / "checkpoint_path")
71
+ case "--prev-checkpoint":
72
+ args[i + 1] = str(tmp_path / "prev_checkpoint")
73
+ case "--run-name":
74
+ args[i + 1] = action.run_name or ""
75
+ case "--name":
76
+ args[i + 1] = action.name
77
+ insert_point = args.index("--raw-data-path")
78
+ args.insert(insert_point, str(tmp_path))
79
+ args.insert(insert_point, "--run-base-dir")
80
+ return args
81
+
82
+
83
+ @log
84
+ async def run_task(
85
+ tctx: TaskContext, controller: Controller, task: TaskTemplate, inputs: Dict[str, Any]
86
+ ) -> Tuple[Any, Optional[Exception]]:
87
+ try:
88
+ logger.info(f"Parent task executing {tctx.action}")
89
+ outputs = await task.execute(**inputs)
90
+ logger.info(f"Parent task completed successfully, {tctx.action}")
91
+ return outputs, None
92
+ except RuntimeSystemError as e:
93
+ logger.exception(f"Task failed with error: {e}")
94
+ return {}, e
95
+ except RuntimeUnknownError as e:
96
+ logger.exception(f"Task failed with error: {e}")
97
+ return {}, e
98
+ except RuntimeUserError as e:
99
+ logger.exception(f"Task failed with error: {e}")
100
+ return {}, e
101
+ except Exception as e:
102
+ logger.exception(f"Task failed with error: {e}")
103
+ return {}, CustomError.from_exception(e)
104
+ finally:
105
+ logger.info(f"Parent task finalized {tctx.action}")
106
+ # reconstruct run id here
107
+ await controller.finalize_parent_action(tctx.action)
108
+
109
+
110
+ async def convert_and_run(
111
+ *,
112
+ task: TaskTemplate,
113
+ inputs: Inputs,
114
+ action: ActionID,
115
+ controller: Controller,
116
+ raw_data_path: RawDataPath,
117
+ version: str,
118
+ output_path: str,
119
+ run_base_dir: str,
120
+ checkpoints: Checkpoints | None = None,
121
+ code_bundle: CodeBundle | None = None,
122
+ image_cache: ImageCache | None = None,
123
+ ) -> Tuple[Optional[Outputs], Optional[Error]]:
124
+ """
125
+ This method is used to convert the inputs to native types, and run the task. It assumes you are running
126
+ in a context tree.
127
+ """
128
+ ctx = internal_ctx()
129
+ tctx = TaskContext(
130
+ action=action,
131
+ checkpoints=checkpoints,
132
+ code_bundle=code_bundle,
133
+ output_path=output_path,
134
+ run_base_dir=run_base_dir,
135
+ version=version,
136
+ raw_data_path=raw_data_path,
137
+ compiled_image_cache=image_cache,
138
+ report=flyte.report.Report(name=action.name),
139
+ mode="remote" if not ctx.data.task_context else ctx.data.task_context.mode,
140
+ )
141
+ with ctx.replace_task_context(tctx):
142
+ inputs_kwargs = await convert_inputs_to_native(inputs, task.native_interface)
143
+ out, err = await run_task(tctx=tctx, controller=controller, task=task, inputs=inputs_kwargs)
144
+ if err is not None:
145
+ return None, convert_from_native_to_error(err)
146
+ if task.report:
147
+ await flyte.report.flush.aio()
148
+ return await convert_from_native_to_outputs(out, task.native_interface, task.name), None
149
+
150
+
151
+ async def extract_download_run_upload(
152
+ task: TaskTemplate,
153
+ *,
154
+ action: ActionID,
155
+ controller: Controller,
156
+ raw_data_path: RawDataPath,
157
+ output_path: str,
158
+ run_base_dir: str,
159
+ version: str,
160
+ checkpoints: Checkpoints | None = None,
161
+ code_bundle: CodeBundle | None = None,
162
+ input_path: str | None = None,
163
+ image_cache: ImageCache | None = None,
164
+ ):
165
+ """
166
+ This method is invoked from the CLI (urun) and is used to run a task. This assumes that the context tree
167
+ has already been created, and the task has been loaded. It also handles the loading of the task.
168
+ """
169
+ inputs = await load_inputs(input_path) if input_path else None
170
+ outputs, err = await convert_and_run(
171
+ task=task,
172
+ inputs=inputs or Inputs.empty(),
173
+ action=action,
174
+ controller=controller,
175
+ raw_data_path=raw_data_path,
176
+ output_path=output_path,
177
+ run_base_dir=run_base_dir,
178
+ version=version,
179
+ checkpoints=checkpoints,
180
+ code_bundle=code_bundle,
181
+ image_cache=image_cache,
182
+ )
183
+ if err is not None:
184
+ path = await upload_error(err.err, output_path)
185
+ logger.error(f"Task {task.name} failed with error: {err}. Uploaded error to {path}")
186
+ return
187
+ if outputs is None:
188
+ logger.info(f"Task {task.name} completed successfully, no outputs")
189
+ return
190
+ await upload_outputs(outputs, output_path) if output_path else None
191
+ logger.info(f"Task {task.name} completed successfully, uploaded outputs to {output_path}")
@@ -0,0 +1,54 @@
1
+ from typing import Dict, Optional, TypeVar
2
+
3
+ from flyteidl.core import interface_pb2
4
+
5
+ from flyte.models import NativeInterface
6
+ from flyte.types._type_engine import TypeEngine
7
+
8
+ T = TypeVar("T")
9
+
10
+
11
+ def transform_variable_map(
12
+ variable_map: Dict[str, type],
13
+ descriptions: Optional[Dict[str, str]] = None,
14
+ ) -> Dict[str, interface_pb2.Variable]:
15
+ """
16
+ Given a map of str (names of inputs for instance) to their Python native types, return a map of the name to a
17
+ Flyte Variable object with that type.
18
+ """
19
+ res = {}
20
+ descriptions = descriptions or {}
21
+ if variable_map:
22
+ for k, v in variable_map.items():
23
+ res[k] = transform_type(v, descriptions.get(k, k))
24
+ return res
25
+
26
+
27
+ def transform_native_to_typed_interface(
28
+ interface: Optional[NativeInterface],
29
+ ) -> Optional[interface_pb2.TypedInterface]:
30
+ """
31
+ Transform the given simple python native interface to FlyteIDL's interface
32
+ """
33
+ if interface is None:
34
+ return None
35
+ input_descriptions: Dict[str, str] = {}
36
+ output_descriptions: Dict[str, str] = {}
37
+ if interface.docstring:
38
+ # Fill in descriptions from docstring in the future
39
+ input_descriptions = {}
40
+ output_descriptions = {}
41
+
42
+ inputs_map = transform_variable_map(interface.get_input_types(), input_descriptions)
43
+ outputs_map = transform_variable_map(interface.outputs, output_descriptions)
44
+ return interface_pb2.TypedInterface(
45
+ inputs=interface_pb2.VariableMap(variables=inputs_map), outputs=interface_pb2.VariableMap(variables=outputs_map)
46
+ )
47
+
48
+
49
+ def transform_type(x: type, description: Optional[str] = None) -> interface_pb2.Variable:
50
+ # add artifact handling eventually
51
+ return interface_pb2.Variable(
52
+ type=TypeEngine.to_literal_type(x),
53
+ description=description,
54
+ )
flyte/_logging.py ADDED
@@ -0,0 +1,135 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ from typing import Optional
6
+
7
+ from ._tools import ipython_check, is_in_cluster
8
+
9
+ DEFAULT_LOG_LEVEL = logging.WARNING
10
+
11
+
12
+ def make_hyperlink(label: str, url: str):
13
+ """
14
+ Create a hyperlink in the terminal output.
15
+ """
16
+ BLUE = "\033[94m"
17
+ RESET = "\033[0m"
18
+ OSC8_BEGIN = f"\033]8;;{url}\033\\"
19
+ OSC8_END = "\033]8;;\033\\"
20
+ return f"{BLUE}{OSC8_BEGIN}{label}{RESET}{OSC8_END}"
21
+
22
+
23
+ def is_rich_logging_disabled() -> bool:
24
+ """
25
+ Check if rich logging is enabled
26
+ """
27
+ return os.environ.get("DISABLE_RICH_LOGGING") is not None
28
+
29
+
30
+ def get_env_log_level() -> int:
31
+ return int(os.environ.get("LOG_LEVEL", DEFAULT_LOG_LEVEL))
32
+
33
+
34
+ def log_format_from_env() -> str:
35
+ """
36
+ Get the log format from the environment variable.
37
+ """
38
+ return os.environ.get("LOG_FORMAT", "json")
39
+
40
+
41
+ def get_rich_handler(log_level: int) -> Optional[logging.Handler]:
42
+ """
43
+ Upgrades the global loggers to use Rich logging.
44
+ """
45
+ if is_in_cluster():
46
+ return None
47
+ if not ipython_check() and is_rich_logging_disabled():
48
+ return None
49
+
50
+ import click
51
+ from rich.console import Console
52
+ from rich.logging import RichHandler
53
+
54
+ try:
55
+ width = os.get_terminal_size().columns
56
+ except Exception as e:
57
+ logger.debug(f"Failed to get terminal size: {e}")
58
+ width = 160
59
+
60
+ handler = RichHandler(
61
+ tracebacks_suppress=[click],
62
+ rich_tracebacks=True,
63
+ omit_repeated_times=False,
64
+ show_path=False,
65
+ log_time_format="%H:%M:%S.%f",
66
+ console=Console(width=width),
67
+ level=log_level,
68
+ )
69
+
70
+ formatter = logging.Formatter(fmt="%(filename)s:%(lineno)d - %(message)s")
71
+ handler.setFormatter(formatter)
72
+ return handler
73
+
74
+
75
+ def get_default_handler(log_level: int) -> logging.Handler:
76
+ handler = logging.StreamHandler()
77
+ handler.setLevel(log_level)
78
+ formatter = logging.Formatter(fmt="[%(name)s] %(message)s")
79
+ if log_format_from_env() == "json":
80
+ pass
81
+ # formatter = jsonlogger.JsonFormatter(fmt="%(asctime)s %(name)s %(levelname)s %(message)s")
82
+ handler.setFormatter(formatter)
83
+ return handler
84
+
85
+
86
+ def initialize_logger(log_level: int = DEFAULT_LOG_LEVEL, enable_rich: bool = False):
87
+ """
88
+ Initializes the global loggers to the default configuration.
89
+ """
90
+ global logger # noqa: PLW0603
91
+ logger = _create_logger("flyte", log_level, enable_rich)
92
+
93
+
94
+ def _create_logger(name: str, log_level: int = DEFAULT_LOG_LEVEL, enable_rich: bool = False) -> logging.Logger:
95
+ """
96
+ Creates a logger with the given name and log level.
97
+ """
98
+ logger = logging.getLogger(name)
99
+ logger.setLevel(log_level)
100
+ handler = None
101
+ logger.handlers = []
102
+ if enable_rich:
103
+ handler = get_rich_handler(log_level)
104
+ if handler is None:
105
+ handler = get_default_handler(log_level)
106
+ logger.addHandler(handler)
107
+ return logger
108
+
109
+
110
+ def log(fn=None, *, level=logging.DEBUG, entry=True, exit=True):
111
+ """
112
+ Decorator to log function calls.
113
+ """
114
+
115
+ def decorator(func):
116
+ if logger.isEnabledFor(level):
117
+
118
+ def wrapper(*args, **kwargs):
119
+ if entry:
120
+ logger.log(level, f"[{func.__name__}] with args: {args} and kwargs: {kwargs}")
121
+ try:
122
+ return func(*args, **kwargs)
123
+ finally:
124
+ if exit:
125
+ logger.log(level, f"[{func.__name__}] completed")
126
+
127
+ return wrapper
128
+ return func
129
+
130
+ if fn is None:
131
+ return decorator
132
+ return decorator(fn)
133
+
134
+
135
+ logger = _create_logger("flyte", get_env_log_level())
flyte/_map.py ADDED
@@ -0,0 +1,215 @@
1
+ import asyncio
2
+ from typing import Any, AsyncGenerator, AsyncIterator, Generic, Iterable, Iterator, List, Union, cast
3
+
4
+ from flyte.syncify import syncify
5
+
6
+ from ._group import group
7
+ from ._logging import logger
8
+ from ._task import P, R, TaskTemplate
9
+
10
+
11
+ class MapAsyncIterator(Generic[P, R]):
12
+ """AsyncIterator implementation for the map function results"""
13
+
14
+ def __init__(self, func: TaskTemplate[P, R], args: tuple, name: str, concurrency: int, return_exceptions: bool):
15
+ self.func = func
16
+ self.args = args
17
+ self.name = name
18
+ self.concurrency = concurrency
19
+ self.return_exceptions = return_exceptions
20
+ self._tasks: List[asyncio.Task] = []
21
+ self._current_index = 0
22
+ self._completed_count = 0
23
+ self._exception_count = 0
24
+ self._task_count = 0
25
+ self._initialized = False
26
+
27
+ def __aiter__(self) -> AsyncIterator[Union[R, Exception]]:
28
+ """Return self as the async iterator"""
29
+ return self
30
+
31
+ async def __anext__(self) -> Union[R, Exception]:
32
+ """Get the next result"""
33
+ # Initialize on first call
34
+ if not self._initialized:
35
+ await self._initialize()
36
+
37
+ # Check if we've exhausted all tasks
38
+ if self._current_index >= self._task_count:
39
+ raise StopAsyncIteration
40
+
41
+ # Get the next task result
42
+ task = self._tasks[self._current_index]
43
+ self._current_index += 1
44
+
45
+ try:
46
+ result = await task
47
+ self._completed_count += 1
48
+ logger.debug(f"Task {self._current_index - 1} completed successfully")
49
+ return result
50
+ except Exception as e:
51
+ self._exception_count += 1
52
+ logger.debug(f"Task {self._current_index - 1} failed with exception: {e}")
53
+ if self.return_exceptions:
54
+ return e
55
+ else:
56
+ # Cancel remaining tasks
57
+ for remaining_task in self._tasks[self._current_index + 1 :]:
58
+ remaining_task.cancel()
59
+ raise e
60
+
61
+ async def _initialize(self):
62
+ """Initialize the tasks - called lazily on first iteration"""
63
+ # Create all tasks at once
64
+ tasks = []
65
+ task_count = 0
66
+
67
+ for arg_tuple in zip(*self.args):
68
+ task = asyncio.create_task(self.func.aio(*arg_tuple))
69
+ tasks.append(task)
70
+ task_count += 1
71
+
72
+ if task_count == 0:
73
+ logger.info(f"Group '{self.name}' has no tasks to process")
74
+ self._tasks = []
75
+ self._task_count = 0
76
+ else:
77
+ logger.info(f"Starting {task_count} tasks in group '{self.name}' with unlimited concurrency")
78
+ self._tasks = tasks
79
+ self._task_count = task_count
80
+
81
+ self._initialized = True
82
+
83
+ async def collect(self) -> List[Union[R, Exception]]:
84
+ """Convenience method to collect all results into a list"""
85
+ results = []
86
+ async for result in self:
87
+ results.append(result)
88
+ return results
89
+
90
+ def __repr__(self):
91
+ return f"MapAsyncIterator(group_name='{self.name}', concurrency={self.concurrency})"
92
+
93
+
94
+ class _Mapper(Generic[P, R]):
95
+ """
96
+ Internal mapper class to handle the mapping logic
97
+
98
+ NOTE: The reason why we do not use the `@syncify` decorator here is because, in `syncify` we cannot use
99
+ context managers like `group` directly in the function body. This is because the `__exit__` method of the
100
+ context manager is called after the function returns. An for `_context` the `__exit__` method releases the
101
+ token (for contextvar), which was created in a separate thread. This leads to an exception like:
102
+
103
+ """
104
+
105
+ @classmethod
106
+ def _get_name(cls, task_name: str, group_name: str | None) -> str:
107
+ """Get the name of the group, defaulting to 'map' if not provided."""
108
+ return f"{task_name}_{group_name or 'map'}"
109
+
110
+ def __call__(
111
+ self,
112
+ func: TaskTemplate[P, R],
113
+ *args: Iterable[Any],
114
+ group_name: str | None = None,
115
+ concurrency: int = 0,
116
+ return_exceptions: bool = True,
117
+ ) -> Iterator[Union[R, Exception]]:
118
+ """
119
+ Map a function over the provided arguments with concurrent execution.
120
+
121
+ :param func: The async function to map.
122
+ :param args: Positional arguments to pass to the function (iterables that will be zipped).
123
+ :param group_name: The name of the group for the mapped tasks.
124
+ :param concurrency: The maximum number of concurrent tasks to run. If 0, run all tasks concurrently.
125
+ :param return_exceptions: If True, yield exceptions instead of raising them.
126
+ :return: AsyncIterator yielding results in order.
127
+ """
128
+ if not args:
129
+ return
130
+
131
+ name = self._get_name(func.name, group_name)
132
+ logger.debug(f"Blocking Map for {name}")
133
+ with group(name):
134
+ import flyte
135
+
136
+ tctx = flyte.ctx()
137
+ if tctx is None or tctx.mode == "local":
138
+ logger.warning("Running map in local mode, which will run every task sequentially.")
139
+ for v in zip(*args):
140
+ try:
141
+ yield func(*v) # type: ignore
142
+ except Exception as e:
143
+ if return_exceptions:
144
+ yield e
145
+ else:
146
+ raise e
147
+ return
148
+
149
+ i = 0
150
+ for x in cast(
151
+ Iterator[R],
152
+ _map(
153
+ func,
154
+ *args,
155
+ name=name,
156
+ concurrency=concurrency,
157
+ return_exceptions=True,
158
+ ),
159
+ ):
160
+ logger.debug(f"Mapped {x}, task {i}")
161
+ i += 1
162
+ yield x
163
+
164
+ async def aio(
165
+ self,
166
+ func: TaskTemplate[P, R],
167
+ *args: Iterable[Any],
168
+ group_name: str | None = None,
169
+ concurrency: int = 0,
170
+ return_exceptions: bool = True,
171
+ ) -> AsyncGenerator[Union[R, Exception], None]:
172
+ if not args:
173
+ return
174
+ name = self._get_name(func.name, group_name)
175
+ with group(name):
176
+ import flyte
177
+
178
+ tctx = flyte.ctx()
179
+ if tctx is None or tctx.mode == "local":
180
+ logger.warning("Running map in local mode, which will run every task sequentially.")
181
+ for v in zip(*args):
182
+ try:
183
+ yield func(*v) # type: ignore
184
+ except Exception as e:
185
+ if return_exceptions:
186
+ yield e
187
+ else:
188
+ raise e
189
+ return
190
+ async for x in _map.aio(
191
+ func,
192
+ *args,
193
+ name=name,
194
+ concurrency=concurrency,
195
+ return_exceptions=return_exceptions,
196
+ ):
197
+ yield cast(Union[R, Exception], x)
198
+
199
+
200
+ @syncify
201
+ async def _map(
202
+ func: TaskTemplate[P, R],
203
+ *args: Iterable[Any],
204
+ name: str = "map",
205
+ concurrency: int = 0,
206
+ return_exceptions: bool = True,
207
+ ) -> AsyncIterator[Union[R, Exception]]:
208
+ iter = MapAsyncIterator(
209
+ func=func, args=args, name=name, concurrency=concurrency, return_exceptions=return_exceptions
210
+ )
211
+ async for result in iter:
212
+ yield result
213
+
214
+
215
+ map: _Mapper = _Mapper()
flyte/_pod.py ADDED
@@ -0,0 +1,19 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import TYPE_CHECKING, Dict, Optional
3
+
4
+ if TYPE_CHECKING:
5
+ from kubernetes.client import V1PodSpec
6
+
7
+
8
+ _PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"
9
+ _PRIMARY_CONTAINER_DEFAULT_NAME = "primary"
10
+
11
+
12
+ @dataclass(init=True, repr=True, eq=True, frozen=False)
13
+ class PodTemplate(object):
14
+ """Custom PodTemplate specification for a Task."""
15
+
16
+ pod_spec: Optional["V1PodSpec"] = field(default_factory=lambda: V1PodSpec())
17
+ primary_container_name: str = _PRIMARY_CONTAINER_DEFAULT_NAME
18
+ labels: Optional[Dict[str, str]] = None
19
+ annotations: Optional[Dict[str, str]] = None
File without changes
@@ -0,0 +1,66 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: common/authorization.proto
4
+ """Generated protocol buffer code."""
5
+ from google.protobuf import descriptor as _descriptor
6
+ from google.protobuf import descriptor_pool as _descriptor_pool
7
+ from google.protobuf import symbol_database as _symbol_database
8
+ from google.protobuf.internal import builder as _builder
9
+ # @@protoc_insertion_point(imports)
10
+
11
+ _sym_db = _symbol_database.Default()
12
+
13
+
14
+ from flyte._protos.common import identifier_pb2 as common_dot_identifier__pb2
15
+ from flyte._protos.validate.validate import validate_pb2 as validate_dot_validate__pb2
16
+
17
+
18
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1a\x63ommon/authorization.proto\x12\x0f\x63loudidl.common\x1a\x17\x63ommon/identifier.proto\x1a\x17validate/validate.proto\"+\n\x0cOrganization\x12\x1b\n\x04name\x18\x01 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x04name\"i\n\x06\x44omain\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12K\n\x0corganization\x18\x02 \x01(\x0b\x32\x1d.cloudidl.common.OrganizationB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x0corganization\"a\n\x07Project\x12\x1b\n\x04name\x18\x01 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x04name\x12\x39\n\x06\x64omain\x18\x02 \x01(\x0b\x32\x17.cloudidl.common.DomainB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x06\x64omain\"e\n\x08Workflow\x12\x1b\n\x04name\x18\x01 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x04name\x12<\n\x07project\x18\x02 \x01(\x0b\x32\x18.cloudidl.common.ProjectB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x07project\"g\n\nLaunchPlan\x12\x1b\n\x04name\x18\x01 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x04name\x12<\n\x07project\x18\x02 \x01(\x0b\x32\x18.cloudidl.common.ProjectB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x07project\"\xfd\x02\n\x08Resource\x12\x43\n\x0corganization\x18\x01 \x01(\x0b\x32\x1d.cloudidl.common.OrganizationH\x00R\x0corganization\x12\x31\n\x06\x64omain\x18\x02 \x01(\x0b\x32\x17.cloudidl.common.DomainH\x00R\x06\x64omain\x12\x34\n\x07project\x18\x03 \x01(\x0b\x32\x18.cloudidl.common.ProjectH\x00R\x07project\x12\x37\n\x08workflow\x18\x04 \x01(\x0b\x32\x19.cloudidl.common.WorkflowH\x00R\x08workflow\x12>\n\x0blaunch_plan\x18\x05 \x01(\x0b\x32\x1b.cloudidl.common.LaunchPlanH\x00R\nlaunchPlan\x12>\n\x07\x63luster\x18\x06 \x01(\x0b\x32\".cloudidl.common.ClusterIdentifierH\x00R\x07\x63lusterB\n\n\x08resource\"v\n\nPermission\x12\x35\n\x08resource\x18\x01 \x01(\x0b\x32\x19.cloudidl.common.ResourceR\x08resource\x12\x31\n\x07\x61\x63tions\x18\x02 \x03(\x0e\x32\x17.cloudidl.common.ActionR\x07\x61\x63tions*\x94\x04\n\x06\x41\x63tion\x12\x0f\n\x0b\x41\x43TION_NONE\x10\x00\x12\x15\n\rACTION_CREATE\x10\x01\x1a\x02\x08\x01\x12\x13\n\x0b\x41\x43TION_READ\x10\x02\x1a\x02\x08\x01\x12\x15\n\rACTION_UPDATE\x10\x03\x1a\x02\x08\x01\x12\x15\n\rACTION_DELETE\x10\x04\x1a\x02\x08\x01\x12\x1f\n\x1b\x41\x43TION_VIEW_FLYTE_INVENTORY\x10\x05\x12 \n\x1c\x41\x43TION_VIEW_FLYTE_EXECUTIONS\x10\x06\x12#\n\x1f\x41\x43TION_REGISTER_FLYTE_INVENTORY\x10\x07\x12\"\n\x1e\x41\x43TION_CREATE_FLYTE_EXECUTIONS\x10\x08\x12\x1d\n\x19\x41\x43TION_ADMINISTER_PROJECT\x10\t\x12\x1d\n\x19\x41\x43TION_MANAGE_PERMISSIONS\x10\n\x12\x1d\n\x19\x41\x43TION_ADMINISTER_ACCOUNT\x10\x0b\x12\x19\n\x15\x41\x43TION_MANAGE_CLUSTER\x10\x0c\x12,\n(ACTION_EDIT_EXECUTION_RELATED_ATTRIBUTES\x10\r\x12*\n&ACTION_EDIT_CLUSTER_RELATED_ATTRIBUTES\x10\x0e\x12!\n\x1d\x41\x43TION_EDIT_UNUSED_ATTRIBUTES\x10\x0f\x12\x1e\n\x1a\x41\x43TION_SUPPORT_SYSTEM_LOGS\x10\x10\x42\xb3\x01\n\x13\x63om.cloudidl.commonB\x12\x41uthorizationProtoH\x02P\x01Z)github.com/unionai/cloud/gen/pb-go/common\xa2\x02\x03\x43\x43X\xaa\x02\x0f\x43loudidl.Common\xca\x02\x0f\x43loudidl\\Common\xe2\x02\x1b\x43loudidl\\Common\\GPBMetadata\xea\x02\x10\x43loudidl::Commonb\x06proto3')
19
+
20
+ _globals = globals()
21
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
22
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'common.authorization_pb2', _globals)
23
+ if _descriptor._USE_C_DESCRIPTORS == False:
24
+ DESCRIPTOR._options = None
25
+ DESCRIPTOR._serialized_options = b'\n\023com.cloudidl.commonB\022AuthorizationProtoH\002P\001Z)github.com/unionai/cloud/gen/pb-go/common\242\002\003CCX\252\002\017Cloudidl.Common\312\002\017Cloudidl\\Common\342\002\033Cloudidl\\Common\\GPBMetadata\352\002\020Cloudidl::Common'
26
+ _ACTION.values_by_name["ACTION_CREATE"]._options = None
27
+ _ACTION.values_by_name["ACTION_CREATE"]._serialized_options = b'\010\001'
28
+ _ACTION.values_by_name["ACTION_READ"]._options = None
29
+ _ACTION.values_by_name["ACTION_READ"]._serialized_options = b'\010\001'
30
+ _ACTION.values_by_name["ACTION_UPDATE"]._options = None
31
+ _ACTION.values_by_name["ACTION_UPDATE"]._serialized_options = b'\010\001'
32
+ _ACTION.values_by_name["ACTION_DELETE"]._options = None
33
+ _ACTION.values_by_name["ACTION_DELETE"]._serialized_options = b'\010\001'
34
+ _ORGANIZATION.fields_by_name['name']._options = None
35
+ _ORGANIZATION.fields_by_name['name']._serialized_options = b'\372B\004r\002\020\001'
36
+ _DOMAIN.fields_by_name['organization']._options = None
37
+ _DOMAIN.fields_by_name['organization']._serialized_options = b'\372B\005\212\001\002\020\001'
38
+ _PROJECT.fields_by_name['name']._options = None
39
+ _PROJECT.fields_by_name['name']._serialized_options = b'\372B\004r\002\020\001'
40
+ _PROJECT.fields_by_name['domain']._options = None
41
+ _PROJECT.fields_by_name['domain']._serialized_options = b'\372B\005\212\001\002\020\001'
42
+ _WORKFLOW.fields_by_name['name']._options = None
43
+ _WORKFLOW.fields_by_name['name']._serialized_options = b'\372B\004r\002\020\001'
44
+ _WORKFLOW.fields_by_name['project']._options = None
45
+ _WORKFLOW.fields_by_name['project']._serialized_options = b'\372B\005\212\001\002\020\001'
46
+ _LAUNCHPLAN.fields_by_name['name']._options = None
47
+ _LAUNCHPLAN.fields_by_name['name']._serialized_options = b'\372B\004r\002\020\001'
48
+ _LAUNCHPLAN.fields_by_name['project']._options = None
49
+ _LAUNCHPLAN.fields_by_name['project']._serialized_options = b'\372B\005\212\001\002\020\001'
50
+ _globals['_ACTION']._serialized_start=1061
51
+ _globals['_ACTION']._serialized_end=1593
52
+ _globals['_ORGANIZATION']._serialized_start=97
53
+ _globals['_ORGANIZATION']._serialized_end=140
54
+ _globals['_DOMAIN']._serialized_start=142
55
+ _globals['_DOMAIN']._serialized_end=247
56
+ _globals['_PROJECT']._serialized_start=249
57
+ _globals['_PROJECT']._serialized_end=346
58
+ _globals['_WORKFLOW']._serialized_start=348
59
+ _globals['_WORKFLOW']._serialized_end=449
60
+ _globals['_LAUNCHPLAN']._serialized_start=451
61
+ _globals['_LAUNCHPLAN']._serialized_end=554
62
+ _globals['_RESOURCE']._serialized_start=557
63
+ _globals['_RESOURCE']._serialized_end=938
64
+ _globals['_PERMISSION']._serialized_start=940
65
+ _globals['_PERMISSION']._serialized_end=1058
66
+ # @@protoc_insertion_point(module_scope)