NeuralNetworks 0.1.12__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,45 @@
1
+ # NeuralNetworks - Multi-Layer Perceptrons avec encodage Fourier
2
+ # Copyright (C) 2026 Alexandre Brun
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+
8
+ from ..Dependances import plt, np
9
+
10
+ def learnings(*nets):
11
+
12
+ # --- Initialisation de la figure ---
13
+ fig, ax1 = plt.subplots()
14
+ fig.set_figheight(5)
15
+ fig.set_figwidth(5)
16
+
17
+ # --- Définition des limites des axes ---
18
+ all_learnings = [[lr for lr in net.learnings] for net in nets]
19
+ if max(len(lst) for lst in all_learnings) == 1:
20
+ lenlearnings = 2
21
+ else:
22
+ lenlearnings = max(len(lst) for lst in all_learnings)
23
+ plt.xlim(1, lenlearnings)
24
+
25
+ # --- Tracé des courbes de pertes pour chaque réseau ---
26
+ for k, net in enumerate(nets):
27
+ ax1.plot(
28
+ np.arange(1, len(all_learnings[k]) + 1),
29
+ all_learnings[k],
30
+ label=net.name
31
+ )
32
+ ax1.set_xlabel("Epochs")
33
+ ax1.set_ylabel("Learning rate")
34
+ ax1.legend(loc="upper left")
35
+ ax1.grid(True)
36
+
37
+ plt.yscale('log', nonpositive='mask')
38
+ # --- Affichage ---
39
+ plt.legend()
40
+ plt.xlabel("Epoch")
41
+ plt.ylabel("Learning rate")
42
+ fig.canvas.draw_idle()
43
+ plt.tight_layout()
44
+ plt.ion() # mode interactif
45
+ plt.show()
@@ -0,0 +1,45 @@
1
+ # NeuralNetworks - Multi-Layer Perceptrons avec encodage Fourier
2
+ # Copyright (C) 2026 Alexandre Brun
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+
8
+ from ..Dependances import plt, np
9
+
10
+ def losses(*nets):
11
+
12
+ # --- Initialisation de la figure ---
13
+ fig, ax1 = plt.subplots()
14
+ fig.set_figheight(5)
15
+ fig.set_figwidth(5)
16
+
17
+ # --- Définition des limites des axes ---
18
+ all_losses = [[loss for loss in net.losses] for net in nets]
19
+ if max(len(lst) for lst in all_losses) == 1:
20
+ lenlosses = 2
21
+ else:
22
+ lenlosses = max(len(lst) for lst in all_losses)
23
+ plt.xlim(1, lenlosses)
24
+
25
+ # --- Tracé des courbes de pertes pour chaque réseau ---
26
+ for k, net in enumerate(nets):
27
+ ax1.plot(
28
+ np.arange(1, len(all_losses[k]) + 1),
29
+ all_losses[k],
30
+ label=net.name
31
+ )
32
+ ax1.set_xlabel("Epochs")
33
+ ax1.set_ylabel("Loss")
34
+ ax1.legend(loc="upper left")
35
+ ax1.grid(True)
36
+
37
+ plt.yscale('log', nonpositive='mask')
38
+ # --- Affichage ---
39
+ plt.legend()
40
+ plt.xlabel("Epoch")
41
+ plt.ylabel("Résidus")
42
+ fig.canvas.draw_idle()
43
+ plt.tight_layout()
44
+ plt.ion() # mode interactif
45
+ plt.show()
@@ -0,0 +1,9 @@
1
+ # NeuralNetworks - Multi-Layer Perceptrons avec encodage Fourier
2
+ # Copyright (C) 2026 Alexandre Brun
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+
8
+ from .Losses import losses
9
+ from .Learnings import learnings
@@ -1,128 +1,23 @@
1
- # NeuralNetworksBeta - Multi-Layer Perceptrons avec encodage Fourier
2
- # Copyright (C) 2025 Alexandre Brun
1
+ # NeuralNetworks - Multi-Layer Perceptrons avec encodage Fourier
2
+ # Copyright (C) 2026 Alexandre Brun
3
3
  # This program is free software: you can redistribute it and/or modify
4
4
  # it under the terms of the GNU General Public License as published by
5
5
  # the Free Software Foundation, either version 3 of the License, or
6
6
  # (at your option) any later version.
7
7
 
8
- """
9
- NeuralNetworks Module
10
- ====================
11
-
12
- Module complet pour la création, l'entraînement et la visualisation de Multi-Layer Perceptrons (MLP)
13
- avec encodage optionnel Fourier, gestion automatique des pertes, compilation Torch et outils
14
- de traitement d'images pour l'apprentissage sur des images RGB.
15
-
16
- Contenu principal
17
- -----------------
18
-
19
- Classes
20
- -------
21
-
22
- MLP
23
- Multi-Layer Perceptron (MLP) avec options avancées :
24
- - Encodage Fourier gaussien (RFF) optionnel.
25
- - Stockage automatique des pertes.
26
- - Compilation Torch optionnelle pour accélérer l’inférence.
27
- - Gestion flexible de l’optimiseur, de la fonction de perte et de la normalisation.
28
-
29
- Méthodes principales :
30
- - __init__(layers, learning_rate, Fourier, optimizer, criterion, normalizer, name, Iscompiled)
31
- Initialise le réseau avec toutes les options.
32
- - train(inputs, outputs, num_epochs, batch_size)
33
- Entraîne le MLP sur des données (inputs → outputs) en utilisant AMP et mini-batchs.
34
- - plot(inputs, img_array)
35
- Affiche l'image originale, la prédiction du MLP et la courbe des pertes.
36
- - __call__(x)
37
- Applique l’encodage puis le MLP pour produire une prédiction.
38
- - Create_MLP(layers)
39
- Construit le MLP avec normalisation/activation et Sigmoid finale.
40
- - params()
41
- Retourne tous les poids du MLP (ligne par ligne) sous forme de liste de numpy.ndarray.
42
- - nb_params()
43
- Calcule le nombre total de poids dans le MLP.
44
- - neurons()
45
- Retourne la liste des biais (neurones) de toutes les couches linéaires.
46
- - __repr__()
47
- Affiche un schéma visuel du MLP via visualtorch et print des dimensions.
48
-
49
- Fonctions utilitaires
50
- --------------------
51
-
52
- tensorise(obj)
53
- Convertit un objet array-like ou tensor en torch.Tensor float32 sur le device actif.
54
-
55
- list_to_cpu(cuda_tensors)
56
- Copie une liste de tenseurs CUDA et les transfère sur le CPU.
57
-
58
- rglen(list)
59
- Renvoie un range correspondant aux indices d'une liste.
60
-
61
- fPrintDoc(obj)
62
- Crée une fonction lambda qui affiche le docstring d'un objet.
63
-
64
- image_from_url(url, img_size)
65
- Télécharge une image depuis une URL, la redimensionne et génère :
66
- - img_array : np.ndarray (H, W, 3) pour affichage.
67
- - inputs : tenseur (H*W, 2) coordonnées normalisées.
68
- - outputs : tenseur (H*W, 3) valeurs RGB cibles.
69
-
70
- Visualisation et comparaison
71
- ----------------------------
72
-
73
- plot(img_array, inputs, *nets)
74
- Affiche pour chaque réseau l'image reconstruite à partir des entrées.
75
-
76
- compare(img_array, inputs, *nets)
77
- Affiche pour chaque réseau l'erreur absolue entre l'image originale et la prédiction,
78
- et trace également les pertes cumulées. Chaque réseau doit posséder :
79
- - encoding(x) si RFF activé
80
- - model() retournant un tenseur (N, 3)
81
- - attribute losses
82
-
83
- Objets et dictionnaires
84
- -----------------------
85
-
86
- Norm_list : dict
87
- Contient les modules PyTorch correspondant aux fonctions de normalisation/activation
88
- disponibles (ReLU, GELU, Sigmoid, Tanh, etc.)
89
-
90
- Criterion_list : dict
91
- Contient les fonctions de perte PyTorch disponibles (MSE, L1, SmoothL1, BCE, CrossEntropy, etc.)
92
-
93
- Optim_list(self, learning_rate)
94
- Retourne un dictionnaire d’optimiseurs PyTorch initialisés avec `self.model.parameters()`.
95
-
96
- Device et configuration
97
- -----------------------
98
-
99
- device
100
- Device par défaut (GPU si disponible, sinon CPU).
8
+ # Import des dépendances et utilitaires globaux (device, settings, tensorise, etc.)
9
+ from .Dependances import norms, crits, optims, rglen, device, pi, e, tensorise
101
10
 
102
- Paramètres matplotlib et PyTorch
103
- - Style global pour fond transparent et texte gris.
104
- - Optimisations CUDA activées pour TF32, matmul et convolutions.
105
- - Autograd configuré pour privilégier les performances.
11
+ # Modèle MLP principal + fonction d'entraînement associée
12
+ from .MLP import MLP, losses
106
13
 
107
- Notes générales
108
- ---------------
14
+ from .Trainer import Trainer
109
15
 
110
- - Toutes les méthodes de MLP utilisent les tenseurs sur le device global (CPU ou GPU).
111
- - Les images doivent être normalisées entre 0 et 1.
112
- - Les fonctions interactives (plot, compare) utilisent matplotlib en mode interactif.
113
- - Le module est conçu pour fonctionner dans Jupyter et scripts Python classiques.
114
- """
16
+ from .UI import *
115
17
 
116
- # Import des dépendances et utilitaires globaux (device, settings, tensorise, etc.)
117
- from .Dependances import norms, crits, optims, rglen, device, pi, e, tensorise
18
+ from .Latent import Latent
118
19
 
119
20
  # Fonctions de chargement/preprocessing des images
120
- from .Image import image_from_url
121
-
122
- # Fonctions d'affichage : reconstruction, comparaison, courbes de pertes
123
- from .Plot import compare, plot, losses, train
124
-
125
- # Modèle MLP principal + fonction d'entraînement associée
126
- from .MLP import MLP
21
+ from .tools import image, MNIST, AirfRANS
127
22
 
128
- __version__ = "0.1.12"
23
+ __version__ = "0.2.2"
@@ -0,0 +1,36 @@
1
+ # NeuralNetworksBeta - Multi-Layer Perceptrons avec encodage Fourier
2
+ # Copyright (C) 2025 Alexandre Brun
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+
8
+ from ..Dependances import *
9
+ from airfrans import *
10
+
11
+ def download(path,unzip = True, OpenFOAM = False):
12
+ """
13
+ Télécharge le dataset AirfRANS dans le dossier spécifié.
14
+
15
+ Cette fonction est un simple wrapper autour :
16
+ dataset.download(root=path, file_name='AirfRANS', unzip=True, OpenFOAM=True)
17
+
18
+ Les arguments `unzip` et `OpenFOAM` sont actuellement ignorés par la fonction
19
+ et forcés à True dans l’appel interne.
20
+
21
+ Parameters
22
+ ----------
23
+ path : str
24
+ Chemin du dossier dans lequel le dataset doit être téléchargé.
25
+ unzip : bool, optional
26
+ Paramètre non utilisé. Le téléchargement interne force `unzip=True`.
27
+ OpenFOAM : bool, optional
28
+ Paramètre non utilisé. Le téléchargement interne force `OpenFOAM=True`.
29
+
30
+ Notes
31
+ -----
32
+ - Le fichier téléchargé s’appelle `'AirfRANS'`.
33
+ - Le dataset est automatiquement décompressé.
34
+ - Le format OpenFOAM est toujours inclus.
35
+ """
36
+ dataset.download(root = path, file_name = 'AirfRANS', unzip = True, OpenFOAM = True)
@@ -0,0 +1,118 @@
1
+ # NeuralNetworksBeta - Multi-Layer Perceptrons avec encodage Fourier
2
+ # Copyright (C) 2025 Alexandre Brun
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+
8
+ from ..Dependances import *
9
+ from torchvision.datasets import MNIST
10
+
11
+ def data(path):
12
+ """
13
+ Charge le dataset MNIST depuis `path`, applique une transformation en tenseur,
14
+ puis convertit les images en vecteurs numpy aplatis et les labels en tenseur PyTorch.
15
+
16
+ Parameters
17
+ ----------
18
+ path : str
19
+ Chemin du dossier où MNIST sera téléchargé ou chargé.
20
+
21
+ Returns
22
+ -------
23
+ inputs : np.ndarray
24
+ Tableau numpy de shape (N, 784) contenant les images MNIST aplaties.
25
+ Chaque pixel est normalisé dans [0, 1] via `ToTensor()`.
26
+ outputs : torch.Tensor
27
+ Tenseur PyTorch de shape (N, 1) contenant les labels entiers (0–9).
28
+
29
+ Notes
30
+ -----
31
+ - Le dataset MNIST est téléchargé si absent.
32
+ - Chaque image 28×28 est convertie via `ToTensor()` puis aplatie en vecteur de 784 valeurs.
33
+ - Les labels sont convertis en tenseur long et remis dans une dimension (N, 1)
34
+ pour compatibilité avec un MLP produisant une sortie scalaire.
35
+ """
36
+ transform = Compose([ToTensor()])
37
+ dataset = MNIST(path, transform=transform, download=True)
38
+
39
+ inputs, outputs = [], []
40
+ for data in dataset:
41
+ outputs.append(data[1])
42
+ value= data[0].numpy().flatten()
43
+ inputs.append(value)
44
+ outputs = torch.tensor(np.array(outputs)) # convert list → tensor
45
+ outputs = outputs.unsqueeze(1)
46
+ inputs = np.array(inputs)
47
+
48
+ return inputs, outputs
49
+
50
+ def evaluate (inputs, *nets):
51
+ """
52
+ Évalue visuellement un ou plusieurs réseaux sur un échantillon MNIST choisi
53
+ aléatoirement. La fonction affiche simultanément :
54
+
55
+ - l'image d'entrée (28×28),
56
+ - les courbes de perte de chaque réseau (échelle logarithmique),
57
+ - la prédiction de chaque réseau imprimée dans la console.
58
+
59
+ Parameters
60
+ ----------
61
+ inputs : np.ndarray
62
+ Tableau numpy contenant les images aplaties (N, 784).
63
+ Une image sera choisie aléatoirement parmi celles-ci.
64
+ nets : MLP
65
+ Un ou plusieurs réseaux entraînés, chacun possédant :
66
+ - net.losses : liste des pertes par époque,
67
+ - net.name : nom du modèle,
68
+ - net(x) : méthode d'inférence retournant une valeur prédite.
69
+
70
+ Notes
71
+ -----
72
+ - L'image affichée est l'entrée sélectionnée, remise en forme en 28×28.
73
+ - Les pertes sont tracées pour chaque réseau sur une échelle Y logarithmique.
74
+ - Les prédictions sont arrondies et converties en entiers pour un affichage clair.
75
+ - Une figure matplotlib avec deux sous-graphiques est générée via GridSpec :
76
+ * à gauche : l'image MNIST,
77
+ * à droite : les courbes de pertes.
78
+ - Les résultats (prédictions) sont également affichés dans la console.
79
+ """
80
+
81
+ # --- Configuration de la grille de figure ---
82
+ fig = plt.figure(figsize=(10, 5))
83
+ gs = GridSpec(1, 2, figure=fig)
84
+
85
+ index = np.random.randint(0,len(inputs)-1)
86
+
87
+ # --- Préparation du subplot pour les courbes de pertes ---
88
+ ax_loss = fig.add_subplot(gs[0, 1])
89
+ ax_loss.set_yscale('log', nonpositive='mask')
90
+ all_losses = [[loss for loss in net.losses] for net in nets]
91
+ if max(len(lst) for lst in all_losses) == 1:
92
+ lenlosses = 2
93
+ else:
94
+ lenlosses = max(len(lst) for lst in all_losses)
95
+ ax_loss.set_xlim(1, lenlosses)
96
+
97
+ preds = []
98
+ for k, net in enumerate(nets):
99
+ preds.append(int(np.round(net(inputs[index]))))
100
+ # Tracé des pertes cumulées
101
+ ax_loss.plot(np.arange(1, len(all_losses[k])+1), all_losses[k],label = net.name)
102
+ ax_loss.legend()
103
+
104
+ # --- Affichage de l'image originale ---
105
+ ax_orig = fig.add_subplot(gs[0, 0])
106
+ ax_orig.axis('off')
107
+ ax_orig.set_title("input")
108
+ show = inputs[index].reshape(28,28)
109
+ ax_orig.imshow(255*show)
110
+
111
+ # --- Affichage final ---
112
+ fig.canvas.draw_idle()
113
+ plt.tight_layout()
114
+ plt.ion()
115
+ plt.show()
116
+
117
+ for k in rglen(preds):
118
+ print(f"{nets[k].name} output : {preds[k]}")
@@ -0,0 +1,7 @@
1
+ # NeuralNetworksBeta - Multi-Layer Perceptrons avec encodage Fourier
2
+ # Copyright (C) 2025 Alexandre Brun
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+
@@ -0,0 +1,249 @@
1
+ # NeuralNetworksBeta - Multi-Layer Perceptrons avec encodage Fourier
2
+ # Copyright (C) 2025 Alexandre Brun
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+
8
+ from ..Dependances import *
9
+
10
+ def url(url, img_size=256):
11
+ """
12
+ Télécharge une image depuis une URL, la redimensionne et prépare les
13
+ données pour l'entraînement d'un MLP pixel-wise.
14
+
15
+ Cette fonction retourne :
16
+ - `img_array` : image RGB sous forme de tableau NumPy (H, W, 3), pour affichage.
17
+ - `inputs` : coordonnées normalisées (x, y) de chaque pixel, sous forme de tenseur (H*W, 2).
18
+ - `outputs` : valeurs RGB cibles pour chaque pixel, sous forme de tenseur (H*W, 3).
19
+
20
+ Paramètres
21
+ ----------
22
+ url : str
23
+ URL de l'image à télécharger.
24
+ img_size : int, optionnel
25
+ Taille finale carrée de l'image (img_size x img_size). Par défaut 256.
26
+
27
+ Retours
28
+ -------
29
+ img_array : numpy.ndarray of shape (H, W, 3)
30
+ Image sous forme de tableau NumPy, valeurs normalisées entre 0 et 1.
31
+ inputs : torch.Tensor of shape (H*W, 2)
32
+ Coordonnées normalisées des pixels pour l'entrée du MLP.
33
+ outputs : torch.Tensor of shape (H*W, 3)
34
+ Valeurs RGB cibles pour chaque pixel, pour la sortie du MLP.
35
+
36
+ Notes
37
+ -----
38
+ - La fonction utilise `PIL` pour le traitement de l'image et `torchvision.transforms`
39
+ pour la conversion en tenseur normalisé.
40
+ - Les coordonnées sont normalisées dans [0, 1] pour une utilisation optimale
41
+ avec des MLP utilisant Fourier Features ou activations standard.
42
+ - Les tenseurs `inputs` et `outputs` sont prêts à être envoyés sur GPU si nécessaire.
43
+ """
44
+
45
+ # --- Téléchargement et ouverture de l'image ---
46
+ response = requests.get(url)
47
+ img = Image.open(BytesIO(response.content)).convert("RGB")
48
+
49
+ # --- Redimensionnement et conversion en tenseur normalisé ---
50
+ transform = Compose([
51
+ Resize((img_size, img_size)),
52
+ ToTensor() # Donne un tenseur (3, H, W) normalisé entre 0 et 1
53
+ ])
54
+ img_tensor = transform(img)
55
+
56
+ # Récupération de la hauteur et largeur
57
+ h, w = img_tensor.shape[1:]
58
+
59
+ # Conversion en tableau NumPy (H, W, 3) pour affichage
60
+ img_array = img_tensor.permute(1, 2, 0).numpy()
61
+
62
+ # --- Création d'une grille normalisée des coordonnées des pixels ---
63
+ x_coords = torch.linspace(0, 1, w)
64
+ y_coords = torch.linspace(0, 1, h)
65
+ x_grid, y_grid = torch.meshgrid(x_coords, y_coords, indexing="ij")
66
+
67
+ # Flatten de la grille pour former les entrées du MLP : shape (H*W, 2)
68
+ inputs = torch.stack([x_grid.flatten(), y_grid.flatten()], dim=-1)
69
+
70
+ # Extraction des valeurs RGB comme sorties cibles : shape (H*W, 3)
71
+ outputs = img_tensor.view(3, -1).permute(1, 0)
72
+
73
+ return img_array, inputs, outputs
74
+ url.help = fPrintDoc(url)
75
+
76
+ def reshape(img_array, array):
77
+ """
78
+ Reshape un tenseur plat de prédiction en image (H, W, 3) en utilisant
79
+ les dimensions de l’image originale.
80
+
81
+ Parameters
82
+ ----------
83
+ img_array : np.ndarray of shape (H, W, 3)
84
+ Image originale servant de référence pour récupérer la hauteur (H)
85
+ et la largeur (W).
86
+ array : tensor-like or ndarray of shape (H*W, 3)
87
+ Tableau plat contenant les valeurs RGB prédites pour chaque pixel.
88
+
89
+ Returns
90
+ -------
91
+ np.ndarray of shape (H, W, 3)
92
+ Image reconstruite à partir du tableau plat.
93
+
94
+ Notes
95
+ -----
96
+ - Cette fonction ne modifie pas les valeurs, elle fait uniquement un reshape.
97
+ - Utile après une prédiction de type MLP qui renvoie un tableau (N, 3).
98
+ """
99
+
100
+ # Récupération de la hauteur et largeur à partir de l’image originale
101
+ h, w = img_array.shape[:2]
102
+
103
+ # Reconstruction en image RGB
104
+ return array.reshape(h, w, 3)
105
+ reshape.help = fPrintDoc(reshape)
106
+
107
+ def compare(img_array, inputs, *nets):
108
+ """
109
+ Affiche, pour chaque réseau, l’erreur absolue entre l’image originale
110
+ et l’image reconstruite par le réseau.
111
+
112
+ Chaque réseau doit posséder :
113
+ - une méthode `encoding(x)` (si RFF activé),
114
+ - un module `model` retournant un tenseur de shape (N, 3),
115
+ - une reconstruction compatible avec (H, W, 3).
116
+
117
+ Parameters
118
+ ----------
119
+ img_array : np.ndarray of shape (H, W, 3)
120
+ Image originale servant de référence.
121
+ inputs : tensor-like of shape (H*W, 2)
122
+ Coordonnées normalisées des pixels correspondant à chaque point de l'image.
123
+ *nets : *MLP
124
+ Un ou plusieurs réseaux.
125
+
126
+ Notes
127
+ -----
128
+ - L’affichage montre la différence absolue entre l’image originale et la prédiction du réseau.
129
+ - Les pertes cumulées sont également tracées pour chaque réseau.
130
+ - Utilise matplotlib en mode interactif.
131
+ """
132
+
133
+ # --- Conversion des inputs en tensor et récupération du nombre d'échantillons ---
134
+ inputs, n_samples = tensorise(inputs), inputs.size(0)
135
+ h, w = img_array.shape[:2]
136
+
137
+ # --- Configuration de la grille de figure ---
138
+ grid_length = 2 if len(nets) == 1 else len(nets)
139
+ fig = plt.figure(figsize=(5*grid_length, 10))
140
+ gs = GridSpec(2, grid_length, figure=fig)
141
+
142
+ # --- Affichage de l'image originale ---
143
+ ax_orig = fig.add_subplot(gs[0, 0])
144
+ ax_orig.axis('off')
145
+ ax_orig.set_title("Original Image")
146
+ ax_orig.imshow(img_array)
147
+
148
+ # --- Préparation du subplot pour les courbes de pertes ---
149
+ ax_loss = fig.add_subplot(gs[0, 1])
150
+ all_losses = [[loss for loss in net.losses] for net in nets]
151
+ if max(len(lst) for lst in all_losses) == 1:
152
+ lenlosses = 2
153
+ else:
154
+ lenlosses = max(len(lst) for lst in all_losses)
155
+ ax_loss.set_xlim(1, lenlosses)
156
+ ax_loss.set_yscale('log', nonpositive='mask')
157
+
158
+
159
+ # --- Boucle sur chaque réseau pour afficher l'erreur et les pertes ---
160
+ for k, net in enumerate(nets):
161
+ # Subplot pour l'erreur absolue
162
+ ax = fig.add_subplot(gs[1, k])
163
+ ax.axis('off')
164
+ ax.set_title(net.name)
165
+
166
+ # Prédiction et reconstruction de l'image
167
+ pred_img = net(inputs).reshape(h, w, 3)
168
+
169
+ # Tracé des pertes cumulées
170
+ ax_loss.plot(np.arange(1, len(all_losses[k])+1), all_losses[k],label = net.name)
171
+
172
+ # Affichage de l'erreur absolue
173
+ ax.imshow(np.abs(img_array - pred_img))
174
+ ax_loss.legend()
175
+
176
+ # --- Affichage final ---
177
+ fig.canvas.draw_idle()
178
+ plt.tight_layout()
179
+ plt.ion()
180
+ plt.show()
181
+ compare.help = fPrintDoc(compare)
182
+
183
+ def plot(img_array, inputs, *nets):
184
+ """
185
+ Affiche, pour chaque réseau, l’image reconstruite à partir de ses prédictions.
186
+
187
+ Parameters
188
+ ----------
189
+ img_array : np.ndarray of shape (H, W, 3)
190
+ Image originale, utilisée pour connaître les dimensions de reconstruction.
191
+ inputs : tensor-like of shape (H*W, 2)
192
+ Coordonnées normalisées des pixels correspondant à chaque point de l'image.
193
+ *nets : *MLP
194
+ Un ou plusieurs réseaux.
195
+ Notes
196
+ -----
197
+ - Cette fonction affiche la prédiction brute.
198
+ - Les pertes cumulées sont également tracées pour chaque réseau.
199
+ - Utilise matplotlib en mode interactif.
200
+ """
201
+
202
+ # --- Conversion des inputs en tensor et récupération du nombre d'échantillons ---
203
+ inputs, n_samples = tensorise(inputs), inputs.size(0)
204
+ h, w = img_array.shape[:2]
205
+
206
+ # --- Configuration de la grille de figure ---
207
+ grid_length = 2 if len(nets) == 1 else len(nets)
208
+ fig = plt.figure(figsize=(5*grid_length, 10))
209
+ gs = GridSpec(2, grid_length, figure=fig)
210
+
211
+ # --- Affichage de l'image originale ---
212
+ ax_orig = fig.add_subplot(gs[0, 0])
213
+ ax_orig.axis('off')
214
+ ax_orig.set_title("Original Image")
215
+ ax_orig.imshow(img_array)
216
+
217
+ # --- Préparation du subplot pour les courbes de pertes ---
218
+ ax_loss = fig.add_subplot(gs[0, 1])
219
+ all_losses = [[loss for loss in net.losses] for net in nets]
220
+ if max(len(lst) for lst in all_losses) == 1:
221
+ lenlosses = 2
222
+ else:
223
+ lenlosses = max(len(lst) for lst in all_losses)
224
+ ax_loss.set_xlim(1, lenlosses)
225
+
226
+ # --- Boucle sur chaque réseau pour afficher les prédictions et pertes ---
227
+ for k, net in enumerate(nets):
228
+ # Subplot pour l'image reconstruite
229
+ ax = fig.add_subplot(gs[1, k])
230
+ ax.axis('off')
231
+ ax.set_title(net.name)
232
+
233
+ # Prédiction et reconstruction de l'image
234
+ pred_img = net(inputs).reshape(h, w, 3)
235
+
236
+ # Tracé des pertes cumulées
237
+ ax_loss.plot(np.arange(1, len(all_losses[k])+1), all_losses[k],label = net.name)
238
+ ax_loss.set_yscale('log', nonpositive='mask')
239
+
240
+ # Affichage de l'image prédite
241
+ ax.imshow(pred_img)
242
+ ax_loss.legend()
243
+
244
+ # --- Affichage final ---
245
+ fig.canvas.draw_idle()
246
+ plt.tight_layout()
247
+ plt.ion()
248
+ plt.show()
249
+ plot.help = fPrintDoc(plot)