isolate 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. isolate/__init__.py +3 -0
  2. isolate/_isolate_version.py +34 -0
  3. isolate/_version.py +6 -0
  4. isolate/backends/__init__.py +2 -0
  5. isolate/backends/_base.py +132 -0
  6. isolate/backends/common.py +259 -0
  7. isolate/backends/conda.py +215 -0
  8. isolate/backends/container.py +64 -0
  9. isolate/backends/local.py +46 -0
  10. isolate/backends/pyenv.py +143 -0
  11. isolate/backends/remote.py +141 -0
  12. isolate/backends/settings.py +121 -0
  13. isolate/backends/virtualenv.py +204 -0
  14. isolate/common/__init__.py +0 -0
  15. isolate/common/timestamp.py +15 -0
  16. isolate/connections/__init__.py +21 -0
  17. isolate/connections/_local/__init__.py +2 -0
  18. isolate/connections/_local/_base.py +190 -0
  19. isolate/connections/_local/agent_startup.py +53 -0
  20. isolate/connections/common.py +121 -0
  21. isolate/connections/grpc/__init__.py +1 -0
  22. isolate/connections/grpc/_base.py +175 -0
  23. isolate/connections/grpc/agent.py +284 -0
  24. isolate/connections/grpc/configuration.py +23 -0
  25. isolate/connections/grpc/definitions/__init__.py +11 -0
  26. isolate/connections/grpc/definitions/agent.proto +18 -0
  27. isolate/connections/grpc/definitions/agent_pb2.py +29 -0
  28. isolate/connections/grpc/definitions/agent_pb2.pyi +44 -0
  29. isolate/connections/grpc/definitions/agent_pb2_grpc.py +68 -0
  30. isolate/connections/grpc/definitions/common.proto +49 -0
  31. isolate/connections/grpc/definitions/common_pb2.py +35 -0
  32. isolate/connections/grpc/definitions/common_pb2.pyi +152 -0
  33. isolate/connections/grpc/definitions/common_pb2_grpc.py +4 -0
  34. isolate/connections/grpc/interface.py +71 -0
  35. isolate/connections/ipc/__init__.py +5 -0
  36. isolate/connections/ipc/_base.py +225 -0
  37. isolate/connections/ipc/agent.py +205 -0
  38. isolate/logger.py +53 -0
  39. isolate/logs.py +76 -0
  40. isolate/py.typed +0 -0
  41. isolate/registry.py +53 -0
  42. isolate/server/__init__.py +1 -0
  43. isolate/server/definitions/__init__.py +13 -0
  44. isolate/server/definitions/server.proto +80 -0
  45. isolate/server/definitions/server_pb2.py +56 -0
  46. isolate/server/definitions/server_pb2.pyi +241 -0
  47. isolate/server/definitions/server_pb2_grpc.py +205 -0
  48. isolate/server/health/__init__.py +11 -0
  49. isolate/server/health/health.proto +23 -0
  50. isolate/server/health/health_pb2.py +32 -0
  51. isolate/server/health/health_pb2.pyi +66 -0
  52. isolate/server/health/health_pb2_grpc.py +99 -0
  53. isolate/server/health_server.py +40 -0
  54. isolate/server/interface.py +27 -0
  55. isolate/server/server.py +735 -0
  56. isolate-0.22.0.dist-info/METADATA +88 -0
  57. isolate-0.22.0.dist-info/RECORD +61 -0
  58. isolate-0.22.0.dist-info/WHEEL +5 -0
  59. isolate-0.22.0.dist-info/entry_points.txt +7 -0
  60. isolate-0.22.0.dist-info/licenses/LICENSE +201 -0
  61. isolate-0.22.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,735 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import os
5
+ import signal
6
+ import threading
7
+ import time
8
+ import traceback
9
+ import uuid
10
+ from argparse import ArgumentParser
11
+ from collections import defaultdict
12
+ from concurrent import futures
13
+ from concurrent.futures import ThreadPoolExecutor
14
+ from contextlib import ExitStack, contextmanager
15
+ from dataclasses import dataclass, field, replace
16
+ from queue import Empty as QueueEmpty
17
+ from queue import Queue
18
+ from typing import Any, Callable, Iterator, cast
19
+
20
+ import grpc
21
+ from grpc import ServicerContext, StatusCode
22
+ from grpc.experimental import wrap_server_method_handler
23
+
24
+ from isolate.backends import (
25
+ EnvironmentCreationError,
26
+ IsolateSettings,
27
+ )
28
+ from isolate.backends.common import active_python
29
+ from isolate.backends.local import LocalPythonEnvironment
30
+ from isolate.backends.virtualenv import VirtualPythonEnvironment
31
+ from isolate.connections.grpc import AgentError, LocalPythonGRPC
32
+ from isolate.connections.grpc.configuration import get_default_options
33
+ from isolate.logger import IsolateLogger
34
+ from isolate.logs import Log, LogLevel, LogSource
35
+ from isolate.server import definitions, health
36
+ from isolate.server.health_server import HealthServicer
37
+ from isolate.server.interface import from_grpc, to_grpc
38
+
39
+ EMPTY_MESSAGE_INTERVAL = float(os.getenv("ISOLATE_EMPTY_MESSAGE_INTERVAL", "600"))
40
+ SKIP_EMPTY_LOGS = os.getenv("ISOLATE_SKIP_EMPTY_LOGS") == "1"
41
+ MAX_GRPC_WAIT_TIMEOUT = float(os.getenv("ISOLATE_MAX_GRPC_WAIT_TIMEOUT", "10.0"))
42
+
43
+ # Whether to inherit all the packages from the current environment or not.
44
+ INHERIT_FROM_LOCAL = os.getenv("ISOLATE_INHERIT_FROM_LOCAL") == "1"
45
+
46
+ # Number of threads that the gRPC server will use.
47
+ MAX_THREADS = int(os.getenv("MAX_THREADS", "5"))
48
+ _AGENT_REQUIREMENTS_TXT = os.getenv("AGENT_REQUIREMENTS_TXT")
49
+
50
+ if _AGENT_REQUIREMENTS_TXT is not None:
51
+ with open(_AGENT_REQUIREMENTS_TXT) as stream:
52
+ AGENT_REQUIREMENTS = stream.read().splitlines()
53
+ else:
54
+ AGENT_REQUIREMENTS = []
55
+
56
+
57
+ # Number of seconds to observe the queue before checking the termination
58
+ # event.
59
+ _Q_WAIT_DELAY = 0.1
60
+
61
+
62
+ class GRPCException(Exception):
63
+ def __init__(self, message: str, code: StatusCode = StatusCode.INVALID_ARGUMENT):
64
+ super().__init__(message)
65
+ self.message = message
66
+ self.code = code
67
+
68
+ def __str__(self) -> str:
69
+ return f"{self.code.name}: {self.message}"
70
+
71
+
72
+ @dataclass
73
+ class RunnerAgent:
74
+ stub: definitions.AgentStub
75
+ message_queue: Queue[definitions.PartialRunResult]
76
+ _bound_context: ExitStack
77
+ _channel_state_history: list[grpc.ChannelConnectivity] = field(default_factory=list)
78
+ _connection: LocalPythonGRPC | None = None
79
+ _terminated: bool = False
80
+
81
+ def __post_init__(self):
82
+ def switch_state(connectivity_update: grpc.ChannelConnectivity) -> None:
83
+ self._channel_state_history.append(connectivity_update)
84
+
85
+ self.channel.subscribe(switch_state)
86
+
87
+ @property
88
+ def channel(self) -> grpc.Channel:
89
+ return self.stub._channel # type: ignore
90
+
91
+ @property
92
+ def is_accessible(self) -> bool:
93
+ try:
94
+ last_known_state = self._channel_state_history[-1]
95
+ except IndexError:
96
+ last_known_state = None
97
+
98
+ return last_known_state is grpc.ChannelConnectivity.READY
99
+
100
+ def check_connectivity(self) -> bool:
101
+ # Check whether the server is ready.
102
+ # TODO: This is more of a hack rather than a guaranteed health check,
103
+ # we might have to introduce the proper protocol to the agents as well
104
+ # to make sure that they are ready to receive requests.
105
+ return self.is_accessible
106
+
107
+ def terminate(self) -> None:
108
+ """
109
+ Abort the agent first, then close the bound context.
110
+
111
+ Closing the ExitStack tears down the gRPC channel; doing that before
112
+ terminating the agent triggers an asyncio.CancelledError mid-request and
113
+ the agent never receives SIGTERM. By aborting first we deliver SIGTERM
114
+ while the connection is still alive, then close it.
115
+ """
116
+ self._terminated = True
117
+ if self._connection:
118
+ self._connection.abort_agent()
119
+ self._bound_context.close()
120
+
121
+
122
+ @dataclass
123
+ class BridgeManager:
124
+ _agent_access_lock: threading.Lock = field(default_factory=threading.Lock)
125
+ _agents: dict[tuple[Any, ...], list[RunnerAgent]] = field(
126
+ default_factory=lambda: defaultdict(list)
127
+ )
128
+ _stack: ExitStack = field(default_factory=ExitStack)
129
+
130
+ @contextmanager
131
+ def establish(
132
+ self,
133
+ connection: LocalPythonGRPC,
134
+ queue: Queue,
135
+ ) -> Iterator[RunnerAgent]:
136
+ agent = self._allocate_new_agent(connection, queue)
137
+
138
+ try:
139
+ yield agent
140
+ finally:
141
+ self._cache_agent(connection, agent)
142
+
143
+ def _cache_agent(
144
+ self,
145
+ connection: LocalPythonGRPC,
146
+ agent: RunnerAgent,
147
+ ) -> None:
148
+ with self._agent_access_lock:
149
+ self._agents[self._identify(connection)].append(agent)
150
+
151
+ def _allocate_new_agent(
152
+ self,
153
+ connection: LocalPythonGRPC,
154
+ queue: Queue,
155
+ ) -> RunnerAgent:
156
+ with self._agent_access_lock:
157
+ available_agents = self._agents[self._identify(connection)]
158
+ while available_agents:
159
+ agent = available_agents.pop()
160
+ if agent.check_connectivity():
161
+ return agent
162
+ else:
163
+ agent.terminate()
164
+
165
+ bound_context = ExitStack()
166
+ stub = bound_context.enter_context(
167
+ connection._establish_bridge(max_wait_timeout=MAX_GRPC_WAIT_TIMEOUT)
168
+ )
169
+ return RunnerAgent(stub, queue, bound_context, [], connection)
170
+
171
+ def _identify(self, connection: LocalPythonGRPC) -> tuple[Any, ...]:
172
+ return (
173
+ connection.environment_path,
174
+ *connection.extra_inheritance_paths,
175
+ )
176
+
177
+ def __enter__(self) -> BridgeManager:
178
+ return self
179
+
180
+ def __exit__(self, *exc_info: Any) -> None:
181
+ for agents in self._agents.values():
182
+ for agent in agents:
183
+ agent.terminate()
184
+
185
+
186
+ @dataclass
187
+ class RunTask:
188
+ request: definitions.BoundFunction
189
+ future: futures.Future | None = None
190
+ agent: RunnerAgent | None = None
191
+ logger: IsolateLogger = field(default_factory=IsolateLogger.from_env)
192
+
193
+ def cancel(self):
194
+ while True:
195
+ # Cancelling a running future is not possible, and it sometimes blocks,
196
+ # which means we never terminate the agent. So check if it's not running
197
+ if self.future and not self.future.running():
198
+ self.future.cancel()
199
+
200
+ if self.agent:
201
+ self.agent.terminate()
202
+
203
+ try:
204
+ if self.future:
205
+ self.future.exception(timeout=0.1)
206
+ return
207
+ except futures.TimeoutError:
208
+ pass
209
+
210
+ @property
211
+ def stream_logs(self) -> bool:
212
+ return self.request.stream_logs
213
+
214
+
215
+ @dataclass
216
+ class IsolateServicer(definitions.IsolateServicer):
217
+ bridge_manager: BridgeManager
218
+ default_settings: IsolateSettings = field(default_factory=IsolateSettings)
219
+ background_tasks: dict[str, RunTask] = field(default_factory=dict)
220
+ _shutting_down: bool = field(default=False)
221
+
222
+ _thread_pool: futures.ThreadPoolExecutor = field(
223
+ default_factory=lambda: futures.ThreadPoolExecutor(max_workers=MAX_THREADS)
224
+ )
225
+
226
+ def _run_task(self, task: RunTask) -> Iterator[definitions.PartialRunResult]:
227
+ messages: Queue[definitions.PartialRunResult] = Queue()
228
+ environments = []
229
+ for env in task.request.environments:
230
+ try:
231
+ environments.append((env.force, from_grpc(env)))
232
+ except ValueError:
233
+ raise GRPCException(f"Unknown environment kind: {env.kind}")
234
+ except TypeError as exc:
235
+ raise GRPCException(f"Invalid environment: {str(exc)}")
236
+
237
+ if not environments:
238
+ raise GRPCException(
239
+ "At least one environment must be specified for a run!",
240
+ StatusCode.INVALID_ARGUMENT,
241
+ )
242
+
243
+ log_handler = LogHandler(messages, task=task)
244
+
245
+ run_settings = replace(
246
+ self.default_settings,
247
+ log_hook=log_handler.handle,
248
+ serialization_method=task.request.function.method,
249
+ )
250
+
251
+ for _, environment in environments:
252
+ environment.apply_settings(run_settings)
253
+
254
+ _, primary_environment = environments[0]
255
+
256
+ if AGENT_REQUIREMENTS:
257
+ python_version = getattr(
258
+ primary_environment, "python_version", active_python()
259
+ )
260
+ agent_environ = VirtualPythonEnvironment(
261
+ requirements=AGENT_REQUIREMENTS,
262
+ python_version=python_version,
263
+ )
264
+ agent_environ.apply_settings(run_settings)
265
+ environments.insert(1, (False, agent_environ))
266
+
267
+ extra_inheritance_paths = []
268
+ if INHERIT_FROM_LOCAL:
269
+ local_environment = LocalPythonEnvironment()
270
+ extra_inheritance_paths.append(local_environment.create())
271
+
272
+ with ThreadPoolExecutor(max_workers=1) as local_pool:
273
+ environment_paths = []
274
+ for should_force_create, environment in environments:
275
+ future = local_pool.submit(
276
+ environment.create, force=should_force_create
277
+ )
278
+ yield from self.watch_queue_until_completed(messages, future.done)
279
+ try:
280
+ # Assuming that the iterator above only stops yielding once
281
+ # the future is completed, the timeout here should be redundant
282
+ # but it is just in case.
283
+ environment_paths.append(future.result(timeout=0.1))
284
+ except EnvironmentCreationError as e:
285
+ raise GRPCException(f"{e}", StatusCode.INVALID_ARGUMENT)
286
+
287
+ primary_path, *inheritance_paths = environment_paths
288
+ inheritance_paths.extend(extra_inheritance_paths)
289
+ _, primary_environment = environments[0]
290
+
291
+ connection = LocalPythonGRPC(
292
+ primary_environment,
293
+ primary_path,
294
+ extra_inheritance_paths=inheritance_paths,
295
+ )
296
+
297
+ with self.bridge_manager.establish(connection, queue=messages) as agent:
298
+ task.agent = agent
299
+ function_call = definitions.FunctionCall(
300
+ function=task.request.function,
301
+ setup_func=task.request.setup_func,
302
+ )
303
+ if not task.request.HasField("setup_func"):
304
+ function_call.ClearField("setup_func")
305
+
306
+ future = local_pool.submit(
307
+ _proxy_to_queue,
308
+ # The agent may have been cached, so use the agent's message queue
309
+ queue=agent.message_queue,
310
+ bridge=agent.stub,
311
+ input=function_call,
312
+ )
313
+
314
+ # Unlike above; we are not interested in the result value of future
315
+ # here, since it will be already transferred to other side without
316
+ # us even seeing (through the queue).
317
+ yield from self.watch_queue_until_completed(
318
+ agent.message_queue, future.done
319
+ )
320
+
321
+ # But we still have to check whether there were any errors raised
322
+ # during the execution, and handle them accordingly.
323
+ exception = future.exception(timeout=0.1)
324
+ if exception is not None:
325
+ # If this is an RPC error, propagate it as is without any
326
+ # further processing.
327
+
328
+ if isinstance(exception, grpc.RpcError):
329
+ # on abort, we terminate the process before we close the channel
330
+ # because we need to populate SIGTERM to the agent process
331
+ if (
332
+ agent._terminated
333
+ and exception.code() == StatusCode.UNAVAILABLE
334
+ ):
335
+ return
336
+ raise GRPCException(
337
+ str(exception),
338
+ exception.code(),
339
+ )
340
+
341
+ # Otherwise this is a bug in the agent itself, so needs
342
+ # to be propagated with more details.
343
+ for line in traceback.format_exception(
344
+ type(exception), exception, exception.__traceback__
345
+ ):
346
+ yield from self.log(line, level=LogLevel.ERROR)
347
+ if isinstance(exception, AgentError):
348
+ raise GRPCException(str(exception), StatusCode.ABORTED)
349
+ else:
350
+ raise GRPCException(
351
+ f"An unexpected error occurred: {exception}.",
352
+ StatusCode.UNKNOWN,
353
+ )
354
+
355
+ def _run_task_in_background(self, task: RunTask) -> None:
356
+ for _ in self._run_task(task):
357
+ pass
358
+
359
+ def Submit(
360
+ self,
361
+ request: definitions.SubmitRequest,
362
+ context: ServicerContext,
363
+ ) -> definitions.SubmitResponse:
364
+ task = RunTask(request=request.function)
365
+ self.set_metadata(task, request.metadata)
366
+
367
+ task.future = self._thread_pool.submit(self._run_task_in_background, task)
368
+ task_id = str(uuid.uuid4())
369
+
370
+ print(f"Submitted a task {task_id}")
371
+
372
+ self.background_tasks[task_id] = task
373
+
374
+ def _callback(future: futures.Future) -> None:
375
+ msg = f"Task {task_id} finished with"
376
+ if exc := future.exception():
377
+ msg += f" error: {exc!r}"
378
+ else:
379
+ msg += f" result: {future.result()!r}"
380
+ print(msg)
381
+ self.background_tasks.pop(task_id, None)
382
+
383
+ task.future.add_done_callback(_callback)
384
+
385
+ return definitions.SubmitResponse(task_id=task_id)
386
+
387
+ def SetMetadata(
388
+ self,
389
+ request: definitions.SetMetadataRequest,
390
+ context: ServicerContext,
391
+ ) -> definitions.SetMetadataResponse:
392
+ if request.task_id not in self.background_tasks:
393
+ raise GRPCException(
394
+ f"Task {request.task_id} not found.",
395
+ StatusCode.NOT_FOUND,
396
+ )
397
+
398
+ self.set_metadata(self.background_tasks[request.task_id], request.metadata)
399
+
400
+ return definitions.SetMetadataResponse()
401
+
402
+ def set_metadata(self, task: RunTask, metadata: definitions.TaskMetadata) -> None:
403
+ task.logger.extra_labels = dict(metadata.logger_labels)
404
+
405
+ def Run(
406
+ self,
407
+ request: definitions.BoundFunction,
408
+ context: ServicerContext,
409
+ ) -> Iterator[definitions.PartialRunResult]:
410
+ try:
411
+ task = RunTask(request=request)
412
+
413
+ # HACK: we can support only one task at a time
414
+ # TODO: move away from this when we use submit for env-aware tasks
415
+ self.background_tasks["RUN"] = task
416
+ yield from self._run_task(task)
417
+ except GRPCException as exc:
418
+ return self.abort_with_msg(
419
+ exc.message,
420
+ context,
421
+ code=exc.code,
422
+ )
423
+ finally:
424
+ self.background_tasks.pop("RUN", None)
425
+
426
+ def List(
427
+ self,
428
+ request: definitions.ListRequest,
429
+ context: ServicerContext,
430
+ ) -> definitions.ListResponse:
431
+ return definitions.ListResponse(
432
+ tasks=[
433
+ definitions.TaskInfo(task_id=task_id)
434
+ for task_id in self.background_tasks.keys()
435
+ ]
436
+ )
437
+
438
+ def Cancel(
439
+ self,
440
+ request: definitions.CancelRequest,
441
+ context: ServicerContext,
442
+ ) -> definitions.CancelResponse:
443
+ task_id = request.task_id
444
+
445
+ print(f"Canceling task {task_id}")
446
+ task = self.background_tasks.get(task_id)
447
+ if task is not None:
448
+ task.cancel()
449
+
450
+ return definitions.CancelResponse()
451
+
452
+ def shutdown(self) -> None:
453
+ if self._shutting_down:
454
+ print("Shutdown already in progress...")
455
+ return
456
+
457
+ self._shutting_down = True
458
+ task_count = len(self.background_tasks)
459
+ print(f"Shutting down, canceling {task_count} tasks...")
460
+ self.cancel_tasks()
461
+ print("All tasks canceled.")
462
+
463
+ def watch_queue_until_completed(
464
+ self, queue: Queue, is_completed: Callable[[], bool]
465
+ ) -> Iterator[definitions.PartialRunResult]:
466
+ """Watch the given queue until the is_completed function returns True.
467
+ Note that even if the function is completed, this function might not
468
+ finish until the queue is empty.
469
+ """
470
+
471
+ timer = time.monotonic()
472
+ while not is_completed():
473
+ try:
474
+ yield queue.get(timeout=_Q_WAIT_DELAY)
475
+ except QueueEmpty:
476
+ # Send an empty (but 'real') packet to the client, currently a hacky way
477
+ # to make sure the stream results are never ignored.
478
+ if time.monotonic() - timer > EMPTY_MESSAGE_INTERVAL:
479
+ timer = time.monotonic()
480
+ yield definitions.PartialRunResult(
481
+ is_complete=False,
482
+ logs=[],
483
+ result=None,
484
+ )
485
+
486
+ # Clear the final messages
487
+ while not queue.empty():
488
+ try:
489
+ yield queue.get_nowait()
490
+ except QueueEmpty:
491
+ continue
492
+
493
+ def log(
494
+ self,
495
+ message: str,
496
+ level: LogLevel = LogLevel.TRACE,
497
+ source: LogSource = LogSource.BRIDGE,
498
+ ) -> Iterator[definitions.PartialRunResult]:
499
+ log = to_grpc(Log(message, level=level, source=source))
500
+ log = cast(definitions.Log, log)
501
+ yield definitions.PartialRunResult(result=None, is_complete=False, logs=[log])
502
+
503
+ def abort_with_msg(
504
+ self,
505
+ message: str,
506
+ context: ServicerContext,
507
+ *,
508
+ code: StatusCode = StatusCode.INVALID_ARGUMENT,
509
+ ) -> None:
510
+ context.set_code(code)
511
+ context.set_details(message)
512
+ return None
513
+
514
+ def cancel_tasks(self):
515
+ tasks_copy = self.background_tasks.copy()
516
+ for task in tasks_copy.values():
517
+ task.cancel()
518
+
519
+
520
+ def _proxy_to_queue(
521
+ queue: Queue,
522
+ bridge: definitions.AgentStub,
523
+ input: definitions.FunctionCall,
524
+ ) -> None:
525
+ for message in bridge.Run(input):
526
+ queue.put_nowait(message)
527
+
528
+
529
+ @dataclass
530
+ class LogHandler:
531
+ messages: Queue
532
+ # Reference to the task so we can change the logger
533
+ task: RunTask
534
+
535
+ def handle(self, log: Log) -> None:
536
+ if not SKIP_EMPTY_LOGS or log.message.strip():
537
+ self.task.logger.log(log.level, log.message, source=log.source)
538
+ self._add_log_to_queue(log)
539
+
540
+ def _add_log_to_queue(self, log: Log) -> None:
541
+ if not self.task.stream_logs:
542
+ # We do not queue the logs if the stream_logs is disabled
543
+ # but still log them to the logger.
544
+ return
545
+
546
+ grpc_log = cast(definitions.Log, to_grpc(log))
547
+ grpc_result = definitions.PartialRunResult(
548
+ is_complete=False,
549
+ logs=[grpc_log],
550
+ result=None,
551
+ )
552
+ self.messages.put_nowait(grpc_result)
553
+
554
+
555
+ @dataclass
556
+ class ServerBoundInterceptor(grpc.ServerInterceptor):
557
+ _server: grpc.Server | None = None
558
+ _servicer: IsolateServicer | None = None
559
+
560
+ def register_server(self, server: grpc.Server) -> None:
561
+ if self._server is not None:
562
+ raise RuntimeError("A server is already bound to this interceptor.")
563
+
564
+ self._server = server
565
+
566
+ @property
567
+ def server(self) -> grpc.Server:
568
+ if self._server is None:
569
+ raise RuntimeError("No server was bound to this interceptor.")
570
+
571
+ return self._server
572
+
573
+ def register_servicer(self, servicer: IsolateServicer) -> None:
574
+ if self._servicer is not None:
575
+ raise RuntimeError("A servicer is already bound to this interceptor.")
576
+
577
+ self._servicer = servicer
578
+
579
+ @property
580
+ def servicer(self) -> IsolateServicer:
581
+ if self._servicer is None:
582
+ raise RuntimeError("No servicer was bound to this interceptor.")
583
+
584
+ return self._servicer
585
+
586
+
587
+ @dataclass
588
+ class SingleTaskInterceptor(ServerBoundInterceptor):
589
+ """Sets server to terminate after the first Submit/Run task."""
590
+
591
+ _done: bool = False
592
+ _task_id: str | None = None
593
+
594
+ def __init__(self):
595
+ def terminate(request: Any, context: grpc.ServicerContext) -> Any:
596
+ context.abort(
597
+ grpc.StatusCode.RESOURCE_EXHAUSTED,
598
+ "Server has already served one Run/Submit task.",
599
+ )
600
+
601
+ self._terminator = grpc.unary_unary_rpc_method_handler(terminate)
602
+
603
+ def intercept_service(self, continuation, handler_call_details):
604
+ handler = continuation(handler_call_details)
605
+
606
+ is_submit = handler_call_details.method == "/Isolate/Submit"
607
+ is_run = handler_call_details.method == "/Isolate/Run"
608
+ is_new_task = is_submit or is_run
609
+
610
+ if not is_new_task:
611
+ # Let other requests like List/Cancel/etc pass through
612
+ return handler
613
+
614
+ if self._done:
615
+ # Fail the request if the server has already served or is serving
616
+ # a Run/Submit task.
617
+ return self._terminator
618
+
619
+ self._done = True
620
+
621
+ def wrapper(method_impl):
622
+ @functools.wraps(method_impl)
623
+ def _wrapper(request: Any, context: grpc.ServicerContext) -> Any:
624
+ def termination() -> None:
625
+ if is_run:
626
+ print("Stopping server since run is finished")
627
+ self.servicer.shutdown()
628
+ # Stop the server after the Run task is finished
629
+ self.server.stop(grace=0.1)
630
+ print("Server stopped")
631
+
632
+ elif is_submit:
633
+ # Wait until the task_id is assigned
634
+ while self._task_id is None:
635
+ time.sleep(0.1)
636
+
637
+ # Get the task from the background tasks
638
+ task = self.servicer.background_tasks.get(self._task_id)
639
+
640
+ if task is not None:
641
+ # Wait until the task future is assigned
642
+ tries = 0
643
+ while task.future is None:
644
+ time.sleep(0.1)
645
+ tries += 1
646
+ if tries > 100:
647
+ raise RuntimeError(
648
+ "Task future was not assigned in time."
649
+ )
650
+
651
+ def _stop(*args):
652
+ # Small sleep to make sure the cancellation is processed
653
+ time.sleep(0.1)
654
+ print("Stopping server since the task is finished")
655
+ self.servicer.shutdown()
656
+ self.server.stop(grace=0.1)
657
+ print("Server stopped")
658
+
659
+ # Add a callback which will stop the server
660
+ # after the task is finished
661
+ task.future.add_done_callback(_stop)
662
+
663
+ context.add_callback(termination)
664
+ res = method_impl(request, context)
665
+
666
+ if is_submit:
667
+ self._task_id = cast(definitions.SubmitResponse, res).task_id
668
+
669
+ return res
670
+
671
+ return _wrapper
672
+
673
+ return wrap_server_method_handler(wrapper, handler)
674
+
675
+
676
+ def main(argv: list[str] | None = None) -> None:
677
+ parser = ArgumentParser()
678
+ parser.add_argument("--host", default="0.0.0.0")
679
+ parser.add_argument("--port", type=int, default=50001)
680
+ parser.add_argument(
681
+ "--single-use",
682
+ action="store_true",
683
+ help="Terminate the server after the first Run or Submit task is completed.",
684
+ )
685
+ parser.add_argument(
686
+ "--num-workers",
687
+ type=int,
688
+ default=MAX_THREADS,
689
+ help="Number of worker threads to use for the gRPC server.",
690
+ )
691
+
692
+ options = parser.parse_args(argv)
693
+ if options.num_workers is None:
694
+ options.num_workers = 1 if options.single_use else os.cpu_count()
695
+
696
+ interceptors: list[ServerBoundInterceptor] = []
697
+ if options.single_use:
698
+ interceptors.append(SingleTaskInterceptor())
699
+
700
+ server = grpc.server(
701
+ futures.ThreadPoolExecutor(max_workers=options.num_workers),
702
+ options=get_default_options(),
703
+ interceptors=interceptors, # type: ignore
704
+ )
705
+
706
+ for interceptor in interceptors:
707
+ interceptor.register_server(server)
708
+
709
+ with BridgeManager() as bridge_manager:
710
+ servicer = IsolateServicer(bridge_manager)
711
+
712
+ for interceptor in interceptors:
713
+ interceptor.register_servicer(servicer)
714
+
715
+ definitions.register_isolate(servicer, server)
716
+ health.register_health(HealthServicer(), server)
717
+
718
+ def handle_termination(*args):
719
+ print("Termination signal received, shutting down...")
720
+ servicer.shutdown()
721
+ server.stop(grace=0.1)
722
+
723
+ signal.signal(signal.SIGINT, handle_termination)
724
+ signal.signal(signal.SIGTERM, handle_termination)
725
+
726
+ server.add_insecure_port(f"[::]:{options.port}")
727
+ print(f"Started listening at {options.host}:{options.port}")
728
+
729
+ server.start()
730
+ server.wait_for_termination()
731
+ print("Server shut down")
732
+
733
+
734
+ if __name__ == "__main__":
735
+ main()