dragon-ml-toolbox 20.0.0__py3-none-any.whl → 20.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 20.0.0
3
+ Version: 20.1.0
4
4
  Summary: Complete pipelines and helper tools for data science and machine learning projects.
5
5
  Author-email: Karl Luigi Loza Vidaurre <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -1,5 +1,5 @@
1
- dragon_ml_toolbox-20.0.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
2
- dragon_ml_toolbox-20.0.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=0-HBRMMgKuwtGy6nMJZvIn1fLxhx_ksyyVB2U_iyYZU,2818
1
+ dragon_ml_toolbox-20.1.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
2
+ dragon_ml_toolbox-20.1.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=0-HBRMMgKuwtGy6nMJZvIn1fLxhx_ksyyVB2U_iyYZU,2818
3
3
  ml_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  ml_tools/constants.py,sha256=3br5Rk9cL2IUo638eJuMOGdbGQaWssaUecYEvSeRBLM,3322
5
5
  ml_tools/ETL_cleaning/__init__.py,sha256=TytE8RKmtW4KQlkaTxpYKlJAbCu-VAc82eDdHwVD3Jo,427
@@ -21,7 +21,7 @@ ml_tools/IO_tools/__init__.py,sha256=ZeEM5bbZ5udgRXFAL51uRXzoCzPLO8TWZ4AiME7NNy0
21
21
  ml_tools/IO_tools/_imprimir.py,sha256=eN-V60xtDNFINThuRTjXknMxtbK8Ah0MWgc8l2GTXMA,250
22
22
  ml_tools/MICE/_MICE_imputation.py,sha256=N1cDwVYfoHvIZz7FLLcW-guZUo8iFKedtkfS7CU6TVE,5318
23
23
  ml_tools/MICE/__init__.py,sha256=i5N_fd3rxpEgLsKKDoLbokW0rHm-ADEg8r3gBB5426E,313
24
- ml_tools/MICE/_dragon_mice.py,sha256=E6LyCe7JjEvDeKJfDfDd1iKJS86pDQLYGYoajahtuyg,17736
24
+ ml_tools/MICE/_dragon_mice.py,sha256=qEOy9Gx1QzVBvkvGR8790TkvKw8-fp06vCDGWM6j9os,17806
25
25
  ml_tools/MICE/_imprimir.py,sha256=YVhgZlUQ-NrDUVhHTK3u8s1QEbZ_jvDVF7-0FptVsxs,215
26
26
  ml_tools/ML_callbacks/__init__.py,sha256=dF37KXezy6P3VArhZbm5CI6si65GA-qVY70jvZFZYkA,427
27
27
  ml_tools/ML_callbacks/_base.py,sha256=xLVAFOhBHjqnf8a_wKgW1F-tn2u6EqV3IHXsXKTn2NE,3269
@@ -29,10 +29,11 @@ ml_tools/ML_callbacks/_checkpoint.py,sha256=Ioj9wn8XlsR_S1NnmWbyT9lkO8o2_DcHVMrF
29
29
  ml_tools/ML_callbacks/_early_stop.py,sha256=qzTzxfDCDim0qj7QQ7ykJNIOBWbXtviDptMCczXXy_k,8073
30
30
  ml_tools/ML_callbacks/_imprimir.py,sha256=Wz6NXhiCFSJsAZh3JnQ4qt7tj2_qhu14DTwu-gkkzZs,257
31
31
  ml_tools/ML_callbacks/_scheduler.py,sha256=mn97_VH8Lp37KH3zSgmPemGQV8g-K8GfhRNHTftaNcg,7390
32
- ml_tools/ML_chain/__init__.py,sha256=rUBVwB96fAoq-Q9zY3s0fL_TFU5W2axlg7XZzrCXrSU,399
33
- ml_tools/ML_chain/_chaining_tools.py,sha256=ASi0Zr9WBVA7wd-pYVN69VIZFOIuB4QpGlrSl9Ob-90,13788
34
- ml_tools/ML_chain/_dragon_chain.py,sha256=wFlknv0rlL8P3K0ls8kj_oup4SvPNFqSxDmiBdPfGt4,5737
35
- ml_tools/ML_chain/_imprimir.py,sha256=JCVslxnrmvJ_LJOmexL2u5-OYykHFe1H49EkrJPpAIg,254
32
+ ml_tools/ML_chain/__init__.py,sha256=UVD1xaJ59pft_ysg8z_ihqjEDQqPRQwmhui_zNRFp7I,491
33
+ ml_tools/ML_chain/_chaining_tools.py,sha256=BDwTvgJFbJ-wgy3IkP6_SNpNaWpHGXV3PhAM7sYmHeU,13675
34
+ ml_tools/ML_chain/_dragon_chain.py,sha256=x3fN136C5N9WcXJJW9zkNrBzP8QoBaXpxz7SPF3txjg,5601
35
+ ml_tools/ML_chain/_imprimir.py,sha256=tHVXoGhMlbpkpcoGKwtkYVFlHFEllRCsYdpiAFI1aZk,285
36
+ ml_tools/ML_chain/_update_schema.py,sha256=z1Us7lv6hy6GwSu1mcid50Jmqq3sh91hMQ0LnQjhte8,3806
36
37
  ml_tools/ML_configuration/__init__.py,sha256=wSpfk8bHRSoYjcKJmjd5ivB4Fw8UFjyOZL4hct9rJT0,2637
37
38
  ml_tools/ML_configuration/_base_model_config.py,sha256=95L3IfobNFMtnNr79zYpDGerC1q1v7M05tWZvTS2cwE,2247
38
39
  ml_tools/ML_configuration/_finalize.py,sha256=l_n13bLu0avMdJ8hNRrH8V_wOBQZM1UGsTydKBkTysM,15047
@@ -125,11 +126,11 @@ ml_tools/_core/__init__.py,sha256=m-VP0RW0tOTm9N5NI3kFNcpM7WtVgs0RK9pK3ZJRZQQ,14
125
126
  ml_tools/_core/_logger.py,sha256=xzhn_FouMDRVNwXGBGlPC9Ruq6i5uCrmNaS5jesguMU,4972
126
127
  ml_tools/_core/_schema_load_ops.py,sha256=KLs9vBzANz5ESe2wlP-C41N4VlgGil-ywcfvWKSOGss,1551
127
128
  ml_tools/_core/_script_info.py,sha256=LtFGt10gEvCnhIRMKJPi2yXkiGLcdr7lE-oIP2XGHzQ,234
128
- ml_tools/data_exploration/__init__.py,sha256=a4hlq6Pyc_cQjiys_2CUFd5nIvzqPc4g8asWEHJz9Es,1674
129
+ ml_tools/data_exploration/__init__.py,sha256=w9dM6wjmxfbEXQCWGFVL_cIuLHtYVP364aQvzRwfZXY,1674
129
130
  ml_tools/data_exploration/_analysis.py,sha256=H6LryV56FFCHWjvQdkhZbtprZy6aP8EqU_hC2Cf9CLE,7832
130
131
  ml_tools/data_exploration/_cleaning.py,sha256=LpoOHOB6HVtdObZExg-B8SxZW-JUc51tblnkCFDZxKg,20846
131
132
  ml_tools/data_exploration/_features.py,sha256=wW-M8n2aLIy05DR2z4fI8wjpPjn3mOAnm9aSGYbMKwI,23363
132
- ml_tools/data_exploration/_imprimir.py,sha256=PkvDvQkYTQC_KnfI1gxxUxtC-XeSRePniM1TyJj8Caw,876
133
+ ml_tools/data_exploration/_imprimir.py,sha256=0nXu60HpeJZ8s83mpVoRtdKILK3t8EHRFVk7d9vRVUo,876
133
134
  ml_tools/data_exploration/_plotting.py,sha256=zH1dPcIoAlOuww23xIoBCsQOAshPPv9OyGposOA2RvI,19883
134
135
  ml_tools/data_exploration/_schema_ops.py,sha256=PoFeHaS9dXI9gfL0SRD-8uSP4owqmbQFbtfA-HxkLnY,7108
135
136
  ml_tools/ensemble_evaluation/__init__.py,sha256=Xxx-F-_TvSVzMaocKXOo_tEXLibMJtf_YY85Ac3U0EI,483
@@ -146,7 +147,7 @@ ml_tools/excel_handler/_excel_handler.py,sha256=TODudmeQgDSdxUKzLfAzizs--VL-g8Wx
146
147
  ml_tools/excel_handler/_imprimir.py,sha256=QHazgqjRMzthRbDt33EVpvR7GqufSzng6jHw7IVCdtI,306
147
148
  ml_tools/keys/__init__.py,sha256=DV52KLOY5GfpLwJdDAHlFVz0qAmyh-KWg3gZorFdMSk,336
148
149
  ml_tools/keys/_imprimir.py,sha256=4qmwdia16DPq3OtlWGMkgLPT5R3lcM-ka3tQdCLx5qk,197
149
- ml_tools/keys/_keys.py,sha256=wyUpNY7iZIGIqvnT2BSahnkkNkK_vvZALOtRWZ7h50A,8800
150
+ ml_tools/keys/_keys.py,sha256=fArSyT_UGGSH4PHjG-R0kefFznAtAxSAasDCQ7-89a8,8899
150
151
  ml_tools/math_utilities/__init__.py,sha256=NuTcb_Ogdwx5x-oDieBt1EAqCoZRnXbkZbUrwB6ItH0,337
151
152
  ml_tools/math_utilities/_imprimir.py,sha256=kk5DQb_BV9g767uTdXQiRjEEHgQwJpEXU3jxO3QV2Fw,238
152
153
  ml_tools/math_utilities/_math_utilities.py,sha256=BYHIVcM9tuKIhVrkgLLiM5QalJ39zx7dXYy_M9aGgiM,9012
@@ -162,7 +163,7 @@ ml_tools/plot_fonts/__init__.py,sha256=l-vSSpjZb6IeWjjgPTcNmEs7M-vbw0lqgEKD5jhtX
162
163
  ml_tools/plot_fonts/_imprimir.py,sha256=zNi6naa5eWBFfa_yV569MhUtSAL44H0xDjMcgrJSlXk,131
163
164
  ml_tools/plot_fonts/_plot_fonts.py,sha256=mfjXNT9P59ymHoTI85Q8CcvfxfK5BIFBWtTZH-hNIC4,2209
164
165
  ml_tools/schema/__init__.py,sha256=9LQtKz3OO9wm-1piUgAhCJZVZT-F-YSg5QLus9pxfgA,263
165
- ml_tools/schema/_feature_schema.py,sha256=QLsxBS3_CIJp4c4dknvMs7RHZl_GZDEBJQ0MxLrQo6Y,8536
166
+ ml_tools/schema/_feature_schema.py,sha256=ICymTIL05n1qs61TvyY7rapDOJ9PlaOHi0F86N4tNlU,8547
166
167
  ml_tools/schema/_gui_schema.py,sha256=IVwN4THAdFrvh2TpV4SFd_zlzMX3eioF-w-qcSVTndE,7245
167
168
  ml_tools/schema/_imprimir.py,sha256=waNHozZmkCKKNFWSw0HFf9489FkSXogl6KuT5cn5V74,190
168
169
  ml_tools/serde/__init__.py,sha256=Gj6B8Sgf0-ad72jFXq2W_k5pXOT2iNx5Dvzwrd7Tj1U,229
@@ -172,7 +173,7 @@ ml_tools/utilities/__init__.py,sha256=pkR2HxUIlKZMDderP2awYXVIFxkU2Xt3FkJmcmuRIp
172
173
  ml_tools/utilities/_imprimir.py,sha256=sV3ASBOsTdVYvGojOTIpZYFyrnd4panS5h_4HcMzob4,432
173
174
  ml_tools/utilities/_utility_save_load.py,sha256=7skiiuYGVLVMK_nU9uLfUZw16ePvF3i9ub7G7LMyUgs,16085
174
175
  ml_tools/utilities/_utility_tools.py,sha256=bN0J9d1S0W5wNzNntBWqDsJcEAK7-1OgQg3X2fwXns0,6918
175
- dragon_ml_toolbox-20.0.0.dist-info/METADATA,sha256=ILeGioHn8qeLS5vaaqOs-zId8QvQxoWZcjKgHYmeuPo,7866
176
- dragon_ml_toolbox-20.0.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
177
- dragon_ml_toolbox-20.0.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
178
- dragon_ml_toolbox-20.0.0.dist-info/RECORD,,
176
+ dragon_ml_toolbox-20.1.0.dist-info/METADATA,sha256=g8BdKr-giBfa-J0TWjinoX1W4lzGaTFZEovm_Fv_43w,7866
177
+ dragon_ml_toolbox-20.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
178
+ dragon_ml_toolbox-20.1.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
179
+ dragon_ml_toolbox-20.1.0.dist-info/RECORD,,
@@ -197,7 +197,7 @@ class DragonMICE:
197
197
  _LOGGER.error(f"Index mismatch in dataset {subname}")
198
198
  raise ValueError()
199
199
 
200
- _LOGGER.info("Schema-based MICE imputation complete.")
200
+ _LOGGER.info("⬅️ Schema-based MICE imputation complete.")
201
201
 
202
202
  return kernel, imputed_datasets, imputed_dataset_names
203
203
 
@@ -237,9 +237,6 @@ class DragonMICE:
237
237
  # We pass an empty DF as 'targets' to save_imputed_datasets to prevent duplication.
238
238
  df_input = df
239
239
  df_targets_to_save = pd.DataFrame(index=df.index)
240
-
241
- # Monitor all columns that had NaNs
242
- imputed_column_names = [col for col in df.columns if df[col].isna().any()]
243
240
  else:
244
241
  # Explicitly cast tuple to list for Pandas indexing
245
242
  feature_cols = list(self._schema.feature_names)
@@ -253,8 +250,9 @@ class DragonMICE:
253
250
  df_input = df[feature_cols]
254
251
  # Drop features to get targets (more robust than explicit selection if targets vary)
255
252
  df_targets_to_save = df.drop(columns=feature_cols)
256
-
257
- imputed_column_names = _get_na_column_names(df=df_input) # type: ignore
253
+
254
+ # Monitor all columns that had NaNs
255
+ imputed_column_names = [col for col in df_input.columns if df_input[col].isna().any()]
258
256
 
259
257
  # Run core logic
260
258
  kernel, imputed_datasets, imputed_dataset_names = self._run_mice(df=df_input, df_name=df_name) # type: ignore
@@ -316,35 +314,41 @@ def get_convergence_diagnostic(kernel: mf.ImputationKernel, imputed_dataset_name
316
314
 
317
315
  # iterate over each imputed dataset
318
316
  for dataset_id, imputed_dataset_name in zip(range(dataset_count), imputed_dataset_names):
319
- #Check directory for current dataset
320
317
  dataset_file_dir = f"Convergence_Metrics_{imputed_dataset_name}"
321
318
  local_save_dir = make_fullpath(input_path=root_path / dataset_file_dir, make=True)
322
319
 
323
- for feature_name in column_names:
324
- means_per_iteration = []
325
- for iteration in range(iterations_cap):
326
- current_imputed = kernel.complete_data(dataset=dataset_id, iteration=iteration)
327
- means_per_iteration.append(np.mean(current_imputed[feature_name])) # type: ignore
328
-
320
+ # 1. Pre-calculate means for all features across all iterations
321
+ # Structure: {feature_name: [mean_iter_0, mean_iter_1, ...]}
322
+ history = {col: [] for col in column_names}
323
+
324
+ for iteration in range(iterations_cap):
325
+ # Resolve dataset ONLY ONCE per iteration
326
+ current_imputed = kernel.complete_data(dataset=dataset_id, iteration=iteration)
327
+
328
+ for col in column_names:
329
+ # Fast lookup
330
+ val = np.mean(current_imputed[col])
331
+ history[col].append(val)
332
+
333
+ # 2. Plotting loop
334
+ for feature_name, means_per_iteration in history.items():
329
335
  plt.figure(figsize=(10, 8))
330
336
  plt.plot(means_per_iteration, marker='o')
331
337
  plt.xlabel("Iteration", **label_font)
332
338
  plt.ylabel("Mean of Imputed Values", **label_font)
333
339
  plt.title(f"Mean Convergence for '{feature_name}'", **label_font)
334
340
 
335
- # Adjust plot display for the X axis
336
341
  _ticks = np.arange(iterations_cap)
337
342
  _labels = np.arange(1, iterations_cap + 1)
338
- plt.xticks(ticks=_ticks, labels=_labels) # type: ignore
343
+ plt.xticks(ticks=_ticks, labels=_labels)
339
344
  plt.grid(True)
340
345
 
341
- feature_save_name = sanitize_filename(feature_name)
342
- feature_save_name = feature_save_name + ".svg"
346
+ feature_save_name = sanitize_filename(feature_name) + ".svg"
343
347
  save_path = local_save_dir / feature_save_name
344
348
  plt.savefig(save_path, bbox_inches='tight', format="svg")
345
349
  plt.close()
346
350
 
347
- _LOGGER.info(f"{dataset_file_dir} process completed.")
351
+ _LOGGER.info(f"📉 Convergence diagnostics complete.")
348
352
 
349
353
 
350
354
  # Imputed distributions
@@ -431,5 +435,5 @@ def get_imputed_distributions(kernel: mf.ImputationKernel, df_name: str, root_di
431
435
  fig = kernel.plot_imputed_distributions(variables=[feature])
432
436
  _process_figure(fig, feature)
433
437
 
434
- _LOGGER.info(f"{local_dir_name} completed.")
438
+ _LOGGER.info(f"📊 Imputed distributions complete.")
435
439
 
@@ -8,11 +8,16 @@ from ._chaining_tools import (
8
8
  prepare_chaining_dataset,
9
9
  )
10
10
 
11
+ from ._update_schema import (
12
+ derive_next_step_schema
13
+ )
14
+
11
15
  from ._imprimir import info
12
16
 
13
17
 
14
18
  __all__ = [
15
19
  "DragonChainOrchestrator",
20
+ "derive_next_step_schema",
16
21
  "augment_dataset_with_predictions",
17
22
  "augment_dataset_with_predictions_multi",
18
23
  "prepare_chaining_dataset",
@@ -5,7 +5,7 @@ from typing import Optional, Literal
5
5
 
6
6
  from ..ML_inference import DragonInferenceHandler
7
7
 
8
- from ..keys._keys import MLTaskKeys, PyTorchInferenceKeys
8
+ from ..keys._keys import MLTaskKeys, PyTorchInferenceKeys, ChainKeys
9
9
  from .._core import get_logger
10
10
 
11
11
 
@@ -23,11 +23,10 @@ def augment_dataset_with_predictions(
23
23
  handler: DragonInferenceHandler,
24
24
  dataset: pd.DataFrame,
25
25
  ground_truth_targets: list[str],
26
- prediction_col_prefix: str = "pred_",
27
26
  batch_size: int = 4096
28
27
  ) -> pd.DataFrame:
29
28
  """
30
- Uses a DragonInferenceHandler to generate predictions for a dataset and appends them as new feature columns.
29
+ Uses a DragonInferenceHandler to generate predictions for a dataset and appends them as new feature columns with a standardized prefix.
31
30
 
32
31
  This function splits the features from the ground truth targets, runs inference in batches to ensure
33
32
  memory efficiency, and returns a unified DataFrame containing:
@@ -38,8 +37,6 @@ def augment_dataset_with_predictions(
38
37
  dataset (pd.DataFrame): The input pandas DataFrame containing features and ground truth targets.
39
38
  ground_truth_targets (List[str]): A list of column names in `dataset` representing the actual targets.
40
39
  These are removed from the input features during inference and appended to the end of the result.
41
- prediction_col_prefix (str, optional): A string to prepend when creating the
42
- new prediction columns.
43
40
  batch_size (int, optional): The number of samples to process in a single inference step.
44
41
  Prevents OOM errors on large datasets. Defaults to 4096.
45
42
 
@@ -107,7 +104,7 @@ def augment_dataset_with_predictions(
107
104
  full_prediction_array = np.vstack(all_predictions)
108
105
 
109
106
  # Generate new column names
110
- new_col_names = [f"{prediction_col_prefix}{tid}" for tid in handler.target_ids]
107
+ new_col_names = [f"{ChainKeys.CHAIN_PREDICTION_PREFIX}{tid}" for tid in handler.target_ids]
111
108
 
112
109
  # Verify dimensions match
113
110
  if full_prediction_array.shape[1] != len(new_col_names):
@@ -77,18 +77,16 @@ class DragonChainOrchestrator:
77
77
  def update_with_inference(
78
78
  self,
79
79
  handler: DragonInferenceHandler,
80
- prefix: str = "pred_",
81
80
  batch_size: int = 4096
82
81
  ) -> None:
83
82
  """
84
83
  Runs inference using the provided handler on the full internal dataset and appends the results as new features.
85
84
 
86
85
  This updates the internal state of the Orchestrator. Subsequent calls to `get_training_data`
87
- will include these new prediction columns as features.
86
+ will include these new prediction columns as features with a standardized prefix.
88
87
 
89
88
  Args:
90
89
  handler (DragonInferenceHandler): The trained model handler.
91
- prefix (str): Prefix for the new prediction columns (e.g., "m1_", "step2_").
92
90
  batch_size (int): Batch size for inference.
93
91
  """
94
92
  _LOGGER.info(f"Orchestrator: Updating internal state with predictions from handler (Targets: {handler.target_ids})...")
@@ -99,7 +97,6 @@ class DragonChainOrchestrator:
99
97
  handler=handler,
100
98
  dataset=self.current_dataset,
101
99
  ground_truth_targets=self.all_targets,
102
- prediction_col_prefix=prefix,
103
100
  batch_size=batch_size
104
101
  )
105
102
 
@@ -2,6 +2,7 @@ from .._core import _imprimir_disponibles
2
2
 
3
3
  _GRUPOS = [
4
4
  "DragonChainOrchestrator",
5
+ "derive_next_step_schema",
5
6
  "augment_dataset_with_predictions",
6
7
  "augment_dataset_with_predictions_multi",
7
8
  "prepare_chaining_dataset",
@@ -0,0 +1,96 @@
1
+ from ..schema import FeatureSchema
2
+ from ..ML_inference import DragonInferenceHandler
3
+
4
+ from ..keys._keys import MLTaskKeys, ChainKeys
5
+ from .._core import get_logger
6
+
7
+
8
+ _LOGGER = get_logger("Schema Updater")
9
+
10
+
11
+ __all__ = [
12
+ "derive_next_step_schema",
13
+ ]
14
+
15
+
16
+ def derive_next_step_schema(
17
+ current_schema: FeatureSchema,
18
+ handler: DragonInferenceHandler,
19
+ verbose: bool = True
20
+ ) -> FeatureSchema:
21
+ """
22
+ Creates the FeatureSchema for the NEXT step in the chain by appending the current handler's predictions as new features.
23
+
24
+ Args:
25
+ current_schema (FeatureSchema): The current FeatureSchema.
26
+ handler (DragonInferenceHandler): The inference handler of the model trained using the current schema.
27
+
28
+ Returns:
29
+ FeatureSchema: An updated schema including new predicted features.
30
+ """
31
+ # 1. Determine New Column Names
32
+ # Match logic from _chaining_tools.py
33
+ if handler.target_ids is None:
34
+ _LOGGER.error("Handler target_ids is None; cannot derive schema.")
35
+ raise ValueError()
36
+
37
+ new_cols = [f"{ChainKeys.CHAIN_PREDICTION_PREFIX}{tid}" for tid in handler.target_ids]
38
+
39
+ # 2. Base Lists (Convert tuples to lists for mutation)
40
+ new_feature_names = list(current_schema.feature_names) + new_cols
41
+ new_cont_names = list(current_schema.continuous_feature_names)
42
+ new_cat_names = list(current_schema.categorical_feature_names)
43
+
44
+ # Copy existing maps (handle None case)
45
+ new_cat_index_map = dict(current_schema.categorical_index_map) if current_schema.categorical_index_map else {}
46
+ new_cat_mappings = dict(current_schema.categorical_mappings) if current_schema.categorical_mappings else {}
47
+
48
+ # 3. Determine Feature Type based on Task
49
+ is_categorical = False
50
+ cardinality = 0
51
+
52
+ if handler.task in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
53
+ is_categorical = True
54
+ cardinality = 2
55
+
56
+ elif handler.task == MLTaskKeys.MULTICLASS_CLASSIFICATION:
57
+ is_categorical = True
58
+ # We rely on the class map to know the 'vocabulary' size
59
+ if handler._class_map is None:
60
+ _LOGGER.error("Handler class_map is None, cannot determine cardinality for multiclass classification model.")
61
+ raise ValueError()
62
+ cardinality = len(handler._class_map)
63
+
64
+ # 4. Append New Metadata
65
+ current_total_feats = len(current_schema.feature_names)
66
+
67
+ for i, col_name in enumerate(new_cols):
68
+ # Calculate the absolute index of this new column
69
+ # If we had 10 features (0-9), the new one is at index 10 + i
70
+ new_index = current_total_feats + i
71
+
72
+ if is_categorical:
73
+ new_cat_names.append(col_name)
74
+
75
+ # A. Update Cardinality for Embeddings
76
+ new_cat_index_map[new_index] = cardinality
77
+
78
+ # B. Create Identity Mapping (Dummy Encoding)
79
+ # Maps string representation of int back to the int.
80
+ identity_map = {str(k): k for k in range(cardinality)}
81
+ new_cat_mappings[col_name] = identity_map
82
+ else:
83
+ # Regression / Multitarget Regression
84
+ new_cont_names.append(col_name)
85
+
86
+ if verbose:
87
+ _LOGGER.info(f"Derived next step schema with {len(new_feature_names)} features:\n {len(new_cont_names)} continuous\n {len(new_cat_names)} categorical")
88
+
89
+ # 5. Return New Immutable Schema
90
+ return FeatureSchema(
91
+ feature_names=tuple(new_feature_names),
92
+ continuous_feature_names=tuple(new_cont_names),
93
+ categorical_feature_names=tuple(new_cat_names),
94
+ categorical_index_map=new_cat_index_map if new_cat_index_map else None,
95
+ categorical_mappings=new_cat_mappings if new_cat_mappings else None
96
+ )
@@ -53,13 +53,13 @@ __all__ = [
53
53
  "split_features_targets",
54
54
  "split_continuous_binary",
55
55
  "split_continuous_categorical_targets",
56
- "encode_categorical_features",
57
56
  "clip_outliers_single",
58
57
  "clip_outliers_multi",
59
58
  "drop_outlier_samples",
60
59
  "plot_continuous_vs_target",
61
60
  "plot_categorical_vs_target",
62
61
  "plot_correlation_heatmap",
62
+ "encode_categorical_features",
63
63
  "finalize_feature_schema",
64
64
  "apply_feature_schema",
65
65
  "match_and_filter_columns_by_regex",
@@ -12,13 +12,13 @@ _GRUPOS = [
12
12
  "split_features_targets",
13
13
  "split_continuous_binary",
14
14
  "split_continuous_categorical_targets",
15
- "encode_categorical_features",
16
15
  "clip_outliers_single",
17
16
  "clip_outliers_multi",
18
17
  "drop_outlier_samples",
19
18
  "plot_continuous_vs_target",
20
19
  "plot_categorical_vs_target",
21
20
  "plot_correlation_heatmap",
21
+ "encode_categorical_features",
22
22
  "finalize_feature_schema",
23
23
  "apply_feature_schema",
24
24
  "match_and_filter_columns_by_regex",
ml_tools/keys/_keys.py CHANGED
@@ -278,6 +278,11 @@ class SchemaKeys:
278
278
  OPTIONAL_LABELS = "optional_labels"
279
279
 
280
280
 
281
+ class ChainKeys:
282
+ """Used by the ML chaining module."""
283
+ CHAIN_PREDICTION_PREFIX = "pred_"
284
+
285
+
281
286
  class _EvaluationConfig:
282
287
  """Set config values for evaluation modules."""
283
288
  DPI = 400
@@ -44,7 +44,7 @@ class FeatureSchema(NamedTuple):
44
44
  Handles conversion of Tuple->List and IntKeys->StrKeys automatically.
45
45
  """
46
46
  # validate path
47
- dir_path = make_fullpath(directory, enforce="directory")
47
+ dir_path = make_fullpath(directory, make=True, enforce="directory")
48
48
  file_path = dir_path / SchemaKeys.SCHEMA_FILENAME
49
49
 
50
50
  try: