pytorch-kito 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kito/__init__.py +49 -0
- kito/callbacks/__init__.py +20 -0
- kito/callbacks/callback_base.py +107 -0
- kito/callbacks/csv_logger.py +66 -0
- kito/callbacks/ddp_aware_callback.py +60 -0
- kito/callbacks/early_stopping_callback.py +45 -0
- kito/callbacks/modelcheckpoint.py +78 -0
- kito/callbacks/tensorboard_callback_images.py +298 -0
- kito/callbacks/tensorboard_callbacks.py +132 -0
- kito/callbacks/txt_logger.py +57 -0
- kito/config/__init__.py +0 -0
- kito/config/moduleconfig.py +201 -0
- kito/data/__init__.py +35 -0
- kito/data/datapipeline.py +273 -0
- kito/data/datasets.py +166 -0
- kito/data/preprocessed_dataset.py +57 -0
- kito/data/preprocessing.py +318 -0
- kito/data/registry.py +96 -0
- kito/engine.py +841 -0
- kito/module.py +447 -0
- kito/strategies/__init__.py +0 -0
- kito/strategies/logger_strategy.py +51 -0
- kito/strategies/progress_bar_strategy.py +57 -0
- kito/strategies/readiness_validator.py +85 -0
- kito/utils/__init__.py +0 -0
- kito/utils/decorators.py +45 -0
- kito/utils/gpu_utils.py +94 -0
- kito/utils/loss_utils.py +38 -0
- kito/utils/ssim_utils.py +94 -0
- pytorch_kito-0.2.0.dist-info/METADATA +328 -0
- pytorch_kito-0.2.0.dist-info/RECORD +34 -0
- pytorch_kito-0.2.0.dist-info/WHEEL +5 -0
- pytorch_kito-0.2.0.dist-info/licenses/LICENSE +21 -0
- pytorch_kito-0.2.0.dist-info/top_level.txt +1 -0
kito/utils/gpu_utils.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def assign_device(device_type: str = "cuda", gpu_id: int = 0) -> torch.device:
|
|
5
|
+
"""
|
|
6
|
+
Assign device based on preference and availability.
|
|
7
|
+
|
|
8
|
+
Args:
|
|
9
|
+
device_type: Preferred device ("cuda", "mps", or "cpu")
|
|
10
|
+
gpu_id: GPU ID for CUDA
|
|
11
|
+
|
|
12
|
+
Returns:
|
|
13
|
+
torch.device: Best available device
|
|
14
|
+
|
|
15
|
+
Behavior:
|
|
16
|
+
- "cuda": Use CUDA if available, else CPU (with warning)
|
|
17
|
+
- "mps": Use MPS if available, else CPU (with warning)
|
|
18
|
+
- "cpu": Always CPU (no fallback)
|
|
19
|
+
"""
|
|
20
|
+
valid_devices = {"cuda", "mps", "cpu"} # already valitated before, this check might be removed
|
|
21
|
+
device_type_lower = device_type.lower()
|
|
22
|
+
|
|
23
|
+
if device_type_lower not in valid_devices:
|
|
24
|
+
raise ValueError(
|
|
25
|
+
f"Invalid device_type: '{device_type}'. "
|
|
26
|
+
f"Must be one of {valid_devices}."
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
if device_type_lower == "cuda":
|
|
30
|
+
if torch.cuda.is_available():
|
|
31
|
+
return torch.device(f"cuda:{gpu_id}")
|
|
32
|
+
else:
|
|
33
|
+
print("Warning: CUDA requested but not available. Falling back to CPU.")
|
|
34
|
+
return torch.device("cpu")
|
|
35
|
+
|
|
36
|
+
elif device_type_lower == "mps":
|
|
37
|
+
if torch.backends.mps.is_available():
|
|
38
|
+
return torch.device("mps")
|
|
39
|
+
else:
|
|
40
|
+
print("Warning: MPS requested but not available. Falling back to CPU.")
|
|
41
|
+
return torch.device("cpu")
|
|
42
|
+
|
|
43
|
+
elif device_type_lower == "cpu":
|
|
44
|
+
return torch.device("cpu")
|
|
45
|
+
|
|
46
|
+
def get_available_devices() -> dict:
|
|
47
|
+
"""
|
|
48
|
+
Get information about all available devices.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
dict: Dictionary with device availability information
|
|
52
|
+
|
|
53
|
+
Example:
|
|
54
|
+
>>> info = get_available_devices()
|
|
55
|
+
>>> print(info)
|
|
56
|
+
{
|
|
57
|
+
'cuda': True,
|
|
58
|
+
'cuda_count': 2,
|
|
59
|
+
'mps': False,
|
|
60
|
+
'cpu': True
|
|
61
|
+
}
|
|
62
|
+
"""
|
|
63
|
+
return {
|
|
64
|
+
'cuda': torch.cuda.is_available(),
|
|
65
|
+
'cuda_count': torch.cuda.device_count() if torch.cuda.is_available() else 0,
|
|
66
|
+
'mps': torch.backends.mps.is_available(),
|
|
67
|
+
'cpu': True # CPU is always available
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def validate_device_type(device_type: str, raise_error: bool = True) -> bool:
|
|
72
|
+
"""
|
|
73
|
+
Validate device_type string.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
device_type: Device type to validate
|
|
77
|
+
raise_error: If True, raise ValueError on invalid type
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
bool: True if valid, False otherwise
|
|
81
|
+
|
|
82
|
+
Raises:
|
|
83
|
+
ValueError: If device_type is invalid and raise_error=True
|
|
84
|
+
"""
|
|
85
|
+
valid_devices = {"cuda", "mps", "cpu"}
|
|
86
|
+
is_valid = device_type.lower() in valid_devices
|
|
87
|
+
|
|
88
|
+
if not is_valid and raise_error:
|
|
89
|
+
raise ValueError(
|
|
90
|
+
f"Invalid device_type: '{device_type}'. "
|
|
91
|
+
f"Must be one of {valid_devices}."
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
return is_valid
|
kito/utils/loss_utils.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from torch import nn
|
|
2
|
+
|
|
3
|
+
from kito.utils.ssim_utils import ssim_loss
|
|
4
|
+
|
|
5
|
+
_torch_loss_dict = {
|
|
6
|
+
'mean_squared_error': 'MSELoss',
|
|
7
|
+
'mean_absolute_error': 'L1Loss',
|
|
8
|
+
'cross_entropy_loss': 'CrossEntropyLoss',
|
|
9
|
+
'ctc_loss': 'CTCLoss',
|
|
10
|
+
'negative_log_likelihood_loss': 'NLLLoss',
|
|
11
|
+
'negative_log_likelihood_poisson_loss': 'PoissonNLLLoss',
|
|
12
|
+
'negative_log_likelihood_gaussian_loss': 'GaussianNLLLoss',
|
|
13
|
+
'kullback_leibler_divergence_loss': 'KLDivLoss',
|
|
14
|
+
'binary_cross_entropy_loss': 'BCELoss',
|
|
15
|
+
'binary_cross_entropy_logits_loss': 'BCEWithLogitsLoss',
|
|
16
|
+
'margin_ranking_loss': 'MarginRankingLoss',
|
|
17
|
+
'hinge_embedding_loss': 'HingeEmbeddingLoss',
|
|
18
|
+
'multi_label_margin_loss': 'MultiLabelMarginLoss',
|
|
19
|
+
'huber_loss': 'HuberLoss',
|
|
20
|
+
'smooth_l1_loss': 'SmoothL1Loss',
|
|
21
|
+
'soft_margin_loss': 'SoftMarginLoss',
|
|
22
|
+
'multi_label_soft_margin_loss': 'MultiLabelSoftMarginLoss',
|
|
23
|
+
'cosine_embedding_loss': 'CosineEmbeddingLoss',
|
|
24
|
+
'multi_margin_loss': 'MultiMarginLoss',
|
|
25
|
+
'triplet_margin_loss': 'TripletMarginLoss',
|
|
26
|
+
'triplet_margin_distance_loss': 'TripletMarginWithDistanceLoss'
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def get_loss(loss: str):
|
|
31
|
+
assert loss != ''
|
|
32
|
+
loss = loss.lower()
|
|
33
|
+
if loss in _torch_loss_dict:
|
|
34
|
+
return getattr(nn, _torch_loss_dict[loss])() # returns the instantiated object
|
|
35
|
+
if loss == 'ssim_loss':
|
|
36
|
+
return ssim_loss # return the signature of the custom loss: /!\ this is not the object
|
|
37
|
+
raise ValueError(
|
|
38
|
+
f"Loss '{loss}' not valid. Supported values are: {', '.join(map(repr, _torch_loss_dict.keys()))}")
|
kito/utils/ssim_utils.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def gaussian_kernel(window_size: int, sigma: float, device, dtype):
|
|
6
|
+
"""Create a 2D Gaussian kernel (window_size×window_size)."""
|
|
7
|
+
coords = torch.arange(window_size, device=device, dtype=dtype) - (window_size - 1) / 2.0
|
|
8
|
+
g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
|
|
9
|
+
g = g / g.sum()
|
|
10
|
+
# outer product to get 2D kernel
|
|
11
|
+
gauss2d = g.unsqueeze(1) @ g.unsqueeze(0)
|
|
12
|
+
return gauss2d
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def ssim_2d_single(img1, img2, window_size=11, sigma=1.5, K1=0.01, K2=0.03, L=1.0):
|
|
16
|
+
"""
|
|
17
|
+
Compute SSIM over a single 2D windowed image (float range [0, L]).
|
|
18
|
+
img1, img2: shape (1, 1, H, W) (batch=1, channel=1)
|
|
19
|
+
Returns mean SSIM over all pixels in that 2D image.
|
|
20
|
+
"""
|
|
21
|
+
# Create Gaussian window
|
|
22
|
+
device, dtype = img1.device, img1.dtype
|
|
23
|
+
kernel = gaussian_kernel(window_size, sigma, device, dtype)
|
|
24
|
+
kernel = kernel.unsqueeze(0).unsqueeze(0) # shape (1,1,win,win)
|
|
25
|
+
|
|
26
|
+
# Convolve to get local means
|
|
27
|
+
mu1 = F.conv2d(img1, kernel, padding=window_size // 2)
|
|
28
|
+
mu2 = F.conv2d(img2, kernel, padding=window_size // 2)
|
|
29
|
+
|
|
30
|
+
mu1_sq = mu1 * mu1
|
|
31
|
+
mu2_sq = mu2 * mu2
|
|
32
|
+
mu1_mu2 = mu1 * mu2
|
|
33
|
+
|
|
34
|
+
# local variances: E[x^2] - (E[x])^2
|
|
35
|
+
sigma1_sq = F.conv2d(img1 * img1, kernel, padding=window_size // 2) - mu1_sq
|
|
36
|
+
sigma2_sq = F.conv2d(img2 * img2, kernel, padding=window_size // 2) - mu2_sq
|
|
37
|
+
sigma12 = F.conv2d(img1 * img2, kernel, padding=window_size // 2) - mu1_mu2
|
|
38
|
+
|
|
39
|
+
sigma1_sq = torch.clamp(sigma1_sq, min=1e-6)
|
|
40
|
+
sigma2_sq = torch.clamp(sigma2_sq, min=1e-6)
|
|
41
|
+
|
|
42
|
+
# SSIM constants
|
|
43
|
+
C1 = (K1 * L) ** 2
|
|
44
|
+
C2 = (K2 * L) ** 2
|
|
45
|
+
C3 = C2 / 2.0
|
|
46
|
+
|
|
47
|
+
# Luminance term
|
|
48
|
+
luminance = (2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)
|
|
49
|
+
# Contrast term
|
|
50
|
+
contrast = (2 * torch.sqrt(sigma1_sq * sigma2_sq) + C2) / (sigma1_sq + sigma2_sq + C2)
|
|
51
|
+
# Structure term
|
|
52
|
+
structure = (sigma12 + C3) / (torch.sqrt(sigma1_sq * sigma2_sq) + C3)
|
|
53
|
+
|
|
54
|
+
# Full SSIM map
|
|
55
|
+
ssim_map = luminance * contrast * structure
|
|
56
|
+
return ssim_map.mean()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def batch_time_ssim(pred, target, window_size=11, sigma=1.5):
|
|
60
|
+
"""
|
|
61
|
+
pred, target: tensors of shape (B, T, H, W, 1), normalized to [0,1].
|
|
62
|
+
Returns SSIM per (batch, time) as a tensor (B, T), and also a mean per-batch if desired.
|
|
63
|
+
"""
|
|
64
|
+
B, T, H, W, C = pred.shape
|
|
65
|
+
assert C == 1, "SSIM code below assumes single channel"
|
|
66
|
+
ssim_scores = torch.zeros((B, T), device=pred.device, dtype=pred.dtype)
|
|
67
|
+
|
|
68
|
+
for b in range(B):
|
|
69
|
+
for t in range(T):
|
|
70
|
+
im1 = target[b, t, ..., :].permute(2, 0, 1).unsqueeze(0) # shape (1,1,H,W)
|
|
71
|
+
im2 = pred[b, t, ..., :].permute(2, 0, 1).unsqueeze(0) # shape (1,1,H,W)
|
|
72
|
+
ssim_scores[b, t] = ssim_2d_single(im1, im2, window_size, sigma)
|
|
73
|
+
return ssim_scores # shape (B, T)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def minmax_normalize_sample(x):
|
|
77
|
+
# x: (T, H, W, 1)
|
|
78
|
+
mn = x.min(dim=0, keepdim=True)[0].min(dim=1, keepdim=True)[0].min(dim=2, keepdim=True)[0]
|
|
79
|
+
mx = x.max(dim=0, keepdim=True)[0].max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0]
|
|
80
|
+
denom = (mx - mn).clamp(min=1e-6)
|
|
81
|
+
return (x - mn) / denom
|
|
82
|
+
|
|
83
|
+
def ssim_loss(target, predicted):
|
|
84
|
+
no_disp_mask = (target.sum(dim=tuple(range(1, target.dim()))) == 0)
|
|
85
|
+
valid_regress_targets = target[~no_disp_mask]
|
|
86
|
+
valid_regress_outputs = predicted[~no_disp_mask]
|
|
87
|
+
B = len(valid_regress_targets)
|
|
88
|
+
pred = torch.stack([minmax_normalize_sample(valid_regress_outputs[b]) for b in range(B)], dim=0)
|
|
89
|
+
target = torch.stack([minmax_normalize_sample(valid_regress_targets[b]) for b in range(B)], dim=0)
|
|
90
|
+
|
|
91
|
+
ssim_bt = batch_time_ssim(pred, target, window_size=11, sigma=1.5) # shape (B, T)
|
|
92
|
+
ssim_per_sample = ssim_bt.mean(dim=1) # average SSIM over time for each batch
|
|
93
|
+
loss_ssim = (1.0 - ssim_per_sample).mean()
|
|
94
|
+
return loss_ssim
|
|
@@ -0,0 +1,328 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: pytorch-kito
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: Effortless PyTorch training - define your model, Kito handles the rest
|
|
5
|
+
Author-email: Giuseppe Costantino <giuseppe.costantino95@gmail.com>
|
|
6
|
+
License: MIT License
|
|
7
|
+
|
|
8
|
+
Copyright (c) 2026 Giuseppe Costantino
|
|
9
|
+
|
|
10
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
11
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
12
|
+
in the Software without restriction, including without limitation the rights
|
|
13
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
14
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
15
|
+
furnished to do so, subject to the following conditions:
|
|
16
|
+
|
|
17
|
+
The above copyright notice and this permission notice shall be included in all
|
|
18
|
+
copies or substantial portions of the Software.
|
|
19
|
+
|
|
20
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
21
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
22
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
23
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
24
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
25
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
26
|
+
SOFTWARE.
|
|
27
|
+
Project-URL: Homepage, https://github.com/gcostantino/kito
|
|
28
|
+
Project-URL: Documentation, https://kito.readthedocs.io
|
|
29
|
+
Project-URL: Repository, https://github.com/gcostantino/kito
|
|
30
|
+
Project-URL: Bug Tracker, https://github.com/gcostantino/kito/issues
|
|
31
|
+
Keywords: pytorch,deep-learning,machine-learning,training,automation
|
|
32
|
+
Classifier: Development Status :: 3 - Alpha
|
|
33
|
+
Classifier: Intended Audience :: Science/Research
|
|
34
|
+
Classifier: Intended Audience :: Developers
|
|
35
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
36
|
+
Classifier: Programming Language :: Python :: 3
|
|
37
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
38
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
39
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
40
|
+
Classifier: Topic :: Scientific/Engineering
|
|
41
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
42
|
+
Requires-Python: >=3.9
|
|
43
|
+
Description-Content-Type: text/markdown
|
|
44
|
+
License-File: LICENSE
|
|
45
|
+
Requires-Dist: torch>=2.0.0
|
|
46
|
+
Requires-Dist: numpy>=1.20.0
|
|
47
|
+
Requires-Dist: h5py>=3.0.0
|
|
48
|
+
Requires-Dist: torchsummary>=1.5.1
|
|
49
|
+
Requires-Dist: packaging>=20.0
|
|
50
|
+
Requires-Dist: tensorboard>=2.10.0
|
|
51
|
+
Requires-Dist: pkbar>=0.5
|
|
52
|
+
Provides-Extra: dev
|
|
53
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
54
|
+
Requires-Dist: pytest-cov>=4.0; extra == "dev"
|
|
55
|
+
Requires-Dist: pytest-mock>=3.10; extra == "dev"
|
|
56
|
+
Requires-Dist: black>=23.0; extra == "dev"
|
|
57
|
+
Requires-Dist: ruff>=0.1.0; extra == "dev"
|
|
58
|
+
Provides-Extra: tensorboard
|
|
59
|
+
Requires-Dist: tensorboard>=2.10.0; extra == "tensorboard"
|
|
60
|
+
Provides-Extra: docs
|
|
61
|
+
Requires-Dist: sphinx>=5.0; extra == "docs"
|
|
62
|
+
Requires-Dist: sphinx-rtd-theme>=1.0; extra == "docs"
|
|
63
|
+
Provides-Extra: all
|
|
64
|
+
Requires-Dist: tensorboard>=2.10.0; extra == "all"
|
|
65
|
+
Requires-Dist: pytest>=7.0; extra == "all"
|
|
66
|
+
Requires-Dist: pytest-cov>=4.0; extra == "all"
|
|
67
|
+
Requires-Dist: pytest-mock>=3.10; extra == "all"
|
|
68
|
+
Requires-Dist: black>=23.0; extra == "all"
|
|
69
|
+
Requires-Dist: ruff>=0.1.0; extra == "all"
|
|
70
|
+
Requires-Dist: sphinx>=5.0; extra == "all"
|
|
71
|
+
Requires-Dist: sphinx-rtd-theme>=1.0; extra == "all"
|
|
72
|
+
Dynamic: license-file
|
|
73
|
+
|
|
74
|
+
# Kito
|
|
75
|
+
|
|
76
|
+
**Effortless PyTorch training - define your model, Kito handles the rest.**
|
|
77
|
+
|
|
78
|
+
[](https://github.com/gcostantino/kito/actions)
|
|
79
|
+
[](https://pypi.org/project/pytorch-kito/)
|
|
80
|
+
[](https://opensource.org/licenses/MIT)
|
|
81
|
+
[](https://www.python.org/downloads/)
|
|
82
|
+
[](https://pypi.org/project/pytorch-kito/)
|
|
83
|
+
|
|
84
|
+
Kito is a lightweight PyTorch training library that eliminates boilerplate code. Define your model architecture and loss function - Kito automatically handles training loops, optimization, callbacks, distributed training, and more.
|
|
85
|
+
|
|
86
|
+
## ✨ Key Features
|
|
87
|
+
|
|
88
|
+
- **Zero Boilerplate** - No training loops, no optimizer setup, no device management
|
|
89
|
+
- **Auto-Everything** - Automatic model building, optimizer binding, and device assignment
|
|
90
|
+
- **Built-in DDP** - Distributed training works out of the box
|
|
91
|
+
- **Smart Callbacks** - TensorBoard, checkpointing, logging, and custom callbacks
|
|
92
|
+
- **Flexible** - Simple for beginners, powerful for experts
|
|
93
|
+
- **Lightweight** - Minimal dependencies, pure PyTorch under the hood
|
|
94
|
+
|
|
95
|
+
## Quick Start
|
|
96
|
+
|
|
97
|
+
### Installation
|
|
98
|
+
|
|
99
|
+
```bash
|
|
100
|
+
pip install pytorch-kito
|
|
101
|
+
```
|
|
102
|
+
|
|
103
|
+
### Your First Model in 3 Steps
|
|
104
|
+
|
|
105
|
+
```python
|
|
106
|
+
import torch.nn as nn
|
|
107
|
+
from kito import Engine, KitoModule
|
|
108
|
+
|
|
109
|
+
# 1. Define your model
|
|
110
|
+
class MyModel(KitoModule):
|
|
111
|
+
def build_inner_model(self):
|
|
112
|
+
self.model = nn.Sequential(
|
|
113
|
+
nn.Linear(784, 128),
|
|
114
|
+
nn.ReLU(),
|
|
115
|
+
nn.Linear(128, 10)
|
|
116
|
+
)
|
|
117
|
+
self.model_input_size = (784,)
|
|
118
|
+
|
|
119
|
+
def bind_optimizer(self):
|
|
120
|
+
self.optimizer = torch.optim.Adam(
|
|
121
|
+
self.model.parameters(),
|
|
122
|
+
lr=self.learning_rate
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# 2. Initialize
|
|
126
|
+
model = MyModel('MyModel', device, config)
|
|
127
|
+
engine = Engine(model, config)
|
|
128
|
+
|
|
129
|
+
# 3. Train! (That's it - everything else is automatic)
|
|
130
|
+
engine.fit(train_loader, val_loader, max_epochs=10)
|
|
131
|
+
```
|
|
132
|
+
|
|
133
|
+
## Philosophy
|
|
134
|
+
|
|
135
|
+
Kito follows a **"define once, train anywhere"** philosophy:
|
|
136
|
+
|
|
137
|
+
1. **You focus on**: Model architecture and research ideas
|
|
138
|
+
2. **Kito handles**: Training loops, optimization, distributed training, callbacks
|
|
139
|
+
|
|
140
|
+
Perfect for researchers who want to iterate quickly without rewriting training code.
|
|
141
|
+
|
|
142
|
+
## Core Concepts
|
|
143
|
+
|
|
144
|
+
### KitoModule
|
|
145
|
+
|
|
146
|
+
Your model inherits from `KitoModule` and implements two methods:
|
|
147
|
+
|
|
148
|
+
```python
|
|
149
|
+
class MyModel(KitoModule):
|
|
150
|
+
def build_inner_model(self):
|
|
151
|
+
# Define your architecture
|
|
152
|
+
self.model = nn.Sequential(...)
|
|
153
|
+
self.model_input_size = (C, H, W) # Input shape
|
|
154
|
+
|
|
155
|
+
def bind_optimizer(self):
|
|
156
|
+
# Choose your optimizer
|
|
157
|
+
self.optimizer = torch.optim.Adam(
|
|
158
|
+
self.model.parameters(),
|
|
159
|
+
lr=self.learning_rate
|
|
160
|
+
)
|
|
161
|
+
```
|
|
162
|
+
|
|
163
|
+
### Engine
|
|
164
|
+
|
|
165
|
+
The `Engine` orchestrates everything:
|
|
166
|
+
|
|
167
|
+
```python
|
|
168
|
+
engine = Engine(module, config)
|
|
169
|
+
|
|
170
|
+
# Training
|
|
171
|
+
engine.fit(train_loader, val_loader, max_epochs=100)
|
|
172
|
+
|
|
173
|
+
# Inference
|
|
174
|
+
predictions = engine.predict(test_loader)
|
|
175
|
+
```
|
|
176
|
+
|
|
177
|
+
### Data Pipeline
|
|
178
|
+
|
|
179
|
+
Kito provides a clean data pipeline with preprocessing:
|
|
180
|
+
|
|
181
|
+
```python
|
|
182
|
+
from kito.data import H5Dataset, GenericDataPipeline
|
|
183
|
+
from kito.data.preprocessing import Pipeline, Normalize, ToTensor
|
|
184
|
+
|
|
185
|
+
# Create dataset
|
|
186
|
+
dataset = H5Dataset('data.h5')
|
|
187
|
+
|
|
188
|
+
# Add preprocessing
|
|
189
|
+
preprocessing = Pipeline([
|
|
190
|
+
Normalize(min_val=0.0, max_val=1.0),
|
|
191
|
+
ToTensor()
|
|
192
|
+
])
|
|
193
|
+
|
|
194
|
+
# Setup data pipeline
|
|
195
|
+
pipeline = GenericDataPipeline(
|
|
196
|
+
config=config,
|
|
197
|
+
dataset=dataset,
|
|
198
|
+
preprocessing=preprocessing
|
|
199
|
+
)
|
|
200
|
+
pipeline.setup()
|
|
201
|
+
|
|
202
|
+
# Get dataloaders
|
|
203
|
+
train_loader = pipeline.train_dataloader()
|
|
204
|
+
val_loader = pipeline.val_dataloader()
|
|
205
|
+
```
|
|
206
|
+
|
|
207
|
+
### Callbacks
|
|
208
|
+
|
|
209
|
+
Kito includes powerful callbacks for common tasks:
|
|
210
|
+
|
|
211
|
+
```python
|
|
212
|
+
from kito.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger
|
|
213
|
+
|
|
214
|
+
callbacks = [
|
|
215
|
+
ModelCheckpoint('best_model.pt', monitor='val_loss', mode='min'),
|
|
216
|
+
EarlyStopping(patience=10, monitor='val_loss'),
|
|
217
|
+
CSVLogger('training.csv')
|
|
218
|
+
]
|
|
219
|
+
|
|
220
|
+
engine.fit(train_loader, val_loader, callbacks=callbacks)
|
|
221
|
+
```
|
|
222
|
+
|
|
223
|
+
Or create custom callbacks:
|
|
224
|
+
|
|
225
|
+
```python
|
|
226
|
+
from kito.callbacks import Callback
|
|
227
|
+
|
|
228
|
+
class MyCallback(Callback):
|
|
229
|
+
def on_epoch_end(self, epoch, logs, **kwargs):
|
|
230
|
+
print(f"Epoch {epoch}: loss={logs['train_loss']:.4f}")
|
|
231
|
+
```
|
|
232
|
+
|
|
233
|
+
## Advanced Features
|
|
234
|
+
|
|
235
|
+
### Distributed Training (DDP)
|
|
236
|
+
|
|
237
|
+
Enable distributed training with one config change:
|
|
238
|
+
|
|
239
|
+
```python
|
|
240
|
+
config.training.distributed_training = True
|
|
241
|
+
|
|
242
|
+
# Everything else stays the same!
|
|
243
|
+
engine.fit(train_loader, val_loader, max_epochs=100)
|
|
244
|
+
```
|
|
245
|
+
|
|
246
|
+
### Custom Training Logic
|
|
247
|
+
|
|
248
|
+
Override `training_step` for custom behavior:
|
|
249
|
+
|
|
250
|
+
```python
|
|
251
|
+
class MyModel(KitoModule):
|
|
252
|
+
def training_step(self, batch, pbar_handler=None):
|
|
253
|
+
inputs, targets = batch
|
|
254
|
+
|
|
255
|
+
# Custom forward pass
|
|
256
|
+
outputs = self.model(inputs)
|
|
257
|
+
loss = self.compute_loss((inputs, targets), outputs)
|
|
258
|
+
|
|
259
|
+
# Custom backward (e.g., gradient clipping)
|
|
260
|
+
self.optimizer.zero_grad()
|
|
261
|
+
loss.backward()
|
|
262
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
263
|
+
self.optimizer.step()
|
|
264
|
+
|
|
265
|
+
return {'loss': loss}
|
|
266
|
+
```
|
|
267
|
+
|
|
268
|
+
### Multiple Datasets
|
|
269
|
+
|
|
270
|
+
Kito supports HDF5 and in-memory datasets out of the box:
|
|
271
|
+
|
|
272
|
+
```python
|
|
273
|
+
from kito.data import H5Dataset, MemDataset
|
|
274
|
+
|
|
275
|
+
# HDF5 dataset (lazy loading)
|
|
276
|
+
dataset = H5Dataset('large_data.h5')
|
|
277
|
+
|
|
278
|
+
# In-memory dataset (fast)
|
|
279
|
+
dataset = MemDataset(x_train, y_train)
|
|
280
|
+
```
|
|
281
|
+
|
|
282
|
+
Register custom datasets easily:
|
|
283
|
+
|
|
284
|
+
```python
|
|
285
|
+
from kito.data.registry import DATASETS
|
|
286
|
+
|
|
287
|
+
@DATASETS.register('my_custom_dataset')
|
|
288
|
+
class MyDataset(KitoDataset):
|
|
289
|
+
def _load_sample(self, index):
|
|
290
|
+
return data, labels
|
|
291
|
+
```
|
|
292
|
+
|
|
293
|
+
## 📦 Installation Options
|
|
294
|
+
|
|
295
|
+
```bash
|
|
296
|
+
# Basic installation
|
|
297
|
+
pip install pytorch-kito
|
|
298
|
+
|
|
299
|
+
# With TensorBoard support
|
|
300
|
+
pip install pytorch-kito[tensorboard]
|
|
301
|
+
|
|
302
|
+
# Development installation
|
|
303
|
+
pip install pytorch-kito[dev]
|
|
304
|
+
|
|
305
|
+
# Everything
|
|
306
|
+
pip install pytorch-kito[all]
|
|
307
|
+
```
|
|
308
|
+
|
|
309
|
+
## 🤝 Contributing
|
|
310
|
+
|
|
311
|
+
Contributions are very welcome! Please check out our [Contributing Guide](CONTRIBUTING.md).
|
|
312
|
+
|
|
313
|
+
## 📄 License
|
|
314
|
+
|
|
315
|
+
MIT License - see [LICENSE](LICENSE) file for details.
|
|
316
|
+
|
|
317
|
+
## 🙏 Acknowledgments
|
|
318
|
+
|
|
319
|
+
Kito is inspired by PyTorch Lightning and Keras, aiming to bring similar ease-of-use to pure PyTorch workflows for researchers.
|
|
320
|
+
|
|
321
|
+
## Contact
|
|
322
|
+
|
|
323
|
+
- **GitHub Issues**: [Report bugs or request features](https://github.com/gcostantino/kito/issues)
|
|
324
|
+
- **GitHub Discussions**: [Ask questions](https://github.com/gcostantino/kito/discussions)
|
|
325
|
+
|
|
326
|
+
---
|
|
327
|
+
|
|
328
|
+
**Made with ❤️ for the PyTorch community**
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
kito/__init__.py,sha256=j5He-bAp7SZYkaisFKyxHtByt-5t4x8tZ2uThTwFrTc,938
|
|
2
|
+
kito/engine.py,sha256=bmuDndaWIXy4gOdsTVyvGjQxKiHrZczzlI09ozfBRwI,30617
|
|
3
|
+
kito/module.py,sha256=Zc3doiqXfAIXNdS6ZYQqdudgRO_bXT2VaABVUbkRvVY,13931
|
|
4
|
+
kito/callbacks/__init__.py,sha256=Yt9BBCepsgnGCbwTsrwLkBjqIbDdsCxrbHERBbQ-748,624
|
|
5
|
+
kito/callbacks/callback_base.py,sha256=AW_h-PEHuOJJA7kEcv7eSl_Gq7k8dLgqdNQX4XvUBjs,3661
|
|
6
|
+
kito/callbacks/csv_logger.py,sha256=8vfUt7NO_XyUoqBNIlpC7tGEnAhbRjGJsXwErFdspH0,1850
|
|
7
|
+
kito/callbacks/ddp_aware_callback.py,sha256=nty5N86jE9Ha26FBbUQpaNrZ5i8ex9lPg_l6bgI6IIQ,2124
|
|
8
|
+
kito/callbacks/early_stopping_callback.py,sha256=lqrV_FDy0xFbPAkfq64eQcDS-yWbJpgfKA7e0hqJaGA,1372
|
|
9
|
+
kito/callbacks/modelcheckpoint.py,sha256=_HFrgdz2fjKsUIq_wj6I_eHbbUViVdgJF3Se48FCB9Q,2376
|
|
10
|
+
kito/callbacks/tensorboard_callback_images.py,sha256=Vpx-63wXfHHy5MJbibJQtuqAkOBITtjKmmJZXfjn9Sk,9210
|
|
11
|
+
kito/callbacks/tensorboard_callbacks.py,sha256=xguPSogXX_U1tyOsolryNcazqCdTBHDWun5FoSrl4Tk,3701
|
|
12
|
+
kito/callbacks/txt_logger.py,sha256=dNDMrwEcbC5coUua4R8rJqiIb5xYY7lyIh9maamE48o,1687
|
|
13
|
+
kito/config/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
|
+
kito/config/moduleconfig.py,sha256=qY49kA-zFA4OUuE9lCrJh9cGwf6UjuAoI1k04lsNer0,5631
|
|
15
|
+
kito/data/__init__.py,sha256=h9HIKEHtC3CYD3aWsAUh_dcWYiQ6fmiTOcsXhnuhYGE,723
|
|
16
|
+
kito/data/datapipeline.py,sha256=CgYAtjTN_URPjXi1VvS6GY7nLCa5n3ow-F406St78jI,8292
|
|
17
|
+
kito/data/datasets.py,sha256=L3ClAw9rCkCuXavgaEtSUC6mVr-WLVv6DYhOpeOvIUE,4833
|
|
18
|
+
kito/data/preprocessed_dataset.py,sha256=hiLdMvyvjNWpgAylV07bnifAPAhTeN6eOaeBmIePRRA,1733
|
|
19
|
+
kito/data/preprocessing.py,sha256=VawgS0Kci5gvuvPwurWE-Dj3vV2s0s5jsCdz0FuepMQ,8325
|
|
20
|
+
kito/data/registry.py,sha256=rU8qXlGavSVbwJWEG_1_rbh3imiNpYgPuCz8bqY8_VQ,2279
|
|
21
|
+
kito/strategies/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
22
|
+
kito/strategies/logger_strategy.py,sha256=76mxBKs1ZlYoZ7MHjfujH9driye3Q_w8wW0BqVX15oE,1269
|
|
23
|
+
kito/strategies/progress_bar_strategy.py,sha256=z2GzKnrXyDiCJahhdNTZobS-lTb2YuOoeYVSx5ybUHU,1791
|
|
24
|
+
kito/strategies/readiness_validator.py,sha256=YUV201JsKtaMgKOV93OVpBzGJjbm9sYrb2SsBCnx_jw,2684
|
|
25
|
+
kito/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
26
|
+
kito/utils/decorators.py,sha256=p0gWNNL4F0v1zafxI7BDxUVNZugda1Vw8d4lMDIFt_g,1650
|
|
27
|
+
kito/utils/gpu_utils.py,sha256=MLsOWBLmu92E71amanSc95NSJ38448QBFBaajYh1eCM,2747
|
|
28
|
+
kito/utils/loss_utils.py,sha256=u5rWWBJB1e9pGLSSQARHAe5fI9ZOuLzZqu_22VDebKY,1555
|
|
29
|
+
kito/utils/ssim_utils.py,sha256=gBNCWi6X1KzWofQkJXuzJ6qkXx4Ja8faCKGyWSmJj-k,3771
|
|
30
|
+
pytorch_kito-0.2.0.dist-info/licenses/LICENSE,sha256=yeTd9I65vu2Fo9VOKMHVhFt9UlTbRAAGxhR7OO9Dvbs,1075
|
|
31
|
+
pytorch_kito-0.2.0.dist-info/METADATA,sha256=loEohuioKK1FLmMqWnGCGsAzJQv7s-ILdDl8xI8OQcw,9977
|
|
32
|
+
pytorch_kito-0.2.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
33
|
+
pytorch_kito-0.2.0.dist-info/top_level.txt,sha256=hDmlE0Vyhj0KWg8m24LOrJKBeDOIqL8b6-Pr5kwQRGM,5
|
|
34
|
+
pytorch_kito-0.2.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Giuseppe Costantino
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
kito
|