bigraph-schema 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,204 @@
1
+ from plum import dispatch
2
+ import numpy as np
3
+
4
+ from bigraph_schema.schema import (
5
+ Node,
6
+ Empty,
7
+ Union,
8
+ Tuple,
9
+ Boolean,
10
+ Number,
11
+ Integer,
12
+ Float,
13
+ Delta,
14
+ Nonnegative,
15
+ String,
16
+ Enum,
17
+ Wrap,
18
+ Maybe,
19
+ Overwrite,
20
+ List,
21
+ Map,
22
+ Tree,
23
+ Array,
24
+ Key,
25
+ Path,
26
+ Wires,
27
+ Protocol,
28
+ LocalProtocol,
29
+ Schema,
30
+ Link,
31
+ )
32
+
33
+
34
+ @dispatch
35
+ def default(schema: None):
36
+ return None
37
+
38
+ @dispatch
39
+ def default(schema: Empty):
40
+ return None
41
+
42
+ @dispatch
43
+ def default(schema: Wrap):
44
+ if schema._default is not None:
45
+ return schema._default
46
+ else:
47
+ return default(schema._value)
48
+
49
+ @dispatch
50
+ def default(schema: Union):
51
+ if schema._default is not None:
52
+ return schema._default
53
+ else:
54
+ return default(schema._options[0])
55
+
56
+ @dispatch
57
+ def default(schema: Tuple):
58
+ if schema._default is not None:
59
+ return schema._default
60
+ else:
61
+ return [
62
+ default(subschema)
63
+ for subschema in schema._values]
64
+
65
+ @dispatch
66
+ def default(schema: Boolean):
67
+ if schema._default is not None:
68
+ return schema._default
69
+ else:
70
+ return False
71
+
72
+ @dispatch
73
+ def default(schema: Integer):
74
+ if schema._default is not None:
75
+ return schema._default
76
+ else:
77
+ return 0
78
+
79
+ @dispatch
80
+ def default(schema: Float):
81
+ if schema._default is not None:
82
+ return schema._default
83
+ else:
84
+ return 0.0
85
+
86
+ @dispatch
87
+ def default(schema: String):
88
+ if schema._default is not None:
89
+ return schema._default
90
+ else:
91
+ return ''
92
+
93
+ @dispatch
94
+ def default(schema: Enum):
95
+ if schema._default is not None:
96
+ return schema._default
97
+ else:
98
+ return schema._values[0]
99
+
100
+ @dispatch
101
+ def default(schema: List):
102
+ if schema._default is not None:
103
+ return schema._default
104
+ else:
105
+ return []
106
+
107
+ @dispatch
108
+ def default(schema: Map):
109
+ if schema._default is not None:
110
+ return schema._default
111
+ else:
112
+ return {}
113
+
114
+ @dispatch
115
+ def default(schema: Tree):
116
+ if schema._default is not None:
117
+ return schema._default
118
+ else:
119
+ return {}
120
+
121
+
122
+ @dispatch
123
+ def default(schema: Array):
124
+ if schema._default is not None:
125
+ return schema._default
126
+ else:
127
+ return np.zeros(
128
+ schema._shape,
129
+ dtype=schema._data)
130
+
131
+ def default_wires(schema, path=None):
132
+ path = path or []
133
+
134
+ if isinstance(schema, dict):
135
+ result = {}
136
+ for key, subschema in schema.items():
137
+ subpath = path+[key]
138
+ result[key] = default_wires(
139
+ subschema,
140
+ subpath)
141
+ return result
142
+
143
+ elif isinstance(schema, Node):
144
+ return path
145
+
146
+
147
+ @dispatch
148
+ def default(schema: Protocol):
149
+ if schema._default is not None:
150
+ return schema._default
151
+ else:
152
+ return {
153
+ 'protocol': 'local',
154
+ 'data': 'edge'}
155
+
156
+ def default_link(schema: Link):
157
+ if schema._default:
158
+ return schema._default
159
+ else:
160
+ return {
161
+ 'address': default(schema.address) or 'local:edge',
162
+ 'config': default(schema.config) or {},
163
+ '_inputs': schema._inputs,
164
+ '_outputs': schema._outputs,
165
+ 'inputs': default(schema.inputs) or default_wires(schema._inputs),
166
+ 'outputs': default(schema.outputs) or default_wires(schema._outputs)}
167
+
168
+ @dispatch
169
+ def default(schema: Link):
170
+ return default_link(schema)
171
+
172
+ def is_schema_key(key):
173
+ return isinstance(key, str) and key.startswith('_')
174
+
175
+ @dispatch
176
+ def default(schema: dict):
177
+ if '_default' in schema:
178
+ return schema['_default']
179
+ else:
180
+ result = {}
181
+ for key in schema:
182
+ if not is_schema_key(key):
183
+ if isinstance(schema[key], float):
184
+ import ipdb; ipdb.set_trace()
185
+ inner = default(
186
+ schema[key])
187
+ result[key] = inner
188
+
189
+ return result
190
+
191
+ @dispatch
192
+ def default(schema: Node):
193
+ if schema._default is not None:
194
+ return schema._default
195
+ else:
196
+ result = {}
197
+ for key in schema.__dataclass_fields__:
198
+ if not is_schema_key(key):
199
+ inner = default(
200
+ getattr(schema, key))
201
+ result[key] = inner
202
+
203
+ return result
204
+
@@ -0,0 +1,309 @@
1
+ import copy
2
+ from plum import dispatch
3
+ import numpy as np
4
+
5
+ from dataclasses import replace, dataclass
6
+
7
+ from bigraph_schema.schema import (
8
+ Node,
9
+ Empty,
10
+ Union,
11
+ Tuple,
12
+ Boolean,
13
+ Number,
14
+ Integer,
15
+ Float,
16
+ Delta,
17
+ Nonnegative,
18
+ String,
19
+ Enum,
20
+ Wrap,
21
+ Maybe,
22
+ Overwrite,
23
+ List,
24
+ Map,
25
+ Tree,
26
+ Array,
27
+ Key,
28
+ Path,
29
+ Wires,
30
+ Schema,
31
+ Link,
32
+ )
33
+
34
+
35
+ from bigraph_schema.methods.default import default
36
+ from bigraph_schema.methods.resolve import resolve
37
+ from bigraph_schema.methods.merge import merge, merge_update
38
+
39
+
40
+ NODE_INSTANCE = Node()
41
+
42
+ def generalize_subclass(subclass, superclass):
43
+ if superclass == NODE_INSTANCE:
44
+ return subclass
45
+ result = {}
46
+ for key in superclass.__dataclass_fields__:
47
+ if key == '_default':
48
+ result[key] = superclass._default or subclass._default
49
+ else:
50
+ subattr = getattr(subclass, key)
51
+ if not key.startswith('_'):
52
+ superattr = getattr(superclass, key)
53
+ try:
54
+ outcome = generalize(subattr, superattr)
55
+ except Exception as e:
56
+ raise Exception(f'\ncannot generalize subtypes for attribute \'{key}\':\n{subattr}\n{superattr}\n\n due to\n{e}')
57
+ result[key] = outcome
58
+ else:
59
+ result[key] = subattr
60
+ generalized = type(superclass)(**result)
61
+ return generalized
62
+
63
+
64
+ @dispatch
65
+ def generalize(current: Integer, update: Float):
66
+ if not update._default:
67
+ update._default = current._default
68
+ return update
69
+
70
+ @dispatch
71
+ def generalize(current: Float, update: Integer):
72
+ if not current._default:
73
+ current._default = update._default
74
+ return current
75
+
76
+ @dispatch
77
+ def generalize(current: Empty, update: Node):
78
+ return update
79
+
80
+ @dispatch
81
+ def generalize(current: Node, update: Empty):
82
+ return current
83
+
84
+ @dispatch
85
+ def generalize(current: Wrap, update: Wrap):
86
+ value = generalize(current._value, update._value)
87
+ return value
88
+ # if type(current) == type(update):
89
+ # return type(current)(_value=value)
90
+ # elif issubclass(current_type, update_type):
91
+ # return generalize_subclass(current, update)
92
+ # elif issubclass(update_type, current_type):
93
+ # return generalize_subclass(update, current)
94
+ # else:
95
+ # return update
96
+
97
+ @dispatch
98
+ def generalize(current: Wrap, update: Node):
99
+ value = generalize(current._value, update)
100
+ return value
101
+
102
+ @dispatch
103
+ def generalize(current: Node, update: Wrap):
104
+ value = generalize(current, update._value)
105
+ return value
106
+
107
+ @dispatch
108
+ def generalize(current: Node, update: Node):
109
+ current_type = type(current)
110
+ update_type = type(update)
111
+ if current_type == update_type or issubclass(current_type, update_type):
112
+ return generalize_subclass(current, update)
113
+ elif issubclass(update_type, current_type):
114
+ return generalize_subclass(update, current)
115
+ else:
116
+ raise Exception(f'\ncannot generalize types:\n{current}\n{update}\n')
117
+
118
+
119
+ def generalize_node_dict(current: Node, update: dict):
120
+ fields = set(current.__dataclass_fields__)
121
+ keys = set(update.keys())
122
+
123
+ if len(keys.difference(fields)) > 0:
124
+ return update
125
+ else:
126
+ return current
127
+
128
+
129
+ @dispatch
130
+ def generalize(current: Array, update: dict):
131
+ for key in update:
132
+ if not (isinstance(key, int) and key < current._shape[0]):
133
+ if key != '*':
134
+ return generalize_node_dict(current, update)
135
+
136
+ return current
137
+
138
+
139
+ @dispatch
140
+ def generalize(current: Map, update: dict):
141
+ result = current._value
142
+
143
+ try:
144
+ for key, value in update.items():
145
+ result = generalize(result, value)
146
+ generalized = replace(current, _value=result)
147
+
148
+ except:
149
+ # upgrade from map to struct schema
150
+ map_default = default(current)
151
+ generalized = {
152
+ key: current._value
153
+ for key in map_default}
154
+ generalized.update(update)
155
+
156
+ schema = merge_update(generalized, current, update)
157
+ return schema
158
+
159
+
160
+ @dispatch
161
+ def generalize(current: dict, update: Map):
162
+ result = update._value
163
+
164
+ try:
165
+ for key, value in current.items():
166
+ result = generalize(result, value)
167
+ generalized = replace(update, _value=result)
168
+
169
+ except:
170
+ # upgrade from map to struct schema
171
+ map_default = default(update)
172
+ generalized = {
173
+ key: update._value
174
+ for key in map_default}
175
+ current.update(generalized)
176
+
177
+ schema = merge_update(generalized, current, update)
178
+ return schema
179
+
180
+ @dispatch
181
+ def generalize(current: Tree, update: Map):
182
+ value = current._leaf
183
+ leaf = update._value
184
+ update_leaf = generalize(leaf, value)
185
+ result = copy.copy(current)
186
+ generalized = replace(result, _leaf=update_leaf)
187
+
188
+ schema = merge_update(generalized, current, update)
189
+ return schema
190
+
191
+ @dispatch
192
+ def generalize(current: Tree, update: Tree):
193
+ current_leaf = current._leaf
194
+ update_leaf = update._leaf
195
+ generalized = generalize(current_leaf, update_leaf)
196
+ result = replace(current, _leaf=generalized)
197
+
198
+ schema = merge_update(result, current, update)
199
+ return schema
200
+
201
+ @dispatch
202
+ def generalize(current: Tree, update: Node):
203
+ leaf = current._leaf
204
+ try:
205
+ generalized = generalize(leaf, update)
206
+ except:
207
+ raise(f'update schema is neither a tree or a leaf:\n{current}\n{update}')
208
+
209
+ replace(current, _leaf=generalized)
210
+ return current
211
+
212
+ @dispatch
213
+ def generalize(current: Tree, update: dict):
214
+ result = copy.copy(current)
215
+ leaf = current._leaf
216
+ for key, value in update.items():
217
+ try:
218
+ leaf = generalize(leaf, value)
219
+ except:
220
+ result = generalize(result, value)
221
+ generalized = replace(result, _leaf=leaf)
222
+
223
+ schema = merge_update(generalized, current, update)
224
+ return schema
225
+
226
+
227
+ @dispatch
228
+ def generalize(current: dict, update: dict):
229
+ result = {}
230
+ all_keys = set(current.keys()).union(set(update.keys()))
231
+ for key in all_keys:
232
+ try:
233
+ value = generalize(
234
+ current.get(key),
235
+ update.get(key))
236
+ except Exception as e:
237
+ raise Exception(f'\ncannot generalize subtypes for key \'{key}\':\n{current}\n{update}\n\n due to\n{e}')
238
+
239
+ result[key] = value
240
+ return result
241
+
242
+
243
+ @dispatch
244
+ def generalize(current: Node, update: dict):
245
+ return generalize_node_dict(current, update)
246
+
247
+
248
+ @dispatch
249
+ def generalize(current: dict, update: Node):
250
+ fields = set(update.__dataclass_fields__)
251
+ keys = set(current.keys())
252
+
253
+ if len(keys.difference(fields)) > 0:
254
+ return current
255
+ else:
256
+ return update
257
+
258
+ # @dispatch
259
+ # def generalize(current: dict, update: Node):
260
+ # fields = set(update.__dataclass_fields__)
261
+ # keys = set(current.keys())
262
+
263
+ # for key in keys.intersect(fields):
264
+ # getattr(update, key)
265
+
266
+
267
+
268
+ @dispatch
269
+ def generalize(current: list, update: list):
270
+ return tuple(update)
271
+
272
+
273
+ @dispatch
274
+ def generalize(current: String, update: Node):
275
+ if current._default:
276
+ update._default = current._default
277
+ return update
278
+
279
+ @dispatch
280
+ def generalize(current: String, update: Wrap):
281
+ return generalize(current, update._value)
282
+
283
+ @dispatch
284
+ def generalize(current: String, update: String):
285
+ if update._default or not current._default:
286
+ return update
287
+ else:
288
+ return current
289
+
290
+ @dispatch
291
+ def generalize(current: Node, update: String):
292
+ if update._default:
293
+ current._default = update._default
294
+ return current
295
+
296
+ @dispatch
297
+ def generalize(current: Empty, update: Empty):
298
+ return update
299
+
300
+ @dispatch
301
+ def generalize(current, update):
302
+ if current is None or not current:
303
+ return update
304
+ elif update is None or not update:
305
+ return current
306
+ else:
307
+ raise Exception(f'\ncannot generalize types, not schemas:\n{current}\n{update}\n')
308
+
309
+
@@ -0,0 +1,182 @@
1
+ from plum import dispatch
2
+ import numpy as np
3
+ import numpy.lib.format as nf
4
+
5
+ from types import NoneType
6
+ from dataclasses import replace
7
+
8
+ from bigraph_schema.schema import (
9
+ Node,
10
+ Union,
11
+ Tuple,
12
+ Boolean,
13
+ Number,
14
+ Integer,
15
+ Float,
16
+ Complex,
17
+ Delta,
18
+ Nonnegative,
19
+ NPRandom,
20
+ String,
21
+ Enum,
22
+ Wrap,
23
+ Maybe,
24
+ Overwrite,
25
+ List,
26
+ Map,
27
+ Tree,
28
+ Array,
29
+ Key,
30
+ Path,
31
+ Wires,
32
+ Schema,
33
+ Link,
34
+ schema_dtype,
35
+ )
36
+
37
+ # aligning parameters takes them from positioned arguments and gives them keys
38
+ # in a dict.
39
+
40
+ # reifying the schema takes a dict of representations and turns them into schemas
41
+ # according to the parameters
42
+
43
+ # handling parameters combines these operations to go from positioned arguments to schemas
44
+
45
+ # we need aligning when parsing, but only reifying when inferring from state
46
+ # hence the distinction here
47
+
48
+ def schema_keys(schema):
49
+ keys = []
50
+ for key in schema.__dataclass_fields__:
51
+ if key.startswith('_'):
52
+ keys.append(key)
53
+
54
+ return keys
55
+
56
+ @dispatch
57
+ def align_parameters(schema: Tuple, parameters):
58
+ return {
59
+ '_values': parameters}
60
+
61
+ @dispatch
62
+ def align_parameters(schema: Enum, parameters):
63
+ return {
64
+ '_values': parameters}
65
+
66
+ @dispatch
67
+ def align_parameters(schema: Union, parameters):
68
+ return {
69
+ '_options': parameters}
70
+
71
+ @dispatch
72
+ def align_parameters(schema: Map, parameters):
73
+ align = {}
74
+
75
+ if len(parameters) == 1:
76
+ align['_value'] = parameters[0]
77
+ elif len(parameters) == 2:
78
+ align['_key'], align['_value'] = parameters
79
+
80
+ return align
81
+
82
+ @dispatch
83
+ def align_parameters(schema: Array, parameters):
84
+ return {
85
+ '_shape': parameters[0],
86
+ '_data': parameters[1]}
87
+
88
+ @dispatch
89
+ def align_parameters(schema: Link, parameters):
90
+ align = {
91
+ '_inputs': parameters[0],
92
+ '_outputs': parameters[1]}
93
+
94
+ return align
95
+
96
+ @dispatch
97
+ def align_parameters(schema: Node, parameters):
98
+ align = {}
99
+ keys = schema_keys(schema)[1:]
100
+ for key, parameter in zip(keys, parameters):
101
+ align[key] = parameter
102
+ return align
103
+
104
+ @dispatch
105
+ def align_parameters(schema, parameters):
106
+ raise Exception(f'unknown parameters for schema {schema}: {parameters}')
107
+
108
+ @dispatch
109
+ def reify_schema(core, schema: Enum, parameters):
110
+ if '_values' in parameters:
111
+ schema._values = parameters['_values']
112
+ return schema
113
+
114
+ @dispatch
115
+ def reify_schema(core, schema: Array, parameters):
116
+ if '|' in parameters.get('_shape', ''):
117
+ import ipdb; ipdb.set_trace()
118
+
119
+ schema._shape = tuple([
120
+ int(value)
121
+ for value in parameters.get('_shape', (1,))])
122
+
123
+ data = parameters.get('_data', 'float')
124
+ data_schema = core.access(data)
125
+ # if isinstance(data, Node):
126
+ # data = core.render(data)
127
+ # schema._data = nf.descr_to_dtype(data)
128
+ dtype = schema_dtype(data)
129
+ if isinstance(dtype, Array):
130
+ schema = replace(schema, **{'_shape': schema._shape + dtype._shape})
131
+ else:
132
+ schema._data = dtype
133
+
134
+ return schema
135
+
136
+ @dispatch
137
+ def reify_schema(core, schema: Union, parameters):
138
+ return replace(schema, **parameters)
139
+
140
+
141
+ def reify_schema_link(core, schema, parameters):
142
+ if 'address' in parameters:
143
+ schema.address = core.access(parameters['address'])
144
+ if 'config' in parameters:
145
+ schema.config = core.access(parameters['config'])
146
+ if 'inputs' in parameters:
147
+ schema.inputs = core.access(parameters['inputs'])
148
+ if 'outputs' in parameters:
149
+ schema.outputs = core.access(parameters['outputs'])
150
+ if '_inputs' in parameters:
151
+ schema._inputs = core.access(parameters['_inputs'])
152
+ if '_outputs' in parameters:
153
+ schema._outputs = core.access(parameters['_outputs'])
154
+
155
+ return schema
156
+
157
+ @dispatch
158
+ def reify_schema(core, schema: Link, parameters):
159
+ return reify_schema_link(core, schema, parameters)
160
+
161
+ @dispatch
162
+ def reify_schema(core, schema: Node, parameters):
163
+ for key, parameter in parameters.items():
164
+ subkey = core.access(parameter)
165
+
166
+ if hasattr(schema, key):
167
+ field = getattr(schema, key)
168
+ resolve = core.resolve(field, subkey)
169
+ else:
170
+ resolve = subkey
171
+
172
+ setattr(schema, key, resolve)
173
+
174
+ return schema
175
+
176
+ @dispatch
177
+ def reify_schema(core, schema, parameters):
178
+ import ipdb; ipdb.set_trace()
179
+
180
+ def handle_parameters(core, schema, parameters):
181
+ align = align_parameters(schema, parameters)
182
+ return reify_schema(core, schema, align)