nnodely 1.5.4__py3-none-any.whl → 1.5.5.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nnodely/__init__.py +82 -32
- nnodely/basic/loss.py +20 -11
- nnodely/basic/model.py +178 -74
- nnodely/basic/modeldef.py +225 -110
- nnodely/basic/optimizer.py +42 -26
- nnodely/basic/relation.py +165 -69
- nnodely/exporter/emptyexporter.py +13 -11
- nnodely/exporter/export.py +281 -155
- nnodely/exporter/reporter.py +41 -15
- nnodely/exporter/standardexporter.py +126 -47
- nnodely/layers/activation.py +102 -62
- nnodely/layers/arithmetic.py +199 -124
- nnodely/layers/equationlearner.py +80 -15
- nnodely/layers/fir.py +110 -48
- nnodely/layers/fuzzify.py +149 -65
- nnodely/layers/input.py +145 -64
- nnodely/layers/interpolation.py +55 -25
- nnodely/layers/linear.py +105 -39
- nnodely/layers/localmodel.py +151 -64
- nnodely/layers/neuralODE.py +55 -37
- nnodely/layers/output.py +9 -6
- nnodely/layers/parameter.py +114 -51
- nnodely/layers/parametricfunction.py +209 -91
- nnodely/layers/part.py +254 -139
- nnodely/layers/rungekutta.py +30 -28
- nnodely/layers/timeoperation.py +55 -19
- nnodely/layers/trigonometric.py +121 -59
- nnodely/nnodely.py +110 -42
- nnodely/operators/composer.py +214 -81
- nnodely/operators/exporter.py +134 -49
- nnodely/operators/loader.py +132 -54
- nnodely/operators/network.py +428 -143
- nnodely/operators/trainer.py +308 -116
- nnodely/operators/validator.py +214 -97
- nnodely/support/earlystopping.py +28 -16
- nnodely/support/fixstepsolver.py +26 -13
- nnodely/support/initializer.py +60 -21
- nnodely/support/jsonutils.py +556 -260
- nnodely/support/logger.py +31 -20
- nnodely/support/mathutils.py +9 -2
- nnodely/support/odeint/__init__.py +0 -0
- nnodely/support/odeint/adjoint.py +406 -0
- nnodely/support/odeint/dopri5.py +60 -0
- nnodely/support/odeint/fixed_grid.py +18 -0
- nnodely/support/odeint/my_odeint.py +158 -0
- nnodely/support/odeint/rk_solvers.py +547 -0
- nnodely/support/odeint/solvers.py +233 -0
- nnodely/support/odeint/utils.py +279 -0
- nnodely/support/utils.py +48 -21
- nnodely/visualizer/__init__.py +1 -1
- nnodely/visualizer/dynamicmpl/functionplot.py +11 -11
- nnodely/visualizer/dynamicmpl/fuzzyplot.py +8 -7
- nnodely/visualizer/dynamicmpl/resultsplot.py +8 -7
- nnodely/visualizer/dynamicmpl/trainingplot.py +10 -7
- nnodely/visualizer/emptyvisualizer.py +7 -5
- nnodely/visualizer/mplnotebookvisualizer.py +63 -34
- nnodely/visualizer/mplvisualizer.py +152 -74
- nnodely/visualizer/textvisualizer.py +317 -143
- {nnodely-1.5.4.dist-info → nnodely-1.5.5.dev2.dist-info}/METADATA +39 -60
- nnodely-1.5.5.dev2.dist-info/RECORD +67 -0
- {nnodely-1.5.4.dist-info → nnodely-1.5.5.dev2.dist-info}/WHEEL +1 -2
- mplplots/__init__.py +0 -1
- mplplots/plots.py +0 -188
- nnodely-1.5.4.dist-info/RECORD +0 -62
- nnodely-1.5.4.dist-info/top_level.txt +0 -2
- {nnodely-1.5.4.dist-info → nnodely-1.5.5.dev2.dist-info}/licenses/LICENSE +0 -0
nnodely/__init__.py
CHANGED
|
@@ -14,7 +14,15 @@ from nnodely.layers.arithmetic import Add, Sum, Sub, Mul, Div, Pow, Neg, Sign
|
|
|
14
14
|
from nnodely.layers.trigonometric import Sin, Cos, Tan, Cosh, Tanh, Sech
|
|
15
15
|
from nnodely.layers.parametricfunction import ParamFun
|
|
16
16
|
from nnodely.layers.fuzzify import Fuzzify
|
|
17
|
-
from nnodely.layers.part import
|
|
17
|
+
from nnodely.layers.part import (
|
|
18
|
+
Part,
|
|
19
|
+
Select,
|
|
20
|
+
Concatenate,
|
|
21
|
+
SamplePart,
|
|
22
|
+
SampleSelect,
|
|
23
|
+
TimePart,
|
|
24
|
+
TimeConcatenate,
|
|
25
|
+
)
|
|
18
26
|
from nnodely.layers.localmodel import LocalModel
|
|
19
27
|
from nnodely.layers.equationlearner import EquationLearner
|
|
20
28
|
from nnodely.layers.timeoperation import Integrate, Differentiate
|
|
@@ -37,43 +45,85 @@ from nnodely.support import logger
|
|
|
37
45
|
major, minor = sys.version_info.major, sys.version_info.minor
|
|
38
46
|
logger.LOG_LEVEL = logging.INFO
|
|
39
47
|
|
|
40
|
-
__version__ =
|
|
48
|
+
__version__ = "1.5.4"
|
|
41
49
|
|
|
42
50
|
if major < 3:
|
|
43
|
-
sys.exit(
|
|
51
|
+
sys.exit(
|
|
52
|
+
"Sorry, Python 2 is not supported. You need Python >= 3.10 for "
|
|
53
|
+
+ __package__
|
|
54
|
+
+ "."
|
|
55
|
+
)
|
|
44
56
|
elif minor < 9:
|
|
45
|
-
sys.exit("Sorry, You need Python >= 3.10 for "+__package__+".")
|
|
57
|
+
sys.exit("Sorry, You need Python >= 3.10 for " + __package__ + ".")
|
|
46
58
|
else:
|
|
47
|
-
print(
|
|
48
|
-
|
|
49
|
-
|
|
59
|
+
print(
|
|
60
|
+
">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
|
|
61
|
+
+ f" {__package__}_v{__version__} ".center(20, "-")
|
|
62
|
+
+ "<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<"
|
|
63
|
+
)
|
|
50
64
|
|
|
51
65
|
|
|
52
66
|
__all__ = [
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
67
|
+
"nnodely",
|
|
68
|
+
"Modely",
|
|
69
|
+
"clearNames",
|
|
70
|
+
"Input",
|
|
71
|
+
"Connect",
|
|
72
|
+
"ClosedLoop",
|
|
73
|
+
"Parameter",
|
|
74
|
+
"Constant",
|
|
75
|
+
"SampleTime",
|
|
76
|
+
"Output",
|
|
77
|
+
"Relu",
|
|
78
|
+
"ELU",
|
|
79
|
+
"Softmax",
|
|
80
|
+
"Sigmoid",
|
|
81
|
+
"Identity",
|
|
82
|
+
"Fir",
|
|
83
|
+
"Linear",
|
|
84
|
+
"NeuralODE",
|
|
85
|
+
"Add",
|
|
86
|
+
"Sum",
|
|
87
|
+
"Sub",
|
|
88
|
+
"Mul",
|
|
89
|
+
"Div",
|
|
90
|
+
"Pow",
|
|
91
|
+
"Neg",
|
|
92
|
+
"Sign",
|
|
93
|
+
"Sin",
|
|
94
|
+
"Cos",
|
|
95
|
+
"Tan",
|
|
96
|
+
"Cosh",
|
|
97
|
+
"Tanh",
|
|
98
|
+
"Sech",
|
|
99
|
+
"ParamFun",
|
|
100
|
+
"Fuzzify",
|
|
101
|
+
"Part",
|
|
102
|
+
"Select",
|
|
103
|
+
"Concatenate",
|
|
104
|
+
"SamplePart",
|
|
105
|
+
"SampleSelect",
|
|
106
|
+
"TimePart",
|
|
107
|
+
"TimeConcatenate",
|
|
108
|
+
"LocalModel",
|
|
109
|
+
"EquationLearner",
|
|
110
|
+
"Integrate",
|
|
111
|
+
"Differentiate",
|
|
112
|
+
"Interpolation",
|
|
113
|
+
"ForwardEuler",
|
|
114
|
+
"RK2",
|
|
115
|
+
"RK4",
|
|
116
|
+
"TextVisualizer",
|
|
117
|
+
"MPLVisualizer",
|
|
118
|
+
"MPLNotebookVisualizer",
|
|
119
|
+
"StandardExporter",
|
|
120
|
+
"SGD",
|
|
121
|
+
"Adam",
|
|
122
|
+
"Optimizer",
|
|
123
|
+
"init_negexp",
|
|
124
|
+
"init_lin",
|
|
125
|
+
"init_constant",
|
|
126
|
+
"init_exp",
|
|
77
127
|
# Main nnodely classes
|
|
78
|
-
|
|
128
|
+
"__version__",
|
|
79
129
|
]
|
nnodely/basic/loss.py
CHANGED
|
@@ -2,26 +2,35 @@ import torch.nn as nn
|
|
|
2
2
|
import torch
|
|
3
3
|
from nnodely.support.utils import check
|
|
4
4
|
|
|
5
|
-
available_losses = [
|
|
5
|
+
available_losses = ["mse", "rmse", "mae", "cross_entropy"]
|
|
6
|
+
|
|
6
7
|
|
|
7
8
|
class CustomLoss(nn.Module):
|
|
8
|
-
def __init__(self, loss_type=
|
|
9
|
+
def __init__(self, loss_type="mse", **kwargs):
|
|
9
10
|
super(CustomLoss, self).__init__()
|
|
10
|
-
check(
|
|
11
|
+
check(
|
|
12
|
+
loss_type in available_losses,
|
|
13
|
+
TypeError,
|
|
14
|
+
f'The "{loss_type}" loss is not available. Possible losses are: {available_losses}.',
|
|
15
|
+
)
|
|
11
16
|
self.loss_type = loss_type
|
|
12
17
|
self.loss = nn.MSELoss(**kwargs)
|
|
13
18
|
if callable(loss_type):
|
|
14
19
|
self.loss = loss_type
|
|
15
|
-
elif self.loss_type ==
|
|
20
|
+
elif self.loss_type == "mae":
|
|
16
21
|
self.loss = nn.L1Loss(**kwargs)
|
|
17
|
-
elif self.loss_type ==
|
|
22
|
+
elif self.loss_type == "cross_entropy":
|
|
18
23
|
self.loss = nn.CrossEntropyLoss(**kwargs)
|
|
19
|
-
|
|
24
|
+
|
|
20
25
|
def forward(self, inA, inB):
|
|
21
|
-
if self.loss_type ==
|
|
22
|
-
inB =
|
|
26
|
+
if self.loss_type == "cross_entropy":
|
|
27
|
+
inB = (
|
|
28
|
+
inB.squeeze().float()
|
|
29
|
+
if inA.shape == inB.shape
|
|
30
|
+
else inB.squeeze().long()
|
|
31
|
+
)
|
|
23
32
|
inA = inA.squeeze()
|
|
24
|
-
res = self.loss(inA,inB)
|
|
25
|
-
if self.loss_type ==
|
|
33
|
+
res = self.loss(inA, inB)
|
|
34
|
+
if self.loss_type == "rmse":
|
|
26
35
|
res = torch.sqrt(res)
|
|
27
|
-
return res
|
|
36
|
+
return res
|
nnodely/basic/model.py
CHANGED
|
@@ -9,34 +9,46 @@ from itertools import product
|
|
|
9
9
|
from nnodely.support.utils import TORCH_DTYPE
|
|
10
10
|
from nnodely.support import initializer
|
|
11
11
|
|
|
12
|
+
|
|
12
13
|
@torch.fx.wrap
|
|
13
14
|
def update_state(data_in, rel):
|
|
14
|
-
#virtual = torch.roll(data_in, shifts=-1, dims=1)
|
|
15
|
+
# virtual = torch.roll(data_in, shifts=-1, dims=1)
|
|
15
16
|
max_dim = min(rel.size(1), data_in.size(1))
|
|
16
17
|
data_out = data_in.clone()
|
|
17
18
|
data_out[:, -max_dim:, :] = rel[:, -max_dim:, :]
|
|
18
19
|
return data_out
|
|
19
20
|
|
|
21
|
+
|
|
20
22
|
class Model(nn.Module):
|
|
21
23
|
def __init__(self, model_def):
|
|
22
24
|
super(Model, self).__init__()
|
|
23
25
|
model_def = copy.deepcopy(model_def)
|
|
24
26
|
|
|
25
|
-
self.states = {
|
|
27
|
+
self.states = {
|
|
28
|
+
key: value
|
|
29
|
+
for key, value in model_def["Inputs"].items()
|
|
30
|
+
if ("closedLoop" in value.keys() or "connect" in value.keys())
|
|
31
|
+
}
|
|
26
32
|
|
|
27
|
-
self.inputs = model_def[
|
|
28
|
-
self.outputs = model_def[
|
|
29
|
-
self.relations = model_def[
|
|
30
|
-
self.params = model_def[
|
|
31
|
-
self.constants = model_def[
|
|
32
|
-
self.sample_time = model_def[
|
|
33
|
-
self.functions = model_def[
|
|
33
|
+
self.inputs = model_def["Inputs"]
|
|
34
|
+
self.outputs = model_def["Outputs"]
|
|
35
|
+
self.relations = model_def["Relations"]
|
|
36
|
+
self.params = model_def["Parameters"]
|
|
37
|
+
self.constants = model_def["Constants"]
|
|
38
|
+
self.sample_time = model_def["Info"]["SampleTime"]
|
|
39
|
+
self.functions = model_def["Functions"]
|
|
34
40
|
|
|
35
|
-
self.minimizers = model_def[
|
|
36
|
-
self.minimizers_keys = [
|
|
41
|
+
self.minimizers = model_def["Minimizers"] if "Minimizers" in model_def else {}
|
|
42
|
+
self.minimizers_keys = [
|
|
43
|
+
self.minimizers[key]["A"] for key in self.minimizers
|
|
44
|
+
] + [self.minimizers[key]["B"] for key in self.minimizers]
|
|
37
45
|
|
|
38
|
-
self.input_ns_backward = {
|
|
39
|
-
|
|
46
|
+
self.input_ns_backward = {
|
|
47
|
+
key: value["ns"][0] for key, value in model_def["Inputs"].items()
|
|
48
|
+
}
|
|
49
|
+
self.input_n_samples = {
|
|
50
|
+
key: value["ntot"] for key, value in model_def["Inputs"].items()
|
|
51
|
+
}
|
|
40
52
|
|
|
41
53
|
## Build the network
|
|
42
54
|
self.all_parameters = {}
|
|
@@ -51,55 +63,89 @@ class Model(nn.Module):
|
|
|
51
63
|
|
|
52
64
|
## Define the correct slicing
|
|
53
65
|
for _, items in self.relations.items():
|
|
54
|
-
if items[0] ==
|
|
66
|
+
if items[0] == "SamplePart":
|
|
55
67
|
if items[1][0] in self.inputs.keys():
|
|
56
68
|
items[3][0] = self.input_ns_backward[items[1][0]] + items[3][0]
|
|
57
69
|
items[3][1] = self.input_ns_backward[items[1][0]] + items[3][1]
|
|
58
|
-
if len(items) > 4:
|
|
70
|
+
if len(items) > 4: ## Offset
|
|
59
71
|
items[4] = self.input_ns_backward[items[1][0]] + items[4]
|
|
60
|
-
if items[0] ==
|
|
72
|
+
if items[0] == "TimePart":
|
|
61
73
|
if items[1][0] in self.inputs.keys():
|
|
62
|
-
items[3][0] = self.input_ns_backward[items[1][0]] + round(
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
74
|
+
items[3][0] = self.input_ns_backward[items[1][0]] + round(
|
|
75
|
+
items[3][0] / self.sample_time
|
|
76
|
+
)
|
|
77
|
+
items[3][1] = self.input_ns_backward[items[1][0]] + round(
|
|
78
|
+
items[3][1] / self.sample_time
|
|
79
|
+
)
|
|
80
|
+
if len(items) > 4: ## Offset
|
|
81
|
+
items[4] = self.input_ns_backward[items[1][0]] + round(
|
|
82
|
+
items[4] / self.sample_time
|
|
83
|
+
)
|
|
66
84
|
else:
|
|
67
|
-
items[3][0] = round(items[3][0]/self.sample_time)
|
|
68
|
-
items[3][1] = round(items[3][1]/self.sample_time)
|
|
69
|
-
if len(items) > 4:
|
|
70
|
-
items[4] = round(items[4]/self.sample_time)
|
|
85
|
+
items[3][0] = round(items[3][0] / self.sample_time)
|
|
86
|
+
items[3][1] = round(items[3][1] / self.sample_time)
|
|
87
|
+
if len(items) > 4: ## Offset
|
|
88
|
+
items[4] = round(items[4] / self.sample_time)
|
|
71
89
|
|
|
72
90
|
## Create all the parameters
|
|
73
91
|
for name, param_data in self.params.items():
|
|
74
|
-
window =
|
|
75
|
-
|
|
76
|
-
|
|
92
|
+
window = (
|
|
93
|
+
"tw"
|
|
94
|
+
if "tw" in param_data.keys()
|
|
95
|
+
else ("sw" if "sw" in param_data.keys() else None)
|
|
96
|
+
)
|
|
97
|
+
aux_sample_time = self.sample_time if "tw" == window else 1
|
|
98
|
+
sample_window = (
|
|
99
|
+
round(param_data[window] / aux_sample_time) if window else None
|
|
100
|
+
)
|
|
77
101
|
if sample_window is None:
|
|
78
|
-
param_size =
|
|
102
|
+
param_size = (
|
|
103
|
+
tuple(param_data["dim"])
|
|
104
|
+
if type(param_data["dim"]) is list
|
|
105
|
+
else (param_data["dim"],)
|
|
106
|
+
)
|
|
79
107
|
else:
|
|
80
|
-
param_size = (
|
|
81
|
-
|
|
82
|
-
|
|
108
|
+
param_size = (
|
|
109
|
+
(sample_window,) + tuple(param_data["dim"])
|
|
110
|
+
if type(param_data["dim"]) is list
|
|
111
|
+
else (sample_window, param_data["dim"])
|
|
112
|
+
)
|
|
113
|
+
if "values" in param_data:
|
|
114
|
+
self.all_parameters[name] = nn.Parameter(
|
|
115
|
+
torch.tensor(param_data["values"], dtype=TORCH_DTYPE),
|
|
116
|
+
requires_grad=True,
|
|
117
|
+
)
|
|
83
118
|
# TODO clean code
|
|
84
|
-
elif
|
|
85
|
-
if
|
|
86
|
-
exec(param_data[
|
|
87
|
-
function_to_call = globals()[param_data[
|
|
119
|
+
elif "init_fun" in param_data:
|
|
120
|
+
if "code" in param_data["init_fun"].keys():
|
|
121
|
+
exec(param_data["init_fun"]["code"], globals())
|
|
122
|
+
function_to_call = globals()[param_data["init_fun"]["name"]]
|
|
88
123
|
else:
|
|
89
|
-
function_to_call = getattr(
|
|
124
|
+
function_to_call = getattr(
|
|
125
|
+
initializer, param_data["init_fun"]["name"]
|
|
126
|
+
)
|
|
90
127
|
values = np.zeros(param_size)
|
|
91
128
|
for indexes in product(*(range(v) for v in param_size)):
|
|
92
|
-
if
|
|
93
|
-
values[indexes] = function_to_call(
|
|
129
|
+
if "params" in param_data["init_fun"]:
|
|
130
|
+
values[indexes] = function_to_call(
|
|
131
|
+
indexes, param_size, param_data["init_fun"]["params"]
|
|
132
|
+
)
|
|
94
133
|
else:
|
|
95
134
|
values[indexes] = function_to_call(indexes, param_size)
|
|
96
|
-
self.all_parameters[name] = nn.Parameter(
|
|
135
|
+
self.all_parameters[name] = nn.Parameter(
|
|
136
|
+
torch.tensor(values.tolist(), dtype=TORCH_DTYPE), requires_grad=True
|
|
137
|
+
)
|
|
97
138
|
else:
|
|
98
|
-
self.all_parameters[name] = nn.Parameter(
|
|
139
|
+
self.all_parameters[name] = nn.Parameter(
|
|
140
|
+
torch.rand(size=param_size, dtype=TORCH_DTYPE), requires_grad=True
|
|
141
|
+
)
|
|
99
142
|
|
|
100
143
|
## Create all the constants
|
|
101
144
|
for name, param_data in self.constants.items():
|
|
102
|
-
self.all_constants[name] = nn.Parameter(
|
|
145
|
+
self.all_constants[name] = nn.Parameter(
|
|
146
|
+
torch.tensor(param_data["values"], dtype=TORCH_DTYPE),
|
|
147
|
+
requires_grad=False,
|
|
148
|
+
)
|
|
103
149
|
all_params_and_consts = self.all_parameters | self.all_constants
|
|
104
150
|
|
|
105
151
|
## Create all the relations
|
|
@@ -107,27 +153,41 @@ class Model(nn.Module):
|
|
|
107
153
|
## Take the relation name and the inputs needed to solve the relation
|
|
108
154
|
rel_name, input_var = inputs[0], inputs[1]
|
|
109
155
|
## Create All the Relations
|
|
110
|
-
func = getattr(self,rel_name)
|
|
156
|
+
func = getattr(self, rel_name)
|
|
111
157
|
if func:
|
|
112
158
|
layer_inputs = []
|
|
113
159
|
for item in inputs[2:]:
|
|
114
|
-
if item in list(
|
|
160
|
+
if item in list(
|
|
161
|
+
self.params.keys()
|
|
162
|
+
): ## the relation takes parameters
|
|
115
163
|
layer_inputs.append(self.all_parameters[item])
|
|
116
|
-
elif item in list(
|
|
164
|
+
elif item in list(
|
|
165
|
+
self.constants.keys()
|
|
166
|
+
): ## the relation takes a constant
|
|
117
167
|
layer_inputs.append(self.all_constants[item])
|
|
118
|
-
elif item in list(
|
|
168
|
+
elif item in list(
|
|
169
|
+
self.functions.keys()
|
|
170
|
+
): ## the relation takes a custom function
|
|
119
171
|
layer_inputs.append(self.functions[item])
|
|
120
|
-
if
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
172
|
+
if (
|
|
173
|
+
"params_and_consts" in self.functions[item].keys()
|
|
174
|
+
and len(self.functions[item]["params_and_consts"]) >= 0
|
|
175
|
+
): ## Parametric function that takes parameters
|
|
176
|
+
layer_inputs.append(
|
|
177
|
+
[
|
|
178
|
+
all_params_and_consts[par]
|
|
179
|
+
for par in self.functions[item]["params_and_consts"]
|
|
180
|
+
]
|
|
181
|
+
)
|
|
182
|
+
if "map_over_dim" in self.functions[item].keys():
|
|
183
|
+
layer_inputs.append(self.functions[item]["map_over_dim"])
|
|
124
184
|
else:
|
|
125
185
|
layer_inputs.append(item)
|
|
126
186
|
|
|
127
|
-
if rel_name ==
|
|
187
|
+
if rel_name == "SamplePart":
|
|
128
188
|
if layer_inputs[0] == -1:
|
|
129
189
|
layer_inputs[0] = self.input_n_samples[input_var[0]]
|
|
130
|
-
elif rel_name ==
|
|
190
|
+
elif rel_name == "TimePart":
|
|
131
191
|
if layer_inputs[0] == -1:
|
|
132
192
|
layer_inputs[0] = self.input_n_samples[input_var[0]]
|
|
133
193
|
else:
|
|
@@ -145,60 +205,96 @@ class Model(nn.Module):
|
|
|
145
205
|
self.network_output_predictions = set(self.outputs.values())
|
|
146
206
|
## list of network minimization outputs
|
|
147
207
|
self.network_output_minimizers = []
|
|
148
|
-
for _,value in self.minimizers.items():
|
|
149
|
-
self.network_output_minimizers.append(self.outputs[value[
|
|
150
|
-
|
|
208
|
+
for _, value in self.minimizers.items():
|
|
209
|
+
self.network_output_minimizers.append(self.outputs[value["A"]]) if value[
|
|
210
|
+
"A"
|
|
211
|
+
] in self.outputs.keys() else self.network_output_minimizers.append(
|
|
212
|
+
value["A"]
|
|
213
|
+
)
|
|
214
|
+
self.network_output_minimizers.append(self.outputs[value["B"]]) if value[
|
|
215
|
+
"B"
|
|
216
|
+
] in self.outputs.keys() else self.network_output_minimizers.append(
|
|
217
|
+
value["B"]
|
|
218
|
+
)
|
|
151
219
|
self.network_output_minimizers = set(self.network_output_minimizers)
|
|
152
220
|
## list of all the network Outputs
|
|
153
|
-
self.network_outputs = self.network_output_predictions.union(
|
|
221
|
+
self.network_outputs = self.network_output_predictions.union(
|
|
222
|
+
self.network_output_minimizers
|
|
223
|
+
)
|
|
154
224
|
|
|
155
225
|
def forward(self, kwargs):
|
|
156
226
|
result_dict = {}
|
|
157
227
|
|
|
158
228
|
## Initially i have only the inputs from the dataset, the parameters, and the constants
|
|
159
|
-
available_inputs = [
|
|
160
|
-
|
|
229
|
+
available_inputs = [
|
|
230
|
+
key for key in self.inputs.keys() if key not in self.connect_update.keys()
|
|
231
|
+
] ## remove connected inputs
|
|
232
|
+
available_keys = set(
|
|
233
|
+
available_inputs
|
|
234
|
+
+ list(self.all_parameters.keys())
|
|
235
|
+
+ list(self.all_constants.keys())
|
|
236
|
+
)
|
|
161
237
|
|
|
162
238
|
## Forward pass through the relations
|
|
163
|
-
while not self.network_outputs.issubset(
|
|
239
|
+
while not self.network_outputs.issubset(
|
|
240
|
+
available_keys
|
|
241
|
+
): ## i need to climb the relation tree until i get all the outputs
|
|
164
242
|
for relation in self.relations.keys():
|
|
165
243
|
## if i have all the variables i can calculate the relation
|
|
166
|
-
if set(self.relation_inputs[relation]).issubset(available_keys) and (
|
|
244
|
+
if set(self.relation_inputs[relation]).issubset(available_keys) and (
|
|
245
|
+
relation not in available_keys
|
|
246
|
+
):
|
|
167
247
|
## Collect all the necessary inputs for the relation
|
|
168
248
|
layer_inputs = []
|
|
169
249
|
for key in self.relation_inputs[relation]:
|
|
170
|
-
if
|
|
250
|
+
if (
|
|
251
|
+
key in self.all_constants.keys()
|
|
252
|
+
): ## relation that takes a constant
|
|
171
253
|
layer_inputs.append(self.all_constants[key])
|
|
172
254
|
elif key in available_inputs: ## relation that takes inputs
|
|
173
255
|
layer_inputs.append(kwargs[key])
|
|
174
|
-
elif
|
|
256
|
+
elif (
|
|
257
|
+
key in self.all_parameters.keys()
|
|
258
|
+
): ## relation that takes parameters
|
|
175
259
|
layer_inputs.append(self.all_parameters[key])
|
|
176
|
-
else:
|
|
260
|
+
else: ## relation than takes another relation or a connect variable
|
|
177
261
|
layer_inputs.append(result_dict[key])
|
|
178
262
|
|
|
179
263
|
## Execute the current relation
|
|
180
|
-
result_dict[relation] = self.relation_forward[relation](
|
|
264
|
+
result_dict[relation] = self.relation_forward[relation](
|
|
265
|
+
*layer_inputs
|
|
266
|
+
)
|
|
181
267
|
available_keys.add(relation)
|
|
182
268
|
|
|
183
269
|
## Check if the relation is inside the connect
|
|
184
270
|
for connect_input, connect_rel in self.connect_update.items():
|
|
185
271
|
if relation == connect_rel:
|
|
186
|
-
result_dict[connect_input] = update_state(
|
|
272
|
+
result_dict[connect_input] = update_state(
|
|
273
|
+
kwargs[connect_input], result_dict[relation]
|
|
274
|
+
)
|
|
187
275
|
available_keys.add(connect_input)
|
|
188
276
|
|
|
189
277
|
## Return a dictionary with all the connected inputs
|
|
190
|
-
connect_update_dict = {
|
|
278
|
+
connect_update_dict = {
|
|
279
|
+
key: result_dict[key] for key in self.connect_update.keys()
|
|
280
|
+
}
|
|
191
281
|
## Return a dictionary with all the relations that updates the state variables
|
|
192
|
-
closed_loop_update_dict = {
|
|
282
|
+
closed_loop_update_dict = {
|
|
283
|
+
key: result_dict[value] for key, value in self.closed_loop_update.items()
|
|
284
|
+
}
|
|
193
285
|
## Return a dictionary with all the outputs final values
|
|
194
286
|
output_dict = {key: result_dict[value] for key, value in self.outputs.items()}
|
|
195
287
|
## Return a dictionary with the minimization relations
|
|
196
288
|
minimize_dict = {}
|
|
197
289
|
for key in self.minimizers_keys:
|
|
198
|
-
minimize_dict[key] =
|
|
290
|
+
minimize_dict[key] = (
|
|
291
|
+
result_dict[self.outputs[key]]
|
|
292
|
+
if key in self.outputs.keys()
|
|
293
|
+
else result_dict[key]
|
|
294
|
+
)
|
|
199
295
|
return output_dict, minimize_dict, closed_loop_update_dict, connect_update_dict
|
|
200
296
|
|
|
201
|
-
def update(self, *, closed_loop
|
|
297
|
+
def update(self, *, closed_loop={}, connect={}, disconnect=False):
|
|
202
298
|
self.closed_loop_update = {}
|
|
203
299
|
self.connect_update = {}
|
|
204
300
|
|
|
@@ -206,15 +302,23 @@ class Model(nn.Module):
|
|
|
206
302
|
return
|
|
207
303
|
|
|
208
304
|
for key, state in self.states.items():
|
|
209
|
-
if
|
|
210
|
-
self.connect_update[key] = state[
|
|
211
|
-
elif
|
|
212
|
-
self.closed_loop_update[key] = state[
|
|
305
|
+
if "connect" in state.keys():
|
|
306
|
+
self.connect_update[key] = state["connect"]
|
|
307
|
+
elif "closedLoop" in state.keys():
|
|
308
|
+
self.closed_loop_update[key] = state["closedLoop"]
|
|
213
309
|
|
|
214
310
|
# Get relation from outputs
|
|
215
311
|
for connect_in, connect_rel in connect.items():
|
|
216
|
-
set_relation =
|
|
312
|
+
set_relation = (
|
|
313
|
+
self.outputs[connect_rel]
|
|
314
|
+
if connect_rel in self.outputs.keys()
|
|
315
|
+
else connect_rel
|
|
316
|
+
)
|
|
217
317
|
self.connect_update[connect_in] = set_relation
|
|
218
318
|
for close_in, close_rel in closed_loop.items():
|
|
219
|
-
set_relation =
|
|
319
|
+
set_relation = (
|
|
320
|
+
self.outputs[close_rel]
|
|
321
|
+
if close_rel in self.outputs.keys()
|
|
322
|
+
else close_rel
|
|
323
|
+
)
|
|
220
324
|
self.closed_loop_update[close_in] = set_relation
|