MEDfl 0.2.1__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- MEDfl/LearningManager/__init__.py +13 -13
- MEDfl/LearningManager/client.py +150 -181
- MEDfl/LearningManager/dynamicModal.py +287 -287
- MEDfl/LearningManager/federated_dataset.py +60 -60
- MEDfl/LearningManager/flpipeline.py +192 -192
- MEDfl/LearningManager/model.py +223 -223
- MEDfl/LearningManager/params.yaml +14 -14
- MEDfl/LearningManager/params_optimiser.py +442 -442
- MEDfl/LearningManager/plot.py +229 -229
- MEDfl/LearningManager/server.py +181 -189
- MEDfl/LearningManager/strategy.py +82 -138
- MEDfl/LearningManager/utils.py +331 -331
- MEDfl/NetManager/__init__.py +10 -10
- MEDfl/NetManager/database_connector.py +43 -43
- MEDfl/NetManager/dataset.py +92 -92
- MEDfl/NetManager/flsetup.py +320 -320
- MEDfl/NetManager/net_helper.py +254 -254
- MEDfl/NetManager/net_manager_queries.py +142 -142
- MEDfl/NetManager/network.py +194 -194
- MEDfl/NetManager/node.py +184 -184
- MEDfl/__init__.py +2 -2
- MEDfl/scripts/__init__.py +1 -1
- MEDfl/scripts/base.py +29 -29
- MEDfl/scripts/create_db.py +126 -126
- Medfl/LearningManager/__init__.py +13 -0
- Medfl/LearningManager/client.py +150 -0
- Medfl/LearningManager/dynamicModal.py +287 -0
- Medfl/LearningManager/federated_dataset.py +60 -0
- Medfl/LearningManager/flpipeline.py +192 -0
- Medfl/LearningManager/model.py +223 -0
- Medfl/LearningManager/params.yaml +14 -0
- Medfl/LearningManager/params_optimiser.py +442 -0
- Medfl/LearningManager/plot.py +229 -0
- Medfl/LearningManager/server.py +181 -0
- Medfl/LearningManager/strategy.py +82 -0
- Medfl/LearningManager/utils.py +331 -0
- Medfl/NetManager/__init__.py +10 -0
- Medfl/NetManager/database_connector.py +43 -0
- Medfl/NetManager/dataset.py +92 -0
- Medfl/NetManager/flsetup.py +320 -0
- Medfl/NetManager/net_helper.py +254 -0
- Medfl/NetManager/net_manager_queries.py +142 -0
- Medfl/NetManager/network.py +194 -0
- Medfl/NetManager/node.py +184 -0
- Medfl/__init__.py +3 -0
- Medfl/scripts/__init__.py +2 -0
- Medfl/scripts/base.py +30 -0
- Medfl/scripts/create_db.py +126 -0
- alembic/env.py +61 -61
- {MEDfl-0.2.1.dist-info → medfl-2.0.0.dist-info}/METADATA +120 -108
- medfl-2.0.0.dist-info/RECORD +55 -0
- {MEDfl-0.2.1.dist-info → medfl-2.0.0.dist-info}/WHEEL +1 -1
- {MEDfl-0.2.1.dist-info → medfl-2.0.0.dist-info/licenses}/LICENSE +674 -674
- MEDfl-0.2.1.dist-info/RECORD +0 -31
- {MEDfl-0.2.1.dist-info → medfl-2.0.0.dist-info}/top_level.txt +0 -0
MEDfl/LearningManager/plot.py
CHANGED
@@ -1,229 +1,229 @@
|
|
1
|
-
import matplotlib.pyplot as plt
|
2
|
-
import numpy as np
|
3
|
-
import seaborn as sns
|
4
|
-
|
5
|
-
from .utils import *
|
6
|
-
|
7
|
-
# Replace this with your actual code for data collection
|
8
|
-
results_dict = {
|
9
|
-
("LR: 0.001, Optimizer: Adam", "accuracy"): [0.85, 0.89, 0.92, 0.94, ...],
|
10
|
-
("LR: 0.001, Optimizer: Adam", "loss"): [0.2, 0.15, 0.1, 0.08, ...],
|
11
|
-
("LR: 0.01, Optimizer: SGD", "accuracy"): [0.88, 0.91, 0.93, 0.95, ...],
|
12
|
-
("LR: 0.01, Optimizer: SGD", "loss"): [0.18, 0.13, 0.09, 0.07, ...],
|
13
|
-
("LR: 0.1, Optimizer: Adam", "accuracy"): [0.82, 0.87, 0.91, 0.93, ...],
|
14
|
-
("LR: 0.1, Optimizer: Adam", "loss"): [0.25, 0.2, 0.15, 0.12, ...],
|
15
|
-
}
|
16
|
-
"""
|
17
|
-
server should have:
|
18
|
-
#len = num of rounds
|
19
|
-
self.accuracies
|
20
|
-
self.losses
|
21
|
-
|
22
|
-
Client should have
|
23
|
-
# len = num of epochs
|
24
|
-
self.accuracies
|
25
|
-
self.losses
|
26
|
-
self.epsilons
|
27
|
-
self.deltas
|
28
|
-
|
29
|
-
#common things : LR,SGD, Aggregation
|
30
|
-
|
31
|
-
"""
|
32
|
-
|
33
|
-
|
34
|
-
class AccuracyLossPlotter:
|
35
|
-
"""
|
36
|
-
A utility class for plotting accuracy and loss metrics based on experiment results.
|
37
|
-
|
38
|
-
Args:
|
39
|
-
results_dict (dict): Dictionary containing experiment results organized by parameters and metrics.
|
40
|
-
|
41
|
-
Attributes:
|
42
|
-
results_dict (dict): Dictionary containing experiment results organized by parameters and metrics.
|
43
|
-
parameters (list): List of unique parameters in the experiment results.
|
44
|
-
metrics (list): List of unique metrics in the experiment results.
|
45
|
-
iterations (range): Range of iterations (rounds or epochs) in the experiment.
|
46
|
-
"""
|
47
|
-
|
48
|
-
def __init__(self, results_dict):
|
49
|
-
"""
|
50
|
-
Initialize the AccuracyLossPlotter with experiment results.
|
51
|
-
|
52
|
-
Args:
|
53
|
-
results_dict (dict): Dictionary containing experiment results organized by parameters and metrics.
|
54
|
-
"""
|
55
|
-
self.results_dict = results_dict
|
56
|
-
self.parameters = list(
|
57
|
-
set([param[0] for param in results_dict.keys()])
|
58
|
-
)
|
59
|
-
self.metrics = list(set([param[1] for param in results_dict.keys()]))
|
60
|
-
self.iterations = range(1, len(list(results_dict.values())[0]) + 1)
|
61
|
-
|
62
|
-
def plot_accuracy_loss(self):
|
63
|
-
"""
|
64
|
-
Plot accuracy and loss metrics for different parameters.
|
65
|
-
"""
|
66
|
-
|
67
|
-
plt.figure(figsize=(8, 6))
|
68
|
-
|
69
|
-
for param in self.parameters:
|
70
|
-
for metric in self.metrics:
|
71
|
-
key = (param, metric)
|
72
|
-
values = self.results_dict[key]
|
73
|
-
plt.plot(
|
74
|
-
self.iterations,
|
75
|
-
values,
|
76
|
-
label=f"{param} ({metric})",
|
77
|
-
marker="o",
|
78
|
-
linestyle="-",
|
79
|
-
)
|
80
|
-
|
81
|
-
plt.xlabel("Rounds")
|
82
|
-
plt.ylabel("Accuracy / Loss")
|
83
|
-
plt.title("Accuracy and Loss by Parameters")
|
84
|
-
plt.legend()
|
85
|
-
plt.grid(True)
|
86
|
-
plt.show()
|
87
|
-
|
88
|
-
@staticmethod
|
89
|
-
def plot_global_confusion_matrix(pipeline_name: str):
|
90
|
-
"""
|
91
|
-
Plot a global confusion matrix based on pipeline results.
|
92
|
-
|
93
|
-
Args:
|
94
|
-
pipeline_name (str): Name of the pipeline.
|
95
|
-
|
96
|
-
Returns:
|
97
|
-
None
|
98
|
-
"""
|
99
|
-
# Get the id of the pipeline by name
|
100
|
-
pipeline_id = get_pipeline_from_name(pipeline_name)
|
101
|
-
# get the confusion matrix pf the pipeline
|
102
|
-
confusion_matrix = get_pipeline_confusion_matrix(pipeline_id)
|
103
|
-
|
104
|
-
# Extracting confusion matrix values
|
105
|
-
TP = confusion_matrix['TP']
|
106
|
-
FP = confusion_matrix['FP']
|
107
|
-
FN = confusion_matrix['FN']
|
108
|
-
TN = confusion_matrix['TN']
|
109
|
-
|
110
|
-
# Creating a matrix for visualization
|
111
|
-
matrix = [[TN, FP],
|
112
|
-
[FN, TP]]
|
113
|
-
|
114
|
-
# Plotting the confusion matrix as a heatmap
|
115
|
-
plt.figure(figsize=(6, 4))
|
116
|
-
sns.heatmap(matrix, annot=True, fmt='d', cmap='Blues',
|
117
|
-
xticklabels=['Predicted Negative', 'Predicted Positive'],
|
118
|
-
yticklabels=['Actual Negative', 'Actual Positive'])
|
119
|
-
plt.title('Global Confusion Matrix')
|
120
|
-
plt.xlabel('Predicted label')
|
121
|
-
plt.ylabel('True label')
|
122
|
-
plt.tight_layout()
|
123
|
-
|
124
|
-
# Display the confusion matrix heatmap
|
125
|
-
plt.show()
|
126
|
-
|
127
|
-
@staticmethod
|
128
|
-
def plot_confusion_Matrix_by_node(node_name: str, pipeline_name: str):
|
129
|
-
"""
|
130
|
-
Plot a confusion matrix for a specific node in the pipeline.
|
131
|
-
|
132
|
-
Args:
|
133
|
-
node_name (str): Name of the node.
|
134
|
-
pipeline_name (str): Name of the pipeline.
|
135
|
-
|
136
|
-
Returns:
|
137
|
-
None
|
138
|
-
"""
|
139
|
-
|
140
|
-
# Get the id of the pipeline by name
|
141
|
-
pipeline_id = get_pipeline_from_name(pipeline_name)
|
142
|
-
# get the confusion matrix pf the pipeline
|
143
|
-
confusion_matrix = get_node_confusion_matrix(
|
144
|
-
pipeline_id, node_name=node_name)
|
145
|
-
|
146
|
-
# Extracting confusion matrix values
|
147
|
-
TP = confusion_matrix['TP']
|
148
|
-
FP = confusion_matrix['FP']
|
149
|
-
FN = confusion_matrix['FN']
|
150
|
-
TN = confusion_matrix['TN']
|
151
|
-
|
152
|
-
# Creating a matrix for visualization
|
153
|
-
matrix = [[TN, FP],
|
154
|
-
[FN, TP]]
|
155
|
-
|
156
|
-
# Plotting the confusion matrix as a heatmap
|
157
|
-
plt.figure(figsize=(6, 4))
|
158
|
-
sns.heatmap(matrix, annot=True, fmt='d', cmap='Blues',
|
159
|
-
xticklabels=['Predicted Negative', 'Predicted Positive'],
|
160
|
-
yticklabels=['Actual Negative', 'Actual Positive'])
|
161
|
-
plt.title('Confusion Matrix of node: '+node_name)
|
162
|
-
plt.xlabel('Predicted label')
|
163
|
-
plt.ylabel('True label')
|
164
|
-
plt.tight_layout()
|
165
|
-
|
166
|
-
# Display the confusion matrix heatmap
|
167
|
-
plt.show()
|
168
|
-
return
|
169
|
-
|
170
|
-
@staticmethod
|
171
|
-
def plot_classification_report(pipeline_name: str):
|
172
|
-
"""
|
173
|
-
Plot a comparison of classification report metrics between nodes.
|
174
|
-
|
175
|
-
Args:
|
176
|
-
pipeline_name (str): Name of the pipeline.
|
177
|
-
|
178
|
-
Returns:
|
179
|
-
None
|
180
|
-
"""
|
181
|
-
|
182
|
-
colors = ['#FF5733', '#6A5ACD', '#3CB371', '#FFD700', '#FFA500', '#8A2BE2', '#00FFFF', '#FF00FF', '#A52A2A', '#00FF00']
|
183
|
-
|
184
|
-
# Get the id of the pipeline by name
|
185
|
-
pipeline_id = get_pipeline_from_name(pipeline_name)
|
186
|
-
|
187
|
-
pipeline_results = get_pipeline_result(pipeline_id)
|
188
|
-
|
189
|
-
nodesList = pipeline_results['nodename']
|
190
|
-
classificationReports = []
|
191
|
-
|
192
|
-
for index, node in enumerate(nodesList):
|
193
|
-
classificationReports.append({
|
194
|
-
'Accuracy': pipeline_results['accuracy'][index],
|
195
|
-
'Sensitivity/Recall': pipeline_results['sensivity'][index],
|
196
|
-
'PPV/Precision': pipeline_results['ppv'][index],
|
197
|
-
'NPV': pipeline_results['npv'][index],
|
198
|
-
'F1-score': pipeline_results['f1score'][index],
|
199
|
-
'False positive rate': pipeline_results['fpr'][index],
|
200
|
-
'True positive rate': pipeline_results['tpr'][index]
|
201
|
-
})
|
202
|
-
|
203
|
-
metric_labels = list(classificationReports[0].keys()) # Assuming both reports have the same keys
|
204
|
-
|
205
|
-
# Set the positions of the bars on the x-axis
|
206
|
-
x = np.arange(len(metric_labels))
|
207
|
-
|
208
|
-
# Set the width of the bars
|
209
|
-
width = 0.35
|
210
|
-
|
211
|
-
plt.figure(figsize=(12, 6))
|
212
|
-
|
213
|
-
for index, report in enumerate(classificationReports):
|
214
|
-
metric = list(report.values())
|
215
|
-
plt.bar(x + (index - len(nodesList) / 2) * width / len(nodesList), metric, width / len(nodesList),
|
216
|
-
label=nodesList[index], color=colors[index % len(colors)])
|
217
|
-
|
218
|
-
# Adding labels, title, and legend
|
219
|
-
plt.xlabel('Metrics')
|
220
|
-
plt.ylabel('Values')
|
221
|
-
plt.title('Comparison of Classification Report Metrics between Nodes')
|
222
|
-
plt.xticks(ticks=x, labels=metric_labels, rotation=45)
|
223
|
-
plt.legend()
|
224
|
-
|
225
|
-
# Show plot
|
226
|
-
plt.tight_layout()
|
227
|
-
plt.show()
|
228
|
-
|
229
|
-
return
|
1
|
+
import matplotlib.pyplot as plt
|
2
|
+
import numpy as np
|
3
|
+
import seaborn as sns
|
4
|
+
|
5
|
+
from .utils import *
|
6
|
+
|
7
|
+
# Replace this with your actual code for data collection
|
8
|
+
results_dict = {
|
9
|
+
("LR: 0.001, Optimizer: Adam", "accuracy"): [0.85, 0.89, 0.92, 0.94, ...],
|
10
|
+
("LR: 0.001, Optimizer: Adam", "loss"): [0.2, 0.15, 0.1, 0.08, ...],
|
11
|
+
("LR: 0.01, Optimizer: SGD", "accuracy"): [0.88, 0.91, 0.93, 0.95, ...],
|
12
|
+
("LR: 0.01, Optimizer: SGD", "loss"): [0.18, 0.13, 0.09, 0.07, ...],
|
13
|
+
("LR: 0.1, Optimizer: Adam", "accuracy"): [0.82, 0.87, 0.91, 0.93, ...],
|
14
|
+
("LR: 0.1, Optimizer: Adam", "loss"): [0.25, 0.2, 0.15, 0.12, ...],
|
15
|
+
}
|
16
|
+
"""
|
17
|
+
server should have:
|
18
|
+
#len = num of rounds
|
19
|
+
self.accuracies
|
20
|
+
self.losses
|
21
|
+
|
22
|
+
Client should have
|
23
|
+
# len = num of epochs
|
24
|
+
self.accuracies
|
25
|
+
self.losses
|
26
|
+
self.epsilons
|
27
|
+
self.deltas
|
28
|
+
|
29
|
+
#common things : LR,SGD, Aggregation
|
30
|
+
|
31
|
+
"""
|
32
|
+
|
33
|
+
|
34
|
+
class AccuracyLossPlotter:
|
35
|
+
"""
|
36
|
+
A utility class for plotting accuracy and loss metrics based on experiment results.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
results_dict (dict): Dictionary containing experiment results organized by parameters and metrics.
|
40
|
+
|
41
|
+
Attributes:
|
42
|
+
results_dict (dict): Dictionary containing experiment results organized by parameters and metrics.
|
43
|
+
parameters (list): List of unique parameters in the experiment results.
|
44
|
+
metrics (list): List of unique metrics in the experiment results.
|
45
|
+
iterations (range): Range of iterations (rounds or epochs) in the experiment.
|
46
|
+
"""
|
47
|
+
|
48
|
+
def __init__(self, results_dict):
|
49
|
+
"""
|
50
|
+
Initialize the AccuracyLossPlotter with experiment results.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
results_dict (dict): Dictionary containing experiment results organized by parameters and metrics.
|
54
|
+
"""
|
55
|
+
self.results_dict = results_dict
|
56
|
+
self.parameters = list(
|
57
|
+
set([param[0] for param in results_dict.keys()])
|
58
|
+
)
|
59
|
+
self.metrics = list(set([param[1] for param in results_dict.keys()]))
|
60
|
+
self.iterations = range(1, len(list(results_dict.values())[0]) + 1)
|
61
|
+
|
62
|
+
def plot_accuracy_loss(self):
|
63
|
+
"""
|
64
|
+
Plot accuracy and loss metrics for different parameters.
|
65
|
+
"""
|
66
|
+
|
67
|
+
plt.figure(figsize=(8, 6))
|
68
|
+
|
69
|
+
for param in self.parameters:
|
70
|
+
for metric in self.metrics:
|
71
|
+
key = (param, metric)
|
72
|
+
values = self.results_dict[key]
|
73
|
+
plt.plot(
|
74
|
+
self.iterations,
|
75
|
+
values,
|
76
|
+
label=f"{param} ({metric})",
|
77
|
+
marker="o",
|
78
|
+
linestyle="-",
|
79
|
+
)
|
80
|
+
|
81
|
+
plt.xlabel("Rounds")
|
82
|
+
plt.ylabel("Accuracy / Loss")
|
83
|
+
plt.title("Accuracy and Loss by Parameters")
|
84
|
+
plt.legend()
|
85
|
+
plt.grid(True)
|
86
|
+
plt.show()
|
87
|
+
|
88
|
+
@staticmethod
|
89
|
+
def plot_global_confusion_matrix(pipeline_name: str):
|
90
|
+
"""
|
91
|
+
Plot a global confusion matrix based on pipeline results.
|
92
|
+
|
93
|
+
Args:
|
94
|
+
pipeline_name (str): Name of the pipeline.
|
95
|
+
|
96
|
+
Returns:
|
97
|
+
None
|
98
|
+
"""
|
99
|
+
# Get the id of the pipeline by name
|
100
|
+
pipeline_id = get_pipeline_from_name(pipeline_name)
|
101
|
+
# get the confusion matrix pf the pipeline
|
102
|
+
confusion_matrix = get_pipeline_confusion_matrix(pipeline_id)
|
103
|
+
|
104
|
+
# Extracting confusion matrix values
|
105
|
+
TP = confusion_matrix['TP']
|
106
|
+
FP = confusion_matrix['FP']
|
107
|
+
FN = confusion_matrix['FN']
|
108
|
+
TN = confusion_matrix['TN']
|
109
|
+
|
110
|
+
# Creating a matrix for visualization
|
111
|
+
matrix = [[TN, FP],
|
112
|
+
[FN, TP]]
|
113
|
+
|
114
|
+
# Plotting the confusion matrix as a heatmap
|
115
|
+
plt.figure(figsize=(6, 4))
|
116
|
+
sns.heatmap(matrix, annot=True, fmt='d', cmap='Blues',
|
117
|
+
xticklabels=['Predicted Negative', 'Predicted Positive'],
|
118
|
+
yticklabels=['Actual Negative', 'Actual Positive'])
|
119
|
+
plt.title('Global Confusion Matrix')
|
120
|
+
plt.xlabel('Predicted label')
|
121
|
+
plt.ylabel('True label')
|
122
|
+
plt.tight_layout()
|
123
|
+
|
124
|
+
# Display the confusion matrix heatmap
|
125
|
+
plt.show()
|
126
|
+
|
127
|
+
@staticmethod
|
128
|
+
def plot_confusion_Matrix_by_node(node_name: str, pipeline_name: str):
|
129
|
+
"""
|
130
|
+
Plot a confusion matrix for a specific node in the pipeline.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
node_name (str): Name of the node.
|
134
|
+
pipeline_name (str): Name of the pipeline.
|
135
|
+
|
136
|
+
Returns:
|
137
|
+
None
|
138
|
+
"""
|
139
|
+
|
140
|
+
# Get the id of the pipeline by name
|
141
|
+
pipeline_id = get_pipeline_from_name(pipeline_name)
|
142
|
+
# get the confusion matrix pf the pipeline
|
143
|
+
confusion_matrix = get_node_confusion_matrix(
|
144
|
+
pipeline_id, node_name=node_name)
|
145
|
+
|
146
|
+
# Extracting confusion matrix values
|
147
|
+
TP = confusion_matrix['TP']
|
148
|
+
FP = confusion_matrix['FP']
|
149
|
+
FN = confusion_matrix['FN']
|
150
|
+
TN = confusion_matrix['TN']
|
151
|
+
|
152
|
+
# Creating a matrix for visualization
|
153
|
+
matrix = [[TN, FP],
|
154
|
+
[FN, TP]]
|
155
|
+
|
156
|
+
# Plotting the confusion matrix as a heatmap
|
157
|
+
plt.figure(figsize=(6, 4))
|
158
|
+
sns.heatmap(matrix, annot=True, fmt='d', cmap='Blues',
|
159
|
+
xticklabels=['Predicted Negative', 'Predicted Positive'],
|
160
|
+
yticklabels=['Actual Negative', 'Actual Positive'])
|
161
|
+
plt.title('Confusion Matrix of node: '+node_name)
|
162
|
+
plt.xlabel('Predicted label')
|
163
|
+
plt.ylabel('True label')
|
164
|
+
plt.tight_layout()
|
165
|
+
|
166
|
+
# Display the confusion matrix heatmap
|
167
|
+
plt.show()
|
168
|
+
return
|
169
|
+
|
170
|
+
@staticmethod
|
171
|
+
def plot_classification_report(pipeline_name: str):
|
172
|
+
"""
|
173
|
+
Plot a comparison of classification report metrics between nodes.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
pipeline_name (str): Name of the pipeline.
|
177
|
+
|
178
|
+
Returns:
|
179
|
+
None
|
180
|
+
"""
|
181
|
+
|
182
|
+
colors = ['#FF5733', '#6A5ACD', '#3CB371', '#FFD700', '#FFA500', '#8A2BE2', '#00FFFF', '#FF00FF', '#A52A2A', '#00FF00']
|
183
|
+
|
184
|
+
# Get the id of the pipeline by name
|
185
|
+
pipeline_id = get_pipeline_from_name(pipeline_name)
|
186
|
+
|
187
|
+
pipeline_results = get_pipeline_result(pipeline_id)
|
188
|
+
|
189
|
+
nodesList = pipeline_results['nodename']
|
190
|
+
classificationReports = []
|
191
|
+
|
192
|
+
for index, node in enumerate(nodesList):
|
193
|
+
classificationReports.append({
|
194
|
+
'Accuracy': pipeline_results['accuracy'][index],
|
195
|
+
'Sensitivity/Recall': pipeline_results['sensivity'][index],
|
196
|
+
'PPV/Precision': pipeline_results['ppv'][index],
|
197
|
+
'NPV': pipeline_results['npv'][index],
|
198
|
+
'F1-score': pipeline_results['f1score'][index],
|
199
|
+
'False positive rate': pipeline_results['fpr'][index],
|
200
|
+
'True positive rate': pipeline_results['tpr'][index]
|
201
|
+
})
|
202
|
+
|
203
|
+
metric_labels = list(classificationReports[0].keys()) # Assuming both reports have the same keys
|
204
|
+
|
205
|
+
# Set the positions of the bars on the x-axis
|
206
|
+
x = np.arange(len(metric_labels))
|
207
|
+
|
208
|
+
# Set the width of the bars
|
209
|
+
width = 0.35
|
210
|
+
|
211
|
+
plt.figure(figsize=(12, 6))
|
212
|
+
|
213
|
+
for index, report in enumerate(classificationReports):
|
214
|
+
metric = list(report.values())
|
215
|
+
plt.bar(x + (index - len(nodesList) / 2) * width / len(nodesList), metric, width / len(nodesList),
|
216
|
+
label=nodesList[index], color=colors[index % len(colors)])
|
217
|
+
|
218
|
+
# Adding labels, title, and legend
|
219
|
+
plt.xlabel('Metrics')
|
220
|
+
plt.ylabel('Values')
|
221
|
+
plt.title('Comparison of Classification Report Metrics between Nodes')
|
222
|
+
plt.xticks(ticks=x, labels=metric_labels, rotation=45)
|
223
|
+
plt.legend()
|
224
|
+
|
225
|
+
# Show plot
|
226
|
+
plt.tight_layout()
|
227
|
+
plt.show()
|
228
|
+
|
229
|
+
return
|