sier2 1.0.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sier2/_dag.py ADDED
@@ -0,0 +1,648 @@
1
+ from ._block import Block, BlockError, BlockValidateError, BlockState
2
+ from dataclasses import dataclass, field #, KW_ONLY, field
3
+ from collections import deque
4
+ from importlib.metadata import entry_points
5
+ import threading
6
+ import sys
7
+ from typing import Any
8
+
9
+ # By default, loops in a dag aren't allowed.
10
+ #
11
+ _DISALLOW_CYCLES = True
12
+
13
+ @dataclass
14
+ class Connection:
15
+ """Define a connection between an output parameter and an input parameter."""
16
+
17
+ src_param_name: str
18
+ dst_param_name: str
19
+
20
+ def __post_init__(self):
21
+ if not self.src_param_name.startswith('out_'):
22
+ raise BlockError('Output params must start with "out_"')
23
+
24
+ if not self.dst_param_name.startswith('in_'):
25
+ raise BlockError('Input params must start with "in_"')
26
+
27
+ @dataclass
28
+ class _InputValues:
29
+ """Record a param value change.
30
+
31
+ When a block updates an output param, the update is queued until
32
+ the block finishes executing. Instances of this class are
33
+ what is queued.
34
+ """
35
+
36
+ # The block to be updated.
37
+ #
38
+ dst: Block
39
+
40
+ # The values to be set before the block executes.
41
+ # Values will be non-empty when execute() is called.
42
+ #
43
+ values: dict[str, Any] = field(default_factory=dict)
44
+
45
+ class _BlockContext:
46
+ """A context manager to wrap the execution of a block within a dag.
47
+
48
+ This default context manager handles the block state, the stopper,
49
+ and converts block execution errors to GimzoError exceptions.
50
+
51
+ This could be done inline, but using a context manager allows
52
+ the context manager to be replaced. For example, a panel-based
53
+ dag runner could use a context manager that incorporates logging
54
+ and displays information in a GUI.
55
+ """
56
+
57
+ def __init__(self, *, block: Block, dag: 'Dag', dag_logger=None):
58
+ self.block = block
59
+ self.dag = dag
60
+ self.dag_logger = dag_logger
61
+
62
+ def __enter__(self):
63
+ self.block._block_state = BlockState.EXECUTING
64
+
65
+ return self.block
66
+
67
+ def __exit__(self, exc_type, exc_val, exc_tb):
68
+ if exc_type is None:
69
+ self.block._block_state = BlockState.WAITING if self.block.block_pause_execution else BlockState.SUCCESSFUL
70
+ elif exc_type is KeyboardInterrupt:
71
+ self.block_state._block_state = BlockState.INTERRUPTED
72
+ if not self.dag._is_pyodide:
73
+ self.dag._stopper.event.set()
74
+ print(f'KEYBOARD INTERRUPT IN BLOCK {self.name}')
75
+ else:
76
+ state = BlockState.ERROR
77
+ self.block._block_state = state
78
+ if exc_type is not BlockValidateError:
79
+ # Validation errors don't set the stopper;
80
+ # they just stop execution.
81
+ #
82
+ if self.dag_logger:
83
+ self.dag_logger.exception(
84
+ block_name=self.block.name,
85
+ block_state=state
86
+ )
87
+
88
+ # msg = f'While in {self.block.name}.execute(): {exc_val}'
89
+ # LOGGER.exception(msg)
90
+ if not self.dag._is_pyodide:
91
+ self.dag._stopper.event.set()
92
+
93
+ if not issubclass(exc_type, BlockError):
94
+ # Convert non-BlockErrors in the block to a BlockError.
95
+ #
96
+ raise BlockError(f'Block {self.block.name}: {str(exc_val)}') from exc_val
97
+
98
+ # Don't suppress the original exception.
99
+ #
100
+ return False
101
+
102
+ class _Stopper:
103
+ def __init__(self):
104
+ self.event = threading.Event()
105
+
106
+ @property
107
+ def is_stopped(self):
108
+ return self.event
109
+
110
+ @is_stopped.getter
111
+ def is_stopped(self) -> bool:
112
+ return self.event.is_set()
113
+
114
+ def __repr__(self):
115
+ return f'stopped={self.is_stopped}'
116
+
117
+ def _find_logging():
118
+ PLUGIN_GROUP = 'sier2.logging'
119
+ library = entry_points(group=PLUGIN_GROUP)
120
+ if (liblen:=len(library))==0:
121
+ # There is no logging plugin, so return a dummy.
122
+ #
123
+ return lambda f, *args, **kwargs: f
124
+ elif liblen>1:
125
+ raise BlockError(f'More than one plugin for {PLUGIN_GROUP}')
126
+
127
+ ep = next(iter(library))
128
+ try:
129
+ logging_func = ep.load()
130
+
131
+ return logging_func
132
+ except AttributeError as e:
133
+ e.add_note(f'While attempting to load logging function {ep.value}')
134
+ raise BlockError(e)
135
+
136
+ # A marker from Dag.execute_after_input() to tell Dag.execute()
137
+ # that this a restart.
138
+ #
139
+ _RESTART = ':restart:'
140
+
141
+ class Dag:
142
+ """A directed acyclic graph of blocks."""
143
+
144
+ def __init__(self, *, site: str='Block', title: str, doc: str, author: dict[str, str]=None, show_doc: bool=True):
145
+ """A new dag.
146
+
147
+ Parameters
148
+ ----------
149
+ site: str
150
+ Name of the site.
151
+ title: str
152
+ A title to show in the header.
153
+ doc: str
154
+ Dag documentation.
155
+ author: str
156
+ The dag author.
157
+ show_doc: bool
158
+ Show the dag docstring if True.
159
+ """
160
+
161
+ self._block_pairs: list[tuple[Block, Block]] = []
162
+
163
+ self.site = site
164
+ self.title = title
165
+ self.doc = doc
166
+ self.show_doc = show_doc
167
+
168
+ if author is not None:
169
+ if 'name' in author and 'email' in author:
170
+ self.author = {'name': author['name', 'email': author: 'email']}
171
+ else:
172
+ raise ValueError('Author must contain name and email keys')
173
+ else:
174
+ self.author = None
175
+
176
+ if not self._is_pyodide:
177
+ self._stopper = _Stopper()
178
+
179
+ # We watch output params to be notified when they are set.
180
+ # Events are queued here.
181
+ #
182
+ self._block_queue: deque[_InputValues] = deque()
183
+
184
+ # The context manager class to use to run blocks.
185
+ #
186
+ self._block_context = _BlockContext
187
+
188
+ # Set up the logging hook.
189
+ #
190
+ self.logging = _find_logging()
191
+
192
+ @property
193
+ def _is_pyodide(self) -> bool:
194
+ return '_pyodide' in sys.modules
195
+
196
+ def _for_each_once(self):
197
+ """Yield each connected block once."""
198
+
199
+ seen = set()
200
+ for s, d in self._block_pairs:
201
+ for g in s, d:
202
+ if g not in seen:
203
+ seen.add(g)
204
+ yield g
205
+
206
+ def stop(self):
207
+ """Stop further execution of Block instances in this dag."""
208
+ if not self._is_pyodide:
209
+ self._stopper.event.set()
210
+
211
+ def unstop(self):
212
+ """Enable further execution of Block instances in this dag."""
213
+ if not self._is_pyodide:
214
+ self._stopper.event.clear()
215
+
216
+ def connect(self, src: Block, dst: Block, *connections: Connection):
217
+ if any(not isinstance(c, Connection) for c in connections):
218
+ raise BlockError('All arguments must be Connection instances')
219
+
220
+ # Because this is probably the first place that the Block instance is used,
221
+ # this is a convenient place to check that the block was correctly initialised.
222
+ #
223
+ # Pick an arbitrary attribute that should be present.
224
+ #
225
+ for b in src, dst:
226
+ if not hasattr(b, 'block_doc'):
227
+ raise BlockError(f'Did you call super().__init__() in {b}?')
228
+
229
+ if _DISALLOW_CYCLES:
230
+ if _has_cycle(self._block_pairs + [(src, dst)]):
231
+ raise BlockError('This connection would create a cycle')
232
+
233
+ if src.name==dst.name:
234
+ raise BlockError('Cannot add two blocks with the same name')
235
+
236
+ for g in self._for_each_once():
237
+ if (g is not src and g.name==src.name) or (g is not dst and g.name==dst.name):
238
+ raise BlockError('A block with this name already exists')
239
+
240
+ for s, d in self._block_pairs:
241
+ if src is s and dst is d:
242
+ raise BlockError('These blocks are already connected')
243
+
244
+ if self._block_pairs:
245
+ connected = any(src is s or src is d or dst is s or dst is d for s, d in self._block_pairs)
246
+ if not connected:
247
+ raise BlockError('A new block must connect to existing block')
248
+
249
+ # Group watchers by their attributes.
250
+ # This optimises the number of watchers.
251
+ #
252
+ # If we just add a watcher per param in the loop, then
253
+ # param.update() won't batch the events.
254
+ #
255
+ # src_out_params = defaultdict(list)
256
+ src_out_params = []
257
+
258
+ for conn in connections:
259
+ # dst_param = getattr(dst.param, conn.dst_param_name)
260
+ # if dst_param.allow_refs:
261
+ # raise BlockError(f'Destination parameter {dst}.{inp} must be "allow_refs=True"')
262
+
263
+ src_param = getattr(src.param, conn.src_param_name)
264
+ if src_param.allow_refs:
265
+ raise BlockError(f'Source parameter {src}.{conn.src_param_name} must not be "allow_refs=True"')
266
+
267
+ dst._block_name_map[src.name, conn.src_param_name] = conn.dst_param_name
268
+ src_out_params.append(conn.src_param_name)
269
+
270
+ src.param.watch(lambda *events: self._param_event(dst, *events), src_out_params, onlychanged=False)
271
+ src._block_out_params.extend(src_out_params)
272
+
273
+ self._block_pairs.append((src, dst))
274
+
275
+ def _param_event(self, dst: Block, *events):
276
+ """The callback for a watch event."""
277
+
278
+ # print(f'DAG EVENTS: {events} -> {dst.name}')
279
+ for event in events:
280
+ cls = event.cls.name
281
+ name = event.name
282
+ new = event.new
283
+
284
+ # The input param in the dst block.
285
+ #
286
+ inp = dst._block_name_map[cls, name]
287
+
288
+ # Look for the destination block in the event queue.
289
+ # If found, update the param value dictionary,
290
+ # else append a new item.
291
+ # This ensures that all param updates for a destination
292
+ # block are merged into a single queue item, even if the
293
+ # updates come from different source blocks.
294
+ #
295
+ for item in self._block_queue:
296
+ if dst is item.dst:
297
+ item.values[inp] = new
298
+ break
299
+ else:
300
+ item = _InputValues(dst)
301
+ item.values[inp] = new
302
+ self._block_queue.append(item)
303
+
304
+ def execute_after_input(self, block: Block, *, dag_logger=None):
305
+ """Execute the dag after running ``prepare()``.
306
+
307
+ After prepare() executes, and the user has possibly
308
+ provided input, the dag must continue with execute() in the
309
+ same block.
310
+
311
+ This method will prime the block queue with the specified block's
312
+ output, and call execute().
313
+
314
+ Parameters
315
+ ----------
316
+ block: Block
317
+ The block to restart the dag at.
318
+ dag_logger:
319
+ A logger adapter that will accept log messages.
320
+ """
321
+
322
+ if not block.block_pause_execution:
323
+ raise BlockError(f'A dag can only restart a paused Block, not {block.name}')
324
+
325
+ # Prime the block queue, using _RESTART
326
+ # to indicate that this is a restart, and Block.execute()
327
+ # must be called.
328
+ #
329
+ self._block_queue.appendleft(_InputValues(block, {_RESTART: True}))
330
+ self.execute(dag_logger=dag_logger)
331
+
332
+ def execute(self, *, dag_logger=None) -> Block|None:
333
+ """Execute the dag.
334
+
335
+ The dag is executed by iterating through the block event queue
336
+ and popping events from the head of the queue. For each event,
337
+ update the destination block's input parameters and call
338
+ that block's execute() method.
339
+
340
+ If the current destination block's ``block_pause_execution` is True,
341
+ the loop will call ``block.prepare()``, then stop; execute()
342
+ will return the block that is puased on.
343
+ The dag can then be restarted with ``dag.execute_after_input()``,
344
+ using the paused block as the parameter.
345
+
346
+ To start the dag, either:
347
+ - there must be something in the event queue - the dag must be "primed". A block must have updated at least one output param before the dag's execute() is called;
348
+ - the first block in the dag must be an input block (block_pause_execution=True).
349
+
350
+ Calling ``dag.execute()`` will then execute the dag starting with the relevant block.
351
+ """
352
+
353
+ if not self._block_queue:
354
+ # If there aren't any blocks on the queue, find the first block in the dag.
355
+ # If this block is an input block, put it on the queue.
356
+ #
357
+ sorted_blocks = self.get_sorted()
358
+ if sorted_blocks:
359
+ first = sorted_blocks[0]
360
+ if first.block_pause_execution:
361
+ self._block_queue.appendleft(_InputValues(first, {}))
362
+
363
+ if not self._block_queue:
364
+ # Attempting to execute a dag with no updates is probably a mistake.
365
+ #
366
+ raise BlockError('Nothing to execute')
367
+
368
+ self.logging(None, sier2_dag_=self)
369
+
370
+ can_execute = True
371
+ while self._block_queue:
372
+ # print(len(self._block_queue), self._block_queue)
373
+ # The user has set the "stop executing" flag.
374
+ # Continue to set params, but don't execute anything
375
+ #
376
+ if not self._is_pyodide:
377
+ if self._stopper.is_stopped:
378
+ can_execute = False
379
+
380
+ item = self._block_queue.popleft()
381
+ is_restart = item.values.pop(_RESTART, False)
382
+ try:
383
+ item.dst.param.update(item.values)
384
+ except ValueError as e:
385
+ msg = f'While in {item.dst.name} setting a parameter: {e}'
386
+ if not self._is_pyodide:
387
+ self._stopper.event.set()
388
+ raise BlockError(msg) from e
389
+
390
+ # Execute the block.
391
+ # Don't execute input blocks when we get to them,
392
+ # unless this is after the user has selected the "Continue"
393
+ # button.
394
+ #
395
+ is_input_block = item.dst.block_pause_execution
396
+ if can_execute:
397
+ with self._block_context(block=item.dst, dag=self, dag_logger=dag_logger) as g:
398
+
399
+ logging_params = {
400
+ 'sier2_dag_': self,
401
+ 'sier2_block_': f'{item.dst}'
402
+ }
403
+
404
+ # If we need to wait for a user, just run prepare().
405
+ # If we are restarting, just run execute().
406
+ # Otherwise, run both.
407
+ if is_input_block and not is_restart:
408
+ self.logging(g.prepare, **logging_params)()
409
+ elif is_restart:
410
+ self.logging(g.execute, **logging_params)()
411
+ else:
412
+ self.logging(g.prepare, **logging_params)()
413
+ self.logging(g.execute, **logging_params)()
414
+
415
+ if is_input_block and not is_restart:# and item.values:
416
+ # If the current destination block requires user input,
417
+ # stop executing the dag immediately, because we don't
418
+ # want to be setting the input params of further blocks
419
+ # and causing them to do things.
420
+ #
421
+ # This possibly leaves items on the queue, which will be
422
+ # executed on the next call to execute().
423
+ #
424
+ return item.dst
425
+
426
+ return None
427
+
428
+ def disconnect(self, g: Block) -> None:
429
+ """Disconnect block g from other blocks.
430
+
431
+ All parameters (input and output) will be disconnected.
432
+
433
+ Parameters
434
+ ----------
435
+ g: Block
436
+ The block to be disconnected.
437
+ """
438
+
439
+ for p, watchers in g.param.watchers.items():
440
+ for watcher in watchers['value']:
441
+ # print(f'disconnect watcher {g.name}.{watcher}')
442
+ g.param.unwatch(watcher)
443
+
444
+ for src, dst in self._block_pairs:
445
+ if dst is g:
446
+ for p, watchers in src.param.watchers.items():
447
+ for watcher in watchers['value']:
448
+ # print(f'disconnect watcher {src.name}.{watcher}')
449
+ src.param.unwatch(watcher)
450
+
451
+ # Remove this block from the dag.
452
+ # Check for sources and destinations.
453
+ #
454
+ self._block_pairs[:] = [(src, dst) for src, dst in self._block_pairs if src is not g and dst is not g]
455
+
456
+ # Because this block is no longer watching anything, the name map can be cleared.
457
+ #
458
+ g._block_name_map.clear()
459
+
460
+ def block_by_name(self, name) -> Block | None:
461
+ """Get a specific block by name."""
462
+
463
+ for s, d in self._block_pairs:
464
+ if s.name==name:
465
+ return s
466
+
467
+ if d.name == name:
468
+ return d
469
+
470
+ return None
471
+
472
+ def get_sorted(self) -> list[Block]:
473
+ """Return the blocks in this dag in topological order.
474
+
475
+ This is useful for arranging the blocks in a GUI, for example.
476
+
477
+ The returned dictionary is in no particular order:
478
+ the rank values determine the order of the blocks.
479
+
480
+ Returns
481
+ -------
482
+ dict[Block, int]
483
+ A mapping of block to rank
484
+ """
485
+
486
+ return _get_sorted(self._block_pairs)
487
+
488
+ def has_cycle(self):
489
+ return _has_cycle(self._block_pairs)
490
+
491
+ def dump(self):
492
+ """Dump the dag to a serialisable (eg to JSON) dictionary.
493
+
494
+ The blocks and connections are reduced to simple representations.
495
+ There is no need to serialize code: the blocks themselves are assumed
496
+ to be available when loaded - it is just the attributes of the blocks
497
+ that need to be saved.
498
+
499
+ Two sets of attributes in particular are saved.
500
+
501
+ * The name of the block class. Each block has a name by virtue of it
502
+ being a Parameterized subclass.
503
+ * The ``__init__`` parameters, where possible. For each parameter,
504
+ if the block object has a matching instance name, the value of
505
+ the name is saved.
506
+
507
+ Returns
508
+ -------
509
+ dict
510
+ A dictionary containing the serialised dag.
511
+ """
512
+
513
+ block_instances: dict[Block, int] = {}
514
+
515
+ instance = 0
516
+ for s, d in self._block_pairs:
517
+ if s not in block_instances:
518
+ block_instances[s] = instance
519
+ instance += 1
520
+ if d not in block_instances:
521
+ block_instances[d] = instance
522
+ instance += 1
523
+
524
+ blocks = []
525
+ for g, i in block_instances.items():
526
+ # We have to pass some arguments to the block when it is reconstituted.
527
+ # `name` is mandatory - what else?
528
+ #
529
+ args = {'name': g.name}
530
+
531
+ # What are __init__'s plain Python parameters?
532
+ # The first parameter is always self - skip that.
533
+ #
534
+ vars = g.__init__.__code__.co_varnames[1:g.__init__.__code__.co_argcount] # type: ignore[misc]
535
+ for var in vars:
536
+ if hasattr(g, var):
537
+ args[var] = getattr(g, var)
538
+
539
+ block = {
540
+ 'block': g.block_key(),
541
+ 'instance': i,
542
+ 'args': args
543
+ }
544
+ blocks.append(block)
545
+
546
+ connections = []
547
+ for s, d in self._block_pairs:
548
+ connection: dict[str, Any] = {
549
+ 'src': block_instances[s],
550
+ 'dst': block_instances[d],
551
+ 'conn_args': []
552
+ }
553
+
554
+ # Get src params that have been connected to dst params.
555
+ #
556
+ nmap = {(gname, sname): dname for (gname, sname), dname in d._block_name_map.items() if gname==s.name}
557
+
558
+ for (gname, sname), dname in nmap.items():
559
+ args = {
560
+ 'src_param_name': sname,
561
+ 'dst_param_name': dname
562
+ }
563
+
564
+ # for pname, data in s.param.watchers.items():
565
+ # if pname==sname:
566
+ # for watcher in data['value']:
567
+ # args['onlychanged'] = watcher.onlychanged
568
+ # args['queued'] = watcher.queued
569
+ # args['precedence'] = watcher.precedence
570
+
571
+
572
+ connection['conn_args'].append(args)
573
+
574
+ connections.append(connection)
575
+
576
+ return {
577
+ 'dag': {
578
+ 'type': self.__class__.__name__,
579
+ 'doc': self.doc,
580
+ 'site': self.site,
581
+ 'title': self.title
582
+ },
583
+ 'blocks': blocks,
584
+ 'connections': connections
585
+ }
586
+
587
+ def topological_sort(pairs):
588
+ """Implement a topological sort as described at
589
+ `Topological sorting <https://en.wikipedia.org/wiki/Topological_sorting>`_.
590
+
591
+ code-block:: python
592
+
593
+ L ← Empty list that will contain the sorted elements
594
+ S ← Set of all nodes with no incoming edge
595
+
596
+ while S is not empty do
597
+ remove a node n from S
598
+ add n to L
599
+ for each node m with an edge e from n to m do
600
+ remove edge e from the graph
601
+ if m has no other incoming edges then
602
+ insert m into S
603
+
604
+ if graph has edges then
605
+ return error (graph has at least one cycle)
606
+ else
607
+ return L (a topologically sorted order)
608
+ """
609
+
610
+ def edge(pairs, n, m):
611
+ for ix, pair in enumerate(pairs):
612
+ if pair==(n, m):
613
+ return ix
614
+
615
+ return None
616
+
617
+ def has_incoming(pairs, m):
618
+ return any(d is m for _,d in pairs)
619
+
620
+ remaining = pairs[:]
621
+ L = []
622
+
623
+ srcs, dsts = zip(*remaining)
624
+ S = list(set([s for s in srcs if s not in dsts]))
625
+
626
+ while S:
627
+ n = S.pop(0)
628
+ L.append(n)
629
+ for _, m in remaining[:]:
630
+ if (e:=edge(remaining, n, m))is not None:
631
+ del remaining[e]
632
+ if not has_incoming(remaining, m):
633
+ S.append(m)
634
+
635
+ return L, remaining
636
+
637
+ def _has_cycle(block_pairs: list[tuple[Block, Block]]):
638
+ _, remaining = topological_sort(block_pairs)
639
+
640
+ return len(remaining)>0
641
+
642
+ def _get_sorted(block_pairs: list[tuple[Block, Block]]) -> list[Block]:
643
+ ordered, remaining = topological_sort(block_pairs)
644
+
645
+ if remaining:
646
+ raise BlockError('Dag contains a cycle')
647
+
648
+ return ordered