isolate 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isolate/__init__.py +3 -0
- isolate/_isolate_version.py +34 -0
- isolate/_version.py +6 -0
- isolate/backends/__init__.py +2 -0
- isolate/backends/_base.py +132 -0
- isolate/backends/common.py +259 -0
- isolate/backends/conda.py +215 -0
- isolate/backends/container.py +64 -0
- isolate/backends/local.py +46 -0
- isolate/backends/pyenv.py +143 -0
- isolate/backends/remote.py +141 -0
- isolate/backends/settings.py +121 -0
- isolate/backends/virtualenv.py +204 -0
- isolate/common/__init__.py +0 -0
- isolate/common/timestamp.py +15 -0
- isolate/connections/__init__.py +21 -0
- isolate/connections/_local/__init__.py +2 -0
- isolate/connections/_local/_base.py +190 -0
- isolate/connections/_local/agent_startup.py +53 -0
- isolate/connections/common.py +121 -0
- isolate/connections/grpc/__init__.py +1 -0
- isolate/connections/grpc/_base.py +175 -0
- isolate/connections/grpc/agent.py +284 -0
- isolate/connections/grpc/configuration.py +23 -0
- isolate/connections/grpc/definitions/__init__.py +11 -0
- isolate/connections/grpc/definitions/agent.proto +18 -0
- isolate/connections/grpc/definitions/agent_pb2.py +29 -0
- isolate/connections/grpc/definitions/agent_pb2.pyi +44 -0
- isolate/connections/grpc/definitions/agent_pb2_grpc.py +68 -0
- isolate/connections/grpc/definitions/common.proto +49 -0
- isolate/connections/grpc/definitions/common_pb2.py +35 -0
- isolate/connections/grpc/definitions/common_pb2.pyi +152 -0
- isolate/connections/grpc/definitions/common_pb2_grpc.py +4 -0
- isolate/connections/grpc/interface.py +71 -0
- isolate/connections/ipc/__init__.py +5 -0
- isolate/connections/ipc/_base.py +225 -0
- isolate/connections/ipc/agent.py +205 -0
- isolate/logger.py +53 -0
- isolate/logs.py +76 -0
- isolate/py.typed +0 -0
- isolate/registry.py +53 -0
- isolate/server/__init__.py +1 -0
- isolate/server/definitions/__init__.py +13 -0
- isolate/server/definitions/server.proto +80 -0
- isolate/server/definitions/server_pb2.py +56 -0
- isolate/server/definitions/server_pb2.pyi +241 -0
- isolate/server/definitions/server_pb2_grpc.py +205 -0
- isolate/server/health/__init__.py +11 -0
- isolate/server/health/health.proto +23 -0
- isolate/server/health/health_pb2.py +32 -0
- isolate/server/health/health_pb2.pyi +66 -0
- isolate/server/health/health_pb2_grpc.py +99 -0
- isolate/server/health_server.py +40 -0
- isolate/server/interface.py +27 -0
- isolate/server/server.py +735 -0
- isolate-0.22.0.dist-info/METADATA +88 -0
- isolate-0.22.0.dist-info/RECORD +61 -0
- isolate-0.22.0.dist-info/WHEEL +5 -0
- isolate-0.22.0.dist-info/entry_points.txt +7 -0
- isolate-0.22.0.dist-info/licenses/LICENSE +201 -0
- isolate-0.22.0.dist-info/top_level.txt +1 -0
isolate/server/server.py
ADDED
|
@@ -0,0 +1,735 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import os
|
|
5
|
+
import signal
|
|
6
|
+
import threading
|
|
7
|
+
import time
|
|
8
|
+
import traceback
|
|
9
|
+
import uuid
|
|
10
|
+
from argparse import ArgumentParser
|
|
11
|
+
from collections import defaultdict
|
|
12
|
+
from concurrent import futures
|
|
13
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
14
|
+
from contextlib import ExitStack, contextmanager
|
|
15
|
+
from dataclasses import dataclass, field, replace
|
|
16
|
+
from queue import Empty as QueueEmpty
|
|
17
|
+
from queue import Queue
|
|
18
|
+
from typing import Any, Callable, Iterator, cast
|
|
19
|
+
|
|
20
|
+
import grpc
|
|
21
|
+
from grpc import ServicerContext, StatusCode
|
|
22
|
+
from grpc.experimental import wrap_server_method_handler
|
|
23
|
+
|
|
24
|
+
from isolate.backends import (
|
|
25
|
+
EnvironmentCreationError,
|
|
26
|
+
IsolateSettings,
|
|
27
|
+
)
|
|
28
|
+
from isolate.backends.common import active_python
|
|
29
|
+
from isolate.backends.local import LocalPythonEnvironment
|
|
30
|
+
from isolate.backends.virtualenv import VirtualPythonEnvironment
|
|
31
|
+
from isolate.connections.grpc import AgentError, LocalPythonGRPC
|
|
32
|
+
from isolate.connections.grpc.configuration import get_default_options
|
|
33
|
+
from isolate.logger import IsolateLogger
|
|
34
|
+
from isolate.logs import Log, LogLevel, LogSource
|
|
35
|
+
from isolate.server import definitions, health
|
|
36
|
+
from isolate.server.health_server import HealthServicer
|
|
37
|
+
from isolate.server.interface import from_grpc, to_grpc
|
|
38
|
+
|
|
39
|
+
EMPTY_MESSAGE_INTERVAL = float(os.getenv("ISOLATE_EMPTY_MESSAGE_INTERVAL", "600"))
|
|
40
|
+
SKIP_EMPTY_LOGS = os.getenv("ISOLATE_SKIP_EMPTY_LOGS") == "1"
|
|
41
|
+
MAX_GRPC_WAIT_TIMEOUT = float(os.getenv("ISOLATE_MAX_GRPC_WAIT_TIMEOUT", "10.0"))
|
|
42
|
+
|
|
43
|
+
# Whether to inherit all the packages from the current environment or not.
|
|
44
|
+
INHERIT_FROM_LOCAL = os.getenv("ISOLATE_INHERIT_FROM_LOCAL") == "1"
|
|
45
|
+
|
|
46
|
+
# Number of threads that the gRPC server will use.
|
|
47
|
+
MAX_THREADS = int(os.getenv("MAX_THREADS", "5"))
|
|
48
|
+
_AGENT_REQUIREMENTS_TXT = os.getenv("AGENT_REQUIREMENTS_TXT")
|
|
49
|
+
|
|
50
|
+
if _AGENT_REQUIREMENTS_TXT is not None:
|
|
51
|
+
with open(_AGENT_REQUIREMENTS_TXT) as stream:
|
|
52
|
+
AGENT_REQUIREMENTS = stream.read().splitlines()
|
|
53
|
+
else:
|
|
54
|
+
AGENT_REQUIREMENTS = []
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# Number of seconds to observe the queue before checking the termination
|
|
58
|
+
# event.
|
|
59
|
+
_Q_WAIT_DELAY = 0.1
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class GRPCException(Exception):
|
|
63
|
+
def __init__(self, message: str, code: StatusCode = StatusCode.INVALID_ARGUMENT):
|
|
64
|
+
super().__init__(message)
|
|
65
|
+
self.message = message
|
|
66
|
+
self.code = code
|
|
67
|
+
|
|
68
|
+
def __str__(self) -> str:
|
|
69
|
+
return f"{self.code.name}: {self.message}"
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@dataclass
|
|
73
|
+
class RunnerAgent:
|
|
74
|
+
stub: definitions.AgentStub
|
|
75
|
+
message_queue: Queue[definitions.PartialRunResult]
|
|
76
|
+
_bound_context: ExitStack
|
|
77
|
+
_channel_state_history: list[grpc.ChannelConnectivity] = field(default_factory=list)
|
|
78
|
+
_connection: LocalPythonGRPC | None = None
|
|
79
|
+
_terminated: bool = False
|
|
80
|
+
|
|
81
|
+
def __post_init__(self):
|
|
82
|
+
def switch_state(connectivity_update: grpc.ChannelConnectivity) -> None:
|
|
83
|
+
self._channel_state_history.append(connectivity_update)
|
|
84
|
+
|
|
85
|
+
self.channel.subscribe(switch_state)
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def channel(self) -> grpc.Channel:
|
|
89
|
+
return self.stub._channel # type: ignore
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def is_accessible(self) -> bool:
|
|
93
|
+
try:
|
|
94
|
+
last_known_state = self._channel_state_history[-1]
|
|
95
|
+
except IndexError:
|
|
96
|
+
last_known_state = None
|
|
97
|
+
|
|
98
|
+
return last_known_state is grpc.ChannelConnectivity.READY
|
|
99
|
+
|
|
100
|
+
def check_connectivity(self) -> bool:
|
|
101
|
+
# Check whether the server is ready.
|
|
102
|
+
# TODO: This is more of a hack rather than a guaranteed health check,
|
|
103
|
+
# we might have to introduce the proper protocol to the agents as well
|
|
104
|
+
# to make sure that they are ready to receive requests.
|
|
105
|
+
return self.is_accessible
|
|
106
|
+
|
|
107
|
+
def terminate(self) -> None:
|
|
108
|
+
"""
|
|
109
|
+
Abort the agent first, then close the bound context.
|
|
110
|
+
|
|
111
|
+
Closing the ExitStack tears down the gRPC channel; doing that before
|
|
112
|
+
terminating the agent triggers an asyncio.CancelledError mid-request and
|
|
113
|
+
the agent never receives SIGTERM. By aborting first we deliver SIGTERM
|
|
114
|
+
while the connection is still alive, then close it.
|
|
115
|
+
"""
|
|
116
|
+
self._terminated = True
|
|
117
|
+
if self._connection:
|
|
118
|
+
self._connection.abort_agent()
|
|
119
|
+
self._bound_context.close()
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@dataclass
|
|
123
|
+
class BridgeManager:
|
|
124
|
+
_agent_access_lock: threading.Lock = field(default_factory=threading.Lock)
|
|
125
|
+
_agents: dict[tuple[Any, ...], list[RunnerAgent]] = field(
|
|
126
|
+
default_factory=lambda: defaultdict(list)
|
|
127
|
+
)
|
|
128
|
+
_stack: ExitStack = field(default_factory=ExitStack)
|
|
129
|
+
|
|
130
|
+
@contextmanager
|
|
131
|
+
def establish(
|
|
132
|
+
self,
|
|
133
|
+
connection: LocalPythonGRPC,
|
|
134
|
+
queue: Queue,
|
|
135
|
+
) -> Iterator[RunnerAgent]:
|
|
136
|
+
agent = self._allocate_new_agent(connection, queue)
|
|
137
|
+
|
|
138
|
+
try:
|
|
139
|
+
yield agent
|
|
140
|
+
finally:
|
|
141
|
+
self._cache_agent(connection, agent)
|
|
142
|
+
|
|
143
|
+
def _cache_agent(
|
|
144
|
+
self,
|
|
145
|
+
connection: LocalPythonGRPC,
|
|
146
|
+
agent: RunnerAgent,
|
|
147
|
+
) -> None:
|
|
148
|
+
with self._agent_access_lock:
|
|
149
|
+
self._agents[self._identify(connection)].append(agent)
|
|
150
|
+
|
|
151
|
+
def _allocate_new_agent(
|
|
152
|
+
self,
|
|
153
|
+
connection: LocalPythonGRPC,
|
|
154
|
+
queue: Queue,
|
|
155
|
+
) -> RunnerAgent:
|
|
156
|
+
with self._agent_access_lock:
|
|
157
|
+
available_agents = self._agents[self._identify(connection)]
|
|
158
|
+
while available_agents:
|
|
159
|
+
agent = available_agents.pop()
|
|
160
|
+
if agent.check_connectivity():
|
|
161
|
+
return agent
|
|
162
|
+
else:
|
|
163
|
+
agent.terminate()
|
|
164
|
+
|
|
165
|
+
bound_context = ExitStack()
|
|
166
|
+
stub = bound_context.enter_context(
|
|
167
|
+
connection._establish_bridge(max_wait_timeout=MAX_GRPC_WAIT_TIMEOUT)
|
|
168
|
+
)
|
|
169
|
+
return RunnerAgent(stub, queue, bound_context, [], connection)
|
|
170
|
+
|
|
171
|
+
def _identify(self, connection: LocalPythonGRPC) -> tuple[Any, ...]:
|
|
172
|
+
return (
|
|
173
|
+
connection.environment_path,
|
|
174
|
+
*connection.extra_inheritance_paths,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
def __enter__(self) -> BridgeManager:
|
|
178
|
+
return self
|
|
179
|
+
|
|
180
|
+
def __exit__(self, *exc_info: Any) -> None:
|
|
181
|
+
for agents in self._agents.values():
|
|
182
|
+
for agent in agents:
|
|
183
|
+
agent.terminate()
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
@dataclass
|
|
187
|
+
class RunTask:
|
|
188
|
+
request: definitions.BoundFunction
|
|
189
|
+
future: futures.Future | None = None
|
|
190
|
+
agent: RunnerAgent | None = None
|
|
191
|
+
logger: IsolateLogger = field(default_factory=IsolateLogger.from_env)
|
|
192
|
+
|
|
193
|
+
def cancel(self):
|
|
194
|
+
while True:
|
|
195
|
+
# Cancelling a running future is not possible, and it sometimes blocks,
|
|
196
|
+
# which means we never terminate the agent. So check if it's not running
|
|
197
|
+
if self.future and not self.future.running():
|
|
198
|
+
self.future.cancel()
|
|
199
|
+
|
|
200
|
+
if self.agent:
|
|
201
|
+
self.agent.terminate()
|
|
202
|
+
|
|
203
|
+
try:
|
|
204
|
+
if self.future:
|
|
205
|
+
self.future.exception(timeout=0.1)
|
|
206
|
+
return
|
|
207
|
+
except futures.TimeoutError:
|
|
208
|
+
pass
|
|
209
|
+
|
|
210
|
+
@property
|
|
211
|
+
def stream_logs(self) -> bool:
|
|
212
|
+
return self.request.stream_logs
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
@dataclass
|
|
216
|
+
class IsolateServicer(definitions.IsolateServicer):
|
|
217
|
+
bridge_manager: BridgeManager
|
|
218
|
+
default_settings: IsolateSettings = field(default_factory=IsolateSettings)
|
|
219
|
+
background_tasks: dict[str, RunTask] = field(default_factory=dict)
|
|
220
|
+
_shutting_down: bool = field(default=False)
|
|
221
|
+
|
|
222
|
+
_thread_pool: futures.ThreadPoolExecutor = field(
|
|
223
|
+
default_factory=lambda: futures.ThreadPoolExecutor(max_workers=MAX_THREADS)
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
def _run_task(self, task: RunTask) -> Iterator[definitions.PartialRunResult]:
|
|
227
|
+
messages: Queue[definitions.PartialRunResult] = Queue()
|
|
228
|
+
environments = []
|
|
229
|
+
for env in task.request.environments:
|
|
230
|
+
try:
|
|
231
|
+
environments.append((env.force, from_grpc(env)))
|
|
232
|
+
except ValueError:
|
|
233
|
+
raise GRPCException(f"Unknown environment kind: {env.kind}")
|
|
234
|
+
except TypeError as exc:
|
|
235
|
+
raise GRPCException(f"Invalid environment: {str(exc)}")
|
|
236
|
+
|
|
237
|
+
if not environments:
|
|
238
|
+
raise GRPCException(
|
|
239
|
+
"At least one environment must be specified for a run!",
|
|
240
|
+
StatusCode.INVALID_ARGUMENT,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
log_handler = LogHandler(messages, task=task)
|
|
244
|
+
|
|
245
|
+
run_settings = replace(
|
|
246
|
+
self.default_settings,
|
|
247
|
+
log_hook=log_handler.handle,
|
|
248
|
+
serialization_method=task.request.function.method,
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
for _, environment in environments:
|
|
252
|
+
environment.apply_settings(run_settings)
|
|
253
|
+
|
|
254
|
+
_, primary_environment = environments[0]
|
|
255
|
+
|
|
256
|
+
if AGENT_REQUIREMENTS:
|
|
257
|
+
python_version = getattr(
|
|
258
|
+
primary_environment, "python_version", active_python()
|
|
259
|
+
)
|
|
260
|
+
agent_environ = VirtualPythonEnvironment(
|
|
261
|
+
requirements=AGENT_REQUIREMENTS,
|
|
262
|
+
python_version=python_version,
|
|
263
|
+
)
|
|
264
|
+
agent_environ.apply_settings(run_settings)
|
|
265
|
+
environments.insert(1, (False, agent_environ))
|
|
266
|
+
|
|
267
|
+
extra_inheritance_paths = []
|
|
268
|
+
if INHERIT_FROM_LOCAL:
|
|
269
|
+
local_environment = LocalPythonEnvironment()
|
|
270
|
+
extra_inheritance_paths.append(local_environment.create())
|
|
271
|
+
|
|
272
|
+
with ThreadPoolExecutor(max_workers=1) as local_pool:
|
|
273
|
+
environment_paths = []
|
|
274
|
+
for should_force_create, environment in environments:
|
|
275
|
+
future = local_pool.submit(
|
|
276
|
+
environment.create, force=should_force_create
|
|
277
|
+
)
|
|
278
|
+
yield from self.watch_queue_until_completed(messages, future.done)
|
|
279
|
+
try:
|
|
280
|
+
# Assuming that the iterator above only stops yielding once
|
|
281
|
+
# the future is completed, the timeout here should be redundant
|
|
282
|
+
# but it is just in case.
|
|
283
|
+
environment_paths.append(future.result(timeout=0.1))
|
|
284
|
+
except EnvironmentCreationError as e:
|
|
285
|
+
raise GRPCException(f"{e}", StatusCode.INVALID_ARGUMENT)
|
|
286
|
+
|
|
287
|
+
primary_path, *inheritance_paths = environment_paths
|
|
288
|
+
inheritance_paths.extend(extra_inheritance_paths)
|
|
289
|
+
_, primary_environment = environments[0]
|
|
290
|
+
|
|
291
|
+
connection = LocalPythonGRPC(
|
|
292
|
+
primary_environment,
|
|
293
|
+
primary_path,
|
|
294
|
+
extra_inheritance_paths=inheritance_paths,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
with self.bridge_manager.establish(connection, queue=messages) as agent:
|
|
298
|
+
task.agent = agent
|
|
299
|
+
function_call = definitions.FunctionCall(
|
|
300
|
+
function=task.request.function,
|
|
301
|
+
setup_func=task.request.setup_func,
|
|
302
|
+
)
|
|
303
|
+
if not task.request.HasField("setup_func"):
|
|
304
|
+
function_call.ClearField("setup_func")
|
|
305
|
+
|
|
306
|
+
future = local_pool.submit(
|
|
307
|
+
_proxy_to_queue,
|
|
308
|
+
# The agent may have been cached, so use the agent's message queue
|
|
309
|
+
queue=agent.message_queue,
|
|
310
|
+
bridge=agent.stub,
|
|
311
|
+
input=function_call,
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
# Unlike above; we are not interested in the result value of future
|
|
315
|
+
# here, since it will be already transferred to other side without
|
|
316
|
+
# us even seeing (through the queue).
|
|
317
|
+
yield from self.watch_queue_until_completed(
|
|
318
|
+
agent.message_queue, future.done
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
# But we still have to check whether there were any errors raised
|
|
322
|
+
# during the execution, and handle them accordingly.
|
|
323
|
+
exception = future.exception(timeout=0.1)
|
|
324
|
+
if exception is not None:
|
|
325
|
+
# If this is an RPC error, propagate it as is without any
|
|
326
|
+
# further processing.
|
|
327
|
+
|
|
328
|
+
if isinstance(exception, grpc.RpcError):
|
|
329
|
+
# on abort, we terminate the process before we close the channel
|
|
330
|
+
# because we need to populate SIGTERM to the agent process
|
|
331
|
+
if (
|
|
332
|
+
agent._terminated
|
|
333
|
+
and exception.code() == StatusCode.UNAVAILABLE
|
|
334
|
+
):
|
|
335
|
+
return
|
|
336
|
+
raise GRPCException(
|
|
337
|
+
str(exception),
|
|
338
|
+
exception.code(),
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
# Otherwise this is a bug in the agent itself, so needs
|
|
342
|
+
# to be propagated with more details.
|
|
343
|
+
for line in traceback.format_exception(
|
|
344
|
+
type(exception), exception, exception.__traceback__
|
|
345
|
+
):
|
|
346
|
+
yield from self.log(line, level=LogLevel.ERROR)
|
|
347
|
+
if isinstance(exception, AgentError):
|
|
348
|
+
raise GRPCException(str(exception), StatusCode.ABORTED)
|
|
349
|
+
else:
|
|
350
|
+
raise GRPCException(
|
|
351
|
+
f"An unexpected error occurred: {exception}.",
|
|
352
|
+
StatusCode.UNKNOWN,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
def _run_task_in_background(self, task: RunTask) -> None:
|
|
356
|
+
for _ in self._run_task(task):
|
|
357
|
+
pass
|
|
358
|
+
|
|
359
|
+
def Submit(
|
|
360
|
+
self,
|
|
361
|
+
request: definitions.SubmitRequest,
|
|
362
|
+
context: ServicerContext,
|
|
363
|
+
) -> definitions.SubmitResponse:
|
|
364
|
+
task = RunTask(request=request.function)
|
|
365
|
+
self.set_metadata(task, request.metadata)
|
|
366
|
+
|
|
367
|
+
task.future = self._thread_pool.submit(self._run_task_in_background, task)
|
|
368
|
+
task_id = str(uuid.uuid4())
|
|
369
|
+
|
|
370
|
+
print(f"Submitted a task {task_id}")
|
|
371
|
+
|
|
372
|
+
self.background_tasks[task_id] = task
|
|
373
|
+
|
|
374
|
+
def _callback(future: futures.Future) -> None:
|
|
375
|
+
msg = f"Task {task_id} finished with"
|
|
376
|
+
if exc := future.exception():
|
|
377
|
+
msg += f" error: {exc!r}"
|
|
378
|
+
else:
|
|
379
|
+
msg += f" result: {future.result()!r}"
|
|
380
|
+
print(msg)
|
|
381
|
+
self.background_tasks.pop(task_id, None)
|
|
382
|
+
|
|
383
|
+
task.future.add_done_callback(_callback)
|
|
384
|
+
|
|
385
|
+
return definitions.SubmitResponse(task_id=task_id)
|
|
386
|
+
|
|
387
|
+
def SetMetadata(
|
|
388
|
+
self,
|
|
389
|
+
request: definitions.SetMetadataRequest,
|
|
390
|
+
context: ServicerContext,
|
|
391
|
+
) -> definitions.SetMetadataResponse:
|
|
392
|
+
if request.task_id not in self.background_tasks:
|
|
393
|
+
raise GRPCException(
|
|
394
|
+
f"Task {request.task_id} not found.",
|
|
395
|
+
StatusCode.NOT_FOUND,
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
self.set_metadata(self.background_tasks[request.task_id], request.metadata)
|
|
399
|
+
|
|
400
|
+
return definitions.SetMetadataResponse()
|
|
401
|
+
|
|
402
|
+
def set_metadata(self, task: RunTask, metadata: definitions.TaskMetadata) -> None:
|
|
403
|
+
task.logger.extra_labels = dict(metadata.logger_labels)
|
|
404
|
+
|
|
405
|
+
def Run(
|
|
406
|
+
self,
|
|
407
|
+
request: definitions.BoundFunction,
|
|
408
|
+
context: ServicerContext,
|
|
409
|
+
) -> Iterator[definitions.PartialRunResult]:
|
|
410
|
+
try:
|
|
411
|
+
task = RunTask(request=request)
|
|
412
|
+
|
|
413
|
+
# HACK: we can support only one task at a time
|
|
414
|
+
# TODO: move away from this when we use submit for env-aware tasks
|
|
415
|
+
self.background_tasks["RUN"] = task
|
|
416
|
+
yield from self._run_task(task)
|
|
417
|
+
except GRPCException as exc:
|
|
418
|
+
return self.abort_with_msg(
|
|
419
|
+
exc.message,
|
|
420
|
+
context,
|
|
421
|
+
code=exc.code,
|
|
422
|
+
)
|
|
423
|
+
finally:
|
|
424
|
+
self.background_tasks.pop("RUN", None)
|
|
425
|
+
|
|
426
|
+
def List(
|
|
427
|
+
self,
|
|
428
|
+
request: definitions.ListRequest,
|
|
429
|
+
context: ServicerContext,
|
|
430
|
+
) -> definitions.ListResponse:
|
|
431
|
+
return definitions.ListResponse(
|
|
432
|
+
tasks=[
|
|
433
|
+
definitions.TaskInfo(task_id=task_id)
|
|
434
|
+
for task_id in self.background_tasks.keys()
|
|
435
|
+
]
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
def Cancel(
|
|
439
|
+
self,
|
|
440
|
+
request: definitions.CancelRequest,
|
|
441
|
+
context: ServicerContext,
|
|
442
|
+
) -> definitions.CancelResponse:
|
|
443
|
+
task_id = request.task_id
|
|
444
|
+
|
|
445
|
+
print(f"Canceling task {task_id}")
|
|
446
|
+
task = self.background_tasks.get(task_id)
|
|
447
|
+
if task is not None:
|
|
448
|
+
task.cancel()
|
|
449
|
+
|
|
450
|
+
return definitions.CancelResponse()
|
|
451
|
+
|
|
452
|
+
def shutdown(self) -> None:
|
|
453
|
+
if self._shutting_down:
|
|
454
|
+
print("Shutdown already in progress...")
|
|
455
|
+
return
|
|
456
|
+
|
|
457
|
+
self._shutting_down = True
|
|
458
|
+
task_count = len(self.background_tasks)
|
|
459
|
+
print(f"Shutting down, canceling {task_count} tasks...")
|
|
460
|
+
self.cancel_tasks()
|
|
461
|
+
print("All tasks canceled.")
|
|
462
|
+
|
|
463
|
+
def watch_queue_until_completed(
|
|
464
|
+
self, queue: Queue, is_completed: Callable[[], bool]
|
|
465
|
+
) -> Iterator[definitions.PartialRunResult]:
|
|
466
|
+
"""Watch the given queue until the is_completed function returns True.
|
|
467
|
+
Note that even if the function is completed, this function might not
|
|
468
|
+
finish until the queue is empty.
|
|
469
|
+
"""
|
|
470
|
+
|
|
471
|
+
timer = time.monotonic()
|
|
472
|
+
while not is_completed():
|
|
473
|
+
try:
|
|
474
|
+
yield queue.get(timeout=_Q_WAIT_DELAY)
|
|
475
|
+
except QueueEmpty:
|
|
476
|
+
# Send an empty (but 'real') packet to the client, currently a hacky way
|
|
477
|
+
# to make sure the stream results are never ignored.
|
|
478
|
+
if time.monotonic() - timer > EMPTY_MESSAGE_INTERVAL:
|
|
479
|
+
timer = time.monotonic()
|
|
480
|
+
yield definitions.PartialRunResult(
|
|
481
|
+
is_complete=False,
|
|
482
|
+
logs=[],
|
|
483
|
+
result=None,
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
# Clear the final messages
|
|
487
|
+
while not queue.empty():
|
|
488
|
+
try:
|
|
489
|
+
yield queue.get_nowait()
|
|
490
|
+
except QueueEmpty:
|
|
491
|
+
continue
|
|
492
|
+
|
|
493
|
+
def log(
|
|
494
|
+
self,
|
|
495
|
+
message: str,
|
|
496
|
+
level: LogLevel = LogLevel.TRACE,
|
|
497
|
+
source: LogSource = LogSource.BRIDGE,
|
|
498
|
+
) -> Iterator[definitions.PartialRunResult]:
|
|
499
|
+
log = to_grpc(Log(message, level=level, source=source))
|
|
500
|
+
log = cast(definitions.Log, log)
|
|
501
|
+
yield definitions.PartialRunResult(result=None, is_complete=False, logs=[log])
|
|
502
|
+
|
|
503
|
+
def abort_with_msg(
|
|
504
|
+
self,
|
|
505
|
+
message: str,
|
|
506
|
+
context: ServicerContext,
|
|
507
|
+
*,
|
|
508
|
+
code: StatusCode = StatusCode.INVALID_ARGUMENT,
|
|
509
|
+
) -> None:
|
|
510
|
+
context.set_code(code)
|
|
511
|
+
context.set_details(message)
|
|
512
|
+
return None
|
|
513
|
+
|
|
514
|
+
def cancel_tasks(self):
|
|
515
|
+
tasks_copy = self.background_tasks.copy()
|
|
516
|
+
for task in tasks_copy.values():
|
|
517
|
+
task.cancel()
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
def _proxy_to_queue(
|
|
521
|
+
queue: Queue,
|
|
522
|
+
bridge: definitions.AgentStub,
|
|
523
|
+
input: definitions.FunctionCall,
|
|
524
|
+
) -> None:
|
|
525
|
+
for message in bridge.Run(input):
|
|
526
|
+
queue.put_nowait(message)
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
@dataclass
|
|
530
|
+
class LogHandler:
|
|
531
|
+
messages: Queue
|
|
532
|
+
# Reference to the task so we can change the logger
|
|
533
|
+
task: RunTask
|
|
534
|
+
|
|
535
|
+
def handle(self, log: Log) -> None:
|
|
536
|
+
if not SKIP_EMPTY_LOGS or log.message.strip():
|
|
537
|
+
self.task.logger.log(log.level, log.message, source=log.source)
|
|
538
|
+
self._add_log_to_queue(log)
|
|
539
|
+
|
|
540
|
+
def _add_log_to_queue(self, log: Log) -> None:
|
|
541
|
+
if not self.task.stream_logs:
|
|
542
|
+
# We do not queue the logs if the stream_logs is disabled
|
|
543
|
+
# but still log them to the logger.
|
|
544
|
+
return
|
|
545
|
+
|
|
546
|
+
grpc_log = cast(definitions.Log, to_grpc(log))
|
|
547
|
+
grpc_result = definitions.PartialRunResult(
|
|
548
|
+
is_complete=False,
|
|
549
|
+
logs=[grpc_log],
|
|
550
|
+
result=None,
|
|
551
|
+
)
|
|
552
|
+
self.messages.put_nowait(grpc_result)
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
@dataclass
|
|
556
|
+
class ServerBoundInterceptor(grpc.ServerInterceptor):
|
|
557
|
+
_server: grpc.Server | None = None
|
|
558
|
+
_servicer: IsolateServicer | None = None
|
|
559
|
+
|
|
560
|
+
def register_server(self, server: grpc.Server) -> None:
|
|
561
|
+
if self._server is not None:
|
|
562
|
+
raise RuntimeError("A server is already bound to this interceptor.")
|
|
563
|
+
|
|
564
|
+
self._server = server
|
|
565
|
+
|
|
566
|
+
@property
|
|
567
|
+
def server(self) -> grpc.Server:
|
|
568
|
+
if self._server is None:
|
|
569
|
+
raise RuntimeError("No server was bound to this interceptor.")
|
|
570
|
+
|
|
571
|
+
return self._server
|
|
572
|
+
|
|
573
|
+
def register_servicer(self, servicer: IsolateServicer) -> None:
|
|
574
|
+
if self._servicer is not None:
|
|
575
|
+
raise RuntimeError("A servicer is already bound to this interceptor.")
|
|
576
|
+
|
|
577
|
+
self._servicer = servicer
|
|
578
|
+
|
|
579
|
+
@property
|
|
580
|
+
def servicer(self) -> IsolateServicer:
|
|
581
|
+
if self._servicer is None:
|
|
582
|
+
raise RuntimeError("No servicer was bound to this interceptor.")
|
|
583
|
+
|
|
584
|
+
return self._servicer
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
@dataclass
|
|
588
|
+
class SingleTaskInterceptor(ServerBoundInterceptor):
|
|
589
|
+
"""Sets server to terminate after the first Submit/Run task."""
|
|
590
|
+
|
|
591
|
+
_done: bool = False
|
|
592
|
+
_task_id: str | None = None
|
|
593
|
+
|
|
594
|
+
def __init__(self):
|
|
595
|
+
def terminate(request: Any, context: grpc.ServicerContext) -> Any:
|
|
596
|
+
context.abort(
|
|
597
|
+
grpc.StatusCode.RESOURCE_EXHAUSTED,
|
|
598
|
+
"Server has already served one Run/Submit task.",
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
self._terminator = grpc.unary_unary_rpc_method_handler(terminate)
|
|
602
|
+
|
|
603
|
+
def intercept_service(self, continuation, handler_call_details):
|
|
604
|
+
handler = continuation(handler_call_details)
|
|
605
|
+
|
|
606
|
+
is_submit = handler_call_details.method == "/Isolate/Submit"
|
|
607
|
+
is_run = handler_call_details.method == "/Isolate/Run"
|
|
608
|
+
is_new_task = is_submit or is_run
|
|
609
|
+
|
|
610
|
+
if not is_new_task:
|
|
611
|
+
# Let other requests like List/Cancel/etc pass through
|
|
612
|
+
return handler
|
|
613
|
+
|
|
614
|
+
if self._done:
|
|
615
|
+
# Fail the request if the server has already served or is serving
|
|
616
|
+
# a Run/Submit task.
|
|
617
|
+
return self._terminator
|
|
618
|
+
|
|
619
|
+
self._done = True
|
|
620
|
+
|
|
621
|
+
def wrapper(method_impl):
|
|
622
|
+
@functools.wraps(method_impl)
|
|
623
|
+
def _wrapper(request: Any, context: grpc.ServicerContext) -> Any:
|
|
624
|
+
def termination() -> None:
|
|
625
|
+
if is_run:
|
|
626
|
+
print("Stopping server since run is finished")
|
|
627
|
+
self.servicer.shutdown()
|
|
628
|
+
# Stop the server after the Run task is finished
|
|
629
|
+
self.server.stop(grace=0.1)
|
|
630
|
+
print("Server stopped")
|
|
631
|
+
|
|
632
|
+
elif is_submit:
|
|
633
|
+
# Wait until the task_id is assigned
|
|
634
|
+
while self._task_id is None:
|
|
635
|
+
time.sleep(0.1)
|
|
636
|
+
|
|
637
|
+
# Get the task from the background tasks
|
|
638
|
+
task = self.servicer.background_tasks.get(self._task_id)
|
|
639
|
+
|
|
640
|
+
if task is not None:
|
|
641
|
+
# Wait until the task future is assigned
|
|
642
|
+
tries = 0
|
|
643
|
+
while task.future is None:
|
|
644
|
+
time.sleep(0.1)
|
|
645
|
+
tries += 1
|
|
646
|
+
if tries > 100:
|
|
647
|
+
raise RuntimeError(
|
|
648
|
+
"Task future was not assigned in time."
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
def _stop(*args):
|
|
652
|
+
# Small sleep to make sure the cancellation is processed
|
|
653
|
+
time.sleep(0.1)
|
|
654
|
+
print("Stopping server since the task is finished")
|
|
655
|
+
self.servicer.shutdown()
|
|
656
|
+
self.server.stop(grace=0.1)
|
|
657
|
+
print("Server stopped")
|
|
658
|
+
|
|
659
|
+
# Add a callback which will stop the server
|
|
660
|
+
# after the task is finished
|
|
661
|
+
task.future.add_done_callback(_stop)
|
|
662
|
+
|
|
663
|
+
context.add_callback(termination)
|
|
664
|
+
res = method_impl(request, context)
|
|
665
|
+
|
|
666
|
+
if is_submit:
|
|
667
|
+
self._task_id = cast(definitions.SubmitResponse, res).task_id
|
|
668
|
+
|
|
669
|
+
return res
|
|
670
|
+
|
|
671
|
+
return _wrapper
|
|
672
|
+
|
|
673
|
+
return wrap_server_method_handler(wrapper, handler)
|
|
674
|
+
|
|
675
|
+
|
|
676
|
+
def main(argv: list[str] | None = None) -> None:
|
|
677
|
+
parser = ArgumentParser()
|
|
678
|
+
parser.add_argument("--host", default="0.0.0.0")
|
|
679
|
+
parser.add_argument("--port", type=int, default=50001)
|
|
680
|
+
parser.add_argument(
|
|
681
|
+
"--single-use",
|
|
682
|
+
action="store_true",
|
|
683
|
+
help="Terminate the server after the first Run or Submit task is completed.",
|
|
684
|
+
)
|
|
685
|
+
parser.add_argument(
|
|
686
|
+
"--num-workers",
|
|
687
|
+
type=int,
|
|
688
|
+
default=MAX_THREADS,
|
|
689
|
+
help="Number of worker threads to use for the gRPC server.",
|
|
690
|
+
)
|
|
691
|
+
|
|
692
|
+
options = parser.parse_args(argv)
|
|
693
|
+
if options.num_workers is None:
|
|
694
|
+
options.num_workers = 1 if options.single_use else os.cpu_count()
|
|
695
|
+
|
|
696
|
+
interceptors: list[ServerBoundInterceptor] = []
|
|
697
|
+
if options.single_use:
|
|
698
|
+
interceptors.append(SingleTaskInterceptor())
|
|
699
|
+
|
|
700
|
+
server = grpc.server(
|
|
701
|
+
futures.ThreadPoolExecutor(max_workers=options.num_workers),
|
|
702
|
+
options=get_default_options(),
|
|
703
|
+
interceptors=interceptors, # type: ignore
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
for interceptor in interceptors:
|
|
707
|
+
interceptor.register_server(server)
|
|
708
|
+
|
|
709
|
+
with BridgeManager() as bridge_manager:
|
|
710
|
+
servicer = IsolateServicer(bridge_manager)
|
|
711
|
+
|
|
712
|
+
for interceptor in interceptors:
|
|
713
|
+
interceptor.register_servicer(servicer)
|
|
714
|
+
|
|
715
|
+
definitions.register_isolate(servicer, server)
|
|
716
|
+
health.register_health(HealthServicer(), server)
|
|
717
|
+
|
|
718
|
+
def handle_termination(*args):
|
|
719
|
+
print("Termination signal received, shutting down...")
|
|
720
|
+
servicer.shutdown()
|
|
721
|
+
server.stop(grace=0.1)
|
|
722
|
+
|
|
723
|
+
signal.signal(signal.SIGINT, handle_termination)
|
|
724
|
+
signal.signal(signal.SIGTERM, handle_termination)
|
|
725
|
+
|
|
726
|
+
server.add_insecure_port(f"[::]:{options.port}")
|
|
727
|
+
print(f"Started listening at {options.host}:{options.port}")
|
|
728
|
+
|
|
729
|
+
server.start()
|
|
730
|
+
server.wait_for_termination()
|
|
731
|
+
print("Server shut down")
|
|
732
|
+
|
|
733
|
+
|
|
734
|
+
if __name__ == "__main__":
|
|
735
|
+
main()
|