uniovi-simur-wearablepermed-ml 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of uniovi-simur-wearablepermed-ml might be problematic. Click here for more details.

@@ -0,0 +1,63 @@
1
+ from models import *
2
+ from basic_functions.address import *
3
+ import pandas as pd
4
+ from data import DataReader
5
+ import ast
6
+
7
+ # Model factory pattern
8
+ def modelGenerator(modelID: str, data, params:dict={}, verbose=False, debug=False):
9
+ '''
10
+ ARGUMENTS
11
+ modelID (str) ID that indicates the model type
12
+ data (featExtraction object) Data object needed to train
13
+ params (dict) the params that define the model
14
+ '''
15
+ # data = data
16
+ # modelID = modelID
17
+ # params = params
18
+
19
+ if verbose:
20
+ print("Building model")
21
+
22
+ if not params and not debug:
23
+ if verbose:
24
+ print("loading best hyperparameters")
25
+ params_path = get_param_path(modelID)
26
+ df_params = pd.read_csv(params_path,index_col=0)
27
+ params = ast.literal_eval(df_params.loc[data.dataID,'params'])[0]
28
+
29
+
30
+ # Aquí se llamarían a todos los modelos según su ID
31
+ # 1. CNN ESANN
32
+ if modelID == "ESANN":
33
+ model = SiMuRModel_ESANN(data, params)
34
+
35
+ # 2. CNN CAPTURE-24
36
+ elif modelID == "CAPTURE24":
37
+ model = SiMuRModel_CAPTURE24(data, params)
38
+
39
+ # 3. Random Forest
40
+ elif modelID == "RandomForest":
41
+ model = SiMuRModel_RandomForest(data, params)
42
+
43
+ # 4. XGBoost
44
+ elif modelID == "XGBoost":
45
+ model = SiMuRModel_XGBoost(data, params)
46
+
47
+ else:
48
+ model = None
49
+ raise Exception("Model not implemented")
50
+ return model
51
+
52
+ # Unit testing
53
+ if __name__ == "__main__":
54
+
55
+ # Test time models
56
+ # modelID = "SiMuRModel_ESANN_data_tot"
57
+ modelID = "SiMuRModel_CAPTURE24_data_tot"
58
+ # params = {"N_capas":3}
59
+ params = {"N_capas":6}
60
+ data = DataReader(p_train=0.7, dataset='data_tot')
61
+
62
+ model = modelGenerator(modelID=modelID, data=data, params=params, debug=False)
63
+ model.train()
@@ -0,0 +1,130 @@
1
+ import os # Para manejar rutas de archivos
2
+ import subprocess # Para ejecutar scripts Python como procesos separados
3
+ import re # Para expresiones regulares, usado para extraer accuracy
4
+ import numpy as np # Para operaciones numéricas y guardar resultados
5
+
6
+ N_RUNS = 30 # Número de ejecuciones de train+test
7
+
8
+ # Ruta Windows
9
+ # case_id_folder = "D:\\DATA_PMP_File_Server\\output" # Carpeta base de los datos
10
+ # Ruta Linux
11
+ # case_id_folder = "/mnt/nvme1n2/git/uniovi-simur-wearablepermed-data/output"
12
+ case_id_folder = "/mnt/simur-fileserver/data/wearablepermed/output"
13
+
14
+ case_id = "cases_dataset_PI_M/case_PI_M_BRF_superclasses" # Identificador del caso
15
+
16
+ # Argumentos para el script de entrenamiento
17
+ train_args = [
18
+ # Ruta Windows
19
+ # "src\\wearablepermed_ml\\trainer.py", # Script de entrenamiento
20
+ # Ruta Linux
21
+ "src/wearablepermed_ml/trainer.py", # Script de entrenamiento
22
+ "--case-id", case_id, # ID del caso
23
+ "--case-id-folder", case_id_folder, # Carpeta de datos
24
+ "--ml-models", "RandomForest", # Modelo ML a usar
25
+ "--training-percent", "70", # Porcentaje de datos para entrenamiento
26
+ # "--validation-percent", "20", # Porcentaje de datos para validación
27
+ "--create-superclasses" # Flag opcional para crear superclases
28
+ ]
29
+
30
+ # Argumentos para el script de test
31
+ test_args = [
32
+ # Ruta Windows
33
+ # "src\\wearablepermed_ml\\tester.py", # Script de test
34
+ # Ruta Linux
35
+ "src/wearablepermed_ml/tester.py", # Script de test
36
+ "--case-id", case_id, # ID del caso
37
+ "--case-id-folder", case_id_folder, # Carpeta de datos
38
+ "--model-id", "RandomForest", # Modelo ML usado para test
39
+ "--training-percent", "70", # Porcentaje usado en entrenamiento
40
+ # "--validation-percent", "20", # Porcentaje de datos para validaciones
41
+ "--create-superclasses" # Flag opcional
42
+ ]
43
+
44
+ # Ruta del ejecutable de Python del entorno virtual (Windows)
45
+ # python_exe = os.path.join(".venv", "Scripts", "python.exe")
46
+ # En Linux, será:
47
+ python_exe = os.path.join(".venv", "bin", "python")
48
+
49
+ accuracies = [] # Lista donde se guardarán las accuracy de cada ejecución
50
+ # recalls = [] # Lista para los recall capturados
51
+ f1_scores = [] # Lista para los f1-score capturados
52
+
53
+ for i in range(1, N_RUNS + 1): # Bucle principal: repite N_RUNS veces
54
+ print(f"\n=== EJECUCIÓN {i} ===") # Indica el número de ejecución actual
55
+
56
+ # --- TRAIN ---
57
+ print(f"\n--- TRAIN (ejecución {i}) ---") # Mensaje de inicio de entrenamiento
58
+ subprocess.run([python_exe] + train_args, check=True) # Ejecuta trainer.py con los argumentos definidos
59
+
60
+ # --- TEST ---
61
+ test_args_with_i = test_args + ["--run-index", str(i)] # Agrega el índice de ejecución al comando de test
62
+ print(f"\n--- TEST (ejecución {i}) ---") # Mensaje de inicio de test
63
+
64
+ result = subprocess.run( # Lanza tester.py y captura su salida
65
+ [python_exe] + test_args_with_i, # Comando completo (python + tester.py + args)
66
+ check=True, # Si hay error, lanza excepción
67
+ capture_output=True, # Captura stdout y stderr
68
+ text=True # Interpreta la salida como texto (no bytes)
69
+ )
70
+
71
+ print(result.stdout) # Muestra la salida completa del tester.py
72
+
73
+ # --- Extraer métricas ---
74
+ acc_match = re.search(r"Global accuracy score (test)\s*=\s*([0-9.]+)", result.stdout) # Busca el accuracy en la salida
75
+ # recall_match = re.search(r"Global recall score\s*=\s*([0-9.]+)", result.stdout) # Busca el recall
76
+ f1_match = re.search(r"Global F1 (test)[-\s]?score\s*=\s*([0-9.]+)", result.stdout) # Busca el F1-score (permite F1-score o F1 score)
77
+
78
+ if acc_match: # Si se encontró el accuracy
79
+ acc = float(acc_match.group(1)) # Convierte el valor capturado a float
80
+ accuracies.append(acc) # Lo guarda en la lista
81
+ print(f"Accuracy capturado en la ejecución {i}: {acc} [%]") # Muestra el valor capturado
82
+ else:
83
+ print("No se encontró 'Global accuracy score' en la salida de tester.py") # Aviso si no se encontró
84
+
85
+ # if recall_match: # Si se encontró el recall
86
+ # rec = float(recall_match.group(1)) # Convierte a float
87
+ # recalls.append(rec) # Guarda el valor
88
+ # print(f"Recall capturado en la ejecución {i}: {rec} [%]") # Muestra el valor capturado
89
+ # else:
90
+ # print("No se encontró 'Global recall score' en la salida de tester.py") # Aviso si falta el dato
91
+
92
+ if f1_match: # Si se encontró el F1-score
93
+ f1 = float(f1_match.group(1)) # Convierte a float
94
+ f1_scores.append(f1) # Guarda el valor
95
+ print(f"F1-score capturado en la ejecución {i}: {f1} [%]") # Muestra el valor capturado
96
+ else:
97
+ print("No se encontró 'Global F1-score' en la salida de tester.py") # Aviso si no se encontró
98
+
99
+
100
+ # --- RESUMEN FINAL ---
101
+ print("\n=== RESUMEN FINAL ===") # Título del resumen
102
+ print("Accuracies:", accuracies) # Muestra lista completa de accuracies
103
+ # print("Recalls:", recalls) # Muestra lista de recalls
104
+ print("F1-scores:", f1_scores) # Muestra lista de F1-scores
105
+
106
+ if accuracies: # Si hay valores de accuracy
107
+ print(f"Accuracy mean: {np.mean(accuracies):.4f} | std: {np.std(accuracies):.4f}") # Calcula y muestra media y std
108
+ # if recalls: # Si hay valores de recall
109
+ # print(f"Recall mean: {np.mean(recalls):.4f} | std: {np.std(recalls):.4f}") # Calcula y muestra media y std
110
+ if f1_scores: # Si hay valores de f1
111
+ print(f"F1 mean: {np.mean(f1_scores):.4f} | std: {np.std(f1_scores):.4f}") # Calcula y muestra media y std
112
+
113
+
114
+ # --- GUARDAR EN .npz ---
115
+ accuracies_test_path = os.path.join(case_id_folder, case_id, "metrics_test.npz") # Ruta final donde guardar el archivo
116
+
117
+ np.savez( # Guarda los datos en un archivo comprimido .npz
118
+ accuracies_test_path, # Nombre/ruta del archivo de salida
119
+ accuracies=np.array(accuracies), # Lista de accuracies
120
+ # recalls=np.array(recalls), # Lista de recalls
121
+ f1_scores=np.array(f1_scores), # Lista de F1-scores
122
+ acc_mean=np.mean(accuracies), # Media de accuracy
123
+ acc_std=np.std(accuracies), # Desviación estándar de accuracy
124
+ # rec_mean=np.mean(recalls), # Media de recall
125
+ # rec_std=np.std(recalls), # Desviación estándar de recall
126
+ f1_mean=np.mean(f1_scores), # Media de F1-score
127
+ f1_std=np.std(f1_scores) # Desviación estándar de F1-score
128
+ )
129
+
130
+ print(f"\nResultados guardados en {accuracies_test_path}") # Mensaje final de confirmación
@@ -0,0 +1,156 @@
1
+ from enum import Enum
2
+ import os
3
+ import sys
4
+ import argparse
5
+ import logging
6
+
7
+ from testing import testing
8
+
9
+ __author__ = "Miguel Salinas <uo34525@uniovi.es>, Alejandro <uo265351@uniovi.es>"
10
+ __copyright__ = "Uniovi"
11
+ __license__ = "MIT"
12
+
13
+ _logger = logging.getLogger(__name__)
14
+
15
+ class ML_Model(Enum):
16
+ ESANN = 'ESANN'
17
+ CAPTURE24 = 'CAPTURE24'
18
+ RANDOM_FOREST = 'RandomForest'
19
+ XGBOOST = 'XGBoost'
20
+
21
+ def parse_args(args):
22
+ """Parse command line parameters
23
+
24
+ Args:
25
+ args (List[str]): command line parameters as list of strings
26
+ (for example ``["--help"]``).
27
+
28
+ Returns:
29
+ :obj:`argparse.Namespace`: command line parameters namespace
30
+ """
31
+ parser = argparse.ArgumentParser(description="Machine Learning Model Trainer")
32
+ parser.add_argument(
33
+ "-case-id",
34
+ "--case-id",
35
+ dest="case_id",
36
+ required=True,
37
+ help="Case unique identifier."
38
+ )
39
+ parser.add_argument(
40
+ "-case-id-folder",
41
+ "--case-id-folder",
42
+ dest="case_id_folder",
43
+ required=True,
44
+ help="Choose the case id root folder."
45
+ )
46
+ parser.add_argument(
47
+ "-model-id",
48
+ "--model-id",
49
+ dest="model_id",
50
+ required=True,
51
+ help="Choose the model id."
52
+ )
53
+ parser.add_argument(
54
+ "-create-superclasses",
55
+ "--create-superclasses",
56
+ dest="create_superclasses",
57
+ action='store_true',
58
+ help="Create activity superclasses (true/false)."
59
+ )
60
+ parser.add_argument(
61
+ "-create-superclasses-CPA-METs",
62
+ "--create-superclasses-CPA-METs",
63
+ dest="create_superclasses_CPA_METs",
64
+ action='store_true',
65
+ help="Create activity superclasses (true/false) with the CPA/METs method."
66
+ )
67
+ parser.add_argument(
68
+ '-training-percent',
69
+ '--training-percent',
70
+ dest='training_percent',
71
+ type=int,
72
+ default=70,
73
+ required=True,
74
+ help="Training percent"
75
+ )
76
+ parser.add_argument(
77
+ '-validation-percent',
78
+ '--validation-percent',
79
+ dest='validation_percent',
80
+ type=int,
81
+ default=0,
82
+ help="Validation percent"
83
+ )
84
+ parser.add_argument(
85
+ '-run-index',
86
+ '--run-index',
87
+ dest='run_index',
88
+ type=str,
89
+ default=1,
90
+ help="Run index of each iteration of the test step."
91
+ )
92
+ parser.add_argument(
93
+ "-v",
94
+ "--verbose",
95
+ dest="loglevel",
96
+ help="set loglevel to DEBUG.",
97
+ action="store_const",
98
+ const=logging.DEBUG,
99
+ )
100
+ parser.add_argument(
101
+ "-vv",
102
+ "--very-verbose",
103
+ dest="loglevel",
104
+ help="set loglevel to INFO.",
105
+ action="store_const",
106
+ const=logging.INFO,
107
+ )
108
+ return parser.parse_args(args)
109
+
110
+ def setup_logging(loglevel):
111
+ """Setup basic logging
112
+
113
+ Args:
114
+ loglevel (int): minimum loglevel for emitting messages
115
+ """
116
+ logformat = "[%(asctime)s] %(levelname)s:%(name)s:%(message)s"
117
+ logging.basicConfig(
118
+ level=loglevel, stream=sys.stdout, format=logformat, datefmt="%Y-%m-%d %H:%M:%S"
119
+ )
120
+
121
+ def main(args):
122
+ """Wrapper allowing :func:`fib` to be called with string arguments in a CLI fashion
123
+
124
+ Instead of returning the value from :func:`fib`, it prints the result to the
125
+ ``stdout`` in a nicely formatted message.
126
+
127
+ Args:
128
+ args (List[str]): command line parameters as list of strings
129
+ (for example ``["--verbose", "42"]``).
130
+ """
131
+ args = parse_args(args)
132
+ setup_logging(args.loglevel)
133
+
134
+ # create the output case id folder if not exist
135
+ case_id_folder = os.path.join(args.case_id_folder, args.case_id)
136
+ os.makedirs(case_id_folder, exist_ok=True)
137
+
138
+ _logger.info("Tester starts here")
139
+ testing.tester(case_id_folder,
140
+ args.model_id,
141
+ args.create_superclasses,
142
+ args.create_superclasses_CPA_METs,
143
+ args.training_percent,
144
+ args.validation_percent,
145
+ args.run_index)
146
+ _logger.info("Script ends here")
147
+
148
+ def run():
149
+ """Calls :func:`main` passing the CLI arguments extracted from :obj:`sys.argv`
150
+
151
+ This function can be used as entry point to create console scripts with setuptools.
152
+ """
153
+ main(sys.argv[1:])
154
+
155
+ if __name__ == "__main__":
156
+ run()
File without changes
@@ -0,0 +1,203 @@
1
+ from enum import Enum
2
+ import json
3
+ from data import DataReader
4
+ from models.model_generator import modelGenerator
5
+ from basic_functions.address import *
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ import seaborn as sns
9
+ from sklearn.metrics import accuracy_score, f1_score, recall_score, classification_report, confusion_matrix
10
+ import joblib
11
+
12
+ class ML_Model(Enum):
13
+ ESANN = 'ESANN'
14
+ CAPTURE24 = 'CAPTURE24'
15
+ RANDOM_FOREST = 'RandomForest'
16
+ XGBOOST = 'XGBoost'
17
+
18
+ def tester(case_id_folder, model_id, create_superclasses, create_superclasses_CPA_METs, training_percent, validation_percent, run_index):
19
+ # Cargar el LabelEncoder
20
+ # Ver las clases asociadas a cada número
21
+ test_label_encoder_path = os.path.join(case_id_folder, "label_encoder.pkl")
22
+ label_encoder = joblib.load(test_label_encoder_path)
23
+
24
+ print(label_encoder.classes_)
25
+
26
+ # class_names_total = ['CAMINAR CON LA COMPRA', 'CAMINAR CON MÓVIL O LIBRO', 'CAMINAR USUAL SPEED',
27
+ # 'CAMINAR ZIGZAG', 'DE PIE BARRIENDO', 'DE PIE DOBLANDO TOALLAS',
28
+ # 'DE PIE MOVIENDO LIBROS', 'DE PIE USANDO PC', 'FASE REPOSO CON K5',
29
+ # 'INCREMENTAL CICLOERGOMETRO', 'SENTADO LEYENDO', 'SENTADO USANDO PC',
30
+ # 'SENTADO VIENDO LA TV', 'SIT TO STAND 30 s', 'SUBIR Y BAJAR ESCALERAS',
31
+ # 'TAPIZ RODANTE', 'TROTAR', 'YOGA']
32
+
33
+ # class_names_total = ['CAMINAR CON LA COMPRA', 'CAMINAR CON MÓVIL O LIBRO', 'CAMINAR USUAL SPEED',
34
+ # 'CAMINAR ZIGZAG', 'DE PIE BARRIENDO', 'DE PIE DOBLANDO TOALLAS',
35
+ # 'DE PIE MOVIENDO LIBROS', 'DE PIE USANDO PC', 'FASE REPOSO CON K5',
36
+ # 'INCREMENTAL CICLOERGOMETRO', 'SENTADO LEYENDO', 'SENTADO USANDO PC',
37
+ # 'SENTADO VIENDO LA TV', 'SUBIR Y BAJAR ESCALERAS',
38
+ # 'TAPIZ RODANTE', 'TROTAR', 'YOGA']
39
+
40
+ class_names_total = label_encoder.classes_
41
+
42
+ print(len(class_names_total))
43
+
44
+ # Obtener el mapeo de cada etiqueta a su número asignado
45
+ mapeo = dict(zip(label_encoder.classes_, range(len(label_encoder.classes_))))
46
+ print("Mapeo de etiquetas:", mapeo)
47
+
48
+ # Lectura de hiperparámetros óptimos de cada modelo, previamente buscados
49
+ # ------------------------------------------------------------------------------------
50
+ if (model_id == ML_Model.ESANN.value):
51
+ test_dataset_path = os.path.join(case_id_folder, "data_all.npz")
52
+ # Ruta al archivo de hiperparámetros guardados
53
+ hp_json_path = os.path.join(case_id_folder, "mejores_hiperparametros_ESANN.json")
54
+ # Cargar hiperparámetros desde el archivo JSON
55
+ with open(hp_json_path, "r") as f:
56
+ best_hp_values = json.load(f) # Diccionario: {param: valor}
57
+
58
+ elif (model_id == ML_Model.CAPTURE24.value):
59
+ test_dataset_path = os.path.join(case_id_folder, "data_all.npz")
60
+ # Ruta al archivo de hiperparámetros guardados
61
+ hp_json_path = os.path.join(case_id_folder, "mejores_hiperparametros_CAPTURE24.json")
62
+ # Cargar hiperparámetros desde el archivo JSON
63
+ with open(hp_json_path, "r") as f:
64
+ best_hp_values = json.load(f) # Diccionario: {param: valor}
65
+
66
+ elif (model_id == ML_Model.RANDOM_FOREST.value):
67
+ test_dataset_path = os.path.join(case_id_folder, "data_feature_all.npz")
68
+ # Ruta al archivo de hiperparámetros guardados
69
+ hp_json_path = os.path.join(case_id_folder, "mejores_hiperparametros_BRF.json")
70
+ # Cargar hiperparámetros desde el archivo JSON
71
+ with open(hp_json_path, "r") as f:
72
+ best_hp_values = json.load(f) # Diccionario: {param: valor}
73
+
74
+ elif (model_id == ML_Model.XGBOOST.value):
75
+ test_dataset_path = os.path.join(case_id_folder, "data_feature_all.npz")
76
+ # Ruta al archivo de hiperparámetros guardados
77
+ hp_json_path = os.path.join(case_id_folder, "mejores_hiperparametros_XGB.json")
78
+ # Cargar hiperparámetros desde el archivo JSON
79
+ with open(hp_json_path, "r") as f:
80
+ best_hp_values = json.load(f) # Diccionario: {param: valor}
81
+
82
+ # Testeamos el rendimiento del modelo de clasificación con los DATOS TOTALES
83
+ data = DataReader(modelID=model_id,
84
+ create_superclasses=create_superclasses,
85
+ create_superclasses_CPA_METs = create_superclasses_CPA_METs,
86
+ p_train = training_percent,
87
+ p_validation=validation_percent,
88
+ file_path=test_dataset_path,
89
+ label_encoder_path=test_label_encoder_path)
90
+
91
+ # Construir modelo usando modelGenerator y los hiperparámetros
92
+ model = modelGenerator(
93
+ modelID=model_id,
94
+ data=data,
95
+ params=best_hp_values, # Pasamos directamente el diccionario de hiperparámetros óptimos
96
+ debug=False
97
+ )
98
+
99
+ model.load(model_id, case_id_folder)
100
+
101
+ # print train/test sizes
102
+ print(model.X_test.shape)
103
+ print(model.X_train.shape)
104
+
105
+ # testing the model
106
+ y_predicted_train = model.predict(model.X_train)
107
+ y_predicted_validation = model.predict(model.X_validation)
108
+ y_predicted_test = model.predict(model.X_test)
109
+
110
+ # get the class with the highest probability
111
+ if (model_id == ML_Model.ESANN.value or model_id == ML_Model.CAPTURE24.value):
112
+ y_final_prediction_train = np.argmax(y_predicted_train, axis=1)
113
+ y_final_prediction_validation = np.argmax(y_predicted_validation, axis=1)
114
+ y_final_prediction_test = np.argmax(y_predicted_test, axis=1) # Trabajamos con clasificación multicategoría, no necesario para los bosques aleatorios
115
+
116
+ else:
117
+ y_final_prediction_train = y_predicted_train
118
+ y_final_prediction_validation = y_predicted_validation
119
+ y_final_prediction_test = y_predicted_test # esta línea solo es necesaria para los bosques aleatorios y XGBoost
120
+
121
+
122
+ print(model.y_test)
123
+ print(model.y_test.shape)
124
+
125
+ print(y_predicted_test)
126
+ print(y_predicted_test.shape)
127
+
128
+ # Matriz de confusión
129
+ # Obtener todas las clases posibles desde 0 hasta N-1
130
+ num_classes = len(class_names_total) # Asegurar que contiene todas las clases esperadas
131
+ all_classes = np.arange(num_classes) # Crear array con todas las clases (0, 1, 2, ..., N-1)
132
+
133
+ # Crear la matriz de confusión asegurando que todas las clases están representadas
134
+ cm = confusion_matrix(model.y_test, y_final_prediction_test, labels=all_classes)
135
+
136
+ # Graficar la matriz de confusión
137
+ confusion_matrix_test_path = os.path.join(case_id_folder, "confusion_matrix_test_"+run_index+".png")
138
+
139
+ plt.figure(figsize=(10,7))
140
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names_total, yticklabels=class_names_total)
141
+ plt.xlabel('Predicted label')
142
+ plt.ylabel('True label')
143
+ plt.title('Confusion Matrix Test')
144
+ plt.savefig(confusion_matrix_test_path, bbox_inches='tight')
145
+
146
+ # -------------------------------------------------
147
+ # MÉTRICAS DE TEST GLOBALES
148
+ print("-------------------------------------------------\n")
149
+
150
+ # Accuracy
151
+ acc_score_train = accuracy_score(model.y_train, y_final_prediction_train)
152
+ print("Global accuracy score (train) = "+str(round(acc_score_train*100,2))+" [%]")
153
+
154
+ acc_score_validation = accuracy_score(model.y_validation, y_final_prediction_validation)
155
+ print("Global accuracy score (validation) = "+str(round(acc_score_validation*100,2))+" [%]")
156
+
157
+ acc_score_test = accuracy_score(model.y_test, y_final_prediction_test)
158
+ print("Global accuracy score (test) = "+str(round(acc_score_test*100,2))+" [%]")
159
+
160
+ # F1 Score
161
+ F1_score_train = f1_score(model.y_train, y_final_prediction_train, average='macro') # revisar las opciones de average
162
+ print("Global F1 score (train) = "+str(round(F1_score_train*100,2))+" [%]")
163
+
164
+ F1_score_validation = f1_score(model.y_validation, y_final_prediction_validation, average='macro') # revisar las opciones de average
165
+ print("Global F1 score (validation) = "+str(round(F1_score_validation*100,2))+" [%]")
166
+
167
+ F1_score_test = f1_score(model.y_test, y_final_prediction_test, average='macro') # revisar las opciones de average
168
+ print("Global F1 score (test) = "+str(round(F1_score_test*100,2))+" [%]")
169
+
170
+ # Recall global
171
+ # recall_score_global = recall_score(model.y_test, y_final_predicton, average='macro')
172
+ # print("Global recall score = "+str(round(recall_score_global*100,2))+" [%]")
173
+
174
+ # Save to a file
175
+ clasification_global_report_path = os.path.join(case_id_folder, "clasification_global_report_"+run_index+".txt")
176
+ with open(clasification_global_report_path, "w") as f:
177
+ f.write(f"Global F1 Score (train): {F1_score_train:.4f}\n")
178
+ f.write(f"Global accuracy score (train): {acc_score_train:.4f}\n")
179
+ f.write(f"Global F1 Score (validation): {F1_score_validation:.4f}\n")
180
+ f.write(f"Global accuracy score (validation): {acc_score_validation:.4f}\n")
181
+ f.write(f"Global F1 Score (test): {F1_score_test:.4f}\n")
182
+ f.write(f"Global accuracy score (test): {acc_score_test:.4f}\n")
183
+ # f.write(f"Global recall score: {recall_score_global:.4f}\n")
184
+
185
+ # -------------------------------------------------
186
+ # Obtener todas las clases posibles desde 0 hasta N-1
187
+ num_classes = len(class_names_total) # Número total de clases
188
+ all_classes = np.arange(num_classes) # Crea un array con todas las clases (0, 1, 2, ..., N-1)
189
+
190
+ # Tabla de métricas para cada clase
191
+ classification_per_class_report = classification_report(
192
+ model.y_test,
193
+ y_final_prediction_test,
194
+ labels=all_classes,
195
+ target_names=class_names_total,
196
+ zero_division=0
197
+ )
198
+ print(classification_per_class_report)
199
+
200
+ # Save per-class report to a file
201
+ clasification_per_class_report_path = os.path.join(case_id_folder, "clasification_per_class_report_"+run_index+".txt")
202
+ with open(clasification_per_class_report_path, "w") as f:
203
+ f.write(classification_per_class_report)