torchloop 0.2.3__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {torchloop-0.2.3 → torchloop-0.3.0}/PKG-INFO +10 -5
  2. {torchloop-0.2.3 → torchloop-0.3.0}/pyproject.toml +11 -2
  3. {torchloop-0.2.3 → torchloop-0.3.0}/src/torchloop/__init__.py +4 -12
  4. torchloop-0.3.0/src/torchloop/callbacks/__init__.py +7 -0
  5. torchloop-0.3.0/src/torchloop/callbacks/base.py +19 -0
  6. torchloop-0.3.0/src/torchloop/callbacks/mlflow_logger.py +72 -0
  7. torchloop-0.3.0/src/torchloop/callbacks/wandb_logger.py +64 -0
  8. {torchloop-0.2.3 → torchloop-0.3.0}/src/torchloop/edge/deploy.py +14 -4
  9. {torchloop-0.2.3 → torchloop-0.3.0}/src/torchloop/edge/estimate.py +19 -4
  10. {torchloop-0.2.3 → torchloop-0.3.0}/src/torchloop/trainer.py +12 -18
  11. torchloop-0.3.0/tests/test_callbacks.py +98 -0
  12. {torchloop-0.2.3 → torchloop-0.3.0}/tests/test_trainer.py +12 -15
  13. torchloop-0.2.3/src/torchloop/callbacks.py +0 -216
  14. torchloop-0.2.3/tests/test_callbacks.py +0 -191
  15. {torchloop-0.2.3 → torchloop-0.3.0}/.github/workflows/ci.yml +0 -0
  16. {torchloop-0.2.3 → torchloop-0.3.0}/.github/workflows/publish.yml +0 -0
  17. {torchloop-0.2.3 → torchloop-0.3.0}/.gitignore +0 -0
  18. {torchloop-0.2.3 → torchloop-0.3.0}/CHANGELOG.md +0 -0
  19. {torchloop-0.2.3 → torchloop-0.3.0}/LICENSE +0 -0
  20. {torchloop-0.2.3 → torchloop-0.3.0}/README.md +0 -0
  21. {torchloop-0.2.3 → torchloop-0.3.0}/examples/edge_deployment.py +0 -0
  22. {torchloop-0.2.3 → torchloop-0.3.0}/requirements.txt +0 -0
  23. {torchloop-0.2.3 → torchloop-0.3.0}/src/torchloop/edge/__init__.py +0 -0
  24. {torchloop-0.2.3 → torchloop-0.3.0}/src/torchloop/evaluator.py +0 -0
  25. {torchloop-0.2.3 → torchloop-0.3.0}/src/torchloop/exporter.py +0 -0
  26. {torchloop-0.2.3 → torchloop-0.3.0}/tests/__init__.py +0 -0
  27. {torchloop-0.2.3 → torchloop-0.3.0}/tests/conftest.py +0 -0
  28. {torchloop-0.2.3 → torchloop-0.3.0}/tests/test_edge.py +0 -0
  29. {torchloop-0.2.3 → torchloop-0.3.0}/tests/test_evaluator.py +0 -0
  30. {torchloop-0.2.3 → torchloop-0.3.0}/tests/test_exporter.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchloop
3
- Version: 0.2.3
3
+ Version: 0.3.0
4
4
  Summary: Lightweight PyTorch utility library for training, evaluation, and TFLite export — without the framework lock-in.
5
5
  Project-URL: Homepage, https://github.com/Tharun007-TK/torchloop
6
6
  Project-URL: Repository, https://github.com/Tharun007-TK/torchloop
@@ -47,19 +47,17 @@ Requires-Dist: torchvision>=0.15.0
47
47
  Requires-Dist: tqdm>=4.65.0
48
48
  Provides-Extra: all
49
49
  Requires-Dist: black>=23.0; extra == 'all'
50
- Requires-Dist: coremltools>=7.0; (platform_system == 'Darwin') and extra == 'all'
51
50
  Requires-Dist: hatch>=1.7.0; extra == 'all'
51
+ Requires-Dist: mlflow>=2.0.0; extra == 'all'
52
52
  Requires-Dist: mypy>=1.0; extra == 'all'
53
53
  Requires-Dist: onnx>=1.14.0; extra == 'all'
54
- Requires-Dist: onnx>=1.15; extra == 'all'
55
54
  Requires-Dist: onnxruntime>=1.15.0; extra == 'all'
56
- Requires-Dist: onnxruntime>=1.16; extra == 'all'
57
55
  Requires-Dist: pytest-cov>=4.0; extra == 'all'
58
56
  Requires-Dist: pytest>=7.0; extra == 'all'
59
57
  Requires-Dist: ruff>=0.1.0; extra == 'all'
60
- Requires-Dist: tensorflow-lite-runtime>=2.15; (platform_system != 'Darwin') and extra == 'all'
61
58
  Requires-Dist: tensorflow>=2.13.0; extra == 'all'
62
59
  Requires-Dist: torchinfo>=1.8; extra == 'all'
60
+ Requires-Dist: wandb>=0.15.0; extra == 'all'
63
61
  Provides-Extra: dev
64
62
  Requires-Dist: black>=23.0; extra == 'dev'
65
63
  Requires-Dist: hatch>=1.7.0; extra == 'dev'
@@ -68,6 +66,10 @@ Requires-Dist: pytest-cov>=4.0; extra == 'dev'
68
66
  Requires-Dist: pytest>=7.0; extra == 'dev'
69
67
  Requires-Dist: ruff>=0.1.0; extra == 'dev'
70
68
  Requires-Dist: torchinfo>=1.8; extra == 'dev'
69
+ Provides-Extra: docs
70
+ Requires-Dist: mkdocs-material>=9.0.0; extra == 'docs'
71
+ Requires-Dist: mkdocs>=1.5.0; extra == 'docs'
72
+ Requires-Dist: mkdocstrings[python]>=0.24.0; extra == 'docs'
71
73
  Provides-Extra: edge
72
74
  Requires-Dist: coremltools>=7.0; (platform_system == 'Darwin') and extra == 'edge'
73
75
  Requires-Dist: onnx>=1.15; extra == 'edge'
@@ -77,6 +79,9 @@ Provides-Extra: export
77
79
  Requires-Dist: onnx>=1.14.0; extra == 'export'
78
80
  Requires-Dist: onnxruntime>=1.15.0; extra == 'export'
79
81
  Requires-Dist: tensorflow>=2.13.0; extra == 'export'
82
+ Provides-Extra: logging
83
+ Requires-Dist: mlflow>=2.0.0; extra == 'logging'
84
+ Requires-Dist: wandb>=0.15.0; extra == 'logging'
80
85
  Description-Content-Type: text/markdown
81
86
 
82
87
  # torchloop
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "torchloop"
7
- version = "0.2.3"
7
+ version = "0.3.0"
8
8
  description = "Lightweight PyTorch utility library for training, evaluation, and TFLite export — without the framework lock-in."
9
9
  readme = "README.md"
10
10
  license = { file = "LICENSE" }
@@ -41,6 +41,15 @@ edge = [
41
41
  "tensorflow-lite-runtime>=2.15; platform_system != 'Darwin'",
42
42
  "coremltools>=7.0; platform_system == 'Darwin'",
43
43
  ]
44
+ logging = [
45
+ "wandb>=0.15.0",
46
+ "mlflow>=2.0.0",
47
+ ]
48
+ docs = [
49
+ "mkdocs>=1.5.0",
50
+ "mkdocs-material>=9.0.0",
51
+ "mkdocstrings[python]>=0.24.0",
52
+ ]
44
53
  export = [
45
54
  "onnx>=1.14.0",
46
55
  "onnxruntime>=1.15.0",
@@ -55,7 +64,7 @@ dev = [
55
64
  "ruff>=0.1.0", # linter + formatter
56
65
  "hatch>=1.7.0",
57
66
  ]
58
- all = ["torchloop[edge,export,dev]"]
67
+ all = ["torchloop[export,dev,logging]"]
59
68
 
60
69
  [project.urls]
61
70
  Homepage = "https://github.com/Tharun007-TK/torchloop"
@@ -7,16 +7,10 @@ Modules:
7
7
  exporter : PyTorch → ONNX → TFLite with optional quantization
8
8
  """
9
9
 
10
- __version__ = "0.2.3"
10
+ __version__ = "0.3.0"
11
11
  __author__ = "Tharun Kumar"
12
12
 
13
- from torchloop.callbacks import (
14
- CSVLogger,
15
- Callback,
16
- EarlyStopping,
17
- ModelCheckpoint,
18
- StopTraining,
19
- )
13
+ from torchloop.callbacks import Callback, MLflowLogger, WandBLogger
20
14
  from torchloop.edge import deploy_to_edge, estimate_model
21
15
  from torchloop.evaluator import Evaluator
22
16
  from torchloop.exporter import Exporter
@@ -27,10 +21,8 @@ __all__ = [
27
21
  "Evaluator",
28
22
  "Exporter",
29
23
  "Callback",
30
- "EarlyStopping",
31
- "ModelCheckpoint",
32
- "CSVLogger",
33
- "StopTraining",
24
+ "WandBLogger",
25
+ "MLflowLogger",
34
26
  "deploy_to_edge",
35
27
  "estimate_model",
36
28
  "__version__",
@@ -0,0 +1,7 @@
1
+ """Callback integrations for torchloop."""
2
+
3
+ from torchloop.callbacks.base import Callback
4
+ from torchloop.callbacks.mlflow_logger import MLflowLogger
5
+ from torchloop.callbacks.wandb_logger import WandBLogger
6
+
7
+ __all__ = ["Callback", "WandBLogger", "MLflowLogger"]
@@ -0,0 +1,19 @@
1
+ """Base callback abstractions for torchloop training events."""
2
+
3
+ from __future__ import annotations
4
+
5
+
6
+ class Callback:
7
+ """Base callback with optional training lifecycle hooks.
8
+
9
+ Subclasses can override any hook they need.
10
+ """
11
+
12
+ def on_train_begin(self, logs: dict) -> None:
13
+ """Run before training starts."""
14
+
15
+ def on_epoch_end(self, epoch: int, logs: dict) -> None:
16
+ """Run at the end of each epoch."""
17
+
18
+ def on_train_end(self, logs: dict) -> None:
19
+ """Run after training finishes."""
@@ -0,0 +1,72 @@
1
+ """MLflow callback integration for torchloop."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from importlib import import_module
6
+ from typing import Optional
7
+
8
+ from torchloop.callbacks.base import Callback
9
+
10
+
11
+ class MLflowLogger(Callback):
12
+ """Log training metrics to MLflow.
13
+
14
+ Args:
15
+ experiment_name: MLflow experiment name.
16
+ tracking_uri: Optional tracking server URI.
17
+ run_name: Optional MLflow run name.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ experiment_name: str,
23
+ tracking_uri: Optional[str] = None,
24
+ run_name: Optional[str] = None,
25
+ ) -> None:
26
+ self.experiment_name = experiment_name
27
+ self.tracking_uri = tracking_uri
28
+ self.run_name = run_name
29
+
30
+ def on_train_begin(self, logs: dict) -> None:
31
+ """Initialize MLflow experiment and run."""
32
+ try:
33
+ mlflow = import_module("mlflow")
34
+ except ImportError as exc:
35
+ raise ImportError(
36
+ "mlflow is required for MLflowLogger. "
37
+ "Install with: pip install torchloop[logging]"
38
+ ) from exc
39
+
40
+ if self.tracking_uri:
41
+ mlflow.set_tracking_uri(self.tracking_uri)
42
+ mlflow.set_experiment(self.experiment_name)
43
+ mlflow.start_run(run_name=self.run_name)
44
+
45
+ def on_epoch_end(self, epoch: int, logs: dict) -> None:
46
+ """Log epoch metrics to MLflow."""
47
+ try:
48
+ mlflow = import_module("mlflow")
49
+ except ImportError as exc:
50
+ raise ImportError(
51
+ "mlflow is required for MLflowLogger. "
52
+ "Install with: pip install torchloop[logging]"
53
+ ) from exc
54
+
55
+ numeric_logs = {
56
+ key: value
57
+ for key, value in logs.items()
58
+ if isinstance(value, (int, float))
59
+ }
60
+ mlflow.log_metrics(numeric_logs, step=epoch)
61
+
62
+ def on_train_end(self, logs: dict) -> None:
63
+ """End the active MLflow run."""
64
+ try:
65
+ mlflow = import_module("mlflow")
66
+ except ImportError as exc:
67
+ raise ImportError(
68
+ "mlflow is required for MLflowLogger. "
69
+ "Install with: pip install torchloop[logging]"
70
+ ) from exc
71
+
72
+ mlflow.end_run()
@@ -0,0 +1,64 @@
1
+ """Weights & Biases callback integration for torchloop."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from importlib import import_module
6
+ from typing import Any, Optional
7
+
8
+ from torchloop.callbacks.base import Callback
9
+
10
+
11
+ class WandBLogger(Callback):
12
+ """Log training metrics to Weights & Biases.
13
+
14
+ Args:
15
+ project: Weights & Biases project name.
16
+ name: Optional run name.
17
+ config: Optional run configuration dictionary.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ project: str,
23
+ name: Optional[str] = None,
24
+ config: Optional[dict[str, Any]] = None,
25
+ ) -> None:
26
+ self.project = project
27
+ self.name = name
28
+ self.config = config or {}
29
+
30
+ def on_train_begin(self, logs: dict) -> None:
31
+ """Initialize a W&B run."""
32
+ try:
33
+ wandb = import_module("wandb")
34
+ except ImportError as exc:
35
+ raise ImportError(
36
+ "wandb is required for WandBLogger. "
37
+ "Install with: pip install torchloop[logging]"
38
+ ) from exc
39
+
40
+ wandb.init(project=self.project, name=self.name, config=self.config)
41
+
42
+ def on_epoch_end(self, epoch: int, logs: dict) -> None:
43
+ """Log epoch metrics to W&B."""
44
+ try:
45
+ wandb = import_module("wandb")
46
+ except ImportError as exc:
47
+ raise ImportError(
48
+ "wandb is required for WandBLogger. "
49
+ "Install with: pip install torchloop[logging]"
50
+ ) from exc
51
+
52
+ wandb.log(logs, step=epoch)
53
+
54
+ def on_train_end(self, logs: dict) -> None:
55
+ """Finish the active W&B run."""
56
+ try:
57
+ wandb = import_module("wandb")
58
+ except ImportError as exc:
59
+ raise ImportError(
60
+ "wandb is required for WandBLogger. "
61
+ "Install with: pip install torchloop[logging]"
62
+ ) from exc
63
+
64
+ wandb.finish()
@@ -61,7 +61,10 @@ def deploy_to_edge(
61
61
  artifact_format = "onnx"
62
62
 
63
63
  if target in {"esp32", "android"}:
64
- artifact_path = output if output.suffix == ".tflite" else output.with_suffix(".tflite")
64
+ artifact_path = (
65
+ output if output.suffix == ".tflite"
66
+ else output.with_suffix(".tflite")
67
+ )
65
68
  if target == "esp32" and quantize_type != "int8":
66
69
  warnings.warn(
67
70
  "ESP32 is typically best with int8 quantization.",
@@ -79,7 +82,10 @@ def deploy_to_edge(
79
82
  _convert_to_tflite(onnx_path, artifact_path, quantize, quantize_type)
80
83
  artifact_format = "tflite"
81
84
  elif target == "ios":
82
- artifact_path = output if output.suffix == ".mlpackage" else output.with_suffix(".mlpackage")
85
+ artifact_path = (
86
+ output if output.suffix == ".mlpackage"
87
+ else output.with_suffix(".mlpackage")
88
+ )
83
89
  _convert_to_coreml(model, input_shape, artifact_path)
84
90
  artifact_format = "coreml"
85
91
  elif target == "jetson":
@@ -109,7 +115,9 @@ def _export_to_onnx(
109
115
  onnx_name = "".join(["on", "nx"])
110
116
  onnx = import_module(onnx_name)
111
117
  except ImportError as exc:
112
- raise ImportError("onnx is required. Install with: pip install torchloop[edge]") from exc
118
+ raise ImportError(
119
+ "onnx is required. Install with: pip install torchloop[edge]"
120
+ ) from exc
113
121
 
114
122
  model.eval()
115
123
  device = _get_model_device(model)
@@ -184,7 +192,9 @@ def _convert_to_coreml(
184
192
  coremltools_name = "".join(["coreml", "tools"])
185
193
  ct = import_module(coremltools_name)
186
194
  except ImportError as exc:
187
- raise ImportError("coremltools is required. Install with: pip install torchloop[edge]") from exc
195
+ raise ImportError(
196
+ "coremltools is required. Install with: pip install torchloop[edge]"
197
+ ) from exc
188
198
 
189
199
  model_cpu = model.to("cpu").eval()
190
200
  traced = torch.jit.trace(model_cpu, torch.randn(*input_shape))
@@ -63,19 +63,34 @@ def _estimate_activation_mb(input_shape: tuple[int, ...]) -> float:
63
63
  return (numel * 4 * 2) / (1024 * 1024)
64
64
 
65
65
 
66
- def _estimate_flops(model: torch.nn.Module, input_shape: tuple[int, ...]) -> int:
66
+ def _estimate_flops(
67
+ model: torch.nn.Module,
68
+ input_shape: tuple[int, ...],
69
+ ) -> int:
67
70
  """Estimate FLOPs using hooks for common modules (Conv2d, Linear)."""
68
71
  flops = 0
69
72
  hooks = []
70
73
 
71
- def conv_hook(module: torch.nn.Conv2d, _inp: tuple[torch.Tensor], out: torch.Tensor) -> None:
74
+ def conv_hook(
75
+ module: torch.nn.Conv2d,
76
+ _inp: tuple[torch.Tensor],
77
+ out: torch.Tensor,
78
+ ) -> None:
72
79
  nonlocal flops
73
80
  batch = out.shape[0]
74
81
  out_h, out_w = out.shape[2], out.shape[3]
75
- kernel_ops = module.kernel_size[0] * module.kernel_size[1] * (module.in_channels / module.groups)
82
+ kernel_ops = (
83
+ module.kernel_size[0]
84
+ * module.kernel_size[1]
85
+ * (module.in_channels / module.groups)
86
+ )
76
87
  flops += int(batch * out_h * out_w * module.out_channels * kernel_ops * 2)
77
88
 
78
- def linear_hook(module: torch.nn.Linear, inp: tuple[torch.Tensor], _out: torch.Tensor) -> None:
89
+ def linear_hook(
90
+ module: torch.nn.Linear,
91
+ inp: tuple[torch.Tensor],
92
+ _out: torch.Tensor,
93
+ ) -> None:
79
94
  nonlocal flops
80
95
  batch = inp[0].shape[0]
81
96
  flops += int(batch * module.in_features * module.out_features * 2)
@@ -32,7 +32,7 @@ import torch.nn as nn
32
32
  from torch.utils.data import DataLoader
33
33
  from tqdm import tqdm
34
34
 
35
- from torchloop.callbacks import Callback, StopTraining
35
+ from torchloop.callbacks import Callback
36
36
 
37
37
 
38
38
  class Trainer:
@@ -129,10 +129,9 @@ class Trainer:
129
129
  Returns:
130
130
  history dict with train_loss, val_loss, val_metric, lr per epoch.
131
131
  """
132
- self._trigger_hook("on_train_begin", logs=self.history)
132
+ self._run_callbacks("on_train_begin", dict(self.history))
133
133
  for epoch in range(1, epochs + 1):
134
134
  t0 = time.time()
135
- self._trigger_hook("on_epoch_begin", epoch=epoch, logs=self.history)
136
135
  train_loss = self._train_epoch(train_loader)
137
136
  self.history["train_loss"].append(train_loss)
138
137
 
@@ -153,13 +152,13 @@ class Trainer:
153
152
  )
154
153
 
155
154
  epoch_logs = {
155
+ "epoch": epoch,
156
156
  "train_loss": train_loss,
157
157
  "val_loss": val_loss,
158
158
  "val_metric": val_metric,
159
159
  "lr": current_lr,
160
- "model": self.model,
161
160
  }
162
- self._trigger_hook("on_epoch_end", epoch=epoch, logs=epoch_logs)
161
+ self._run_callbacks("on_epoch_end", epoch_logs)
163
162
 
164
163
  if self._should_stop():
165
164
  print(f" Early stopping triggered at epoch {epoch}.")
@@ -169,7 +168,7 @@ class Trainer:
169
168
  self.model.load_state_dict(self._best_state)
170
169
  print(" Restored best model weights.")
171
170
 
172
- self._trigger_hook("on_train_end", logs=self.history)
171
+ self._run_callbacks("on_train_end", dict(self.history))
173
172
  return self.history
174
173
 
175
174
  def add_callback(self, callback: Callback) -> None:
@@ -222,11 +221,6 @@ class Trainer:
222
221
  pending_steps = 0
223
222
 
224
223
  total_loss += raw_loss.item() * inputs.size(0)
225
- self._trigger_hook(
226
- "on_batch_end",
227
- batch=batch_idx,
228
- logs={"loss": float(raw_loss.item())},
229
- )
230
224
 
231
225
  if pending_steps > 0:
232
226
  self._optimizer_step()
@@ -296,15 +290,15 @@ class Trainer:
296
290
  )
297
291
  )
298
292
 
299
- def _trigger_hook(self, hook_name: str, **kwargs: Any) -> None:
300
- for callback in self.callbacks:
301
- hook = getattr(callback, hook_name, None)
293
+ def _run_callbacks(self, event: str, logs: dict[str, Any]) -> None:
294
+ for callback in (self.callbacks or []):
295
+ hook = getattr(callback, event, None)
302
296
  if hook is None:
303
297
  continue
304
- try:
305
- hook(**kwargs)
306
- except StopTraining:
307
- self._stop_early = True
298
+ if event == "on_epoch_end":
299
+ hook(int(logs.get("epoch", 0)), logs)
300
+ else:
301
+ hook(logs)
308
302
 
309
303
  @staticmethod
310
304
  def _log(
@@ -0,0 +1,98 @@
1
+ """Tests for torchloop callback integrations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import sys
6
+
7
+ import pytest
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.utils.data import DataLoader, TensorDataset
11
+
12
+ from torchloop import Trainer
13
+ from torchloop.callbacks import Callback, MLflowLogger, WandBLogger
14
+
15
+
16
+ def _make_loader(n: int = 32, features: int = 8, classes: int = 3) -> DataLoader:
17
+ x_data = torch.randn(n, features)
18
+ y_data = torch.randint(0, classes, (n,))
19
+ dataset = TensorDataset(x_data, y_data)
20
+ return DataLoader(dataset, batch_size=8)
21
+
22
+
23
+ def _make_model(features: int = 8, classes: int = 3) -> nn.Module:
24
+ return nn.Sequential(nn.Linear(features, 16), nn.ReLU(), nn.Linear(16, classes))
25
+
26
+
27
+ def test_callback_base_methods_are_noop() -> None:
28
+ """Ensure base callback methods are safe no-op defaults."""
29
+ callback = Callback()
30
+ callback.on_train_begin({"start": True})
31
+ callback.on_epoch_end(1, {"loss": 1.0})
32
+ callback.on_train_end({"done": True})
33
+
34
+
35
+ def test_wandb_logger_raises_if_not_installed(monkeypatch: pytest.MonkeyPatch) -> None:
36
+ """Raise ImportError when wandb is unavailable at runtime."""
37
+ monkeypatch.setitem(sys.modules, "wandb", None)
38
+ logger = WandBLogger(project="proj")
39
+
40
+ with pytest.raises(ImportError, match="wandb is required"):
41
+ logger.on_train_begin({"train": True})
42
+
43
+
44
+
45
+ def test_mlflow_logger_raises_if_not_installed(monkeypatch: pytest.MonkeyPatch) -> None:
46
+ """Raise ImportError when mlflow is unavailable at runtime."""
47
+ monkeypatch.setitem(sys.modules, "mlflow", None)
48
+ logger = MLflowLogger(experiment_name="exp")
49
+
50
+ with pytest.raises(ImportError, match="mlflow is required"):
51
+ logger.on_train_begin({"train": True})
52
+
53
+
54
+
55
+ def test_trainer_accepts_empty_callbacks_list() -> None:
56
+ """Allow explicit empty callback lists during fit execution."""
57
+ model = _make_model()
58
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
59
+ criterion = nn.CrossEntropyLoss()
60
+
61
+ trainer = Trainer(
62
+ model=model,
63
+ optimizer=optimizer,
64
+ criterion=criterion,
65
+ device="cpu",
66
+ callbacks=[],
67
+ )
68
+
69
+ history = trainer.fit(_make_loader(), _make_loader(), epochs=2)
70
+ assert len(history["train_loss"]) == 2
71
+
72
+
73
+
74
+ def test_trainer_calls_on_epoch_end() -> None:
75
+ """Invoke callback epoch-end hook once per epoch with epoch indices."""
76
+
77
+ class TestCallback(Callback):
78
+ def __init__(self) -> None:
79
+ self.epochs: list[int] = []
80
+
81
+ def on_epoch_end(self, epoch: int, logs: dict) -> None:
82
+ self.epochs.append(epoch)
83
+
84
+ callback = TestCallback()
85
+ model = _make_model()
86
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
87
+ criterion = nn.CrossEntropyLoss()
88
+
89
+ trainer = Trainer(
90
+ model=model,
91
+ optimizer=optimizer,
92
+ criterion=criterion,
93
+ device="cpu",
94
+ callbacks=[callback],
95
+ )
96
+ trainer.fit(_make_loader(), _make_loader(), epochs=3)
97
+
98
+ assert callback.epochs == [1, 2, 3]
@@ -4,7 +4,7 @@ import pytest
4
4
  from torch.utils.data import DataLoader, TensorDataset
5
5
 
6
6
  from torchloop import Trainer
7
- from torchloop.callbacks import Callback, EarlyStopping
7
+ from torchloop.callbacks import Callback
8
8
 
9
9
 
10
10
  def _make_loader(n=64, features=16, classes=3, batch=16):
@@ -140,20 +140,12 @@ def test_trainer_callbacks_triggered():
140
140
  class Recorder(Callback):
141
141
  def __init__(self):
142
142
  self.train_begin = 0
143
- self.epoch_begin = 0
144
- self.batch_end = 0
145
143
  self.epoch_end = 0
146
144
  self.train_end = 0
147
145
 
148
146
  def on_train_begin(self, logs=None):
149
147
  self.train_begin += 1
150
148
 
151
- def on_epoch_begin(self, epoch, logs=None):
152
- self.epoch_begin += 1
153
-
154
- def on_batch_end(self, batch, logs=None):
155
- self.batch_end += 1
156
-
157
149
  def on_epoch_end(self, epoch, logs=None):
158
150
  self.epoch_end += 1
159
151
 
@@ -175,17 +167,22 @@ def test_trainer_callbacks_triggered():
175
167
  trainer.fit(_make_loader(n=32, batch=8), epochs=2)
176
168
 
177
169
  assert recorder.train_begin == 1
178
- assert recorder.epoch_begin == 2
179
- assert recorder.batch_end == 8
180
170
  assert recorder.epoch_end == 2
181
171
  assert recorder.train_end == 1
182
172
 
183
173
 
184
- def test_early_stopping_callback_stops_training():
174
+ def test_callback_receives_epoch_end_for_each_epoch():
175
+ class EpochRecorder(Callback):
176
+ def __init__(self):
177
+ self.epochs = []
178
+
179
+ def on_epoch_end(self, epoch, logs=None):
180
+ self.epochs.append(epoch)
181
+
185
182
  model = _make_model()
186
183
  optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
187
184
  criterion = nn.CrossEntropyLoss()
188
- callback = EarlyStopping(patience=1, monitor="val_loss", min_delta=0.1)
185
+ callback = EpochRecorder()
189
186
 
190
187
  trainer = Trainer(
191
188
  model,
@@ -194,8 +191,8 @@ def test_early_stopping_callback_stops_training():
194
191
  device="cpu",
195
192
  callbacks=[callback],
196
193
  )
197
- history = trainer.fit(_make_loader(), _make_loader(), epochs=10)
198
- assert len(history["train_loss"]) < 10
194
+ trainer.fit(_make_loader(), _make_loader(), epochs=3)
195
+ assert callback.epochs == [1, 2, 3]
199
196
 
200
197
 
201
198
  @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@@ -1,216 +0,0 @@
1
- """Callback utilities for training lifecycle hooks."""
2
-
3
- from __future__ import annotations
4
-
5
- import csv
6
- from pathlib import Path
7
- from typing import Any, Optional
8
-
9
- import torch
10
-
11
-
12
- class StopTraining(Exception):
13
- """Raised by callbacks to request early stop of training."""
14
-
15
-
16
- class Callback:
17
- """Base callback class.
18
-
19
- Subclass this and override hook methods to customize training behavior.
20
- """
21
-
22
- def on_train_begin(self, logs: Optional[dict[str, Any]] = None) -> None:
23
- """Called before the first epoch starts.
24
-
25
- Args:
26
- logs: Optional training state dictionary.
27
- """
28
-
29
- def on_train_end(self, logs: Optional[dict[str, Any]] = None) -> None:
30
- """Called after training ends.
31
-
32
- Args:
33
- logs: Optional training state dictionary.
34
- """
35
-
36
- def on_epoch_begin(
37
- self,
38
- epoch: int,
39
- logs: Optional[dict[str, Any]] = None,
40
- ) -> None:
41
- """Called at the start of each epoch.
42
-
43
- Args:
44
- epoch: Current epoch number.
45
- logs: Optional training state dictionary.
46
- """
47
-
48
- def on_epoch_end(
49
- self,
50
- epoch: int,
51
- logs: Optional[dict[str, Any]] = None,
52
- ) -> None:
53
- """Called at the end of each epoch.
54
-
55
- Args:
56
- epoch: Current epoch number.
57
- logs: Optional training state dictionary.
58
- """
59
-
60
- def on_batch_end(
61
- self,
62
- batch: int,
63
- logs: Optional[dict[str, Any]] = None,
64
- ) -> None:
65
- """Called after each training batch.
66
-
67
- Args:
68
- batch: Zero-based batch index.
69
- logs: Optional batch-level metrics.
70
- """
71
-
72
-
73
- class EarlyStopping(Callback):
74
- """Stop training when a monitored metric stops improving.
75
-
76
- Args:
77
- patience: Number of epochs to wait for an improvement.
78
- monitor: Metric name expected in logs.
79
- min_delta: Minimum improvement required to reset patience.
80
- """
81
-
82
- def __init__(
83
- self,
84
- patience: int = 5,
85
- monitor: str = "val_loss",
86
- min_delta: float = 0.0,
87
- ) -> None:
88
- if patience < 1:
89
- raise ValueError("patience must be >= 1")
90
- if min_delta < 0:
91
- raise ValueError("min_delta must be >= 0")
92
-
93
- self.patience = patience
94
- self.monitor = monitor
95
- self.min_delta = min_delta
96
- self.best = float("inf")
97
- self.counter = 0
98
-
99
- def on_epoch_end(
100
- self,
101
- epoch: int,
102
- logs: Optional[dict[str, Any]] = None,
103
- ) -> None:
104
- """Check metric value and raise StopTraining when patience is exceeded.
105
-
106
- Args:
107
- epoch: Current epoch number.
108
- logs: Metrics dictionary from trainer.
109
- """
110
- logs = logs or {}
111
- current = logs.get(self.monitor)
112
- if current is None:
113
- return
114
-
115
- if current <= self.best - self.min_delta:
116
- self.best = float(current)
117
- self.counter = 0
118
- return
119
-
120
- self.counter += 1
121
- if self.counter >= self.patience:
122
- raise StopTraining(f"Early stopping triggered at epoch {epoch}")
123
-
124
-
125
- class ModelCheckpoint(Callback):
126
- """Save model checkpoints based on monitored metric.
127
-
128
- Args:
129
- monitor: Metric name expected in logs.
130
- save_best_only: Save only when monitored metric improves.
131
- filepath: Path to write checkpoint file.
132
- """
133
-
134
- def __init__(
135
- self,
136
- monitor: str = "val_loss",
137
- save_best_only: bool = True,
138
- filepath: str = "checkpoint.pt",
139
- ) -> None:
140
- self.monitor = monitor
141
- self.save_best_only = save_best_only
142
- self.filepath = Path(filepath)
143
- self.best = float("inf")
144
-
145
- def on_epoch_end(
146
- self,
147
- epoch: int,
148
- logs: Optional[dict[str, Any]] = None,
149
- ) -> None:
150
- """Persist model checkpoint if criteria are met.
151
-
152
- Args:
153
- epoch: Current epoch number.
154
- logs: Metrics dictionary from trainer.
155
- """
156
- logs = logs or {}
157
- model = logs.get("model")
158
- if model is None:
159
- return
160
-
161
- current = logs.get(self.monitor)
162
- if current is None:
163
- return
164
-
165
- if not self.save_best_only:
166
- torch.save(model.state_dict(), self.filepath)
167
- return
168
-
169
- if float(current) < self.best:
170
- self.best = float(current)
171
- torch.save(model.state_dict(), self.filepath)
172
-
173
-
174
- class CSVLogger(Callback):
175
- """Log epoch metrics to a CSV file.
176
-
177
- Args:
178
- log_dir: Directory where CSV logs are written.
179
- filename: Name of CSV file.
180
- """
181
-
182
- def __init__(self, log_dir: str = "./logs", filename: str = "metrics.csv") -> None:
183
- self.log_dir = Path(log_dir)
184
- self.filename = filename
185
- self._initialized = False
186
- self._fieldnames: list[str] = []
187
-
188
- def on_epoch_end(
189
- self,
190
- epoch: int,
191
- logs: Optional[dict[str, Any]] = None,
192
- ) -> None:
193
- """Append epoch metrics to CSV.
194
-
195
- Args:
196
- epoch: Current epoch number.
197
- logs: Metrics dictionary from trainer.
198
- """
199
- logs = dict(logs or {})
200
- logs["epoch"] = epoch
201
-
202
- self.log_dir.mkdir(parents=True, exist_ok=True)
203
- file_path = self.log_dir / self.filename
204
-
205
- if not self._initialized:
206
- self._fieldnames = sorted(logs.keys())
207
- with file_path.open("w", newline="", encoding="utf-8") as f:
208
- writer = csv.DictWriter(f, fieldnames=self._fieldnames)
209
- writer.writeheader()
210
- writer.writerow({k: logs.get(k) for k in self._fieldnames})
211
- self._initialized = True
212
- return
213
-
214
- with file_path.open("a", newline="", encoding="utf-8") as f:
215
- writer = csv.DictWriter(f, fieldnames=self._fieldnames)
216
- writer.writerow({k: logs.get(k) for k in self._fieldnames})
@@ -1,191 +0,0 @@
1
- """Tests for callback utilities."""
2
-
3
- from __future__ import annotations
4
-
5
- import csv
6
- from pathlib import Path
7
-
8
- import pytest
9
- import torch
10
- import torch.nn as nn
11
- from torch.utils.data import DataLoader, TensorDataset
12
-
13
- from torchloop import Trainer
14
- from torchloop.callbacks import CSVLogger, Callback, EarlyStopping, ModelCheckpoint, StopTraining
15
-
16
-
17
- class TinyNet(nn.Module):
18
- """Simple network for callback tests."""
19
-
20
- def __init__(self) -> None:
21
- super().__init__()
22
- self.fc = nn.Linear(10, 2)
23
-
24
- def forward(self, x: torch.Tensor) -> torch.Tensor:
25
- return self.fc(x)
26
-
27
-
28
- @pytest.fixture
29
- def loaders() -> tuple[DataLoader, DataLoader]:
30
- """Create tiny train and validation loaders."""
31
- x_train = torch.randn(12, 10)
32
- y_train = torch.randint(0, 2, (12,))
33
- x_val = torch.randn(8, 10)
34
- y_val = torch.randint(0, 2, (8,))
35
-
36
- train_loader = DataLoader(TensorDataset(x_train, y_train), batch_size=4)
37
- val_loader = DataLoader(TensorDataset(x_val, y_val), batch_size=4)
38
- return train_loader, val_loader
39
-
40
-
41
- @pytest.fixture
42
- def trainer_parts() -> tuple[nn.Module, torch.optim.Optimizer, nn.Module]:
43
- """Create model, optimizer, and criterion for Trainer setup."""
44
- model = TinyNet()
45
- optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
46
- criterion = nn.CrossEntropyLoss()
47
- return model, optimizer, criterion
48
-
49
-
50
- def test_early_stopping_init_values() -> None:
51
- """Validate early stopping initialization fields."""
52
- callback = EarlyStopping(patience=3, monitor="val_loss", min_delta=0.1)
53
- assert callback.patience == 3
54
- assert callback.monitor == "val_loss"
55
- assert callback.min_delta == 0.1
56
- assert callback.best == float("inf")
57
- assert callback.counter == 0
58
-
59
-
60
- def test_early_stopping_validation_errors() -> None:
61
- """Validate constructor error handling for invalid config."""
62
- with pytest.raises(ValueError, match="patience must be >= 1"):
63
- EarlyStopping(patience=0)
64
-
65
- with pytest.raises(ValueError, match="min_delta must be >= 0"):
66
- EarlyStopping(min_delta=-0.1)
67
-
68
-
69
- def test_early_stopping_triggers_after_patience() -> None:
70
- """Raise StopTraining when monitored metric does not improve."""
71
- callback = EarlyStopping(patience=2, monitor="val_loss")
72
- callback.on_epoch_end(epoch=0, logs={"val_loss": 1.0})
73
- callback.on_epoch_end(epoch=1, logs={"val_loss": 1.1})
74
-
75
- with pytest.raises(StopTraining):
76
- callback.on_epoch_end(epoch=2, logs={"val_loss": 1.2})
77
-
78
-
79
- def test_model_checkpoint_save_best_only(tmp_path: Path) -> None:
80
- """Save checkpoint only when monitored metric improves."""
81
- model = TinyNet()
82
- path = tmp_path / "best.pt"
83
- callback = ModelCheckpoint(filepath=str(path), save_best_only=True)
84
-
85
- callback.on_epoch_end(epoch=0, logs={"model": model, "val_loss": 1.0})
86
- first_mtime = path.stat().st_mtime
87
-
88
- callback.on_epoch_end(epoch=1, logs={"model": model, "val_loss": 1.2})
89
- assert path.stat().st_mtime == first_mtime
90
-
91
- callback.on_epoch_end(epoch=2, logs={"model": model, "val_loss": 0.8})
92
- assert path.stat().st_mtime >= first_mtime
93
-
94
-
95
- def test_model_checkpoint_save_all(tmp_path: Path) -> None:
96
- """Always save checkpoint when save_best_only is disabled."""
97
- model = TinyNet()
98
- path = tmp_path / "all.pt"
99
- callback = ModelCheckpoint(filepath=str(path), save_best_only=False)
100
-
101
- callback.on_epoch_end(epoch=0, logs={"model": model, "val_loss": 2.0})
102
- first_mtime = path.stat().st_mtime
103
- callback.on_epoch_end(epoch=1, logs={"model": model, "val_loss": 3.0})
104
- assert path.stat().st_mtime >= first_mtime
105
-
106
-
107
- def test_model_checkpoint_noop_when_logs_missing(tmp_path: Path) -> None:
108
- """Do not crash when required keys are missing in logs."""
109
- path = tmp_path / "noop.pt"
110
- callback = ModelCheckpoint(filepath=str(path))
111
-
112
- callback.on_epoch_end(epoch=0, logs={"val_loss": 1.0})
113
- callback.on_epoch_end(epoch=1, logs={"model": TinyNet()})
114
- assert not path.exists()
115
-
116
-
117
- def test_csv_logger_writes_header_and_rows(tmp_path: Path) -> None:
118
- """Write CSV header once and append per-epoch values."""
119
- callback = CSVLogger(log_dir=str(tmp_path), filename="metrics.csv")
120
- callback.on_epoch_end(epoch=0, logs={"loss": 1.0, "val_loss": 1.1})
121
- callback.on_epoch_end(epoch=1, logs={"loss": 0.9, "val_loss": 1.0})
122
-
123
- csv_path = tmp_path / "metrics.csv"
124
- with csv_path.open("r", encoding="utf-8") as f:
125
- rows = list(csv.reader(f))
126
-
127
- assert len(rows) == 3
128
- assert "epoch" in rows[0]
129
-
130
-
131
- def test_trainer_calls_custom_callback_hooks(
132
- trainer_parts: tuple[nn.Module, torch.optim.Optimizer, nn.Module],
133
- loaders: tuple[DataLoader, DataLoader],
134
- ) -> None:
135
- """Ensure custom callback receives train, epoch, and batch hooks."""
136
-
137
- class Recorder(Callback):
138
- def __init__(self) -> None:
139
- self.events: list[str] = []
140
-
141
- def on_train_begin(self, logs=None) -> None:
142
- self.events.append("train_begin")
143
-
144
- def on_epoch_begin(self, epoch: int, logs=None) -> None:
145
- self.events.append(f"epoch_begin_{epoch}")
146
-
147
- def on_batch_end(self, batch: int, logs=None) -> None:
148
- self.events.append("batch_end")
149
-
150
- def on_epoch_end(self, epoch: int, logs=None) -> None:
151
- self.events.append(f"epoch_end_{epoch}")
152
-
153
- def on_train_end(self, logs=None) -> None:
154
- self.events.append("train_end")
155
-
156
- model, optimizer, criterion = trainer_parts
157
- train_loader, val_loader = loaders
158
- callback = Recorder()
159
- trainer = Trainer(model, optimizer, criterion, callbacks=[callback], device="cpu")
160
- trainer.fit(train_loader, val_loader, epochs=2)
161
-
162
- assert "train_begin" in callback.events
163
- assert "train_end" in callback.events
164
- assert "epoch_begin_1" in callback.events
165
- assert "epoch_end_2" in callback.events
166
- assert "batch_end" in callback.events
167
-
168
-
169
- def test_trainer_handles_stop_training_callback(
170
- trainer_parts: tuple[nn.Module, torch.optim.Optimizer, nn.Module],
171
- loaders: tuple[DataLoader, DataLoader],
172
- ) -> None:
173
- """StopTraining raised by callback should halt training gracefully."""
174
-
175
- class StopAfterFirstEpoch(Callback):
176
- def on_epoch_end(self, epoch: int, logs=None) -> None:
177
- if epoch >= 1:
178
- raise StopTraining("stop")
179
-
180
- model, optimizer, criterion = trainer_parts
181
- train_loader, val_loader = loaders
182
- trainer = Trainer(
183
- model,
184
- optimizer,
185
- criterion,
186
- callbacks=[StopAfterFirstEpoch()],
187
- device="cpu",
188
- )
189
-
190
- history = trainer.fit(train_loader, val_loader, epochs=5)
191
- assert len(history["train_loss"]) == 1
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes