pytorch-kito 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kito/__init__.py +49 -0
- kito/callbacks/__init__.py +20 -0
- kito/callbacks/callback_base.py +107 -0
- kito/callbacks/csv_logger.py +66 -0
- kito/callbacks/ddp_aware_callback.py +60 -0
- kito/callbacks/early_stopping_callback.py +45 -0
- kito/callbacks/modelcheckpoint.py +78 -0
- kito/callbacks/tensorboard_callback_images.py +298 -0
- kito/callbacks/tensorboard_callbacks.py +132 -0
- kito/callbacks/txt_logger.py +57 -0
- kito/config/__init__.py +0 -0
- kito/config/moduleconfig.py +201 -0
- kito/data/__init__.py +35 -0
- kito/data/datapipeline.py +273 -0
- kito/data/datasets.py +166 -0
- kito/data/preprocessed_dataset.py +57 -0
- kito/data/preprocessing.py +318 -0
- kito/data/registry.py +96 -0
- kito/engine.py +841 -0
- kito/module.py +447 -0
- kito/strategies/__init__.py +0 -0
- kito/strategies/logger_strategy.py +51 -0
- kito/strategies/progress_bar_strategy.py +57 -0
- kito/strategies/readiness_validator.py +85 -0
- kito/utils/__init__.py +0 -0
- kito/utils/decorators.py +45 -0
- kito/utils/gpu_utils.py +94 -0
- kito/utils/loss_utils.py +38 -0
- kito/utils/ssim_utils.py +94 -0
- pytorch_kito-0.2.0.dist-info/METADATA +328 -0
- pytorch_kito-0.2.0.dist-info/RECORD +34 -0
- pytorch_kito-0.2.0.dist-info/WHEEL +5 -0
- pytorch_kito-0.2.0.dist-info/licenses/LICENSE +21 -0
- pytorch_kito-0.2.0.dist-info/top_level.txt +1 -0
kito/__init__.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Kito: Effortless PyTorch Training
|
|
3
|
+
|
|
4
|
+
Define your model, Kito handles the rest.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
__version__ = "0.2.0"
|
|
8
|
+
|
|
9
|
+
# Import main classes for top-level access
|
|
10
|
+
from kito.engine import Engine
|
|
11
|
+
from kito.module import KitoModule
|
|
12
|
+
|
|
13
|
+
# Import common data classes
|
|
14
|
+
from kito.data.datasets import H5Dataset, MemDataset, KitoDataset
|
|
15
|
+
from kito.data.datapipeline import GenericDataPipeline
|
|
16
|
+
from kito.data.preprocessing import (
|
|
17
|
+
Preprocessing,
|
|
18
|
+
Pipeline,
|
|
19
|
+
Normalize,
|
|
20
|
+
Standardization,
|
|
21
|
+
ToTensor
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
# Import registries
|
|
25
|
+
from kito.data.registry import DATASETS, PREPROCESSING
|
|
26
|
+
|
|
27
|
+
# Define what's available with "from kito import *"
|
|
28
|
+
__all__ = [
|
|
29
|
+
# Core
|
|
30
|
+
"Engine",
|
|
31
|
+
"KitoModule",
|
|
32
|
+
|
|
33
|
+
# Data
|
|
34
|
+
"H5Dataset",
|
|
35
|
+
"MemDataset",
|
|
36
|
+
"KitoDataset",
|
|
37
|
+
"GenericDataPipeline",
|
|
38
|
+
|
|
39
|
+
# Preprocessing
|
|
40
|
+
"Preprocessing",
|
|
41
|
+
"Pipeline",
|
|
42
|
+
"Normalize",
|
|
43
|
+
"Standardization",
|
|
44
|
+
"ToTensor",
|
|
45
|
+
|
|
46
|
+
# Registries
|
|
47
|
+
"DATASETS",
|
|
48
|
+
"PREPROCESSING",
|
|
49
|
+
]
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# src/kito/callbacks/__init__.py
|
|
2
|
+
"""Kito Callbacks - For custom training behavior"""
|
|
3
|
+
|
|
4
|
+
from kito.callbacks.callback_base import Callback, CallbackList
|
|
5
|
+
from kito.callbacks.modelcheckpoint import ModelCheckpoint
|
|
6
|
+
from kito.callbacks.csv_logger import CSVLogger
|
|
7
|
+
from kito.callbacks.txt_logger import TextLogger
|
|
8
|
+
from kito.callbacks.tensorboard_callbacks import TensorBoardScalars, TensorBoardGraph, TensorBoardHistograms
|
|
9
|
+
# ... other callbacks
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"Callback",
|
|
13
|
+
"CallbackList",
|
|
14
|
+
"ModelCheckpoint",
|
|
15
|
+
"CSVLogger",
|
|
16
|
+
"TextLogger",
|
|
17
|
+
"TensorBoardScalars",
|
|
18
|
+
"TensorBoardGraph",
|
|
19
|
+
"TensorBoardHistograms"
|
|
20
|
+
]
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Modern callback system for BaseModule.
|
|
3
|
+
|
|
4
|
+
Inspired by Keras and PyTorch Lightning callback patterns.
|
|
5
|
+
Each callback is independent and handles a single concern.
|
|
6
|
+
"""
|
|
7
|
+
from abc import ABC
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Callback(ABC):
|
|
11
|
+
"""
|
|
12
|
+
Base class for all callbacks.
|
|
13
|
+
|
|
14
|
+
Callbacks allow you to customize the training loop behavior
|
|
15
|
+
by hooking into specific events (epoch start/end, batch start/end, etc.).
|
|
16
|
+
|
|
17
|
+
All methods have default no-op implementations, so you only need to
|
|
18
|
+
override the ones you care about.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def on_train_begin(self, engine, model, **kwargs):
|
|
22
|
+
"""Called at the beginning of training."""
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
def on_train_end(self, engine, model, **kwargs):
|
|
26
|
+
"""Called at the end of training."""
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
def on_epoch_begin(self, epoch, engine, model, **kwargs):
|
|
30
|
+
"""Called at the beginning of each epoch."""
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
def on_epoch_end(self, epoch, engine, model, logs=None, **kwargs):
|
|
34
|
+
"""
|
|
35
|
+
Called at the end of each epoch.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
epoch: Current epoch number (1-indexed)
|
|
39
|
+
engine: Reference to the Engine
|
|
40
|
+
model: The PyTorch model
|
|
41
|
+
logs: Dictionary of metrics (e.g., {'train_loss': 0.5, 'val_loss': 0.3})
|
|
42
|
+
**kwargs: Additional context (val_data, val_outputs, etc.)
|
|
43
|
+
"""
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
def on_train_batch_begin(self, batch, engine, model, **kwargs):
|
|
47
|
+
"""Called at the beginning of each training batch."""
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
def on_train_batch_end(self, batch, engine, model, logs=None, **kwargs):
|
|
51
|
+
"""Called at the end of each training batch."""
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
def on_validation_begin(self, epoch, engine, model, **kwargs):
|
|
55
|
+
"""Called at the beginning of validation."""
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
def on_validation_end(self, epoch, engine, model, logs=None, **kwargs):
|
|
59
|
+
"""Called at the end of validation."""
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class CallbackList:
|
|
64
|
+
"""
|
|
65
|
+
Container for managing multiple callbacks.
|
|
66
|
+
|
|
67
|
+
Iterates through all callbacks and calls the appropriate method.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(self, callbacks=None):
|
|
71
|
+
self.callbacks = callbacks or []
|
|
72
|
+
|
|
73
|
+
def append(self, callback):
|
|
74
|
+
"""Add a callback to the list."""
|
|
75
|
+
self.callbacks.append(callback)
|
|
76
|
+
|
|
77
|
+
def on_train_begin(self, engine, model, **kwargs):
|
|
78
|
+
for callback in self.callbacks:
|
|
79
|
+
callback.on_train_begin(engine, model, **kwargs)
|
|
80
|
+
|
|
81
|
+
def on_train_end(self, engine, model, **kwargs):
|
|
82
|
+
for callback in self.callbacks:
|
|
83
|
+
callback.on_train_end(engine, model, **kwargs)
|
|
84
|
+
|
|
85
|
+
def on_epoch_begin(self, epoch, engine, model, **kwargs):
|
|
86
|
+
for callback in self.callbacks:
|
|
87
|
+
callback.on_epoch_begin(epoch, engine, model, **kwargs)
|
|
88
|
+
|
|
89
|
+
def on_epoch_end(self, epoch, engine, model, logs=None, **kwargs):
|
|
90
|
+
for callback in self.callbacks:
|
|
91
|
+
callback.on_epoch_end(epoch, engine, model, logs, **kwargs)
|
|
92
|
+
|
|
93
|
+
def on_train_batch_begin(self, batch, engine, model, **kwargs):
|
|
94
|
+
for callback in self.callbacks:
|
|
95
|
+
callback.on_train_batch_begin(batch, engine, model, **kwargs)
|
|
96
|
+
|
|
97
|
+
def on_train_batch_end(self, batch, engine, model, logs=None, **kwargs):
|
|
98
|
+
for callback in self.callbacks:
|
|
99
|
+
callback.on_train_batch_end(batch, engine, model, logs, **kwargs)
|
|
100
|
+
|
|
101
|
+
def on_validation_begin(self, epoch, engine, model, **kwargs):
|
|
102
|
+
for callback in self.callbacks:
|
|
103
|
+
callback.on_validation_begin(epoch, engine, model, **kwargs)
|
|
104
|
+
|
|
105
|
+
def on_validation_end(self, epoch, engine, model, logs=None, **kwargs):
|
|
106
|
+
for callback in self.callbacks:
|
|
107
|
+
callback.on_validation_end(epoch, engine, model, logs, **kwargs)
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
import csv
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from kito.callbacks.callback_base import Callback
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class CSVLogger(Callback):
|
|
8
|
+
"""
|
|
9
|
+
Log training metrics to a CSV file.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
filename: Path to CSV file
|
|
13
|
+
separator: Column separator (default: ',')
|
|
14
|
+
append: Append to existing file or overwrite
|
|
15
|
+
|
|
16
|
+
Example:
|
|
17
|
+
csv_logger = CSVLogger('logs/training_log.csv')
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
filename: str,
|
|
23
|
+
separator: str = ',',
|
|
24
|
+
append: bool = False
|
|
25
|
+
):
|
|
26
|
+
self.filename = filename
|
|
27
|
+
self.separator = separator
|
|
28
|
+
self.append = append
|
|
29
|
+
self.writer = None
|
|
30
|
+
self.file = None
|
|
31
|
+
self.keys = None
|
|
32
|
+
|
|
33
|
+
# Create directory
|
|
34
|
+
os.makedirs(os.path.dirname(filename) if os.path.dirname(filename) else '.', exist_ok=True)
|
|
35
|
+
|
|
36
|
+
def on_train_begin(self, engine, model, **kwargs):
|
|
37
|
+
"""Open CSV file and write header."""
|
|
38
|
+
mode = 'a' if self.append else 'w'
|
|
39
|
+
self.file = open(self.filename, mode, newline='')
|
|
40
|
+
self.writer = csv.writer(self.file, delimiter=self.separator)
|
|
41
|
+
|
|
42
|
+
# If not appending, we'll write header on first epoch
|
|
43
|
+
if not self.append:
|
|
44
|
+
self.keys = None
|
|
45
|
+
|
|
46
|
+
def on_epoch_end(self, epoch, engine, model, logs=None, **kwargs):
|
|
47
|
+
"""Write metrics to CSV."""
|
|
48
|
+
if logs is None:
|
|
49
|
+
return
|
|
50
|
+
|
|
51
|
+
# Add epoch to logs
|
|
52
|
+
row_dict = {'epoch': epoch, **logs}
|
|
53
|
+
|
|
54
|
+
# Write header if first time
|
|
55
|
+
if self.keys is None:
|
|
56
|
+
self.keys = list(row_dict.keys())
|
|
57
|
+
self.writer.writerow(self.keys)
|
|
58
|
+
|
|
59
|
+
# Write values
|
|
60
|
+
self.writer.writerow([row_dict.get(k, '') for k in self.keys])
|
|
61
|
+
self.file.flush()
|
|
62
|
+
|
|
63
|
+
def on_train_end(self, engine, model, **kwargs):
|
|
64
|
+
"""Close CSV file."""
|
|
65
|
+
if self.file:
|
|
66
|
+
self.file.close()
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import torch.distributed as dist
|
|
2
|
+
|
|
3
|
+
from kito.callbacks.callback_base import Callback
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DDPAwareCallback(Callback):
|
|
7
|
+
"""
|
|
8
|
+
Wrapper that makes any callback DDP-safe.
|
|
9
|
+
|
|
10
|
+
Only executes on rank 0 in distributed training.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
callback: The callback to wrap
|
|
14
|
+
|
|
15
|
+
Example:
|
|
16
|
+
checkpoint = ModelCheckpoint('weights/best.pt')
|
|
17
|
+
ddp_checkpoint = DDPAwareCallback(checkpoint)
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, callback: Callback):
|
|
21
|
+
self.callback = callback
|
|
22
|
+
self.is_driver = self._check_if_driver()
|
|
23
|
+
|
|
24
|
+
def _check_if_driver(self):
|
|
25
|
+
"""Check if this is the driver process (rank 0)."""
|
|
26
|
+
if dist.is_available() and dist.is_initialized():
|
|
27
|
+
return dist.get_rank() == 0
|
|
28
|
+
return True # Single GPU or CPU
|
|
29
|
+
|
|
30
|
+
def on_train_begin(self, engine, model, **kwargs):
|
|
31
|
+
if self.is_driver:
|
|
32
|
+
self.callback.on_train_begin(engine, model, **kwargs)
|
|
33
|
+
|
|
34
|
+
def on_train_end(self, engine, model, **kwargs):
|
|
35
|
+
if self.is_driver:
|
|
36
|
+
self.callback.on_train_end(engine, model, **kwargs)
|
|
37
|
+
|
|
38
|
+
def on_epoch_begin(self, epoch, engine, model, **kwargs):
|
|
39
|
+
if self.is_driver:
|
|
40
|
+
self.callback.on_epoch_begin(epoch, engine, model, **kwargs)
|
|
41
|
+
|
|
42
|
+
def on_epoch_end(self, epoch, engine, model, logs=None, **kwargs):
|
|
43
|
+
if self.is_driver:
|
|
44
|
+
self.callback.on_epoch_end(epoch, engine, model, logs, **kwargs)
|
|
45
|
+
|
|
46
|
+
def on_train_batch_begin(self, batch, engine, model, **kwargs):
|
|
47
|
+
if self.is_driver:
|
|
48
|
+
self.callback.on_train_batch_begin(batch, engine, model, **kwargs)
|
|
49
|
+
|
|
50
|
+
def on_train_batch_end(self, batch, engine, model, logs=None, **kwargs):
|
|
51
|
+
if self.is_driver:
|
|
52
|
+
self.callback.on_train_batch_end(batch, engine, model, logs, **kwargs)
|
|
53
|
+
|
|
54
|
+
def on_validation_begin(self, epoch, engine, model, **kwargs):
|
|
55
|
+
if self.is_driver:
|
|
56
|
+
self.callback.on_validation_begin(epoch, engine, model, **kwargs)
|
|
57
|
+
|
|
58
|
+
def on_validation_end(self, epoch, engine, model, logs=None, **kwargs):
|
|
59
|
+
if self.is_driver:
|
|
60
|
+
self.callback.on_validation_end(epoch, engine, model, logs, **kwargs)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from kito.callbacks.callback_base import Callback
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class EarlyStoppingCallback(Callback):
|
|
5
|
+
"""
|
|
6
|
+
Stop training when monitored metric stops improving.
|
|
7
|
+
|
|
8
|
+
Args:
|
|
9
|
+
monitor: Metric to monitor
|
|
10
|
+
patience: Number of epochs with no improvement before stopping
|
|
11
|
+
mode: 'min' or 'max'
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, monitor='val_loss', patience=10, mode='min'):
|
|
15
|
+
self.monitor = monitor
|
|
16
|
+
self.patience = patience
|
|
17
|
+
self.mode = mode
|
|
18
|
+
self.best = float('inf') if mode == 'min' else float('-inf')
|
|
19
|
+
self.wait = 0
|
|
20
|
+
self.stopped_epoch = 0
|
|
21
|
+
|
|
22
|
+
def on_epoch_end(self, epoch, engine, model, logs=None, **kwargs):
|
|
23
|
+
"""Check if training should stop."""
|
|
24
|
+
if logs is None:
|
|
25
|
+
return
|
|
26
|
+
|
|
27
|
+
current = logs.get(self.monitor)
|
|
28
|
+
if current is None:
|
|
29
|
+
return
|
|
30
|
+
|
|
31
|
+
# Check if improved
|
|
32
|
+
improved = (
|
|
33
|
+
(self.mode == 'min' and current < self.best) or
|
|
34
|
+
(self.mode == 'max' and current > self.best)
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
if improved:
|
|
38
|
+
self.best = current
|
|
39
|
+
self.wait = 0
|
|
40
|
+
else:
|
|
41
|
+
self.wait += 1
|
|
42
|
+
if self.wait >= self.patience:
|
|
43
|
+
self.stopped_epoch = epoch
|
|
44
|
+
engine.stop_training = True # Requires Engine to support this
|
|
45
|
+
print(f"\nEarly stopping at epoch {epoch}")
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from kito.callbacks.callback_base import Callback
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ModelCheckpoint(Callback):
|
|
8
|
+
"""
|
|
9
|
+
Save model weights during training.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
filepath: Path template for saving weights (can include {epoch}, {val_loss}, etc.)
|
|
13
|
+
monitor: Metric to monitor (e.g., 'val_loss')
|
|
14
|
+
save_best_only: Only save when monitored metric improves
|
|
15
|
+
mode: 'min' or 'max' depending on whether lower/higher is better
|
|
16
|
+
verbose: Print message when saving
|
|
17
|
+
|
|
18
|
+
Example:
|
|
19
|
+
checkpoint = ModelCheckpoint(
|
|
20
|
+
filepath='weights/model_epoch{epoch:02d}_valloss{val_loss:.4f}.pt',
|
|
21
|
+
monitor='val_loss',
|
|
22
|
+
save_best_only=True,
|
|
23
|
+
mode='min'
|
|
24
|
+
)
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
filepath: str,
|
|
30
|
+
monitor: str = 'val_loss',
|
|
31
|
+
save_best_only: bool = True,
|
|
32
|
+
mode: str = 'min',
|
|
33
|
+
verbose: bool = False
|
|
34
|
+
):
|
|
35
|
+
self.filepath = filepath
|
|
36
|
+
self.monitor = monitor
|
|
37
|
+
self.save_best_only = save_best_only
|
|
38
|
+
self.mode = mode
|
|
39
|
+
self.verbose = verbose
|
|
40
|
+
|
|
41
|
+
# Track best metric
|
|
42
|
+
self.best = float('inf') if mode == 'min' else float('-inf')
|
|
43
|
+
|
|
44
|
+
# Create directory if needed
|
|
45
|
+
os.makedirs(os.path.dirname(filepath) if os.path.dirname(filepath) else '.', exist_ok=True)
|
|
46
|
+
|
|
47
|
+
def on_epoch_end(self, epoch, engine, model, logs=None, **kwargs):
|
|
48
|
+
"""Save model if metric improved."""
|
|
49
|
+
if logs is None:
|
|
50
|
+
return
|
|
51
|
+
|
|
52
|
+
current = logs.get(self.monitor)
|
|
53
|
+
if current is None:
|
|
54
|
+
return
|
|
55
|
+
|
|
56
|
+
# Check if improved
|
|
57
|
+
improved = (
|
|
58
|
+
(self.mode == 'min' and current < self.best) or
|
|
59
|
+
(self.mode == 'max' and current > self.best)
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
if improved or not self.save_best_only:
|
|
63
|
+
self.best = current
|
|
64
|
+
|
|
65
|
+
# Format filepath
|
|
66
|
+
filepath = self.filepath.format(epoch=epoch, **logs)
|
|
67
|
+
|
|
68
|
+
# Save model (handle DDP)
|
|
69
|
+
if hasattr(model, 'module'):
|
|
70
|
+
state_dict = model.module.state_dict()
|
|
71
|
+
else:
|
|
72
|
+
state_dict = model.state_dict()
|
|
73
|
+
|
|
74
|
+
torch.save(state_dict, filepath)
|
|
75
|
+
|
|
76
|
+
if self.verbose:
|
|
77
|
+
print(f"\nEpoch {epoch}: {self.monitor} improved to {current:.4f}, "
|
|
78
|
+
f"saving model to {filepath}")
|