dragon-ml-toolbox 19.14.0__py3-none-any.whl → 20.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
  2. dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
  3. ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
  4. ml_tools/ETL_cleaning/_basic_clean.py +351 -0
  5. ml_tools/ETL_cleaning/_clean_tools.py +128 -0
  6. ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
  7. ml_tools/ETL_cleaning/_imprimir.py +13 -0
  8. ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
  9. ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
  10. ml_tools/ETL_engineering/_imprimir.py +24 -0
  11. ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
  12. ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
  13. ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
  14. ml_tools/GUI_tools/_imprimir.py +12 -0
  15. ml_tools/IO_tools/_IO_loggers.py +235 -0
  16. ml_tools/IO_tools/_IO_save_load.py +151 -0
  17. ml_tools/IO_tools/_IO_utils.py +140 -0
  18. ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
  19. ml_tools/IO_tools/_imprimir.py +14 -0
  20. ml_tools/MICE/_MICE_imputation.py +132 -0
  21. ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
  22. ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
  23. ml_tools/MICE/_imprimir.py +11 -0
  24. ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
  25. ml_tools/ML_callbacks/_base.py +101 -0
  26. ml_tools/ML_callbacks/_checkpoint.py +232 -0
  27. ml_tools/ML_callbacks/_early_stop.py +208 -0
  28. ml_tools/ML_callbacks/_imprimir.py +12 -0
  29. ml_tools/ML_callbacks/_scheduler.py +197 -0
  30. ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
  31. ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
  32. ml_tools/ML_chain/_dragon_chain.py +140 -0
  33. ml_tools/ML_chain/_imprimir.py +11 -0
  34. ml_tools/ML_configuration/__init__.py +90 -0
  35. ml_tools/ML_configuration/_base_model_config.py +69 -0
  36. ml_tools/ML_configuration/_finalize.py +366 -0
  37. ml_tools/ML_configuration/_imprimir.py +47 -0
  38. ml_tools/ML_configuration/_metrics.py +593 -0
  39. ml_tools/ML_configuration/_models.py +206 -0
  40. ml_tools/ML_configuration/_training.py +124 -0
  41. ml_tools/ML_datasetmaster/__init__.py +28 -0
  42. ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
  43. ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
  44. ml_tools/ML_datasetmaster/_imprimir.py +15 -0
  45. ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
  46. ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
  47. ml_tools/ML_evaluation/__init__.py +53 -0
  48. ml_tools/ML_evaluation/_classification.py +629 -0
  49. ml_tools/ML_evaluation/_feature_importance.py +409 -0
  50. ml_tools/ML_evaluation/_imprimir.py +25 -0
  51. ml_tools/ML_evaluation/_loss.py +92 -0
  52. ml_tools/ML_evaluation/_regression.py +273 -0
  53. ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
  54. ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
  55. ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
  56. ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
  57. ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
  58. ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
  59. ml_tools/ML_finalize_handler/__init__.py +10 -0
  60. ml_tools/ML_finalize_handler/_imprimir.py +8 -0
  61. ml_tools/ML_inference/__init__.py +22 -0
  62. ml_tools/ML_inference/_base_inference.py +166 -0
  63. ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
  64. ml_tools/ML_inference/_dragon_inference.py +332 -0
  65. ml_tools/ML_inference/_imprimir.py +11 -0
  66. ml_tools/ML_inference/_multi_inference.py +180 -0
  67. ml_tools/ML_inference_sequence/__init__.py +10 -0
  68. ml_tools/ML_inference_sequence/_imprimir.py +8 -0
  69. ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
  70. ml_tools/ML_inference_vision/__init__.py +10 -0
  71. ml_tools/ML_inference_vision/_imprimir.py +8 -0
  72. ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
  73. ml_tools/ML_models/__init__.py +32 -0
  74. ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
  75. ml_tools/ML_models/_base_mlp_attention.py +198 -0
  76. ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
  77. ml_tools/ML_models/_dragon_tabular.py +248 -0
  78. ml_tools/ML_models/_imprimir.py +18 -0
  79. ml_tools/ML_models/_mlp_attention.py +134 -0
  80. ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
  81. ml_tools/ML_models_sequence/__init__.py +10 -0
  82. ml_tools/ML_models_sequence/_imprimir.py +8 -0
  83. ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
  84. ml_tools/ML_models_vision/__init__.py +29 -0
  85. ml_tools/ML_models_vision/_base_wrapper.py +254 -0
  86. ml_tools/ML_models_vision/_image_classification.py +182 -0
  87. ml_tools/ML_models_vision/_image_segmentation.py +108 -0
  88. ml_tools/ML_models_vision/_imprimir.py +16 -0
  89. ml_tools/ML_models_vision/_object_detection.py +135 -0
  90. ml_tools/ML_optimization/__init__.py +21 -0
  91. ml_tools/ML_optimization/_imprimir.py +13 -0
  92. ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
  93. ml_tools/ML_optimization/_single_dragon.py +203 -0
  94. ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
  95. ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
  96. ml_tools/ML_scaler/__init__.py +10 -0
  97. ml_tools/ML_scaler/_imprimir.py +8 -0
  98. ml_tools/ML_trainer/__init__.py +20 -0
  99. ml_tools/ML_trainer/_base_trainer.py +297 -0
  100. ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
  101. ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
  102. ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
  103. ml_tools/ML_trainer/_imprimir.py +10 -0
  104. ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
  105. ml_tools/ML_utilities/_artifact_finder.py +382 -0
  106. ml_tools/ML_utilities/_imprimir.py +16 -0
  107. ml_tools/ML_utilities/_inspection.py +325 -0
  108. ml_tools/ML_utilities/_train_tools.py +205 -0
  109. ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
  110. ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
  111. ml_tools/ML_vision_transformers/_imprimir.py +14 -0
  112. ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
  113. ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
  114. ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
  115. ml_tools/PSO_optimization/_imprimir.py +10 -0
  116. ml_tools/SQL/__init__.py +7 -0
  117. ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
  118. ml_tools/SQL/_imprimir.py +8 -0
  119. ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
  120. ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
  121. ml_tools/VIF/_imprimir.py +10 -0
  122. ml_tools/_core/__init__.py +7 -1
  123. ml_tools/_core/_logger.py +8 -18
  124. ml_tools/_core/_schema_load_ops.py +43 -0
  125. ml_tools/_core/_script_info.py +2 -2
  126. ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
  127. ml_tools/data_exploration/_analysis.py +214 -0
  128. ml_tools/data_exploration/_cleaning.py +566 -0
  129. ml_tools/data_exploration/_features.py +583 -0
  130. ml_tools/data_exploration/_imprimir.py +32 -0
  131. ml_tools/data_exploration/_plotting.py +487 -0
  132. ml_tools/data_exploration/_schema_ops.py +176 -0
  133. ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
  134. ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
  135. ml_tools/ensemble_evaluation/_imprimir.py +14 -0
  136. ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
  137. ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
  138. ml_tools/ensemble_inference/_imprimir.py +9 -0
  139. ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
  140. ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
  141. ml_tools/ensemble_learning/_imprimir.py +10 -0
  142. ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
  143. ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
  144. ml_tools/excel_handler/_imprimir.py +13 -0
  145. ml_tools/{keys.py → keys/__init__.py} +4 -1
  146. ml_tools/keys/_imprimir.py +11 -0
  147. ml_tools/{_core → keys}/_keys.py +2 -0
  148. ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
  149. ml_tools/math_utilities/_imprimir.py +11 -0
  150. ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
  151. ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
  152. ml_tools/optimization_tools/_imprimir.py +13 -0
  153. ml_tools/optimization_tools/_optimization_bounds.py +236 -0
  154. ml_tools/optimization_tools/_optimization_plots.py +218 -0
  155. ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
  156. ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
  157. ml_tools/path_manager/_imprimir.py +15 -0
  158. ml_tools/path_manager/_path_tools.py +346 -0
  159. ml_tools/plot_fonts/__init__.py +8 -0
  160. ml_tools/plot_fonts/_imprimir.py +8 -0
  161. ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
  162. ml_tools/schema/__init__.py +15 -0
  163. ml_tools/schema/_feature_schema.py +223 -0
  164. ml_tools/schema/_gui_schema.py +191 -0
  165. ml_tools/schema/_imprimir.py +10 -0
  166. ml_tools/{serde.py → serde/__init__.py} +4 -2
  167. ml_tools/serde/_imprimir.py +10 -0
  168. ml_tools/{_core → serde}/_serde.py +3 -8
  169. ml_tools/{utilities.py → utilities/__init__.py} +11 -6
  170. ml_tools/utilities/_imprimir.py +18 -0
  171. ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
  172. ml_tools/utilities/_utility_tools.py +192 -0
  173. dragon_ml_toolbox-19.14.0.dist-info/RECORD +0 -111
  174. ml_tools/ML_chaining_inference.py +0 -8
  175. ml_tools/ML_configuration.py +0 -86
  176. ml_tools/ML_configuration_pytab.py +0 -14
  177. ml_tools/ML_datasetmaster.py +0 -10
  178. ml_tools/ML_evaluation.py +0 -16
  179. ml_tools/ML_evaluation_multi.py +0 -12
  180. ml_tools/ML_finalize_handler.py +0 -8
  181. ml_tools/ML_inference.py +0 -12
  182. ml_tools/ML_models.py +0 -14
  183. ml_tools/ML_models_advanced.py +0 -14
  184. ml_tools/ML_models_pytab.py +0 -14
  185. ml_tools/ML_optimization.py +0 -14
  186. ml_tools/ML_optimization_pareto.py +0 -8
  187. ml_tools/ML_scaler.py +0 -8
  188. ml_tools/ML_sequence_datasetmaster.py +0 -8
  189. ml_tools/ML_sequence_evaluation.py +0 -10
  190. ml_tools/ML_sequence_inference.py +0 -8
  191. ml_tools/ML_sequence_models.py +0 -8
  192. ml_tools/ML_trainer.py +0 -12
  193. ml_tools/ML_vision_datasetmaster.py +0 -12
  194. ml_tools/ML_vision_evaluation.py +0 -10
  195. ml_tools/ML_vision_inference.py +0 -8
  196. ml_tools/ML_vision_models.py +0 -18
  197. ml_tools/SQL.py +0 -8
  198. ml_tools/_core/_ETL_cleaning.py +0 -694
  199. ml_tools/_core/_IO_tools.py +0 -498
  200. ml_tools/_core/_ML_callbacks.py +0 -702
  201. ml_tools/_core/_ML_configuration.py +0 -1332
  202. ml_tools/_core/_ML_configuration_pytab.py +0 -102
  203. ml_tools/_core/_ML_evaluation.py +0 -867
  204. ml_tools/_core/_ML_evaluation_multi.py +0 -544
  205. ml_tools/_core/_ML_inference.py +0 -646
  206. ml_tools/_core/_ML_models.py +0 -668
  207. ml_tools/_core/_ML_models_pytab.py +0 -693
  208. ml_tools/_core/_ML_trainer.py +0 -2323
  209. ml_tools/_core/_ML_utilities.py +0 -886
  210. ml_tools/_core/_ML_vision_models.py +0 -644
  211. ml_tools/_core/_data_exploration.py +0 -1909
  212. ml_tools/_core/_optimization_tools.py +0 -493
  213. ml_tools/_core/_schema.py +0 -359
  214. ml_tools/plot_fonts.py +0 -8
  215. ml_tools/schema.py +0 -12
  216. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
  217. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
  218. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  219. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
@@ -1,498 +0,0 @@
1
- from pathlib import Path
2
- from datetime import datetime
3
- from typing import Union, List, Dict, Any, Literal, overload
4
- import traceback
5
- import json
6
- import csv
7
- from itertools import zip_longest
8
- from collections import Counter
9
-
10
- from ._path_manager import sanitize_filename, make_fullpath
11
- from ._script_info import _script_info
12
- from ._logger import get_logger
13
-
14
-
15
- _LOGGER = get_logger("IO")
16
-
17
-
18
- __all__ = [
19
- "custom_logger",
20
- "train_logger",
21
- "save_json",
22
- "load_json",
23
- "save_list_strings",
24
- "load_list_strings",
25
- "compare_lists"
26
- ]
27
-
28
-
29
- def custom_logger(
30
- data: Union[
31
- List[Any],
32
- Dict[Any, Any],
33
- str,
34
- BaseException
35
- ],
36
- save_directory: Union[str, Path],
37
- log_name: str,
38
- add_timestamp: bool=True,
39
- dict_as: Literal['auto', 'json', 'csv'] = 'auto',
40
- ) -> None:
41
- """
42
- Logs various data types to corresponding output formats:
43
-
44
- - list[Any] → .txt
45
- Each element is written on a new line.
46
-
47
- - dict[str, list[Any]] → .csv (if dict_as='auto' or 'csv')
48
- Dictionary is treated as tabular data; keys become columns, values become rows.
49
-
50
- - dict[str, scalar] → .json (if dict_as='auto' or 'json')
51
- Dictionary is treated as structured data and serialized as JSON.
52
-
53
- - str → .log
54
- Plain text string is written to a .log file.
55
-
56
- - BaseException → .log
57
- Full traceback is logged for debugging purposes.
58
-
59
- Args:
60
- data (Any): The data to be logged. Must be one of the supported types.
61
- save_directory (str | Path): Directory where the log will be saved. Created if it does not exist.
62
- log_name (str): Base name for the log file.
63
- add_timestamp (bool): Whether to add a timestamp to the filename.
64
- dict_as ('auto'|'json'|'csv'):
65
- - 'auto': Guesses format (JSON or CSV) based on dictionary content.
66
- - 'json': Forces .json format for any dictionary.
67
- - 'csv': Forces .csv format. Will fail if dict values are not all lists.
68
-
69
- Raises:
70
- ValueError: If the data type is unsupported.
71
- """
72
- try:
73
- if not isinstance(data, BaseException) and not data:
74
- _LOGGER.warning("Empty data received. No log file will be saved.")
75
- return
76
-
77
- save_path = make_fullpath(save_directory, make=True)
78
-
79
- sanitized_log_name = sanitize_filename(log_name)
80
-
81
- if add_timestamp:
82
- timestamp = datetime.now().strftime(r"%Y%m%d_%H%M%S")
83
- base_path = save_path / f"{sanitized_log_name}_{timestamp}"
84
- else:
85
- base_path = save_path / sanitized_log_name
86
-
87
- # Router
88
- if isinstance(data, list):
89
- _log_list_to_txt(data, base_path.with_suffix(".txt"))
90
-
91
- elif isinstance(data, dict):
92
- if dict_as == 'json':
93
- _log_dict_to_json(data, base_path.with_suffix(".json"))
94
-
95
- elif dict_as == 'csv':
96
- # This will raise a ValueError if data is not all lists
97
- _log_dict_to_csv(data, base_path.with_suffix(".csv"))
98
-
99
- else: # 'auto' mode
100
- if all(isinstance(v, list) for v in data.values()):
101
- _log_dict_to_csv(data, base_path.with_suffix(".csv"))
102
- else:
103
- _log_dict_to_json(data, base_path.with_suffix(".json"))
104
-
105
- elif isinstance(data, str):
106
- _log_string_to_log(data, base_path.with_suffix(".log"))
107
-
108
- elif isinstance(data, BaseException):
109
- _log_exception_to_log(data, base_path.with_suffix(".log"))
110
-
111
- else:
112
- _LOGGER.error("Unsupported data type. Must be list, dict, str, or BaseException.")
113
- raise ValueError()
114
-
115
- _LOGGER.info(f"Log saved as: '{base_path.name}'")
116
-
117
- except Exception:
118
- _LOGGER.exception(f"Log not saved.")
119
-
120
-
121
- def _log_list_to_txt(data: List[Any], path: Path) -> None:
122
- log_lines = []
123
- for item in data:
124
- try:
125
- log_lines.append(str(item).strip())
126
- except Exception:
127
- log_lines.append(f"(unrepresentable item of type {type(item)})")
128
-
129
- with open(path, 'w', encoding='utf-8') as f:
130
- f.write('\n'.join(log_lines))
131
-
132
-
133
- def _log_dict_to_csv(data: Dict[Any, List[Any]], path: Path) -> None:
134
- sanitized_dict = {}
135
- max_length = max(len(v) for v in data.values()) if data else 0
136
-
137
- for key, value in data.items():
138
- if not isinstance(value, list):
139
- _LOGGER.error(f"Dictionary value for key '{key}' must be a list.")
140
- raise ValueError()
141
-
142
- sanitized_key = str(key).strip().replace('\n', '_').replace('\r', '_')
143
- padded_value = value + [None] * (max_length - len(value))
144
- sanitized_dict[sanitized_key] = padded_value
145
-
146
- # The `newline=''` argument is important to prevent extra blank rows
147
- with open(path, 'w', newline='', encoding='utf-8') as csv_file:
148
- writer = csv.writer(csv_file)
149
-
150
- # 1. Write the header row from the sanitized dictionary keys
151
- header = list(sanitized_dict.keys())
152
- writer.writerow(header)
153
-
154
- # 2. Transpose columns to rows and write them
155
- # zip(*sanitized_dict.values()) elegantly converts the column data
156
- # (lists in the dict) into row-by-row tuples.
157
- rows_to_write = zip(*sanitized_dict.values())
158
- writer.writerows(rows_to_write)
159
-
160
-
161
- def _log_string_to_log(data: str, path: Path) -> None:
162
- with open(path, 'w', encoding='utf-8') as f:
163
- f.write(data.strip() + '\n')
164
-
165
-
166
- def _log_exception_to_log(exc: BaseException, path: Path) -> None:
167
- with open(path, 'w', encoding='utf-8') as f:
168
- f.write("Exception occurred:\n")
169
- traceback.print_exception(type(exc), exc, exc.__traceback__, file=f)
170
-
171
-
172
- def _log_dict_to_json(data: Dict[Any, Any], path: Path) -> None:
173
- with open(path, 'w', encoding='utf-8') as f:
174
- json.dump(data, f, indent=4, ensure_ascii=False)
175
-
176
-
177
- def save_json(
178
- data: Union[Dict[Any, Any], List[Any]],
179
- directory: Union[str, Path],
180
- filename: str,
181
- verbose: bool = True
182
- ) -> None:
183
- """
184
- Saves a dictionary or list as a JSON file.
185
-
186
- Args:
187
- data (dict | list): The data to save.
188
- directory (str | Path): The directory to save the file in.
189
- filename (str): The name of the file (extension .json will be added if missing).
190
- verbose (bool): Whether to log success messages.
191
- """
192
- target_dir = make_fullpath(directory, make=True, enforce="directory")
193
- sanitized_name = sanitize_filename(filename)
194
-
195
- if not sanitized_name.endswith(".json"):
196
- sanitized_name += ".json"
197
-
198
- full_path = target_dir / sanitized_name
199
-
200
- try:
201
- with open(full_path, 'w', encoding='utf-8') as f:
202
- # Using _RobustEncoder ensures compatibility with non-standard types (like 'type' objects)
203
- json.dump(data, f, indent=4, ensure_ascii=False, cls=_RobustEncoder)
204
-
205
- if verbose:
206
- _LOGGER.info(f"JSON file saved as '{full_path.name}'.")
207
-
208
- except Exception as e:
209
- _LOGGER.error(f"Failed to save JSON to '{full_path}': {e}")
210
- raise
211
-
212
-
213
- # 1. Define Overloads (for the type checker)
214
- @overload
215
- def load_json(
216
- file_path: Union[str, Path],
217
- expected_type: Literal["dict"] = "dict",
218
- verbose: bool = True
219
- ) -> Dict[Any, Any]: ...
220
-
221
- @overload
222
- def load_json(
223
- file_path: Union[str, Path],
224
- expected_type: Literal["list"],
225
- verbose: bool = True
226
- ) -> List[Any]: ...
227
-
228
-
229
- def load_json(
230
- file_path: Union[str, Path],
231
- expected_type: Literal["dict", "list"] = "dict",
232
- verbose: bool = True
233
- ) -> Union[Dict[Any, Any], List[Any]]:
234
- """
235
- Loads a JSON file.
236
-
237
- Args:
238
- file_path (str | Path): The path to the JSON file.
239
- expected_type ('dict' | 'list'): strict check for the root type of the JSON.
240
- verbose (bool): Whether to log success/failure messages.
241
-
242
- Returns:
243
- dict | list: The loaded JSON data.
244
- """
245
- target_path = make_fullpath(file_path, enforce="file")
246
-
247
- # Map string literals to actual python types
248
- type_map = {"dict": dict, "list": list}
249
- target_type = type_map.get(expected_type, dict)
250
-
251
- try:
252
- with open(target_path, 'r', encoding='utf-8') as f:
253
- data = json.load(f)
254
-
255
- if not isinstance(data, target_type):
256
- _LOGGER.error(f"JSON root is type {type(data)}, expected {expected_type}.")
257
- raise ValueError()
258
-
259
- if verbose:
260
- _LOGGER.info(f"Loaded JSON data from '{target_path.name}'.")
261
-
262
- return data
263
-
264
- except json.JSONDecodeError as e:
265
- _LOGGER.error(f"Failed to decode JSON from '{target_path}': {e.msg}")
266
- raise ValueError()
267
-
268
- except Exception as e:
269
- _LOGGER.error(f"Error loading JSON from '{target_path}': {e}")
270
- raise
271
-
272
-
273
- def save_list_strings(list_strings: list[str], directory: Union[str,Path], filename: str, verbose: bool=True):
274
- """Saves a list of strings as a text file."""
275
- target_dir = make_fullpath(directory, make=True, enforce="directory")
276
- sanitized_name = sanitize_filename(filename)
277
-
278
- if not sanitized_name.endswith(".txt"):
279
- sanitized_name = sanitized_name + ".txt"
280
-
281
- full_path = target_dir / sanitized_name
282
- with open(full_path, 'w') as f:
283
- for string_data in list_strings:
284
- f.write(f"{string_data}\n")
285
-
286
- if verbose:
287
- _LOGGER.info(f"Text file saved as '{full_path.name}'.")
288
-
289
-
290
- def load_list_strings(text_file: Union[str,Path], verbose: bool=True) -> list[str]:
291
- """Loads a text file as a list of strings."""
292
- target_path = make_fullpath(text_file, enforce="file")
293
- loaded_strings = []
294
-
295
- with open(target_path, 'r') as f:
296
- loaded_strings = [line.strip() for line in f]
297
-
298
- if len(loaded_strings) == 0:
299
- _LOGGER.error("The text file is empty.")
300
- raise ValueError()
301
-
302
- if verbose:
303
- _LOGGER.info(f"Loaded '{target_path.name}' as list of strings.")
304
-
305
- return loaded_strings
306
-
307
-
308
- class _RobustEncoder(json.JSONEncoder):
309
- """
310
- Custom JSON encoder to handle non-serializable objects.
311
-
312
- This handles:
313
- 1. `type` objects (e.g., <class 'int'>) which result from
314
- `check_type_only=True`.
315
- 2. Any other custom class or object by falling back to its
316
- string representation.
317
- """
318
- def default(self, o):
319
- if isinstance(o, type):
320
- return str(o)
321
- try:
322
- return super().default(o)
323
- except TypeError:
324
- return str(o)
325
-
326
- def compare_lists(
327
- list_A: list,
328
- list_B: list,
329
- save_dir: Union[str, Path],
330
- strict: bool = False,
331
- check_type_only: bool = False
332
- ) -> dict:
333
- """
334
- Compares two lists and saves a JSON report of the differences.
335
-
336
- Args:
337
- list_A (list): The first list to compare.
338
- list_B (list): The second list to compare.
339
- save_dir (str | Path): The directory where the resulting report will be saved.
340
- strict (bool):
341
- - If False: Performs a "bag" comparison. Order does not matter, but duplicates do.
342
- - If True: Performs a strict, positional comparison.
343
-
344
- check_type_only (bool):
345
- - If False: Compares items using `==` (`__eq__` operator).
346
- - If True: Compares only the `type()` of the items.
347
-
348
- Returns:
349
- dict: A dictionary detailing the differences. (saved to `save_dir`).
350
- """
351
- MISSING_A_KEY = "missing_in_A"
352
- MISSING_B_KEY = "missing_in_B"
353
- MISMATCH_KEY = "mismatch"
354
-
355
- results: dict[str, list] = {MISSING_A_KEY: [], MISSING_B_KEY: []}
356
-
357
- # make directory
358
- save_path = make_fullpath(input_path=save_dir, make=True, enforce="directory")
359
-
360
- if strict:
361
- # --- STRICT (Positional) Mode ---
362
- results[MISMATCH_KEY] = []
363
- sentinel = object()
364
-
365
- if check_type_only:
366
- compare_func = lambda a, b: type(a) == type(b)
367
- else:
368
- compare_func = lambda a, b: a == b
369
-
370
- for index, (item_a, item_b) in enumerate(
371
- zip_longest(list_A, list_B, fillvalue=sentinel)
372
- ):
373
- if item_a is sentinel:
374
- results[MISSING_A_KEY].append({"index": index, "item": item_b})
375
- elif item_b is sentinel:
376
- results[MISSING_B_KEY].append({"index": index, "item": item_a})
377
- elif not compare_func(item_a, item_b):
378
- results[MISMATCH_KEY].append(
379
- {
380
- "index": index,
381
- "list_A_item": item_a,
382
- "list_B_item": item_b,
383
- }
384
- )
385
-
386
- else:
387
- # --- NON-STRICT (Bag) Mode ---
388
- if check_type_only:
389
- # Types are hashable, we can use Counter (O(N))
390
- types_A_counts = Counter(type(item) for item in list_A)
391
- types_B_counts = Counter(type(item) for item in list_B)
392
-
393
- diff_A_B = types_A_counts - types_B_counts
394
- for item_type, count in diff_A_B.items():
395
- results[MISSING_B_KEY].extend([item_type] * count)
396
-
397
- diff_B_A = types_B_counts - types_A_counts
398
- for item_type, count in diff_B_A.items():
399
- results[MISSING_A_KEY].extend([item_type] * count)
400
-
401
- else:
402
- # Items may be unhashable. Use O(N*M) .remove() method
403
- temp_B = list(list_B)
404
- missing_in_B = []
405
-
406
- for item_a in list_A:
407
- try:
408
- temp_B.remove(item_a)
409
- except ValueError:
410
- missing_in_B.append(item_a)
411
-
412
- results[MISSING_A_KEY] = temp_B
413
- results[MISSING_B_KEY] = missing_in_B
414
-
415
- # --- Save the Report ---
416
- try:
417
- full_path = save_path / "list_comparison.json"
418
-
419
- # Write the report dictionary to the JSON file
420
- with open(full_path, 'w', encoding='utf-8') as f:
421
- json.dump(results, f, indent=4, cls=_RobustEncoder)
422
-
423
- except Exception as e:
424
- _LOGGER.error(f"Failed to save comparison report to {save_path}: \n{e}")
425
-
426
- return results
427
-
428
-
429
- def train_logger(train_config: Union[dict, Any],
430
- model_parameters: Union[dict, Any],
431
- train_history: Union[dict, None],
432
- save_directory: Union[str, Path]):
433
- """
434
- Logs training data to JSON, adding a timestamp to the filename.
435
-
436
- Args:
437
- train_config (dict | Any): Training configuration parameters. If object, must have a `.to_log()` method returning a dict.
438
- model_parameters (dict | Any): Model parameters. If object, must have a `.to_log()` method returning a dict.
439
- train_history (dict | None): Training history log.
440
- save_directory (str | Path): Directory to save the log file.
441
- """
442
- # train_config should be a dict or a custom object with the ".to_log()" method
443
- if not isinstance(train_config, dict):
444
- if hasattr(train_config, "to_log") and callable(getattr(train_config, "to_log")):
445
- train_config_dict: dict = train_config.to_log()
446
- if not isinstance(train_config_dict, dict):
447
- _LOGGER.error("'train_config.to_log()' did not return a dictionary.")
448
- raise ValueError()
449
- else:
450
- _LOGGER.error("'train_config' must be a dict or an object with a 'to_log()' method.")
451
- raise ValueError()
452
- else:
453
- # check for empty dict
454
- if not train_config:
455
- _LOGGER.error("'train_config' dictionary is empty.")
456
- raise ValueError()
457
-
458
- train_config_dict = train_config
459
-
460
- # model_parameters should be a dict or a custom object with the ".to_log()" method
461
- if not isinstance(model_parameters, dict):
462
- if hasattr(model_parameters, "to_log") and callable(getattr(model_parameters, "to_log")):
463
- model_parameters_dict: dict = model_parameters.to_log()
464
- if not isinstance(model_parameters_dict, dict):
465
- _LOGGER.error("'model_parameters.to_log()' did not return a dictionary.")
466
- raise ValueError()
467
- else:
468
- _LOGGER.error("'model_parameters' must be a dict or an object with a 'to_log()' method.")
469
- raise ValueError()
470
- else:
471
- # check for empty dict
472
- if not model_parameters:
473
- _LOGGER.error("'model_parameters' dictionary is empty.")
474
- raise ValueError()
475
-
476
- model_parameters_dict = model_parameters
477
-
478
- # make base dictionary
479
- data: dict = train_config_dict | model_parameters_dict
480
-
481
- # add training history if provided and is not empty
482
- if train_history is not None:
483
- if not train_history:
484
- _LOGGER.error("'train_history' dictionary was provided but is empty.")
485
- raise ValueError()
486
- data.update(train_history)
487
-
488
- custom_logger(
489
- data=data,
490
- save_directory=save_directory,
491
- log_name="training_log",
492
- add_timestamp=True,
493
- dict_as='json'
494
- )
495
-
496
-
497
- def info():
498
- _script_info(__all__)