sier2 0.29__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sier2 might be problematic. Click here for more details.

sier2/_dag.py ADDED
@@ -0,0 +1,694 @@
1
+ from ._block import Block, BlockError, BlockValidateError, BlockState
2
+ from dataclasses import dataclass, field #, KW_ONLY, field
3
+ from collections import defaultdict, deque
4
+ import holoviews as hv
5
+ from importlib.metadata import entry_points
6
+ import threading
7
+ import sys
8
+ from typing import Any
9
+
10
+ # By default, loops in a dag aren't allowed.
11
+ #
12
+ _DISALLOW_CYCLES = True
13
+
14
+ @dataclass
15
+ class Connection:
16
+ """Define a connection between an output parameter and an input parameter."""
17
+
18
+ src_param_name: str
19
+ dst_param_name: str
20
+
21
+ def __post_init__(self):
22
+ if not self.src_param_name.startswith('out_'):
23
+ raise BlockError('Output params must start with "out_"')
24
+
25
+ if not self.dst_param_name.startswith('in_'):
26
+ raise BlockError('Input params must start with "in_"')
27
+
28
+ @dataclass
29
+ class _InputValues:
30
+ """Record a param value change.
31
+
32
+ When a block updates an output param, the update is queued until
33
+ the block finishes executing. Instances of this class are
34
+ what is queued.
35
+ """
36
+
37
+ # The block to be updated.
38
+ #
39
+ dst: Block
40
+
41
+ # The values to be set before the block executes.
42
+ # For a normal block, values will be non-empty when execute() is called.
43
+ # For an input block, if values is non-empty, prepare()
44
+ # will be called, else execute() will be called
45
+ #
46
+ values: dict[str, Any] = field(default_factory=dict)
47
+
48
+ class _BlockContext:
49
+ """A context manager to wrap the execution of a block within a dag.
50
+
51
+ This default context manager handles the block state, the stopper,
52
+ and converts block execution errors to GimzoError exceptions.
53
+
54
+ This could be done inline, but using a context manager allows
55
+ the context manager to be replaced. For example, a panel-based
56
+ dag runner could use a context manager that incorporates logging
57
+ and displays information in a GUI.
58
+ """
59
+
60
+ def __init__(self, *, block: Block, dag: 'Dag', dag_logger=None):
61
+ self.block = block
62
+ self.dag = dag
63
+ self.dag_logger = dag_logger
64
+
65
+ def __enter__(self):
66
+ self.block._block_state = BlockState.EXECUTING
67
+
68
+ return self.block
69
+
70
+ def __exit__(self, exc_type, exc_val, exc_tb):
71
+ if exc_type is None:
72
+ self.block._block_state = BlockState.WAITING if self.block.block_pause_execution else BlockState.SUCCESSFUL
73
+ elif exc_type is KeyboardInterrupt:
74
+ self.block_state._block_state = BlockState.INTERRUPTED
75
+ if not self.dag._is_pyodide:
76
+ self.dag._stopper.event.set()
77
+ print(f'KEYBOARD INTERRUPT IN BLOCK {self.name}')
78
+ else:
79
+ state = BlockState.ERROR
80
+ self.block._block_state = state
81
+ if exc_type is not BlockValidateError:
82
+ # Validation errors don't set the stopper;
83
+ # they just stop execution.
84
+ #
85
+ if self.dag_logger:
86
+ self.dag_logger.exception(
87
+ block_name=self.block.name,
88
+ block_state=state
89
+ )
90
+
91
+ # msg = f'While in {self.block.name}.execute(): {exc_val}'
92
+ # LOGGER.exception(msg)
93
+ if not self.dag._is_pyodide:
94
+ self.dag._stopper.event.set()
95
+
96
+ if not issubclass(exc_type, BlockError):
97
+ # Convert non-BlockErrors in the block to a BlockError.
98
+ #
99
+ raise BlockError(f'Block {self.block.name}: {str(exc_val)}') from exc_val
100
+
101
+ # Don't suppress the original exception.
102
+ #
103
+ return False
104
+
105
+ class _Stopper:
106
+ def __init__(self):
107
+ self.event = threading.Event()
108
+
109
+ @property
110
+ def is_stopped(self):
111
+ return self.event
112
+
113
+ @is_stopped.getter
114
+ def is_stopped(self) -> bool:
115
+ return self.event.is_set()
116
+
117
+ def __repr__(self):
118
+ return f'stopped={self.is_stopped}'
119
+
120
+ def _find_logging():
121
+ PLUGIN_GROUP = 'sier2.logging'
122
+ library = entry_points(group=PLUGIN_GROUP)
123
+ if (liblen:=len(library))==0:
124
+ # There is no logging plugin, so return a dummy.
125
+ #
126
+ return lambda f, *args, **kwargs: f
127
+ elif liblen>1:
128
+ raise BlockError(f'More than one plugin for {PLUGIN_GROUP}')
129
+
130
+ ep = next(iter(library))
131
+ try:
132
+ logging_func = ep.load()
133
+
134
+ return logging_func
135
+ except AttributeError as e:
136
+ e.add_note(f'While attempting to load logging function {ep.value}')
137
+ raise BlockError(e)
138
+
139
+ # A marker from Dag.execute_after_input() to tell Dag.execute()
140
+ # that this a restart.
141
+ #
142
+ _RESTART = ':restart:'
143
+
144
+ class Dag:
145
+ """A directed acyclic graph of blocks."""
146
+
147
+ def __init__(self, *, site: str='Block', title: str, doc: str, author: dict[str, str]=None):
148
+ self._block_pairs: list[tuple[Block, Block]] = []
149
+
150
+ self.site = site
151
+ self.title = title
152
+ self.doc = doc
153
+
154
+ if author is not None:
155
+ if 'name' in author and 'email' in author:
156
+ self.author = {'name': author['name', 'email': author: 'email']}
157
+ else:
158
+ raise ValueError('Author must contain name and email keys')
159
+ else:
160
+ self.author = None
161
+
162
+ if not self._is_pyodide:
163
+ self._stopper = _Stopper()
164
+
165
+ # We watch output params to be notified when they are set.
166
+ # Events are queued here.
167
+ #
168
+ self._block_queue: deque[_InputValues] = deque()
169
+
170
+ # The context manager class to use to run blocks.
171
+ #
172
+ self._block_context = _BlockContext
173
+
174
+ # Set up the logging hook.
175
+ #
176
+ self.logging = _find_logging()
177
+
178
+ @property
179
+ def _is_pyodide(self) -> bool:
180
+ return '_pyodide' in sys.modules
181
+
182
+ def _for_each_once(self):
183
+ """Yield each connected block once."""
184
+
185
+ seen = set()
186
+ for s, d in self._block_pairs:
187
+ for g in s, d:
188
+ if g not in seen:
189
+ seen.add(g)
190
+ yield g
191
+
192
+ def stop(self):
193
+ """Stop further execution of Block instances in this dag."""
194
+ if not self._is_pyodide:
195
+ self._stopper.event.set()
196
+
197
+ def unstop(self):
198
+ """Enable further execution of Block instances in this dag."""
199
+ if not self._is_pyodide:
200
+ self._stopper.event.clear()
201
+
202
+ def connect(self, src: Block, dst: Block, *connections: Connection):
203
+ if any(not isinstance(c, Connection) for c in connections):
204
+ raise BlockError('All arguments must be Connection instances')
205
+
206
+ if _DISALLOW_CYCLES:
207
+ if _has_cycle(self._block_pairs + [(src, dst)]):
208
+ raise BlockError('This connection would create a cycle')
209
+
210
+ if src.name==dst.name:
211
+ raise BlockError('Cannot add two blocks with the same name')
212
+
213
+ for g in self._for_each_once():
214
+ if (g is not src and g.name==src.name) or (g is not dst and g.name==dst.name):
215
+ raise BlockError('A block with this name already exists')
216
+
217
+ for s, d in self._block_pairs:
218
+ if src is s and dst is d:
219
+ raise BlockError('These blocks are already connected')
220
+
221
+ if self._block_pairs:
222
+ connected = any(src is s or src is d or dst is s or dst is d for s, d in self._block_pairs)
223
+ if not connected:
224
+ raise BlockError('A new block must connect to existing block')
225
+
226
+ # Group watchers by their attributes.
227
+ # This optimises the number of watchers.
228
+ #
229
+ # If we just add a watcher per param in the loop, then
230
+ # param.update() won't batch the events.
231
+ #
232
+ # src_out_params = defaultdict(list)
233
+ src_out_params = []
234
+
235
+ for conn in connections:
236
+ # dst_param = getattr(dst.param, conn.dst_param_name)
237
+ # if dst_param.allow_refs:
238
+ # raise BlockError(f'Destination parameter {dst}.{inp} must be "allow_refs=True"')
239
+
240
+ src_param = getattr(src.param, conn.src_param_name)
241
+ if src_param.allow_refs:
242
+ raise BlockError(f'Source parameter {src}.{conn.src_param_name} must not be "allow_refs=True"')
243
+
244
+ dst._block_name_map[src.name, conn.src_param_name] = conn.dst_param_name
245
+ src_out_params.append(conn.src_param_name)
246
+
247
+ src.param.watch(lambda *events: self._param_event(dst, *events), src_out_params, onlychanged=False)
248
+ src._block_out_params.extend(src_out_params)
249
+
250
+ self._block_pairs.append((src, dst))
251
+
252
+ def _param_event(self, dst: Block, *events):
253
+ """The callback for a watch event."""
254
+
255
+ # print(f'DAG EVENTS: {events} -> {dst.name}')
256
+ for event in events:
257
+ cls = event.cls.name
258
+ name = event.name
259
+ new = event.new
260
+
261
+ # The input param in the dst block.
262
+ #
263
+ inp = dst._block_name_map[cls, name]
264
+
265
+ # Look for the destination block in the event queue.
266
+ # If found, update the param value dictionary,
267
+ # else append a new item.
268
+ # This ensures that all param updates for a destination
269
+ # block are merged into a single queue item, even if the
270
+ # updates come from different source blocks.
271
+ #
272
+ for item in self._block_queue:
273
+ if dst is item.dst:
274
+ item.values[inp] = new
275
+ break
276
+ else:
277
+ item = _InputValues(dst)
278
+ item.values[inp] = new
279
+ self._block_queue.append(item)
280
+
281
+ def execute_after_input(self, block: Block, *, dag_logger=None):
282
+ """Execute the dag after running ``prepare()`` in an input block.
283
+
284
+ After prepare() executes, and the user has possibly
285
+ provided input, the dag must continue with execute() in the
286
+ same block.
287
+
288
+ This method will prime the block queue with the specified block's
289
+ output, and call execute().
290
+
291
+ Parameters
292
+ ----------
293
+ block: Block
294
+ The block to restart the dag at.
295
+ dag_logger:
296
+ A logger adapter that will accept log messages.
297
+ """
298
+
299
+ if not block.block_pause_execution:
300
+ raise BlockError(f'A dag can only restart a paused Block, not {block.name}')
301
+
302
+ # Prime the block queue, using _RESTART
303
+ # to indicate that this is a restart, and Block.execute()
304
+ # must be called.
305
+ #
306
+ self._block_queue.appendleft(_InputValues(block, {_RESTART: True}))
307
+ self.execute(dag_logger=dag_logger)
308
+
309
+ def execute(self, *, dag_logger=None) -> Block|None:
310
+ """Execute the dag.
311
+
312
+ The dag is executed by iterating through the block event queue
313
+ and popping events from the head of the queue. For each event,
314
+ update the destination block's input parameters and call
315
+ that block's execute() method.
316
+
317
+ If the current destination block's ``block_pause_execution` is True,
318
+ the loop will call ``block.prepare()` instead of ``block.execute()``,
319
+ then stop; execute() will return the block that is puased on.
320
+ The dag can then be restarted with ``dag.execute_after_input()``,
321
+ using the paused block as the parameter.
322
+
323
+ To start the dag, either:
324
+ - there must be something in the event queue - the dag must be "primed". A block must have updated at least one output param before the dag's execute() is called;
325
+ - the first block in the dag must be an input block (block_pause_execution=True).
326
+
327
+ Calling ``dag.execute()`` will then execute the dag starting with the relevant block.
328
+ """
329
+
330
+ if not self._block_queue:
331
+ # If there aren't any blocks on the queue, find the first block in the dag.
332
+ # If this block is an input block, put it on the queue.
333
+ #
334
+ sorted_blocks = self.get_sorted()
335
+ if sorted_blocks:
336
+ first = sorted_blocks[0]
337
+ if first.block_pause_execution:
338
+ self._block_queue.appendleft(_InputValues(first, {}))
339
+
340
+ if not self._block_queue:
341
+ # Attempting to execute a dag with no updates is probably a mistake.
342
+ #
343
+ raise BlockError('Nothing to execute')
344
+
345
+ self.logging(None, sier2_dag_=self)
346
+
347
+ can_execute = True
348
+ while self._block_queue:
349
+ # print(len(self._block_queue), self._block_queue)
350
+ # The user has set the "stop executing" flag.
351
+ # Continue to set params, but don't execute anything
352
+ #
353
+ if not self._is_pyodide:
354
+ if self._stopper.is_stopped:
355
+ can_execute = False
356
+
357
+ item = self._block_queue.popleft()
358
+ is_restart = item.values.pop(_RESTART, False)
359
+ try:
360
+ item.dst.param.update(item.values)
361
+ except ValueError as e:
362
+ msg = f'While in {item.dst.name} setting a parameter: {e}'
363
+ if not self._is_pyodide:
364
+ self._stopper.event.set()
365
+ raise BlockError(msg) from e
366
+
367
+ # Execute the block.
368
+ # Don't execute input blocks when we get to them,
369
+ # unless this is after the user has selected the "Continue"
370
+ # button.
371
+ #
372
+ is_input_block = item.dst.block_pause_execution
373
+ if can_execute:
374
+ with self._block_context(block=item.dst, dag=self, dag_logger=dag_logger) as g:
375
+
376
+ logging_params = {
377
+ 'sier2_dag_': self,
378
+ 'sier2_block_': f'{item.dst}'
379
+ }
380
+
381
+ # If this is an input block, and there are input
382
+ # values, call prepare() if it exists.
383
+ #
384
+ if is_input_block and not is_restart:# and item.values:
385
+ self.logging(g.prepare, **logging_params)()
386
+ else:
387
+ self.logging(g.execute, **logging_params)()
388
+
389
+ # print(f'{is_input_block=}')
390
+ # print(f'{is_restart=}')
391
+ # print(f'{item.values=}')
392
+ if is_input_block and not is_restart:# and item.values:
393
+ # If the current destination block requires user input,
394
+ # stop executing the dag immediately, because we don't
395
+ # want to be setting the input params of further blocks
396
+ # and causing them to do things.
397
+ #
398
+ # This possibly leaves items on the queue, which will be
399
+ # executed on the next call to execute().
400
+ #
401
+ return item.dst
402
+
403
+ return None
404
+
405
+ def disconnect(self, g: Block) -> None:
406
+ """Disconnect block g from other blocks.
407
+
408
+ All parameters (input and output) will be disconnected.
409
+
410
+ Parameters
411
+ ----------
412
+ g: Block
413
+ The block to be disconnected.
414
+ """
415
+
416
+ for p, watchers in g.param.watchers.items():
417
+ for watcher in watchers['value']:
418
+ # print(f'disconnect watcher {g.name}.{watcher}')
419
+ g.param.unwatch(watcher)
420
+
421
+ for src, dst in self._block_pairs:
422
+ if dst is g:
423
+ for p, watchers in src.param.watchers.items():
424
+ for watcher in watchers['value']:
425
+ # print(f'disconnect watcher {src.name}.{watcher}')
426
+ src.param.unwatch(watcher)
427
+
428
+ # Remove this block from the dag.
429
+ # Check for sources and destinations.
430
+ #
431
+ self._block_pairs[:] = [(src, dst) for src, dst in self._block_pairs if src is not g and dst is not g]
432
+
433
+ # Because this block is no longer watching anything, the name map can be cleared.
434
+ #
435
+ g._block_name_map.clear()
436
+
437
+ def block_by_name(self, name) -> Block | None:
438
+ """Get a specific block by name."""
439
+
440
+ for s, d in self._block_pairs:
441
+ if s.name==name:
442
+ return s
443
+
444
+ if d.name == name:
445
+ return d
446
+
447
+ return None
448
+
449
+ def get_sorted(self) -> list[Block]:
450
+ """Return the blocks in this dag in topological order.
451
+
452
+ This is useful for arranging the blocks in a GUI, for example.
453
+
454
+ The returned dictionary is in no particular order:
455
+ the rank values determine the order of the blocks.
456
+
457
+ Returns
458
+ -------
459
+ dict[Block, int]
460
+ A mapping of block to rank
461
+ """
462
+
463
+ return _get_sorted(self._block_pairs)
464
+
465
+ def has_cycle(self):
466
+ return _has_cycle(self._block_pairs)
467
+
468
+ def dump(self):
469
+ """Dump the dag to a serialisable (eg to JSON) dictionary.
470
+
471
+ The blocks and connections are reduced to simple representations.
472
+ There is no need to serialize code: the blocks themselves are assumed
473
+ to be available when loaded - it is just the attributes of the blocks
474
+ that need to be saved.
475
+
476
+ Two sets of attributes in particular are saved.
477
+
478
+ * The name of the block class. Each block has a name by virtue of it
479
+ being a Parameterized subclass.
480
+ * The ``__init__`` parameters, where possible. For each parameter,
481
+ if the block object has a matching instance name, the value of
482
+ the name is saved.
483
+
484
+ Returns
485
+ -------
486
+ dict
487
+ A dictionary containing the serialised dag.
488
+ """
489
+
490
+ block_instances: dict[Block, int] = {}
491
+
492
+ instance = 0
493
+ for s, d in self._block_pairs:
494
+ if s not in block_instances:
495
+ block_instances[s] = instance
496
+ instance += 1
497
+ if d not in block_instances:
498
+ block_instances[d] = instance
499
+ instance += 1
500
+
501
+ blocks = []
502
+ for g, i in block_instances.items():
503
+ # We have to pass some arguments to the block when it is reconstituted.
504
+ # `name` is mandatory - what else?
505
+ #
506
+ args = {'name': g.name}
507
+
508
+ # What are __init__'s plain Python parameters?
509
+ # The first parameter is always self - skip that.
510
+ #
511
+ vars = g.__init__.__code__.co_varnames[1:g.__init__.__code__.co_argcount] # type: ignore[misc]
512
+ for var in vars:
513
+ if hasattr(g, var):
514
+ args[var] = getattr(g, var)
515
+
516
+ block = {
517
+ 'block': g.block_key(),
518
+ 'instance': i,
519
+ 'args': args
520
+ }
521
+ blocks.append(block)
522
+
523
+ connections = []
524
+ for s, d in self._block_pairs:
525
+ connection: dict[str, Any] = {
526
+ 'src': block_instances[s],
527
+ 'dst': block_instances[d],
528
+ 'conn_args': []
529
+ }
530
+
531
+ # Get src params that have been connected to dst params.
532
+ #
533
+ nmap = {(gname, sname): dname for (gname, sname), dname in d._block_name_map.items() if gname==s.name}
534
+
535
+ for (gname, sname), dname in nmap.items():
536
+ args = {
537
+ 'src_param_name': sname,
538
+ 'dst_param_name': dname
539
+ }
540
+
541
+ # for pname, data in s.param.watchers.items():
542
+ # if pname==sname:
543
+ # for watcher in data['value']:
544
+ # args['onlychanged'] = watcher.onlychanged
545
+ # args['queued'] = watcher.queued
546
+ # args['precedence'] = watcher.precedence
547
+
548
+
549
+ connection['conn_args'].append(args)
550
+
551
+ connections.append(connection)
552
+
553
+ return {
554
+ 'dag': {
555
+ 'type': self.__class__.__name__,
556
+ 'doc': self.doc,
557
+ 'site': self.site,
558
+ 'title': self.title
559
+ },
560
+ 'blocks': blocks,
561
+ 'connections': connections
562
+ }
563
+
564
+ def hv_graph(self):
565
+ """Build a HoloViews Graph to visualise the block connections."""
566
+
567
+ src: list[Block] = []
568
+ dst: list[Block] = []
569
+
570
+ def build_layers():
571
+ """Traverse the block pairs and organise them into layers.
572
+
573
+ The first layer contains the root (no input) nodes.
574
+ """
575
+
576
+ ranks = {}
577
+ remaining = self._block_pairs[:]
578
+
579
+ # Find the root nodes and assign them a layer.
580
+ #
581
+ src[:], dst[:] = zip(*remaining)
582
+ S = list(set([s for s in src if s not in dst]))
583
+ for s in S:
584
+ ranks[s.name] = 0
585
+
586
+ n_layers = 1
587
+ while remaining:
588
+ for s, d in remaining:
589
+ if s.name in ranks:
590
+ # This destination could be from sources at different layers.
591
+ # Make sure the deepest one is used.
592
+ #
593
+ ranks[d.name] = max(ranks.get(d.name, 0), ranks[s.name] + 1)
594
+ n_layers = max(n_layers, ranks[d.name])
595
+
596
+ remaining = [(s,d) for s,d in remaining if d.name not in ranks]
597
+
598
+ return n_layers, ranks
599
+
600
+ def layout(_):
601
+ """Arrange the graph nodes."""
602
+
603
+ max_width = 0
604
+
605
+ # Arrange the graph y by layer from top to bottom.
606
+ # For x, for no we start at 0 and +1 in each layer.
607
+ #
608
+ yx = {y:0 for y in ranks.values()}
609
+ gxy = {}
610
+ for g, y in ranks.items():
611
+ gxy[g] = [yx[y], y]
612
+ yx[y] += 1
613
+ max_width = max(max_width, yx[y])
614
+
615
+ # Balance out the x in each layer.
616
+ #
617
+ for y in range(n_layers+1):
618
+ layer = {name: xy for name,xy in gxy.items() if xy[1]==y}
619
+ if len(layer)<max_width:
620
+ for x, (name, xy) in enumerate(layer.items(), 1):
621
+ gxy[name][0] = x/max_width
622
+
623
+ return gxy
624
+
625
+ n_layers, ranks = build_layers()
626
+
627
+ src_names = [g.name for g in src]
628
+ dst_names = [g.name for g in dst]
629
+ g = hv.Graph(((src_names, dst_names),))
630
+
631
+ return hv.element.graphs.layout_nodes(g, layout=layout)
632
+
633
+ def topological_sort(pairs):
634
+ """Implement a topological sort as described at
635
+ `Topological sorting <https://en.wikipedia.org/wiki/Topological_sorting>`_.
636
+
637
+ code-block:: python
638
+
639
+ L ← Empty list that will contain the sorted elements
640
+ S ← Set of all nodes with no incoming edge
641
+
642
+ while S is not empty do
643
+ remove a node n from S
644
+ add n to L
645
+ for each node m with an edge e from n to m do
646
+ remove edge e from the graph
647
+ if m has no other incoming edges then
648
+ insert m into S
649
+
650
+ if graph has edges then
651
+ return error (graph has at least one cycle)
652
+ else
653
+ return L (a topologically sorted order)
654
+ """
655
+
656
+ def edge(pairs, n, m):
657
+ for ix, pair in enumerate(pairs):
658
+ if pair==(n, m):
659
+ return ix
660
+
661
+ return None
662
+
663
+ def has_incoming(pairs, m):
664
+ return any(d is m for _,d in pairs)
665
+
666
+ remaining = pairs[:]
667
+ L = []
668
+
669
+ srcs, dsts = zip(*remaining)
670
+ S = list(set([s for s in srcs if s not in dsts]))
671
+
672
+ while S:
673
+ n = S.pop(0)
674
+ L.append(n)
675
+ for _, m in remaining[:]:
676
+ if (e:=edge(remaining, n, m))is not None:
677
+ del remaining[e]
678
+ if not has_incoming(remaining, m):
679
+ S.append(m)
680
+
681
+ return L, remaining
682
+
683
+ def _has_cycle(block_pairs: list[tuple[Block, Block]]):
684
+ _, remaining = topological_sort(block_pairs)
685
+
686
+ return len(remaining)>0
687
+
688
+ def _get_sorted(block_pairs: list[tuple[Block, Block]]) -> list[Block]:
689
+ ordered, remaining = topological_sort(block_pairs)
690
+
691
+ if remaining:
692
+ raise BlockError('Dag contains a cycle')
693
+
694
+ return ordered