flyte 2.0.0b13__py3-none-any.whl → 2.0.0b30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. flyte/__init__.py +18 -2
  2. flyte/_bin/debug.py +38 -0
  3. flyte/_bin/runtime.py +62 -8
  4. flyte/_cache/cache.py +4 -2
  5. flyte/_cache/local_cache.py +216 -0
  6. flyte/_code_bundle/_ignore.py +12 -4
  7. flyte/_code_bundle/_packaging.py +13 -9
  8. flyte/_code_bundle/_utils.py +18 -10
  9. flyte/_code_bundle/bundle.py +17 -9
  10. flyte/_constants.py +1 -0
  11. flyte/_context.py +4 -1
  12. flyte/_custom_context.py +73 -0
  13. flyte/_debug/constants.py +38 -0
  14. flyte/_debug/utils.py +17 -0
  15. flyte/_debug/vscode.py +307 -0
  16. flyte/_deploy.py +235 -61
  17. flyte/_environment.py +20 -6
  18. flyte/_excepthook.py +1 -1
  19. flyte/_hash.py +1 -16
  20. flyte/_image.py +178 -81
  21. flyte/_initialize.py +132 -51
  22. flyte/_interface.py +39 -2
  23. flyte/_internal/controllers/__init__.py +4 -5
  24. flyte/_internal/controllers/_local_controller.py +70 -29
  25. flyte/_internal/controllers/_trace.py +1 -1
  26. flyte/_internal/controllers/remote/__init__.py +0 -2
  27. flyte/_internal/controllers/remote/_action.py +14 -16
  28. flyte/_internal/controllers/remote/_client.py +1 -1
  29. flyte/_internal/controllers/remote/_controller.py +68 -70
  30. flyte/_internal/controllers/remote/_core.py +127 -99
  31. flyte/_internal/controllers/remote/_informer.py +19 -10
  32. flyte/_internal/controllers/remote/_service_protocol.py +7 -7
  33. flyte/_internal/imagebuild/docker_builder.py +181 -69
  34. flyte/_internal/imagebuild/image_builder.py +0 -5
  35. flyte/_internal/imagebuild/remote_builder.py +155 -64
  36. flyte/_internal/imagebuild/utils.py +51 -2
  37. flyte/_internal/resolvers/_task_module.py +5 -38
  38. flyte/_internal/resolvers/default.py +2 -2
  39. flyte/_internal/runtime/convert.py +110 -21
  40. flyte/_internal/runtime/entrypoints.py +27 -1
  41. flyte/_internal/runtime/io.py +21 -8
  42. flyte/_internal/runtime/resources_serde.py +20 -6
  43. flyte/_internal/runtime/reuse.py +1 -1
  44. flyte/_internal/runtime/rusty.py +20 -5
  45. flyte/_internal/runtime/task_serde.py +34 -19
  46. flyte/_internal/runtime/taskrunner.py +22 -4
  47. flyte/_internal/runtime/trigger_serde.py +160 -0
  48. flyte/_internal/runtime/types_serde.py +1 -1
  49. flyte/_keyring/__init__.py +0 -0
  50. flyte/_keyring/file.py +115 -0
  51. flyte/_logging.py +201 -39
  52. flyte/_map.py +111 -14
  53. flyte/_module.py +70 -0
  54. flyte/_pod.py +4 -3
  55. flyte/_resources.py +213 -31
  56. flyte/_run.py +110 -39
  57. flyte/_task.py +75 -16
  58. flyte/_task_environment.py +105 -29
  59. flyte/_task_plugins.py +4 -2
  60. flyte/_trace.py +5 -0
  61. flyte/_trigger.py +1000 -0
  62. flyte/_utils/__init__.py +2 -1
  63. flyte/_utils/asyn.py +3 -1
  64. flyte/_utils/coro_management.py +2 -1
  65. flyte/_utils/docker_credentials.py +173 -0
  66. flyte/_utils/module_loader.py +17 -2
  67. flyte/_version.py +3 -3
  68. flyte/cli/_abort.py +3 -3
  69. flyte/cli/_build.py +3 -6
  70. flyte/cli/_common.py +78 -7
  71. flyte/cli/_create.py +182 -4
  72. flyte/cli/_delete.py +23 -1
  73. flyte/cli/_deploy.py +63 -16
  74. flyte/cli/_get.py +79 -34
  75. flyte/cli/_params.py +26 -10
  76. flyte/cli/_plugins.py +209 -0
  77. flyte/cli/_run.py +151 -26
  78. flyte/cli/_serve.py +64 -0
  79. flyte/cli/_update.py +37 -0
  80. flyte/cli/_user.py +17 -0
  81. flyte/cli/main.py +30 -4
  82. flyte/config/_config.py +10 -6
  83. flyte/config/_internal.py +1 -0
  84. flyte/config/_reader.py +29 -8
  85. flyte/connectors/__init__.py +11 -0
  86. flyte/connectors/_connector.py +270 -0
  87. flyte/connectors/_server.py +197 -0
  88. flyte/connectors/utils.py +135 -0
  89. flyte/errors.py +22 -2
  90. flyte/extend.py +8 -1
  91. flyte/extras/_container.py +6 -1
  92. flyte/git/__init__.py +3 -0
  93. flyte/git/_config.py +21 -0
  94. flyte/io/__init__.py +2 -0
  95. flyte/io/_dataframe/__init__.py +2 -0
  96. flyte/io/_dataframe/basic_dfs.py +17 -8
  97. flyte/io/_dataframe/dataframe.py +98 -132
  98. flyte/io/_dir.py +575 -113
  99. flyte/io/_file.py +582 -139
  100. flyte/io/_hashing_io.py +342 -0
  101. flyte/models.py +74 -15
  102. flyte/remote/__init__.py +6 -1
  103. flyte/remote/_action.py +34 -26
  104. flyte/remote/_client/_protocols.py +39 -4
  105. flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
  106. flyte/remote/_client/auth/_authenticators/pkce.py +1 -1
  107. flyte/remote/_client/auth/_channel.py +10 -6
  108. flyte/remote/_client/controlplane.py +17 -5
  109. flyte/remote/_console.py +3 -2
  110. flyte/remote/_data.py +6 -6
  111. flyte/remote/_logs.py +3 -3
  112. flyte/remote/_run.py +64 -8
  113. flyte/remote/_secret.py +26 -17
  114. flyte/remote/_task.py +75 -33
  115. flyte/remote/_trigger.py +306 -0
  116. flyte/remote/_user.py +33 -0
  117. flyte/report/_report.py +1 -1
  118. flyte/storage/__init__.py +6 -1
  119. flyte/storage/_config.py +5 -1
  120. flyte/storage/_parallel_reader.py +274 -0
  121. flyte/storage/_storage.py +200 -103
  122. flyte/types/__init__.py +16 -0
  123. flyte/types/_interface.py +2 -2
  124. flyte/types/_pickle.py +35 -8
  125. flyte/types/_string_literals.py +8 -9
  126. flyte/types/_type_engine.py +40 -70
  127. flyte/types/_utils.py +1 -1
  128. flyte-2.0.0b30.data/scripts/debug.py +38 -0
  129. {flyte-2.0.0b13.data → flyte-2.0.0b30.data}/scripts/runtime.py +62 -8
  130. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/METADATA +11 -3
  131. flyte-2.0.0b30.dist-info/RECORD +192 -0
  132. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/entry_points.txt +3 -0
  133. flyte/_protos/common/authorization_pb2.py +0 -66
  134. flyte/_protos/common/authorization_pb2.pyi +0 -108
  135. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  136. flyte/_protos/common/identifier_pb2.py +0 -93
  137. flyte/_protos/common/identifier_pb2.pyi +0 -110
  138. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  139. flyte/_protos/common/identity_pb2.py +0 -48
  140. flyte/_protos/common/identity_pb2.pyi +0 -72
  141. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  142. flyte/_protos/common/list_pb2.py +0 -36
  143. flyte/_protos/common/list_pb2.pyi +0 -71
  144. flyte/_protos/common/list_pb2_grpc.py +0 -4
  145. flyte/_protos/common/policy_pb2.py +0 -37
  146. flyte/_protos/common/policy_pb2.pyi +0 -27
  147. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  148. flyte/_protos/common/role_pb2.py +0 -37
  149. flyte/_protos/common/role_pb2.pyi +0 -53
  150. flyte/_protos/common/role_pb2_grpc.py +0 -4
  151. flyte/_protos/common/runtime_version_pb2.py +0 -28
  152. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  153. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  154. flyte/_protos/imagebuilder/definition_pb2.py +0 -59
  155. flyte/_protos/imagebuilder/definition_pb2.pyi +0 -140
  156. flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
  157. flyte/_protos/imagebuilder/payload_pb2.py +0 -32
  158. flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
  159. flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
  160. flyte/_protos/imagebuilder/service_pb2.py +0 -29
  161. flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
  162. flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
  163. flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
  164. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
  165. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  166. flyte/_protos/secret/definition_pb2.py +0 -49
  167. flyte/_protos/secret/definition_pb2.pyi +0 -93
  168. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  169. flyte/_protos/secret/payload_pb2.py +0 -62
  170. flyte/_protos/secret/payload_pb2.pyi +0 -94
  171. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  172. flyte/_protos/secret/secret_pb2.py +0 -38
  173. flyte/_protos/secret/secret_pb2.pyi +0 -6
  174. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  175. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  176. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  177. flyte/_protos/workflow/common_pb2.py +0 -27
  178. flyte/_protos/workflow/common_pb2.pyi +0 -14
  179. flyte/_protos/workflow/common_pb2_grpc.py +0 -4
  180. flyte/_protos/workflow/environment_pb2.py +0 -29
  181. flyte/_protos/workflow/environment_pb2.pyi +0 -12
  182. flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
  183. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  184. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  185. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  186. flyte/_protos/workflow/queue_service_pb2.py +0 -109
  187. flyte/_protos/workflow/queue_service_pb2.pyi +0 -166
  188. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  189. flyte/_protos/workflow/run_definition_pb2.py +0 -121
  190. flyte/_protos/workflow/run_definition_pb2.pyi +0 -327
  191. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  192. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  193. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  194. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  195. flyte/_protos/workflow/run_service_pb2.py +0 -137
  196. flyte/_protos/workflow/run_service_pb2.pyi +0 -185
  197. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -446
  198. flyte/_protos/workflow/state_service_pb2.py +0 -67
  199. flyte/_protos/workflow/state_service_pb2.pyi +0 -76
  200. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  201. flyte/_protos/workflow/task_definition_pb2.py +0 -79
  202. flyte/_protos/workflow/task_definition_pb2.pyi +0 -81
  203. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  204. flyte/_protos/workflow/task_service_pb2.py +0 -60
  205. flyte/_protos/workflow/task_service_pb2.pyi +0 -59
  206. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
  207. flyte-2.0.0b13.dist-info/RECORD +0 -239
  208. /flyte/{_protos → _debug}/__init__.py +0 -0
  209. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/WHEEL +0 -0
  210. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/licenses/LICENSE +0 -0
  211. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/top_level.txt +0 -0
flyte/remote/_task.py CHANGED
@@ -1,24 +1,23 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import asyncio
3
4
  import functools
4
5
  from dataclasses import dataclass
5
- from threading import Lock
6
6
  from typing import Any, AsyncIterator, Callable, Coroutine, Dict, Iterator, Literal, Optional, Tuple, Union, cast
7
7
 
8
8
  import rich.repr
9
- from flyteidl.core import literals_pb2
10
- from google.protobuf import timestamp
9
+ from flyteidl2.common import identifier_pb2, list_pb2
10
+ from flyteidl2.core import literals_pb2
11
+ from flyteidl2.task import task_definition_pb2, task_service_pb2
11
12
 
12
13
  import flyte
13
14
  import flyte.errors
14
15
  from flyte._cache.cache import CacheBehavior
15
16
  from flyte._context import internal_ctx
16
- from flyte._initialize import ensure_client, get_client, get_common_config
17
+ from flyte._initialize import ensure_client, get_client, get_init_config
17
18
  from flyte._internal.runtime.resources_serde import get_proto_resources
18
19
  from flyte._internal.runtime.task_serde import get_proto_retry_strategy, get_proto_timeout, get_security_context
19
20
  from flyte._logging import logger
20
- from flyte._protos.common import identifier_pb2, list_pb2
21
- from flyte._protos.workflow import task_definition_pb2, task_service_pb2
22
21
  from flyte.models import NativeInterface
23
22
  from flyte.syncify import syncify
24
23
 
@@ -35,7 +34,7 @@ def _repr_task_metadata(metadata: task_definition_pb2.TaskMetadata) -> rich.repr
35
34
  else:
36
35
  yield "deployed_by", f"App: {metadata.deployed_by.application.spec.name}"
37
36
  yield "short_name", metadata.short_name
38
- yield "deployed_at", timestamp.to_datetime(metadata.deployed_at)
37
+ yield "deployed_at", metadata.deployed_at.ToDatetime()
39
38
  yield "environment_name", metadata.environment_name
40
39
 
41
40
 
@@ -49,7 +48,7 @@ class LazyEntity:
49
48
  self._task: Optional[TaskDetails] = None
50
49
  self._getter = getter
51
50
  self._name = name
52
- self._mutex = Lock()
51
+ self._mutex = asyncio.Lock()
53
52
 
54
53
  @property
55
54
  def name(self) -> str:
@@ -60,11 +59,11 @@ class LazyEntity:
60
59
  """
61
60
  Forwards all other attributes to task, causing the task to be fetched!
62
61
  """
63
- with self._mutex:
62
+ async with self._mutex:
64
63
  if self._task is None:
65
64
  self._task = await self._getter()
66
- if self._task is None:
67
- raise RuntimeError(f"Error downloading the task {self._name}, (check original exception...)")
65
+ if self._task is None:
66
+ raise RuntimeError(f"Error downloading the task {self._name}, (check original exception...)")
68
67
  return self._task
69
68
 
70
69
  @syncify
@@ -73,8 +72,10 @@ class LazyEntity:
73
72
  **kwargs: Any,
74
73
  ) -> LazyEntity:
75
74
  task_details = cast(TaskDetails, await self.fetch.aio())
76
- task_details.override(**kwargs)
77
- return self
75
+ new_task_details = task_details.override(**kwargs)
76
+ new_entity = LazyEntity(self._name, self._getter)
77
+ new_entity._task = new_task_details
78
+ return new_entity
78
79
 
79
80
  async def __call__(self, *args, **kwargs):
80
81
  """
@@ -93,10 +94,11 @@ class LazyEntity:
93
94
  AutoVersioning = Literal["latest", "current"]
94
95
 
95
96
 
96
- @dataclass
97
+ @dataclass(frozen=True)
97
98
  class TaskDetails(ToJSONMixin):
98
99
  pb2: task_definition_pb2.TaskDetails
99
100
  max_inline_io_bytes: int = 10 * 1024 * 1024 # 10 MB
101
+ overriden_queue: Optional[str] = None
100
102
 
101
103
  @classmethod
102
104
  def get(
@@ -148,7 +150,7 @@ class TaskDetails(ToJSONMixin):
148
150
  if ctx is None:
149
151
  raise ValueError("auto_version=current can only be used within a task context.")
150
152
  _version = ctx.version
151
- cfg = get_common_config()
153
+ cfg = get_init_config()
152
154
  task_id = task_definition_pb2.TaskIdentifier(
153
155
  org=cfg.org,
154
156
  project=project or cfg.project,
@@ -261,12 +263,6 @@ class TaskDetails(ToJSONMixin):
261
263
  f"Reference task {self.name} does not support positional arguments"
262
264
  f"currently. Please use keyword arguments."
263
265
  )
264
- if len(self.required_args) > 0:
265
- if len(args) + len(kwargs) < len(self.required_args):
266
- raise ValueError(
267
- f"Task {self.name} requires at least {self.required_args} arguments, "
268
- f"but only received args:{args} kwargs{kwargs}."
269
- )
270
266
 
271
267
  ctx = internal_ctx()
272
268
  if ctx.is_task_context():
@@ -276,19 +272,37 @@ class TaskDetails(ToJSONMixin):
276
272
  from flyte._internal.controllers import get_controller
277
273
 
278
274
  controller = get_controller()
275
+ if len(self.required_args) > 0:
276
+ if len(args) + len(kwargs) < len(self.required_args):
277
+ raise ValueError(
278
+ f"Task {self.name} requires at least {self.required_args} arguments, "
279
+ f"but only received args:{args} kwargs{kwargs}."
280
+ )
279
281
  if controller:
280
- return await controller.submit_task_ref(self.pb2, self.max_inline_io_bytes, *args, **kwargs)
281
- raise flyte.errors
282
+ return await controller.submit_task_ref(self, *args, **kwargs)
283
+ raise flyte.errors.ReferenceTaskError(
284
+ f"Reference tasks [{self.name}] cannot be executed locally, only remotely."
285
+ )
286
+
287
+ @property
288
+ def queue(self) -> Optional[str]:
289
+ """
290
+ The queue to use for the task.
291
+ """
292
+ return self.overriden_queue
282
293
 
283
294
  def override(
284
295
  self,
285
296
  *,
286
- friendly_name: Optional[str] = None,
297
+ short_name: Optional[str] = None,
287
298
  resources: Optional[flyte.Resources] = None,
288
299
  retries: Union[int, flyte.RetryStrategy] = 0,
289
300
  timeout: Optional[flyte.TimeoutType] = None,
290
301
  env_vars: Optional[Dict[str, str]] = None,
291
302
  secrets: Optional[flyte.SecretRequest] = None,
303
+ max_inline_io_bytes: Optional[int] = None,
304
+ cache: Optional[flyte.Cache] = None,
305
+ queue: Optional[str] = None,
292
306
  **kwargs: Any,
293
307
  ) -> TaskDetails:
294
308
  if len(kwargs) > 0:
@@ -296,29 +310,57 @@ class TaskDetails(ToJSONMixin):
296
310
  f"ReferenceTasks [{self.name}] do not support overriding with kwargs: {kwargs}, "
297
311
  f"Check the parameters for override method."
298
312
  )
299
- template = self.pb2.spec.task_template
300
- if friendly_name:
301
- self.pb2.metadata.short_name = friendly_name
313
+ pb2 = task_definition_pb2.TaskDetails()
314
+ pb2.CopyFrom(self.pb2)
315
+
316
+ if short_name:
317
+ pb2.metadata.short_name = short_name
318
+
319
+ template = pb2.spec.task_template
302
320
  if secrets:
303
321
  template.security_context.CopyFrom(get_security_context(secrets))
322
+
304
323
  if template.HasField("container"):
305
324
  if env_vars:
306
325
  template.container.env.clear()
307
326
  template.container.env.extend([literals_pb2.KeyValuePair(key=k, value=v) for k, v in env_vars.items()])
308
327
  if resources:
309
328
  template.container.resources.CopyFrom(get_proto_resources(resources))
329
+
330
+ md = template.metadata
310
331
  if retries:
311
- template.metadata.retries.CopyFrom(get_proto_retry_strategy(retries))
312
- if timeout:
313
- template.metadata.timeout.CopyFrom(get_proto_timeout(timeout))
332
+ md.retries.CopyFrom(get_proto_retry_strategy(retries))
314
333
 
315
- return self
334
+ if timeout:
335
+ md.timeout.CopyFrom(get_proto_timeout(timeout))
336
+
337
+ if cache:
338
+ if cache.behavior == "disable":
339
+ md.discoverable = False
340
+ md.discovery_version = ""
341
+ elif cache.behavior == "override":
342
+ md.discoverable = True
343
+ if not cache.version_override:
344
+ raise ValueError("cache.version_override must be set when cache.behavior is 'override'")
345
+ md.discovery_version = cache.version_override
346
+ else:
347
+ if cache.behavior == "auto":
348
+ raise ValueError("cache.behavior must be 'disable' or 'override' for reference tasks")
349
+ raise ValueError(f"Invalid cache behavior: {cache.behavior}.")
350
+ md.cache_serializable = cache.serialize
351
+ md.cache_ignore_input_vars[:] = list(cache.ignored_inputs or ())
352
+
353
+ return TaskDetails(
354
+ pb2,
355
+ max_inline_io_bytes=max_inline_io_bytes or self.max_inline_io_bytes,
356
+ overriden_queue=queue,
357
+ )
316
358
 
317
359
  def __rich_repr__(self) -> rich.repr.Result:
318
360
  """
319
361
  Rich representation of the task.
320
362
  """
321
- yield "friendly_name", self.pb2.spec.short_name
363
+ yield "short_name", self.pb2.spec.short_name
322
364
  yield "environment", self.pb2.spec.environment
323
365
  yield "default_inputs_keys", self.default_input_args
324
366
  yield "required_args", self.required_args
@@ -408,7 +450,7 @@ class Task(ToJSONMixin):
408
450
  sort_pb2 = list_pb2.Sort(
409
451
  key=sort_by[0], direction=list_pb2.Sort.ASCENDING if sort_by[1] == "asc" else list_pb2.Sort.DESCENDING
410
452
  )
411
- cfg = get_common_config()
453
+ cfg = get_init_config()
412
454
  filters = []
413
455
  if by_task_name:
414
456
  filters.append(
@@ -0,0 +1,306 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from functools import cached_property
5
+ from typing import AsyncIterator
6
+
7
+ import grpc.aio
8
+ from flyteidl2.common import identifier_pb2, list_pb2
9
+ from flyteidl2.task import common_pb2, task_definition_pb2
10
+ from flyteidl2.trigger import trigger_definition_pb2, trigger_service_pb2
11
+
12
+ import flyte
13
+ from flyte._initialize import ensure_client, get_client, get_init_config
14
+ from flyte._internal.runtime import trigger_serde
15
+ from flyte.syncify import syncify
16
+
17
+ from ._common import ToJSONMixin
18
+ from ._task import Task, TaskDetails
19
+
20
+
21
+ @dataclass
22
+ class TriggerDetails(ToJSONMixin):
23
+ pb2: trigger_definition_pb2.TriggerDetails
24
+
25
+ @syncify
26
+ @classmethod
27
+ async def get(cls, *, name: str, task_name: str) -> TriggerDetails:
28
+ """
29
+ Retrieve detailed information about a specific trigger by its name.
30
+ """
31
+ ensure_client()
32
+ cfg = get_init_config()
33
+ resp = await get_client().trigger_service.GetTriggerDetails(
34
+ request=trigger_service_pb2.GetTriggerDetailsRequest(
35
+ name=identifier_pb2.TriggerName(
36
+ task_name=task_name,
37
+ name=name,
38
+ org=cfg.org,
39
+ project=cfg.project,
40
+ domain=cfg.domain,
41
+ ),
42
+ )
43
+ )
44
+ return cls(pb2=resp.trigger)
45
+
46
+ @property
47
+ def name(self) -> str:
48
+ return self.id.name.name
49
+
50
+ @property
51
+ def id(self) -> identifier_pb2.TriggerIdentifier:
52
+ return self.pb2.id
53
+
54
+ @property
55
+ def task_name(self) -> str:
56
+ return self.pb2.id.name.task_name
57
+
58
+ @property
59
+ def automation_spec(self) -> common_pb2.TriggerAutomationSpec:
60
+ return self.pb2.automation_spec
61
+
62
+ @property
63
+ def metadata(self) -> trigger_definition_pb2.TriggerMetadata:
64
+ return self.pb2.metadata
65
+
66
+ @property
67
+ def status(self) -> trigger_definition_pb2.TriggerStatus:
68
+ return self.pb2.status
69
+
70
+ @property
71
+ def is_active(self) -> bool:
72
+ return self.pb2.spec.active
73
+
74
+ @cached_property
75
+ def trigger(self) -> trigger_definition_pb2.Trigger:
76
+ return trigger_definition_pb2.Trigger(
77
+ id=self.pb2.id,
78
+ automation_spec=self.automation_spec,
79
+ metadata=self.metadata,
80
+ status=self.status,
81
+ active=self.is_active,
82
+ )
83
+
84
+
85
+ @dataclass
86
+ class Trigger(ToJSONMixin):
87
+ pb2: trigger_definition_pb2.Trigger
88
+ details: TriggerDetails | None = None
89
+
90
+ @syncify
91
+ @classmethod
92
+ async def create(
93
+ cls,
94
+ trigger: flyte.Trigger,
95
+ task_name: str,
96
+ task_version: str | None = None,
97
+ ) -> Trigger:
98
+ """
99
+ Create a new trigger in the Flyte platform.
100
+
101
+ :param trigger: The flyte.Trigger object containing the trigger definition.
102
+ :param task_name: Optional name of the task to associate with the trigger.
103
+ """
104
+ ensure_client()
105
+ cfg = get_init_config()
106
+
107
+ # Fetch the task to ensure it exists and to get its input definitions
108
+ try:
109
+ lazy = (
110
+ Task.get(name=task_name, version=task_version)
111
+ if task_version
112
+ else Task.get(name=task_name, auto_version="latest")
113
+ )
114
+ task: TaskDetails = await lazy.fetch.aio()
115
+
116
+ task_trigger = await trigger_serde.to_task_trigger(
117
+ t=trigger,
118
+ task_name=task_name,
119
+ task_inputs=task.pb2.spec.task_template.interface.inputs,
120
+ task_default_inputs=list(task.pb2.spec.default_inputs),
121
+ )
122
+
123
+ resp = await get_client().trigger_service.DeployTrigger(
124
+ request=trigger_service_pb2.DeployTriggerRequest(
125
+ name=identifier_pb2.TriggerName(
126
+ name=trigger.name,
127
+ task_name=task_name,
128
+ org=cfg.org,
129
+ project=cfg.project,
130
+ domain=cfg.domain,
131
+ ),
132
+ spec=trigger_definition_pb2.TriggerSpec(
133
+ active=task_trigger.spec.active,
134
+ inputs=task_trigger.spec.inputs,
135
+ run_spec=task_trigger.spec.run_spec,
136
+ task_version=task.version,
137
+ ),
138
+ automation_spec=task_trigger.automation_spec,
139
+ )
140
+ )
141
+
142
+ details = TriggerDetails(pb2=resp.trigger)
143
+
144
+ return cls(pb2=details.trigger, details=details)
145
+ except grpc.aio.AioRpcError as e:
146
+ if e.code() == grpc.StatusCode.NOT_FOUND:
147
+ raise ValueError(f"Task {task_name}:{task_version or 'latest'} not found") from e
148
+ raise
149
+
150
+ @syncify
151
+ @classmethod
152
+ async def get(cls, *, name: str, task_name: str) -> TriggerDetails:
153
+ """
154
+ Retrieve a trigger by its name and associated task name.
155
+ """
156
+ return await TriggerDetails.get.aio(name=name, task_name=task_name)
157
+
158
+ @syncify
159
+ @classmethod
160
+ async def listall(
161
+ cls, task_name: str | None = None, task_version: str | None = None, limit: int = 100
162
+ ) -> AsyncIterator[Trigger]:
163
+ """
164
+ List all triggers associated with a specific task or all tasks if no task name is provided.
165
+ """
166
+ ensure_client()
167
+ cfg = get_init_config()
168
+ token = None
169
+ task_name_id = None
170
+ project_id = None
171
+ task_id = None
172
+ if task_name and task_version:
173
+ task_id = task_definition_pb2.TaskIdentifier(
174
+ name=task_name,
175
+ project=cfg.project,
176
+ domain=cfg.domain,
177
+ org=cfg.org,
178
+ version=task_version,
179
+ )
180
+ elif task_name:
181
+ task_name_id = task_definition_pb2.TaskName(
182
+ name=task_name,
183
+ project=cfg.project,
184
+ domain=cfg.domain,
185
+ org=cfg.org,
186
+ )
187
+ else:
188
+ project_id = identifier_pb2.ProjectIdentifier(
189
+ organization=cfg.org,
190
+ domain=cfg.domain,
191
+ name=cfg.project,
192
+ )
193
+
194
+ while True:
195
+ resp = await get_client().trigger_service.ListTriggers(
196
+ request=trigger_service_pb2.ListTriggersRequest(
197
+ project_id=project_id,
198
+ task_id=task_id,
199
+ task_name=task_name_id,
200
+ request=list_pb2.ListRequest(
201
+ limit=limit,
202
+ token=token,
203
+ ),
204
+ )
205
+ )
206
+ token = resp.token
207
+ for r in resp.triggers:
208
+ yield cls(r)
209
+ if not token:
210
+ break
211
+
212
+ @syncify
213
+ @classmethod
214
+ async def update(cls, name: str, task_name: str, active: bool):
215
+ """
216
+ Pause a trigger by its name and associated task name.
217
+ """
218
+ ensure_client()
219
+ cfg = get_init_config()
220
+ await get_client().trigger_service.UpdateTriggers(
221
+ request=trigger_service_pb2.UpdateTriggersRequest(
222
+ names=[
223
+ identifier_pb2.TriggerName(
224
+ org=cfg.org,
225
+ project=cfg.project,
226
+ domain=cfg.domain,
227
+ name=name,
228
+ task_name=task_name,
229
+ )
230
+ ],
231
+ active=active,
232
+ )
233
+ )
234
+
235
+ @syncify
236
+ @classmethod
237
+ async def delete(cls, name: str, task_name: str):
238
+ """
239
+ Delete a trigger by its name.
240
+ """
241
+ ensure_client()
242
+ cfg = get_init_config()
243
+ await get_client().trigger_service.DeleteTriggers(
244
+ request=trigger_service_pb2.DeleteTriggersRequest(
245
+ names=[
246
+ identifier_pb2.TriggerName(
247
+ org=cfg.org,
248
+ project=cfg.project,
249
+ domain=cfg.domain,
250
+ name=name,
251
+ task_name=task_name,
252
+ )
253
+ ],
254
+ )
255
+ )
256
+
257
+ @property
258
+ def id(self) -> identifier_pb2.TriggerIdentifier:
259
+ return self.pb2.id
260
+
261
+ @property
262
+ def name(self) -> str:
263
+ return self.id.name.name
264
+
265
+ @property
266
+ def task_name(self) -> str:
267
+ return self.id.name.task_name
268
+
269
+ @property
270
+ def automation_spec(self) -> common_pb2.TriggerAutomationSpec:
271
+ return self.pb2.automation_spec
272
+
273
+ async def get_details(self) -> TriggerDetails:
274
+ """
275
+ Get detailed information about this trigger.
276
+ """
277
+ if not self.details:
278
+ details = await TriggerDetails.get.aio(name=self.pb2.id.name.name)
279
+ self.details = details
280
+ return self.details
281
+
282
+ @property
283
+ def is_active(self) -> bool:
284
+ return self.pb2.active
285
+
286
+ def _rich_automation(self, automation: common_pb2.TriggerAutomationSpec):
287
+ if automation.type == common_pb2.TriggerAutomationSpec.type.TYPE_NONE:
288
+ yield "none", None
289
+ elif automation.type == common_pb2.TriggerAutomationSpec.type.TYPE_SCHEDULE:
290
+ if automation.schedule.cron is not None:
291
+ yield "cron", automation.schedule.cron
292
+ elif automation.schedule.rate is not None:
293
+ r = automation.schedule.rate
294
+ yield (
295
+ "fixed_rate",
296
+ (
297
+ f"Every [{r.value}] {r.unit} starting at "
298
+ f"{r.start_time.ToDatetime() if automation.HasField('start_time') else 'now'}"
299
+ ),
300
+ )
301
+
302
+ def __rich_repr__(self):
303
+ yield "task_name", self.task_name
304
+ yield "name", self.name
305
+ yield from self._rich_automation(self.pb2.automation_spec)
306
+ yield "auto_activate", self.is_active
flyte/remote/_user.py ADDED
@@ -0,0 +1,33 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from flyteidl.service import identity_pb2
6
+ from flyteidl.service.identity_pb2 import UserInfoResponse
7
+
8
+ from .._initialize import ensure_client, get_client
9
+ from ..syncify import syncify
10
+ from ._common import ToJSONMixin
11
+
12
+
13
+ @dataclass
14
+ class User(ToJSONMixin):
15
+ pb2: UserInfoResponse
16
+
17
+ @syncify
18
+ @classmethod
19
+ async def get(cls) -> User:
20
+ """
21
+ Fetches information about the currently logged in user.
22
+ Returns: A User object containing details about the user.
23
+ """
24
+ ensure_client()
25
+
26
+ resp = await get_client().identity_service.UserInfo(identity_pb2.UserInfoRequest())
27
+ return cls(resp)
28
+
29
+ def subject(self) -> str:
30
+ return self.pb2.subject
31
+
32
+ def name(self) -> str:
33
+ return self.pb2.name
flyte/report/_report.py CHANGED
@@ -4,7 +4,6 @@ import string
4
4
  from dataclasses import dataclass, field
5
5
  from typing import TYPE_CHECKING, Dict, List, Union
6
6
 
7
- from flyte._internal.runtime import io
8
7
  from flyte._logging import logger
9
8
  from flyte._tools import ipython_check
10
9
  from flyte.syncify import syncify
@@ -133,6 +132,7 @@ async def flush():
133
132
  """
134
133
  import flyte.storage as storage
135
134
  from flyte._context import internal_ctx
135
+ from flyte._internal.runtime import io
136
136
 
137
137
  if not internal_ctx().is_task_context():
138
138
  return
flyte/storage/__init__.py CHANGED
@@ -3,6 +3,8 @@ __all__ = [
3
3
  "GCS",
4
4
  "S3",
5
5
  "Storage",
6
+ "exists",
7
+ "exists_sync",
6
8
  "get",
7
9
  "get_configured_fsspec_kwargs",
8
10
  "get_random_local_directory",
@@ -11,13 +13,15 @@ __all__ = [
11
13
  "get_underlying_filesystem",
12
14
  "is_remote",
13
15
  "join",
16
+ "open",
14
17
  "put",
15
18
  "put_stream",
16
- "put_stream",
17
19
  ]
18
20
 
19
21
  from ._config import ABFS, GCS, S3, Storage
20
22
  from ._storage import (
23
+ exists,
24
+ exists_sync,
21
25
  get,
22
26
  get_configured_fsspec_kwargs,
23
27
  get_random_local_directory,
@@ -26,6 +30,7 @@ from ._storage import (
26
30
  get_underlying_filesystem,
27
31
  is_remote,
28
32
  join,
33
+ open,
29
34
  put,
30
35
  put_stream,
31
36
  )
flyte/storage/_config.py CHANGED
@@ -61,6 +61,7 @@ class S3(Storage):
61
61
  endpoint: typing.Optional[str] = None
62
62
  access_key_id: typing.Optional[str] = None
63
63
  secret_access_key: typing.Optional[str] = None
64
+ region: typing.Optional[str] = None
64
65
 
65
66
  _KEY_ENV_VAR_MAPPING: ClassVar[typing.Dict[str, str]] = {
66
67
  "endpoint": "FLYTE_AWS_ENDPOINT",
@@ -76,7 +77,7 @@ class S3(Storage):
76
77
  _KEY_SKIP_SIGNATURE: ClassVar = "skip_signature"
77
78
 
78
79
  @classmethod
79
- def auto(cls) -> S3:
80
+ def auto(cls, region: str | None = None) -> S3:
80
81
  """
81
82
  :return: Config
82
83
  """
@@ -88,6 +89,7 @@ class S3(Storage):
88
89
  kwargs = set_if_exists(kwargs, "endpoint", endpoint)
89
90
  kwargs = set_if_exists(kwargs, "access_key_id", access_key_id)
90
91
  kwargs = set_if_exists(kwargs, "secret_access_key", secret_access_key)
92
+ kwargs = set_if_exists(kwargs, "region", region)
91
93
 
92
94
  return S3(**kwargs)
93
95
 
@@ -141,6 +143,8 @@ class S3(Storage):
141
143
  kwargs["config"] = config
142
144
  kwargs["client_options"] = client_options or None
143
145
  kwargs["retry_config"] = retry_config or None
146
+ if self.region:
147
+ kwargs["region"] = self.region
144
148
 
145
149
  return kwargs
146
150