uniovi-simur-wearablepermed-ml 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of uniovi-simur-wearablepermed-ml might be problematic. Click here for more details.
- uniovi_simur_wearablepermed_ml-1.1.0.dist-info/METADATA +411 -0
- uniovi_simur_wearablepermed_ml-1.1.0.dist-info/RECORD +19 -0
- uniovi_simur_wearablepermed_ml-1.1.0.dist-info/WHEEL +5 -0
- uniovi_simur_wearablepermed_ml-1.1.0.dist-info/entry_points.txt +3 -0
- uniovi_simur_wearablepermed_ml-1.1.0.dist-info/licenses/LICENSE.txt +21 -0
- uniovi_simur_wearablepermed_ml-1.1.0.dist-info/top_level.txt +1 -0
- wearablepermed_ml/__init__.py +16 -0
- wearablepermed_ml/basic_functions/__init__.py +0 -0
- wearablepermed_ml/basic_functions/address.py +17 -0
- wearablepermed_ml/data/DataReader.py +388 -0
- wearablepermed_ml/data/__init__.py +1 -0
- wearablepermed_ml/models/SiMuR_Model.py +671 -0
- wearablepermed_ml/models/__init__.py +1 -0
- wearablepermed_ml/models/model_generator.py +63 -0
- wearablepermed_ml/run_trainer_and_tester_30_times.py +130 -0
- wearablepermed_ml/tester.py +156 -0
- wearablepermed_ml/testing/__init__.py +0 -0
- wearablepermed_ml/testing/testing.py +203 -0
- wearablepermed_ml/trainer.py +782 -0
|
@@ -0,0 +1,782 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import argparse
|
|
4
|
+
import logging
|
|
5
|
+
from enum import Enum
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from data import DataReader
|
|
9
|
+
from models.model_generator import modelGenerator
|
|
10
|
+
from basic_functions.address import *
|
|
11
|
+
|
|
12
|
+
import tensorflow as tf
|
|
13
|
+
|
|
14
|
+
# import keras_tuner
|
|
15
|
+
import json
|
|
16
|
+
|
|
17
|
+
from ray import tune
|
|
18
|
+
from ray.tune.schedulers import ASHAScheduler
|
|
19
|
+
from ray.air import RunConfig, CheckpointConfig
|
|
20
|
+
from ray.air.config import FailureConfig
|
|
21
|
+
from ray.air import session
|
|
22
|
+
|
|
23
|
+
from ray.tune.tuner import TuneConfig
|
|
24
|
+
|
|
25
|
+
from models import SiMuRModel_ESANN, SiMuRModel_CAPTURE24, SiMuRModel_RandomForest, SiMuRModel_XGBoost
|
|
26
|
+
from sklearn.metrics import accuracy_score
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# Configuration of GPU
|
|
30
|
+
gpus = tf.config.list_physical_devices('GPU')
|
|
31
|
+
if gpus:
|
|
32
|
+
try:
|
|
33
|
+
for gpu in gpus:
|
|
34
|
+
tf.config.experimental.set_memory_growth(gpu, True)
|
|
35
|
+
print(f"{len(gpus)} GPU(s) detected and VRAM set to crossover mode..")
|
|
36
|
+
except RuntimeError as e:
|
|
37
|
+
print(f"GPU configuration error : {e}")
|
|
38
|
+
else:
|
|
39
|
+
print("⚠️ I also discovered the GPU. Training takes place on the CPU.")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
__author__ = "Miguel Salinas <uo34525@uniovi.es>, Alejandro <uo265351@uniovi.es>"
|
|
43
|
+
__copyright__ = "Uniovi"
|
|
44
|
+
__license__ = "MIT"
|
|
45
|
+
|
|
46
|
+
_logger = logging.getLogger(__name__)
|
|
47
|
+
|
|
48
|
+
CONVOLUTIONAL_DATASET_FILE = "data_all.npz"
|
|
49
|
+
FEATURE_DATASET_FILE = "data_feature_all.npz"
|
|
50
|
+
LABEL_ENCODER_FILE = "label_encoder.pkl"
|
|
51
|
+
CONFIG_FILE = "config.cfg"
|
|
52
|
+
|
|
53
|
+
class ML_Model(Enum):
|
|
54
|
+
ESANN = 'ESANN'
|
|
55
|
+
CAPTURE24 = 'CAPTURE24'
|
|
56
|
+
RANDOM_FOREST = 'RandomForest'
|
|
57
|
+
XGBOOST = 'XGBoost'
|
|
58
|
+
|
|
59
|
+
class ML_Sensor(Enum):
|
|
60
|
+
PI = 'thigh'
|
|
61
|
+
M = 'wrist'
|
|
62
|
+
C = 'hip'
|
|
63
|
+
|
|
64
|
+
def parse_ml_model(value):
|
|
65
|
+
try:
|
|
66
|
+
"""Parse a comma-separated list of CML Models lor values into a list of ML_Sensor enums."""
|
|
67
|
+
values = [v.strip() for v in value.split(',') if v.strip()]
|
|
68
|
+
result = []
|
|
69
|
+
invalid = []
|
|
70
|
+
for v in values:
|
|
71
|
+
try:
|
|
72
|
+
result.append(ML_Model(v))
|
|
73
|
+
except ValueError:
|
|
74
|
+
invalid.append(v)
|
|
75
|
+
if invalid:
|
|
76
|
+
valid = ', '.join(c.value for c in ML_Model)
|
|
77
|
+
raise argparse.ArgumentTypeError(
|
|
78
|
+
f"Invalid color(s): {', '.join(invalid)}. "
|
|
79
|
+
f"Choose from: {valid}"
|
|
80
|
+
)
|
|
81
|
+
return result
|
|
82
|
+
except ValueError:
|
|
83
|
+
valid = ', '.join(ml_model.value for ml_model in ML_Model)
|
|
84
|
+
raise argparse.ArgumentTypeError(f"Invalid ML Model '{value}'. Choose from: {valid}")
|
|
85
|
+
|
|
86
|
+
def parse_args(args):
|
|
87
|
+
"""Parse command line parameters
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
args (List[str]): command line parameters as list of strings
|
|
91
|
+
(for example ``["--help"]``).
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
:obj:`argparse.Namespace`: command line parameters namespace
|
|
95
|
+
"""
|
|
96
|
+
parser = argparse.ArgumentParser(description="Machine Learning Model Trainer")
|
|
97
|
+
parser.add_argument(
|
|
98
|
+
"-case-id",
|
|
99
|
+
"--case-id",
|
|
100
|
+
dest="case_id",
|
|
101
|
+
required=True,
|
|
102
|
+
help="Case unique identifier."
|
|
103
|
+
)
|
|
104
|
+
parser.add_argument(
|
|
105
|
+
"-case-id-folder",
|
|
106
|
+
"--case-id-folder",
|
|
107
|
+
dest="case_id_folder",
|
|
108
|
+
required=True,
|
|
109
|
+
help="Choose the case id root folder."
|
|
110
|
+
)
|
|
111
|
+
parser.add_argument(
|
|
112
|
+
"-ml-models",
|
|
113
|
+
"--ml-models",
|
|
114
|
+
type=parse_ml_model,
|
|
115
|
+
nargs='+',
|
|
116
|
+
dest="ml_models",
|
|
117
|
+
required=True,
|
|
118
|
+
help=f"Available ML models: {[c.value for c in ML_Model]}."
|
|
119
|
+
)
|
|
120
|
+
parser.add_argument(
|
|
121
|
+
"-create-superclasses",
|
|
122
|
+
"--create-superclasses",
|
|
123
|
+
dest="create_superclasses",
|
|
124
|
+
action='store_true',
|
|
125
|
+
help="Create activity superclasses (true/false)."
|
|
126
|
+
)
|
|
127
|
+
parser.add_argument(
|
|
128
|
+
"-create-superclasses-CPA-METs",
|
|
129
|
+
"--create-superclasses-CPA-METs",
|
|
130
|
+
dest="create_superclasses_CPA_METs",
|
|
131
|
+
action='store_true',
|
|
132
|
+
help="Create activity superclasses (true/false) with the CPA/METs method."
|
|
133
|
+
)
|
|
134
|
+
parser.add_argument(
|
|
135
|
+
'-training-percent',
|
|
136
|
+
'--training-percent',
|
|
137
|
+
dest='training_percent',
|
|
138
|
+
type=int,
|
|
139
|
+
default=70,
|
|
140
|
+
required=True,
|
|
141
|
+
help="Training percent"
|
|
142
|
+
)
|
|
143
|
+
parser.add_argument(
|
|
144
|
+
'-validation-percent',
|
|
145
|
+
'--validation-percent',
|
|
146
|
+
dest='validation_percent',
|
|
147
|
+
type=int,
|
|
148
|
+
default=0,
|
|
149
|
+
help="Validation percent"
|
|
150
|
+
)
|
|
151
|
+
parser.add_argument(
|
|
152
|
+
'-add-sintetic-data',
|
|
153
|
+
'--add-sintetic-data',
|
|
154
|
+
dest='add_sintetic_data',
|
|
155
|
+
type=bool,
|
|
156
|
+
default=False,
|
|
157
|
+
help="Add sintetic data for training"
|
|
158
|
+
)
|
|
159
|
+
parser.add_argument(
|
|
160
|
+
"-v",
|
|
161
|
+
"--verbose",
|
|
162
|
+
dest="loglevel",
|
|
163
|
+
help="set loglevel to INFO.",
|
|
164
|
+
action="store_const",
|
|
165
|
+
const=logging.INFO,
|
|
166
|
+
)
|
|
167
|
+
parser.add_argument(
|
|
168
|
+
"-vv",
|
|
169
|
+
"--very-verbose",
|
|
170
|
+
dest="loglevel",
|
|
171
|
+
help="set loglevel to DEBUG.",
|
|
172
|
+
action="store_const",
|
|
173
|
+
const=logging.DEBUG,
|
|
174
|
+
)
|
|
175
|
+
return parser.parse_args(args)
|
|
176
|
+
|
|
177
|
+
def setup_logging(loglevel):
|
|
178
|
+
"""Setup basic logging
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
loglevel (int): minimum loglevel for emitting messages
|
|
182
|
+
"""
|
|
183
|
+
logformat = "[%(asctime)s] %(levelname)s:%(name)s:%(message)s"
|
|
184
|
+
logging.basicConfig(
|
|
185
|
+
level=loglevel, stream=sys.stdout, format=logformat, datefmt="%Y-%m-%d %H:%M:%S"
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def convolution_model_selected(models):
|
|
189
|
+
for model in models:
|
|
190
|
+
if model.value in [ML_Model.CAPTURE24.value, ML_Model.ESANN.value]:
|
|
191
|
+
return True
|
|
192
|
+
|
|
193
|
+
return False
|
|
194
|
+
|
|
195
|
+
def feature_model_selected(models):
|
|
196
|
+
for model in models:
|
|
197
|
+
if model.value in [ML_Model.RANDOM_FOREST.value, ML_Model.XGBOOST.value]:
|
|
198
|
+
return True
|
|
199
|
+
|
|
200
|
+
return False
|
|
201
|
+
|
|
202
|
+
# ------------------------------------------------------------------------
|
|
203
|
+
# if searching optimal hyperparameter:
|
|
204
|
+
def train_cnn_ray_tune(config, model_class, data):
|
|
205
|
+
params = {
|
|
206
|
+
"N_capas": config["N_capas"],
|
|
207
|
+
"optimizador": config["optimizador"],
|
|
208
|
+
"funcion_activacion": config["funcion_activacion"],
|
|
209
|
+
"tamanho_minilote": config["tamanho_minilote"],
|
|
210
|
+
"numero_filtros": config["numero_filtros"],
|
|
211
|
+
"tamanho_filtro": config["tamanho_filtro"],
|
|
212
|
+
"tasa_aprendizaje": config["tasa_aprendizaje"],
|
|
213
|
+
"epochs": config["epochs"]
|
|
214
|
+
}
|
|
215
|
+
model = model_class(data, params)
|
|
216
|
+
model.train(config["epochs"])
|
|
217
|
+
y_pred = model.predict(data.X_validation)
|
|
218
|
+
# Si devuelve probabilidades, convierte a clases
|
|
219
|
+
if y_pred.ndim > 1 and y_pred.shape[1] > 1:
|
|
220
|
+
y_pred = y_pred.argmax(axis=1)
|
|
221
|
+
val_acc = accuracy_score(data.y_validation, y_pred)
|
|
222
|
+
session.report({"val_accuracy": float(val_acc)})
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def train_brf_ray_tune(config, model_class, data):
|
|
226
|
+
params = { # Extraer hiperparámetros desde config
|
|
227
|
+
"n_estimators": config["n_estimators"], # Número de árboles en el bosque
|
|
228
|
+
"max_depth": config["max_depth"], # Profundidad máxima de los árboles
|
|
229
|
+
"min_samples_split": config["min_samples_split"], # Muestras mínimas para dividir un nodo
|
|
230
|
+
"min_samples_leaf": config["min_samples_leaf"], # Muestras mínimas por hoja
|
|
231
|
+
"max_features": config["max_features"], # Número de características consideradas por división
|
|
232
|
+
}
|
|
233
|
+
model = model_class(data, params) # Instanciar el modelo usando model_class
|
|
234
|
+
model.train() # Entrenar el modelo
|
|
235
|
+
y_pred = model.predict(data.X_test) # Predecir sobre conjunto de test
|
|
236
|
+
if y_pred.ndim > 1 and y_pred.shape[1] > 1: # Si devuelve probabilidades, convertir a clases
|
|
237
|
+
y_pred = y_pred.argmax(axis=1)
|
|
238
|
+
test_acc = accuracy_score(data.y_test, y_pred) # Calcular precisión
|
|
239
|
+
session.report({"test_accuracy": float(test_acc)}) # Reportar precisión a Ray Tune
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def train_xgb_ray_tune(config, model_class, data):
|
|
243
|
+
params = { # Extraer hiperparámetros desde config
|
|
244
|
+
"num_boost_round": config["num_boost_round"], # Número de árboles (rondas) de boosting
|
|
245
|
+
"max_depth": config["max_depth"], # Profundidad máxima
|
|
246
|
+
"learning_rate": config["learning_rate"], # Tasa de aprendizaje
|
|
247
|
+
"subsample": config["subsample"], # Fracción de muestras por árbol
|
|
248
|
+
"colsample_bytree": config["colsample_bytree"], # Fracción de columnas por árbol
|
|
249
|
+
"gamma": config["gamma"], # Regularización mínima de pérdida
|
|
250
|
+
"min_child_weight": config["min_child_weight"], # Peso mínimo de hijos
|
|
251
|
+
"reg_alpha": config["reg_alpha"], # L1 regularization
|
|
252
|
+
"reg_lambda": config["reg_lambda"] # L2 regularization
|
|
253
|
+
}
|
|
254
|
+
model = model_class(data, params) # Instanciar el modelo usando model_class (ej. SiMuRModel_XGBoost)
|
|
255
|
+
model.train() # Entrenar el modelo
|
|
256
|
+
y_pred = model.predict(data.X_validation) # Predecir sobre conjunto de validación
|
|
257
|
+
if y_pred.ndim > 1 and y_pred.shape[1] > 1: # Si devuelve probabilidades, convertir a clases
|
|
258
|
+
y_pred = y_pred.argmax(axis=1)
|
|
259
|
+
validation_acc = accuracy_score(data.y_validation, y_pred) # Calcular precisión
|
|
260
|
+
session.report({"validation_accuracy": float(validation_acc)}) # Reportar precisión a Ray Tune
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def main(args):
|
|
264
|
+
"""Wrapper allowing :func:`fib` to be called with string arguments in a CLI fashion
|
|
265
|
+
|
|
266
|
+
Instead of returning the value from :func:`fib`, it prints the result to the
|
|
267
|
+
``stdout`` in a nicely formatted message.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
args (List[str]): command line parameters as list of strings
|
|
271
|
+
(for example ``["--verbose", "42"]``).
|
|
272
|
+
"""
|
|
273
|
+
args = parse_args(args)
|
|
274
|
+
setup_logging(args.loglevel)
|
|
275
|
+
|
|
276
|
+
_logger.info("Trainer starts here")
|
|
277
|
+
|
|
278
|
+
# create the output case id folder if not exist
|
|
279
|
+
case_id_folder = os.path.join(args.case_id_folder, args.case_id)
|
|
280
|
+
os.makedirs(case_id_folder, exist_ok=True)
|
|
281
|
+
|
|
282
|
+
for ml_model in args.ml_models[0]:
|
|
283
|
+
modelID = ml_model.value
|
|
284
|
+
|
|
285
|
+
# **********
|
|
286
|
+
# Modelo A *
|
|
287
|
+
# **********
|
|
288
|
+
if modelID == ML_Model.ESANN.value:
|
|
289
|
+
dataset_file = os.path.join(case_id_folder, CONVOLUTIONAL_DATASET_FILE)
|
|
290
|
+
label_encoder_file = os.path.join(case_id_folder, LABEL_ENCODER_FILE)
|
|
291
|
+
config_file = os.path.join(case_id_folder, CONFIG_FILE)
|
|
292
|
+
|
|
293
|
+
data_tot = DataReader(modelID=modelID,
|
|
294
|
+
create_superclasses=args.create_superclasses,
|
|
295
|
+
create_superclasses_CPA_METs = args.create_superclasses_CPA_METs,
|
|
296
|
+
p_train = args.training_percent,
|
|
297
|
+
p_validation = args.validation_percent,
|
|
298
|
+
file_path=dataset_file,
|
|
299
|
+
label_encoder_path=label_encoder_file,
|
|
300
|
+
config_path = config_file)
|
|
301
|
+
|
|
302
|
+
# Se entrenan y salvan los modelos (fichero .h5).
|
|
303
|
+
# Ruta al archivo de hiperparámetros guardados
|
|
304
|
+
hp_json_path = os.path.join(case_id_folder, "mejores_hiperparametros_ESANN.json")
|
|
305
|
+
# Verifica que el archivo existe
|
|
306
|
+
if os.path.isfile(hp_json_path):
|
|
307
|
+
# Cargar hiperparámetros desde el archivo JSON
|
|
308
|
+
with open(hp_json_path, "r") as f:
|
|
309
|
+
best_hp_values = json.load(f) # Diccionario: {param: valor}
|
|
310
|
+
# Construir modelo usando modelGenerator y los hiperparámetros
|
|
311
|
+
model_ESANN_data_tot = modelGenerator(
|
|
312
|
+
modelID=modelID,
|
|
313
|
+
data=data_tot,
|
|
314
|
+
params=best_hp_values, # Pasamos directamente el diccionario
|
|
315
|
+
debug=False
|
|
316
|
+
)
|
|
317
|
+
# Entrenar el modelo con todos los datos
|
|
318
|
+
model_ESANN_data_tot.train(best_hp_values['epochs'])
|
|
319
|
+
# Guardar los pesos del modelo en formato .weights.h5
|
|
320
|
+
model_ESANN_data_tot.store(modelID, case_id_folder)
|
|
321
|
+
else:
|
|
322
|
+
print(f"Se lanza la búsqueda de hiperparámetros óptimos del modelo")
|
|
323
|
+
# -----------------------------------------------------------------------------------------------
|
|
324
|
+
# Búsqueda de hiperparámetros óptimos del modelo, implementando el algoritmo ASHA según Ray Tune.
|
|
325
|
+
# -----------------------------------------------------------------------------------------------
|
|
326
|
+
# Espacio de búsqueda
|
|
327
|
+
search_space = {
|
|
328
|
+
"N_capas": tune.randint(2, 8), # Número de capas entre 2 y 7 (el límite superior es exclusivo)
|
|
329
|
+
"optimizador": tune.choice(["adam", "rmsprop", "SGD"]), # Algoritmo de optimización a usar
|
|
330
|
+
"funcion_activacion": tune.choice(["relu", "tanh", "sigmoid"]), # Función de activación en las capas
|
|
331
|
+
"tamanho_minilote": tune.choice([10, 17, 24, 31]), # Tamaño del minibatch (batch size)
|
|
332
|
+
"numero_filtros": tune.choice([12, 16, 20, 24, 28, 30]), # Cantidad de filtros para capas convolucionales
|
|
333
|
+
"tamanho_filtro": tune.choice([3, 5, 7, 9, 11, 13, 15]), # Tamaño del kernel (filtro) en capas convolucionales
|
|
334
|
+
"tasa_aprendizaje": tune.loguniform(1e-4, 1e-1), # Tasa de aprendizaje entre 0.0001 y 0.1 (escala logarítmica)
|
|
335
|
+
"epochs": tune.randint(5, 51) # Número de épocas de entrenamiento entre 5 y 50
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
# Configuración del scheduler
|
|
339
|
+
scheduler = ASHAScheduler(
|
|
340
|
+
metric="val_accuracy", # Métrica a optimizar: precisión en el conjunto de validación
|
|
341
|
+
mode="max", # Se busca maximizar la métrica especificada
|
|
342
|
+
max_t=10, # Número máximo de iteraciones (por ejemplo, épocas) por prueba
|
|
343
|
+
grace_period=1, # Número mínimo de iteraciones antes de detener una prueba prematuramente
|
|
344
|
+
reduction_factor=2 # Factor por el cual se reduce el número de pruebas en cada ronda de selección
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
# Envolver función con parámetros adicionales
|
|
348
|
+
wrapped_train_fn = tune.with_parameters(
|
|
349
|
+
train_cnn_ray_tune, # Función de entrenamiento base que se usará en la búsqueda
|
|
350
|
+
model_class=SiMuRModel_ESANN, # Clase del modelo a usar
|
|
351
|
+
data=data_tot # Conjunto de datos completo que se pasará a cada ejecución de entrenamiento
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
# Crear el tuner
|
|
355
|
+
tuner = tune.Tuner(
|
|
356
|
+
wrapped_train_fn, # Función de entrenamiento envuelta con parámetros fijos
|
|
357
|
+
param_space=search_space, # Espacio de búsqueda de hiperparámetros definido antes
|
|
358
|
+
tune_config=TuneConfig(
|
|
359
|
+
scheduler=scheduler, # Scheduler para manejar la parada temprana (ASHAScheduler)
|
|
360
|
+
num_samples=20, # Número de configuraciones (experimentos) a probar
|
|
361
|
+
trial_name_creator=lambda trial: f"trial_{trial.trial_id[:5]}", # Nombre personalizado para cada prueba
|
|
362
|
+
trial_dirname_creator=lambda trial: f"dir_{trial.trial_id[:5]}" # Carpeta personalizada para cada prueba
|
|
363
|
+
),
|
|
364
|
+
run_config=RunConfig(
|
|
365
|
+
name="ESANN_hyperparameters_tuning", # Nombre general del experimento
|
|
366
|
+
storage_path=case_id_folder, # Ruta donde se guardan los resultados y checkpoints
|
|
367
|
+
checkpoint_config=CheckpointConfig(num_to_keep=1), # Guardar solo el último checkpoint por prueba
|
|
368
|
+
failure_config=FailureConfig(fail_fast=False, max_failures=10), # Permite hasta 10 fallos antes de parar
|
|
369
|
+
verbose=2, # Nivel de detalle en los logs (más detallado)
|
|
370
|
+
log_to_file=False # No guardar logs en archivos (evita problemas con rutas largas)
|
|
371
|
+
)
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
# Ejecutar búsqueda de hiperparámetros
|
|
375
|
+
results = tuner.fit()
|
|
376
|
+
|
|
377
|
+
# Obtener mejor resultado
|
|
378
|
+
best_result = results.get_best_result(metric="val_accuracy", mode="max")
|
|
379
|
+
print("Mejores hiperparámetros:", best_result.config)
|
|
380
|
+
|
|
381
|
+
# Obtener la configuración óptima como diccionario
|
|
382
|
+
mejores_hiperparametros = best_result.config
|
|
383
|
+
|
|
384
|
+
# Guardar en un archivo JSON
|
|
385
|
+
with open(os.path.join(case_id_folder,"mejores_hiperparametros_ESANN.json"), "w") as f:
|
|
386
|
+
json.dump(mejores_hiperparametros, f, indent=4)
|
|
387
|
+
|
|
388
|
+
# Obtener los resultados como DataFrame
|
|
389
|
+
df = results.get_dataframe()
|
|
390
|
+
df.to_json(os.path.join(case_id_folder, "resultados_busqueda_ray_tune_ESANN.json"), orient="records", lines=True)
|
|
391
|
+
|
|
392
|
+
# Construir modelo usando modelGenerator y los mejores hiperparámetros
|
|
393
|
+
model_ESANN_data_tot = modelGenerator(
|
|
394
|
+
modelID=modelID,
|
|
395
|
+
data=data_tot,
|
|
396
|
+
params=mejores_hiperparametros, # Pasamos directamente el diccionario
|
|
397
|
+
debug=False
|
|
398
|
+
)
|
|
399
|
+
# Entrenar el modelo con todos los datos
|
|
400
|
+
model_ESANN_data_tot.train(mejores_hiperparametros['epochs'])
|
|
401
|
+
# Guardar los pesos del modelo en formato .weights.h5
|
|
402
|
+
model_ESANN_data_tot.store(modelID, case_id_folder)
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
# **********
|
|
406
|
+
# Modelo B *
|
|
407
|
+
# **********
|
|
408
|
+
elif modelID == ML_Model.CAPTURE24.value:
|
|
409
|
+
dataset_file = os.path.join(case_id_folder, CONVOLUTIONAL_DATASET_FILE)
|
|
410
|
+
label_encoder_file = os.path.join(case_id_folder, LABEL_ENCODER_FILE)
|
|
411
|
+
config_file = os.path.join(case_id_folder, CONFIG_FILE)
|
|
412
|
+
|
|
413
|
+
data_tot = DataReader(modelID=modelID,
|
|
414
|
+
create_superclasses=args.create_superclasses,
|
|
415
|
+
create_superclasses_CPA_METs=args.create_superclasses_CPA_METs,
|
|
416
|
+
p_train = args.training_percent,
|
|
417
|
+
p_validation = args.validation_percent,
|
|
418
|
+
file_path=dataset_file,
|
|
419
|
+
label_encoder_path=label_encoder_file,
|
|
420
|
+
config_path = config_file)
|
|
421
|
+
|
|
422
|
+
# Se entrenan y salvan los modelos (fichero .h5).
|
|
423
|
+
# Ruta al archivo de hiperparámetros guardados
|
|
424
|
+
hp_json_path = os.path.join(case_id_folder, "mejores_hiperparametros_CAPTURE24.json")
|
|
425
|
+
# Verifica que el archivo existe
|
|
426
|
+
if os.path.isfile(hp_json_path):
|
|
427
|
+
# Cargar hiperparámetros desde el archivo JSON
|
|
428
|
+
with open(hp_json_path, "r") as f:
|
|
429
|
+
best_hp_values = json.load(f) # Diccionario: {param: valor}
|
|
430
|
+
# Construir modelo usando modelGenerator y los hiperparámetros
|
|
431
|
+
model_CAPTURE24_data_tot = modelGenerator(
|
|
432
|
+
modelID=modelID,
|
|
433
|
+
data=data_tot,
|
|
434
|
+
params=best_hp_values, # Pasamos directamente el diccionario
|
|
435
|
+
debug=False
|
|
436
|
+
)
|
|
437
|
+
# Entrenar el modelo con todos los datos
|
|
438
|
+
model_CAPTURE24_data_tot.train(best_hp_values['epochs'])
|
|
439
|
+
# Guardar los pesos del modelo en formato .weights.h5
|
|
440
|
+
model_CAPTURE24_data_tot.store(modelID, case_id_folder)
|
|
441
|
+
else:
|
|
442
|
+
print(f"Se lanza la búsqueda de hiperparámetros óptimos del modelo")
|
|
443
|
+
# -----------------------------------------------------------------------------------------------
|
|
444
|
+
# Búsqueda de hiperparámetros óptimos del modelo, implementando el algoritmo ASHA según Ray Tune.
|
|
445
|
+
# -----------------------------------------------------------------------------------------------
|
|
446
|
+
# Espacio de búsqueda
|
|
447
|
+
# search_space = {
|
|
448
|
+
# "N_capas": tune.randint(2, 8), # Número de capas entre 2 y 7 (el límite superior es exclusivo)
|
|
449
|
+
# "optimizador": tune.choice(["adam", "rmsprop", "sgd"]), # Algoritmo de optimización a usar
|
|
450
|
+
# "funcion_activacion": tune.choice(["relu", "tanh", "sigmoid"]), # Función de activación en las capas
|
|
451
|
+
# "tamanho_minilote": tune.choice([4, 8, 16]), # Tamaño del minibatch (batch size)
|
|
452
|
+
# "numero_filtros": tune.choice([64, 96, 128]), # Cantidad de filtros para capas convolucionales
|
|
453
|
+
# "tamanho_filtro": tune.choice([3, 5, 7]), # Tamaño del kernel (filtro) en capas convolucionales
|
|
454
|
+
# "tasa_aprendizaje": tune.loguniform(1e-4, 1e-1), # Tasa de aprendizaje entre 0.0001 y 0.1 (escala logarítmica)
|
|
455
|
+
# "epochs": tune.randint(5, 51) # Número de épocas de entrenamiento entre 5 y 50
|
|
456
|
+
# }
|
|
457
|
+
|
|
458
|
+
search_space = {
|
|
459
|
+
"N_capas": tune.randint(2, 5), # 2–4 capas
|
|
460
|
+
"optimizador": tune.choice(["adam", "rmsprop", "sgd"]),
|
|
461
|
+
"funcion_activacion": tune.choice(["relu", "tanh"]), # activaciones que funcionan mejor en CNN
|
|
462
|
+
"tamanho_minilote": tune.choice([2, 4, 6]), # batch pequeño para memoria limitada
|
|
463
|
+
"numero_filtros": tune.choice([32, 48, 64]), # filtros moderados
|
|
464
|
+
"tamanho_filtro": tune.choice([3, 5, 7]), # tamaño de kernel razonable
|
|
465
|
+
"num_resblocks": tune.choice([1, 2]), # 1 o 2 ResBlocks por etapa
|
|
466
|
+
"tasa_aprendizaje": tune.loguniform(1e-4, 5e-4), # learning rate conservador
|
|
467
|
+
"epochs": tune.randint(10, 30) # número de epochs moderado
|
|
468
|
+
}
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
# Configuración del scheduler
|
|
472
|
+
scheduler = ASHAScheduler(
|
|
473
|
+
metric="val_accuracy", # Métrica a optimizar: precisión en el conjunto de validación
|
|
474
|
+
mode="max", # Se busca maximizar la métrica especificada
|
|
475
|
+
max_t=10, # Número máximo de iteraciones (por ejemplo, épocas) por prueba
|
|
476
|
+
grace_period=1, # Número mínimo de iteraciones antes de detener una prueba prematuramente
|
|
477
|
+
reduction_factor=2 # Factor por el cual se reduce el número de pruebas en cada ronda de selección
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
# Envolver función con parámetros adicionales
|
|
481
|
+
wrapped_train_fn = tune.with_parameters(
|
|
482
|
+
train_cnn_ray_tune, # Función de entrenamiento base que se usará en la búsqueda
|
|
483
|
+
model_class=SiMuRModel_CAPTURE24, # Clase del modelo a usar
|
|
484
|
+
data=data_tot # Conjunto de datos completo que se pasará a cada ejecución de entrenamiento
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
tuner = tune.Tuner(
|
|
488
|
+
wrapped_train_fn,
|
|
489
|
+
param_space=search_space,
|
|
490
|
+
tune_config=TuneConfig(
|
|
491
|
+
num_samples=20,
|
|
492
|
+
trial_name_creator=lambda trial: f"trial_{trial.trial_id[:5]}",
|
|
493
|
+
trial_dirname_creator=lambda trial: f"dir_{trial.trial_id[:5]}",
|
|
494
|
+
),
|
|
495
|
+
run_config=RunConfig(
|
|
496
|
+
name="CAPTURE24_hyperparameters_tuning",
|
|
497
|
+
storage_path=case_id_folder,
|
|
498
|
+
checkpoint_config=CheckpointConfig(num_to_keep=1),
|
|
499
|
+
failure_config=FailureConfig(fail_fast=False, max_failures=10),
|
|
500
|
+
verbose=2,
|
|
501
|
+
log_to_file=False
|
|
502
|
+
)
|
|
503
|
+
)
|
|
504
|
+
results = tuner.fit()
|
|
505
|
+
|
|
506
|
+
# Obtener mejor resultado
|
|
507
|
+
best_result = results.get_best_result(metric="val_accuracy", mode="max")
|
|
508
|
+
print("Mejores hiperparámetros:", best_result.config)
|
|
509
|
+
|
|
510
|
+
# Obtener la configuración óptima como diccionario
|
|
511
|
+
mejores_hiperparametros = best_result.config
|
|
512
|
+
|
|
513
|
+
# Guardar en un archivo JSON
|
|
514
|
+
with open(os.path.join(case_id_folder,"mejores_hiperparametros_CAPTURE24.json"), "w") as f:
|
|
515
|
+
json.dump(mejores_hiperparametros, f, indent=4)
|
|
516
|
+
|
|
517
|
+
# Obtener los resultados como DataFrame
|
|
518
|
+
df = results.get_dataframe()
|
|
519
|
+
df.to_json(os.path.join(case_id_folder, "resultados_busqueda_ray_tune_CAPTURE24.json"), orient="records", lines=True)
|
|
520
|
+
|
|
521
|
+
# Construir modelo usando modelGenerator y los mejores hiperparámetros
|
|
522
|
+
model_CAPTURE24_data_tot = modelGenerator(
|
|
523
|
+
modelID=modelID,
|
|
524
|
+
data=data_tot,
|
|
525
|
+
params=mejores_hiperparametros, # Pasamos directamente el diccionario
|
|
526
|
+
debug=False
|
|
527
|
+
)
|
|
528
|
+
# Entrenar el modelo con todos los datos
|
|
529
|
+
model_CAPTURE24_data_tot.train(mejores_hiperparametros['epochs'])
|
|
530
|
+
# Guardar los pesos del modelo en formato .weights.h5
|
|
531
|
+
model_CAPTURE24_data_tot.store(modelID, case_id_folder)
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
# **********
|
|
535
|
+
# Modelo C *
|
|
536
|
+
# **********
|
|
537
|
+
elif modelID == ML_Model.RANDOM_FOREST.value:
|
|
538
|
+
dataset_file = os.path.join(case_id_folder, FEATURE_DATASET_FILE)
|
|
539
|
+
label_encoder_file = os.path.join(case_id_folder, LABEL_ENCODER_FILE)
|
|
540
|
+
config_file = os.path.join(case_id_folder, CONFIG_FILE)
|
|
541
|
+
|
|
542
|
+
data_tot = DataReader(modelID=modelID,
|
|
543
|
+
create_superclasses=args.create_superclasses,
|
|
544
|
+
create_superclasses_CPA_METs= args.create_superclasses_CPA_METs,
|
|
545
|
+
p_train = args.training_percent,
|
|
546
|
+
p_validation = args.validation_percent,
|
|
547
|
+
file_path=dataset_file,
|
|
548
|
+
label_encoder_path=label_encoder_file,
|
|
549
|
+
config_path = config_file)
|
|
550
|
+
|
|
551
|
+
# Se entrenan y salvan los modelos (fichero .pkl).
|
|
552
|
+
# Ruta al archivo de hiperparámetros guardados
|
|
553
|
+
hp_json_path = os.path.join(case_id_folder, "mejores_hiperparametros_BRF.json")
|
|
554
|
+
# Verifica que el archivo existe
|
|
555
|
+
if os.path.isfile(hp_json_path):
|
|
556
|
+
# Cargar hiperparámetros desde el archivo JSON
|
|
557
|
+
with open(hp_json_path, "r") as f:
|
|
558
|
+
best_hp_values = json.load(f) # Diccionario: {param: valor}
|
|
559
|
+
# Construir modelo usando modelGenerator y los hiperparámetros
|
|
560
|
+
model_RandomForest_data_tot = modelGenerator(
|
|
561
|
+
modelID=modelID,
|
|
562
|
+
data=data_tot,
|
|
563
|
+
params=best_hp_values, # Pasamos directamente el diccionario
|
|
564
|
+
debug=False
|
|
565
|
+
)
|
|
566
|
+
# Entrenar el modelo con todos los datos
|
|
567
|
+
model_RandomForest_data_tot.train()
|
|
568
|
+
# Guardar los pesos del modelo en formato .weights.h5
|
|
569
|
+
model_RandomForest_data_tot.store(modelID, case_id_folder)
|
|
570
|
+
else:
|
|
571
|
+
print(f"Se lanza la búsqueda de hiperparámetros óptimos del modelo")
|
|
572
|
+
# ------------------------------------------------------------------------------------------------------
|
|
573
|
+
# Búsqueda de hiperparámetros óptimos del modelo BalancedRandomForestClassifier, usando Ray Tune y ASHA
|
|
574
|
+
# ------------------------------------------------------------------------------------------------------
|
|
575
|
+
# Espacio de búsqueda
|
|
576
|
+
search_space = {
|
|
577
|
+
"n_estimators": tune.randint(50, 301), # Número de árboles entre 50 y 300
|
|
578
|
+
"max_depth": tune.choice([5, 10, 15, 20, None]), # Profundidad máxima del árbol
|
|
579
|
+
"min_samples_split": tune.randint(2, 11), # Muestras mínimas para dividir un nodo
|
|
580
|
+
"min_samples_leaf": tune.randint(1, 11), # Muestras mínimas por hoja
|
|
581
|
+
"max_features": tune.choice([None, "sqrt", "log2"]), # Número de características por división
|
|
582
|
+
"random_state": tune.randint(0, 10000)
|
|
583
|
+
}
|
|
584
|
+
|
|
585
|
+
# Configuración del scheduler
|
|
586
|
+
scheduler = ASHAScheduler(
|
|
587
|
+
metric="test_accuracy", # Métrica a optimizar: precisión en el conjunto de test
|
|
588
|
+
mode="max", # Se busca maximizar la métrica especificada
|
|
589
|
+
max_t=10, # Número máximo de iteraciones (no se usa directamente en Random Forest, pero requerido por ASHA)
|
|
590
|
+
grace_period=1, # Número mínimo de iteraciones antes de detener una prueba prematuramente
|
|
591
|
+
reduction_factor=2 # Factor por el cual se reduce el número de pruebas en cada ronda de selección
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
# Envolver función con parámetros adicionales
|
|
595
|
+
wrapped_train_fn = tune.with_parameters(
|
|
596
|
+
train_brf_ray_tune, # Función de entrenamiento base que se usará en la búsqueda
|
|
597
|
+
model_class=SiMuRModel_RandomForest, # Clase del modelo a usar
|
|
598
|
+
data=data_tot # Conjunto de datos completo que se pasará a cada ejecución de entrenamiento
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
# Crear el tuner
|
|
602
|
+
tuner = tune.Tuner(
|
|
603
|
+
wrapped_train_fn, # Función de entrenamiento envuelta con parámetros fijos
|
|
604
|
+
param_space=search_space, # Espacio de búsqueda de hiperparámetros definido antes
|
|
605
|
+
tune_config=TuneConfig(
|
|
606
|
+
scheduler=scheduler, # Scheduler para manejar la parada temprana (ASHAScheduler)
|
|
607
|
+
num_samples=20, # Número de configuraciones (experimentos) a probar
|
|
608
|
+
trial_name_creator=lambda trial: f"trial_{trial.trial_id[:5]}", # Nombre personalizado para cada prueba
|
|
609
|
+
trial_dirname_creator=lambda trial: f"dir_{trial.trial_id[:5]}" # Carpeta personalizada para cada prueba
|
|
610
|
+
),
|
|
611
|
+
run_config=RunConfig(
|
|
612
|
+
name="BalancedRF_hyperparameters_tuning", # Nombre general del experimento
|
|
613
|
+
storage_path=case_id_folder, # Ruta donde se guardan los resultados y checkpoints
|
|
614
|
+
checkpoint_config=CheckpointConfig(num_to_keep=1), # Guardar solo el último checkpoint por prueba
|
|
615
|
+
failure_config=FailureConfig(fail_fast=False, max_failures=10), # Permite hasta 10 fallos antes de parar
|
|
616
|
+
verbose=2, # Nivel de detalle en los logs (más detallado)
|
|
617
|
+
log_to_file=False # No guardar logs en archivos (evita problemas con rutas largas)
|
|
618
|
+
)
|
|
619
|
+
)
|
|
620
|
+
|
|
621
|
+
# Ejecutar búsqueda de hiperparámetros
|
|
622
|
+
results = tuner.fit()
|
|
623
|
+
|
|
624
|
+
# Obtener mejor resultado
|
|
625
|
+
best_result = results.get_best_result(metric="test_accuracy", mode="max")
|
|
626
|
+
print("Mejores hiperparámetros:", best_result.config)
|
|
627
|
+
|
|
628
|
+
# Obtener la configuración óptima como diccionario
|
|
629
|
+
mejores_hiperparametros = best_result.config
|
|
630
|
+
|
|
631
|
+
# Guardar en un archivo JSON
|
|
632
|
+
with open(os.path.join(case_id_folder,"mejores_hiperparametros_BRF.json"), "w") as f:
|
|
633
|
+
json.dump(mejores_hiperparametros, f, indent=4)
|
|
634
|
+
|
|
635
|
+
# Obtener los resultados como DataFrame
|
|
636
|
+
df = results.get_dataframe()
|
|
637
|
+
df.to_json(os.path.join(case_id_folder, "resultados_busqueda_ray_tune_BRF.json"), orient="records", lines=True)
|
|
638
|
+
|
|
639
|
+
# Construir modelo usando modelGenerator y los mejores hiperparámetros
|
|
640
|
+
model_RandomForest_data_tot = modelGenerator(
|
|
641
|
+
modelID=modelID,
|
|
642
|
+
data=data_tot,
|
|
643
|
+
params=mejores_hiperparametros, # Pasamos directamente el diccionario
|
|
644
|
+
debug=False
|
|
645
|
+
)
|
|
646
|
+
# Entrenar el modelo con todos los datos
|
|
647
|
+
model_RandomForest_data_tot.train()
|
|
648
|
+
# Guardar los pesos del modelo en formato .weights.h5
|
|
649
|
+
model_RandomForest_data_tot.store(modelID, case_id_folder)
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
# **********
|
|
653
|
+
# Modelo D *
|
|
654
|
+
# **********
|
|
655
|
+
elif modelID == ML_Model.XGBOOST.value:
|
|
656
|
+
dataset_file = os.path.join(case_id_folder, FEATURE_DATASET_FILE)
|
|
657
|
+
label_encoder_file = os.path.join(case_id_folder, LABEL_ENCODER_FILE)
|
|
658
|
+
config_file = os.path.join(case_id_folder, CONFIG_FILE)
|
|
659
|
+
|
|
660
|
+
data_tot = DataReader(modelID=modelID,
|
|
661
|
+
create_superclasses=args.create_superclasses,
|
|
662
|
+
create_superclasses_CPA_METs = args.create_superclasses_CPA_METs,
|
|
663
|
+
p_train = args.training_percent,
|
|
664
|
+
p_validation = args.validation_percent,
|
|
665
|
+
file_path=dataset_file,
|
|
666
|
+
label_encoder_path=label_encoder_file,
|
|
667
|
+
config_path = config_file)
|
|
668
|
+
|
|
669
|
+
# Se entrenan y salvan los modelos (fichero .pkl).
|
|
670
|
+
# Ruta al archivo de hiperparámetros guardados
|
|
671
|
+
hp_json_path = os.path.join(case_id_folder, "mejores_hiperparametros_XGB.json")
|
|
672
|
+
# Verifica que el archivo existe
|
|
673
|
+
if os.path.isfile(hp_json_path):
|
|
674
|
+
# Cargar hiperparámetros desde el archivo JSON
|
|
675
|
+
with open(hp_json_path, "r") as f:
|
|
676
|
+
best_hp_values = json.load(f) # Diccionario: {param: valor}
|
|
677
|
+
# Construir modelo usando modelGenerator y los hiperparámetros
|
|
678
|
+
model_XGBoost_data_tot = modelGenerator(
|
|
679
|
+
modelID=modelID,
|
|
680
|
+
data=data_tot,
|
|
681
|
+
# params=best_hp_values, # Pasamos directamente el diccionario
|
|
682
|
+
params=best_hp_values,
|
|
683
|
+
debug=False
|
|
684
|
+
)
|
|
685
|
+
# Entrenar el modelo con todos los datos de (X_train, y_train), implementando la validación con (X_validation, y_validation)
|
|
686
|
+
model_XGBoost_data_tot.train()
|
|
687
|
+
# Guardar los pesos del modelo en formato .weights.h5
|
|
688
|
+
model_XGBoost_data_tot.store(modelID, case_id_folder)
|
|
689
|
+
|
|
690
|
+
else:
|
|
691
|
+
print(f"Se lanza la búsqueda de hiperparámetros óptimos del modelo.")
|
|
692
|
+
# ------------------------------------------------------------------------------------------------------
|
|
693
|
+
# Búsqueda de hiperparámetros óptimos del modelo XGBoost, usando Ray Tune (ASHA)
|
|
694
|
+
# ------------------------------------------------------------------------------------------------------
|
|
695
|
+
# Espacio de búsqueda para XGBoost
|
|
696
|
+
search_space_xgb = {
|
|
697
|
+
"num_boost_round": tune.randint(50, 3001), # Árboles (rondas) de boosting
|
|
698
|
+
"max_depth": tune.randint(3, 11), # Profundidad máxima
|
|
699
|
+
"learning_rate": tune.uniform(0.01, 0.3), # Tasa de aprendizaje
|
|
700
|
+
"subsample": tune.uniform(0.5, 1.0), # Fracción de muestras por árbol
|
|
701
|
+
"colsample_bytree": tune.uniform(0.5, 1.0), # Fracción de columnas por árbol
|
|
702
|
+
"gamma": tune.uniform(0, 5), # Regularización mínima de pérdida
|
|
703
|
+
"min_child_weight": tune.randint(1, 10), # Peso mínimo de hijos
|
|
704
|
+
"reg_alpha": tune.uniform(0, 1), # L1 regularization
|
|
705
|
+
"reg_lambda": tune.uniform(0, 1), # L2 regularization
|
|
706
|
+
"random_state": tune.randint(0, 10000)
|
|
707
|
+
}
|
|
708
|
+
|
|
709
|
+
# Configuración del scheduler (igual que en RF)
|
|
710
|
+
scheduler = ASHAScheduler( # Crea una instancia del scheduler ASHA para optimizar entrenamientos
|
|
711
|
+
metric="validation_accuracy", # Métrica a optimizar (precisión en validación)
|
|
712
|
+
mode="max", # Indica que la métrica debe maximizarse
|
|
713
|
+
max_t=10, # Número máximo de iteraciones/épocas por configuración
|
|
714
|
+
grace_period=1, # Número mínimo de iteraciones antes de detener un trial por bajo rendimiento
|
|
715
|
+
reduction_factor=2 # Factor de reducción para descartar configuraciones poco prometedoras
|
|
716
|
+
)
|
|
717
|
+
|
|
718
|
+
# Envolver función de entrenamiento
|
|
719
|
+
wrapped_train_fn_xgb = tune.with_parameters( # Crea una versión de la función con parámetros fijos predefinidos
|
|
720
|
+
train_xgb_ray_tune, # Función de entrenamiento adaptada a XGBoost
|
|
721
|
+
model_class=SiMuRModel_XGBoost, # Clase del modelo a utilizar (implementación XGBoost personalizada)
|
|
722
|
+
data=data_tot # Conjunto de datos completo que se usará en el entrenamiento
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
# Crear el tuner
|
|
726
|
+
tuner_xgb = tune.Tuner( # Crea un objeto Tuner para ejecutar la búsqueda de hiperparámetros
|
|
727
|
+
wrapped_train_fn_xgb, # Función de entrenamiento envuelta con parámetros fijos
|
|
728
|
+
param_space=search_space_xgb, # Espacio de búsqueda de hiperparámetros
|
|
729
|
+
tune_config=TuneConfig( # Configuración de la optimización
|
|
730
|
+
scheduler=scheduler, # Planificador (scheduler) para gestionar recursos y early stopping
|
|
731
|
+
num_samples=20, # Número de configuraciones distintas a probar
|
|
732
|
+
trial_name_creator=lambda trial: f"trial_{trial.trial_id[:5]}", # Nombre personalizado para cada experimento
|
|
733
|
+
trial_dirname_creator=lambda trial: f"dir_{trial.trial_id[:5]}" # Carpeta personalizada para cada experimento
|
|
734
|
+
),
|
|
735
|
+
run_config=RunConfig( # Configuración de ejecución de los experimentos
|
|
736
|
+
name="XGBoost_hyperparameters_tuning", # Nombre general de la ejecución
|
|
737
|
+
storage_path=case_id_folder, # Carpeta donde guardar resultados y checkpoints
|
|
738
|
+
checkpoint_config=CheckpointConfig(num_to_keep=1), # Mantener solo el último checkpoint por trial
|
|
739
|
+
failure_config=FailureConfig(fail_fast=False, max_failures=10), # Permitir hasta 10 fallos sin abortar
|
|
740
|
+
verbose=2, # Nivel de detalle en la salida por consola
|
|
741
|
+
log_to_file=False # No guardar logs en archivo (solo consola)
|
|
742
|
+
)
|
|
743
|
+
)
|
|
744
|
+
|
|
745
|
+
# Ejecutar la búsqueda
|
|
746
|
+
results_xgb = tuner_xgb.fit()
|
|
747
|
+
|
|
748
|
+
# Mejor resultado
|
|
749
|
+
best_result_xgb = results_xgb.get_best_result(metric="validation_accuracy", mode="max")
|
|
750
|
+
print("Mejores hiperparámetros XGBoost:", best_result_xgb.config)
|
|
751
|
+
|
|
752
|
+
# Guardar en JSON
|
|
753
|
+
mejores_hiperparametros_xgb = best_result_xgb.config
|
|
754
|
+
with open(os.path.join(case_id_folder, "mejores_hiperparametros_XGB.json"), "w") as f:
|
|
755
|
+
json.dump(mejores_hiperparametros_xgb, f, indent=4)
|
|
756
|
+
|
|
757
|
+
# Guardar todos los resultados
|
|
758
|
+
df_xgb = results_xgb.get_dataframe()
|
|
759
|
+
df_xgb.to_json(os.path.join(case_id_folder, "resultados_busqueda_ray_tune_XGB.json"), orient="records", lines=True)
|
|
760
|
+
|
|
761
|
+
# Construir y entrenar modelo con mejores hiperparámetros
|
|
762
|
+
model_XGB_data_tot = modelGenerator(
|
|
763
|
+
modelID=modelID,
|
|
764
|
+
data=data_tot,
|
|
765
|
+
params=mejores_hiperparametros_xgb,
|
|
766
|
+
debug=False
|
|
767
|
+
)
|
|
768
|
+
model_XGB_data_tot.train()
|
|
769
|
+
model_XGB_data_tot.store(modelID, case_id_folder)
|
|
770
|
+
|
|
771
|
+
|
|
772
|
+
_logger.info("Script ends here")
|
|
773
|
+
|
|
774
|
+
def run():
|
|
775
|
+
"""Calls :func:`main` passing the CLI arguments extracted from :obj:`sys.argv`
|
|
776
|
+
|
|
777
|
+
This function can be used as entry point to create console scripts with setuptools.
|
|
778
|
+
"""
|
|
779
|
+
main(sys.argv[1:])
|
|
780
|
+
|
|
781
|
+
if __name__ == "__main__":
|
|
782
|
+
run()
|