torchloop 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchloop-0.1.0/.github/workflows/ci.yml +32 -0
- torchloop-0.1.0/.github/workflows/publish.yml +30 -0
- torchloop-0.1.0/.gitignore +30 -0
- torchloop-0.1.0/LICENSE +21 -0
- torchloop-0.1.0/PKG-INFO +162 -0
- torchloop-0.1.0/README.md +96 -0
- torchloop-0.1.0/pyproject.toml +68 -0
- torchloop-0.1.0/src/torchloop/__init__.py +16 -0
- torchloop-0.1.0/src/torchloop/evaluator.py +128 -0
- torchloop-0.1.0/src/torchloop/exporter.py +137 -0
- torchloop-0.1.0/src/torchloop/trainer.py +176 -0
- torchloop-0.1.0/tests/__init__.py +0 -0
- torchloop-0.1.0/tests/test_evaluator.py +40 -0
- torchloop-0.1.0/tests/test_trainer.py +62 -0
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
name: CI
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
branches: [main]
|
|
6
|
+
pull_request:
|
|
7
|
+
branches: [main]
|
|
8
|
+
|
|
9
|
+
jobs:
|
|
10
|
+
test:
|
|
11
|
+
runs-on: ubuntu-latest
|
|
12
|
+
strategy:
|
|
13
|
+
matrix:
|
|
14
|
+
python-version: ["3.9", "3.10", "3.11"]
|
|
15
|
+
|
|
16
|
+
steps:
|
|
17
|
+
- uses: actions/checkout@v4
|
|
18
|
+
|
|
19
|
+
- name: Set up Python ${{ matrix.python-version }}
|
|
20
|
+
uses: actions/setup-python@v5
|
|
21
|
+
with:
|
|
22
|
+
python-version: ${{ matrix.python-version }}
|
|
23
|
+
|
|
24
|
+
- name: Install dependencies
|
|
25
|
+
run: |
|
|
26
|
+
pip install -e ".[dev]"
|
|
27
|
+
|
|
28
|
+
- name: Lint with ruff
|
|
29
|
+
run: ruff check src/
|
|
30
|
+
|
|
31
|
+
- name: Run tests
|
|
32
|
+
run: pytest
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
name: Publish to PyPI
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
tags:
|
|
6
|
+
- "v*.*.*"
|
|
7
|
+
|
|
8
|
+
jobs:
|
|
9
|
+
publish:
|
|
10
|
+
runs-on: ubuntu-latest
|
|
11
|
+
environment: pypi
|
|
12
|
+
permissions:
|
|
13
|
+
id-token: write
|
|
14
|
+
|
|
15
|
+
steps:
|
|
16
|
+
- uses: actions/checkout@v4
|
|
17
|
+
|
|
18
|
+
- name: Set up Python
|
|
19
|
+
uses: actions/setup-python@v5
|
|
20
|
+
with:
|
|
21
|
+
python-version: "3.11"
|
|
22
|
+
|
|
23
|
+
- name: Install hatch
|
|
24
|
+
run: pip install hatch
|
|
25
|
+
|
|
26
|
+
- name: Build package
|
|
27
|
+
run: hatch build
|
|
28
|
+
|
|
29
|
+
- name: Publish to PyPI
|
|
30
|
+
uses: pypa/gh-action-pypi-publish@release/v1
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# Python bytecode and caches
|
|
2
|
+
__pycache__/
|
|
3
|
+
*.py[cod]
|
|
4
|
+
*$py.class
|
|
5
|
+
|
|
6
|
+
# Virtual environments
|
|
7
|
+
venv/
|
|
8
|
+
.venv/
|
|
9
|
+
env/
|
|
10
|
+
ENV/
|
|
11
|
+
|
|
12
|
+
# Packaging/build artifacts
|
|
13
|
+
build/
|
|
14
|
+
dist/
|
|
15
|
+
*.egg-info/
|
|
16
|
+
.eggs/
|
|
17
|
+
|
|
18
|
+
# Test, coverage, and tooling artifacts
|
|
19
|
+
.pytest_cache/
|
|
20
|
+
.coverage
|
|
21
|
+
.coverage.*
|
|
22
|
+
htmlcov/
|
|
23
|
+
.mypy_cache/
|
|
24
|
+
.ruff_cache/
|
|
25
|
+
|
|
26
|
+
# Editor/OS files
|
|
27
|
+
.vscode/
|
|
28
|
+
.idea/
|
|
29
|
+
.DS_Store
|
|
30
|
+
Thumbs.db
|
torchloop-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Tharun K
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
torchloop-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torchloop
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Lightweight PyTorch utility library for training, evaluation, and TFLite export — without the framework lock-in.
|
|
5
|
+
Project-URL: Homepage, https://github.com/Tharun007-TK/torchloop
|
|
6
|
+
Project-URL: Repository, https://github.com/Tharun007-TK/torchloop
|
|
7
|
+
Project-URL: Issues, https://github.com/Tharun007-TK/torchloop/issues
|
|
8
|
+
Author-email: Tharun Kumar <tharunkumarvmt@gmail.com>
|
|
9
|
+
License: MIT License
|
|
10
|
+
|
|
11
|
+
Copyright (c) 2025 Tharun K
|
|
12
|
+
|
|
13
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
14
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
15
|
+
in the Software without restriction, including without limitation the rights
|
|
16
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
17
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
18
|
+
furnished to do so, subject to the following conditions:
|
|
19
|
+
|
|
20
|
+
The above copyright notice and this permission notice shall be included in all
|
|
21
|
+
copies or substantial portions of the Software.
|
|
22
|
+
|
|
23
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
24
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
25
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
26
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
27
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
28
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
29
|
+
SOFTWARE.
|
|
30
|
+
License-File: LICENSE
|
|
31
|
+
Keywords: deep learning,export,ml utilities,pytorch,tflite,training
|
|
32
|
+
Classifier: Development Status :: 3 - Alpha
|
|
33
|
+
Classifier: Intended Audience :: Developers
|
|
34
|
+
Classifier: Intended Audience :: Science/Research
|
|
35
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
36
|
+
Classifier: Programming Language :: Python :: 3
|
|
37
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
38
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
39
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
40
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
41
|
+
Requires-Python: >=3.9
|
|
42
|
+
Requires-Dist: matplotlib>=3.7.0
|
|
43
|
+
Requires-Dist: numpy>=1.24.0
|
|
44
|
+
Requires-Dist: scikit-learn>=1.3.0
|
|
45
|
+
Requires-Dist: torch>=2.0.0
|
|
46
|
+
Requires-Dist: torchvision>=0.15.0
|
|
47
|
+
Requires-Dist: tqdm>=4.65.0
|
|
48
|
+
Provides-Extra: all
|
|
49
|
+
Requires-Dist: hatch>=1.7.0; extra == 'all'
|
|
50
|
+
Requires-Dist: onnx>=1.14.0; extra == 'all'
|
|
51
|
+
Requires-Dist: onnxruntime>=1.15.0; extra == 'all'
|
|
52
|
+
Requires-Dist: pytest-cov>=4.1.0; extra == 'all'
|
|
53
|
+
Requires-Dist: pytest>=7.4.0; extra == 'all'
|
|
54
|
+
Requires-Dist: ruff>=0.1.0; extra == 'all'
|
|
55
|
+
Requires-Dist: tensorflow>=2.13.0; extra == 'all'
|
|
56
|
+
Provides-Extra: dev
|
|
57
|
+
Requires-Dist: hatch>=1.7.0; extra == 'dev'
|
|
58
|
+
Requires-Dist: pytest-cov>=4.1.0; extra == 'dev'
|
|
59
|
+
Requires-Dist: pytest>=7.4.0; extra == 'dev'
|
|
60
|
+
Requires-Dist: ruff>=0.1.0; extra == 'dev'
|
|
61
|
+
Provides-Extra: export
|
|
62
|
+
Requires-Dist: onnx>=1.14.0; extra == 'export'
|
|
63
|
+
Requires-Dist: onnxruntime>=1.15.0; extra == 'export'
|
|
64
|
+
Requires-Dist: tensorflow>=2.13.0; extra == 'export'
|
|
65
|
+
Description-Content-Type: text/markdown
|
|
66
|
+
|
|
67
|
+
# torchloop
|
|
68
|
+
|
|
69
|
+
> Lightweight PyTorch utility library for training, evaluation, and TFLite export — without the framework lock-in.
|
|
70
|
+
|
|
71
|
+
[](https://github.com/Tharun007-TK/torchloop/actions)
|
|
72
|
+
[](https://pypi.org/project/torchloop/)
|
|
73
|
+
[](https://pypi.org/project/torchloop/)
|
|
74
|
+
[](LICENSE)
|
|
75
|
+
|
|
76
|
+
---
|
|
77
|
+
|
|
78
|
+
## The Problem
|
|
79
|
+
|
|
80
|
+
You write the same PyTorch training loop in every project. Same checkpoint logic. Same metric assembly. Same TFLite export steps. It's tedious and inconsistent.
|
|
81
|
+
|
|
82
|
+
`torchloop` abstracts exactly that — nothing more.
|
|
83
|
+
|
|
84
|
+
---
|
|
85
|
+
|
|
86
|
+
## Install
|
|
87
|
+
|
|
88
|
+
```bash
|
|
89
|
+
pip install torchloop
|
|
90
|
+
|
|
91
|
+
# With TFLite export support
|
|
92
|
+
pip install torchloop[export]
|
|
93
|
+
```
|
|
94
|
+
|
|
95
|
+
---
|
|
96
|
+
|
|
97
|
+
## Usage
|
|
98
|
+
|
|
99
|
+
### Training
|
|
100
|
+
|
|
101
|
+
```python
|
|
102
|
+
from torchloop import Trainer
|
|
103
|
+
|
|
104
|
+
trainer = Trainer(
|
|
105
|
+
model,
|
|
106
|
+
optimizer=torch.optim.Adam(model.parameters()),
|
|
107
|
+
criterion=torch.nn.CrossEntropyLoss(),
|
|
108
|
+
device="cuda",
|
|
109
|
+
patience=5, # early stopping
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
history = trainer.fit(train_loader, val_loader, epochs=30)
|
|
113
|
+
trainer.save("best.pt")
|
|
114
|
+
```
|
|
115
|
+
|
|
116
|
+
### Evaluation
|
|
117
|
+
|
|
118
|
+
```python
|
|
119
|
+
from torchloop import Evaluator
|
|
120
|
+
|
|
121
|
+
ev = Evaluator(model, device="cuda")
|
|
122
|
+
results = ev.report(val_loader, class_names=["No Damage", "Minor", "Major", "Destroyed"])
|
|
123
|
+
# prints sklearn classification report
|
|
124
|
+
|
|
125
|
+
fig = ev.confusion_matrix(val_loader)
|
|
126
|
+
fig.savefig("cm.png")
|
|
127
|
+
|
|
128
|
+
per_class = ev.f1_per_class(val_loader)
|
|
129
|
+
# {'No Damage': 0.91, 'Minor': 0.78, ...}
|
|
130
|
+
```
|
|
131
|
+
|
|
132
|
+
### Export
|
|
133
|
+
|
|
134
|
+
```python
|
|
135
|
+
from torchloop.exporter import Exporter
|
|
136
|
+
|
|
137
|
+
exp = Exporter(model, input_shape=(1, 3, 224, 224))
|
|
138
|
+
exp.to_onnx("model.onnx")
|
|
139
|
+
exp.to_tflite("model.tflite", quantize=True)
|
|
140
|
+
```
|
|
141
|
+
|
|
142
|
+
---
|
|
143
|
+
|
|
144
|
+
## Design Principles
|
|
145
|
+
|
|
146
|
+
- **No lock-in**: Works with any nn.Module. No subclassing required.
|
|
147
|
+
- **Minimal surface area**: Three modules. That's it.
|
|
148
|
+
- **You own the model**: torchloop wraps your loop, doesn't replace your architecture.
|
|
149
|
+
|
|
150
|
+
---
|
|
151
|
+
|
|
152
|
+
## Roadmap
|
|
153
|
+
|
|
154
|
+
- [ ] `v0.1.0` — Trainer, Evaluator, Exporter
|
|
155
|
+
- [ ] `v0.2.0` — LR scheduler support, mixed precision (AMP)
|
|
156
|
+
- [ ] `v0.3.0` — W&B / MLflow logging hooks
|
|
157
|
+
|
|
158
|
+
---
|
|
159
|
+
|
|
160
|
+
## License
|
|
161
|
+
|
|
162
|
+
MIT
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
# torchloop
|
|
2
|
+
|
|
3
|
+
> Lightweight PyTorch utility library for training, evaluation, and TFLite export — without the framework lock-in.
|
|
4
|
+
|
|
5
|
+
[](https://github.com/Tharun007-TK/torchloop/actions)
|
|
6
|
+
[](https://pypi.org/project/torchloop/)
|
|
7
|
+
[](https://pypi.org/project/torchloop/)
|
|
8
|
+
[](LICENSE)
|
|
9
|
+
|
|
10
|
+
---
|
|
11
|
+
|
|
12
|
+
## The Problem
|
|
13
|
+
|
|
14
|
+
You write the same PyTorch training loop in every project. Same checkpoint logic. Same metric assembly. Same TFLite export steps. It's tedious and inconsistent.
|
|
15
|
+
|
|
16
|
+
`torchloop` abstracts exactly that — nothing more.
|
|
17
|
+
|
|
18
|
+
---
|
|
19
|
+
|
|
20
|
+
## Install
|
|
21
|
+
|
|
22
|
+
```bash
|
|
23
|
+
pip install torchloop
|
|
24
|
+
|
|
25
|
+
# With TFLite export support
|
|
26
|
+
pip install torchloop[export]
|
|
27
|
+
```
|
|
28
|
+
|
|
29
|
+
---
|
|
30
|
+
|
|
31
|
+
## Usage
|
|
32
|
+
|
|
33
|
+
### Training
|
|
34
|
+
|
|
35
|
+
```python
|
|
36
|
+
from torchloop import Trainer
|
|
37
|
+
|
|
38
|
+
trainer = Trainer(
|
|
39
|
+
model,
|
|
40
|
+
optimizer=torch.optim.Adam(model.parameters()),
|
|
41
|
+
criterion=torch.nn.CrossEntropyLoss(),
|
|
42
|
+
device="cuda",
|
|
43
|
+
patience=5, # early stopping
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
history = trainer.fit(train_loader, val_loader, epochs=30)
|
|
47
|
+
trainer.save("best.pt")
|
|
48
|
+
```
|
|
49
|
+
|
|
50
|
+
### Evaluation
|
|
51
|
+
|
|
52
|
+
```python
|
|
53
|
+
from torchloop import Evaluator
|
|
54
|
+
|
|
55
|
+
ev = Evaluator(model, device="cuda")
|
|
56
|
+
results = ev.report(val_loader, class_names=["No Damage", "Minor", "Major", "Destroyed"])
|
|
57
|
+
# prints sklearn classification report
|
|
58
|
+
|
|
59
|
+
fig = ev.confusion_matrix(val_loader)
|
|
60
|
+
fig.savefig("cm.png")
|
|
61
|
+
|
|
62
|
+
per_class = ev.f1_per_class(val_loader)
|
|
63
|
+
# {'No Damage': 0.91, 'Minor': 0.78, ...}
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
### Export
|
|
67
|
+
|
|
68
|
+
```python
|
|
69
|
+
from torchloop.exporter import Exporter
|
|
70
|
+
|
|
71
|
+
exp = Exporter(model, input_shape=(1, 3, 224, 224))
|
|
72
|
+
exp.to_onnx("model.onnx")
|
|
73
|
+
exp.to_tflite("model.tflite", quantize=True)
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
---
|
|
77
|
+
|
|
78
|
+
## Design Principles
|
|
79
|
+
|
|
80
|
+
- **No lock-in**: Works with any nn.Module. No subclassing required.
|
|
81
|
+
- **Minimal surface area**: Three modules. That's it.
|
|
82
|
+
- **You own the model**: torchloop wraps your loop, doesn't replace your architecture.
|
|
83
|
+
|
|
84
|
+
---
|
|
85
|
+
|
|
86
|
+
## Roadmap
|
|
87
|
+
|
|
88
|
+
- [ ] `v0.1.0` — Trainer, Evaluator, Exporter
|
|
89
|
+
- [ ] `v0.2.0` — LR scheduler support, mixed precision (AMP)
|
|
90
|
+
- [ ] `v0.3.0` — W&B / MLflow logging hooks
|
|
91
|
+
|
|
92
|
+
---
|
|
93
|
+
|
|
94
|
+
## License
|
|
95
|
+
|
|
96
|
+
MIT
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "torchloop"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Lightweight PyTorch utility library for training, evaluation, and TFLite export — without the framework lock-in."
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
license = { file = "LICENSE" }
|
|
11
|
+
requires-python = ">=3.9"
|
|
12
|
+
authors = [
|
|
13
|
+
{ name = "Tharun Kumar", email = "tharunkumarvmt@gmail.com" }
|
|
14
|
+
]
|
|
15
|
+
keywords = ["pytorch", "deep learning", "training", "tflite", "export", "ml utilities"]
|
|
16
|
+
classifiers = [
|
|
17
|
+
"Development Status :: 3 - Alpha",
|
|
18
|
+
"Intended Audience :: Developers",
|
|
19
|
+
"Intended Audience :: Science/Research",
|
|
20
|
+
"License :: OSI Approved :: MIT License",
|
|
21
|
+
"Programming Language :: Python :: 3",
|
|
22
|
+
"Programming Language :: Python :: 3.9",
|
|
23
|
+
"Programming Language :: Python :: 3.10",
|
|
24
|
+
"Programming Language :: Python :: 3.11",
|
|
25
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
dependencies = [
|
|
29
|
+
"torch>=2.0.0",
|
|
30
|
+
"torchvision>=0.15.0",
|
|
31
|
+
"scikit-learn>=1.3.0",
|
|
32
|
+
"numpy>=1.24.0",
|
|
33
|
+
"matplotlib>=3.7.0",
|
|
34
|
+
"tqdm>=4.65.0",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
[project.optional-dependencies]
|
|
38
|
+
export = [
|
|
39
|
+
"onnx>=1.14.0",
|
|
40
|
+
"onnxruntime>=1.15.0",
|
|
41
|
+
"tensorflow>=2.13.0", # needed for TFLite conversion
|
|
42
|
+
]
|
|
43
|
+
dev = [
|
|
44
|
+
"pytest>=7.4.0",
|
|
45
|
+
"pytest-cov>=4.1.0",
|
|
46
|
+
"ruff>=0.1.0", # linter + formatter
|
|
47
|
+
"hatch>=1.7.0",
|
|
48
|
+
]
|
|
49
|
+
all = ["torchloop[export,dev]"]
|
|
50
|
+
|
|
51
|
+
[project.urls]
|
|
52
|
+
Homepage = "https://github.com/Tharun007-TK/torchloop"
|
|
53
|
+
Repository = "https://github.com/Tharun007-TK/torchloop"
|
|
54
|
+
Issues = "https://github.com/Tharun007-TK/torchloop/issues"
|
|
55
|
+
|
|
56
|
+
[tool.hatch.build.targets.wheel]
|
|
57
|
+
packages = ["src/torchloop"]
|
|
58
|
+
|
|
59
|
+
[tool.ruff]
|
|
60
|
+
line-length = 88
|
|
61
|
+
target-version = "py39"
|
|
62
|
+
|
|
63
|
+
[tool.ruff.lint]
|
|
64
|
+
select = ["E", "F", "I"] # pycodestyle + pyflakes + isort
|
|
65
|
+
|
|
66
|
+
[tool.pytest.ini_options]
|
|
67
|
+
testpaths = ["tests"]
|
|
68
|
+
addopts = "--cov=src/torchloop --cov-report=term-missing"
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""
|
|
2
|
+
torchloop — Lightweight PyTorch utility library.
|
|
3
|
+
|
|
4
|
+
Modules:
|
|
5
|
+
trainer : Training loop, metric logging, checkpoint management
|
|
6
|
+
evaluator : Classification report, confusion matrix, per-class F1
|
|
7
|
+
exporter : PyTorch → ONNX → TFLite with optional quantization
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
__version__ = "0.1.0"
|
|
11
|
+
__author__ = "Tharun Kumar"
|
|
12
|
+
|
|
13
|
+
from torchloop.evaluator import Evaluator
|
|
14
|
+
from torchloop.trainer import Trainer
|
|
15
|
+
|
|
16
|
+
__all__ = ["Trainer", "Evaluator", "__version__"]
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
"""
|
|
2
|
+
torchloop.evaluator
|
|
3
|
+
-------------------
|
|
4
|
+
One-call classification diagnostics. No more assembling sklearn +
|
|
5
|
+
matplotlib calls manually across every project.
|
|
6
|
+
|
|
7
|
+
Usage:
|
|
8
|
+
from torchloop import Evaluator
|
|
9
|
+
|
|
10
|
+
ev = Evaluator(model, device="cuda")
|
|
11
|
+
ev.report(val_loader, class_names=["cat", "dog"])
|
|
12
|
+
ev.confusion_matrix(val_loader)
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from typing import Optional
|
|
18
|
+
|
|
19
|
+
import matplotlib.pyplot as plt
|
|
20
|
+
import numpy as np
|
|
21
|
+
import torch
|
|
22
|
+
import torch.nn as nn
|
|
23
|
+
from sklearn.metrics import (
|
|
24
|
+
ConfusionMatrixDisplay,
|
|
25
|
+
classification_report,
|
|
26
|
+
confusion_matrix,
|
|
27
|
+
f1_score,
|
|
28
|
+
)
|
|
29
|
+
from torch.utils.data import DataLoader
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Evaluator:
|
|
33
|
+
"""
|
|
34
|
+
Classification model evaluator.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
model : Trained nn.Module.
|
|
38
|
+
device : 'cuda', 'cpu', or 'mps'. Auto-detects if None.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, model: nn.Module, device: Optional[str] = None):
|
|
42
|
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
43
|
+
self.model = model.to(self.device)
|
|
44
|
+
|
|
45
|
+
# ------------------------------------------------------------------
|
|
46
|
+
# Public API
|
|
47
|
+
# ------------------------------------------------------------------
|
|
48
|
+
|
|
49
|
+
def report(
|
|
50
|
+
self,
|
|
51
|
+
loader: DataLoader,
|
|
52
|
+
class_names: Optional[list[str]] = None,
|
|
53
|
+
) -> dict:
|
|
54
|
+
"""
|
|
55
|
+
Print full sklearn classification report.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
dict with keys: accuracy, macro_f1, weighted_f1, per_class_f1
|
|
59
|
+
"""
|
|
60
|
+
preds, targets = self._infer(loader)
|
|
61
|
+
report = classification_report(
|
|
62
|
+
targets, preds, target_names=class_names, zero_division=0
|
|
63
|
+
)
|
|
64
|
+
print(report)
|
|
65
|
+
per_class = f1_score(targets, preds, average=None, zero_division=0).tolist()
|
|
66
|
+
return {
|
|
67
|
+
"accuracy": float((np.array(preds) == np.array(targets)).mean()),
|
|
68
|
+
"macro_f1": float(
|
|
69
|
+
f1_score(targets, preds, average="macro", zero_division=0)
|
|
70
|
+
),
|
|
71
|
+
"weighted_f1": float(
|
|
72
|
+
f1_score(targets, preds, average="weighted", zero_division=0)
|
|
73
|
+
),
|
|
74
|
+
"per_class_f1": {
|
|
75
|
+
(class_names[i] if class_names else str(i)): round(v, 4)
|
|
76
|
+
for i, v in enumerate(per_class)
|
|
77
|
+
},
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
def confusion_matrix(
|
|
81
|
+
self,
|
|
82
|
+
loader: DataLoader,
|
|
83
|
+
class_names: Optional[list[str]] = None,
|
|
84
|
+
normalize: Optional[str] = "true", # 'true' | 'pred' | 'all' | None
|
|
85
|
+
figsize: tuple = (8, 6),
|
|
86
|
+
) -> plt.Figure:
|
|
87
|
+
"""
|
|
88
|
+
Plot and return confusion matrix figure.
|
|
89
|
+
"""
|
|
90
|
+
preds, targets = self._infer(loader)
|
|
91
|
+
cm = confusion_matrix(targets, preds, normalize=normalize)
|
|
92
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
93
|
+
disp = ConfusionMatrixDisplay(cm, display_labels=class_names)
|
|
94
|
+
disp.plot(ax=ax, colorbar=True, cmap="Blues")
|
|
95
|
+
ax.set_title("Confusion Matrix")
|
|
96
|
+
plt.tight_layout()
|
|
97
|
+
return fig
|
|
98
|
+
|
|
99
|
+
def f1_per_class(
|
|
100
|
+
self,
|
|
101
|
+
loader: DataLoader,
|
|
102
|
+
class_names: Optional[list[str]] = None,
|
|
103
|
+
) -> dict[str, float]:
|
|
104
|
+
"""
|
|
105
|
+
Returns per-class F1 as a dict. Clean for logging to W&B or MLflow.
|
|
106
|
+
"""
|
|
107
|
+
preds, targets = self._infer(loader)
|
|
108
|
+
scores = f1_score(targets, preds, average=None, zero_division=0)
|
|
109
|
+
return {
|
|
110
|
+
(class_names[i] if class_names else str(i)): round(float(s), 4)
|
|
111
|
+
for i, s in enumerate(scores)
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
# ------------------------------------------------------------------
|
|
115
|
+
# Internal
|
|
116
|
+
# ------------------------------------------------------------------
|
|
117
|
+
|
|
118
|
+
def _infer(self, loader: DataLoader) -> tuple[list, list]:
|
|
119
|
+
self.model.eval()
|
|
120
|
+
all_preds, all_targets = [], []
|
|
121
|
+
with torch.no_grad():
|
|
122
|
+
for inputs, targets in loader:
|
|
123
|
+
inputs = inputs.to(self.device)
|
|
124
|
+
outputs = self.model(inputs)
|
|
125
|
+
preds = outputs.argmax(dim=1).cpu().tolist()
|
|
126
|
+
all_preds.extend(preds)
|
|
127
|
+
all_targets.extend(targets.tolist())
|
|
128
|
+
return all_preds, all_targets
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
"""
|
|
2
|
+
torchloop.exporter
|
|
3
|
+
------------------
|
|
4
|
+
PyTorch → ONNX → TFLite in one place.
|
|
5
|
+
Requires: pip install torchloop[export]
|
|
6
|
+
|
|
7
|
+
Usage:
|
|
8
|
+
from torchloop.exporter import Exporter
|
|
9
|
+
|
|
10
|
+
exp = Exporter(model, input_shape=(1, 3, 224, 224))
|
|
11
|
+
exp.to_onnx("model.onnx")
|
|
12
|
+
exp.to_tflite("model.tflite", quantize=True)
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import Optional
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
import torch.nn as nn
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Exporter:
|
|
25
|
+
"""
|
|
26
|
+
Handles model export from PyTorch to ONNX and TFLite.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
model : Trained nn.Module (will be set to eval mode).
|
|
30
|
+
input_shape : Tuple describing one sample input e.g. (1, 3, 224, 224).
|
|
31
|
+
device : Device to run dummy forward pass on.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
model: nn.Module,
|
|
37
|
+
input_shape: tuple,
|
|
38
|
+
device: Optional[str] = None,
|
|
39
|
+
):
|
|
40
|
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
41
|
+
self.model = model.to(self.device).eval()
|
|
42
|
+
self.input_shape = input_shape
|
|
43
|
+
self._dummy = torch.randn(*input_shape).to(self.device)
|
|
44
|
+
|
|
45
|
+
# ------------------------------------------------------------------
|
|
46
|
+
# Public API
|
|
47
|
+
# ------------------------------------------------------------------
|
|
48
|
+
|
|
49
|
+
def to_onnx(self, path: str | Path, opset: int = 17) -> Path:
|
|
50
|
+
"""
|
|
51
|
+
Export model to ONNX format.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
path : Output .onnx file path.
|
|
55
|
+
opset : ONNX opset version. Default 17 covers most torch ops.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Resolved path to exported file.
|
|
59
|
+
"""
|
|
60
|
+
try:
|
|
61
|
+
import onnx
|
|
62
|
+
except ImportError:
|
|
63
|
+
raise ImportError(
|
|
64
|
+
"onnx is not installed. Run: pip install torchloop[export]"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
path = Path(path)
|
|
68
|
+
torch.onnx.export(
|
|
69
|
+
self.model,
|
|
70
|
+
self._dummy,
|
|
71
|
+
str(path),
|
|
72
|
+
opset_version=opset,
|
|
73
|
+
input_names=["input"],
|
|
74
|
+
output_names=["output"],
|
|
75
|
+
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
|
|
76
|
+
)
|
|
77
|
+
model_onnx = onnx.load(str(path))
|
|
78
|
+
onnx.checker.check_model(model_onnx)
|
|
79
|
+
print(f" ONNX export verified → {path}")
|
|
80
|
+
return path
|
|
81
|
+
|
|
82
|
+
def to_tflite(
|
|
83
|
+
self,
|
|
84
|
+
path: str | Path,
|
|
85
|
+
quantize: bool = False,
|
|
86
|
+
onnx_path: Optional[str | Path] = None,
|
|
87
|
+
) -> Path:
|
|
88
|
+
"""
|
|
89
|
+
Export model to TFLite via ONNX → TF → TFLite pipeline.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
path : Output .tflite file path.
|
|
93
|
+
quantize : If True, applies dynamic range quantization.
|
|
94
|
+
onnx_path : Intermediate .onnx file path. Auto-generated if None.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
Resolved path to exported .tflite file.
|
|
98
|
+
|
|
99
|
+
Note:
|
|
100
|
+
Requires tensorflow and onnx2tf installed.
|
|
101
|
+
pip install torchloop[export] onnx2tf
|
|
102
|
+
"""
|
|
103
|
+
try:
|
|
104
|
+
import onnx2tf
|
|
105
|
+
import tensorflow as tf
|
|
106
|
+
except ImportError:
|
|
107
|
+
raise ImportError(
|
|
108
|
+
"tensorflow or onnx2tf not installed.\n"
|
|
109
|
+
"Run: pip install torchloop[export] onnx2tf"
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
path = Path(path)
|
|
113
|
+
|
|
114
|
+
# Step 1: Export to ONNX first
|
|
115
|
+
_onnx_path = Path(onnx_path) if onnx_path else path.with_suffix(".onnx")
|
|
116
|
+
self.to_onnx(_onnx_path)
|
|
117
|
+
|
|
118
|
+
# Step 2: ONNX → SavedModel via onnx2tf
|
|
119
|
+
saved_model_dir = path.parent / "_tflite_savedmodel_tmp"
|
|
120
|
+
onnx2tf.convert(
|
|
121
|
+
input_onnx_file_path=str(_onnx_path),
|
|
122
|
+
output_folder_path=str(saved_model_dir),
|
|
123
|
+
not_use_onnxsim=False,
|
|
124
|
+
verbosity="error",
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Step 3: SavedModel → TFLite
|
|
128
|
+
converter = tf.lite.TFLiteConverter.from_saved_model(str(saved_model_dir))
|
|
129
|
+
if quantize:
|
|
130
|
+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
|
131
|
+
print(" Quantization: dynamic range enabled.")
|
|
132
|
+
|
|
133
|
+
tflite_model = converter.convert()
|
|
134
|
+
path.write_bytes(tflite_model)
|
|
135
|
+
size_kb = path.stat().st_size / 1024
|
|
136
|
+
print(f" TFLite export → {path} ({size_kb:.1f} KB)")
|
|
137
|
+
return path
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
"""
|
|
2
|
+
torchloop.trainer
|
|
3
|
+
-----------------
|
|
4
|
+
Wraps the PyTorch training loop so you stop rewriting it.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
from torchloop import Trainer
|
|
8
|
+
|
|
9
|
+
trainer = Trainer(model, optimizer, criterion, device="cuda")
|
|
10
|
+
trainer.fit(train_loader, val_loader, epochs=20)
|
|
11
|
+
trainer.save("best.pt")
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import time
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import Callable, Optional
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
import torch.nn as nn
|
|
22
|
+
from torch.utils.data import DataLoader
|
|
23
|
+
from tqdm import tqdm
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Trainer:
|
|
27
|
+
"""
|
|
28
|
+
Minimal, opinionated PyTorch training loop.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
model : nn.Module to train.
|
|
32
|
+
optimizer : Any torch.optim optimizer.
|
|
33
|
+
criterion : Loss function (nn.Module or callable).
|
|
34
|
+
device : 'cuda', 'cpu', or 'mps'. Auto-detects if None.
|
|
35
|
+
metric_fn : Optional callable(preds, targets) → float for val metric.
|
|
36
|
+
patience : Early stopping patience (epochs). None = disabled.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
model: nn.Module,
|
|
42
|
+
optimizer: torch.optim.Optimizer,
|
|
43
|
+
criterion: Callable,
|
|
44
|
+
device: Optional[str] = None,
|
|
45
|
+
metric_fn: Optional[Callable] = None,
|
|
46
|
+
patience: Optional[int] = None,
|
|
47
|
+
):
|
|
48
|
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
49
|
+
self.model = model.to(self.device)
|
|
50
|
+
self.optimizer = optimizer
|
|
51
|
+
self.criterion = criterion
|
|
52
|
+
self.metric_fn = metric_fn
|
|
53
|
+
self.patience = patience
|
|
54
|
+
|
|
55
|
+
self.history: dict[str, list] = {
|
|
56
|
+
"train_loss": [],
|
|
57
|
+
"val_loss": [],
|
|
58
|
+
"val_metric": [],
|
|
59
|
+
}
|
|
60
|
+
self._best_val_loss = float("inf")
|
|
61
|
+
self._best_state: Optional[dict] = None
|
|
62
|
+
self._no_improve_count = 0
|
|
63
|
+
|
|
64
|
+
# ------------------------------------------------------------------
|
|
65
|
+
# Public API
|
|
66
|
+
# ------------------------------------------------------------------
|
|
67
|
+
|
|
68
|
+
def fit(
|
|
69
|
+
self,
|
|
70
|
+
train_loader: DataLoader,
|
|
71
|
+
val_loader: Optional[DataLoader] = None,
|
|
72
|
+
epochs: int = 10,
|
|
73
|
+
) -> dict:
|
|
74
|
+
"""
|
|
75
|
+
Train the model.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
history dict with train_loss, val_loss, val_metric per epoch.
|
|
79
|
+
"""
|
|
80
|
+
for epoch in range(1, epochs + 1):
|
|
81
|
+
t0 = time.time()
|
|
82
|
+
train_loss = self._train_epoch(train_loader)
|
|
83
|
+
self.history["train_loss"].append(train_loss)
|
|
84
|
+
|
|
85
|
+
val_loss, val_metric = None, None
|
|
86
|
+
if val_loader is not None:
|
|
87
|
+
val_loss, val_metric = self._val_epoch(val_loader)
|
|
88
|
+
self.history["val_loss"].append(val_loss)
|
|
89
|
+
self.history["val_metric"].append(val_metric)
|
|
90
|
+
self._checkpoint(val_loss)
|
|
91
|
+
|
|
92
|
+
self._log(epoch, epochs, train_loss, val_loss, val_metric, time.time() - t0)
|
|
93
|
+
|
|
94
|
+
if self._should_stop():
|
|
95
|
+
print(f" Early stopping triggered at epoch {epoch}.")
|
|
96
|
+
break
|
|
97
|
+
|
|
98
|
+
if self._best_state is not None:
|
|
99
|
+
self.model.load_state_dict(self._best_state)
|
|
100
|
+
print(" Restored best model weights.")
|
|
101
|
+
|
|
102
|
+
return self.history
|
|
103
|
+
|
|
104
|
+
def save(self, path: str | Path) -> None:
|
|
105
|
+
"""Save model state dict to path."""
|
|
106
|
+
torch.save(self.model.state_dict(), path)
|
|
107
|
+
print(f" Saved → {path}")
|
|
108
|
+
|
|
109
|
+
def load(self, path: str | Path) -> None:
|
|
110
|
+
"""Load model state dict from path."""
|
|
111
|
+
self.model.load_state_dict(torch.load(path, map_location=self.device))
|
|
112
|
+
print(f" Loaded ← {path}")
|
|
113
|
+
|
|
114
|
+
# ------------------------------------------------------------------
|
|
115
|
+
# Internal
|
|
116
|
+
# ------------------------------------------------------------------
|
|
117
|
+
|
|
118
|
+
def _train_epoch(self, loader: DataLoader) -> float:
|
|
119
|
+
self.model.train()
|
|
120
|
+
total_loss = 0.0
|
|
121
|
+
for inputs, targets in tqdm(loader, desc=" train", leave=False):
|
|
122
|
+
inputs, targets = inputs.to(self.device), targets.to(self.device)
|
|
123
|
+
self.optimizer.zero_grad()
|
|
124
|
+
outputs = self.model(inputs)
|
|
125
|
+
loss = self.criterion(outputs, targets)
|
|
126
|
+
loss.backward()
|
|
127
|
+
self.optimizer.step()
|
|
128
|
+
total_loss += loss.item() * inputs.size(0)
|
|
129
|
+
return total_loss / len(loader.dataset)
|
|
130
|
+
|
|
131
|
+
def _val_epoch(self, loader: DataLoader) -> tuple[float, Optional[float]]:
|
|
132
|
+
self.model.eval()
|
|
133
|
+
total_loss = 0.0
|
|
134
|
+
all_preds, all_targets = [], []
|
|
135
|
+
with torch.no_grad():
|
|
136
|
+
for inputs, targets in tqdm(loader, desc=" val ", leave=False):
|
|
137
|
+
inputs, targets = inputs.to(self.device), targets.to(self.device)
|
|
138
|
+
outputs = self.model(inputs)
|
|
139
|
+
loss = self.criterion(outputs, targets)
|
|
140
|
+
total_loss += loss.item() * inputs.size(0)
|
|
141
|
+
if self.metric_fn is not None:
|
|
142
|
+
all_preds.append(outputs.cpu())
|
|
143
|
+
all_targets.append(targets.cpu())
|
|
144
|
+
avg_loss = total_loss / len(loader.dataset)
|
|
145
|
+
metric = None
|
|
146
|
+
if self.metric_fn is not None and all_preds:
|
|
147
|
+
metric = self.metric_fn(
|
|
148
|
+
torch.cat(all_preds), torch.cat(all_targets)
|
|
149
|
+
)
|
|
150
|
+
return avg_loss, metric
|
|
151
|
+
|
|
152
|
+
def _checkpoint(self, val_loss: float) -> None:
|
|
153
|
+
if val_loss < self._best_val_loss:
|
|
154
|
+
self._best_val_loss = val_loss
|
|
155
|
+
self._best_state = {
|
|
156
|
+
k: v.clone() for k, v in self.model.state_dict().items()
|
|
157
|
+
}
|
|
158
|
+
self._no_improve_count = 0
|
|
159
|
+
else:
|
|
160
|
+
self._no_improve_count += 1
|
|
161
|
+
|
|
162
|
+
def _should_stop(self) -> bool:
|
|
163
|
+
return (
|
|
164
|
+
self.patience is not None
|
|
165
|
+
and self._no_improve_count >= self.patience
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
@staticmethod
|
|
169
|
+
def _log(epoch, epochs, train_loss, val_loss, val_metric, elapsed) -> None:
|
|
170
|
+
parts = [f"Epoch [{epoch:>3}/{epochs}]", f"train_loss={train_loss:.4f}"]
|
|
171
|
+
if val_loss is not None:
|
|
172
|
+
parts.append(f"val_loss={val_loss:.4f}")
|
|
173
|
+
if val_metric is not None:
|
|
174
|
+
parts.append(f"val_metric={val_metric:.4f}")
|
|
175
|
+
parts.append(f"({elapsed:.1f}s)")
|
|
176
|
+
print(" " + " ".join(parts))
|
|
File without changes
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
4
|
+
|
|
5
|
+
from torchloop import Evaluator
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _make_loader():
|
|
9
|
+
X = torch.randn(64, 16)
|
|
10
|
+
y = torch.randint(0, 3, (64,))
|
|
11
|
+
return DataLoader(TensorDataset(X, y), batch_size=16)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _make_model():
|
|
15
|
+
return nn.Sequential(nn.Linear(16, 32), nn.ReLU(), nn.Linear(32, 3))
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def test_report_returns_keys():
|
|
19
|
+
model = _make_model()
|
|
20
|
+
ev = Evaluator(model, device="cpu")
|
|
21
|
+
result = ev.report(_make_loader(), class_names=["a", "b", "c"])
|
|
22
|
+
assert "macro_f1" in result
|
|
23
|
+
assert "per_class_f1" in result
|
|
24
|
+
assert set(result["per_class_f1"].keys()) == {"a", "b", "c"}
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def test_f1_per_class_length():
|
|
28
|
+
model = _make_model()
|
|
29
|
+
ev = Evaluator(model, device="cpu")
|
|
30
|
+
scores = ev.f1_per_class(_make_loader())
|
|
31
|
+
assert len(scores) == 3
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def test_confusion_matrix_returns_figure():
|
|
35
|
+
import matplotlib.pyplot as plt
|
|
36
|
+
model = _make_model()
|
|
37
|
+
ev = Evaluator(model, device="cpu")
|
|
38
|
+
fig = ev.confusion_matrix(_make_loader())
|
|
39
|
+
assert isinstance(fig, plt.Figure)
|
|
40
|
+
plt.close(fig)
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
4
|
+
|
|
5
|
+
from torchloop import Trainer
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _make_loader(n=64, features=16, classes=3, batch=16):
|
|
9
|
+
X = torch.randn(n, features)
|
|
10
|
+
y = torch.randint(0, classes, (n,))
|
|
11
|
+
return DataLoader(TensorDataset(X, y), batch_size=batch)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _make_model(features=16, classes=3):
|
|
15
|
+
return nn.Sequential(nn.Linear(features, 32), nn.ReLU(), nn.Linear(32, classes))
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def test_trainer_fit_returns_history():
|
|
19
|
+
model = _make_model()
|
|
20
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
|
21
|
+
criterion = nn.CrossEntropyLoss()
|
|
22
|
+
trainer = Trainer(model, optimizer, criterion, device="cpu")
|
|
23
|
+
history = trainer.fit(_make_loader(), _make_loader(), epochs=3)
|
|
24
|
+
assert "train_loss" in history
|
|
25
|
+
assert len(history["train_loss"]) == 3
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def test_trainer_early_stopping():
|
|
29
|
+
model = _make_model()
|
|
30
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
|
31
|
+
criterion = nn.CrossEntropyLoss()
|
|
32
|
+
trainer = Trainer(model, optimizer, criterion, device="cpu", patience=2)
|
|
33
|
+
history = trainer.fit(_make_loader(), _make_loader(), epochs=20)
|
|
34
|
+
assert len(history["train_loss"]) <= 20
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def test_trainer_save_load(tmp_path):
|
|
38
|
+
model = _make_model()
|
|
39
|
+
optimizer = torch.optim.Adam(model.parameters())
|
|
40
|
+
criterion = nn.CrossEntropyLoss()
|
|
41
|
+
trainer = Trainer(model, optimizer, criterion, device="cpu")
|
|
42
|
+
trainer.fit(_make_loader(), epochs=1)
|
|
43
|
+
save_path = tmp_path / "model.pt"
|
|
44
|
+
trainer.save(save_path)
|
|
45
|
+
assert save_path.exists()
|
|
46
|
+
trainer.load(save_path)
|
|
47
|
+
|
|
48
|
+
def test_trainer_with_metric_fn():
|
|
49
|
+
from sklearn.metrics import f1_score as skf1
|
|
50
|
+
|
|
51
|
+
def metric_fn(preds, targets):
|
|
52
|
+
p = preds.argmax(dim=1).numpy()
|
|
53
|
+
t = targets.numpy()
|
|
54
|
+
return skf1(t, p, average="macro", zero_division=0)
|
|
55
|
+
|
|
56
|
+
model = _make_model()
|
|
57
|
+
optimizer = torch.optim.Adam(model.parameters())
|
|
58
|
+
criterion = nn.CrossEntropyLoss()
|
|
59
|
+
trainer = Trainer(model, optimizer, criterion, device="cpu", metric_fn=metric_fn)
|
|
60
|
+
history = trainer.fit(_make_loader(), _make_loader(), epochs=2)
|
|
61
|
+
assert "val_metric" in history
|
|
62
|
+
assert len(history["val_metric"]) == 2
|