flyte 0.2.0b9__py3-none-any.whl → 0.2.0b11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (41) hide show
  1. flyte/__init__.py +4 -2
  2. flyte/_bin/runtime.py +6 -3
  3. flyte/_deploy.py +3 -0
  4. flyte/_initialize.py +30 -6
  5. flyte/_internal/controllers/_local_controller.py +4 -3
  6. flyte/_internal/controllers/_trace.py +1 -0
  7. flyte/_internal/controllers/remote/_action.py +1 -1
  8. flyte/_internal/controllers/remote/_informer.py +1 -1
  9. flyte/_internal/runtime/convert.py +7 -4
  10. flyte/_internal/runtime/task_serde.py +80 -10
  11. flyte/_internal/runtime/taskrunner.py +1 -1
  12. flyte/_logging.py +1 -1
  13. flyte/_pod.py +19 -0
  14. flyte/_run.py +84 -39
  15. flyte/_task.py +2 -13
  16. flyte/_utils/org_discovery.py +31 -0
  17. flyte/_version.py +2 -2
  18. flyte/cli/_common.py +6 -6
  19. flyte/cli/_create.py +16 -8
  20. flyte/cli/_params.py +2 -2
  21. flyte/cli/_run.py +1 -1
  22. flyte/cli/main.py +4 -8
  23. flyte/errors.py +11 -0
  24. flyte/extras/_container.py +29 -39
  25. flyte/io/__init__.py +17 -1
  26. flyte/io/_file.py +2 -0
  27. flyte/io/{structured_dataset → _structured_dataset}/basic_dfs.py +1 -1
  28. flyte/io/{structured_dataset → _structured_dataset}/structured_dataset.py +1 -1
  29. flyte/models.py +1 -0
  30. flyte/remote/_data.py +2 -1
  31. flyte/types/__init__.py +23 -0
  32. flyte/{io/pickle/transformer.py → types/_pickle.py} +2 -1
  33. flyte/types/_type_engine.py +7 -5
  34. {flyte-0.2.0b9.dist-info → flyte-0.2.0b11.dist-info}/METADATA +5 -6
  35. {flyte-0.2.0b9.dist-info → flyte-0.2.0b11.dist-info}/RECORD +39 -39
  36. flyte/io/_dataframe.py +0 -0
  37. flyte/io/pickle/__init__.py +0 -0
  38. /flyte/io/{structured_dataset → _structured_dataset}/__init__.py +0 -0
  39. {flyte-0.2.0b9.dist-info → flyte-0.2.0b11.dist-info}/WHEEL +0 -0
  40. {flyte-0.2.0b9.dist-info → flyte-0.2.0b11.dist-info}/entry_points.txt +0 -0
  41. {flyte-0.2.0b9.dist-info → flyte-0.2.0b11.dist-info}/top_level.txt +0 -0
flyte/__init__.py CHANGED
@@ -25,6 +25,7 @@ __all__ = [
25
25
  "Device",
26
26
  "Environment",
27
27
  "Image",
28
+ "PodTemplate",
28
29
  "Resources",
29
30
  "RetryStrategy",
30
31
  "ReusePolicy",
@@ -38,7 +39,7 @@ __all__ = [
38
39
  "deploy",
39
40
  "group",
40
41
  "init",
41
- "init_auto_from_config",
42
+ "init_from_config",
42
43
  "map",
43
44
  "run",
44
45
  "trace",
@@ -51,8 +52,9 @@ from ._deploy import deploy
51
52
  from ._environment import Environment
52
53
  from ._group import group
53
54
  from ._image import Image
54
- from ._initialize import init, init_auto_from_config
55
+ from ._initialize import init, init_from_config
55
56
  from ._map import map
57
+ from ._pod import PodTemplate
56
58
  from ._resources import GPU, TPU, Device, Resources
57
59
  from ._retry import RetryStrategy
58
60
  from ._reusable_environment import ReusePolicy
flyte/_bin/runtime.py CHANGED
@@ -76,13 +76,17 @@ def main(
76
76
  ):
77
77
  sys.path.insert(0, ".")
78
78
 
79
+ import flyte
79
80
  import flyte._utils as utils
80
81
  from flyte._initialize import initialize_in_cluster
81
82
  from flyte._internal.controllers import create_controller
82
83
  from flyte._internal.imagebuild.image_builder import ImageCache
83
84
  from flyte._internal.runtime.entrypoints import load_and_run_task
85
+ from flyte._logging import logger
84
86
  from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath
85
87
 
88
+ logger.info(f"Initializing flyte runtime - version {flyte.__version__}")
89
+
86
90
  assert org, "Org is required for now"
87
91
  assert project, "Project is required"
88
92
  assert domain, "Domain is required"
@@ -98,15 +102,14 @@ def main(
98
102
  # This detection of api key is a hack for now.
99
103
  controller_kwargs: dict[str, Any] = {"insecure": False}
100
104
  if api_key := os.getenv(_UNION_EAGER_API_KEY_ENV_VAR):
101
- from flyte._logging import logger
102
-
103
- logger.warning(f"Using api key {api_key}")
105
+ logger.info("Using api key from environment")
104
106
  controller_kwargs["api_key"] = api_key
105
107
  else:
106
108
  ep = os.environ.get(ENDPOINT_OVERRIDE, "host.docker.internal:8090")
107
109
  controller_kwargs["endpoint"] = ep
108
110
  if "localhost" in ep or "docker" in ep:
109
111
  controller_kwargs["insecure"] = True
112
+ logger.debug(f"Using controller endpoint: {ep} with kwargs: {controller_kwargs}")
110
113
 
111
114
  bundle = CodeBundle(tgz=tgz, pkl=pkl, destination=dest, computed_version=version)
112
115
  initialize_in_cluster()
flyte/_deploy.py CHANGED
@@ -128,6 +128,9 @@ async def apply(deployment: DeploymentPlan, copy_style: CopyFiles, dryrun: bool
128
128
  else:
129
129
  code_bundle = await build_code_bundle(from_dir=cfg.root_dir, dryrun=dryrun, copy_style=copy_style)
130
130
  deployment.version = code_bundle.computed_version
131
+ # TODO we should update the version to include the image cache digest and code bundle digest. This is
132
+ # to ensure that changes in image dependencies, cause an update to the deployment version.
133
+ # TODO Also hash the environment and tasks to ensure that changes in the environment or tasks
131
134
 
132
135
  sc = SerializationContext(
133
136
  project=cfg.project,
flyte/_initialize.py CHANGED
@@ -138,14 +138,13 @@ async def init(
138
138
  :param project: Optional project name (not used in this implementation)
139
139
  :param domain: Optional domain name (not used in this implementation)
140
140
  :param root_dir: Optional root directory from which to determine how to load files, and find paths to files.
141
+ This is useful for determining the root directory for the current project, and for locating files like config etc.
142
+ also use to determine all the code that needs to be copied to the remote location.
141
143
  defaults to the editable install directory if the cwd is in a Python editable install, else just the cwd.
142
144
  :param log_level: Optional logging level for the logger, default is set using the default initialization policies
143
145
  :param api_key: Optional API key for authentication
144
146
  :param endpoint: Optional API endpoint URL
145
147
  :param headless: Optional Whether to run in headless mode
146
- :param mode: Optional execution model (local, remote). Default is local. When local is used,
147
- the execution will be done locally. When remote is used, the execution will be sent to a remote server,
148
- In the remote case, the endpoint or api_key must be set.
149
148
  :param insecure_skip_verify: Whether to skip SSL certificate verification
150
149
  :param auth_client_config: Optional client configuration for authentication
151
150
  :param auth_type: The authentication type to use (Pkce, ClientSecret, ExternalCommand, DeviceFlow)
@@ -169,6 +168,8 @@ async def init(
169
168
  """
170
169
  from flyte._utils import get_cwd_editable_install
171
170
 
171
+ from ._utils.org_discovery import org_from_endpoint
172
+
172
173
  interactive_mode = ipython_check()
173
174
 
174
175
  initialize_logger(enable_rich=interactive_mode)
@@ -177,6 +178,9 @@ async def init(
177
178
 
178
179
  global _init_config # noqa: PLW0603
179
180
 
181
+ if endpoint and "://" not in endpoint:
182
+ endpoint = f"dns:///{endpoint}"
183
+
180
184
  with _init_lock:
181
185
  client = None
182
186
  if endpoint or api_key:
@@ -204,17 +208,25 @@ async def init(
204
208
  domain=domain,
205
209
  client=client,
206
210
  storage=storage,
207
- org=org,
211
+ org=org or org_from_endpoint(endpoint),
208
212
  )
209
213
 
210
214
 
211
215
  @syncify
212
- async def init_auto_from_config(path_or_config: str | Config | None = None) -> None:
216
+ async def init_from_config(
217
+ path_or_config: str | Config | None = None, root_dir: Path | None = None, log_level: int | None = None
218
+ ) -> None:
213
219
  """
214
220
  Initialize the Flyte system using a configuration file or Config object. This method should be called before any
215
221
  other Flyte remote API methods are called. Thread-safe implementation.
216
222
 
217
223
  :param path_or_config: Path to the configuration file or Config object
224
+ :param root_dir: Optional root directory from which to determine how to load files, and find paths to
225
+ files like config etc. For example if one uses the copy-style=="all", it is essential to determine the
226
+ root directory for the current project. If not provided, it defaults to the editable install directory or
227
+ if not available, the current working directory.
228
+ :param log_level: Optional logging level for the framework logger,
229
+ default is set using the default initialization policies
218
230
  :return: None
219
231
  """
220
232
  import flyte.config as config
@@ -222,7 +234,17 @@ async def init_auto_from_config(path_or_config: str | Config | None = None) -> N
222
234
  cfg: config.Config
223
235
  if path_or_config is None or isinstance(path_or_config, str):
224
236
  # If a string is passed, treat it as a path to the config file
225
- cfg = config.auto(path_or_config)
237
+ if path_or_config:
238
+ if not Path(path_or_config).exists():
239
+ raise InitializationError(
240
+ "ConfigFileNotFoundError",
241
+ "user",
242
+ f"Configuration file '{path_or_config}' does not exist., current working directory is {Path.cwd()}",
243
+ )
244
+ if root_dir and path_or_config:
245
+ cfg = config.auto(str(root_dir / path_or_config))
246
+ else:
247
+ cfg = config.auto(path_or_config)
226
248
  else:
227
249
  # If a Config object is passed, use it directly
228
250
  cfg = path_or_config
@@ -241,6 +263,8 @@ async def init_auto_from_config(path_or_config: str | Config | None = None) -> N
241
263
  proxy_command=cfg.platform.proxy_command,
242
264
  client_id=cfg.platform.client_id,
243
265
  client_credentials_secret=cfg.platform.client_credentials_secret,
266
+ root_dir=root_dir,
267
+ log_level=log_level,
244
268
  )
245
269
 
246
270
 
@@ -14,7 +14,7 @@ from flyte._logging import log, logger
14
14
  from flyte._protos.workflow import task_definition_pb2
15
15
  from flyte._task import TaskTemplate
16
16
  from flyte._utils.helpers import _selector_policy
17
- from flyte.models import ActionID, NativeInterface, RawDataPath
17
+ from flyte.models import ActionID, NativeInterface
18
18
 
19
19
  R = TypeVar("R")
20
20
 
@@ -86,7 +86,7 @@ class LocalController:
86
86
  sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
87
87
  tctx, _task.name, serialized_inputs, 0
88
88
  )
89
- sub_action_raw_data_path = RawDataPath(path=sub_action_output_path)
89
+ sub_action_raw_data_path = tctx.raw_data_path
90
90
 
91
91
  out, err = await direct_dispatch(
92
92
  _task,
@@ -162,6 +162,7 @@ class LocalController:
162
162
  action=action_id,
163
163
  interface=_interface,
164
164
  inputs_path=action_output_path,
165
+ name=_func.__name__,
165
166
  ),
166
167
  True,
167
168
  )
@@ -179,7 +180,7 @@ class LocalController:
179
180
 
180
181
  if info.interface.outputs and info.output:
181
182
  # If the result is not an AsyncGenerator, convert it directly
182
- converted_outputs = await convert.convert_from_native_to_outputs(info.output, info.interface)
183
+ converted_outputs = await convert.convert_from_native_to_outputs(info.output, info.interface, info.name)
183
184
  assert converted_outputs
184
185
  elif info.error:
185
186
  # If there is an error, convert it to a native error
@@ -18,6 +18,7 @@ class TraceInfo:
18
18
  duration: Optional[timedelta] = None
19
19
  output: Optional[Any] = None
20
20
  error: Optional[Exception] = None
21
+ name: str = ""
21
22
 
22
23
  def add_outputs(self, output: Any, duration: timedelta):
23
24
  """
@@ -130,7 +130,7 @@ class Action:
130
130
  """
131
131
  from flyte._logging import logger
132
132
 
133
- logger.info(f"In Action from_state {obj.action_id} {obj.phase} {obj.output_uri}")
133
+ logger.debug(f"In Action from_state {obj.action_id} {obj.phase} {obj.output_uri}")
134
134
  return cls(
135
135
  action_id=obj.action_id,
136
136
  parent_action_name=parent_action_name,
@@ -235,7 +235,7 @@ class Informer:
235
235
  await self._shared_queue.put(node)
236
236
  # hack to work in the absence of sentinel
237
237
  except asyncio.CancelledError:
238
- logger.warning(f"Watch cancelled: {self.name}")
238
+ logger.info(f"Watch cancelled: {self.name}")
239
239
  return
240
240
  except asyncio.TimeoutError as e:
241
241
  logger.error(f"Watch timeout: {self.name}", exc_info=e)
@@ -11,7 +11,7 @@ import flyte.errors
11
11
  import flyte.storage as storage
12
12
  from flyte._protos.workflow import run_definition_pb2, task_definition_pb2
13
13
  from flyte.models import ActionID, NativeInterface, TaskContext
14
- from flyte.types import TypeEngine
14
+ from flyte.types import TypeEngine, TypeTransformerFailedError
15
15
 
16
16
 
17
17
  @dataclass(frozen=True)
@@ -80,7 +80,7 @@ async def convert_from_native_to_inputs(interface: NativeInterface, *args, **kwa
80
80
  )
81
81
 
82
82
 
83
- async def convert_from_native_to_outputs(o: Any, interface: NativeInterface) -> Outputs:
83
+ async def convert_from_native_to_outputs(o: Any, interface: NativeInterface, task_name: str = "") -> Outputs:
84
84
  # Always make it a tuple even if it's just one item to simplify logic below
85
85
  if not isinstance(o, tuple):
86
86
  o = (o,)
@@ -90,8 +90,11 @@ async def convert_from_native_to_outputs(o: Any, interface: NativeInterface) ->
90
90
  )
91
91
  named = []
92
92
  for (output_name, python_type), v in zip(interface.outputs.items(), o):
93
- lit = await TypeEngine.to_literal(v, python_type, TypeEngine.to_literal_type(python_type))
94
- named.append(run_definition_pb2.NamedLiteral(name=output_name, value=lit))
93
+ try:
94
+ lit = await TypeEngine.to_literal(v, python_type, TypeEngine.to_literal_type(python_type))
95
+ named.append(run_definition_pb2.NamedLiteral(name=output_name, value=lit))
96
+ except TypeTransformerFailedError as e:
97
+ raise flyte.errors.RuntimeDataValidationError(output_name, e, task_name)
95
98
 
96
99
  return Outputs(proto_outputs=run_definition_pb2.Outputs(literals=named))
97
100
 
@@ -3,9 +3,11 @@ This module provides functionality to serialize and deserialize tasks to and fro
3
3
  It includes a Resolver interface for loading tasks, and functions to load classes and tasks.
4
4
  """
5
5
 
6
+ import copy
6
7
  import importlib
8
+ import typing
7
9
  from datetime import timedelta
8
- from typing import Optional, Type
10
+ from typing import Optional, Type, cast
9
11
 
10
12
  from flyteidl.core import identifier_pb2, literals_pb2, security_pb2, tasks_pb2
11
13
  from google.protobuf import duration_pb2, wrappers_pb2
@@ -13,6 +15,7 @@ from google.protobuf import duration_pb2, wrappers_pb2
13
15
  import flyte.errors
14
16
  from flyte._cache.cache import VersionParameters, cache_from_request
15
17
  from flyte._logging import logger
18
+ from flyte._pod import _PRIMARY_CONTAINER_NAME_FIELD, PodTemplate
16
19
  from flyte._protos.workflow import task_definition_pb2
17
20
  from flyte._secret import SecretRequest, secrets_from_request
18
21
  from flyte._task import AsyncFunctionTaskTemplate, TaskTemplate
@@ -121,17 +124,18 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
121
124
  # if task.parent_env is None:
122
125
  # raise ValueError(f"Task {task.name} must have a parent environment")
123
126
 
124
- #
125
- # This pod will be incorrect when doing fast serialize
126
- #
127
- container = _get_urun_container(serialize_context, task)
127
+ # TODO Add support for SQL, extra_config, custom
128
+ extra_config: typing.Dict[str, str] = {}
129
+ custom = {} # type: ignore
128
130
 
129
- # TODO Add support for SQL, Pod, extra_config, custom
130
- pod = None
131
131
  sql = None
132
- # pod = task.get_k8s_pod(serialize_context)
133
- extra_config = {} # type: ignore
134
- custom = {} # type: ignore
132
+ if task.pod_template and not isinstance(task.pod_template, str):
133
+ container = None
134
+ pod = _get_k8s_pod(_get_urun_container(serialize_context, task), task.pod_template)
135
+ extra_config[_PRIMARY_CONTAINER_NAME_FIELD] = task.pod_template.primary_container_name
136
+ else:
137
+ container = _get_urun_container(serialize_context, task)
138
+ pod = None
135
139
 
136
140
  # -------------- CACHE HANDLING ----------------------
137
141
  task_cache = cache_from_request(task.cache)
@@ -210,6 +214,72 @@ def _get_urun_container(
210
214
  )
211
215
 
212
216
 
217
+ def _sanitize_resource_name(resource: tasks_pb2.Resources.ResourceEntry) -> str:
218
+ return tasks_pb2.Resources.ResourceName.Name(resource.name).lower().replace("_", "-")
219
+
220
+
221
+ def _get_k8s_pod(primary_container: tasks_pb2.Container, pod_template: PodTemplate) -> Optional[tasks_pb2.K8sPod]:
222
+ """
223
+ Get the K8sPod representation of the task template.
224
+ :param task: The task to convert.
225
+ :return: The K8sPod representation of the task template.
226
+ """
227
+ from kubernetes.client import ApiClient, V1PodSpec
228
+ from kubernetes.client.models import V1EnvVar, V1ResourceRequirements
229
+
230
+ pod_template = copy.deepcopy(pod_template)
231
+ containers = cast(V1PodSpec, pod_template.pod_spec).containers
232
+ primary_exists = False
233
+
234
+ for container in containers:
235
+ if container.name == pod_template.primary_container_name:
236
+ primary_exists = True
237
+ break
238
+
239
+ if not primary_exists:
240
+ raise ValueError(
241
+ "No primary container defined in the pod spec."
242
+ f" You must define a primary container with the name '{pod_template.primary_container_name}'."
243
+ )
244
+ final_containers = []
245
+
246
+ for container in containers:
247
+ # We overwrite the primary container attributes with the values given to ContainerTask.
248
+ # The attributes include: image, command, args, resource, and env (env is unioned)
249
+
250
+ if container.name == pod_template.primary_container_name:
251
+ if container.image is None:
252
+ # Copy the image from primary_container only if the image is not specified in the pod spec.
253
+ container.image = primary_container.image
254
+
255
+ container.command = list(primary_container.command)
256
+ container.args = list(primary_container.args)
257
+
258
+ limits, requests = {}, {}
259
+ for resource in primary_container.resources.limits:
260
+ limits[_sanitize_resource_name(resource)] = resource.value
261
+ for resource in primary_container.resources.requests:
262
+ requests[_sanitize_resource_name(resource)] = resource.value
263
+
264
+ resource_requirements = V1ResourceRequirements(limits=limits, requests=requests)
265
+ if len(limits) > 0 or len(requests) > 0:
266
+ # Important! Only copy over resource requirements if they are non-empty.
267
+ container.resources = resource_requirements
268
+
269
+ if primary_container.env is not None:
270
+ container.env = [V1EnvVar(name=e.key, value=e.value) for e in primary_container.env] + (
271
+ container.env or []
272
+ )
273
+
274
+ final_containers.append(container)
275
+
276
+ cast(V1PodSpec, pod_template.pod_spec).containers = final_containers
277
+ pod_spec = ApiClient().sanitize_for_serialization(pod_template.pod_spec)
278
+
279
+ metadata = tasks_pb2.K8sObjectMetadata(labels=pod_template.labels, annotations=pod_template.annotations)
280
+ return tasks_pb2.K8sPod(pod_spec=pod_spec, metadata=metadata)
281
+
282
+
213
283
  def extract_code_bundle(task_spec: task_definition_pb2.TaskSpec) -> Optional[CodeBundle]:
214
284
  """
215
285
  Extract the code bundle from the task spec.
@@ -145,7 +145,7 @@ async def convert_and_run(
145
145
  return None, convert_from_native_to_error(err)
146
146
  if task.report:
147
147
  await flyte.report.flush.aio()
148
- return await convert_from_native_to_outputs(out, task.native_interface), None
148
+ return await convert_from_native_to_outputs(out, task.native_interface, task.name), None
149
149
 
150
150
 
151
151
  async def extract_download_run_upload(
flyte/_logging.py CHANGED
@@ -6,7 +6,7 @@ from typing import Optional
6
6
 
7
7
  from ._tools import ipython_check, is_in_cluster
8
8
 
9
- DEFAULT_LOG_LEVEL = logging.INFO
9
+ DEFAULT_LOG_LEVEL = logging.WARNING
10
10
 
11
11
 
12
12
  def is_rich_logging_disabled() -> bool:
flyte/_pod.py ADDED
@@ -0,0 +1,19 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import TYPE_CHECKING, Dict, Optional
3
+
4
+ if TYPE_CHECKING:
5
+ from kubernetes.client import V1PodSpec
6
+
7
+
8
+ _PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"
9
+ _PRIMARY_CONTAINER_DEFAULT_NAME = "primary"
10
+
11
+
12
+ @dataclass(init=True, repr=True, eq=True, frozen=False)
13
+ class PodTemplate(object):
14
+ """Custom PodTemplate specification for a Task."""
15
+
16
+ pod_spec: Optional["V1PodSpec"] = field(default_factory=lambda: V1PodSpec())
17
+ primary_container_name: str = _PRIMARY_CONTAINER_DEFAULT_NAME
18
+ labels: Optional[Dict[str, str]] = None
19
+ annotations: Optional[Dict[str, str]] = None
flyte/_run.py CHANGED
@@ -5,6 +5,7 @@ import pathlib
5
5
  import uuid
6
6
  from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple, Union, cast
7
7
 
8
+ import flyte.errors
8
9
  from flyte.errors import InitializationError
9
10
  from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath, SerializationContext, TaskContext
10
11
  from flyte.syncify import syncify
@@ -25,6 +26,7 @@ from ._tools import ipython_check
25
26
 
26
27
  if TYPE_CHECKING:
27
28
  from flyte.remote import Run
29
+ from flyte.remote._task import LazyEntity
28
30
 
29
31
  from ._code_bundle import CopyFiles
30
32
 
@@ -81,8 +83,11 @@ class _Runner:
81
83
  self._run_base_dir = run_base_dir or "/tmp/base"
82
84
 
83
85
  @requires_initialization
84
- async def _run_remote(self, obj: TaskTemplate[P, R], *args: P.args, **kwargs: P.kwargs) -> Run:
86
+ async def _run_remote(self, obj: TaskTemplate[P, R] | LazyEntity, *args: P.args, **kwargs: P.kwargs) -> Run:
87
+ import grpc
88
+
85
89
  from flyte.remote import Run
90
+ from flyte.remote._task import LazyEntity
86
91
 
87
92
  from ._code_bundle import build_code_bundle, build_pkl_bundle
88
93
  from ._deploy import build_images, plan_deploy
@@ -93,37 +98,47 @@ class _Runner:
93
98
 
94
99
  cfg = get_common_config()
95
100
 
96
- if obj.parent_env is None:
97
- raise ValueError("Task is not attached to an environment. Please attach the task to an environment.")
101
+ if isinstance(obj, LazyEntity):
102
+ task = await obj.fetch.aio()
103
+ task_spec = task.pb2.spec
104
+ inputs = await convert_from_native_to_inputs(task.interface, *args, **kwargs)
105
+ version = task.pb2.task_id.version
106
+ code_bundle = None
107
+ else:
108
+ if obj.parent_env is None:
109
+ raise ValueError("Task is not attached to an environment. Please attach the task to an environment")
98
110
 
99
- deploy_plan = plan_deploy(cast(Environment, obj.parent_env()))
100
- image_cache = await build_images(deploy_plan)
111
+ deploy_plan = plan_deploy(cast(Environment, obj.parent_env()))
112
+ image_cache = await build_images(deploy_plan)
101
113
 
102
- if self._interactive_mode:
103
- code_bundle = await build_pkl_bundle(
104
- obj, upload_to_controlplane=not self._dry_run, copy_bundle_to=self._copy_bundle_to
105
- )
106
- else:
107
- if self._copy_files != "none":
108
- code_bundle = await build_code_bundle(
109
- from_dir=cfg.root_dir, dryrun=self._dry_run, copy_bundle_to=self._copy_bundle_to
114
+ if self._interactive_mode:
115
+ code_bundle = await build_pkl_bundle(
116
+ obj, upload_to_controlplane=not self._dry_run, copy_bundle_to=self._copy_bundle_to
110
117
  )
111
118
  else:
112
- code_bundle = None
119
+ if self._copy_files != "none":
120
+ code_bundle = await build_code_bundle(
121
+ from_dir=cfg.root_dir,
122
+ dryrun=self._dry_run,
123
+ copy_bundle_to=self._copy_bundle_to,
124
+ copy_style=self._copy_files,
125
+ )
126
+ else:
127
+ code_bundle = None
113
128
 
114
- version = self._version or (
115
- code_bundle.computed_version if code_bundle and code_bundle.computed_version else None
116
- )
117
- if not version:
118
- raise ValueError("Version is required when running a task")
119
- s_ctx = SerializationContext(
120
- code_bundle=code_bundle,
121
- version=version,
122
- image_cache=image_cache,
123
- root_dir=cfg.root_dir,
124
- )
125
- task_spec = translate_task_to_wire(obj, s_ctx)
126
- inputs = await convert_from_native_to_inputs(obj.native_interface, *args, **kwargs)
129
+ version = self._version or (
130
+ code_bundle.computed_version if code_bundle and code_bundle.computed_version else None
131
+ )
132
+ if not version:
133
+ raise ValueError("Version is required when running a task")
134
+ s_ctx = SerializationContext(
135
+ code_bundle=code_bundle,
136
+ version=version,
137
+ image_cache=image_cache,
138
+ root_dir=cfg.root_dir,
139
+ )
140
+ task_spec = translate_task_to_wire(obj, s_ctx)
141
+ inputs = await convert_from_native_to_inputs(obj.native_interface, *args, **kwargs)
127
142
 
128
143
  if not self._dry_run:
129
144
  if get_client() is None:
@@ -160,15 +175,35 @@ class _Runner:
160
175
  if task_spec.task_template.id.version == "":
161
176
  task_spec.task_template.id.version = version
162
177
 
163
- resp = await get_client().run_service.CreateRun(
164
- run_service_pb2.CreateRunRequest(
165
- run_id=run_id,
166
- project_id=project_id,
167
- task_spec=task_spec,
168
- inputs=inputs.proto_inputs,
169
- ),
170
- )
171
- return Run(pb2=resp.run)
178
+ try:
179
+ resp = await get_client().run_service.CreateRun(
180
+ run_service_pb2.CreateRunRequest(
181
+ run_id=run_id,
182
+ project_id=project_id,
183
+ task_spec=task_spec,
184
+ inputs=inputs.proto_inputs,
185
+ ),
186
+ )
187
+ return Run(pb2=resp.run)
188
+ except grpc.aio.AioRpcError as e:
189
+ if e.code() == grpc.StatusCode.UNAVAILABLE:
190
+ raise flyte.errors.RuntimeSystemError(
191
+ "SystemUnavailableError",
192
+ "Flyte system is currently unavailable. check your configuration, or the service status.",
193
+ ) from e
194
+ elif e.code() == grpc.StatusCode.INVALID_ARGUMENT:
195
+ raise flyte.errors.RuntimeUserError("InvalidArgumentError", e.details())
196
+ elif e.code() == grpc.StatusCode.ALREADY_EXISTS:
197
+ # TODO maybe this should be a pass and return existing run?
198
+ raise flyte.errors.RuntimeUserError(
199
+ "RunAlreadyExistsError",
200
+ f"A run with the name '{self._name}' already exists. Please choose a different name.",
201
+ )
202
+ else:
203
+ raise flyte.errors.RuntimeSystemError(
204
+ "RunCreationError",
205
+ f"Failed to create run: {e.details()}",
206
+ ) from e
172
207
 
173
208
  class DryRun(Run):
174
209
  def __init__(self, _task_spec, _inputs, _code_bundle):
@@ -225,7 +260,10 @@ class _Runner:
225
260
  else:
226
261
  if self._copy_files != "none":
227
262
  code_bundle = await build_code_bundle(
228
- from_dir=cfg.root_dir, dryrun=self._dry_run, copy_bundle_to=self._copy_bundle_to
263
+ from_dir=cfg.root_dir,
264
+ dryrun=self._dry_run,
265
+ copy_bundle_to=self._copy_bundle_to,
266
+ copy_style=self._copy_files,
229
267
  )
230
268
  else:
231
269
  code_bundle = None
@@ -313,7 +351,7 @@ class _Runner:
313
351
  report=Report(name=action.name),
314
352
  mode="local",
315
353
  )
316
- async with ctx.replace_task_context(tctx):
354
+ with ctx.replace_task_context(tctx):
317
355
  # make the local version always runs on a different thread, returns a wrapped future.
318
356
  if obj._call_as_synchronous:
319
357
  fut = controller.submit_sync(obj, *args, **kwargs)
@@ -323,7 +361,9 @@ class _Runner:
323
361
  return await controller.submit(obj, *args, **kwargs)
324
362
 
325
363
  @syncify
326
- async def run(self, task: TaskTemplate[P, Union[R, Run]], *args: P.args, **kwargs: P.kwargs) -> Union[R, Run]:
364
+ async def run(
365
+ self, task: TaskTemplate[P, Union[R, Run]] | LazyEntity, *args: P.args, **kwargs: P.kwargs
366
+ ) -> Union[R, Run]:
327
367
  """
328
368
  Run an async `@env.task` or `TaskTemplate` instance. The existing async context will be used.
329
369
 
@@ -345,8 +385,13 @@ class _Runner:
345
385
  :param kwargs: Keyword arguments to pass to the Task
346
386
  :return: Run instance or the result of the task
347
387
  """
388
+ from flyte.remote._task import LazyEntity
389
+
390
+ if isinstance(task, LazyEntity) and self._mode != "remote":
391
+ raise ValueError("Remote task can only be run in remote mode.")
348
392
  if self._mode == "remote":
349
393
  return await self._run_remote(task, *args, **kwargs)
394
+ task = cast(TaskTemplate, task)
350
395
  if self._mode == "hybrid":
351
396
  return await self._run_hybrid(task, *args, **kwargs)
352
397
 
flyte/_task.py CHANGED
@@ -23,6 +23,7 @@ from typing import (
23
23
 
24
24
  from flyteidl.core.tasks_pb2 import DataLoadingConfig
25
25
 
26
+ from flyte._pod import PodTemplate
26
27
  from flyte.errors import RuntimeSystemError, RuntimeUserError
27
28
 
28
29
  from ._cache import Cache, CacheRequest
@@ -37,8 +38,6 @@ from ._timeout import TimeoutType
37
38
  from .models import NativeInterface, SerializationContext
38
39
 
39
40
  if TYPE_CHECKING:
40
- from kubernetes.client import V1PodTemplate
41
-
42
41
  from ._task_environment import TaskEnvironment
43
42
 
44
43
  P = ParamSpec("P") # capture the function's parameters
@@ -96,8 +95,7 @@ class TaskTemplate(Generic[P, R]):
96
95
  env: Optional[Dict[str, str]] = None
97
96
  secrets: Optional[SecretRequest] = None
98
97
  timeout: Optional[TimeoutType] = None
99
- primary_container_name: str = "primary"
100
- pod_template: Optional[Union[str, V1PodTemplate]] = None
98
+ pod_template: Optional[Union[str, PodTemplate]] = None
101
99
  report: bool = False
102
100
 
103
101
  parent_env: Optional[weakref.ReferenceType[TaskEnvironment]] = None
@@ -108,15 +106,6 @@ class TaskTemplate(Generic[P, R]):
108
106
  _call_as_synchronous: bool = False
109
107
 
110
108
  def __post_init__(self):
111
- # If pod_template is set to a pod, verify
112
- if self.pod_template is not None and not isinstance(self.pod_template, str):
113
- try:
114
- from kubernetes.client import V1PodTemplate # noqa: F401
115
- except ImportError as e:
116
- raise ImportError(
117
- "kubernetes is not installed, please install kubernetes package to use pod_template"
118
- ) from e
119
-
120
109
  # Auto set the image based on the image request
121
110
  if self.image == "auto":
122
111
  self.image = Image.auto()