dragon-ml-toolbox 3.12.6__py3-none-any.whl → 4.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -2,28 +2,24 @@ import numpy as np
2
2
  from pathlib import Path
3
3
  import xgboost as xgb
4
4
  import lightgbm as lgb
5
- from sklearn.ensemble import HistGradientBoostingRegressor
6
- from sklearn.base import ClassifierMixin
7
5
  from typing import Literal, Union, Tuple, Dict, Optional
8
6
  import pandas as pd
9
7
  from copy import deepcopy
10
8
  from .utilities import (
11
- _script_info,
12
- list_csv_paths,
13
9
  threshold_binary_values,
14
10
  threshold_binary_values_batch,
15
- deserialize_object,
16
- list_files_by_extension,
17
- save_dataframe,
18
- make_fullpath,
19
- yield_dataframes_from_dir,
20
- sanitize_filename)
11
+ deserialize_object,
12
+ yield_dataframes_from_dir)
13
+ from .path_manager import sanitize_filename, make_fullpath, list_files_by_extension, list_csv_paths
21
14
  import torch
22
15
  from tqdm import trange
23
16
  import matplotlib.pyplot as plt
24
17
  import seaborn as sns
25
- from .logger import _LOGGER
18
+ from ._logger import _LOGGER
26
19
  from .keys import ModelSaveKeys
20
+ from ._script_info import _script_info
21
+ from .SQL import DatabaseManager
22
+ from contextlib import nullcontext
27
23
 
28
24
 
29
25
  __all__ = [
@@ -125,7 +121,7 @@ class ObjectiveFunction():
125
121
  return features_array * noise
126
122
 
127
123
  def check_model(self):
128
- if isinstance(self.model, ClassifierMixin) or isinstance(self.model, xgb.XGBClassifier) or isinstance(self.model, lgb.LGBMClassifier):
124
+ if isinstance(self.model, xgb.XGBClassifier) or isinstance(self.model, lgb.LGBMClassifier):
129
125
  raise ValueError(f"[Model Check Failed] ❌\nThe loaded model ({type(self.model).__name__}) is a Classifier.\nOptimization is not suitable for standard classification tasks.")
130
126
  if self.model is None:
131
127
  raise ValueError("Loaded model is None")
@@ -187,45 +183,73 @@ def _set_feature_names(size: int, names: Union[list[str], None]):
187
183
  else:
188
184
  assert len(names) == size, "List with feature names do not match the number of features"
189
185
  return names
190
-
191
186
 
192
- def _save_results(*dicts, save_dir: Union[str,Path], target_name: str):
193
- combined_dict = dict()
194
- for single_dict in dicts:
195
- combined_dict.update(single_dict)
187
+
188
+ def _save_result(result_dict: dict,
189
+ save_format: Literal['csv', 'sqlite', 'both'],
190
+ csv_path: Path,
191
+ db_manager: Optional[DatabaseManager] = None,
192
+ db_table_name: Optional[str] = None):
193
+ """
194
+ Handles saving a single result to CSV, SQLite, or both.
195
+ """
196
+ # Save to CSV
197
+ if save_format in ['csv', 'both']:
198
+ _save_or_append_to_csv(result_dict, csv_path)
199
+
200
+ # Save to SQLite
201
+ if save_format in ['sqlite', 'both']:
202
+ if db_manager and db_table_name:
203
+ db_manager.insert_row(db_table_name, result_dict)
204
+ else:
205
+ _LOGGER.warning("SQLite saving requested but db_manager or table_name not provided.")
206
+
207
+
208
+ def _save_or_append_to_csv(data_dict: dict, save_path: Path):
209
+ """
210
+ Saves or appends a dictionary of data as a single row to a CSV file.
211
+
212
+ If the file doesn't exist, it creates it and writes the header.
213
+ If the file exists, it appends the new data without the header.
214
+ """
215
+ df_row = pd.DataFrame([data_dict])
196
216
 
197
- df = pd.DataFrame(combined_dict)
217
+ file_exists = save_path.exists()
198
218
 
199
- save_dataframe(df=df, save_dir=save_dir, filename=f"Optimization_{target_name}")
219
+ df_row.to_csv(
220
+ save_path,
221
+ mode='a', # 'a' for append mode
222
+ index=False, # Don't write the DataFrame index
223
+ header=not file_exists # Write header only if file does NOT exist
224
+ )
200
225
 
201
226
 
202
- def _run_single_pso(objective_function: ObjectiveFunction, pso_args: dict, feature_names: list[str], target_name: str, random_state: int):
203
- """Helper for a single PSO run."""
227
+ def _run_single_pso(objective_function: ObjectiveFunction, pso_args: dict, feature_names: list[str], target_name: str, random_state: int, save_format: Literal['csv', 'sqlite', 'both'], csv_path: Path, db_manager: Optional[DatabaseManager], db_table_name: str):
228
+ """Helper for a single PSO run that also handles saving."""
204
229
  pso_args.update({"seed": random_state})
205
230
 
206
231
  best_features, best_target, *_ = _pso(**pso_args)
207
232
 
208
- # Flip best_target if maximization was used
209
233
  if objective_function.task == "maximization":
210
234
  best_target = -best_target
211
235
 
212
- # Threshold binary features
213
236
  binary_number = objective_function.binary_features
214
237
  best_features_threshold = threshold_binary_values(best_features, binary_number)
215
238
 
216
- # Name features and target
217
239
  best_features_named = {name: value for name, value in zip(feature_names, best_features_threshold)}
218
240
  best_target_named = {target_name: best_target}
219
241
 
242
+ # Save the result using the new helper
243
+ combined_dict = {**best_features_named, **best_target_named}
244
+ _save_result(combined_dict, save_format, csv_path, db_manager, db_table_name)
245
+
220
246
  return best_features_named, best_target_named
221
247
 
222
248
 
223
- def _run_post_hoc_pso(objective_function: ObjectiveFunction, pso_args: dict, feature_names: list[str], target_name: str, repetitions: int):
224
- """Helper for post-hoc PSO analysis."""
225
- all_best_targets = []
226
- all_best_features = [[] for _ in range(len(feature_names))]
227
-
228
- for _ in range(repetitions):
249
+ def _run_post_hoc_pso(objective_function: ObjectiveFunction, pso_args: dict, feature_names: list[str], target_name: str, repetitions: int, save_format: Literal['csv', 'sqlite', 'both'], csv_path: Path, db_manager: Optional[DatabaseManager], db_table_name: str):
250
+ """Helper for post-hoc analysis that saves results incrementally."""
251
+ progress = trange(repetitions, desc="Post-Hoc PSO", unit="run")
252
+ for _ in progress:
229
253
  best_features, best_target, *_ = _pso(**pso_args)
230
254
 
231
255
  if objective_function.task == "maximization":
@@ -234,28 +258,25 @@ def _run_post_hoc_pso(objective_function: ObjectiveFunction, pso_args: dict, fea
234
258
  binary_number = objective_function.binary_features
235
259
  best_features_threshold = threshold_binary_values(best_features, binary_number)
236
260
 
237
- for i, best_feature in enumerate(best_features_threshold):
238
- all_best_features[i].append(best_feature)
239
- all_best_targets.append(best_target)
240
-
241
- # Name features and target
242
- all_best_features_named = {name: lst for name, lst in zip(feature_names, all_best_features)}
243
- all_best_targets_named = {target_name: all_best_targets}
244
-
245
- return all_best_features_named, all_best_targets_named
261
+ result_dict = {name: value for name, value in zip(feature_names, best_features_threshold)}
262
+ result_dict[target_name] = best_target
263
+
264
+ # Save each result incrementally
265
+ _save_result(result_dict, save_format, csv_path, db_manager, db_table_name)
246
266
 
247
267
 
248
268
  def run_pso(lower_boundaries: list[float],
249
269
  upper_boundaries: list[float],
250
270
  objective_function: ObjectiveFunction,
251
271
  save_results_dir: Union[str,Path],
272
+ save_format: Literal['csv', 'sqlite', 'both'] = 'csv',
252
273
  auto_binary_boundaries: bool=True,
253
274
  target_name: Union[str, None]=None,
254
275
  feature_names: Union[list[str], None]=None,
255
276
  swarm_size: int=200,
256
277
  max_iterations: int=3000,
257
278
  random_state: int=101,
258
- post_hoc_analysis: Optional[int]=10) -> Tuple[Dict[str, float | list[float]], Dict[str, float | list[float]]]:
279
+ post_hoc_analysis: Optional[int]=10) -> Optional[Tuple[Dict[str, float], Dict[str, float]]]:
259
280
  """
260
281
  Executes Particle Swarm Optimization (PSO) to optimize a given objective function and saves the results as a CSV file.
261
282
 
@@ -269,6 +290,11 @@ def run_pso(lower_boundaries: list[float],
269
290
  A callable object encapsulating a tree-based regression model.
270
291
  save_results_dir : str | Path
271
292
  Directory path to save the results CSV file.
293
+ save_format : {'csv', 'sqlite', 'both'}, default 'csv'
294
+ The format for saving optimization results.
295
+ - 'csv': Saves results to a CSV file.
296
+ - 'sqlite': Saves results to an SQLite database file. ⚠️ If a database exists, new tables will be created using the target name.
297
+ - 'both': Saves results to both formats.
272
298
  auto_binary_boundaries : bool
273
299
  Use `ObjectiveFunction.binary_features` to append as many binary boundaries as needed to `lower_boundaries` and `upper_boundaries` automatically.
274
300
  target_name : str or None, optional
@@ -284,14 +310,11 @@ def run_pso(lower_boundaries: list[float],
284
310
 
285
311
  Returns
286
312
  -------
287
- Tuple[Dict[str, float | list[float]], Dict[str, float | list[float]]]
288
- If `post_hoc_analysis` is None, returns two dictionaries:
289
- - feature_names: Feature values (after inverse scaling) that yield the best result.
290
- - target_name: Best result obtained for the target variable.
291
-
292
- If `post_hoc_analysis` is an integer, returns two dictionaries:
293
- - feature_names: Lists of best feature values (after inverse scaling) for each repetition.
294
- - target_name: List of best target values across repetitions.
313
+ Tuple[Dict[str, float], Dict[str, float]] or None
314
+ - If `post_hoc_analysis` is None, returns two dictionaries containing the
315
+ single best features and the corresponding target value.
316
+ - If `post_hoc_analysis` is active, results are streamed directly to a CSV file
317
+ and this function returns `None`.
295
318
 
296
319
  Notes
297
320
  -----
@@ -316,8 +339,9 @@ def run_pso(lower_boundaries: list[float],
316
339
  # Append binary boundaries
317
340
  binary_number = objective_function.binary_features
318
341
  if auto_binary_boundaries and binary_number > 0:
319
- local_lower_boundaries.extend([0] * binary_number)
320
- local_upper_boundaries.extend([1] * binary_number)
342
+ # simplify binary search by constraining range
343
+ local_lower_boundaries.extend([0.45] * binary_number)
344
+ local_upper_boundaries.extend([0.55] * binary_number)
321
345
 
322
346
  # Set the total length of features
323
347
  size_of_features = len(local_lower_boundaries)
@@ -333,7 +357,25 @@ def run_pso(lower_boundaries: list[float],
333
357
  if target_name is None and objective_function.target_name is not None:
334
358
  target_name = objective_function.target_name
335
359
  if target_name is None:
336
- target_name = "Target"
360
+ raise ValueError(f"'target' name was not provided and was not found in the .joblib object.")
361
+
362
+ # --- Setup: Saving Infrastructure ---
363
+ sanitized_target_name = sanitize_filename(target_name)
364
+ save_dir_path = make_fullpath(save_results_dir, make=True, enforce="directory")
365
+ base_filename = f"Optimization_{sanitized_target_name}"
366
+ csv_path = save_dir_path / f"{base_filename}.csv"
367
+ db_path = save_dir_path / "Optimization.db"
368
+ db_table_name = f"{sanitized_target_name}"
369
+
370
+ if save_format in ['sqlite', 'both']:
371
+ # Dynamically create the schema for the database table
372
+ schema = {name: "REAL" for name in names}
373
+ schema[target_name] = "REAL"
374
+ schema = {"result_id": "INTEGER PRIMARY KEY AUTOINCREMENT", **schema}
375
+
376
+ # Create table
377
+ with DatabaseManager(db_path) as db:
378
+ db.create_table(db_table_name, schema)
337
379
 
338
380
  pso_arguments = {
339
381
  "func":objective_function,
@@ -345,17 +387,29 @@ def run_pso(lower_boundaries: list[float],
345
387
  "particle_output": False,
346
388
  }
347
389
 
348
- # Dispatcher
349
- if post_hoc_analysis is None or post_hoc_analysis <= 1:
350
- features, target = _run_single_pso(objective_function, pso_arguments, names, target_name, random_state)
351
- else:
352
- features, target = _run_post_hoc_pso(objective_function, pso_arguments, names, target_name, post_hoc_analysis)
353
-
354
- # --- Save Results ---
355
- save_results_path = make_fullpath(save_results_dir, make=True)
356
- _save_results(features, target, save_dir=save_results_path, target_name=target_name)
357
-
358
- return features, target # type: ignore
390
+ # --- Dispatcher ---
391
+ # Use a real or dummy context manager to handle the DB connection cleanly
392
+ db_context = DatabaseManager(db_path) if save_format in ['sqlite', 'both'] else nullcontext()
393
+
394
+ with db_context as db_manager:
395
+ if post_hoc_analysis is None or post_hoc_analysis <= 1:
396
+ # --- Single Run Logic ---
397
+ features_dict, target_dict = _run_single_pso(
398
+ objective_function, pso_arguments, names, target_name, random_state,
399
+ save_format, csv_path, db_manager, db_table_name
400
+ )
401
+ _LOGGER.info(f"✅ Single optimization complete.")
402
+ return features_dict, target_dict
403
+
404
+ else:
405
+ # --- Post-Hoc Analysis Logic ---
406
+ _LOGGER.info(f"🏁 Starting post-hoc analysis with {post_hoc_analysis} repetitions...")
407
+ _run_post_hoc_pso(
408
+ objective_function, pso_arguments, names, target_name, post_hoc_analysis,
409
+ save_format, csv_path, db_manager, db_table_name
410
+ )
411
+ _LOGGER.info("✅ Post-hoc analysis complete. Results saved.")
412
+ return None
359
413
 
360
414
 
361
415
  def _pso(func: ObjectiveFunction,
ml_tools/RNN_forecast.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import torch
2
2
  from torch import nn
3
3
  import numpy as np
4
+ from ._script_info import _script_info
4
5
 
5
6
  __all__ = [
6
7
  "rnn_forecast"
@@ -47,3 +48,7 @@ def rnn_forecast(model: nn.Module, start_sequence: torch.Tensor, steps: int, dev
47
48
 
48
49
  # Concatenate all predictions and flatten the array for easy use
49
50
  return np.concatenate(predictions).flatten()
51
+
52
+
53
+ def info():
54
+ _script_info
ml_tools/SQL.py ADDED
@@ -0,0 +1,272 @@
1
+ import sqlite3
2
+ import pandas as pd
3
+ from pathlib import Path
4
+ from typing import Union, Dict, Any, Optional, List, Literal
5
+ from ._logger import _LOGGER
6
+ from ._script_info import _script_info
7
+ from .path_manager import make_fullpath
8
+
9
+
10
+ __all__ = [
11
+ "DatabaseManager",
12
+ ]
13
+
14
+
15
+ class DatabaseManager:
16
+ """
17
+ A user-friendly context manager for handling SQLite database operations.
18
+
19
+ This class abstracts the underlying sqlite3 connection and cursor management,
20
+ providing simple methods to execute queries, create tables, and handle data
21
+ insertion and retrieval using pandas DataFrames.
22
+
23
+ Parameters
24
+ ----------
25
+ db_path : Union[str, Path]
26
+ The file path to the SQLite database. If the file does not exist,
27
+ it will be created upon connection.
28
+
29
+ Example
30
+ -------
31
+ >>> schema = {
32
+ ... "id": "INTEGER PRIMARY KEY AUTOINCREMENT",
33
+ ... "run_name": "TEXT NOT NULL",
34
+ ... "feature_a": "REAL",
35
+ ... "score": "REAL"
36
+ ... }
37
+ >>> with DatabaseManager("my_results.db") as db:
38
+ ... db.create_table("experiments", schema)
39
+ ... data = {"run_name": "first_run", "feature_a": 0.123, "score": 95.5}
40
+ ... db.insert_row("experiments", data)
41
+ ... df = db.query_to_dataframe("SELECT * FROM experiments")
42
+ ... print(df)
43
+ """
44
+ def __init__(self, db_path: Union[str, Path]):
45
+ """Initializes the DatabaseManager with the path to the database file."""
46
+ if isinstance(db_path, str):
47
+ if not db_path.endswith(".db"):
48
+ db_path = db_path + ".db"
49
+ elif isinstance(db_path, Path):
50
+ if db_path.suffix != ".db":
51
+ db_path = db_path.with_suffix(".db")
52
+
53
+ self.db_path = make_fullpath(db_path, make=True, enforce="file")
54
+ self.conn: Optional[sqlite3.Connection] = None
55
+ self.cursor: Optional[sqlite3.Cursor] = None
56
+
57
+ def __enter__(self):
58
+ """Establishes the database connection and returns the manager instance."""
59
+ try:
60
+ self.conn = sqlite3.connect(self.db_path)
61
+ self.cursor = self.conn.cursor()
62
+ _LOGGER.info(f"✅ Successfully connected to database: {self.db_path}")
63
+ return self
64
+ except sqlite3.Error as e:
65
+ _LOGGER.error(f"❌ Database connection failed: {e}")
66
+ raise # Re-raise the exception after logging
67
+
68
+ def __exit__(self, exc_type, exc_val, exc_tb):
69
+ """Commits changes and closes the database connection."""
70
+ if self.conn:
71
+ if exc_type: # If an exception occurred, rollback
72
+ self.conn.rollback()
73
+ _LOGGER.warning("⚠️ Rolling back transaction due to an error.")
74
+ else: # Otherwise, commit the transaction
75
+ self.conn.commit()
76
+ self.conn.close()
77
+ _LOGGER.info(f"❇️ Database connection closed: {self.db_path.name}")
78
+
79
+ def create_table(self, table_name: str, schema: Dict[str, str], if_not_exists: bool = True):
80
+ """
81
+ Creates a new table in the database based on a provided schema.
82
+
83
+ Parameters
84
+ ----------
85
+ table_name : str
86
+ The name of the table to create.
87
+ schema : Dict[str, str]
88
+ A dictionary where keys are column names and values are their SQL data types
89
+ (e.g., {"id": "INTEGER PRIMARY KEY", "name": "TEXT NOT NULL"}).
90
+ if_not_exists : bool, default=True
91
+ If True, adds "IF NOT EXISTS" to the SQL statement to prevent errors
92
+ if the table already exists.
93
+ """
94
+ if not self.cursor:
95
+ raise sqlite3.Error("Database connection is not open.")
96
+
97
+ columns_def = ", ".join([f'"{col_name}" {col_type}' for col_name, col_type in schema.items()])
98
+ exists_clause = "IF NOT EXISTS" if if_not_exists else ""
99
+
100
+ query = f"CREATE TABLE {exists_clause} {table_name} ({columns_def})"
101
+
102
+ _LOGGER.info(f"🗂️ Executing: {query}")
103
+ self.cursor.execute(query)
104
+
105
+ def insert_row(self, table_name: str, data: Dict[str, Any]):
106
+ """
107
+ Inserts a single row of data into the specified table.
108
+
109
+ Parameters
110
+ ----------
111
+ table_name : str
112
+ The name of the target table.
113
+ data : Dict[str, Any]
114
+ A dictionary where keys correspond to column names and values are the
115
+ data to be inserted.
116
+ """
117
+ if not self.cursor:
118
+ raise sqlite3.Error("Database connection is not open.")
119
+
120
+ columns = ', '.join(f'"{k}"' for k in data.keys())
121
+ placeholders = ', '.join(['?'] * len(data))
122
+ values = list(data.values())
123
+
124
+ query = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
125
+
126
+ self.cursor.execute(query, values)
127
+
128
+ def query_to_dataframe(self, query: str, params: Optional[tuple] = None) -> pd.DataFrame:
129
+ """
130
+ Executes a SELECT query and returns the results as a pandas DataFrame.
131
+
132
+ Parameters
133
+ ----------
134
+ query : str
135
+ The SQL SELECT statement to execute.
136
+ params : Optional[tuple], default=None
137
+ An optional tuple of parameters to pass to the query for safety
138
+ against SQL injection.
139
+
140
+ Returns
141
+ -------
142
+ pd.DataFrame
143
+ A DataFrame containing the query results.
144
+ """
145
+ if not self.conn:
146
+ raise sqlite3.Error("Database connection is not open.")
147
+
148
+ return pd.read_sql_query(query, self.conn, params=params)
149
+
150
+ def execute_sql(self, query: str, params: Optional[tuple] = None):
151
+ """
152
+ Executes an arbitrary SQL command that does not return data (e.g., UPDATE, DELETE).
153
+
154
+ Parameters
155
+ ----------
156
+ query : str
157
+ The SQL statement to execute.
158
+ params : Optional[tuple], default=None
159
+ An optional tuple of parameters for the query.
160
+ """
161
+ if not self.cursor:
162
+ raise sqlite3.Error("Database connection is not open.")
163
+
164
+ self.cursor.execute(query, params if params else ())
165
+
166
+ def insert_many(self, table_name: str, data: List[Dict[str, Any]]):
167
+ """
168
+ Inserts multiple rows into the specified table in a single, efficient transaction.
169
+
170
+ Parameters
171
+ ----------
172
+ table_name : str
173
+ The name of the target table.
174
+ data : List[Dict[str, Any]]
175
+ A list of dictionaries, where each dictionary represents a row to be inserted.
176
+ All dictionaries should have the same keys.
177
+ """
178
+ if not self.cursor:
179
+ raise sqlite3.Error("Database connection is not open.")
180
+ if not data:
181
+ _LOGGER.warning("⚠️ insert_many called with empty data list. No action taken.")
182
+ return
183
+
184
+ # Assume all dicts have the same keys as the first one
185
+ first_row = data[0]
186
+ columns = ', '.join(f'"{k}"' for k in first_row.keys())
187
+ placeholders = ', '.join(['?'] * len(first_row))
188
+
189
+ # Create a list of tuples, where each tuple is a row of values
190
+ values_to_insert = [list(row.values()) for row in data]
191
+
192
+ query = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
193
+
194
+ self.cursor.executemany(query, values_to_insert)
195
+ _LOGGER.info(f"✅ Bulk inserted {len(values_to_insert)} rows into '{table_name}'.")
196
+
197
+ def insert_from_dataframe(self, table_name: str, df: pd.DataFrame, if_exists: Literal['fail', 'replace', 'append'] = 'append'):
198
+ """
199
+ Writes records from a pandas DataFrame to the specified SQL table.
200
+
201
+ Parameters
202
+ ----------
203
+ table_name : str
204
+ The name of the target SQL table.
205
+ df : pd.DataFrame
206
+ The DataFrame to be written.
207
+ if_exists : str, default 'append'
208
+ How to behave if the table already exists.
209
+ - 'fail': Raise a ValueError.
210
+ - 'replace': Drop the table before inserting new values.
211
+ - 'append': Insert new values to the existing table.
212
+ """
213
+ if not self.conn:
214
+ raise sqlite3.Error("Database connection is not open.")
215
+
216
+ df.to_sql(
217
+ table_name,
218
+ self.conn,
219
+ if_exists=if_exists,
220
+ index=False # Typically, we don't want to save the DataFrame index
221
+ )
222
+ _LOGGER.info(f"✅ Wrote {len(df)} rows from DataFrame to table '{table_name}' using mode '{if_exists}'.")
223
+
224
+ def list_tables(self) -> List[str]:
225
+ """Returns a list of all table names in the database."""
226
+ if not self.cursor:
227
+ raise sqlite3.Error("Database connection is not open.")
228
+
229
+ self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
230
+ # The result of the fetch is a list of tuples, e.g., [('table1',), ('table2',)]
231
+ return [table[0] for table in self.cursor.fetchall()]
232
+
233
+ def get_table_schema(self, table_name: str) -> pd.DataFrame:
234
+ """
235
+ Retrieves the schema of a specific table and returns it as a DataFrame.
236
+
237
+ Returns a DataFrame with columns: cid, name, type, notnull, dflt_value, pk
238
+ """
239
+ if not self.conn:
240
+ raise sqlite3.Error("Database connection is not open.")
241
+
242
+ # PRAGMA is a special SQL command in SQLite for database metadata
243
+ return pd.read_sql_query(f'PRAGMA table_info("{table_name}");', self.conn)
244
+
245
+ def create_index(self, table_name: str, column_name: str, unique: bool = False):
246
+ """
247
+ Creates an index on a column of a specified table to speed up queries.
248
+
249
+ Parameters
250
+ ----------
251
+ table_name : str
252
+ The name of the table containing the column.
253
+ column_name : str
254
+ The name of the column to be indexed.
255
+ unique : bool, default=False
256
+ If True, creates a unique index, which ensures all values in the
257
+ column are unique.
258
+ """
259
+ if not self.cursor:
260
+ raise sqlite3.Error("Database connection is not open.")
261
+
262
+ index_name = f"idx_{table_name}_{column_name}"
263
+ unique_clause = "UNIQUE" if unique else ""
264
+
265
+ query = f"CREATE {unique_clause} INDEX IF NOT EXISTS {index_name} ON {table_name} ({column_name})"
266
+
267
+ _LOGGER.info(f"🗂️ Executing: {query}")
268
+ self.cursor.execute(query)
269
+
270
+
271
+ def info():
272
+ _script_info(__all__)
ml_tools/VIF_factor.py CHANGED
@@ -7,9 +7,10 @@ from statsmodels.stats.outliers_influence import variance_inflation_factor
7
7
  from statsmodels.tools.tools import add_constant
8
8
  import warnings
9
9
  from pathlib import Path
10
- from .utilities import sanitize_filename, yield_dataframes_from_dir, save_dataframe, _script_info, make_fullpath
11
- from .logger import _LOGGER
12
-
10
+ from .utilities import yield_dataframes_from_dir, save_dataframe
11
+ from .path_manager import sanitize_filename, make_fullpath
12
+ from ._logger import _LOGGER
13
+ from ._script_info import _script_info
13
14
 
14
15
  __all__ = [
15
16
  "compute_vif",
ml_tools/_logger.py ADDED
@@ -0,0 +1,36 @@
1
+ import logging
2
+ import sys
3
+
4
+
5
+ def _get_logger(name: str = "ml_tools", level: int = logging.INFO):
6
+ """
7
+ Initializes and returns a configured logger instance.
8
+
9
+ - `logger.info()`
10
+ - `logger.warning()`
11
+ - `logger.error()` the program can potentially recover.
12
+ - `logger.critical()` the program is going to crash.
13
+ """
14
+ logger = logging.getLogger(name)
15
+ logger.setLevel(level)
16
+
17
+ # Prevents adding handlers multiple times if the function is called again
18
+ if not logger.handlers:
19
+ handler = logging.StreamHandler(sys.stdout)
20
+
21
+ # Define the format string and the date format separately
22
+ log_format = '\n🐉%(asctime)s - %(name)s - %(levelname)s - %(message)s'
23
+ date_format = '%Y-%m-%d %H:%M' # Format: Year-Month-Day Hour:Minute
24
+
25
+ # Pass both the format and the date format to the Formatter
26
+ formatter = logging.Formatter(log_format, datefmt=date_format)
27
+
28
+ handler.setFormatter(formatter)
29
+ logger.addHandler(handler)
30
+
31
+ logger.propagate = False
32
+
33
+ return logger
34
+
35
+ # Create a single logger instance to be imported by other modules
36
+ _LOGGER = _get_logger()
@@ -1,6 +1,6 @@
1
1
  import torch
2
2
  from torch import nn
3
- from .utilities import _script_info
3
+ from ._script_info import _script_info
4
4
 
5
5
 
6
6
  __all__ = [
@@ -0,0 +1,8 @@
1
+
2
+ def _script_info(all_data: list[str]):
3
+ """
4
+ List available names.
5
+ """
6
+ print("Available functions and objects:")
7
+ for i, name in enumerate(all_data, start=1):
8
+ print(f"{i} - {name}")