laco-lightning 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- laco/integrations/lightning/__init__.py +34 -0
- laco/integrations/lightning/_core.py +177 -0
- laco_lightning-1.0.0.dist-info/METADATA +29 -0
- laco_lightning-1.0.0.dist-info/RECORD +7 -0
- laco_lightning-1.0.0.dist-info/WHEEL +5 -0
- laco_lightning-1.0.0.dist-info/entry_points.txt +2 -0
- laco_lightning-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""PyTorch Lightning integration for laco.
|
|
2
|
+
|
|
3
|
+
Provides:
|
|
4
|
+
|
|
5
|
+
- :class:`LacoConfigCallback` — logs the resolved laco DictConfig to every
|
|
6
|
+
logger attached to the ``Trainer`` (W&B, TensorBoard, MLflow, CSV, …).
|
|
7
|
+
- :func:`trainer` — build a ``lightning.Trainer`` from a laco DictConfig.
|
|
8
|
+
- :func:`fit` — convenience wrapper that wires config logging and training
|
|
9
|
+
into a single call.
|
|
10
|
+
|
|
11
|
+
Examples
|
|
12
|
+
--------
|
|
13
|
+
::
|
|
14
|
+
|
|
15
|
+
import laco
|
|
16
|
+
import laco.integrations.lightning as laco_lightning
|
|
17
|
+
import lightning as L
|
|
18
|
+
|
|
19
|
+
cfg = laco.load("configs/train.py")
|
|
20
|
+
model = laco.instantiate(cfg.model)
|
|
21
|
+
|
|
22
|
+
trainer = laco_lightning.trainer(cfg.trainer)
|
|
23
|
+
laco_lightning.fit(trainer, model, cfg=cfg, datamodule=dm)
|
|
24
|
+
|
|
25
|
+
# Or just attach the callback manually:
|
|
26
|
+
trainer = L.Trainer(callbacks=[laco_lightning.LacoConfigCallback(cfg)])
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
from laco.integrations.lightning._core import LacoConfigCallback as LacoConfigCallback
|
|
30
|
+
from laco.integrations.lightning._core import _bootstrap as _bootstrap
|
|
31
|
+
from laco.integrations.lightning._core import fit as fit
|
|
32
|
+
from laco.integrations.lightning._core import trainer as trainer
|
|
33
|
+
|
|
34
|
+
__all__ = ["LacoConfigCallback", "fit", "trainer"]
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
"""Implementation for laco-lightning."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import typing
|
|
6
|
+
|
|
7
|
+
if typing.TYPE_CHECKING:
|
|
8
|
+
import lightning
|
|
9
|
+
from omegaconf import DictConfig
|
|
10
|
+
|
|
11
|
+
# Cache for the lazily-constructed Callback subclass so we only pay the
|
|
12
|
+
# lightning import cost once (on first instantiation of LacoConfigCallback).
|
|
13
|
+
_CallbackCls: type | None = None
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _bootstrap() -> None:
|
|
17
|
+
"""Register Lightning classes as ``@L.configurable`` targets.
|
|
18
|
+
|
|
19
|
+
Called automatically via the ``laco.plugins`` entry point when
|
|
20
|
+
``laco-lightning`` is installed — no explicit import required.
|
|
21
|
+
Silently does nothing if ``lightning`` is not importable.
|
|
22
|
+
"""
|
|
23
|
+
try:
|
|
24
|
+
import lightning as _L
|
|
25
|
+
from laco.language import configurable
|
|
26
|
+
|
|
27
|
+
configurable(_L.Trainer)
|
|
28
|
+
configurable(_L.LightningModule)
|
|
29
|
+
configurable(_L.LightningDataModule)
|
|
30
|
+
except ImportError:
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _make_callback_cls() -> type:
|
|
35
|
+
global _CallbackCls # noqa: PLW0603
|
|
36
|
+
if _CallbackCls is not None:
|
|
37
|
+
return _CallbackCls
|
|
38
|
+
|
|
39
|
+
from lightning.pytorch.callbacks import Callback # type: ignore[import-untyped]
|
|
40
|
+
from omegaconf import OmegaConf
|
|
41
|
+
|
|
42
|
+
class _LacoConfigCallback(Callback):
|
|
43
|
+
"""Logs the laco DictConfig to all Trainer loggers on fit start."""
|
|
44
|
+
|
|
45
|
+
def __init__(self, cfg: DictConfig) -> None:
|
|
46
|
+
super().__init__()
|
|
47
|
+
self.cfg = cfg
|
|
48
|
+
|
|
49
|
+
def on_fit_start(
|
|
50
|
+
self,
|
|
51
|
+
trainer: lightning.Trainer,
|
|
52
|
+
pl_module: lightning.LightningModule,
|
|
53
|
+
) -> None:
|
|
54
|
+
flat = OmegaConf.to_container(
|
|
55
|
+
self.cfg, resolve=True, throw_on_missing=False
|
|
56
|
+
)
|
|
57
|
+
for logger in trainer.loggers:
|
|
58
|
+
if hasattr(logger, "log_hyperparams"):
|
|
59
|
+
logger.log_hyperparams(flat) # type: ignore[arg-type]
|
|
60
|
+
|
|
61
|
+
_LacoConfigCallback.__name__ = "LacoConfigCallback"
|
|
62
|
+
_LacoConfigCallback.__qualname__ = "LacoConfigCallback"
|
|
63
|
+
_CallbackCls = _LacoConfigCallback
|
|
64
|
+
return _CallbackCls
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class LacoConfigCallback:
|
|
68
|
+
"""Lightning ``Callback`` that logs the laco config to all active loggers.
|
|
69
|
+
|
|
70
|
+
Runs on ``on_fit_start`` so the config is recorded before any training
|
|
71
|
+
step, regardless of which loggers are attached to the ``Trainer``.
|
|
72
|
+
|
|
73
|
+
Parameters
|
|
74
|
+
----------
|
|
75
|
+
cfg : DictConfig
|
|
76
|
+
Resolved laco DictConfig to log as hyperparameters.
|
|
77
|
+
|
|
78
|
+
Examples
|
|
79
|
+
--------
|
|
80
|
+
::
|
|
81
|
+
|
|
82
|
+
trainer = L.Trainer(
|
|
83
|
+
max_epochs=10,
|
|
84
|
+
callbacks=[laco_lightning.LacoConfigCallback(cfg)],
|
|
85
|
+
)
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
def __new__(cls, cfg: DictConfig) -> LacoConfigCallback: # type: ignore[misc]
|
|
89
|
+
real_cls = _make_callback_cls()
|
|
90
|
+
instance = real_cls.__new__(real_cls)
|
|
91
|
+
instance.__init__(cfg)
|
|
92
|
+
return instance # type: ignore[return-value]
|
|
93
|
+
|
|
94
|
+
def __init__(self, cfg: DictConfig) -> None:
|
|
95
|
+
self.cfg = cfg # type: ignore[misc]
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def trainer(
|
|
99
|
+
cfg: DictConfig | None = None,
|
|
100
|
+
**kwargs: typing.Any,
|
|
101
|
+
) -> lightning.Trainer:
|
|
102
|
+
"""Build a ``lightning.Trainer`` from a laco DictConfig.
|
|
103
|
+
|
|
104
|
+
Fields in *cfg* are passed as keyword arguments to ``Trainer.__init__``.
|
|
105
|
+
Any *kwargs* are merged on top, letting call-site overrides win.
|
|
106
|
+
|
|
107
|
+
Parameters
|
|
108
|
+
----------
|
|
109
|
+
cfg : DictConfig | None
|
|
110
|
+
Optional DictConfig whose keys map to ``Trainer`` constructor
|
|
111
|
+
parameters (``max_epochs``, ``accelerator``, ``devices``, …).
|
|
112
|
+
**kwargs
|
|
113
|
+
Additional ``Trainer`` kwargs that override *cfg* values.
|
|
114
|
+
|
|
115
|
+
Returns
|
|
116
|
+
-------
|
|
117
|
+
lightning.Trainer
|
|
118
|
+
Configured trainer instance.
|
|
119
|
+
|
|
120
|
+
Examples
|
|
121
|
+
--------
|
|
122
|
+
::
|
|
123
|
+
|
|
124
|
+
t = laco_lightning.trainer(cfg.trainer_cfg, logger=my_logger)
|
|
125
|
+
"""
|
|
126
|
+
import lightning as _L
|
|
127
|
+
from omegaconf import OmegaConf
|
|
128
|
+
|
|
129
|
+
params: dict[str, typing.Any] = {}
|
|
130
|
+
if cfg is not None:
|
|
131
|
+
params = dict(OmegaConf.to_container(cfg, resolve=True)) # type: ignore[arg-type]
|
|
132
|
+
params.update(kwargs)
|
|
133
|
+
return _L.Trainer(**params)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def fit(
|
|
137
|
+
trainer: lightning.Trainer,
|
|
138
|
+
model: lightning.LightningModule,
|
|
139
|
+
*,
|
|
140
|
+
cfg: DictConfig,
|
|
141
|
+
log_config: bool = True,
|
|
142
|
+
**fit_kwargs: typing.Any,
|
|
143
|
+
) -> None:
|
|
144
|
+
"""Fit a Lightning model, optionally logging the laco config first.
|
|
145
|
+
|
|
146
|
+
Attaches a :class:`LacoConfigCallback` when *log_config* is ``True``
|
|
147
|
+
(the default), then calls ``trainer.fit(model, **fit_kwargs)``.
|
|
148
|
+
|
|
149
|
+
Parameters
|
|
150
|
+
----------
|
|
151
|
+
trainer : lightning.Trainer
|
|
152
|
+
Configured ``Trainer`` instance.
|
|
153
|
+
model : lightning.LightningModule
|
|
154
|
+
``LightningModule`` to train.
|
|
155
|
+
cfg : DictConfig
|
|
156
|
+
Resolved laco DictConfig — logged to all loggers if *log_config*
|
|
157
|
+
is ``True``.
|
|
158
|
+
log_config : bool
|
|
159
|
+
Whether to attach :class:`LacoConfigCallback`. Set to ``False`` if
|
|
160
|
+
you handle config logging yourself. Default: ``True``.
|
|
161
|
+
**fit_kwargs
|
|
162
|
+
Extra keyword arguments forwarded to ``trainer.fit()``
|
|
163
|
+
(e.g. ``datamodule``, ``train_dataloaders``, ``val_dataloaders``).
|
|
164
|
+
|
|
165
|
+
Examples
|
|
166
|
+
--------
|
|
167
|
+
::
|
|
168
|
+
|
|
169
|
+
laco_lightning.fit(trainer, model, cfg=cfg, datamodule=dm)
|
|
170
|
+
"""
|
|
171
|
+
if log_config:
|
|
172
|
+
callback = _make_callback_cls()(cfg)
|
|
173
|
+
if hasattr(trainer, "callbacks") and isinstance(trainer.callbacks, list):
|
|
174
|
+
trainer.callbacks.append(callback)
|
|
175
|
+
else:
|
|
176
|
+
trainer.callbacks = [callback]
|
|
177
|
+
trainer.fit(model, **fit_kwargs)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: laco-lightning
|
|
3
|
+
Version: 1.0.0
|
|
4
|
+
Summary: PyTorch Lightning integration for laco.
|
|
5
|
+
Author-email: Kurt Stolle <kurt@khws.io>
|
|
6
|
+
Requires-Python: >=3.13
|
|
7
|
+
Description-Content-Type: text/markdown
|
|
8
|
+
Requires-Dist: laco>=1.0.0
|
|
9
|
+
Requires-Dist: lightning>=2.0
|
|
10
|
+
|
|
11
|
+
# Laco-Lightning
|
|
12
|
+
|
|
13
|
+
PyTorch Lightning integration for laco.
|
|
14
|
+
|
|
15
|
+
Part of the [laco](https://github.com/khwstolle/laco) project — see the [root README](../../README.md) for an overview.
|
|
16
|
+
|
|
17
|
+
## Installation
|
|
18
|
+
|
|
19
|
+
```bash
|
|
20
|
+
pip install laco-lightning
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
## Features
|
|
24
|
+
|
|
25
|
+
`LacoConfigCallback`, `trainer()`, `fit()`
|
|
26
|
+
|
|
27
|
+
## Usage
|
|
28
|
+
|
|
29
|
+
See [`docs/index.md`](docs/index.md) for the full guide.
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
laco/integrations/lightning/__init__.py,sha256=rmwiss6rzvka4z5JxNyb0mnIyuG8W8eSkDfBwxejjTg,1157
|
|
2
|
+
laco/integrations/lightning/_core.py,sha256=z0SMDdrUmCu-TnS_LRuw7qNtd0FQlBXd5Nyl85W3t-A,5373
|
|
3
|
+
laco_lightning-1.0.0.dist-info/METADATA,sha256=HuFLcRpOLaHLQ65Pfnc2lX-EtWAet-6xBw-Yt7ppPA0,633
|
|
4
|
+
laco_lightning-1.0.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
|
|
5
|
+
laco_lightning-1.0.0.dist-info/entry_points.txt,sha256=owt1zWRxoMjOL-ydLbfOqZKf7tx59cOmWTemxN8mGHU,66
|
|
6
|
+
laco_lightning-1.0.0.dist-info/top_level.txt,sha256=G2kLu09Aje44OkSqu-Tae3mjmTYhyRc2VrTyh3OmxFw,5
|
|
7
|
+
laco_lightning-1.0.0.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
laco
|