flyte 0.2.0b8__py3-none-any.whl → 0.2.0b10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (42) hide show
  1. flyte/__init__.py +4 -2
  2. flyte/_context.py +7 -1
  3. flyte/_deploy.py +3 -0
  4. flyte/_group.py +1 -0
  5. flyte/_initialize.py +15 -5
  6. flyte/_internal/controllers/__init__.py +13 -2
  7. flyte/_internal/controllers/_local_controller.py +67 -5
  8. flyte/_internal/controllers/remote/_controller.py +47 -2
  9. flyte/_internal/runtime/taskrunner.py +2 -1
  10. flyte/_map.py +215 -0
  11. flyte/_run.py +109 -64
  12. flyte/_task.py +56 -7
  13. flyte/_utils/helpers.py +15 -0
  14. flyte/_version.py +2 -2
  15. flyte/cli/__init__.py +0 -7
  16. flyte/cli/_abort.py +1 -1
  17. flyte/cli/_common.py +7 -7
  18. flyte/cli/_create.py +44 -29
  19. flyte/cli/_delete.py +2 -2
  20. flyte/cli/_deploy.py +3 -3
  21. flyte/cli/_gen.py +12 -4
  22. flyte/cli/_get.py +35 -27
  23. flyte/cli/_params.py +1 -1
  24. flyte/cli/main.py +32 -29
  25. flyte/extras/_container.py +29 -32
  26. flyte/io/__init__.py +17 -1
  27. flyte/io/_file.py +2 -0
  28. flyte/io/{structured_dataset → _structured_dataset}/basic_dfs.py +1 -1
  29. flyte/io/{structured_dataset → _structured_dataset}/structured_dataset.py +1 -1
  30. flyte/models.py +11 -1
  31. flyte/syncify/_api.py +43 -15
  32. flyte/types/__init__.py +23 -0
  33. flyte/{io/pickle/transformer.py → types/_pickle.py} +2 -1
  34. flyte/types/_type_engine.py +4 -4
  35. {flyte-0.2.0b8.dist-info → flyte-0.2.0b10.dist-info}/METADATA +7 -6
  36. {flyte-0.2.0b8.dist-info → flyte-0.2.0b10.dist-info}/RECORD +40 -41
  37. flyte/io/_dataframe.py +0 -0
  38. flyte/io/pickle/__init__.py +0 -0
  39. /flyte/io/{structured_dataset → _structured_dataset}/__init__.py +0 -0
  40. {flyte-0.2.0b8.dist-info → flyte-0.2.0b10.dist-info}/WHEEL +0 -0
  41. {flyte-0.2.0b8.dist-info → flyte-0.2.0b10.dist-info}/entry_points.txt +0 -0
  42. {flyte-0.2.0b8.dist-info → flyte-0.2.0b10.dist-info}/top_level.txt +0 -0
flyte/__init__.py CHANGED
@@ -38,7 +38,8 @@ __all__ = [
38
38
  "deploy",
39
39
  "group",
40
40
  "init",
41
- "init_auto_from_config",
41
+ "init_from_config",
42
+ "map",
42
43
  "run",
43
44
  "trace",
44
45
  "with_runcontext",
@@ -50,7 +51,8 @@ from ._deploy import deploy
50
51
  from ._environment import Environment
51
52
  from ._group import group
52
53
  from ._image import Image
53
- from ._initialize import init, init_auto_from_config
54
+ from ._initialize import init, init_from_config
55
+ from ._map import map
54
56
  from ._resources import GPU, TPU, Device, Resources
55
57
  from ._retry import RetryStrategy
56
58
  from ._reusable_environment import ReusePolicy
flyte/_context.py CHANGED
@@ -4,6 +4,7 @@ import contextvars
4
4
  from dataclasses import dataclass, replace
5
5
  from typing import TYPE_CHECKING, Awaitable, Callable, Optional, ParamSpec, TypeVar
6
6
 
7
+ from flyte._logging import logger
7
8
  from flyte.models import GroupData, RawDataPath, TaskContext
8
9
 
9
10
  if TYPE_CHECKING:
@@ -49,6 +50,7 @@ class Context:
49
50
  raise ValueError("Cannot create a new context without contextdata.")
50
51
  self._data = data
51
52
  self._id = id(self) # Immutable unique identifier
53
+ self._token = None # Context variable token to restore the previous context
52
54
 
53
55
  @property
54
56
  def data(self) -> ContextData:
@@ -106,7 +108,11 @@ class Context:
106
108
 
107
109
  def __exit__(self, exc_type, exc_val, exc_tb):
108
110
  """Exit the context, restoring the previous context."""
109
- root_context_var.reset(self._token)
111
+ try:
112
+ root_context_var.reset(self._token)
113
+ except Exception as e:
114
+ logger.warn(f"Failed to reset context: {e}")
115
+ raise e
110
116
 
111
117
  async def __aenter__(self):
112
118
  """Async version of context entry."""
flyte/_deploy.py CHANGED
@@ -128,6 +128,9 @@ async def apply(deployment: DeploymentPlan, copy_style: CopyFiles, dryrun: bool
128
128
  else:
129
129
  code_bundle = await build_code_bundle(from_dir=cfg.root_dir, dryrun=dryrun, copy_style=copy_style)
130
130
  deployment.version = code_bundle.computed_version
131
+ # TODO we should update the version to include the image cache digest and code bundle digest. This is
132
+ # to ensure that changes in image dependencies, cause an update to the deployment version.
133
+ # TODO Also hash the environment and tasks to ensure that changes in the environment or tasks
131
134
 
132
135
  sc = SerializationContext(
133
136
  project=cfg.project,
flyte/_group.py CHANGED
@@ -29,3 +29,4 @@ def group(name: str):
29
29
  new_tctx = tctx.replace(group_data=GroupData(name))
30
30
  with ctx.replace_task_context(new_tctx):
31
31
  yield
32
+ # Exit the context and restore the previous context
flyte/_initialize.py CHANGED
@@ -138,14 +138,13 @@ async def init(
138
138
  :param project: Optional project name (not used in this implementation)
139
139
  :param domain: Optional domain name (not used in this implementation)
140
140
  :param root_dir: Optional root directory from which to determine how to load files, and find paths to files.
141
+ This is useful for determining the root directory for the current project, and for locating files like config etc.
142
+ also use to determine all the code that needs to be copied to the remote location.
141
143
  defaults to the editable install directory if the cwd is in a Python editable install, else just the cwd.
142
144
  :param log_level: Optional logging level for the logger, default is set using the default initialization policies
143
145
  :param api_key: Optional API key for authentication
144
146
  :param endpoint: Optional API endpoint URL
145
147
  :param headless: Optional Whether to run in headless mode
146
- :param mode: Optional execution model (local, remote). Default is local. When local is used,
147
- the execution will be done locally. When remote is used, the execution will be sent to a remote server,
148
- In the remote case, the endpoint or api_key must be set.
149
148
  :param insecure_skip_verify: Whether to skip SSL certificate verification
150
149
  :param auth_client_config: Optional client configuration for authentication
151
150
  :param auth_type: The authentication type to use (Pkce, ClientSecret, ExternalCommand, DeviceFlow)
@@ -177,6 +176,9 @@ async def init(
177
176
 
178
177
  global _init_config # noqa: PLW0603
179
178
 
179
+ if endpoint and "://" not in endpoint:
180
+ endpoint = f"dns:///{endpoint}"
181
+
180
182
  with _init_lock:
181
183
  client = None
182
184
  if endpoint or api_key:
@@ -209,12 +211,16 @@ async def init(
209
211
 
210
212
 
211
213
  @syncify
212
- async def init_auto_from_config(path_or_config: str | Config | None = None) -> None:
214
+ async def init_from_config(path_or_config: str | Config | None = None, root_dir: Path | None = None) -> None:
213
215
  """
214
216
  Initialize the Flyte system using a configuration file or Config object. This method should be called before any
215
217
  other Flyte remote API methods are called. Thread-safe implementation.
216
218
 
217
219
  :param path_or_config: Path to the configuration file or Config object
220
+ :param root_dir: Optional root directory from which to determine how to load files, and find paths to
221
+ files like config etc. For example if one uses the copy-style=="all", it is essential to determine the
222
+ root directory for the current project. If not provided, it defaults to the editable install directory or
223
+ if not available, the current working directory.
218
224
  :return: None
219
225
  """
220
226
  import flyte.config as config
@@ -222,7 +228,10 @@ async def init_auto_from_config(path_or_config: str | Config | None = None) -> N
222
228
  cfg: config.Config
223
229
  if path_or_config is None or isinstance(path_or_config, str):
224
230
  # If a string is passed, treat it as a path to the config file
225
- cfg = config.auto(path_or_config)
231
+ if root_dir and path_or_config:
232
+ cfg = config.auto(str(root_dir / path_or_config))
233
+ else:
234
+ cfg = config.auto(path_or_config)
226
235
  else:
227
236
  # If a Config object is passed, use it directly
228
237
  cfg = path_or_config
@@ -241,6 +250,7 @@ async def init_auto_from_config(path_or_config: str | Config | None = None) -> N
241
250
  proxy_command=cfg.platform.proxy_command,
242
251
  client_id=cfg.platform.client_id,
243
252
  client_credentials_secret=cfg.platform.client_credentials_secret,
253
+ root_dir=root_dir,
244
254
  )
245
255
 
246
256
 
@@ -1,5 +1,6 @@
1
+ import concurrent.futures
1
2
  import threading
2
- from typing import Any, Callable, Literal, Optional, Protocol, Tuple, TypeVar
3
+ from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Protocol, Tuple, TypeVar
3
4
 
4
5
  from flyte._task import TaskTemplate
5
6
  from flyte.models import ActionID, NativeInterface
@@ -10,6 +11,9 @@ __all__ = ["Controller", "ControllerType", "TraceInfo", "create_controller", "ge
10
11
 
11
12
  from ..._protos.workflow import task_definition_pb2
12
13
 
14
+ if TYPE_CHECKING:
15
+ import concurrent.futures
16
+
13
17
  ControllerType = Literal["local", "remote"]
14
18
 
15
19
  R = TypeVar("R")
@@ -28,7 +32,14 @@ class Controller(Protocol):
28
32
  """
29
33
  ...
30
34
 
31
- def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> Any: ...
35
+ def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
36
+ """
37
+ This should call the async submit method above, but return a concurrent Future object that can be
38
+ used in a blocking wait or wrapped in an async future. This is called when
39
+ a) a synchronous task is kicked off locally,
40
+ b) a running task (of either kind) kicks off a downstream synchronous task.
41
+ """
42
+ ...
32
43
 
33
44
  async def submit_task_ref(self, _task: task_definition_pb2.TaskDetails, *args, **kwargs) -> Any:
34
45
  """
@@ -1,3 +1,8 @@
1
+ import asyncio
2
+ import atexit
3
+ import concurrent.futures
4
+ import os
5
+ import threading
1
6
  from typing import Any, Callable, Tuple, TypeVar
2
7
 
3
8
  import flyte.errors
@@ -8,20 +13,67 @@ from flyte._internal.runtime.entrypoints import direct_dispatch
8
13
  from flyte._logging import log, logger
9
14
  from flyte._protos.workflow import task_definition_pb2
10
15
  from flyte._task import TaskTemplate
11
- from flyte._utils.asyn import loop_manager
12
- from flyte.models import ActionID, NativeInterface, RawDataPath
16
+ from flyte._utils.helpers import _selector_policy
17
+ from flyte.models import ActionID, NativeInterface
13
18
 
14
19
  R = TypeVar("R")
15
20
 
16
21
 
22
+ class _TaskRunner:
23
+ """A task runner that runs an asyncio event loop on a background thread."""
24
+
25
+ def __init__(self) -> None:
26
+ self.__loop: asyncio.AbstractEventLoop | None = None
27
+ self.__runner_thread: threading.Thread | None = None
28
+ self.__lock = threading.Lock()
29
+ atexit.register(self._close)
30
+
31
+ def _close(self) -> None:
32
+ if self.__loop:
33
+ self.__loop.stop()
34
+
35
+ def _execute(self) -> None:
36
+ loop = self.__loop
37
+ assert loop is not None
38
+ try:
39
+ loop.run_forever()
40
+ finally:
41
+ loop.close()
42
+
43
+ def get_exc_handler(self):
44
+ def exc_handler(loop, context):
45
+ logger.error(
46
+ f"Taskrunner for {self.__runner_thread.name if self.__runner_thread else 'no thread'} caught"
47
+ f" exception in {loop}: {context}"
48
+ )
49
+
50
+ return exc_handler
51
+
52
+ def get_run_future(self, coro: Any) -> concurrent.futures.Future:
53
+ """Synchronously run a coroutine on a background thread."""
54
+ name = f"{threading.current_thread().name} : loop-runner"
55
+ with self.__lock:
56
+ if self.__loop is None:
57
+ with _selector_policy():
58
+ self.__loop = asyncio.new_event_loop()
59
+
60
+ exc_handler = self.get_exc_handler()
61
+ self.__loop.set_exception_handler(exc_handler)
62
+ self.__runner_thread = threading.Thread(target=self._execute, daemon=True, name=name)
63
+ self.__runner_thread.start()
64
+ fut = asyncio.run_coroutine_threadsafe(coro, self.__loop)
65
+ return fut
66
+
67
+
17
68
  class LocalController:
18
69
  def __init__(self):
19
70
  logger.debug("LocalController init")
71
+ self._runner_map: dict[str, _TaskRunner] = {}
20
72
 
21
73
  @log
22
74
  async def submit(self, _task: TaskTemplate, *args, **kwargs) -> Any:
23
75
  """
24
- Submit a node to the controller
76
+ Main entrypoint for submitting a task to the local controller.
25
77
  """
26
78
  ctx = internal_ctx()
27
79
  tctx = ctx.data.task_context
@@ -34,7 +86,7 @@ class LocalController:
34
86
  sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
35
87
  tctx, _task.name, serialized_inputs, 0
36
88
  )
37
- sub_action_raw_data_path = RawDataPath(path=sub_action_output_path)
89
+ sub_action_raw_data_path = tctx.raw_data_path
38
90
 
39
91
  out, err = await direct_dispatch(
40
92
  _task,
@@ -59,7 +111,17 @@ class LocalController:
59
111
  return result
60
112
  return out
61
113
 
62
- submit_sync = loop_manager.synced(submit)
114
+ def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
115
+ name = threading.current_thread().name + f"PID:{os.getpid()}"
116
+ coro = self.submit(_task, *args, **kwargs)
117
+ if name not in self._runner_map:
118
+ if len(self._runner_map) > 100:
119
+ logger.warning(
120
+ "More than 100 event loop runners created!!! This could be a case of runaway recursion..."
121
+ )
122
+ self._runner_map[name] = _TaskRunner()
123
+
124
+ return self._runner_map[name].get_run_future(coro)
63
125
 
64
126
  async def finalize_parent_action(self, action: ActionID):
65
127
  pass
@@ -1,6 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import asyncio
4
+ import concurrent.futures
5
+ import os
6
+ import threading
4
7
  from collections import defaultdict
5
8
  from collections.abc import Callable
6
9
  from pathlib import Path
@@ -21,7 +24,7 @@ from flyte._internal.runtime.task_serde import translate_task_to_wire
21
24
  from flyte._logging import logger
22
25
  from flyte._protos.workflow import run_definition_pb2, task_definition_pb2
23
26
  from flyte._task import TaskTemplate
24
- from flyte._utils.asyn import loop_manager
27
+ from flyte._utils.helpers import _selector_policy
25
28
  from flyte.models import ActionID, NativeInterface, SerializationContext
26
29
 
27
30
  R = TypeVar("R")
@@ -120,6 +123,8 @@ class RemoteController(Controller):
120
123
  self._parent_action_task_call_sequence: DefaultDict[str, DefaultDict[int, int]] = defaultdict(
121
124
  lambda: defaultdict(int)
122
125
  )
126
+ self._submit_loop: asyncio.AbstractEventLoop | None = None
127
+ self._submit_thread: threading.Thread | None = None
123
128
 
124
129
  def generate_task_call_sequence(self, task_obj: object, action_id: ActionID) -> int:
125
130
  """
@@ -236,7 +241,47 @@ class RemoteController(Controller):
236
241
  async with self._parent_action_semaphore[unique_action_name(current_action_id)]:
237
242
  return await self._submit(task_call_seq, _task, *args, **kwargs)
238
243
 
239
- submit_sync = loop_manager.synced(submit)
244
+ def _sync_thread_loop_runner(self) -> None:
245
+ """This method runs the event loop and should be invoked in a separate thread."""
246
+
247
+ loop = self._submit_loop
248
+ assert loop is not None
249
+ try:
250
+ loop.run_forever()
251
+ finally:
252
+ loop.close()
253
+
254
+ def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
255
+ """
256
+ This function creates a cached thread and loop for the purpose of calling the submit method synchronously,
257
+ returning a concurrent Future that can be awaited. There's no need for a lock because this function itself is
258
+ single threaded and non-async. This pattern here is basically the trivial/degenerate case of the thread pool
259
+ in the LocalController.
260
+ Please see additional comments in protocol.
261
+
262
+ :param _task:
263
+ :param args:
264
+ :param kwargs:
265
+ :return:
266
+ """
267
+ if self._submit_thread is None:
268
+ # Please see LocalController for the general implementation of this pattern.
269
+ def exc_handler(loop, context):
270
+ logger.error(f"Remote controller submit sync loop caught exception in {loop}: {context}")
271
+
272
+ with _selector_policy():
273
+ self._submit_loop = asyncio.new_event_loop()
274
+ self._submit_loop.set_exception_handler(exc_handler)
275
+
276
+ self._submit_thread = threading.Thread(
277
+ name=f"remote-controller-{os.getpid()}-submitter", daemon=True, target=self._sync_thread_loop_runner
278
+ )
279
+ self._submit_thread.start()
280
+
281
+ coro = self.submit(_task, *args, **kwargs)
282
+ assert self._submit_loop is not None, "Submit loop should always have been initialized by now"
283
+ fut = asyncio.run_coroutine_threadsafe(coro, self._submit_loop)
284
+ return fut
240
285
 
241
286
  async def finalize_parent_action(self, action_id: ActionID):
242
287
  """
@@ -136,8 +136,9 @@ async def convert_and_run(
136
136
  raw_data_path=raw_data_path,
137
137
  compiled_image_cache=image_cache,
138
138
  report=flyte.report.Report(name=action.name),
139
+ mode="remote" if not ctx.data.task_context else ctx.data.task_context.mode,
139
140
  )
140
- async with ctx.replace_task_context(tctx):
141
+ with ctx.replace_task_context(tctx):
141
142
  inputs_kwargs = await convert_inputs_to_native(inputs, task.native_interface)
142
143
  out, err = await run_task(tctx=tctx, controller=controller, task=task, inputs=inputs_kwargs)
143
144
  if err is not None:
flyte/_map.py ADDED
@@ -0,0 +1,215 @@
1
+ import asyncio
2
+ from typing import Any, AsyncGenerator, AsyncIterator, Generic, Iterable, Iterator, List, Union, cast
3
+
4
+ from flyte.syncify import syncify
5
+
6
+ from ._group import group
7
+ from ._logging import logger
8
+ from ._task import P, R, TaskTemplate
9
+
10
+
11
+ class MapAsyncIterator(Generic[P, R]):
12
+ """AsyncIterator implementation for the map function results"""
13
+
14
+ def __init__(self, func: TaskTemplate[P, R], args: tuple, name: str, concurrency: int, return_exceptions: bool):
15
+ self.func = func
16
+ self.args = args
17
+ self.name = name
18
+ self.concurrency = concurrency
19
+ self.return_exceptions = return_exceptions
20
+ self._tasks: List[asyncio.Task] = []
21
+ self._current_index = 0
22
+ self._completed_count = 0
23
+ self._exception_count = 0
24
+ self._task_count = 0
25
+ self._initialized = False
26
+
27
+ def __aiter__(self) -> AsyncIterator[Union[R, Exception]]:
28
+ """Return self as the async iterator"""
29
+ return self
30
+
31
+ async def __anext__(self) -> Union[R, Exception]:
32
+ """Get the next result"""
33
+ # Initialize on first call
34
+ if not self._initialized:
35
+ await self._initialize()
36
+
37
+ # Check if we've exhausted all tasks
38
+ if self._current_index >= self._task_count:
39
+ raise StopAsyncIteration
40
+
41
+ # Get the next task result
42
+ task = self._tasks[self._current_index]
43
+ self._current_index += 1
44
+
45
+ try:
46
+ result = await task
47
+ self._completed_count += 1
48
+ logger.debug(f"Task {self._current_index - 1} completed successfully")
49
+ return result
50
+ except Exception as e:
51
+ self._exception_count += 1
52
+ logger.debug(f"Task {self._current_index - 1} failed with exception: {e}")
53
+ if self.return_exceptions:
54
+ return e
55
+ else:
56
+ # Cancel remaining tasks
57
+ for remaining_task in self._tasks[self._current_index + 1 :]:
58
+ remaining_task.cancel()
59
+ raise e
60
+
61
+ async def _initialize(self):
62
+ """Initialize the tasks - called lazily on first iteration"""
63
+ # Create all tasks at once
64
+ tasks = []
65
+ task_count = 0
66
+
67
+ for arg_tuple in zip(*self.args):
68
+ task = asyncio.create_task(self.func.aio(*arg_tuple))
69
+ tasks.append(task)
70
+ task_count += 1
71
+
72
+ if task_count == 0:
73
+ logger.info(f"Group '{self.name}' has no tasks to process")
74
+ self._tasks = []
75
+ self._task_count = 0
76
+ else:
77
+ logger.info(f"Starting {task_count} tasks in group '{self.name}' with unlimited concurrency")
78
+ self._tasks = tasks
79
+ self._task_count = task_count
80
+
81
+ self._initialized = True
82
+
83
+ async def collect(self) -> List[Union[R, Exception]]:
84
+ """Convenience method to collect all results into a list"""
85
+ results = []
86
+ async for result in self:
87
+ results.append(result)
88
+ return results
89
+
90
+ def __repr__(self):
91
+ return f"MapAsyncIterator(group_name='{self.name}', concurrency={self.concurrency})"
92
+
93
+
94
+ class _Mapper(Generic[P, R]):
95
+ """
96
+ Internal mapper class to handle the mapping logic
97
+
98
+ NOTE: The reason why we do not use the `@syncify` decorator here is because, in `syncify` we cannot use
99
+ context managers like `group` directly in the function body. This is because the `__exit__` method of the
100
+ context manager is called after the function returns. An for `_context` the `__exit__` method releases the
101
+ token (for contextvar), which was created in a separate thread. This leads to an exception like:
102
+
103
+ """
104
+
105
+ @classmethod
106
+ def _get_name(cls, task_name: str, group_name: str | None) -> str:
107
+ """Get the name of the group, defaulting to 'map' if not provided."""
108
+ return f"{task_name}_{group_name or 'map'}"
109
+
110
+ def __call__(
111
+ self,
112
+ func: TaskTemplate[P, R],
113
+ *args: Iterable[Any],
114
+ group_name: str | None = None,
115
+ concurrency: int = 0,
116
+ return_exceptions: bool = True,
117
+ ) -> Iterator[Union[R, Exception]]:
118
+ """
119
+ Map a function over the provided arguments with concurrent execution.
120
+
121
+ :param func: The async function to map.
122
+ :param args: Positional arguments to pass to the function (iterables that will be zipped).
123
+ :param group_name: The name of the group for the mapped tasks.
124
+ :param concurrency: The maximum number of concurrent tasks to run. If 0, run all tasks concurrently.
125
+ :param return_exceptions: If True, yield exceptions instead of raising them.
126
+ :return: AsyncIterator yielding results in order.
127
+ """
128
+ if not args:
129
+ return
130
+
131
+ name = self._get_name(func.name, group_name)
132
+ logger.debug(f"Blocking Map for {name}")
133
+ with group(name):
134
+ import flyte
135
+
136
+ tctx = flyte.ctx()
137
+ if tctx is None or tctx.mode == "local":
138
+ logger.warning("Running map in local mode, which will run every task sequentially.")
139
+ for v in zip(*args):
140
+ try:
141
+ yield func(*v) # type: ignore
142
+ except Exception as e:
143
+ if return_exceptions:
144
+ yield e
145
+ else:
146
+ raise e
147
+ return
148
+
149
+ i = 0
150
+ for x in cast(
151
+ Iterator[R],
152
+ _map(
153
+ func,
154
+ *args,
155
+ name=name,
156
+ concurrency=concurrency,
157
+ return_exceptions=True,
158
+ ),
159
+ ):
160
+ logger.debug(f"Mapped {x}, task {i}")
161
+ i += 1
162
+ yield x
163
+
164
+ async def aio(
165
+ self,
166
+ func: TaskTemplate[P, R],
167
+ *args: Iterable[Any],
168
+ group_name: str | None = None,
169
+ concurrency: int = 0,
170
+ return_exceptions: bool = True,
171
+ ) -> AsyncGenerator[Union[R, Exception], None]:
172
+ if not args:
173
+ return
174
+ name = self._get_name(func.name, group_name)
175
+ with group(name):
176
+ import flyte
177
+
178
+ tctx = flyte.ctx()
179
+ if tctx is None or tctx.mode == "local":
180
+ logger.warning("Running map in local mode, which will run every task sequentially.")
181
+ for v in zip(*args):
182
+ try:
183
+ yield func(*v) # type: ignore
184
+ except Exception as e:
185
+ if return_exceptions:
186
+ yield e
187
+ else:
188
+ raise e
189
+ return
190
+ async for x in _map.aio(
191
+ func,
192
+ *args,
193
+ name=name,
194
+ concurrency=concurrency,
195
+ return_exceptions=return_exceptions,
196
+ ):
197
+ yield cast(Union[R, Exception], x)
198
+
199
+
200
+ @syncify
201
+ async def _map(
202
+ func: TaskTemplate[P, R],
203
+ *args: Iterable[Any],
204
+ name: str = "map",
205
+ concurrency: int = 0,
206
+ return_exceptions: bool = True,
207
+ ) -> AsyncIterator[Union[R, Exception]]:
208
+ iter = MapAsyncIterator(
209
+ func=func, args=args, name=name, concurrency=concurrency, return_exceptions=return_exceptions
210
+ )
211
+ async for result in iter:
212
+ yield result
213
+
214
+
215
+ map: _Mapper = _Mapper()