pytorch-kito 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,298 @@
1
+ """
2
+ Base and example callbacks for image plotting to TensorBoard.
3
+
4
+ The BaseImagePlotter provides a foundation for custom image plotting callbacks.
5
+ """
6
+ from abc import abstractmethod
7
+ from typing import Optional, List
8
+
9
+ import matplotlib.pyplot as plt
10
+ from torch.utils.tensorboard import SummaryWriter
11
+
12
+ from kito.callbacks.callback_base import Callback
13
+
14
+
15
+ class BaseImagePlotter(Callback):
16
+ """
17
+ Base class for plotting images to TensorBoard.
18
+
19
+ Subclasses must implement `create_figure()` to define custom plotting logic.
20
+
21
+ Args:
22
+ log_dir: Directory for TensorBoard logs
23
+ tag: Tag for the image in TensorBoard
24
+ freq: Frequency of plotting (every N epochs)
25
+ batch_indices: Which batch indices to visualize (e.g., [0, 1, 2])
26
+
27
+ Example:
28
+ class MyCustomPlotter(BaseImagePlotter):
29
+ def create_figure(self, val_data, val_outputs, epoch, **kwargs):
30
+ fig, ax = plt.subplots()
31
+ # Your custom plotting logic here
32
+ ax.imshow(val_outputs[0].cpu().numpy())
33
+ return fig
34
+
35
+ plotter = MyCustomPlotter('logs/images', tag='predictions')
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ log_dir: str,
41
+ tag: str = 'validation_images',
42
+ freq: int = 1,
43
+ batch_indices: Optional[List[int]] = None
44
+ ):
45
+ self.log_dir = log_dir
46
+ self.tag = tag
47
+ self.freq = freq
48
+ self.batch_indices = batch_indices or [0]
49
+ self.writer = None
50
+
51
+ def on_train_begin(self, engine, model, **kwargs):
52
+ """Initialize TensorBoard writer."""
53
+ self.writer = SummaryWriter(self.log_dir)
54
+
55
+ @abstractmethod
56
+ def create_figure(self, val_data, val_outputs, epoch, **kwargs):
57
+ """
58
+ Create matplotlib figure for visualization.
59
+
60
+ Must be implemented by subclasses.
61
+
62
+ Args:
63
+ val_data: Validation data (tuple of inputs and targets)
64
+ val_outputs: Model predictions on validation data
65
+ epoch: Current epoch number
66
+ **kwargs: Additional context from engine
67
+
68
+ Returns:
69
+ matplotlib.figure.Figure or list of figures
70
+ """
71
+ pass
72
+
73
+ def on_epoch_end(self, epoch, engine, model, logs=None, **kwargs):
74
+ """Plot images to TensorBoard."""
75
+ if epoch % self.freq != 0:
76
+ return
77
+
78
+ # Get validation data and outputs from kwargs
79
+ val_data = kwargs.get('val_data')
80
+ val_outputs = kwargs.get('val_outputs')
81
+
82
+ if val_data is None or val_outputs is None:
83
+ return
84
+
85
+ # Create figure(s)
86
+ figures = self.create_figure(val_data, val_outputs, epoch, **kwargs)
87
+
88
+ # Log to TensorBoard
89
+ if isinstance(figures, (list, tuple)):
90
+ # Multiple figures
91
+ for i, fig in enumerate(figures):
92
+ idx = self.batch_indices[i] if i < len(self.batch_indices) else i
93
+ self.writer.add_figure(
94
+ f'{self.tag}_batch_{idx}',
95
+ fig,
96
+ global_step=epoch,
97
+ close=True
98
+ )
99
+ else:
100
+ # Single figure
101
+ self.writer.add_figure(
102
+ self.tag,
103
+ figures,
104
+ global_step=epoch,
105
+ close=True
106
+ )
107
+
108
+ self.writer.flush()
109
+
110
+ def on_train_end(self, engine, model, **kwargs):
111
+ """Close TensorBoard writer."""
112
+ if self.writer:
113
+ self.writer.close()
114
+
115
+
116
+ # ============================================================================
117
+ # EXAMPLE IMAGE PLOTTERS
118
+ # ============================================================================
119
+
120
+ class SimpleImagePlotter(BaseImagePlotter):
121
+ """
122
+ Simple image plotter for input/output comparison.
123
+
124
+ Shows input, ground truth, and prediction side by side.
125
+
126
+ Args:
127
+ log_dir: Directory for TensorBoard logs
128
+ tag: Tag for the image in TensorBoard
129
+ freq: Frequency of plotting (every N epochs)
130
+ batch_indices: Which batch indices to visualize
131
+ cmap: Colormap for matplotlib (default: 'viridis')
132
+
133
+ Example:
134
+ plotter = SimpleImagePlotter(
135
+ 'logs/images',
136
+ tag='validation',
137
+ batch_indices=[0, 1, 2]
138
+ )
139
+ """
140
+
141
+ def __init__(
142
+ self,
143
+ log_dir: str,
144
+ tag: str = 'validation',
145
+ freq: int = 1,
146
+ batch_indices: Optional[List[int]] = None,
147
+ cmap: str = 'viridis'
148
+ ):
149
+ super().__init__(log_dir, tag, freq, batch_indices)
150
+ self.cmap = cmap
151
+
152
+ def create_figure(self, val_data, val_outputs, epoch, **kwargs):
153
+ """Create side-by-side comparison figure."""
154
+ # Extract input and target
155
+ val_input = val_data[0] # (B, C, H, W) or (B, T, C, H, W)
156
+ val_target = val_data[1]
157
+
158
+ # Create figures for each batch index
159
+ figures = []
160
+ for idx in self.batch_indices:
161
+ if idx >= val_input.shape[0]:
162
+ break
163
+
164
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
165
+
166
+ # Get data for this batch index
167
+ input_img = self._prepare_image(val_input[idx])
168
+ target_img = self._prepare_image(val_target[idx])
169
+ pred_img = self._prepare_image(val_outputs[idx])
170
+
171
+ # Plot
172
+ axes[0].imshow(input_img, cmap=self.cmap)
173
+ axes[0].set_title('Input')
174
+ axes[0].axis('off')
175
+
176
+ axes[1].imshow(target_img, cmap=self.cmap)
177
+ axes[1].set_title('Ground Truth')
178
+ axes[1].axis('off')
179
+
180
+ axes[2].imshow(pred_img, cmap=self.cmap)
181
+ axes[2].set_title('Prediction')
182
+ axes[2].axis('off')
183
+
184
+ plt.tight_layout()
185
+ figures.append(fig)
186
+
187
+ return figures if len(figures) > 1 else figures[0]
188
+
189
+ def _prepare_image(self, tensor):
190
+ """
191
+ Prepare tensor for visualization.
192
+
193
+ Handles:
194
+ - Time series: (T, C, H, W) -> take middle frame
195
+ - Multi-channel: (C, H, W) -> take first channel
196
+ - Single channel: (1, H, W) -> squeeze
197
+ """
198
+ # Move to CPU and detach
199
+ img = tensor.detach().cpu()
200
+
201
+ # Handle time series
202
+ if img.ndim == 4: # (T, C, H, W)
203
+ img = img[img.shape[0] // 2] # Take middle frame
204
+
205
+ # Handle channels
206
+ if img.ndim == 3: # (C, H, W)
207
+ if img.shape[0] == 1:
208
+ img = img[0] # Single channel
209
+ elif img.shape[0] == 3:
210
+ img = img.permute(1, 2, 0) # RGB
211
+ else:
212
+ img = img[0] # Take first channel
213
+
214
+ return img.numpy()
215
+
216
+
217
+ class DifferencePlotter(BaseImagePlotter):
218
+ """
219
+ Plot input, prediction, and difference (error) map.
220
+
221
+ Useful for regression tasks to visualize prediction errors.
222
+
223
+ Args:
224
+ log_dir: Directory for TensorBoard logs
225
+ tag: Tag for the image in TensorBoard
226
+ freq: Frequency of plotting (every N epochs)
227
+ batch_indices: Which batch indices to visualize
228
+
229
+ Example:
230
+ plotter = DifferencePlotter('logs/images', tag='errors')
231
+ """
232
+
233
+ def __init__(
234
+ self,
235
+ log_dir: str,
236
+ tag: str = 'difference',
237
+ freq: int = 1,
238
+ batch_indices: Optional[List[int]] = None
239
+ ):
240
+ super().__init__(log_dir, tag, freq, batch_indices)
241
+
242
+ def create_figure(self, val_data, val_outputs, epoch, **kwargs):
243
+ """Create figure with input, prediction, and difference."""
244
+ val_input = val_data[0]
245
+ val_target = val_data[1]
246
+
247
+ figures = []
248
+ for idx in self.batch_indices:
249
+ if idx >= val_input.shape[0]:
250
+ break
251
+
252
+ fig, axes = plt.subplots(1, 4, figsize=(20, 5))
253
+
254
+ # Get data
255
+ input_img = self._prepare_image(val_input[idx])
256
+ target_img = self._prepare_image(val_target[idx])
257
+ pred_img = self._prepare_image(val_outputs[idx])
258
+ diff_img = target_img - pred_img
259
+
260
+ # Plot
261
+ axes[0].imshow(input_img, cmap='gray')
262
+ axes[0].set_title('Input')
263
+ axes[0].axis('off')
264
+
265
+ axes[1].imshow(target_img, cmap='gray')
266
+ axes[1].set_title('Target')
267
+ axes[1].axis('off')
268
+
269
+ axes[2].imshow(pred_img, cmap='gray')
270
+ axes[2].set_title('Prediction')
271
+ axes[2].axis('off')
272
+
273
+ im = axes[3].imshow(diff_img, cmap='RdBu_r', vmin=-diff_img.std(), vmax=diff_img.std())
274
+ axes[3].set_title('Difference (Target - Pred)')
275
+ axes[3].axis('off')
276
+ plt.colorbar(im, ax=axes[3])
277
+
278
+ plt.tight_layout()
279
+ figures.append(fig)
280
+
281
+ return figures if len(figures) > 1 else figures[0]
282
+
283
+ def _prepare_image(self, tensor):
284
+ """Prepare tensor for visualization."""
285
+ img = tensor.detach().cpu()
286
+
287
+ # Handle time series
288
+ if img.ndim == 4:
289
+ img = img[img.shape[0] // 2]
290
+
291
+ # Handle channels
292
+ if img.ndim == 3:
293
+ if img.shape[0] > 1:
294
+ img = img[0] # Take first channel
295
+ else:
296
+ img = img[0]
297
+
298
+ return img.numpy()
@@ -0,0 +1,132 @@
1
+ from typing import Optional, Callable
2
+
3
+ from kito.callbacks.callback_base import Callback
4
+ from torch.utils.tensorboard import SummaryWriter
5
+
6
+
7
+ class TensorBoardScalars(Callback):
8
+ """
9
+ Log scalar metrics to TensorBoard.
10
+
11
+ Args:
12
+ log_dir: Directory for TensorBoard logs
13
+
14
+ Example:
15
+ tb_scalars = TensorBoardScalars('logs/tensorboard')
16
+ """
17
+
18
+ def __init__(self, log_dir: str):
19
+ self.log_dir = log_dir
20
+ self.writer = None
21
+
22
+ def on_train_begin(self, engine, model, **kwargs):
23
+ """Initialize TensorBoard writer."""
24
+ self.writer = SummaryWriter(self.log_dir)
25
+
26
+ def on_epoch_end(self, epoch, engine, model, logs=None, **kwargs):
27
+ """Log scalars to TensorBoard."""
28
+ if logs is None:
29
+ return
30
+
31
+ for key, value in logs.items():
32
+ self.writer.add_scalar(key, value, epoch)
33
+
34
+ self.writer.flush()
35
+
36
+ def on_train_end(self, engine, model, **kwargs):
37
+ """Close TensorBoard writer."""
38
+ if self.writer:
39
+ self.writer.close()
40
+
41
+
42
+ class TensorBoardHistograms(Callback):
43
+ """
44
+ Log model parameter histograms to TensorBoard.
45
+
46
+ Args:
47
+ log_dir: Directory for TensorBoard logs
48
+ freq: Frequency of histogram logging (every N epochs)
49
+
50
+ Example:
51
+ tb_histograms = TensorBoardHistograms('logs/tensorboard', freq=5)
52
+ """
53
+
54
+ def __init__(self, log_dir: str, freq: int = 1):
55
+ self.log_dir = log_dir
56
+ self.freq = freq
57
+ self.writer = None
58
+
59
+ def on_train_begin(self, engine, model, **kwargs):
60
+ """Initialize TensorBoard writer."""
61
+ self.writer = SummaryWriter(self.log_dir)
62
+
63
+ def on_epoch_end(self, epoch, engine, model, logs=None, **kwargs):
64
+ """Log parameter histograms."""
65
+ if epoch % self.freq != 0:
66
+ return
67
+
68
+ # Handle DDP
69
+ model_to_log = model.module if hasattr(model, 'module') else model
70
+
71
+ for name, param in model_to_log.named_parameters():
72
+ self.writer.add_histogram(name, param, epoch)
73
+
74
+ self.writer.flush()
75
+
76
+ def on_train_end(self, engine, model, **kwargs):
77
+ """Close TensorBoard writer."""
78
+ if self.writer:
79
+ self.writer.close()
80
+
81
+
82
+ class TensorBoardGraph(Callback):
83
+ """
84
+ Log model graph to TensorBoard.
85
+
86
+ Args:
87
+ log_dir: Directory for TensorBoard logs
88
+ input_to_model: Function that returns sample input for the model
89
+
90
+ Example:
91
+ def get_input():
92
+ return torch.randn(1, 3, 64, 64)
93
+
94
+ tb_graph = TensorBoardGraph('logs/tensorboard', input_to_model=get_input)
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ log_dir: str,
100
+ input_to_model: Optional[Callable] = None
101
+ ):
102
+ self.log_dir = log_dir
103
+ self.input_to_model = input_to_model
104
+ self.writer = None
105
+ self.logged = False
106
+
107
+ def on_train_begin(self, engine, model, **kwargs):
108
+ """Initialize TensorBoard writer and log graph."""
109
+ if self.logged:
110
+ return
111
+
112
+ self.writer = SummaryWriter(self.log_dir)
113
+
114
+ # Get sample input
115
+ if self.input_to_model:
116
+ sample_input = self.input_to_model()
117
+ elif hasattr(engine, 'get_sample_input'):
118
+ sample_input = engine.get_sample_input()
119
+ else:
120
+ return # Can't log graph without input
121
+
122
+ # Handle DDP
123
+ model_to_log = model.module if hasattr(model, 'module') else model
124
+
125
+ self.writer.add_graph(model_to_log, sample_input)
126
+ self.writer.flush()
127
+ self.logged = True
128
+
129
+ def on_train_end(self, engine, model, **kwargs):
130
+ """Close TensorBoard writer."""
131
+ if self.writer:
132
+ self.writer.close()
@@ -0,0 +1,57 @@
1
+ import datetime
2
+ import os
3
+
4
+ from kito.callbacks.callback_base import Callback
5
+
6
+
7
+ class TextLogger(Callback):
8
+ """
9
+ Log training metrics to a text file.
10
+
11
+ Args:
12
+ filename: Path to log file
13
+ append: Append to existing file or overwrite
14
+
15
+ Example:
16
+ text_logger = TextLogger('logs/training.log')
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ filename: str,
22
+ append: bool = True
23
+ ):
24
+ self.filename = filename
25
+ self.append = append
26
+ self.file = None
27
+
28
+ # Create directory
29
+ os.makedirs(os.path.dirname(filename) if os.path.dirname(filename) else '.', exist_ok=True)
30
+
31
+ def on_train_begin(self, engine, model, **kwargs):
32
+ """Open log file."""
33
+ mode = 'a' if self.append else 'w'
34
+ self.file = open(self.filename, mode)
35
+
36
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
37
+ self.file.write(f"{'=' * 60}")
38
+ self.file.write(f"Training started at {timestamp}")
39
+ self.file.write(f"{'=' * 60}")
40
+ self.file.flush()
41
+
42
+ def on_epoch_end(self, epoch, engine, model, logs=None, **kwargs):
43
+ """Write metrics to log file."""
44
+ if logs is None:
45
+ return
46
+
47
+ self.file.write(f"Epoch {epoch}: ")
48
+ metrics_str = ", ".join(f"{k}={v:.4f}" for k, v in logs.items())
49
+ self.file.write(metrics_str + "\n")
50
+ self.file.flush()
51
+
52
+ def on_train_end(self, engine, model, **kwargs):
53
+ """Close log file."""
54
+ if self.file:
55
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
56
+ self.file.write(f"\nTraining ended at {timestamp}\n")
57
+ self.file.close()
File without changes
@@ -0,0 +1,201 @@
1
+ """
2
+ Base configuration system for KitoModule framework.
3
+
4
+ Users can extend these base configs with their own custom parameters.
5
+ """
6
+ from dataclasses import dataclass, field
7
+ from typing import Tuple, List, Optional, Dict, Any
8
+
9
+
10
+ @dataclass
11
+ class PreprocessingStepConfig:
12
+ """
13
+ Configuration for a single preprocessing step.
14
+
15
+ Args:
16
+ type: Name of preprocessing class (e.g., 'detrend', 'standardization')
17
+ params: Dictionary of parameters to pass to preprocessing class
18
+
19
+ Example:
20
+ >>> step = PreprocessingStepConfig(
21
+ ... type='standardization',
22
+ ... params={'mean': 0.5, 'std': 0.2}
23
+ ... )
24
+ """
25
+ type: str
26
+ params: Dict[str, Any] = field(default_factory=dict)
27
+
28
+
29
+ @dataclass
30
+ class DataConfig:
31
+ """
32
+ Data loading and preprocessing configuration.
33
+
34
+ This config defines:
35
+ - What dataset to use (H5, memory, custom)
36
+ - Where data is located
37
+ - How to initialize the dataset (flexible args)
38
+ - Memory loading strategy
39
+ - How to split data (train/val/test)
40
+ - What preprocessing to apply
41
+ - DataLoader settings
42
+ """
43
+
44
+ # Dataset configuration
45
+ dataset_type: str = 'h5dataset' # 'h5dataset', 'memdataset', or custom
46
+
47
+ # Simple path (backward compatible)
48
+ dataset_path: str = ''
49
+
50
+ # Flexible initialization args (for custom datasets)
51
+ # If provided, takes precedence over dataset_path
52
+ # Allows any constructor signature: Dataset(**dataset_init_args)
53
+ dataset_init_args: Dict[str, Any] = field(default_factory=dict)
54
+
55
+ # Memory management
56
+ load_into_memory: bool = False # Load entire dataset into RAM for faster training
57
+
58
+ # Splitting ratios
59
+ train_ratio: float = 0.8
60
+ val_ratio: float = 0.1
61
+ # test_ratio is implicit: 1 - train_ratio - val_ratio
62
+
63
+ # Total samples to use (None = use all available)
64
+ total_samples: Optional[int] = None
65
+
66
+ # Preprocessing pipeline
67
+ # List of preprocessing steps applied in order
68
+ preprocessing: List[PreprocessingStepConfig] = field(default_factory=list)
69
+
70
+ # DataLoader settings
71
+ num_workers: int = 0
72
+ prefetch_factor: int = 2
73
+ pin_memory: bool = False
74
+ persistent_workers: bool = False
75
+
76
+
77
+ @dataclass
78
+ class TrainingConfig:
79
+ """Core training parameters required by KitoModule."""
80
+
81
+ # Essential training parameters
82
+ learning_rate: float
83
+ n_train_epochs: int
84
+ batch_size: int
85
+ train_mode: bool # True for training, False for inference
86
+
87
+ # Verbosity (0=silent, 1=progress bar, 2=detailed)
88
+ train_verbosity_level: int = 2
89
+ val_verbosity_level: int = 2
90
+ test_verbosity_level: int = 2
91
+
92
+ # Distributed training
93
+ distributed_training: bool = False
94
+ master_gpu_id: int = 0 # Only used if not distributed
95
+
96
+ # Weight initialization
97
+ initialize_model_with_saved_weights: bool = False
98
+
99
+ # device type initialization
100
+ device_type: str = "cuda" # "cuda", "mps", or "cpu"
101
+
102
+ def __post_init__(self):
103
+ """Validate device_type after initialization."""
104
+ valid_devices = {"cuda", "mps", "cpu"}
105
+
106
+ if self.device_type not in valid_devices:
107
+ raise ValueError(
108
+ f"Invalid device_type: '{self.device_type}'. "
109
+ f"Must be one of {valid_devices}."
110
+ )
111
+
112
+ # Normalize to lowercase (user-friendly)
113
+ self.device_type = self.device_type.lower()
114
+
115
+
116
+ @dataclass
117
+ class ModelConfig:
118
+ """Core model parameters required by KitoModule."""
119
+
120
+ # Data dimensions
121
+ input_data_size: Tuple[int, ...] # Flexible shape
122
+
123
+ # Loss and optimization
124
+ loss: str = ""
125
+
126
+ # Callbacks and logging
127
+ log_to_tensorboard: bool = False
128
+ save_model_weights: bool = False
129
+ text_logging: bool = False
130
+ csv_logging: bool = False
131
+ train_codename: str = "experiment"
132
+
133
+ # Weights
134
+ weight_load_path: str = ""
135
+
136
+ # Inference
137
+ save_inference_to_disk: bool = False
138
+ inference_filename: str = ""
139
+
140
+ # TensorBoard visualization (optional)
141
+ tensorboard_img_id: str = "training_viz"
142
+ batch_idx_viz: List[int] = field(default_factory=lambda: [0])
143
+
144
+
145
+ @dataclass
146
+ class WorkDirConfig:
147
+ """Working directory configuration."""
148
+ work_directory: str
149
+
150
+
151
+ @dataclass
152
+ class CallbacksConfig:
153
+ """
154
+ Configuration for Kito's built-in default callbacks.
155
+
156
+ For custom callbacks, use instead:
157
+ callbacks = engine.get_default_callbacks()
158
+ callbacks.append(MyCustomCallback())
159
+ engine.fit(..., callbacks=callbacks)
160
+ """
161
+
162
+ # === CSV Logger ===
163
+ enable_csv_logger: bool = True
164
+
165
+ # === Text Logger ===
166
+ enable_text_logger: bool = True
167
+
168
+ # === Model Checkpoint ===
169
+ enable_model_checkpoint: bool = True
170
+ checkpoint_monitor: str = 'val_loss'
171
+ checkpoint_mode: str = 'min' # 'min' or 'max'
172
+ checkpoint_save_best_only: bool = True
173
+ checkpoint_verbose: bool = False
174
+
175
+ # === TensorBoard ===
176
+ enable_tensorboard: bool = False # Master switch
177
+ tensorboard_scalars: bool = True
178
+ tensorboard_histograms: bool = True
179
+ tensorboard_histogram_freq: int = 5
180
+ tensorboard_graph: bool = True
181
+ tensorboard_images: bool = False
182
+ tensorboard_image_freq: int = 1
183
+ tensorboard_batch_indices: List[int] = field(default_factory=lambda: [0])
184
+
185
+
186
+ @dataclass
187
+ class KitoModuleConfig:
188
+ """
189
+ Base configuration container for KitoModule.
190
+
191
+ Contains all configuration sections:
192
+ - training: Training parameters
193
+ - model: Model architecture and settings
194
+ - workdir: Output directories
195
+ - data: Dataset and preprocessing
196
+ """
197
+ training: TrainingConfig
198
+ model: ModelConfig
199
+ workdir: WorkDirConfig
200
+ data: DataConfig
201
+ callbacks: CallbacksConfig