lightly-train 0.2.3__tar.gz → 0.2.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lightly_train-0.2.3 → lightly_train-0.2.4}/.github/workflows/build_docker_image.yml +2 -2
- {lightly_train-0.2.3 → lightly_train-0.2.4}/DOCS.md +26 -5
- {lightly_train-0.2.3 → lightly_train-0.2.4}/Makefile +3 -3
- {lightly_train-0.2.3/src/lightly_train.egg-info → lightly_train-0.2.4}/PKG-INFO +3 -1
- {lightly_train-0.2.3 → lightly_train-0.2.4}/docker/Makefile +2 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/pyproject.toml +3 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/__init__.py +1 -1
- lightly_train-0.2.4/src/lightly_train/_callbacks/callback_args.py +45 -0
- lightly_train-0.2.4/src/lightly_train/_callbacks/callback_helpers.py +75 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_callbacks/checkpoint.py +11 -1
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_cli.py +18 -4
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_commands/export.py +2 -2
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_commands/train.py +33 -1
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_commands/train_helpers.py +5 -46
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_loggers/jsonl.py +12 -3
- lightly_train-0.2.4/src/lightly_train/_loggers/logger_args.py +23 -0
- lightly_train-0.2.4/src/lightly_train/_loggers/logger_helpers.py +89 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_loggers/tensorboard.py +14 -0
- lightly_train-0.2.4/src/lightly_train/_loggers/wandb.py +32 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_methods/densecl.py +6 -6
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_methods/densecldino.py +5 -6
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_methods/dino.py +6 -6
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_methods/method.py +48 -1
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_methods/simclr.py +7 -8
- lightly_train-0.2.4/src/lightly_train/_plot.py +95 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4/src/lightly_train.egg-info}/PKG-INFO +3 -1
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train.egg-info/SOURCES.txt +10 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train.egg-info/requires.txt +3 -0
- lightly_train-0.2.4/tests/_callbacks/test_callback_helpers.py +154 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_commands/test_common_helpers.py +3 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_commands/test_embed.py +28 -5
- {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_commands/test_train.py +52 -3
- {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_commands/test_train_helpers.py +5 -4
- lightly_train-0.2.4/tests/_loggers/test_logger_helpers.py +158 -0
- lightly_train-0.2.4/tests/_optim/__init__.py +7 -0
- lightly_train-0.2.4/tests/test__plot.py +31 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/.github/pull_request_template.md +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/.github/workflows/release_dockerhub.yml +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/.github/workflows/release_pypi.yml +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/.github/workflows/test.yml +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/.github/workflows/test_code_format.yml +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/.github/workflows/test_docker.yml +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/.github/workflows/test_minimal_deps.yml +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/.github/workflows/weekly_dependency_tests.yml +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/.gitignore +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/CONTRIBUTE.md +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/LICENSE +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/README.md +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/dev_tools/licenseheader.tmpl +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/docker/Dockerfile-amd64-cuda +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/docker/README.md +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/setup.cfg +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_callbacks/__init__.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_checkpoint.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_commands/__init__.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_commands/_warnings.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_commands/common_helpers.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_commands/embed.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_commands/extract_video_frames.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_configs/__init__.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_configs/config.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_configs/omegaconf_utils.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_configs/validate.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_constants.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_embedding/__init__.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_embedding/embedding_format.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_embedding/embedding_predictor.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_embedding/embedding_transform.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_embedding/writers/__init__.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_embedding/writers/csv_writer.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_embedding/writers/embedding_writer.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_embedding/writers/torch_writer.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_embedding/writers/writer_helpers.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_loggers/__init__.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_logging.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_methods/__init__.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_methods/method_args.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_methods/method_helpers.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/__init__.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/custom/__init__.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/custom/custom.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/custom/custom_package.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/embedding/__init__.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/embedding/base.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/embedding_model.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/feature_extractor.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/package.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/package_helpers.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/super_gradients/__init__.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/super_gradients/customizable_detector.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/super_gradients/super_gradients.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/super_gradients/super_gradients_package.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/timm/__init__.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/timm/timm.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/timm/timm_package.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/torchvision/__init__.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/torchvision/convnext.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/torchvision/resnet.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/torchvision/torchvision.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_models/torchvision/torchvision_package.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_optim/__init__.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_optim/adamw_args.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_optim/optimizer.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_optim/optimizer_args.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_optim/optimizer_type.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_optim/trainable_modules.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_scaling.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/_transforms.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/errors.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train/types.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train.egg-info/dependency_links.txt +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train.egg-info/entry_points.txt +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/src/lightly_train.egg-info/top_level.txt +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/__init__.py +0 -0
- {lightly_train-0.2.3/tests/_commands → lightly_train-0.2.4/tests/_callbacks}/__init__.py +0 -0
- {lightly_train-0.2.3/tests/_configs → lightly_train-0.2.4/tests/_commands}/__init__.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_commands/test_export.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_commands/test_extract_video_frames.py +0 -0
- {lightly_train-0.2.3/tests/_embedding → lightly_train-0.2.4/tests/_configs}/__init__.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_configs/test_validate.py +0 -0
- {lightly_train-0.2.3/tests/_embedding/writers → lightly_train-0.2.4/tests/_embedding}/__init__.py +0 -0
- {lightly_train-0.2.3/tests/_methods → lightly_train-0.2.4/tests/_embedding/writers}/__init__.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_embedding/writers/test_csv_writer.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_embedding/writers/test_torch_writer.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_embedding/writers/test_writer_helpers.py +0 -0
- {lightly_train-0.2.3/tests/_models → lightly_train-0.2.4/tests/_methods}/__init__.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_methods/test_method_helpers.py +0 -0
- {lightly_train-0.2.3/tests/_models/custom → lightly_train-0.2.4/tests/_models}/__init__.py +0 -0
- {lightly_train-0.2.3/tests/_models/timm → lightly_train-0.2.4/tests/_models/custom}/__init__.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_models/custom/test_custom.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_models/custom/test_custom_package.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_models/super_gradients/test_super_gradients_package.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_models/test_package_helpers.py +0 -0
- {lightly_train-0.2.3/tests/_optim → lightly_train-0.2.4/tests/_models/timm}/__init__.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_models/timm/test_timm.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_models/timm/test_timm_package.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_models/torchvision/test_torchvision_package.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/_optim/test_optimizer.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/helpers.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/test__checkpoint.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/test__cli.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/test__logging.py +0 -0
- {lightly_train-0.2.3 → lightly_train-0.2.4}/tests/test__scaling.py +0 -0
|
@@ -52,7 +52,7 @@ jobs:
|
|
|
52
52
|
if: ${{ needs.build-docker-image.result == 'success' }}
|
|
53
53
|
uses: rtCamp/action-slack-notify@v2
|
|
54
54
|
env:
|
|
55
|
-
SLACK_WEBHOOK: ${{ secrets.
|
|
55
|
+
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK_RELEASES_DEV }}
|
|
56
56
|
SLACK_ICON_EMOJI: ":train:"
|
|
57
57
|
SLACK_USERNAME: "Build of LightlyTrain Docker Image"
|
|
58
58
|
SLACK_COLOR: "good"
|
|
@@ -71,7 +71,7 @@ jobs:
|
|
|
71
71
|
if: ${{ needs.build-docker-image.result != 'success' }}
|
|
72
72
|
uses: rtCamp/action-slack-notify@v2
|
|
73
73
|
env:
|
|
74
|
-
SLACK_WEBHOOK: ${{ secrets.
|
|
74
|
+
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK_RELEASES_DEV }}
|
|
75
75
|
SLACK_ICON_EMOJI: ":x:"
|
|
76
76
|
SLACK_USERNAME: "Build of LightlyTrain Docker Image"
|
|
77
77
|
SLACK_COLOR: "danger"
|
|
@@ -27,12 +27,20 @@ lightly_train.train(
|
|
|
27
27
|
|
|
28
28
|
In most cases you only have to specify `out`, `data`, and `model`. The rest is optional.
|
|
29
29
|
|
|
30
|
-
|
|
31
|
-
`pip install lightly-train[tensorboard]`):
|
|
30
|
+
You can monitor your training process with the help of `tensorboard` and `wandb` loggers:
|
|
32
31
|
```
|
|
33
|
-
tensorboard
|
|
32
|
+
pip install "lightly-train[tensorboard, wandb]"
|
|
34
33
|
```
|
|
35
34
|
|
|
35
|
+
Configure the loggers from Python:
|
|
36
|
+
* `loggers={"tensorboard": True}`: Enable `tensorboard` logger with default arguments.
|
|
37
|
+
* `loggers={"wandb": True}`: Enable `wandb` logger with default arguments.
|
|
38
|
+
* `loggers={"wandb": {"project": "my-project"}}`: Configure `wandb` logger with custom arguments.
|
|
39
|
+
|
|
40
|
+
LightlyTrain uses the PyTorchLightning loggers under the hood. Learn more about their configuration:
|
|
41
|
+
* https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.wandb.html
|
|
42
|
+
* https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.tensorboard.html
|
|
43
|
+
|
|
36
44
|
#### Exporting
|
|
37
45
|
```python
|
|
38
46
|
import lightly_train
|
|
@@ -94,8 +102,21 @@ lightly-train train \
|
|
|
94
102
|
|
|
95
103
|
In most cases you only have to specify `out`, `data`, and `model`. The rest is optional.
|
|
96
104
|
|
|
97
|
-
|
|
98
|
-
|
|
105
|
+
You can monitor your training process with the help of `tensorboard` and `wandb` loggers:
|
|
106
|
+
```
|
|
107
|
+
pip install "lightly-train[tensorboard, wandb]"
|
|
108
|
+
```
|
|
109
|
+
|
|
110
|
+
Configure the loggers from the command-line:
|
|
111
|
+
* `loggers.tensorboard=True`: Enable `tensorboard` logger with default arguments.
|
|
112
|
+
* `loggers.wandb=True`: Enable `wandb` logger with default arguments.
|
|
113
|
+
* `loggers.wandb.project="my-project"`: Configure `wandb` logger with custom arguments.
|
|
114
|
+
|
|
115
|
+
LightlyTrain uses the PyTorchLightning loggers under the hood. Learn more about their configuration:
|
|
116
|
+
* https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.wandb.html
|
|
117
|
+
* https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.tensorboard.html
|
|
118
|
+
|
|
119
|
+
|
|
99
120
|
```
|
|
100
121
|
tensorboard --logdir my_output_dir
|
|
101
122
|
```
|
|
@@ -72,7 +72,7 @@ test:
|
|
|
72
72
|
|
|
73
73
|
.PHONY: test-ci
|
|
74
74
|
test-ci:
|
|
75
|
-
pytest tests
|
|
75
|
+
pytest tests -v
|
|
76
76
|
|
|
77
77
|
|
|
78
78
|
### Virtual Environment
|
|
@@ -105,8 +105,8 @@ endif
|
|
|
105
105
|
|
|
106
106
|
# SuperGradients is not compatible with Python>=3.10. It is also not easy to install
|
|
107
107
|
# on MacOS. Therefore we exclude it from the default extras.
|
|
108
|
-
EXTRAS = [dev,tensorboard,timm]
|
|
109
|
-
DOCKER_EXTRAS = --extra tensorboard --extra timm
|
|
108
|
+
EXTRAS = [dev,tensorboard,timm,wandb]
|
|
109
|
+
DOCKER_EXTRAS = --extra tensorboard --extra timm --extra wandb
|
|
110
110
|
|
|
111
111
|
# Date until which dependencies installed with --exclude-newer must have been released.
|
|
112
112
|
# Dependencies released after this date are ignored.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: lightly_train
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.4
|
|
4
4
|
Summary: Train models with self-supervised learning in a single command
|
|
5
5
|
Author: Lightly Team
|
|
6
6
|
License: AGPL-3.0
|
|
@@ -28,6 +28,8 @@ Provides-Extra: tensorboard
|
|
|
28
28
|
Requires-Dist: tensorboard>=2.10.0; extra == "tensorboard"
|
|
29
29
|
Provides-Extra: timm
|
|
30
30
|
Requires-Dist: timm>=1.0.3; extra == "timm"
|
|
31
|
+
Provides-Extra: wandb
|
|
32
|
+
Requires-Dist: wandb>=0.12.10; extra == "wandb"
|
|
31
33
|
|
|
32
34
|
# LightlyTrain
|
|
33
35
|
|
|
@@ -69,6 +69,7 @@ test:
|
|
|
69
69
|
-v $(LIGHTLY_TRAIN_OUT):/out \
|
|
70
70
|
-v $(LIGHTLY_TRAIN_DATA):/data \
|
|
71
71
|
-v ./Makefile:/home/lightly_train/docker/Makefile \
|
|
72
|
+
--gpus all \
|
|
72
73
|
lightly/$(IMAGE):$(TAG) make test-cli-from-within-docker -C docker
|
|
73
74
|
|
|
74
75
|
# This target is run from within the docker container
|
|
@@ -76,6 +77,7 @@ test-cli-from-within-docker:
|
|
|
76
77
|
@echo "Test train"
|
|
77
78
|
lightly-train train data=/data out=/out model="torchvision/convnext_small" epochs=2 batch_size=2 model_args.weights="IMAGENET1K_V1" devices=2
|
|
78
79
|
test -f /out/checkpoints/last.ckpt
|
|
80
|
+
test `grep -c "GPU available: True (cuda), used: True" /out/train.log` -gt 0
|
|
79
81
|
@echo "Test embed"
|
|
80
82
|
lightly-train embed data=/data out="/out/embeddings.csv" checkpoint="/out/checkpoints/last.ckpt" batch_size=2 format="csv"
|
|
81
83
|
test `wc -l < /out/embeddings.csv` -eq 6
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) Lightly AG and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
#
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
from lightly_train._callbacks.checkpoint import ModelCheckpointArgs
|
|
12
|
+
from lightly_train._configs.config import Config
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class LearningRateMonitorArgs(Config):
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class DeviceStatsMonitorArgs(Config):
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class EarlyStoppingArgs(Config):
|
|
27
|
+
monitor: str = "train_loss"
|
|
28
|
+
patience: int = int(1e12)
|
|
29
|
+
check_finite: bool = True
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class CallbackArgs(Config):
|
|
34
|
+
learning_rate_monitor: Optional[LearningRateMonitorArgs] = field(
|
|
35
|
+
default_factory=LearningRateMonitorArgs
|
|
36
|
+
)
|
|
37
|
+
device_stats_monitor: Optional[DeviceStatsMonitorArgs] = field(
|
|
38
|
+
default_factory=DeviceStatsMonitorArgs
|
|
39
|
+
)
|
|
40
|
+
early_stopping: Optional[EarlyStoppingArgs] = field(
|
|
41
|
+
default_factory=EarlyStoppingArgs
|
|
42
|
+
)
|
|
43
|
+
model_checkpoint: Optional[ModelCheckpointArgs] = field(
|
|
44
|
+
default_factory=ModelCheckpointArgs
|
|
45
|
+
)
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) Lightly AG and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
#
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import dataclasses
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from omegaconf import OmegaConf
|
|
15
|
+
from pytorch_lightning import Callback
|
|
16
|
+
from pytorch_lightning.callbacks import (
|
|
17
|
+
DeviceStatsMonitor,
|
|
18
|
+
EarlyStopping,
|
|
19
|
+
LearningRateMonitor,
|
|
20
|
+
)
|
|
21
|
+
from torch.nn import Module
|
|
22
|
+
|
|
23
|
+
from lightly_train._callbacks.callback_args import (
|
|
24
|
+
CallbackArgs,
|
|
25
|
+
)
|
|
26
|
+
from lightly_train._callbacks.checkpoint import ModelCheckpoint
|
|
27
|
+
from lightly_train._checkpoint import CheckpointLightlyTrainModels
|
|
28
|
+
from lightly_train._configs import validate
|
|
29
|
+
from lightly_train._models.embedding_model import EmbeddingModel
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def validate_callback_args(callback_dict: dict[str, Any] | None) -> CallbackArgs:
|
|
33
|
+
callback_dictconfig = OmegaConf.create(
|
|
34
|
+
{} if callback_dict is None else callback_dict
|
|
35
|
+
)
|
|
36
|
+
callback_dictconfig = validate.validate_dictconfig(
|
|
37
|
+
callback_dictconfig, CallbackArgs
|
|
38
|
+
)
|
|
39
|
+
callback_args = OmegaConf.to_object(callback_dictconfig)
|
|
40
|
+
assert isinstance(callback_args, CallbackArgs) # make mypy happy
|
|
41
|
+
return callback_args
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def get_callbacks(
|
|
45
|
+
callback_args: CallbackArgs,
|
|
46
|
+
out: Path,
|
|
47
|
+
model: Module,
|
|
48
|
+
embedding_model: EmbeddingModel,
|
|
49
|
+
):
|
|
50
|
+
callbacks: list[Callback] = []
|
|
51
|
+
if callback_args.learning_rate_monitor is not None:
|
|
52
|
+
callbacks.append(
|
|
53
|
+
LearningRateMonitor(
|
|
54
|
+
**dataclasses.asdict(callback_args.learning_rate_monitor)
|
|
55
|
+
)
|
|
56
|
+
)
|
|
57
|
+
if callback_args.device_stats_monitor is not None:
|
|
58
|
+
callbacks.append(
|
|
59
|
+
DeviceStatsMonitor(**dataclasses.asdict(callback_args.device_stats_monitor))
|
|
60
|
+
)
|
|
61
|
+
if callback_args.early_stopping is not None:
|
|
62
|
+
callbacks.append(
|
|
63
|
+
EarlyStopping(**dataclasses.asdict(callback_args.early_stopping))
|
|
64
|
+
)
|
|
65
|
+
if callback_args.model_checkpoint is not None:
|
|
66
|
+
callbacks.append(
|
|
67
|
+
ModelCheckpoint(
|
|
68
|
+
models=CheckpointLightlyTrainModels(
|
|
69
|
+
model=model, embedding_model=embedding_model
|
|
70
|
+
),
|
|
71
|
+
dirpath=out / "checkpoints",
|
|
72
|
+
**dataclasses.asdict(callback_args.model_checkpoint),
|
|
73
|
+
)
|
|
74
|
+
)
|
|
75
|
+
return callbacks
|
|
@@ -8,8 +8,9 @@
|
|
|
8
8
|
from __future__ import annotations
|
|
9
9
|
|
|
10
10
|
import logging
|
|
11
|
+
from dataclasses import dataclass
|
|
11
12
|
from datetime import timedelta
|
|
12
|
-
from typing import Any
|
|
13
|
+
from typing import Any, Optional
|
|
13
14
|
|
|
14
15
|
from pytorch_lightning import LightningModule, Trainer
|
|
15
16
|
from pytorch_lightning.callbacks import ModelCheckpoint as _ModelCheckpoint
|
|
@@ -19,11 +20,20 @@ from lightly_train._checkpoint import (
|
|
|
19
20
|
CheckpointLightlyTrain,
|
|
20
21
|
CheckpointLightlyTrainModels,
|
|
21
22
|
)
|
|
23
|
+
from lightly_train._configs.config import Config
|
|
22
24
|
from lightly_train.types import PathLike
|
|
23
25
|
|
|
24
26
|
logger = logging.getLogger(__name__)
|
|
25
27
|
|
|
26
28
|
|
|
29
|
+
@dataclass
|
|
30
|
+
class ModelCheckpointArgs(Config):
|
|
31
|
+
save_last: bool = True
|
|
32
|
+
enable_version_counter: bool = False
|
|
33
|
+
save_top_k: int = 1
|
|
34
|
+
every_n_epochs: Optional[int] = None
|
|
35
|
+
|
|
36
|
+
|
|
27
37
|
class ModelCheckpoint(_ModelCheckpoint):
|
|
28
38
|
def __init__(
|
|
29
39
|
self,
|
|
@@ -34,7 +34,7 @@ _HELP_MSG = """
|
|
|
34
34
|
lightly-train extract_video_frames Extract frames from videos using ffmpeg.
|
|
35
35
|
lightly-train help Show help message.
|
|
36
36
|
|
|
37
|
-
Run `
|
|
37
|
+
Run `lightly-train <command> help` for more information on a specific command.
|
|
38
38
|
|
|
39
39
|
Optional arguments:
|
|
40
40
|
-v, --verbose Run the command in verbose mode for detailed output.
|
|
@@ -105,6 +105,20 @@ _TRAIN_HELP_MSG = f"""
|
|
|
105
105
|
For details, see: https://lightning.ai/docs/pytorch/stable/common/trainer.html#precision
|
|
106
106
|
seed (int):
|
|
107
107
|
Random seed for reproducibility. Default: {_train_cfg.seed}
|
|
108
|
+
loggers (dict):
|
|
109
|
+
Loggers for training. Either null or a dictionary of logger names to either
|
|
110
|
+
null or a dictionary of logger arguments. Null uses the default loggers.
|
|
111
|
+
To disable a logger, set it to null: `loggers.tensorboard=null`.
|
|
112
|
+
To configure a logger, pass the respective arguments:
|
|
113
|
+
`loggers.wandb.project="my_project"`.
|
|
114
|
+
Default: null
|
|
115
|
+
callbacks (dict):
|
|
116
|
+
Callbacks fo training. Either null or a dictionary of callback names to
|
|
117
|
+
either null or a dictionary of callback arguments. Null uses the default
|
|
118
|
+
callbacks. To disable a callback, set it to null:
|
|
119
|
+
`callbacks.model_checkpoint=null`. To configure a callback, pass the
|
|
120
|
+
respective arguments: `callbacks.model_checkpoint.every_n_epochs=5`.
|
|
121
|
+
Default: null
|
|
108
122
|
optim_args (dict):
|
|
109
123
|
Arguments for AdamW optimizer. Available arguments are:
|
|
110
124
|
- lr (float)
|
|
@@ -159,10 +173,10 @@ _EXPORT_HELP_MSG = """
|
|
|
159
173
|
Format to save the model in. Valid options are 'torch_model' and
|
|
160
174
|
'torch_state_dict'. 'torch_model' saves the model as a torch module which
|
|
161
175
|
can be loaded with `model = torch.load(out)`. This requires that the same
|
|
162
|
-
|
|
176
|
+
LightlyTrain version is installed when the model is exported and when it is
|
|
163
177
|
loaded again. 'torch_state_dict' saves the model's state dict which can be
|
|
164
178
|
loaded with `model.load_state_dict(torch.load(out))`. This is more flexible
|
|
165
|
-
and can be used to load the model with different
|
|
179
|
+
and can be used to load the model with different LightlyTrain versions but
|
|
166
180
|
requires the model to already be instantiated.
|
|
167
181
|
|
|
168
182
|
Optional arguments:
|
|
@@ -399,7 +413,7 @@ def _show_invalid_command_help(command: str) -> None:
|
|
|
399
413
|
msg = _format_msg(
|
|
400
414
|
f"""
|
|
401
415
|
Unknown command '{command}':
|
|
402
|
-
|
|
416
|
+
lightly-train {command}
|
|
403
417
|
"""
|
|
404
418
|
)
|
|
405
419
|
msg += "\n"
|
|
@@ -52,10 +52,10 @@ def export(
|
|
|
52
52
|
Format to save the model in. Valid options are 'torch_model' and
|
|
53
53
|
'torch_state_dict'. 'torch_model' saves the model as a torch module which
|
|
54
54
|
can be loaded with `model = torch.load(out)`. This requires that the same
|
|
55
|
-
|
|
55
|
+
LightlyTrain version is installed when the model is exported and when it is
|
|
56
56
|
loaded again. 'torch_state_dict' saves the model's state dict which can be
|
|
57
57
|
loaded with `model.load_state_dict(torch.load(out))`. This is more flexible
|
|
58
|
-
and can be used to load the model with different
|
|
58
|
+
and can be used to load the model with different LightlyTrain versions but
|
|
59
59
|
requires the model to already be instantiated.
|
|
60
60
|
overwrite:
|
|
61
61
|
Overwrite the output file if it already exists.
|
|
@@ -21,7 +21,9 @@ from pytorch_lightning.trainer.connectors.accelerator_connector import _PRECISIO
|
|
|
21
21
|
from torch.nn import Module
|
|
22
22
|
from torch.utils.data import Dataset
|
|
23
23
|
|
|
24
|
+
import lightly_train._loggers.logger_helpers
|
|
24
25
|
from lightly_train import _logging
|
|
26
|
+
from lightly_train._callbacks import callback_helpers
|
|
25
27
|
from lightly_train._commands import _warnings, common_helpers, train_helpers
|
|
26
28
|
from lightly_train._configs import omegaconf_utils, validate
|
|
27
29
|
from lightly_train._configs.config import Config
|
|
@@ -49,6 +51,8 @@ def train(
|
|
|
49
51
|
strategy: str | Strategy = "auto",
|
|
50
52
|
precision: _PRECISION_INPUT = "32-true", # Default precision in PyTorch Lightning
|
|
51
53
|
seed: int = 0,
|
|
54
|
+
loggers: dict[str, dict[str, Any] | None] | None = None,
|
|
55
|
+
callbacks: dict[str, dict[str, Any] | None] | None = None,
|
|
52
56
|
optim_args: dict[str, Any] | None = None,
|
|
53
57
|
transform_args: dict[str, Any] | None = None,
|
|
54
58
|
loader_args: dict[str, Any] | None = None,
|
|
@@ -111,6 +115,19 @@ def train(
|
|
|
111
115
|
https://lightning.ai/docs/pytorch/stable/common/trainer.html#precision
|
|
112
116
|
seed:
|
|
113
117
|
Random seed for reproducibility.
|
|
118
|
+
loggers:
|
|
119
|
+
Loggers for training. Either None or a dictionary of logger names to either
|
|
120
|
+
None or a dictionary of logger arguments. None uses the default loggers.
|
|
121
|
+
To disable a logger, set it to None: `loggers={"tensorboard": None}`.
|
|
122
|
+
To configure a logger, pass the respective arguments:
|
|
123
|
+
`loggers={"wandb": {"project": "my_project"}}`.
|
|
124
|
+
callbacks:
|
|
125
|
+
Callbacks for training. Either None or a dictionary of callback names to
|
|
126
|
+
either None or a dictionary of callback arguments. None uses the default
|
|
127
|
+
callbacks. To disable a callback, set it to None:
|
|
128
|
+
`callbacks={"model_checkpoint": None}`. To configure a callback, pass the
|
|
129
|
+
respective arguments:
|
|
130
|
+
`callbacks={"model_checkpoint": {"every_n_epochs": 5}}`.
|
|
114
131
|
optim_args:
|
|
115
132
|
Arguments for AdamW optimizer. Available arguments are:
|
|
116
133
|
- lr: float
|
|
@@ -160,10 +177,21 @@ def train(
|
|
|
160
177
|
log_every_n_steps = train_helpers.get_lightning_logging_interval(
|
|
161
178
|
dataset_size=scaling_info.dataset_size, batch_size=batch_size
|
|
162
179
|
)
|
|
163
|
-
|
|
180
|
+
logger_args = lightly_train._loggers.logger_helpers.validate_logger_args(
|
|
181
|
+
loggers=loggers
|
|
182
|
+
)
|
|
183
|
+
logger_instances = lightly_train._loggers.logger_helpers.get_loggers(
|
|
184
|
+
logger_args=logger_args, out=out_dir
|
|
185
|
+
)
|
|
186
|
+
callback_args = callback_helpers.validate_callback_args(callback_dict=callbacks)
|
|
187
|
+
callback_instances = callback_helpers.get_callbacks(
|
|
188
|
+
callback_args=callback_args,
|
|
164
189
|
out=out_dir,
|
|
165
190
|
model=model_instance,
|
|
166
191
|
embedding_model=embedding_model,
|
|
192
|
+
)
|
|
193
|
+
trainer_instance = train_helpers.get_trainer(
|
|
194
|
+
out=out_dir,
|
|
167
195
|
epochs=epochs,
|
|
168
196
|
accelerator=accelerator,
|
|
169
197
|
strategy=strategy,
|
|
@@ -171,6 +199,8 @@ def train(
|
|
|
171
199
|
num_nodes=num_nodes,
|
|
172
200
|
precision=precision,
|
|
173
201
|
log_every_n_steps=log_every_n_steps,
|
|
202
|
+
loggers=logger_instances,
|
|
203
|
+
callbacks=callback_instances,
|
|
174
204
|
trainer_args=trainer_args,
|
|
175
205
|
)
|
|
176
206
|
dataloader = train_helpers.get_dataloader(
|
|
@@ -228,6 +258,8 @@ class TrainConfig(Config):
|
|
|
228
258
|
strategy: str = "auto"
|
|
229
259
|
precision: str = "32-true"
|
|
230
260
|
seed: int = 0
|
|
261
|
+
loggers: Optional[Dict[str, Optional[Dict[str, Any]]]] = None
|
|
262
|
+
callbacks: Optional[Dict[str, Optional[Dict[str, Any]]]] = None
|
|
231
263
|
optim_args: Optional[Dict[str, Any]] = None
|
|
232
264
|
transform_args: Optional[Dict[str, Any]] = None
|
|
233
265
|
loader_args: Optional[Dict[str, Any]] = None
|
|
@@ -7,32 +7,24 @@
|
|
|
7
7
|
#
|
|
8
8
|
from __future__ import annotations
|
|
9
9
|
|
|
10
|
+
import logging
|
|
10
11
|
from pathlib import Path
|
|
11
12
|
from typing import Any, Callable, Sized, Type
|
|
12
13
|
|
|
13
14
|
from lightly.data import LightlyDataset
|
|
14
|
-
from pytorch_lightning import Trainer
|
|
15
|
+
from pytorch_lightning import Callback, Trainer
|
|
15
16
|
from pytorch_lightning.accelerators.accelerator import Accelerator
|
|
16
17
|
from pytorch_lightning.accelerators.cpu import CPUAccelerator
|
|
17
18
|
from pytorch_lightning.accelerators.cuda import CUDAAccelerator
|
|
18
|
-
from pytorch_lightning.callbacks import (
|
|
19
|
-
DeviceStatsMonitor,
|
|
20
|
-
EarlyStopping,
|
|
21
|
-
LearningRateMonitor,
|
|
22
|
-
)
|
|
23
19
|
from pytorch_lightning.loggers import Logger
|
|
24
20
|
from pytorch_lightning.strategies.strategy import Strategy
|
|
25
21
|
from pytorch_lightning.trainer.connectors.accelerator_connector import _PRECISION_INPUT
|
|
26
22
|
from torch.nn import Module
|
|
27
23
|
from torch.utils.data import DataLoader, Dataset
|
|
28
24
|
|
|
29
|
-
from lightly_train._callbacks.checkpoint import ModelCheckpoint
|
|
30
|
-
from lightly_train._checkpoint import CheckpointLightlyTrainModels
|
|
31
25
|
from lightly_train._commands import common_helpers
|
|
32
26
|
from lightly_train._configs import validate
|
|
33
27
|
from lightly_train._constants import DATALOADER_TIMEOUT
|
|
34
|
-
from lightly_train._loggers.jsonl import JSONLLogger
|
|
35
|
-
from lightly_train._loggers.tensorboard import TensorBoardLogger
|
|
36
28
|
from lightly_train._methods import method_helpers
|
|
37
29
|
from lightly_train._methods.method import Method
|
|
38
30
|
from lightly_train._models import package_helpers
|
|
@@ -42,18 +34,6 @@ from lightly_train._optim.optimizer_type import OptimizerType
|
|
|
42
34
|
from lightly_train._scaling import IMAGENET_SIZE, ScalingInfo
|
|
43
35
|
from lightly_train.types import PathLike, Transform
|
|
44
36
|
|
|
45
|
-
try:
|
|
46
|
-
import timm
|
|
47
|
-
except ImportError:
|
|
48
|
-
timm = None
|
|
49
|
-
|
|
50
|
-
try:
|
|
51
|
-
import tensorboard
|
|
52
|
-
except ImportError:
|
|
53
|
-
tensorboard = None
|
|
54
|
-
|
|
55
|
-
import logging
|
|
56
|
-
|
|
57
37
|
logger = logging.getLogger(__name__)
|
|
58
38
|
|
|
59
39
|
|
|
@@ -166,8 +146,6 @@ def get_embedding_model(model: Module, embed_dim: int | None = None) -> Embeddin
|
|
|
166
146
|
|
|
167
147
|
def get_trainer(
|
|
168
148
|
out: Path,
|
|
169
|
-
model: Module,
|
|
170
|
-
embedding_model: EmbeddingModel,
|
|
171
149
|
epochs: int,
|
|
172
150
|
accelerator: str | Accelerator,
|
|
173
151
|
strategy: str | Strategy,
|
|
@@ -175,17 +153,11 @@ def get_trainer(
|
|
|
175
153
|
num_nodes: int,
|
|
176
154
|
log_every_n_steps: int,
|
|
177
155
|
precision: _PRECISION_INPUT | None,
|
|
156
|
+
loggers: list[Logger],
|
|
157
|
+
callbacks: list[Callback],
|
|
178
158
|
trainer_args: dict[str, Any] | None,
|
|
179
159
|
) -> Trainer:
|
|
180
160
|
logger.debug("Getting trainer.")
|
|
181
|
-
# Set version and name to empty string to save logs directly in the root
|
|
182
|
-
# directory.
|
|
183
|
-
loggers: list[Logger] = [
|
|
184
|
-
JSONLLogger(save_dir=out, name="", version=""),
|
|
185
|
-
]
|
|
186
|
-
if tensorboard is not None:
|
|
187
|
-
loggers.append(TensorBoardLogger(save_dir=out, name="", version=""))
|
|
188
|
-
logger.debug(f"Using loggers {[log.__class__.__name__ for log in loggers]}.")
|
|
189
161
|
|
|
190
162
|
accelerator = common_helpers.get_accelerator(accelerator=accelerator)
|
|
191
163
|
strategy = get_strategy(accelerator=accelerator, strategy=strategy, devices=devices)
|
|
@@ -200,20 +172,7 @@ def get_trainer(
|
|
|
200
172
|
num_nodes=num_nodes,
|
|
201
173
|
precision=precision,
|
|
202
174
|
log_every_n_steps=log_every_n_steps,
|
|
203
|
-
callbacks=
|
|
204
|
-
LearningRateMonitor(),
|
|
205
|
-
DeviceStatsMonitor(),
|
|
206
|
-
ModelCheckpoint(
|
|
207
|
-
save_last=True,
|
|
208
|
-
models=CheckpointLightlyTrainModels(
|
|
209
|
-
model=model, embedding_model=embedding_model
|
|
210
|
-
),
|
|
211
|
-
dirpath=out / "checkpoints",
|
|
212
|
-
enable_version_counter=False,
|
|
213
|
-
),
|
|
214
|
-
# Stop if training loss diverges.
|
|
215
|
-
EarlyStopping(monitor="train_loss", patience=int(1e12), check_finite=True),
|
|
216
|
-
],
|
|
175
|
+
callbacks=callbacks,
|
|
217
176
|
logger=loggers,
|
|
218
177
|
sync_batchnorm=sync_batchnorm,
|
|
219
178
|
)
|
|
@@ -5,9 +5,12 @@
|
|
|
5
5
|
# This source code is licensed under the license found in the
|
|
6
6
|
# LICENSE file in the root directory of this source tree.
|
|
7
7
|
#
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
8
10
|
import json
|
|
9
11
|
import logging
|
|
10
12
|
import os
|
|
13
|
+
from dataclasses import dataclass
|
|
11
14
|
from typing import Dict, List, Optional, Union
|
|
12
15
|
|
|
13
16
|
from lightning_fabric.loggers.logger import rank_zero_experiment
|
|
@@ -15,6 +18,7 @@ from lightning_fabric.utilities.rank_zero import rank_zero_warn
|
|
|
15
18
|
from pytorch_lightning.loggers import CSVLogger
|
|
16
19
|
from pytorch_lightning.loggers.csv_logs import ExperimentWriter as CSVExperimentWriter
|
|
17
20
|
|
|
21
|
+
from lightly_train._configs.config import Config
|
|
18
22
|
from lightly_train.types import PathLike
|
|
19
23
|
|
|
20
24
|
log = logging.getLogger(__name__)
|
|
@@ -55,6 +59,11 @@ class ExperimentWriter(CSVExperimentWriter):
|
|
|
55
59
|
)
|
|
56
60
|
|
|
57
61
|
|
|
62
|
+
@dataclass
|
|
63
|
+
class JSONLLoggerArgs(Config):
|
|
64
|
+
flush_logs_every_n_steps: int = 100
|
|
65
|
+
|
|
66
|
+
|
|
58
67
|
class JSONLLogger(CSVLogger):
|
|
59
68
|
"""Log to local file system in JSON Lines format.
|
|
60
69
|
|
|
@@ -69,7 +78,7 @@ class JSONLLogger(CSVLogger):
|
|
|
69
78
|
|
|
70
79
|
Args:
|
|
71
80
|
save_dir: Save directory
|
|
72
|
-
name: Experiment name.
|
|
81
|
+
name: Experiment name.
|
|
73
82
|
version:
|
|
74
83
|
Experiment version. If version is not specified the logger inspects the save
|
|
75
84
|
directory for existing versions, then automatically assigns the next
|
|
@@ -83,8 +92,8 @@ class JSONLLogger(CSVLogger):
|
|
|
83
92
|
def __init__(
|
|
84
93
|
self,
|
|
85
94
|
save_dir: PathLike,
|
|
86
|
-
name: str = "
|
|
87
|
-
version: Optional[Union[int, str]] =
|
|
95
|
+
name: str = "",
|
|
96
|
+
version: Optional[Union[int, str]] = "",
|
|
88
97
|
prefix: str = "",
|
|
89
98
|
flush_logs_every_n_steps: int = 100,
|
|
90
99
|
):
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (c) Lightly AG and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
#
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
from lightly_train._configs.config import Config
|
|
12
|
+
from lightly_train._loggers.jsonl import JSONLLoggerArgs
|
|
13
|
+
from lightly_train._loggers.tensorboard import TensorBoardLoggerArgs
|
|
14
|
+
from lightly_train._loggers.wandb import WandbLoggerArgs
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class LoggerArgs(Config):
|
|
19
|
+
jsonl: Optional[JSONLLoggerArgs] = field(default_factory=JSONLLoggerArgs)
|
|
20
|
+
tensorboard: Optional[TensorBoardLoggerArgs] = field(
|
|
21
|
+
default_factory=TensorBoardLoggerArgs
|
|
22
|
+
)
|
|
23
|
+
wandb: Optional[WandbLoggerArgs] = None
|