bigraph-schema 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,217 @@
1
+ from plum import dispatch
2
+ import numpy as np
3
+ from numpy.random.mtrand import RandomState
4
+ import traceback
5
+
6
+ from types import NoneType
7
+ from dataclasses import replace
8
+
9
+ from bigraph_schema.schema import (
10
+ Node,
11
+ Union,
12
+ Tuple,
13
+ Boolean,
14
+ Number,
15
+ Integer,
16
+ Float,
17
+ Delta,
18
+ Nonnegative,
19
+ NPRandom,
20
+ String,
21
+ Enum,
22
+ Wrap,
23
+ Maybe,
24
+ Overwrite,
25
+ List,
26
+ Map,
27
+ Tree,
28
+ Array,
29
+ Key,
30
+ Path,
31
+ Wires,
32
+ Schema,
33
+ Link,
34
+ )
35
+
36
+
37
+ from bigraph_schema.methods.serialize import serialize
38
+ from bigraph_schema.methods.realize import realize
39
+
40
+ MISSING_TYPES = {}
41
+
42
+
43
+ def set_default(schema, value):
44
+ if value is not None:
45
+ serialized = serialize(schema, value)
46
+ if isinstance(serialized, dict) and '_default' in serialized:
47
+ serialized = serialized['_default']
48
+
49
+ if isinstance(schema, Node):
50
+ schema = replace(schema, _default=serialized)
51
+ elif isinstance(schema, dict):
52
+ schema['_default'] = serialized
53
+
54
+ return schema
55
+
56
+ @dispatch
57
+ def infer(core,
58
+ value: (int | np.int32 | np.int64 |
59
+ np.dtypes.Int32DType | np.dtypes.Int64DType),
60
+ path: tuple = ()):
61
+ schema = Integer()
62
+ return set_default(schema, value), []
63
+
64
+ @dispatch
65
+ def infer(core, value: bool, path: tuple = ()):
66
+ schema = Boolean()
67
+ return set_default(schema, value), []
68
+
69
+ @dispatch
70
+ def infer(core,
71
+ value: (float | np.float32 | np.float64 |
72
+ np.dtypes.Float32DType | np.dtypes.Float64DType),
73
+ path: tuple = ()):
74
+ schema = Float()
75
+ return set_default(schema, value), []
76
+
77
+ @dispatch
78
+ def infer(core, value: str, path: tuple = ()):
79
+ schema = String()
80
+ return set_default(schema, value), []
81
+
82
+ @dispatch
83
+ def infer(core, value: np.ndarray, path: tuple = ()):
84
+ schema = Array(
85
+ _shape=value.shape,
86
+ _data=value.dtype) # Dtype(_fields=value.dtype))
87
+
88
+ return set_default(schema, value), []
89
+
90
+ @dispatch
91
+ def infer(core, value: RandomState, path: tuple = ()):
92
+ state = value.get_state()
93
+ data, merges = infer(core, state)
94
+ schema = NPRandom(state=data)
95
+
96
+ return set_default(schema, value), merges
97
+
98
+ @dispatch
99
+ def infer(core, value: list, path: tuple = ()):
100
+ merges = []
101
+ if len(value) > 0:
102
+ element, merges = infer(
103
+ core,
104
+ value[0],
105
+ path+('_element',))
106
+ else:
107
+ element = Node()
108
+
109
+ schema = List(_element=element)
110
+ return set_default(schema, value), merges
111
+
112
+ @dispatch
113
+ def infer(core, value: tuple, path: tuple = ()):
114
+ result = []
115
+ merges = []
116
+ for index, item in enumerate(value):
117
+ if isinstance(item, np.str_):
118
+ result.append(item)
119
+ else:
120
+ inner, submerges = infer(core, item, path+(index,))
121
+ merges += submerges
122
+ result.append(inner)
123
+
124
+ schema = Tuple(_values=result)
125
+ return set_default(schema, value), merges
126
+
127
+ @dispatch
128
+ def infer(core, value: NoneType, path: tuple = ()):
129
+ schema = Maybe(_value=Node())
130
+ return set_default(schema, value), []
131
+
132
+ @dispatch
133
+ def infer(core, value: set, path: tuple = ()):
134
+ return infer(
135
+ core,
136
+ list(value),
137
+ path)
138
+
139
+
140
+ def separate_keys(d):
141
+ schema = {}
142
+ state = {}
143
+ for key, value in d.items():
144
+ if key.startswith('_'):
145
+ schema[key] = value
146
+ else:
147
+ state[key] = value
148
+
149
+ return schema, state
150
+
151
+ @dispatch
152
+ def infer(core, value: dict, path: tuple = ()):
153
+ if '_type' in value:
154
+ schema_keys, state = separate_keys(value)
155
+ schema = core.access_type(schema_keys)
156
+ merges = []
157
+
158
+ return set_default(schema, state), merges
159
+
160
+ elif '_default' in value:
161
+ return infer(core, value['_default'])
162
+
163
+ else:
164
+ subvalues = {}
165
+ distinct_subvalues = []
166
+ merges = []
167
+ for key, subvalue in value.items():
168
+ subvalues[key], submerges = infer(
169
+ core,
170
+ subvalue,
171
+ path+(key,))
172
+ merges += submerges
173
+
174
+ if len(distinct_subvalues) < 2 and subvalues[key] not in distinct_subvalues:
175
+ distinct_subvalues.append(
176
+ subvalues[key])
177
+
178
+ if len(distinct_subvalues) == 1 and len(subvalues) > 1:
179
+ map_value = distinct_subvalues[0]
180
+ schema = Map(_value=map_value)
181
+ return set_default(schema, value), merges
182
+ else:
183
+ # return Place(_default=value, _subnodes=subvalues)
184
+ return subvalues, merges
185
+
186
+ @dispatch
187
+ def infer(core, value: object, path: tuple = ()):
188
+ type_name = str(type(value))
189
+
190
+ value_keys = value.__dict__.keys()
191
+ value_schema = {}
192
+
193
+ merges = []
194
+
195
+ for key in value_keys:
196
+ if not key.startswith('_'):
197
+ try:
198
+ value_schema[key], submerges = infer(
199
+ core,
200
+ getattr(value, key),
201
+ path + (key,))
202
+ merges += submerges
203
+
204
+ except Exception as e:
205
+ traceback.print_exc()
206
+ print(e)
207
+
208
+ if type_name not in MISSING_TYPES:
209
+ MISSING_TYPES[type_name] = set([])
210
+
211
+ MISSING_TYPES[type_name].add(
212
+ path)
213
+
214
+ value_schema[key] = Node()
215
+
216
+ return value_schema, merges
217
+
@@ -0,0 +1,432 @@
1
+ from plum import dispatch
2
+ import numpy as np
3
+
4
+ from dataclasses import replace, dataclass
5
+
6
+ from bigraph_schema.schema import (
7
+ Node,
8
+ Atom,
9
+ Empty,
10
+ Union,
11
+ Tuple,
12
+ Boolean,
13
+ Number,
14
+ Integer,
15
+ Float,
16
+ Delta,
17
+ Nonnegative,
18
+ String,
19
+ Enum,
20
+ Wrap,
21
+ Maybe,
22
+ Overwrite,
23
+ List,
24
+ Map,
25
+ Tree,
26
+ Array,
27
+ Key,
28
+ Path,
29
+ Wires,
30
+ Schema,
31
+ Link,
32
+ Star,
33
+ Index,
34
+ Jump,
35
+ convert_path,
36
+ walk_path,
37
+ )
38
+
39
+ from bigraph_schema.methods import default, check, serialize, resolve
40
+
41
+
42
+ @dispatch
43
+ def jump(schema: Empty, state, to, context):
44
+ return schema, None
45
+
46
+
47
+ @dispatch
48
+ def jump(schema: Maybe, state, to, context):
49
+ if state is None:
50
+ return Empty(), state
51
+ else:
52
+ return jump(schema._value, state, to, context)
53
+
54
+
55
+ @dispatch
56
+ def jump(schema: Wrap, state, to, context):
57
+ return jump(schema._value, state, to, context)
58
+
59
+
60
+ @dispatch
61
+ def jump(schema: Union, state, to, context):
62
+ for option in schema._options:
63
+ if check(option, state):
64
+ return jump(option, state, to, context)
65
+ return Empty(), None
66
+
67
+
68
+ @dispatch
69
+ def jump(schema: Tuple, state, to: Key, context):
70
+ index = Index(int(to._value))
71
+ return jump(schema, state, index, context)
72
+
73
+
74
+ @dispatch
75
+ def jump(schema: Tuple, state, to: Index, context):
76
+ return traverse(
77
+ schema._values[to._value],
78
+ state[to._value],
79
+ context['subpath'],
80
+ context)
81
+
82
+
83
+ @dispatch
84
+ def jump(schema: Tuple, state, to: Star, context):
85
+ value_schemas = []
86
+ values = []
87
+
88
+ for index, value in enumerate(schema._values):
89
+ subvalue_schema, subvalue = traverse(
90
+ value,
91
+ state[index],
92
+ context['subpath'],
93
+ context)
94
+
95
+ value_schemas.append(subvalue_schema)
96
+ values.append(subvalue)
97
+
98
+ subschema = Tuple(_values=value_schemas)
99
+ substate = tuple(values)
100
+
101
+ return subschema, substate
102
+
103
+
104
+ @dispatch
105
+ def jump(schema: Tuple, state, to: Jump, context):
106
+ # TODO: find general way to format exceptions (!)
107
+ raise Exception(f'cannot lookup index "{to._value}" in tuple {state}\ncontext:\n{context}')
108
+
109
+
110
+ @dispatch
111
+ def jump(schema: List, state, to: Key, context):
112
+ index = Index(int(to._value))
113
+ return jump(schema, state, index, context)
114
+
115
+
116
+ @dispatch
117
+ def jump(schema: List, state, to: Index, context):
118
+ return traverse(
119
+ schema._element,
120
+ state[to._value],
121
+ context['subpath'],
122
+ context)
123
+
124
+
125
+ @dispatch
126
+ def jump(schema: List, state, to: Star, context):
127
+ subelement = Node()
128
+ elements = []
129
+
130
+ for index, value in state:
131
+ subvalue_schema, subvalue = traverse(
132
+ schema._element,
133
+ state[index],
134
+ context['subpath'],
135
+ context)
136
+
137
+ subelement = resolve(subelement, subvalue_schema)
138
+ elements.append(subvalue)
139
+
140
+ subschema = List(_element=subelement)
141
+ return subschema, elements
142
+
143
+
144
+ @dispatch
145
+ def jump(schema: List, state, to: Jump, context):
146
+ raise Exception(f'cannot lookup index "{to._value}" in list {state}\ncontext:\n{context}')
147
+
148
+
149
+ @dispatch
150
+ def jump(schema: Map, state, to: Index, context):
151
+ key = Key(str(to._value))
152
+ return jump(schema, state, key, context)
153
+
154
+
155
+ @dispatch
156
+ def jump(schema: Map, state, to: Key, context):
157
+ state = state or {}
158
+
159
+ return traverse(
160
+ schema._value,
161
+ state.get(to._value),
162
+ context['subpath'],
163
+ context)
164
+
165
+
166
+ @dispatch
167
+ def jump(schema: Map, state, to: Star, context):
168
+ value_schema = Node()
169
+ values = {}
170
+
171
+ for key, value in state.items():
172
+ index = serialize(schema._key, key)
173
+ subvalue_schema, subvalue = traverse(
174
+ schema._value,
175
+ state[index],
176
+ context['subpath'],
177
+ context)
178
+
179
+ value_schema = resolve(value_schema, subvalue_schema)
180
+ values[index] = subvalue
181
+
182
+ subschema = Map(_key=schema._key, _value=value_schema)
183
+ return subschema, values
184
+
185
+
186
+ @dispatch
187
+ def jump(schema: Map, state, to: Jump, context):
188
+ key = serialize(schema._key, to._value)
189
+ return jump(schema, state, Key(_value=key), context)
190
+
191
+
192
+ @dispatch
193
+ def jump(schema: Tree, state, to: Key, context):
194
+ down = state[to._value]
195
+
196
+ subschema = schema
197
+ if check(schema._leaf, down):
198
+ subschema = schema._leaf
199
+
200
+ return traverse(
201
+ subschema,
202
+ down,
203
+ context['subpath'],
204
+ context)
205
+
206
+
207
+ @dispatch
208
+ def jump(schema: Tree, state, to: Star, context):
209
+ leaf_schema = Node()
210
+ branches = {}
211
+
212
+ for key, branch in state.items():
213
+ subschema = schema
214
+ if check(schema._leaf, branch):
215
+ subschema = schema._leaf
216
+
217
+ branch_schema, branch_value = traverse(
218
+ subschema,
219
+ branch,
220
+ context['subpath'],
221
+ context)
222
+
223
+ leaf_schema = resolve(leaf_schema, branch_schema)
224
+ branches[key] = branch_value
225
+
226
+ subschema = Tree(_leaf=leaf_schema)
227
+ return subschema, branches
228
+
229
+
230
+ @dispatch
231
+ def jump(schema: Tree, state, to: Jump, context):
232
+ raise Exception(f'cannot lookup key "{to._value}" in tree {state}\ncontext:\n{context}')
233
+
234
+
235
+ @dispatch
236
+ def jump(schema: Atom, state, to, context):
237
+ if to._value:
238
+ raise Exception(f'cannot jump in atom - key is "{to._value}" but state is an atom:\n{state}')
239
+ else:
240
+ return schema, state
241
+
242
+
243
+ def jump_link(schema: Link, state, to: Key, context):
244
+ key = to._value
245
+ if key in ['inputs', 'outputs']:
246
+ if not key in state:
247
+ raise Exception(f'no "{key}" key in state to jump to:\n{state}')
248
+
249
+ puts_schema = getattr(schema, f'_{key}')
250
+ wires_schema = getattr(schema, key)
251
+ subcontext = dict(context, **{
252
+ 'ports_key': key,
253
+ 'link_path': context['path'][:-1],
254
+ f'_{key}': puts_schema})
255
+
256
+ return traverse(
257
+ wires_schema,
258
+ state[key],
259
+ context['subpath'],
260
+ subcontext)
261
+
262
+ elif hasattr(schema, key):
263
+ subschema = getattr(schema, key)
264
+ return traverse(
265
+ subschema,
266
+ state.get(key),
267
+ context['subpath'],
268
+ context)
269
+
270
+ else:
271
+ return jump(Node(), state, to, context)
272
+
273
+
274
+ @dispatch
275
+ def jump(schema: Link, state, to: Key, context):
276
+ return jump_link(schema, state, to, context)
277
+
278
+ @dispatch
279
+ def jump(schema: Wires, state, to: Key, context):
280
+ key = to._value
281
+
282
+ if not key in state:
283
+ raise Exception(f'no entry "{key}" for wires:\n{state}')
284
+
285
+ substate = state[key]
286
+ if isinstance(substate, list):
287
+ outer_path = context['link_path'][:-1]
288
+ subpath = tuple(convert_path(substate)) + tuple(context['subpath'])
289
+ target_path = outer_path + subpath
290
+ subcontext = dict(context, **{
291
+ 'path': outer_path,
292
+ 'subpath': subpath})
293
+
294
+ return traverse(
295
+ context['schema'],
296
+ context['state'],
297
+ target_path,
298
+ subcontext)
299
+
300
+ else:
301
+ return traverse(
302
+ schema,
303
+ substate,
304
+ context['subpath'],
305
+ context)
306
+
307
+ @dispatch
308
+ def jump(schema: Node, state, to: Star, context):
309
+ value_schema = {}
310
+ values = {}
311
+
312
+ for key, value in schema.__dataclass_fields__:
313
+ if key in state:
314
+ subschema, subvalue = traverse(
315
+ getattr(schema, key),
316
+ state[key],
317
+ context['subpath'],
318
+ context)
319
+
320
+ value_schema[key] = subschema
321
+ values[key] = subvalue
322
+ else:
323
+ raise Exception(f'traverse: no key "{key}" in state {state} at path {context["path"]}')
324
+
325
+ return value_schema, values
326
+
327
+
328
+ @dispatch
329
+ def jump(schema: Array, state, to: Jump, context):
330
+ index = to._value
331
+ if isinstance(index, int) and index < schema._shape[0]:
332
+ subschema = replace(schema, **{'_shape': schema._shape[1:]})
333
+ return traverse(
334
+ subschema,
335
+ state[index],
336
+ context['subpath'],
337
+ context)
338
+ else:
339
+ raise Exception(f'traverse: no index {index} in array state {state} at path {context["path"]}')
340
+
341
+
342
+ @dispatch
343
+ def jump(schema: Array, state, to: Star, context):
344
+ results = []
345
+ subschema = replace(schema, **{'_shape': schema._shape[1:]})
346
+ for index, row in enumerate(state):
347
+ result_schema, result_state = traverse(
348
+ subschema,
349
+ row,
350
+ context['subpath'],
351
+ context)
352
+ results.append(result_state)
353
+
354
+ if isinstance(subschema, Array):
355
+ result_shape = (schema._shape[0],) + tuple(result_schema._shape)
356
+ subschema = replace(schema, **{'_shape': result_shape})
357
+ return subschema, np.array(results)
358
+ else:
359
+ # TODO: handle array data of different types
360
+ import ipdb; ipdb.set_trace()
361
+
362
+
363
+ @dispatch
364
+ def jump(schema: Node, state, to: Jump, context):
365
+ key = to._value
366
+ if key in state:
367
+ return traverse(
368
+ getattr(schema, key),
369
+ state[key],
370
+ context['subpath'],
371
+ context)
372
+ else:
373
+ raise Exception(f'traverse: no key "{key}" in state {state} at path {context["path"]}')
374
+
375
+
376
+ @dispatch
377
+ def jump(schema: Node, state, to, context):
378
+ raise Exception(f'cannot lookup key "{to._value}" in state {state}\ncontext:\n{context}')
379
+
380
+
381
+ @dispatch
382
+ def jump(schema: dict, state, to: Key, context):
383
+ state = state or {}
384
+ key = to._value
385
+
386
+ if key in schema:
387
+ return traverse(
388
+ schema[key],
389
+ state.get(key),
390
+ context['subpath'],
391
+ context)
392
+
393
+ elif key in state:
394
+ return None, state[key]
395
+
396
+ else:
397
+ return Empty(), None
398
+
399
+
400
+ @dispatch
401
+ def jump(schema: dict, state, to: Star, context):
402
+ value_schema = {}
403
+ values = {}
404
+
405
+ for key in schema:
406
+ if key in state:
407
+ subschema, subvalue = traverse(
408
+ schema[key],
409
+ state[key],
410
+ context['subpath'],
411
+ context)
412
+
413
+ value_schema[key] = subschema
414
+ values[key] = subvalue
415
+
416
+ return value_schema, values
417
+
418
+
419
+ @dispatch
420
+ def jump(schema: dict, state, to: Jump, context):
421
+ raise Exception(f'cannot lookup key "{to._value}" in state {state}\ncontext:\n{context}')
422
+
423
+
424
+ def traverse(schema, state, path, context):
425
+ if path:
426
+ to = path[0]
427
+ subpath = path[1:]
428
+ subcontext = walk_path(context, to, subpath)
429
+
430
+ return jump(schema, state, to, subcontext)
431
+ else:
432
+ return schema, state