flyte 0.2.0b1__py3-none-any.whl → 2.0.0b46__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (266) hide show
  1. flyte/__init__.py +83 -30
  2. flyte/_bin/connect.py +61 -0
  3. flyte/_bin/debug.py +38 -0
  4. flyte/_bin/runtime.py +87 -19
  5. flyte/_bin/serve.py +351 -0
  6. flyte/_build.py +3 -2
  7. flyte/_cache/cache.py +6 -5
  8. flyte/_cache/local_cache.py +216 -0
  9. flyte/_code_bundle/_ignore.py +31 -5
  10. flyte/_code_bundle/_packaging.py +42 -11
  11. flyte/_code_bundle/_utils.py +57 -34
  12. flyte/_code_bundle/bundle.py +130 -27
  13. flyte/_constants.py +1 -0
  14. flyte/_context.py +21 -5
  15. flyte/_custom_context.py +73 -0
  16. flyte/_debug/constants.py +37 -0
  17. flyte/_debug/utils.py +17 -0
  18. flyte/_debug/vscode.py +315 -0
  19. flyte/_deploy.py +396 -75
  20. flyte/_deployer.py +109 -0
  21. flyte/_environment.py +94 -11
  22. flyte/_excepthook.py +37 -0
  23. flyte/_group.py +2 -1
  24. flyte/_hash.py +1 -16
  25. flyte/_image.py +544 -231
  26. flyte/_initialize.py +456 -316
  27. flyte/_interface.py +40 -5
  28. flyte/_internal/controllers/__init__.py +22 -8
  29. flyte/_internal/controllers/_local_controller.py +159 -35
  30. flyte/_internal/controllers/_trace.py +18 -10
  31. flyte/_internal/controllers/remote/__init__.py +38 -9
  32. flyte/_internal/controllers/remote/_action.py +82 -12
  33. flyte/_internal/controllers/remote/_client.py +6 -2
  34. flyte/_internal/controllers/remote/_controller.py +290 -64
  35. flyte/_internal/controllers/remote/_core.py +155 -95
  36. flyte/_internal/controllers/remote/_informer.py +40 -20
  37. flyte/_internal/controllers/remote/_service_protocol.py +2 -2
  38. flyte/_internal/imagebuild/__init__.py +2 -10
  39. flyte/_internal/imagebuild/docker_builder.py +391 -84
  40. flyte/_internal/imagebuild/image_builder.py +111 -55
  41. flyte/_internal/imagebuild/remote_builder.py +409 -0
  42. flyte/_internal/imagebuild/utils.py +79 -0
  43. flyte/_internal/resolvers/_app_env_module.py +92 -0
  44. flyte/_internal/resolvers/_task_module.py +5 -38
  45. flyte/_internal/resolvers/app_env.py +26 -0
  46. flyte/_internal/resolvers/common.py +8 -1
  47. flyte/_internal/resolvers/default.py +2 -2
  48. flyte/_internal/runtime/convert.py +319 -36
  49. flyte/_internal/runtime/entrypoints.py +106 -18
  50. flyte/_internal/runtime/io.py +71 -23
  51. flyte/_internal/runtime/resources_serde.py +21 -7
  52. flyte/_internal/runtime/reuse.py +125 -0
  53. flyte/_internal/runtime/rusty.py +196 -0
  54. flyte/_internal/runtime/task_serde.py +239 -66
  55. flyte/_internal/runtime/taskrunner.py +48 -8
  56. flyte/_internal/runtime/trigger_serde.py +162 -0
  57. flyte/_internal/runtime/types_serde.py +7 -16
  58. flyte/_keyring/file.py +115 -0
  59. flyte/_link.py +30 -0
  60. flyte/_logging.py +241 -42
  61. flyte/_map.py +312 -0
  62. flyte/_metrics.py +59 -0
  63. flyte/_module.py +74 -0
  64. flyte/_pod.py +30 -0
  65. flyte/_resources.py +296 -33
  66. flyte/_retry.py +1 -7
  67. flyte/_reusable_environment.py +72 -7
  68. flyte/_run.py +462 -132
  69. flyte/_secret.py +47 -11
  70. flyte/_serve.py +333 -0
  71. flyte/_task.py +245 -56
  72. flyte/_task_environment.py +219 -97
  73. flyte/_task_plugins.py +47 -0
  74. flyte/_tools.py +8 -8
  75. flyte/_trace.py +15 -24
  76. flyte/_trigger.py +1027 -0
  77. flyte/_utils/__init__.py +12 -1
  78. flyte/_utils/asyn.py +3 -1
  79. flyte/_utils/async_cache.py +139 -0
  80. flyte/_utils/coro_management.py +5 -4
  81. flyte/_utils/description_parser.py +19 -0
  82. flyte/_utils/docker_credentials.py +173 -0
  83. flyte/_utils/helpers.py +45 -19
  84. flyte/_utils/module_loader.py +123 -0
  85. flyte/_utils/org_discovery.py +57 -0
  86. flyte/_utils/uv_script_parser.py +8 -1
  87. flyte/_version.py +16 -3
  88. flyte/app/__init__.py +27 -0
  89. flyte/app/_app_environment.py +362 -0
  90. flyte/app/_connector_environment.py +40 -0
  91. flyte/app/_deploy.py +130 -0
  92. flyte/app/_parameter.py +343 -0
  93. flyte/app/_runtime/__init__.py +3 -0
  94. flyte/app/_runtime/app_serde.py +383 -0
  95. flyte/app/_types.py +113 -0
  96. flyte/app/extras/__init__.py +9 -0
  97. flyte/app/extras/_auth_middleware.py +217 -0
  98. flyte/app/extras/_fastapi.py +93 -0
  99. flyte/app/extras/_model_loader/__init__.py +3 -0
  100. flyte/app/extras/_model_loader/config.py +7 -0
  101. flyte/app/extras/_model_loader/loader.py +288 -0
  102. flyte/cli/__init__.py +12 -0
  103. flyte/cli/_abort.py +28 -0
  104. flyte/cli/_build.py +114 -0
  105. flyte/cli/_common.py +493 -0
  106. flyte/cli/_create.py +371 -0
  107. flyte/cli/_delete.py +45 -0
  108. flyte/cli/_deploy.py +401 -0
  109. flyte/cli/_gen.py +316 -0
  110. flyte/cli/_get.py +446 -0
  111. flyte/cli/_option.py +33 -0
  112. flyte/{_cli → cli}/_params.py +57 -17
  113. flyte/cli/_plugins.py +209 -0
  114. flyte/cli/_prefetch.py +292 -0
  115. flyte/cli/_run.py +690 -0
  116. flyte/cli/_serve.py +338 -0
  117. flyte/cli/_update.py +86 -0
  118. flyte/cli/_user.py +20 -0
  119. flyte/cli/main.py +246 -0
  120. flyte/config/__init__.py +2 -167
  121. flyte/config/_config.py +215 -163
  122. flyte/config/_internal.py +10 -1
  123. flyte/config/_reader.py +225 -0
  124. flyte/connectors/__init__.py +11 -0
  125. flyte/connectors/_connector.py +330 -0
  126. flyte/connectors/_server.py +194 -0
  127. flyte/connectors/utils.py +159 -0
  128. flyte/errors.py +134 -2
  129. flyte/extend.py +24 -0
  130. flyte/extras/_container.py +69 -56
  131. flyte/git/__init__.py +3 -0
  132. flyte/git/_config.py +279 -0
  133. flyte/io/__init__.py +8 -1
  134. flyte/io/{structured_dataset → _dataframe}/__init__.py +32 -30
  135. flyte/io/{structured_dataset → _dataframe}/basic_dfs.py +75 -68
  136. flyte/io/{structured_dataset/structured_dataset.py → _dataframe/dataframe.py} +207 -242
  137. flyte/io/_dir.py +575 -113
  138. flyte/io/_file.py +587 -141
  139. flyte/io/_hashing_io.py +342 -0
  140. flyte/io/extend.py +7 -0
  141. flyte/models.py +635 -0
  142. flyte/prefetch/__init__.py +22 -0
  143. flyte/prefetch/_hf_model.py +563 -0
  144. flyte/remote/__init__.py +14 -3
  145. flyte/remote/_action.py +879 -0
  146. flyte/remote/_app.py +346 -0
  147. flyte/remote/_auth_metadata.py +42 -0
  148. flyte/remote/_client/_protocols.py +62 -4
  149. flyte/remote/_client/auth/_auth_utils.py +19 -0
  150. flyte/remote/_client/auth/_authenticators/base.py +8 -2
  151. flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
  152. flyte/remote/_client/auth/_authenticators/factory.py +4 -0
  153. flyte/remote/_client/auth/_authenticators/passthrough.py +79 -0
  154. flyte/remote/_client/auth/_authenticators/pkce.py +17 -18
  155. flyte/remote/_client/auth/_channel.py +47 -18
  156. flyte/remote/_client/auth/_client_config.py +5 -3
  157. flyte/remote/_client/auth/_keyring.py +15 -2
  158. flyte/remote/_client/auth/_token_client.py +3 -3
  159. flyte/remote/_client/controlplane.py +206 -18
  160. flyte/remote/_common.py +66 -0
  161. flyte/remote/_data.py +107 -22
  162. flyte/remote/_logs.py +116 -33
  163. flyte/remote/_project.py +21 -19
  164. flyte/remote/_run.py +164 -631
  165. flyte/remote/_secret.py +72 -29
  166. flyte/remote/_task.py +387 -46
  167. flyte/remote/_trigger.py +368 -0
  168. flyte/remote/_user.py +43 -0
  169. flyte/report/_report.py +10 -6
  170. flyte/storage/__init__.py +13 -1
  171. flyte/storage/_config.py +237 -0
  172. flyte/storage/_parallel_reader.py +289 -0
  173. flyte/storage/_storage.py +268 -59
  174. flyte/syncify/__init__.py +56 -0
  175. flyte/syncify/_api.py +414 -0
  176. flyte/types/__init__.py +39 -0
  177. flyte/types/_interface.py +22 -7
  178. flyte/{io/pickle/transformer.py → types/_pickle.py} +37 -9
  179. flyte/types/_string_literals.py +8 -9
  180. flyte/types/_type_engine.py +226 -126
  181. flyte/types/_utils.py +1 -1
  182. flyte-2.0.0b46.data/scripts/debug.py +38 -0
  183. flyte-2.0.0b46.data/scripts/runtime.py +194 -0
  184. flyte-2.0.0b46.dist-info/METADATA +352 -0
  185. flyte-2.0.0b46.dist-info/RECORD +221 -0
  186. flyte-2.0.0b46.dist-info/entry_points.txt +8 -0
  187. flyte-2.0.0b46.dist-info/licenses/LICENSE +201 -0
  188. flyte/_api_commons.py +0 -3
  189. flyte/_cli/_common.py +0 -299
  190. flyte/_cli/_create.py +0 -42
  191. flyte/_cli/_delete.py +0 -23
  192. flyte/_cli/_deploy.py +0 -140
  193. flyte/_cli/_get.py +0 -235
  194. flyte/_cli/_run.py +0 -174
  195. flyte/_cli/main.py +0 -98
  196. flyte/_datastructures.py +0 -342
  197. flyte/_internal/controllers/pbhash.py +0 -39
  198. flyte/_protos/common/authorization_pb2.py +0 -66
  199. flyte/_protos/common/authorization_pb2.pyi +0 -108
  200. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  201. flyte/_protos/common/identifier_pb2.py +0 -71
  202. flyte/_protos/common/identifier_pb2.pyi +0 -82
  203. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  204. flyte/_protos/common/identity_pb2.py +0 -48
  205. flyte/_protos/common/identity_pb2.pyi +0 -72
  206. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  207. flyte/_protos/common/list_pb2.py +0 -36
  208. flyte/_protos/common/list_pb2.pyi +0 -69
  209. flyte/_protos/common/list_pb2_grpc.py +0 -4
  210. flyte/_protos/common/policy_pb2.py +0 -37
  211. flyte/_protos/common/policy_pb2.pyi +0 -27
  212. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  213. flyte/_protos/common/role_pb2.py +0 -37
  214. flyte/_protos/common/role_pb2.pyi +0 -53
  215. flyte/_protos/common/role_pb2_grpc.py +0 -4
  216. flyte/_protos/common/runtime_version_pb2.py +0 -28
  217. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  218. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  219. flyte/_protos/logs/dataplane/payload_pb2.py +0 -96
  220. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -168
  221. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  222. flyte/_protos/secret/definition_pb2.py +0 -49
  223. flyte/_protos/secret/definition_pb2.pyi +0 -93
  224. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  225. flyte/_protos/secret/payload_pb2.py +0 -62
  226. flyte/_protos/secret/payload_pb2.pyi +0 -94
  227. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  228. flyte/_protos/secret/secret_pb2.py +0 -38
  229. flyte/_protos/secret/secret_pb2.pyi +0 -6
  230. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  231. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  232. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  233. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  234. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  235. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  236. flyte/_protos/workflow/queue_service_pb2.py +0 -106
  237. flyte/_protos/workflow/queue_service_pb2.pyi +0 -141
  238. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  239. flyte/_protos/workflow/run_definition_pb2.py +0 -128
  240. flyte/_protos/workflow/run_definition_pb2.pyi +0 -310
  241. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  242. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  243. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  244. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  245. flyte/_protos/workflow/run_service_pb2.py +0 -133
  246. flyte/_protos/workflow/run_service_pb2.pyi +0 -175
  247. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -412
  248. flyte/_protos/workflow/state_service_pb2.py +0 -58
  249. flyte/_protos/workflow/state_service_pb2.pyi +0 -71
  250. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  251. flyte/_protos/workflow/task_definition_pb2.py +0 -72
  252. flyte/_protos/workflow/task_definition_pb2.pyi +0 -65
  253. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  254. flyte/_protos/workflow/task_service_pb2.py +0 -44
  255. flyte/_protos/workflow/task_service_pb2.pyi +0 -31
  256. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -104
  257. flyte/io/_dataframe.py +0 -0
  258. flyte/io/pickle/__init__.py +0 -0
  259. flyte/remote/_console.py +0 -18
  260. flyte-0.2.0b1.dist-info/METADATA +0 -179
  261. flyte-0.2.0b1.dist-info/RECORD +0 -204
  262. flyte-0.2.0b1.dist-info/entry_points.txt +0 -3
  263. /flyte/{_cli → _debug}/__init__.py +0 -0
  264. /flyte/{_protos → _keyring}/__init__.py +0 -0
  265. {flyte-0.2.0b1.dist-info → flyte-2.0.0b46.dist-info}/WHEEL +0 -0
  266. {flyte-0.2.0b1.dist-info → flyte-2.0.0b46.dist-info}/top_level.txt +0 -0
@@ -16,13 +16,13 @@ import typing
16
16
  from abc import ABC, abstractmethod
17
17
  from collections import OrderedDict
18
18
  from functools import lru_cache
19
- from types import GenericAlias
19
+ from types import GenericAlias, NoneType
20
20
  from typing import Any, Dict, NamedTuple, Optional, Type, cast
21
21
 
22
22
  import msgpack
23
- from flyteidl.core import interface_pb2, literals_pb2, types_pb2
24
- from flyteidl.core.literals_pb2 import Binary, Literal, LiteralCollection, LiteralMap, Primitive, Scalar, Union, Void
25
- from flyteidl.core.types_pb2 import LiteralType, SimpleType, TypeAnnotation, TypeStructure, UnionType
23
+ from flyteidl2.core import interface_pb2, literals_pb2, types_pb2
24
+ from flyteidl2.core.literals_pb2 import Binary, Literal, LiteralCollection, LiteralMap, Primitive, Scalar, Union, Void
25
+ from flyteidl2.core.types_pb2 import LiteralType, SimpleType, TypeAnnotation, TypeStructure, UnionType
26
26
  from fsspec.asyn import _run_coros_in_chunks # pylint: disable=W0212
27
27
  from google.protobuf import json_format as _json_format
28
28
  from google.protobuf import struct_pb2
@@ -35,13 +35,14 @@ from mashumaro.jsonschema.models import Context, JSONSchema
35
35
  from mashumaro.jsonschema.plugins import BasePlugin
36
36
  from mashumaro.jsonschema.schema import Instance
37
37
  from mashumaro.mixins.json import DataClassJSONMixin
38
+ from pydantic import BaseModel
38
39
  from typing_extensions import Annotated, get_args, get_origin
39
40
 
40
41
  import flyte.storage as storage
41
- from flyte._datastructures import NativeInterface
42
- from flyte._hash import HashMethod
43
42
  from flyte._logging import logger
44
43
  from flyte._utils.helpers import load_proto_from_file
44
+ from flyte.errors import RestrictedTypeError
45
+ from flyte.models import NativeInterface
45
46
 
46
47
  from ._utils import literal_types_match
47
48
 
@@ -306,6 +307,9 @@ class SimpleTransformer(TypeTransformer[T]):
306
307
  expected_python_type = get_underlying_type(expected_python_type)
307
308
 
308
309
  if expected_python_type is not self._type:
310
+ if expected_python_type is None and issubclass(self._type, NoneType):
311
+ # If the expected type is NoneType, we can return None
312
+ return None
309
313
  raise TypeTransformerFailedError(
310
314
  f"Cannot convert to type {expected_python_type}, only {self._type} is supported"
311
315
  )
@@ -328,10 +332,6 @@ class SimpleTransformer(TypeTransformer[T]):
328
332
  raise ValueError(f"Transformer {self} cannot reverse {literal_type}")
329
333
 
330
334
 
331
- class RestrictedTypeError(Exception):
332
- pass
333
-
334
-
335
335
  class RestrictedTypeTransformer(TypeTransformer[T], ABC):
336
336
  """
337
337
  Types registered with the RestrictedTypeTransformer are not allowed to be converted to and from literals.
@@ -352,6 +352,66 @@ class RestrictedTypeTransformer(TypeTransformer[T], ABC):
352
352
  raise RestrictedTypeError(f"Transformer for type {self.python_type} is restricted currently")
353
353
 
354
354
 
355
+ class PydanticTransformer(TypeTransformer[BaseModel]):
356
+ def __init__(self):
357
+ super().__init__("Pydantic Transformer", BaseModel, enable_type_assertions=False)
358
+
359
+ def get_literal_type(self, t: Type[BaseModel]) -> LiteralType:
360
+ schema = t.model_json_schema()
361
+
362
+ meta_struct = struct_pb2.Struct()
363
+ meta_struct.update(
364
+ {
365
+ CACHE_KEY_METADATA: {
366
+ SERIALIZATION_FORMAT: MESSAGEPACK,
367
+ }
368
+ }
369
+ )
370
+
371
+ # The type engine used to publish a type structure for attribute access. As of v2, this is no longer needed.
372
+ return LiteralType(
373
+ simple=SimpleType.STRUCT,
374
+ metadata=schema,
375
+ annotation=TypeAnnotation(annotations=meta_struct),
376
+ )
377
+
378
+ async def to_literal(
379
+ self,
380
+ python_val: BaseModel,
381
+ python_type: Type[BaseModel],
382
+ expected: LiteralType,
383
+ ) -> Literal:
384
+ json_str = python_val.model_dump_json()
385
+ dict_obj = json.loads(json_str)
386
+ msgpack_bytes = msgpack.dumps(dict_obj)
387
+ return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag=MESSAGEPACK)))
388
+
389
+ def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[BaseModel]) -> BaseModel:
390
+ if binary_idl_object.tag == MESSAGEPACK:
391
+ dict_obj = msgpack.loads(binary_idl_object.value, strict_map_key=False)
392
+ json_str = json.dumps(dict_obj)
393
+ python_val = expected_python_type.model_validate_json(
394
+ json_data=json_str, strict=False, context={"deserialize": True}
395
+ )
396
+ return python_val
397
+ else:
398
+ raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`")
399
+
400
+ async def to_python_value(self, lv: Literal, expected_python_type: Type[BaseModel]) -> BaseModel:
401
+ """
402
+ There are two kinds of literal values to handle:
403
+ 1. Protobuf Structs (from the UI)
404
+ 2. Binary scalars (from other sources)
405
+ We need to account for both cases accordingly.
406
+ """
407
+ if lv and lv.HasField("scalar") and lv.scalar.HasField("binary"):
408
+ return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore
409
+
410
+ json_str = _json_format.MessageToJson(lv.scalar.generic)
411
+ python_val = expected_python_type.model_validate_json(json_str, strict=False, context={"deserialize": True})
412
+ return python_val
413
+
414
+
355
415
  class PydanticSchemaPlugin(BasePlugin):
356
416
  """This allows us to generate proper schemas for Pydantic models."""
357
417
 
@@ -392,35 +452,6 @@ class DataclassTransformer(TypeTransformer[object]):
392
452
  2. Deserialization: The dataclass transformer converts the MessagePack Bytes back to a dataclass.
393
453
  (1) Convert MessagePack Bytes to a dataclass using mashumaro.
394
454
  (2) Handle dataclass attributes to ensure they are of the correct types.
395
-
396
- TODO: Update the example using mashumaro instead of the older library
397
-
398
- Example
399
-
400
- .. code-block:: python
401
-
402
- @dataclass
403
- class Test:
404
- a: int
405
- b: str
406
-
407
- t = Test(a=10,b="e")
408
- JSONSchema().dump(t.schema())
409
-
410
- Output will look like
411
-
412
- .. code-block:: json
413
-
414
- {'$schema': 'http://json-schema.org/draft-07/schema#',
415
- 'definitions': {'TestSchema': {'properties': {'a': {'title': 'a',
416
- 'type': 'number',
417
- 'format': 'integer'},
418
- 'b': {'title': 'b', 'type': 'string'}},
419
- 'type': 'object',
420
- 'additionalProperties': False}},
421
- '$ref': '#/definitions/TestSchema'}
422
-
423
-
424
455
  """
425
456
 
426
457
  def __init__(self) -> None:
@@ -543,28 +574,7 @@ class DataclassTransformer(TypeTransformer[object]):
543
574
  f"Possibly remove `DataClassJsonMixin` and `dataclass_json` decorator from dataclass declaration"
544
575
  )
545
576
 
546
- # Recursively construct the dataclass_type which contains the literal type of each field
547
- literal_type = {}
548
-
549
- hints = typing.get_type_hints(t)
550
- # Get the type of each field from dataclass
551
- for field in t.__dataclass_fields__.values(): # type: ignore
552
- try:
553
- name = field.name
554
- python_type = hints.get(name, field.type)
555
- literal_type[name] = TypeEngine.to_literal_type(python_type)
556
- except Exception as e:
557
- logger.warning(
558
- "Field {} of type {} cannot be converted to a literal type. Error: {}".format(
559
- field.name, field.type, e
560
- )
561
- )
562
-
563
- # This is for attribute access in FlytePropeller.
564
- ts = TypeStructure(tag="", dataclass_type=literal_type)
565
- from google.protobuf.struct_pb2 import Struct
566
-
567
- meta_struct = Struct()
577
+ meta_struct = struct_pb2.Struct()
568
578
  meta_struct.update(
569
579
  {
570
580
  CACHE_KEY_METADATA: {
@@ -572,10 +582,11 @@ class DataclassTransformer(TypeTransformer[object]):
572
582
  }
573
583
  }
574
584
  )
585
+
586
+ # The type engine used to publish the type `structure` for attribute access. As of v2, this is no longer needed.
575
587
  return types_pb2.LiteralType(
576
588
  simple=types_pb2.SimpleType.STRUCT,
577
589
  metadata=schema,
578
- structure=ts,
579
590
  annotation=TypeAnnotation(annotations=meta_struct),
580
591
  )
581
592
 
@@ -627,7 +638,7 @@ class DataclassTransformer(TypeTransformer[object]):
627
638
  field.type = self._get_origin_type_in_annotation(cast(type, field.type))
628
639
  return python_type
629
640
 
630
- async def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> T:
641
+ def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> T:
631
642
  if binary_idl_object.tag == MESSAGEPACK:
632
643
  if issubclass(expected_python_type, DataClassJSONMixin):
633
644
  dict_obj = msgpack.loads(binary_idl_object.value, strict_map_key=False)
@@ -652,9 +663,10 @@ class DataclassTransformer(TypeTransformer[object]):
652
663
  "user defined datatypes in Flytekit"
653
664
  )
654
665
 
655
- if lv.scalar and lv.scalar.binary:
656
- return await self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore
666
+ if lv.HasField("scalar") and lv.scalar.HasField("binary"):
667
+ return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore
657
668
 
669
+ # todo: revisit this, it should always be a binary in v2.
658
670
  json_str = _json_format.MessageToJson(lv.scalar.generic)
659
671
 
660
672
  # The `from_json` function is provided from mashumaro's `DataClassJSONMixin`.
@@ -771,6 +783,13 @@ class EnumTransformer(TypeTransformer[enum.Enum]):
771
783
  return LiteralType(enum_type=types_pb2.EnumType(values=values))
772
784
 
773
785
  async def to_literal(self, python_val: enum.Enum, python_type: Type[T], expected: LiteralType) -> Literal:
786
+ if isinstance(python_val, str):
787
+ # this is the case when python Literals are used as enums
788
+ if python_val not in expected.enum_type.values:
789
+ raise TypeTransformerFailedError(
790
+ f"Value {python_val} is not valid value, expected - {expected.enum_type.values}"
791
+ )
792
+ return Literal(scalar=Scalar(primitive=Primitive(string_value=python_val))) # type: ignore
774
793
  if type(python_val).__class__ != enum.EnumMeta:
775
794
  raise TypeTransformerFailedError("Expected an enum")
776
795
  if type(python_val.value) is not str:
@@ -781,6 +800,12 @@ class EnumTransformer(TypeTransformer[enum.Enum]):
781
800
  async def to_python_value(self, lv: Literal, expected_python_type: Type[T]) -> T:
782
801
  if lv.HasField("scalar") and lv.scalar.HasField("binary"):
783
802
  return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore
803
+ from flyte._interface import LITERAL_ENUM
804
+
805
+ if expected_python_type.__name__ is LITERAL_ENUM:
806
+ # This is the case when python Literal types are used as enums. The class name is always LiteralEnum an
807
+ # hardcoded in flyte.models
808
+ return lv.scalar.primitive.string_value
784
809
  return expected_python_type(lv.scalar.primitive.string_value) # type: ignore
785
810
 
786
811
  def guess_python_type(self, literal_type: LiteralType) -> Type[enum.Enum]:
@@ -799,7 +824,37 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name:
799
824
  from flyte.io._file import File
800
825
 
801
826
  attribute_list: typing.List[typing.Tuple[Any, Any]] = []
802
- for property_key, property_val in schema["properties"].items():
827
+ nested_types: typing.Dict[str, type] = {} # Track nested model types for conversion
828
+
829
+ # Use 'required' field to preserve property order, as protobuf Struct doesn't preserve dict order
830
+ properties = schema["properties"]
831
+ property_order = schema.get("required", list(properties.keys()))
832
+
833
+ for property_key in property_order:
834
+ property_val = properties[property_key]
835
+ # Handle $ref for nested Pydantic models
836
+ if property_val.get("$ref"):
837
+ ref_path = property_val["$ref"]
838
+ # Extract the definition name from the $ref path (e.g., "#/$defs/MyNestedModel" -> "MyNestedModel")
839
+ ref_name = ref_path.split("/")[-1]
840
+ # Get the referenced schema from $defs (or definitions for older schemas)
841
+ defs = schema.get("$defs", schema.get("definitions", {}))
842
+ if ref_name in defs:
843
+ ref_schema = defs[ref_name].copy()
844
+ # Include $defs so nested models can resolve their own $refs
845
+ if "$defs" not in ref_schema and defs:
846
+ ref_schema["$defs"] = defs
847
+ nested_class: type = convert_mashumaro_json_schema_to_python_class(ref_schema, ref_name)
848
+ attribute_list.append(
849
+ (
850
+ property_key,
851
+ typing.cast(GenericAlias, nested_class),
852
+ )
853
+ )
854
+ # Track this as a nested type that needs dict-to-object conversion
855
+ nested_types[property_key] = nested_class
856
+ continue
857
+
803
858
  if property_val.get("anyOf"):
804
859
  property_type = property_val["anyOf"][0]["type"]
805
860
  elif property_val.get("enum"):
@@ -837,14 +892,14 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name:
837
892
  )
838
893
  )
839
894
  continue
895
+ nested_class = convert_mashumaro_json_schema_to_python_class(sub_schemea, sub_schemea_name)
840
896
  attribute_list.append(
841
897
  (
842
898
  property_key,
843
- typing.cast(
844
- GenericAlias, convert_mashumaro_json_schema_to_python_class(sub_schemea, sub_schemea_name)
845
- ),
899
+ typing.cast(GenericAlias, nested_class),
846
900
  )
847
901
  )
902
+ nested_types[property_key] = nested_class
848
903
  elif property_val.get("additionalProperties"):
849
904
  # For typing.Dict type
850
905
  elem_type = _get_element_type(property_val["additionalProperties"])
@@ -875,14 +930,14 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name:
875
930
  )
876
931
  )
877
932
  continue
933
+ nested_class = convert_mashumaro_json_schema_to_python_class(property_val, sub_schemea_name)
878
934
  attribute_list.append(
879
935
  (
880
936
  property_key,
881
- typing.cast(
882
- GenericAlias, convert_mashumaro_json_schema_to_python_class(property_val, sub_schemea_name)
883
- ),
937
+ typing.cast(GenericAlias, nested_class),
884
938
  )
885
939
  )
940
+ nested_types[property_key] = nested_class
886
941
  else:
887
942
  # For untyped dict
888
943
  attribute_list.append((property_key, dict)) # type: ignore
@@ -891,7 +946,7 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name:
891
946
  # Handle int, float, bool or str
892
947
  else:
893
948
  attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore
894
- return attribute_list
949
+ return attribute_list, nested_types
895
950
 
896
951
 
897
952
  class TypeEngine(typing.Generic[T]):
@@ -970,11 +1025,10 @@ class TypeEngine(typing.Generic[T]):
970
1025
  return cls._REGISTRY[python_type.__origin__]
971
1026
 
972
1027
  # Handling UnionType specially - PEP 604
973
- if sys.version_info >= (3, 10):
974
- import types
1028
+ import types
975
1029
 
976
- if isinstance(python_type, types.UnionType):
977
- return cls._REGISTRY[types.UnionType]
1030
+ if isinstance(python_type, types.UnionType):
1031
+ return cls._REGISTRY[types.UnionType]
978
1032
 
979
1033
  if python_type in cls._REGISTRY:
980
1034
  return cls._REGISTRY[python_type]
@@ -1008,7 +1062,7 @@ class TypeEngine(typing.Generic[T]):
1008
1062
  return cls._DATACLASS_TRANSFORMER
1009
1063
 
1010
1064
  display_pickle_warning(str(python_type))
1011
- from flyte.io.pickle.transformer import FlytePickleTransformer
1065
+ from flyte.types._pickle import FlytePickleTransformer
1012
1066
 
1013
1067
  return FlytePickleTransformer()
1014
1068
 
@@ -1021,10 +1075,10 @@ class TypeEngine(typing.Generic[T]):
1021
1075
  # Avoid a race condition where concurrent threads may exit lazy_import_transformers before the transformers
1022
1076
  # have been imported. This could be implemented without a lock if you assume python assignments are atomic
1023
1077
  # and re-registering transformers is acceptable, but I decided to play it safe.
1024
- from flyte.io.structured_dataset import lazy_import_structured_dataset_handler
1078
+ from flyte.io._dataframe import lazy_import_dataframe_handler
1025
1079
 
1026
1080
  # todo: bring in extras transformers (pytorch, etc.)
1027
- lazy_import_structured_dataset_handler()
1081
+ lazy_import_dataframe_handler()
1028
1082
 
1029
1083
  @classmethod
1030
1084
  def to_literal_type(cls, python_type: Type[T]) -> LiteralType:
@@ -1052,23 +1106,6 @@ class TypeEngine(typing.Generic[T]):
1052
1106
  ):
1053
1107
  raise TypeTransformerFailedError(f"Python value cannot be None, expected {python_type}/{expected}")
1054
1108
 
1055
- @classmethod
1056
- def calculate_hash(cls, python_val: typing.Any, python_type: Type[T]) -> Optional[str]:
1057
- # In case the value is an annotated type we inspect the annotations and look for hash-related annotations.
1058
- hsh = None
1059
- if is_annotated(python_type):
1060
- # We are now dealing with one of two cases:
1061
- # 1. The annotated type is a `HashMethod`, which indicates that we should produce the hash using
1062
- # the method indicated in the annotation.
1063
- # 2. The annotated type is being used for a different purpose other than calculating hash values,
1064
- # in which case we should just continue.
1065
- for annotation in get_args(python_type)[1:]:
1066
- if not isinstance(annotation, HashMethod):
1067
- continue
1068
- hsh = annotation.calculate(python_val)
1069
- break
1070
- return hsh
1071
-
1072
1109
  @classmethod
1073
1110
  async def to_literal(
1074
1111
  cls, python_val: typing.Any, python_type: Type[T], expected: types_pb2.LiteralType
@@ -1081,8 +1118,6 @@ class TypeEngine(typing.Generic[T]):
1081
1118
  lv = await transformer.to_literal(python_val, python_type, expected)
1082
1119
 
1083
1120
  modify_literal_uris(lv)
1084
- calculated_hash = cls.calculate_hash(python_val, python_type) or ""
1085
- lv.hash = calculated_hash
1086
1121
  return lv
1087
1122
 
1088
1123
  @classmethod
@@ -1160,20 +1195,26 @@ class TypeEngine(typing.Generic[T]):
1160
1195
  f"Received more input values {len(lm.literals)}"
1161
1196
  f" than allowed by the input spec {len(python_interface_inputs)}"
1162
1197
  )
1198
+ # Create tasks for converting each kwarg
1199
+ tasks = {}
1200
+ for k in lm.literals:
1201
+ tasks[k] = asyncio.create_task(TypeEngine.to_python_value(lm.literals[k], python_interface_inputs[k]))
1202
+
1203
+ # Gather all tasks, returning exceptions instead of raising them
1204
+ results = await asyncio.gather(*tasks.values(), return_exceptions=True)
1205
+
1206
+ # Check for exceptions and raise with specific kwarg name
1163
1207
  kwargs = {}
1164
- try:
1165
- for i, k in enumerate(lm.literals):
1166
- kwargs[k] = asyncio.create_task(TypeEngine.to_python_value(lm.literals[k], python_interface_inputs[k]))
1167
- await asyncio.gather(*kwargs.values())
1168
- except Exception as e:
1169
- raise TypeTransformerFailedError(
1170
- f"Error converting input:\n"
1171
- f"Literal value: {lm.literals[k]}\n"
1172
- f"Expected Python type: {python_interface_inputs[k]}\n"
1173
- f"Exception: {e}"
1174
- )
1208
+ for (key, task), result in zip(tasks.items(), results):
1209
+ if isinstance(result, Exception):
1210
+ raise TypeTransformerFailedError(
1211
+ f"Error converting input '{key}':\n"
1212
+ f"Literal value: {lm.literals[key]}\n"
1213
+ f"Expected Python type: {python_interface_inputs[key]}\n"
1214
+ f"Exception: {result}"
1215
+ ) from result
1216
+ kwargs[key] = result
1175
1217
 
1176
- kwargs = {k: v.result() for k, v in kwargs.items() if v is not None}
1177
1218
  return kwargs
1178
1219
 
1179
1220
  @classmethod
@@ -1207,7 +1248,12 @@ class TypeEngine(typing.Generic[T]):
1207
1248
  python_type = type_hints.get(k, type(d[k]))
1208
1249
  e: BaseException = literal_map[k].exception() # type: ignore
1209
1250
  if isinstance(e, TypeError):
1210
- raise TypeError(f"Error converting: {type(v)}, {python_type}, received_value {v}")
1251
+ raise TypeError(
1252
+ f"Type conversion failed for variable '{k}'.\n"
1253
+ f"Expected type: {python_type}\n"
1254
+ f"Actual type: {type(d[k])}\n"
1255
+ f"Value received: {d[k]!r}"
1256
+ ) from e
1211
1257
  else:
1212
1258
  raise e
1213
1259
  literal_map[k] = v.result()
@@ -1242,7 +1288,8 @@ class TypeEngine(typing.Generic[T]):
1242
1288
  try:
1243
1289
  return transformer.guess_python_type(flyte_type)
1244
1290
  except ValueError:
1245
- logger.debug(f"Skipping transformer {transformer.name} for {flyte_type}")
1291
+ # Skipping transformer
1292
+ continue
1246
1293
 
1247
1294
  # Because the dataclass transformer is handled explicitly in the get_transformer code, we have to handle it
1248
1295
  # separately here too.
@@ -1365,6 +1412,19 @@ def _type_essence(x: types_pb2.LiteralType) -> types_pb2.LiteralType:
1365
1412
 
1366
1413
 
1367
1414
  def _are_types_castable(upstream: types_pb2.LiteralType, downstream: types_pb2.LiteralType) -> bool:
1415
+ if upstream.union_type is not None:
1416
+ # for each upstream variant, there must be a compatible type downstream
1417
+ for v in upstream.union_type.variants:
1418
+ if not _are_types_castable(v, downstream):
1419
+ return False
1420
+ return True
1421
+
1422
+ if downstream.union_type is not None:
1423
+ # there must be a compatible downstream type
1424
+ for v in downstream.union_type.variants:
1425
+ if _are_types_castable(upstream, v):
1426
+ return True
1427
+
1368
1428
  if upstream.HasField("collection_type"):
1369
1429
  if not downstream.HasField("collection_type"):
1370
1430
  return False
@@ -1410,7 +1470,7 @@ def _are_types_castable(upstream: types_pb2.LiteralType, downstream: types_pb2.L
1410
1470
 
1411
1471
  return True
1412
1472
 
1413
- if upstream.HasField("union_type"):
1473
+ if upstream.HasField("union_type") and upstream.union_type is not None:
1414
1474
  # for each upstream variant, there must be a compatible type downstream
1415
1475
  for v in upstream.union_type.variants:
1416
1476
  if not _are_types_castable(v, downstream):
@@ -1504,6 +1564,18 @@ class UnionTransformer(TypeTransformer[T]):
1504
1564
  self, python_val: T, python_type: Type[T], expected: types_pb2.LiteralType
1505
1565
  ) -> literals_pb2.Literal:
1506
1566
  python_type = get_underlying_type(python_type)
1567
+ inferred_type = type(python_val)
1568
+ subtypes = get_args(python_type)
1569
+
1570
+ if inferred_type in subtypes:
1571
+ # If the Python value's type matches one of the types in the Union,
1572
+ # always use the transformer associated with that specific type.
1573
+ transformer = TypeEngine.get_transformer(inferred_type)
1574
+ res = await transformer.to_literal(
1575
+ python_val, inferred_type, expected.union_type.variants[subtypes.index(inferred_type)]
1576
+ )
1577
+ res_type = _add_tag_to_type(transformer.get_literal_type(inferred_type), transformer.name)
1578
+ return Literal(scalar=Scalar(union=Union(value=res, type=res_type)))
1507
1579
 
1508
1580
  potential_types = []
1509
1581
  found_res = False
@@ -1511,14 +1583,14 @@ class UnionTransformer(TypeTransformer[T]):
1511
1583
  res = None
1512
1584
  res_type = None
1513
1585
  t = None
1514
- for i in range(len(get_args(python_type))):
1586
+ for i in range(len(subtypes)):
1515
1587
  try:
1516
- t = get_args(python_type)[i]
1588
+ t = subtypes[i]
1517
1589
  trans: TypeTransformer[T] = TypeEngine.get_transformer(t)
1518
1590
  attempt = trans.to_literal(python_val, t, expected.union_type.variants[i])
1519
1591
  res = await attempt
1520
1592
  if found_res:
1521
- logger.debug(f"Current type {get_args(python_type)[i]} old res {res_type}")
1593
+ logger.debug(f"Current type {subtypes[i]} old res {res_type}")
1522
1594
  is_ambiguous = True
1523
1595
  res_type = _add_tag_to_type(trans.get_literal_type(t), trans.name)
1524
1596
  found_res = True
@@ -1657,7 +1729,7 @@ class DictTransformer(TypeTransformer[dict]):
1657
1729
  Converts a Python dictionary to a Flyte-specific ``Literal`` using MessagePack encoding.
1658
1730
  Falls back to Pickle if encoding fails and `allow_pickle` is True.
1659
1731
  """
1660
- from flyte.io.pickle.transformer import FlytePickle
1732
+ from flyte.types._pickle import FlytePickle
1661
1733
 
1662
1734
  try:
1663
1735
  # Handle dictionaries with non-string keys (e.g., Dict[int, Type])
@@ -1723,7 +1795,6 @@ class DictTransformer(TypeTransformer[dict]):
1723
1795
  for k, v in python_val.items():
1724
1796
  if type(k) is not str:
1725
1797
  raise ValueError("Flyte MapType expects all keys to be strings")
1726
- # TODO: log a warning for Annotated objects that contain HashMethod
1727
1798
 
1728
1799
  _, v_type = self.extract_types(python_type)
1729
1800
  lit_map[k] = TypeEngine.to_literal(v, cast(type, v_type), expected.map_value_type)
@@ -1763,7 +1834,7 @@ class DictTransformer(TypeTransformer[dict]):
1763
1834
  # pr: han-ru is this part still necessary?
1764
1835
  if lv and lv.HasField("scalar") and lv.scalar.HasField("generic"):
1765
1836
  if lv.metadata and lv.metadata.get("format", None) == "pickle":
1766
- from flyte.io.pickle.transformer import FlytePickle
1837
+ from flyte.types._pickle import FlytePickle
1767
1838
 
1768
1839
  uri = json.loads(_json_format.MessageToJson(lv.scalar.generic)).get("pickle_file")
1769
1840
  return await FlytePickle.from_pickle(uri)
@@ -1833,8 +1904,26 @@ def convert_mashumaro_json_schema_to_python_class(schema: dict, schema_name: typ
1833
1904
  :param schema_name: dataclass name of return type
1834
1905
  """
1835
1906
 
1836
- attribute_list = generate_attribute_list_from_dataclass_json_mixin(schema, schema_name)
1837
- return dataclasses.make_dataclass(schema_name, attribute_list)
1907
+ attribute_list, nested_types = generate_attribute_list_from_dataclass_json_mixin(schema, schema_name)
1908
+ cls = dataclasses.make_dataclass(schema_name, attribute_list)
1909
+
1910
+ # Wrap __init__ to convert dict inputs to nested types
1911
+ if nested_types:
1912
+ # Store the original __init__ from the class's __dict__ to avoid mypy error
1913
+ original_init = cls.__dict__["__init__"]
1914
+
1915
+ def __init__(self, *args, **kwargs): # type: ignore[misc]
1916
+ # Convert dict values to nested types before calling original __init__
1917
+ for field_name, field_type in nested_types.items():
1918
+ if field_name in kwargs:
1919
+ value = kwargs[field_name]
1920
+ if isinstance(value, dict):
1921
+ kwargs[field_name] = field_type(**value)
1922
+ original_init(self, *args, **kwargs)
1923
+
1924
+ cls.__init__ = __init__ # type: ignore[method-assign, misc]
1925
+
1926
+ return cls
1838
1927
 
1839
1928
 
1840
1929
  def _get_element_type(element_property: typing.Dict[str, str]) -> Type:
@@ -1870,7 +1959,6 @@ def _get_element_type(element_property: typing.Dict[str, str]) -> Type:
1870
1959
  return str
1871
1960
 
1872
1961
 
1873
- # pr: han-ru is this still needed?
1874
1962
  def dataclass_from_dict(cls: type, src: typing.Dict[str, typing.Any]) -> typing.Any:
1875
1963
  """
1876
1964
  Utility function to construct a dataclass object from dict
@@ -1950,7 +2038,7 @@ def _handle_flyte_console_float_input_to_int(lv: Literal) -> int:
1950
2038
 
1951
2039
  def _check_and_convert_void(lv: Literal) -> None:
1952
2040
  if not lv.scalar.HasField("none_type"):
1953
- raise TypeTransformerFailedError(f"Cannot convert literal {lv} to None")
2041
+ raise TypeTransformerFailedError(f"Cannot convert literal '{lv}' to None")
1954
2042
  return None
1955
2043
 
1956
2044
 
@@ -2011,7 +2099,9 @@ DateTransformer = SimpleTransformer(
2011
2099
  lambda x: Literal(
2012
2100
  scalar=Scalar(primitive=Primitive(datetime=datetime.datetime.combine(x, datetime.time.min)))
2013
2101
  ), # convert datetime to date
2014
- lambda x: x.scalar.primitive.datetime.date() if x.scalar.primitive.HasField("datetime") else None,
2102
+ lambda x: x.scalar.primitive.datetime.ToDatetime().replace(tzinfo=datetime.timezone.utc).date()
2103
+ if x.scalar.primitive.HasField("datetime")
2104
+ else None,
2015
2105
  )
2016
2106
 
2017
2107
  NoneTransformer = SimpleTransformer(
@@ -2035,10 +2125,20 @@ def _register_default_type_transformers():
2035
2125
  TypeEngine.register(BoolTransformer)
2036
2126
  TypeEngine.register(NoneTransformer, [None])
2037
2127
  TypeEngine.register(ListTransformer())
2038
- TypeEngine.register(UnionTransformer(), [UnionType])
2128
+
2129
+ if sys.version_info < (3, 14):
2130
+ TypeEngine.register(UnionTransformer(), [UnionType])
2131
+ else:
2132
+ # In Python 3.14+, types.UnionType and typing.Union are the same object.
2133
+ # UnionTransformer's python_type is already typing.Union, so only add UnionType
2134
+ # as an additional type if it's different from typing.Union.
2135
+ union_transformer = UnionTransformer()
2136
+ additional_union_types = [] if UnionType is union_transformer.python_type else [UnionType]
2137
+ TypeEngine.register(union_transformer, additional_union_types)
2039
2138
  TypeEngine.register(DictTransformer())
2040
2139
  TypeEngine.register(EnumTransformer())
2041
2140
  TypeEngine.register(ProtobufTransformer())
2141
+ TypeEngine.register(PydanticTransformer())
2042
2142
 
2043
2143
  # inner type is. Also unsupported are typing's Tuples. Even though you can look inside them, Flyte's type system
2044
2144
  # doesn't support these currently.
flyte/types/_utils.py CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import importlib
4
4
  import typing
5
5
 
6
- from flyteidl.core.types_pb2 import EnumType, LiteralType, UnionType
6
+ from flyteidl2.core.types_pb2 import EnumType, LiteralType, UnionType
7
7
 
8
8
  T = typing.TypeVar("T")
9
9
 
@@ -0,0 +1,38 @@
1
+ import click
2
+
3
+
4
+ @click.group()
5
+ def _debug():
6
+ """Debug commands for Flyte."""
7
+
8
+
9
+ @_debug.command("resume")
10
+ @click.option("--pid", "-m", type=int, required=True, help="PID of the vscode server.")
11
+ def resume(pid):
12
+ """
13
+ Resume a Flyte task for debugging purposes.
14
+
15
+ Args:
16
+ pid (int): PID of the vscode server.
17
+ """
18
+ import os
19
+ import signal
20
+
21
+ print("Terminating server and resuming task.")
22
+ answer = (
23
+ input(
24
+ "This operation will kill the server. All unsaved data will be lost,"
25
+ " and you will no longer be able to connect to it. Do you really want to terminate? (Y/N): "
26
+ )
27
+ .strip()
28
+ .upper()
29
+ )
30
+ if answer == "Y":
31
+ os.kill(pid, signal.SIGTERM)
32
+ print("The server has been terminated and the task has been resumed.")
33
+ else:
34
+ print("Operation canceled.")
35
+
36
+
37
+ if __name__ == "__main__":
38
+ _debug()