indexify 0.3.15__py3-none-any.whl → 0.3.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. indexify/cli/cli.py +20 -91
  2. indexify/executor/api_objects.py +2 -0
  3. indexify/executor/executor.py +75 -84
  4. indexify/executor/function_executor/function_executor_state.py +43 -43
  5. indexify/executor/function_executor/function_executor_states_container.py +10 -4
  6. indexify/executor/function_executor/function_executor_status.py +91 -0
  7. indexify/executor/function_executor/metrics/function_executor.py +1 -1
  8. indexify/executor/function_executor/metrics/function_executor_state.py +36 -0
  9. indexify/executor/function_executor/server/function_executor_server_factory.py +8 -8
  10. indexify/executor/function_executor/single_task_runner.py +100 -37
  11. indexify/executor/grpc/channel_creator.py +53 -0
  12. indexify/executor/grpc/metrics/channel_creator.py +18 -0
  13. indexify/executor/grpc/metrics/state_reporter.py +17 -0
  14. indexify/executor/{state_reconciler.py → grpc/state_reconciler.py} +60 -31
  15. indexify/executor/grpc/state_reporter.py +199 -0
  16. indexify/executor/monitoring/health_checker/generic_health_checker.py +27 -12
  17. indexify/executor/task_runner.py +30 -6
  18. indexify/{task_scheduler/proto → proto}/task_scheduler.proto +23 -17
  19. indexify/proto/task_scheduler_pb2.py +64 -0
  20. indexify/{task_scheduler/proto → proto}/task_scheduler_pb2.pyi +28 -10
  21. indexify/{task_scheduler/proto → proto}/task_scheduler_pb2_grpc.py +16 -16
  22. {indexify-0.3.15.dist-info → indexify-0.3.16.dist-info}/METADATA +1 -1
  23. {indexify-0.3.15.dist-info → indexify-0.3.16.dist-info}/RECORD +25 -21
  24. indexify/executor/state_reporter.py +0 -127
  25. indexify/task_scheduler/proto/task_scheduler_pb2.py +0 -69
  26. {indexify-0.3.15.dist-info → indexify-0.3.16.dist-info}/WHEEL +0 -0
  27. {indexify-0.3.15.dist-info → indexify-0.3.16.dist-info}/entry_points.txt +0 -0
indexify/cli/cli.py CHANGED
@@ -9,11 +9,7 @@ configure_logging_early()
9
9
  import os
10
10
  import re
11
11
  import shutil
12
- import signal
13
- import subprocess
14
12
  import sys
15
- import threading
16
- import time
17
13
  from importlib.metadata import version
18
14
  from pathlib import Path
19
15
  from socket import gethostname
@@ -51,91 +47,6 @@ console = Console(theme=custom_theme)
51
47
  app = typer.Typer(pretty_exceptions_enable=False, no_args_is_help=True)
52
48
 
53
49
 
54
- @app.command(
55
- help="Run server and executor in dev mode (Not recommended for production.)"
56
- )
57
- def server_dev_mode():
58
- indexify_server_path = os.path.expanduser("~/.indexify/indexify-server")
59
- if not os.path.exists(indexify_server_path):
60
- print("indexify-server not found. Downloading...")
61
- try:
62
- download_command = subprocess.check_output(
63
- ["curl", "-s", "https://getindexify.ai"], universal_newlines=True
64
- )
65
- subprocess.run(download_command, shell=True, check=True)
66
- except subprocess.CalledProcessError as e:
67
- print(f"failed to download indexify-server: {e}")
68
- exit(1)
69
- try:
70
- os.makedirs(os.path.dirname(indexify_server_path), exist_ok=True)
71
- shutil.move("indexify-server", indexify_server_path)
72
- except Exception as e:
73
- print(f"failed to move indexify-server to {indexify_server_path}: {e}")
74
- exit(1)
75
- print("starting indexify server and executor in dev mode...")
76
- print("press Ctrl+C to stop the server and executor.")
77
- print(f"server binary path: {indexify_server_path}")
78
- commands: List[List[str]] = [
79
- [indexify_server_path, "--dev"],
80
- ["indexify-cli", "executor", "--dev"],
81
- ]
82
- processes = []
83
- stop_event = threading.Event()
84
-
85
- def handle_output(process):
86
- for line in iter(process.stdout.readline, ""):
87
- sys.stdout.write(line)
88
- sys.stdout.flush()
89
-
90
- def terminate_processes():
91
- print("Terminating processes...")
92
- stop_event.set()
93
- for process in processes:
94
- if process.poll() is None:
95
- try:
96
- process.terminate()
97
- process.wait(timeout=5)
98
- except subprocess.TimeoutExpired:
99
- print(f"Force killing process {process.pid}")
100
- process.kill()
101
-
102
- def signal_handler(sig, frame):
103
- print("\nCtrl+C pressed. Shutting down...")
104
- terminate_processes()
105
- sys.exit(0)
106
-
107
- signal.signal(signal.SIGINT, signal_handler)
108
- signal.signal(signal.SIGTERM, signal_handler)
109
-
110
- for cmd in commands:
111
- process = subprocess.Popen(
112
- cmd,
113
- stdout=subprocess.PIPE,
114
- stderr=subprocess.STDOUT,
115
- bufsize=1,
116
- universal_newlines=True,
117
- preexec_fn=os.setsid if os.name != "nt" else None,
118
- )
119
- processes.append(process)
120
-
121
- thread = threading.Thread(target=handle_output, args=(process,))
122
- thread.daemon = True
123
- thread.start()
124
-
125
- try:
126
- while True:
127
- time.sleep(1)
128
- if all(process.poll() is not None for process in processes):
129
- print("All processes have finished.")
130
- break
131
- except KeyboardInterrupt:
132
- signal_handler(None, None)
133
- finally:
134
- terminate_processes()
135
-
136
- print("Script execution completed.")
137
-
138
-
139
50
  @app.command(help="Build image for function names")
140
51
  def build_image(
141
52
  workflow_file_path: str,
@@ -208,17 +119,27 @@ def executor(
208
119
  help="Port where to run Executor Monitoring server",
209
120
  ),
210
121
  ] = 7000,
122
+ # TODO: Figure out mTLS for gRPC.
211
123
  grpc_server_addr: Annotated[
212
124
  Optional[str],
213
125
  typer.Option(
214
126
  "--grpc-server-addr",
215
127
  help=(
216
128
  "(exprimental) Address of server gRPC API to connect to, e.g. 'localhost:8901'.\n"
217
- "If set disables automatic Function Executor management on Executor and uses the Server gRPC API\n"
218
- "for Function Executor management and placement of tasks on them."
129
+ "Enables gRPC state reporter that will periodically report the state of the Function Executors to Server\n"
219
130
  ),
220
131
  ),
221
132
  ] = None,
133
+ enable_grpc_state_reconciler: Annotated[
134
+ bool,
135
+ typer.Option(
136
+ "--enable-grpc-state-reconciler",
137
+ help=(
138
+ "(exprimental) Enable gRPC state reconciler that will reconcile the state of the Function Executors and Task Allocations\n"
139
+ "with the desired state provided by Server. Required --grpc-server-addr to be set."
140
+ ),
141
+ ),
142
+ ] = False,
222
143
  ):
223
144
  if dev:
224
145
  configure_development_mode_logging()
@@ -236,6 +157,11 @@ def executor(
236
157
  "--executor-id should be at least 10 characters long and only include characters _-[0-9][a-z][A-Z]"
237
158
  )
238
159
 
160
+ if enable_grpc_state_reconciler and grpc_server_addr is None:
161
+ raise typer.BadParameter(
162
+ "--grpc-server-addr must be set when --enable-grpc-state-reconciler is set"
163
+ )
164
+
239
165
  executor_version = version("indexify")
240
166
  logger = structlog.get_logger(module=__name__, executor_id=executor_id)
241
167
 
@@ -252,6 +178,7 @@ def executor(
252
178
  monitoring_server_host=monitoring_server_host,
253
179
  monitoring_server_port=monitoring_server_port,
254
180
  grpc_server_addr=grpc_server_addr,
181
+ enable_grpc_state_reconciler=enable_grpc_state_reconciler,
255
182
  )
256
183
 
257
184
  executor_cache = Path(executor_cache).expanduser().absolute()
@@ -277,6 +204,7 @@ def executor(
277
204
 
278
205
  Executor(
279
206
  id=executor_id,
207
+ development_mode=dev,
280
208
  version=executor_version,
281
209
  health_checker=GenericHealthChecker(),
282
210
  code_path=executor_cache,
@@ -290,6 +218,7 @@ def executor(
290
218
  monitoring_server_host=monitoring_server_host,
291
219
  monitoring_server_port=monitoring_server_port,
292
220
  grpc_server_addr=grpc_server_addr,
221
+ enable_grpc_state_reconciler=enable_grpc_state_reconciler,
293
222
  ).run()
294
223
 
295
224
 
@@ -14,6 +14,8 @@ class Task(BaseModel):
14
14
  graph_version: str
15
15
  image_uri: Optional[str] = None
16
16
  "image_uri defines the URI of the image of this task. Optional since some executors do not require it."
17
+ secret_names: Optional[List[str]] = None
18
+ "secret_names defines the names of the secrets to set on function executor. Optional for backward compatibility."
17
19
 
18
20
 
19
21
  class FunctionURI(BaseModel):
@@ -5,11 +5,12 @@ from pathlib import Path
5
5
  from socket import gethostname
6
6
  from typing import Any, Dict, List, Optional
7
7
 
8
- import grpc
9
8
  import structlog
10
9
  from tensorlake.function_executor.proto.function_executor_pb2 import SerializedObject
11
10
  from tensorlake.utils.logging import suppress as suppress_logging
12
11
 
12
+ from indexify.proto.task_scheduler_pb2 import ExecutorStatus
13
+
13
14
  from .api_objects import FunctionURI, Task
14
15
  from .downloader import Downloader
15
16
  from .function_executor.function_executor_states_container import (
@@ -18,6 +19,9 @@ from .function_executor.function_executor_states_container import (
18
19
  from .function_executor.server.function_executor_server_factory import (
19
20
  FunctionExecutorServerFactory,
20
21
  )
22
+ from .grpc.channel_creator import ChannelCreator
23
+ from .grpc.state_reconciler import ExecutorStateReconciler
24
+ from .grpc.state_reporter import ExecutorStateReporter
21
25
  from .metrics.executor import (
22
26
  METRIC_TASKS_COMPLETED_OUTCOME_ALL,
23
27
  METRIC_TASKS_COMPLETED_OUTCOME_ERROR_CUSTOMER_CODE,
@@ -39,14 +43,10 @@ from .monitoring.health_checker.health_checker import HealthChecker
39
43
  from .monitoring.prometheus_metrics_handler import PrometheusMetricsHandler
40
44
  from .monitoring.server import MonitoringServer
41
45
  from .monitoring.startup_probe_handler import StartupProbeHandler
42
- from .state_reconciler import ExecutorStateReconciler
43
- from .state_reporter import ExecutorStateReporter
44
46
  from .task_fetcher import TaskFetcher
45
47
  from .task_reporter import TaskReporter
46
48
  from .task_runner import TaskInput, TaskOutput, TaskRunner
47
49
 
48
- EXECUTOR_GRPC_SERVER_READY_TIMEOUT_SEC = 10
49
-
50
50
  metric_executor_state.state("starting")
51
51
 
52
52
 
@@ -54,6 +54,7 @@ class Executor:
54
54
  def __init__(
55
55
  self,
56
56
  id: str,
57
+ development_mode: bool,
57
58
  version: str,
58
59
  code_path: Path,
59
60
  health_checker: HealthChecker,
@@ -64,10 +65,10 @@ class Executor:
64
65
  monitoring_server_host: str,
65
66
  monitoring_server_port: int,
66
67
  grpc_server_addr: Optional[str],
68
+ enable_grpc_state_reconciler: bool,
67
69
  ):
68
70
  self._logger = structlog.get_logger(module=__name__)
69
71
  self._is_shutdown: bool = False
70
- self._config_path = config_path
71
72
  protocol: str = "http"
72
73
  if config_path:
73
74
  self._logger.info("running the extractor with TLS enabled")
@@ -84,7 +85,9 @@ class Executor:
84
85
  health_probe_handler=HealthCheckHandler(health_checker),
85
86
  metrics_handler=PrometheusMetricsHandler(),
86
87
  )
87
- self._function_executor_states = FunctionExecutorStatesContainer()
88
+ self._function_executor_states = FunctionExecutorStatesContainer(
89
+ logger=self._logger
90
+ )
88
91
  health_checker.set_function_executor_states_container(
89
92
  self._function_executor_states
90
93
  )
@@ -94,24 +97,54 @@ class Executor:
94
97
  self._task_reporter = TaskReporter(
95
98
  base_url=self._base_url,
96
99
  executor_id=id,
97
- config_path=self._config_path,
100
+ config_path=config_path,
98
101
  )
99
- self._grpc_server_addr: Optional[str] = grpc_server_addr
100
- self._id = id
101
102
  self._function_allowlist: Optional[List[FunctionURI]] = function_allowlist
102
103
  self._function_executor_server_factory = function_executor_server_factory
104
+
105
+ # HTTP mode services
106
+ self._task_runner: Optional[TaskRunner] = None
107
+ self._task_fetcher: Optional[TaskFetcher] = None
108
+ # gRPC mode services
109
+ self._channel_creator: Optional[ChannelCreator] = None
103
110
  self._state_reporter: Optional[ExecutorStateReporter] = None
104
111
  self._state_reconciler: Optional[ExecutorStateReconciler] = None
105
112
 
106
- if self._grpc_server_addr is None:
107
- self._task_runner: Optional[TaskRunner] = TaskRunner(
113
+ if grpc_server_addr is not None:
114
+ self._channel_creator = ChannelCreator(grpc_server_addr, self._logger)
115
+ self._state_reporter = ExecutorStateReporter(
116
+ executor_id=id,
117
+ development_mode=development_mode,
118
+ function_allowlist=self._function_allowlist,
119
+ function_executor_states=self._function_executor_states,
120
+ channel_creator=self._channel_creator,
121
+ logger=self._logger,
122
+ )
123
+ self._state_reporter.update_executor_status(
124
+ ExecutorStatus.EXECUTOR_STATUS_STARTING_UP
125
+ )
126
+
127
+ if enable_grpc_state_reconciler:
128
+ self._state_reconciler = ExecutorStateReconciler(
129
+ executor_id=id,
130
+ function_executor_server_factory=self._function_executor_server_factory,
131
+ base_url=self._base_url,
132
+ function_executor_states=self._function_executor_states,
133
+ config_path=config_path,
134
+ downloader=self._downloader,
135
+ task_reporter=self._task_reporter,
136
+ channel_creator=self._channel_creator,
137
+ logger=self._logger,
138
+ )
139
+ else:
140
+ self._task_runner = TaskRunner(
108
141
  executor_id=id,
109
142
  function_executor_server_factory=function_executor_server_factory,
110
143
  base_url=self._base_url,
111
144
  function_executor_states=self._function_executor_states,
112
145
  config_path=config_path,
113
146
  )
114
- self._task_fetcher: Optional[TaskFetcher] = TaskFetcher(
147
+ self._task_fetcher = TaskFetcher(
115
148
  executor_id=id,
116
149
  executor_version=version,
117
150
  function_allowlist=function_allowlist,
@@ -122,11 +155,13 @@ class Executor:
122
155
 
123
156
  executor_info: Dict[str, str] = {
124
157
  "id": id,
158
+ "dev_mode": str(development_mode),
125
159
  "version": version,
126
160
  "code_path": str(code_path),
127
161
  "server_addr": server_addr,
128
162
  "config_path": str(config_path),
129
163
  "grpc_server_addr": str(grpc_server_addr),
164
+ "enable_grpc_state_reconciler": str(enable_grpc_state_reconciler),
130
165
  "hostname": gethostname(),
131
166
  }
132
167
  executor_info.update(function_allowlist_to_info_dict(function_allowlist))
@@ -146,84 +181,35 @@ class Executor:
146
181
  )
147
182
 
148
183
  asyncio.get_event_loop().create_task(self._monitoring_server.run())
184
+ if self._state_reporter is not None:
185
+ self._state_reporter.update_executor_status(
186
+ ExecutorStatus.EXECUTOR_STATUS_RUNNING
187
+ )
188
+ asyncio.get_event_loop().create_task(self._state_reporter.run())
149
189
 
150
- try:
151
- if self._grpc_server_addr is None:
152
- asyncio.get_event_loop().run_until_complete(self._http_mode_loop())
153
- else:
154
- asyncio.get_event_loop().run_until_complete(self._grpc_mode_loop())
155
- except asyncio.CancelledError:
156
- pass # Suppress this expected exception and return without error (normally).
157
-
158
- async def _grpc_mode_loop(self):
159
190
  metric_executor_state.state("running")
160
191
  self._startup_probe_handler.set_ready()
161
192
 
162
- while not self._is_shutdown:
163
- async with self._establish_grpc_server_channel() as server_channel:
164
- server_channel: grpc.aio.Channel
165
- await self._run_grpc_mode_services(server_channel)
166
- self._logger.warning(
167
- "grpc mode services exited, retrying in 5 seconds",
168
- )
169
- await asyncio.sleep(5)
170
-
171
- async def _establish_grpc_server_channel(self) -> grpc.aio.Channel:
172
193
  try:
173
- channel = grpc.aio.insecure_channel(self._grpc_server_addr)
174
- await asyncio.wait_for(
175
- channel.channel_ready(),
176
- timeout=EXECUTOR_GRPC_SERVER_READY_TIMEOUT_SEC,
177
- )
178
- return channel
179
- except Exception as e:
180
- self._logger.error("failed establishing grpc server channel", exc_info=e)
181
- raise
194
+ if self._state_reconciler is None:
195
+ asyncio.get_event_loop().run_until_complete(
196
+ self._http_task_runner_loop()
197
+ )
198
+ else:
199
+ asyncio.get_event_loop().run_until_complete(
200
+ self._grpc_state_reconciler_loop()
201
+ )
202
+ except asyncio.CancelledError:
203
+ pass # Suppress this expected exception and return without error (normally).
182
204
 
183
- async def _run_grpc_mode_services(self, server_channel: grpc.aio.Channel):
184
- """Runs the gRPC mode services.
205
+ async def _grpc_state_reconciler_loop(self):
206
+ """Runs the gRPC state reconciler and state reporter.
185
207
 
186
208
  Never raises any exceptions."""
187
- try:
188
- self._state_reporter = ExecutorStateReporter(
189
- executor_id=self._id,
190
- function_allowlist=self._function_allowlist,
191
- function_executor_states=self._function_executor_states,
192
- server_channel=server_channel,
193
- logger=self._logger,
194
- )
195
- self._state_reconciler = ExecutorStateReconciler(
196
- executor_id=self._id,
197
- function_executor_server_factory=self._function_executor_server_factory,
198
- base_url=self._base_url,
199
- function_executor_states=self._function_executor_states,
200
- config_path=self._config_path,
201
- downloader=self._downloader,
202
- task_reporter=self._task_reporter,
203
- server_channel=server_channel,
204
- logger=self._logger,
205
- )
209
+ asyncio.create_task(self._state_reporter.run())
210
+ await self._state_reconciler.run()
206
211
 
207
- # Task group ensures that:
208
- # 1. If one of the tasks fails then the other tasks are cancelled.
209
- # 2. If Executor shuts down then all the tasks are cancelled and this function returns.
210
- async with asyncio.TaskGroup() as tg:
211
- tg.create_task(self._state_reporter.run())
212
- tg.create_task(self._state_reconciler.run())
213
- except Exception as e:
214
- self._logger.error("failed running grpc mode services", exc_info=e)
215
- finally:
216
- # Handle task cancellation using finally.
217
- if self._state_reporter is not None:
218
- self._state_reporter.shutdown()
219
- self._state_reporter = None
220
- if self._state_reconciler is not None:
221
- self._state_reconciler.shutdown()
222
- self._state_reconciler = None
223
-
224
- async def _http_mode_loop(self):
225
- metric_executor_state.state("running")
226
- self._startup_probe_handler.set_ready()
212
+ async def _http_task_runner_loop(self):
227
213
  while not self._is_shutdown:
228
214
  try:
229
215
  async for task in self._task_fetcher.run():
@@ -341,6 +327,10 @@ class Executor:
341
327
 
342
328
  async def _shutdown(self, loop):
343
329
  self._logger.info("shutting down")
330
+ if self._state_reporter is not None:
331
+ self._state_reporter.update_executor_status(
332
+ ExecutorStatus.EXECUTOR_STATUS_STOPPING
333
+ )
344
334
  metric_executor_state.state("shutting_down")
345
335
  # There will be lots of task cancellation exceptions and "X is shutting down"
346
336
  # exceptions logged during Executor shutdown. Suppress their logs as they are
@@ -352,12 +342,13 @@ class Executor:
352
342
 
353
343
  if self._task_runner is not None:
354
344
  await self._task_runner.shutdown()
345
+
346
+ if self._channel_creator is not None:
347
+ await self._channel_creator.shutdown()
355
348
  if self._state_reporter is not None:
356
349
  await self._state_reporter.shutdown()
357
- self._state_reporter = None
358
350
  if self._state_reconciler is not None:
359
351
  await self._state_reconciler.shutdown()
360
- self._state_reconciler = None
361
352
 
362
353
  # We need to shutdown all users of FE states first,
363
354
  # otherwise states might disappear unexpectedly and we might
@@ -1,11 +1,11 @@
1
1
  import asyncio
2
- from typing import Optional
3
-
4
- from indexify.task_scheduler.proto.task_scheduler_pb2 import FunctionExecutorStatus
2
+ from typing import Any, List, Optional
5
3
 
6
4
  from .function_executor import FunctionExecutor
5
+ from .function_executor_status import FunctionExecutorStatus, is_status_change_allowed
7
6
  from .metrics.function_executor_state import (
8
7
  metric_function_executor_state_not_locked_errors,
8
+ metric_function_executors_with_status,
9
9
  )
10
10
 
11
11
 
@@ -25,6 +25,7 @@ class FunctionExecutorState:
25
25
  graph_version: str,
26
26
  function_name: str,
27
27
  image_uri: Optional[str],
28
+ logger: Any,
28
29
  ):
29
30
  # Read only fields.
30
31
  self.id: str = id
@@ -32,69 +33,68 @@ class FunctionExecutorState:
32
33
  self.graph_name: str = graph_name
33
34
  self.function_name: str = function_name
34
35
  self.image_uri: Optional[str] = image_uri
36
+ self._logger: Any = logger.bind(
37
+ module=__name__,
38
+ function_executor_id=id,
39
+ namespace=namespace,
40
+ graph_name=graph_name,
41
+ graph_version=graph_version,
42
+ function_name=function_name,
43
+ image_uri=image_uri,
44
+ )
35
45
  # The lock must be held while modifying the fields below.
36
46
  self.lock: asyncio.Lock = asyncio.Lock()
47
+ # TODO: Move graph_version to immutable fields once we migrate to gRPC State Reconciler.
37
48
  self.graph_version: str = graph_version
38
- self.is_shutdown: bool = False
39
- # Set to True if a Function Executor health check ever failed.
40
- self.health_check_failed: bool = False
41
- # TODO: remove fields that duplicate this status field.
42
- self.status: FunctionExecutorStatus = (
43
- FunctionExecutorStatus.FUNCTION_EXECUTOR_STATUS_STOPPED
44
- )
45
- self.function_executor: Optional[FunctionExecutor] = None
46
- self.running_tasks: int = 0
47
- self.running_tasks_change_notifier: asyncio.Condition = asyncio.Condition(
49
+ self.status: FunctionExecutorStatus = FunctionExecutorStatus.DESTROYED
50
+ self.status_change_notifier: asyncio.Condition = asyncio.Condition(
48
51
  lock=self.lock
49
52
  )
53
+ self.function_executor: Optional[FunctionExecutor] = None
54
+ metric_function_executors_with_status.labels(status=self.status.name).inc()
50
55
 
51
- def increment_running_tasks(self) -> None:
52
- """Increments the number of running tasks.
53
-
54
- The caller must hold the lock.
55
- """
56
- self.check_locked()
57
- self.running_tasks += 1
58
- self.running_tasks_change_notifier.notify_all()
59
-
60
- def decrement_running_tasks(self) -> None:
61
- """Decrements the number of running tasks.
56
+ async def wait_status(self, allowlist: List[FunctionExecutorStatus]) -> None:
57
+ """Waits until Function Executor status reaches one of the allowed values.
62
58
 
63
59
  The caller must hold the lock.
64
60
  """
65
61
  self.check_locked()
66
- self.running_tasks -= 1
67
- self.running_tasks_change_notifier.notify_all()
62
+ while self.status not in allowlist:
63
+ await self.status_change_notifier.wait()
68
64
 
69
- async def wait_running_tasks_less(self, value: int) -> None:
70
- """Waits until the number of running tasks is less than the supplied value.
65
+ async def set_status(self, new_status: FunctionExecutorStatus) -> None:
66
+ """Sets the status of the Function Executor.
71
67
 
72
68
  The caller must hold the lock.
69
+ Raises ValueError if the status change is not allowed.
73
70
  """
74
71
  self.check_locked()
75
- while self.running_tasks >= value:
76
- await self.running_tasks_change_notifier.wait()
72
+ if is_status_change_allowed(self.status, new_status):
73
+ self._logger.info(
74
+ "function executor status changed",
75
+ old_status=self.status.name,
76
+ new_status=new_status.name,
77
+ )
78
+ metric_function_executors_with_status.labels(status=self.status.name).dec()
79
+ metric_function_executors_with_status.labels(status=new_status.name).inc()
80
+ self.status = new_status
81
+ self.status_change_notifier.notify_all()
82
+ else:
83
+ raise ValueError(
84
+ f"Invalid status change from {self.status} to {new_status}"
85
+ )
77
86
 
78
87
  async def destroy_function_executor(self) -> None:
79
88
  """Destroys the Function Executor if it exists.
80
89
 
81
- The caller must hold the lock."""
90
+ The caller must hold the lock.
91
+ """
82
92
  self.check_locked()
93
+ await self.set_status(FunctionExecutorStatus.DESTROYING)
83
94
  if self.function_executor is not None:
84
95
  await self.function_executor.destroy()
85
96
  self.function_executor = None
86
-
87
- async def shutdown(self) -> None:
88
- """Shuts down the state.
89
-
90
- Called only during Executor shutdown so it's okay to fail all running and pending
91
- Function Executor tasks. The state is not valid anymore after this call.
92
- The caller must hold the lock.
93
- """
94
- self.check_locked()
95
- # Pending tasks will not create a new Function Executor and won't run.
96
- self.is_shutdown = True
97
- await self.destroy_function_executor()
97
+ await self.set_status(FunctionExecutorStatus.DESTROYED)
98
98
 
99
99
  def check_locked(self) -> None:
100
100
  """Raises an exception if the lock is not held."""
@@ -1,7 +1,8 @@
1
1
  import asyncio
2
- from typing import AsyncGenerator, Dict, Optional
2
+ from typing import Any, AsyncGenerator, Dict, Optional
3
3
 
4
4
  from .function_executor_state import FunctionExecutorState
5
+ from .function_executor_status import FunctionExecutorStatus
5
6
  from .metrics.function_executor_state_container import (
6
7
  metric_function_executor_states_count,
7
8
  )
@@ -10,11 +11,12 @@ from .metrics.function_executor_state_container import (
10
11
  class FunctionExecutorStatesContainer:
11
12
  """An asyncio concurrent container for the function executor states."""
12
13
 
13
- def __init__(self):
14
+ def __init__(self, logger: Any):
14
15
  # The fields below are protected by the lock.
15
16
  self._lock: asyncio.Lock = asyncio.Lock()
16
17
  self._states: Dict[str, FunctionExecutorState] = {}
17
18
  self._is_shutdown: bool = False
19
+ self._logger: Any = logger.bind(module=__name__)
18
20
 
19
21
  async def get_or_create_state(
20
22
  self,
@@ -43,6 +45,7 @@ class FunctionExecutorStatesContainer:
43
45
  graph_version=graph_version,
44
46
  function_name=function_name,
45
47
  image_uri=image_uri,
48
+ logger=self._logger,
46
49
  )
47
50
  self._states[id] = state
48
51
  metric_function_executor_states_count.set(len(self._states))
@@ -72,5 +75,8 @@ class FunctionExecutorStatesContainer:
72
75
  # Only ongoing tasks who have a reference to the state already can see it.
73
76
  # The state is unlocked while a task is running inside Function Executor.
74
77
  async with state.lock:
75
- await state.shutdown()
76
- # The task running inside the Function Executor will fail because it's destroyed.
78
+ await state.set_status(FunctionExecutorStatus.SHUTDOWN)
79
+ if state.function_executor is not None:
80
+ await state.function_executor.destroy()
81
+ state.function_executor = None
82
+ # The task running inside the Function Executor will fail because it's destroyed.