nnodely 0.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mplplots/__init__.py +0 -0
- mplplots/plots.py +131 -0
- nnodely/__init__.py +42 -0
- nnodely/activation.py +85 -0
- nnodely/arithmetic.py +203 -0
- nnodely/earlystopping.py +81 -0
- nnodely/exporter/__init__.py +3 -0
- nnodely/exporter/export.py +275 -0
- nnodely/exporter/exporter.py +45 -0
- nnodely/exporter/reporter.py +48 -0
- nnodely/exporter/standardexporter.py +108 -0
- nnodely/fir.py +150 -0
- nnodely/fuzzify.py +221 -0
- nnodely/initializer.py +31 -0
- nnodely/input.py +131 -0
- nnodely/linear.py +130 -0
- nnodely/localmodel.py +82 -0
- nnodely/logger.py +94 -0
- nnodely/loss.py +30 -0
- nnodely/model.py +263 -0
- nnodely/modeldef.py +205 -0
- nnodely/nnodely.py +1295 -0
- nnodely/optimizer.py +91 -0
- nnodely/output.py +23 -0
- nnodely/parameter.py +103 -0
- nnodely/parametricfunction.py +329 -0
- nnodely/part.py +201 -0
- nnodely/relation.py +149 -0
- nnodely/trigonometric.py +67 -0
- nnodely/utils.py +101 -0
- nnodely/visualizer/__init__.py +4 -0
- nnodely/visualizer/dynamicmpl/functionplot.py +34 -0
- nnodely/visualizer/dynamicmpl/fuzzyplot.py +31 -0
- nnodely/visualizer/dynamicmpl/resultsplot.py +28 -0
- nnodely/visualizer/dynamicmpl/trainingplot.py +46 -0
- nnodely/visualizer/mplnotebookvisualizer.py +66 -0
- nnodely/visualizer/mplvisualizer.py +215 -0
- nnodely/visualizer/textvisualizer.py +320 -0
- nnodely/visualizer/visualizer.py +84 -0
- nnodely-0.14.0.dist-info/LICENSE +21 -0
- nnodely-0.14.0.dist-info/METADATA +401 -0
- nnodely-0.14.0.dist-info/RECORD +44 -0
- nnodely-0.14.0.dist-info/WHEEL +5 -0
- nnodely-0.14.0.dist-info/top_level.txt +2 -0
nnodely/logger.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
from nnodely import LOG_LEVEL
|
|
5
|
+
|
|
6
|
+
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
|
|
7
|
+
|
|
8
|
+
#The background is set with 40 plus the number of the color, and the foreground with 30
|
|
9
|
+
|
|
10
|
+
#These are the sequences need to get colored ouput
|
|
11
|
+
RESET_SEQ = "\033[0m"
|
|
12
|
+
COLOR_SEQ = "\033[%dm"
|
|
13
|
+
COLOR_BOLD_SEQ = "\033[1;%dm"
|
|
14
|
+
BOLD_SEQ = "\033[1m"
|
|
15
|
+
|
|
16
|
+
SUPPRESS = logging.CRITICAL + 10
|
|
17
|
+
logging.getLogger().setLevel(SUPPRESS)
|
|
18
|
+
COLORS = {
|
|
19
|
+
logging.DEBUG: MAGENTA,
|
|
20
|
+
logging.INFO: BLUE,
|
|
21
|
+
logging.WARNING: YELLOW,
|
|
22
|
+
logging.CRITICAL: RED,
|
|
23
|
+
logging.ERROR: RED
|
|
24
|
+
}
|
|
25
|
+
LEVEL_STRING = {
|
|
26
|
+
logging.DEBUG: "DEBUG",
|
|
27
|
+
logging.INFO: "INFO",
|
|
28
|
+
logging.WARNING: "WARNING",
|
|
29
|
+
logging.CRITICAL: "CRITICAL",
|
|
30
|
+
logging.ERROR: "ERROR",
|
|
31
|
+
SUPPRESS: "SUPPRESS"
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class JsonFormatter(logging.Formatter):
|
|
37
|
+
FORMAT = "[%(levelname)s][%(name)s:%(filename)s:%(funcName)s:%(lineno)d] %(message)s" # + ""
|
|
38
|
+
FORMAT_WARNING = "[%(funcName)s] %(message)s"
|
|
39
|
+
FORMAT_INFO = "%(message)s"
|
|
40
|
+
def __init__(self):
|
|
41
|
+
logging.Formatter.__init__(self, self.FORMAT)
|
|
42
|
+
|
|
43
|
+
def format(self, record):
|
|
44
|
+
if record.levelno == logging.WARNING:
|
|
45
|
+
self._style._fmt = self.FORMAT_WARNING
|
|
46
|
+
elif record.levelno == logging.INFO:
|
|
47
|
+
self._style._fmt = self.FORMAT_INFO
|
|
48
|
+
else:
|
|
49
|
+
self._style._fmt = self.FORMAT
|
|
50
|
+
result = logging.Formatter.format(self, record)
|
|
51
|
+
result = COLOR_SEQ % (30 + COLORS[record.levelno]) + result + RESET_SEQ
|
|
52
|
+
return result
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# Custom logger class with multiple destinations
|
|
56
|
+
class nnLogger(logging.Logger):
|
|
57
|
+
levels = []
|
|
58
|
+
loggers = []
|
|
59
|
+
params = {'level':None}
|
|
60
|
+
def __init__(self, name, level):
|
|
61
|
+
logging.Logger.__init__(self, name)
|
|
62
|
+
self.setLevel(max(level, LOG_LEVEL))
|
|
63
|
+
|
|
64
|
+
#file = logging.FileHandler('example.log')
|
|
65
|
+
#color_formatter = ColoredFormatter(self.COLOR_FORMAT)
|
|
66
|
+
|
|
67
|
+
self.console = logging.StreamHandler(sys.stdout)
|
|
68
|
+
color_formatter = JsonFormatter()
|
|
69
|
+
|
|
70
|
+
self.console.setFormatter(color_formatter)
|
|
71
|
+
#self.console.setLevel(logging.CRITICAL)
|
|
72
|
+
|
|
73
|
+
#logging.getLogger().addHandler(self.console)
|
|
74
|
+
self.addHandler(self.console)
|
|
75
|
+
self.loggers.append(self)
|
|
76
|
+
self.levels.append(level)
|
|
77
|
+
#self.addHandler(file)
|
|
78
|
+
|
|
79
|
+
def setAllLevel(self, level):
|
|
80
|
+
if self.params['level'] is None or self.params['level'] != level:
|
|
81
|
+
self._log(logging.INFO,
|
|
82
|
+
COLOR_SEQ % (30 + BLUE) + (f" Loggers to {LEVEL_STRING[level]} ").center(80, '=') + RESET_SEQ, None)
|
|
83
|
+
self.params['level'] = level
|
|
84
|
+
for ind, logger in enumerate(self.loggers):
|
|
85
|
+
logger.setLevel(level)
|
|
86
|
+
|
|
87
|
+
def resetAllLevel(self):
|
|
88
|
+
if self.params['level'] != 0:
|
|
89
|
+
self._log(logging.INFO, COLOR_SEQ % (30 + BLUE) + (" Standard Level Log ").center(80, '=') + RESET_SEQ, None)
|
|
90
|
+
self.params['level'] = None
|
|
91
|
+
for ind, logger in enumerate(self.loggers):
|
|
92
|
+
logger.setLevel(self.levels[ind])
|
|
93
|
+
|
|
94
|
+
|
nnodely/loss.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
import torch
|
|
3
|
+
from nnodely.utils import check
|
|
4
|
+
|
|
5
|
+
available_losses = ['mse', 'rmse', 'mae']
|
|
6
|
+
|
|
7
|
+
# class CustomRMSE(nn.Module):
|
|
8
|
+
# def __init__(self):
|
|
9
|
+
# super(CustomRMSE, self).__init__()
|
|
10
|
+
# self.mse = nn.MSELoss()
|
|
11
|
+
#
|
|
12
|
+
# def forward(self, inA, inB):
|
|
13
|
+
# #assert predictions.keys() == labels.keys(), "Keys of predictions and labels must match"
|
|
14
|
+
# #loss = torch.sqrt(self.mse(inA, inB))
|
|
15
|
+
# return self.mse(inA, inB)
|
|
16
|
+
|
|
17
|
+
class CustomLoss(nn.Module):
|
|
18
|
+
def __init__(self, loss_type='mse'):
|
|
19
|
+
super(CustomLoss, self).__init__()
|
|
20
|
+
check(loss_type in available_losses, TypeError, f'The \"{loss_type}\" loss is not available. Possible losses are: {available_losses}.')
|
|
21
|
+
self.loss_type = loss_type
|
|
22
|
+
self.loss = nn.MSELoss()
|
|
23
|
+
if self.loss_type == 'mae':
|
|
24
|
+
self.loss = nn.L1Loss()
|
|
25
|
+
|
|
26
|
+
def forward(self, inA, inB):
|
|
27
|
+
res = self.loss(inA,inB)
|
|
28
|
+
if self.loss_type == 'rmse':
|
|
29
|
+
res = torch.sqrt(res)
|
|
30
|
+
return res
|
nnodely/model.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
from itertools import product
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
import copy
|
|
8
|
+
|
|
9
|
+
class Model(nn.Module):
|
|
10
|
+
def __init__(self, model_def):
|
|
11
|
+
super(Model, self).__init__()
|
|
12
|
+
model_def = copy.deepcopy(model_def)
|
|
13
|
+
self.inputs = model_def['Inputs']
|
|
14
|
+
self.outputs = model_def['Outputs']
|
|
15
|
+
self.relations = model_def['Relations']
|
|
16
|
+
self.params = model_def['Parameters']
|
|
17
|
+
self.constants = model_def['Constants']
|
|
18
|
+
self.sample_time = model_def['Info']['SampleTime']
|
|
19
|
+
self.functions = model_def['Functions']
|
|
20
|
+
self.state_model_main = model_def['States']
|
|
21
|
+
self.minimizers = model_def['Minimizers']
|
|
22
|
+
self.state_model = copy.deepcopy(self.state_model_main)
|
|
23
|
+
self.input_ns_backward = {key:value['ns'][0] for key, value in (model_def['Inputs']|model_def['States']).items()}
|
|
24
|
+
self.input_n_samples = {key:value['ntot'] for key, value in (model_def['Inputs']|model_def['States']).items()}
|
|
25
|
+
self.minimizers_keys = [self.minimizers[key]['A'] for key in self.minimizers] + [self.minimizers[key]['B'] for key in self.minimizers]
|
|
26
|
+
|
|
27
|
+
## Build the network
|
|
28
|
+
self.all_parameters = {}
|
|
29
|
+
self.all_constants = {}
|
|
30
|
+
self.relation_forward = {}
|
|
31
|
+
self.relation_inputs = {}
|
|
32
|
+
self.states = {}
|
|
33
|
+
self.states_closed_loop = {}
|
|
34
|
+
self.states_connect = {}
|
|
35
|
+
#self.constants = set()
|
|
36
|
+
|
|
37
|
+
self.connect_variables = {}
|
|
38
|
+
self.connect = {}
|
|
39
|
+
self.initialize_connect = False
|
|
40
|
+
|
|
41
|
+
## Define the correct slicing
|
|
42
|
+
json_inputs = self.inputs | self.state_model
|
|
43
|
+
for _, items in self.relations.items():
|
|
44
|
+
if items[0] == 'SamplePart':
|
|
45
|
+
if items[1][0] in json_inputs.keys():
|
|
46
|
+
items[2][0] = self.input_ns_backward[items[1][0]] + items[2][0]
|
|
47
|
+
items[2][1] = self.input_ns_backward[items[1][0]] + items[2][1]
|
|
48
|
+
if len(items) > 3: ## Offset
|
|
49
|
+
items[3] = self.input_ns_backward[items[1][0]] + items[3]
|
|
50
|
+
if items[0] == 'TimePart':
|
|
51
|
+
if items[1][0] in json_inputs.keys():
|
|
52
|
+
items[2][0] = self.input_ns_backward[items[1][0]] + round(items[2][0]/self.sample_time)
|
|
53
|
+
items[2][1] = self.input_ns_backward[items[1][0]] + round(items[2][1]/self.sample_time)
|
|
54
|
+
if len(items) > 3: ## Offset
|
|
55
|
+
items[3] = self.input_ns_backward[items[1][0]] + round(items[3]/self.sample_time)
|
|
56
|
+
else:
|
|
57
|
+
items[2][0] = round(items[2][0]/self.sample_time)
|
|
58
|
+
items[2][1] = round(items[2][1]/self.sample_time)
|
|
59
|
+
if len(items) > 3: ## Offset
|
|
60
|
+
items[3] = round(items[3]/self.sample_time)
|
|
61
|
+
|
|
62
|
+
## Create all the parameters
|
|
63
|
+
for name, param_data in self.params.items():
|
|
64
|
+
window = 'tw' if 'tw' in param_data.keys() else ('sw' if 'sw' in param_data.keys() else None)
|
|
65
|
+
aux_sample_time = self.sample_time if 'tw' == window else 1
|
|
66
|
+
sample_window = round(param_data[window] / aux_sample_time) if window else 1
|
|
67
|
+
param_size = (sample_window,)+tuple(param_data['dim']) if type(param_data['dim']) is list else (sample_window, param_data['dim'])
|
|
68
|
+
if 'values' in param_data:
|
|
69
|
+
self.all_parameters[name] = nn.Parameter(torch.tensor(param_data['values'], dtype=torch.float32), requires_grad=True)
|
|
70
|
+
# TODO clean code
|
|
71
|
+
elif 'init_fun' in param_data:
|
|
72
|
+
exec(param_data['init_fun']['code'], globals())
|
|
73
|
+
function_to_call = globals()[param_data['init_fun']['name']]
|
|
74
|
+
values = np.zeros(param_size)
|
|
75
|
+
for indexes in product(*(range(v) for v in param_size)):
|
|
76
|
+
if 'params' in param_data['init_fun']:
|
|
77
|
+
values[indexes] = function_to_call(indexes, param_size, param_data['init_fun']['params'])
|
|
78
|
+
else:
|
|
79
|
+
values[indexes] = function_to_call(indexes, param_size)
|
|
80
|
+
self.all_parameters[name] = nn.Parameter(torch.tensor(values.tolist(), dtype=torch.float32), requires_grad=True)
|
|
81
|
+
else:
|
|
82
|
+
self.all_parameters[name] = nn.Parameter(torch.rand(size=param_size, dtype=torch.float32), requires_grad=True)
|
|
83
|
+
|
|
84
|
+
## Create all the constants
|
|
85
|
+
for name, param_data in self.constants.items():
|
|
86
|
+
self.all_constants[name] = nn.Parameter(torch.tensor(param_data['values'], dtype=torch.float32), requires_grad=False)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
## Initialize state variables
|
|
90
|
+
self.init_states(self.state_model_main, reset_states=True)
|
|
91
|
+
all_params_and_consts = self.all_parameters | self.all_constants
|
|
92
|
+
|
|
93
|
+
## Create all the relations
|
|
94
|
+
for relation, inputs in self.relations.items():
|
|
95
|
+
## Take the relation name
|
|
96
|
+
rel_name = inputs[0]
|
|
97
|
+
## collect the inputs needed for the relation
|
|
98
|
+
input_var = inputs[1]
|
|
99
|
+
## collect the constants of the model
|
|
100
|
+
#self.constants.update([item for item in inputs[1] if not isinstance(item, str)])
|
|
101
|
+
|
|
102
|
+
## Create All the Relations
|
|
103
|
+
func = getattr(self,rel_name)
|
|
104
|
+
if func:
|
|
105
|
+
layer_inputs = []
|
|
106
|
+
for item in inputs[2:]:
|
|
107
|
+
if item in list(self.params.keys()): ## the relation takes parameters
|
|
108
|
+
layer_inputs.append(self.all_parameters[item])
|
|
109
|
+
elif item in list(self.constants.keys()): ## the relation takes parameters
|
|
110
|
+
layer_inputs.append(self.all_constants[item])
|
|
111
|
+
elif item in list(self.functions.keys()): ## the relation takes a custom function
|
|
112
|
+
layer_inputs.append(self.functions[item])
|
|
113
|
+
if 'params_and_consts' in self.functions[item].keys() and len(self.functions[item]['params_and_consts']) >= 0: ## Parametric function that takes parameters
|
|
114
|
+
layer_inputs.append([all_params_and_consts[par] for par in self.functions[item]['params_and_consts']])
|
|
115
|
+
if 'map_over_dim' in self.functions[item].keys():
|
|
116
|
+
layer_inputs.append(self.functions[item]['map_over_dim'])
|
|
117
|
+
else:
|
|
118
|
+
layer_inputs.append(item)
|
|
119
|
+
|
|
120
|
+
## Initialize the relation
|
|
121
|
+
self.relation_forward[relation] = func(*layer_inputs)
|
|
122
|
+
## Save the inputs needed for the relative relation
|
|
123
|
+
self.relation_inputs[relation] = input_var
|
|
124
|
+
|
|
125
|
+
else:
|
|
126
|
+
print(f"Key Error: [{rel_name}] Relation not defined")
|
|
127
|
+
|
|
128
|
+
## Add the gradient to all the relations and parameters that requires it
|
|
129
|
+
self.relation_forward = nn.ParameterDict(self.relation_forward)
|
|
130
|
+
self.all_constants = nn.ParameterDict(self.all_constants)
|
|
131
|
+
self.all_parameters = nn.ParameterDict(self.all_parameters)
|
|
132
|
+
|
|
133
|
+
## list of network outputs
|
|
134
|
+
self.network_output_predictions = set(self.outputs.values())
|
|
135
|
+
|
|
136
|
+
## list of network minimization outputs
|
|
137
|
+
self.network_output_minimizers = []
|
|
138
|
+
for _,value in self.minimizers.items():
|
|
139
|
+
self.network_output_minimizers.append(self.outputs[value['A']]) if value['A'] in self.outputs.keys() else self.network_output_minimizers.append(value['A'])
|
|
140
|
+
self.network_output_minimizers.append(self.outputs[value['B']]) if value['B'] in self.outputs.keys() else self.network_output_minimizers.append(value['B'])
|
|
141
|
+
self.network_output_minimizers = set(self.network_output_minimizers)
|
|
142
|
+
|
|
143
|
+
## list of all the network Outputs
|
|
144
|
+
self.network_outputs = self.network_output_predictions.union(self.network_output_minimizers)
|
|
145
|
+
|
|
146
|
+
def forward(self, kwargs):
|
|
147
|
+
result_dict = {}
|
|
148
|
+
|
|
149
|
+
## Initially i have only the inputs from the dataset, the parameters, and the constants
|
|
150
|
+
available_inputs = [key for key in self.inputs.keys() if key not in self.connect.keys()] ## remove connected inputs
|
|
151
|
+
available_states = [key for key in self.state_model.keys() if key not in self.states_connect.keys()] ## remove connected states
|
|
152
|
+
available_keys = set(available_inputs + list(self.all_parameters.keys()) + list(self.all_constants.keys()) + available_states)
|
|
153
|
+
|
|
154
|
+
## Forward pass through the relations
|
|
155
|
+
while not self.network_outputs.issubset(available_keys): ## i need to climb the relation tree until i get all the outputs
|
|
156
|
+
for relation in self.relations.keys():
|
|
157
|
+
## if i have all the variables i can calculate the relation
|
|
158
|
+
if set(self.relation_inputs[relation]).issubset(available_keys) and (relation not in available_keys):
|
|
159
|
+
## Collect all the necessary inputs for the relation
|
|
160
|
+
layer_inputs = []
|
|
161
|
+
for key in self.relation_inputs[relation]:
|
|
162
|
+
if key in self.all_constants.keys(): ## relation that takes a constant
|
|
163
|
+
layer_inputs.append(self.all_constants[key])
|
|
164
|
+
elif key in self.states.keys(): ## relation that takes a state
|
|
165
|
+
layer_inputs.append(self.states[key])
|
|
166
|
+
elif key in available_inputs: ## relation that takes inputs (self.inputs.keys())
|
|
167
|
+
layer_inputs.append(kwargs[key])
|
|
168
|
+
elif key in self.all_parameters.keys(): ## relation that takes parameters
|
|
169
|
+
layer_inputs.append(self.all_parameters[key])
|
|
170
|
+
else: ## relation than takes another relation or a connect variable
|
|
171
|
+
layer_inputs.append(result_dict[key])
|
|
172
|
+
|
|
173
|
+
## Execute the current relation
|
|
174
|
+
result_dict[relation] = self.relation_forward[relation](*layer_inputs)
|
|
175
|
+
available_keys.add(relation)
|
|
176
|
+
|
|
177
|
+
## Update the connect variables if necessary
|
|
178
|
+
for connect_in, connect_out in self.connect.items():
|
|
179
|
+
if relation == self.outputs[connect_out]: ## we have to save the output
|
|
180
|
+
shift = result_dict[relation].shape[1]
|
|
181
|
+
self.connect_variables[connect_in] = torch.roll(self.connect_variables[connect_in], shifts=-1, dims=1)
|
|
182
|
+
self.connect_variables[connect_in][:, -shift:, :] = result_dict[relation]
|
|
183
|
+
result_dict[connect_in] = self.connect_variables[connect_in].clone()
|
|
184
|
+
available_keys.add(connect_in)
|
|
185
|
+
|
|
186
|
+
## Update connect state if necessary
|
|
187
|
+
if relation in self.states_connect.values():
|
|
188
|
+
for state in [key for key, value in self.states_connect.items() if value == relation]:
|
|
189
|
+
shift = result_dict[relation].shape[1]
|
|
190
|
+
self.states[state] = torch.roll(self.states[state], shifts=-1, dims=1)
|
|
191
|
+
self.states[state][:, -shift:, :] = result_dict[relation]#.detach() ## TODO: detach??
|
|
192
|
+
available_keys.add(state)
|
|
193
|
+
|
|
194
|
+
## Update closed loop state if necessary
|
|
195
|
+
for relation in self.relations.keys():
|
|
196
|
+
if relation in self.states_closed_loop.values():
|
|
197
|
+
for state in [key for key, value in self.states_closed_loop.items() if value == relation]:
|
|
198
|
+
shift = result_dict[relation].shape[1]
|
|
199
|
+
self.states[state] = torch.roll(self.states[state], shifts=-1, dims=1) # shifts=-shift, dims=1)
|
|
200
|
+
self.states[state][:, -shift:, :] = result_dict[relation] # .detach() ## TODO: detach??
|
|
201
|
+
|
|
202
|
+
## Return a dictionary with all the outputs final values
|
|
203
|
+
output_dict = {key: result_dict[value] for key, value in self.outputs.items()}
|
|
204
|
+
## Return a dictionary with the minimization relations
|
|
205
|
+
minimize_dict = {}
|
|
206
|
+
for key in self.minimizers_keys:
|
|
207
|
+
minimize_dict[key] = result_dict[self.outputs[key]] if key in self.outputs.keys() else result_dict[key]
|
|
208
|
+
|
|
209
|
+
return output_dict, minimize_dict
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def init_states(self, state_model, connect = {}, reset_states = False):
|
|
213
|
+
## Initialize state variables
|
|
214
|
+
if reset_states:
|
|
215
|
+
self.reset_states()
|
|
216
|
+
self.reset_connect_variables(copy.deepcopy(connect), only=False)
|
|
217
|
+
self.states_connect = {}
|
|
218
|
+
self.states_closed_loop = {}
|
|
219
|
+
## save the states updates
|
|
220
|
+
for state, param in state_model.items():
|
|
221
|
+
if 'connect' in param.keys():
|
|
222
|
+
self.states_connect[state] = param['connect']
|
|
223
|
+
else:
|
|
224
|
+
self.states_closed_loop[state] = param['closedLoop']
|
|
225
|
+
|
|
226
|
+
def reset_connect_variables(self, connect, values = None, only = True):
|
|
227
|
+
if only == False:
|
|
228
|
+
self.connect = connect
|
|
229
|
+
self.connect_variables = {}
|
|
230
|
+
self.initialize_connect = True
|
|
231
|
+
for key in connect.keys():
|
|
232
|
+
if values is not None and key in values.keys():
|
|
233
|
+
self.connect_variables[key] = values[key].clone()
|
|
234
|
+
elif only == False:
|
|
235
|
+
batch = values[list(values)[0]].shape[0] if values is not None else 1
|
|
236
|
+
window_size = self.input_n_samples[key]
|
|
237
|
+
self.connect_variables[key] = torch.zeros(size=(batch, window_size, self.inputs[key]['dim']),
|
|
238
|
+
dtype=torch.float32, requires_grad=False)
|
|
239
|
+
def reset_states(self, values = None, only = True):
|
|
240
|
+
if values is None:
|
|
241
|
+
for key, value in self.state_model.items():
|
|
242
|
+
batch = self.states[key].shape[0] if key in self.states else 1
|
|
243
|
+
window_size = self.input_n_samples[key]
|
|
244
|
+
self.states[key] = torch.zeros(size=(batch, window_size, self.state_model[key]['dim']),
|
|
245
|
+
dtype=torch.float32, requires_grad=False)
|
|
246
|
+
else:
|
|
247
|
+
if type(values) is set:
|
|
248
|
+
for key in self.state_model.keys():
|
|
249
|
+
if key in values:
|
|
250
|
+
batch = self.states[key].shape[0] if key in self.states else 1
|
|
251
|
+
window_size = self.input_n_samples[key]
|
|
252
|
+
self.states[key] = torch.zeros(size=(batch, window_size, self.state_model[key]['dim']),
|
|
253
|
+
dtype=torch.float32, requires_grad=False)
|
|
254
|
+
else:
|
|
255
|
+
for key in self.state_model.keys():
|
|
256
|
+
if key in values.keys():
|
|
257
|
+
self.states[key] = values[key].clone()
|
|
258
|
+
self.states[key].requires_grad = False
|
|
259
|
+
elif only == False:
|
|
260
|
+
batch = values[list(values)[0]].shape[0]
|
|
261
|
+
window_size = self.input_n_samples[key]
|
|
262
|
+
self.states[key] = torch.zeros(size=(batch, window_size, self.state_model[key]['dim']),
|
|
263
|
+
dtype=torch.float32, requires_grad=False)
|
nnodely/modeldef.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from nnodely.input import closedloop_name, connect_name
|
|
6
|
+
from nnodely.utils import check, merge
|
|
7
|
+
from nnodely.relation import MAIN_JSON, Stream
|
|
8
|
+
from nnodely.output import Output
|
|
9
|
+
|
|
10
|
+
from nnodely.logger import logging, nnLogger
|
|
11
|
+
log = nnLogger(__name__, logging.INFO)
|
|
12
|
+
|
|
13
|
+
class ModelDef():
|
|
14
|
+
def __init__(self, model_def = MAIN_JSON):
|
|
15
|
+
# Models definition
|
|
16
|
+
self.json_base = copy.deepcopy(model_def)
|
|
17
|
+
|
|
18
|
+
# Inizialize the model definition
|
|
19
|
+
self.json = copy.deepcopy(self.json_base)
|
|
20
|
+
if "SampleTime" in self.json['Info']:
|
|
21
|
+
self.sample_time = self.json['Info']["SampleTime"]
|
|
22
|
+
else:
|
|
23
|
+
self.sample_time = None
|
|
24
|
+
self.model_dict = {}
|
|
25
|
+
self.minimize_dict = {}
|
|
26
|
+
self.update_state_dict = {}
|
|
27
|
+
|
|
28
|
+
def __contains__(self, key):
|
|
29
|
+
return key in self.json
|
|
30
|
+
|
|
31
|
+
def __getitem__(self, key):
|
|
32
|
+
return self.json[key]
|
|
33
|
+
|
|
34
|
+
def __setitem__(self, key, value):
|
|
35
|
+
self.json[key] = value
|
|
36
|
+
|
|
37
|
+
def isDefined(self):
|
|
38
|
+
return self.json is not None
|
|
39
|
+
|
|
40
|
+
def update(self, model_def = None, model_dict = None, minimize_dict = None, update_state_dict = None):
|
|
41
|
+
self.json = copy.deepcopy(model_def) if model_def is not None else copy.deepcopy(self.json_base)
|
|
42
|
+
model_dict = copy.deepcopy(model_dict) if model_dict is not None else self.model_dict
|
|
43
|
+
minimize_dict = copy.deepcopy(minimize_dict) if minimize_dict is not None else self.minimize_dict
|
|
44
|
+
update_state_dict = copy.deepcopy(update_state_dict) if update_state_dict is not None else self.update_state_dict
|
|
45
|
+
|
|
46
|
+
# Add models to the model_def
|
|
47
|
+
for key, stream_list in model_dict.items():
|
|
48
|
+
for stream in stream_list:
|
|
49
|
+
self.json = merge(self.json, stream.json)
|
|
50
|
+
if len(model_dict) > 1:
|
|
51
|
+
if 'Models' not in self.json:
|
|
52
|
+
self.json['Models'] = {}
|
|
53
|
+
for model_name, model_params in model_dict.items():
|
|
54
|
+
self.json['Models'][model_name] = {'Inputs': [], 'States': [], 'Outputs': [], 'Parameters': [],
|
|
55
|
+
'Constants': []}
|
|
56
|
+
parameters, constants, inputs, states = set(), set(), set(), set()
|
|
57
|
+
for param in model_params:
|
|
58
|
+
self.json['Models'][model_name]['Outputs'].append(param.name)
|
|
59
|
+
parameters |= set(param.json['Parameters'].keys())
|
|
60
|
+
constants |= set(param.json['Constants'].keys())
|
|
61
|
+
inputs |= set(param.json['Inputs'].keys())
|
|
62
|
+
states |= set(param.json['States'].keys())
|
|
63
|
+
self.json['Models'][model_name]['Parameters'] = list(parameters)
|
|
64
|
+
self.json['Models'][model_name]['Constants'] = list(constants)
|
|
65
|
+
self.json['Models'][model_name]['Inputs'] = list(inputs)
|
|
66
|
+
self.json['Models'][model_name]['States'] = list(states)
|
|
67
|
+
elif len(model_dict) == 1:
|
|
68
|
+
self.json['Models'] = list(model_dict.keys())[0]
|
|
69
|
+
|
|
70
|
+
if 'Minimizers' not in self.json:
|
|
71
|
+
self.json['Minimizers'] = {}
|
|
72
|
+
for key, minimize in minimize_dict.items():
|
|
73
|
+
self.json = merge(self.json, minimize['A'].json)
|
|
74
|
+
self.json = merge(self.json, minimize['B'].json)
|
|
75
|
+
self.json['Minimizers'][key] = {}
|
|
76
|
+
self.json['Minimizers'][key]['A'] = minimize['A'].name
|
|
77
|
+
self.json['Minimizers'][key]['B'] = minimize['B'].name
|
|
78
|
+
self.json['Minimizers'][key]['loss'] = minimize['loss']
|
|
79
|
+
|
|
80
|
+
for key, update_state in update_state_dict.items():
|
|
81
|
+
self.json = merge(self.json, update_state.json)
|
|
82
|
+
|
|
83
|
+
if "SampleTime" in self.json['Info']:
|
|
84
|
+
self.sample_time = self.json['Info']["SampleTime"]
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def __update_state(self, stream_out, state_list_in, UpdateState):
|
|
88
|
+
from nnodely.input import State
|
|
89
|
+
if type(state_list_in) is not list:
|
|
90
|
+
state_list_in = [state_list_in]
|
|
91
|
+
for state_in in state_list_in:
|
|
92
|
+
check(isinstance(stream_out, (Output, Stream)), TypeError,
|
|
93
|
+
f"The {stream_out} must be a Stream or Output and not a {type(stream_out)}.")
|
|
94
|
+
check(type(state_in) is State, TypeError,
|
|
95
|
+
f"The {state_in} must be a State and not a {type(state_in)}.")
|
|
96
|
+
check(stream_out.dim['dim'] == state_in.dim['dim'], ValueError,
|
|
97
|
+
f"The dimension of {stream_out.name} is not equal to the dimension of {state_in.name} ({stream_out.dim['dim']}!={state_in.dim['dim']}).")
|
|
98
|
+
if type(stream_out) is Output:
|
|
99
|
+
stream_name = self.json['Outputs'][stream_out.name]
|
|
100
|
+
stream_out = Stream(stream_name,stream_out.json,stream_out.dim, 0)
|
|
101
|
+
self.update_state_dict[state_in.name] = UpdateState(stream_out, state_in)
|
|
102
|
+
|
|
103
|
+
def addConnect(self, stream_out, state_list_in):
|
|
104
|
+
from nnodely.input import Connect
|
|
105
|
+
self.__update_state(stream_out, state_list_in, Connect)
|
|
106
|
+
self.update()
|
|
107
|
+
|
|
108
|
+
def addClosedLoop(self, stream_out, state_list_in):
|
|
109
|
+
from nnodely.input import ClosedLoop
|
|
110
|
+
self.__update_state(stream_out, state_list_in, ClosedLoop)
|
|
111
|
+
self.update()
|
|
112
|
+
|
|
113
|
+
def addModel(self, name, stream_list):
|
|
114
|
+
if isinstance(stream_list, (Output,Stream)):
|
|
115
|
+
stream_list = [stream_list]
|
|
116
|
+
if type(stream_list) is list:
|
|
117
|
+
self.model_dict[name] = copy.deepcopy(stream_list)
|
|
118
|
+
else:
|
|
119
|
+
raise TypeError(f'stream_list is type {type(stream_list)} but must be an Output or Stream or a list of them')
|
|
120
|
+
self.update()
|
|
121
|
+
|
|
122
|
+
def removeModel(self, name_list):
|
|
123
|
+
if type(name_list) is str:
|
|
124
|
+
name_list = [name_list]
|
|
125
|
+
if type(name_list) is list:
|
|
126
|
+
for name in name_list:
|
|
127
|
+
check(name in self.model_dict, IndexError, f"The name {name} is not part of the available models")
|
|
128
|
+
del self.model_dict[name]
|
|
129
|
+
self.update()
|
|
130
|
+
|
|
131
|
+
def addMinimize(self, name, streamA, streamB, loss_function='mse'):
|
|
132
|
+
check(isinstance(streamA, (Output, Stream)), TypeError, 'streamA must be an instance of Output or Stream')
|
|
133
|
+
check(isinstance(streamB, (Output, Stream)), TypeError, 'streamA must be an instance of Output or Stream')
|
|
134
|
+
check(streamA.dim == streamB.dim, ValueError, f'Dimension of streamA={streamA.dim} and streamB={streamB.dim} are not equal.')
|
|
135
|
+
self.minimize_dict[name]={'A':copy.deepcopy(streamA), 'B': copy.deepcopy(streamB), 'loss':loss_function}
|
|
136
|
+
self.update()
|
|
137
|
+
|
|
138
|
+
def removeMinimize(self, name_list):
|
|
139
|
+
if type(name_list) is str:
|
|
140
|
+
name_list = [name_list]
|
|
141
|
+
if type(name_list) is list:
|
|
142
|
+
for name in name_list:
|
|
143
|
+
check(name in self.minimize_dict, IndexError, f"The name {name} is not part of the available minimuzes")
|
|
144
|
+
del self.minimize_dict[name]
|
|
145
|
+
self.update()
|
|
146
|
+
|
|
147
|
+
def setBuildWindow(self, sample_time = None):
|
|
148
|
+
check(self.json is not None, RuntimeError, "No model is defined!")
|
|
149
|
+
if sample_time is not None:
|
|
150
|
+
check(sample_time > 0, RuntimeError, 'Sample time must be strictly positive!')
|
|
151
|
+
self.sample_time = sample_time
|
|
152
|
+
else:
|
|
153
|
+
if self.sample_time is None:
|
|
154
|
+
self.sample_time = 1
|
|
155
|
+
|
|
156
|
+
self.json['Info'] = {"SampleTime": self.sample_time}
|
|
157
|
+
|
|
158
|
+
check(self.json['Inputs'] | self.json['States'] != {}, RuntimeError, "No model is defined!")
|
|
159
|
+
json_inputs = self.json['Inputs'] | self.json['States']
|
|
160
|
+
|
|
161
|
+
for key,value in self.json['States'].items():
|
|
162
|
+
check(closedloop_name in self.json['States'][key] or connect_name in self.json['States'][key],
|
|
163
|
+
KeyError, f'Update function is missing for state {key}. Use Connect or ClosedLoop to update the state.')
|
|
164
|
+
|
|
165
|
+
input_tw_backward, input_tw_forward, input_ns_backward, input_ns_forward = {}, {}, {}, {}
|
|
166
|
+
for key, value in json_inputs.items():
|
|
167
|
+
if value['sw'] == [0,0] and value['tw'] == [0,0]:
|
|
168
|
+
assert(False), f"Input {key} has no time window or sample window"
|
|
169
|
+
if value['sw'] == [0, 0] and self.sample_time is not None:
|
|
170
|
+
input_ns_backward[key] = round(-value['tw'][0] / self.sample_time)
|
|
171
|
+
input_ns_forward[key] = round(value['tw'][1] / self.sample_time)
|
|
172
|
+
elif self.sample_time is not None:
|
|
173
|
+
input_ns_backward[key] = max(round(-value['tw'][0] / self.sample_time),-value['sw'][0])
|
|
174
|
+
input_ns_forward[key] = max(round(value['tw'][1] / self.sample_time),value['sw'][1])
|
|
175
|
+
else:
|
|
176
|
+
check(value['tw'] == [0,0], RuntimeError, f"Sample time is not defined for input {key}")
|
|
177
|
+
input_ns_backward[key] = -value['sw'][0]
|
|
178
|
+
input_ns_forward[key] = value['sw'][1]
|
|
179
|
+
value['ns'] = [input_ns_backward[key], input_ns_forward[key]]
|
|
180
|
+
value['ntot'] = sum(value['ns'])
|
|
181
|
+
|
|
182
|
+
self.json['Info']['ns'] = [max(input_ns_backward.values()), max(input_ns_forward.values())]
|
|
183
|
+
self.json['Info']['ntot'] = sum(self.json['Info']['ns'])
|
|
184
|
+
if self.json['Info']['ns'][0] < 0:
|
|
185
|
+
log.warning(
|
|
186
|
+
f"The input is only in the far past the max_samples_backward is: {self.json['Info']['ns'][0]}")
|
|
187
|
+
if self.json['Info']['ns'][1] < 0:
|
|
188
|
+
log.warning(
|
|
189
|
+
f"The input is only in the far future the max_sample_forward is: {self.json['Info']['ns'][1]}")
|
|
190
|
+
|
|
191
|
+
for k, v in (self.json['Parameters']|self.json['Constants']).items():
|
|
192
|
+
if 'values' in v:
|
|
193
|
+
window = 'tw' if 'tw' in v.keys() else ('sw' if 'sw' in v.keys() else None)
|
|
194
|
+
if window == 'tw':
|
|
195
|
+
check(np.array(v['values']).shape[0] == v['tw']/self.sample_time, ValueError,
|
|
196
|
+
f"{k} has a different number of values for this sample time.")
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def updateParameters(self, model):
|
|
200
|
+
if model is not None:
|
|
201
|
+
for key in self.json['Parameters'].keys():
|
|
202
|
+
if key in model.all_parameters:
|
|
203
|
+
self.json['Parameters'][key]['values'] = model.all_parameters[key].tolist()
|
|
204
|
+
if 'init_fun' in self.json['Parameters'][key]:
|
|
205
|
+
del self.json['Parameters'][key]['init_fun']
|