dragon-ml-toolbox 2.3.0__py3-none-any.whl → 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ml_tools/ML_trainer.py ADDED
@@ -0,0 +1,344 @@
1
+ from typing import List, Literal, Union, Optional
2
+ from pathlib import Path
3
+ from torch.utils.data import DataLoader, Dataset
4
+ import torch
5
+ from torch import nn
6
+ import numpy as np
7
+
8
+ from .ML_callbacks import Callback, History, TqdmProgressBar
9
+ from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot
10
+ from .utilities import _script_info, LogKeys
11
+ from .logger import _LOGGER
12
+
13
+
14
+ __all__ = [
15
+ "MyTrainer"
16
+ ]
17
+
18
+
19
+ class MyTrainer:
20
+ def __init__(self, model: nn.Module, train_dataset: Dataset, test_dataset: Dataset,
21
+ kind: Literal["regression", "classification"],
22
+ criterion: nn.Module, optimizer: torch.optim.Optimizer,
23
+ device: Union[Literal['cuda', 'mps', 'cpu'],str], dataloader_workers: int = 2, callbacks: Optional[List[Callback]] = None):
24
+ """
25
+ Automates the training process of a PyTorch Model.
26
+
27
+ Built-in Callbacks: `History`, `TqdmProgressBar`
28
+
29
+ Args:
30
+ model (nn.Module): The PyTorch model to train.
31
+ train_dataset (Dataset): The training dataset.
32
+ test_dataset (Dataset): The testing/validation dataset.
33
+ kind (str): The type of task, 'regression' or 'classification'.
34
+ criterion (nn.Module): The loss function.
35
+ optimizer (torch.optim.Optimizer): The optimizer.
36
+ device (str): The device to run training on ('cpu', 'cuda', 'mps').
37
+ dataloader_workers (int): Subprocesses for data loading. Defaults to 2.
38
+ callbacks (List[Callback] | None): A list of callbacks to use during training.
39
+
40
+ Note:
41
+ For **regression** tasks, suggested criterions include `nn.MSELoss` or `nn.L1Loss`.
42
+
43
+ For **classification** tasks, `nn.CrossEntropyLoss` (multi-class) or `nn.BCEWithLogitsLoss` (binary) are common choices.
44
+ """
45
+ if kind not in ["regression", "classification"]:
46
+ raise TypeError("Kind must be 'regression' or 'classification'.")
47
+
48
+ self.model = model
49
+ self.train_dataset = train_dataset
50
+ self.test_dataset = test_dataset
51
+ self.kind = kind
52
+ self.criterion = criterion
53
+ self.optimizer = optimizer
54
+ self.device = self._validate_device(device)
55
+ self.dataloader_workers = dataloader_workers
56
+
57
+ # Callback handler - History and TqdmProgressBar are added by default
58
+ default_callbacks = [History(), TqdmProgressBar()]
59
+ user_callbacks = callbacks if callbacks is not None else []
60
+ self.callbacks = default_callbacks + user_callbacks
61
+ self._set_trainer_on_callbacks()
62
+
63
+ # Internal state
64
+ self.train_loader = None
65
+ self.test_loader = None
66
+ self.history = {}
67
+ self.epoch = 0
68
+ self.epochs = 0 # Total epochs for the fit run
69
+ self.stop_training = False
70
+
71
+ def _validate_device(self, device: str) -> torch.device:
72
+ """Validates the selected device and returns a torch.device object."""
73
+ device_lower = device.lower()
74
+ if "cuda" in device_lower and not torch.cuda.is_available():
75
+ _LOGGER.warning("CUDA not available, switching to CPU.")
76
+ device = "cpu"
77
+ elif device_lower == "mps" and not torch.backends.mps.is_available():
78
+ _LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
79
+ device = "cpu"
80
+ return torch.device(device)
81
+
82
+ def _set_trainer_on_callbacks(self):
83
+ """Gives each callback a reference to this trainer instance."""
84
+ for callback in self.callbacks:
85
+ callback.set_trainer(self)
86
+
87
+ def _create_dataloaders(self, batch_size: int, shuffle: bool):
88
+ """Initializes the DataLoaders."""
89
+ # Ensure stability on MPS devices by setting num_workers to 0
90
+ loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
91
+
92
+ self.train_loader = DataLoader(
93
+ dataset=self.train_dataset,
94
+ batch_size=batch_size,
95
+ shuffle=shuffle,
96
+ num_workers=loader_workers,
97
+ pin_memory=(self.device.type == "cuda")
98
+ )
99
+ self.test_loader = DataLoader(
100
+ dataset=self.test_dataset,
101
+ batch_size=batch_size,
102
+ shuffle=False,
103
+ num_workers=loader_workers,
104
+ pin_memory=(self.device.type == "cuda")
105
+ )
106
+
107
+ def fit(self, epochs: int = 10, batch_size: int = 32, shuffle: bool = True):
108
+ """
109
+ Starts the training-validation process of the model.
110
+
111
+ Args:
112
+ epochs (int): The total number of epochs to train for.
113
+ batch_size (int): The number of samples per batch.
114
+ shuffle (bool): Whether to shuffle the training data at each epoch.
115
+ """
116
+ self.epochs = epochs
117
+ self._create_dataloaders(batch_size, shuffle)
118
+ self.model.to(self.device)
119
+
120
+ # Reset stop_training flag on the trainer
121
+ self.stop_training = False
122
+
123
+ self.callbacks_hook('on_train_begin')
124
+
125
+ for epoch in range(1, self.epochs + 1):
126
+ self.epoch = epoch
127
+ epoch_logs = {}
128
+ self.callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
129
+
130
+ train_logs = self._train_step()
131
+ epoch_logs.update(train_logs)
132
+
133
+ val_logs = self._validation_step()
134
+ epoch_logs.update(val_logs)
135
+
136
+ self.callbacks_hook('on_epoch_end', epoch, logs=epoch_logs)
137
+
138
+ # Check the early stopping flag
139
+ if self.stop_training:
140
+ break
141
+
142
+ self.callbacks_hook('on_train_end')
143
+ return self.history
144
+
145
+ def _train_step(self):
146
+ self.model.train()
147
+ running_loss = 0.0
148
+ # Enumerate to get batch index
149
+ for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
150
+ # Create a log dictionary for the batch
151
+ batch_logs = {
152
+ LogKeys.BATCH_INDEX: batch_idx,
153
+ LogKeys.BATCH_SIZE: features.size(0)
154
+ }
155
+ self.callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
156
+
157
+ features, target = features.to(self.device), target.to(self.device)
158
+ self.optimizer.zero_grad()
159
+ output = self.model(features)
160
+ if isinstance(self.criterion, (nn.MSELoss, nn.L1Loss)):
161
+ output = output.view_as(target)
162
+ loss = self.criterion(output, target)
163
+ loss.backward()
164
+ self.optimizer.step()
165
+
166
+ # Calculate batch loss and update running loss for the epoch
167
+ batch_loss = loss.item()
168
+ running_loss += batch_loss * features.size(0)
169
+
170
+ # Add the batch loss to the logs and call the end-of-batch hook
171
+ batch_logs[LogKeys.BATCH_LOSS] = batch_loss
172
+ self.callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
173
+
174
+ # Return the average loss for the entire epoch
175
+ return {LogKeys.TRAIN_LOSS: running_loss / len(self.train_loader.dataset)} # type: ignore
176
+
177
+ def _validation_step(self):
178
+ self.model.eval()
179
+ running_loss = 0.0
180
+ with torch.no_grad():
181
+ for features, target in self.test_loader: # type: ignore
182
+ features, target = features.to(self.device), target.to(self.device)
183
+ output = self.model(features)
184
+ if isinstance(self.criterion, (nn.MSELoss, nn.L1Loss)):
185
+ output = output.view_as(target)
186
+ loss = self.criterion(output, target)
187
+ running_loss += loss.item() * features.size(0)
188
+ logs = {LogKeys.VAL_LOSS: running_loss / len(self.test_loader.dataset)} # type: ignore
189
+ return logs
190
+
191
+ def predict(self, dataloader: DataLoader):
192
+ """
193
+ Yields model predictions batch by batch, avoids loading all predictions into memory at once.
194
+
195
+ Args:
196
+ dataloader (DataLoader): The dataloader to predict on.
197
+
198
+ Yields:
199
+ tuple: A tuple containing (y_pred_batch, y_prob_batch, y_true_batch).
200
+ y_prob_batch is None for regression tasks.
201
+ """
202
+ self.model.eval()
203
+ self.model.to(self.device)
204
+ with torch.no_grad():
205
+ for features, target in dataloader:
206
+ features = features.to(self.device)
207
+ output = self.model(features).cpu()
208
+ y_true_batch = target.numpy()
209
+
210
+ if self.kind == "classification":
211
+ probs = nn.functional.softmax(output, dim=1)
212
+ preds = torch.argmax(probs, dim=1)
213
+ y_pred_batch = preds.numpy()
214
+ y_prob_batch = probs.numpy()
215
+ else:
216
+ y_pred_batch = output.numpy()
217
+ y_prob_batch = None
218
+
219
+ yield y_pred_batch, y_prob_batch, y_true_batch
220
+
221
+ def evaluate(self, data: Optional[Union[DataLoader, Dataset]] = None, save_dir: Optional[Union[str,Path]] = None):
222
+ """
223
+ Evaluates the model on the given data.
224
+
225
+ Args:
226
+ data (DataLoader | Dataset | None ): The data to evaluate on.
227
+ Can be a DataLoader or a Dataset. If None, defaults to the trainer's internal test_dataset.
228
+ save_dir (str | Path | None): Directory to save all reports and plots. If None, metrics are shown but not saved.
229
+ """
230
+ eval_loader = None
231
+ if isinstance(data, DataLoader):
232
+ eval_loader = data
233
+ else:
234
+ # Determine which dataset to use (the one passed in, or the default test_dataset)
235
+ dataset_to_use = data if data is not None else self.test_dataset
236
+ if not isinstance(dataset_to_use, Dataset):
237
+ raise ValueError("Cannot evaluate. No valid DataLoader or Dataset was provided, "
238
+ "and no test_dataset is available in the trainer.")
239
+
240
+ # Create a new DataLoader from the dataset
241
+ eval_loader = DataLoader(
242
+ dataset=dataset_to_use,
243
+ batch_size=32, # A sensible default for evaluation
244
+ shuffle=False,
245
+ num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
246
+ pin_memory=(self.device.type == "cuda")
247
+ )
248
+
249
+ print("\n--- Model Evaluation ---")
250
+
251
+ # Collect results from the predict generator
252
+ all_preds, all_probs, all_true = [], [], []
253
+ for y_pred_b, y_prob_b, y_true_b in self.predict(eval_loader):
254
+ all_preds.append(y_pred_b)
255
+ if y_prob_b is not None:
256
+ all_probs.append(y_prob_b)
257
+ all_true.append(y_true_b)
258
+
259
+ y_pred = np.concatenate(all_preds)
260
+ y_true = np.concatenate(all_true)
261
+ y_prob = np.concatenate(all_probs) if self.kind == "classification" else None
262
+
263
+ if self.kind == "classification":
264
+ classification_metrics(y_true, y_pred, y_prob, save_dir=save_dir)
265
+ else:
266
+ regression_metrics(y_true.flatten(), y_pred.flatten(), save_dir=save_dir)
267
+
268
+ print("\n--- Training History ---")
269
+ plot_losses(self.history, save_dir=save_dir)
270
+
271
+ def explain(self, explain_dataset: Optional[Dataset] = None, n_samples: int = 100,
272
+ feature_names: Optional[List[str]] = None, save_dir: Optional[str] = None):
273
+ """
274
+ Explains model predictions using SHAP and saves all artifacts.
275
+
276
+ The background data is automatically sampled from the trainer's training dataset.
277
+
278
+ Args:
279
+ explain_dataset (Dataset, optional): A specific dataset to explain.
280
+ If None, the trainer's test dataset is used.
281
+ n_samples (int): The number of samples to use for both background and explanation.
282
+ feature_names (List[str], optional): Names for the features.
283
+ save_dir (str, optional): Directory to save all SHAP artifacts.
284
+ """
285
+ # Internal helper to create a dataloader and get a random sample
286
+ def _get_random_sample(dataset: Dataset, num_samples: int):
287
+ if dataset is None:
288
+ return None
289
+
290
+ # For MPS devices, num_workers must be 0 to ensure stability
291
+ loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
292
+
293
+ loader = DataLoader(
294
+ dataset,
295
+ batch_size=64,
296
+ shuffle=False,
297
+ num_workers=loader_workers
298
+ )
299
+
300
+ all_features = [features for features, _ in loader]
301
+ if not all_features:
302
+ return None
303
+
304
+ full_data = torch.cat(all_features, dim=0)
305
+
306
+ if num_samples >= full_data.size(0):
307
+ return full_data
308
+
309
+ rand_indices = torch.randperm(full_data.size(0))[:num_samples]
310
+ return full_data[rand_indices]
311
+
312
+ print(f"\n--- Preparing SHAP Data (sampling up to {n_samples} instances) ---")
313
+
314
+ # 1. Get background data from the trainer's train_dataset
315
+ background_data = _get_random_sample(self.train_dataset, n_samples)
316
+ if background_data is None:
317
+ print("Warning: Trainer's train_dataset is empty or invalid. Skipping SHAP analysis.")
318
+ return
319
+
320
+ # 2. Determine target dataset and get explanation instances
321
+ target_dataset = explain_dataset if explain_dataset is not None else self.test_dataset
322
+ instances_to_explain = _get_random_sample(target_dataset, n_samples)
323
+ if instances_to_explain is None:
324
+ print("Warning: Explanation dataset is empty or invalid. Skipping SHAP analysis.")
325
+ return
326
+
327
+ # 3. Call the plotting function
328
+ shap_summary_plot(
329
+ model=self.model,
330
+ background_data=background_data,
331
+ instances_to_explain=instances_to_explain,
332
+ feature_names=feature_names,
333
+ save_dir=save_dir
334
+ )
335
+
336
+
337
+ def callbacks_hook(self, method_name: str, *args, **kwargs):
338
+ """Calls the specified method on all callbacks."""
339
+ for callback in self.callbacks:
340
+ method = getattr(callback, method_name)
341
+ method(*args, **kwargs)
342
+
343
+ def info():
344
+ _script_info(__all__)
@@ -0,0 +1,300 @@
1
+ import json
2
+ from typing import Literal, Optional, Union
3
+ from pathlib import Path
4
+ from .logger import _LOGGER
5
+ from .utilities import make_fullpath, sanitize_filename
6
+
7
+
8
+ __all__ = [
9
+ "generate_notebook"
10
+ ]
11
+
12
+ def _get_notebook_content(kind: str):
13
+ """Helper function to generate the cell content for the notebook."""
14
+
15
+ # --- Common Cells ---
16
+ imports_cell = {
17
+ "cell_type": "code",
18
+ "source": [
19
+ "import torch\n",
20
+ "from torch import nn\n",
21
+ "from torch.utils.data import TensorDataset, DataLoader\n",
22
+ "import numpy as np\n",
23
+ "from pathlib import Path\n",
24
+ "\n",
25
+ "# Import from dragon_ml_toolbox\n",
26
+ "from ml_tools.ML_trainer import MyTrainer\n",
27
+ "from ml_tools.ML_callbacks import EarlyStopping, ModelCheckpoint"
28
+ "from ml_tools.utilities import LogKeys"
29
+ ]
30
+ }
31
+
32
+ device_cell = {
33
+ "cell_type": "code",
34
+ "source": [
35
+ "import torch\\n",
36
+ "if torch.cuda.is_available():\\n",
37
+ " device = 'cuda'\\n",
38
+ "elif torch.backends.mps.is_available():\\n",
39
+ " device = 'mps'\\n",
40
+ "else:\\n",
41
+ " device = 'cpu'\\n",
42
+ "\\n",
43
+ "print(f'Using device: {device}')"
44
+ ]
45
+ }
46
+
47
+ model_definition_cell = {
48
+ "cell_type": "markdown",
49
+ "source": [
50
+ "### 3. Define the Model, Criterion, and Optimizer\n",
51
+ "Next, we define a simple neural network for our task. We also need to choose a loss function (`criterion`) and an `optimizer`."
52
+ ]
53
+ }
54
+
55
+ callbacks_cell = {
56
+ "cell_type": "code",
57
+ "source": [
58
+ "# Define callbacks for training\n",
59
+ "model_filepath = 'best_model.pth'\n",
60
+ "monitor_metric = LogKeys.VAL_LOSS\n",
61
+ "\n",
62
+ "model_checkpoint = ModelCheckpoint(\n",
63
+ " filepath=model_filepath, \n",
64
+ " save_best_only=True, \n",
65
+ " monitor=monitor_metric, \n",
66
+ " mode='min'\n",
67
+ ")\n",
68
+ "\n",
69
+ "early_stopping = EarlyStopping(\n",
70
+ " patience=10, \n",
71
+ " monitor=monitor_metric, \n",
72
+ " mode='min'\n",
73
+ ")"
74
+ ]
75
+ }
76
+
77
+ trainer_instantiation_cell = {
78
+ "cell_type": "code",
79
+ "source": [
80
+ "trainer = MyTrainer(\n",
81
+ " model=model,\n",
82
+ " train_dataset=train_dataset,\n",
83
+ " test_dataset=test_dataset,\n",
84
+ f" kind='{kind}',\n",
85
+ " criterion=criterion,\n",
86
+ " optimizer=optimizer,\n",
87
+ " device=device,\\n",
88
+ " callbacks=[model_checkpoint, early_stopping]\n",
89
+ ")"
90
+ ]
91
+ }
92
+
93
+ fit_cell = {
94
+ "cell_type": "code",
95
+ "source": [
96
+ "history = trainer.fit(epochs=100, batch_size=16)"
97
+ ]
98
+ }
99
+
100
+ evaluation_cell = {
101
+ "cell_type": "code",
102
+ "source": [
103
+ "save_dir = Path('tutorial_results')\n",
104
+ "\n",
105
+ "# The evaluate method will automatically use the test_loader.\n",
106
+ "# First, we load the best weights saved by ModelCheckpoint.\n",
107
+ "model_path = Path(model_filepath)\n",
108
+ "if model_path.exists():\n",
109
+ " print(f'Loading best model from {model_path}')\n",
110
+ " trainer.model.load_state_dict(torch.load(model_path))\n",
111
+ "\n",
112
+ "print('\\n--- Evaluating Model ---')\n",
113
+ "# All evaluation artifacts will be saved in the 'evaluation' subdirectory.\n",
114
+ "trainer.evaluate(save_dir=save_dir / 'evaluation')"
115
+ ]
116
+ }
117
+
118
+ explanation_cell = {
119
+ "cell_type": "code",
120
+ "source": [
121
+ "print('\\n--- Explaining Model ---')\n",
122
+ "# We can also generate SHAP plots to explain the model's predictions.\n",
123
+ "# All SHAP artifacts will be saved in the 'explanation' subdirectory.\n",
124
+ "trainer.explain(\n",
125
+ " background_loader=trainer.train_loader,\n",
126
+ " explain_loader=trainer.test_loader,\n",
127
+ " save_dir=save_dir / 'explanation'\n",
128
+ ")"
129
+ ]
130
+ }
131
+
132
+
133
+ # --- Task-Specific Cells ---
134
+ if kind == 'classification':
135
+ title = "Classification Tutorial"
136
+ data_prep_source = [
137
+ "### 2. Prepare the Data\n",
138
+ "For this example, we'll generate some simple, linearly separable mock data for a binary classification task. We'll then wrap it in PyTorch `TensorDataset` objects."
139
+ ]
140
+ data_creation_source = [
141
+ "from sklearn.datasets import make_classification\n",
142
+ "from sklearn.model_selection import train_test_split\n",
143
+ "\n",
144
+ "X, y = make_classification(n_samples=200, n_features=10, n_informative=5, n_redundant=0, random_state=42)\n",
145
+ "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
146
+ "\n",
147
+ "# Convert to PyTorch tensors\n",
148
+ "X_train = torch.FloatTensor(X_train)\n",
149
+ "y_train = torch.LongTensor(y_train)\n",
150
+ "X_test = torch.FloatTensor(X_test)\n",
151
+ "y_test = torch.LongTensor(y_test)\n",
152
+ "\n",
153
+ "# Create TensorDatasets\n",
154
+ "train_dataset = TensorDataset(X_train, y_train)\n",
155
+ "test_dataset = TensorDataset(X_test, y_test)"
156
+ ]
157
+ model_creation_source = [
158
+ "class SimpleClassifier(nn.Module):\n",
159
+ " def __init__(self, input_features, num_classes):\n",
160
+ " super().__init__()\n",
161
+ " self.layer_1 = nn.Linear(input_features, 32)\n",
162
+ " self.layer_2 = nn.Linear(32, num_classes)\n",
163
+ " self.relu = nn.ReLU()\n",
164
+ " \n",
165
+ " def forward(self, x):\n",
166
+ " return self.layer_2(self.relu(self.layer_1(x)))\n",
167
+ "\n",
168
+ "model = SimpleClassifier(input_features=10, num_classes=2)\n",
169
+ "criterion = nn.CrossEntropyLoss()\n",
170
+ "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)"
171
+ ]
172
+
173
+ elif kind == 'regression':
174
+ title = "Regression Tutorial"
175
+ data_prep_source = [
176
+ "### 2. Prepare the Data\n",
177
+ "For this example, we'll generate some simple mock data for a regression task. We'll then wrap it in PyTorch `TensorDataset` objects."
178
+ ]
179
+ data_creation_source = [
180
+ "from sklearn.datasets import make_regression\n",
181
+ "from sklearn.model_selection import train_test_split\n",
182
+ "\n",
183
+ "X, y = make_regression(n_samples=200, n_features=5, noise=15, random_state=42)\n",
184
+ "y = y.reshape(-1, 1) # Reshape for compatibility with MSELoss\n",
185
+ "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
186
+ "\n",
187
+ "# Convert to PyTorch tensors\n",
188
+ "X_train = torch.FloatTensor(X_train)\n",
189
+ "y_train = torch.FloatTensor(y_train)\n",
190
+ "X_test = torch.FloatTensor(X_test)\n",
191
+ "y_test = torch.FloatTensor(y_test)\n",
192
+ "\n",
193
+ "# Create TensorDatasets\n",
194
+ "train_dataset = TensorDataset(X_train, y_train)\n",
195
+ "test_dataset = TensorDataset(X_test, y_test)"
196
+ ]
197
+ model_creation_source = [
198
+ "class SimpleRegressor(nn.Module):\n",
199
+ " def __init__(self, input_features, output_features):\n",
200
+ " super().__init__()\n",
201
+ " self.layer_1 = nn.Linear(input_features, 32)\n",
202
+ " self.layer_2 = nn.Linear(32, output_features)\n",
203
+ " self.relu = nn.ReLU()\n",
204
+ " \n",
205
+ " def forward(self, x):\n",
206
+ " return self.layer_2(self.relu(self.layer_1(x)))\n",
207
+ "\n",
208
+ "model = SimpleRegressor(input_features=5, output_features=1)\n",
209
+ "criterion = nn.MSELoss()\n",
210
+ "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)"
211
+ ]
212
+ else:
213
+ raise ValueError("kind must be 'classification' or 'regression'")
214
+
215
+ # --- Assemble Notebook ---
216
+ cells = [
217
+ {"cell_type": "markdown", "source": [f"# Dragon ML Toolbox - {title}\n", "This notebook demonstrates how to use the `MyTrainer` class for a complete training and evaluation workflow."]},
218
+ {"cell_type": "markdown", "source": ["### 1. Imports\n", "First, let's import all the necessary components."]},
219
+ imports_cell,
220
+ {"cell_type": "markdown", "source": data_prep_source},
221
+ {"cell_type": "code", "source": data_creation_source},
222
+ model_definition_cell,
223
+ {"cell_type": "code", "source": model_creation_source},
224
+ {"cell_type": "markdown", "source": ["### 4. Configure Callbacks\n", "We'll set up `ModelCheckpoint` to save the best model and `EarlyStopping` to prevent overfitting."]},
225
+ callbacks_cell,
226
+ {"cell_type": "markdown", "source": ["### 5. Initialize the Trainer\\n", "First, we'll determine the best device to run on. Then, we can instantiate `MyTrainer` with all our components."]},
227
+ device_cell,
228
+ trainer_instantiation_cell,
229
+ {"cell_type": "markdown", "source": ["### 6. Train the Model\n", "Call the `.fit()` method to start training."]},
230
+ fit_cell,
231
+ {"cell_type": "markdown", "source": ["### 7. Evaluate the Model\n", "Finally, call the `.evaluate()` method to see the performance report and save all plots and metrics."]},
232
+ evaluation_cell,
233
+ {"cell_type": "markdown", "source": ["### 8. Explain the Model\n", "We can also use the `.explain()` method to generate and save SHAP plots for model interpretability."]},
234
+ explanation_cell,
235
+ ]
236
+
237
+ # Add execution counts to code cells
238
+ for cell in cells:
239
+ if cell['cell_type'] == 'code':
240
+ cell['execution_count'] = None
241
+ cell['metadata'] = {}
242
+ cell['outputs'] = []
243
+
244
+ return cells
245
+
246
+
247
+ def generate_notebook(kind: Literal['classification', 'regression'] = 'classification', filepath: Optional[Union[str,Path]] = None):
248
+ """
249
+ Generates a tutorial Jupyter Notebook (.ipynb) file.
250
+
251
+ This function creates a complete, runnable notebook with mock data,
252
+ a simple model, and a full training/evaluation cycle using MyTrainer.
253
+
254
+ Args:
255
+ kind (str): The type of tutorial to generate, either 'classification' or 'regression'.
256
+ filepath (str | Path | None): The path to save the notebook file.
257
+ If None, defaults to 'classification_tutorial.ipynb' or
258
+ 'regression_tutorial.ipynb' in the current directory.
259
+ """
260
+ if kind not in ["classification", "regression"]:
261
+ raise ValueError("kind must be 'classification' or 'regression'")
262
+
263
+ if filepath is None:
264
+ sanitized_filepath = f"{kind}_tutorial.ipynb"
265
+ else:
266
+ sanitized_filepath = sanitize_filename(str(filepath))
267
+
268
+ # check suffix
269
+ if not sanitized_filepath.endswith(".ipynb"):
270
+ sanitized_filepath = sanitized_filepath + ".ipynb"
271
+
272
+ new_filepath = make_fullpath(sanitized_filepath, make=True)
273
+
274
+ _LOGGER.info(f"Generating {kind} tutorial notebook at: {filepath}")
275
+
276
+ cells = _get_notebook_content(kind)
277
+
278
+ notebook = {
279
+ "cells": cells,
280
+ "metadata": {
281
+ "kernelspec": {
282
+ "display_name": "Python 3",
283
+ "language": "python",
284
+ "name": "python3"
285
+ },
286
+ "language_info": {
287
+ "name": "python",
288
+ "version": "3.10.0"
289
+ }
290
+ },
291
+ "nbformat": 4,
292
+ "nbformat_minor": 2
293
+ }
294
+
295
+ try:
296
+ with open(new_filepath, 'w') as f:
297
+ json.dump(notebook, f, indent=2)
298
+ _LOGGER.info("Notebook generated successfully.")
299
+ except Exception as e:
300
+ _LOGGER.error(f"Error generating notebook: {e}")