dragon-ml-toolbox 8.1.0__py3-none-any.whl → 9.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

Files changed (34) hide show
  1. {dragon_ml_toolbox-8.1.0.dist-info → dragon_ml_toolbox-9.0.0.dist-info}/METADATA +5 -1
  2. dragon_ml_toolbox-9.0.0.dist-info/RECORD +35 -0
  3. ml_tools/ETL_engineering.py +216 -81
  4. ml_tools/GUI_tools.py +5 -5
  5. ml_tools/MICE_imputation.py +12 -8
  6. ml_tools/ML_callbacks.py +6 -3
  7. ml_tools/ML_datasetmaster.py +37 -20
  8. ml_tools/ML_evaluation.py +4 -4
  9. ml_tools/ML_evaluation_multi.py +26 -17
  10. ml_tools/ML_inference.py +30 -23
  11. ml_tools/ML_models.py +14 -14
  12. ml_tools/ML_optimization.py +4 -3
  13. ml_tools/ML_scaler.py +7 -7
  14. ml_tools/ML_trainer.py +17 -15
  15. ml_tools/PSO_optimization.py +16 -8
  16. ml_tools/RNN_forecast.py +1 -1
  17. ml_tools/SQL.py +22 -13
  18. ml_tools/VIF_factor.py +7 -6
  19. ml_tools/_logger.py +105 -7
  20. ml_tools/custom_logger.py +12 -8
  21. ml_tools/data_exploration.py +20 -15
  22. ml_tools/ensemble_evaluation.py +10 -6
  23. ml_tools/ensemble_inference.py +18 -18
  24. ml_tools/ensemble_learning.py +8 -5
  25. ml_tools/handle_excel.py +15 -11
  26. ml_tools/optimization_tools.py +3 -4
  27. ml_tools/path_manager.py +21 -15
  28. ml_tools/utilities.py +35 -26
  29. dragon_ml_toolbox-8.1.0.dist-info/RECORD +0 -36
  30. ml_tools/_ML_optimization_multi.py +0 -231
  31. {dragon_ml_toolbox-8.1.0.dist-info → dragon_ml_toolbox-9.0.0.dist-info}/WHEEL +0 -0
  32. {dragon_ml_toolbox-8.1.0.dist-info → dragon_ml_toolbox-9.0.0.dist-info}/licenses/LICENSE +0 -0
  33. {dragon_ml_toolbox-8.1.0.dist-info → dragon_ml_toolbox-9.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  34. {dragon_ml_toolbox-8.1.0.dist-info → dragon_ml_toolbox-9.0.0.dist-info}/top_level.txt +0 -0
@@ -127,7 +127,8 @@ def create_pytorch_problem(
127
127
  SearcherClass = GeneticAlgorithm
128
128
 
129
129
  else:
130
- raise ValueError(f"Unknown algorithm '{algorithm}'.")
130
+ _LOGGER.error(f"Unknown algorithm '{algorithm}'.")
131
+ raise ValueError()
131
132
 
132
133
  # Create a factory function with all arguments pre-filled
133
134
  searcher_factory = partial(SearcherClass, problem, **searcher_kwargs)
@@ -242,7 +243,7 @@ def run_optimization(
242
243
  if verbose:
243
244
  _handle_pandas_log(pandas_logger, save_path=save_path, target_name=target_name)
244
245
 
245
- _LOGGER.info(f"Optimization complete. Best solution saved to '{csv_path.name}'")
246
+ _LOGGER.info(f"Optimization complete. Best solution saved to '{csv_path.name}'")
246
247
  return result_dict
247
248
 
248
249
  # --- MULTIPLE REPETITIONS LOGIC ---
@@ -295,7 +296,7 @@ def run_optimization(
295
296
  if pandas_logger is not None:
296
297
  _handle_pandas_log(pandas_logger, save_path=save_path, target_name=target_name)
297
298
 
298
- _LOGGER.info(f"Optimal solution space complete. Results saved to '{save_path}'")
299
+ _LOGGER.info(f"Optimal solution space complete. Results saved to '{save_path}'")
299
300
  return None
300
301
 
301
302
 
ml_tools/ML_scaler.py CHANGED
@@ -50,7 +50,7 @@ class PytorchScaler:
50
50
  PytorchScaler: A new, fitted instance of the scaler.
51
51
  """
52
52
  if not continuous_feature_indices:
53
- _LOGGER.warning("⚠️ No continuous feature indices provided. Scaler will not be fitted.")
53
+ _LOGGER.error("No continuous feature indices provided. Scaler will not be fitted.")
54
54
  return cls()
55
55
 
56
56
  loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
@@ -72,7 +72,7 @@ class PytorchScaler:
72
72
  count += continuous_features.size(0)
73
73
 
74
74
  if count == 0:
75
- _LOGGER.warning("⚠️ Dataset is empty. Scaler cannot be fitted.")
75
+ _LOGGER.error("Dataset is empty. Scaler cannot be fitted.")
76
76
  return cls(continuous_feature_indices=continuous_feature_indices)
77
77
 
78
78
  # Calculate mean
@@ -80,7 +80,7 @@ class PytorchScaler:
80
80
 
81
81
  # Calculate standard deviation
82
82
  if count < 2:
83
- _LOGGER.warning(f"⚠️ Only one sample found. Standard deviation cannot be calculated and is set to 1.")
83
+ _LOGGER.warning(f"Only one sample found. Standard deviation cannot be calculated and is set to 1.")
84
84
  std = torch.ones_like(mean)
85
85
  else:
86
86
  # var = E[X^2] - (E[X])^2
@@ -101,7 +101,7 @@ class PytorchScaler:
101
101
  torch.Tensor: The transformed data tensor.
102
102
  """
103
103
  if self.mean_ is None or self.std_ is None or self.continuous_feature_indices is None:
104
- _LOGGER.warning("⚠️ Scaler has not been fitted. Returning original data.")
104
+ _LOGGER.error("Scaler has not been fitted. Returning original data.")
105
105
  return data
106
106
 
107
107
  data_clone = data.clone()
@@ -132,7 +132,7 @@ class PytorchScaler:
132
132
  torch.Tensor: The original-scale data tensor.
133
133
  """
134
134
  if self.mean_ is None or self.std_ is None or self.continuous_feature_indices is None:
135
- _LOGGER.warning("⚠️ Scaler has not been fitted. Returning original data.")
135
+ _LOGGER.error("Scaler has not been fitted. Returning original data.")
136
136
  return data
137
137
 
138
138
  data_clone = data.clone()
@@ -163,7 +163,7 @@ class PytorchScaler:
163
163
  'continuous_feature_indices': self.continuous_feature_indices
164
164
  }
165
165
  torch.save(state, path_obj)
166
- _LOGGER.info(f"PytorchScaler state saved to '{path_obj.name}'.")
166
+ _LOGGER.info(f"PytorchScaler state saved to '{path_obj.name}'.")
167
167
 
168
168
  @staticmethod
169
169
  def load(filepath: Union[str, Path]) -> 'PytorchScaler':
@@ -178,7 +178,7 @@ class PytorchScaler:
178
178
  """
179
179
  path_obj = make_fullpath(filepath, enforce="file")
180
180
  state = torch.load(path_obj)
181
- _LOGGER.info(f"PytorchScaler state loaded from '{path_obj.name}'.")
181
+ _LOGGER.info(f"PytorchScaler state loaded from '{path_obj.name}'.")
182
182
  return PytorchScaler(
183
183
  mean=state['mean'],
184
184
  std=state['std'],
ml_tools/ML_trainer.py CHANGED
@@ -76,10 +76,10 @@ class MLTrainer:
76
76
  """Validates the selected device and returns a torch.device object."""
77
77
  device_lower = device.lower()
78
78
  if "cuda" in device_lower and not torch.cuda.is_available():
79
- _LOGGER.warning("⚠️ CUDA not available, switching to CPU.")
79
+ _LOGGER.warning("CUDA not available, switching to CPU.")
80
80
  device = "cpu"
81
81
  elif device_lower == "mps" and not torch.backends.mps.is_available():
82
- _LOGGER.warning("⚠️ Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
82
+ _LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
83
83
  device = "cpu"
84
84
  return torch.device(device)
85
85
 
@@ -275,7 +275,8 @@ class MLTrainer:
275
275
  dataset_for_names = data
276
276
  else: # data is None, use the trainer's default test dataset
277
277
  if self.test_dataset is None:
278
- raise ValueError("Cannot evaluate. No data provided and no test_dataset available in the trainer.")
278
+ _LOGGER.error("Cannot evaluate. No data provided and no test_dataset available in the trainer.")
279
+ raise ValueError()
279
280
  # Create a fresh DataLoader from the test_dataset
280
281
  eval_loader = DataLoader(self.test_dataset,
281
282
  batch_size=32,
@@ -285,7 +286,8 @@ class MLTrainer:
285
286
  dataset_for_names = self.test_dataset
286
287
 
287
288
  if eval_loader is None:
288
- raise ValueError("Cannot evaluate. No valid data was provided or found.")
289
+ _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
290
+ raise ValueError()
289
291
 
290
292
  print("\n--- Model Evaluation ---")
291
293
 
@@ -296,7 +298,7 @@ class MLTrainer:
296
298
  if y_true_b is not None: all_true.append(y_true_b)
297
299
 
298
300
  if not all_true:
299
- _LOGGER.error("Evaluation failed: No data was processed.")
301
+ _LOGGER.error("Evaluation failed: No data was processed.")
300
302
  return
301
303
 
302
304
  y_pred = np.concatenate(all_preds)
@@ -316,7 +318,7 @@ class MLTrainer:
316
318
  except AttributeError:
317
319
  num_targets = y_true.shape[1]
318
320
  target_names = [f"target_{i}" for i in range(num_targets)]
319
- _LOGGER.warning(f"⚠️ Dataset has no 'target_names' attribute. Using generic names.")
321
+ _LOGGER.warning(f"Dataset has no 'target_names' attribute. Using generic names.")
320
322
  multi_target_regression_metrics(y_true, y_pred, target_names, save_dir)
321
323
 
322
324
  elif self.kind == "multi_label_classification":
@@ -325,10 +327,10 @@ class MLTrainer:
325
327
  except AttributeError:
326
328
  num_targets = y_true.shape[1]
327
329
  target_names = [f"label_{i}" for i in range(num_targets)]
328
- _LOGGER.warning(f"⚠️ Dataset has no 'target_names' attribute. Using generic names.")
330
+ _LOGGER.warning(f"Dataset has no 'target_names' attribute. Using generic names.")
329
331
 
330
332
  if y_prob is None:
331
- _LOGGER.error("Evaluation for multi_label_classification requires probabilities (y_prob).")
333
+ _LOGGER.error("Evaluation for multi_label_classification requires probabilities (y_prob).")
332
334
  return
333
335
  multi_label_classification_metrics(y_true, y_prob, target_names, save_dir, classification_threshold)
334
336
 
@@ -390,14 +392,14 @@ class MLTrainer:
390
392
  # 1. Get background data from the trainer's train_dataset
391
393
  background_data = _get_random_sample(self.train_dataset, n_samples)
392
394
  if background_data is None:
393
- _LOGGER.error("Trainer's train_dataset is empty or invalid. Skipping SHAP analysis.")
395
+ _LOGGER.error("Trainer's train_dataset is empty or invalid. Skipping SHAP analysis.")
394
396
  return
395
397
 
396
398
  # 2. Determine target dataset and get explanation instances
397
399
  target_dataset = explain_dataset if explain_dataset is not None else self.test_dataset
398
400
  instances_to_explain = _get_random_sample(target_dataset, n_samples)
399
401
  if instances_to_explain is None:
400
- _LOGGER.error("Explanation dataset is empty or invalid. Skipping SHAP analysis.")
402
+ _LOGGER.error("Explanation dataset is empty or invalid. Skipping SHAP analysis.")
401
403
  return
402
404
 
403
405
  # attempt to get feature names
@@ -410,8 +412,8 @@ class MLTrainer:
410
412
  # Handle PyTorch Subset
411
413
  feature_names = target_dataset.dataset.feature_names # type: ignore
412
414
  except AttributeError:
413
- _LOGGER.error("Could not extract `feature_names` from the dataset.")
414
- raise ValueError("`feature_names` must be provided if the dataset object does not have a `feature_names` attribute.")
415
+ _LOGGER.error("Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a `feature_names` attribute.")
416
+ raise ValueError()
415
417
 
416
418
  # 3. Call the plotting function
417
419
  if self.kind in ["regression", "classification"]:
@@ -490,13 +492,13 @@ class MLTrainer:
490
492
 
491
493
  # --- Step 1: Check if the model supports this explanation ---
492
494
  if not hasattr(self.model, 'forward_attention'):
493
- _LOGGER.error("Model does not have a `forward_attention` method. Skipping attention explanation.")
495
+ _LOGGER.error("Model does not have a `forward_attention` method. Skipping attention explanation.")
494
496
  return
495
497
 
496
498
  # --- Step 2: Set up the dataloader ---
497
499
  dataset_to_use = explain_dataset if explain_dataset is not None else self.test_dataset
498
500
  if not isinstance(dataset_to_use, Dataset):
499
- _LOGGER.error("The explanation dataset is empty or invalid. Skipping attention analysis.")
501
+ _LOGGER.error("The explanation dataset is empty or invalid. Skipping attention analysis.")
500
502
  return
501
503
 
502
504
  explain_loader = DataLoader(
@@ -519,7 +521,7 @@ class MLTrainer:
519
521
  save_dir=save_dir
520
522
  )
521
523
  else:
522
- _LOGGER.error("No attention weights were collected from the model.")
524
+ _LOGGER.error("No attention weights were collected from the model.")
523
525
 
524
526
  def callbacks_hook(self, method_name: str, *args, **kwargs):
525
527
  """Calls the specified method on all callbacks."""
@@ -65,7 +65,9 @@ class ObjectiveFunction():
65
65
  np.ndarray
66
66
  1D array with length n_samples containing predicted target values.
67
67
  """
68
- assert features_array.ndim == 2, f"Expected 2D array, got shape {features_array.shape}"
68
+ if features_array.ndim != 2:
69
+ _LOGGER.error(f"Expected 2D array, got shape {features_array.shape}.")
70
+ raise AssertionError()
69
71
 
70
72
  # Apply noise if enabled
71
73
  if self.use_noise:
@@ -101,7 +103,9 @@ class ObjectiveFunction():
101
103
  np.ndarray
102
104
  Noised array of same shape
103
105
  """
104
- assert features_array.ndim == 2, "Expected 2D array for batch noise injection"
106
+ if features_array.ndim != 2:
107
+ _LOGGER.error(f"Expected 2D array for batch noise injection, got shape {features_array.shape}.")
108
+ raise AssertionError()
105
109
 
106
110
  if self.binary_features > 0:
107
111
  split_idx = -self.binary_features
@@ -118,13 +122,16 @@ class ObjectiveFunction():
118
122
 
119
123
  def check_model(self):
120
124
  if isinstance(self.model, xgb.XGBClassifier) or isinstance(self.model, lgb.LGBMClassifier):
121
- raise ValueError(f"[Model Check Failed] ❌\nThe loaded model ({type(self.model).__name__}) is a Classifier.\nOptimization is not suitable for standard classification tasks.")
125
+ _LOGGER.error(f"[Model Check Failed]\nThe loaded model ({type(self.model).__name__}) is a Classifier.\nOptimization is not suitable for standard classification tasks.")
126
+ raise ValueError()
122
127
  if self.model is None:
123
- raise ValueError("Loaded model is None")
128
+ _LOGGER.error("Loaded model is None")
129
+ raise ValueError()
124
130
 
125
131
  def _get_from_artifact(self, key: str):
126
132
  if self._artifact is None:
127
- raise TypeError("Load model error")
133
+ _LOGGER.error("Load model error")
134
+ raise TypeError()
128
135
  val = self._artifact.get(key)
129
136
  if key == EnsembleKeys.FEATURES:
130
137
  result = val if isinstance(val, list) and val else None
@@ -314,7 +321,8 @@ def run_pso(lower_boundaries: list[float],
314
321
  if target_name is None and objective_function.target_name is not None:
315
322
  target_name = objective_function.target_name
316
323
  if target_name is None:
317
- raise ValueError(f"'target' name was not provided and was not found in the .joblib object.")
324
+ _LOGGER.error(f"'target' name was not provided and was not found in the .joblib object.")
325
+ raise ValueError()
318
326
 
319
327
  # --- Setup: Saving Infrastructure ---
320
328
  sanitized_target_name = sanitize_filename(target_name)
@@ -355,7 +363,7 @@ def run_pso(lower_boundaries: list[float],
355
363
  objective_function, pso_arguments, names, target_name, random_state,
356
364
  save_format, csv_path, db_manager, db_table_name
357
365
  )
358
- _LOGGER.info(f"Single optimization complete.")
366
+ _LOGGER.info(f"Single optimization complete.")
359
367
  return features_dict, target_dict
360
368
 
361
369
  else:
@@ -365,7 +373,7 @@ def run_pso(lower_boundaries: list[float],
365
373
  objective_function, pso_arguments, names, target_name, post_hoc_analysis,
366
374
  save_format, csv_path, db_manager, db_table_name
367
375
  )
368
- _LOGGER.info("Post-hoc analysis complete. Results saved.")
376
+ _LOGGER.info("Post-hoc analysis complete. Results saved.")
369
377
  return None
370
378
 
371
379
 
ml_tools/RNN_forecast.py CHANGED
@@ -51,4 +51,4 @@ def rnn_forecast(model: nn.Module, start_sequence: torch.Tensor, steps: int, dev
51
51
 
52
52
 
53
53
  def info():
54
- _script_info
54
+ _script_info(__all__)
ml_tools/SQL.py CHANGED
@@ -62,7 +62,7 @@ class DatabaseManager:
62
62
  _LOGGER.info(f"❇️ Successfully connected to database: {self.db_path}")
63
63
  return self
64
64
  except sqlite3.Error as e:
65
- _LOGGER.error(f"Database connection failed: {e}")
65
+ _LOGGER.error(f"Database connection failed: {e}")
66
66
  raise # Re-raise the exception after logging
67
67
 
68
68
  def __exit__(self, exc_type, exc_val, exc_tb):
@@ -70,11 +70,11 @@ class DatabaseManager:
70
70
  if self.conn:
71
71
  if exc_type: # If an exception occurred, rollback
72
72
  self.conn.rollback()
73
- _LOGGER.warning("⚠️ Rolling back transaction due to an error.")
73
+ _LOGGER.warning("Rolling back transaction due to an error.")
74
74
  else: # Otherwise, commit the transaction
75
75
  self.conn.commit()
76
76
  self.conn.close()
77
- _LOGGER.info(f"❇️ Database connection closed: {self.db_path.name}")
77
+ _LOGGER.info(f"Database connection closed: {self.db_path.name}")
78
78
 
79
79
  def create_table(self, table_name: str, schema: Dict[str, str], if_not_exists: bool = True):
80
80
  """
@@ -92,7 +92,8 @@ class DatabaseManager:
92
92
  if the table already exists.
93
93
  """
94
94
  if not self.cursor:
95
- raise sqlite3.Error("Database connection is not open.")
95
+ _LOGGER.error("Database connection is not open.")
96
+ raise sqlite3.Error()
96
97
 
97
98
  columns_def = ", ".join([f'"{col_name}" {col_type}' for col_name, col_type in schema.items()])
98
99
  exists_clause = "IF NOT EXISTS" if if_not_exists else ""
@@ -115,7 +116,8 @@ class DatabaseManager:
115
116
  data to be inserted.
116
117
  """
117
118
  if not self.cursor:
118
- raise sqlite3.Error("Database connection is not open.")
119
+ _LOGGER.error("Database connection is not open.")
120
+ raise sqlite3.Error()
119
121
 
120
122
  columns = ', '.join(f'"{k}"' for k in data.keys())
121
123
  placeholders = ', '.join(['?'] * len(data))
@@ -143,7 +145,8 @@ class DatabaseManager:
143
145
  A DataFrame containing the query results.
144
146
  """
145
147
  if not self.conn:
146
- raise sqlite3.Error("Database connection is not open.")
148
+ _LOGGER.error("Database connection is not open.")
149
+ raise sqlite3.Error()
147
150
 
148
151
  return pd.read_sql_query(query, self.conn, params=params)
149
152
 
@@ -159,7 +162,8 @@ class DatabaseManager:
159
162
  An optional tuple of parameters for the query.
160
163
  """
161
164
  if not self.cursor:
162
- raise sqlite3.Error("Database connection is not open.")
165
+ _LOGGER.error("Database connection is not open.")
166
+ raise sqlite3.Error()
163
167
 
164
168
  self.cursor.execute(query, params if params else ())
165
169
 
@@ -176,9 +180,10 @@ class DatabaseManager:
176
180
  All dictionaries should have the same keys.
177
181
  """
178
182
  if not self.cursor:
179
- raise sqlite3.Error("Database connection is not open.")
183
+ _LOGGER.error("Database connection is not open.")
184
+ raise sqlite3.Error()
180
185
  if not data:
181
- _LOGGER.warning("⚠️ insert_many called with empty data list. No action taken.")
186
+ _LOGGER.warning("'insert_many' called with empty data list. No action taken.")
182
187
  return
183
188
 
184
189
  # Assume all dicts have the same keys as the first one
@@ -211,7 +216,8 @@ class DatabaseManager:
211
216
  - 'append': Insert new values to the existing table.
212
217
  """
213
218
  if not self.conn:
214
- raise sqlite3.Error("Database connection is not open.")
219
+ _LOGGER.error("Database connection is not open.")
220
+ raise sqlite3.Error()
215
221
 
216
222
  df.to_sql(
217
223
  table_name,
@@ -224,7 +230,8 @@ class DatabaseManager:
224
230
  def list_tables(self) -> List[str]:
225
231
  """Returns a list of all table names in the database."""
226
232
  if not self.cursor:
227
- raise sqlite3.Error("Database connection is not open.")
233
+ _LOGGER.error("Database connection is not open.")
234
+ raise sqlite3.Error()
228
235
 
229
236
  self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
230
237
  # The result of the fetch is a list of tuples, e.g., [('table1',), ('table2',)]
@@ -237,7 +244,8 @@ class DatabaseManager:
237
244
  Returns a DataFrame with columns: cid, name, type, notnull, dflt_value, pk
238
245
  """
239
246
  if not self.conn:
240
- raise sqlite3.Error("Database connection is not open.")
247
+ _LOGGER.error("Database connection is not open.")
248
+ raise sqlite3.Error()
241
249
 
242
250
  # PRAGMA is a special SQL command in SQLite for database metadata
243
251
  return pd.read_sql_query(f'PRAGMA table_info("{table_name}");', self.conn)
@@ -257,7 +265,8 @@ class DatabaseManager:
257
265
  column are unique.
258
266
  """
259
267
  if not self.cursor:
260
- raise sqlite3.Error("Database connection is not open.")
268
+ _LOGGER.error("Database connection is not open.")
269
+ raise sqlite3.Error()
261
270
 
262
271
  index_name = f"idx_{table_name}_{column_name}"
263
272
  unique_clause = "UNIQUE" if unique else ""
ml_tools/VIF_factor.py CHANGED
@@ -55,19 +55,19 @@ def compute_vif(
55
55
  sanitized_columns = df.select_dtypes(include='number').columns.tolist()
56
56
  missing_features = set(ground_truth_cols) - set(sanitized_columns)
57
57
  if missing_features:
58
- _LOGGER.warning(f"⚠️ These columns are not Numeric:\n{missing_features}")
58
+ _LOGGER.warning(f"These columns are not Numeric:\n{missing_features}")
59
59
  else:
60
60
  sanitized_columns = list()
61
61
  for feature in use_columns:
62
62
  if feature not in ground_truth_cols:
63
- _LOGGER.warning(f"⚠️ The provided column '{feature}' is not in the DataFrame.")
63
+ _LOGGER.warning(f"The provided column '{feature}' is not in the DataFrame.")
64
64
  else:
65
65
  sanitized_columns.append(feature)
66
66
 
67
67
  if ignore_columns is not None and use_columns is None:
68
68
  missing_ignore = set(ignore_columns) - set(ground_truth_cols)
69
69
  if missing_ignore:
70
- _LOGGER.warning(f"⚠️ Warning: The following 'columns to ignore' are not found in the Dataframe:\n{missing_ignore}")
70
+ _LOGGER.warning(f"The following 'columns to ignore' are not found in the Dataframe:\n{missing_ignore}")
71
71
  sanitized_columns = [f for f in sanitized_columns if f not in ignore_columns]
72
72
 
73
73
  X = df[sanitized_columns].copy()
@@ -138,7 +138,7 @@ def compute_vif(
138
138
  filename += ".svg"
139
139
  full_save_path = save_path / filename
140
140
  plt.savefig(full_save_path, format='svg', bbox_inches='tight')
141
- _LOGGER.info(f" Saved VIF plot: '{filename}'")
141
+ _LOGGER.info(f"📊 Saved VIF plot: '{filename}'")
142
142
 
143
143
  if show_plot:
144
144
  plt.show()
@@ -163,7 +163,8 @@ def drop_vif_based(df: pd.DataFrame, vif_df: pd.DataFrame, threshold: float = 10
163
163
  """
164
164
  # Ensure expected structure
165
165
  if 'feature' not in vif_df.columns or 'VIF' not in vif_df.columns:
166
- raise ValueError("'vif_df' must contain 'feature' and 'VIF' columns.")
166
+ _LOGGER.error("'vif_df' must contain 'feature' and 'VIF' columns.")
167
+ raise ValueError()
167
168
 
168
169
  # Identify features to drop
169
170
  to_drop = vif_df[vif_df["VIF"] > threshold]["feature"].tolist()
@@ -177,7 +178,7 @@ def drop_vif_based(df: pd.DataFrame, vif_df: pd.DataFrame, threshold: float = 10
177
178
  result_df = df.drop(columns=to_drop)
178
179
 
179
180
  if result_df.empty:
180
- _LOGGER.warning(f"⚠️ All columns were dropped.")
181
+ _LOGGER.warning(f"All columns were dropped.")
181
182
 
182
183
  return result_df, to_drop
183
184
 
ml_tools/_logger.py CHANGED
@@ -1,6 +1,73 @@
1
1
  import logging
2
2
  import sys
3
3
 
4
+ # Step 1: Conditionally import colorlog
5
+ try:
6
+ import colorlog # type: ignore
7
+ except ImportError:
8
+ colorlog = None
9
+
10
+
11
+
12
+
13
+ # --- Centralized Configuration ---
14
+ LEVEL_EMOJIS = {
15
+ logging.INFO: "✅",
16
+ logging.WARNING: "⚠️ ",
17
+ logging.ERROR: "🚨",
18
+ logging.CRITICAL: "❌"
19
+ }
20
+
21
+ # Define base format strings.
22
+ BASE_INFO_FORMAT = '\n🐉 %(asctime)s [%(emoji)s %(levelname)s] - %(message)s'
23
+ BASE_WARN_FORMAT = '\n🐉 %(asctime)s [%(emoji)s %(levelname)s] [%(filename)s:%(lineno)d] - %(message)s'
24
+
25
+
26
+ # --- Unified Formatter ---
27
+ # Determine the base class and format strings based on colorlog availability
28
+ if colorlog:
29
+ # If colorlog is available, use it as the base and use colorized formats.
30
+ _BaseFormatter = colorlog.ColoredFormatter
31
+ _INFO_FORMAT = BASE_INFO_FORMAT.replace('%(levelname)s', '%(log_color)s%(levelname)s%(reset)s')
32
+ _WARN_FORMAT = BASE_WARN_FORMAT.replace('%(levelname)s', '%(log_color)s%(levelname)s%(reset)s')
33
+ else:
34
+ # Otherwise, fall back to the standard logging.Formatter.
35
+ _BaseFormatter = logging.Formatter
36
+ _INFO_FORMAT = BASE_INFO_FORMAT
37
+ _WARN_FORMAT = BASE_WARN_FORMAT
38
+
39
+
40
+ class _UnifiedFormatter(_BaseFormatter): # type: ignore
41
+ """
42
+ A unified log formatter that adds emojis, uses level-specific formats,
43
+ and applies colors if colorlog is available.
44
+ """
45
+
46
+ def __init__(self, *args, **kwargs):
47
+ """Initializes the formatter, creating sub-formatters for each level."""
48
+ # The base class __init__ is called implicitly. We prepare our custom formatters here.
49
+ self.datefmt = kwargs.get('datefmt')
50
+
51
+ # We need to pass the correct arguments to the correct formatter type
52
+ if colorlog:
53
+ log_colors = kwargs.get('log_colors', {})
54
+ self.info_formatter = colorlog.ColoredFormatter(_INFO_FORMAT, datefmt=self.datefmt, log_colors=log_colors)
55
+ self.warn_formatter = colorlog.ColoredFormatter(_WARN_FORMAT, datefmt=self.datefmt, log_colors=log_colors)
56
+ else:
57
+ self.info_formatter = logging.Formatter(_INFO_FORMAT, datefmt=self.datefmt)
58
+ self.warn_formatter = logging.Formatter(_WARN_FORMAT, datefmt=self.datefmt)
59
+
60
+ def format(self, record):
61
+ """Adds a custom emoji attribute to the record before formatting."""
62
+ # Add the new attribute to the record. Use .get() for a safe default.
63
+ record.emoji = LEVEL_EMOJIS.get(record.levelno, "")
64
+
65
+ # Select the appropriate formatter and let it handle the rest.
66
+ if record.levelno >= logging.WARNING:
67
+ return self.warn_formatter.format(record)
68
+ else:
69
+ return self.info_formatter.format(record)
70
+
4
71
 
5
72
  def _get_logger(name: str = "ml_tools", level: int = logging.INFO):
6
73
  """
@@ -9,6 +76,7 @@ def _get_logger(name: str = "ml_tools", level: int = logging.INFO):
9
76
  - `logger.info()`
10
77
  - `logger.warning()`
11
78
  - `logger.error()` the program can potentially recover.
79
+ - `logger.exception()` inside an except block.
12
80
  - `logger.critical()` the program is going to crash.
13
81
  """
14
82
  logger = logging.getLogger(name)
@@ -16,15 +84,26 @@ def _get_logger(name: str = "ml_tools", level: int = logging.INFO):
16
84
 
17
85
  # Prevents adding handlers multiple times if the function is called again
18
86
  if not logger.handlers:
19
- handler = logging.StreamHandler(sys.stdout)
20
-
21
- # Define the format string and the date format separately
22
- log_format = '\n🐉%(asctime)s - %(name)s - %(levelname)s - %(message)s'
23
- date_format = '%Y-%m-%d %H:%M' # Format: Year-Month-Day Hour:Minute
87
+ # Prepare arguments for the unified formatter
88
+ formatter_kwargs = {
89
+ 'datefmt': '%Y-%m-%d %H:%M'
90
+ }
24
91
 
25
- # Pass both the format and the date format to the Formatter
26
- formatter = logging.Formatter(log_format, datefmt=date_format)
92
+ # Use colorlog's handler if available, and add color arguments
93
+ if colorlog:
94
+ handler = colorlog.StreamHandler()
95
+ formatter_kwargs["log_colors"] = { # type: ignore
96
+ 'DEBUG': 'cyan',
97
+ 'INFO': 'green',
98
+ 'WARNING': 'yellow',
99
+ 'ERROR': 'red',
100
+ 'CRITICAL': 'red,bg_white',
101
+ }
102
+ else:
103
+ handler = logging.StreamHandler(sys.stdout)
27
104
 
105
+ # Create and set the single, unified formatter
106
+ formatter = _UnifiedFormatter(**formatter_kwargs)
28
107
  handler.setFormatter(formatter)
29
108
  logger.addHandler(handler)
30
109
 
@@ -32,5 +111,24 @@ def _get_logger(name: str = "ml_tools", level: int = logging.INFO):
32
111
 
33
112
  return logger
34
113
 
114
+
35
115
  # Create a single logger instance to be imported by other modules
36
116
  _LOGGER = _get_logger()
117
+
118
+
119
+ def _log_and_exit(message: str, exit_code: int = 1):
120
+ """Logs a critical message inside an exception block and terminates the program."""
121
+ _LOGGER.exception(message)
122
+ sys.exit(exit_code)
123
+
124
+
125
+ if __name__ == "__main__":
126
+ _LOGGER.info("Data loading process started.")
127
+ _LOGGER.warning("A non-critical configuration value is missing.")
128
+
129
+ try:
130
+ x = 1 / 0
131
+ except ZeroDivisionError:
132
+ _LOGGER.exception("Critical error during calculation.")
133
+
134
+ _LOGGER.critical("Total failure.")
ml_tools/custom_logger.py CHANGED
@@ -76,12 +76,13 @@ def custom_logger(
76
76
  _log_exception_to_log(data, base_path.with_suffix(".log"))
77
77
 
78
78
  else:
79
- raise ValueError("Unsupported data type. Must be list, dict, str, or BaseException.")
79
+ _LOGGER.error("Unsupported data type. Must be list, dict, str, or BaseException.")
80
+ raise ValueError()
80
81
 
81
- _LOGGER.info(f"🗄️ Log saved to: '{base_path}'")
82
+ _LOGGER.info(f"Log saved to: '{base_path}'")
82
83
 
83
- except Exception as e:
84
- _LOGGER.error(f"Log not saved: {e}")
84
+ except Exception:
85
+ _LOGGER.exception(f"Log not saved.")
85
86
 
86
87
 
87
88
  def _log_list_to_txt(data: List[Any], path: Path) -> None:
@@ -102,7 +103,9 @@ def _log_dict_to_csv(data: Dict[Any, List[Any]], path: Path) -> None:
102
103
 
103
104
  for key, value in data.items():
104
105
  if not isinstance(value, list):
105
- raise ValueError(f"Dictionary value for key '{key}' must be a list.")
106
+ _LOGGER.error(f"Dictionary value for key '{key}' must be a list.")
107
+ raise ValueError()
108
+
106
109
  sanitized_key = str(key).strip().replace('\n', '_').replace('\r', '_')
107
110
  padded_value = value + [None] * (max_length - len(value))
108
111
  sanitized_dict[sanitized_key] = padded_value
@@ -152,7 +155,7 @@ def save_list_strings(list_strings: list[str], directory: Union[str,Path], filen
152
155
  f.write(f"{string_data}\n")
153
156
 
154
157
  if verbose:
155
- _LOGGER.info(f"Text file saved as '{full_path.name}'.")
158
+ _LOGGER.info(f"Text file saved as '{full_path.name}'.")
156
159
 
157
160
 
158
161
  def load_list_strings(text_file: Union[str,Path], verbose: bool=True) -> list[str]:
@@ -164,10 +167,11 @@ def load_list_strings(text_file: Union[str,Path], verbose: bool=True) -> list[st
164
167
  loaded_strings = [line.strip() for line in f]
165
168
 
166
169
  if len(loaded_strings) == 0:
167
- raise ValueError("The text file is empty.")
170
+ _LOGGER.error("The text file is empty.")
171
+ raise ValueError()
168
172
 
169
173
  if verbose:
170
- _LOGGER.info(f"Text file loaded as list of strings.")
174
+ _LOGGER.info(f"Text file loaded as list of strings.")
171
175
 
172
176
  return loaded_strings
173
177