sier2 0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sier2 might be problematic. Click here for more details.

sier2/_dag.py ADDED
@@ -0,0 +1,558 @@
1
+ from ._block import Block, BlockError, BlockState
2
+ from dataclasses import dataclass, field #, KW_ONLY, field
3
+ from collections import defaultdict, deque
4
+ import holoviews as hv
5
+ import threading
6
+ from typing import Any
7
+
8
+ # By default, loops in a dag aren't allowed.
9
+ #
10
+ _DISALLOW_CYCLES = True
11
+
12
+ @dataclass
13
+ class Connection:
14
+ """Define a connection between an output parameter and an input parameter."""
15
+
16
+ src_param_name: str
17
+ dst_param_name: str
18
+
19
+ def __post_init__(self):
20
+ if not self.src_param_name.startswith('out_'):
21
+ raise BlockError('Output params must start with "out_"')
22
+
23
+ if not self.dst_param_name.startswith('in_'):
24
+ raise BlockError('Input params must start with "in_"')
25
+
26
+ @dataclass
27
+ class _InputValues:
28
+ """Record a param value change.
29
+
30
+ When a block updates an output param, the update is queued until
31
+ the block finishes executing. This class is what is queued.
32
+ """
33
+
34
+ # The block to be updated.
35
+ #
36
+ dst: Block
37
+
38
+ # The values to be set before the block executes.
39
+ #
40
+ values: dict[str, Any] = field(default_factory=dict)
41
+
42
+ class _BlockContext:
43
+ """A context manager to wrap the execution of a block within a dag.
44
+
45
+ This default context manager handles the block state, the stopper,
46
+ and converts block execution errors to GimzoError exceptions.
47
+
48
+ This could be done inline, but using a context manager allows
49
+ the context manager to be replaced. For example, a panel-based
50
+ dag runner could use a context manager that incorporates logging
51
+ and displays information in a GUI.
52
+ """
53
+
54
+ def __init__(self, *, block: Block, dag: 'Dag', dag_logger=None):
55
+ self.block = block
56
+ self.dag = dag
57
+ self.dag_logger = dag_logger
58
+
59
+ def __enter__(self):
60
+ self.block._block_state = BlockState.EXECUTING
61
+
62
+ return self.block
63
+
64
+ def __exit__(self, exc_type, exc_val, exc_tb):
65
+ if exc_type is None:
66
+ self.block._block_state = BlockState.WAITING if self.block.user_input else BlockState.SUCCESSFUL
67
+ elif isinstance(exc_type, KeyboardInterrupt):
68
+ self.block_state._block_state = BlockState.INTERRUPTED
69
+ self.dag._stopper.event.set()
70
+ print(f'KEYBOARD INTERRUPT IN BLOCK {self.name}')
71
+ else:
72
+ self.block._block_state = BlockState.ERROR
73
+ msg = f'While in {self.block.name}.execute(): {exc_val}'
74
+ # LOGGER.exception(msg)
75
+ self.dag._stopper.event.set()
76
+
77
+ # Convert the error in the block to a BlockError.
78
+ #
79
+ raise BlockError(f'Block {self.block.name}: {str(exc_val)}') from exc_val
80
+
81
+ return False
82
+
83
+ class _Stopper:
84
+ def __init__(self):
85
+ self.event = threading.Event()
86
+
87
+ @property
88
+ def is_stopped(self):
89
+ return self.event
90
+
91
+ @is_stopped.getter
92
+ def is_stopped(self) -> bool:
93
+ return self.event.is_set()
94
+
95
+ def __repr__(self):
96
+ return f'stopped={self.is_stopped}'
97
+
98
+ class Dag:
99
+ """A directed acyclic graph of blocks."""
100
+
101
+ def __init__(self, *, site: str='Block', title: str, doc: str):
102
+ self._block_pairs: list[tuple[Block, Block]] = []
103
+ self._stopper = _Stopper()
104
+ self.site = site
105
+ self.title = title
106
+ self.doc = doc
107
+
108
+ # We watch output params to be notified when they are set.
109
+ # Events are queued here.
110
+ #
111
+ self._block_queue: deque[_InputValues] = deque()
112
+
113
+ # The context manager class to use to run blocks.
114
+ #
115
+ self._block_context = _BlockContext
116
+
117
+ def _for_each_once(self):
118
+ """Yield each connected block once."""
119
+
120
+ seen = set()
121
+ for s, d in self._block_pairs:
122
+ for g in s, d:
123
+ if g not in seen:
124
+ seen.add(g)
125
+ yield g
126
+
127
+ def stop(self):
128
+ """Stop further execution of Block instances in this dag."""
129
+
130
+ self._stopper.event.set()
131
+
132
+ def unstop(self):
133
+ """Enable further execution of Block instances in this dag."""
134
+
135
+ self._stopper.event.clear()
136
+
137
+ def connect(self, src: Block, dst: Block, *connections: Connection):
138
+ if any(not isinstance(c, Connection) for c in connections):
139
+ raise BlockError('All arguments must be Connection instances')
140
+
141
+ if _DISALLOW_CYCLES:
142
+ if _has_cycle(self._block_pairs + [(src, dst)]):
143
+ raise BlockError('This connection would create a cycle')
144
+
145
+ if src.name==dst.name:
146
+ raise BlockError('Cannot add two blocks with the same name')
147
+
148
+ for g in self._for_each_once():
149
+ if (g is not src and g.name==src.name) or (g is not dst and g.name==dst.name):
150
+ raise BlockError('A block with this name already exists')
151
+
152
+ for s, d in self._block_pairs:
153
+ if src is s and dst is d:
154
+ raise BlockError('These blocks are already connected')
155
+
156
+ if self._block_pairs:
157
+ connected = any(src is s or src is d or dst is s or dst is d for s, d in self._block_pairs)
158
+ if not connected:
159
+ raise BlockError('A new block must connect to existing block')
160
+
161
+ # Group watchers by their attributes.
162
+ # This optimises the number of watchers.
163
+ #
164
+ # If we just add a watcher per param in the loop, then
165
+ # param.update() won't batch the events.
166
+ #
167
+ # src_out_params = defaultdict(list)
168
+ src_out_params = []
169
+
170
+ for conn in connections:
171
+ # dst_param = getattr(dst.param, conn.dst_param_name)
172
+ # if dst_param.allow_refs:
173
+ # raise BlockError(f'Destination parameter {dst}.{inp} must be "allow_refs=True"')
174
+
175
+ src_param = getattr(src.param, conn.src_param_name)
176
+ if src_param.allow_refs:
177
+ raise BlockError(f'Source parameter {src}.{conn.src_param_name} must not be "allow_refs=True"')
178
+
179
+ dst._block_name_map[src.name, conn.src_param_name] = conn.dst_param_name
180
+ src_out_params.append(conn.src_param_name)
181
+
182
+ src.param.watch(lambda *events: self._param_event(dst, *events), src_out_params, onlychanged=False)
183
+ src._block_out_params.extend(src_out_params)
184
+
185
+ self._block_pairs.append((src, dst))
186
+
187
+ def _param_event(self, dst: Block, *events):
188
+ """The callback for a watch event."""
189
+
190
+ # print(f'DAG EVENTS: {events} -> {dst.name}')
191
+ for event in events:
192
+ cls = event.cls.name
193
+ name = event.name
194
+ new = event.new
195
+
196
+ # The input param in the dst block.
197
+ #
198
+ inp = dst._block_name_map[cls, name]
199
+
200
+ # Look for the destination block in the event queue.
201
+ # If found, update the param value dictionary,
202
+ # else append a new item.
203
+ # This ensures that all param updates for a destination
204
+ # block are merged into a single queue item, even if the
205
+ # updates come from different source blocks.
206
+ #
207
+ for item in self._block_queue:
208
+ if dst is item.dst:
209
+ item.values[inp] = new
210
+ break
211
+ else:
212
+ item = _InputValues(dst)
213
+ item.values[inp] = new
214
+ self._block_queue.append(item)
215
+
216
+ def execute(self, *, dag_logger=None):
217
+ """Execute the dag.
218
+
219
+ The dag is executed by iterating through the block events queue
220
+ and popping events from the head of the queue. For each event,
221
+ update the destination block's input parameters and call
222
+ that block's execute() method.
223
+
224
+ If the current destination block has user_flag True,
225
+ the loop will continue to set param values until the queue is empty,
226
+ but no execute() method will be called.
227
+
228
+ To start (or restart) the dag, there must be something in the event queue.
229
+ The first (or current) user_input block must have updated at least one
230
+ output param before the dag's execute() is called.
231
+ """
232
+
233
+ if not self._block_queue:
234
+ # Attempting to execute a dag with no updates is probably a mistake.
235
+ #
236
+ raise BlockError('Nothing to execute')
237
+
238
+ can_execute = True
239
+ while self._block_queue:
240
+ # print(len(self._block_queue), self._block_queue)
241
+ # The user has set the "stop executing" flag.
242
+ # Continue to set params, but don't execute anything
243
+ #
244
+ if self._stopper.is_stopped:
245
+ can_execute = False
246
+
247
+ item = self._block_queue.popleft()
248
+ try:
249
+ item.dst.param.update(item.values)
250
+ except ValueError as e:
251
+ msg = f'While in {item.dst.name} setting a parameter: {e}'
252
+ self._stopper.event.set()
253
+ raise BlockError(msg) from e
254
+
255
+ if can_execute:
256
+ with self._block_context(block=item.dst, dag=self, dag_logger=dag_logger) as g:
257
+ g.execute()
258
+
259
+ if item.dst.user_input:
260
+ # If the current destination block requires user input,
261
+ # continue to set params, but don't execute anything.
262
+ #
263
+ can_execute = False
264
+
265
+ def disconnect(self, g: Block) -> None:
266
+ """Disconnect block g from other blocks.
267
+
268
+ All parameters (input and output) will be disconnected.
269
+
270
+ Parameters
271
+ ----------
272
+ g: Block
273
+ The block to be disconnected.
274
+ """
275
+
276
+ for p, watchers in g.param.watchers.items():
277
+ for watcher in watchers['value']:
278
+ # print(f'disconnect watcher {g.name}.{watcher}')
279
+ g.param.unwatch(watcher)
280
+
281
+ for src, dst in self._block_pairs:
282
+ if dst is g:
283
+ for p, watchers in src.param.watchers.items():
284
+ for watcher in watchers['value']:
285
+ # print(f'disconnect watcher {src.name}.{watcher}')
286
+ src.param.unwatch(watcher)
287
+
288
+ # Remove this block from the dag.
289
+ # Check for sources and destinations.
290
+ #
291
+ self._block_pairs[:] = [(src, dst) for src, dst in self._block_pairs if src is not g and dst is not g]
292
+
293
+ # Because this block is no longer watching anything, the name map can be cleared.
294
+ #
295
+ g._block_name_map.clear()
296
+
297
+ def block_by_name(self, name) -> Block | None:
298
+ """Get a specific block by name."""
299
+
300
+ for s, d in self._block_pairs:
301
+ if s.name==name:
302
+ return s
303
+
304
+ if d.name == name:
305
+ return d
306
+
307
+ return None
308
+
309
+ def get_sorted(self):
310
+ """Return the blocks in this dag in topological order.
311
+
312
+ This is useful for arranging the blocks in a GUI, for example.
313
+
314
+ The returned dictionary is in no particular order:
315
+ the rank values determine the order of the blocks.
316
+
317
+ Returns
318
+ -------
319
+ dict[Block, int]
320
+ A mapping of block to rank
321
+ """
322
+
323
+ return _get_sorted(self._block_pairs)
324
+
325
+ def has_cycle(self):
326
+ return _has_cycle(self._block_pairs)
327
+
328
+ def dump(self):
329
+ """Dump the dag to a serialisable (eg to JSON) dictionary.
330
+
331
+ The blocks and connections are reduced to simple representations.
332
+ There is no need to serialize code: the blocks themselves are assumed
333
+ to be available when loaded - it is just the attributes of the blocks
334
+ that need to be saved.
335
+
336
+ Two sets of attributes in particular are saved.
337
+
338
+ * The name of the block class. Each block has a name by virtue of it
339
+ being a Parameterized subclass.
340
+ * The ``__init__`` parameters, where possible. For each parameter,
341
+ if the block object has a matching instance name, the value of
342
+ the name is saved.
343
+
344
+ Returns
345
+ -------
346
+ dict
347
+ A dictionary containing the serialised dag.
348
+ """
349
+
350
+ block_instances: dict[Block, int] = {}
351
+
352
+ instance = 0
353
+ for s, d in self._block_pairs:
354
+ if s not in block_instances:
355
+ block_instances[s] = instance
356
+ instance += 1
357
+ if d not in block_instances:
358
+ block_instances[d] = instance
359
+ instance += 1
360
+
361
+ blocks = []
362
+ for g, i in block_instances.items():
363
+ # We have to pass some arguments to the block when it is reconstituted.
364
+ # `name` is mandatory - what else?
365
+ #
366
+ args = {'name': g.name}
367
+
368
+ # What are __init__'s plain Python parameters?
369
+ # The first parameter is always self - skip that.
370
+ #
371
+ vars = g.__init__.__code__.co_varnames[1:g.__init__.__code__.co_argcount] # type: ignore[misc]
372
+ for var in vars:
373
+ if hasattr(g, var):
374
+ args[var] = getattr(g, var)
375
+
376
+ # TODO is there a better way of checking for user_input?
377
+ if hasattr(g, 'user_input'):
378
+ args['user_input'] = getattr(g, 'user_input')
379
+
380
+ block = {
381
+ 'block': g.block_key(),
382
+ 'instance': i,
383
+ 'args': args
384
+ }
385
+ blocks.append(block)
386
+
387
+ connections = []
388
+ for s, d in self._block_pairs:
389
+ connection: dict[str, Any] = {
390
+ 'src': block_instances[s],
391
+ 'dst': block_instances[d],
392
+ 'conn_args': []
393
+ }
394
+
395
+ # Get src params that have been connected to dst params.
396
+ #
397
+ nmap = {(gname, sname): dname for (gname, sname), dname in d._block_name_map.items() if gname==s.name}
398
+
399
+ for (gname, sname), dname in nmap.items():
400
+ args = {
401
+ 'src_param_name': sname,
402
+ 'dst_param_name': dname
403
+ }
404
+
405
+ # for pname, data in s.param.watchers.items():
406
+ # if pname==sname:
407
+ # for watcher in data['value']:
408
+ # args['onlychanged'] = watcher.onlychanged
409
+ # args['queued'] = watcher.queued
410
+ # args['precedence'] = watcher.precedence
411
+
412
+
413
+ connection['conn_args'].append(args)
414
+
415
+ connections.append(connection)
416
+
417
+ return {
418
+ 'dag': {
419
+ 'type': self.__class__.__name__,
420
+ 'doc': self.doc,
421
+ 'site': self.site,
422
+ 'title': self.title
423
+ },
424
+ 'blocks': blocks,
425
+ 'connections': connections
426
+ }
427
+
428
+ def hv_graph(self):
429
+ """Build a HoloViews Graph to visualise the block connections."""
430
+
431
+ src: list[Block] = []
432
+ dst: list[Block] = []
433
+
434
+ def build_layers():
435
+ """Traverse the block pairs and organise them into layers.
436
+
437
+ The first layer contains the root (no input) nodes.
438
+ """
439
+
440
+ ranks = {}
441
+ remaining = self._block_pairs[:]
442
+
443
+ # Find the root nodes and assign them a layer.
444
+ #
445
+ src[:], dst[:] = zip(*remaining)
446
+ S = list(set([s for s in src if s not in dst]))
447
+ for s in S:
448
+ ranks[s.name] = 0
449
+
450
+ n_layers = 1
451
+ while remaining:
452
+ for s, d in remaining:
453
+ if s.name in ranks:
454
+ # This destination could be from sources at different layers.
455
+ # Make sure the deepest one is used.
456
+ #
457
+ ranks[d.name] = max(ranks.get(d.name, 0), ranks[s.name] + 1)
458
+ n_layers = max(n_layers, ranks[d.name])
459
+
460
+ remaining = [(s,d) for s,d in remaining if d.name not in ranks]
461
+
462
+ return n_layers, ranks
463
+
464
+ def layout(_):
465
+ """Arrange the graph nodes."""
466
+
467
+ max_width = 0
468
+
469
+ # Arrange the graph y by layer from top to bottom.
470
+ # For x, for no we start at 0 and +1 in each layer.
471
+ #
472
+ yx = {y:0 for y in ranks.values()}
473
+ gxy = {}
474
+ for g, y in ranks.items():
475
+ gxy[g] = [yx[y], y]
476
+ yx[y] += 1
477
+ max_width = max(max_width, yx[y])
478
+
479
+ # Balance out the x in each layer.
480
+ #
481
+ for y in range(n_layers+1):
482
+ layer = {name: xy for name,xy in gxy.items() if xy[1]==y}
483
+ if len(layer)<max_width:
484
+ for x, (name, xy) in enumerate(layer.items(), 1):
485
+ gxy[name][0] = x/max_width
486
+
487
+ return gxy
488
+
489
+ n_layers, ranks = build_layers()
490
+
491
+ src_names = [g.name for g in src]
492
+ dst_names = [g.name for g in dst]
493
+ g = hv.Graph(((src_names, dst_names),))
494
+
495
+ return hv.element.graphs.layout_nodes(g, layout=layout)
496
+
497
+ def topological_sort(pairs):
498
+ """Implement a topological sort as described at
499
+ `Topological sorting <https://en.wikipedia.org/wiki/Topological_sorting>`_.
500
+
501
+ code-block:: python
502
+
503
+ L ← Empty list that will contain the sorted elements
504
+ S ← Set of all nodes with no incoming edge
505
+
506
+ while S is not empty do
507
+ remove a node n from S
508
+ add n to L
509
+ for each node m with an edge e from n to m do
510
+ remove edge e from the graph
511
+ if m has no other incoming edges then
512
+ insert m into S
513
+
514
+ if graph has edges then
515
+ return error (graph has at least one cycle)
516
+ else
517
+ return L (a topologically sorted order)
518
+ """
519
+
520
+ def edge(pairs, n, m):
521
+ for ix, pair in enumerate(pairs):
522
+ if pair==(n, m):
523
+ return ix
524
+
525
+ return None
526
+
527
+ def has_incoming(pairs, m):
528
+ return any(d is m for _,d in pairs)
529
+
530
+ remaining = pairs[:]
531
+ L = []
532
+
533
+ srcs, dsts = zip(*remaining)
534
+ S = list(set([s for s in srcs if s not in dsts]))
535
+
536
+ while S:
537
+ n = S.pop(0)
538
+ L.append(n)
539
+ for _, m in remaining[:]:
540
+ if (e:=edge(remaining, n, m))is not None:
541
+ del remaining[e]
542
+ if not has_incoming(remaining, m):
543
+ S.append(m)
544
+
545
+ return L, remaining
546
+
547
+ def _has_cycle(block_pairs: list[tuple[Block, Block]]):
548
+ _, remaining = topological_sort(block_pairs)
549
+
550
+ return len(remaining)>0
551
+
552
+ def _get_sorted(block_pairs: list[tuple[Block, Block]]):
553
+ ordered, remaining = topological_sort(block_pairs)
554
+
555
+ if remaining:
556
+ raise BlockError('Dag contains a cycle')
557
+
558
+ return ordered