torchloop 0.2.0__tar.gz → 0.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -27,4 +27,4 @@ jobs:
27
27
  run: hatch build
28
28
 
29
29
  - name: Publish to PyPI
30
- uses: pypa/gh-action-pypi-publish@release/v1
30
+ uses: pypa/gh-action-pypi-publish@release/v1
@@ -0,0 +1,40 @@
1
+ # Changelog
2
+
3
+ All notable changes to this project will be documented in this file.
4
+
5
+ The format is based on Keep a Changelog, and this project adheres to Semantic Versioning.
6
+
7
+ ## [0.2.3] - 2026-03-27
8
+
9
+ ### Added
10
+
11
+ - Trainer gradient accumulation via `accumulate_steps`.
12
+ - Callback system with `Callback`, `EarlyStopping`, `ModelCheckpoint`, `CSVLogger`, and `StopTraining`.
13
+ - Edge deployment submodule with `deploy_to_edge` and `estimate_model`.
14
+ - Example script for edge deployment in `examples/edge_deployment.py`.
15
+ - Expanded test suite for callbacks, exporter, and edge deployment logic.
16
+
17
+ ### Changed
18
+
19
+ - AMP configuration now prefers `use_amp`; legacy `amp` is kept as a deprecated alias.
20
+ - Public exports updated to include `Exporter`, callbacks, and edge APIs.
21
+ - Project metadata updated for v0.2.3 and edge/dev optional dependencies.
22
+ - README updated with callback usage and edge deployment documentation.
23
+
24
+ ### Fixed
25
+
26
+ - Pytest import resolution for src-layout test runs.
27
+ - Deterministic early-stopping test behavior in trainer test suite.
28
+
29
+ ## [0.2.0] - 2026-03-xx
30
+
31
+ ### Added
32
+
33
+ - Automatic mixed precision (AMP) support in Trainer.
34
+ - Learning-rate scheduler integration in Trainer.
35
+
36
+ ## [0.1.0] - 2026-03-xx
37
+
38
+ ### Added
39
+
40
+ - Core `Trainer`, `Evaluator`, and `Exporter` APIs.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchloop
3
- Version: 0.2.0
3
+ Version: 0.2.3
4
4
  Summary: Lightweight PyTorch utility library for training, evaluation, and TFLite export — without the framework lock-in.
5
5
  Project-URL: Homepage, https://github.com/Tharun007-TK/torchloop
6
6
  Project-URL: Repository, https://github.com/Tharun007-TK/torchloop
@@ -40,24 +40,39 @@ Classifier: Programming Language :: Python :: 3.11
40
40
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
41
41
  Requires-Python: >=3.9
42
42
  Requires-Dist: matplotlib>=3.7.0
43
- Requires-Dist: numpy>=1.24.0
44
- Requires-Dist: scikit-learn>=1.3.0
45
- Requires-Dist: torch>=2.0.0
43
+ Requires-Dist: numpy>=1.20
44
+ Requires-Dist: scikit-learn>=1.0
45
+ Requires-Dist: torch<2.6,>=2.0
46
46
  Requires-Dist: torchvision>=0.15.0
47
47
  Requires-Dist: tqdm>=4.65.0
48
48
  Provides-Extra: all
49
+ Requires-Dist: black>=23.0; extra == 'all'
50
+ Requires-Dist: coremltools>=7.0; (platform_system == 'Darwin') and extra == 'all'
49
51
  Requires-Dist: hatch>=1.7.0; extra == 'all'
52
+ Requires-Dist: mypy>=1.0; extra == 'all'
50
53
  Requires-Dist: onnx>=1.14.0; extra == 'all'
54
+ Requires-Dist: onnx>=1.15; extra == 'all'
51
55
  Requires-Dist: onnxruntime>=1.15.0; extra == 'all'
52
- Requires-Dist: pytest-cov>=4.1.0; extra == 'all'
53
- Requires-Dist: pytest>=7.4.0; extra == 'all'
56
+ Requires-Dist: onnxruntime>=1.16; extra == 'all'
57
+ Requires-Dist: pytest-cov>=4.0; extra == 'all'
58
+ Requires-Dist: pytest>=7.0; extra == 'all'
54
59
  Requires-Dist: ruff>=0.1.0; extra == 'all'
60
+ Requires-Dist: tensorflow-lite-runtime>=2.15; (platform_system != 'Darwin') and extra == 'all'
55
61
  Requires-Dist: tensorflow>=2.13.0; extra == 'all'
62
+ Requires-Dist: torchinfo>=1.8; extra == 'all'
56
63
  Provides-Extra: dev
64
+ Requires-Dist: black>=23.0; extra == 'dev'
57
65
  Requires-Dist: hatch>=1.7.0; extra == 'dev'
58
- Requires-Dist: pytest-cov>=4.1.0; extra == 'dev'
59
- Requires-Dist: pytest>=7.4.0; extra == 'dev'
66
+ Requires-Dist: mypy>=1.0; extra == 'dev'
67
+ Requires-Dist: pytest-cov>=4.0; extra == 'dev'
68
+ Requires-Dist: pytest>=7.0; extra == 'dev'
60
69
  Requires-Dist: ruff>=0.1.0; extra == 'dev'
70
+ Requires-Dist: torchinfo>=1.8; extra == 'dev'
71
+ Provides-Extra: edge
72
+ Requires-Dist: coremltools>=7.0; (platform_system == 'Darwin') and extra == 'edge'
73
+ Requires-Dist: onnx>=1.15; extra == 'edge'
74
+ Requires-Dist: onnxruntime>=1.16; extra == 'edge'
75
+ Requires-Dist: tensorflow-lite-runtime>=2.15; (platform_system != 'Darwin') and extra == 'edge'
61
76
  Provides-Extra: export
62
77
  Requires-Dist: onnx>=1.14.0; extra == 'export'
63
78
  Requires-Dist: onnxruntime>=1.15.0; extra == 'export'
@@ -86,10 +101,17 @@ You write the same PyTorch training loop in every project. Same checkpoint logic
86
101
  ## Install
87
102
 
88
103
  ```bash
104
+ # Base installation
89
105
  pip install torchloop
90
106
 
91
107
  # With TFLite export support
92
108
  pip install torchloop[export]
109
+
110
+ # With edge deployment support
111
+ pip install torchloop[edge]
112
+
113
+ # Development setup
114
+ pip install torchloop[dev]
93
115
  ```
94
116
 
95
117
  ---
@@ -99,16 +121,21 @@ pip install torchloop[export]
99
121
  ### Training
100
122
 
101
123
  ```python
102
- from torchloop import Trainer
124
+ from torchloop import EarlyStopping, ModelCheckpoint, Trainer
103
125
 
104
126
  trainer = Trainer(
105
127
  model,
106
128
  optimizer=torch.optim.Adam(model.parameters()),
107
129
  criterion=torch.nn.CrossEntropyLoss(),
108
130
  device="cuda",
109
- patience=5, # early stopping
131
+ use_amp=True,
132
+ accumulate_steps=4,
133
+ patience=5,
110
134
  )
111
135
 
136
+ trainer.add_callback(EarlyStopping(patience=5))
137
+ trainer.add_callback(ModelCheckpoint(filepath="best.pt"))
138
+
112
139
  history = trainer.fit(train_loader, val_loader, epochs=30)
113
140
  trainer.save("best.pt")
114
141
  ```
@@ -139,6 +166,25 @@ exp.to_onnx("model.onnx")
139
166
  exp.to_tflite("model.tflite", quantize=True)
140
167
  ```
141
168
 
169
+ ### Edge Deployment
170
+
171
+ ```python
172
+ from torchloop.edge import deploy_to_edge, estimate_model
173
+
174
+ stats = estimate_model(model, (1, 3, 224, 224), target_device="esp32")
175
+ print(f"RAM: {stats['estimated_ram_mb']} MB")
176
+ print(f"Latency: {stats['estimated_latency_ms']} ms")
177
+
178
+ deploy_to_edge(
179
+ model,
180
+ target="esp32",
181
+ input_shape=(1, 3, 224, 224),
182
+ output_path="model.tflite",
183
+ quantize=True,
184
+ quantize_type="int8",
185
+ )
186
+ ```
187
+
142
188
  ---
143
189
 
144
190
  ## Design Principles
@@ -151,9 +197,12 @@ exp.to_tflite("model.tflite", quantize=True)
151
197
 
152
198
  ## Roadmap
153
199
 
154
- - [ ] `v0.1.0` — Trainer, Evaluator, Exporter
155
- - [ ] `v0.2.0` — LR scheduler support, mixed precision (AMP)
156
- - [ ] `v0.3.0` — W&B / MLflow logging hooks
200
+ - [x] `v0.1.0` — Trainer, Evaluator, Exporter
201
+ - [x] `v0.2.0` — LR scheduler support, mixed precision (AMP)
202
+ - [x] `v0.2.1` — Gradient accumulation + callbacks
203
+ - [x] `v0.2.2` — Edge submodule
204
+ - [ ] `v0.3.0` — W&B / MLflow hooks + CoreML export
205
+ - [ ] `v0.3.1` — Model pruning utilities
157
206
 
158
207
  ---
159
208
 
@@ -20,10 +20,17 @@ You write the same PyTorch training loop in every project. Same checkpoint logic
20
20
  ## Install
21
21
 
22
22
  ```bash
23
+ # Base installation
23
24
  pip install torchloop
24
25
 
25
26
  # With TFLite export support
26
27
  pip install torchloop[export]
28
+
29
+ # With edge deployment support
30
+ pip install torchloop[edge]
31
+
32
+ # Development setup
33
+ pip install torchloop[dev]
27
34
  ```
28
35
 
29
36
  ---
@@ -33,16 +40,21 @@ pip install torchloop[export]
33
40
  ### Training
34
41
 
35
42
  ```python
36
- from torchloop import Trainer
43
+ from torchloop import EarlyStopping, ModelCheckpoint, Trainer
37
44
 
38
45
  trainer = Trainer(
39
46
  model,
40
47
  optimizer=torch.optim.Adam(model.parameters()),
41
48
  criterion=torch.nn.CrossEntropyLoss(),
42
49
  device="cuda",
43
- patience=5, # early stopping
50
+ use_amp=True,
51
+ accumulate_steps=4,
52
+ patience=5,
44
53
  )
45
54
 
55
+ trainer.add_callback(EarlyStopping(patience=5))
56
+ trainer.add_callback(ModelCheckpoint(filepath="best.pt"))
57
+
46
58
  history = trainer.fit(train_loader, val_loader, epochs=30)
47
59
  trainer.save("best.pt")
48
60
  ```
@@ -73,6 +85,25 @@ exp.to_onnx("model.onnx")
73
85
  exp.to_tflite("model.tflite", quantize=True)
74
86
  ```
75
87
 
88
+ ### Edge Deployment
89
+
90
+ ```python
91
+ from torchloop.edge import deploy_to_edge, estimate_model
92
+
93
+ stats = estimate_model(model, (1, 3, 224, 224), target_device="esp32")
94
+ print(f"RAM: {stats['estimated_ram_mb']} MB")
95
+ print(f"Latency: {stats['estimated_latency_ms']} ms")
96
+
97
+ deploy_to_edge(
98
+ model,
99
+ target="esp32",
100
+ input_shape=(1, 3, 224, 224),
101
+ output_path="model.tflite",
102
+ quantize=True,
103
+ quantize_type="int8",
104
+ )
105
+ ```
106
+
76
107
  ---
77
108
 
78
109
  ## Design Principles
@@ -85,9 +116,12 @@ exp.to_tflite("model.tflite", quantize=True)
85
116
 
86
117
  ## Roadmap
87
118
 
88
- - [ ] `v0.1.0` — Trainer, Evaluator, Exporter
89
- - [ ] `v0.2.0` — LR scheduler support, mixed precision (AMP)
90
- - [ ] `v0.3.0` — W&B / MLflow logging hooks
119
+ - [x] `v0.1.0` — Trainer, Evaluator, Exporter
120
+ - [x] `v0.2.0` — LR scheduler support, mixed precision (AMP)
121
+ - [x] `v0.2.1` — Gradient accumulation + callbacks
122
+ - [x] `v0.2.2` — Edge submodule
123
+ - [ ] `v0.3.0` — W&B / MLflow hooks + CoreML export
124
+ - [ ] `v0.3.1` — Model pruning utilities
91
125
 
92
126
  ---
93
127
 
@@ -0,0 +1,67 @@
1
+ """Example: train a tiny CNN and prepare an ESP32-friendly artifact."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.utils.data import DataLoader, TensorDataset
8
+
9
+ from torchloop import EarlyStopping, Trainer
10
+ from torchloop.edge import deploy_to_edge, estimate_model
11
+
12
+
13
+ class SimpleCNN(nn.Module):
14
+ """Small CNN for demonstration."""
15
+
16
+ def __init__(self) -> None:
17
+ super().__init__()
18
+ self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)
19
+ self.pool = nn.AdaptiveAvgPool2d(1)
20
+ self.fc = nn.Linear(16, 10)
21
+
22
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
23
+ x = torch.relu(self.conv(x))
24
+ x = self.pool(x)
25
+ x = x.view(x.size(0), -1)
26
+ return self.fc(x)
27
+
28
+
29
+ def main() -> None:
30
+ """Run training, estimation, and deployment."""
31
+ model = SimpleCNN()
32
+ x_data = torch.randn(100, 3, 32, 32)
33
+ y_data = torch.randint(0, 10, (100,))
34
+
35
+ dataset = TensorDataset(x_data, y_data)
36
+ loader = DataLoader(dataset, batch_size=10, shuffle=True)
37
+
38
+ trainer = Trainer(
39
+ model,
40
+ optimizer=torch.optim.Adam(model.parameters()),
41
+ criterion=nn.CrossEntropyLoss(),
42
+ device="cpu",
43
+ use_amp=False,
44
+ accumulate_steps=1,
45
+ )
46
+ trainer.add_callback(EarlyStopping(patience=3))
47
+ trainer.fit(loader, loader, epochs=10)
48
+
49
+ stats = estimate_model(model, (1, 3, 32, 32), target_device="esp32")
50
+ print("ESP32 estimate")
51
+ print(f"Params: {stats['params']:,}")
52
+ print(f"RAM: {stats['estimated_ram_mb']:.2f} MB")
53
+ print(f"Latency: {stats['estimated_latency_ms']:.2f} ms")
54
+
55
+ deploy_to_edge(
56
+ model,
57
+ target="esp32",
58
+ input_shape=(1, 3, 32, 32),
59
+ output_path="model_esp32.tflite",
60
+ quantize=True,
61
+ quantize_type="int8",
62
+ )
63
+ print("Export complete: model_esp32.tflite")
64
+
65
+
66
+ if __name__ == "__main__":
67
+ main()
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "torchloop"
7
- version = "0.2.0"
7
+ version = "0.2.3"
8
8
  description = "Lightweight PyTorch utility library for training, evaluation, and TFLite export — without the framework lock-in."
9
9
  readme = "README.md"
10
10
  license = { file = "LICENSE" }
@@ -26,27 +26,36 @@ classifiers = [
26
26
  ]
27
27
 
28
28
  dependencies = [
29
- "torch>=2.0.0",
29
+ "torch>=2.0,<2.6",
30
30
  "torchvision>=0.15.0",
31
- "scikit-learn>=1.3.0",
32
- "numpy>=1.24.0",
31
+ "scikit-learn>=1.0",
32
+ "numpy>=1.20",
33
33
  "matplotlib>=3.7.0",
34
34
  "tqdm>=4.65.0",
35
35
  ]
36
36
 
37
37
  [project.optional-dependencies]
38
+ edge = [
39
+ "onnx>=1.15",
40
+ "onnxruntime>=1.16",
41
+ "tensorflow-lite-runtime>=2.15; platform_system != 'Darwin'",
42
+ "coremltools>=7.0; platform_system == 'Darwin'",
43
+ ]
38
44
  export = [
39
45
  "onnx>=1.14.0",
40
46
  "onnxruntime>=1.15.0",
41
47
  "tensorflow>=2.13.0", # needed for TFLite conversion
42
48
  ]
43
49
  dev = [
44
- "pytest>=7.4.0",
45
- "pytest-cov>=4.1.0",
50
+ "pytest>=7.0",
51
+ "pytest-cov>=4.0",
52
+ "black>=23.0",
53
+ "mypy>=1.0",
54
+ "torchinfo>=1.8",
46
55
  "ruff>=0.1.0", # linter + formatter
47
56
  "hatch>=1.7.0",
48
57
  ]
49
- all = ["torchloop[export,dev]"]
58
+ all = ["torchloop[edge,export,dev]"]
50
59
 
51
60
  [project.urls]
52
61
  Homepage = "https://github.com/Tharun007-TK/torchloop"
@@ -65,4 +74,9 @@ select = ["E", "F", "I"] # pycodestyle + pyflakes + isort
65
74
 
66
75
  [tool.pytest.ini_options]
67
76
  testpaths = ["tests"]
77
+ pythonpath = ["src"]
68
78
  addopts = "--cov=src/torchloop --cov-report=term-missing"
79
+ filterwarnings = [
80
+ "ignore:AMP requested on non-CUDA.*:UserWarning",
81
+ "ignore::DeprecationWarning",
82
+ ]
@@ -0,0 +1,6 @@
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ scikit-learn>=1.3.0
4
+ numpy>=1.24.0
5
+ matplotlib>=3.7.0
6
+ tqdm>=4.65.0
@@ -0,0 +1,37 @@
1
+ """
2
+ torchloop — Lightweight PyTorch utility library.
3
+
4
+ Modules:
5
+ trainer : Training loop, metric logging, checkpoint management
6
+ evaluator : Classification report, confusion matrix, per-class F1
7
+ exporter : PyTorch → ONNX → TFLite with optional quantization
8
+ """
9
+
10
+ __version__ = "0.2.3"
11
+ __author__ = "Tharun Kumar"
12
+
13
+ from torchloop.callbacks import (
14
+ CSVLogger,
15
+ Callback,
16
+ EarlyStopping,
17
+ ModelCheckpoint,
18
+ StopTraining,
19
+ )
20
+ from torchloop.edge import deploy_to_edge, estimate_model
21
+ from torchloop.evaluator import Evaluator
22
+ from torchloop.exporter import Exporter
23
+ from torchloop.trainer import Trainer
24
+
25
+ __all__ = [
26
+ "Trainer",
27
+ "Evaluator",
28
+ "Exporter",
29
+ "Callback",
30
+ "EarlyStopping",
31
+ "ModelCheckpoint",
32
+ "CSVLogger",
33
+ "StopTraining",
34
+ "deploy_to_edge",
35
+ "estimate_model",
36
+ "__version__",
37
+ ]
@@ -0,0 +1,216 @@
1
+ """Callback utilities for training lifecycle hooks."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import csv
6
+ from pathlib import Path
7
+ from typing import Any, Optional
8
+
9
+ import torch
10
+
11
+
12
+ class StopTraining(Exception):
13
+ """Raised by callbacks to request early stop of training."""
14
+
15
+
16
+ class Callback:
17
+ """Base callback class.
18
+
19
+ Subclass this and override hook methods to customize training behavior.
20
+ """
21
+
22
+ def on_train_begin(self, logs: Optional[dict[str, Any]] = None) -> None:
23
+ """Called before the first epoch starts.
24
+
25
+ Args:
26
+ logs: Optional training state dictionary.
27
+ """
28
+
29
+ def on_train_end(self, logs: Optional[dict[str, Any]] = None) -> None:
30
+ """Called after training ends.
31
+
32
+ Args:
33
+ logs: Optional training state dictionary.
34
+ """
35
+
36
+ def on_epoch_begin(
37
+ self,
38
+ epoch: int,
39
+ logs: Optional[dict[str, Any]] = None,
40
+ ) -> None:
41
+ """Called at the start of each epoch.
42
+
43
+ Args:
44
+ epoch: Current epoch number.
45
+ logs: Optional training state dictionary.
46
+ """
47
+
48
+ def on_epoch_end(
49
+ self,
50
+ epoch: int,
51
+ logs: Optional[dict[str, Any]] = None,
52
+ ) -> None:
53
+ """Called at the end of each epoch.
54
+
55
+ Args:
56
+ epoch: Current epoch number.
57
+ logs: Optional training state dictionary.
58
+ """
59
+
60
+ def on_batch_end(
61
+ self,
62
+ batch: int,
63
+ logs: Optional[dict[str, Any]] = None,
64
+ ) -> None:
65
+ """Called after each training batch.
66
+
67
+ Args:
68
+ batch: Zero-based batch index.
69
+ logs: Optional batch-level metrics.
70
+ """
71
+
72
+
73
+ class EarlyStopping(Callback):
74
+ """Stop training when a monitored metric stops improving.
75
+
76
+ Args:
77
+ patience: Number of epochs to wait for an improvement.
78
+ monitor: Metric name expected in logs.
79
+ min_delta: Minimum improvement required to reset patience.
80
+ """
81
+
82
+ def __init__(
83
+ self,
84
+ patience: int = 5,
85
+ monitor: str = "val_loss",
86
+ min_delta: float = 0.0,
87
+ ) -> None:
88
+ if patience < 1:
89
+ raise ValueError("patience must be >= 1")
90
+ if min_delta < 0:
91
+ raise ValueError("min_delta must be >= 0")
92
+
93
+ self.patience = patience
94
+ self.monitor = monitor
95
+ self.min_delta = min_delta
96
+ self.best = float("inf")
97
+ self.counter = 0
98
+
99
+ def on_epoch_end(
100
+ self,
101
+ epoch: int,
102
+ logs: Optional[dict[str, Any]] = None,
103
+ ) -> None:
104
+ """Check metric value and raise StopTraining when patience is exceeded.
105
+
106
+ Args:
107
+ epoch: Current epoch number.
108
+ logs: Metrics dictionary from trainer.
109
+ """
110
+ logs = logs or {}
111
+ current = logs.get(self.monitor)
112
+ if current is None:
113
+ return
114
+
115
+ if current <= self.best - self.min_delta:
116
+ self.best = float(current)
117
+ self.counter = 0
118
+ return
119
+
120
+ self.counter += 1
121
+ if self.counter >= self.patience:
122
+ raise StopTraining(f"Early stopping triggered at epoch {epoch}")
123
+
124
+
125
+ class ModelCheckpoint(Callback):
126
+ """Save model checkpoints based on monitored metric.
127
+
128
+ Args:
129
+ monitor: Metric name expected in logs.
130
+ save_best_only: Save only when monitored metric improves.
131
+ filepath: Path to write checkpoint file.
132
+ """
133
+
134
+ def __init__(
135
+ self,
136
+ monitor: str = "val_loss",
137
+ save_best_only: bool = True,
138
+ filepath: str = "checkpoint.pt",
139
+ ) -> None:
140
+ self.monitor = monitor
141
+ self.save_best_only = save_best_only
142
+ self.filepath = Path(filepath)
143
+ self.best = float("inf")
144
+
145
+ def on_epoch_end(
146
+ self,
147
+ epoch: int,
148
+ logs: Optional[dict[str, Any]] = None,
149
+ ) -> None:
150
+ """Persist model checkpoint if criteria are met.
151
+
152
+ Args:
153
+ epoch: Current epoch number.
154
+ logs: Metrics dictionary from trainer.
155
+ """
156
+ logs = logs or {}
157
+ model = logs.get("model")
158
+ if model is None:
159
+ return
160
+
161
+ current = logs.get(self.monitor)
162
+ if current is None:
163
+ return
164
+
165
+ if not self.save_best_only:
166
+ torch.save(model.state_dict(), self.filepath)
167
+ return
168
+
169
+ if float(current) < self.best:
170
+ self.best = float(current)
171
+ torch.save(model.state_dict(), self.filepath)
172
+
173
+
174
+ class CSVLogger(Callback):
175
+ """Log epoch metrics to a CSV file.
176
+
177
+ Args:
178
+ log_dir: Directory where CSV logs are written.
179
+ filename: Name of CSV file.
180
+ """
181
+
182
+ def __init__(self, log_dir: str = "./logs", filename: str = "metrics.csv") -> None:
183
+ self.log_dir = Path(log_dir)
184
+ self.filename = filename
185
+ self._initialized = False
186
+ self._fieldnames: list[str] = []
187
+
188
+ def on_epoch_end(
189
+ self,
190
+ epoch: int,
191
+ logs: Optional[dict[str, Any]] = None,
192
+ ) -> None:
193
+ """Append epoch metrics to CSV.
194
+
195
+ Args:
196
+ epoch: Current epoch number.
197
+ logs: Metrics dictionary from trainer.
198
+ """
199
+ logs = dict(logs or {})
200
+ logs["epoch"] = epoch
201
+
202
+ self.log_dir.mkdir(parents=True, exist_ok=True)
203
+ file_path = self.log_dir / self.filename
204
+
205
+ if not self._initialized:
206
+ self._fieldnames = sorted(logs.keys())
207
+ with file_path.open("w", newline="", encoding="utf-8") as f:
208
+ writer = csv.DictWriter(f, fieldnames=self._fieldnames)
209
+ writer.writeheader()
210
+ writer.writerow({k: logs.get(k) for k in self._fieldnames})
211
+ self._initialized = True
212
+ return
213
+
214
+ with file_path.open("a", newline="", encoding="utf-8") as f:
215
+ writer = csv.DictWriter(f, fieldnames=self._fieldnames)
216
+ writer.writerow({k: logs.get(k) for k in self._fieldnames})
@@ -0,0 +1,6 @@
1
+ """Edge deployment utilities for torchloop."""
2
+
3
+ from torchloop.edge.deploy import deploy_to_edge
4
+ from torchloop.edge.estimate import estimate_model
5
+
6
+ __all__ = ["deploy_to_edge", "estimate_model"]