torchloop 0.2.0__tar.gz → 0.2.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torchloop-0.2.0 → torchloop-0.2.3}/.github/workflows/publish.yml +1 -1
- torchloop-0.2.3/CHANGELOG.md +40 -0
- {torchloop-0.2.0 → torchloop-0.2.3}/PKG-INFO +62 -13
- {torchloop-0.2.0 → torchloop-0.2.3}/README.md +39 -5
- torchloop-0.2.3/examples/edge_deployment.py +67 -0
- {torchloop-0.2.0 → torchloop-0.2.3}/pyproject.toml +21 -7
- torchloop-0.2.3/requirements.txt +6 -0
- torchloop-0.2.3/src/torchloop/__init__.py +37 -0
- torchloop-0.2.3/src/torchloop/callbacks.py +216 -0
- torchloop-0.2.3/src/torchloop/edge/__init__.py +6 -0
- torchloop-0.2.3/src/torchloop/edge/deploy.py +209 -0
- torchloop-0.2.3/src/torchloop/edge/estimate.py +105 -0
- {torchloop-0.2.0 → torchloop-0.2.3}/src/torchloop/trainer.py +113 -27
- torchloop-0.2.3/tests/conftest.py +12 -0
- torchloop-0.2.3/tests/test_callbacks.py +191 -0
- torchloop-0.2.3/tests/test_edge.py +296 -0
- torchloop-0.2.3/tests/test_exporter.py +193 -0
- {torchloop-0.2.0 → torchloop-0.2.3}/tests/test_trainer.py +118 -3
- torchloop-0.2.0/src/torchloop/__init__.py +0 -16
- {torchloop-0.2.0 → torchloop-0.2.3}/.github/workflows/ci.yml +0 -0
- {torchloop-0.2.0 → torchloop-0.2.3}/.gitignore +0 -0
- {torchloop-0.2.0 → torchloop-0.2.3}/LICENSE +0 -0
- {torchloop-0.2.0 → torchloop-0.2.3}/src/torchloop/evaluator.py +0 -0
- {torchloop-0.2.0 → torchloop-0.2.3}/src/torchloop/exporter.py +0 -0
- {torchloop-0.2.0 → torchloop-0.2.3}/tests/__init__.py +0 -0
- {torchloop-0.2.0 → torchloop-0.2.3}/tests/test_evaluator.py +0 -0
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# Changelog
|
|
2
|
+
|
|
3
|
+
All notable changes to this project will be documented in this file.
|
|
4
|
+
|
|
5
|
+
The format is based on Keep a Changelog, and this project adheres to Semantic Versioning.
|
|
6
|
+
|
|
7
|
+
## [0.2.3] - 2026-03-27
|
|
8
|
+
|
|
9
|
+
### Added
|
|
10
|
+
|
|
11
|
+
- Trainer gradient accumulation via `accumulate_steps`.
|
|
12
|
+
- Callback system with `Callback`, `EarlyStopping`, `ModelCheckpoint`, `CSVLogger`, and `StopTraining`.
|
|
13
|
+
- Edge deployment submodule with `deploy_to_edge` and `estimate_model`.
|
|
14
|
+
- Example script for edge deployment in `examples/edge_deployment.py`.
|
|
15
|
+
- Expanded test suite for callbacks, exporter, and edge deployment logic.
|
|
16
|
+
|
|
17
|
+
### Changed
|
|
18
|
+
|
|
19
|
+
- AMP configuration now prefers `use_amp`; legacy `amp` is kept as a deprecated alias.
|
|
20
|
+
- Public exports updated to include `Exporter`, callbacks, and edge APIs.
|
|
21
|
+
- Project metadata updated for v0.2.3 and edge/dev optional dependencies.
|
|
22
|
+
- README updated with callback usage and edge deployment documentation.
|
|
23
|
+
|
|
24
|
+
### Fixed
|
|
25
|
+
|
|
26
|
+
- Pytest import resolution for src-layout test runs.
|
|
27
|
+
- Deterministic early-stopping test behavior in trainer test suite.
|
|
28
|
+
|
|
29
|
+
## [0.2.0] - 2026-03-xx
|
|
30
|
+
|
|
31
|
+
### Added
|
|
32
|
+
|
|
33
|
+
- Automatic mixed precision (AMP) support in Trainer.
|
|
34
|
+
- Learning-rate scheduler integration in Trainer.
|
|
35
|
+
|
|
36
|
+
## [0.1.0] - 2026-03-xx
|
|
37
|
+
|
|
38
|
+
### Added
|
|
39
|
+
|
|
40
|
+
- Core `Trainer`, `Evaluator`, and `Exporter` APIs.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torchloop
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.3
|
|
4
4
|
Summary: Lightweight PyTorch utility library for training, evaluation, and TFLite export — without the framework lock-in.
|
|
5
5
|
Project-URL: Homepage, https://github.com/Tharun007-TK/torchloop
|
|
6
6
|
Project-URL: Repository, https://github.com/Tharun007-TK/torchloop
|
|
@@ -40,24 +40,39 @@ Classifier: Programming Language :: Python :: 3.11
|
|
|
40
40
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
41
41
|
Requires-Python: >=3.9
|
|
42
42
|
Requires-Dist: matplotlib>=3.7.0
|
|
43
|
-
Requires-Dist: numpy>=1.
|
|
44
|
-
Requires-Dist: scikit-learn>=1.
|
|
45
|
-
Requires-Dist: torch
|
|
43
|
+
Requires-Dist: numpy>=1.20
|
|
44
|
+
Requires-Dist: scikit-learn>=1.0
|
|
45
|
+
Requires-Dist: torch<2.6,>=2.0
|
|
46
46
|
Requires-Dist: torchvision>=0.15.0
|
|
47
47
|
Requires-Dist: tqdm>=4.65.0
|
|
48
48
|
Provides-Extra: all
|
|
49
|
+
Requires-Dist: black>=23.0; extra == 'all'
|
|
50
|
+
Requires-Dist: coremltools>=7.0; (platform_system == 'Darwin') and extra == 'all'
|
|
49
51
|
Requires-Dist: hatch>=1.7.0; extra == 'all'
|
|
52
|
+
Requires-Dist: mypy>=1.0; extra == 'all'
|
|
50
53
|
Requires-Dist: onnx>=1.14.0; extra == 'all'
|
|
54
|
+
Requires-Dist: onnx>=1.15; extra == 'all'
|
|
51
55
|
Requires-Dist: onnxruntime>=1.15.0; extra == 'all'
|
|
52
|
-
Requires-Dist:
|
|
53
|
-
Requires-Dist: pytest>=
|
|
56
|
+
Requires-Dist: onnxruntime>=1.16; extra == 'all'
|
|
57
|
+
Requires-Dist: pytest-cov>=4.0; extra == 'all'
|
|
58
|
+
Requires-Dist: pytest>=7.0; extra == 'all'
|
|
54
59
|
Requires-Dist: ruff>=0.1.0; extra == 'all'
|
|
60
|
+
Requires-Dist: tensorflow-lite-runtime>=2.15; (platform_system != 'Darwin') and extra == 'all'
|
|
55
61
|
Requires-Dist: tensorflow>=2.13.0; extra == 'all'
|
|
62
|
+
Requires-Dist: torchinfo>=1.8; extra == 'all'
|
|
56
63
|
Provides-Extra: dev
|
|
64
|
+
Requires-Dist: black>=23.0; extra == 'dev'
|
|
57
65
|
Requires-Dist: hatch>=1.7.0; extra == 'dev'
|
|
58
|
-
Requires-Dist:
|
|
59
|
-
Requires-Dist: pytest>=
|
|
66
|
+
Requires-Dist: mypy>=1.0; extra == 'dev'
|
|
67
|
+
Requires-Dist: pytest-cov>=4.0; extra == 'dev'
|
|
68
|
+
Requires-Dist: pytest>=7.0; extra == 'dev'
|
|
60
69
|
Requires-Dist: ruff>=0.1.0; extra == 'dev'
|
|
70
|
+
Requires-Dist: torchinfo>=1.8; extra == 'dev'
|
|
71
|
+
Provides-Extra: edge
|
|
72
|
+
Requires-Dist: coremltools>=7.0; (platform_system == 'Darwin') and extra == 'edge'
|
|
73
|
+
Requires-Dist: onnx>=1.15; extra == 'edge'
|
|
74
|
+
Requires-Dist: onnxruntime>=1.16; extra == 'edge'
|
|
75
|
+
Requires-Dist: tensorflow-lite-runtime>=2.15; (platform_system != 'Darwin') and extra == 'edge'
|
|
61
76
|
Provides-Extra: export
|
|
62
77
|
Requires-Dist: onnx>=1.14.0; extra == 'export'
|
|
63
78
|
Requires-Dist: onnxruntime>=1.15.0; extra == 'export'
|
|
@@ -86,10 +101,17 @@ You write the same PyTorch training loop in every project. Same checkpoint logic
|
|
|
86
101
|
## Install
|
|
87
102
|
|
|
88
103
|
```bash
|
|
104
|
+
# Base installation
|
|
89
105
|
pip install torchloop
|
|
90
106
|
|
|
91
107
|
# With TFLite export support
|
|
92
108
|
pip install torchloop[export]
|
|
109
|
+
|
|
110
|
+
# With edge deployment support
|
|
111
|
+
pip install torchloop[edge]
|
|
112
|
+
|
|
113
|
+
# Development setup
|
|
114
|
+
pip install torchloop[dev]
|
|
93
115
|
```
|
|
94
116
|
|
|
95
117
|
---
|
|
@@ -99,16 +121,21 @@ pip install torchloop[export]
|
|
|
99
121
|
### Training
|
|
100
122
|
|
|
101
123
|
```python
|
|
102
|
-
from torchloop import Trainer
|
|
124
|
+
from torchloop import EarlyStopping, ModelCheckpoint, Trainer
|
|
103
125
|
|
|
104
126
|
trainer = Trainer(
|
|
105
127
|
model,
|
|
106
128
|
optimizer=torch.optim.Adam(model.parameters()),
|
|
107
129
|
criterion=torch.nn.CrossEntropyLoss(),
|
|
108
130
|
device="cuda",
|
|
109
|
-
|
|
131
|
+
use_amp=True,
|
|
132
|
+
accumulate_steps=4,
|
|
133
|
+
patience=5,
|
|
110
134
|
)
|
|
111
135
|
|
|
136
|
+
trainer.add_callback(EarlyStopping(patience=5))
|
|
137
|
+
trainer.add_callback(ModelCheckpoint(filepath="best.pt"))
|
|
138
|
+
|
|
112
139
|
history = trainer.fit(train_loader, val_loader, epochs=30)
|
|
113
140
|
trainer.save("best.pt")
|
|
114
141
|
```
|
|
@@ -139,6 +166,25 @@ exp.to_onnx("model.onnx")
|
|
|
139
166
|
exp.to_tflite("model.tflite", quantize=True)
|
|
140
167
|
```
|
|
141
168
|
|
|
169
|
+
### Edge Deployment
|
|
170
|
+
|
|
171
|
+
```python
|
|
172
|
+
from torchloop.edge import deploy_to_edge, estimate_model
|
|
173
|
+
|
|
174
|
+
stats = estimate_model(model, (1, 3, 224, 224), target_device="esp32")
|
|
175
|
+
print(f"RAM: {stats['estimated_ram_mb']} MB")
|
|
176
|
+
print(f"Latency: {stats['estimated_latency_ms']} ms")
|
|
177
|
+
|
|
178
|
+
deploy_to_edge(
|
|
179
|
+
model,
|
|
180
|
+
target="esp32",
|
|
181
|
+
input_shape=(1, 3, 224, 224),
|
|
182
|
+
output_path="model.tflite",
|
|
183
|
+
quantize=True,
|
|
184
|
+
quantize_type="int8",
|
|
185
|
+
)
|
|
186
|
+
```
|
|
187
|
+
|
|
142
188
|
---
|
|
143
189
|
|
|
144
190
|
## Design Principles
|
|
@@ -151,9 +197,12 @@ exp.to_tflite("model.tflite", quantize=True)
|
|
|
151
197
|
|
|
152
198
|
## Roadmap
|
|
153
199
|
|
|
154
|
-
- [
|
|
155
|
-
- [
|
|
156
|
-
- [
|
|
200
|
+
- [x] `v0.1.0` — Trainer, Evaluator, Exporter
|
|
201
|
+
- [x] `v0.2.0` — LR scheduler support, mixed precision (AMP)
|
|
202
|
+
- [x] `v0.2.1` — Gradient accumulation + callbacks
|
|
203
|
+
- [x] `v0.2.2` — Edge submodule
|
|
204
|
+
- [ ] `v0.3.0` — W&B / MLflow hooks + CoreML export
|
|
205
|
+
- [ ] `v0.3.1` — Model pruning utilities
|
|
157
206
|
|
|
158
207
|
---
|
|
159
208
|
|
|
@@ -20,10 +20,17 @@ You write the same PyTorch training loop in every project. Same checkpoint logic
|
|
|
20
20
|
## Install
|
|
21
21
|
|
|
22
22
|
```bash
|
|
23
|
+
# Base installation
|
|
23
24
|
pip install torchloop
|
|
24
25
|
|
|
25
26
|
# With TFLite export support
|
|
26
27
|
pip install torchloop[export]
|
|
28
|
+
|
|
29
|
+
# With edge deployment support
|
|
30
|
+
pip install torchloop[edge]
|
|
31
|
+
|
|
32
|
+
# Development setup
|
|
33
|
+
pip install torchloop[dev]
|
|
27
34
|
```
|
|
28
35
|
|
|
29
36
|
---
|
|
@@ -33,16 +40,21 @@ pip install torchloop[export]
|
|
|
33
40
|
### Training
|
|
34
41
|
|
|
35
42
|
```python
|
|
36
|
-
from torchloop import Trainer
|
|
43
|
+
from torchloop import EarlyStopping, ModelCheckpoint, Trainer
|
|
37
44
|
|
|
38
45
|
trainer = Trainer(
|
|
39
46
|
model,
|
|
40
47
|
optimizer=torch.optim.Adam(model.parameters()),
|
|
41
48
|
criterion=torch.nn.CrossEntropyLoss(),
|
|
42
49
|
device="cuda",
|
|
43
|
-
|
|
50
|
+
use_amp=True,
|
|
51
|
+
accumulate_steps=4,
|
|
52
|
+
patience=5,
|
|
44
53
|
)
|
|
45
54
|
|
|
55
|
+
trainer.add_callback(EarlyStopping(patience=5))
|
|
56
|
+
trainer.add_callback(ModelCheckpoint(filepath="best.pt"))
|
|
57
|
+
|
|
46
58
|
history = trainer.fit(train_loader, val_loader, epochs=30)
|
|
47
59
|
trainer.save("best.pt")
|
|
48
60
|
```
|
|
@@ -73,6 +85,25 @@ exp.to_onnx("model.onnx")
|
|
|
73
85
|
exp.to_tflite("model.tflite", quantize=True)
|
|
74
86
|
```
|
|
75
87
|
|
|
88
|
+
### Edge Deployment
|
|
89
|
+
|
|
90
|
+
```python
|
|
91
|
+
from torchloop.edge import deploy_to_edge, estimate_model
|
|
92
|
+
|
|
93
|
+
stats = estimate_model(model, (1, 3, 224, 224), target_device="esp32")
|
|
94
|
+
print(f"RAM: {stats['estimated_ram_mb']} MB")
|
|
95
|
+
print(f"Latency: {stats['estimated_latency_ms']} ms")
|
|
96
|
+
|
|
97
|
+
deploy_to_edge(
|
|
98
|
+
model,
|
|
99
|
+
target="esp32",
|
|
100
|
+
input_shape=(1, 3, 224, 224),
|
|
101
|
+
output_path="model.tflite",
|
|
102
|
+
quantize=True,
|
|
103
|
+
quantize_type="int8",
|
|
104
|
+
)
|
|
105
|
+
```
|
|
106
|
+
|
|
76
107
|
---
|
|
77
108
|
|
|
78
109
|
## Design Principles
|
|
@@ -85,9 +116,12 @@ exp.to_tflite("model.tflite", quantize=True)
|
|
|
85
116
|
|
|
86
117
|
## Roadmap
|
|
87
118
|
|
|
88
|
-
- [
|
|
89
|
-
- [
|
|
90
|
-
- [
|
|
119
|
+
- [x] `v0.1.0` — Trainer, Evaluator, Exporter
|
|
120
|
+
- [x] `v0.2.0` — LR scheduler support, mixed precision (AMP)
|
|
121
|
+
- [x] `v0.2.1` — Gradient accumulation + callbacks
|
|
122
|
+
- [x] `v0.2.2` — Edge submodule
|
|
123
|
+
- [ ] `v0.3.0` — W&B / MLflow hooks + CoreML export
|
|
124
|
+
- [ ] `v0.3.1` — Model pruning utilities
|
|
91
125
|
|
|
92
126
|
---
|
|
93
127
|
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""Example: train a tiny CNN and prepare an ESP32-friendly artifact."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
8
|
+
|
|
9
|
+
from torchloop import EarlyStopping, Trainer
|
|
10
|
+
from torchloop.edge import deploy_to_edge, estimate_model
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SimpleCNN(nn.Module):
|
|
14
|
+
"""Small CNN for demonstration."""
|
|
15
|
+
|
|
16
|
+
def __init__(self) -> None:
|
|
17
|
+
super().__init__()
|
|
18
|
+
self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)
|
|
19
|
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
|
20
|
+
self.fc = nn.Linear(16, 10)
|
|
21
|
+
|
|
22
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
23
|
+
x = torch.relu(self.conv(x))
|
|
24
|
+
x = self.pool(x)
|
|
25
|
+
x = x.view(x.size(0), -1)
|
|
26
|
+
return self.fc(x)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def main() -> None:
|
|
30
|
+
"""Run training, estimation, and deployment."""
|
|
31
|
+
model = SimpleCNN()
|
|
32
|
+
x_data = torch.randn(100, 3, 32, 32)
|
|
33
|
+
y_data = torch.randint(0, 10, (100,))
|
|
34
|
+
|
|
35
|
+
dataset = TensorDataset(x_data, y_data)
|
|
36
|
+
loader = DataLoader(dataset, batch_size=10, shuffle=True)
|
|
37
|
+
|
|
38
|
+
trainer = Trainer(
|
|
39
|
+
model,
|
|
40
|
+
optimizer=torch.optim.Adam(model.parameters()),
|
|
41
|
+
criterion=nn.CrossEntropyLoss(),
|
|
42
|
+
device="cpu",
|
|
43
|
+
use_amp=False,
|
|
44
|
+
accumulate_steps=1,
|
|
45
|
+
)
|
|
46
|
+
trainer.add_callback(EarlyStopping(patience=3))
|
|
47
|
+
trainer.fit(loader, loader, epochs=10)
|
|
48
|
+
|
|
49
|
+
stats = estimate_model(model, (1, 3, 32, 32), target_device="esp32")
|
|
50
|
+
print("ESP32 estimate")
|
|
51
|
+
print(f"Params: {stats['params']:,}")
|
|
52
|
+
print(f"RAM: {stats['estimated_ram_mb']:.2f} MB")
|
|
53
|
+
print(f"Latency: {stats['estimated_latency_ms']:.2f} ms")
|
|
54
|
+
|
|
55
|
+
deploy_to_edge(
|
|
56
|
+
model,
|
|
57
|
+
target="esp32",
|
|
58
|
+
input_shape=(1, 3, 32, 32),
|
|
59
|
+
output_path="model_esp32.tflite",
|
|
60
|
+
quantize=True,
|
|
61
|
+
quantize_type="int8",
|
|
62
|
+
)
|
|
63
|
+
print("Export complete: model_esp32.tflite")
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
if __name__ == "__main__":
|
|
67
|
+
main()
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "torchloop"
|
|
7
|
-
version = "0.2.
|
|
7
|
+
version = "0.2.3"
|
|
8
8
|
description = "Lightweight PyTorch utility library for training, evaluation, and TFLite export — without the framework lock-in."
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = { file = "LICENSE" }
|
|
@@ -26,27 +26,36 @@ classifiers = [
|
|
|
26
26
|
]
|
|
27
27
|
|
|
28
28
|
dependencies = [
|
|
29
|
-
"torch>=2.0.
|
|
29
|
+
"torch>=2.0,<2.6",
|
|
30
30
|
"torchvision>=0.15.0",
|
|
31
|
-
"scikit-learn>=1.
|
|
32
|
-
"numpy>=1.
|
|
31
|
+
"scikit-learn>=1.0",
|
|
32
|
+
"numpy>=1.20",
|
|
33
33
|
"matplotlib>=3.7.0",
|
|
34
34
|
"tqdm>=4.65.0",
|
|
35
35
|
]
|
|
36
36
|
|
|
37
37
|
[project.optional-dependencies]
|
|
38
|
+
edge = [
|
|
39
|
+
"onnx>=1.15",
|
|
40
|
+
"onnxruntime>=1.16",
|
|
41
|
+
"tensorflow-lite-runtime>=2.15; platform_system != 'Darwin'",
|
|
42
|
+
"coremltools>=7.0; platform_system == 'Darwin'",
|
|
43
|
+
]
|
|
38
44
|
export = [
|
|
39
45
|
"onnx>=1.14.0",
|
|
40
46
|
"onnxruntime>=1.15.0",
|
|
41
47
|
"tensorflow>=2.13.0", # needed for TFLite conversion
|
|
42
48
|
]
|
|
43
49
|
dev = [
|
|
44
|
-
"pytest>=7.
|
|
45
|
-
"pytest-cov>=4.
|
|
50
|
+
"pytest>=7.0",
|
|
51
|
+
"pytest-cov>=4.0",
|
|
52
|
+
"black>=23.0",
|
|
53
|
+
"mypy>=1.0",
|
|
54
|
+
"torchinfo>=1.8",
|
|
46
55
|
"ruff>=0.1.0", # linter + formatter
|
|
47
56
|
"hatch>=1.7.0",
|
|
48
57
|
]
|
|
49
|
-
all = ["torchloop[export,dev]"]
|
|
58
|
+
all = ["torchloop[edge,export,dev]"]
|
|
50
59
|
|
|
51
60
|
[project.urls]
|
|
52
61
|
Homepage = "https://github.com/Tharun007-TK/torchloop"
|
|
@@ -65,4 +74,9 @@ select = ["E", "F", "I"] # pycodestyle + pyflakes + isort
|
|
|
65
74
|
|
|
66
75
|
[tool.pytest.ini_options]
|
|
67
76
|
testpaths = ["tests"]
|
|
77
|
+
pythonpath = ["src"]
|
|
68
78
|
addopts = "--cov=src/torchloop --cov-report=term-missing"
|
|
79
|
+
filterwarnings = [
|
|
80
|
+
"ignore:AMP requested on non-CUDA.*:UserWarning",
|
|
81
|
+
"ignore::DeprecationWarning",
|
|
82
|
+
]
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""
|
|
2
|
+
torchloop — Lightweight PyTorch utility library.
|
|
3
|
+
|
|
4
|
+
Modules:
|
|
5
|
+
trainer : Training loop, metric logging, checkpoint management
|
|
6
|
+
evaluator : Classification report, confusion matrix, per-class F1
|
|
7
|
+
exporter : PyTorch → ONNX → TFLite with optional quantization
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
__version__ = "0.2.3"
|
|
11
|
+
__author__ = "Tharun Kumar"
|
|
12
|
+
|
|
13
|
+
from torchloop.callbacks import (
|
|
14
|
+
CSVLogger,
|
|
15
|
+
Callback,
|
|
16
|
+
EarlyStopping,
|
|
17
|
+
ModelCheckpoint,
|
|
18
|
+
StopTraining,
|
|
19
|
+
)
|
|
20
|
+
from torchloop.edge import deploy_to_edge, estimate_model
|
|
21
|
+
from torchloop.evaluator import Evaluator
|
|
22
|
+
from torchloop.exporter import Exporter
|
|
23
|
+
from torchloop.trainer import Trainer
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"Trainer",
|
|
27
|
+
"Evaluator",
|
|
28
|
+
"Exporter",
|
|
29
|
+
"Callback",
|
|
30
|
+
"EarlyStopping",
|
|
31
|
+
"ModelCheckpoint",
|
|
32
|
+
"CSVLogger",
|
|
33
|
+
"StopTraining",
|
|
34
|
+
"deploy_to_edge",
|
|
35
|
+
"estimate_model",
|
|
36
|
+
"__version__",
|
|
37
|
+
]
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
"""Callback utilities for training lifecycle hooks."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import csv
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Optional
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class StopTraining(Exception):
|
|
13
|
+
"""Raised by callbacks to request early stop of training."""
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Callback:
|
|
17
|
+
"""Base callback class.
|
|
18
|
+
|
|
19
|
+
Subclass this and override hook methods to customize training behavior.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def on_train_begin(self, logs: Optional[dict[str, Any]] = None) -> None:
|
|
23
|
+
"""Called before the first epoch starts.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
logs: Optional training state dictionary.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def on_train_end(self, logs: Optional[dict[str, Any]] = None) -> None:
|
|
30
|
+
"""Called after training ends.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
logs: Optional training state dictionary.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def on_epoch_begin(
|
|
37
|
+
self,
|
|
38
|
+
epoch: int,
|
|
39
|
+
logs: Optional[dict[str, Any]] = None,
|
|
40
|
+
) -> None:
|
|
41
|
+
"""Called at the start of each epoch.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
epoch: Current epoch number.
|
|
45
|
+
logs: Optional training state dictionary.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def on_epoch_end(
|
|
49
|
+
self,
|
|
50
|
+
epoch: int,
|
|
51
|
+
logs: Optional[dict[str, Any]] = None,
|
|
52
|
+
) -> None:
|
|
53
|
+
"""Called at the end of each epoch.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
epoch: Current epoch number.
|
|
57
|
+
logs: Optional training state dictionary.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def on_batch_end(
|
|
61
|
+
self,
|
|
62
|
+
batch: int,
|
|
63
|
+
logs: Optional[dict[str, Any]] = None,
|
|
64
|
+
) -> None:
|
|
65
|
+
"""Called after each training batch.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
batch: Zero-based batch index.
|
|
69
|
+
logs: Optional batch-level metrics.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class EarlyStopping(Callback):
|
|
74
|
+
"""Stop training when a monitored metric stops improving.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
patience: Number of epochs to wait for an improvement.
|
|
78
|
+
monitor: Metric name expected in logs.
|
|
79
|
+
min_delta: Minimum improvement required to reset patience.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
patience: int = 5,
|
|
85
|
+
monitor: str = "val_loss",
|
|
86
|
+
min_delta: float = 0.0,
|
|
87
|
+
) -> None:
|
|
88
|
+
if patience < 1:
|
|
89
|
+
raise ValueError("patience must be >= 1")
|
|
90
|
+
if min_delta < 0:
|
|
91
|
+
raise ValueError("min_delta must be >= 0")
|
|
92
|
+
|
|
93
|
+
self.patience = patience
|
|
94
|
+
self.monitor = monitor
|
|
95
|
+
self.min_delta = min_delta
|
|
96
|
+
self.best = float("inf")
|
|
97
|
+
self.counter = 0
|
|
98
|
+
|
|
99
|
+
def on_epoch_end(
|
|
100
|
+
self,
|
|
101
|
+
epoch: int,
|
|
102
|
+
logs: Optional[dict[str, Any]] = None,
|
|
103
|
+
) -> None:
|
|
104
|
+
"""Check metric value and raise StopTraining when patience is exceeded.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
epoch: Current epoch number.
|
|
108
|
+
logs: Metrics dictionary from trainer.
|
|
109
|
+
"""
|
|
110
|
+
logs = logs or {}
|
|
111
|
+
current = logs.get(self.monitor)
|
|
112
|
+
if current is None:
|
|
113
|
+
return
|
|
114
|
+
|
|
115
|
+
if current <= self.best - self.min_delta:
|
|
116
|
+
self.best = float(current)
|
|
117
|
+
self.counter = 0
|
|
118
|
+
return
|
|
119
|
+
|
|
120
|
+
self.counter += 1
|
|
121
|
+
if self.counter >= self.patience:
|
|
122
|
+
raise StopTraining(f"Early stopping triggered at epoch {epoch}")
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class ModelCheckpoint(Callback):
|
|
126
|
+
"""Save model checkpoints based on monitored metric.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
monitor: Metric name expected in logs.
|
|
130
|
+
save_best_only: Save only when monitored metric improves.
|
|
131
|
+
filepath: Path to write checkpoint file.
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
def __init__(
|
|
135
|
+
self,
|
|
136
|
+
monitor: str = "val_loss",
|
|
137
|
+
save_best_only: bool = True,
|
|
138
|
+
filepath: str = "checkpoint.pt",
|
|
139
|
+
) -> None:
|
|
140
|
+
self.monitor = monitor
|
|
141
|
+
self.save_best_only = save_best_only
|
|
142
|
+
self.filepath = Path(filepath)
|
|
143
|
+
self.best = float("inf")
|
|
144
|
+
|
|
145
|
+
def on_epoch_end(
|
|
146
|
+
self,
|
|
147
|
+
epoch: int,
|
|
148
|
+
logs: Optional[dict[str, Any]] = None,
|
|
149
|
+
) -> None:
|
|
150
|
+
"""Persist model checkpoint if criteria are met.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
epoch: Current epoch number.
|
|
154
|
+
logs: Metrics dictionary from trainer.
|
|
155
|
+
"""
|
|
156
|
+
logs = logs or {}
|
|
157
|
+
model = logs.get("model")
|
|
158
|
+
if model is None:
|
|
159
|
+
return
|
|
160
|
+
|
|
161
|
+
current = logs.get(self.monitor)
|
|
162
|
+
if current is None:
|
|
163
|
+
return
|
|
164
|
+
|
|
165
|
+
if not self.save_best_only:
|
|
166
|
+
torch.save(model.state_dict(), self.filepath)
|
|
167
|
+
return
|
|
168
|
+
|
|
169
|
+
if float(current) < self.best:
|
|
170
|
+
self.best = float(current)
|
|
171
|
+
torch.save(model.state_dict(), self.filepath)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class CSVLogger(Callback):
|
|
175
|
+
"""Log epoch metrics to a CSV file.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
log_dir: Directory where CSV logs are written.
|
|
179
|
+
filename: Name of CSV file.
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
def __init__(self, log_dir: str = "./logs", filename: str = "metrics.csv") -> None:
|
|
183
|
+
self.log_dir = Path(log_dir)
|
|
184
|
+
self.filename = filename
|
|
185
|
+
self._initialized = False
|
|
186
|
+
self._fieldnames: list[str] = []
|
|
187
|
+
|
|
188
|
+
def on_epoch_end(
|
|
189
|
+
self,
|
|
190
|
+
epoch: int,
|
|
191
|
+
logs: Optional[dict[str, Any]] = None,
|
|
192
|
+
) -> None:
|
|
193
|
+
"""Append epoch metrics to CSV.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
epoch: Current epoch number.
|
|
197
|
+
logs: Metrics dictionary from trainer.
|
|
198
|
+
"""
|
|
199
|
+
logs = dict(logs or {})
|
|
200
|
+
logs["epoch"] = epoch
|
|
201
|
+
|
|
202
|
+
self.log_dir.mkdir(parents=True, exist_ok=True)
|
|
203
|
+
file_path = self.log_dir / self.filename
|
|
204
|
+
|
|
205
|
+
if not self._initialized:
|
|
206
|
+
self._fieldnames = sorted(logs.keys())
|
|
207
|
+
with file_path.open("w", newline="", encoding="utf-8") as f:
|
|
208
|
+
writer = csv.DictWriter(f, fieldnames=self._fieldnames)
|
|
209
|
+
writer.writeheader()
|
|
210
|
+
writer.writerow({k: logs.get(k) for k in self._fieldnames})
|
|
211
|
+
self._initialized = True
|
|
212
|
+
return
|
|
213
|
+
|
|
214
|
+
with file_path.open("a", newline="", encoding="utf-8") as f:
|
|
215
|
+
writer = csv.DictWriter(f, fieldnames=self._fieldnames)
|
|
216
|
+
writer.writerow({k: logs.get(k) for k in self._fieldnames})
|