flyte 0.2.0b35__py3-none-any.whl → 0.2.0b37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (39) hide show
  1. flyte/_image.py +1 -1
  2. flyte/_internal/controllers/_local_controller.py +3 -2
  3. flyte/_internal/controllers/_trace.py +14 -10
  4. flyte/_internal/controllers/remote/_action.py +37 -7
  5. flyte/_internal/controllers/remote/_controller.py +43 -21
  6. flyte/_internal/controllers/remote/_core.py +32 -16
  7. flyte/_internal/controllers/remote/_informer.py +18 -7
  8. flyte/_internal/runtime/task_serde.py +17 -6
  9. flyte/_protos/common/identifier_pb2.py +23 -1
  10. flyte/_protos/common/identifier_pb2.pyi +28 -0
  11. flyte/_protos/workflow/queue_service_pb2.py +33 -29
  12. flyte/_protos/workflow/queue_service_pb2.pyi +34 -16
  13. flyte/_protos/workflow/run_definition_pb2.py +64 -71
  14. flyte/_protos/workflow/run_definition_pb2.pyi +44 -31
  15. flyte/_protos/workflow/run_logs_service_pb2.py +10 -10
  16. flyte/_protos/workflow/run_logs_service_pb2.pyi +3 -3
  17. flyte/_protos/workflow/run_service_pb2.py +54 -46
  18. flyte/_protos/workflow/run_service_pb2.pyi +32 -18
  19. flyte/_protos/workflow/run_service_pb2_grpc.py +34 -0
  20. flyte/_protos/workflow/state_service_pb2.py +20 -19
  21. flyte/_protos/workflow/state_service_pb2.pyi +13 -12
  22. flyte/_run.py +11 -6
  23. flyte/_trace.py +4 -10
  24. flyte/_version.py +2 -2
  25. flyte/migrate/__init__.py +1 -0
  26. flyte/migrate/dynamic.py +13 -0
  27. flyte/migrate/task.py +99 -0
  28. flyte/migrate/workflow.py +13 -0
  29. flyte/remote/_action.py +56 -25
  30. flyte/remote/_logs.py +4 -3
  31. flyte/remote/_run.py +5 -4
  32. flyte-0.2.0b37.dist-info/METADATA +371 -0
  33. {flyte-0.2.0b35.dist-info → flyte-0.2.0b37.dist-info}/RECORD +38 -33
  34. flyte-0.2.0b37.dist-info/licenses/LICENSE +201 -0
  35. flyte-0.2.0b35.dist-info/METADATA +0 -249
  36. {flyte-0.2.0b35.data → flyte-0.2.0b37.data}/scripts/runtime.py +0 -0
  37. {flyte-0.2.0b35.dist-info → flyte-0.2.0b37.dist-info}/WHEEL +0 -0
  38. {flyte-0.2.0b35.dist-info → flyte-0.2.0b37.dist-info}/entry_points.txt +0 -0
  39. {flyte-0.2.0b35.dist-info → flyte-0.2.0b37.dist-info}/top_level.txt +0 -0
flyte/_image.py CHANGED
@@ -262,7 +262,7 @@ class Env(Layer):
262
262
 
263
263
  Architecture = Literal["linux/amd64", "linux/arm64"]
264
264
 
265
- _BASE_REGISTRY = "ghcr.io/unionai-oss"
265
+ _BASE_REGISTRY = "ghcr.io/flyteorg"
266
266
  _DEFAULT_IMAGE_NAME = "flyte"
267
267
 
268
268
 
@@ -161,10 +161,10 @@ class LocalController:
161
161
  assert action_output_path
162
162
  return (
163
163
  TraceInfo(
164
+ name=_func.__name__,
164
165
  action=action_id,
165
166
  interface=_interface,
166
167
  inputs_path=action_output_path,
167
- name=_func.__name__,
168
168
  ),
169
169
  True,
170
170
  )
@@ -189,7 +189,8 @@ class LocalController:
189
189
  converted_error = convert.convert_from_native_to_error(info.error)
190
190
  assert converted_error
191
191
  assert info.action
192
- assert info.duration
192
+ assert info.start_time
193
+ assert info.end_time
193
194
 
194
195
  async def submit_task_ref(self, _task: task_definition_pb2.TaskDetails, *args, **kwargs) -> Any:
195
196
  raise flyte.errors.ReferenceTaskError("Reference tasks cannot be executed locally, only remotely.")
@@ -1,5 +1,4 @@
1
- from dataclasses import dataclass
2
- from datetime import timedelta
1
+ from dataclasses import dataclass, field
3
2
  from typing import Any, Optional
4
3
 
5
4
  from flyte.models import ActionID, NativeInterface
@@ -12,30 +11,35 @@ class TraceInfo:
12
11
  the action is completed.
13
12
  """
14
13
 
14
+ name: str
15
15
  action: ActionID
16
16
  interface: NativeInterface
17
17
  inputs_path: str
18
- duration: Optional[timedelta] = None
18
+ start_time: float = field(init=False, default=0.0)
19
+ end_time: float = field(init=False, default=0.0)
19
20
  output: Optional[Any] = None
20
21
  error: Optional[Exception] = None
21
- name: str = ""
22
22
 
23
- def add_outputs(self, output: Any, duration: timedelta):
23
+ def add_outputs(self, output: Any, start_time: float, end_time: float):
24
24
  """
25
25
  Add outputs to the trace information.
26
26
  :param output: Output of the action
27
- :param duration: Duration of the action
27
+ :param start_time: Start time of the action
28
+ :param end_time: End time of the action
28
29
  :return:
29
30
  """
30
31
  self.output = output
31
- self.duration = duration
32
+ self.start_time = start_time
33
+ self.end_time = end_time
32
34
 
33
- def add_error(self, error: Exception, duration: timedelta):
35
+ def add_error(self, error: Exception, start_time: float, end_time: float):
34
36
  """
35
37
  Add error to the trace information.
36
38
  :param error: Error of the action
37
- :param duration: Duration of the action
39
+ :param start_time: Start time of the action
40
+ :param end_time: End time of the action
38
41
  :return:
39
42
  """
40
43
  self.error = error
41
- self.duration = duration
44
+ self.start_time = start_time
45
+ self.end_time = end_time
@@ -1,11 +1,18 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import Any, Literal
4
+ from typing import Literal
5
5
 
6
6
  from flyteidl.core import execution_pb2
7
-
8
- from flyte._protos.workflow import run_definition_pb2, state_service_pb2, task_definition_pb2
7
+ from google.protobuf import timestamp_pb2
8
+
9
+ from flyte._protos.common import identifier_pb2
10
+ from flyte._protos.workflow import (
11
+ queue_service_pb2,
12
+ run_definition_pb2,
13
+ state_service_pb2,
14
+ task_definition_pb2,
15
+ )
9
16
  from flyte.models import GroupData
10
17
 
11
18
  ActionType = Literal["task", "trace"]
@@ -18,12 +25,13 @@ class Action:
18
25
  Holds the inmemory state of a task. It is combined representation of local and remote states.
19
26
  """
20
27
 
21
- action_id: run_definition_pb2.ActionIdentifier
28
+ action_id: identifier_pb2.ActionIdentifier
22
29
  parent_action_name: str
23
30
  type: ActionType = "task" # type of action, task or trace
24
31
  friendly_name: str | None = None
25
32
  group: GroupData | None = None
26
33
  task: task_definition_pb2.TaskSpec | None = None
34
+ trace: queue_service_pb2.TraceAction | None = None
27
35
  inputs_uri: str | None = None
28
36
  run_output_base: str | None = None
29
37
  realized_outputs_uri: str | None = None
@@ -108,7 +116,7 @@ class Action:
108
116
  def from_task(
109
117
  cls,
110
118
  parent_action_name: str,
111
- sub_action_id: run_definition_pb2.ActionIdentifier,
119
+ sub_action_id: identifier_pb2.ActionIdentifier,
112
120
  group_data: GroupData | None,
113
121
  task_spec: task_definition_pb2.TaskSpec,
114
122
  inputs_uri: str,
@@ -153,16 +161,27 @@ class Action:
153
161
  def from_trace(
154
162
  cls,
155
163
  parent_action_name: str,
156
- action_id: run_definition_pb2.ActionIdentifier,
164
+ action_id: identifier_pb2.ActionIdentifier,
157
165
  friendly_name: str,
158
166
  group_data: GroupData | None,
159
- trace_spec: Any, # TODO
160
167
  inputs_uri: str,
161
168
  outputs_uri: str,
169
+ start_time: float, # Unix timestamp in seconds with fractional seconds
170
+ end_time: float, # Unix timestamp in seconds with fractional seconds
171
+ run_output_base: str,
172
+ report_uri: str | None = None,
162
173
  ) -> Action:
163
174
  """
164
175
  This creates a new action for tracing purposes. It is used to track the execution of a trace.
165
176
  """
177
+ st = timestamp_pb2.Timestamp()
178
+ st.FromSeconds(int(start_time))
179
+ st.nanos = int((start_time % 1) * 1e9)
180
+
181
+ et = timestamp_pb2.Timestamp()
182
+ et.FromSeconds(int(end_time))
183
+ et.nanos = int((end_time % 1) * 1e9)
184
+
166
185
  return cls(
167
186
  action_id=action_id,
168
187
  parent_action_name=parent_action_name,
@@ -172,4 +191,15 @@ class Action:
172
191
  inputs_uri=inputs_uri,
173
192
  realized_outputs_uri=outputs_uri,
174
193
  phase=run_definition_pb2.Phase.PHASE_SUCCEEDED,
194
+ run_output_base=run_output_base,
195
+ trace=queue_service_pb2.TraceAction(
196
+ name=friendly_name,
197
+ phase=run_definition_pb2.Phase.PHASE_SUCCEEDED,
198
+ start_time=st,
199
+ end_time=et,
200
+ outputs=run_definition_pb2.OutputReferences(
201
+ output_uri=outputs_uri,
202
+ report_uri=report_uri,
203
+ ),
204
+ ),
175
205
  )
@@ -22,6 +22,7 @@ from flyte._internal.controllers.remote._service_protocol import ClientSet
22
22
  from flyte._internal.runtime import convert, io
23
23
  from flyte._internal.runtime.task_serde import translate_task_to_wire
24
24
  from flyte._logging import logger
25
+ from flyte._protos.common import identifier_pb2
25
26
  from flyte._protos.workflow import run_definition_pb2, task_definition_pb2
26
27
  from flyte._task import TaskTemplate
27
28
  from flyte._utils.helpers import _selector_policy
@@ -131,11 +132,18 @@ class RemoteController(Controller):
131
132
  Generate a task call sequence for the given task object and action ID.
132
133
  This is used to track the number of times a task is called within an action.
133
134
  """
134
- current_action_sequencer = self._parent_action_task_call_sequence[unique_action_name(action_id)]
135
+ uniq = unique_action_name(action_id)
136
+ current_action_sequencer = self._parent_action_task_call_sequence[uniq]
135
137
  current_task_id = id(task_obj)
136
138
  v = current_action_sequencer[current_task_id]
137
139
  new_seq = v + 1
138
140
  current_action_sequencer[current_task_id] = new_seq
141
+ name = ""
142
+ if hasattr(task_obj, "__name__"):
143
+ name = task_obj.__name__
144
+ elif hasattr(task_obj, "name"):
145
+ name = task_obj.name
146
+ logger.warning(f"For action {uniq}, task {name} call sequence is {new_seq}")
139
147
  return new_seq
140
148
 
141
149
  async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwargs) -> Any:
@@ -178,6 +186,7 @@ class RemoteController(Controller):
178
186
  sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
179
187
  tctx, task_spec, inputs_hash, _task_call_seq
180
188
  )
189
+ logger.warning(f"Sub action {sub_action_id} output path {sub_action_output_path}")
181
190
 
182
191
  serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
183
192
  inputs_uri = io.inputs_path(sub_action_output_path)
@@ -204,9 +213,9 @@ class RemoteController(Controller):
204
213
  inputs_hash = None # type: ignore
205
214
 
206
215
  action = Action.from_task(
207
- sub_action_id=run_definition_pb2.ActionIdentifier(
216
+ sub_action_id=identifier_pb2.ActionIdentifier(
208
217
  name=sub_action_id.name,
209
- run=run_definition_pb2.RunIdentifier(
218
+ run=identifier_pb2.RunIdentifier(
210
219
  name=current_action_id.run_name,
211
220
  project=current_action_id.project,
212
221
  domain=current_action_id.domain,
@@ -293,7 +302,9 @@ class RemoteController(Controller):
293
302
  self._submit_loop.set_exception_handler(exc_handler)
294
303
 
295
304
  self._submit_thread = threading.Thread(
296
- name=f"remote-controller-{os.getpid()}-submitter", daemon=True, target=self._sync_thread_loop_runner
305
+ name=f"remote-controller-{os.getpid()}-submitter",
306
+ daemon=True,
307
+ target=self._sync_thread_loop_runner,
297
308
  )
298
309
  self._submit_thread.start()
299
310
 
@@ -307,7 +318,7 @@ class RemoteController(Controller):
307
318
  This method is invoked when the parent action is finished. It will finalize the run and upload the outputs
308
319
  to the control plane.
309
320
  """
310
- run_id = run_definition_pb2.RunIdentifier(
321
+ run_id = identifier_pb2.RunIdentifier(
311
322
  name=action_id.run_name,
312
323
  project=action_id.project,
313
324
  domain=action_id.domain,
@@ -350,9 +361,9 @@ class RemoteController(Controller):
350
361
  serialized_inputs = None # type: ignore
351
362
 
352
363
  prev_action = await self.get_action(
353
- run_definition_pb2.ActionIdentifier(
364
+ identifier_pb2.ActionIdentifier(
354
365
  name=sub_action_id.name,
355
- run=run_definition_pb2.RunIdentifier(
366
+ run=identifier_pb2.RunIdentifier(
356
367
  name=current_action_id.run_name,
357
368
  project=current_action_id.project,
358
369
  domain=current_action_id.domain,
@@ -363,21 +374,27 @@ class RemoteController(Controller):
363
374
  )
364
375
 
365
376
  if prev_action is None:
366
- return TraceInfo(sub_action_id, _interface, inputs_uri), False
377
+ return TraceInfo(func_name, sub_action_id, _interface, inputs_uri), False
367
378
 
368
379
  if prev_action.phase == run_definition_pb2.PHASE_FAILED:
369
380
  if prev_action.has_error():
370
381
  exc = convert.convert_error_to_native(prev_action.err)
371
- return TraceInfo(sub_action_id, _interface, inputs_uri, error=exc), True
382
+ return (
383
+ TraceInfo(func_name, sub_action_id, _interface, inputs_uri, error=exc),
384
+ True,
385
+ )
372
386
  else:
373
387
  logger.warning(f"Action {prev_action.action_id.name} failed, but no error was found, re-running trace!")
374
388
  elif prev_action.realized_outputs_uri is not None:
375
389
  outputs_file_path = io.outputs_path(prev_action.realized_outputs_uri)
376
390
  o = await io.load_outputs(outputs_file_path)
377
391
  outputs = await convert.convert_outputs_to_native(_interface, o)
378
- return TraceInfo(sub_action_id, _interface, inputs_uri, output=outputs), True
392
+ return (
393
+ TraceInfo(func_name, sub_action_id, _interface, inputs_uri, output=outputs),
394
+ True,
395
+ )
379
396
 
380
- return TraceInfo(sub_action_id, _interface, inputs_uri), False
397
+ return TraceInfo(func_name, sub_action_id, _interface, inputs_uri), False
381
398
 
382
399
  async def record_trace(self, info: TraceInfo):
383
400
  """
@@ -391,26 +408,29 @@ class RemoteController(Controller):
391
408
  raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
392
409
 
393
410
  current_action_id = tctx.action
394
- current_output_path = tctx.output_path
395
- sub_run_output_path = storage.join(current_output_path, info.action.name)
411
+ sub_run_output_path = storage.join(tctx.run_base_dir, info.action.name)
412
+ print(f"Sub run output path for {info.name} is {sub_run_output_path}", flush=True)
396
413
 
397
414
  if info.interface.has_outputs():
398
415
  outputs_file_path: str = ""
399
416
  if info.output:
400
417
  outputs = await convert.convert_from_native_to_outputs(info.output, info.interface)
401
418
  outputs_file_path = io.outputs_path(sub_run_output_path)
402
- await io.upload_outputs(outputs, outputs_file_path)
419
+ print(
420
+ f"Uploading outputs for {info.name} Outputs file path: {outputs_file_path}",
421
+ flush=True,
422
+ )
423
+ await io.upload_outputs(outputs, sub_run_output_path)
403
424
  elif info.error:
404
425
  err = convert.convert_from_native_to_error(info.error)
405
- error_path = io.error_path(sub_run_output_path)
406
- await io.upload_error(err.err, error_path)
426
+ await io.upload_error(err.err, sub_run_output_path)
407
427
  else:
408
428
  raise flyte.errors.RuntimeSystemError("BadTraceInfo", "Trace info does not have output or error")
409
429
  trace_action = Action.from_trace(
410
430
  parent_action_name=current_action_id.name,
411
- action_id=run_definition_pb2.ActionIdentifier(
431
+ action_id=identifier_pb2.ActionIdentifier(
412
432
  name=info.action.name,
413
- run=run_definition_pb2.RunIdentifier(
433
+ run=identifier_pb2.RunIdentifier(
414
434
  name=current_action_id.run_name,
415
435
  project=current_action_id.project,
416
436
  domain=current_action_id.domain,
@@ -419,9 +439,11 @@ class RemoteController(Controller):
419
439
  ),
420
440
  inputs_uri=info.inputs_path,
421
441
  outputs_uri=outputs_file_path,
422
- trace_spec="",
423
442
  friendly_name=info.name,
424
443
  group_data=tctx.group_data,
444
+ run_output_base=tctx.run_base_dir,
445
+ start_time=info.start_time,
446
+ end_time=info.end_time,
425
447
  )
426
448
  try:
427
449
  logger.info(
@@ -478,9 +500,9 @@ class RemoteController(Controller):
478
500
  inputs_hash = None # type: ignore
479
501
 
480
502
  action = Action.from_task(
481
- sub_action_id=run_definition_pb2.ActionIdentifier(
503
+ sub_action_id=identifier_pb2.ActionIdentifier(
482
504
  name=sub_action_id.name,
483
- run=run_definition_pb2.RunIdentifier(
505
+ run=identifier_pb2.RunIdentifier(
484
506
  name=current_action_id.run_name,
485
507
  project=current_action_id.project,
486
508
  domain=current_action_id.domain,
@@ -11,7 +11,11 @@ from google.protobuf.wrappers_pb2 import StringValue
11
11
 
12
12
  import flyte.errors
13
13
  from flyte._logging import log, logger
14
- from flyte._protos.workflow import queue_service_pb2, run_definition_pb2, task_definition_pb2
14
+ from flyte._protos.common import identifier_pb2
15
+ from flyte._protos.workflow import (
16
+ queue_service_pb2,
17
+ task_definition_pb2,
18
+ )
15
19
  from flyte.errors import RuntimeSystemError
16
20
 
17
21
  from ._action import Action
@@ -80,9 +84,7 @@ class Controller:
80
84
  """Public API to submit a resource and wait for completion"""
81
85
  return await self._run_coroutine_in_controller_thread(self._bg_submit_action(action))
82
86
 
83
- async def get_action(
84
- self, action_id: run_definition_pb2.ActionIdentifier, parent_action_name: str
85
- ) -> Optional[Action]:
87
+ async def get_action(self, action_id: identifier_pb2.ActionIdentifier, parent_action_name: str) -> Optional[Action]:
86
88
  """Get the action from the informer"""
87
89
  informer = await self._informers.get(run_name=action_id.run.name, parent_action_name=parent_action_name)
88
90
  if informer:
@@ -94,7 +96,10 @@ class Controller:
94
96
  return await self._run_coroutine_in_controller_thread(self._bg_cancel_action(action))
95
97
 
96
98
  async def _finalize_parent_action(
97
- self, run_id: run_definition_pb2.RunIdentifier, parent_action_name: str, timeout: Optional[float] = None
99
+ self,
100
+ run_id: identifier_pb2.RunIdentifier,
101
+ parent_action_name: str,
102
+ timeout: Optional[float] = None,
98
103
  ):
99
104
  """Finalize the parent run"""
100
105
  await self._run_coroutine_in_controller_thread(
@@ -124,7 +129,8 @@ class Controller:
124
129
  """Watch for errors in the background thread"""
125
130
  await self._run_coroutine_in_controller_thread(self._bg_watch_for_errors())
126
131
  raise RuntimeSystemError(
127
- code="InformerWatchFailure", message=f"Controller thread failed with exception: {self._get_exception()}"
132
+ code="InformerWatchFailure",
133
+ message=f"Controller thread failed with exception: {self._get_exception()}",
128
134
  )
129
135
 
130
136
  @log
@@ -162,7 +168,8 @@ class Controller:
162
168
 
163
169
  if self._get_exception():
164
170
  raise RuntimeSystemError(
165
- type(self._get_exception()).__name__, f"Controller thread startup failed: {self._get_exception()}"
171
+ type(self._get_exception()).__name__,
172
+ f"Controller thread startup failed: {self._get_exception()}",
166
173
  )
167
174
 
168
175
  logger.info(f"Controller started in thread: {self._thread.name}")
@@ -223,7 +230,7 @@ class Controller:
223
230
  logger.debug(f"Controller thread exiting: {threading.current_thread().name}")
224
231
 
225
232
  async def _bg_get_action(
226
- self, action_id: run_definition_pb2.ActionIdentifier, parent_action_name: str
233
+ self, action_id: identifier_pb2.ActionIdentifier, parent_action_name: str
227
234
  ) -> Optional[Action]:
228
235
  """Get the action from the informer"""
229
236
  # Ensure the informer is created and wait for it to be ready
@@ -240,7 +247,10 @@ class Controller:
240
247
  return None
241
248
 
242
249
  async def _bg_finalize_informer(
243
- self, run_id: run_definition_pb2.RunIdentifier, parent_action_name: str, timeout: Optional[float] = None
250
+ self,
251
+ run_id: identifier_pb2.RunIdentifier,
252
+ parent_action_name: str,
253
+ timeout: Optional[float] = None,
244
254
  ):
245
255
  informer = await self._informers.remove(run_name=run_id.name, parent_action_name=parent_action_name)
246
256
  if informer:
@@ -294,7 +304,10 @@ class Controller:
294
304
  # )
295
305
  logger.info(f"Successfully cancelled action: {action.name}")
296
306
  except grpc.aio.AioRpcError as e:
297
- if e.code() in [grpc.StatusCode.NOT_FOUND, grpc.StatusCode.FAILED_PRECONDITION]:
307
+ if e.code() in [
308
+ grpc.StatusCode.NOT_FOUND,
309
+ grpc.StatusCode.FAILED_PRECONDITION,
310
+ ]:
298
311
  logger.info(f"Action {action.name} not found, assumed completed or cancelled.")
299
312
  return
300
313
  else:
@@ -333,10 +346,8 @@ class Controller:
333
346
  spec=action.task,
334
347
  cache_key=cache_key,
335
348
  )
336
- else:
337
- trace = queue_service_pb2.TraceAction(
338
- name=action.friendly_name,
339
- )
349
+ elif action.type == "trace":
350
+ trace = action.trace
340
351
 
341
352
  logger.debug(f"Attempting to launch action: {action.name}")
342
353
  try:
@@ -380,7 +391,11 @@ class Controller:
380
391
  async def _bg_log_stats(self):
381
392
  """Periodically log resource stats if debug is enabled"""
382
393
  while self._running:
383
- async for started, pending, terminal in self._informers.count_started_pending_terminal_actions():
394
+ async for (
395
+ started,
396
+ pending,
397
+ terminal,
398
+ ) in self._informers.count_started_pending_terminal_actions():
384
399
  logger.info(f"Resource stats: Started={started}, Pending={pending}, Terminal={terminal}")
385
400
  await asyncio.sleep(self._resource_log_interval)
386
401
 
@@ -407,7 +422,8 @@ class Controller:
407
422
  err.__cause__ = e
408
423
  action.set_client_error(err)
409
424
  informer = await self._informers.get(
410
- run_name=action.run_name, parent_action_name=action.parent_action_name
425
+ run_name=action.run_name,
426
+ parent_action_name=action.parent_action_name,
411
427
  )
412
428
  if informer:
413
429
  await informer.fire_completion_event(action.name)
@@ -7,6 +7,7 @@ from typing import AsyncIterator, Callable, Dict, Optional, Tuple, cast
7
7
  import grpc.aio
8
8
 
9
9
  from flyte._logging import log, logger
10
+ from flyte._protos.common import identifier_pb2
10
11
  from flyte._protos.workflow import run_definition_pb2, state_service_pb2
11
12
 
12
13
  from ._action import Action
@@ -42,7 +43,9 @@ class ActionCache:
42
43
  if state.output_uri:
43
44
  logger.debug(f"Output URI: {state.output_uri}")
44
45
  else:
45
- logger.warning(f"{state.action_id.name} has no output URI")
46
+ logger.warning(
47
+ f"{state.action_id.name} has no output URI, in phase {run_definition_pb2.Phase.Name(state.phase)}"
48
+ )
46
49
  if state.phase == run_definition_pb2.Phase.PHASE_FAILED:
47
50
  logger.error(
48
51
  f"Action {state.action_id.name} failed with error (msg):"
@@ -125,7 +128,7 @@ class Informer:
125
128
 
126
129
  def __init__(
127
130
  self,
128
- run_id: run_definition_pb2.RunIdentifier,
131
+ run_id: identifier_pb2.RunIdentifier,
129
132
  parent_action_name: str,
130
133
  shared_queue: Queue,
131
134
  client: Optional[StateService] = None,
@@ -217,7 +220,7 @@ class Informer:
217
220
  try:
218
221
  watcher = self._client.Watch(
219
222
  state_service_pb2.WatchRequest(
220
- parent_action_id=run_definition_pb2.ActionIdentifier(
223
+ parent_action_id=identifier_pb2.ActionIdentifier(
221
224
  name=self.parent_action_name,
222
225
  run=self._run_id,
223
226
  ),
@@ -288,7 +291,7 @@ class InformerCache:
288
291
  @log
289
292
  async def get_or_create(
290
293
  self,
291
- run_id: run_definition_pb2.RunIdentifier,
294
+ run_id: identifier_pb2.RunIdentifier,
292
295
  parent_action_name: str,
293
296
  shared_queue: Queue,
294
297
  state_service: StateService,
@@ -330,20 +333,28 @@ class InformerCache:
330
333
  async def get(self, *, run_name: str, parent_action_name: str) -> Informer | None:
331
334
  """Get an informer by name"""
332
335
  async with self._lock:
333
- return self._cache.get(Informer.mkname(run_name=run_name, parent_action_name=parent_action_name), None)
336
+ return self._cache.get(
337
+ Informer.mkname(run_name=run_name, parent_action_name=parent_action_name),
338
+ None,
339
+ )
334
340
 
335
341
  @log
336
342
  async def remove(self, *, run_name: str, parent_action_name: str) -> Informer | None:
337
343
  """Remove an informer from the cache"""
338
344
  async with self._lock:
339
- return self._cache.pop(Informer.mkname(run_name=run_name, parent_action_name=parent_action_name), None)
345
+ return self._cache.pop(
346
+ Informer.mkname(run_name=run_name, parent_action_name=parent_action_name),
347
+ None,
348
+ )
340
349
 
341
350
  async def has(self, *, run_name: str, parent_action_name: str) -> bool:
342
351
  """Check if an informer exists in the cache"""
343
352
  async with self._lock:
344
353
  return Informer.mkname(run_name=run_name, parent_action_name=parent_action_name) in self._cache
345
354
 
346
- async def count_started_pending_terminal_actions(self) -> AsyncIterator[Tuple[int, int, int]]:
355
+ async def count_started_pending_terminal_actions(
356
+ self,
357
+ ) -> AsyncIterator[Tuple[int, int, int]]:
347
358
  """Log resource stats"""
348
359
  async with self._lock:
349
360
  for informer in self._cache.values():
@@ -59,7 +59,9 @@ def translate_task_to_wire(
59
59
  )
60
60
 
61
61
 
62
- def get_security_context(secrets: Optional[SecretRequest]) -> Optional[security_pb2.SecurityContext]:
62
+ def get_security_context(
63
+ secrets: Optional[SecretRequest],
64
+ ) -> Optional[security_pb2.SecurityContext]:
63
65
  """
64
66
  Get the security context from a list of secrets. This is a placeholder function.
65
67
 
@@ -86,7 +88,9 @@ def get_security_context(secrets: Optional[SecretRequest]) -> Optional[security_
86
88
  )
87
89
 
88
90
 
89
- def get_proto_retry_strategy(retries: RetryStrategy | int | None) -> Optional[literals_pb2.RetryStrategy]:
91
+ def get_proto_retry_strategy(
92
+ retries: RetryStrategy | int | None,
93
+ ) -> Optional[literals_pb2.RetryStrategy]:
90
94
  if retries is None:
91
95
  return None
92
96
 
@@ -158,7 +162,7 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
158
162
  discoverable=cache_enabled,
159
163
  discovery_version=cache_version,
160
164
  cache_serializable=task_cache.serialize,
161
- cache_ignore_input_vars=task_cache.get_ignored_inputs() if cache_enabled else None,
165
+ cache_ignore_input_vars=(task_cache.get_ignored_inputs() if cache_enabled else None),
162
166
  runtime=tasks_pb2.RuntimeMetadata(
163
167
  version=flyte.version(),
164
168
  type=tasks_pb2.RuntimeMetadata.RuntimeType.FLYTE_SDK,
@@ -166,7 +170,7 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
166
170
  ),
167
171
  retries=get_proto_retry_strategy(task.retries),
168
172
  timeout=get_proto_timeout(task.timeout),
169
- pod_template_name=task.pod_template if task.pod_template and isinstance(task.pod_template, str) else None,
173
+ pod_template_name=(task.pod_template if task.pod_template and isinstance(task.pod_template, str) else None),
170
174
  interruptible=task.interruptable,
171
175
  generates_deck=wrappers_pb2.BoolValue(value=task.report),
172
176
  ),
@@ -300,7 +304,9 @@ def _get_k8s_pod(primary_container: tasks_pb2.Container, pod_template: PodTempla
300
304
  return tasks_pb2.K8sPod(pod_spec=pod_spec, metadata=metadata)
301
305
 
302
306
 
303
- def extract_code_bundle(task_spec: task_definition_pb2.TaskSpec) -> Optional[CodeBundle]:
307
+ def extract_code_bundle(
308
+ task_spec: task_definition_pb2.TaskSpec,
309
+ ) -> Optional[CodeBundle]:
304
310
  """
305
311
  Extract the code bundle from the task spec.
306
312
  :param task_spec: The task spec to extract the code bundle from.
@@ -326,5 +332,10 @@ def extract_code_bundle(task_spec: task_definition_pb2.TaskSpec) -> Optional[Cod
326
332
  # Extract the version from the argument
327
333
  version = container.args[i + 1] if i + 1 < len(container.args) else ""
328
334
  if pkl_path or tgz_path:
329
- return CodeBundle(destination=dest_path, tgz=tgz_path, pkl=pkl_path, computed_version=version)
335
+ return CodeBundle(
336
+ destination=dest_path,
337
+ tgz=tgz_path,
338
+ pkl=pkl_path,
339
+ computed_version=version,
340
+ )
330
341
  return None
@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
14
14
  from flyte._protos.validate.validate import validate_pb2 as validate_dot_validate__pb2
15
15
 
16
16
 
17
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x63ommon/identifier.proto\x12\x0f\x63loudidl.common\x1a\x17validate/validate.proto\"~\n\x11ProjectIdentifier\x12+\n\x0corganization\x18\x01 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x0corganization\x12\x1f\n\x06\x64omain\x18\x02 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x06\x64omain\x12\x1b\n\x04name\x18\x03 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x04name\"T\n\x11\x43lusterIdentifier\x12\"\n\x0corganization\x18\x01 \x01(\tR\x0corganization\x12\x1b\n\x04name\x18\x02 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x04name\"O\n\x15\x43lusterPoolIdentifier\x12\"\n\x0corganization\x18\x01 \x01(\tR\x0corganization\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\"_\n\x17\x43lusterConfigIdentifier\x12+\n\x0corganization\x18\x01 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x0corganization\x12\x17\n\x02id\x18\x02 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x02id\"3\n\x0eUserIdentifier\x12!\n\x07subject\x18\x01 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x07subject\":\n\x15\x41pplicationIdentifier\x12!\n\x07subject\x18\x01 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x07subject\"Q\n\x0eRoleIdentifier\x12\"\n\x0corganization\x18\x01 \x01(\tR\x0corganization\x12\x1b\n\x04name\x18\x02 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x04name\"O\n\rOrgIdentifier\x12>\n\x04name\x18\x01 \x01(\tB*\xfa\x42\'r%\x10\x01\x18?2\x1f^[a-z0-9]([-a-z0-9]*[a-z0-9])?$R\x04name\"y\n\x18ManagedClusterIdentifier\x12\x1b\n\x04name\x18\x02 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x04name\x12:\n\x03org\x18\x03 \x01(\x0b\x32\x1e.cloudidl.common.OrgIdentifierB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x03orgJ\x04\x08\x01\x10\x02\"S\n\x10PolicyIdentifier\x12\"\n\x0corganization\x18\x01 \x01(\tR\x0corganization\x12\x1b\n\x04name\x18\x02 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x04nameB\xb0\x01\n\x13\x63om.cloudidl.commonB\x0fIdentifierProtoH\x02P\x01Z)github.com/unionai/cloud/gen/pb-go/common\xa2\x02\x03\x43\x43X\xaa\x02\x0f\x43loudidl.Common\xca\x02\x0f\x43loudidl\\Common\xe2\x02\x1b\x43loudidl\\Common\\GPBMetadata\xea\x02\x10\x43loudidl::Commonb\x06proto3')
17
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x63ommon/identifier.proto\x12\x0f\x63loudidl.common\x1a\x17validate/validate.proto\"~\n\x11ProjectIdentifier\x12+\n\x0corganization\x18\x01 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x0corganization\x12\x1f\n\x06\x64omain\x18\x02 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x06\x64omain\x12\x1b\n\x04name\x18\x03 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x04name\"T\n\x11\x43lusterIdentifier\x12\"\n\x0corganization\x18\x01 \x01(\tR\x0corganization\x12\x1b\n\x04name\x18\x02 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x04name\"O\n\x15\x43lusterPoolIdentifier\x12\"\n\x0corganization\x18\x01 \x01(\tR\x0corganization\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\"_\n\x17\x43lusterConfigIdentifier\x12+\n\x0corganization\x18\x01 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x0corganization\x12\x17\n\x02id\x18\x02 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x02id\"3\n\x0eUserIdentifier\x12!\n\x07subject\x18\x01 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x07subject\":\n\x15\x41pplicationIdentifier\x12!\n\x07subject\x18\x01 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x07subject\"Q\n\x0eRoleIdentifier\x12\"\n\x0corganization\x18\x01 \x01(\tR\x0corganization\x12\x1b\n\x04name\x18\x02 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x04name\"O\n\rOrgIdentifier\x12>\n\x04name\x18\x01 \x01(\tB*\xfa\x42\'r%\x10\x01\x18?2\x1f^[a-z0-9]([-a-z0-9]*[a-z0-9])?$R\x04name\"y\n\x18ManagedClusterIdentifier\x12\x1b\n\x04name\x18\x02 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x04name\x12:\n\x03org\x18\x03 \x01(\x0b\x32\x1e.cloudidl.common.OrgIdentifierB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x03orgJ\x04\x08\x01\x10\x02\"S\n\x10PolicyIdentifier\x12\"\n\x0corganization\x18\x01 \x01(\tR\x0corganization\x12\x1b\n\x04name\x18\x02 \x01(\tB\x07\xfa\x42\x04r\x02\x10\x01R\x04name\"\x93\x01\n\rRunIdentifier\x12\x1b\n\x03org\x18\x01 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x03org\x12#\n\x07project\x18\x02 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x07project\x12!\n\x06\x64omain\x18\x03 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18?R\x06\x64omain\x12\x1d\n\x04name\x18\x04 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18\x1eR\x04name\"m\n\x10\x41\x63tionIdentifier\x12:\n\x03run\x18\x01 \x01(\x0b\x32\x1e.cloudidl.common.RunIdentifierB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x03run\x12\x1d\n\x04name\x18\x02 \x01(\tB\t\xfa\x42\x06r\x04\x10\x01\x18\x1eR\x04name\"\x86\x01\n\x17\x41\x63tionAttemptIdentifier\x12H\n\taction_id\x18\x01 \x01(\x0b\x32!.cloudidl.common.ActionIdentifierB\x08\xfa\x42\x05\x8a\x01\x02\x10\x01R\x08\x61\x63tionId\x12!\n\x07\x61ttempt\x18\x02 \x01(\rB\x07\xfa\x42\x04*\x02 \x00R\x07\x61ttemptB\xb0\x01\n\x13\x63om.cloudidl.commonB\x0fIdentifierProtoH\x02P\x01Z)github.com/unionai/cloud/gen/pb-go/common\xa2\x02\x03\x43\x43X\xaa\x02\x0f\x43loudidl.Common\xca\x02\x0f\x43loudidl\\Common\xe2\x02\x1b\x43loudidl\\Common\\GPBMetadata\xea\x02\x10\x43loudidl::Commonb\x06proto3')
18
18
 
19
19
  _globals = globals()
20
20
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -48,6 +48,22 @@ if _descriptor._USE_C_DESCRIPTORS == False:
48
48
  _MANAGEDCLUSTERIDENTIFIER.fields_by_name['org']._serialized_options = b'\372B\005\212\001\002\020\001'
49
49
  _POLICYIDENTIFIER.fields_by_name['name']._options = None
50
50
  _POLICYIDENTIFIER.fields_by_name['name']._serialized_options = b'\372B\004r\002\020\001'
51
+ _RUNIDENTIFIER.fields_by_name['org']._options = None
52
+ _RUNIDENTIFIER.fields_by_name['org']._serialized_options = b'\372B\006r\004\020\001\030?'
53
+ _RUNIDENTIFIER.fields_by_name['project']._options = None
54
+ _RUNIDENTIFIER.fields_by_name['project']._serialized_options = b'\372B\006r\004\020\001\030?'
55
+ _RUNIDENTIFIER.fields_by_name['domain']._options = None
56
+ _RUNIDENTIFIER.fields_by_name['domain']._serialized_options = b'\372B\006r\004\020\001\030?'
57
+ _RUNIDENTIFIER.fields_by_name['name']._options = None
58
+ _RUNIDENTIFIER.fields_by_name['name']._serialized_options = b'\372B\006r\004\020\001\030\036'
59
+ _ACTIONIDENTIFIER.fields_by_name['run']._options = None
60
+ _ACTIONIDENTIFIER.fields_by_name['run']._serialized_options = b'\372B\005\212\001\002\020\001'
61
+ _ACTIONIDENTIFIER.fields_by_name['name']._options = None
62
+ _ACTIONIDENTIFIER.fields_by_name['name']._serialized_options = b'\372B\006r\004\020\001\030\036'
63
+ _ACTIONATTEMPTIDENTIFIER.fields_by_name['action_id']._options = None
64
+ _ACTIONATTEMPTIDENTIFIER.fields_by_name['action_id']._serialized_options = b'\372B\005\212\001\002\020\001'
65
+ _ACTIONATTEMPTIDENTIFIER.fields_by_name['attempt']._options = None
66
+ _ACTIONATTEMPTIDENTIFIER.fields_by_name['attempt']._serialized_options = b'\372B\004*\002 \000'
51
67
  _globals['_PROJECTIDENTIFIER']._serialized_start=69
52
68
  _globals['_PROJECTIDENTIFIER']._serialized_end=195
53
69
  _globals['_CLUSTERIDENTIFIER']._serialized_start=197
@@ -68,4 +84,10 @@ if _descriptor._USE_C_DESCRIPTORS == False:
68
84
  _globals['_MANAGEDCLUSTERIDENTIFIER']._serialized_end=859
69
85
  _globals['_POLICYIDENTIFIER']._serialized_start=861
70
86
  _globals['_POLICYIDENTIFIER']._serialized_end=944
87
+ _globals['_RUNIDENTIFIER']._serialized_start=947
88
+ _globals['_RUNIDENTIFIER']._serialized_end=1094
89
+ _globals['_ACTIONIDENTIFIER']._serialized_start=1096
90
+ _globals['_ACTIONIDENTIFIER']._serialized_end=1205
91
+ _globals['_ACTIONATTEMPTIDENTIFIER']._serialized_start=1208
92
+ _globals['_ACTIONATTEMPTIDENTIFIER']._serialized_end=1342
71
93
  # @@protoc_insertion_point(module_scope)