pytorch-kito 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kito/__init__.py ADDED
@@ -0,0 +1,49 @@
1
+ """
2
+ Kito: Effortless PyTorch Training
3
+
4
+ Define your model, Kito handles the rest.
5
+ """
6
+
7
+ __version__ = "0.2.0"
8
+
9
+ # Import main classes for top-level access
10
+ from kito.engine import Engine
11
+ from kito.module import KitoModule
12
+
13
+ # Import common data classes
14
+ from kito.data.datasets import H5Dataset, MemDataset, KitoDataset
15
+ from kito.data.datapipeline import GenericDataPipeline
16
+ from kito.data.preprocessing import (
17
+ Preprocessing,
18
+ Pipeline,
19
+ Normalize,
20
+ Standardization,
21
+ ToTensor
22
+ )
23
+
24
+ # Import registries
25
+ from kito.data.registry import DATASETS, PREPROCESSING
26
+
27
+ # Define what's available with "from kito import *"
28
+ __all__ = [
29
+ # Core
30
+ "Engine",
31
+ "KitoModule",
32
+
33
+ # Data
34
+ "H5Dataset",
35
+ "MemDataset",
36
+ "KitoDataset",
37
+ "GenericDataPipeline",
38
+
39
+ # Preprocessing
40
+ "Preprocessing",
41
+ "Pipeline",
42
+ "Normalize",
43
+ "Standardization",
44
+ "ToTensor",
45
+
46
+ # Registries
47
+ "DATASETS",
48
+ "PREPROCESSING",
49
+ ]
@@ -0,0 +1,20 @@
1
+ # src/kito/callbacks/__init__.py
2
+ """Kito Callbacks - For custom training behavior"""
3
+
4
+ from kito.callbacks.callback_base import Callback, CallbackList
5
+ from kito.callbacks.modelcheckpoint import ModelCheckpoint
6
+ from kito.callbacks.csv_logger import CSVLogger
7
+ from kito.callbacks.txt_logger import TextLogger
8
+ from kito.callbacks.tensorboard_callbacks import TensorBoardScalars, TensorBoardGraph, TensorBoardHistograms
9
+ # ... other callbacks
10
+
11
+ __all__ = [
12
+ "Callback",
13
+ "CallbackList",
14
+ "ModelCheckpoint",
15
+ "CSVLogger",
16
+ "TextLogger",
17
+ "TensorBoardScalars",
18
+ "TensorBoardGraph",
19
+ "TensorBoardHistograms"
20
+ ]
@@ -0,0 +1,107 @@
1
+ """
2
+ Modern callback system for BaseModule.
3
+
4
+ Inspired by Keras and PyTorch Lightning callback patterns.
5
+ Each callback is independent and handles a single concern.
6
+ """
7
+ from abc import ABC
8
+
9
+
10
+ class Callback(ABC):
11
+ """
12
+ Base class for all callbacks.
13
+
14
+ Callbacks allow you to customize the training loop behavior
15
+ by hooking into specific events (epoch start/end, batch start/end, etc.).
16
+
17
+ All methods have default no-op implementations, so you only need to
18
+ override the ones you care about.
19
+ """
20
+
21
+ def on_train_begin(self, engine, model, **kwargs):
22
+ """Called at the beginning of training."""
23
+ pass
24
+
25
+ def on_train_end(self, engine, model, **kwargs):
26
+ """Called at the end of training."""
27
+ pass
28
+
29
+ def on_epoch_begin(self, epoch, engine, model, **kwargs):
30
+ """Called at the beginning of each epoch."""
31
+ pass
32
+
33
+ def on_epoch_end(self, epoch, engine, model, logs=None, **kwargs):
34
+ """
35
+ Called at the end of each epoch.
36
+
37
+ Args:
38
+ epoch: Current epoch number (1-indexed)
39
+ engine: Reference to the Engine
40
+ model: The PyTorch model
41
+ logs: Dictionary of metrics (e.g., {'train_loss': 0.5, 'val_loss': 0.3})
42
+ **kwargs: Additional context (val_data, val_outputs, etc.)
43
+ """
44
+ pass
45
+
46
+ def on_train_batch_begin(self, batch, engine, model, **kwargs):
47
+ """Called at the beginning of each training batch."""
48
+ pass
49
+
50
+ def on_train_batch_end(self, batch, engine, model, logs=None, **kwargs):
51
+ """Called at the end of each training batch."""
52
+ pass
53
+
54
+ def on_validation_begin(self, epoch, engine, model, **kwargs):
55
+ """Called at the beginning of validation."""
56
+ pass
57
+
58
+ def on_validation_end(self, epoch, engine, model, logs=None, **kwargs):
59
+ """Called at the end of validation."""
60
+ pass
61
+
62
+
63
+ class CallbackList:
64
+ """
65
+ Container for managing multiple callbacks.
66
+
67
+ Iterates through all callbacks and calls the appropriate method.
68
+ """
69
+
70
+ def __init__(self, callbacks=None):
71
+ self.callbacks = callbacks or []
72
+
73
+ def append(self, callback):
74
+ """Add a callback to the list."""
75
+ self.callbacks.append(callback)
76
+
77
+ def on_train_begin(self, engine, model, **kwargs):
78
+ for callback in self.callbacks:
79
+ callback.on_train_begin(engine, model, **kwargs)
80
+
81
+ def on_train_end(self, engine, model, **kwargs):
82
+ for callback in self.callbacks:
83
+ callback.on_train_end(engine, model, **kwargs)
84
+
85
+ def on_epoch_begin(self, epoch, engine, model, **kwargs):
86
+ for callback in self.callbacks:
87
+ callback.on_epoch_begin(epoch, engine, model, **kwargs)
88
+
89
+ def on_epoch_end(self, epoch, engine, model, logs=None, **kwargs):
90
+ for callback in self.callbacks:
91
+ callback.on_epoch_end(epoch, engine, model, logs, **kwargs)
92
+
93
+ def on_train_batch_begin(self, batch, engine, model, **kwargs):
94
+ for callback in self.callbacks:
95
+ callback.on_train_batch_begin(batch, engine, model, **kwargs)
96
+
97
+ def on_train_batch_end(self, batch, engine, model, logs=None, **kwargs):
98
+ for callback in self.callbacks:
99
+ callback.on_train_batch_end(batch, engine, model, logs, **kwargs)
100
+
101
+ def on_validation_begin(self, epoch, engine, model, **kwargs):
102
+ for callback in self.callbacks:
103
+ callback.on_validation_begin(epoch, engine, model, **kwargs)
104
+
105
+ def on_validation_end(self, epoch, engine, model, logs=None, **kwargs):
106
+ for callback in self.callbacks:
107
+ callback.on_validation_end(epoch, engine, model, logs, **kwargs)
@@ -0,0 +1,66 @@
1
+ import csv
2
+ import os
3
+
4
+ from kito.callbacks.callback_base import Callback
5
+
6
+
7
+ class CSVLogger(Callback):
8
+ """
9
+ Log training metrics to a CSV file.
10
+
11
+ Args:
12
+ filename: Path to CSV file
13
+ separator: Column separator (default: ',')
14
+ append: Append to existing file or overwrite
15
+
16
+ Example:
17
+ csv_logger = CSVLogger('logs/training_log.csv')
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ filename: str,
23
+ separator: str = ',',
24
+ append: bool = False
25
+ ):
26
+ self.filename = filename
27
+ self.separator = separator
28
+ self.append = append
29
+ self.writer = None
30
+ self.file = None
31
+ self.keys = None
32
+
33
+ # Create directory
34
+ os.makedirs(os.path.dirname(filename) if os.path.dirname(filename) else '.', exist_ok=True)
35
+
36
+ def on_train_begin(self, engine, model, **kwargs):
37
+ """Open CSV file and write header."""
38
+ mode = 'a' if self.append else 'w'
39
+ self.file = open(self.filename, mode, newline='')
40
+ self.writer = csv.writer(self.file, delimiter=self.separator)
41
+
42
+ # If not appending, we'll write header on first epoch
43
+ if not self.append:
44
+ self.keys = None
45
+
46
+ def on_epoch_end(self, epoch, engine, model, logs=None, **kwargs):
47
+ """Write metrics to CSV."""
48
+ if logs is None:
49
+ return
50
+
51
+ # Add epoch to logs
52
+ row_dict = {'epoch': epoch, **logs}
53
+
54
+ # Write header if first time
55
+ if self.keys is None:
56
+ self.keys = list(row_dict.keys())
57
+ self.writer.writerow(self.keys)
58
+
59
+ # Write values
60
+ self.writer.writerow([row_dict.get(k, '') for k in self.keys])
61
+ self.file.flush()
62
+
63
+ def on_train_end(self, engine, model, **kwargs):
64
+ """Close CSV file."""
65
+ if self.file:
66
+ self.file.close()
@@ -0,0 +1,60 @@
1
+ import torch.distributed as dist
2
+
3
+ from kito.callbacks.callback_base import Callback
4
+
5
+
6
+ class DDPAwareCallback(Callback):
7
+ """
8
+ Wrapper that makes any callback DDP-safe.
9
+
10
+ Only executes on rank 0 in distributed training.
11
+
12
+ Args:
13
+ callback: The callback to wrap
14
+
15
+ Example:
16
+ checkpoint = ModelCheckpoint('weights/best.pt')
17
+ ddp_checkpoint = DDPAwareCallback(checkpoint)
18
+ """
19
+
20
+ def __init__(self, callback: Callback):
21
+ self.callback = callback
22
+ self.is_driver = self._check_if_driver()
23
+
24
+ def _check_if_driver(self):
25
+ """Check if this is the driver process (rank 0)."""
26
+ if dist.is_available() and dist.is_initialized():
27
+ return dist.get_rank() == 0
28
+ return True # Single GPU or CPU
29
+
30
+ def on_train_begin(self, engine, model, **kwargs):
31
+ if self.is_driver:
32
+ self.callback.on_train_begin(engine, model, **kwargs)
33
+
34
+ def on_train_end(self, engine, model, **kwargs):
35
+ if self.is_driver:
36
+ self.callback.on_train_end(engine, model, **kwargs)
37
+
38
+ def on_epoch_begin(self, epoch, engine, model, **kwargs):
39
+ if self.is_driver:
40
+ self.callback.on_epoch_begin(epoch, engine, model, **kwargs)
41
+
42
+ def on_epoch_end(self, epoch, engine, model, logs=None, **kwargs):
43
+ if self.is_driver:
44
+ self.callback.on_epoch_end(epoch, engine, model, logs, **kwargs)
45
+
46
+ def on_train_batch_begin(self, batch, engine, model, **kwargs):
47
+ if self.is_driver:
48
+ self.callback.on_train_batch_begin(batch, engine, model, **kwargs)
49
+
50
+ def on_train_batch_end(self, batch, engine, model, logs=None, **kwargs):
51
+ if self.is_driver:
52
+ self.callback.on_train_batch_end(batch, engine, model, logs, **kwargs)
53
+
54
+ def on_validation_begin(self, epoch, engine, model, **kwargs):
55
+ if self.is_driver:
56
+ self.callback.on_validation_begin(epoch, engine, model, **kwargs)
57
+
58
+ def on_validation_end(self, epoch, engine, model, logs=None, **kwargs):
59
+ if self.is_driver:
60
+ self.callback.on_validation_end(epoch, engine, model, logs, **kwargs)
@@ -0,0 +1,45 @@
1
+ from kito.callbacks.callback_base import Callback
2
+
3
+
4
+ class EarlyStoppingCallback(Callback):
5
+ """
6
+ Stop training when monitored metric stops improving.
7
+
8
+ Args:
9
+ monitor: Metric to monitor
10
+ patience: Number of epochs with no improvement before stopping
11
+ mode: 'min' or 'max'
12
+ """
13
+
14
+ def __init__(self, monitor='val_loss', patience=10, mode='min'):
15
+ self.monitor = monitor
16
+ self.patience = patience
17
+ self.mode = mode
18
+ self.best = float('inf') if mode == 'min' else float('-inf')
19
+ self.wait = 0
20
+ self.stopped_epoch = 0
21
+
22
+ def on_epoch_end(self, epoch, engine, model, logs=None, **kwargs):
23
+ """Check if training should stop."""
24
+ if logs is None:
25
+ return
26
+
27
+ current = logs.get(self.monitor)
28
+ if current is None:
29
+ return
30
+
31
+ # Check if improved
32
+ improved = (
33
+ (self.mode == 'min' and current < self.best) or
34
+ (self.mode == 'max' and current > self.best)
35
+ )
36
+
37
+ if improved:
38
+ self.best = current
39
+ self.wait = 0
40
+ else:
41
+ self.wait += 1
42
+ if self.wait >= self.patience:
43
+ self.stopped_epoch = epoch
44
+ engine.stop_training = True # Requires Engine to support this
45
+ print(f"\nEarly stopping at epoch {epoch}")
@@ -0,0 +1,78 @@
1
+ import os
2
+ import torch
3
+
4
+ from kito.callbacks.callback_base import Callback
5
+
6
+
7
+ class ModelCheckpoint(Callback):
8
+ """
9
+ Save model weights during training.
10
+
11
+ Args:
12
+ filepath: Path template for saving weights (can include {epoch}, {val_loss}, etc.)
13
+ monitor: Metric to monitor (e.g., 'val_loss')
14
+ save_best_only: Only save when monitored metric improves
15
+ mode: 'min' or 'max' depending on whether lower/higher is better
16
+ verbose: Print message when saving
17
+
18
+ Example:
19
+ checkpoint = ModelCheckpoint(
20
+ filepath='weights/model_epoch{epoch:02d}_valloss{val_loss:.4f}.pt',
21
+ monitor='val_loss',
22
+ save_best_only=True,
23
+ mode='min'
24
+ )
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ filepath: str,
30
+ monitor: str = 'val_loss',
31
+ save_best_only: bool = True,
32
+ mode: str = 'min',
33
+ verbose: bool = False
34
+ ):
35
+ self.filepath = filepath
36
+ self.monitor = monitor
37
+ self.save_best_only = save_best_only
38
+ self.mode = mode
39
+ self.verbose = verbose
40
+
41
+ # Track best metric
42
+ self.best = float('inf') if mode == 'min' else float('-inf')
43
+
44
+ # Create directory if needed
45
+ os.makedirs(os.path.dirname(filepath) if os.path.dirname(filepath) else '.', exist_ok=True)
46
+
47
+ def on_epoch_end(self, epoch, engine, model, logs=None, **kwargs):
48
+ """Save model if metric improved."""
49
+ if logs is None:
50
+ return
51
+
52
+ current = logs.get(self.monitor)
53
+ if current is None:
54
+ return
55
+
56
+ # Check if improved
57
+ improved = (
58
+ (self.mode == 'min' and current < self.best) or
59
+ (self.mode == 'max' and current > self.best)
60
+ )
61
+
62
+ if improved or not self.save_best_only:
63
+ self.best = current
64
+
65
+ # Format filepath
66
+ filepath = self.filepath.format(epoch=epoch, **logs)
67
+
68
+ # Save model (handle DDP)
69
+ if hasattr(model, 'module'):
70
+ state_dict = model.module.state_dict()
71
+ else:
72
+ state_dict = model.state_dict()
73
+
74
+ torch.save(state_dict, filepath)
75
+
76
+ if self.verbose:
77
+ print(f"\nEpoch {epoch}: {self.monitor} improved to {current:.4f}, "
78
+ f"saving model to {filepath}")