flyte 0.2.0b22__py3-none-any.whl → 0.2.0b24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

flyte/__init__.py CHANGED
@@ -1,19 +1,5 @@
1
1
  """
2
- Flyte SDK for authoring Compound AI applications, services and workflows.
3
-
4
- ## Environments
5
-
6
- TaskEnvironment class to define a new environment for a set of tasks.
7
-
8
- Example usage:
9
-
10
- ```python
11
- env = flyte.TaskEnvironment(name="my_env", image="my_image", resources=Resources(cpu="1", memory="1Gi"))
12
-
13
- @env.task
14
- async def my_task():
15
- pass
16
- ```
2
+ Flyte SDK for authoring compound AI applications, services and workflows.
17
3
  """
18
4
 
19
5
  __all__ = [
flyte/_bin/runtime.py CHANGED
@@ -32,7 +32,12 @@ ENABLE_REF_TASKS = "_REF_TASKS" # This is a temporary flag to enable reference
32
32
  _UNION_EAGER_API_KEY_ENV_VAR = "_UNION_EAGER_API_KEY"
33
33
 
34
34
 
35
- @click.command("a0")
35
+ @click.group()
36
+ def _pass_through():
37
+ pass
38
+
39
+
40
+ @_pass_through.command("a0")
36
41
  @click.option("--inputs", "-i", required=True)
37
42
  @click.option("--outputs-path", "-o", required=True)
38
43
  @click.option("--version", "-v", required=True)
@@ -158,3 +163,7 @@ def main(
158
163
  await controller.stop()
159
164
 
160
165
  asyncio.run(_run_and_stop())
166
+
167
+
168
+ if __name__ == "__main__":
169
+ _pass_through()
flyte/_environment.py CHANGED
@@ -49,6 +49,13 @@ class Environment:
49
49
  """
50
50
  Add a dependency to the environment.
51
51
  """
52
+ for e in env:
53
+ if not isinstance(e, Environment):
54
+ raise TypeError(f"Expected Environment, got {type(e)}")
55
+ if e.name == self.name:
56
+ raise ValueError("Cannot add self as a dependency")
57
+ if e in self.depends_on:
58
+ continue
52
59
  self.depends_on.extend(env)
53
60
 
54
61
  def clone_with(
flyte/_image.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  import base64
4
4
  import hashlib
5
5
  import sys
6
+ import typing
6
7
  from abc import abstractmethod
7
8
  from dataclasses import asdict, dataclass, field
8
9
  from functools import cached_property
@@ -10,6 +11,7 @@ from pathlib import Path
10
11
  from typing import Callable, ClassVar, Dict, List, Literal, Optional, Tuple, TypeVar, Union
11
12
 
12
13
  import rich.repr
14
+ from packaging.version import Version
13
15
 
14
16
  # Supported Python versions
15
17
  PYTHON_3_10 = (3, 10)
@@ -352,21 +354,29 @@ class Image:
352
354
  layer.validate()
353
355
 
354
356
  @classmethod
355
- def _get_default_image_for(cls, python_version: Tuple[int, int], flyte_version: Optional[str] = None) -> Image:
357
+ def _get_default_image_for(
358
+ cls, python_version: Tuple[int, int], flyte_version: Optional[str] = None, install_flyte: bool = True
359
+ ) -> Image:
356
360
  # Would love a way to move this outside of this class (but still needs to be accessible via Image.auto())
357
361
  # this default image definition may need to be updated once there is a released pypi version
358
362
  from flyte._version import __version__
359
363
 
360
- dev_mode = (cls._is_editable_install() or (__version__ and "dev" in __version__)) and not flyte_version
361
- if flyte_version is None:
362
- flyte_version = __version__.replace("+", "-")
363
- preset_tag = flyte_version if flyte_version.startswith("v") else f"v{flyte_version}"
364
- preset_tag = f"py{python_version[0]}.{python_version[1]}-{preset_tag}"
364
+ dev_mode = (
365
+ (cls._is_editable_install() or (__version__ and "dev" in __version__))
366
+ and not flyte_version
367
+ and install_flyte
368
+ )
369
+ if install_flyte is False:
370
+ preset_tag = f"py{python_version[0]}.{python_version[1]}"
371
+ else:
372
+ if flyte_version is None:
373
+ flyte_version = __version__.replace("+", "-")
374
+ suffix = flyte_version if flyte_version.startswith("v") else f"v{flyte_version}"
375
+ preset_tag = f"py{python_version[0]}.{python_version[1]}-{suffix}"
365
376
  image = Image._new(
366
377
  base_image=f"python:{python_version[0]}.{python_version[1]}-slim-bookworm",
367
378
  registry=_BASE_REGISTRY,
368
379
  name=_DEFAULT_IMAGE_NAME,
369
- # _tag=preset_tag,
370
380
  platform=("linux/amd64", "linux/arm64"),
371
381
  )
372
382
  labels_and_user = _DockerLines(
@@ -389,11 +399,15 @@ class Image:
389
399
  )
390
400
  image = image.with_apt_packages("build-essential", "ca-certificates")
391
401
 
392
- # Add in flyte library
393
- if dev_mode:
394
- image = image._with_local_v2()
395
- else:
396
- image = image.with_pip_packages(f"flyte=={flyte_version}")
402
+ if install_flyte:
403
+ if dev_mode:
404
+ image = image.with_local_v2()
405
+ else:
406
+ flyte_version = typing.cast(str, flyte_version)
407
+ if Version(flyte_version).is_devrelease or Version(flyte_version).is_prerelease:
408
+ image = image.with_pip_packages(f"flyte=={flyte_version}", pre=True)
409
+ else:
410
+ image = image.with_pip_packages(f"flyte=={flyte_version}")
397
411
  object.__setattr__(image, "_tag", preset_tag)
398
412
  # Set this to auto for all auto images because the meaning of "auto" can change (based on logic inside
399
413
  # _get_default_image_for, acts differently in a running task container) so let's make sure it stays auto.
@@ -413,6 +427,7 @@ class Image:
413
427
  cls,
414
428
  python_version: Optional[Tuple[int, int]] = None,
415
429
  flyte_version: Optional[str] = None,
430
+ install_flyte: bool = True,
416
431
  registry: Optional[str] = None,
417
432
  name: Optional[str] = None,
418
433
  ) -> Image:
@@ -422,6 +437,7 @@ class Image:
422
437
 
423
438
  :param python_version: If not specified, will use the current Python version
424
439
  :param flyte_version: Union version to use
440
+ :param install_flyte: If True, will install the flyte library in the image
425
441
  :param registry: Registry to use for the image
426
442
  :param name: Name of the image if you want to override the default name
427
443
 
@@ -430,7 +446,9 @@ class Image:
430
446
  if python_version is None:
431
447
  python_version = _detect_python_version()
432
448
 
433
- base_image = cls._get_default_image_for(python_version=python_version, flyte_version=flyte_version)
449
+ base_image = cls._get_default_image_for(
450
+ python_version=python_version, flyte_version=flyte_version, install_flyte=install_flyte
451
+ )
434
452
 
435
453
  if registry and name:
436
454
  return base_image.clone(registry=registry, name=name)
@@ -534,6 +552,7 @@ class Image:
534
552
  self,
535
553
  registry: Optional[str] = None,
536
554
  name: Optional[str] = None,
555
+ python_version: Optional[Tuple[int, int]] = None,
537
556
  addl_layer: Optional[Layer] = None,
538
557
  ) -> Image:
539
558
  """
@@ -541,17 +560,16 @@ class Image:
541
560
 
542
561
  :param registry: Registry to use for the image
543
562
  :param name: Name of the image
563
+ :param python_version: Python version for the image, if not specified, will use the current Python version
544
564
  :param addl_layer: Additional layer to add to the image. This will be added to the end of the layers.
545
- :param addl_layer: Additional layer to add to the image. This will be added on top of the existing layers.
546
565
 
547
566
  :return:
548
567
  """
549
568
  registry = registry if registry else self.registry
550
569
  name = name if name else self.name
551
- if addl_layer and (not name or not registry):
570
+ if addl_layer and (not name):
552
571
  raise ValueError(
553
- f"Cannot add additional layer {addl_layer} to an image"
554
- f" without a registry and name. Please first clone()."
572
+ f"Cannot add additional layer {addl_layer} to an image without name. Please first clone()."
555
573
  )
556
574
  new_layers = (*self._layers, addl_layer) if addl_layer else self._layers
557
575
  img = Image._new(
@@ -560,7 +578,7 @@ class Image:
560
578
  registry=registry,
561
579
  name=name,
562
580
  platform=self.platform,
563
- python_version=self.python_version,
581
+ python_version=python_version or self.python_version,
564
582
  _layers=new_layers,
565
583
  )
566
584
 
@@ -710,7 +728,7 @@ class Image:
710
728
  :param dst: destination folder in the image
711
729
  :return: Image
712
730
  """
713
- new_image = self.clone(addl_layer=CopyConfig(path_type=1, src=src, dst=dst))
731
+ new_image = self.clone(addl_layer=CopyConfig(path_type=1, src=src, dst=dst, _compute_identifier=lambda x: dst))
714
732
  return new_image
715
733
 
716
734
  def with_source_file(self, src: Path, dst: str = ".") -> Image:
@@ -722,7 +740,7 @@ class Image:
722
740
  :param dst: destination folder in the image
723
741
  :return: Image
724
742
  """
725
- new_image = self.clone(addl_layer=CopyConfig(path_type=0, src=src, dst=dst))
743
+ new_image = self.clone(addl_layer=CopyConfig(path_type=0, src=src, dst=dst, _compute_identifier=lambda x: dst))
726
744
  return new_image
727
745
 
728
746
  def with_uv_project(
@@ -774,7 +792,7 @@ class Image:
774
792
  new_image = self.clone(addl_layer=Commands(commands=new_commands))
775
793
  return new_image
776
794
 
777
- def _with_local_v2(self) -> Image:
795
+ def with_local_v2(self) -> Image:
778
796
  """
779
797
  Use this method to create a new image with the local v2 builder
780
798
  This will override any existing builder
@@ -1,12 +1,15 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
+ from typing import Any, Literal
4
5
 
5
6
  from flyteidl.core import execution_pb2
6
7
 
7
8
  from flyte._protos.workflow import run_definition_pb2, state_service_pb2, task_definition_pb2
8
9
  from flyte.models import GroupData
9
10
 
11
+ ActionType = Literal["task", "trace"]
12
+
10
13
 
11
14
  @dataclass
12
15
  class Action:
@@ -17,6 +20,7 @@ class Action:
17
20
 
18
21
  action_id: run_definition_pb2.ActionIdentifier
19
22
  parent_action_name: str
23
+ type: ActionType = "task" # type of action, task or trace
20
24
  friendly_name: str | None = None
21
25
  group: GroupData | None = None
22
26
  task: task_definition_pb2.TaskSpec | None = None
@@ -144,3 +148,28 @@ class Action:
144
148
  err=obj.error if obj.HasField("error") else None,
145
149
  realized_outputs_uri=obj.output_uri,
146
150
  )
151
+
152
+ @classmethod
153
+ def from_trace(
154
+ cls,
155
+ parent_action_name: str,
156
+ action_id: run_definition_pb2.ActionIdentifier,
157
+ friendly_name: str,
158
+ group_data: GroupData | None,
159
+ trace_spec: Any, # TODO
160
+ inputs_uri: str,
161
+ outputs_uri: str,
162
+ ) -> Action:
163
+ """
164
+ This creates a new action for tracing purposes. It is used to track the execution of a trace.
165
+ """
166
+ return cls(
167
+ action_id=action_id,
168
+ parent_action_name=parent_action_name,
169
+ type="trace",
170
+ friendly_name=friendly_name,
171
+ group=group_data,
172
+ inputs_uri=inputs_uri,
173
+ realized_outputs_uri=outputs_uri,
174
+ phase=run_definition_pb2.Phase.PHASE_SUCCEEDED,
175
+ )
@@ -174,13 +174,12 @@ class RemoteController(Controller):
174
174
  )
175
175
 
176
176
  task_spec = translate_task_to_wire(_task, new_serialization_context)
177
- serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
178
-
179
- inputs_hash = convert.generate_inputs_hash(serialized_inputs)
177
+ inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs)
180
178
  sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
181
179
  tctx, task_spec, inputs_hash, _task_call_seq
182
180
  )
183
181
 
182
+ serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
184
183
  inputs_uri = io.inputs_path(sub_action_output_path)
185
184
  await upload_inputs_with_retry(serialized_inputs, inputs_uri)
186
185
 
@@ -391,10 +390,12 @@ class RemoteController(Controller):
391
390
  if tctx is None:
392
391
  raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
393
392
 
393
+ current_action_id = tctx.action
394
394
  current_output_path = tctx.output_path
395
395
  sub_run_output_path = storage.join(current_output_path, info.action.name)
396
396
 
397
397
  if info.interface.has_outputs():
398
+ outputs_file_path: str = ""
398
399
  if info.output:
399
400
  outputs = await convert.convert_from_native_to_outputs(info.output, info.interface)
400
401
  outputs_file_path = io.outputs_path(sub_run_output_path)
@@ -405,6 +406,33 @@ class RemoteController(Controller):
405
406
  await io.upload_error(err.err, error_path)
406
407
  else:
407
408
  raise flyte.errors.RuntimeSystemError("BadTraceInfo", "Trace info does not have output or error")
409
+ trace_action = Action.from_trace(
410
+ parent_action_name=current_action_id.name,
411
+ action_id=run_definition_pb2.ActionIdentifier(
412
+ name=info.action.name,
413
+ run=run_definition_pb2.RunIdentifier(
414
+ name=current_action_id.run_name,
415
+ project=current_action_id.project,
416
+ domain=current_action_id.domain,
417
+ org=current_action_id.org,
418
+ ),
419
+ ),
420
+ inputs_uri=info.inputs_path,
421
+ outputs_uri=outputs_file_path,
422
+ trace_spec="",
423
+ friendly_name=info.name,
424
+ group_data=tctx.group_data,
425
+ )
426
+ try:
427
+ logger.info(
428
+ f"Submitting Trace action Run:[{trace_action.run_name}, Parent:[{trace_action.parent_action_name}],"
429
+ f" Trace fn:[{info.name}], action:[{info.action.name}]"
430
+ )
431
+ await self.submit_action(trace_action)
432
+ logger.info(f"Trace Action for [{info.name}] action id: {info.action.name}, completed!")
433
+ except asyncio.CancelledError:
434
+ # If the action is cancelled, we need to cancel the action on the server as well
435
+ raise
408
436
 
409
437
  async def submit_task_ref(self, _task: task_definition_pb2.TaskDetails, *args, **kwargs) -> Any:
410
438
  ctx = internal_ctx()
@@ -420,12 +448,12 @@ class RemoteController(Controller):
420
448
  _task.spec.task_template.interface, default_inputs=_task.spec.default_inputs
421
449
  )
422
450
  inputs = await convert.convert_from_native_to_inputs(native_interface, *args, **kwargs)
423
- serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
424
- inputs_hash = convert.generate_inputs_hash(serialized_inputs)
451
+ inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs)
425
452
  sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
426
453
  tctx, task_name, inputs_hash, invoke_seq_num
427
454
  )
428
455
 
456
+ serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
429
457
  inputs_uri = io.inputs_path(sub_action_output_path)
430
458
  await upload_inputs_with_retry(serialized_inputs, inputs_uri)
431
459
  # cache key - task name, task signature, inputs, cache version
@@ -309,29 +309,43 @@ class Controller:
309
309
  """
310
310
  Attempt to launch an action.
311
311
  """
312
- if not action.is_started() and action.task is not None:
313
- logger.debug(f"Attempting to launch action: {action.name}")
314
- try:
312
+ if not action.is_started():
313
+ task: queue_service_pb2.TaskAction | None = None
314
+ trace: queue_service_pb2.TraceAction | None = None
315
+ if action.type == "task":
316
+ if action.task is None:
317
+ raise flyte.errors.RuntimeSystemError(
318
+ "NoTaskSpec", "Task Spec not found, cannot launch Task Action."
319
+ )
315
320
  cache_key = None
316
321
  logger.info(f"Action {action.name} has cache version {action.cache_key}")
317
322
  if action.cache_key:
318
323
  cache_key = StringValue(value=action.cache_key)
319
324
 
325
+ task = queue_service_pb2.TaskAction(
326
+ id=task_definition_pb2.TaskIdentifier(
327
+ version=action.task.task_template.id.version,
328
+ org=action.task.task_template.id.org,
329
+ project=action.task.task_template.id.project,
330
+ domain=action.task.task_template.id.domain,
331
+ name=action.task.task_template.id.name,
332
+ ),
333
+ spec=action.task,
334
+ cache_key=cache_key,
335
+ )
336
+ else:
337
+ trace = queue_service_pb2.TraceAction(
338
+ name=action.friendly_name,
339
+ )
340
+
341
+ logger.debug(f"Attempting to launch action: {action.name}")
342
+ try:
320
343
  await self._queue_service.EnqueueAction(
321
344
  queue_service_pb2.EnqueueActionRequest(
322
345
  action_id=action.action_id,
323
346
  parent_action_name=action.parent_action_name,
324
- task=queue_service_pb2.TaskAction(
325
- id=task_definition_pb2.TaskIdentifier(
326
- version=action.task.task_template.id.version,
327
- org=action.task.task_template.id.org,
328
- project=action.task.task_template.id.project,
329
- domain=action.task.task_template.id.domain,
330
- name=action.task.task_template.id.name,
331
- ),
332
- spec=action.task,
333
- cache_key=cache_key,
334
- ),
347
+ task=task,
348
+ trace=trace,
335
349
  input_uri=action.inputs_uri,
336
350
  run_output_base=action.run_output_base,
337
351
  group=action.group.name if action.group else None,
@@ -1,15 +1,3 @@
1
- import asyncio
2
- from typing import List
3
-
4
- from flyte._image import Image
5
- from flyte._internal.imagebuild.docker_builder import DockerImageBuilder
6
1
  from flyte._internal.imagebuild.image_builder import ImageBuildEngine
7
- from flyte._internal.imagebuild.remote_builder import RemoteImageBuilder
8
-
9
- __all__ = ["DockerImageBuilder", "ImageBuildEngine", "RemoteImageBuilder"]
10
-
11
2
 
12
- async def build(images: List[Image]) -> List[str]:
13
- builder = DockerImageBuilder()
14
- ts = [asyncio.create_task(builder.build_image(image)) for image in images]
15
- return list(await asyncio.gather(*ts))
3
+ __all__ = ["ImageBuildEngine"]
@@ -72,7 +72,7 @@ RUN --mount=type=cache,sharing=locked,mode=0777,target=/var/cache/apt,id=apt \
72
72
 
73
73
  UV_PYTHON_INSTALL_COMMAND = Template("""\
74
74
  RUN --mount=type=cache,sharing=locked,mode=0777,target=/root/.cache/uv,id=uv \
75
- uv pip install --prerelease=allow $PIP_INSTALL_ARGS
75
+ uv pip install $PIP_INSTALL_ARGS
76
76
  """)
77
77
 
78
78
  # uv pip install --python /root/env/bin/python
@@ -155,9 +155,18 @@ async def convert_from_native_to_outputs(o: Any, interface: NativeInterface, tas
155
155
  if not isinstance(o, tuple):
156
156
  o = (o,)
157
157
 
158
- assert len(interface.outputs) == len(interface.outputs), (
159
- f"Received {len(o)} outputs but interface has {len(interface.outputs)}"
160
- )
158
+ if len(interface.outputs) == 0:
159
+ if len(o) != 0:
160
+ if len(o) == 1 and o[0] is not None:
161
+ raise flyte.errors.RuntimeDataValidationError(
162
+ "o0",
163
+ f"Expected no outputs but got {o},did you miss a return type annotation?",
164
+ task_name,
165
+ )
166
+ else:
167
+ assert len(o) == len(interface.outputs), (
168
+ f"Received {len(o)} outputs but return annotation has {len(interface.outputs)} outputs specified. "
169
+ )
161
170
  named = []
162
171
  for (output_name, python_type), v in zip(interface.outputs.items(), o):
163
172
  try:
@@ -279,6 +288,29 @@ def generate_inputs_hash(serialized_inputs: str | bytes) -> str:
279
288
  return hash_data(serialized_inputs)
280
289
 
281
290
 
291
+ def generate_inputs_hash_from_proto(inputs: run_definition_pb2.Inputs) -> str:
292
+ """
293
+ Generate a hash for the inputs. This is used to uniquely identify the inputs for a task.
294
+ :param inputs: The inputs to hash.
295
+ :return: A hexadecimal string representation of the hash.
296
+ """
297
+ if not inputs:
298
+ return ""
299
+ return generate_inputs_hash(inputs.SerializeToString(deterministic=True))
300
+
301
+
302
+ def generate_interface_hash(task_interface: interface_pb2.TypedInterface) -> str:
303
+ """
304
+ Generate a hash for the task interface. This is used to uniquely identify the task interface.
305
+ :param task_interface: The interface of the task.
306
+ :return: A hexadecimal string representation of the hash.
307
+ """
308
+ if not task_interface:
309
+ return ""
310
+ serialized_interface = task_interface.SerializeToString(deterministic=True)
311
+ return hash_data(serialized_interface)
312
+
313
+
282
314
  def generate_cache_key_hash(
283
315
  task_name: str,
284
316
  inputs_hash: str,
@@ -302,10 +334,13 @@ def generate_cache_key_hash(
302
334
  if ignored_input_vars:
303
335
  filtered = [named_lit for named_lit in proto_inputs.literals if named_lit.name not in ignored_input_vars]
304
336
  final = run_definition_pb2.Inputs(literals=filtered)
305
- final_inputs = final.SerializeToString(deterministic=True)
337
+ final_inputs = generate_inputs_hash_from_proto(final)
306
338
  else:
307
339
  final_inputs = inputs_hash
308
- data = f"{final_inputs}{task_name}{task_interface.SerializeToString(deterministic=True)}{cache_version}"
340
+
341
+ interface_hash = generate_interface_hash(task_interface)
342
+
343
+ data = f"{final_inputs}{task_name}{interface_hash}{cache_version}"
309
344
  return hash_data(data)
310
345
 
311
346
 
@@ -143,7 +143,7 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
143
143
 
144
144
  # TODO Add support for SQL, extra_config, custom
145
145
  extra_config: typing.Dict[str, str] = {}
146
- custom = {} # type: ignore
146
+ custom = task.custom_config(serialize_context)
147
147
 
148
148
  sql = None
149
149
  if task.pod_template and not isinstance(task.pod_template, str):
@@ -220,9 +220,16 @@ def _get_urun_container(
220
220
  if not serialize_context.image_cache:
221
221
  # This computes the image uri, computing hashes as necessary so can fail if done remotely.
222
222
  img_uri = task_template.image.uri
223
- elif image_id not in serialize_context.image_cache.image_lookup:
223
+ elif serialize_context.image_cache and image_id not in serialize_context.image_cache.image_lookup:
224
224
  img_uri = task_template.image.uri
225
- logger.warning(f"Image uri {img_uri} not found in image_cache {serialize_context.image_cache.image_lookup}")
225
+ from flyte._version import __version__
226
+
227
+ logger.warning(
228
+ f"Image {task_template.image} not found in the image cache: {serialize_context.image_cache.image_lookup}.\n"
229
+ f"This typically occurs when the Flyte SDK version (`{__version__}`) used in the task environment "
230
+ f"differs from the version used to compile or deploy it.\n"
231
+ f"Ensure both environments use the same Flyte SDK version to avoid inconsistencies in image resolution."
232
+ )
226
233
  else:
227
234
  img_uri = serialize_context.image_cache.image_lookup[image_id]
228
235
 
flyte/_pod.py CHANGED
@@ -2,7 +2,8 @@ from dataclasses import dataclass, field
2
2
  from typing import TYPE_CHECKING, Dict, Optional
3
3
 
4
4
  if TYPE_CHECKING:
5
- from kubernetes.client import V1PodSpec
5
+ from flyteidl.core.tasks_pb2 import K8sPod
6
+ from kubernetes.client import ApiClient, V1PodSpec
6
7
 
7
8
 
8
9
  _PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"
@@ -17,3 +18,12 @@ class PodTemplate(object):
17
18
  primary_container_name: str = _PRIMARY_CONTAINER_DEFAULT_NAME
18
19
  labels: Optional[Dict[str, str]] = None
19
20
  annotations: Optional[Dict[str, str]] = None
21
+
22
+ def to_k8s_pod(self) -> "K8sPod":
23
+ from flyteidl.core.tasks_pb2 import K8sObjectMetadata, K8sPod
24
+
25
+ return K8sPod(
26
+ metadata=K8sObjectMetadata(labels=self.labels, annotations=self.annotations),
27
+ pod_spec=ApiClient().sanitize_for_serialization(self.pod_spec),
28
+ primary_container_name=self.primary_container_name,
29
+ )
flyte/_resources.py CHANGED
@@ -1,8 +1,15 @@
1
- from dataclasses import dataclass
2
- from typing import Literal, Optional, Tuple, Union, get_args
1
+ from dataclasses import dataclass, fields
2
+ from typing import TYPE_CHECKING, Literal, Optional, Tuple, Union, get_args
3
3
 
4
4
  import rich.repr
5
5
 
6
+ from flyte._pod import _PRIMARY_CONTAINER_DEFAULT_NAME
7
+
8
+ if TYPE_CHECKING:
9
+ from kubernetes.client import V1PodSpec
10
+
11
+ PRIMARY_CONTAINER_DEFAULT_NAME = "primary"
12
+
6
13
  GPUType = Literal["T4", "A100", "A100 80G", "H100", "L4", "L40s"]
7
14
  GPUQuantity = Literal[1, 2, 3, 4, 5, 6, 7, 8]
8
15
  A100Parts = Literal["1g.5gb", "2g.10gb", "3g.20gb", "4g.20gb", "7g.40gb"]
@@ -224,3 +231,61 @@ class Resources:
224
231
  if self.shm == "auto":
225
232
  return ""
226
233
  return self.shm
234
+
235
+
236
+ def _check_resource_is_singular(resource: Resources):
237
+ """
238
+ Raise a value error if the resource has a tuple.
239
+ """
240
+ for field in fields(resource):
241
+ value = getattr(resource, field.name)
242
+ if isinstance(value, (tuple, list)):
243
+ raise ValueError(f"{value} can not be a list or tuple")
244
+ return resource
245
+
246
+
247
+ def pod_spec_from_resources(
248
+ primary_container_name: str = _PRIMARY_CONTAINER_DEFAULT_NAME,
249
+ requests: Optional[Resources] = None,
250
+ limits: Optional[Resources] = None,
251
+ k8s_gpu_resource_key: str = "nvidia.com/gpu",
252
+ ) -> "V1PodSpec":
253
+ from kubernetes.client import V1Container, V1PodSpec, V1ResourceRequirements
254
+
255
+ def _construct_k8s_pods_resources(resources: Optional[Resources], k8s_gpu_resource_key: str):
256
+ if resources is None:
257
+ return None
258
+
259
+ resources_map = {
260
+ "cpu": "cpu",
261
+ "memory": "memory",
262
+ "gpu": k8s_gpu_resource_key,
263
+ "ephemeral_storage": "ephemeral-storage",
264
+ }
265
+
266
+ k8s_pod_resources = {}
267
+
268
+ _check_resource_is_singular(resources)
269
+ for resource in fields(resources):
270
+ resource_value = getattr(resources, resource.name)
271
+ if resource_value is not None:
272
+ k8s_pod_resources[resources_map[resource.name]] = resource_value
273
+
274
+ return k8s_pod_resources
275
+
276
+ requests = _construct_k8s_pods_resources(resources=requests, k8s_gpu_resource_key=k8s_gpu_resource_key)
277
+ limits = _construct_k8s_pods_resources(resources=limits, k8s_gpu_resource_key=k8s_gpu_resource_key)
278
+ requests = requests or limits
279
+ limits = limits or requests
280
+
281
+ return V1PodSpec(
282
+ containers=[
283
+ V1Container(
284
+ name=primary_container_name,
285
+ resources=V1ResourceRequirements(
286
+ requests=requests,
287
+ limits=limits,
288
+ ),
289
+ )
290
+ ]
291
+ )
flyte/_run.py CHANGED
@@ -6,13 +6,9 @@ import uuid
6
6
  from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union, cast
7
7
 
8
8
  import flyte.errors
9
- from flyte.errors import InitializationError
10
- from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath, SerializationContext, TaskContext
11
- from flyte.syncify import syncify
12
-
13
- from ._context import contextual_run, internal_ctx
14
- from ._environment import Environment
15
- from ._initialize import (
9
+ from flyte._context import contextual_run, internal_ctx
10
+ from flyte._environment import Environment
11
+ from flyte._initialize import (
16
12
  _get_init_config,
17
13
  get_client,
18
14
  get_common_config,
@@ -20,9 +16,19 @@ from ._initialize import (
20
16
  requires_initialization,
21
17
  requires_storage,
22
18
  )
23
- from ._logging import logger
24
- from ._task import P, R, TaskTemplate
25
- from ._tools import ipython_check
19
+ from flyte._logging import logger
20
+ from flyte._task import P, R, TaskTemplate
21
+ from flyte._tools import ipython_check
22
+ from flyte.errors import InitializationError
23
+ from flyte.models import (
24
+ ActionID,
25
+ Checkpoints,
26
+ CodeBundle,
27
+ RawDataPath,
28
+ SerializationContext,
29
+ TaskContext,
30
+ )
31
+ from flyte.syncify import syncify
26
32
 
27
33
  if TYPE_CHECKING:
28
34
  from flyte.remote import Run
@@ -132,7 +138,9 @@ class _Runner:
132
138
 
133
139
  if self._interactive_mode:
134
140
  code_bundle = await build_pkl_bundle(
135
- obj, upload_to_controlplane=not self._dry_run, copy_bundle_to=self._copy_bundle_to
141
+ obj,
142
+ upload_to_controlplane=not self._dry_run,
143
+ copy_bundle_to=self._copy_bundle_to,
136
144
  )
137
145
  else:
138
146
  if self._copy_files != "none":
@@ -253,7 +261,8 @@ class _Runner:
253
261
  pb2=run_definition_pb2.Run(
254
262
  action=run_definition_pb2.Action(
255
263
  id=run_definition_pb2.ActionIdentifier(
256
- name="a0", run=run_definition_pb2.RunIdentifier(name="dry-run")
264
+ name="a0",
265
+ run=run_definition_pb2.RunIdentifier(name="dry-run"),
257
266
  )
258
267
  )
259
268
  )
@@ -286,7 +295,7 @@ class _Runner:
286
295
  if obj.parent_env is None:
287
296
  raise ValueError("Task is not attached to an environment. Please attach the task to an environment.")
288
297
 
289
- image_cache = build_images.aio(cast(Environment, obj.parent_env()))
298
+ image_cache = await build_images.aio(cast(Environment, obj.parent_env()))
290
299
 
291
300
  code_bundle = None
292
301
  if self._name is not None:
@@ -296,7 +305,9 @@ class _Runner:
296
305
  if not code_bundle:
297
306
  if self._interactive_mode:
298
307
  code_bundle = await build_pkl_bundle(
299
- obj, upload_to_controlplane=not self._dry_run, copy_bundle_to=self._copy_bundle_to
308
+ obj,
309
+ upload_to_controlplane=not self._dry_run,
310
+ copy_bundle_to=self._copy_bundle_to,
300
311
  )
301
312
  else:
302
313
  if self._copy_files != "none":
@@ -381,7 +392,8 @@ class _Runner:
381
392
  tctx = TaskContext(
382
393
  action=action,
383
394
  checkpoints=Checkpoints(
384
- prev_checkpoint_path=internal_ctx().raw_data.path, checkpoint_path=internal_ctx().raw_data.path
395
+ prev_checkpoint_path=internal_ctx().raw_data.path,
396
+ checkpoint_path=internal_ctx().raw_data.path,
385
397
  ),
386
398
  code_bundle=None,
387
399
  output_path=self._metadata_path,
@@ -403,7 +415,10 @@ class _Runner:
403
415
 
404
416
  @syncify
405
417
  async def run(
406
- self, task: TaskTemplate[P, Union[R, Run]] | LazyEntity, *args: P.args, **kwargs: P.kwargs
418
+ self,
419
+ task: TaskTemplate[P, Union[R, Run]] | LazyEntity,
420
+ *args: P.args,
421
+ **kwargs: P.kwargs,
407
422
  ) -> Union[R, Run]:
408
423
  """
409
424
  Run an async `@env.task` or `TaskTemplate` instance. The existing async context will be used.
@@ -430,6 +445,10 @@ class _Runner:
430
445
 
431
446
  if isinstance(task, LazyEntity) and self._mode != "remote":
432
447
  raise ValueError("Remote task can only be run in remote mode.")
448
+
449
+ if not isinstance(task, TaskTemplate) and not isinstance(task, LazyEntity):
450
+ raise TypeError("On Flyte tasks can be run, not generic functions or methods.")
451
+
433
452
  if self._mode == "remote":
434
453
  return await self._run_remote(task, *args, **kwargs)
435
454
  task = cast(TaskTemplate, task)
flyte/_task.py CHANGED
@@ -361,6 +361,7 @@ class AsyncFunctionTaskTemplate(TaskTemplate[P, R]):
361
361
  """
362
362
 
363
363
  func: FunctionTypes
364
+ plugin_config: Optional[Any] = None # This is used to pass plugin specific configuration
364
365
 
365
366
  def __post_init__(self):
366
367
  super().__post_init__()
@@ -60,8 +60,9 @@ class TaskEnvironment(Environment):
60
60
  :param reusable: Reuse policy for the environment, if set, a python process may be reused for multiple tasks.
61
61
  """
62
62
 
63
- cache: Union[CacheRequest] = "auto"
63
+ cache: Union[CacheRequest] = "disable"
64
64
  reusable: ReusePolicy | None = None
65
+ plugin_config: Optional[Any] = None
65
66
  # TODO Shall we make this union of string or env? This way we can lookup the env by module/file:name
66
67
  # TODO also we could add list of files that are used by this environment
67
68
 
@@ -133,6 +134,8 @@ class TaskEnvironment(Environment):
133
134
  used.
134
135
  :param report: Optional Whether to generate the html report for the task, defaults to False.
135
136
  """
137
+ from ._task import P, R
138
+
136
139
  if self.reusable is not None:
137
140
  if pod_template is not None:
138
141
  raise ValueError("Cannot set pod_template when environment is reusable.")
@@ -141,7 +144,22 @@ class TaskEnvironment(Environment):
141
144
  friendly_name = name or func.__name__
142
145
  task_name = self.name + "." + func.__name__
143
146
 
144
- tmpl: AsyncFunctionTaskTemplate = AsyncFunctionTaskTemplate(
147
+ if self.plugin_config is not None:
148
+ from flyte.extend import TaskPluginRegistry
149
+
150
+ task_template_class: type[AsyncFunctionTaskTemplate[P, R]] | None = TaskPluginRegistry.find(
151
+ config_type=type(self.plugin_config)
152
+ )
153
+ if task_template_class is None:
154
+ raise ValueError(
155
+ f"No task plugin found for config type {type(self.plugin_config)}. "
156
+ f"Please register a plugin using flyte.extend.TaskPluginRegistry.register() api."
157
+ )
158
+ else:
159
+ task_template_class = AsyncFunctionTaskTemplate
160
+
161
+ task_template_class = cast(type[AsyncFunctionTaskTemplate[P, R]], task_template_class)
162
+ tmpl = task_template_class(
145
163
  func=func,
146
164
  name=task_name,
147
165
  image=self.image,
@@ -158,6 +176,7 @@ class TaskEnvironment(Environment):
158
176
  interface=NativeInterface.from_callable(func),
159
177
  report=report,
160
178
  friendly_name=friendly_name,
179
+ plugin_config=self.plugin_config,
161
180
  )
162
181
  self._tasks[task_name] = tmpl
163
182
  return tmpl
flyte/_task_plugins.py ADDED
@@ -0,0 +1,45 @@
1
+ import typing
2
+ from typing import Type
3
+
4
+ import rich.repr
5
+
6
+ if typing.TYPE_CHECKING:
7
+ from ._task import AsyncFunctionTaskTemplate
8
+
9
+ T = typing.TypeVar("T", bound="AsyncFunctionTaskTemplate")
10
+
11
+
12
+ class _Registry:
13
+ """
14
+ A registry for task plugins.
15
+ """
16
+
17
+ def __init__(self):
18
+ self._plugins: typing.Dict[Type, Type[T]] = {}
19
+
20
+ def register(self, config_type: Type, plugin: Type[T]):
21
+ """
22
+ Register a plugin.
23
+ """
24
+ self._plugins[config_type] = plugin
25
+
26
+ def find(self, config_type: Type) -> typing.Optional[Type[T]]:
27
+ """
28
+ Get a plugin by name.
29
+ """
30
+ return self._plugins.get(config_type)
31
+
32
+ def list_plugins(self):
33
+ """
34
+ List all registered plugins.
35
+ """
36
+ return list(self._plugins.keys())
37
+
38
+ def __rich_repr__(self) -> "rich.repr.Result":
39
+ yield from (("Name", i) for i in self.list_plugins())
40
+
41
+ def __repr__(self):
42
+ return f"TaskPluginRegistry(plugins={self.list_plugins()})"
43
+
44
+
45
+ TaskPluginRegistry = _Registry()
flyte/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.0b22'
21
- __version_tuple__ = version_tuple = (0, 2, 0, 'b22')
20
+ __version__ = version = '0.2.0b24'
21
+ __version_tuple__ = version_tuple = (0, 2, 0, 'b24')
flyte/cli/_common.py CHANGED
@@ -151,7 +151,9 @@ class InvokeBaseMixin:
151
151
  # If the user has requested verbose output, print the full traceback
152
152
  console = Console()
153
153
  console.print(Traceback.from_exception(type(e), e, e.__traceback__))
154
- raise click.ClickException(f"Error invoking command: {e}") from e
154
+ exit(1)
155
+ else:
156
+ raise click.ClickException(f"Error invoking command: {e}") from e
155
157
 
156
158
 
157
159
  class CommandBase(InvokeBaseMixin, click.RichCommand):
flyte/errors.py CHANGED
@@ -157,9 +157,9 @@ class RuntimeDataValidationError(RuntimeUserError):
157
157
  This error is raised when the user tries to access a resource that does not exist or is invalid.
158
158
  """
159
159
 
160
- def __init__(self, var: str, e: Exception, task_name: str = ""):
160
+ def __init__(self, var: str, e: Exception | str, task_name: str = ""):
161
161
  super().__init__(
162
- "DataValiationError", f"In task {task_name} variable {var}, failed to serialize/deserialize because {e}"
162
+ "DataValiationError", f"In task {task_name} variable {var}, failed to serialize/deserialize because of {e}"
163
163
  )
164
164
 
165
165
 
flyte/extend.py ADDED
@@ -0,0 +1,12 @@
1
+ from ._initialize import is_initialized
2
+ from ._resources import PRIMARY_CONTAINER_DEFAULT_NAME, pod_spec_from_resources
3
+ from ._task import AsyncFunctionTaskTemplate
4
+ from ._task_plugins import TaskPluginRegistry
5
+
6
+ __all__ = [
7
+ "PRIMARY_CONTAINER_DEFAULT_NAME",
8
+ "AsyncFunctionTaskTemplate",
9
+ "TaskPluginRegistry",
10
+ "is_initialized",
11
+ "pod_spec_from_resources",
12
+ ]
@@ -1315,7 +1315,8 @@ class TypeEngine(typing.Generic[T]):
1315
1315
  try:
1316
1316
  return transformer.guess_python_type(flyte_type)
1317
1317
  except ValueError:
1318
- logger.debug(f"Skipping transformer {transformer.name} for {flyte_type}")
1318
+ # Skipping transformer
1319
+ continue
1319
1320
 
1320
1321
  # Because the dataclass transformer is handled explicitly in the get_transformer code, we have to handle it
1321
1322
  # separately here too.
@@ -1438,6 +1439,19 @@ def _type_essence(x: types_pb2.LiteralType) -> types_pb2.LiteralType:
1438
1439
 
1439
1440
 
1440
1441
  def _are_types_castable(upstream: types_pb2.LiteralType, downstream: types_pb2.LiteralType) -> bool:
1442
+ if upstream.union_type is not None:
1443
+ # for each upstream variant, there must be a compatible type downstream
1444
+ for v in upstream.union_type.variants:
1445
+ if not _are_types_castable(v, downstream):
1446
+ return False
1447
+ return True
1448
+
1449
+ if downstream.union_type is not None:
1450
+ # there must be a compatible downstream type
1451
+ for v in downstream.union_type.variants:
1452
+ if _are_types_castable(upstream, v):
1453
+ return True
1454
+
1441
1455
  if upstream.HasField("collection_type"):
1442
1456
  if not downstream.HasField("collection_type"):
1443
1457
  return False
@@ -1483,7 +1497,7 @@ def _are_types_castable(upstream: types_pb2.LiteralType, downstream: types_pb2.L
1483
1497
 
1484
1498
  return True
1485
1499
 
1486
- if upstream.HasField("union_type"):
1500
+ if upstream.HasField("union_type") and upstream.union_type is not None:
1487
1501
  # for each upstream variant, there must be a compatible type downstream
1488
1502
  for v in upstream.union_type.variants:
1489
1503
  if not _are_types_castable(v, downstream):
@@ -0,0 +1,169 @@
1
+ """
2
+ Flyte runtime module, this is the entrypoint script for the Flyte runtime.
3
+
4
+ Caution: Startup time for this module is very important, as it is the entrypoint for the Flyte runtime.
5
+ Refrain from importing any modules here. If you need to import any modules, do it inside the main function.
6
+ """
7
+
8
+ import asyncio
9
+ import os
10
+ import sys
11
+ from typing import Any, List
12
+
13
+ import click
14
+
15
+ # Todo: work with pvditt to make these the names
16
+ # ACTION_NAME = "_U_ACTION_NAME"
17
+ # RUN_NAME = "_U_RUN_NAME"
18
+ # PROJECT_NAME = "_U_PROJECT_NAME"
19
+ # DOMAIN_NAME = "_U_DOMAIN_NAME"
20
+ # ORG_NAME = "_U_ORG_NAME"
21
+
22
+ ACTION_NAME = "ACTION_NAME"
23
+ RUN_NAME = "RUN_NAME"
24
+ PROJECT_NAME = "FLYTE_INTERNAL_TASK_PROJECT"
25
+ DOMAIN_NAME = "FLYTE_INTERNAL_TASK_DOMAIN"
26
+ ORG_NAME = "_U_ORG_NAME"
27
+ ENDPOINT_OVERRIDE = "_U_EP_OVERRIDE"
28
+ RUN_OUTPUT_BASE_DIR = "_U_RUN_BASE"
29
+ ENABLE_REF_TASKS = "_REF_TASKS" # This is a temporary flag to enable reference tasks in the runtime.
30
+
31
+ # TODO: Remove this after proper auth is implemented
32
+ _UNION_EAGER_API_KEY_ENV_VAR = "_UNION_EAGER_API_KEY"
33
+
34
+
35
+ @click.group()
36
+ def _pass_through():
37
+ pass
38
+
39
+
40
+ @_pass_through.command("a0")
41
+ @click.option("--inputs", "-i", required=True)
42
+ @click.option("--outputs-path", "-o", required=True)
43
+ @click.option("--version", "-v", required=True)
44
+ @click.option("--run-base-dir", envvar=RUN_OUTPUT_BASE_DIR, required=True)
45
+ @click.option("--raw-data-path", "-r", required=False)
46
+ @click.option("--checkpoint-path", "-c", required=False)
47
+ @click.option("--prev-checkpoint", "-p", required=False)
48
+ @click.option("--name", envvar=ACTION_NAME, required=False)
49
+ @click.option("--run-name", envvar=RUN_NAME, required=False)
50
+ @click.option("--project", envvar=PROJECT_NAME, required=False)
51
+ @click.option("--domain", envvar=DOMAIN_NAME, required=False)
52
+ @click.option("--org", envvar=ORG_NAME, required=False)
53
+ @click.option("--image-cache", required=False)
54
+ @click.option("--tgz", required=False)
55
+ @click.option("--pkl", required=False)
56
+ @click.option("--dest", required=False)
57
+ @click.option("--resolver", required=False)
58
+ @click.argument(
59
+ "resolver-args",
60
+ type=click.UNPROCESSED,
61
+ nargs=-1,
62
+ )
63
+ def main(
64
+ run_name: str,
65
+ name: str,
66
+ project: str,
67
+ domain: str,
68
+ org: str,
69
+ image_cache: str,
70
+ version: str,
71
+ inputs: str,
72
+ run_base_dir: str,
73
+ outputs_path: str,
74
+ raw_data_path: str,
75
+ checkpoint_path: str,
76
+ prev_checkpoint: str,
77
+ tgz: str,
78
+ pkl: str,
79
+ dest: str,
80
+ resolver: str,
81
+ resolver_args: List[str],
82
+ ):
83
+ sys.path.insert(0, ".")
84
+
85
+ import flyte
86
+ import flyte._utils as utils
87
+ from flyte._initialize import init
88
+ from flyte._internal.controllers import create_controller
89
+ from flyte._internal.imagebuild.image_builder import ImageCache
90
+ from flyte._internal.runtime.entrypoints import load_and_run_task
91
+ from flyte._logging import logger
92
+ from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath
93
+
94
+ logger.info(f"Initializing flyte runtime - version {flyte.__version__}")
95
+
96
+ assert org, "Org is required for now"
97
+ assert project, "Project is required"
98
+ assert domain, "Domain is required"
99
+ assert run_name, f"Run name is required {run_name}"
100
+ assert name, f"Action name is required {name}"
101
+
102
+ if run_name.startswith("{{"):
103
+ run_name = os.getenv("RUN_NAME", "")
104
+ if name.startswith("{{"):
105
+ name = os.getenv("ACTION_NAME", "")
106
+
107
+ # Figure out how to connect
108
+ # This detection of api key is a hack for now.
109
+ controller_kwargs: dict[str, Any] = {"insecure": False}
110
+ if api_key := os.getenv(_UNION_EAGER_API_KEY_ENV_VAR):
111
+ logger.info("Using api key from environment")
112
+ controller_kwargs["api_key"] = api_key
113
+ else:
114
+ ep = os.environ.get(ENDPOINT_OVERRIDE, "host.docker.internal:8090")
115
+ controller_kwargs["endpoint"] = ep
116
+ if "localhost" in ep or "docker" in ep:
117
+ controller_kwargs["insecure"] = True
118
+ logger.debug(f"Using controller endpoint: {ep} with kwargs: {controller_kwargs}")
119
+
120
+ bundle = CodeBundle(tgz=tgz, pkl=pkl, destination=dest, computed_version=version)
121
+ enable_ref_tasks = os.getenv(ENABLE_REF_TASKS, "false").lower() in ("true", "1", "yes")
122
+ # We init regular client here so that reference tasks can work
123
+ # Current reference tasks will not work with remote controller, because we create 2 different
124
+ # channels on different threads and this is not supported by grpcio or the auth system. It ends up leading
125
+ # File "src/python/grpcio/grpc/_cython/_cygrpc/aio/completion_queue.pyx.pxi", line 147,
126
+ # in grpc._cython.cygrpc.PollerCompletionQueue._handle_events
127
+ # BlockingIOError: [Errno 11] Resource temporarily unavailable
128
+ # TODO solution is to use a single channel for both controller and reference tasks, but this requires a refactor
129
+ if enable_ref_tasks:
130
+ logger.warning(
131
+ "Reference tasks are enabled. This will initialize client and you will see a BlockIOError. "
132
+ "This is harmless, but a nuisance. We are working on a fix."
133
+ )
134
+ init(org=org, project=project, domain=domain, **controller_kwargs)
135
+ else:
136
+ init()
137
+ # Controller is created with the same kwargs as init, so that it can be used to run tasks
138
+ controller = create_controller(ct="remote", **controller_kwargs)
139
+
140
+ ic = ImageCache.from_transport(image_cache) if image_cache else None
141
+
142
+ # Create a coroutine to load the task and run it
143
+ task_coroutine = load_and_run_task(
144
+ resolver=resolver,
145
+ resolver_args=resolver_args,
146
+ action=ActionID(name=name, run_name=run_name, project=project, domain=domain, org=org),
147
+ raw_data_path=RawDataPath(path=raw_data_path),
148
+ checkpoints=Checkpoints(checkpoint_path, prev_checkpoint),
149
+ code_bundle=bundle,
150
+ input_path=inputs,
151
+ output_path=outputs_path,
152
+ run_base_dir=run_base_dir,
153
+ version=version,
154
+ controller=controller,
155
+ image_cache=ic,
156
+ )
157
+ # Create a coroutine to watch for errors
158
+ controller_failure = controller.watch_for_errors()
159
+
160
+ # Run both coroutines concurrently and wait for first to finish and cancel the other
161
+ async def _run_and_stop():
162
+ await utils.run_coros(controller_failure, task_coroutine)
163
+ await controller.stop()
164
+
165
+ asyncio.run(_run_and_stop())
166
+
167
+
168
+ if __name__ == "__main__":
169
+ _pass_through()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flyte
3
- Version: 0.2.0b22
3
+ Version: 0.2.0b24
4
4
  Summary: Add your description here
5
5
  Author-email: Ketan Umare <kumare3@users.noreply.github.com>
6
6
  Requires-Python: >=3.10
@@ -1,35 +1,37 @@
1
- flyte/__init__.py,sha256=VjIdD4AyX7mYNoWDSqZ6tmNKo14ZFwfToKeb49SB60U,1733
1
+ flyte/__init__.py,sha256=jWSynBJyJ0WuSjE8HID2MT386ZNlRzx7LzCHW1y_XNw,1468
2
2
  flyte/_build.py,sha256=MkgfLAPeL56YeVrGRNZUCZgbwzlEzVP3wLbl5Qru4yk,578
3
3
  flyte/_context.py,sha256=K0-TCt-_pHOoE5Xni87_8uIe2vCBOhfNQEtjGT4Hu4k,5239
4
4
  flyte/_deploy.py,sha256=HU2ksZUvHj77DrJm7MryN0n26DUJqo4h7-op4fUTBUg,9593
5
5
  flyte/_doc.py,sha256=_OPCf3t_git6UT7kSJISFaWO9cfNzJhhoe6JjVdyCJo,706
6
6
  flyte/_docstring.py,sha256=SsG0Ab_YMAwy2ABJlEo3eBKlyC3kwPdnDJ1FIms-ZBQ,1127
7
- flyte/_environment.py,sha256=dmIFmFLRIklOEYL9gsP2IH3-MYcjHYyyOlqlcf3E6_A,2924
7
+ flyte/_environment.py,sha256=oKVXLBX0ky2eE_wjBdzvQGI_2LiT2Nbx58ur7GMt50c,3231
8
8
  flyte/_excepthook.py,sha256=nXts84rzEg6-7RtFarbKzOsRZTQR4plnbWVIFMAEprs,1310
9
9
  flyte/_group.py,sha256=7o1j16sZyUmYB50mOiq1ui4TBAKhRpDqLakV8Ya1kw4,803
10
10
  flyte/_hash.py,sha256=Of_Zl_DzzzF2jp4ZsLm-3o-xJFCCJ8_GubmLI1htx78,504
11
- flyte/_image.py,sha256=6d_bxy2KOn0_wj895iPOxXKN5uzygUeT2fOhxG95dg4,30633
11
+ flyte/_image.py,sha256=ksG_BuDqQyUrrgkoI-I2FWvHvj0Zt8h1x9FAO2DryXI,31439
12
12
  flyte/_initialize.py,sha256=xKl_LYMluRt21wWqa6RTKuLo0_DCbSaTfUk27_brtNk,18232
13
13
  flyte/_interface.py,sha256=1B9zIwFDjiVp_3l_mk8EpA4g3Re-6DUBEBi9z9vDvPs,3504
14
14
  flyte/_logging.py,sha256=_yNo-Nx2yzh0MLoZGbnIYHGKei4wtQmSGM0lE30Ev7w,3662
15
15
  flyte/_map.py,sha256=efPd8O-JKUg1OY3_MzU3KGbhsGYDVRNBwWr0ceNIXhQ,7444
16
- flyte/_pod.py,sha256=lNaQuWX22QG6Xji7-8GpuKUkqCmVFaRxOVU-eUEa-Vk,637
17
- flyte/_resources.py,sha256=UOLyEVhdxolvrHhddiBbYdJuE1RkM_l7xeS9G1abe6M,7583
16
+ flyte/_pod.py,sha256=--72b0c6IkOEbBwZPLmgl-ll-j7ECfG-kh75LzBnNN8,1068
17
+ flyte/_resources.py,sha256=L2JuvQDlMo1JLJeUmJPRwtWbunhR2xJEhFgQW5yc72c,9690
18
18
  flyte/_retry.py,sha256=rfLv0MvWxzPByKESTglEmjPsytEAKiIvvmzlJxXwsfE,941
19
19
  flyte/_reusable_environment.py,sha256=P4FBATVKAYcIKpdFN98sI8acPyKy8eIGx6V0kUb9YdM,1289
20
- flyte/_run.py,sha256=HDObXGfZKBafebuh1OR5e1NWLUSbBDcceIz3kX2GMss,22828
20
+ flyte/_run.py,sha256=ZdpnWNkMTWLA7QZLzczaBdcgWfzv6DMMUAe5OQdv_Ik,23227
21
21
  flyte/_secret.py,sha256=SqIHs6mi8hEkIIBZe3bI9jJsPt65Mt6dV5uh9_op1ME,2392
22
- flyte/_task.py,sha256=-3VLSROj8g3-ZWzV272Ww7mt5yIeA753kmpNyr75I58,18087
23
- flyte/_task_environment.py,sha256=B493b54xpnNFNtHek3cmT4XuNiPNVf2gx9Yylyb1ncw,6905
22
+ flyte/_task.py,sha256=uXKhOn5Avz4TmUUlDbJ5f-XnOr1C4ps-lB8aP5celKE,18181
23
+ flyte/_task_environment.py,sha256=u2JZCtP1A3C_xCILQGFkfRdPgsyjDSORcCznZFkUZ2E,7779
24
+ flyte/_task_plugins.py,sha256=9MH3nFPOH_e8_92BT4sFk4oyAnj6GJFvaPYWaraX7yE,1037
24
25
  flyte/_timeout.py,sha256=zx5sFcbYmjJAJbZWSGzzX-BpC9HC7Jfs35T7vVhKwkk,1571
25
26
  flyte/_tools.py,sha256=JewkQZBR_M85tS6QY8e4xXue75jbOE48nID4ZHnc9jY,632
26
27
  flyte/_trace.py,sha256=C788bgoSc3st8kE8Cae2xegnLx2CT6uuRKKfaDrDUys,5122
27
- flyte/_version.py,sha256=EDrKK4XJyvTSZe8-ZzLNFdqQWtheoqh-sXUajAe6Zug,521
28
- flyte/errors.py,sha256=skXcdexLisFZFcTnUmMvMmuh4Ty96oJkyLKaipzkyeI,4954
28
+ flyte/_version.py,sha256=FFwCuNAGsDblVk5R0UiFXYSEp_KeGkD1OMVoRLzyNbc,521
29
+ flyte/errors.py,sha256=IUXHuEgszis9ZyhBUOSSlZx7VPQk0BF4N875yj1FBKo,4963
30
+ flyte/extend.py,sha256=GB4ZedGzKa30vYWRVPOdxEeK62xnUVFY4z2tD6H9eEw,376
29
31
  flyte/models.py,sha256=my7Vxo-NK6gHGahyqtHr42wYjPGw0nl2SGBBoSb6hIc,14733
30
32
  flyte/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
33
  flyte/_bin/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
32
- flyte/_bin/runtime.py,sha256=_ADnpWY4nVvjaujfcx2pBpOskAvnm0WtXBf1uPnZIAQ,6230
34
+ flyte/_bin/runtime.py,sha256=kiRtbRxfFy2zQn6T2-Qfnb3BYopt7FIXBjSCBO63UL4,6334
33
35
  flyte/_cache/__init__.py,sha256=zhdO5UuHQRdzn8GHmSN40nrxfAmI4ihDRuHZM11U84Y,305
34
36
  flyte/_cache/cache.py,sha256=ErhWzzJdEjTIuEF4f-r6IBgko-3Al9iUs1Eq4O42TUE,5021
35
37
  flyte/_cache/defaults.py,sha256=gzJZW0QJPUfd2OPnGpv3tzIfwPtgFjAKoie3NP1P97U,217
@@ -44,14 +46,14 @@ flyte/_internal/controllers/__init__.py,sha256=5CBnS9lb1VFMzZuRXUiaPhlN3G9qh7Aq9
44
46
  flyte/_internal/controllers/_local_controller.py,sha256=Wpgtd50C_ovIHpQSZC6asQc7iKyKIraEf-MAHCwcNJI,7124
45
47
  flyte/_internal/controllers/_trace.py,sha256=biI-lXSIe3gXuWI-KT6T-jTtojQCQ7BLOHTCG3J6MQc,1145
46
48
  flyte/_internal/controllers/remote/__init__.py,sha256=9_azH1eHLqY6VULpDugXi7Kf1kK1ODqEnsQ_3wM6IqU,1919
47
- flyte/_internal/controllers/remote/_action.py,sha256=aGmTCciHkx9YSU6gPWWFsgjIan-mgelHfmZM4SzZJS4,5121
49
+ flyte/_internal/controllers/remote/_action.py,sha256=tB3GfcB4-PcqaupeoFKxqkodFuNCaruHbXHFGkPaTFM,6040
48
50
  flyte/_internal/controllers/remote/_client.py,sha256=HPbzbfaWZVv5wpOvKNtFXR6COiZDwd1cUJQqi60A7oU,1421
49
- flyte/_internal/controllers/remote/_controller.py,sha256=G4gEUl-2f4FqACz3IK7sK0-w93kkJndzkDpbISqo_4w,21247
50
- flyte/_internal/controllers/remote/_core.py,sha256=UKjiL3dGidXcR1dPzJswn67g38HQmVXV7n0zoX_8AZY,18422
51
+ flyte/_internal/controllers/remote/_controller.py,sha256=hTy1cfCVBUxBtQg4ibTzL6RUiJyVfO2xTmMU7fZzXNY,22692
52
+ flyte/_internal/controllers/remote/_core.py,sha256=FJe1ZAWu_0w5SQUFgAHY4Hjjf_AjuI8sYu_orFkPFeU,18889
51
53
  flyte/_internal/controllers/remote/_informer.py,sha256=StiPcQLLW0h36uEBhKsupMY79EeFCKA3QQzvv2IyvRo,14188
52
54
  flyte/_internal/controllers/remote/_service_protocol.py,sha256=B9qbIg6DiGeac-iSccLmX_AL2xUgX4ezNUOiAbSy4V0,1357
53
- flyte/_internal/imagebuild/__init__.py,sha256=4fPz5RN3QB1DGw5O2sZo7x71F3EWhyE1jmCCqy4iezk,573
54
- flyte/_internal/imagebuild/docker_builder.py,sha256=LIfqtLQ4NVxkjgWlASiIi_xqvxhztlUDTVt4P7McLDo,15946
55
+ flyte/_internal/imagebuild/__init__.py,sha256=dwXdJ1jMhw9RF8itF7jkPLanvX1yCviSns7hE5eoIts,102
56
+ flyte/_internal/imagebuild/docker_builder.py,sha256=4-AJLX9CgPmVblzV-caIn3uvHLIlBMntOUKUmC-ptM8,15927
55
57
  flyte/_internal/imagebuild/image_builder.py,sha256=dXBXl62qcPabus6dR3eP8P9mBGNhpZHZ2Xm12AymKkk,11150
56
58
  flyte/_internal/imagebuild/remote_builder.py,sha256=wm-D3xhdrk8OSQxeyI0nfMpw3P1gW7mqDKAOKOnwY6o,10089
57
59
  flyte/_internal/resolvers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -59,11 +61,11 @@ flyte/_internal/resolvers/_task_module.py,sha256=jwy1QYygUK7xmpCZLt1SPTfJCkfox3C
59
61
  flyte/_internal/resolvers/common.py,sha256=ADQLRoyGsJ4vuUkitffMGrMKKjy0vpk6X53g4FuKDLc,993
60
62
  flyte/_internal/resolvers/default.py,sha256=nX4DHUYod1nRvEsl_vSgutQVEdExu2xL8pRkyi4VWbY,981
61
63
  flyte/_internal/runtime/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
62
- flyte/_internal/runtime/convert.py,sha256=rvKXEDQxFLyAslSJeOrqeG1kkXaYxHWxkBMloaZNCPc,13795
64
+ flyte/_internal/runtime/convert.py,sha256=W0nMLZU3GkTGtIg5ZAz0ioIk3aitbaK2M2qbSe26VDQ,15079
63
65
  flyte/_internal/runtime/entrypoints.py,sha256=Kyi19i7LYk7YM3ZV_Y4FXGt5Pc1tIftGkIDohopblyY,5127
64
66
  flyte/_internal/runtime/io.py,sha256=Lgdy4iPjlKjUO-V_AkoPZff6lywaFjZUG-PErRukmx4,4248
65
67
  flyte/_internal/runtime/resources_serde.py,sha256=TObMVsSjVcQhcY8-nY81pbvrz7TP-adDD5xV-LqAaxM,4813
66
- flyte/_internal/runtime/task_serde.py,sha256=wjYZSr0aXCvo6egWLlTrrRgYbx-V6L1clY2jstJNQlk,14039
68
+ flyte/_internal/runtime/task_serde.py,sha256=x24z2U_StqeavhQ5U5ljrHLMlSOjz-jOrLPliqOKrxM,14479
67
69
  flyte/_internal/runtime/taskrunner.py,sha256=rHWS4t5qgZnzGdGrs0_O0sSs_PVGoE1CNPDb-fTwwmo,7332
68
70
  flyte/_internal/runtime/types_serde.py,sha256=EjRh9Yypx9-20XXQprtNgp766LeQVRoYWtY6XPGMZQg,1813
69
71
  flyte/_protos/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -153,7 +155,7 @@ flyte/_utils/uv_script_parser.py,sha256=PxqD8lSMi6xv0uDd1s8LKB2IPZr4ttZJCUweqlyM
153
155
  flyte/cli/__init__.py,sha256=aeCcumeP9xD_5aCmaRYUPCe2QRJSGCaxcUbTZ3co768,341
154
156
  flyte/cli/_abort.py,sha256=Ty-63Gtd2PUn6lCuL5AaasfBoPu7TDSU5EQKVbkF4qw,661
155
157
  flyte/cli/_build.py,sha256=SBgybTVWOZ22VBHFL8CVFB_oo34lF9wvlwNirYFFyk0,3543
156
- flyte/cli/_common.py,sha256=CFBmdc0c59XqWb7PuZfLr7Z55HXnlqKaxDCq1sgY9MU,11115
158
+ flyte/cli/_common.py,sha256=prQPs8_qSNCvhBPgHa6AY8c9OwutWw3js9jd5vdL_kk,11161
157
159
  flyte/cli/_create.py,sha256=8U4kePYmfW02j54184IurNYiTGvjdr3J-ERQM9goHdg,4715
158
160
  flyte/cli/_delete.py,sha256=VTmXv09PBjkdtyl23mbSjIQQlN7Y1AI_bO0GkHP-f9E,546
159
161
  flyte/cli/_deploy.py,sha256=Zxm7vn1zrqmll73CJTiVGYJI95P-XBI1AlhOlmbmkD0,4635
@@ -219,10 +221,11 @@ flyte/types/_interface.py,sha256=5y9EC5r897xz03Hh0ltF8QVGKMfMfAznws-hKSEO4Go,167
219
221
  flyte/types/_pickle.py,sha256=PjdR66OTDMZ3OYq6GvM_Ua0cIo5t2XQaIjmpJ9xo4Ys,4050
220
222
  flyte/types/_renderer.py,sha256=ygcCo5l60lHufyQISFddZfWwLlQ8kJAKxUT_XnR_6dY,4818
221
223
  flyte/types/_string_literals.py,sha256=NlG1xV8RSA-sZ-n-IFQCAsdB6jXJOAKkHWtnopxVVDk,4231
222
- flyte/types/_type_engine.py,sha256=CCjpqXNX2BMza2cKq42hJXwabWy8GWsimsgiGZJ_00M,96583
224
+ flyte/types/_type_engine.py,sha256=qIrXXBFqv4CydzLW-9iBeTKK2Xo08DHtEYbF9iYDkbU,97087
223
225
  flyte/types/_utils.py,sha256=pbts9E1_2LTdLygAY0UYTLYJ8AsN3BZyviSXvrtcutc,2626
224
- flyte-0.2.0b22.dist-info/METADATA,sha256=RAb1kOFofWEX_kiGOtmRH3HNRfNOu_1iUznOooQsrUI,5857
225
- flyte-0.2.0b22.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
226
- flyte-0.2.0b22.dist-info/entry_points.txt,sha256=MIq2z5dBurdCJfpXfMKzgBv7sJOakKRYxr8G0cMiTrg,75
227
- flyte-0.2.0b22.dist-info/top_level.txt,sha256=7dkyFbikvA12LEZEqawx8oDG1CMod6hTliPj7iWzgYo,6
228
- flyte-0.2.0b22.dist-info/RECORD,,
226
+ flyte-0.2.0b24.data/scripts/runtime.py,sha256=kiRtbRxfFy2zQn6T2-Qfnb3BYopt7FIXBjSCBO63UL4,6334
227
+ flyte-0.2.0b24.dist-info/METADATA,sha256=fi8nuO9osztTEnuZkk9tVSxMWKP3sC2xqgIZVJ_aW44,5857
228
+ flyte-0.2.0b24.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
229
+ flyte-0.2.0b24.dist-info/entry_points.txt,sha256=MIq2z5dBurdCJfpXfMKzgBv7sJOakKRYxr8G0cMiTrg,75
230
+ flyte-0.2.0b24.dist-info/top_level.txt,sha256=7dkyFbikvA12LEZEqawx8oDG1CMod6hTliPj7iWzgYo,6
231
+ flyte-0.2.0b24.dist-info/RECORD,,