nnodely 1.5.4__tar.gz → 1.5.5.dev1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nnodely-1.5.4 → nnodely-1.5.5.dev1}/PKG-INFO +39 -60
- {nnodely-1.5.4 → nnodely-1.5.5.dev1}/README.md +29 -29
- nnodely-1.5.5.dev1/pyproject.toml +44 -0
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/__init__.py +82 -32
- nnodely-1.5.5.dev1/src/nnodely/basic/loss.py +36 -0
- nnodely-1.5.5.dev1/src/nnodely/basic/model.py +324 -0
- nnodely-1.5.5.dev1/src/nnodely/basic/modeldef.py +345 -0
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/basic/optimizer.py +42 -26
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/basic/relation.py +165 -69
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/exporter/emptyexporter.py +13 -11
- nnodely-1.5.5.dev1/src/nnodely/exporter/export.py +532 -0
- nnodely-1.5.5.dev1/src/nnodely/exporter/reporter.py +81 -0
- nnodely-1.5.5.dev1/src/nnodely/exporter/standardexporter.py +196 -0
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/layers/activation.py +102 -62
- nnodely-1.5.5.dev1/src/nnodely/layers/arithmetic.py +407 -0
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/layers/equationlearner.py +80 -15
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/layers/fir.py +110 -48
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/layers/fuzzify.py +149 -65
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/layers/input.py +145 -64
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/layers/interpolation.py +55 -25
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/layers/linear.py +105 -39
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/layers/localmodel.py +37 -17
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/layers/neuralODE.py +55 -37
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/layers/output.py +9 -6
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/layers/parameter.py +114 -51
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/layers/parametricfunction.py +209 -91
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/layers/part.py +254 -139
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/layers/rungekutta.py +30 -28
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/layers/timeoperation.py +55 -19
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/layers/trigonometric.py +121 -59
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/nnodely.py +110 -42
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/operators/composer.py +214 -81
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/operators/exporter.py +134 -49
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/operators/loader.py +132 -54
- nnodely-1.5.5.dev1/src/nnodely/operators/network.py +711 -0
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/operators/trainer.py +308 -116
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/operators/validator.py +214 -97
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/support/earlystopping.py +28 -16
- nnodely-1.5.5.dev1/src/nnodely/support/fixstepsolver.py +48 -0
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/support/initializer.py +60 -21
- nnodely-1.5.5.dev1/src/nnodely/support/jsonutils.py +678 -0
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/support/logger.py +31 -20
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/support/mathutils.py +9 -2
- nnodely-1.5.5.dev1/src/nnodely/support/odeint/__init__.py +0 -0
- nnodely-1.5.5.dev1/src/nnodely/support/odeint/adjoint.py +406 -0
- nnodely-1.5.5.dev1/src/nnodely/support/odeint/dopri5.py +60 -0
- nnodely-1.5.5.dev1/src/nnodely/support/odeint/fixed_grid.py +18 -0
- nnodely-1.5.5.dev1/src/nnodely/support/odeint/my_odeint.py +158 -0
- nnodely-1.5.5.dev1/src/nnodely/support/odeint/rk_solvers.py +547 -0
- nnodely-1.5.5.dev1/src/nnodely/support/odeint/solvers.py +233 -0
- nnodely-1.5.5.dev1/src/nnodely/support/odeint/utils.py +279 -0
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/support/utils.py +48 -21
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/visualizer/__init__.py +1 -1
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/visualizer/dynamicmpl/functionplot.py +11 -11
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/visualizer/dynamicmpl/fuzzyplot.py +8 -7
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/visualizer/dynamicmpl/resultsplot.py +8 -7
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/visualizer/dynamicmpl/trainingplot.py +10 -7
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/visualizer/emptyvisualizer.py +7 -5
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/visualizer/mplnotebookvisualizer.py +63 -34
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/visualizer/mplvisualizer.py +152 -74
- nnodely-1.5.5.dev1/src/nnodely/visualizer/textvisualizer.py +493 -0
- nnodely-1.5.4/mplplots/__init__.py +0 -1
- nnodely-1.5.4/mplplots/plots.py +0 -188
- nnodely-1.5.4/nnodely/basic/loss.py +0 -27
- nnodely-1.5.4/nnodely/basic/model.py +0 -220
- nnodely-1.5.4/nnodely/basic/modeldef.py +0 -230
- nnodely-1.5.4/nnodely/exporter/export.py +0 -406
- nnodely-1.5.4/nnodely/exporter/reporter.py +0 -55
- nnodely-1.5.4/nnodely/exporter/standardexporter.py +0 -117
- nnodely-1.5.4/nnodely/layers/arithmetic.py +0 -332
- nnodely-1.5.4/nnodely/operators/network.py +0 -426
- nnodely-1.5.4/nnodely/support/fixstepsolver.py +0 -35
- nnodely-1.5.4/nnodely/support/jsonutils.py +0 -459
- nnodely-1.5.4/nnodely/visualizer/textvisualizer.py +0 -319
- nnodely-1.5.4/nnodely.egg-info/PKG-INFO +0 -318
- nnodely-1.5.4/nnodely.egg-info/SOURCES.txt +0 -82
- nnodely-1.5.4/nnodely.egg-info/dependency_links.txt +0 -1
- nnodely-1.5.4/nnodely.egg-info/requires.txt +0 -14
- nnodely-1.5.4/nnodely.egg-info/top_level.txt +0 -2
- nnodely-1.5.4/pyproject.toml +0 -49
- nnodely-1.5.4/setup.cfg +0 -4
- nnodely-1.5.4/setup.py +0 -27
- nnodely-1.5.4/tests/test_dataset.py +0 -974
- nnodely-1.5.4/tests/test_documentation.py +0 -26
- nnodely-1.5.4/tests/test_export.py +0 -344
- nnodely-1.5.4/tests/test_export_recurrent.py +0 -1062
- nnodely-1.5.4/tests/test_input_dimensions.py +0 -438
- nnodely-1.5.4/tests/test_json.py +0 -741
- nnodely-1.5.4/tests/test_losses.py +0 -190
- nnodely-1.5.4/tests/test_model_predict.py +0 -1947
- nnodely-1.5.4/tests/test_model_predict_recurrent.py +0 -1595
- nnodely-1.5.4/tests/test_network_element.py +0 -516
- nnodely-1.5.4/tests/test_parameters_of_train.py +0 -1208
- nnodely-1.5.4/tests/test_results.py +0 -317
- nnodely-1.5.4/tests/test_train.py +0 -344
- nnodely-1.5.4/tests/test_train_recurrent.py +0 -1785
- nnodely-1.5.4/tests/test_utils.py +0 -25
- nnodely-1.5.4/tests/test_visualizer.py +0 -271
- {nnodely-1.5.4 → nnodely-1.5.5.dev1}/LICENSE +0 -0
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/basic/__init__.py +0 -0
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/exporter/__init__.py +0 -0
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/layers/__init__.py +0 -0
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/operators/__init__.py +0 -0
- {nnodely-1.5.4 → nnodely-1.5.5.dev1/src}/nnodely/support/__init__.py +0 -0
|
@@ -1,40 +1,16 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nnodely
|
|
3
|
-
Version: 1.5.
|
|
3
|
+
Version: 1.5.5.dev1
|
|
4
4
|
Summary: Model-structured neural network framework for the modeling and control of physical systems
|
|
5
|
+
Author: Gastone Pietro Rosati Papini
|
|
5
6
|
Author-email: Gastone Pietro Rosati Papini <tonegas@gmail.com>
|
|
6
|
-
License: MIT
|
|
7
|
-
|
|
8
|
-
Copyright (c) 2024 Gastone Pietro Rosati Papini
|
|
9
|
-
|
|
10
|
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
11
|
-
of this software and associated documentation files (the "Software"), to deal
|
|
12
|
-
in the Software without restriction, including without limitation the rights
|
|
13
|
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
14
|
-
copies of the Software, and to permit persons to whom the Software is
|
|
15
|
-
furnished to do so, subject to the following conditions:
|
|
16
|
-
|
|
17
|
-
The above copyright notice and this permission notice shall be included in all
|
|
18
|
-
copies or substantial portions of the Software.
|
|
19
|
-
|
|
20
|
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
21
|
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
22
|
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
23
|
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
24
|
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
25
|
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
26
|
-
SOFTWARE.
|
|
27
|
-
|
|
28
|
-
Project-URL: Homepage, https://github.com/tonegas/nnodely
|
|
7
|
+
License-Expression: MIT
|
|
8
|
+
License-File: LICENSE
|
|
29
9
|
Classifier: Programming Language :: Python :: 3
|
|
30
|
-
Classifier: License :: OSI Approved :: MIT License
|
|
31
10
|
Classifier: Operating System :: OS Independent
|
|
32
|
-
Requires-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
Requires-Dist: numpy==1.26.4; platform_machine == "x86_64" and python_version == "3.10"
|
|
36
|
-
Requires-Dist: torch==2.2.2; platform_machine == "x86_64" and python_version == "3.10"
|
|
37
|
-
Requires-Dist: torch==2.6.0; platform_machine != "x86_64" or python_version != "3.10"
|
|
11
|
+
Requires-Dist: numpy==1.26.4 ; python_full_version == '3.10.*' and platform_machine == 'x86_64'
|
|
12
|
+
Requires-Dist: torch==2.2.2 ; python_full_version == '3.10.*' and platform_machine == 'x86_64'
|
|
13
|
+
Requires-Dist: torch==2.6.0 ; python_full_version != '3.10.*' or platform_machine != 'x86_64'
|
|
38
14
|
Requires-Dist: numpy
|
|
39
15
|
Requires-Dist: onnx
|
|
40
16
|
Requires-Dist: pandas
|
|
@@ -42,7 +18,10 @@ Requires-Dist: reportlab
|
|
|
42
18
|
Requires-Dist: matplotlib
|
|
43
19
|
Requires-Dist: onnxruntime
|
|
44
20
|
Requires-Dist: graphviz
|
|
45
|
-
|
|
21
|
+
Requires-Python: >=3.10, <3.14
|
|
22
|
+
Project-URL: Homepage, https://github.com/tonegas/nnodely
|
|
23
|
+
Project-URL: Repository, https://github.com/tonegas/nnodely
|
|
24
|
+
Description-Content-Type: text/markdown
|
|
46
25
|
|
|
47
26
|
<a name="readme-top"></a>
|
|
48
27
|
<p align="center">
|
|
@@ -54,16 +33,16 @@ Dynamic: license-file
|
|
|
54
33
|
[](https://nnodely.readthedocs.io/)
|
|
55
34
|
[](https://pypi.org/project/nnodely/)
|
|
56
35
|
|
|
57
|
-
# Neural Network
|
|
36
|
+
# Neural Network Framework for Modelling, Control, and Estimation of Physical Systems
|
|
58
37
|
|
|
59
38
|
Modeling, control, and estimation of physical systems are central to many engineering disciplines. While data-driven methods like neural networks offer powerful tools, they often struggle to **incorporate prior domain knowledge**, limiting their interpretability, generalizability, and safety.
|
|
60
39
|
|
|
61
|
-
To bridge this gap, we present ***nnodely*** (where "nn" can be read as "m," forming *Modely*) — a framework that facilitates the creation and deployment of **Model-Structured Neural Networks** (**MS-NNs**).
|
|
40
|
+
To bridge this gap, we present ***nnodely*** (where "nn" can be read as "m," forming *Modely*) — a framework that facilitates the creation and deployment of **Model-Structured Neural Networks** (**MS-NNs**).
|
|
62
41
|
MS-NNs combine the learning capabilities of neural networks with structural **priors** grounded in **physics, control, and estimation theory**, enabling:
|
|
63
42
|
|
|
64
|
-
- **Reduced training data** requirements
|
|
65
|
-
- **Generalization** to unseen scenarios
|
|
66
|
-
- **Real-time** deployment in real-world applications
|
|
43
|
+
- **Reduced training data** requirements
|
|
44
|
+
- **Generalization** to unseen scenarios
|
|
45
|
+
- **Real-time** deployment in real-world applications
|
|
67
46
|
|
|
68
47
|
In short:
|
|
69
48
|
|
|
@@ -172,10 +151,10 @@ The `nnodely` main class defined in __nnodely.py__, it contains all the main pro
|
|
|
172
151
|
2. __loader.py__ contains the function for managing the dataset, the main function is `dataLoad`.
|
|
173
152
|
3. __trainer.py__ contains the function for training the network as the `trainModel`.
|
|
174
153
|
4. __exporter.py__ contains all the function for import and export: `saveModel`, `loadModel`, `exportONNX` etc..
|
|
175
|
-
5. __validator.py__ contains all the function for validate the model and the `resultsAnalysis`.
|
|
154
|
+
5. __validator.py__ contains all the function for validate the model and the `resultsAnalysis`.
|
|
176
155
|
6. All the operators derive from `Network` defined in __network.py__, that contains the shared support functions for all the operators.
|
|
177
156
|
|
|
178
|
-
The folder `basic/` contains the main classes for the low level functionalities:
|
|
157
|
+
The folder `basic/` contains the main classes for the low level functionalities:
|
|
179
158
|
1. __model.py__ containts the pytorch template model for the structured network.
|
|
180
159
|
2. __modeldef.py__ containts the operation for work with the json model definition.
|
|
181
160
|
3. __loss.py__ contains the loss functions.
|
|
@@ -199,14 +178,14 @@ The main basic layers without parameters are:
|
|
|
199
178
|
2. __arithmetic.py__ this file contains the aritmetic functions as: +, -, /, *., **.
|
|
200
179
|
3. __trigonometric.py__ this file contains all the trigonometric functions.
|
|
201
180
|
4. __part.py__ are used for selecting part of the data.
|
|
202
|
-
5. __fuzzify.py__ contains the operation for the fuzzification of a variable,
|
|
181
|
+
5. __fuzzify.py__ contains the operation for the fuzzification of a variable,
|
|
203
182
|
commonly used in the local model as activation function as in [[1]](#1) with rectangular activation functions or in [[3]](#3), [[4]](#4) and [[5]](#5) with triangular activation function activation functions.
|
|
204
183
|
Using fuzzification it is also possible create a channel coding as presented in [[2]](#2).
|
|
205
184
|
|
|
206
185
|
The main basic layers with parameters are:
|
|
207
|
-
1. __fir.py__ this file contains the finite impulse response filter function. It is a linear operation on the time dimension (second dimension).
|
|
186
|
+
1. __fir.py__ this file contains the finite impulse response filter function. It is a linear operation on the time dimension (second dimension).
|
|
208
187
|
This filter was introduced in [[1]](#1).
|
|
209
|
-
2. __linear.py__ this file contains the linear function. Typical Linear operation `W*x+b` operated on the space dimension (third dimension).
|
|
188
|
+
2. __linear.py__ this file contains the linear function. Typical Linear operation `W*x+b` operated on the space dimension (third dimension).
|
|
210
189
|
This operation is presented in [[1]](#1).
|
|
211
190
|
3. __localmodel.py__ this file contains the logic for build a local model. This operation is presented in [[1]](#1), [[3]](#3), [[4]](#4) and [[5]](#5).
|
|
212
191
|
4. __parametricfunction.py__ are the user custom function. The function can use the pytorch syntax. A parametric function is presented in [[3]](#3), [[4]](#4), [[5]](#5).
|
|
@@ -242,8 +221,8 @@ This folder contains the images used in the documentation.
|
|
|
242
221
|
|
|
243
222
|
To contribute to the nnodely framework, you can:
|
|
244
223
|
|
|
245
|
-
- Open a pull request if you have a new feature or bug fix.
|
|
246
|
-
- Open an issue if you have a question or suggestion.
|
|
224
|
+
- Open a pull request if you have a new feature or bug fix.
|
|
225
|
+
- Open an issue if you have a question or suggestion.
|
|
247
226
|
|
|
248
227
|
We welcome contributions and collaborations.
|
|
249
228
|
|
|
@@ -258,53 +237,53 @@ This project is released under the license [License: MIT](https://opensource.org
|
|
|
258
237
|
<a name="references"></a>
|
|
259
238
|
## References
|
|
260
239
|
|
|
261
|
-
<a id="1">[1]</a>
|
|
262
|
-
Mauro Da Lio, Daniele Bortoluzzi, Gastone Pietro Rosati Papini. (2019).
|
|
263
|
-
Modelling longitudinal vehicle dynamics with neural networks.
|
|
240
|
+
<a id="1">[1]</a>
|
|
241
|
+
Mauro Da Lio, Daniele Bortoluzzi, Gastone Pietro Rosati Papini. (2019).
|
|
242
|
+
Modelling longitudinal vehicle dynamics with neural networks.
|
|
264
243
|
Vehicle System Dynamics. https://doi.org/10.1080/00423114.2019.1638947 (look the [[code]](https://github.com/tonegas/nnodely-applications/blob/main/vehicle/model_longit_vehicle_dynamics/model_longit_vehicle_dynamics.py))
|
|
265
244
|
|
|
266
|
-
<a id="2">[2]</a>
|
|
267
|
-
Alice Plebe, Mauro Da Lio, Daniele Bortoluzzi. (2019).
|
|
268
|
-
On Reliable Neural Network Sensorimotor Control in Autonomous Vehicles.
|
|
245
|
+
<a id="2">[2]</a>
|
|
246
|
+
Alice Plebe, Mauro Da Lio, Daniele Bortoluzzi. (2019).
|
|
247
|
+
On Reliable Neural Network Sensorimotor Control in Autonomous Vehicles.
|
|
269
248
|
IEEE Transaction on Intelligent Transportation System. https://doi.org/10.1109/TITS.2019.2896375
|
|
270
249
|
|
|
271
|
-
<a id="3">[3]</a>
|
|
272
|
-
Mauro Da Lio, Riccardo Donà, Gastone Pietro Rosati Papini, Francesco Biral, Henrik Svensson. (2020).
|
|
250
|
+
<a id="3">[3]</a>
|
|
251
|
+
Mauro Da Lio, Riccardo Donà, Gastone Pietro Rosati Papini, Francesco Biral, Henrik Svensson. (2020).
|
|
273
252
|
A Mental Simulation Approach for Learning Neural-Network Predictive Control (in Self-Driving Cars).
|
|
274
253
|
IEEE Access. https://doi.org/10.1109/ACCESS.2020.3032780 (look the [[code]](https://github.com/tonegas/nnodely-applications/blob/main/vehicle/model_lateral_vehicle_dynamics/model_lateral_vehicle_dynamics.ipynb))
|
|
275
254
|
|
|
276
|
-
<a id="4">[4]</a>
|
|
277
|
-
Edoardo Pagot, Mattia Piccinini, Enrico Bertolazzi, Francesco Biral. (2023).
|
|
255
|
+
<a id="4">[4]</a>
|
|
256
|
+
Edoardo Pagot, Mattia Piccinini, Enrico Bertolazzi, Francesco Biral. (2023).
|
|
278
257
|
Fast Planning and Tracking of Complex Autonomous Parking Maneuvers With Optimal Control and Pseudo-Neural Networks.
|
|
279
258
|
IEEE Access. https://doi.org/10.1109/ACCESS.2023.3330431 (look the [[code]](https://github.com/tonegas/nnodely-applications/blob/main/vehicle/control_steer_car_parking/control_steer_car_parking.ipynb))
|
|
280
259
|
|
|
281
|
-
<a id="5">[5]</a>
|
|
260
|
+
<a id="5">[5]</a>
|
|
282
261
|
Mattia Piccinini, Sebastiano Taddei, Matteo Larcher, Mattia Piazza, Francesco Biral. (2023).
|
|
283
262
|
A Physics-Driven Artificial Agent for Online Time-Optimal Vehicle Motion Planning and Control.
|
|
284
263
|
IEEE Access. https://doi.org/10.1109/ACCESS.2023.3274836 (look [[code basic]](https://github.com/tonegas/nnodely-applications/blob/main/vehicle/control_steer_artificial_race_driver/control_steer_artificial_race_driver.ipynb)
|
|
285
264
|
and [[code extended]](https://github.com/tonegas/nnodely-applications/blob/main/vehicle/control_steer_artificial_race_driver_extended/control_steer_artificial_race_driver_extended.ipynb))
|
|
286
265
|
|
|
287
|
-
<a id="6">[6]</a>
|
|
266
|
+
<a id="6">[6]</a>
|
|
288
267
|
Hector Perez-Villeda, Justus Piater, Matteo Saveriano. (2023).
|
|
289
268
|
Learning and extrapolation of robotic skills using task-parameterized equation learner networks.
|
|
290
269
|
Robotics and Autonomous Systems. https://doi.org/10.1016/j.robot.2022.104309 (look the [[code]](https://github.com/tonegas/nnodely-applications/blob/main/equation_learner/equation_learner.ipynb))
|
|
291
270
|
|
|
292
|
-
<a id="7">[7]</a>
|
|
271
|
+
<a id="7">[7]</a>
|
|
293
272
|
M. Raissi. P. Perdikaris b, G.E. Karniadakis a. (2019).
|
|
294
273
|
Physics-informed neural networks: A deep learning framework for solving forward and inverse problems involving nonlinear partial differential equations
|
|
295
274
|
Journal of Computational Physics. https://doi.org/10.1016/j.jcp.2018.10.045 (look the [[example Burger's equation]](https://github.com/tonegas/nnodely-applications/blob/main/pinn/pinn_Burgers_equation.ipynb))
|
|
296
275
|
|
|
297
|
-
<a id="8">[8]</a>
|
|
276
|
+
<a id="8">[8]</a>
|
|
298
277
|
Wojciech Marian Czarnecki, Simon Osindero, Max Jaderberg, Grzegorz Świrszcz, Razvan Pascanu. (2017).
|
|
299
278
|
Sobolev Training for Neural Networks.
|
|
300
279
|
arXiv. https://doi.org/10.48550/arXiv.1706.04859 (look the [[code]](https://github.com/tonegas/nnodely-applications/blob/main/sobolev/Sobolev_learning.ipynb))
|
|
301
280
|
|
|
302
|
-
<a id="9">[9]</a>
|
|
281
|
+
<a id="9">[9]</a>
|
|
303
282
|
Mattia Piccinini, Matteo Zumerle, Johannes Betz, Gastone Pietro Rosati Papini. (2025).
|
|
304
283
|
A Road Friction-Aware Anti-Lock Braking System Based on Model-Structured Neural Networks.
|
|
305
284
|
IEEE Open Journal of Intelligent Transportation Systems. https://doi.org/10.1109/OJITS.2025.3563347 (look at the [[code]](https://github.com/tonegas/nnodely-applications/tree/main/vehicle/road_friction_aware_ABS))
|
|
306
285
|
|
|
307
|
-
<a id="10">[10]</a>
|
|
286
|
+
<a id="10">[10]</a>
|
|
308
287
|
Mauro Da Lio, Mattia Piccinini, Francesco Biral. (2023).
|
|
309
288
|
Robust and Sample-Efficient Estimation of Vehicle Lateral Velocity Using Neural Networks With Explainable Structure Informed by Kinematic Principles.
|
|
310
289
|
IEEE Transactions on Intelligent Transportation Systems. https://doi.org/10.1109/TITS.2023.3303776
|
|
@@ -8,16 +8,16 @@
|
|
|
8
8
|
[](https://nnodely.readthedocs.io/)
|
|
9
9
|
[](https://pypi.org/project/nnodely/)
|
|
10
10
|
|
|
11
|
-
# Neural Network
|
|
11
|
+
# Neural Network Framework for Modelling, Control, and Estimation of Physical Systems
|
|
12
12
|
|
|
13
13
|
Modeling, control, and estimation of physical systems are central to many engineering disciplines. While data-driven methods like neural networks offer powerful tools, they often struggle to **incorporate prior domain knowledge**, limiting their interpretability, generalizability, and safety.
|
|
14
14
|
|
|
15
|
-
To bridge this gap, we present ***nnodely*** (where "nn" can be read as "m," forming *Modely*) — a framework that facilitates the creation and deployment of **Model-Structured Neural Networks** (**MS-NNs**).
|
|
15
|
+
To bridge this gap, we present ***nnodely*** (where "nn" can be read as "m," forming *Modely*) — a framework that facilitates the creation and deployment of **Model-Structured Neural Networks** (**MS-NNs**).
|
|
16
16
|
MS-NNs combine the learning capabilities of neural networks with structural **priors** grounded in **physics, control, and estimation theory**, enabling:
|
|
17
17
|
|
|
18
|
-
- **Reduced training data** requirements
|
|
19
|
-
- **Generalization** to unseen scenarios
|
|
20
|
-
- **Real-time** deployment in real-world applications
|
|
18
|
+
- **Reduced training data** requirements
|
|
19
|
+
- **Generalization** to unseen scenarios
|
|
20
|
+
- **Real-time** deployment in real-world applications
|
|
21
21
|
|
|
22
22
|
In short:
|
|
23
23
|
|
|
@@ -126,10 +126,10 @@ The `nnodely` main class defined in __nnodely.py__, it contains all the main pro
|
|
|
126
126
|
2. __loader.py__ contains the function for managing the dataset, the main function is `dataLoad`.
|
|
127
127
|
3. __trainer.py__ contains the function for training the network as the `trainModel`.
|
|
128
128
|
4. __exporter.py__ contains all the function for import and export: `saveModel`, `loadModel`, `exportONNX` etc..
|
|
129
|
-
5. __validator.py__ contains all the function for validate the model and the `resultsAnalysis`.
|
|
129
|
+
5. __validator.py__ contains all the function for validate the model and the `resultsAnalysis`.
|
|
130
130
|
6. All the operators derive from `Network` defined in __network.py__, that contains the shared support functions for all the operators.
|
|
131
131
|
|
|
132
|
-
The folder `basic/` contains the main classes for the low level functionalities:
|
|
132
|
+
The folder `basic/` contains the main classes for the low level functionalities:
|
|
133
133
|
1. __model.py__ containts the pytorch template model for the structured network.
|
|
134
134
|
2. __modeldef.py__ containts the operation for work with the json model definition.
|
|
135
135
|
3. __loss.py__ contains the loss functions.
|
|
@@ -153,14 +153,14 @@ The main basic layers without parameters are:
|
|
|
153
153
|
2. __arithmetic.py__ this file contains the aritmetic functions as: +, -, /, *., **.
|
|
154
154
|
3. __trigonometric.py__ this file contains all the trigonometric functions.
|
|
155
155
|
4. __part.py__ are used for selecting part of the data.
|
|
156
|
-
5. __fuzzify.py__ contains the operation for the fuzzification of a variable,
|
|
156
|
+
5. __fuzzify.py__ contains the operation for the fuzzification of a variable,
|
|
157
157
|
commonly used in the local model as activation function as in [[1]](#1) with rectangular activation functions or in [[3]](#3), [[4]](#4) and [[5]](#5) with triangular activation function activation functions.
|
|
158
158
|
Using fuzzification it is also possible create a channel coding as presented in [[2]](#2).
|
|
159
159
|
|
|
160
160
|
The main basic layers with parameters are:
|
|
161
|
-
1. __fir.py__ this file contains the finite impulse response filter function. It is a linear operation on the time dimension (second dimension).
|
|
161
|
+
1. __fir.py__ this file contains the finite impulse response filter function. It is a linear operation on the time dimension (second dimension).
|
|
162
162
|
This filter was introduced in [[1]](#1).
|
|
163
|
-
2. __linear.py__ this file contains the linear function. Typical Linear operation `W*x+b` operated on the space dimension (third dimension).
|
|
163
|
+
2. __linear.py__ this file contains the linear function. Typical Linear operation `W*x+b` operated on the space dimension (third dimension).
|
|
164
164
|
This operation is presented in [[1]](#1).
|
|
165
165
|
3. __localmodel.py__ this file contains the logic for build a local model. This operation is presented in [[1]](#1), [[3]](#3), [[4]](#4) and [[5]](#5).
|
|
166
166
|
4. __parametricfunction.py__ are the user custom function. The function can use the pytorch syntax. A parametric function is presented in [[3]](#3), [[4]](#4), [[5]](#5).
|
|
@@ -196,8 +196,8 @@ This folder contains the images used in the documentation.
|
|
|
196
196
|
|
|
197
197
|
To contribute to the nnodely framework, you can:
|
|
198
198
|
|
|
199
|
-
- Open a pull request if you have a new feature or bug fix.
|
|
200
|
-
- Open an issue if you have a question or suggestion.
|
|
199
|
+
- Open a pull request if you have a new feature or bug fix.
|
|
200
|
+
- Open an issue if you have a question or suggestion.
|
|
201
201
|
|
|
202
202
|
We welcome contributions and collaborations.
|
|
203
203
|
|
|
@@ -212,53 +212,53 @@ This project is released under the license [License: MIT](https://opensource.org
|
|
|
212
212
|
<a name="references"></a>
|
|
213
213
|
## References
|
|
214
214
|
|
|
215
|
-
<a id="1">[1]</a>
|
|
216
|
-
Mauro Da Lio, Daniele Bortoluzzi, Gastone Pietro Rosati Papini. (2019).
|
|
217
|
-
Modelling longitudinal vehicle dynamics with neural networks.
|
|
215
|
+
<a id="1">[1]</a>
|
|
216
|
+
Mauro Da Lio, Daniele Bortoluzzi, Gastone Pietro Rosati Papini. (2019).
|
|
217
|
+
Modelling longitudinal vehicle dynamics with neural networks.
|
|
218
218
|
Vehicle System Dynamics. https://doi.org/10.1080/00423114.2019.1638947 (look the [[code]](https://github.com/tonegas/nnodely-applications/blob/main/vehicle/model_longit_vehicle_dynamics/model_longit_vehicle_dynamics.py))
|
|
219
219
|
|
|
220
|
-
<a id="2">[2]</a>
|
|
221
|
-
Alice Plebe, Mauro Da Lio, Daniele Bortoluzzi. (2019).
|
|
222
|
-
On Reliable Neural Network Sensorimotor Control in Autonomous Vehicles.
|
|
220
|
+
<a id="2">[2]</a>
|
|
221
|
+
Alice Plebe, Mauro Da Lio, Daniele Bortoluzzi. (2019).
|
|
222
|
+
On Reliable Neural Network Sensorimotor Control in Autonomous Vehicles.
|
|
223
223
|
IEEE Transaction on Intelligent Transportation System. https://doi.org/10.1109/TITS.2019.2896375
|
|
224
224
|
|
|
225
|
-
<a id="3">[3]</a>
|
|
226
|
-
Mauro Da Lio, Riccardo Donà, Gastone Pietro Rosati Papini, Francesco Biral, Henrik Svensson. (2020).
|
|
225
|
+
<a id="3">[3]</a>
|
|
226
|
+
Mauro Da Lio, Riccardo Donà, Gastone Pietro Rosati Papini, Francesco Biral, Henrik Svensson. (2020).
|
|
227
227
|
A Mental Simulation Approach for Learning Neural-Network Predictive Control (in Self-Driving Cars).
|
|
228
228
|
IEEE Access. https://doi.org/10.1109/ACCESS.2020.3032780 (look the [[code]](https://github.com/tonegas/nnodely-applications/blob/main/vehicle/model_lateral_vehicle_dynamics/model_lateral_vehicle_dynamics.ipynb))
|
|
229
229
|
|
|
230
|
-
<a id="4">[4]</a>
|
|
231
|
-
Edoardo Pagot, Mattia Piccinini, Enrico Bertolazzi, Francesco Biral. (2023).
|
|
230
|
+
<a id="4">[4]</a>
|
|
231
|
+
Edoardo Pagot, Mattia Piccinini, Enrico Bertolazzi, Francesco Biral. (2023).
|
|
232
232
|
Fast Planning and Tracking of Complex Autonomous Parking Maneuvers With Optimal Control and Pseudo-Neural Networks.
|
|
233
233
|
IEEE Access. https://doi.org/10.1109/ACCESS.2023.3330431 (look the [[code]](https://github.com/tonegas/nnodely-applications/blob/main/vehicle/control_steer_car_parking/control_steer_car_parking.ipynb))
|
|
234
234
|
|
|
235
|
-
<a id="5">[5]</a>
|
|
235
|
+
<a id="5">[5]</a>
|
|
236
236
|
Mattia Piccinini, Sebastiano Taddei, Matteo Larcher, Mattia Piazza, Francesco Biral. (2023).
|
|
237
237
|
A Physics-Driven Artificial Agent for Online Time-Optimal Vehicle Motion Planning and Control.
|
|
238
238
|
IEEE Access. https://doi.org/10.1109/ACCESS.2023.3274836 (look [[code basic]](https://github.com/tonegas/nnodely-applications/blob/main/vehicle/control_steer_artificial_race_driver/control_steer_artificial_race_driver.ipynb)
|
|
239
239
|
and [[code extended]](https://github.com/tonegas/nnodely-applications/blob/main/vehicle/control_steer_artificial_race_driver_extended/control_steer_artificial_race_driver_extended.ipynb))
|
|
240
240
|
|
|
241
|
-
<a id="6">[6]</a>
|
|
241
|
+
<a id="6">[6]</a>
|
|
242
242
|
Hector Perez-Villeda, Justus Piater, Matteo Saveriano. (2023).
|
|
243
243
|
Learning and extrapolation of robotic skills using task-parameterized equation learner networks.
|
|
244
244
|
Robotics and Autonomous Systems. https://doi.org/10.1016/j.robot.2022.104309 (look the [[code]](https://github.com/tonegas/nnodely-applications/blob/main/equation_learner/equation_learner.ipynb))
|
|
245
245
|
|
|
246
|
-
<a id="7">[7]</a>
|
|
246
|
+
<a id="7">[7]</a>
|
|
247
247
|
M. Raissi. P. Perdikaris b, G.E. Karniadakis a. (2019).
|
|
248
248
|
Physics-informed neural networks: A deep learning framework for solving forward and inverse problems involving nonlinear partial differential equations
|
|
249
249
|
Journal of Computational Physics. https://doi.org/10.1016/j.jcp.2018.10.045 (look the [[example Burger's equation]](https://github.com/tonegas/nnodely-applications/blob/main/pinn/pinn_Burgers_equation.ipynb))
|
|
250
250
|
|
|
251
|
-
<a id="8">[8]</a>
|
|
251
|
+
<a id="8">[8]</a>
|
|
252
252
|
Wojciech Marian Czarnecki, Simon Osindero, Max Jaderberg, Grzegorz Świrszcz, Razvan Pascanu. (2017).
|
|
253
253
|
Sobolev Training for Neural Networks.
|
|
254
254
|
arXiv. https://doi.org/10.48550/arXiv.1706.04859 (look the [[code]](https://github.com/tonegas/nnodely-applications/blob/main/sobolev/Sobolev_learning.ipynb))
|
|
255
255
|
|
|
256
|
-
<a id="9">[9]</a>
|
|
256
|
+
<a id="9">[9]</a>
|
|
257
257
|
Mattia Piccinini, Matteo Zumerle, Johannes Betz, Gastone Pietro Rosati Papini. (2025).
|
|
258
258
|
A Road Friction-Aware Anti-Lock Braking System Based on Model-Structured Neural Networks.
|
|
259
259
|
IEEE Open Journal of Intelligent Transportation Systems. https://doi.org/10.1109/OJITS.2025.3563347 (look at the [[code]](https://github.com/tonegas/nnodely-applications/tree/main/vehicle/road_friction_aware_ABS))
|
|
260
260
|
|
|
261
|
-
<a id="10">[10]</a>
|
|
261
|
+
<a id="10">[10]</a>
|
|
262
262
|
Mauro Da Lio, Mattia Piccinini, Francesco Biral. (2023).
|
|
263
263
|
Robust and Sample-Efficient Estimation of Vehicle Lateral Velocity Using Neural Networks With Explainable Structure Informed by Kinematic Principles.
|
|
264
264
|
IEEE Transactions on Intelligent Transportation Systems. https://doi.org/10.1109/TITS.2023.3303776
|
|
@@ -269,4 +269,4 @@ IEEE Transactions on Intelligent Transportation Systems. https://doi.org/10.1109
|
|
|
269
269
|
## Cite Us
|
|
270
270
|
|
|
271
271
|
> TODO: Possiamo aggiungere DOI di repo con zenodo e mettere la citazione di quello [guida](https://docs.github.com/en/repositories/archiving-a-github-repository/referencing-and-citing-content)
|
|
272
|
-
-->
|
|
272
|
+
-->
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["uv_build>=0.11.6,<0.12.0"]
|
|
3
|
+
build-backend = "uv_build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "nnodely"
|
|
7
|
+
version = "1.5.5.dev1"
|
|
8
|
+
description = "Model-structured neural network framework for the modeling and control of physical systems"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.10,<3.14"
|
|
11
|
+
license = "MIT"
|
|
12
|
+
license-files = ["LICENSE"]
|
|
13
|
+
authors = [
|
|
14
|
+
{ name = "Gastone Pietro Rosati Papini", email = "tonegas@gmail.com" },
|
|
15
|
+
]
|
|
16
|
+
classifiers = [
|
|
17
|
+
"Programming Language :: Python :: 3",
|
|
18
|
+
"Operating System :: OS Independent",
|
|
19
|
+
]
|
|
20
|
+
dependencies = [
|
|
21
|
+
"numpy == 1.26.4; platform_machine == 'x86_64' and python_version == '3.10'",
|
|
22
|
+
"torch == 2.2.2; platform_machine == 'x86_64' and python_version == '3.10'",
|
|
23
|
+
"torch == 2.6.0; platform_machine != 'x86_64' or python_version != '3.10'",
|
|
24
|
+
"numpy",
|
|
25
|
+
"onnx",
|
|
26
|
+
"pandas",
|
|
27
|
+
"reportlab",
|
|
28
|
+
"matplotlib",
|
|
29
|
+
"onnxruntime",
|
|
30
|
+
"graphviz",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
[project.urls]
|
|
34
|
+
"Homepage" = "https://github.com/tonegas/nnodely"
|
|
35
|
+
"Repository" = "https://github.com/tonegas/nnodely"
|
|
36
|
+
|
|
37
|
+
[dependency-groups]
|
|
38
|
+
dev = ["pre-commit>=4.5.1", "pytest-cov>=7.1.0", "ruff>=0.15.10"]
|
|
39
|
+
|
|
40
|
+
[[tool.uv.index]]
|
|
41
|
+
name = "testpypi"
|
|
42
|
+
url = "https://test.pypi.org/simple/"
|
|
43
|
+
publish-url = "https://test.pypi.org/legacy/"
|
|
44
|
+
explicit = true
|
|
@@ -14,7 +14,15 @@ from nnodely.layers.arithmetic import Add, Sum, Sub, Mul, Div, Pow, Neg, Sign
|
|
|
14
14
|
from nnodely.layers.trigonometric import Sin, Cos, Tan, Cosh, Tanh, Sech
|
|
15
15
|
from nnodely.layers.parametricfunction import ParamFun
|
|
16
16
|
from nnodely.layers.fuzzify import Fuzzify
|
|
17
|
-
from nnodely.layers.part import
|
|
17
|
+
from nnodely.layers.part import (
|
|
18
|
+
Part,
|
|
19
|
+
Select,
|
|
20
|
+
Concatenate,
|
|
21
|
+
SamplePart,
|
|
22
|
+
SampleSelect,
|
|
23
|
+
TimePart,
|
|
24
|
+
TimeConcatenate,
|
|
25
|
+
)
|
|
18
26
|
from nnodely.layers.localmodel import LocalModel
|
|
19
27
|
from nnodely.layers.equationlearner import EquationLearner
|
|
20
28
|
from nnodely.layers.timeoperation import Integrate, Differentiate
|
|
@@ -37,43 +45,85 @@ from nnodely.support import logger
|
|
|
37
45
|
major, minor = sys.version_info.major, sys.version_info.minor
|
|
38
46
|
logger.LOG_LEVEL = logging.INFO
|
|
39
47
|
|
|
40
|
-
__version__ =
|
|
48
|
+
__version__ = "1.5.4"
|
|
41
49
|
|
|
42
50
|
if major < 3:
|
|
43
|
-
sys.exit(
|
|
51
|
+
sys.exit(
|
|
52
|
+
"Sorry, Python 2 is not supported. You need Python >= 3.10 for "
|
|
53
|
+
+ __package__
|
|
54
|
+
+ "."
|
|
55
|
+
)
|
|
44
56
|
elif minor < 9:
|
|
45
|
-
sys.exit("Sorry, You need Python >= 3.10 for "+__package__+".")
|
|
57
|
+
sys.exit("Sorry, You need Python >= 3.10 for " + __package__ + ".")
|
|
46
58
|
else:
|
|
47
|
-
print(
|
|
48
|
-
|
|
49
|
-
|
|
59
|
+
print(
|
|
60
|
+
">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
|
|
61
|
+
+ f" {__package__}_v{__version__} ".center(20, "-")
|
|
62
|
+
+ "<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<"
|
|
63
|
+
)
|
|
50
64
|
|
|
51
65
|
|
|
52
66
|
__all__ = [
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
67
|
+
"nnodely",
|
|
68
|
+
"Modely",
|
|
69
|
+
"clearNames",
|
|
70
|
+
"Input",
|
|
71
|
+
"Connect",
|
|
72
|
+
"ClosedLoop",
|
|
73
|
+
"Parameter",
|
|
74
|
+
"Constant",
|
|
75
|
+
"SampleTime",
|
|
76
|
+
"Output",
|
|
77
|
+
"Relu",
|
|
78
|
+
"ELU",
|
|
79
|
+
"Softmax",
|
|
80
|
+
"Sigmoid",
|
|
81
|
+
"Identity",
|
|
82
|
+
"Fir",
|
|
83
|
+
"Linear",
|
|
84
|
+
"NeuralODE",
|
|
85
|
+
"Add",
|
|
86
|
+
"Sum",
|
|
87
|
+
"Sub",
|
|
88
|
+
"Mul",
|
|
89
|
+
"Div",
|
|
90
|
+
"Pow",
|
|
91
|
+
"Neg",
|
|
92
|
+
"Sign",
|
|
93
|
+
"Sin",
|
|
94
|
+
"Cos",
|
|
95
|
+
"Tan",
|
|
96
|
+
"Cosh",
|
|
97
|
+
"Tanh",
|
|
98
|
+
"Sech",
|
|
99
|
+
"ParamFun",
|
|
100
|
+
"Fuzzify",
|
|
101
|
+
"Part",
|
|
102
|
+
"Select",
|
|
103
|
+
"Concatenate",
|
|
104
|
+
"SamplePart",
|
|
105
|
+
"SampleSelect",
|
|
106
|
+
"TimePart",
|
|
107
|
+
"TimeConcatenate",
|
|
108
|
+
"LocalModel",
|
|
109
|
+
"EquationLearner",
|
|
110
|
+
"Integrate",
|
|
111
|
+
"Differentiate",
|
|
112
|
+
"Interpolation",
|
|
113
|
+
"ForwardEuler",
|
|
114
|
+
"RK2",
|
|
115
|
+
"RK4",
|
|
116
|
+
"TextVisualizer",
|
|
117
|
+
"MPLVisualizer",
|
|
118
|
+
"MPLNotebookVisualizer",
|
|
119
|
+
"StandardExporter",
|
|
120
|
+
"SGD",
|
|
121
|
+
"Adam",
|
|
122
|
+
"Optimizer",
|
|
123
|
+
"init_negexp",
|
|
124
|
+
"init_lin",
|
|
125
|
+
"init_constant",
|
|
126
|
+
"init_exp",
|
|
77
127
|
# Main nnodely classes
|
|
78
|
-
|
|
128
|
+
"__version__",
|
|
79
129
|
]
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
import torch
|
|
3
|
+
from nnodely.support.utils import check
|
|
4
|
+
|
|
5
|
+
available_losses = ["mse", "rmse", "mae", "cross_entropy"]
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class CustomLoss(nn.Module):
|
|
9
|
+
def __init__(self, loss_type="mse", **kwargs):
|
|
10
|
+
super(CustomLoss, self).__init__()
|
|
11
|
+
check(
|
|
12
|
+
loss_type in available_losses,
|
|
13
|
+
TypeError,
|
|
14
|
+
f'The "{loss_type}" loss is not available. Possible losses are: {available_losses}.',
|
|
15
|
+
)
|
|
16
|
+
self.loss_type = loss_type
|
|
17
|
+
self.loss = nn.MSELoss(**kwargs)
|
|
18
|
+
if callable(loss_type):
|
|
19
|
+
self.loss = loss_type
|
|
20
|
+
elif self.loss_type == "mae":
|
|
21
|
+
self.loss = nn.L1Loss(**kwargs)
|
|
22
|
+
elif self.loss_type == "cross_entropy":
|
|
23
|
+
self.loss = nn.CrossEntropyLoss(**kwargs)
|
|
24
|
+
|
|
25
|
+
def forward(self, inA, inB):
|
|
26
|
+
if self.loss_type == "cross_entropy":
|
|
27
|
+
inB = (
|
|
28
|
+
inB.squeeze().float()
|
|
29
|
+
if inA.shape == inB.shape
|
|
30
|
+
else inB.squeeze().long()
|
|
31
|
+
)
|
|
32
|
+
inA = inA.squeeze()
|
|
33
|
+
res = self.loss(inA, inB)
|
|
34
|
+
if self.loss_type == "rmse":
|
|
35
|
+
res = torch.sqrt(res)
|
|
36
|
+
return res
|