flyte 0.2.0b8__py3-none-any.whl → 0.2.0b10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (42) hide show
  1. flyte/__init__.py +4 -2
  2. flyte/_context.py +7 -1
  3. flyte/_deploy.py +3 -0
  4. flyte/_group.py +1 -0
  5. flyte/_initialize.py +15 -5
  6. flyte/_internal/controllers/__init__.py +13 -2
  7. flyte/_internal/controllers/_local_controller.py +67 -5
  8. flyte/_internal/controllers/remote/_controller.py +47 -2
  9. flyte/_internal/runtime/taskrunner.py +2 -1
  10. flyte/_map.py +215 -0
  11. flyte/_run.py +109 -64
  12. flyte/_task.py +56 -7
  13. flyte/_utils/helpers.py +15 -0
  14. flyte/_version.py +2 -2
  15. flyte/cli/__init__.py +0 -7
  16. flyte/cli/_abort.py +1 -1
  17. flyte/cli/_common.py +7 -7
  18. flyte/cli/_create.py +44 -29
  19. flyte/cli/_delete.py +2 -2
  20. flyte/cli/_deploy.py +3 -3
  21. flyte/cli/_gen.py +12 -4
  22. flyte/cli/_get.py +35 -27
  23. flyte/cli/_params.py +1 -1
  24. flyte/cli/main.py +32 -29
  25. flyte/extras/_container.py +29 -32
  26. flyte/io/__init__.py +17 -1
  27. flyte/io/_file.py +2 -0
  28. flyte/io/{structured_dataset → _structured_dataset}/basic_dfs.py +1 -1
  29. flyte/io/{structured_dataset → _structured_dataset}/structured_dataset.py +1 -1
  30. flyte/models.py +11 -1
  31. flyte/syncify/_api.py +43 -15
  32. flyte/types/__init__.py +23 -0
  33. flyte/{io/pickle/transformer.py → types/_pickle.py} +2 -1
  34. flyte/types/_type_engine.py +4 -4
  35. {flyte-0.2.0b8.dist-info → flyte-0.2.0b10.dist-info}/METADATA +7 -6
  36. {flyte-0.2.0b8.dist-info → flyte-0.2.0b10.dist-info}/RECORD +40 -41
  37. flyte/io/_dataframe.py +0 -0
  38. flyte/io/pickle/__init__.py +0 -0
  39. /flyte/io/{structured_dataset → _structured_dataset}/__init__.py +0 -0
  40. {flyte-0.2.0b8.dist-info → flyte-0.2.0b10.dist-info}/WHEEL +0 -0
  41. {flyte-0.2.0b8.dist-info → flyte-0.2.0b10.dist-info}/entry_points.txt +0 -0
  42. {flyte-0.2.0b8.dist-info → flyte-0.2.0b10.dist-info}/top_level.txt +0 -0
flyte/cli/_gen.py CHANGED
@@ -1,3 +1,4 @@
1
+ import textwrap
1
2
  from os import getcwd
2
3
  from typing import Generator, Tuple
3
4
 
@@ -9,7 +10,7 @@ import flyte.cli._common as common
9
10
  @click.group(name="gen")
10
11
  def gen():
11
12
  """
12
- Generate documentation
13
+ Generate documentation.
13
14
  """
14
15
 
15
16
 
@@ -18,7 +19,7 @@ def gen():
18
19
  @click.pass_obj
19
20
  def docs(cfg: common.CLIConfig, doc_type: str, project: str | None = None, domain: str | None = None):
20
21
  """
21
- Generate documentation
22
+ Generate documentation.
22
23
  """
23
24
  if doc_type == "markdown":
24
25
  markdown(cfg)
@@ -82,7 +83,7 @@ def markdown(cfg: common.CLIConfig):
82
83
  output.append(f"{'#' * (len(cmd_path_parts) + 1)} {cmd_path}")
83
84
  if cmd.help:
84
85
  output.append("")
85
- output.append(f"{cmd.help.strip()}")
86
+ output.append(f"{dedent(cmd.help)}")
86
87
 
87
88
  if not cmd.params:
88
89
  continue
@@ -109,7 +110,7 @@ def markdown(cfg: common.CLIConfig):
109
110
  if param.default is not None:
110
111
  default_value = f"`{param.default}`"
111
112
  default_value = default_value.replace(f"{getcwd()}/", "")
112
- help_text = param.help.strip() if param.help else ""
113
+ help_text = dedent(param.help) if param.help else ""
113
114
  table_data.append([opts, f"`{param.type.name}`", default_value, help_text])
114
115
 
115
116
  if not table_data:
@@ -153,3 +154,10 @@ def markdown(cfg: common.CLIConfig):
153
154
  print("{{< /grid >}}")
154
155
  print()
155
156
  print("\n".join(output))
157
+
158
+
159
+ def dedent(text: str) -> str:
160
+ """
161
+ Remove leading whitespace from a string.
162
+ """
163
+ return textwrap.dedent(text).strip("\n")
flyte/cli/_get.py CHANGED
@@ -11,16 +11,23 @@ from . import _common as common
11
11
  @click.group(name="get")
12
12
  def get():
13
13
  """
14
- The `get` subcommand allows you to retrieve various resources from a Flyte deployment.
14
+ Retrieve resources from a Flyte deployment.
15
+
16
+ You can get information about projects, runs, tasks, actions, secrets, logs and input/output values.
15
17
 
16
- You can get information about projects, runs, tasks, actions, secrets, and more.
17
18
  Each command supports optional parameters to filter or specify the resource you want to retrieve.
18
19
 
19
- Every `get` subcommand for example ``get project` without any arguments will list all projects.
20
- `get project my_project` will return the details of the project named `my_project`.
20
+ Using a `get` subcommand without any arguments will retrieve a list of available resources to get.
21
+ For example:
22
+
23
+ * `get project` (without specifiying aproject), will list all projects.
24
+ * `get project my_project` will return the details of the project named `my_project`.
25
+
26
+ In some cases, a partially specified command will act as a filter and return available further parameters.
27
+ For example:
21
28
 
22
- In some cases `get action my_run` will return all actions for the run named `my_run` and
23
- `get action my_run my_action` will return the details of the action named `my_action` for the run `my_run`.
29
+ * `get action my_run` will return all actions for the run named `my_run`.
30
+ * `get action my_run my_action` will return the details of the action named `my_action` for the run `my_run`.
24
31
  """
25
32
 
26
33
 
@@ -29,7 +36,7 @@ def get():
29
36
  @click.pass_obj
30
37
  def project(cfg: common.CLIConfig, name: str | None = None):
31
38
  """
32
- Retrieve a list of all projects or details of a specific project by name.
39
+ Get a list of all projects, or details of a specific project by name.
33
40
  """
34
41
  from flyte.remote import Project
35
42
 
@@ -48,7 +55,7 @@ def project(cfg: common.CLIConfig, name: str | None = None):
48
55
  @click.pass_obj
49
56
  def run(cfg: common.CLIConfig, name: str | None = None, project: str | None = None, domain: str | None = None):
50
57
  """
51
- Get list of all runs or details of a specific run by name.
58
+ Get a list of all runs, or details of a specific run by name.
52
59
 
53
60
  The run details will include information about the run, its status, but only the root action will be shown.
54
61
 
@@ -78,9 +85,9 @@ def task(
78
85
  domain: str | None = None,
79
86
  ):
80
87
  """
81
- Retrieve a list of all tasks or details of a specific task by name and version.
88
+ Retrieve a list of all tasks, or details of a specific task by name and version.
82
89
 
83
- Currently name+version are required to get a specific task.
90
+ Currently, both `name` and `version` are required to get a specific task.
84
91
  """
85
92
  from flyte.remote import Task
86
93
 
@@ -156,23 +163,24 @@ def logs(
156
163
  ):
157
164
  """
158
165
  Stream logs for the provided run or action.
159
- If only the run is provided, only the logs for the parent action will be streamed.
166
+ If only the run is provided, only the logs for the parent action will be streamed:
160
167
 
161
- Example:
162
168
  ```bash
163
- flyte get logs my_run
169
+ $ flyte get logs my_run
164
170
  ```
165
171
 
166
- But, if you want to see the logs for a specific action, you can provide the action name as well:
172
+ If you want to see the logs for a specific action, you can provide the action name as well:
173
+
167
174
  ```bash
168
- flyte get logs my_run my_action
175
+ $ flyte get logs my_run my_action
169
176
  ```
170
- By default logs will be shown in the raw format, will scroll on the terminal. If automatic scrolling and only
171
- tailing --lines lines is desired, use the `--pretty` flag:
177
+
178
+ By default, logs will be shown in the raw format and will scroll the terminal.
179
+ If automatic scrolling and only tailing `--lines` number of lines is desired, use the `--pretty` flag:
180
+
172
181
  ```bash
173
- flyte get logs my_run my_action --pretty --lines 50
182
+ $ flyte get logs my_run my_action --pretty --lines 50
174
183
  ```
175
-
176
184
  """
177
185
  import flyte.remote as remote
178
186
 
@@ -206,7 +214,7 @@ def secret(
206
214
  domain: str | None = None,
207
215
  ):
208
216
  """
209
- Retrieve a list of all secrets or details of a specific secret by name.
217
+ Get a list of all secrets, or details of a specific secret by name.
210
218
  """
211
219
  import flyte.remote as remote
212
220
 
@@ -236,18 +244,18 @@ def io(
236
244
  ):
237
245
  """
238
246
  Get the inputs and outputs of a run or action.
239
- if only the run name is provided, it will show the inputs and outputs of the root action of that run.
247
+ If only the run name is provided, it will show the inputs and outputs of the root action of that run.
240
248
  If an action name is provided, it will show the inputs and outputs for that action.
241
-
242
249
  If `--inputs-only` or `--outputs-only` is specified, it will only show the inputs or outputs respectively.
243
250
 
244
- Example:
251
+ Examples:
252
+
245
253
  ```bash
246
- flyte get io my_run
254
+ $ flyte get io my_run
247
255
  ```
248
- or
256
+
249
257
  ```bash
250
- flyte get io my_run my_action
258
+ $ flyte get io my_run my_action
251
259
  ```
252
260
  """
253
261
  if inputs_only and outputs_only:
@@ -293,7 +301,7 @@ def io(
293
301
  @click.pass_obj
294
302
  def config(cfg: common.CLIConfig):
295
303
  """
296
- Shows the automatically detected configuration to connect with remote Flyte services.
304
+ Shows the automatically detected configuration to connect with the remote backend.
297
305
 
298
306
  The configuration will include the endpoint, organization, and other settings that are used by the CLI.
299
307
  """
flyte/cli/_params.py CHANGED
@@ -23,7 +23,7 @@ from mashumaro.codecs.json import JSONEncoder
23
23
 
24
24
  from flyte._logging import logger
25
25
  from flyte.io import Dir, File
26
- from flyte.io.pickle.transformer import FlytePickleTransformer
26
+ from flyte.types._pickle import FlytePickleTransformer
27
27
 
28
28
 
29
29
  class StructuredDataset:
flyte/cli/main.py CHANGED
@@ -13,7 +13,7 @@ from ._run import run
13
13
  click.rich_click.COMMAND_GROUPS = {
14
14
  "flyte": [
15
15
  {
16
- "name": "Running Workflows",
16
+ "name": "Running workflows",
17
17
  "commands": ["run", "abort"],
18
18
  },
19
19
  {
@@ -21,7 +21,7 @@ click.rich_click.COMMAND_GROUPS = {
21
21
  "commands": ["create", "deploy", "get"],
22
22
  },
23
23
  {
24
- "name": "Documentation Generation",
24
+ "name": "Documentation generation",
25
25
  "commands": ["gen"],
26
26
  },
27
27
  ]
@@ -58,13 +58,13 @@ def _verbosity_to_loglevel(verbosity: int) -> int | None:
58
58
  "--endpoint",
59
59
  type=str,
60
60
  required=False,
61
- help="The endpoint to connect to, this will override any config and simply used pkce to connect.",
61
+ help="The endpoint to connect to. This will override any configuration file and simply use `pkce` to connect.",
62
62
  )
63
63
  @click.option(
64
- "--insecure/--secure",
64
+ "--insecure",
65
65
  is_flag=True,
66
66
  required=False,
67
- help="Use insecure connection to the endpoint. If secure is specified, the CLI will use TLS.",
67
+ help="Use an insecure connection to the endpoint. If not specified, the CLI will use TLS.",
68
68
  type=bool,
69
69
  default=None,
70
70
  show_default=True,
@@ -73,7 +73,7 @@ def _verbosity_to_loglevel(verbosity: int) -> int | None:
73
73
  "-v",
74
74
  "--verbose",
75
75
  required=False,
76
- help="Show verbose messages and exception traces, multiple times increases verbosity (e.g., -vvv).",
76
+ help="Show verbose messages and exception traces. Repeating multiple times increases the verbosity (e.g., -vvv).",
77
77
  count=True,
78
78
  default=0,
79
79
  type=int,
@@ -82,7 +82,7 @@ def _verbosity_to_loglevel(verbosity: int) -> int | None:
82
82
  "--org",
83
83
  type=str,
84
84
  required=False,
85
- help="Organization to use",
85
+ help="The organization to which the command applies.",
86
86
  )
87
87
  @click.option(
88
88
  "-c",
@@ -90,8 +90,7 @@ def _verbosity_to_loglevel(verbosity: int) -> int | None:
90
90
  "config_file",
91
91
  required=False,
92
92
  type=click.Path(exists=True),
93
- help="Path to config file (YAML format) to use for the CLI. If not specified,"
94
- " the default config file will be used.",
93
+ help="Path to the configuration file to use. If not specified, the default configuration file is used.",
95
94
  )
96
95
  @click.rich_config(help_config=help_config)
97
96
  @click.pass_context
@@ -104,34 +103,38 @@ def main(
104
103
  config_file: str | None,
105
104
  ):
106
105
  """
107
- ### Flyte entrypoint for the CLI
108
- The Flyte CLI is a command line interface for interacting with Flyte.
106
+ The Flyte CLI is the the command line interface for working with the Flyte SDK and backend.
109
107
 
110
- The flyte cli follows a simple verb based structure, where the top-level commands are verbs that describe the action
111
- to be taken, and the subcommands are nouns that describe the object of the action.
108
+ It follows a simple verb/noun structure,
109
+ where the top-level commands are verbs that describe the action to be taken,
110
+ and the subcommands are nouns that describe the object of the action.
112
111
 
113
- The root command can be used to configure the CLI for most commands, such as setting the endpoint,
114
- organization, and verbosity level.
112
+ The root command can be used to configure the CLI for persistent settings,
113
+ such as the endpoint, organization, and verbosity level.
115
114
 
116
- Example: Set endpoint and organization
117
- ```bash
118
- flyte --endpoint <endpoint> --org <org> get project <project_name>
119
- ```
115
+ Set endpoint and organization:
120
116
 
121
- Example: Increase verbosity level (This is useful for debugging, this will show more logs and exception traces)
122
- ```bash
123
- flyte -vvv get logs <run-name>
124
- ```
117
+ ```bash
118
+ $ flyte --endpoint <endpoint> --org <org> get project <project_name>
119
+ ```
120
+
121
+ Increase verbosity level (This is useful for debugging,
122
+ this will show more logs and exception traces):
123
+
124
+ ```bash
125
+ $ flyte -vvv get logs <run-name>
126
+ ```
127
+
128
+ Override the default config file:
125
129
 
126
- Example: Override the default config file
127
130
  ```bash
128
- flyte --config /path/to/config.yaml run ...
131
+ $ flyte --config /path/to/config.yaml run ...
129
132
  ```
130
133
 
131
- 👉 [Documentation](https://www.union.ai/docs/flyte/user-guide/) \n
132
- 👉 [GitHub](https://github.com/flyteorg/flyte) - Please leave a ⭐. \n
133
- 👉 [Slack](https://slack.flyte.org) - Join the community and ask questions.
134
- 👉 [Issues](https://github.com/flyteorg/flyte/issues)
134
+ * [Documentation](https://www.union.ai/docs/flyte/user-guide/)
135
+ * [GitHub](https://github.com/flyteorg/flyte): Please leave a star if you like Flyte!
136
+ * [Slack](https://slack.flyte.org): Join the community and ask questions.
137
+ * [Issues](https://github.com/flyteorg/flyte/issues)
135
138
 
136
139
  """
137
140
  import flyte.config as config
@@ -12,17 +12,14 @@ from flyte.models import NativeInterface, SerializationContext
12
12
  _PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"
13
13
 
14
14
 
15
- def _extract_command_key(cmd: str, **kwargs) -> Any:
15
+ def _extract_command_key(cmd: str, **kwargs) -> List[Any] | None:
16
16
  """
17
17
  Extract the key from the command using regex.
18
18
  """
19
19
  import re
20
20
 
21
- input_regex = r"^\{\{\s*\.inputs\.(.*?)\s*\}\}$"
22
- match = re.match(input_regex, cmd)
23
- if match:
24
- return match.group(1)
25
- return None
21
+ input_regex = r"\{\{\.inputs\.([a-zA-Z0-9_]+)\}\}"
22
+ return re.findall(input_regex, cmd)
26
23
 
27
24
 
28
25
  def _extract_path_command_key(cmd: str, input_data_dir: Optional[str]) -> Optional[str]:
@@ -70,7 +67,7 @@ class ContainerTask(TaskTemplate):
70
67
  input_data_dir: str | pathlib.Path = "/var/inputs",
71
68
  output_data_dir: str | pathlib.Path = "/var/outputs",
72
69
  metadata_format: MetadataFormat = "JSON",
73
- local_logs: bool = False,
70
+ local_logs: bool = True,
74
71
  **kwargs,
75
72
  ):
76
73
  super().__init__(
@@ -106,34 +103,33 @@ class ContainerTask(TaskTemplate):
106
103
  For FlyteFile and FlyteDirectory commands, e.g., "/var/inputs/inputs", we extract the key from strings that
107
104
  begin with the specified `input_data_dir`.
108
105
  """
109
- # from flytekit.types.directory import FlyteDirectory
110
- # from flytekit.types.file import FlyteFile
106
+ from flyte.io import Dir, File
111
107
 
112
108
  volume_binding: Dict[str, Dict[str, str]] = {}
113
109
  path_k = _extract_path_command_key(cmd, str(self._input_data_dir))
114
- k = path_k if path_k else _extract_command_key(cmd)
115
-
116
- if k:
117
- input_val = kwargs.get(k)
118
- # TODO: Add support file and directory transformer first
119
- # if type(input_val) in [FlyteFile, FlyteDirectory]:
120
- # if not path_k:
121
- # raise AssertionError(
122
- # "FlyteFile and FlyteDirectory commands should not use the template syntax like this:
123
- # {{.inputs.infile}}\n"
124
- # "Please use a path-like syntax, such as: /var/inputs/infile.\n"
125
- # "This requirement is due to how Flyte Propeller processes template syntax inputs."
126
- # )
127
- # local_flyte_file_or_dir_path = str(input_val)
128
- # remote_flyte_file_or_dir_path = os.path.join(self._input_data_dir, k) # type: ignore
129
- # volume_binding[local_flyte_file_or_dir_path] = {
130
- # "bind": remote_flyte_file_or_dir_path,
131
- # "mode": "rw",
132
- # }
133
- # command = remote_flyte_file_or_dir_path
134
- command = str(input_val)
135
- else:
136
- command = cmd
110
+ keys = path_k if path_k else _extract_command_key(cmd)
111
+
112
+ if keys:
113
+ for k in keys:
114
+ input_val = kwargs.get(k)
115
+ # TODO: Add support file and directory transformer first
116
+ if type(input_val) in [File, Dir]:
117
+ if not path_k:
118
+ raise AssertionError(
119
+ "File and Directory commands should not use the template syntax "
120
+ "like this: {{.inputs.infile}}\n"
121
+ "Please use a path-like syntax, such as: /var/inputs/infile.\n"
122
+ "This requirement is due to how Flyte Propeller processes template syntax inputs."
123
+ )
124
+ local_flyte_file_or_dir_path = str(input_val)
125
+ remote_flyte_file_or_dir_path = os.path.join(self._input_data_dir, k) # type: ignore
126
+ volume_binding[local_flyte_file_or_dir_path] = {
127
+ "bind": remote_flyte_file_or_dir_path,
128
+ "mode": "rw",
129
+ }
130
+ command = remote_flyte_file_or_dir_path
131
+ else:
132
+ command = cmd
137
133
 
138
134
  return command, volume_binding
139
135
 
@@ -235,6 +231,7 @@ class ContainerTask(TaskTemplate):
235
231
  raise AssertionError(f"Only Image objects are supported, not strings. Got {self._image} instead.")
236
232
  uri = self._image.uri
237
233
  self._pull_image_if_not_exists(client, uri)
234
+ print(f"Command: {commands!r}")
238
235
 
239
236
  container = client.containers.run(uri, command=commands, remove=True, volumes=volume_bindings, detach=True)
240
237
 
flyte/io/__init__.py CHANGED
@@ -3,9 +3,25 @@
3
3
 
4
4
  This package contains additional data types beyond the primitive data types in python to abstract data flow
5
5
  of large datasets in Union.
6
+
6
7
  """
7
8
 
8
- __all__ = ["Dir", "File"]
9
+ __all__ = [
10
+ "Dir",
11
+ "File",
12
+ "StructuredDataset",
13
+ "StructuredDatasetDecoder",
14
+ "StructuredDatasetEncoder",
15
+ "StructuredDatasetTransformerEngine",
16
+ "lazy_import_structured_dataset_handler",
17
+ ]
9
18
 
10
19
  from ._dir import Dir
11
20
  from ._file import File
21
+ from ._structured_dataset import (
22
+ StructuredDataset,
23
+ StructuredDatasetDecoder,
24
+ StructuredDatasetEncoder,
25
+ StructuredDatasetTransformerEngine,
26
+ lazy_import_structured_dataset_handler,
27
+ )
flyte/io/_file.py CHANGED
@@ -232,6 +232,8 @@ class File(BaseModel, Generic[T], SerializableType):
232
232
  # This code is broadly similar to what storage.get_stream does, but without actually reading from the stream
233
233
  file_handle = None
234
234
  try:
235
+ if "b" not in mode:
236
+ raise ValueError("Mode must include 'b' for binary access, when using remote files.")
235
237
  if isinstance(fs, AsyncFileSystem):
236
238
  file_handle = await fs.open_async(self.path, mode)
237
239
  yield file_handle
@@ -9,7 +9,7 @@ from fsspec.core import split_protocol, strip_protocol
9
9
  import flyte.storage as storage
10
10
  from flyte._logging import logger
11
11
  from flyte._utils import lazy_module
12
- from flyte.io.structured_dataset.structured_dataset import (
12
+ from flyte.io._structured_dataset.structured_dataset import (
13
13
  CSV,
14
14
  PARQUET,
15
15
  StructuredDataset,
@@ -168,7 +168,7 @@ class StructuredDataset(SerializableType, DataClassJSONMixin):
168
168
  return self._literal_sd
169
169
 
170
170
  def open(self, dataframe_type: Type[DF]):
171
- from flyte.io.structured_dataset import lazy_import_structured_dataset_handler
171
+ from flyte.io._structured_dataset import lazy_import_structured_dataset_handler
172
172
 
173
173
  """
174
174
  Load the handler if needed. For the use case like:
flyte/models.py CHANGED
@@ -5,7 +5,9 @@ import os
5
5
  import pathlib
6
6
  import tempfile
7
7
  from dataclasses import dataclass, field, replace
8
- from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type
8
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Literal, Optional, Tuple, Type
9
+
10
+ import rich.repr
9
11
 
10
12
  from flyte._docstring import Docstring
11
13
  from flyte._interface import extract_return_annotation
@@ -27,6 +29,7 @@ def generate_random_name() -> str:
27
29
  return str(uuid4()) # Placeholder for actual random name generation logic
28
30
 
29
31
 
32
+ @rich.repr.auto
30
33
  @dataclass(frozen=True, kw_only=True)
31
34
  class ActionID:
32
35
  """
@@ -68,6 +71,7 @@ class ActionID:
68
71
  return self.new_sub_action(new_name)
69
72
 
70
73
 
74
+ @rich.repr.auto
71
75
  @dataclass(frozen=True, kw_only=True)
72
76
  class RawDataPath:
73
77
  """
@@ -91,6 +95,7 @@ class RawDataPath:
91
95
  # Create a temporary directory for data storage
92
96
  p = tempfile.mkdtemp()
93
97
  logger.debug(f"Creating temporary directory for data storage: {p}")
98
+ pathlib.Path(p).mkdir(parents=True, exist_ok=True)
94
99
  return RawDataPath(path=p)
95
100
  case str():
96
101
  return RawDataPath(path=local_folder)
@@ -133,11 +138,13 @@ class RawDataPath:
133
138
  return remote_path
134
139
 
135
140
 
141
+ @rich.repr.auto
136
142
  @dataclass(frozen=True)
137
143
  class GroupData:
138
144
  name: str
139
145
 
140
146
 
147
+ @rich.repr.auto
141
148
  @dataclass(frozen=True, kw_only=True)
142
149
  class TaskContext:
143
150
  """
@@ -160,6 +167,7 @@ class TaskContext:
160
167
  code_bundle: CodeBundle | None = None
161
168
  compiled_image_cache: ImageCache | None = None
162
169
  data: Dict[str, Any] = field(default_factory=dict)
170
+ mode: Literal["local", "remote", "hybrid"] = "remote"
163
171
 
164
172
  def replace(self, **kwargs) -> TaskContext:
165
173
  if "data" in kwargs:
@@ -177,6 +185,7 @@ class TaskContext:
177
185
  return self.data.get(key)
178
186
 
179
187
 
188
+ @rich.repr.auto
180
189
  @dataclass(frozen=True, kw_only=True)
181
190
  class CodeBundle:
182
191
  """
@@ -211,6 +220,7 @@ class CodeBundle:
211
220
  return replace(self, downloaded_path=path)
212
221
 
213
222
 
223
+ @rich.repr.auto
214
224
  @dataclass(frozen=True)
215
225
  class Checkpoints:
216
226
  """
flyte/syncify/_api.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  import asyncio
4
4
  import atexit
5
5
  import concurrent.futures
6
+ import functools
6
7
  import inspect
7
8
  import threading
8
9
  from typing import (
@@ -14,7 +15,6 @@ from typing import (
14
15
  Iterator,
15
16
  ParamSpec,
16
17
  Protocol,
17
- Type,
18
18
  TypeVar,
19
19
  Union,
20
20
  cast,
@@ -92,7 +92,9 @@ class _BackgroundLoop:
92
92
 
93
93
  def iterate_in_loop_sync(self, async_gen: AsyncIterator[R_co]) -> Iterator[R_co]:
94
94
  # Create an iterator that pulls items from the async generator
95
- assert self.thread.name != threading.current_thread().name, "Cannot run coroutine in the same thread"
95
+ assert self.thread.name != threading.current_thread().name, (
96
+ f"Cannot run coroutine in the same thread {self.thread.name}"
97
+ )
96
98
  while True:
97
99
  try:
98
100
  # use __anext__() and cast to Coroutine so mypy is happy
@@ -104,12 +106,19 @@ class _BackgroundLoop:
104
106
  except (StopAsyncIteration, StopIteration):
105
107
  break
106
108
 
107
- def call_in_loop_sync(self, coro: Coroutine[Any, Any, R_co]) -> R_co:
109
+ def call_in_loop_sync(self, coro: Coroutine[Any, Any, R_co]) -> R_co | Iterator[R_co]:
108
110
  """
109
111
  Run the given coroutine in the background loop and return its result.
110
112
  """
111
- future: concurrent.futures.Future[R_co] = asyncio.run_coroutine_threadsafe(coro, self.loop)
112
- return future.result()
113
+ future: concurrent.futures.Future[R_co | AsyncIterator[R_co]] = asyncio.run_coroutine_threadsafe(
114
+ coro, self.loop
115
+ )
116
+ result = future.result()
117
+ if result is not None and hasattr(result, "__aiter__"):
118
+ # If the result is an async iterator, we need to convert it to a sync iterator
119
+ return cast(Iterator[R_co], self.iterate_in_loop_sync(cast(AsyncIterator[R_co], result)))
120
+ # Otherwise, just return the result
121
+ return result
113
122
 
114
123
  async def iterate_in_loop(self, async_gen: AsyncIterator[R_co]) -> AsyncIterator[R_co]:
115
124
  """
@@ -131,7 +140,9 @@ class _BackgroundLoop:
131
140
  # Wrap the future in an asyncio Future to yield it in an async context
132
141
  aio_future: asyncio.Future[R_co] = asyncio.wrap_future(future)
133
142
  # await for the future to complete and yield its result
134
- yield await aio_future
143
+ v = await aio_future
144
+ print(f"Yielding value: {v}")
145
+ yield v
135
146
  except StopAsyncIteration:
136
147
  break
137
148
 
@@ -159,8 +170,6 @@ class _SyncWrapper:
159
170
  self,
160
171
  fn: Any,
161
172
  bg_loop: _BackgroundLoop,
162
- instance: Any = None,
163
- owner: Type | None = None,
164
173
  underlying_obj: Any = None,
165
174
  ):
166
175
  self.fn = fn
@@ -168,6 +177,12 @@ class _SyncWrapper:
168
177
  self._underlying_obj = underlying_obj
169
178
 
170
179
  def __call__(self, *args: Any, **kwargs: Any) -> Any:
180
+ if threading.current_thread().name == self._bg_loop.thread.name:
181
+ # If we are already in the background loop thread, we can call the function directly
182
+ raise AssertionError(
183
+ f"Deadlock detected: blocking call used in syncify thread {self._bg_loop.thread.name} "
184
+ f"when calling function {self.fn}, use .aio() if in an async call."
185
+ )
171
186
  # bind method if needed
172
187
  coro_fn = self.fn
173
188
 
@@ -194,7 +209,9 @@ class _SyncWrapper:
194
209
  # If we have an owner, we need to bind the method to the owner (for classmethods or staticmethods)
195
210
  fn = self._underlying_obj.__get__(None, owner)
196
211
 
197
- return _SyncWrapper(fn, bg_loop=self._bg_loop, underlying_obj=self._underlying_obj)
212
+ wrapper = _SyncWrapper(fn, bg_loop=self._bg_loop, underlying_obj=self._underlying_obj)
213
+ functools.update_wrapper(wrapper, self.fn)
214
+ return wrapper
198
215
 
199
216
  def aio(self, *args: Any, **kwargs: Any) -> Any:
200
217
  fn = self.fn
@@ -262,15 +279,26 @@ class Syncify:
262
279
 
263
280
  def __call__(self, obj):
264
281
  if isinstance(obj, classmethod):
265
- return _SyncWrapper(obj.__func__, bg_loop=self._bg_loop, underlying_obj=obj)
282
+ wrapper = _SyncWrapper(obj.__func__, bg_loop=self._bg_loop, underlying_obj=obj)
283
+ functools.update_wrapper(wrapper, obj.__func__)
284
+ return wrapper
285
+
266
286
  if isinstance(obj, staticmethod):
267
- return staticmethod(cast(Any, _SyncWrapper(obj.__func__, bg_loop=self._bg_loop)))
287
+ fn = obj.__func__
288
+ wrapper = _SyncWrapper(fn, bg_loop=self._bg_loop)
289
+ functools.update_wrapper(wrapper, fn)
290
+ return staticmethod(wrapper)
291
+
268
292
  if inspect.isasyncgenfunction(obj):
269
- # If the function is an async generator, we need to handle it differently
270
- return cast(Callable[P, Iterator[R_co]], _SyncWrapper(obj, bg_loop=self._bg_loop))
293
+ wrapper = _SyncWrapper(obj, bg_loop=self._bg_loop)
294
+ functools.update_wrapper(wrapper, obj)
295
+ return cast(Callable[P, Iterator[R_co]], wrapper)
296
+
271
297
  if inspect.iscoroutinefunction(obj):
272
- # If the function is a coroutine, we can wrap it directly
273
- return _SyncWrapper(obj, bg_loop=self._bg_loop)
298
+ wrapper = _SyncWrapper(obj, bg_loop=self._bg_loop)
299
+ functools.update_wrapper(wrapper, obj)
300
+ return wrapper
301
+
274
302
  raise TypeError(
275
303
  "Syncify can only be applied to async functions, async generators, async classmethods or staticmethods."
276
304
  )