flyte 2.0.0b22__py3-none-any.whl → 2.0.0b23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (88) hide show
  1. flyte/__init__.py +5 -0
  2. flyte/_bin/runtime.py +35 -5
  3. flyte/_cache/cache.py +4 -2
  4. flyte/_cache/local_cache.py +215 -0
  5. flyte/_code_bundle/bundle.py +1 -0
  6. flyte/_debug/constants.py +0 -1
  7. flyte/_debug/vscode.py +6 -1
  8. flyte/_deploy.py +193 -52
  9. flyte/_environment.py +5 -0
  10. flyte/_excepthook.py +1 -1
  11. flyte/_image.py +101 -72
  12. flyte/_initialize.py +23 -0
  13. flyte/_internal/controllers/_local_controller.py +64 -24
  14. flyte/_internal/controllers/remote/_action.py +4 -1
  15. flyte/_internal/controllers/remote/_controller.py +5 -2
  16. flyte/_internal/controllers/remote/_core.py +6 -3
  17. flyte/_internal/controllers/remote/_informer.py +1 -1
  18. flyte/_internal/imagebuild/docker_builder.py +92 -28
  19. flyte/_internal/imagebuild/image_builder.py +7 -13
  20. flyte/_internal/imagebuild/remote_builder.py +6 -1
  21. flyte/_internal/runtime/io.py +13 -1
  22. flyte/_internal/runtime/rusty.py +17 -2
  23. flyte/_internal/runtime/task_serde.py +14 -20
  24. flyte/_internal/runtime/taskrunner.py +1 -1
  25. flyte/_internal/runtime/trigger_serde.py +153 -0
  26. flyte/_logging.py +1 -1
  27. flyte/_protos/common/identifier_pb2.py +19 -1
  28. flyte/_protos/common/identifier_pb2.pyi +22 -0
  29. flyte/_protos/workflow/common_pb2.py +14 -3
  30. flyte/_protos/workflow/common_pb2.pyi +49 -0
  31. flyte/_protos/workflow/queue_service_pb2.py +41 -35
  32. flyte/_protos/workflow/queue_service_pb2.pyi +26 -12
  33. flyte/_protos/workflow/queue_service_pb2_grpc.py +34 -0
  34. flyte/_protos/workflow/run_definition_pb2.py +38 -38
  35. flyte/_protos/workflow/run_definition_pb2.pyi +4 -2
  36. flyte/_protos/workflow/run_service_pb2.py +60 -50
  37. flyte/_protos/workflow/run_service_pb2.pyi +24 -6
  38. flyte/_protos/workflow/run_service_pb2_grpc.py +34 -0
  39. flyte/_protos/workflow/task_definition_pb2.py +15 -11
  40. flyte/_protos/workflow/task_definition_pb2.pyi +19 -2
  41. flyte/_protos/workflow/task_service_pb2.py +18 -17
  42. flyte/_protos/workflow/task_service_pb2.pyi +5 -2
  43. flyte/_protos/workflow/trigger_definition_pb2.py +66 -0
  44. flyte/_protos/workflow/trigger_definition_pb2.pyi +117 -0
  45. flyte/_protos/workflow/trigger_definition_pb2_grpc.py +4 -0
  46. flyte/_protos/workflow/trigger_service_pb2.py +96 -0
  47. flyte/_protos/workflow/trigger_service_pb2.pyi +110 -0
  48. flyte/_protos/workflow/trigger_service_pb2_grpc.py +281 -0
  49. flyte/_run.py +42 -15
  50. flyte/_task.py +35 -4
  51. flyte/_task_environment.py +60 -15
  52. flyte/_trigger.py +382 -0
  53. flyte/_version.py +3 -3
  54. flyte/cli/_abort.py +3 -3
  55. flyte/cli/_build.py +1 -3
  56. flyte/cli/_common.py +15 -2
  57. flyte/cli/_create.py +74 -0
  58. flyte/cli/_delete.py +23 -1
  59. flyte/cli/_deploy.py +5 -9
  60. flyte/cli/_get.py +75 -34
  61. flyte/cli/_params.py +4 -2
  62. flyte/cli/_run.py +12 -3
  63. flyte/cli/_update.py +36 -0
  64. flyte/cli/_user.py +17 -0
  65. flyte/cli/main.py +9 -1
  66. flyte/errors.py +9 -0
  67. flyte/io/_dir.py +513 -115
  68. flyte/io/_file.py +495 -135
  69. flyte/models.py +32 -0
  70. flyte/remote/__init__.py +6 -1
  71. flyte/remote/_client/_protocols.py +36 -2
  72. flyte/remote/_client/controlplane.py +19 -3
  73. flyte/remote/_run.py +42 -2
  74. flyte/remote/_task.py +14 -1
  75. flyte/remote/_trigger.py +308 -0
  76. flyte/remote/_user.py +33 -0
  77. flyte/storage/__init__.py +6 -1
  78. flyte/storage/_storage.py +119 -101
  79. flyte/types/_pickle.py +16 -3
  80. {flyte-2.0.0b22.data → flyte-2.0.0b23.data}/scripts/runtime.py +35 -5
  81. {flyte-2.0.0b22.dist-info → flyte-2.0.0b23.dist-info}/METADATA +3 -1
  82. {flyte-2.0.0b22.dist-info → flyte-2.0.0b23.dist-info}/RECORD +87 -75
  83. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  84. {flyte-2.0.0b22.data → flyte-2.0.0b23.data}/scripts/debug.py +0 -0
  85. {flyte-2.0.0b22.dist-info → flyte-2.0.0b23.dist-info}/WHEEL +0 -0
  86. {flyte-2.0.0b22.dist-info → flyte-2.0.0b23.dist-info}/entry_points.txt +0 -0
  87. {flyte-2.0.0b22.dist-info → flyte-2.0.0b23.dist-info}/licenses/LICENSE +0 -0
  88. {flyte-2.0.0b22.dist-info → flyte-2.0.0b23.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,281 @@
1
+ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
+ """Client and server classes corresponding to protobuf-defined services."""
3
+ import grpc
4
+
5
+ from flyte._protos.workflow import trigger_service_pb2 as workflow_dot_trigger__service__pb2
6
+
7
+
8
+ class TriggerServiceStub(object):
9
+ """TriggerService provides an interface for managing triggers.
10
+ """
11
+
12
+ def __init__(self, channel):
13
+ """Constructor.
14
+
15
+ Args:
16
+ channel: A grpc.Channel.
17
+ """
18
+ self.DeployTrigger = channel.unary_unary(
19
+ '/cloudidl.workflow.TriggerService/DeployTrigger',
20
+ request_serializer=workflow_dot_trigger__service__pb2.DeployTriggerRequest.SerializeToString,
21
+ response_deserializer=workflow_dot_trigger__service__pb2.DeployTriggerResponse.FromString,
22
+ )
23
+ self.GetTriggerDetails = channel.unary_unary(
24
+ '/cloudidl.workflow.TriggerService/GetTriggerDetails',
25
+ request_serializer=workflow_dot_trigger__service__pb2.GetTriggerDetailsRequest.SerializeToString,
26
+ response_deserializer=workflow_dot_trigger__service__pb2.GetTriggerDetailsResponse.FromString,
27
+ )
28
+ self.GetTriggerRevisionDetails = channel.unary_unary(
29
+ '/cloudidl.workflow.TriggerService/GetTriggerRevisionDetails',
30
+ request_serializer=workflow_dot_trigger__service__pb2.GetTriggerRevisionDetailsRequest.SerializeToString,
31
+ response_deserializer=workflow_dot_trigger__service__pb2.GetTriggerRevisionDetailsResponse.FromString,
32
+ )
33
+ self.ListTriggers = channel.unary_unary(
34
+ '/cloudidl.workflow.TriggerService/ListTriggers',
35
+ request_serializer=workflow_dot_trigger__service__pb2.ListTriggersRequest.SerializeToString,
36
+ response_deserializer=workflow_dot_trigger__service__pb2.ListTriggersResponse.FromString,
37
+ )
38
+ self.GetTriggerRevisionHistory = channel.unary_unary(
39
+ '/cloudidl.workflow.TriggerService/GetTriggerRevisionHistory',
40
+ request_serializer=workflow_dot_trigger__service__pb2.GetTriggerRevisionHistoryRequest.SerializeToString,
41
+ response_deserializer=workflow_dot_trigger__service__pb2.GetTriggerRevisionHistoryResponse.FromString,
42
+ )
43
+ self.UpdateTriggers = channel.unary_unary(
44
+ '/cloudidl.workflow.TriggerService/UpdateTriggers',
45
+ request_serializer=workflow_dot_trigger__service__pb2.UpdateTriggersRequest.SerializeToString,
46
+ response_deserializer=workflow_dot_trigger__service__pb2.UpdateTriggersResponse.FromString,
47
+ )
48
+ self.DeleteTriggers = channel.unary_unary(
49
+ '/cloudidl.workflow.TriggerService/DeleteTriggers',
50
+ request_serializer=workflow_dot_trigger__service__pb2.DeleteTriggersRequest.SerializeToString,
51
+ response_deserializer=workflow_dot_trigger__service__pb2.DeleteTriggersResponse.FromString,
52
+ )
53
+
54
+
55
+ class TriggerServiceServicer(object):
56
+ """TriggerService provides an interface for managing triggers.
57
+ """
58
+
59
+ def DeployTrigger(self, request, context):
60
+ """Create if trigger didn't exist previously.
61
+ Update if it already exists.
62
+ Re-create(or undelete) if it was soft-deleted.
63
+ Client must fetch the latest trigger in order to obtain the latest `trigger.id.revision`.
64
+ If trigger is not found, client can set `trigger.id.revision` to 1, it is ignored and set automatically by backend.
65
+ If trigger is found, client should set `trigger.id.revision` to the <latest>.
66
+ Backend validates that version is the latest and creates a new revision of the trigger.
67
+ Otherwise, operation is rejected(optimistic locking) and client must re-fetch trigger again.
68
+ """
69
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
70
+ context.set_details('Method not implemented!')
71
+ raise NotImplementedError('Method not implemented!')
72
+
73
+ def GetTriggerDetails(self, request, context):
74
+ """Get detailed info about the latest trigger revision
75
+ """
76
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
77
+ context.set_details('Method not implemented!')
78
+ raise NotImplementedError('Method not implemented!')
79
+
80
+ def GetTriggerRevisionDetails(self, request, context):
81
+ """Get detailed info about a specific trigger revision
82
+ """
83
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
84
+ context.set_details('Method not implemented!')
85
+ raise NotImplementedError('Method not implemented!')
86
+
87
+ def ListTriggers(self, request, context):
88
+ """List basic info about triggers based on various filtering and sorting rules.
89
+ """
90
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
91
+ context.set_details('Method not implemented!')
92
+ raise NotImplementedError('Method not implemented!')
93
+
94
+ def GetTriggerRevisionHistory(self, request, context):
95
+ """GetTriggerRevisionHistory returns all revisions for a given trigger
96
+ """
97
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
98
+ context.set_details('Method not implemented!')
99
+ raise NotImplementedError('Method not implemented!')
100
+
101
+ def UpdateTriggers(self, request, context):
102
+ """Update some trigger spec fields for multiple triggers at once
103
+ """
104
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
105
+ context.set_details('Method not implemented!')
106
+ raise NotImplementedError('Method not implemented!')
107
+
108
+ def DeleteTriggers(self, request, context):
109
+ """Soft-delete multiple triggers at once.
110
+ """
111
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
112
+ context.set_details('Method not implemented!')
113
+ raise NotImplementedError('Method not implemented!')
114
+
115
+
116
+ def add_TriggerServiceServicer_to_server(servicer, server):
117
+ rpc_method_handlers = {
118
+ 'DeployTrigger': grpc.unary_unary_rpc_method_handler(
119
+ servicer.DeployTrigger,
120
+ request_deserializer=workflow_dot_trigger__service__pb2.DeployTriggerRequest.FromString,
121
+ response_serializer=workflow_dot_trigger__service__pb2.DeployTriggerResponse.SerializeToString,
122
+ ),
123
+ 'GetTriggerDetails': grpc.unary_unary_rpc_method_handler(
124
+ servicer.GetTriggerDetails,
125
+ request_deserializer=workflow_dot_trigger__service__pb2.GetTriggerDetailsRequest.FromString,
126
+ response_serializer=workflow_dot_trigger__service__pb2.GetTriggerDetailsResponse.SerializeToString,
127
+ ),
128
+ 'GetTriggerRevisionDetails': grpc.unary_unary_rpc_method_handler(
129
+ servicer.GetTriggerRevisionDetails,
130
+ request_deserializer=workflow_dot_trigger__service__pb2.GetTriggerRevisionDetailsRequest.FromString,
131
+ response_serializer=workflow_dot_trigger__service__pb2.GetTriggerRevisionDetailsResponse.SerializeToString,
132
+ ),
133
+ 'ListTriggers': grpc.unary_unary_rpc_method_handler(
134
+ servicer.ListTriggers,
135
+ request_deserializer=workflow_dot_trigger__service__pb2.ListTriggersRequest.FromString,
136
+ response_serializer=workflow_dot_trigger__service__pb2.ListTriggersResponse.SerializeToString,
137
+ ),
138
+ 'GetTriggerRevisionHistory': grpc.unary_unary_rpc_method_handler(
139
+ servicer.GetTriggerRevisionHistory,
140
+ request_deserializer=workflow_dot_trigger__service__pb2.GetTriggerRevisionHistoryRequest.FromString,
141
+ response_serializer=workflow_dot_trigger__service__pb2.GetTriggerRevisionHistoryResponse.SerializeToString,
142
+ ),
143
+ 'UpdateTriggers': grpc.unary_unary_rpc_method_handler(
144
+ servicer.UpdateTriggers,
145
+ request_deserializer=workflow_dot_trigger__service__pb2.UpdateTriggersRequest.FromString,
146
+ response_serializer=workflow_dot_trigger__service__pb2.UpdateTriggersResponse.SerializeToString,
147
+ ),
148
+ 'DeleteTriggers': grpc.unary_unary_rpc_method_handler(
149
+ servicer.DeleteTriggers,
150
+ request_deserializer=workflow_dot_trigger__service__pb2.DeleteTriggersRequest.FromString,
151
+ response_serializer=workflow_dot_trigger__service__pb2.DeleteTriggersResponse.SerializeToString,
152
+ ),
153
+ }
154
+ generic_handler = grpc.method_handlers_generic_handler(
155
+ 'cloudidl.workflow.TriggerService', rpc_method_handlers)
156
+ server.add_generic_rpc_handlers((generic_handler,))
157
+
158
+
159
+ # This class is part of an EXPERIMENTAL API.
160
+ class TriggerService(object):
161
+ """TriggerService provides an interface for managing triggers.
162
+ """
163
+
164
+ @staticmethod
165
+ def DeployTrigger(request,
166
+ target,
167
+ options=(),
168
+ channel_credentials=None,
169
+ call_credentials=None,
170
+ insecure=False,
171
+ compression=None,
172
+ wait_for_ready=None,
173
+ timeout=None,
174
+ metadata=None):
175
+ return grpc.experimental.unary_unary(request, target, '/cloudidl.workflow.TriggerService/DeployTrigger',
176
+ workflow_dot_trigger__service__pb2.DeployTriggerRequest.SerializeToString,
177
+ workflow_dot_trigger__service__pb2.DeployTriggerResponse.FromString,
178
+ options, channel_credentials,
179
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
180
+
181
+ @staticmethod
182
+ def GetTriggerDetails(request,
183
+ target,
184
+ options=(),
185
+ channel_credentials=None,
186
+ call_credentials=None,
187
+ insecure=False,
188
+ compression=None,
189
+ wait_for_ready=None,
190
+ timeout=None,
191
+ metadata=None):
192
+ return grpc.experimental.unary_unary(request, target, '/cloudidl.workflow.TriggerService/GetTriggerDetails',
193
+ workflow_dot_trigger__service__pb2.GetTriggerDetailsRequest.SerializeToString,
194
+ workflow_dot_trigger__service__pb2.GetTriggerDetailsResponse.FromString,
195
+ options, channel_credentials,
196
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
197
+
198
+ @staticmethod
199
+ def GetTriggerRevisionDetails(request,
200
+ target,
201
+ options=(),
202
+ channel_credentials=None,
203
+ call_credentials=None,
204
+ insecure=False,
205
+ compression=None,
206
+ wait_for_ready=None,
207
+ timeout=None,
208
+ metadata=None):
209
+ return grpc.experimental.unary_unary(request, target, '/cloudidl.workflow.TriggerService/GetTriggerRevisionDetails',
210
+ workflow_dot_trigger__service__pb2.GetTriggerRevisionDetailsRequest.SerializeToString,
211
+ workflow_dot_trigger__service__pb2.GetTriggerRevisionDetailsResponse.FromString,
212
+ options, channel_credentials,
213
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
214
+
215
+ @staticmethod
216
+ def ListTriggers(request,
217
+ target,
218
+ options=(),
219
+ channel_credentials=None,
220
+ call_credentials=None,
221
+ insecure=False,
222
+ compression=None,
223
+ wait_for_ready=None,
224
+ timeout=None,
225
+ metadata=None):
226
+ return grpc.experimental.unary_unary(request, target, '/cloudidl.workflow.TriggerService/ListTriggers',
227
+ workflow_dot_trigger__service__pb2.ListTriggersRequest.SerializeToString,
228
+ workflow_dot_trigger__service__pb2.ListTriggersResponse.FromString,
229
+ options, channel_credentials,
230
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
231
+
232
+ @staticmethod
233
+ def GetTriggerRevisionHistory(request,
234
+ target,
235
+ options=(),
236
+ channel_credentials=None,
237
+ call_credentials=None,
238
+ insecure=False,
239
+ compression=None,
240
+ wait_for_ready=None,
241
+ timeout=None,
242
+ metadata=None):
243
+ return grpc.experimental.unary_unary(request, target, '/cloudidl.workflow.TriggerService/GetTriggerRevisionHistory',
244
+ workflow_dot_trigger__service__pb2.GetTriggerRevisionHistoryRequest.SerializeToString,
245
+ workflow_dot_trigger__service__pb2.GetTriggerRevisionHistoryResponse.FromString,
246
+ options, channel_credentials,
247
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
248
+
249
+ @staticmethod
250
+ def UpdateTriggers(request,
251
+ target,
252
+ options=(),
253
+ channel_credentials=None,
254
+ call_credentials=None,
255
+ insecure=False,
256
+ compression=None,
257
+ wait_for_ready=None,
258
+ timeout=None,
259
+ metadata=None):
260
+ return grpc.experimental.unary_unary(request, target, '/cloudidl.workflow.TriggerService/UpdateTriggers',
261
+ workflow_dot_trigger__service__pb2.UpdateTriggersRequest.SerializeToString,
262
+ workflow_dot_trigger__service__pb2.UpdateTriggersResponse.FromString,
263
+ options, channel_credentials,
264
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
265
+
266
+ @staticmethod
267
+ def DeleteTriggers(request,
268
+ target,
269
+ options=(),
270
+ channel_credentials=None,
271
+ call_credentials=None,
272
+ insecure=False,
273
+ compression=None,
274
+ wait_for_ready=None,
275
+ timeout=None,
276
+ metadata=None):
277
+ return grpc.experimental.unary_unary(request, target, '/cloudidl.workflow.TriggerService/DeleteTriggers',
278
+ workflow_dot_trigger__service__pb2.DeleteTriggersRequest.SerializeToString,
279
+ workflow_dot_trigger__service__pb2.DeleteTriggersResponse.FromString,
280
+ options, channel_credentials,
281
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
flyte/_run.py CHANGED
@@ -90,9 +90,10 @@ class _Runner:
90
90
  env_vars: Dict[str, str] | None = None,
91
91
  labels: Dict[str, str] | None = None,
92
92
  annotations: Dict[str, str] | None = None,
93
- interruptible: bool = False,
93
+ interruptible: bool | None = None,
94
94
  log_level: int | None = None,
95
95
  disable_run_cache: bool = False,
96
+ queue: Optional[str] = None,
96
97
  ):
97
98
  from flyte._tools import ipython_check
98
99
 
@@ -111,8 +112,8 @@ class _Runner:
111
112
  self._copy_bundle_to = copy_bundle_to
112
113
  self._interactive_mode = interactive_mode if interactive_mode else ipython_check()
113
114
  self._raw_data_path = raw_data_path
114
- self._metadata_path = metadata_path or "/tmp"
115
- self._run_base_dir = run_base_dir or "/tmp/base"
115
+ self._metadata_path = metadata_path
116
+ self._run_base_dir = run_base_dir
116
117
  self._overwrite_cache = overwrite_cache
117
118
  self._project = project
118
119
  self._domain = domain
@@ -122,6 +123,7 @@ class _Runner:
122
123
  self._interruptible = interruptible
123
124
  self._log_level = log_level
124
125
  self._disable_run_cache = disable_run_cache
126
+ self._queue = queue
125
127
 
126
128
  @requires_initialization
127
129
  async def _run_remote(self, obj: TaskTemplate[P, R] | LazyEntity, *args: P.args, **kwargs: P.kwargs) -> Run:
@@ -150,6 +152,7 @@ class _Runner:
150
152
  version = task.pb2.task_id.version
151
153
  code_bundle = None
152
154
  else:
155
+ task = cast(TaskTemplate[P, R], obj)
153
156
  if obj.parent_env is None:
154
157
  raise ValueError("Task is not attached to an environment. Please attach the task to an environment")
155
158
 
@@ -204,10 +207,11 @@ class _Runner:
204
207
  inputs = await convert_from_native_to_inputs(obj.native_interface, *args, **kwargs)
205
208
 
206
209
  env = self._env_vars or {}
207
- if self._log_level:
208
- env["LOG_LEVEL"] = str(self._log_level)
209
- else:
210
- env["LOG_LEVEL"] = str(logger.getEffectiveLevel())
210
+ if env.get("LOG_LEVEL") is None:
211
+ if self._log_level:
212
+ env["LOG_LEVEL"] = str(self._log_level)
213
+ else:
214
+ env["LOG_LEVEL"] = str(logger.getEffectiveLevel())
211
215
 
212
216
  if not self._dry_run:
213
217
  if get_client() is None:
@@ -263,10 +267,13 @@ class _Runner:
263
267
  inputs=inputs.proto_inputs,
264
268
  run_spec=run_definition_pb2.RunSpec(
265
269
  overwrite_cache=self._overwrite_cache,
266
- interruptible=wrappers_pb2.BoolValue(value=self._interruptible),
270
+ interruptible=wrappers_pb2.BoolValue(value=self._interruptible)
271
+ if self._interruptible is not None
272
+ else None,
267
273
  annotations=annotations,
268
274
  labels=labels,
269
275
  envs=env_kv,
276
+ cluster=self._queue or task.queue,
270
277
  ),
271
278
  ),
272
279
  )
@@ -385,6 +392,7 @@ class _Runner:
385
392
  " flyte.with_runcontext(run_base_dir='s3://bucket/metadata/outputs')",
386
393
  )
387
394
  output_path = self._run_base_dir
395
+ run_base_dir = self._run_base_dir
388
396
  raw_data_path = f"{output_path}/rd/{random_id}"
389
397
  raw_data_path_obj = RawDataPath(path=raw_data_path)
390
398
  checkpoint_path = f"{raw_data_path}/checkpoint"
@@ -401,7 +409,7 @@ class _Runner:
401
409
  version=version if version else "na",
402
410
  raw_data_path=raw_data_path_obj,
403
411
  compiled_image_cache=image_cache,
404
- run_base_dir=self._run_base_dir,
412
+ run_base_dir=run_base_dir,
405
413
  report=flyte.report.Report(name=action.name),
406
414
  )
407
415
  async with ctx.replace_task_context(tctx):
@@ -426,6 +434,18 @@ class _Runner:
426
434
  else:
427
435
  action = ActionID(name=self._name)
428
436
 
437
+ metadata_path = self._metadata_path
438
+ if metadata_path is None:
439
+ metadata_path = pathlib.Path("/") / "tmp" / "flyte" / "metadata" / action.name
440
+ else:
441
+ metadata_path = pathlib.Path(metadata_path) / action.name
442
+ output_path = metadata_path / "a0"
443
+ if self._raw_data_path is None:
444
+ path = pathlib.Path("/") / "tmp" / "flyte" / "raw_data" / action.name
445
+ raw_data_path = RawDataPath(path=str(path))
446
+ else:
447
+ raw_data_path = RawDataPath(path=self._raw_data_path)
448
+
429
449
  ctx = internal_ctx()
430
450
  tctx = TaskContext(
431
451
  action=action,
@@ -434,10 +454,10 @@ class _Runner:
434
454
  checkpoint_path=internal_ctx().raw_data.path,
435
455
  ),
436
456
  code_bundle=None,
437
- output_path=self._metadata_path,
438
- run_base_dir=self._metadata_path,
457
+ output_path=str(output_path),
458
+ run_base_dir=str(metadata_path),
439
459
  version="na",
440
- raw_data_path=internal_ctx().raw_data,
460
+ raw_data_path=raw_data_path,
441
461
  compiled_image_cache=None,
442
462
  report=Report(name=action.name),
443
463
  mode="local",
@@ -469,7 +489,7 @@ class _Runner:
469
489
 
470
490
  @property
471
491
  def url(self) -> str:
472
- return "local-run"
492
+ return str(metadata_path)
473
493
 
474
494
  def wait(
475
495
  self,
@@ -550,9 +570,10 @@ def with_runcontext(
550
570
  env_vars: Dict[str, str] | None = None,
551
571
  labels: Dict[str, str] | None = None,
552
572
  annotations: Dict[str, str] | None = None,
553
- interruptible: bool = False,
573
+ interruptible: bool | None = None,
554
574
  log_level: int | None = None,
555
575
  disable_run_cache: bool = False,
576
+ queue: Optional[str] = None,
556
577
  ) -> _Runner:
557
578
  """
558
579
  Launch a new run with the given parameters as the context.
@@ -590,15 +611,20 @@ def with_runcontext(
590
611
  :param env_vars: Optional Environment variables to set for the run
591
612
  :param labels: Optional Labels to set for the run
592
613
  :param annotations: Optional Annotations to set for the run
593
- :param interruptible: Optional If true, the run can be interrupted by the user.
614
+ :param interruptible: Optional If true, the run can be scheduled on interruptible instances and false implies
615
+ that all tasks in the run should only be scheduled on non-interruptible instances. If not specified the
616
+ original setting on all tasks is retained.
594
617
  :param log_level: Optional Log level to set for the run. If not provided, it will be set to the default log level
595
618
  set using `flyte.init()`
596
619
  :param disable_run_cache: Optional If true, the run cache will be disabled. This is useful for testing purposes.
620
+ :param queue: Optional The queue to use for the run. This is used to specify the cluster to use for the run.
597
621
 
598
622
  :return: runner
599
623
  """
600
624
  if mode == "hybrid" and not name and not run_base_dir:
601
625
  raise ValueError("Run name and run base dir are required for hybrid mode")
626
+ if copy_style == "none" and not version:
627
+ raise ValueError("Version is required when copy_style is 'none'")
602
628
  return _Runner(
603
629
  force_mode=mode,
604
630
  name=name,
@@ -619,6 +645,7 @@ def with_runcontext(
619
645
  domain=domain,
620
646
  log_level=log_level,
621
647
  disable_run_cache=disable_run_cache,
648
+ queue=queue,
622
649
  )
623
650
 
624
651
 
flyte/_task.py CHANGED
@@ -15,6 +15,7 @@ from typing import (
15
15
  Literal,
16
16
  Optional,
17
17
  ParamSpec,
18
+ Tuple,
18
19
  TypeAlias,
19
20
  TypeVar,
20
21
  Union,
@@ -38,6 +39,8 @@ from .models import MAX_INLINE_IO_BYTES, NativeInterface, SerializationContext
38
39
  if TYPE_CHECKING:
39
40
  from flyteidl.core.tasks_pb2 import DataLoadingConfig
40
41
 
42
+ from flyte.trigger import Trigger
43
+
41
44
  from ._task_environment import TaskEnvironment
42
45
 
43
46
  P = ParamSpec("P") # capture the function's parameters
@@ -69,8 +72,8 @@ class TaskTemplate(Generic[P, R]):
69
72
  version with flyte installed
70
73
  :param resources: Optional The resources to use for the task
71
74
  :param cache: Optional The cache policy for the task, defaults to auto, which will cache the results of the task.
72
- :param interruptable: Optional The interruptable policy for the task, defaults to False, which means the task
73
- will not be scheduled on interruptable nodes. If set to True, the task will be scheduled on interruptable nodes,
75
+ :param interruptible: Optional The interruptible policy for the task, defaults to False, which means the task
76
+ will not be scheduled on interruptible nodes. If set to True, the task will be scheduled on interruptible nodes,
74
77
  and the code should handle interruptions and resumptions.
75
78
  :param retries: Optional The number of retries for the task, defaults to 0, which means no retries.
76
79
  :param reusable: Optional The reusability policy for the task, defaults to None, which means the task environment
@@ -81,6 +84,9 @@ class TaskTemplate(Generic[P, R]):
81
84
  :param timeout: Optional The timeout for the task.
82
85
  :param max_inline_io_bytes: Maximum allowed size (in bytes) for all inputs and outputs passed directly to the task
83
86
  (e.g., primitives, strings, dicts). Does not apply to files, directories, or dataframes.
87
+ :param pod_template: Optional The pod template to use for the task.
88
+ :param report: Optional Whether to report the task execution to the Flyte console, defaults to False.
89
+ :param queue: Optional The queue to use for the task. If not provided, the default queue will be used.
84
90
  """
85
91
 
86
92
  name: str
@@ -90,8 +96,8 @@ class TaskTemplate(Generic[P, R]):
90
96
  task_type_version: int = 0
91
97
  image: Union[str, Image, Literal["auto"]] = "auto"
92
98
  resources: Optional[Resources] = None
93
- cache: CacheRequest = "auto"
94
- interruptable: bool = False
99
+ cache: CacheRequest = "disable"
100
+ interruptible: bool = False
95
101
  retries: Union[int, RetryStrategy] = 0
96
102
  reusable: Union[ReusePolicy, None] = None
97
103
  docs: Optional[Documentation] = None
@@ -100,10 +106,12 @@ class TaskTemplate(Generic[P, R]):
100
106
  timeout: Optional[TimeoutType] = None
101
107
  pod_template: Optional[Union[str, PodTemplate]] = None
102
108
  report: bool = False
109
+ queue: Optional[str] = None
103
110
 
104
111
  parent_env: Optional[weakref.ReferenceType[TaskEnvironment]] = None
105
112
  ref: bool = field(default=False, init=False, repr=False, compare=False)
106
113
  max_inline_io_bytes: int = MAX_INLINE_IO_BYTES
114
+ triggers: Tuple[Trigger, ...] = field(default_factory=tuple)
107
115
 
108
116
  # Only used in python 3.10 and 3.11, where we cannot use markcoroutinefunction
109
117
  _call_as_synchronous: bool = False
@@ -327,11 +335,30 @@ class TaskTemplate(Generic[P, R]):
327
335
  secrets: Optional[SecretRequest] = None,
328
336
  max_inline_io_bytes: int | None = None,
329
337
  pod_template: Optional[Union[str, PodTemplate]] = None,
338
+ queue: Optional[str] = None,
339
+ interruptible: Optional[bool] = None,
330
340
  **kwargs: Any,
331
341
  ) -> TaskTemplate:
332
342
  """
333
343
  Override various parameters of the task template. This allows for dynamic configuration of the task
334
344
  when it is called, such as changing the image, resources, cache policy, etc.
345
+
346
+ :param short_name: Optional override for the short name of the task.
347
+ :param resources: Optional override for the resources to use for the task.
348
+ :param cache: Optional override for the cache policy for the task.
349
+ :param retries: Optional override for the number of retries for the task.
350
+ :param timeout: Optional override for the timeout for the task.
351
+ :param reusable: Optional override for the reusability policy for the task.
352
+ :param env_vars: Optional override for the environment variables to set for the task.
353
+ :param secrets: Optional override for the secrets that will be injected into the task at runtime.
354
+ :param max_inline_io_bytes: Optional override for the maximum allowed size (in bytes) for all inputs and outputs
355
+ passed directly to the task.
356
+ :param pod_template: Optional override for the pod template to use for the task.
357
+ :param queue: Optional override for the queue to use for the task.
358
+ :param kwargs: Additional keyword arguments for further overrides. Some fields like name, image, docs,
359
+ and interface cannot be overridden.
360
+
361
+ :return: A new TaskTemplate instance with the overridden parameters.
335
362
  """
336
363
  cache = cache or self.cache
337
364
  retries = retries or self.retries
@@ -366,6 +393,8 @@ class TaskTemplate(Generic[P, R]):
366
393
  env_vars = env_vars or self.env_vars
367
394
  secrets = secrets or self.secrets
368
395
 
396
+ interruptible = interruptible if interruptible is not None else self.interruptible
397
+
369
398
  for k, v in kwargs.items():
370
399
  if k == "name":
371
400
  raise ValueError("Name cannot be overridden")
@@ -388,6 +417,8 @@ class TaskTemplate(Generic[P, R]):
388
417
  secrets=secrets,
389
418
  max_inline_io_bytes=max_inline_io_bytes,
390
419
  pod_template=pod_template,
420
+ interruptible=interruptible,
421
+ queue=queue or self.queue,
391
422
  **kwargs,
392
423
  )
393
424