flyte 0.2.0b7__py3-none-any.whl → 0.2.0b9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

flyte/__init__.py CHANGED
@@ -39,6 +39,7 @@ __all__ = [
39
39
  "group",
40
40
  "init",
41
41
  "init_auto_from_config",
42
+ "map",
42
43
  "run",
43
44
  "trace",
44
45
  "with_runcontext",
@@ -51,6 +52,7 @@ from ._environment import Environment
51
52
  from ._group import group
52
53
  from ._image import Image
53
54
  from ._initialize import init, init_auto_from_config
55
+ from ._map import map
54
56
  from ._resources import GPU, TPU, Device, Resources
55
57
  from ._retry import RetryStrategy
56
58
  from ._reusable_environment import ReusePolicy
flyte/_context.py CHANGED
@@ -4,6 +4,7 @@ import contextvars
4
4
  from dataclasses import dataclass, replace
5
5
  from typing import TYPE_CHECKING, Awaitable, Callable, Optional, ParamSpec, TypeVar
6
6
 
7
+ from flyte._logging import logger
7
8
  from flyte.models import GroupData, RawDataPath, TaskContext
8
9
 
9
10
  if TYPE_CHECKING:
@@ -49,6 +50,7 @@ class Context:
49
50
  raise ValueError("Cannot create a new context without contextdata.")
50
51
  self._data = data
51
52
  self._id = id(self) # Immutable unique identifier
53
+ self._token = None # Context variable token to restore the previous context
52
54
 
53
55
  @property
54
56
  def data(self) -> ContextData:
@@ -106,7 +108,11 @@ class Context:
106
108
 
107
109
  def __exit__(self, exc_type, exc_val, exc_tb):
108
110
  """Exit the context, restoring the previous context."""
109
- root_context_var.reset(self._token)
111
+ try:
112
+ root_context_var.reset(self._token)
113
+ except Exception as e:
114
+ logger.warn(f"Failed to reset context: {e}")
115
+ raise e
110
116
 
111
117
  async def __aenter__(self):
112
118
  """Async version of context entry."""
flyte/_environment.py CHANGED
@@ -1,7 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import re
3
4
  from dataclasses import dataclass, field
4
- from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
5
+ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
5
6
 
6
7
  import rich.repr
7
8
 
@@ -14,6 +15,10 @@ if TYPE_CHECKING:
14
15
  from kubernetes.client import V1PodTemplate
15
16
 
16
17
 
18
+ def is_snake_or_kebab_with_numbers(s: str) -> bool:
19
+ return re.fullmatch(r"^[a-z0-9]+([_-][a-z0-9]+)*$", s) is not None
20
+
21
+
17
22
  @rich.repr.auto
18
23
  @dataclass(init=True, repr=True)
19
24
  class Environment:
@@ -36,8 +41,44 @@ class Environment:
36
41
  resources: Optional[Resources] = None
37
42
  image: Union[str, Image, Literal["auto"]] = "auto"
38
43
 
44
+ def __post_init__(self):
45
+ if not is_snake_or_kebab_with_numbers(self.name):
46
+ raise ValueError(f"Environment name '{self.name}' must be in snake_case or kebab-case format.")
47
+
39
48
  def add_dependency(self, *env: Environment):
40
49
  """
41
50
  Add a dependency to the environment.
42
51
  """
43
52
  self.env_dep_hints.extend(env)
53
+
54
+ def clone_with(
55
+ self,
56
+ name: str,
57
+ image: Optional[Union[str, Image, Literal["auto"]]] = None,
58
+ resources: Optional[Resources] = None,
59
+ env: Optional[Dict[str, str]] = None,
60
+ secrets: Optional[SecretRequest] = None,
61
+ env_dep_hints: Optional[List[Environment]] = None,
62
+ **kwargs: Any,
63
+ ) -> Environment:
64
+ raise NotImplementedError
65
+
66
+ def _get_kwargs(self) -> Dict[str, Any]:
67
+ """
68
+ Get the keyword arguments for the environment.
69
+ """
70
+ kwargs: Dict[str, Any] = {
71
+ "env_dep_hints": self.env_dep_hints,
72
+ "image": self.image,
73
+ }
74
+ if self.resources is not None:
75
+ kwargs["resources"] = self.resources
76
+ if self.secrets is not None:
77
+ kwargs["secrets"] = self.secrets
78
+ if self.env is not None:
79
+ kwargs["env"] = self.env
80
+ if self.pod_template is not None:
81
+ kwargs["pod_template"] = self.pod_template
82
+ if self.description is not None:
83
+ kwargs["description"] = self.description
84
+ return kwargs
flyte/_group.py CHANGED
@@ -29,3 +29,4 @@ def group(name: str):
29
29
  new_tctx = tctx.replace(group_data=GroupData(name))
30
30
  with ctx.replace_task_context(new_tctx):
31
31
  yield
32
+ # Exit the context and restore the previous context
flyte/_image.py CHANGED
@@ -444,8 +444,7 @@ class Image:
444
444
  ```
445
445
 
446
446
  For more information on the uv script format, see the documentation:
447
- <href="https://docs.astral.sh/uv/guides/scripts/#declaring-script-dependencies">
448
- UV: Declaring script dependencies</href>
447
+ [UV: Declaring script dependencies](https://docs.astral.sh/uv/guides/scripts/#declaring-script-dependencies)
449
448
 
450
449
  :param name: name of the image
451
450
  :param registry: registry to use for the image
@@ -1,5 +1,6 @@
1
+ import concurrent.futures
1
2
  import threading
2
- from typing import Any, Callable, Literal, Optional, Protocol, Tuple, TypeVar
3
+ from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Protocol, Tuple, TypeVar
3
4
 
4
5
  from flyte._task import TaskTemplate
5
6
  from flyte.models import ActionID, NativeInterface
@@ -10,6 +11,9 @@ __all__ = ["Controller", "ControllerType", "TraceInfo", "create_controller", "ge
10
11
 
11
12
  from ..._protos.workflow import task_definition_pb2
12
13
 
14
+ if TYPE_CHECKING:
15
+ import concurrent.futures
16
+
13
17
  ControllerType = Literal["local", "remote"]
14
18
 
15
19
  R = TypeVar("R")
@@ -28,6 +32,15 @@ class Controller(Protocol):
28
32
  """
29
33
  ...
30
34
 
35
+ def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
36
+ """
37
+ This should call the async submit method above, but return a concurrent Future object that can be
38
+ used in a blocking wait or wrapped in an async future. This is called when
39
+ a) a synchronous task is kicked off locally,
40
+ b) a running task (of either kind) kicks off a downstream synchronous task.
41
+ """
42
+ ...
43
+
31
44
  async def submit_task_ref(self, _task: task_definition_pb2.TaskDetails, *args, **kwargs) -> Any:
32
45
  """
33
46
  Submit a task reference to the controller asynchronously and wait for the result. This is async and will block
@@ -1,3 +1,8 @@
1
+ import asyncio
2
+ import atexit
3
+ import concurrent.futures
4
+ import os
5
+ import threading
1
6
  from typing import Any, Callable, Tuple, TypeVar
2
7
 
3
8
  import flyte.errors
@@ -8,19 +13,67 @@ from flyte._internal.runtime.entrypoints import direct_dispatch
8
13
  from flyte._logging import log, logger
9
14
  from flyte._protos.workflow import task_definition_pb2
10
15
  from flyte._task import TaskTemplate
16
+ from flyte._utils.helpers import _selector_policy
11
17
  from flyte.models import ActionID, NativeInterface, RawDataPath
12
18
 
13
19
  R = TypeVar("R")
14
20
 
15
21
 
22
+ class _TaskRunner:
23
+ """A task runner that runs an asyncio event loop on a background thread."""
24
+
25
+ def __init__(self) -> None:
26
+ self.__loop: asyncio.AbstractEventLoop | None = None
27
+ self.__runner_thread: threading.Thread | None = None
28
+ self.__lock = threading.Lock()
29
+ atexit.register(self._close)
30
+
31
+ def _close(self) -> None:
32
+ if self.__loop:
33
+ self.__loop.stop()
34
+
35
+ def _execute(self) -> None:
36
+ loop = self.__loop
37
+ assert loop is not None
38
+ try:
39
+ loop.run_forever()
40
+ finally:
41
+ loop.close()
42
+
43
+ def get_exc_handler(self):
44
+ def exc_handler(loop, context):
45
+ logger.error(
46
+ f"Taskrunner for {self.__runner_thread.name if self.__runner_thread else 'no thread'} caught"
47
+ f" exception in {loop}: {context}"
48
+ )
49
+
50
+ return exc_handler
51
+
52
+ def get_run_future(self, coro: Any) -> concurrent.futures.Future:
53
+ """Synchronously run a coroutine on a background thread."""
54
+ name = f"{threading.current_thread().name} : loop-runner"
55
+ with self.__lock:
56
+ if self.__loop is None:
57
+ with _selector_policy():
58
+ self.__loop = asyncio.new_event_loop()
59
+
60
+ exc_handler = self.get_exc_handler()
61
+ self.__loop.set_exception_handler(exc_handler)
62
+ self.__runner_thread = threading.Thread(target=self._execute, daemon=True, name=name)
63
+ self.__runner_thread.start()
64
+ fut = asyncio.run_coroutine_threadsafe(coro, self.__loop)
65
+ return fut
66
+
67
+
16
68
  class LocalController:
17
69
  def __init__(self):
18
70
  logger.debug("LocalController init")
71
+ self._runner_map: dict[str, _TaskRunner] = {}
19
72
 
20
73
  @log
21
74
  async def submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
22
75
  """
23
- Submit a node to the controller
76
+ Main entrypoint for submitting a task to the local controller.
24
77
  """
25
78
  ctx = internal_ctx()
26
79
  tctx = ctx.data.task_context
@@ -58,6 +111,18 @@ class LocalController:
58
111
  return result
59
112
  return out
60
113
 
114
+ def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
115
+ name = threading.current_thread().name + f"PID:{os.getpid()}"
116
+ coro = self.submit(_task, *args, **kwargs)
117
+ if name not in self._runner_map:
118
+ if len(self._runner_map) > 100:
119
+ logger.warning(
120
+ "More than 100 event loop runners created!!! This could be a case of runaway recursion..."
121
+ )
122
+ self._runner_map[name] = _TaskRunner()
123
+
124
+ return self._runner_map[name].get_run_future(coro)
125
+
61
126
  async def finalize_parent_action(self, action: ActionID):
62
127
  pass
63
128
 
@@ -1,6 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import asyncio
4
+ import concurrent.futures
5
+ import os
6
+ import threading
4
7
  from collections import defaultdict
5
8
  from collections.abc import Callable
6
9
  from pathlib import Path
@@ -21,6 +24,7 @@ from flyte._internal.runtime.task_serde import translate_task_to_wire
21
24
  from flyte._logging import logger
22
25
  from flyte._protos.workflow import run_definition_pb2, task_definition_pb2
23
26
  from flyte._task import TaskTemplate
27
+ from flyte._utils.helpers import _selector_policy
24
28
  from flyte.models import ActionID, NativeInterface, SerializationContext
25
29
 
26
30
  R = TypeVar("R")
@@ -119,6 +123,8 @@ class RemoteController(Controller):
119
123
  self._parent_action_task_call_sequence: DefaultDict[str, DefaultDict[int, int]] = defaultdict(
120
124
  lambda: defaultdict(int)
121
125
  )
126
+ self._submit_loop: asyncio.AbstractEventLoop | None = None
127
+ self._submit_thread: threading.Thread | None = None
122
128
 
123
129
  def generate_task_call_sequence(self, task_obj: object, action_id: ActionID) -> int:
124
130
  """
@@ -235,6 +241,48 @@ class RemoteController(Controller):
235
241
  async with self._parent_action_semaphore[unique_action_name(current_action_id)]:
236
242
  return await self._submit(task_call_seq, _task, *args, **kwargs)
237
243
 
244
+ def _sync_thread_loop_runner(self) -> None:
245
+ """This method runs the event loop and should be invoked in a separate thread."""
246
+
247
+ loop = self._submit_loop
248
+ assert loop is not None
249
+ try:
250
+ loop.run_forever()
251
+ finally:
252
+ loop.close()
253
+
254
+ def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
255
+ """
256
+ This function creates a cached thread and loop for the purpose of calling the submit method synchronously,
257
+ returning a concurrent Future that can be awaited. There's no need for a lock because this function itself is
258
+ single threaded and non-async. This pattern here is basically the trivial/degenerate case of the thread pool
259
+ in the LocalController.
260
+ Please see additional comments in protocol.
261
+
262
+ :param _task:
263
+ :param args:
264
+ :param kwargs:
265
+ :return:
266
+ """
267
+ if self._submit_thread is None:
268
+ # Please see LocalController for the general implementation of this pattern.
269
+ def exc_handler(loop, context):
270
+ logger.error(f"Remote controller submit sync loop caught exception in {loop}: {context}")
271
+
272
+ with _selector_policy():
273
+ self._submit_loop = asyncio.new_event_loop()
274
+ self._submit_loop.set_exception_handler(exc_handler)
275
+
276
+ self._submit_thread = threading.Thread(
277
+ name=f"remote-controller-{os.getpid()}-submitter", daemon=True, target=self._sync_thread_loop_runner
278
+ )
279
+ self._submit_thread.start()
280
+
281
+ coro = self.submit(_task, *args, **kwargs)
282
+ assert self._submit_loop is not None, "Submit loop should always have been initialized by now"
283
+ fut = asyncio.run_coroutine_threadsafe(coro, self._submit_loop)
284
+ return fut
285
+
238
286
  async def finalize_parent_action(self, action_id: ActionID):
239
287
  """
240
288
  This method is invoked when the parent action is finished. It will finalize the run and upload the outputs
@@ -38,11 +38,11 @@ class ActionCache:
38
38
  """
39
39
  Add an action to the cache if it doesn't exist. This is invoked by the watch.
40
40
  """
41
- logger.info(f"Observing phase {run_definition_pb2.Phase.Name(state.phase)} for {state.action_id.name}")
41
+ logger.debug(f"Observing phase {run_definition_pb2.Phase.Name(state.phase)} for {state.action_id.name}")
42
42
  if state.output_uri:
43
- logger.info(f"Output URI: {state.output_uri}")
43
+ logger.debug(f"Output URI: {state.output_uri}")
44
44
  else:
45
- logger.info(f"{state.action_id.name} has no output URI")
45
+ logger.warning(f"{state.action_id.name} has no output URI")
46
46
  if state.phase == run_definition_pb2.Phase.PHASE_FAILED:
47
47
  logger.error(
48
48
  f"Action {state.action_id.name} failed with error (msg):"
@@ -136,8 +136,9 @@ async def convert_and_run(
136
136
  raw_data_path=raw_data_path,
137
137
  compiled_image_cache=image_cache,
138
138
  report=flyte.report.Report(name=action.name),
139
+ mode="remote" if not ctx.data.task_context else ctx.data.task_context.mode,
139
140
  )
140
- async with ctx.replace_task_context(tctx):
141
+ with ctx.replace_task_context(tctx):
141
142
  inputs_kwargs = await convert_inputs_to_native(inputs, task.native_interface)
142
143
  out, err = await run_task(tctx=tctx, controller=controller, task=task, inputs=inputs_kwargs)
143
144
  if err is not None:
flyte/_map.py ADDED
@@ -0,0 +1,215 @@
1
+ import asyncio
2
+ from typing import Any, AsyncGenerator, AsyncIterator, Generic, Iterable, Iterator, List, Union, cast
3
+
4
+ from flyte.syncify import syncify
5
+
6
+ from ._group import group
7
+ from ._logging import logger
8
+ from ._task import P, R, TaskTemplate
9
+
10
+
11
+ class MapAsyncIterator(Generic[P, R]):
12
+ """AsyncIterator implementation for the map function results"""
13
+
14
+ def __init__(self, func: TaskTemplate[P, R], args: tuple, name: str, concurrency: int, return_exceptions: bool):
15
+ self.func = func
16
+ self.args = args
17
+ self.name = name
18
+ self.concurrency = concurrency
19
+ self.return_exceptions = return_exceptions
20
+ self._tasks: List[asyncio.Task] = []
21
+ self._current_index = 0
22
+ self._completed_count = 0
23
+ self._exception_count = 0
24
+ self._task_count = 0
25
+ self._initialized = False
26
+
27
+ def __aiter__(self) -> AsyncIterator[Union[R, Exception]]:
28
+ """Return self as the async iterator"""
29
+ return self
30
+
31
+ async def __anext__(self) -> Union[R, Exception]:
32
+ """Get the next result"""
33
+ # Initialize on first call
34
+ if not self._initialized:
35
+ await self._initialize()
36
+
37
+ # Check if we've exhausted all tasks
38
+ if self._current_index >= self._task_count:
39
+ raise StopAsyncIteration
40
+
41
+ # Get the next task result
42
+ task = self._tasks[self._current_index]
43
+ self._current_index += 1
44
+
45
+ try:
46
+ result = await task
47
+ self._completed_count += 1
48
+ logger.debug(f"Task {self._current_index - 1} completed successfully")
49
+ return result
50
+ except Exception as e:
51
+ self._exception_count += 1
52
+ logger.debug(f"Task {self._current_index - 1} failed with exception: {e}")
53
+ if self.return_exceptions:
54
+ return e
55
+ else:
56
+ # Cancel remaining tasks
57
+ for remaining_task in self._tasks[self._current_index + 1 :]:
58
+ remaining_task.cancel()
59
+ raise e
60
+
61
+ async def _initialize(self):
62
+ """Initialize the tasks - called lazily on first iteration"""
63
+ # Create all tasks at once
64
+ tasks = []
65
+ task_count = 0
66
+
67
+ for arg_tuple in zip(*self.args):
68
+ task = asyncio.create_task(self.func.aio(*arg_tuple))
69
+ tasks.append(task)
70
+ task_count += 1
71
+
72
+ if task_count == 0:
73
+ logger.info(f"Group '{self.name}' has no tasks to process")
74
+ self._tasks = []
75
+ self._task_count = 0
76
+ else:
77
+ logger.info(f"Starting {task_count} tasks in group '{self.name}' with unlimited concurrency")
78
+ self._tasks = tasks
79
+ self._task_count = task_count
80
+
81
+ self._initialized = True
82
+
83
+ async def collect(self) -> List[Union[R, Exception]]:
84
+ """Convenience method to collect all results into a list"""
85
+ results = []
86
+ async for result in self:
87
+ results.append(result)
88
+ return results
89
+
90
+ def __repr__(self):
91
+ return f"MapAsyncIterator(group_name='{self.name}', concurrency={self.concurrency})"
92
+
93
+
94
+ class _Mapper(Generic[P, R]):
95
+ """
96
+ Internal mapper class to handle the mapping logic
97
+
98
+ NOTE: The reason why we do not use the `@syncify` decorator here is because, in `syncify` we cannot use
99
+ context managers like `group` directly in the function body. This is because the `__exit__` method of the
100
+ context manager is called after the function returns. An for `_context` the `__exit__` method releases the
101
+ token (for contextvar), which was created in a separate thread. This leads to an exception like:
102
+
103
+ """
104
+
105
+ @classmethod
106
+ def _get_name(cls, task_name: str, group_name: str | None) -> str:
107
+ """Get the name of the group, defaulting to 'map' if not provided."""
108
+ return f"{task_name}_{group_name or 'map'}"
109
+
110
+ def __call__(
111
+ self,
112
+ func: TaskTemplate[P, R],
113
+ *args: Iterable[Any],
114
+ group_name: str | None = None,
115
+ concurrency: int = 0,
116
+ return_exceptions: bool = True,
117
+ ) -> Iterator[Union[R, Exception]]:
118
+ """
119
+ Map a function over the provided arguments with concurrent execution.
120
+
121
+ :param func: The async function to map.
122
+ :param args: Positional arguments to pass to the function (iterables that will be zipped).
123
+ :param group_name: The name of the group for the mapped tasks.
124
+ :param concurrency: The maximum number of concurrent tasks to run. If 0, run all tasks concurrently.
125
+ :param return_exceptions: If True, yield exceptions instead of raising them.
126
+ :return: AsyncIterator yielding results in order.
127
+ """
128
+ if not args:
129
+ return
130
+
131
+ name = self._get_name(func.name, group_name)
132
+ logger.debug(f"Blocking Map for {name}")
133
+ with group(name):
134
+ import flyte
135
+
136
+ tctx = flyte.ctx()
137
+ if tctx is None or tctx.mode == "local":
138
+ logger.warning("Running map in local mode, which will run every task sequentially.")
139
+ for v in zip(*args):
140
+ try:
141
+ yield func(*v) # type: ignore
142
+ except Exception as e:
143
+ if return_exceptions:
144
+ yield e
145
+ else:
146
+ raise e
147
+ return
148
+
149
+ i = 0
150
+ for x in cast(
151
+ Iterator[R],
152
+ _map(
153
+ func,
154
+ *args,
155
+ name=name,
156
+ concurrency=concurrency,
157
+ return_exceptions=True,
158
+ ),
159
+ ):
160
+ logger.debug(f"Mapped {x}, task {i}")
161
+ i += 1
162
+ yield x
163
+
164
+ async def aio(
165
+ self,
166
+ func: TaskTemplate[P, R],
167
+ *args: Iterable[Any],
168
+ group_name: str | None = None,
169
+ concurrency: int = 0,
170
+ return_exceptions: bool = True,
171
+ ) -> AsyncGenerator[Union[R, Exception], None]:
172
+ if not args:
173
+ return
174
+ name = self._get_name(func.name, group_name)
175
+ with group(name):
176
+ import flyte
177
+
178
+ tctx = flyte.ctx()
179
+ if tctx is None or tctx.mode == "local":
180
+ logger.warning("Running map in local mode, which will run every task sequentially.")
181
+ for v in zip(*args):
182
+ try:
183
+ yield func(*v) # type: ignore
184
+ except Exception as e:
185
+ if return_exceptions:
186
+ yield e
187
+ else:
188
+ raise e
189
+ return
190
+ async for x in _map.aio(
191
+ func,
192
+ *args,
193
+ name=name,
194
+ concurrency=concurrency,
195
+ return_exceptions=return_exceptions,
196
+ ):
197
+ yield cast(Union[R, Exception], x)
198
+
199
+
200
+ @syncify
201
+ async def _map(
202
+ func: TaskTemplate[P, R],
203
+ *args: Iterable[Any],
204
+ name: str = "map",
205
+ concurrency: int = 0,
206
+ return_exceptions: bool = True,
207
+ ) -> AsyncIterator[Union[R, Exception]]:
208
+ iter = MapAsyncIterator(
209
+ func=func, args=args, name=name, concurrency=concurrency, return_exceptions=return_exceptions
210
+ )
211
+ async for result in iter:
212
+ yield result
213
+
214
+
215
+ map: _Mapper = _Mapper()
flyte/_run.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import asyncio
3
4
  import pathlib
4
5
  import uuid
5
6
  from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple, Union, cast
@@ -287,40 +288,39 @@ class _Runner:
287
288
 
288
289
  async def _run_local(self, obj: TaskTemplate[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
289
290
  from flyte._internal.controllers import create_controller
290
- from flyte._internal.runtime.convert import (
291
- convert_error_to_native,
292
- convert_from_native_to_inputs,
293
- convert_outputs_to_native,
294
- )
295
- from flyte._internal.runtime.entrypoints import direct_dispatch
291
+ from flyte._internal.controllers._local_controller import LocalController
292
+ from flyte.report import Report
293
+
294
+ controller = cast(LocalController, create_controller("local"))
296
295
 
297
- controller = create_controller("local")
298
- inputs = await convert_from_native_to_inputs(obj.native_interface, *args, **kwargs)
299
296
  if self._name is None:
300
297
  action = ActionID.create_random()
301
298
  else:
302
299
  action = ActionID(name=self._name)
303
- out, err = await direct_dispatch(
304
- obj,
300
+
301
+ ctx = internal_ctx()
302
+ tctx = TaskContext(
305
303
  action=action,
306
- raw_data_path=internal_ctx().raw_data,
307
- version="na",
308
- controller=controller,
309
- inputs=inputs,
310
- output_path=self._metadata_path,
311
- run_base_dir=self._metadata_path,
312
304
  checkpoints=Checkpoints(
313
305
  prev_checkpoint_path=internal_ctx().raw_data.path, checkpoint_path=internal_ctx().raw_data.path
314
306
  ),
315
- ) # type: ignore
316
- if err:
317
- native_err = convert_error_to_native(err)
318
- if native_err:
319
- raise native_err
320
- if obj.native_interface.outputs and len(obj.native_interface.outputs) > 0:
321
- if out is not None:
322
- return cast(R, await convert_outputs_to_native(obj.native_interface, out))
323
- return cast(R, None)
307
+ code_bundle=None,
308
+ output_path=self._metadata_path,
309
+ run_base_dir=self._metadata_path,
310
+ version="na",
311
+ raw_data_path=internal_ctx().raw_data,
312
+ compiled_image_cache=None,
313
+ report=Report(name=action.name),
314
+ mode="local",
315
+ )
316
+ async with ctx.replace_task_context(tctx):
317
+ # make the local version always runs on a different thread, returns a wrapped future.
318
+ if obj._call_as_synchronous:
319
+ fut = controller.submit_sync(obj, *args, **kwargs)
320
+ awaitable = asyncio.wrap_future(fut)
321
+ return await awaitable
322
+ else:
323
+ return await controller.submit(obj, *args, **kwargs)
324
324
 
325
325
  @syncify
326
326
  async def run(self, task: TaskTemplate[P, Union[R, Run]], *args: P.args, **kwargs: P.kwargs) -> Union[R, Run]:
@@ -351,7 +351,7 @@ class _Runner:
351
351
  return await self._run_hybrid(task, *args, **kwargs)
352
352
 
353
353
  # TODO We could use this for remote as well and users could simply pass flyte:// or s3:// or file://
354
- async with internal_ctx().new_raw_data_path(
354
+ with internal_ctx().new_raw_data_path(
355
355
  raw_data_path=RawDataPath.from_local_folder(local_folder=self._raw_data_path)
356
356
  ):
357
357
  return await self._run_local(task, *args, **kwargs)