torchloop 0.2.3__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torchloop-0.2.3 → torchloop-0.3.0}/PKG-INFO +10 -5
- {torchloop-0.2.3 → torchloop-0.3.0}/pyproject.toml +11 -2
- {torchloop-0.2.3 → torchloop-0.3.0}/src/torchloop/__init__.py +4 -12
- torchloop-0.3.0/src/torchloop/callbacks/__init__.py +7 -0
- torchloop-0.3.0/src/torchloop/callbacks/base.py +19 -0
- torchloop-0.3.0/src/torchloop/callbacks/mlflow_logger.py +72 -0
- torchloop-0.3.0/src/torchloop/callbacks/wandb_logger.py +64 -0
- {torchloop-0.2.3 → torchloop-0.3.0}/src/torchloop/edge/deploy.py +14 -4
- {torchloop-0.2.3 → torchloop-0.3.0}/src/torchloop/edge/estimate.py +19 -4
- {torchloop-0.2.3 → torchloop-0.3.0}/src/torchloop/trainer.py +12 -18
- torchloop-0.3.0/tests/test_callbacks.py +98 -0
- {torchloop-0.2.3 → torchloop-0.3.0}/tests/test_trainer.py +12 -15
- torchloop-0.2.3/src/torchloop/callbacks.py +0 -216
- torchloop-0.2.3/tests/test_callbacks.py +0 -191
- {torchloop-0.2.3 → torchloop-0.3.0}/.github/workflows/ci.yml +0 -0
- {torchloop-0.2.3 → torchloop-0.3.0}/.github/workflows/publish.yml +0 -0
- {torchloop-0.2.3 → torchloop-0.3.0}/.gitignore +0 -0
- {torchloop-0.2.3 → torchloop-0.3.0}/CHANGELOG.md +0 -0
- {torchloop-0.2.3 → torchloop-0.3.0}/LICENSE +0 -0
- {torchloop-0.2.3 → torchloop-0.3.0}/README.md +0 -0
- {torchloop-0.2.3 → torchloop-0.3.0}/examples/edge_deployment.py +0 -0
- {torchloop-0.2.3 → torchloop-0.3.0}/requirements.txt +0 -0
- {torchloop-0.2.3 → torchloop-0.3.0}/src/torchloop/edge/__init__.py +0 -0
- {torchloop-0.2.3 → torchloop-0.3.0}/src/torchloop/evaluator.py +0 -0
- {torchloop-0.2.3 → torchloop-0.3.0}/src/torchloop/exporter.py +0 -0
- {torchloop-0.2.3 → torchloop-0.3.0}/tests/__init__.py +0 -0
- {torchloop-0.2.3 → torchloop-0.3.0}/tests/conftest.py +0 -0
- {torchloop-0.2.3 → torchloop-0.3.0}/tests/test_edge.py +0 -0
- {torchloop-0.2.3 → torchloop-0.3.0}/tests/test_evaluator.py +0 -0
- {torchloop-0.2.3 → torchloop-0.3.0}/tests/test_exporter.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torchloop
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: Lightweight PyTorch utility library for training, evaluation, and TFLite export — without the framework lock-in.
|
|
5
5
|
Project-URL: Homepage, https://github.com/Tharun007-TK/torchloop
|
|
6
6
|
Project-URL: Repository, https://github.com/Tharun007-TK/torchloop
|
|
@@ -47,19 +47,17 @@ Requires-Dist: torchvision>=0.15.0
|
|
|
47
47
|
Requires-Dist: tqdm>=4.65.0
|
|
48
48
|
Provides-Extra: all
|
|
49
49
|
Requires-Dist: black>=23.0; extra == 'all'
|
|
50
|
-
Requires-Dist: coremltools>=7.0; (platform_system == 'Darwin') and extra == 'all'
|
|
51
50
|
Requires-Dist: hatch>=1.7.0; extra == 'all'
|
|
51
|
+
Requires-Dist: mlflow>=2.0.0; extra == 'all'
|
|
52
52
|
Requires-Dist: mypy>=1.0; extra == 'all'
|
|
53
53
|
Requires-Dist: onnx>=1.14.0; extra == 'all'
|
|
54
|
-
Requires-Dist: onnx>=1.15; extra == 'all'
|
|
55
54
|
Requires-Dist: onnxruntime>=1.15.0; extra == 'all'
|
|
56
|
-
Requires-Dist: onnxruntime>=1.16; extra == 'all'
|
|
57
55
|
Requires-Dist: pytest-cov>=4.0; extra == 'all'
|
|
58
56
|
Requires-Dist: pytest>=7.0; extra == 'all'
|
|
59
57
|
Requires-Dist: ruff>=0.1.0; extra == 'all'
|
|
60
|
-
Requires-Dist: tensorflow-lite-runtime>=2.15; (platform_system != 'Darwin') and extra == 'all'
|
|
61
58
|
Requires-Dist: tensorflow>=2.13.0; extra == 'all'
|
|
62
59
|
Requires-Dist: torchinfo>=1.8; extra == 'all'
|
|
60
|
+
Requires-Dist: wandb>=0.15.0; extra == 'all'
|
|
63
61
|
Provides-Extra: dev
|
|
64
62
|
Requires-Dist: black>=23.0; extra == 'dev'
|
|
65
63
|
Requires-Dist: hatch>=1.7.0; extra == 'dev'
|
|
@@ -68,6 +66,10 @@ Requires-Dist: pytest-cov>=4.0; extra == 'dev'
|
|
|
68
66
|
Requires-Dist: pytest>=7.0; extra == 'dev'
|
|
69
67
|
Requires-Dist: ruff>=0.1.0; extra == 'dev'
|
|
70
68
|
Requires-Dist: torchinfo>=1.8; extra == 'dev'
|
|
69
|
+
Provides-Extra: docs
|
|
70
|
+
Requires-Dist: mkdocs-material>=9.0.0; extra == 'docs'
|
|
71
|
+
Requires-Dist: mkdocs>=1.5.0; extra == 'docs'
|
|
72
|
+
Requires-Dist: mkdocstrings[python]>=0.24.0; extra == 'docs'
|
|
71
73
|
Provides-Extra: edge
|
|
72
74
|
Requires-Dist: coremltools>=7.0; (platform_system == 'Darwin') and extra == 'edge'
|
|
73
75
|
Requires-Dist: onnx>=1.15; extra == 'edge'
|
|
@@ -77,6 +79,9 @@ Provides-Extra: export
|
|
|
77
79
|
Requires-Dist: onnx>=1.14.0; extra == 'export'
|
|
78
80
|
Requires-Dist: onnxruntime>=1.15.0; extra == 'export'
|
|
79
81
|
Requires-Dist: tensorflow>=2.13.0; extra == 'export'
|
|
82
|
+
Provides-Extra: logging
|
|
83
|
+
Requires-Dist: mlflow>=2.0.0; extra == 'logging'
|
|
84
|
+
Requires-Dist: wandb>=0.15.0; extra == 'logging'
|
|
80
85
|
Description-Content-Type: text/markdown
|
|
81
86
|
|
|
82
87
|
# torchloop
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "torchloop"
|
|
7
|
-
version = "0.
|
|
7
|
+
version = "0.3.0"
|
|
8
8
|
description = "Lightweight PyTorch utility library for training, evaluation, and TFLite export — without the framework lock-in."
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = { file = "LICENSE" }
|
|
@@ -41,6 +41,15 @@ edge = [
|
|
|
41
41
|
"tensorflow-lite-runtime>=2.15; platform_system != 'Darwin'",
|
|
42
42
|
"coremltools>=7.0; platform_system == 'Darwin'",
|
|
43
43
|
]
|
|
44
|
+
logging = [
|
|
45
|
+
"wandb>=0.15.0",
|
|
46
|
+
"mlflow>=2.0.0",
|
|
47
|
+
]
|
|
48
|
+
docs = [
|
|
49
|
+
"mkdocs>=1.5.0",
|
|
50
|
+
"mkdocs-material>=9.0.0",
|
|
51
|
+
"mkdocstrings[python]>=0.24.0",
|
|
52
|
+
]
|
|
44
53
|
export = [
|
|
45
54
|
"onnx>=1.14.0",
|
|
46
55
|
"onnxruntime>=1.15.0",
|
|
@@ -55,7 +64,7 @@ dev = [
|
|
|
55
64
|
"ruff>=0.1.0", # linter + formatter
|
|
56
65
|
"hatch>=1.7.0",
|
|
57
66
|
]
|
|
58
|
-
all = ["torchloop[
|
|
67
|
+
all = ["torchloop[export,dev,logging]"]
|
|
59
68
|
|
|
60
69
|
[project.urls]
|
|
61
70
|
Homepage = "https://github.com/Tharun007-TK/torchloop"
|
|
@@ -7,16 +7,10 @@ Modules:
|
|
|
7
7
|
exporter : PyTorch → ONNX → TFLite with optional quantization
|
|
8
8
|
"""
|
|
9
9
|
|
|
10
|
-
__version__ = "0.
|
|
10
|
+
__version__ = "0.3.0"
|
|
11
11
|
__author__ = "Tharun Kumar"
|
|
12
12
|
|
|
13
|
-
from torchloop.callbacks import
|
|
14
|
-
CSVLogger,
|
|
15
|
-
Callback,
|
|
16
|
-
EarlyStopping,
|
|
17
|
-
ModelCheckpoint,
|
|
18
|
-
StopTraining,
|
|
19
|
-
)
|
|
13
|
+
from torchloop.callbacks import Callback, MLflowLogger, WandBLogger
|
|
20
14
|
from torchloop.edge import deploy_to_edge, estimate_model
|
|
21
15
|
from torchloop.evaluator import Evaluator
|
|
22
16
|
from torchloop.exporter import Exporter
|
|
@@ -27,10 +21,8 @@ __all__ = [
|
|
|
27
21
|
"Evaluator",
|
|
28
22
|
"Exporter",
|
|
29
23
|
"Callback",
|
|
30
|
-
"
|
|
31
|
-
"
|
|
32
|
-
"CSVLogger",
|
|
33
|
-
"StopTraining",
|
|
24
|
+
"WandBLogger",
|
|
25
|
+
"MLflowLogger",
|
|
34
26
|
"deploy_to_edge",
|
|
35
27
|
"estimate_model",
|
|
36
28
|
"__version__",
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
"""Callback integrations for torchloop."""
|
|
2
|
+
|
|
3
|
+
from torchloop.callbacks.base import Callback
|
|
4
|
+
from torchloop.callbacks.mlflow_logger import MLflowLogger
|
|
5
|
+
from torchloop.callbacks.wandb_logger import WandBLogger
|
|
6
|
+
|
|
7
|
+
__all__ = ["Callback", "WandBLogger", "MLflowLogger"]
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Base callback abstractions for torchloop training events."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Callback:
|
|
7
|
+
"""Base callback with optional training lifecycle hooks.
|
|
8
|
+
|
|
9
|
+
Subclasses can override any hook they need.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def on_train_begin(self, logs: dict) -> None:
|
|
13
|
+
"""Run before training starts."""
|
|
14
|
+
|
|
15
|
+
def on_epoch_end(self, epoch: int, logs: dict) -> None:
|
|
16
|
+
"""Run at the end of each epoch."""
|
|
17
|
+
|
|
18
|
+
def on_train_end(self, logs: dict) -> None:
|
|
19
|
+
"""Run after training finishes."""
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
"""MLflow callback integration for torchloop."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from importlib import import_module
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
from torchloop.callbacks.base import Callback
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MLflowLogger(Callback):
|
|
12
|
+
"""Log training metrics to MLflow.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
experiment_name: MLflow experiment name.
|
|
16
|
+
tracking_uri: Optional tracking server URI.
|
|
17
|
+
run_name: Optional MLflow run name.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
experiment_name: str,
|
|
23
|
+
tracking_uri: Optional[str] = None,
|
|
24
|
+
run_name: Optional[str] = None,
|
|
25
|
+
) -> None:
|
|
26
|
+
self.experiment_name = experiment_name
|
|
27
|
+
self.tracking_uri = tracking_uri
|
|
28
|
+
self.run_name = run_name
|
|
29
|
+
|
|
30
|
+
def on_train_begin(self, logs: dict) -> None:
|
|
31
|
+
"""Initialize MLflow experiment and run."""
|
|
32
|
+
try:
|
|
33
|
+
mlflow = import_module("mlflow")
|
|
34
|
+
except ImportError as exc:
|
|
35
|
+
raise ImportError(
|
|
36
|
+
"mlflow is required for MLflowLogger. "
|
|
37
|
+
"Install with: pip install torchloop[logging]"
|
|
38
|
+
) from exc
|
|
39
|
+
|
|
40
|
+
if self.tracking_uri:
|
|
41
|
+
mlflow.set_tracking_uri(self.tracking_uri)
|
|
42
|
+
mlflow.set_experiment(self.experiment_name)
|
|
43
|
+
mlflow.start_run(run_name=self.run_name)
|
|
44
|
+
|
|
45
|
+
def on_epoch_end(self, epoch: int, logs: dict) -> None:
|
|
46
|
+
"""Log epoch metrics to MLflow."""
|
|
47
|
+
try:
|
|
48
|
+
mlflow = import_module("mlflow")
|
|
49
|
+
except ImportError as exc:
|
|
50
|
+
raise ImportError(
|
|
51
|
+
"mlflow is required for MLflowLogger. "
|
|
52
|
+
"Install with: pip install torchloop[logging]"
|
|
53
|
+
) from exc
|
|
54
|
+
|
|
55
|
+
numeric_logs = {
|
|
56
|
+
key: value
|
|
57
|
+
for key, value in logs.items()
|
|
58
|
+
if isinstance(value, (int, float))
|
|
59
|
+
}
|
|
60
|
+
mlflow.log_metrics(numeric_logs, step=epoch)
|
|
61
|
+
|
|
62
|
+
def on_train_end(self, logs: dict) -> None:
|
|
63
|
+
"""End the active MLflow run."""
|
|
64
|
+
try:
|
|
65
|
+
mlflow = import_module("mlflow")
|
|
66
|
+
except ImportError as exc:
|
|
67
|
+
raise ImportError(
|
|
68
|
+
"mlflow is required for MLflowLogger. "
|
|
69
|
+
"Install with: pip install torchloop[logging]"
|
|
70
|
+
) from exc
|
|
71
|
+
|
|
72
|
+
mlflow.end_run()
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
"""Weights & Biases callback integration for torchloop."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from importlib import import_module
|
|
6
|
+
from typing import Any, Optional
|
|
7
|
+
|
|
8
|
+
from torchloop.callbacks.base import Callback
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class WandBLogger(Callback):
|
|
12
|
+
"""Log training metrics to Weights & Biases.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
project: Weights & Biases project name.
|
|
16
|
+
name: Optional run name.
|
|
17
|
+
config: Optional run configuration dictionary.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
project: str,
|
|
23
|
+
name: Optional[str] = None,
|
|
24
|
+
config: Optional[dict[str, Any]] = None,
|
|
25
|
+
) -> None:
|
|
26
|
+
self.project = project
|
|
27
|
+
self.name = name
|
|
28
|
+
self.config = config or {}
|
|
29
|
+
|
|
30
|
+
def on_train_begin(self, logs: dict) -> None:
|
|
31
|
+
"""Initialize a W&B run."""
|
|
32
|
+
try:
|
|
33
|
+
wandb = import_module("wandb")
|
|
34
|
+
except ImportError as exc:
|
|
35
|
+
raise ImportError(
|
|
36
|
+
"wandb is required for WandBLogger. "
|
|
37
|
+
"Install with: pip install torchloop[logging]"
|
|
38
|
+
) from exc
|
|
39
|
+
|
|
40
|
+
wandb.init(project=self.project, name=self.name, config=self.config)
|
|
41
|
+
|
|
42
|
+
def on_epoch_end(self, epoch: int, logs: dict) -> None:
|
|
43
|
+
"""Log epoch metrics to W&B."""
|
|
44
|
+
try:
|
|
45
|
+
wandb = import_module("wandb")
|
|
46
|
+
except ImportError as exc:
|
|
47
|
+
raise ImportError(
|
|
48
|
+
"wandb is required for WandBLogger. "
|
|
49
|
+
"Install with: pip install torchloop[logging]"
|
|
50
|
+
) from exc
|
|
51
|
+
|
|
52
|
+
wandb.log(logs, step=epoch)
|
|
53
|
+
|
|
54
|
+
def on_train_end(self, logs: dict) -> None:
|
|
55
|
+
"""Finish the active W&B run."""
|
|
56
|
+
try:
|
|
57
|
+
wandb = import_module("wandb")
|
|
58
|
+
except ImportError as exc:
|
|
59
|
+
raise ImportError(
|
|
60
|
+
"wandb is required for WandBLogger. "
|
|
61
|
+
"Install with: pip install torchloop[logging]"
|
|
62
|
+
) from exc
|
|
63
|
+
|
|
64
|
+
wandb.finish()
|
|
@@ -61,7 +61,10 @@ def deploy_to_edge(
|
|
|
61
61
|
artifact_format = "onnx"
|
|
62
62
|
|
|
63
63
|
if target in {"esp32", "android"}:
|
|
64
|
-
artifact_path =
|
|
64
|
+
artifact_path = (
|
|
65
|
+
output if output.suffix == ".tflite"
|
|
66
|
+
else output.with_suffix(".tflite")
|
|
67
|
+
)
|
|
65
68
|
if target == "esp32" and quantize_type != "int8":
|
|
66
69
|
warnings.warn(
|
|
67
70
|
"ESP32 is typically best with int8 quantization.",
|
|
@@ -79,7 +82,10 @@ def deploy_to_edge(
|
|
|
79
82
|
_convert_to_tflite(onnx_path, artifact_path, quantize, quantize_type)
|
|
80
83
|
artifact_format = "tflite"
|
|
81
84
|
elif target == "ios":
|
|
82
|
-
artifact_path =
|
|
85
|
+
artifact_path = (
|
|
86
|
+
output if output.suffix == ".mlpackage"
|
|
87
|
+
else output.with_suffix(".mlpackage")
|
|
88
|
+
)
|
|
83
89
|
_convert_to_coreml(model, input_shape, artifact_path)
|
|
84
90
|
artifact_format = "coreml"
|
|
85
91
|
elif target == "jetson":
|
|
@@ -109,7 +115,9 @@ def _export_to_onnx(
|
|
|
109
115
|
onnx_name = "".join(["on", "nx"])
|
|
110
116
|
onnx = import_module(onnx_name)
|
|
111
117
|
except ImportError as exc:
|
|
112
|
-
raise ImportError(
|
|
118
|
+
raise ImportError(
|
|
119
|
+
"onnx is required. Install with: pip install torchloop[edge]"
|
|
120
|
+
) from exc
|
|
113
121
|
|
|
114
122
|
model.eval()
|
|
115
123
|
device = _get_model_device(model)
|
|
@@ -184,7 +192,9 @@ def _convert_to_coreml(
|
|
|
184
192
|
coremltools_name = "".join(["coreml", "tools"])
|
|
185
193
|
ct = import_module(coremltools_name)
|
|
186
194
|
except ImportError as exc:
|
|
187
|
-
raise ImportError(
|
|
195
|
+
raise ImportError(
|
|
196
|
+
"coremltools is required. Install with: pip install torchloop[edge]"
|
|
197
|
+
) from exc
|
|
188
198
|
|
|
189
199
|
model_cpu = model.to("cpu").eval()
|
|
190
200
|
traced = torch.jit.trace(model_cpu, torch.randn(*input_shape))
|
|
@@ -63,19 +63,34 @@ def _estimate_activation_mb(input_shape: tuple[int, ...]) -> float:
|
|
|
63
63
|
return (numel * 4 * 2) / (1024 * 1024)
|
|
64
64
|
|
|
65
65
|
|
|
66
|
-
def _estimate_flops(
|
|
66
|
+
def _estimate_flops(
|
|
67
|
+
model: torch.nn.Module,
|
|
68
|
+
input_shape: tuple[int, ...],
|
|
69
|
+
) -> int:
|
|
67
70
|
"""Estimate FLOPs using hooks for common modules (Conv2d, Linear)."""
|
|
68
71
|
flops = 0
|
|
69
72
|
hooks = []
|
|
70
73
|
|
|
71
|
-
def conv_hook(
|
|
74
|
+
def conv_hook(
|
|
75
|
+
module: torch.nn.Conv2d,
|
|
76
|
+
_inp: tuple[torch.Tensor],
|
|
77
|
+
out: torch.Tensor,
|
|
78
|
+
) -> None:
|
|
72
79
|
nonlocal flops
|
|
73
80
|
batch = out.shape[0]
|
|
74
81
|
out_h, out_w = out.shape[2], out.shape[3]
|
|
75
|
-
kernel_ops =
|
|
82
|
+
kernel_ops = (
|
|
83
|
+
module.kernel_size[0]
|
|
84
|
+
* module.kernel_size[1]
|
|
85
|
+
* (module.in_channels / module.groups)
|
|
86
|
+
)
|
|
76
87
|
flops += int(batch * out_h * out_w * module.out_channels * kernel_ops * 2)
|
|
77
88
|
|
|
78
|
-
def linear_hook(
|
|
89
|
+
def linear_hook(
|
|
90
|
+
module: torch.nn.Linear,
|
|
91
|
+
inp: tuple[torch.Tensor],
|
|
92
|
+
_out: torch.Tensor,
|
|
93
|
+
) -> None:
|
|
79
94
|
nonlocal flops
|
|
80
95
|
batch = inp[0].shape[0]
|
|
81
96
|
flops += int(batch * module.in_features * module.out_features * 2)
|
|
@@ -32,7 +32,7 @@ import torch.nn as nn
|
|
|
32
32
|
from torch.utils.data import DataLoader
|
|
33
33
|
from tqdm import tqdm
|
|
34
34
|
|
|
35
|
-
from torchloop.callbacks import Callback
|
|
35
|
+
from torchloop.callbacks import Callback
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
class Trainer:
|
|
@@ -129,10 +129,9 @@ class Trainer:
|
|
|
129
129
|
Returns:
|
|
130
130
|
history dict with train_loss, val_loss, val_metric, lr per epoch.
|
|
131
131
|
"""
|
|
132
|
-
self.
|
|
132
|
+
self._run_callbacks("on_train_begin", dict(self.history))
|
|
133
133
|
for epoch in range(1, epochs + 1):
|
|
134
134
|
t0 = time.time()
|
|
135
|
-
self._trigger_hook("on_epoch_begin", epoch=epoch, logs=self.history)
|
|
136
135
|
train_loss = self._train_epoch(train_loader)
|
|
137
136
|
self.history["train_loss"].append(train_loss)
|
|
138
137
|
|
|
@@ -153,13 +152,13 @@ class Trainer:
|
|
|
153
152
|
)
|
|
154
153
|
|
|
155
154
|
epoch_logs = {
|
|
155
|
+
"epoch": epoch,
|
|
156
156
|
"train_loss": train_loss,
|
|
157
157
|
"val_loss": val_loss,
|
|
158
158
|
"val_metric": val_metric,
|
|
159
159
|
"lr": current_lr,
|
|
160
|
-
"model": self.model,
|
|
161
160
|
}
|
|
162
|
-
self.
|
|
161
|
+
self._run_callbacks("on_epoch_end", epoch_logs)
|
|
163
162
|
|
|
164
163
|
if self._should_stop():
|
|
165
164
|
print(f" Early stopping triggered at epoch {epoch}.")
|
|
@@ -169,7 +168,7 @@ class Trainer:
|
|
|
169
168
|
self.model.load_state_dict(self._best_state)
|
|
170
169
|
print(" Restored best model weights.")
|
|
171
170
|
|
|
172
|
-
self.
|
|
171
|
+
self._run_callbacks("on_train_end", dict(self.history))
|
|
173
172
|
return self.history
|
|
174
173
|
|
|
175
174
|
def add_callback(self, callback: Callback) -> None:
|
|
@@ -222,11 +221,6 @@ class Trainer:
|
|
|
222
221
|
pending_steps = 0
|
|
223
222
|
|
|
224
223
|
total_loss += raw_loss.item() * inputs.size(0)
|
|
225
|
-
self._trigger_hook(
|
|
226
|
-
"on_batch_end",
|
|
227
|
-
batch=batch_idx,
|
|
228
|
-
logs={"loss": float(raw_loss.item())},
|
|
229
|
-
)
|
|
230
224
|
|
|
231
225
|
if pending_steps > 0:
|
|
232
226
|
self._optimizer_step()
|
|
@@ -296,15 +290,15 @@ class Trainer:
|
|
|
296
290
|
)
|
|
297
291
|
)
|
|
298
292
|
|
|
299
|
-
def
|
|
300
|
-
for callback in self.callbacks:
|
|
301
|
-
hook = getattr(callback,
|
|
293
|
+
def _run_callbacks(self, event: str, logs: dict[str, Any]) -> None:
|
|
294
|
+
for callback in (self.callbacks or []):
|
|
295
|
+
hook = getattr(callback, event, None)
|
|
302
296
|
if hook is None:
|
|
303
297
|
continue
|
|
304
|
-
|
|
305
|
-
hook(
|
|
306
|
-
|
|
307
|
-
|
|
298
|
+
if event == "on_epoch_end":
|
|
299
|
+
hook(int(logs.get("epoch", 0)), logs)
|
|
300
|
+
else:
|
|
301
|
+
hook(logs)
|
|
308
302
|
|
|
309
303
|
@staticmethod
|
|
310
304
|
def _log(
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""Tests for torchloop callback integrations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import sys
|
|
6
|
+
|
|
7
|
+
import pytest
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
11
|
+
|
|
12
|
+
from torchloop import Trainer
|
|
13
|
+
from torchloop.callbacks import Callback, MLflowLogger, WandBLogger
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _make_loader(n: int = 32, features: int = 8, classes: int = 3) -> DataLoader:
|
|
17
|
+
x_data = torch.randn(n, features)
|
|
18
|
+
y_data = torch.randint(0, classes, (n,))
|
|
19
|
+
dataset = TensorDataset(x_data, y_data)
|
|
20
|
+
return DataLoader(dataset, batch_size=8)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _make_model(features: int = 8, classes: int = 3) -> nn.Module:
|
|
24
|
+
return nn.Sequential(nn.Linear(features, 16), nn.ReLU(), nn.Linear(16, classes))
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def test_callback_base_methods_are_noop() -> None:
|
|
28
|
+
"""Ensure base callback methods are safe no-op defaults."""
|
|
29
|
+
callback = Callback()
|
|
30
|
+
callback.on_train_begin({"start": True})
|
|
31
|
+
callback.on_epoch_end(1, {"loss": 1.0})
|
|
32
|
+
callback.on_train_end({"done": True})
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def test_wandb_logger_raises_if_not_installed(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
36
|
+
"""Raise ImportError when wandb is unavailable at runtime."""
|
|
37
|
+
monkeypatch.setitem(sys.modules, "wandb", None)
|
|
38
|
+
logger = WandBLogger(project="proj")
|
|
39
|
+
|
|
40
|
+
with pytest.raises(ImportError, match="wandb is required"):
|
|
41
|
+
logger.on_train_begin({"train": True})
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def test_mlflow_logger_raises_if_not_installed(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
46
|
+
"""Raise ImportError when mlflow is unavailable at runtime."""
|
|
47
|
+
monkeypatch.setitem(sys.modules, "mlflow", None)
|
|
48
|
+
logger = MLflowLogger(experiment_name="exp")
|
|
49
|
+
|
|
50
|
+
with pytest.raises(ImportError, match="mlflow is required"):
|
|
51
|
+
logger.on_train_begin({"train": True})
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def test_trainer_accepts_empty_callbacks_list() -> None:
|
|
56
|
+
"""Allow explicit empty callback lists during fit execution."""
|
|
57
|
+
model = _make_model()
|
|
58
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
|
59
|
+
criterion = nn.CrossEntropyLoss()
|
|
60
|
+
|
|
61
|
+
trainer = Trainer(
|
|
62
|
+
model=model,
|
|
63
|
+
optimizer=optimizer,
|
|
64
|
+
criterion=criterion,
|
|
65
|
+
device="cpu",
|
|
66
|
+
callbacks=[],
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
history = trainer.fit(_make_loader(), _make_loader(), epochs=2)
|
|
70
|
+
assert len(history["train_loss"]) == 2
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def test_trainer_calls_on_epoch_end() -> None:
|
|
75
|
+
"""Invoke callback epoch-end hook once per epoch with epoch indices."""
|
|
76
|
+
|
|
77
|
+
class TestCallback(Callback):
|
|
78
|
+
def __init__(self) -> None:
|
|
79
|
+
self.epochs: list[int] = []
|
|
80
|
+
|
|
81
|
+
def on_epoch_end(self, epoch: int, logs: dict) -> None:
|
|
82
|
+
self.epochs.append(epoch)
|
|
83
|
+
|
|
84
|
+
callback = TestCallback()
|
|
85
|
+
model = _make_model()
|
|
86
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
|
87
|
+
criterion = nn.CrossEntropyLoss()
|
|
88
|
+
|
|
89
|
+
trainer = Trainer(
|
|
90
|
+
model=model,
|
|
91
|
+
optimizer=optimizer,
|
|
92
|
+
criterion=criterion,
|
|
93
|
+
device="cpu",
|
|
94
|
+
callbacks=[callback],
|
|
95
|
+
)
|
|
96
|
+
trainer.fit(_make_loader(), _make_loader(), epochs=3)
|
|
97
|
+
|
|
98
|
+
assert callback.epochs == [1, 2, 3]
|
|
@@ -4,7 +4,7 @@ import pytest
|
|
|
4
4
|
from torch.utils.data import DataLoader, TensorDataset
|
|
5
5
|
|
|
6
6
|
from torchloop import Trainer
|
|
7
|
-
from torchloop.callbacks import Callback
|
|
7
|
+
from torchloop.callbacks import Callback
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def _make_loader(n=64, features=16, classes=3, batch=16):
|
|
@@ -140,20 +140,12 @@ def test_trainer_callbacks_triggered():
|
|
|
140
140
|
class Recorder(Callback):
|
|
141
141
|
def __init__(self):
|
|
142
142
|
self.train_begin = 0
|
|
143
|
-
self.epoch_begin = 0
|
|
144
|
-
self.batch_end = 0
|
|
145
143
|
self.epoch_end = 0
|
|
146
144
|
self.train_end = 0
|
|
147
145
|
|
|
148
146
|
def on_train_begin(self, logs=None):
|
|
149
147
|
self.train_begin += 1
|
|
150
148
|
|
|
151
|
-
def on_epoch_begin(self, epoch, logs=None):
|
|
152
|
-
self.epoch_begin += 1
|
|
153
|
-
|
|
154
|
-
def on_batch_end(self, batch, logs=None):
|
|
155
|
-
self.batch_end += 1
|
|
156
|
-
|
|
157
149
|
def on_epoch_end(self, epoch, logs=None):
|
|
158
150
|
self.epoch_end += 1
|
|
159
151
|
|
|
@@ -175,17 +167,22 @@ def test_trainer_callbacks_triggered():
|
|
|
175
167
|
trainer.fit(_make_loader(n=32, batch=8), epochs=2)
|
|
176
168
|
|
|
177
169
|
assert recorder.train_begin == 1
|
|
178
|
-
assert recorder.epoch_begin == 2
|
|
179
|
-
assert recorder.batch_end == 8
|
|
180
170
|
assert recorder.epoch_end == 2
|
|
181
171
|
assert recorder.train_end == 1
|
|
182
172
|
|
|
183
173
|
|
|
184
|
-
def
|
|
174
|
+
def test_callback_receives_epoch_end_for_each_epoch():
|
|
175
|
+
class EpochRecorder(Callback):
|
|
176
|
+
def __init__(self):
|
|
177
|
+
self.epochs = []
|
|
178
|
+
|
|
179
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
180
|
+
self.epochs.append(epoch)
|
|
181
|
+
|
|
185
182
|
model = _make_model()
|
|
186
183
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
|
187
184
|
criterion = nn.CrossEntropyLoss()
|
|
188
|
-
callback =
|
|
185
|
+
callback = EpochRecorder()
|
|
189
186
|
|
|
190
187
|
trainer = Trainer(
|
|
191
188
|
model,
|
|
@@ -194,8 +191,8 @@ def test_early_stopping_callback_stops_training():
|
|
|
194
191
|
device="cpu",
|
|
195
192
|
callbacks=[callback],
|
|
196
193
|
)
|
|
197
|
-
|
|
198
|
-
assert
|
|
194
|
+
trainer.fit(_make_loader(), _make_loader(), epochs=3)
|
|
195
|
+
assert callback.epochs == [1, 2, 3]
|
|
199
196
|
|
|
200
197
|
|
|
201
198
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
|
@@ -1,216 +0,0 @@
|
|
|
1
|
-
"""Callback utilities for training lifecycle hooks."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
import csv
|
|
6
|
-
from pathlib import Path
|
|
7
|
-
from typing import Any, Optional
|
|
8
|
-
|
|
9
|
-
import torch
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class StopTraining(Exception):
|
|
13
|
-
"""Raised by callbacks to request early stop of training."""
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class Callback:
|
|
17
|
-
"""Base callback class.
|
|
18
|
-
|
|
19
|
-
Subclass this and override hook methods to customize training behavior.
|
|
20
|
-
"""
|
|
21
|
-
|
|
22
|
-
def on_train_begin(self, logs: Optional[dict[str, Any]] = None) -> None:
|
|
23
|
-
"""Called before the first epoch starts.
|
|
24
|
-
|
|
25
|
-
Args:
|
|
26
|
-
logs: Optional training state dictionary.
|
|
27
|
-
"""
|
|
28
|
-
|
|
29
|
-
def on_train_end(self, logs: Optional[dict[str, Any]] = None) -> None:
|
|
30
|
-
"""Called after training ends.
|
|
31
|
-
|
|
32
|
-
Args:
|
|
33
|
-
logs: Optional training state dictionary.
|
|
34
|
-
"""
|
|
35
|
-
|
|
36
|
-
def on_epoch_begin(
|
|
37
|
-
self,
|
|
38
|
-
epoch: int,
|
|
39
|
-
logs: Optional[dict[str, Any]] = None,
|
|
40
|
-
) -> None:
|
|
41
|
-
"""Called at the start of each epoch.
|
|
42
|
-
|
|
43
|
-
Args:
|
|
44
|
-
epoch: Current epoch number.
|
|
45
|
-
logs: Optional training state dictionary.
|
|
46
|
-
"""
|
|
47
|
-
|
|
48
|
-
def on_epoch_end(
|
|
49
|
-
self,
|
|
50
|
-
epoch: int,
|
|
51
|
-
logs: Optional[dict[str, Any]] = None,
|
|
52
|
-
) -> None:
|
|
53
|
-
"""Called at the end of each epoch.
|
|
54
|
-
|
|
55
|
-
Args:
|
|
56
|
-
epoch: Current epoch number.
|
|
57
|
-
logs: Optional training state dictionary.
|
|
58
|
-
"""
|
|
59
|
-
|
|
60
|
-
def on_batch_end(
|
|
61
|
-
self,
|
|
62
|
-
batch: int,
|
|
63
|
-
logs: Optional[dict[str, Any]] = None,
|
|
64
|
-
) -> None:
|
|
65
|
-
"""Called after each training batch.
|
|
66
|
-
|
|
67
|
-
Args:
|
|
68
|
-
batch: Zero-based batch index.
|
|
69
|
-
logs: Optional batch-level metrics.
|
|
70
|
-
"""
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
class EarlyStopping(Callback):
|
|
74
|
-
"""Stop training when a monitored metric stops improving.
|
|
75
|
-
|
|
76
|
-
Args:
|
|
77
|
-
patience: Number of epochs to wait for an improvement.
|
|
78
|
-
monitor: Metric name expected in logs.
|
|
79
|
-
min_delta: Minimum improvement required to reset patience.
|
|
80
|
-
"""
|
|
81
|
-
|
|
82
|
-
def __init__(
|
|
83
|
-
self,
|
|
84
|
-
patience: int = 5,
|
|
85
|
-
monitor: str = "val_loss",
|
|
86
|
-
min_delta: float = 0.0,
|
|
87
|
-
) -> None:
|
|
88
|
-
if patience < 1:
|
|
89
|
-
raise ValueError("patience must be >= 1")
|
|
90
|
-
if min_delta < 0:
|
|
91
|
-
raise ValueError("min_delta must be >= 0")
|
|
92
|
-
|
|
93
|
-
self.patience = patience
|
|
94
|
-
self.monitor = monitor
|
|
95
|
-
self.min_delta = min_delta
|
|
96
|
-
self.best = float("inf")
|
|
97
|
-
self.counter = 0
|
|
98
|
-
|
|
99
|
-
def on_epoch_end(
|
|
100
|
-
self,
|
|
101
|
-
epoch: int,
|
|
102
|
-
logs: Optional[dict[str, Any]] = None,
|
|
103
|
-
) -> None:
|
|
104
|
-
"""Check metric value and raise StopTraining when patience is exceeded.
|
|
105
|
-
|
|
106
|
-
Args:
|
|
107
|
-
epoch: Current epoch number.
|
|
108
|
-
logs: Metrics dictionary from trainer.
|
|
109
|
-
"""
|
|
110
|
-
logs = logs or {}
|
|
111
|
-
current = logs.get(self.monitor)
|
|
112
|
-
if current is None:
|
|
113
|
-
return
|
|
114
|
-
|
|
115
|
-
if current <= self.best - self.min_delta:
|
|
116
|
-
self.best = float(current)
|
|
117
|
-
self.counter = 0
|
|
118
|
-
return
|
|
119
|
-
|
|
120
|
-
self.counter += 1
|
|
121
|
-
if self.counter >= self.patience:
|
|
122
|
-
raise StopTraining(f"Early stopping triggered at epoch {epoch}")
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
class ModelCheckpoint(Callback):
|
|
126
|
-
"""Save model checkpoints based on monitored metric.
|
|
127
|
-
|
|
128
|
-
Args:
|
|
129
|
-
monitor: Metric name expected in logs.
|
|
130
|
-
save_best_only: Save only when monitored metric improves.
|
|
131
|
-
filepath: Path to write checkpoint file.
|
|
132
|
-
"""
|
|
133
|
-
|
|
134
|
-
def __init__(
|
|
135
|
-
self,
|
|
136
|
-
monitor: str = "val_loss",
|
|
137
|
-
save_best_only: bool = True,
|
|
138
|
-
filepath: str = "checkpoint.pt",
|
|
139
|
-
) -> None:
|
|
140
|
-
self.monitor = monitor
|
|
141
|
-
self.save_best_only = save_best_only
|
|
142
|
-
self.filepath = Path(filepath)
|
|
143
|
-
self.best = float("inf")
|
|
144
|
-
|
|
145
|
-
def on_epoch_end(
|
|
146
|
-
self,
|
|
147
|
-
epoch: int,
|
|
148
|
-
logs: Optional[dict[str, Any]] = None,
|
|
149
|
-
) -> None:
|
|
150
|
-
"""Persist model checkpoint if criteria are met.
|
|
151
|
-
|
|
152
|
-
Args:
|
|
153
|
-
epoch: Current epoch number.
|
|
154
|
-
logs: Metrics dictionary from trainer.
|
|
155
|
-
"""
|
|
156
|
-
logs = logs or {}
|
|
157
|
-
model = logs.get("model")
|
|
158
|
-
if model is None:
|
|
159
|
-
return
|
|
160
|
-
|
|
161
|
-
current = logs.get(self.monitor)
|
|
162
|
-
if current is None:
|
|
163
|
-
return
|
|
164
|
-
|
|
165
|
-
if not self.save_best_only:
|
|
166
|
-
torch.save(model.state_dict(), self.filepath)
|
|
167
|
-
return
|
|
168
|
-
|
|
169
|
-
if float(current) < self.best:
|
|
170
|
-
self.best = float(current)
|
|
171
|
-
torch.save(model.state_dict(), self.filepath)
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
class CSVLogger(Callback):
|
|
175
|
-
"""Log epoch metrics to a CSV file.
|
|
176
|
-
|
|
177
|
-
Args:
|
|
178
|
-
log_dir: Directory where CSV logs are written.
|
|
179
|
-
filename: Name of CSV file.
|
|
180
|
-
"""
|
|
181
|
-
|
|
182
|
-
def __init__(self, log_dir: str = "./logs", filename: str = "metrics.csv") -> None:
|
|
183
|
-
self.log_dir = Path(log_dir)
|
|
184
|
-
self.filename = filename
|
|
185
|
-
self._initialized = False
|
|
186
|
-
self._fieldnames: list[str] = []
|
|
187
|
-
|
|
188
|
-
def on_epoch_end(
|
|
189
|
-
self,
|
|
190
|
-
epoch: int,
|
|
191
|
-
logs: Optional[dict[str, Any]] = None,
|
|
192
|
-
) -> None:
|
|
193
|
-
"""Append epoch metrics to CSV.
|
|
194
|
-
|
|
195
|
-
Args:
|
|
196
|
-
epoch: Current epoch number.
|
|
197
|
-
logs: Metrics dictionary from trainer.
|
|
198
|
-
"""
|
|
199
|
-
logs = dict(logs or {})
|
|
200
|
-
logs["epoch"] = epoch
|
|
201
|
-
|
|
202
|
-
self.log_dir.mkdir(parents=True, exist_ok=True)
|
|
203
|
-
file_path = self.log_dir / self.filename
|
|
204
|
-
|
|
205
|
-
if not self._initialized:
|
|
206
|
-
self._fieldnames = sorted(logs.keys())
|
|
207
|
-
with file_path.open("w", newline="", encoding="utf-8") as f:
|
|
208
|
-
writer = csv.DictWriter(f, fieldnames=self._fieldnames)
|
|
209
|
-
writer.writeheader()
|
|
210
|
-
writer.writerow({k: logs.get(k) for k in self._fieldnames})
|
|
211
|
-
self._initialized = True
|
|
212
|
-
return
|
|
213
|
-
|
|
214
|
-
with file_path.open("a", newline="", encoding="utf-8") as f:
|
|
215
|
-
writer = csv.DictWriter(f, fieldnames=self._fieldnames)
|
|
216
|
-
writer.writerow({k: logs.get(k) for k in self._fieldnames})
|
|
@@ -1,191 +0,0 @@
|
|
|
1
|
-
"""Tests for callback utilities."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
import csv
|
|
6
|
-
from pathlib import Path
|
|
7
|
-
|
|
8
|
-
import pytest
|
|
9
|
-
import torch
|
|
10
|
-
import torch.nn as nn
|
|
11
|
-
from torch.utils.data import DataLoader, TensorDataset
|
|
12
|
-
|
|
13
|
-
from torchloop import Trainer
|
|
14
|
-
from torchloop.callbacks import CSVLogger, Callback, EarlyStopping, ModelCheckpoint, StopTraining
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class TinyNet(nn.Module):
|
|
18
|
-
"""Simple network for callback tests."""
|
|
19
|
-
|
|
20
|
-
def __init__(self) -> None:
|
|
21
|
-
super().__init__()
|
|
22
|
-
self.fc = nn.Linear(10, 2)
|
|
23
|
-
|
|
24
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
25
|
-
return self.fc(x)
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
@pytest.fixture
|
|
29
|
-
def loaders() -> tuple[DataLoader, DataLoader]:
|
|
30
|
-
"""Create tiny train and validation loaders."""
|
|
31
|
-
x_train = torch.randn(12, 10)
|
|
32
|
-
y_train = torch.randint(0, 2, (12,))
|
|
33
|
-
x_val = torch.randn(8, 10)
|
|
34
|
-
y_val = torch.randint(0, 2, (8,))
|
|
35
|
-
|
|
36
|
-
train_loader = DataLoader(TensorDataset(x_train, y_train), batch_size=4)
|
|
37
|
-
val_loader = DataLoader(TensorDataset(x_val, y_val), batch_size=4)
|
|
38
|
-
return train_loader, val_loader
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
@pytest.fixture
|
|
42
|
-
def trainer_parts() -> tuple[nn.Module, torch.optim.Optimizer, nn.Module]:
|
|
43
|
-
"""Create model, optimizer, and criterion for Trainer setup."""
|
|
44
|
-
model = TinyNet()
|
|
45
|
-
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
|
|
46
|
-
criterion = nn.CrossEntropyLoss()
|
|
47
|
-
return model, optimizer, criterion
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
def test_early_stopping_init_values() -> None:
|
|
51
|
-
"""Validate early stopping initialization fields."""
|
|
52
|
-
callback = EarlyStopping(patience=3, monitor="val_loss", min_delta=0.1)
|
|
53
|
-
assert callback.patience == 3
|
|
54
|
-
assert callback.monitor == "val_loss"
|
|
55
|
-
assert callback.min_delta == 0.1
|
|
56
|
-
assert callback.best == float("inf")
|
|
57
|
-
assert callback.counter == 0
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def test_early_stopping_validation_errors() -> None:
|
|
61
|
-
"""Validate constructor error handling for invalid config."""
|
|
62
|
-
with pytest.raises(ValueError, match="patience must be >= 1"):
|
|
63
|
-
EarlyStopping(patience=0)
|
|
64
|
-
|
|
65
|
-
with pytest.raises(ValueError, match="min_delta must be >= 0"):
|
|
66
|
-
EarlyStopping(min_delta=-0.1)
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
def test_early_stopping_triggers_after_patience() -> None:
|
|
70
|
-
"""Raise StopTraining when monitored metric does not improve."""
|
|
71
|
-
callback = EarlyStopping(patience=2, monitor="val_loss")
|
|
72
|
-
callback.on_epoch_end(epoch=0, logs={"val_loss": 1.0})
|
|
73
|
-
callback.on_epoch_end(epoch=1, logs={"val_loss": 1.1})
|
|
74
|
-
|
|
75
|
-
with pytest.raises(StopTraining):
|
|
76
|
-
callback.on_epoch_end(epoch=2, logs={"val_loss": 1.2})
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
def test_model_checkpoint_save_best_only(tmp_path: Path) -> None:
|
|
80
|
-
"""Save checkpoint only when monitored metric improves."""
|
|
81
|
-
model = TinyNet()
|
|
82
|
-
path = tmp_path / "best.pt"
|
|
83
|
-
callback = ModelCheckpoint(filepath=str(path), save_best_only=True)
|
|
84
|
-
|
|
85
|
-
callback.on_epoch_end(epoch=0, logs={"model": model, "val_loss": 1.0})
|
|
86
|
-
first_mtime = path.stat().st_mtime
|
|
87
|
-
|
|
88
|
-
callback.on_epoch_end(epoch=1, logs={"model": model, "val_loss": 1.2})
|
|
89
|
-
assert path.stat().st_mtime == first_mtime
|
|
90
|
-
|
|
91
|
-
callback.on_epoch_end(epoch=2, logs={"model": model, "val_loss": 0.8})
|
|
92
|
-
assert path.stat().st_mtime >= first_mtime
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
def test_model_checkpoint_save_all(tmp_path: Path) -> None:
|
|
96
|
-
"""Always save checkpoint when save_best_only is disabled."""
|
|
97
|
-
model = TinyNet()
|
|
98
|
-
path = tmp_path / "all.pt"
|
|
99
|
-
callback = ModelCheckpoint(filepath=str(path), save_best_only=False)
|
|
100
|
-
|
|
101
|
-
callback.on_epoch_end(epoch=0, logs={"model": model, "val_loss": 2.0})
|
|
102
|
-
first_mtime = path.stat().st_mtime
|
|
103
|
-
callback.on_epoch_end(epoch=1, logs={"model": model, "val_loss": 3.0})
|
|
104
|
-
assert path.stat().st_mtime >= first_mtime
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
def test_model_checkpoint_noop_when_logs_missing(tmp_path: Path) -> None:
|
|
108
|
-
"""Do not crash when required keys are missing in logs."""
|
|
109
|
-
path = tmp_path / "noop.pt"
|
|
110
|
-
callback = ModelCheckpoint(filepath=str(path))
|
|
111
|
-
|
|
112
|
-
callback.on_epoch_end(epoch=0, logs={"val_loss": 1.0})
|
|
113
|
-
callback.on_epoch_end(epoch=1, logs={"model": TinyNet()})
|
|
114
|
-
assert not path.exists()
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
def test_csv_logger_writes_header_and_rows(tmp_path: Path) -> None:
|
|
118
|
-
"""Write CSV header once and append per-epoch values."""
|
|
119
|
-
callback = CSVLogger(log_dir=str(tmp_path), filename="metrics.csv")
|
|
120
|
-
callback.on_epoch_end(epoch=0, logs={"loss": 1.0, "val_loss": 1.1})
|
|
121
|
-
callback.on_epoch_end(epoch=1, logs={"loss": 0.9, "val_loss": 1.0})
|
|
122
|
-
|
|
123
|
-
csv_path = tmp_path / "metrics.csv"
|
|
124
|
-
with csv_path.open("r", encoding="utf-8") as f:
|
|
125
|
-
rows = list(csv.reader(f))
|
|
126
|
-
|
|
127
|
-
assert len(rows) == 3
|
|
128
|
-
assert "epoch" in rows[0]
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
def test_trainer_calls_custom_callback_hooks(
|
|
132
|
-
trainer_parts: tuple[nn.Module, torch.optim.Optimizer, nn.Module],
|
|
133
|
-
loaders: tuple[DataLoader, DataLoader],
|
|
134
|
-
) -> None:
|
|
135
|
-
"""Ensure custom callback receives train, epoch, and batch hooks."""
|
|
136
|
-
|
|
137
|
-
class Recorder(Callback):
|
|
138
|
-
def __init__(self) -> None:
|
|
139
|
-
self.events: list[str] = []
|
|
140
|
-
|
|
141
|
-
def on_train_begin(self, logs=None) -> None:
|
|
142
|
-
self.events.append("train_begin")
|
|
143
|
-
|
|
144
|
-
def on_epoch_begin(self, epoch: int, logs=None) -> None:
|
|
145
|
-
self.events.append(f"epoch_begin_{epoch}")
|
|
146
|
-
|
|
147
|
-
def on_batch_end(self, batch: int, logs=None) -> None:
|
|
148
|
-
self.events.append("batch_end")
|
|
149
|
-
|
|
150
|
-
def on_epoch_end(self, epoch: int, logs=None) -> None:
|
|
151
|
-
self.events.append(f"epoch_end_{epoch}")
|
|
152
|
-
|
|
153
|
-
def on_train_end(self, logs=None) -> None:
|
|
154
|
-
self.events.append("train_end")
|
|
155
|
-
|
|
156
|
-
model, optimizer, criterion = trainer_parts
|
|
157
|
-
train_loader, val_loader = loaders
|
|
158
|
-
callback = Recorder()
|
|
159
|
-
trainer = Trainer(model, optimizer, criterion, callbacks=[callback], device="cpu")
|
|
160
|
-
trainer.fit(train_loader, val_loader, epochs=2)
|
|
161
|
-
|
|
162
|
-
assert "train_begin" in callback.events
|
|
163
|
-
assert "train_end" in callback.events
|
|
164
|
-
assert "epoch_begin_1" in callback.events
|
|
165
|
-
assert "epoch_end_2" in callback.events
|
|
166
|
-
assert "batch_end" in callback.events
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
def test_trainer_handles_stop_training_callback(
|
|
170
|
-
trainer_parts: tuple[nn.Module, torch.optim.Optimizer, nn.Module],
|
|
171
|
-
loaders: tuple[DataLoader, DataLoader],
|
|
172
|
-
) -> None:
|
|
173
|
-
"""StopTraining raised by callback should halt training gracefully."""
|
|
174
|
-
|
|
175
|
-
class StopAfterFirstEpoch(Callback):
|
|
176
|
-
def on_epoch_end(self, epoch: int, logs=None) -> None:
|
|
177
|
-
if epoch >= 1:
|
|
178
|
-
raise StopTraining("stop")
|
|
179
|
-
|
|
180
|
-
model, optimizer, criterion = trainer_parts
|
|
181
|
-
train_loader, val_loader = loaders
|
|
182
|
-
trainer = Trainer(
|
|
183
|
-
model,
|
|
184
|
-
optimizer,
|
|
185
|
-
criterion,
|
|
186
|
-
callbacks=[StopAfterFirstEpoch()],
|
|
187
|
-
device="cpu",
|
|
188
|
-
)
|
|
189
|
-
|
|
190
|
-
history = trainer.fit(train_loader, val_loader, epochs=5)
|
|
191
|
-
assert len(history["train_loss"]) == 1
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|