tracdap-runtime 0.6.2__py3-none-any.whl → 0.6.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. tracdap/rt/_exec/actors.py +87 -10
  2. tracdap/rt/_exec/context.py +207 -100
  3. tracdap/rt/_exec/dev_mode.py +52 -20
  4. tracdap/rt/_exec/engine.py +79 -14
  5. tracdap/rt/_exec/functions.py +14 -17
  6. tracdap/rt/_exec/runtime.py +83 -40
  7. tracdap/rt/_exec/server.py +306 -29
  8. tracdap/rt/_impl/config_parser.py +219 -49
  9. tracdap/rt/_impl/data.py +70 -5
  10. tracdap/rt/_impl/grpc/codec.py +60 -5
  11. tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2.py +19 -19
  12. tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2.pyi +11 -9
  13. tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2_grpc.py +25 -25
  14. tracdap/rt/_impl/grpc/tracdap/metadata/data_pb2.py +18 -18
  15. tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.py +28 -16
  16. tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.pyi +37 -6
  17. tracdap/rt/_impl/grpc/tracdap/metadata/object_pb2.py +8 -3
  18. tracdap/rt/_impl/grpc/tracdap/metadata/object_pb2.pyi +13 -2
  19. tracdap/rt/_impl/guard_rails.py +21 -0
  20. tracdap/rt/_impl/models.py +25 -0
  21. tracdap/rt/_impl/static_api.py +43 -13
  22. tracdap/rt/_impl/type_system.py +17 -0
  23. tracdap/rt/_impl/validation.py +47 -4
  24. tracdap/rt/_plugins/config_local.py +49 -0
  25. tracdap/rt/_version.py +1 -1
  26. tracdap/rt/api/hook.py +6 -5
  27. tracdap/rt/api/model_api.py +50 -7
  28. tracdap/rt/api/static_api.py +81 -23
  29. tracdap/rt/config/__init__.py +4 -4
  30. tracdap/rt/config/common.py +25 -15
  31. tracdap/rt/config/job.py +2 -2
  32. tracdap/rt/config/platform.py +25 -35
  33. tracdap/rt/config/result.py +2 -2
  34. tracdap/rt/config/runtime.py +4 -2
  35. tracdap/rt/ext/config.py +34 -0
  36. tracdap/rt/ext/embed.py +1 -3
  37. tracdap/rt/ext/plugins.py +47 -6
  38. tracdap/rt/launch/cli.py +11 -4
  39. tracdap/rt/launch/launch.py +53 -12
  40. tracdap/rt/metadata/__init__.py +17 -17
  41. tracdap/rt/metadata/common.py +2 -2
  42. tracdap/rt/metadata/custom.py +3 -3
  43. tracdap/rt/metadata/data.py +12 -12
  44. tracdap/rt/metadata/file.py +6 -6
  45. tracdap/rt/metadata/flow.py +6 -6
  46. tracdap/rt/metadata/job.py +8 -8
  47. tracdap/rt/metadata/model.py +21 -11
  48. tracdap/rt/metadata/object.py +3 -0
  49. tracdap/rt/metadata/object_id.py +8 -8
  50. tracdap/rt/metadata/search.py +5 -5
  51. tracdap/rt/metadata/stoarge.py +6 -6
  52. tracdap/rt/metadata/tag.py +1 -1
  53. tracdap/rt/metadata/tag_update.py +1 -1
  54. tracdap/rt/metadata/type.py +4 -4
  55. {tracdap_runtime-0.6.2.dist-info → tracdap_runtime-0.6.4.dist-info}/METADATA +4 -4
  56. tracdap_runtime-0.6.4.dist-info/RECORD +112 -0
  57. {tracdap_runtime-0.6.2.dist-info → tracdap_runtime-0.6.4.dist-info}/WHEEL +1 -1
  58. tracdap/rt/_impl/grpc/tracdap/config/common_pb2.py +0 -55
  59. tracdap/rt/_impl/grpc/tracdap/config/common_pb2.pyi +0 -103
  60. tracdap/rt/_impl/grpc/tracdap/config/job_pb2.py +0 -42
  61. tracdap/rt/_impl/grpc/tracdap/config/job_pb2.pyi +0 -44
  62. tracdap/rt/_impl/grpc/tracdap/config/platform_pb2.py +0 -71
  63. tracdap/rt/_impl/grpc/tracdap/config/platform_pb2.pyi +0 -197
  64. tracdap/rt/_impl/grpc/tracdap/config/result_pb2.py +0 -37
  65. tracdap/rt/_impl/grpc/tracdap/config/result_pb2.pyi +0 -35
  66. tracdap/rt/_impl/grpc/tracdap/config/runtime_pb2.py +0 -42
  67. tracdap/rt/_impl/grpc/tracdap/config/runtime_pb2.pyi +0 -46
  68. tracdap/rt/ext/_guard.py +0 -37
  69. tracdap_runtime-0.6.2.dist-info/RECORD +0 -121
  70. {tracdap_runtime-0.6.2.dist-info → tracdap_runtime-0.6.4.dist-info}/LICENSE +0 -0
  71. {tracdap_runtime-0.6.2.dist-info → tracdap_runtime-0.6.4.dist-info}/top_level.txt +0 -0
@@ -12,57 +12,334 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import asyncio
16
+ import threading
15
17
  import typing as tp
16
- import concurrent.futures as futures
18
+
19
+ import tracdap.rt.config as config
20
+ import tracdap.rt.exceptions as ex
21
+ import tracdap.rt._exec.actors as actors
22
+ import tracdap.rt._impl.grpc.codec as codec # noqa
23
+ import tracdap.rt._impl.util as util # noqa
24
+
25
+ # Check whether gRPC is installed before trying to load any of the generated modules
26
+ try:
27
+ import grpc.aio # noqa
28
+ import google.protobuf.message as _msg # noqa
29
+ except ImportError:
30
+ raise ex.EStartup("The runtime API server cannot be enabled because gRPC libraries are not installed")
17
31
 
18
32
  # Imports for gRPC generated code, these are managed by build_runtime.py for distribution
19
33
  import tracdap.rt._impl.grpc.tracdap.api.internal.runtime_pb2 as runtime_pb2
20
34
  import tracdap.rt._impl.grpc.tracdap.api.internal.runtime_pb2_grpc as runtime_grpc
21
- import grpc
22
35
 
23
36
 
24
37
  class RuntimeApiServer(runtime_grpc.TracRuntimeApiServicer):
25
38
 
26
- __THREAD_POOL_DEFAULT_SIZE = 2
27
- __THREAD_NAME_PREFIX = "server-"
28
- __DEFAULT_SHUTDOWN_TIMEOUT = 10.0 # seconds
39
+ # Default timeout values in seconds
40
+ __DEFAULT_STARTUP_TIMEOUT = 5.0
41
+ __DEFAULT_SHUTDOWN_TIMEOUT = 10.0
42
+ __DEFAULT_REQUEST_TIMEOUT = 10.0
43
+
44
+ def __init__(self, system: actors.ActorSystem, port: int):
45
+
46
+ self.__log = util.logger_for_object(self)
47
+
48
+ self.__system = system
49
+ self.__engine_id = system.main_id()
50
+ self.__agent: tp.Optional[ApiAgent] = None
29
51
 
30
- def __init__(self, port: int, n_workers: int = None):
31
52
  self.__port = port
32
- self.__n_workers = n_workers or self.__THREAD_POOL_DEFAULT_SIZE
33
- self.__server: tp.Optional[grpc.Server] = None
34
- self.__thread_pool: tp.Optional[futures.ThreadPoolExecutor] = None
53
+ self.__request_timeout = self.__DEFAULT_REQUEST_TIMEOUT # Not configurable atm
54
+ self.__server: tp.Optional[grpc.aio.Server] = None
55
+ self.__server_thread: tp.Optional[threading.Thread] = None
56
+ self.__event_loop: tp.Optional[asyncio.AbstractEventLoop] = None
57
+
58
+ self.__start_signal: tp.Optional[threading.Event] = None
59
+ self.__stop_signal: tp.Optional[asyncio.Event] = None
60
+
61
+ def start(self, startup_timeout: float = None):
62
+
63
+ if self.__start_signal is not None:
64
+ return
65
+
66
+ timeout = startup_timeout or self.__DEFAULT_SHUTDOWN_TIMEOUT
67
+
68
+ self.__start_signal = threading.Event()
69
+ self.__server_thread = threading.Thread(target=self.__server_main, name="api_server", daemon=True)
70
+ self.__server_thread.start()
71
+
72
+ try:
73
+ self.__start_signal.wait(timeout)
74
+ except TimeoutError as e:
75
+ raise ex.EStartup("Runtime API failed to start") from e
76
+
77
+ def stop(self, shutdown_timeout: float = None):
78
+
79
+ if self.__server is None:
80
+ return
81
+
82
+ timeout = shutdown_timeout or self.__DEFAULT_SHUTDOWN_TIMEOUT
35
83
 
36
- def listJobs(self, request, context):
37
- return super().listJobs(request, context)
84
+ self.__event_loop.call_soon_threadsafe(lambda: self.__stop_signal.set())
85
+ self.__server_thread.join(timeout)
38
86
 
39
- def getJobStatus(self, request: runtime_pb2.BatchJobStatusRequest, context: grpc.ServicerContext):
40
- return super().getJobStatus(request, context)
87
+ if self.__server_thread.is_alive():
88
+ self.__log.warning("Runtime API server did not go down cleanly")
41
89
 
42
- def getJobDetails(self, request, context):
43
- return super().getJobDetails(request, context)
90
+ def __server_main(self):
44
91
 
45
- def start(self):
92
+ self.__event_loop = asyncio.new_event_loop()
93
+ self.__event_loop.run_until_complete(self.__server_main_async())
94
+ self.__event_loop.close()
46
95
 
47
- self.__thread_pool = futures.ThreadPoolExecutor(
48
- max_workers=self.__n_workers,
49
- thread_name_prefix=self.__THREAD_NAME_PREFIX)
96
+ async def __server_main_async(self):
50
97
 
51
- self.__server = grpc.server(self.__thread_pool)
98
+ server_address = f"[::]:{self.__port}"
52
99
 
53
- socket = f"[::]:{self.__port}"
54
- self.__server.add_insecure_port(socket)
100
+ # Asyncio events must be created inside the event loop for Python <= 3.9
101
+ self.__stop_signal = asyncio.Event()
55
102
 
103
+ # Agent using asyncio, so must be created inside the event loop
104
+ self.__agent = ApiAgent()
105
+ self.__system.spawn_agent(self.__agent)
106
+
107
+ self.__server = grpc.aio.server()
108
+ self.__server.add_insecure_port(server_address)
56
109
  runtime_grpc.add_TracRuntimeApiServicer_to_server(self, self.__server)
57
110
 
58
- self.__server.start()
111
+ await self.__server.start()
112
+ await self.__agent.started()
59
113
 
60
- def stop(self, shutdown_timeout: float = None):
114
+ self.__start_signal.set()
115
+
116
+ self.__log.info(f"Runtime API server is up and listening on port [{self.__port}]")
117
+
118
+ await asyncio.create_task(self.__stop_signal.wait())
119
+
120
+ self.__log.info(f"Shutdown signal received, runtime API server is going down...")
121
+
122
+ await self.__server.stop(self.__DEFAULT_SHUTDOWN_TIMEOUT)
123
+ self.__server = None
124
+
125
+ self.__log.info("Runtime API server has gone down cleanly")
126
+
127
+ async def listJobs(self, request: runtime_pb2.RuntimeListJobsRequest, context: grpc.aio.ServicerContext):
128
+
129
+ request_task = ListJobsRequest(self.__engine_id, request, context)
130
+ self.__agent.threadsafe().spawn(request_task)
131
+
132
+ return await request_task.complete(self.__request_timeout)
133
+
134
+ async def getJobStatus(self, request: runtime_pb2.RuntimeJobInfoRequest, context: grpc.ServicerContext):
135
+
136
+ request_task = GetJobStatusRequest(self.__engine_id, request, context)
137
+ self.__agent.threadsafe().spawn(request_task)
138
+
139
+ return await request_task.complete(self.__request_timeout)
140
+
141
+ async def getJobResult(self, request: runtime_pb2.RuntimeJobInfoRequest, context: grpc.ServicerContext):
142
+
143
+ request_task = GetJobResultRequest(self.__engine_id, request, context)
144
+ self.__agent.threadsafe().spawn(request_task)
145
+
146
+ return await request_task.complete(self.__request_timeout)
147
+
148
+
149
+ _T_REQUEST = tp.TypeVar("_T_REQUEST", bound=_msg.Message)
150
+ _T_RESPONSE = tp.TypeVar("_T_RESPONSE", bound=_msg.Message)
151
+
152
+
153
+ class ApiAgent(actors.ThreadsafeActor):
154
+
155
+ # API Agent is the parent actor that will be used to spawn API requests
156
+ # It must be created inside the asyncio event loop
157
+
158
+ def __init__(self):
159
+ super().__init__()
160
+ self._log = util.logger_for_object(self)
161
+ self._event_loop = asyncio.get_event_loop()
162
+ self.__start_signal = asyncio.Event()
163
+
164
+ def on_start(self):
165
+ self._event_loop.call_soon_threadsafe(lambda: self.__start_signal.set())
166
+
167
+ def on_signal(self, signal: actors.Signal) -> tp.Optional[bool]:
168
+
169
+ # Do not allow a failed request to bring down the API server
170
+ if signal.message == actors.SignalNames.FAILED:
171
+ error = signal.error if isinstance(signal, actors.ErrorSignal) else None
172
+ self._log.warning("Unhandled error during API request: " + str(error))
173
+ self._log.warning("The API agent will continue running")
174
+ return True
175
+
176
+ return False
177
+
178
+ async def started(self):
179
+ await self.__start_signal.wait()
180
+
181
+
182
+ class ApiRequest(actors.ThreadsafeActor, tp.Generic[_T_REQUEST, _T_RESPONSE]):
183
+
184
+ # API request is the bridge between asyncio events (gRPC) and actor messages (TRAC runtime engine)
185
+ # Requests objects must be created inside the asyncio event loop
186
+
187
+ _log = None
188
+
189
+ def __init__(
190
+ self, engine_id, method: str, request: _T_REQUEST,
191
+ context: grpc.aio.ServicerContext):
192
+
193
+ super().__init__()
194
+
195
+ self._engine_id = engine_id
196
+ self._method = method
197
+ self._request = request
198
+ self._response: tp.Optional[_T_RESPONSE] = None
199
+ self._error: tp.Optional[Exception] = None
200
+ self._grpc_code = grpc.StatusCode.OK
201
+ self._grpc_message = ""
202
+
203
+ self._context = context
204
+ self._event_loop = asyncio.get_event_loop()
205
+ self._completion = asyncio.Event()
206
+
207
+ self._log.info("API call start: %s()", self._method)
208
+
209
+ def _mark_complete(self):
210
+
211
+ self._event_loop.call_soon_threadsafe(lambda: self._completion.set())
212
+
213
+ def on_stop(self):
214
+
215
+ if self.state() == actors.ActorState.ERROR:
216
+ self._error = self.error()
217
+
218
+ self._mark_complete()
219
+
220
+ async def complete(self, request_timeout: float) -> _T_RESPONSE:
221
+
222
+ try:
223
+
224
+ completion_task = asyncio.create_task(self._completion.wait())
225
+ await asyncio.wait_for(completion_task, request_timeout)
226
+
227
+ if self._error:
228
+ raise self._error
229
+
230
+ elif self._grpc_code != grpc.StatusCode.OK:
231
+ self._log.info("API call failed: %s() %s %s", self._method, self._grpc_code.name, self._grpc_message)
232
+ self._context.set_code(self._grpc_code)
233
+ self._context.set_details(self._grpc_message)
234
+
235
+ elif self._response is not None:
236
+ self._log.info("API call succeeded: %s()", self._method)
237
+ return self._response
238
+
239
+ else:
240
+ raise ex.EUnexpected()
241
+
242
+ except TimeoutError:
243
+ self._completion.set()
244
+ self._context.set_code(grpc.StatusCode.DEADLINE_EXCEEDED)
245
+ self._context.set_details("The TRAC runtime engine did not respond")
246
+ self._log.error("API call failed: %s() %s", self._method, "The TRAC runtime engine did not respond")
247
+ raise
248
+
249
+ except Exception as e:
250
+ self._context.set_code(grpc.StatusCode.INTERNAL)
251
+ self._context.set_details("Internal server error")
252
+ self._log.error("API call failed: %s() %s", self._method, str(e))
253
+ self._log.exception(e)
254
+ raise
255
+
256
+ finally:
257
+ self.threadsafe().stop()
258
+
259
+
260
+ ApiRequest._log = util.logger_for_class(ApiRequest)
261
+
262
+
263
+ class ListJobsRequest(ApiRequest[runtime_pb2.RuntimeListJobsRequest, runtime_pb2.RuntimeListJobsResponse]):
264
+
265
+ def __init__(self, engine_id, request, context):
266
+ super().__init__(engine_id, "get_job_list", request, context)
267
+
268
+ def on_start(self):
269
+ self.actors().send(self._engine_id, "get_job_list")
270
+
271
+ @actors.Message
272
+ def job_list(self, job_list):
273
+
274
+ self._response = runtime_pb2.RuntimeListJobsResponse(
275
+ jobs=codec.encode(job_list))
276
+
277
+ self._mark_complete()
278
+
279
+
280
+ class GetJobStatusRequest(ApiRequest[runtime_pb2.RuntimeJobInfoRequest, runtime_pb2.RuntimeJobStatus]):
281
+
282
+ def __init__(self, engine_id, request, context):
283
+
284
+ super().__init__(engine_id, "get_job_status", request, context)
285
+
286
+ if request.HasField("jobKey"):
287
+ self._job_key = self._request.jobKey
288
+ elif request.HasField("jobSelector"):
289
+ self._job_key = util.object_key(self._request.jobSelector)
290
+ else:
291
+ raise ex.EValidation("Bad request: Neither jobKey nor jobSelector is specified")
292
+
293
+ def on_start(self):
294
+ self.actors().send(self._engine_id, "get_job_details", self._job_key, details=False)
295
+
296
+ @actors.Message
297
+ def job_details(self, job_details: tp.Optional[config.JobResult]):
298
+
299
+ if job_details is None:
300
+ self._grpc_code = grpc.StatusCode.NOT_FOUND
301
+ self._grpc_message = f"Job not found: [{self._job_key}]"
302
+
303
+ else:
304
+ self._response = runtime_pb2.RuntimeJobStatus(
305
+ jobId=codec.encode(job_details.jobId),
306
+ statusCode=codec.encode(job_details.statusCode),
307
+ statusMessage=codec.encode(job_details.statusMessage))
308
+
309
+ self._mark_complete()
310
+
311
+
312
+ class GetJobResultRequest(ApiRequest[runtime_pb2.RuntimeJobInfoRequest, runtime_pb2.RuntimeJobResult]):
313
+
314
+ def __init__(self, engine_id, request, context):
315
+
316
+ super().__init__(engine_id, "get_job_result", request, context)
317
+
318
+ if request.HasField("jobKey"):
319
+ self._job_key = self._request.jobKey
320
+ elif request.HasField("jobSelector"):
321
+ self._job_key = util.object_key(self._request.jobSelector)
322
+ else:
323
+ raise ex.EValidation("Bad request: Neither jobKey nor jobSelector is specified")
324
+
325
+ def on_start(self):
326
+ self.actors().send(self._engine_id, "get_job_details", self._job_key, details=True)
327
+
328
+ @actors.Message
329
+ def job_details(self, job_details: tp.Optional[config.JobResult]):
330
+
331
+ if job_details is None:
332
+ self._grpc_code = grpc.StatusCode.NOT_FOUND
333
+ self._grpc_message = f"Job not found: [{self._job_key}]"
334
+
335
+ else:
61
336
 
62
- grace = shutdown_timeout or self.__DEFAULT_SHUTDOWN_TIMEOUT
337
+ encoded_results = dict((k, codec.encode(v)) for k, v in job_details.results.items())
63
338
 
64
- if self.__server is not None:
65
- self.__server.stop(grace)
339
+ self._response = runtime_pb2.RuntimeJobResult(
340
+ jobId=codec.encode(job_details.jobId),
341
+ statusCode=codec.encode(job_details.statusCode),
342
+ statusMessage=codec.encode(job_details.statusMessage),
343
+ results=encoded_results)
66
344
 
67
- if self.__thread_pool is not None:
68
- self.__thread_pool.shutdown()
345
+ self._mark_complete()
@@ -14,89 +14,216 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
- import re
18
- import typing as tp
17
+ import dataclasses as _dc
19
18
  import decimal
20
19
  import enum
21
- import uuid
20
+ import io
22
21
  import inspect
23
- import dataclasses as _dc
22
+ import json
23
+ import os
24
+ import pathlib
25
+ import re
26
+ import typing as tp
27
+ import urllib.parse as _urlp
28
+ import uuid
24
29
 
30
+ import tracdap.rt.config as _config
25
31
  import tracdap.rt.exceptions as _ex
32
+ import tracdap.rt.ext.plugins as _plugins
33
+ import tracdap.rt.ext.config as _config_ext
26
34
  import tracdap.rt._impl.util as _util
27
35
 
28
- import pathlib
29
- import json
30
36
  import yaml
31
37
  import yaml.parser
32
38
 
33
-
34
39
  _T = tp.TypeVar('_T')
35
40
 
36
41
 
37
- class ConfigParser(tp.Generic[_T]):
42
+ class ConfigManager:
38
43
 
39
- # The metaclass for generic types varies between versions of the typing library
40
- # To work around this, detect the correct metaclass by inspecting a generic type variable
41
- __generic_metaclass = type(tp.List[object])
44
+ @classmethod
45
+ def for_root_config(cls, root_config_file: tp.Union[str, pathlib.Path, None]) -> ConfigManager:
46
+
47
+ if isinstance(root_config_file, pathlib.Path):
48
+ root_file_path = cls._resolve_scheme(root_config_file)
49
+ root_dir_path = cls._resolve_scheme(root_config_file.parent)
50
+ if root_dir_path[-1] not in ["/", "\\"]:
51
+ root_dir_path += os.sep
52
+ root_file_url = _urlp.urlparse(root_file_path, scheme="file")
53
+ root_dir_url = _urlp.urlparse(root_dir_path, scheme="file")
54
+ return ConfigManager(root_dir_url, root_file_url, )
55
+
56
+ elif isinstance(root_config_file, str):
57
+ root_file_with_scheme = cls._resolve_scheme(root_config_file)
58
+ root_file_url = _urlp.urlparse(root_file_with_scheme, scheme="file")
59
+ root_dir_path = str(pathlib.Path(root_file_url.path).parent)
60
+ if root_dir_path[-1] not in ["/", "\\"]:
61
+ root_dir_path += os.sep if root_file_url.scheme == "file" else "/"
62
+ root_dir_url = _urlp.urlparse(_urlp.urljoin(root_file_url.geturl(), root_dir_path))
63
+ return ConfigManager(root_dir_url, root_file_url)
42
64
 
43
- __primitive_types: tp.Dict[type, callable] = {
44
- bool: bool,
45
- int: int,
46
- float: float,
47
- str: str,
48
- decimal.Decimal: decimal.Decimal
49
- # TODO: Date (requires type system)
50
- # TODO: Datetime (requires type system)
51
- }
65
+ else:
66
+ working_dir_path = str(pathlib.Path.cwd().resolve())
67
+ working_dir_url = _urlp.urlparse(str(working_dir_path), scheme="file")
68
+ return ConfigManager(working_dir_url, None)
52
69
 
53
- def __init__(self, config_class: _T.__class__, dev_mode_locations: tp.List[str] = None):
70
+ @classmethod
71
+ def for_root_dir(cls, root_config_dir: tp.Union[str, pathlib.Path]) -> ConfigManager:
72
+
73
+ if isinstance(root_config_dir, pathlib.Path):
74
+ root_dir_path = cls._resolve_scheme(root_config_dir)
75
+ if root_dir_path[-1] not in ["/", "\\"]:
76
+ root_dir_path += os.sep
77
+ root_dir_url = _urlp.urlparse(root_dir_path, scheme="file")
78
+ return ConfigManager(root_dir_url, None)
79
+
80
+ elif isinstance(root_config_dir, str):
81
+ root_dir_with_scheme = cls._resolve_scheme(root_config_dir)
82
+ if root_dir_with_scheme[-1] not in ["/", "\\"]:
83
+ root_dir_with_scheme += "/"
84
+ root_dir_url = _urlp.urlparse(root_dir_with_scheme, scheme="file")
85
+ return ConfigManager(root_dir_url, None)
86
+
87
+ # Should never happen since root dir is specified explicitly
88
+ else:
89
+ raise _ex.ETracInternal("Wrong parameter type for root_config_dir")
90
+
91
+ @classmethod
92
+ def _resolve_scheme(cls, raw_url: tp.Union[str, pathlib.Path]) -> str:
93
+
94
+ if isinstance(raw_url, pathlib.Path):
95
+ return "file:" + str(raw_url.resolve())
96
+
97
+ # Look for drive letters on Windows - these can be mis-interpreted as URL scheme
98
+ # If there is a drive letter, explicitly set scheme = file instead
99
+ if len(raw_url) > 1 and raw_url[1] == ":":
100
+ return "file:" + raw_url
101
+ else:
102
+ return raw_url
103
+
104
+ def __init__(self, root_dir_url: _urlp.ParseResult, root_file_url: tp.Optional[_urlp.ParseResult]):
54
105
  self._log = _util.logger_for_object(self)
55
- self._config_class = config_class
56
- self._dev_mode_locations = dev_mode_locations or []
57
- self._errors = []
106
+ self._root_dir_url = root_dir_url
107
+ self._root_file_url = root_file_url
108
+
109
+ def config_dir_path(self):
110
+ if self._root_dir_url.scheme == "file":
111
+ return pathlib.Path(self._root_dir_url.path).resolve()
112
+ else:
113
+ return None
114
+
115
+ def load_root_object(
116
+ self, config_class: type(_T),
117
+ dev_mode_locations: tp.List[str] = None,
118
+ config_file_name: tp.Optional[str] = None) -> _T:
58
119
 
59
- def load_raw_config(self, config_file: tp.Union[str, pathlib.Path], config_file_name: str = None):
120
+ # Root config not available normally means you're using embedded config
121
+ # In which case this method should not be called
122
+ if self._root_file_url is None:
123
+ message = f"Root config file not available"
124
+ self._log.error(message)
125
+ raise _ex.EConfigLoad(message)
126
+
127
+ resolved_url = self._root_file_url
60
128
 
61
129
  if config_file_name is not None:
62
- self._log.info(f"Loading {config_file_name} config: {str(config_file)}")
130
+ self._log.info(f"Loading {config_file_name} config: {self._url_to_str(resolved_url)}")
63
131
  else:
64
- self._log.info(f"Loading config file: {str(config_file)}")
132
+ self._log.info(f"Loading config file: {self._url_to_str(resolved_url)}")
133
+
134
+ config_dict = self._load_config_dict(resolved_url)
65
135
 
66
- # Construct a Path for config_file and make sure the file exists
67
- # (For now, config must be on a locally mounted filesystem)
136
+ parser = ConfigParser(config_class, dev_mode_locations)
137
+ return parser.parse(config_dict, resolved_url.path)
68
138
 
69
- if isinstance(config_file, str):
70
- config_path = pathlib.Path(config_file)
139
+ def load_config_object(
140
+ self, config_url: tp.Union[str, pathlib.Path],
141
+ config_class: type(_T),
142
+ dev_mode_locations: tp.List[str] = None,
143
+ config_file_name: tp.Optional[str] = None) -> _T:
71
144
 
72
- elif isinstance(config_file, pathlib.Path):
73
- config_path = config_file
145
+ resolved_url = self._resolve_config_file(config_url)
74
146
 
147
+ if config_file_name is not None:
148
+ self._log.info(f"Loading {config_file_name} config: {self._url_to_str(resolved_url)}")
75
149
  else:
76
- config_file_type = type(config_file) if config_file is not None else "None"
77
- err = f"Attempt to load an invalid config file, expected a path, got {config_file_type}"
78
- self._log.error(err)
79
- raise _ex.EConfigLoad(err)
150
+ self._log.info(f"Loading config file: {self._url_to_str(resolved_url)}")
80
151
 
81
- if not config_path.exists():
82
- msg = f"Config file not found: [{config_file}]"
83
- self._log.error(msg)
84
- raise _ex.EConfigLoad(msg)
152
+ config_dict = self._load_config_dict(resolved_url)
85
153
 
86
- if not config_path.is_file():
87
- msg = f"Config path does not point to a regular file: [{config_file}]"
88
- self._log.error(msg)
89
- raise _ex.EConfigLoad(msg)
154
+ parser = ConfigParser(config_class, dev_mode_locations)
155
+ return parser.parse(config_dict, config_url)
90
156
 
91
- return self._parse_raw_config(config_path)
157
+ def load_config_file(
158
+ self, config_url: tp.Union[str, pathlib.Path],
159
+ config_file_name: tp.Optional[str] = None) -> bytes:
92
160
 
93
- def _parse_raw_config(self, config_path: pathlib.Path):
161
+ resolved_url = self._resolve_config_file(config_url)
162
+
163
+ if config_file_name is not None:
164
+ self._log.info(f"Loading {config_file_name} config: {self._url_to_str(resolved_url)}")
165
+ else:
166
+ self._log.info(f"Loading config file: {self._url_to_str(resolved_url)}")
167
+
168
+ return self._load_config_file(resolved_url)
169
+
170
+ def _resolve_config_file(self, config_url: tp.Union[str, pathlib.Path]) -> _urlp.ParseResult:
171
+
172
+ # If the config URL defines a scheme, treat it as absolute
173
+ # (This also works for Windows paths, C:\ is an absolute path)
174
+ if ":" in str(config_url):
175
+ absolute_url = str(config_url)
176
+ # If the root URL is a path, resolve using path logic (this allows for config_url to be an absolute path)
177
+ elif self._root_dir_url.scheme == "file":
178
+ absolute_url = str(pathlib.Path(self._root_dir_url.path).joinpath(str(config_url)))
179
+ # Otherwise resolve relative to the root URL
180
+ else:
181
+ absolute_url = _urlp.urljoin(self._root_dir_url.geturl(), str(config_url))
182
+
183
+ # Look for drive letters on Windows - these can be mis-interpreted as URL scheme
184
+ # If there is a drive letter, explicitly set scheme = file instead
185
+ if len(absolute_url) > 1 and absolute_url[1] == ":":
186
+ absolute_url = "file:" + absolute_url
187
+
188
+ return _urlp.urlparse(absolute_url, scheme="file")
189
+
190
+ def _load_config_file(self, resolved_url: _urlp.ParseResult) -> bytes:
191
+
192
+ loader = self._get_loader(resolved_url)
193
+ config_url = self._url_to_str(resolved_url)
194
+
195
+ if not loader.has_config_file(config_url):
196
+ message = f"Config file not found: {config_url}"
197
+ self._log.error(message)
198
+ raise _ex.EConfigLoad(message)
199
+
200
+ return loader.load_config_file(config_url)
201
+
202
+ def _load_config_dict(self, resolved_url: _urlp.ParseResult) -> dict:
203
+
204
+ loader = self._get_loader(resolved_url)
205
+ config_url = self._url_to_str(resolved_url)
206
+
207
+ if loader.has_config_dict(config_url):
208
+ return loader.load_config_dict(config_url)
209
+
210
+ elif loader.has_config_file(config_url):
211
+ config_bytes = loader.load_config_file(config_url)
212
+ config_path = pathlib.Path(resolved_url.path)
213
+ return self._parse_config_dict(config_bytes, config_path)
214
+
215
+ else:
216
+ message = f"Config file not found: {config_url}"
217
+ self._log.error(message)
218
+ raise _ex.EConfigLoad(message)
219
+
220
+ def _parse_config_dict(self, config_bytes: bytes, config_path: pathlib.Path):
94
221
 
95
222
  # Read in the raw config, use the file extension to decide which format to expect
96
223
 
97
224
  try:
98
225
 
99
- with config_path.open('r') as config_stream:
226
+ with io.BytesIO(config_bytes) as config_stream:
100
227
 
101
228
  extension = config_path.suffix.lower()
102
229
 
@@ -123,11 +250,54 @@ class ConfigParser(tp.Generic[_T]):
123
250
  self._log.error(err)
124
251
  raise _ex.EConfigParse(err) from e
125
252
 
126
- except yaml.parser.ParserError as e:
253
+ except (yaml.parser.ParserError, yaml.reader.ReaderError) as e:
127
254
  err = f"Config file contains invalid YAML ({str(e)})"
128
255
  self._log.error(err)
129
256
  raise _ex.EConfigParse(err) from e
130
257
 
258
+ def _get_loader(self, resolved_url: _urlp.ParseResult) -> _config_ext.IConfigLoader:
259
+
260
+ protocol = resolved_url.scheme
261
+ loader_config = _config.PluginConfig(protocol)
262
+
263
+ if not _plugins.PluginManager.is_plugin_available(_config_ext.IConfigLoader, protocol):
264
+ message = f"No config loader available for protocol [{protocol}]: {self._url_to_str(resolved_url)}"
265
+ self._log.error(message)
266
+ raise _ex.EConfigLoad(message)
267
+
268
+ return _plugins.PluginManager.load_config_plugin(_config_ext.IConfigLoader, loader_config)
269
+
270
+ @staticmethod
271
+ def _url_to_str(url: _urlp.ParseResult) -> str:
272
+
273
+ if url.scheme == "file" and not url.netloc:
274
+ return url.path
275
+ else:
276
+ return url.geturl()
277
+
278
+
279
+ class ConfigParser(tp.Generic[_T]):
280
+
281
+ # The metaclass for generic types varies between versions of the typing library
282
+ # To work around this, detect the correct metaclass by inspecting a generic type variable
283
+ __generic_metaclass = type(tp.List[object])
284
+
285
+ __primitive_types: tp.Dict[type, callable] = {
286
+ bool: bool,
287
+ int: int,
288
+ float: float,
289
+ str: str,
290
+ decimal.Decimal: decimal.Decimal
291
+ # TODO: Date (requires type system)
292
+ # TODO: Datetime (requires type system)
293
+ }
294
+
295
+ def __init__(self, config_class: _T.__class__, dev_mode_locations: tp.List[str] = None):
296
+ self._log = _util.logger_for_object(self)
297
+ self._config_class = config_class
298
+ self._dev_mode_locations = dev_mode_locations or []
299
+ self._errors = []
300
+
131
301
  def parse(self, config_dict: dict, config_file: tp.Union[str, pathlib.Path] = None) -> _T:
132
302
 
133
303
  # If config is empty, return a default (blank) config