indexify 0.3.19__py3-none-any.whl → 0.3.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. indexify/cli/cli.py +12 -0
  2. indexify/executor/api_objects.py +11 -6
  3. indexify/executor/blob_store/blob_store.py +69 -0
  4. indexify/executor/blob_store/local_fs_blob_store.py +48 -0
  5. indexify/executor/blob_store/metrics/blob_store.py +33 -0
  6. indexify/executor/blob_store/s3_blob_store.py +88 -0
  7. indexify/executor/downloader.py +192 -27
  8. indexify/executor/executor.py +29 -13
  9. indexify/executor/function_executor/function_executor.py +1 -1
  10. indexify/executor/function_executor/function_executor_states_container.py +5 -0
  11. indexify/executor/function_executor/function_executor_status.py +2 -0
  12. indexify/executor/function_executor/health_checker.py +7 -2
  13. indexify/executor/function_executor/invocation_state_client.py +4 -2
  14. indexify/executor/function_executor/single_task_runner.py +2 -0
  15. indexify/executor/function_executor/task_output.py +8 -1
  16. indexify/executor/grpc/channel_manager.py +4 -3
  17. indexify/executor/grpc/function_executor_controller.py +163 -193
  18. indexify/executor/grpc/metrics/state_reconciler.py +17 -0
  19. indexify/executor/grpc/metrics/task_controller.py +8 -0
  20. indexify/executor/grpc/state_reconciler.py +305 -188
  21. indexify/executor/grpc/state_reporter.py +18 -10
  22. indexify/executor/grpc/task_controller.py +247 -189
  23. indexify/executor/metrics/task_reporter.py +17 -0
  24. indexify/executor/task_reporter.py +217 -94
  25. indexify/executor/task_runner.py +1 -0
  26. indexify/proto/executor_api.proto +37 -11
  27. indexify/proto/executor_api_pb2.py +49 -47
  28. indexify/proto/executor_api_pb2.pyi +55 -15
  29. {indexify-0.3.19.dist-info → indexify-0.3.21.dist-info}/METADATA +2 -1
  30. {indexify-0.3.19.dist-info → indexify-0.3.21.dist-info}/RECORD +32 -27
  31. indexify/executor/grpc/completed_tasks_container.py +0 -26
  32. {indexify-0.3.19.dist-info → indexify-0.3.21.dist-info}/WHEEL +0 -0
  33. {indexify-0.3.19.dist-info → indexify-0.3.21.dist-info}/entry_points.txt +0 -0
@@ -1,5 +1,4 @@
1
1
  import asyncio
2
- import time
3
2
  from typing import Any, Optional
4
3
 
5
4
  import grpc
@@ -16,6 +15,7 @@ from tensorlake.function_executor.proto.message_validator import MessageValidato
16
15
  from indexify.proto.executor_api_pb2 import Task
17
16
 
18
17
  from ..downloader import Downloader
18
+ from ..function_executor.function_executor import FunctionExecutor
19
19
  from ..function_executor.function_executor_state import FunctionExecutorState
20
20
  from ..function_executor.function_executor_status import FunctionExecutorStatus
21
21
  from ..function_executor.metrics.single_task_runner import (
@@ -36,10 +36,10 @@ from ..metrics.executor import (
36
36
  metric_task_outcome_report_retries,
37
37
  metric_task_outcome_reports,
38
38
  metric_tasks_completed,
39
+ metric_tasks_fetched,
39
40
  metric_tasks_reporting_outcome,
40
41
  )
41
42
  from ..metrics.task_runner import (
42
- metric_task_policy_errors,
43
43
  metric_task_policy_latency,
44
44
  metric_task_policy_runs,
45
45
  metric_task_run_latency,
@@ -50,191 +50,197 @@ from ..metrics.task_runner import (
50
50
  metric_tasks_running,
51
51
  )
52
52
  from ..task_reporter import TaskReporter
53
- from .completed_tasks_container import CompletedTasksContainer
53
+ from .metrics.task_controller import metric_task_cancellations
54
54
 
55
55
  _TASK_OUTCOME_REPORT_BACKOFF_SEC = 5.0
56
56
 
57
57
 
58
- class FunctionTimeoutError(Exception):
59
- """Exception raised when a customer's task execution exceeds the allowed timeout."""
58
+ def validate_task(task: Task) -> None:
59
+ """Validates the supplied Task.
60
60
 
61
- def __init__(self, message: str):
62
- super().__init__(message)
61
+ Raises ValueError if the Task is not valid.
62
+ """
63
+ validator = MessageValidator(task)
64
+ validator.required_field("id")
65
+ validator.required_field("namespace")
66
+ validator.required_field("graph_name")
67
+ validator.required_field("graph_version")
68
+ validator.required_field("function_name")
69
+ validator.required_field("graph_invocation_id")
70
+ if not (task.HasField("input_key") or task.HasField("input")):
71
+ raise ValueError(
72
+ "Task must have either input_key or input field set. " f"Got task: {task}"
73
+ )
74
+
75
+
76
+ def task_logger(task: Task, logger: Any) -> Any:
77
+ """Returns a logger bound with the task's metadata.
78
+
79
+ The function assumes that the task might be invalid."""
80
+ return logger.bind(
81
+ task_id=task.id if task.HasField("id") else None,
82
+ namespace=task.namespace if task.HasField("namespace") else None,
83
+ graph_name=task.graph_name if task.HasField("graph_name") else None,
84
+ graph_version=task.graph_version if task.HasField("graph_version") else None,
85
+ function_name=task.function_name if task.HasField("function_name") else None,
86
+ graph_invocation_id=(
87
+ task.graph_invocation_id if task.HasField("graph_invocation_id") else None
88
+ ),
89
+ )
63
90
 
64
91
 
65
92
  class TaskController:
66
93
  def __init__(
67
94
  self,
68
95
  task: Task,
69
- function_executor_state: FunctionExecutorState,
70
96
  downloader: Downloader,
71
97
  task_reporter: TaskReporter,
72
- completed_tasks_container: CompletedTasksContainer,
98
+ function_executor_id: str,
99
+ function_executor_state: FunctionExecutorState,
73
100
  logger: Any,
74
101
  ):
75
102
  """Creates a new TaskController instance.
76
103
 
77
- Raises ValueError if the supplied Task is not valid.
104
+ The supplied Task must be already validated by the caller using validate_task().
78
105
  """
79
- _validate_task(task)
80
106
  self._task: Task = task
81
- self._function_executor_state: FunctionExecutorState = function_executor_state
82
107
  self._downloader: Downloader = downloader
83
108
  self._task_reporter: TaskReporter = task_reporter
84
- self._completed_tasks_container: CompletedTasksContainer = (
85
- completed_tasks_container
86
- )
87
- self._logger: Any = logger.bind(
88
- function_executor_id=function_executor_state.id,
89
- task_id=task.id,
109
+ self._function_executor_id: str = function_executor_id
110
+ self._function_executor_state: FunctionExecutorState = function_executor_state
111
+ self._logger: Any = task_logger(task, logger).bind(
112
+ function_executor_id=function_executor_id,
90
113
  module=__name__,
91
- namespace=task.namespace,
92
- graph_name=task.graph_name,
93
- graph_version=task.graph_version,
94
- function_name=task.function_name,
95
- invocation_id=task.graph_invocation_id,
96
114
  )
97
- self._is_running: bool = False
98
- self._is_cancelled: bool = False
115
+
99
116
  self._input: Optional[SerializedObject] = None
100
117
  self._init_value: Optional[SerializedObject] = None
101
- self._output: Optional[TaskOutput] = None
118
+ self._is_timed_out: bool = False
119
+ # Automatically start the controller on creation.
120
+ self._task_runner: asyncio.Task = asyncio.create_task(
121
+ self._run(), name="task controller task runner"
122
+ )
102
123
 
103
- async def cancel_task(self) -> None:
104
- """Cancells the task."""
105
- self._is_cancelled = True
124
+ def function_executor_id(self) -> str:
125
+ return self._function_executor_id
106
126
 
107
- async with self._function_executor_state.lock:
108
- if not self._is_running:
109
- return
127
+ def task(self) -> Task:
128
+ return self._task
110
129
 
111
- # Mark the Function Executor as unhealthy to destroy it to cancel the running function.
112
- # If FE status changed, then it means that we're off normal task execution path, e.g.
113
- # Server decided to do something with FE.
114
- if (
115
- self._function_executor_state.status
116
- == FunctionExecutorStatus.RUNNING_TASK
117
- ):
118
- # TODO: Add a separate FE status for cancelled function so we don't lie to server that FE is unhealthy to destroy it.
119
- await self._function_executor_state.set_status(
120
- FunctionExecutorStatus.UNHEALTHY,
121
- )
122
- self._logger.warning("task is cancelled")
123
- else:
124
- self._logger.warning(
125
- "skipping marking Function Executor unhealthy on task cancellation due to unexpected FE status",
126
- status=self._function_executor_state.status.name,
127
- )
130
+ async def destroy(self) -> None:
131
+ """Destroys the controller and cancells the task if it didn't finish yet.
128
132
 
129
- async def run_task(self) -> None:
133
+ A running task is cancelled by destroying its Function Executor.
134
+ Doesn't raise any exceptions.
135
+ """
136
+ if self._task_runner.done():
137
+ return # Nothin to do, the task is finished already.
138
+
139
+ # The task runner code handles asyncio.CancelledError properly.
140
+ self._task_runner.cancel()
141
+ # Don't await the cancelled task to not block the caller unnecessary.
142
+
143
+ async def _run(self) -> None:
144
+ metric_tasks_fetched.inc()
145
+ with metric_task_completion_latency.time():
146
+ await self._run_task()
147
+
148
+ async def _run_task(self) -> None:
130
149
  """Runs the supplied task and does full managemenet of its lifecycle.
131
150
 
132
151
  Doesn't raise any exceptions."""
133
- start_time: float = time.monotonic()
152
+ output: Optional[TaskOutput] = None
134
153
 
135
154
  try:
136
- # The task can be cancelled at any time but we'll just wait until FE gets shutdown
137
- # because we require this to happen from the cancel_task() caller.
138
- self._input = await self._downloader.download_input(
139
- namespace=self._task.namespace,
140
- graph_name=self._task.graph_name,
141
- graph_invocation_id=self._task.graph_invocation_id,
142
- input_key=self._task.input_key,
143
- logger=self._logger,
144
- )
145
- if self._task.HasField("reducer_output_key"):
146
- self._init_value = await self._downloader.download_init_value(
147
- namespace=self._task.namespace,
148
- graph_name=self._task.graph_name,
149
- function_name=self._task.function_name,
150
- graph_invocation_id=self._task.graph_invocation_id,
151
- reducer_output_key=self._task.reducer_output_key,
152
- logger=self._logger,
153
- )
155
+ await self._download_inputs()
156
+ output = await self._run_task_when_function_executor_is_available()
157
+ self._logger.info("task execution finished", success=output.success)
158
+ _log_function_metrics(output, self._logger)
159
+ except Exception as e:
160
+ metric_task_run_platform_errors.inc(),
161
+ output = self._internal_error_output()
162
+ self._logger.error("task execution failed", exc_info=e)
163
+ except asyncio.CancelledError:
164
+ metric_task_cancellations.inc()
165
+ self._logger.info("task execution cancelled")
166
+ # Don't report task outcome according to the current policy.
167
+ # asyncio.CancelledError can't be suppressed, see Python docs.
168
+ raise
169
+
170
+ # Current task outcome reporting policy:
171
+ # Don't report task outcomes for tasks that didn't fail with internal or customer error.
172
+ # This is required to simplify the protocol so Server doesn't need to care about task states
173
+ # and cancel each tasks carefully to not get its outcome as failed.
174
+ with (
175
+ metric_tasks_reporting_outcome.track_inprogress(),
176
+ metric_task_outcome_report_latency.time(),
177
+ ):
178
+ metric_task_outcome_reports.inc()
179
+ await self._report_task_outcome(output)
154
180
 
155
- await self._wait_for_idle_function_executor()
181
+ async def _download_inputs(self) -> None:
182
+ """Downloads the task inputs and init value.
156
183
 
157
- with (
158
- metric_task_run_platform_errors.count_exceptions(),
159
- metric_tasks_running.track_inprogress(),
160
- metric_task_run_latency.time(),
161
- ):
162
- metric_task_runs.inc()
163
- await self._run_task()
184
+ Raises an Exception if the inputs failed to download.
185
+ """
186
+ self._input = await self._downloader.download_input(
187
+ namespace=self._task.namespace,
188
+ graph_name=self._task.graph_name,
189
+ graph_invocation_id=self._task.graph_invocation_id,
190
+ input_key=self._task.input_key,
191
+ data_payload=self._task.input if self._task.HasField("input") else None,
192
+ logger=self._logger,
193
+ )
164
194
 
165
- self._logger.info("task execution finished", success=self._output.success)
166
- except FunctionTimeoutError:
167
- self._output = TaskOutput.function_timeout(
168
- task_id=self._task.id,
169
- namespace=self._task.namespace,
170
- graph_name=self._task.graph_name,
171
- function_name=self._task.function_name,
172
- graph_version=self._task.graph_version,
173
- graph_invocation_id=self._task.graph_invocation_id,
174
- )
175
- async with self._function_executor_state.lock:
176
- # Mark the Function Executor as unhealthy to destroy it to cancel the running function.
177
- # If FE status changed, then it means that we're off normal task execution path, e.g.
178
- # Server decided to do something with FE.
179
- if (
180
- self._function_executor_state.status
181
- == FunctionExecutorStatus.RUNNING_TASK
182
- ):
183
- # TODO: Add a separate FE status for timed out function so we don't lie to server that FE is unhealthy to destroy it.
184
- await self._function_executor_state.set_status(
185
- FunctionExecutorStatus.UNHEALTHY,
186
- )
187
- else:
188
- self._logger.warning(
189
- "skipping marking Function Executor unhealthy on task timeout due to unexpected FE status",
190
- status=self._function_executor_state.status.name,
191
- )
192
- except Exception as e:
193
- self._output = TaskOutput.internal_error(
194
- task_id=self._task.id,
195
+ if self._task.HasField("reducer_output_key") or self._task.HasField(
196
+ "reducer_input"
197
+ ):
198
+ self._init_value = await self._downloader.download_init_value(
195
199
  namespace=self._task.namespace,
196
200
  graph_name=self._task.graph_name,
197
201
  function_name=self._task.function_name,
198
- graph_version=self._task.graph_version,
199
202
  graph_invocation_id=self._task.graph_invocation_id,
203
+ reducer_output_key=(
204
+ self._task.reducer_output_key
205
+ if self._task.HasField("reducer_output_key")
206
+ else ""
207
+ ),
208
+ data_payload=(
209
+ self._task.reducer_input
210
+ if self._task.HasField("reducer_input")
211
+ else None
212
+ ),
213
+ logger=self._logger,
200
214
  )
201
- self._logger.error("task execution failed", exc_info=e)
202
- finally:
203
- # Release the Function Executor so others can run tasks on it if FE status didn't change.
204
- # If FE status changed, then it means that we're off normal task execution path, e.g.
205
- # Server decided to do something with FE.
206
- async with self._function_executor_state.lock:
207
- if (
208
- self._function_executor_state.status
209
- == FunctionExecutorStatus.RUNNING_TASK
210
- ):
211
- await self._function_executor_state.set_status(
212
- FunctionExecutorStatus.IDLE
213
- )
214
- else:
215
- self._logger.warning(
216
- "skipping marking Function Executor IDLE due to unexpected FE status",
217
- status=self._function_executor_state.status,
218
- )
219
215
 
220
- _log_function_metrics(self._output, self._logger)
216
+ async def _run_task_when_function_executor_is_available(self) -> TaskOutput:
217
+ """Runs the task on the Function Executor when it's available.
221
218
 
222
- with (
223
- metric_tasks_reporting_outcome.track_inprogress(),
224
- metric_task_outcome_report_latency.time(),
225
- ):
226
- metric_task_outcome_reports.inc()
227
- await self._report_task_outcome()
219
+ Raises an Exception if task failed due to an internal error."""
220
+ await self._acquire_function_executor()
228
221
 
229
- metric_task_completion_latency.observe(time.monotonic() - start_time)
222
+ next_status: FunctionExecutorStatus = FunctionExecutorStatus.IDLE
223
+ try:
224
+ return await self._run_task_on_acquired_function_executor()
225
+ except asyncio.CancelledError:
226
+ # This one is raised here when destroy() was called while we were running the task on this FE.
227
+ next_status = FunctionExecutorStatus.UNHEALTHY
228
+ # asyncio.CancelledError can't be suppressed, see Python docs.
229
+ raise
230
+ finally:
231
+ # If the task finished running on FE then put it into IDLE state so other tasks can run on it.
232
+ # Otherwise, mark the FE as unhealthy to force its destruction so the task stops running on it eventually
233
+ # and no other tasks run on this FE because it'd result in undefined behavior.
234
+ if self._is_timed_out:
235
+ next_status = FunctionExecutorStatus.UNHEALTHY
236
+ await self._release_function_executor(next_status=next_status)
230
237
 
231
- async def _wait_for_idle_function_executor(self) -> None:
232
- """Waits until the Function Executor is in IDLE state.
238
+ async def _acquire_function_executor(self) -> None:
239
+ """Waits until the Function Executor is in IDLE state and then locks it so the task can run on it.
233
240
 
234
- Raises an Exception if the Function Executor is in SHUTDOWN state.
241
+ Doesn't raise any exceptions.
235
242
  """
236
243
  with (
237
- metric_task_policy_errors.count_exceptions(),
238
244
  metric_tasks_blocked_by_policy.track_inprogress(),
239
245
  metric_tasks_blocked_by_policy_per_function_name.labels(
240
246
  function_name=self._task.function_name
@@ -247,18 +253,8 @@ class TaskController:
247
253
  )
248
254
  async with self._function_executor_state.lock:
249
255
  await self._function_executor_state.wait_status(
250
- allowlist=[
251
- FunctionExecutorStatus.IDLE,
252
- FunctionExecutorStatus.SHUTDOWN,
253
- ]
256
+ allowlist=[FunctionExecutorStatus.IDLE]
254
257
  )
255
- if (
256
- self._function_executor_state.status
257
- == FunctionExecutorStatus.SHUTDOWN
258
- ):
259
- raise Exception(
260
- "Task's Function Executor got shutdown, can't run task"
261
- )
262
258
  await self._function_executor_state.set_status(
263
259
  FunctionExecutorStatus.RUNNING_TASK
264
260
  )
@@ -266,7 +262,45 @@ class TaskController:
266
262
  # At this point the Function Executor belongs to this task controller due to RUNNING_TASK status.
267
263
  # We can now unlock the FE state. We have to update the FE status once the task succeeds or fails.
268
264
 
269
- async def _run_task(self) -> None:
265
+ async def _release_function_executor(
266
+ self, next_status: FunctionExecutorStatus
267
+ ) -> None:
268
+ # Release the Function Executor so others can run tasks on it if FE status didn't change.
269
+ # If FE status changed, then it means that we're off normal task execution path, e.g.
270
+ # Server decided to do something with FE.
271
+ async with self._function_executor_state.lock:
272
+ if (
273
+ self._function_executor_state.status
274
+ == FunctionExecutorStatus.RUNNING_TASK
275
+ ):
276
+ await self._function_executor_state.set_status(next_status)
277
+ if next_status == FunctionExecutorStatus.UNHEALTHY:
278
+ # Destroy the unhealthy FE asap so it doesn't consume resources.
279
+ # Don't do it under the state lock to not add unnecessary delays.
280
+ asyncio.create_task(
281
+ self._function_executor_state.function_executor.destroy()
282
+ )
283
+ self._function_executor_state.function_executor = None
284
+ else:
285
+ self._logger.warning(
286
+ "skipping releasing Function Executor after running the task due to unexpected Function Executor status",
287
+ status=self._function_executor_state.status.name,
288
+ next_status=next_status.name,
289
+ )
290
+
291
+ async def _run_task_on_acquired_function_executor(self) -> TaskOutput:
292
+ """Runs the task on the Function Executor acquired by this task already and returns the output.
293
+
294
+ Raises an Exception if the task failed to run due to an internal error."""
295
+ with metric_tasks_running.track_inprogress(), metric_task_run_latency.time():
296
+ metric_task_runs.inc()
297
+ return await self._run_task_rpc_on_function_executor()
298
+
299
+ async def _run_task_rpc_on_function_executor(self) -> TaskOutput:
300
+ """Runs the task on the Function Executor and returns the output.
301
+
302
+ Raises an Exception if the task failed to run due to an internal error.
303
+ """
270
304
  request: RunTaskRequest = RunTaskRequest(
271
305
  namespace=self._task.namespace,
272
306
  graph_name=self._task.graph_name,
@@ -276,8 +310,14 @@ class TaskController:
276
310
  task_id=self._task.id,
277
311
  function_input=self._input,
278
312
  )
313
+ # Don't keep the input in memory after we started running the task.
314
+ self._input = None
315
+
279
316
  if self._init_value is not None:
280
317
  request.function_init_value.CopyFrom(self._init_value)
318
+ # Don't keep the init value in memory after we started running the task.
319
+ self._init_value = None
320
+
281
321
  channel: grpc.aio.Channel = (
282
322
  self._function_executor_state.function_executor.channel()
283
323
  )
@@ -289,7 +329,7 @@ class TaskController:
289
329
 
290
330
  async with _RunningTaskContextManager(
291
331
  task=self._task,
292
- function_executor_state=self._function_executor_state,
332
+ function_executor=self._function_executor_state.function_executor,
293
333
  ):
294
334
  with (
295
335
  metric_function_executor_run_task_rpc_errors.count_exceptions(),
@@ -305,14 +345,16 @@ class TaskController:
305
345
  ).run_task(request, timeout=timeout_sec)
306
346
  except grpc.aio.AioRpcError as e:
307
347
  if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
308
- raise FunctionTimeoutError(
309
- f"Task execution timeout {timeout_sec} expired"
310
- ) from e
348
+ # Not logging customer error.
349
+ self._is_timed_out = True
350
+ return self._function_timeout_output(timeout_sec=timeout_sec)
311
351
  raise
312
352
 
313
- self._output = _task_output(task=self._task, response=response)
353
+ return _task_output_from_function_executor_response(
354
+ task=self._task, response=response
355
+ )
314
356
 
315
- async def _report_task_outcome(self) -> None:
357
+ async def _report_task_outcome(self, output: TaskOutput) -> None:
316
358
  """Reports the task with the given output to the server.
317
359
 
318
360
  Doesn't raise any Exceptions. Runs till the reporting is successful."""
@@ -320,14 +362,8 @@ class TaskController:
320
362
 
321
363
  while True:
322
364
  logger = self._logger.bind(retries=reporting_retries)
323
- if self._is_cancelled:
324
- logger.warning(
325
- "task is cancelled, skipping its outcome reporting to workaround lack of server side retries"
326
- )
327
- break
328
-
329
365
  try:
330
- await self._task_reporter.report(output=self._output, logger=logger)
366
+ await self._task_reporter.report(output=output, logger=logger)
331
367
  break
332
368
  except Exception as e:
333
369
  logger.error(
@@ -338,13 +374,12 @@ class TaskController:
338
374
  metric_task_outcome_report_retries.inc()
339
375
  await asyncio.sleep(_TASK_OUTCOME_REPORT_BACKOFF_SEC)
340
376
 
341
- await self._completed_tasks_container.add(self._task.id)
342
377
  metric_tasks_completed.labels(outcome=METRIC_TASKS_COMPLETED_OUTCOME_ALL).inc()
343
- if self._output.is_internal_error:
378
+ if output.is_internal_error:
344
379
  metric_tasks_completed.labels(
345
380
  outcome=METRIC_TASKS_COMPLETED_OUTCOME_ERROR_PLATFORM
346
381
  ).inc()
347
- elif self._output.success:
382
+ elif output.success:
348
383
  metric_tasks_completed.labels(
349
384
  outcome=METRIC_TASKS_COMPLETED_OUTCOME_SUCCESS
350
385
  ).inc()
@@ -353,23 +388,41 @@ class TaskController:
353
388
  outcome=METRIC_TASKS_COMPLETED_OUTCOME_ERROR_CUSTOMER_CODE
354
389
  ).inc()
355
390
 
391
+ def _internal_error_output(self) -> TaskOutput:
392
+ return TaskOutput.internal_error(
393
+ task_id=self._task.id,
394
+ namespace=self._task.namespace,
395
+ graph_name=self._task.graph_name,
396
+ function_name=self._task.function_name,
397
+ graph_version=self._task.graph_version,
398
+ graph_invocation_id=self._task.graph_invocation_id,
399
+ output_payload_uri_prefix=(
400
+ self._task.output_payload_uri_prefix
401
+ if self._task.HasField("output_payload_uri_prefix")
402
+ else None
403
+ ),
404
+ )
356
405
 
357
- def _validate_task(task: Task) -> None:
358
- """Validates the supplied Task.
359
-
360
- Raises ValueError if the Task is not valid.
361
- """
362
- validator = MessageValidator(task)
363
- validator.required_field("id")
364
- validator.required_field("namespace")
365
- validator.required_field("graph_name")
366
- validator.required_field("graph_version")
367
- validator.required_field("function_name")
368
- validator.required_field("graph_invocation_id")
369
- validator.required_field("input_key")
406
+ def _function_timeout_output(self, timeout_sec: float) -> TaskOutput:
407
+ return TaskOutput.function_timeout(
408
+ task_id=self._task.id,
409
+ namespace=self._task.namespace,
410
+ graph_name=self._task.graph_name,
411
+ function_name=self._task.function_name,
412
+ graph_version=self._task.graph_version,
413
+ graph_invocation_id=self._task.graph_invocation_id,
414
+ timeout_sec=timeout_sec,
415
+ output_payload_uri_prefix=(
416
+ self._task.output_payload_uri_prefix
417
+ if self._task.HasField("output_payload_uri_prefix")
418
+ else None
419
+ ),
420
+ )
370
421
 
371
422
 
372
- def _task_output(task: Task, response: RunTaskResponse) -> TaskOutput:
423
+ def _task_output_from_function_executor_response(
424
+ task: Task, response: RunTaskResponse
425
+ ) -> TaskOutput:
373
426
  response_validator = MessageValidator(response)
374
427
  response_validator.required_field("stdout")
375
428
  response_validator.required_field("stderr")
@@ -394,6 +447,11 @@ def _task_output(task: Task, response: RunTaskResponse) -> TaskOutput:
394
447
  reducer=response.is_reducer,
395
448
  success=response.success,
396
449
  metrics=metrics,
450
+ output_payload_uri_prefix=(
451
+ task.output_payload_uri_prefix
452
+ if task.HasField("output_payload_uri_prefix")
453
+ else None
454
+ ),
397
455
  )
398
456
 
399
457
  if response.HasField("function_output"):
@@ -430,20 +488,20 @@ class _RunningTaskContextManager:
430
488
 
431
489
  def __init__(
432
490
  self,
433
- task_controller: TaskController,
491
+ task: Task,
492
+ function_executor: FunctionExecutor,
434
493
  ):
435
- self._task_controller: TaskController = task_controller
494
+ self._task = task
495
+ self._function_executor: FunctionExecutor = function_executor
436
496
 
437
497
  async def __aenter__(self):
438
- self._task_controller._function_executor_state.function_executor.invocation_state_client().add_task_to_invocation_id_entry(
439
- task_id=self._task_controller._task.id,
440
- invocation_id=self._task_controller._task.graph_invocation_id,
498
+ self._function_executor.invocation_state_client().add_task_to_invocation_id_entry(
499
+ task_id=self._task.id,
500
+ invocation_id=self._task.graph_invocation_id,
441
501
  )
442
- self._task_controller._is_running = True
443
502
  return self
444
503
 
445
504
  async def __aexit__(self, exc_type, exc_val, exc_tb):
446
- self._task_controller._is_running = False
447
- self._task_controller._function_executor_state.function_executor.invocation_state_client().remove_task_to_invocation_id_entry(
448
- task_id=self._task_controller._task.id,
505
+ self._function_executor.invocation_state_client().remove_task_to_invocation_id_entry(
506
+ task_id=self._task.id,
449
507
  )
@@ -21,6 +21,23 @@ metric_server_ingest_files_latency: prometheus_client.Histogram = (
21
21
  )
22
22
  )
23
23
 
24
+ metric_task_output_blob_store_uploads: prometheus_client.Counter = (
25
+ prometheus_client.Counter(
26
+ "task_output_blob_store_uploads", "Number of task output uploads to blob store"
27
+ )
28
+ )
29
+ metric_task_output_blob_store_upload_errors: prometheus_client.Counter = (
30
+ prometheus_client.Counter(
31
+ "task_output_blob_store_upload_errors",
32
+ "Number of failed task output uploads to blob store",
33
+ )
34
+ )
35
+ metric_task_output_blob_store_upload_latency: prometheus_client.Histogram = (
36
+ latency_metric_for_fast_operation(
37
+ "task_output_blob_store_upload", "Upload task output to blob store"
38
+ )
39
+ )
40
+
24
41
  metric_report_task_outcome_rpcs = prometheus_client.Counter(
25
42
  "report_task_outcome_rpcs",
26
43
  "Number of report task outcome RPCs to Server",