bigraph-schema 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,692 @@
1
+ import copy
2
+ from plum import dispatch
3
+ import numpy as np
4
+
5
+ from dataclasses import replace, dataclass
6
+
7
+ from bigraph_schema.schema import (
8
+ Node,
9
+ Empty,
10
+ Union,
11
+ Tuple,
12
+ Boolean,
13
+ Number,
14
+ Integer,
15
+ Float,
16
+ Delta,
17
+ Nonnegative,
18
+ String,
19
+ Enum,
20
+ Wrap,
21
+ Maybe,
22
+ Overwrite,
23
+ List,
24
+ Map,
25
+ Tree,
26
+ Array,
27
+ Key,
28
+ Path,
29
+ Wires,
30
+ Schema,
31
+ Link,
32
+ is_empty,
33
+ dtype_schema,
34
+ schema_dtype,
35
+ )
36
+
37
+
38
+ from bigraph_schema.methods.default import default
39
+ from bigraph_schema.methods.merge import merge, merge_update
40
+
41
+
42
+ def resolve_subclass(subclass, superclass):
43
+ result = {}
44
+ for key in subclass.__dataclass_fields__:
45
+ if key == '_default':
46
+ result[key] = subclass._default or superclass._default
47
+ else:
48
+ subattr = getattr(subclass, key)
49
+ if hasattr(superclass, key): # and not key.startswith('_'):
50
+ superattr = getattr(superclass, key)
51
+ if isinstance(superattr, (Node, dict)):
52
+ try:
53
+ outcome = resolve(subattr, superattr)
54
+ except Exception as e:
55
+ raise Exception(f'\ncannot resolve subtypes for attribute \'{key}\':\n{subattr}\n{superattr}\n\n due to\n{e}')
56
+
57
+ result[key] = outcome
58
+ else:
59
+ result[key] = subattr
60
+ else:
61
+ result[key] = subattr
62
+
63
+ resolved = type(subclass)(**result)
64
+ return resolved
65
+
66
+
67
+ def resolve_empty(empty, update, path=None):
68
+ if path:
69
+ head = path[0]
70
+ result = {}
71
+ result[head] = resolve(empty, update, path[1:])
72
+ return result
73
+ else:
74
+ return update
75
+
76
+ @dispatch
77
+ def resolve(current: Empty, update: Empty, path=None):
78
+ return resolve_empty(current, update, path=path)
79
+
80
+ @dispatch
81
+ def resolve(current: Empty, update: Node, path=None):
82
+ return resolve_empty(current, update, path=path)
83
+
84
+ @dispatch
85
+ def resolve(current: Node, update: Empty, path=None):
86
+ return resolve_empty(update, current, path=path)
87
+
88
+ @dispatch
89
+ def resolve(current: Wrap, update: Wrap, path=None):
90
+ if type(current) == type(update):
91
+ value = resolve(current._value, update._value, path=path)
92
+ return type(current)(_value=value)
93
+
94
+ @dispatch
95
+ def resolve(current: Wrap, update: Node, path=None):
96
+ value = resolve(current._value, update, path=path)
97
+ return type(current)(_value=value)
98
+
99
+ @dispatch
100
+ def resolve(current: Integer, update: Float, path=None):
101
+ if is_empty(update._default):
102
+ if is_empty(current._default):
103
+ return update
104
+ else:
105
+ return replace(update, **{'_default': current._default})
106
+ else:
107
+ return update
108
+
109
+ @dispatch
110
+ def resolve(current: Float, update: Integer, path=None):
111
+ if is_empty(update._default):
112
+ return current
113
+ elif is_empty(current._default):
114
+ return replace(current, **{'_default': update._default})
115
+ else:
116
+ return current
117
+
118
+ @dispatch
119
+ def resolve(current: Node, update: Wrap, path=None):
120
+ value = resolve(current, update._value, path=path)
121
+ return type(update)(_value=value)
122
+
123
+ @dispatch
124
+ def resolve(current: Node, update: Node, path=None):
125
+ if path:
126
+ head = path[0]
127
+ if current == Node():
128
+ current = {
129
+ head: resolve({}, update, path[1:])}
130
+ return current
131
+ else:
132
+ down_current = None
133
+ if hasattr(current, head):
134
+ down_current = getattr(current, head)
135
+ down_resolve = resolve(down_current, update, path[1:])
136
+ setattr(current, head, down_resolve)
137
+ return current
138
+
139
+ current_type = type(current)
140
+ update_type = type(update)
141
+
142
+ if current_type == update_type or issubclass(current_type, update_type):
143
+ return resolve_subclass(current, update)
144
+
145
+ elif issubclass(update_type, current_type):
146
+ return resolve_subclass(update, current)
147
+
148
+ elif isinstance(update, String):
149
+ default_value = update_type._default
150
+ if default_value:
151
+ return replace(
152
+ current,
153
+ **{'_default': default_value})
154
+ else:
155
+ return current
156
+
157
+ else:
158
+ raise Exception(f'\ncannot resolve types:\n{current}\n{update}\n')
159
+
160
+ def resolve_map(current: Map, update, path=None):
161
+ current_type = type(current)
162
+ update_type = type(update)
163
+
164
+ if current_type == update_type or issubclass(current_type, update_type):
165
+ return resolve_subclass(current, update)
166
+
167
+ elif issubclass(update_type, current_type):
168
+ return resolve_subclass(update, current)
169
+
170
+ elif isinstance(update, String):
171
+ default_value = update_type._default
172
+ if default_value:
173
+ return replace(
174
+ current,
175
+ **{'_default': default_value})
176
+ else:
177
+ return current
178
+
179
+ else:
180
+ raise Exception(f'\ncannot resolve types:\n{current}\n{update}\n')
181
+
182
+
183
+ @dispatch
184
+ def resolve(current: Map, update: Map, path=None):
185
+ if path:
186
+ head = path[0]
187
+
188
+ if head == '*':
189
+ down_resolve = resolve(
190
+ current._value,
191
+ update._value,
192
+ path[1:])
193
+
194
+ else:
195
+ down_resolve = resolve(
196
+ current._value,
197
+ update,
198
+ path[1:])
199
+
200
+ return replace(current, **{'_value': down_resolve})
201
+
202
+ else:
203
+ return resolve_map(current, update, path=path)
204
+
205
+ @dispatch
206
+ def resolve(current: Map, update: Node, path=None):
207
+ if path:
208
+ head = path[0]
209
+
210
+ if head == '*':
211
+ for key in update.__dataclass_fields__:
212
+ value = current._value
213
+ if not key.startswith('_'):
214
+ value = resolve(
215
+ value,
216
+ getattr(update, key),
217
+ path[1:])
218
+
219
+ return replace(current, **{'_value': value})
220
+
221
+ else:
222
+ down_resolve = resolve(
223
+ current._value,
224
+ update,
225
+ path[1:])
226
+
227
+ return replace(current, **{'_value': down_resolve})
228
+
229
+ else:
230
+ return resolve_map(current, update, path=path)
231
+
232
+
233
+ @dispatch
234
+ def resolve(current: Map, update: dict, path=None):
235
+ if path:
236
+ head = path[0]
237
+ down_resolve = resolve(current._value, update, path[1:])
238
+ return replace(current, **{'_value': down_resolve})
239
+
240
+ result = current._value
241
+ try:
242
+ for key, value in update.items():
243
+ result = resolve(result, value)
244
+ resolved = replace(current, _value=result)
245
+
246
+ except:
247
+ # upgrade from map to struct schema
248
+ map_default = default(current)
249
+ resolved = {
250
+ key: current._value
251
+ for key in map_default}
252
+ resolved.update(update)
253
+
254
+ schema = merge_update(resolved, current, update)
255
+ return schema
256
+
257
+ @dispatch
258
+ def resolve(current: dict, update: Map, path=None):
259
+ if path:
260
+ head = path[0]
261
+ if head == '*':
262
+ if current:
263
+ for key, subcurrent in current.items():
264
+ current[key] = resolve(
265
+ subcurrent,
266
+ update._value,
267
+ path[1:])
268
+ return current
269
+
270
+ else:
271
+ subvalue = resolve(
272
+ current,
273
+ update._value,
274
+ path[1:])
275
+
276
+ return replace(
277
+ update,
278
+ **{'_value': subvalue})
279
+
280
+ else:
281
+ down_resolve = resolve(
282
+ current.get(head, {}),
283
+ update,
284
+ path[1:])
285
+ current[head] = down_resolve
286
+ return current
287
+
288
+ result = update._value
289
+
290
+ try:
291
+ for key, value in current.items():
292
+ result = resolve(result, value)
293
+ resolved = replace(update, _value=result)
294
+
295
+ except:
296
+ # upgrade from map to struct schema
297
+ map_default = default(update)
298
+ resolved = {
299
+ key: update._value
300
+ for key in map_default}
301
+ current.update(resolved)
302
+
303
+ schema = merge_update(resolved, current, update)
304
+ return schema
305
+
306
+
307
+ def tree_path(current, update, path):
308
+ head = path[0]
309
+ down_resolve = resolve(current, update, path[1:])
310
+ if isinstance(down_resolve, Tree):
311
+ return down_resolve
312
+ else:
313
+ return replace(current, **{'_leaf': down_resolve})
314
+
315
+ @dispatch
316
+ def resolve(current: Tree, update: Map, path=None):
317
+ if path:
318
+ return tree_path(current, update, path)
319
+
320
+ value = current._leaf
321
+ leaf = update._value
322
+ update_leaf = resolve(leaf, value)
323
+ result = copy.copy(current)
324
+ resolved = replace(result, _leaf=update_leaf)
325
+
326
+ schema = merge_update(resolved, current, update)
327
+ return schema
328
+
329
+ @dispatch
330
+ def resolve(current: Tree, update: Tree, path=None):
331
+ if path:
332
+ return tree_path(current, update, path)
333
+
334
+ current_leaf = current._leaf
335
+ update_leaf = update._leaf
336
+ resolved = resolve(current_leaf, update_leaf)
337
+ result = replace(current, _leaf=resolved)
338
+
339
+ schema = merge_update(result, current, update)
340
+ return schema
341
+
342
+ @dispatch
343
+ def resolve(current: Tree, update: Node, path=None):
344
+ if path:
345
+ return tree_path(current, update, path)
346
+
347
+ leaf = current._leaf
348
+ try:
349
+ resolved = resolve(leaf, update)
350
+ except:
351
+ raise Exception(f'update schema is neither a tree or a leaf:\n{current}\n{update}')
352
+
353
+ replace(current, _leaf=resolved)
354
+ return current
355
+
356
+ @dispatch
357
+ def resolve(current: Tree, update: dict, path=None):
358
+ if path:
359
+ return tree_path(current, update, path)
360
+
361
+ result = copy.copy(current)
362
+ leaf = current._leaf
363
+ for key, value in update.items():
364
+ try:
365
+ leaf = resolve(leaf, value)
366
+ except:
367
+ result = resolve(result, value)
368
+ resolved = replace(result, _leaf=leaf)
369
+
370
+ schema = merge_update(resolved, current, update)
371
+ return schema
372
+
373
+ @dispatch
374
+ def resolve(current: dict, update: dict, path=None):
375
+ if path:
376
+ head = path[0]
377
+ down_resolve = resolve(
378
+ current.get(head, {}),
379
+ update,
380
+ path[1:])
381
+ current[head] = down_resolve
382
+ return current
383
+
384
+ result = {}
385
+
386
+ all_keys = list(current.keys())
387
+ for key in update.keys():
388
+ if not key in current:
389
+ all_keys.append(key)
390
+
391
+ for key in all_keys:
392
+ if key in ('_inherit',):
393
+ continue
394
+
395
+ try:
396
+ value = resolve(
397
+ current.get(key),
398
+ update.get(key))
399
+
400
+ except Exception as e:
401
+ raise Exception(f'\ncannot resolve subtypes for key \'{key}\':\n{current}\n{update}\n\n due to\n{e}')
402
+
403
+ result[key] = value
404
+ return result
405
+
406
+
407
+ def resolve_array_path(array: Array, update, path=None):
408
+ if path:
409
+ head = path[0]
410
+ subshape = array._shape[1:]
411
+
412
+ if subshape:
413
+ down_schema = replace(array, **{
414
+ '_shape': subshape})
415
+ down_resolve = resolve(down_schema, update, path=path[1:])
416
+ up_schema = replace(down_resolve, **{
417
+ '_shape': (array._shape[0],) + tuple(down_resolve._shape)})
418
+ return up_schema
419
+ else:
420
+ data_schema = dtype_schema(array._data)
421
+
422
+ if isinstance(update, Array):
423
+ if update._shape:
424
+ raise Exception(f'resolving arrays but they have different dimensions:\n\n{array}\n\n{update}')
425
+ else:
426
+ subupdate = dtype_schema(update._data)
427
+ elif isinstance(update, dict):
428
+ subupdate = update.get(head)
429
+ else:
430
+ subupdate = update
431
+ # raise Exception(f'resolving array with incompatible schema:\\n{array}\\n{update}')
432
+
433
+ subschema = resolve(data_schema, subupdate, path=path[1:])
434
+ dtype = schema_dtype(subschema)
435
+ if isinstance(dtype, Array):
436
+ up_schema = replace(array, **{
437
+ '_shape': array._shape + dtype._shape})
438
+ else:
439
+ up_schema = replace(array, **{
440
+ '_data': dtype})
441
+
442
+ return up_schema
443
+ else:
444
+ return array
445
+
446
+
447
+ @dispatch
448
+ def resolve(current: Array, update: Array, path=None):
449
+ if path:
450
+ return resolve_array_path(current, update, path=path)
451
+
452
+ new_shape = [
453
+ max(current_shape, update_shape)
454
+ for current_shape, update_shape in zip(current._shape, update._shape)]
455
+ if len(current._shape) > len(update._shape):
456
+ new_shape += current._shape[len(update._shape):]
457
+ if len(update._shape) > len(current._shape):
458
+ new_shape += update._shape[len(current._shape):]
459
+ return replace(current, **{'_shape': new_shape})
460
+
461
+
462
+ @dispatch
463
+ def resolve(current: Array, update: Node, path=None):
464
+ if path:
465
+ return resolve_array_path(current, update, path=path)
466
+
467
+ # TODO:
468
+ # finish array behavior
469
+
470
+ return current
471
+
472
+ # for key, subschema in update.items():
473
+ # if isinstance(key, int):
474
+
475
+ # @dispatch
476
+ # def resolve(current: Node, update: Array, path=None):
477
+ # if path:
478
+ # import ipdb; ipdb.set_trace()
479
+
480
+ # return resolve_array_path(update, current, path=path)
481
+
482
+ # return update
483
+
484
+ @dispatch
485
+ def resolve(current: Array, update: dict, path=None):
486
+ if path:
487
+ return resolve_array_path(current, update, path=path)
488
+ else:
489
+ return current
490
+
491
+ def resolve_dict_path(current, update, path=None):
492
+ if path:
493
+ head = path[0]
494
+ if head == '*':
495
+ if isinstance(update, Array):
496
+ row_shape = update._shape[0]
497
+ if not all([isinstance(key, int) for key in current.keys()]):
498
+ raise Exception(f'trying to resolve a dict and array but the keys are not all indexes:\n\n{current}\n\n{update}')
499
+
500
+ if current:
501
+ row_shape = max(row_shape, max(current.keys()) + 1)
502
+
503
+ subshape = update._shape[1:]
504
+ if subshape:
505
+ subschema = replace(update, **{'_shape': subshape})
506
+ else:
507
+ subschema = dtype_schema(update._data)
508
+
509
+ if current:
510
+ resolve_schema = subschema
511
+ for key, subcurrent in current.items():
512
+ merge_schema = resolve(
513
+ subcurrent,
514
+ subschema,
515
+ path=path[1:])
516
+ resolve_schema = resolve(resolve_schema, merge_schema)
517
+
518
+ else:
519
+ resolve_schema = resolve(
520
+ {},
521
+ subschema,
522
+ path=path[1:])
523
+
524
+ if isinstance(resolve_schema, dict):
525
+ inner_index = max(resolve_schema.keys())
526
+ inner = resolve_schema[inner_index]
527
+ if isinstance(inner, Array):
528
+ inner_shape = (inner_index+1,) + inner._shape
529
+ else:
530
+ inner_shape = (inner_index+1,)
531
+ resolve_schema = replace(update, **{'_shape': inner_shape})
532
+
533
+ if isinstance(resolve_schema, Array):
534
+ resolve_shape = (row_shape,) + resolve_schema._shape
535
+ result_schema = replace(update, **{'_shape': resolve_shape})
536
+ else:
537
+ dtype = schema_dtype(resolve_schema)
538
+
539
+ if isinstance(dtype, Array):
540
+ result_schema = replace(update, **{'_shape': update._shape + dtype._shape})
541
+ else:
542
+ result_schema = replace(update, **{'_data': dtype})
543
+
544
+ return result_schema
545
+
546
+ else:
547
+ down_schema = current.get(head, {})
548
+ down_resolve = resolve(down_schema, update, path=path[1:])
549
+ current[head] = down_resolve
550
+ return current
551
+
552
+ else:
553
+ return update
554
+
555
+ @dispatch
556
+ def resolve(current: dict, update: Array, path=None):
557
+ if path:
558
+ return resolve_dict_path(current, update, path=path)
559
+ return update
560
+
561
+ def resolve_link(link: Link, update, path=None):
562
+ if path:
563
+ head = path[0]
564
+ down_schema = {}
565
+ if hasattr(link, head):
566
+ down_schema = getattr(link, head)
567
+ down_resolve = resolve(
568
+ down_schema,
569
+ update,
570
+ path[1:])
571
+ return replace(link, **{head: down_resolve})
572
+
573
+ schema = link
574
+ for key in ['_inputs', '_outputs']:
575
+ if key in update:
576
+ subupdate = update[key]
577
+ attr = getattr(schema, key)
578
+ subresolve = resolve(attr, subupdate)
579
+ schema = replace(schema, **{key: subresolve})
580
+
581
+ return schema
582
+
583
+ @dispatch
584
+ def resolve(current: Link, update: dict, path=None):
585
+ return resolve_link(current, update, path=path)
586
+
587
+ @dispatch
588
+ def resolve(current: dict, update: Link, path=None):
589
+ return resolve_link(update, current, path=path)
590
+
591
+ @dispatch
592
+ def resolve(current: Node, update: dict, path=None):
593
+ if path:
594
+ head = path[0]
595
+ down_schema = {}
596
+ if hasattr(current, head):
597
+ down_schema = getattr(current, head)
598
+ down_resolve = resolve(down_schema, update, path[1:])
599
+ return replace(current, **{head: down_resolve})
600
+
601
+ fields = set(current.__dataclass_fields__)
602
+ keys = set(update.keys())
603
+
604
+ if len(keys.difference(fields)) > 0:
605
+ return update
606
+ else:
607
+ return current
608
+
609
+ @dispatch
610
+ def resolve(current: dict, update: Node, path=None):
611
+ if path:
612
+ return resolve_dict_path(current, update, path=path)
613
+
614
+ if not current:
615
+ return update
616
+
617
+ fields = set(update.__dataclass_fields__)
618
+ keys = set(current.keys())
619
+
620
+ if len(keys.difference(fields)) > 0:
621
+ return current
622
+ else:
623
+ return update
624
+
625
+ @dispatch
626
+ def resolve(current: String, update: Wrap, path=None):
627
+ return replace(update, **{'_value':resolve(current, update._value, path=path)})
628
+
629
+ @dispatch
630
+ def resolve(current: String, update: Node, path=None):
631
+ if current._default:
632
+ update._default = current._default
633
+ return update
634
+
635
+ # @dispatch
636
+ # def resolve(current: Node, update: String):
637
+ # if update._default:
638
+ # current._default = update._default
639
+ # return current
640
+
641
+ # @dispatch
642
+ # def resolve(current: String, update: Wrap):
643
+ # return resolve(current, update._value)
644
+
645
+ # @dispatch
646
+ # def resolve(current: String, update: String):
647
+ # if update._default or not current._default:
648
+ # return update
649
+ # else:
650
+ # return current
651
+
652
+ # @dispatch
653
+ # def resolve(current: Node, update: String):
654
+ # # import ipdb; ipdb.set_trace()
655
+ # if update._default:
656
+ # current = replace(current, **{'_default': update._default})
657
+ # return current
658
+
659
+
660
+ # @dispatch
661
+ # def resolve(current: dict, update: Node):
662
+ # fields = set(update.__dataclass_fields__)
663
+ # keys = set(current.keys())
664
+
665
+ # for key in keys.intersect(fields):
666
+ # getattr(update, key)
667
+
668
+
669
+
670
+ @dispatch
671
+ def resolve(current: List, update: Tuple, path=None):
672
+ if not update._default and current._default:
673
+ update._default = tuple(current._default)
674
+ return update
675
+
676
+
677
+ @dispatch
678
+ def resolve(current: list, update: list, path=None):
679
+ ### ???
680
+ return tuple(update)
681
+
682
+
683
+ @dispatch
684
+ def resolve(current, update, path=None):
685
+ if is_empty(current):
686
+ return update
687
+ elif is_empty(update):
688
+ return current
689
+ else:
690
+ raise Exception(f'\ncannot resolve types, not schemas:\n{current}\n{update}\n')
691
+
692
+