torchloop 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,32 @@
1
+ name: CI
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+ pull_request:
7
+ branches: [main]
8
+
9
+ jobs:
10
+ test:
11
+ runs-on: ubuntu-latest
12
+ strategy:
13
+ matrix:
14
+ python-version: ["3.9", "3.10", "3.11"]
15
+
16
+ steps:
17
+ - uses: actions/checkout@v4
18
+
19
+ - name: Set up Python ${{ matrix.python-version }}
20
+ uses: actions/setup-python@v5
21
+ with:
22
+ python-version: ${{ matrix.python-version }}
23
+
24
+ - name: Install dependencies
25
+ run: |
26
+ pip install -e ".[dev]"
27
+
28
+ - name: Lint with ruff
29
+ run: ruff check src/
30
+
31
+ - name: Run tests
32
+ run: pytest
@@ -0,0 +1,30 @@
1
+ name: Publish to PyPI
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - "v*.*.*"
7
+
8
+ jobs:
9
+ publish:
10
+ runs-on: ubuntu-latest
11
+ environment: pypi
12
+ permissions:
13
+ id-token: write
14
+
15
+ steps:
16
+ - uses: actions/checkout@v4
17
+
18
+ - name: Set up Python
19
+ uses: actions/setup-python@v5
20
+ with:
21
+ python-version: "3.11"
22
+
23
+ - name: Install hatch
24
+ run: pip install hatch
25
+
26
+ - name: Build package
27
+ run: hatch build
28
+
29
+ - name: Publish to PyPI
30
+ uses: pypa/gh-action-pypi-publish@release/v1
@@ -0,0 +1,30 @@
1
+ # Python bytecode and caches
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # Virtual environments
7
+ venv/
8
+ .venv/
9
+ env/
10
+ ENV/
11
+
12
+ # Packaging/build artifacts
13
+ build/
14
+ dist/
15
+ *.egg-info/
16
+ .eggs/
17
+
18
+ # Test, coverage, and tooling artifacts
19
+ .pytest_cache/
20
+ .coverage
21
+ .coverage.*
22
+ htmlcov/
23
+ .mypy_cache/
24
+ .ruff_cache/
25
+
26
+ # Editor/OS files
27
+ .vscode/
28
+ .idea/
29
+ .DS_Store
30
+ Thumbs.db
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Tharun K
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,162 @@
1
+ Metadata-Version: 2.4
2
+ Name: torchloop
3
+ Version: 0.1.0
4
+ Summary: Lightweight PyTorch utility library for training, evaluation, and TFLite export — without the framework lock-in.
5
+ Project-URL: Homepage, https://github.com/Tharun007-TK/torchloop
6
+ Project-URL: Repository, https://github.com/Tharun007-TK/torchloop
7
+ Project-URL: Issues, https://github.com/Tharun007-TK/torchloop/issues
8
+ Author-email: Tharun Kumar <tharunkumarvmt@gmail.com>
9
+ License: MIT License
10
+
11
+ Copyright (c) 2025 Tharun K
12
+
13
+ Permission is hereby granted, free of charge, to any person obtaining a copy
14
+ of this software and associated documentation files (the "Software"), to deal
15
+ in the Software without restriction, including without limitation the rights
16
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
17
+ copies of the Software, and to permit persons to whom the Software is
18
+ furnished to do so, subject to the following conditions:
19
+
20
+ The above copyright notice and this permission notice shall be included in all
21
+ copies or substantial portions of the Software.
22
+
23
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
24
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
25
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
26
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
27
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
28
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
29
+ SOFTWARE.
30
+ License-File: LICENSE
31
+ Keywords: deep learning,export,ml utilities,pytorch,tflite,training
32
+ Classifier: Development Status :: 3 - Alpha
33
+ Classifier: Intended Audience :: Developers
34
+ Classifier: Intended Audience :: Science/Research
35
+ Classifier: License :: OSI Approved :: MIT License
36
+ Classifier: Programming Language :: Python :: 3
37
+ Classifier: Programming Language :: Python :: 3.9
38
+ Classifier: Programming Language :: Python :: 3.10
39
+ Classifier: Programming Language :: Python :: 3.11
40
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
41
+ Requires-Python: >=3.9
42
+ Requires-Dist: matplotlib>=3.7.0
43
+ Requires-Dist: numpy>=1.24.0
44
+ Requires-Dist: scikit-learn>=1.3.0
45
+ Requires-Dist: torch>=2.0.0
46
+ Requires-Dist: torchvision>=0.15.0
47
+ Requires-Dist: tqdm>=4.65.0
48
+ Provides-Extra: all
49
+ Requires-Dist: hatch>=1.7.0; extra == 'all'
50
+ Requires-Dist: onnx>=1.14.0; extra == 'all'
51
+ Requires-Dist: onnxruntime>=1.15.0; extra == 'all'
52
+ Requires-Dist: pytest-cov>=4.1.0; extra == 'all'
53
+ Requires-Dist: pytest>=7.4.0; extra == 'all'
54
+ Requires-Dist: ruff>=0.1.0; extra == 'all'
55
+ Requires-Dist: tensorflow>=2.13.0; extra == 'all'
56
+ Provides-Extra: dev
57
+ Requires-Dist: hatch>=1.7.0; extra == 'dev'
58
+ Requires-Dist: pytest-cov>=4.1.0; extra == 'dev'
59
+ Requires-Dist: pytest>=7.4.0; extra == 'dev'
60
+ Requires-Dist: ruff>=0.1.0; extra == 'dev'
61
+ Provides-Extra: export
62
+ Requires-Dist: onnx>=1.14.0; extra == 'export'
63
+ Requires-Dist: onnxruntime>=1.15.0; extra == 'export'
64
+ Requires-Dist: tensorflow>=2.13.0; extra == 'export'
65
+ Description-Content-Type: text/markdown
66
+
67
+ # torchloop
68
+
69
+ > Lightweight PyTorch utility library for training, evaluation, and TFLite export — without the framework lock-in.
70
+
71
+ [![CI](https://github.com/Tharun007-TK/torchloop/actions/workflows/ci.yml/badge.svg)](https://github.com/Tharun007-TK/torchloop/actions)
72
+ [![PyPI](https://img.shields.io/pypi/v/torchloop)](https://pypi.org/project/torchloop/)
73
+ [![Python](https://img.shields.io/pypi/pyversions/torchloop)](https://pypi.org/project/torchloop/)
74
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)
75
+
76
+ ---
77
+
78
+ ## The Problem
79
+
80
+ You write the same PyTorch training loop in every project. Same checkpoint logic. Same metric assembly. Same TFLite export steps. It's tedious and inconsistent.
81
+
82
+ `torchloop` abstracts exactly that — nothing more.
83
+
84
+ ---
85
+
86
+ ## Install
87
+
88
+ ```bash
89
+ pip install torchloop
90
+
91
+ # With TFLite export support
92
+ pip install torchloop[export]
93
+ ```
94
+
95
+ ---
96
+
97
+ ## Usage
98
+
99
+ ### Training
100
+
101
+ ```python
102
+ from torchloop import Trainer
103
+
104
+ trainer = Trainer(
105
+ model,
106
+ optimizer=torch.optim.Adam(model.parameters()),
107
+ criterion=torch.nn.CrossEntropyLoss(),
108
+ device="cuda",
109
+ patience=5, # early stopping
110
+ )
111
+
112
+ history = trainer.fit(train_loader, val_loader, epochs=30)
113
+ trainer.save("best.pt")
114
+ ```
115
+
116
+ ### Evaluation
117
+
118
+ ```python
119
+ from torchloop import Evaluator
120
+
121
+ ev = Evaluator(model, device="cuda")
122
+ results = ev.report(val_loader, class_names=["No Damage", "Minor", "Major", "Destroyed"])
123
+ # prints sklearn classification report
124
+
125
+ fig = ev.confusion_matrix(val_loader)
126
+ fig.savefig("cm.png")
127
+
128
+ per_class = ev.f1_per_class(val_loader)
129
+ # {'No Damage': 0.91, 'Minor': 0.78, ...}
130
+ ```
131
+
132
+ ### Export
133
+
134
+ ```python
135
+ from torchloop.exporter import Exporter
136
+
137
+ exp = Exporter(model, input_shape=(1, 3, 224, 224))
138
+ exp.to_onnx("model.onnx")
139
+ exp.to_tflite("model.tflite", quantize=True)
140
+ ```
141
+
142
+ ---
143
+
144
+ ## Design Principles
145
+
146
+ - **No lock-in**: Works with any nn.Module. No subclassing required.
147
+ - **Minimal surface area**: Three modules. That's it.
148
+ - **You own the model**: torchloop wraps your loop, doesn't replace your architecture.
149
+
150
+ ---
151
+
152
+ ## Roadmap
153
+
154
+ - [ ] `v0.1.0` — Trainer, Evaluator, Exporter
155
+ - [ ] `v0.2.0` — LR scheduler support, mixed precision (AMP)
156
+ - [ ] `v0.3.0` — W&B / MLflow logging hooks
157
+
158
+ ---
159
+
160
+ ## License
161
+
162
+ MIT
@@ -0,0 +1,96 @@
1
+ # torchloop
2
+
3
+ > Lightweight PyTorch utility library for training, evaluation, and TFLite export — without the framework lock-in.
4
+
5
+ [![CI](https://github.com/Tharun007-TK/torchloop/actions/workflows/ci.yml/badge.svg)](https://github.com/Tharun007-TK/torchloop/actions)
6
+ [![PyPI](https://img.shields.io/pypi/v/torchloop)](https://pypi.org/project/torchloop/)
7
+ [![Python](https://img.shields.io/pypi/pyversions/torchloop)](https://pypi.org/project/torchloop/)
8
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)
9
+
10
+ ---
11
+
12
+ ## The Problem
13
+
14
+ You write the same PyTorch training loop in every project. Same checkpoint logic. Same metric assembly. Same TFLite export steps. It's tedious and inconsistent.
15
+
16
+ `torchloop` abstracts exactly that — nothing more.
17
+
18
+ ---
19
+
20
+ ## Install
21
+
22
+ ```bash
23
+ pip install torchloop
24
+
25
+ # With TFLite export support
26
+ pip install torchloop[export]
27
+ ```
28
+
29
+ ---
30
+
31
+ ## Usage
32
+
33
+ ### Training
34
+
35
+ ```python
36
+ from torchloop import Trainer
37
+
38
+ trainer = Trainer(
39
+ model,
40
+ optimizer=torch.optim.Adam(model.parameters()),
41
+ criterion=torch.nn.CrossEntropyLoss(),
42
+ device="cuda",
43
+ patience=5, # early stopping
44
+ )
45
+
46
+ history = trainer.fit(train_loader, val_loader, epochs=30)
47
+ trainer.save("best.pt")
48
+ ```
49
+
50
+ ### Evaluation
51
+
52
+ ```python
53
+ from torchloop import Evaluator
54
+
55
+ ev = Evaluator(model, device="cuda")
56
+ results = ev.report(val_loader, class_names=["No Damage", "Minor", "Major", "Destroyed"])
57
+ # prints sklearn classification report
58
+
59
+ fig = ev.confusion_matrix(val_loader)
60
+ fig.savefig("cm.png")
61
+
62
+ per_class = ev.f1_per_class(val_loader)
63
+ # {'No Damage': 0.91, 'Minor': 0.78, ...}
64
+ ```
65
+
66
+ ### Export
67
+
68
+ ```python
69
+ from torchloop.exporter import Exporter
70
+
71
+ exp = Exporter(model, input_shape=(1, 3, 224, 224))
72
+ exp.to_onnx("model.onnx")
73
+ exp.to_tflite("model.tflite", quantize=True)
74
+ ```
75
+
76
+ ---
77
+
78
+ ## Design Principles
79
+
80
+ - **No lock-in**: Works with any nn.Module. No subclassing required.
81
+ - **Minimal surface area**: Three modules. That's it.
82
+ - **You own the model**: torchloop wraps your loop, doesn't replace your architecture.
83
+
84
+ ---
85
+
86
+ ## Roadmap
87
+
88
+ - [ ] `v0.1.0` — Trainer, Evaluator, Exporter
89
+ - [ ] `v0.2.0` — LR scheduler support, mixed precision (AMP)
90
+ - [ ] `v0.3.0` — W&B / MLflow logging hooks
91
+
92
+ ---
93
+
94
+ ## License
95
+
96
+ MIT
@@ -0,0 +1,68 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "torchloop"
7
+ version = "0.1.0"
8
+ description = "Lightweight PyTorch utility library for training, evaluation, and TFLite export — without the framework lock-in."
9
+ readme = "README.md"
10
+ license = { file = "LICENSE" }
11
+ requires-python = ">=3.9"
12
+ authors = [
13
+ { name = "Tharun Kumar", email = "tharunkumarvmt@gmail.com" }
14
+ ]
15
+ keywords = ["pytorch", "deep learning", "training", "tflite", "export", "ml utilities"]
16
+ classifiers = [
17
+ "Development Status :: 3 - Alpha",
18
+ "Intended Audience :: Developers",
19
+ "Intended Audience :: Science/Research",
20
+ "License :: OSI Approved :: MIT License",
21
+ "Programming Language :: Python :: 3",
22
+ "Programming Language :: Python :: 3.9",
23
+ "Programming Language :: Python :: 3.10",
24
+ "Programming Language :: Python :: 3.11",
25
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
26
+ ]
27
+
28
+ dependencies = [
29
+ "torch>=2.0.0",
30
+ "torchvision>=0.15.0",
31
+ "scikit-learn>=1.3.0",
32
+ "numpy>=1.24.0",
33
+ "matplotlib>=3.7.0",
34
+ "tqdm>=4.65.0",
35
+ ]
36
+
37
+ [project.optional-dependencies]
38
+ export = [
39
+ "onnx>=1.14.0",
40
+ "onnxruntime>=1.15.0",
41
+ "tensorflow>=2.13.0", # needed for TFLite conversion
42
+ ]
43
+ dev = [
44
+ "pytest>=7.4.0",
45
+ "pytest-cov>=4.1.0",
46
+ "ruff>=0.1.0", # linter + formatter
47
+ "hatch>=1.7.0",
48
+ ]
49
+ all = ["torchloop[export,dev]"]
50
+
51
+ [project.urls]
52
+ Homepage = "https://github.com/Tharun007-TK/torchloop"
53
+ Repository = "https://github.com/Tharun007-TK/torchloop"
54
+ Issues = "https://github.com/Tharun007-TK/torchloop/issues"
55
+
56
+ [tool.hatch.build.targets.wheel]
57
+ packages = ["src/torchloop"]
58
+
59
+ [tool.ruff]
60
+ line-length = 88
61
+ target-version = "py39"
62
+
63
+ [tool.ruff.lint]
64
+ select = ["E", "F", "I"] # pycodestyle + pyflakes + isort
65
+
66
+ [tool.pytest.ini_options]
67
+ testpaths = ["tests"]
68
+ addopts = "--cov=src/torchloop --cov-report=term-missing"
@@ -0,0 +1,16 @@
1
+ """
2
+ torchloop — Lightweight PyTorch utility library.
3
+
4
+ Modules:
5
+ trainer : Training loop, metric logging, checkpoint management
6
+ evaluator : Classification report, confusion matrix, per-class F1
7
+ exporter : PyTorch → ONNX → TFLite with optional quantization
8
+ """
9
+
10
+ __version__ = "0.1.0"
11
+ __author__ = "Tharun Kumar"
12
+
13
+ from torchloop.evaluator import Evaluator
14
+ from torchloop.trainer import Trainer
15
+
16
+ __all__ = ["Trainer", "Evaluator", "__version__"]
@@ -0,0 +1,128 @@
1
+ """
2
+ torchloop.evaluator
3
+ -------------------
4
+ One-call classification diagnostics. No more assembling sklearn +
5
+ matplotlib calls manually across every project.
6
+
7
+ Usage:
8
+ from torchloop import Evaluator
9
+
10
+ ev = Evaluator(model, device="cuda")
11
+ ev.report(val_loader, class_names=["cat", "dog"])
12
+ ev.confusion_matrix(val_loader)
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ from typing import Optional
18
+
19
+ import matplotlib.pyplot as plt
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn as nn
23
+ from sklearn.metrics import (
24
+ ConfusionMatrixDisplay,
25
+ classification_report,
26
+ confusion_matrix,
27
+ f1_score,
28
+ )
29
+ from torch.utils.data import DataLoader
30
+
31
+
32
+ class Evaluator:
33
+ """
34
+ Classification model evaluator.
35
+
36
+ Args:
37
+ model : Trained nn.Module.
38
+ device : 'cuda', 'cpu', or 'mps'. Auto-detects if None.
39
+ """
40
+
41
+ def __init__(self, model: nn.Module, device: Optional[str] = None):
42
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
43
+ self.model = model.to(self.device)
44
+
45
+ # ------------------------------------------------------------------
46
+ # Public API
47
+ # ------------------------------------------------------------------
48
+
49
+ def report(
50
+ self,
51
+ loader: DataLoader,
52
+ class_names: Optional[list[str]] = None,
53
+ ) -> dict:
54
+ """
55
+ Print full sklearn classification report.
56
+
57
+ Returns:
58
+ dict with keys: accuracy, macro_f1, weighted_f1, per_class_f1
59
+ """
60
+ preds, targets = self._infer(loader)
61
+ report = classification_report(
62
+ targets, preds, target_names=class_names, zero_division=0
63
+ )
64
+ print(report)
65
+ per_class = f1_score(targets, preds, average=None, zero_division=0).tolist()
66
+ return {
67
+ "accuracy": float((np.array(preds) == np.array(targets)).mean()),
68
+ "macro_f1": float(
69
+ f1_score(targets, preds, average="macro", zero_division=0)
70
+ ),
71
+ "weighted_f1": float(
72
+ f1_score(targets, preds, average="weighted", zero_division=0)
73
+ ),
74
+ "per_class_f1": {
75
+ (class_names[i] if class_names else str(i)): round(v, 4)
76
+ for i, v in enumerate(per_class)
77
+ },
78
+ }
79
+
80
+ def confusion_matrix(
81
+ self,
82
+ loader: DataLoader,
83
+ class_names: Optional[list[str]] = None,
84
+ normalize: Optional[str] = "true", # 'true' | 'pred' | 'all' | None
85
+ figsize: tuple = (8, 6),
86
+ ) -> plt.Figure:
87
+ """
88
+ Plot and return confusion matrix figure.
89
+ """
90
+ preds, targets = self._infer(loader)
91
+ cm = confusion_matrix(targets, preds, normalize=normalize)
92
+ fig, ax = plt.subplots(figsize=figsize)
93
+ disp = ConfusionMatrixDisplay(cm, display_labels=class_names)
94
+ disp.plot(ax=ax, colorbar=True, cmap="Blues")
95
+ ax.set_title("Confusion Matrix")
96
+ plt.tight_layout()
97
+ return fig
98
+
99
+ def f1_per_class(
100
+ self,
101
+ loader: DataLoader,
102
+ class_names: Optional[list[str]] = None,
103
+ ) -> dict[str, float]:
104
+ """
105
+ Returns per-class F1 as a dict. Clean for logging to W&B or MLflow.
106
+ """
107
+ preds, targets = self._infer(loader)
108
+ scores = f1_score(targets, preds, average=None, zero_division=0)
109
+ return {
110
+ (class_names[i] if class_names else str(i)): round(float(s), 4)
111
+ for i, s in enumerate(scores)
112
+ }
113
+
114
+ # ------------------------------------------------------------------
115
+ # Internal
116
+ # ------------------------------------------------------------------
117
+
118
+ def _infer(self, loader: DataLoader) -> tuple[list, list]:
119
+ self.model.eval()
120
+ all_preds, all_targets = [], []
121
+ with torch.no_grad():
122
+ for inputs, targets in loader:
123
+ inputs = inputs.to(self.device)
124
+ outputs = self.model(inputs)
125
+ preds = outputs.argmax(dim=1).cpu().tolist()
126
+ all_preds.extend(preds)
127
+ all_targets.extend(targets.tolist())
128
+ return all_preds, all_targets
@@ -0,0 +1,137 @@
1
+ """
2
+ torchloop.exporter
3
+ ------------------
4
+ PyTorch → ONNX → TFLite in one place.
5
+ Requires: pip install torchloop[export]
6
+
7
+ Usage:
8
+ from torchloop.exporter import Exporter
9
+
10
+ exp = Exporter(model, input_shape=(1, 3, 224, 224))
11
+ exp.to_onnx("model.onnx")
12
+ exp.to_tflite("model.tflite", quantize=True)
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ from pathlib import Path
18
+ from typing import Optional
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+
24
+ class Exporter:
25
+ """
26
+ Handles model export from PyTorch to ONNX and TFLite.
27
+
28
+ Args:
29
+ model : Trained nn.Module (will be set to eval mode).
30
+ input_shape : Tuple describing one sample input e.g. (1, 3, 224, 224).
31
+ device : Device to run dummy forward pass on.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ model: nn.Module,
37
+ input_shape: tuple,
38
+ device: Optional[str] = None,
39
+ ):
40
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
41
+ self.model = model.to(self.device).eval()
42
+ self.input_shape = input_shape
43
+ self._dummy = torch.randn(*input_shape).to(self.device)
44
+
45
+ # ------------------------------------------------------------------
46
+ # Public API
47
+ # ------------------------------------------------------------------
48
+
49
+ def to_onnx(self, path: str | Path, opset: int = 17) -> Path:
50
+ """
51
+ Export model to ONNX format.
52
+
53
+ Args:
54
+ path : Output .onnx file path.
55
+ opset : ONNX opset version. Default 17 covers most torch ops.
56
+
57
+ Returns:
58
+ Resolved path to exported file.
59
+ """
60
+ try:
61
+ import onnx
62
+ except ImportError:
63
+ raise ImportError(
64
+ "onnx is not installed. Run: pip install torchloop[export]"
65
+ )
66
+
67
+ path = Path(path)
68
+ torch.onnx.export(
69
+ self.model,
70
+ self._dummy,
71
+ str(path),
72
+ opset_version=opset,
73
+ input_names=["input"],
74
+ output_names=["output"],
75
+ dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
76
+ )
77
+ model_onnx = onnx.load(str(path))
78
+ onnx.checker.check_model(model_onnx)
79
+ print(f" ONNX export verified → {path}")
80
+ return path
81
+
82
+ def to_tflite(
83
+ self,
84
+ path: str | Path,
85
+ quantize: bool = False,
86
+ onnx_path: Optional[str | Path] = None,
87
+ ) -> Path:
88
+ """
89
+ Export model to TFLite via ONNX → TF → TFLite pipeline.
90
+
91
+ Args:
92
+ path : Output .tflite file path.
93
+ quantize : If True, applies dynamic range quantization.
94
+ onnx_path : Intermediate .onnx file path. Auto-generated if None.
95
+
96
+ Returns:
97
+ Resolved path to exported .tflite file.
98
+
99
+ Note:
100
+ Requires tensorflow and onnx2tf installed.
101
+ pip install torchloop[export] onnx2tf
102
+ """
103
+ try:
104
+ import onnx2tf
105
+ import tensorflow as tf
106
+ except ImportError:
107
+ raise ImportError(
108
+ "tensorflow or onnx2tf not installed.\n"
109
+ "Run: pip install torchloop[export] onnx2tf"
110
+ )
111
+
112
+ path = Path(path)
113
+
114
+ # Step 1: Export to ONNX first
115
+ _onnx_path = Path(onnx_path) if onnx_path else path.with_suffix(".onnx")
116
+ self.to_onnx(_onnx_path)
117
+
118
+ # Step 2: ONNX → SavedModel via onnx2tf
119
+ saved_model_dir = path.parent / "_tflite_savedmodel_tmp"
120
+ onnx2tf.convert(
121
+ input_onnx_file_path=str(_onnx_path),
122
+ output_folder_path=str(saved_model_dir),
123
+ not_use_onnxsim=False,
124
+ verbosity="error",
125
+ )
126
+
127
+ # Step 3: SavedModel → TFLite
128
+ converter = tf.lite.TFLiteConverter.from_saved_model(str(saved_model_dir))
129
+ if quantize:
130
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
131
+ print(" Quantization: dynamic range enabled.")
132
+
133
+ tflite_model = converter.convert()
134
+ path.write_bytes(tflite_model)
135
+ size_kb = path.stat().st_size / 1024
136
+ print(f" TFLite export → {path} ({size_kb:.1f} KB)")
137
+ return path
@@ -0,0 +1,176 @@
1
+ """
2
+ torchloop.trainer
3
+ -----------------
4
+ Wraps the PyTorch training loop so you stop rewriting it.
5
+
6
+ Usage:
7
+ from torchloop import Trainer
8
+
9
+ trainer = Trainer(model, optimizer, criterion, device="cuda")
10
+ trainer.fit(train_loader, val_loader, epochs=20)
11
+ trainer.save("best.pt")
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import time
17
+ from pathlib import Path
18
+ from typing import Callable, Optional
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ from torch.utils.data import DataLoader
23
+ from tqdm import tqdm
24
+
25
+
26
+ class Trainer:
27
+ """
28
+ Minimal, opinionated PyTorch training loop.
29
+
30
+ Args:
31
+ model : nn.Module to train.
32
+ optimizer : Any torch.optim optimizer.
33
+ criterion : Loss function (nn.Module or callable).
34
+ device : 'cuda', 'cpu', or 'mps'. Auto-detects if None.
35
+ metric_fn : Optional callable(preds, targets) → float for val metric.
36
+ patience : Early stopping patience (epochs). None = disabled.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ model: nn.Module,
42
+ optimizer: torch.optim.Optimizer,
43
+ criterion: Callable,
44
+ device: Optional[str] = None,
45
+ metric_fn: Optional[Callable] = None,
46
+ patience: Optional[int] = None,
47
+ ):
48
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
49
+ self.model = model.to(self.device)
50
+ self.optimizer = optimizer
51
+ self.criterion = criterion
52
+ self.metric_fn = metric_fn
53
+ self.patience = patience
54
+
55
+ self.history: dict[str, list] = {
56
+ "train_loss": [],
57
+ "val_loss": [],
58
+ "val_metric": [],
59
+ }
60
+ self._best_val_loss = float("inf")
61
+ self._best_state: Optional[dict] = None
62
+ self._no_improve_count = 0
63
+
64
+ # ------------------------------------------------------------------
65
+ # Public API
66
+ # ------------------------------------------------------------------
67
+
68
+ def fit(
69
+ self,
70
+ train_loader: DataLoader,
71
+ val_loader: Optional[DataLoader] = None,
72
+ epochs: int = 10,
73
+ ) -> dict:
74
+ """
75
+ Train the model.
76
+
77
+ Returns:
78
+ history dict with train_loss, val_loss, val_metric per epoch.
79
+ """
80
+ for epoch in range(1, epochs + 1):
81
+ t0 = time.time()
82
+ train_loss = self._train_epoch(train_loader)
83
+ self.history["train_loss"].append(train_loss)
84
+
85
+ val_loss, val_metric = None, None
86
+ if val_loader is not None:
87
+ val_loss, val_metric = self._val_epoch(val_loader)
88
+ self.history["val_loss"].append(val_loss)
89
+ self.history["val_metric"].append(val_metric)
90
+ self._checkpoint(val_loss)
91
+
92
+ self._log(epoch, epochs, train_loss, val_loss, val_metric, time.time() - t0)
93
+
94
+ if self._should_stop():
95
+ print(f" Early stopping triggered at epoch {epoch}.")
96
+ break
97
+
98
+ if self._best_state is not None:
99
+ self.model.load_state_dict(self._best_state)
100
+ print(" Restored best model weights.")
101
+
102
+ return self.history
103
+
104
+ def save(self, path: str | Path) -> None:
105
+ """Save model state dict to path."""
106
+ torch.save(self.model.state_dict(), path)
107
+ print(f" Saved → {path}")
108
+
109
+ def load(self, path: str | Path) -> None:
110
+ """Load model state dict from path."""
111
+ self.model.load_state_dict(torch.load(path, map_location=self.device))
112
+ print(f" Loaded ← {path}")
113
+
114
+ # ------------------------------------------------------------------
115
+ # Internal
116
+ # ------------------------------------------------------------------
117
+
118
+ def _train_epoch(self, loader: DataLoader) -> float:
119
+ self.model.train()
120
+ total_loss = 0.0
121
+ for inputs, targets in tqdm(loader, desc=" train", leave=False):
122
+ inputs, targets = inputs.to(self.device), targets.to(self.device)
123
+ self.optimizer.zero_grad()
124
+ outputs = self.model(inputs)
125
+ loss = self.criterion(outputs, targets)
126
+ loss.backward()
127
+ self.optimizer.step()
128
+ total_loss += loss.item() * inputs.size(0)
129
+ return total_loss / len(loader.dataset)
130
+
131
+ def _val_epoch(self, loader: DataLoader) -> tuple[float, Optional[float]]:
132
+ self.model.eval()
133
+ total_loss = 0.0
134
+ all_preds, all_targets = [], []
135
+ with torch.no_grad():
136
+ for inputs, targets in tqdm(loader, desc=" val ", leave=False):
137
+ inputs, targets = inputs.to(self.device), targets.to(self.device)
138
+ outputs = self.model(inputs)
139
+ loss = self.criterion(outputs, targets)
140
+ total_loss += loss.item() * inputs.size(0)
141
+ if self.metric_fn is not None:
142
+ all_preds.append(outputs.cpu())
143
+ all_targets.append(targets.cpu())
144
+ avg_loss = total_loss / len(loader.dataset)
145
+ metric = None
146
+ if self.metric_fn is not None and all_preds:
147
+ metric = self.metric_fn(
148
+ torch.cat(all_preds), torch.cat(all_targets)
149
+ )
150
+ return avg_loss, metric
151
+
152
+ def _checkpoint(self, val_loss: float) -> None:
153
+ if val_loss < self._best_val_loss:
154
+ self._best_val_loss = val_loss
155
+ self._best_state = {
156
+ k: v.clone() for k, v in self.model.state_dict().items()
157
+ }
158
+ self._no_improve_count = 0
159
+ else:
160
+ self._no_improve_count += 1
161
+
162
+ def _should_stop(self) -> bool:
163
+ return (
164
+ self.patience is not None
165
+ and self._no_improve_count >= self.patience
166
+ )
167
+
168
+ @staticmethod
169
+ def _log(epoch, epochs, train_loss, val_loss, val_metric, elapsed) -> None:
170
+ parts = [f"Epoch [{epoch:>3}/{epochs}]", f"train_loss={train_loss:.4f}"]
171
+ if val_loss is not None:
172
+ parts.append(f"val_loss={val_loss:.4f}")
173
+ if val_metric is not None:
174
+ parts.append(f"val_metric={val_metric:.4f}")
175
+ parts.append(f"({elapsed:.1f}s)")
176
+ print(" " + " ".join(parts))
File without changes
@@ -0,0 +1,40 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import DataLoader, TensorDataset
4
+
5
+ from torchloop import Evaluator
6
+
7
+
8
+ def _make_loader():
9
+ X = torch.randn(64, 16)
10
+ y = torch.randint(0, 3, (64,))
11
+ return DataLoader(TensorDataset(X, y), batch_size=16)
12
+
13
+
14
+ def _make_model():
15
+ return nn.Sequential(nn.Linear(16, 32), nn.ReLU(), nn.Linear(32, 3))
16
+
17
+
18
+ def test_report_returns_keys():
19
+ model = _make_model()
20
+ ev = Evaluator(model, device="cpu")
21
+ result = ev.report(_make_loader(), class_names=["a", "b", "c"])
22
+ assert "macro_f1" in result
23
+ assert "per_class_f1" in result
24
+ assert set(result["per_class_f1"].keys()) == {"a", "b", "c"}
25
+
26
+
27
+ def test_f1_per_class_length():
28
+ model = _make_model()
29
+ ev = Evaluator(model, device="cpu")
30
+ scores = ev.f1_per_class(_make_loader())
31
+ assert len(scores) == 3
32
+
33
+
34
+ def test_confusion_matrix_returns_figure():
35
+ import matplotlib.pyplot as plt
36
+ model = _make_model()
37
+ ev = Evaluator(model, device="cpu")
38
+ fig = ev.confusion_matrix(_make_loader())
39
+ assert isinstance(fig, plt.Figure)
40
+ plt.close(fig)
@@ -0,0 +1,62 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import DataLoader, TensorDataset
4
+
5
+ from torchloop import Trainer
6
+
7
+
8
+ def _make_loader(n=64, features=16, classes=3, batch=16):
9
+ X = torch.randn(n, features)
10
+ y = torch.randint(0, classes, (n,))
11
+ return DataLoader(TensorDataset(X, y), batch_size=batch)
12
+
13
+
14
+ def _make_model(features=16, classes=3):
15
+ return nn.Sequential(nn.Linear(features, 32), nn.ReLU(), nn.Linear(32, classes))
16
+
17
+
18
+ def test_trainer_fit_returns_history():
19
+ model = _make_model()
20
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
21
+ criterion = nn.CrossEntropyLoss()
22
+ trainer = Trainer(model, optimizer, criterion, device="cpu")
23
+ history = trainer.fit(_make_loader(), _make_loader(), epochs=3)
24
+ assert "train_loss" in history
25
+ assert len(history["train_loss"]) == 3
26
+
27
+
28
+ def test_trainer_early_stopping():
29
+ model = _make_model()
30
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
31
+ criterion = nn.CrossEntropyLoss()
32
+ trainer = Trainer(model, optimizer, criterion, device="cpu", patience=2)
33
+ history = trainer.fit(_make_loader(), _make_loader(), epochs=20)
34
+ assert len(history["train_loss"]) <= 20
35
+
36
+
37
+ def test_trainer_save_load(tmp_path):
38
+ model = _make_model()
39
+ optimizer = torch.optim.Adam(model.parameters())
40
+ criterion = nn.CrossEntropyLoss()
41
+ trainer = Trainer(model, optimizer, criterion, device="cpu")
42
+ trainer.fit(_make_loader(), epochs=1)
43
+ save_path = tmp_path / "model.pt"
44
+ trainer.save(save_path)
45
+ assert save_path.exists()
46
+ trainer.load(save_path)
47
+
48
+ def test_trainer_with_metric_fn():
49
+ from sklearn.metrics import f1_score as skf1
50
+
51
+ def metric_fn(preds, targets):
52
+ p = preds.argmax(dim=1).numpy()
53
+ t = targets.numpy()
54
+ return skf1(t, p, average="macro", zero_division=0)
55
+
56
+ model = _make_model()
57
+ optimizer = torch.optim.Adam(model.parameters())
58
+ criterion = nn.CrossEntropyLoss()
59
+ trainer = Trainer(model, optimizer, criterion, device="cpu", metric_fn=metric_fn)
60
+ history = trainer.fit(_make_loader(), _make_loader(), epochs=2)
61
+ assert "val_metric" in history
62
+ assert len(history["val_metric"]) == 2