flyte 0.2.0b35__py3-none-any.whl → 0.2.0b37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/_image.py +1 -1
- flyte/_internal/controllers/_local_controller.py +3 -2
- flyte/_internal/controllers/_trace.py +14 -10
- flyte/_internal/controllers/remote/_action.py +37 -7
- flyte/_internal/controllers/remote/_controller.py +43 -21
- flyte/_internal/controllers/remote/_core.py +32 -16
- flyte/_internal/controllers/remote/_informer.py +18 -7
- flyte/_internal/runtime/task_serde.py +17 -6
- flyte/_protos/common/identifier_pb2.py +23 -1
- flyte/_protos/common/identifier_pb2.pyi +28 -0
- flyte/_protos/workflow/queue_service_pb2.py +33 -29
- flyte/_protos/workflow/queue_service_pb2.pyi +34 -16
- flyte/_protos/workflow/run_definition_pb2.py +64 -71
- flyte/_protos/workflow/run_definition_pb2.pyi +44 -31
- flyte/_protos/workflow/run_logs_service_pb2.py +10 -10
- flyte/_protos/workflow/run_logs_service_pb2.pyi +3 -3
- flyte/_protos/workflow/run_service_pb2.py +54 -46
- flyte/_protos/workflow/run_service_pb2.pyi +32 -18
- flyte/_protos/workflow/run_service_pb2_grpc.py +34 -0
- flyte/_protos/workflow/state_service_pb2.py +20 -19
- flyte/_protos/workflow/state_service_pb2.pyi +13 -12
- flyte/_run.py +11 -6
- flyte/_trace.py +4 -10
- flyte/_version.py +2 -2
- flyte/migrate/__init__.py +1 -0
- flyte/migrate/dynamic.py +13 -0
- flyte/migrate/task.py +99 -0
- flyte/migrate/workflow.py +13 -0
- flyte/remote/_action.py +56 -25
- flyte/remote/_logs.py +4 -3
- flyte/remote/_run.py +5 -4
- flyte-0.2.0b37.dist-info/METADATA +371 -0
- {flyte-0.2.0b35.dist-info → flyte-0.2.0b37.dist-info}/RECORD +38 -33
- flyte-0.2.0b37.dist-info/licenses/LICENSE +201 -0
- flyte-0.2.0b35.dist-info/METADATA +0 -249
- {flyte-0.2.0b35.data → flyte-0.2.0b37.data}/scripts/runtime.py +0 -0
- {flyte-0.2.0b35.dist-info → flyte-0.2.0b37.dist-info}/WHEEL +0 -0
- {flyte-0.2.0b35.dist-info → flyte-0.2.0b37.dist-info}/entry_points.txt +0 -0
- {flyte-0.2.0b35.dist-info → flyte-0.2.0b37.dist-info}/top_level.txt +0 -0
flyte/_image.py
CHANGED
|
@@ -161,10 +161,10 @@ class LocalController:
|
|
|
161
161
|
assert action_output_path
|
|
162
162
|
return (
|
|
163
163
|
TraceInfo(
|
|
164
|
+
name=_func.__name__,
|
|
164
165
|
action=action_id,
|
|
165
166
|
interface=_interface,
|
|
166
167
|
inputs_path=action_output_path,
|
|
167
|
-
name=_func.__name__,
|
|
168
168
|
),
|
|
169
169
|
True,
|
|
170
170
|
)
|
|
@@ -189,7 +189,8 @@ class LocalController:
|
|
|
189
189
|
converted_error = convert.convert_from_native_to_error(info.error)
|
|
190
190
|
assert converted_error
|
|
191
191
|
assert info.action
|
|
192
|
-
assert info.
|
|
192
|
+
assert info.start_time
|
|
193
|
+
assert info.end_time
|
|
193
194
|
|
|
194
195
|
async def submit_task_ref(self, _task: task_definition_pb2.TaskDetails, *args, **kwargs) -> Any:
|
|
195
196
|
raise flyte.errors.ReferenceTaskError("Reference tasks cannot be executed locally, only remotely.")
|
|
@@ -1,5 +1,4 @@
|
|
|
1
|
-
from dataclasses import dataclass
|
|
2
|
-
from datetime import timedelta
|
|
1
|
+
from dataclasses import dataclass, field
|
|
3
2
|
from typing import Any, Optional
|
|
4
3
|
|
|
5
4
|
from flyte.models import ActionID, NativeInterface
|
|
@@ -12,30 +11,35 @@ class TraceInfo:
|
|
|
12
11
|
the action is completed.
|
|
13
12
|
"""
|
|
14
13
|
|
|
14
|
+
name: str
|
|
15
15
|
action: ActionID
|
|
16
16
|
interface: NativeInterface
|
|
17
17
|
inputs_path: str
|
|
18
|
-
|
|
18
|
+
start_time: float = field(init=False, default=0.0)
|
|
19
|
+
end_time: float = field(init=False, default=0.0)
|
|
19
20
|
output: Optional[Any] = None
|
|
20
21
|
error: Optional[Exception] = None
|
|
21
|
-
name: str = ""
|
|
22
22
|
|
|
23
|
-
def add_outputs(self, output: Any,
|
|
23
|
+
def add_outputs(self, output: Any, start_time: float, end_time: float):
|
|
24
24
|
"""
|
|
25
25
|
Add outputs to the trace information.
|
|
26
26
|
:param output: Output of the action
|
|
27
|
-
:param
|
|
27
|
+
:param start_time: Start time of the action
|
|
28
|
+
:param end_time: End time of the action
|
|
28
29
|
:return:
|
|
29
30
|
"""
|
|
30
31
|
self.output = output
|
|
31
|
-
self.
|
|
32
|
+
self.start_time = start_time
|
|
33
|
+
self.end_time = end_time
|
|
32
34
|
|
|
33
|
-
def add_error(self, error: Exception,
|
|
35
|
+
def add_error(self, error: Exception, start_time: float, end_time: float):
|
|
34
36
|
"""
|
|
35
37
|
Add error to the trace information.
|
|
36
38
|
:param error: Error of the action
|
|
37
|
-
:param
|
|
39
|
+
:param start_time: Start time of the action
|
|
40
|
+
:param end_time: End time of the action
|
|
38
41
|
:return:
|
|
39
42
|
"""
|
|
40
43
|
self.error = error
|
|
41
|
-
self.
|
|
44
|
+
self.start_time = start_time
|
|
45
|
+
self.end_time = end_time
|
|
@@ -1,11 +1,18 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import Literal
|
|
5
5
|
|
|
6
6
|
from flyteidl.core import execution_pb2
|
|
7
|
-
|
|
8
|
-
|
|
7
|
+
from google.protobuf import timestamp_pb2
|
|
8
|
+
|
|
9
|
+
from flyte._protos.common import identifier_pb2
|
|
10
|
+
from flyte._protos.workflow import (
|
|
11
|
+
queue_service_pb2,
|
|
12
|
+
run_definition_pb2,
|
|
13
|
+
state_service_pb2,
|
|
14
|
+
task_definition_pb2,
|
|
15
|
+
)
|
|
9
16
|
from flyte.models import GroupData
|
|
10
17
|
|
|
11
18
|
ActionType = Literal["task", "trace"]
|
|
@@ -18,12 +25,13 @@ class Action:
|
|
|
18
25
|
Holds the inmemory state of a task. It is combined representation of local and remote states.
|
|
19
26
|
"""
|
|
20
27
|
|
|
21
|
-
action_id:
|
|
28
|
+
action_id: identifier_pb2.ActionIdentifier
|
|
22
29
|
parent_action_name: str
|
|
23
30
|
type: ActionType = "task" # type of action, task or trace
|
|
24
31
|
friendly_name: str | None = None
|
|
25
32
|
group: GroupData | None = None
|
|
26
33
|
task: task_definition_pb2.TaskSpec | None = None
|
|
34
|
+
trace: queue_service_pb2.TraceAction | None = None
|
|
27
35
|
inputs_uri: str | None = None
|
|
28
36
|
run_output_base: str | None = None
|
|
29
37
|
realized_outputs_uri: str | None = None
|
|
@@ -108,7 +116,7 @@ class Action:
|
|
|
108
116
|
def from_task(
|
|
109
117
|
cls,
|
|
110
118
|
parent_action_name: str,
|
|
111
|
-
sub_action_id:
|
|
119
|
+
sub_action_id: identifier_pb2.ActionIdentifier,
|
|
112
120
|
group_data: GroupData | None,
|
|
113
121
|
task_spec: task_definition_pb2.TaskSpec,
|
|
114
122
|
inputs_uri: str,
|
|
@@ -153,16 +161,27 @@ class Action:
|
|
|
153
161
|
def from_trace(
|
|
154
162
|
cls,
|
|
155
163
|
parent_action_name: str,
|
|
156
|
-
action_id:
|
|
164
|
+
action_id: identifier_pb2.ActionIdentifier,
|
|
157
165
|
friendly_name: str,
|
|
158
166
|
group_data: GroupData | None,
|
|
159
|
-
trace_spec: Any, # TODO
|
|
160
167
|
inputs_uri: str,
|
|
161
168
|
outputs_uri: str,
|
|
169
|
+
start_time: float, # Unix timestamp in seconds with fractional seconds
|
|
170
|
+
end_time: float, # Unix timestamp in seconds with fractional seconds
|
|
171
|
+
run_output_base: str,
|
|
172
|
+
report_uri: str | None = None,
|
|
162
173
|
) -> Action:
|
|
163
174
|
"""
|
|
164
175
|
This creates a new action for tracing purposes. It is used to track the execution of a trace.
|
|
165
176
|
"""
|
|
177
|
+
st = timestamp_pb2.Timestamp()
|
|
178
|
+
st.FromSeconds(int(start_time))
|
|
179
|
+
st.nanos = int((start_time % 1) * 1e9)
|
|
180
|
+
|
|
181
|
+
et = timestamp_pb2.Timestamp()
|
|
182
|
+
et.FromSeconds(int(end_time))
|
|
183
|
+
et.nanos = int((end_time % 1) * 1e9)
|
|
184
|
+
|
|
166
185
|
return cls(
|
|
167
186
|
action_id=action_id,
|
|
168
187
|
parent_action_name=parent_action_name,
|
|
@@ -172,4 +191,15 @@ class Action:
|
|
|
172
191
|
inputs_uri=inputs_uri,
|
|
173
192
|
realized_outputs_uri=outputs_uri,
|
|
174
193
|
phase=run_definition_pb2.Phase.PHASE_SUCCEEDED,
|
|
194
|
+
run_output_base=run_output_base,
|
|
195
|
+
trace=queue_service_pb2.TraceAction(
|
|
196
|
+
name=friendly_name,
|
|
197
|
+
phase=run_definition_pb2.Phase.PHASE_SUCCEEDED,
|
|
198
|
+
start_time=st,
|
|
199
|
+
end_time=et,
|
|
200
|
+
outputs=run_definition_pb2.OutputReferences(
|
|
201
|
+
output_uri=outputs_uri,
|
|
202
|
+
report_uri=report_uri,
|
|
203
|
+
),
|
|
204
|
+
),
|
|
175
205
|
)
|
|
@@ -22,6 +22,7 @@ from flyte._internal.controllers.remote._service_protocol import ClientSet
|
|
|
22
22
|
from flyte._internal.runtime import convert, io
|
|
23
23
|
from flyte._internal.runtime.task_serde import translate_task_to_wire
|
|
24
24
|
from flyte._logging import logger
|
|
25
|
+
from flyte._protos.common import identifier_pb2
|
|
25
26
|
from flyte._protos.workflow import run_definition_pb2, task_definition_pb2
|
|
26
27
|
from flyte._task import TaskTemplate
|
|
27
28
|
from flyte._utils.helpers import _selector_policy
|
|
@@ -131,11 +132,18 @@ class RemoteController(Controller):
|
|
|
131
132
|
Generate a task call sequence for the given task object and action ID.
|
|
132
133
|
This is used to track the number of times a task is called within an action.
|
|
133
134
|
"""
|
|
134
|
-
|
|
135
|
+
uniq = unique_action_name(action_id)
|
|
136
|
+
current_action_sequencer = self._parent_action_task_call_sequence[uniq]
|
|
135
137
|
current_task_id = id(task_obj)
|
|
136
138
|
v = current_action_sequencer[current_task_id]
|
|
137
139
|
new_seq = v + 1
|
|
138
140
|
current_action_sequencer[current_task_id] = new_seq
|
|
141
|
+
name = ""
|
|
142
|
+
if hasattr(task_obj, "__name__"):
|
|
143
|
+
name = task_obj.__name__
|
|
144
|
+
elif hasattr(task_obj, "name"):
|
|
145
|
+
name = task_obj.name
|
|
146
|
+
logger.warning(f"For action {uniq}, task {name} call sequence is {new_seq}")
|
|
139
147
|
return new_seq
|
|
140
148
|
|
|
141
149
|
async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwargs) -> Any:
|
|
@@ -178,6 +186,7 @@ class RemoteController(Controller):
|
|
|
178
186
|
sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
|
|
179
187
|
tctx, task_spec, inputs_hash, _task_call_seq
|
|
180
188
|
)
|
|
189
|
+
logger.warning(f"Sub action {sub_action_id} output path {sub_action_output_path}")
|
|
181
190
|
|
|
182
191
|
serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
|
|
183
192
|
inputs_uri = io.inputs_path(sub_action_output_path)
|
|
@@ -204,9 +213,9 @@ class RemoteController(Controller):
|
|
|
204
213
|
inputs_hash = None # type: ignore
|
|
205
214
|
|
|
206
215
|
action = Action.from_task(
|
|
207
|
-
sub_action_id=
|
|
216
|
+
sub_action_id=identifier_pb2.ActionIdentifier(
|
|
208
217
|
name=sub_action_id.name,
|
|
209
|
-
run=
|
|
218
|
+
run=identifier_pb2.RunIdentifier(
|
|
210
219
|
name=current_action_id.run_name,
|
|
211
220
|
project=current_action_id.project,
|
|
212
221
|
domain=current_action_id.domain,
|
|
@@ -293,7 +302,9 @@ class RemoteController(Controller):
|
|
|
293
302
|
self._submit_loop.set_exception_handler(exc_handler)
|
|
294
303
|
|
|
295
304
|
self._submit_thread = threading.Thread(
|
|
296
|
-
name=f"remote-controller-{os.getpid()}-submitter",
|
|
305
|
+
name=f"remote-controller-{os.getpid()}-submitter",
|
|
306
|
+
daemon=True,
|
|
307
|
+
target=self._sync_thread_loop_runner,
|
|
297
308
|
)
|
|
298
309
|
self._submit_thread.start()
|
|
299
310
|
|
|
@@ -307,7 +318,7 @@ class RemoteController(Controller):
|
|
|
307
318
|
This method is invoked when the parent action is finished. It will finalize the run and upload the outputs
|
|
308
319
|
to the control plane.
|
|
309
320
|
"""
|
|
310
|
-
run_id =
|
|
321
|
+
run_id = identifier_pb2.RunIdentifier(
|
|
311
322
|
name=action_id.run_name,
|
|
312
323
|
project=action_id.project,
|
|
313
324
|
domain=action_id.domain,
|
|
@@ -350,9 +361,9 @@ class RemoteController(Controller):
|
|
|
350
361
|
serialized_inputs = None # type: ignore
|
|
351
362
|
|
|
352
363
|
prev_action = await self.get_action(
|
|
353
|
-
|
|
364
|
+
identifier_pb2.ActionIdentifier(
|
|
354
365
|
name=sub_action_id.name,
|
|
355
|
-
run=
|
|
366
|
+
run=identifier_pb2.RunIdentifier(
|
|
356
367
|
name=current_action_id.run_name,
|
|
357
368
|
project=current_action_id.project,
|
|
358
369
|
domain=current_action_id.domain,
|
|
@@ -363,21 +374,27 @@ class RemoteController(Controller):
|
|
|
363
374
|
)
|
|
364
375
|
|
|
365
376
|
if prev_action is None:
|
|
366
|
-
return TraceInfo(sub_action_id, _interface, inputs_uri), False
|
|
377
|
+
return TraceInfo(func_name, sub_action_id, _interface, inputs_uri), False
|
|
367
378
|
|
|
368
379
|
if prev_action.phase == run_definition_pb2.PHASE_FAILED:
|
|
369
380
|
if prev_action.has_error():
|
|
370
381
|
exc = convert.convert_error_to_native(prev_action.err)
|
|
371
|
-
return
|
|
382
|
+
return (
|
|
383
|
+
TraceInfo(func_name, sub_action_id, _interface, inputs_uri, error=exc),
|
|
384
|
+
True,
|
|
385
|
+
)
|
|
372
386
|
else:
|
|
373
387
|
logger.warning(f"Action {prev_action.action_id.name} failed, but no error was found, re-running trace!")
|
|
374
388
|
elif prev_action.realized_outputs_uri is not None:
|
|
375
389
|
outputs_file_path = io.outputs_path(prev_action.realized_outputs_uri)
|
|
376
390
|
o = await io.load_outputs(outputs_file_path)
|
|
377
391
|
outputs = await convert.convert_outputs_to_native(_interface, o)
|
|
378
|
-
return
|
|
392
|
+
return (
|
|
393
|
+
TraceInfo(func_name, sub_action_id, _interface, inputs_uri, output=outputs),
|
|
394
|
+
True,
|
|
395
|
+
)
|
|
379
396
|
|
|
380
|
-
return TraceInfo(sub_action_id, _interface, inputs_uri), False
|
|
397
|
+
return TraceInfo(func_name, sub_action_id, _interface, inputs_uri), False
|
|
381
398
|
|
|
382
399
|
async def record_trace(self, info: TraceInfo):
|
|
383
400
|
"""
|
|
@@ -391,26 +408,29 @@ class RemoteController(Controller):
|
|
|
391
408
|
raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
|
|
392
409
|
|
|
393
410
|
current_action_id = tctx.action
|
|
394
|
-
|
|
395
|
-
|
|
411
|
+
sub_run_output_path = storage.join(tctx.run_base_dir, info.action.name)
|
|
412
|
+
print(f"Sub run output path for {info.name} is {sub_run_output_path}", flush=True)
|
|
396
413
|
|
|
397
414
|
if info.interface.has_outputs():
|
|
398
415
|
outputs_file_path: str = ""
|
|
399
416
|
if info.output:
|
|
400
417
|
outputs = await convert.convert_from_native_to_outputs(info.output, info.interface)
|
|
401
418
|
outputs_file_path = io.outputs_path(sub_run_output_path)
|
|
402
|
-
|
|
419
|
+
print(
|
|
420
|
+
f"Uploading outputs for {info.name} Outputs file path: {outputs_file_path}",
|
|
421
|
+
flush=True,
|
|
422
|
+
)
|
|
423
|
+
await io.upload_outputs(outputs, sub_run_output_path)
|
|
403
424
|
elif info.error:
|
|
404
425
|
err = convert.convert_from_native_to_error(info.error)
|
|
405
|
-
|
|
406
|
-
await io.upload_error(err.err, error_path)
|
|
426
|
+
await io.upload_error(err.err, sub_run_output_path)
|
|
407
427
|
else:
|
|
408
428
|
raise flyte.errors.RuntimeSystemError("BadTraceInfo", "Trace info does not have output or error")
|
|
409
429
|
trace_action = Action.from_trace(
|
|
410
430
|
parent_action_name=current_action_id.name,
|
|
411
|
-
action_id=
|
|
431
|
+
action_id=identifier_pb2.ActionIdentifier(
|
|
412
432
|
name=info.action.name,
|
|
413
|
-
run=
|
|
433
|
+
run=identifier_pb2.RunIdentifier(
|
|
414
434
|
name=current_action_id.run_name,
|
|
415
435
|
project=current_action_id.project,
|
|
416
436
|
domain=current_action_id.domain,
|
|
@@ -419,9 +439,11 @@ class RemoteController(Controller):
|
|
|
419
439
|
),
|
|
420
440
|
inputs_uri=info.inputs_path,
|
|
421
441
|
outputs_uri=outputs_file_path,
|
|
422
|
-
trace_spec="",
|
|
423
442
|
friendly_name=info.name,
|
|
424
443
|
group_data=tctx.group_data,
|
|
444
|
+
run_output_base=tctx.run_base_dir,
|
|
445
|
+
start_time=info.start_time,
|
|
446
|
+
end_time=info.end_time,
|
|
425
447
|
)
|
|
426
448
|
try:
|
|
427
449
|
logger.info(
|
|
@@ -478,9 +500,9 @@ class RemoteController(Controller):
|
|
|
478
500
|
inputs_hash = None # type: ignore
|
|
479
501
|
|
|
480
502
|
action = Action.from_task(
|
|
481
|
-
sub_action_id=
|
|
503
|
+
sub_action_id=identifier_pb2.ActionIdentifier(
|
|
482
504
|
name=sub_action_id.name,
|
|
483
|
-
run=
|
|
505
|
+
run=identifier_pb2.RunIdentifier(
|
|
484
506
|
name=current_action_id.run_name,
|
|
485
507
|
project=current_action_id.project,
|
|
486
508
|
domain=current_action_id.domain,
|
|
@@ -11,7 +11,11 @@ from google.protobuf.wrappers_pb2 import StringValue
|
|
|
11
11
|
|
|
12
12
|
import flyte.errors
|
|
13
13
|
from flyte._logging import log, logger
|
|
14
|
-
from flyte._protos.
|
|
14
|
+
from flyte._protos.common import identifier_pb2
|
|
15
|
+
from flyte._protos.workflow import (
|
|
16
|
+
queue_service_pb2,
|
|
17
|
+
task_definition_pb2,
|
|
18
|
+
)
|
|
15
19
|
from flyte.errors import RuntimeSystemError
|
|
16
20
|
|
|
17
21
|
from ._action import Action
|
|
@@ -80,9 +84,7 @@ class Controller:
|
|
|
80
84
|
"""Public API to submit a resource and wait for completion"""
|
|
81
85
|
return await self._run_coroutine_in_controller_thread(self._bg_submit_action(action))
|
|
82
86
|
|
|
83
|
-
async def get_action(
|
|
84
|
-
self, action_id: run_definition_pb2.ActionIdentifier, parent_action_name: str
|
|
85
|
-
) -> Optional[Action]:
|
|
87
|
+
async def get_action(self, action_id: identifier_pb2.ActionIdentifier, parent_action_name: str) -> Optional[Action]:
|
|
86
88
|
"""Get the action from the informer"""
|
|
87
89
|
informer = await self._informers.get(run_name=action_id.run.name, parent_action_name=parent_action_name)
|
|
88
90
|
if informer:
|
|
@@ -94,7 +96,10 @@ class Controller:
|
|
|
94
96
|
return await self._run_coroutine_in_controller_thread(self._bg_cancel_action(action))
|
|
95
97
|
|
|
96
98
|
async def _finalize_parent_action(
|
|
97
|
-
self,
|
|
99
|
+
self,
|
|
100
|
+
run_id: identifier_pb2.RunIdentifier,
|
|
101
|
+
parent_action_name: str,
|
|
102
|
+
timeout: Optional[float] = None,
|
|
98
103
|
):
|
|
99
104
|
"""Finalize the parent run"""
|
|
100
105
|
await self._run_coroutine_in_controller_thread(
|
|
@@ -124,7 +129,8 @@ class Controller:
|
|
|
124
129
|
"""Watch for errors in the background thread"""
|
|
125
130
|
await self._run_coroutine_in_controller_thread(self._bg_watch_for_errors())
|
|
126
131
|
raise RuntimeSystemError(
|
|
127
|
-
code="InformerWatchFailure",
|
|
132
|
+
code="InformerWatchFailure",
|
|
133
|
+
message=f"Controller thread failed with exception: {self._get_exception()}",
|
|
128
134
|
)
|
|
129
135
|
|
|
130
136
|
@log
|
|
@@ -162,7 +168,8 @@ class Controller:
|
|
|
162
168
|
|
|
163
169
|
if self._get_exception():
|
|
164
170
|
raise RuntimeSystemError(
|
|
165
|
-
type(self._get_exception()).__name__,
|
|
171
|
+
type(self._get_exception()).__name__,
|
|
172
|
+
f"Controller thread startup failed: {self._get_exception()}",
|
|
166
173
|
)
|
|
167
174
|
|
|
168
175
|
logger.info(f"Controller started in thread: {self._thread.name}")
|
|
@@ -223,7 +230,7 @@ class Controller:
|
|
|
223
230
|
logger.debug(f"Controller thread exiting: {threading.current_thread().name}")
|
|
224
231
|
|
|
225
232
|
async def _bg_get_action(
|
|
226
|
-
self, action_id:
|
|
233
|
+
self, action_id: identifier_pb2.ActionIdentifier, parent_action_name: str
|
|
227
234
|
) -> Optional[Action]:
|
|
228
235
|
"""Get the action from the informer"""
|
|
229
236
|
# Ensure the informer is created and wait for it to be ready
|
|
@@ -240,7 +247,10 @@ class Controller:
|
|
|
240
247
|
return None
|
|
241
248
|
|
|
242
249
|
async def _bg_finalize_informer(
|
|
243
|
-
self,
|
|
250
|
+
self,
|
|
251
|
+
run_id: identifier_pb2.RunIdentifier,
|
|
252
|
+
parent_action_name: str,
|
|
253
|
+
timeout: Optional[float] = None,
|
|
244
254
|
):
|
|
245
255
|
informer = await self._informers.remove(run_name=run_id.name, parent_action_name=parent_action_name)
|
|
246
256
|
if informer:
|
|
@@ -294,7 +304,10 @@ class Controller:
|
|
|
294
304
|
# )
|
|
295
305
|
logger.info(f"Successfully cancelled action: {action.name}")
|
|
296
306
|
except grpc.aio.AioRpcError as e:
|
|
297
|
-
if e.code() in [
|
|
307
|
+
if e.code() in [
|
|
308
|
+
grpc.StatusCode.NOT_FOUND,
|
|
309
|
+
grpc.StatusCode.FAILED_PRECONDITION,
|
|
310
|
+
]:
|
|
298
311
|
logger.info(f"Action {action.name} not found, assumed completed or cancelled.")
|
|
299
312
|
return
|
|
300
313
|
else:
|
|
@@ -333,10 +346,8 @@ class Controller:
|
|
|
333
346
|
spec=action.task,
|
|
334
347
|
cache_key=cache_key,
|
|
335
348
|
)
|
|
336
|
-
|
|
337
|
-
trace =
|
|
338
|
-
name=action.friendly_name,
|
|
339
|
-
)
|
|
349
|
+
elif action.type == "trace":
|
|
350
|
+
trace = action.trace
|
|
340
351
|
|
|
341
352
|
logger.debug(f"Attempting to launch action: {action.name}")
|
|
342
353
|
try:
|
|
@@ -380,7 +391,11 @@ class Controller:
|
|
|
380
391
|
async def _bg_log_stats(self):
|
|
381
392
|
"""Periodically log resource stats if debug is enabled"""
|
|
382
393
|
while self._running:
|
|
383
|
-
async for
|
|
394
|
+
async for (
|
|
395
|
+
started,
|
|
396
|
+
pending,
|
|
397
|
+
terminal,
|
|
398
|
+
) in self._informers.count_started_pending_terminal_actions():
|
|
384
399
|
logger.info(f"Resource stats: Started={started}, Pending={pending}, Terminal={terminal}")
|
|
385
400
|
await asyncio.sleep(self._resource_log_interval)
|
|
386
401
|
|
|
@@ -407,7 +422,8 @@ class Controller:
|
|
|
407
422
|
err.__cause__ = e
|
|
408
423
|
action.set_client_error(err)
|
|
409
424
|
informer = await self._informers.get(
|
|
410
|
-
run_name=action.run_name,
|
|
425
|
+
run_name=action.run_name,
|
|
426
|
+
parent_action_name=action.parent_action_name,
|
|
411
427
|
)
|
|
412
428
|
if informer:
|
|
413
429
|
await informer.fire_completion_event(action.name)
|
|
@@ -7,6 +7,7 @@ from typing import AsyncIterator, Callable, Dict, Optional, Tuple, cast
|
|
|
7
7
|
import grpc.aio
|
|
8
8
|
|
|
9
9
|
from flyte._logging import log, logger
|
|
10
|
+
from flyte._protos.common import identifier_pb2
|
|
10
11
|
from flyte._protos.workflow import run_definition_pb2, state_service_pb2
|
|
11
12
|
|
|
12
13
|
from ._action import Action
|
|
@@ -42,7 +43,9 @@ class ActionCache:
|
|
|
42
43
|
if state.output_uri:
|
|
43
44
|
logger.debug(f"Output URI: {state.output_uri}")
|
|
44
45
|
else:
|
|
45
|
-
logger.warning(
|
|
46
|
+
logger.warning(
|
|
47
|
+
f"{state.action_id.name} has no output URI, in phase {run_definition_pb2.Phase.Name(state.phase)}"
|
|
48
|
+
)
|
|
46
49
|
if state.phase == run_definition_pb2.Phase.PHASE_FAILED:
|
|
47
50
|
logger.error(
|
|
48
51
|
f"Action {state.action_id.name} failed with error (msg):"
|
|
@@ -125,7 +128,7 @@ class Informer:
|
|
|
125
128
|
|
|
126
129
|
def __init__(
|
|
127
130
|
self,
|
|
128
|
-
run_id:
|
|
131
|
+
run_id: identifier_pb2.RunIdentifier,
|
|
129
132
|
parent_action_name: str,
|
|
130
133
|
shared_queue: Queue,
|
|
131
134
|
client: Optional[StateService] = None,
|
|
@@ -217,7 +220,7 @@ class Informer:
|
|
|
217
220
|
try:
|
|
218
221
|
watcher = self._client.Watch(
|
|
219
222
|
state_service_pb2.WatchRequest(
|
|
220
|
-
parent_action_id=
|
|
223
|
+
parent_action_id=identifier_pb2.ActionIdentifier(
|
|
221
224
|
name=self.parent_action_name,
|
|
222
225
|
run=self._run_id,
|
|
223
226
|
),
|
|
@@ -288,7 +291,7 @@ class InformerCache:
|
|
|
288
291
|
@log
|
|
289
292
|
async def get_or_create(
|
|
290
293
|
self,
|
|
291
|
-
run_id:
|
|
294
|
+
run_id: identifier_pb2.RunIdentifier,
|
|
292
295
|
parent_action_name: str,
|
|
293
296
|
shared_queue: Queue,
|
|
294
297
|
state_service: StateService,
|
|
@@ -330,20 +333,28 @@ class InformerCache:
|
|
|
330
333
|
async def get(self, *, run_name: str, parent_action_name: str) -> Informer | None:
|
|
331
334
|
"""Get an informer by name"""
|
|
332
335
|
async with self._lock:
|
|
333
|
-
return self._cache.get(
|
|
336
|
+
return self._cache.get(
|
|
337
|
+
Informer.mkname(run_name=run_name, parent_action_name=parent_action_name),
|
|
338
|
+
None,
|
|
339
|
+
)
|
|
334
340
|
|
|
335
341
|
@log
|
|
336
342
|
async def remove(self, *, run_name: str, parent_action_name: str) -> Informer | None:
|
|
337
343
|
"""Remove an informer from the cache"""
|
|
338
344
|
async with self._lock:
|
|
339
|
-
return self._cache.pop(
|
|
345
|
+
return self._cache.pop(
|
|
346
|
+
Informer.mkname(run_name=run_name, parent_action_name=parent_action_name),
|
|
347
|
+
None,
|
|
348
|
+
)
|
|
340
349
|
|
|
341
350
|
async def has(self, *, run_name: str, parent_action_name: str) -> bool:
|
|
342
351
|
"""Check if an informer exists in the cache"""
|
|
343
352
|
async with self._lock:
|
|
344
353
|
return Informer.mkname(run_name=run_name, parent_action_name=parent_action_name) in self._cache
|
|
345
354
|
|
|
346
|
-
async def count_started_pending_terminal_actions(
|
|
355
|
+
async def count_started_pending_terminal_actions(
|
|
356
|
+
self,
|
|
357
|
+
) -> AsyncIterator[Tuple[int, int, int]]:
|
|
347
358
|
"""Log resource stats"""
|
|
348
359
|
async with self._lock:
|
|
349
360
|
for informer in self._cache.values():
|
|
@@ -59,7 +59,9 @@ def translate_task_to_wire(
|
|
|
59
59
|
)
|
|
60
60
|
|
|
61
61
|
|
|
62
|
-
def get_security_context(
|
|
62
|
+
def get_security_context(
|
|
63
|
+
secrets: Optional[SecretRequest],
|
|
64
|
+
) -> Optional[security_pb2.SecurityContext]:
|
|
63
65
|
"""
|
|
64
66
|
Get the security context from a list of secrets. This is a placeholder function.
|
|
65
67
|
|
|
@@ -86,7 +88,9 @@ def get_security_context(secrets: Optional[SecretRequest]) -> Optional[security_
|
|
|
86
88
|
)
|
|
87
89
|
|
|
88
90
|
|
|
89
|
-
def get_proto_retry_strategy(
|
|
91
|
+
def get_proto_retry_strategy(
|
|
92
|
+
retries: RetryStrategy | int | None,
|
|
93
|
+
) -> Optional[literals_pb2.RetryStrategy]:
|
|
90
94
|
if retries is None:
|
|
91
95
|
return None
|
|
92
96
|
|
|
@@ -158,7 +162,7 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
|
|
|
158
162
|
discoverable=cache_enabled,
|
|
159
163
|
discovery_version=cache_version,
|
|
160
164
|
cache_serializable=task_cache.serialize,
|
|
161
|
-
cache_ignore_input_vars=task_cache.get_ignored_inputs() if cache_enabled else None,
|
|
165
|
+
cache_ignore_input_vars=(task_cache.get_ignored_inputs() if cache_enabled else None),
|
|
162
166
|
runtime=tasks_pb2.RuntimeMetadata(
|
|
163
167
|
version=flyte.version(),
|
|
164
168
|
type=tasks_pb2.RuntimeMetadata.RuntimeType.FLYTE_SDK,
|
|
@@ -166,7 +170,7 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
|
|
|
166
170
|
),
|
|
167
171
|
retries=get_proto_retry_strategy(task.retries),
|
|
168
172
|
timeout=get_proto_timeout(task.timeout),
|
|
169
|
-
pod_template_name=task.pod_template if task.pod_template and isinstance(task.pod_template, str) else None,
|
|
173
|
+
pod_template_name=(task.pod_template if task.pod_template and isinstance(task.pod_template, str) else None),
|
|
170
174
|
interruptible=task.interruptable,
|
|
171
175
|
generates_deck=wrappers_pb2.BoolValue(value=task.report),
|
|
172
176
|
),
|
|
@@ -300,7 +304,9 @@ def _get_k8s_pod(primary_container: tasks_pb2.Container, pod_template: PodTempla
|
|
|
300
304
|
return tasks_pb2.K8sPod(pod_spec=pod_spec, metadata=metadata)
|
|
301
305
|
|
|
302
306
|
|
|
303
|
-
def extract_code_bundle(
|
|
307
|
+
def extract_code_bundle(
|
|
308
|
+
task_spec: task_definition_pb2.TaskSpec,
|
|
309
|
+
) -> Optional[CodeBundle]:
|
|
304
310
|
"""
|
|
305
311
|
Extract the code bundle from the task spec.
|
|
306
312
|
:param task_spec: The task spec to extract the code bundle from.
|
|
@@ -326,5 +332,10 @@ def extract_code_bundle(task_spec: task_definition_pb2.TaskSpec) -> Optional[Cod
|
|
|
326
332
|
# Extract the version from the argument
|
|
327
333
|
version = container.args[i + 1] if i + 1 < len(container.args) else ""
|
|
328
334
|
if pkl_path or tgz_path:
|
|
329
|
-
return CodeBundle(
|
|
335
|
+
return CodeBundle(
|
|
336
|
+
destination=dest_path,
|
|
337
|
+
tgz=tgz_path,
|
|
338
|
+
pkl=pkl_path,
|
|
339
|
+
computed_version=version,
|
|
340
|
+
)
|
|
330
341
|
return None
|
|
@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
|
|
|
14
14
|
from flyte._protos.validate.validate import validate_pb2 as validate_dot_validate__pb2
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x63ommon/identifier.proto\x12\x0f\x63loudidl.common\x1a\x17validate/validate.proto\"~\n\x11ProjectIdentifier\x12+\n\x0corganization\x18\x01 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x0corganization\x12\x1f\n\x06\x64omain\x18\x02 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x06\x64omain\x12\x1b\n\x04name\x18\x03 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x04name\"T\n\x11\x43lusterIdentifier\x12\"\n\x0corganization\x18\x01 \x01(\tR\x0corganization\x12\x1b\n\x04name\x18\x02 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x04name\"O\n\x15\x43lusterPoolIdentifier\x12\"\n\x0corganization\x18\x01 \x01(\tR\x0corganization\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\"_\n\x17\x43lusterConfigIdentifier\x12+\n\x0corganization\x18\x01 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x0corganization\x12\x17\n\x02id\x18\x02 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x02id\"3\n\x0eUserIdentifier\x12!\n\x07subject\x18\x01 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x07subject\":\n\x15\x41pplicationIdentifier\x12!\n\x07subject\x18\x01 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x07subject\"Q\n\x0eRoleIdentifier\x12\"\n\x0corganization\x18\x01 \x01(\tR\x0corganization\x12\x1b\n\x04name\x18\x02 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x04name\"O\n\rOrgIdentifier\x12>\n\x04name\x18\x01 \x01(\tB*\xfa\x42\'r%\x10\x01\x18?2\x1f^[a-z0-9]([-a-z0-9]*[a-z0-9])?$R\x04name\"y\n\x18ManagedClusterIdentifier\x12\x1b\n\x04name\x18\x02 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x04name\x12:\n\x03org\x18\x03 \x01(\x0b\x32\x1e.cloudidl.common.OrgIdentifierB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x03orgJ\x04\x08\x01\x10\x02\"S\n\x10PolicyIdentifier\x12\"\n\x0corganization\x18\x01 \x01(\tR\x0corganization\x12\x1b\n\x04name\x18\x02 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\
|
|
17
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x63ommon/identifier.proto\x12\x0f\x63loudidl.common\x1a\x17validate/validate.proto\"~\n\x11ProjectIdentifier\x12+\n\x0corganization\x18\x01 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x0corganization\x12\x1f\n\x06\x64omain\x18\x02 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x06\x64omain\x12\x1b\n\x04name\x18\x03 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x04name\"T\n\x11\x43lusterIdentifier\x12\"\n\x0corganization\x18\x01 \x01(\tR\x0corganization\x12\x1b\n\x04name\x18\x02 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x04name\"O\n\x15\x43lusterPoolIdentifier\x12\"\n\x0corganization\x18\x01 \x01(\tR\x0corganization\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\"_\n\x17\x43lusterConfigIdentifier\x12+\n\x0corganization\x18\x01 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x0corganization\x12\x17\n\x02id\x18\x02 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x02id\"3\n\x0eUserIdentifier\x12!\n\x07subject\x18\x01 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x07subject\":\n\x15\x41pplicationIdentifier\x12!\n\x07subject\x18\x01 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x07subject\"Q\n\x0eRoleIdentifier\x12\"\n\x0corganization\x18\x01 \x01(\tR\x0corganization\x12\x1b\n\x04name\x18\x02 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x04name\"O\n\rOrgIdentifier\x12>\n\x04name\x18\x01 \x01(\tB*\xfa\x42\'r%\x10\x01\x18?2\x1f^[a-z0-9]([-a-z0-9]*[a-z0-9])?$R\x04name\"y\n\x18ManagedClusterIdentifier\x12\x1b\n\x04name\x18\x02 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x04name\x12:\n\x03org\x18\x03 \x01(\x0b\x32\x1e.cloudidl.common.OrgIdentifierB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x03orgJ\x04\x08\x01\x10\x02\"S\n\x10PolicyIdentifier\x12\"\n\x0corganization\x18\x01 \x01(\tR\x0corganization\x12\x1b\n\x04name\x18\x02 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x04name\"\x93\x01\n\rRunIdentifier\x12\x1b\n\x03org\x18\x01 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x03org\x12#\n\x07project\x18\x02 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x07project\x12!\n\x06\x64omain\x18\x03 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x06\x64omain\x12\x1d\n\x04name\x18\x04 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18\x1eR\x04name\"m\n\x10\x41\x63tionIdentifier\x12:\n\x03run\x18\x01 \x01(\x0b\x32\x1e.cloudidl.common.RunIdentifierB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x03run\x12\x1d\n\x04name\x18\x02 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18\x1eR\x04name\"\x86\x01\n\x17\x41\x63tionAttemptIdentifier\x12H\n\taction_id\x18\x01 \x01(\x0b\x32!.cloudidl.common.ActionIdentifierB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x08\x61\x63tionId\x12!\n\x07\x61ttempt\x18\x02 \x01(\rB\x07\xfa\x42\x04*\x02 \x00R\x07\x61ttemptB\xb0\x01\n\x13\x63om.cloudidl.commonB\x0fIdentifierProtoH\x02P\x01Z)github.com/unionai/cloud/gen/pb-go/common\xa2\x02\x03\x43\x43X\xaa\x02\x0f\x43loudidl.Common\xca\x02\x0f\x43loudidl\\Common\xe2\x02\x1b\x43loudidl\\Common\\GPBMetadata\xea\x02\x10\x43loudidl::Commonb\x06proto3')
|
|
18
18
|
|
|
19
19
|
_globals = globals()
|
|
20
20
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
@@ -48,6 +48,22 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
|
48
48
|
_MANAGEDCLUSTERIDENTIFIER.fields_by_name['org']._serialized_options = b'\372B\005\212\001\002\020\001'
|
|
49
49
|
_POLICYIDENTIFIER.fields_by_name['name']._options = None
|
|
50
50
|
_POLICYIDENTIFIER.fields_by_name['name']._serialized_options = b'\372B\004r\002\020\001'
|
|
51
|
+
_RUNIDENTIFIER.fields_by_name['org']._options = None
|
|
52
|
+
_RUNIDENTIFIER.fields_by_name['org']._serialized_options = b'\372B\006r\004\020\001\030?'
|
|
53
|
+
_RUNIDENTIFIER.fields_by_name['project']._options = None
|
|
54
|
+
_RUNIDENTIFIER.fields_by_name['project']._serialized_options = b'\372B\006r\004\020\001\030?'
|
|
55
|
+
_RUNIDENTIFIER.fields_by_name['domain']._options = None
|
|
56
|
+
_RUNIDENTIFIER.fields_by_name['domain']._serialized_options = b'\372B\006r\004\020\001\030?'
|
|
57
|
+
_RUNIDENTIFIER.fields_by_name['name']._options = None
|
|
58
|
+
_RUNIDENTIFIER.fields_by_name['name']._serialized_options = b'\372B\006r\004\020\001\030\036'
|
|
59
|
+
_ACTIONIDENTIFIER.fields_by_name['run']._options = None
|
|
60
|
+
_ACTIONIDENTIFIER.fields_by_name['run']._serialized_options = b'\372B\005\212\001\002\020\001'
|
|
61
|
+
_ACTIONIDENTIFIER.fields_by_name['name']._options = None
|
|
62
|
+
_ACTIONIDENTIFIER.fields_by_name['name']._serialized_options = b'\372B\006r\004\020\001\030\036'
|
|
63
|
+
_ACTIONATTEMPTIDENTIFIER.fields_by_name['action_id']._options = None
|
|
64
|
+
_ACTIONATTEMPTIDENTIFIER.fields_by_name['action_id']._serialized_options = b'\372B\005\212\001\002\020\001'
|
|
65
|
+
_ACTIONATTEMPTIDENTIFIER.fields_by_name['attempt']._options = None
|
|
66
|
+
_ACTIONATTEMPTIDENTIFIER.fields_by_name['attempt']._serialized_options = b'\372B\004*\002 \000'
|
|
51
67
|
_globals['_PROJECTIDENTIFIER']._serialized_start=69
|
|
52
68
|
_globals['_PROJECTIDENTIFIER']._serialized_end=195
|
|
53
69
|
_globals['_CLUSTERIDENTIFIER']._serialized_start=197
|
|
@@ -68,4 +84,10 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
|
68
84
|
_globals['_MANAGEDCLUSTERIDENTIFIER']._serialized_end=859
|
|
69
85
|
_globals['_POLICYIDENTIFIER']._serialized_start=861
|
|
70
86
|
_globals['_POLICYIDENTIFIER']._serialized_end=944
|
|
87
|
+
_globals['_RUNIDENTIFIER']._serialized_start=947
|
|
88
|
+
_globals['_RUNIDENTIFIER']._serialized_end=1094
|
|
89
|
+
_globals['_ACTIONIDENTIFIER']._serialized_start=1096
|
|
90
|
+
_globals['_ACTIONIDENTIFIER']._serialized_end=1205
|
|
91
|
+
_globals['_ACTIONATTEMPTIDENTIFIER']._serialized_start=1208
|
|
92
|
+
_globals['_ACTIONATTEMPTIDENTIFIER']._serialized_end=1342
|
|
71
93
|
# @@protoc_insertion_point(module_scope)
|