nnodely 0.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. mplplots/__init__.py +0 -0
  2. mplplots/plots.py +131 -0
  3. nnodely/__init__.py +42 -0
  4. nnodely/activation.py +85 -0
  5. nnodely/arithmetic.py +203 -0
  6. nnodely/earlystopping.py +81 -0
  7. nnodely/exporter/__init__.py +3 -0
  8. nnodely/exporter/export.py +275 -0
  9. nnodely/exporter/exporter.py +45 -0
  10. nnodely/exporter/reporter.py +48 -0
  11. nnodely/exporter/standardexporter.py +108 -0
  12. nnodely/fir.py +150 -0
  13. nnodely/fuzzify.py +221 -0
  14. nnodely/initializer.py +31 -0
  15. nnodely/input.py +131 -0
  16. nnodely/linear.py +130 -0
  17. nnodely/localmodel.py +82 -0
  18. nnodely/logger.py +94 -0
  19. nnodely/loss.py +30 -0
  20. nnodely/model.py +263 -0
  21. nnodely/modeldef.py +205 -0
  22. nnodely/nnodely.py +1295 -0
  23. nnodely/optimizer.py +91 -0
  24. nnodely/output.py +23 -0
  25. nnodely/parameter.py +103 -0
  26. nnodely/parametricfunction.py +329 -0
  27. nnodely/part.py +201 -0
  28. nnodely/relation.py +149 -0
  29. nnodely/trigonometric.py +67 -0
  30. nnodely/utils.py +101 -0
  31. nnodely/visualizer/__init__.py +4 -0
  32. nnodely/visualizer/dynamicmpl/functionplot.py +34 -0
  33. nnodely/visualizer/dynamicmpl/fuzzyplot.py +31 -0
  34. nnodely/visualizer/dynamicmpl/resultsplot.py +28 -0
  35. nnodely/visualizer/dynamicmpl/trainingplot.py +46 -0
  36. nnodely/visualizer/mplnotebookvisualizer.py +66 -0
  37. nnodely/visualizer/mplvisualizer.py +215 -0
  38. nnodely/visualizer/textvisualizer.py +320 -0
  39. nnodely/visualizer/visualizer.py +84 -0
  40. nnodely-0.14.0.dist-info/LICENSE +21 -0
  41. nnodely-0.14.0.dist-info/METADATA +401 -0
  42. nnodely-0.14.0.dist-info/RECORD +44 -0
  43. nnodely-0.14.0.dist-info/WHEEL +5 -0
  44. nnodely-0.14.0.dist-info/top_level.txt +2 -0
nnodely/logger.py ADDED
@@ -0,0 +1,94 @@
1
+ import logging
2
+ import sys
3
+
4
+ from nnodely import LOG_LEVEL
5
+
6
+ BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
7
+
8
+ #The background is set with 40 plus the number of the color, and the foreground with 30
9
+
10
+ #These are the sequences need to get colored ouput
11
+ RESET_SEQ = "\033[0m"
12
+ COLOR_SEQ = "\033[%dm"
13
+ COLOR_BOLD_SEQ = "\033[1;%dm"
14
+ BOLD_SEQ = "\033[1m"
15
+
16
+ SUPPRESS = logging.CRITICAL + 10
17
+ logging.getLogger().setLevel(SUPPRESS)
18
+ COLORS = {
19
+ logging.DEBUG: MAGENTA,
20
+ logging.INFO: BLUE,
21
+ logging.WARNING: YELLOW,
22
+ logging.CRITICAL: RED,
23
+ logging.ERROR: RED
24
+ }
25
+ LEVEL_STRING = {
26
+ logging.DEBUG: "DEBUG",
27
+ logging.INFO: "INFO",
28
+ logging.WARNING: "WARNING",
29
+ logging.CRITICAL: "CRITICAL",
30
+ logging.ERROR: "ERROR",
31
+ SUPPRESS: "SUPPRESS"
32
+ }
33
+
34
+
35
+
36
+ class JsonFormatter(logging.Formatter):
37
+ FORMAT = "[%(levelname)s][%(name)s:%(filename)s:%(funcName)s:%(lineno)d] %(message)s" # + ""
38
+ FORMAT_WARNING = "[%(funcName)s] %(message)s"
39
+ FORMAT_INFO = "%(message)s"
40
+ def __init__(self):
41
+ logging.Formatter.__init__(self, self.FORMAT)
42
+
43
+ def format(self, record):
44
+ if record.levelno == logging.WARNING:
45
+ self._style._fmt = self.FORMAT_WARNING
46
+ elif record.levelno == logging.INFO:
47
+ self._style._fmt = self.FORMAT_INFO
48
+ else:
49
+ self._style._fmt = self.FORMAT
50
+ result = logging.Formatter.format(self, record)
51
+ result = COLOR_SEQ % (30 + COLORS[record.levelno]) + result + RESET_SEQ
52
+ return result
53
+
54
+
55
+ # Custom logger class with multiple destinations
56
+ class nnLogger(logging.Logger):
57
+ levels = []
58
+ loggers = []
59
+ params = {'level':None}
60
+ def __init__(self, name, level):
61
+ logging.Logger.__init__(self, name)
62
+ self.setLevel(max(level, LOG_LEVEL))
63
+
64
+ #file = logging.FileHandler('example.log')
65
+ #color_formatter = ColoredFormatter(self.COLOR_FORMAT)
66
+
67
+ self.console = logging.StreamHandler(sys.stdout)
68
+ color_formatter = JsonFormatter()
69
+
70
+ self.console.setFormatter(color_formatter)
71
+ #self.console.setLevel(logging.CRITICAL)
72
+
73
+ #logging.getLogger().addHandler(self.console)
74
+ self.addHandler(self.console)
75
+ self.loggers.append(self)
76
+ self.levels.append(level)
77
+ #self.addHandler(file)
78
+
79
+ def setAllLevel(self, level):
80
+ if self.params['level'] is None or self.params['level'] != level:
81
+ self._log(logging.INFO,
82
+ COLOR_SEQ % (30 + BLUE) + (f" Loggers to {LEVEL_STRING[level]} ").center(80, '=') + RESET_SEQ, None)
83
+ self.params['level'] = level
84
+ for ind, logger in enumerate(self.loggers):
85
+ logger.setLevel(level)
86
+
87
+ def resetAllLevel(self):
88
+ if self.params['level'] != 0:
89
+ self._log(logging.INFO, COLOR_SEQ % (30 + BLUE) + (" Standard Level Log ").center(80, '=') + RESET_SEQ, None)
90
+ self.params['level'] = None
91
+ for ind, logger in enumerate(self.loggers):
92
+ logger.setLevel(self.levels[ind])
93
+
94
+
nnodely/loss.py ADDED
@@ -0,0 +1,30 @@
1
+ import torch.nn as nn
2
+ import torch
3
+ from nnodely.utils import check
4
+
5
+ available_losses = ['mse', 'rmse', 'mae']
6
+
7
+ # class CustomRMSE(nn.Module):
8
+ # def __init__(self):
9
+ # super(CustomRMSE, self).__init__()
10
+ # self.mse = nn.MSELoss()
11
+ #
12
+ # def forward(self, inA, inB):
13
+ # #assert predictions.keys() == labels.keys(), "Keys of predictions and labels must match"
14
+ # #loss = torch.sqrt(self.mse(inA, inB))
15
+ # return self.mse(inA, inB)
16
+
17
+ class CustomLoss(nn.Module):
18
+ def __init__(self, loss_type='mse'):
19
+ super(CustomLoss, self).__init__()
20
+ check(loss_type in available_losses, TypeError, f'The \"{loss_type}\" loss is not available. Possible losses are: {available_losses}.')
21
+ self.loss_type = loss_type
22
+ self.loss = nn.MSELoss()
23
+ if self.loss_type == 'mae':
24
+ self.loss = nn.L1Loss()
25
+
26
+ def forward(self, inA, inB):
27
+ res = self.loss(inA,inB)
28
+ if self.loss_type == 'rmse':
29
+ res = torch.sqrt(res)
30
+ return res
nnodely/model.py ADDED
@@ -0,0 +1,263 @@
1
+ from itertools import product
2
+ import numpy as np
3
+
4
+ import torch.nn as nn
5
+ import torch
6
+
7
+ import copy
8
+
9
+ class Model(nn.Module):
10
+ def __init__(self, model_def):
11
+ super(Model, self).__init__()
12
+ model_def = copy.deepcopy(model_def)
13
+ self.inputs = model_def['Inputs']
14
+ self.outputs = model_def['Outputs']
15
+ self.relations = model_def['Relations']
16
+ self.params = model_def['Parameters']
17
+ self.constants = model_def['Constants']
18
+ self.sample_time = model_def['Info']['SampleTime']
19
+ self.functions = model_def['Functions']
20
+ self.state_model_main = model_def['States']
21
+ self.minimizers = model_def['Minimizers']
22
+ self.state_model = copy.deepcopy(self.state_model_main)
23
+ self.input_ns_backward = {key:value['ns'][0] for key, value in (model_def['Inputs']|model_def['States']).items()}
24
+ self.input_n_samples = {key:value['ntot'] for key, value in (model_def['Inputs']|model_def['States']).items()}
25
+ self.minimizers_keys = [self.minimizers[key]['A'] for key in self.minimizers] + [self.minimizers[key]['B'] for key in self.minimizers]
26
+
27
+ ## Build the network
28
+ self.all_parameters = {}
29
+ self.all_constants = {}
30
+ self.relation_forward = {}
31
+ self.relation_inputs = {}
32
+ self.states = {}
33
+ self.states_closed_loop = {}
34
+ self.states_connect = {}
35
+ #self.constants = set()
36
+
37
+ self.connect_variables = {}
38
+ self.connect = {}
39
+ self.initialize_connect = False
40
+
41
+ ## Define the correct slicing
42
+ json_inputs = self.inputs | self.state_model
43
+ for _, items in self.relations.items():
44
+ if items[0] == 'SamplePart':
45
+ if items[1][0] in json_inputs.keys():
46
+ items[2][0] = self.input_ns_backward[items[1][0]] + items[2][0]
47
+ items[2][1] = self.input_ns_backward[items[1][0]] + items[2][1]
48
+ if len(items) > 3: ## Offset
49
+ items[3] = self.input_ns_backward[items[1][0]] + items[3]
50
+ if items[0] == 'TimePart':
51
+ if items[1][0] in json_inputs.keys():
52
+ items[2][0] = self.input_ns_backward[items[1][0]] + round(items[2][0]/self.sample_time)
53
+ items[2][1] = self.input_ns_backward[items[1][0]] + round(items[2][1]/self.sample_time)
54
+ if len(items) > 3: ## Offset
55
+ items[3] = self.input_ns_backward[items[1][0]] + round(items[3]/self.sample_time)
56
+ else:
57
+ items[2][0] = round(items[2][0]/self.sample_time)
58
+ items[2][1] = round(items[2][1]/self.sample_time)
59
+ if len(items) > 3: ## Offset
60
+ items[3] = round(items[3]/self.sample_time)
61
+
62
+ ## Create all the parameters
63
+ for name, param_data in self.params.items():
64
+ window = 'tw' if 'tw' in param_data.keys() else ('sw' if 'sw' in param_data.keys() else None)
65
+ aux_sample_time = self.sample_time if 'tw' == window else 1
66
+ sample_window = round(param_data[window] / aux_sample_time) if window else 1
67
+ param_size = (sample_window,)+tuple(param_data['dim']) if type(param_data['dim']) is list else (sample_window, param_data['dim'])
68
+ if 'values' in param_data:
69
+ self.all_parameters[name] = nn.Parameter(torch.tensor(param_data['values'], dtype=torch.float32), requires_grad=True)
70
+ # TODO clean code
71
+ elif 'init_fun' in param_data:
72
+ exec(param_data['init_fun']['code'], globals())
73
+ function_to_call = globals()[param_data['init_fun']['name']]
74
+ values = np.zeros(param_size)
75
+ for indexes in product(*(range(v) for v in param_size)):
76
+ if 'params' in param_data['init_fun']:
77
+ values[indexes] = function_to_call(indexes, param_size, param_data['init_fun']['params'])
78
+ else:
79
+ values[indexes] = function_to_call(indexes, param_size)
80
+ self.all_parameters[name] = nn.Parameter(torch.tensor(values.tolist(), dtype=torch.float32), requires_grad=True)
81
+ else:
82
+ self.all_parameters[name] = nn.Parameter(torch.rand(size=param_size, dtype=torch.float32), requires_grad=True)
83
+
84
+ ## Create all the constants
85
+ for name, param_data in self.constants.items():
86
+ self.all_constants[name] = nn.Parameter(torch.tensor(param_data['values'], dtype=torch.float32), requires_grad=False)
87
+
88
+
89
+ ## Initialize state variables
90
+ self.init_states(self.state_model_main, reset_states=True)
91
+ all_params_and_consts = self.all_parameters | self.all_constants
92
+
93
+ ## Create all the relations
94
+ for relation, inputs in self.relations.items():
95
+ ## Take the relation name
96
+ rel_name = inputs[0]
97
+ ## collect the inputs needed for the relation
98
+ input_var = inputs[1]
99
+ ## collect the constants of the model
100
+ #self.constants.update([item for item in inputs[1] if not isinstance(item, str)])
101
+
102
+ ## Create All the Relations
103
+ func = getattr(self,rel_name)
104
+ if func:
105
+ layer_inputs = []
106
+ for item in inputs[2:]:
107
+ if item in list(self.params.keys()): ## the relation takes parameters
108
+ layer_inputs.append(self.all_parameters[item])
109
+ elif item in list(self.constants.keys()): ## the relation takes parameters
110
+ layer_inputs.append(self.all_constants[item])
111
+ elif item in list(self.functions.keys()): ## the relation takes a custom function
112
+ layer_inputs.append(self.functions[item])
113
+ if 'params_and_consts' in self.functions[item].keys() and len(self.functions[item]['params_and_consts']) >= 0: ## Parametric function that takes parameters
114
+ layer_inputs.append([all_params_and_consts[par] for par in self.functions[item]['params_and_consts']])
115
+ if 'map_over_dim' in self.functions[item].keys():
116
+ layer_inputs.append(self.functions[item]['map_over_dim'])
117
+ else:
118
+ layer_inputs.append(item)
119
+
120
+ ## Initialize the relation
121
+ self.relation_forward[relation] = func(*layer_inputs)
122
+ ## Save the inputs needed for the relative relation
123
+ self.relation_inputs[relation] = input_var
124
+
125
+ else:
126
+ print(f"Key Error: [{rel_name}] Relation not defined")
127
+
128
+ ## Add the gradient to all the relations and parameters that requires it
129
+ self.relation_forward = nn.ParameterDict(self.relation_forward)
130
+ self.all_constants = nn.ParameterDict(self.all_constants)
131
+ self.all_parameters = nn.ParameterDict(self.all_parameters)
132
+
133
+ ## list of network outputs
134
+ self.network_output_predictions = set(self.outputs.values())
135
+
136
+ ## list of network minimization outputs
137
+ self.network_output_minimizers = []
138
+ for _,value in self.minimizers.items():
139
+ self.network_output_minimizers.append(self.outputs[value['A']]) if value['A'] in self.outputs.keys() else self.network_output_minimizers.append(value['A'])
140
+ self.network_output_minimizers.append(self.outputs[value['B']]) if value['B'] in self.outputs.keys() else self.network_output_minimizers.append(value['B'])
141
+ self.network_output_minimizers = set(self.network_output_minimizers)
142
+
143
+ ## list of all the network Outputs
144
+ self.network_outputs = self.network_output_predictions.union(self.network_output_minimizers)
145
+
146
+ def forward(self, kwargs):
147
+ result_dict = {}
148
+
149
+ ## Initially i have only the inputs from the dataset, the parameters, and the constants
150
+ available_inputs = [key for key in self.inputs.keys() if key not in self.connect.keys()] ## remove connected inputs
151
+ available_states = [key for key in self.state_model.keys() if key not in self.states_connect.keys()] ## remove connected states
152
+ available_keys = set(available_inputs + list(self.all_parameters.keys()) + list(self.all_constants.keys()) + available_states)
153
+
154
+ ## Forward pass through the relations
155
+ while not self.network_outputs.issubset(available_keys): ## i need to climb the relation tree until i get all the outputs
156
+ for relation in self.relations.keys():
157
+ ## if i have all the variables i can calculate the relation
158
+ if set(self.relation_inputs[relation]).issubset(available_keys) and (relation not in available_keys):
159
+ ## Collect all the necessary inputs for the relation
160
+ layer_inputs = []
161
+ for key in self.relation_inputs[relation]:
162
+ if key in self.all_constants.keys(): ## relation that takes a constant
163
+ layer_inputs.append(self.all_constants[key])
164
+ elif key in self.states.keys(): ## relation that takes a state
165
+ layer_inputs.append(self.states[key])
166
+ elif key in available_inputs: ## relation that takes inputs (self.inputs.keys())
167
+ layer_inputs.append(kwargs[key])
168
+ elif key in self.all_parameters.keys(): ## relation that takes parameters
169
+ layer_inputs.append(self.all_parameters[key])
170
+ else: ## relation than takes another relation or a connect variable
171
+ layer_inputs.append(result_dict[key])
172
+
173
+ ## Execute the current relation
174
+ result_dict[relation] = self.relation_forward[relation](*layer_inputs)
175
+ available_keys.add(relation)
176
+
177
+ ## Update the connect variables if necessary
178
+ for connect_in, connect_out in self.connect.items():
179
+ if relation == self.outputs[connect_out]: ## we have to save the output
180
+ shift = result_dict[relation].shape[1]
181
+ self.connect_variables[connect_in] = torch.roll(self.connect_variables[connect_in], shifts=-1, dims=1)
182
+ self.connect_variables[connect_in][:, -shift:, :] = result_dict[relation]
183
+ result_dict[connect_in] = self.connect_variables[connect_in].clone()
184
+ available_keys.add(connect_in)
185
+
186
+ ## Update connect state if necessary
187
+ if relation in self.states_connect.values():
188
+ for state in [key for key, value in self.states_connect.items() if value == relation]:
189
+ shift = result_dict[relation].shape[1]
190
+ self.states[state] = torch.roll(self.states[state], shifts=-1, dims=1)
191
+ self.states[state][:, -shift:, :] = result_dict[relation]#.detach() ## TODO: detach??
192
+ available_keys.add(state)
193
+
194
+ ## Update closed loop state if necessary
195
+ for relation in self.relations.keys():
196
+ if relation in self.states_closed_loop.values():
197
+ for state in [key for key, value in self.states_closed_loop.items() if value == relation]:
198
+ shift = result_dict[relation].shape[1]
199
+ self.states[state] = torch.roll(self.states[state], shifts=-1, dims=1) # shifts=-shift, dims=1)
200
+ self.states[state][:, -shift:, :] = result_dict[relation] # .detach() ## TODO: detach??
201
+
202
+ ## Return a dictionary with all the outputs final values
203
+ output_dict = {key: result_dict[value] for key, value in self.outputs.items()}
204
+ ## Return a dictionary with the minimization relations
205
+ minimize_dict = {}
206
+ for key in self.minimizers_keys:
207
+ minimize_dict[key] = result_dict[self.outputs[key]] if key in self.outputs.keys() else result_dict[key]
208
+
209
+ return output_dict, minimize_dict
210
+
211
+
212
+ def init_states(self, state_model, connect = {}, reset_states = False):
213
+ ## Initialize state variables
214
+ if reset_states:
215
+ self.reset_states()
216
+ self.reset_connect_variables(copy.deepcopy(connect), only=False)
217
+ self.states_connect = {}
218
+ self.states_closed_loop = {}
219
+ ## save the states updates
220
+ for state, param in state_model.items():
221
+ if 'connect' in param.keys():
222
+ self.states_connect[state] = param['connect']
223
+ else:
224
+ self.states_closed_loop[state] = param['closedLoop']
225
+
226
+ def reset_connect_variables(self, connect, values = None, only = True):
227
+ if only == False:
228
+ self.connect = connect
229
+ self.connect_variables = {}
230
+ self.initialize_connect = True
231
+ for key in connect.keys():
232
+ if values is not None and key in values.keys():
233
+ self.connect_variables[key] = values[key].clone()
234
+ elif only == False:
235
+ batch = values[list(values)[0]].shape[0] if values is not None else 1
236
+ window_size = self.input_n_samples[key]
237
+ self.connect_variables[key] = torch.zeros(size=(batch, window_size, self.inputs[key]['dim']),
238
+ dtype=torch.float32, requires_grad=False)
239
+ def reset_states(self, values = None, only = True):
240
+ if values is None:
241
+ for key, value in self.state_model.items():
242
+ batch = self.states[key].shape[0] if key in self.states else 1
243
+ window_size = self.input_n_samples[key]
244
+ self.states[key] = torch.zeros(size=(batch, window_size, self.state_model[key]['dim']),
245
+ dtype=torch.float32, requires_grad=False)
246
+ else:
247
+ if type(values) is set:
248
+ for key in self.state_model.keys():
249
+ if key in values:
250
+ batch = self.states[key].shape[0] if key in self.states else 1
251
+ window_size = self.input_n_samples[key]
252
+ self.states[key] = torch.zeros(size=(batch, window_size, self.state_model[key]['dim']),
253
+ dtype=torch.float32, requires_grad=False)
254
+ else:
255
+ for key in self.state_model.keys():
256
+ if key in values.keys():
257
+ self.states[key] = values[key].clone()
258
+ self.states[key].requires_grad = False
259
+ elif only == False:
260
+ batch = values[list(values)[0]].shape[0]
261
+ window_size = self.input_n_samples[key]
262
+ self.states[key] = torch.zeros(size=(batch, window_size, self.state_model[key]['dim']),
263
+ dtype=torch.float32, requires_grad=False)
nnodely/modeldef.py ADDED
@@ -0,0 +1,205 @@
1
+ import copy
2
+
3
+ import numpy as np
4
+
5
+ from nnodely.input import closedloop_name, connect_name
6
+ from nnodely.utils import check, merge
7
+ from nnodely.relation import MAIN_JSON, Stream
8
+ from nnodely.output import Output
9
+
10
+ from nnodely.logger import logging, nnLogger
11
+ log = nnLogger(__name__, logging.INFO)
12
+
13
+ class ModelDef():
14
+ def __init__(self, model_def = MAIN_JSON):
15
+ # Models definition
16
+ self.json_base = copy.deepcopy(model_def)
17
+
18
+ # Inizialize the model definition
19
+ self.json = copy.deepcopy(self.json_base)
20
+ if "SampleTime" in self.json['Info']:
21
+ self.sample_time = self.json['Info']["SampleTime"]
22
+ else:
23
+ self.sample_time = None
24
+ self.model_dict = {}
25
+ self.minimize_dict = {}
26
+ self.update_state_dict = {}
27
+
28
+ def __contains__(self, key):
29
+ return key in self.json
30
+
31
+ def __getitem__(self, key):
32
+ return self.json[key]
33
+
34
+ def __setitem__(self, key, value):
35
+ self.json[key] = value
36
+
37
+ def isDefined(self):
38
+ return self.json is not None
39
+
40
+ def update(self, model_def = None, model_dict = None, minimize_dict = None, update_state_dict = None):
41
+ self.json = copy.deepcopy(model_def) if model_def is not None else copy.deepcopy(self.json_base)
42
+ model_dict = copy.deepcopy(model_dict) if model_dict is not None else self.model_dict
43
+ minimize_dict = copy.deepcopy(minimize_dict) if minimize_dict is not None else self.minimize_dict
44
+ update_state_dict = copy.deepcopy(update_state_dict) if update_state_dict is not None else self.update_state_dict
45
+
46
+ # Add models to the model_def
47
+ for key, stream_list in model_dict.items():
48
+ for stream in stream_list:
49
+ self.json = merge(self.json, stream.json)
50
+ if len(model_dict) > 1:
51
+ if 'Models' not in self.json:
52
+ self.json['Models'] = {}
53
+ for model_name, model_params in model_dict.items():
54
+ self.json['Models'][model_name] = {'Inputs': [], 'States': [], 'Outputs': [], 'Parameters': [],
55
+ 'Constants': []}
56
+ parameters, constants, inputs, states = set(), set(), set(), set()
57
+ for param in model_params:
58
+ self.json['Models'][model_name]['Outputs'].append(param.name)
59
+ parameters |= set(param.json['Parameters'].keys())
60
+ constants |= set(param.json['Constants'].keys())
61
+ inputs |= set(param.json['Inputs'].keys())
62
+ states |= set(param.json['States'].keys())
63
+ self.json['Models'][model_name]['Parameters'] = list(parameters)
64
+ self.json['Models'][model_name]['Constants'] = list(constants)
65
+ self.json['Models'][model_name]['Inputs'] = list(inputs)
66
+ self.json['Models'][model_name]['States'] = list(states)
67
+ elif len(model_dict) == 1:
68
+ self.json['Models'] = list(model_dict.keys())[0]
69
+
70
+ if 'Minimizers' not in self.json:
71
+ self.json['Minimizers'] = {}
72
+ for key, minimize in minimize_dict.items():
73
+ self.json = merge(self.json, minimize['A'].json)
74
+ self.json = merge(self.json, minimize['B'].json)
75
+ self.json['Minimizers'][key] = {}
76
+ self.json['Minimizers'][key]['A'] = minimize['A'].name
77
+ self.json['Minimizers'][key]['B'] = minimize['B'].name
78
+ self.json['Minimizers'][key]['loss'] = minimize['loss']
79
+
80
+ for key, update_state in update_state_dict.items():
81
+ self.json = merge(self.json, update_state.json)
82
+
83
+ if "SampleTime" in self.json['Info']:
84
+ self.sample_time = self.json['Info']["SampleTime"]
85
+
86
+
87
+ def __update_state(self, stream_out, state_list_in, UpdateState):
88
+ from nnodely.input import State
89
+ if type(state_list_in) is not list:
90
+ state_list_in = [state_list_in]
91
+ for state_in in state_list_in:
92
+ check(isinstance(stream_out, (Output, Stream)), TypeError,
93
+ f"The {stream_out} must be a Stream or Output and not a {type(stream_out)}.")
94
+ check(type(state_in) is State, TypeError,
95
+ f"The {state_in} must be a State and not a {type(state_in)}.")
96
+ check(stream_out.dim['dim'] == state_in.dim['dim'], ValueError,
97
+ f"The dimension of {stream_out.name} is not equal to the dimension of {state_in.name} ({stream_out.dim['dim']}!={state_in.dim['dim']}).")
98
+ if type(stream_out) is Output:
99
+ stream_name = self.json['Outputs'][stream_out.name]
100
+ stream_out = Stream(stream_name,stream_out.json,stream_out.dim, 0)
101
+ self.update_state_dict[state_in.name] = UpdateState(stream_out, state_in)
102
+
103
+ def addConnect(self, stream_out, state_list_in):
104
+ from nnodely.input import Connect
105
+ self.__update_state(stream_out, state_list_in, Connect)
106
+ self.update()
107
+
108
+ def addClosedLoop(self, stream_out, state_list_in):
109
+ from nnodely.input import ClosedLoop
110
+ self.__update_state(stream_out, state_list_in, ClosedLoop)
111
+ self.update()
112
+
113
+ def addModel(self, name, stream_list):
114
+ if isinstance(stream_list, (Output,Stream)):
115
+ stream_list = [stream_list]
116
+ if type(stream_list) is list:
117
+ self.model_dict[name] = copy.deepcopy(stream_list)
118
+ else:
119
+ raise TypeError(f'stream_list is type {type(stream_list)} but must be an Output or Stream or a list of them')
120
+ self.update()
121
+
122
+ def removeModel(self, name_list):
123
+ if type(name_list) is str:
124
+ name_list = [name_list]
125
+ if type(name_list) is list:
126
+ for name in name_list:
127
+ check(name in self.model_dict, IndexError, f"The name {name} is not part of the available models")
128
+ del self.model_dict[name]
129
+ self.update()
130
+
131
+ def addMinimize(self, name, streamA, streamB, loss_function='mse'):
132
+ check(isinstance(streamA, (Output, Stream)), TypeError, 'streamA must be an instance of Output or Stream')
133
+ check(isinstance(streamB, (Output, Stream)), TypeError, 'streamA must be an instance of Output or Stream')
134
+ check(streamA.dim == streamB.dim, ValueError, f'Dimension of streamA={streamA.dim} and streamB={streamB.dim} are not equal.')
135
+ self.minimize_dict[name]={'A':copy.deepcopy(streamA), 'B': copy.deepcopy(streamB), 'loss':loss_function}
136
+ self.update()
137
+
138
+ def removeMinimize(self, name_list):
139
+ if type(name_list) is str:
140
+ name_list = [name_list]
141
+ if type(name_list) is list:
142
+ for name in name_list:
143
+ check(name in self.minimize_dict, IndexError, f"The name {name} is not part of the available minimuzes")
144
+ del self.minimize_dict[name]
145
+ self.update()
146
+
147
+ def setBuildWindow(self, sample_time = None):
148
+ check(self.json is not None, RuntimeError, "No model is defined!")
149
+ if sample_time is not None:
150
+ check(sample_time > 0, RuntimeError, 'Sample time must be strictly positive!')
151
+ self.sample_time = sample_time
152
+ else:
153
+ if self.sample_time is None:
154
+ self.sample_time = 1
155
+
156
+ self.json['Info'] = {"SampleTime": self.sample_time}
157
+
158
+ check(self.json['Inputs'] | self.json['States'] != {}, RuntimeError, "No model is defined!")
159
+ json_inputs = self.json['Inputs'] | self.json['States']
160
+
161
+ for key,value in self.json['States'].items():
162
+ check(closedloop_name in self.json['States'][key] or connect_name in self.json['States'][key],
163
+ KeyError, f'Update function is missing for state {key}. Use Connect or ClosedLoop to update the state.')
164
+
165
+ input_tw_backward, input_tw_forward, input_ns_backward, input_ns_forward = {}, {}, {}, {}
166
+ for key, value in json_inputs.items():
167
+ if value['sw'] == [0,0] and value['tw'] == [0,0]:
168
+ assert(False), f"Input {key} has no time window or sample window"
169
+ if value['sw'] == [0, 0] and self.sample_time is not None:
170
+ input_ns_backward[key] = round(-value['tw'][0] / self.sample_time)
171
+ input_ns_forward[key] = round(value['tw'][1] / self.sample_time)
172
+ elif self.sample_time is not None:
173
+ input_ns_backward[key] = max(round(-value['tw'][0] / self.sample_time),-value['sw'][0])
174
+ input_ns_forward[key] = max(round(value['tw'][1] / self.sample_time),value['sw'][1])
175
+ else:
176
+ check(value['tw'] == [0,0], RuntimeError, f"Sample time is not defined for input {key}")
177
+ input_ns_backward[key] = -value['sw'][0]
178
+ input_ns_forward[key] = value['sw'][1]
179
+ value['ns'] = [input_ns_backward[key], input_ns_forward[key]]
180
+ value['ntot'] = sum(value['ns'])
181
+
182
+ self.json['Info']['ns'] = [max(input_ns_backward.values()), max(input_ns_forward.values())]
183
+ self.json['Info']['ntot'] = sum(self.json['Info']['ns'])
184
+ if self.json['Info']['ns'][0] < 0:
185
+ log.warning(
186
+ f"The input is only in the far past the max_samples_backward is: {self.json['Info']['ns'][0]}")
187
+ if self.json['Info']['ns'][1] < 0:
188
+ log.warning(
189
+ f"The input is only in the far future the max_sample_forward is: {self.json['Info']['ns'][1]}")
190
+
191
+ for k, v in (self.json['Parameters']|self.json['Constants']).items():
192
+ if 'values' in v:
193
+ window = 'tw' if 'tw' in v.keys() else ('sw' if 'sw' in v.keys() else None)
194
+ if window == 'tw':
195
+ check(np.array(v['values']).shape[0] == v['tw']/self.sample_time, ValueError,
196
+ f"{k} has a different number of values for this sample time.")
197
+
198
+
199
+ def updateParameters(self, model):
200
+ if model is not None:
201
+ for key in self.json['Parameters'].keys():
202
+ if key in model.all_parameters:
203
+ self.json['Parameters'][key]['values'] = model.all_parameters[key].tolist()
204
+ if 'init_fun' in self.json['Parameters'][key]:
205
+ del self.json['Parameters'][key]['init_fun']