nnodely 1.5.4__py3-none-any.whl → 1.5.5.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. nnodely/__init__.py +82 -32
  2. nnodely/basic/loss.py +20 -11
  3. nnodely/basic/model.py +178 -74
  4. nnodely/basic/modeldef.py +225 -110
  5. nnodely/basic/optimizer.py +42 -26
  6. nnodely/basic/relation.py +165 -69
  7. nnodely/exporter/emptyexporter.py +13 -11
  8. nnodely/exporter/export.py +281 -155
  9. nnodely/exporter/reporter.py +41 -15
  10. nnodely/exporter/standardexporter.py +126 -47
  11. nnodely/layers/activation.py +102 -62
  12. nnodely/layers/arithmetic.py +199 -124
  13. nnodely/layers/equationlearner.py +80 -15
  14. nnodely/layers/fir.py +110 -48
  15. nnodely/layers/fuzzify.py +149 -65
  16. nnodely/layers/input.py +145 -64
  17. nnodely/layers/interpolation.py +55 -25
  18. nnodely/layers/linear.py +105 -39
  19. nnodely/layers/localmodel.py +151 -64
  20. nnodely/layers/neuralODE.py +55 -37
  21. nnodely/layers/output.py +9 -6
  22. nnodely/layers/parameter.py +114 -51
  23. nnodely/layers/parametricfunction.py +209 -91
  24. nnodely/layers/part.py +254 -139
  25. nnodely/layers/rungekutta.py +30 -28
  26. nnodely/layers/timeoperation.py +55 -19
  27. nnodely/layers/trigonometric.py +121 -59
  28. nnodely/nnodely.py +110 -42
  29. nnodely/operators/composer.py +214 -81
  30. nnodely/operators/exporter.py +134 -49
  31. nnodely/operators/loader.py +132 -54
  32. nnodely/operators/network.py +428 -143
  33. nnodely/operators/trainer.py +308 -116
  34. nnodely/operators/validator.py +214 -97
  35. nnodely/support/earlystopping.py +28 -16
  36. nnodely/support/fixstepsolver.py +26 -13
  37. nnodely/support/initializer.py +60 -21
  38. nnodely/support/jsonutils.py +556 -260
  39. nnodely/support/logger.py +31 -20
  40. nnodely/support/mathutils.py +9 -2
  41. nnodely/support/odeint/__init__.py +0 -0
  42. nnodely/support/odeint/adjoint.py +406 -0
  43. nnodely/support/odeint/dopri5.py +60 -0
  44. nnodely/support/odeint/fixed_grid.py +18 -0
  45. nnodely/support/odeint/my_odeint.py +158 -0
  46. nnodely/support/odeint/rk_solvers.py +547 -0
  47. nnodely/support/odeint/solvers.py +233 -0
  48. nnodely/support/odeint/utils.py +279 -0
  49. nnodely/support/utils.py +48 -21
  50. nnodely/visualizer/__init__.py +1 -1
  51. nnodely/visualizer/dynamicmpl/functionplot.py +11 -11
  52. nnodely/visualizer/dynamicmpl/fuzzyplot.py +8 -7
  53. nnodely/visualizer/dynamicmpl/resultsplot.py +8 -7
  54. nnodely/visualizer/dynamicmpl/trainingplot.py +10 -7
  55. nnodely/visualizer/emptyvisualizer.py +7 -5
  56. nnodely/visualizer/mplnotebookvisualizer.py +63 -34
  57. nnodely/visualizer/mplvisualizer.py +152 -74
  58. nnodely/visualizer/textvisualizer.py +317 -143
  59. {nnodely-1.5.4.dist-info → nnodely-1.5.5.dev2.dist-info}/METADATA +39 -60
  60. nnodely-1.5.5.dev2.dist-info/RECORD +67 -0
  61. {nnodely-1.5.4.dist-info → nnodely-1.5.5.dev2.dist-info}/WHEEL +1 -2
  62. mplplots/__init__.py +0 -1
  63. mplplots/plots.py +0 -188
  64. nnodely-1.5.4.dist-info/RECORD +0 -62
  65. nnodely-1.5.4.dist-info/top_level.txt +0 -2
  66. {nnodely-1.5.4.dist-info → nnodely-1.5.5.dev2.dist-info}/licenses/LICENSE +0 -0
nnodely/__init__.py CHANGED
@@ -14,7 +14,15 @@ from nnodely.layers.arithmetic import Add, Sum, Sub, Mul, Div, Pow, Neg, Sign
14
14
  from nnodely.layers.trigonometric import Sin, Cos, Tan, Cosh, Tanh, Sech
15
15
  from nnodely.layers.parametricfunction import ParamFun
16
16
  from nnodely.layers.fuzzify import Fuzzify
17
- from nnodely.layers.part import Part, Select, Concatenate, SamplePart, SampleSelect, TimePart, TimeConcatenate
17
+ from nnodely.layers.part import (
18
+ Part,
19
+ Select,
20
+ Concatenate,
21
+ SamplePart,
22
+ SampleSelect,
23
+ TimePart,
24
+ TimeConcatenate,
25
+ )
18
26
  from nnodely.layers.localmodel import LocalModel
19
27
  from nnodely.layers.equationlearner import EquationLearner
20
28
  from nnodely.layers.timeoperation import Integrate, Differentiate
@@ -37,43 +45,85 @@ from nnodely.support import logger
37
45
  major, minor = sys.version_info.major, sys.version_info.minor
38
46
  logger.LOG_LEVEL = logging.INFO
39
47
 
40
- __version__ = '1.5.4'
48
+ __version__ = "1.5.4"
41
49
 
42
50
  if major < 3:
43
- sys.exit("Sorry, Python 2 is not supported. You need Python >= 3.10 for "+__package__+".")
51
+ sys.exit(
52
+ "Sorry, Python 2 is not supported. You need Python >= 3.10 for "
53
+ + __package__
54
+ + "."
55
+ )
44
56
  elif minor < 9:
45
- sys.exit("Sorry, You need Python >= 3.10 for "+__package__+".")
57
+ sys.exit("Sorry, You need Python >= 3.10 for " + __package__ + ".")
46
58
  else:
47
- print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>' +
48
- f' {__package__}_v{__version__} '.center(20, '-') +
49
- '<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
59
+ print(
60
+ ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
61
+ + f" {__package__}_v{__version__} ".center(20, "-")
62
+ + "<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<"
63
+ )
50
64
 
51
65
 
52
66
  __all__ = [
53
- 'nnodely', 'Modely', 'clearNames',
54
- 'Input', 'Connect', 'ClosedLoop',
55
- 'Parameter', 'Constant', 'SampleTime',
56
- 'Output',
57
- 'Relu', 'ELU', 'Softmax', 'Sigmoid', 'Identity',
58
- 'Fir',
59
- 'Linear',
60
- 'NeuralODE',
61
- 'Add', 'Sum', 'Sub', 'Mul', 'Div', 'Pow', 'Neg', 'Sign',
62
- 'Sin', 'Cos', 'Tan', 'Cosh', 'Tanh', 'Sech',
63
- 'ParamFun',
64
- 'Fuzzify',
65
- 'Part', 'Select', 'Concatenate',
66
- 'SamplePart', 'SampleSelect',
67
- 'TimePart', 'TimeConcatenate',
68
- 'LocalModel',
69
- 'EquationLearner',
70
- 'Integrate', 'Differentiate',
71
- 'Interpolation',
72
- 'ForwardEuler', 'RK2', 'RK4',
73
- 'TextVisualizer', 'MPLVisualizer', 'MPLNotebookVisualizer',
74
- 'StandardExporter',
75
- 'SGD', 'Adam', 'Optimizer',
76
- 'init_negexp', 'init_lin', 'init_constant', 'init_exp',
67
+ "nnodely",
68
+ "Modely",
69
+ "clearNames",
70
+ "Input",
71
+ "Connect",
72
+ "ClosedLoop",
73
+ "Parameter",
74
+ "Constant",
75
+ "SampleTime",
76
+ "Output",
77
+ "Relu",
78
+ "ELU",
79
+ "Softmax",
80
+ "Sigmoid",
81
+ "Identity",
82
+ "Fir",
83
+ "Linear",
84
+ "NeuralODE",
85
+ "Add",
86
+ "Sum",
87
+ "Sub",
88
+ "Mul",
89
+ "Div",
90
+ "Pow",
91
+ "Neg",
92
+ "Sign",
93
+ "Sin",
94
+ "Cos",
95
+ "Tan",
96
+ "Cosh",
97
+ "Tanh",
98
+ "Sech",
99
+ "ParamFun",
100
+ "Fuzzify",
101
+ "Part",
102
+ "Select",
103
+ "Concatenate",
104
+ "SamplePart",
105
+ "SampleSelect",
106
+ "TimePart",
107
+ "TimeConcatenate",
108
+ "LocalModel",
109
+ "EquationLearner",
110
+ "Integrate",
111
+ "Differentiate",
112
+ "Interpolation",
113
+ "ForwardEuler",
114
+ "RK2",
115
+ "RK4",
116
+ "TextVisualizer",
117
+ "MPLVisualizer",
118
+ "MPLNotebookVisualizer",
119
+ "StandardExporter",
120
+ "SGD",
121
+ "Adam",
122
+ "Optimizer",
123
+ "init_negexp",
124
+ "init_lin",
125
+ "init_constant",
126
+ "init_exp",
77
127
  # Main nnodely classes
78
- '__version__'
128
+ "__version__",
79
129
  ]
nnodely/basic/loss.py CHANGED
@@ -2,26 +2,35 @@ import torch.nn as nn
2
2
  import torch
3
3
  from nnodely.support.utils import check
4
4
 
5
- available_losses = ['mse', 'rmse', 'mae', 'cross_entropy']
5
+ available_losses = ["mse", "rmse", "mae", "cross_entropy"]
6
+
6
7
 
7
8
  class CustomLoss(nn.Module):
8
- def __init__(self, loss_type='mse', **kwargs):
9
+ def __init__(self, loss_type="mse", **kwargs):
9
10
  super(CustomLoss, self).__init__()
10
- check(loss_type in available_losses, TypeError, f'The \"{loss_type}\" loss is not available. Possible losses are: {available_losses}.')
11
+ check(
12
+ loss_type in available_losses,
13
+ TypeError,
14
+ f'The "{loss_type}" loss is not available. Possible losses are: {available_losses}.',
15
+ )
11
16
  self.loss_type = loss_type
12
17
  self.loss = nn.MSELoss(**kwargs)
13
18
  if callable(loss_type):
14
19
  self.loss = loss_type
15
- elif self.loss_type == 'mae':
20
+ elif self.loss_type == "mae":
16
21
  self.loss = nn.L1Loss(**kwargs)
17
- elif self.loss_type == 'cross_entropy':
22
+ elif self.loss_type == "cross_entropy":
18
23
  self.loss = nn.CrossEntropyLoss(**kwargs)
19
-
24
+
20
25
  def forward(self, inA, inB):
21
- if self.loss_type == 'cross_entropy':
22
- inB = inB.squeeze().float() if inA.shape == inB.shape else inB.squeeze().long()
26
+ if self.loss_type == "cross_entropy":
27
+ inB = (
28
+ inB.squeeze().float()
29
+ if inA.shape == inB.shape
30
+ else inB.squeeze().long()
31
+ )
23
32
  inA = inA.squeeze()
24
- res = self.loss(inA,inB)
25
- if self.loss_type == 'rmse':
33
+ res = self.loss(inA, inB)
34
+ if self.loss_type == "rmse":
26
35
  res = torch.sqrt(res)
27
- return res
36
+ return res
nnodely/basic/model.py CHANGED
@@ -9,34 +9,46 @@ from itertools import product
9
9
  from nnodely.support.utils import TORCH_DTYPE
10
10
  from nnodely.support import initializer
11
11
 
12
+
12
13
  @torch.fx.wrap
13
14
  def update_state(data_in, rel):
14
- #virtual = torch.roll(data_in, shifts=-1, dims=1)
15
+ # virtual = torch.roll(data_in, shifts=-1, dims=1)
15
16
  max_dim = min(rel.size(1), data_in.size(1))
16
17
  data_out = data_in.clone()
17
18
  data_out[:, -max_dim:, :] = rel[:, -max_dim:, :]
18
19
  return data_out
19
20
 
21
+
20
22
  class Model(nn.Module):
21
23
  def __init__(self, model_def):
22
24
  super(Model, self).__init__()
23
25
  model_def = copy.deepcopy(model_def)
24
26
 
25
- self.states = {key: value for key, value in model_def['Inputs'].items() if ('closedLoop' in value.keys() or 'connect' in value.keys())}
27
+ self.states = {
28
+ key: value
29
+ for key, value in model_def["Inputs"].items()
30
+ if ("closedLoop" in value.keys() or "connect" in value.keys())
31
+ }
26
32
 
27
- self.inputs = model_def['Inputs']
28
- self.outputs = model_def['Outputs']
29
- self.relations = model_def['Relations']
30
- self.params = model_def['Parameters']
31
- self.constants = model_def['Constants']
32
- self.sample_time = model_def['Info']['SampleTime']
33
- self.functions = model_def['Functions']
33
+ self.inputs = model_def["Inputs"]
34
+ self.outputs = model_def["Outputs"]
35
+ self.relations = model_def["Relations"]
36
+ self.params = model_def["Parameters"]
37
+ self.constants = model_def["Constants"]
38
+ self.sample_time = model_def["Info"]["SampleTime"]
39
+ self.functions = model_def["Functions"]
34
40
 
35
- self.minimizers = model_def['Minimizers'] if 'Minimizers' in model_def else {}
36
- self.minimizers_keys = [self.minimizers[key]['A'] for key in self.minimizers] + [self.minimizers[key]['B'] for key in self.minimizers]
41
+ self.minimizers = model_def["Minimizers"] if "Minimizers" in model_def else {}
42
+ self.minimizers_keys = [
43
+ self.minimizers[key]["A"] for key in self.minimizers
44
+ ] + [self.minimizers[key]["B"] for key in self.minimizers]
37
45
 
38
- self.input_ns_backward = {key:value['ns'][0] for key, value in model_def['Inputs'].items()}
39
- self.input_n_samples = {key:value['ntot'] for key, value in model_def['Inputs'].items()}
46
+ self.input_ns_backward = {
47
+ key: value["ns"][0] for key, value in model_def["Inputs"].items()
48
+ }
49
+ self.input_n_samples = {
50
+ key: value["ntot"] for key, value in model_def["Inputs"].items()
51
+ }
40
52
 
41
53
  ## Build the network
42
54
  self.all_parameters = {}
@@ -51,55 +63,89 @@ class Model(nn.Module):
51
63
 
52
64
  ## Define the correct slicing
53
65
  for _, items in self.relations.items():
54
- if items[0] == 'SamplePart':
66
+ if items[0] == "SamplePart":
55
67
  if items[1][0] in self.inputs.keys():
56
68
  items[3][0] = self.input_ns_backward[items[1][0]] + items[3][0]
57
69
  items[3][1] = self.input_ns_backward[items[1][0]] + items[3][1]
58
- if len(items) > 4: ## Offset
70
+ if len(items) > 4: ## Offset
59
71
  items[4] = self.input_ns_backward[items[1][0]] + items[4]
60
- if items[0] == 'TimePart':
72
+ if items[0] == "TimePart":
61
73
  if items[1][0] in self.inputs.keys():
62
- items[3][0] = self.input_ns_backward[items[1][0]] + round(items[3][0]/self.sample_time)
63
- items[3][1] = self.input_ns_backward[items[1][0]] + round(items[3][1]/self.sample_time)
64
- if len(items) > 4: ## Offset
65
- items[4] = self.input_ns_backward[items[1][0]] + round(items[4]/self.sample_time)
74
+ items[3][0] = self.input_ns_backward[items[1][0]] + round(
75
+ items[3][0] / self.sample_time
76
+ )
77
+ items[3][1] = self.input_ns_backward[items[1][0]] + round(
78
+ items[3][1] / self.sample_time
79
+ )
80
+ if len(items) > 4: ## Offset
81
+ items[4] = self.input_ns_backward[items[1][0]] + round(
82
+ items[4] / self.sample_time
83
+ )
66
84
  else:
67
- items[3][0] = round(items[3][0]/self.sample_time)
68
- items[3][1] = round(items[3][1]/self.sample_time)
69
- if len(items) > 4: ## Offset
70
- items[4] = round(items[4]/self.sample_time)
85
+ items[3][0] = round(items[3][0] / self.sample_time)
86
+ items[3][1] = round(items[3][1] / self.sample_time)
87
+ if len(items) > 4: ## Offset
88
+ items[4] = round(items[4] / self.sample_time)
71
89
 
72
90
  ## Create all the parameters
73
91
  for name, param_data in self.params.items():
74
- window = 'tw' if 'tw' in param_data.keys() else ('sw' if 'sw' in param_data.keys() else None)
75
- aux_sample_time = self.sample_time if 'tw' == window else 1
76
- sample_window = round(param_data[window] / aux_sample_time) if window else None
92
+ window = (
93
+ "tw"
94
+ if "tw" in param_data.keys()
95
+ else ("sw" if "sw" in param_data.keys() else None)
96
+ )
97
+ aux_sample_time = self.sample_time if "tw" == window else 1
98
+ sample_window = (
99
+ round(param_data[window] / aux_sample_time) if window else None
100
+ )
77
101
  if sample_window is None:
78
- param_size = tuple(param_data['dim']) if type(param_data['dim']) is list else (param_data['dim'],)
102
+ param_size = (
103
+ tuple(param_data["dim"])
104
+ if type(param_data["dim"]) is list
105
+ else (param_data["dim"],)
106
+ )
79
107
  else:
80
- param_size = (sample_window,)+tuple(param_data['dim']) if type(param_data['dim']) is list else (sample_window, param_data['dim'])
81
- if 'values' in param_data:
82
- self.all_parameters[name] = nn.Parameter(torch.tensor(param_data['values'], dtype=TORCH_DTYPE), requires_grad=True)
108
+ param_size = (
109
+ (sample_window,) + tuple(param_data["dim"])
110
+ if type(param_data["dim"]) is list
111
+ else (sample_window, param_data["dim"])
112
+ )
113
+ if "values" in param_data:
114
+ self.all_parameters[name] = nn.Parameter(
115
+ torch.tensor(param_data["values"], dtype=TORCH_DTYPE),
116
+ requires_grad=True,
117
+ )
83
118
  # TODO clean code
84
- elif 'init_fun' in param_data:
85
- if 'code' in param_data['init_fun'].keys():
86
- exec(param_data['init_fun']['code'], globals())
87
- function_to_call = globals()[param_data['init_fun']['name']]
119
+ elif "init_fun" in param_data:
120
+ if "code" in param_data["init_fun"].keys():
121
+ exec(param_data["init_fun"]["code"], globals())
122
+ function_to_call = globals()[param_data["init_fun"]["name"]]
88
123
  else:
89
- function_to_call = getattr(initializer, param_data['init_fun']['name'])
124
+ function_to_call = getattr(
125
+ initializer, param_data["init_fun"]["name"]
126
+ )
90
127
  values = np.zeros(param_size)
91
128
  for indexes in product(*(range(v) for v in param_size)):
92
- if 'params' in param_data['init_fun']:
93
- values[indexes] = function_to_call(indexes, param_size, param_data['init_fun']['params'])
129
+ if "params" in param_data["init_fun"]:
130
+ values[indexes] = function_to_call(
131
+ indexes, param_size, param_data["init_fun"]["params"]
132
+ )
94
133
  else:
95
134
  values[indexes] = function_to_call(indexes, param_size)
96
- self.all_parameters[name] = nn.Parameter(torch.tensor(values.tolist(), dtype=TORCH_DTYPE), requires_grad=True)
135
+ self.all_parameters[name] = nn.Parameter(
136
+ torch.tensor(values.tolist(), dtype=TORCH_DTYPE), requires_grad=True
137
+ )
97
138
  else:
98
- self.all_parameters[name] = nn.Parameter(torch.rand(size=param_size, dtype=TORCH_DTYPE), requires_grad=True)
139
+ self.all_parameters[name] = nn.Parameter(
140
+ torch.rand(size=param_size, dtype=TORCH_DTYPE), requires_grad=True
141
+ )
99
142
 
100
143
  ## Create all the constants
101
144
  for name, param_data in self.constants.items():
102
- self.all_constants[name] = nn.Parameter(torch.tensor(param_data['values'], dtype=TORCH_DTYPE), requires_grad=False)
145
+ self.all_constants[name] = nn.Parameter(
146
+ torch.tensor(param_data["values"], dtype=TORCH_DTYPE),
147
+ requires_grad=False,
148
+ )
103
149
  all_params_and_consts = self.all_parameters | self.all_constants
104
150
 
105
151
  ## Create all the relations
@@ -107,27 +153,41 @@ class Model(nn.Module):
107
153
  ## Take the relation name and the inputs needed to solve the relation
108
154
  rel_name, input_var = inputs[0], inputs[1]
109
155
  ## Create All the Relations
110
- func = getattr(self,rel_name)
156
+ func = getattr(self, rel_name)
111
157
  if func:
112
158
  layer_inputs = []
113
159
  for item in inputs[2:]:
114
- if item in list(self.params.keys()): ## the relation takes parameters
160
+ if item in list(
161
+ self.params.keys()
162
+ ): ## the relation takes parameters
115
163
  layer_inputs.append(self.all_parameters[item])
116
- elif item in list(self.constants.keys()): ## the relation takes a constant
164
+ elif item in list(
165
+ self.constants.keys()
166
+ ): ## the relation takes a constant
117
167
  layer_inputs.append(self.all_constants[item])
118
- elif item in list(self.functions.keys()): ## the relation takes a custom function
168
+ elif item in list(
169
+ self.functions.keys()
170
+ ): ## the relation takes a custom function
119
171
  layer_inputs.append(self.functions[item])
120
- if 'params_and_consts' in self.functions[item].keys() and len(self.functions[item]['params_and_consts']) >= 0: ## Parametric function that takes parameters
121
- layer_inputs.append([all_params_and_consts[par] for par in self.functions[item]['params_and_consts']])
122
- if 'map_over_dim' in self.functions[item].keys():
123
- layer_inputs.append(self.functions[item]['map_over_dim'])
172
+ if (
173
+ "params_and_consts" in self.functions[item].keys()
174
+ and len(self.functions[item]["params_and_consts"]) >= 0
175
+ ): ## Parametric function that takes parameters
176
+ layer_inputs.append(
177
+ [
178
+ all_params_and_consts[par]
179
+ for par in self.functions[item]["params_and_consts"]
180
+ ]
181
+ )
182
+ if "map_over_dim" in self.functions[item].keys():
183
+ layer_inputs.append(self.functions[item]["map_over_dim"])
124
184
  else:
125
185
  layer_inputs.append(item)
126
186
 
127
- if rel_name == 'SamplePart':
187
+ if rel_name == "SamplePart":
128
188
  if layer_inputs[0] == -1:
129
189
  layer_inputs[0] = self.input_n_samples[input_var[0]]
130
- elif rel_name == 'TimePart':
190
+ elif rel_name == "TimePart":
131
191
  if layer_inputs[0] == -1:
132
192
  layer_inputs[0] = self.input_n_samples[input_var[0]]
133
193
  else:
@@ -145,60 +205,96 @@ class Model(nn.Module):
145
205
  self.network_output_predictions = set(self.outputs.values())
146
206
  ## list of network minimization outputs
147
207
  self.network_output_minimizers = []
148
- for _,value in self.minimizers.items():
149
- self.network_output_minimizers.append(self.outputs[value['A']]) if value['A'] in self.outputs.keys() else self.network_output_minimizers.append(value['A'])
150
- self.network_output_minimizers.append(self.outputs[value['B']]) if value['B'] in self.outputs.keys() else self.network_output_minimizers.append(value['B'])
208
+ for _, value in self.minimizers.items():
209
+ self.network_output_minimizers.append(self.outputs[value["A"]]) if value[
210
+ "A"
211
+ ] in self.outputs.keys() else self.network_output_minimizers.append(
212
+ value["A"]
213
+ )
214
+ self.network_output_minimizers.append(self.outputs[value["B"]]) if value[
215
+ "B"
216
+ ] in self.outputs.keys() else self.network_output_minimizers.append(
217
+ value["B"]
218
+ )
151
219
  self.network_output_minimizers = set(self.network_output_minimizers)
152
220
  ## list of all the network Outputs
153
- self.network_outputs = self.network_output_predictions.union(self.network_output_minimizers)
221
+ self.network_outputs = self.network_output_predictions.union(
222
+ self.network_output_minimizers
223
+ )
154
224
 
155
225
  def forward(self, kwargs):
156
226
  result_dict = {}
157
227
 
158
228
  ## Initially i have only the inputs from the dataset, the parameters, and the constants
159
- available_inputs = [key for key in self.inputs.keys() if key not in self.connect_update.keys()] ## remove connected inputs
160
- available_keys = set(available_inputs + list(self.all_parameters.keys()) + list(self.all_constants.keys()))
229
+ available_inputs = [
230
+ key for key in self.inputs.keys() if key not in self.connect_update.keys()
231
+ ] ## remove connected inputs
232
+ available_keys = set(
233
+ available_inputs
234
+ + list(self.all_parameters.keys())
235
+ + list(self.all_constants.keys())
236
+ )
161
237
 
162
238
  ## Forward pass through the relations
163
- while not self.network_outputs.issubset(available_keys): ## i need to climb the relation tree until i get all the outputs
239
+ while not self.network_outputs.issubset(
240
+ available_keys
241
+ ): ## i need to climb the relation tree until i get all the outputs
164
242
  for relation in self.relations.keys():
165
243
  ## if i have all the variables i can calculate the relation
166
- if set(self.relation_inputs[relation]).issubset(available_keys) and (relation not in available_keys):
244
+ if set(self.relation_inputs[relation]).issubset(available_keys) and (
245
+ relation not in available_keys
246
+ ):
167
247
  ## Collect all the necessary inputs for the relation
168
248
  layer_inputs = []
169
249
  for key in self.relation_inputs[relation]:
170
- if key in self.all_constants.keys(): ## relation that takes a constant
250
+ if (
251
+ key in self.all_constants.keys()
252
+ ): ## relation that takes a constant
171
253
  layer_inputs.append(self.all_constants[key])
172
254
  elif key in available_inputs: ## relation that takes inputs
173
255
  layer_inputs.append(kwargs[key])
174
- elif key in self.all_parameters.keys(): ## relation that takes parameters
256
+ elif (
257
+ key in self.all_parameters.keys()
258
+ ): ## relation that takes parameters
175
259
  layer_inputs.append(self.all_parameters[key])
176
- else: ## relation than takes another relation or a connect variable
260
+ else: ## relation than takes another relation or a connect variable
177
261
  layer_inputs.append(result_dict[key])
178
262
 
179
263
  ## Execute the current relation
180
- result_dict[relation] = self.relation_forward[relation](*layer_inputs)
264
+ result_dict[relation] = self.relation_forward[relation](
265
+ *layer_inputs
266
+ )
181
267
  available_keys.add(relation)
182
268
 
183
269
  ## Check if the relation is inside the connect
184
270
  for connect_input, connect_rel in self.connect_update.items():
185
271
  if relation == connect_rel:
186
- result_dict[connect_input] = update_state(kwargs[connect_input], result_dict[relation])
272
+ result_dict[connect_input] = update_state(
273
+ kwargs[connect_input], result_dict[relation]
274
+ )
187
275
  available_keys.add(connect_input)
188
276
 
189
277
  ## Return a dictionary with all the connected inputs
190
- connect_update_dict = {key: result_dict[key] for key in self.connect_update.keys()}
278
+ connect_update_dict = {
279
+ key: result_dict[key] for key in self.connect_update.keys()
280
+ }
191
281
  ## Return a dictionary with all the relations that updates the state variables
192
- closed_loop_update_dict = {key: result_dict[value] for key, value in self.closed_loop_update.items()}
282
+ closed_loop_update_dict = {
283
+ key: result_dict[value] for key, value in self.closed_loop_update.items()
284
+ }
193
285
  ## Return a dictionary with all the outputs final values
194
286
  output_dict = {key: result_dict[value] for key, value in self.outputs.items()}
195
287
  ## Return a dictionary with the minimization relations
196
288
  minimize_dict = {}
197
289
  for key in self.minimizers_keys:
198
- minimize_dict[key] = result_dict[self.outputs[key]] if key in self.outputs.keys() else result_dict[key]
290
+ minimize_dict[key] = (
291
+ result_dict[self.outputs[key]]
292
+ if key in self.outputs.keys()
293
+ else result_dict[key]
294
+ )
199
295
  return output_dict, minimize_dict, closed_loop_update_dict, connect_update_dict
200
296
 
201
- def update(self, *, closed_loop = {}, connect = {}, disconnect = False):
297
+ def update(self, *, closed_loop={}, connect={}, disconnect=False):
202
298
  self.closed_loop_update = {}
203
299
  self.connect_update = {}
204
300
 
@@ -206,15 +302,23 @@ class Model(nn.Module):
206
302
  return
207
303
 
208
304
  for key, state in self.states.items():
209
- if 'connect' in state.keys():
210
- self.connect_update[key] = state['connect']
211
- elif 'closedLoop' in state.keys():
212
- self.closed_loop_update[key] = state['closedLoop']
305
+ if "connect" in state.keys():
306
+ self.connect_update[key] = state["connect"]
307
+ elif "closedLoop" in state.keys():
308
+ self.closed_loop_update[key] = state["closedLoop"]
213
309
 
214
310
  # Get relation from outputs
215
311
  for connect_in, connect_rel in connect.items():
216
- set_relation = self.outputs[connect_rel] if connect_rel in self.outputs.keys() else connect_rel
312
+ set_relation = (
313
+ self.outputs[connect_rel]
314
+ if connect_rel in self.outputs.keys()
315
+ else connect_rel
316
+ )
217
317
  self.connect_update[connect_in] = set_relation
218
318
  for close_in, close_rel in closed_loop.items():
219
- set_relation = self.outputs[close_rel] if close_rel in self.outputs.keys() else close_rel
319
+ set_relation = (
320
+ self.outputs[close_rel]
321
+ if close_rel in self.outputs.keys()
322
+ else close_rel
323
+ )
220
324
  self.closed_loop_update[close_in] = set_relation