indexify 0.3.13__py3-none-any.whl → 0.3.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
indexify/cli/cli.py CHANGED
@@ -208,13 +208,17 @@ def executor(
208
208
  help="Port where to run Executor Monitoring server",
209
209
  ),
210
210
  ] = 7000,
211
- disable_automatic_function_executor_management: Annotated[
212
- bool,
211
+ grpc_server_addr: Annotated[
212
+ Optional[str],
213
213
  typer.Option(
214
- "--disable-automatic-function-executor-management",
215
- help="Disable automatic Function Executor management by Executor",
214
+ "--grpc-server-addr",
215
+ help=(
216
+ "(exprimental) Address of server gRPC API to connect to, e.g. 'localhost:8901'.\n"
217
+ "If set disables automatic Function Executor management on Executor and uses the Server gRPC API\n"
218
+ "for Function Executor management and placement of tasks on them."
219
+ ),
216
220
  ),
217
- ] = False,
221
+ ] = None,
218
222
  ):
219
223
  if dev:
220
224
  configure_development_mode_logging()
@@ -247,7 +251,7 @@ def executor(
247
251
  dev_mode=dev,
248
252
  monitoring_server_host=monitoring_server_host,
249
253
  monitoring_server_port=monitoring_server_port,
250
- disable_automatic_function_executor_management=disable_automatic_function_executor_management,
254
+ grpc_server_addr=grpc_server_addr,
251
255
  )
252
256
 
253
257
  executor_cache = Path(executor_cache).expanduser().absolute()
@@ -285,7 +289,7 @@ def executor(
285
289
  config_path=config_path,
286
290
  monitoring_server_host=monitoring_server_host,
287
291
  monitoring_server_port=monitoring_server_port,
288
- disable_automatic_function_executor_management=disable_automatic_function_executor_management,
292
+ grpc_server_addr=grpc_server_addr,
289
293
  ).run()
290
294
 
291
295
 
@@ -3,7 +3,7 @@ import os
3
3
  from typing import Any, Optional
4
4
 
5
5
  import httpx
6
- import structlog
6
+ import nanoid
7
7
  from tensorlake.function_executor.proto.function_executor_pb2 import SerializedObject
8
8
  from tensorlake.utils.http_client import get_httpx_client
9
9
 
@@ -33,41 +33,81 @@ class Downloader:
33
33
  self._base_url = base_url
34
34
  self._client = get_httpx_client(config_path, make_async=True)
35
35
 
36
- async def download_graph(self, task: Task) -> SerializedObject:
36
+ async def download_graph(
37
+ self, namespace: str, graph_name: str, graph_version: str, logger: Any
38
+ ) -> SerializedObject:
39
+ logger = logger.bind(module=__name__)
37
40
  with (
38
41
  metric_graph_download_errors.count_exceptions(),
39
42
  metric_tasks_downloading_graphs.track_inprogress(),
40
43
  metric_graph_download_latency.time(),
41
44
  ):
42
45
  metric_graph_downloads.inc()
43
- return await self._download_graph(task)
46
+ return await self._download_graph(
47
+ namespace=namespace,
48
+ graph_name=graph_name,
49
+ graph_version=graph_version,
50
+ logger=logger,
51
+ )
44
52
 
45
- async def download_input(self, task: Task) -> SerializedObject:
53
+ async def download_input(
54
+ self,
55
+ namespace: str,
56
+ graph_name: str,
57
+ graph_invocation_id: str,
58
+ input_key: str,
59
+ logger: Any,
60
+ ) -> SerializedObject:
61
+ logger = logger.bind(module=__name__)
46
62
  with (
47
63
  metric_task_input_download_errors.count_exceptions(),
48
64
  metric_tasks_downloading_inputs.track_inprogress(),
49
65
  metric_task_input_download_latency.time(),
50
66
  ):
51
67
  metric_task_input_downloads.inc()
52
- return await self._download_input(task)
68
+ return await self._download_input(
69
+ namespace=namespace,
70
+ graph_name=graph_name,
71
+ graph_invocation_id=graph_invocation_id,
72
+ input_key=input_key,
73
+ logger=logger,
74
+ )
53
75
 
54
- async def download_init_value(self, task: Task) -> SerializedObject:
76
+ async def download_init_value(
77
+ self,
78
+ namespace: str,
79
+ graph_name: str,
80
+ function_name: str,
81
+ graph_invocation_id: str,
82
+ reducer_output_key: str,
83
+ logger: Any,
84
+ ) -> SerializedObject:
85
+ logger = logger.bind(module=__name__)
55
86
  with (
56
87
  metric_reducer_init_value_download_errors.count_exceptions(),
57
88
  metric_tasks_downloading_reducer_init_value.track_inprogress(),
58
89
  metric_reducer_init_value_download_latency.time(),
59
90
  ):
60
91
  metric_reducer_init_value_downloads.inc()
61
- return await self._download_init_value(task)
92
+ return await self._fetch_function_init_value(
93
+ namespace=namespace,
94
+ graph_name=graph_name,
95
+ function_name=function_name,
96
+ graph_invocation_id=graph_invocation_id,
97
+ reducer_output_key=reducer_output_key,
98
+ logger=logger,
99
+ )
62
100
 
63
- async def _download_graph(self, task: Task) -> SerializedObject:
101
+ async def _download_graph(
102
+ self, namespace: str, graph_name: str, graph_version: str, logger: Any
103
+ ) -> SerializedObject:
64
104
  # Cache graph to reduce load on the server.
65
105
  graph_path = os.path.join(
66
106
  self.code_path,
67
107
  "graph_cache",
68
- task.namespace,
69
- task.compute_graph,
70
- task.graph_version,
108
+ namespace,
109
+ graph_name,
110
+ graph_version,
71
111
  )
72
112
  # Filesystem operations are synchronous.
73
113
  # Run in a separate thread to not block the main event loop.
@@ -78,13 +118,17 @@ class Downloader:
78
118
  metric_graphs_from_cache.inc()
79
119
  return graph
80
120
 
81
- logger = self._task_logger(task)
82
- graph: SerializedObject = await self._fetch_graph(task, logger)
121
+ graph: SerializedObject = await self._fetch_graph(
122
+ namespace=namespace,
123
+ graph_name=graph_name,
124
+ graph_version=graph_version,
125
+ logger=logger,
126
+ )
83
127
  # Filesystem operations are synchronous.
84
128
  # Run in a separate thread to not block the main event loop.
85
129
  # We don't need to wait for the write completion so we use create_task.
86
130
  asyncio.create_task(
87
- asyncio.to_thread(self._write_cached_graph, task, graph_path, graph)
131
+ asyncio.to_thread(self._write_cached_graph, graph_path, graph)
88
132
  )
89
133
 
90
134
  return graph
@@ -96,14 +140,12 @@ class Downloader:
96
140
  with open(path, "rb") as f:
97
141
  return SerializedObject.FromString(f.read())
98
142
 
99
- def _write_cached_graph(
100
- self, task: Task, path: str, graph: SerializedObject
101
- ) -> None:
143
+ def _write_cached_graph(self, path: str, graph: SerializedObject) -> None:
102
144
  if os.path.exists(path):
103
145
  # Another task already cached the graph.
104
146
  return None
105
147
 
106
- tmp_path = os.path.join(self.code_path, "task_graph_cache", task.id)
148
+ tmp_path = os.path.join(self.code_path, "task_graph_cache", nanoid.generate())
107
149
  os.makedirs(os.path.dirname(tmp_path), exist_ok=True)
108
150
  with open(tmp_path, "wb") as f:
109
151
  f.write(graph.SerializeToString())
@@ -114,60 +156,67 @@ class Downloader:
114
156
  # This also allows to share the same cache between multiple Executors.
115
157
  os.replace(tmp_path, path)
116
158
 
117
- async def _download_input(self, task: Task) -> SerializedObject:
118
- logger = self._task_logger(task)
119
-
120
- first_function_in_graph = task.invocation_id == task.input_key.split("|")[-1]
159
+ async def _download_input(
160
+ self,
161
+ namespace: str,
162
+ graph_name: str,
163
+ graph_invocation_id: str,
164
+ input_key: str,
165
+ logger: Any,
166
+ ) -> SerializedObject:
167
+ first_function_in_graph = graph_invocation_id == input_key.split("|")[-1]
121
168
  if first_function_in_graph:
122
169
  # The first function in Graph gets its input from graph invocation payload.
123
- return await self._fetch_graph_invocation_payload(task, logger)
170
+ return await self._fetch_graph_invocation_payload(
171
+ namespace=namespace,
172
+ graph_name=graph_name,
173
+ graph_invocation_id=graph_invocation_id,
174
+ logger=logger,
175
+ )
124
176
  else:
125
- return await self._fetch_function_input(task, logger)
126
-
127
- async def _download_init_value(self, task: Task) -> SerializedObject:
128
- logger = self._task_logger(task)
129
- return await self._fetch_function_init_value(task, logger)
130
-
131
- def _task_logger(self, task: Task) -> Any:
132
- return structlog.get_logger(
133
- module=__name__,
134
- namespace=task.namespace,
135
- name=task.compute_graph,
136
- version=task.graph_version,
137
- task_id=task.id,
138
- )
177
+ return await self._fetch_function_input(input_key=input_key, logger=logger)
139
178
 
140
- async def _fetch_graph(self, task: Task, logger: Any) -> SerializedObject:
179
+ async def _fetch_graph(
180
+ self, namespace: str, graph_name: str, graph_version: str, logger: Any
181
+ ) -> SerializedObject:
141
182
  """Downloads the compute graph for the task and returns it."""
142
183
  return await self._fetch_url(
143
- url=f"{self._base_url}/internal/namespaces/{task.namespace}/compute_graphs/{task.compute_graph}/versions/{task.graph_version}/code",
144
- resource_description=f"compute graph: {task.compute_graph}",
184
+ url=f"{self._base_url}/internal/namespaces/{namespace}/compute_graphs/{graph_name}/versions/{graph_version}/code",
185
+ resource_description=f"compute graph: {graph_name}",
145
186
  logger=logger,
146
187
  )
147
188
 
148
189
  async def _fetch_graph_invocation_payload(
149
- self, task: Task, logger: Any
190
+ self, namespace: str, graph_name: str, graph_invocation_id: str, logger: Any
150
191
  ) -> SerializedObject:
151
192
  return await self._fetch_url(
152
- url=f"{self._base_url}/namespaces/{task.namespace}/compute_graphs/{task.compute_graph}/invocations/{task.invocation_id}/payload",
153
- resource_description=f"graph invocation payload: {task.invocation_id}",
193
+ url=f"{self._base_url}/namespaces/{namespace}/compute_graphs/{graph_name}/invocations/{graph_invocation_id}/payload",
194
+ resource_description=f"graph invocation payload: {graph_invocation_id}",
154
195
  logger=logger,
155
196
  )
156
197
 
157
- async def _fetch_function_input(self, task: Task, logger: Any) -> SerializedObject:
198
+ async def _fetch_function_input(
199
+ self, input_key: str, logger: Any
200
+ ) -> SerializedObject:
158
201
  return await self._fetch_url(
159
- url=f"{self._base_url}/internal/fn_outputs/{task.input_key}",
160
- resource_description=f"function input: {task.input_key}",
202
+ url=f"{self._base_url}/internal/fn_outputs/{input_key}",
203
+ resource_description=f"function input: {input_key}",
161
204
  logger=logger,
162
205
  )
163
206
 
164
207
  async def _fetch_function_init_value(
165
- self, task: Task, logger: Any
208
+ self,
209
+ namespace: str,
210
+ graph_name: str,
211
+ function_name: str,
212
+ graph_invocation_id: str,
213
+ reducer_output_key: str,
214
+ logger: Any,
166
215
  ) -> SerializedObject:
167
216
  return await self._fetch_url(
168
- url=f"{self._base_url}/namespaces/{task.namespace}/compute_graphs/{task.compute_graph}"
169
- f"/invocations/{task.invocation_id}/fn/{task.compute_fn}/output/{task.reducer_output_id}",
170
- resource_description=f"reducer output: {task.reducer_output_id}",
217
+ url=f"{self._base_url}/namespaces/{namespace}/compute_graphs/{graph_name}"
218
+ f"/invocations/{graph_invocation_id}/fn/{function_name}/output/{reducer_output_key}",
219
+ resource_description=f"reducer output: {reducer_output_key}",
171
220
  logger=logger,
172
221
  )
173
222
 
@@ -5,6 +5,7 @@ from pathlib import Path
5
5
  from socket import gethostname
6
6
  from typing import Any, Dict, List, Optional
7
7
 
8
+ import grpc
8
9
  import structlog
9
10
  from tensorlake.function_executor.proto.function_executor_pb2 import SerializedObject
10
11
  from tensorlake.utils.logging import suppress as suppress_logging
@@ -38,10 +39,14 @@ from .monitoring.health_checker.health_checker import HealthChecker
38
39
  from .monitoring.prometheus_metrics_handler import PrometheusMetricsHandler
39
40
  from .monitoring.server import MonitoringServer
40
41
  from .monitoring.startup_probe_handler import StartupProbeHandler
42
+ from .state_reconciler import ExecutorStateReconciler
43
+ from .state_reporter import ExecutorStateReporter
41
44
  from .task_fetcher import TaskFetcher
42
45
  from .task_reporter import TaskReporter
43
46
  from .task_runner import TaskInput, TaskOutput, TaskRunner
44
47
 
48
+ EXECUTOR_GRPC_SERVER_READY_TIMEOUT_SEC = 10
49
+
45
50
  metric_executor_state.state("starting")
46
51
 
47
52
 
@@ -58,7 +63,7 @@ class Executor:
58
63
  config_path: Optional[str],
59
64
  monitoring_server_host: str,
60
65
  monitoring_server_port: int,
61
- disable_automatic_function_executor_management: bool,
66
+ grpc_server_addr: Optional[str],
62
67
  ):
63
68
  self._logger = structlog.get_logger(module=__name__)
64
69
  self._is_shutdown: bool = False
@@ -83,39 +88,45 @@ class Executor:
83
88
  health_checker.set_function_executor_states_container(
84
89
  self._function_executor_states
85
90
  )
86
- self._task_runner = TaskRunner(
87
- executor_id=id,
88
- function_executor_server_factory=function_executor_server_factory,
89
- base_url=self._base_url,
90
- disable_automatic_function_executor_management=disable_automatic_function_executor_management,
91
- function_executor_states=self._function_executor_states,
92
- config_path=config_path,
93
- )
94
91
  self._downloader = Downloader(
95
92
  code_path=code_path, base_url=self._base_url, config_path=config_path
96
93
  )
97
- self._task_fetcher = TaskFetcher(
98
- executor_id=id,
99
- executor_version=version,
100
- function_allowlist=function_allowlist,
101
- protocol=protocol,
102
- indexify_server_addr=self._server_addr,
103
- config_path=config_path,
104
- )
105
94
  self._task_reporter = TaskReporter(
106
95
  base_url=self._base_url,
107
96
  executor_id=id,
108
97
  config_path=self._config_path,
109
98
  )
99
+ self._grpc_server_addr: Optional[str] = grpc_server_addr
100
+ self._id = id
101
+ self._function_allowlist: Optional[List[FunctionURI]] = function_allowlist
102
+ self._function_executor_server_factory = function_executor_server_factory
103
+ self._state_reporter: Optional[ExecutorStateReporter] = None
104
+ self._state_reconciler: Optional[ExecutorStateReconciler] = None
105
+
106
+ if self._grpc_server_addr is None:
107
+ self._task_runner: Optional[TaskRunner] = TaskRunner(
108
+ executor_id=id,
109
+ function_executor_server_factory=function_executor_server_factory,
110
+ base_url=self._base_url,
111
+ function_executor_states=self._function_executor_states,
112
+ config_path=config_path,
113
+ )
114
+ self._task_fetcher: Optional[TaskFetcher] = TaskFetcher(
115
+ executor_id=id,
116
+ executor_version=version,
117
+ function_allowlist=function_allowlist,
118
+ protocol=protocol,
119
+ indexify_server_addr=self._server_addr,
120
+ config_path=config_path,
121
+ )
122
+
110
123
  executor_info: Dict[str, str] = {
111
124
  "id": id,
112
125
  "version": version,
113
126
  "code_path": str(code_path),
114
127
  "server_addr": server_addr,
115
128
  "config_path": str(config_path),
116
- "disable_automatic_function_executor_management": str(
117
- disable_automatic_function_executor_management
118
- ),
129
+ "grpc_server_addr": str(grpc_server_addr),
119
130
  "hostname": gethostname(),
120
131
  }
121
132
  executor_info.update(function_allowlist_to_info_dict(function_allowlist))
@@ -137,18 +148,88 @@ class Executor:
137
148
  asyncio.get_event_loop().create_task(self._monitoring_server.run())
138
149
 
139
150
  try:
140
- asyncio.get_event_loop().run_until_complete(self._run_tasks_loop())
151
+ if self._grpc_server_addr is None:
152
+ asyncio.get_event_loop().run_until_complete(self._http_mode_loop())
153
+ else:
154
+ asyncio.get_event_loop().run_until_complete(self._grpc_mode_loop())
141
155
  except asyncio.CancelledError:
142
156
  pass # Suppress this expected exception and return without error (normally).
143
157
 
144
- async def _run_tasks_loop(self):
158
+ async def _grpc_mode_loop(self):
159
+ metric_executor_state.state("running")
160
+ self._startup_probe_handler.set_ready()
161
+
162
+ while not self._is_shutdown:
163
+ async with self._establish_grpc_server_channel() as server_channel:
164
+ server_channel: grpc.aio.Channel
165
+ await self._run_grpc_mode_services(server_channel)
166
+ self._logger.warning(
167
+ "grpc mode services exited, retrying in 5 seconds",
168
+ )
169
+ await asyncio.sleep(5)
170
+
171
+ async def _establish_grpc_server_channel(self) -> grpc.aio.Channel:
172
+ try:
173
+ channel = grpc.aio.insecure_channel(self._grpc_server_addr)
174
+ await asyncio.wait_for(
175
+ channel.channel_ready(),
176
+ timeout=EXECUTOR_GRPC_SERVER_READY_TIMEOUT_SEC,
177
+ )
178
+ return channel
179
+ except Exception as e:
180
+ self._logger.error("failed establishing grpc server channel", exc_info=e)
181
+ raise
182
+
183
+ async def _run_grpc_mode_services(self, server_channel: grpc.aio.Channel):
184
+ """Runs the gRPC mode services.
185
+
186
+ Never raises any exceptions."""
187
+ try:
188
+ self._state_reporter = ExecutorStateReporter(
189
+ executor_id=self._id,
190
+ function_allowlist=self._function_allowlist,
191
+ function_executor_states=self._function_executor_states,
192
+ server_channel=server_channel,
193
+ logger=self._logger,
194
+ )
195
+ self._state_reconciler = ExecutorStateReconciler(
196
+ executor_id=self._id,
197
+ function_executor_server_factory=self._function_executor_server_factory,
198
+ base_url=self._base_url,
199
+ function_executor_states=self._function_executor_states,
200
+ config_path=self._config_path,
201
+ downloader=self._downloader,
202
+ task_reporter=self._task_reporter,
203
+ server_channel=server_channel,
204
+ logger=self._logger,
205
+ )
206
+
207
+ # Task group ensures that:
208
+ # 1. If one of the tasks fails then the other tasks are cancelled.
209
+ # 2. If Executor shuts down then all the tasks are cancelled and this function returns.
210
+ async with asyncio.TaskGroup() as tg:
211
+ tg.create_task(self._state_reporter.run())
212
+ tg.create_task(self._state_reconciler.run())
213
+ except Exception as e:
214
+ self._logger.error("failed running grpc mode services", exc_info=e)
215
+ finally:
216
+ # Handle task cancellation using finally.
217
+ if self._state_reporter is not None:
218
+ self._state_reporter.shutdown()
219
+ self._state_reporter = None
220
+ if self._state_reconciler is not None:
221
+ self._state_reconciler.shutdown()
222
+ self._state_reconciler = None
223
+
224
+ async def _http_mode_loop(self):
145
225
  metric_executor_state.state("running")
146
226
  self._startup_probe_handler.set_ready()
147
227
  while not self._is_shutdown:
148
228
  try:
149
229
  async for task in self._task_fetcher.run():
150
230
  metric_tasks_fetched.inc()
151
- asyncio.create_task(self._run_task(task))
231
+ if not self._is_shutdown:
232
+ asyncio.create_task(self._run_task(task))
152
233
  except Exception as e:
153
234
  self._logger.error(
154
235
  "failed fetching tasks, retrying in 5 seconds", exc_info=e
@@ -167,7 +248,14 @@ class Executor:
167
248
  output = await self._run_task_and_get_output(task, logger)
168
249
  logger.info("task execution finished", success=output.success)
169
250
  except Exception as e:
170
- output = TaskOutput.internal_error(task)
251
+ output = TaskOutput.internal_error(
252
+ task_id=task.id,
253
+ namespace=task.namespace,
254
+ graph_name=task.compute_graph,
255
+ function_name=task.compute_fn,
256
+ graph_version=task.graph_version,
257
+ graph_invocation_id=task.invocation_id,
258
+ )
171
259
  logger.error("task execution failed", exc_info=e)
172
260
 
173
261
  with (
@@ -180,12 +268,32 @@ class Executor:
180
268
  metric_task_completion_latency.observe(time.monotonic() - start_time)
181
269
 
182
270
  async def _run_task_and_get_output(self, task: Task, logger: Any) -> TaskOutput:
183
- graph: SerializedObject = await self._downloader.download_graph(task)
184
- input: SerializedObject = await self._downloader.download_input(task)
271
+ graph: SerializedObject = await self._downloader.download_graph(
272
+ namespace=task.namespace,
273
+ graph_name=task.compute_graph,
274
+ graph_version=task.graph_version,
275
+ logger=logger,
276
+ )
277
+ input: SerializedObject = await self._downloader.download_input(
278
+ namespace=task.namespace,
279
+ graph_name=task.compute_graph,
280
+ graph_invocation_id=task.invocation_id,
281
+ input_key=task.input_key,
282
+ logger=logger,
283
+ )
185
284
  init_value: Optional[SerializedObject] = (
186
285
  None
187
286
  if task.reducer_output_id is None
188
- else (await self._downloader.download_init_value(task))
287
+ else (
288
+ await self._downloader.download_init_value(
289
+ namespace=task.namespace,
290
+ graph_name=task.compute_graph,
291
+ function_name=task.compute_fn,
292
+ graph_invocation_id=task.invocation_id,
293
+ reducer_output_key=task.reducer_output_id,
294
+ logger=logger,
295
+ )
296
+ )
189
297
  )
190
298
  return await self._task_runner.run(
191
299
  TaskInput(
@@ -241,11 +349,24 @@ class Executor:
241
349
 
242
350
  self._is_shutdown = True
243
351
  await self._monitoring_server.shutdown()
244
- await self._task_runner.shutdown()
352
+
353
+ if self._task_runner is not None:
354
+ await self._task_runner.shutdown()
355
+ if self._state_reporter is not None:
356
+ await self._state_reporter.shutdown()
357
+ self._state_reporter = None
358
+ if self._state_reconciler is not None:
359
+ await self._state_reconciler.shutdown()
360
+ self._state_reconciler = None
361
+
362
+ # We need to shutdown all users of FE states first,
363
+ # otherwise states might disappear unexpectedly and we might
364
+ # report errors, etc that are expected.
245
365
  await self._function_executor_states.shutdown()
246
- # We mainly need to cancel the task that runs _run_tasks_loop().
366
+ # We mainly need to cancel the task that runs _.*_mode_loop().
247
367
  for task in asyncio.all_tasks(loop):
248
368
  task.cancel()
369
+ # The current task is cancelled, the code after this line will not run.
249
370
 
250
371
  def shutdown(self, loop):
251
372
  loop.create_task(self._shutdown(loop))
@@ -1,6 +1,8 @@
1
1
  import asyncio
2
2
  from typing import Optional
3
3
 
4
+ from indexify.task_scheduler.proto.task_scheduler_pb2 import FunctionExecutorStatus
5
+
4
6
  from .function_executor import FunctionExecutor
5
7
  from .metrics.function_executor_state import (
6
8
  metric_function_executor_state_not_locked_errors,
@@ -15,14 +17,31 @@ class FunctionExecutorState:
15
17
  under the lock.
16
18
  """
17
19
 
18
- def __init__(self, function_id_with_version: str, function_id_without_version: str):
19
- self.function_id_with_version: str = function_id_with_version
20
- self.function_id_without_version: str = function_id_without_version
21
- # All the fields below are protected by the lock.
20
+ def __init__(
21
+ self,
22
+ id: str,
23
+ namespace: str,
24
+ graph_name: str,
25
+ graph_version: str,
26
+ function_name: str,
27
+ image_uri: Optional[str],
28
+ ):
29
+ # Read only fields.
30
+ self.id: str = id
31
+ self.namespace: str = namespace
32
+ self.graph_name: str = graph_name
33
+ self.function_name: str = function_name
34
+ self.image_uri: Optional[str] = image_uri
35
+ # The lock must be held while modifying the fields below.
22
36
  self.lock: asyncio.Lock = asyncio.Lock()
37
+ self.graph_version: str = graph_version
23
38
  self.is_shutdown: bool = False
24
39
  # Set to True if a Function Executor health check ever failed.
25
40
  self.health_check_failed: bool = False
41
+ # TODO: remove fields that duplicate this status field.
42
+ self.status: FunctionExecutorStatus = (
43
+ FunctionExecutorStatus.FUNCTION_EXECUTOR_STATUS_STOPPED
44
+ )
26
45
  self.function_executor: Optional[FunctionExecutor] = None
27
46
  self.running_tasks: int = 0
28
47
  self.running_tasks_change_notifier: asyncio.Condition = asyncio.Condition(
@@ -1,7 +1,6 @@
1
1
  import asyncio
2
- from typing import AsyncGenerator, Dict
2
+ from typing import AsyncGenerator, Dict, Optional
3
3
 
4
- from ..api_objects import Task
5
4
  from .function_executor_state import FunctionExecutorState
6
5
  from .metrics.function_executor_state_container import (
7
6
  metric_function_executor_states_count,
@@ -17,19 +16,33 @@ class FunctionExecutorStatesContainer:
17
16
  self._states: Dict[str, FunctionExecutorState] = {}
18
17
  self._is_shutdown: bool = False
19
18
 
20
- async def get_or_create_state(self, task: Task) -> FunctionExecutorState:
21
- """Get or create a function executor state for the given task.
19
+ async def get_or_create_state(
20
+ self,
21
+ id: str,
22
+ namespace: str,
23
+ graph_name: str,
24
+ graph_version: str,
25
+ function_name: str,
26
+ image_uri: Optional[str],
27
+ ) -> FunctionExecutorState:
28
+ """Get or create a function executor state with the given ID.
22
29
 
30
+ If the state already exists, it is returned. Otherwise, a new state is created from the supplied task.
23
31
  Raises Exception if it's not possible to create a new state at this time."""
24
32
  async with self._lock:
25
33
  if self._is_shutdown:
26
- raise RuntimeError("Task runner is shutting down.")
34
+ raise RuntimeError(
35
+ "Function Executor states container is shutting down."
36
+ )
27
37
 
28
- id = function_id_without_version(task)
29
38
  if id not in self._states:
30
39
  state = FunctionExecutorState(
31
- function_id_with_version=function_id_with_version(task),
32
- function_id_without_version=id,
40
+ id=id,
41
+ namespace=namespace,
42
+ graph_name=graph_name,
43
+ graph_version=graph_version,
44
+ function_name=function_name,
45
+ image_uri=image_uri,
33
46
  )
34
47
  self._states[id] = state
35
48
  metric_function_executor_states_count.set(len(self._states))
@@ -41,6 +54,13 @@ class FunctionExecutorStatesContainer:
41
54
  for state in self._states.values():
42
55
  yield state
43
56
 
57
+ async def pop(self, id: str) -> FunctionExecutorState:
58
+ """Removes the state with the given ID and returns it."""
59
+ async with self._lock:
60
+ state = self._states.pop(id)
61
+ metric_function_executor_states_count.set(len(self._states))
62
+ return state
63
+
44
64
  async def shutdown(self):
45
65
  # Function Executors are outside the Executor process
46
66
  # so they need to get cleaned up explicitly and reliably.
@@ -54,11 +74,3 @@ class FunctionExecutorStatesContainer:
54
74
  async with state.lock:
55
75
  await state.shutdown()
56
76
  # The task running inside the Function Executor will fail because it's destroyed.
57
-
58
-
59
- def function_id_with_version(task: Task) -> str:
60
- return f"versioned/{task.namespace}/{task.compute_graph}/{task.graph_version}/{task.compute_fn}"
61
-
62
-
63
- def function_id_without_version(task: Task) -> str:
64
- return f"not_versioned/{task.namespace}/{task.compute_graph}/{task.compute_fn}"