splime 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spl/__init__.py +14 -0
- spl/client.py +1364 -0
- spl/core/__init__.py +23 -0
- spl/core/common.py +350 -0
- spl/core/entities/__init__.py +0 -0
- spl/core/entities/adapter.py +210 -0
- spl/core/entities/artifact.py +141 -0
- spl/core/entities/control.py +45 -0
- spl/core/entities/distribution.py +65 -0
- spl/core/entities/function.py +254 -0
- spl/core/entities/local_function.py +286 -0
- spl/core/entities/misc.py +14 -0
- spl/core/entities/module.py +88 -0
- spl/core/entities/node.py +286 -0
- spl/core/entities/node_function.py +79 -0
- spl/core/entities/node_remote.py +295 -0
- spl/core/entities/pipeline.py +436 -0
- spl/core/entities/scalar.py +55 -0
- spl/core/ir/__init__.py +0 -0
- spl/core/ir/common.py +34 -0
- spl/core/ir/parse.py +79 -0
- spl/core/ir/unparse.py +29 -0
- spl/core/ir/utils.py +163 -0
- spl/daemon/__init__.py +23 -0
- spl/daemon/__main__.py +11 -0
- spl/daemon/cli.py +582 -0
- spl/daemon/client.py +43 -0
- spl/daemon/docker_environment.py +329 -0
- spl/daemon/docker_pool.py +516 -0
- spl/daemon/environment.py +228 -0
- spl/daemon/environment_base.py +479 -0
- spl/daemon/heartbeat_service.py +119 -0
- spl/daemon/metadata.py +427 -0
- spl/daemon/remote_client.py +457 -0
- spl/daemon/repositories/__init__.py +17 -0
- spl/daemon/repositories/env.py +323 -0
- spl/daemon/repositories/library.py +181 -0
- spl/daemon/repositories/object.py +997 -0
- spl/daemon/repositories/run.py +279 -0
- spl/daemon/repositories/server_connection.py +657 -0
- spl/daemon/repositories/sync_event.py +129 -0
- spl/daemon/routes/__init__.py +1 -0
- spl/daemon/routes/_helpers.py +147 -0
- spl/daemon/routes/artifacts.py +77 -0
- spl/daemon/routes/diagnostics.py +114 -0
- spl/daemon/routes/envs.py +82 -0
- spl/daemon/routes/libraries.py +129 -0
- spl/daemon/routes/objects.py +174 -0
- spl/daemon/routes/remote.py +56 -0
- spl/daemon/routes/runs.py +96 -0
- spl/daemon/routes/server_connections.py +86 -0
- spl/daemon/runtime_backend.py +368 -0
- spl/daemon/runtime_config.py +133 -0
- spl/daemon/runtime_dependencies.py +459 -0
- spl/daemon/secret_store.py +187 -0
- spl/daemon/server.py +2224 -0
- spl/daemon/server_connection.py +267 -0
- spl/daemon/services/__init__.py +1 -0
- spl/daemon/services/sync.py +76 -0
- spl/daemon/signature.py +376 -0
- spl/daemon/storage_base.py +542 -0
- spl/daemon/store.py +436 -0
- spl/daemon/worker.py +526 -0
- spl/daemon_client.py +945 -0
- spl/pipeline_widget.py +1452 -0
- spl/py.typed +0 -0
- spl/server_client.py +787 -0
- splime-0.1.2.dist-info/METADATA +189 -0
- splime-0.1.2.dist-info/RECORD +74 -0
- splime-0.1.2.dist-info/WHEEL +5 -0
- splime-0.1.2.dist-info/entry_points.txt +2 -0
- splime-0.1.2.dist-info/licenses/LICENSE +201 -0
- splime-0.1.2.dist-info/licenses/NOTICE +8 -0
- splime-0.1.2.dist-info/top_level.txt +1 -0
spl/core/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import spl.core.entities.adapter
|
|
2
|
+
import spl.core.entities.artifact
|
|
3
|
+
import spl.core.entities.distribution
|
|
4
|
+
import spl.core.entities.function
|
|
5
|
+
# Registered before `module` so local user functions are inlined instead of
|
|
6
|
+
# being captured as a bare `from local_module import ...` (see local_function).
|
|
7
|
+
import spl.core.entities.local_function
|
|
8
|
+
import spl.core.entities.misc
|
|
9
|
+
import spl.core.entities.module
|
|
10
|
+
import spl.core.entities.node
|
|
11
|
+
import spl.core.entities.node_function
|
|
12
|
+
import spl.core.entities.node_remote
|
|
13
|
+
import spl.core.entities.pipeline
|
|
14
|
+
import spl.core.entities.scalar # noqa: F401
|
|
15
|
+
from spl.core.entities.node_remote import NodeRemote
|
|
16
|
+
from spl.core.ir.utils import spl_export_to_dir, spl_export_to_file, spl_import_from_file
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
'NodeRemote',
|
|
20
|
+
'spl_export_to_dir',
|
|
21
|
+
'spl_export_to_file',
|
|
22
|
+
'spl_import_from_file'
|
|
23
|
+
]
|
spl/core/common.py
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import shutil
|
|
3
|
+
import tempfile
|
|
4
|
+
import weakref
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from dataclasses import dataclass, replace
|
|
7
|
+
from functools import reduce
|
|
8
|
+
from itertools import groupby
|
|
9
|
+
from operator import itemgetter
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from types import FunctionType
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from spl.core.entities.adapter import Adapter
|
|
15
|
+
from spl.core.entities.artifact import ArtifactRef, compute_sha256
|
|
16
|
+
from spl.core.entities.node import (
|
|
17
|
+
DEFAULT_PORT,
|
|
18
|
+
FormattedOutputRef,
|
|
19
|
+
Node,
|
|
20
|
+
NodeInputRef,
|
|
21
|
+
NodeOutputRef,
|
|
22
|
+
)
|
|
23
|
+
from spl.core.entities.node_function import NodeFunction
|
|
24
|
+
from spl.core.entities.node_remote import NodeRemote
|
|
25
|
+
from spl.core.entities.pipeline import Pipeline
|
|
26
|
+
from spl.core.entities.scalar import Scalar
|
|
27
|
+
|
|
28
|
+
_JSON_NATIVE_TYPES = {str, int, float, bool, dict, list}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass(frozen = True)
|
|
32
|
+
class PipelineBuilder:
|
|
33
|
+
pipeline: Pipeline
|
|
34
|
+
root: Node
|
|
35
|
+
format: str | None = None
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def lift(x):
|
|
39
|
+
|
|
40
|
+
match x:
|
|
41
|
+
case PipelineBuilder():
|
|
42
|
+
return x
|
|
43
|
+
|
|
44
|
+
case FunctionType():
|
|
45
|
+
root = NodeFunction(x)
|
|
46
|
+
return PipelineBuilder(
|
|
47
|
+
Pipeline(nodes = {root}, links = set()),
|
|
48
|
+
root)
|
|
49
|
+
|
|
50
|
+
case NodeFunction():
|
|
51
|
+
root = x
|
|
52
|
+
return PipelineBuilder(
|
|
53
|
+
Pipeline(nodes = {root}, links = set()),
|
|
54
|
+
root)
|
|
55
|
+
|
|
56
|
+
case NodeRemote():
|
|
57
|
+
root = x
|
|
58
|
+
return PipelineBuilder(
|
|
59
|
+
Pipeline(nodes = {root}, links = set()),
|
|
60
|
+
root)
|
|
61
|
+
|
|
62
|
+
case _:
|
|
63
|
+
raise ValueError(x)
|
|
64
|
+
|
|
65
|
+
def get_input_node_refs(self, port_name: str, is_free: bool):
|
|
66
|
+
node_refs = [
|
|
67
|
+
NodeInputRef(node, port)
|
|
68
|
+
for node in self.pipeline.nodes
|
|
69
|
+
for port in node.inputs
|
|
70
|
+
if port.name == port_name]
|
|
71
|
+
|
|
72
|
+
if is_free:
|
|
73
|
+
bound_refs = set(map(itemgetter(0), self.pipeline.links))
|
|
74
|
+
node_refs = [x for x in node_refs if x not in bound_refs]
|
|
75
|
+
|
|
76
|
+
return node_refs
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def bind(self, **kwargs):
|
|
80
|
+
return self._bind(kwargs, is_strict = True, is_free = False)
|
|
81
|
+
|
|
82
|
+
def bind_all(self, **kwargs):
|
|
83
|
+
return self._bind(kwargs, is_strict = False, is_free = True)
|
|
84
|
+
|
|
85
|
+
def _bind(self, kwargs, is_strict: bool, is_free: bool):
|
|
86
|
+
pipeline = self.pipeline
|
|
87
|
+
for port_name, v in kwargs.items():
|
|
88
|
+
match self.get_input_node_refs(port_name, is_free):
|
|
89
|
+
case []:
|
|
90
|
+
raise ValueError('node(s) for port `{}` is not found'.format(port_name))
|
|
91
|
+
|
|
92
|
+
case [ref]:
|
|
93
|
+
pipeline = self._update_pipeline(pipeline, ref, v)
|
|
94
|
+
|
|
95
|
+
case refs:
|
|
96
|
+
if is_strict:
|
|
97
|
+
raise ValueError('ambigious node for port `{}`'.format(port_name))
|
|
98
|
+
|
|
99
|
+
pipeline = reduce(
|
|
100
|
+
lambda acc, ref: self._update_pipeline(acc, ref, v),
|
|
101
|
+
refs,
|
|
102
|
+
pipeline)
|
|
103
|
+
return PipelineBuilder(
|
|
104
|
+
pipeline = pipeline,
|
|
105
|
+
root = self.root,
|
|
106
|
+
format = self.format)
|
|
107
|
+
|
|
108
|
+
def alias(self, name):
|
|
109
|
+
return replace(self, pipeline = self.pipeline.add_alias(self.root, name))
|
|
110
|
+
|
|
111
|
+
def as_format(self, format: str) -> 'PipelineBuilder':
|
|
112
|
+
"""Return a builder whose output edge uses an artifact format."""
|
|
113
|
+
|
|
114
|
+
if not isinstance(format, str):
|
|
115
|
+
raise TypeError('pipeline builder format must be a string')
|
|
116
|
+
if not format:
|
|
117
|
+
raise ValueError('pipeline builder format must be a non-empty string')
|
|
118
|
+
return replace(self, format = format)
|
|
119
|
+
|
|
120
|
+
@staticmethod
|
|
121
|
+
def _update_pipeline(pipeline, ref, v):
|
|
122
|
+
match v:
|
|
123
|
+
case PipelineBuilder():
|
|
124
|
+
output_ref = NodeOutputRef(
|
|
125
|
+
v.root,
|
|
126
|
+
v.root.get_output_port(DEFAULT_PORT))
|
|
127
|
+
link_value = (
|
|
128
|
+
output_ref
|
|
129
|
+
if v.format is None
|
|
130
|
+
else FormattedOutputRef(output_ref, v.format))
|
|
131
|
+
return (pipeline | v.pipeline).add_link(
|
|
132
|
+
ref,
|
|
133
|
+
link_value)
|
|
134
|
+
case _:
|
|
135
|
+
return pipeline.add_link(
|
|
136
|
+
ref,
|
|
137
|
+
Scalar(v))
|
|
138
|
+
|
|
139
|
+
def render(self, name: str | None = None):
|
|
140
|
+
return replace(self.pipeline, name = name)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
lift = PipelineBuilder.lift
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def encode(value: Any, adapter: Adapter, artifacts_dir: Path) -> ArtifactRef:
|
|
147
|
+
"""Materialize a value with an adapter and return its artifact reference."""
|
|
148
|
+
|
|
149
|
+
fd, artifact_path_value = tempfile.mkstemp(
|
|
150
|
+
prefix = 'artifact-',
|
|
151
|
+
dir = artifacts_dir)
|
|
152
|
+
os.close(fd)
|
|
153
|
+
artifact_path = Path(artifact_path_value)
|
|
154
|
+
|
|
155
|
+
try:
|
|
156
|
+
adapter.save(str(artifact_path), value)
|
|
157
|
+
except BaseException:
|
|
158
|
+
artifact_path.unlink(missing_ok = True)
|
|
159
|
+
raise
|
|
160
|
+
|
|
161
|
+
size = artifact_path.stat().st_size
|
|
162
|
+
sha256 = compute_sha256(artifact_path)
|
|
163
|
+
return ArtifactRef(
|
|
164
|
+
key = adapter.key,
|
|
165
|
+
uri = str(artifact_path),
|
|
166
|
+
sha256 = sha256,
|
|
167
|
+
size = size)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def decode(ref: ArtifactRef, adapter: Adapter) -> Any:
|
|
171
|
+
"""Load an artifact reference with an adapter after validating its digest."""
|
|
172
|
+
|
|
173
|
+
if ref.key != adapter.key:
|
|
174
|
+
raise ValueError('artifact ref key does not match adapter')
|
|
175
|
+
|
|
176
|
+
artifact_path = Path(ref.uri)
|
|
177
|
+
if artifact_path.stat().st_size != ref.size:
|
|
178
|
+
raise ValueError('artifact ref size does not match file')
|
|
179
|
+
if compute_sha256(artifact_path) != ref.sha256:
|
|
180
|
+
raise ValueError('artifact ref sha256 does not match file')
|
|
181
|
+
return adapter.load(str(artifact_path))
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class Deployment:
|
|
185
|
+
def __init__(self, client=None, pipeline=None):
|
|
186
|
+
if pipeline is None:
|
|
187
|
+
pipeline = client
|
|
188
|
+
client = None
|
|
189
|
+
self._client = client
|
|
190
|
+
self._pipeline = pipeline
|
|
191
|
+
|
|
192
|
+
def setup(self):
|
|
193
|
+
pass
|
|
194
|
+
|
|
195
|
+
def teardown(self):
|
|
196
|
+
pass
|
|
197
|
+
|
|
198
|
+
def run(self, **kwargs):
|
|
199
|
+
return Run(self._callback, self._pipeline, **kwargs)
|
|
200
|
+
|
|
201
|
+
def _callback(self, node, kwargs):
|
|
202
|
+
final_kwargs = {port.name: v for port, v in kwargs.items()}
|
|
203
|
+
output_port = self._single_output_port(node)
|
|
204
|
+
match node:
|
|
205
|
+
case NodeFunction():
|
|
206
|
+
return {output_port.name: node.func(**final_kwargs)}
|
|
207
|
+
|
|
208
|
+
case NodeRemote():
|
|
209
|
+
if self._client is None:
|
|
210
|
+
raise RuntimeError('remote node execution requires a client')
|
|
211
|
+
return {output_port.name: self._client.run_node(node, final_kwargs)}
|
|
212
|
+
|
|
213
|
+
case _:
|
|
214
|
+
raise ValueError(node)
|
|
215
|
+
|
|
216
|
+
@staticmethod
|
|
217
|
+
def _single_output_port(node):
|
|
218
|
+
outputs = node.outputs or []
|
|
219
|
+
if len(outputs) != 1:
|
|
220
|
+
raise RuntimeError(
|
|
221
|
+
'node {} has {} outputs; local Deployment currently supports '
|
|
222
|
+
'exactly one output and requires an explicit daemon/server '
|
|
223
|
+
'output selector for multi-output pipelines'.format(
|
|
224
|
+
node,
|
|
225
|
+
len(outputs)))
|
|
226
|
+
return outputs[0]
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class Run:
|
|
230
|
+
def __init__(
|
|
231
|
+
self,
|
|
232
|
+
callback: Callable[..., dict[str, Any]],
|
|
233
|
+
pipeline: Pipeline,
|
|
234
|
+
**kwargs: Any) -> None:
|
|
235
|
+
self._callback = callback
|
|
236
|
+
self._pipeline = pipeline
|
|
237
|
+
self._kwargs = kwargs
|
|
238
|
+
self._deps: dict[Node, dict[Any, Any]] = {
|
|
239
|
+
k: dict(map(itemgetter(slice(1, None)), vs))
|
|
240
|
+
for k, vs in groupby(
|
|
241
|
+
sorted(
|
|
242
|
+
[(x.node, x.port, y) for (x, y) in pipeline.links],
|
|
243
|
+
key = lambda x: hash(x[0])),
|
|
244
|
+
itemgetter(0))}
|
|
245
|
+
self._results: dict[Node, dict[str, Any]] = dict()
|
|
246
|
+
self._artifact_refs: dict[tuple[Node, str, str], ArtifactRef] = dict()
|
|
247
|
+
self._artifacts_dir: Path | None = None
|
|
248
|
+
self._artifacts_finalizer: Any = None
|
|
249
|
+
self._closed = False
|
|
250
|
+
|
|
251
|
+
def __enter__(self) -> 'Run':
|
|
252
|
+
self._ensure_open()
|
|
253
|
+
return self
|
|
254
|
+
|
|
255
|
+
def __exit__(self, exc_type: Any, exc: Any, traceback: Any) -> None:
|
|
256
|
+
self.close()
|
|
257
|
+
|
|
258
|
+
def close(self) -> None:
|
|
259
|
+
if not self._closed:
|
|
260
|
+
if self._artifacts_finalizer is not None:
|
|
261
|
+
self._artifacts_finalizer()
|
|
262
|
+
self._closed = True
|
|
263
|
+
|
|
264
|
+
def _ensure_open(self) -> None:
|
|
265
|
+
if self._closed:
|
|
266
|
+
raise RuntimeError('run is closed')
|
|
267
|
+
|
|
268
|
+
def _get_artifacts_dir(self) -> Path:
|
|
269
|
+
self._ensure_open()
|
|
270
|
+
if self._artifacts_dir is None:
|
|
271
|
+
artifacts_dir = Path(tempfile.mkdtemp(prefix = 'spl-run-'))
|
|
272
|
+
self._artifacts_dir = artifacts_dir
|
|
273
|
+
self._artifacts_finalizer = weakref.finalize(
|
|
274
|
+
self,
|
|
275
|
+
shutil.rmtree,
|
|
276
|
+
artifacts_dir,
|
|
277
|
+
ignore_errors = True)
|
|
278
|
+
return self._artifacts_dir
|
|
279
|
+
|
|
280
|
+
def _round_trip_artifact(
|
|
281
|
+
self,
|
|
282
|
+
value: Any,
|
|
283
|
+
source_ref: NodeOutputRef | None = None,
|
|
284
|
+
adapter_format: str | None = None) -> Any:
|
|
285
|
+
if adapter_format is None and type(value) in _JSON_NATIVE_TYPES:
|
|
286
|
+
return value
|
|
287
|
+
|
|
288
|
+
adapter = self._pipeline.resolve_adapter(
|
|
289
|
+
py_type = type(value),
|
|
290
|
+
format = adapter_format)
|
|
291
|
+
if adapter is None:
|
|
292
|
+
if adapter_format is not None:
|
|
293
|
+
raise ValueError(
|
|
294
|
+
'pipeline adapter is not found for python type ({}) '
|
|
295
|
+
'and format `{}`'.format(type(value), adapter_format))
|
|
296
|
+
return value
|
|
297
|
+
|
|
298
|
+
self._ensure_open()
|
|
299
|
+
if source_ref is None:
|
|
300
|
+
ref = encode(value, adapter, self._get_artifacts_dir())
|
|
301
|
+
else:
|
|
302
|
+
cache_key = (source_ref.node, source_ref.port.name, adapter.key)
|
|
303
|
+
if cache_key not in self._artifact_refs:
|
|
304
|
+
self._artifact_refs[cache_key] = encode(
|
|
305
|
+
value,
|
|
306
|
+
adapter,
|
|
307
|
+
self._get_artifacts_dir())
|
|
308
|
+
ref = self._artifact_refs[cache_key]
|
|
309
|
+
return decode(ref, adapter)
|
|
310
|
+
|
|
311
|
+
def _get_input(self, x: Any) -> Any:
|
|
312
|
+
match x:
|
|
313
|
+
case Scalar():
|
|
314
|
+
return self._round_trip_artifact(x.value)
|
|
315
|
+
|
|
316
|
+
case NodeOutputRef():
|
|
317
|
+
return self._round_trip_artifact(
|
|
318
|
+
(self._get_result(x.node))[x.port.name],
|
|
319
|
+
x)
|
|
320
|
+
|
|
321
|
+
case FormattedOutputRef():
|
|
322
|
+
return self._round_trip_artifact(
|
|
323
|
+
(self._get_result(x.out_ref.node))[x.out_ref.port.name],
|
|
324
|
+
source_ref = x.out_ref,
|
|
325
|
+
adapter_format = x.format)
|
|
326
|
+
|
|
327
|
+
case _: raise ValueError(x)
|
|
328
|
+
|
|
329
|
+
def _get_result(self, node: Node) -> dict[str, Any]:
|
|
330
|
+
if node not in self._results:
|
|
331
|
+
self._ensure_open()
|
|
332
|
+
kwargs = {
|
|
333
|
+
port: self._round_trip_artifact(self._kwargs[port.name])
|
|
334
|
+
for port in node.inputs
|
|
335
|
+
if port.name in self._kwargs}
|
|
336
|
+
|
|
337
|
+
if node in self._deps:
|
|
338
|
+
kwargs = kwargs | {
|
|
339
|
+
port: self._get_input(v)
|
|
340
|
+
for port, v in self._deps[node].items()}
|
|
341
|
+
|
|
342
|
+
self._results[node] = self._callback(node, kwargs)
|
|
343
|
+
return self._results[node]
|
|
344
|
+
|
|
345
|
+
def __getitem__(self, node: Node) -> dict[str, Any]:
|
|
346
|
+
try:
|
|
347
|
+
return self._get_result(node)
|
|
348
|
+
except BaseException:
|
|
349
|
+
self.close()
|
|
350
|
+
raise
|
|
File without changes
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from types import FunctionType
|
|
6
|
+
from typing import Any, Generator, cast
|
|
7
|
+
|
|
8
|
+
import yaml
|
|
9
|
+
|
|
10
|
+
from spl.core.entities.distribution import DDistribution
|
|
11
|
+
from spl.core.entities.function import get_function_metadata
|
|
12
|
+
from spl.core.ir.common import DBase
|
|
13
|
+
from spl.core.ir.parse import _branch, ir_parse
|
|
14
|
+
from spl.core.ir.unparse import ir_unparse
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _validate_non_empty_string(name: str, value: str) -> None:
|
|
18
|
+
if not isinstance(value, str):
|
|
19
|
+
raise TypeError('adapter {} must be a string'.format(name))
|
|
20
|
+
if not value:
|
|
21
|
+
raise ValueError('adapter {} must be a non-empty string'.format(name))
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _format_from_key(key: str) -> str:
|
|
25
|
+
_validate_non_empty_string('key', key)
|
|
26
|
+
key_head, separator, key_format = key.rpartition('@')
|
|
27
|
+
if not key_head or not separator or not key_format:
|
|
28
|
+
raise ValueError('adapter key must be `<python_type>@<format>`')
|
|
29
|
+
return key_format
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def make_key(py_type: type[Any], format: str) -> str:
|
|
33
|
+
"""Return the stable adapter key for a Python type and storage format."""
|
|
34
|
+
|
|
35
|
+
if not isinstance(py_type, type):
|
|
36
|
+
raise TypeError('adapter python type must be a type')
|
|
37
|
+
_validate_non_empty_string('format', format)
|
|
38
|
+
return '{}.{}@{}'.format(py_type.__module__, py_type.__qualname__, format)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _validate_function(name: str, value: Callable[..., Any]) -> None:
|
|
42
|
+
if not isinstance(value, FunctionType):
|
|
43
|
+
raise TypeError('adapter {} must be a function'.format(name))
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _normalize_distributions(value: Any) -> tuple[DDistribution, ...]:
|
|
47
|
+
if not isinstance(value, tuple | list):
|
|
48
|
+
raise TypeError('adapter distributions must be a tuple')
|
|
49
|
+
|
|
50
|
+
distributions = tuple(value)
|
|
51
|
+
if any(not isinstance(x, DDistribution) for x in distributions):
|
|
52
|
+
raise TypeError('adapter distributions must contain DDistribution values')
|
|
53
|
+
return tuple(sorted(distributions))
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _validate_py_type(value: type[Any] | None) -> None:
|
|
57
|
+
if value is not None and not isinstance(value, type):
|
|
58
|
+
raise TypeError('adapter py_type must be a type or None')
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _validate_key_format(key: str, py_type: type[Any] | None, format: str) -> None:
|
|
62
|
+
_validate_non_empty_string('format', format)
|
|
63
|
+
if _format_from_key(key) != format:
|
|
64
|
+
raise ValueError('adapter key format does not match format')
|
|
65
|
+
if py_type is not None and key != make_key(py_type, format):
|
|
66
|
+
raise ValueError('adapter key does not match python type and format')
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _function_name(func: Callable[..., Any]) -> str:
|
|
70
|
+
metadata = get_function_metadata(cast(FunctionType, func))
|
|
71
|
+
return cast(str, metadata.name)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@dataclass(frozen = True)
|
|
75
|
+
class Adapter:
|
|
76
|
+
"""Versioned save/load pair for materializing values as artifacts."""
|
|
77
|
+
|
|
78
|
+
key: str
|
|
79
|
+
save: Callable[..., Any]
|
|
80
|
+
load: Callable[..., Any]
|
|
81
|
+
py_type: type[Any] | None
|
|
82
|
+
format: str
|
|
83
|
+
distributions: tuple[DDistribution, ...] = ()
|
|
84
|
+
|
|
85
|
+
def __post_init__(self) -> None:
|
|
86
|
+
_validate_function('save', self.save)
|
|
87
|
+
_validate_function('load', self.load)
|
|
88
|
+
_validate_py_type(self.py_type)
|
|
89
|
+
_validate_key_format(
|
|
90
|
+
key = self.key,
|
|
91
|
+
py_type = self.py_type,
|
|
92
|
+
format = self.format)
|
|
93
|
+
object.__setattr__(
|
|
94
|
+
self,
|
|
95
|
+
'distributions',
|
|
96
|
+
_normalize_distributions(self.distributions))
|
|
97
|
+
|
|
98
|
+
def __hash__(self) -> int:
|
|
99
|
+
return hash((
|
|
100
|
+
self.key,
|
|
101
|
+
_function_name(self.save),
|
|
102
|
+
_function_name(self.load),
|
|
103
|
+
self.distributions))
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@dataclass(frozen = True)
|
|
107
|
+
class DAdapter(DBase):
|
|
108
|
+
"""Serialized Adapter value for pipeline YAML."""
|
|
109
|
+
|
|
110
|
+
key: str
|
|
111
|
+
save: str
|
|
112
|
+
load: str
|
|
113
|
+
distributions: tuple[DDistribution, ...] = ()
|
|
114
|
+
|
|
115
|
+
def __post_init__(self) -> None:
|
|
116
|
+
_format_from_key(self.key)
|
|
117
|
+
_validate_non_empty_string('save', self.save)
|
|
118
|
+
_validate_non_empty_string('load', self.load)
|
|
119
|
+
object.__setattr__(
|
|
120
|
+
self,
|
|
121
|
+
'distributions',
|
|
122
|
+
_normalize_distributions(self.distributions))
|
|
123
|
+
|
|
124
|
+
def __hash__(self) -> int:
|
|
125
|
+
return hash((
|
|
126
|
+
self.key,
|
|
127
|
+
self.save,
|
|
128
|
+
self.load,
|
|
129
|
+
self.distributions))
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
yaml.add_representer(
|
|
133
|
+
DAdapter,
|
|
134
|
+
lambda dumper, data: dumper.represent_mapping('!DAdapter', {
|
|
135
|
+
'key': data.key,
|
|
136
|
+
'save': data.save,
|
|
137
|
+
'load': data.load,
|
|
138
|
+
'distributions': list(data.distributions)}))
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _construct_dadapter(loader: Any, node: Any) -> DAdapter:
|
|
142
|
+
return DAdapter(**loader.construct_mapping(node, deep = True))
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
yaml.add_constructor(
|
|
146
|
+
'!DAdapter',
|
|
147
|
+
_construct_dadapter)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
@ir_parse.register(
|
|
151
|
+
lambda x: isinstance(x, Adapter))
|
|
152
|
+
def _ir_parse__adapter(
|
|
153
|
+
x: Adapter,
|
|
154
|
+
name: str | None = None) -> _branch:
|
|
155
|
+
def mk_dependencies(frame_offset: int) -> Generator[Any]:
|
|
156
|
+
yield ir_parse(x.save)
|
|
157
|
+
yield ir_parse(x.load)
|
|
158
|
+
|
|
159
|
+
return _branch(
|
|
160
|
+
x,
|
|
161
|
+
lambda: DAdapter(
|
|
162
|
+
key = x.key,
|
|
163
|
+
save = _function_name(x.save),
|
|
164
|
+
load = _function_name(x.load),
|
|
165
|
+
distributions = x.distributions),
|
|
166
|
+
mk_dependencies)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _unparse_distributions(distributions: tuple[DDistribution, ...]) -> ast.Tuple:
|
|
170
|
+
return ast.Tuple(
|
|
171
|
+
elts = [
|
|
172
|
+
ast.Call(
|
|
173
|
+
func = ast.Name(id = 'DDistribution', ctx = ast.Load()),
|
|
174
|
+
keywords = [
|
|
175
|
+
ast.keyword(
|
|
176
|
+
arg = 'package',
|
|
177
|
+
value = ast.Constant(value = x.package)),
|
|
178
|
+
ast.keyword(
|
|
179
|
+
arg = 'version',
|
|
180
|
+
value = ast.Constant(value = x.version))])
|
|
181
|
+
for x in distributions],
|
|
182
|
+
ctx = ast.Load())
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@ir_unparse.register(
|
|
186
|
+
lambda x: isinstance(x, DAdapter))
|
|
187
|
+
def _ir_unparse__adapter(x: DAdapter, source: Path) -> Generator[ast.stmt]:
|
|
188
|
+
yield ast.Assign(
|
|
189
|
+
targets = [ast.Name(id = '_adapter', ctx = ast.Store())],
|
|
190
|
+
value = ast.Call(
|
|
191
|
+
func = ast.Name(id = 'Adapter', ctx = ast.Load()),
|
|
192
|
+
keywords = [
|
|
193
|
+
ast.keyword(
|
|
194
|
+
arg = 'key',
|
|
195
|
+
value = ast.Constant(value = x.key)),
|
|
196
|
+
ast.keyword(
|
|
197
|
+
arg = 'save',
|
|
198
|
+
value = ast.Name(id = x.save, ctx = ast.Load())),
|
|
199
|
+
ast.keyword(
|
|
200
|
+
arg = 'load',
|
|
201
|
+
value = ast.Name(id = x.load, ctx = ast.Load())),
|
|
202
|
+
ast.keyword(
|
|
203
|
+
arg = 'py_type',
|
|
204
|
+
value = ast.Constant(value = None)),
|
|
205
|
+
ast.keyword(
|
|
206
|
+
arg = 'format',
|
|
207
|
+
value = ast.Constant(value = _format_from_key(x.key))),
|
|
208
|
+
ast.keyword(
|
|
209
|
+
arg = 'distributions',
|
|
210
|
+
value = _unparse_distributions(x.distributions))]))
|