flyte 2.0.0b22__py3-none-any.whl → 2.0.0b23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (88) hide show
  1. flyte/__init__.py +5 -0
  2. flyte/_bin/runtime.py +35 -5
  3. flyte/_cache/cache.py +4 -2
  4. flyte/_cache/local_cache.py +215 -0
  5. flyte/_code_bundle/bundle.py +1 -0
  6. flyte/_debug/constants.py +0 -1
  7. flyte/_debug/vscode.py +6 -1
  8. flyte/_deploy.py +193 -52
  9. flyte/_environment.py +5 -0
  10. flyte/_excepthook.py +1 -1
  11. flyte/_image.py +101 -72
  12. flyte/_initialize.py +23 -0
  13. flyte/_internal/controllers/_local_controller.py +64 -24
  14. flyte/_internal/controllers/remote/_action.py +4 -1
  15. flyte/_internal/controllers/remote/_controller.py +5 -2
  16. flyte/_internal/controllers/remote/_core.py +6 -3
  17. flyte/_internal/controllers/remote/_informer.py +1 -1
  18. flyte/_internal/imagebuild/docker_builder.py +92 -28
  19. flyte/_internal/imagebuild/image_builder.py +7 -13
  20. flyte/_internal/imagebuild/remote_builder.py +6 -1
  21. flyte/_internal/runtime/io.py +13 -1
  22. flyte/_internal/runtime/rusty.py +17 -2
  23. flyte/_internal/runtime/task_serde.py +14 -20
  24. flyte/_internal/runtime/taskrunner.py +1 -1
  25. flyte/_internal/runtime/trigger_serde.py +153 -0
  26. flyte/_logging.py +1 -1
  27. flyte/_protos/common/identifier_pb2.py +19 -1
  28. flyte/_protos/common/identifier_pb2.pyi +22 -0
  29. flyte/_protos/workflow/common_pb2.py +14 -3
  30. flyte/_protos/workflow/common_pb2.pyi +49 -0
  31. flyte/_protos/workflow/queue_service_pb2.py +41 -35
  32. flyte/_protos/workflow/queue_service_pb2.pyi +26 -12
  33. flyte/_protos/workflow/queue_service_pb2_grpc.py +34 -0
  34. flyte/_protos/workflow/run_definition_pb2.py +38 -38
  35. flyte/_protos/workflow/run_definition_pb2.pyi +4 -2
  36. flyte/_protos/workflow/run_service_pb2.py +60 -50
  37. flyte/_protos/workflow/run_service_pb2.pyi +24 -6
  38. flyte/_protos/workflow/run_service_pb2_grpc.py +34 -0
  39. flyte/_protos/workflow/task_definition_pb2.py +15 -11
  40. flyte/_protos/workflow/task_definition_pb2.pyi +19 -2
  41. flyte/_protos/workflow/task_service_pb2.py +18 -17
  42. flyte/_protos/workflow/task_service_pb2.pyi +5 -2
  43. flyte/_protos/workflow/trigger_definition_pb2.py +66 -0
  44. flyte/_protos/workflow/trigger_definition_pb2.pyi +117 -0
  45. flyte/_protos/workflow/trigger_definition_pb2_grpc.py +4 -0
  46. flyte/_protos/workflow/trigger_service_pb2.py +96 -0
  47. flyte/_protos/workflow/trigger_service_pb2.pyi +110 -0
  48. flyte/_protos/workflow/trigger_service_pb2_grpc.py +281 -0
  49. flyte/_run.py +42 -15
  50. flyte/_task.py +35 -4
  51. flyte/_task_environment.py +60 -15
  52. flyte/_trigger.py +382 -0
  53. flyte/_version.py +3 -3
  54. flyte/cli/_abort.py +3 -3
  55. flyte/cli/_build.py +1 -3
  56. flyte/cli/_common.py +15 -2
  57. flyte/cli/_create.py +74 -0
  58. flyte/cli/_delete.py +23 -1
  59. flyte/cli/_deploy.py +5 -9
  60. flyte/cli/_get.py +75 -34
  61. flyte/cli/_params.py +4 -2
  62. flyte/cli/_run.py +12 -3
  63. flyte/cli/_update.py +36 -0
  64. flyte/cli/_user.py +17 -0
  65. flyte/cli/main.py +9 -1
  66. flyte/errors.py +9 -0
  67. flyte/io/_dir.py +513 -115
  68. flyte/io/_file.py +495 -135
  69. flyte/models.py +32 -0
  70. flyte/remote/__init__.py +6 -1
  71. flyte/remote/_client/_protocols.py +36 -2
  72. flyte/remote/_client/controlplane.py +19 -3
  73. flyte/remote/_run.py +42 -2
  74. flyte/remote/_task.py +14 -1
  75. flyte/remote/_trigger.py +308 -0
  76. flyte/remote/_user.py +33 -0
  77. flyte/storage/__init__.py +6 -1
  78. flyte/storage/_storage.py +119 -101
  79. flyte/types/_pickle.py +16 -3
  80. {flyte-2.0.0b22.data → flyte-2.0.0b23.data}/scripts/runtime.py +35 -5
  81. {flyte-2.0.0b22.dist-info → flyte-2.0.0b23.dist-info}/METADATA +3 -1
  82. {flyte-2.0.0b22.dist-info → flyte-2.0.0b23.dist-info}/RECORD +87 -75
  83. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  84. {flyte-2.0.0b22.data → flyte-2.0.0b23.data}/scripts/debug.py +0 -0
  85. {flyte-2.0.0b22.dist-info → flyte-2.0.0b23.dist-info}/WHEEL +0 -0
  86. {flyte-2.0.0b22.dist-info → flyte-2.0.0b23.dist-info}/entry_points.txt +0 -0
  87. {flyte-2.0.0b22.dist-info → flyte-2.0.0b23.dist-info}/licenses/LICENSE +0 -0
  88. {flyte-2.0.0b22.dist-info → flyte-2.0.0b23.dist-info}/top_level.txt +0 -0
flyte/_initialize.py CHANGED
@@ -182,6 +182,28 @@ async def init(
182
182
  """
183
183
  from flyte._utils import get_cwd_editable_install, org_from_endpoint, sanitize_endpoint
184
184
 
185
+ if endpoint or api_key:
186
+ if project is None:
187
+ raise ValueError(
188
+ "Project must be provided to initialize the client. "
189
+ "Please set 'project' in the 'task' section of your config file, "
190
+ "or pass it directly to flyte.init(project='your-project-name')."
191
+ )
192
+
193
+ if domain is None:
194
+ raise ValueError(
195
+ "Domain must be provided to initialize the client. "
196
+ "Please set 'domain' in the 'task' section of your config file, "
197
+ "or pass it directly to flyte.init(domain='your-domain-name')."
198
+ )
199
+
200
+ if org is None and org_from_endpoint(endpoint) is None:
201
+ raise ValueError(
202
+ "Organization must be provided to initialize the client. "
203
+ "Please set 'org' in the 'task' section of your config file, "
204
+ "or pass it directly to flyte.init(org='your-org-name')."
205
+ )
206
+
185
207
  _initialize_logger(log_level=log_level)
186
208
 
187
209
  global _init_config # noqa: PLW0603
@@ -278,6 +300,7 @@ async def init_from_config(
278
300
  _initialize_logger(log_level=log_level)
279
301
 
280
302
  logger.info(f"Flyte config initialized as {cfg}", extra={"highlighter": ReprHighlighter()})
303
+
281
304
  await init.aio(
282
305
  org=cfg.task.org,
283
306
  project=cfg.task.project,
@@ -2,16 +2,20 @@ import asyncio
2
2
  import atexit
3
3
  import concurrent.futures
4
4
  import os
5
+ import pathlib
5
6
  import threading
6
7
  from typing import Any, Callable, Tuple, TypeVar
7
8
 
8
9
  import flyte.errors
10
+ from flyte._cache.cache import VersionParameters, cache_from_request
11
+ from flyte._cache.local_cache import LocalTaskCache
9
12
  from flyte._context import internal_ctx
10
13
  from flyte._internal.controllers import TraceInfo
11
14
  from flyte._internal.runtime import convert
12
15
  from flyte._internal.runtime.entrypoints import direct_dispatch
16
+ from flyte._internal.runtime.types_serde import transform_native_to_typed_interface
13
17
  from flyte._logging import log, logger
14
- from flyte._task import TaskTemplate
18
+ from flyte._task import AsyncFunctionTaskTemplate, TaskTemplate
15
19
  from flyte._utils.helpers import _selector_policy
16
20
  from flyte.models import ActionID, NativeInterface
17
21
  from flyte.remote._task import TaskDetails
@@ -81,31 +85,67 @@ class LocalController:
81
85
  raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
82
86
 
83
87
  inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
84
- serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
88
+ inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs)
89
+ task_interface = transform_native_to_typed_interface(_task.interface)
85
90
 
86
91
  sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
87
- tctx, _task.name, serialized_inputs, 0
92
+ tctx, _task.name, inputs_hash, 0
88
93
  )
89
94
  sub_action_raw_data_path = tctx.raw_data_path
90
-
91
- out, err = await direct_dispatch(
92
- _task,
93
- controller=self,
94
- action=sub_action_id,
95
- raw_data_path=sub_action_raw_data_path,
96
- inputs=inputs,
97
- version=tctx.version,
98
- checkpoints=tctx.checkpoints,
99
- code_bundle=tctx.code_bundle,
100
- output_path=sub_action_output_path,
101
- run_base_dir=tctx.run_base_dir,
95
+ # Make sure the output path exists
96
+ pathlib.Path(sub_action_output_path).mkdir(parents=True, exist_ok=True)
97
+ pathlib.Path(sub_action_raw_data_path.path).mkdir(parents=True, exist_ok=True)
98
+
99
+ task_cache = cache_from_request(_task.cache)
100
+ cache_enabled = task_cache.is_enabled()
101
+ if isinstance(_task, AsyncFunctionTaskTemplate):
102
+ version_parameters = VersionParameters(func=_task.func, image=_task.image)
103
+ else:
104
+ version_parameters = VersionParameters(func=None, image=_task.image)
105
+ cache_version = task_cache.get_version(version_parameters)
106
+ cache_key = convert.generate_cache_key_hash(
107
+ _task.name,
108
+ inputs_hash,
109
+ task_interface,
110
+ cache_version,
111
+ list(task_cache.get_ignored_inputs()),
112
+ inputs.proto_inputs,
102
113
  )
103
- if err:
104
- exc = convert.convert_error_to_native(err)
105
- if exc:
106
- raise exc
107
- else:
108
- raise flyte.errors.RuntimeSystemError("BadError", "Unknown error")
114
+
115
+ out = None
116
+ # We only get output from cache if the cache behavior is set to auto
117
+ if task_cache.behavior == "auto":
118
+ out = await LocalTaskCache.get(cache_key)
119
+ if out is not None:
120
+ logger.info(
121
+ f"Cache hit for task '{_task.name}' (version: {cache_version}), getting result from cache..."
122
+ )
123
+
124
+ if out is None:
125
+ out, err = await direct_dispatch(
126
+ _task,
127
+ controller=self,
128
+ action=sub_action_id,
129
+ raw_data_path=sub_action_raw_data_path,
130
+ inputs=inputs,
131
+ version=cache_version,
132
+ checkpoints=tctx.checkpoints,
133
+ code_bundle=tctx.code_bundle,
134
+ output_path=sub_action_output_path,
135
+ run_base_dir=tctx.run_base_dir,
136
+ )
137
+
138
+ if err:
139
+ exc = convert.convert_error_to_native(err)
140
+ if exc:
141
+ raise exc
142
+ else:
143
+ raise flyte.errors.RuntimeSystemError("BadError", "Unknown error")
144
+
145
+ # store into cache
146
+ if cache_enabled and out is not None:
147
+ await LocalTaskCache.set(cache_key, out)
148
+
109
149
  if _task.native_interface.outputs:
110
150
  if out is None:
111
151
  raise flyte.errors.RuntimeSystemError("BadOutput", "Task output not captured.")
@@ -129,7 +169,7 @@ class LocalController:
129
169
  pass
130
170
 
131
171
  async def stop(self):
132
- pass
172
+ await LocalTaskCache.close()
133
173
 
134
174
  async def watch_for_errors(self):
135
175
  pass
@@ -151,11 +191,11 @@ class LocalController:
151
191
  converted_inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
152
192
  assert converted_inputs
153
193
 
154
- serialized_inputs = converted_inputs.proto_inputs.SerializeToString(deterministic=True)
194
+ inputs_hash = convert.generate_inputs_hash_from_proto(converted_inputs.proto_inputs)
155
195
  action_id, action_output_path = convert.generate_sub_action_id_and_output_path(
156
196
  tctx,
157
197
  _func.__name__,
158
- serialized_inputs,
198
+ inputs_hash,
159
199
  0,
160
200
  )
161
201
  assert action_output_path
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import Literal
4
+ from typing import Literal, Optional
5
5
 
6
6
  from flyteidl.core import execution_pb2, interface_pb2
7
7
  from google.protobuf import timestamp_pb2
@@ -39,6 +39,7 @@ class Action:
39
39
  phase: run_definition_pb2.Phase | None = None
40
40
  started: bool = False
41
41
  retries: int = 0
42
+ queue: Optional[str] = None # The queue to which this action was submitted.
42
43
  client_err: Exception | None = None # This error is set when something goes wrong in the controller.
43
44
  cache_key: str | None = None # None means no caching, otherwise it is the version of the cache.
44
45
 
@@ -122,6 +123,7 @@ class Action:
122
123
  inputs_uri: str,
123
124
  run_output_base: str,
124
125
  cache_key: str | None = None,
126
+ queue: Optional[str] = None,
125
127
  ) -> Action:
126
128
  return cls(
127
129
  action_id=sub_action_id,
@@ -132,6 +134,7 @@ class Action:
132
134
  inputs_uri=inputs_uri,
133
135
  run_output_base=run_output_base,
134
136
  cache_key=cache_key,
137
+ queue=queue,
135
138
  )
136
139
 
137
140
  @classmethod
@@ -126,7 +126,7 @@ class RemoteController(Controller):
126
126
  workers=workers,
127
127
  max_system_retries=max_system_retries,
128
128
  )
129
- default_parent_concurrency = int(os.getenv("_F_P_CNC", "100"))
129
+ default_parent_concurrency = int(os.getenv("_F_P_CNC", "1000"))
130
130
  self._default_parent_concurrency = default_parent_concurrency
131
131
  self._parent_action_semaphore: DefaultDict[str, asyncio.Semaphore] = defaultdict(
132
132
  lambda: asyncio.Semaphore(default_parent_concurrency)
@@ -238,6 +238,7 @@ class RemoteController(Controller):
238
238
  inputs_uri=inputs_uri,
239
239
  run_output_base=tctx.run_base_dir,
240
240
  cache_key=cache_key,
241
+ queue=_task.queue,
241
242
  )
242
243
 
243
244
  try:
@@ -377,9 +378,10 @@ class RemoteController(Controller):
377
378
  invoke_seq_num = self.generate_task_call_sequence(_func, current_action_id)
378
379
  inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
379
380
  serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
381
+ inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs)
380
382
 
381
383
  sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
382
- tctx, func_name, serialized_inputs, invoke_seq_num
384
+ tctx, func_name, inputs_hash, invoke_seq_num
383
385
  )
384
386
 
385
387
  inputs_uri = io.inputs_path(sub_action_output_path)
@@ -539,6 +541,7 @@ class RemoteController(Controller):
539
541
  inputs_uri=inputs_uri,
540
542
  run_output_base=tctx.run_base_dir,
541
543
  cache_key=cache_key,
544
+ queue=None,
542
545
  )
543
546
 
544
547
  try:
@@ -118,13 +118,14 @@ class Controller:
118
118
  raise RuntimeError("Failure event not initialized")
119
119
  self._failure_event.set()
120
120
  except asyncio.CancelledError:
121
- pass
121
+ raise
122
122
 
123
123
  async def _bg_watch_for_errors(self):
124
124
  if self._failure_event is None:
125
125
  raise RuntimeError("Failure event not initialized")
126
126
  await self._failure_event.wait()
127
127
  logger.warning(f"Failure event received: {self._failure_event}, cleaning up informers and exiting.")
128
+ self._running = False
128
129
 
129
130
  async def watch_for_errors(self):
130
131
  """Watch for errors in the background thread"""
@@ -351,6 +352,7 @@ class Controller:
351
352
  ),
352
353
  spec=action.task,
353
354
  cache_key=cache_key,
355
+ cluster=action.queue,
354
356
  )
355
357
  elif action.type == "trace":
356
358
  trace = action.trace
@@ -440,10 +442,11 @@ class Controller:
440
442
  logger.warning(f"[{worker_id}] Retrying action {action.name} after backoff")
441
443
  await self._shared_queue.put(action)
442
444
  except Exception as e:
443
- logger.error(f"[{worker_id}] Error in controller loop: {e}")
445
+ logger.error(f"[{worker_id}] Error in controller loop for {action.name}: {e}")
444
446
  err = flyte.errors.RuntimeSystemError(
445
447
  code=type(e).__name__,
446
- message=f"Controller failed, system retries {action.retries} crossed threshold {self._max_retries}",
448
+ message=f"Controller failed, system retries {action.retries} / {self._max_retries} "
449
+ f"crossed threshold, for action {action.name}: {e}",
447
450
  worker=worker_id,
448
451
  )
449
452
  err.__cause__ = e
@@ -270,7 +270,7 @@ class Informer:
270
270
  logger.warning("Informer already running")
271
271
  return cast(asyncio.Task, self._watch_task)
272
272
  self._running = True
273
- self._watch_task = asyncio.create_task(self.watch())
273
+ self._watch_task = asyncio.create_task(self.watch(), name=f"InformerWatch-{self.parent_action_name}")
274
274
  await self.wait_for_cache_sync(timeout=timeout)
275
275
  return self._watch_task
276
276
 
@@ -22,6 +22,7 @@ from flyte._image import (
22
22
  Layer,
23
23
  PipOption,
24
24
  PipPackages,
25
+ PoetryProject,
25
26
  PythonWheels,
26
27
  Requirements,
27
28
  UVProject,
@@ -46,44 +47,71 @@ FLYTE_DOCKER_BUILDER_CACHE_TO = "FLYTE_DOCKER_BUILDER_CACHE_TO"
46
47
 
47
48
  UV_LOCK_WITHOUT_PROJECT_INSTALL_TEMPLATE = Template("""\
48
49
  RUN --mount=type=cache,sharing=locked,mode=0777,target=/root/.cache/uv,id=uv \
49
- --mount=type=bind,target=uv.lock,src=$UV_LOCK_PATH \
50
- --mount=type=bind,target=pyproject.toml,src=$PYPROJECT_PATH \
51
- $SECRET_MOUNT \
52
- uv sync --active $PIP_INSTALL_ARGS
50
+ --mount=type=bind,target=uv.lock,src=$UV_LOCK_PATH \
51
+ --mount=type=bind,target=pyproject.toml,src=$PYPROJECT_PATH \
52
+ $SECRET_MOUNT \
53
+ uv sync --active --inexact $PIP_INSTALL_ARGS
53
54
  """)
54
55
 
55
56
  UV_LOCK_INSTALL_TEMPLATE = Template("""\
56
- COPY $PYPROJECT_PATH $PYPROJECT_PATH
57
57
  RUN --mount=type=cache,sharing=locked,mode=0777,target=/root/.cache/uv,id=uv \
58
- $SECRET_MOUNT \
59
- uv sync --active $PIP_INSTALL_ARGS --project $PYPROJECT_PATH
58
+ --mount=type=bind,target=/root/.flyte/$PYPROJECT_PATH,src=$PYPROJECT_PATH,rw \
59
+ $SECRET_MOUNT \
60
+ uv sync --active --inexact --no-editable $PIP_INSTALL_ARGS --project /root/.flyte/$PYPROJECT_PATH
61
+ """)
62
+
63
+ POETRY_LOCK_WITHOUT_PROJECT_INSTALL_TEMPLATE = Template("""\
64
+ RUN --mount=type=cache,sharing=locked,mode=0777,target=/root/.cache/uv,id=uv \
65
+ uv pip install poetry
66
+
67
+ ENV POETRY_CACHE_DIR=/tmp/poetry_cache \
68
+ POETRY_VIRTUALENVS_IN_PROJECT=true
69
+
70
+ RUN --mount=type=cache,sharing=locked,mode=0777,target=/tmp/poetry_cache,id=poetry \
71
+ --mount=type=bind,target=poetry.lock,src=$POETRY_LOCK_PATH \
72
+ --mount=type=bind,target=pyproject.toml,src=$PYPROJECT_PATH \
73
+ $SECRET_MOUNT \
74
+ poetry install $POETRY_INSTALL_ARGS
75
+ """)
76
+
77
+ POETRY_LOCK_INSTALL_TEMPLATE = Template("""\
78
+ RUN --mount=type=cache,sharing=locked,mode=0777,target=/root/.cache/uv,id=uv \
79
+ uv pip install poetry
80
+
81
+ ENV POETRY_CACHE_DIR=/tmp/poetry_cache \
82
+ POETRY_VIRTUALENVS_IN_PROJECT=true
83
+
84
+ RUN --mount=type=cache,sharing=locked,mode=0777,target=/tmp/poetry_cache,id=poetry \
85
+ --mount=type=bind,target=/root/.flyte/$PYPROJECT_PATH,src=$PYPROJECT_PATH,rw \
86
+ $SECRET_MOUNT \
87
+ poetry install $POETRY_INSTALL_ARGS
60
88
  """)
61
89
 
62
90
  UV_PACKAGE_INSTALL_COMMAND_TEMPLATE = Template("""\
63
91
  RUN --mount=type=cache,sharing=locked,mode=0777,target=/root/.cache/uv,id=uv \
64
- $REQUIREMENTS_MOUNT \
65
- $SECRET_MOUNT \
66
- uv pip install --python $$UV_PYTHON $PIP_INSTALL_ARGS
92
+ $REQUIREMENTS_MOUNT \
93
+ $SECRET_MOUNT \
94
+ uv pip install --python $$UV_PYTHON $PIP_INSTALL_ARGS
67
95
  """)
68
96
 
69
97
  UV_WHEEL_INSTALL_COMMAND_TEMPLATE = Template("""\
70
98
  RUN --mount=type=cache,sharing=locked,mode=0777,target=/root/.cache/uv,id=wheel \
71
- --mount=source=/dist,target=/dist,type=bind \
72
- $SECRET_MOUNT \
73
- uv pip install --python $$UV_PYTHON $PIP_INSTALL_ARGS
99
+ --mount=source=/dist,target=/dist,type=bind \
100
+ $SECRET_MOUNT \
101
+ uv pip install --python $$UV_PYTHON $PIP_INSTALL_ARGS
74
102
  """)
75
103
 
76
104
  APT_INSTALL_COMMAND_TEMPLATE = Template("""\
77
105
  RUN --mount=type=cache,sharing=locked,mode=0777,target=/var/cache/apt,id=apt \
78
- $SECRET_MOUNT \
79
- apt-get update && apt-get install -y --no-install-recommends \
80
- $APT_PACKAGES
106
+ $SECRET_MOUNT \
107
+ apt-get update && apt-get install -y --no-install-recommends \
108
+ $APT_PACKAGES
81
109
  """)
82
110
 
83
111
  UV_PYTHON_INSTALL_COMMAND = Template("""\
84
112
  RUN --mount=type=cache,sharing=locked,mode=0777,target=/root/.cache/uv,id=uv \
85
- $SECRET_MOUNT \
86
- uv pip install $PIP_INSTALL_ARGS
113
+ $SECRET_MOUNT \
114
+ uv pip install $PIP_INSTALL_ARGS
87
115
  """)
88
116
 
89
117
  # uv pip install --python /root/env/bin/python
@@ -93,24 +121,29 @@ DOCKER_FILE_UV_BASE_TEMPLATE = Template("""\
93
121
  FROM ghcr.io/astral-sh/uv:0.8.13 AS uv
94
122
  FROM $BASE_IMAGE
95
123
 
124
+
96
125
  USER root
97
126
 
127
+
98
128
  # Copy in uv so that later commands don't have to mount it in
99
129
  COPY --from=uv /uv /usr/bin/uv
100
130
 
131
+
101
132
  # Configure default envs
102
133
  ENV UV_COMPILE_BYTECODE=1 \
103
- UV_LINK_MODE=copy \
104
- VIRTUALENV=/opt/venv \
105
- UV_PYTHON=/opt/venv/bin/python \
106
- PATH="/opt/venv/bin:$$PATH"
134
+ UV_LINK_MODE=copy \
135
+ VIRTUALENV=/opt/venv \
136
+ UV_PYTHON=/opt/venv/bin/python \
137
+ PATH="/opt/venv/bin:$$PATH"
138
+
107
139
 
108
140
  # Create a virtualenv with the user specified python version
109
141
  RUN uv venv $$VIRTUALENV --python=$PYTHON_VERSION
110
142
 
143
+
111
144
  # Adds nvidia just in case it exists
112
145
  ENV PATH="$$PATH:/usr/local/nvidia/bin:/usr/local/cuda/bin" \
113
- LD_LIBRARY_PATH="/usr/local/nvidia/lib64"
146
+ LD_LIBRARY_PATH="/usr/local/nvidia/lib64"
114
147
  """)
115
148
 
116
149
  # This gets added on to the end of the dockerfile
@@ -258,6 +291,32 @@ class UVProjectHandler:
258
291
  return dockerfile
259
292
 
260
293
 
294
+ class PoetryProjectHandler:
295
+ @staticmethod
296
+ async def handel(layer: PoetryProject, context_path: Path, dockerfile: str) -> str:
297
+ secret_mounts = _get_secret_mounts_layer(layer.secret_mounts)
298
+ if layer.extra_args and "--no-root" in layer.extra_args:
299
+ # Only Copy pyproject.yaml and poetry.lock.
300
+ pyproject_dst = copy_files_to_context(layer.pyproject, context_path)
301
+ poetry_lock_dst = copy_files_to_context(layer.poetry_lock, context_path)
302
+ delta = POETRY_LOCK_WITHOUT_PROJECT_INSTALL_TEMPLATE.substitute(
303
+ POETRY_LOCK_PATH=poetry_lock_dst.relative_to(context_path),
304
+ PYPROJECT_PATH=pyproject_dst.relative_to(context_path),
305
+ POETRY_INSTALL_ARGS=layer.extra_args or "",
306
+ SECRET_MOUNT=secret_mounts,
307
+ )
308
+ else:
309
+ # Copy the entire project.
310
+ pyproject_dst = copy_files_to_context(layer.pyproject.parent, context_path)
311
+ delta = POETRY_LOCK_INSTALL_TEMPLATE.substitute(
312
+ PYPROJECT_PATH=pyproject_dst.relative_to(context_path),
313
+ POETRY_INSTALL_ARGS=layer.extra_args or "",
314
+ SECRET_MOUNT=secret_mounts,
315
+ )
316
+ dockerfile += delta
317
+ return dockerfile
318
+
319
+
261
320
  class DockerIgnoreHandler:
262
321
  @staticmethod
263
322
  async def handle(layer: DockerIgnore, context_path: Path, _: str):
@@ -332,8 +391,9 @@ class CommandsHandler:
332
391
  @staticmethod
333
392
  async def handle(layer: Commands, _: Path, dockerfile: str) -> str:
334
393
  # Append raw commands to the dockerfile
394
+ secret_mounts = _get_secret_mounts_layer(layer.secret_mounts)
335
395
  for command in layer.commands:
336
- dockerfile += f"\nRUN {command}\n"
396
+ dockerfile += f"\nRUN {secret_mounts} {command}\n"
337
397
 
338
398
  return dockerfile
339
399
 
@@ -355,9 +415,8 @@ def _get_secret_commands(layers: typing.Tuple[Layer, ...]) -> typing.List[str]:
355
415
  secret = Secret(key=secret)
356
416
  secret_id = hash(secret)
357
417
  secret_env_key = "_".join([k.upper() for k in filter(None, (secret.group, secret.key))])
358
- secret_env = os.getenv(secret_env_key)
359
- if secret_env:
360
- return ["--secret", f"id={secret_id},env={secret_env}"]
418
+ if os.getenv(secret_env_key):
419
+ return ["--secret", f"id={secret_id},env={secret_env_key}"]
361
420
  secret_file_name = "_".join(list(filter(None, (secret.group, secret.key))))
362
421
  secret_file_path = f"/etc/secrets/{secret_file_name}"
363
422
  if not os.path.exists(secret_file_path):
@@ -365,7 +424,7 @@ def _get_secret_commands(layers: typing.Tuple[Layer, ...]) -> typing.List[str]:
365
424
  return ["--secret", f"id={secret_id},src={secret_file_path}"]
366
425
 
367
426
  for layer in layers:
368
- if isinstance(layer, (PipOption, AptPackages)):
427
+ if isinstance(layer, (PipOption, AptPackages, Commands)):
369
428
  if layer.secret_mounts:
370
429
  for secret_mount in layer.secret_mounts:
371
430
  commands.extend(_get_secret_command(secret_mount))
@@ -426,6 +485,10 @@ async def _process_layer(
426
485
  # Handle UV project
427
486
  dockerfile = await UVProjectHandler.handle(layer, context_path, dockerfile)
428
487
 
488
+ case PoetryProject():
489
+ # Handle Poetry project
490
+ dockerfile = await PoetryProjectHandler.handel(layer, context_path, dockerfile)
491
+
429
492
  case CopyConfig():
430
493
  # Handle local files and folders
431
494
  dockerfile = await CopyConfigHandler.handle(layer, context_path, dockerfile, docker_ignore_file_path)
@@ -572,6 +635,7 @@ class DockerImageBuilder(ImageBuilder):
572
635
  - start from the base image
573
636
  - use python to create a default venv and export variables
574
637
 
638
+
575
639
  Then for the layers
576
640
  - for each layer
577
641
  - find the appropriate layer handler
@@ -135,11 +135,6 @@ class ImageBuildEngine:
135
135
 
136
136
  ImageBuilderType = typing.Literal["local", "remote"]
137
137
 
138
- _SEEN_IMAGES: typing.ClassVar[typing.Dict[str, str]] = {
139
- # Set default for the auto container. See Image._identifier_override for more info.
140
- "auto": Image.from_debian_base().uri,
141
- }
142
-
143
138
  @staticmethod
144
139
  @alru_cache
145
140
  async def image_exists(image: Image) -> Optional[str]:
@@ -235,7 +230,7 @@ class ImageBuildEngine:
235
230
 
236
231
 
237
232
  class ImageCache(BaseModel):
238
- image_lookup: Dict[str, Dict[str, str]]
233
+ image_lookup: Dict[str, str]
239
234
  serialized_form: str | None = None
240
235
 
241
236
  @property
@@ -273,11 +268,10 @@ class ImageCache(BaseModel):
273
268
  """
274
269
  tuples = []
275
270
  for k, v in self.image_lookup.items():
276
- for py_version, image_uri in v.items():
277
- tuples.append(
278
- [
279
- ("Name", f"{k} (py{py_version})"),
280
- ("image", image_uri),
281
- ]
282
- )
271
+ tuples.append(
272
+ [
273
+ ("Name", k),
274
+ ("image", v),
275
+ ]
276
+ )
283
277
  return tuples
@@ -182,6 +182,11 @@ async def _validate_configuration(image: Image) -> Tuple[str, Optional[str]]:
182
182
  def _get_layers_proto(image: Image, context_path: Path) -> "image_definition_pb2.ImageSpec":
183
183
  from flyte._protos.imagebuilder import definition_pb2 as image_definition_pb2
184
184
 
185
+ if image.dockerfile is not None:
186
+ raise flyte.errors.ImageBuildError(
187
+ "Custom Dockerfile is not supported with remote image builder.You can use local image builder instead."
188
+ )
189
+
185
190
  layers = []
186
191
  for layer in image._layers:
187
192
  secret_mounts = None
@@ -251,7 +256,7 @@ def _get_layers_proto(image: Image, context_path: Path) -> "image_definition_pb2
251
256
  if "tool.uv.index" in line:
252
257
  raise ValueError("External sources are not supported in pyproject.toml")
253
258
 
254
- if layer.extra_index_urls and "--no-install-project" in layer.extra_index_urls:
259
+ if layer.extra_args and "--no-install-project" in layer.extra_args:
255
260
  # Copy pyproject itself
256
261
  pyproject_dst = copy_files_to_context(layer.pyproject, context_path)
257
262
  else:
@@ -9,6 +9,7 @@ from flyteidl.core import errors_pb2, execution_pb2
9
9
 
10
10
  import flyte.storage as storage
11
11
  from flyte._protos.workflow import run_definition_pb2
12
+ from flyte.models import PathRewrite
12
13
 
13
14
  from .convert import Inputs, Outputs, _clean_error_code
14
15
 
@@ -90,10 +91,11 @@ async def upload_error(err: execution_pb2.ExecutionError, output_prefix: str):
90
91
 
91
92
 
92
93
  # ------------------------------- DOWNLOAD Methods ------------------------------- #
93
- async def load_inputs(path: str, max_bytes: int = -1) -> Inputs:
94
+ async def load_inputs(path: str, max_bytes: int = -1, path_rewrite_config: PathRewrite | None = None) -> Inputs:
94
95
  """
95
96
  :param path: Input file to be downloaded
96
97
  :param max_bytes: Maximum number of bytes to read from the input file. Default is -1, which means no limit.
98
+ :param path_rewrite_config: If provided, rewrites paths in the input blobs according to the configuration.
97
99
  :return: Inputs object
98
100
  """
99
101
  lm = run_definition_pb2.Inputs()
@@ -115,6 +117,16 @@ async def load_inputs(path: str, max_bytes: int = -1) -> Inputs:
115
117
  proto_str = b"".join(proto_bytes)
116
118
 
117
119
  lm.ParseFromString(proto_str)
120
+
121
+ if path_rewrite_config is not None:
122
+ for inp in lm.literals:
123
+ if inp.value.HasField("scalar") and inp.value.scalar.HasField("blob"):
124
+ scalar_blob = inp.value.scalar.blob
125
+ if scalar_blob.uri.startswith(path_rewrite_config.old_prefix):
126
+ scalar_blob.uri = scalar_blob.uri.replace(
127
+ path_rewrite_config.old_prefix, path_rewrite_config.new_prefix, 1
128
+ )
129
+
118
130
  return Inputs(proto_inputs=lm)
119
131
 
120
132
 
@@ -11,7 +11,7 @@ from flyte._internal.runtime.entrypoints import download_code_bundle, load_pkl_t
11
11
  from flyte._internal.runtime.taskrunner import extract_download_run_upload
12
12
  from flyte._logging import logger
13
13
  from flyte._task import TaskTemplate
14
- from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath
14
+ from flyte.models import ActionID, Checkpoints, CodeBundle, PathRewrite, RawDataPath
15
15
 
16
16
 
17
17
  async def download_tgz(destination: str, version: str, tgz: str) -> CodeBundle:
@@ -115,6 +115,7 @@ async def run_task(
115
115
  prev_checkpoint: str | None = None,
116
116
  code_bundle: CodeBundle | None = None,
117
117
  input_path: str | None = None,
118
+ path_rewrite_cfg: str | None = None,
118
119
  ):
119
120
  """
120
121
  Runs the task with the provided parameters.
@@ -134,6 +135,7 @@ async def run_task(
134
135
  :param controller: The controller to use for the task.
135
136
  :param code_bundle: Optional code bundle for the task.
136
137
  :param input_path: Optional input path for the task.
138
+ :param path_rewrite_cfg: Optional path rewrite configuration.
137
139
  :return: The loaded task template.
138
140
  """
139
141
  start_time = time.time()
@@ -144,6 +146,19 @@ async def run_task(
144
146
  f" at {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))}"
145
147
  )
146
148
 
149
+ path_rewrite = PathRewrite.from_str(path_rewrite_cfg) if path_rewrite_cfg else None
150
+ if path_rewrite:
151
+ import flyte.storage as storage
152
+
153
+ if not await storage.exists(path_rewrite.new_prefix):
154
+ logger.error(
155
+ f"[rusty] Path rewrite failed for path {path_rewrite.new_prefix}, "
156
+ f"not found, reverting to original path {path_rewrite.old_prefix}"
157
+ )
158
+ path_rewrite = None
159
+ else:
160
+ logger.info(f"[rusty] Using path rewrite: {path_rewrite}")
161
+
147
162
  try:
148
163
  await contextual_run(
149
164
  extract_download_run_upload,
@@ -151,7 +166,7 @@ async def run_task(
151
166
  action=ActionID(name=name, org=org, project=project, domain=domain, run_name=run_name),
152
167
  version=version,
153
168
  controller=controller,
154
- raw_data_path=RawDataPath(path=raw_data_path),
169
+ raw_data_path=RawDataPath(path=raw_data_path, path_rewrite=path_rewrite),
155
170
  output_path=output_path,
156
171
  run_base_dir=run_base_dir,
157
172
  checkpoints=Checkpoints(prev_checkpoint_path=prev_checkpoint, checkpoint_path=checkpoint_path),