splime 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. spl/__init__.py +14 -0
  2. spl/client.py +1364 -0
  3. spl/core/__init__.py +23 -0
  4. spl/core/common.py +350 -0
  5. spl/core/entities/__init__.py +0 -0
  6. spl/core/entities/adapter.py +210 -0
  7. spl/core/entities/artifact.py +141 -0
  8. spl/core/entities/control.py +45 -0
  9. spl/core/entities/distribution.py +65 -0
  10. spl/core/entities/function.py +254 -0
  11. spl/core/entities/local_function.py +286 -0
  12. spl/core/entities/misc.py +14 -0
  13. spl/core/entities/module.py +88 -0
  14. spl/core/entities/node.py +286 -0
  15. spl/core/entities/node_function.py +79 -0
  16. spl/core/entities/node_remote.py +295 -0
  17. spl/core/entities/pipeline.py +436 -0
  18. spl/core/entities/scalar.py +55 -0
  19. spl/core/ir/__init__.py +0 -0
  20. spl/core/ir/common.py +34 -0
  21. spl/core/ir/parse.py +79 -0
  22. spl/core/ir/unparse.py +29 -0
  23. spl/core/ir/utils.py +163 -0
  24. spl/daemon/__init__.py +23 -0
  25. spl/daemon/__main__.py +11 -0
  26. spl/daemon/cli.py +582 -0
  27. spl/daemon/client.py +43 -0
  28. spl/daemon/docker_environment.py +329 -0
  29. spl/daemon/docker_pool.py +516 -0
  30. spl/daemon/environment.py +228 -0
  31. spl/daemon/environment_base.py +479 -0
  32. spl/daemon/heartbeat_service.py +119 -0
  33. spl/daemon/metadata.py +427 -0
  34. spl/daemon/remote_client.py +457 -0
  35. spl/daemon/repositories/__init__.py +17 -0
  36. spl/daemon/repositories/env.py +323 -0
  37. spl/daemon/repositories/library.py +181 -0
  38. spl/daemon/repositories/object.py +997 -0
  39. spl/daemon/repositories/run.py +279 -0
  40. spl/daemon/repositories/server_connection.py +657 -0
  41. spl/daemon/repositories/sync_event.py +129 -0
  42. spl/daemon/routes/__init__.py +1 -0
  43. spl/daemon/routes/_helpers.py +147 -0
  44. spl/daemon/routes/artifacts.py +77 -0
  45. spl/daemon/routes/diagnostics.py +114 -0
  46. spl/daemon/routes/envs.py +82 -0
  47. spl/daemon/routes/libraries.py +129 -0
  48. spl/daemon/routes/objects.py +174 -0
  49. spl/daemon/routes/remote.py +56 -0
  50. spl/daemon/routes/runs.py +96 -0
  51. spl/daemon/routes/server_connections.py +86 -0
  52. spl/daemon/runtime_backend.py +368 -0
  53. spl/daemon/runtime_config.py +133 -0
  54. spl/daemon/runtime_dependencies.py +459 -0
  55. spl/daemon/secret_store.py +187 -0
  56. spl/daemon/server.py +2224 -0
  57. spl/daemon/server_connection.py +267 -0
  58. spl/daemon/services/__init__.py +1 -0
  59. spl/daemon/services/sync.py +76 -0
  60. spl/daemon/signature.py +376 -0
  61. spl/daemon/storage_base.py +542 -0
  62. spl/daemon/store.py +436 -0
  63. spl/daemon/worker.py +526 -0
  64. spl/daemon_client.py +945 -0
  65. spl/pipeline_widget.py +1452 -0
  66. spl/py.typed +0 -0
  67. spl/server_client.py +787 -0
  68. splime-0.1.2.dist-info/METADATA +189 -0
  69. splime-0.1.2.dist-info/RECORD +74 -0
  70. splime-0.1.2.dist-info/WHEEL +5 -0
  71. splime-0.1.2.dist-info/entry_points.txt +2 -0
  72. splime-0.1.2.dist-info/licenses/LICENSE +201 -0
  73. splime-0.1.2.dist-info/licenses/NOTICE +8 -0
  74. splime-0.1.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,436 @@
1
+ import ast
2
+ from collections.abc import Callable
3
+ from dataclasses import dataclass, field, replace
4
+ from itertools import chain
5
+ from operator import itemgetter
6
+ from pathlib import Path
7
+ from typing import Any, Generator
8
+
9
+ import yaml
10
+
11
+ import spl.core.entities.adapter as m_adapter
12
+ import spl.core.entities.artifact as m_artifact
13
+ import spl.core.entities.distribution as m_distribution
14
+ import spl.core.entities.node as m_node
15
+ import spl.core.entities.node_function as m_node_function
16
+ import spl.core.entities.scalar as m_scalar
17
+ from spl.core.entities.adapter import Adapter, make_key
18
+ from spl.core.entities.node import (
19
+ FormattedOutputRef,
20
+ Node,
21
+ NodeInputRef,
22
+ NodeOutputRef,
23
+ )
24
+ from spl.core.ir.common import DBase
25
+ from spl.core.ir.parse import _branch, ir_parse
26
+ from spl.core.ir.unparse import ir_unparse
27
+
28
+
29
+ def _as_node_output_ref(value: Any) -> NodeOutputRef | None:
30
+ if isinstance(value, FormattedOutputRef):
31
+ return value.out_ref
32
+ if isinstance(value, NodeOutputRef):
33
+ return value
34
+ return None
35
+
36
+
37
+ @dataclass(frozen = True)
38
+ class Pipeline:
39
+ name: str | None = None
40
+ nodes: set[Node] = field(default_factory = set)
41
+ links: set[tuple[NodeInputRef, Any]] = field(default_factory = set)
42
+ aliases: dict[str, Node] = field(default_factory = dict)
43
+ adapters: dict[str, Adapter] = field(default_factory = dict)
44
+
45
+ def __hash__(self):
46
+ return hash((
47
+ tuple(sorted(map(hash, self.nodes))),
48
+ tuple(sorted(map(hash, self.links))),
49
+ tuple(sorted([
50
+ (key, hash(adapter))
51
+ for key, adapter in self.adapters.items()]))))
52
+
53
+ def __or__(self, other):
54
+ nodes = set.union(self.nodes, other.nodes)
55
+ links = set.union(self.links, other.links)
56
+ aliases = self._merge_aliases(other)
57
+ adapters = self._merge_adapters(other)
58
+ return Pipeline(
59
+ nodes = nodes,
60
+ links = links,
61
+ aliases = aliases,
62
+ adapters = adapters)._validate_consistency()
63
+
64
+ def add_link(self, node_input_ref, value):
65
+ if (node := node_input_ref.node) not in self.nodes:
66
+ raise ValueError('pipeline does not contain input node ({})'.format(node))
67
+ if node_input_ref.port not in node.inputs:
68
+ raise ValueError('pipeline input ref does not belong to node ({})'.format(node_input_ref))
69
+
70
+ if (output_ref := _as_node_output_ref(value)) is not None:
71
+ if (node := output_ref.node) not in self.nodes:
72
+ raise ValueError('pipeline does not contain output node ({})'.format(node))
73
+ if output_ref.port not in node.outputs:
74
+ raise ValueError('pipeline output ref does not belong to node ({})'.format(output_ref))
75
+
76
+ for existing_ref, existing_value in self.links:
77
+ if existing_ref == node_input_ref and existing_value != value:
78
+ raise ValueError(
79
+ 'pipeline input `{}` is already linked'.format(node_input_ref))
80
+
81
+ return Pipeline(
82
+ nodes = self.nodes,
83
+ links = {*self.links, (node_input_ref, value)},
84
+ aliases = self.aliases,
85
+ adapters = self.adapters)._validate_consistency()
86
+
87
+
88
+ def add_alias(self, node, name):
89
+ if not isinstance(name, str) or not name:
90
+ raise ValueError('pipeline alias name must be a non-empty string')
91
+ if node not in self.nodes:
92
+ raise ValueError('pipeline alias points to unknown node ({})'.format(node))
93
+ if name in self.aliases and self.aliases[name] != node:
94
+ raise ValueError('pipeline alias `{}` already points to another node'.format(name))
95
+ return replace(self, aliases = {**self.aliases, name: node})._validate_consistency()
96
+
97
+
98
+ def add_adapter(
99
+ self,
100
+ py_type: type[Any],
101
+ format: str,
102
+ *,
103
+ save: Callable[..., Any],
104
+ load: Callable[..., Any],
105
+ distributions: tuple[Any, ...] = ()) -> 'Pipeline':
106
+ key = make_key(py_type, format)
107
+ adapter = Adapter(
108
+ key = key,
109
+ save = save,
110
+ load = load,
111
+ py_type = py_type,
112
+ format = format,
113
+ distributions = distributions)
114
+ if key in self.adapters and self.adapters[key] != adapter:
115
+ raise ValueError('pipeline adapter conflict: `{}`'.format(key))
116
+ return replace(self, adapters = {**self.adapters, key: adapter})._validate_consistency()
117
+
118
+ def resolve_adapter(
119
+ self,
120
+ *,
121
+ py_type: type[Any] | None = None,
122
+ format: str | None = None,
123
+ key: str | None = None) -> Adapter | None:
124
+ if key is not None and (py_type is not None or format is not None):
125
+ raise ValueError('pipeline adapter lookup accepts key or python type and format')
126
+ if key is None:
127
+ if py_type is None:
128
+ raise ValueError('pipeline adapter lookup requires key or python type')
129
+ if format is not None:
130
+ key = make_key(py_type, format)
131
+ else:
132
+ prefix = '{}.{}@'.format(py_type.__module__, py_type.__qualname__)
133
+ adapters = [
134
+ adapter
135
+ for key, adapter in sorted(self.adapters.items())
136
+ if key.startswith(prefix)]
137
+ if len(adapters) > 1:
138
+ raise ValueError(
139
+ 'pipeline adapter lookup is ambiguous for python type ({})'.format(
140
+ py_type))
141
+ return adapters[0] if adapters else None
142
+ if not isinstance(key, str):
143
+ raise TypeError('pipeline adapter key must be a string')
144
+ if not key:
145
+ raise ValueError('pipeline adapter key must be a non-empty string')
146
+ return self.adapters.get(key)
147
+
148
+ def get_free_inputs(self) -> list[NodeInputRef]:
149
+ return ({
150
+ NodeInputRef(node, port)
151
+ for node in self.nodes
152
+ for port in node.inputs} - set(map(itemgetter(0), self.links)))
153
+
154
+ def get_unbound_inputs(self) -> list[NodeInputRef]:
155
+ return ({
156
+ NodeInputRef(node, port)
157
+ for node in self.nodes
158
+ for port in node.inputs
159
+ if port.default is None} - set(map(itemgetter(0), self.links)))
160
+
161
+ def get_outputs(self) -> list[NodeOutputRef]:
162
+ return ({
163
+ NodeOutputRef(node, port)
164
+ for node in self.nodes
165
+ for port in node.outputs})
166
+
167
+ def get_node_by_alias(self, name):
168
+ return self.aliases[name]
169
+
170
+ def _merge_aliases(self, other):
171
+ aliases = dict(self.aliases)
172
+ for name, node in other.aliases.items():
173
+ if name in aliases and aliases[name] != node:
174
+ raise ValueError('pipeline alias conflict: `{}`'.format(name))
175
+ aliases[name] = node
176
+ return aliases
177
+
178
+ def _merge_adapters(self, other):
179
+ adapters = dict(self.adapters)
180
+ for key, adapter in other.adapters.items():
181
+ if key in adapters and adapters[key] != adapter:
182
+ raise ValueError('pipeline adapter conflict: `{}`'.format(key))
183
+ adapters[key] = adapter
184
+ return adapters
185
+
186
+ def _validate_consistency(self):
187
+ linked_inputs = set()
188
+ for node_input_ref, value in self.links:
189
+ if node_input_ref.node not in self.nodes:
190
+ raise ValueError(
191
+ 'pipeline link target node is not in pipeline ({})'.format(
192
+ node_input_ref.node))
193
+ if node_input_ref.port not in node_input_ref.node.inputs:
194
+ raise ValueError(
195
+ 'pipeline link target port is not on node ({})'.format(
196
+ node_input_ref))
197
+ if node_input_ref in linked_inputs:
198
+ raise ValueError('pipeline input `{}` is linked more than once'.format(node_input_ref))
199
+ linked_inputs.add(node_input_ref)
200
+
201
+ if (output_ref := _as_node_output_ref(value)) is not None:
202
+ if output_ref.node not in self.nodes:
203
+ raise ValueError(
204
+ 'pipeline link source node is not in pipeline ({})'.format(
205
+ output_ref.node))
206
+ if output_ref.port not in output_ref.node.outputs:
207
+ raise ValueError(
208
+ 'pipeline link source port is not on node ({})'.format(
209
+ output_ref))
210
+
211
+ for name, node in self.aliases.items():
212
+ if node not in self.nodes:
213
+ raise ValueError('pipeline alias `{}` points to unknown node'.format(name))
214
+ for key, adapter in self.adapters.items():
215
+ if not isinstance(key, str) or not key:
216
+ raise ValueError('pipeline adapter key must be a non-empty string')
217
+ if not isinstance(adapter, Adapter):
218
+ raise TypeError('pipeline adapter `{}` must be Adapter'.format(key))
219
+ if key != adapter.key:
220
+ raise ValueError('pipeline adapter key mismatch: `{}`'.format(key))
221
+ return self
222
+
223
+
224
+ @dataclass(frozen = True)
225
+ class DPipeline(DBase):
226
+ name: str
227
+ nodes: list[Any]
228
+ links: list[Any]
229
+ aliases: list[list[str]]
230
+ adapters: list[Any] = field(default_factory = list)
231
+
232
+ def __hash__(self):
233
+ return hash((
234
+ tuple(sorted(map(hash, self.nodes))),
235
+ tuple(sorted(map(hash, chain.from_iterable(self.links)))),
236
+ tuple(sorted(map(hash, self.adapters)))))
237
+
238
+ yaml.add_representer(
239
+ DPipeline,
240
+ lambda dumper, data: dumper.represent_mapping('!DPipeline', data.__dict__))
241
+
242
+ yaml.add_constructor(
243
+ '!DPipeline',
244
+ lambda loader, node: DPipeline(**loader.construct_mapping(node)))
245
+
246
+
247
+ @ir_parse.register(
248
+ lambda x: isinstance(x, Pipeline))
249
+ def _ir_parse__pipeline(
250
+ x: Pipeline,
251
+ name: str | None = None):
252
+
253
+ def mk_root():
254
+ return DPipeline(
255
+ name = x.name,
256
+ nodes = [ir_parse(n, name = name).mk_root() for n in x.nodes],
257
+ links = [
258
+ [ir_parse(l_from).mk_root(), ir_parse(l_to).mk_root()]
259
+ for (l_from, l_to) in x.links],
260
+ aliases = [[k, str(v.uuid)] for k, v in x.aliases.items()],
261
+ adapters = [
262
+ ir_parse(adapter).mk_root()
263
+ for _, adapter in sorted(x.adapters.items())])
264
+
265
+ def mk_dependencies(frame_offset):
266
+ return chain.from_iterable([
267
+ *[
268
+ ir_parse(n, name = name).mk_dependencies(frame_offset)
269
+ for n in x.nodes],
270
+ *[
271
+ ir_parse(adapter).mk_dependencies(frame_offset)
272
+ for _, adapter in sorted(x.adapters.items())]])
273
+
274
+ return _branch(
275
+ x,
276
+ mk_root,
277
+ mk_dependencies)
278
+
279
+
280
+ @ir_unparse.register(
281
+ lambda x: isinstance(x, DPipeline))
282
+ def _ir_unparse__pipeline(x: DPipeline, source: Path) -> Generator[ast.stmt]:
283
+
284
+
285
+ # Importing helpers
286
+ # TODO: move to corresponding modules
287
+ yield ast.ImportFrom(
288
+ module = 'uuid',
289
+ names = [
290
+ ast.alias(name = 'UUID')])
291
+
292
+ yield ast.ImportFrom(
293
+ module = m_node.__name__,
294
+ names = [
295
+ ast.alias(name = 'FormattedOutputRef'),
296
+ ast.alias(name = 'NodeInputRef'),
297
+ ast.alias(name = 'NodeOutputRef')])
298
+
299
+ yield ast.ImportFrom(
300
+ module = m_scalar.__name__,
301
+ names = [
302
+ ast.alias(name = 'Scalar')])
303
+
304
+ yield ast.ImportFrom(
305
+ module = m_artifact.__name__,
306
+ names = [
307
+ ast.alias(name = 'ArtifactRef')])
308
+
309
+ yield ast.ImportFrom(
310
+ module = m_adapter.__name__,
311
+ names = [
312
+ ast.alias(name = 'Adapter')])
313
+
314
+ yield ast.ImportFrom(
315
+ module = m_distribution.__name__,
316
+ names = [
317
+ ast.alias(name = 'DDistribution')])
318
+
319
+ yield ast.ImportFrom(
320
+ module = m_node_function.__name__,
321
+ names = [
322
+ ast.alias(name = 'NodeFunction')])
323
+
324
+ yield ast.ImportFrom(
325
+ module = __name__,
326
+ names = [
327
+ ast.alias(name = 'Pipeline')])
328
+
329
+ # _nodes = {}
330
+ yield ast.Assign(
331
+ targets = [ast.Name(id = '_nodes', ctx = ast.Store())],
332
+ value = ast.Dict())
333
+
334
+ for n in x.nodes:
335
+ # _node = ...
336
+ yield from ir_unparse(n, source = source)
337
+
338
+ # _nodes[_node.uuid] = _node
339
+ yield ast.Assign(
340
+ targets = [
341
+ ast.Subscript(
342
+ value = ast.Name(id = '_nodes', ctx = ast.Load()),
343
+ slice = ast.Attribute(
344
+ value = ast.Name(id = '_node', ctx = ast.Load()),
345
+ attr = 'uuid',
346
+ ctx = ast.Load()),
347
+ ctx = ast.Store())],
348
+ value = ast.Name(id = '_node', ctx = ast.Load()))
349
+
350
+ # _links = []
351
+ yield ast.Assign(
352
+ targets = [ast.Name(id = '_links', ctx = ast.Store())],
353
+ value = ast.List())
354
+
355
+ for (link_from, link_to) in x.links:
356
+ # _link_from = ...
357
+ yield from ir_unparse(link_from, source = source)
358
+
359
+ # _link_to = ...
360
+ yield from ir_unparse(link_to, source = source)
361
+
362
+ # _links.append((_link_from, _link_to))
363
+ yield ast.Expr(
364
+ value = ast.Call(
365
+ func = ast.Attribute(
366
+ value = ast.Name(id = '_links', ctx = ast.Load()),
367
+ attr = 'append',
368
+ ctx = ast.Load()),
369
+ args = [
370
+ ast.Tuple(elts = [
371
+ ast.Name(id = '_link_from', ctx = ast.Load()),
372
+ ast.Name(id = '_link_to', ctx = ast.Load())],
373
+ ctx = ast.Load())]))
374
+
375
+ # _adapters = {}
376
+ yield ast.Assign(
377
+ targets = [ast.Name(id = '_adapters', ctx = ast.Store())],
378
+ value = ast.Dict())
379
+
380
+ for adapter in x.adapters:
381
+ # _adapter = ...
382
+ yield from ir_unparse(adapter, source = source)
383
+
384
+ # _adapters[_adapter.key] = _adapter
385
+ yield ast.Assign(
386
+ targets = [
387
+ ast.Subscript(
388
+ value = ast.Name(id = '_adapters', ctx = ast.Load()),
389
+ slice = ast.Attribute(
390
+ value = ast.Name(id = '_adapter', ctx = ast.Load()),
391
+ attr = 'key',
392
+ ctx = ast.Load()),
393
+ ctx = ast.Store())],
394
+ value = ast.Name(id = '_adapter', ctx = ast.Load()))
395
+
396
+ # pipeline = Pipeline(...)
397
+ yield ast.Assign(
398
+ targets = [ast.Name(id = x.name, ctx = ast.Store())],
399
+ value = ast.Call(
400
+ func = ast.Name(id = 'Pipeline', ctx = ast.Load()),
401
+ keywords = [
402
+ ast.keyword(
403
+ arg = 'name',
404
+ value = ast.Constant(value = x.name)),
405
+
406
+ ast.keyword(
407
+ arg = 'nodes',
408
+ value = ast.Set(elts = [
409
+ ast.Starred(value = ast.Call(func = ast.Attribute(
410
+ value = ast.Name(id = '_nodes', ctx = ast.Load()),
411
+ attr = 'values',
412
+ ctx = ast.Load())))])),
413
+
414
+ ast.keyword(
415
+ arg = 'links',
416
+ value = ast.Set(elts = [
417
+ ast.Starred(value = ast.Name(id = '_links', ctx = ast.Load()))])),
418
+
419
+ ast.keyword(
420
+ arg = 'aliases',
421
+ value = ast.Dict(
422
+ keys = [
423
+ ast.Constant(value = k)
424
+ for [k, _] in x.aliases],
425
+ values = [
426
+ ast.Subscript(
427
+ value = ast.Name(id = '_nodes', ctx = ast.Load()),
428
+ slice = ast.Call(
429
+ func = ast.Name(id = 'UUID', ctx = ast.Load()),
430
+ args = [ast.Constant(value = v)]),
431
+ ctx = ast.Load())
432
+ for [_, v] in x.aliases])),
433
+
434
+ ast.keyword(
435
+ arg = 'adapters',
436
+ value = ast.Name(id = '_adapters', ctx = ast.Load()))]))
@@ -0,0 +1,55 @@
1
+ import ast
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import Any, Generator
5
+
6
+ import yaml
7
+
8
+ from spl.core.ir.common import DBase
9
+ from spl.core.ir.parse import _branch, ir_parse
10
+ from spl.core.ir.unparse import ir_unparse
11
+
12
+
13
+ class Scalar:
14
+ value: Any
15
+
16
+ def __init__(self, value):
17
+ self.value = value
18
+
19
+ # def __repr__(self):
20
+ # return repr(self.value)
21
+
22
+ @dataclass(frozen = True)
23
+ class DScalar(DBase):
24
+ value: Any
25
+
26
+ yaml.add_representer(
27
+ DScalar,
28
+ lambda dumper, data: dumper.represent_mapping('!DScalar', data.__dict__))
29
+
30
+ yaml.add_constructor(
31
+ '!DScalar',
32
+ lambda loader, node: DScalar(**loader.construct_mapping(node)))
33
+
34
+ @ir_parse.register(
35
+ lambda x: isinstance(x, Scalar))
36
+ def _ir_parse__scalar(
37
+ x: Scalar,
38
+ name: str | None = None):
39
+ return _branch(
40
+ x,
41
+ lambda: DScalar(x.value),
42
+ lambda frame_offset: [])
43
+
44
+
45
+ @ir_unparse.register(
46
+ lambda x: isinstance(x, DScalar))
47
+ def _ir_unparse__scalar(x: DScalar, source: Path) -> Generator[ast.stmt]:
48
+ yield ast.Assign(
49
+ targets = [ast.Name(id = '_link_to', ctx = ast.Store())],
50
+ value = ast.Call(
51
+ func = ast.Name(id = 'Scalar', ctx = ast.Load()),
52
+ keywords = [
53
+ ast.keyword(
54
+ arg = 'value',
55
+ value = ast.Constant(value = x.value))]))
File without changes
spl/core/ir/common.py ADDED
@@ -0,0 +1,34 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, Callable, Protocol, TypeVar, cast
3
+
4
+
5
+ @dataclass(frozen = True)
6
+ class DBase: pass
7
+
8
+
9
+ _F = TypeVar('_F', bound = Callable[..., Any])
10
+
11
+
12
+ class Dispatcher(Protocol):
13
+ def register(self, p: Callable[[Any], bool]) -> Callable[[_F], _F]: ...
14
+
15
+ def __call__(self, x: Any, *args: Any, **kwargs: Any) -> Any: ...
16
+
17
+
18
+ def mk_dispatcher() -> Dispatcher:
19
+ handlers: list[tuple[Callable[[Any], bool], Callable[..., Any]]] = []
20
+
21
+ def register(p: Callable[[Any], bool]) -> Callable[[_F], _F]:
22
+ def decorator(f: _F) -> _F:
23
+ handlers.append((p, f))
24
+ return f
25
+ return decorator
26
+
27
+ def dispatch(x: Any, *args: Any, **kwargs: Any) -> Any:
28
+ for p, f in handlers:
29
+ if p(x):
30
+ return f(x, *args, **kwargs)
31
+ raise ValueError(x)
32
+
33
+ setattr(dispatch, 'register', register)
34
+ return cast(Dispatcher, dispatch)
spl/core/ir/parse.py ADDED
@@ -0,0 +1,79 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, Callable, Generator
3
+
4
+ from spl.core.ir.common import DBase, mk_dispatcher
5
+
6
+ # ir_parse :: (x: Any, name: str | None = None, dependencies: bool = False)
7
+ ir_parse = mk_dispatcher()
8
+
9
+
10
+ @dataclass(frozen = True)
11
+ class _branch: # noqa: N801
12
+ x: Any
13
+ mk_root: Callable[[], Any]
14
+ mk_dependencies: Callable[[int], Generator[Any]] = lambda _: iter(int, int())
15
+
16
+
17
+ @dataclass(frozen = True)
18
+ class _attach: # noqa: N801
19
+ dependencies: Generator[Any]
20
+
21
+
22
+ @dataclass(frozen = True)
23
+ class _set_cursor: # noqa: N801
24
+ cursor: str
25
+
26
+
27
+ def stack_push(stack, *vs):
28
+ return [*stack, *vs]
29
+
30
+
31
+ def stack_pop(stack):
32
+ match stack:
33
+ case []:
34
+ raise StopIteration()
35
+
36
+ case [*new_stack, v]:
37
+ return (new_stack, v)
38
+
39
+
40
+ def get_top_level_deps(frame_offset: int, xs: list[Any]) -> list[tuple[DBase, list[DBase]]]:
41
+ refs_root = {}
42
+ refs_dependencies = {}
43
+
44
+ cursor = None
45
+
46
+ stack = []
47
+ for x in xs:
48
+ stack = stack_push(stack, ir_parse(x))
49
+
50
+ while len(stack):
51
+ (stack, x) = stack_pop(stack)
52
+ match x:
53
+ case _set_cursor(new_cursor):
54
+ cursor = new_cursor
55
+
56
+ case _branch(x, mk_root, mk_dependencies):
57
+ if (x in refs_root) and (cursor is not None):
58
+ refs_dependencies[cursor] = [*refs_dependencies[cursor], refs_root[x]]
59
+ else:
60
+ root = mk_root()
61
+ dependencies = mk_dependencies(frame_offset)
62
+
63
+ refs_root[x] = root
64
+ refs_dependencies[x] = []
65
+ stack = stack_push(stack, root, _set_cursor(cursor), _attach(dependencies))
66
+ cursor = x
67
+
68
+ case _attach(dependencies):
69
+ stack = stack_push(stack, *dependencies)
70
+
71
+ case _:
72
+ if cursor is not None:
73
+ refs_dependencies[cursor] = [*refs_dependencies[cursor], x]
74
+
75
+ # python 3.7+ maintains order of insertions, we rely on it
76
+ return list(zip(
77
+ refs_root.values(),
78
+ refs_dependencies.values(),
79
+ strict = True))
spl/core/ir/unparse.py ADDED
@@ -0,0 +1,29 @@
1
+ import ast
2
+ from functools import partial
3
+ from itertools import chain
4
+ from pathlib import Path
5
+
6
+ from spl.core.ir.common import DBase, mk_dispatcher
7
+
8
+ IIFE_NAME = '_'
9
+
10
+ # ir_unparse :: (x: Any, source: Path) -> Generator[ast]
11
+ ir_unparse = mk_dispatcher()
12
+
13
+
14
+ def mk_top_level_ast(d: tuple[DBase, list[DBase]], source: Path):
15
+ (root, dependencies) = d
16
+ name = root.name
17
+
18
+ return ast.fix_missing_locations(ast.Module([
19
+ ast.FunctionDef(
20
+ name = IIFE_NAME,
21
+ args = ast.arguments(),
22
+ body = [
23
+ *chain.from_iterable(map(partial(ir_unparse, source = source), dependencies)),
24
+ *ir_unparse(root, source),
25
+ ast.Return(value = ast.Name(name))]),
26
+
27
+ ast.Assign(
28
+ targets = [ast.Name(name, ctx = ast.Store())],
29
+ value = ast.Call(func = ast.Name(id = IIFE_NAME, ctx = ast.Load())))]))