tracdap-runtime 0.6.2__py3-none-any.whl → 0.6.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tracdap/rt/_exec/actors.py +87 -10
- tracdap/rt/_exec/context.py +207 -100
- tracdap/rt/_exec/dev_mode.py +52 -20
- tracdap/rt/_exec/engine.py +79 -14
- tracdap/rt/_exec/functions.py +14 -17
- tracdap/rt/_exec/runtime.py +83 -40
- tracdap/rt/_exec/server.py +306 -29
- tracdap/rt/_impl/config_parser.py +219 -49
- tracdap/rt/_impl/data.py +70 -5
- tracdap/rt/_impl/grpc/codec.py +60 -5
- tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2.py +19 -19
- tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2.pyi +11 -9
- tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2_grpc.py +25 -25
- tracdap/rt/_impl/grpc/tracdap/metadata/data_pb2.py +18 -18
- tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.py +28 -16
- tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.pyi +37 -6
- tracdap/rt/_impl/grpc/tracdap/metadata/object_pb2.py +8 -3
- tracdap/rt/_impl/grpc/tracdap/metadata/object_pb2.pyi +13 -2
- tracdap/rt/_impl/guard_rails.py +21 -0
- tracdap/rt/_impl/models.py +25 -0
- tracdap/rt/_impl/static_api.py +43 -13
- tracdap/rt/_impl/type_system.py +17 -0
- tracdap/rt/_impl/validation.py +47 -4
- tracdap/rt/_plugins/config_local.py +49 -0
- tracdap/rt/_version.py +1 -1
- tracdap/rt/api/hook.py +6 -5
- tracdap/rt/api/model_api.py +50 -7
- tracdap/rt/api/static_api.py +81 -23
- tracdap/rt/config/__init__.py +4 -4
- tracdap/rt/config/common.py +25 -15
- tracdap/rt/config/job.py +2 -2
- tracdap/rt/config/platform.py +25 -35
- tracdap/rt/config/result.py +2 -2
- tracdap/rt/config/runtime.py +4 -2
- tracdap/rt/ext/config.py +34 -0
- tracdap/rt/ext/embed.py +1 -3
- tracdap/rt/ext/plugins.py +47 -6
- tracdap/rt/launch/cli.py +11 -4
- tracdap/rt/launch/launch.py +53 -12
- tracdap/rt/metadata/__init__.py +17 -17
- tracdap/rt/metadata/common.py +2 -2
- tracdap/rt/metadata/custom.py +3 -3
- tracdap/rt/metadata/data.py +12 -12
- tracdap/rt/metadata/file.py +6 -6
- tracdap/rt/metadata/flow.py +6 -6
- tracdap/rt/metadata/job.py +8 -8
- tracdap/rt/metadata/model.py +21 -11
- tracdap/rt/metadata/object.py +3 -0
- tracdap/rt/metadata/object_id.py +8 -8
- tracdap/rt/metadata/search.py +5 -5
- tracdap/rt/metadata/stoarge.py +6 -6
- tracdap/rt/metadata/tag.py +1 -1
- tracdap/rt/metadata/tag_update.py +1 -1
- tracdap/rt/metadata/type.py +4 -4
- {tracdap_runtime-0.6.2.dist-info → tracdap_runtime-0.6.4.dist-info}/METADATA +4 -4
- tracdap_runtime-0.6.4.dist-info/RECORD +112 -0
- {tracdap_runtime-0.6.2.dist-info → tracdap_runtime-0.6.4.dist-info}/WHEEL +1 -1
- tracdap/rt/_impl/grpc/tracdap/config/common_pb2.py +0 -55
- tracdap/rt/_impl/grpc/tracdap/config/common_pb2.pyi +0 -103
- tracdap/rt/_impl/grpc/tracdap/config/job_pb2.py +0 -42
- tracdap/rt/_impl/grpc/tracdap/config/job_pb2.pyi +0 -44
- tracdap/rt/_impl/grpc/tracdap/config/platform_pb2.py +0 -71
- tracdap/rt/_impl/grpc/tracdap/config/platform_pb2.pyi +0 -197
- tracdap/rt/_impl/grpc/tracdap/config/result_pb2.py +0 -37
- tracdap/rt/_impl/grpc/tracdap/config/result_pb2.pyi +0 -35
- tracdap/rt/_impl/grpc/tracdap/config/runtime_pb2.py +0 -42
- tracdap/rt/_impl/grpc/tracdap/config/runtime_pb2.pyi +0 -46
- tracdap/rt/ext/_guard.py +0 -37
- tracdap_runtime-0.6.2.dist-info/RECORD +0 -121
- {tracdap_runtime-0.6.2.dist-info → tracdap_runtime-0.6.4.dist-info}/LICENSE +0 -0
- {tracdap_runtime-0.6.2.dist-info → tracdap_runtime-0.6.4.dist-info}/top_level.txt +0 -0
tracdap/rt/_exec/server.py
CHANGED
@@ -12,57 +12,334 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
import asyncio
|
16
|
+
import threading
|
15
17
|
import typing as tp
|
16
|
-
|
18
|
+
|
19
|
+
import tracdap.rt.config as config
|
20
|
+
import tracdap.rt.exceptions as ex
|
21
|
+
import tracdap.rt._exec.actors as actors
|
22
|
+
import tracdap.rt._impl.grpc.codec as codec # noqa
|
23
|
+
import tracdap.rt._impl.util as util # noqa
|
24
|
+
|
25
|
+
# Check whether gRPC is installed before trying to load any of the generated modules
|
26
|
+
try:
|
27
|
+
import grpc.aio # noqa
|
28
|
+
import google.protobuf.message as _msg # noqa
|
29
|
+
except ImportError:
|
30
|
+
raise ex.EStartup("The runtime API server cannot be enabled because gRPC libraries are not installed")
|
17
31
|
|
18
32
|
# Imports for gRPC generated code, these are managed by build_runtime.py for distribution
|
19
33
|
import tracdap.rt._impl.grpc.tracdap.api.internal.runtime_pb2 as runtime_pb2
|
20
34
|
import tracdap.rt._impl.grpc.tracdap.api.internal.runtime_pb2_grpc as runtime_grpc
|
21
|
-
import grpc
|
22
35
|
|
23
36
|
|
24
37
|
class RuntimeApiServer(runtime_grpc.TracRuntimeApiServicer):
|
25
38
|
|
26
|
-
|
27
|
-
|
28
|
-
__DEFAULT_SHUTDOWN_TIMEOUT = 10.0
|
39
|
+
# Default timeout values in seconds
|
40
|
+
__DEFAULT_STARTUP_TIMEOUT = 5.0
|
41
|
+
__DEFAULT_SHUTDOWN_TIMEOUT = 10.0
|
42
|
+
__DEFAULT_REQUEST_TIMEOUT = 10.0
|
43
|
+
|
44
|
+
def __init__(self, system: actors.ActorSystem, port: int):
|
45
|
+
|
46
|
+
self.__log = util.logger_for_object(self)
|
47
|
+
|
48
|
+
self.__system = system
|
49
|
+
self.__engine_id = system.main_id()
|
50
|
+
self.__agent: tp.Optional[ApiAgent] = None
|
29
51
|
|
30
|
-
def __init__(self, port: int, n_workers: int = None):
|
31
52
|
self.__port = port
|
32
|
-
self.
|
33
|
-
self.__server: tp.Optional[grpc.Server] = None
|
34
|
-
self.
|
53
|
+
self.__request_timeout = self.__DEFAULT_REQUEST_TIMEOUT # Not configurable atm
|
54
|
+
self.__server: tp.Optional[grpc.aio.Server] = None
|
55
|
+
self.__server_thread: tp.Optional[threading.Thread] = None
|
56
|
+
self.__event_loop: tp.Optional[asyncio.AbstractEventLoop] = None
|
57
|
+
|
58
|
+
self.__start_signal: tp.Optional[threading.Event] = None
|
59
|
+
self.__stop_signal: tp.Optional[asyncio.Event] = None
|
60
|
+
|
61
|
+
def start(self, startup_timeout: float = None):
|
62
|
+
|
63
|
+
if self.__start_signal is not None:
|
64
|
+
return
|
65
|
+
|
66
|
+
timeout = startup_timeout or self.__DEFAULT_SHUTDOWN_TIMEOUT
|
67
|
+
|
68
|
+
self.__start_signal = threading.Event()
|
69
|
+
self.__server_thread = threading.Thread(target=self.__server_main, name="api_server", daemon=True)
|
70
|
+
self.__server_thread.start()
|
71
|
+
|
72
|
+
try:
|
73
|
+
self.__start_signal.wait(timeout)
|
74
|
+
except TimeoutError as e:
|
75
|
+
raise ex.EStartup("Runtime API failed to start") from e
|
76
|
+
|
77
|
+
def stop(self, shutdown_timeout: float = None):
|
78
|
+
|
79
|
+
if self.__server is None:
|
80
|
+
return
|
81
|
+
|
82
|
+
timeout = shutdown_timeout or self.__DEFAULT_SHUTDOWN_TIMEOUT
|
35
83
|
|
36
|
-
|
37
|
-
|
84
|
+
self.__event_loop.call_soon_threadsafe(lambda: self.__stop_signal.set())
|
85
|
+
self.__server_thread.join(timeout)
|
38
86
|
|
39
|
-
|
40
|
-
|
87
|
+
if self.__server_thread.is_alive():
|
88
|
+
self.__log.warning("Runtime API server did not go down cleanly")
|
41
89
|
|
42
|
-
def
|
43
|
-
return super().getJobDetails(request, context)
|
90
|
+
def __server_main(self):
|
44
91
|
|
45
|
-
|
92
|
+
self.__event_loop = asyncio.new_event_loop()
|
93
|
+
self.__event_loop.run_until_complete(self.__server_main_async())
|
94
|
+
self.__event_loop.close()
|
46
95
|
|
47
|
-
|
48
|
-
max_workers=self.__n_workers,
|
49
|
-
thread_name_prefix=self.__THREAD_NAME_PREFIX)
|
96
|
+
async def __server_main_async(self):
|
50
97
|
|
51
|
-
|
98
|
+
server_address = f"[::]:{self.__port}"
|
52
99
|
|
53
|
-
|
54
|
-
self.
|
100
|
+
# Asyncio events must be created inside the event loop for Python <= 3.9
|
101
|
+
self.__stop_signal = asyncio.Event()
|
55
102
|
|
103
|
+
# Agent using asyncio, so must be created inside the event loop
|
104
|
+
self.__agent = ApiAgent()
|
105
|
+
self.__system.spawn_agent(self.__agent)
|
106
|
+
|
107
|
+
self.__server = grpc.aio.server()
|
108
|
+
self.__server.add_insecure_port(server_address)
|
56
109
|
runtime_grpc.add_TracRuntimeApiServicer_to_server(self, self.__server)
|
57
110
|
|
58
|
-
self.__server.start()
|
111
|
+
await self.__server.start()
|
112
|
+
await self.__agent.started()
|
59
113
|
|
60
|
-
|
114
|
+
self.__start_signal.set()
|
115
|
+
|
116
|
+
self.__log.info(f"Runtime API server is up and listening on port [{self.__port}]")
|
117
|
+
|
118
|
+
await asyncio.create_task(self.__stop_signal.wait())
|
119
|
+
|
120
|
+
self.__log.info(f"Shutdown signal received, runtime API server is going down...")
|
121
|
+
|
122
|
+
await self.__server.stop(self.__DEFAULT_SHUTDOWN_TIMEOUT)
|
123
|
+
self.__server = None
|
124
|
+
|
125
|
+
self.__log.info("Runtime API server has gone down cleanly")
|
126
|
+
|
127
|
+
async def listJobs(self, request: runtime_pb2.RuntimeListJobsRequest, context: grpc.aio.ServicerContext):
|
128
|
+
|
129
|
+
request_task = ListJobsRequest(self.__engine_id, request, context)
|
130
|
+
self.__agent.threadsafe().spawn(request_task)
|
131
|
+
|
132
|
+
return await request_task.complete(self.__request_timeout)
|
133
|
+
|
134
|
+
async def getJobStatus(self, request: runtime_pb2.RuntimeJobInfoRequest, context: grpc.ServicerContext):
|
135
|
+
|
136
|
+
request_task = GetJobStatusRequest(self.__engine_id, request, context)
|
137
|
+
self.__agent.threadsafe().spawn(request_task)
|
138
|
+
|
139
|
+
return await request_task.complete(self.__request_timeout)
|
140
|
+
|
141
|
+
async def getJobResult(self, request: runtime_pb2.RuntimeJobInfoRequest, context: grpc.ServicerContext):
|
142
|
+
|
143
|
+
request_task = GetJobResultRequest(self.__engine_id, request, context)
|
144
|
+
self.__agent.threadsafe().spawn(request_task)
|
145
|
+
|
146
|
+
return await request_task.complete(self.__request_timeout)
|
147
|
+
|
148
|
+
|
149
|
+
_T_REQUEST = tp.TypeVar("_T_REQUEST", bound=_msg.Message)
|
150
|
+
_T_RESPONSE = tp.TypeVar("_T_RESPONSE", bound=_msg.Message)
|
151
|
+
|
152
|
+
|
153
|
+
class ApiAgent(actors.ThreadsafeActor):
|
154
|
+
|
155
|
+
# API Agent is the parent actor that will be used to spawn API requests
|
156
|
+
# It must be created inside the asyncio event loop
|
157
|
+
|
158
|
+
def __init__(self):
|
159
|
+
super().__init__()
|
160
|
+
self._log = util.logger_for_object(self)
|
161
|
+
self._event_loop = asyncio.get_event_loop()
|
162
|
+
self.__start_signal = asyncio.Event()
|
163
|
+
|
164
|
+
def on_start(self):
|
165
|
+
self._event_loop.call_soon_threadsafe(lambda: self.__start_signal.set())
|
166
|
+
|
167
|
+
def on_signal(self, signal: actors.Signal) -> tp.Optional[bool]:
|
168
|
+
|
169
|
+
# Do not allow a failed request to bring down the API server
|
170
|
+
if signal.message == actors.SignalNames.FAILED:
|
171
|
+
error = signal.error if isinstance(signal, actors.ErrorSignal) else None
|
172
|
+
self._log.warning("Unhandled error during API request: " + str(error))
|
173
|
+
self._log.warning("The API agent will continue running")
|
174
|
+
return True
|
175
|
+
|
176
|
+
return False
|
177
|
+
|
178
|
+
async def started(self):
|
179
|
+
await self.__start_signal.wait()
|
180
|
+
|
181
|
+
|
182
|
+
class ApiRequest(actors.ThreadsafeActor, tp.Generic[_T_REQUEST, _T_RESPONSE]):
|
183
|
+
|
184
|
+
# API request is the bridge between asyncio events (gRPC) and actor messages (TRAC runtime engine)
|
185
|
+
# Requests objects must be created inside the asyncio event loop
|
186
|
+
|
187
|
+
_log = None
|
188
|
+
|
189
|
+
def __init__(
|
190
|
+
self, engine_id, method: str, request: _T_REQUEST,
|
191
|
+
context: grpc.aio.ServicerContext):
|
192
|
+
|
193
|
+
super().__init__()
|
194
|
+
|
195
|
+
self._engine_id = engine_id
|
196
|
+
self._method = method
|
197
|
+
self._request = request
|
198
|
+
self._response: tp.Optional[_T_RESPONSE] = None
|
199
|
+
self._error: tp.Optional[Exception] = None
|
200
|
+
self._grpc_code = grpc.StatusCode.OK
|
201
|
+
self._grpc_message = ""
|
202
|
+
|
203
|
+
self._context = context
|
204
|
+
self._event_loop = asyncio.get_event_loop()
|
205
|
+
self._completion = asyncio.Event()
|
206
|
+
|
207
|
+
self._log.info("API call start: %s()", self._method)
|
208
|
+
|
209
|
+
def _mark_complete(self):
|
210
|
+
|
211
|
+
self._event_loop.call_soon_threadsafe(lambda: self._completion.set())
|
212
|
+
|
213
|
+
def on_stop(self):
|
214
|
+
|
215
|
+
if self.state() == actors.ActorState.ERROR:
|
216
|
+
self._error = self.error()
|
217
|
+
|
218
|
+
self._mark_complete()
|
219
|
+
|
220
|
+
async def complete(self, request_timeout: float) -> _T_RESPONSE:
|
221
|
+
|
222
|
+
try:
|
223
|
+
|
224
|
+
completion_task = asyncio.create_task(self._completion.wait())
|
225
|
+
await asyncio.wait_for(completion_task, request_timeout)
|
226
|
+
|
227
|
+
if self._error:
|
228
|
+
raise self._error
|
229
|
+
|
230
|
+
elif self._grpc_code != grpc.StatusCode.OK:
|
231
|
+
self._log.info("API call failed: %s() %s %s", self._method, self._grpc_code.name, self._grpc_message)
|
232
|
+
self._context.set_code(self._grpc_code)
|
233
|
+
self._context.set_details(self._grpc_message)
|
234
|
+
|
235
|
+
elif self._response is not None:
|
236
|
+
self._log.info("API call succeeded: %s()", self._method)
|
237
|
+
return self._response
|
238
|
+
|
239
|
+
else:
|
240
|
+
raise ex.EUnexpected()
|
241
|
+
|
242
|
+
except TimeoutError:
|
243
|
+
self._completion.set()
|
244
|
+
self._context.set_code(grpc.StatusCode.DEADLINE_EXCEEDED)
|
245
|
+
self._context.set_details("The TRAC runtime engine did not respond")
|
246
|
+
self._log.error("API call failed: %s() %s", self._method, "The TRAC runtime engine did not respond")
|
247
|
+
raise
|
248
|
+
|
249
|
+
except Exception as e:
|
250
|
+
self._context.set_code(grpc.StatusCode.INTERNAL)
|
251
|
+
self._context.set_details("Internal server error")
|
252
|
+
self._log.error("API call failed: %s() %s", self._method, str(e))
|
253
|
+
self._log.exception(e)
|
254
|
+
raise
|
255
|
+
|
256
|
+
finally:
|
257
|
+
self.threadsafe().stop()
|
258
|
+
|
259
|
+
|
260
|
+
ApiRequest._log = util.logger_for_class(ApiRequest)
|
261
|
+
|
262
|
+
|
263
|
+
class ListJobsRequest(ApiRequest[runtime_pb2.RuntimeListJobsRequest, runtime_pb2.RuntimeListJobsResponse]):
|
264
|
+
|
265
|
+
def __init__(self, engine_id, request, context):
|
266
|
+
super().__init__(engine_id, "get_job_list", request, context)
|
267
|
+
|
268
|
+
def on_start(self):
|
269
|
+
self.actors().send(self._engine_id, "get_job_list")
|
270
|
+
|
271
|
+
@actors.Message
|
272
|
+
def job_list(self, job_list):
|
273
|
+
|
274
|
+
self._response = runtime_pb2.RuntimeListJobsResponse(
|
275
|
+
jobs=codec.encode(job_list))
|
276
|
+
|
277
|
+
self._mark_complete()
|
278
|
+
|
279
|
+
|
280
|
+
class GetJobStatusRequest(ApiRequest[runtime_pb2.RuntimeJobInfoRequest, runtime_pb2.RuntimeJobStatus]):
|
281
|
+
|
282
|
+
def __init__(self, engine_id, request, context):
|
283
|
+
|
284
|
+
super().__init__(engine_id, "get_job_status", request, context)
|
285
|
+
|
286
|
+
if request.HasField("jobKey"):
|
287
|
+
self._job_key = self._request.jobKey
|
288
|
+
elif request.HasField("jobSelector"):
|
289
|
+
self._job_key = util.object_key(self._request.jobSelector)
|
290
|
+
else:
|
291
|
+
raise ex.EValidation("Bad request: Neither jobKey nor jobSelector is specified")
|
292
|
+
|
293
|
+
def on_start(self):
|
294
|
+
self.actors().send(self._engine_id, "get_job_details", self._job_key, details=False)
|
295
|
+
|
296
|
+
@actors.Message
|
297
|
+
def job_details(self, job_details: tp.Optional[config.JobResult]):
|
298
|
+
|
299
|
+
if job_details is None:
|
300
|
+
self._grpc_code = grpc.StatusCode.NOT_FOUND
|
301
|
+
self._grpc_message = f"Job not found: [{self._job_key}]"
|
302
|
+
|
303
|
+
else:
|
304
|
+
self._response = runtime_pb2.RuntimeJobStatus(
|
305
|
+
jobId=codec.encode(job_details.jobId),
|
306
|
+
statusCode=codec.encode(job_details.statusCode),
|
307
|
+
statusMessage=codec.encode(job_details.statusMessage))
|
308
|
+
|
309
|
+
self._mark_complete()
|
310
|
+
|
311
|
+
|
312
|
+
class GetJobResultRequest(ApiRequest[runtime_pb2.RuntimeJobInfoRequest, runtime_pb2.RuntimeJobResult]):
|
313
|
+
|
314
|
+
def __init__(self, engine_id, request, context):
|
315
|
+
|
316
|
+
super().__init__(engine_id, "get_job_result", request, context)
|
317
|
+
|
318
|
+
if request.HasField("jobKey"):
|
319
|
+
self._job_key = self._request.jobKey
|
320
|
+
elif request.HasField("jobSelector"):
|
321
|
+
self._job_key = util.object_key(self._request.jobSelector)
|
322
|
+
else:
|
323
|
+
raise ex.EValidation("Bad request: Neither jobKey nor jobSelector is specified")
|
324
|
+
|
325
|
+
def on_start(self):
|
326
|
+
self.actors().send(self._engine_id, "get_job_details", self._job_key, details=True)
|
327
|
+
|
328
|
+
@actors.Message
|
329
|
+
def job_details(self, job_details: tp.Optional[config.JobResult]):
|
330
|
+
|
331
|
+
if job_details is None:
|
332
|
+
self._grpc_code = grpc.StatusCode.NOT_FOUND
|
333
|
+
self._grpc_message = f"Job not found: [{self._job_key}]"
|
334
|
+
|
335
|
+
else:
|
61
336
|
|
62
|
-
|
337
|
+
encoded_results = dict((k, codec.encode(v)) for k, v in job_details.results.items())
|
63
338
|
|
64
|
-
|
65
|
-
|
339
|
+
self._response = runtime_pb2.RuntimeJobResult(
|
340
|
+
jobId=codec.encode(job_details.jobId),
|
341
|
+
statusCode=codec.encode(job_details.statusCode),
|
342
|
+
statusMessage=codec.encode(job_details.statusMessage),
|
343
|
+
results=encoded_results)
|
66
344
|
|
67
|
-
|
68
|
-
self.__thread_pool.shutdown()
|
345
|
+
self._mark_complete()
|
@@ -14,89 +14,216 @@
|
|
14
14
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
|
-
import
|
18
|
-
import typing as tp
|
17
|
+
import dataclasses as _dc
|
19
18
|
import decimal
|
20
19
|
import enum
|
21
|
-
import
|
20
|
+
import io
|
22
21
|
import inspect
|
23
|
-
import
|
22
|
+
import json
|
23
|
+
import os
|
24
|
+
import pathlib
|
25
|
+
import re
|
26
|
+
import typing as tp
|
27
|
+
import urllib.parse as _urlp
|
28
|
+
import uuid
|
24
29
|
|
30
|
+
import tracdap.rt.config as _config
|
25
31
|
import tracdap.rt.exceptions as _ex
|
32
|
+
import tracdap.rt.ext.plugins as _plugins
|
33
|
+
import tracdap.rt.ext.config as _config_ext
|
26
34
|
import tracdap.rt._impl.util as _util
|
27
35
|
|
28
|
-
import pathlib
|
29
|
-
import json
|
30
36
|
import yaml
|
31
37
|
import yaml.parser
|
32
38
|
|
33
|
-
|
34
39
|
_T = tp.TypeVar('_T')
|
35
40
|
|
36
41
|
|
37
|
-
class
|
42
|
+
class ConfigManager:
|
38
43
|
|
39
|
-
|
40
|
-
|
41
|
-
|
44
|
+
@classmethod
|
45
|
+
def for_root_config(cls, root_config_file: tp.Union[str, pathlib.Path, None]) -> ConfigManager:
|
46
|
+
|
47
|
+
if isinstance(root_config_file, pathlib.Path):
|
48
|
+
root_file_path = cls._resolve_scheme(root_config_file)
|
49
|
+
root_dir_path = cls._resolve_scheme(root_config_file.parent)
|
50
|
+
if root_dir_path[-1] not in ["/", "\\"]:
|
51
|
+
root_dir_path += os.sep
|
52
|
+
root_file_url = _urlp.urlparse(root_file_path, scheme="file")
|
53
|
+
root_dir_url = _urlp.urlparse(root_dir_path, scheme="file")
|
54
|
+
return ConfigManager(root_dir_url, root_file_url, )
|
55
|
+
|
56
|
+
elif isinstance(root_config_file, str):
|
57
|
+
root_file_with_scheme = cls._resolve_scheme(root_config_file)
|
58
|
+
root_file_url = _urlp.urlparse(root_file_with_scheme, scheme="file")
|
59
|
+
root_dir_path = str(pathlib.Path(root_file_url.path).parent)
|
60
|
+
if root_dir_path[-1] not in ["/", "\\"]:
|
61
|
+
root_dir_path += os.sep if root_file_url.scheme == "file" else "/"
|
62
|
+
root_dir_url = _urlp.urlparse(_urlp.urljoin(root_file_url.geturl(), root_dir_path))
|
63
|
+
return ConfigManager(root_dir_url, root_file_url)
|
42
64
|
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
str: str,
|
48
|
-
decimal.Decimal: decimal.Decimal
|
49
|
-
# TODO: Date (requires type system)
|
50
|
-
# TODO: Datetime (requires type system)
|
51
|
-
}
|
65
|
+
else:
|
66
|
+
working_dir_path = str(pathlib.Path.cwd().resolve())
|
67
|
+
working_dir_url = _urlp.urlparse(str(working_dir_path), scheme="file")
|
68
|
+
return ConfigManager(working_dir_url, None)
|
52
69
|
|
53
|
-
|
70
|
+
@classmethod
|
71
|
+
def for_root_dir(cls, root_config_dir: tp.Union[str, pathlib.Path]) -> ConfigManager:
|
72
|
+
|
73
|
+
if isinstance(root_config_dir, pathlib.Path):
|
74
|
+
root_dir_path = cls._resolve_scheme(root_config_dir)
|
75
|
+
if root_dir_path[-1] not in ["/", "\\"]:
|
76
|
+
root_dir_path += os.sep
|
77
|
+
root_dir_url = _urlp.urlparse(root_dir_path, scheme="file")
|
78
|
+
return ConfigManager(root_dir_url, None)
|
79
|
+
|
80
|
+
elif isinstance(root_config_dir, str):
|
81
|
+
root_dir_with_scheme = cls._resolve_scheme(root_config_dir)
|
82
|
+
if root_dir_with_scheme[-1] not in ["/", "\\"]:
|
83
|
+
root_dir_with_scheme += "/"
|
84
|
+
root_dir_url = _urlp.urlparse(root_dir_with_scheme, scheme="file")
|
85
|
+
return ConfigManager(root_dir_url, None)
|
86
|
+
|
87
|
+
# Should never happen since root dir is specified explicitly
|
88
|
+
else:
|
89
|
+
raise _ex.ETracInternal("Wrong parameter type for root_config_dir")
|
90
|
+
|
91
|
+
@classmethod
|
92
|
+
def _resolve_scheme(cls, raw_url: tp.Union[str, pathlib.Path]) -> str:
|
93
|
+
|
94
|
+
if isinstance(raw_url, pathlib.Path):
|
95
|
+
return "file:" + str(raw_url.resolve())
|
96
|
+
|
97
|
+
# Look for drive letters on Windows - these can be mis-interpreted as URL scheme
|
98
|
+
# If there is a drive letter, explicitly set scheme = file instead
|
99
|
+
if len(raw_url) > 1 and raw_url[1] == ":":
|
100
|
+
return "file:" + raw_url
|
101
|
+
else:
|
102
|
+
return raw_url
|
103
|
+
|
104
|
+
def __init__(self, root_dir_url: _urlp.ParseResult, root_file_url: tp.Optional[_urlp.ParseResult]):
|
54
105
|
self._log = _util.logger_for_object(self)
|
55
|
-
self.
|
56
|
-
self.
|
57
|
-
|
106
|
+
self._root_dir_url = root_dir_url
|
107
|
+
self._root_file_url = root_file_url
|
108
|
+
|
109
|
+
def config_dir_path(self):
|
110
|
+
if self._root_dir_url.scheme == "file":
|
111
|
+
return pathlib.Path(self._root_dir_url.path).resolve()
|
112
|
+
else:
|
113
|
+
return None
|
114
|
+
|
115
|
+
def load_root_object(
|
116
|
+
self, config_class: type(_T),
|
117
|
+
dev_mode_locations: tp.List[str] = None,
|
118
|
+
config_file_name: tp.Optional[str] = None) -> _T:
|
58
119
|
|
59
|
-
|
120
|
+
# Root config not available normally means you're using embedded config
|
121
|
+
# In which case this method should not be called
|
122
|
+
if self._root_file_url is None:
|
123
|
+
message = f"Root config file not available"
|
124
|
+
self._log.error(message)
|
125
|
+
raise _ex.EConfigLoad(message)
|
126
|
+
|
127
|
+
resolved_url = self._root_file_url
|
60
128
|
|
61
129
|
if config_file_name is not None:
|
62
|
-
self._log.info(f"Loading {config_file_name} config: {
|
130
|
+
self._log.info(f"Loading {config_file_name} config: {self._url_to_str(resolved_url)}")
|
63
131
|
else:
|
64
|
-
self._log.info(f"Loading config file: {
|
132
|
+
self._log.info(f"Loading config file: {self._url_to_str(resolved_url)}")
|
133
|
+
|
134
|
+
config_dict = self._load_config_dict(resolved_url)
|
65
135
|
|
66
|
-
|
67
|
-
|
136
|
+
parser = ConfigParser(config_class, dev_mode_locations)
|
137
|
+
return parser.parse(config_dict, resolved_url.path)
|
68
138
|
|
69
|
-
|
70
|
-
|
139
|
+
def load_config_object(
|
140
|
+
self, config_url: tp.Union[str, pathlib.Path],
|
141
|
+
config_class: type(_T),
|
142
|
+
dev_mode_locations: tp.List[str] = None,
|
143
|
+
config_file_name: tp.Optional[str] = None) -> _T:
|
71
144
|
|
72
|
-
|
73
|
-
config_path = config_file
|
145
|
+
resolved_url = self._resolve_config_file(config_url)
|
74
146
|
|
147
|
+
if config_file_name is not None:
|
148
|
+
self._log.info(f"Loading {config_file_name} config: {self._url_to_str(resolved_url)}")
|
75
149
|
else:
|
76
|
-
|
77
|
-
err = f"Attempt to load an invalid config file, expected a path, got {config_file_type}"
|
78
|
-
self._log.error(err)
|
79
|
-
raise _ex.EConfigLoad(err)
|
150
|
+
self._log.info(f"Loading config file: {self._url_to_str(resolved_url)}")
|
80
151
|
|
81
|
-
|
82
|
-
msg = f"Config file not found: [{config_file}]"
|
83
|
-
self._log.error(msg)
|
84
|
-
raise _ex.EConfigLoad(msg)
|
152
|
+
config_dict = self._load_config_dict(resolved_url)
|
85
153
|
|
86
|
-
|
87
|
-
|
88
|
-
self._log.error(msg)
|
89
|
-
raise _ex.EConfigLoad(msg)
|
154
|
+
parser = ConfigParser(config_class, dev_mode_locations)
|
155
|
+
return parser.parse(config_dict, config_url)
|
90
156
|
|
91
|
-
|
157
|
+
def load_config_file(
|
158
|
+
self, config_url: tp.Union[str, pathlib.Path],
|
159
|
+
config_file_name: tp.Optional[str] = None) -> bytes:
|
92
160
|
|
93
|
-
|
161
|
+
resolved_url = self._resolve_config_file(config_url)
|
162
|
+
|
163
|
+
if config_file_name is not None:
|
164
|
+
self._log.info(f"Loading {config_file_name} config: {self._url_to_str(resolved_url)}")
|
165
|
+
else:
|
166
|
+
self._log.info(f"Loading config file: {self._url_to_str(resolved_url)}")
|
167
|
+
|
168
|
+
return self._load_config_file(resolved_url)
|
169
|
+
|
170
|
+
def _resolve_config_file(self, config_url: tp.Union[str, pathlib.Path]) -> _urlp.ParseResult:
|
171
|
+
|
172
|
+
# If the config URL defines a scheme, treat it as absolute
|
173
|
+
# (This also works for Windows paths, C:\ is an absolute path)
|
174
|
+
if ":" in str(config_url):
|
175
|
+
absolute_url = str(config_url)
|
176
|
+
# If the root URL is a path, resolve using path logic (this allows for config_url to be an absolute path)
|
177
|
+
elif self._root_dir_url.scheme == "file":
|
178
|
+
absolute_url = str(pathlib.Path(self._root_dir_url.path).joinpath(str(config_url)))
|
179
|
+
# Otherwise resolve relative to the root URL
|
180
|
+
else:
|
181
|
+
absolute_url = _urlp.urljoin(self._root_dir_url.geturl(), str(config_url))
|
182
|
+
|
183
|
+
# Look for drive letters on Windows - these can be mis-interpreted as URL scheme
|
184
|
+
# If there is a drive letter, explicitly set scheme = file instead
|
185
|
+
if len(absolute_url) > 1 and absolute_url[1] == ":":
|
186
|
+
absolute_url = "file:" + absolute_url
|
187
|
+
|
188
|
+
return _urlp.urlparse(absolute_url, scheme="file")
|
189
|
+
|
190
|
+
def _load_config_file(self, resolved_url: _urlp.ParseResult) -> bytes:
|
191
|
+
|
192
|
+
loader = self._get_loader(resolved_url)
|
193
|
+
config_url = self._url_to_str(resolved_url)
|
194
|
+
|
195
|
+
if not loader.has_config_file(config_url):
|
196
|
+
message = f"Config file not found: {config_url}"
|
197
|
+
self._log.error(message)
|
198
|
+
raise _ex.EConfigLoad(message)
|
199
|
+
|
200
|
+
return loader.load_config_file(config_url)
|
201
|
+
|
202
|
+
def _load_config_dict(self, resolved_url: _urlp.ParseResult) -> dict:
|
203
|
+
|
204
|
+
loader = self._get_loader(resolved_url)
|
205
|
+
config_url = self._url_to_str(resolved_url)
|
206
|
+
|
207
|
+
if loader.has_config_dict(config_url):
|
208
|
+
return loader.load_config_dict(config_url)
|
209
|
+
|
210
|
+
elif loader.has_config_file(config_url):
|
211
|
+
config_bytes = loader.load_config_file(config_url)
|
212
|
+
config_path = pathlib.Path(resolved_url.path)
|
213
|
+
return self._parse_config_dict(config_bytes, config_path)
|
214
|
+
|
215
|
+
else:
|
216
|
+
message = f"Config file not found: {config_url}"
|
217
|
+
self._log.error(message)
|
218
|
+
raise _ex.EConfigLoad(message)
|
219
|
+
|
220
|
+
def _parse_config_dict(self, config_bytes: bytes, config_path: pathlib.Path):
|
94
221
|
|
95
222
|
# Read in the raw config, use the file extension to decide which format to expect
|
96
223
|
|
97
224
|
try:
|
98
225
|
|
99
|
-
with
|
226
|
+
with io.BytesIO(config_bytes) as config_stream:
|
100
227
|
|
101
228
|
extension = config_path.suffix.lower()
|
102
229
|
|
@@ -123,11 +250,54 @@ class ConfigParser(tp.Generic[_T]):
|
|
123
250
|
self._log.error(err)
|
124
251
|
raise _ex.EConfigParse(err) from e
|
125
252
|
|
126
|
-
except yaml.parser.ParserError as e:
|
253
|
+
except (yaml.parser.ParserError, yaml.reader.ReaderError) as e:
|
127
254
|
err = f"Config file contains invalid YAML ({str(e)})"
|
128
255
|
self._log.error(err)
|
129
256
|
raise _ex.EConfigParse(err) from e
|
130
257
|
|
258
|
+
def _get_loader(self, resolved_url: _urlp.ParseResult) -> _config_ext.IConfigLoader:
|
259
|
+
|
260
|
+
protocol = resolved_url.scheme
|
261
|
+
loader_config = _config.PluginConfig(protocol)
|
262
|
+
|
263
|
+
if not _plugins.PluginManager.is_plugin_available(_config_ext.IConfigLoader, protocol):
|
264
|
+
message = f"No config loader available for protocol [{protocol}]: {self._url_to_str(resolved_url)}"
|
265
|
+
self._log.error(message)
|
266
|
+
raise _ex.EConfigLoad(message)
|
267
|
+
|
268
|
+
return _plugins.PluginManager.load_config_plugin(_config_ext.IConfigLoader, loader_config)
|
269
|
+
|
270
|
+
@staticmethod
|
271
|
+
def _url_to_str(url: _urlp.ParseResult) -> str:
|
272
|
+
|
273
|
+
if url.scheme == "file" and not url.netloc:
|
274
|
+
return url.path
|
275
|
+
else:
|
276
|
+
return url.geturl()
|
277
|
+
|
278
|
+
|
279
|
+
class ConfigParser(tp.Generic[_T]):
|
280
|
+
|
281
|
+
# The metaclass for generic types varies between versions of the typing library
|
282
|
+
# To work around this, detect the correct metaclass by inspecting a generic type variable
|
283
|
+
__generic_metaclass = type(tp.List[object])
|
284
|
+
|
285
|
+
__primitive_types: tp.Dict[type, callable] = {
|
286
|
+
bool: bool,
|
287
|
+
int: int,
|
288
|
+
float: float,
|
289
|
+
str: str,
|
290
|
+
decimal.Decimal: decimal.Decimal
|
291
|
+
# TODO: Date (requires type system)
|
292
|
+
# TODO: Datetime (requires type system)
|
293
|
+
}
|
294
|
+
|
295
|
+
def __init__(self, config_class: _T.__class__, dev_mode_locations: tp.List[str] = None):
|
296
|
+
self._log = _util.logger_for_object(self)
|
297
|
+
self._config_class = config_class
|
298
|
+
self._dev_mode_locations = dev_mode_locations or []
|
299
|
+
self._errors = []
|
300
|
+
|
131
301
|
def parse(self, config_dict: dict, config_file: tp.Union[str, pathlib.Path] = None) -> _T:
|
132
302
|
|
133
303
|
# If config is empty, return a default (blank) config
|