flyte 0.2.0b23__py3-none-any.whl → 0.2.0b25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

flyte/_run.py CHANGED
@@ -6,13 +6,9 @@ import uuid
6
6
  from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union, cast
7
7
 
8
8
  import flyte.errors
9
- from flyte.errors import InitializationError
10
- from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath, SerializationContext, TaskContext
11
- from flyte.syncify import syncify
12
-
13
- from ._context import contextual_run, internal_ctx
14
- from ._environment import Environment
15
- from ._initialize import (
9
+ from flyte._context import contextual_run, internal_ctx
10
+ from flyte._environment import Environment
11
+ from flyte._initialize import (
16
12
  _get_init_config,
17
13
  get_client,
18
14
  get_common_config,
@@ -20,9 +16,19 @@ from ._initialize import (
20
16
  requires_initialization,
21
17
  requires_storage,
22
18
  )
23
- from ._logging import logger
24
- from ._task import P, R, TaskTemplate
25
- from ._tools import ipython_check
19
+ from flyte._logging import logger
20
+ from flyte._task import P, R, TaskTemplate
21
+ from flyte._tools import ipython_check
22
+ from flyte.errors import InitializationError
23
+ from flyte.models import (
24
+ ActionID,
25
+ Checkpoints,
26
+ CodeBundle,
27
+ RawDataPath,
28
+ SerializationContext,
29
+ TaskContext,
30
+ )
31
+ from flyte.syncify import syncify
26
32
 
27
33
  if TYPE_CHECKING:
28
34
  from flyte.remote import Run
@@ -132,7 +138,9 @@ class _Runner:
132
138
 
133
139
  if self._interactive_mode:
134
140
  code_bundle = await build_pkl_bundle(
135
- obj, upload_to_controlplane=not self._dry_run, copy_bundle_to=self._copy_bundle_to
141
+ obj,
142
+ upload_to_controlplane=not self._dry_run,
143
+ copy_bundle_to=self._copy_bundle_to,
136
144
  )
137
145
  else:
138
146
  if self._copy_files != "none":
@@ -253,7 +261,8 @@ class _Runner:
253
261
  pb2=run_definition_pb2.Run(
254
262
  action=run_definition_pb2.Action(
255
263
  id=run_definition_pb2.ActionIdentifier(
256
- name="a0", run=run_definition_pb2.RunIdentifier(name="dry-run")
264
+ name="a0",
265
+ run=run_definition_pb2.RunIdentifier(name="dry-run"),
257
266
  )
258
267
  )
259
268
  )
@@ -286,7 +295,7 @@ class _Runner:
286
295
  if obj.parent_env is None:
287
296
  raise ValueError("Task is not attached to an environment. Please attach the task to an environment.")
288
297
 
289
- image_cache = build_images.aio(cast(Environment, obj.parent_env()))
298
+ image_cache = await build_images.aio(cast(Environment, obj.parent_env()))
290
299
 
291
300
  code_bundle = None
292
301
  if self._name is not None:
@@ -296,7 +305,9 @@ class _Runner:
296
305
  if not code_bundle:
297
306
  if self._interactive_mode:
298
307
  code_bundle = await build_pkl_bundle(
299
- obj, upload_to_controlplane=not self._dry_run, copy_bundle_to=self._copy_bundle_to
308
+ obj,
309
+ upload_to_controlplane=not self._dry_run,
310
+ copy_bundle_to=self._copy_bundle_to,
300
311
  )
301
312
  else:
302
313
  if self._copy_files != "none":
@@ -381,7 +392,8 @@ class _Runner:
381
392
  tctx = TaskContext(
382
393
  action=action,
383
394
  checkpoints=Checkpoints(
384
- prev_checkpoint_path=internal_ctx().raw_data.path, checkpoint_path=internal_ctx().raw_data.path
395
+ prev_checkpoint_path=internal_ctx().raw_data.path,
396
+ checkpoint_path=internal_ctx().raw_data.path,
385
397
  ),
386
398
  code_bundle=None,
387
399
  output_path=self._metadata_path,
@@ -403,7 +415,10 @@ class _Runner:
403
415
 
404
416
  @syncify
405
417
  async def run(
406
- self, task: TaskTemplate[P, Union[R, Run]] | LazyEntity, *args: P.args, **kwargs: P.kwargs
418
+ self,
419
+ task: TaskTemplate[P, Union[R, Run]] | LazyEntity,
420
+ *args: P.args,
421
+ **kwargs: P.kwargs,
407
422
  ) -> Union[R, Run]:
408
423
  """
409
424
  Run an async `@env.task` or `TaskTemplate` instance. The existing async context will be used.
@@ -430,6 +445,10 @@ class _Runner:
430
445
 
431
446
  if isinstance(task, LazyEntity) and self._mode != "remote":
432
447
  raise ValueError("Remote task can only be run in remote mode.")
448
+
449
+ if not isinstance(task, TaskTemplate) and not isinstance(task, LazyEntity):
450
+ raise TypeError("On Flyte tasks can be run, not generic functions or methods.")
451
+
433
452
  if self._mode == "remote":
434
453
  return await self._run_remote(task, *args, **kwargs)
435
454
  task = cast(TaskTemplate, task)
flyte/_secret.py CHANGED
@@ -45,6 +45,27 @@ class Secret:
45
45
  if not re.match(pattern, self.as_env_var):
46
46
  raise ValueError(f"Invalid environment variable name: {self.as_env_var}, must match {pattern}")
47
47
 
48
+ def stable_hash(self) -> str:
49
+ """
50
+ Deterministic, process-independent hash (as hex string).
51
+ """
52
+ import hashlib
53
+
54
+ data = (
55
+ self.key,
56
+ self.group or "",
57
+ str(self.mount) if self.mount else "",
58
+ self.as_env_var or "",
59
+ )
60
+ joined = "|".join(data)
61
+ return hashlib.sha256(joined.encode("utf-8")).hexdigest()
62
+
63
+ def __hash__(self) -> int:
64
+ """
65
+ Deterministic hash function for the Secret class.
66
+ """
67
+ return int(self.stable_hash()[:16], 16)
68
+
48
69
 
49
70
  SecretRequest = Union[str, Secret, List[str | Secret]]
50
71
 
@@ -59,3 +80,12 @@ def secrets_from_request(secrets: SecretRequest) -> List[Secret]:
59
80
  return [secrets]
60
81
  else:
61
82
  return [Secret(key=s) if isinstance(s, str) else s for s in secrets]
83
+
84
+
85
+ if __name__ == "__main__":
86
+ # Example usage
87
+ secret1 = Secret(key="MY_SECRET", mount=pathlib.Path("/path/to/secret"), as_env_var="MY_SECRET_ENV")
88
+ secret2 = Secret(
89
+ key="ANOTHER_SECRET",
90
+ )
91
+ print(hash(secret1), hash(secret2))
flyte/_task.py CHANGED
@@ -3,7 +3,6 @@ from __future__ import annotations
3
3
  import asyncio
4
4
  import weakref
5
5
  from dataclasses import dataclass, field, replace
6
- from functools import cached_property
7
6
  from inspect import iscoroutinefunction
8
7
  from typing import (
9
8
  TYPE_CHECKING,
@@ -361,16 +360,13 @@ class AsyncFunctionTaskTemplate(TaskTemplate[P, R]):
361
360
  """
362
361
 
363
362
  func: FunctionTypes
363
+ plugin_config: Optional[Any] = None # This is used to pass plugin specific configuration
364
364
 
365
365
  def __post_init__(self):
366
366
  super().__post_init__()
367
367
  if not iscoroutinefunction(self.func):
368
368
  self._call_as_synchronous = True
369
369
 
370
- @cached_property
371
- def native_interface(self) -> NativeInterface:
372
- return NativeInterface.from_callable(self.func)
373
-
374
370
  def forward(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, R] | R:
375
371
  # In local execution, we want to just call the function. Note we're not awaiting anything here.
376
372
  # If the function was a coroutine function, the coroutine is returned and the await that the caller has
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import inspect
3
4
  import weakref
4
5
  from dataclasses import dataclass, field, replace
5
6
  from datetime import timedelta
@@ -60,13 +61,19 @@ class TaskEnvironment(Environment):
60
61
  :param reusable: Reuse policy for the environment, if set, a python process may be reused for multiple tasks.
61
62
  """
62
63
 
63
- cache: Union[CacheRequest] = "auto"
64
+ cache: Union[CacheRequest] = "disable"
64
65
  reusable: ReusePolicy | None = None
66
+ plugin_config: Optional[Any] = None
65
67
  # TODO Shall we make this union of string or env? This way we can lookup the env by module/file:name
66
68
  # TODO also we could add list of files that are used by this environment
67
69
 
68
70
  _tasks: Dict[str, TaskTemplate] = field(default_factory=dict, init=False)
69
71
 
72
+ def __post_init__(self) -> None:
73
+ super().__post_init__()
74
+ if self.reusable is not None and self.plugin_config is not None:
75
+ raise ValueError("Cannot set plugin_config when environment is reusable.")
76
+
70
77
  def clone_with(
71
78
  self,
72
79
  name: str,
@@ -133,6 +140,8 @@ class TaskEnvironment(Environment):
133
140
  used.
134
141
  :param report: Optional Whether to generate the html report for the task, defaults to False.
135
142
  """
143
+ from ._task import P, R
144
+
136
145
  if self.reusable is not None:
137
146
  if pod_template is not None:
138
147
  raise ValueError("Cannot set pod_template when environment is reusable.")
@@ -141,7 +150,29 @@ class TaskEnvironment(Environment):
141
150
  friendly_name = name or func.__name__
142
151
  task_name = self.name + "." + func.__name__
143
152
 
144
- tmpl: AsyncFunctionTaskTemplate = AsyncFunctionTaskTemplate(
153
+ if not inspect.iscoroutinefunction(func) and self.reusable is not None:
154
+ if self.reusable.concurrency > 1:
155
+ raise ValueError(
156
+ "Reusable environments with concurrency greater than 1 are only supported for async tasks. "
157
+ "Please use an async function or set concurrency to 1."
158
+ )
159
+
160
+ if self.plugin_config is not None:
161
+ from flyte.extend import TaskPluginRegistry
162
+
163
+ task_template_class: type[AsyncFunctionTaskTemplate[P, R]] | None = TaskPluginRegistry.find(
164
+ config_type=type(self.plugin_config)
165
+ )
166
+ if task_template_class is None:
167
+ raise ValueError(
168
+ f"No task plugin found for config type {type(self.plugin_config)}. "
169
+ f"Please register a plugin using flyte.extend.TaskPluginRegistry.register() api."
170
+ )
171
+ else:
172
+ task_template_class = AsyncFunctionTaskTemplate[P, R]
173
+
174
+ task_template_class = cast(type[AsyncFunctionTaskTemplate[P, R]], task_template_class)
175
+ tmpl = task_template_class(
145
176
  func=func,
146
177
  name=task_name,
147
178
  image=self.image,
@@ -158,6 +189,7 @@ class TaskEnvironment(Environment):
158
189
  interface=NativeInterface.from_callable(func),
159
190
  report=report,
160
191
  friendly_name=friendly_name,
192
+ plugin_config=self.plugin_config,
161
193
  )
162
194
  self._tasks[task_name] = tmpl
163
195
  return tmpl
flyte/_task_plugins.py ADDED
@@ -0,0 +1,45 @@
1
+ import typing
2
+ from typing import Type
3
+
4
+ import rich.repr
5
+
6
+ if typing.TYPE_CHECKING:
7
+ from ._task import AsyncFunctionTaskTemplate
8
+
9
+ T = typing.TypeVar("T", bound="AsyncFunctionTaskTemplate")
10
+
11
+
12
+ class _Registry:
13
+ """
14
+ A registry for task plugins.
15
+ """
16
+
17
+ def __init__(self):
18
+ self._plugins: typing.Dict[Type, Type[T]] = {}
19
+
20
+ def register(self, config_type: Type, plugin: Type[T]):
21
+ """
22
+ Register a plugin.
23
+ """
24
+ self._plugins[config_type] = plugin
25
+
26
+ def find(self, config_type: Type) -> typing.Optional[Type[T]]:
27
+ """
28
+ Get a plugin by name.
29
+ """
30
+ return self._plugins.get(config_type)
31
+
32
+ def list_plugins(self):
33
+ """
34
+ List all registered plugins.
35
+ """
36
+ return list(self._plugins.keys())
37
+
38
+ def __rich_repr__(self) -> "rich.repr.Result":
39
+ yield from (("Name", i) for i in self.list_plugins())
40
+
41
+ def __repr__(self):
42
+ return f"TaskPluginRegistry(plugins={self.list_plugins()})"
43
+
44
+
45
+ TaskPluginRegistry = _Registry()
flyte/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.0b23'
21
- __version_tuple__ = version_tuple = (0, 2, 0, 'b23')
20
+ __version__ = version = '0.2.0b25'
21
+ __version_tuple__ = version_tuple = (0, 2, 0, 'b25')
flyte/errors.py CHANGED
@@ -157,9 +157,9 @@ class RuntimeDataValidationError(RuntimeUserError):
157
157
  This error is raised when the user tries to access a resource that does not exist or is invalid.
158
158
  """
159
159
 
160
- def __init__(self, var: str, e: Exception, task_name: str = ""):
160
+ def __init__(self, var: str, e: Exception | str, task_name: str = ""):
161
161
  super().__init__(
162
- "DataValiationError", f"In task {task_name} variable {var}, failed to serialize/deserialize because {e}"
162
+ "DataValiationError", f"In task {task_name} variable {var}, failed to serialize/deserialize because of {e}"
163
163
  )
164
164
 
165
165
 
flyte/extend.py ADDED
@@ -0,0 +1,12 @@
1
+ from ._initialize import is_initialized
2
+ from ._resources import PRIMARY_CONTAINER_DEFAULT_NAME, pod_spec_from_resources
3
+ from ._task import AsyncFunctionTaskTemplate
4
+ from ._task_plugins import TaskPluginRegistry
5
+
6
+ __all__ = [
7
+ "PRIMARY_CONTAINER_DEFAULT_NAME",
8
+ "AsyncFunctionTaskTemplate",
9
+ "TaskPluginRegistry",
10
+ "is_initialized",
11
+ "pod_spec_from_resources",
12
+ ]
flyte/models.py CHANGED
@@ -256,6 +256,7 @@ class NativeInterface:
256
256
  _remote_defaults: Optional[Dict[str, literals_pb2.Literal]] = field(default=None, repr=False)
257
257
 
258
258
  has_default: ClassVar[Type[_has_default]] = _has_default # This can be used to indicate if a specific input
259
+
259
260
  # has a default value or not, in the case when the default value is not known. An example would be remote tasks.
260
261
 
261
262
  def has_outputs(self) -> bool:
@@ -298,7 +299,15 @@ class NativeInterface:
298
299
  sig = inspect.signature(func)
299
300
 
300
301
  # Extract parameter details (name, type, default value)
301
- param_info = {name: (param.annotation, param.default) for name, param in sig.parameters.items()}
302
+ param_info = {}
303
+ for name, param in sig.parameters.items():
304
+ if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
305
+ raise ValueError(f"Function {func.__name__} cannot have variable positional or keyword arguments.")
306
+ if param.annotation is inspect.Parameter.empty:
307
+ logger.warning(
308
+ f"Function {func.__name__} has parameter {name} without type annotation. Data will be pickled."
309
+ )
310
+ param_info[name] = (param.annotation, param.default)
302
311
 
303
312
  # Get return type
304
313
  outputs = extract_return_annotation(sig.return_annotation)
@@ -1315,7 +1315,8 @@ class TypeEngine(typing.Generic[T]):
1315
1315
  try:
1316
1316
  return transformer.guess_python_type(flyte_type)
1317
1317
  except ValueError:
1318
- logger.debug(f"Skipping transformer {transformer.name} for {flyte_type}")
1318
+ # Skipping transformer
1319
+ continue
1319
1320
 
1320
1321
  # Because the dataclass transformer is handled explicitly in the get_transformer code, we have to handle it
1321
1322
  # separately here too.
@@ -1438,6 +1439,19 @@ def _type_essence(x: types_pb2.LiteralType) -> types_pb2.LiteralType:
1438
1439
 
1439
1440
 
1440
1441
  def _are_types_castable(upstream: types_pb2.LiteralType, downstream: types_pb2.LiteralType) -> bool:
1442
+ if upstream.union_type is not None:
1443
+ # for each upstream variant, there must be a compatible type downstream
1444
+ for v in upstream.union_type.variants:
1445
+ if not _are_types_castable(v, downstream):
1446
+ return False
1447
+ return True
1448
+
1449
+ if downstream.union_type is not None:
1450
+ # there must be a compatible downstream type
1451
+ for v in downstream.union_type.variants:
1452
+ if _are_types_castable(upstream, v):
1453
+ return True
1454
+
1441
1455
  if upstream.HasField("collection_type"):
1442
1456
  if not downstream.HasField("collection_type"):
1443
1457
  return False
@@ -1483,7 +1497,7 @@ def _are_types_castable(upstream: types_pb2.LiteralType, downstream: types_pb2.L
1483
1497
 
1484
1498
  return True
1485
1499
 
1486
- if upstream.HasField("union_type"):
1500
+ if upstream.HasField("union_type") and upstream.union_type is not None:
1487
1501
  # for each upstream variant, there must be a compatible type downstream
1488
1502
  for v in upstream.union_type.variants:
1489
1503
  if not _are_types_castable(v, downstream):
@@ -0,0 +1,169 @@
1
+ """
2
+ Flyte runtime module, this is the entrypoint script for the Flyte runtime.
3
+
4
+ Caution: Startup time for this module is very important, as it is the entrypoint for the Flyte runtime.
5
+ Refrain from importing any modules here. If you need to import any modules, do it inside the main function.
6
+ """
7
+
8
+ import asyncio
9
+ import os
10
+ import sys
11
+ from typing import Any, List
12
+
13
+ import click
14
+
15
+ # Todo: work with pvditt to make these the names
16
+ # ACTION_NAME = "_U_ACTION_NAME"
17
+ # RUN_NAME = "_U_RUN_NAME"
18
+ # PROJECT_NAME = "_U_PROJECT_NAME"
19
+ # DOMAIN_NAME = "_U_DOMAIN_NAME"
20
+ # ORG_NAME = "_U_ORG_NAME"
21
+
22
+ ACTION_NAME = "ACTION_NAME"
23
+ RUN_NAME = "RUN_NAME"
24
+ PROJECT_NAME = "FLYTE_INTERNAL_TASK_PROJECT"
25
+ DOMAIN_NAME = "FLYTE_INTERNAL_TASK_DOMAIN"
26
+ ORG_NAME = "_U_ORG_NAME"
27
+ ENDPOINT_OVERRIDE = "_U_EP_OVERRIDE"
28
+ RUN_OUTPUT_BASE_DIR = "_U_RUN_BASE"
29
+ ENABLE_REF_TASKS = "_REF_TASKS" # This is a temporary flag to enable reference tasks in the runtime.
30
+
31
+ # TODO: Remove this after proper auth is implemented
32
+ _UNION_EAGER_API_KEY_ENV_VAR = "_UNION_EAGER_API_KEY"
33
+
34
+
35
+ @click.group()
36
+ def _pass_through():
37
+ pass
38
+
39
+
40
+ @_pass_through.command("a0")
41
+ @click.option("--inputs", "-i", required=True)
42
+ @click.option("--outputs-path", "-o", required=True)
43
+ @click.option("--version", "-v", required=True)
44
+ @click.option("--run-base-dir", envvar=RUN_OUTPUT_BASE_DIR, required=True)
45
+ @click.option("--raw-data-path", "-r", required=False)
46
+ @click.option("--checkpoint-path", "-c", required=False)
47
+ @click.option("--prev-checkpoint", "-p", required=False)
48
+ @click.option("--name", envvar=ACTION_NAME, required=False)
49
+ @click.option("--run-name", envvar=RUN_NAME, required=False)
50
+ @click.option("--project", envvar=PROJECT_NAME, required=False)
51
+ @click.option("--domain", envvar=DOMAIN_NAME, required=False)
52
+ @click.option("--org", envvar=ORG_NAME, required=False)
53
+ @click.option("--image-cache", required=False)
54
+ @click.option("--tgz", required=False)
55
+ @click.option("--pkl", required=False)
56
+ @click.option("--dest", required=False)
57
+ @click.option("--resolver", required=False)
58
+ @click.argument(
59
+ "resolver-args",
60
+ type=click.UNPROCESSED,
61
+ nargs=-1,
62
+ )
63
+ def main(
64
+ run_name: str,
65
+ name: str,
66
+ project: str,
67
+ domain: str,
68
+ org: str,
69
+ image_cache: str,
70
+ version: str,
71
+ inputs: str,
72
+ run_base_dir: str,
73
+ outputs_path: str,
74
+ raw_data_path: str,
75
+ checkpoint_path: str,
76
+ prev_checkpoint: str,
77
+ tgz: str,
78
+ pkl: str,
79
+ dest: str,
80
+ resolver: str,
81
+ resolver_args: List[str],
82
+ ):
83
+ sys.path.insert(0, ".")
84
+
85
+ import flyte
86
+ import flyte._utils as utils
87
+ from flyte._initialize import init
88
+ from flyte._internal.controllers import create_controller
89
+ from flyte._internal.imagebuild.image_builder import ImageCache
90
+ from flyte._internal.runtime.entrypoints import load_and_run_task
91
+ from flyte._logging import logger
92
+ from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath
93
+
94
+ logger.info(f"Initializing flyte runtime - version {flyte.__version__}")
95
+
96
+ assert org, "Org is required for now"
97
+ assert project, "Project is required"
98
+ assert domain, "Domain is required"
99
+ assert run_name, f"Run name is required {run_name}"
100
+ assert name, f"Action name is required {name}"
101
+
102
+ if run_name.startswith("{{"):
103
+ run_name = os.getenv("RUN_NAME", "")
104
+ if name.startswith("{{"):
105
+ name = os.getenv("ACTION_NAME", "")
106
+
107
+ # Figure out how to connect
108
+ # This detection of api key is a hack for now.
109
+ controller_kwargs: dict[str, Any] = {"insecure": False}
110
+ if api_key := os.getenv(_UNION_EAGER_API_KEY_ENV_VAR):
111
+ logger.info("Using api key from environment")
112
+ controller_kwargs["api_key"] = api_key
113
+ else:
114
+ ep = os.environ.get(ENDPOINT_OVERRIDE, "host.docker.internal:8090")
115
+ controller_kwargs["endpoint"] = ep
116
+ if "localhost" in ep or "docker" in ep:
117
+ controller_kwargs["insecure"] = True
118
+ logger.debug(f"Using controller endpoint: {ep} with kwargs: {controller_kwargs}")
119
+
120
+ bundle = CodeBundle(tgz=tgz, pkl=pkl, destination=dest, computed_version=version)
121
+ enable_ref_tasks = os.getenv(ENABLE_REF_TASKS, "false").lower() in ("true", "1", "yes")
122
+ # We init regular client here so that reference tasks can work
123
+ # Current reference tasks will not work with remote controller, because we create 2 different
124
+ # channels on different threads and this is not supported by grpcio or the auth system. It ends up leading
125
+ # File "src/python/grpcio/grpc/_cython/_cygrpc/aio/completion_queue.pyx.pxi", line 147,
126
+ # in grpc._cython.cygrpc.PollerCompletionQueue._handle_events
127
+ # BlockingIOError: [Errno 11] Resource temporarily unavailable
128
+ # TODO solution is to use a single channel for both controller and reference tasks, but this requires a refactor
129
+ if enable_ref_tasks:
130
+ logger.warning(
131
+ "Reference tasks are enabled. This will initialize client and you will see a BlockIOError. "
132
+ "This is harmless, but a nuisance. We are working on a fix."
133
+ )
134
+ init(org=org, project=project, domain=domain, **controller_kwargs)
135
+ else:
136
+ init()
137
+ # Controller is created with the same kwargs as init, so that it can be used to run tasks
138
+ controller = create_controller(ct="remote", **controller_kwargs)
139
+
140
+ ic = ImageCache.from_transport(image_cache) if image_cache else None
141
+
142
+ # Create a coroutine to load the task and run it
143
+ task_coroutine = load_and_run_task(
144
+ resolver=resolver,
145
+ resolver_args=resolver_args,
146
+ action=ActionID(name=name, run_name=run_name, project=project, domain=domain, org=org),
147
+ raw_data_path=RawDataPath(path=raw_data_path),
148
+ checkpoints=Checkpoints(checkpoint_path, prev_checkpoint),
149
+ code_bundle=bundle,
150
+ input_path=inputs,
151
+ output_path=outputs_path,
152
+ run_base_dir=run_base_dir,
153
+ version=version,
154
+ controller=controller,
155
+ image_cache=ic,
156
+ )
157
+ # Create a coroutine to watch for errors
158
+ controller_failure = controller.watch_for_errors()
159
+
160
+ # Run both coroutines concurrently and wait for first to finish and cancel the other
161
+ async def _run_and_stop():
162
+ await utils.run_coros(controller_failure, task_coroutine)
163
+ await controller.stop()
164
+
165
+ asyncio.run(_run_and_stop())
166
+
167
+
168
+ if __name__ == "__main__":
169
+ _pass_through()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flyte
3
- Version: 0.2.0b23
3
+ Version: 0.2.0b25
4
4
  Summary: Add your description here
5
5
  Author-email: Ketan Umare <kumare3@users.noreply.github.com>
6
6
  Requires-Python: >=3.10