dragon-ml-toolbox 13.1.0__py3-none-any.whl → 14.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -1,13 +1,10 @@
1
1
  import torch
2
- from torch.utils.data import Dataset, Subset
2
+ from torch.utils.data import Dataset
3
3
  import pandas
4
4
  import numpy
5
5
  from sklearn.model_selection import train_test_split
6
6
  from typing import Literal, Union, Tuple, List, Optional
7
7
  from abc import ABC, abstractmethod
8
- from PIL import Image, ImageOps
9
- from torchvision.datasets import ImageFolder
10
- from torchvision import transforms
11
8
  import matplotlib.pyplot as plt
12
9
  from pathlib import Path
13
10
 
@@ -23,9 +20,7 @@ from ._schema import FeatureSchema
23
20
  __all__ = [
24
21
  "DatasetMaker",
25
22
  "DatasetMakerMulti",
26
- "VisionDatasetMaker",
27
- "SequenceMaker",
28
- "ResizeAspectFill",
23
+ "SequenceMaker"
29
24
  ]
30
25
 
31
26
 
@@ -126,8 +121,8 @@ class _BaseDatasetMaker(ABC):
126
121
  else:
127
122
  _LOGGER.info("No continuous features listed in schema. Scaler will not be fitted.")
128
123
 
129
- X_train_values = X_train.values
130
- X_test_values = X_test.values
124
+ X_train_values = X_train.to_numpy()
125
+ X_test_values = X_test.to_numpy()
131
126
 
132
127
  # continuous_feature_indices is derived
133
128
  if self.scaler is None and continuous_feature_indices:
@@ -253,26 +248,42 @@ class DatasetMaker(_BaseDatasetMaker):
253
248
  pandas_df: pandas.DataFrame,
254
249
  schema: FeatureSchema,
255
250
  kind: Literal["regression", "classification"],
251
+ scaler: Union[Literal["fit"], Literal["none"], PytorchScaler],
256
252
  test_size: float = 0.2,
257
- random_state: int = 42,
258
- scaler: Optional[PytorchScaler] = None):
253
+ random_state: int = 42):
259
254
  """
260
255
  Args:
261
256
  pandas_df (pandas.DataFrame):
262
257
  The pre-processed input DataFrame containing all columns. (features and single target).
263
258
  schema (FeatureSchema):
264
259
  The definitive schema object from data_exploration.
265
- kind (Literal["regression", "classification"]):
260
+ kind ("regression" | "classification"):
266
261
  The type of ML task. This determines the data type of the labels.
262
+ scaler ("fit" | "none" | PytorchScaler):
263
+ Strategy for data scaling:
264
+ - "fit": Fit a new PytorchScaler on continuous features.
265
+ - "none": Do not scale data (e.g., for TabularTransformer).
266
+ - PytorchScaler instance: Use a pre-fitted scaler to transform data.
267
267
  test_size (float):
268
268
  The proportion of the dataset to allocate to the test split.
269
269
  random_state (int):
270
270
  The seed for the random number of generator for reproducibility.
271
- scaler (PytorchScaler | None):
272
- A pre-fitted PytorchScaler instance, if None a new scaler will be created.
271
+
273
272
  """
274
273
  super().__init__()
275
- self.scaler = scaler
274
+
275
+ _apply_scaling: bool = False
276
+ if scaler == "fit":
277
+ self.scaler = None # To be created
278
+ _apply_scaling = True
279
+ elif scaler == "none":
280
+ self.scaler = None
281
+ elif isinstance(scaler, PytorchScaler):
282
+ self.scaler = scaler # Use the provided one
283
+ _apply_scaling = True
284
+ else:
285
+ _LOGGER.error(f"Invalid 'scaler' argument. Must be 'fit', 'none', or a PytorchScaler instance.")
286
+ raise ValueError()
276
287
 
277
288
  # --- 1. Identify features (from schema) ---
278
289
  self._feature_names = list(schema.feature_names)
@@ -310,9 +321,14 @@ class DatasetMaker(_BaseDatasetMaker):
310
321
  label_dtype = torch.float32 if kind == "regression" else torch.int64
311
322
 
312
323
  # --- 4. Scale (using the schema) ---
313
- X_train_final, X_test_final = self._prepare_scaler(
314
- X_train, y_train, X_test, label_dtype, schema
315
- )
324
+ if _apply_scaling:
325
+ X_train_final, X_test_final = self._prepare_scaler(
326
+ X_train, y_train, X_test, label_dtype, schema
327
+ )
328
+ else:
329
+ _LOGGER.info("Features have not been scaled as specified.")
330
+ X_train_final = X_train.to_numpy()
331
+ X_test_final = X_test.to_numpy()
316
332
 
317
333
  # --- 5. Create Datasets ---
318
334
  self._train_ds = _PytorchDataset(X_train_final, y_train, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=self._target_names)
@@ -336,9 +352,9 @@ class DatasetMakerMulti(_BaseDatasetMaker):
336
352
  pandas_df: pandas.DataFrame,
337
353
  target_columns: List[str],
338
354
  schema: FeatureSchema,
355
+ scaler: Union[Literal["fit"], Literal["none"], PytorchScaler],
339
356
  test_size: float = 0.2,
340
- random_state: int = 42,
341
- scaler: Optional[PytorchScaler] = None):
357
+ random_state: int = 42):
342
358
  """
343
359
  Args:
344
360
  pandas_df (pandas.DataFrame):
@@ -348,20 +364,35 @@ class DatasetMakerMulti(_BaseDatasetMaker):
348
364
  List of target column names.
349
365
  schema (FeatureSchema):
350
366
  The definitive schema object from data_exploration.
367
+ scaler ("fit" | "none" | PytorchScaler):
368
+ Strategy for data scaling:
369
+ - "fit": Fit a new PytorchScaler on continuous features.
370
+ - "none": Do not scale data (e.g., for TabularTransformer).
371
+ - PytorchScaler instance: Use a pre-fitted scaler to transform data.
351
372
  test_size (float):
352
373
  The proportion of the dataset to allocate to the test split.
353
374
  random_state (int):
354
375
  The seed for the random number generator for reproducibility.
355
- scaler (PytorchScaler | None):
356
- A pre-fitted PytorchScaler instance.
357
376
 
358
377
  ## Note:
359
378
  For multi-binary classification, the most common PyTorch loss function is nn.BCEWithLogitsLoss.
360
379
  This loss function requires the labels to be torch.float32 which is the same type required for regression (multi-regression) tasks.
361
380
  """
362
381
  super().__init__()
363
- self.scaler = scaler
364
-
382
+
383
+ _apply_scaling: bool = False
384
+ if scaler == "fit":
385
+ self.scaler = None
386
+ _apply_scaling = True
387
+ elif scaler == "none":
388
+ self.scaler = None
389
+ elif isinstance(scaler, PytorchScaler):
390
+ self.scaler = scaler # Use the provided one
391
+ _apply_scaling = True
392
+ else:
393
+ _LOGGER.error(f"Invalid 'scaler' argument. Must be 'fit', 'none', or a PytorchScaler instance.")
394
+ raise ValueError()
395
+
365
396
  # --- 1. Get features and targets from schema/args ---
366
397
  self._feature_names = list(schema.feature_names)
367
398
  self._target_names = target_columns
@@ -403,9 +434,14 @@ class DatasetMakerMulti(_BaseDatasetMaker):
403
434
  label_dtype = torch.float32
404
435
 
405
436
  # --- 4. Scale (using the schema) ---
406
- X_train_final, X_test_final = self._prepare_scaler(
407
- X_train, y_train, X_test, label_dtype, schema
408
- )
437
+ if _apply_scaling:
438
+ X_train_final, X_test_final = self._prepare_scaler(
439
+ X_train, y_train, X_test, label_dtype, schema
440
+ )
441
+ else:
442
+ _LOGGER.info("Features have not been scaled as specified.")
443
+ X_train_final = X_train.to_numpy()
444
+ X_test_final = X_test.to_numpy()
409
445
 
410
446
  # --- 5. Create Datasets ---
411
447
  # _PytorchDataset now correctly handles y_train (a DataFrame)
@@ -432,149 +468,6 @@ class _BaseMaker(ABC):
432
468
  pass
433
469
 
434
470
 
435
- # --- VisionDatasetMaker ---
436
- class VisionDatasetMaker(_BaseMaker):
437
- """
438
- Creates processed PyTorch datasets for computer vision tasks from an
439
- image folder directory.
440
-
441
- Uses online augmentations per epoch (image augmentation without creating new files).
442
- """
443
- def __init__(self, full_dataset: ImageFolder):
444
- super().__init__()
445
- self.full_dataset = full_dataset
446
- self.labels = [s[1] for s in self.full_dataset.samples]
447
- self.class_map = full_dataset.class_to_idx
448
-
449
- self._is_split = False
450
- self._are_transforms_configured = False
451
-
452
- @classmethod
453
- def from_folder(cls, root_dir: str) -> 'VisionDatasetMaker':
454
- """Creates a maker instance from a root directory of images."""
455
- initial_transform = transforms.Compose([transforms.ToTensor()])
456
- full_dataset = ImageFolder(root=root_dir, transform=initial_transform)
457
- _LOGGER.info(f"Found {len(full_dataset)} images in {len(full_dataset.classes)} classes.")
458
- return cls(full_dataset)
459
-
460
- @staticmethod
461
- def inspect_folder(path: Union[str, Path]):
462
- """
463
- Logs a report of the types, sizes, and channels of image files
464
- found in the directory and its subdirectories.
465
- """
466
- path_obj = make_fullpath(path)
467
-
468
- non_image_files = set()
469
- img_types = set()
470
- img_sizes = set()
471
- img_channels = set()
472
- img_counter = 0
473
-
474
- _LOGGER.info(f"Inspecting folder: {path_obj}...")
475
- # Use rglob to recursively find all files
476
- for filepath in path_obj.rglob('*'):
477
- if filepath.is_file():
478
- try:
479
- # Using PIL to open is a more reliable check
480
- with Image.open(filepath) as img:
481
- img_types.add(img.format)
482
- img_sizes.add(img.size)
483
- img_channels.update(img.getbands())
484
- img_counter += 1
485
- except (IOError, SyntaxError):
486
- non_image_files.add(filepath.name)
487
-
488
- if non_image_files:
489
- _LOGGER.warning(f"Non-image or corrupted files found and ignored: {non_image_files}")
490
-
491
- report = (
492
- f"\n--- Inspection Report for '{path_obj.name}' ---\n"
493
- f"Total images found: {img_counter}\n"
494
- f"Image formats: {img_types or 'None'}\n"
495
- f"Image sizes (WxH): {img_sizes or 'None'}\n"
496
- f"Image channels (bands): {img_channels or 'None'}\n"
497
- f"--------------------------------------"
498
- )
499
- print(report)
500
-
501
- def split_data(self, val_size: float = 0.2, test_size: float = 0.0,
502
- stratify: bool = True, random_state: Optional[int] = None) -> 'VisionDatasetMaker':
503
- """Splits the dataset into training, validation, and optional test sets."""
504
- if self._is_split:
505
- _LOGGER.warning("Data has already been split.")
506
- return self
507
-
508
- if val_size + test_size >= 1.0:
509
- _LOGGER.error("The sum of val_size and test_size must be less than 1.")
510
- raise ValueError()
511
-
512
- indices = list(range(len(self.full_dataset)))
513
- labels_for_split = self.labels if stratify else None
514
-
515
- train_indices, val_test_indices = train_test_split(
516
- indices, test_size=(val_size + test_size), random_state=random_state, stratify=labels_for_split
517
- )
518
-
519
- if test_size > 0:
520
- val_test_labels = [self.labels[i] for i in val_test_indices]
521
- stratify_val_test = val_test_labels if stratify else None
522
- val_indices, test_indices = train_test_split(
523
- val_test_indices, test_size=(test_size / (val_size + test_size)),
524
- random_state=random_state, stratify=stratify_val_test
525
- )
526
- self._test_dataset = Subset(self.full_dataset, test_indices)
527
- _LOGGER.info(f"Test set created with {len(self._test_dataset)} images.")
528
- else:
529
- val_indices = val_test_indices
530
-
531
- self._train_dataset = Subset(self.full_dataset, train_indices)
532
- self._val_dataset = Subset(self.full_dataset, val_indices)
533
- self._is_split = True
534
-
535
- _LOGGER.info(f"Data split into: \n- Training: {len(self._train_dataset)} images \n- Validation: {len(self._val_dataset)} images")
536
- return self
537
-
538
- def configure_transforms(self, resize_size: int = 256, crop_size: int = 224,
539
- mean: List[float] = [0.485, 0.456, 0.406],
540
- std: List[float] = [0.229, 0.224, 0.225],
541
- extra_train_transforms: Optional[List] = None) -> 'VisionDatasetMaker':
542
- """Configures and applies the image transformations (augmentations)."""
543
- if not self._is_split:
544
- _LOGGER.error("Transforms must be configured AFTER splitting data. Call .split_data() first.")
545
- raise RuntimeError()
546
-
547
- base_train_transforms = [transforms.RandomResizedCrop(crop_size), transforms.RandomHorizontalFlip()]
548
- if extra_train_transforms:
549
- base_train_transforms.extend(extra_train_transforms)
550
-
551
- final_transforms = [transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]
552
-
553
- val_transform = transforms.Compose([transforms.Resize(resize_size), transforms.CenterCrop(crop_size), *final_transforms])
554
- train_transform = transforms.Compose([*base_train_transforms, *final_transforms])
555
-
556
- self._train_dataset.dataset.transform = train_transform # type: ignore
557
- self._val_dataset.dataset.transform = val_transform # type: ignore
558
- if self._test_dataset:
559
- self._test_dataset.dataset.transform = val_transform # type: ignore
560
-
561
- self._are_transforms_configured = True
562
- _LOGGER.info("Image transforms configured and applied.")
563
- return self
564
-
565
- def get_datasets(self) -> Tuple[Dataset, ...]:
566
- """Returns the final train, validation, and optional test datasets."""
567
- if not self._is_split:
568
- _LOGGER.error("Data has not been split. Call .split_data() first.")
569
- raise RuntimeError()
570
- if not self._are_transforms_configured:
571
- _LOGGER.warning("Transforms have not been configured. Using default ToTensor only.")
572
-
573
- if self._test_dataset:
574
- return self._train_dataset, self._val_dataset, self._test_dataset
575
- return self._train_dataset, self._val_dataset
576
-
577
-
578
471
  # --- SequenceMaker ---
579
472
  class SequenceMaker(_BaseMaker):
580
473
  """
@@ -763,40 +656,5 @@ class SequenceMaker(_BaseMaker):
763
656
  return self._train_dataset, self._test_dataset
764
657
 
765
658
 
766
- # --- Custom Vision Transform Class ---
767
- class ResizeAspectFill:
768
- """
769
- Custom transformation to make an image square by padding it to match the
770
- longest side, preserving the aspect ratio. The image is finally centered.
771
-
772
- Args:
773
- pad_color (Union[str, int]): Color to use for the padding.
774
- Defaults to "black".
775
- """
776
- def __init__(self, pad_color: Union[str, int] = "black") -> None:
777
- self.pad_color = pad_color
778
-
779
- def __call__(self, image: Image.Image) -> Image.Image:
780
- if not isinstance(image, Image.Image):
781
- _LOGGER.error(f"Expected PIL.Image.Image, got {type(image).__name__}")
782
- raise TypeError()
783
-
784
- w, h = image.size
785
- if w == h:
786
- return image
787
-
788
- # Determine padding to center the image
789
- if w > h:
790
- top_padding = (w - h) // 2
791
- bottom_padding = w - h - top_padding
792
- padding = (0, top_padding, 0, bottom_padding)
793
- else: # h > w
794
- left_padding = (h - w) // 2
795
- right_padding = h - w - left_padding
796
- padding = (left_padding, 0, right_padding, 0)
797
-
798
- return ImageOps.expand(image, padding, fill=self.pad_color)
799
-
800
-
801
659
  def info():
802
660
  _script_info(__all__)
ml_tools/ML_evaluation.py CHANGED
@@ -24,7 +24,7 @@ import warnings
24
24
  from .path_manager import make_fullpath
25
25
  from ._logger import _LOGGER
26
26
  from ._script_info import _script_info
27
- from .keys import SHAPKeys
27
+ from .keys import SHAPKeys, PyTorchLogKeys
28
28
 
29
29
 
30
30
  __all__ = [
@@ -44,8 +44,8 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
44
44
  history (dict): A dictionary containing 'train_loss' and 'val_loss'.
45
45
  save_dir (str | Path): Directory to save the plot image.
46
46
  """
47
- train_loss = history.get('train_loss', [])
48
- val_loss = history.get('val_loss', [])
47
+ train_loss = history.get(PyTorchLogKeys.TRAIN_LOSS, [])
48
+ val_loss = history.get(PyTorchLogKeys.VAL_LOSS, [])
49
49
 
50
50
  if not train_loss and not val_loss:
51
51
  print("Warning: Loss history is empty or incomplete. Cannot plot.")
@@ -258,7 +258,7 @@ def shap_summary_plot(model,
258
258
  feature_names: Optional[list[str]],
259
259
  save_dir: Union[str, Path],
260
260
  device: torch.device = torch.device('cpu'),
261
- explainer_type: Literal['deep', 'kernel'] = 'deep'):
261
+ explainer_type: Literal['deep', 'kernel'] = 'kernel'):
262
262
  """
263
263
  Calculates SHAP values and saves summary plots and data.
264
264
 
@@ -270,7 +270,7 @@ def shap_summary_plot(model,
270
270
  save_dir (str | Path): Directory to save SHAP artifacts.
271
271
  device (torch.device): The torch device for SHAP calculations.
272
272
  explainer_type (Literal['deep', 'kernel']): The explainer to use.
273
- - 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient for
273
+ - 'deep': Uses shap.DeepExplainer. Fast and efficient for
274
274
  PyTorch models.
275
275
  - 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY
276
276
  slow and memory-intensive.
@@ -285,7 +285,7 @@ def shap_summary_plot(model,
285
285
  instances_to_explain_np = None
286
286
 
287
287
  if explainer_type == 'deep':
288
- # --- 1. Use DeepExplainer (Preferred) ---
288
+ # --- 1. Use DeepExplainer ---
289
289
 
290
290
  # Ensure data is torch.Tensor
291
291
  if isinstance(background_data, np.ndarray):
@@ -309,10 +309,9 @@ def shap_summary_plot(model,
309
309
  instances_to_explain_np = instances_to_explain.cpu().numpy()
310
310
 
311
311
  elif explainer_type == 'kernel':
312
- # --- 2. Use KernelExplainer (Slow Fallback) ---
312
+ # --- 2. Use KernelExplainer ---
313
313
  _LOGGER.warning(
314
- "Using KernelExplainer. This is memory-intensive and slow. "
315
- "Consider reducing 'n_samples' if the process terminates unexpectedly."
314
+ "KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
316
315
  )
317
316
 
318
317
  # Ensure data is np.ndarray
@@ -348,14 +347,26 @@ def shap_summary_plot(model,
348
347
  else:
349
348
  _LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
350
349
  raise ValueError()
350
+
351
+ if not isinstance(shap_values, list) and shap_values.ndim == 3 and shap_values.shape[2] == 1:
352
+ # _LOGGER.info("Squeezing SHAP values from (N, F, 1) to (N, F) for regression plot.")
353
+ shap_values = shap_values.squeeze(-1)
351
354
 
352
355
  # --- 3. Plotting and Saving ---
353
356
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
354
357
  plt.ioff()
355
358
 
359
+ # Convert instances to a DataFrame. robust way to ensure SHAP correctly maps values to feature names.
360
+ if feature_names is None:
361
+ # Create generic names if none were provided
362
+ num_features = instances_to_explain_np.shape[1]
363
+ feature_names = [f'feature_{i}' for i in range(num_features)]
364
+
365
+ instances_df = pd.DataFrame(instances_to_explain_np, columns=feature_names)
366
+
356
367
  # Save Bar Plot
357
368
  bar_path = save_dir_path / "shap_bar_plot.svg"
358
- shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="bar", show=False)
369
+ shap.summary_plot(shap_values, instances_df, plot_type="bar", show=False)
359
370
  ax = plt.gca()
360
371
  ax.set_xlabel("SHAP Value Impact", labelpad=10)
361
372
  plt.title("SHAP Feature Importance")
@@ -366,7 +377,7 @@ def shap_summary_plot(model,
366
377
 
367
378
  # Save Dot Plot
368
379
  dot_path = save_dir_path / "shap_dot_plot.svg"
369
- shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
380
+ shap.summary_plot(shap_values, instances_df, plot_type="dot", show=False)
370
381
  ax = plt.gca()
371
382
  ax.set_xlabel("SHAP Value Impact", labelpad=10)
372
383
  if plt.gcf().axes and len(plt.gcf().axes) > 1:
@@ -389,9 +400,6 @@ def shap_summary_plot(model,
389
400
  mean_abs_shap = np.abs(shap_values).mean(axis=0)
390
401
 
391
402
  mean_abs_shap = mean_abs_shap.flatten()
392
-
393
- if feature_names is None:
394
- feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
395
403
 
396
404
  summary_df = pd.DataFrame({
397
405
  SHAPKeys.FEATURE_COLUMN: feature_names,
@@ -401,7 +409,7 @@ def shap_summary_plot(model,
401
409
  summary_df.to_csv(summary_path, index=False)
402
410
 
403
411
  _LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
404
- plt.ion()
412
+ plt.ion()
405
413
 
406
414
 
407
415
  def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path], top_n: int = 10):
@@ -235,7 +235,7 @@ def multi_target_shap_summary_plot(
235
235
  target_names: List[str],
236
236
  save_dir: Union[str, Path],
237
237
  device: torch.device = torch.device('cpu'),
238
- explainer_type: Literal['deep', 'kernel'] = 'deep'
238
+ explainer_type: Literal['deep', 'kernel'] = 'kernel'
239
239
  ):
240
240
  """
241
241
  Calculates SHAP values for a multi-target model and saves summary plots and data for each target.
@@ -249,7 +249,7 @@ def multi_target_shap_summary_plot(
249
249
  save_dir (str | Path): Directory to save SHAP artifacts.
250
250
  device (torch.device): The torch device for SHAP calculations.
251
251
  explainer_type (Literal['deep', 'kernel']): The explainer to use.
252
- - 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient.
252
+ - 'deep': Uses shap.DeepExplainer. Fast and efficient.
253
253
  - 'kernel': Uses shap.KernelExplainer. Model-agnostic but slow and memory-intensive.
254
254
  """
255
255
  _LOGGER.info(f"--- Multi-Target SHAP Value Explanation (Using: {explainer_type.upper()}Explainer) ---")
@@ -260,7 +260,7 @@ def multi_target_shap_summary_plot(
260
260
  instances_to_explain_np = None
261
261
 
262
262
  if explainer_type == 'deep':
263
- # --- 1. Use DeepExplainer (Preferred) ---
263
+ # --- 1. Use DeepExplainer ---
264
264
 
265
265
  # Ensure data is torch.Tensor
266
266
  if isinstance(background_data, np.ndarray):
@@ -285,10 +285,9 @@ def multi_target_shap_summary_plot(
285
285
  instances_to_explain_np = instances_to_explain.cpu().numpy()
286
286
 
287
287
  elif explainer_type == 'kernel':
288
- # --- 2. Use KernelExplainer (Slow Fallback) ---
288
+ # --- 2. Use KernelExplainer ---
289
289
  _LOGGER.warning(
290
- "Using KernelExplainer. This is memory-intensive and slow. "
291
- "Consider reducing 'n_samples' if the process terminates."
290
+ "KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
292
291
  )
293
292
 
294
293
  # Convert all data to numpy
ml_tools/ML_inference.py CHANGED
@@ -82,7 +82,6 @@ class _BaseInferenceHandler(ABC):
82
82
  _LOGGER.warning("CUDA not available, switching to CPU.")
83
83
  device_lower = "cpu"
84
84
  elif device_lower == "mps" and not torch.backends.mps.is_available():
85
- # Your M-series Mac will appreciate this check!
86
85
  _LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
87
86
  device_lower = "cpu"
88
87
  return torch.device(device_lower)
ml_tools/ML_models.py CHANGED
@@ -306,10 +306,10 @@ class TabularTransformer(nn.Module, _ArchitectureHandlerMixin):
306
306
  def __init__(self, *,
307
307
  schema: FeatureSchema,
308
308
  out_targets: int,
309
- embedding_dim: int = 32,
309
+ embedding_dim: int = 256,
310
310
  num_heads: int = 8,
311
311
  num_layers: int = 6,
312
- dropout: float = 0.1):
312
+ dropout: float = 0.2):
313
313
  """
314
314
  Args:
315
315
  schema (FeatureSchema):
@@ -317,14 +317,28 @@ class TabularTransformer(nn.Module, _ArchitectureHandlerMixin):
317
317
  out_targets (int):
318
318
  Number of output targets (1 for regression).
319
319
  embedding_dim (int):
320
- The dimension for all feature embeddings. Must be divisible
321
- by num_heads.
320
+ The dimension for all feature embeddings. Must be divisible by num_heads. Common values: (64, 128, 192, 256, etc.)
322
321
  num_heads (int):
323
- The number of heads in the multi-head attention mechanism.
322
+ The number of heads in the multi-head attention mechanism. Common values: (4, 8, 16)
324
323
  num_layers (int):
325
- The number of sub-encoder-layers in the transformer encoder.
324
+ The number of sub-encoder-layers in the transformer encoder. Common values: (4, 8, 12)
326
325
  dropout (float):
327
326
  The dropout value.
327
+
328
+ ## Note:
329
+
330
+ **Embedding Dimension:** "Width" of the model. It's the N-dimension vector that will be used to represent each one of the features.
331
+ - Each continuous feature gets its own learnable N-dimension vector.
332
+ - Each categorical feature gets an embedding table that maps every category (e.g., "color=red", "color=blue") to a unique N-dimension vector.
333
+
334
+ **Attention Heads:** Controls the "Multi-Head Attention" mechanism. Instead of looking at all the feature interactions at once, the model splits its attention into N parallel heads.
335
+ - Embedding Dimensions get divided by the number of Attention Heads, resulting in the dimensions assigned per head.
336
+
337
+ **Number of Layers:** "Depth" of the model. Number of identical `TransformerEncoderLayer` blocks that are stacked on top of each other.
338
+ - Layer 1: The attention heads find simple, direct interactions between the features.
339
+ - Layer 2: Takes the output of Layer 1 and finds interactions between those interactions and so on.
340
+ - Trade-off: More layers are more powerful but are slower to train and more prone to overfitting. If the training loss goes down but the validation loss goes up, you might have too many layers (or need more dropout).
341
+
328
342
  """
329
343
  super().__init__()
330
344
 
@@ -734,5 +748,7 @@ class SequencePredictorLSTM(nn.Module, _ArchitectureHandlerMixin):
734
748
  )
735
749
 
736
750
 
751
+ # ---- PyTorch models ---
752
+
737
753
  def info():
738
754
  _script_info(__all__)