flyte 0.2.0b12__py3-none-any.whl → 0.2.0b14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

@@ -15,15 +15,15 @@ class StateServiceStub(object):
15
15
  Args:
16
16
  channel: A grpc.Channel.
17
17
  """
18
- self.Store = channel.unary_unary(
19
- '/cloudidl.workflow.StateService/Store',
20
- request_serializer=workflow_dot_state__service__pb2.StoreRequest.SerializeToString,
21
- response_deserializer=workflow_dot_state__service__pb2.StoreResponse.FromString,
18
+ self.Put = channel.stream_stream(
19
+ '/cloudidl.workflow.StateService/Put',
20
+ request_serializer=workflow_dot_state__service__pb2.PutRequest.SerializeToString,
21
+ response_deserializer=workflow_dot_state__service__pb2.PutResponse.FromString,
22
22
  )
23
- self.Load = channel.unary_unary(
24
- '/cloudidl.workflow.StateService/Load',
25
- request_serializer=workflow_dot_state__service__pb2.LoadRequest.SerializeToString,
26
- response_deserializer=workflow_dot_state__service__pb2.LoadResponse.FromString,
23
+ self.Get = channel.stream_stream(
24
+ '/cloudidl.workflow.StateService/Get',
25
+ request_serializer=workflow_dot_state__service__pb2.GetRequest.SerializeToString,
26
+ response_deserializer=workflow_dot_state__service__pb2.GetResponse.FromString,
27
27
  )
28
28
  self.Watch = channel.unary_stream(
29
29
  '/cloudidl.workflow.StateService/Watch',
@@ -36,15 +36,15 @@ class StateServiceServicer(object):
36
36
  """provides an interface for managing the state of actions.
37
37
  """
38
38
 
39
- def Store(self, request, context):
40
- """store the state of an action.
39
+ def Put(self, request_iterator, context):
40
+ """put the state of an action.
41
41
  """
42
42
  context.set_code(grpc.StatusCode.UNIMPLEMENTED)
43
43
  context.set_details('Method not implemented!')
44
44
  raise NotImplementedError('Method not implemented!')
45
45
 
46
- def Load(self, request, context):
47
- """load the state of an action.
46
+ def Get(self, request_iterator, context):
47
+ """get the state of an action.
48
48
  """
49
49
  context.set_code(grpc.StatusCode.UNIMPLEMENTED)
50
50
  context.set_details('Method not implemented!')
@@ -60,15 +60,15 @@ class StateServiceServicer(object):
60
60
 
61
61
  def add_StateServiceServicer_to_server(servicer, server):
62
62
  rpc_method_handlers = {
63
- 'Store': grpc.unary_unary_rpc_method_handler(
64
- servicer.Store,
65
- request_deserializer=workflow_dot_state__service__pb2.StoreRequest.FromString,
66
- response_serializer=workflow_dot_state__service__pb2.StoreResponse.SerializeToString,
63
+ 'Put': grpc.stream_stream_rpc_method_handler(
64
+ servicer.Put,
65
+ request_deserializer=workflow_dot_state__service__pb2.PutRequest.FromString,
66
+ response_serializer=workflow_dot_state__service__pb2.PutResponse.SerializeToString,
67
67
  ),
68
- 'Load': grpc.unary_unary_rpc_method_handler(
69
- servicer.Load,
70
- request_deserializer=workflow_dot_state__service__pb2.LoadRequest.FromString,
71
- response_serializer=workflow_dot_state__service__pb2.LoadResponse.SerializeToString,
68
+ 'Get': grpc.stream_stream_rpc_method_handler(
69
+ servicer.Get,
70
+ request_deserializer=workflow_dot_state__service__pb2.GetRequest.FromString,
71
+ response_serializer=workflow_dot_state__service__pb2.GetResponse.SerializeToString,
72
72
  ),
73
73
  'Watch': grpc.unary_stream_rpc_method_handler(
74
74
  servicer.Watch,
@@ -87,7 +87,7 @@ class StateService(object):
87
87
  """
88
88
 
89
89
  @staticmethod
90
- def Store(request,
90
+ def Put(request_iterator,
91
91
  target,
92
92
  options=(),
93
93
  channel_credentials=None,
@@ -97,14 +97,14 @@ class StateService(object):
97
97
  wait_for_ready=None,
98
98
  timeout=None,
99
99
  metadata=None):
100
- return grpc.experimental.unary_unary(request, target, '/cloudidl.workflow.StateService/Store',
101
- workflow_dot_state__service__pb2.StoreRequest.SerializeToString,
102
- workflow_dot_state__service__pb2.StoreResponse.FromString,
100
+ return grpc.experimental.stream_stream(request_iterator, target, '/cloudidl.workflow.StateService/Put',
101
+ workflow_dot_state__service__pb2.PutRequest.SerializeToString,
102
+ workflow_dot_state__service__pb2.PutResponse.FromString,
103
103
  options, channel_credentials,
104
104
  insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
105
105
 
106
106
  @staticmethod
107
- def Load(request,
107
+ def Get(request_iterator,
108
108
  target,
109
109
  options=(),
110
110
  channel_credentials=None,
@@ -114,9 +114,9 @@ class StateService(object):
114
114
  wait_for_ready=None,
115
115
  timeout=None,
116
116
  metadata=None):
117
- return grpc.experimental.unary_unary(request, target, '/cloudidl.workflow.StateService/Load',
118
- workflow_dot_state__service__pb2.LoadRequest.SerializeToString,
119
- workflow_dot_state__service__pb2.LoadResponse.FromString,
117
+ return grpc.experimental.stream_stream(request_iterator, target, '/cloudidl.workflow.StateService/Get',
118
+ workflow_dot_state__service__pb2.GetRequest.SerializeToString,
119
+ workflow_dot_state__service__pb2.GetResponse.FromString,
120
120
  options, channel_credentials,
121
121
  insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
122
122
 
flyte/_run.py CHANGED
@@ -63,6 +63,7 @@ class _Runner:
63
63
  raw_data_path: str | None = None,
64
64
  metadata_path: str | None = None,
65
65
  run_base_dir: str | None = None,
66
+ overwrite_cache: bool = False,
66
67
  ):
67
68
  init_config = _get_init_config()
68
69
  client = init_config.client if init_config else None
@@ -81,6 +82,7 @@ class _Runner:
81
82
  self._raw_data_path = raw_data_path
82
83
  self._metadata_path = metadata_path or "/tmp"
83
84
  self._run_base_dir = run_base_dir or "/tmp/base"
85
+ self._overwrite_cache = overwrite_cache
84
86
 
85
87
  @requires_initialization
86
88
  async def _run_remote(self, obj: TaskTemplate[P, R] | LazyEntity, *args: P.args, **kwargs: P.kwargs) -> Run:
@@ -182,6 +184,9 @@ class _Runner:
182
184
  project_id=project_id,
183
185
  task_spec=task_spec,
184
186
  inputs=inputs.proto_inputs,
187
+ run_spec=run_definition_pb2.RunSpec(
188
+ overwrite_cache=self._overwrite_cache,
189
+ ),
185
190
  ),
186
191
  )
187
192
  return Run(pb2=resp.run)
@@ -414,6 +419,7 @@ def with_runcontext(
414
419
  interactive_mode: bool | None = None,
415
420
  raw_data_path: str | None = None,
416
421
  run_base_dir: str | None = None,
422
+ overwrite_cache: bool = False,
417
423
  ) -> _Runner:
418
424
  """
419
425
  Launch a new run with the given parameters as the context.
flyte/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.0b12'
21
- __version_tuple__ = version_tuple = (0, 2, 0, 'b12')
20
+ __version__ = version = '0.2.0b14'
21
+ __version_tuple__ = version_tuple = (0, 2, 0, 'b14')
flyte/cli/_common.py CHANGED
@@ -137,8 +137,8 @@ class InvokeBaseMixin:
137
137
  if e.code() == grpc.StatusCode.INVALID_ARGUMENT:
138
138
  raise click.ClickException(f"Invalid argument provided. Please check your input. Error: {e.details()}")
139
139
  raise click.ClickException(f"RPC error invoking command: {e!s}") from e
140
- except flyte.errors.InitializationError:
141
- raise click.ClickException("Initialize the CLI with a remote configuration. For example, pass --endpoint")
140
+ except flyte.errors.InitializationError as e:
141
+ raise click.ClickException(f"Initialization failed. Pass remote config for CLI. (Reason: {e})")
142
142
  except flyte.errors.BaseRuntimeError as e:
143
143
  raise click.ClickException(f"{e.kind} failure, {e.code}. {e}") from e
144
144
  except click.exceptions.Exit as e:
flyte/cli/_delete.py CHANGED
@@ -10,7 +10,7 @@ def delete():
10
10
  """
11
11
 
12
12
 
13
- @click.command(cls=common.CommandBase)
13
+ @delete.command(cls=common.CommandBase)
14
14
  @click.argument("name", type=str, required=True)
15
15
  @click.pass_obj
16
16
  def secret(cfg: common.CLIConfig, name: str, project: str | None = None, domain: str | None = None):
flyte/cli/_deploy.py CHANGED
@@ -147,6 +147,6 @@ deploy = EnvFiles(
147
147
  name="deploy",
148
148
  help="""
149
149
  Deploy one or more environments from a python file.
150
- The deploy command will create or update environments in the Flyte system.
150
+ This command will create or update environments in the Flyte system.
151
151
  """,
152
152
  )
flyte/cli/_get.py CHANGED
@@ -20,7 +20,7 @@ def get():
20
20
  Using a `get` subcommand without any arguments will retrieve a list of available resources to get.
21
21
  For example:
22
22
 
23
- * `get project` (without specifiying aproject), will list all projects.
23
+ * `get project` (without specifying a project), will list all projects.
24
24
  * `get project my_project` will return the details of the project named `my_project`.
25
25
 
26
26
  In some cases, a partially specified command will act as a filter and return available further parameters.
@@ -143,7 +143,7 @@ def action(
143
143
  "--pretty",
144
144
  is_flag=True,
145
145
  default=False,
146
- help="Show logs in a auto scrolling box, where number of lines is limited to `--lines`",
146
+ help="Show logs in an auto-scrolling box, where number of lines is limited to `--lines`",
147
147
  )
148
148
  @click.option(
149
149
  "--attempt", "-a", type=int, default=None, help="Attempt number to show logs for, defaults to the latest attempt."
flyte/cli/_run.py CHANGED
@@ -66,7 +66,7 @@ class RunArguments:
66
66
  ["--follow", "-f"],
67
67
  is_flag=True,
68
68
  default=False,
69
- help="Wait and watch logs for the parent action. If not provided, the cli will exit after "
69
+ help="Wait and watch logs for the parent action. If not provided, the CLI will exit after "
70
70
  "successfully launching a remote execution with a link to the UI.",
71
71
  )
72
72
  },
@@ -99,8 +99,6 @@ class RunTaskCommand(click.Command):
99
99
 
100
100
  obj = CLIConfig(flyte.config.auto(), ctx)
101
101
 
102
- if not self.run_args.local:
103
- assert obj.endpoint, "CLI Config should have an endpoint"
104
102
  obj.init(self.run_args.project, self.run_args.domain)
105
103
 
106
104
  async def _run():
@@ -108,7 +106,6 @@ class RunTaskCommand(click.Command):
108
106
 
109
107
  r = flyte.with_runcontext(
110
108
  copy_style=self.run_args.copy_style,
111
- version=self.run_args.copy_style,
112
109
  mode="local" if self.run_args.local else "remote",
113
110
  name=self.run_args.name,
114
111
  ).run(self.obj, **ctx.params)
flyte/cli/main.py CHANGED
@@ -5,31 +5,31 @@ from flyte._logging import initialize_logger, logger
5
5
  from ._abort import abort
6
6
  from ._common import CLIConfig
7
7
  from ._create import create
8
+ from ._delete import delete
8
9
  from ._deploy import deploy
9
10
  from ._gen import gen
10
11
  from ._get import get
11
12
  from ._run import run
12
13
 
13
- click.rich_click.COMMAND_GROUPS = {
14
- "flyte": [
15
- {
16
- "name": "Running workflows",
17
- "commands": ["run", "abort"],
18
- },
19
- {
20
- "name": "Management",
21
- "commands": ["create", "deploy", "get"],
22
- },
23
- {
24
- "name": "Documentation generation",
25
- "commands": ["gen"],
26
- },
27
- ]
28
- }
29
-
30
14
  help_config = click.RichHelpConfiguration(
31
15
  use_markdown=True,
32
16
  use_markdown_emoji=True,
17
+ command_groups={
18
+ "flyte": [
19
+ {
20
+ "name": "Run and stop tasks",
21
+ "commands": ["run", "abort"],
22
+ },
23
+ {
24
+ "name": "Management",
25
+ "commands": ["create", "deploy", "get", "delete"],
26
+ },
27
+ {
28
+ "name": "Documentation generation",
29
+ "commands": ["gen"],
30
+ },
31
+ ]
32
+ },
33
33
  )
34
34
 
35
35
 
@@ -103,7 +103,7 @@ def main(
103
103
  config_file: str | None,
104
104
  ):
105
105
  """
106
- The Flyte CLI is the the command line interface for working with the Flyte SDK and backend.
106
+ The Flyte CLI is the command line interface for working with the Flyte SDK and backend.
107
107
 
108
108
  It follows a simple verb/noun structure,
109
109
  where the top-level commands are verbs that describe the action to be taken,
@@ -163,3 +163,4 @@ main.add_command(get) # type: ignore
163
163
  main.add_command(create) # type: ignore
164
164
  main.add_command(abort) # type: ignore
165
165
  main.add_command(gen) # type: ignore
166
+ main.add_command(delete) # type: ignore
@@ -1,3 +1,4 @@
1
+ import os
1
2
  import ssl
2
3
  import typing
3
4
 
@@ -15,6 +16,11 @@ from ._authenticators.factory import (
15
16
  get_async_proxy_authenticator,
16
17
  )
17
18
 
19
+ # Set environment variables for gRPC, this reduces log spew and avoids unnecessary warnings
20
+ if "GRPC_VERBOSITY" not in os.environ:
21
+ os.environ["GRPC_VERBOSITY"] = "ERROR"
22
+ os.environ["GRPC_CPP_MIN_LOG_LEVEL"] = "ERROR"
23
+
18
24
  # Initialize gRPC AIO early enough so it can be used in the main thread
19
25
  init_grpc_aio()
20
26
 
flyte/remote/_data.py CHANGED
@@ -16,7 +16,7 @@ from flyteidl.service import dataproxy_pb2
16
16
  from google.protobuf import duration_pb2
17
17
 
18
18
  from flyte._initialize import CommonInit, ensure_client, get_client, get_common_config
19
- from flyte.errors import RuntimeSystemError
19
+ from flyte.errors import InitializationError, RuntimeSystemError
20
20
 
21
21
  _UPLOAD_EXPIRES_IN = timedelta(seconds=60)
22
22
 
@@ -83,6 +83,8 @@ async def _upload_single_file(
83
83
  raise RuntimeSystemError(
84
84
  "PermissionDenied", f"Failed to get signed url for {fp}, please check your permissions."
85
85
  )
86
+ elif e.code() == grpc.StatusCode.UNAVAILABLE:
87
+ raise InitializationError("EndpointUnavailable", "user", "Service is unavailable.")
86
88
  else:
87
89
  raise RuntimeSystemError(e.code().value, f"Failed to get signed url for {fp}.")
88
90
  except Exception as e:
flyte/syncify/_api.py CHANGED
@@ -5,6 +5,7 @@ import atexit
5
5
  import concurrent.futures
6
6
  import functools
7
7
  import inspect
8
+ import logging
8
9
  import threading
9
10
  from typing import (
10
11
  Any,
@@ -21,6 +22,8 @@ from typing import (
21
22
  overload,
22
23
  )
23
24
 
25
+ from flyte._logging import logger
26
+
24
27
  P = ParamSpec("P")
25
28
  R_co = TypeVar("R_co", covariant=True)
26
29
  T = TypeVar("T")
@@ -50,7 +53,7 @@ class SyncGenFunction(Protocol[P, R_co]):
50
53
 
51
54
  class _BackgroundLoop:
52
55
  """
53
- A background event loop that runs in a separate thread and used the the Syncify decorator to run asynchronous
56
+ A background event loop that runs in a separate thread and used the `Syncify` decorator to run asynchronous
54
57
  functions or methods synchronously.
55
58
  """
56
59
 
@@ -105,6 +108,19 @@ class _BackgroundLoop:
105
108
  yield future.result()
106
109
  except (StopAsyncIteration, StopIteration):
107
110
  break
111
+ except Exception as e:
112
+ if logger.getEffectiveLevel() > logging.DEBUG:
113
+ # If the log level is not DEBUG, we will remove the extra stack frames to avoid confusion for the
114
+ # user
115
+ # This is because the stack trace will include the Syncify wrapper and the background loop thread
116
+ tb = e.__traceback__
117
+ while tb and tb.tb_next:
118
+ if tb.tb_frame.f_code.co_name == "":
119
+ break
120
+ tb = tb.tb_next
121
+ raise e.with_traceback(tb)
122
+ # If the log level is DEBUG, we will keep the extra stack frames to help with debugging
123
+ raise e
108
124
 
109
125
  def call_in_loop_sync(self, coro: Coroutine[Any, Any, R_co]) -> R_co | Iterator[R_co]:
110
126
  """
@@ -144,6 +160,19 @@ class _BackgroundLoop:
144
160
  yield v
145
161
  except StopAsyncIteration:
146
162
  break
163
+ except Exception as e:
164
+ if logger.getEffectiveLevel() > logging.DEBUG:
165
+ # If the log level is not DEBUG, we will remove the extra stack frames to avoid confusion for the
166
+ # user.
167
+ # This is because the stack trace will include the Syncify wrapper and the background loop thread
168
+ tb = e.__traceback__
169
+ while tb and tb.tb_next:
170
+ if tb.tb_frame.f_code.co_name == "":
171
+ break
172
+ tb = tb.tb_next
173
+ raise e.with_traceback(tb)
174
+ # If the log level is DEBUG, we will keep the extra stack frames to help with debugging
175
+ raise e
147
176
 
148
177
  async def aio(self, coro: Coroutine[Any, Any, R_co]) -> R_co:
149
178
  """
@@ -152,12 +181,25 @@ class _BackgroundLoop:
152
181
  if self.is_in_loop():
153
182
  # If we are already in the background loop, just run the coroutine
154
183
  return await coro
155
- # Otherwise, run it in the background loop and wait for the result
156
- future: concurrent.futures.Future[R_co] = asyncio.run_coroutine_threadsafe(coro, self.loop)
157
- # Wrap the future in an asyncio Future to await it in an async context
158
- aio_future: asyncio.Future[R_co] = asyncio.wrap_future(future)
159
- # await for the future to complete and return its result
160
- return await aio_future
184
+ try:
185
+ # Otherwise, run it in the background loop and wait for the result
186
+ future: concurrent.futures.Future[R_co] = asyncio.run_coroutine_threadsafe(coro, self.loop)
187
+ # Wrap the future in an asyncio Future to await it in an async context
188
+ aio_future: asyncio.Future[R_co] = asyncio.wrap_future(future)
189
+ # await for the future to complete and return its result
190
+ return await aio_future
191
+ except Exception as e:
192
+ if logger.getEffectiveLevel() > logging.DEBUG:
193
+ # If the log level is not DEBUG, we will remove the extra stack frames to avoid confusion for the user
194
+ # This is because the stack trace will include the Syncify wrapper and the background loop thread
195
+ tb = e.__traceback__
196
+ while tb and tb.tb_next:
197
+ if tb.tb_frame.f_code.co_name == "":
198
+ break
199
+ tb = tb.tb_next
200
+ raise e.with_traceback(tb)
201
+ # If the log level is DEBUG, we will keep the extra stack frames to help with debugging
202
+ raise e
161
203
 
162
204
 
163
205
  class _SyncWrapper:
@@ -175,23 +217,6 @@ class _SyncWrapper:
175
217
  self._bg_loop = bg_loop
176
218
  self._underlying_obj = underlying_obj
177
219
 
178
- def __call__(self, *args: Any, **kwargs: Any) -> Any:
179
- if threading.current_thread().name == self._bg_loop.thread.name:
180
- # If we are already in the background loop thread, we can call the function directly
181
- raise AssertionError(
182
- f"Deadlock detected: blocking call used in syncify thread {self._bg_loop.thread.name} "
183
- f"when calling function {self.fn}, use .aio() if in an async call."
184
- )
185
- # bind method if needed
186
- coro_fn = self.fn
187
-
188
- if inspect.isasyncgenfunction(coro_fn):
189
- # Handle async iterator by converting to sync iterator
190
- async_gen = coro_fn(*args, **kwargs)
191
- return self._bg_loop.iterate_in_loop_sync(async_gen)
192
- else:
193
- return self._bg_loop.call_in_loop_sync(coro_fn(*args, **kwargs))
194
-
195
220
  def __get__(self, instance: Any, owner: Any) -> Any:
196
221
  """
197
222
  This method is called when the wrapper is accessed as a method of a class instance.
@@ -212,20 +237,63 @@ class _SyncWrapper:
212
237
  functools.update_wrapper(wrapper, self.fn)
213
238
  return wrapper
214
239
 
240
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
241
+ if threading.current_thread().name == self._bg_loop.thread.name:
242
+ # If we are already in the background loop thread, we can call the function directly
243
+ raise AssertionError(
244
+ f"Deadlock detected: blocking call used in syncify thread {self._bg_loop.thread.name} "
245
+ f"when calling function {self.fn}, use .aio() if in an async call."
246
+ )
247
+ try:
248
+ # bind method if needed
249
+ coro_fn = self.fn
250
+
251
+ if inspect.isasyncgenfunction(coro_fn):
252
+ # Handle async iterator by converting to sync iterator
253
+ async_gen = coro_fn(*args, **kwargs)
254
+ return self._bg_loop.iterate_in_loop_sync(async_gen)
255
+ else:
256
+ return self._bg_loop.call_in_loop_sync(coro_fn(*args, **kwargs))
257
+ except Exception as e:
258
+ if logger.getEffectiveLevel() > logging.DEBUG:
259
+ # If the log level is not DEBUG, we will remove the extra stack frames to avoid confusion for the user
260
+ # This is because the stack trace will include the Syncify wrapper and the background loop thread
261
+ tb = e.__traceback__
262
+ while tb and tb.tb_next:
263
+ if tb.tb_frame.f_code.co_name == self.fn.__name__:
264
+ break
265
+ tb = tb.tb_next
266
+ raise e.with_traceback(tb)
267
+ # If the log level is DEBUG, we will keep the extra stack frames to help with debugging
268
+ raise e
269
+
215
270
  def aio(self, *args: Any, **kwargs: Any) -> Any:
216
271
  fn = self.fn
217
272
 
218
- if inspect.isasyncgenfunction(fn):
219
- # If the function is an async generator, we need to handle it differently
220
- async_iter = fn(*args, **kwargs)
221
- return self._bg_loop.iterate_in_loop(async_iter)
222
- else:
223
- # If we are already in the background loop, just return the coroutine
224
- coro = fn(*args, **kwargs)
225
- if hasattr(coro, "__aiter__"):
226
- # If the coroutine is an async iterator, we need to handle it differently
227
- return self._bg_loop.iterate_in_loop(coro)
228
- return self._bg_loop.aio(coro)
273
+ try:
274
+ if inspect.isasyncgenfunction(fn):
275
+ # If the function is an async generator, we need to handle it differently
276
+ async_iter = fn(*args, **kwargs)
277
+ return self._bg_loop.iterate_in_loop(async_iter)
278
+ else:
279
+ # If we are already in the background loop, just return the coroutine
280
+ coro = fn(*args, **kwargs)
281
+ if hasattr(coro, "__aiter__"):
282
+ # If the coroutine is an async iterator, we need to handle it differently
283
+ return self._bg_loop.iterate_in_loop(coro)
284
+ return self._bg_loop.aio(coro)
285
+ except Exception as e:
286
+ if logger.getEffectiveLevel() > logging.DEBUG:
287
+ # If the log level is not DEBUG, we will remove the extra stack frames to avoid confusion for the user
288
+ # This is because the stack trace will include the Syncify wrapper and the background loop thread
289
+ tb = e.__traceback__
290
+ while tb and tb.tb_next:
291
+ if tb.tb_frame.f_code.co_name == self.fn.__name__:
292
+ break
293
+ tb = tb.tb_next
294
+ raise e.with_traceback(tb)
295
+ # If the log level is DEBUG, we will keep the extra stack frames to help with debugging
296
+ raise e
229
297
 
230
298
 
231
299
  class Syncify:
@@ -35,6 +35,7 @@ from mashumaro.jsonschema.models import Context, JSONSchema
35
35
  from mashumaro.jsonschema.plugins import BasePlugin
36
36
  from mashumaro.jsonschema.schema import Instance
37
37
  from mashumaro.mixins.json import DataClassJSONMixin
38
+ from pydantic import BaseModel
38
39
  from typing_extensions import Annotated, get_args, get_origin
39
40
 
40
41
  import flyte.storage as storage
@@ -352,6 +353,79 @@ class RestrictedTypeTransformer(TypeTransformer[T], ABC):
352
353
  raise RestrictedTypeError(f"Transformer for type {self.python_type} is restricted currently")
353
354
 
354
355
 
356
+ class PydanticTransformer(TypeTransformer[BaseModel]):
357
+ def __init__(self):
358
+ super().__init__("Pydantic Transformer", BaseModel, enable_type_assertions=False)
359
+
360
+ def get_literal_type(self, t: Type[BaseModel]) -> LiteralType:
361
+ schema = t.model_json_schema()
362
+ fields = t.__annotations__.items()
363
+
364
+ literal_type = {}
365
+ for name, python_type in fields:
366
+ try:
367
+ literal_type[name] = TypeEngine.to_literal_type(python_type)
368
+ except Exception as e:
369
+ logger.warning(
370
+ "Field {} of type {} cannot be converted to a literal type. Error: {}".format(name, python_type, e)
371
+ )
372
+
373
+ # This is for attribute access in FlytePropeller.
374
+ ts = TypeStructure(tag="", dataclass_type=literal_type)
375
+
376
+ meta_struct = struct_pb2.Struct()
377
+ meta_struct.update(
378
+ {
379
+ CACHE_KEY_METADATA: {
380
+ SERIALIZATION_FORMAT: MESSAGEPACK,
381
+ }
382
+ }
383
+ )
384
+
385
+ return LiteralType(
386
+ simple=SimpleType.STRUCT,
387
+ metadata=schema,
388
+ structure=ts,
389
+ annotation=TypeAnnotation(annotations=meta_struct),
390
+ )
391
+
392
+ async def to_literal(
393
+ self,
394
+ python_val: BaseModel,
395
+ python_type: Type[BaseModel],
396
+ expected: LiteralType,
397
+ ) -> Literal:
398
+ json_str = python_val.model_dump_json()
399
+ dict_obj = json.loads(json_str)
400
+ msgpack_bytes = msgpack.dumps(dict_obj)
401
+ return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag=MESSAGEPACK)))
402
+
403
+ def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[BaseModel]) -> BaseModel:
404
+ if binary_idl_object.tag == MESSAGEPACK:
405
+ dict_obj = msgpack.loads(binary_idl_object.value, strict_map_key=False)
406
+ json_str = json.dumps(dict_obj)
407
+ python_val = expected_python_type.model_validate_json(
408
+ json_data=json_str, strict=False, context={"deserialize": True}
409
+ )
410
+ return python_val
411
+ else:
412
+ raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`")
413
+
414
+ async def to_python_value(self, lv: Literal, expected_python_type: Type[BaseModel]) -> BaseModel:
415
+ """
416
+ There are two kinds of literal values to handle:
417
+ 1. Protobuf Structs (from the UI)
418
+ 2. Binary scalars (from other sources)
419
+ We need to account for both cases accordingly.
420
+ """
421
+ if lv and lv.HasField("scalar") and lv.scalar.HasField("binary"):
422
+ return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore
423
+
424
+ json_str = _json_format.MessageToJson(lv.scalar.generic)
425
+ python_val = expected_python_type.model_validate_json(json_str, strict=False, context={"deserialize": True})
426
+ return python_val
427
+
428
+
355
429
  class PydanticSchemaPlugin(BasePlugin):
356
430
  """This allows us to generate proper schemas for Pydantic models."""
357
431
 
@@ -562,9 +636,8 @@ class DataclassTransformer(TypeTransformer[object]):
562
636
 
563
637
  # This is for attribute access in FlytePropeller.
564
638
  ts = TypeStructure(tag="", dataclass_type=literal_type)
565
- from google.protobuf.struct_pb2 import Struct
566
639
 
567
- meta_struct = Struct()
640
+ meta_struct = struct_pb2.Struct()
568
641
  meta_struct.update(
569
642
  {
570
643
  CACHE_KEY_METADATA: {
@@ -627,7 +700,7 @@ class DataclassTransformer(TypeTransformer[object]):
627
700
  field.type = self._get_origin_type_in_annotation(cast(type, field.type))
628
701
  return python_type
629
702
 
630
- async def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> T:
703
+ def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> T:
631
704
  if binary_idl_object.tag == MESSAGEPACK:
632
705
  if issubclass(expected_python_type, DataClassJSONMixin):
633
706
  dict_obj = msgpack.loads(binary_idl_object.value, strict_map_key=False)
@@ -652,9 +725,10 @@ class DataclassTransformer(TypeTransformer[object]):
652
725
  "user defined datatypes in Flytekit"
653
726
  )
654
727
 
655
- if lv.scalar and lv.scalar.binary:
656
- return await self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore
728
+ if lv.HasField("scalar") and lv.scalar.HasField("binary"):
729
+ return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore
657
730
 
731
+ # todo: revisit this, it should always be a binary in v2.
658
732
  json_str = _json_format.MessageToJson(lv.scalar.generic)
659
733
 
660
734
  # The `from_json` function is provided from mashumaro's `DataClassJSONMixin`.
@@ -970,11 +1044,10 @@ class TypeEngine(typing.Generic[T]):
970
1044
  return cls._REGISTRY[python_type.__origin__]
971
1045
 
972
1046
  # Handling UnionType specially - PEP 604
973
- if sys.version_info >= (3, 10):
974
- import types
1047
+ import types
975
1048
 
976
- if isinstance(python_type, types.UnionType):
977
- return cls._REGISTRY[types.UnionType]
1049
+ if isinstance(python_type, types.UnionType):
1050
+ return cls._REGISTRY[types.UnionType]
978
1051
 
979
1052
  if python_type in cls._REGISTRY:
980
1053
  return cls._REGISTRY[python_type]
@@ -2041,6 +2114,7 @@ def _register_default_type_transformers():
2041
2114
  TypeEngine.register(DictTransformer())
2042
2115
  TypeEngine.register(EnumTransformer())
2043
2116
  TypeEngine.register(ProtobufTransformer())
2117
+ TypeEngine.register(PydanticTransformer())
2044
2118
 
2045
2119
  # inner type is. Also unsupported are typing's Tuples. Even though you can look inside them, Flyte's type system
2046
2120
  # doesn't support these currently.