tracdap-runtime 0.6.1.dev3__py3-none-any.whl → 0.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. tracdap/rt/_exec/actors.py +87 -10
  2. tracdap/rt/_exec/context.py +25 -1
  3. tracdap/rt/_exec/dev_mode.py +277 -221
  4. tracdap/rt/_exec/engine.py +79 -14
  5. tracdap/rt/_exec/functions.py +37 -8
  6. tracdap/rt/_exec/graph.py +2 -0
  7. tracdap/rt/_exec/graph_builder.py +118 -56
  8. tracdap/rt/_exec/runtime.py +108 -37
  9. tracdap/rt/_exec/server.py +345 -0
  10. tracdap/rt/_impl/config_parser.py +219 -49
  11. tracdap/rt/_impl/data.py +14 -0
  12. tracdap/rt/_impl/grpc/__init__.py +13 -0
  13. tracdap/rt/_impl/grpc/codec.py +99 -0
  14. tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2.py +51 -0
  15. tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2.pyi +61 -0
  16. tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2_grpc.py +183 -0
  17. tracdap/rt/_impl/grpc/tracdap/metadata/common_pb2.py +33 -0
  18. tracdap/rt/_impl/grpc/tracdap/metadata/common_pb2.pyi +34 -0
  19. tracdap/rt/{metadata → _impl/grpc/tracdap/metadata}/custom_pb2.py +5 -5
  20. tracdap/rt/_impl/grpc/tracdap/metadata/custom_pb2.pyi +15 -0
  21. tracdap/rt/_impl/grpc/tracdap/metadata/data_pb2.py +51 -0
  22. tracdap/rt/_impl/grpc/tracdap/metadata/data_pb2.pyi +115 -0
  23. tracdap/rt/_impl/grpc/tracdap/metadata/file_pb2.py +28 -0
  24. tracdap/rt/_impl/grpc/tracdap/metadata/file_pb2.pyi +22 -0
  25. tracdap/rt/_impl/grpc/tracdap/metadata/flow_pb2.py +59 -0
  26. tracdap/rt/_impl/grpc/tracdap/metadata/flow_pb2.pyi +109 -0
  27. tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.py +76 -0
  28. tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.pyi +177 -0
  29. tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.py +63 -0
  30. tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.pyi +119 -0
  31. tracdap/rt/_impl/grpc/tracdap/metadata/object_id_pb2.py +32 -0
  32. tracdap/rt/_impl/grpc/tracdap/metadata/object_id_pb2.pyi +68 -0
  33. tracdap/rt/_impl/grpc/tracdap/metadata/object_pb2.py +40 -0
  34. tracdap/rt/_impl/grpc/tracdap/metadata/object_pb2.pyi +46 -0
  35. tracdap/rt/_impl/grpc/tracdap/metadata/search_pb2.py +39 -0
  36. tracdap/rt/_impl/grpc/tracdap/metadata/search_pb2.pyi +83 -0
  37. tracdap/rt/_impl/grpc/tracdap/metadata/stoarge_pb2.py +50 -0
  38. tracdap/rt/_impl/grpc/tracdap/metadata/stoarge_pb2.pyi +89 -0
  39. tracdap/rt/_impl/grpc/tracdap/metadata/tag_pb2.py +34 -0
  40. tracdap/rt/_impl/grpc/tracdap/metadata/tag_pb2.pyi +26 -0
  41. tracdap/rt/_impl/grpc/tracdap/metadata/tag_update_pb2.py +30 -0
  42. tracdap/rt/_impl/grpc/tracdap/metadata/tag_update_pb2.pyi +34 -0
  43. tracdap/rt/_impl/grpc/tracdap/metadata/type_pb2.py +47 -0
  44. tracdap/rt/_impl/grpc/tracdap/metadata/type_pb2.pyi +101 -0
  45. tracdap/rt/_impl/guard_rails.py +26 -6
  46. tracdap/rt/_impl/models.py +25 -0
  47. tracdap/rt/_impl/static_api.py +27 -9
  48. tracdap/rt/_impl/type_system.py +17 -0
  49. tracdap/rt/_impl/validation.py +10 -0
  50. tracdap/rt/_plugins/config_local.py +49 -0
  51. tracdap/rt/_version.py +1 -1
  52. tracdap/rt/api/hook.py +10 -3
  53. tracdap/rt/api/model_api.py +22 -0
  54. tracdap/rt/api/static_api.py +79 -19
  55. tracdap/rt/config/__init__.py +3 -3
  56. tracdap/rt/config/common.py +10 -0
  57. tracdap/rt/config/platform.py +9 -19
  58. tracdap/rt/config/runtime.py +2 -0
  59. tracdap/rt/ext/config.py +34 -0
  60. tracdap/rt/ext/embed.py +1 -3
  61. tracdap/rt/ext/plugins.py +47 -6
  62. tracdap/rt/launch/cli.py +7 -5
  63. tracdap/rt/launch/launch.py +49 -12
  64. tracdap/rt/metadata/__init__.py +24 -24
  65. tracdap/rt/metadata/common.py +7 -7
  66. tracdap/rt/metadata/custom.py +2 -0
  67. tracdap/rt/metadata/data.py +28 -5
  68. tracdap/rt/metadata/file.py +2 -0
  69. tracdap/rt/metadata/flow.py +66 -4
  70. tracdap/rt/metadata/job.py +56 -16
  71. tracdap/rt/metadata/model.py +10 -0
  72. tracdap/rt/metadata/object.py +3 -0
  73. tracdap/rt/metadata/object_id.py +9 -9
  74. tracdap/rt/metadata/search.py +35 -13
  75. tracdap/rt/metadata/stoarge.py +64 -6
  76. tracdap/rt/metadata/tag_update.py +21 -7
  77. tracdap/rt/metadata/type.py +28 -13
  78. {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.3.dist-info}/METADATA +22 -19
  79. tracdap_runtime-0.6.3.dist-info/RECORD +112 -0
  80. {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.3.dist-info}/WHEEL +1 -1
  81. tracdap/rt/config/common_pb2.py +0 -55
  82. tracdap/rt/config/job_pb2.py +0 -42
  83. tracdap/rt/config/platform_pb2.py +0 -71
  84. tracdap/rt/config/result_pb2.py +0 -37
  85. tracdap/rt/config/runtime_pb2.py +0 -42
  86. tracdap/rt/ext/_guard.py +0 -37
  87. tracdap/rt/metadata/common_pb2.py +0 -33
  88. tracdap/rt/metadata/data_pb2.py +0 -51
  89. tracdap/rt/metadata/file_pb2.py +0 -28
  90. tracdap/rt/metadata/flow_pb2.py +0 -55
  91. tracdap/rt/metadata/job_pb2.py +0 -76
  92. tracdap/rt/metadata/model_pb2.py +0 -51
  93. tracdap/rt/metadata/object_id_pb2.py +0 -32
  94. tracdap/rt/metadata/object_pb2.py +0 -35
  95. tracdap/rt/metadata/search_pb2.py +0 -39
  96. tracdap/rt/metadata/stoarge_pb2.py +0 -50
  97. tracdap/rt/metadata/tag_pb2.py +0 -34
  98. tracdap/rt/metadata/tag_update_pb2.py +0 -30
  99. tracdap/rt/metadata/type_pb2.py +0 -48
  100. tracdap_runtime-0.6.1.dev3.dist-info/RECORD +0 -96
  101. {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.3.dist-info}/LICENSE +0 -0
  102. {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.3.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ from __future__ import annotations
16
16
 
17
17
  import dataclasses as dc
18
18
  import datetime as dt
19
+ import signal
19
20
  import threading
20
21
 
21
22
  import sys
@@ -54,6 +55,8 @@ class TracRuntime:
54
55
  _engine.ModelNodeProcessor: "model",
55
56
  _engine.DataNodeProcessor: "data"}
56
57
 
58
+ __DEFAULT_API_PORT = 9000
59
+
57
60
  def __init__(
58
61
  self,
59
62
  sys_config: tp.Union[str, pathlib.Path, _cfg.RuntimeConfig],
@@ -61,6 +64,7 @@ class TracRuntime:
61
64
  job_result_format: tp.Optional[str] = None,
62
65
  scratch_dir: tp.Union[str, pathlib.Path, None] = None,
63
66
  scratch_dir_persist: bool = False,
67
+ plugin_packages: tp.List[str] = None,
64
68
  dev_mode: bool = False):
65
69
 
66
70
  trac_version = _version.__version__
@@ -83,22 +87,35 @@ class TracRuntime:
83
87
  self._log.info(f"TRAC D.A.P. Python Runtime {trac_version}")
84
88
 
85
89
  self._sys_config = sys_config if isinstance(sys_config, _cfg.RuntimeConfig) else None
86
- self._sys_config_path = pathlib.Path(sys_config) if not self._sys_config else None
90
+ self._sys_config_path = sys_config if not self._sys_config else None
87
91
  self._job_result_dir = job_result_dir
88
92
  self._job_result_format = job_result_format
89
93
  self._scratch_dir = scratch_dir
90
94
  self._scratch_dir_provided = True if scratch_dir is not None else False
91
95
  self._scratch_dir_persist = scratch_dir_persist
96
+ self._plugin_packages = plugin_packages or []
92
97
  self._dev_mode = dev_mode
93
98
 
99
+ # Runtime control
100
+ self._runtime_lock = threading.Lock()
101
+ self._runtime_event = threading.Condition(self._runtime_lock)
102
+ self._pre_start_complete = False
103
+ self._shutdown_requested = False
104
+ self._oneshot_job = None
105
+
94
106
  # Top level resources
107
+ self._config_mgr: tp.Optional[_cparse.ConfigManager] = None
95
108
  self._models: tp.Optional[_models.ModelLoader] = None
96
109
  self._storage: tp.Optional[_storage.StorageManager] = None
97
110
 
98
111
  # The execution engine
99
112
  self._system: tp.Optional[_actors.ActorSystem] = None
100
113
  self._engine: tp.Optional[_engine.TracEngine] = None
101
- self._engine_event = threading.Condition()
114
+
115
+ # Runtime API server
116
+ self._server_enabled = False
117
+ self._server_port = 0
118
+ self._server = None
102
119
 
103
120
  self._jobs: tp.Dict[str, _RuntimeJobInfo] = dict()
104
121
 
@@ -127,21 +144,28 @@ class TracRuntime:
127
144
 
128
145
  self._prepare_scratch_dir()
129
146
 
130
- # Plugin manager and static API impl are singletons
131
- # If these methods are called multiple times, the second and subsequent calls are ignored
147
+ # Plugin manager, static API and guard rails are singletons
148
+ # Calling these methods multiple times is safe (e.g. for embedded or testing scenarios)
149
+ # However, plugins are never un-registered for the lifetime of the processes
132
150
 
133
151
  _plugins.PluginManager.register_core_plugins()
152
+
153
+ for plugin_package in self._plugin_packages:
154
+ _plugins.PluginManager.register_plugin_package(plugin_package)
155
+
134
156
  _static_api.StaticApiImpl.register_impl()
135
157
  _guard.PythonGuardRails.protect_dangerous_functions()
136
158
 
137
159
  # Load sys config (or use embedded), config errors are detected before start()
138
160
  # Job config can also be checked before start() by using load_job_config()
139
161
 
162
+ self._config_mgr = _cparse.ConfigManager.for_root_config(self._sys_config_path)
163
+
140
164
  if self._sys_config is None:
141
165
  sys_config_dev_mode = _dev_mode.DEV_MODE_SYS_CONFIG if self._dev_mode else None
142
- sys_config_parser = _cparse.ConfigParser(_cfg.RuntimeConfig, sys_config_dev_mode)
143
- sys_config_raw = sys_config_parser.load_raw_config(self._sys_config_path, config_file_name="system")
144
- self._sys_config = sys_config_parser.parse(sys_config_raw, self._sys_config_path)
166
+ self._sys_config = self._config_mgr.load_root_object(
167
+ _cfg.RuntimeConfig, sys_config_dev_mode,
168
+ config_file_name="system")
145
169
  else:
146
170
  self._log.info("Using embedded system config")
147
171
 
@@ -149,8 +173,17 @@ class TracRuntime:
149
173
  # I.e. it can be applied to embedded configs
150
174
 
151
175
  if self._dev_mode:
152
- config_dir = self._sys_config_path.parent if self._sys_config_path is not None else None
153
- self._sys_config = _dev_mode.DevModeTranslator.translate_sys_config(self._sys_config, config_dir)
176
+ self._sys_config = _dev_mode.DevModeTranslator.translate_sys_config(self._sys_config, self._config_mgr)
177
+
178
+ # Runtime API server is controlled by the sys config
179
+
180
+ if self._sys_config.runtimeApi is not None:
181
+ api_config = self._sys_config.runtimeApi
182
+ if api_config.enabled:
183
+ self._server_enabled = True
184
+ self._server_port = api_config.port or self.__DEFAULT_API_PORT
185
+
186
+ self._pre_start_complete = True
154
187
 
155
188
  except Exception as e:
156
189
  self._handle_startup_error(e)
@@ -159,6 +192,10 @@ class TracRuntime:
159
192
 
160
193
  try:
161
194
 
195
+ # Ensure pre-start has been run
196
+ if not self._pre_start_complete:
197
+ self.pre_start()
198
+
162
199
  self._log.info("Starting the engine...")
163
200
 
164
201
  self._models = _models.ModelLoader(self._sys_config, self._scratch_dir)
@@ -175,11 +212,26 @@ class TracRuntime:
175
212
 
176
213
  self._system.start(wait=wait)
177
214
 
215
+ # If the runtime server has been enabled, start it up
216
+ if self._server_enabled:
217
+
218
+ self._log.info("Starting the runtime API server...")
219
+
220
+ # The server module pulls in all the gRPC dependencies, don't import it unless we have to
221
+ import tracdap.rt._exec.server as _server
222
+
223
+ self._server = _server.RuntimeApiServer(self._system, self._server_port)
224
+ self._server.start()
225
+
178
226
  except Exception as e:
179
227
  self._handle_startup_error(e)
180
228
 
181
229
  def stop(self, due_to_error=False):
182
230
 
231
+ if self._server is not None:
232
+ self._log.info("Stopping the runtime API server...")
233
+ self._server.stop()
234
+
183
235
  if due_to_error:
184
236
  self._log.info("Shutting down the engine in response to an error")
185
237
  else:
@@ -209,6 +261,28 @@ class TracRuntime:
209
261
  else:
210
262
  self._log.info("TRAC runtime has gone down cleanly")
211
263
 
264
+ def is_oneshot(self):
265
+ return not self._server_enabled
266
+
267
+ def run_until_done(self):
268
+
269
+ if self._server_enabled == False and len(self._jobs) == 0:
270
+ self._log.error("No job config supplied, TRAC runtime will not run")
271
+ raise _ex.EStartup("No job config supplied")
272
+
273
+ signal.signal(signal.SIGTERM, self._request_shutdown)
274
+ signal.signal(signal.SIGINT, self._request_shutdown)
275
+
276
+ with self._runtime_lock:
277
+ while not self._shutdown_requested:
278
+ self._runtime_event.wait()
279
+
280
+ def _request_shutdown(self, _signum = None, _frame = None):
281
+
282
+ with self._runtime_lock:
283
+ self._shutdown_requested = True
284
+ self._runtime_event.notify()
285
+
212
286
  def _prepare_scratch_dir(self):
213
287
 
214
288
  if not self._scratch_dir_provided:
@@ -246,20 +320,18 @@ class TracRuntime:
246
320
 
247
321
  if isinstance(job_config, _cfg.JobConfig):
248
322
  self._log.info("Using embedded job config")
249
- job_config_path = None
250
323
 
251
324
  else:
252
- job_config_path = job_config
253
325
  job_config_dev_mode = _dev_mode.DEV_MODE_JOB_CONFIG if self._dev_mode else None
254
- job_config_parser = _cparse.ConfigParser(_cfg.JobConfig, job_config_dev_mode)
255
- job_config_raw = job_config_parser.load_raw_config(job_config_path, config_file_name="job")
256
- job_config = job_config_parser.parse(job_config_raw, job_config_path)
326
+ job_config = self._config_mgr.load_config_object(
327
+ job_config, _cfg.JobConfig,
328
+ job_config_dev_mode,
329
+ config_file_name="job")
257
330
 
258
331
  if self._dev_mode:
259
- config_dir = job_config_path.parent if job_config_path is not None else None
260
332
  job_config = _dev_mode.DevModeTranslator.translate_job_config(
261
333
  self._sys_config, job_config,
262
- self._scratch_dir, config_dir,
334
+ self._scratch_dir, self._config_mgr,
263
335
  model_class)
264
336
 
265
337
  return job_config
@@ -269,7 +341,7 @@ class TracRuntime:
269
341
  job_key = _util.object_key(job_config.jobId)
270
342
  self._jobs[job_key] = _RuntimeJobInfo()
271
343
 
272
- self._system.send(
344
+ self._system.send_main(
273
345
  "submit_job", job_config,
274
346
  str(self._job_result_dir) if self._job_result_dir else "",
275
347
  self._job_result_format if self._job_result_format else "")
@@ -281,35 +353,34 @@ class TracRuntime:
281
353
  if job_key not in self._jobs:
282
354
  raise _ex.ETracInternal(f"Attempt to wait for a job that was never started")
283
355
 
284
- with self._engine_event:
285
- while True:
356
+ self._oneshot_job = job_key
286
357
 
287
- job_info = self._jobs[job_key]
358
+ self.run_until_done()
288
359
 
289
- if job_info.error is not None:
290
- raise job_info.error
360
+ job_info = self._jobs[job_key]
291
361
 
292
- if job_info.result is not None:
293
- return job_info.result
362
+ if job_info.error is not None:
363
+ raise job_info.error
294
364
 
295
- # TODO: Timeout / heartbeat
365
+ elif job_info.result is not None:
366
+ return job_info.result
296
367
 
297
- self._engine_event.wait(1)
368
+ else:
369
+ err = f"No result or error information is available for job [{job_key}]"
370
+ self._log.error(err)
371
+ raise _ex.ETracInternal(err)
298
372
 
299
373
  def _engine_callback(self, job_key, job_result, job_error):
300
374
 
301
- with self._engine_event:
302
-
303
- if job_result is not None:
304
- self._jobs[job_key].done = True
305
- self._jobs[job_key].result = job_result
306
- elif job_error is not None:
307
- self._jobs[job_key].done = True
308
- self._jobs[job_key].error = job_error
309
- else:
310
- pass
375
+ if job_result is not None:
376
+ self._jobs[job_key].done = True
377
+ self._jobs[job_key].result = job_result
378
+ elif job_error is not None:
379
+ self._jobs[job_key].done = True
380
+ self._jobs[job_key].error = job_error
311
381
 
312
- self._engine_event.notify()
382
+ if self._oneshot_job == job_key:
383
+ self._request_shutdown()
313
384
 
314
385
  # ------------------------------------------------------------------------------------------------------------------
315
386
  # Error handling
@@ -0,0 +1,345 @@
1
+ # Copyright 2024 Accenture Global Solutions Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import threading
17
+ import typing as tp
18
+
19
+ import tracdap.rt.config as config
20
+ import tracdap.rt.exceptions as ex
21
+ import tracdap.rt._exec.actors as actors
22
+ import tracdap.rt._impl.grpc.codec as codec # noqa
23
+ import tracdap.rt._impl.util as util # noqa
24
+
25
+ # Check whether gRPC is installed before trying to load any of the generated modules
26
+ try:
27
+ import grpc.aio # noqa
28
+ import google.protobuf.message as _msg # noqa
29
+ except ImportError:
30
+ raise ex.EStartup("The runtime API server cannot be enabled because gRPC libraries are not installed")
31
+
32
+ # Imports for gRPC generated code, these are managed by build_runtime.py for distribution
33
+ import tracdap.rt._impl.grpc.tracdap.api.internal.runtime_pb2 as runtime_pb2
34
+ import tracdap.rt._impl.grpc.tracdap.api.internal.runtime_pb2_grpc as runtime_grpc
35
+
36
+
37
+ class RuntimeApiServer(runtime_grpc.TracRuntimeApiServicer):
38
+
39
+ # Default timeout values in seconds
40
+ __DEFAULT_STARTUP_TIMEOUT = 5.0
41
+ __DEFAULT_SHUTDOWN_TIMEOUT = 10.0
42
+ __DEFAULT_REQUEST_TIMEOUT = 10.0
43
+
44
+ def __init__(self, system: actors.ActorSystem, port: int):
45
+
46
+ self.__log = util.logger_for_object(self)
47
+
48
+ self.__system = system
49
+ self.__engine_id = system.main_id()
50
+ self.__agent: tp.Optional[ApiAgent] = None
51
+
52
+ self.__port = port
53
+ self.__request_timeout = self.__DEFAULT_REQUEST_TIMEOUT # Not configurable atm
54
+ self.__server: tp.Optional[grpc.aio.Server] = None
55
+ self.__server_thread: tp.Optional[threading.Thread] = None
56
+ self.__event_loop: tp.Optional[asyncio.AbstractEventLoop] = None
57
+
58
+ self.__start_signal: tp.Optional[threading.Event] = None
59
+ self.__stop_signal: tp.Optional[asyncio.Event] = None
60
+
61
+ def start(self, startup_timeout: float = None):
62
+
63
+ if self.__start_signal is not None:
64
+ return
65
+
66
+ timeout = startup_timeout or self.__DEFAULT_SHUTDOWN_TIMEOUT
67
+
68
+ self.__start_signal = threading.Event()
69
+ self.__server_thread = threading.Thread(target=self.__server_main, name="api_server", daemon=True)
70
+ self.__server_thread.start()
71
+
72
+ try:
73
+ self.__start_signal.wait(timeout)
74
+ except TimeoutError as e:
75
+ raise ex.EStartup("Runtime API failed to start") from e
76
+
77
+ def stop(self, shutdown_timeout: float = None):
78
+
79
+ if self.__server is None:
80
+ return
81
+
82
+ timeout = shutdown_timeout or self.__DEFAULT_SHUTDOWN_TIMEOUT
83
+
84
+ self.__event_loop.call_soon_threadsafe(lambda: self.__stop_signal.set())
85
+ self.__server_thread.join(timeout)
86
+
87
+ if self.__server_thread.is_alive():
88
+ self.__log.warning("Runtime API server did not go down cleanly")
89
+
90
+ def __server_main(self):
91
+
92
+ self.__event_loop = asyncio.new_event_loop()
93
+ self.__event_loop.run_until_complete(self.__server_main_async())
94
+ self.__event_loop.close()
95
+
96
+ async def __server_main_async(self):
97
+
98
+ server_address = f"[::]:{self.__port}"
99
+
100
+ # Asyncio events must be created inside the event loop for Python <= 3.9
101
+ self.__stop_signal = asyncio.Event()
102
+
103
+ # Agent using asyncio, so must be created inside the event loop
104
+ self.__agent = ApiAgent()
105
+ self.__system.spawn_agent(self.__agent)
106
+
107
+ self.__server = grpc.aio.server()
108
+ self.__server.add_insecure_port(server_address)
109
+ runtime_grpc.add_TracRuntimeApiServicer_to_server(self, self.__server)
110
+
111
+ await self.__server.start()
112
+ await self.__agent.started()
113
+
114
+ self.__start_signal.set()
115
+
116
+ self.__log.info(f"Runtime API server is up and listening on port [{self.__port}]")
117
+
118
+ await asyncio.create_task(self.__stop_signal.wait())
119
+
120
+ self.__log.info(f"Shutdown signal received, runtime API server is going down...")
121
+
122
+ await self.__server.stop(self.__DEFAULT_SHUTDOWN_TIMEOUT)
123
+ self.__server = None
124
+
125
+ self.__log.info("Runtime API server has gone down cleanly")
126
+
127
+ async def listJobs(self, request: runtime_pb2.RuntimeListJobsRequest, context: grpc.aio.ServicerContext):
128
+
129
+ request_task = ListJobsRequest(self.__engine_id, request, context)
130
+ self.__agent.threadsafe().spawn(request_task)
131
+
132
+ return await request_task.complete(self.__request_timeout)
133
+
134
+ async def getJobStatus(self, request: runtime_pb2.RuntimeJobInfoRequest, context: grpc.ServicerContext):
135
+
136
+ request_task = GetJobStatusRequest(self.__engine_id, request, context)
137
+ self.__agent.threadsafe().spawn(request_task)
138
+
139
+ return await request_task.complete(self.__request_timeout)
140
+
141
+ async def getJobResult(self, request: runtime_pb2.RuntimeJobInfoRequest, context: grpc.ServicerContext):
142
+
143
+ request_task = GetJobResultRequest(self.__engine_id, request, context)
144
+ self.__agent.threadsafe().spawn(request_task)
145
+
146
+ return await request_task.complete(self.__request_timeout)
147
+
148
+
149
+ _T_REQUEST = tp.TypeVar("_T_REQUEST", bound=_msg.Message)
150
+ _T_RESPONSE = tp.TypeVar("_T_RESPONSE", bound=_msg.Message)
151
+
152
+
153
+ class ApiAgent(actors.ThreadsafeActor):
154
+
155
+ # API Agent is the parent actor that will be used to spawn API requests
156
+ # It must be created inside the asyncio event loop
157
+
158
+ def __init__(self):
159
+ super().__init__()
160
+ self._log = util.logger_for_object(self)
161
+ self._event_loop = asyncio.get_event_loop()
162
+ self.__start_signal = asyncio.Event()
163
+
164
+ def on_start(self):
165
+ self._event_loop.call_soon_threadsafe(lambda: self.__start_signal.set())
166
+
167
+ def on_signal(self, signal: actors.Signal) -> tp.Optional[bool]:
168
+
169
+ # Do not allow a failed request to bring down the API server
170
+ if signal.message == actors.SignalNames.FAILED:
171
+ error = signal.error if isinstance(signal, actors.ErrorSignal) else None
172
+ self._log.warning("Unhandled error during API request: " + str(error))
173
+ self._log.warning("The API agent will continue running")
174
+ return True
175
+
176
+ return False
177
+
178
+ async def started(self):
179
+ await self.__start_signal.wait()
180
+
181
+
182
+ class ApiRequest(actors.ThreadsafeActor, tp.Generic[_T_REQUEST, _T_RESPONSE]):
183
+
184
+ # API request is the bridge between asyncio events (gRPC) and actor messages (TRAC runtime engine)
185
+ # Requests objects must be created inside the asyncio event loop
186
+
187
+ _log = None
188
+
189
+ def __init__(
190
+ self, engine_id, method: str, request: _T_REQUEST,
191
+ context: grpc.aio.ServicerContext):
192
+
193
+ super().__init__()
194
+
195
+ self._engine_id = engine_id
196
+ self._method = method
197
+ self._request = request
198
+ self._response: tp.Optional[_T_RESPONSE] = None
199
+ self._error: tp.Optional[Exception] = None
200
+ self._grpc_code = grpc.StatusCode.OK
201
+ self._grpc_message = ""
202
+
203
+ self._context = context
204
+ self._event_loop = asyncio.get_event_loop()
205
+ self._completion = asyncio.Event()
206
+
207
+ self._log.info("API call start: %s()", self._method)
208
+
209
+ def _mark_complete(self):
210
+
211
+ self._event_loop.call_soon_threadsafe(lambda: self._completion.set())
212
+
213
+ def on_stop(self):
214
+
215
+ if self.state() == actors.ActorState.ERROR:
216
+ self._error = self.error()
217
+
218
+ self._mark_complete()
219
+
220
+ async def complete(self, request_timeout: float) -> _T_RESPONSE:
221
+
222
+ try:
223
+
224
+ completion_task = asyncio.create_task(self._completion.wait())
225
+ await asyncio.wait_for(completion_task, request_timeout)
226
+
227
+ if self._error:
228
+ raise self._error
229
+
230
+ elif self._grpc_code != grpc.StatusCode.OK:
231
+ self._log.info("API call failed: %s() %s %s", self._method, self._grpc_code.name, self._grpc_message)
232
+ self._context.set_code(self._grpc_code)
233
+ self._context.set_details(self._grpc_message)
234
+
235
+ elif self._response is not None:
236
+ self._log.info("API call succeeded: %s()", self._method)
237
+ return self._response
238
+
239
+ else:
240
+ raise ex.EUnexpected()
241
+
242
+ except TimeoutError:
243
+ self._completion.set()
244
+ self._context.set_code(grpc.StatusCode.DEADLINE_EXCEEDED)
245
+ self._context.set_details("The TRAC runtime engine did not respond")
246
+ self._log.error("API call failed: %s() %s", self._method, "The TRAC runtime engine did not respond")
247
+ raise
248
+
249
+ except Exception as e:
250
+ self._context.set_code(grpc.StatusCode.INTERNAL)
251
+ self._context.set_details("Internal server error")
252
+ self._log.error("API call failed: %s() %s", self._method, str(e))
253
+ self._log.exception(e)
254
+ raise
255
+
256
+ finally:
257
+ self.threadsafe().stop()
258
+
259
+
260
+ ApiRequest._log = util.logger_for_class(ApiRequest)
261
+
262
+
263
+ class ListJobsRequest(ApiRequest[runtime_pb2.RuntimeListJobsRequest, runtime_pb2.RuntimeListJobsResponse]):
264
+
265
+ def __init__(self, engine_id, request, context):
266
+ super().__init__(engine_id, "get_job_list", request, context)
267
+
268
+ def on_start(self):
269
+ self.actors().send(self._engine_id, "get_job_list")
270
+
271
+ @actors.Message
272
+ def job_list(self, job_list):
273
+
274
+ self._response = runtime_pb2.RuntimeListJobsResponse(
275
+ jobs=codec.encode(job_list))
276
+
277
+ self._mark_complete()
278
+
279
+
280
+ class GetJobStatusRequest(ApiRequest[runtime_pb2.RuntimeJobInfoRequest, runtime_pb2.RuntimeJobStatus]):
281
+
282
+ def __init__(self, engine_id, request, context):
283
+
284
+ super().__init__(engine_id, "get_job_status", request, context)
285
+
286
+ if request.HasField("jobKey"):
287
+ self._job_key = self._request.jobKey
288
+ elif request.HasField("jobSelector"):
289
+ self._job_key = util.object_key(self._request.jobSelector)
290
+ else:
291
+ raise ex.EValidation("Bad request: Neither jobKey nor jobSelector is specified")
292
+
293
+ def on_start(self):
294
+ self.actors().send(self._engine_id, "get_job_details", self._job_key, details=False)
295
+
296
+ @actors.Message
297
+ def job_details(self, job_details: tp.Optional[config.JobResult]):
298
+
299
+ if job_details is None:
300
+ self._grpc_code = grpc.StatusCode.NOT_FOUND
301
+ self._grpc_message = f"Job not found: [{self._job_key}]"
302
+
303
+ else:
304
+ self._response = runtime_pb2.RuntimeJobStatus(
305
+ jobId=codec.encode(job_details.jobId),
306
+ statusCode=codec.encode(job_details.statusCode),
307
+ statusMessage=codec.encode(job_details.statusMessage))
308
+
309
+ self._mark_complete()
310
+
311
+
312
+ class GetJobResultRequest(ApiRequest[runtime_pb2.RuntimeJobInfoRequest, runtime_pb2.RuntimeJobResult]):
313
+
314
+ def __init__(self, engine_id, request, context):
315
+
316
+ super().__init__(engine_id, "get_job_result", request, context)
317
+
318
+ if request.HasField("jobKey"):
319
+ self._job_key = self._request.jobKey
320
+ elif request.HasField("jobSelector"):
321
+ self._job_key = util.object_key(self._request.jobSelector)
322
+ else:
323
+ raise ex.EValidation("Bad request: Neither jobKey nor jobSelector is specified")
324
+
325
+ def on_start(self):
326
+ self.actors().send(self._engine_id, "get_job_details", self._job_key, details=True)
327
+
328
+ @actors.Message
329
+ def job_details(self, job_details: tp.Optional[config.JobResult]):
330
+
331
+ if job_details is None:
332
+ self._grpc_code = grpc.StatusCode.NOT_FOUND
333
+ self._grpc_message = f"Job not found: [{self._job_key}]"
334
+
335
+ else:
336
+
337
+ encoded_results = dict((k, codec.encode(v)) for k, v in job_details.results.items())
338
+
339
+ self._response = runtime_pb2.RuntimeJobResult(
340
+ jobId=codec.encode(job_details.jobId),
341
+ statusCode=codec.encode(job_details.statusCode),
342
+ statusMessage=codec.encode(job_details.statusMessage),
343
+ results=encoded_results)
344
+
345
+ self._mark_complete()