MEDfl 0.1.26__py3-none-any.whl → 0.1.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. MEDfl/LearningManager/__init__.py +13 -0
  2. MEDfl/LearningManager/client.py +150 -0
  3. MEDfl/LearningManager/dynamicModal.py +287 -0
  4. MEDfl/LearningManager/federated_dataset.py +60 -0
  5. MEDfl/LearningManager/flpipeline.py +192 -0
  6. MEDfl/LearningManager/model.py +223 -0
  7. MEDfl/LearningManager/params.yaml +14 -0
  8. MEDfl/LearningManager/params_optimiser.py +442 -0
  9. MEDfl/LearningManager/plot.py +229 -0
  10. MEDfl/LearningManager/server.py +181 -0
  11. MEDfl/LearningManager/strategy.py +82 -0
  12. MEDfl/LearningManager/utils.py +308 -0
  13. MEDfl/NetManager/__init__.py +9 -0
  14. MEDfl/NetManager/database_connector.py +48 -0
  15. MEDfl/NetManager/dataset.py +92 -0
  16. MEDfl/NetManager/flsetup.py +320 -0
  17. MEDfl/NetManager/net_helper.py +248 -0
  18. MEDfl/NetManager/net_manager_queries.py +137 -0
  19. MEDfl/NetManager/network.py +174 -0
  20. MEDfl/NetManager/node.py +178 -0
  21. MEDfl/__init__.py +2 -0
  22. {MEDfl-0.1.26.data → MEDfl-0.1.27.data}/scripts/setup_mysql.sh +22 -22
  23. {MEDfl-0.1.26.dist-info → MEDfl-0.1.27.dist-info}/METADATA +127 -127
  24. MEDfl-0.1.27.dist-info/RECORD +53 -0
  25. {MEDfl-0.1.26.dist-info → MEDfl-0.1.27.dist-info}/WHEEL +1 -1
  26. Medfl/LearningManager/__init__.py +13 -13
  27. Medfl/LearningManager/client.py +150 -150
  28. Medfl/LearningManager/dynamicModal.py +287 -287
  29. Medfl/LearningManager/federated_dataset.py +60 -60
  30. Medfl/LearningManager/flpipeline.py +192 -192
  31. Medfl/LearningManager/model.py +223 -223
  32. Medfl/LearningManager/params.yaml +14 -14
  33. Medfl/LearningManager/params_optimiser.py +442 -442
  34. Medfl/LearningManager/plot.py +229 -229
  35. Medfl/LearningManager/server.py +181 -181
  36. Medfl/LearningManager/strategy.py +82 -82
  37. Medfl/LearningManager/utils.py +308 -308
  38. Medfl/NetManager/__init__.py +9 -9
  39. Medfl/NetManager/database_connector.py +48 -48
  40. Medfl/NetManager/dataset.py +92 -92
  41. Medfl/NetManager/flsetup.py +320 -320
  42. Medfl/NetManager/net_helper.py +248 -248
  43. Medfl/NetManager/net_manager_queries.py +137 -137
  44. Medfl/NetManager/network.py +174 -174
  45. Medfl/NetManager/node.py +178 -178
  46. Medfl/__init__.py +1 -1
  47. alembic/env.py +61 -61
  48. scripts/base.py +29 -29
  49. scripts/config.ini +5 -5
  50. scripts/create_db.py +133 -133
  51. MEDfl-0.1.26.dist-info/RECORD +0 -32
  52. {MEDfl-0.1.26.dist-info → MEDfl-0.1.27.dist-info}/top_level.txt +0 -0
@@ -1,308 +1,308 @@
1
- #!/usr/bin/env python3
2
-
3
- import pkg_resources
4
- import torch
5
- import yaml
6
- from sklearn.metrics import *
7
- from yaml.loader import SafeLoader
8
-
9
-
10
- from MEDfl.NetManager.database_connector import DatabaseManager
11
-
12
- # from scripts.base import *
13
- import json
14
-
15
-
16
- import pandas as pd
17
- import numpy as np
18
-
19
- import os
20
- import configparser
21
-
22
- import subprocess
23
- import ast
24
-
25
- from sqlalchemy import text
26
-
27
-
28
- # Get the directory of the current script
29
- current_directory = os.path.dirname(os.path.abspath(__file__))
30
-
31
- # Load configuration from the config file
32
- yaml_path = os.path.join(current_directory, 'params.yaml')
33
-
34
- with open(yaml_path) as g:
35
- params = yaml.load(g, Loader=SafeLoader)
36
-
37
- # global_yaml_path = pkg_resources.resource_filename(__name__, "../../global_params.yaml")
38
- # with open(global_yaml_path) as g:
39
- # global_params = yaml.load(g, Loader=SafeLoader)
40
-
41
-
42
- # Default path for the config file
43
- DEFAULT_CONFIG_PATH = 'db_config.ini'
44
-
45
-
46
- def load_db_config():
47
- config = os.environ.get('MEDfl_DB_CONFIG')
48
-
49
- if config:
50
- return ast.literal_eval(config)
51
- else:
52
- raise ValueError(f"MEDfl db config not found")
53
-
54
- # Function to allow users to set config path programmatically
55
-
56
-
57
- def set_db_config(config_path):
58
- config = configparser.ConfigParser()
59
- config.read(config_path)
60
- if (config['mysql']):
61
- os.environ['MEDfl_DB_CONFIG'] = str(dict(config['mysql']))
62
- else:
63
- raise ValueError(f"mysql key not found in file '{config_path}'")
64
-
65
-
66
-
67
- # Create databas
68
-
69
-
70
- def create_MEDfl_db():
71
- script_path = os.path.join(os.path.dirname(
72
- __file__), 'scripts', 'create_db.sh')
73
- subprocess.run(['sh', script_path], check=True)
74
-
75
-
76
- def custom_classification_report(y_true, y_pred_prob):
77
- """
78
- Compute custom classification report metrics including accuracy, sensitivity, specificity, precision, NPV,
79
- F1-score, false positive rate, and true positive rate.
80
-
81
- Args:
82
- y_true (array-like): True labels.
83
- y_pred (array-like): Predicted labels.
84
-
85
- Returns:
86
- dict: A dictionary containing custom classification report metrics.
87
- """
88
- y_pred = (y_pred_prob).round(
89
- ) # Round absolute values of predicted probabilities to the nearest integer
90
-
91
- auc = roc_auc_score(y_true, y_pred_prob) # Calculate AUC
92
-
93
- tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
94
-
95
- # Accuracy
96
- denominator_acc = tp + tn + fp + fn
97
- acc = (tp + tn) / denominator_acc if denominator_acc != 0 else 0.0
98
-
99
- # Sensitivity/Recall
100
- denominator_sen = tp + fn
101
- sen = tp / denominator_sen if denominator_sen != 0 else 0.0
102
-
103
- # Specificity
104
- denominator_sp = tn + fp
105
- sp = tn / denominator_sp if denominator_sp != 0 else 0.0
106
-
107
- # PPV/Precision
108
- denominator_ppv = tp + fp
109
- ppv = tp / denominator_ppv if denominator_ppv != 0 else 0.0
110
-
111
- # NPV
112
- denominator_npv = tn + fn
113
- npv = tn / denominator_npv if denominator_npv != 0 else 0.0
114
-
115
- # F1 Score
116
- denominator_f1 = sen + ppv
117
- f1 = 2 * (sen * ppv) / denominator_f1 if denominator_f1 != 0 else 0.0
118
-
119
- # False Positive Rate
120
- denominator_fpr = fp + tn
121
- fpr = fp / denominator_fpr if denominator_fpr != 0 else 0.0
122
-
123
- # True Positive Rate
124
- denominator_tpr = tp + fn
125
- tpr = tp / denominator_tpr if denominator_tpr != 0 else 0.0
126
-
127
- return {
128
- "confusion matrix": {"TP": tp, "FP": fp, "FN": fn, "TN": tn},
129
- "Accuracy": round(acc, 3),
130
- "Sensitivity/Recall": round(sen, 3),
131
- "Specificity": round(sp, 3),
132
- "PPV/Precision": round(ppv, 3),
133
- "NPV": round(npv, 3),
134
- "F1-score": round(f1, 3),
135
- "False positive rate": round(fpr, 3),
136
- "True positive rate": round(tpr, 3),
137
- "auc": auc
138
- }
139
-
140
-
141
- def test(model, test_loader, device=torch.device("cpu")):
142
- """
143
- Evaluate a model using a test loader and return a custom classification report.
144
-
145
- Args:
146
- model (torch.nn.Module): PyTorch model to evaluate.
147
- test_loader (torch.utils.data.DataLoader): DataLoader for the test dataset.
148
- device (torch.device, optional): Device for model evaluation. Default is "cpu".
149
-
150
- Returns:
151
- dict: A dictionary containing custom classification report metrics.
152
- """
153
-
154
- model.eval()
155
- with torch.no_grad():
156
- X_test, y_test = test_loader.dataset[:][0].to(
157
- device), test_loader.dataset[:][1].to(device)
158
- y_hat_prob = torch.squeeze(model(X_test), 1).cpu()
159
-
160
- return custom_classification_report(y_test.cpu().numpy(), y_hat_prob.cpu().numpy())
161
-
162
-
163
- column_map = {"object": "VARCHAR(255)", "int64": "INT", "float64": "FLOAT"}
164
-
165
-
166
- def empty_db():
167
- """
168
- Empty the database by deleting records from multiple tables and resetting auto-increment counters.
169
-
170
- Returns:
171
- None
172
- """
173
- db_manager = DatabaseManager()
174
- db_manager.connect()
175
- my_eng = db_manager.get_connection()
176
-
177
- # my_eng.execute(text(f"DELETE FROM {'DataSets'}"))
178
- my_eng.execute(text(f"DELETE FROM {'Nodes'}"))
179
- my_eng.execute(text(f"DELETE FROM {'FedDatasets'}"))
180
- my_eng.execute(text(f"DELETE FROM {'Networks'}"))
181
- my_eng.execute(text(f"DELETE FROM {'FLsetup'}"))
182
-
183
- my_eng.execute(text(f"DELETE FROM {'FLpipeline'}"))
184
- my_eng.execute(text(f"ALTER TABLE {'Nodes'} AUTO_INCREMENT = 1"))
185
- my_eng.execute(text(f"ALTER TABLE {'Networks'} AUTO_INCREMENT = 1"))
186
- my_eng.execute(text(f"ALTER TABLE {'FedDatasets'} AUTO_INCREMENT = 1"))
187
- my_eng.execute(text(f"ALTER TABLE {'FLsetup'} AUTO_INCREMENT = 1"))
188
- my_eng.execute(text(f"ALTER TABLE {'FLpipeline'} AUTO_INCREMENT = 1"))
189
- my_eng.execute(text(f"DELETE FROM {'testResults'}"))
190
- my_eng.execute(text(f"DROP TABLE IF EXISTS {'MasterDataset'}"))
191
- my_eng.execute(text(f"DROP TABLE IF EXISTS {'DataSets'}"))
192
-
193
-
194
- def get_pipeline_from_name(name):
195
- """
196
- Get the pipeline ID from its name in the database.
197
-
198
- Args:
199
- name (str): Name of the pipeline.
200
-
201
- Returns:
202
- int: ID of the pipeline.
203
- """
204
- db_manager = DatabaseManager()
205
- db_manager.connect()
206
- my_eng = db_manager.get_connection()
207
-
208
- NodeId = int(
209
- pd.read_sql(
210
- text(f"SELECT id FROM FLpipeline WHERE name = '{name}'"), my_eng
211
- ).iloc[0, 0]
212
- )
213
- return NodeId
214
-
215
-
216
- def get_pipeline_confusion_matrix(pipeline_id):
217
- """
218
- Get the global confusion matrix for a pipeline based on test results.
219
-
220
- Args:
221
- pipeline_id (int): ID of the pipeline.
222
-
223
- Returns:
224
- dict: A dictionary representing the global confusion matrix.
225
- """
226
- db_manager = DatabaseManager()
227
- db_manager.connect()
228
- my_eng = db_manager.get_connection()
229
-
230
- data = pd.read_sql(
231
- text(
232
- f"SELECT confusionmatrix FROM testResults WHERE pipelineid = '{pipeline_id}'"), my_eng
233
- )
234
-
235
- # Convert the column of strings into a list of dictionaries representing confusion matrices
236
- confusion_matrices = [
237
- json.loads(matrix.replace("'", "\"")) for matrix in data['confusionmatrix']
238
- ]
239
-
240
- # Initialize variables for global confusion matrix
241
- global_TP = global_FP = global_FN = global_TN = 0
242
-
243
- # Iterate through each dictionary and sum the corresponding values for each category
244
- for matrix in confusion_matrices:
245
- global_TP += matrix['TP']
246
- global_FP += matrix['FP']
247
- global_FN += matrix['FN']
248
- global_TN += matrix['TN']
249
-
250
- # Create a global confusion matrix as a dictionary
251
- global_confusion_matrix = {
252
- 'TP': global_TP,
253
- 'FP': global_FP,
254
- 'FN': global_FN,
255
- 'TN': global_TN
256
- }
257
- # Return the list of dictionaries representing confusion matrices
258
- return global_confusion_matrix
259
-
260
-
261
- def get_node_confusion_matrix(pipeline_id, node_name):
262
- """
263
- Get the confusion matrix for a specific node in a pipeline based on test results.
264
-
265
- Args:
266
- pipeline_id (int): ID of the pipeline.
267
- node_name (str): Name of the node.
268
-
269
- Returns:
270
- dict: A dictionary representing the confusion matrix for the specified node.
271
- """
272
- db_manager = DatabaseManager()
273
- db_manager.connect()
274
- my_eng = db_manager.get_connection()
275
-
276
- data = pd.read_sql(
277
- text(
278
- f"SELECT confusionmatrix FROM testResults WHERE pipelineid = '{pipeline_id}' AND nodename = '{node_name}'"), my_eng
279
- )
280
-
281
- # Convert the column of strings into a list of dictionaries representing confusion matrices
282
- confusion_matrices = [
283
- json.loads(matrix.replace("'", "\"")) for matrix in data['confusionmatrix']
284
- ]
285
-
286
- # Return the list of dictionaries representing confusion matrices
287
- return confusion_matrices[0]
288
-
289
-
290
- def get_pipeline_result(pipeline_id):
291
- """
292
- Get the test results for a pipeline.
293
-
294
- Args:
295
- pipeline_id (int): ID of the pipeline.
296
-
297
- Returns:
298
- pandas.DataFrame: DataFrame containing test results for the specified pipeline.
299
- """
300
- db_manager = DatabaseManager()
301
- db_manager.connect()
302
- my_eng = db_manager.get_connection()
303
-
304
- data = pd.read_sql(
305
- text(
306
- f"SELECT * FROM testResults WHERE pipelineid = '{pipeline_id}'"), my_eng
307
- )
308
- return data
1
+ #!/usr/bin/env python3
2
+
3
+ import pkg_resources
4
+ import torch
5
+ import yaml
6
+ from sklearn.metrics import *
7
+ from yaml.loader import SafeLoader
8
+
9
+
10
+ from MEDfl.NetManager.database_connector import DatabaseManager
11
+
12
+ # from scripts.base import *
13
+ import json
14
+
15
+
16
+ import pandas as pd
17
+ import numpy as np
18
+
19
+ import os
20
+ import configparser
21
+
22
+ import subprocess
23
+ import ast
24
+
25
+ from sqlalchemy import text
26
+
27
+
28
+ # Get the directory of the current script
29
+ current_directory = os.path.dirname(os.path.abspath(__file__))
30
+
31
+ # Load configuration from the config file
32
+ yaml_path = os.path.join(current_directory, 'params.yaml')
33
+
34
+ with open(yaml_path) as g:
35
+ params = yaml.load(g, Loader=SafeLoader)
36
+
37
+ # global_yaml_path = pkg_resources.resource_filename(__name__, "../../global_params.yaml")
38
+ # with open(global_yaml_path) as g:
39
+ # global_params = yaml.load(g, Loader=SafeLoader)
40
+
41
+
42
+ # Default path for the config file
43
+ DEFAULT_CONFIG_PATH = 'db_config.ini'
44
+
45
+
46
+ def load_db_config():
47
+ config = os.environ.get('MEDfl_DB_CONFIG')
48
+
49
+ if config:
50
+ return ast.literal_eval(config)
51
+ else:
52
+ raise ValueError(f"MEDfl db config not found")
53
+
54
+ # Function to allow users to set config path programmatically
55
+
56
+
57
+ def set_db_config(config_path):
58
+ config = configparser.ConfigParser()
59
+ config.read(config_path)
60
+ if (config['mysql']):
61
+ os.environ['MEDfl_DB_CONFIG'] = str(dict(config['mysql']))
62
+ else:
63
+ raise ValueError(f"mysql key not found in file '{config_path}'")
64
+
65
+
66
+
67
+ # Create databas
68
+
69
+
70
+ def create_MEDfl_db():
71
+ script_path = os.path.join(os.path.dirname(
72
+ __file__), 'scripts', 'create_db.sh')
73
+ subprocess.run(['sh', script_path], check=True)
74
+
75
+
76
+ def custom_classification_report(y_true, y_pred_prob):
77
+ """
78
+ Compute custom classification report metrics including accuracy, sensitivity, specificity, precision, NPV,
79
+ F1-score, false positive rate, and true positive rate.
80
+
81
+ Args:
82
+ y_true (array-like): True labels.
83
+ y_pred (array-like): Predicted labels.
84
+
85
+ Returns:
86
+ dict: A dictionary containing custom classification report metrics.
87
+ """
88
+ y_pred = (y_pred_prob).round(
89
+ ) # Round absolute values of predicted probabilities to the nearest integer
90
+
91
+ auc = roc_auc_score(y_true, y_pred_prob) # Calculate AUC
92
+
93
+ tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
94
+
95
+ # Accuracy
96
+ denominator_acc = tp + tn + fp + fn
97
+ acc = (tp + tn) / denominator_acc if denominator_acc != 0 else 0.0
98
+
99
+ # Sensitivity/Recall
100
+ denominator_sen = tp + fn
101
+ sen = tp / denominator_sen if denominator_sen != 0 else 0.0
102
+
103
+ # Specificity
104
+ denominator_sp = tn + fp
105
+ sp = tn / denominator_sp if denominator_sp != 0 else 0.0
106
+
107
+ # PPV/Precision
108
+ denominator_ppv = tp + fp
109
+ ppv = tp / denominator_ppv if denominator_ppv != 0 else 0.0
110
+
111
+ # NPV
112
+ denominator_npv = tn + fn
113
+ npv = tn / denominator_npv if denominator_npv != 0 else 0.0
114
+
115
+ # F1 Score
116
+ denominator_f1 = sen + ppv
117
+ f1 = 2 * (sen * ppv) / denominator_f1 if denominator_f1 != 0 else 0.0
118
+
119
+ # False Positive Rate
120
+ denominator_fpr = fp + tn
121
+ fpr = fp / denominator_fpr if denominator_fpr != 0 else 0.0
122
+
123
+ # True Positive Rate
124
+ denominator_tpr = tp + fn
125
+ tpr = tp / denominator_tpr if denominator_tpr != 0 else 0.0
126
+
127
+ return {
128
+ "confusion matrix": {"TP": tp, "FP": fp, "FN": fn, "TN": tn},
129
+ "Accuracy": round(acc, 3),
130
+ "Sensitivity/Recall": round(sen, 3),
131
+ "Specificity": round(sp, 3),
132
+ "PPV/Precision": round(ppv, 3),
133
+ "NPV": round(npv, 3),
134
+ "F1-score": round(f1, 3),
135
+ "False positive rate": round(fpr, 3),
136
+ "True positive rate": round(tpr, 3),
137
+ "auc": auc
138
+ }
139
+
140
+
141
+ def test(model, test_loader, device=torch.device("cpu")):
142
+ """
143
+ Evaluate a model using a test loader and return a custom classification report.
144
+
145
+ Args:
146
+ model (torch.nn.Module): PyTorch model to evaluate.
147
+ test_loader (torch.utils.data.DataLoader): DataLoader for the test dataset.
148
+ device (torch.device, optional): Device for model evaluation. Default is "cpu".
149
+
150
+ Returns:
151
+ dict: A dictionary containing custom classification report metrics.
152
+ """
153
+
154
+ model.eval()
155
+ with torch.no_grad():
156
+ X_test, y_test = test_loader.dataset[:][0].to(
157
+ device), test_loader.dataset[:][1].to(device)
158
+ y_hat_prob = torch.squeeze(model(X_test), 1).cpu()
159
+
160
+ return custom_classification_report(y_test.cpu().numpy(), y_hat_prob.cpu().numpy())
161
+
162
+
163
+ column_map = {"object": "VARCHAR(255)", "int64": "INT", "float64": "FLOAT"}
164
+
165
+
166
+ def empty_db():
167
+ """
168
+ Empty the database by deleting records from multiple tables and resetting auto-increment counters.
169
+
170
+ Returns:
171
+ None
172
+ """
173
+ db_manager = DatabaseManager()
174
+ db_manager.connect()
175
+ my_eng = db_manager.get_connection()
176
+
177
+ # my_eng.execute(text(f"DELETE FROM {'DataSets'}"))
178
+ my_eng.execute(text(f"DELETE FROM {'Nodes'}"))
179
+ my_eng.execute(text(f"DELETE FROM {'FedDatasets'}"))
180
+ my_eng.execute(text(f"DELETE FROM {'Networks'}"))
181
+ my_eng.execute(text(f"DELETE FROM {'FLsetup'}"))
182
+
183
+ my_eng.execute(text(f"DELETE FROM {'FLpipeline'}"))
184
+ my_eng.execute(text(f"ALTER TABLE {'Nodes'} AUTO_INCREMENT = 1"))
185
+ my_eng.execute(text(f"ALTER TABLE {'Networks'} AUTO_INCREMENT = 1"))
186
+ my_eng.execute(text(f"ALTER TABLE {'FedDatasets'} AUTO_INCREMENT = 1"))
187
+ my_eng.execute(text(f"ALTER TABLE {'FLsetup'} AUTO_INCREMENT = 1"))
188
+ my_eng.execute(text(f"ALTER TABLE {'FLpipeline'} AUTO_INCREMENT = 1"))
189
+ my_eng.execute(text(f"DELETE FROM {'testResults'}"))
190
+ my_eng.execute(text(f"DROP TABLE IF EXISTS {'MasterDataset'}"))
191
+ my_eng.execute(text(f"DROP TABLE IF EXISTS {'DataSets'}"))
192
+
193
+
194
+ def get_pipeline_from_name(name):
195
+ """
196
+ Get the pipeline ID from its name in the database.
197
+
198
+ Args:
199
+ name (str): Name of the pipeline.
200
+
201
+ Returns:
202
+ int: ID of the pipeline.
203
+ """
204
+ db_manager = DatabaseManager()
205
+ db_manager.connect()
206
+ my_eng = db_manager.get_connection()
207
+
208
+ NodeId = int(
209
+ pd.read_sql(
210
+ text(f"SELECT id FROM FLpipeline WHERE name = '{name}'"), my_eng
211
+ ).iloc[0, 0]
212
+ )
213
+ return NodeId
214
+
215
+
216
+ def get_pipeline_confusion_matrix(pipeline_id):
217
+ """
218
+ Get the global confusion matrix for a pipeline based on test results.
219
+
220
+ Args:
221
+ pipeline_id (int): ID of the pipeline.
222
+
223
+ Returns:
224
+ dict: A dictionary representing the global confusion matrix.
225
+ """
226
+ db_manager = DatabaseManager()
227
+ db_manager.connect()
228
+ my_eng = db_manager.get_connection()
229
+
230
+ data = pd.read_sql(
231
+ text(
232
+ f"SELECT confusionmatrix FROM testResults WHERE pipelineid = '{pipeline_id}'"), my_eng
233
+ )
234
+
235
+ # Convert the column of strings into a list of dictionaries representing confusion matrices
236
+ confusion_matrices = [
237
+ json.loads(matrix.replace("'", "\"")) for matrix in data['confusionmatrix']
238
+ ]
239
+
240
+ # Initialize variables for global confusion matrix
241
+ global_TP = global_FP = global_FN = global_TN = 0
242
+
243
+ # Iterate through each dictionary and sum the corresponding values for each category
244
+ for matrix in confusion_matrices:
245
+ global_TP += matrix['TP']
246
+ global_FP += matrix['FP']
247
+ global_FN += matrix['FN']
248
+ global_TN += matrix['TN']
249
+
250
+ # Create a global confusion matrix as a dictionary
251
+ global_confusion_matrix = {
252
+ 'TP': global_TP,
253
+ 'FP': global_FP,
254
+ 'FN': global_FN,
255
+ 'TN': global_TN
256
+ }
257
+ # Return the list of dictionaries representing confusion matrices
258
+ return global_confusion_matrix
259
+
260
+
261
+ def get_node_confusion_matrix(pipeline_id, node_name):
262
+ """
263
+ Get the confusion matrix for a specific node in a pipeline based on test results.
264
+
265
+ Args:
266
+ pipeline_id (int): ID of the pipeline.
267
+ node_name (str): Name of the node.
268
+
269
+ Returns:
270
+ dict: A dictionary representing the confusion matrix for the specified node.
271
+ """
272
+ db_manager = DatabaseManager()
273
+ db_manager.connect()
274
+ my_eng = db_manager.get_connection()
275
+
276
+ data = pd.read_sql(
277
+ text(
278
+ f"SELECT confusionmatrix FROM testResults WHERE pipelineid = '{pipeline_id}' AND nodename = '{node_name}'"), my_eng
279
+ )
280
+
281
+ # Convert the column of strings into a list of dictionaries representing confusion matrices
282
+ confusion_matrices = [
283
+ json.loads(matrix.replace("'", "\"")) for matrix in data['confusionmatrix']
284
+ ]
285
+
286
+ # Return the list of dictionaries representing confusion matrices
287
+ return confusion_matrices[0]
288
+
289
+
290
+ def get_pipeline_result(pipeline_id):
291
+ """
292
+ Get the test results for a pipeline.
293
+
294
+ Args:
295
+ pipeline_id (int): ID of the pipeline.
296
+
297
+ Returns:
298
+ pandas.DataFrame: DataFrame containing test results for the specified pipeline.
299
+ """
300
+ db_manager = DatabaseManager()
301
+ db_manager.connect()
302
+ my_eng = db_manager.get_connection()
303
+
304
+ data = pd.read_sql(
305
+ text(
306
+ f"SELECT * FROM testResults WHERE pipelineid = '{pipeline_id}'"), my_eng
307
+ )
308
+ return data
@@ -1,9 +1,9 @@
1
- # # MEDfl/NetworkManager/__init__.py
2
-
3
- # # Import modules from this package
4
- # from .dataset import *
5
- # from .flsetup import *
6
- # from .net_helper import *
7
- # from .net_manager_queries import *
8
- # from .network import *
9
- # from .node import *
1
+ # # MEDfl/NetworkManager/__init__.py
2
+
3
+ # # Import modules from this package
4
+ # from .dataset import *
5
+ # from .flsetup import *
6
+ # from .net_helper import *
7
+ # from .net_manager_queries import *
8
+ # from .network import *
9
+ # from .node import *