flyte 2.0.0b13__py3-none-any.whl → 2.0.0b15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (45) hide show
  1. flyte/_bin/debug.py +38 -0
  2. flyte/_bin/runtime.py +13 -0
  3. flyte/_code_bundle/_utils.py +2 -0
  4. flyte/_code_bundle/bundle.py +4 -4
  5. flyte/_debug/__init__.py +0 -0
  6. flyte/_debug/constants.py +39 -0
  7. flyte/_debug/utils.py +17 -0
  8. flyte/_debug/vscode.py +300 -0
  9. flyte/_image.py +32 -6
  10. flyte/_initialize.py +14 -28
  11. flyte/_internal/controllers/remote/_action.py +1 -1
  12. flyte/_internal/controllers/remote/_controller.py +35 -35
  13. flyte/_internal/imagebuild/docker_builder.py +11 -15
  14. flyte/_internal/imagebuild/remote_builder.py +52 -23
  15. flyte/_internal/runtime/entrypoints.py +3 -0
  16. flyte/_internal/runtime/task_serde.py +1 -2
  17. flyte/_internal/runtime/taskrunner.py +9 -3
  18. flyte/_protos/common/identifier_pb2.py +25 -19
  19. flyte/_protos/common/identifier_pb2.pyi +10 -0
  20. flyte/_protos/imagebuilder/definition_pb2.py +32 -31
  21. flyte/_protos/imagebuilder/definition_pb2.pyi +25 -12
  22. flyte/_protos/workflow/queue_service_pb2.py +26 -24
  23. flyte/_protos/workflow/queue_service_pb2.pyi +6 -4
  24. flyte/_protos/workflow/run_definition_pb2.py +50 -48
  25. flyte/_protos/workflow/run_definition_pb2.pyi +41 -16
  26. flyte/_protos/workflow/task_definition_pb2.py +16 -13
  27. flyte/_protos/workflow/task_definition_pb2.pyi +7 -0
  28. flyte/_task.py +6 -6
  29. flyte/_task_environment.py +4 -4
  30. flyte/_version.py +3 -3
  31. flyte/cli/_build.py +2 -3
  32. flyte/cli/_run.py +11 -12
  33. flyte/models.py +2 -0
  34. flyte/remote/_action.py +5 -2
  35. flyte/remote/_client/auth/_authenticators/device_code.py +1 -1
  36. flyte/remote/_client/auth/_authenticators/pkce.py +1 -1
  37. flyte/remote/_task.py +4 -4
  38. flyte-2.0.0b15.data/scripts/debug.py +38 -0
  39. {flyte-2.0.0b13.data → flyte-2.0.0b15.data}/scripts/runtime.py +13 -0
  40. {flyte-2.0.0b13.dist-info → flyte-2.0.0b15.dist-info}/METADATA +2 -2
  41. {flyte-2.0.0b13.dist-info → flyte-2.0.0b15.dist-info}/RECORD +45 -39
  42. {flyte-2.0.0b13.dist-info → flyte-2.0.0b15.dist-info}/WHEEL +0 -0
  43. {flyte-2.0.0b13.dist-info → flyte-2.0.0b15.dist-info}/entry_points.txt +0 -0
  44. {flyte-2.0.0b13.dist-info → flyte-2.0.0b15.dist-info}/licenses/LICENSE +0 -0
  45. {flyte-2.0.0b13.dist-info → flyte-2.0.0b15.dist-info}/top_level.txt +0 -0
flyte/_bin/debug.py ADDED
@@ -0,0 +1,38 @@
1
+ import click
2
+
3
+
4
+ @click.group()
5
+ def _debug():
6
+ """Debug commands for Flyte."""
7
+
8
+
9
+ @_debug.command("resume")
10
+ @click.option("--pid", "-m", type=int, required=True, help="PID of the vscode server.")
11
+ def resume(pid):
12
+ """
13
+ Resume a Flyte task for debugging purposes.
14
+
15
+ Args:
16
+ pid (int): PID of the vscode server.
17
+ """
18
+ import os
19
+ import signal
20
+
21
+ print("Terminating server and resuming task.")
22
+ answer = (
23
+ input(
24
+ "This operation will kill the server. All unsaved data will be lost,"
25
+ " and you will no longer be able to connect to it. Do you really want to terminate? (Y/N): "
26
+ )
27
+ .strip()
28
+ .upper()
29
+ )
30
+ if answer == "Y":
31
+ os.kill(pid, signal.SIGTERM)
32
+ print("The server has been terminated and the task has been resumed.")
33
+ else:
34
+ print("Operation canceled.")
35
+
36
+
37
+ if __name__ == "__main__":
38
+ _debug()
flyte/_bin/runtime.py CHANGED
@@ -26,6 +26,7 @@ DOMAIN_NAME = "FLYTE_INTERNAL_TASK_DOMAIN"
26
26
  ORG_NAME = "_U_ORG_NAME"
27
27
  ENDPOINT_OVERRIDE = "_U_EP_OVERRIDE"
28
28
  RUN_OUTPUT_BASE_DIR = "_U_RUN_BASE"
29
+ FLYTE_ENABLE_VSCODE_KEY = "_F_E_VS"
29
30
 
30
31
  # TODO: Remove this after proper auth is implemented
31
32
  _UNION_EAGER_API_KEY_ENV_VAR = "_UNION_EAGER_API_KEY"
@@ -49,6 +50,8 @@ def _pass_through():
49
50
  @click.option("--project", envvar=PROJECT_NAME, required=False)
50
51
  @click.option("--domain", envvar=DOMAIN_NAME, required=False)
51
52
  @click.option("--org", envvar=ORG_NAME, required=False)
53
+ @click.option("--debug", envvar=FLYTE_ENABLE_VSCODE_KEY, type=click.BOOL, required=False)
54
+ @click.option("--interactive-mode", type=click.BOOL, required=False)
52
55
  @click.option("--image-cache", required=False)
53
56
  @click.option("--tgz", required=False)
54
57
  @click.option("--pkl", required=False)
@@ -59,12 +62,16 @@ def _pass_through():
59
62
  type=click.UNPROCESSED,
60
63
  nargs=-1,
61
64
  )
65
+ @click.pass_context
62
66
  def main(
67
+ ctx: click.Context,
63
68
  run_name: str,
64
69
  name: str,
65
70
  project: str,
66
71
  domain: str,
67
72
  org: str,
73
+ debug: bool,
74
+ interactive_mode: bool,
68
75
  image_cache: str,
69
76
  version: str,
70
77
  inputs: str,
@@ -109,6 +116,11 @@ def main(
109
116
  if name.startswith("{{"):
110
117
  name = os.getenv("ACTION_NAME", "")
111
118
 
119
+ if debug and name == "a0":
120
+ from flyte._debug.vscode import _start_vscode_server
121
+
122
+ asyncio.run(_start_vscode_server(ctx))
123
+
112
124
  # Figure out how to connect
113
125
  # This detection of api key is a hack for now.
114
126
  controller_kwargs: dict[str, Any] = {"insecure": False}
@@ -143,6 +155,7 @@ def main(
143
155
  version=version,
144
156
  controller=controller,
145
157
  image_cache=ic,
158
+ interactive_mode=interactive_mode or debug,
146
159
  )
147
160
  # Create a coroutine to watch for errors
148
161
  controller_failure = controller.watch_for_errors()
@@ -240,6 +240,8 @@ def list_imported_modules_as_files(source_path: str, modules: List[ModuleType])
240
240
 
241
241
  if not _file_is_in_directory(mod_file, source_path):
242
242
  # Only upload files where the module file in the source directory
243
+ # print log line for files that have common ancestor with source_path, but not in it.
244
+ logger.debug(f"{mod_file} is not in {source_path}")
243
245
  continue
244
246
 
245
247
  files.append(mod_file)
@@ -178,9 +178,10 @@ async def download_bundle(bundle: CodeBundle) -> pathlib.Path:
178
178
  # TODO make storage apis better to accept pathlib.Path
179
179
  if bundle.tgz:
180
180
  downloaded_bundle = dest / os.path.basename(bundle.tgz)
181
+ if downloaded_bundle.exists():
182
+ return downloaded_bundle.absolute()
181
183
  # Download the tgz file
182
- path = await storage.get(bundle.tgz, str(downloaded_bundle.absolute()))
183
- downloaded_bundle = pathlib.Path(path)
184
+ await storage.get(bundle.tgz, str(downloaded_bundle.absolute()))
184
185
  # NOTE the os.path.join(destination, ''). This is to ensure that the given path is in fact a directory and all
185
186
  # downloaded data should be copied into this directory. We do this to account for a difference in behavior in
186
187
  # fsspec, which requires a trailing slash in case of pre-existing directory.
@@ -204,8 +205,7 @@ async def download_bundle(bundle: CodeBundle) -> pathlib.Path:
204
205
 
205
206
  downloaded_bundle = dest / os.path.basename(bundle.pkl)
206
207
  # Download the tgz file
207
- path = await storage.get(bundle.pkl, str(downloaded_bundle.absolute()))
208
- downloaded_bundle = pathlib.Path(path)
208
+ await storage.get(bundle.pkl, str(downloaded_bundle.absolute()))
209
209
  return downloaded_bundle.absolute()
210
210
  else:
211
211
  raise ValueError("Code bundle should be either tgz or pkl, found neither.")
File without changes
@@ -0,0 +1,39 @@
1
+ import os
2
+ from pathlib import Path
3
+
4
+ # Where the code-server tar and plugins are downloaded to
5
+ EXECUTABLE_NAME = "code-server"
6
+ DOWNLOAD_DIR = Path.home() / ".code-server"
7
+ HOURS_TO_SECONDS = 60 * 60
8
+ DEFAULT_UP_SECONDS = 10 * HOURS_TO_SECONDS # 10 hours
9
+ DEFAULT_CODE_SERVER_REMOTE_PATHS = {
10
+ "amd64": "https://github.com/coder/code-server/releases/download/v4.18.0/code-server-4.18.0-linux-amd64.tar.gz",
11
+ "arm64": "https://github.com/coder/code-server/releases/download/v4.18.0/code-server-4.18.0-linux-arm64.tar.gz",
12
+ }
13
+ DEFAULT_CODE_SERVER_EXTENSIONS = [
14
+ "https://raw.githubusercontent.com/flyteorg/flytetools/master/flytekitplugins/flyin/ms-python.python-2023.20.0.vsix",
15
+ "https://raw.githubusercontent.com/flyteorg/flytetools/master/flytekitplugins/flyin/ms-toolsai.jupyter-2023.9.100.vsix",
16
+ ]
17
+
18
+ # Duration to pause the checking of the heartbeat file until the next one
19
+ HEARTBEAT_CHECK_SECONDS = 60
20
+ MAX_IDLE_SECONDS = 180
21
+
22
+ # The path is hardcoded by code-server
23
+ # https://coder.com/docs/code-server/latest/FAQ#what-is-the-heartbeat-file
24
+ HEARTBEAT_PATH = os.path.expanduser("~/.local/share/code-server/heartbeat")
25
+
26
+ INTERACTIVE_DEBUGGING_FILE_NAME = "flyteinteractive_interactive_entrypoint.py"
27
+ RESUME_TASK_FILE_NAME = "flyteinteractive_resume_task.py"
28
+ # Config keys to store in task template
29
+ VSCODE_TYPE_KEY = "flyteinteractive_type"
30
+ VSCODE_PORT_KEY = "flyteinteractive_port"
31
+
32
+ TASK_FUNCTION_SOURCE_PATH = "TASK_FUNCTION_SOURCE_PATH"
33
+
34
+ # Default max idle seconds to terminate the flyteinteractive server
35
+ HOURS_TO_SECONDS = 60 * 60
36
+ MAX_IDLE_SECONDS = 10 * HOURS_TO_SECONDS # 10 hours
37
+
38
+ # Subprocess constants
39
+ EXIT_CODE_SUCCESS = 0
flyte/_debug/utils.py ADDED
@@ -0,0 +1,17 @@
1
+ import asyncio
2
+
3
+ from flyte._debug.constants import EXIT_CODE_SUCCESS
4
+ from flyte._logging import logger
5
+
6
+
7
+ async def execute_command(cmd: str):
8
+ """
9
+ Execute a command in the shell.
10
+ """
11
+ process = await asyncio.create_subprocess_shell(cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
12
+ logger.info(f"cmd: {cmd}")
13
+ stdout, stderr = await process.communicate()
14
+ if process.returncode != EXIT_CODE_SUCCESS:
15
+ raise RuntimeError(f"Command {cmd} failed with error: {stderr!r}")
16
+ logger.info(f"stdout: {stdout!r}")
17
+ logger.info(f"stderr: {stderr!r}")
flyte/_debug/vscode.py ADDED
@@ -0,0 +1,300 @@
1
+ import asyncio
2
+ import json
3
+ import multiprocessing
4
+ import os
5
+ import platform
6
+ import shutil
7
+ import subprocess
8
+ import sys
9
+ import tarfile
10
+ import time
11
+ from typing import List
12
+
13
+ import aiofiles
14
+ import click
15
+ import httpx
16
+
17
+ from flyte import storage
18
+ from flyte._debug.constants import (
19
+ DEFAULT_CODE_SERVER_EXTENSIONS,
20
+ DEFAULT_CODE_SERVER_REMOTE_PATHS,
21
+ DOWNLOAD_DIR,
22
+ EXECUTABLE_NAME,
23
+ EXIT_CODE_SUCCESS,
24
+ HEARTBEAT_PATH,
25
+ MAX_IDLE_SECONDS,
26
+ )
27
+ from flyte._debug.utils import (
28
+ execute_command,
29
+ )
30
+ from flyte._internal.runtime.rusty import download_tgz
31
+ from flyte._logging import logger
32
+
33
+
34
+ async def download_file(url: str, target_dir: str) -> str:
35
+ """
36
+ Downloads a file from a given URL using HTTPX and saves it locally.
37
+
38
+ Args:
39
+ url (str): The URL of the file to download.
40
+ target_dir (str): The directory where the file should be saved. Defaults to current directory.
41
+ """
42
+ try:
43
+ filename = os.path.join(target_dir, os.path.basename(url))
44
+ if url.startswith("http"):
45
+ response = httpx.get(url, follow_redirects=True)
46
+ response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
47
+ async with aiofiles.open(filename, "wb") as f:
48
+ await f.write(response.content)
49
+ else:
50
+ await storage.get(url, filename)
51
+ logger.info(f"File '{filename}' downloaded successfully from '{url}'.")
52
+ return filename
53
+
54
+ except httpx.RequestError as e:
55
+ raise RuntimeError(f"An error occurred while requesting '{url}': {e}")
56
+ except httpx.HTTPStatusError as e:
57
+ raise RuntimeError(f"HTTP error occurred: {e.response.status_code} - {e.response.text}")
58
+ except Exception as e:
59
+ raise RuntimeError(f"An unexpected error occurred: {e}")
60
+
61
+
62
+ def get_default_extensions() -> List[str]:
63
+ extensions = os.getenv("_F_CS_E")
64
+ if extensions is not None:
65
+ return extensions.split(",")
66
+ return DEFAULT_CODE_SERVER_EXTENSIONS
67
+
68
+
69
+ def get_code_server_info() -> str:
70
+ """
71
+ Returns the code server information based on the system's architecture.
72
+
73
+ This function checks the system's architecture and returns the corresponding
74
+ code server information from the provided dictionary. The function currently
75
+ supports AMD64 and ARM64 architectures.
76
+
77
+ Returns:
78
+ str: The code server information corresponding to the system's architecture.
79
+
80
+ Raises:
81
+ ValueError: If the system's architecture is not AMD64 or ARM64.
82
+ """
83
+ code_server_path = os.getenv("_F_CS_RP")
84
+ if code_server_path is not None:
85
+ return code_server_path
86
+
87
+ machine_info = platform.machine()
88
+ logger.info(f"machine type: {machine_info}")
89
+ code_server_info_dict = DEFAULT_CODE_SERVER_REMOTE_PATHS
90
+
91
+ if "aarch64" == machine_info:
92
+ return code_server_info_dict["arm64"]
93
+ elif "x86_64" == machine_info:
94
+ return code_server_info_dict["amd64"]
95
+ else:
96
+ raise ValueError(
97
+ "Automatic download is only supported on AMD64 and ARM64 architectures."
98
+ " If you are using a different architecture, please visit the code-server official website to"
99
+ " manually download the appropriate version for your image."
100
+ )
101
+
102
+
103
+ def get_installed_extensions() -> List[str]:
104
+ """
105
+ Get the list of installed extensions.
106
+
107
+ Returns:
108
+ List[str]: The list of installed extensions.
109
+ """
110
+ installed_extensions = subprocess.run(
111
+ ["code-server", "--list-extensions"], check=False, capture_output=True, text=True
112
+ )
113
+ if installed_extensions.returncode != EXIT_CODE_SUCCESS:
114
+ logger.info(f"Command code-server --list-extensions failed with error: {installed_extensions.stderr}")
115
+ return []
116
+
117
+ return installed_extensions.stdout.splitlines()
118
+
119
+
120
+ def is_extension_installed(extension: str, installed_extensions: List[str]) -> bool:
121
+ return any(installed_extension in extension for installed_extension in installed_extensions)
122
+
123
+
124
+ async def download_vscode():
125
+ """
126
+ Download vscode server and extension from remote to local and add the directory of binary executable to $PATH.
127
+ """
128
+ # If the code server already exists in the container, skip downloading
129
+ executable_path = shutil.which(EXECUTABLE_NAME)
130
+ if executable_path is not None or os.path.exists(DOWNLOAD_DIR):
131
+ logger.info(f"Code server binary already exists at {executable_path}")
132
+ logger.info("Skipping downloading code server...")
133
+ else:
134
+ logger.info("Code server is not in $PATH, start downloading code server...")
135
+ # Create DOWNLOAD_DIR if not exist
136
+ logger.info(f"DOWNLOAD_DIR: {DOWNLOAD_DIR}")
137
+ os.makedirs(DOWNLOAD_DIR)
138
+
139
+ logger.info(f"Start downloading files to {DOWNLOAD_DIR}")
140
+ # Download remote file to local
141
+ code_server_remote_path = get_code_server_info()
142
+ code_server_tar_path = await download_file(code_server_remote_path, str(DOWNLOAD_DIR))
143
+
144
+ # Extract the tarball
145
+ with tarfile.open(code_server_tar_path, "r:gz") as tar:
146
+ tar.extractall(path=DOWNLOAD_DIR)
147
+
148
+ if os.path.exists(DOWNLOAD_DIR):
149
+ code_server_dir_name = os.path.basename(get_code_server_info()).removesuffix(".tar.gz")
150
+ code_server_bin_dir = os.path.join(DOWNLOAD_DIR, code_server_dir_name, "bin")
151
+ # Add the directory of code-server binary to $PATH
152
+ os.environ["PATH"] = code_server_bin_dir + os.pathsep + os.environ["PATH"]
153
+
154
+ # If the extension already exists in the container, skip downloading
155
+ installed_extensions = get_installed_extensions()
156
+ coros = []
157
+
158
+ for extension in get_default_extensions():
159
+ if not is_extension_installed(extension, installed_extensions):
160
+ coros.append(download_file(extension, str(DOWNLOAD_DIR)))
161
+ extension_paths = await asyncio.gather(*coros)
162
+
163
+ coros = []
164
+ for p in extension_paths:
165
+ logger.info(f"Execute extension installation command to install extension {p}")
166
+ coros.append(execute_command(f"code-server --install-extension {p}"))
167
+
168
+ await asyncio.gather(*coros)
169
+
170
+
171
+ def prepare_launch_json(ctx: click.Context, pid: int):
172
+ """
173
+ Generate the launch.json and settings.json for users to easily launch interactive debugging and task resumption.
174
+ """
175
+
176
+ virtual_venv = os.getenv("VIRTUAL_ENV")
177
+ if virtual_venv is None:
178
+ raise RuntimeError("VIRTUAL_ENV is not found in environment variables.")
179
+
180
+ run_name = ctx.params["run_name"]
181
+ name = ctx.params["name"]
182
+ # TODO: Executor should pass correct name.
183
+ if run_name.startswith("{{"):
184
+ run_name = os.getenv("RUN_NAME", "")
185
+ if name.startswith("{{"):
186
+ name = os.getenv("ACTION_NAME", "")
187
+
188
+ launch_json = {
189
+ "version": "0.2.0",
190
+ "configurations": [
191
+ {
192
+ "name": "Interactive Debugging",
193
+ "type": "python",
194
+ "request": "launch",
195
+ "program": f"{virtual_venv}/bin/runtime.py",
196
+ "console": "integratedTerminal",
197
+ "justMyCode": True,
198
+ "args": [
199
+ "a0",
200
+ "--inputs",
201
+ ctx.params["inputs"],
202
+ "--outputs-path",
203
+ ctx.params["outputs_path"],
204
+ "--version",
205
+ ctx.params["version"],
206
+ "--run-base-dir",
207
+ ctx.params["run_base_dir"],
208
+ "--name",
209
+ name,
210
+ "--run-name",
211
+ run_name,
212
+ "--project",
213
+ ctx.params["project"],
214
+ "--domain",
215
+ ctx.params["domain"],
216
+ "--org",
217
+ ctx.params["org"],
218
+ "--image-cache",
219
+ ctx.params["image_cache"],
220
+ "--debug",
221
+ "False",
222
+ "--interactive-mode",
223
+ "True",
224
+ "--tgz",
225
+ ctx.params["tgz"],
226
+ "--dest",
227
+ ctx.params["dest"],
228
+ "--resolver",
229
+ ctx.params["resolver"],
230
+ *ctx.params["resolver_args"],
231
+ ],
232
+ },
233
+ {
234
+ "name": "Resume Task",
235
+ "type": "python",
236
+ "request": "launch",
237
+ "program": f"{virtual_venv}/bin/debug.py",
238
+ "console": "integratedTerminal",
239
+ "justMyCode": True,
240
+ "args": ["resume", "--pid", str(pid)],
241
+ },
242
+ ],
243
+ }
244
+
245
+ vscode_directory = os.path.join(os.getcwd(), ".vscode")
246
+ if not os.path.exists(vscode_directory):
247
+ os.makedirs(vscode_directory)
248
+
249
+ with open(os.path.join(vscode_directory, "launch.json"), "w") as file:
250
+ json.dump(launch_json, file, indent=4)
251
+
252
+ settings_json = {"python.defaultInterpreterPath": sys.executable}
253
+ with open(os.path.join(vscode_directory, "settings.json"), "w") as file:
254
+ json.dump(settings_json, file, indent=4)
255
+
256
+
257
+ async def _start_vscode_server(ctx: click.Context):
258
+ await asyncio.gather(download_tgz(ctx.params["dest"], ctx.params["version"], ctx.params["tgz"]), download_vscode())
259
+ child_process = multiprocessing.Process(
260
+ target=lambda cmd: asyncio.run(asyncio.run(execute_command(cmd))),
261
+ kwargs={"cmd": f"code-server --bind-addr 0.0.0.0:8080 --disable-workspace-trust --auth none {os.getcwd()}"},
262
+ )
263
+ child_process.start()
264
+ if child_process.pid is None:
265
+ raise RuntimeError("Failed to start vscode server.")
266
+ prepare_launch_json(ctx, child_process.pid)
267
+
268
+ start_time = time.time()
269
+ check_interval = 60 # Interval for heartbeat checking in seconds
270
+ last_heartbeat_check = time.time() - check_interval
271
+
272
+ def terminate_process():
273
+ if child_process.is_alive():
274
+ child_process.terminate()
275
+ child_process.join()
276
+
277
+ logger.info("waiting for task to resume...")
278
+ while child_process.is_alive():
279
+ current_time = time.time()
280
+ if current_time - last_heartbeat_check >= check_interval:
281
+ last_heartbeat_check = current_time
282
+ if not os.path.exists(HEARTBEAT_PATH):
283
+ delta = current_time - start_time
284
+ logger.info(f"Code server has not been connected since {delta} seconds ago.")
285
+ logger.info("Please open the browser to connect to the running server.")
286
+ else:
287
+ delta = current_time - os.path.getmtime(HEARTBEAT_PATH)
288
+ logger.info(f"The latest activity on code server is {delta} seconds ago.")
289
+
290
+ # If the time from last connection is longer than max idle seconds, terminate the vscode server.
291
+ if delta > MAX_IDLE_SECONDS:
292
+ logger.info(f"VSCode server is idle for more than {MAX_IDLE_SECONDS} seconds. Terminating...")
293
+ terminate_process()
294
+ sys.exit()
295
+
296
+ await asyncio.sleep(1)
297
+
298
+ logger.info("User has resumed the task.")
299
+ terminate_process()
300
+ return
flyte/_image.py CHANGED
@@ -252,9 +252,15 @@ class AptPackages(Layer):
252
252
  @dataclass(frozen=True, repr=True)
253
253
  class Commands(Layer):
254
254
  commands: Tuple[str, ...]
255
+ secret_mounts: Optional[Tuple[str | Secret, ...]] = None
255
256
 
256
257
  def update_hash(self, hasher: hashlib._Hash):
257
- hasher.update("".join(self.commands).encode("utf-8"))
258
+ hash_input = "".join(self.commands)
259
+
260
+ if self.secret_mounts:
261
+ for secret_mount in self.secret_mounts:
262
+ hash_input += str(secret_mount)
263
+ hasher.update(hash_input.encode("utf-8"))
258
264
 
259
265
 
260
266
  @rich.repr.auto
@@ -782,9 +788,24 @@ class Image:
782
788
 
783
789
  Example:
784
790
  ```python
785
- @flyte.task(image=(flyte.Image
786
- .ubuntu_python()
787
- .with_pip_packages("requests", "numpy")))
791
+ @flyte.task(image=(flyte.Image.from_debian_base().with_pip_packages("requests", "numpy")))
792
+ def my_task(x: int) -> int:
793
+ import numpy as np
794
+ return np.sum([x, 1])
795
+ ```
796
+
797
+ To mount secrets during the build process to download private packages, you can use the `secret_mounts`.
798
+ In the below example, "GITHUB_PAT" will be mounted as env var "GITHUB_PAT",
799
+ and "apt-secret" will be mounted at /etc/apt/apt-secret.
800
+ Example:
801
+ ```python
802
+ private_package = "git+https://$GITHUB_PAT@github.com/flyteorg/flytex.git@2e20a2acebfc3877d84af643fdd768edea41d533"
803
+ @flyte.task(
804
+ image=(
805
+ flyte.Image.from_debian_base()
806
+ .with_pip_packages("private_package", secret_mounts=[Secret(key="GITHUB_PAT")])
807
+ .with_apt_packages("git", secret_mounts=[Secret(key="apt-secret", mount="/etc/apt/apt-secret")])
808
+ )
788
809
  def my_task(x: int) -> int:
789
810
  import numpy as np
790
811
  return np.sum([x, 1])
@@ -909,16 +930,21 @@ class Image:
909
930
  )
910
931
  return new_image
911
932
 
912
- def with_commands(self, commands: List[str]) -> Image:
933
+ def with_commands(self, commands: List[str], secret_mounts: Optional[SecretRequest] = None) -> Image:
913
934
  """
914
935
  Use this method to create a new image with the specified commands layered on top of the current image
915
936
  Be sure not to use RUN in your command.
916
937
 
917
938
  :param commands: list of commands to run
939
+ :param secret_mounts: list of secret mounts to use for the build process.
918
940
  :return: Image
919
941
  """
920
942
  new_commands: Tuple = _ensure_tuple(commands)
921
- new_image = self.clone(addl_layer=Commands(commands=new_commands))
943
+ new_image = self.clone(
944
+ addl_layer=Commands(
945
+ commands=new_commands, secret_mounts=_ensure_tuple(secret_mounts) if secret_mounts else None
946
+ )
947
+ )
922
948
  return new_image
923
949
 
924
950
  def with_local_v2(self) -> Image:
flyte/_initialize.py CHANGED
@@ -205,6 +205,14 @@ async def init(
205
205
  http_proxy_url=http_proxy_url,
206
206
  )
207
207
 
208
+ if not root_dir:
209
+ editable_root = get_cwd_editable_install()
210
+ if editable_root:
211
+ logger.info(f"Using editable install as root directory: {editable_root}")
212
+ root_dir = editable_root
213
+ else:
214
+ logger.info("No editable install found, using current working directory as root directory.")
215
+ root_dir = Path.cwd()
208
216
  root_dir = root_dir or get_cwd_editable_install() or Path.cwd()
209
217
  _init_config = _InitConfig(
210
218
  root_dir=root_dir,
@@ -242,17 +250,19 @@ async def init_from_config(
242
250
  cfg: config.Config
243
251
  if path_or_config is None or isinstance(path_or_config, str):
244
252
  # If a string is passed, treat it as a path to the config file
245
- if path_or_config:
253
+ if root_dir and path_or_config:
254
+ cfg = config.auto(str(root_dir / path_or_config))
255
+ elif path_or_config:
246
256
  if not Path(path_or_config).exists():
247
257
  raise InitializationError(
248
258
  "ConfigFileNotFoundError",
249
259
  "user",
250
260
  f"Configuration file '{path_or_config}' does not exist., current working directory is {Path.cwd()}",
251
261
  )
252
- if root_dir and path_or_config:
253
- cfg = config.auto(str(root_dir / path_or_config))
254
- else:
255
262
  cfg = config.auto(path_or_config)
263
+ else:
264
+ # If no path is provided, use the default config file
265
+ cfg = config.auto()
256
266
  else:
257
267
  # If a Config object is passed, use it directly
258
268
  cfg = path_or_config
@@ -374,30 +384,6 @@ def ensure_client():
374
384
  )
375
385
 
376
386
 
377
- def requires_client(func: T) -> T:
378
- """
379
- Decorator that checks if the client has been initialized before executing the function.
380
- Raises InitializationError if the client is not initialized.
381
-
382
- :param func: Function to decorate
383
- :return: Decorated function that checks for initialization
384
- """
385
-
386
- @functools.wraps(func)
387
- async def wrapper(*args, **kwargs) -> T:
388
- init_config = _get_init_config()
389
- if init_config is None or init_config.client is None:
390
- raise InitializationError(
391
- "ClientNotInitializedError",
392
- "user",
393
- f"Function '{func.__name__}' requires client to be initialized. "
394
- f"Call flyte.init() with a valid endpoint or api-key before using this function.",
395
- )
396
- return func(*args, **kwargs)
397
-
398
- return typing.cast(T, wrapper)
399
-
400
-
401
387
  def requires_storage(func: T) -> T:
402
388
  """
403
389
  Decorator that checks if the storage has been initialized before executing the function.
@@ -184,7 +184,7 @@ class Action:
184
184
  et.nanos = int((end_time % 1) * 1e9)
185
185
 
186
186
  spec = (
187
- task_definition_pb2.TaskSpec(task_template=tasks_pb2.TaskTemplate(interface=typed_interface))
187
+ task_definition_pb2.TraceSpec(interface=typed_interface)
188
188
  if typed_interface
189
189
  else None
190
190
  )