indexify 0.3.31__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. indexify/cli/__init__.py +18 -0
  2. indexify/cli/build_image.py +51 -0
  3. indexify/cli/deploy.py +57 -0
  4. indexify/cli/executor.py +205 -0
  5. indexify/executor/{grpc/channel_manager.py → channel_manager.py} +17 -11
  6. indexify/executor/executor.py +57 -313
  7. indexify/executor/function_allowlist.py +59 -0
  8. indexify/executor/function_executor/function_executor.py +12 -6
  9. indexify/executor/function_executor/invocation_state_client.py +25 -3
  10. indexify/executor/function_executor/server/function_executor_server_factory.py +3 -3
  11. indexify/executor/function_executor/server/subprocess_function_executor_server_factory.py +22 -11
  12. indexify/executor/function_executor_controller/__init__.py +13 -0
  13. indexify/executor/function_executor_controller/completed_task_metrics.py +82 -0
  14. indexify/executor/function_executor_controller/create_function_executor.py +154 -0
  15. indexify/executor/function_executor_controller/debug_event_loop.py +37 -0
  16. indexify/executor/function_executor_controller/destroy_function_executor.py +28 -0
  17. indexify/executor/function_executor_controller/downloads.py +199 -0
  18. indexify/executor/function_executor_controller/events.py +172 -0
  19. indexify/executor/function_executor_controller/function_executor_controller.py +759 -0
  20. indexify/executor/function_executor_controller/loggers.py +57 -0
  21. indexify/executor/function_executor_controller/message_validators.py +65 -0
  22. indexify/executor/function_executor_controller/metrics/completed_task_metrics.py +68 -0
  23. indexify/executor/{metrics/downloader.py → function_executor_controller/metrics/downloads.py} +1 -3
  24. indexify/executor/function_executor_controller/metrics/function_executor_controller.py +60 -0
  25. indexify/executor/{function_executor/metrics/single_task_runner.py → function_executor_controller/metrics/run_task.py} +9 -3
  26. indexify/executor/function_executor_controller/metrics/upload_task_output.py +39 -0
  27. indexify/executor/function_executor_controller/prepare_task.py +38 -0
  28. indexify/executor/function_executor_controller/run_task.py +201 -0
  29. indexify/executor/function_executor_controller/task_info.py +33 -0
  30. indexify/executor/function_executor_controller/task_output.py +122 -0
  31. indexify/executor/function_executor_controller/upload_task_output.py +234 -0
  32. indexify/executor/host_resources/host_resources.py +20 -25
  33. indexify/executor/{grpc/metrics → metrics}/channel_manager.py +1 -1
  34. indexify/executor/metrics/executor.py +0 -47
  35. indexify/executor/{grpc/metrics → metrics}/state_reconciler.py +1 -1
  36. indexify/executor/{grpc/metrics → metrics}/state_reporter.py +1 -1
  37. indexify/executor/monitoring/health_checker/generic_health_checker.py +6 -59
  38. indexify/executor/monitoring/health_checker/health_checker.py +0 -11
  39. indexify/executor/{grpc/state_reconciler.py → state_reconciler.py} +139 -141
  40. indexify/executor/state_reporter.py +364 -0
  41. indexify/proto/executor_api.proto +67 -59
  42. indexify/proto/executor_api_pb2.py +52 -52
  43. indexify/proto/executor_api_pb2.pyi +125 -104
  44. indexify/proto/executor_api_pb2_grpc.py +0 -47
  45. {indexify-0.3.31.dist-info → indexify-0.4.2.dist-info}/METADATA +1 -3
  46. indexify-0.4.2.dist-info/RECORD +68 -0
  47. indexify-0.4.2.dist-info/entry_points.txt +3 -0
  48. indexify/cli/cli.py +0 -268
  49. indexify/executor/api_objects.py +0 -92
  50. indexify/executor/downloader.py +0 -417
  51. indexify/executor/executor_flavor.py +0 -7
  52. indexify/executor/function_executor/function_executor_state.py +0 -107
  53. indexify/executor/function_executor/function_executor_states_container.py +0 -93
  54. indexify/executor/function_executor/function_executor_status.py +0 -95
  55. indexify/executor/function_executor/metrics/function_executor_state.py +0 -46
  56. indexify/executor/function_executor/metrics/function_executor_state_container.py +0 -10
  57. indexify/executor/function_executor/single_task_runner.py +0 -345
  58. indexify/executor/function_executor/task_input.py +0 -21
  59. indexify/executor/function_executor/task_output.py +0 -105
  60. indexify/executor/grpc/function_executor_controller.py +0 -418
  61. indexify/executor/grpc/metrics/task_controller.py +0 -8
  62. indexify/executor/grpc/state_reporter.py +0 -317
  63. indexify/executor/grpc/task_controller.py +0 -508
  64. indexify/executor/metrics/task_fetcher.py +0 -21
  65. indexify/executor/metrics/task_reporter.py +0 -53
  66. indexify/executor/metrics/task_runner.py +0 -52
  67. indexify/executor/monitoring/function_allowlist.py +0 -25
  68. indexify/executor/runtime_probes.py +0 -68
  69. indexify/executor/task_fetcher.py +0 -96
  70. indexify/executor/task_reporter.py +0 -459
  71. indexify/executor/task_runner.py +0 -177
  72. indexify-0.3.31.dist-info/RECORD +0 -68
  73. indexify-0.3.31.dist-info/entry_points.txt +0 -3
  74. {indexify-0.3.31.dist-info → indexify-0.4.2.dist-info}/WHEEL +0 -0
@@ -1,417 +0,0 @@
1
- import asyncio
2
- import os
3
- from typing import Any, Optional, Union
4
-
5
- import httpx
6
- import nanoid
7
- from tensorlake.function_executor.proto.function_executor_pb2 import SerializedObject
8
- from tensorlake.function_executor.proto.message_validator import MessageValidator
9
- from tensorlake.utils.http_client import get_httpx_client
10
-
11
- from indexify.proto.executor_api_pb2 import DataPayload as DataPayloadProto
12
- from indexify.proto.executor_api_pb2 import DataPayloadEncoding
13
-
14
- from .api_objects import DataPayload
15
- from .blob_store.blob_store import BLOBStore
16
- from .metrics.downloader import (
17
- metric_graph_download_errors,
18
- metric_graph_download_latency,
19
- metric_graph_downloads,
20
- metric_graphs_from_cache,
21
- metric_reducer_init_value_download_errors,
22
- metric_reducer_init_value_download_latency,
23
- metric_reducer_init_value_downloads,
24
- metric_task_input_download_errors,
25
- metric_task_input_download_latency,
26
- metric_task_input_downloads,
27
- metric_tasks_downloading_graphs,
28
- metric_tasks_downloading_inputs,
29
- metric_tasks_downloading_reducer_init_value,
30
- )
31
-
32
-
33
- class Downloader:
34
- def __init__(
35
- self,
36
- code_path: str,
37
- base_url: str,
38
- blob_store: BLOBStore,
39
- config_path: Optional[str] = None,
40
- ):
41
- self._code_path = code_path
42
- self._base_url = base_url
43
- self._client = get_httpx_client(config_path, make_async=True)
44
- self._blob_store: BLOBStore = blob_store
45
-
46
- async def download_graph(
47
- self,
48
- namespace: str,
49
- graph_name: str,
50
- graph_version: str,
51
- data_payload: Optional[Union[DataPayload, DataPayloadProto]],
52
- logger: Any,
53
- ) -> SerializedObject:
54
- logger = logger.bind(module=__name__)
55
- with (
56
- metric_graph_download_errors.count_exceptions(),
57
- metric_tasks_downloading_graphs.track_inprogress(),
58
- metric_graph_download_latency.time(),
59
- ):
60
- metric_graph_downloads.inc()
61
- return await self._download_graph(
62
- namespace=namespace,
63
- graph_name=graph_name,
64
- graph_version=graph_version,
65
- data_payload=data_payload,
66
- logger=logger,
67
- )
68
-
69
- async def download_input(
70
- self,
71
- namespace: str,
72
- graph_name: str,
73
- graph_invocation_id: str,
74
- input_key: str,
75
- data_payload: Optional[DataPayload],
76
- logger: Any,
77
- ) -> SerializedObject:
78
- logger = logger.bind(module=__name__)
79
- with (
80
- metric_task_input_download_errors.count_exceptions(),
81
- metric_tasks_downloading_inputs.track_inprogress(),
82
- metric_task_input_download_latency.time(),
83
- ):
84
- metric_task_input_downloads.inc()
85
- return await self._download_input(
86
- namespace=namespace,
87
- graph_name=graph_name,
88
- graph_invocation_id=graph_invocation_id,
89
- input_key=input_key,
90
- data_payload=data_payload,
91
- logger=logger,
92
- )
93
-
94
- async def download_init_value(
95
- self,
96
- namespace: str,
97
- graph_name: str,
98
- function_name: str,
99
- graph_invocation_id: str,
100
- reducer_output_key: str,
101
- data_payload: Optional[Union[DataPayload, DataPayloadProto]],
102
- logger: Any,
103
- ) -> SerializedObject:
104
- logger = logger.bind(module=__name__)
105
- with (
106
- metric_reducer_init_value_download_errors.count_exceptions(),
107
- metric_tasks_downloading_reducer_init_value.track_inprogress(),
108
- metric_reducer_init_value_download_latency.time(),
109
- ):
110
- metric_reducer_init_value_downloads.inc()
111
- return await self._download_init_value(
112
- namespace=namespace,
113
- graph_name=graph_name,
114
- function_name=function_name,
115
- graph_invocation_id=graph_invocation_id,
116
- reducer_output_key=reducer_output_key,
117
- data_payload=data_payload,
118
- logger=logger,
119
- )
120
-
121
- async def _download_graph(
122
- self,
123
- namespace: str,
124
- graph_name: str,
125
- graph_version: str,
126
- data_payload: Optional[Union[DataPayload, DataPayloadProto]],
127
- logger: Any,
128
- ) -> SerializedObject:
129
- # Cache graph to reduce load on the server.
130
- graph_path = os.path.join(
131
- self._code_path,
132
- "graph_cache",
133
- namespace,
134
- graph_name,
135
- graph_version,
136
- )
137
- # Filesystem operations are synchronous.
138
- # Run in a separate thread to not block the main event loop.
139
- graph: Optional[SerializedObject] = await asyncio.to_thread(
140
- self._read_cached_graph, graph_path
141
- )
142
- if graph is not None:
143
- metric_graphs_from_cache.inc()
144
- return graph
145
-
146
- if data_payload is None:
147
- graph: SerializedObject = await self._fetch_graph_from_server(
148
- namespace=namespace,
149
- graph_name=graph_name,
150
- graph_version=graph_version,
151
- logger=logger,
152
- )
153
- elif isinstance(data_payload, DataPayloadProto):
154
- (
155
- MessageValidator(data_payload)
156
- .required_field("uri")
157
- .required_field("encoding")
158
- )
159
- data: bytes = await self._blob_store.get(
160
- uri=data_payload.uri, logger=logger
161
- )
162
- return _serialized_object_from_data_payload_proto(
163
- data_payload=data_payload,
164
- data=data,
165
- )
166
- elif isinstance(data_payload, DataPayload):
167
- data: bytes = await self._blob_store.get(
168
- uri=data_payload.path, logger=logger
169
- )
170
- return _serialized_object_from_data_payload(
171
- data_payload=data_payload,
172
- data=data,
173
- )
174
-
175
- # Filesystem operations are synchronous.
176
- # Run in a separate thread to not block the main event loop.
177
- # We don't need to wait for the write completion so we use create_task.
178
- asyncio.create_task(
179
- asyncio.to_thread(self._write_cached_graph, graph_path, graph),
180
- name="graph cache write",
181
- )
182
-
183
- return graph
184
-
185
- def _read_cached_graph(self, path: str) -> Optional[SerializedObject]:
186
- if not os.path.exists(path):
187
- return None
188
-
189
- with open(path, "rb") as f:
190
- return SerializedObject.FromString(f.read())
191
-
192
- def _write_cached_graph(self, path: str, graph: SerializedObject) -> None:
193
- if os.path.exists(path):
194
- # Another task already cached the graph.
195
- return None
196
-
197
- tmp_path = os.path.join(self._code_path, "task_graph_cache", nanoid.generate())
198
- os.makedirs(os.path.dirname(tmp_path), exist_ok=True)
199
- with open(tmp_path, "wb") as f:
200
- f.write(graph.SerializeToString())
201
- os.makedirs(os.path.dirname(path), exist_ok=True)
202
- # Atomically rename the fully written file at tmp path.
203
- # This allows us to not use any locking because file link/unlink
204
- # are atomic operations at filesystem level.
205
- # This also allows to share the same cache between multiple Executors.
206
- os.replace(tmp_path, path)
207
-
208
- async def _download_input(
209
- self,
210
- namespace: str,
211
- graph_name: str,
212
- graph_invocation_id: str,
213
- input_key: str,
214
- data_payload: Optional[Union[DataPayload, DataPayloadProto]],
215
- logger: Any,
216
- ) -> SerializedObject:
217
- if data_payload is None:
218
- first_function_in_graph = graph_invocation_id == input_key.split("|")[-1]
219
- if first_function_in_graph:
220
- # The first function in Graph gets its input from graph invocation payload.
221
- return await self._fetch_graph_invocation_payload_from_server(
222
- namespace=namespace,
223
- graph_name=graph_name,
224
- graph_invocation_id=graph_invocation_id,
225
- logger=logger,
226
- )
227
- else:
228
- return await self._fetch_function_input_from_server(
229
- input_key=input_key, logger=logger
230
- )
231
- elif isinstance(data_payload, DataPayloadProto):
232
- (
233
- MessageValidator(data_payload)
234
- .required_field("uri")
235
- .required_field("encoding")
236
- )
237
- data: bytes = await self._blob_store.get(
238
- uri=data_payload.uri, logger=logger
239
- )
240
- return _serialized_object_from_data_payload_proto(
241
- data_payload=data_payload,
242
- data=data,
243
- )
244
- elif isinstance(data_payload, DataPayload):
245
- data: bytes = await self._blob_store.get(
246
- uri=data_payload.path, logger=logger
247
- )
248
- return _serialized_object_from_data_payload(
249
- data_payload=data_payload,
250
- data=data,
251
- )
252
-
253
- async def _download_init_value(
254
- self,
255
- namespace: str,
256
- graph_name: str,
257
- function_name: str,
258
- graph_invocation_id: str,
259
- reducer_output_key: str,
260
- data_payload: Optional[Union[DataPayload, DataPayloadProto]],
261
- logger: Any,
262
- ) -> SerializedObject:
263
- if data_payload is None:
264
- return await self._fetch_function_init_value_from_server(
265
- namespace=namespace,
266
- graph_name=graph_name,
267
- function_name=function_name,
268
- graph_invocation_id=graph_invocation_id,
269
- reducer_output_key=reducer_output_key,
270
- logger=logger,
271
- )
272
- elif isinstance(data_payload, DataPayloadProto):
273
- (
274
- MessageValidator(data_payload)
275
- .required_field("uri")
276
- .required_field("encoding")
277
- )
278
- data: bytes = await self._blob_store.get(
279
- uri=data_payload.uri, logger=logger
280
- )
281
- return _serialized_object_from_data_payload_proto(
282
- data_payload=data_payload,
283
- data=data,
284
- )
285
- elif isinstance(data_payload, DataPayload):
286
- data: bytes = await self._blob_store.get(
287
- uri=data_payload.path, logger=logger
288
- )
289
- return _serialized_object_from_data_payload(
290
- data_payload=data_payload,
291
- data=data,
292
- )
293
-
294
- async def _fetch_graph_from_server(
295
- self, namespace: str, graph_name: str, graph_version: str, logger: Any
296
- ) -> SerializedObject:
297
- """Downloads the compute graph for the task and returns it."""
298
- return await self._fetch_url(
299
- url=f"{self._base_url}/internal/namespaces/{namespace}/compute_graphs/{graph_name}/versions/{graph_version}/code",
300
- resource_description=f"compute graph: {graph_name}",
301
- logger=logger,
302
- )
303
-
304
- async def _fetch_graph_invocation_payload_from_server(
305
- self, namespace: str, graph_name: str, graph_invocation_id: str, logger: Any
306
- ) -> SerializedObject:
307
- return await self._fetch_url(
308
- url=f"{self._base_url}/namespaces/{namespace}/compute_graphs/{graph_name}/invocations/{graph_invocation_id}/payload",
309
- resource_description=f"graph invocation payload: {graph_invocation_id}",
310
- logger=logger,
311
- )
312
-
313
- async def _fetch_function_input_from_server(
314
- self, input_key: str, logger: Any
315
- ) -> SerializedObject:
316
- return await self._fetch_url(
317
- url=f"{self._base_url}/internal/fn_outputs/{input_key}",
318
- resource_description=f"function input: {input_key}",
319
- logger=logger,
320
- )
321
-
322
- async def _fetch_function_init_value_from_server(
323
- self,
324
- namespace: str,
325
- graph_name: str,
326
- function_name: str,
327
- graph_invocation_id: str,
328
- reducer_output_key: str,
329
- logger: Any,
330
- ) -> SerializedObject:
331
- return await self._fetch_url(
332
- url=f"{self._base_url}/namespaces/{namespace}/compute_graphs/{graph_name}"
333
- f"/invocations/{graph_invocation_id}/fn/{function_name}/output/{reducer_output_key}",
334
- resource_description=f"reducer output: {reducer_output_key}",
335
- logger=logger,
336
- )
337
-
338
- async def _fetch_url(
339
- self, url: str, resource_description: str, logger: Any
340
- ) -> SerializedObject:
341
- logger.warning(
342
- f"downloading resource from Server",
343
- url=url,
344
- resource_description=resource_description,
345
- )
346
- response: httpx.Response = await self._client.get(url)
347
- try:
348
- response.raise_for_status()
349
- except httpx.HTTPStatusError as e:
350
- logger.error(
351
- f"failed to download {resource_description}",
352
- error=response.text,
353
- exc_info=e,
354
- )
355
- raise
356
-
357
- return serialized_object_from_http_response(response)
358
-
359
-
360
- def serialized_object_from_http_response(response: httpx.Response) -> SerializedObject:
361
- # We're hardcoding the content type currently used by Python SDK. It might change in the future.
362
- # There's no other way for now to determine if the response is a bytes or string.
363
- if response.headers["content-type"] in [
364
- "application/octet-stream",
365
- "application/pickle",
366
- ]:
367
- return SerializedObject(
368
- bytes=response.content, content_type=response.headers["content-type"]
369
- )
370
- else:
371
- return SerializedObject(
372
- string=response.text, content_type=response.headers["content-type"]
373
- )
374
-
375
-
376
- def _serialized_object_from_data_payload(
377
- data_payload: DataPayload, data: bytes
378
- ) -> SerializedObject:
379
- """Converts the given data payload and its data into SerializedObject accepted by Function Executor."""
380
- if data_payload.content_type in [
381
- "application/octet-stream",
382
- "application/pickle",
383
- ]:
384
- return SerializedObject(bytes=data, content_type=data_payload.content_type)
385
- else:
386
- return SerializedObject(
387
- string=data.decode("utf-8"), content_type=data_payload.content_type
388
- )
389
-
390
-
391
- def _serialized_object_from_data_payload_proto(
392
- data_payload: DataPayloadProto, data: bytes
393
- ) -> SerializedObject:
394
- """Converts the given data payload and its data into SerializedObject accepted by Function Executor.
395
-
396
- Raises ValueError if the supplied data payload can't be converted into serialized object.
397
- """
398
- if data_payload.encoding == DataPayloadEncoding.DATA_PAYLOAD_ENCODING_BINARY_PICKLE:
399
- return SerializedObject(
400
- bytes=data,
401
- content_type="application/octet-stream",
402
- )
403
- elif data_payload.encoding == DataPayloadEncoding.DATA_PAYLOAD_ENCODING_UTF8_TEXT:
404
- return SerializedObject(
405
- content_type="text/plain",
406
- string=data.decode("utf-8"),
407
- )
408
- elif data_payload.encoding == DataPayloadEncoding.DATA_PAYLOAD_ENCODING_UTF8_JSON:
409
- result = SerializedObject(
410
- content_type="application/json",
411
- string=data.decode("utf-8"),
412
- )
413
- return result
414
-
415
- raise ValueError(
416
- f"Can't convert data payload {data_payload} into serialized object"
417
- )
@@ -1,7 +0,0 @@
1
- from enum import Enum
2
-
3
-
4
- class ExecutorFlavor(Enum):
5
- UNKNOWN = "unknown"
6
- OSS = "oss"
7
- PLATFORM = "platform"
@@ -1,107 +0,0 @@
1
- import asyncio
2
- from typing import Any, List, Optional
3
-
4
- from .function_executor import FunctionExecutor
5
- from .function_executor_status import FunctionExecutorStatus, is_status_change_allowed
6
- from .metrics.function_executor_state import (
7
- metric_function_executor_state_not_locked_errors,
8
- metric_function_executors_with_status,
9
- )
10
-
11
-
12
- class FunctionExecutorState:
13
- """State of a Function Executor with a particular ID.
14
-
15
- The Function Executor might not exist, i.e. not yet created or destroyed.
16
- This object represents all such states. Any state modification must be done
17
- under the lock.
18
- """
19
-
20
- def __init__(
21
- self,
22
- id: str,
23
- namespace: str,
24
- graph_name: str,
25
- graph_version: str,
26
- function_name: str,
27
- image_uri: Optional[str],
28
- secret_names: List[str],
29
- logger: Any,
30
- ):
31
- # Read only fields.
32
- self.id: str = id
33
- self.namespace: str = namespace
34
- self.graph_name: str = graph_name
35
- self.function_name: str = function_name
36
- self.image_uri: Optional[str] = image_uri
37
- self.secret_names: List[str] = secret_names
38
- self._logger: Any = logger.bind(
39
- module=__name__,
40
- function_executor_id=id,
41
- namespace=namespace,
42
- graph_name=graph_name,
43
- graph_version=graph_version,
44
- function_name=function_name,
45
- image_uri=image_uri,
46
- )
47
- # The lock must be held while modifying the fields below.
48
- self.lock: asyncio.Lock = asyncio.Lock()
49
- # TODO: Move graph_version to immutable fields once we migrate to gRPC State Reconciler.
50
- self.graph_version: str = graph_version
51
- self.status: FunctionExecutorStatus = FunctionExecutorStatus.DESTROYED
52
- self.status_change_notifier: asyncio.Condition = asyncio.Condition(
53
- lock=self.lock
54
- )
55
- self.function_executor: Optional[FunctionExecutor] = None
56
- metric_function_executors_with_status.labels(status=self.status.name).inc()
57
-
58
- async def wait_status(self, allowlist: List[FunctionExecutorStatus]) -> None:
59
- """Waits until Function Executor status reaches one of the allowed values.
60
-
61
- The caller must hold the lock.
62
- """
63
- self.check_locked()
64
- while self.status not in allowlist:
65
- await self.status_change_notifier.wait()
66
-
67
- async def set_status(self, new_status: FunctionExecutorStatus) -> None:
68
- """Sets the status of the Function Executor.
69
-
70
- The caller must hold the lock.
71
- Raises ValueError if the status change is not allowed.
72
- """
73
- self.check_locked()
74
- if is_status_change_allowed(self.status, new_status):
75
- # If status didn't change then still log it for visibility.
76
- self._logger.info(
77
- "function executor status changed",
78
- old_status=self.status.name,
79
- new_status=new_status.name,
80
- )
81
- metric_function_executors_with_status.labels(status=self.status.name).dec()
82
- metric_function_executors_with_status.labels(status=new_status.name).inc()
83
- self.status = new_status
84
- self.status_change_notifier.notify_all()
85
- else:
86
- raise ValueError(
87
- f"Invalid status change from {self.status} to {new_status}"
88
- )
89
-
90
- # TODO: Delete this method once HTTP protocol is removed as it's used only there.
91
- async def destroy_function_executor(self) -> None:
92
- """Destroys the Function Executor if it exists.
93
-
94
- The caller must hold the lock.
95
- """
96
- self.check_locked()
97
- await self.set_status(FunctionExecutorStatus.DESTROYING)
98
- if self.function_executor is not None:
99
- await self.function_executor.destroy()
100
- self.function_executor = None
101
- await self.set_status(FunctionExecutorStatus.DESTROYED)
102
-
103
- def check_locked(self) -> None:
104
- """Raises an exception if the lock is not held."""
105
- if not self.lock.locked():
106
- metric_function_executor_state_not_locked_errors.inc()
107
- raise RuntimeError("The FunctionExecutorState lock must be held.")
@@ -1,93 +0,0 @@
1
- import asyncio
2
- from typing import Any, AsyncGenerator, Dict, List, Optional
3
-
4
- from .function_executor_state import FunctionExecutorState
5
- from .function_executor_status import FunctionExecutorStatus
6
- from .metrics.function_executor_state_container import (
7
- metric_function_executor_states_count,
8
- )
9
-
10
-
11
- class FunctionExecutorStatesContainer:
12
- """An asyncio concurrent container for the function executor states."""
13
-
14
- def __init__(self, logger: Any):
15
- # The fields below are protected by the lock.
16
- self._lock: asyncio.Lock = asyncio.Lock()
17
- self._states: Dict[str, FunctionExecutorState] = {}
18
- self._is_shutdown: bool = False
19
- self._logger: Any = logger.bind(module=__name__)
20
-
21
- async def get_or_create_state(
22
- self,
23
- id: str,
24
- namespace: str,
25
- graph_name: str,
26
- graph_version: str,
27
- function_name: str,
28
- image_uri: Optional[str],
29
- secret_names: List[str],
30
- ) -> FunctionExecutorState:
31
- """Get or create a function executor state with the given ID.
32
-
33
- If the state already exists, it is returned. Otherwise, a new state is created from the supplied task.
34
- Raises Exception if it's not possible to create a new state at this time."""
35
- async with self._lock:
36
- if self._is_shutdown:
37
- raise RuntimeError(
38
- "Function Executor states container is shutting down."
39
- )
40
-
41
- if id not in self._states:
42
- state = FunctionExecutorState(
43
- id=id,
44
- namespace=namespace,
45
- graph_name=graph_name,
46
- graph_version=graph_version,
47
- function_name=function_name,
48
- image_uri=image_uri,
49
- secret_names=secret_names,
50
- logger=self._logger,
51
- )
52
- self._states[id] = state
53
- metric_function_executor_states_count.set(len(self._states))
54
-
55
- return self._states[id]
56
-
57
- async def get(self, id: str) -> FunctionExecutorState:
58
- """Get the state with the given ID. Raises Exception if the state does not exist."""
59
- async with self._lock:
60
- return self._states[id]
61
-
62
- async def __aiter__(self) -> AsyncGenerator[FunctionExecutorState, None]:
63
- async with self._lock:
64
- for state in self._states.values():
65
- yield state
66
-
67
- async def pop(self, id: str) -> FunctionExecutorState:
68
- """Removes the state with the given ID and returns it."""
69
- async with self._lock:
70
- state = self._states.pop(id)
71
- metric_function_executor_states_count.set(len(self._states))
72
- return state
73
-
74
- def exists(self, id: str) -> bool:
75
- """Check if the state with the given ID exists."""
76
- return id in self._states
77
-
78
- async def shutdown(self):
79
- # Function Executors are outside the Executor process
80
- # so they need to get cleaned up explicitly and reliably.
81
- async with self._lock:
82
- self._is_shutdown = True # No new Function Executor States can be created.
83
- while self._states:
84
- id, state = self._states.popitem()
85
- metric_function_executor_states_count.set(len(self._states))
86
- # Only ongoing tasks who have a reference to the state already can see it.
87
- # The state is unlocked while a task is running inside Function Executor.
88
- async with state.lock:
89
- await state.set_status(FunctionExecutorStatus.SHUTDOWN)
90
- if state.function_executor is not None:
91
- await state.function_executor.destroy()
92
- state.function_executor = None
93
- # The task running inside the Function Executor will fail because it's destroyed.